From 51e452532bb76effa0c4295b12215b4c7b0e29da Mon Sep 17 00:00:00 2001 From: amrinfathima-mcw Date: Fri, 8 Nov 2024 11:10:44 +0530 Subject: [PATCH 0001/1324] Adds bf16, f16 type support for tfl.concatenation op (#92) * TfLite concatenation missing datatype support (#41) * added f16/bf16 support for concatenation op --- tensorflow/lite/kernels/BUILD | 1 + tensorflow/lite/kernels/concatenation.cc | 17 +++- tensorflow/lite/kernels/concatenation_test.cc | 78 +++++++++++++++++++ tensorflow/lite/kernels/test_util.h | 4 + 4 files changed, 96 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index d675cdd9907730..e505bd999b8fa4 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -1565,6 +1565,7 @@ cc_test( ":test_util", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", + "@eigen_archive//:eigen3", "@flatbuffers", ], ) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 4f116440fd2049..fbf153b90d7327 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -91,6 +92,12 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, int axis, case kTfLiteFloat32: TF_LITE_CONCATENATION(float); break; + case kTfLiteFloat16: + TF_LITE_CONCATENATION(Eigen::half); + break; + case kTfLiteBFloat16: + TF_LITE_CONCATENATION(Eigen::bfloat16); + break; case kTfLiteInt32: TF_LITE_CONCATENATION(int32); break; @@ -142,10 +149,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || - input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || - input_type == kTfLiteInt32 || input_type == kTfLiteInt64 || - input_type == kTfLiteBool || input_type == kTfLiteUInt32); + input_type == kTfLiteFloat32 || input_type == kTfLiteFloat16 || + input_type == kTfLiteBFloat16 || + input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 || + input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || + input_type == kTfLiteInt64 || input_type == kTfLiteBool || + input_type == kTfLiteUInt32); // Check to see if we can calculate the output now. bool all_inputs_at_prepare = true; diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index 685abd5d5e7569..28692ae1528dd3 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -108,6 +108,29 @@ TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); } +TEST(ConcatenationOpTest, ThreeDimensionalOneInputBFloat16) { + ConcatenationOpModel m({TensorType_BFLOAT16, {2, 1, 2}}, + /*axis=*/1, + /*num_inputs=*/1); + m.SetInput( + 0, + {static_cast(1.0f), static_cast(3.0f), + static_cast(4.0f), static_cast(7.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(ConcatenationOpTest, ThreeDimensionalOneInputFloat16) { + ConcatenationOpModel m({TensorType_FLOAT16, {2, 1, 2}}, + /*axis=*/1, + /*num_inputs=*/1); + m.SetInput(0, + {static_cast(1.0f), static_cast(3.0f), + static_cast(4.0f), static_cast(7.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + TEST(ConcatenationOpTest, ThreeDimensionalOneInputUInt32) { ConcatenationOpModel m0({TensorType_UINT32, {2, 1, 2}}, /*axis=*/1, /*num_inputs=*/1); @@ -152,6 +175,61 @@ TEST(ConcatenationOpTest, FiveDimensionalTwoInput) { 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24})); } +TEST(ConcatenationOpTest, FiveDimensionalTwoInputBFloat16) { + ConcatenationOpModel m( + {TensorType_BFLOAT16, {2, 1, 2, 1, 3}}, + /*axis=*/0, + /*num_inputs=*/2); + m.SetInput( + 0, + {static_cast(1.0f), static_cast(2.0f), + static_cast(3.0f), static_cast(4.0f), + static_cast(5.0f), static_cast(6.0f), + static_cast(7.0f), static_cast(8.0f), + static_cast(9.0f), static_cast(10.0f), + static_cast(11.0f), + static_cast(12.0f)}); + m.SetInput( + 1, + {static_cast(13.0f), static_cast(14.0f), + static_cast(15.0f), Eigen::bfloat16{16.0f}, + static_cast(17.0f), static_cast(18.0f), + static_cast(19.0f), static_cast(20.0f), + static_cast(21.0f), static_cast(22.0f), + static_cast(23.0f), + static_cast(24.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24})); +} + +TEST(ConcatenationOpTest, FiveDimensionalTwoInputFloat16) { + ConcatenationOpModel m({TensorType_FLOAT16, {2, 1, 2, 1, 3}}, + /*axis=*/0, + /*num_inputs=*/2); + m.SetInput( + 0, {static_cast(1.0f), static_cast(2.0f), + static_cast(3.0f), static_cast(4.0f), + static_cast(5.0f), static_cast(6.0f), + static_cast(7.0f), Eigen::half{8.0f}, + static_cast(9.0f), static_cast(10.0f), + static_cast(11.0f), static_cast(12.0f)}); + m.SetInput( + 1, {static_cast(13.0f), static_cast(14.0f), + Eigen::half{15.0f}, static_cast(16.0f), + Eigen::half{17.0f}, static_cast(18.0f), + static_cast(19.0f), static_cast(20.0f), + static_cast(21.0f), static_cast(22.0f), + static_cast(23.0f), static_cast(24.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24})); +} + TEST(ConcatenationOpTest, FiveDimensionalTwoInputUInt32) { ConcatenationOpModel m0({TensorType_UINT32, {2, 1, 2, 1, 3}}, /*axis=*/0, diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 8cc248a36ed7f9..654c1052ca2558 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -133,6 +133,10 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteFloat16; } +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteBFloat16; +} // A test model that contains a single operator. All operator inputs and // output are external to the model, so the tests can directly access them. // Typical usage: From 8563f1f95d10f17fe13dd7a79cf4d52d2f8fc776 Mon Sep 17 00:00:00 2001 From: Sachin Muradi Date: Thu, 20 Feb 2025 14:31:34 -0800 Subject: [PATCH 0002/1324] Add NonmaxsuppresionV3/V4 in slow ops --- tensorflow/compiler/jit/compilability_check_util.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 2b15a4affc76af..50b26371698877 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -370,11 +370,16 @@ bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const { // https://github.com/tensorflow/tensorflow/pull/31012: // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes // create convolutions too large for CuDNN to handle. + // NonMaxSuppressionV3/V4 in XLA runs significantly slower than TF kernel in + // object detection models, specially when there are a lot of proposed + // bounding boxes. return node.type_string() == "SelfAdjointEigV2" || node.type_string() == "Svd" || node.type_string() == "Qr" || node.type_string() == "MatrixInverse" || node.type_string() == "MatrixSolve" || - node.type_string() == "ResizeBilinearGrad"; + node.type_string() == "ResizeBilinearGrad" || + node.type_string() == "NonMaxSuppressionV3" || + node.type_string() == "NonMaxSuppressionV4"; } bool RecursiveCompilabilityChecker::IsCompilableNode( From 208b54d4d650643f37ec1580b225453a96f3325b Mon Sep 17 00:00:00 2001 From: Sachin Muradi Date: Thu, 27 Feb 2025 13:01:39 -0800 Subject: [PATCH 0003/1324] Add compilable check test for NMSV3/4 --- .../jit/compilability_check_util_test.cc | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 0fe2d2d2fe96b7..ea24176bb04a4a 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -51,6 +51,7 @@ constexpr char kUncompilableFunctionName[] = "UncompilableFn"; constexpr char kUncompilableFunctionNodeName[] = "n_c_uncompilable"; constexpr char kUncompilableFunctionTwoName[] = "UncompilableFnTwo"; constexpr char kUncompilableFunctionNodeTwoName[] = "n_d_uncompilable"; +constexpr char kNonMaxSuppressionNodeName[] = "NonMaxSuppression"; // A dummy OpKernel for testing. class DummyCompilableOp : public XlaOpKernel { @@ -63,6 +64,7 @@ class DummyCompilableOp : public XlaOpKernel { // Register the DummyCompilableOp kernel for CPU. REGISTER_OP("InputFloatOp").Output("o: float"); +REGISTER_OP("InputInt32Op").Output("o: int32"); REGISTER_OP("CompilableOp").Input("i: float").Output("o: float"); REGISTER_XLA_OP(Name("CompilableOp").Device(DEVICE_CPU_XLA_JIT), DummyCompilableOp); @@ -554,5 +556,90 @@ TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) { EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); } +TEST_F(CompilabilityCheckUtilTest, CheckNonMaxSuppressionV3UncompilableSlowOp) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + + Node* boxes = ops::SourceOp("InputFloatOp", opts); + Node* scores = ops::SourceOp("InputFloatOp", opts); + Node* max_output_size = ops::SourceOp("InputInt32Op", opts); + Node* iou_threshold = ops::SourceOp("InputFloatOp", opts); + Node* score_threshold = ops::SourceOp("InputFloatOp", opts); + + NodeBuilder non_max_suppression_builder( + kNonMaxSuppressionNodeName, "NonMaxSuppressionV3", opts.op_registry()); + non_max_suppression_builder.Input(boxes) + .Input(scores) + .Input(max_output_size) + .Input(iou_threshold) + .Input(score_threshold) + .Attr("T", DT_FLOAT); + Node* non_max_suppression; + non_max_suppression = + builder.opts().FinalizeBuilder(&non_max_suppression_builder); + + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + EXPECT_FALSE(checker_->IsCompilableNode(*non_max_suppression, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*non_max_suppression, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "slow operation")); +} + +TEST_F(CompilabilityCheckUtilTest, CheckNonMaxSuppressionV4UncompilableSlowOp) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + + Node* boxes = ops::SourceOp("InputFloatOp", opts); + Node* scores = ops::SourceOp("InputFloatOp", opts); + Node* max_output_size = ops::SourceOp("InputInt32Op", opts); + Node* iou_threshold = ops::SourceOp("InputFloatOp", opts); + Node* score_threshold = ops::SourceOp("InputFloatOp", opts); + + NodeBuilder non_max_suppression_v4_builder( + kNonMaxSuppressionNodeName, "NonMaxSuppressionV4", opts.op_registry()); + non_max_suppression_v4_builder.Input(boxes) + .Input(scores) + .Input(max_output_size) + .Input(iou_threshold) + .Input(score_threshold) + .Attr("T", DT_FLOAT); + Node* non_max_suppression_v4; + non_max_suppression_v4 = + builder.opts().FinalizeBuilder(&non_max_suppression_v4_builder); + + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + EXPECT_FALSE( + checker_->IsCompilableNode(*non_max_suppression_v4, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*non_max_suppression_v4, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "slow operation")); +} + } // namespace } // namespace tensorflow From 91bf4e0dd81b226b538f2ac2b723e058602f73b4 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 <115997457+gaikwadrahul8@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:04:34 +0530 Subject: [PATCH 0004/1324] Fix 03 broken links in object_detection.md --- tensorflow/lite/g3doc/android/tutorials/object_detection.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/g3doc/android/tutorials/object_detection.md b/tensorflow/lite/g3doc/android/tutorials/object_detection.md index e9ecc441d651f3..640b082311970b 100644 --- a/tensorflow/lite/g3doc/android/tutorials/object_detection.md +++ b/tensorflow/lite/g3doc/android/tutorials/object_detection.md @@ -147,7 +147,7 @@ convert data such as images, into a tensor data format that can be processed by the model you are using. The example app uses the TensorFlow Lite -[Task library for vision](../../inference_with_metadata/task_library/overview#supported_tasks) +[Task library for vision](../../inference_with_metadata/task_library/overview.md#supported-tasks) to enable execution of the object detection machine learning model. The following instructions explain how to add the required library dependencies to your own Android app project. @@ -263,7 +263,7 @@ device, such as Graphics Processing Units (GPUs), Tensor Processing Units TensorFlow Lite models is recommended, but not required. The object detector is initialized using the current settings on the thread that -is using it. You can use CPU and [NNAPI](../../android/delegates/nnapi) +is using it. You can use CPU and [NNAPI](../../android/delegates/nnapi.md) delegates with detectors that are created on the main thread and used on a background thread, but the thread that initialized the detector must use the GPU delegate. @@ -290,7 +290,7 @@ when (currentDelegate) { ``` For more information about using hardware acceleration delegates with TensorFlow -Lite, see [TensorFlow Lite Delegates](../../performance/delegates). +Lite, see [TensorFlow Lite Delegates](../../performance/delegates.md). ## Prepare data for the model From 7dc87202aba4812d1f6d72e742d6254282e01e01 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 <115997457+gaikwadrahul8@users.noreply.github.com> Date: Fri, 21 Mar 2025 00:34:02 +0530 Subject: [PATCH 0005/1324] Update 07 broken links in text_classification.md --- .../g3doc/android/tutorials/text_classification.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/g3doc/android/tutorials/text_classification.md b/tensorflow/lite/g3doc/android/tutorials/text_classification.md index 19baf52bb17ac4..3b0fe1f27d781e 100644 --- a/tensorflow/lite/g3doc/android/tutorials/text_classification.md +++ b/tensorflow/lite/g3doc/android/tutorials/text_classification.md @@ -7,7 +7,7 @@ physical Android device but can also run on a device emulator. The [example application](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android) uses TensorFlow Lite to classify text as either positive or negative, using the -[Task library for natural language (NL)](../../inference_with_metadata/task_library/overview#supported_tasks) +[Task library for natural language (NL)](https://ai.google.dev/edge/litert/libraries/task_library/overview) to enable execution of the text classification machine learning models. If you are updating an existing project, you can use the example application as @@ -31,7 +31,7 @@ text being correctly classified as either positive or negative. For more information on how the models in this tutorial are generated, refer to the -[Text classification with TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) +[Text classification with TensorFlow Lite Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification) tutorial. ## Models and dataset @@ -41,7 +41,7 @@ This tutorial uses models that were trained using the Treebank) dataset. SST-2 contains 67,349 movie reviews for training and 872 movie reviews for testing, with each review categorized as either positive or negative. The models used in this app were trained using the TensorFlow Lite -[Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) +[Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification) tool. The example application uses the following pre-trained models: @@ -149,10 +149,10 @@ implement text classification features to your production applications: ## How the example app works {:#how_it_works} The application uses the -[Task library for natural language (NL)](../../inference_with_metadata/task_library/overview#supported_tasks) +[Task library for natural language (NL)](https://ai.google.dev/edge/litert/libraries/task_library/overview) package to implement the text classification models. The two models, Average Word Vector and MobileBERT, were trained using the TensorFlow Lite -[Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification). +[Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification). The application runs on CPU by default, with the option of hardware acceleration using the NNAPI delegate. @@ -237,7 +237,7 @@ model with parameters before running predictions with the model. A TensorFlow Lite model is stored as a `*.tflite` file. The model file contains the prediction logic and typically includes -[metadata](../../models/convert/metadata) about how to interpret prediction +[metadata](https://ai.google.dev/edge/litert/models/metadata) about how to interpret prediction results, such as prediction class names. Typically, model files are stored in the `src/main/assets` directory of your development project, as in the code example: @@ -475,7 +475,7 @@ user interface. ## Next steps * Train and implement the models from scratch with the - [Text classification with TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) + [Text classification with TensorFlow Lite Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification) tutorial. * Explore more [text processing tools for TensorFlow](https://www.tensorflow.org/text). From b15bfbe60e65476100f463569928e9e3e7f6a6a4 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 <115997457+gaikwadrahul8@users.noreply.github.com> Date: Fri, 28 Mar 2025 15:46:33 +0530 Subject: [PATCH 0006/1324] Update 04 broken links in gpu_native.md --- tensorflow/lite/g3doc/android/delegates/gpu_native.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/g3doc/android/delegates/gpu_native.md b/tensorflow/lite/g3doc/android/delegates/gpu_native.md index 2221c2066f9cb6..75b03603e91d6d 100644 --- a/tensorflow/lite/g3doc/android/delegates/gpu_native.md +++ b/tensorflow/lite/g3doc/android/delegates/gpu_native.md @@ -4,17 +4,17 @@ Using graphics processing units (GPUs) to run your machine learning (ML) models can dramatically improve the performance and the user experience of your ML-enabled applications. On Android devices, you can enable GPU-accelerated execution of your models using a -[*delegate*](../../performance/delegates) and one of the following APIs: +[*delegate*](https://ai.google.dev/edge/litert/performance/delegates) and one of the following APIs: - Interpreter API - [guide](./gpu) -- Task library API - [guide](./gpu_task) +- Task library API - [guide](./gpu_task.md) - Native (C/C++) API - this guide This guide covers advanced uses of the GPU delegate for the C API, C++ API, and use of quantized models. For more information about using the GPU delegate for TensorFlow Lite, including best practices and advanced techniques, see the -[GPU delegates](../../performance/gpu) page. +[GPU delegates](https://ai.google.dev/edge/litert/performance/gpu) page. ## Enable GPU acceleration @@ -65,7 +65,7 @@ thread in which `Interpreter::ModifyGraphWithDelegate()` was called. #### With TensorFlow Lite in Google Play Services: -If you are using TensorFlow Lite in Google Play Services [C API](../native), +If you are using TensorFlow Lite in Google Play Services [C API](https://ai.google.dev/edge/litert/android/native), you’ll need to use the Java/Kotlin API to check if a GPU delegate is available for your device before initializing the TensorFlow Lite runtime. @@ -171,4 +171,4 @@ if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; For more information about running quantized models with GPU acceleration, -see [GPU delegate](../../performance/gpu#quantized-models) overview. \ No newline at end of file +see [GPU delegate](https://ai.google.dev/edge/litert/performance/gpu#quantized_models) overview. From 21c685a571711b6092b42f0eefc76e5e61140372 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 28 Mar 2025 17:18:08 -0700 Subject: [PATCH 0007/1324] Integrate StableHLO at openxla/stablehlo@be8ce602 PiperOrigin-RevId: 741690690 --- third_party/stablehlo/temporary.patch | 36 ++++++------------- third_party/stablehlo/workspace.bzl | 4 +-- .../xla/third_party/stablehlo/temporary.patch | 36 ++++++------------- .../xla/third_party/stablehlo/workspace.bzl | 4 +-- 4 files changed, 26 insertions(+), 54 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index dd815d64317bad..949ebc772ae60e 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,17 +1,3 @@ -diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir ---- stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -+++ stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -@@ -79,8 +79,8 @@ - - // CHECK-LABEL: @slice - func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> { -- // CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -- // CHECK: %[[START:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -+ // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -+ // CHECK-DAG: %[[START:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> - // CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]] - %0 = "stablehlo.slice"(%arg) { - start_indices = array, diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp --- stablehlo/stablehlo/dialect/Serialization.cpp +++ stablehlo/stablehlo/dialect/Serialization.cpp @@ -413,14 +399,15 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable namespace mlir { namespace stablehlo { -@@ -53,19 +56,33 @@ +@@ -53,17 +56,33 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { public: - StablehloToVhloTypeConverter() : vhlo::VhloTypeConverter() { - addConversion([](Type type) -> Type { -- if (type.getDialect().getNamespace() == -- vhlo::VhloDialect::getDialectNamespace()) { +- if (llvm::isa(type.getDialect())) return type; +- +- LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); + StablehloToVhloTypeConverter(bool allowOtherDialects) + : vhlo::VhloTypeConverter() { + LLVM_DEBUG( @@ -436,9 +423,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + LLVM_DEBUG(llvm::dbgs() + << "[StablehloToVhloTypeConverter] Valid non-VHLO type: " + << type << '\n'); - return type; - } -- LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); ++ return type; ++ } + + LLVM_DEBUG(llvm::dbgs() << "[StablehloToVhloTypeConverter] Invalid type: " + << type << '\n'); @@ -452,7 +438,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable } Attribute convertEncoding(Attribute attr) const final { -@@ -1026,14 +1043,27 @@ +@@ -1021,14 +1040,27 @@ struct StablehloLegalizeToVhloPass : public impl::StablehloLegalizeToVhloPassBase< StablehloLegalizeToVhloPass> { @@ -481,7 +467,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable patterns = std::move(patterns_); return success(); -@@ -1048,7 +1078,7 @@ +@@ -1043,7 +1075,7 @@ } private: @@ -518,7 +504,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable } Attribute convertEncoding(Attribute attr) const final { -@@ -1022,6 +1024,36 @@ +@@ -1021,6 +1023,36 @@ } }; @@ -555,7 +541,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable template void populateVhloToStablehloPatterns(RewritePatternSet* patterns, TypeConverter* converter, -@@ -1044,6 +1076,7 @@ +@@ -1043,6 +1075,7 @@ RewritePatternSet patterns_(context); stablehlo::populateVhloToStablehloPatterns(&patterns_, &converter, context); @@ -563,7 +549,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable patterns = std::move(patterns_); return success(); -@@ -1056,6 +1089,12 @@ +@@ -1055,6 +1088,12 @@ if (failed(applyPartialConversion(getOperation(), *target, patterns))) { return signalPassFailure(); } diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index dad364bbc463a8..a6a1f38b9f4494 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "af8ba04d6682fada77595c69ce5bd36611a919e8" - STABLEHLO_SHA256 = "7793e819d38c53a5d67e54f6988a09e45bd6dc506276b9e29095521ae917e7a3" + STABLEHLO_COMMIT = "be8ce602efbd90fd677247075745bf16eb4b31ac" + STABLEHLO_SHA256 = "81f44a6f4c37599fc600c159a899602590b17f3d3858fe6a400bb5643b0c9ba1" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index dd815d64317bad..949ebc772ae60e 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,17 +1,3 @@ -diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir ---- stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -+++ stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -@@ -79,8 +79,8 @@ - - // CHECK-LABEL: @slice - func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> { -- // CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -- // CHECK: %[[START:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -+ // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -+ // CHECK-DAG: %[[START:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> - // CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]] - %0 = "stablehlo.slice"(%arg) { - start_indices = array, diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp --- stablehlo/stablehlo/dialect/Serialization.cpp +++ stablehlo/stablehlo/dialect/Serialization.cpp @@ -413,14 +399,15 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable namespace mlir { namespace stablehlo { -@@ -53,19 +56,33 @@ +@@ -53,17 +56,33 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { public: - StablehloToVhloTypeConverter() : vhlo::VhloTypeConverter() { - addConversion([](Type type) -> Type { -- if (type.getDialect().getNamespace() == -- vhlo::VhloDialect::getDialectNamespace()) { +- if (llvm::isa(type.getDialect())) return type; +- +- LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); + StablehloToVhloTypeConverter(bool allowOtherDialects) + : vhlo::VhloTypeConverter() { + LLVM_DEBUG( @@ -436,9 +423,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + LLVM_DEBUG(llvm::dbgs() + << "[StablehloToVhloTypeConverter] Valid non-VHLO type: " + << type << '\n'); - return type; - } -- LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); ++ return type; ++ } + + LLVM_DEBUG(llvm::dbgs() << "[StablehloToVhloTypeConverter] Invalid type: " + << type << '\n'); @@ -452,7 +438,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable } Attribute convertEncoding(Attribute attr) const final { -@@ -1026,14 +1043,27 @@ +@@ -1021,14 +1040,27 @@ struct StablehloLegalizeToVhloPass : public impl::StablehloLegalizeToVhloPassBase< StablehloLegalizeToVhloPass> { @@ -481,7 +467,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable patterns = std::move(patterns_); return success(); -@@ -1048,7 +1078,7 @@ +@@ -1043,7 +1075,7 @@ } private: @@ -518,7 +504,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable } Attribute convertEncoding(Attribute attr) const final { -@@ -1022,6 +1024,36 @@ +@@ -1021,6 +1023,36 @@ } }; @@ -555,7 +541,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable template void populateVhloToStablehloPatterns(RewritePatternSet* patterns, TypeConverter* converter, -@@ -1044,6 +1076,7 @@ +@@ -1043,6 +1075,7 @@ RewritePatternSet patterns_(context); stablehlo::populateVhloToStablehloPatterns(&patterns_, &converter, context); @@ -563,7 +549,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable patterns = std::move(patterns_); return success(); -@@ -1056,6 +1089,12 @@ +@@ -1055,6 +1088,12 @@ if (failed(applyPartialConversion(getOperation(), *target, patterns))) { return signalPassFailure(); } diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index dad364bbc463a8..a6a1f38b9f4494 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "af8ba04d6682fada77595c69ce5bd36611a919e8" - STABLEHLO_SHA256 = "7793e819d38c53a5d67e54f6988a09e45bd6dc506276b9e29095521ae917e7a3" + STABLEHLO_COMMIT = "be8ce602efbd90fd677247075745bf16eb4b31ac" + STABLEHLO_SHA256 = "81f44a6f4c37599fc600c159a899602590b17f3d3858fe6a400bb5643b0c9ba1" # LINT.ThenChange(Google-internal path) tf_http_archive( From 860e3a368fb7e24b115617821481374564212def Mon Sep 17 00:00:00 2001 From: Matt Hurd Date: Fri, 28 Mar 2025 18:55:15 -0700 Subject: [PATCH 0008/1324] Move tensorflow/core/profiler/utils into tensorflow/profiler. PiperOrigin-RevId: 741710599 --- tensorflow/core/framework/BUILD | 2 + tensorflow/core/profiler/convert/BUILD | 135 +-- .../convert/hlo_proto_to_graph_view.cc | 2 +- .../core/profiler/convert/inference_stats.cc | 2 +- .../core/profiler/convert/inference_stats.h | 2 +- .../convert/multi_xplanes_to_op_stats.cc | 5 +- .../multi_xspace_to_inference_stats.cc | 3 +- .../convert/multi_xspace_to_inference_stats.h | 3 +- .../profiler/convert/op_metrics_db_combiner.h | 8 +- .../profiler/convert/op_profile_builder.cc | 2 +- .../profiler/convert/op_stats_combiner.cc | 6 +- .../core/profiler/convert/op_stats_combiner.h | 2 +- .../convert/op_stats_combiner_test.cc | 2 +- .../op_stats_to_input_pipeline_analysis.cc | 14 +- .../op_stats_to_input_pipeline_analysis.h | 5 +- ...p_stats_to_input_pipeline_analysis_test.cc | 5 +- .../convert/op_stats_to_op_profile.cc | 4 +- .../convert/op_stats_to_overview_page.cc | 13 +- .../profiler/convert/op_stats_to_pod_stats.cc | 4 +- .../convert/op_stats_to_pod_stats_test.cc | 4 +- .../convert/op_stats_to_pod_viewer.cc | 2 +- .../convert/op_stats_to_pod_viewer_test.cc | 4 +- .../convert/op_stats_to_roofline_model.cc | 2 +- .../profiler/convert/op_stats_to_tf_stats.cc | 4 +- .../convert/preprocess_single_host_xplane.cc | 2 +- tensorflow/core/profiler/convert/repository.h | 2 +- .../convert/step_events_to_steps_db.cc | 4 +- .../convert/step_events_to_steps_db.h | 2 +- .../convert/xplane_to_kernel_stats_db.cc | 4 +- .../convert/xplane_to_kernel_stats_db.h | 6 +- .../convert/xplane_to_kernel_stats_db_test.cc | 2 +- .../convert/xplane_to_op_metrics_db.cc | 6 +- .../convert/xplane_to_op_metrics_db.h | 2 +- .../convert/xplane_to_op_metrics_db_test.cc | 2 + .../profiler/convert/xplane_to_op_stats.cc | 14 +- .../profiler/convert/xplane_to_step_events.cc | 4 +- .../profiler/convert/xplane_to_step_events.h | 2 +- .../convert/xplane_to_step_events_test.cc | 2 + .../profiler/convert/xplane_to_step_stats.cc | 2 + .../convert/xplane_to_tf_data_stats.cc | 2 +- .../profiler/convert/xplane_to_tools_data.cc | 2 +- .../convert/xspace_to_dcn_slack_analysis.cc | 2 +- tensorflow/core/profiler/utils/BUILD | 448 +++------- tensorflow/core/profiler/utils/cost_utils.cc | 131 --- tensorflow/core/profiler/utils/cost_utils.h | 41 +- .../core/profiler/utils/derived_timeline.cc | 772 ------------------ .../core/profiler/utils/derived_timeline.h | 182 +---- .../profiler/utils/derived_timeline_test.cc | 576 ------------- .../core/profiler/utils/device_caps_utils.cc | 90 -- .../core/profiler/utils/device_caps_utils.h | 12 +- tensorflow/core/profiler/utils/diagnostics.cc | 87 -- tensorflow/core/profiler/utils/diagnostics.h | 26 +- tensorflow/core/profiler/utils/event_span.cc | 449 ---------- tensorflow/core/profiler/utils/event_span.h | 250 +----- .../core/profiler/utils/gpu_event_stats.cc | 106 --- .../core/profiler/utils/gpu_event_stats.h | 63 +- .../profiler/utils/hardware_type_utils.cc | 347 -------- .../core/profiler/utils/hardware_type_utils.h | 63 +- .../utils/hardware_type_utils_test.cc | 66 -- .../core/profiler/utils/hlo_module_map.cc | 181 ---- .../core/profiler/utils/hlo_module_map.h | 196 +---- .../core/profiler/utils/hlo_module_utils.h | 99 +-- .../profiler/utils/hlo_module_utils_test.cc | 104 --- .../core/profiler/utils/hlo_proto_map.cc | 172 ---- .../core/profiler/utils/hlo_proto_map.h | 67 +- .../profiler/utils/hlo_proto_to_module.cc | 55 -- .../core/profiler/utils/hlo_proto_to_module.h | 18 +- .../core/profiler/utils/host_offload_utils.cc | 199 ----- .../core/profiler/utils/host_offload_utils.h | 72 -- tensorflow/core/profiler/utils/html_utils.h | 17 +- .../core/profiler/utils/kernel_stats_utils.cc | 352 -------- .../core/profiler/utils/kernel_stats_utils.h | 117 +-- .../profiler/utils/kernel_stats_utils_test.cc | 175 ---- .../profiler/utils/op_metrics_db_utils.cc | 370 --------- .../core/profiler/utils/op_metrics_db_utils.h | 132 +-- .../utils/op_metrics_db_utils_test.cc | 220 ----- tensorflow/core/profiler/utils/op_utils.cc | 183 ----- tensorflow/core/profiler/utils/op_utils.h | 90 +- .../core/profiler/utils/step_intersection.cc | 305 ------- .../core/profiler/utils/step_intersection.h | 68 +- .../profiler/utils/step_intersection_test.cc | 260 ------ .../core/profiler/utils/tfstreamz_utils.cc | 132 --- .../core/profiler/utils/tfstreamz_utils.h | 23 +- .../profiler/utils/tpu_step_breakdown_utils.h | 59 +- .../profiler/utils/tpu_step_details_utils.h | 33 +- .../profiler/utils/xprof_gpu_cost_analysis.cc | 147 ---- .../profiler/utils/xprof_gpu_cost_analysis.h | 38 +- .../utils/xprof_gpu_cost_analysis_test.cc | 189 ----- tensorflow/workspace2.bzl | 6 +- third_party/xla/xla/tsl/lib/gtl/BUILD | 3 + third_party/xla/xla/tsl/lib/monitoring/BUILD | 4 +- 91 files changed, 325 insertions(+), 7767 deletions(-) delete mode 100644 tensorflow/core/profiler/utils/cost_utils.cc delete mode 100644 tensorflow/core/profiler/utils/derived_timeline.cc delete mode 100644 tensorflow/core/profiler/utils/derived_timeline_test.cc delete mode 100644 tensorflow/core/profiler/utils/device_caps_utils.cc delete mode 100644 tensorflow/core/profiler/utils/diagnostics.cc delete mode 100644 tensorflow/core/profiler/utils/event_span.cc delete mode 100644 tensorflow/core/profiler/utils/gpu_event_stats.cc delete mode 100644 tensorflow/core/profiler/utils/hardware_type_utils.cc delete mode 100644 tensorflow/core/profiler/utils/hardware_type_utils_test.cc delete mode 100644 tensorflow/core/profiler/utils/hlo_module_map.cc delete mode 100644 tensorflow/core/profiler/utils/hlo_module_utils_test.cc delete mode 100644 tensorflow/core/profiler/utils/hlo_proto_map.cc delete mode 100644 tensorflow/core/profiler/utils/hlo_proto_to_module.cc delete mode 100644 tensorflow/core/profiler/utils/host_offload_utils.cc delete mode 100644 tensorflow/core/profiler/utils/host_offload_utils.h delete mode 100644 tensorflow/core/profiler/utils/kernel_stats_utils.cc delete mode 100644 tensorflow/core/profiler/utils/kernel_stats_utils_test.cc delete mode 100644 tensorflow/core/profiler/utils/op_metrics_db_utils.cc delete mode 100644 tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc delete mode 100644 tensorflow/core/profiler/utils/op_utils.cc delete mode 100644 tensorflow/core/profiler/utils/step_intersection.cc delete mode 100644 tensorflow/core/profiler/utils/step_intersection_test.cc delete mode 100644 tensorflow/core/profiler/utils/tfstreamz_utils.cc delete mode 100644 tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc delete mode 100644 tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 744b1df9e09c74..cc9215059ee06f 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -44,6 +44,8 @@ default_visibility = [ #internal nexus library tests, "//tensorflow/compiler/jit:__subpackages__", #internal library, + # TODO(matthurd): to be removed when summary.proto.h deps moves to TSL + "@org_xprof//xprof:__subpackages__", ] package( diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 5359e934e2bf01..6748b3a5e22e58 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -19,9 +19,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:cost_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:op_utils", "//tensorflow/core/profiler/utils:trace_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_visitor", @@ -36,6 +33,9 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:timespan", "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//xprof/utils:cost_utils", + "@org_xprof//xprof/utils:op_metrics_db_utils", + "@org_xprof//xprof/utils:op_utils", ], ) @@ -49,6 +49,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", @@ -57,6 +58,7 @@ tf_cc_test( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -67,11 +69,10 @@ cc_library( copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", - "//tensorflow/core/platform:protobuf", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -128,11 +129,11 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:protobuf", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:diagnostics", ], ) @@ -146,10 +147,12 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -172,17 +175,21 @@ cc_library( "//tensorflow/core/profiler/protobuf:power_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:html_utils", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:format_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:input_pipeline_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:overview_page_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:power_metrics_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", + "@org_xprof//xprof/utils:diagnostics", + "@org_xprof//xprof/utils:hardware_type_utils", + "@org_xprof//xprof/utils:html_utils", + "@org_xprof//xprof/utils:kernel_stats_utils", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -197,12 +204,12 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:diagnostics", + "@org_xprof//xprof/utils:event_span", ], ) @@ -216,9 +223,9 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:diagnostics", + "@org_xprof//xprof/utils:event_span", ], ) @@ -233,8 +240,8 @@ cc_library( "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", "//tensorflow/core/profiler/protobuf:pod_viewer_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", "@com_google_absl//absl/log:check", + "@org_xprof//xprof/utils:diagnostics", ], ) @@ -249,9 +256,9 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:diagnostics", + "@org_xprof//xprof/utils:event_span", ], ) @@ -288,12 +295,6 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tpu_input_pipeline_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:html_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:tpu_step_breakdown_utils", - "//tensorflow/core/profiler/utils:tpu_step_details_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -305,6 +306,16 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/util:stats_calculator_portable", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:input_pipeline_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tpu_input_pipeline_proto_cc", + "@org_xprof//xprof/utils:diagnostics", + "@org_xprof//xprof/utils:event_span", + "@org_xprof//xprof/utils:html_utils", + "@org_xprof//xprof/utils:op_metrics_db_utils", + "@org_xprof//xprof/utils:tpu_step_breakdown_utils", + "@org_xprof//xprof/utils:tpu_step_details_utils", ], ) @@ -316,10 +327,11 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", + "@local_tsl//tsl/platform:protobuf", "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:event_span", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -334,9 +346,9 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:kernel_stats_utils", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -370,12 +382,12 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:event_span", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -401,14 +413,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:device_caps_utils", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:gpu_event_stats", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:hlo_module_map", "//tensorflow/core/profiler/utils:hlo_proto_map", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:op_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", @@ -423,6 +428,13 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:timespan", "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//xprof/utils:device_caps_utils", + "@org_xprof//xprof/utils:event_span", + "@org_xprof//xprof/utils:gpu_event_stats", + "@org_xprof//xprof/utils:hardware_type_utils", + "@org_xprof//xprof/utils:hlo_module_map", + "@org_xprof//xprof/utils:kernel_stats_utils", + "@org_xprof//xprof/utils:op_utils", ], ) @@ -439,11 +451,12 @@ cc_library( "//tensorflow/core:portable_gif_internal", "//tensorflow/core/platform:status", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:step_intersection", "@com_google_absl//absl/status", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/platform:statusor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//xprof/utils:hardware_type_utils", + "@org_xprof//xprof/utils:step_intersection", ], ) @@ -488,8 +501,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:trace_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_visitor", @@ -502,6 +513,8 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//xprof/utils:event_span", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -513,6 +526,7 @@ tf_cc_test( ":xplane_to_step_events", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", @@ -521,6 +535,7 @@ tf_cc_test( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:group_events", "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//xprof/utils:event_span", ], ) @@ -531,9 +546,6 @@ cc_library( copts = tf_profiler_copts(), deps = [ "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/utils:gpu_event_stats", - "//tensorflow/core/profiler/utils:hlo_module_map", - "//tensorflow/core/profiler/utils:kernel_stats_utils", "//tensorflow/core/profiler/utils:trace_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/log", @@ -541,6 +553,9 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//xprof/utils:gpu_event_stats", + "@org_xprof//xprof/utils:hlo_module_map", + "@org_xprof//xprof/utils:kernel_stats_utils", ], ) @@ -554,12 +569,12 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/utils:kernel_stats_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@org_xprof//xprof/utils:kernel_stats_utils", ], ) @@ -665,10 +680,10 @@ cc_library( "//tensorflow/core/profiler/protobuf:power_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:topology_proto_cc", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:step_intersection", "@com_google_absl//absl/container:flat_hash_map", + "@org_xprof//xprof/utils:hardware_type_utils", + "@org_xprof//xprof/utils:kernel_stats_utils", + "@org_xprof//xprof/utils:step_intersection", ], ) @@ -683,8 +698,8 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:step_intersection", "@com_google_absl//absl/container:flat_hash_map", + "@org_xprof//xprof/utils:step_intersection", ], ) @@ -695,13 +710,13 @@ cc_library( copts = tf_profiler_copts(), visibility = ["//tensorflow/core/profiler:internal"], deps = [ - "//tensorflow/core/profiler/utils:derived_timeline", "//tensorflow/core/profiler/utils:xplane_schema", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:group_events", "@local_xla//xla/tsl/profiler/utils:preprocess_xplane", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//xprof/utils:derived_timeline", ], ) @@ -744,7 +759,6 @@ cc_library( "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:hardware_type_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/log", @@ -756,6 +770,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:timespan", "@org_xprof//xprof/convert/trace_viewer:trace_events_to_json", "@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility", + "@org_xprof//xprof/utils:hardware_type_utils", ], ) @@ -768,7 +783,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/utils:html_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/algorithm:container", @@ -781,6 +795,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:html_utils", ], ) @@ -823,6 +838,7 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//xprof/utils:gpu_event_stats", ], ) @@ -916,7 +932,6 @@ cc_library( "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", @@ -924,6 +939,7 @@ cc_library( "@com_google_absl//absl/strings", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) @@ -946,8 +962,8 @@ cc_library( "@local_xla//xla/tsl/platform:statusor", "//tensorflow/core/platform:statusor", "//tensorflow/core/profiler/utils:hlo_module_utils", - "//tensorflow/core/profiler/utils:hlo_proto_to_module", # copybara:uncomment "@com_github_nlohmann_json//:json", + "@org_xprof//xprof/utils:hlo_proto_to_module", ], ) @@ -984,7 +1000,6 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core/platform:errors", - "//tensorflow/core/profiler/utils:hlo_module_map", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -992,6 +1007,7 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/platform:statusor", "@local_xla//xla/tsl/profiler/utils:file_system_utils", + "@org_xprof//xprof/utils:hlo_module_map", ], ) @@ -1153,7 +1169,6 @@ cc_library( "//tensorflow/core/profiler/protobuf:topology_proto_cc", "//tensorflow/core/profiler/utils:hlo_module_utils", "//tensorflow/core/profiler/utils:hlo_proto_map", - "//tensorflow/core/profiler/utils:hlo_proto_to_module", "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -1173,6 +1188,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//xprof/utils:hlo_proto_to_module", ], ) @@ -1237,7 +1253,6 @@ cc_library( deps = [ "//tensorflow/core/lib/gtl:map_util", "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:xplane_schema", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -1256,6 +1271,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//xprof/utils:event_span", ], ) @@ -1312,7 +1328,6 @@ cc_library( ":repository", ":xplane_to_step_events", "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/status", @@ -1324,6 +1339,8 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:inference_stats_proto_cc", + "@org_xprof//xprof/utils:event_span", ], ) diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc index 2f153dd850b2e9..87a2f123e7f917 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc @@ -44,7 +44,7 @@ limitations under the License. #include "xla/tsl/platform/errors.h" #include "tensorflow/core/profiler/convert/tool_options.h" #include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" +#include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/inference_stats.cc b/tensorflow/core/profiler/convert/inference_stats.cc index 25e87b31dab352..904e97b1615735 100644 --- a/tensorflow/core/profiler/convert/inference_stats.cc +++ b/tensorflow/core/profiler/convert/inference_stats.cc @@ -46,10 +46,10 @@ limitations under the License. #include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tsl/platform/protobuf.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/inference_stats.h b/tensorflow/core/profiler/convert/inference_stats.h index 36b0aa600a2125..cc291fa9e336f4 100644 --- a/tensorflow/core/profiler/convert/inference_stats.h +++ b/tensorflow/core/profiler/convert/inference_stats.h @@ -25,8 +25,8 @@ limitations under the License. #include "xla/tsl/profiler/utils/group_events.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc index e01b645d3b19b6..38cfb2ea2ffc4e 100644 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc @@ -26,9 +26,10 @@ limitations under the License. #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof +#include "xprof/utils/step_intersection.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc index e1a466665492dc..f5cbb9b62a4b66 100644 --- a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc +++ b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc @@ -33,10 +33,11 @@ limitations under the License. #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/convert/xplane_to_step_events.h" #include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/inference_stats.pb.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow::profiler { diff --git a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h index 8214921600efea..9d8399f2f43f62 100644 --- a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h +++ b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h @@ -19,7 +19,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" +#include "plugin/tensorboard_plugin_profile/protobuf/inference_stats.pb.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow::profiler { // Get non overlapped step events from xspace for GPU. diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.h b/tensorflow/core/profiler/convert/op_metrics_db_combiner.h index 76019da86cd467..d538a232e4f41e 100644 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.h +++ b/tensorflow/core/profiler/convert/op_metrics_db_combiner.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" +#include "tsl/platform/protobuf.h" +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { @@ -34,8 +34,8 @@ void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst, // Combines the memory access breakdown. void CombineMemoryAccessedBreakdown( - const protobuf::RepeatedPtrField& src, - protobuf::RepeatedPtrField* dst); + const tsl::protobuf::RepeatedPtrField& src, + tsl::protobuf::RepeatedPtrField* dst); // Helper to combine op metrics databases. class OpMetricsDbCombiner : public OpMetricsDbBuilder { diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc index 0fde4660dcd496..fcb02c9227a938 100644 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ b/tensorflow/core/profiler/convert/op_profile_builder.cc @@ -34,8 +34,8 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_metrics_to_record.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tsl/platform/protobuf.h" +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc index 8879eac65c4f04..5c4b1bf08abb27 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner.cc +++ b/tensorflow/core/profiler/convert/op_stats_combiner.cc @@ -31,9 +31,9 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/power_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof +#include "xprof/utils/step_intersection.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.h b/tensorflow/core/profiler/convert/op_stats_combiner.h index a8cb3c62c4087a..e2a8bf25db0556 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner.h +++ b/tensorflow/core/profiler/convert/op_stats_combiner.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" +#include "xprof/utils/step_intersection.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc index cd5e97fe3c7e18..b4da91a61e1611 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" +#include "xprof/utils/step_intersection.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index e77a8a21cc9c73..956e7e46c8b34e 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -50,13 +50,15 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/html_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h" -#include "tensorflow/core/profiler/utils/tpu_step_details_utils.h" #include "tsl/platform/protobuf.h" +#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/input_pipeline.pb.h" // from @org_xprof +#include "xprof/utils/diagnostics.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof +#include "xprof/utils/html_utils.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof +#include "xprof/utils/tpu_step_breakdown_utils.h" // from @org_xprof +#include "xprof/utils/tpu_step_details_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h index 53ebe189eaa324..79c874212d8da1 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h @@ -30,8 +30,11 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" #include "tsl/platform/protobuf.h" +#include "plugin/tensorboard_plugin_profile/protobuf/input_pipeline.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tpu_input_pipeline.pb.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc index 3b9cff76410794..663fc62ed80d83 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc @@ -22,8 +22,9 @@ limitations under the License. #include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" +#include "tsl/platform/protobuf.h" +#include "xprof/utils/event_span.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc b/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc index 59ac8ca086bd4a..6e3119e1b3931d 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc @@ -26,7 +26,9 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/op_profile.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" +#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index 73af4c71436627..f582933c782aeb 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -39,11 +39,14 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/power_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/html_utils.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" +#include "plugin/tensorboard_plugin_profile/protobuf/overview_page.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/power_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof +#include "xprof/utils/diagnostics.h" // from @org_xprof +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof +#include "xprof/utils/html_utils.h" // from @org_xprof +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc index 3735c2a188bc19..13fcef0ca25dec 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" +#include "xprof/utils/diagnostics.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc index 899b8ade54ca9e..90909301b11e10 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" +#include "xprof/utils/diagnostics.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc index aad1e1ca79fd95..6dc5f8e870b2de 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" #include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" +#include "xprof/utils/diagnostics.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc index 2273bce70fb228..d01fba32ee1cfb 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" +#include "xprof/utils/diagnostics.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc index 58ebbc10ec9571..02066b0720c8ae 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" #include "tsl/platform/protobuf.h" +#include "xprof/utils/diagnostics.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc index 841a7b58be9d4c..c5b88a817e3766 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc index 7d8f22914421a7..7a1fc581fd3104 100644 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/tsl/profiler/utils/group_events.h" #include "xla/tsl/profiler/utils/preprocess_xplane.h" #include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/derived_timeline.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "xprof/utils/derived_timeline.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h index f6d4f78277d592..df90b16f5a8748 100644 --- a/tensorflow/core/profiler/convert/repository.h +++ b/tensorflow/core/profiler/convert/repository.h @@ -31,9 +31,9 @@ limitations under the License. #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" #include "tsl/platform/path.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/hlo_module_map.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc index 0d8f90bcbbbb76..9351d1c3b2baa3 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" +#include "xprof/utils/event_span.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h index 9764c46cfca6de..18c0f6a34819e2 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index 733185a2747624..d2360ecefd3924 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -24,11 +24,11 @@ limitations under the License. #include "xla/tsl/profiler/utils/tf_op_utils.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h index 57607d06337b52..ca6f98fd1515d3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h @@ -21,10 +21,10 @@ limitations under the License. #include "absl/log/log.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof +#include "xprof/utils/hlo_module_map.h" // from @org_xprof +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc index a675e69248a81a..03429987be30a9 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index b216f95a4a2fcf..505ef66fb81591 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -39,13 +39,13 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/convert/op_stack.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/cost_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/cost_utils.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof +#include "xprof/utils/op_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h index 06bcec66d136cb..337b4bee2cf27a 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/op_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index c877fe8ec8d942..30fd0b1a5f26a7 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -27,11 +27,13 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index de75de4bc2b77f..3747047a2730a4 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -45,18 +45,18 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/device_caps_utils.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/device_caps_utils.h" // from @org_xprof +#include "xprof/utils/event_span.h" // from @org_xprof +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof +#include "xprof/utils/hlo_module_map.h" // from @org_xprof +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof +#include "xprof/utils/op_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index e7debb44da8996..251dc6b72a9150 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -34,12 +34,12 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/event_span.h" // from @org_xprof +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.h b/tensorflow/core/profiler/convert/xplane_to_step_events.h index 35580f95281589..7d343b746b9fff 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.h +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ -#include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc index 2389f619edf3c5..d02c231659e353 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc @@ -22,11 +22,13 @@ limitations under the License. #include "xla/tsl/profiler/utils/group_events.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc index b653d7723f2310..dbe625dde2091f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc @@ -30,11 +30,13 @@ limitations under the License. #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/gpu_event_stats.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc index ae7871569e3046..fafdcc386c1295 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc @@ -34,9 +34,9 @@ limitations under the License. #include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" +#include "xprof/utils/html_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc index 60476773873064..816267f43362c0 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc @@ -66,13 +66,13 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tsl/platform/protobuf.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/convert/trace_viewer/trace_events_to_json.h" // from @org_xprof #include "xprof/convert/trace_viewer/trace_viewer_visibility.h" // from @org_xprof +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc index 55977e2ed00833..c59b06d6bc049f 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc @@ -46,10 +46,10 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/topology.pb.h" #include "tensorflow/core/profiler/utils/hlo_module_utils.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tsl/platform/regexp.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 2826a9747c29d5..4122fdefbfe62f 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library") +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") @@ -15,69 +15,61 @@ package_group( ], ) +# DO NOT ADD NEW DEPENDENCIES TO ANY TARGET IN THIS FILE. +# Instead, use //third_party/xprof/utils. + cc_library( name = "diagnostics", - srcs = ["diagnostics.cc"], hdrs = ["diagnostics.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", + "@org_xprof//xprof/utils:diagnostics", ], ) cc_library( name = "event_span", - srcs = ["event_span.cc"], hdrs = ["event_span.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:event_span", ], ) cc_library( name = "hardware_type_utils", - srcs = ["hardware_type_utils.cc"], hdrs = ["hardware_type_utils.h"], copts = tf_profiler_copts(), - deps = [ - ":xplane_schema", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "hardware_type_utils_test", - srcs = ["hardware_type_utils_test.cc"], deps = [ - ":hardware_type_utils", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:hardware_type_utils", ], ) cc_library( name = "math_utils", hdrs = ["math_utils.h"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/service:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/rpc:__pkg__", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@local_xla//xla/tsl/profiler/utils:math_utils", @@ -87,64 +79,40 @@ cc_library( cc_library( name = "html_utils", hdrs = ["html_utils.h"], + visibility = [ + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "@com_google_absl//absl/strings", + "@org_xprof//xprof/utils:html_utils", ], ) cc_library( name = "op_metrics_db_utils", - srcs = ["op_metrics_db_utils.cc"], hdrs = ["op_metrics_db_utils.h"], copts = tf_profiler_copts(), - deps = [ - ":xplane_visitor", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "op_metrics_db_utils_test", - srcs = ["op_metrics_db_utils_test.cc"], deps = [ - ":op_metrics_db_utils", - ":xplane_builder", - ":xplane_schema", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) cc_library( name = "op_utils", - srcs = ["op_utils.cc"], hdrs = ["op_utils.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":hlo_module_map", - ":op_metrics_db_utils", - "//tensorflow/core/profiler/convert:op_metrics_db_combiner", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/platform:types", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:op_utils", ], ) @@ -152,6 +120,12 @@ cc_library( name = "trace_utils", hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xprofilez/nvidia_gpu:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ "@local_xla//xla/tsl/profiler/utils:trace_utils", ], @@ -190,7 +164,11 @@ cc_library( testonly = True, hdrs = ["xplane_test_utils.h"], copts = tf_profiler_copts(), - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/db:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = ["@local_xla//xla/tsl/profiler/utils:xplane_test_utils"], ) @@ -206,275 +184,135 @@ cc_library( cc_library( name = "cost_utils", - srcs = ["cost_utils.cc"], hdrs = ["cost_utils.h"], copts = tf_profiler_copts(), - deps = [ - ":xplane_schema", - ":xplane_visitor", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/costs:cost_estimator", - "//tensorflow/core/grappler/costs:op_context", - "//tensorflow/core/grappler/costs:op_level_cost_estimator", - "//tensorflow/core/grappler/costs:op_performance_data_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -cc_library( - name = "host_offload_utils", - srcs = ["host_offload_utils.cc"], - hdrs = ["host_offload_utils.h"], - copts = tf_profiler_copts(), deps = [ - ":trace_utils", - ":xplane_builder", - ":xplane_schema", - ":xplane_visitor", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla:shape_util", - "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:cost_utils", ], ) cc_library( name = "derived_timeline", - srcs = ["derived_timeline.cc"], hdrs = ["derived_timeline.h"], copts = tf_profiler_copts(), - visibility = [":friends"], - deps = [ - ":gpu_event_stats", - ":hlo_module_map", - ":hlo_proto_map", - ":host_offload_utils", - ":trace_utils", - ":xplane_builder", - ":xplane_schema", - ":xplane_utils", - ":xplane_visitor", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:device_utils", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:trace_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@local_xla//xla/tsl/util:stats_calculator_portable", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//platforms/darwinn/tools/xprof_trace:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "derived_timeline_test", - srcs = ["derived_timeline_test.cc"], deps = [ - ":derived_timeline", - ":trace_utils", - ":xplane_builder", - ":xplane_schema", - ":xplane_test_utils", - ":xplane_visitor", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//xprof/utils:derived_timeline", ], ) cc_library( name = "kernel_stats_utils", - srcs = ["kernel_stats_utils.cc"], hdrs = ["kernel_stats_utils.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "kernel_stats_utils_test", - srcs = ["kernel_stats_utils_test.cc"], - deps = [ - ":kernel_stats_utils", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/backends/profiler/gpu:cupti_buffer_events", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], + deps = ["@org_xprof//xprof/utils:kernel_stats_utils"], ) cc_library( name = "tfstreamz_utils", - srcs = ["tfstreamz_utils.cc"], hdrs = ["tfstreamz_utils.h"], copts = tf_profiler_copts(), + visibility = ["//perftools/accelerators/xprof/xprofilez/cpu:__pkg__"], deps = [ - ":xplane_builder", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/framework:protos_all_cc", - "//tensorflow/core/profiler/protobuf:tfstreamz_proto_cc", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@org_xprof//xprof/utils:tfstreamz_utils", ], ) cc_library( name = "step_intersection", - srcs = ["step_intersection.cc"], hdrs = ["step_intersection.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/platform:types", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:timespan", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "step_intersection_test", - srcs = ["step_intersection_test.cc"], deps = [ - ":step_intersection", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/container:flat_hash_map", + "@org_xprof//xprof/utils:step_intersection", ], ) cc_library( name = "device_caps_utils", - srcs = ["device_caps_utils.cc"], hdrs = ["device_caps_utils.h"], copts = tf_profiler_copts(), - visibility = [":friends"], - deps = [ - ":xplane_builder", - ":xplane_schema", - ":xplane_visitor", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + visibility = [ + "//perftools/accelerators/xprof/xplane:__pkg__", + "//platforms/xla/tools:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], + deps = ["@org_xprof//xprof/utils:device_caps_utils"], ) cc_library( name = "gpu_event_stats", - srcs = ["gpu_event_stats.cc"], hdrs = ["gpu_event_stats.h"], copts = tf_profiler_copts(), - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":xplane_schema", - ":xplane_visitor", - "@com_google_absl//absl/strings", + "@org_xprof//xprof/utils:gpu_event_stats", ], ) cc_library( name = "hlo_proto_map", - srcs = ["hlo_proto_map.cc"], hdrs = ["hlo_proto_map.h"], - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + "@org_xprof//xprof/convert/google:__pkg__", + ], deps = [ - ":xplane_schema", - ":xplane_utils", - ":xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//xprof/utils:hlo_proto_map", ], ) cc_library( name = "hlo_proto_to_module", - srcs = ["hlo_proto_to_module.cc"], hdrs = ["hlo_proto_to_module.h"], - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@local_xla//xla:util", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/platform:statusor", + "@org_xprof//xprof/utils:hlo_proto_to_module", ], ) tf_cuda_library( name = "hlo_module_map", - srcs = ["hlo_module_map.cc"], hdrs = ["hlo_module_map.h"], cuda_deps = [ "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", ], - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":hlo_module_utils", - ":hlo_proto_map", - ":hlo_proto_to_module", - "//tensorflow/core/platform:path", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla:shape_util", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_cost_analysis", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@org_xprof//xprof/utils:hlo_module_map", ], ) @@ -482,77 +320,45 @@ cc_library( name = "hlo_module_utils", hdrs = ["hlo_module_utils.h"], visibility = [ - ":friends", - # copybara:uncomment "//tensorflow/compiler/mlir/lite/experimental/google/tooling/google:__subpackages__", + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/compiler/mlir/lite/experimental/google/tooling/hlo_adapter:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], deps = [ - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - ], -) - -tf_cc_test( - name = "hlo_module_utils_test", - srcs = ["hlo_module_utils_test.cc"], - deps = [ - ":hlo_module_utils", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tests:hlo_test_base", + "@org_xprof//xprof/utils:hlo_module_utils", ], ) cc_library( name = "xprof_gpu_cost_analysis", - srcs = ["xprof_gpu_cost_analysis.cc"], hdrs = ["xprof_gpu_cost_analysis.h"], - visibility = [":friends"], + visibility = ["//perftools/accelerators/xprof/convert:__pkg__"], deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla:shape_util", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_cost_analysis", - "@local_xla//xla/service/gpu:cublas_cudnn", - "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", - "@local_xla//xla/tsl/platform:errors", + "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], ) cc_library( name = "tpu_step_breakdown_utils", hdrs = ["tpu_step_breakdown_utils.h"], - visibility = [":friends"], - deps = ["//tensorflow/core/profiler/protobuf:steps_db_proto_cc"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], + deps = [ + "@org_xprof//xprof/utils:tpu_step_breakdown_utils", + ], ) cc_library( name = "tpu_step_details_utils", hdrs = ["tpu_step_details_utils.h"], - visibility = [":friends"], - deps = ["//tensorflow/core/profiler/protobuf:tpu_input_pipeline_proto_cc"], -) - -tf_cc_test( - name = "xprof_gpu_cost_analysis_test", - srcs = ["xprof_gpu_cost_analysis_test.cc"], + visibility = [ + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":xprof_gpu_cost_analysis", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_xla//xla:shape_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/hlo/testlib:test_helpers", - "@local_xla//xla/service:hlo_cost_analysis", - "@local_xla//xla/tests:hlo_test_base", - "@local_xla//xla/tests:xla_internal_test_main", - "@local_xla//xla/tsl/platform:statusor", + "@org_xprof//xprof/utils:tpu_step_details_utils", ], ) diff --git a/tensorflow/core/profiler/utils/cost_utils.cc b/tensorflow/core/profiler/utils/cost_utils.cc deleted file mode 100644 index 8d44fd513d6e91..00000000000000 --- a/tensorflow/core/profiler/utils/cost_utils.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/cost_utils.h" - -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/grappler/costs/cost_estimator.h" -#include "tensorflow/core/grappler/costs/op_context.h" -#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Decode the string that encodes tensor shape and type information and convert -// to TensorProperties. -// Returns an empty TensorProperties if error or input is "". -// See OpKernel::TraceString() to see when the shape is encoded as "". -// Input format is [, ,...] -static OpInfo::TensorProperties GetTensorProperties(absl::string_view info) { - OpInfo::TensorProperties tensor_prop; - std::vector parts = absl::StrSplit(info, '['); - if (parts.size() != 2) return tensor_prop; - DataType data_type = DT_INVALID; - if (!DataTypeFromString(parts[0], &data_type)) return tensor_prop; - tensor_prop.set_dtype(data_type); - absl::ConsumeSuffix(&parts[1], "]"); - if (parts[1].empty()) { // Scalar type. - tensor_prop.mutable_shape()->add_dim()->set_size(1); - return tensor_prop; - } - std::vector dims = absl::StrSplit(parts[1], ','); - for (const auto dim : dims) { - int size; - if (!absl::SimpleAtoi(dim, &size)) return OpInfo::TensorProperties(); - tensor_prop.mutable_shape()->add_dim()->set_size(size); - } - return tensor_prop; -} - -} // namespace - -TfOpRoofLineCostEstimator::~TfOpRoofLineCostEstimator() { - if (!unsupported_ops_.empty()) { - LOG(ERROR) << "Unsupported Op for Roofline Cost Analysis are:" - << absl::StrJoin(unsupported_ops_, ","); - } -} - -grappler::DeviceInfo TfOpRoofLineCostEstimator::GetDeviceInfo( - const DeviceProperties& device) const { - // Hypothetical devices that is used to measure peak flops and memory bytes - // accessed. - return grappler::DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1); -} - -TfOpRoofLineCostEstimator::OpRoofLineStats TfOpRoofLineCostEstimator::Predict( - const XEventVisitor& event) { - tsl::profiler::TfOp tf_op; - absl::string_view tensor_shapes; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kTfOp: - tf_op = tsl::profiler::ParseTfOpFullname(stat.StrOrRefValue()); - break; - case StatType::kTensorShapes: - tensor_shapes = stat.StrOrRefValue(); - break; - } - }); - - // Return empty OpRoofLineStats if shape is not traced or this is not a tf op. - if (tf_op.type.empty() || tensor_shapes.empty()) { - return {0ULL, 0ULL, /*inaccurate=*/true}; - } - - grappler::OpContext op_context; - op_context.name = std::string(tf_op.type); - op_context.op_info.set_op(op_context.name); - for (absl::string_view tensor : - tsl::profiler::ParseTensorShapes(tensor_shapes)) { - *op_context.op_info.add_inputs() = GetTensorProperties(tensor); - } - grappler::Costs costs = PredictCosts(op_context); - if (costs.inaccurate) unsupported_ops_.insert(std::string(tf_op.type)); - - VLOG(1) << tf_op.type << tensor_shapes - << " flops:" << costs.compute_time.count() - << " bytes:" << costs.memory_time.count(); - - /* The compute_time is measured in nanoseconds, therefore numerically it is - * equal to flops because giga ops / second cancel the nanoseconds. - * Same for memory_time */ - return {/*flops=*/static_cast(costs.compute_time.count()), - /*bytes_accessed=*/static_cast(costs.memory_time.count()), - /*inaccurate=*/costs.inaccurate}; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/cost_utils.h b/tensorflow/core/profiler/utils/cost_utils.h index 01a6540d8145cb..b7f139ffdb5915 100644 --- a/tensorflow/core/profiler/utils/cost_utils.h +++ b/tensorflow/core/profiler/utils/cost_utils.h @@ -15,45 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "tensorflow/core/grappler/costs/cost_estimator.h" -#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// This is a wrapper of tensorflow::grappler::OpLevelCostEstimator and use -// tracing time information to estimate the roof line stats for each traced -// tensorflow op. -class TfOpRoofLineCostEstimator - : public tensorflow::grappler::OpLevelCostEstimator { - public: - TfOpRoofLineCostEstimator() = default; - ~TfOpRoofLineCostEstimator() override; - - grappler::DeviceInfo GetDeviceInfo( - const DeviceProperties& device) const override; - - struct OpRoofLineStats { - uint64 flops = 0LL; - uint64 bytes_accessed = 0LL; - bool inaccurate = false; - }; - OpRoofLineStats Predict(const XEventVisitor& event); - - private: - absl::flat_hash_set - unsupported_ops_; // summary for unsupported ops. - - TfOpRoofLineCostEstimator(const TfOpRoofLineCostEstimator&) = delete; - void operator=(const TfOpRoofLineCostEstimator&) = delete; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/cost_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc deleted file mode 100644 index 721c283c7dda7c..00000000000000 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ /dev/null @@ -1,772 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/derived_timeline.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/trace_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "xla/tsl/util/stats_calculator.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/host_offload_utils.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::profiler::DeviceType; -using ::tsl::profiler::FindMutableTensorCorePlanes; -using ::tsl::profiler::GetDeviceType; - -inline std::string HloModuleEventName(const GpuEventStats& stats) { - return stats.program_id ? tsl::profiler::HloModuleNameWithProgramId( - stats.hlo_module_name, *stats.program_id) - : std::string(stats.hlo_module_name); -} - -// Returns a prefix that uniquely identifies the HLO module. -inline std::string HloOpEventPrefix(const GpuEventStats& stats) { - return stats.program_id ? absl::StrCat(*stats.program_id, "/") - : absl::StrCat(stats.hlo_module_name, "/"); -} - -std::vector GetOrCreateHloOpEventsMetadata( - XPlaneBuilder& xplane, const GpuEventStats& stats, const Symbol symbol) { - DCHECK(stats.IsXlaOp()); - std::vector hlo_op_events_metadata; - hlo_op_events_metadata.reserve(stats.hlo_op_names.size()); - // Prepend an HLO module identifier so HLO operators with the same name but in - // different modules have different metadata. - std::string hlo_op_event_prefix = HloOpEventPrefix(stats); - for (absl::string_view hlo_op_name : stats.hlo_op_names) { - XEventMetadata* hlo_op_event_metadata = xplane.GetOrCreateEventMetadata( - absl::StrCat(hlo_op_event_prefix, hlo_op_name)); - // Display the HLO name without the module name in tools. - if (hlo_op_event_metadata->display_name().empty()) { - hlo_op_event_metadata->set_display_name(std::string(hlo_op_name)); - } - hlo_op_events_metadata.push_back(hlo_op_event_metadata); - if (!symbol.hlo_text.empty()) { - XStatsBuilder event_stats(hlo_op_event_metadata, &xplane); - event_stats.SetOrAddStatValue(*xplane.GetOrCreateStatMetadata("hlo_text"), - symbol.hlo_text); - } - } - return hlo_op_events_metadata; -} - -// Get the derived line id for a given derived line in group which starts from -// first_derived_line_id. -// According to definition in trace_utils.h, the derived lines are: -// kThreadIdTfNameScope to kThreadIdSource. Keep the line id sequence in each -// group as this original group.. -inline int64_t GetDerivedLineId(int64_t first_derived_line_id, - int64_t target_line_id) { - return first_derived_line_id + (target_line_id - kThreadIdTfNameScope); -} - -// Get the derived line name for a given derived line in group which starts from -// first_derived_line_id. -std::string GetDerivedLineName(int64_t first_derived_line_id, - int64_t target_line_id, - absl::Span source_line_ids) { - int64_t offset = target_line_id - kThreadIdTfNameScope; - std::string suffix; - if (first_derived_line_id != kThreadIdTfNameScope && - !source_line_ids.empty()) { - suffix = absl::StrCat(" - from #", source_line_ids[0]); - } - switch (offset) { - case kThreadIdTfNameScope - kThreadIdTfNameScope: - return absl::StrCat(kTensorFlowNameScopeLineName, suffix); - case kThreadIdHloOp - kThreadIdTfNameScope: - return absl::StrCat(kXlaOpLineName, suffix); - case kThreadIdHloModule - kThreadIdTfNameScope: - return absl::StrCat(kXlaModuleLineName, suffix); - case kThreadIdTfOp - kThreadIdTfNameScope: - return absl::StrCat(kTensorFlowOpLineName, suffix); - case kThreadIdSource - kThreadIdTfNameScope: - return absl::StrCat(kSourceLineName, suffix); - default: - LOG(ERROR) << "Invalid target line id: " << target_line_id - << " for first_derived_line_id: " << first_derived_line_id; - return absl::StrCat("UnknownDerived#", first_derived_line_id + offset); - } -} - -// Derive events from the given line ids using annotations. -// Returns the derived line ids in the order of tf_name_scope, tf_op, hlo_op, -// hlo_module, source. Where the derived line id for tf_name_scope is -// first_derived_line_id. -std::vector DeriveEventsFromAnnotationsForLines( - const SymbolResolver& symbol_resolver, XPlane* device_trace, - absl::Span line_ids, int64_t first_derived_line_id, - const ScopeRangeIdTree* scope_range_id_tree = nullptr) { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - - XPlaneBuilder plane_builder(device_trace); - int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace); - DerivedXLineBuilder tf_ops( - &plane_builder, GetDerivedLineId(first_derived_line_id, kThreadIdTfOp), - GetDerivedLineName(first_derived_line_id, kThreadIdTfOp, line_ids), - start_timestamp_ns, {}); - DerivedXLineBuilder tf_name_scope( - &plane_builder, - GetDerivedLineId(first_derived_line_id, kThreadIdTfNameScope), - GetDerivedLineName(first_derived_line_id, kThreadIdTfNameScope, line_ids), - start_timestamp_ns, {&tf_ops}); - DerivedXLineBuilder hlo_ops( - &plane_builder, GetDerivedLineId(first_derived_line_id, kThreadIdHloOp), - GetDerivedLineName(first_derived_line_id, kThreadIdHloOp, line_ids), - start_timestamp_ns, {}); - DerivedXLineBuilder hlo_modules( - &plane_builder, - GetDerivedLineId(first_derived_line_id, kThreadIdHloModule), - GetDerivedLineName(first_derived_line_id, kThreadIdHloModule, line_ids), - start_timestamp_ns, {&tf_name_scope, &hlo_ops}); - DerivedXLineBuilder source( - &plane_builder, GetDerivedLineId(first_derived_line_id, kThreadIdSource), - GetDerivedLineName(first_derived_line_id, kThreadIdSource, line_ids), - start_timestamp_ns, {}); - - // Declare this vector here so that its memory will be reused during the loop, - // instead of being allocated and deallocated for each iteration. - std::vector> level_range_ids; - for (const XEventVisitor& event : - GetSortedEvents(plane_visitor, false, line_ids)) { - GpuEventStats stats(&event); - // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or - // allocation events). Also CudaGraph executions are also treated as - // kernel events. - if (!stats.IsKernel() && !stats.IsCudaGraphExecution()) continue; - tsl::profiler::Timespan event_span = event.GetTimespan(); - - if ((!stats.hlo_module_name.empty() || stats.IsXlaOp())) { - level_range_ids.clear(); - if (stats.scope_range_id.has_value()) { - level_range_ids.push_back(stats.scope_range_id); - if (scope_range_id_tree) { - for (auto it = scope_range_id_tree->find(*stats.scope_range_id); - it != scope_range_id_tree->end(); - it = scope_range_id_tree->find(it->second)) { - level_range_ids.push_back(it->second); - } - } - } - // Now, level_range_ids looks like: - // [child_level_n, child_level_n-1, ..., child_level_1, root_level] - } - - if (!stats.hlo_module_name.empty()) { - // back() of the level_range_ids, i.e. root_level in above comment, - // is the scope range id of HLO module. - hlo_modules.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(HloModuleEventName(stats)), - event_span, stats.group_id, - level_range_ids.empty() ? std::nullopt : level_range_ids.back()); - } - - if (stats.IsXlaOp()) { - auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name, - stats.hlo_op_names.back()); - auto hlo_events_metadata = - GetOrCreateHloOpEventsMetadata(plane_builder, stats, symbol); - // level_range_ids, if not empty, should be of same size as - // hlo_events_metadata. If not of same size, do not use those ids. - absl::Span> xla_op_level_range_ids = {}; - if (level_range_ids.size() == hlo_events_metadata.size()) { - std::reverse(level_range_ids.begin(), level_range_ids.end()); - // after reverse, the level_range_ids looks like: - // [root_level, child_level_1, ..., child_level_n-1, child_level_n] - xla_op_level_range_ids = absl::MakeSpan(level_range_ids); - } - hlo_ops.ExpandOrAddEvents(hlo_events_metadata, event_span, stats.group_id, - xla_op_level_range_ids); - - // If the kernel event is nodes of a CudaGraph or a whole cuda graph - // exec, try to mark extra stats to to corresponding XLA op event here. - if (stats.cuda_graph_id_for_inner_node.has_value() && - *stats.cuda_graph_id_for_inner_node != 0) { - int level = static_cast(hlo_events_metadata.size()) - 1; - if (level >= 0) { - hlo_ops.AddStatToLevelEvent(level, *hlo_ops.GetCudaGraphIdMetadata(), - *stats.cuda_graph_id_for_inner_node); - if (stats.correlation_id.has_value()) { - hlo_ops.AddStatToLevelEvent(level, - *hlo_ops.GetCorrelationIdMetadata(), - *stats.correlation_id); - } - } - } - - if (!symbol.tf_op_name.empty()) { - ProcessTfOpEvent(symbol.tf_op_name, event_span, stats.group_id, - plane_builder, tf_name_scope, tf_ops); - } - if (!symbol.source_info.empty()) { - source.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(symbol.source_info), - event_span, stats.group_id); - } - } else if (stats.IsTfOp()) { - ProcessTfOpEvent(stats.tf_op_fullname, event_span, stats.group_id, - plane_builder, tf_name_scope, tf_ops); - } - } - return {tf_name_scope.Line().Id(), tf_ops.Line().Id(), - hlo_modules.Line().Id(), hlo_ops.Line().Id(), source.Line().Id()}; -} - -} // namespace - -void ProcessTfOpEvent(absl::string_view tf_op_full_name, - tsl::profiler::Timespan event_span, - std::optional group_id, - XPlaneBuilder& plane_builder, - DerivedXLineBuilder& tf_name_scope_line_builder, - DerivedXLineBuilder& tf_op_line_builder) { - tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname(tf_op_full_name); - tsl::profiler::Category category = tf_op.category; - if (category == tsl::profiler::Category::kTensorFlow || - category == tsl::profiler::Category::kJax) { - tf_name_scope_line_builder.ExpandOrAddEvents( - plane_builder.GetOrCreateEventsMetadata( - tsl::profiler::ParseTfNameScopes(tf_op)), - event_span, group_id); - } - XEventMetadata* tf_op_event_metadata = - plane_builder.GetOrCreateEventMetadata(tf_op_full_name); - // Set the display name to op_type so that the events of the same op_type have - // the same color in the trace viewer. - if (tf_op_event_metadata->display_name().empty()) { - tf_op_event_metadata->set_display_name(tsl::profiler::TfOpEventName(tf_op)); - } - tf_op_line_builder.ExpandOrAddEvent(*tf_op_event_metadata, event_span, - group_id); -} - -DerivedXEventBuilder::DerivedXEventBuilder( - XEventBuilder event, std::optional group_id, - std::optional scope_range_id) - : event_(std::move(event)), - group_id_(group_id), - scope_range_id_(scope_range_id) {} - -bool DerivedXEventBuilder::ShouldExpand( - const XEventMetadata& event_metadata, std::optional group_id, - std::optional scope_range_id) const { - return event_.MetadataId() == event_metadata.id() && group_id_ == group_id && - (!scope_range_id.has_value() || !scope_range_id_.has_value() || - scope_range_id_ == scope_range_id); -} - -void DerivedXEventBuilder::Expand(tsl::profiler::Timespan event_span) { - tsl::profiler::Timespan timespan = event_.GetTimespan(); - DCHECK_LE(timespan.begin_ps(), event_span.begin_ps()); - timespan.ExpandToInclude(event_span); - event_.SetTimespan(timespan); -} - -DerivedXLineBuilder::DerivedXLineBuilder( - XPlaneBuilder* plane, int64_t line_id, absl::string_view name, - int64_t timestamp_ns, std::vector dependent_lines) - : group_id_stat_metadata_( - plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))), - correlation_id_metadata_(plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCorrelationId))), - cuda_graph_id_metadata_(plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCudaGraphId))), - line_(plane->GetOrCreateLine(line_id)), - dependent_lines_(std::move(dependent_lines)) { - line_.SetName(name); - line_.SetTimestampNs(timestamp_ns); - is_gpu_plane_ = GetDeviceType(plane->Name()) == DeviceType::kGpu; -} - -void DerivedXLineBuilder::ExpandOrAddEvent( - const XEventMetadata& event_metadata, tsl::profiler::Timespan event_span, - std::optional group_id, std::optional scope_range_id) { - ExpandOrAddLevelEvent(event_metadata, event_span, group_id, scope_range_id, - /*level=*/0); -} - -void DerivedXLineBuilder::ExpandOrAddEvents( - const std::vector& events_metadata_per_level, - tsl::profiler::Timespan event_span, std::optional group_id, - absl::Span> scope_range_ids) { - if (events_metadata_per_level.empty()) return; - - size_t current_nested_level = events_metadata_per_level.size(); - for (size_t level = 0; level < current_nested_level; ++level) { - ExpandOrAddLevelEvent( - *events_metadata_per_level[level], event_span, group_id, - level < scope_range_ids.size() ? scope_range_ids[level] : std::nullopt, - level); - } - ResetLastEvents(current_nested_level); -} - -void DerivedXLineBuilder::ExpandOrAddLevelEvent( - const XEventMetadata& event_metadata, tsl::profiler::Timespan event_span, - std::optional group_id, std::optional scope_range_id, - int level) { - auto& last_event = last_event_by_level_[level]; - // If group_id is not set and we still choose to expand, put an extra check: - // Expand only if the gap between the last event and the new event is less - // than 2 * duration of the last event. - // TODO: b/373944719 - add the extra node_id check for GPU profiles. - if (last_event.has_value() && - last_event->ShouldExpand(event_metadata, group_id, scope_range_id) && - (is_gpu_plane_ || group_id.has_value() || - (last_event->GetTimespan().end_ps() + - 2 * last_event->GetTimespan().duration_ps()) >= - event_span.begin_ps())) { - // Expand the last event to cover the given event. - last_event->Expand(event_span); - } else { - // Otherwise, reset the last events lower than or equal to the given level. - ResetLastEvents(level); - // And create a new event for the given level. - XEventBuilder event = line_.AddEvent(event_metadata); - event.SetTimespan(event_span); - if (group_id.has_value()) { - event.AddStatValue(*group_id_stat_metadata_, *group_id); - } - last_event.emplace(std::move(event), group_id, scope_range_id); - } -} - -void DerivedXLineBuilder::AddStatToLevelEvent(int level, - const XStatMetadata& metadata, - int64_t value) { - if (auto it = last_event_by_level_.find(level); - it != last_event_by_level_.end() && it->second.has_value()) { - it->second->SetOrAddStatValue(metadata, value); - } -} - -void DerivedXLineBuilder::AddStatToLevelEvent(int level, - const XStatMetadata& metadata, - uint64_t value) { - if (auto it = last_event_by_level_.find(level); - it != last_event_by_level_.end() && it->second.has_value()) { - it->second->SetOrAddStatValue(metadata, value); - } -} - -// When deriving a bunch of events with the same timespan, there could be -// indeterministic behavior of how trace viewer stacking these events. -// This function will shrink the stack of events with the same timespan when -// necessary. Event at top of stack might shrink more than event at the -// bottom. Because the time unit in trace viewer is nanosecond, therefore the -// minimum difference is 1ns. However to prevent shrink induced inconsitency, -// we can not shrink more than the duration of event at the top of the stack. -void DerivedXLineBuilder::AdjustDurationForTraceViewer(int level) { - if (level >= last_event_by_level_.size() || !last_event_by_level_[level]) - return; - - int max_level = level; - for (; max_level < last_event_by_level_.size(); ++max_level) { - if (!last_event_by_level_[max_level].has_value()) { - break; - } - } - --max_level; - if (max_level <= level) return; - auto& event_on_top_stack = *last_event_by_level_[max_level]; - tsl::profiler::Timespan timespan = event_on_top_stack.GetTimespan(); - // We will at most shrink the top of the stack to 1ns. - int64_t max_shrink_ns = timespan.duration_ps() / 1000 - 1; - int64_t shrink_ns = 0; - std::optional last_level_timespan; - for (int i = level; i <= max_level; ++i) { - auto& current_event = *last_event_by_level_[i]; - if (shrink_ns < max_shrink_ns && - last_level_timespan == current_event.GetTimespan()) { - shrink_ns++; - } - last_level_timespan = current_event.GetTimespan(); - if (shrink_ns) { - current_event.SetTimespan(tsl::profiler::Timespan::FromEndPoints( - last_level_timespan->begin_ps(), - last_level_timespan->end_ps() - 1000 * shrink_ns)); - } - } -} - -void DerivedXLineBuilder::ResetLastEvents(int level) { - AdjustDurationForTraceViewer(level); - for (int i = level, end = last_event_by_level_.size(); i < end; ++i) { - last_event_by_level_[i].reset(); - } - if (level == 0) { - for (DerivedXLineBuilder* line : dependent_lines_) { - line->ResetLastEvents(0); - } - } -} - -void DeriveStepEventsFromGroups( - const tsl::profiler::GroupMetadataMap& group_metadata_map, - XPlane* device_trace) { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - const XStatMetadata* group_id_stat_metadata = - plane_visitor.GetStatMetadataByType(StatType::kGroupId); - if (group_id_stat_metadata == nullptr) return; - XPlaneBuilder plane_builder(device_trace); - int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace); - DerivedXLineBuilder steps(&plane_builder, kThreadIdStepInfo, kStepLineName, - start_timestamp_ns, {}); - for (const XEventVisitor& event_visitor : - GetSortedEvents(plane_visitor)) { - std::optional group_id_stat = - event_visitor.GetStat(StatType::kGroupId, *group_id_stat_metadata); - if (group_id_stat.has_value()) { - int64_t group_id = group_id_stat->IntValue(); - steps.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(absl::StrCat(group_id)), - event_visitor.GetTimespan(), group_id); - } - } - AddGroupMetadataToStepEvents(group_metadata_map, steps.Line()); -} - -void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, - XPlane* device_trace, - const ScopeRangeIdTree* scope_range_id_tree) { - if (tsl::profiler::GetDeviceType(*device_trace) != - tsl::profiler::DeviceType::kGpu) { - DeriveEventsFromAnnotationsForLines(symbol_resolver, device_trace, {}, - kThreadIdTfNameScope); - } else { - // TODO: Currently we derive events only from the line with the most number - // of events. We should consider deriving events from all lines in the - // future, also then we need to utilize the derived relation provided by - // DeriveEventsFromAnnotationsForLines(), and find solid way to sort all - // lines. - int64_t line_id_with_most_events = -1; - int64_t max_num_events_per_line = -1; - { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - plane_visitor.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) return; - int num_events = line.NumEvents(); - // make sure strong ordering - if (num_events > max_num_events_per_line || - (num_events == max_num_events_per_line && - line.Id() < line_id_with_most_events)) { - max_num_events_per_line = num_events; - line_id_with_most_events = line.Id(); - } - }); - } - - if (line_id_with_most_events >= 0) { - DeriveEventsFromAnnotationsForLines( - symbol_resolver, device_trace, {line_id_with_most_events}, - kThreadIdTfNameScope, scope_range_id_tree); - } - } - RemoveEmptyLines(device_trace); -} - -void DeriveEventsFromHostTrace( - const XPlane* host_trace, - const tsl::profiler::GroupMetadataMap& group_metadata_map, - std::vector device_traces) { - struct GroupLaunchInfo { // "Group" normally means step. - tsl::profiler::Timespan timespan; - tsl::Stat stat; - - void AddEventTimespan(tsl::profiler::Timespan event_span) { - if (stat.count() == 0) { - timespan = event_span; - } else { - timespan.ExpandToInclude(event_span); - } - stat.UpdateStat(event_span.duration_ps()); - } - }; - using DeviceLaunchInfo = - absl::flat_hash_map; - - const int num_devices = device_traces.size(); - std::vector per_device_launch_info(num_devices); - - XPlaneVisitor host_plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - host_plane.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) return; - line.ForEachEvent([&](const XEventVisitor& event) { - // Filter out API calls for cuEventRecord/cuEventQuery/cuCtxSynchronize - // etc for now. TODO: find a better way to filter out only the memcpy and - // kernel launch events. - if (absl::StartsWith(event.Name(), "cu")) return; - LaunchEventStats stats(&event); - if (stats.group_id.has_value() && stats.IsLaunch() && - 0 <= *stats.device_id && *stats.device_id < num_devices) { - // This is a launch event on a known device. - GroupLaunchInfo& group_launch_info = - per_device_launch_info[*stats.device_id][*stats.group_id]; - group_launch_info.AddEventTimespan(event.GetTimespan()); - } - }); - }); - - int64_t host_plane_start = GetStartTimestampNs(*host_trace); - for (int i = 0; i < num_devices; ++i) { - if (per_device_launch_info[i].empty()) continue; - int64_t device_plane_start = GetStartTimestampNs(*device_traces[i]); - - XPlaneBuilder device_plane(device_traces[i]); - const XStatMetadata& group_id_stat_metadata = - *device_plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - const XStatMetadata& num_launches_stat_metadata = - *device_plane.GetOrCreateStatMetadata("num_launches"); - const XStatMetadata& max_launch_time_us_stat_metadata = - *device_plane.GetOrCreateStatMetadata("max_launch_time_us"); - const XStatMetadata& avg_launch_time_us_stat_metadata = - *device_plane.GetOrCreateStatMetadata("avg_launch_time_us"); - - XLineBuilder launch_line = - device_plane.GetOrCreateLine(kThreadIdKernelLaunch); - launch_line.SetName(kKernelLaunchLineName); - launch_line.SetTimestampNs(std::min(device_plane_start, host_plane_start)); - for (const auto& kv : per_device_launch_info[i]) { - int64_t group_id = kv.first; - const GroupLaunchInfo& group_info = kv.second; - if (const tsl::profiler::GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id)) { - XEventBuilder device_event = - launch_line.AddEvent(*device_plane.GetOrCreateEventMetadata( - absl::StrCat("Launch Stats for ", group_metadata->name))); - device_event.SetTimespan(group_info.timespan); - device_event.AddStatValue(group_id_stat_metadata, group_id); - device_event.AddStatValue(num_launches_stat_metadata, - group_info.stat.count()); - device_event.AddStatValue( - max_launch_time_us_stat_metadata, - tsl::profiler::PicoToMicro(group_info.stat.max())); - device_event.AddStatValue( - avg_launch_time_us_stat_metadata, - tsl::profiler::PicoToMicro(group_info.stat.avg())); - } - } - } -} - -void GenerateDerivedTimeLines( - const tsl::profiler::GroupMetadataMap& group_metadata_map, XSpace* space) { - HloModuleMap hlo_module_map; - { - HloProtoMap hlo_proto_map; - hlo_proto_map.AddHloProtosFromXSpace(*space); - for (const auto& [program_id, hlo_proto] : hlo_proto_map) { - AddHloProto(hlo_module_map, program_id, *hlo_proto); - } - } - - auto symbol_resolver = [&](absl::optional program_id, - absl::string_view hlo_module, - absl::string_view hlo_op) -> Symbol { - Symbol output; - const auto* hlo_instruction = - GetHloInstruction(hlo_module_map, program_id, hlo_op); - if (hlo_instruction != nullptr) { - output.tf_op_name = hlo_instruction->op_full_name(); - output.source_info = std::string(hlo_instruction->source_info()); - } - return output; - }; - - ScopeRangeIdTree scope_range_id_tree; - const XPlane* namespace_tree_plane = - FindPlaneWithName(*space, tsl::profiler::kScopeRangeIdTreePlaneName); - if (namespace_tree_plane) { - XPlaneVisitor namespace_tree_visitor = - tsl::profiler::CreateTfXPlaneVisitor(namespace_tree_plane); - namespace_tree_visitor.ForEachStat([&](const XStatVisitor& stat) { - scope_range_id_tree.emplace(stat.Id(), stat.IntValue()); - }); - } - - std::vector device_planes = - FindMutablePlanesWithPrefix(space, kGpuPlanePrefix); - for (XPlane* plane : device_planes) { - DeriveStepEventsFromGroups(group_metadata_map, plane); - DeriveEventsFromAnnotations(symbol_resolver, plane, &scope_range_id_tree); - } - - const XPlane* host_plane = FindPlaneWithName(*space, kHostThreadsPlaneName); - if (host_plane) { - DeriveEventsFromHostTrace(host_plane, group_metadata_map, device_planes); - } - for (XPlane* plane : FindMutableTensorCorePlanes(space)) { - DeriveLinesFromStats(plane); - SortXPlane(plane); - } -} - -void DeriveLinesFromStats(XPlane* device_trace) { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - XPlaneBuilder plane_builder(device_trace); - int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace); - DerivedXLineBuilder tf_ops( - &plane_builder, tensorflow::profiler::kThreadIdTfOp, - tensorflow::profiler::kTensorFlowOpLineName, start_timestamp_ns, {}); - DerivedXLineBuilder tf_name_scope( - &plane_builder, tensorflow::profiler::kThreadIdTfNameScope, - tensorflow::profiler::kTensorFlowNameScopeLineName, start_timestamp_ns, - {&tf_ops}); - DerivedXLineBuilder source( - &plane_builder, tensorflow::profiler::kThreadIdSource, - tensorflow::profiler::kSourceLineName, start_timestamp_ns, {}); - - HostOffloadEventProcessor host_offload_event_processor(&plane_builder, - start_timestamp_ns); - - for (const XEventVisitor& event : - GetSortedEvents(plane_visitor, true)) { - tsl::profiler::Timespan event_span = event.GetTimespan(); - std::optional tf_op_name; - std::optional source_info; - std::optional group_id; - std::optional is_async; - auto for_each_stat = [&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kTfOp) { - tf_op_name = stat.StrOrRefValue(); - } else if (stat.Type() == StatType::kGroupId) { - group_id = stat.IntOrUintValue(); - } else if (stat.Type() == StatType::kSourceInfo) { - source_info = stat.StrOrRefValue(); - } else if (stat.Type() == StatType::kIsAsync) { - is_async = stat.IntOrUintValue(); - } - }; - event.Metadata().ForEachStat(for_each_stat); - event.ForEachStat(for_each_stat); - - if (is_async && *is_async) continue; // Disregard asynchronous events. - - if (tf_op_name && !tf_op_name->empty()) { - ProcessTfOpEvent(*tf_op_name, event_span, group_id, plane_builder, - tf_name_scope, tf_ops); - } - if (source_info && !source_info->empty()) { - source.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(*source_info), event_span, - group_id); - } - if (host_offload_event_processor.IsHostOffloadOpName(event)) { - host_offload_event_processor.ProcessHostOffloadOpEvent(event, group_id); - } - } - tf_name_scope.ResetLastEvents(0); - - RemoveEmptyLines(device_trace); -} - -void DeriveLinesForXlaCpuOps(XPlane* host_trace) { - if (host_trace == nullptr || - !absl::StartsWith(host_trace->name(), kHostThreadsPlaneName)) - return; - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - XPlane destination_plane; - XPlaneBuilder plane_builder(&destination_plane); - int64_t line_id = tsl::profiler::kThreadIdHostXlaRegionStart; - visitor.ForEachLine([&](const XLineVisitor& line) { - int64_t start_timestamp_ns = line.TimestampNs(); - DerivedXLineBuilder tf_ops( - &plane_builder, line_id++, - absl::StrCat(line.Name(), "-", - tensorflow::profiler::kTensorFlowOpLineName), - start_timestamp_ns, {}); - DerivedXLineBuilder tf_name_scope( - &plane_builder, line_id++, - absl::StrCat(line.Name(), "-", - tensorflow::profiler::kTensorFlowNameScopeLineName), - start_timestamp_ns, {&tf_ops}); - DerivedXLineBuilder xla_cpu_ops( - &plane_builder, line_id++, - absl::StrCat(line.Name(), "-", tsl::profiler::kXlaModuleLineName), - start_timestamp_ns, {}); - line.ForEachEvent([&](const XEventVisitor& event) { - std::optional hlo_module_name; - std::optional framework_op_name; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - // TODO: Add additional stats for framework ops. - switch (stat.Type().value()) { - case StatType::kHloModule: - hlo_module_name = stat.StrOrRefValue(); - break; - case StatType::kTfOp: - framework_op_name = stat.StrOrRefValue(); - break; - } - }); - if (hlo_module_name.has_value()) { - xla_cpu_ops.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(*hlo_module_name), - event.GetTimespan(), std::nullopt); - if (framework_op_name.has_value()) { - ProcessTfOpEvent(*framework_op_name, event.GetTimespan(), - std::nullopt, plane_builder, tf_name_scope, tf_ops); - } - } - }); - }); - RemoveEmptyLines(&destination_plane); - MergePlanes(destination_plane, host_trace); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h index a152327319ccff..f2d41461fa2f1d 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.h +++ b/tensorflow/core/profiler/utils/derived_timeline.h @@ -15,186 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Store the mapping from child scope range id to parent scope range id, which -// logically form a scope range call stack tree/forest. -typedef absl::flat_hash_map - ScopeRangeIdTree; - -// Helper for deriving XEvents. -class DerivedXEventBuilder { - public: - DerivedXEventBuilder(XEventBuilder event, std::optional group_id, - std::optional scope_range_id = std::nullopt); - - bool ShouldExpand(const XEventMetadata& event_metadata, - std::optional group_id, - std::optional scope_range_id = std::nullopt) const; - - void Expand(tsl::profiler::Timespan event_span); - tsl::profiler::Timespan GetTimespan() const { return event_.GetTimespan(); } - void SetTimespan(tsl::profiler::Timespan event_span) { - event_.SetTimespan(event_span); - } - - template - void SetOrAddStatValue(const XStatMetadata& metadata, ValueT&& value) { - event_.SetOrAddStatValue(metadata, std::forward(value)); - } - - private: - XEventBuilder event_; - std::optional group_id_; - std::optional scope_range_id_; -}; - -// Helper for deriving an XLine from events in another XLine. -class DerivedXLineBuilder { - public: - DerivedXLineBuilder(XPlaneBuilder* plane, int64_t line_id, - absl::string_view name, int64_t timestamp_ns, - std::vector dependent_lines); - - XLineBuilder& Line() { return line_; } - - // Either merges event with the last event or creates a new event on this - // XLine. group_id and low_level_event_name may be passed to separate - // consecutive invocations of the same event, depending on the XEvent type: - // TF-op, TF name scope: both group_id and low_level_event_name are used. - // HLO-op, step: only group_id is used. - // HLO module, source: both group_id and low_level_event_name are NOT used. - // If scope_range_id is provided, it will be compared with the one in the - // event which is to be merged with. If they are different, merging is not - // allowed. - void ExpandOrAddEvent(const XEventMetadata& event_metadata, - tsl::profiler::Timespan event_span, - std::optional group_id, - std::optional scope_range_id = std::nullopt); - - // The multi-level version of ExpandOrAddEvent. Here, the XEvents at different - // levels all share the same group_id and low_level_event_name. - // Conceptually, the scope_range_ids should be of same length as the - // events_metadata_per_level. However, if it is shorter, this function will - // assume the missing elements at the end of scope_range_ids vector with the - // value of std::nullopt; and if it is longer, the extra elements in - // scope_range_ids will be ignored. - void ExpandOrAddEvents( - const std::vector& events_metadata_per_level, - tsl::profiler::Timespan event_span, std::optional group_id, - absl::Span> scope_range_ids = {}); - - // Reset the last events lower than or equal to the given level. - void ResetLastEvents(int level = 0); - - // To avoid using templates while need hide its implementation in .cc file, - // use two functions to set stat value for int64_t and uint64_t here. - void AddStatToLevelEvent(int level, const XStatMetadata& metadata, - int64_t value); - - void AddStatToLevelEvent(int level, const XStatMetadata& metadata, - uint64_t value); - - const XStatMetadata* GetCorrelationIdMetadata() const { - return correlation_id_metadata_; - } - - const XStatMetadata* GetCudaGraphIdMetadata() const { - return cuda_graph_id_metadata_; - } - - private: - // If the last event of the given level has the same metadata, expands it to - // include the time until the given event's end time. - // Otherwise, adds a new event and clears last_event_by_level_ for the levels - // below the given level and all levels of the dependent lines. Clearing - // last_event_by_level_ prevents a nested event from growing larger than the - // parent event(s). - void ExpandOrAddLevelEvent(const XEventMetadata& event_metadata, - tsl::profiler::Timespan event_span, - std::optional group_id, - std::optional scope_range_id, int level); - void AdjustDurationForTraceViewer(int level); - - const XStatMetadata* group_id_stat_metadata_ = nullptr; - const XStatMetadata* correlation_id_metadata_ = nullptr; - const XStatMetadata* cuda_graph_id_metadata_ = nullptr; - - XLineBuilder line_; - absl::flat_hash_map> - last_event_by_level_; - std::vector dependent_lines_; - bool is_gpu_plane_ = false; -}; - -struct Symbol { - absl::string_view tf_op_name; - std::string source_info; - std::string hlo_text; -}; - -using SymbolResolver = std::function program_id, - absl::string_view hlo_module_name, - absl::string_view hlo_op)>; - -// Derives TF name scope and op events from the TF op's fully qualified name -// with the name of the originating low-level event. -void ProcessTfOpEvent(absl::string_view tf_op_full_name, - tsl::profiler::Timespan event_span, - std::optional group_id, - XPlaneBuilder& plane_builder, - DerivedXLineBuilder& tf_name_scope_line_builder, - DerivedXLineBuilder& tf_op_line_builder); - -// Derives "Steps" line from group_id XStat in XEvents. -void DeriveStepEventsFromGroups( - const tsl::profiler::GroupMetadataMap& group_metadata_map, - XPlane* device_trace); - -// Derives "TensorFlow Ops", "TensorFlow Name Scope", "XLA Ops" and "XLA Module" -// lines in an NVIDIA_GPU device trace from data passed as ScopedAnnotations and -// stored as XStats in XEvents corresponding to GPU Kernels. Consecutive -// annotations with the same value are merged into a single event except for XLA -// modules. The device_trace is both input and output. -void DeriveEventsFromAnnotations( - const SymbolResolver& symbol_resolver, XPlane* device_trace, - const ScopeRangeIdTree* scope_range_id_tree = nullptr); - -// Derives "Launch Activities Summary" line from host trace. -void DeriveEventsFromHostTrace( - const XPlane* host_trace, - const tsl::profiler::GroupMetadataMap& group_metadata_map, - std::vector device_traces); - -// Loops through XPlanes of input XSpace, if it is "device" XPlane, generating -// derived timelines for the plane by calling DeriveEventsFromAnnotations. -void GenerateDerivedTimeLines( - const tsl::profiler::GroupMetadataMap& group_metadata_map, XSpace* space); - -// Derives `Tensorflow Ops`, `Tensorflow Name Scope` and `Source Code` lines -// from device_trace. -void DeriveLinesFromStats(tensorflow::profiler::XPlane* device_trace); - -// Devices Framework Op and Module lines for XLA:CPU ops. -void DeriveLinesForXlaCpuOps(tensorflow::profiler::XPlane* host_trace); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/derived_timeline.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc deleted file mode 100644 index 1e728003531fae..00000000000000 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ /dev/null @@ -1,576 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/derived_timeline.h" - -#include - -#include -#include -#include -#include - -#include -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -TEST(DerivedTimelineTest, EmptySpaceTest) { - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - GenerateDerivedTimeLines(group_metadata_map, &space); - EXPECT_EQ(space.planes_size(), 0); -} - -// Checks that HLO module events are expanded. -TEST(DerivedTimelineTest, HloModuleNameTest) { - const absl::string_view kHloModuleName = "hlo_module"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, kHloModuleName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, kHloModuleName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kHloModuleName); - }); - }); -} - -// Checks that HLO module events are expanded, with both same name and scope -// range id. Note that strange XStatValue{int64_t{10}} is to handle different -// compilers behavior. -TEST(DerivedTimelineTest, HloModuleNameSameScopeRangeIdTest) { - const absl::string_view kHloModuleName = "hlo_module"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kHloModuleName); - }); - }); -} - -// Checks that HLO module events are expanded, with same name only, -// but different scope range id. -TEST(DerivedTimelineTest, HloModuleNameDifferentScopeRangeIdTest) { - const absl::string_view kHloModuleName = "hlo_module"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{20}}}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kHloModuleName); - }); - }); -} - -// Checks that HLO module events are expanded. -TEST(DerivedTimelineTest, NoHloModuleNameTest) { - const absl::string_view kKernelDetails = "kernel_details"; - const uint64_t kCudaGraphExecId = 1; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(&plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kKernelDetails, kKernelDetails}}); - // Also add a CudaGraph Execution event. - CreateXEvent(&plane_builder, &line_builder, "op3", 500, 100, - {{StatType::kCudaGraphExecId, kCudaGraphExecId}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 1); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 0); - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfOpLineTest) { - const absl::string_view kTfOpName = "mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - const uint64_t kCudaGraphExecId = 1; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - // Also add a CudaGraph Execution event. - CreateXEvent(&plane_builder, &line_builder, "op3", 500, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kCudaGraphExecId, kCudaGraphExecId}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the tf op line is added and other empty lines are removed at the end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdTfOp); - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 600); - }); - }); -} - -// Checks that the dependency between the step line and the TF op line prevents -// TF op events from being expanded. -TEST(DerivedTimelineTest, DependencyTest) { - constexpr int64_t kFirstGroupId = 0; - constexpr int64_t kSecondGroupId = 1; - - const absl::string_view kTfOpName = "mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map( - {{0, {"train 0"}}, {1, {"train 1"}}}); - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kGroupId, kFirstGroupId}, - {StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kGroupId, kSecondGroupId}, - {StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The step line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_TRUE(line_visitor.Id() == kThreadIdStepInfo || - line_visitor.Id() == kThreadIdTfOp); - EXPECT_EQ(line_visitor.NumEvents(), 2); - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfOpNameScopeTest) { - const absl::string_view kTfOpName = "scope1/scope2/mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } else if (line_id == kThreadIdTfOp) { - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfNameScopeMaintainsOrder) { - const absl::string_view kTfOpName = "scope1/scope2/mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = - GetOrCreateTpuXPlane(&space, /*device_ordinal=*/0, "TPU V4", 0, 0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 10000, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Name() == tsl::profiler::kTensorFlowNameScopeLineName) { - EXPECT_EQ(line_visitor.NumEvents(), 2); - uint64_t expected_duration = 10000; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - LOG(INFO) << "scope: " << event_visitor.Name(); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), expected_duration); - expected_duration -= 1000; - }); - } - }); -} - -// Checks only derived events from line with most events for gpu trace. -TEST(DerivedTimelineTest, OnlyDerivedEventsFromLineWithMostEvents) { - const absl::string_view kTfOpName = "scope1/scope2/mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - // Add first line with two events. - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - // Add second line with only one event. - auto line_builder_2 = plane_builder.GetOrCreateLine(1); - CreateXEvent(&plane_builder, &line_builder_2, "op3", 50, 850, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - // Derive lines for the plane. - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 4); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0 || line_id == 1) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.OffsetPs(), 0); - // When derived from first line only, we should get single event which - // starts from op1' start (0), end at op2's end (200 + 300), - // duration is 500. - // If derived from both lines, the derived event duration will be - // (50 + 850) - 0 = 900. - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } else if (line_id == kThreadIdTfOp) { - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfOpNameScopeShrinkTest) { - { - // Case 1: shirnk is possible. - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 10000, - {{StatType::kTfOp, "a/b/c/Add:Add"}, - {StatType::kKernelDetails, "blah"}}); - CreateXEvent( - &plane_builder, &line_builder, "op2", 20000, 30000, - {{StatType::kTfOp, "a/d/Mul:Mul"}, {StatType::kKernelDetails, "blah"}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 4); - std::map durations; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - durations[event_visitor.Name()] = event_visitor.DurationPs(); - }); - EXPECT_EQ(durations["a"], 50000); - EXPECT_EQ(durations["b"], 10000); - EXPECT_EQ(durations["c"], 9000); // shrinked to be distinguish with b. - EXPECT_EQ(durations["d"], 30000); - } - }); - } - { - // Case 2: shirnk is impossible due to top event is too small. - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 10000, - {{StatType::kTfOp, "a/b/c/d/e/Add:Add"}, - {StatType::kKernelDetails, "blah"}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 10000, 2000, - {{StatType::kTfOp, "a/b/c/d/f/Sub:Sub"}, - {StatType::kKernelDetails, "blah"}}); - CreateXEvent( - &plane_builder, &line_builder, "op3", 20000, 30000, - {{StatType::kTfOp, "a/g/Mul:Mul"}, {StatType::kKernelDetails, "blah"}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 7); - std::map durations; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - durations[event_visitor.Name()] = event_visitor.DurationPs(); - }); - for (const auto& [name, duration] : durations) { - LOG(ERROR) << name << ": " << duration; - } - EXPECT_EQ(durations["a"], 50000); - EXPECT_EQ(durations["b"], 12000); - EXPECT_EQ(durations["c"], 11000); // shrinked to be distinguish with b. - EXPECT_EQ(durations["d"], 11000); // not shrinked because of f. - EXPECT_EQ(durations["e"], 10000); - EXPECT_EQ(durations["f"], 1000); - EXPECT_EQ(durations["g"], 30000); - } - }); - } -} - -// Checks that XLA Ops mapping to CudaGraph launch has extra stats. -TEST(DerivedTimelineTest, XloOpHasCudaGraphStats) { - constexpr absl::string_view kModuleName = "module"; - constexpr absl::string_view kHloOpName = "op_level_2"; - constexpr absl::string_view kKernelDetails = "kernel_details"; - constexpr int64_t kGroupIdValue = 1; - constexpr int64_t kCorrelationIdValue = 10000; - const uint64_t kCudaGraphIdValue = 20; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - - // Build Input Plane/Line/Events and derive events from them. - XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(&plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kKernelDetails, kKernelDetails}, - {StatType::kGroupId, kGroupIdValue}, - {StatType::kHloModule, kModuleName}, - {StatType::kHloOp, kHloOpName}, - {StatType::kCorrelationId, kCorrelationIdValue}, - {StatType::kCudaGraphId, kCudaGraphIdValue}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kKernelDetails, kKernelDetails}, - {StatType::kGroupId, kGroupIdValue}, - {StatType::kHloModule, kModuleName}, - {StatType::kHloOp, kHloOpName}, - {StatType::kCorrelationId, kCorrelationIdValue}, - {StatType::kCudaGraphId, kCudaGraphIdValue}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - - // Check that the HLO op line is added and has the extra stats for the first - // derived event. - size_t num_hlo_op_line = 0; - size_t num_events = 0; - std::optional correlation_id; - std::optional cuda_graph_id; - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == kThreadIdHloOp) { - num_hlo_op_line++; - if (num_hlo_op_line == 1) { - num_events = line_visitor.NumEvents(); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - correlation_id = event_visitor.GetStat(StatType::kCorrelationId); - cuda_graph_id = event_visitor.GetStat(StatType::kCudaGraphId); - }); - } - } - }); - EXPECT_EQ(num_hlo_op_line, 1); - EXPECT_EQ(num_events, 1); - ASSERT_TRUE(correlation_id.has_value()); - EXPECT_EQ(correlation_id->IntValue(), kCorrelationIdValue); - ASSERT_TRUE(cuda_graph_id.has_value()); - EXPECT_EQ(cuda_graph_id->UintValue(), kCudaGraphIdValue); -} - -TEST(DerivedTimelineTest, DeriveLinesForXlaCpuOps) { - XPlane xplane; - XPlaneBuilder plane_builder(&xplane); - plane_builder.SetName(tsl::profiler::kHostThreadsPlaneName); - - absl::string_view main_line_name = "main"; - auto line_builder = plane_builder.GetOrCreateLine(0); - line_builder.SetName(main_line_name); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, "Module1"}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 400, - {{StatType::kHloModule, "Module2"}}); - - DeriveLinesForXlaCpuOps(&xplane); - - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Name() == main_line_name) return; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - if (event_visitor.Name() == "Module1") { - EXPECT_EQ(event_visitor.DurationPs(), 100); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - } else if (event_visitor.Name() == "Module2") { - EXPECT_EQ(event_visitor.DurationPs(), 400); - EXPECT_EQ(event_visitor.OffsetPs(), 200); - } else { - FAIL() << "Found Event " << event_visitor.Name(); - } - }); - }); -} - -TEST(DerivedTimelineTest, MergeAndNoMerge) { - constexpr absl::string_view kHloModuleName = "Framework Ops"; - static constexpr absl::string_view kTfOpName = "abc:model/layer/MatMul_1"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = - GetOrCreateTpuXPlane(&space, /*device_ordinal=*/0, "DummyTPU", 1.0, 1.0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent( - &plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, kHloModuleName}, {StatType::kTfOp, kTfOpName}}); - CreateXEvent( - &plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, kHloModuleName}, {StatType::kTfOp, kTfOpName}}); - // The above two events are merged into one. This event will not be merged - // because the gap is > 2x(0..200+300) = 1000. - CreateXEvent( - &plane_builder, &line_builder, "op3", 1501, 300, - {{StatType::kHloModule, kHloModuleName}, {StatType::kTfOp, kTfOpName}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - }); - }); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/device_caps_utils.cc b/tensorflow/core/profiler/utils/device_caps_utils.cc deleted file mode 100644 index 3b149ad528b654..00000000000000 --- a/tensorflow/core/profiler/utils/device_caps_utils.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/device_caps_utils.h" - -#include - -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane) { - XPlaneBuilder xplane(plane); - int clock_rate_in_khz = - static_cast(caps.clock_rate_in_ghz() * 1000000.0); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapClockRateKHz)), - clock_rate_in_khz); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapCoreCount)), - caps.num_cores()); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), - caps.memory_bandwidth()); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemorySize)), - caps.memory_size_in_bytes()); - if (caps.has_compute_capability()) { - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMajor)), - caps.compute_capability().major()); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMinor)), - caps.compute_capability().minor()); - } - xplane.AddStatValue( - *xplane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kDevVendor)), - caps.device_vendor()); -} - -DeviceCapabilities GetDeviceCaps(const XPlane& plane) { - DeviceCapabilities caps; - XPlaneVisitor xplane = tsl::profiler::CreateTfXPlaneVisitor(&plane); - xplane.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kDevCapClockRateKHz: - caps.set_clock_rate_in_ghz(stat.IntOrUintValue() / 1000000.0); - break; - case StatType::kDevCapCoreCount: - caps.set_num_cores(stat.IntOrUintValue()); - break; - case StatType::kDevCapMemoryBandwidth: - caps.set_memory_bandwidth(stat.IntOrUintValue()); - break; - case StatType::kDevCapMemorySize: - caps.set_memory_size_in_bytes(stat.IntOrUintValue()); - break; - case StatType::kDevCapComputeCapMajor: - caps.mutable_compute_capability()->set_major(stat.IntOrUintValue()); - break; - case StatType::kDevCapComputeCapMinor: - caps.mutable_compute_capability()->set_minor(stat.IntOrUintValue()); - break; - case StatType::kDevVendor: - caps.set_device_vendor(std::string(stat.StrOrRefValue())); - break; - } - }); - return caps; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/device_caps_utils.h b/tensorflow/core/profiler/utils/device_caps_utils.h index c6c84133db3aaf..a500ed1d18acc6 100644 --- a/tensorflow/core/profiler/utils/device_caps_utils.h +++ b/tensorflow/core/profiler/utils/device_caps_utils.h @@ -16,16 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane); -DeviceCapabilities GetDeviceCaps(const XPlane& plane); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/device_caps_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/diagnostics.cc b/tensorflow/core/profiler/utils/diagnostics.cc deleted file mode 100644 index c4ff0f2069f07a..00000000000000 --- a/tensorflow/core/profiler/utils/diagnostics.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/diagnostics.h" - -#include - -#include "absl/algorithm/container.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -const absl::string_view kErrorIncompleteStep = - "Incomplete step observed and hence the step time is unknown." - "Instead, we use the trace duration as the step time. This may happen" - " if your profiling duration is shorter than the step time. In this" - " case, you may try to profile longer."; - -const absl::string_view kErrorEmptyIntersect = - "Although there are steps observed on some host(s), the intersection of " - "the steps over all hosts is empty (because the differences among " - "individual host's step sequences are too big). Consequently, the overall " - "step time is " - "unknown."; - -const absl::string_view kErrorNoStepMarker = - "No step marker observed and hence the step time is unknown." - " This may happen if (1) training steps are not instrumented (e.g., if" - " you are not using Keras) or (2) the profiling duration is shorter" - " than the step time. For (1), you need to add step instrumentation;" - " for (2), you may try to profile longer."; - -const absl::string_view kNoDeviceTraceCollected = - "No TensorCore device trace was collected. This might happen if your job " - "hadn't been run on the device when sampling was turned on. You could try " - "the sampling again later."; - -const absl::string_view kStepsDropped = - " steps dropped. This might happen when you profile many hosts and/or many " - "steps. You could try to profile shorter or reduce the number of hosts " - "you profile."; - -void PopulateStepDiagnostics(const OpStats& op_stats, Diagnostics* diag) { - if (op_stats.step_db().use_incomplete_step()) { - *diag->add_warnings() = std::string(kErrorIncompleteStep); - } else if (op_stats.step_db().step_sequence().empty()) { - *diag->add_warnings() = op_stats.step_db().empty_intersect() - ? std::string(kErrorEmptyIntersect) - : std::string(kErrorNoStepMarker); - } - if (op_stats.step_db().num_steps_dropped()) { - *diag->add_warnings() = - absl::StrCat(op_stats.step_db().num_steps_dropped(), kStepsDropped); - } -} - -void PopulateOverviewDiagnostics(const OpStats& op_stats, Diagnostics* diag) { - *diag->mutable_errors() = op_stats.diagnostics().errors(); - absl::c_sort(*diag->mutable_errors()); - if (diag->errors().empty()) { - // Shows run-environment error only if there is no other existing error. - if (op_stats.run_environment().device_type() != "CPU" && - op_stats.run_environment().device_core_count() <= 0) { - *diag->add_errors() = std::string(kNoDeviceTraceCollected); - } - } - *diag->mutable_warnings() = op_stats.diagnostics().warnings(); - PopulateStepDiagnostics(op_stats, diag); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/diagnostics.h b/tensorflow/core/profiler/utils/diagnostics.h index 25fb16900f2575..67eb4020d54c14 100644 --- a/tensorflow/core/profiler/utils/diagnostics.h +++ b/tensorflow/core/profiler/utils/diagnostics.h @@ -16,30 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/macros.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Error message that the visualization is based on incomplete step. -TF_CONST_INIT extern const absl::string_view kErrorIncompleteStep; - -// Error message that no step marker is seen and visualization contains no -// step info. -TF_CONST_INIT extern const absl::string_view kErrorNoStepMarker; - -TF_CONST_INIT extern const absl::string_view kNoDeviceTraceCollected; - -TF_CONST_INIT extern const absl::string_view kStepsDropped; - -void PopulateStepDiagnostics(const OpStats& op_stats, Diagnostics* diag); - -void PopulateOverviewDiagnostics(const OpStats& op_stats, Diagnostics* diag); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/diagnostics.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc deleted file mode 100644 index b5e9b813a15c01..00000000000000 --- a/tensorflow/core/profiler/utils/event_span.cc +++ /dev/null @@ -1,449 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/event_span.h" - -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Representing a boundary of an event. -struct EventBoundary { - // Time at this boundary. - uint64 time_ps; - // Type of the event. - EventType type; - // True if this is the start of the event; False if this is the end. - bool is_start; - EventBoundary(uint64 time_ps, EventType type, bool is_start) - : time_ps(time_ps), type(type), is_start(is_start) {} -}; - -// Returns true if EventBoundary a should appear before EventBoundary b. -bool CmpEventBoundaries(const EventBoundary& a, const EventBoundary& b) { - if (a.time_ps == b.time_ps) { - if (a.is_start == b.is_start) { - // Puts the higher-priority type before the lower-priority type if they - // have the same time and same boundary type. - return a.type > b.type; - } else { - // Puts the "end" bounary before the "start" boundary if they have the - // same time. - return !a.is_start; - } - } - // In ascending order of time. - return a.time_ps < b.time_ps; -} - -// Generates vector of event boundaries from the given overlapped_events. -std::vector GenerateEventBoundaries( - const std::vector& overlapped_events) { - std::vector boundaries; - boundaries.reserve(2 * overlapped_events.size()); - for (const auto& event : overlapped_events) { - boundaries.push_back( - {event.span.begin_ps(), event.type, /*is_start=*/true}); - boundaries.push_back({event.span.end_ps(), event.type, /*is_start=*/false}); - } - absl::c_sort(boundaries, CmpEventBoundaries); - return boundaries; -} - -// A class to track the highest priority that an event should be assigned. -class PriorityTracker { - private: - // The current maximum priority. - EventType current_max_priority_; - // A count for each possible priority. - std::vector priority_count_; - - public: - PriorityTracker() { - current_max_priority_ = UNKNOWN_TIME; - priority_count_.resize(LAST_EVENT_TYPE + 1, 0); - } - // Updates current_max_priority_ and priority_count_[] given the boundary. - // Returns the new current_max_priority_. - EventType Update(const EventBoundary& boundary) { - EventType event_type = boundary.type; - bool is_start = boundary.is_start; - if (is_start) { - priority_count_[event_type]++; - if (event_type > current_max_priority_) { - current_max_priority_ = event_type; - } - } else { - priority_count_[event_type]--; - if (event_type == current_max_priority_ && - priority_count_[event_type] == 0) { - // Reduces current_max_priority_ to the first event type (starting from - // the highest priority) that has a non-zero count. - bool found = false; - for (int i = event_type - 1; i >= 0; i--) { - if (priority_count_[i] > 0) { - current_max_priority_ = static_cast(i); - found = true; - break; - } - } - if (!found) current_max_priority_ = UNKNOWN_TIME; - } - } - return current_max_priority_; - } -}; - -constexpr int kNumGenericEventTypes = GenericEventType::kLastGenericEventType - - GenericEventType::kFirstGenericEventType + - 1; - -using GenericEventTypeStrMap = - absl::flat_hash_map; - -const GenericEventTypeStrMap& GetGenericEventTypeStrMap() { - static const auto* generic_event_type_str_map = new GenericEventTypeStrMap({ - {kDeviceCompute, "Device compute"}, - {kDeviceToDevice, "Device to device"}, - {kDeviceCollectives, "Device collective communication"}, - {kHostCompute, "Host compute"}, - {kHostPrepare, "Kernel launch"}, - {kInput, "Input"}, - {kOutput, "Output"}, - {kCompile, "Compilation"}, - {kAllOthers, "All others"}, - }); - DCHECK_EQ(generic_event_type_str_map->size(), kNumGenericEventTypes); - return *generic_event_type_str_map; -} - -} // namespace - -absl::string_view GetGenericEventTypeStr(GenericEventType event_type) { - return GetGenericEventTypeStrMap().at(event_type); -} - -std::string PrintEventType(EventType event_type) { - switch (event_type) { - case UNKNOWN_TIME: - return "unknown_time"; - case HOST_COMPUTE: - return "host_compute"; - case HOST_COMPILE: - return "host_compile"; - case HOST_TO_HOST: - return "host_to_host"; - case HOST_TO_DEVICE: - return "host_to_device"; - case HOST_PREPARE: - return "host_prepare"; - case DEVICE_COLLECTIVES: - return "device_collectives"; - case HOST_WAIT_INPUT: - return "host_wait_input"; - case DEVICE_TO_DEVICE: - return "device_to_device"; - case DEVICE_TO_HOST: - return "device_to_host"; - case DEVICE_COMPUTE_32: - return "device_compute_32"; - case DEVICE_COMPUTE_16: - return "device_compute_16"; - case DEVICE_WAIT_DEVICE: - return "device_wait_device"; - case DEVICE_WAIT_HOST: - return "device_wait_host"; - default: - return "unexpected"; - } -} - -std::string PrintEventTypeSpan(const EventTypeSpan& event_type_span) { - return absl::StrCat("(", PrintEventType(event_type_span.type), ", ", - event_type_span.span.DebugString(), ")"); -} - -absl::string_view PrintStepMarkerType(StepMarkerType type) { - switch (type) { - case StepMarkerType::kExplicitHostStepMarker: - return "ExplicitHostStepMarker"; - case StepMarkerType::kImplicitHostStepMarker: - return "ImplicitHostStepMarker"; - case StepMarkerType::kDeviceStepMarker: - return "DeviceStepMarker"; - } -} - -std::string PrintStepMarker(const StepMarker& step_marker) { - return absl::StrCat("(", PrintStepMarkerType(step_marker.type), ", ", - step_marker.event_name, ", ", - step_marker.span.DebugString(), ")"); -} - -std::string PrintStepEvents(const StepEvents& step_events) { - std::vector step_ids; - step_ids.reserve(step_events.size()); - for (const auto& id_details : step_events) { - step_ids.push_back(id_details.first); - } - absl::c_sort(step_ids); - std::string result = "{"; - for (auto id : step_ids) { - absl::StrAppend(&result, "\n"); - auto* details = gtl::FindOrNull(step_events, id); - std::string details_str = details ? details->DebugString() : "()"; - absl::StrAppend(&result, id, ":", details_str); - } - return absl::StrCat(result, "\n}"); -} - -void UnionCombineStepEvents(const StepEvents& src, StepEvents* dst) { - for (const auto& step_details : src) { - int64_t step_id = step_details.first; - const StepDetails& src_details = step_details.second; - StepDetails* dst_details = &(*dst)[step_id]; - dst_details->Combine(src_details); - } -} - -void IntersectCombineStepEvents(const StepEvents& src, StepEvents* dst) { - if (dst->empty()) { - *dst = src; - return; - } - auto iter = dst->begin(); - while (iter != dst->end()) { - if (!src.contains(iter->first)) { - // This is safe because the post-increment is sequenced after the full - // expression that contains it. - dst->erase(iter++); - } else { - iter->second.Combine(src.at(iter->first)); - iter++; - } - } -} - -std::vector ToNonOverlappedEvents( - const std::vector& overlapped_events) { - std::vector event_boundaries = - GenerateEventBoundaries(overlapped_events); - std::vector result; - if (event_boundaries.empty()) return result; - result.reserve(event_boundaries.size()); - PriorityTracker priority_tracker; - for (int64_t i = 0, end = (event_boundaries.size() - 1); i < end; i++) { - EventType highest_priority = priority_tracker.Update(event_boundaries[i]); - result.push_back({highest_priority, tsl::profiler::Timespan::FromEndPoints( - event_boundaries[i].time_ps, - event_boundaries[i + 1].time_ps)}); - } - return result; -} - -// Converts from overlapped step-events to non-overlapped step-events. -StepEvents ToNonOverlappedStepEvents(const StepEvents& overlapped_step_events) { - StepEvents non_overlapped_step_events; - for (const auto& step_events : overlapped_step_events) { - const auto& step_id = step_events.first; - const auto& step_details = step_events.second; - non_overlapped_step_events.try_emplace(step_id, - step_details.ToNonOverlapped()); - } - return non_overlapped_step_events; -} - -void StepDetails::AddMarker(const StepMarker& m) { markers_.push_back(m); } - -void StepDetails::AddEvent(const EventTypeSpan& e) { events_.push_back(e); } - -void StepDetails::AggregateDeviceMemoryTransfers( - const std::vector& device_memory_transfers) { - if (device_memory_transfers.size() != device_memory_transfers_.size()) { - return; // Sanity check. - } - for (size_t i = 0; i < device_memory_transfers.size(); ++i) { - device_memory_transfers_[i].set_occurrence( - device_memory_transfers_[i].occurrence() + - device_memory_transfers[i].occurrence()); - device_memory_transfers_[i].set_bytes_transferred( - device_memory_transfers_[i].bytes_transferred() + - device_memory_transfers[i].bytes_transferred()); - device_memory_transfers_[i].set_time_us( - device_memory_transfers_[i].time_us() + - device_memory_transfers[i].time_us()); - } -} - -void StepDetails::AddCollectiveOpEvent(uint64 core_id, const AllReduceInfo& e) { - *collectives_[core_id].add_all_reduce_info() = e; -} - -void StepDetails::AddDeviceMemoryTransferEvent( - EventType event_type, const tsl::profiler::Timespan& time_span, - uint64 bytes) { - int index = 0; - switch (event_type) { - case HOST_TO_DEVICE: - index = 0; - break; - case DEVICE_TO_HOST: - index = 1; - break; - case DEVICE_TO_DEVICE: - index = 2; - break; - default: - return; - } - device_memory_transfers_[index].set_occurrence( - device_memory_transfers_[index].occurrence() + 1); - device_memory_transfers_[index].set_time_us( - device_memory_transfers_[index].time_us() + - time_span.duration_ps() / 1000000.0); - device_memory_transfers_[index].set_bytes_transferred( - device_memory_transfers_[index].bytes_transferred() + bytes); -} - -tsl::profiler::Timespan StepDetails::StepTime() const { - tsl::profiler::Timespan max_host_step_time; - tsl::profiler::Timespan max_device_step_time; - for (const auto& marker : markers_) { - tsl::profiler::Timespan& cur_max_step_time = - marker.type == StepMarkerType::kDeviceStepMarker ? max_device_step_time - : max_host_step_time; - const tsl::profiler::Timespan& new_step_time = marker.span; - if (new_step_time.duration_ps() > cur_max_step_time.duration_ps()) - cur_max_step_time = new_step_time; - } - // CPU-only profile. - if (max_device_step_time.Empty()) { - return max_host_step_time; - } - - // If the host step time includes the device step time, use the host step - // time. This covers the case where the device is synchronized at the end of - // each step. - if (max_host_step_time.Includes(max_device_step_time)) { - return max_host_step_time; - } - return max_device_step_time; -} - -StepDetails StepDetails::ToNonOverlapped() const { - StepDetails non_overlapped_step_details; - non_overlapped_step_details.markers_ = markers_; - non_overlapped_step_details.events_ = ToNonOverlappedEvents(events_); - non_overlapped_step_details.collectives_ = collectives_; - non_overlapped_step_details.device_memory_transfers_ = - device_memory_transfers_; - non_overlapped_step_details.step_name_ = step_name_; - non_overlapped_step_details.per_core_op_metrics_db_ = per_core_op_metrics_db_; - return non_overlapped_step_details; -} - -void StepDetails::Combine(const StepDetails& other) { - markers_.insert(markers_.end(), other.markers_.begin(), other.markers_.end()); - events_.insert(events_.end(), other.events_.begin(), other.events_.end()); - collectives_.insert(other.collectives_.begin(), other.collectives_.end()); - AggregateDeviceMemoryTransfers(other.device_memory_transfers_); - for (const auto& [core_id, op_metric_db] : other.per_core_op_metrics_db_) { - per_core_op_metrics_db_[core_id] = op_metric_db; - } - if (step_name_.empty()) step_name_ = other.step_name_; -} - -std::string StepDetails::DebugString() const { - std::string result = "(["; - for (int i = 0, end = markers_.size(); i < end; i++) { - if (i > 0) absl::StrAppend(&result, ", "); - absl::StrAppend(&result, PrintStepMarker(markers_[i])); - } - absl::StrAppend(&result, "], ["); - for (int i = 0, end = events_.size(); i < end; i++) { - if (i > 0) absl::StrAppend(&result, ", "); - absl::StrAppend(&result, PrintEventTypeSpan(events_[i])); - } - return absl::StrCat(result, "])"); -} - -bool StepDetails::operator==(const StepDetails& other) const { - const auto& other_markers = other.Markers(); - if (markers_.size() != other_markers.size()) return false; - for (uint64 i = 0; i < markers_.size(); i++) { - if (markers_[i] != other_markers[i]) return false; - } - const auto& other_events = other.Events(); - if (events_.size() != other_events.size()) return false; - for (uint64 i = 0; i < events_.size(); i++) { - if (events_[i] != other_events[i]) return false; - } - return true; -} - -bool operator==(const StepEvents& a, const StepEvents& b) { - if (a.size() != b.size()) return false; - for (const auto& id_details : a) { - const auto a_id = id_details.first; - const auto& a_details = id_details.second; - const auto* b_details = gtl::FindOrNull(b, a_id); - if (b_details == nullptr) return false; - if (a_details != *b_details) return false; - } - return true; -} - -PrecisionStats ComputePrecisionStats( - const StepEvents& nonoverlapped_step_events) { - int64_t compute_32bit_ps = 0; - int64_t compute_16bit_ps = 0; - for (const auto& id_details : nonoverlapped_step_events) { - for (const auto& event : id_details.second.Events()) { - switch (event.type) { - case DEVICE_COMPUTE_32: - compute_32bit_ps += event.span.duration_ps(); - break; - case DEVICE_COMPUTE_16: - compute_16bit_ps += event.span.duration_ps(); - break; - default: - break; - } - } - } - PrecisionStats precision_stats; - precision_stats.set_compute_32bit_ps(compute_32bit_ps); - precision_stats.set_compute_16bit_ps(compute_16bit_ps); - return precision_stats; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h index 2b7b2c75b2f700..04506b6e6c6811 100644 --- a/tensorflow/core/profiler/utils/event_span.h +++ b/tensorflow/core/profiler/utils/event_span.h @@ -16,254 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -// The various event types. Enumerations are numbered such that a bigger number -// has a higher priority than a smaller number when used in execution-time -// breakdown. -enum EventType { - // No event associated with the time. It could be that the machine was idle or - // executing some events which were not traced. - UNKNOWN_TIME = 0, - // Host is computing. - HOST_COMPUTE = 10, - // Host is preprocessing the data before the execution on device. - HOST_PREPROCESS = 20, - // Host is postprocessing the data after the execution on device. - HOST_POSTPROCESS = 30, - // Host is batching data (for inference). - HOST_BATCH_FORMATION = 40, - // Host runtime, like memory allocation and etc. - HOST_RUNTIME = 50, - // Host is compiling. - HOST_COMPILE = 60, - // Host-to-host communication. - HOST_TO_HOST = 70, - // Host-to-device communication. - HOST_TO_DEVICE = 80, - // Host is preparing to launch a computation on device. - HOST_PREPARE = 90, - // Assigns a smaller priority to DEVICE_COLLECTIVES than HOST_WAIT_INPUT, - // because if an all-reduce event is overlapped with an host-wait-input event, - // we want to count it as waiting for input. - // Collective Ops such as All-Reduce. - DEVICE_COLLECTIVES = 100, - // Host is waiting for input. - HOST_WAIT_INPUT = 110, - // Device-to-device communication. - DEVICE_TO_DEVICE = 120, - // Device-to-host communication. - DEVICE_TO_HOST = 130, - // Device is computing with 32-bit precision. - DEVICE_COMPUTE_32 = 140, - // Device is computing with 16-bit precision. - DEVICE_COMPUTE_16 = 150, - // Device is waiting for another device. - DEVICE_WAIT_DEVICE = 160, - // Device is waiting for host. - DEVICE_WAIT_HOST = 170, - LAST_EVENT_TYPE = DEVICE_WAIT_HOST -}; - -// Generic event types that shown to the user. -enum GenericEventType { - kFirstGenericEventType = 1, - // Device is computing. - kDeviceCompute = kFirstGenericEventType, - // Device-to-device communication. - kDeviceToDevice, - // Collective Ops such as All-Reduce and NCCL. - kDeviceCollectives, - // Host is computing. - kHostCompute, - // Host is preparing to launch a computation on device. - kHostPrepare, - // Device waiting for input from the host. - kInput, - // Device sending output to the host. - kOutput, - // Host is compling. - kCompile, - // No recognized event associated with the time. - kAllOthers, - kLastGenericEventType = kAllOthers, -}; - -// Contains the type and timespan of an event. -struct EventTypeSpan { - EventType type; // type of this event. - tsl::profiler::Timespan span; // timespan of this event. - EventTypeSpan(EventType t, tsl::profiler::Timespan s) : type(t), span(s) {} - // Equality test. - bool operator==(const EventTypeSpan& other) const { - return type == other.type && span == other.span; - } - // Inequality test. - bool operator!=(const EventTypeSpan& other) const { - return !(*this == other); - } -}; - -enum class StepMarkerType { - // "TraceContext" TraceMe events. - kExplicitHostStepMarker, - // Identified by group_events (e.g., FunctionRun, SessionRun). - kImplicitHostStepMarker, - // Derived from the result of group_events. A device step marker starts with - // the first device event of the group and ends with the last event of the - // group. - kDeviceStepMarker, -}; - -// Record of an event that is used as a step marker. -struct StepMarker { - StepMarkerType type; - std::string event_name; // name of this event. - std::string step_name; - tsl::profiler::Timespan span; // timespan of this event. - StepMarker(StepMarkerType step_marker_type, absl::string_view name, - tsl::profiler::Timespan s) - : type(step_marker_type), event_name(name), span(s) {} - // Equality test. - bool operator==(const StepMarker& other) const { - return type == other.type && event_name == other.event_name && - span == other.span; - } - // Inequality test. - bool operator!=(const StepMarker& other) const { return !(*this == other); } -}; - -// Details of a step. Note that this could be the result of combining the -// StepDetails of the same step executed on different cores. -class StepDetails { - public: - StepDetails() : device_memory_transfers_(3) {} - - const std::vector& Markers() const { return markers_; } - const std::vector& Events() const { return events_; } - - const absl::flat_hash_map& Collectives() const { - return collectives_; - } - const std::vector& DeviceMemoryTransfers() const { - return device_memory_transfers_; - } - - absl::flat_hash_map& PerCoreOpMetricsDb() { - return per_core_op_metrics_db_; - } - // Returns the step time. - tsl::profiler::Timespan StepTime() const; - // Adds a step-marker to this step. - void AddMarker(const StepMarker& m); - // Adds an EventTypeSpan to this step. - void AddEvent(const EventTypeSpan& e); - // Adds a collective op to this step. - void AddCollectiveOpEvent(uint64 core_id, const AllReduceInfo& e); - // Appends device memory transfer events to this step. - // Only event type of HOST_TO_DEVICE/DEVICE_TO_DEVICE/DEVICE_TO_HOST are - // allowed. - void AddDeviceMemoryTransferEvent(EventType event_type, - const tsl::profiler::Timespan& time_span, - uint64 bytes); - // Returns the step name. - std::string StepName() const { return step_name_; } - // Sets the name of this step. - void SetStepName(std::string step_name) { step_name_ = step_name; } - - // Converts from overlapped events to non-overlapped events. - StepDetails ToNonOverlapped() const; - - // Combines other. - void Combine(const StepDetails& other); - - // Equality test. - bool operator==(const StepDetails& other) const; - // Inequality test. - bool operator!=(const StepDetails& other) const { return !(*this == other); } - - // Returns a string that prints the content of this object. - std::string DebugString() const; - - void SetPerCoreOpMetricsDb(OpMetricsDb db, uint32 core_id) { - per_core_op_metrics_db_[core_id] = db; - } - - private: - // Accumulates the device memory transfers from another step to this step. - void AggregateDeviceMemoryTransfers( - const std::vector& device_memory_transfers); - - // All step-markers found for marking this step in the traces. There could be - // multiple step-markers for a single step for different reasons. One such - // reason is that there may be one step-marker for the same step on each core; - // so after combining the StepDetails from multiple cores, there would be - // multiple step-markers for the same step. - std::vector markers_; - // All events belonging to this step. - std::vector events_; - // Collective operation related events such as all-reduce etc. - absl::flat_hash_map collectives_; - // Device memory transfers (including time and bytes involved). - // TODO(jiesun): Consider to use IntervalSet instead of just sum up the event - // durations. - std::vector device_memory_transfers_; - std::string step_name_; - - absl::flat_hash_map per_core_op_metrics_db_; -}; - -// Map from step_id to the events happened in that step. -using StepEvents = absl::flat_hash_map; - -// Equality test for StepEvents. -bool operator==(const StepEvents& a, const StepEvents& b); - -// Returns the name of the given EventType. -std::string PrintEventType(EventType event_type); - -// Returns the string of the given GenericEventType. -absl::string_view GetGenericEventTypeStr(GenericEventType event_type); - -// Returns a string that prints the given EventTypeSpan. -std::string PrintEventTypeSpan(const EventTypeSpan& event_type_span); - -// Returns a string that prints the given StepMarker. -std::string PrintStepMarker(const StepMarker& step_marker); - -// Returns a string that prints the given StepEvents. -std::string PrintStepEvents(const StepEvents& step_events); - -// Unions the map of StepEvents and combines the src StepEvents into dst. -void UnionCombineStepEvents(const StepEvents& src, StepEvents* dst); - -// Intersects the map of StepEvents and combines the src StepEvents into dst. -void IntersectCombineStepEvents(const StepEvents& src, StepEvents* dst); - -// Converts from overlapped events to non-overlapped events. -std::vector ToNonOverlappedEvents( - const std::vector& overlapped_events); - -// Converts from overlapped step-events to non-overlapped step events. -StepEvents ToNonOverlappedStepEvents(const StepEvents& overlapped_step_events); - -// Returns the precision stats of the given non-overlapped step events. -PrecisionStats ComputePrecisionStats( - const StepEvents& nonoverlapped_step_events); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/event_span.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.cc b/tensorflow/core/profiler/utils/gpu_event_stats.cc deleted file mode 100644 index eaa4c6ae17ae9d..00000000000000 --- a/tensorflow/core/profiler/utils/gpu_event_stats.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" - -#include - -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -const absl::string_view kAnnotationDelimiter = "::"; - -} - -GpuEventStats::GpuEventStats(const XEventVisitor* event) { - event->ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kTfOp: - tf_op_fullname = stat.StrOrRefValue(); - break; - case StatType::kEquation: - equation = stat.StrOrRefValue(); - break; - case StatType::kTensorShapes: - tensor_shapes = stat.StrOrRefValue(); - break; - case StatType::kHloOp: - hlo_op_names = - absl::StrSplit(stat.StrOrRefValue(), kAnnotationDelimiter); - break; - case StatType::kHloModule: - hlo_module_name = stat.StrOrRefValue(); - break; - case StatType::kProgramId: - program_id = stat.IntOrUintValue(); - break; - case StatType::kKernelDetails: - kernel_details = stat.StrOrRefValue(); - break; - case StatType::kMemcpyDetails: - memcpy_details = stat.StrOrRefValue(); - break; - case StatType::kCorrelationId: - correlation_id = static_cast(stat.IntOrUintValue()); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - case StatType::kIsEager: - is_eager = stat.BoolValue(); - break; - case StatType::kCudaGraphExecId: - cuda_graph_exec_id = stat.UintValue(); - break; - case StatType::kCudaGraphId: - cuda_graph_id_for_inner_node = stat.UintValue(); - break; - case StatType::kScopeRangeId: - scope_range_id = stat.IntValue(); - break; - default: - break; - } - }); -} - -LaunchEventStats::LaunchEventStats(const XEventVisitor* event) { - event->ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kDeviceId: - device_id = stat.IntOrUintValue(); - break; - case StatType::kCorrelationId: - correlation_id = static_cast(stat.IntOrUintValue()); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - default: - break; - } - }); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.h b/tensorflow/core/profiler/utils/gpu_event_stats.h index 369492dcceef88..574e333ae6784f 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.h +++ b/tensorflow/core/profiler/utils/gpu_event_stats.h @@ -16,67 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// Stats from a GPU stream XEvent. -struct GpuEventStats { - explicit GpuEventStats(const XEventVisitor* event); - - bool IsKernel() const { return !kernel_details.empty(); } - bool IsMemCpy() const { return !memcpy_details.empty(); } - bool IsCudaGraphExecution() const { return cuda_graph_exec_id.has_value(); } - - bool IsXlaOp() const { return !hlo_op_names.empty(); } - bool IsTfOp() const { return !tf_op_fullname.empty(); } - - // Stats from TensorFlow. - absl::string_view tf_op_fullname; - absl::string_view equation; - absl::string_view tensor_shapes; - - // Stats from XLA. - std::vector hlo_op_names; - absl::string_view hlo_module_name; - std::optional program_id; - - // Stats from CUPTI. - absl::string_view kernel_details; - absl::string_view memcpy_details; - std::optional correlation_id; - std::optional scope_range_id; - - // Stats derived by grouping. - std::optional group_id; - bool is_eager = false; - std::optional cuda_graph_exec_id; - std::optional cuda_graph_id_for_inner_node; -}; - -// Stats for a host-side GPU launch XEvent. -struct LaunchEventStats { - explicit LaunchEventStats(const XEventVisitor* event); - - bool IsLaunch() const { - return device_id.has_value() && correlation_id.has_value(); - } - - // Stats from CUPTI. - std::optional device_id; - std::optional correlation_id; - - // Stat derived by grouping. - std::optional group_id; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc deleted file mode 100644 index 22beb1d51bc860..00000000000000 --- a/tensorflow/core/profiler/utils/hardware_type_utils.cc +++ /dev/null @@ -1,347 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" - -#include - -#include "absl/container/btree_map.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// The calculation methods is referred from Nvidia developer forum: -// https://forums.developer.nvidia.com/t/how-to-calculate-the-tensor-core-fp16-performance-of-h100/244727 -// Below data are calculated from the various NVidia whitepapers/specs. - -// https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_9_0 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 256, - .bf16_tflops = 512, - .fp16_tflops = 512, - .int8_tops = 1024, - }, - .tensor_core = - { - .fp64_tflops = 256, - .fp32_tflops = 2048, - .bf16_tflops = 4096, - .fp16_tflops = 4096, - .fp8_tflops = 8192, - .int8_tops = 8192, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_9 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 256, - .bf16_tflops = 256, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp32_tflops = 512, - .bf16_tflops = 1024, - .fp16_tflops = 1024, - .fp8_tflops = 2048, - .int8_tops = 2048, - .int4_tops = 4096, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_6 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 256, - .bf16_tflops = 256, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp32_tflops = 256, - .bf16_tflops = 512, - .fp16_tflops = 1024, - .int8_tops = 2048, - .int4_tops = 4096, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_0 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .bf16_tflops = 256, - .fp16_tflops = 512, - .int8_tops = 512, - }, - .tensor_core = - { - .fp64_tflops = 128, - .fp32_tflops = 1024, - .bf16_tflops = 2048, - .fp16_tflops = 2048, - .int8_tops = 4096, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://images.nvidia.com/aem-dam/en-zz/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_5 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp16_tflops = 1024, - .int8_tops = 2048, - .int4_tops = 4096, - }, - .has_tensor_core_sparsity_support = false, -}; - -// https://images.nvidia.com/content/volta-architecture/pdf/volta-architecture-whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_0 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .bf16_tflops = 0.0, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp16_tflops = 1024, - }, - .has_tensor_core_sparsity_support = false, -}; - -// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_1 = { - .cuda_core = - { - .fp64_tflops = 8, - .fp32_tflops = 256, - .fp16_tflops = 4, - .int8_tops = 1024, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_0 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-product-literature/NVIDIA-Kepler-GK110-GK210-Architecture-Whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_5_0 = { - .cuda_core = - { - .fp64_tflops = 4, - .fp32_tflops = 256, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -// https://www.nvidia.com/content/PDF/product-specifications/GeForce_GTX_680_Whitepaper_FINAL.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_3_0 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 384, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_2_0 = { - .cuda_core = - { - .fp64_tflops = 8, - .fp32_tflops = 64, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -GpuFlopCapabilities GetNvidiaFlopCapsPerSMPerCycle(int major_comp_cap, - int minor_comp_cap) { - static const auto& kPerSMFlopCapsTable = - *new absl::btree_map{ - // TODO: Add incoming blackwell, and other old GPUS - {9000, &kComputeCap_PerSM_PerCycle_9_0}, - {8090, &kComputeCap_PerSM_PerCycle_8_9}, - {8060, &kComputeCap_PerSM_PerCycle_8_6}, - {8000, &kComputeCap_PerSM_PerCycle_8_0}, - {7050, &kComputeCap_PerSM_PerCycle_7_5}, - {7000, &kComputeCap_PerSM_PerCycle_7_0}, - {6010, &kComputeCap_PerSM_PerCycle_6_1}, - {6000, &kComputeCap_PerSM_PerCycle_6_0}, - {5000, &kComputeCap_PerSM_PerCycle_5_0}, - {3000, &kComputeCap_PerSM_PerCycle_3_0}, - {2000, &kComputeCap_PerSM_PerCycle_2_0}, - }; - - const int normalized_compute_cap = - major_comp_cap * 1000 + minor_comp_cap * 10; - GpuFlopCapabilities flops_cap{}; - auto it = kPerSMFlopCapsTable.lower_bound(normalized_compute_cap); - if (it == kPerSMFlopCapsTable.end()) { - LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." - << minor_comp_cap << " is too old to support."; - } else { - flops_cap = *it->second; - if (it->first != normalized_compute_cap) { - LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." - << minor_comp_cap - << " is not found. Use the highest compute cap known " - << (it->first / 1000) << "." << ((it->first % 1000) / 10) - << " instead."; - } - } - return flops_cap; -} - -GpuFlopCapabilities GetGpuFlopCapabilitiesPerSM( - const DeviceCapabilities& device_cap) { - GpuFlopCapabilities flops_cap{}; - if (device_cap.device_vendor() == kDeviceVendorNvidia) { - flops_cap = - GetNvidiaFlopCapsPerSMPerCycle(device_cap.compute_capability().major(), - device_cap.compute_capability().minor()); - } else { - LOG(WARNING) << "Unsupported device vendor " << device_cap.device_vendor(); - } - - flops_cap.ScaleWith(device_cap.clock_rate_in_ghz()); - return flops_cap; -} - -} // namespace - -double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap) { - GpuFlopCapabilities sm_flops = GetGpuFlopCapabilitiesPerSM(device_cap); - double result = std::max( - {sm_flops.cuda_core.fp32_tflops, sm_flops.cuda_core.fp16_tflops, - sm_flops.tensor_core.fp32_tflops, sm_flops.tensor_core.fp16_tflops}); - VLOG(3) << "GetFlopMaxThroughputPerSM get result: " << result << " GFLOPs"; - return result; -} - -double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap) { - // https://docs.nvidia.com/gameworks/content/developertools/desktop/analysis/report/cudaexperiments/kernellevel/memorystatisticsshared.htm - // Compute capability 2.0, each bank has bandwidth of 4 bytes per 2 cycles. - // For compute capability 3.0 and above, each bank has bandwidth 8 bytes per - // cycle. Each SM has 32 banks. - double transaction_byts_per_cycle = - device_cap.compute_capability().major() <= 2 ? (32 * 4 / 2) : (32 * 8); - double GiBPS = transaction_byts_per_cycle * device_cap.clock_rate_in_ghz(); - return tsl::profiler::GigaToUni(GiBPS); -} - -absl::string_view GpuModelName(const DeviceCapabilities& device_cap) { - if (device_cap.device_vendor() == kDeviceVendorNvidia) { - switch (device_cap.compute_capability().major()) { - case 2: - return "Nvidia GPU (Fermi)"; - case 3: - return "Nvidia GPU (Kepler)"; - case 5: - return "Nvidia GPU (Maxwell)"; - case 6: - return "Nvidia GPU (Pascal)"; - case 7: - if (device_cap.compute_capability().minor() < 5) { - return "Nvidia GPU (Volta)"; - } else { - return "Nvidia GPU (Turing)"; - } - case 8: - if (device_cap.compute_capability().minor() < 9) { - return "Nvidia GPU (Ampere)"; - } else { - return "Nvidia GPU (Ada Lovelace)"; - } - case 9: - return "Nvidia GPU (Hopper)"; - case 10: - return "Nvidia GPU (Blackwell)"; - default: - return "Nvidia GPU"; - } - } else if (device_cap.device_vendor() == kDeviceVendorAMD) { - switch (device_cap.compute_capability().major()) { - case 9: - return "AMD GPU - gfx-9XX series"; - case 10: - return "AMD GPU - gfx-10XX series"; - case 11: - return "AMD GPU - gfx-11XX series"; - default: - return "AMD GPU"; - } - } else { - LOG(ERROR) << "Unknown device vendor " << device_cap.device_vendor(); - return ""; - } -} - -HardwareType ParseHardwareType(absl::string_view device_type) { - if (absl::StrContains(device_type, "GPU")) return HardwareType::GPU; - if (device_type == "CPU") return HardwareType::CPU_ONLY; - if (absl::StrContains(device_type, "TPU")) return HardwareType::TPU; - return HardwareType::UNKNOWN_HARDWARE; -} - -bool HasDevice(HardwareType x) { return x > tensorflow::profiler::CPU_ONLY; } - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.h b/tensorflow/core/profiler/utils/hardware_type_utils.h index 41b1bd4b65471c..c2fc5266bc3778 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.h +++ b/tensorflow/core/profiler/utils/hardware_type_utils.h @@ -16,67 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" - -namespace tensorflow { -namespace profiler { - -struct GpuFlopCapabilities { - struct FlopCapabilityOnPrecisions { - double fp64_tflops = 0; - double fp32_tflops = 0; // also for tf32 for nvidia tensor core - double bf16_tflops = 0; - double fp16_tflops = 0; - double fp8_tflops = 0; - double int8_tops = 0; - double fp4_tflops = 0; - double int4_tops = 0; - - void ScaleWith(double scale) { - fp64_tflops *= scale; - fp32_tflops *= scale; - bf16_tflops *= scale; - fp16_tflops *= scale; - fp8_tflops *= scale; - int8_tops *= scale; - fp4_tflops *= scale; - int4_tops *= scale; - } - }; - - FlopCapabilityOnPrecisions cuda_core; - FlopCapabilityOnPrecisions tensor_core; - bool has_tensor_core_sparsity_support = false; - - void ScaleWith(double scale) { - cuda_core.ScaleWith(scale); - tensor_core.ScaleWith(scale); - } -}; - -// Get peak single precision throughput of the GPU in GFLOPS per -// streaming multiprocessor. -// TODO: Need design on how to use the sparsity capability of FLOPs. -double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap); - -// for Nvidia GPU, return shared memory bandwidth in Bytes Per Second on -// one single SM given the GPU core freq in device_cap. -double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap); - -// Returns the GPU model name from the given DeviceCapabilities. -// For nvidia GPUs, the name is like "Nvidia GPU (Kepler)" or "Nvidia GPU -// (Turing)". For AMD GPUs, the name is like "AMD GPU - gfx-10XX series". -// The model name here for Nvidia GPU in fact refers to its microarchitecture -// name. -absl::string_view GpuModelName(const DeviceCapabilities& device_cap); - -HardwareType ParseHardwareType(absl::string_view device_type); - -// Returns true if the given hardware type has a device. -bool HasDevice(HardwareType x); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc deleted file mode 100644 index 9476848a650dcc..00000000000000 --- a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace profiler { -namespace { - -TEST(HardwareTypeUtilsTest, H100PeakComputTFlops) { - DeviceCapabilities device_cap; - // For NVIDIA H100 PCIe 80 GB, according to - // https://resources.nvidia.com/en-us-data-center-overview/gtc22-whitepaper-hopper - // https://www.techpowerup.com/gpu-specs/h100-pcie-80-gb.c3899 - device_cap.set_clock_rate_in_ghz(1.620); - device_cap.set_num_cores(114); - device_cap.set_memory_size_in_bytes( - tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); - device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); - device_cap.set_device_vendor("Nvidia"); - device_cap.mutable_compute_capability()->set_major(9); - device_cap.mutable_compute_capability()->set_minor(0); - - // Get target TFLOPS per SM and check. - double peak_tflops = - GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; - EXPECT_NEAR(peak_tflops, 756, /*abs_error=*/1.0); -} - -TEST(HardwareTypeUtilsTest, A100PeakComputTFlops) { - DeviceCapabilities device_cap; - // For NVIDIA A100 SXM4 80 GB, according to: - // https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf - // https://www.techpowerup.com/gpu-specs/a100-sxm4-80-gb.c3746 - device_cap.set_clock_rate_in_ghz(1.410); - device_cap.set_num_cores(108); - device_cap.set_memory_size_in_bytes( - tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); - device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); - device_cap.set_device_vendor("Nvidia"); - device_cap.mutable_compute_capability()->set_major(8); - device_cap.mutable_compute_capability()->set_minor(0); - - double peak_tflops = - GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; - EXPECT_NEAR(peak_tflops, 312, /*abs_error=*/1.0); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_map.cc b/tensorflow/core/profiler/utils/hlo_module_map.cc deleted file mode 100644 index d4683d22f33efa..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_module_map.cc +++ /dev/null @@ -1,181 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_module_map.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "tsl/profiler/lib/traceme_encode.h" - -#if GOOGLE_CUDA -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#endif -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/core/platform/path.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -#if GOOGLE_CUDA -int64_t ShapeSize(const xla::Shape& shape) { - constexpr int64_t kPointerSize = 8; - return xla::ShapeUtil::ByteSizeOf(shape, kPointerSize); -} -#endif - -} // namespace - -HloInstructionWrapper::HloInstructionWrapper( - const xla::HloInstruction* instr, const xla::HloCostAnalysis* cost_analysis) - : instr_(instr), - op_full_name_( - tsl::profiler::TraceMeOp(Metadata().op_name(), Metadata().op_type())), - tf_op_name_(tsl::profiler::TfOpFullname(Metadata().op_type(), - Metadata().op_name())), - category_(instr_->ToCategory()), - expression_(tensorflow::profiler::UncachedExpression( - instr_, false, tensorflow::profiler::kMaxHlolNameSize)) { - ProcessXlaCostAnalysis(cost_analysis); -} - -HloModuleWrapper::HloModuleWrapper( - const xla::HloProto& hlo_proto, - std::function shape_func) - : HloModuleWrapper(ConvertHloProtoToModuleIgnoringErrors(hlo_proto), - shape_func) {} - -HloModuleWrapper::HloModuleWrapper( - std::unique_ptr module, - std::function shape_func) - : module_(std::move(module)) { - if (module_ == nullptr) return; - - const xla::HloCostAnalysis* cost_analysis = nullptr; -#if GOOGLE_CUDA - if (shape_func == nullptr) shape_func = ShapeSize; - xla::HloCostAnalysis::Options options; - options.shape_size = shape_func; - xla::gpu::GpuHloCostAnalysis gpu_cost_analysis(options); - - const xla::HloComputation* hlo_computation = module_->entry_computation(); - gpu_cost_analysis.ReserveVisitStates(hlo_computation->instruction_count()); - tsl::Status analysis_status = hlo_computation->Accept(&gpu_cost_analysis); - if (analysis_status.ok()) { - // Clear the visit state as it isn't used by anybody and it uses a lot of - // memory. - gpu_cost_analysis.DestroyVisitState(); - } else { - LOG(ERROR) << "Failed to create cost analysis: " << analysis_status; - } - cost_analysis = &gpu_cost_analysis; -#endif - - // Populate instructions_by_name_ with module. - for (const xla::HloComputation* computation : module_->computations()) { - for (const xla::HloInstruction* instr : computation->instructions()) { - instructions_by_name_.try_emplace( - instr->name(), HloInstructionWrapper(instr, cost_analysis)); - } - } - // Gather nested fusion instructions. - for (const xla::HloComputation* computation : module_->computations()) { - // Some modules still seem to have "dead" fusions computations. In this - // case, IsFusionComputation() = true but there is no parent - // FusionInstruction(). - if (computation->FusionInstruction() != nullptr) { - GatherFusionInstructions(computation->FusionInstruction()); - } - } -} - -// Function to gather all the instructions in a fusion computation. -void HloModuleWrapper::GatherFusionInstructions(xla::HloInstruction* inst) { - HloInstructionWrapper* fused_inst_wrapper = - GetMutableHloInstruction(inst->name()); - DCHECK(fused_inst_wrapper != nullptr); - if (!fused_inst_wrapper->FusedChildren().empty()) return; - for (auto* fused : inst->fused_instructions()) { - const auto child_inst_wrapper = GetHloInstruction(fused->name()); - DCHECK(child_inst_wrapper != nullptr); - fused_inst_wrapper->AddFusedChild(child_inst_wrapper); - if (fused->opcode() == xla::HloOpcode::kFusion) { - GatherFusionInstructions(fused); - } - } -} - -HloInstructionWrapper* HloModuleWrapper::GetMutableHloInstruction( - absl::string_view hlo_name) { - auto it = instructions_by_name_.find(hlo_name); - if (it != instructions_by_name_.end()) return &it->second; - return nullptr; -} - -const HloInstructionWrapper* HloModuleWrapper::GetHloInstruction( - absl::string_view hlo_name) const { - auto it = instructions_by_name_.find(hlo_name); - if (it != instructions_by_name_.end()) return &it->second; - return nullptr; -} - -std::string HloInstructionWrapper::source_info() const { - if (!Metadata().source_file().empty()) { - return absl::StrCat(io::Basename(Metadata().source_file()), ":", - Metadata().source_line()); - } else { - return std::string(); - } -} - -void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, - const xla::HloProto& hlo_proto) { - auto hlo_module = ConvertHloProtoToModule(hlo_proto); - if (!hlo_module.ok()) { - LOG(ERROR) << hlo_module.status(); - return; - } - hlo_module_map.try_emplace(program_id, - HloModuleWrapper(std::move(hlo_module).value(), - /*shape_func=*/nullptr)); -} - -void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, - const XSpace* space) { - for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(*space)) { - AddHloProto(hlo_module_map, program_id, *hlo_proto); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_map.h b/tensorflow/core/profiler/utils/hlo_module_map.h index ab6898af72ed84..e6c58633d5f334 100644 --- a/tensorflow/core/profiler/utils/hlo_module_map.h +++ b/tensorflow/core/profiler/utils/hlo_module_map.h @@ -16,200 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -class HloInstructionInterface { - public: - virtual ~HloInstructionInterface() = default; - virtual absl::string_view Name() const = 0; - virtual xla::HloOpcode HloOpcode() const = 0; - virtual absl::string_view Category() const = 0; - virtual std::string HloOpcodeString() const = 0; - virtual const xla::OpMetadata& Metadata() const = 0; - virtual size_t flops() const = 0; - virtual size_t bytes_accessed() const = 0; - virtual std::string_view op_full_name() const = 0; - virtual std::string_view TfOpName() const = 0; - virtual std::string source_info() const = 0; - virtual bool isRoot() const = 0; - virtual bool IsFusion() const = 0; - virtual const std::string& Expression() const = 0; - - virtual void ProcessXlaCostAnalysis( - const xla::HloCostAnalysis* cost_analysis) = 0; - virtual std::string OpLocationStack(int32_t frame_id) const = 0; - virtual tsl::profiler::OpSourceInfo SourceInfo() const = 0; -}; - -// This wrapper allows caching the results of HloInstruction methods. -// This wrapper is not thread safe. -class HloInstructionWrapper : public HloInstructionInterface { - public: - explicit HloInstructionWrapper( - const xla::HloInstruction* instr, - const xla::HloCostAnalysis* cost_analysis = nullptr); - - // Non copyable - HloInstructionWrapper(const HloInstructionWrapper&) = delete; - HloInstructionWrapper& operator=(const HloInstructionWrapper&) = delete; - // Movable. - HloInstructionWrapper(HloInstructionWrapper&&) = default; - HloInstructionWrapper& operator=(HloInstructionWrapper&&) = default; - - absl::string_view Name() const override { return instr_->name(); } - - xla::HloOpcode HloOpcode() const override { return instr_->opcode(); } - - absl::string_view Category() const override { return category_; } - - std::string HloOpcodeString() const override { - return std::string(xla::HloOpcodeString(instr_->opcode())); - } - - const xla::OpMetadata& Metadata() const override { - return instr_->metadata(); - } - - size_t flops() const override { return flops_; } - size_t bytes_accessed() const override { return bytes_accessed_; } - - std::string_view op_full_name() const override { return op_full_name_; } - std::string_view TfOpName() const override { return tf_op_name_; } - std::string source_info() const override; - - bool isRoot() const override { return instr_->IsRoot(); } - bool IsFusion() const override { return !fused_children_.empty(); }; - - void ProcessXlaCostAnalysis( - const xla::HloCostAnalysis* cost_analysis) override { - if (cost_analysis == nullptr) return; - flops_ = cost_analysis->flop_count(*instr_); - bytes_accessed_ = cost_analysis->bytes_accessed(*instr_); - } - - const std::string& Expression() const override { return expression_; } - - void AddFusedChild(const HloInstructionWrapper* child) { - fused_children_.push_back(child); - }; - - const std::vector& FusedChildren() const { - return fused_children_; - } - - std::string OpLocationStack(int32_t frame_id) const override { - return GetOpLocationStack(frame_id, instr_); - } - - tsl::profiler::OpSourceInfo SourceInfo() const override { - return GetSourceInfo(instr_); - } - - private: - const xla::HloInstruction* instr_; - std::vector fused_children_; - std::string op_full_name_; - std::string tf_op_name_; - size_t flops_ = 0; - size_t bytes_accessed_ = 0; - std::string category_; - std::string expression_; -}; - -// Helper class for accessing HloModule. -class HloModuleInterface { - public: - virtual ~HloModuleInterface() = default; - - // If the module contains no instructions. - virtual bool Empty() const = 0; - virtual absl::string_view Name() const = 0; - // Function to populated nested childs= instructions in a fusion. - virtual void GatherFusionInstructions(xla::HloInstruction* inst) = 0; -}; - -// Wraps HLO module and provides an interface that maps HLO names to -// HloInstructionWrappers. -class HloModuleWrapper : public HloModuleInterface { - public: - explicit HloModuleWrapper( - const xla::HloProto& hlo_proto, - std::function shape_func = nullptr); - - explicit HloModuleWrapper( - std::unique_ptr module, - std::function shape_func); - - const HloInstructionWrapper* GetHloInstruction( - absl::string_view hlo_name) const; - HloInstructionWrapper* GetMutableHloInstruction(absl::string_view hlo_name); - - bool Empty() const override { return instructions_by_name_.empty(); } - - absl::string_view Name() const override { return module_->name(); } - void GatherFusionInstructions(xla::HloInstruction* inst) override; - - private: - std::unique_ptr module_; - - // Map of HloInstructionWrappers by name. - using HloInstructionMap = - absl::flat_hash_map; - HloInstructionMap instructions_by_name_; -}; - -// Map of HloModuleWrappers by program_id. -using HloModuleMap = - absl::flat_hash_map; - -void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, - const xla::HloProto& hlo_proto); - -// Process HloModuleMap from single XSpace. -void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, - const XSpace* space); - -// WARNING: The returned pointer will be invalidated if HloModuleMap is mutated. -inline const HloModuleWrapper* GetHloModule(const HloModuleMap* hlo_module_map, - uint64_t program_id) { - if (hlo_module_map == nullptr) return nullptr; - auto iter = hlo_module_map->find(program_id); - if (iter == hlo_module_map->end()) return nullptr; - return &iter->second; -} - -inline const HloInstructionWrapper* GetHloInstruction( - const HloModuleMap& hlo_module_map, std::optional program_id, - absl::string_view hlo_name) { - if (!program_id.has_value()) return nullptr; - const auto* hlo_module = GetHloModule(&hlo_module_map, *program_id); - if (hlo_module == nullptr) return nullptr; - return hlo_module->GetHloInstruction(hlo_name); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_module_map.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ diff --git a/tensorflow/core/profiler/utils/hlo_module_utils.h b/tensorflow/core/profiler/utils/hlo_module_utils.h index 2de48469253fe9..8b68816a52ebb6 100644 --- a/tensorflow/core/profiler/utils/hlo_module_utils.h +++ b/tensorflow/core/profiler/utils/hlo_module_utils.h @@ -16,103 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_print_options.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" - -namespace tensorflow { -namespace profiler { - -// Sometimes HLO produce a huge string (>100MB). Limit the name size to 1MB. -static constexpr size_t kMaxHlolNameSize = 1000000; - -inline const xla::HloInstruction* FindInstruction(const xla::HloModule& module, - std::string node_name) { - if (absl::StartsWith(node_name, "%")) { - node_name.erase(node_name.begin()); - } - for (const xla::HloComputation* computation : module.computations()) { - auto instrs = computation->instructions(); - auto it = absl::c_find_if(instrs, [&](const xla::HloInstruction* instr) { - // Try with and without "%" at the beginning of the node name. - return absl::EqualsIgnoreCase(instr->name(), node_name) || - absl::EqualsIgnoreCase(instr->name(), - absl::StrCat("%", node_name)); - }); - if (it != instrs.end()) { - return *it; - } - } - return nullptr; -} - -inline const xla::HloComputation* FindComputation( - const xla::HloModule& module, const std::string& comp_name) { - for (const xla::HloComputation* computation : module.computations()) { - if (absl::EqualsIgnoreCase(computation->name(), comp_name)) { - return computation; - } - } - return nullptr; -} - -inline std::string UncachedExpression(const xla::HloInstruction* instr, - bool skip_expression, size_t max_size) { - if (skip_expression) { - return ""; - } - static const auto* hlo_print_options = - new xla::HloPrintOptions(xla::HloPrintOptions() - .set_print_metadata(false) - .set_print_backend_config(false) - .set_print_infeed_outfeed_config(false) - .set_print_operand_shape(true) - .set_print_large_constants(false)); - std::string expression = instr->ToString(*hlo_print_options); - if (expression.size() > max_size) { - expression.resize(max_size); - } - return expression; -} - -inline std::string GetOpLocationStack(int32_t frame_id, - const xla::HloInstruction* instr) { - std::string stack_lines; - xla::HloModule* hlo_module = instr->GetModule(); - while (frame_id != 0) { - xla::HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); - if (frame.empty()) { - break; - } - stack_lines.insert(0, absl::StrCat(frame.file_name, ":", frame.line, ":", - frame.column, "\n")); - frame_id = frame.parent_frame_id; - } - - return stack_lines; -}; - -inline tsl::profiler::OpSourceInfo GetSourceInfo( - const xla::HloInstruction* instr) { - if (int32_t stack_frame_id = instr->metadata().stack_frame_id(); - stack_frame_id != 0) { - return {.source_file = instr->metadata().source_file(), - .source_line = instr->metadata().source_line(), - .stack_frame = GetOpLocationStack(stack_frame_id, instr)}; - } - return {.source_file = instr->metadata().source_file(), - .source_line = instr->metadata().source_line()}; -}; -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_module_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/hlo_module_utils_test.cc b/tensorflow/core/profiler/utils/hlo_module_utils_test.cc deleted file mode 100644 index 18eb2a2cdce7ce..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_module_utils_test.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" - -#include - -#include -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace profiler { -namespace { - -class HloModuleUtilsTest : public xla::HloTestBase { - protected: - absl::StatusOr> GetModuleWithStackFrames() { - const char file_name[] = "main.py"; - const char function_name[] = "func1"; - const int line_number = 10; - const int column_number = 5; - const int frame_id = 1; - const char text[] = R"( - HloModule a_module - - ENTRY main { - %c = s32[] constant(1) - ROOT %result = s32[] parameter(0) - } - )"; - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(text)); - - auto module_proto = module->ToProto(); - auto index = module_proto.mutable_stack_frame_index(); - index->add_file_names(file_name); - index->add_function_names(function_name); - auto location = index->add_file_locations(); - location->set_file_name_id(frame_id); - location->set_function_name_id(1); - location->set_line(line_number); - location->set_column(column_number); - - auto frame = index->add_stack_frames(); - frame->set_file_location_id(1); - - // Set the stack frame id of the root instruction. - for (auto& computation : *module_proto.mutable_computations()) { - if (computation.id() == module_proto.entry_computation_id()) { - for (auto& instruction : *computation.mutable_instructions()) { - if (instruction.id() == computation.root_id()) { - instruction.mutable_metadata()->set_stack_frame_id(frame_id); - instruction.mutable_metadata()->set_source_file(file_name); - instruction.mutable_metadata()->set_source_line(line_number); - } - } - } - } - - return xla::HloModule::CreateFromProto(module_proto, module->config()); - } -}; - -TEST_F(HloModuleUtilsTest, TestGetLocationStack) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_with_stack_frames, - GetModuleWithStackFrames()); - auto root_instruction = - module_with_stack_frames->entry_computation()->root_instruction(); - EXPECT_EQ(GetOpLocationStack(1, root_instruction), "main.py:10:5\n"); -} - -TEST_F(HloModuleUtilsTest, TestGetSourceInfo) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_with_stack_frames, - GetModuleWithStackFrames()); - auto root_instruction = - module_with_stack_frames->entry_computation()->root_instruction(); - auto source_info = GetSourceInfo(root_instruction); - EXPECT_EQ(source_info.source_file, "main.py"); - EXPECT_EQ(source_info.source_line, 10); - EXPECT_EQ(source_info.stack_frame, "main.py:10:5\n"); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc deleted file mode 100644 index 50d96c49980e74..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -int NumHeapSimulatorTraceEvents(const xla::HloProto* hlo) { - int result = 0; - for (const auto& trace : hlo->buffer_assignment().heap_simulator_traces()) { - result += trace.events_size(); - } - return result; -} - -} // namespace - -absl::flat_hash_map> -ParseHloProtosFromXSpace(const XSpace& space) { - absl::flat_hash_map> hlo_protos; - std::vector planes = - FindPlanesWithNames(space, {kMetadataPlaneName}); - for (const XPlane* raw_plane : planes) { - if (raw_plane != nullptr) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(raw_plane); - - const XStatMetadata* hlo_proto_stat_metadata = - plane.GetStatMetadataByType(StatType::kHloProto); - if (hlo_proto_stat_metadata != nullptr) { - plane.ForEachEventMetadata( - [&](const XEventMetadataVisitor& event_metadata) { - auto hlo_proto_stat = event_metadata.GetStat( - StatType::kHloProto, *hlo_proto_stat_metadata); - if (!hlo_proto_stat) return; - if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = hlo_proto_stat->BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), - byte_value.size())) { - if (!hlo_protos - .try_emplace(event_metadata.Id(), std::move(hlo_proto)) - .second) { - LOG(WARNING) << "Insert failed for hlo_proto with program_id" - << event_metadata.Id(); - } - } - }); - } - } - } - return hlo_protos; -} - -bool HloProtoMap::AddHloProto(uint64_t program_id, - const xla::HloProto* hlo_proto) { - bool new_program_id = - hlo_protos_by_program_id_.try_emplace(program_id, hlo_proto).second; - absl::string_view hlo_module_name = hlo_proto->hlo_module().name(); - bool new_module_name = - hlo_protos_by_name_ - .try_emplace(tsl::profiler::HloModuleNameWithProgramId( - hlo_module_name, program_id), - hlo_proto) - .second; - return new_program_id || new_module_name; -} - -void HloProtoMap::AddHloProto(uint64_t program_id, - std::unique_ptr hlo_proto) { - if (AddHloProto(program_id, hlo_proto.get())) { - // Only add to if is new to HloProtoMap. - owned_hlo_protos_.push_back(std::move(hlo_proto)); - } -} - -void HloProtoMap::AddHloProtosFromXSpace(const XSpace& space) { - for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(space)) { - AddHloProto(program_id, std::move(hlo_proto)); - } -} - -std::vector HloProtoMap::GetModuleList() const { - std::vector module_list; - module_list.reserve(hlo_protos_by_name_.size()); - for (const auto& [name, hlo_proto] : hlo_protos_by_name_) { - module_list.push_back(name); - } - return module_list; -} - -std::vector HloProtoMap::GetSortedModuleList() const { - std::vector module_list = GetModuleList(); - absl::c_sort(module_list); - return module_list; -} - -std::vector HloProtoMap::GetSortedModuleListByHeapTraceSize() - const { - std::vector> hlo_protos( - hlo_protos_by_name_.begin(), hlo_protos_by_name_.end()); - - // Sort the hlo protos by heap trace size and then by hlo module name. - // This way trivial computations will be on the bottom of the list. - absl::c_stable_sort(hlo_protos, [](const auto& a, const auto& b) { - int num_a = tensorflow::profiler::NumHeapSimulatorTraceEvents(a.second); - int num_b = tensorflow::profiler::NumHeapSimulatorTraceEvents(b.second); - return std::tie(num_a, b.first) > std::tie(num_b, a.first); - }); - - std::vector module_list; - module_list.reserve(hlo_protos.size()); - for (const auto& [name, hlo_proto] : hlo_protos) { - module_list.push_back(name); - } - return module_list; -} - -absl::StatusOr HloProtoMap::GetHloProtoByProgramId( - uint64_t program_id) const { - auto iter = hlo_protos_by_program_id_.find(program_id); - if (iter != hlo_protos_by_program_id_.end()) { - return iter->second; - } - return absl::NotFoundError( - absl::StrCat("Program id: ", program_id, " is not found.")); -} - -absl::StatusOr HloProtoMap::GetHloProtoByModuleName( - absl::string_view module_name) const { - auto iter = hlo_protos_by_name_.find(module_name); - if (iter != hlo_protos_by_name_.end()) { - return iter->second; - } - return absl::NotFoundError( - absl::StrCat("Module name: ", module_name, " is not found.")); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.h b/tensorflow/core/profiler/utils/hlo_proto_map.h index 383c3064bc85de..23259adffaedab 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.h +++ b/tensorflow/core/profiler/utils/hlo_proto_map.h @@ -16,71 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -absl::flat_hash_map> -ParseHloProtosFromXSpace(const XSpace& space); - -class HloProtoMap { - public: - void AddHloProtosFromXSpace(const XSpace& space); - - void AddHloProto(uint64_t program_id, - std::unique_ptr hlo_proto); - - size_t size() const { return hlo_protos_by_program_id_.size(); } - - auto begin() const { return hlo_protos_by_program_id_.begin(); } - auto end() const { return hlo_protos_by_program_id_.end(); } - - bool contains(absl::string_view name) const { - return hlo_protos_by_name_.contains(name); - } - - bool contains(uint64_t program_id) const { - return hlo_protos_by_program_id_.contains(program_id); - } - - // Returns a list of module names (not sorted). - std::vector GetModuleList() const; - - // Returns a list of module names sorted alphabetically. - std::vector GetSortedModuleList() const; - - // Returns a list of hlo module names sorted first by heap trace size and then - // by hlo module name alphabetically. - std::vector GetSortedModuleListByHeapTraceSize() const; - - absl::StatusOr GetHloProtoByModuleName( - absl::string_view module_name) const; - - absl::StatusOr GetHloProtoByProgramId( - uint64_t program_id) const; - - private: - absl::flat_hash_map hlo_protos_by_program_id_; - absl::flat_hash_map hlo_protos_by_name_; - std::vector> owned_hlo_protos_; - - // Try to add proto to the map and returns true if the addition is successful - // (i.e., the proto is new to the map). - bool AddHloProto(uint64_t program_id, const xla::HloProto* hlo_proto); -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_proto_map.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ diff --git a/tensorflow/core/profiler/utils/hlo_proto_to_module.cc b/tensorflow/core/profiler/utils/hlo_proto_to_module.cc deleted file mode 100644 index 4083bbfe8bbe49..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_proto_to_module.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" - -#include -#include - -#include "absl/log/log.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/util.h" - -namespace tensorflow { -namespace profiler { - -absl::StatusOr> ConvertHloProtoToModule( - const xla::HloProto& hlo_proto) { - if (!hlo_proto.has_hlo_module()) { - return xla::Internal("No HLO module found in the HLO proto"); - } - const xla::HloModuleProto& module_proto = hlo_proto.hlo_module(); - TF_ASSIGN_OR_RETURN(auto config, xla::HloModule::CreateModuleConfigFromProto( - module_proto, xla::DebugOptions())); - TF_ASSIGN_OR_RETURN(auto module, - xla::HloModule::CreateFromProto(module_proto, config)); - return module; -} - -std::unique_ptr ConvertHloProtoToModuleIgnoringErrors( - const xla::HloProto& hlo_proto) { - auto module = ConvertHloProtoToModule(hlo_proto); - if (!module.ok()) { - LOG(ERROR) << module.status(); - return nullptr; - } - return std::move(module).value(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_proto_to_module.h b/tensorflow/core/profiler/utils/hlo_proto_to_module.h index 4cf3fa6383367d..954ed71345c9bd 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_to_module.h +++ b/tensorflow/core/profiler/utils/hlo_proto_to_module.h @@ -16,22 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ -#include - -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" - -namespace tensorflow { -namespace profiler { - -absl::StatusOr> ConvertHloProtoToModule( - const xla::HloProto& hlo_proto); - -std::unique_ptr ConvertHloProtoToModuleIgnoringErrors( - const xla::HloProto& hlo_proto); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ diff --git a/tensorflow/core/profiler/utils/host_offload_utils.cc b/tensorflow/core/profiler/utils/host_offload_utils.cc deleted file mode 100644 index 7f135985d0b1c6..00000000000000 --- a/tensorflow/core/profiler/utils/host_offload_utils.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/host_offload_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -bool HostOffloadEventProcessor::IsHostOffloadOpName( - const XEventVisitor& event) const { - static constexpr absl::string_view keywords[] = {"copy-start", - "copy-done", - "dynamic-slice-start", - "dynamic-slice-done", - "dynamic-update-slice-start", - "dynamic-update-slice-done"}; - - for (const auto& keyword : keywords) { - // The host_memory_label_ S(5) is used by instructions to designate tensors - // that are on the host. - if (absl::StrContains(event.DisplayName(), keyword) && - absl::StrContains(event.Name(), host_memory_label_)) { - return true; - } - } - return false; -} - -std::string HostOffloadEventProcessor::GetOffloadInstructionID( - absl::string_view op_name) const { - std::vector op_name_vec = absl::StrSplit(op_name, '.'); - - // If no dot is found, or it's at the beginning or end of the string, return - // a 0. Hlo opnames are not expected to have a dot followed by 0. - if (op_name_vec.size() < 2) { - return "0"; - } - return op_name_vec.back(); -} - -std::string HostOffloadEventProcessor::GetOffloadInstructionName( - absl::string_view op_name) const { - // TODO(b/342469268): Get the display ID and name from the HloInstruction, not - // just the event name. - std::string display_id = GetOffloadInstructionID(op_name); - - size_t startPos = op_name.find("-start"); - size_t donePos = op_name.find("-done"); - - absl::string_view display_opname; - if (startPos != absl::string_view::npos) { - display_opname = op_name.substr(0, startPos); - } else if (donePos != absl::string_view::npos) { - display_opname = op_name.substr(0, donePos); - } else { - // Invalid input format: neither "-start" nor "-done" found - LOG(WARNING) << "Invalid op name: " << op_name; - display_opname = op_name; - } - return absl::StrCat("offload-", display_opname, ".", display_id); -} - -void HostOffloadEventProcessor::ProcessHostOffloadOpEvent( - const XEventVisitor& event, std::optional group_id) { - std::string display_opname = GetOffloadInstructionName(event.DisplayName()); - - auto [iter, inserted] = seen_events_.try_emplace(display_opname); - std::queue& events = iter->second; - - if (absl::StrContains(event.DisplayName(), "-start")) { - // For start events, just push them into the queue. - events.push(&event); - return; - } else if (absl::StrContains(event.DisplayName(), "-done")) { - // for done events, pop the start event and create the new event. - // Not all start events may be traced. In this case we just skip the - // corresponding done event. - if (events.empty()) { - LOG(INFO) << "No corresponding start event found for " - << event.DisplayName(); - return; - } - const XEventVisitor* start_event = events.front(); - events.pop(); - - // At this point, we have the corresponding start and end event. - // Create the new event. - tsl::profiler::Timespan event_span = tsl::profiler::Timespan::FromEndPoints( - start_event->GetTimespan().begin_ps(), event.GetTimespan().end_ps()); - - // Find the line with the smallest event end time frontier that can fit this - // new event without overlapping with its other events. - int line_builder_index = -1; - uint64_t minimum_end_time_frontier = event_span.begin_ps(); - for (int i = 0; i < host_offload_op_line_builders_.size(); ++i) { - if (host_offload_op_line_builders_[i].event_end_time_frontier_ns <= - minimum_end_time_frontier) { - line_builder_index = i; - minimum_end_time_frontier = - host_offload_op_line_builders_[i].event_end_time_frontier_ns; - } - } - - constexpr int kMaxHostOffloadOpLinesSize = - kThreadIdHostOffloadOpEnd - kThreadIdHostOffloadOpStart + 1; - - // If no existing lines can fit this new event, create a new line. - if (line_builder_index == -1) { - if (host_offload_op_line_builders_.size() < kMaxHostOffloadOpLinesSize) { - XLineBuilder lb = plane_builder_->GetOrCreateLine( - kThreadIdHostOffloadOpStart + - host_offload_op_line_builders_.size()); - lb.SetName(absl::StrFormat("%s row %d", kHostOffloadOpLineName, - host_offload_op_line_builders_.size())); - lb.SetTimestampNs(start_timestamp_ns_); - host_offload_op_line_builders_.push_back( - {std::move(lb), event_span.end_ps()}); - } - // If we have reached the maximum number of lines, just use the last line. - line_builder_index = host_offload_op_line_builders_.size() - 1; - } - - // Update the event end time frontier for the line. - host_offload_op_line_builders_[line_builder_index] - .event_end_time_frontier_ns = - std::max(host_offload_op_line_builders_[line_builder_index] - .event_end_time_frontier_ns, - event_span.end_ps()); - - XEventMetadata* host_offload_copy_metadata = - plane_builder_->CreateEventMetadata(); - host_offload_copy_metadata->set_display_name(display_opname); - XEventBuilder event_builder = - host_offload_op_line_builders_[line_builder_index] - .line_builder.AddEvent(*host_offload_copy_metadata); - event_builder.SetTimespan(event_span); - - // We mark the events as async so that they are displayed on new sub-lines - // below other async events. - const XStatMetadata& async_stat = *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kIsAsync)); - event_builder.AddStatValue(async_stat, 1); - - // Set metadata stats for the event. - const XStatMetadata& raw_bytes_stat = - *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kRawBytesAccessed)); - event.Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kRawBytesAccessed) { - event_builder.AddStatValue(raw_bytes_stat, stat.IntValue()); - } - }); - const XStatMetadata& shape_with_layout_str = - *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kShapeWithLayout)); - // Use the shape from start_event, since it contains the shape of end event. - start_event->Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kShapeWithLayout) { - event_builder.AddStatValue(shape_with_layout_str, stat.StrOrRefValue()); - } - }); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/host_offload_utils.h b/tensorflow/core/profiler/utils/host_offload_utils.h deleted file mode 100644 index dbf308fbfe1e41..00000000000000 --- a/tensorflow/core/profiler/utils/host_offload_utils.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/layout.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -struct LineBuilderAndEventEndTimeFrontier { - XLineBuilder line_builder; - uint64_t event_end_time_frontier_ns; -}; - -class HostOffloadEventProcessor { - public: - HostOffloadEventProcessor(XPlaneBuilder* plane_builder, - uint64_t start_timestamp_ns) - : plane_builder_(plane_builder), - start_timestamp_ns_(start_timestamp_ns) {} - ~HostOffloadEventProcessor() = default; - - void ProcessHostOffloadOpEvent(const XEventVisitor& event, - std::optional group_id); - - bool IsHostOffloadOpName(const XEventVisitor& event) const; - - private: - std::string GetOffloadInstructionID(absl::string_view op_name) const; - std::string GetOffloadInstructionName(absl::string_view op_name) const; - - absl::flat_hash_map> - seen_events_; - std::string host_memory_label_ = - absl::StrCat("S(", xla::Layout::kHostMemorySpace, ")"); - - XPlaneBuilder* plane_builder_; - uint64_t start_timestamp_ns_; - - std::vector - host_offload_op_line_builders_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/html_utils.h b/tensorflow/core/profiler/utils/html_utils.h index 215d9f51d5bec2..9dbf42507b4321 100644 --- a/tensorflow/core/profiler/utils/html_utils.h +++ b/tensorflow/core/profiler/utils/html_utils.h @@ -16,21 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" - -namespace tensorflow { -namespace profiler { - -// Creates a html that links to the given url with the given text. -inline std::string AnchorElement(absl::string_view url, - absl::string_view text) { - return absl::StrCat("", text, ""); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/html_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc deleted file mode 100644 index be88b216465220..00000000000000 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ /dev/null @@ -1,352 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// The maximum number of Kernels displayed on Kernel Stats page. -const int kMaxNumOfKernels = 1000; - -// A list of patterns to help determine if a kernel uses Tensor Core. -// A kernel uses Tensor Core if its kernel name contains any of these patterns. -// Some examples of kernel names: volta_h884gemm, turing_fp16_s1688cudnn_fp16 -constexpr absl::string_view kTensorCoreKernelNamePatterns[] = { - "16816", - "c1688", - "conv1x1", - "conv2d_c1_k1", - "dgrad_1x1_stride_2x2", - "direct_group", - "first_layer_wgrad_kernel", - "h1688", - "h884", - "hmma", - "i16832", - "i8816", - "s884", - "s1688", - "xmma_gemm", - "xmma_implicit_gemm", - "xmma_sparse_conv", - "xmma_sparse_gemm", - "xmma_warp_specialized_implicit_gemm"}; - -} // namespace - -void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, - KernelReport* kernel) { - const std::vector params = - absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(" \n")); - - constexpr uint32 kNumDimensions = 3; - for (uint32 dim = 0; dim < kNumDimensions; ++dim) { - kernel->add_block_dim(1); - kernel->add_grid_dim(1); - } - - // Process tokens. - for (const auto& param : params) { - const std::vector key_value = absl::StrSplit(param, ':'); - if (key_value.size() != 2) { - // Unrecognized token. - continue; - } - absl::string_view key = key_value[0]; - absl::string_view value_str = key_value[1]; - uint32 value = 0; - double pct = 0.0; - // Cases that consume a pair of tokens "key:value". - if (key == "regs" && absl::SimpleAtoi(value_str, &value)) { - kernel->set_registers_per_thread(value); - } else if (key == "static_shared" && absl::SimpleAtoi(value_str, &value)) { - kernel->set_static_shmem_bytes(value); - } else if (key == "dynamic_shared" && absl::SimpleAtoi(value_str, &value)) { - kernel->set_dynamic_shmem_bytes(value); - } else if (key == "block") { - const std::vector& block = - absl::StrSplit(value_str, ','); - uint32 tmp[3]; - if (block.size() == 3 && absl::SimpleAtoi(block[0], &tmp[0]) && - absl::SimpleAtoi(block[1], &tmp[1]) && - absl::SimpleAtoi(block[2], &tmp[2])) { - std::copy_n(tmp, 3, kernel->mutable_block_dim()->begin()); - } - } else if (key == "grid") { - const std::vector& grid = - absl::StrSplit(value_str, ','); - uint32 tmp[3]; - if (grid.size() == 3 && absl::SimpleAtoi(grid[0], &tmp[0]) && - absl::SimpleAtoi(grid[1], &tmp[1]) && - absl::SimpleAtoi(grid[2], &tmp[2])) { - std::copy_n(tmp, 3, kernel->mutable_grid_dim()->begin()); - } - } else if (key == "occ_pct" && absl::SimpleAtod(value_str, &pct)) { - kernel->set_occupancy_pct(pct); - } - } -} - -bool IsKernelUsingTensorCore(absl::string_view kernel_name) { - VLOG(1) << "kernel name: " << kernel_name; - for (absl::string_view pattern : kTensorCoreKernelNamePatterns) { - if (absl::StrContains(kernel_name, pattern)) { - return true; - } - } - return false; -} - -// This list is not exhaustive. -bool IsOpTensorCoreEligible(absl::string_view tf_op_name) { - // Disable formatting to keep inline comments vertically aligned. - // clang-format off - return false - // Using EndsWith to match Fused operations. - || absl::EndsWith(tf_op_name, "Conv2D") - || absl::EndsWith(tf_op_name, "Conv2DBackpropFilter") - || absl::EndsWith(tf_op_name, "Conv2DBackpropInput") - || absl::EndsWith(tf_op_name, "Conv3D") - || absl::EndsWith(tf_op_name, "DepthwiseConv2dNative") - || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropFilter") - || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropInput") - // Using Contains to match V2/V3 suffixes. - || absl::StrContains(tf_op_name, "BatchMatMul") - // MatMul requires exact matching. - || absl::EndsWith(tf_op_name, "/MatMul") - || absl::EndsWith(tf_op_name, "FusedMatMul") - // cuDNN operations. - || absl::EndsWith(tf_op_name, "/CudnnRNN") - || absl::StrContains(tf_op_name, "CudnnRNNV") - || absl::StrContains(tf_op_name, "CudnnRNNForward") - || absl::StrContains(tf_op_name, "CudnnRNNBackprop") - // Special cases. - || absl::EndsWith(tf_op_name, "XlaDot") - || absl::EndsWith(tf_op_name, "XlaDotV2"); - // clang-format on -} - -bool IsEinsumTensorCoreEligible(absl::string_view equation) { - if (equation.empty()) { - return false; - } - const std::vector input_output = - absl::StrSplit(equation, "->"); - if (input_output.size() != 2) { - return false; - } - const std::vector lhs_rhs = - absl::StrSplit(input_output[0], ','); - return lhs_rhs.size() == 2; -} - -bool KernelReportLessThanComparator::operator()(const KernelReport& lhs, - const KernelReport& rhs) const { - // Disable formatting to keep vertical alignment for better readability, - // and make it easier to reorder columns. - // clang-format off - auto lhs_tuple = std::make_tuple( - lhs.name(), - lhs.grid_dim(0), - lhs.grid_dim(1), - lhs.grid_dim(2), - lhs.block_dim(0), - lhs.block_dim(1), - lhs.block_dim(2), - lhs.registers_per_thread(), - lhs.static_shmem_bytes(), - lhs.dynamic_shmem_bytes(), - lhs.is_kernel_using_tensor_core(), - lhs.is_op_tensor_core_eligible(), - lhs.op_name()); - - auto rhs_tuple = std::make_tuple( - rhs.name(), - rhs.grid_dim(0), - rhs.grid_dim(1), - rhs.grid_dim(2), - rhs.block_dim(0), - rhs.block_dim(1), - rhs.block_dim(2), - rhs.registers_per_thread(), - rhs.static_shmem_bytes(), - rhs.dynamic_shmem_bytes(), - rhs.is_kernel_using_tensor_core(), - rhs.is_op_tensor_core_eligible(), - rhs.op_name()); - // clang-format on - return lhs_tuple < rhs_tuple; -} - -bool KernelReportEqualToComparator::operator()(const KernelReport& lhs, - const KernelReport& rhs) const { - // Disable formatting to keep vertical alignment for better readability, - // and make it easier to reorder columns. - // clang-format off - // Put the most expensive string comparisons last. - return ( - lhs.is_kernel_using_tensor_core() == rhs.is_kernel_using_tensor_core() && - lhs.is_op_tensor_core_eligible() == rhs.is_op_tensor_core_eligible() && - lhs.block_dim(0) == rhs.block_dim(0) && - lhs.block_dim(1) == rhs.block_dim(1) && - lhs.block_dim(2) == rhs.block_dim(2) && - lhs.grid_dim(0) == rhs.grid_dim(0) && - lhs.grid_dim(1) == rhs.grid_dim(1) && - lhs.grid_dim(2) == rhs.grid_dim(2) && - lhs.registers_per_thread() == rhs.registers_per_thread() && - lhs.static_shmem_bytes() == rhs.static_shmem_bytes() && - lhs.dynamic_shmem_bytes() == rhs.dynamic_shmem_bytes() && - lhs.name() == rhs.name() && - lhs.op_name() == rhs.op_name()); - // clang-format on -} - -void SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb* kernel_stats_db) { - auto comp = [](const KernelReport& lhs, const KernelReport& rhs) { - return lhs.total_duration_ns() > rhs.total_duration_ns() || - (lhs.total_duration_ns() == rhs.total_duration_ns() && - KernelReportLessThanComparator()(lhs, rhs)); - }; - - // Sort and keep at most kernel reports. - if (kernel_stats_db->reports_size() > kMaxNumOfKernels) { - std::partial_sort( - kernel_stats_db->mutable_reports()->begin(), - kernel_stats_db->mutable_reports()->begin() + kMaxNumOfKernels, - kernel_stats_db->mutable_reports()->end(), comp); - kernel_stats_db->mutable_reports()->erase( - kernel_stats_db->mutable_reports()->begin() + kMaxNumOfKernels, - kernel_stats_db->mutable_reports()->end()); - } else { - std::sort(kernel_stats_db->mutable_reports()->begin(), - kernel_stats_db->mutable_reports()->end(), comp); - } -} - -void CopyTopKDurationKernelReportsToDb(const KernelReportMap& reports, - KernelStatsDb* dst) { - std::vector> - kernels_to_sort; - kernels_to_sort.reserve(reports.size()); - for (const auto& report_value : reports) { - kernels_to_sort.push_back( - std::make_pair(&report_value.first, &report_value.second)); - } - - auto comp = - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.second->total_duration_ns > rhs.second->total_duration_ns || - (lhs.second->total_duration_ns == - rhs.second->total_duration_ns && - KernelReportLessThanComparator()(*lhs.first, *rhs.first)); - }; - - // Sort and copy at most kernels to . - if (kernels_to_sort.size() > kMaxNumOfKernels) { - absl::c_partial_sort(kernels_to_sort, - kernels_to_sort.begin() + kMaxNumOfKernels, comp); - } else { - absl::c_sort(kernels_to_sort, comp); - } - - int copy_size = - std::min(kMaxNumOfKernels, static_cast(kernels_to_sort.size())); - for (int i = 0; i < copy_size; i++) { - KernelReport* report = dst->add_reports(); - *report = *kernels_to_sort[i].first; - const KernelReportValue& kernel_value = *kernels_to_sort[i].second; - // Set value using KernelReportValue. - report->set_occurrences(kernel_value.occurrences); - report->set_min_duration_ns(kernel_value.min_duration_ns); - report->set_max_duration_ns(kernel_value.max_duration_ns); - report->set_total_duration_ns(kernel_value.total_duration_ns); - } -} - -void InsertOrUpdateKernelReport(const KernelReport& kernel, - const KernelReportValue& value, - KernelReportMap* dst) { - KernelReportValue& element = (*dst)[kernel]; - if (element.occurrences == 0) { - element = value; - } else { - element.total_duration_ns += value.total_duration_ns; - element.min_duration_ns = - std::min(element.min_duration_ns, value.min_duration_ns); - element.max_duration_ns = - std::max(element.max_duration_ns, value.max_duration_ns); - element.occurrences += value.occurrences; - } -} - -void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst) { - for (auto& kernel_value : reports) { - InsertOrUpdateKernelReport(kernel_value.first, kernel_value.second, dst); - } -} - -KernelStatsByOpName GroupKernelReportsByOpName( - const KernelStatsDb& kernel_stats_db) { - KernelStatsByOpName op_level_kernel_stats; - for (const KernelReport& kernel_report : kernel_stats_db.reports()) { - auto ret = op_level_kernel_stats.emplace(kernel_report.op_name(), - OpLevelKernelStats()); - if (ret.second) { - // Inserted. Add a new op in . - OpLevelKernelStats& stats = ret.first->second; - stats.is_op_tensor_core_eligible = - kernel_report.is_op_tensor_core_eligible(); - stats.total_duration_ns += kernel_report.total_duration_ns(); - if (kernel_report.is_kernel_using_tensor_core()) { - stats.tensor_core_duration_ns += kernel_report.total_duration_ns(); - } - } else { - // Not inserted. Aggregate kernel stats to op level. - OpLevelKernelStats& stats = ret.first->second; - // Verifies operations with the same name have the same TensorCore - // eligibility. - DCHECK_EQ(stats.is_op_tensor_core_eligible, - kernel_report.is_op_tensor_core_eligible()); - stats.total_duration_ns += kernel_report.total_duration_ns(); - if (kernel_report.is_kernel_using_tensor_core()) { - stats.tensor_core_duration_ns += kernel_report.total_duration_ns(); - } - } - } - return op_level_kernel_stats; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.h b/tensorflow/core/profiler/utils/kernel_stats_utils.h index 1afecd6d54b1f0..6e625d9835e91f 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.h +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.h @@ -16,121 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/hash/hash.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Populates kernel launch information from a kKernelDetails XStat. -void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, - KernelReport* kernel); - -// Returns true if kernel uses TensorCores. -bool IsKernelUsingTensorCore(absl::string_view kernel_name); - -// Returns true if operation is eligible to use TensorCores. -bool IsOpTensorCoreEligible(absl::string_view tf_op_name); - -// Returns true if Einsum equation is eligible to use TensorCores. -bool IsEinsumTensorCoreEligible(absl::string_view equation); - -// Less than comparator for Kernel Reports. -struct KernelReportLessThanComparator { - bool operator()(const KernelReport& lhs, const KernelReport& rhs) const; -}; - -// Equal to comparator for Kernel Reports. -struct KernelReportEqualToComparator { - bool operator()(const KernelReport& lhs, const KernelReport& rhs) const; -}; - -// Sorts kernel reports by total duration descendingly. -// Keeps only the top kernel reports with long kernel duration in the given -// KernelStatsDb. Kernel reports with shorter kernel duration are dropped. -void SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb* kernel_stats_db); - -struct KernelReportValue { - uint64 total_duration_ns = 0; - uint64 min_duration_ns = 0; - uint64 max_duration_ns = 0; - uint64 occurrences = 0; -}; - -struct KernelKeyWrap { - const KernelReport* key; - template - friend H AbslHashValue(H h, KernelKeyWrap wrap) { - // Kernel reports are grouped by these fields, hence they are used as - // hashing criteria. - // clang-format off - return H::combine( - std::move(h), - wrap.key->is_kernel_using_tensor_core(), - wrap.key->is_op_tensor_core_eligible(), - wrap.key->block_dim(0), - wrap.key->block_dim(1), - wrap.key->block_dim(2), - wrap.key->grid_dim(0), - wrap.key->grid_dim(1), - wrap.key->grid_dim(2), - wrap.key->registers_per_thread(), - wrap.key->static_shmem_bytes(), - wrap.key->dynamic_shmem_bytes(), - wrap.key->name(), - wrap.key->op_name()); - // clang-format on - } -}; - -struct KernelHash { - size_t operator()(const KernelReport& key) const { - return absl::Hash()(KernelKeyWrap{&key}); - } -}; - -using KernelReportMap = - absl::flat_hash_map; - -// Copies the top kernel reports with long kernel duration into the given -// KernelStatsDb. -void CopyTopKDurationKernelReportsToDb(const KernelReportMap& reports, - KernelStatsDb* dst); - -// Inserts or aggregates KernelReports into the given KernelReportMap. -void InsertOrUpdateKernelReport(const KernelReport& kernel, - const KernelReportValue& value, - KernelReportMap* dst); - -// Aggregates values from one KernelReportMap into another. -void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst); - -// Kernel stats aggregated at TF operation level. -struct OpLevelKernelStats { - // Whether op is eligible to use TensorCore. - bool is_op_tensor_core_eligible = false; - // The accumulated duration of all the kernels launched in this op. - uint64 total_duration_ns = 0; - // The accumulated duration of all the kernels using TensorCore in this op. - // If this value is not 0, at least one of the kernels launched by this op - // is using TensorCore. - uint64 tensor_core_duration_ns = 0; -}; - -using KernelStatsByOpName = - absl::flat_hash_map; - -// Groups KernelReport in by tensorflow operation name. -KernelStatsByOpName GroupKernelReportsByOpName( - const KernelStatsDb& kernel_stats_db); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc b/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc deleted file mode 100644 index a8cf90adf62a9b..00000000000000 --- a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "xla/backends/profiler/gpu/cupti_buffer_events.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::FieldsAre; - -TEST(KernelStatsUtilsTest, TestGroupKernelReportsByOpName) { - KernelStatsDb kernel_stats_db; - KernelReport* kernel_report_1 = kernel_stats_db.add_reports(); - kernel_report_1->set_name("op1_kernel1"); - kernel_report_1->set_op_name("op1"); - kernel_report_1->set_total_duration_ns(1000); - kernel_report_1->set_is_kernel_using_tensor_core(true); - kernel_report_1->set_is_op_tensor_core_eligible(true); - - KernelReport* kernel_report_2 = kernel_stats_db.add_reports(); - kernel_report_2->set_name("op1_kernel2"); - kernel_report_2->set_op_name("op1"); - kernel_report_2->set_total_duration_ns(1000); - kernel_report_2->set_is_kernel_using_tensor_core(false); - kernel_report_2->set_is_op_tensor_core_eligible(true); - - KernelReport* kernel_report_3 = kernel_stats_db.add_reports(); - kernel_report_3->set_name("op2_kernel1"); - kernel_report_3->set_op_name("op2"); - kernel_report_3->set_total_duration_ns(100); - kernel_report_3->set_is_kernel_using_tensor_core(false); - kernel_report_3->set_is_op_tensor_core_eligible(false); - - KernelStatsByOpName kernel_stats_by_op_name = - GroupKernelReportsByOpName(kernel_stats_db); - - // Verifies there are two OpLevelKernelStats - ASSERT_EQ(kernel_stats_by_op_name.size(), 2); - auto iter1 = kernel_stats_by_op_name.find("op1"); - auto iter2 = kernel_stats_by_op_name.find("op2"); - ASSERT_NE(iter1, kernel_stats_by_op_name.end()); - ASSERT_NE(iter2, kernel_stats_by_op_name.end()); - const OpLevelKernelStats& op1_stats = iter1->second; - const OpLevelKernelStats& op2_stats = iter2->second; - - EXPECT_EQ(op1_stats.is_op_tensor_core_eligible, true); - EXPECT_EQ(op1_stats.total_duration_ns, 2000); - EXPECT_EQ(op1_stats.tensor_core_duration_ns, 1000); - - EXPECT_EQ(op2_stats.is_op_tensor_core_eligible, false); - EXPECT_EQ(op2_stats.total_duration_ns, 100); - EXPECT_EQ(op2_stats.tensor_core_duration_ns, 0); -} - -TEST(KernelStatsUtilsTest, KernelDetailsXStatParser) { - xla::profiler::KernelDetails kernel_info; - kernel_info.registers_per_thread = 10; - kernel_info.static_shared_memory_usage = 128; - kernel_info.dynamic_shared_memory_usage = 256; - kernel_info.block_x = 32; - kernel_info.block_y = 8; - kernel_info.block_z = 4; - kernel_info.grid_x = 3; - kernel_info.grid_y = 2; - kernel_info.grid_z = 1; - const double occupancy_pct = 50.0; - std::string xstat_kernel_details = ToXStat(kernel_info, occupancy_pct); - KernelReport kernel; - ParseKernelLaunchParams(xstat_kernel_details, &kernel); - // Verifies that the parser can parse kKernelDetails XStat. - EXPECT_EQ(kernel.registers_per_thread(), 10); - EXPECT_EQ(kernel.static_shmem_bytes(), 128); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 256); - EXPECT_EQ(kernel.block_dim()[0], 32); - EXPECT_EQ(kernel.block_dim()[1], 8); - EXPECT_EQ(kernel.block_dim()[2], 4); - EXPECT_EQ(kernel.grid_dim()[0], 3); - EXPECT_EQ(kernel.grid_dim()[1], 2); - EXPECT_EQ(kernel.grid_dim()[2], 1); -} - -TEST(KernelStatsUtilsTest, KernelDetailsTokenizer) { - KernelReport kernel; - - // Test odd token count (3): { "odd", "grid", "3,2,1" } - absl::string_view kernel_details_0 = "odd grid:3,2,1"; - ParseKernelLaunchParams(kernel_details_0, &kernel); - EXPECT_EQ(kernel.grid_dim()[0], 3); - EXPECT_EQ(kernel.grid_dim()[1], 2); - EXPECT_EQ(kernel.grid_dim()[2], 1); - - // Test odd token count (3): { "block", "6,5,4", "odd" } - absl::string_view kernel_details_1 = "block:6,5,4 odd "; - ParseKernelLaunchParams(kernel_details_1, &kernel); - EXPECT_EQ(kernel.block_dim()[0], 6); - EXPECT_EQ(kernel.block_dim()[1], 5); - EXPECT_EQ(kernel.block_dim()[2], 4); - - // Test odd token count (3): { "block", "1,2,3", "odd", "grid", "4,5,6" } - absl::string_view kernel_details_2 = "block:1,2,3 odd grid:4,5,6"; - ParseKernelLaunchParams(kernel_details_2, &kernel); - EXPECT_EQ(kernel.block_dim()[0], 1); - EXPECT_EQ(kernel.block_dim()[1], 2); - EXPECT_EQ(kernel.block_dim()[2], 3); - EXPECT_EQ(kernel.grid_dim()[0], 4); - EXPECT_EQ(kernel.grid_dim()[1], 5); - EXPECT_EQ(kernel.grid_dim()[2], 6); - - // Test even token count (4): { "static_shared", "7", "dynamic_shared", "8" } - absl::string_view kernel_details_3 = "static_shared:7 dynamic_shared:8"; - ParseKernelLaunchParams(kernel_details_3, &kernel); - EXPECT_EQ(kernel.static_shmem_bytes(), 7); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 8); -} - -TEST(KernelStatsUtilsTest, TestInsertOrUpdateKernelReport) { - KernelReport kr; - kr.set_name("op1_kernel1"); - kr.set_op_name("op1"); - // Must provide dummy dims since KernelReportMap's comparator assumes array of - // size 3; values here were suggested by autocomplete - kr.add_block_dim(32); - kr.add_block_dim(8); - kr.add_block_dim(4); - kr.add_grid_dim(3); - kr.add_grid_dim(2); - kr.add_grid_dim(1); - - KernelReportValue krv1; - krv1.total_duration_ns = 1700; - krv1.min_duration_ns = 500; - krv1.max_duration_ns = 1200; - krv1.occurrences = 2; - - KernelReportValue krv2; - krv2.total_duration_ns = 900; - krv2.min_duration_ns = 900; - krv2.max_duration_ns = 900; - krv2.occurrences = 1; - - KernelReportMap dst1; - InsertOrUpdateKernelReport(kr, krv1, &dst1); - InsertOrUpdateKernelReport(kr, krv2, &dst1); - EXPECT_THAT(dst1[kr], FieldsAre(2600, 500, 1200, 3)); - - KernelReportMap dst2; - InsertOrUpdateKernelReport(kr, krv2, &dst2); - InsertOrUpdateKernelReport(kr, krv1, &dst2); - EXPECT_THAT(dst2[kr], FieldsAre(2600, 500, 1200, 3)); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc deleted file mode 100644 index 7ff1c33c762f80..00000000000000 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -const absl::string_view kIdle = "IDLE"; -const uint32_t kSparseCoreIndexStart = 1000000; -const int64_t kSingleOccurrence = 1; - -namespace { - -constexpr uint64_t kRootSymbolId = 0; - -using tsl::profiler::StatType; -using tsl::profiler::XEventMetadataVisitor; -using tsl::profiler::XStatVisitor; - -class DeviceTfOpMetricsDbBuilder : public OpMetricsDbBuilder { - public: - explicit DeviceTfOpMetricsDbBuilder(OpMetricsDb* db) - : OpMetricsDbBuilder(db) {} - - void UpdateTfOpMetricsWithDeviceOpMetrics( - absl::string_view tf_op_name, absl::string_view tf_op_type, - const OpMetrics& device_op_metrics) { - OpMetrics* tf_op_metrics = OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( - /*hlo_module_id=*/0, tf_op_name); - if (tf_op_metrics->category().empty()) { - tf_op_metrics->set_category(tf_op_type == tsl::profiler::kUnknownOp - ? "Unknown" - : std::string(tf_op_type)); - } - tf_op_metrics->set_is_eager(device_op_metrics.is_eager()); - // The occurrences of a TF-op is the maximum among the occurrences of all - // device ops that it contains. - tf_op_metrics->set_occurrences(std::max(tf_op_metrics->occurrences(), - device_op_metrics.occurrences())); - tf_op_metrics->set_time_ps(tf_op_metrics->time_ps() + - device_op_metrics.time_ps()); - tf_op_metrics->set_self_time_ps(tf_op_metrics->self_time_ps() + - device_op_metrics.self_time_ps()); - tf_op_metrics->set_flops(tf_op_metrics->flops() + - device_op_metrics.flops()); - tf_op_metrics->set_bytes_accessed(tf_op_metrics->bytes_accessed() + - device_op_metrics.bytes_accessed()); - } -}; - -void SetOpMetadataFromHloEventMetadata( - const XEventMetadataVisitor& hlo_event_metadata, OpMetrics* op_metrics) { - if (hlo_event_metadata.HasDisplayName()) { - op_metrics->set_name(std::string(hlo_event_metadata.DisplayName())); - op_metrics->set_long_name(std::string(hlo_event_metadata.Name())); - } else { - op_metrics->set_name(std::string(hlo_event_metadata.Name())); - } - hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type().has_value()) { - switch (static_cast(*stat.Type())) { - case StatType::kProgramId: - op_metrics->set_hlo_module_id(stat.IntOrUintValue()); - break; - case StatType::kHloCategory: - op_metrics->set_category(std::string(stat.StrOrRefValue())); - break; - case StatType::kTfOp: - op_metrics->set_provenance(std::string(stat.StrOrRefValue())); - break; - case StatType::kFlops: - op_metrics->set_flops(stat.IntOrUintValue()); - break; - case StatType::kModelFlops: - op_metrics->set_model_flops(stat.IntOrUintValue()); - break; - case StatType::kBytesAccessed: - op_metrics->set_bytes_accessed(stat.IntOrUintValue()); - break; - case StatType::kMemoryAccessBreakdown: { - tensorflow::profiler::MemoryAccessBreakdown breakdown; - const auto& value = stat.BytesValue(); - if (breakdown.ParseFromArray(value.data(), value.size())) { - *op_metrics->mutable_memory_accessed_breakdown() = - breakdown.memory_accessed(); - } - break; - } - case StatType::kDeduplicatedName: - op_metrics->set_deduplicated_name(std::string(stat.StrOrRefValue())); - break; - default: - break; - } - } - }); - hlo_event_metadata.ForEachChild( - [&](const XEventMetadataVisitor& child_hlo_event_metadata) { - OpMetrics* child = op_metrics->mutable_children()->add_metrics_db(); - child->set_occurrences(1); - SetOpMetadataFromHloEventMetadata(child_hlo_event_metadata, child); - }); -} - -void SetOpMetricsFromHloEvent(const tsl::profiler::XEventVisitor& hlo_event, - OpMetrics* op_metrics) { - uint64_t duration_ps = hlo_event.DurationPs(); - uint64_t min_duration_ps = duration_ps; - uint64_t self_duration_ps = duration_ps; - uint64_t dma_stall_ps = 0; - hlo_event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type()) return; - switch (static_cast(*stat.Type())) { - case StatType::kMinDurationPs: - min_duration_ps = stat.IntValue(); - break; - case StatType::kSelfDurationPs: - self_duration_ps = stat.IntValue(); - break; - case StatType::kDmaStallDurationPs: - dma_stall_ps = stat.IntValue(); - break; - default: - break; - } - }); - if (op_metrics->occurrences() == 0) { - SetOpMetadataFromHloEventMetadata(hlo_event.Metadata(), op_metrics); - op_metrics->set_occurrences( - std::max(kSingleOccurrence, hlo_event.NumOccurrences())); - op_metrics->set_time_ps(duration_ps); - op_metrics->set_min_time_ps(min_duration_ps); - op_metrics->set_self_time_ps(self_duration_ps); - op_metrics->set_dma_stall_ps(dma_stall_ps); - op_metrics->set_num_cores(1); - } else { - op_metrics->set_occurrences(op_metrics->occurrences() + - hlo_event.NumOccurrences()); - op_metrics->set_time_ps(op_metrics->time_ps() + duration_ps); - op_metrics->set_min_time_ps( - std::min(op_metrics->min_time_ps(), min_duration_ps)); - op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_duration_ps); - op_metrics->set_dma_stall_ps(op_metrics->dma_stall_ps() + dma_stall_ps); - } -} - -void MergeOpMetrics(const OpMetrics& src, OpMetrics& dst) { - if (dst.occurrences() == 0) { - dst = src; - } else { - dst.set_occurrences(src.occurrences() + dst.occurrences()); - dst.set_time_ps(src.time_ps() + dst.time_ps()); - dst.set_min_time_ps( - std::min(src.min_time_ps(), dst.min_time_ps())); - dst.set_self_time_ps(src.self_time_ps() + dst.self_time_ps()); - dst.set_dma_stall_ps(src.dma_stall_ps() + dst.dma_stall_ps()); - } -} - -void AdjustFlopsAndBytesAccessed(OpMetrics& op_metrics) { - op_metrics.set_flops(op_metrics.flops() * op_metrics.occurrences()); - if (op_metrics.model_flops() > 0) { - op_metrics.set_model_flops(op_metrics.model_flops() * - op_metrics.occurrences()); - } else { - op_metrics.set_model_flops(op_metrics.flops()); - } - op_metrics.set_bytes_accessed(op_metrics.bytes_accessed() * - op_metrics.occurrences()); - for (auto& memory_access : *op_metrics.mutable_memory_accessed_breakdown()) { - memory_access.set_bytes_accessed(memory_access.bytes_accessed() * - op_metrics.occurrences()); - } -} - -} // namespace - -OpMetricsDbBuilder::OpMetricsDbBuilder(OpMetricsDb* db) : db_(db) { - DCHECK_NE(db_, nullptr); - DCHECK_EQ(db_->metrics_db_size(), db->metrics_db_size()); -} - -OpMetrics* OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( - uint64 hlo_module_id, absl::string_view name) { - OpMetrics*& op_metrics = op_metrics_map_[hlo_module_id][name]; - if (op_metrics == nullptr) { - op_metrics = db_->add_metrics_db(); - op_metrics->set_hlo_module_id(hlo_module_id); - op_metrics->set_name(name.data(), name.size()); - } - return op_metrics; -} - -void XEventsOpMetricsDbBuilder::AddOpMetric( - const tsl::profiler::XEventVisitor& event) { - AddOpMetric(FromXEvent(event), GetOpKeyFromXEvent(event)); -} - -void XEventsOpMetricsDbBuilder::AddOpMetric(const OpMetrics& op_metrics, - const OpKey& key) { - if (!key.program_id.has_value() || !key.symbol_id.has_value() || - key.symbol_id == kRootSymbolId) - return; - MergeOpMetrics( - op_metrics, - flat_op_metric_[key.program_id.value()][key.symbol_id.value()]); -} - -OpMetricsDb XEventsOpMetricsDbBuilder::Finalize(uint64_t total_time_ps) { - OpMetricsDb db = Finalize(); - SetTotalTimePs(db, total_time_ps); - AddIdleOp(db); - return db; -} - -OpMetricsDb XEventsOpMetricsDbBuilder::Finalize() { - OpMetricsDb db; - uint64_t total_op_time_ps = 0; - for (auto& [program_id, op_metric_by_symbol] : flat_op_metric_) { - for (auto& [symbol_id, op_metrics] : op_metric_by_symbol) { - AdjustFlopsAndBytesAccessed(op_metrics); - total_op_time_ps += op_metrics.self_time_ps(); - db.add_metrics_db()->Swap(&op_metrics); - } - } - db.set_total_op_time_ps(total_op_time_ps); - return db; -} - -double IdleTimeRatio(const OpMetricsDb& db) { - return 1.0 - - tsl::profiler::SafeDivide(db.total_op_time_ps(), db.total_time_ps()); -} - -uint64 IdleTimePs(const OpMetricsDb& db) { - DCHECK_GE(db.total_time_ps(), db.total_op_time_ps()); - return db.total_time_ps() - db.total_op_time_ps(); -} - -void SetIdleOp(uint64_t idle_time_ps, OpMetrics& metrics) { - metrics.set_name(std::string(kIdle)); - metrics.set_category(std::string(kIdle)); - metrics.set_occurrences(0); - metrics.set_time_ps(idle_time_ps); - metrics.set_self_time_ps(idle_time_ps); -} - -void AddIdleOp(OpMetricsDb& db) { - uint64 idle_time_ps = IdleTimePs(db); - SetIdleOp(idle_time_ps, *db.add_metrics_db()); -} - -std::optional HostInfeedEnqueueRatio(const OpMetricsDb& db) { - if (db.total_host_infeed_enq_start_timestamp_ps_diff() > 0) { - // We use total_host_infeed_enq_start_timestamp_ps_diff to approximate the - // total host time. - return tsl::profiler::SafeDivide( - db.total_host_infeed_enq_duration_ps(), - db.total_host_infeed_enq_start_timestamp_ps_diff()); - } - return std::nullopt; -} - -OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( - const OpMetricsDb& device_op_metrics_db, bool with_idle) { - OpMetricsDb tf_op_metrics_db; - DeviceTfOpMetricsDbBuilder builder(&tf_op_metrics_db); - for (const auto& device_op_metrics : device_op_metrics_db.metrics_db()) { - if (IsIdleOp(device_op_metrics)) { - if (with_idle) { - builder.UpdateTfOpMetricsWithDeviceOpMetrics(kIdle, kIdle, - device_op_metrics); - } - } else if (device_op_metrics.provenance().empty()) { - builder.UpdateTfOpMetricsWithDeviceOpMetrics(device_op_metrics.name(), - tsl::profiler::kUnknownOp, - device_op_metrics); - } else { - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(device_op_metrics.provenance()); - builder.UpdateTfOpMetricsWithDeviceOpMetrics(tf_op.name, tf_op.type, - device_op_metrics); - } - } - tf_op_metrics_db.set_total_op_time_ps( - device_op_metrics_db.total_op_time_ps()); - - tf_op_metrics_db.set_total_time_ps( - with_idle ? device_op_metrics_db.total_time_ps() - : device_op_metrics_db.total_op_time_ps()); - - return tf_op_metrics_db; -} - -OpMetrics FromXEvent(const tsl::profiler::XEventVisitor& xevent) { - OpMetrics op_metrics; - std::optional stat = xevent.GetStat(StatType::kStepIdleTimePs); - if (stat.has_value()) { - // TODO(b/397774568) : Remove this once the SparseCore OpMetricsDb is - // implemented. - uint64_t idle_time_ps = stat->IntOrUintValue(); - op_metrics.set_self_time_ps(xevent.DurationPs() - idle_time_ps); - op_metrics.set_name("sparse_core_busy_ops"); - op_metrics.set_category("sparse_core_busy_ops"); - return op_metrics; - } - SetOpMetricsFromHloEvent(xevent, &op_metrics); - return op_metrics; -} - -XEventsOpMetricsDbBuilder::OpKey GetOpKeyFromXEvent( - const XEventVisitor& event) { - std::optional stat = event.GetStat(StatType::kStepIdleTimePs); - if (stat.has_value()) { - return {.program_id = std::numeric_limits::max(), - .symbol_id = std::numeric_limits::max()}; - } - - XEventsOpMetricsDbBuilder::OpKey op_key; - DCHECK(event.metadata() != nullptr); - event.Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type().has_value()) { - switch (static_cast(*stat.Type())) { - case StatType::kProgramId: - op_key.program_id = stat.IntOrUintValue(); - break; - case StatType::kSymbolId: - op_key.symbol_id = stat.IntOrUintValue(); - break; - default: - break; - } - } - }); - return op_key; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index 4eca439960b0c2..5ed177ac3780d9 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -16,136 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/macros.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" - -namespace tensorflow { -namespace profiler { - -// The name of OpMetrics to represent the idle time. -TF_CONST_INIT extern const absl::string_view kIdle; -// The core index to add to sparse core index in op metrics. -TF_CONST_INIT extern const uint32_t kSparseCoreIndexStart; - -// Helps build an op metrics database (borrowed). -// Enables fast lookup of existing ops and prevents the creation of duplicate -// ops. It is the user's responsibility to ensure an op metrics database -// outlives its builder, and that no ops are added to the database outside of -// the builder. -class OpMetricsDbBuilder { - public: - // Create with a borrowed op database. - // REQUIRED: The op database must be empty. - explicit OpMetricsDbBuilder(OpMetricsDb* db); - - protected: - // Looks up the given OP name. If it is already in the database, - // return its OpMetrics; otherwise, insert a new one. - OpMetrics* LookupOrInsertNewOpMetrics(uint64 hlo_module_id, - absl::string_view name); - - OpMetricsDb* db() { return db_; } - - private: - // Map op (hlo_module_id, name) to the corresponding metrics in the op - // database. - absl::flat_hash_map> - op_metrics_map_; - - // The op database. - OpMetricsDb* db_; -}; - -// Helps build an op metrics database (borrowed) from XEvents, -class XEventsOpMetricsDbBuilder { - public: - struct OpKey { - std::optional program_id; - std::optional symbol_id; - }; - // DEPRECATED: Use the OpKey version below. - // Add OpMetric from XEventVisitor. - void AddOpMetric(const tsl::profiler::XEventVisitor& xevent); - - // Add an OpMetric to the builder based on the provided key. - void AddOpMetric(const OpMetrics& op_metrics, const OpKey& key); - - // Finalize OpMetricDb and add total time and Idle op. - OpMetricsDb Finalize(uint64_t total_time); - - // Finalize OpMetricDb, but the total time is unknown at the moment, So ignore - // the total time and Idle Op and will be handled by the caller. - OpMetricsDb Finalize(); - - private: - using OpMetricBySymbol = - absl::flat_hash_map; - absl::flat_hash_map - flat_op_metric_; -}; - -// Constructs an OpMetrics from the provided XEventVisitor. -OpMetrics FromXEvent(const tsl::profiler::XEventVisitor& xevent); - -// Returns the OpKey for the provided XEventVisitor. -XEventsOpMetricsDbBuilder::OpKey GetOpKeyFromXEvent( - const tsl::profiler::XEventVisitor& event); - -// Sets the total time for OpMetricsDb, ensuring idle time is not negative. -inline void SetTotalTimePs(OpMetricsDb& db, uint64_t total_time_ps) { - db.set_total_time_ps(std::max(db.total_op_time_ps(), total_time_ps)); -} - -// Returns the total time in OpMetricsDb, optionally excluding the idle time. -inline uint64_t TotalTimePs(const OpMetricsDb& db, bool exclude_idle = false) { - return exclude_idle ? db.total_op_time_ps() : db.total_time_ps(); -} - -// Returns the ratio of time that is idle (no op execution) over total time. -double IdleTimeRatio(const OpMetricsDb& db); - -// Returns the idle time in picoseconds. -uint64 IdleTimePs(const OpMetricsDb& db); - -// Populates an OpMetrics record representing idle time, i.e., the amount of -// time spent without any op execution. -void SetIdleOp(uint64_t idle_time_ps, OpMetrics& metrics); - -// Adds an OpMetrics record representing idle time, i.e., the amount of time -// spent without any op execution. -// REQUIRED: All ops must have been added to the database and the total time -// must have been set. -void AddIdleOp(OpMetricsDb& db); - -// Returns true if the given metrics represents idle time. -inline bool IsIdleOp(const OpMetrics& metrics) { - return metrics.category() == kIdle; -} - -// Returns the time spent in children (nested) ops. -inline uint64_t ChildrenTimePs(const OpMetrics& metrics) { - return metrics.time_ps() - metrics.self_time_ps(); -} - -// Returns the ratio of time spent sending data from the host to the device -// relative to the total time the host was active. -std::optional HostInfeedEnqueueRatio(const OpMetricsDb& db); - -// Converts from the device op metrics to Tf-op metrics. -OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( - const OpMetricsDb& device_op_metrics_db, bool with_idle = true); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc deleted file mode 100644 index 07d85e1411e0a1..00000000000000 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc +++ /dev/null @@ -1,220 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -#include -#include -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { -#if defined(PLATFORM_GOOGLE) -using ::testing::EqualsProto; -using ::testing::proto::IgnoringRepeatedFieldOrdering; -#endif - -constexpr double kMaxError = 1E-10; - -TEST(OpMetricsDbTest, IdleTimeRatio) { - OpMetricsDb metrics_db_0; - metrics_db_0.set_total_time_ps(100000000); - metrics_db_0.set_total_op_time_ps(60000000); - EXPECT_NEAR(0.4, IdleTimeRatio(metrics_db_0), kMaxError); - - OpMetricsDb metrics_db_1; - metrics_db_1.set_total_time_ps(200000000); - metrics_db_1.set_total_op_time_ps(150000000); - EXPECT_NEAR(0.25, IdleTimeRatio(metrics_db_1), kMaxError); - - OpMetricsDb metrics_db_2; - metrics_db_1.set_total_time_ps(0); - metrics_db_1.set_total_op_time_ps(0); - EXPECT_NEAR(1.0, IdleTimeRatio(metrics_db_2), kMaxError); -} - -TEST(OpMetricsDbTest, FromXEventHandlesMissingOccurrences) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - XLineBuilder line = plane.GetOrCreateLine(0); - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("metadata"); - event_metadata->set_display_name("display_name"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 2); - stats.AddStatValue(*plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDeduplicatedName)), - "deduplicated_name"); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), "tf_op"); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kHloCategory)), - "tf_op_category"); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kFlops)), 3); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kModelFlops)), 4); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kBytesAccessed)), - 5); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - tsl::profiler::XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(&raw_plane); - tsl::profiler::XEventVisitor event_visitor( - &plane_visitor, &raw_plane.lines(0), &raw_plane.lines(0).events(0)); - OpMetrics op_metrics = FromXEvent(event_visitor); - -#if defined(PLATFORM_GOOGLE) - EXPECT_THAT(op_metrics, EqualsProto(R"pb( - occurrences: 1 - time_ps: 100 - self_time_ps: 100 - dma_stall_ps: 0 - hlo_module_id: 1 - flops: 3 - model_flops: 4 - bytes_accessed: 5 - name: "display_name" - long_name: "metadata" - deduplicated_name: "deduplicated_name" - category: "tf_op_category" - provenance: "tf_op" - min_time_ps: 100 - num_cores: 1 - )pb")); -#endif -} - -TEST(OpMetricsDbTest, GetOpKeyFromXEvent) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("metadata"); - event_metadata->set_display_name("display_name"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 2); - XLineBuilder line = plane.GetOrCreateLine(0); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - tsl::profiler::XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(&raw_plane); - tsl::profiler::XEventVisitor event_visitor( - &plane_visitor, &raw_plane.lines(0), &raw_plane.lines(0).events(0)); - XEventsOpMetricsDbBuilder::OpKey op_key = GetOpKeyFromXEvent(event_visitor); - EXPECT_EQ(op_key.program_id, 1); - EXPECT_EQ(op_key.symbol_id, 2); -} - -TEST(OpMetricsDbTest, XEventsOpMetricsDbBuilder) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - XLineBuilder line = plane.GetOrCreateLine(0); - { - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("m1"); - event_metadata->set_display_name("display_name1"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), - 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 1); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - XEventBuilder event2 = line.AddEvent(*event_metadata); - event2.SetOffsetPs(100); - event2.SetDurationPs(100); - } - { - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("m2"); - event_metadata->set_display_name("display_name2"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), - 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 2); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - } - { - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("m3"); - event_metadata->set_display_name("display_name3"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 1); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - } - - XEventsOpMetricsDbBuilder builder; - XEventsOpMetricsDbBuilder legacy_builder; - tsl::profiler::XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(&raw_plane); - plane_visitor.ForEachLine([&](const tsl::profiler::XLineVisitor& line) { - line.ForEachEvent([&](const tsl::profiler::XEventVisitor& event) { - builder.AddOpMetric(FromXEvent(event), GetOpKeyFromXEvent(event)); - legacy_builder.AddOpMetric(event); - }); - }); -#if defined(PLATFORM_GOOGLE) - OpMetricsDb legacy_db = legacy_builder.Finalize(); - OpMetricsDb db = builder.Finalize(); - EXPECT_THAT(db, IgnoringRepeatedFieldOrdering(EqualsProto(legacy_db))); - EXPECT_THAT(db, IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( - metrics_db { - hlo_module_id: 1 - self_time_ps: 200 - occurrences: 2 - name: "display_name1" - long_name: "m1" - time_ps: 200 - min_time_ps: 100 - num_cores: 1 - } - metrics_db { - hlo_module_id: 1 - self_time_ps: 100 - occurrences: 1 - name: "display_name2" - long_name: "m2" - time_ps: 100 - min_time_ps: 100 - num_cores: 1 - } - total_op_time_ps: 300 - )pb"))); -#endif -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc deleted file mode 100644 index 72b55ba1a76c9b..00000000000000 --- a/tensorflow/core/profiler/utils/op_utils.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/op_utils.h" - -#include -#include - -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/platform/types.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -using tsl::uint64; - -namespace {} // namespace - -// Annotate the op_metrics with the metadata from the instr_wrapper. -void EnterOpMetadata(OpMetrics* op_metrics, - const HloInstructionWrapper* instr_wrapper) { - if (op_metrics->name().empty() && op_metrics->category().empty() && - op_metrics->provenance().empty()) { - op_metrics->set_name(std::string(instr_wrapper->Name())); - op_metrics->set_category(std::string(instr_wrapper->Category())); - op_metrics->set_deduplicated_name( - instr_wrapper->Metadata().deduplicated_name()); - op_metrics->set_provenance(std::string(instr_wrapper->op_full_name())); - op_metrics->set_num_cores(1); - op_metrics->set_occurrences(op_metrics->occurrences() + 1); - op_metrics->set_flops(op_metrics->flops() + instr_wrapper->flops()); - op_metrics->set_bytes_accessed(op_metrics->bytes_accessed() + - instr_wrapper->bytes_accessed()); - op_metrics->set_long_name(instr_wrapper->Expression()); - } -} - -void AddFusionChildrenToOpMetricsFromHloInstruction( - OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper) { - if (instr_wrapper->FusedChildren().empty()) return; - for (const HloInstructionWrapper* child : instr_wrapper->FusedChildren()) { - if (child->HloOpcode() == xla::HloOpcode::kParameter || - child->HloOpcode() == xla::HloOpcode::kTuple) - continue; - OpMetrics* child_op_metrics = - op_metrics->mutable_children()->add_metrics_db(); - // DeviceOpMetricsDbBuilder children_db_builder( - // op_metrics->mutable_children()); - EnterOpMetadata(child_op_metrics, child); - // children_db_builder.EnterOpMetadata(child_op_metrics, child); - AddFusionChildrenToOpMetricsFromHloInstruction(child_op_metrics, child); - } -} - -void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, - const HloModuleMap& hlo_module_map) { - const HloInstructionWrapper* instr_wrapper = GetHloInstruction( - hlo_module_map, op_metrics->hlo_module_id(), op_metrics->name()); - if (instr_wrapper != nullptr) { - AddFusionChildrenToOpMetricsFromHloInstruction(op_metrics, instr_wrapper); - } -} - -void HostOpMetricsDbBuilder::EnterOp(absl::string_view name, - absl::string_view category, bool is_eager, - uint64 time_ps, uint64 children_time_ps) { - uint64 self_time_ps = time_ps - children_time_ps; - DCHECK_GE(time_ps, self_time_ps); - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(/*hlo_module_id=*/0, name); - if (op_metrics->category().empty()) - op_metrics->set_category(category.data(), category.size()); - op_metrics->set_num_cores(1); - op_metrics->set_is_eager(op_metrics->is_eager() || is_eager); - op_metrics->set_occurrences(op_metrics->occurrences() + 1); - op_metrics->set_time_ps(op_metrics->time_ps() + time_ps); - op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_time_ps); - db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps); -} - -void HostOpMetricsDbBuilder::EnterHostInfeedEnqueue( - tsl::profiler::Timespan host_infeed_enqueue) { - if (!last_host_infeed_enqueue_.Empty()) { - // Expect non-overlapping InfeedEnqueue timespans sorted by time. - DCHECK_GE(host_infeed_enqueue.end_ps(), - last_host_infeed_enqueue_.begin_ps()); - db()->set_total_host_infeed_enq_duration_ps( - db()->total_host_infeed_enq_duration_ps() + - last_host_infeed_enqueue_.duration_ps()); - db()->set_total_host_infeed_enq_start_timestamp_ps_diff( - db()->total_host_infeed_enq_start_timestamp_ps_diff() + - (host_infeed_enqueue.begin_ps() - - last_host_infeed_enqueue_.begin_ps())); - } - last_host_infeed_enqueue_ = host_infeed_enqueue; -} - -void DeviceOpMetricsDbBuilder::EnterOpMetadataFromHloModuleMap( - uint64 program_id, absl::string_view op_name, - const HloModuleMap& hlo_module_map) { - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, op_name); - tensorflow::profiler::EnterOpMetadataFromHloModuleMap(op_metrics, - hlo_module_map); -} - -void DeviceOpMetricsDbBuilder::EnterOpMetadata( - uint64 program_id, absl::string_view program_name, - absl::string_view category, absl::string_view provenance, - absl::string_view deduplicated_name, bool is_eager, - absl::string_view long_name) { - // We only need to add xla metadata once to each new op, as they are the - // same across occurrences. - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, program_name); - if (op_metrics->occurrences() > 0 || !op_metrics->category().empty() || - !op_metrics->provenance().empty()) - return; - op_metrics->set_category(category == tsl::profiler::kUnknownOp - ? "unknown" - : std::string(category)); - op_metrics->set_provenance(std::string(provenance)); - if (!deduplicated_name.empty()) { - op_metrics->set_deduplicated_name(std::string(deduplicated_name)); - } - if (!long_name.empty()) { - op_metrics->set_long_name(std::string(long_name)); - } - op_metrics->set_is_eager(op_metrics->is_eager() || is_eager); -} - -void DeviceOpMetricsDbBuilder::EnterOp( - uint64 program_id, absl::string_view name, absl::string_view category, - absl::string_view provenance, absl::string_view deduplicated_name, - bool is_eager, uint64 occurrences, uint64 time_ps, uint64 children_time_ps, - int64_t flops, int64_t bytes_accessed, - // NOLINTNEXTLINE: clang-tidy missing-includes false positive - const tsl::protobuf::RepeatedPtrField& - memory_accessed_breakdown, - int64_t model_flops) { - EnterOpMetadata(program_id, name, category, provenance, deduplicated_name, - is_eager); - uint64 self_time_ps = time_ps - children_time_ps; - DCHECK_GE(time_ps, self_time_ps); - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name); - op_metrics->set_num_cores(1); - op_metrics->set_occurrences(op_metrics->occurrences() + occurrences); - op_metrics->set_time_ps(op_metrics->time_ps() + time_ps); - op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_time_ps); - op_metrics->set_flops(op_metrics->flops() + flops * occurrences); - if (model_flops == 0) { - // If ModelsFlops is 0, use the same value as device flops. - op_metrics->set_model_flops(op_metrics->flops()); - } else { - op_metrics->set_model_flops(op_metrics->model_flops() + - model_flops * occurrences); - } - op_metrics->set_bytes_accessed(op_metrics->bytes_accessed() + - bytes_accessed * occurrences); - CombineMemoryAccessedBreakdown( - memory_accessed_breakdown, - op_metrics->mutable_memory_accessed_breakdown()); - db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h index 9363574a474f99..b5edd9288a4652 100644 --- a/tensorflow/core/profiler/utils/op_utils.h +++ b/tensorflow/core/profiler/utils/op_utils.h @@ -16,94 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/types.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -using tsl::uint64; - -// Annotate the op_metrics with the metadata from the instr_wrapper. -void EnterOpMetadata(OpMetrics* op_metrics, - const HloInstructionWrapper* instr_wrapper); -void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, - const HloModuleMap& hlo_module_map); - -void AddFusionChildrenToOpMetricsFromHloInstruction( - OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper); - -class HostOpMetricsDbBuilder : public OpMetricsDbBuilder { - public: - explicit HostOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} - - // A function that will be called when the end of an OP is - // observed on a trace, where: - // name = the OP name. - // category = the OP category. - // is_eager = whether this OP is eagerly executed. - // time_ps = the total execution time of the OP in picoseconds, including - // the execution time of its children. - // children_time_ps = the execution time of the children of this OP in - // picoseconds - void EnterOp(absl::string_view name, absl::string_view category, - bool is_eager, uint64 time_ps, uint64 children_time_ps); - - // Updates total_host_infeed_enq_duration_ps_ and - // total_host_infeed_enq_duration_ps_. - void EnterHostInfeedEnqueue(tsl::profiler::Timespan host_infeed_enqueue); - - private: - // The tsl::profiler::Timespan of the last InfeedEnqueue op on this thread. - tsl::profiler::Timespan last_host_infeed_enqueue_; -}; - -class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder { - public: - explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} - - // A function that will be called when the end of an OP is - // observed on a trace, where: - // program_id = the ID of the program that contains this OP. - // name = the OP name. - // category = the OP category. - // provenance = the provenance of this OP (e.g. original TF OP). - // is_eager = whether this OP is eagerly executed. - // occurrences = the number of occurrences of this OP. - // time_ps = the total execution time of the OP in picoseconds, including - // the execution time of its children. - // children_time_ps = the execution time of the children of this OP in - // picoseconds. - // flops = the number of floating-point operations computed. - // bytes_accessed = the sum of bytes read and bytes written by this OP. - // memory_accessed_breakdown = the breakdown of memory accessed by operation - // type and memory space. - void EnterOp(uint64 program_id, absl::string_view name, - absl::string_view category, absl::string_view provenance, - absl::string_view deduplicated_name, bool is_eager, - uint64 occurrences, uint64 time_ps, uint64 children_time_ps, - int64_t flops, int64_t bytes_accessed, - const tsl::protobuf::RepeatedPtrField& - memory_accessed_breakdown = {}, - int64_t model_flops = 0); - - void EnterOpMetadata(uint64 program_id, absl::string_view program_name, - absl::string_view category, absl::string_view provenance, - absl::string_view deduplicated_name, bool is_eager, - absl::string_view long_name = ""); - - void EnterOpMetadataFromHloModuleMap(uint64 program_id, - absl::string_view op_name, - const HloModuleMap& hlo_module_map); -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/op_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/step_intersection.cc b/tensorflow/core/profiler/utils/step_intersection.cc deleted file mode 100644 index 8eb967fafba1e2..00000000000000 --- a/tensorflow/core/profiler/utils/step_intersection.cc +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/step_intersection.h" - -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/str_cat.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Returns the timespan in this step (across all cores). -tsl::profiler::Timespan StepTimespan(const PerCoreStepInfo& percore_stepinfo) { - uint64 min_ps = kuint64max; - uint64 max_ps = 0; - for (const auto& core_stepinfo : percore_stepinfo.step_info_per_core()) { - const auto& stepinfo = core_stepinfo.second; - uint64 begin_ps = stepinfo.begin_ps(); - uint64 end_ps = begin_ps + stepinfo.duration_ps(); - min_ps = std::min(min_ps, begin_ps); - max_ps = std::max(max_ps, end_ps); - } - return (min_ps < max_ps) - ? tsl::profiler::Timespan::FromEndPoints(min_ps, max_ps) - : tsl::profiler::Timespan(); -} - -// Returns the timespan across all steps in the given step_db. -tsl::profiler::Timespan AllStepsTimespan(const StepDatabaseResult& step_db) { - uint64 min_ps = kuint64max; - uint64 max_ps = 0; - for (const auto& step : step_db.step_sequence()) { - tsl::profiler::Timespan timespan = StepTimespan(step); - uint64 begin_ps = timespan.begin_ps(); - uint64 end_ps = timespan.end_ps(); - min_ps = std::min(min_ps, begin_ps); - max_ps = std::max(max_ps, end_ps); - } - return (min_ps < max_ps) - ? tsl::profiler::Timespan::FromEndPoints(min_ps, max_ps) - : tsl::profiler::Timespan(); -} - -struct AlignmentInfo { - StepsAlignment alignment; - double similarity; -}; - -// Computes the similarity between the given two steps. The closer their -// timespans are, the larger is the similarity. -double StepSimilarity(const PerCoreStepInfo& subordinate_step, - const PerCoreStepInfo& chief_step) { - tsl::profiler::Timespan subordinate_timespan = StepTimespan(subordinate_step); - tsl::profiler::Timespan chief_timespan = StepTimespan(chief_step); - return chief_timespan.OverlappedDurationPs(subordinate_timespan); -} - -// If the subordinate steps and the chief steps are aligned at the given anchor -// points (i.e. at the subordinate_anchor step on the subordinate sequence, at -// the chief_anchor step on the chief sequence), returns the corresponding -// AlignmentInfo. -AlignmentInfo ComputeAlignmentInfo(const StepDatabaseResult& subordinate, - uint32 subordinate_anchor, - const StepDatabaseResult& chief, - uint32 chief_anchor) { - // Assumes that the step at subordinate_anchor on the subordinate sequence is - // aligned with the step at the chief_anchor on the chief sequence. Then the - // number of steps before the anchor is the minimum of the number of steps - // before the anchor in the subordinate and that before the anchor in the - // chief. Similarly, the number of steps after the anchor is the minimum of - // the number of steps after the anchor in the subordinate and that after the - // anchor in the chief. - uint32 pre_anchor_steps = std::min(subordinate_anchor, chief_anchor); - uint32 post_anchor_steps = - std::min(subordinate.step_sequence_size() - subordinate_anchor, - chief.step_sequence_size() - chief_anchor); - // total number of steps aligned = pre_anchor_steps + post_anchor_steps. - uint32 alignment_steps = pre_anchor_steps + post_anchor_steps; - - double similarity = 0; - // Where the aligned steps begin on the subordinate sequence. - uint32 begin_subordinate_idx = subordinate_anchor - pre_anchor_steps; - // Where the aligned steps begin on the chief sequence. - uint32 begin_chief_idx = chief_anchor - pre_anchor_steps; - - for (uint32 i = 0; i < alignment_steps; i++) { - // Accumulates the similarity at each step. - similarity += - StepSimilarity(subordinate.step_sequence(begin_subordinate_idx + i), - chief.step_sequence(begin_chief_idx + i)); - } - StepsAlignment alignment = {begin_subordinate_idx, begin_chief_idx, - alignment_steps}; - return {alignment, similarity}; -} - -// Returns the best alignment for aligning subordinate against chief. -StepsAlignment FindStepsAlignment(const StepDatabaseResult& subordinate, - const StepDatabaseResult& chief) { - double max_similarity = -1; - StepsAlignment alignment = {0, 0, 0}; - if (subordinate.step_sequence_size() == 0 || chief.step_sequence_size() == 0) - return alignment; - for (auto c = 0; c < chief.step_sequence_size(); c++) { - AlignmentInfo info = - ComputeAlignmentInfo(subordinate, /*subordinate_anchor=*/0, chief, c); - if (info.similarity <= max_similarity) continue; - max_similarity = info.similarity; - alignment = info.alignment; - } - for (auto s = 1; s < subordinate.step_sequence_size(); s++) { - // s starts at 1 instead of 0, because the loop above already considers - // (s=0, c=0). - AlignmentInfo info = - ComputeAlignmentInfo(subordinate, s, chief, /*chief_anchor=*/0); - if (info.similarity <= max_similarity) continue; - max_similarity = info.similarity; - alignment = info.alignment; - } - return alignment; -} - -std::string StringStepsAlignment(const StepsAlignment& alignment) { - return absl::StrCat( - "[begin_subordinate_idx: ", alignment.begin_subordinate_idx, - ", begin_chief_idx: ", alignment.begin_chief_idx, - ", num_steps: ", alignment.num_steps, "]"); -} - -std::string StringDstStepNumbers(const std::vector& step_numbers) { - std::string str; - absl::StrAppend(&str, "["); - for (auto i = 0; i < step_numbers.size(); i++) { - if (i > 0) absl::StrAppend(&str, ", "); - absl::StrAppend(&str, step_numbers[i]); - } - absl::StrAppend(&str, "]"); - return str; -} - -std::string StringSrcToDstIndexMap(uint32 src_first_step_idx, - uint32 num_steps) { - std::string str; - absl::StrAppend(&str, "["); - for (auto i = 0; i < num_steps; i++) { - if (i > 0) absl::StrAppend(&str, ", "); - absl::StrAppend(&str, src_first_step_idx + i, ":", i); - } - absl::StrAppend(&str, "]"); - return str; -} - -} // namespace - -StepIntersection::StepIntersection( - uint32 max_steps, - const absl::flat_hash_map& - perhost_stepdb) { - empty_intersect_ = false; - - // Figures out the host with the shortest timespan among their steps (called - // this host the "chief"). - chief_host_id_ = kuint32max; - uint64 min_duration_ps = kuint64max; - const StepDatabaseResult* chief_step_db = nullptr; - for (const auto& hostid_stepdb : perhost_stepdb) { - auto host_id = hostid_stepdb.first; - const auto& step_db = hostid_stepdb.second; - tsl::profiler::Timespan timespan = AllStepsTimespan(*step_db); - if (timespan.duration_ps() < min_duration_ps) { - chief_host_id_ = host_id; - chief_step_db = step_db; - min_duration_ps = timespan.duration_ps(); - } - } - if (chief_host_id_ == kuint32max) { - // There is no step at all on any host. - steps_dropped_ = 0; - begin_chief_idx_ = 0; - end_chief_idx_ = 0; - return; - } - - uint32 max_begin_chief_idx = 0; - uint32 min_end_chief_idx = kuint32max; - // Aligns the steps in all hosts with those in the chief. - for (const auto& hostid_stepdb : perhost_stepdb) { - auto host_id = hostid_stepdb.first; - const auto& step_db = hostid_stepdb.second; - if (host_id == chief_host_id_) { - // Simply aligns with itself. - perhost_alignment_[host_id] = { - /*begin_subordinate_idx=*/0, /*begin_chief_idx=*/0, - static_cast(step_db->step_sequence_size())}; - } else { - perhost_alignment_[host_id] = - FindStepsAlignment(*step_db, *chief_step_db); - } - // Intersects this host's alignment with other hosts' alignments. - uint32 host_begin_chief_idx = perhost_alignment_[host_id].begin_chief_idx; - max_begin_chief_idx = std::max(max_begin_chief_idx, host_begin_chief_idx); - uint32 host_end_chief_idx = perhost_alignment_[host_id].begin_chief_idx + - perhost_alignment_[host_id].num_steps; - min_end_chief_idx = std::min(min_end_chief_idx, host_end_chief_idx); - } - if (max_begin_chief_idx > min_end_chief_idx) { - // The intersection is empty. - steps_dropped_ = 0; - begin_chief_idx_ = 0; - end_chief_idx_ = 0; - empty_intersect_ = true; - return; - } - - begin_chief_idx_ = max_begin_chief_idx; - - // Takes max_steps into account. - uint32 num_steps = min_end_chief_idx - max_begin_chief_idx; - if (num_steps > max_steps) { - steps_dropped_ = num_steps - max_steps; - // TODO(ckluk): Drops from both ends to avoid incomplete steps at the - // beginning and end of the profile. - end_chief_idx_ = max_begin_chief_idx + max_steps; - } else { - steps_dropped_ = 0; - end_chief_idx_ = min_end_chief_idx; - } -} - -std::vector StepIntersection::DstStepNumbers() const { - // TODO(ckluk): Honors training-loop boundaries (if more than one loop - // sampled). - std::vector result; - result.reserve(NumSteps()); - for (uint32 i = 0; i < NumSteps(); i++) { - result.push_back(i); - } - return result; -} - -uint32 StepIntersection::FirstStepIndex(uint32 host_id) const { - const auto* alignment = gtl::FindOrNull(perhost_alignment_, host_id); - if (alignment == nullptr) return 0; - DCHECK(alignment->begin_chief_idx <= begin_chief_idx_); - uint32 shift = begin_chief_idx_ - alignment->begin_chief_idx; - uint32 begin_subordinate_idx = alignment->begin_subordinate_idx + shift; - return begin_subordinate_idx; -} - -std::string StepIntersection::DebugString() const { - std::string str; - absl::StrAppend(&str, "chief host id_: ", chief_host_id_, "\n"); - absl::StrAppend(&str, "begin_chief_idx_: ", begin_chief_idx_, - ", num_steps: ", NumSteps(), "\n"); - absl::StrAppend( - &str, "DstStepNumbers(): ", StringDstStepNumbers(DstStepNumbers()), "\n"); - - std::vector host_ids; - host_ids.reserve(perhost_alignment_.size()); - for (const auto& hostid_alignment : perhost_alignment_) { - auto host_id = hostid_alignment.first; - host_ids.push_back(host_id); - } - absl::c_sort(host_ids); - - absl::StrAppend(&str, "perhost_alignment:\n"); - for (const auto host_id : host_ids) { - const auto* ptr = gtl::FindOrNull(perhost_alignment_, host_id); - if (ptr == nullptr) continue; - absl::StrAppend(&str, "host: ", host_id, - ", step-alignment: ", StringStepsAlignment(*ptr), "\n"); - } - absl::StrAppend(&str, "SrcToDstIndexMap():\n"); - for (const auto host_id : host_ids) { - absl::StrAppend(&str, "host: ", host_id, ", src-to-dst-index-map: ", - StringSrcToDstIndexMap(FirstStepIndex(host_id), NumSteps()), - "\n"); - } - return str; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/step_intersection.h b/tensorflow/core/profiler/utils/step_intersection.h index 777b0528c30a05..d1932a5c5e43be 100644 --- a/tensorflow/core/profiler/utils/step_intersection.h +++ b/tensorflow/core/profiler/utils/step_intersection.h @@ -16,72 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -// Description of how two step sequences are aligned. -struct StepsAlignment { - uint32 begin_subordinate_idx; // where the alignment begins on the - // subordinate steps. - uint32 begin_chief_idx; // where the alignment begins on the chief steps. - uint32 num_steps; // aligned for how many steps. -}; - -class StepIntersection { - public: - StepIntersection( - uint32 max_steps, - const absl::flat_hash_map& - perhost_stepdb); - - // Returns the number of steps in the intersection. - uint32 NumSteps() const { return end_chief_idx_ - begin_chief_idx_; } - - // Returns the value of empty_intersect_ (see the explanation of - // empty_intersect_ below). - bool EmptyIntersect() const { return empty_intersect_; } - - // Returns the step numbers for the destination (i.e. the intersection - // result). - std::vector DstStepNumbers() const; - - // Returns the index to the step in the given host that corresponds to the - // first step in the intersection. - uint32 FirstStepIndex(uint32 host_id) const; - - // Returns the number of steps dropped due to the max_steps constraint - // specified in the constructor. - uint32 StepsDropped() const { return steps_dropped_; } - - std::string DebugString() const; - - private: - absl::flat_hash_map perhost_alignment_; - uint32 - chief_host_id_; // the host whose step sequence is selected as the chief. - uint32 steps_dropped_; // number of steps dropped. - // If NumSteps() is 0, empty_intersect indicates one of two possible reasons: - // (i) At least one host has some steps, but the intersection over all hosts - // is empty. In this case, empty_intersect is true, - // (ii) None of the hosts has any steps. In this case, empty_intersect is - // false. - // If NumSteps() > 0, empty_intersect is don't care. - bool empty_intersect_; - // The begin and end indices to the chief step sequence for this step - // intersection. Note that the begin index is inclusive but the end index is - // exclusive. - uint32 begin_chief_idx_; - uint32 end_chief_idx_; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/step_intersection.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ diff --git a/tensorflow/core/profiler/utils/step_intersection_test.cc b/tensorflow/core/profiler/utils/step_intersection_test.cc deleted file mode 100644 index 2115581ff1a270..00000000000000 --- a/tensorflow/core/profiler/utils/step_intersection_test.cc +++ /dev/null @@ -1,260 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/step_intersection.h" - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using PerHostStepDb = - absl::flat_hash_map; - -constexpr uint64 kStepDurationPs = 2000000000; -constexpr uint32 kNumStepsPerHost = 10; -constexpr uint64 kStepGapPs = 0; -constexpr uint32 kNumCoresPerHost = 8; - -PerCoreStepInfo CreateOneTestStep(uint32 host_id, uint32 num_steps, - uint32 step_idx, uint64 step_begin_ps) { - PerCoreStepInfo result; - uint32 step_num = - step_idx * host_id; // creates the situation where each host has a - // different step number for the same step. - result.set_step_num(step_num); - StepInfoResult info; - info.set_step_num(step_num); - if (host_id == 0 && step_idx == (num_steps - 1)) { - // Makes the last step on host_id is little bit shorter so that host-0 will - // be chosen as the chief. - info.set_duration_ps(kStepDurationPs - 1); - } else { - info.set_duration_ps(kStepDurationPs); - } - info.set_begin_ps(step_begin_ps); - // Don't care about the rest of the fields in StepInfoResult. - for (uint32 core_id = 0; core_id < kNumCoresPerHost; core_id++) { - (*result.mutable_step_info_per_core())[core_id] = info; - // Don't care about the rest of the fields in PerCoreStepInfo. - } - return result; -} - -PerHostStepDb CreateTestSteps(uint32 num_hosts, uint64 shift_ps) { - PerHostStepDb result; - uint64 first_step_begin_ps = 0; - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - StepDatabaseResult step_db; - uint64 step_begin_ps = first_step_begin_ps; - for (uint32 step_idx = 0; step_idx < kNumStepsPerHost; step_idx++) { - *step_db.add_step_sequence() = - CreateOneTestStep(host_id, kNumStepsPerHost, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db; - first_step_begin_ps += shift_ps; - } - return result; -} - -PerHostStepDb CreateEmptyIntersectTestSteps() { - PerHostStepDb result; - - uint64 step_begin_ps; - uint32 host_id; - - // Host-0 - host_id = 0; - step_begin_ps = 0; - uint64 host_0_num_steps = 10; - StepDatabaseResult step_db_0; - for (uint32 step_idx = 0; step_idx < host_0_num_steps; step_idx++) { - *step_db_0.add_step_sequence() = - CreateOneTestStep(host_id, host_0_num_steps, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db_0; - - // Host-1 - host_id = 1; - step_begin_ps = (host_0_num_steps - 2) * (kStepDurationPs + kStepGapPs); - uint64 host_1_num_steps = 5; - StepDatabaseResult step_db_1; - for (uint32 step_idx = 0; step_idx < host_1_num_steps; step_idx++) { - *step_db_1.add_step_sequence() = - CreateOneTestStep(host_id, host_1_num_steps, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db_1; - - // Host-2 - host_id = 2; - step_begin_ps = (host_0_num_steps + host_1_num_steps - 4) * - (kStepDurationPs + kStepGapPs); - uint64 host_2_num_steps = 10; - StepDatabaseResult step_db_2; - for (uint32 step_idx = 0; step_idx < host_2_num_steps; step_idx++) { - *step_db_2.add_step_sequence() = - CreateOneTestStep(host_id, host_2_num_steps, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db_2; - - return result; -} - -PerHostStepDb CreateNoStep(uint32 num_hosts) { - PerHostStepDb result; - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - StepDatabaseResult step_db; - result[host_id] = step_db; - } - return result; -} - -absl::flat_hash_map Convert( - const PerHostStepDb& perhost_stepdb) { - absl::flat_hash_map result; - for (const auto& hostid_stepdb : perhost_stepdb) { - auto host_id = hostid_stepdb.first; - const auto& step_db = hostid_stepdb.second; - result[host_id] = &step_db; - } - return result; -} - -TEST(StepIntersectionTest, EachHostShiftedBy1StepDuration) { - uint32 num_hosts = 4; - uint64 shift_ps = kStepDurationPs; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost - num_hosts + 1; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - uint32 src_first_step_index = intersection.FirstStepIndex(0); - EXPECT_EQ(src_first_step_index, num_hosts - 1); - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } -} - -TEST(StepIntersectionTest, ExactlyNoShift) { - uint32 num_hosts = 4; - uint64 shift_ps = 0; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - uint32 src_first_step_index = intersection.FirstStepIndex(host_id); - EXPECT_EQ(src_first_step_index, 0); - } -} - -TEST(StepIntersectionTest, EachHostShiftedByJustABit) { - uint32 num_hosts = 4; - uint64 shift_ps = 100; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - uint32 src_first_step_index = intersection.FirstStepIndex(host_id); - EXPECT_EQ(src_first_step_index, 0); - } -} - -TEST(StepIntersectionTest, SingleHost) { - uint32 num_hosts = 1; - uint64 shift_ps = 0; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - uint32 src_first_step_index = intersection.FirstStepIndex(host_id); - EXPECT_EQ(src_first_step_index, 0); - } -} - -TEST(StepIntersectionTest, WithMaxSteps) { - uint32 num_hosts = 4; - uint64 shift_ps = 0; - uint32 max_steps = 3; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(max_steps, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), kNumStepsPerHost - max_steps); - EXPECT_EQ(intersection.NumSteps(), max_steps); -} - -TEST(StepIntersectionTest, NoStep) { - uint32 num_hosts = 4; - uint32 max_steps = 100; - PerHostStepDb perhost_stepdb = CreateNoStep(num_hosts); - StepIntersection intersection = - StepIntersection(max_steps, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.NumSteps(), 0); - EXPECT_FALSE(intersection.EmptyIntersect()); -} - -TEST(StepIntersectionTest, EmptyIntersection) { - uint32 max_steps = 100; - PerHostStepDb perhost_stepdb = CreateEmptyIntersectTestSteps(); - StepIntersection intersection = - StepIntersection(max_steps, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - EXPECT_EQ(intersection.NumSteps(), 0); - EXPECT_TRUE(intersection.EmptyIntersect()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc deleted file mode 100644 index 2d3b5fa4a1bc8e..00000000000000 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/tfstreamz_utils.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/strings/substitute.h" -#include "tensorflow/core/framework/summary.pb.h" -#include "tensorflow/core/lib/monitoring/collected_metrics.h" -#include "tensorflow/core/lib/monitoring/metric_def.h" -#include "tensorflow/core/lib/monitoring/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/tfstreamz.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -std::string ConstructXStatName(absl::string_view name, - const monitoring::Point& point) { - if (point.labels.empty()) { - return std::string(name); - } - return absl::Substitute( - "$0{$1}", name, - absl::StrJoin( - point.labels, ", ", - [](std::string* out, const monitoring::Point::Label& label) { - absl::StrAppend(out, label.name, "=", label.value); - })); -} - -tfstreamz::Percentiles ToProto(const monitoring::Percentiles& percentiles) { - tfstreamz::Percentiles output; - output.set_unit_of_measure( - static_cast(percentiles.unit_of_measure)); - output.set_start_nstime(percentiles.start_nstime); - output.set_end_nstime(percentiles.end_nstime); - output.set_min_value(percentiles.min_value); - output.set_max_value(percentiles.max_value); - output.set_mean(percentiles.mean); - output.set_stddev(percentiles.stddev); - output.set_num_samples(percentiles.num_samples); - output.set_total_samples(percentiles.total_samples); - output.set_accumulator(percentiles.accumulator); - for (const auto& pp : percentiles.points) { - auto* percentile_point = output.add_points(); - percentile_point->set_percentile(pp.percentile); - percentile_point->set_value(pp.value); - } - return output; -} - -} // namespace - -absl::Status SerializeToXPlane(const std::vector& snapshots, - XPlane* plane, uint64 line_start_time_ns) { - XPlaneBuilder xplane(plane); - XLineBuilder line = xplane.GetOrCreateLine(0); // This plane has single line. - line.SetTimestampNs(line_start_time_ns); - - // For each snapshot, create a virtual event. - for (const auto& snapshot : snapshots) { - XEventMetadata* event_metadata = - xplane.GetOrCreateEventMetadata("TFStreamz Snapshot"); - XEventBuilder xevent = line.AddEvent(*event_metadata); - xevent.SetTimestampNs(snapshot.start_time_ns); - xevent.SetEndTimestampNs(snapshot.end_time_ns); - auto& metric_descriptor_map = snapshot.metrics->metric_descriptor_map; - for (const auto& point_set : snapshot.metrics->point_set_map) { - const std::string& metric_name = point_set.first; - // Each metrics have multiple points corresponding to different labels. - for (const auto& point : point_set.second->points) { - // Generates one KPI metric for each point. - std::string stat_name = ConstructXStatName(metric_name, *point); - auto* metadata = xplane.GetOrCreateStatMetadata(stat_name); - auto it = metric_descriptor_map.find(metric_name); - if (it != metric_descriptor_map.end()) { - metadata->set_description(it->second->description); - } - switch (point->value_type) { - case monitoring::ValueType::kInt64: - xevent.AddStatValue(*metadata, point->int64_value); - break; - case monitoring::ValueType::kBool: - xevent.AddStatValue(*metadata, point->bool_value); - break; - case monitoring::ValueType::kString: - xevent.AddStatValue(*metadata, *xplane.GetOrCreateStatMetadata( - point->string_value)); - break; - case monitoring::ValueType::kDouble: - xevent.AddStatValue(*metadata, point->double_value); - break; - case monitoring::ValueType::kHistogram: - xevent.AddStatValue(*metadata, point->histogram_value); - break; - case monitoring::ValueType::kPercentiles: - xevent.AddStatValue(*metadata, ToProto(point->percentiles_value)); - break; - } - } - } - } - return absl::OkStatus(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.h b/tensorflow/core/profiler/utils/tfstreamz_utils.h index abaafbc6e3c990..dffca153ac07b0 100644 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.h +++ b/tensorflow/core/profiler/utils/tfstreamz_utils.h @@ -15,27 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ -#include -#include - -#include "absl/status/status.h" -#include "tensorflow/core/lib/monitoring/collected_metrics.h" -#include "tensorflow/core/platform/types.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -struct TfStreamzSnapshot { - std::unique_ptr metrics; - uint64 start_time_ns; // time before collection. - uint64 end_time_ns; // time after collection. -}; - -absl::Status SerializeToXPlane(const std::vector& snapshots, - XPlane* plane, uint64 line_start_time_ns); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/tfstreamz_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h b/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h index 731481a4da8612..e803bbc1b41244 100644 --- a/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h +++ b/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The OpenXLA Authors. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,61 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ -#include - -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -// Total duration of infeed from host or SparseCoreV0 to TensorCore. -inline uint64_t InfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.infeed_duration_ps() + tpu.wait_for_scv0_duration_ps() + - tpu.scv0_infeed_transform_ps(); -} - -// Total duration of outfeed from TensorCore to host or SparseCoreV0. -inline uint64_t OutfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.host_outfeed_ps() + tpu.scv0_outfeed_ps(); -} - -// Total duration of infeed from host to SparseCoreV0. -inline uint64_t ScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.wait_for_scv0_duration_ps() * tpu.scv0_infeed_percent() / 100.0; -} - -// Total duration of SparseCoreV0 compute. -inline uint64_t ScV0ComputeDurationPs(const TpuStepBreakdown& tpu) { - return tpu.wait_for_scv0_duration_ps() - ScV0InfeedDurationPs(tpu); -} - -// Total duration of infeed from host to TensorCore or SparseCoreV0. -inline uint64_t TcPlusScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.infeed_duration_ps() + ScV0InfeedDurationPs(tpu); -} - -// Total duration of send and recv ops. -inline uint64_t SendRecvDurationPs(const TpuStepBreakdown& tpu) { - return tpu.send_duration_ps() + tpu.recv_duration_ps(); -} - -// Total duration of host send and host recv ops. -inline uint64_t HostSendRecvDurationPs(const TpuStepBreakdown& tpu) { - return tpu.host_send_duration_ps() + tpu.host_recv_duration_ps(); -} - -// Total duration TensorCore spends waiting for host. -inline uint64_t WaitForHostDurationPs(const TpuStepBreakdown& tpu) { - return tpu.infeed_duration_ps() + tpu.host_outfeed_ps() + - HostSendRecvDurationPs(tpu) + tpu.tc_idle_ps(); -} - -// Total duration TensorCore spends waiting for host or SparseCoreV0. -inline uint64_t WaitForHostOrScV0DurationPs(const TpuStepBreakdown& tpu) { - return WaitForHostDurationPs(tpu) + tpu.wait_for_scv0_duration_ps(); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/tpu_step_breakdown_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/tpu_step_details_utils.h b/tensorflow/core/profiler/utils/tpu_step_details_utils.h index 23c1609dc797b7..8ce4f3a2bef490 100644 --- a/tensorflow/core/profiler/utils/tpu_step_details_utils.h +++ b/tensorflow/core/profiler/utils/tpu_step_details_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The OpenXLA Authors. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,35 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ -#include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" - -namespace tensorflow { -namespace profiler { - -inline double ComputeTimeMs(const PerTpuStepDetails& details) { - return details.tc_compute_time_ms() + details.scv0_compute_time_ms(); -} - -inline double InfeedTimeMs(const PerTpuStepDetails& details) { - return details.tc_infeed_time_ms() + details.scv0_infeed_time_ms(); -} - -inline double AllReduceTimeMs(const PerTpuStepDetails& details) { - return details.all_reduce_compute_time_ms() + - details.all_reduce_sync_time_ms(); -} - -inline double NonIdleTimeMs(const PerTpuStepDetails& details) { - return ComputeTimeMs(details) + InfeedTimeMs(details) + - AllReduceTimeMs(details) + details.tc_outfeed_time_ms(); -} - -// Time spent by a training step on TPU. -inline double StepTimeMs(const PerTpuStepDetails& details) { - return NonIdleTimeMs(details) + details.tc_idle_time_ms(); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/tpu_step_details_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc deleted file mode 100644 index 321cf041502c7b..00000000000000 --- a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/primitive_util.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/tsl/platform/errors.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -std::vector GetInputBitwidths(const xla::HloInstruction& hlo) { - std::vector input_bitwidths; - for (const auto& operand : hlo.operands()) { - switch (operand->shape().element_type()) { - case xla::PRIMITIVE_TYPE_INVALID: - case xla::TUPLE: - case xla::OPAQUE_TYPE: - case xla::TOKEN: - break; - default: - input_bitwidths.push_back( - xla::primitive_util::BitWidth(operand->shape().element_type())); - } - } - return input_bitwidths; -} - -} // namespace - -absl::Status XProfGpuCostAnalysis::HandleCustomCall( - const xla::HloInstruction* hlo) { - TF_RETURN_IF_ERROR(xla::gpu::GpuHloCostAnalysis::HandleCustomCall(hlo)); - - if (xla::gpu::IsCublasGemm(*hlo)) { - // The naming conventions and meanings of gemm parameters are documented at: - // https://docs.nvidia.com/cuda/cublas/index.html#using-the-cublaslt-api - // as inherited from GpuHloCostAnalysis, we only normalize the flops based - // on the datatype of A and B, which are supposed of same bitwidth. - int dot_operands_bitwidth = - xla::primitive_util::BitWidth(hlo->operand(0)->shape().element_type()); - uint32_t flop_rate_adjustment = 1; - switch (dot_operands_bitwidth) { - case 8: - flop_rate_adjustment = 2; - break; - case 4: - flop_rate_adjustment = 4; - break; - default: - break; - } - float model_flops = current_properties_[kFlopsKey]; - current_properties_[kDeviceFlopsAdjustment] = - model_flops - model_flops / flop_rate_adjustment; - } - return absl::OkStatus(); -} - -absl::Status XProfGpuCostAnalysis::DefaultPostprocess( - const xla::HloInstruction* hlo) { - uint32_t flop_rate_adjustment = 1; - float model_flops = current_properties_[kFlopsKey]; - - // Calculate adjustment of device flops based on input bit widths. - // This provide most general adjustment for all ops, and for all gpus. - std::vector input_bitwidths = GetInputBitwidths(*hlo); - if (!input_bitwidths.empty()) { - int max_input_bitwidth = - *std::max_element(input_bitwidths.begin(), input_bitwidths.end()); - if (model_flops) { - // for int8/fp8, 2x flops assumed comparing with fp16 flops(most of - // recent GPU models); for int4, 4x of model flops assumed comparing - // with fp16 flops. (like Nvidia T4, 3090). It will be more precise - // after adjustment based on specific GPUs mentioned above. - switch (max_input_bitwidth) { - case 8: - flop_rate_adjustment = 2; - break; - case 4: - flop_rate_adjustment = 4; - break; - default: - break; - } - } - } - current_properties_[kDeviceFlopsAdjustment] = - model_flops - model_flops / flop_rate_adjustment; - return absl::OkStatus(); -} - -absl::Status XProfGpuCostAnalysis::Postprocess(const xla::HloInstruction* hlo) { - if (hlo == nullptr) { - return absl::OkStatus(); - } - - switch (hlo->opcode()) { - case xla::HloOpcode::kCustomCall: - // Already handled specially in HandleCustomCall(), skip here. - // Add more OpCode here if it is handled specially in future. - break; - default: - DefaultPostprocess(hlo).IgnoreError(); - break; - } - - return xla::gpu::GpuHloCostAnalysis::Postprocess(hlo); -} - -std::unique_ptr -XProfGpuCostAnalysis::CreateNestedCostAnalysis() { - return std::make_unique(options_); -} - -int64_t XProfGpuCostAnalysis::GetDeviceFlopsAdjustment( - const xla::HloInstruction& hlo) { - return GetPropertyForHlo(hlo, kDeviceFlopsAdjustment, hlo_properties_); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h index 76b50f5997d9c6..3814be42d65646 100644 --- a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h @@ -16,42 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/hlo_cost_analysis.h" - -namespace tensorflow { -namespace profiler { - -// XProfGpuCostAnalysis provides additional cost analysis for XProf, which -// normalizes the flops to the device flops based on input bit widths. -class XProfGpuCostAnalysis : public xla::gpu::GpuHloCostAnalysis { - public: - explicit XProfGpuCostAnalysis(const xla::HloCostAnalysis::Options& options) - : xla::gpu::GpuHloCostAnalysis(options) {} - - absl::Status HandleCustomCall(const xla::HloInstruction* hlo) override; - - absl::Status Postprocess(const xla::HloInstruction* hlo) override; - - int64_t GetDeviceFlopsAdjustment(const xla::HloInstruction& hlo); - - protected: - std::unique_ptr CreateNestedCostAnalysis() override; - - absl::Status DefaultPostprocess(const xla::HloInstruction* hlo); - - private: - static inline constexpr absl::string_view kDeviceFlopsAdjustment = - "device_flops_adjustment"; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc deleted file mode 100644 index c71d1a9dfb5730..00000000000000 --- a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/testlib/test_helpers.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/xla_data.pb.h" - -namespace tensorflow { -namespace profiler { - -class XprofGpuHloCostAnalysisTest : public xla::HloTestBase { - xla::HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const xla::Shape& shape) { - constexpr int64_t kPointerSize = 8; - return xla::ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - - public: - xla::HloCostAnalysis::Options options_{ - ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; - XProfGpuCostAnalysis analysis_{options_}; - XprofGpuHloCostAnalysisTest() : xla::HloTestBase() {} -}; - -TEST_F(XprofGpuHloCostAnalysisTest, Fp16GemmNoAdjustment) { - absl::string_view hlo_string = R"( -HloModule r - -ENTRY e { - arg0 = f16[65536,32800] parameter(0) - arg1 = f16[32800,32] parameter(1) - gemm = (f16[65536,32], s8[0]) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config="{ - \"gemm_backend_config\": { - \"alpha_real\":1, - \"beta\":0, - \"dot_dimension_numbers\":{ - \"lhs_contracting_dimensions\":[\"1\"], - \"rhs_contracting_dimensions\":[\"0\"], - \"lhs_batch_dimensions\":[], - \"rhs_batch_dimensions\":[] - }, - \"alpha_imag\":0, - \"precision_config\":{ - \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] - }, - \"epilogue\":\"DEFAULT\" - } - }" - ROOT get-tuple-element = f16[65536,32] - get-tuple-element((f16[65536,32], s8[0]) gemm), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - xla::HloComputation* comp = module->entry_computation(); - const xla::HloInstruction* fp16gemm = comp->GetInstructionWithName("gemm"); - // flops of gemm A * B = rows(A) * cols(B) * cols(A) * 2 - // where 2 is for the add and multiply - int64_t gold_flops = 65536LL * 32800 * 32 * 2; - EXPECT_EQ(analysis_.flop_count(*fp16gemm), gold_flops); - EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*fp16gemm), 0); -} - -TEST_F(XprofGpuHloCostAnalysisTest, S8GemmAdjustment) { - absl::string_view hlo_string = R"( -HloModule r - -ENTRY e { - arg0 = s8[65536,32800] parameter(0) - arg1 = s8[32800,32] parameter(1) - gemm = (s32[65536,32], s8[0]) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config="{ - \"gemm_backend_config\": { - \"alpha_real\":1, - \"beta\":0, - \"dot_dimension_numbers\":{ - \"lhs_contracting_dimensions\":[\"1\"], - \"rhs_contracting_dimensions\":[\"0\"], - \"lhs_batch_dimensions\":[], - \"rhs_batch_dimensions\":[] - }, - \"alpha_imag\":0, - \"precision_config\":{ - \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] - }, - \"epilogue\":\"DEFAULT\" - } - }" - ROOT get-tuple-element = s32[65536,32] - get-tuple-element((s32[65536,32], s8[0]) gemm), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - xla::HloComputation* comp = module->entry_computation(); - const xla::HloInstruction* s8gemm = comp->GetInstructionWithName("gemm"); - int64_t gold_flops = 65536LL * 32800 * 32 * 2; - EXPECT_EQ(analysis_.flop_count(*s8gemm), gold_flops); - // Matmul of int8 * int8 -> int32, normalized it to equivalent fp16 flops by - // dividing by 2 as all inputs are 8 bits - EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*s8gemm), gold_flops / 2); -} - -// test special handling logic when fp32 parameter is also used -TEST_F(XprofGpuHloCostAnalysisTest, Fp8GemmWithFp32ParameterAdjustment) { - absl::string_view hlo_string = R"( -HloModule r - -ENTRY e { - arg0 = f8e4m3fn[2048,5120]{1,0} parameter(0) - arg1 = f8e4m3fn[5120,5120]{0,1} parameter(1) - arg2 = f32[] parameter(2) - arg3 = f32[] parameter(3) - gemm = (bf16[2048,5120]{1,0}, s8[33554432]{0}) - custom-call(arg0, arg1, arg2, arg3), - custom_call_target="__cublas$lt$matmul$f8", - backend_config="{ - \"gemm_backend_config\": { - \"alpha_real\":1, - \"beta\":0, - \"dot_dimension_numbers\":{ - \"lhs_contracting_dimensions\":[\"1\"], - \"rhs_contracting_dimensions\":[\"0\"], - \"lhs_batch_dimensions\":[], - \"rhs_batch_dimensions\":[] - }, - \"alpha_imag\":0, - \"precision_config\":{ - \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] - }, - \"epilogue\":\"DEFAULT\", - \"lhs_stride\":\"10485760\", - \"rhs_stride\":\"26214400\", - \"grad_x\":false, - \"grad_y\":false, - \"damax_output\":false - } - }" - ROOT get-tuple-element = bf16[2048,5120]{1,0} - get-tuple-element((bf16[2048,5120]{1,0}, s8[33554432]{0}) gemm), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - xla::HloComputation* comp = module->entry_computation(); - const xla::HloInstruction* fp8_gemm = comp->GetInstructionWithName("gemm"); - int64_t gold_flops = 2048LL * 5120 * 5120 * 2; - EXPECT_EQ(analysis_.flop_count(*fp8_gemm), gold_flops); - // Matmul of int8 * int8 -> int32, normalized it to equivalent fp16 flops by - // dividing by 2 as all inputs are 8 bits - EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*fp8_gemm), gold_flops / 2); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 4daa7388dd0dd8..e8570a3754381b 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -937,9 +937,9 @@ def _tf_repositories(): tf_http_archive( name = "org_xprof", - sha256 = "11ce459e0ca28779f09065b9376c4ed141de4ed3cbabc309f8f5a8a5119a7c2f", - strip_prefix = "profiler-eaa0840079b0f2815a27cef8d250c523310631e4", - urls = tf_mirror_urls("https://github.com/tensorflow/profiler/archive/eaa0840079b0f2815a27cef8d250c523310631e4.zip"), + sha256 = "88bc65694f79f266e16269da73b5b9238db1552175d1cd75bc08c7337377ab0d", + strip_prefix = "profiler-5d90906294ecbd83639b583fc926cbedc06e60dc", + urls = tf_mirror_urls("https://github.com/tensorflow/profiler/archive/5d90906294ecbd83639b583fc926cbedc06e60dc.zip"), ) # used for adding androidx.annotation dependencies in tflite android jni. diff --git a/third_party/xla/xla/tsl/lib/gtl/BUILD b/third_party/xla/xla/tsl/lib/gtl/BUILD index 9bb44b99691f43..de5643bee7b310 100644 --- a/third_party/xla/xla/tsl/lib/gtl/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/BUILD @@ -32,6 +32,9 @@ package( "//xla/tsl/distributed_runtime/rpc:__pkg__", "//xla/tsl/profiler/utils:__pkg__", "//tensorflow/core/profiler/convert:__pkg__", + # xprof /convert and /utils uses map_util and top_n + "//third_party/xprof/convert:__pkg__", + "//third_party/xprof/utils:__pkg__", ]), licenses = ["notice"], ) diff --git a/third_party/xla/xla/tsl/lib/monitoring/BUILD b/third_party/xla/xla/tsl/lib/monitoring/BUILD index c5d20521f13b0a..cee06b4f516f9d 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/BUILD +++ b/third_party/xla/xla/tsl/lib/monitoring/BUILD @@ -28,8 +28,8 @@ package( "//xla/tsl/distributed_runtime:__subpackages__", "//tensorflow/compiler/mlir/tf2xla:__subpackages__", "//tensorflow_serving/model_servers:__subpackages__", - # xprof/pywrap depends on this package - "//third_party/xprof/pywrap:__pkg__", + # xprof depends on this package + "//third_party/xprof:__subpackages__", "//tensorflow/python/profiler/internal:__pkg__", ]), licenses = ["notice"], From 02ebe2bc63a8d5b93ee80f3c945376eadc819f07 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Fri, 28 Mar 2025 20:13:41 -0700 Subject: [PATCH 0009/1324] Fix IfrtLoadVariable is deleted while still in use PiperOrigin-RevId: 741724411 --- .../mlrt/rewrite_ifrt_load_variable.mlir | 23 ++++++++++++++++++ .../mlrt/rewrite_ifrt_load_variable.cc | 24 ++++++++++++++----- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir index a862e6abf7274f..fa2ec0b14c8166 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir @@ -18,3 +18,26 @@ %2 = "tf.IfrtCall"(%arg0, %array_key) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> return %2 : tensor<1x1xf32> } + + +// ----- +// Variable is used by two CPU ops +// +// CHECK-LABEL: func @serving_default +// CHECK-NEXT: [[HANDLE:%.*]] = "tf.VarHandleOp"() +// CHECK-NEXT: [[ARRAYKEY:%.*]], [[FURTURE:%.*]] = "tf_mlrt.tf_ifrt_load_variable"([[HANDLE]]) +// CHECK-SAME: <{used_by_host = true}> : (tensor>>) -> (tensor, !mlrt.future) +// CHECK: [[TENSOR:%.*]] = "tf_mlrt.tf_await"([[FURTURE]]) : (!mlrt.future) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.AddV2"([[TENSOR]], %cst) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.Sub"([[TENSOR]], %cst) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> +// CHECK-NEXT: return +// + func.func @serving_default() { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %array_key, %tensor = "tf.IfrtLoadVariable"(%0) <{used_by_host = true}> : (tensor>>) -> (tensor, tensor<3x1xf32>) + %cst_24 = "tf.Const"() <{value = dense<[[0.0], [1.0], [2.0]]> : tensor<3x1xf32>}> : () -> tensor<3x1xf32> + %1 = "tf.AddV2"(%tensor, %cst_24) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> + %2 = "tf.Sub"(%tensor, %cst_24) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> + + return + } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc index 368a91ac54f955..98058a3b32028c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -75,16 +76,27 @@ class RewriteIfrtLoadVariablePass builder.create( load_variable_op->getLoc(), result_types, load_variable_op->getOperands(), load_variable_op->getAttrs()); - for (auto user : load_variable_op.getTensorFuture().getUsers()) { - builder.setInsertionPoint(user); - auto await_op = builder.create( - user->getLoc(), load_variable_op.getTensorFuture().getType(), - mlrt_load_variable_op.getTensorFuture()); + tf_mlrt::TFAwaitOp await_op; + for (auto user : llvm::make_early_inc_range( + load_variable_op.getTensorFuture().getUsers())) { + // Materialize the future for the first use. Reuse it for the rest of + // the uses. + if (!await_op) { + builder.setInsertionPoint(user); + await_op = builder.create( + user->getLoc(), load_variable_op.getTensorFuture().getType(), + mlrt_load_variable_op.getTensorFuture()); + } else { + if (user->isBeforeInBlock(await_op)) { + await_op->moveBefore(user); + } + } user->replaceUsesOfWith(load_variable_op.getTensorFuture(), await_op.getResult()); } - for (auto user : load_variable_op.getArrayKey().getUsers()) { + for (auto user : llvm::make_early_inc_range( + load_variable_op.getArrayKey().getUsers())) { user->replaceUsesOfWith(load_variable_op.getArrayKey(), mlrt_load_variable_op.getArrayKey()); } From 1610bd6b4d0a62598155f1a1b14bcb36514f654b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Mar 2025 20:38:59 -0700 Subject: [PATCH 0010/1324] Automated Code Change PiperOrigin-RevId: 741729560 --- tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc | 4 ++++ tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc index 05f242397f44d9..5046ace8c07eb7 100644 --- a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc +++ b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include #include #include "flatbuffers/flexbuffers.h" // from @flatbuffers diff --git a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc index 0ffacfbae80bc8..3e3591f6471f02 100644 --- a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc +++ b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include +#include #include +#include #include #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/interpreter.h" From c766545c086b2926c28f44d55c6b67285b92212c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Mar 2025 21:45:42 -0700 Subject: [PATCH 0011/1324] Automated Code Change PiperOrigin-RevId: 741741192 --- .../experimental/acceleration/mini_benchmark/c/c_api.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc index 51c927836135dc..1e8401aeb6f35e 100644 --- a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" -#include +#include +#include #include #include -#include #include #include @@ -25,9 +25,9 @@ limitations under the License. #include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/benchmark_result_evaluator.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/blocking_validator_runner.h" -#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_options.h" From 5475b2061a1150fe0a82513577e8ca0e73ccba83 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 29 Mar 2025 02:02:39 -0700 Subject: [PATCH 0012/1324] compat: Update forward compatibility horizon to 2025-03-29 PiperOrigin-RevId: 741780941 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index c38d60326e1ef6..c18dd57b46e9c3 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 28) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 29) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From af882fe7414c42fa78c14c9779e23b405cc3938d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 29 Mar 2025 02:02:40 -0700 Subject: [PATCH 0013/1324] Update GraphDef version to 2181. PiperOrigin-RevId: 741780944 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index ddcfd2d0bec899..2e0cb08f8b384a 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2180 // Updated: 2025/3/28 +#define TF_GRAPH_DEF_VERSION 2181 // Updated: 2025/3/29 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 3c4f2c5e71aec24dd12f21985ce65ade3dce8f02 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 29 Mar 2025 10:44:37 -0700 Subject: [PATCH 0014/1324] Integrate LLVM at llvm/llvm-project@c0952a931c7d Updates LLVM usage to match [c0952a931c7d](https://github.com/llvm/llvm-project/commit/c0952a931c7d) PiperOrigin-RevId: 741850272 --- third_party/llvm/generated.patch | 545 ++++++++ third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 1149 ++++++++--------- third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 1149 ++++++++--------- .../xla/third_party/shardy/workspace.bzl | 4 +- 6 files changed, 1655 insertions(+), 1200 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..99ef3cb5cdd7c0 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,546 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp +--- a/clang/lib/Driver/ToolChains/Clang.cpp ++++ b/clang/lib/Driver/ToolChains/Clang.cpp +@@ -6397,7 +6397,9 @@ + Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, + options::OPT_fno_convergent_functions); + +- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); ++ // NVPTX doesn't support PGO or coverage ++ if (!Triple.isNVPTX()) ++ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); + + Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); + +diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu +--- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu ++++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu +@@ -0,0 +1,33 @@ ++// Check that profiling/coverage arguments doen't get passed down to device-side ++// compilation. ++// ++// ++// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ ++// XRUN: -fprofile-generate %s 2>&1 | \ ++// XRUN: FileCheck --check-prefixes=CHECK,PROF %s ++// ++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ ++// RUN: -fprofile-instr-generate %s 2>&1 | \ ++// RUN: FileCheck --check-prefixes=CHECK,PROF %s ++// ++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ ++// RUN: -coverage %s 2>&1 | \ ++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s ++// ++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ ++// RUN: -ftest-coverage %s 2>&1 | \ ++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s ++// ++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ ++// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ ++// RUN: FileCheck --check-prefixes=CHECK,PROF %s ++// ++// ++// CHECK-NOT: error: unsupported option '-fprofile ++// CHECK-NOT: error: invalid argument ++// CHECK-DAG: "-fcuda-is-device" ++// CHECK-NOT: "-f{{[^"/]*coverage.*}}" ++// CHECK-NOT: "-fprofile{{[^"]*}}" ++// CHECK: "-triple" "x86_64-unknown-linux-gnu" ++// PROF: "-fprofile{{.*}}" ++// GCOV: "-coverage-notes-file= +diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp +--- a/lldb/tools/lldb-dap/DAP.cpp ++++ b/lldb/tools/lldb-dap/DAP.cpp +@@ -711,12 +711,12 @@ + [](const std::string &message) -> llvm::StringRef { + return message; + }, +- [](const protocol::Response::Message &message) ++ [](const protocol::ResponseMessage &message) + -> llvm::StringRef { + switch (message) { +- case protocol::Response::Message::cancelled: ++ case protocol::eResponseMessageCancelled: + return "cancelled"; +- case protocol::Response::Message::notStopped: ++ case protocol::eResponseMessageNotStopped: + return "notStopped"; + } + }), +diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +@@ -7,6 +7,7 @@ + //===----------------------------------------------------------------------===// + + #include "Protocol/ProtocolBase.h" ++#include "lldb/lldb-enumerations.h" + #include "llvm/ADT/StringRef.h" + #include "llvm/ADT/StringSwitch.h" + #include "llvm/Support/ErrorHandling.h" +@@ -31,11 +32,8 @@ + + namespace lldb_dap::protocol { + +-enum MessageType { +- eMessageTypeRequest, +- eMessageTypeResponse, +- eMessageTypeEvent +-}; ++FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, ++ eMessageTypeEvent}; + + bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { + auto rawType = Params.getAsString(); +@@ -107,12 +105,12 @@ + + if (R.message) { + assert(!R.success && "message can only be used if success is false"); +- if (const auto *messageEnum = std::get_if(&*R.message)) { ++ if (const auto *messageEnum = std::get_if(&*R.message)) { + switch (*messageEnum) { +- case Response::Message::cancelled: ++ case eResponseMessageCancelled: + Result.insert({"message", "cancelled"}); + break; +- case Response::Message::notStopped: ++ case eResponseMessageNotStopped: + Result.insert({"message", "notStopped"}); + break; + } +@@ -129,16 +127,16 @@ + } + + bool fromJSON(json::Value const &Params, +- std::variant &M, json::Path P) { ++ std::variant &M, json::Path P) { + auto rawMessage = Params.getAsString(); + if (!rawMessage) { + P.report("expected a string"); + return false; + } +- std::optional message = +- StringSwitch>(*rawMessage) +- .Case("cancelled", Response::Message::cancelled) +- .Case("notStopped", Response::Message::notStopped) ++ std::optional message = ++ StringSwitch>(*rawMessage) ++ .Case("cancelled", eResponseMessageCancelled) ++ .Case("notStopped", eResponseMessageNotStopped) + .Default(std::nullopt); + if (message) + M = *message; +diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +@@ -20,6 +20,7 @@ + #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H + #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H + ++#include "lldb/lldb-enumerations.h" + #include "llvm/Support/JSON.h" + #include + #include +@@ -64,15 +65,15 @@ + llvm::json::Value toJSON(const Event &); + bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); + +-/// Response for a request. +-struct Response { +- enum class Message { ++FLAGS_ENUM(ResponseMessage){ + /// The request was cancelled +- cancelled, ++ eResponseMessageCancelled, + /// The request may be retried once the adapter is in a 'stopped' state +- notStopped, +- }; ++ eResponseMessageNotStopped, ++}; + ++/// Response for a request. ++struct Response { + /// Sequence number of the corresponding request. + int64_t request_seq; + +@@ -90,7 +91,7 @@ + /// Contains the raw error in short form if `success` is false. This raw error + /// might be interpreted by the client and is not shown in the UI. Some + /// predefined values exist. +- std::optional> message; ++ std::optional> message; + + /// Contains request result if success is true and error details if success is + /// false. +diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +--- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ++++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +@@ -22,6 +22,8 @@ + + #include "Protocol/ProtocolBase.h" + #include "Protocol/ProtocolTypes.h" ++#include "lldb/lldb-enumerations.h" ++#include "llvm/ADT/DenseSet.h" + #include "llvm/Support/JSON.h" + #include + #include +@@ -55,26 +57,26 @@ + using DisconnectResponse = VoidResponse; + + /// Features supported by DAP clients. +-enum ClientFeature { +- eClientFeatureVariableType, +- eClientFeatureVariablePaging, +- eClientFeatureRunInTerminalRequest, +- eClientFeatureMemoryReferences, +- eClientFeatureProgressReporting, +- eClientFeatureInvalidatedEvent, +- eClientFeatureMemoryEvent, +- /// Client supports the `argsCanBeInterpretedByShell` attribute on the +- /// `runInTerminal` request. +- eClientFeatureArgsCanBeInterpretedByShell, +- eClientFeatureStartDebuggingRequest, +- /// The client will interpret ANSI escape sequences in the display of +- /// `OutputEvent.output` and `Variable.value` fields when +- /// `Capabilities.supportsANSIStyling` is also enabled. +- eClientFeatureANSIStyling, ++FLAGS_ENUM(ClientFeature){ ++ eClientFeatureVariableType, ++ eClientFeatureVariablePaging, ++ eClientFeatureRunInTerminalRequest, ++ eClientFeatureMemoryReferences, ++ eClientFeatureProgressReporting, ++ eClientFeatureInvalidatedEvent, ++ eClientFeatureMemoryEvent, ++ /// Client supports the `argsCanBeInterpretedByShell` attribute on the ++ /// `runInTerminal` request. ++ eClientFeatureArgsCanBeInterpretedByShell, ++ eClientFeatureStartDebuggingRequest, ++ /// The client will interpret ANSI escape sequences in the display of ++ /// `OutputEvent.output` and `Variable.value` fields when ++ /// `Capabilities.supportsANSIStyling` is also enabled. ++ eClientFeatureANSIStyling, + }; + + /// Format of paths reported by the debug adapter. +-enum PathFormat { ePatFormatPath, ePathFormatURI }; ++FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; + + /// Arguments for `initialize` request. + struct InitializeRequestArguments { +diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +--- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ++++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +@@ -20,6 +20,7 @@ + #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H + #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H + ++#include "lldb/lldb-enumerations.h" + #include "llvm/ADT/DenseSet.h" + #include "llvm/Support/JSON.h" + #include +@@ -56,12 +57,8 @@ + }; + llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); + +-enum ColumnType { +- eColumnTypeString, +- eColumnTypeNumber, +- eColumnTypeBoolean, +- eColumnTypeTimestamp +-}; ++FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, ++ eColumnTypeTimestamp}; + + /// A ColumnDescriptor specifies what module attribute to show in a column of + /// the modules view, how to format it, and what the column’s label should be. +@@ -90,27 +87,23 @@ + + /// Names of checksum algorithms that may be supported by a debug adapter. + /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. +-enum ChecksumAlgorithm { +- eChecksumAlgorithmMD5, +- eChecksumAlgorithmSHA1, +- eChecksumAlgorithmSHA256, +- eChecksumAlgorithmTimestamp +-}; ++FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, ++ eChecksumAlgorithmSHA256, ++ eChecksumAlgorithmTimestamp}; + llvm::json::Value toJSON(const ChecksumAlgorithm &); + + /// Describes one or more type of breakpoint a BreakpointMode applies to. This + /// is a non-exhaustive enumeration and may expand as future breakpoint types + /// are added. +-enum BreakpointModeApplicability { +- /// In `SourceBreakpoint`'s. +- eBreakpointModeApplicabilitySource, +- /// In exception breakpoints applied in the `ExceptionFilterOptions`. +- eBreakpointModeApplicabilityException, +- /// In data breakpoints requested in the `DataBreakpointInfo` request. +- eBreakpointModeApplicabilityData, +- /// In `InstructionBreakpoint`'s. +- eBreakpointModeApplicabilityInstruction +-}; ++FLAGS_ENUM(BreakpointModeApplicability){ ++ /// In `SourceBreakpoint`'s. ++ eBreakpointModeApplicabilitySource, ++ /// In exception breakpoints applied in the `ExceptionFilterOptions`. ++ eBreakpointModeApplicabilityException, ++ /// In data breakpoints requested in the `DataBreakpointInfo` request. ++ eBreakpointModeApplicabilityData, ++ /// In `InstructionBreakpoint`'s. ++ eBreakpointModeApplicabilityInstruction}; + llvm::json::Value toJSON(const BreakpointModeApplicability &); + + /// A `BreakpointMode` is provided as a option when setting breakpoints on +@@ -133,101 +126,101 @@ + llvm::json::Value toJSON(const BreakpointMode &); + + /// Debug Adapter Features flags supported by lldb-dap. +-enum AdapterFeature { +- /// The debug adapter supports ANSI escape sequences in styling of +- /// `OutputEvent.output` and `Variable.value` fields. +- eAdapterFeatureANSIStyling, +- /// The debug adapter supports the `breakpointLocations` request. +- eAdapterFeatureBreakpointLocationsRequest, +- /// The debug adapter supports the `cancel` request. +- eAdapterFeatureCancelRequest, +- /// The debug adapter supports the `clipboard` context value in the +- /// `evaluate` request. +- eAdapterFeatureClipboardContext, +- /// The debug adapter supports the `completions` request. +- eAdapterFeatureCompletionsRequest, +- /// The debug adapter supports conditional breakpoints. +- eAdapterFeatureConditionalBreakpoints, +- /// The debug adapter supports the `configurationDone` request. +- eAdapterFeatureConfigurationDoneRequest, +- /// The debug adapter supports the `asAddress` and `bytes` fields in the +- /// `dataBreakpointInfo` request. +- eAdapterFeatureDataBreakpointBytes, +- /// The debug adapter supports data breakpoints. +- eAdapterFeatureDataBreakpoints, +- /// The debug adapter supports the delayed loading of parts of the stack, +- /// which requires that both the `startFrame` and `levels` arguments and the +- /// `totalFrames` result of the `stackTrace` request are supported. +- eAdapterFeatureDelayedStackTraceLoading, +- /// The debug adapter supports the `disassemble` request. +- eAdapterFeatureDisassembleRequest, +- /// The debug adapter supports a (side effect free) `evaluate` request for +- /// data hovers. +- eAdapterFeatureEvaluateForHovers, +- /// The debug adapter supports `filterOptions` as an argument on the +- /// `setExceptionBreakpoints` request. +- eAdapterFeatureExceptionFilterOptions, +- /// The debug adapter supports the `exceptionInfo` request. +- eAdapterFeatureExceptionInfoRequest, +- /// The debug adapter supports `exceptionOptions` on the +- /// `setExceptionBreakpoints` request. +- eAdapterFeatureExceptionOptions, +- /// The debug adapter supports function breakpoints. +- eAdapterFeatureFunctionBreakpoints, +- /// The debug adapter supports the `gotoTargets` request. +- eAdapterFeatureGotoTargetsRequest, +- /// The debug adapter supports breakpoints that break execution after a +- /// specified number of hits. +- eAdapterFeatureHitConditionalBreakpoints, +- /// The debug adapter supports adding breakpoints based on instruction +- /// references. +- eAdapterFeatureInstructionBreakpoints, +- /// The debug adapter supports the `loadedSources` request. +- eAdapterFeatureLoadedSourcesRequest, +- /// The debug adapter supports log points by interpreting the `logMessage` +- /// attribute of the `SourceBreakpoint`. +- eAdapterFeatureLogPoints, +- /// The debug adapter supports the `modules` request. +- eAdapterFeatureModulesRequest, +- /// The debug adapter supports the `readMemory` request. +- eAdapterFeatureReadMemoryRequest, +- /// The debug adapter supports restarting a frame. +- eAdapterFeatureRestartFrame, +- /// The debug adapter supports the `restart` request. In this case a client +- /// should not implement `restart` by terminating and relaunching the +- /// adapter but by calling the `restart` request. +- eAdapterFeatureRestartRequest, +- /// The debug adapter supports the `setExpression` request. +- eAdapterFeatureSetExpression, +- /// The debug adapter supports setting a variable to a value. +- eAdapterFeatureSetVariable, +- /// The debug adapter supports the `singleThread` property on the execution +- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +- /// `stepBack`). +- eAdapterFeatureSingleThreadExecutionRequests, +- /// The debug adapter supports stepping back via the `stepBack` and +- /// `reverseContinue` requests. +- eAdapterFeatureStepBack, +- /// The debug adapter supports the `stepInTargets` request. +- eAdapterFeatureStepInTargetsRequest, +- /// The debug adapter supports stepping granularities (argument +- /// `granularity`) for the stepping requests. +- eAdapterFeatureSteppingGranularity, +- /// The debug adapter supports the `terminate` request. +- eAdapterFeatureTerminateRequest, +- /// The debug adapter supports the `terminateThreads` request. +- eAdapterFeatureTerminateThreadsRequest, +- /// The debug adapter supports the `suspendDebuggee` attribute on the +- /// `disconnect` request. +- eAdapterFeatureSuspendDebuggee, +- /// The debug adapter supports a `format` attribute on the `stackTrace`, +- /// `variables`, and `evaluate` requests. +- eAdapterFeatureValueFormattingOptions, +- /// The debug adapter supports the `writeMemory` request. +- eAdapterFeatureWriteMemoryRequest, +- /// The debug adapter supports the `terminateDebuggee` attribute on the +- /// `disconnect` request. +- eAdapterFeatureTerminateDebuggee, ++FLAGS_ENUM(AdapterFeature){ ++ /// The debug adapter supports ANSI escape sequences in styling of ++ /// `OutputEvent.output` and `Variable.value` fields. ++ eAdapterFeatureANSIStyling, ++ /// The debug adapter supports the `breakpointLocations` request. ++ eAdapterFeatureBreakpointLocationsRequest, ++ /// The debug adapter supports the `cancel` request. ++ eAdapterFeatureCancelRequest, ++ /// The debug adapter supports the `clipboard` context value in the ++ /// `evaluate` request. ++ eAdapterFeatureClipboardContext, ++ /// The debug adapter supports the `completions` request. ++ eAdapterFeatureCompletionsRequest, ++ /// The debug adapter supports conditional breakpoints. ++ eAdapterFeatureConditionalBreakpoints, ++ /// The debug adapter supports the `configurationDone` request. ++ eAdapterFeatureConfigurationDoneRequest, ++ /// The debug adapter supports the `asAddress` and `bytes` fields in the ++ /// `dataBreakpointInfo` request. ++ eAdapterFeatureDataBreakpointBytes, ++ /// The debug adapter supports data breakpoints. ++ eAdapterFeatureDataBreakpoints, ++ /// The debug adapter supports the delayed loading of parts of the stack, ++ /// which requires that both the `startFrame` and `levels` arguments and the ++ /// `totalFrames` result of the `stackTrace` request are supported. ++ eAdapterFeatureDelayedStackTraceLoading, ++ /// The debug adapter supports the `disassemble` request. ++ eAdapterFeatureDisassembleRequest, ++ /// The debug adapter supports a (side effect free) `evaluate` request for ++ /// data hovers. ++ eAdapterFeatureEvaluateForHovers, ++ /// The debug adapter supports `filterOptions` as an argument on the ++ /// `setExceptionBreakpoints` request. ++ eAdapterFeatureExceptionFilterOptions, ++ /// The debug adapter supports the `exceptionInfo` request. ++ eAdapterFeatureExceptionInfoRequest, ++ /// The debug adapter supports `exceptionOptions` on the ++ /// `setExceptionBreakpoints` request. ++ eAdapterFeatureExceptionOptions, ++ /// The debug adapter supports function breakpoints. ++ eAdapterFeatureFunctionBreakpoints, ++ /// The debug adapter supports the `gotoTargets` request. ++ eAdapterFeatureGotoTargetsRequest, ++ /// The debug adapter supports breakpoints that break execution after a ++ /// specified number of hits. ++ eAdapterFeatureHitConditionalBreakpoints, ++ /// The debug adapter supports adding breakpoints based on instruction ++ /// references. ++ eAdapterFeatureInstructionBreakpoints, ++ /// The debug adapter supports the `loadedSources` request. ++ eAdapterFeatureLoadedSourcesRequest, ++ /// The debug adapter supports log points by interpreting the `logMessage` ++ /// attribute of the `SourceBreakpoint`. ++ eAdapterFeatureLogPoints, ++ /// The debug adapter supports the `modules` request. ++ eAdapterFeatureModulesRequest, ++ /// The debug adapter supports the `readMemory` request. ++ eAdapterFeatureReadMemoryRequest, ++ /// The debug adapter supports restarting a frame. ++ eAdapterFeatureRestartFrame, ++ /// The debug adapter supports the `restart` request. In this case a client ++ /// should not implement `restart` by terminating and relaunching the ++ /// adapter but by calling the `restart` request. ++ eAdapterFeatureRestartRequest, ++ /// The debug adapter supports the `setExpression` request. ++ eAdapterFeatureSetExpression, ++ /// The debug adapter supports setting a variable to a value. ++ eAdapterFeatureSetVariable, ++ /// The debug adapter supports the `singleThread` property on the execution ++ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, ++ /// `stepBack`). ++ eAdapterFeatureSingleThreadExecutionRequests, ++ /// The debug adapter supports stepping back via the `stepBack` and ++ /// `reverseContinue` requests. ++ eAdapterFeatureStepBack, ++ /// The debug adapter supports the `stepInTargets` request. ++ eAdapterFeatureStepInTargetsRequest, ++ /// The debug adapter supports stepping granularities (argument ++ /// `granularity`) for the stepping requests. ++ eAdapterFeatureSteppingGranularity, ++ /// The debug adapter supports the `terminate` request. ++ eAdapterFeatureTerminateRequest, ++ /// The debug adapter supports the `terminateThreads` request. ++ eAdapterFeatureTerminateThreadsRequest, ++ /// The debug adapter supports the `suspendDebuggee` attribute on the ++ /// `disconnect` request. ++ eAdapterFeatureSuspendDebuggee, ++ /// The debug adapter supports a `format` attribute on the `stackTrace`, ++ /// `variables`, and `evaluate` requests. ++ eAdapterFeatureValueFormattingOptions, ++ /// The debug adapter supports the `writeMemory` request. ++ eAdapterFeatureWriteMemoryRequest, ++ /// The debug adapter supports the `terminateDebuggee` attribute on the ++ /// `disconnect` request. ++ eAdapterFeatureTerminateDebuggee, + }; + + /// Information about the capabilities of a debug adapter. +@@ -268,10 +261,10 @@ + }; + llvm::json::Value toJSON(const Capabilities &); + +-enum PresentationHint { +- ePresentationHintNormal, +- ePresentationHintEmphasize, +- ePresentationHintDeemphasize, ++FLAGS_ENUM(PresentationHint){ ++ ePresentationHintNormal, ++ ePresentationHintEmphasize, ++ ePresentationHintDeemphasize, + }; + + /// A `Source` is a descriptor for source code. It is returned from the debug +diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +--- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ++++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +@@ -1,7 +1,7 @@ + // Header + // + // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) +-// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) ++// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) + // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) + // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) + // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) +diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c +--- a/offload/test/offloading/gpupgo/pgo1.c ++++ b/offload/test/offloading/gpupgo/pgo1.c +@@ -14,7 +14,7 @@ + // RUN: %target_triple.%basename_t.clang.profraw | \ + // RUN: %fcheck-generic --check-prefix="CLANG-PGO" + +-// REQUIRES: gpu ++// REQUIRES: amdgpu + // REQUIRES: pgo + + int test1(int a) { return a / 2; } +diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c +--- a/offload/test/offloading/gpupgo/pgo2.c ++++ b/offload/test/offloading/gpupgo/pgo2.c +@@ -48,7 +48,7 @@ + // RUN: %target_triple.%basename_t.hfdi.profraw \ + // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" + +-// REQUIRES: gpu ++// REQUIRES: amdgpu + // REQUIRES: pgo + + int main() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 725480b45c06df..005737af0dd2ac 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" - LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" + LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" + LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 5d85aa5a6d2274..5a732df12a541f 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,611 +1,566 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 2cd88ea..509398d 100644 +index 509398d..99ef3cb 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,591 +1 @@ +@@ -1 +1,546 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ----- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --@@ -25557,31 +25557,8 @@ -- if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations)) -- return NarrowBOp; -- --- // If only EXTRACT_SUBVECTOR nodes use the source vector we can --- // simplify it based on the (valid) extractions. --- if (!V.getValueType().isScalableVector() && --- llvm::all_of(V->users(), [&](SDNode *Use) { --- return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR && --- Use->getOperand(0) == V; --- })) { --- unsigned NumElts = V.getValueType().getVectorNumElements(); --- APInt DemandedElts = APInt::getZero(NumElts); --- for (SDNode *User : V->users()) { --- unsigned ExtIdx = User->getConstantOperandVal(1); --- unsigned NumSubElts = User->getValueType(0).getVectorNumElements(); --- DemandedElts.setBits(ExtIdx, ExtIdx + NumSubElts); --- } --- if (SimplifyDemandedVectorElts(V, DemandedElts, /*AssumeSingleUse=*/true)) { --- // We simplified the vector operand of this extract subvector. If this --- // extract is not dead, visit it again so it is folded properly. --- if (N->getOpcode() != ISD::DELETED_NODE) --- AddToWorklist(N); --- return SDValue(N, 0); --- } --- } else { --- if (SimplifyDemandedVectorElts(SDValue(N, 0))) --- return SDValue(N, 0); --- } --+ if (SimplifyDemandedVectorElts(SDValue(N, 0))) --+ return SDValue(N, 0); -- -- return SDValue(); -- } --diff -ruN --strip-trailing-cr a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp ----- a/llvm/lib/Target/X86/X86ISelLowering.cpp --+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp --@@ -58823,8 +58823,6 @@ -- -- uint64_t IdxVal = N->getConstantOperandVal(2); -- MVT SubVecVT = SubVec.getSimpleValueType(); --- int VecNumElts = OpVT.getVectorNumElements(); --- int SubVecNumElts = SubVecVT.getVectorNumElements(); -- -- if (Vec.isUndef() && SubVec.isUndef()) -- return DAG.getUNDEF(OpVT); --@@ -58884,9 +58882,10 @@ -- SubVec.getOperand(0).getSimpleValueType() == OpVT && -- (IdxVal != 0 || -- !(Vec.isUndef() || ISD::isBuildVectorAllZeros(Vec.getNode())))) { --- SDValue ExtSrc = SubVec.getOperand(0); -- int ExtIdxVal = SubVec.getConstantOperandVal(1); -- if (ExtIdxVal != 0) { --+ int VecNumElts = OpVT.getVectorNumElements(); --+ int SubVecNumElts = SubVecVT.getVectorNumElements(); -- SmallVector Mask(VecNumElts); -- // First create an identity shuffle mask. -- for (int i = 0; i != VecNumElts; ++i) --@@ -58894,24 +58893,8 @@ -- // Now insert the extracted portion. -- for (int i = 0; i != SubVecNumElts; ++i) -- Mask[i + IdxVal] = i + ExtIdxVal + VecNumElts; --- return DAG.getVectorShuffle(OpVT, dl, Vec, ExtSrc, Mask); --- } --- // If we're broadcasting, see if we can use a blend instead of --- // extract/insert pair. For subvector broadcasts, we must ensure that the --- // subvector is aligned with the insertion/extractions. --- if (ExtSrc.getOpcode() == X86ISD::VBROADCAST || --- ExtSrc.getOpcode() == X86ISD::VBROADCAST_LOAD || --- (ExtSrc.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD && --- (ExtIdxVal % SubVecNumElts) == 0 && (IdxVal % SubVecNumElts) == 0 && --- cast(ExtSrc)->getMemoryVT() == SubVecVT)) { --- SmallVector Mask(VecNumElts); --- // First create an identity shuffle mask. --- for (int i = 0; i != VecNumElts; ++i) --- Mask[i] = i; --- // Now blend the broadcast. --- for (int i = 0; i != SubVecNumElts; ++i) --- Mask[i + IdxVal] = i + IdxVal + VecNumElts; --- return DAG.getVectorShuffle(OpVT, dl, Vec, ExtSrc, Mask); --+ --+ return DAG.getVectorShuffle(OpVT, dl, Vec, SubVec.getOperand(0), Mask); -- } -- } -- --@@ -58959,7 +58942,7 @@ -- // If we're splatting the lower half subvector of a full vector load into the -- // upper half, attempt to create a subvector broadcast. -- // TODO: Drop hasOneUse checks. --- if ((int)IdxVal == (VecNumElts / 2) && --+ if (IdxVal == (OpVT.getVectorNumElements() / 2) && -- Vec.getValueSizeInBits() == (2 * SubVec.getValueSizeInBits()) && -- (Vec.hasOneUse() || SubVec.hasOneUse())) { -- auto *VecLd = dyn_cast(Vec); --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll ----- a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll --+++ b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll --@@ -2239,7 +2239,7 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512F-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2253,7 +2253,7 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2267,7 +2267,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rsi), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rdx) --@@ -2458,7 +2458,7 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512F-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2472,7 +2472,7 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2486,7 +2486,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rsi), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rdx) --@@ -3095,7 +3095,7 @@ -- ; AVX512F: # %bb.0: -- ; AVX512F-NEXT: vpbroadcastw (%rdi), %ymm0 -- ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],mem[1,2,3,4,5],xmm0[6],mem[7] ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rsi), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rsi), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rdx) --@@ -3107,7 +3107,7 @@ -- ; AVX512DQ: # %bb.0: -- ; AVX512DQ-NEXT: vpbroadcastw (%rdi), %ymm0 -- ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],mem[1,2,3,4,5],xmm0[6],mem[7] ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rsi), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rdx) --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll ----- a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll --+++ b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll --@@ -2573,7 +2573,8 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,15,3,4,15,6,7,15,9,10,15,12,13,15] -- ; AVX512F-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2590,7 +2591,8 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,15,3,4,15,6,7,15,9,10,15,12,13,15] -- ; AVX512DQ-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2835,7 +2837,8 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2852,7 +2855,8 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2868,7 +2872,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rdx), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rcx) --@@ -3096,7 +3100,8 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3113,7 +3118,8 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3129,7 +3135,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rdx), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rcx) --@@ -3608,11 +3614,12 @@ -- ; AVX512F: # %bb.0: -- ; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2],xmm0[3],xmm1[4,5],xmm0[6],xmm1[7] ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3624,11 +3631,12 @@ -- ; AVX512DQ: # %bb.0: -- ; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2],xmm0[3],xmm1[4,5],xmm0[6],xmm1[7] ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3860,11 +3868,12 @@ -- ; AVX512F: # %bb.0: -- ; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5],xmm0[6],xmm1[7] ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3876,11 +3885,12 @@ -- ; AVX512DQ: # %bb.0: -- ; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5],xmm0[6],xmm1[7] ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/pr42905.ll b/llvm/test/CodeGen/X86/pr42905.ll ----- a/llvm/test/CodeGen/X86/pr42905.ll --+++ b/llvm/test/CodeGen/X86/pr42905.ll --@@ -4,10 +4,16 @@ -- define <4 x double> @autogen_SD30452(i1 %L230) { -- ; CHECK-LABEL: autogen_SD30452: -- ; CHECK: # %bb.0: # %BB ---; CHECK-NEXT: movdqa {{.*#+}} xmm0 = [151829,151829] ---; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] ---; CHECK-NEXT: cvtdq2pd %xmm0, %xmm0 ---; CHECK-NEXT: movaps %xmm0, %xmm1 --+; CHECK-NEXT: movdqa {{.*#+}} xmm1 = [151829,151829] --+; CHECK-NEXT: movq %xmm0, %rax --+; CHECK-NEXT: cvtsi2sd %rax, %xmm0 --+; CHECK-NEXT: pshufd {{.*#+}} xmm2 = xmm0[2,3,2,3] --+; CHECK-NEXT: movq %xmm2, %rax --+; CHECK-NEXT: xorps %xmm2, %xmm2 --+; CHECK-NEXT: cvtsi2sd %rax, %xmm2 --+; CHECK-NEXT: unpcklpd {{.*#+}} xmm0 = xmm0[0],xmm2[0] --+; CHECK-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] --+; CHECK-NEXT: cvtdq2pd %xmm1, %xmm1 -- ; CHECK-NEXT: retq -- BB: -- %I = insertelement <4 x i64> zeroinitializer, i64 151829, i32 3 --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/sad.ll b/llvm/test/CodeGen/X86/sad.ll ----- a/llvm/test/CodeGen/X86/sad.ll --+++ b/llvm/test/CodeGen/X86/sad.ll --@@ -927,7 +927,8 @@ -- ; AVX512F-NEXT: vmovdqu 32(%rdi), %ymm1 -- ; AVX512F-NEXT: vpsadbw 32(%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpsadbw (%rdx), %ymm0, %ymm0 ---; AVX512F-NEXT: vpaddq %ymm1, %ymm0, %ymm0 --+; AVX512F-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 --+; AVX512F-NEXT: vpaddq %zmm1, %zmm0, %zmm0 -- ; AVX512F-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512F-NEXT: vpaddq %xmm1, %xmm0, %xmm0 -- ; AVX512F-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll ----- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll --+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll --@@ -2079,7 +2079,7 @@ -- ; AVX-NEXT: vpsrld $16, %xmm8, %xmm10 -- ; AVX-NEXT: vpunpckhdq {{.*#+}} xmm10 = xmm3[2],xmm10[2],xmm3[3],xmm10[3] -- ; AVX-NEXT: vpunpckhwd {{.*#+}} xmm12 = xmm8[4],xmm3[4],xmm8[5],xmm3[5],xmm8[6],xmm3[6],xmm8[7],xmm3[7] ---; AVX-NEXT: vpshufd {{.*#+}} xmm12 = xmm12[1,1,2,3] --+; AVX-NEXT: vpshuflw {{.*#+}} xmm12 = xmm12[2,2,2,2,4,5,6,7] -- ; AVX-NEXT: vpshufhw {{.*#+}} xmm12 = xmm12[0,1,2,3,4,5,5,4] -- ; AVX-NEXT: vinsertf128 $1, %xmm10, %ymm12, %ymm10 -- ; AVX-NEXT: vandnps %ymm10, %ymm6, %ymm6 --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll ----- a/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll --+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll --@@ -350,7 +350,7 @@ -- ; AVX-NEXT: vshufpd {{.*#+}} ymm4 = ymm4[0,0,3,3] -- ; AVX-NEXT: vblendps {{.*#+}} ymm4 = ymm4[0,1],ymm5[2,3],ymm4[4,5,6],ymm5[7] -- ; AVX-NEXT: vbroadcastf128 {{.*#+}} ymm5 = mem[0,1,0,1] ---; AVX-NEXT: vblendps {{.*#+}} ymm7 = ymm0[0,1,2,3],ymm5[4,5,6,7] --+; AVX-NEXT: vinsertf128 $1, %xmm5, %ymm0, %ymm7 -- ; AVX-NEXT: vblendps {{.*#+}} ymm4 = ymm7[0],ymm4[1,2,3],ymm7[4],ymm4[5,6,7] -- ; AVX-NEXT: vpermilps {{.*#+}} ymm1 = ymm1[u,u,u,2,u,u,u,7] -- ; AVX-NEXT: vblendps {{.*#+}} ymm0 = ymm1[0,1],ymm0[2],ymm1[3,4,5,6,7] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll b/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll --@@ -170,7 +170,7 @@ -- ; AVX512-LABEL: test_v16f32: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxps %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxps %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxps %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -264,7 +264,7 @@ -- ; AVX512-LABEL: test_v8f64: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -306,7 +306,7 @@ -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll b/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll --@@ -175,7 +175,7 @@ -- ; AVX512-LABEL: test_v16f32: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxps %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxps %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxps %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -311,7 +311,7 @@ -- ; AVX512-LABEL: test_v8f64: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -353,7 +353,7 @@ -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll b/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll --@@ -216,7 +216,7 @@ -- ; AVX512-LABEL: test_v16f32: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminps %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminps %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminps %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -310,7 +310,7 @@ -- ; AVX512-LABEL: test_v8f64: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -352,7 +352,7 @@ -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-mul.ll b/llvm/test/CodeGen/X86/vector-reduce-mul.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-mul.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-mul.ll --@@ -357,14 +357,14 @@ -- ; AVX512BW-LABEL: test_v8i64: -- ; AVX512BW: # %bb.0: -- ; AVX512BW-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BW-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BW-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BW-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BW-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BW-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BW-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BW-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BW-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --@@ -390,14 +390,14 @@ -- ; AVX512BWVL-LABEL: test_v8i64: -- ; AVX512BWVL: # %bb.0: -- ; AVX512BWVL-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BWVL-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BWVL-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BWVL-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BWVL-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BWVL-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BWVL-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --@@ -667,14 +667,14 @@ -- ; AVX512BW-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BW-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BW-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BW-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BW-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BW-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BW-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BW-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BW-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --@@ -708,14 +708,14 @@ -- ; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BWVL-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BWVL-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BWVL-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BWVL-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BWVL-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BWVL-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll b/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll ----- a/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll --+++ b/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll --@@ -3862,14 +3862,15 @@ -- ; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2 -- ; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3,4,5,6,7] -- ; AVX-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,0,1,1] ---; AVX-NEXT: vbroadcastss (%rdi), %xmm3 --+; AVX-NEXT: vbroadcastss (%rdi), %ymm3 -- ; AVX-NEXT: vpaddb 32(%rsi), %xmm0, %xmm0 ---; AVX-NEXT: vpaddb (%rsi), %xmm1, %xmm1 -- ; AVX-NEXT: vpblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm3[4,5],xmm2[6,7] -- ; AVX-NEXT: vpaddb 16(%rsi), %xmm2, %xmm2 ---; AVX-NEXT: vmovdqa %xmm2, 16(%rdx) --+; AVX-NEXT: vpaddb (%rsi), %xmm1, %xmm1 -- ; AVX-NEXT: vmovdqa %xmm1, (%rdx) --+; AVX-NEXT: vmovdqa %xmm2, 16(%rdx) -- ; AVX-NEXT: vmovdqa %xmm0, 32(%rdx) --+; AVX-NEXT: vzeroupper -- ; AVX-NEXT: retq -- ; -- ; AVX2-SLOW-LABEL: vec384_i32_widen_to_i96_factor3_broadcast_to_v4i96_factor4: --@@ -4115,7 +4116,7 @@ -- ; AVX: # %bb.0: -- ; AVX-NEXT: vmovdqa 48(%rdi), %xmm0 -- ; AVX-NEXT: vpblendw {{.*#+}} xmm0 = mem[0,1],xmm0[2,3,4,5,6,7] ---; AVX-NEXT: vbroadcastss (%rdi), %xmm1 --+; AVX-NEXT: vbroadcastss (%rdi), %ymm1 -- ; AVX-NEXT: vmovaps 32(%rsi), %ymm2 -- ; AVX-NEXT: vxorps %xmm3, %xmm3, %xmm3 -- ; AVX-NEXT: vblendps {{.*#+}} xmm1 = xmm3[0,1],xmm1[2],xmm3[3] ++diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ++--- a/clang/lib/Driver/ToolChains/Clang.cpp +++++ b/clang/lib/Driver/ToolChains/Clang.cpp ++@@ -6397,7 +6397,9 @@ ++ Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, ++ options::OPT_fno_convergent_functions); ++ ++- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); +++ // NVPTX doesn't support PGO or coverage +++ if (!Triple.isNVPTX()) +++ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); ++ ++ Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); ++ ++diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ++--- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu +++++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ++@@ -0,0 +1,33 @@ +++// Check that profiling/coverage arguments doen't get passed down to device-side +++// compilation. +++// +++// +++// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// XRUN: -fprofile-generate %s 2>&1 | \ +++// XRUN: FileCheck --check-prefixes=CHECK,PROF %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -fprofile-instr-generate %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,PROF %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -coverage %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -ftest-coverage %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,PROF %s +++// +++// +++// CHECK-NOT: error: unsupported option '-fprofile +++// CHECK-NOT: error: invalid argument +++// CHECK-DAG: "-fcuda-is-device" +++// CHECK-NOT: "-f{{[^"/]*coverage.*}}" +++// CHECK-NOT: "-fprofile{{[^"]*}}" +++// CHECK: "-triple" "x86_64-unknown-linux-gnu" +++// PROF: "-fprofile{{.*}}" +++// GCOV: "-coverage-notes-file= ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp ++--- a/lldb/tools/lldb-dap/DAP.cpp +++++ b/lldb/tools/lldb-dap/DAP.cpp ++@@ -711,12 +711,12 @@ ++ [](const std::string &message) -> llvm::StringRef { ++ return message; ++ }, ++- [](const protocol::Response::Message &message) +++ [](const protocol::ResponseMessage &message) ++ -> llvm::StringRef { ++ switch (message) { ++- case protocol::Response::Message::cancelled: +++ case protocol::eResponseMessageCancelled: ++ return "cancelled"; ++- case protocol::Response::Message::notStopped: +++ case protocol::eResponseMessageNotStopped: ++ return "notStopped"; ++ } ++ }), ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ++@@ -7,6 +7,7 @@ ++ //===----------------------------------------------------------------------===// ++ ++ #include "Protocol/ProtocolBase.h" +++#include "lldb/lldb-enumerations.h" ++ #include "llvm/ADT/StringRef.h" ++ #include "llvm/ADT/StringSwitch.h" ++ #include "llvm/Support/ErrorHandling.h" ++@@ -31,11 +32,8 @@ ++ ++ namespace lldb_dap::protocol { ++ ++-enum MessageType { ++- eMessageTypeRequest, ++- eMessageTypeResponse, ++- eMessageTypeEvent ++-}; +++FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, +++ eMessageTypeEvent}; ++ ++ bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { ++ auto rawType = Params.getAsString(); ++@@ -107,12 +105,12 @@ ++ ++ if (R.message) { ++ assert(!R.success && "message can only be used if success is false"); ++- if (const auto *messageEnum = std::get_if(&*R.message)) { +++ if (const auto *messageEnum = std::get_if(&*R.message)) { ++ switch (*messageEnum) { ++- case Response::Message::cancelled: +++ case eResponseMessageCancelled: ++ Result.insert({"message", "cancelled"}); ++ break; ++- case Response::Message::notStopped: +++ case eResponseMessageNotStopped: ++ Result.insert({"message", "notStopped"}); ++ break; ++ } ++@@ -129,16 +127,16 @@ ++ } ++ ++ bool fromJSON(json::Value const &Params, ++- std::variant &M, json::Path P) { +++ std::variant &M, json::Path P) { ++ auto rawMessage = Params.getAsString(); ++ if (!rawMessage) { ++ P.report("expected a string"); ++ return false; ++ } ++- std::optional message = ++- StringSwitch>(*rawMessage) ++- .Case("cancelled", Response::Message::cancelled) ++- .Case("notStopped", Response::Message::notStopped) +++ std::optional message = +++ StringSwitch>(*rawMessage) +++ .Case("cancelled", eResponseMessageCancelled) +++ .Case("notStopped", eResponseMessageNotStopped) ++ .Default(std::nullopt); ++ if (message) ++ M = *message; ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ++@@ -20,6 +20,7 @@ ++ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H ++ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H ++ +++#include "lldb/lldb-enumerations.h" ++ #include "llvm/Support/JSON.h" ++ #include ++ #include ++@@ -64,15 +65,15 @@ ++ llvm::json::Value toJSON(const Event &); ++ bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); ++ ++-/// Response for a request. ++-struct Response { ++- enum class Message { +++FLAGS_ENUM(ResponseMessage){ ++ /// The request was cancelled ++- cancelled, +++ eResponseMessageCancelled, ++ /// The request may be retried once the adapter is in a 'stopped' state ++- notStopped, ++- }; +++ eResponseMessageNotStopped, +++}; ++ +++/// Response for a request. +++struct Response { ++ /// Sequence number of the corresponding request. ++ int64_t request_seq; ++ ++@@ -90,7 +91,7 @@ ++ /// Contains the raw error in short form if `success` is false. This raw error ++ /// might be interpreted by the client and is not shown in the UI. Some ++ /// predefined values exist. ++- std::optional> message; +++ std::optional> message; ++ ++ /// Contains request result if success is true and error details if success is ++ /// false. ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ++@@ -22,6 +22,8 @@ ++ ++ #include "Protocol/ProtocolBase.h" ++ #include "Protocol/ProtocolTypes.h" +++#include "lldb/lldb-enumerations.h" +++#include "llvm/ADT/DenseSet.h" ++ #include "llvm/Support/JSON.h" ++ #include ++ #include ++@@ -55,26 +57,26 @@ ++ using DisconnectResponse = VoidResponse; ++ ++ /// Features supported by DAP clients. ++-enum ClientFeature { ++- eClientFeatureVariableType, ++- eClientFeatureVariablePaging, ++- eClientFeatureRunInTerminalRequest, ++- eClientFeatureMemoryReferences, ++- eClientFeatureProgressReporting, ++- eClientFeatureInvalidatedEvent, ++- eClientFeatureMemoryEvent, ++- /// Client supports the `argsCanBeInterpretedByShell` attribute on the ++- /// `runInTerminal` request. ++- eClientFeatureArgsCanBeInterpretedByShell, ++- eClientFeatureStartDebuggingRequest, ++- /// The client will interpret ANSI escape sequences in the display of ++- /// `OutputEvent.output` and `Variable.value` fields when ++- /// `Capabilities.supportsANSIStyling` is also enabled. ++- eClientFeatureANSIStyling, +++FLAGS_ENUM(ClientFeature){ +++ eClientFeatureVariableType, +++ eClientFeatureVariablePaging, +++ eClientFeatureRunInTerminalRequest, +++ eClientFeatureMemoryReferences, +++ eClientFeatureProgressReporting, +++ eClientFeatureInvalidatedEvent, +++ eClientFeatureMemoryEvent, +++ /// Client supports the `argsCanBeInterpretedByShell` attribute on the +++ /// `runInTerminal` request. +++ eClientFeatureArgsCanBeInterpretedByShell, +++ eClientFeatureStartDebuggingRequest, +++ /// The client will interpret ANSI escape sequences in the display of +++ /// `OutputEvent.output` and `Variable.value` fields when +++ /// `Capabilities.supportsANSIStyling` is also enabled. +++ eClientFeatureANSIStyling, ++ }; ++ ++ /// Format of paths reported by the debug adapter. ++-enum PathFormat { ePatFormatPath, ePathFormatURI }; +++FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; ++ ++ /// Arguments for `initialize` request. ++ struct InitializeRequestArguments { ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ++@@ -20,6 +20,7 @@ ++ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H ++ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H ++ +++#include "lldb/lldb-enumerations.h" ++ #include "llvm/ADT/DenseSet.h" ++ #include "llvm/Support/JSON.h" ++ #include ++@@ -56,12 +57,8 @@ ++ }; ++ llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); ++ ++-enum ColumnType { ++- eColumnTypeString, ++- eColumnTypeNumber, ++- eColumnTypeBoolean, ++- eColumnTypeTimestamp ++-}; +++FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, +++ eColumnTypeTimestamp}; ++ ++ /// A ColumnDescriptor specifies what module attribute to show in a column of ++ /// the modules view, how to format it, and what the column’s label should be. ++@@ -90,27 +87,23 @@ ++ ++ /// Names of checksum algorithms that may be supported by a debug adapter. ++ /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. ++-enum ChecksumAlgorithm { ++- eChecksumAlgorithmMD5, ++- eChecksumAlgorithmSHA1, ++- eChecksumAlgorithmSHA256, ++- eChecksumAlgorithmTimestamp ++-}; +++FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, +++ eChecksumAlgorithmSHA256, +++ eChecksumAlgorithmTimestamp}; ++ llvm::json::Value toJSON(const ChecksumAlgorithm &); ++ ++ /// Describes one or more type of breakpoint a BreakpointMode applies to. This ++ /// is a non-exhaustive enumeration and may expand as future breakpoint types ++ /// are added. ++-enum BreakpointModeApplicability { ++- /// In `SourceBreakpoint`'s. ++- eBreakpointModeApplicabilitySource, ++- /// In exception breakpoints applied in the `ExceptionFilterOptions`. ++- eBreakpointModeApplicabilityException, ++- /// In data breakpoints requested in the `DataBreakpointInfo` request. ++- eBreakpointModeApplicabilityData, ++- /// In `InstructionBreakpoint`'s. ++- eBreakpointModeApplicabilityInstruction ++-}; +++FLAGS_ENUM(BreakpointModeApplicability){ +++ /// In `SourceBreakpoint`'s. +++ eBreakpointModeApplicabilitySource, +++ /// In exception breakpoints applied in the `ExceptionFilterOptions`. +++ eBreakpointModeApplicabilityException, +++ /// In data breakpoints requested in the `DataBreakpointInfo` request. +++ eBreakpointModeApplicabilityData, +++ /// In `InstructionBreakpoint`'s. +++ eBreakpointModeApplicabilityInstruction}; ++ llvm::json::Value toJSON(const BreakpointModeApplicability &); ++ ++ /// A `BreakpointMode` is provided as a option when setting breakpoints on ++@@ -133,101 +126,101 @@ ++ llvm::json::Value toJSON(const BreakpointMode &); ++ ++ /// Debug Adapter Features flags supported by lldb-dap. ++-enum AdapterFeature { ++- /// The debug adapter supports ANSI escape sequences in styling of ++- /// `OutputEvent.output` and `Variable.value` fields. ++- eAdapterFeatureANSIStyling, ++- /// The debug adapter supports the `breakpointLocations` request. ++- eAdapterFeatureBreakpointLocationsRequest, ++- /// The debug adapter supports the `cancel` request. ++- eAdapterFeatureCancelRequest, ++- /// The debug adapter supports the `clipboard` context value in the ++- /// `evaluate` request. ++- eAdapterFeatureClipboardContext, ++- /// The debug adapter supports the `completions` request. ++- eAdapterFeatureCompletionsRequest, ++- /// The debug adapter supports conditional breakpoints. ++- eAdapterFeatureConditionalBreakpoints, ++- /// The debug adapter supports the `configurationDone` request. ++- eAdapterFeatureConfigurationDoneRequest, ++- /// The debug adapter supports the `asAddress` and `bytes` fields in the ++- /// `dataBreakpointInfo` request. ++- eAdapterFeatureDataBreakpointBytes, ++- /// The debug adapter supports data breakpoints. ++- eAdapterFeatureDataBreakpoints, ++- /// The debug adapter supports the delayed loading of parts of the stack, ++- /// which requires that both the `startFrame` and `levels` arguments and the ++- /// `totalFrames` result of the `stackTrace` request are supported. ++- eAdapterFeatureDelayedStackTraceLoading, ++- /// The debug adapter supports the `disassemble` request. ++- eAdapterFeatureDisassembleRequest, ++- /// The debug adapter supports a (side effect free) `evaluate` request for ++- /// data hovers. ++- eAdapterFeatureEvaluateForHovers, ++- /// The debug adapter supports `filterOptions` as an argument on the ++- /// `setExceptionBreakpoints` request. ++- eAdapterFeatureExceptionFilterOptions, ++- /// The debug adapter supports the `exceptionInfo` request. ++- eAdapterFeatureExceptionInfoRequest, ++- /// The debug adapter supports `exceptionOptions` on the ++- /// `setExceptionBreakpoints` request. ++- eAdapterFeatureExceptionOptions, ++- /// The debug adapter supports function breakpoints. ++- eAdapterFeatureFunctionBreakpoints, ++- /// The debug adapter supports the `gotoTargets` request. ++- eAdapterFeatureGotoTargetsRequest, ++- /// The debug adapter supports breakpoints that break execution after a ++- /// specified number of hits. ++- eAdapterFeatureHitConditionalBreakpoints, ++- /// The debug adapter supports adding breakpoints based on instruction ++- /// references. ++- eAdapterFeatureInstructionBreakpoints, ++- /// The debug adapter supports the `loadedSources` request. ++- eAdapterFeatureLoadedSourcesRequest, ++- /// The debug adapter supports log points by interpreting the `logMessage` ++- /// attribute of the `SourceBreakpoint`. ++- eAdapterFeatureLogPoints, ++- /// The debug adapter supports the `modules` request. ++- eAdapterFeatureModulesRequest, ++- /// The debug adapter supports the `readMemory` request. ++- eAdapterFeatureReadMemoryRequest, ++- /// The debug adapter supports restarting a frame. ++- eAdapterFeatureRestartFrame, ++- /// The debug adapter supports the `restart` request. In this case a client ++- /// should not implement `restart` by terminating and relaunching the ++- /// adapter but by calling the `restart` request. ++- eAdapterFeatureRestartRequest, ++- /// The debug adapter supports the `setExpression` request. ++- eAdapterFeatureSetExpression, ++- /// The debug adapter supports setting a variable to a value. ++- eAdapterFeatureSetVariable, ++- /// The debug adapter supports the `singleThread` property on the execution ++- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, ++- /// `stepBack`). ++- eAdapterFeatureSingleThreadExecutionRequests, ++- /// The debug adapter supports stepping back via the `stepBack` and ++- /// `reverseContinue` requests. ++- eAdapterFeatureStepBack, ++- /// The debug adapter supports the `stepInTargets` request. ++- eAdapterFeatureStepInTargetsRequest, ++- /// The debug adapter supports stepping granularities (argument ++- /// `granularity`) for the stepping requests. ++- eAdapterFeatureSteppingGranularity, ++- /// The debug adapter supports the `terminate` request. ++- eAdapterFeatureTerminateRequest, ++- /// The debug adapter supports the `terminateThreads` request. ++- eAdapterFeatureTerminateThreadsRequest, ++- /// The debug adapter supports the `suspendDebuggee` attribute on the ++- /// `disconnect` request. ++- eAdapterFeatureSuspendDebuggee, ++- /// The debug adapter supports a `format` attribute on the `stackTrace`, ++- /// `variables`, and `evaluate` requests. ++- eAdapterFeatureValueFormattingOptions, ++- /// The debug adapter supports the `writeMemory` request. ++- eAdapterFeatureWriteMemoryRequest, ++- /// The debug adapter supports the `terminateDebuggee` attribute on the ++- /// `disconnect` request. ++- eAdapterFeatureTerminateDebuggee, +++FLAGS_ENUM(AdapterFeature){ +++ /// The debug adapter supports ANSI escape sequences in styling of +++ /// `OutputEvent.output` and `Variable.value` fields. +++ eAdapterFeatureANSIStyling, +++ /// The debug adapter supports the `breakpointLocations` request. +++ eAdapterFeatureBreakpointLocationsRequest, +++ /// The debug adapter supports the `cancel` request. +++ eAdapterFeatureCancelRequest, +++ /// The debug adapter supports the `clipboard` context value in the +++ /// `evaluate` request. +++ eAdapterFeatureClipboardContext, +++ /// The debug adapter supports the `completions` request. +++ eAdapterFeatureCompletionsRequest, +++ /// The debug adapter supports conditional breakpoints. +++ eAdapterFeatureConditionalBreakpoints, +++ /// The debug adapter supports the `configurationDone` request. +++ eAdapterFeatureConfigurationDoneRequest, +++ /// The debug adapter supports the `asAddress` and `bytes` fields in the +++ /// `dataBreakpointInfo` request. +++ eAdapterFeatureDataBreakpointBytes, +++ /// The debug adapter supports data breakpoints. +++ eAdapterFeatureDataBreakpoints, +++ /// The debug adapter supports the delayed loading of parts of the stack, +++ /// which requires that both the `startFrame` and `levels` arguments and the +++ /// `totalFrames` result of the `stackTrace` request are supported. +++ eAdapterFeatureDelayedStackTraceLoading, +++ /// The debug adapter supports the `disassemble` request. +++ eAdapterFeatureDisassembleRequest, +++ /// The debug adapter supports a (side effect free) `evaluate` request for +++ /// data hovers. +++ eAdapterFeatureEvaluateForHovers, +++ /// The debug adapter supports `filterOptions` as an argument on the +++ /// `setExceptionBreakpoints` request. +++ eAdapterFeatureExceptionFilterOptions, +++ /// The debug adapter supports the `exceptionInfo` request. +++ eAdapterFeatureExceptionInfoRequest, +++ /// The debug adapter supports `exceptionOptions` on the +++ /// `setExceptionBreakpoints` request. +++ eAdapterFeatureExceptionOptions, +++ /// The debug adapter supports function breakpoints. +++ eAdapterFeatureFunctionBreakpoints, +++ /// The debug adapter supports the `gotoTargets` request. +++ eAdapterFeatureGotoTargetsRequest, +++ /// The debug adapter supports breakpoints that break execution after a +++ /// specified number of hits. +++ eAdapterFeatureHitConditionalBreakpoints, +++ /// The debug adapter supports adding breakpoints based on instruction +++ /// references. +++ eAdapterFeatureInstructionBreakpoints, +++ /// The debug adapter supports the `loadedSources` request. +++ eAdapterFeatureLoadedSourcesRequest, +++ /// The debug adapter supports log points by interpreting the `logMessage` +++ /// attribute of the `SourceBreakpoint`. +++ eAdapterFeatureLogPoints, +++ /// The debug adapter supports the `modules` request. +++ eAdapterFeatureModulesRequest, +++ /// The debug adapter supports the `readMemory` request. +++ eAdapterFeatureReadMemoryRequest, +++ /// The debug adapter supports restarting a frame. +++ eAdapterFeatureRestartFrame, +++ /// The debug adapter supports the `restart` request. In this case a client +++ /// should not implement `restart` by terminating and relaunching the +++ /// adapter but by calling the `restart` request. +++ eAdapterFeatureRestartRequest, +++ /// The debug adapter supports the `setExpression` request. +++ eAdapterFeatureSetExpression, +++ /// The debug adapter supports setting a variable to a value. +++ eAdapterFeatureSetVariable, +++ /// The debug adapter supports the `singleThread` property on the execution +++ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +++ /// `stepBack`). +++ eAdapterFeatureSingleThreadExecutionRequests, +++ /// The debug adapter supports stepping back via the `stepBack` and +++ /// `reverseContinue` requests. +++ eAdapterFeatureStepBack, +++ /// The debug adapter supports the `stepInTargets` request. +++ eAdapterFeatureStepInTargetsRequest, +++ /// The debug adapter supports stepping granularities (argument +++ /// `granularity`) for the stepping requests. +++ eAdapterFeatureSteppingGranularity, +++ /// The debug adapter supports the `terminate` request. +++ eAdapterFeatureTerminateRequest, +++ /// The debug adapter supports the `terminateThreads` request. +++ eAdapterFeatureTerminateThreadsRequest, +++ /// The debug adapter supports the `suspendDebuggee` attribute on the +++ /// `disconnect` request. +++ eAdapterFeatureSuspendDebuggee, +++ /// The debug adapter supports a `format` attribute on the `stackTrace`, +++ /// `variables`, and `evaluate` requests. +++ eAdapterFeatureValueFormattingOptions, +++ /// The debug adapter supports the `writeMemory` request. +++ eAdapterFeatureWriteMemoryRequest, +++ /// The debug adapter supports the `terminateDebuggee` attribute on the +++ /// `disconnect` request. +++ eAdapterFeatureTerminateDebuggee, ++ }; ++ ++ /// Information about the capabilities of a debug adapter. ++@@ -268,10 +261,10 @@ ++ }; ++ llvm::json::Value toJSON(const Capabilities &); ++ ++-enum PresentationHint { ++- ePresentationHintNormal, ++- ePresentationHintEmphasize, ++- ePresentationHintDeemphasize, +++FLAGS_ENUM(PresentationHint){ +++ ePresentationHintNormal, +++ ePresentationHintEmphasize, +++ ePresentationHintDeemphasize, ++ }; ++ ++ /// A `Source` is a descriptor for source code. It is returned from the debug ++diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ++--- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +++++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ++@@ -1,7 +1,7 @@ ++ // Header ++ // ++ // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) ++-// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) +++// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) ++ // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) ++ // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) ++ // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) ++diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c ++--- a/offload/test/offloading/gpupgo/pgo1.c +++++ b/offload/test/offloading/gpupgo/pgo1.c ++@@ -14,7 +14,7 @@ ++ // RUN: %target_triple.%basename_t.clang.profraw | \ ++ // RUN: %fcheck-generic --check-prefix="CLANG-PGO" ++ ++-// REQUIRES: gpu +++// REQUIRES: amdgpu ++ // REQUIRES: pgo ++ ++ int test1(int a) { return a / 2; } ++diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c ++--- a/offload/test/offloading/gpupgo/pgo2.c +++++ b/offload/test/offloading/gpupgo/pgo2.c ++@@ -48,7 +48,7 @@ ++ // RUN: %target_triple.%basename_t.hfdi.profraw \ ++ // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" ++ ++-// REQUIRES: gpu +++// REQUIRES: amdgpu ++ // REQUIRES: pgo ++ ++ int main() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 5ceac99..725480b 100644 +index 725480b..005737a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "5eccd71ce4f852c7b2f06ecd1976d9e34040fcaa" -- LLVM_SHA256 = "fd100fd69425ebac40ed58ff0558e0064b74bd97b6023d5e65e4c706c726a483" -+ LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" -+ LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" +- LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" +- LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" ++ LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" ++ LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 9110bac3105338..7b1a0496a0965c 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "4639651b8267094807acab8b2709ed57cb924888" - SHARDY_SHA256 = "f3fd113eef0bffe8fbab9435b22022598b56dd710999fd5cafd5c1dfdc7c91ce" + SHARDY_COMMIT = "9435b34df0279d473240f5bcc2a829d0589ae372" + SHARDY_SHA256 = "5f2a037d3301a1407e5c94778dd56d855f5abe26999cce448ccfa1923cf9559f" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 5d85aa5a6d2274..5a732df12a541f 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,611 +1,566 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 2cd88ea..509398d 100644 +index 509398d..99ef3cb 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,591 +1 @@ +@@ -1 +1,546 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ----- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --@@ -25557,31 +25557,8 @@ -- if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations)) -- return NarrowBOp; -- --- // If only EXTRACT_SUBVECTOR nodes use the source vector we can --- // simplify it based on the (valid) extractions. --- if (!V.getValueType().isScalableVector() && --- llvm::all_of(V->users(), [&](SDNode *Use) { --- return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR && --- Use->getOperand(0) == V; --- })) { --- unsigned NumElts = V.getValueType().getVectorNumElements(); --- APInt DemandedElts = APInt::getZero(NumElts); --- for (SDNode *User : V->users()) { --- unsigned ExtIdx = User->getConstantOperandVal(1); --- unsigned NumSubElts = User->getValueType(0).getVectorNumElements(); --- DemandedElts.setBits(ExtIdx, ExtIdx + NumSubElts); --- } --- if (SimplifyDemandedVectorElts(V, DemandedElts, /*AssumeSingleUse=*/true)) { --- // We simplified the vector operand of this extract subvector. If this --- // extract is not dead, visit it again so it is folded properly. --- if (N->getOpcode() != ISD::DELETED_NODE) --- AddToWorklist(N); --- return SDValue(N, 0); --- } --- } else { --- if (SimplifyDemandedVectorElts(SDValue(N, 0))) --- return SDValue(N, 0); --- } --+ if (SimplifyDemandedVectorElts(SDValue(N, 0))) --+ return SDValue(N, 0); -- -- return SDValue(); -- } --diff -ruN --strip-trailing-cr a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp ----- a/llvm/lib/Target/X86/X86ISelLowering.cpp --+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp --@@ -58823,8 +58823,6 @@ -- -- uint64_t IdxVal = N->getConstantOperandVal(2); -- MVT SubVecVT = SubVec.getSimpleValueType(); --- int VecNumElts = OpVT.getVectorNumElements(); --- int SubVecNumElts = SubVecVT.getVectorNumElements(); -- -- if (Vec.isUndef() && SubVec.isUndef()) -- return DAG.getUNDEF(OpVT); --@@ -58884,9 +58882,10 @@ -- SubVec.getOperand(0).getSimpleValueType() == OpVT && -- (IdxVal != 0 || -- !(Vec.isUndef() || ISD::isBuildVectorAllZeros(Vec.getNode())))) { --- SDValue ExtSrc = SubVec.getOperand(0); -- int ExtIdxVal = SubVec.getConstantOperandVal(1); -- if (ExtIdxVal != 0) { --+ int VecNumElts = OpVT.getVectorNumElements(); --+ int SubVecNumElts = SubVecVT.getVectorNumElements(); -- SmallVector Mask(VecNumElts); -- // First create an identity shuffle mask. -- for (int i = 0; i != VecNumElts; ++i) --@@ -58894,24 +58893,8 @@ -- // Now insert the extracted portion. -- for (int i = 0; i != SubVecNumElts; ++i) -- Mask[i + IdxVal] = i + ExtIdxVal + VecNumElts; --- return DAG.getVectorShuffle(OpVT, dl, Vec, ExtSrc, Mask); --- } --- // If we're broadcasting, see if we can use a blend instead of --- // extract/insert pair. For subvector broadcasts, we must ensure that the --- // subvector is aligned with the insertion/extractions. --- if (ExtSrc.getOpcode() == X86ISD::VBROADCAST || --- ExtSrc.getOpcode() == X86ISD::VBROADCAST_LOAD || --- (ExtSrc.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD && --- (ExtIdxVal % SubVecNumElts) == 0 && (IdxVal % SubVecNumElts) == 0 && --- cast(ExtSrc)->getMemoryVT() == SubVecVT)) { --- SmallVector Mask(VecNumElts); --- // First create an identity shuffle mask. --- for (int i = 0; i != VecNumElts; ++i) --- Mask[i] = i; --- // Now blend the broadcast. --- for (int i = 0; i != SubVecNumElts; ++i) --- Mask[i + IdxVal] = i + IdxVal + VecNumElts; --- return DAG.getVectorShuffle(OpVT, dl, Vec, ExtSrc, Mask); --+ --+ return DAG.getVectorShuffle(OpVT, dl, Vec, SubVec.getOperand(0), Mask); -- } -- } -- --@@ -58959,7 +58942,7 @@ -- // If we're splatting the lower half subvector of a full vector load into the -- // upper half, attempt to create a subvector broadcast. -- // TODO: Drop hasOneUse checks. --- if ((int)IdxVal == (VecNumElts / 2) && --+ if (IdxVal == (OpVT.getVectorNumElements() / 2) && -- Vec.getValueSizeInBits() == (2 * SubVec.getValueSizeInBits()) && -- (Vec.hasOneUse() || SubVec.hasOneUse())) { -- auto *VecLd = dyn_cast(Vec); --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll ----- a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll --+++ b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast_from_memory.ll --@@ -2239,7 +2239,7 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512F-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2253,7 +2253,7 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2267,7 +2267,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rsi), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rdx) --@@ -2458,7 +2458,7 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512F-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512F-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2472,7 +2472,7 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vpaddb 32(%rsi), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vmovdqa %ymm1, 32(%rdx) --@@ -2486,7 +2486,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm0 = mem[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb (%rdi), %ymm1 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rsi), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rdx) --@@ -3095,7 +3095,7 @@ -- ; AVX512F: # %bb.0: -- ; AVX512F-NEXT: vpbroadcastw (%rdi), %ymm0 -- ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],mem[1,2,3,4,5],xmm0[6],mem[7] ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rsi), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rsi), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rdx) --@@ -3107,7 +3107,7 @@ -- ; AVX512DQ: # %bb.0: -- ; AVX512DQ-NEXT: vpbroadcastw (%rdi), %ymm0 -- ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],mem[1,2,3,4,5],xmm0[6],mem[7] ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rsi), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rdx) --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll ----- a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll --+++ b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll --@@ -2573,7 +2573,8 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,15,3,4,15,6,7,15,9,10,15,12,13,15] -- ; AVX512F-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2590,7 +2591,8 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,15,3,4,15,6,7,15,9,10,15,12,13,15] -- ; AVX512DQ-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2835,7 +2837,8 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2852,7 +2855,8 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -2868,7 +2872,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,15,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rdx), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rcx) --@@ -3096,7 +3100,8 @@ -- ; AVX512F-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512F-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512F-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3113,7 +3118,8 @@ -- ; AVX512DQ-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512DQ-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3129,7 +3135,7 @@ -- ; AVX512BW-NEXT: vpalignr {{.*#+}} xmm1 = xmm1[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],xmm0[0] -- ; AVX512BW-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[15,0,1,2,3,4,5,6,7,8,9,10,15,12,13,14] -- ; AVX512BW-NEXT: vpbroadcastb %xmm0, %ymm0 ---; AVX512BW-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512BW-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512BW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -- ; AVX512BW-NEXT: vpaddb (%rdx), %zmm0, %zmm0 -- ; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rcx) --@@ -3608,11 +3614,12 @@ -- ; AVX512F: # %bb.0: -- ; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2],xmm0[3],xmm1[4,5],xmm0[6],xmm1[7] ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3624,11 +3631,12 @@ -- ; AVX512DQ: # %bb.0: -- ; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2],xmm0[3],xmm1[4,5],xmm0[6],xmm1[7] ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3860,11 +3868,12 @@ -- ; AVX512F: # %bb.0: -- ; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512F-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5],xmm0[6],xmm1[7] ---; AVX512F-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512F-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) --@@ -3876,11 +3885,12 @@ -- ; AVX512DQ: # %bb.0: -- ; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -- ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 ---; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -- ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 --+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm0, %zmm0 --+; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 -- ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5],xmm0[6],xmm1[7] ---; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm0[4,5,6,7] --+; AVX512DQ-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 -- ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 -- ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/pr42905.ll b/llvm/test/CodeGen/X86/pr42905.ll ----- a/llvm/test/CodeGen/X86/pr42905.ll --+++ b/llvm/test/CodeGen/X86/pr42905.ll --@@ -4,10 +4,16 @@ -- define <4 x double> @autogen_SD30452(i1 %L230) { -- ; CHECK-LABEL: autogen_SD30452: -- ; CHECK: # %bb.0: # %BB ---; CHECK-NEXT: movdqa {{.*#+}} xmm0 = [151829,151829] ---; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] ---; CHECK-NEXT: cvtdq2pd %xmm0, %xmm0 ---; CHECK-NEXT: movaps %xmm0, %xmm1 --+; CHECK-NEXT: movdqa {{.*#+}} xmm1 = [151829,151829] --+; CHECK-NEXT: movq %xmm0, %rax --+; CHECK-NEXT: cvtsi2sd %rax, %xmm0 --+; CHECK-NEXT: pshufd {{.*#+}} xmm2 = xmm0[2,3,2,3] --+; CHECK-NEXT: movq %xmm2, %rax --+; CHECK-NEXT: xorps %xmm2, %xmm2 --+; CHECK-NEXT: cvtsi2sd %rax, %xmm2 --+; CHECK-NEXT: unpcklpd {{.*#+}} xmm0 = xmm0[0],xmm2[0] --+; CHECK-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] --+; CHECK-NEXT: cvtdq2pd %xmm1, %xmm1 -- ; CHECK-NEXT: retq -- BB: -- %I = insertelement <4 x i64> zeroinitializer, i64 151829, i32 3 --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/sad.ll b/llvm/test/CodeGen/X86/sad.ll ----- a/llvm/test/CodeGen/X86/sad.ll --+++ b/llvm/test/CodeGen/X86/sad.ll --@@ -927,7 +927,8 @@ -- ; AVX512F-NEXT: vmovdqu 32(%rdi), %ymm1 -- ; AVX512F-NEXT: vpsadbw 32(%rdx), %ymm1, %ymm1 -- ; AVX512F-NEXT: vpsadbw (%rdx), %ymm0, %ymm0 ---; AVX512F-NEXT: vpaddq %ymm1, %ymm0, %ymm0 --+; AVX512F-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 --+; AVX512F-NEXT: vpaddq %zmm1, %zmm0, %zmm0 -- ; AVX512F-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512F-NEXT: vpaddq %xmm1, %xmm0, %xmm0 -- ; AVX512F-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll ----- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll --+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll --@@ -2079,7 +2079,7 @@ -- ; AVX-NEXT: vpsrld $16, %xmm8, %xmm10 -- ; AVX-NEXT: vpunpckhdq {{.*#+}} xmm10 = xmm3[2],xmm10[2],xmm3[3],xmm10[3] -- ; AVX-NEXT: vpunpckhwd {{.*#+}} xmm12 = xmm8[4],xmm3[4],xmm8[5],xmm3[5],xmm8[6],xmm3[6],xmm8[7],xmm3[7] ---; AVX-NEXT: vpshufd {{.*#+}} xmm12 = xmm12[1,1,2,3] --+; AVX-NEXT: vpshuflw {{.*#+}} xmm12 = xmm12[2,2,2,2,4,5,6,7] -- ; AVX-NEXT: vpshufhw {{.*#+}} xmm12 = xmm12[0,1,2,3,4,5,5,4] -- ; AVX-NEXT: vinsertf128 $1, %xmm10, %ymm12, %ymm10 -- ; AVX-NEXT: vandnps %ymm10, %ymm6, %ymm6 --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll ----- a/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll --+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i32-stride-5.ll --@@ -350,7 +350,7 @@ -- ; AVX-NEXT: vshufpd {{.*#+}} ymm4 = ymm4[0,0,3,3] -- ; AVX-NEXT: vblendps {{.*#+}} ymm4 = ymm4[0,1],ymm5[2,3],ymm4[4,5,6],ymm5[7] -- ; AVX-NEXT: vbroadcastf128 {{.*#+}} ymm5 = mem[0,1,0,1] ---; AVX-NEXT: vblendps {{.*#+}} ymm7 = ymm0[0,1,2,3],ymm5[4,5,6,7] --+; AVX-NEXT: vinsertf128 $1, %xmm5, %ymm0, %ymm7 -- ; AVX-NEXT: vblendps {{.*#+}} ymm4 = ymm7[0],ymm4[1,2,3],ymm7[4],ymm4[5,6,7] -- ; AVX-NEXT: vpermilps {{.*#+}} ymm1 = ymm1[u,u,u,2,u,u,u,7] -- ; AVX-NEXT: vblendps {{.*#+}} ymm0 = ymm1[0,1],ymm0[2],ymm1[3,4,5,6,7] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll b/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-fmax-fmin-fast.ll --@@ -170,7 +170,7 @@ -- ; AVX512-LABEL: test_v16f32: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxps %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxps %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxps %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -264,7 +264,7 @@ -- ; AVX512-LABEL: test_v8f64: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -306,7 +306,7 @@ -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll b/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-fmax-nnan.ll --@@ -175,7 +175,7 @@ -- ; AVX512-LABEL: test_v16f32: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxps %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxps %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxps %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -311,7 +311,7 @@ -- ; AVX512-LABEL: test_v8f64: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -353,7 +353,7 @@ -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vmaxpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vmaxpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vmaxpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll b/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-fmin-nnan.ll --@@ -216,7 +216,7 @@ -- ; AVX512-LABEL: test_v16f32: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminps %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminps %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminps %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -310,7 +310,7 @@ -- ; AVX512-LABEL: test_v8f64: -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --@@ -352,7 +352,7 @@ -- ; AVX512: # %bb.0: -- ; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf64x4 $1, %zmm0, %ymm1 ---; AVX512-NEXT: vminpd %ymm1, %ymm0, %ymm0 --+; AVX512-NEXT: vminpd %zmm1, %zmm0, %zmm0 -- ; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm1 -- ; AVX512-NEXT: vminpd %xmm1, %xmm0, %xmm0 -- ; AVX512-NEXT: vshufpd {{.*#+}} xmm1 = xmm0[1,0] --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-reduce-mul.ll b/llvm/test/CodeGen/X86/vector-reduce-mul.ll ----- a/llvm/test/CodeGen/X86/vector-reduce-mul.ll --+++ b/llvm/test/CodeGen/X86/vector-reduce-mul.ll --@@ -357,14 +357,14 @@ -- ; AVX512BW-LABEL: test_v8i64: -- ; AVX512BW: # %bb.0: -- ; AVX512BW-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BW-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BW-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BW-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BW-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BW-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BW-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BW-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BW-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --@@ -390,14 +390,14 @@ -- ; AVX512BWVL-LABEL: test_v8i64: -- ; AVX512BWVL: # %bb.0: -- ; AVX512BWVL-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BWVL-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BWVL-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BWVL-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BWVL-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BWVL-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BWVL-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --@@ -667,14 +667,14 @@ -- ; AVX512BW-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BW-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BW-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BW-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BW-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BW-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BW-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BW-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BW-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BW-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BW-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BW-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BW-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BW-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --@@ -708,14 +708,14 @@ -- ; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vextracti64x4 $1, %zmm0, %ymm1 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm0, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpsrlq $32, %ymm1, %ymm3 ---; AVX512BWVL-NEXT: vpmuludq %ymm3, %ymm0, %ymm3 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm3, %ymm2 ---; AVX512BWVL-NEXT: vpsllq $32, %ymm2, %ymm2 ---; AVX512BWVL-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ---; AVX512BWVL-NEXT: vpaddq %ymm2, %ymm0, %ymm0 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm0, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpsrlq $32, %zmm1, %zmm3 --+; AVX512BWVL-NEXT: vpmuludq %zmm3, %zmm0, %zmm3 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm3, %zmm2 --+; AVX512BWVL-NEXT: vpsllq $32, %zmm2, %zmm2 --+; AVX512BWVL-NEXT: vpmuludq %zmm1, %zmm0, %zmm0 --+; AVX512BWVL-NEXT: vpaddq %zmm2, %zmm0, %zmm0 -- ; AVX512BWVL-NEXT: vextracti128 $1, %ymm0, %xmm1 -- ; AVX512BWVL-NEXT: vpsrlq $32, %xmm0, %xmm2 -- ; AVX512BWVL-NEXT: vpmuludq %xmm1, %xmm2, %xmm2 --diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll b/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll ----- a/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll --+++ b/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast_from_memory.ll --@@ -3862,14 +3862,15 @@ -- ; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2 -- ; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3,4,5,6,7] -- ; AVX-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,0,1,1] ---; AVX-NEXT: vbroadcastss (%rdi), %xmm3 --+; AVX-NEXT: vbroadcastss (%rdi), %ymm3 -- ; AVX-NEXT: vpaddb 32(%rsi), %xmm0, %xmm0 ---; AVX-NEXT: vpaddb (%rsi), %xmm1, %xmm1 -- ; AVX-NEXT: vpblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm3[4,5],xmm2[6,7] -- ; AVX-NEXT: vpaddb 16(%rsi), %xmm2, %xmm2 ---; AVX-NEXT: vmovdqa %xmm2, 16(%rdx) --+; AVX-NEXT: vpaddb (%rsi), %xmm1, %xmm1 -- ; AVX-NEXT: vmovdqa %xmm1, (%rdx) --+; AVX-NEXT: vmovdqa %xmm2, 16(%rdx) -- ; AVX-NEXT: vmovdqa %xmm0, 32(%rdx) --+; AVX-NEXT: vzeroupper -- ; AVX-NEXT: retq -- ; -- ; AVX2-SLOW-LABEL: vec384_i32_widen_to_i96_factor3_broadcast_to_v4i96_factor4: --@@ -4115,7 +4116,7 @@ -- ; AVX: # %bb.0: -- ; AVX-NEXT: vmovdqa 48(%rdi), %xmm0 -- ; AVX-NEXT: vpblendw {{.*#+}} xmm0 = mem[0,1],xmm0[2,3,4,5,6,7] ---; AVX-NEXT: vbroadcastss (%rdi), %xmm1 --+; AVX-NEXT: vbroadcastss (%rdi), %ymm1 -- ; AVX-NEXT: vmovaps 32(%rsi), %ymm2 -- ; AVX-NEXT: vxorps %xmm3, %xmm3, %xmm3 -- ; AVX-NEXT: vblendps {{.*#+}} xmm1 = xmm3[0,1],xmm1[2],xmm3[3] ++diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ++--- a/clang/lib/Driver/ToolChains/Clang.cpp +++++ b/clang/lib/Driver/ToolChains/Clang.cpp ++@@ -6397,7 +6397,9 @@ ++ Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, ++ options::OPT_fno_convergent_functions); ++ ++- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); +++ // NVPTX doesn't support PGO or coverage +++ if (!Triple.isNVPTX()) +++ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); ++ ++ Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); ++ ++diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ++--- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu +++++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ++@@ -0,0 +1,33 @@ +++// Check that profiling/coverage arguments doen't get passed down to device-side +++// compilation. +++// +++// +++// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// XRUN: -fprofile-generate %s 2>&1 | \ +++// XRUN: FileCheck --check-prefixes=CHECK,PROF %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -fprofile-instr-generate %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,PROF %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -coverage %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -ftest-coverage %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +++// +++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +++// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ +++// RUN: FileCheck --check-prefixes=CHECK,PROF %s +++// +++// +++// CHECK-NOT: error: unsupported option '-fprofile +++// CHECK-NOT: error: invalid argument +++// CHECK-DAG: "-fcuda-is-device" +++// CHECK-NOT: "-f{{[^"/]*coverage.*}}" +++// CHECK-NOT: "-fprofile{{[^"]*}}" +++// CHECK: "-triple" "x86_64-unknown-linux-gnu" +++// PROF: "-fprofile{{.*}}" +++// GCOV: "-coverage-notes-file= ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp ++--- a/lldb/tools/lldb-dap/DAP.cpp +++++ b/lldb/tools/lldb-dap/DAP.cpp ++@@ -711,12 +711,12 @@ ++ [](const std::string &message) -> llvm::StringRef { ++ return message; ++ }, ++- [](const protocol::Response::Message &message) +++ [](const protocol::ResponseMessage &message) ++ -> llvm::StringRef { ++ switch (message) { ++- case protocol::Response::Message::cancelled: +++ case protocol::eResponseMessageCancelled: ++ return "cancelled"; ++- case protocol::Response::Message::notStopped: +++ case protocol::eResponseMessageNotStopped: ++ return "notStopped"; ++ } ++ }), ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ++@@ -7,6 +7,7 @@ ++ //===----------------------------------------------------------------------===// ++ ++ #include "Protocol/ProtocolBase.h" +++#include "lldb/lldb-enumerations.h" ++ #include "llvm/ADT/StringRef.h" ++ #include "llvm/ADT/StringSwitch.h" ++ #include "llvm/Support/ErrorHandling.h" ++@@ -31,11 +32,8 @@ ++ ++ namespace lldb_dap::protocol { ++ ++-enum MessageType { ++- eMessageTypeRequest, ++- eMessageTypeResponse, ++- eMessageTypeEvent ++-}; +++FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, +++ eMessageTypeEvent}; ++ ++ bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { ++ auto rawType = Params.getAsString(); ++@@ -107,12 +105,12 @@ ++ ++ if (R.message) { ++ assert(!R.success && "message can only be used if success is false"); ++- if (const auto *messageEnum = std::get_if(&*R.message)) { +++ if (const auto *messageEnum = std::get_if(&*R.message)) { ++ switch (*messageEnum) { ++- case Response::Message::cancelled: +++ case eResponseMessageCancelled: ++ Result.insert({"message", "cancelled"}); ++ break; ++- case Response::Message::notStopped: +++ case eResponseMessageNotStopped: ++ Result.insert({"message", "notStopped"}); ++ break; ++ } ++@@ -129,16 +127,16 @@ ++ } ++ ++ bool fromJSON(json::Value const &Params, ++- std::variant &M, json::Path P) { +++ std::variant &M, json::Path P) { ++ auto rawMessage = Params.getAsString(); ++ if (!rawMessage) { ++ P.report("expected a string"); ++ return false; ++ } ++- std::optional message = ++- StringSwitch>(*rawMessage) ++- .Case("cancelled", Response::Message::cancelled) ++- .Case("notStopped", Response::Message::notStopped) +++ std::optional message = +++ StringSwitch>(*rawMessage) +++ .Case("cancelled", eResponseMessageCancelled) +++ .Case("notStopped", eResponseMessageNotStopped) ++ .Default(std::nullopt); ++ if (message) ++ M = *message; ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ++@@ -20,6 +20,7 @@ ++ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H ++ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H ++ +++#include "lldb/lldb-enumerations.h" ++ #include "llvm/Support/JSON.h" ++ #include ++ #include ++@@ -64,15 +65,15 @@ ++ llvm::json::Value toJSON(const Event &); ++ bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); ++ ++-/// Response for a request. ++-struct Response { ++- enum class Message { +++FLAGS_ENUM(ResponseMessage){ ++ /// The request was cancelled ++- cancelled, +++ eResponseMessageCancelled, ++ /// The request may be retried once the adapter is in a 'stopped' state ++- notStopped, ++- }; +++ eResponseMessageNotStopped, +++}; ++ +++/// Response for a request. +++struct Response { ++ /// Sequence number of the corresponding request. ++ int64_t request_seq; ++ ++@@ -90,7 +91,7 @@ ++ /// Contains the raw error in short form if `success` is false. This raw error ++ /// might be interpreted by the client and is not shown in the UI. Some ++ /// predefined values exist. ++- std::optional> message; +++ std::optional> message; ++ ++ /// Contains request result if success is true and error details if success is ++ /// false. ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ++@@ -22,6 +22,8 @@ ++ ++ #include "Protocol/ProtocolBase.h" ++ #include "Protocol/ProtocolTypes.h" +++#include "lldb/lldb-enumerations.h" +++#include "llvm/ADT/DenseSet.h" ++ #include "llvm/Support/JSON.h" ++ #include ++ #include ++@@ -55,26 +57,26 @@ ++ using DisconnectResponse = VoidResponse; ++ ++ /// Features supported by DAP clients. ++-enum ClientFeature { ++- eClientFeatureVariableType, ++- eClientFeatureVariablePaging, ++- eClientFeatureRunInTerminalRequest, ++- eClientFeatureMemoryReferences, ++- eClientFeatureProgressReporting, ++- eClientFeatureInvalidatedEvent, ++- eClientFeatureMemoryEvent, ++- /// Client supports the `argsCanBeInterpretedByShell` attribute on the ++- /// `runInTerminal` request. ++- eClientFeatureArgsCanBeInterpretedByShell, ++- eClientFeatureStartDebuggingRequest, ++- /// The client will interpret ANSI escape sequences in the display of ++- /// `OutputEvent.output` and `Variable.value` fields when ++- /// `Capabilities.supportsANSIStyling` is also enabled. ++- eClientFeatureANSIStyling, +++FLAGS_ENUM(ClientFeature){ +++ eClientFeatureVariableType, +++ eClientFeatureVariablePaging, +++ eClientFeatureRunInTerminalRequest, +++ eClientFeatureMemoryReferences, +++ eClientFeatureProgressReporting, +++ eClientFeatureInvalidatedEvent, +++ eClientFeatureMemoryEvent, +++ /// Client supports the `argsCanBeInterpretedByShell` attribute on the +++ /// `runInTerminal` request. +++ eClientFeatureArgsCanBeInterpretedByShell, +++ eClientFeatureStartDebuggingRequest, +++ /// The client will interpret ANSI escape sequences in the display of +++ /// `OutputEvent.output` and `Variable.value` fields when +++ /// `Capabilities.supportsANSIStyling` is also enabled. +++ eClientFeatureANSIStyling, ++ }; ++ ++ /// Format of paths reported by the debug adapter. ++-enum PathFormat { ePatFormatPath, ePathFormatURI }; +++FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; ++ ++ /// Arguments for `initialize` request. ++ struct InitializeRequestArguments { ++diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ++--- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +++++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ++@@ -20,6 +20,7 @@ ++ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H ++ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H ++ +++#include "lldb/lldb-enumerations.h" ++ #include "llvm/ADT/DenseSet.h" ++ #include "llvm/Support/JSON.h" ++ #include ++@@ -56,12 +57,8 @@ ++ }; ++ llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); ++ ++-enum ColumnType { ++- eColumnTypeString, ++- eColumnTypeNumber, ++- eColumnTypeBoolean, ++- eColumnTypeTimestamp ++-}; +++FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, +++ eColumnTypeTimestamp}; ++ ++ /// A ColumnDescriptor specifies what module attribute to show in a column of ++ /// the modules view, how to format it, and what the column’s label should be. ++@@ -90,27 +87,23 @@ ++ ++ /// Names of checksum algorithms that may be supported by a debug adapter. ++ /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. ++-enum ChecksumAlgorithm { ++- eChecksumAlgorithmMD5, ++- eChecksumAlgorithmSHA1, ++- eChecksumAlgorithmSHA256, ++- eChecksumAlgorithmTimestamp ++-}; +++FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, +++ eChecksumAlgorithmSHA256, +++ eChecksumAlgorithmTimestamp}; ++ llvm::json::Value toJSON(const ChecksumAlgorithm &); ++ ++ /// Describes one or more type of breakpoint a BreakpointMode applies to. This ++ /// is a non-exhaustive enumeration and may expand as future breakpoint types ++ /// are added. ++-enum BreakpointModeApplicability { ++- /// In `SourceBreakpoint`'s. ++- eBreakpointModeApplicabilitySource, ++- /// In exception breakpoints applied in the `ExceptionFilterOptions`. ++- eBreakpointModeApplicabilityException, ++- /// In data breakpoints requested in the `DataBreakpointInfo` request. ++- eBreakpointModeApplicabilityData, ++- /// In `InstructionBreakpoint`'s. ++- eBreakpointModeApplicabilityInstruction ++-}; +++FLAGS_ENUM(BreakpointModeApplicability){ +++ /// In `SourceBreakpoint`'s. +++ eBreakpointModeApplicabilitySource, +++ /// In exception breakpoints applied in the `ExceptionFilterOptions`. +++ eBreakpointModeApplicabilityException, +++ /// In data breakpoints requested in the `DataBreakpointInfo` request. +++ eBreakpointModeApplicabilityData, +++ /// In `InstructionBreakpoint`'s. +++ eBreakpointModeApplicabilityInstruction}; ++ llvm::json::Value toJSON(const BreakpointModeApplicability &); ++ ++ /// A `BreakpointMode` is provided as a option when setting breakpoints on ++@@ -133,101 +126,101 @@ ++ llvm::json::Value toJSON(const BreakpointMode &); ++ ++ /// Debug Adapter Features flags supported by lldb-dap. ++-enum AdapterFeature { ++- /// The debug adapter supports ANSI escape sequences in styling of ++- /// `OutputEvent.output` and `Variable.value` fields. ++- eAdapterFeatureANSIStyling, ++- /// The debug adapter supports the `breakpointLocations` request. ++- eAdapterFeatureBreakpointLocationsRequest, ++- /// The debug adapter supports the `cancel` request. ++- eAdapterFeatureCancelRequest, ++- /// The debug adapter supports the `clipboard` context value in the ++- /// `evaluate` request. ++- eAdapterFeatureClipboardContext, ++- /// The debug adapter supports the `completions` request. ++- eAdapterFeatureCompletionsRequest, ++- /// The debug adapter supports conditional breakpoints. ++- eAdapterFeatureConditionalBreakpoints, ++- /// The debug adapter supports the `configurationDone` request. ++- eAdapterFeatureConfigurationDoneRequest, ++- /// The debug adapter supports the `asAddress` and `bytes` fields in the ++- /// `dataBreakpointInfo` request. ++- eAdapterFeatureDataBreakpointBytes, ++- /// The debug adapter supports data breakpoints. ++- eAdapterFeatureDataBreakpoints, ++- /// The debug adapter supports the delayed loading of parts of the stack, ++- /// which requires that both the `startFrame` and `levels` arguments and the ++- /// `totalFrames` result of the `stackTrace` request are supported. ++- eAdapterFeatureDelayedStackTraceLoading, ++- /// The debug adapter supports the `disassemble` request. ++- eAdapterFeatureDisassembleRequest, ++- /// The debug adapter supports a (side effect free) `evaluate` request for ++- /// data hovers. ++- eAdapterFeatureEvaluateForHovers, ++- /// The debug adapter supports `filterOptions` as an argument on the ++- /// `setExceptionBreakpoints` request. ++- eAdapterFeatureExceptionFilterOptions, ++- /// The debug adapter supports the `exceptionInfo` request. ++- eAdapterFeatureExceptionInfoRequest, ++- /// The debug adapter supports `exceptionOptions` on the ++- /// `setExceptionBreakpoints` request. ++- eAdapterFeatureExceptionOptions, ++- /// The debug adapter supports function breakpoints. ++- eAdapterFeatureFunctionBreakpoints, ++- /// The debug adapter supports the `gotoTargets` request. ++- eAdapterFeatureGotoTargetsRequest, ++- /// The debug adapter supports breakpoints that break execution after a ++- /// specified number of hits. ++- eAdapterFeatureHitConditionalBreakpoints, ++- /// The debug adapter supports adding breakpoints based on instruction ++- /// references. ++- eAdapterFeatureInstructionBreakpoints, ++- /// The debug adapter supports the `loadedSources` request. ++- eAdapterFeatureLoadedSourcesRequest, ++- /// The debug adapter supports log points by interpreting the `logMessage` ++- /// attribute of the `SourceBreakpoint`. ++- eAdapterFeatureLogPoints, ++- /// The debug adapter supports the `modules` request. ++- eAdapterFeatureModulesRequest, ++- /// The debug adapter supports the `readMemory` request. ++- eAdapterFeatureReadMemoryRequest, ++- /// The debug adapter supports restarting a frame. ++- eAdapterFeatureRestartFrame, ++- /// The debug adapter supports the `restart` request. In this case a client ++- /// should not implement `restart` by terminating and relaunching the ++- /// adapter but by calling the `restart` request. ++- eAdapterFeatureRestartRequest, ++- /// The debug adapter supports the `setExpression` request. ++- eAdapterFeatureSetExpression, ++- /// The debug adapter supports setting a variable to a value. ++- eAdapterFeatureSetVariable, ++- /// The debug adapter supports the `singleThread` property on the execution ++- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, ++- /// `stepBack`). ++- eAdapterFeatureSingleThreadExecutionRequests, ++- /// The debug adapter supports stepping back via the `stepBack` and ++- /// `reverseContinue` requests. ++- eAdapterFeatureStepBack, ++- /// The debug adapter supports the `stepInTargets` request. ++- eAdapterFeatureStepInTargetsRequest, ++- /// The debug adapter supports stepping granularities (argument ++- /// `granularity`) for the stepping requests. ++- eAdapterFeatureSteppingGranularity, ++- /// The debug adapter supports the `terminate` request. ++- eAdapterFeatureTerminateRequest, ++- /// The debug adapter supports the `terminateThreads` request. ++- eAdapterFeatureTerminateThreadsRequest, ++- /// The debug adapter supports the `suspendDebuggee` attribute on the ++- /// `disconnect` request. ++- eAdapterFeatureSuspendDebuggee, ++- /// The debug adapter supports a `format` attribute on the `stackTrace`, ++- /// `variables`, and `evaluate` requests. ++- eAdapterFeatureValueFormattingOptions, ++- /// The debug adapter supports the `writeMemory` request. ++- eAdapterFeatureWriteMemoryRequest, ++- /// The debug adapter supports the `terminateDebuggee` attribute on the ++- /// `disconnect` request. ++- eAdapterFeatureTerminateDebuggee, +++FLAGS_ENUM(AdapterFeature){ +++ /// The debug adapter supports ANSI escape sequences in styling of +++ /// `OutputEvent.output` and `Variable.value` fields. +++ eAdapterFeatureANSIStyling, +++ /// The debug adapter supports the `breakpointLocations` request. +++ eAdapterFeatureBreakpointLocationsRequest, +++ /// The debug adapter supports the `cancel` request. +++ eAdapterFeatureCancelRequest, +++ /// The debug adapter supports the `clipboard` context value in the +++ /// `evaluate` request. +++ eAdapterFeatureClipboardContext, +++ /// The debug adapter supports the `completions` request. +++ eAdapterFeatureCompletionsRequest, +++ /// The debug adapter supports conditional breakpoints. +++ eAdapterFeatureConditionalBreakpoints, +++ /// The debug adapter supports the `configurationDone` request. +++ eAdapterFeatureConfigurationDoneRequest, +++ /// The debug adapter supports the `asAddress` and `bytes` fields in the +++ /// `dataBreakpointInfo` request. +++ eAdapterFeatureDataBreakpointBytes, +++ /// The debug adapter supports data breakpoints. +++ eAdapterFeatureDataBreakpoints, +++ /// The debug adapter supports the delayed loading of parts of the stack, +++ /// which requires that both the `startFrame` and `levels` arguments and the +++ /// `totalFrames` result of the `stackTrace` request are supported. +++ eAdapterFeatureDelayedStackTraceLoading, +++ /// The debug adapter supports the `disassemble` request. +++ eAdapterFeatureDisassembleRequest, +++ /// The debug adapter supports a (side effect free) `evaluate` request for +++ /// data hovers. +++ eAdapterFeatureEvaluateForHovers, +++ /// The debug adapter supports `filterOptions` as an argument on the +++ /// `setExceptionBreakpoints` request. +++ eAdapterFeatureExceptionFilterOptions, +++ /// The debug adapter supports the `exceptionInfo` request. +++ eAdapterFeatureExceptionInfoRequest, +++ /// The debug adapter supports `exceptionOptions` on the +++ /// `setExceptionBreakpoints` request. +++ eAdapterFeatureExceptionOptions, +++ /// The debug adapter supports function breakpoints. +++ eAdapterFeatureFunctionBreakpoints, +++ /// The debug adapter supports the `gotoTargets` request. +++ eAdapterFeatureGotoTargetsRequest, +++ /// The debug adapter supports breakpoints that break execution after a +++ /// specified number of hits. +++ eAdapterFeatureHitConditionalBreakpoints, +++ /// The debug adapter supports adding breakpoints based on instruction +++ /// references. +++ eAdapterFeatureInstructionBreakpoints, +++ /// The debug adapter supports the `loadedSources` request. +++ eAdapterFeatureLoadedSourcesRequest, +++ /// The debug adapter supports log points by interpreting the `logMessage` +++ /// attribute of the `SourceBreakpoint`. +++ eAdapterFeatureLogPoints, +++ /// The debug adapter supports the `modules` request. +++ eAdapterFeatureModulesRequest, +++ /// The debug adapter supports the `readMemory` request. +++ eAdapterFeatureReadMemoryRequest, +++ /// The debug adapter supports restarting a frame. +++ eAdapterFeatureRestartFrame, +++ /// The debug adapter supports the `restart` request. In this case a client +++ /// should not implement `restart` by terminating and relaunching the +++ /// adapter but by calling the `restart` request. +++ eAdapterFeatureRestartRequest, +++ /// The debug adapter supports the `setExpression` request. +++ eAdapterFeatureSetExpression, +++ /// The debug adapter supports setting a variable to a value. +++ eAdapterFeatureSetVariable, +++ /// The debug adapter supports the `singleThread` property on the execution +++ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +++ /// `stepBack`). +++ eAdapterFeatureSingleThreadExecutionRequests, +++ /// The debug adapter supports stepping back via the `stepBack` and +++ /// `reverseContinue` requests. +++ eAdapterFeatureStepBack, +++ /// The debug adapter supports the `stepInTargets` request. +++ eAdapterFeatureStepInTargetsRequest, +++ /// The debug adapter supports stepping granularities (argument +++ /// `granularity`) for the stepping requests. +++ eAdapterFeatureSteppingGranularity, +++ /// The debug adapter supports the `terminate` request. +++ eAdapterFeatureTerminateRequest, +++ /// The debug adapter supports the `terminateThreads` request. +++ eAdapterFeatureTerminateThreadsRequest, +++ /// The debug adapter supports the `suspendDebuggee` attribute on the +++ /// `disconnect` request. +++ eAdapterFeatureSuspendDebuggee, +++ /// The debug adapter supports a `format` attribute on the `stackTrace`, +++ /// `variables`, and `evaluate` requests. +++ eAdapterFeatureValueFormattingOptions, +++ /// The debug adapter supports the `writeMemory` request. +++ eAdapterFeatureWriteMemoryRequest, +++ /// The debug adapter supports the `terminateDebuggee` attribute on the +++ /// `disconnect` request. +++ eAdapterFeatureTerminateDebuggee, ++ }; ++ ++ /// Information about the capabilities of a debug adapter. ++@@ -268,10 +261,10 @@ ++ }; ++ llvm::json::Value toJSON(const Capabilities &); ++ ++-enum PresentationHint { ++- ePresentationHintNormal, ++- ePresentationHintEmphasize, ++- ePresentationHintDeemphasize, +++FLAGS_ENUM(PresentationHint){ +++ ePresentationHintNormal, +++ ePresentationHintEmphasize, +++ ePresentationHintDeemphasize, ++ }; ++ ++ /// A `Source` is a descriptor for source code. It is returned from the debug ++diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ++--- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +++++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ++@@ -1,7 +1,7 @@ ++ // Header ++ // ++ // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) ++-// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) +++// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) ++ // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) ++ // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) ++ // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) ++diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c ++--- a/offload/test/offloading/gpupgo/pgo1.c +++++ b/offload/test/offloading/gpupgo/pgo1.c ++@@ -14,7 +14,7 @@ ++ // RUN: %target_triple.%basename_t.clang.profraw | \ ++ // RUN: %fcheck-generic --check-prefix="CLANG-PGO" ++ ++-// REQUIRES: gpu +++// REQUIRES: amdgpu ++ // REQUIRES: pgo ++ ++ int test1(int a) { return a / 2; } ++diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c ++--- a/offload/test/offloading/gpupgo/pgo2.c +++++ b/offload/test/offloading/gpupgo/pgo2.c ++@@ -48,7 +48,7 @@ ++ // RUN: %target_triple.%basename_t.hfdi.profraw \ ++ // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" ++ ++-// REQUIRES: gpu +++// REQUIRES: amdgpu ++ // REQUIRES: pgo ++ ++ int main() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 5ceac99..725480b 100644 +index 725480b..005737a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "5eccd71ce4f852c7b2f06ecd1976d9e34040fcaa" -- LLVM_SHA256 = "fd100fd69425ebac40ed58ff0558e0064b74bd97b6023d5e65e4c706c726a483" -+ LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" -+ LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" +- LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" +- LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" ++ LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" ++ LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 9110bac3105338..7b1a0496a0965c 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "4639651b8267094807acab8b2709ed57cb924888" - SHARDY_SHA256 = "f3fd113eef0bffe8fbab9435b22022598b56dd710999fd5cafd5c1dfdc7c91ce" + SHARDY_COMMIT = "9435b34df0279d473240f5bcc2a829d0589ae372" + SHARDY_SHA256 = "5f2a037d3301a1407e5c94778dd56d855f5abe26999cce448ccfa1923cf9559f" tf_http_archive( name = "shardy", From a985e1ea4c20be02fcd0eaa515870020b6c47709 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Sat, 29 Mar 2025 15:06:34 -0700 Subject: [PATCH 0015/1324] Switch `HloProgram` SerDes to deserialize into StableHLO and assume the input program to be StableHLO JAX has been emitting StableHLO for a while, so it makes sense to assume StableHLO as the only program. This avoids unnecessary MHLO conversion at deserialization time and makes sure that StableHLO is passed to the actual runtime behind IFRT Proxy. This does not change serialization format since the serialized bytes are still in VHLO. PiperOrigin-RevId: 741883550 --- third_party/xla/xla/python/ifrt/hlo/BUILD | 4 +--- .../xla/python/ifrt/hlo/hlo_program_serdes.cc | 19 ++----------------- .../ifrt/hlo/hlo_program_serdes_test.cc | 17 ++++++++--------- .../pjrt_ifrt/xla_executable_impl_test_lib.cc | 14 +++++++------- 4 files changed, 18 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/hlo/BUILD b/third_party/xla/xla/python/ifrt/hlo/BUILD index 80b2065c2bbe70..1f477aa726854c 100644 --- a/third_party/xla/xla/python/ifrt/hlo/BUILD +++ b/third_party/xla/xla/python/ifrt/hlo/BUILD @@ -33,14 +33,12 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "//xla/pjrt:mlir_to_hlo", "//xla/python/ifrt:serdes", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_serialization", ], alwayslink = True, @@ -52,7 +50,6 @@ xla_cc_test( deps = [ ":hlo_program", ":hlo_program_serdes", - "//xla/mlir_hlo", "//xla/pjrt:mlir_to_hlo", "//xla/python/ifrt:serdes", "//xla/python/ifrt:serdes_proto_cc", @@ -66,5 +63,6 @@ xla_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes.cc b/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes.cc index b620ecc4810207..aadf22c0e171d4 100644 --- a/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes.cc +++ b/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes.cc @@ -26,14 +26,12 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/Serialization.h" #include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/serdes.h" -#include "xla/tsl/platform/statusor.h" namespace xla { namespace ifrt { @@ -78,10 +76,8 @@ class HloProgramSerDes : public llvm::RTTIExtends { llvm::cast(program.mlir_module->clone())); // Serialize portable artifact. - TF_ASSIGN_OR_RETURN(std::string serialized, - xla::SerializeUsingVersionedStablehlo( - *module, xla::GetDefaultStablehloVersion())); - return serialized; + return xla::SerializeUsingVersionedStablehlo( + *module, xla::GetDefaultStablehloVersion()); } absl::StatusOr> Deserialize( @@ -103,17 +99,6 @@ class HloProgramSerDes : public llvm::RTTIExtends { status.message())); } - // Convert StableHLO back to MHLO to keep the contract the same before and - // after a serialization/deserialization round trip. - mlir::PassManager pm(context.get()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - if (!mlir::succeeded(pm.run(*module))) { - const absl::Status status = diagnostic_handler.ConsumeStatus(); - return absl::InvalidArgumentError(absl::StrCat( - "Failed to legalize StableHLO to MHLO;\n\nDetailed error from MLIR: ", - status.message())); - } - return std::make_unique(std::move(context), std::move(module)); } diff --git a/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes_test.cc index 859ba3715261a9..e8991eecc577da 100644 --- a/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes_test.cc +++ b/third_party/xla/xla/python/ifrt/hlo/hlo_program_serdes_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/Support/DebugStringHelper.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/serdes.h" @@ -48,11 +48,10 @@ TEST(HloProgramSerDesTest, RoundTrip) { static constexpr absl::string_view kMlirModuleStr = R"( module { func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - %0 = "mhlo.copy"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - %1 = mhlo.constant dense<1.000000e+00> : tensor - %2 = "mhlo.broadcast"(%1) {broadcast_sizes = dense<[2, 3]> : tensor<2xi64>} : (tensor) -> tensor<2x3xf32> - %3 = mhlo.add %0, %2 : tensor<2x3xf32> - return %3 : tensor<2x3xf32> + %0 = stablehlo.constant dense<1.000000e+00> : tensor + %1 = "stablehlo.broadcast_in_dim"(%0) {broadcast_dimensions = array} : (tensor) -> tensor<2x3xf32> + %2 = stablehlo.add %arg0, %1 : tensor<2x3xf32> + return %2 : tensor<2x3xf32> } })"; @@ -72,13 +71,13 @@ module { std::unique_ptr xla_program, Deserialize(serialized, /*options=*/nullptr)); - // Verify that the deserialized program has no StableHLO ops. + // Verify that the deserialized program has no MHLO ops. bool has_unsupported_dialect = false; xla_program->mlir_module->walk([&](mlir::Operation *op) { if (!llvm::isa(op->getDialect())) { + mlir::stablehlo::StablehloDialect>(op->getDialect())) { LOG(ERROR) << "Found an op with an unsupported dialect: " - << mlir::debugString(op); + << mlir::debugString(*op); has_unsupported_dialect = true; } }); diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 4e577b28d2a728..5f2e0239125156 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -58,13 +58,13 @@ using ::tsl::testing::IsOkAndHolds; // Serialized `ModuleOp` that does add 1. static const char* const module_add_one = R"(module { -func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - %0 = "mhlo.copy"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - %1 = mhlo.constant dense<1.000000e+00> : tensor - %2 = "mhlo.broadcast"(%1) {broadcast_sizes = dense<[2, 3]> : tensor<2xi64>} : (tensor) -> tensor<2x3xf32> - %3 = mhlo.add %0, %2 : tensor<2x3xf32> - return %3 : tensor<2x3xf32> -}})"; + func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = stablehlo.constant dense<1.000000e+00> : tensor + %1 = "stablehlo.broadcast_in_dim"(%0) {broadcast_dimensions = array} : (tensor) -> tensor<2x3xf32> + %2 = stablehlo.add %arg0, %1 : tensor<2x3xf32> + return %2 : tensor<2x3xf32> + } +})"; // Compiles an MLIR module on specified devices. If devices is empty, compiles // it as a portable executable. From 6f86be0bb60781b92da99b89b42126984bed93da Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 29 Mar 2025 16:18:51 -0700 Subject: [PATCH 0016/1324] Automated Code Change PiperOrigin-RevId: 741892340 --- third_party/xla/xla/backends/cpu/runtime/BUILD | 3 +++ third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc | 1 + third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h | 1 + third_party/xla/xla/backends/cpu/runtime/dot_lib.cc | 1 + third_party/xla/xla/backends/cpu/runtime/dot_lib.h | 1 + third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc | 1 + 6 files changed, 8 insertions(+) diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index db57c57f9bc8ce..5759e3657150a4 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -744,6 +744,7 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_status_internal", "//xla/service:custom_call_target_registry", + "//xla/service:hlo_proto_cc", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", @@ -773,6 +774,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "@com_google_absl//absl/algorithm:container", @@ -1165,6 +1167,7 @@ cc_library( ":thunk", "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service/cpu:runtime_fft", diff --git a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc index 606ccc226cd0ae..9c949bee4f07ce 100644 --- a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -50,6 +50,7 @@ limitations under the License. #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo.pb.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" diff --git a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h index 81d2b504e5992f..e52c1e8126b50a 100644 --- a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h @@ -30,6 +30,7 @@ limitations under the License. #include "xla/ffi/execution_state.h" #include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" +#include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc b/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc index 64d4e917fd8946..9d8c96bf80ba1b 100644 --- a/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_lib.h b/third_party/xla/xla/backends/cpu/runtime/dot_lib.h index 1c7bbf9727a260..00a29d7d2fc20f 100644 --- a/third_party/xla/xla/backends/cpu/runtime/dot_lib.h +++ b/third_party/xla/xla/backends/cpu/runtime/dot_lib.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { diff --git a/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc index b2bf851485ffe3..a19e5cb49a7e17 100644 --- a/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { From d36cf74b3b23c201b6170d587be8818c986c8f52 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Sat, 29 Mar 2025 22:15:27 -0700 Subject: [PATCH 0017/1324] Add send / recv callback support PiperOrigin-RevId: 741943284 --- third_party/xla/xla/pjrt/gpu/tfrt/BUILD | 5 +- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 294 +++++++++++++++++- .../xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 97 ++++++ 3 files changed, 379 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD index 86599720689473..3f863f27d1dfc3 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD +++ b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD @@ -36,6 +36,7 @@ cc_library( "//xla/client:local_client", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/pjrt:host_callback", "//xla/pjrt:host_memory_spaces", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", @@ -133,6 +134,7 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", @@ -147,8 +149,9 @@ xla_cc_test( "//xla/tsl/concurrency:async_value", # copybara:uncomment "//xla/tsl/framework:allocator", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", - # copybara:uncomment "//tensorflow/core:framework", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:platform_port", ], diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 4ef8e7fc67c88d..65b8ffec8b34f7 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -61,6 +61,7 @@ limitations under the License. #include "xla/pjrt/gpu/tfrt/host_memory_allocator.h" #include "xla/pjrt/gpu/tfrt/stream_pool.h" #include "xla/pjrt/gpu/tfrt/tracked_tfrt_gpu_device_buffer.h" +#include "xla/pjrt/host_callback.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" @@ -608,6 +609,239 @@ absl::flat_hash_map GetAttrsForDevices( return attrs; } +class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { + public: + TfrtGpuCopyToDeviceStream(int64_t channel_id, se::Stream* stream, + se::DeviceMemoryBase dst, + tsl::AsyncValueRef> done) + : CopyToDeviceStream(dst.size(), /*granule_bytes=*/1), + channel_id_(channel_id), + stream_(stream), + dst_(dst), + done_(std::move(done)) {} + + PjRtFuture<> AddChunk(PjRtChunk chunk) final { + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode("TfrtGpuCopyToDeviceStream::AddChunk", + {{"channel_id", channel_id_}}); + }); + + absl::ReleasableMutexLock lock(&mu_); + + VLOG(3) << "Add chunk to a H2D channel #" << channel_id_ << ": " + << "size=" << chunk.size() << ", " + << "current_bytes=" << current_bytes_ << ", " + << "total_bytes=" << total_bytes_; + + if (chunk.size() % granule_size_in_bytes() != 0) { + done_.SetError(absl::InvalidArgumentError(absl::StrFormat( + "Chunk size (%d) was not a multiple of the granule size (%d)", + chunk.size(), granule_size_in_bytes()))); + return PjRtFuture<>(done_.GetError()); + } + + if (current_bytes_ + chunk.size() > total_bytes_) { + done_.SetError(absl::InvalidArgumentError( + absl::StrFormat("Adding chunk of size %d would overflow buffer of " + "size %d (%d already transferred)", + chunk.size(), total_bytes_, current_bytes_))); + return PjRtFuture<>(done_.GetError()); + } + + se::DeviceMemoryBase dst( + reinterpret_cast(dst_.opaque()) + current_bytes_, + dst_.size() - current_bytes_); + + current_bytes_ += chunk.size(); + bool complete = IsCompleteLocked(); + lock.Release(); + + auto copied = stream_->Memcpy(&dst, chunk.data(), chunk.size()); + if (!copied.ok()) { + done_.SetError(copied); + return PjRtFuture<>(done_.GetError()); + } + + // Delete chunk once the memcpy operation completes. + auto* chunk_ptr = std::make_unique(std::move(chunk)).release(); + auto deleted = stream_->DoHostCallback([chunk_ptr]() { delete chunk_ptr; }); + if (!deleted.ok()) { + done_.SetError(deleted); + return PjRtFuture<>(done_.GetError()); + } + + // Record done event once processed the last chunk. It is the caller + // responsibility to synchronize with this event before submitting any new + // computations to the stream. + if (complete) { + auto recorded = stream_->RecordEvent(done_.get().get()); + if (!recorded.ok()) { + done_.SetError(recorded); + return PjRtFuture<>(done_.GetError()); + } + done_.SetStateConcrete(); + } + + return PjRtFuture<>(absl::OkStatus()); + } + + private: + int64_t channel_id_; + se::Stream* stream_; + se::DeviceMemoryBase dst_; + + // Async value will become available after we'll submit the last memcpy + // operation, and the event will be recorded on the stream. + tsl::AsyncValueRef> done_; +}; + +template +const T* FindCallback(int channel_id, absl::Span callbacks) { + // TODO(ezhulenev): Can we use binary search here assuming that callbacks + // are sorted by channel id? Are they always sorted? + auto it = absl::c_find_if(callbacks, [&](const T& callback) { + return callback.channel_id == channel_id; + }); + return it == callbacks.end() ? nullptr : &*it; +} + +// Converts PjRt SendCallbacks to an XLA StreamExecutor send function. +SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( + int replica, const ExecuteOptions& options, + tsl::thread::ThreadPool* thread_pool) { + // Check if we have callbacks registered for the given replica. + if (replica >= options.send_callbacks.size()) { + return [replica](int64_t channel_id, se::Stream*, const Shape&, + const se::DeviceMemoryBase&, + const absl::flat_hash_map&) { + return Internal( + "Don't send a buffer to the channel_id=%d, there was no send " + "callbacks registered for the replica=%d", + channel_id, replica); + }; + } + + // SendCallbacks registered for a device ordinal. Can be empty. + absl::Span callbacks = options.send_callbacks[replica]; + + return [callbacks, thread_pool]( + int64_t channel_id, se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& src, + const absl::flat_hash_map&) + -> absl::StatusOr>> { + VLOG(3) << "Send " << src.size() << " bytes to channel #" << channel_id + << " (shape=" << shape.ToString() << ")"; + + const SendCallback* send = FindCallback(channel_id, callbacks); + if (!send) { + return InvalidArgument( + "Failed to send a buffer to the channel_id=%d, callback not found", + channel_id); + } + + // Allocate event that will signal completion of send operation. We do not + // actually track the completion of the send callback, we only have to keep + // the device memory long enough to complete the memcpy command. + TF_ASSIGN_OR_RETURN(auto se_event, stream->parent()->CreateEvent()); + auto done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(se_event)); + + thread_pool->Schedule([done_event, stream, src, channel_id, shape, send] { + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode("TfrtGpuExecutable::Send", + {{"channel_id", channel_id}}); + }); + + // Allocate chunk on the host for copying data from device. + PjRtChunk chunk = PjRtChunk::AllocateDefault(src.size()); + + auto status = stream->Memcpy(chunk.data(), src, src.size()); + if (!status.ok()) { + done_event.SetError(status); + return; + } + status = stream->RecordEvent(done_event.get().get()); + if (!status.ok()) { + done_event.SetError(status); + return; + } + + // Wait for the data to be available on the host. + if (auto st = stream->BlockHostUntilDone(); !st.ok()) { + done_event.SetError(absl::InternalError(absl::StrFormat( + "failed to synchronize send operation with a stream: %s", + st.message()))); + return; + } + + // Pass chunk to the registered callback. + auto sent = send->callback({shape}, std::move(chunk), + /*total_size_in_bytes=*/src.size(), + /*done=*/true); + + if (!sent.ok()) { + done_event.SetError(sent); + } else { + done_event.SetStateConcrete(); + } + }); + + return std::move(done_event); + }; +} + +RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( + int replica, const ExecuteOptions& options) { + // Check if we have callbacks registered for the given replica. + if (replica >= options.send_callbacks.size()) { + return [replica](int64_t channel_id, se::Stream*, const Shape&, + se::DeviceMemoryBase*, + const absl::flat_hash_map&) { + return InvalidArgument( + "Failed to receive a buffer from the channel_id=%d, there was no " + "recv callbacks registered for the replica=%d", + channel_id, replica); + }; + } + + // RecvCallbacks registered for a device ordinal. Can be empty. + absl::Span callbacks = options.recv_callbacks[replica]; + + return [callbacks](int64_t channel_id, se::Stream* stream, const Shape& shape, + se::DeviceMemoryBase* dst, + const absl::flat_hash_map&) + -> absl::StatusOr>> { + VLOG(3) << "Recv from channel #" << channel_id + << " (shape=" << shape.ToString() << ")"; + + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode("TfrtGpuExecutable::Recv", + {{"channel_id", channel_id}}); + }); + + const RecvCallback* recv = FindCallback(channel_id, callbacks); + if (!recv) { + return InvalidArgument( + "Failed to recv a buffer from the channel_id=%d, callback not found", + channel_id); + } + + // Allocate event that will signal completion of recv operation. We record + // it on a stream after submitting the memcpy for the last chunk (see + // `TfrtGpuCopyToDeviceStream` implementation above). + TF_ASSIGN_OR_RETURN(auto event, stream->parent()->CreateEvent()); + auto done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(event)); + + recv->callback({shape}, std::make_unique( + channel_id, stream, *dst, done_event)); + + return std::move(done_event); + }; +} + } // namespace TfrtGpuMemorySpace::TfrtGpuMemorySpace(int id, PjRtDevice* device, @@ -2198,8 +2432,11 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( tsl::AsyncValueRef last_collective_launch_event, bool fill_future, TfrtGpuDevice* device) { tsl::profiler::TraceMe traceme("TfrtGpuExecutable::ExecuteHelper"); - VLOG(2) << "ExecuteHelper " << name() << ": " << options.launch_id - << "; replica: " << replica; + if (VLOG_IS_ON(2)) { + LOG(INFO) << "ExecuteHelper " << name() << ": " << options.launch_id + << "; replica: " << replica << "; partition: " << partition + << "; mapped to device ordinal for execution: " << device->id(); + } std::shared_ptr device_assignment; if (device == nullptr) { @@ -2235,6 +2472,9 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( } } + // SPMD sharding produces a single executable for multiple partitions. + int executable_idx = executables_.size() > 1 ? partition : 0; + // `execute_event` indicates whether gpu computation is complete and whether // there was an error. auto execute_event = tsl::MakeConstructedAsyncValueRef(); @@ -2255,9 +2495,9 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( std::vector> input_deps; input_deps.reserve(argument_handles.size() + 1); - // TODO(b/382117736): Support multiple devices. - CHECK_EQ(parameters_that_must_be_donated_.size(), 1); - auto donate_it = parameters_that_must_be_donated_[0].begin(); + absl::Span donated_params = + parameters_that_must_be_donated_[executable_idx]; + auto donate_it = donated_params.begin(); for (int i = 0; i < argument_handles.size(); ++i) { PjRtBuffer* handle = argument_handles[i]; @@ -2269,9 +2509,10 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( i, replica, tfrt_buffer->device()->DebugString(), device->DebugString()); } - - bool must_donate = donate_it != parameters_that_must_be_donated_[0].end() && - *donate_it == i; + bool donation_denied_at_runtime = + options.non_donatable_input_indices.contains(i); + bool must_donate = donate_it != donated_params.end() && *donate_it == i && + !donation_denied_at_runtime; TrackedTfrtGpuDeviceBuffer* tracked_buffer = nullptr; if (must_donate) { ++donate_it; @@ -2322,7 +2563,8 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( std::vector> output_buffers; std::vector> outputs; - Executable* executable = executables_[0]->executable(); + auto gpu_executable = executables_[executable_idx]; + Executable* executable = gpu_executable->executable(); const Shape& result_shape = executable->result_shape(); bool untuple_result = options.untuple_result; bool result_is_tuple = result_shape.IsTuple(); @@ -2391,16 +2633,28 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( << options.launch_id << "; replica: " << replica; auto ffi_context = options.context != nullptr ? &options.context->ffi_context() : nullptr; + + // Create a PjRt<->StreamExecutor adaptors to send/recv device memory as + // PjRt chunks via the user-provided callbacks. + SendDeviceMemoryFunction send_device_memory = + ConvertSendCallbacksToSendFunction(replica, options, + client_->non_blocking_thread_pool()); + RecvDeviceMemoryFunction recv_device_memory = + ConvertRecvCallbacksToRecvFunction(replica, options); + auto execute_fn = [replica, partition, device, launch_id(options.launch_id), run_id(run_id), output_buffers(output_buffers), execute_event(execute_event.CopyRef()), untuple_result(untuple_result), result_is_tuple(result_is_tuple), compute_reservation(std::move(compute_reservation)), donation_transactions(std::move(donation_transactions)), - parameter_shapes(on_device_executable_parameter_shapes_[0]), - gpu_executable(executables_[0]), device_assignment(device_assignment), - executable_name(name()), ffi_context(ffi_context), - inputs_avs(CopyAsyncValues(input_deps)), + parameter_shapes(on_device_executable_parameter_shapes_[executable_idx]), + gpu_executable(std::move(gpu_executable)), + device_assignment(device_assignment), executable_name(name()), + ffi_context(ffi_context), inputs_avs(CopyAsyncValues(input_deps)), + execution_profile(options.execution_profile), + send_device_memory(std::move(send_device_memory)), + recv_device_memory(std::move(recv_device_memory)), client = client_](std::vector execution_inputs) mutable { VLOG(0) << "execute_fn for " << executable_name << ": " << launch_id << "; replica: " << replica; @@ -2447,6 +2701,13 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( run_options.set_physical_device_ordinal( device->local_hardware_id().value()); run_options.set_ffi_execution_context(ffi_context); + run_options.set_intra_op_thread_pool( + client->xla_client() + ->backend() + .eigen_intra_op_thread_pool_device()); + run_options.set_send_device_memory_function(&send_device_memory); + run_options.set_recv_device_memory_function(&recv_device_memory); + run_options.set_execution_profile(execution_profile); VLOG(2) << "launch id for " << executable_name << ": " << run_options.launch_id(); @@ -2504,10 +2765,11 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( execute_event(execute_event.CopyRef()), output_buffers(std::move(output_buffers)), execute_fn(std::move(execute_fn)), input_deps(std::move(input_deps)), - parameter_shapes(on_device_executable_parameter_shapes_[0]), + parameter_shapes(on_device_executable_parameter_shapes_[executable_idx]), parameter_is_tupled_arguments(parameter_is_tupled_arguments_), arguments_are_tupled(options.arguments_are_tupled), - input_buffer_sizes_in_bytes(input_buffer_sizes_in_bytes_[0])]() mutable { + input_buffer_sizes_in_bytes( + input_buffer_sizes_in_bytes_[executable_idx])]() mutable { tsl::profiler::TraceMe traceme("prepare_inputs"); VLOG(2) << "prepare_inputs"; @@ -2635,7 +2897,7 @@ TfrtGpuExecutable::Execute( if (returned_futures.has_value()) { returned_futures->resize(num_addressable_devices); } - if (num_addressable_devices == 1) { + if (num_addressable_devices == 1 && !ThisThreadIsInsideHostCallback()) { // Fast-path if there is only one device — run the computation on the // current thread. const int replica = addressable_device_logical_ids_[0].replica; diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index c9655e2dec6530..5af05acbd64c5f 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -51,10 +52,13 @@ limitations under the License. #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "tsl/platform/casts.h" @@ -90,6 +94,48 @@ absl::StatusOr> CompileExecutable( return client.CompileAndLoad(xla_computation, compile_options); } +// Given the result of a PjrtExecutable::Execute call (TF-status of vectors of +// vectors), extract the zeroth result from the zeroth device. +absl::StatusOr> ExtractSingleResult( + absl::StatusOr>>>& + result) { + TF_RETURN_IF_ERROR(result.status()); + TF_RET_CHECK(result->size() == 1); + std::vector>& result_buffers = (*result)[0]; + TF_RET_CHECK(result_buffers.size() == 1); + auto literal_or = result_buffers[0]->ToLiteralSync(); + if (!literal_or.status().ok()) return literal_or.status(); + return *literal_or; +} + +static constexpr char const* kProgram = R"(HloModule HostTransfer +ENTRY SendRecvSynchronous() -> f32[2] { + in_chain = token[] after-all() + + data = f32[2] constant({2, 3}) + send = (f32[2], u32[], token[]) send(data, in_chain), + channel_id=1, + is_host_transfer=true, + frontend_attributes={ + _xla_host_transfer_handler_name="undef", + _xla_host_transfer_rendezvous="undef" + } + send-done = token[] send-done(send), + channel_id=1, is_host_transfer=true + + recv = (f32[2], u32[], token[]) recv(send-done), + channel_id=2, + is_host_transfer=true, + frontend_attributes={ + _xla_host_transfer_handler_name="undef", + _xla_host_transfer_rendezvous="undef" + } + recv-done = (f32[2], token[]) recv-done(recv), + channel_id=2, is_host_transfer=true + + ROOT result = f32[2] get-tuple-element(recv-done), index=0 +})"; + TEST(TfrtGpuClientTest, GpuClientOptions) { GpuClientOptions options; options.platform_name = "cuda"; @@ -167,6 +213,57 @@ ENTRY %Add.6 (a.1: f32[], b.2: f32[]) -> (f32[], f32[]) { EXPECT_EQ(result[0][0]->GetReadyFuture().Await(), input_error); } +TEST(TfrtGpuClientTest, SendRecvChunked) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kProgram, *client)); + + std::array sent_value = {0.0f, 0.0f}; + + // Send buffer to host. + SendCallback send_callback = { + /*channel_id=*/1, [&](const PjRtTransferMetadata& m, PjRtChunk chunk, + int64_t total_size_in_bytes, bool done) { + float* data = reinterpret_cast(chunk.data()); + sent_value[0] = data[0]; + sent_value[1] = data[1]; + return absl::OkStatus(); + }}; + + // Recv buffer from host. + RecvCallback recv_callback = { + /*channel_id=*/2, [&](const PjRtTransferMetadata& m, + std::unique_ptr stream) { + auto chunk0 = PjRtChunk::AllocateDefault(sizeof(float)); + *reinterpret_cast(chunk0.data()) = 5.0f; + TF_CHECK_OK(stream->AddChunk(std::move(chunk0)).Await()); + + auto chunk1 = PjRtChunk::AllocateDefault(sizeof(float)); + *reinterpret_cast(chunk1.data()) = 6.0f; + TF_CHECK_OK(stream->AddChunk(std::move(chunk1)).Await()); + + return absl::OkStatus(); + }}; + + // Callbacks for point-to-point communication ops. + std::vector> send_callbacks = {{send_callback}}; + std::vector> recv_callbacks = {{recv_callback}}; + + ExecuteOptions opts; + opts.send_callbacks = send_callbacks; + opts.recv_callbacks = recv_callbacks; + + auto result = executable->Execute(/*argument_handles=*/{{}}, opts); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr result_literal, + ExtractSingleResult(result)); + EXPECT_EQ(sent_value[0], 2.0f); + EXPECT_EQ(sent_value[1], 3.0f); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({5.0f, 6.0f}), + *result_literal)); +} + TEST(TfrtGpuClientTest, AcquireDonation) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); ASSERT_GE(client->devices().size(), 1); From eef216dc0e89d03eb4ae409ae873a0db42c10c95 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Sun, 30 Mar 2025 00:06:05 -0700 Subject: [PATCH 0018/1324] Remove unused hlo_module_util include from hlo_runner_interface.h. PiperOrigin-RevId: 741957652 --- third_party/xla/xla/service/BUILD | 1 - third_party/xla/xla/service/hlo_runner_interface.h | 1 - 2 files changed, 2 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 456dec8e8a7587..fa73abdc3a8c95 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4734,7 +4734,6 @@ cc_library( hdrs = ["hlo_runner_interface.h"], deps = [ ":computation_placer", - ":hlo_module_util", "//xla:literal", "//xla:shape_util", "//xla:util", diff --git a/third_party/xla/xla/service/hlo_runner_interface.h b/third_party/xla/xla/service/hlo_runner_interface.h index d4efe55e24284d..f58290dfd2cb9a 100644 --- a/third_party/xla/xla/service/hlo_runner_interface.h +++ b/third_party/xla/xla/service/hlo_runner_interface.h @@ -33,7 +33,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/service/computation_placer.h" -#include "xla/service/hlo_module_util.h" #include "xla/shape.h" #include "xla/util.h" #include "xla/xla_data.pb.h" From 75f927670d1df6f970ac228c9362385deafa600c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 02:02:39 -0700 Subject: [PATCH 0019/1324] compat: Update forward compatibility horizon to 2025-03-30 PiperOrigin-RevId: 741975469 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index c18dd57b46e9c3..8d378effc9366b 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 30) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 6557278b9b372e479f1171721eda34857c32cf4f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 02:02:45 -0700 Subject: [PATCH 0020/1324] Update GraphDef version to 2182. PiperOrigin-RevId: 741975488 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 2e0cb08f8b384a..6455563109606c 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2181 // Updated: 2025/3/29 +#define TF_GRAPH_DEF_VERSION 2182 // Updated: 2025/3/30 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 874e7a64680a99584e208ae27a6916c99d2e3c93 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 14:52:29 -0700 Subject: [PATCH 0021/1324] Automated Code Change PiperOrigin-RevId: 742074897 --- third_party/xla/xla/service/gpu/resource_requests.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/resource_requests.cc b/third_party/xla/xla/service/gpu/resource_requests.cc index ff9b4039282aa6..d864d439d67cbe 100644 --- a/third_party/xla/xla/service/gpu/resource_requests.cc +++ b/third_party/xla/xla/service/gpu/resource_requests.cc @@ -28,7 +28,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/backends/gpu/collectives/gpu_clique.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" From 2092694c873aebaab4377749d85846feace8e670 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 15:45:14 -0700 Subject: [PATCH 0022/1324] Conceptually, a shape is a union of several cases: an invalid shape, a token, an opaque, an array, or a tuple. These cases are mutually exclusive (a shape must be exactly one of these cases.) Currently, the exclusivity of the shape cases is not enforced, allowing bugs where a shape is used as the wrong case (e.g. accessing the dimension fields of a tuple shape). This change makes `Shape` safer by representing its state as an `std::variant<>`. By construction, this guarantees that it can hold the state for just one case, and trying to access the state for another case would result in a crash. Given the large number of existing bugs uncovered by the intended change, this change is deliberately conservative: in some cases where the caller violates the precondition of a `Shape` method, we choose to return a default value instead of crashing (e.g. calling `dimensions()` on a non-array shape would return an empty span even though it's a programmer error). We will tighten the enforcement of the preconditions in future CLs. PiperOrigin-RevId: 742081360 --- .../client/executable_build_options_test.cc | 4 + .../xla/xla/hlo/builder/xla_builder.cc | 8 +- third_party/xla/xla/layout_util.cc | 31 +- third_party/xla/xla/service/hlo_cse.cc | 6 +- third_party/xla/xla/shape.cc | 265 +++++++++---- third_party/xla/xla/shape.h | 371 ++++++++++++++---- third_party/xla/xla/shape_util.cc | 89 +++-- third_party/xla/xla/shape_util.h | 1 + 8 files changed, 555 insertions(+), 220 deletions(-) diff --git a/third_party/xla/xla/client/executable_build_options_test.cc b/third_party/xla/xla/client/executable_build_options_test.cc index d7529ccd7eeaa1..34cafb8c81af2d 100644 --- a/third_party/xla/xla/client/executable_build_options_test.cc +++ b/third_party/xla/xla/client/executable_build_options_test.cc @@ -52,8 +52,12 @@ std::unique_ptr ProcessNewEnv( TEST(ExecutableBuildOptionsTest, ProtoRoundTripWorks) { ExecutableBuildOptionsProto p; p.set_device_ordinal(1); + + // Set result_layout to an array shape. + p.mutable_result_layout()->set_element_type(PrimitiveType::F32); p.mutable_result_layout()->add_dimensions(2); p.mutable_result_layout()->add_is_dynamic_dimension(true); + { CompilationEnvironments::RegisterProcessNewEnvFn( test::TestCompilationEnvironment1::descriptor(), ProcessNewEnv); diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 9c7ee050eda4b8..f4f94586a89eda 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -3079,9 +3079,11 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); - Shape output_shape = shape; - output_shape.set_element_type(PRIMITIVE_TYPE_INVALID); - if (primitive_util::IsArrayType(shape.element_type())) { + Shape output_shape; // An invalid shape by default. + if (shape.IsArray()) { + // Make output_shape the same as the input shape, but with an unsigned + // integral type. + output_shape = shape; output_shape.set_element_type( primitive_util::UnsignedIntegralTypeForBitWidth( primitive_util::BitWidth(shape.element_type()))); diff --git a/third_party/xla/xla/layout_util.cc b/third_party/xla/xla/layout_util.cc index 6e4ad8f06f9ff9..e75110a93f7506 100644 --- a/third_party/xla/xla/layout_util.cc +++ b/third_party/xla/xla/layout_util.cc @@ -180,14 +180,12 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } - shape->clear_layout(); } else if (shape->IsArray()) { auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. - shape->clear_layout(); } } @@ -207,10 +205,6 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { /* static */ absl::Status LayoutUtil::ValidateLayoutInShape( const Shape& shape, bool allow_missing_layouts) { if (shape.IsTuple()) { - // Tuple shape. - if (shape.has_layout()) { - return InvalidArgument("tuple should not have a layout field"); - } for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR( ValidateLayoutInShape(element_shape, allow_missing_layouts)); @@ -227,11 +221,6 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { return ValidateLayoutForShape(shape.layout(), shape); } else { // Token, opaque, etc. shape. - if (shape.has_layout()) { - return InvalidArgument( - "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type())); - } return absl::OkStatus(); } } @@ -242,14 +231,7 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (!shape.IsArray()) { - if (layout.minor_to_major_size() != 0) { - return InvalidArgument( - "shape of primitive type %s should not have a non-trivial layout", - PrimitiveType_Name(shape.element_type())); - } - return absl::OkStatus(); - } + if (!shape.IsArray()) return absl::OkStatus(); if (layout.minor_to_major_size() != shape.dimensions_size()) { return InvalidArgument( @@ -423,9 +405,12 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { - shape->clear_layout(); - for (auto& element_shape : *shape->mutable_tuple_shapes()) { - ClearLayout(&element_shape); + if (shape->IsArray()) { + shape->clear_layout(); + } else if (shape->IsTuple()) { + for (auto& element_shape : *shape->mutable_tuple_shapes()) { + ClearLayout(&element_shape); + } } } @@ -601,7 +586,7 @@ absl::Status CopyLayoutInternal(const Shape& src, Shape* dst) { TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i), dst->mutable_tuple_shapes(i))); } - } else { + } else if (src.IsArray()) { if (src.has_layout()) { if (src.dimensions_size() != dst->dimensions_size()) { return InvalidArgument("cannot copy layout from shape: ranks differs"); diff --git a/third_party/xla/xla/service/hlo_cse.cc b/third_party/xla/xla/service/hlo_cse.cc index 2b98151580efca..570ffa03de0edf 100644 --- a/third_party/xla/xla/service/hlo_cse.cc +++ b/third_party/xla/xla/service/hlo_cse.cc @@ -120,8 +120,10 @@ struct CseKey { template friend H AbslHashValue(H h, const CseKey& key) { auto instruction = key.hlo; - h = H::combine(std::move(h), instruction->opcode(), - instruction->shape().dimensions()); + h = instruction->shape().IsArray() + ? H::combine(std::move(h), instruction->opcode(), + instruction->shape().dimensions()) + : H::combine(std::move(h), instruction->opcode()); auto window_hash = [](H h, const Window& window) { const auto& window_dims = window.dimensions(); for (const auto& window_dim : window_dims) { diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index e455e6a96e6892..0d9feb7ad2f556 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/types/span.h" @@ -45,61 +46,71 @@ Shape::Shape(Shape&&) noexcept = default; Shape& Shape::operator=(const Shape&) = default; Shape& Shape::operator=(Shape&&) noexcept = default; -Shape::Shape(const PrimitiveType element_type) : element_type_(element_type) { - CHECK(element_type_ == TOKEN || element_type_ == OPAQUE_TYPE) +Shape::Shape(const PrimitiveType element_type) { + CHECK(element_type == TOKEN || element_type == OPAQUE_TYPE) << "Invalid element type for token or opaque shape: " << element_type_; + set_element_type(element_type); } Shape::Shape(const PrimitiveType element_type, const absl::Span dimensions, - const absl::Span dynamic_dimensions) - : element_type_(element_type), - dimensions_(dimensions.begin(), dimensions.end()), - dynamic_dimensions_(dynamic_dimensions.begin(), - dynamic_dimensions.end()) { - CHECK(primitive_util::IsArrayType(element_type_)) - << "Invalid element type for array shape: " << element_type_; - if (dynamic_dimensions_.empty()) { - // Assume all dimensions are static. - dynamic_dimensions_.resize(dimensions_.size(), false); - } else { - CHECK_EQ(dimensions_.size(), dynamic_dimensions_.size()) + const absl::Span dynamic_dimensions) { + CHECK(primitive_util::IsArrayType(element_type)) + << "Invalid element type for array shape: " << element_type; + if (!dynamic_dimensions.empty()) { + CHECK_EQ(dimensions.size(), dynamic_dimensions.size()) << "If dynamic_dimensions is provided, it must have the same size as " "dimensions."; } + + set_element_type(element_type); + auto& state = array_state(); + state.dimensions = {dimensions.begin(), dimensions.end()}; + if (dynamic_dimensions.empty()) { + // Assume all dimensions are static. + state.dynamic_dimensions.resize(dimensions.size(), false); + } else { + state.dynamic_dimensions = absl::InlinedVector( + dynamic_dimensions.begin(), dynamic_dimensions.end()); + } } -Shape::Shape(std::vector tuple_shapes) - : element_type_(TUPLE), tuple_shapes_(std::move(tuple_shapes)) {} +Shape::Shape(std::vector tuple_shapes) { + set_element_type(TUPLE); + tuple_state().tuple_shapes = std::move(tuple_shapes); +} Shape::Shape(const ShapeProto& shape_proto) { set_element_type(shape_proto.element_type()); - dimensions_.reserve(shape_proto.dimensions_size()); - for (const int64_t dimension : shape_proto.dimensions()) { - add_dimensions(dimension); - } - // A malformed proto may have different is_dynamic_dimension_size and - // dimensions_size. Since C++ is evil, and we have no good way of bailing out - // in a constructor, conservatively trim the is_dynamic_dimension size. - // TODO(b/120111794): Make this a hard error when we have a factory method - // instead of a constructor. - if (shape_proto.dimensions_size() != - shape_proto.is_dynamic_dimension_size()) { - if (shape_proto.is_dynamic_dimension_size() != 0) { - LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " - "fields does not match number of dimension fields"; - } else { - LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty"; + if (auto* const state = if_array_state()) { + state->dimensions.reserve(shape_proto.dimensions_size()); + for (const int64_t dimension : shape_proto.dimensions()) { + add_dimensions(dimension); + } + // A malformed proto may have different is_dynamic_dimension_size and + // dimensions_size. Since C++ is evil, and we have no good way of bailing + // out in a constructor, conservatively trim the is_dynamic_dimension size. + // TODO(b/120111794): Make this a hard error when we have a factory method + // instead of a constructor. + if (shape_proto.dimensions_size() != + shape_proto.is_dynamic_dimension_size()) { + if (shape_proto.is_dynamic_dimension_size() != 0) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } else { + LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty"; + } + } + const int64_t num_dynamic_dimension_fields = std::min( + shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); + for (int i = 0; i < num_dynamic_dimension_fields; i++) { + state->dynamic_dimensions[i] = shape_proto.is_dynamic_dimension(i); + } + } else if (auto* const state = if_tuple_state()) { + state->tuple_shapes.reserve(shape_proto.tuple_shapes_size()); + for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { + state->tuple_shapes.emplace_back(element_shape); } - } - const int64_t num_dynamic_dimension_fields = std::min( - shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); - for (int i = 0; i < num_dynamic_dimension_fields; i++) { - dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i); - } - tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); - for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { - tuple_shapes_.emplace_back(element_shape); } if (shape_proto.has_layout()) { if (!IsArray()) { @@ -115,19 +126,23 @@ Shape::Shape(const ShapeProto& shape_proto) { void Shape::SetProto(ShapeProto& proto) const { proto.Clear(); proto.set_element_type(element_type_); - proto.mutable_dimensions()->Reserve(dimensions_size()); - for (const int64_t dimension : dimensions()) { - proto.add_dimensions(dimension); - } - for (const bool dynamic : dynamic_dimensions_) { - proto.add_is_dynamic_dimension(dynamic); - } - proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); - for (const Shape& shape : tuple_shapes()) { - shape.SetProto(*proto.add_tuple_shapes()); - } - if (has_layout()) { - layout().SetProto(*proto.mutable_layout()); + + if (const auto* const state = if_array_state()) { + proto.mutable_dimensions()->Reserve(state->dimensions.size()); + for (const int64_t dimension : state->dimensions) { + proto.add_dimensions(dimension); + } + for (const bool dynamic : state->dynamic_dimensions) { + proto.add_is_dynamic_dimension(dynamic); + } + if (state->layout.has_value()) { + state->layout->SetProto(*proto.mutable_layout()); + } + } else if (const auto* const state = if_tuple_state()) { + proto.mutable_tuple_shapes()->Reserve(state->tuple_shapes.size()); + for (const Shape& shape : state->tuple_shapes) { + shape.SetProto(*proto.add_tuple_shapes()); + } } } @@ -154,75 +169,163 @@ std::string Shape::ToString(bool print_layout) const { } bool Shape::AreAllLeavesIntegers() const { - if (IsTuple()) { - return absl::c_all_of( - tuple_shapes_, [](const Shape& s) { return s.AreAllLeavesIntegers(); }); + if (const auto* const state = if_tuple_state()) { + return absl::c_all_of(state->tuple_shapes, [](const Shape& s) { + return s.AreAllLeavesIntegers(); + }); } return primitive_util::IsIntegralType(element_type()); } bool Shape::is_static() const { - if (IsTuple()) { - return absl::c_all_of(tuple_shapes_, + if (const auto* const state = if_tuple_state()) { + return absl::c_all_of(state->tuple_shapes, [](const Shape& s) { return s.is_static(); }); } - return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); + if (const auto* const state = if_array_state()) { + return !absl::c_any_of(state->dynamic_dimensions, [](bool b) { return b; }); + } + return true; } bool Shape::is_unbounded_dynamic() const { - if (IsTuple()) { - return absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + if (const auto* const state = if_tuple_state()) { + return absl::c_any_of(state->tuple_shapes, [](const Shape& subshape) { return subshape.is_unbounded_dynamic(); }); } - return absl::c_any_of(dimensions_, - [](int64_t dim) { return dim == kUnboundedSize; }); + if (const auto* const state = if_array_state()) { + return absl::c_any_of(state->dimensions, + [](int64_t dim) { return dim == kUnboundedSize; }); + } + return false; } bool Shape::is_bounded_dynamic() const { - if (IsTuple()) { - return absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + if (const auto* const state = if_tuple_state()) { + return absl::c_any_of(state->tuple_shapes, [](const Shape& subshape) { return subshape.is_bounded_dynamic(); }); } - for (auto i = 0; i < dimensions_.size(); ++i) { - if (is_bounded_dynamic_dimension(i)) return true; + if (const auto* const state = if_array_state()) { + for (auto i = 0; i < state->dimensions.size(); ++i) { + if (is_bounded_dynamic_dimension(i)) return true; + } + return false; } return false; } void Shape::DeleteDimension(int64_t dim_to_delete) { - CHECK(IsArray()); + auto& state = array_state(); CHECK_GE(dim_to_delete, 0); - CHECK_LT(dim_to_delete, dimensions_.size()); - dimensions_.erase(dimensions_.begin() + dim_to_delete); - dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete); + CHECK_LT(dim_to_delete, state.dimensions.size()); + state.dimensions.erase(state.dimensions.begin() + dim_to_delete); + state.dynamic_dimensions.erase(state.dynamic_dimensions.begin() + + dim_to_delete); if (LayoutUtil::HasLayout(*this)) { - layout_->DeleteDimension(dim_to_delete); // NOLINT: optional-access + state.layout->DeleteDimension(dim_to_delete); // NOLINT: optional-access } } void Shape::DeleteDimensions(absl::Span sorted_dims_to_delete) { - CHECK(IsArray()); + auto& state = array_state(); CHECK(absl::c_is_sorted(sorted_dims_to_delete)); - dimensions_ = RemoveElements(sorted_dims_to_delete, dimensions_); - dynamic_dimensions_ = - RemoveElements(sorted_dims_to_delete, dynamic_dimensions_); + state.dimensions = RemoveElements(sorted_dims_to_delete, state.dimensions); + state.dynamic_dimensions = + RemoveElements(sorted_dims_to_delete, state.dynamic_dimensions); if (LayoutUtil::HasLayout(*this)) { for (auto it = sorted_dims_to_delete.rbegin(); it != sorted_dims_to_delete.rend(); ++it) { - layout_->DeleteDimension(*it); // NOLINT: optional-access + state.layout->DeleteDimension(*it); // NOLINT: optional-access } } } +void Shape::CheckStateIsEmpty() const { + if (const auto* const state = if_array_state()) { + CHECK(state->dimensions.empty()) << ToString(); + CHECK(state->dynamic_dimensions.empty()) << ToString(); + CHECK(!state->layout.has_value()) << ToString(); + } else if (const auto* const state = if_tuple_state()) { + CHECK(state->tuple_shapes.empty()) << ToString(); + } +} + +const std::vector& Shape::tuple_shapes() const { + if (const auto* const state = if_tuple_state()) { + return state->tuple_shapes; + } + // TODO(b/404276923): ensure that this is never called on non-tuple shapes. + static const auto* const kEmpty = new std::vector(); + return *kEmpty; +} + +void Shape::Clear() { + // Before setting the element type to invalid, we need to clear the state + // because the state may be non-empty if the shape was previously valid. + // Without this step, set_element_type() may CHECK-fail. + if (auto* const state = if_array_state()) { + *state = ArrayState(); + } else if (auto* const state = if_tuple_state()) { + *state = TupleState(); + } + set_element_type(PRIMITIVE_TYPE_INVALID); +} + +void Shape::set_element_type(const PrimitiveType value) { + element_type_ = value; + + // Make sure the variant state matches the element type. + // If we have to change the case of the variant, and the current case is not + // empty, it's likely a programmer error - we CHECK-fail to catch it. + if (element_type_ == TOKEN) { + if (!if_token_state()) { + CheckStateIsEmpty(); + state_ = TokenState(); + } + return; + } + if (element_type_ == OPAQUE_TYPE) { + if (!if_opaque_state()) { + CheckStateIsEmpty(); + state_ = OpaqueState(); + } + return; + } + if (element_type_ == TUPLE) { + if (!if_tuple_state()) { + CheckStateIsEmpty(); + state_ = TupleState(); + } + return; + } + if (primitive_util::IsArrayType(element_type_)) { + if (!if_array_state()) { + CheckStateIsEmpty(); + state_ = ArrayState(); + } + return; + } + // Treat all other types as invalid. + if (element_type_ != PRIMITIVE_TYPE_INVALID) { + LOG(ERROR) << "Unsupported element type: " << element_type_; + element_type_ = PRIMITIVE_TYPE_INVALID; + } + if (!if_invalid_state()) { + CheckStateIsEmpty(); + state_ = InvalidState(); + } +} + const Shape& Shape::tuple_shapes(int index) const { - return tuple_shapes_[index]; + return tuple_state().tuple_shapes[index]; } Shape* Shape::add_tuple_shapes() { - tuple_shapes_.push_back(Shape()); - return &tuple_shapes_.back(); + auto& state = tuple_state(); + state.tuple_shapes.push_back(Shape()); + return &state.tuple_shapes.back(); } bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index f5882bdd5bf953..e8b14898f39867 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/attributes.h" @@ -41,6 +42,18 @@ namespace xla { // A shape describes the number of dimensions in a array, the bounds of each // dimension, and the primitive component type. For tuples, shape describes the // structure (number of elements and nesting). +// +// Depending on the element type, the shape falls into one of the following +// categories: +// +// - Invalid: element_type == PRIMITIVE_TYPE_INVALID +// - Token: element_type == TOKEN +// - Opaque: element_type == OPAQUE_TYPE +// - Array: element_type is an array type +// - Tuple: element_type == TUPLE +// +// These categories are mutually exclusive, i.e. a shape can only be one of +// them. class Shape { public: // Creates an invalid shape, with element type PRIMITIVE_TYPE_INVALID and the @@ -89,7 +102,8 @@ class Shape { // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". std::string ToString(bool print_layout = false) const; - // Returns whether the shape is of the specified type (array, tuple, etc). + // Returns whether the shape is of the specified category (array, tuple, etc). + // TODO(b/404276923): check that element_type() and the state_ are in sync. bool IsArray() const { return primitive_util::IsArrayType(element_type()); } bool IsTuple() const { return element_type() == TUPLE; } bool IsToken() const { return element_type() == TOKEN; } @@ -103,6 +117,8 @@ class Shape { // shapes are traversed recursively. bool is_static() const; + // Returns true if the shape contains at least one dynamic dimension. Tuple + // shapes are traversed recursively. bool is_dynamic() const { return !is_static(); } // Unbounded dynamism. @@ -118,14 +134,19 @@ class Shape { bool is_unbounded_dynamic() const; // Returns true if the given dimension is unbounded dynamic. + // Precondition: this is an array shape and `dimension` is a valid dimension + // index. bool is_unbounded_dynamic_dimension(int dimension) const { - return dimensions_[dimension] == kUnboundedSize; + return array_state().dimensions[dimension] == kUnboundedSize; } // Sets a given dimension as unbounded dynamic. + // Precondition: this is an array shape and `dimension` is a valid dimension + // index. void set_unbounded_dynamic_dimension(int dimension) { - dynamic_dimensions_[dimension] = true; - dimensions_[dimension] = kUnboundedSize; + auto& state = array_state(); + state.dynamic_dimensions[dimension] = true; + state.dimensions[dimension] = kUnboundedSize; } // Returns true if the shape has one or more dimensions with bounded sizes. @@ -134,118 +155,225 @@ class Shape { bool is_bounded_dynamic() const; // Returns true if the given dimension is bounded dynamic. + // Precondition: this is an array shape and `dimension` is a valid dimension + // index. bool is_bounded_dynamic_dimension(int dimension) const { return is_dynamic_dimension(dimension) && !is_unbounded_dynamic_dimension(dimension); } // Returns true if the given dimension is dynamically-sized. + // Precondition: this is an array shape and `dimension` is a valid dimension + // index. bool is_dynamic_dimension(int dimension) const { - return dynamic_dimensions_[dimension]; + return array_state().dynamic_dimensions[dimension]; } // Returns true if the given dimension is statically-sized. + // Precondition: this is an array shape and `dimension` is a valid dimension + // index. bool is_static_dimension(int dimension) const { - return !dynamic_dimensions_[dimension]; + return !array_state().dynamic_dimensions[dimension]; } // Sets whether or not the given dimension is dynamically-sized. + // Precondition: this is an array shape and `dimension` is a valid dimension + // index. void set_dynamic_dimension(int dimension, bool is_dynamic) { - dynamic_dimensions_[dimension] = is_dynamic; + array_state().dynamic_dimensions[dimension] = is_dynamic; } + // Returns a span to indicate whether each dimension is dynamic. + // Precondition: this is an array shape. absl::Span dynamic_dimensions() const { - return dynamic_dimensions_; + if (auto* const state = if_array_state()) { + return state->dynamic_dimensions; + } + // TODO(b/404276923): ensure that this is never called on non-array shapes. + return {}; } - absl::Span mutable_dynamic_dimensions() { - return absl::MakeSpan(dynamic_dimensions_); + return absl::MakeSpan(array_state().dynamic_dimensions); } // Removes the given dimension from the shape. Layout, if it exists, is // adjusted to match the modified shape. + // Precondition: this is an array shape, and the input dimension indices are + // valid. void DeleteDimension(int64_t dim_to_delete); void DeleteDimensions(absl::Span sorted_dims_to_delete); - // Methods for accessing the primitive type. + // Returns the primitive type of the shape. PrimitiveType element_type() const { return element_type_; } - void set_element_type(PrimitiveType value) { element_type_ = value; } - // Methods for accessing the dimensions array. + // Sets the primitive type of the shape. If the new type and the old type + // are in different categories (e.g. array vs. tuple), the state is reset + // to the default (empty) state for the new type; otherwise, the state is + // preserved. This behavior ensures that the state is always consistent with + // the element type. + void set_element_type(PrimitiveType value); + + // Returns the number of dimensions in the shape. + // Precondition: this is an array shape. ABSL_DEPRECATE_AND_INLINE() inline int dimensions_size() const { return dimensions().size(); } - int64_t dimensions(int index) const { return dimensions_[index]; } + // Returns the size of the given dimension if it's static, or the upper bound + // of the dimension size if it's dynamic. + // Precondition: this is an array shape and `index` is a valid dimension + // index. + int64_t dimensions(int index) const { + return array_state().dimensions[index]; + } + + // Returns the physical dimension index of the index-th minor dimension. + // Precondition: this is an array shape, `index` is a valid dimension + // index, and the shape has a layout. int64_t dimensions_minor(int index) const { CHECK(has_layout()); - return dimensions_[layout_->minor_to_major(index)]; + const auto& state = array_state(); + return state.dimensions[state.layout->minor_to_major(index)]; + } + + // Sets the size of the given dimension if it's static, or sets the upper + // bound of the dimension size if it's dynamic. + // Precondition: this is an array shape, `index` is a valid dimension + // index, and value is either >= 0 or kUnboundedSize. + void set_dimensions(int index, int64_t value) { + array_state().dimensions[index] = value; } - void set_dimensions(int index, int64_t value) { dimensions_[index] = value; } + + // Sets the physical dimension index of the index-th minor dimension. + // Precondition: this is an array shape, `index` and `value` are valid + // dimension indices, and the shape has a layout. void set_dimensions_minor(int index, int64_t value) { CHECK(has_layout()); - dimensions_[layout_->minor_to_major(index)] = value; + auto& state = array_state(); + state.dimensions[state.layout->minor_to_major(index)] = value; } + + // Appends a new dimension with the given fixed size. + // Precondition: this is an array shape, and `value` is >= 0. void add_dimensions(int64_t value) { - dimensions_.push_back(value); - dynamic_dimensions_.push_back(false); + auto& state = array_state(); + state.dimensions.push_back(value); + state.dynamic_dimensions.push_back(false); } + + // Clears all dimensions (i.e. makes this shape a scalar). + // Precondition: this is an array shape. void clear_dimensions() { - dimensions_.clear(); - dynamic_dimensions_.clear(); + auto& state = array_state(); + state.dimensions.clear(); + state.dynamic_dimensions.clear(); + } + + // Returns a span to indicate the size of each dimension. + // Precondition: this is an array shape. + absl::Span dimensions() const { + if (const auto* const state = if_array_state()) { + return state->dimensions; + } + // TODO(b/404276923): ensure that this is never called on non-array shapes. + return {}; } - absl::Span dimensions() const { return dimensions_; } absl::Span mutable_dimensions() { - return absl::MakeSpan(dimensions_); + return absl::MakeSpan(array_state().dimensions); } - // Methods for accessing the tuple subshapes. This field only non-empty for - // tuple shapes. - int tuple_shapes_size() const { return tuple_shapes_.size(); } + // Returns the number of top-level tuple components in this shape. + // Precondition: this is a tuple shape. + int tuple_shapes_size() const { + if (const auto* const state = if_tuple_state()) { + return state->tuple_shapes.size(); + } + // TODO(b/404276923): ensure that this is never called on non-tuple shapes. + return 0; + } + + // Returns the shape of the i-th tuple component. + // Precondition: this is a tuple shape and `index` is a valid tuple component + // index. const Shape& tuple_shapes(int index) const; - Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_[index]; } + Shape* mutable_tuple_shapes(int index) { + return &tuple_state().tuple_shapes[index]; + } + + // Appends a new invalid shape to the tuple and returns a pointer to it. + // Precondition: this is a tuple shape. + // Postcondition: the returned pointer is not null, and the pointee is owned + // by this shape. Shape* add_tuple_shapes(); - void clear_tuple_shapes() { tuple_shapes_.clear(); } - const std::vector& tuple_shapes() const { return tuple_shapes_; } - std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } - // Methods for accessing the layout field. - bool has_layout() const { return layout_ != std::nullopt; } + // Clears all tuple components (i.e. makes this shape a 0-tuple). + // Precondition: this is a tuple shape. + void clear_tuple_shapes() { tuple_state().tuple_shapes.clear(); } + + // Returns a vector of all tuple component shapes. + // Precondition: this is a tuple shape. + const std::vector& tuple_shapes() const; + std::vector* mutable_tuple_shapes() { + return &tuple_state().tuple_shapes; + } + + // Returns true if the shape is an array and has a layout. + bool has_layout() const { + const auto* const state = if_array_state(); + return state != nullptr && state->layout != std::nullopt; + } + + // Returns the layout of the shape. + // Precondition: this is an array shape and has a layout. const Layout& layout() const { - CHECK(has_layout()) << ShortDebugString(); - return *layout_; + CHECK(has_layout()) << ToString(); + return *array_state().layout; } + + // Returns a pointer to the layout of the shape. If the shape does not have a + // layout, an empty layout is created. + // Precondition: this is an array shape. + // Postcondition: the returned pointer is not null, and the pointee is owned + // by this shape. Layout* mutable_layout() { - CHECK(IsArray()) << ShortDebugString(); - if (layout_ == std::nullopt) { - layout_.emplace(); + auto& state = array_state(); + if (state.layout == std::nullopt) { + state.layout.emplace(); + } + return &(*state.layout); + } + + // Removes the layout of the shape, if any. + // Precondition: this is an array shape. + void clear_layout() { + // TODO(b/404276923): ensure that this is never called on non-array shapes. + if (auto* const state = if_array_state()) { + state->layout = std::nullopt; } - return &(*layout_); } - void clear_layout() { layout_ = std::nullopt; } // Recursively clear all dynamic dimension of a shape, including bounded and - // unbounded dynamic dimensions. + // unbounded dynamic dimensions. Clearing a dynamic dimension means + // changing the dimension to static and setting its size as the dynamic + // dimension's size upper bound. void clear_dynamic_dimensions() { - if (!IsTuple()) { + if (auto* const state = if_array_state()) { if (is_dynamic()) { mutable_layout()->set_dynamic_shape_metadata_prefix_bytes(0); } - for (int64_t i = 0; i < dynamic_dimensions_.size(); ++i) { - dynamic_dimensions_[i] = false; + for (int64_t i = 0; i < state->dynamic_dimensions.size(); ++i) { + state->dynamic_dimensions[i] = false; } return; } - for (auto& subshape : tuple_shapes_) { - subshape.clear_dynamic_dimensions(); + if (auto* const state = if_tuple_state()) { + for (auto& subshape : state->tuple_shapes) { + subshape.clear_dynamic_dimensions(); + } } } - void Clear() { - element_type_ = PRIMITIVE_TYPE_INVALID; - clear_dimensions(); - tuple_shapes_.clear(); - clear_layout(); - } + // Resets this to the default state (an invalid shape). + void Clear(); std::string SerializeAsString() const { return ToProto().SerializeAsString(); @@ -336,18 +464,21 @@ class Shape { template static H Hash(H h, const Shape& s) { - if (s.IsTuple()) { - for (const Shape& subshape : s.tuple_shapes_) { + if (const auto* const state = s.if_tuple_state()) { + for (const Shape& subshape : state->tuple_shapes) { h = Shape::Hash(std::move(h), subshape); } - return H::combine(std::move(h), s.tuple_shapes_size()); + return H::combine(std::move(h), state->tuple_shapes.size()); } - h = H::combine(std::move(h), s.element_type_, s.dimensions_, - s.dynamic_dimensions_); - if (kIsLayoutSensitive) { - h = H::combine(std::move(h), s.layout_); + if (const auto* const state = s.if_array_state()) { + h = H::combine(std::move(h), s.element_type_, state->dimensions, + state->dynamic_dimensions); + if (kIsLayoutSensitive) { + h = H::combine(std::move(h), state->layout); + } + return std::move(h); } - return std::move(h); + return H::combine(std::move(h), s.element_type_); } template @@ -356,23 +487,119 @@ class Shape { } private: - // The element type of this shape (tuple, array, etc). - PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; + friend absl::Status ValidateNonLayoutProperties(const Shape& shape); + + // Define one state struct for each shape category. Depending on the element + // type, the state_ variant will be set to exactly one of these structs. + // This design has several benefits: + // - It prevents (by construction) bugs where the shape's state has + // non-empty fields that don't match the shape's element type. + // - It prevents (by construction) bugs where the code accesses a field + // of a shape's state that doesn't match the shape's element type (e.g. + // accessing the tuple_shapes field of an array shape). + // - It simplifies the code by eliminating the need for runtime handling of + // fields that are irrelevant to the shape's category. + // - It reduces the size of the Shape class as the variant doesn't need to + // store the fields for all shape categories at once. + struct InvalidState {}; + struct TokenState {}; + struct OpaqueState {}; + struct ArrayState { + // The array bounds of the dimensions. For a dynamically-sized dimension, + // the respective value in this vector is an inclusive upper limit of the + // array bound. + DimensionVector dimensions; + + // This vector has the same size as 'dimensions' and indicates whether the + // respective dimension is dynamically sized. + absl::InlinedVector dynamic_dimensions; + + // The layout of the shape. + std::optional layout; + }; + struct TupleState { + // The tuple element subshapes. + std::vector tuple_shapes; + }; + + using State = std::variant; + + // Convenience accessors for the state_ variant. Each if_*_state() accessor + // returns a pointer to the corresponding state struct, or nullptr if the + // shape is not of the corresponding category. The version without the `if_` + // prefix is similar, but will CHECK-fail if the shape is not of the + // corresponding category. I.e. if_foo_state() vs foo_state() is analogous to + // std::get_if() vs std::get(). + // + // In general, prefer foo_state() over if_foo_state() as the former catches + // programmer errors earlier and generates a more informative error message. + // However, if_foo_state() is useful in cases where it's not a programmer + // error if the shape is not of the corresponding category. + + const InvalidState* if_invalid_state() const { + return std::get_if(&state_); + } + const TokenState* if_token_state() const { + return std::get_if(&state_); + } + const OpaqueState* if_opaque_state() const { + return std::get_if(&state_); + } + const ArrayState* if_array_state() const { + return std::get_if(&state_); + } + ArrayState* if_array_state() { return std::get_if(&state_); } + const TupleState* if_tuple_state() const { + return std::get_if(&state_); + } + TupleState* if_tuple_state() { return std::get_if(&state_); } - // The array bounds of the dimensions. This is nonempty only for array - // shapes. For a dynamically-sized dimension, the respective value in this - // vector is an inclusive upper limit of the array bound. - DimensionVector dimensions_; + const InvalidState& invalid_state() const { + const auto* const state = if_invalid_state(); + CHECK(state) << "Expected an invalid shape. Got " << ToString(); + return *state; + } + const TokenState& token_state() const { + const auto* const state = if_token_state(); + CHECK(state) << "Expected a token shape. Got " << ToString(); + return *state; + } + const OpaqueState& opaque_state() const { + const auto* const state = if_opaque_state(); + CHECK(state) << "Expected an opaque shape. Got " << ToString(); + return *state; + } + const ArrayState& array_state() const { + const auto* const state = if_array_state(); + CHECK(state) << "Expected an array shape. Got " << ToString(); + return *state; + } + ArrayState& array_state() { + auto* const state = if_array_state(); + CHECK(state) << "Expected an array shape. Got " << ToString(); + return *state; + } + const TupleState& tuple_state() const { + const auto* const state = if_tuple_state(); + CHECK(state) << "Expected a tuple shape. Got " << ToString(); + return *state; + } + TupleState& tuple_state() { + auto* const state = if_tuple_state(); + CHECK(state) << "Expected a tuple shape. Got " << ToString(); + return *state; + } - // This vector is the same size as 'dimensions_' and indicates whether the - // respective dimension is dynamically sized. - absl::InlinedVector dynamic_dimensions_; + // CHECK-fails if this shape's state is not empty. + void CheckStateIsEmpty() const; - // The tuple element subshapes. This is nonempty only for tuple shapes. - std::vector tuple_shapes_; + // The element type of this shape (tuple, array, etc). + PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; - // The layout of the shape. Only relevant for arrays. - std::optional layout_; + // The state of this shape. + // Invariant: element_type_ always matches the type held in this variant. + State state_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index fc5f723623e844..9a5db0e9bc399a 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -40,7 +40,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/index_util.h" @@ -160,8 +159,7 @@ const T& Deref(const T& ref) { template Shape MakeTupleShapeImpl(absl::Span shapes) { - Shape result; - result.set_element_type(TUPLE); + Shape result(std::vector{}); result.mutable_tuple_shapes()->reserve(shapes.size()); for (const auto& shape : shapes) { ShapeUtil::AppendShapeToTuple(Deref(shape), &result); @@ -228,8 +226,13 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { } /* static */ int64_t ShapeUtil::TrueNumDimensions(const Shape& shape) { + if (!shape.IsArray()) { + // TODO(b/404276923): enforce that this is never called on non-array shapes. + return 0; + } + int64_t accum = 0; - for (int64_t dimension : shape.dimensions()) { + for (const int64_t dimension : shape.dimensions()) { // We do not count unit dimensions. if (dimension != 1) { accum += 1; @@ -450,10 +453,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( Shape* shape) { shape->Clear(); shape->set_element_type(element_type); - for (int64_t dimension : dimensions) { - shape->add_dimensions(dimension); + // TODO(b/404276923): ensure that dimensions is empty if this is a non-array + // shape. + if (shape->IsArray()) { + for (int64_t dimension : dimensions) { + shape->add_dimensions(dimension); + } + LayoutUtil::SetToDefaultLayout(shape); } - LayoutUtil::SetToDefaultLayout(shape); return ValidateShape(*shape); } @@ -483,19 +490,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return MakeTupleShape(shapes); } -/* static */ Shape ShapeUtil::MakeOpaqueShape() { - Shape result; - result.set_element_type(OPAQUE_TYPE); - TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); - return result; -} +/* static */ Shape ShapeUtil::MakeOpaqueShape() { return Shape(OPAQUE_TYPE); } -/* static */ Shape ShapeUtil::MakeTokenShape() { - Shape result; - result.set_element_type(TOKEN); - TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); - return result; -} +/* static */ Shape ShapeUtil::MakeTokenShape() { return Shape(TOKEN); } /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { @@ -688,9 +685,11 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { if (shape.element_type() == primitive_type) { return true; } - for (const Shape& element_shape : shape.tuple_shapes()) { - if (HasPrimitiveType(element_shape, primitive_type)) { - return true; + if (shape.IsTuple()) { + for (const Shape& element_shape : shape.tuple_shapes()) { + if (HasPrimitiveType(element_shape, primitive_type)) { + return true; + } } } return false; @@ -713,7 +712,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } printer->Append( primitive_util::LowercasePrimitiveTypeName(shape.element_type())); - if (shape.dimensions().empty()) { + if (!shape.IsArray() || shape.dimensions().empty()) { printer->Append("[]"); return; } @@ -744,6 +743,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return; } PrintHumanString(printer, shape); + if (!shape.IsArray()) return; if (!shape.has_layout()) return; if (IsScalar(shape)) { std::string layout_str = LayoutUtil::HumanString(shape.layout()); @@ -751,7 +751,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { if (layout_str != "{}") { printer->Append(layout_str); } - } else if (shape.IsArray()) { + } else { LayoutUtil::PrintHumanString(printer, shape.layout()); } } @@ -1002,17 +1002,21 @@ absl::Status ValidateDimensions(const Shape& shape) { } return absl::OkStatus(); } +} // namespace // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. absl::Status ValidateNonLayoutProperties(const Shape& shape) { + // Make sure the element type is valid. if (shape.element_type() == PRIMITIVE_TYPE_INVALID || !PrimitiveType_IsValid(shape.element_type())) { return ShapeError(shape, "Invalid element type."); } + + // Validate tuple shapes. if (shape.element_type() == TUPLE) { - if (shape.dimensions_size() != 0) { - return ShapeError(shape, "This type cannot have dimensions."); + if (!shape.if_tuple_state()) { + return ShapeError(shape, "This type must have a tuple state."); } for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateNonLayoutProperties(element_shape)); @@ -1020,27 +1024,34 @@ absl::Status ValidateNonLayoutProperties(const Shape& shape) { return absl::OkStatus(); } - // Non-tuple shape. - if (shape.tuple_shapes_size() > 0) { - return ShapeError(shape, "Non-tuple type contains tuple_shapes."); + // Validate token shapes. + if (shape.element_type() == TOKEN) { + if (!shape.if_token_state()) { + return ShapeError(shape, "This type must have a token state."); + } + return absl::OkStatus(); } - // Tokens and opaques should not have layout or dimensions. - if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) { - if (shape.dimensions_size() != 0) { - return ShapeError(shape, "This type cannot have dimensions."); + // Validate opaque shapes. + if (shape.element_type() == OPAQUE_TYPE) { + if (!shape.if_opaque_state()) { + return ShapeError(shape, "This type must have an opaque state."); } - if (shape.has_layout()) { - return ShapeError(shape, "This type cannot have a layout."); + return absl::OkStatus(); + } + + // Validate array shapes. + if (primitive_util::IsArrayType(shape.element_type())) { + if (!shape.if_array_state()) { + return ShapeError(shape, "This type must have an array state."); } + TF_RETURN_IF_ERROR(ValidateDimensions(shape)); + TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); return absl::OkStatus(); } - TF_RETURN_IF_ERROR(ValidateDimensions(shape)); - TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); - return absl::OkStatus(); + return ShapeError(shape, "Unsupported element type."); } -} // namespace /* static */ absl::Status ShapeUtil::ValidateShapeWithOptionalLayout( const Shape& shape) { diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 9bb20af7f62ff9..b132c79f752b3d 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -294,6 +294,7 @@ class ShapeUtil { // 1. e.g., f32[2x1x1] has a true dimensionality of 1D, the other dimensions // are just fluff. Note that zero dimensions are included in the true // dimensionality, e.g., f32[3,0,1] has a true dimensionality of 2D. + // Precondition: shape.IsArray(). static int64_t TrueNumDimensions(const Shape& shape); static ProgramShape MakeProgramShape(std::initializer_list parameters, From 39be55828ff7eeda360dd507cf942d8a29b81ee2 Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Sun, 30 Mar 2025 17:39:48 -0700 Subject: [PATCH 0023/1324] Fix the breakage for OSS preparation: * LLVM * Py Test * TQDM * TFLite headers * TFLite deps * Py APIs PiperOrigin-RevId: 742095838 --- tensorflow/lite/core/BUILD | 13 +++++++++++++ tensorflow/lite/core/c/BUILD | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index f52c4cc77b46fd..5e0ded435cd945 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -148,6 +148,19 @@ cc_library( alwayslink = 1, # TODO(b/161243354): eliminate this. ) +# This is a private target, its visibility is set to public only to be +# used by LiteRT dependencies. +# Do not use this target directly and don't consider it as a part of the public API. +# TODO(weiyiw): Refactor LiteRT deps from TFLite. +alias( + name = "private_cc_api_stable", + actual = ":cc_api_stable", + tags = ["avoid_dep"], + visibility = [ + "//visibility:public", + ], +) + # TODO(b/242310498): move logger.cc from tensorflow/lite/ to here. cc_library( name = "cc_api_stable", diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index a2b03389d68673..b5e2bf493757e5 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -277,7 +277,7 @@ cc_test( ) # This is a private target, its visibility is set to public only to be -# used by "tflite_custom_c_library". +# used by "tflite_custom_c_library" and LiteRT dependencies. # Do not use this target directly and don't consider it as a part of the public API. alias( name = "private_c_api_types", @@ -552,7 +552,7 @@ tflite_cc_library_with_c_headers_test( ) # This is a private target, its visibility is set to public only to be -# used by "custom_c_library_with_tflite". +# used by "custom_c_library_with_tflite" and LiteRT dependencies. # Do not use this target directly and don't consider it as a part of the public API. alias( name = "private_c_api_opaque_without_op_resolver", From 1499eafe1aea48e149a1c0d0e3ed9e43264bf699 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 21:34:14 -0700 Subject: [PATCH 0024/1324] Automated Code Change PiperOrigin-RevId: 742136280 --- tensorflow/lite/experimental/microfrontend/lib/fft.cc | 2 ++ tensorflow/lite/experimental/microfrontend/lib/fft_util.cc | 3 +++ 2 files changed, 5 insertions(+) diff --git a/tensorflow/lite/experimental/microfrontend/lib/fft.cc b/tensorflow/lite/experimental/microfrontend/lib/fft.cc index 8a107e2b492ef5..0f30eec49167f5 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/fft.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft.cc @@ -16,6 +16,8 @@ limitations under the License. #include +#include + #define FIXED_POINT 16 #include "kiss_fft.h" #include "kiss_fftr.h" diff --git a/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc b/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc index b913f3c0365eb5..18e0d36b53d7d0 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc @@ -16,6 +16,9 @@ limitations under the License. #include +#include +#include + #define FIXED_POINT 16 #include "kiss_fft.h" #include "kiss_fftr.h" From c2b62089457e12e0856dbcc937e8acd6c52855b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 22:57:19 -0700 Subject: [PATCH 0025/1324] Ensure that in `ShapeUtil::PopulateShape()`, `dimensions` is empty if the shape is non-array. PiperOrigin-RevId: 742148212 --- third_party/xla/xla/shape_util.cc | 5 +++-- third_party/xla/xla/shape_util.h | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 9a5db0e9bc399a..1c3fda998628dd 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -453,13 +453,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( Shape* shape) { shape->Clear(); shape->set_element_type(element_type); - // TODO(b/404276923): ensure that dimensions is empty if this is a non-array - // shape. if (shape->IsArray()) { for (int64_t dimension : dimensions) { shape->add_dimensions(dimension); } LayoutUtil::SetToDefaultLayout(shape); + } else { + CHECK(dimensions.empty()) << "Non-array shape " << shape->ToString() + << " cannot have dimensions."; } return ValidateShape(*shape); } diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index b132c79f752b3d..b6ccf195b6047b 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -477,6 +477,9 @@ class ShapeUtil { const Shape& shape); // As MakeShape, but the object to write to is passed in. + // Precondition: + // - if element_type is a non-array type, dimensions must be empty. + // - shape must not be null. static absl::Status PopulateShape(PrimitiveType element_type, absl::Span dimensions, Shape* shape); From 6f4541f5f0745c5786bb960e911ec782981443ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 23:00:51 -0700 Subject: [PATCH 0026/1324] In `Shape::Is*()`, enforce that the shape's element type and state are consistent. PiperOrigin-RevId: 742148588 --- third_party/xla/xla/shape.h | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index e8b14898f39867..6031146c1146c3 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -103,11 +103,38 @@ class Shape { std::string ToString(bool print_layout = false) const; // Returns whether the shape is of the specified category (array, tuple, etc). - // TODO(b/404276923): check that element_type() and the state_ are in sync. - bool IsArray() const { return primitive_util::IsArrayType(element_type()); } - bool IsTuple() const { return element_type() == TUPLE; } - bool IsToken() const { return element_type() == TOKEN; } - bool IsOpaque() const { return element_type() == OPAQUE_TYPE; } + bool IsArray() const { + const bool result = primitive_util::IsArrayType(element_type()); + // We do this check in debug mode only to avoid performance regressions. + DCHECK_EQ(result, if_array_state() != nullptr) + << "Shape " << ToString() + << " has inconsistent element_type and state."; + return result; + } + bool IsTuple() const { + const bool result = element_type() == TUPLE; + // We do this check in debug mode only to avoid performance regressions. + DCHECK_EQ(result, if_tuple_state() != nullptr) + << "Shape " << ToString() + << " has inconsistent element_type and state."; + return result; + } + bool IsToken() const { + const bool result = element_type() == TOKEN; + // We do this check in debug mode only to avoid performance regressions. + DCHECK_EQ(result, if_token_state() != nullptr) + << "Shape " << ToString() + << " has inconsistent element_type and state."; + return result; + } + bool IsOpaque() const { + const bool result = element_type() == OPAQUE_TYPE; + // We do this check in debug mode only to avoid performance regressions. + DCHECK_EQ(result, if_opaque_state() != nullptr) + << "Shape " << ToString() + << " has inconsistent element_type and state."; + return result; + } // Returns whether all elements in the shape are integers. // Tuple shapes are traversed recursively. From 8b432f90052e776bfa118fe325b97b742bf921f1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 30 Mar 2025 23:33:26 -0700 Subject: [PATCH 0027/1324] Automated Code Change PiperOrigin-RevId: 742153252 --- third_party/xla/xla/stream_executor/cuda/BUILD | 9 +++++++++ .../cuda/assemble_compilation_provider.cc | 1 + .../stream_executor/cuda/caching_compilation_provider.h | 1 + .../stream_executor/cuda/compilation_provider_options.cc | 1 + .../cuda/compilation_provider_options_test.cc | 1 + .../xla/stream_executor/cuda/compilation_provider_test.h | 1 + .../cuda/composite_compilation_provider.cc | 1 + .../cuda/composite_compilation_provider.h | 1 + .../xla/xla/stream_executor/cuda/cuda_asm_compiler.h | 1 + .../xla/xla/stream_executor/cuda/cuda_collectives.cc | 1 + .../stream_executor/cuda/cuda_compute_capability_test.cc | 1 + .../xla/stream_executor/cuda/cuda_diagnostics_test.cc | 1 - third_party/xla/xla/stream_executor/cuda/cuda_dnn.h | 1 + .../xla/xla/stream_executor/cuda/cuda_executor.cc | 1 + third_party/xla/xla/stream_executor/cuda/cuda_fft.cc | 2 ++ third_party/xla/xla/stream_executor/cuda/cuda_kernel.h | 1 + .../xla/xla/stream_executor/cuda/cuda_kernel_test.cc | 1 + .../xla/xla/stream_executor/cuda/cuda_platform_test.cc | 1 + .../xla/xla/stream_executor/cuda/cuda_solver_context.cc | 3 +++ third_party/xla/xla/stream_executor/cuda/cuda_stream.h | 1 + .../defer_relocatable_compilation_compilation_provider.h | 1 + .../xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc | 3 ++- .../stream_executor/cuda/driver_compilation_provider.h | 1 + .../cuda/nvjitlink_compilation_provider.cc | 1 + .../cuda/nvjitlink_compilation_provider.h | 1 + .../xla/xla/stream_executor/cuda/nvjitlink_impl.cc | 1 + .../cuda/nvptxcompiler_compilation_provider.cc | 1 + .../cuda/nvptxcompiler_compilation_provider.h | 1 + .../xla/xla/stream_executor/cuda/ptx_compiler_helpers.h | 1 + .../xla/stream_executor/cuda/subprocess_compilation.h | 1 + .../cuda/subprocess_compilation_provider.cc | 1 + .../cuda/subprocess_compilation_provider.h | 1 + 32 files changed, 43 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 500d7a5500f6ac..7253ec4b295d6c 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -238,6 +238,7 @@ cc_library( deps = [ "//xla/stream_executor:activate_context", "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -410,6 +411,7 @@ cc_library( "//xla/stream_executor:gpu_solver_context", "//xla/stream_executor:stream", "//xla/tsl/cuda:cusolver", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -463,6 +465,7 @@ cc_library( "//xla/tsl/cuda:cufft", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:logging", @@ -783,6 +786,7 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -1551,6 +1555,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -1652,6 +1657,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", ], @@ -1669,6 +1675,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", ], @@ -1827,6 +1834,7 @@ xla_cc_test( srcs = ["compilation_provider_options_test.cc"], deps = [ ":compilation_provider_options", + "//xla:xla_proto_cc", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", @@ -1950,6 +1958,7 @@ xla_cc_test( srcs = ["cuda_compute_capability_test.cc"], deps = [ ":cuda_compute_capability", + ":cuda_compute_capability_proto_cc", "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc index 782651f35a6094..016384a72ce097 100644 --- a/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/compilation_provider_options.h" #include "xla/stream_executor/cuda/composite_compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h index 264b0384d99d46..786105aa8c097e 100644 --- a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/node_hash_map.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" diff --git a/third_party/xla/xla/stream_executor/cuda/compilation_provider_options.cc b/third_party/xla/xla/stream_executor/cuda/compilation_provider_options.cc index 43cdddbc52d3d9..eb349b746e78ca 100644 --- a/third_party/xla/xla/stream_executor/cuda/compilation_provider_options.cc +++ b/third_party/xla/xla/stream_executor/cuda/compilation_provider_options.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "xla/xla.pb.h" namespace stream_executor::cuda { diff --git a/third_party/xla/xla/stream_executor/cuda/compilation_provider_options_test.cc b/third_party/xla/xla/stream_executor/cuda/compilation_provider_options_test.cc index 903b253cd53aef..a096c48071f085 100644 --- a/third_party/xla/xla/stream_executor/cuda/compilation_provider_options_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/compilation_provider_options_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/hash/hash_testing.h" #include "absl/strings/str_cat.h" +#include "xla/xla.pb.h" namespace stream_executor::cuda { namespace { diff --git a/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h b/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h index 118d2c8389fe2e..a7cb46b9770514 100644 --- a/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h +++ b/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/cuda/compilation_provider.h" namespace stream_executor::cuda { diff --git a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc index c9e665aa514600..3cdb44e2f0410a 100644 --- a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h index 131d80d30b3aef..59475d721e6322 100644 --- a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h index caf2af501526e8..bc1a78170eea45 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/macros.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/stream_executor/cuda/cubin_or_ptx_image.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc index 382fd7dc3fba10..8721ddef66f2d6 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability_test.cc index a2740268052cfd..c367ae2a21759b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/hash/hash_testing.h" #include "absl/status/status.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.pb.h" #include "xla/tsl/platform/status_matchers.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics_test.cc index bc93ab86d04d5e..cdd70c3b53d26f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/debugging/leak_check.h" #include "absl/log/check.h" -#include "absl/log/globals.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index c25ed37364cde6..e3a45ef761e37a 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "third_party/gpus/cudnn/cudnn_version.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index fd00b8e639868e..e50d8f791476cf 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_executor.h" +#include #include #include #include diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc index c84adac268d956..8c72d127109a1f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cufft.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h index c2e0b990d999a6..18bb2ca5772a63 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/kernel.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc index cf790ddec09cb2..0ca67b2fa51e73 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc index b9621f76aee349..ce3a4715810080 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_solver_context.cc b/third_party/xla/xla/stream_executor/cuda/cuda_solver_context.cc index 55e7861e015421..2b11cb67e8e90b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_solver_context.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_solver_context.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "third_party/gpus/cuda/include/cuComplex.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" @@ -33,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/gpu_solver_context.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h index 68c909c5e59ba6..671e0f5a2258f0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_event.h" diff --git a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h index 4451ea7255fc86..9992ec3dc207e7 100644 --- a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc index 93dfa1053d8a68..10ba837934bcf1 100644 --- a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include "absl/status/statusor.h" #include "xla/stream_executor/cuda/delay_kernel.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/typed_kernel_factory.h" diff --git a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h index e73db347c69e1b..ccfa64bec0c933 100644 --- a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc index 6dd1f1e215b694..01d9f8da87657a 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h index b680e0882a1729..ba099b36e7e41f 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc index 4cee025615cbbe..6d8f298fe25622 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "third_party/gpus/cuda/include/nvJitLink.h" #include "xla/stream_executor/cuda/nvjitlink.h" diff --git a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc index 3cebebf368077d..f0b344f4884e87 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h index 5ffdee124c19fe..a7964214885069 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h index 0bf6beed34b482..9a832d95bf40eb 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h index f052da91069ca8..99d1da03067d83 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/cubin_or_ptx_image.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc index 52ac2f2cecaa0e..07d132fc8f637a 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h index 2960b3c657476f..052fefc4a9ebb6 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" From ef8e20cfe410dca7b50efceac016f1c0a82b3d0d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 00:03:40 -0700 Subject: [PATCH 0028/1324] Automated Code Change PiperOrigin-RevId: 742158882 --- tensorflow/lite/delegates/utils/async_type_helpers.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/lite/delegates/utils/async_type_helpers.cc b/tensorflow/lite/delegates/utils/async_type_helpers.cc index 4f6904c45bfe27..2d8bf0b79fc325 100644 --- a/tensorflow/lite/delegates/utils/async_type_helpers.cc +++ b/tensorflow/lite/delegates/utils/async_type_helpers.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/delegates/utils/async_type_helpers.h" #include -#include #include "tensorflow/lite/async/interop/c/attribute_map.h" #include "tensorflow/lite/async/interop/c/constants.h" From 1f09b81cc4a47312472b1c59e38d42624aff1228 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 31 Mar 2025 00:29:24 -0700 Subject: [PATCH 0029/1324] Adjust PriorityFusion to allow forming simple multi-output Triton fusions. With simple multi-output fusion we mean fusions that have only one root without users. We only allow fusions which are supported by the current Triton codegen. We further restrict to fuse only if there is a single user to fuse with, as otherwise we would also need to detect which is the best user to fuse with, and the priority updates become more complicated. This can be done in a later step. PiperOrigin-RevId: 742164221 --- third_party/xla/xla/debug_options_flags.cc | 9 + .../model/gpu_indexing_performance_model.cc | 9 +- .../gpu/model/symbolic_tile_analysis.h | 3 + .../xla/xla/service/gpu/transforms/BUILD | 1 + .../service/gpu/transforms/priority_fusion.cc | 218 +++++++++++++++--- .../service/gpu/transforms/priority_fusion.h | 3 +- .../gpu/transforms/priority_fusion_test.cc | 103 +++++++++ third_party/xla/xla/xla.proto | 6 +- 8 files changed, 314 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index c9df97529ff92b..9df7ead471c190 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -227,6 +227,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_triton_gemm(true); opts.set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(false); + opts.set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(true); opts.set_xla_gpu_unsupported_force_triton_gemm(false); @@ -1774,6 +1775,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, ->xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(), "Enable lowering Triton GEMM fusions through the generic Triton " "emitter.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_unsupported_enable_triton_multi_output_fusion", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms), + debug_options + ->xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(), + "Enable Triton multi-output fusions.")); flag_list->push_back(tsl::Flag( "xla_gpu_verify_triton_fusion_numerics", bool_setter_for(&DebugOptions::set_xla_gpu_verify_triton_fusion_numerics), diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 8f22c3e347f61a..ce41097da44677 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -512,7 +512,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( } absl::Duration compute_time = - ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), + ComputeTime(*device_info_, flops, num_blocks, launch_dimensions.num_threads_per_block()); absl::Duration memory_access_time = read_time + write_time; @@ -543,15 +543,12 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( return absl::FailedPreconditionError(absl::StrCat( "SymbolicTileAnalysis failed. ", fusion_decision->Explain())); } - // TODO(b/390559452): Add support for more than one fusion root. - if (tile_sizes.size() != 1) { - return absl::UnimplementedError("Only 1 root is supported right now"); - } SymbolicTileAnalysis analysis = std::get(std::move(analysis_or_error)); TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, - analysis.ComputeTiledHloInstructions(tile_sizes[0])); + analysis.ComputeTiledHloInstructions( + tile_sizes[analysis.real_root_index()])); return EstimateRunTimeForTiledHloComputation( fusion_adaptor, tiled_hlo_computation, launch_dimensions); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 17bfd837ad38c1..07243bb4904ef4 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -142,6 +142,9 @@ class SymbolicTileAnalysis { return root_indexing_.roots[idx]; } + // Returns the output index of the real root. + int64_t real_root_index() const { return root_indexing_.real_root_index; } + // Returns the number of tile parameters in this symbolic analysis. // TODO(b/390569102): This assumes that there is only one root that matters // for computing the tiling, and that it is the last symbolic tiled hlo diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index ba26a805777e15..68d8af6460438d 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2438,6 +2438,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/gpu/codegen/triton:support", + "//xla/hlo/analysis:hlo_dfs_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_traversal", diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index 1850ff499126da..11f3435fc59d1f 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/backends/gpu/codegen/triton/support.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -162,6 +163,7 @@ class PriorityFusionQueue { fusion_analysis_cache_(fusion_analysis_cache), fusion_deduplication_cache_(fusion_deduplication_cache), fusion_info_cache_(*device_info_), + reachability_(HloDfsReachability::Build(computation)), triton_heroless_fusion_enabled_(triton_heroless_fusion_enabled) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -253,6 +255,10 @@ class PriorityFusionQueue { reverse_map_.erase(current_producer_); current_consumers_ = current_producer_->users(); + auto preferred_consumer = GetPreferredConsumer(current_producer_); + if (preferred_consumer) { + current_consumers_ = {*preferred_consumer}; + } if (HloPredicateIsOp(current_producer_)) { // We don't check if bitcasts can be fused with all consumers, so we @@ -266,6 +272,15 @@ class PriorityFusionQueue { return !current_consumers_.empty(); } + std::optional GetPreferredConsumer( + HloInstruction* producer) { + auto it = preferred_consumer_.find(producer); + if (it == preferred_consumer_.end()) { + return std::nullopt; + } + return it->second; + } + absl::Status UpdatePerformanceModelCache(HloInstruction* producer) { if (!IsFusible(*producer)) { return absl::OkStatus(); @@ -390,9 +405,19 @@ class PriorityFusionQueue { } // Updates data for the new fusion instruction and its users and operands. + // Both `original_producer` and `original_consumer` could have been removed + // already from the computation, waiting for deletion. We can still + // dereference them though. void OnFusingInstruction(HloInstruction* fusion, HloInstruction* original_producer, - HloInstruction* original_consumer) { + HloInstruction* original_consumer, + int64_t original_consumer_operand_index) { + bool creates_multi_output_fusion = + preferred_consumer_.contains(original_producer); + fusion_deduplication_cache_.UpdateFusedInstructionId( + fusion, original_producer, original_consumer, + original_consumer_operand_index, creates_multi_output_fusion); + if (fusion_process_dump_) { auto* fusion_step = fusion_process_dump_->add_fusion_steps()->mutable_fusion(); @@ -411,13 +436,24 @@ class PriorityFusionQueue { *fusion); } - // The original consumer was replaced with the fusion, but it's pointer can - // still be referenced somewhere, for example, in to_update_priority_. - // Priority recomputation is called before DCE. Remove all references to - // the original consumer here. - if (fusion != original_consumer) { + if (fusion == original_consumer) { + // We need to check again whether we can use `original_consumer` as a + // producer for a ProducerConsumer multi-output fusion. + preferred_consumer_.erase(original_consumer); + } else { + // The original consumer was replaced with the fusion, but it's pointer + // can still be referenced somewhere, for example, in to_update_priority_. + // Priority recomputation is called before DCE. Remove all references to + // the original consumer here. + reachability_->OnInstructionReplaced(/*previous=*/original_consumer, + /*now=*/fusion); RemoveInstruction(original_consumer); } + if (creates_multi_output_fusion) { + // After a multi-output fusion was created, we need to rebuild the + // HloDfsReachability data structure. + reachability_ = HloDfsReachability::Build(computation_); + } // Collect the instructions whose priorities need to be updated. for (HloInstruction* operand : fusion->operands()) { @@ -433,10 +469,21 @@ class PriorityFusionQueue { } to_update_priority_.insert(operand); - // update the consumers of this operand that we care about, - // so we can do incremental update of the operand + // Update the consumers of this operand that we care about, + // so we can do incremental update of the operand. operands_to_new_consumers_[operand].push_back(fusion); + + // We may need to reset `preferred_consumer_`, as we don't know yet + // whether that fusion would still be valid. + auto it = preferred_consumer_.find(operand); + if (it != preferred_consumer_.end() && it->second == original_consumer) { + preferred_consumer_.erase(it); + } } + // TODO(b/390559452): For multi-output fusion, we would also need to update + // the priorities of the other consumers of `producer` with which we did not + // fuse. For now, as we only allow multi-output fusion if there is just a + // single fusible consumer, this is not needed. to_update_priority_.insert(fusion); } @@ -451,6 +498,7 @@ class PriorityFusionQueue { } producer_priority_queue_.erase(reverse_it->second); reverse_map_.erase(reverse_it); + preferred_consumer_.erase(instruction); } // Returns a map from consumer to BlockLevelParameters. This is used to @@ -485,9 +533,24 @@ class PriorityFusionQueue { return -absl::InfiniteDuration(); } - // Don't fuse if we can't fuse in all users. if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer); !fusion_decision) { + // If we cannot fuse `producer` into all non-bitcast consumers, try + // Triton multi-output fusion next. + std::vector possible_consumers = + FindPossibleConsumersForTritonMultiOutputFusion(producer); + if (CanFuseTritonMultiOutputWithSingleUser(producer, + possible_consumers)) { + GpuPerformanceModel::RunTimes run_times = + GpuPerformanceModel::EstimateRunTimes( + producer, *device_info_, &cost_analysis_, + GpuPerformanceModelOptions::Default( + &fusion_analysis_cache_, &gpu_performance_model_cache_), + /*fused_consumers=*/possible_consumers); + preferred_consumer_[producer] = possible_consumers[0]; + return run_times.time_unfused - run_times.time_fused; + } + // Don't fuse if we can't fuse in all users. if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = fusion_process_dump_->add_fusion_steps() @@ -564,10 +627,12 @@ class PriorityFusionQueue { } TiledRunTimeDataOrError GetTiledRunTimeDataCached( - const HloInstruction* producer, const HloInstruction* consumer) { + const HloInstruction* producer, const HloInstruction* consumer, + bool use_multi_output_fusion = false) { FusionDeduplicationCache::FusionId fusion_id = [&]() { absl::MutexLock lock(&fusion_deduplication_cache_mutex_); - return fusion_deduplication_cache_.GetFusionId(producer, consumer); + return fusion_deduplication_cache_.GetFusionId(producer, consumer, + use_multi_output_fusion); }(); { @@ -579,7 +644,8 @@ class PriorityFusionQueue { } } - auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer); + auto fusion = HloFusionAdaptor::ForProducerConsumer( + producer, consumer, use_multi_output_fusion); absl::StatusOr result_or_status = gpu_indexing_performance_model_.TryFindBestTilingForFusion(*fusion); @@ -611,7 +677,8 @@ class PriorityFusionQueue { } FusionDecision CanFuseTriton(HloInstruction* producer, - HloInstruction* consumer) { + HloInstruction* consumer, + bool use_multi_output_fusion = false) { if (!IsGenericTritonFusion(*producer) && !IsGenericTritonFusion(*consumer) && !triton_heroless_fusion_enabled_) { return FusionDecision::Forbid("triton heroless fusion is not enabled"); @@ -626,7 +693,7 @@ class PriorityFusionQueue { } TiledRunTimeDataOrError tiled_run_time_data_or_error = - GetTiledRunTimeDataCached(producer, consumer); + GetTiledRunTimeDataCached(producer, consumer, use_multi_output_fusion); if (const auto* fusion_decision = std::get_if(&tiled_run_time_data_or_error)) { @@ -638,9 +705,14 @@ class PriorityFusionQueue { // This is our way to pass the runtime estimate to the CalculatePriorities() // function. + // This is somewhat brittle as we currently don't distinguish between + // ProducerConsumer fusion where we allow multi-output fusions to be formed, + // and ProducerConsumer fusion where we don't allow it. Same for the + // `block_level_parameters_cache_` down below. Currently we only try out + // multi-output fusion if we cannot fuse into all consumers, and it is tried + // last, so the final cached value should be what we want. gpu_performance_model_cache_.Set( *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); - { absl::MutexLock lock(&block_level_parameters_cache_mutex_); block_level_parameters_cache_[producer][consumer] = @@ -780,6 +852,63 @@ class PriorityFusionQueue { return fusion_decision; } + // Checks whether any operand of `consumer` is reachable from `producer` + // following user edges in the HLO graph. If that is the case, we would + // introduce a cycle by fusing `producer` into `consumer`. + bool OperandReachableFromProducer(const HloInstruction* producer, + const HloInstruction* consumer) { + for (const auto* consumer_operand : consumer->operands()) { + CHECK(reachability_->IsPresent(consumer_operand) && + reachability_->IsPresent(producer)) + << "Reachability map is incomplete. This should never " + "happen."; + if (producer != consumer_operand && + reachability_->IsReachable(producer, consumer_operand)) { + return true; + } + } + return false; + } + + std::vector FindPossibleConsumersForTritonMultiOutputFusion( + HloInstruction* producer) { + bool triton_multi_output_fusion_enabled = + producer->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_enable_triton_multi_output_fusion(); + if (!triton_multi_output_fusion_enabled) { + return {}; + } + std::vector possible_consumers; + for (const auto& user : producer->users()) { + if (HloPredicateIsOp(user)) { + continue; + } + if (CanFuseTriton(producer, user, /*use_multi_output_fusion=*/true) && + !OperandReachableFromProducer(producer, user)) { + possible_consumers.push_back(user); + } + } + return possible_consumers; + } + + FusionDecision CanFuseTritonMultiOutputWithSingleUser( + HloInstruction* producer, + const std::vector& possible_consumers) { + if (possible_consumers.empty()) { + return FusionDecision::Forbid("No users to fuse"); + } + + if (possible_consumers.size() != 1) { + // TODO(b/390559452): If there are several possible consumers to fuse + // with, decide which one is best. Also depends on what further fusions + // might be possible, needs checking the reachability graph. + return FusionDecision::Forbid("more than one consumer to fuse with"); + } + return FusionDecision::Allow(); + } + FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { if (producer->users().empty()) { return FusionDecision::Forbid("No users to fuse"); @@ -825,6 +954,11 @@ class PriorityFusionQueue { // A reverse map that helps find an instruction in the priority queue. absl::flat_hash_map reverse_map_; + // Stores a mapping from the producer to the preferred consumer to fuse into. + // This is only used in case that we want to use ProducerConsumer multi-output + // fusion. + absl::flat_hash_map preferred_consumer_; + // The current producer being visited. HloInstruction* current_producer_; @@ -880,6 +1014,10 @@ class PriorityFusionQueue { // like shared memory usage or number of unnested reductions of fusion nodes. FusionInfoCache fusion_info_cache_; + // Allows evaluation of whether an HloInstruction is an ancestor of another + // HloInstruction. + std::unique_ptr reachability_; + // If true, redirect all fusion decisions to Triton fusion. bool triton_heroless_fusion_enabled_; @@ -984,7 +1122,12 @@ absl::StatusOr PriorityFusion::Run( block_level_parameters_map = fusion_queue->GetBlockLevelParametersMap(producer); - for (auto* consumer : fusion_queue->current_consumers()) { + auto preferred_consumer = fusion_queue->GetPreferredConsumer(producer); + std::vector consumers = + fusion_queue->current_consumers(); + bool use_multi_output_fusion = preferred_consumer.has_value(); + + for (auto* consumer : consumers) { // Don't fuse into single bitcasts. We ignore them in the check // CanFuseWithAllNonBitcastUsers(), so we need to check it here. if (HloPredicateIsOp(consumer)) { @@ -998,12 +1141,8 @@ absl::StatusOr PriorityFusion::Run( int64_t consumer_operand_index = consumer->operand_index(producer); fusion_queue->PreFusion(producer, consumer); - auto fusion_instruction = Fuse(producer, consumer); - fusion_deduplication_cache.UpdateFusedInstructionId( - fusion_instruction, producer, consumer, consumer_operand_index); - fusion_queue->OnFusingInstruction(fusion_instruction, producer, - consumer); - + auto fusion_instruction = + Fuse(producer, consumer, use_multi_output_fusion); auto backend_config_it = block_level_parameters_map.find(consumer); if (backend_config_it != block_level_parameters_map.end()) { TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( @@ -1011,20 +1150,25 @@ absl::StatusOr PriorityFusion::Run( fusion_instruction->set_fusion_kind( HloInstruction::FusionKind::kCustom); } + fusion_queue->OnFusingInstruction(fusion_instruction, producer, + consumer, consumer_operand_index); changed = true; } fusion_queue->ComputeRuntimesOfRemovedConsumers(); - if (producer->user_count() == 0) { + if (use_multi_output_fusion || producer->user_count() == 0) { fusion_queue->InvalidateCaches(producer); - producer->DetachFromOperandsAndUsers(); fusion_queue->RemoveInstruction(producer); - // Remove from computation. - TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + // When we use ProducerConsumer multi-output fusion, `producer` will + // have been removed already. + if (!use_multi_output_fusion) { + producer->DetachFromOperandsAndUsers(); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + } } - for (auto* consumer : fusion_queue->current_consumers()) { + for (auto* consumer : consumers) { fusion_queue->InvalidateCaches(consumer); } TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); @@ -1094,7 +1238,8 @@ HloInstruction::FusionKind PriorityFusion::ChooseKind( } HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { + HloInstruction* consumer, + bool use_multi_output_fusion) { VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); @@ -1115,9 +1260,22 @@ HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, /*skip_async_execution_thread_overwrite=*/false); if (HloPredicateIsOp(producer)) { - fusion_instruction->MergeFusionInstruction(producer); + if (use_multi_output_fusion) { + fusion_instruction->MergeFusionInstructionIntoMultiOutput(producer); + } else { + fusion_instruction->MergeFusionInstruction(producer); + } } else { - fusion_instruction->FuseInstruction(producer); + if (use_multi_output_fusion) { + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + // MergeFusionInstructionIntoMultiOutput already removes `producer` from + // the computation. Do the same here, so that we have the invariant that + // the producer has been cleaned up when multi-output fusion is used. + CHECK_EQ(0, producer->user_count()); + TF_CHECK_OK(producer->parent()->RemoveInstruction(producer)); + } else { + fusion_instruction->FuseInstruction(producer); + } } if (fusion_instruction != consumer) { diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h index f1d19532a755f5..c9a29a5c056f35 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h @@ -60,7 +60,8 @@ class PriorityFusion : public HloModulePass { HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); - HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer, + bool use_multi_output_fusion = false); private: // Consumes a unit of compiler fuel and returns true if we should diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index b534f46150b382..463b4d45c475b5 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -1026,6 +1026,109 @@ ENTRY main { 2); } +TEST_F(PriorityFusionTest, + FuseTritonProducerWithTwoConsumersUsingMultiOutputFusion) { + const std::string kHloText = R"( +HloModule t + +producer_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} +} + +consumer_computation { + parameter_0 = f32[125,127] parameter(0) + ROOT log = f32[125,127] log(parameter_0) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + producer_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=producer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","127"]}],"num_warps":"1"}}} + consumer_fusion = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation + ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion, producer_fusion) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true); + EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction *fusion1, *fusion2; + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Fusion(&fusion1, m::Parameter()), 0), + m::GetTupleElement(m::Fusion(&fusion2, m::Parameter()), 1)))); + EXPECT_EQ(fusion1, fusion2); + EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, + fusion1->backend_config()); + EXPECT_TRUE( + backend_config1.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config1.fusion_backend_config() + .block_level_fusion_config() + .output_tiles(0) + .sizes_size(), + 2); +} + +TEST_F(PriorityFusionTest, + FuseProducerWithTritonConsumerUsingMultiOutputFusion) { + const std::string kHloText = R"( +HloModule t + +consumer_computation { + parameter_0 = f32[125,127] parameter(0) + ROOT log = f32[125,127] log(parameter_0) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + producer = f32[125,127] broadcast(param_0), dimensions={0} + consumer_fusion = f32[125,127] fusion(producer), kind=kCustom, calls=consumer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","127"]}],"num_warps":"1"}}} + ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion, producer) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true); + EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction *fusion1, *fusion2; + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Fusion(&fusion1, m::Parameter()), 0), + m::GetTupleElement(m::Fusion(&fusion2, m::Parameter()), 1)))); + EXPECT_EQ(fusion1, fusion2); + EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, + fusion1->backend_config()); + EXPECT_TRUE( + backend_config1.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config1.fusion_backend_config() + .block_level_fusion_config() + .output_tiles(0) + .sizes_size(), + 2); +} + TEST_F(PriorityFusionTest, TritonProducerNotSupported_DoNotFuse) { const std::string kHloText = R"( HloModule t diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 60fbfe8c8472e5..eb20e72bd5d53e 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -775,6 +775,10 @@ message DebugOptions { // Internal debug/testing flag to switch Triton GEMM fusions on or off. bool xla_gpu_unsupported_enable_triton_gemm = 322; + // Enable experimental triton multi-output fusion. + // TODO(b/390559452): Remove the flag once the feature is stable. + bool xla_gpu_unsupported_enable_triton_multi_output_fusion = 382; + // Internal debug/testing flag to force all GEMMs to use Triton, independently // of known issues. // TODO(b/395903738): use to make specific tests pass on A100 while working @@ -1196,7 +1200,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 382 + // Next id: 383 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 6c3dd7fe6775371ccbfb1fd2e69ec4c6c8ccdc5f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 00:30:56 -0700 Subject: [PATCH 0030/1324] Automated Code Change PiperOrigin-RevId: 742164515 --- tensorflow/core/ir/utility.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/ir/utility.cc b/tensorflow/core/ir/utility.cc index e04b4df00c98dd..34ab5dc5e44f95 100644 --- a/tensorflow/core/ir/utility.cc +++ b/tensorflow/core/ir/utility.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/ir/utility.h" +#include +#include #include #include "mlir/IR/Block.h" // from @llvm-project From 44919f1b593683fcbed963cd0a8ab783e92def2a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 31 Mar 2025 00:33:29 -0700 Subject: [PATCH 0031/1324] [xla:gpu] Add an HLO test for testing conditional (case) command PiperOrigin-RevId: 742165114 --- .../service/gpu/tests/command_buffer_test.cc | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc index e2cf02b8871e7f..377daaffe0c27c 100644 --- a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc +++ b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include @@ -69,5 +70,131 @@ TEST_F(CommandBufferTest, Fusions) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(CommandBufferTest, TrueFalseConditional) { + constexpr absl::string_view hlo_text = R"( + HloModule m, is_scheduled=true + + double { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] add(p0, p0) + } + + square { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] multiply(p0, p0) + } + + double_computation { + p0 = f32[2,2] parameter(0) + ROOT double = f32[2,2] fusion(p0), kind=kLoop, calls=double + } + + square_computation { + p0 = f32[2,2] parameter(0) + ROOT square = f32[2,2] fusion(p0), kind=kLoop, calls=square + } + + command_buffer { + p0 = pred[] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT conditional = f32[2,2] conditional(p0, p1, p1), + true_computation=double_computation, + false_computation=square_computation + } + + ENTRY main { + p0 = pred[] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT call = f32[2,2] call(p0, p1), to_apply=command_buffer + })"; + + Literal p1 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + + { // Execute `true` branch. + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); + + Literal pred = LiteralUtil::CreateR0(true); + Literal expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + Literal result = ExecuteNoHloPasses(std::move(m), {&pred, &p1}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + + { // Execute `false` branch. + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); + + Literal pred = LiteralUtil::CreateR0(false); + Literal expected = LiteralUtil::CreateR2({{1.0, 4.0}, {9.0, 16.0}}); + Literal result = ExecuteNoHloPasses(std::move(m), {&pred, &p1}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } +} + +TEST_F(CommandBufferTest, IndexConditional) { + constexpr absl::string_view hlo_text = R"( + HloModule m, is_scheduled=true + + double { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] add(p0, p0) + } + + square { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] multiply(p0, p0) + } + + double_computation { + p0 = f32[2,2] parameter(0) + ROOT double = f32[2,2] fusion(p0), kind=kLoop, calls=double + } + + square_computation { + p0 = f32[2,2] parameter(0) + ROOT square = f32[2,2] fusion(p0), kind=kLoop, calls=square + } + + command_buffer { + p0 = s32[] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT conditional = f32[2,2] conditional(p0, p1, p1), + branch_computations={double_computation, square_computation} + } + + ENTRY main { + p0 = s32[] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT call = f32[2,2] call(p0, p1), to_apply=command_buffer + })"; + + Literal p1 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + + { // Execute `0` branch. + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); + + Literal index = LiteralUtil::CreateR0(0); + Literal expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + Literal result = ExecuteNoHloPasses(std::move(m), {&index, &p1}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + + { // Execute `1` branch. + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); + + Literal index = LiteralUtil::CreateR0(1); + Literal expected = LiteralUtil::CreateR2({{1.0, 4.0}, {9.0, 16.0}}); + Literal result = ExecuteNoHloPasses(std::move(m), {&index, &p1}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + + { // Execute `1024` branch (our of bound index executes N-1 branch). + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); + + Literal index = LiteralUtil::CreateR0(1024); + Literal expected = LiteralUtil::CreateR2({{1.0, 4.0}, {9.0, 16.0}}); + Literal result = ExecuteNoHloPasses(std::move(m), {&index, &p1}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } +} + } // namespace } // namespace xla::gpu From 4a8276d71f3312f585881901dc7f69cda65620e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 01:05:20 -0700 Subject: [PATCH 0032/1324] Automated Code Change PiperOrigin-RevId: 742171297 --- .../xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc index 19cddf67432994..0d4fe39f6c8741 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc @@ -428,7 +428,7 @@ class GraphString { std::string Graph() const { std::string graph; - for (OpDescriptor op : graph_) { + for (const OpDescriptor& op : graph_) { std::vector operand_uids_in_graph; for (HloInstruction* operand : op.operands) { if (OpInGraph(operand)) { From 802e577d611bfd75d3f9ed861aca0f09e1cef3e3 Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Mon, 31 Mar 2025 01:21:52 -0700 Subject: [PATCH 0033/1324] [XLA:GPU] update checks of having a generic triton emitter on simplifies the testing setup plus makes sure that all legacy dots are going through the generic emitter PiperOrigin-RevId: 742174830 --- .../gpu/codegen/triton/fusion_emitter.cc | 4 ++- .../triton/fusion_emitter_device_test.cc | 26 +++++++------------ .../gpu/codegen/triton/support_test.cc | 9 +------ 3 files changed, 13 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 73d5aec5143c8b..3640e3c7e5b283 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -721,7 +721,6 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, const TiledHloInstruction& tiled_hlo_dot, mlir::triton::FuncOp fn, ValueRange tile_multi_index) { - QCHECK(UseGenericTritonEmitterForGemms(tiled_hlo_dot.hlo())); // We expect to get a tiled HLO in form: // // left { ... } @@ -1569,6 +1568,9 @@ absl::StatusOr CreateTritonModule( // explicitly. std::optional tma_metadata = std::nullopt; if (fusion_kind == kTritonGemmFusionKind) { + // If the generic Triton emitter is enabled, we should never go through the + // legacy MatMul emitter. + QCHECK(!UseGenericTritonEmitterForGemms(fusion)); TF_ASSIGN_OR_RETURN(tma_metadata, EmitMatMul(b, libdevice_path, device_info, fusion, fn, block_level_parameters)); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index b74d67f1bd8afa..f25d7ed31903a8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -69,14 +69,6 @@ class TritonEmitterTest : public GpuCodegenTest { ->GetDeviceDescription() .gpu_compute_capability(); } - - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); - // TODO(b/393299275): Remove when flag is enabled by default. - debug_options - .set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(true); - return debug_options; - } }; TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { @@ -683,7 +675,7 @@ triton_softmax_computation { ENTRY main { parameter_1 = f32[32]{0} parameter(1) parameter_0 = f32[32,16]{1,0} parameter(0) - ROOT _ = f32[32,16]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, + ROOT _ = f32[32,16]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, calls=triton_softmax_computation, backend_config={ "fusion_backend_config":{ @@ -777,8 +769,8 @@ triton_softmax_computation { ENTRY main { parameter_0 = f32[16,32]{1,0} parameter(0) parameter_1 = f32[32]{0} parameter(1) - ROOT _ = f32[16,32]{1,0} fusion(parameter_0,parameter_1), kind=kCustom, - calls=triton_softmax_computation, + ROOT _ = f32[16,32]{1,0} fusion(parameter_0,parameter_1), kind=kCustom, + calls=triton_softmax_computation, backend_config={ "fusion_backend_config":{ "kind":"__triton", @@ -876,7 +868,7 @@ triton_softmax_computation { ENTRY main { parameter_1 = f32[64,32,16]{2,1,0} parameter(1) parameter_0 = f32[16]{0} parameter(0) - ROOT _ = f32[64,32,16]{2,1,0} fusion(f32[64,32,16]{2,1,0} parameter_1, f32[16]{0} parameter_0), kind=kCustom, + ROOT _ = f32[64,32,16]{2,1,0} fusion(f32[64,32,16]{2,1,0} parameter_1, f32[16]{0} parameter_0), kind=kCustom, calls=triton_softmax_computation, backend_config={ "fusion_backend_config":{ @@ -1158,8 +1150,8 @@ fused_computation { ENTRY entry_computation { param_0.2 = f32[16,16,32] parameter(0) - ROOT fusion = f32[4,4,8] fusion(param_0.2), kind=kCustom, - calls=fused_computation, + ROOT fusion = f32[4,4,8] fusion(param_0.2), kind=kCustom, + calls=fused_computation, backend_config={ "fusion_backend_config":{ "kind":"__triton", @@ -2264,7 +2256,7 @@ ENTRY main { "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["32"]}], - "num_warps":"1", + "num_warps":"1", "num_ctas":"1", "num_stages":"1"}}} })"; @@ -2354,7 +2346,7 @@ ENTRY entry { kind=kCustom, calls=dot, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["16", "64"]}], "num_warps":"1", + "output_tiles":[{"sizes":["16", "64"]}], "num_warps":"1", "num_ctas":"1", "num_stages":"1" } } @@ -2392,7 +2384,7 @@ ENTRY e (p0.1: f32[11,1,24,1], p1.1: f32[128,8]) -> f32[264,8] { p0.1 = f32[11,1,24,1]{3,2,1,0} parameter(0) bitcast = f32[264]{0} bitcast(p0.1) p1.1 = f32[128,8]{1,0} parameter(1) - ROOT result.1 = f32[264,8]{1,0} fusion(bitcast, p1.1), kind=kCustom, + ROOT result.1 = f32[264,8]{1,0} fusion(bitcast, p1.1), kind=kCustom, calls=triton_dot, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index ea14452a99134a..5e28de6292fc30 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -1497,14 +1497,7 @@ INSTANTIATE_TEST_SUITE_P(ComplexTestSuite, ComplexTest, AllTestCombinationsForOpcodes(kTestedOpsComplex), TritonSupportTestTypeAndOpcodeAndDeviceToString); -class DotTest : public TritonSupportTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions opts = TritonSupportTest::GetDebugOptionsForTest(); - opts.set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(true); - return opts; - } -}; +using DotTest = TritonSupportTest; class DotTypesTest : public DotTest, public ::testing::WithParamInterface< From 7ede3fbb1352756b45211054558d806c05eac8cb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 02:02:46 -0700 Subject: [PATCH 0034/1324] compat: Update forward compatibility horizon to 2025-03-31 PiperOrigin-RevId: 742184186 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 8d378effc9366b..4fae6576e28ebf 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 30) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 31) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 7c9fe0897c6b1efce2847d445ecfa22d79cc3c1c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 02:02:55 -0700 Subject: [PATCH 0035/1324] Update GraphDef version to 2183. PiperOrigin-RevId: 742184255 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 6455563109606c..14905975e0e21b 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2182 // Updated: 2025/3/30 +#define TF_GRAPH_DEF_VERSION 2183 // Updated: 2025/3/31 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 0ebf4bbfe7b943e80b1e5e2a020ee910f1b375e3 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Mon, 31 Mar 2025 02:05:28 -0700 Subject: [PATCH 0036/1324] PR #24248: Move-to-memory-space custom calls use default layout Imported from GitHub PR https://github.com/openxla/xla/pull/24248 Using non-default layout for MoveToHost makes the compiler insert a transpose operation on the host value if that value flows into the root of the entry computation. Such transposes cause the host offloader to emit a slow on-host transpose (and a warning on the console). We see those warnings on maxtext llama2-7b with optimizer state offloading with fsdp=2. This patch enforces default layout for on-host values so that no transpose is necessary (as long as there is no override of layout for host values in entry computation). Note that such transposes cannot be sunk into the uses by the offloading lagalizer because there is nowhere to sink to - the value is returned from the computation. Copybara import of the project: -- ad192bbba93040a88bda9e4e1074dd1de7153b32 by Jaroslav Sevcik : Use default layout for offloading ops -- f8f0ddcf73c9a4792cbbfd99c2d465f4de80c418 by Jaroslav Sevcik : Address reviewer comments Merging this change closes #24248 PiperOrigin-RevId: 742185098 --- .../gpu/transforms/layout_assignment.cc | 25 +++++++++++++------ .../gpu/transforms/layout_assignment_test.cc | 20 +++++++-------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 7d83bffe5b085a..b5fbe5c093f104 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -340,6 +340,15 @@ bool IsPackedInstruction(const HloInstruction* instruction) { instruction->operand(0)->shape().element_type())); } +bool IsCustomCallToMemoryPlacement(const HloInstruction* hlo) { + if (hlo->opcode() != HloOpcode::kCustomCall) { + return false; + } + const std::string& target = hlo->custom_call_target(); + return target == memory_annotations::kMoveToDeviceCustomCallTarget || + target == memory_annotations::kMoveToHostCustomCallTarget; +} + } // namespace absl::Status GpuLayoutAssignment::AddDotBackendConstraints( @@ -571,6 +580,13 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( LayoutUtil::SetToDefaultLayout(subshape); }); TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction)); + } else if (IsCustomCallToMemoryPlacement(instruction)) { + // Make sure that host memory buffers use the default layout so that + // the compiler does not insert transposes on host memory buffers. + Shape operand_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&operand_shape); + TF_RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, 0)); + TF_RETURN_IF_ERROR(SetInstructionLayout(operand_shape, instruction)); } } return absl::OkStatus(); @@ -691,19 +707,12 @@ bool GpuLayoutAssignment::PropagateReductionLayoutToOperand( bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance( const HloInstruction* instruction) { - // The host offloading custom calls will be eventually removed - // by the offloader, so we need to make sure that the calls do not change - // the layout and thus cause layout mismatches after the removal. // The TopK custom call cannot handle the case if the operand has a different // layout. const HloCustomCallInstruction* custom_call = DynCast(instruction); if (custom_call != nullptr && - (custom_call->custom_call_target() == - memory_annotations::kMoveToHostCustomCallTarget || - custom_call->custom_call_target() == - memory_annotations::kMoveToDeviceCustomCallTarget || - custom_call->custom_call_target() == kTopKCustomCallTarget)) { + custom_call->custom_call_target() == kTopKCustomCallTarget) { return false; } diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 2e54f9ec373ad5..8650c3c4a52c22 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -517,9 +517,10 @@ TEST_F(LayoutAssignmentTest, MoveToHostCustomCallConstrained) { HloModule TestModule ENTRY entry { - Arg_0 = f32[2,5,5]{2,1,0} parameter(0) + Arg_0 = f32[2,5,5]{0,1,2} parameter(0) custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToHost" - ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}} + ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), + custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, @@ -536,9 +537,8 @@ ENTRY entry { const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0"); const Layout input_layout = call_0->operand(0)->shape().layout(); const Layout output_layout = call_0->shape().layout(); - EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout)) - << "Expected the same input/output layouts. Input: " << input_layout - << ". Output: " << output_layout; + EXPECT_EQ(input_layout, LayoutUtil::GetDefaultLayoutForR3()); + EXPECT_EQ(output_layout, LayoutUtil::GetDefaultLayoutForR3()); } TEST_F(LayoutAssignmentTest, MoveToDeviceCustomCallConstrained) { @@ -546,9 +546,10 @@ TEST_F(LayoutAssignmentTest, MoveToDeviceCustomCallConstrained) { HloModule TestModule ENTRY entry { - Arg_0 = f32[2,5,5]{2,1,0} parameter(0) + Arg_0 = f32[2,5,5]{1,2,0} parameter(0) custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToDevice" - ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}} + ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), + custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{0,1,2}} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, @@ -565,9 +566,8 @@ ENTRY entry { const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0"); const Layout input_layout = call_0->operand(0)->shape().layout(); const Layout output_layout = call_0->shape().layout(); - EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout)) - << "Expected the same input/output layouts. Input: " << input_layout - << ". Output: " << output_layout; + EXPECT_EQ(input_layout, LayoutUtil::GetDefaultLayoutForR3()); + EXPECT_EQ(output_layout, LayoutUtil::GetDefaultLayoutForR3()); } TEST_F(LayoutAssignmentTest, CuDNNConvolutionHasNHWCLayoutPostHopper) { From da6aac347655159f306e6e95addec4ed19c0cfd1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 02:17:19 -0700 Subject: [PATCH 0037/1324] Automated Code Change PiperOrigin-RevId: 742187650 --- third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 3a870add02f400..6f99c125a54040 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -4297,6 +4297,7 @@ GetGenericCudnnOperationGraph( for (int op_index = 1; op_index < op_graph.Size(); ++op_index) { TF_ASSIGN_OR_RETURN(op_descriptor, op_graph.OpDescriptorAt(op_index)); std::vector preceding_ops; + preceding_ops.reserve(op_descriptor.operand_uids.size()); for (int operand_uid : op_descriptor.operand_uids) { preceding_ops.emplace_back( op_graph.FindOpDescriptor(operand_uid).value()); From af32e0cf0bff88792dcf6f9f21c416acba1facd2 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 31 Mar 2025 04:14:53 -0700 Subject: [PATCH 0038/1324] PR #23827: [DOC] Fix multihost HLO runner doc. Imported from GitHub PR https://github.com/openxla/xla/pull/23827 - Omit -c opt which is the default. - Omit dump_hlo_as_text which is the default. - Omit --config=cuda which is nowadays handled by configure.py. - Remove --xla_disable_all_hlo_passes which has no effect in presence of --run_xla_backend_only. - Leave single mention of dynamic_mode=off which usually isn't needed. - Other minor fixes. Copybara import of the project: -- bc64c89d3cb4db37945331d94faf9acc573ed6d9 by Ilia Sergachev : [DOC] Fix multihost HLO runner doc. - Omit -c opt which is the default. - Omit dump_hlo_as_text which is the default. - Omit --config=cuda which is nowadays handled by configure.py. - Remove --xla_disable_all_hlo_passes which has no effect in presence of --run_xla_backend_only. - Leave single mention of dynamic_mode=off which usually isn't needed. - Other minor fixes. Merging this change closes #23827 PiperOrigin-RevId: 742214204 --- .../xla/docs/tools_multihost_hlo_runner.md | 47 +++++++++---------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/third_party/xla/docs/tools_multihost_hlo_runner.md b/third_party/xla/docs/tools_multihost_hlo_runner.md index 17d9bcbabd47b6..783eeaa75e9c9e 100644 --- a/third_party/xla/docs/tools_multihost_hlo_runner.md +++ b/third_party/xla/docs/tools_multihost_hlo_runner.md @@ -11,26 +11,23 @@ We can identify these HLOs by seeing `sharding=` annotations. For example `sharding={devices=[1,1,2,1]0,1}` means that the annotated tensor should be sharded to 2 GPUs (GPU0 and GPU1) along the 3rd dimension. -The following instructions assume the working directory is the xla Git +The following instructions assume the working directory is the XLA Git repository and that `./configure.py` has been run. If we have enough GPUs, we can replay these HLOs like this: ``` -bazel run -c opt --config=cuda --dynamic_mode=off \ - //xla/tools/multihost_hlo_runner:hlo_runner_main -- my-hlo.txt +bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- my-hlo.txt ``` Tip: If the input generation takes too long or uses too much host memory, consider using `--hlo_argument_mode=uninitialized`. It is also possible to compile the same HLO without running it by setting -`--run=false` +`--run=false`: ``` -bazel run -c opt --config=cuda --dynamic_mode=off \ - //xla/tools/multihost_hlo_runner:hlo_runner_main \ - -- --run=false my-hlo.txt +bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- --run=false my-hlo.txt ``` In that case, a single GPU is necessary, unless the @@ -63,44 +60,43 @@ Note, those instructions can be outdated more quickly. Adjust as needed. ``` # The 8 below is the number of GPUs you have. # test-pax.sh --help for more details on the parallelization options -(export XLA_FLAGS="--xla_dump_to=/tmp/dump --xla_dump_hlo_as_text"; test-pax.sh --fsdp 8 --batch-per-gpu 1) +(export XLA_FLAGS="--xla_dump_to=/tmp/dump"; test-pax.sh --fsdp 8 --batch-per-gpu 1) ls -lSh /tmp/dump/*before_optimizations.txt # The biggest file one is normally the one you care about. # I picked one, for the rest of the scripts, but the name could change when you change the JAX or XLA version. ``` -### Build XLA multinode runner +### Build XLA multihost runner ``` cd /opt/xla/ ./configure.py --backend CUDA --nccl -bazel build -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main +bazel build //xla/tools/multihost_hlo_runner:hlo_runner_main ``` ### Single process example: Before optimization graph replay ``` -bazel run -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main -- /tmp/dump/module_0023.pjit__wrapped_step_fn.before_optimizations.txt +bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- \ + /tmp/dump/module_0023.pjit__wrapped_step_fn.before_optimizations.txt ``` ### Single process example: After optimization graph replay -To replay an optimized HLO, you must use those two parameters -`--run_xla_backend_only=true --xla_disable_all_hlo_passes=true`. Otherwise, it -will try to recompile the HLO and this isn't supported. So it will give you many -strange errors. +To replay an optimized HLO, you must use either `--xla_disable_all_hlo_passes` +or `--run_xla_backend_only`. Otherwise, XLA will try to recompile the HLO and +this isn't supported. So it will give you many strange errors. -Full command: `bazel run -c opt --config=cuda --dynamic_mode=off -//xla/tools/multihost_hlo_runner:hlo_runner_main -- --run_xla_backend_only=true ---xla_disable_all_hlo_passes=true +Full command: `bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- +--run_xla_backend_only /tmp/dump/module_0023.pjit__wrapped_step_fn.sm_8.0_gpu_after_optimizations.txt` ## Multi-processes, single-node ### Launch container -Also install some missing librairies. (Note, that can be outdated more quickly. +Also install some missing libraries. (Note, that can be outdated more quickly. Adjust as needed.) ``` @@ -108,14 +104,13 @@ docker run -it --shm-size=1g --gpus all ghcr.io/nvidia/jax:pax-2024-06-03 apt-get update && apt-get install -y openmpi-bin openmpi-common libopenmpi-dev ``` -### Run original model and dump HLO. +### Run original model and dump HLO For this example, we will use an 8-GPU PAXML model from `test-pax.sh`. (Note this will be the same dump as the single process case. So you can do `cp -r /tmp/dump /tmp/dump_multi_process` if you already have it. `export -XLA_FLAGS="--xla_dump_to=/tmp/dump_multi_process --xla_dump_hlo_as_text" mpirun ---allow-run-as-root -np 8 test-pax.sh --fsdp 8 --batch-per-gpu 1 -o -/tmp/checkpoint --multiprocess` +XLA_FLAGS="--xla_dump_to=/tmp/dump_multi_process" mpirun --allow-run-as-root -np +8 test-pax.sh --fsdp 8 --batch-per-gpu 1 -o /tmp/checkpoint --multiprocess` The HLO dump will be saved to `/tmp/dump_multi_process/`. For PAX specifically, the main module will have "pjit__wrapped_step_fn" in the name. For this example @@ -129,7 +124,7 @@ Create a bash script called `run.sh`: ``` #!/bin/bash export CUDA_VISIBLE_DEVICES=${OMPI_COMM_WORLD_LOCAL_RANK} -bazel run -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main -- \ +bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- \ --task_id=${OMPI_COMM_WORLD_RANK} \ --num_nodes=${OMPI_COMM_WORLD_SIZE} \ --address=127.0.0.1:12345 \ @@ -146,10 +141,10 @@ mpirun --allow-run-as-root -np 8 run.sh ### Run on multiple nodes with SLURM When running on multiple nodes using SLURM, you can forward the SLURM env -variables to the hlo runner like so in your slurm job: +variables to the HLO runner like so in your SLURM job: ``` -bazel run -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main -- \ +bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- \ --task_id=${SLURM_PROCID} \ --num_nodes=${SLURM_NTASKS} \ --address="${SLURM_LAUNCH_NODE_IPADDR}:12345" \ From 45684e944ded140dda906e486a7afd0d9610d026 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 31 Mar 2025 04:20:07 -0700 Subject: [PATCH 0039/1324] PR #24008: [GPU] Fix TraceMe annotation. Imported from GitHub PR https://github.com/openxla/xla/pull/24008 Copybara import of the project: -- 829c15522e8ebebb11079913e51fac965f594cb2 by Ilia Sergachev : [GPU] Fix TraceMe annotation. Merging this change closes #24008 PiperOrigin-RevId: 742215534 --- third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 1a183620fdd435..c050fad9b8d68a 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -746,7 +746,7 @@ absl::Status RunAsyncCollectivesConversionPasses(HloModule* module) { absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info) { - tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); + tsl::profiler::TraceMe traceme("ScheduleGpuModule"); // Tag the module with its 128 bit fingerprint. The fingerprint should include // instruction name with ids. From b2a9832ea4c43a5fc5ee367ab0b59e10c03659ff Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 31 Mar 2025 05:07:51 -0700 Subject: [PATCH 0040/1324] [xla:gpu] CommandBuffer: switch CommandBuffer::Case to explicit command update API PiperOrigin-RevId: 742226471 --- .../gpu/runtime/command_buffer_cmd.cc | 20 +-- .../xla/xla/stream_executor/command_buffer.h | 13 +- .../stream_executor/gpu/gpu_command_buffer.cc | 120 +++++++++++------- .../stream_executor/gpu/gpu_command_buffer.h | 21 ++- .../gpu/gpu_command_buffer_test.cc | 14 +- 5 files changed, 124 insertions(+), 64 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 08ac30a4ed54ab..e082c296b9fce3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -790,15 +790,19 @@ absl::Status CaseCmd::Record(const Thunk::ExecuteParams& execute_params, VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")"; if (index_is_bool_) { - return command_buffer->Case( - se::DeviceMemory(index), - CreateBuilders(absl::MakeSpan(branches_commands_), &execute_params, - &record_params)); + return command_buffer + ->Case(se::DeviceMemory(index), + CreateBuilders(absl::MakeSpan(branches_commands_), + &execute_params, &record_params), + {}) + .status(); } else { - return command_buffer->Case( - se::DeviceMemory(index), - CreateBuilders(absl::MakeSpan(branches_commands_), &execute_params, - &record_params)); + return command_buffer + ->Case(se::DeviceMemory(index), + CreateBuilders(absl::MakeSpan(branches_commands_), + &execute_params, &record_params), + {}) + .status(); } } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 1e766a28e86c4a..ffb86bbeb34525 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -174,10 +174,19 @@ class CommandBuffer { // will run a conditional command buffer constructed by the last builder. // // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case - virtual absl::Status Case(DeviceMemory index, + virtual absl::StatusOr Case( + DeviceMemory index, std::vector branches, + absl::Span dependencies) = 0; + + virtual absl::StatusOr Case( + DeviceMemory index, std::vector branches, + absl::Span dependencies) = 0; + + // Updates a Case operation. + virtual absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) = 0; - virtual absl::Status Case(DeviceMemory index, + virtual absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) = 0; // Adds a conditional operation that will execute a command buffer constructed diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 34bea9d91c6ea3..c007e183c3426f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -350,15 +350,18 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { return handles; } -absl::Status GpuCommandBuffer::Case(DeviceMemory index, - bool index_is_bool, - std::vector branches) { +absl::StatusOr GpuCommandBuffer::Case( + DeviceMemory index, bool index_is_bool, + std::vector branches, + absl::Span dependencies) { constexpr size_t kBranchBatchSize = 8; if (state_ == State::kCreate) { GpuCaseCommand command = {}; - auto dependencies = GetAutoDependencies(); + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); int32_t batch_offset = 0; while (batch_offset < branches.size()) { @@ -384,7 +387,7 @@ absl::Status GpuCommandBuffer::Case(DeviceMemory index, TF_ASSIGN_OR_RETURN(auto set_condition_node, CreateSetCaseConditionNode( conditionals, index, index_is_bool, batch_offset, - enable_conditional_default, dependencies)); + enable_conditional_default, barrier)); std::vector conditional_nodes; for (int z = 0; z < batch_size; ++z) { @@ -412,69 +415,98 @@ absl::Status GpuCommandBuffer::Case(DeviceMemory index, batch_offset += batch_size; } - AppendCommand(std::move(command)); - return absl::OkStatus(); + return AppendCommand(std::move(command)); } if (state_ == State::kUpdate) { Command& command = *commands_[update_state_.command_idx++]; - auto* gpu_command = tsl::down_cast(&command); + TF_RETURN_IF_ERROR(Case(&command, index, index_is_bool, branches)); + return &command; + } - // Update branch conditionals. - size_t batch_index = 0; - int32_t batch_offset = 0; - while (batch_offset < branches.size()) { - int32_t remaining_branches = branches.size() - batch_offset; - int32_t batch_size; - bool enable_conditional_default; - if (remaining_branches <= kBranchBatchSize) { - batch_size = remaining_branches; - enable_conditional_default = true; - } else { - batch_size = kBranchBatchSize; - enable_conditional_default = false; - } + return UnsupportedStateError(state_); +} - TF_RETURN_IF_ERROR(UpdateSetCaseConditionNode( - gpu_command->set_condition_nodes[batch_index], - absl::MakeSpan(gpu_command->conditionals) - .subspan(batch_offset, batch_size), - index, index_is_bool, batch_offset, enable_conditional_default)); +absl::Status GpuCommandBuffer::Case(const Command* command, + DeviceMemory index, + bool index_is_bool, + std::vector branches) { + constexpr size_t kBranchBatchSize = 8; - batch_offset += batch_size; - batch_index += 1; + auto* gpu_command = tsl::down_cast(command); + + // Update branch conditionals. + size_t batch_index = 0; + int32_t batch_offset = 0; + while (batch_offset < branches.size()) { + int32_t remaining_branches = branches.size() - batch_offset; + int32_t batch_size; + bool enable_conditional_default; + if (remaining_branches <= kBranchBatchSize) { + batch_size = remaining_branches; + enable_conditional_default = true; + } else { + batch_size = kBranchBatchSize; + enable_conditional_default = false; } - // Update branch command buffers. - for (size_t i = 0; i < gpu_command->conditional_nodes.size(); ++i) { - GpuCommandBuffer* case_command_buffer = - gpu_command->conditional_nodes[i].command_buffer.get(); - auto scoped_update_mode = ActivateUpdateMode(case_command_buffer); - TF_RETURN_IF_ERROR(case_command_buffer->Update()); - TF_RETURN_IF_ERROR(branches[i](case_command_buffer)); - TF_RETURN_IF_ERROR(case_command_buffer->Finalize()); - } + TF_RETURN_IF_ERROR(UpdateSetCaseConditionNode( + gpu_command->set_condition_nodes[batch_index], + absl::MakeSpan(gpu_command->conditionals) + .subspan(batch_offset, batch_size), + index, index_is_bool, batch_offset, enable_conditional_default)); - return absl::OkStatus(); + batch_offset += batch_size; + batch_index += 1; } - return UnsupportedStateError(state_); + // Update branch command buffers. + for (size_t i = 0; i < gpu_command->conditional_nodes.size(); ++i) { + GpuCommandBuffer* case_command_buffer = + gpu_command->conditional_nodes[i].command_buffer.get(); + auto scoped_update_mode = ActivateUpdateMode(case_command_buffer); + TF_RETURN_IF_ERROR(case_command_buffer->Update()); + TF_RETURN_IF_ERROR(branches[i](case_command_buffer)); + TF_RETURN_IF_ERROR(case_command_buffer->Finalize()); + } + + return absl::OkStatus(); } -absl::Status GpuCommandBuffer::Case(DeviceMemory index, - std::vector branches) { +absl::StatusOr GpuCommandBuffer::Case( + DeviceMemory index, std::vector branches, + absl::Span dependencies) { return Case( DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), - /*index_is_bool=*/true, branches); + /*index_is_bool=*/false, branches, dependencies); +} + +absl::StatusOr GpuCommandBuffer::Case( + DeviceMemory index, std::vector branches, + absl::Span dependencies) { + return Case( + DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), + /*index_is_bool=*/true, branches, dependencies); } -absl::Status GpuCommandBuffer::Case(DeviceMemory index, +absl::Status GpuCommandBuffer::Case(const Command* command, + DeviceMemory index, std::vector branches) { return Case( + command, DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), /*index_is_bool=*/false, branches); } +absl::Status GpuCommandBuffer::Case(const Command* command, + DeviceMemory index, + std::vector branches) { + return Case( + command, + DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), + /*index_is_bool=*/true, branches); +} + absl::Status GpuCommandBuffer::For(int32_t num_iteration, DeviceMemory loop_counter, Builder body_builder) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index b2acd0ea12f72b..0a46b2baab2ed1 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -166,10 +166,18 @@ class GpuCommandBuffer : public CommandBuffer { const BitPattern& bit_pattern, size_t num_elements) override; - absl::Status Case(DeviceMemory index, + absl::StatusOr Case( + DeviceMemory index, std::vector branches, + absl::Span dependencies) override; + + absl::StatusOr Case( + DeviceMemory index, std::vector branches, + absl::Span dependencies) override; + + absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) override; - absl::Status Case(DeviceMemory index, + absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) override; absl::Status For(int32_t num_iteration, DeviceMemory loop_counter, @@ -283,8 +291,13 @@ class GpuCommandBuffer : public CommandBuffer { virtual absl::Status CheckCanBeUpdated() = 0; private: - absl::Status Case(DeviceMemory index, bool index_is_bool, - std::vector branches); + absl::StatusOr Case( + DeviceMemory index, bool index_is_bool, + std::vector branches, + absl::Span dependencies); + + absl::Status Case(const Command* command, DeviceMemory index, + bool index_is_bool, std::vector branches); // Constructs a new command for the given graph node handle and appends it to // the command buffer. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 99351cbfc9c00e..9dd6d363fce79b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -397,8 +397,9 @@ TEST(GpuCommandBufferTest, ConditionalCaseEmptyGraph) { }; // Create a command buffer with a single conditional operation. - auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1})); + TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, + executor->CreateCommandBuffer(primary)); + TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1}, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -494,8 +495,9 @@ TEST_P(GpuCommandBufferCaseTest, ConditionalMultiCase) { } // Create a command buffer with a single conditional operation. - auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->Case(index, branches)); + TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, + executor->CreateCommandBuffer(primary)); + TF_ASSERT_OK(cmd_buffer->Case(index, branches, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); // We test the out of bounds cases as well ( i < 0, i >= kNumCases). @@ -584,7 +586,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1})); + TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1}, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -800,7 +802,7 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { auto nested_cmd = executor->CreateCommandBuffer(nested).value(); // TODO(b/339653343): Adding this Case condition causes AddNestedCommandBuffer // to fail. - TF_ASSERT_OK(nested_cmd->Case(pred_then, {then_builder, then_builder})); + TF_ASSERT_OK(nested_cmd->Case(pred_then, {then_builder, then_builder}, {})); // Loop cond: loop_counter++ < num_iters; CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { From e273d6643d6b803868b9d64dcf37489c07af9b91 Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Mon, 31 Mar 2025 05:17:50 -0700 Subject: [PATCH 0041/1324] [XLA:GPU] Add more op codes to triton support test 1/n PiperOrigin-RevId: 742228670 --- .../backends/gpu/codegen/triton/support.cc | 9 - .../gpu/codegen/triton/support_test.cc | 268 ++++++++++++++++-- 2 files changed, 246 insertions(+), 31 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 67e75c1ae5616b..705d5c859cba58 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -481,25 +481,18 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { switch (opcode) { case HloOpcode::kAddDependency: case HloOpcode::kAfterAll: - case HloOpcode::kBatchNormGrad: - case HloOpcode::kBatchNormInference: - case HloOpcode::kBatchNormTraining: case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: case HloOpcode::kCholesky: - case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCopyDone: case HloOpcode::kCopyStart: case HloOpcode::kCustomCall: - case HloOpcode::kDomain: case HloOpcode::kDynamicReshape: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: - case HloOpcode::kGetDimensionSize: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: case HloOpcode::kMap: @@ -510,7 +503,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduceWindow: - case HloOpcode::kReverse: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: @@ -521,7 +513,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kTopK: case HloOpcode::kTriangularSolve: case HloOpcode::kTuple: - case HloOpcode::kWhile: return true; default: return false; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 5e28de6292fc30..48f66a2280aeb8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -138,6 +138,10 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { return type == F32 || type == F64; case HloOpcode::kDot: return type != PRED; + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: + return pu::IsFloatingPointType(type); default: // Returning true by default ensures that newly added ops are not // skipped. @@ -1435,11 +1439,9 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc); } -constexpr std::array kTestedOpsRngBitGenerator = {HloOpcode::kRngBitGenerator}; - INSTANTIATE_TEST_SUITE_P( RngBitGeneratorTestSuite, RngBitGeneratorTest, - AllTestCombinationsForOpcodes(kTestedOpsRngBitGenerator), + AllTestCombinationsForOpcodes({HloOpcode::kRngBitGenerator}), TritonSupportTestTypeAndOpcodeAndDeviceToString); using RngGetAndUpdateStateTest = TritonSupportTestWithDeviceParam; @@ -1457,8 +1459,6 @@ TEST_P(RngGetAndUpdateStateTest, RngGetAndUpdateState) { HloOpcode::kRngGetAndUpdateState)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -constexpr std::array kTestedOpsRngGetAndUpdateState = { - HloOpcode::kRngGetAndUpdateState}; INSTANTIATE_TEST_SUITE_P(RngGetAndUpdateStateTestSuite, RngGetAndUpdateStateTest, @@ -1491,10 +1491,233 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc); } -constexpr std::array kTestedOpsComplex = {HloOpcode::kComplex}; - INSTANTIATE_TEST_SUITE_P(ComplexTestSuite, ComplexTest, - AllTestCombinationsForOpcodes(kTestedOpsComplex), + AllTestCombinationsForOpcodes({HloOpcode::kComplex}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using ConditionalTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +TEST_P(ConditionalTest, Conditional) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +true_branch { + p_true = $0[10] parameter(0) + ROOT add = $0[10] add(p_true, p_true) +} +false_branch { + p_false = $0[10] parameter(0) + ROOT mul = $0[10] multiply(p_false, p_false) +} +ENTRY triton_computation { + cond = pred[] parameter(0) + operand = $0[10] parameter(1) + ROOT conditional_op = $0[10] conditional(cond, operand, operand), + true_computation=true_branch, + false_computation=false_branch +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + ConditionalTestSuite, ConditionalTest, + AllTestCombinationsForOpcodes({HloOpcode::kConditional}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using WhileTest = TritonSupportTestWithDeviceParam; +// TODO: b/363981282 - Add tests for more data types. +TEST_P(WhileTest, While) { + auto cc = GetParam(); + const std::string kHloTestTemplate = R"( +body { + constant = s32[] constant(1) + prev = s32[] parameter(0) + ROOT add = s32[] add(constant, prev) +} +condition { + constant = s32[] constant(5) + prev = s32[] parameter(0) + ROOT greater-than = pred[] compare(constant, prev), direction=GT +} +ENTRY triton_computation { + constant = s32[] constant(0) + ROOT while = s32[] while(constant), condition=condition, body=body +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, + F16, // data_type doesn't matter here + HloOpcode::kWhile)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc); +} + +INSTANTIATE_TEST_SUITE_P(WhileTestSuite, WhileTest, + ::testing::ValuesIn(AllDevicesToTest()), + TritonSupportTestDeviceToString); + +using CallTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +TEST_P(CallTest, Call) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +called_computation { + p = $0[10] parameter(0) + ROOT add = $0[10] add(p, p) +} + +ENTRY triton_computation { + operand = $0[10] parameter(0) + ROOT call_op = $0[10] call(operand), to_apply=called_computation +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); +} + +INSTANTIATE_TEST_SUITE_P(CallTestSuite, CallTest, + AllTestCombinationsForOpcodes({HloOpcode::kCall}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using BatchNormInferenceTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +TEST_P(BatchNormInferenceTest, BatchNormInference) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + operand = $0[4,8,16,32] parameter(0) + scale = $0[32] parameter(1) + offset = $0[32] parameter(2) + mean = $0[32] parameter(3) + variance = $0[32] parameter(4) + ROOT bn_inf = $0[4,8,16,32] batch-norm-inference(operand, scale, offset, mean, variance), + epsilon=0.001, feature_index=3 +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + BatchNormInferenceSuite, BatchNormInferenceTest, + AllTestCombinationsForOpcodes({HloOpcode::kBatchNormInference}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using BatchNormTrainingTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +// TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output +// tikes support to RunSupportTest. +TEST_P(BatchNormTrainingTest, BatchNormTraining) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + operand = $0[4,8,16,32] parameter(0) + scale = $0[32] parameter(1) + offset = $0[32] parameter(2) + bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset), + epsilon=0.001, feature_index=3 + ROOT gte = $0[4,8,16,32] get-tuple-element(bn_train), index=0 +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + BatchNormTrainingSuite, BatchNormTrainingTest, + AllTestCombinationsForOpcodes({HloOpcode::kBatchNormTraining}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using BatchNormGradTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +// TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output +// tikes support to RunSupportTest. +TEST_P(BatchNormGradTest, BatchNormGrad) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + operand = $0[4,8,16,32] parameter(0) + scale = $0[32] parameter(1) + mean = $0[32] parameter(2) + variance = $0[32] parameter(3) + grad_output = $0[4,8,16,32] parameter(4) + bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output), + epsilon=0.001, feature_index=3 + ROOT gte = $0[4,8,16,32] get-tuple-element(bn_grad), index=0 +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + BatchNormGradSuite, BatchNormGradTest, + AllTestCombinationsForOpcodes({HloOpcode::kBatchNormGrad}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using DomainTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +TEST_P(DomainTest, Domain) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + operand = $0[] parameter(0) + ROOT domain_op = $0[] domain(operand), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc); +} + +INSTANTIATE_TEST_SUITE_P(DomainSuite, DomainTest, + AllTestCombinationsForOpcodes({HloOpcode::kDomain}), + TritonSupportTestTypeAndOpcodeAndDeviceToString); + +using GetDimensionSizeTest = TritonSupportTestWithDeviceParam; + +TEST_P(GetDimensionSizeTest, GetDimensionSize) { + const auto cc = GetParam(); + + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + operand = s32[16, 32] parameter(0) + ROOT get_dim_size = s32[] get-dimension-size(operand), dimensions={1} +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, + F16, // data_type doesn't matter here + HloOpcode::kGetDimensionSize)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc); +} + +INSTANTIATE_TEST_SUITE_P(GetDimensionSizeSuite, GetDimensionSizeTest, + ::testing::ValuesIn(AllDevicesToTest()), + TritonSupportTestDeviceToString); + +using ReverseTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; + +TEST_P(ReverseTest, Reverse) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + operand = $0[16,32] parameter(0) + ROOT reverse_op = $0[16,32] reverse(operand), dimensions={0, 1} +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{4, 8}, cc); +} + +INSTANTIATE_TEST_SUITE_P(ReverseSuite, ReverseTest, + AllTestCombinationsForOpcodes({HloOpcode::kReverse}), TritonSupportTestTypeAndOpcodeAndDeviceToString); using DotTest = TritonSupportTest; @@ -2115,25 +2338,18 @@ constexpr std::array kUnsupportedOps = { // go/keep-sorted start HloOpcode::kAddDependency, HloOpcode::kAfterAll, - HloOpcode::kBatchNormGrad, - HloOpcode::kBatchNormInference, - HloOpcode::kBatchNormTraining, HloOpcode::kBitcastConvert, - HloOpcode::kCall, HloOpcode::kCholesky, - HloOpcode::kConditional, HloOpcode::kConvolution, HloOpcode::kCopyDone, HloOpcode::kCopyStart, HloOpcode::kCustomCall, - HloOpcode::kDomain, HloOpcode::kDynamicReshape, HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, HloOpcode::kFft, HloOpcode::kFusion, HloOpcode::kGather, - HloOpcode::kGetDimensionSize, HloOpcode::kGetTupleElement, HloOpcode::kInfeed, HloOpcode::kMap, @@ -2144,7 +2360,6 @@ constexpr std::array kUnsupportedOps = { HloOpcode::kRecv, HloOpcode::kRecvDone, HloOpcode::kReduceWindow, - HloOpcode::kReverse, HloOpcode::kScatter, HloOpcode::kSelectAndScatter, HloOpcode::kSend, @@ -2155,7 +2370,6 @@ constexpr std::array kUnsupportedOps = { HloOpcode::kTopK, HloOpcode::kTriangularSolve, HloOpcode::kTuple, - HloOpcode::kWhile // go/keep-sorted end // clang-format on }; @@ -2180,13 +2394,23 @@ absl::flat_hash_set AllTestedOpcodes() { ret.insert(kTestedOpsConstant.begin(), kTestedOpsConstant.end()); ret.insert(kTestedOpsIota.begin(), kTestedOpsIota.end()); ret.insert(kTestedOpsRng.begin(), kTestedOpsRng.end()); - ret.insert(kTestedOpsRngBitGenerator.begin(), - kTestedOpsRngBitGenerator.end()); - ret.insert(kTestedOpsRngGetAndUpdateState.begin(), - kTestedOpsRngGetAndUpdateState.end()); - ret.insert(kTestedOpsComplex.begin(), kTestedOpsComplex.end()); + + ret.emplace(HloOpcode::kBatchNormGrad); + ret.emplace(HloOpcode::kBatchNormInference); + ret.emplace(HloOpcode::kBatchNormTraining); + ret.emplace(HloOpcode::kCall); + ret.emplace(HloOpcode::kComplex); + ret.emplace(HloOpcode::kConditional); + ret.emplace(HloOpcode::kDomain); ret.emplace(HloOpcode::kDot); + ret.emplace(HloOpcode::kGetDimensionSize); + ret.emplace(HloOpcode::kReverse); + ret.emplace(HloOpcode::kRngBitGenerator); + ret.emplace(HloOpcode::kRngGetAndUpdateState); + ret.emplace(HloOpcode::kWhile); + ret.insert(kUnsupportedOps.begin(), kUnsupportedOps.end()); + return ret; } From c2be855d368f13d071b1f3a04bc9f2f73c0be929 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Mon, 31 Mar 2025 05:30:41 -0700 Subject: [PATCH 0042/1324] [XLA:GPU] Collect matmul perf table data for further analysis. PiperOrigin-RevId: 742231191 --- third_party/xla/xla/debug_options_flags.cc | 9 + third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/backend_configs.proto | 2 + .../xla/xla/service/gpu/gpu_compiler.cc | 11 + third_party/xla/xla/service/gpu/model/BUILD | 49 ++++ .../model/matmul_ptable_stats_collection.cc | 114 ++++++++++ .../model/matmul_ptable_stats_collection.h | 54 +++++ .../matmul_ptable_stats_collection_test.cc | 210 ++++++++++++++++++ third_party/xla/xla/xla.proto | 5 +- 9 files changed, 454 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.cc create mode 100644 third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.h create mode 100644 third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 9df7ead471c190..ce53378a85c5c7 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -322,6 +322,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); opts.set_xla_gpu_executable_terminate_timeout_seconds(30); opts.set_xla_gpu_experimental_collective_perf_table_path(""); + opts.set_xla_gpu_experimental_matmul_perf_table_path(""); opts.set_xla_gpu_experimental_disable_binary_libraries(false); // --xla_ignore_channel_id should be kept false by default while channel ids // are load-bearing. @@ -2339,6 +2340,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_unsupported_crash_on_hlo_pass_noop_change(), "Crash if a pass reports that it did change the HLO but in fact it " "did not.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_matmul_perf_table_path", + string_setter_for( + &DebugOptions::set_xla_gpu_experimental_matmul_perf_table_path), + debug_options->xla_gpu_experimental_matmul_perf_table_path(), + "If non empty will interpret this variable as a path for performance " + "tables for matmuls. Expects `xla.gpu.DeviceHloInstructionProfiles` " + "proto.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 0f210aef19db1f..2d01c910f845e1 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1500,6 +1500,7 @@ cc_library( "//xla/service/gpu/model:collective_ptable_stats_collection", "//xla/service/gpu/model:gpu_cost_model_stats_collection", "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:matmul_ptable_stats_collection", "//xla/service/gpu/model:sol_gpu_cost_model_stats_collection", "//xla/service/gpu/transforms/collectives:convert_async_collectives_to_sync", "//xla/service/gpu/transforms/collectives:gpu_all_gather_combiner", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 84c4da396a2444..515f672acbf3a9 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -105,6 +105,8 @@ message GemmBackendConfig { optional bool grad_x = 16; optional bool grad_y = 17; bool damax_output = 18; + + repeated ReificationCost reification_cost = 19; } // Backend config for bitcast operation generated from MLIR MHLO dialect. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index bd8a44bc4f32ac..e2e2a1f862cce3 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -174,6 +174,7 @@ limitations under the License. #include "xla/service/gpu/model/collective_ptable_stats_collection.h" #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/matmul_ptable_stats_collection.h" #include "xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h" #include "xla/service/gpu/pre_scheduling_copy_insertion_pipeline.h" #include "xla/service/gpu/reduction_utils.h" @@ -2611,6 +2612,16 @@ absl::Status GpuCompiler::RunPreSchedulingPasses( pipeline.AddPass( collective_perf_table_path, gpu_device_info); } + + // Perf tables model analysis for matmuls. + if (std::string matmul_perf_table_path = + module->config() + .debug_options() + .xla_gpu_experimental_matmul_perf_table_path(); + !matmul_perf_table_path.empty()) { + pipeline.AddPass(matmul_perf_table_path, + gpu_device_info); + } } return pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index ccf54158aa4f44..6d4514555f800b 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -1202,3 +1202,52 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "matmul_ptable_stats_collection", + srcs = ["matmul_ptable_stats_collection.cc"], + hdrs = ["matmul_ptable_stats_collection.h"], + deps = [ + ":hlo_op_profile_proto_cc", + ":hlo_op_profiles", + ":matmul_interpolator", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_description", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "matmul_ptable_stats_collection_test", + srcs = ["matmul_ptable_stats_collection_test.cc"], + deps = [ + ":hlo_op_profile_proto_cc", + ":matmul_ptable_stats_collection", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tsl/platform:env", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:path", + ], +) diff --git a/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.cc b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.cc new file mode 100644 index 00000000000000..6ddfc0cf4a999f --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.cc @@ -0,0 +1,114 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/matmul_ptable_stats_collection.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/model/hlo_op_profile.pb.h" +#include "xla/service/gpu/model/hlo_op_profiles.h" +#include "xla/service/gpu/model/matmul_interpolator.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::gpu { + +namespace { + +absl::StatusOr CollectProfiles( + const std::string& perf_table_path, + const se::DeviceDescription& device_info) { + DeviceHloInstructionProfiles profile; + + TF_RETURN_IF_ERROR(tsl::Env::Default()->FileExists(perf_table_path)); + TF_RETURN_IF_ERROR(tsl::ReadTextOrBinaryProto(tsl::Env::Default(), + perf_table_path, &profile)); + std::string key = HloOpProfiles::GetProfileName(device_info); + + if (!profile.entries().contains(key)) { + return absl::NotFoundError(absl::StrCat("Cannot find key: ", key)); + } + return profile.entries().at(key); +} + +ReificationCost* GetReificationCost(HloOpcode opcode, + GpuBackendConfig& config) { + if (opcode == HloOpcode::kCustomCall) { + return config.mutable_gemm_backend_config()->add_reification_cost(); + } + if (opcode == HloOpcode::kFusion) { + return config.mutable_fusion_backend_config()->mutable_reification_cost(); + } + return nullptr; +} + +} // namespace + +absl::StatusOr MatmulPerfTableStatsCollection::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN(HloInstructionProfileList profiles, + CollectProfiles(perf_table_path_, device_info_)); + TF_ASSIGN_OR_RETURN(std::unique_ptr interpolator, + MatmulInterpolator::Create(profiles, device_info_)); + + hlo_query::ForEachInstructionWithPred( + *module, HloPredicateIsOp, + [&](HloInstruction* instr) { + // Generate exec time for a matmul. + auto estimation = interpolator->EstimatedRuntime(*instr); + if (!estimation.has_value()) { + VLOG(1) << "No estimation for: " << instr->ToString(); + return; + } + absl::Duration exec_time = *estimation; + + // Set it in the `CollectiveBackendConfig`. + auto gpu_config = instr->backend_config(); + TF_CHECK_OK(gpu_config.status()) + << "Cannot parse backend config: " << instr->ToString(); + ReificationCost* reification_cost = + GetReificationCost(instr->opcode(), *gpu_config); + if (reification_cost == nullptr) { + VLOG(1) << "No reification cost for: " << instr->ToString(); + return; + } + reification_cost->set_exec_time_us( + absl::ToDoubleMicroseconds(exec_time)); + *reification_cost->mutable_name() = name(); + TF_CHECK_OK(instr->set_backend_config(*gpu_config)); + }); + + return false; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.h b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.h new file mode 100644 index 00000000000000..7d7d254c5ef53a --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection.h @@ -0,0 +1,54 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_MATMUL_PTABLE_STATS_COLLECTION_H_ +#define XLA_SERVICE_GPU_MODEL_MATMUL_PTABLE_STATS_COLLECTION_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +class MatmulPerfTableStatsCollection : public HloModulePass { + public: + explicit MatmulPerfTableStatsCollection( + absl::string_view perf_table_path, + const se::DeviceDescription& device_info) + : perf_table_path_(perf_table_path), device_info_(device_info) {} + + absl::string_view name() const override { + return "matmul-perf-table-stats-collection"; + } + + using HloPassInterface::Run; + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const std::string perf_table_path_; + const se::DeviceDescription& device_info_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_MODEL_MATMUL_PTABLE_STATS_COLLECTION_H_ diff --git a/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc new file mode 100644 index 00000000000000..b37542391b03ba --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/matmul_ptable_stats_collection.h" + +#include +#include + +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/hlo_op_profile.pb.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "tsl/platform/path.h" + +namespace xla::gpu { +namespace { + +constexpr const char* kFile = "profiles.pbtxt"; + +using ::testing::Test; + +DeviceHloInstructionProfiles TestProfiles( + const se::DeviceDescription& device_info) { + constexpr char perf_table[] = R"pb( + entries { + key: "sm_89" + value { + entries { + instruction { + opcode: "dot" + shape { + element_type: BF16 + dimensions: 1 + dimensions: 1024 + dimensions: 1024 + } + dot_dimension_numbers { + lhs_contracting_dimensions: 2 + rhs_contracting_dimensions: 1 + lhs_batch_dimensions: 0 + rhs_batch_dimensions: 0 + } + id: 2 + operand_ids: 0 + operand_ids: 1 + } + operands { + name: "lhs" + opcode: "parameter" + shape { + element_type: BF16 + dimensions: 1 + dimensions: 1024 + dimensions: 1024 + } + } + operands { + name: "rhs" + opcode: "parameter" + shape { + element_type: BF16 + dimensions: 1 + dimensions: 1024 + dimensions: 1024 + } + parameter_number: 1 + id: 1 + } + clock_cycles: 1410000000 + } + } + } + )pb"; + DeviceHloInstructionProfiles profiles; + CHECK(tsl::protobuf::TextFormat::ParseFromString(perf_table, &profiles)); + return profiles; +} + +class MatmulPerfTableStatsCollectionTest : public Test { + public: + explicit MatmulPerfTableStatsCollectionTest() + : device_info_(TestGpuDeviceInfo::RTXA6000DeviceInfo()), + profiles_path_(tsl::io::JoinPath(tsl::testing::TmpDir(), kFile)) {} + + void SetUp() override { + CHECK_OK(tsl::WriteTextProto(tsl::Env::Default(), profiles_path_, + TestProfiles(device_info_))); + } + + protected: + const se::DeviceDescription device_info_; + const std::string profiles_path_; +}; + +TEST_F(MatmulPerfTableStatsCollectionTest, + CollectsMatmulPerfTableDataForGemmCustomCalls) { + absl::string_view hlo = R"( + HloModule m + + ENTRY e { + p0 = bf16[1024,1024] parameter(0) + p1 = bf16[1024,1024] parameter(1) + ROOT dot = (bf16[1024,1024], s8[2097152]{0}) custom-call(p0,p1), + custom_call_target="__cublas$gemm", + backend_config={ + "operation_queue_id":"0", + "wait_on_operation_queues":[], + "gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers": { + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["1"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + } + } + } + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + TF_ASSERT_OK_AND_ASSIGN( + bool changed, MatmulPerfTableStatsCollection(profiles_path_, device_info_) + .Run(module.get())); + + VLOG(1) << module->ToString(); + + EXPECT_FALSE(changed); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + CHECK: dot + CHECK-SAME: gemm_backend_config + CHECK-SAME: "exec_time_us":1000000 + )")); +} + +TEST_F(MatmulPerfTableStatsCollectionTest, + CollectsMatmulPerfTableDataForTritonFusionConfig) { + absl::string_view hlo = R"( + HloModule m + + comp { + p0 = bf16[1024,1024] parameter(0) + p1 = bf16[1024,1024] parameter(1) + ROOT _ = bf16[1024,1024] dot(p0,p1), + lhs_contracting_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY e { + p0 = bf16[1024,1024] parameter(0) + p1 = bf16[1024,1024] parameter(1) + ROOT triton_gemm = bf16[1024,1024] fusion(p0,p1), + kind=kCustom, + calls=comp, + backend_config={ + "operation_queue_id":"0", + "wait_on_operation_queues":[], + "fusion_backend_config": { + "kind":"__triton_gemm", + "triton_gemm_config":{ + "block_m":"128", + "block_n":"128", + "block_k":"64", + "split_k":"1", + "num_stages":"1", + "num_warps":"8", + "num_ctas":"1" + } + }, + } + } +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, MatmulPerfTableStatsCollection(profiles_path_, device_info_) + .Run(module.get())); + + VLOG(1) << module->ToString(); + + EXPECT_FALSE(changed); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + CHECK: triton_gemm + CHECK-SAME: fusion_backend_config + CHECK-SAME: "exec_time_us":1000000 + )")); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index eb20e72bd5d53e..e1e951b3f8bded 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -1197,10 +1197,13 @@ message DebugOptions { // use command buffer mode to run. bool xla_test_add_command_buffer_mode = 373; + // Path to experimental collective perf tables. + string xla_gpu_experimental_matmul_perf_table_path = 383; + // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 383 + // Next id: 384 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 247c1a13005a8baa7cd70c8a5532fd0585ca6df6 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 31 Mar 2025 06:11:05 -0700 Subject: [PATCH 0043/1324] [XLA:GPU] Add a pattern to fold vector.insert(vector.extract). When there is nothing fused in the transpose, we read the input and then write it to shmem. We can fold away the packing/unpacking of the vectors. ``` %0 = vector.transfer_read scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %cst) -> (vector<2xf32>) { 405: %3 = vector.extract %0[%arg2] : f32 from vector<2xf32> 406: %4 = vector.insert %3, %arg3 [%arg2] : f32 into vector<2xf32> 407: scf.yield %4 : vector<2xf32> 408: } ``` vector.transfer_write PiperOrigin-RevId: 742240164 --- .../tests/vectorize_loads_stores.mlir | 103 ++++++++++++++++ .../transforms/vectorize_loads_stores.cc | 116 +++++++++++++++--- 2 files changed, 202 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/codegen/emitters/transforms/tests/vectorize_loads_stores.mlir index 1335419c32e879..35cec5a54149bb 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/codegen/emitters/transforms/tests/vectorize_loads_stores.mlir @@ -531,3 +531,106 @@ func.func @simple_atomic_rmw(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-HOPPER: xla.atomic_rmw %[[ARG0]] // CHECK-HOPPER-NEXT: ^bb0(%[[CURRENT:.*]]: vector<2xf32>): // CHECK-HOPPER-NEXT: arith.addf %[[CURRENT]], %[[LOOP]] + +// ----- + +func.func @fold_insert_extract(%in: tensor<64xf32>, %out: tensor<64xf32>) + -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%out_ = %out) -> tensor<64xf32> { + %extracted = tensor.extract %in[%j] : tensor<64xf32> + %inserted = tensor.insert %extracted into %out_[%j] : tensor<64xf32> + scf.yield %inserted : tensor<64xf32> + } + return %loop : tensor<64xf32> +} +// CHECK-LABEL: @fold_insert_extract +// CHECK-NOT: scf.for +// CHECK: vector.transfer_read +// CHECK-NEXT: vector.transfer_write + + +// ----- + +func.func @fold_insert_extract_two_results( + %arg0: tensor<8xf64>, %arg1: tensor<8xf64>, + %arg2: tensor<8xf64>, %arg3: tensor<8xf64>, %arg4: tensor<8xf64>) + -> (tensor<8xf64>, tensor<8xf64>) { + %cst = arith.constant 0.00e+00 : f64 + %cst_0 = arith.constant dense<0.00e+00> : vector<4xf64> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %18 = vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} + : tensor<8xf64>, vector<4xf64> + %16 = vector.transfer_read %arg2[%c0], %cst {in_bounds = [true]} + : tensor<8xf64>, vector<4xf64> + %20 = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true]} + : tensor<8xf64>, vector<4xf64> + %21:4 = scf.for %i = %c0 to %c4 step %c1 + iter_args(%arg6 = %arg3, %arg7 = %arg4, %arg8 = %cst_0, %arg9 = %cst_0) + -> (tensor<8xf64>, tensor<8xf64>, vector<4xf64>, vector<4xf64>) { + %24 = vector.extract %20[%i] : f64 from vector<4xf64> + %25 = vector.extract %18[%i] : f64 from vector<4xf64> + %26 = arith.addf %24, %25 : f64 + %27 = vector.extract %16[%i] : f64 from vector<4xf64> + %28 = vector.insert %27, %arg8 [%i] : f64 into vector<4xf64> + %29 = vector.insert %26, %arg9 [%i] : f64 into vector<4xf64> + scf.yield %arg6, %arg7, %28, %29 + : tensor<8xf64>, tensor<8xf64>, vector<4xf64>, vector<4xf64> + } + %22 = vector.transfer_write %21#3, %arg3[%c0] {in_bounds = [true]} + : vector<4xf64>, tensor<8xf64> + %23 = vector.transfer_write %21#2, %arg4[%c0] {in_bounds = [true]} + : vector<4xf64>, tensor<8xf64> + return %22, %23 : tensor<8xf64>, tensor<8xf64> +} +// CHECK-LABEL: func.func @fold_insert_extract_two_results( +// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]]: tensor<8xf64>, +// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]]: tensor<8xf64>, +// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]]: tensor<8xf64>, +// CHECK-SAME: %[[VAL_3:[a-zA-Z0-9_]*]]: tensor<8xf64>, +// CHECK-SAME: %[[VAL_4:[a-zA-Z0-9_]*]]: tensor<8xf64>) -> (tensor<8xf64>, tensor<8xf64>) { + +// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<4xf64> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_0_VEC:.*]] = vector.transfer_read %[[VAL_0]] +// CHECK-DAG: %[[VAL_1_VEC:.*]] = vector.transfer_read %[[VAL_1]] +// CHECK-DAG: %[[VAL_2_VEC:.*]] = vector.transfer_read %[[VAL_2]] + +// CHECK: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT:.*]] = %[[C0_VEC]]) -> (vector<4xf64>) { + +// CHECK: %[[VEC_0_ELEM:.*]] = vector.extract %[[VAL_0_VEC]][%[[I]]] +// CHECK: %[[VEC_1_ELEM:.*]] = vector.extract %[[VAL_1_VEC]][%[[I]]] +// CHECK: %[[ADD:.*]] = arith.addf %[[VEC_0_ELEM]], %[[VEC_1_ELEM]] : f64 +// CHECK: %[[INSERT:.*]] = vector.insert %[[ADD]], %[[INIT]] [%[[I]]] +// CHECK: scf.yield %[[INSERT]] : vector<4xf64> +// CHECK: } +// CHECK: %[[RES0:.*]] = vector.transfer_write %[[FOR]], %[[VAL_3]][%[[C0]]] +// CHECK: %[[RES1:.*]] = vector.transfer_write %[[VAL_2_VEC]], %[[VAL_4]][%[[C0]]] +// CHECK: return %[[RES0]], %[[RES1]] : tensor<8xf64>, tensor<8xf64> +// CHECK: } + +// ----- + +func.func @avoid_folding_small_tensors(%arg0: tensor<2xi4>, %arg1: tensor<2xi4>) + -> tensor<2xi4> { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %0 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %arg1) + -> (tensor<2xi4>) { + %extracted = tensor.extract %arg0[%arg2] : tensor<2xi4> + %inserted = tensor.insert %extracted into %arg3[%arg2] : tensor<2xi4> + scf.yield %inserted : tensor<2xi4> + } + return %0 : tensor<2xi4> +} +// CHECK-LABEL: func.func @avoid_folding_small_tensors +// CHECK: scf.for diff --git a/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc index f90152fdac0b61..9fbfcca9e47c44 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc @@ -297,26 +297,14 @@ struct VectorizeLoad : mlir::OpRewritePattern { // Verifies that the insertions happening in the loop can all safely be batched // in the end. -bool IsConflictFree(mlir::tensor::InsertOp op) { +bool IsConflictFree(mlir::Operation* op, Value destination) { // The insertion's only use must be the yield. if (!op->hasOneUse() || !mlir::isa(*op->user_begin())) { return false; } // The destination must be one of the loop's block arguments, and the // destination must be the argument's only use. - auto bbarg = mlir::dyn_cast(op.getDest()); - return bbarg && bbarg.hasOneUse() && - bbarg.getOwner()->getParentOp() == op->getParentOp(); -} - -bool IsConflictFree(AtomicRMWOp op) { - // The insertion's only use must be the yield. - if (!op->hasOneUse() || !mlir::isa(*op->user_begin())) { - return false; - } - // The destination must be one of the loop's block arguments, and the - // destination must be the argument's only use. - auto bbarg = mlir::dyn_cast(op->getOpOperand(0).get()); + auto bbarg = mlir::dyn_cast(destination); return bbarg && bbarg.hasOneUse() && bbarg.getOwner()->getParentOp() == op->getParentOp(); } @@ -340,7 +328,7 @@ class VectorizeAtomicRMW : public mlir::OpRewritePattern { return rewriter.notifyMatchFailure(op, "no loop found"); } - if (!IsConflictFree(op)) { + if (!IsConflictFree(op, op.getOperand(0))) { return rewriter.notifyMatchFailure(op, "write may be read back by loop"); } @@ -415,7 +403,7 @@ struct VectorizeStore : mlir::OpRewritePattern { if (!loop) { return rewriter.notifyMatchFailure(op, "no loop found"); } - if (!IsConflictFree(op)) { + if (!IsConflictFree(op, op.getDest())) { return rewriter.notifyMatchFailure(op, "write may be read back by loop"); } auto vector_type = GetVectorType(op.getDest().getType(), loop); @@ -462,6 +450,99 @@ struct VectorizeStore : mlir::OpRewritePattern { } }; +struct FoldVectorInsertExtractPairs + : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::vector::InsertOp insert, + mlir::PatternRewriter& rewriter) const override { + // Check that the vector is 1D and the index is dynamic. + auto vector_type = insert.getDest().getType(); + if (vector_type.getRank() != 1 || !insert.hasDynamicPosition()) { + return rewriter.notifyMatchFailure(insert, "the vector should be 1D"); + } + Value index = insert.getDynamicPosition().front(); + + // Check that the value that we insert is produced by a vector.extract. + auto extract = mlir::dyn_cast_or_null( + insert.getSource().getDefiningOp()); + if (!extract || !extract.hasDynamicPosition() || !extract->hasOneUse()) { + return rewriter.notifyMatchFailure(insert, + "no single-use vector.extract found"); + } + + // Check that the insert is in the loop and is used only by the yield. + auto loop = mlir::dyn_cast_or_null(insert->getParentOp()); + if (!loop) { + return rewriter.notifyMatchFailure(insert, "no scf.for loop found"); + } + if (!IsConflictFree(insert, insert.getDest())) { + return rewriter.notifyMatchFailure(insert, + "write may be read back by loop"); + } + + // Avoid folding the vector.transfer_read from the vector-sized tensors. + // Otherwise, the following function: + // + // func.func @example(%arg0: tensor<2xi4>, %arg1: tensor<2xi4>) + // -> tensor<2xi4> { + // %0 = scf.for %arg2 = %c0 to %c2 step %c1 + // iter_args(%arg3 = %arg1) -> (tensor<2xi4>) { + // %extracted = tensor.extract %arg0[%arg2] : tensor<2xi4> + // %inserted = tensor.insert %extracted into %arg3[%arg2] : tensor<2xi4> + // scf.yield %inserted : tensor<2xi4> + // } + // return %0 : tensor<2xi4> + // } + // + // will be folded into: + // + // func.func @example(%arg0: tensor<2xi4>, + // %arg1: tensor<2xi4>) -> tensor<2xi4> { + // return %arg0 : tensor<2xi4> + // } + // and the data won't be copied. + auto bbarg = mlir::cast(insert.getDest()); + int64_t result_index = bbarg.getArgNumber() - 1; + if (auto transfer_read = + extract.getVector().getDefiningOp()) { + if (transfer_read.getSource().getType().getNumElements() == + vector_type.getNumElements()) { + return rewriter.notifyMatchFailure( + insert, + "do not fold the vector.transfer_read from the vector-sized " + "tensors."); + } + } + + // Check that the extract and insert use the same IV. + if (extract.getDynamicPosition().front() != index || + index != loop.getInductionVar()) { + return rewriter.notifyMatchFailure( + insert, + "both insert and extract should use the IV of the parent loop"); + } + // Check the loop spans the whole vector. + if (mlir::getConstantIntValue(loop.getUpperBound()) != + vector_type.getDimSize(0) || + mlir::getConstantIntValue(loop.getStep()) != 1 || + mlir::getConstantIntValue(loop.getLowerBound()) != 0) { + return rewriter.notifyMatchFailure( + insert, "loop bounds don't match the vector type"); + } + + // Replace the loop result with the corresponding init. + auto yield_op = loop.getBody()->getTerminator(); + rewriter.modifyOpInPlace(yield_op, [&]() { + yield_op->setOperand(result_index, insert.getDest()); + }); + rewriter.replaceAllUsesWith(loop->getResult(result_index), + extract.getVector()); + return mlir::success(); + } +}; + class VectorizeLoadsAndStoresPass : public impl::VectorizeLoadsAndStoresPassBase< VectorizeLoadsAndStoresPass> { @@ -486,7 +567,8 @@ class VectorizeLoadsAndStoresPass } mlir::MLIRContext* mlir_context = &getContext(); mlir::RewritePatternSet patterns(mlir_context); - patterns.add(mlir_context); + patterns.add( + mlir_context); patterns.add(mlir_context, device_spec_); if (mlir::failed( mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { From 7e4367bede5828c5831f7aab068e9e898c92ff43 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 31 Mar 2025 06:14:50 -0700 Subject: [PATCH 0044/1324] [xla:util] Add overloads to PackIntN and UnpackIntN to clean up jaxlib logic. PiperOrigin-RevId: 742241029 --- third_party/xla/xla/util.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h index d80218b8d22bf0..2a5f11b497b007 100644 --- a/third_party/xla/xla/util.h +++ b/third_party/xla/xla/util.h @@ -892,6 +892,19 @@ inline void PackIntN(int bits_per_element, absl::Span input, } } +// Same as above, but takes the number of bits per element, a pointer to the +// source data, and the size of the data in bytes. Returns a unique pointer to +// the packed data. +inline std::unique_ptr PackIntN(int bits_per_element, const char* data, + size_t size) { + size_t packed_size = size * bits_per_element / 8; + auto buffer = std::make_unique(packed_size); + auto src = absl::MakeSpan(data, size); + auto dst = absl::MakeSpan(buffer.get(), packed_size); + PackIntN(bits_per_element, src, dst); + return buffer; +} + // Takes a sequence of packed values, such that every byte stores multiple // values, and unpacks them so every byte stores one value in the low-order // bits. `input` should have @@ -932,6 +945,19 @@ inline void UnpackIntN(int bits_per_element, absl::Span input, } } +// Same as above, but takes the number of bits per element, a pointer to the +// source data, and the size of the data in bytes. Returns a unique pointer to +// the unpacked data. +inline std::unique_ptr UnpackIntN(int bits_per_element, + const char* data, size_t size) { + size_t unpacked_size = size * 8 / bits_per_element; + auto buffer = std::make_unique(unpacked_size); + auto src = absl::MakeSpan(data, size); + auto dst = absl::MakeSpan(buffer.get(), unpacked_size); + UnpackIntN(bits_per_element, src, dst); + return buffer; +} + // Returns a container with `sorted_ids_to_remove` elements removed. template static T RemoveElements(absl::Span sorted_ids_to_remove, From 46491f9da68b153ab9768896e9e7cb69a18f9a7a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 31 Mar 2025 06:28:48 -0700 Subject: [PATCH 0045/1324] [xla:gpu] CommandBuffer: add e2e tests for For and While loops PiperOrigin-RevId: 742244088 --- .../gpu/runtime/command_buffer_cmd_emitter.cc | 12 +- .../xla/backends/gpu/runtime/while_thunk.h | 2 + .../service/gpu/tests/command_buffer_test.cc | 126 ++++++++++++++++++ 3 files changed, 138 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index 271e355bbd6b79..7fabd03e6707bb 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/command_buffer_cmd_emitter.h" +#include #include #include #include @@ -115,9 +116,16 @@ static absl::StatusOr Convert( TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence body_cmds, ConvertToCommands(thunk.body_thunk_sequence()->thunks(), synchronization_mode)); - return std::make_unique(thunk.execution_stream_id(), + + if (std::optional trip_count = thunk.trip_count()) { + return std::make_unique(thunk.execution_stream_id(), *trip_count, thunk.condition_result_buffer(), - std::move(cond_cmds), std::move(body_cmds)); + std::move(body_cmds)); + } else { + return std::make_unique( + thunk.execution_stream_id(), thunk.condition_result_buffer(), + std::move(cond_cmds), std::move(body_cmds)); + } } static absl::StatusOr Convert(const GemmThunk& thunk) { diff --git a/third_party/xla/xla/backends/gpu/runtime/while_thunk.h b/third_party/xla/xla/backends/gpu/runtime/while_thunk.h index daecdf5c4728f5..dd9bb83bb9163c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/while_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/while_thunk.h @@ -78,6 +78,8 @@ class WhileThunk : public Thunk { return condition_result_buffer_index_; } + std::optional trip_count() const { return trip_count_; } + // Returns the current loop iteration if the caller is inside a while loop(s). // // Implementation relies on thread local storage, be careful when call it from diff --git a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc index 377daaffe0c27c..f16683b86629d9 100644 --- a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc +++ b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc @@ -196,5 +196,131 @@ TEST_F(CommandBufferTest, IndexConditional) { } } +TEST_F(CommandBufferTest, DISABLED_ForLoop) { + constexpr absl::string_view hlo_text = R"( + HloModule m, is_scheduled=true + + compare_fusion { + p0 = s32[] parameter(0) + ten = s32[] constant(10) + ROOT compare = compare(p0, ten), direction=LT + } + + add_one { + p0 = s32[] parameter(0) + one = s32[] constant(1) + ROOT add = add(p0, one) + } + + add_two { + p0 = f32[] parameter(0) + two = f32[] constant(2.0) + ROOT add = add(p0, two) + } + + body { + p0 = (s32[], f32[]) parameter(0) + cnt = get-tuple-element(p0), index=0 + val = get-tuple-element(p0), index=1 + add_cnt = s32[] fusion(cnt), kind=kLoop, calls=add_one + add_val = f32[] fusion(val), kind=kLoop, calls=add_two + ROOT tuple = (s32[], f32[]) tuple(add_cnt, add_val) + } + + cond { + p0 = (s32[], f32[]) parameter(0) + cnt = get-tuple-element(p0), index=0 + ROOT compare = pred[] fusion(cnt), kind=kLoop, calls=compare_fusion + } + + command_buffer { + p0 = (s32[], f32[]) parameter(0) + ROOT while = while(p0), condition=cond, body=body, + backend_config={"known_trip_count":{"n":"20"}} + } + + ENTRY main { + p0 = (s32[], f32[]) parameter(0) + ROOT call = (s32[], f32[]) call(p0), to_apply=command_buffer + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + + Literal cnt = LiteralUtil::CreateR0(0); + Literal value = LiteralUtil::CreateR0(0.0); + Literal argument = LiteralUtil::MakeTuple({&cnt, &value}); + + // Because we set the known trip count to 20, the loop will execute 20 times, + // and it will ignore the `cond` result, which would terminate the loop after + // 10 iterations (see WhileLoop test below that runs the real while loop). + Literal expected_cnt = LiteralUtil::CreateR0(20); + Literal expected_value = LiteralUtil::CreateR0(40.0); + Literal expected = LiteralUtil::MakeTuple({&expected_cnt, &expected_value}); + + Literal result = ExecuteNoHloPasses(std::move(module), {&argument}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_F(CommandBufferTest, WhileLoop) { + constexpr absl::string_view hlo_text = R"( + HloModule m, is_scheduled=true + + compare_fusion { + p0 = s32[] parameter(0) + ten = s32[] constant(10) + ROOT compare = compare(p0, ten), direction=LT + } + + add_one { + p0 = s32[] parameter(0) + one = s32[] constant(1) + ROOT add = add(p0, one) + } + + add_two { + p0 = f32[] parameter(0) + two = f32[] constant(2.0) + ROOT add = add(p0, two) + } + + body { + p0 = (s32[], f32[]) parameter(0) + cnt = get-tuple-element(p0), index=0 + val = get-tuple-element(p0), index=1 + add_cnt = s32[] fusion(cnt), kind=kLoop, calls=add_one + add_val = f32[] fusion(val), kind=kLoop, calls=add_two + ROOT tuple = (s32[], f32[]) tuple(add_cnt, add_val) + } + + cond { + p0 = (s32[], f32[]) parameter(0) + cnt = get-tuple-element(p0), index=0 + ROOT compare = pred[] fusion(cnt), kind=kLoop, calls=compare_fusion + } + + command_buffer { + p0 = (s32[], f32[]) parameter(0) + ROOT while = while(p0), condition=cond, body=body + } + + ENTRY main { + p0 = (s32[], f32[]) parameter(0) + ROOT call = (s32[], f32[]) call(p0), to_apply=command_buffer + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + + Literal cnt = LiteralUtil::CreateR0(0); + Literal value = LiteralUtil::CreateR0(0.0); + Literal argument = LiteralUtil::MakeTuple({&cnt, &value}); + + Literal expected_cnt = LiteralUtil::CreateR0(10); + Literal expected_value = LiteralUtil::CreateR0(20.0); + Literal expected = LiteralUtil::MakeTuple({&expected_cnt, &expected_value}); + + Literal result = ExecuteNoHloPasses(std::move(module), {&argument}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + } // namespace } // namespace xla::gpu From 1848bfc798b26c0e4618ca91518c947b3c9946e2 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 31 Mar 2025 06:39:39 -0700 Subject: [PATCH 0046/1324] [XLA:GPU] Change AtomicOrdering from seq_cst to monotonic. The ordering was too strict and after https://github.com/llvm/llvm-project/commit/9638d08af96c4cb8cf16785eed92179b2658bdfe a memory barrier was inserted in PTX. PiperOrigin-RevId: 742246460 --- .../emitters/transforms/lower_tensors.cc | 14 +++---- .../transforms/tests/lower_tensors.mlir | 40 +++++++++---------- .../xla/service/gpu/tests/gpu_atomic_test.cc | 6 +-- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc index 72aed5c5506c03..d122f70533dc5c 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc @@ -846,7 +846,7 @@ class RewriteAtomicRMW : public OpRewritePattern { case ml::AtomicBinOp::umax: case ml::AtomicBinOp::umin: { rewriter.create(loc, atomic_bin_op, addr, modifier_arg, - ml::AtomicOrdering::seq_cst, + ml::AtomicOrdering::monotonic, sync_scope); return success(); } @@ -897,7 +897,7 @@ class RewriteAtomicRMW : public OpRewritePattern { } b.create(loc, ml::AtomicBinOp::fadd, addr, modifier_arg, - ml::AtomicOrdering::seq_cst, sync_scope); + ml::AtomicOrdering::monotonic, sync_scope); return success(); } @@ -971,7 +971,7 @@ class RewriteAtomicRMW : public OpRewritePattern { loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), addr); } b.create(loc, ml::AtomicBinOp::fadd, addr, modifier_arg, - ml::AtomicOrdering::seq_cst, sync_scope); + ml::AtomicOrdering::monotonic, sync_scope); return success(); } @@ -1034,14 +1034,14 @@ class RewriteAtomicRMW : public OpRewritePattern { // atomicMax((int *)address, __float_as_int(val)) nested_b.create( loc, ml::AtomicBinOp::max, addr, source_float_as_int, - ml::AtomicOrdering::seq_cst, sync_scope); + ml::AtomicOrdering::monotonic, sync_scope); nested_b.create(nested_loc); }, [&](OpBuilder& nested_b, Location nested_loc) { // atomicMax((int *)address, __float_as_int(val)) nested_b.create( loc, ml::AtomicBinOp::umin, addr, source_float_as_int, - ml::AtomicOrdering::seq_cst, sync_scope); + ml::AtomicOrdering::monotonic, sync_scope); nested_b.create(nested_loc); }); then_builder.create(loc); @@ -1185,8 +1185,8 @@ class RewriteAtomicRMW : public OpRewritePattern { // Try saving the result atomically, retry if failed. Value cmpxchg = b.create( loc, addr, old_value, new_value, - /*success_ordering=*/ml::AtomicOrdering::seq_cst, - /*failure_ordering=*/ml::AtomicOrdering::seq_cst); + /*success_ordering=*/ml::AtomicOrdering::monotonic, + /*failure_ordering=*/ml::AtomicOrdering::monotonic); Value next = b.create(cmpxchg, 0); Value ok = b.create(cmpxchg, 1); Value low_bit = b.create(b.getOneAttr(b.getI1Type())); diff --git a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir index 8feb7354a96bae..83be709942d973 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir @@ -409,7 +409,7 @@ func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) // CHECK: %[[MASKED_INIT:.*]] = llvm.and %[[INIT]] // CHECK: %[[NEW_VALUE_SHIFTED:.*]] = llvm.shl %[[NEW_VALUE_I32]] // CHECK: %[[NEW_INIT:.*]] = llvm.or %[[MASKED_INIT]], %[[NEW_VALUE_SHIFTED]] -// CHECK: llvm.cmpxchg %{{.*}}, %[[INIT]], %[[NEW_INIT]] seq_cst seq_cst +// CHECK: llvm.cmpxchg %{{.*}}, %[[INIT]], %[[NEW_INIT]] monotonic monotonic // CHECK: scf.condition // ----- @@ -443,7 +443,7 @@ func.func @direct_atomic_rmw_addi(%in: tensor<8xi32>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_addi // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw add %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw add %[[ADDR]], %[[C2]] monotonic // ----- @@ -460,7 +460,7 @@ func.func @direct_atomic_rmw_maxsi(%in: tensor<8xi32>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_maxsi // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw max %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw max %[[ADDR]], %[[C2]] monotonic // ----- @@ -477,7 +477,7 @@ func.func @direct_atomic_rmw_maxui(%in: tensor<8xi32>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_maxui // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw umax %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw umax %[[ADDR]], %[[C2]] monotonic // ----- @@ -494,7 +494,7 @@ func.func @direct_atomic_rmw_minsi(%in: tensor<8xi32>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_minsi // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw min %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw min %[[ADDR]], %[[C2]] monotonic // ----- @@ -511,7 +511,7 @@ func.func @direct_atomic_rmw_minui(%in: tensor<8xi32>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_minui // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw umin %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw umin %[[ADDR]], %[[C2]] monotonic // ----- @@ -528,29 +528,29 @@ func.func @direct_atomic_rmw_fadd_f32(%in: tensor<8xf32>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-VOLTA: %[[C2:.*]] = arith.constant 2 // CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-AMPERE: %[[C2:.*]] = arith.constant 2 // CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2 // CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr // CHECK-GFX908-MI100: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1> -// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst +// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") monotonic // CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2 // CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr // CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1> -// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst +// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") monotonic // ----- @@ -570,12 +570,12 @@ func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, // CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f16 // CHECK-VOLTA: %[[C2:.*]] = arith.constant 2 // CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f16 // CHECK-AMPERE: %[[C2:.*]] = arith.constant 2 // CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f16 // CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd @@ -584,7 +584,7 @@ func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, // CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2 // CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr // CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1> -// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst +// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") monotonic // ----- @@ -604,7 +604,7 @@ func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<8xbf16>, // CHECK-HOPPER-LABEL: @direct_atomic_rmw_fadd_bf16 // CHECK-HOPPER: %[[C2:.*]] = arith.constant 2 // CHECK-HOPPER: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-HOPPER: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-HOPPER: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // ----- @@ -621,17 +621,17 @@ func.func @direct_atomic_rmw_fadd_f64(%in: tensor<8xf64>, // CHECK-PASCAL-LABEL: @direct_atomic_rmw_fadd_f64 // CHECK-PASCAL: %[[C2:.*]] = arith.constant 2 // CHECK-PASCAL: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-PASCAL: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f64 // CHECK-VOLTA: %[[C2:.*]] = arith.constant 2 // CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f64 // CHECK-AMPERE: %[[C2:.*]] = arith.constant 2 // CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst +// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f64 // CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd @@ -668,9 +668,9 @@ func.func @direct_atomic_rmw_maximumf(%in: tensor<8xf32>, // CHECK-PASCAL: %[[INT_MODIFIER_OR_NAN:.*]] = llvm.bitcast %[[MODIFIER_OR_NAN]] : f32 to i32 // CHECK-PASCAL: %[[IS_POSITIVE:.*]] = llvm.icmp "sge" %[[INT_MODIFIER_OR_NAN]], %[[C0]] : i32 // CHECK-PASCAL: scf.if %[[IS_POSITIVE]] { -// CHECK-PASCAL: llvm.atomicrmw max %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw max %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] monotonic // CHECK-PASCAL: } else { -// CHECK-PASCAL: llvm.atomicrmw umin %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst +// CHECK-PASCAL: llvm.atomicrmw umin %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] monotonic // CHECK-PASCAL: } // CHECK-PASCAL: } // CHECK-PASCAL: } diff --git a/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc index 6897b9fa850e35..53cdf60e56c99a 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc @@ -101,10 +101,10 @@ TEST_F(GpuAtomicTest, TestAddAtomicF32) { )"; CompileAndVerifyIr(hlo_string, is_built_with_rocm_ ? R"( -CHECK: atomicrmw fadd ptr addrspace(1) %[[ADDR:.*]], float %[[VALUE:.*]] syncscope("agent") seq_cst +CHECK: atomicrmw fadd ptr addrspace(1) %[[ADDR:.*]], float %[[VALUE:.*]] syncscope("agent") monotonic )" : R"( -CHECK: atomicrmw fadd ptr %[[ADDR:.*]], float %[[VALUE:.*]] seq_cst +CHECK: atomicrmw fadd ptr %[[ADDR:.*]], float %[[VALUE:.*]] monotonic )"); } @@ -141,7 +141,7 @@ TEST_F(GpuAtomicTest, TestAddAtomicF64) { )"; CompileAndVerifyIr(hlo_string, R"( -CHECK: atomicrmw fadd ptr %[[ADDR:.*]], double %[[VALUE:.*]] seq_cst +CHECK: atomicrmw fadd ptr %[[ADDR:.*]], double %[[VALUE:.*]] monotonic )"); } From 5b1be492f194aac5cfa94d6ac87ed0dc2b018ded Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Mon, 31 Mar 2025 07:40:39 -0700 Subject: [PATCH 0047/1324] [XLA:CPU] Ensure benchmark name uniqueness in `multi_benchmark_config` and pass benchmark options by value PiperOrigin-RevId: 742261220 --- .../cpu/benchmarks/multi_benchmark_config.h | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/multi_benchmark_config.h b/third_party/xla/xla/backends/cpu/benchmarks/multi_benchmark_config.h index 1358314a47ce6b..b03e71709500c6 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/multi_benchmark_config.h +++ b/third_party/xla/xla/backends/cpu/benchmarks/multi_benchmark_config.h @@ -254,18 +254,17 @@ class MultiBenchmarkConfig { // Benchmarks 'fn' in JIT and AOT modes. The JIT benchmark // keeps the given 'name'; AOT is suffixed with '_Aot'. inline MultiBenchmarkConfig* RegisterJitAndAotBenchmarks( - absl::string_view name, - void(fn)(benchmark::State&, const HloBenchmarkOptions&)) { + absl::string_view name, void(fn)(benchmark::State&, HloBenchmarkOptions)) { std::string jit_name(name); std::string aot_name = jit_name + "_Aot"; auto jit_fn = [fn](benchmark::State& state) { HloBenchmarkOptions options; - fn(state, options); + fn(state, std::move(options)); }; auto aot_fn = [fn](benchmark::State& state) { HloBenchmarkOptions options; options.aot_options = GetAotCompilationOptions(); - fn(state, options); + fn(state, std::move(options)); }; benchmark::internal::Benchmark* jit = benchmark::RegisterBenchmark(jit_name, jit_fn); @@ -277,8 +276,17 @@ inline MultiBenchmarkConfig* RegisterJitAndAotBenchmarks( // Registers the given benchmark in both JIT and AOT modes. // The benchmark's function signature must be as follows: // `void BenchmarkFunc(benchmark::State&, const HloBenchmarkOptions&)`. -#define XLA_CPU_BENCHMARK(n) \ - static MultiBenchmarkConfig* n##_ptr = RegisterJitAndAotBenchmarks(#n, n) +#define XLA_CPU_BENCHMARK(n) XLA_CPU_BENCHMARK_HELPER(__COUNTER__, n) + +// Helper for implementing macros above. Do not use directly. +// +// Forces the evaluation of "counter", which we expect is equal to __COUNTER__. +#define XLA_CPU_BENCHMARK_HELPER(ctr, n) XLA_CPU_BENCHMARK_HELPER2(ctr, n) + +// Helper for macros above. Don't use directly. +#define XLA_CPU_BENCHMARK_HELPER2(ctr, n) \ + static MultiBenchmarkConfig* xla_cpu_benchmark_ptr_##ctr = \ + RegisterJitAndAotBenchmarks(#n, n) } // namespace xla::cpu From ad9b84ae30820934dfb9845bc1ae512a0607e780 Mon Sep 17 00:00:00 2001 From: Frederic Rechtenstein Date: Mon, 31 Mar 2025 08:35:48 -0700 Subject: [PATCH 0048/1324] Make TFLite interpreter `model_path` argument PEP 519 compliant. Allows to pass `pathlib.Path` instances for `model_path`. PiperOrigin-RevId: 742276190 --- tensorflow/lite/python/interpreter.py | 2 +- tensorflow/lite/python/interpreter_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index e4db4bfcba3b7a..ef102d6c8c5ab0 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -488,7 +488,7 @@ def __init__( x for x in self._custom_op_registerers if not isinstance(x, str) ] self._interpreter = _interpreter_wrapper.CreateWrapperFromFile( - model_path, + os.fspath(model_path), op_resolver_id, custom_op_registerers_by_name, custom_op_registerers_by_func, diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index f216ce44791102..ad7ec933c8ec3a 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -15,6 +15,7 @@ """TensorFlow Lite Python Interface: Sanity check.""" import ctypes import io +import pathlib import sys from unittest import mock @@ -96,6 +97,16 @@ def assertQuantizationParamsEqual(self, scales, zero_points, self.assertAllEqual(zero_points, params['zero_points']) self.assertEqual(quantized_dimension, params['quantized_dimension']) + def testPathLikeModel(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=pathlib.Path( + resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite' + ) + ), + ) + interpreter.allocate_tensors() + def testThreads_NegativeValue(self): with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'): interpreter_wrapper.Interpreter( From 2a09754e60aaf3b001b6b10c383b0660e4cdf2e0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 08:50:44 -0700 Subject: [PATCH 0049/1324] Reduce compile time when using large sharded shapes by constructing sharding_tree at most once per parameter. PiperOrigin-RevId: 742280368 --- third_party/xla/xla/hlo/analysis/BUILD | 1 + .../xla/hlo/analysis/hlo_replication_analysis.cc | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/BUILD b/third_party/xla/xla/hlo/analysis/BUILD index 6deb0895e11fe5..e0157bf490a36b 100644 --- a/third_party/xla/xla/hlo/analysis/BUILD +++ b/third_party/xla/xla/hlo/analysis/BUILD @@ -300,6 +300,7 @@ cc_library( srcs = ["hlo_replication_analysis.cc"], hdrs = ["hlo_replication_analysis.h"], deps = [ + "//xla:shape_tree", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index 632b8ac43cc427..64e38cb469fff9 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -41,7 +41,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/map_util.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" @@ -468,6 +470,15 @@ absl::Status HloReplicationAnalysis::ComputeHloReplication() { auto param = entry->parameter_instruction(i); ShapeTree shape_tree(param->shape(), HloReplication::UniqueOnAllDevices()); + + std::unique_ptr> sharding_tree = nullptr; + if (cross_partition_spmd_ && param->has_sharding()) { + TF_ASSIGN_OR_RETURN(auto result, + param->sharding().AsShapeTree(param->shape())); + sharding_tree = + std::make_unique>(std::move(result)); + } + const auto& replication = param->parameter_replicated_at_leaf_buffers(); int leaf_index = 0; absl::Status status = ShapeUtil::ForEachSubshapeWithStatus( @@ -478,10 +489,8 @@ absl::Status HloReplicationAnalysis::ComputeHloReplication() { if (cross_partition_spmd_ && param->has_sharding()) { // In cross-partition spmd mode, set parameter replication status // based on the parameter's sharding. - TF_ASSIGN_OR_RETURN(auto sharding_tree, - param->sharding().AsShapeTree(param->shape())); *shape_tree.mutable_element(index) = - sharding_tree.element(index).IsReplicated() + sharding_tree->element(index).IsReplicated() ? HloReplication::ReplicatedOnAllDevices() : HloReplication::UniqueOnAllDevices(); } From d228a2ffc55935fc9df88cc2f9bb3293c265545b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 09:54:57 -0700 Subject: [PATCH 0050/1324] Removing PACKED_NIBBLES as we have native int4 support in the compiler PiperOrigin-RevId: 742300141 --- .../replace_cast_hacks_with_tf_xla_ops.cc | 24 +----- .../gpu/codegen/triton/support_test.cc | 9 --- .../evaluator/hlo_evaluator_typed_visitor.h | 76 ++++++------------- .../xla/hlo/transforms/operand_upcaster.cc | 54 ------------- .../simplifiers/algebraic_simplifier.cc | 11 +-- .../hlo/transforms/tests/operand_upcaster.hlo | 60 --------------- .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td | 3 +- .../hlo_legalize_to_stablehlo.cc | 43 +---------- .../stablehlo_legalize_to_hlo.cc | 6 +- ...lo-legalize-to-stablehlo-experimental.mlir | 13 ---- .../mhlo/hlo-legalize-to-stablehlo.mlir | 22 ------ third_party/xla/xla/service/hlo_verifier.cc | 59 +------------- third_party/xla/xla/tests/convolution_test.cc | 16 ---- .../xla/xla/tests/dot_operation_test.cc | 30 -------- third_party/xla/xla/xla_data.proto | 4 +- 15 files changed, 40 insertions(+), 390 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index d1e46b4eb56031..ec5adb87d88c8c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -628,8 +628,7 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings, - int feature_group_cnt, bool four_bit = false, - int num_dims = 4) { + int feature_group_cnt, int num_dims = 4) { int32_t input_zp_value; if (!GetSplatValue(input_zp, input_zp_value)) { emitError(loc, @@ -675,14 +674,6 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, conv_padding, explicit_paddings, padding, num_dims); std::string precision_config_str; - if (four_bit) { - input = PackOperand(builder, loc, input, /*pack_dim=*/num_dims - 1); - filter = PackOperand(builder, loc, filter, /*pack_dim=*/num_dims - 2); - xla::PrecisionConfig precision_config; - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config_str = precision_config.SerializeAsString(); - } Value xla_conv_output = builder .create( @@ -774,14 +765,13 @@ Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, return CreateXlaConvOp(builder, loc, input, filter, input_zp, conv_output, strides, dilations, conv_padding, /*explicit_paddings=*/nullptr, feature_group_cnt, - /*four_bit=*/false, /*num_dims=*/5); + /*num_dims=*/5); } // Helper function to create an XlaDotV2Op. Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, Value weight, Value input_zp, Value weight_zp, - Value output, const xla::DotDimensionNumbers &dnums, - bool four_bit = false) { + Value output, const xla::DotDimensionNumbers &dnums) { int32_t input_zp_value = 0; int32_t weight_zp_value = 0; if (input_zp != nullptr && !GetSplatValue(input_zp, input_zp_value)) { @@ -797,14 +787,6 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, } std::string precision_config_str; - if (four_bit) { - input = PackOperand(builder, loc, input, /*pack_dim=*/1); - weight = PackOperand(builder, loc, weight, /*pack_dim=*/0); - xla::PrecisionConfig precision_config; - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config_str = precision_config.SerializeAsString(); - } Value dot_result = builder diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 48f66a2280aeb8..9d062876e10364 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -2243,15 +2243,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest())), DotPrecisionTestName); -INSTANTIATE_TEST_SUITE_P( - DotPackedNibblePrecisionTestSuite, DotPrecisionTest, - ::testing::Combine(::testing::ValuesIn({PrimitiveType::S8, - PrimitiveType::U8}), - ::testing::ValuesIn({PrecisionConfig::PACKED_NIBBLE}), - ::testing::ValuesIn({PrecisionConfig::PACKED_NIBBLE}), - ::testing::ValuesIn(AllDevicesToTest())), - DotPrecisionTestName); - class DotPrecisionAlgorithmTest : public DotTest, public ::testing::WithParamInterface< diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 063ca3ca6a1131..0ba235098e8886 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -816,11 +816,6 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { const Shape& result_shape = conv->shape(); const Shape& lhs_shape = lhs_literal.shape(); const Shape& rhs_shape = rhs_literal.shape(); - const auto packed_nibble_count = - absl::c_count(conv->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE); - CHECK_NE(packed_nibble_count, 1); - const bool is_packed_nibble = packed_nibble_count == 2; TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); @@ -858,9 +853,8 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, rhs_literal_data, feature_group_count, batch_group_count, - is_packed_nibble, result_shape, - this](const absl::Span out_index, - int /*thread_id*/) { + result_shape, this](const absl::Span out_index, + int /*thread_id*/) { // Dimension number applicable for input (lhs). const int64_t input_batch_dim = dnums.input_batch_dimension(); const int64_t input_z_dim = dnums.input_feature_dimension(); @@ -982,23 +976,15 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { static_cast(lhs_literal_data[lhs_linear_index]); auto rhs = static_cast(rhs_literal_data[rhs_linear_index]); - if (is_packed_nibble) { - auto lhs_n0 = ToArithmeticSafeType(Nibble0(lhs)); - auto lhs_n1 = ToArithmeticSafeType(Nibble1(lhs)); - auto rhs_n0 = ToArithmeticSafeType(Nibble0(rhs)); - auto rhs_n1 = ToArithmeticSafeType(Nibble1(rhs)); - result_val += (lhs_n0 * rhs_n0) + (lhs_n1 * rhs_n1); - } else { - result_val += ToArithmeticSafeType(lhs) * ToArithmeticSafeType(rhs); + result_val += ToArithmeticSafeType(lhs) * ToArithmeticSafeType(rhs); - if (parent_->trace_mac_handler_ != nullptr) { - const int64_t result_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex(result_shape, - out_index); + if (parent_->trace_mac_handler_ != nullptr) { + const int64_t result_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(result_shape, + out_index); - parent_->trace_mac_handler_(result_linear_index, lhs_linear_index, - rhs_linear_index); - } + parent_->trace_mac_handler_(result_linear_index, lhs_linear_index, + rhs_linear_index); } } cnt: {} @@ -1171,11 +1157,6 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { CHECK(ShapeUtil::SameElementType(lhs_literal.shape(), rhs_literal.shape())); CHECK(ShapeUtil::SameElementType(lhs_literal.shape(), dot->shape())); - const auto packed_nibble_count = - absl::c_count(dot->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE); - CHECK_NE(packed_nibble_count, 1); - const bool is_packed_nibble = packed_nibble_count == 2; CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); @@ -1229,30 +1210,21 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { static_cast(lhs_literal.Get(lhs_index)); const auto rhs = static_cast(rhs_literal.Get(rhs_index)); - if (is_packed_nibble) { - auto lhs_n0 = ToArithmeticSafeType(Nibble0(lhs)); - auto lhs_n1 = ToArithmeticSafeType(Nibble1(lhs)); - auto rhs_n0 = ToArithmeticSafeType(Nibble0(rhs)); - auto rhs_n1 = ToArithmeticSafeType(Nibble1(rhs)); - result_val += (lhs_n0 * rhs_n0) + (lhs_n1 * rhs_n1); - } else { - result_val += - ToArithmeticSafeType(lhs) * ToArithmeticSafeType(rhs); - - if (parent_->trace_mac_handler_ != nullptr) { - const int64_t result_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex(dot->shape(), - result_index); - const int64_t lhs_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex( - lhs_literal.shape(), lhs_index); - const int64_t rhs_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex( - rhs_literal.shape(), rhs_index); - - parent_->trace_mac_handler_(result_linear_index, - lhs_linear_index, rhs_linear_index); - } + result_val += ToArithmeticSafeType(lhs) * ToArithmeticSafeType(rhs); + + if (parent_->trace_mac_handler_ != nullptr) { + const int64_t result_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(dot->shape(), + result_index); + const int64_t lhs_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex( + lhs_literal.shape(), lhs_index); + const int64_t rhs_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex( + rhs_literal.shape(), rhs_index); + + parent_->trace_mac_handler_(result_linear_index, lhs_linear_index, + rhs_linear_index); } // If there are no contracting dimensions, do not try to count down diff --git a/third_party/xla/xla/hlo/transforms/operand_upcaster.cc b/third_party/xla/xla/hlo/transforms/operand_upcaster.cc index ed6b4d41ff443a..4fe0df75d32590 100644 --- a/third_party/xla/xla/hlo/transforms/operand_upcaster.cc +++ b/third_party/xla/xla/hlo/transforms/operand_upcaster.cc @@ -63,12 +63,6 @@ bool OperandUpcaster::InstructionMatchesPattern(HloInstruction* instruction) { return false; } - // Always expand packed nibble precision mode. - if (absl::c_count(instruction->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE) == 2) { - return true; - } - PrimitiveType inferred_type = (*status_or_inferred_shape)->element_type(); if (instruction->shape().element_type() == inferred_type && instruction->operand(0)->shape().element_type() == inferred_type && @@ -81,56 +75,8 @@ bool OperandUpcaster::InstructionMatchesPattern(HloInstruction* instruction) { absl::StatusOr OperandUpcaster::ExpandInstruction( HloInstruction* instruction) { - const bool packed_nibble = - absl::c_count(instruction->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE) == 2; auto type = instruction->shape().element_type(); - // If the precision is packed nibble create clone the linear op for each - // nibble of lhs and rhs. - if (packed_nibble) { - HloInstruction *lhs_n0 = instruction->mutable_operand(0), *lhs_n1 = lhs_n0, - *rhs_n0 = instruction->mutable_operand(1), *rhs_n1 = rhs_n0; - - TF_ASSIGN_OR_RETURN(lhs_n0, MakeBinaryHlo(HloOpcode::kShiftLeft, lhs_n0, - MakeScalarLike(lhs_n0, 4))); - HloOpcode lhs_shift = ShapeUtil::ElementIsSigned(lhs_n0->shape()) - ? HloOpcode::kShiftRightArithmetic - : HloOpcode::kShiftRightLogical; - TF_ASSIGN_OR_RETURN( - lhs_n0, MakeBinaryHlo(lhs_shift, lhs_n0, MakeScalarLike(lhs_n0, 4))); - lhs_n0 = MakeConvertToHlo(lhs_n0, type); - - TF_ASSIGN_OR_RETURN( - lhs_n1, MakeBinaryHlo(lhs_shift, lhs_n1, MakeScalarLike(lhs_n1, 4))); - lhs_n1 = MakeConvertToHlo(lhs_n1, type); - - TF_ASSIGN_OR_RETURN(rhs_n0, MakeBinaryHlo(HloOpcode::kShiftLeft, rhs_n0, - MakeScalarLike(rhs_n0, 4))); - HloOpcode rhs_shift = ShapeUtil::ElementIsSigned(rhs_n0->shape()) - ? HloOpcode::kShiftRightArithmetic - : HloOpcode::kShiftRightLogical; - TF_ASSIGN_OR_RETURN( - rhs_n0, MakeBinaryHlo(rhs_shift, rhs_n0, MakeScalarLike(rhs_n0, 4))); - rhs_n0 = MakeConvertToHlo(rhs_n0, type); - - TF_ASSIGN_OR_RETURN( - rhs_n1, MakeBinaryHlo(rhs_shift, rhs_n1, MakeScalarLike(rhs_n1, 4))); - rhs_n1 = MakeConvertToHlo(rhs_n1, type); - - HloInstruction* linear_n0 = - instruction->parent()->AddInstruction(instruction->CloneWithNewOperands( - instruction->shape(), {lhs_n0, rhs_n0})); - linear_n0->mutable_precision_config()->mutable_operand_precision()->Set( - 0, PrecisionConfig::DEFAULT); - linear_n0->mutable_precision_config()->mutable_operand_precision()->Set( - 1, PrecisionConfig::DEFAULT); - HloInstruction* linear_n1 = - instruction->parent()->AddInstruction(linear_n0->CloneWithNewOperands( - instruction->shape(), {lhs_n1, rhs_n1})); - return MakeBinaryHlo(HloOpcode::kAdd, linear_n0, linear_n1); - } - for (int i = 0; i < HloDotInstruction::kOperands; ++i) { auto* operand = instruction->mutable_operand(i); if (operand->shape().element_type() == type) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index a2302a7299d7d9..59a0d5bd69cf0f 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -3995,14 +3995,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } - const bool is_packed_nibble = - absl::c_linear_search(dot->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE); const bool can_rewrite_dot_with_precision_config_algorithm = SupportedDotPrecisionConfig(dot->precision_config()); // If there are no contracting dimensions, a dot can be rewritten as // mul(broadcast(transpose(x)),broadcast(transpose(y))) - if (!is_packed_nibble && can_rewrite_dot_with_precision_config_algorithm && + if (can_rewrite_dot_with_precision_config_algorithm && options_.enable_dot_to_multiply_rewrite() && dnums.lhs_contracting_dimensions_size() == 0) { return RewriteAsMultiplyDotWithZeroLhsContractingDim(dot, lhs, rhs, dnums); @@ -4029,7 +4026,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) - if (!is_packed_nibble && can_rewrite_dot_with_precision_config_algorithm && + if (can_rewrite_dot_with_precision_config_algorithm && options_.enable_dot_strength_reduction() && DotHasOnlyBatchAndContractingOnOneOperand(lhs->shape().dimensions_size(), rhs->shape().dimensions_size(), @@ -9655,9 +9652,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( HloInstruction* convolution) { - if (options_.is_layout_sensitive() || - absl::c_linear_search(convolution->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE)) { + if (options_.is_layout_sensitive()) { return false; } diff --git a/third_party/xla/xla/hlo/transforms/tests/operand_upcaster.hlo b/third_party/xla/xla/hlo/transforms/tests/operand_upcaster.hlo index 61a281e1f9a42e..364de426a3b53f 100644 --- a/third_party/xla/xla/hlo/transforms/tests/operand_upcaster.hlo +++ b/third_party/xla/xla/hlo/transforms/tests/operand_upcaster.hlo @@ -17,63 +17,3 @@ ENTRY test_dot { b = s16[8] parameter(1) ROOT result = s32[8] dot(a, b) } - -// ----- - -// CHECK-LABEL: HloModule TestDotPackedNibble, entry_computation_layout={(f16[8,16]{1,0}, f16[16,8]{1,0})->f32[8,8]{1,0}} - -// CHECK-LABEL: ENTRY %test_dot_packed_nibble -// CHECK-NEXT: %[[arg_0:[^ ]+]] = f16[8,16]{1,0} parameter(0) -// CHECK-NEXT: %[[convert:[^ ]+]] = f32[8,16]{1,0} convert(%[[arg_0]]) -// CHECK-NEXT: %[[arg_1:[^ ]+]] = f16[16,8]{1,0} parameter(1) -// CHECK-NEXT: %[[convert_1:[^ ]+]] = f32[16,8]{1,0} convert(%[[arg_1]]) -// CHECK-NEXT: ROOT %[[dot:[^ ]+]] = f32[8,8]{1,0} dot(%[[convert]], %[[convert_1]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={packed_nibble,default} - -HloModule TestDotPackedNibble - -ENTRY test_dot_packed_nibble { - arg_0 = f16[8,16] parameter(0) - arg_1 = f16[16,8] parameter(1) - ROOT dot = f32[8,8] dot(arg_0, arg_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={packed_nibble,default} -} - -// ----- - -// CHECK-LABEL: HloModule TestConvolutionPackedNibble, entry_computation_layout={(s8[3,3,7,7]{3,2,1,0}, s8[5,11,11,7]{3,2,1,0})->s32[5,11,11,7]{3,2,1,0}} - -// CHECK-LABEL: ENTRY %test_convolution_packed_nibble -// CHECK-NEXT: %[[lhs:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} parameter(1) -// CHECK-NEXT: %[[constant:[^ ]+]] = s8[] constant(4) -// CHECK-NEXT: %[[broadcast:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} broadcast(%[[constant]]), dimensions={} -// CHECK-NEXT: %[[shift_left:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} shift-left(%[[lhs]], %[[broadcast]]) -// CHECK-NEXT: %[[constant_1:[^ ]+]] = s8[] constant(4) -// CHECK-NEXT: %[[broadcast_1:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} broadcast(%[[constant_1]]), dimensions={} -// CHECK-NEXT: %[[shift_right_arithmetic:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} shift-right-arithmetic(%[[shift_left]], %[[broadcast_1]]) -// CHECK-NEXT: %[[convert:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convert(%[[shift_right_arithmetic]]) -// CHECK-NEXT: %[[rhs:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} parameter(0) -// CHECK-NEXT: %[[constant_3:[^ ]+]] = s8[] constant(4) -// CHECK-NEXT: %[[broadcast_3:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} broadcast(%[[constant_3]]), dimensions={} -// CHECK-NEXT: %[[shift_left_1:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} shift-left(%[[rhs]], %[[broadcast_3]]) -// CHECK-NEXT: %[[constant_4:[^ ]+]] = s8[] constant(4) -// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} broadcast(%[[constant_4]]), dimensions={} -// CHECK-NEXT: %[[shift_right_arithmetic_2:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} shift-right-arithmetic(%[[shift_left_1]], %[[broadcast_4]]) -// CHECK-NEXT: %[[convert_2:[^ ]+]] = s32[3,3,7,7]{3,2,1,0} convert(%[[shift_right_arithmetic_2]]) -// CHECK-NEXT: %[[convolution_1:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convolution(%[[convert]], %[[convert_2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f -// CHECK-NEXT: %[[constant_2:[^ ]+]] = s8[] constant(4) -// CHECK-NEXT: %[[broadcast_2:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} broadcast(%[[constant_2]]), dimensions={} -// CHECK-NEXT: %[[shift_right_arithmetic_1:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} shift-right-arithmetic(%[[lhs]], %[[broadcast_2]]) -// CHECK-NEXT: %[[convert_1:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convert(%[[shift_right_arithmetic_1]]) -// CHECK-NEXT: %[[constant_5:[^ ]+]] = s8[] constant(4) -// CHECK-NEXT: %[[broadcast_5:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} broadcast(%[[constant_5]]), dimensions={} -// CHECK-NEXT: %[[shift_right_arithmetic_3:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} shift-right-arithmetic(%[[rhs]], %[[broadcast_5]]) -// CHECK-NEXT: %[[convert_3:[^ ]+]] = s32[3,3,7,7]{3,2,1,0} convert(%[[shift_right_arithmetic_3]]) -// CHECK-NEXT: %[[convolution_2:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convolution(%[[convert_1]], %[[convert_3]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f -// CHECK-NEXT: ROOT %[[add:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} add(%[[convolution_1]], %[[convolution_2]]) - -HloModule TestConvolutionPackedNibble - -ENTRY test_convolution_packed_nibble { - lhs = s8[5,11,11,7] parameter(1) - rhs = s8[3,3,7,7] parameter(0) - ROOT convolution = s32[5,11,11,7] convolution(lhs, rhs), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, operand_precision={PACKED_NIBBLE,PACKED_NIBBLE} -} diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td index 53903a874fde86..b12855091ae25f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -27,11 +27,10 @@ include "mlir/IR/PatternBase.td" def MHLO_PRECISION_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; def MHLO_PRECISION_HIGH : I32EnumAttrCase<"HIGH", 1>; def MHLO_PRECISION_HIGHEST : I32EnumAttrCase<"HIGHEST", 2>; -def MHLO_PRECISION_PACKED_NIBBLE : I32EnumAttrCase<"PACKED_NIBBLE", 3>; def MHLO_Precision : I32EnumAttr<"Precision", "XLA precision for an operand. Has backend specific meaning.", - [MHLO_PRECISION_DEFAULT, MHLO_PRECISION_HIGH, MHLO_PRECISION_HIGHEST, MHLO_PRECISION_PACKED_NIBBLE]> { + [MHLO_PRECISION_DEFAULT, MHLO_PRECISION_HIGH, MHLO_PRECISION_HIGHEST]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::mhlo"; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index fcb185a4d3110a..7b4f0458f95210 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -80,14 +80,6 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { return false; } -bool hasPackedNibble(std::optional precisionConfigAttr) { - if (!precisionConfigAttr) return false; - return llvm::any_of(*precisionConfigAttr, [&](Attribute attr) { - auto precisionAttr = mlir::cast(attr); - return precisionAttr.getValue() == mhlo::Precision::PACKED_NIBBLE; - }); -} - // EXPERIMENTAL MHLO features are being explored by ML frontends but do not have // any agreed upon compatibility guarantees. By default, these features cannot // be converted to StableHLO, although the allow-experimental-features flag can @@ -105,21 +97,6 @@ bool hasExperimentalFeaturesNotInStablehlo(HloOpTy hloOp) { // Proposal: https://github.com/openxla/stablehlo/issues/574. if (hloOp.getNumOperands() != 1) return true; } - if constexpr (std::is_same::value) { - // StableHLO ConvolutionOp doesn't support PACKED_NIBBLE yet. - // Proposal: https://github.com/openxla/stablehlo/issues/742. - if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; - } - if constexpr (std::is_same::value) { - // StableHLO DotGeneral doesn't support PACKED_NIBBLE yet. - // Proposal: https://github.com/openxla/stablehlo/issues/742. - if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; - } - if constexpr (std::is_same::value) { - // StableHLO Dot doesn't support PACKED_NIBBLE yet. - // Proposal: https://github.com/openxla/stablehlo/issues/742. - if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; - } return false; } @@ -294,9 +271,6 @@ Attribute convertAttr(Attribute hloAttr) { attr.getOperandTupleIndices()); } if (auto attr = mlir::dyn_cast(hloAttr)) { - // StableHLO Precision doesn't support PACKED_NIBBLE yet. - // Proposal: https://github.com/openxla/stablehlo/issues/742. - if (attr.getValue() == mhlo::Precision::PACKED_NIBBLE) return {}; RETURN_CONVERTED_ENUM_ATTR(Precision); } if (auto attr = mlir::dyn_cast(hloAttr)) { @@ -364,7 +338,7 @@ Attribute convertAttr(Attribute hloAttr) { #undef RETURN_CONVERTED_ENUM_ATTR // Convert array of enum attrs to an array of enum strings -// [#mhlo] -> ["PACKED_NIBBLE"] +// [#mhlo] -> ["HIGHEST"] // // This is stable as long as enum names are not changed. This is needed to avoid // a dependency on upstream printing / parsing. If an attribute name is changed, @@ -469,17 +443,6 @@ LogicalResult convertAttributes(ConversionPatternRewriter& rewriter, continue; } - // If PACKED_NIBBLE enum support enabled, convert to string "PACKED_NIBBLE" - if constexpr (std::is_same::value || - std::is_same::value || - std::is_same::value) { - if (hloAttr.getName() == "precision_config" && - hasPackedNibble(hloOp.getPrecisionConfig())) { - stablehloAttr = - encodePrecisionConfig(hloOp.getPrecisionConfig().value()); - } - } - // Handle DenseElements --> DenseArray for certain StableHLO ops if constexpr (!std::is_same::value && !std::is_same::value) { @@ -504,10 +467,10 @@ LogicalResult convertAttributes(ConversionPatternRewriter& rewriter, // // Example: // %0 = "mhlo.dot"(%arg0, %arg1) { -// precision_config = [#mhlo] } ... +// precision_config = [#mhlo] } ... // ==> // %0 = stablehlo.custom_call @mhlo.dot { -// mhlo.attributes = {precision_config = ["PACKED_NIBBLE"]}} +// mhlo.attributes = {precision_config = ["HIGHEST"]}} template LogicalResult rewriteMhloOpAsCustomCall(HloOpTy hloOp, ConversionPatternRewriter& rewriter, diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index fa3186ce17a942..cb39eb52ca7f29 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -193,7 +193,7 @@ Attribute convertAttr(Attribute stablehloAttr) { #undef RETURN_CONVERTED_ENUM_ATTR // Convert array of enum strings to array of enum attrs -// ["PACKED_NIBBLE"] --> [#mhlo] +// ["HIGHEST"] --> [#mhlo] Attribute decodePrecisionConfig(Attribute stablehloAttr) { auto arrayAttr = mlir::dyn_cast(stablehloAttr); if (!arrayAttr) return {}; @@ -248,10 +248,10 @@ LogicalResult convertFuncToStablehloRegion(Operation* op, func::FuncOp funcOp, // // Example: // %0 = stablehlo.custom_call @mhlo.dot { -// mhlo.attributes = {precision_config = ["PACKED_NIBBLE"]}} +// mhlo.attributes = {precision_config = ["HIGHEST"]}} // ==> // %0 = "mhlo.dot"(%arg0, %arg1) { -// precision_config = [#mhlo] } ... +// precision_config = [#mhlo] } ... LogicalResult rewriteCustomCallAsMhloOp(stablehlo::CustomCallOp stablehloOp, ConversionPatternRewriter& rewriter, const TypeConverter* typeConverter, diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir index 706c9cdd58628a..38056bdb3d0f67 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir @@ -39,16 +39,3 @@ func.func @op_all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32 return %0#0, %0#1 : tensor<128x4xf32>, tensor<128x4xf32> } -// ----- - -// CHECK-LABEL: "attr_precision_packed_nibble" -func.func @attr_precision_packed_nibble(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { - // CHECK: "stablehlo.custom_call"(%arg0, %arg1) <{call_target_name = "mhlo.dot"}> { - // CHECK-SAME: mhlo.attributes = {precision_config = ["PACKED_NIBBLE"]} - // CHECK-SAME: } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> - // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} - %0 = "mhlo.dot"(%arg0, %arg1) { - precision_config = [#mhlo] - } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> - func.return %0 : tensor<8x8xf32> -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 55d1bfe736a4da..99fc79ea16ab85 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -2085,28 +2085,6 @@ func.func @type_tuple(%arg0: tuple>) -> tuple { // ----- -func.func @attr_precision_config_invalid() -> tensor<8x8xf32> { - // expected-error@+1 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} - %0 = "mhlo.custom_call"() { - call_target_name = "foo", - precision_config = [#mhlo, 1 : i32] - } : () -> tensor<8x8xf32> - func.return %0 : tensor<8x8xf32> -} - -// ----- - -func.func @attr_invalid_nested_in_dictionary() -> tensor<8x8xf32> { - // expected-error@+1 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} - %0 = "mhlo.custom_call"() { - call_target_name = "foo", - precision_config = {config = #mhlo} - } : () -> tensor<8x8xf32> - func.return %0 : tensor<8x8xf32> -} - -// ----- - func.func @op_add_dependency(%arg0: tensor<16xf32>, %arg1: !mhlo.token) -> tensor<16xf32> { // expected-error@+1 {{failed to legalize operation 'mhlo.add_dependency' that was explicitly marked illegal}} %0 = "mhlo.add_dependency"(%arg0, %arg1) : (tensor<16xf32>, !mhlo.token) -> tensor<16xf32> diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 761ba2d0a62945..f754f8e8926c50 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -218,27 +218,7 @@ absl::Status ShapeVerifier::HandleDot(HloInstruction* dot) { dot->operand(0)->shape(), dot->operand(1)->shape(), dot->dot_dimension_numbers(), /*preferred_element_type=*/dot->shape().element_type(), sparsity)); - if (auto nibble_count = - absl::c_count(dot->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE)) { - if (nibble_count == 1) { - return InvalidArgument("Dot cannot have a single packed nibble argument"); - } - if (nibble_count == 2) { - if (!ShapeUtil::ElementIsIntegralWithBits(dot->operand(0)->shape(), 8)) { - return InvalidArgument( - "Packed nibble precision can only apply to 8 bit integers. LHS is " - "%s.", - dot->operand(0)->ToString()); - } - if (!ShapeUtil::ElementIsIntegralWithBits(dot->operand(1)->shape(), 8)) { - return InvalidArgument( - "Packed nibble precision can only apply to 8 bit integers. RHS is " - "%s.", - dot->operand(1)->ToString()); - } - } - } + for (int i = 0; i < sparsity.size(); ++i) { const SparsityDescriptor& descriptor = sparsity[i]; TF_RET_CHECK(descriptor.index() == 0 || descriptor.index() == 1); @@ -281,42 +261,7 @@ absl::Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { convolution->feature_group_count(), convolution->batch_group_count(), convolution->window(), convolution->convolution_dimension_numbers(), /*preferred_element_type=*/convolution->shape().element_type())); - if (auto nibble_count = - absl::c_count(convolution->precision_config().operand_precision(), - PrecisionConfig::PACKED_NIBBLE)) { - if (nibble_count == 1) { - return InvalidArgument( - "Convolution cannot have a single packed nibble argument"); - } - if (nibble_count == 2) { - if (convolution->feature_group_count() != 1) { - return InvalidArgument( - "Packed nibble precision does not support feature group count " - "%s.", - convolution->ToString()); - } - if (convolution->batch_group_count() != 1) { - return InvalidArgument( - "Packed nibble precision does not support batch group count " - "%s.", - convolution->ToString()); - } - if (!ShapeUtil::ElementIsIntegralWithBits( - convolution->operand(0)->shape(), 8)) { - return InvalidArgument( - "Packed nibble precision can only apply to 8 bit integers. LHS is " - "%s.", - convolution->operand(0)->ToString()); - } - if (!ShapeUtil::ElementIsIntegralWithBits( - convolution->operand(1)->shape(), 8)) { - return InvalidArgument( - "Packed nibble precision can only apply to 8 bit integers. RHS is " - "%s.", - convolution->operand(1)->ToString()); - } - } - } + return CheckShape(convolution, expected); } diff --git a/third_party/xla/xla/tests/convolution_test.cc b/third_party/xla/xla/tests/convolution_test.cc index e77a50643f11a4..f335aaf40fab61 100644 --- a/third_party/xla/xla/tests/convolution_test.cc +++ b/third_party/xla/xla/tests/convolution_test.cc @@ -1767,22 +1767,6 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); } -// CUDNN does not support s8->s32 convs -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(PackedNibbleConvolve)) { - constexpr char kHlo[] = R"( -HloModule TestModule - -ENTRY Test { - %lhs = s8[5,11,11,7] parameter(1) - %rhs = s8[3,3,7,7] parameter(0) - ROOT %convolution = s32[5,11,11,7] convolution(lhs, rhs), - window={size=3x3 pad=1_1x1_1}, - dim_labels=b01f_01io->b01f, - operand_precision={PACKED_NIBBLE,PACKED_NIBBLE} -})"; - EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0, 0})); -} - XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolveWithStride) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index 932639ca384706..64c1bdde5d8bc9 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -2053,36 +2053,6 @@ ENTRY SmallIntegerDot { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(DotOperationTextTest, DISABLED_ON_GPU(PackedNibbleDot)) { - absl::string_view hlo_string = - R"( -HloModule SmallIntegerDot - -ENTRY SmallIntegerDot { - arg0 = s8[20,55] parameter(0) - arg1 = s8[55,20] parameter(1) - ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={PACKED_NIBBLE, PACKED_NIBBLE} -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); -} - -XLA_TEST_F(DotOperationTextTest, UnsignedPackedNibbleDot) { - absl::string_view hlo_string = - R"( -HloModule SmallIntegerDot - -ENTRY SmallIntegerDot { - arg0 = u8[3,11,21] parameter(0) - arg1 = u8[55,21,3] parameter(1) - ROOT dot = u32[3,11,55] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={2}, rhs_contracting_dims={1}, operand_precision={PACKED_NIBBLE, PACKED_NIBBLE} -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); -} - XLA_TEST_F(DotOperationTextTest, S32Dot) { absl::string_view hlo_string = R"( diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index fbed278049cfad..d3178ba8ef87f2 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -1071,10 +1071,8 @@ message PrecisionConfig { DEFAULT = 0; HIGH = 1; HIGHEST = 2; - // Each U8/S8 value in a tensor actually represents 2 nibble values. - PACKED_NIBBLE = 3; - // Next: 4 + // Next: 3 } // The algorithm used to evaluate the instruction. From 224eae14a13a6cb66cf5f93fea8aaec62b6735b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 10:46:31 -0700 Subject: [PATCH 0051/1324] Support assigning a multiple values encoded as a comma separated string into a repeated field in DebugOptions: - each occurrence of field, overwrites all previous values in the field, - empty string clears all values from the field, - splitting comma separated string skips empty values. PiperOrigin-RevId: 742318042 --- third_party/xla/xla/pjrt/BUILD | 4 +- third_party/xla/xla/pjrt/pjrt_executable.cc | 318 ++++++++++++------ third_party/xla/xla/pjrt/pjrt_executable.h | 4 +- .../xla/xla/pjrt/pjrt_executable_test.cc | 40 ++- 4 files changed, 241 insertions(+), 125 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index c2d45169eb10fb..2ee156f7faefe0 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -295,7 +295,6 @@ cc_library( ":execute_options_proto_cc", ":pjrt_common", ":pjrt_layout", - "//xla:shape_layout", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", @@ -316,8 +315,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) @@ -330,6 +327,7 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status_matchers", ], diff --git a/third_party/xla/xla/pjrt/pjrt_executable.cc b/third_party/xla/xla/pjrt/pjrt_executable.cc index d37c69999ac0b2..81d5d7d6207b75 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.cc +++ b/third_party/xla/xla/pjrt/pjrt_executable.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include +#include #include #include #include @@ -29,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" @@ -42,13 +44,10 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" -#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -519,8 +518,7 @@ PjRtExecutableUtil::RunHloCostAnalysis( return ret; } -absl::StatusOr< - std::vector>> +absl::StatusOr CompileOptions::LoadEnvOptionOverrides( const google::protobuf::Map& env_option_overrides) { @@ -554,74 +552,161 @@ CompileOptions::LoadEnvOptionOverrides( return result; } +absl::Status ApplyStringOption(const tsl::protobuf::FieldDescriptor* field, + const std::string& value, + xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddString(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetString(&debug_options, field, value); + } + return absl::OkStatus(); +} + +absl::Status ApplyInt32Option(const tsl::protobuf::FieldDescriptor* field, + int32_t value, xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddInt32(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetInt32(&debug_options, field, value); + } + return absl::OkStatus(); +} +absl::Status ApplyInt64Option(const tsl::protobuf::FieldDescriptor* field, + int64_t value, xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddInt64(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetInt64(&debug_options, field, value); + } + return absl::OkStatus(); +} + +absl::Status ApplyFloatOption(const tsl::protobuf::FieldDescriptor* field, + float value, xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddFloat(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetFloat(&debug_options, field, value); + } + return absl::OkStatus(); +} + +absl::Status ApplyDoubleOption(const tsl::protobuf::FieldDescriptor* field, + double value, xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddDouble(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetDouble(&debug_options, field, value); + } + return absl::OkStatus(); +} + +absl::Status ApplyBoolOption(const tsl::protobuf::FieldDescriptor* field, + bool value, xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddBool(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetBool(&debug_options, field, value); + } + return absl::OkStatus(); +} +absl::Status ApplyEnumOption(const tsl::protobuf::FieldDescriptor* field, + int value, xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddEnumValue(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetEnumValue(&debug_options, field, value); + } + return absl::OkStatus(); +} +absl::Status ApplyEnumOption(const tsl::protobuf::FieldDescriptor* field, + const tsl::protobuf::EnumValueDescriptor* value, + xla::DebugOptions& debug_options) { + if (field->is_repeated()) { + debug_options.GetReflection()->AddEnum(&debug_options, field, value); + } else { + debug_options.GetReflection()->SetEnum(&debug_options, field, value); + } + return absl::OkStatus(); +} + absl::Status CompileOptions::ApplyOption(const std::string& key, const OptionOverride& value) { - if (auto* xla_field = xla::DebugOptions::descriptor()->FindFieldByName(key)) { - xla::DebugOptions& debug_options = - *executable_build_options.mutable_debug_options(); - const tsl::protobuf::Reflection* reflection = debug_options.GetReflection(); - if (!reflection) { - return InvalidArgument( - "No reflection object associated with xla::DebugOptions."); + auto* xla_field = xla::DebugOptions::descriptor()->FindFieldByName(key); + if (xla_field == nullptr) { + return InvalidArgument("No such compile option: '%s'", key); + } + xla::DebugOptions& debug_options = + *executable_build_options.mutable_debug_options(); + const tsl::protobuf::Reflection* reflection = debug_options.GetReflection(); + if (reflection == nullptr) { + return InvalidArgument( + "No reflection object associated with xla::DebugOptions."); + } + if (xla_field->is_repeated()) { + debug_options.GetReflection()->ClearField(&debug_options, xla_field); + } + if (std::holds_alternative(value)) { + return ApplyOptionFromString(xla_field, std::get(value)); + } + switch (xla_field->type()) { + case tsl::protobuf::FieldDescriptor::TYPE_BOOL: { + if (std::holds_alternative(value)) { + return ApplyBoolOption(xla_field, std::get(value), debug_options); + } + break; } - if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_BOOL && - std::holds_alternative(value)) { - reflection->SetBool(&debug_options, xla_field, std::get(value)); - return absl::OkStatus(); - } else if (std::holds_alternative(value)) { - TF_RETURN_IF_ERROR( - ApplyOptionFromString(xla_field, std::get(value))); - return absl::OkStatus(); - } else if (xla_field->type() == - tsl::protobuf::FieldDescriptor::TYPE_INT32 && - std::holds_alternative(value)) { - reflection->SetInt32(&debug_options, xla_field, std::get(value)); - return absl::OkStatus(); - } else if (xla_field->type() == - tsl::protobuf::FieldDescriptor::TYPE_INT64 && - std::holds_alternative(value)) { - reflection->SetInt64(&debug_options, xla_field, std::get(value)); - return absl::OkStatus(); - } else if (xla_field->type() == - tsl::protobuf::FieldDescriptor::TYPE_FLOAT && - std::holds_alternative(value)) { - reflection->SetFloat(&debug_options, xla_field, std::get(value)); - return absl::OkStatus(); - } else if (xla_field->type() == - tsl::protobuf::FieldDescriptor::TYPE_DOUBLE && - std::holds_alternative(value)) { - reflection->SetDouble(&debug_options, xla_field, std::get(value)); - return absl::OkStatus(); - } else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + case tsl::protobuf::FieldDescriptor::TYPE_INT32: { if (std::holds_alternative(value)) { - if (xla_field->is_repeated()) { - reflection->AddEnumValue(&debug_options, xla_field, - std::get(value)); - } else { - reflection->SetEnumValue(&debug_options, xla_field, - std::get(value)); + int64_t int64_value = std::get(value); + if (int64_value >= std::numeric_limits::min() && + int64_value <= std::numeric_limits::max()) { + return ApplyInt32Option(xla_field, static_cast(int64_value), + debug_options); } - } else { - auto enum_desc = xla_field->enum_type()->FindValueByName( - std::get(value)); - if (enum_desc != nullptr) { - if (xla_field->is_repeated()) { - reflection->AddEnum(&debug_options, xla_field, enum_desc); - } else { - reflection->SetEnum(&debug_options, xla_field, enum_desc); - } + } + break; + } + case tsl::protobuf::FieldDescriptor::TYPE_INT64: { + if (std::holds_alternative(value)) { + return ApplyInt64Option(xla_field, std::get(value), + debug_options); + } + break; + } + case tsl::protobuf::FieldDescriptor::TYPE_FLOAT: { + if (std::holds_alternative(value)) { + double double_value = std::get(value); + if (double_value >= std::numeric_limits::min() && + double_value <= std::numeric_limits::max()) { + return ApplyFloatOption(xla_field, static_cast(double_value), + debug_options); } } - return absl::OkStatus(); - } else { - return InvalidArgument( - "While setting option %s, '%s' is not a valid %s value.", key, - std::visit([](auto&& arg) { return absl::StrCat(arg); }, value), - xla_field->type_name()); + break; } - } else { - return InvalidArgument("No such compile option: '%s'", key); + case tsl::protobuf::FieldDescriptor::TYPE_DOUBLE: { + if (std::holds_alternative(value)) { + return ApplyFloatOption(xla_field, std::get(value), + debug_options); + } + break; + } + case tsl::protobuf::FieldDescriptor::TYPE_ENUM: { + if (std::holds_alternative(value)) { + return ApplyEnumOption(xla_field, std::get(value), + debug_options); + } + break; + } + default: + break; } + return InvalidArgument( + "While setting option %s, '%s' is not a valid %s value.", key, + std::visit([](auto&& arg) { return absl::StrCat(arg); }, value), + xla_field->type_name()); } absl::Status CompileOptions::ApplyAllOptionOverrides() { @@ -631,65 +716,80 @@ absl::Status CompileOptions::ApplyAllOptionOverrides() { return absl::OkStatus(); } -absl::Status CompileOptions::ApplyOptionFromString( - const tsl::protobuf::FieldDescriptor* field, const std::string& value) { - xla::DebugOptions& debug_options = - *executable_build_options.mutable_debug_options(); - const tsl::protobuf::Reflection* reflection = debug_options.GetReflection(); - if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_STRING) { - reflection->SetString(&debug_options, field, value); - return absl::OkStatus(); - } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT32) { - int int_value; - if (absl::SimpleAtoi(value, &int_value)) { - reflection->SetInt32(&debug_options, field, int_value); - return absl::OkStatus(); +absl::Status ApplyOptionFromSingleString( + const tsl::protobuf::FieldDescriptor* field, const std::string& value, + xla::DebugOptions& debug_options) { + switch (field->type()) { + case tsl::protobuf::FieldDescriptor::TYPE_STRING: + return ApplyStringOption(field, value, debug_options); + case tsl::protobuf::FieldDescriptor::TYPE_INT32: { + int32_t int_value; + if (absl::SimpleAtoi(value, &int_value)) { + return ApplyInt32Option(field, int_value, debug_options); + } + break; } - } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT64) { - int int_value; - if (absl::SimpleAtoi(value, &int_value)) { - reflection->SetInt64(&debug_options, field, int_value); - return absl::OkStatus(); + case tsl::protobuf::FieldDescriptor::TYPE_INT64: { + int64_t int_value; + if (absl::SimpleAtoi(value, &int_value)) { + return ApplyInt64Option(field, int_value, debug_options); + } + break; } - } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_FLOAT) { - float float_value; - if (absl::SimpleAtof(value, &float_value)) { - reflection->SetFloat(&debug_options, field, float_value); - return absl::OkStatus(); + case tsl::protobuf::FieldDescriptor::TYPE_FLOAT: { + float float_value; + if (absl::SimpleAtof(value, &float_value)) { + return ApplyFloatOption(field, float_value, debug_options); + } + break; } - } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_BOOL) { - bool bvalue = value == "True"; - if (value == "True" || value == "False") { - reflection->SetBool(&debug_options, field, bvalue); - return absl::OkStatus(); + case tsl::protobuf::FieldDescriptor::TYPE_DOUBLE: { + double double_value; + if (absl::SimpleAtod(value, &double_value)) { + return ApplyDoubleOption(field, double_value, debug_options); + } + break; } - } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { - int int_value; - if (absl::SimpleAtoi(value, &int_value)) { - if (field->is_repeated()) { - reflection->AddEnumValue(&debug_options, field, int_value); - } else { - reflection->SetEnumValue(&debug_options, field, int_value); + case tsl::protobuf::FieldDescriptor::TYPE_BOOL: { + if (value == "True" || value == "False") { + return ApplyBoolOption(field, value == "True", debug_options); } - return absl::OkStatus(); - } else { - if (value.empty() && field->is_repeated()) { - reflection->ClearField(&debug_options, field); - return absl::OkStatus(); + break; + } + case tsl::protobuf::FieldDescriptor::TYPE_ENUM: { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + return ApplyEnumOption(field, int_value, debug_options); } auto enum_desc = field->enum_type()->FindValueByName(value); if (enum_desc != nullptr) { - if (field->is_repeated()) { - reflection->AddEnum(&debug_options, field, enum_desc); - } else { - reflection->SetEnum(&debug_options, field, enum_desc); - } + return ApplyEnumOption(field, enum_desc, debug_options); } + break; } + default: + break; } return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", field->name(), value, field->type_name()); } +absl::Status CompileOptions::ApplyOptionFromString( + const tsl::protobuf::FieldDescriptor* field, const std::string& value) { + if (!field->is_repeated()) { + return ApplyOptionFromSingleString( + field, value, *executable_build_options.mutable_debug_options()); + } + if (value.empty()) { + return absl::OkStatus(); + } + for (const auto& v : absl::StrSplit(value, ',')) { + TF_RETURN_IF_ERROR(ApplyOptionFromSingleString( + field, std::string(v), + *executable_build_options.mutable_debug_options())); + } + return absl::OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_executable.h b/third_party/xla/xla/pjrt/pjrt_executable.h index 24cef9f8a69d2f..e4ce20ddcac26e 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.h +++ b/third_party/xla/xla/pjrt/pjrt_executable.h @@ -120,9 +120,7 @@ struct CompileOptions { absl::Status ApplyOptionFromString( const tsl::protobuf::FieldDescriptor* field, const std::string& value); - static absl::StatusOr< - std::vector>> - LoadEnvOptionOverrides( + static absl::StatusOr LoadEnvOptionOverrides( const google::protobuf::Map& env_option_overrides); diff --git a/third_party/xla/xla/pjrt/pjrt_executable_test.cc b/third_party/xla/xla/pjrt/pjrt_executable_test.cc index 72c0da6f04bda0..1e446870f78181 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_executable_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/client/executable_build_options.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" @@ -94,18 +95,24 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) { } TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) { - using OptionOverride = std::variant; - std::vector> env_override_options; - env_override_options = { + CompileOptions src; + src.env_option_overrides = { {"xla_gpu_use_runtime_fusion", std::string("True")}, {"xla_gpu_graph_min_graph_size", std::string("2")}, - {"xla_gpu_disable_async_collectives", std::string("2")}, {"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")}, {"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", 0.9}, - {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}}; - CompileOptions src; - src.env_option_overrides = env_override_options; - auto s = src.ApplyAllOptionOverrides(); + {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}, + // Repeated fields. + {"xla_gpu_disable_async_collectives", std::string("2,REDUCESCATTER")}, + {"xla_disable_hlo_passes", + std::string("rematerialization,something else")}, + // Repeated fields provided twice. The last one wins. + {"xla_enable_hlo_passes_only", std::string("one, two, three")}, + {"xla_enable_hlo_passes_only", std::string(",,second, , third,")}, + {"xla_gpu_enable_command_buffer", std::string("CUSTOM_CALL,COLLECTIVES")}, + {"xla_gpu_enable_command_buffer", + static_cast(DebugOptions::CUSTOM_CALL)}}; + TF_EXPECT_OK(src.ApplyAllOptionOverrides()); auto& debug_options = src.executable_build_options.debug_options(); EXPECT_EQ(debug_options.xla_gpu_use_runtime_fusion(), true); EXPECT_EQ(debug_options.xla_gpu_graph_min_graph_size(), 2); @@ -113,8 +120,21 @@ TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) { EXPECT_FLOAT_EQ( debug_options.xla_gpu_auto_spmd_partitioning_memory_budget_ratio(), 0.9); EXPECT_EQ(debug_options.xla_gpu_pgle_profile_file_or_directory_path(), "abc"); - EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives().size(), 1); - EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[0], 2); + EXPECT_THAT(debug_options.xla_gpu_disable_async_collectives(), + testing::ElementsAre(xla::DebugOptions::ALLGATHER, + xla::DebugOptions::REDUCESCATTER)); + EXPECT_THAT(debug_options.xla_disable_hlo_passes(), + testing::ElementsAre("rematerialization", "something else")); + EXPECT_THAT(debug_options.xla_enable_hlo_passes_only(), + testing::ElementsAre("", "", "second", " ", " third", "")); + EXPECT_THAT(debug_options.xla_gpu_enable_command_buffer(), + testing::ElementsAre(DebugOptions::CUSTOM_CALL)); + + // Test that repeated fields are cleared when empty string is provided. + src.env_option_overrides = { + {"xla_gpu_enable_command_buffer", std::string("")}}; + TF_EXPECT_OK(src.ApplyAllOptionOverrides()); + EXPECT_TRUE(debug_options.xla_gpu_enable_command_buffer().empty()); } TEST(CompiledMemoryStatsTest, Serialization) { From 1936090f58f4cac70e1547097195a82f5592ea52 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Mon, 31 Mar 2025 11:08:57 -0700 Subject: [PATCH 0052/1324] #sdy add a sharding rule for mhlo::CopyOp as this op remains when converting from HLO to StableHLO. PiperOrigin-RevId: 742326174 --- .../spmd/shardy/extensions/mhlo_extensions.cc | 12 +++++++++++- .../spmd/shardy/test/mhlo_extensions_test.mlir | 7 +++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/spmd/shardy/extensions/mhlo_extensions.cc b/third_party/xla/xla/service/spmd/shardy/extensions/mhlo_extensions.cc index 5c6adeea8da003..b0c0b3762eb390 100644 --- a/third_party/xla/xla/service/spmd/shardy/extensions/mhlo_extensions.cc +++ b/third_party/xla/xla/service/spmd/shardy/extensions/mhlo_extensions.cc @@ -39,6 +39,15 @@ namespace { using ::mlir::ArrayRef; using ::mlir::sdy::FactorType; using ::mlir::sdy::kNullDim; +using ::mlir::sdy::OpShardingRuleBuilder; + +struct CopyShardingRuleOpInterface + : public mlir::sdy::ShardingRuleOpInterface::ExternalModel< + CopyShardingRuleOpInterface, mhlo::CopyOp> { + mlir::sdy::OpShardingRuleAttr getShardingRule(mlir::Operation* op) const { + return OpShardingRuleBuilder::buildPointwise(op); + } +}; enum RaggedDotMode { // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]. @@ -87,7 +96,7 @@ struct RaggedDotShardingRuleOpInterface mode = RaggedDotMode::kNonContracting; } - mlir::sdy::OpShardingRuleBuilder builder(raggedDot); + OpShardingRuleBuilder builder(raggedDot); mlir::RankedTensorType lhsType = raggedDot.getLhs().getType(); mlir::RankedTensorType rhsType = raggedDot.getRhs().getType(); @@ -181,6 +190,7 @@ struct RaggedDotShardingRuleOpInterface void registerMhloExtensions(mlir::DialectRegistry& registry) { registry.addExtension(+[](mlir::MLIRContext* ctx, mhlo::MhloDialect*) { + mhlo::CopyOp::attachInterface(*ctx); mhlo::RaggedDotOp::attachInterface(*ctx); }); } diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir index 73f76ed180897b..bc71b025823215 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir @@ -1,5 +1,12 @@ // RUN: sdy_opt %s -sdy-populate-op-sharding-rules -verify-diagnostics 2>&1 | FileCheck %s +// CHECK-LABEL: func @copy +func.func @copy(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=8}> + %0 = mhlo.copy %arg0 : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + // CHECK-LABEL: func @ragged_dot_mode_non_contracting func.func @ragged_dot_mode_non_contracting(%arg0: tensor<16x32x64xf32>, %arg1: tensor<4x16x64x8xf32>, %arg2: tensor<16x4xi32>) -> tensor<16x32x8xf32> { // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [m, i, l, k], [i, m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=4} reduction={l} need_replication={j, m}> From fde75f0fd00689b0731864a96efd81955897e9ba Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 31 Mar 2025 11:13:33 -0700 Subject: [PATCH 0053/1324] [XLA:GPU] Split out and refactor `dot_algorithms` in the Triton emitter. The idea is to have a single entry point for the emitters to emit the matmul, and this is a step towards enabling dot algorithms in the generic Triton emitter path. PiperOrigin-RevId: 742327760 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 55 ++ .../gpu/codegen/triton/dot_algorithms.cc | 508 ++++++++++++++++++ .../gpu/codegen/triton/dot_algorithms.h | 47 ++ .../triton/fusion_emitter_legacy_matmul.cc | 259 +-------- 4 files changed, 631 insertions(+), 238 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc create mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 521d8ce467ec59..6c46dd86bcc042 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -289,6 +289,61 @@ cc_library( ["fusion_emitter_legacy_matmul_stub.cc"], ), hdrs = ["fusion_emitter_legacy_matmul.h"], + deps = [ + ":dot_algorithms", + ":emitter_helpers", + "//xla:comparison_util", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/codegen:emitter_loc_op_builder", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/hlo/utils:hlo_traversal", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/mlir_hlo:transformation_helpers", + "//xla/service:algorithm_util", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_indexing_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu:triton_tiling_propagation", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor/gpu:tma_metadata", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@triton//:TritonDialects", + ], +) + +cc_library( + name = "dot_algorithms", + srcs = ["dot_algorithms.cc"], + hdrs = ["dot_algorithms.h"], deps = [ ":emitter_helpers", "//xla:comparison_util", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc new file mode 100644 index 00000000000000..03654935803618 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -0,0 +1,508 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/codegen/triton/dot_algorithms.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/triton/emitter_helpers.h" +#include "xla/codegen/emitter_loc_op_builder.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_traversal.h" +#include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/tensor_float_32_utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace xla { +namespace gpu { +namespace triton { + +namespace { + +namespace arith = ::mlir::arith; +namespace math = ::mlir::math; +namespace ttir = ::mlir::triton; + +using ::mlir::ShapedType; +using ::mlir::Type; +using ::mlir::Value; + +Type ElementType(Value v) { return mlir::getElementTypeOrSelf(v); } + +// Precision-relevant configuration bits for `dot`s. +struct PrecisionSpec { + PrecisionConfig::Algorithm algorithm; + // TODO(bchetioui): we hope to get rid of operand precisions eventually, they + // are currently a (XLA-wide) bridge to work with ALG_UNSET. + PrecisionConfig::Precision lhs_operand_precision; + PrecisionConfig::Precision rhs_operand_precision; + // Encodes `tt.dot`'s `inputPrecision` attribute. + ttir::InputPrecision ttir_input_precision; +}; + +using AlgorithmEmitter = absl::StatusOr (*)(EmitterLocOpBuilder&, + const DotOperands&, + const PrecisionSpec&); + +Value RoundToBF16(EmitterLocOpBuilder b, Value input) { + return Cast(b, input, b.getBF16Type()); +} + +// Truncates |input| of F32 type to the number representable in Bf16 toward +// zero. +Value MaskToBF16(EmitterLocOpBuilder& b, Value input) { + ShapedType input_type = mlir::dyn_cast(input.getType()); + Type input_type_as_i32 = input_type.clone(b.getI32Type()); + Value input_as_i32 = b.create(input_type_as_i32, input); + Value mask = triton::CreateConst(b, b.getI32Type(), 0xFFFF0000u, + input_type.getShape()) + .UnwrapTensor(); + Value high_bits = + b.create(input_type_as_i32, input_as_i32, mask); + + return b.create(input_type, high_bits); +} + +// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. +// If rhs is +infinity, we will have: +// +infinity * 1.0 = +infinity +// +infinity * 0.0 = NaN +// We would get the wrong result if we sum these partial products. Instead, we +// must override any accumulated result if the last partial product is +// non-finite. See b/115844437. +Value ZeroNaNs(EmitterLocOpBuilder& b, Value input) { + Value positive_inf = + CreateConst(b, b.getF32Type(), + std::numeric_limits::infinity(), + mlir::cast(input.getType()).getShape()) + .UnwrapTensor(); + Value abs_input = b.create(input); + Value is_finite = b.create(arith::CmpFPredicate::OGT, + positive_inf, abs_input); + return b.create(is_finite, input, ZerosLike(b, input)); +} + +absl::Status ExpectType(Value v, Type expected_type) { + if (ElementType(v) != expected_type) { + std::string expected_type_str, actual_type_str; + { + llvm::raw_string_ostream os_expected(expected_type_str); + llvm::raw_string_ostream os_actual(actual_type_str); + expected_type.print(os_expected); + ElementType(v).print(os_actual); + } + return absl::FailedPreconditionError(absl::StrCat( + "Expected type ", expected_type_str, " but got ", actual_type_str)); + } + return absl::OkStatus(); +} + +std::vector SplitF32(EmitterLocOpBuilder b, Value input, + int split_count) { + std::vector split_inputs; + split_inputs.reserve(split_count); + for (int i = 0; i < split_count; ++i) { + if (i != split_count - 1) { + Value masked = MaskToBF16(b, input); + input = b.create(input, masked); + split_inputs.push_back(RoundToBF16(b, masked)); + } else { + split_inputs.push_back(RoundToBF16(b, input)); + } + } + return split_inputs; +} + +Value IEEEDot(EmitterLocOpBuilder b, Value lhs, Value rhs, Value acc) { + return b.create(lhs, rhs, acc, + /*inputPrecision=*/ttir::InputPrecision::IEEE, + /*maxNumImpreciseAcc=*/0); +} + +// Leverages BF16 datatype for F32 matmul computation. It follows the guidance +// from https://arxiv.org/pdf/1904.06376.pdf. +absl::StatusOr EmitBF16x9Matmul(EmitterLocOpBuilder& b, + const DotOperands& dot_operands, + const PrecisionSpec& precision_spec) { + Type f32 = b.getF32Type(); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.lhs, f32)); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.rhs, f32)); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.accumulator, f32)); + + std::vector lhs_parts = SplitF32(b, dot_operands.lhs, 3); + std::vector rhs_parts = SplitF32(b, dot_operands.rhs, 3); + + Value local_acc = triton::ZerosLike(b, dot_operands.accumulator); + Value result; + + // low @ low + low @ mid + mid @ low + result = IEEEDot(b, lhs_parts[2], rhs_parts[2], local_acc); + result = IEEEDot(b, lhs_parts[1], rhs_parts[2], result); + result = IEEEDot(b, lhs_parts[2], rhs_parts[1], result); + + // mid @ mid + result = IEEEDot(b, lhs_parts[1], rhs_parts[1], result); + + // high @ low + low @ high + result = IEEEDot(b, lhs_parts[2], rhs_parts[0], result); + result = IEEEDot(b, lhs_parts[0], rhs_parts[2], result); + + // high @ mid + mid @ high + result = IEEEDot(b, lhs_parts[1], rhs_parts[0], result); + result = IEEEDot(b, lhs_parts[0], rhs_parts[1], result); + + result = ZeroNaNs(b, result); + result = IEEEDot(b, lhs_parts[0], rhs_parts[0], result); + result = b.create(dot_operands.accumulator, result); + return result; +} + +// Leverages BF16 datatype for F32 matmul computation. It follows the guidance +// from https://arxiv.org/pdf/1904.06376.pdf. +absl::StatusOr EmitBF16x6Matmul(EmitterLocOpBuilder& b, + const DotOperands& dot_operands, + const PrecisionSpec& precision_spec) { + Type f32 = b.getF32Type(); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.lhs, f32)); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.rhs, f32)); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.accumulator, f32)); + + std::vector lhs_parts = SplitF32(b, dot_operands.lhs, 3); + std::vector rhs_parts = SplitF32(b, dot_operands.rhs, 3); + + Value local_acc = triton::ZerosLike(b, dot_operands.accumulator); + Value result = IEEEDot(b, lhs_parts[1], rhs_parts[1], local_acc); + // high @ low + low @ high + result = IEEEDot(b, lhs_parts[2], rhs_parts[0], result); + result = IEEEDot(b, lhs_parts[0], rhs_parts[2], result); + + // high @ mid + mid @ high + result = IEEEDot(b, lhs_parts[1], rhs_parts[0], result); + result = IEEEDot(b, lhs_parts[0], rhs_parts[1], result); + + result = ZeroNaNs(b, result); + result = IEEEDot(b, lhs_parts[0], rhs_parts[0], result); + result = b.create(dot_operands.accumulator, result); + return result; +} + +// Compute F32 matmul with 3 BF16 dots. It is less accurate than +// EmitBF16x6Matmul. +absl::StatusOr EmitBF16x3Matmul(EmitterLocOpBuilder& b, + const DotOperands& dot_operands, + const PrecisionSpec& precision_spec) { + Type f32 = b.getF32Type(); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.lhs, f32)); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.rhs, f32)); + TF_RETURN_IF_ERROR(ExpectType(dot_operands.accumulator, f32)); + + std::vector lhs_bf16 = SplitF32(b, dot_operands.lhs, 2); + std::vector rhs_bf16 = SplitF32(b, dot_operands.rhs, 2); + + Value local_acc = triton::ZerosLike(b, dot_operands.accumulator); + Value result = IEEEDot(b, lhs_bf16[1], rhs_bf16[0], local_acc); + result = IEEEDot(b, lhs_bf16[0], rhs_bf16[1], result); + result = ZeroNaNs(b, result); + result = IEEEDot(b, lhs_bf16[0], rhs_bf16[0], result); + result = b.create(dot_operands.accumulator, result); + return result; +} + +bool IsTf32Allowed(const HloDotInstruction& dot) { + auto precision_config = dot.precision_config(); + if (precision_config.algorithm() == PrecisionConfig::ALG_UNSET) { + return tsl::tensor_float_32_execution_enabled() && + precision_config.operand_precision(0) == PrecisionConfig::DEFAULT && + precision_config.operand_precision(1) == PrecisionConfig::DEFAULT; + } + return algorithm_util::HasTf32InputType(precision_config.algorithm()); +} + +bool DotDependsOnConvertFromByteWideOrSmallerTypeToF32( + const HloDotInstruction* dot) { + return HloBfsAnyOf({dot}, [&](const HloInstruction* node) { + if (node->opcode() != HloOpcode::kConvert) { + return false; + } + int in_width = + primitive_util::BitWidth(node->operand(0)->shape().element_type()); + return in_width <= 8 && node->shape().element_type() == F32; + }); +} + +ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) { + if (dot.precision_config().algorithm() == + PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { + return ttir::InputPrecision::TF32x3; + } + + bool use_tf32 = IsTf32Allowed(dot); + // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. + use_tf32 = + use_tf32 && !DotDependsOnConvertFromByteWideOrSmallerTypeToF32(&dot); + + return use_tf32 ? ttir::InputPrecision::TF32 : ttir::InputPrecision::IEEE; +} + +Type GetAlgUnsetAccumulatorType(EmitterLocOpBuilder& b, + const DotOperands& dot_operands) { + Type lhs_type = ElementType(dot_operands.lhs); + Type rhs_type = ElementType(dot_operands.rhs); + Type accumulator_type = ElementType(dot_operands.accumulator); + + // The code below assumes that lhs and rhs have the same type. However + // this may not always be the case with f8 matmuls, e.g. e4m3×e5m2 is + // supported at the hardware level. NVIDIA GPUs currently only support f32 + // accumulators for such matmuls. + if (lhs_type.isFloat(8) && rhs_type.isFloat(8)) { + return b.getF32Type(); + } + + CHECK(lhs_type == rhs_type); + + // Currently allowing 8x8-bit ints -> i32. + if (lhs_type == b.getIntegerType(8) && accumulator_type.isInteger(32)) { + return b.getI32Type(); + } + return (accumulator_type.isF64() && lhs_type.isF64()) ? b.getF64Type() + : b.getF32Type(); +} + +absl::StatusOr EmitDotAlgUnset(EmitterLocOpBuilder& b, + const DotOperands& dot_operands, + const PrecisionSpec& precision_spec) { + // Execute matrix multiplication of input tiles and pass the accumulator. + // TODO(manany): Should be looked into once we enable Hopper workloads. + // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a + // lower precision than the output type. The change was introduced here: + // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a + Value lhs = dot_operands.lhs; + Value rhs = dot_operands.rhs; + Value acc = dot_operands.accumulator; + + Type expected_acc_type = GetAlgUnsetAccumulatorType(b, dot_operands); + if (ElementType(acc) != expected_acc_type) { + return absl::FailedPreconditionError( + "Given accumulator type for unset dot does not match expected type."); + } + + int max_num_imprecise_acc = 0; + if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) { + // For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make + // sense to enable frequent accumulator promotion at higher matmul + // precisions set in the config. + max_num_imprecise_acc = std::numeric_limits::max(); + } + + return b.create( + lhs, rhs, acc, + /*inputPrecision=*/precision_spec.ttir_input_precision, + /*maxNumImpreciseAcc=*/max_num_imprecise_acc); +} + +absl::StatusOr EmitRegularDot(EmitterLocOpBuilder& b, + const DotOperands& dot_operands, + const PrecisionSpec& precision_spec) { + Value lhs = dot_operands.lhs; + Value rhs = dot_operands.rhs; + + int max_num_imprecise_acc = 0; + if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) { + // For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make + // sense to enable frequent accumulator promotion at higher matmul + // precisions set in the config. + max_num_imprecise_acc = std::numeric_limits::max(); + } + + // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. + // TODO(bchetioui): abstract this. + if (precision_spec.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32) { + if (ElementType(lhs).isF32()) { + lhs = Cast(b, lhs, b.getBF16Type()); + } + + if (ElementType(rhs).isF32()) { + rhs = Cast(b, rhs, b.getBF16Type()); + } + } + + return b.create( + dot_operands.lhs, dot_operands.rhs, dot_operands.accumulator, + /*inputPrecision=*/precision_spec.ttir_input_precision, + /*maxNumImpreciseAcc=*/max_num_imprecise_acc); +} + +} // namespace + +absl::StatusOr EmitSingleTileDot(EmitterLocOpBuilder& b, + const HloDotInstruction& dot, + DotOperands dot_operands) { + AlgorithmEmitter algorithm_emitter = nullptr; + PrecisionSpec precision_spec{dot.precision_config().algorithm(), + dot.precision_config().operand_precision(0), + dot.precision_config().operand_precision(1), + InferDotPrecision(dot)}; + + // Algorithms mostly expect that their input and output types correspond to + // what the algorithm describes. This is not always the case though, e.g. + // for BF16_BF16_F32_X9, working from inputs casted to BF16 makes no sense; + // this algorithm instead expects F32 inputs, and performs splits into BF16 + // sub-values under the hood. + std::optional force_operands_type; + std::optional force_accumulator_type; + + PrecisionConfig::Algorithm algorithm = precision_spec.algorithm; + + Type bf16 = b.getBF16Type(); + Type f16 = b.getF16Type(); + Type f32 = b.getF32Type(); + Type f64 = b.getF64Type(); + + switch (algorithm) { + case PrecisionConfig::ALG_UNSET: + algorithm_emitter = EmitDotAlgUnset; + break; + case PrecisionConfig::ALG_DOT_F16_F16_F16: + force_operands_type = f16; + force_accumulator_type = f16; + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_F32_F32_F32: + force_operands_type = f32; + force_accumulator_type = f32; + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_F64_F64_F64: + force_operands_type = f64; + force_accumulator_type = f64; + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_F16_F16_F32: + force_operands_type = f16; + force_accumulator_type = f32; + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: + force_operands_type = bf16; + force_accumulator_type = bf16; + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + force_operands_type = bf16; + force_accumulator_type = f32; + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + force_operands_type = f32; // This is not a typo. + force_accumulator_type = f32; + algorithm_emitter = EmitBF16x3Matmul; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + force_operands_type = f32; // This is not a typo. + force_accumulator_type = f32; + algorithm_emitter = EmitBF16x6Matmul; + break; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + // TODO(bchetioui): pass around tf32 matmul config. + force_operands_type = f32; + force_accumulator_type = f32; + // TODO(bchetioui): this should be factored out of EmitRegularDot. + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: + // TODO(bchetioui): pass around tf32 matmul config. + force_operands_type = f32; + force_accumulator_type = f32; + // TODO(bchetioui): this should be factored out of EmitRegularDot. + algorithm_emitter = EmitRegularDot; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: + force_operands_type = f32; // This is not a typo. + force_accumulator_type = f32; + algorithm_emitter = EmitBF16x9Matmul; + break; + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + // TODO(bchetioui): How to enforce "any f8"? + force_accumulator_type = f32; + break; + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + // TODO(bchetioui): How to enforce "any f8"? + force_accumulator_type = f32; + break; + default: + break; + } + + // Couldn't find an algorithm emitter for this algorithm. Raise an error. + if (algorithm_emitter == nullptr) { + return absl::UnimplementedError( + absl::StrCat("This algorithm is not supported yet: ", + PrecisionConfig::Algorithm_Name(algorithm))); + } + + if (force_operands_type.has_value()) { + if (ElementType(dot_operands.lhs) != *force_operands_type) { + dot_operands.lhs = Cast(b, dot_operands.lhs, *force_operands_type); + } + + if (ElementType(dot_operands.rhs) != *force_operands_type) { + dot_operands.rhs = Cast(b, dot_operands.rhs, *force_operands_type); + } + } + + if (force_accumulator_type.has_value()) { + if (ElementType(dot_operands.accumulator) != *force_accumulator_type) { + dot_operands.accumulator = + Cast(b, dot_operands.accumulator, *force_accumulator_type); + } + } + + TF_ASSIGN_OR_RETURN(Value result, + algorithm_emitter(b, dot_operands, precision_spec)); + + // TODO(b/393299275): once we've moved on from the legacy emitter, we should + // make sure that this accumulator type is equal to the one derived here. + Type outer_accumulator_type = ElementType(dot_operands.accumulator); + if (ElementType(result) != outer_accumulator_type) { + result = Cast(b, result, outer_accumulator_type); + } + + return result; +} + +} // namespace triton +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h new file mode 100644 index 00000000000000..2ee2bd2c01c5b3 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h @@ -0,0 +1,47 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_DOT_ALGORITHMS_H_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_DOT_ALGORITHMS_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/Value.h" +#include "xla/codegen/emitter_loc_op_builder.h" +#include "xla/hlo/ir/hlo_instructions.h" + +namespace xla { +namespace gpu { +namespace triton { + +// Carries named `Value`s corresponding to `dot` operands. This includes an +// accumulator. +struct DotOperands { + ::mlir::Value lhs; + ::mlir::Value rhs; + ::mlir::Value accumulator; +}; + +// Emits a single-tile dot, considering the given `dot` instruction's algorithm +// and operand precisions. Raises an `UnimplementedError` if the algorithm is +// not supported. +absl::StatusOr<::mlir::Value> EmitSingleTileDot(EmitterLocOpBuilder& b, + const HloDotInstruction& dot, + DotOperands dot_operands); + +} // namespace triton +} // namespace gpu +} // namespace xla + +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_DOT_ALGORITHMS_H_ diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index 63f99db6e840f9..2e590bc77f8d06 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -39,6 +38,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -52,20 +52,19 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/triton/dot_algorithms.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" -#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_traversal.h" #include "xla/layout.h" -#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" @@ -87,14 +86,9 @@ limitations under the License. #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/tensor_float_32_utils.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -701,10 +695,6 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) { return CreateConst(b, b.getI64Type(), v); } -Value RoundToBF16(EmitterLocOpBuilder b, Value input) { - return triton::Cast(b, input, b.getBF16Type()); -}; - /*static*/ absl::StatusOr MatMulDims::Create( const TritonGemmConfig& config, const HloDotInstruction& dot, const TritonFusionAnalysis& analysis) { @@ -891,11 +881,9 @@ absl::Status UncompilableMatmul(absl::string_view explanation) { return s; } -bool IsFp8Matmul(const HloDotInstruction* dot_instr) { - return absl::c_all_of(std::array{0, 1}, [&](int idx) { - return primitive_util::IsF8Type( - dot_instr->operand(idx)->shape().element_type()); - }); +bool IsFp8Matmul(const HloDotInstruction* dot) { + return primitive_util::IsF8Type(dot->operand(0)->shape().element_type()) && + primitive_util::IsF8Type(dot->operand(1)->shape().element_type()); } class MatMulEmitterHelper { @@ -1579,173 +1567,6 @@ ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, return result; } -// Truncates |input| of F32 type to the number representable in Bf16 toward -// zero. -Value MaskToBF16(EmitterLocOpBuilder& b, Value input) { - ShapedType input_type = mlir::dyn_cast(input.getType()); - Type input_type_as_i32 = input_type.clone(b.getI32Type()); - Value input_as_i32 = b.create(input_type_as_i32, input); - Value mask = CreateConst(b, b.getI32Type(), 0xFFFF0000u, - input_type.getShape()); - Value high_bits = b.create(input_type_as_i32, input_as_i32, mask); - - return b.create(input_type, high_bits); -} - -// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. -// If rhs is +infinity, we will have: -// +infinity * 1.0 = +infinity -// +infinity * 0.0 = NaN -// We would get the wrong result if we sum these partial products. Instead, we -// must override any accumulated result if the last partial product is -// non-finite. See b/115844437. -Value ZeroNaNs(EmitterLocOpBuilder& b, Value input) { - Value positive_inf = CreateConst( - b, b.getF32Type(), std::numeric_limits::infinity(), - mlir::cast(input.getType()).getShape()); - Value abs_input = b.create(input); - Value is_finite = - b.create(ma::CmpFPredicate::OGT, positive_inf, abs_input); - return b.create(is_finite, input, ZerosLike(b, input)); -} - -absl::Status CheckF32Type(EmitterLocOpBuilder& b, Value lhs, Value rhs, - Value acc) { - Type f32 = b.getF32Type(); - TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); - TF_RET_CHECK(mlir::cast(rhs.getType()).getElementType() == f32); - TF_RET_CHECK(mlir::cast(acc.getType()).getElementType() == f32); - return absl::OkStatus(); -} - -std::vector SplitF32(EmitterLocOpBuilder b, Value input, - int split_count) { - std::vector split_inputs; - split_inputs.reserve(split_count); - for (int i = 0; i < split_count; ++i) { - if (i != split_count - 1) { - Value masked = MaskToBF16(b, input); - input = b.create(input, masked); - split_inputs.push_back(RoundToBF16(b, masked)); - } else { - split_inputs.push_back(RoundToBF16(b, input)); - } - } - return split_inputs; -} - -Value IEEEDot(EmitterLocOpBuilder b, Value lhs, Value rhs, Value acc) { - return b.create(lhs, rhs, acc, - /*inputPrecision=*/mt::InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); -} - -// Leverages BF16 datatype for F32 matmul computation. It follows the guidance -// from https://arxiv.org/pdf/1904.06376.pdf. -Value EmitBF16x9Matmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, - Value acc) { - std::vector lhs_parts = SplitF32(b, lhs, 3); - std::vector rhs_parts = SplitF32(b, rhs, 3); - - Value local_acc = ZerosLike(b, acc); - Value result; - - // low @ low + low @ mid + mid @ low - result = IEEEDot(b, lhs_parts[2], rhs_parts[2], local_acc); - result = IEEEDot(b, lhs_parts[1], rhs_parts[2], result); - result = IEEEDot(b, lhs_parts[2], rhs_parts[1], result); - - // mid @ mid - result = IEEEDot(b, lhs_parts[1], rhs_parts[1], result); - - // high @ low + low @ high - result = IEEEDot(b, lhs_parts[2], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[2], result); - - // high @ mid + mid @ high - result = IEEEDot(b, lhs_parts[1], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[1], result); - - result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[0], result); - result = b.create(acc, result); - return result; -} - -// Leverages BF16 datatype for F32 matmul computation. It follows the guidance -// from https://arxiv.org/pdf/1904.06376.pdf. -Value EmitBF16x6Matmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, - Value acc) { - std::vector lhs_parts = SplitF32(b, lhs, 3); - std::vector rhs_parts = SplitF32(b, rhs, 3); - - Value local_acc = ZerosLike(b, acc); - Value result = IEEEDot(b, lhs_parts[1], rhs_parts[1], local_acc); - // high @ low + low @ high - result = IEEEDot(b, lhs_parts[2], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[2], result); - - // high @ mid + mid @ high - result = IEEEDot(b, lhs_parts[1], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[1], result); - - result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[0], result); - result = b.create(acc, result); - return result; -} - -// Compute F32 matmul with 3 BF16 dots. It is less accurate than -// EmitBF16x6Matmul. -Value EmitBF16x3Matmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, - Value acc) { - std::vector lhs_bf16 = SplitF32(b, lhs, 2); - std::vector rhs_bf16 = SplitF32(b, rhs, 2); - - Value local_acc = ZerosLike(b, acc); - Value result = IEEEDot(b, lhs_bf16[1], rhs_bf16[0], local_acc); - result = IEEEDot(b, lhs_bf16[0], rhs_bf16[1], result); - result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_bf16[0], rhs_bf16[0], result); - result = b.create(acc, result); - return result; -} - -bool IsTf32Allowed(const HloDotInstruction* dot_instr) { - const PrecisionConfig::Algorithm algorithm = - dot_instr->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_UNSET) { - return tsl::tensor_float_32_execution_enabled() && - absl::c_none_of(dot_instr->precision_config().operand_precision(), - [](const int precision) { - return precision != PrecisionConfig::DEFAULT; - }); - } - - return algorithm_util::HasTf32InputType(algorithm); -} - -mt::InputPrecision InferDotPrecision(const HloDotInstruction* dot_instr) { - auto algorithm = dot_instr->precision_config().algorithm(); - if (algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { - return mt::InputPrecision::TF32x3; - } - // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. - bool is_unsupported_bitwidth = - HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) { - if (node->opcode() != HloOpcode::kConvert) { - return false; - } - int in_width = - primitive_util::BitWidth(node->operand(0)->shape().element_type()); - return in_width <= 8 && node->shape().element_type() == F32; - }); - - return IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth - ? mt::InputPrecision::TF32 - : mt::InputPrecision::IEEE; -} // This is a heuristic that serves as a proxy for register usage and code size. // @@ -2041,36 +1862,6 @@ Value EmitMaskOnInput(EmitterLocOpBuilder& b, return if_op.getResult(0); } -Value EmitRegularMatmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc, - const HloDotInstruction* dot_instr) { - // Execute matrix multiplication of input tiles and pass the accumulator. - // TODO(manany): Should be looked into once we enable Hopper workloads. - // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a - // lower precision than the output type. The change was introduced here: - // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a - auto dot_precision = InferDotPrecision(dot_instr); - - // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. - if (dot_instr->precision_config().algorithm() == - PrecisionConfig::ALG_DOT_BF16_BF16_F32) { - if (dot_instr->operand(0)->shape().element_type() == F32) { - lhs = triton::Cast(b, lhs, b.getBF16Type()); - } - if (dot_instr->operand(1)->shape().element_type() == F32) { - rhs = triton::Cast(b, rhs, b.getBF16Type()); - } - } - - // For fp8 matmuls, disable accumulator promotion, as it's what cublas - // does. It may make sense to enable frequent accumulator promotion at - // higher matmul precisions set in the config. - int max_num_imprecise_acc = - IsFp8Matmul(dot_instr) ? std::numeric_limits::max() : 0; - return b.create(lhs, rhs, acc, - /*inputPrecision=*/dot_precision, - /*maxNumImpreciseAcc=*/max_num_imprecise_acc); -} - absl::StatusOr GetTritonGemmConfig( const HloFusionInstruction* fusion) { auto backend_config = @@ -2099,11 +1890,12 @@ Type GetIndexType(EmitterLocOpBuilder& b, const HloDotInstruction& dot_instr, return b.getIntegerType(use_64bit_indexing ? 64 : 32); } -void EmitForLoopBody(EmitterLocOpBuilder& b, MatMulEmitterHelper& emitter, - const Scopes& scopes, const HloDotInstruction* dot_instr, - const MatMulDims& dims, - const llvm::SmallVector& inputs, Value ki, - ValueRange iter_args) { +absl::Status EmitForLoopBody(EmitterLocOpBuilder& b, + MatMulEmitterHelper& emitter, const Scopes& scopes, + const HloDotInstruction* dot_instr, + const MatMulDims& dims, + const llvm::SmallVector& inputs, + Value ki, ValueRange iter_args) { SmallVector args_for_yield; std::array, 3> values; @@ -2140,25 +1932,15 @@ void EmitForLoopBody(EmitterLocOpBuilder& b, MatMulEmitterHelper& emitter, // (i.e. zeroed out), so the padded metadata can hold any values. } - Value acc_next; - auto algorithm = dot_instr->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9) { - TF_CHECK_OK(CheckF32Type(b, dot_lhs, dot_rhs, iter_args.back())); - acc_next = EmitBF16x9Matmul(b, dot_lhs, dot_rhs, iter_args.back()); - } else if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) { - TF_CHECK_OK(CheckF32Type(b, dot_lhs, dot_rhs, iter_args.back())); - acc_next = EmitBF16x6Matmul(b, dot_lhs, dot_rhs, iter_args.back()); - } else if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3) { - TF_CHECK_OK(CheckF32Type(b, dot_lhs, dot_rhs, iter_args.back())); - acc_next = EmitBF16x3Matmul(b, dot_lhs, dot_rhs, iter_args.back()); - } else { - acc_next = - EmitRegularMatmul(b, dot_lhs, dot_rhs, iter_args.back(), dot_instr); - } + TF_ASSIGN_OR_RETURN( + Value acc_next, + triton::EmitSingleTileDot( + b, *dot_instr, + triton::DotOperands{dot_lhs, dot_rhs, iter_args.back()})); args_for_yield.push_back(acc_next); - b.create(args_for_yield); + + return absl::OkStatus(); } } // namespace @@ -2238,7 +2020,8 @@ absl::StatusOr> EmitMatMul( auto body_builder_callback = [&](mlir::OpBuilder&, mlir::Location, Value ki, ValueRange iter_args) -> void { - EmitForLoopBody(b, emitter, scopes, dot_instr, dims, inputs, ki, iter_args); + CHECK_OK(EmitForLoopBody(b, emitter, scopes, dot_instr, dims, inputs, ki, + iter_args)); }; iter_args.push_back(accumulator_init); From 59f2d850d14539a80060544a81eedcd3ed3b0067 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Mon, 31 Mar 2025 13:02:28 -0700 Subject: [PATCH 0054/1324] Implement `Compile()` and `DeserializeExecutable()` that return an unloaded executable for PJRT stream executor. Also refactor `LoadSerializedExecutable()` and `Load()` accordingly. PiperOrigin-RevId: 742367054 --- third_party/xla/xla/client/local_client.cc | 25 +- third_party/xla/xla/client/local_client.h | 9 + third_party/xla/xla/pjrt/BUILD | 5 + .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 34 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 10 - .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 5 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc | 12 +- .../pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc | 9 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 232 ++++++++++++------ .../xla/pjrt/pjrt_stream_executor_client.h | 27 +- .../xla/pjrt/stream_executor_executable.cc | 57 +++++ .../xla/xla/pjrt/stream_executor_executable.h | 75 ++++-- .../xla/xla/service/buffer_assignment.cc | 5 +- 13 files changed, 352 insertions(+), 153 deletions(-) diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index 2a1a6d74448b10..fd1bbae81e4503 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -473,17 +473,30 @@ LocalClient::CompileAheadOfTime( absl::StatusOr> LocalClient::Load( const std::string& serialized_aot_result, const ExecutableBuildOptions& options) { - TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, - UpdateBuildOptions(options, default_device_ordinal())); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend().stream_executor(updated_options.device_ordinal())); - TF_ASSIGN_OR_RETURN(Compiler * compiler, Compiler::GetForPlatform(platform())); TF_ASSIGN_OR_RETURN( std::unique_ptr aot_result, compiler->LoadAotCompilationResult(serialized_aot_result)); + return LoadInternal(std::move(aot_result), compiler, options); +} + +absl::StatusOr> LocalClient::Load( + std::unique_ptr aot_result, + const ExecutableBuildOptions& options) { + TF_ASSIGN_OR_RETURN(Compiler * compiler, + Compiler::GetForPlatform(platform())); + return LoadInternal(std::move(aot_result), compiler, options); +} + +absl::StatusOr> LocalClient::LoadInternal( + std::unique_ptr aot_result, Compiler* compiler, + const ExecutableBuildOptions& options) { + TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, + UpdateBuildOptions(options, default_device_ordinal())); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + backend().stream_executor(updated_options.device_ordinal())); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index c9ee317bc42e5a..c687766fcc37b8 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -174,6 +174,11 @@ class LocalClient : public Client { const std::string& serialized_aot_result, const ExecutableBuildOptions& options); + // Variant of `Load()` that accepts an AotCompilationResult. + absl::StatusOr> Load( + std::unique_ptr aot_result, + const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the @@ -244,6 +249,10 @@ class LocalClient : public Client { private: LocalService* local_service_; + + absl::StatusOr> LoadInternal( + std::unique_ptr aot_result, Compiler* compiler, + const ExecutableBuildOptions& options); }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 2ee156f7faefe0..b03ab7cb0020d5 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -483,9 +483,13 @@ cc_library( srcs = ["stream_executor_executable.cc"], hdrs = ["stream_executor_executable.h"], deps = [ + ":host_memory_spaces", ":pjrt_common", ":pjrt_executable", ":stream_executor_executable_proto_cc", + "//xla:shape_util", + "//xla:util", + "//xla/client:local_client", "//xla/hlo/ir:hlo", "//xla/service:compiler", "@com_google_absl//absl/container:flat_hash_map", @@ -517,6 +521,7 @@ cc_library( ":pjrt_future", ":pjrt_stream_executor_device_description", ":semaphore", + ":stream_executor_executable", ":tracked_device_buffer", ":transpose", ":utils", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 0c859c6e49fa34..e6d1ff8645a312 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -906,40 +906,6 @@ StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, load_options); } -absl::StatusOr> -StreamExecutorGpuClient::Load(std::unique_ptr executable) { - auto se_executable = absl::WrapUnique( - tensorflow::down_cast(executable.release())); - - CompileOptions compile_options = se_executable->compile_options(); - CompileOptions input_options = compile_options; - TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); - TF_ASSIGN_OR_RETURN(ExecutableExtras extras, - GetExecutableExtras(&compile_options)); - - // Load Executable from AOT compilation result. - std::vector> local_executables; - local_executables.reserve(se_executable->aot_executables().size()); - for (std::unique_ptr& aot_executable : - se_executable->aot_executables()) { - TF_ASSIGN_OR_RETURN(std::string serialized, - aot_executable->SerializeAsString()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr local_executable, - client()->Load(serialized, compile_options.executable_build_options)); - local_executables.push_back(std::move(local_executable)); - } - bool parameter_is_tupled_arguments = - compile_options.parameter_is_tupled_arguments; - auto ret = std::make_unique( - std::move(local_executables), parameter_is_tupled_arguments, - std::move(extras.device_assignment), std::move(input_options), - std::move(extras.addressable_device_logical_ids), - std::move(extras.addressable_devices), this); - TF_RETURN_IF_ERROR(ret->SetUpDonation(parameter_is_tupled_arguments)); - return std::unique_ptr(std::move(ret)); -} - namespace { #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 7d40b04535a839..a78bf13e641fb2 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -140,16 +140,6 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { return &topology_; } - absl::StatusOr> Load( - std::unique_ptr executable, - const LoadOptions& load_options) override { - return absl::WrapUnique( - tensorflow::down_cast(executable.release())); - } - - absl::StatusOr> Load( - std::unique_ptr executable); - absl::StatusOr> LoadSerialized( absl::string_view serialized, std::optional options, const LoadOptions& load_options); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 78bb3c249c085c..9737ae193e0db5 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1154,7 +1154,7 @@ TEST(StreamExecutorGpuClientTest, GetDeviceFabricInfo) { &executor->GetDeviceDescription())) == 9) { auto fabric_info = GetDeviceFabricInfo(executor->device_ordinal()); if (fabric_info.ok()) { - EXPECT_FALSE(true); + ADD_FAILURE(); } } } @@ -1924,8 +1924,7 @@ TEST(StreamExecutorGpuClientTest, MlirParameterLayoutFromOptionsIsSetInHlo) { xla::CompileOptions options; options.argument_layouts = { {ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 2, 2}, {0, 2, 1})}}; - TF_ASSERT_OK_AND_ASSIGN(auto executable, - client->CompileAndLoad(*module, options)); + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, options)); TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); auto first_param_layout = diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index ace0900148d075..375748f4fad3bc 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -131,7 +131,7 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, if (!options.target_config) { if (client != nullptr) { TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); - return client->CompileAndLoad(computation, options); + return client->Compile(computation, options); } const auto& gpu_topology = tensorflow::down_cast( @@ -177,21 +177,17 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, const int num_partitions = hlo_module->config().num_partitions(); const std::string name = hlo_module->name(); const std::string fingerprint = hlo_module->GetFingerprint128(); - const int num_outputs = hlo_module->result_shape().IsTuple() - ? hlo_module->result_shape().tuple_shapes_size() - : 1; auto unique_module_group = std::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> aot_results, gpu_compiler->CompileAheadOfTime(std::move(unique_module_group), aot_options)); - std::vector> output_memory_kinds(1); - output_memory_kinds[0].resize(num_outputs, - StreamExecutorGpuHbmMemorySpace::kKind); return std::make_unique( std::move(input_options), std::move(aot_results), num_replicas, - num_partitions, name, fingerprint, std::move(output_memory_kinds)); + num_partitions, name, fingerprint, + /*default_memory_kind=*/StreamExecutorGpuHbmMemorySpace::kKind, + /*local_executables=*/std::nullopt); } absl::StatusOr> diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 8c0b9bc6d3a182..2ae9ee769fc260 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -99,8 +99,9 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) { TF_ASSERT_OK_AND_ASSIGN(auto executable, compiler.Compile(opts, mlir_module.get(), *topology, /*client=*/nullptr)); - TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable, - se_client->Load(std::move(executable))); + TF_ASSERT_OK_AND_ASSIGN( + auto loaded_executable, + se_client->Load(std::move(executable), LoadOptions())); TF_ASSERT_OK_AND_ASSIGN( std::vector>> result, @@ -129,7 +130,7 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { compiler.Compile(opts, computation, *topology, /*client=*/nullptr)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr loaded_executable, - se_client->Load(std::move(executable))); + se_client->Load(std::move(executable), LoadOptions())); TF_ASSERT_OK_AND_ASSIGN( std::vector>> result, loaded_executable->Execute(/*argument_handles=*/{{}}, {})); @@ -192,7 +193,7 @@ TEST(StreamExecutorGpuCompilerTest, SuccessSerializeDeserialize) { compiler.Compile(opts, computation, *topology, /*client=*/nullptr)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr loaded_executable, - se_client->Load(std::move(executable))); + se_client->Load(std::move(executable), LoadOptions())); // Serialize the executable and deserialize it without failure. TF_ASSERT_OK_AND_ASSIGN(std::string serialized_executable, diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 097bc9071a85a0..20eb79146e2c6d 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -117,6 +117,7 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/profiling/device_time_measurement.h" #include "xla/pjrt/semaphore.h" +#include "xla/pjrt/stream_executor_executable.h" #include "xla/pjrt/tracked_device_buffer.h" #include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" @@ -3519,14 +3520,14 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { return extras; } -absl::StatusOr> +absl::StatusOr> PjRtStreamExecutorClient::CompileInternal( const XlaComputation& computation, const std::vector& argument_layout_pointers, LayoutCanonicalizationCallback layout_canonicalization_callback, CompileOptions options) { - tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile"); - VLOG(1) << "PjRtStreamExecutorClient::Compile"; + tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::CompileInternal"); + VLOG(1) << "PjRtStreamExecutorClient::CompileInternal"; if (key_value_store().has_value() && !options.executable_build_options.key_value_store()) { options.executable_build_options.set_key_value_store(*key_value_store()); @@ -3535,13 +3536,6 @@ PjRtStreamExecutorClient::CompileInternal( TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); - TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options)); - std::shared_ptr& device_assignment = - extras.device_assignment; - std::vector& - addressable_device_logical_ids = extras.addressable_device_logical_ids; - std::vector& addressable_devices = extras.addressable_devices; - // It is important to set the canonicalization callback after creating // a copy of the options so that the executable's options remain without // the callback - the callback would break the executable's serializability. @@ -3555,26 +3549,39 @@ PjRtStreamExecutorClient::CompileInternal( client()->Compile(computation, argument_layout_pointers, options.executable_build_options)); - auto executable = std::make_unique( - std::move(local_executables), options.parameter_is_tupled_arguments, - std::move(device_assignment), std::move(input_options), - std::move(addressable_device_logical_ids), std::move(addressable_devices), - this); + return BuildPjRtExecutable(std::move(local_executables), input_options); +} - TF_RETURN_IF_ERROR( - executable->SetUpDonation(options.parameter_is_tupled_arguments)); - const auto& ex_options = options.executable_build_options; - if (ex_options.has_debug_options() && - ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots()) { - executable->SetInputHloSnapshotBits( - computation.proto(), options.executable_build_options.debug_options()); - } - return std::unique_ptr(std::move(executable)); +absl::StatusOr> +PjRtStreamExecutorClient::Compile(const XlaComputation& computation, + CompileOptions options) { + std::vector argument_layout_pointers; + const ExecutableBuildOptions& build_options = + options.executable_build_options; + const bool allow_auto_layout = + build_options.has_debug_options() && + build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); + TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( + computation, + [local_client = client(), + allow_auto_layout](Shape shape) -> absl::StatusOr { + if (allow_auto_layout && !shape.has_layout()) { + return shape; + } + return local_client->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(shape); + }, + options.argument_layouts, &options.executable_build_options, + &argument_layout_pointers)); + return CompileInternal(computation, argument_layout_pointers, + /* layout_canonicalization_callback = */ nullptr, + options); } -absl::StatusOr> -PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, - CompileOptions options) { +absl::StatusOr> +PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, + CompileOptions options) { XlaComputation xla_computation; const ExecutableBuildOptions& exec_build_options = options.executable_build_options; @@ -3586,7 +3593,7 @@ PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, // If the compile options specify argument layout, then let's // fall back to using the options to determine layouts. if (options.argument_layouts) { - return CompileAndLoad(xla_computation, options); + return Compile(xla_computation, options); } TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, @@ -3636,28 +3643,17 @@ PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, absl::StatusOr> PjRtStreamExecutorClient::CompileAndLoad(const XlaComputation& computation, CompileOptions options) { - std::vector argument_layout_pointers; - const ExecutableBuildOptions& build_options = - options.executable_build_options; - const bool allow_auto_layout = - build_options.has_debug_options() && - build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); - TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( - computation, - [local_client = client(), - allow_auto_layout](Shape shape) -> absl::StatusOr { - if (allow_auto_layout && !shape.has_layout()) { - return shape; - } - return local_client->backend() - .transfer_manager() - ->ChooseCompactLayoutForShape(shape); - }, - options.argument_layouts, &options.executable_build_options, - &argument_layout_pointers)); - return CompileInternal(computation, argument_layout_pointers, - /* layout_canonicalization_callback = */ nullptr, - options); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + Compile(computation, options)); + return Load(std::move(executable), LoadOptions()); +} + +absl::StatusOr> +PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, + CompileOptions options) { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + Compile(module, options)); + return Load(std::move(executable), LoadOptions()); } absl::StatusOr PjRtStreamExecutorClient::SerializeExecutable( @@ -3694,21 +3690,61 @@ absl::StatusOr PjRtStreamExecutorClient::SerializeExecutable( return proto.SerializeAsString(); } -absl::StatusOr> -PjRtStreamExecutorClient::LoadSerializedExecutable( - absl::string_view serialized, std::optional options, - const LoadOptions& load_options) { +absl::StatusOr> +PjRtStreamExecutorClient::BuildPjRtExecutable( + std::vector> local_executables, + CompileOptions compile_options) { + if (local_executables.empty()) { + return Internal("No local executable"); + } + if (local_executables.size() != 1) { + return Unimplemented("Multiple executables are not supported"); + } + Executable* built_executable = local_executables[0]->executable(); + Compiler* compiler = client_->backend().compiler(); + TF_ASSIGN_OR_RETURN(std::unique_ptr aot_result, + compiler->Export(built_executable)); + std::vector> aot_results; + aot_results.push_back(std::move(aot_result)); + + if (!built_executable->has_module()) { + return absl::InternalError("Executable does not have HLO modules."); + } + const auto& hlo_module = built_executable->module(); + + const int num_replicas = hlo_module.config().replica_count(); + const int num_partitions = hlo_module.config().num_partitions(); + const std::string name = hlo_module.name(); + const std::string fingerprint = hlo_module.GetFingerprint128(); + + return std::make_unique( + std::move(compile_options), std::move(aot_results), num_replicas, + num_partitions, name, fingerprint, memory_spaces()[0]->kind(), + std::move(local_executables)); +} + +absl::StatusOr> +PjRtStreamExecutorClient::DeserializeExecutable( + absl::string_view serialized, + std::optional compile_options) { + TF_ASSIGN_OR_RETURN( + auto local_executables_and_options, + DeserializeToLocalExecutable(serialized, compile_options)); + + return BuildPjRtExecutable(std::move(local_executables_and_options.first), + local_executables_and_options.second); +} + +absl::StatusOr< + std::pair>, CompileOptions>> +PjRtStreamExecutorClient::DeserializeToLocalExecutable( + absl::string_view serialized, std::optional options) { ExecutableAndOptionsProto proto; if (serialized.size() > std::numeric_limits::max()) { - return Internal( - "PjRtStreamExecutorClient::DeserializeExecutable proto too large " - "(>2GB)"); + return Internal("Proto is too large (>2GB)"); } if (!proto.ParseFromArray(serialized.data(), serialized.size())) { - return Internal( - "PjRtStreamExecutorClient::DeserializeExecutable proto " - "deserialization " - "failed"); + return Internal("Proto deserialization failed"); } CompileOptions compile_options; @@ -3718,11 +3754,39 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( TF_ASSIGN_OR_RETURN(compile_options, CompileOptions::FromProto(proto.compile_options())); } - auto input_options = compile_options; tsl::profiler::TraceMe traceme( - "PjRtStreamExecutorClient::DeserializeExecutable"); - VLOG(1) << "PjRtStreamExecutorClient::DeserializeExecutable"; + "PjRtStreamExecutorClient::DeserializeToLocalExecutable"); + VLOG(1) << "PjRtStreamExecutorClient::DeserializeToLocalExecutable"; + + std::string str = std::move(*proto.mutable_serialized_executable()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr loaded, + client()->Load(str, compile_options.executable_build_options)); + + std::vector> local_executables; + local_executables.push_back(std::move(loaded)); + + return std::make_pair(std::move(local_executables), compile_options); +} + +absl::StatusOr> +PjRtStreamExecutorClient::LoadSerializedExecutable( + absl::string_view serialized, std::optional options, + const LoadOptions& load_options) { + TF_ASSIGN_OR_RETURN(auto local_executables_and_options, + DeserializeToLocalExecutable(serialized, options)); + return LoadInternal(std::move(local_executables_and_options.first), + local_executables_and_options.second); +} + +absl::StatusOr> +PjRtStreamExecutorClient::LoadInternal( + std::vector> local_executables, + CompileOptions compile_options) { + auto input_options = compile_options; + + TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&compile_options)); @@ -3732,13 +3796,8 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( addressable_device_logical_ids = extras.addressable_device_logical_ids; std::vector& addressable_devices = extras.addressable_devices; - std::string str = std::move(*proto.mutable_serialized_executable()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded, - client()->Load(str, compile_options.executable_build_options)); - - std::vector> local_executables; - local_executables.push_back(std::move(loaded)); + HloModuleProto hlo_module_proto = + local_executables[0]->executable()->module().ToProto(); auto executable = std::make_unique( std::move(local_executables), @@ -3749,9 +3808,40 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( TF_RETURN_IF_ERROR( executable->SetUpDonation(compile_options.parameter_is_tupled_arguments)); + const auto& ex_options = compile_options.executable_build_options; + if (ex_options.has_debug_options() && + ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots()) { + executable->SetInputHloSnapshotBits( + std::move(hlo_module_proto), + compile_options.executable_build_options.debug_options()); + } return std::unique_ptr(std::move(executable)); } +absl::StatusOr> +PjRtStreamExecutorClient::Load(std::unique_ptr executable, + const LoadOptions& load_options) { + auto se_executable = absl::WrapUnique( + tensorflow::down_cast(executable.release())); + CompileOptions compile_options = se_executable->compile_options(); + + tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Load"); + VLOG(1) << "PjRtStreamExecutorClient::Load"; + + // Load Executables from AOT compilation results. + std::vector> local_executables; + local_executables.reserve(se_executable->aot_executables().size()); + for (int i = 0; i < se_executable->aot_executables().size(); ++i) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr local_executable, + client()->Load(std::move(se_executable->aot_executables()[i]), + compile_options.executable_build_options)); + local_executables.push_back(std::move(local_executable)); + } + + return LoadInternal(std::move(local_executables), compile_options); +} + bool PjRtStreamExecutorClient::IsDmaMapped(const void* data_start, int64_t transfer_size) { absl::MutexLock lock(&dma_maps_mutex_); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 1e5c8c7cc3b001..b2e0f804f7d2bd 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -276,14 +276,22 @@ class PjRtStreamExecutorClient : public PjRtClient { absl::StatusOr GetDefaultLayout( PrimitiveType element_type, absl::Span dims) override; + absl::StatusOr> Compile( + const XlaComputation& computation, CompileOptions options) override; absl::StatusOr> CompileAndLoad( const XlaComputation& computation, CompileOptions options) override; + absl::StatusOr> Compile( + mlir::ModuleOp mlir_module, CompileOptions options) override; absl::StatusOr> CompileAndLoad( mlir::ModuleOp mlir_module, CompileOptions options) override; virtual absl::StatusOr SerializeExecutable( const PjRtLoadedExecutable& executable) const; + absl::StatusOr> DeserializeExecutable( + absl::string_view serialized, + std::optional options) override; + // For PjRtStreamExecutorClient, `options` is mandatory. // This function returns an InvalidArgument error if `std::nullopt` is passed. // TODO(b/237720161): make it actually optional @@ -292,6 +300,10 @@ class PjRtStreamExecutorClient : public PjRtClient { std::optional options, const LoadOptions& load_options) override; + absl::StatusOr> Load( + std::unique_ptr executable, + const LoadOptions& load_options) override; + absl::StatusOr> GetHloCostAnalysis() const override; @@ -417,12 +429,25 @@ class PjRtStreamExecutorClient : public PjRtClient { }; absl::StatusOr GetExecutableExtras(CompileOptions* options); - absl::StatusOr> CompileInternal( + absl::StatusOr> CompileInternal( const XlaComputation& computation, const std::vector& argument_layout_pointers, LayoutCanonicalizationCallback layout_canonicalization_callback, CompileOptions options); + absl::StatusOr> BuildPjRtExecutable( + std::vector> local_executables, + CompileOptions compile_options); + + absl::StatusOr< + std::pair>, CompileOptions>> + DeserializeToLocalExecutable(absl::string_view serialized, + std::optional options); + + absl::StatusOr> LoadInternal( + std::vector> local_executables, + CompileOptions compile_options); + absl::StatusOr> BufferFromHostBufferInternal( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.cc b/third_party/xla/xla/pjrt/stream_executor_executable.cc index ab82fdaf0c2ec1..a500bcb07d33c3 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.cc +++ b/third_party/xla/xla/pjrt/stream_executor_executable.cc @@ -18,11 +18,15 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/compiler.h" +#include "xla/shape.h" +#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -50,4 +54,57 @@ absl::StatusOr StreamExecutorExecutable::SerializeExecutable() compile_options_.ToProto()); return proto.SerializeAsString(); } + +namespace { + +absl::StatusOr MemoryKindFromSimpleShape( + const Shape& shape, absl::string_view default_memory_kind) { + if (!shape.has_layout()) { + return default_memory_kind; + } + switch (shape.layout().memory_space()) { + case Layout::kHostMemorySpace: + return PinnedHostMemorySpace::kKind; + case Layout::kGenericFastMemorySpace: + case Layout::kDefaultMemorySpace: + return default_memory_kind; + default: + return InvalidArgument("Unexpected memory space %d in output layout", + shape.layout().memory_space()); + } +} + +absl::StatusOr> MemoryKindsFromShape( + const Shape& shape, absl::string_view default_memory_kind) { + if (!shape.IsTuple()) { + TF_ASSIGN_OR_RETURN(absl::string_view memory_kind, + MemoryKindFromSimpleShape(shape, default_memory_kind)); + return {{memory_kind}}; + } + std::vector result; + result.reserve(shape.tuple_shapes_size()); + for (const auto& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + absl::string_view element_memory_kind, + MemoryKindFromSimpleShape(element_shape, default_memory_kind)); + result.push_back(element_memory_kind); + } + return result; +} + +} // namespace + +absl::StatusOr>> +StreamExecutorExecutable::GetOutputMemoryKinds() const { + TF_ASSIGN_OR_RETURN(auto shapes, GetOutputShapes()); + std::vector> out; + out.reserve(shapes.size()); + for (const auto& shape : shapes) { + TF_ASSIGN_OR_RETURN(std::vector memory_kind, + MemoryKindsFromShape(shape, default_memory_kind_)); + out.push_back(memory_kind); + } + return out; +} + } // namespace xla diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.h b/third_party/xla/xla/pjrt/stream_executor_executable.h index 826e4f2912f176..aab85c2552924b 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.h +++ b/third_party/xla/xla/pjrt/stream_executor_executable.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/client/local_client.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" @@ -35,19 +36,32 @@ limitations under the License. namespace xla { class StreamExecutorExecutable : public PjRtExecutable { public: + // TODO(b/407470731): Make `xla::AotCompilationResult` provide APIs for + // getting code size, memory stats, etc, so that we do not need to rely on + // `LocalExecutable`s for such information. StreamExecutorExecutable( const CompileOptions& compile_options, - std::vector> executables, + std::vector> aot_executables, int num_replicas, int num_partitions, absl::string_view name, - absl::string_view fingerprint, - std::optional>> - output_memory_kinds) + absl::string_view fingerprint, absl::string_view default_memory_kind, + std::optional>> + local_executables) : compile_options_(compile_options), - aot_executables_(std::move(executables)), + aot_executables_(std::move(aot_executables)), num_replicas_(num_replicas), num_partitions_(num_partitions), name_(name), - fingerprint_(fingerprint) {} + fingerprint_(fingerprint), + default_memory_kind_(default_memory_kind), + local_executables_(std::move(local_executables)) { + if (local_executables_.has_value()) { + std::vector> hlo_modules; + for (const auto& local_executable : *local_executables_) { + hlo_modules.push_back(local_executable->executable()->shared_module()); + } + hlo_modules_ = std::move(hlo_modules); + } + } absl::StatusOr SerializeExecutable() const override; @@ -59,22 +73,51 @@ class StreamExecutorExecutable : public PjRtExecutable { } absl::StatusOr>> GetHloModules() const override { - return absl::UnimplementedError("GetHloModules is not supported."); + if (!hlo_modules_.has_value()) { + return absl::UnimplementedError("GetHloModules is not supported."); + } + return *hlo_modules_; } - absl::StatusOr>> - GetOutputMemoryKinds() const override { - if (output_memory_kinds_.has_value()) { - return *output_memory_kinds_; + absl::StatusOr GetCompiledMemoryStats() const override { + if (!local_executables_.has_value()) { + return absl::UnimplementedError( + "Retrieving CompiledMemoryStats is not supported."); + } + if (local_executables_->size() != 1) { + return absl::UnimplementedError( + "Retrieving CompiledMemoryStats is not supported for multiple " + "executables."); + } + CompiledMemoryStats memory_stats = CompiledMemoryStats(); + memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); + const HloProto* proto = (*local_executables_)[0]->executable()->hlo_proto(); + if (proto != nullptr) { + memory_stats.serialized_hlo_proto = proto->SerializeAsString(); } - return absl::UnimplementedError("GetOutputMemoryKinds is not supported."); + memory_stats.PopulateBufferStatsFromAllocations( + (*local_executables_)[0]->executable()->GetAllocations()); + return memory_stats; } + + absl::StatusOr>> + GetOutputMemoryKinds() const override; + absl::StatusOr> GetCostAnalysis() const override { return absl::UnimplementedError("GetCostAnalysis is not supported."); } - int64_t SizeOfGeneratedCodeInBytes() const override { return 0; } + int64_t SizeOfGeneratedCodeInBytes() const override { + if (!local_executables_.has_value()) { + return 0; + } + int64_t size = 0; + for (auto& executable : *local_executables_) { + size += executable->executable()->SizeOfGeneratedCodeInBytes(); + } + return size; + } const CompileOptions& compile_options() const { return compile_options_; } std::vector>& aot_executables() { @@ -88,12 +131,14 @@ class StreamExecutorExecutable : public PjRtExecutable { private: CompileOptions compile_options_; std::vector> aot_executables_; + std::optional>> hlo_modules_; int num_replicas_; int num_partitions_; std::string name_; std::string fingerprint_; - std::optional>> - output_memory_kinds_; + absl::string_view default_memory_kind_; + std::optional>> + local_executables_; }; } // namespace xla diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index c809be5b0971f6..1e6ea95df721fa 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -1225,8 +1225,11 @@ absl::StatusOr> BufferAssignment::FromProto( absl::c_copy(alloc_proto.parameter_shape_index(), std::back_inserter(shape_idx_vals)); ShapeIndex shape_index(shape_idx_vals); + const bool parameter_has_alias = + module->input_output_alias_config().ParameterHasAlias( + alloc_proto.parameter_number(), shape_index); allocation->set_entry_computation_parameter( - alloc_proto.parameter_number(), shape_index, false); + alloc_proto.parameter_number(), shape_index, parameter_has_alias); } // Process each logical buffer assigned to the current allocation and create From 2cf7bcd1c2dd81262961d523e012e1464ddfdca7 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 31 Mar 2025 13:13:05 -0700 Subject: [PATCH 0055/1324] Explicitly set untuple results for tuplized results. PiperOrigin-RevId: 742371495 --- third_party/xla/xla/pjrt/pjrt_client_test.cc | 8 +- .../xla/xla/service/hlo_runner_pjrt.cc | 77 ++++++------------- third_party/xla/xla/service/hlo_runner_pjrt.h | 5 +- .../functional_hlo_runner.cc | 8 +- 4 files changed, 33 insertions(+), 65 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_client_test.cc b/third_party/xla/xla/pjrt/pjrt_client_test.cc index e536f267fec45d..1552fdf495206d 100644 --- a/third_party/xla/xla/pjrt/pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_client_test.cc @@ -568,6 +568,8 @@ ENTRY DuplicateDonationError() -> (f32[2, 2], f32[2, 2]) { TF_ASSERT_OK_AND_ASSIGN(auto buffer2, MakeFloatBuffer(client.get(), data, {2, 2})); + xla::ExecuteOptions options; + options.untuple_result = true; { auto result = pjrt_executable->Execute(/*argument_handles=*/{{ buffer0.get(), @@ -575,7 +577,7 @@ ENTRY DuplicateDonationError() -> (f32[2, 2], f32[2, 2]) { buffer1.get(), buffer0.get(), }}, - /*options=*/{}); + /*options=*/options); ASSERT_FALSE(result.ok()); EXPECT_THAT(result.status().message(), ::testing::HasSubstr("f(donate(a), donate(a))")); @@ -587,7 +589,7 @@ ENTRY DuplicateDonationError() -> (f32[2, 2], f32[2, 2]) { buffer2.get(), buffer0.get(), }}, - /*options=*/{}); + /*options=*/options); ASSERT_FALSE(result.ok()); EXPECT_THAT(result.status().message(), ::testing::HasSubstr("f(a, donate(a))")); @@ -599,7 +601,7 @@ ENTRY DuplicateDonationError() -> (f32[2, 2], f32[2, 2]) { buffer2.get(), buffer2.get(), }}, - /*options=*/{}); + /*options=*/options); ASSERT_FALSE(result.ok()); EXPECT_THAT(result.status().message(), ::testing::HasSubstr("f(donate(a), a)")); diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.cc b/third_party/xla/xla/service/hlo_runner_pjrt.cc index 73af1871aae951..67ffa302bd5790 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.cc +++ b/third_party/xla/xla/service/hlo_runner_pjrt.cc @@ -131,35 +131,9 @@ absl::StatusOr> FlattenedParameterLayouts( absl::StatusOr GenerateExecuteOptions(const HloModule& module) { ExecuteOptions execute_options; - // PjRt requires untuple_result if any output leaf buffer is in host memory, - // or if any output leaf buffer is not an array. + // PjRt requires untuple_result if the output is a tuple. if (module.result_shape().IsTuple()) { - bool has_array_output_in_host_memory = false; - bool has_non_array_output = false; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - module.entry_computation_layout().result_shape(), - [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { - if (!subshape.IsArray()) { - if (!subshape.IsTuple()) { - has_non_array_output = true; - } - // Skip token, opaque, and tuple outputs. - return absl::OkStatus(); - } - // Arrays require a layout. - if (!subshape.has_layout()) { - return absl::InvalidArgumentError( - "GenerateExecuteOptions requires that all array subshapes of " - "the result shape have layouts."); - } - if (subshape.layout().memory_space() == Layout::kHostMemorySpace) { - has_array_output_in_host_memory = true; - } - return absl::OkStatus(); - })); - - execute_options.untuple_result = - has_array_output_in_host_memory || has_non_array_output; + execute_options.untuple_result = true; } return execute_options; } @@ -545,22 +519,15 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( TF_ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const wrapped_executable, HloRunnerPjRtExecutable::TryUnwrap(*this, executable)); + xla::ExecuteOptions execute_options; + execute_options.untuple_result = true; return ExecuteReplicatedImpl( [&](absl::Span> argument_buffer_slices) - -> absl::StatusOr>> { + -> absl::StatusOr< + std::vector>>> { TF_ASSIGN_OR_RETURN( - auto execution_results, - wrapped_executable->pjrt_loaded_executable()->Execute( - argument_buffer_slices, {})); - - std::vector> results; - - for (auto& device_execution_result : execution_results) { - for (auto& device_buffer : device_execution_result) { - results.push_back(std::move(device_buffer)); - } - } - + auto results, wrapped_executable->pjrt_loaded_executable()->Execute( + argument_buffer_slices, execute_options)); return results; }, [&](int64_t replica) { return options.arguments.size(); }, @@ -578,7 +545,8 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( << "Only single-computation execution is supported."; return ExecuteReplicatedImpl( [&](absl::Span> argument_buffer_slices) - -> absl::StatusOr>> { + -> absl::StatusOr< + std::vector>>> { TF_RET_CHECK(options.use_threads); // The underlying data is modified concurrently. We don't need to @@ -606,9 +574,12 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( pool.Schedule([&per_replica_results, i, executable, args = argument_buffer_slices[i], device_ptr]() { std::optional> returned_future = {}; + xla::ExecuteOptions options; + options.untuple_result = true; per_replica_results[i] = executable->pjrt_loaded_executable()->ExecuteSharded( - args, device_ptr, {}, /*returned_future=*/returned_future, + args, device_ptr, options, + /*returned_future=*/returned_future, /*fill_future=*/true); if (returned_future.has_value()) { if (const absl::Status& status = returned_future->Await(); @@ -620,19 +591,14 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( } } // Aggregate results. - std::vector> results; + std::vector>> results; for (int64_t i = 0; i < options.num_replicas; ++i) { absl::StatusOr>>& replica_result = per_replica_results[i]; if (!replica_result.ok()) { return replica_result.status(); } - if (replica_result->size() != 1) { - return absl::InternalError(absl::StrFormat( - "Expected a single result for replica %d, got %d results.", i, - replica_result->size())); - } - results.push_back(std::move(std::move(replica_result)->front())); + results.push_back(*std::move(replica_result)); } return results; }, @@ -640,8 +606,9 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( } absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( - std::function>>( - absl::Span>)> + std::function< + absl::StatusOr>>>( + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, @@ -744,7 +711,8 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( VLOG(1) << "Replicated execution started"; TF_ASSIGN_OR_RETURN( - const std::vector> result_buffers, + const std::vector>> + result_buffers, execution_helper(BufferMatToPointerMat(argument_buffer_slices))); VLOG(1) << "Replicated execution terminated"; @@ -753,7 +721,8 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( result_literals.reserve(options.num_replicas); for (int64_t i = 0; i < options.num_replicas; ++i) { TF_ASSIGN_OR_RETURN(Literal literal, - TransferLiteralFromDevice(*result_buffers[i])); + TransferLiteralsFromDevice( + result_buffers[i], result_buffers[i].size() != 1)); result_literals.push_back(std::move(literal)); } diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.h b/third_party/xla/xla/service/hlo_runner_pjrt.h index 33936db01f7f4f..9987466b133663 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.h +++ b/third_party/xla/xla/service/hlo_runner_pjrt.h @@ -163,8 +163,9 @@ class HloRunnerPjRt : public HloRunnerInterface { HloModule* module, bool run_hlo_passes); absl::StatusOr> ExecuteReplicatedImpl( - std::function>>( - absl::Span>)> + std::function< + absl::StatusOr>>>( + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index a9ffe451ae9c05..b308cd9ac5ff56 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -1044,13 +1044,9 @@ FunctionalHloRunner::RunInternal( if (!module.result_shape().IsTuple()) { return false; } - return absl::c_any_of( - module.result_shape().tuple_shapes(), [](const Shape& shape) { - return shape.has_layout() && - shape.layout().memory_space() == Layout::kHostMemorySpace; - }); + return true; }; - // If any output leaf buffer is in host memory, PJRT requires untuple_result. + // If any output leaf buffer is a tuple, PJRT requires untuple_result. bool must_untuple_result = output_has_tuple_leaf_on_host_memory_space(); bool default_untuple_result = must_untuple_result || execute_options.untuple_result; From c364bdd75c83b37968da123ec92aa36b3f69940e Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 31 Mar 2025 14:06:24 -0700 Subject: [PATCH 0056/1324] Run lightweight CSE after constant splitting to reduce compilation time. CSE does not invoke any folders and is annotation-aware PiperOrigin-RevId: 742392986 --- .../xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 76ce4f3fd84bf2..d2fe76322c5e1a 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -155,7 +155,10 @@ TEST_F(ShardyXLATest, CostantSplitter) { EXPECT_EQ(dot->operand(0)->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->operand(0)->opcode(), HloOpcode::kConstant); - EXPECT_NE(dot->operand(0)->operand(0), dot->operand(1)->operand(0)); + + // Constants with identical shardings are expected to be merged. + // TODO(tomnatan): Uncomment this test once sdy pun bumped (3/31/25). + // EXPECT_EQ(dot->operand(0)->operand(0), dot->operand(1)->operand(0)); } TEST_F(ShardyXLATest, Dot) { From e6de31cd717fa40e32b375445e2c1ae3d9b97a69 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 31 Mar 2025 14:09:11 -0700 Subject: [PATCH 0057/1324] [xla:gpu] CommandBuffer: remove For command When constructing command buffers from thunks we can't replace while loop with a for loop, because we need an allocation for the loop counter, but for while loop we only get a buffer for pred[], and it's not enough to keep int32_t counter. Furthermore loop induction variable is updated by the body computation anyway, and condition computation is almost always a trivial compare operation, and by replacing compare fusion with AddI32 kernel we don't gain anything at all. XLA:GPU needs only two conditional commands to represent control flow: Case and While. PiperOrigin-RevId: 742394195 --- .../gpu/runtime/command_buffer_cmd.cc | 43 --- .../backends/gpu/runtime/command_buffer_cmd.h | 28 -- .../gpu/runtime/command_buffer_cmd_emitter.cc | 10 +- .../gpu/runtime/command_buffer_thunk_test.cc | 71 ---- .../service/gpu/tests/command_buffer_test.cc | 65 ---- .../xla/xla/stream_executor/command_buffer.h | 8 - .../cuda/command_buffer_kernels.cc | 365 ++---------------- .../cuda/command_buffer_kernels.h | 3 - .../cuda/cuda_command_buffer.cc | 26 -- .../cuda/cuda_command_buffer.h | 14 - .../stream_executor/gpu/gpu_command_buffer.cc | 63 --- .../stream_executor/gpu/gpu_command_buffer.h | 24 -- .../gpu/gpu_command_buffer_test.cc | 50 --- .../rocm/rocm_command_buffer.cc | 12 - .../rocm/rocm_command_buffer.h | 10 - 15 files changed, 34 insertions(+), 758 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index e082c296b9fce3..07e3580e534ac8 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -820,49 +820,6 @@ CommandBufferCmd::BufferUseVector CaseCmd::buffers() { return {buffers.begin(), buffers.end()}; } -//===----------------------------------------------------------------------===// -// ForCmd -//===----------------------------------------------------------------------===// - -ForCmd::ForCmd(ExecutionStreamId execution_stream_id, int32_t num_iterations, - BufferAllocation::Slice loop_counter, - CommandBufferCmdSequence body_commands) - : CommandBufferCmd(CommandBufferCmdType::kForCmd, execution_stream_id), - num_iterations_(num_iterations), - loop_counter_(loop_counter), - body_commands_(std::move(body_commands)) {} - -absl::Status ForCmd::Initialize(const Thunk::InitializeParams& params, - StateManager& state) { - return body_commands_.Initialize(params, state); -} - -absl::Status ForCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { - se::DeviceMemoryBase loop_counter = - execute_params.buffer_allocations->GetDeviceAddress(loop_counter_); - - VLOG(5) << "ForCmd: num_iterations=" << num_iterations_ - << "; body_commands=" << body_commands_.size(); - VLOG(5) << " loop_counter: " << loop_counter_ << " (" - << loop_counter.opaque() << ")"; - - return command_buffer->For( - num_iterations_, se::DeviceMemory(loop_counter), - CreateBuilder(&body_commands_, &execute_params, &record_params)); -} - -bool ForCmd::force_update() { return body_commands_.force_update(); } - -CommandBufferCmd::BufferUseVector ForCmd::buffers() { - absl::flat_hash_set buffers; - buffers.emplace(loop_counter_, MemoryAccess::kWrite); - buffers.insert(body_commands_.buffers().begin(), - body_commands_.buffers().end()); - return {buffers.begin(), buffers.end()}; -} - //===----------------------------------------------------------------------===// // WhileCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 292f2787aadae9..7454725e13431b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -74,7 +74,6 @@ namespace xla::gpu { V(kMemzeroCmd, "MemzeroCmd") \ V(kMemset32Cmd, "Memset32Cmd") \ V(kCaseCmd, "CaseCmd") \ - V(kForCmd, "ForCmd") \ V(kWhileCmd, "WhileCmd") \ V(kCustomCallCmd, "CustomCallCmd") \ V(kBarrierCmd, "BarrierCmd") \ @@ -576,33 +575,6 @@ class CaseCmd : public CommandBufferCmd { std::vector branches_commands_; }; -//===----------------------------------------------------------------------===// -// ForCmd -//===----------------------------------------------------------------------===// - -class ForCmd : public CommandBufferCmd { - public: - ForCmd(ExecutionStreamId execution_stream_id, int32_t num_iterations, - BufferAllocation::Slice loop_counter, - CommandBufferCmdSequence body_commands); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - bool force_update() override; - - BufferUseVector buffers() override; - - private: - int32_t num_iterations_; - BufferAllocation::Slice loop_counter_; - CommandBufferCmdSequence body_commands_; -}; - //===----------------------------------------------------------------------===// // WhileCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index 7fabd03e6707bb..653649f186ae76 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -117,15 +117,9 @@ static absl::StatusOr Convert( ConvertToCommands(thunk.body_thunk_sequence()->thunks(), synchronization_mode)); - if (std::optional trip_count = thunk.trip_count()) { - return std::make_unique(thunk.execution_stream_id(), *trip_count, + return std::make_unique(thunk.execution_stream_id(), thunk.condition_result_buffer(), - std::move(body_cmds)); - } else { - return std::make_unique( - thunk.execution_stream_id(), thunk.condition_result_buffer(), - std::move(cond_cmds), std::move(body_cmds)); - } + std::move(cond_cmds), std::move(body_cmds)); } static absl::StatusOr Convert(const GemmThunk& thunk) { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 4f5ef88c899069..4b8df69791c71e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -1223,77 +1223,6 @@ TEST(CommandBufferThunkTest, CaseCmd) { ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); } -TEST(CommandBufferThunkTest, ForCmd) { - se::StreamExecutor* executor = GpuExecutor(); - - if (!IsAtLeastCuda12300(executor)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: loop_cnt=0, a=1, b=0 - se::DeviceMemory loop_cnt = executor->AllocateArray(1, 0); - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - - TF_ASSERT_OK(stream->Memset32(&loop_cnt, 0, sizeof(int32_t))); - TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); - TF_ASSERT_OK(stream->MemZero(&b, byte_length)); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_cnt(/*index=*/0, 1, /*color=*/0); - BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_cnt(&alloc_cnt, 0, sizeof(int32_t)); - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - - auto args = {slice_a, slice_b, slice_b}; // b = a + b - auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, - MemoryAccess::kWrite}; - - // Prepare commands sequence for loop `body`. - CommandBufferCmdSequence body_commands; - body_commands.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - - // Prepare commands sequence for thunk. - CommandBufferCmdSequence commands; - commands.Emplace(s0, /*num_iterations=*/10, slice_cnt, - std::move(body_commands)); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); - - ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); - BufferAllocations allocations({loop_cnt, a, b}, 0, &allocator); - - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - - TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); - TF_ASSERT_OK( - thunk.Initialize({executor, static_cast(source), - &allocations, stream.get()})); - - // Execute command buffer thunk and verify that it added the value 10 times. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - TF_ASSERT_OK(stream->BlockHostUntilDone()); - - // Copy `b` data back to host. - std::vector dst(4, 0); - TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); - - ASSERT_EQ(dst, std::vector(4, 10)); -} - TEST(CommandBufferThunkTest, WhileCmd) { // TODO(ezhulenev): Find a way to test WhileCmd: add a test only TraceCmd that // could allow us trace custom kernels to update while loop iterations. Or diff --git a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc index f16683b86629d9..f7c6f95c016f54 100644 --- a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc +++ b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc @@ -196,71 +196,6 @@ TEST_F(CommandBufferTest, IndexConditional) { } } -TEST_F(CommandBufferTest, DISABLED_ForLoop) { - constexpr absl::string_view hlo_text = R"( - HloModule m, is_scheduled=true - - compare_fusion { - p0 = s32[] parameter(0) - ten = s32[] constant(10) - ROOT compare = compare(p0, ten), direction=LT - } - - add_one { - p0 = s32[] parameter(0) - one = s32[] constant(1) - ROOT add = add(p0, one) - } - - add_two { - p0 = f32[] parameter(0) - two = f32[] constant(2.0) - ROOT add = add(p0, two) - } - - body { - p0 = (s32[], f32[]) parameter(0) - cnt = get-tuple-element(p0), index=0 - val = get-tuple-element(p0), index=1 - add_cnt = s32[] fusion(cnt), kind=kLoop, calls=add_one - add_val = f32[] fusion(val), kind=kLoop, calls=add_two - ROOT tuple = (s32[], f32[]) tuple(add_cnt, add_val) - } - - cond { - p0 = (s32[], f32[]) parameter(0) - cnt = get-tuple-element(p0), index=0 - ROOT compare = pred[] fusion(cnt), kind=kLoop, calls=compare_fusion - } - - command_buffer { - p0 = (s32[], f32[]) parameter(0) - ROOT while = while(p0), condition=cond, body=body, - backend_config={"known_trip_count":{"n":"20"}} - } - - ENTRY main { - p0 = (s32[], f32[]) parameter(0) - ROOT call = (s32[], f32[]) call(p0), to_apply=command_buffer - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); - - Literal cnt = LiteralUtil::CreateR0(0); - Literal value = LiteralUtil::CreateR0(0.0); - Literal argument = LiteralUtil::MakeTuple({&cnt, &value}); - - // Because we set the known trip count to 20, the loop will execute 20 times, - // and it will ignore the `cond` result, which would terminate the loop after - // 10 iterations (see WhileLoop test below that runs the real while loop). - Literal expected_cnt = LiteralUtil::CreateR0(20); - Literal expected_value = LiteralUtil::CreateR0(40.0); - Literal expected = LiteralUtil::MakeTuple({&expected_cnt, &expected_value}); - - Literal result = ExecuteNoHloPasses(std::move(module), {&argument}); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); -} - TEST_F(CommandBufferTest, WhileLoop) { constexpr absl::string_view hlo_text = R"( HloModule m, is_scheduled=true diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index ffb86bbeb34525..36ead8cc30de0a 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -189,14 +189,6 @@ class CommandBuffer { virtual absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) = 0; - // Adds a conditional operation that will execute a command buffer constructed - // by the `body_builder` exactly `num_iteration` times. This means the - // condition is known at compile time (`num_iteration` < `loop_counter`), and - // does not require a `cond_builder`. - virtual absl::Status For(int32_t num_iteration, - DeviceMemory loop_counter, - Builder body_builder) = 0; - // Adds a conditional operation that will execute a command buffer constructed // by the `cond_builder` that must update `pred` value, and then depending on // the value might execute command buffer constructed by `body_builder` and diff --git a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc index cb5ca81fa58f67..5e44911065b919 100644 --- a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc +++ b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc @@ -19,8 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/stream_executor/kernel_spec.h" -namespace stream_executor { -namespace cuda { +namespace stream_executor::cuda { namespace { // Collection of helper kernels required by command buffers on CUDA backends. We @@ -35,197 +34,6 @@ namespace { // want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the // graph will keep being executed until the conditional handle becomes `0`. -// PTX kernel compiled from: -// -// __global__ void SetIfCondition(cudaGraphConditionalHandle then_handle, -// bool* predicate) { -// if (*predicate) { -// cudaGraphSetConditional(then_handle, 1); -// } else { -// cudaGraphSetConditional(then_handle, 0); -// } -// } -// -// Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr absl::string_view kSetIfConditionKernel = R"( -.version 4.0 -.target sm_50 -.address_size 64 - -.extern .func cudaGraphSetConditional -( - .param .b64 cudaGraphSetConditional_param_0, - .param .b32 cudaGraphSetConditional_param_1 -) - -.visible .entry set_if_condition( - .param .u64 set_if_condition_param_0, - .param .u64 set_if_condition_param_1 -) -{ - .reg .pred %p<2>; - .reg .b16 %rs<2>; - .reg .b64 %rd<4>; - .loc 1 1 0 - - ld.param.u64 %rd1, [set_if_condition_param_0]; - ld.param.u64 %rd2, [set_if_condition_param_1]; - .loc 1 3 3 - cvta.to.global.u64 %rd3, %rd2; - ld.global.u8 %rs1, [%rd3]; - setp.eq.s16 %p1, %rs1, 0; - @%p1 bra $L__BB0_2; - - .loc 1 4 5 - { // callseq 0, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd1; - .param .b32 param1; - st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 0 - bra.uni $L__BB0_3; - -$L__BB0_2: - .loc 1 6 5 - { // callseq 1, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd1; - .param .b32 param1; - st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 1 - -$L__BB0_3: - .loc 1 8 1 - ret; - -})"; - -// PTX kernel compiled from: -// -// __global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, -// cudaGraphConditionalHandle else_handle, -// bool* predicate) { -// if (*predicate) { -// cudaGraphSetConditional(then_handle, 1); -// cudaGraphSetConditional(else_handle, 0); -// } else { -// cudaGraphSetConditional(then_handle, 0); -// cudaGraphSetConditional(else_handle, 1); -// } -// } -// -// Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr absl::string_view kSetIfElseConditionKernel = R"( -.version 4.0 -.target sm_50 -.address_size 64 - -.extern .func cudaGraphSetConditional -( - .param .b64 cudaGraphSetConditional_param_0, - .param .b32 cudaGraphSetConditional_param_1 -) - -.visible .entry set_if_else_condition( - .param .u64 set_if_else_condition_param_0, - .param .u64 set_if_else_condition_param_1, - .param .u64 set_if_else_condition_param_2 -) -{ - .reg .pred %p<2>; - .reg .b16 %rs<2>; - .reg .b64 %rd<5>; - .loc 1 1 0 - - ld.param.u64 %rd1, [set_if_else_condition_param_0]; - ld.param.u64 %rd2, [set_if_else_condition_param_1]; - ld.param.u64 %rd3, [set_if_else_condition_param_2]; - .loc 1 4 3 - cvta.to.global.u64 %rd4, %rd3; - ld.global.u8 %rs1, [%rd4]; - setp.eq.s16 %p1, %rs1, 0; - @%p1 bra $L__BB0_2; - - .loc 1 5 5 - { // callseq 0, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd1; - .param .b32 param1; - st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 0 - .loc 1 6 5 - { // callseq 1, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd2; - .param .b32 param1; - st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 1 - bra.uni $L__BB0_3; - -$L__BB0_2: - .loc 1 8 5 - { // callseq 2, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd1; - .param .b32 param1; - st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 2 - .loc 1 9 5 - { // callseq 3, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd2; - .param .b32 param1; - st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 3 - -$L__BB0_3: - .loc 1 11 1 - ret; - -})"; - // clang-format off // PTX kernel compiled from: // @@ -390,10 +198,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd5; .param .b32 param1; st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 0 @@ -406,10 +214,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd5; .param .b32 param1; st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 1 @@ -428,10 +236,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd6; .param .b32 param1; st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 3 @@ -444,10 +252,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd6; .param .b32 param1; st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 2 @@ -466,10 +274,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd7; .param .b32 param1; st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 5 @@ -482,10 +290,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd7; .param .b32 param1; st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 4 @@ -504,10 +312,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd8; .param .b32 param1; st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 7 @@ -520,10 +328,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd8; .param .b32 param1; st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 6 @@ -556,10 +364,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd12; .param .b32 param1; st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 8 @@ -572,10 +380,10 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( st.param.b64 [param0+0], %rd12; .param .b32 param1; st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, + call.uni + cudaGraphSetConditional, ( - param0, + param0, param1 ); } // callseq 9 @@ -592,95 +400,6 @@ inline constexpr absl::string_view kSetCaseConditionKernel = R"( })"; -// PTX kernel compiled from: -// -// __global__ void SetForCondition(cudaGraphConditionalHandle handle, -// int32_t* loop_index, -// int32_t num_iterations) { -// if (*loop_index < num_iterations) { -// cudaGraphSetConditional(handle, 1); -// } else { -// cudaGraphSetConditional(handle, 0); -// } -// *loop_index += 1; -// } -// -// Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr absl::string_view kSetForConditionKernel = R"( -.version 4.0 -.target sm_50 -.address_size 64 - -.extern .func cudaGraphSetConditional -( - .param .b64 cudaGraphSetConditional_param_0, - .param .b32 cudaGraphSetConditional_param_1 -) - -.visible .entry set_for_condition( - .param .u64 set_for_condition_param_0, - .param .u64 set_for_condition_param_1, - .param .u32 set_for_condition_param_2 -) -{ - .reg .pred %p<2>; - .reg .b32 %r<5>; - .reg .b64 %rd<4>; - .loc 1 1 0 - - ld.param.u64 %rd2, [set_for_condition_param_0]; - ld.param.u64 %rd3, [set_for_condition_param_1]; - ld.param.u32 %r1, [set_for_condition_param_2]; - .loc 1 3 3 - cvta.to.global.u64 %rd1, %rd3; - ld.global.u32 %r2, [%rd1]; - setp.lt.s32 %p1, %r2, %r1; - @%p1 bra $L__BB0_2; - bra.uni $L__BB0_1; - -$L__BB0_2: - .loc 1 4 5 - { // callseq 1, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd2; - .param .b32 param1; - st.param.b32 [param1+0], 1; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 1 - bra.uni $L__BB0_3; - -$L__BB0_1: - .loc 1 6 5 - { // callseq 0, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd2; - .param .b32 param1; - st.param.b32 [param1+0], 0; - call.uni - cudaGraphSetConditional, - ( - param0, - param1 - ); - } // callseq 0 - -$L__BB0_3: - .loc 1 8 3 - ld.global.u32 %r3, [%rd1]; - add.s32 %r4, %r3, 1; - st.global.u32 [%rd1], %r4; - .loc 1 9 1 - ret; - -})"; - // While condition kernel is the same as an `If` with a single branch. inline constexpr absl::string_view kSetWhileConditionKernel = R"( .version 4.0 @@ -771,31 +490,12 @@ inline constexpr absl::string_view kNoOpKernel = R"( } // namespace -absl::StatusOr GetSetIfConditionKernelLoaderSpec() { - MultiKernelLoaderSpec spec(/*arity=*/2); - spec.AddCudaPtxInMemory(cuda::kSetIfConditionKernel, "set_if_condition"); - return spec; -} - -absl::StatusOr GetSetIfElseConditionKernelLoaderSpec() { - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(cuda::kSetIfElseConditionKernel, - "set_if_else_condition"); - return spec; -} - absl::StatusOr GetSetCaseConditionKernelLoaderSpec() { MultiKernelLoaderSpec spec(/*arity=*/13); spec.AddCudaPtxInMemory(cuda::kSetCaseConditionKernel, "set_case_condition"); return spec; } -absl::StatusOr GetSetForConditionKernelLoaderSpec() { - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(cuda::kSetForConditionKernel, "set_for_condition"); - return spec; -} - absl::StatusOr GetSetWhileConditionKernelLoaderSpec() { MultiKernelLoaderSpec spec(/*arity=*/2); spec.AddCudaPtxInMemory(cuda::kSetWhileConditionKernel, @@ -809,5 +509,4 @@ absl::StatusOr GetNoOpKernelLoaderSpec() { return spec; } -} // namespace cuda -} // namespace stream_executor +} // namespace stream_executor::cuda diff --git a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h index a610b3ebf4f4be..818b2d96cbb8ae 100644 --- a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h +++ b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h @@ -24,10 +24,7 @@ namespace stream_executor::cuda { // These are various kernels that update Gpu conditionals based on the device // memory values, and allow implementing on-device control flow via conditional // command buffers. -absl::StatusOr GetSetIfConditionKernelLoaderSpec(); -absl::StatusOr GetSetIfElseConditionKernelLoaderSpec(); absl::StatusOr GetSetCaseConditionKernelLoaderSpec(); -absl::StatusOr GetSetForConditionKernelLoaderSpec(); absl::StatusOr GetSetWhileConditionKernelLoaderSpec(); absl::StatusOr GetNoOpKernelLoaderSpec(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc index 6b955a956afab3..550e8732859010 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -146,32 +146,6 @@ absl::StatusOr> CudaCommandBuffer::Create( // APIs for launching kernels to update conditional handles. //===----------------------------------------------------------------------===// -absl::StatusOr CudaCommandBuffer::CreateSetForConditionNode( - GraphConditionalHandle conditional, DeviceMemory loop_counter, - int32_t iterations, absl::Span dependencies) { - if (!set_for_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, cuda::GetSetForConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( - set_for_condition_kernel_, - SetForConditionKernel::FactoryType::Create(parent_, spec)); - } - auto kernel_args = - PackKernelArgs(set_for_condition_kernel_, ToCudaGraphHandle(conditional), - loop_counter, iterations); - return CreateKernelNode(dependencies, ThreadDim(), BlockDim(), - *set_for_condition_kernel_, *kernel_args); -} - -absl::Status CudaCommandBuffer::UpdateSetForConditionNode( - GraphNodeHandle handle, GraphConditionalHandle conditional, - DeviceMemory loop_counter, int32_t iterations) { - auto kernel_args = - PackKernelArgs(set_for_condition_kernel_, ToCudaGraphHandle(conditional), - loop_counter, iterations); - return UpdateKernelNode(handle, ThreadDim(), BlockDim(), - *set_for_condition_kernel_, *kernel_args); -} - absl::StatusOr CudaCommandBuffer::CreateSetWhileConditionNode( GraphConditionalHandle conditional, DeviceMemory predicate, absl::Span dependencies) { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h index 2d0f6871046fec..5dc8e494e8c6e8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h @@ -85,16 +85,6 @@ class CudaCommandBuffer final : public GpuCommandBuffer { DeviceMemory index, bool index_is_bool, int32_t batch_offset, bool enable_conditional_default) override; - absl::StatusOr CreateSetForConditionNode( - GraphConditionalHandle conditional, DeviceMemory loop_counter, - int32_t iterations, - absl::Span dependencies) override; - - absl::Status UpdateSetForConditionNode(GraphNodeHandle handle, - GraphConditionalHandle conditional, - DeviceMemory loop_counter, - int32_t iterations) override; - absl::StatusOr CreateSetWhileConditionNode( GraphConditionalHandle conditional, DeviceMemory predicate, absl::Span dependencies) override; @@ -188,9 +178,6 @@ class CudaCommandBuffer final : public GpuCommandBuffer { CUgraphConditionalHandle, CUgraphConditionalHandle, DeviceMemory, bool, int32_t, int32_t, bool>; - using SetForConditionKernel = - TypedKernel, int32_t>; - using SetWhileConditionKernel = TypedKernel>; @@ -198,7 +185,6 @@ class CudaCommandBuffer final : public GpuCommandBuffer { // barriers, updating conditional handles, etc.). NoOpKernel noop_kernel_; SetCaseConditionKernel set_case_condition_kernel_; - SetForConditionKernel set_for_condition_kernel_; SetWhileConditionKernel set_while_condition_kernel_; StreamExecutor* parent_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index c007e183c3426f..06c4e3643ff11f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -129,10 +129,6 @@ GpuCommandBuffer::Dependencies GpuCommandBuffer::GetAutoDependencies() const { return dependencies; } - if (auto* gpu_command = dynamic_cast(command)) { - return Dependencies{gpu_command->conditional_node.handle}; - } - if (auto* gpu_command = dynamic_cast(command)) { return Dependencies{gpu_command->conditional_node.handle}; } @@ -507,65 +503,6 @@ absl::Status GpuCommandBuffer::Case(const Command* command, /*index_is_bool=*/true, branches); } -absl::Status GpuCommandBuffer::For(int32_t num_iteration, - DeviceMemory loop_counter, - Builder body_builder) { - if (state_ == State::kCreate) { - GpuForCommand command = {}; - - // Reset loop counter to zero. - TF_ASSIGN_OR_RETURN( - command.memset_node, - CreateMemsetNode(GetAutoDependencies(), loop_counter, uint32_t{0}, 1)); - - TF_ASSIGN_OR_RETURN(command.conditional, CreateConditionalHandle()); - TF_ASSIGN_OR_RETURN( - command.set_init_condition_node, - CreateSetForConditionNode(command.conditional, loop_counter, - num_iteration, {command.memset_node})); - TF_ASSIGN_OR_RETURN( - command.conditional_node, - CreateConditionalNode({command.set_init_condition_node}, - command.conditional, ConditionType::kWhile)); - - GpuCommandBuffer* body = command.conditional_node.command_buffer.get(); - TF_RETURN_IF_ERROR(body_builder(body)); - TF_ASSIGN_OR_RETURN(command.set_body_condition_node, - body->CreateSetForConditionNode( - command.conditional, loop_counter, num_iteration, - body->GetAutoDependencies())); - TF_RETURN_IF_ERROR(command.conditional_node.command_buffer->Finalize()); - - AppendCommand(std::move(command)); - return absl::OkStatus(); - } - - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - auto* gpu_command = tsl::down_cast(&command); - - // Reset loop counter to zero. - TF_RETURN_IF_ERROR(UpdateMemsetNode(gpu_command->memset_node, loop_counter, - uint32_t{0}, 1)); - TF_RETURN_IF_ERROR(UpdateSetForConditionNode( - gpu_command->set_init_condition_node, gpu_command->conditional, - loop_counter, num_iteration)); - - GpuCommandBuffer* body = gpu_command->conditional_node.command_buffer.get(); - auto body_update_mode = ActivateUpdateMode(body); - - // Update command buffer using user-provided builder callback. - TF_RETURN_IF_ERROR(body->Update()); - TF_RETURN_IF_ERROR(body_builder(body)); - TF_RETURN_IF_ERROR(body->UpdateSetForConditionNode( - gpu_command->set_body_condition_node, gpu_command->conditional, - loop_counter, num_iteration)); - TF_RETURN_IF_ERROR(body->Finalize()); - } - - return UnsupportedStateError(state_); -} - absl::Status GpuCommandBuffer::While(DeviceMemory pred, Builder cond_builder, Builder body_builder) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 0a46b2baab2ed1..44c3538724ed85 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -112,15 +112,6 @@ class GpuCommandBuffer : public CommandBuffer { GraphNodeHandle barrier_node; }; - // A GPU command recorded for the For operation. - struct GpuForCommand : public CommandBuffer::Command { - GraphConditionalHandle conditional; - GraphNodeHandle memset_node; - GraphNodeHandle set_init_condition_node; - GraphNodeHandle set_body_condition_node; - GraphConditionalNodeHandle conditional_node; - }; - // A GPU command recorded for the While operation. struct GpuWhileCommand : public CommandBuffer::Command { GraphConditionalHandle conditional; @@ -180,9 +171,6 @@ class GpuCommandBuffer : public CommandBuffer { absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) override; - absl::Status For(int32_t num_iteration, DeviceMemory loop_counter, - Builder body_builder) override; - absl::Status While(DeviceMemory pred, Builder cond_builder, Builder body_builder) override; @@ -246,18 +234,6 @@ class GpuCommandBuffer : public CommandBuffer { DeviceMemory index, bool index_is_bool, int32_t batch_offset, bool enable_conditional_default) = 0; - // Launches a kernel that updates the state of the given graph conditional - // based on the loop counter and the total number of iterations. If the loop - // counter is less than the number of iterations, `conditional` is set to 1, - // otherwise to 0. The loop counter is also incremented by 1. - virtual absl::StatusOr CreateSetForConditionNode( - GraphConditionalHandle conditional, DeviceMemory loop_counter, - int32_t iterations, absl::Span dependencies) = 0; - - virtual absl::Status UpdateSetForConditionNode( - GraphNodeHandle handle, GraphConditionalHandle conditional, - DeviceMemory loop_counter, int32_t iterations) = 0; - // Launches a kernel that updates the state of the given graph conditional // based on the predicate. If the predicate is true, `conditional` is set to // 1, otherwise to 0. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 9dd6d363fce79b..7945a424c6b7fb 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -629,56 +629,6 @@ TEST(GpuCommandBufferTest, ConditionalCase) { ASSERT_EQ(dst, expected_mul); } -TEST(GpuCommandBufferTest, ConditionalFor) { - Platform* platform = GpuPlatform(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - if (!IsAtLeastCuda12300(executor)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); - TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=0, loop_counter=100 - DeviceMemory loop_counter = executor->AllocateArray(1, 0); - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - - // Set loop counter to 100 to check that command buffer resets it. - TF_ASSERT_OK(stream->Memset32(&loop_counter, 100, sizeof(int32_t))); - TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); - TF_ASSERT_OK(stream->MemZero(&b, byte_length)); - - // Loop body: b = a + b - CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { - return body_cmd->Launch(add, ThreadDim(), BlockDim(4), {}, a, b, b) - .status(); - }; - - int32_t num_iters = 10; - - // Create a command buffer with a single conditional operation. - auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->For(num_iters, loop_counter, body_builder)); - TF_ASSERT_OK(cmd_buffer->Finalize()); - - TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); - - // Copy `b` data back to host. - std::vector dst(4, 42); - TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); - - std::vector expected = {10, 10, 10, 10}; - ASSERT_EQ(dst, expected); -} - TEST(GpuCommandBufferTest, ConditionalWhile) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc index efba55a6d439b7..a989c2f5797ca3 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc @@ -124,18 +124,6 @@ absl::Status RocmCommandBuffer::UpdateSetCaseConditionNode( return absl::UnimplementedError("Conditionals are not supported on ROCM."); } -absl::StatusOr RocmCommandBuffer::CreateSetForConditionNode( - GraphConditionalHandle conditional, DeviceMemory loop_counter, - int32_t iterations, absl::Span dependencies) { - return absl::UnimplementedError("Conditionals are not supported on ROCM."); -} - -absl::Status RocmCommandBuffer::UpdateSetForConditionNode( - GraphNodeHandle handle, GraphConditionalHandle conditional, - DeviceMemory loop_counter, int32_t iterations) { - return absl::UnimplementedError("Conditionals are not supported on ROCM."); -} - absl::StatusOr RocmCommandBuffer::CreateSetWhileConditionNode( GraphConditionalHandle conditional, DeviceMemory predicate, absl::Span dependencies) { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h index 6ea1cde66b2275..299fcf93d66fdc 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h @@ -72,16 +72,6 @@ class RocmCommandBuffer : public GpuCommandBuffer { DeviceMemory index, bool index_is_bool, int32_t batch_offset, bool enable_conditional_default) override; - absl::StatusOr CreateSetForConditionNode( - GraphConditionalHandle conditional, DeviceMemory loop_counter, - int32_t iterations, - absl::Span dependencies) override; - - absl::Status UpdateSetForConditionNode(GraphNodeHandle handle, - GraphConditionalHandle conditional, - DeviceMemory loop_counter, - int32_t iterations) override; - absl::StatusOr CreateSetWhileConditionNode( GraphConditionalHandle conditional, DeviceMemory predicate, absl::Span dependencies) override; From 1dace42ac70daa1661563137a45da88badb915b8 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Mon, 31 Mar 2025 14:22:01 -0700 Subject: [PATCH 0058/1324] Add embedding_lookup composite dynamic lowering pattern PiperOrigin-RevId: 742398819 --- .../stablehlo/transforms/composite_lowering_patterns.td | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index b7e1f252507035..47a3693ab3961b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -152,6 +152,15 @@ def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped : Pat< (HasRankAtLeast<2> $table)]>; def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped2 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $indices, $table), + ConstantStrAttr, $attrs, $_, $_), + (TFL_EmbeddingLookupOp $indices, $table), + [(HasRank<1> $indices), + (I32ElementsVal $indices), + (HasRankAtLeast<2> $table)]>; + +def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped3 : Pat< (MHLO_CompositeOp:$composite (variadic $_, $indices, $table), ConstantStrAttr, $attrs, $_, $_), From 533386299e701049bc03b7714e62d628e5cb4563 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Mon, 31 Mar 2025 15:02:34 -0700 Subject: [PATCH 0059/1324] add SparseCore Reshard PiperOrigin-RevId: 742413566 --- .../python/checkpoint/checkpoint_adapter.py | 2 +- tensorflow/python/tpu/BUILD | 5 - .../tpu_embedding_v3_checkpoint_adapter.py | 337 ++++++++++++++---- ...pu_embedding_v3_checkpoint_adapter_test.py | 170 ++++++++- 4 files changed, 434 insertions(+), 80 deletions(-) diff --git a/tensorflow/python/checkpoint/checkpoint_adapter.py b/tensorflow/python/checkpoint/checkpoint_adapter.py index b0e8b02beeff1b..f599d11ed2567b 100644 --- a/tensorflow/python/checkpoint/checkpoint_adapter.py +++ b/tensorflow/python/checkpoint/checkpoint_adapter.py @@ -65,7 +65,7 @@ def update_restore_inputs( Override this method if the arguments to restore op need to be updated as per the resharding required. Args: - checkpoint_key: The cehckpopoint key as requested by the caller + checkpoint_key: The checkpoint key as requested by the caller shape_and_slice_spec: The shape and slice spec as requested by caller Returns: diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index 4ea628c24ba86d..9a19ae4eaeaf1c 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -810,13 +810,8 @@ pytype_strict_library( ":tpu_embedding_v3_utils", "//tensorflow/core/tpu/kernels:sparse_core_layout_proto_py", "//tensorflow/python/checkpoint:checkpoint_adapter", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:manip_ops", - "//tensorflow/python/ops:variables", "//tensorflow/python/trackable:base", "//tensorflow/python/training:py_checkpoint_reader", "//tensorflow/python/util/protobuf", diff --git a/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py b/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py index bcf545b2674100..77fe6df57bb859 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +++ b/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py @@ -15,8 +15,8 @@ """Checkpoint adapter for TPUEmbedding.""" import collections - -from typing import Mapping, Sequence, Optional +import time +from typing import Mapping, Optional, Sequence from absl import logging @@ -52,7 +52,131 @@ def _shard_info_str(shape, shard_info) -> str: return full_shape_str + slice_spec -class EmbeddingReshardCallback(checkpoint_adapter.ReshardCallback): +def _shard_from_cpu_to_sc( + feature_values: tensor.Tensor, + shape_and_slice: str, + to_shard_layout: Sequence[sparse_core_layout_pb2.SparseCoreTableLayout], +) -> tensor.Tensor: + """Shards the feature tables from CPU to SparseCore.""" + + def pad_value(value, variable_shape, table_shape): + return array_ops.pad( + value, + [ + [0, variable_shape[0] - table_shape[0]], + [0, variable_shape[1] - table_shape[1]], + ], + "CONSTANT", + ) + + var_full_shape, shard_info = _parse_shard_info_str(shape_and_slice) + if shard_info.offset > var_full_shape: + raise ValueError( + "Invalid shard offset: {}. Offset should be less than the full shape" + " of the variable: {}".format( + shard_info.offset, + var_full_shape, + ) + ) + num_sc_per_partition = ( + to_shard_layout[0].num_sparse_cores // to_shard_layout[0].num_partitions + ) + + total_rows_per_sc = to_shard_layout[0].total_rows_per_sparse_core_shard + total_rows_per_partition = total_rows_per_sc * num_sc_per_partition + full_values = {} + if (shard_info.shape[0] % total_rows_per_partition) != 0: + raise ValueError( + "Invalid shard shape: {}. Number of rows in input shard slice should" + " be multiple of number of rows in a partition({})".format( + shard_info.shape, + total_rows_per_partition, + ) + ) + # From the shard info, get the row offsets corresponding to the slice + # being looked up. + required_shard_offsets = range( + shard_info.offset[0], + shard_info.offset[0] + shard_info.shape[0], + total_rows_per_partition, + ) + output_shards = [] + for required_shard_offset in required_shard_offsets: + sharded_tensors = [] + for i in range(num_sc_per_partition): + shard_idx = (required_shard_offset // total_rows_per_sc) + i + for table_idx, layout in enumerate(to_shard_layout): + if table_idx not in full_values: + full_values[table_idx] = pad_value( + feature_values[table_idx], + layout.unsharded_padded_shape, + layout.unsharded_shape, + ) + + table_value = full_values[table_idx] + # Apply rotation to get this table's shard index + table_shard_offset = ( + shard_idx + + (layout.num_sparse_cores - layout.sparse_core_shard_rotation) + ) % layout.num_sparse_cores + sharded_tensors.append( + table_value[ + table_shard_offset :: layout.num_sparse_cores, + :, + ] + ) + output_shards.append(array_ops.concat(sharded_tensors, axis=0)) + logging.vlog( + 1, + "_shard_from_cpu_to_sc: last output_shards.shape: %s", + output_shards[-1].shape, + ) + return array_ops.concat(output_shards, axis=0) + + +def _unshard_from_sc_to_cpu( + stacked_table: tensor.Tensor, + from_shard_layouts: Sequence[sparse_core_layout_pb2.SparseCoreTableLayout], +) -> Sequence[tensor.Tensor]: + """Undo the shard the feature tables into SparseCore stacked table. + + Args: + stacked_table: The value of a SparseCore stacked and sharded table. + from_shard_layouts: The target layouts for the target hardware. + + Returns: + The unsharded feature tables. + """ + logging.vlog( + 1, + "To unshuffle_from_sc_to_cpu on stacked_table.shape: %s", + stacked_table[0].shape, + ) + ret_tensors = [] + + for layout in from_shard_layouts: + padded_table = tpu_embedding_v3_utils.unshuffle_from_sc_to_cpu( + stacked_table[0], + num_sparse_cores=layout.num_sparse_cores, + offset_in_shard=layout.sparse_core_shard_row_offset, + size_in_shard=layout.unsharded_padded_shape[0] + // layout.num_sparse_cores, + shard_rotation=layout.sparse_core_shard_rotation, + ) + + orig_table = tpu_embedding_v3_utils.remove_padding_from_sc( + padded_table, layout.unsharded_shape + ) + + logging.vlog( + 1, "orig_tensors.shape[%s]: %s", layout.table_name, orig_table.shape + ) + ret_tensors.append(orig_table) + + return ret_tensors + + +class EmbeddingUnshardToShardCallback(checkpoint_adapter.ReshardCallback): """Reshard callback for embeddings.""" def __init__( @@ -144,76 +268,135 @@ def reshard( Returns: The resharded tensor slice. """ - def pad_value(value, variable_shape, table_shape): - return array_ops.pad( - value, - [ - [0, variable_shape[0] - table_shape[0]], - [0, variable_shape[1] - table_shape[1]], - ], - "CONSTANT", - ) + return _shard_from_cpu_to_sc( + checkpoint_values, shape_and_slice, self._to_shard_layout + ) - var_full_shape, shard_info = _parse_shard_info_str(shape_and_slice) - if shard_info.offset > var_full_shape: - raise ValueError( - "Invalid shard offset: {}. Offset should be less than the full shape" - " of the variable: {}".format( - shard_info.offset, - var_full_shape, - ) - ) - num_sc_per_partition = ( - self._to_shard_layout[0].num_sparse_cores - // self._to_shard_layout[0].num_partitions + +class EmbeddingReshardCallback(checkpoint_adapter.ReshardCallback): + """Reshard callback for embeddings.""" + + def __init__( + self, + object_local_name: str, + from_shard_layouts: Sequence[ + sparse_core_layout_pb2.SparseCoreTableLayout + ], # table name to layout + to_shard_layouts: Sequence[ + sparse_core_layout_pb2.SparseCoreTableLayout + ], # table name to layout + ): + """Initializes Reshard callback. + + Args: + object_local_name: The local name of the object being restored. + from_shard_layouts: layouts as in checkpoint being restored from. + to_shard_layouts: target layouts as specified in the embedding being + restored. + """ + logging.info("Creating EmbeddingReshardCallback for %s", object_local_name) + self._object_local_name = object_local_name + self._from_shard_layouts = from_shard_layouts + self._to_shard_layouts = to_shard_layouts + + def object_name(self) -> str: + return self._object_local_name + + def update_restore_inputs( + self, checkpoint_key: str, shape_and_slice_spec: str + ) -> tuple[Sequence[str], Sequence[str]]: + """Return the full shape of the stacked that is passed into restore_v2. + + This shape information is required by the restore_v2 process to ensure it + loads the complete tensor from the checkpoint. The full tensor is required + to perform resharding operations. + + Args: + checkpoint_key: The input checkpoint key to be read. + shape_and_slice_spec: The shape and slice spec of the checkpoint key to be + read. + + Returns: + A tuple of (keys, slices) that should be passed to restore_v2 in order to + reshard according to the resharding plan. The restored tensors from + restore_v2 op will usually be passed to reshard method of this class to + get the final resharded value. + """ + logging.vlog( + 1, + "Updating restore v2 inputs for %s[%s]: %s", + checkpoint_key, + self._object_local_name, + shape_and_slice_spec, ) - total_rows_per_sc = self._to_shard_layout[ - 0 - ].total_rows_per_sparse_core_shard - total_rows_per_parition = total_rows_per_sc * num_sc_per_partition - full_values = {} - if (shard_info.shape[0] % total_rows_per_parition) != 0: - raise ValueError( - "Invalid shard shape: {}. Number of rows in input shard slice should" - " be multiple of number of rows in a partition({})".format( - shard_info.shape, - total_rows_per_parition, - ) - ) - # From the shard info, get the row offsets corresponding to the slice - # being looked up. - required_shard_offsets = range( - shard_info.offset[0], - shard_info.offset[0] + shard_info.shape[0], - total_rows_per_parition, + slices = [] + + # use the first layout get the full shape of the stacked table + first_layout = self._from_shard_layouts[0] + full_vocab_size = ( + first_layout.total_rows_per_sparse_core_shard + * first_layout.num_sparse_cores + ) + stack_dim = first_layout.unsharded_padded_shape[1] + full_shape = [full_vocab_size, stack_dim] + logging.vlog( + 1, + "Read checkpoint_key %s: %s", + checkpoint_key, + full_shape, ) - output_shards = [] - for required_shard_offset in required_shard_offsets: - sharded_tensors = [] - for i in range(num_sc_per_partition): - shard_idx = (required_shard_offset // total_rows_per_sc) + i - for table_idx, layout in enumerate(self._to_shard_layout): - if table_idx not in full_values: - full_values[table_idx] = pad_value( - checkpoint_values[table_idx], - layout.unsharded_padded_shape, - layout.unsharded_shape, - ) - table_value = full_values[table_idx] - # Apply rotation to get this table's shard index - table_shard_offset = ( - shard_idx - + (layout.num_sparse_cores - layout.sparse_core_shard_rotation) - ) % layout.num_sparse_cores - sharded_tensors.append( - table_value[ - table_shard_offset :: layout.num_sparse_cores, - :, - ] - ) - output_shards.append(array_ops.concat(sharded_tensors, axis=0)) - return array_ops.concat(output_shards, axis=0) + + slices.append( + _shard_info_str( + full_shape, + trackable_base.ShardInfo(offset=[0, 0], shape=full_shape), + ) + ) + return ([checkpoint_key], slices) + + def reshard( + self, checkpoint_values: tensor.Tensor, shape_and_slice: str + ) -> tensor.Tensor: + # unshard + stime = time.time() + logging.vlog( + 1, + "EmbeddingReshardCallback: starting to reshard [%s]", + self._object_local_name, + ) + unsharded_tensors = _unshard_from_sc_to_cpu( + checkpoint_values, self._from_shard_layouts + ) + + ret = _shard_from_cpu_to_sc( + unsharded_tensors, shape_and_slice, self._to_shard_layouts + ) + + etime = time.time() + logging.info( + "EmbeddingReshardCallback: reshard [%s] took %s", + self._object_local_name, + etime - stime, + ) + return ret + + +def _reorg_layouts( + layouts: Sequence[sparse_core_layout_pb2.SparseCoreTableLayout], +) -> Mapping[str, Sequence[sparse_core_layout_pb2.SparseCoreTableLayout]]: + """Reorg the layouts to be in the order of the logical table.""" + stacked_name_to_table_names = collections.defaultdict(list) + for layout in layouts: + stacked_name_to_table_names[layout.stacked_table_name].append(layout) + for stacked_name in stacked_name_to_table_names.keys(): + sorted_layouts = sorted( + stacked_name_to_table_names[stacked_name], + key=lambda layout: layout.sparse_core_shard_row_offset, + ) + stacked_name_to_table_names[stacked_name] = sorted_layouts + + return stacked_name_to_table_names class TpuEmbeddingV3CheckpointAdapter( @@ -273,7 +456,7 @@ def initialize_reshard_callbacks( ) logging.info("Creating resharding plan for %s", stacked_name) self._checkpoint_to_reshard_callback[sorted_layouts[0].table_name] = ( - EmbeddingReshardCallback( + EmbeddingUnshardToShardCallback( stacked_name, [l.table_name for l in sorted_layouts], sorted_layouts, @@ -284,8 +467,18 @@ def initialize_reshard_callbacks( if not embedding_layouts: # TODO(b/326644306): From sharded to unsharded raise NotImplementedError("Sharded to Unsharded is not implemented yet.") - # TODO(b/326644391): First unshard then shard. - raise NotImplementedError("Changing topology is not implemented yet.") + # Reshard to different SC Layout + from_layouts = _reorg_layouts(list(self._checkpoint_layouts.values())) + to_layouts = _reorg_layouts(list(embedding_layouts.values())) + for stacked_name in from_layouts.keys(): + logging.info("Creating resharding plan for %s", stacked_name) + self._checkpoint_to_reshard_callback[stacked_name] = ( + EmbeddingReshardCallback( + object_local_name=stacked_name, + from_shard_layouts=from_layouts[stacked_name], + to_shard_layouts=to_layouts[stacked_name], + ) + ) def is_layouts_same(self, embedding_layouts) -> bool: """Returns True if the all the embedding and checkpoint layouts are the same. diff --git a/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter_test.py b/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter_test.py index 40f72c7b77b157..c1320148b5afe0 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter_test.py @@ -14,16 +14,17 @@ # ============================================================================== """Tests for tpu_embedding_v3_checkpoint_adapter.""" - from tensorflow.core.tpu.kernels import sparse_core_layout_pb2 from tensorflow.python.compat import v2_compat +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework.constant_op import constant as tf_constant from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.tpu import tpu_embedding_v3_checkpoint_adapter +tf_constant = constant_op.constant + def create_layout( tables_name: str, @@ -288,6 +289,171 @@ def test_is_layouts_same_works(self): layout.num_sparse_cores = 3 self.assertFalse(adapter.is_layouts_same({layout.table_name: layout})) + def test_adapt_to_different_sharded_stacked(self): + + source_layouts = { + "one": create_layout( + tables_name="one", + stacked_table_name="one_two_three", + num_sparse_cores=4, + num_partitions=2, + unsharded_shape=(6, 5), + unsharded_padded_shape=(8, 8), + row_offset=0, + shard_rotation=0, + total_rows_per_sparse_core_shard=6, + ), + "two": create_layout( + tables_name="two", + stacked_table_name="one_two_three", + num_sparse_cores=4, + num_partitions=2, + unsharded_shape=(7, 4), + unsharded_padded_shape=(8, 8), + row_offset=2, + shard_rotation=1, + total_rows_per_sparse_core_shard=6, + ), + "three": create_layout( + tables_name="three", + stacked_table_name="one_two_three", + num_sparse_cores=4, + num_partitions=2, + unsharded_shape=(15, 3), + unsharded_padded_shape=(16, 8), + row_offset=4, + shard_rotation=2, + total_rows_per_sparse_core_shard=6, + ), + } + src_layouts_pb = sparse_core_layout_pb2.SparseCoreTableLayouts() + src_layouts_pb.tables.extend(source_layouts.values()) + + sc_to_sc_adapter = ( + tpu_embedding_v3_checkpoint_adapter.TpuEmbeddingV3CheckpointAdapter( + layouts=src_layouts_pb + ) + ) + + target_layouts = { + "one": create_layout( + tables_name="one", + stacked_table_name="one_two_three", + num_sparse_cores=8, + num_partitions=4, + unsharded_shape=(6, 5), + unsharded_padded_shape=(8, 8), + row_offset=0, + shard_rotation=0, + total_rows_per_sparse_core_shard=4, + ), + "two": create_layout( + tables_name="two", + stacked_table_name="one_two_three", + num_sparse_cores=8, + num_partitions=4, + unsharded_shape=(7, 4), + unsharded_padded_shape=(8, 8), + row_offset=1, + shard_rotation=1, + total_rows_per_sparse_core_shard=4, + ), + "three": create_layout( + tables_name="three", + stacked_table_name="one_two_three", + num_sparse_cores=8, + num_partitions=4, + unsharded_shape=(15, 3), + unsharded_padded_shape=(16, 8), + row_offset=2, + shard_rotation=2, + total_rows_per_sparse_core_shard=4, + ), + } + + # this take a mapping[str, sparse_core_layout_pb2.SparseCoreTableLayout] + sc_to_sc_adapter.initialize_reshard_callbacks(target_layouts) + callback = sc_to_sc_adapter.get_reshard_callback("one_two_three") + self.assertEqual(callback.object_name(), "one_two_three") + updated_keys, updated_slices = callback.update_restore_inputs( + "path/to/embedding/one_two/in/checkpoint", "24 8 6,12:0,8" + ) + self.assertAllEqual( + updated_keys, + [ + "path/to/embedding/one_two/in/checkpoint", + ], + ) + self.assertAllEqual( + updated_slices, + ["24 8 0,24:0,8"], + ) + + one_two_three = tf_constant([ + # table one shard 0 + [0, 0, 0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 0, 0, 0], + [13, 13, 13, 13, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [102, 102, 102, 0, 0, 0, 0, 0], + [106, 106, 106, 0, 0, 0, 0, 0], + [110, 110, 110, 0, 0, 0, 0, 0], + [114, 114, 114, 0, 0, 0, 0, 0], + # table one shard 1 + [1, 1, 1, 1, 1, 0, 0, 0], + [5, 5, 5, 5, 5, 0, 0, 0], + [10, 10, 10, 10, 0, 0, 0, 0], + [14, 14, 14, 14, 0, 0, 0, 0], + [103, 103, 103, 0, 0, 0, 0, 0], + [107, 107, 107, 0, 0, 0, 0, 0], + [111, 111, 111, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + # table one shard 2 + [2, 2, 2, 2, 2, 0, 0, 0], + [6, 6, 6, 6, 6, 0, 0, 0], + [11, 11, 11, 11, 0, 0, 0, 0], + [15, 15, 15, 15, 0, 0, 0, 0], + [100, 100, 100, 0, 0, 0, 0, 0], + [104, 104, 104, 0, 0, 0, 0, 0], + [108, 108, 108, 0, 0, 0, 0, 0], + [112, 112, 112, 0, 0, 0, 0, 0], + # table one shard 3 + [3, 3, 3, 3, 3, 0, 0, 0], + [7, 7, 7, 7, 7, 0, 0, 0], + [12, 12, 12, 12, 0, 0, 0, 0], + [16, 16, 16, 16, 0, 0, 0, 0], + [101, 101, 101, 0, 0, 0, 0, 0], + [105, 105, 105, 0, 0, 0, 0, 0], + [109, 109, 109, 0, 0, 0, 0, 0], + [113, 113, 113, 0, 0, 0, 0, 0], + ]) + + self.assertAllEqual( + tf_constant([ + # shard 2 + [2, 2, 2, 2, 2, 0, 0, 0], + [11, 11, 11, 11, 0, 0, 0, 0], + [100, 100, 100, 0, 0, 0, 0, 0], + [108, 108, 108, 0, 0, 0, 0, 0], + # shard 3 + [3, 3, 3, 3, 3, 0, 0, 0], + [12, 12, 12, 12, 0, 0, 0, 0], + [101, 101, 101, 0, 0, 0, 0, 0], + [109, 109, 109, 0, 0, 0, 0, 0], + # shard 4 + [4, 4, 4, 4, 4, 0, 0, 0], + [13, 13, 13, 13, 0, 0, 0, 0], + [102, 102, 102, 0, 0, 0, 0, 0], + [110, 110, 110, 0, 0, 0, 0, 0], + # shard 5 + [5, 5, 5, 5, 5, 0, 0, 0], + [14, 14, 14, 14, 0, 0, 0, 0], + [103, 103, 103, 0, 0, 0, 0, 0], + [111, 111, 111, 0, 0, 0, 0, 0], + ]), + callback.reshard([one_two_three], "32 8 8,16:0,8"), + ) + if __name__ == "__main__": v2_compat.enable_v2_behavior() From 8e4e6b8da65dc77f23941a3a7520ca04e103f51b Mon Sep 17 00:00:00 2001 From: "Ryan M. Lefever" Date: Mon, 31 Mar 2025 15:08:41 -0700 Subject: [PATCH 0060/1324] Fix a bug when we have 2 valid live AllocationValues for an HloValue. This comes up with asychronous operations. PiperOrigin-RevId: 742416086 --- .../xla/service/memory_space_assignment/BUILD | 2 + .../memory_space_assignment/algorithm.cc | 78 +++++- .../memory_space_assignment/algorithm.h | 21 +- .../allocation_value.h | 11 + .../memory_space_assignment_test.cc | 244 +++++++++++++++--- .../memory_space_assignment_test_base.h | 38 +++ 6 files changed, 356 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 156522a6ed6a4b..0a3650e19da5b6 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -188,7 +188,9 @@ cc_library( "//xla/service/cost_modelling:op_cost", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 46d5360f847dda..914dea385720e4 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -627,6 +627,17 @@ void MsaAlgorithm::FindAliases( } } +std::string MsaAlgorithm::RequiredMemoryAssignment::ToString() const { + std::string memory_space_str = + memory_space == MemorySpace::kDefault ? "def" : "alt"; + std::string offset_str = + offset == nullptr ? "null" : absl::StrCat(offset->offset); + + return absl::StrCat( + "RequiredMemoryAssignment(memory_space=", memory_space_str, + ", time=", time, ", offset=", offset_str, ")"); +} + std::vector MsaAlgorithm::GetSortedColocatedIntervals( const MsaBufferInterval& interval) const { std::vector colocated_intervals; @@ -2646,7 +2657,8 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( preferred_offset_for_allocation_value.at(&allocation_value_to_update), definition_time_for_allocation_value.at(&allocation_value_to_update), RequiresNoCopyAlternateMemAllocation(allocation_value_to_update), - all_use_times, entry.only_extend_existing_allocation); + all_use_times, entry.only_extend_existing_allocation, + allocation_values.subspan(0, alloc_value_idx)); if (options_.allocation_request_modifier_testing_fn) { options_.allocation_request_modifier_testing_fn(request); } @@ -2822,7 +2834,8 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest( AliasedOffset* preferred_offset, int64_t definition_time, bool require_no_copy_alternate_mem_allocation, const std::vector& all_use_times, - bool only_extend_existing_allocation) { + bool only_extend_existing_allocation, + absl::Span processed_allocation_values) { const HloUse& hlo_use = use.hlo_use; const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); bool require_copy_allocation = false; @@ -3070,6 +3083,7 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest( request.end_time = use_time; request.only_extend_existing_allocation = only_extend_existing_allocation; + request.processed_allocation_values = processed_allocation_values; return request; } @@ -4598,6 +4612,60 @@ std::string MsaAlgorithm::ResultToString(const AllocationResult& result) { return result_str; } +void MsaAlgorithm::CheckAndUpdateForDualLiveAllocationValues( + const std::optional& + required_memory_assignment_at_start, + AllocationRequest& request) { + if (!request.allocation_value->requires_contiguous_allocation()) { + return; + } + if (!required_memory_assignment_at_start.has_value()) { + return; + } + if (required_memory_assignment_at_start->memory_space != + MemorySpace::kAlternate) { + return; + } + // Go through previous allocations, for the same HloValue, and check if they + // have already allocated alternate memory at the beginning of the current + // AllocationValue, such that we are required to use the same heap offset. + std::vector overlapping_allocations; + Chunk required_chunk = Chunk::FromOffsetSize( + required_memory_assignment_at_start->offset->offset, request.size); + for (const AllocationValue& processed_allocation_value : + request.processed_allocation_values) { + for (const std::unique_ptr& allocation : + *processed_allocation_value.allocation_sequence()) { + if (allocation->is_in_alternate_mem() && + allocation->start_time() <= request.inclusive_start_time && + request.inclusive_start_time <= allocation->end_time() && + allocation->chunk() == required_chunk) { + overlapping_allocations.push_back(allocation.get()); + } + } + } + absl::c_sort(overlapping_allocations, + [](const Allocation* a, const Allocation* b) { + return a->start_time() < b->start_time(); + }); + int64_t chunk_start_time = request.inclusive_start_time; + for (const Allocation* allocation : overlapping_allocations) { + chunk_start_time = std::max(chunk_start_time, allocation->end_time() + 1); + } + + // Note, we don't have to set request.preferred_offset, or do anything special + // to handle aliasing. This is done for us. Specifically, before calling + // CheckAndUpdateForDualLiveAllocationValues(), AllocateSegment() inserts a + // PinnedAllocation with no associated heap chunk, at the beginning of + // request.allocation_value. It aliases that PinnedAllocation with any + // overlapping allocations calculated above. In + // AllocateInAlternateMemoryNoCopy(), we will find that PinnedAllocation and + // realize we need to use the same alternate memory offset. + request.no_copy_chunk_inclusive_start_time = chunk_start_time; + VLOG(3) << "Setting the no-copy chunk (inc) start time to " + << chunk_start_time; +} + AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { auto allocation_sequence = request.allocation_value->mutable_allocation_sequence(); @@ -4723,6 +4791,8 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { required_memory_space_at_end != MemorySpace::kDefault && request.allow_no_copy_alternate_mem_allocation && !request.require_copy_allocation) { + CheckAndUpdateForDualLiveAllocationValues(required_assignment_at_start, + request); allocation_result = AllocateInAlternateMemoryNoCopy(request); if (allocation_result == AllocationResult::kSuccess) { return AllocationResult::kSuccess; @@ -5119,6 +5189,10 @@ AllocationResult MsaAlgorithm::AllocateInAlternateMemoryNoCopy( // If there is a previous allocation, set the start time one after the end // of the previous allocation's end. alternate_mem_interval.start = prev_allocation->end_time() + 1; + if (request.no_copy_chunk_inclusive_start_time.has_value()) { + alternate_mem_interval.start = + *request.no_copy_chunk_inclusive_start_time; + } } if (request.preferred_offset) { diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 5fbb198394f083..78ebbca77c0112 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -354,6 +355,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { bool operator!=(const RequiredMemoryAssignment& other) const { return !(*this == other); } + + std::string ToString() const; }; // A struct that contains a pointer to loop-optimized allocation along with @@ -627,6 +630,9 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // only_extend_existing_allocation is true, no new Allocations will be created // while processing the resulting AllocationRequest, and we only need to // extend an existing Allocation's end_time. + // + // * processed_allocation_values: The AllocationValues that have already been + // processed for the same parent HloValue as is used in the request. AllocationRequest CreateAllocationRequest( AllocationValue& allocation_value, AllocationValue& allocation_value_to_update, @@ -634,7 +640,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { AliasedOffset* preferred_offset, int64_t definition_time, bool require_no_copy_alternate_mem_allocation, const std::vector& all_use_times, - bool only_extend_existing_allocation); + bool only_extend_existing_allocation, + absl::Span processed_allocation_values); // Returns true, if the allocation value requires a pinned allocation in the // alternate memory space. @@ -663,6 +670,18 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { absl::StatusOr AllocateAllocationValues( absl::Span allocation_values); + // Checks for a situation in which an HloValue has more than one live + // AllocationValue at the same time, and the already processed AllocationValue + // has been given alternate memory at the start of the second AllocationValue. + // If such a case is detected, we set + // request.no_copy_chunk_inclusive_start_time with the time where the first + // AllocationValue left off. AllocateInAlternateMemoryNoCopy() takes advantage + // of that information. + void CheckAndUpdateForDualLiveAllocationValues( + const std::optional& + required_memory_assignment_at_start, + AllocationRequest& request); + // Finds an allocation for an allocation request for a segment (see the // documentation for AllocationRequest above how a segment is defined). // diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation_value.h b/third_party/xla/xla/service/memory_space_assignment/allocation_value.h index ad4c9a4e22aa47..0d79a6cb10c0e5 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation_value.h +++ b/third_party/xla/xla/service/memory_space_assignment/allocation_value.h @@ -264,6 +264,17 @@ struct AllocationRequest { // Data structure that contains the options for making window prefetched // allocations. const WindowPrefetchedAllocation::Options* window_prefetch_options = nullptr; + // Previously processed AllocationValues, with the same parent HloValue as the + // request. + absl::Span processed_allocation_values; + // An optional override starting time for the placement of a chunk on the MSA + // heap, for a no-copy allocation (see + // MsaAlgorithm::AllocateInAlternateMemoryNoCopy() for more details). + // + // Note, this override is used when an aliased AllocationValue has already + // done some of the heap allocation for us. So this request picks up where it + // left off. + std::optional no_copy_chunk_inclusive_start_time; }; // Result of an allocation, prefetch, eviction etc. request. The result is diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 5cbb4a12ad2c5b..6e87950496f0b9 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -6269,6 +6269,215 @@ TEST_F(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) { AssignMemorySpace(module.get(), options); } +TEST_F(MemorySpaceAssignmentTest, TwoLiveAllocationValuesBase) { + // In this example, we have enough space to give negate.0 alternate memory, + // and we put put negate.0 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. + // + // We are testing a fix for dual live AllocationsValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for more than 1 instruction. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*06:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.0, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, + TwoLiveAllocationValuesTwoInstructionOverlap) { + // In this example, we have enough space to give negate.0 alternate memory, + // and we put put negate.0 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. + // + // We are testing a fix for dual live AllocationValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for 2 instructions + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*06:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.0, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, + TwoLiveAllocationValuesFirstLiveAllocationValueOutlastsSecond) { + // In this example, we have enough space to give negate.0 alternate memory, + // and we put put negate.0 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. + // + // We are testing a fix for dual live AllocationValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - Segment A0.S2 defined during [add.0, add.1] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for more than 1 instruction. A0 is live beyond the + // end of A1. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*06:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ add.1 = f32[10,10,10,10] add(add.0, negate.0) + /*11:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.1, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, + TwoLiveAllocationValuesUnableToAllocateContiguousAltMem) { + // In this example, we have enough space to give v.2 alternate memory, + // and we put v.2 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. Second, we try to give negate.0 alternate + // memory, but we can't. In order to give negate.0 alternate memory, we need + // to give it contiguous alternate memory during cp-start.0 to cp-done.0. + // (negate.0 and cp-start.0 {0} alias.) However, v.2 is taking too much + // alternate memory to accommodate that request. + // + // We are testing a fix for dual live AllocationValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for more than 1 instruction. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*06:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.0, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"v.2", "negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* v2 = FindInstruction(module.get(), "v.2"); + ASSERT_NE(v2, nullptr); + EXPECT_EQ(v2->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_NE(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + TEST_F(MemorySpaceAssignmentTest, AvoidRedundantEvictionInWhile) { absl::string_view hlo_string = R"( HloModule module, is_scheduled=true @@ -7185,41 +7394,6 @@ ENTRY entry { .memory_space() == kAlternateMemorySpace); } -TEST_F(MemorySpaceAssignmentTest, AsyncOpShortLiveRangeInputBufferConsumer) { - absl::string_view hlo_string = R"( -HloModule module, is_scheduled=true - -ENTRY entry { - param = bf16[4]{0} parameter(0) - negate0 = bf16[4]{0} negate(param) - collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}} - negate1 = bf16[4]{0} negate(negate0) - negate2 = bf16[4]{0} negate(negate1) - negate3 = bf16[4]{0} negate(negate2) - collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start) - ROOT add = add(collective-permute-done, negate3) -} - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - AssignMemorySpace(module.get()); - - // Expect only the destination buffer to get alternate memory allocation - // because negate0 is also used by negate1. - HloInstruction* collective_permute_start = - module->entry_computation()->GetInstructionWithName( - "collective-permute-start"); - EXPECT_TRUE(collective_permute_start->shape() - .tuple_shapes(0) - .layout() - .memory_space() == kDefaultMemorySpace); - EXPECT_TRUE(collective_permute_start->shape() - .tuple_shapes(1) - .layout() - .memory_space() == kAlternateMemorySpace); -} - TEST_F(MemorySpaceAssignmentTest, AsyncOpLongLiveRange) { absl::string_view hlo_string = R"( HloModule module, is_scheduled=true diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h index 90296f30644c00..d031f9ea89bc3d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h @@ -17,14 +17,19 @@ limitations under the License. #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_TEST_BASE_H_ #include +#include #include #include #include #include #include +#include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -133,6 +138,39 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { return options; } + // Creates an MsaBufferIntervalCompare function that prioritizes the + // instructions named in prioritized_instruction_names, in the order + // specified. We default to alphabetical instruction name order for the + // remaining instructions. + static MsaBufferIntervalCompare + CreateBufferIntervalCompareFnFromInstructionNames( + std::vector prioritized_instruction_names) { + absl::flat_hash_map instruction_name_to_priority; + // A lower priority value means its higher on the Msa sort list. + for (size_t i = 0; i < prioritized_instruction_names.size(); ++i) { + instruction_name_to_priority[prioritized_instruction_names[i]] = i; + } + return [instruction_name_to_priority = + std::move(instruction_name_to_priority)]( + const MsaBufferInterval& a, const MsaBufferInterval& b) { + auto get_sort_tuple = [&instruction_name_to_priority]( + const MsaBufferInterval& buffer_interval) { + auto it = instruction_name_to_priority.find( + buffer_interval.buffer->defining_instruction()->name()); + if (it != instruction_name_to_priority.end()) { + return std::make_tuple( + it->second, + buffer_interval.buffer->defining_instruction()->name()); + } + return std::make_tuple( + instruction_name_to_priority.size(), + buffer_interval.buffer->defining_instruction()->name()); + }; + + return get_sort_tuple(a) < get_sort_tuple(b); + }; + } + std::unique_ptr AssignMemorySpaceUsingCostAnalysis( HloModule* module, std::optional memory_space_options_override = std::nullopt, From 822a0da45fc6ec853bbdc05e510deedd8b06af98 Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Mon, 31 Mar 2025 15:45:22 -0700 Subject: [PATCH 0061/1324] Add dynamic input shape support for rfft2d conversion. Removing the additional constraint that fft_lengths need to be powers of 2. This doesn't seem to be a requirement on the runtime side. PiperOrigin-RevId: 742428567 --- .../lite/stablehlo/tests/prepare_hlo.mlir | 48 ++++++++ .../stablehlo/tests/tfl_legalize_hlo.mlir | 20 ++++ .../transforms/legalize_hlo_conversions/BUILD | 2 +- .../legalize_hlo_conversions/fft.cc | 103 ++++++++++++++---- 4 files changed, 150 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir index f363b369d76373..2fa440eee1a3f3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir @@ -845,3 +845,51 @@ func.func @mhlo_nd_fft(%arg0: tensor<2x3x345x256xf32>) -> tensor<2x3x345x129xcom // CHECK: return %2 : tensor<2x3x345x129xcomplex> // ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_1 +func.func @mhlo_dynamic_fft_1(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %4 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %5 = mhlo.reshape %4 : (tensor) -> tensor<1xi32> + // CHECK: %6 = "mhlo.concatenate"(%5, %3, %2, %1) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: %7 = mhlo.dynamic_reshape %arg0, %6 : (tensor, tensor<4xi32>) -> tensor + // CHECK: %8 = "mhlo.fft"(%7) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: %9 = "mhlo.get_dimension_size"(%8) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK: %10 = mhlo.reshape %9 : (tensor) -> tensor<1xi32> + // CHECK: %11 = "mhlo.concatenate"(%10, %3, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK: %12 = mhlo.dynamic_reshape %8, %11 : (tensor>, tensor<3xi32>) -> tensor> + // CHECK: return %12 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2 +func.func @mhlo_dynamic_fft_2(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %3 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %4 = mhlo.reshape %3 : (tensor) -> tensor<1xi32> + // CHECK: %5 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor + // CHECK: %6 = mhlo.reshape %5 : (tensor) -> tensor<1xi32> + // CHECK: %7 = "mhlo.concatenate"(%4, %6, %2, %1) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor, tensor<4xi32>) -> tensor + // CHECK: %9 = "mhlo.fft"(%8) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: %10 = "mhlo.get_dimension_size"(%9) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK: %11 = mhlo.reshape %10 : (tensor) -> tensor<1xi32> + // CHECK: %12 = "mhlo.get_dimension_size"(%9) <{dimension = 1 : i64}> : (tensor>) -> tensor + // CHECK: %13 = mhlo.reshape %12 : (tensor) -> tensor<1xi32> + // CHECK: %14 = "mhlo.concatenate"(%11, %13, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK: %15 = mhlo.dynamic_reshape %9, %14 : (tensor>, tensor<3xi32>) -> tensor> + // CHECK: return %15 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2_neg +func.func @mhlo_dynamic_fft_2_neg(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: return %0 : tensor> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 6d48cfee7e5438..a77d02e78c1dce 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -3801,6 +3801,26 @@ func.func @mhlo_nd_fft_1(%arg0: tensor<2x3x345x4x256xf32>) -> tensor<2x3x345x4x1 // ----- +// CHECK-LABEL: @mhlo_dynamic_fft_1 +func.func @mhlo_dynamic_fft_1(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %cst = arith.constant dense<[1, 2560]> : tensor<2xi32> + // CHECK: %0 = "tfl.rfft2d"(%arg0, %cst) : (tensor, tensor<2xi32>) -> tensor> + // CHECK: return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2 +func.func @mhlo_dynamic_fft_2(%arg0: tensor) -> tensor> { + %9 = "mhlo.fft"(%arg0) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %9 : tensor> + // CHECK: %cst = arith.constant dense<[1, 2560]> : tensor<2xi32> + // CHECK: %0 = "tfl.rfft2d"(%arg0, %cst) : (tensor, tensor<2xi32>) -> tensor> + // CHECK: return %0 : tensor> +} + //===----------------------------------------------------------------------===// // mhlo.imag //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index c58fca93e6e53d..9e2f1cf33f495f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -341,8 +341,8 @@ cc_library( srcs = ["fft.cc"], hdrs = ["fft.h"], deps = [ - "//tensorflow/compiler/mlir/lite:const_tensor_utils", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc index 8f08a0f8a2b1c0..f2d29774c31c89 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include #include #include +#include "mhlo/IR/hlo_ops.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { @@ -62,14 +62,6 @@ bool IsSupportedRfftOp(mhlo::FftOp fft_op) { if (fft_lengths.size() > 2) return false; // Only support 2D FFT. - // TFLite RFFT2d supports only int32 fft_lengths that are powers of 2. - for (int64_t fft_length : fft_lengths) { - if (fft_length != 1 && (!TFL::IsPowerOfTwo(fft_length) || - fft_length > std::numeric_limits::max())) { - return false; - } - } - // Check if the trailing input shape matches the fft_lengths. const std::vector input_shape = mlir::cast(fft_op.getOperand().getType()).getShape(); @@ -77,6 +69,16 @@ bool IsSupportedRfftOp(mhlo::FftOp fft_op) { fft_lengths.begin(), fft_lengths.end()); } +// Returns a tensor of the dimension size of the input tensor. Result of +// mhlo::GetDimensionSizeOp is always a scalar value, but we need a tensor to +// concatenate with other dimension sizes. +Value GetDimensionSizeTensor(OpBuilder& rewriter, Location loc, Value input, + int64_t dim) { + auto size_scalar = rewriter.create(loc, input, dim); + return rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), size_scalar); +} + // Convert rfft to rfft2d. // The transformation pattern looks like below: // @@ -114,18 +116,22 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { auto input_type = mlir::dyn_cast_or_null(fft_op.getOperand().getType()); const std::vector input_shape = - mlir::cast(fft_op.getOperand().getType()).getShape(); + input_type + ? input_type.getShape() + : mlir::cast(fft_op.getOperand().getType()).getShape(); - auto fft_operand = fft_op.getOperand(); + Value fft_operand = fft_op.getOperand(); auto output_type = mlir::cast(fft_op.getResult().getType()); // Create a new fft_length attribute for the 2D FFT. SmallVector new_fft_lengths = {1, fft_lengths.back()}; auto new_fft_lengths_attr = rewriter.getI64TensorAttr(new_fft_lengths); + bool is_dynamic_shape = !input_type || !input_type.hasStaticShape(); + // Input can have a single trivial batch dim next to the fft dimension, in // which case we don't need to expand the input. - if (input_type && (input_shape[input_shape.size() - 2] != 1)) { + if (input_shape[input_shape.size() - 2] != 1) { const std::vector output_shape = output_type.getShape(); // [a, b, c, d, e] -> [a, b, c, d, 1, e] @@ -133,11 +139,42 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { input_shape.end() - 1}; expanded_input_shape.push_back(1); expanded_input_shape.push_back(input_shape.back()); - // Replace the expand_dims op with a reshape op: - auto expanded_input_type = mlir::RankedTensorType::get( + auto expanded_input_type = tensorflow::GetTypeFromTFTensorShape( expanded_input_shape, input_type.getElementType()); - fft_operand = rewriter.create( - fft_op.getLoc(), expanded_input_type, fft_operand); + + // Dynamic shape needs to be handled separately as mhlo::ReshapeOp does + // not support dynamic shape. + if (is_dynamic_shape) { + // Programmatically- + // 1. Get the dimensions of the input tensor and create shape vector. + // 2. Insert a 1 as the penultimate dimension size. + // 3. Concatenate the dimension sizes to create a new SHAPE tensor. + SmallVector expanded_input_shape_values; + for (int i = 0; i < input_shape.size() - 1; ++i) { + expanded_input_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), fft_operand, i)); + } + expanded_input_shape_values.push_back(rewriter.create( + fft_op.getLoc(), rewriter.getI32TensorAttr({1}))); + expanded_input_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), fft_operand, input_shape.size() - 1)); + + auto expanded_input_shape_tensor = rewriter.create( + fft_op.getLoc(), + RankedTensorType::get( + {static_cast(expanded_input_shape_values.size())}, + rewriter.getI32Type()), + expanded_input_shape_values, 0); + + // Create a new mhlo.dynamic_reshape op with the expanded input and + // expanded input shape. SHAPE tensor is created in the previous step. + fft_operand = rewriter.create( + fft_op.getLoc(), expanded_input_type, fft_operand, + expanded_input_shape_tensor); + } else { + fft_operand = rewriter.create( + fft_op.getLoc(), expanded_input_type, fft_operand); + } SmallVector new_output_shape = {output_shape.begin(), output_shape.end() - 1}; @@ -152,12 +189,34 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { rewriter.create(fft_op.getLoc(), output_type, fft_operand, fft_op.getFftType(), new_fft_lengths_attr); - if (input_type && (input_shape[input_shape.size() - 2] != 1)) { + if (input_shape[input_shape.size() - 2] != 1) { // Squeeze the output dimensions back to 2D. - auto squeeze_op = rewriter.create( - fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult()); - - rewriter.replaceOp(fft_op, squeeze_op.getResult()); + if (is_dynamic_shape) { + SmallVector output_shape_values; + for (int i = 0; i < new_fft.getResult().getType().getShape().size() - 2; + ++i) { + output_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), new_fft.getResult(), i)); + } + output_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), new_fft.getResult(), + new_fft.getResult().getType().getShape().size() - 1)); + + auto shape_tensor = rewriter.create( + fft_op.getLoc(), + RankedTensorType::get( + {static_cast(output_shape_values.size())}, + rewriter.getI32Type()), + output_shape_values, 0); + auto squeeze_op = rewriter.create( + fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult(), + shape_tensor); + rewriter.replaceOp(fft_op, squeeze_op.getResult()); + } else { + auto squeeze_op = rewriter.create( + fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult()); + rewriter.replaceOp(fft_op, squeeze_op.getResult()); + } } else { rewriter.replaceOp(fft_op, new_fft.getResult()); } From a8b000ca8addb2b51f83e8083fb03135d7f6b04a Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Mon, 31 Mar 2025 15:48:11 -0700 Subject: [PATCH 0062/1324] Clean up BUILD rules for `generate_hlo_opt_test_checks.py`. Suffix the `py_strict_library` target's name with `_lib`, and add a `py_strict_binary` target with the old name. PiperOrigin-RevId: 742429393 --- third_party/xla/xla/hlo/tools/BUILD | 8 ++++++-- third_party/xla/xla/hlo/tools/tests/BUILD | 3 +-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/BUILD b/third_party/xla/xla/hlo/tools/BUILD index 631f0ae290f355..2992a203790bda 100644 --- a/third_party/xla/xla/hlo/tools/BUILD +++ b/third_party/xla/xla/hlo/tools/BUILD @@ -1,6 +1,6 @@ # Tools and utilities that aid in XLA development and usage. -load("//xla:py_strict.bzl", "py_strict_library") +load("//xla:py_strict.bzl", "py_strict_binary", "py_strict_library") load( "//xla:xla.default.bzl", "xla_cc_binary", @@ -153,7 +153,11 @@ xla_cc_binary( ) py_strict_library( + name = "generate_hlo_test_checks_lib", + srcs = ["generate_hlo_test_checks.py"], +) + +py_strict_binary( name = "generate_hlo_test_checks", srcs = ["generate_hlo_test_checks.py"], - srcs_version = "PY3", ) diff --git a/third_party/xla/xla/hlo/tools/tests/BUILD b/third_party/xla/xla/hlo/tools/tests/BUILD index a9d3b300e83b70..814872908f243d 100644 --- a/third_party/xla/xla/hlo/tools/tests/BUILD +++ b/third_party/xla/xla/hlo/tools/tests/BUILD @@ -59,9 +59,8 @@ py_strict_test( "generate_hlo_test_checks_test_output.hlo", "//xla/hlo/tools:hlo-opt", ], - python_version = "PY3", deps = [ - "//xla/hlo/tools:generate_hlo_test_checks", + "//xla/hlo/tools:generate_hlo_test_checks_lib", "@absl_py//absl/testing:absltest", ], ) From 6d11978995b546c3f676e303e2ad6bfce93cb171 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 31 Mar 2025 15:48:17 -0700 Subject: [PATCH 0063/1324] [MHLO] Migrate shape analysis passes for pre-HLO lowering to StableHLO PiperOrigin-RevId: 742429425 --- .../xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 1 + .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 11 +- third_party/xla/xla/mlir_hlo/BUILD | 13 +- .../xla/xla/mlir_hlo/mhlo/CMakeLists.txt | 1 - .../mlir_hlo/mhlo/transforms/CMakeLists.txt | 1 - .../mlir_hlo/mhlo/transforms/mhlo_passes.td | 5 - .../xla/xla/mlir_hlo/mhlo/transforms/passes.h | 5 - .../analysis/CMakeLists.txt | 0 .../analysis/shape_component_analysis.cpp} | 104 ++++----- .../analysis/shape_component_analysis.h | 37 ++-- .../stablehlo_ext/transforms/passes.td | 9 + .../symbolic_shape_optimization.cpp} | 50 ++--- .../symbolic-shape-optimization.mlir | 198 +++++++++--------- 13 files changed, 225 insertions(+), 210 deletions(-) rename third_party/xla/xla/mlir_hlo/{mhlo => stablehlo_ext}/analysis/CMakeLists.txt (100%) rename third_party/xla/xla/mlir_hlo/{mhlo/analysis/shape_component_analysis.cc => stablehlo_ext/analysis/shape_component_analysis.cpp} (90%) rename third_party/xla/xla/mlir_hlo/{mhlo => stablehlo_ext}/analysis/shape_component_analysis.h (82%) rename third_party/xla/xla/mlir_hlo/{mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc => stablehlo_ext/transforms/symbolic_shape_optimization.cpp} (95%) rename third_party/xla/xla/mlir_hlo/tests/{Dialect/mhlo => stablehlo_ext}/symbolic-shape-optimization.mlir (79%) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index 376d19e2052f30..0a471554a7acc5 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -154,6 +154,7 @@ cc_library( "//xla/mlir/utils:type_util", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", + "//xla/mlir_hlo:stablehlo_extension_passes", "//xla/service:computation_layout", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 4966a9b5ba814d..26ae589ac5e008 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -92,6 +92,7 @@ limitations under the License. #include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/mlir_hlo/stablehlo_ext/transforms/passes.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo.pb.h" @@ -4100,8 +4101,16 @@ absl::Status PrepareForExport(mlir::ModuleOp module) { // Experimental support for exporting dynamic MHLO programs to HLO. // Only bounded dynamism is planned to be supported; unbounded dynamism // is out of scope for now. + // + // Currently takes overhead if input is MHLO for MHLO->StableHLO, can + // be deleted once conversion can assume StableHLO input. + mlir::mhlo::HloLegalizeToStablehloPassOptions options; + options.allow_xla_features_ = true; + pm.addPass(mhlo::createHloLegalizeToStablehloPass(options)); pm.addNestedPass( - mhlo::createSymbolicShapeOptimizationPass()); + stablehlo_ext::createSymbolicShapeOptimizationPass()); + pm.addPass(mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass(mhlo::createShapeLegalizeToHloPass()); } diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 71f02aee9b56fc..6ced829ae19bd5 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -387,7 +387,6 @@ cc_library( "mhlo/transforms/shape_simplification/shape_simplification.cc", "mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc", "mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc", - "mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc", "mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc", "mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc", ], @@ -895,11 +894,10 @@ gentbl_cc_library( cc_library( name = "shape_component_analysis", - srcs = ["mhlo/analysis/shape_component_analysis.cc"], - hdrs = ["mhlo/analysis/shape_component_analysis.h"], + srcs = ["stablehlo_ext/analysis/shape_component_analysis.cpp"], + hdrs = ["stablehlo_ext/analysis/shape_component_analysis.h"], strip_include_prefix = ".", deps = [ - ":mlir_hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -907,6 +905,7 @@ cc_library( "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -1093,6 +1092,7 @@ cc_library( "stablehlo_ext/transforms/stablehlo_legalize_quant_composite.cpp", "stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp", "stablehlo_ext/transforms/stablehlo_refine_shapes.cpp", + "stablehlo_ext/transforms/symbolic_shape_optimization.cpp", ], hdrs = [ "stablehlo_ext/transforms/passes.h", @@ -1102,17 +1102,22 @@ cc_library( compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", deps = [ + ":shape_component_analysis", ":stablehlo_extension_base", ":stablehlo_extension_ops", ":stablehlo_extension_pass_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@shardy//shardy/dialect/sdy/ir:dialect", diff --git a/third_party/xla/xla/mlir_hlo/mhlo/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/CMakeLists.txt index 347117c8bcb1c7..080e665af12953 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/CMakeLists.txt @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -add_subdirectory(analysis) add_subdirectory(IR) add_subdirectory(transforms) add_subdirectory(utils) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index ac34a3e707454d..249531d74f7e31 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -46,7 +46,6 @@ add_mlir_library(MhloPasses shape_legalize_to_hlo/shape_legalize_to_hlo.cc shape_simplification/shape_simplification.cc sink_constants_to_control_flow/sink_constants_to_control_flow.cc - symbolic_shape_optimization/symbolic_shape_optimization.cc test_infer_shaped_type/test_infer_shaped_type_pass.cc unfuse_batch_norm/unfuse_batch_norm.cc unfuse_batch_norm/unfuse_batch_norm_pass.cc diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index 48b6815ceadefb..f3fc6cdec3a579 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -143,11 +143,6 @@ def MergeAssumingOpsPass : Pass<"mhlo-merge-assuming-ops", "func::FuncOp"> { let constructor = "createMergeAssumingOpsPass()"; } -def SymbolicShapeOptimization : Pass<"symbolic-shape-optimization", "func::FuncOp"> { - let summary = "Analyzes shapes and performs shape-related optimizations"; - let constructor = "createSymbolicShapeOptimizationPass()"; -} - def ShapeSimplification : Pass<"shape-simplification", "mlir::func::FuncOp"> { let summary = "Simplify shape ops"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h index 81b35d2f901731..271f7299e6f04b 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h @@ -64,11 +64,6 @@ std::unique_ptr> createBroadcastPropagationPass(); // larger fusions. std::unique_ptr> createMergeAssumingOpsPass(); -/// Creates a pass to analyze shapes and to use that information for -/// shape-related optimizations. -std::unique_ptr> -createSymbolicShapeOptimizationPass(); - // Pass to simplify shape ops. std::unique_ptr> createShapeSimplification(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/CMakeLists.txt similarity index 100% rename from third_party/xla/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt rename to third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/CMakeLists.txt diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc b/third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/shape_component_analysis.cpp similarity index 90% rename from third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc rename to third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/shape_component_analysis.cpp index 19891ffc24e679..a86e24602a5e57 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/shape_component_analysis.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mhlo/analysis/shape_component_analysis.h" +#include "stablehlo_ext/analysis/shape_component_analysis.h" #include #include @@ -22,17 +22,20 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/ErrorHandling.h" -#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/StablehloOps.h" -using namespace mlir; +namespace mlir { +namespace stablehlo_ext { using SymbolicShapeConstraintsMap = ShapeComponentAnalysis::SymbolicShapeConstraintsMap; @@ -90,17 +93,18 @@ struct ShapeVisitor { if (transitivelyRequestedInfo.isShapeInfo()) { if (value.getDefiningOp()) { backwardAssumingShape(value); - } else if (auto bcast = - value.getDefiningOp()) { + } else if (auto bcast = value.getDefiningOp< + stablehlo::DynamicBroadcastInDimOp>()) { backwardDynamicBroadcastInDimShape(bcast); } else if (auto reshape = - value.getDefiningOp()) { + value.getDefiningOp()) { backwardDynamicReshapeShape(reshape); - } else if (value.getDefiningOp()) { + } else if (value.getDefiningOp()) { backwardReduceShape(value); - } else if (auto transpose = value.getDefiningOp()) { + } else if (auto transpose = + value.getDefiningOp()) { backwardTransposeShape(transpose); - } else if (auto select = value.getDefiningOp()) { + } else if (auto select = value.getDefiningOp()) { backwardSelectShape(select); } else if (auto arg = mlir::dyn_cast(value)) { backwardBlockArgumentShape(arg); @@ -142,19 +146,20 @@ struct ShapeVisitor { backwardTensorFromElements(fromElements); } else if (auto extract = value.getDefiningOp()) { backwardTensorExtract(extract); - } else if (auto add = value.getDefiningOp()) { + } else if (auto add = value.getDefiningOp()) { backwardBinOp(add); - } else if (auto mul = value.getDefiningOp()) { + } else if (auto mul = value.getDefiningOp()) { backwardBinOp(mul); } else if (auto add = value.getDefiningOp()) { backwardBinOp(add); } else if (auto mul = value.getDefiningOp()) { backwardBinOp(mul); - } else if (auto concat = value.getDefiningOp()) { + } else if (auto concat = + value.getDefiningOp()) { backwardConcatenate(concat); - } else if (auto reshape = value.getDefiningOp()) { + } else if (auto reshape = value.getDefiningOp()) { backwardReshape(reshape); - } else if (auto slice = value.getDefiningOp()) { + } else if (auto slice = value.getDefiningOp()) { backwardSlice(slice); } else if (matchPattern(value, m_Constant())) { backwardConstant(value); @@ -176,17 +181,18 @@ struct ShapeVisitor { if (!transitivelyRequestedInfo.isValueInfo()) { if (value.getDefiningOp()) { forwardAssumingShape(value); - } else if (auto broadcast = - value.getDefiningOp()) { + } else if (auto broadcast = value.getDefiningOp< + stablehlo::DynamicBroadcastInDimOp>()) { forwardDynamicBroadcastInDimShape(broadcast); } else if (auto reshape = - value.getDefiningOp()) { + value.getDefiningOp()) { forwardDynamicReshapeShape(reshape); - } else if (value.getDefiningOp()) { + } else if (value.getDefiningOp()) { forwardReduceShape(value); - } else if (auto transpose = value.getDefiningOp()) { + } else if (auto transpose = + value.getDefiningOp()) { forwardTransposeShape(transpose); - } else if (auto select = value.getDefiningOp()) { + } else if (auto select = value.getDefiningOp()) { forwardSelectShape(select); } else if (value.getDefiningOp() && value.getDefiningOp() @@ -217,19 +223,20 @@ struct ShapeVisitor { forwardTensorFromElements(fromElements); } else if (auto extract = value.getDefiningOp()) { forwardTensorExtract(extract); - } else if (auto add = value.getDefiningOp()) { + } else if (auto add = value.getDefiningOp()) { forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; }); - } else if (auto mul = value.getDefiningOp()) { + } else if (auto mul = value.getDefiningOp()) { forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; }); } else if (auto add = value.getDefiningOp()) { forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; }); } else if (auto mul = value.getDefiningOp()) { forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; }); - } else if (auto concat = value.getDefiningOp()) { + } else if (auto concat = + value.getDefiningOp()) { forwardConcatenate(concat); - } else if (auto reshape = value.getDefiningOp()) { + } else if (auto reshape = value.getDefiningOp()) { forwardReshape(reshape); - } else if (auto slice = value.getDefiningOp()) { + } else if (auto slice = value.getDefiningOp()) { forwardSlice(slice); } else if (matchPattern(value, m_Constant())) { forwardConstant(value); @@ -324,21 +331,23 @@ struct ShapeVisitor { } assert(dims.size() == rank && "expect one expression per dimension"); } - void backwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) { + void backwardDynamicBroadcastInDimShape( + stablehlo::DynamicBroadcastInDimOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); backwardsWorklist.push_back( ShapeOrValueInfo::getValueInfoOf(op.getOutputDimensions())); } - void forwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) { + void forwardDynamicBroadcastInDimShape( + stablehlo::DynamicBroadcastInDimOp op) { auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); dims = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOutputDimensions())); } - void backwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) { + void backwardDynamicReshapeShape(stablehlo::DynamicReshapeOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); backwardsWorklist.push_back( ShapeOrValueInfo::getValueInfoOf(op.getOutputShape())); } - void forwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) { + void forwardDynamicReshapeShape(stablehlo::DynamicReshapeOp op) { auto rankedTy = mlir::cast(op.getResult().getType()); auto shapeDims = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOutputShape())); @@ -347,14 +356,14 @@ struct ShapeVisitor { } void backwardReduceShape(Value op) { forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); - auto reduceOp = op.getDefiningOp(); + auto reduceOp = op.getDefiningOp(); if (reduceOp.getInputs().size() == 1) { backwardsWorklist.push_back( ShapeOrValueInfo::getShapeInfoOf(reduceOp.getInputs().back())); } } void forwardReduceShape(Value op) { - auto reduceOp = op.getDefiningOp(); + auto reduceOp = op.getDefiningOp(); if (reduceOp.getInputs().size() != 1) return forwardUnknownShape(op); auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); for (const auto &dim : llvm::enumerate(lookup( @@ -363,23 +372,22 @@ struct ShapeVisitor { dims.push_back(dim.value()); } } - void backwardTransposeShape(mhlo::TransposeOp op) { + void backwardTransposeShape(stablehlo::TransposeOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); backwardsWorklist.push_back( ShapeOrValueInfo::getShapeInfoOf(op.getOperand())); } - void forwardTransposeShape(mhlo::TransposeOp op) { + void forwardTransposeShape(stablehlo::TransposeOp op) { auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getOperand())); - auto elem = mlir::cast(op.getPermutation()); - for (const auto &val : elem) dims.push_back(in[val.getZExtValue()]); + for (const auto &val : op.getPermutation()) dims.push_back(in[val]); } - void backwardSelectShape(mhlo::SelectOp op) { + void backwardSelectShape(stablehlo::SelectOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); backwardsWorklist.push_back( ShapeOrValueInfo::getShapeInfoOf(op.getOnTrue())); } - void forwardSelectShape(mhlo::SelectOp op) { + void forwardSelectShape(stablehlo::SelectOp op) { auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); // Forward the `on_true` operand, it has the same shape as the output. dims = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getOnTrue())); @@ -623,12 +631,12 @@ struct ShapeVisitor { forwardUnknown(v); } } - void backwardConcatenate(mhlo::ConcatenateOp op) { + void backwardConcatenate(stablehlo::ConcatenateOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); for (auto operand : op.getOperands()) backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand)); } - void forwardConcatenate(mhlo::ConcatenateOp op) { + void forwardConcatenate(stablehlo::ConcatenateOp op) { for (auto operand : op.getOperands()) { auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand)); if (in.size() != 1) return forwardUnknown(op); @@ -639,35 +647,34 @@ struct ShapeVisitor { dims.push_back({in[0].symbols, in[0].expr}); } } - void backwardReshape(mhlo::ReshapeOp op) { + void backwardReshape(stablehlo::ReshapeOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); backwardsWorklist.push_back( ShapeOrValueInfo::getValueInfoOf(op.getOperand())); } - void forwardReshape(mhlo::ReshapeOp op) { + void forwardReshape(stablehlo::ReshapeOp op) { auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand())); if (in.size() != 1) return forwardUnknown(op); auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); dims.push_back({in[0].symbols, in[0].expr}); } - void backwardSlice(mhlo::SliceOp op) { + void backwardSlice(stablehlo::SliceOp op) { forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); backwardsWorklist.push_back( ShapeOrValueInfo::getValueInfoOf(op.getOperand())); } - void forwardSlice(mhlo::SliceOp op) { + void forwardSlice(stablehlo::SliceOp op) { // Only handle slices equivalent to an extract. if (!op.getType().hasStaticShape({1})) { return forwardUnknown(op); } auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand())); - auto elem = mlir::cast(op.getStartIndices()); - auto i = (*elem.begin()).getZExtValue(); - if (i >= in.size()) { // Bounds check. + auto first = op.getStartIndices().front(); + if (first >= in.size()) { // Bounds check. return forwardUnknown(op); } - dims.push_back({in[i].symbols, in[i].expr}); + dims.push_back({in[first].symbols, in[first].expr}); } void backwardUnknown(Value v) { forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(v)); @@ -839,3 +846,6 @@ void SymbolicExpr::dump(llvm::raw_ostream &os) const { os << '[' << sym.value().index << "]\n"; } } + +} // namespace stablehlo_ext +} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h b/third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/shape_component_analysis.h similarity index 82% rename from third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h rename to third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/shape_component_analysis.h index 27d3a643de417b..2d49f41e7774e7 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/analysis/shape_component_analysis.h @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H -#define MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H +#ifndef STABLEHLO_EXT_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H_ +#define STABLEHLO_EXT_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H_ #include #include "llvm/Support/raw_ostream.h" -#include "mhlo/IR/hlo_ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Value.h" +#include "stablehlo/dialect/StablehloOps.h" namespace mlir { +namespace stablehlo_ext { // Analysis to infer shape information. // @@ -137,34 +138,38 @@ class ShapeComponentAnalysis { // Clear analysis data structures. void reset(); }; +} // namespace stablehlo_ext } // namespace mlir namespace llvm { template <> -struct DenseMapInfo { - static inline mlir::ShapeComponentAnalysis::Symbol getEmptyKey() { - return {mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo:: - getEmptyKey(), +struct DenseMapInfo { + static inline mlir::stablehlo_ext::ShapeComponentAnalysis::Symbol + getEmptyKey() { + return {mlir::stablehlo_ext::ShapeComponentAnalysis::ShapeOrValueInfo:: + DenseMapInfo::getEmptyKey(), llvm::DenseMapInfo::getEmptyKey()}; } - static inline mlir::ShapeComponentAnalysis::Symbol getTombstoneKey() { - return {mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo:: - getTombstoneKey(), + static inline mlir::stablehlo_ext::ShapeComponentAnalysis::Symbol + getTombstoneKey() { + return {mlir::stablehlo_ext::ShapeComponentAnalysis::ShapeOrValueInfo:: + DenseMapInfo::getTombstoneKey(), llvm::DenseMapInfo::getTombstoneKey()}; } - static unsigned getHashValue(mlir::ShapeComponentAnalysis::Symbol symbol) { + static unsigned getHashValue( + mlir::stablehlo_ext::ShapeComponentAnalysis::Symbol symbol) { return llvm::hash_combine( - mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo:: - getHashValue(symbol.source), + mlir::stablehlo_ext::ShapeComponentAnalysis::ShapeOrValueInfo:: + DenseMapInfo::getHashValue(symbol.source), llvm::DenseMapInfo::getHashValue(symbol.index)); } - static bool isEqual(mlir::ShapeComponentAnalysis::Symbol lhs, - mlir::ShapeComponentAnalysis::Symbol rhs) { + static bool isEqual(mlir::stablehlo_ext::ShapeComponentAnalysis::Symbol lhs, + mlir::stablehlo_ext::ShapeComponentAnalysis::Symbol rhs) { return lhs == rhs; } }; } // namespace llvm -#endif // MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H +#endif // STABLEHLO_EXT_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H_ diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.td b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.td index 8fddce26997b20..05e52c740af0d4 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.td @@ -208,3 +208,12 @@ def SinkConstantsToControlFlowPass : Pass<"stablehlo-ext-sink-constants-to-contr }]; } +def SymbolicShapeOptimizationPass : Pass<"stablehlo-ext-symbolic-shape-optimization", "func::FuncOp"> { + let summary = "Analyzes shapes and performs shape-related optimizations"; + let description = [{ + This pass analyzes shapes and performs shape-related optimizations, mostly + only used from programs resulting from TF compilation. This pass is largely + unmaintained otherwise. + }]; +} + diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/symbolic_shape_optimization.cpp similarity index 95% rename from third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc rename to third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/symbolic_shape_optimization.cpp index 961e512d239686..8ef8621a531bdf 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/symbolic_shape_optimization.cpp @@ -24,26 +24,27 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/analysis/shape_component_analysis.h" -#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo_ext/analysis/shape_component_analysis.h" +#include "stablehlo_ext/transforms/passes.h" // IWYU pragma: keep, passes.h.inc namespace mlir { -namespace mhlo { +namespace stablehlo_ext { -#define GEN_PASS_DEF_SYMBOLICSHAPEOPTIMIZATION -#include "mhlo/transforms/mhlo_passes.h.inc" +#define GEN_PASS_DEF_SYMBOLICSHAPEOPTIMIZATIONPASS +#include "stablehlo_ext/transforms/passes.h.inc" using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo; using Symbol = ShapeComponentAnalysis::Symbol; @@ -170,10 +171,10 @@ LogicalResult analyzeDynamicBroadcastInDimExpandingBehavior( // Analyze `mhlo.dynamic_broadcast_in_dim` op and populate attributes for // statically known expanding and non-expanding dimensions. struct AnnotateExpandingDimensionsInDynamicBroadcastInDim - : public mlir::OpRewritePattern { + : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( - mhlo::DynamicBroadcastInDimOp op, + stablehlo::DynamicBroadcastInDimOp op, mlir::PatternRewriter &rewriter) const override { // Analyze shapes and identify expanding and non-expanding dims. ShapeComponentAnalysis analysis; @@ -186,20 +187,18 @@ struct AnnotateExpandingDimensionsInDynamicBroadcastInDim // Collect possibly already annotated info. auto insertAll = [](llvm::SmallSetVector &dst, - std::optional src) { + std::optional> src) { if (!src) return; - for (auto it : *src) dst.insert(it.getLimitedValue()); + for (auto it : *src) dst.insert(it); }; insertAll(knownExpandingDims, op.getKnownExpandingDimensions()); insertAll(knownNonexpandingDims, op.getKnownNonexpandingDimensions()); // Fail pattern application if there is nothing new to annotate. auto isEqual = [](llvm::SmallSetVector &set, - DenseIntElementsAttr attr) { + ArrayRef attr) { return static_cast(set.size()) == attr.size() && - llvm::all_of(attr, [&](auto it) { - return set.count(it.getLimitedValue()); - }); + llvm::all_of(attr, [&](auto it) { return set.count(it); }); }; if (op.getKnownExpandingDimensions() && op.getKnownNonexpandingDimensions() && @@ -211,9 +210,9 @@ struct AnnotateExpandingDimensionsInDynamicBroadcastInDim // Annotate op in place. rewriter.startOpModification(op); op.setKnownExpandingDimensionsAttr( - rewriter.getI64TensorAttr(knownExpandingDims.takeVector())); + rewriter.getDenseI64ArrayAttr(knownExpandingDims.takeVector())); op.setKnownNonexpandingDimensionsAttr( - rewriter.getI64TensorAttr(knownNonexpandingDims.takeVector())); + rewriter.getDenseI64ArrayAttr(knownNonexpandingDims.takeVector())); rewriter.finalizeOpModification(op); return success(); } @@ -268,7 +267,7 @@ bool isSymbolicProduct(const SymbolicExpr &symbolicExpr, LogicalResult materializeReshapeAsScalarExpand(RankedTensorType operandTy, RankedTensorType resultTy, - mhlo::DynamicReshapeOp op, + stablehlo::DynamicReshapeOp op, PatternRewriter &rewriter) { assert(operandTy.getRank() == 0 && "expect scalar operand"); auto loc = op.getLoc(); @@ -286,7 +285,7 @@ LogicalResult materializeReshapeAsScalarExpand(RankedTensorType operandTy, LogicalResult materializeReshapeAsScalarCollapse(RankedTensorType operandTy, RankedTensorType resultTy, - mhlo::DynamicReshapeOp op, + stablehlo::DynamicReshapeOp op, PatternRewriter &rewriter) { assert(resultTy.getRank() == 0 && "expect scalar result"); auto loc = op.getLoc(); @@ -559,7 +558,7 @@ std::optional> requiresReassociationOfKind( LogicalResult materializeReshapeAsExpandAndCollapse( ShapeComponentAnalysis &shapeAnalysis, RankedTensorType operandTy, - RankedTensorType resultTy, mhlo::DynamicReshapeOp op, + RankedTensorType resultTy, stablehlo::DynamicReshapeOp op, PatternRewriter &rewriter) { // Require sucessful shape analysis for operand and result shape. auto operandShapeInfo = shapeAnalysis.GetShapeInfo(op.getOperand()); @@ -607,9 +606,9 @@ LogicalResult materializeReshapeAsExpandAndCollapse( // Tries to express `dynamic_reshape` ops through `expand_shape` and // `collapse_shape` ops. struct DynamicReshapeToExpandAndCollapseShape final - : public OpRewritePattern { + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::DynamicReshapeOp op, + LogicalResult matchAndRewrite(stablehlo::DynamicReshapeOp op, PatternRewriter &rewriter) const override { auto operandTy = mlir::dyn_cast(op.getOperand().getType()); @@ -770,7 +769,7 @@ struct BroadcastOpLowering final }; class SymbolicShapeOptimizationPass final - : public impl::SymbolicShapeOptimizationBase< + : public impl::SymbolicShapeOptimizationPassBase< SymbolicShapeOptimizationPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -802,10 +801,5 @@ class SymbolicShapeOptimizationPass final } // end namespace -std::unique_ptr> -createSymbolicShapeOptimizationPass() { - return std::make_unique(); -} - -} // end namespace mhlo +} // namespace stablehlo_ext } // end namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/symbolic-shape-optimization.mlir similarity index 79% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir rename to third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/symbolic-shape-optimization.mlir index 17f27e2c67f9aa..f5fb9d2395d2b5 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/symbolic-shape-optimization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --split-input-file --symbolic-shape-optimization | \ +// RUN: mlir-hlo-opt %s --split-input-file --stablehlo-ext-symbolic-shape-optimization | \ // RUN: FileCheck %s // CHECK-LABEL: func @reshape_expand_front @@ -8,7 +8,7 @@ func.func @reshape_expand_front(%arg0: tensor) -> tensor<1x?x?xf32> { %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor %shape = tensor.from_elements %c1, %d0, %d1 : tensor<3xindex> - %reshape = "mhlo.dynamic_reshape"(%arg0, %shape) + %reshape = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<3xindex>) -> tensor<1x?x?xf32> // CHECK: tensor.expand_shape %arg0 [ // CHECK-SAME: [0, 1], [2]] {{.*}} : tensor into tensor<1x?x?xf32> @@ -22,7 +22,7 @@ func.func @reshape_expand_front_static(%arg0: tensor<2x?xf32>) -> tensor<1x2x?xf %d0 = tensor.dim %arg0, %c0 : tensor<2x?xf32> %d1 = tensor.dim %arg0, %c1 : tensor<2x?xf32> %shape = tensor.from_elements %c1, %d0, %d1 : tensor<3xindex> - %reshape = "mhlo.dynamic_reshape"(%arg0, %shape) + %reshape = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<2x?xf32>, tensor<3xindex>) -> tensor<1x2x?xf32> // CHECK: tensor.expand_shape %arg0 [ // CHECK-SAME: [0, 1], [2]] {{.*}} : tensor<2x?xf32> into tensor<1x2x?xf32> @@ -38,7 +38,7 @@ func.func @reshape_expand_back(%arg0: tensor) -> tensor { %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg0, %c1 : tensor %shape = tensor.from_elements %d0, %d1, %c1, %c1 : tensor<4xindex> - %reshape = "mhlo.dynamic_reshape"(%arg0, %shape) + %reshape = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<4xindex>) -> tensor // CHECK: tensor.expand_shape %arg0 [ // CHECK-SAME: [0], [1, 2, 3]] {{.*}} : tensor into tensor @@ -53,8 +53,8 @@ func.func @reshape_expand_scalar(%arg0: tensor) -> tensor { // CHECK-DAG: %[[EXPAND:.*]] = tensor.expand_shape %[[ARG]] [] {{.*}} : tensor into tensor<1x1xf32> // CHECK-DAG: %[[RES:.*]] = tensor.cast %[[EXPAND]] : tensor<1x1xf32> to tensor // CHECK: return %[[RES]] - %shape = mhlo.constant dense<1> : tensor<2xi32> - %reshape = "mhlo.dynamic_reshape"(%arg0, %shape) + %shape = stablehlo.constant dense<1> : tensor<2xi32> + %reshape = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xi32>) -> tensor func.return %reshape : tensor } @@ -64,11 +64,11 @@ func.func @reshape_expand_scalar(%arg0: tensor) -> tensor { // CHECK-LABEL: @reshape_collapse_scalar // CHECK-SAME: %[[ARG:.*]]: tensor func.func @reshape_collapse_scalar(%arg0 : tensor) -> tensor { - %shape = mhlo.constant dense<1> : tensor<0xi32> + %shape = stablehlo.constant dense<1> : tensor<0xi32> // CHECK-DAG: %[[CASTED_ARG:.*]] = tensor.cast %[[ARG]] : tensor to tensor<1x1xf32> // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED_ARG]] [] : tensor<1x1xf32> into tensor // CHECK: return %[[COLLAPSED]] - %reshape = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<0xi32>) -> tensor + %reshape = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<0xi32>) -> tensor func.return %reshape : tensor } @@ -76,10 +76,10 @@ func.func @reshape_collapse_scalar(%arg0 : tensor) -> tensor { // CHECK-LABEL: func @reshape_undefined func.func @reshape_undefined(%arg0: tensor) -> tensor<1x1x1xf32> { - // CHECK: mhlo.dynamic_reshape + // CHECK: stablehlo.dynamic_reshape %c1 = arith.constant 1 : index %shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xindex> - %reshape = "mhlo.dynamic_reshape"(%arg0, %shape) + %reshape = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<3xindex>) -> tensor<1x1x1xf32> func.return %reshape : tensor<1x1x1xf32> } @@ -95,7 +95,7 @@ func.func @shape_expansion(%arg : tensor) -> tensor { %c1 = arith.constant 1 : index %d0 = tensor.dim %arg, %c0 : tensor %shape = tensor.from_elements %d0, %c1, %c1 : tensor<3xindex> - %result = "mhlo.dynamic_reshape"(%arg, %shape) + %result = "stablehlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<3xindex>) -> tensor func.return %result : tensor } @@ -114,7 +114,7 @@ func.func @shape_collapse_and_expansion(%arg : tensor<3x?x1xi64>) %d1 = tensor.dim %arg, %c1 : tensor<3x?x1xi64> %three_d1 = arith.muli %c3, %d1 : index %15 = tensor.from_elements %three_d1, %c1, %c1 : tensor<3xindex> - %16 = "mhlo.dynamic_reshape"(%arg, %15) + %16 = "stablehlo.dynamic_reshape"(%arg, %15) : (tensor<3x?x1xi64>, tensor<3xindex>) -> tensor func.return %16 : tensor } @@ -137,7 +137,7 @@ func.func @shape_collapse_and_expansion_w_cast(%arg0: tensor<16x8x?x?xf32>) -> t %2 = tensor.dim %arg0, %c3 : tensor<16x8x?x?xf32> %4 = arith.muli %1, %2 : index %5 = tensor.from_elements %c16, %c4, %c2, %4 : tensor<4xindex> - %6 = "mhlo.dynamic_reshape"(%arg0, %5) : (tensor<16x8x?x?xf32>, tensor<4xindex>) -> tensor<16x4x?x?xf32> + %6 = "stablehlo.dynamic_reshape"(%arg0, %5) : (tensor<16x8x?x?xf32>, tensor<4xindex>) -> tensor<16x4x?x?xf32> func.return %6 : tensor<16x4x?x?xf32> } @@ -161,7 +161,7 @@ func.func @dynamic_reshape_to_collapse_shape(%arg0 : tensor<1x4x?x64x?x8x1x1xf32 %s0 = arith.muli %c4_i32, %d2_i32 : i32 %s1 = arith.muli %c64_i32, %d4_i32 : i32 %shape = tensor.from_elements %s0, %s1, %c8_i32 : tensor<3xi32> - %result = "mhlo.dynamic_reshape"(%arg0, %shape) + %result = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<1x4x?x64x?x8x1x1xf32>, tensor<3xi32>) -> tensor func.return %result : tensor } @@ -177,7 +177,7 @@ func.func @expansion_unit_dims(%arg0: tensor<1x?x1xi64>) -> tensor<1x1x?x1xi64> %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c1 : tensor<1x?x1xi64> %1 = tensor.from_elements %c1, %c1, %0, %c1 : tensor<4xindex> - %2 = "mhlo.dynamic_reshape"(%arg0, %1) + %2 = "stablehlo.dynamic_reshape"(%arg0, %1) : (tensor<1x?x1xi64>, tensor<4xindex>) -> tensor<1x1x?x1xi64> func.return %2 : tensor<1x1x?x1xi64> } @@ -187,43 +187,43 @@ func.func @expansion_unit_dims(%arg0: tensor<1x?x1xi64>) -> tensor<1x1x?x1xi64> // CHECK-LABEL: @multiple_reductions_and_reshape // CHECK-SAME: %[[ARG:.*]]: tensor func.func @multiple_reductions_and_reshape(%arg0: tensor) -> tensor<1x1x1x1xi64> { - // CHECK: %[[RED0:.*]] = mhlo.reduce(%[[ARG]] + // CHECK: %[[RED0:.*]] = stablehlo.reduce(%[[ARG]] // CHECK: %[[RED0_:.*]] = tensor.expand_shape %[[RED0]] {{\[}}[0], [1], [2, 3]{{\]}} {{.*}} : tensor into tensor - // CHECK: %[[RED1:.*]] = mhlo.reduce(%[[RED0_]] + // CHECK: %[[RED1:.*]] = stablehlo.reduce(%[[RED0_]] // CHECK: %[[RED1_:.*]] = tensor.expand_shape %[[RED1]] {{\[}}[0, 1, 2], [3]{{\]}} {{.*}} : tensor into tensor<1x1x?x1xi64> - // CHECK: %[[RED2:.*]] = mhlo.reduce(%[[RED1_]] + // CHECK: %[[RED2:.*]] = stablehlo.reduce(%[[RED1_]] // TODO(b/225204462): This should also become a shape expansion. - // CHECK: %[[RED2_:.*]] = mhlo.reshape %[[RED2]] : (tensor<1xi64>) -> tensor<1x1x1x1xi64> + // CHECK: %[[RED2_:.*]] = stablehlo.reshape %[[RED2]] : (tensor<1xi64>) -> tensor<1x1x1x1xi64> // CHECK: return %[[RED2_]] - %0 = mhlo.constant dense<9223372036854775807> : tensor + %0 = stablehlo.constant dense<9223372036854775807> : tensor %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index - %1 = mhlo.constant dense<1> : tensor - %2 = mhlo.reduce(%arg0 init: %0) - applies mhlo.minimum across dimensions = [3] + %1 = stablehlo.constant dense<1> : tensor + %2 = stablehlo.reduce(%arg0 init: %0) + applies stablehlo.minimum across dimensions = [3] : (tensor, tensor) -> tensor %3 = tensor.dim %2, %c0 : tensor %4 = tensor.dim %2, %c1 : tensor %5 = tensor.dim %2, %c2 : tensor %6 = tensor.from_elements %3, %4, %5, %c1 : tensor<4xindex> - %7 = "mhlo.dynamic_reshape"(%2, %6) + %7 = "stablehlo.dynamic_reshape"(%2, %6) : (tensor, tensor<4xindex>) -> tensor - %8 = mhlo.reduce(%7 init: %0) - applies mhlo.minimum across dimensions = [0, 1] + %8 = stablehlo.reduce(%7 init: %0) + applies stablehlo.minimum across dimensions = [0, 1] : (tensor, tensor) -> tensor %9 = tensor.dim %8, %c0 : tensor %10 = tensor.from_elements %c1, %9, %c1 : tensor<3xindex> - %11 = "mhlo.dynamic_reshape"(%8, %10) + %11 = "stablehlo.dynamic_reshape"(%8, %10) : (tensor, tensor<3xindex>) -> tensor<1x?x1xi64> %12 = tensor.dim %11, %c1 : tensor<1x?x1xi64> %13 = tensor.from_elements %c1, %c1, %12, %c1 : tensor<4xindex> - %14 = "mhlo.dynamic_reshape"(%8, %13) + %14 = "stablehlo.dynamic_reshape"(%8, %13) : (tensor, tensor<4xindex>) -> tensor<1x1x?x1xi64> - %15 = mhlo.reduce(%14 init: %1) - applies mhlo.multiply across dimensions = [0, 1, 2] + %15 = stablehlo.reduce(%14 init: %1) + applies stablehlo.multiply across dimensions = [0, 1, 2] : (tensor<1x1x?x1xi64>, tensor) -> tensor<1xi64> - %16 = "mhlo.reshape"(%15) : (tensor<1xi64>) -> tensor<1x1x1x1xi64> + %16 = "stablehlo.reshape"(%15) : (tensor<1xi64>) -> tensor<1x1x1x1xi64> func.return %16 : tensor<1x1x1x1xi64> } @@ -284,11 +284,11 @@ func.func @optimize_1dx1d_bcast( %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor, tensor<1xindex>) -> tensor func.return %3: tensor } @@ -305,11 +305,11 @@ func.func @optimize_1dx2d_bcast_const_shape( %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<2xindex> -> tensor<2xindex> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[1]> : tensor<1xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor<512xf32>, tensor<2xindex>) -> tensor func.return %3: tensor } @@ -331,11 +331,11 @@ func.func @optimize_1dx1dx1d_bcast( -> tensor<1xindex> %4 = shape.broadcast %3, %2 : tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> - %5 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %5 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %4) + {broadcast_dimensions = array} : (tensor, tensor<1xindex>) -> tensor func.return %5: tensor } @@ -353,17 +353,17 @@ func.func @optimize_2dx1d_bcast( %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<1xindex> -> tensor<2xindex> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 1]> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor<10x?xf32>, tensor<2xindex>) -> tensor<10x?xf32> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> - %4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2) - {broadcast_dimensions = dense<[1]> : tensor<1xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %4 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %2) + {broadcast_dimensions = array} : (tensor, tensor<2xindex>) -> tensor<10x?xf32> func.return %3, %4: tensor<10x?xf32>, tensor<10x?xf32> } @@ -381,17 +381,17 @@ func.func @optimize_3dx3d_bcast( %1 = shape.shape_of %arg1 : tensor<1x?x1xf32> -> tensor<3xindex> %2 = shape.broadcast %0, %1 : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 2]> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor, tensor<3xindex>) -> tensor - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<1> - %4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2) - {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %4 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %2) + {broadcast_dimensions = array} : (tensor<1x?x1xf32>, tensor<3xindex>) -> tensor func.return %3, %4: tensor, tensor } @@ -412,12 +412,11 @@ func.func @optimize_10d_all_cases( -> tensor<10xindex> %2 = shape.broadcast %0, %1 : tensor<10xindex>, tensor<10xindex> -> tensor<10xindex> - // CHECK: mhlo.dynamic_broadcast_in_dim - // CHECK-SAME: known_expanding_dimensions = dense<1> - // CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 3, 4, 5, 6, 9]> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> - : tensor<10xi64>} + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor<1x1x1x8x8x8x?x?x?x?xf32>, tensor<10xindex>) -> tensor func.return %3: tensor @@ -519,17 +518,16 @@ func.func @optimize_1dx1d_bcast( {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} ) -> tensor { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] - // CHECK: %[[DYNAMIC:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) - // CHECK-SAME: broadcast_dimensions = dense<0> - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: %[[DYNAMIC:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[SHAPE]], dims = [0] + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array // CHECK: return %[[DYNAMIC]] %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor, tensor<1xindex>) -> tensor func.return %3: tensor } @@ -544,17 +542,16 @@ func.func @optimize_1dx2d_bcast_const_shape( {rt.symbolic_shape = dense<[-2, 512]> : tensor<2xi64>} ) -> tensor { // CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG1_0]] - // CHECK: %[[DYNAMIC_0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0_0]], %[[SHAPE_0]]) - // CHECK-SAME: broadcast_dimensions = dense<1> - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: %[[DYNAMIC_0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0_0]], %[[SHAPE_0]], dims = [1] + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array // CHECK: return %[[DYNAMIC_0]] %0 = shape.const_shape [512] : tensor<1xindex> %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<2xindex> -> tensor<2xindex> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[1]> : tensor<1xi64>} + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor<512xf32>, tensor<2xindex>) -> tensor func.return %3: tensor } @@ -572,10 +569,9 @@ func.func @optimize_1dx1dx1d_bcast( {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} ) -> tensor { // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG0_1]] - // CHECK: %[[DYNAMIC_1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0_1]], %[[SHAPE_1]]) - // CHECK-SAME: broadcast_dimensions = dense<0> - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: %[[DYNAMIC_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0_1]], %[[SHAPE_1]], dims = [0] + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array // CHECK: return %[[DYNAMIC_1]] %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> @@ -584,8 +580,8 @@ func.func @optimize_1dx1dx1d_bcast( -> tensor<1xindex> %4 = shape.broadcast %3, %2 : tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> - %5 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + %5 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %4) + {broadcast_dimensions = array} : (tensor, tensor<1xindex>) -> tensor func.return %5: tensor } @@ -601,24 +597,22 @@ func.func @optimize_2dx1d_bcast( {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} ) -> (tensor<10x?xf32>, tensor<10x?xf32>) { // CHECK: %[[SHAPE_2:.*]] = shape.shape_of %[[ARG0_2]] - // CHECK: %[[DYNAMIC_2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0_2]], %[[SHAPE_2]]) - // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 1]> - // CHECK: %[[DYNAMIC_3:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1_2]], %[[SHAPE_2]]) - // CHECK-SAME: broadcast_dimensions = dense<1> - // CHECK-SAME: known_expanding_dimensions = dense<> - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: %[[DYNAMIC_2:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0_2]], %[[SHAPE_2]], dims = [0, 1] + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array + // CHECK: %[[DYNAMIC_3:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1_2]], %[[SHAPE_2]], dims = [1] + // CHECK-SAME: known_expanding_dimensions = array + // CHECK-SAME: known_nonexpanding_dimensions = array // CHECK: return %[[DYNAMIC_2]], %[[DYNAMIC_3]] %0 = shape.shape_of %arg0 : tensor<10x?xf32> -> tensor<2xindex> %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<1xindex> -> tensor<2xindex> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = array} : (tensor<10x?xf32>, tensor<2xindex>) -> tensor<10x?xf32> - %4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2) - {broadcast_dimensions = dense<[1]> : tensor<1xi64>} + %4 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %2) + {broadcast_dimensions = array} : (tensor, tensor<2xindex>) -> tensor<10x?xf32> func.return %3, %4: tensor<10x?xf32>, tensor<10x?xf32> } From 11f104e625a4d20d3aced2d30c8ca7847fecde03 Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Mon, 31 Mar 2025 19:45:14 -0700 Subject: [PATCH 0064/1324] Create better infrastructure for BatchMatMul optimizations. And add a new optimization pattern. Add infrastructure and utilities so its easy and intuitive to create and maintain BatchMatMul optimizations. The optimization pattern introduced in this CL will eliminate a transpose from Transpose->reshape->bmm_rhs if the transpose is used exclusively to transpose the unflattened contracting and output dims. Consider the following transpose->reshape->bmm pattern. ``` %37 = "tfl.transpose"(%28, %36) : (tensor<2048x32x128xf32>, tensor<3xi32>) -> tensor<128x2048x32xf32> %38 = "tfl.pseudo_const"() <{value = dense<[128, 65536]> : tensor<2xi32>}> : () -> tensor<2xi32> %39 = "tfl.reshape"(%37, %38) : (tensor<128x2048x32xf32>, tensor<2xi32>) -> tensor<128x65536xf32> %40 = "tfl.pseudo_const"() <{value = dense_resource<__elided__> : tensor<4x128xf32>}> : () -> tensor<4x128xf32> %41 = "tfl.batch_matmul"(%40, %39) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x128xf32>, tensor<128x65536xf32>) -> tensor<4x65536xf32> ``` This can be re-written to use fully_connected instead of bmm, and we can avaoid the transpose- ``` %39 = "tfl.reshape"(%37, %abc) : (tensor<2048x32x128xf32>, tensor<2xi32>) -> tensor<65536x128xf32> %40 = "tfl.pseudo_const"() <{value = dense_resource<__elided__> : tensor<4x128xf32>}> : () -> tensor<4x128xf32> %41 = "tfl.fully_connected"(%40, %39) (tensor<4x128xf32>, tensor<65536x128xf32>) -> tensor<4x65536xf32> ``` PiperOrigin-RevId: 742494098 --- tensorflow/compiler/mlir/lite/BUILD | 28 ++ .../compiler/mlir/lite/ir/tfl_canonicalize.td | 4 +- .../lite/tests/optimize_batch_matmul.mlir | 33 ++ .../transforms/optimize_batch_matmul_pass.cc | 130 +++++++- .../mlir/lite/transforms/optimize_pass.cc | 15 - .../mlir/lite/transforms/optimize_patterns.td | 12 +- .../optimize_batch_matmul_utils.cc | 303 ++++++++++++++++++ .../optimize_batch_matmul_utils.h | 141 ++++++++ .../optimize_batch_matmul_utils_test.cc | 168 ++++++++++ tensorflow/compiler/mlir/lite/utils/utils.h | 75 ++++- tensorflow/compiler/mlir/lite/utils/utils.td | 2 +- 11 files changed, 868 insertions(+), 43 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc create mode 100644 tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h create mode 100644 tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index b5cac23baa56b4..26e436cc519c72 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -395,6 +395,8 @@ cc_library( name = "utils", hdrs = ["utils/utils.h"], deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", @@ -642,6 +644,8 @@ cc_library( "//tensorflow/core:framework", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", "@llvm-project//llvm:Support", @@ -1213,6 +1217,29 @@ cc_library( ], ) +cc_library( + name = "optimize_batch_matmul_utils", + srcs = ["transforms/tflite_passes/optimize_batch_matmul_utils.cc"], + hdrs = ["transforms/tflite_passes/optimize_batch_matmul_utils.h"], + deps = [ + ":utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "optimize_batch_matmul_utils_test", + srcs = ["transforms/tflite_passes/optimize_batch_matmul_utils_test.cc"], + deps = [ + ":optimize_batch_matmul_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "tensorflow_lite_optimize_batch_matmul", srcs = [ @@ -1224,6 +1251,7 @@ cc_library( ], deps = [ ":convert_type", + ":optimize_batch_matmul_utils", ":pass", ":pass_options", ":tensorflow_lite_ops", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td index a359fb9506b2b3..3881a1e291770d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td @@ -32,7 +32,7 @@ def GetSqueezedPermutation: NativeCodeCall<"GetSqueezedPermutation($0, $1)">; // Check to see if the tensor dimensions can be Squeezed by eliminating 1s' def CanSqueezeTensor : Constraint GetSqueezedShape($0).getNumElements()">>; + "GetShapeAttr($0).getNumElements() > GetSqueezedShape($0).getNumElements()">>; // Pattern to convert TFL_TransposeOp with rank>6 to rank<=6 if there are @@ -50,7 +50,7 @@ def ConvertTransposeToDecreaseRank : Pat< (TFL_TransposeOp (TFL_ReshapeOp $input, (Arith_ConstantOp (GetSqueezedShape $input))), (Arith_ConstantOp (GetSqueezedPermutation $input, $permutation))), - (Arith_ConstantOp (GetShape $output_transpose))), + (Arith_ConstantOp (GetShapeAttr $output_transpose))), [(AnyStaticShapeTensor $input), (HasRankAtLeast<7> $input), (CanSqueezeTensor $input)]>; diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir index 79f50aaaadab3d..39b1346bcf93d6 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir @@ -170,3 +170,36 @@ func.func @BatchmatmulToReduceSumF32(%arg0: tensor<1x16384x257xf32>) -> (tensor< // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) <{keep_dims = true}> : (tensor<1x16384x257xf32>, tensor<1xi32>) -> tensor<1x1x257xf32> } + +// CHECK-LABEL: FuseBatchMatmulToTransposeNoBatchDims +func.func @FuseBatchMatmulToTransposeNoBatchDims(%arg0: tensor<2048x32x128xf32>, %arg1: tensor<4x128xf32>) -> tensor<4x65536xf32> { + %36 = "tfl.pseudo_const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %37 = "tfl.transpose"(%arg0, %36) : (tensor<2048x32x128xf32>, tensor<3xi32>) -> tensor<128x2048x32xf32> + %38 = "tfl.pseudo_const"() <{value = dense<[128, 65536]> : tensor<2xi32>}> : () -> tensor<2xi32> + %39 = "tfl.reshape"(%37, %38) : (tensor<128x2048x32xf32>, tensor<2xi32>) -> tensor<128x65536xf32> + %41 = "tfl.batch_matmul"(%arg1, %39) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x128xf32>, tensor<128x65536xf32>) -> tensor<4x65536xf32> + return %41 : tensor<4x65536xf32> + // CHECK-NOT: "tfl.transpose" +} + +// CHECK-LABEL: FuseBatchMatmulToTransposeWithBatchDims +func.func @FuseBatchMatmulToTransposeWithBatchDims(%arg0: tensor<2048x1x8x32x32xf32>, %arg1: tensor<2048x1x2x32xf32>) -> tensor<2048x1x2x256xf32> { + %104 = "tfl.pseudo_const"() <{value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> + %106 = "tfl.pseudo_const"() <{value = dense<[2048, 1, 32, 256]> : tensor<4xi32>}> : () -> tensor<4xi32> + %202 = "tfl.transpose"(%arg0, %104) : (tensor<2048x1x8x32x32xf32>, tensor<5xi32>) -> tensor<2048x1x32x8x32xf32> + %203 = "tfl.reshape"(%202, %106) : (tensor<2048x1x32x8x32xf32>, tensor<4xi32>) -> tensor<2048x1x32x256xf32> + %204 = "tfl.batch_matmul"(%arg1, %203) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2048x1x2x32xf32>, tensor<2048x1x32x256xf32>) -> tensor<2048x1x2x256xf32> + return %204 : tensor<2048x1x2x256xf32> + // CHECK-NOT: "tfl.transpose" +} + +// CHECK-LABEL: FuseBatchMatmulToTransposeNegative +func.func @FuseBatchMatmulToTransposeNegative(%arg0: tensor<2048x32x1x8x2xf32>, %arg1: tensor<2048x1x32x2xf32>) -> tensor<2048x1x32x256xf32> { + %88 = "tfl.pseudo_const"() <{value = dense<[0, 2, 4, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> + %90 = "tfl.pseudo_const"() <{value = dense<[2048, 1, 2, 256]> : tensor<4xi32>}> : () -> tensor<4xi32> + %194 = "tfl.transpose"(%arg0, %88) : (tensor<2048x32x1x8x2xf32>, tensor<5xi32>) -> tensor<2048x1x2x32x8xf32> + %195 = "tfl.reshape"(%194, %90) : (tensor<2048x1x2x32x8xf32>, tensor<4xi32>) -> tensor<2048x1x2x256xf32> + %196 = "tfl.batch_matmul"(%arg1, %195) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2048x1x32x2xf32>, tensor<2048x1x2x256xf32>) -> tensor<2048x1x32x256xf32> + return %196 : tensor<2048x1x32x256xf32> + // CHECK: "tfl.transpose" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc index 2451089517c549..71ebbab92c1a71 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { @@ -56,7 +57,7 @@ bool NotFromDequant(mlir::Value value) { // Converts batch_matmul operation to fully_connected if rhs is a // constant tensor with rank 2 -struct ConvertBatchMatMulOp2FullyConnectedOp +struct ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op, @@ -263,6 +264,127 @@ struct ConvertBatchMatMulOpToReduceSum return false; } }; + +// Pattern to fuse transpose op into RHS of batch_matmul op if the transpose and +// batch_matmul are separated by a reshape op; and the transpose op is used +// exclusively to transpose the contracting dimension and the LHS-Output +// dimension. +// Converts batch_matmul operation to fully_connected if rhs is rank-2 +// else converts it to a BatchMatMul op with adj_y = true and transpose fused +// into RHS. +// +// Example: +// % 0 = "tfl.transpose" // Input: [2048, 32, 128] -> [128, 2048, 32] +// % 1 = "tfl.reshape"(%0) // reshaped [128, 2048, 32] -> [128, 65536] +// % 2 = "tfl.batch_matmul" // LHS: [4, 128], RHS: [128, 65536] -> [4, 65536] +struct FuseRhsTransposeIntoBatchMatMulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op, + PatternRewriter& rewriter) const override { + // Exit the pattern if adj_y is true. + if (bmm_op.getAdjY()) { + return rewriter.notifyMatchFailure( + bmm_op, "Pattern does not apply when adj_y is true."); + } + + // Exit the pattern if the RHS of BatchMatMulOp is not originated from a + // TFL::TransposeOp->TFL::ReshapeOp. + auto reshape_op = bmm_op.getY().getDefiningOp(); + if (!reshape_op) { + return rewriter.notifyMatchFailure( + bmm_op, + "RHS is not originated from a transpose->reshape op pattern."); + } + + auto transpose_op = reshape_op.getInput().getDefiningOp(); + if (!transpose_op) { + return rewriter.notifyMatchFailure( + bmm_op, + "RHS is not originated from a transpose->reshape op pattern."); + } + + // Get the dimensions info of the RHS of BatchMatMulOp. + auto rhs_dimensions_info = GetBatchMatMulRhsDimensionsInfo( + mlir::cast(bmm_op.getY().getType())); + + // Make sure that the reshape op is flattening either the contracting + // dimension or the output dimension. + auto reshape_input_shape = GetShape(reshape_op.getInput()); + if (!HasFlattenedContractingDims(reshape_input_shape, + rhs_dimensions_info) && + !HasFlattenedOutDims(reshape_input_shape, rhs_dimensions_info)) { + return rewriter.notifyMatchFailure( + bmm_op, + "Reshape op is not flattening the contracting dimension or the " + "output dimension."); + } + + // Make sure that the transpose op is only transposing the contracting + // dimensions and the output dimensions. + auto transpose_perm_status_or_value = + GetValueAsIntArray(transpose_op.getPerm()); + auto transpose_input_shape = GetShape(transpose_op.getInput()); + if (transpose_perm_status_or_value.ok() && + !HasTransposedContractingAndOutDims( + transpose_input_shape, transpose_perm_status_or_value.value(), + rhs_dimensions_info)) { + return rewriter.notifyMatchFailure( + bmm_op, + "Transpose op is not transposing the contracting dimension and the " + "output dimension."); + } + + auto rhs_contracting_dimensions = + rhs_dimensions_info.contracting_dimensions(); + auto rhs_out_dimensions = rhs_dimensions_info.out_dimensions(); + auto rhs_batch_dimensions = rhs_dimensions_info.batch_dimensions(); + + // Create a new ReshapeOp, without the TransposeOp, to flatten the + // contracting dimension and the output dimension, as needed. + llvm::SmallVector new_reshape_input_shape; + if (!rhs_dimensions_info.batch_dimensions().AxesArray().empty()) { + for (auto dim_size : rhs_batch_dimensions.SizesArray()) { + new_reshape_input_shape.push_back(dim_size); + } + } + new_reshape_input_shape.push_back(rhs_out_dimensions.SizesArray().front()); + new_reshape_input_shape.push_back( + rhs_contracting_dimensions.SizesArray().front()); + + Value new_reshape_shape_value = rewriter.create( + bmm_op->getLoc(), + GetI32ElementsAttr(new_reshape_input_shape, &rewriter)); + auto new_reshape_value = rewriter.create( + bmm_op->getLoc(), transpose_op.getInput(), new_reshape_shape_value); + + // Replace the BatchMatMulOp with a FullyConnectedOp, if the RHS of BMM has + // no broadcasting dimensions. I.e. RHS of BMM is of Rank 2. + if (rhs_dimensions_info.batch_dimensions().AxesArray().empty()) { + auto no_input = rewriter.create( + bmm_op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); + auto fc_op = rewriter.create( + bmm_op->getLoc(), ArrayRef{bmm_op.getType()}, + /*input=*/bmm_op.getX(), /*filter=*/new_reshape_value, + /*bias=*/no_input, + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(true), + /*asymmetric_quantize_inputs=*/mlir::BoolAttr()); + rewriter.replaceOp(bmm_op, {fc_op.getResult(0)}); + } else { + // Replace the BatchMatMulOp with a BatchMatMulOp with adj_y = true and + // transpose fused into RHS. + auto bmm_op_with_adj_y = rewriter.create( + bmm_op->getLoc(), bmm_op.getType(), bmm_op.getX(), new_reshape_value, + bmm_op.getAdjX(), /*adj_y=*/true, mlir::BoolAttr()); + rewriter.replaceOp(bmm_op, {bmm_op_with_adj_y.getResult()}); + } + + return success(); + } +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize_batch_matmul.inc" } // namespace @@ -271,8 +393,10 @@ void OptimizeBatchMatmulPass::runOnOperation() { auto* ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns + .add( + ctx); TFL::populateWithGenerated(patterns); (void)applyPatternsGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index 994762b7641ebb..cb2702886dc719 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -824,21 +824,6 @@ bool IsPermutationNCHW(Value perm) { #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" -// Returns 1D 32-bit dense elements attribute with the given values. -static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = mlir::RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(32)); - return DenseIntElementsAttr::get(ty, values); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, values); -} - // Get the number of leading 1s in the shape of the given input. // Ex. input_shape = [1 x 1 x 1 x 1 x 2 x 1] => 4 // returns 0 if the input shape is not static. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index cfa1e21619d203..2e08d0f89959a9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -725,7 +725,7 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, // Returns truncated shape of a ranked-tensor. // Prefix-Truncated, here, means eliminating any contiguous 1s' in the lower // dimentions of the tensor -def GetPrefixTruncatedShape: NativeCodeCall<"GetShape($0, true)">; +def GetPrefixTruncatedShape: NativeCodeCall<"GetShapeAttr($0, true)">; // Returns True if the operand type is RankedTensorType and valid. def HasValidRankedTensor : Constraint>; + "GetShapeAttr($0, true) == GetShapeAttr($1)">>; def ConvertSqueezeToReshape : Pat< (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $squeeze_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $squeeze_op))), [(HasValidRankedTensor $squeeze_op)]>; // Pattern to perform the following optimization @@ -793,7 +793,7 @@ def UndoBroadcastConvBiasAdd : Pat< // Pattern to convert a trivial transpose op to a reshape op. def ConvertTrivialTransposeOpToReshapeOp : Pat< (TFL_TransposeOp:$transpose_op $input, (Arith_ConstantOp:$permutation $p1)), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $transpose_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $transpose_op))), [(IsTransposeTrivial $input, $permutation), (AnyStaticShapeTensor $input), (AnyStaticShapeTensor $transpose_op)]>; @@ -810,7 +810,7 @@ def FoldDoubleTranspose : Pat< // Convert expand_dims to reshape if possible. def ConvertExpandDimsToReshape : Pat< (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $expand_dims_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $expand_dims_op))), [(AnyStaticShapeTensor $expand_dims_op)]>; // Here, the element type can be any integer or float type. @@ -1324,7 +1324,7 @@ def ReplaceOneHotFullyConnectedWithLookup : Pat< (Arith_ConstantOp ConstantAttr, "{1,0}">)), (returnType (GetEmbeddingLookupShape $indices, $filter)) ), - (Arith_ConstantOp (GetShape (GetIthValue<0> $outputs)))), + (Arith_ConstantOp (GetShapeAttr (GetIthValue<0> $outputs)))), [(I32ElementsVal $indices), // lookup is not implemented for i64 (IsNoneType $bias)]>; // Maybe folded into the lookup matrix later diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc new file mode 100644 index 00000000000000..e40fb1a85d4e88 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc @@ -0,0 +1,303 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +BatchMatMulDimensionsInfo::BatchMatMulDimensionsInfo(mlir::ShapedType type, + bool is_lhs) + : is_lhs_(is_lhs) { + // BatchMatMulOp has the following shape pattern: B0,...,Bn,L,C and + // B0,...,Bn,C,R. So, there is only one Contracting dimension and one + // output dimension. + const int64_t rank = type.getRank(); + + if (is_lhs) { + contracting_dimensions_.axes.push_back(rank - 1); + contracting_dimensions_.sizes.push_back(type.getDimSize(rank - 1)); + out_dimensions_.axes.push_back(rank - 2); + out_dimensions_.sizes.push_back(type.getDimSize(rank - 2)); + } else { + contracting_dimensions_.axes.push_back(rank - 2); + contracting_dimensions_.sizes.push_back(type.getDimSize(rank - 2)); + out_dimensions_.axes.push_back(rank - 1); + out_dimensions_.sizes.push_back(type.getDimSize(rank - 1)); + } + // Dims 0 and 1 are contracting and output dimensions, hence skipped. + for (int64_t dim = 0; dim < rank - 2; ++dim) { + batch_dimensions_.axes.push_back(dim); + batch_dimensions_.sizes.push_back(type.getDimSize(dim)); + } +} + +const DimensionVector& BatchMatMulDimensionsInfo::batch_dimensions() const { + return batch_dimensions_; +} +const DimensionVector& BatchMatMulDimensionsInfo::contracting_dimensions() + const { + return contracting_dimensions_; +} + +const DimensionVector& BatchMatMulDimensionsInfo::out_dimensions() const { + return out_dimensions_; +} + +bool BatchMatMulDimensionsInfo::is_lhs() const { return is_lhs_; } + +BatchMatMulDimensionsInfo GetBatchMatMulLhsDimensionsInfo( + mlir::ShapedType type) { + return BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); +} + +BatchMatMulDimensionsInfo GetBatchMatMulRhsDimensionsInfo( + mlir::ShapedType type) { + return BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); +} + +bool HasFlattenedContractingDims( + llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + // Batch dimensions are not flattened and need to match the LHS/RHS of + // BatchMatMulOp. + auto batch_dimensions = bmm_dimensions_info.batch_dimensions().SizesArray(); + // The batch dimensions are at the front of the input shape. + auto reshape_input_shape_batch_dims = + reshape_input_shape.take_front(batch_dimensions.size()); + + if (!llvm::all_of( + llvm::zip(batch_dimensions, reshape_input_shape_batch_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + // Out dimensions are assumed to be unflattened and need to match the LHS/RHS + // of BatchMatMulOp. + auto out_dimensions = bmm_dimensions_info.out_dimensions().SizesArray(); + llvm::ArrayRef reshape_input_shape_out_dims; + // The out dimensions are at the end of the input shape for LHS and + // at the front for RHS. + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_out_dims = + reshape_input_shape.slice(batch_dimensions.size(), 1); + } else { + reshape_input_shape_out_dims = + reshape_input_shape.take_back(out_dimensions.size()); + } + if (!llvm::all_of( + llvm::zip(out_dimensions, reshape_input_shape_out_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto contracting_dimensions = + bmm_dimensions_info.contracting_dimensions().SizesArray(); + // The contracting dimensions are at the end of the input shape for + // LHS and at the front for RHS. + llvm::ArrayRef reshape_input_shape_contracting_dims; + size_t num_contracting_dims = reshape_input_shape.size() - + batch_dimensions.size() - out_dimensions.size(); + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_contracting_dims = + reshape_input_shape.take_back(num_contracting_dims); + } else { + reshape_input_shape_contracting_dims = reshape_input_shape.slice( + batch_dimensions.size(), num_contracting_dims); + } + + return (std::accumulate(reshape_input_shape_contracting_dims.begin(), + reshape_input_shape_contracting_dims.end(), 1, + std::multiplies()) == + contracting_dimensions[0]); +} + +bool HasFlattenedOutDims(llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + // Batch dimensions are not flattened and need to match the LHS/RHS of + // BatchMatMulOp. + auto batch_dimensions = bmm_dimensions_info.batch_dimensions().SizesArray(); + // The batch dimensions are at the front of the input shape. + auto reshape_input_shape_batch_dims = + reshape_input_shape.take_front(batch_dimensions.size()); + if (!llvm::all_of( + llvm::zip(batch_dimensions, reshape_input_shape_batch_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto contracting_dimensions = + bmm_dimensions_info.contracting_dimensions().SizesArray(); + // The contracting dimensions are at the end of the input shape for + // LHS and at the front for RHS. + llvm::ArrayRef reshape_input_shape_contracting_dims; + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_contracting_dims = + reshape_input_shape.take_back(contracting_dimensions.size()); + } else { + reshape_input_shape_contracting_dims = + reshape_input_shape.slice(batch_dimensions.size(), 1); + } + if (!llvm::all_of( + llvm::zip(contracting_dimensions, + reshape_input_shape_contracting_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto out_dimensions = bmm_dimensions_info.out_dimensions().SizesArray(); + // The out dimensions are at the end of the input shape for LHS and + // at the front for RHS. + llvm::ArrayRef reshape_input_shape_out_dims; + size_t num_out_dims = reshape_input_shape.size() - batch_dimensions.size() - + contracting_dimensions.size(); + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_out_dims = + reshape_input_shape.slice(batch_dimensions.size(), num_out_dims); + } else { + reshape_input_shape_out_dims = reshape_input_shape.take_back(num_out_dims); + } + + return (std::accumulate(reshape_input_shape_out_dims.begin(), + reshape_input_shape_out_dims.end(), 1, + std::multiplies()) == out_dimensions[0]); +} + +std::tuple, std::pair> +GetTransposedGroupsIndexRange(llvm::ArrayRef transpose_permutation) { + // If the input vector is empty, return None for both pairs. + if (transpose_permutation.empty()) { + return {{-1, -1}, {-1, -1}}; // Use -1 to indicate None + } + + int group_one_end_idx = -1; + for (int i = 0; i < transpose_permutation.size(); ++i) { + if (transpose_permutation[i] == i) { + group_one_end_idx = i; + } else { + break; + } + } + + // If all dimensions are batch dimensions, i.e. the first group is a + // monotonically increasing sequence, return None for both remaining groups. + if (group_one_end_idx == transpose_permutation.size() - 1) { + return {{-1, -1}, {-1, -1}}; + } + + int group_two_start_idx = group_one_end_idx + 1; + int group_two_end_idx = group_two_start_idx; + int group_three_start_idx = -1; + int group_three_end_idx = -1; + + int group_two_end_idx_value = transpose_permutation.size() - 1; + int group_three_start_idx_value = group_one_end_idx + 1; + + for (int i = group_two_start_idx + 1; i < transpose_permutation.size(); ++i) { + if (transpose_permutation[i] > group_two_end_idx_value || + transpose_permutation[i] <= group_three_start_idx_value || + (transpose_permutation[i] != transpose_permutation[i - 1] + 1)) { + break; + } + group_two_end_idx = i; + } + + group_three_start_idx = group_two_end_idx + 1; + group_three_end_idx = transpose_permutation.size() - 1; + // Fail if the last group is not a monotonically increasing sequence. + for (int i = group_three_start_idx + 1; i < transpose_permutation.size(); + ++i) { + if (transpose_permutation[i] != transpose_permutation[i - 1] + 1) { + return {{-1, -1}, {-1, -1}}; + } + } + + // Handle edge cases where start index might be greater than end index. + if (group_two_start_idx > group_two_end_idx) { + group_two_start_idx = group_two_end_idx; + } + + if (group_three_start_idx > group_three_end_idx) { + group_three_start_idx = group_three_end_idx; + } + if (group_three_start_idx >= transpose_permutation.size()) { + group_three_start_idx = -1; + group_three_end_idx = -1; + } + + return {{group_two_start_idx, group_two_end_idx}, + {group_three_start_idx, group_three_end_idx}}; +} + +bool HasTransposedContractingAndOutDims( + llvm::ArrayRef transpose_input_shape, + llvm::ArrayRef transpose_permutation, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + std::tuple, std::pair> + transposed_groups_index_range = + GetTransposedGroupsIndexRange(transpose_permutation); + // Return false if the transpose_permutation is not valid. + if (std::get<0>(transposed_groups_index_range).first == -1 || + std::get<0>(transposed_groups_index_range).second == -1 || + std::get<1>(transposed_groups_index_range).first == -1 || + std::get<1>(transposed_groups_index_range).second == -1) { + return false; + } + + // Check if the broadcast dimensions match the batch dimensions of + // BatchMatMulOp. + if (!bmm_dimensions_info.batch_dimensions().AxesArray().empty() && + bmm_dimensions_info.batch_dimensions().AxesArray().back() != + std::get<0>(transposed_groups_index_range).first - 1) { + return false; + } + + // Accumulating the sizes of the transposed groups should match the sizes of + // the contracting and out dimensions of BatchMatMulOp. + int64_t group_two_dims_size = 1; + int64_t group_three_dims_size = 1; + for (int i = std::get<0>(transposed_groups_index_range).first; + i <= std::get<0>(transposed_groups_index_range).second; ++i) { + group_two_dims_size *= transpose_input_shape[transpose_permutation[i]]; + } + for (int i = std::get<1>(transposed_groups_index_range).first; + i <= std::get<1>(transposed_groups_index_range).second; ++i) { + group_three_dims_size *= transpose_input_shape[transpose_permutation[i]]; + } + + const auto& out_dims = bmm_dimensions_info.out_dimensions().SizesArray()[0]; + const auto& contracting_dims = + bmm_dimensions_info.contracting_dimensions().SizesArray()[0]; + + return bmm_dimensions_info.is_lhs() + ? (group_two_dims_size == out_dims && + group_three_dims_size == contracting_dims) + : (group_two_dims_size == contracting_dims && + group_three_dims_size == out_dims); +} +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h new file mode 100644 index 00000000000000..3eb3de702e1f4a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h @@ -0,0 +1,141 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// LHS and RHS of BatchMatMulOp has shapes following the pattern: +// B0,...,Bn,L,C and B0,...,Bn,C,R. The output shape of BatchMatMulOp is: +// B0,...,Bn,L,R. +// +// LHS and RHS of FullyConnectedOp has shapes following the pattern: +// B0,...,Bn,L,C and R,C. The output shape of FullyConnectedOp is: +// B0,...,Bn,L,R. +// +// The fundamental idea behind seeing transposes and reshapes around +// BatchMatMulOp is that- +// -- BatchMatMulOp is often created as a result of lowering einsum or +// dot_general ops. +// -- einsum and dot_general ops have multiple contracting and output +// dimensions that will to be reshaped and transposed to match the +// BatchMatMulOp's LHS and RHS restrictions. +// +// This file contains utility functions to identify the reshapes and transposes +// around BatchMatMulOp and see if they can be fused. + +// A struct to hold axes and sizes for a set of dimensions. +struct DimensionVector { + llvm::ArrayRef AxesArray() const { return axes; } + llvm::ArrayRef SizesArray() const { return sizes; } + + llvm::SmallVector axes; + llvm::SmallVector sizes; +}; + +// A struct to hold information about dimensions of dot_general operands. +class BatchMatMulDimensionsInfo { + public: + BatchMatMulDimensionsInfo(mlir::ShapedType type, bool is_lhs); + const DimensionVector& batch_dimensions() const; + const DimensionVector& contracting_dimensions() const; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + const DimensionVector& out_dimensions() const; + bool is_lhs() const; + + private: + DimensionVector batch_dimensions_; + DimensionVector contracting_dimensions_; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + DimensionVector out_dimensions_; + bool is_lhs_; +}; + +// Returns the dimensions info of the LHS of BatchMatMulOp. +BatchMatMulDimensionsInfo GetBatchMatMulLhsDimensionsInfo( + mlir::ShapedType type); + +// Returns the dimensions info of the RHS of BatchMatMulOp. +BatchMatMulDimensionsInfo GetBatchMatMulRhsDimensionsInfo( + mlir::ShapedType type); + +// Returns true if the product of the last few dimensions in the +// `reshape_input_shape` is equal to the contracting dimension of the +// `bmm_dimensions_info`. +bool HasFlattenedContractingDims( + llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// Returns true if the product of the first few dimensions in the +// `reshape_input_shape` is equal to the output dimension of the +// `bmm_dimensions_info`. +bool HasFlattenedOutDims(llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// Returns true if the contracting and output dimensions are transposed in the +// `transpose_permutation`. +bool HasTransposedContractingAndOutDims( + llvm::ArrayRef transpose_input_shape, + llvm::ArrayRef transpose_permutation, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// `transpose_permutation` is the permutation of the input shape of the +// transpose op. `transpose_input_shape` is the shape of the input of the +// transpose op. `bmm_dimensions_info` is the dimensions info of the +// BatchMatMulOp. +// +// The dimensions in the transpose_permutation can be split into three groups: +// 1. Batch dimensions +// 2. Contracting dimensions +// 3. Output dimensions +// +// - The number of dimensions and the order of the dimensions in the +// batch-dimensions group is expected to match the batch dimensions of the +// BatchMatMulOp. +// - The number of dimensions in the contracting-dimensions and +// output-dimensions groups can be more than 1. +// - The dimensions in group 1 are expected to be a monotonically increasing +// sequence. +// - The dimensions in group 2 and 3 need not be a monotonically increasing +// sequence. +// - In this function, we only care if the groups 2 and 3 are transposed. +// +// For example, consider the following transpose_permutation- +// [0, 1, 2, 6, 7, 8, 3, 4, 5]. Here all the three groups are monotonically +// increasing. But other permutations like [0, 1, 2, 8, 7, 6, 4, 5, 3] and [0, +// 1, 2, 6, 7, 8, 3, 5, 4] are also valid. +// +// NOTE: The first version of this function will support the case where all the +// three groups are monotonically increasing. +std::tuple, std::pair> +GetTransposedGroupsIndexRange(llvm::ArrayRef transpose_permutation); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc new file mode 100644 index 00000000000000..cf026d8c8169e2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" + +#include +#include +#include + +#include +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace { + +TEST(OptimizeBatchMatmulUtilsTest, BatchMatMulDimensionsInfo) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 4, 5}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_EQ(lhs_info.batch_dimensions().AxesArray(), + llvm::ArrayRef({0, 1, 2})); + EXPECT_EQ(lhs_info.batch_dimensions().SizesArray(), + llvm::ArrayRef({1, 2, 3})); + EXPECT_EQ(lhs_info.contracting_dimensions().AxesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(lhs_info.contracting_dimensions().SizesArray(), + llvm::ArrayRef({5})); + EXPECT_EQ(lhs_info.out_dimensions().AxesArray(), + llvm::ArrayRef({3})); + EXPECT_EQ(lhs_info.out_dimensions().SizesArray(), + llvm::ArrayRef({4})); + EXPECT_TRUE(lhs_info.is_lhs()); + + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_EQ(rhs_info.batch_dimensions().AxesArray(), + llvm::ArrayRef({0, 1, 2})); + EXPECT_EQ(rhs_info.batch_dimensions().SizesArray(), + llvm::ArrayRef({1, 2, 3})); + EXPECT_EQ(rhs_info.contracting_dimensions().AxesArray(), + llvm::ArrayRef({3})); + EXPECT_EQ(rhs_info.contracting_dimensions().SizesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(rhs_info.out_dimensions().AxesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(rhs_info.out_dimensions().SizesArray(), + llvm::ArrayRef({5})); + EXPECT_FALSE(rhs_info.is_lhs()); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasFlattenedContractingDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 4, 50}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedContractingDims({1, 2, 3, 4, 5, 10}, lhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({1, 2, 3, 4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({1, 2, 12, 5}, + mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedContractingDims({1, 2, 3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({1, 2, 3, 4, 10}, rhs_info)); + + type = mlir::RankedTensorType::get({4, 50}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedContractingDims({4, 5, 10}, lhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({12, 5}, mlir::Float32Type::get(&context)); + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedContractingDims({3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({3, 4, 10}, rhs_info)); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasFlattenedOutDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 12, 5}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedOutDims({1, 2, 3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({1, 2, 3, 4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({1, 2, 12, 10}, + mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedOutDims({1, 2, 12, 5, 2}, rhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({1, 2, 3, 4, 10}, rhs_info)); + + type = mlir::RankedTensorType::get({12, 5}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedOutDims({3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({3, 4, 10}, lhs_info)); + + type = + mlir::RankedTensorType::get({12, 10}, mlir::Float32Type::get(&context)); + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedOutDims({12, 5, 2}, rhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({3, 4, 10}, rhs_info)); +} + +TEST(OptimizeBatchMatmulUtilsTest, GetTransposedGroupsIndexRange) { + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 6, 7, 8, 3, 4, 5}), + std::make_tuple(std::make_pair(3, 5), std::make_pair(6, 8))); + EXPECT_EQ(GetTransposedGroupsIndexRange({2, 0, 1}), + std::make_tuple(std::make_pair(0, 0), std::make_pair(1, 2))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 3, 7, 8, 4, 5, 6}), + std::make_tuple(std::make_pair(4, 5), std::make_pair(6, 8))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 3, 8, 7, 4, 5, 6}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasTransposedContractingAndOutDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 504, 120}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 6, 7, 8, 3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 8, 7, 6, 4, 5, 3}, lhs_info)); + + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 6, 7, 8, 3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 8, 7, 6, 4, 5, 3}, rhs_info)); + + type = + mlir::RankedTensorType::get({504, 120}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasTransposedContractingAndOutDims({4, 5, 6, 7, 8, 9}, + {3, 4, 5, 0, 1, 2}, lhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {4, 5, 6, 7, 8, 9}, {5, 4, 3, 1, 2, 0}, lhs_info)); + + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasTransposedContractingAndOutDims({4, 5, 6, 7, 8, 9}, + {3, 4, 5, 0, 1, 2}, rhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {4, 5, 6, 7, 8, 9}, {5, 4, 3, 1, 2, 0}, rhs_info)); +} + +} // namespace +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 6779dac5ad5e8c..3cf71da1d31ac7 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -23,12 +23,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project @@ -59,6 +62,21 @@ inline bool IsPosInfiniteValue(APFloat value) { return value.isInfinity(); } +// Returns 1D 32-bit dense elements attribute with the given values. +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +inline DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + // Returns true if all tensor value in `values` has static shape and same shape. inline bool OpHasSameStaticShapes(Operation* op) { auto values = op->getOperands(); @@ -165,7 +183,7 @@ inline bool IsTransposeTrivial(llvm::ArrayRef input_shape, // Returns the permutation that maps the input shape to the output shape. // This is only valid for trivial reshape ops. inline DenseElementsAttr GetPermutationFromTrivialReshape( - ShapedType input_type, ShapedType output_type) { + mlir::ShapedType input_type, mlir::ShapedType output_type) { ArrayRef in_shape = input_type.getShape(); ArrayRef out_shape = output_type.getShape(); @@ -209,8 +227,8 @@ inline DenseElementsAttr GetPermutationFromTrivialReshape( // Returns true if the reshape op is equivalent to a transpose op. // This is true if the reshape op is a trivial reshape op, meaning no change in // the order of non-identity dimensions. -inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, - ShapedType output_type) { +inline bool IsReshapeEquivalentToTranspose(mlir::ShapedType input_type, + mlir::ShapedType output_type) { std::vector in_shape{input_type.getShape().vec()}; std::vector out_shape{output_type.getShape().vec()}; @@ -299,8 +317,8 @@ inline Type TransposeLastTwoDims(Type type) { // Returns a ShapedType for a permutation and the shape of input after // applying the permutation to the given shape through a transpose. -inline ShapedType GetTransposedType(Value input, - llvm::ArrayRef permutation_array) { +inline mlir::ShapedType GetTransposedType( + Value input, llvm::ArrayRef permutation_array) { auto input_type = input.getType().cast(); if (permutation_array.size() != input_type.getRank()) { return nullptr; @@ -341,41 +359,66 @@ inline DenseElementsAttr GetExpandedShapeAttr(Value input_val, int n) { // Return the resultant shape type if the shape of the supplied attribute/value // is expanded by n leading 1s'. -inline ShapedType GetExpandedShapeType(Value input_val, int n) { +inline mlir::ShapedType GetExpandedShapeType(Value input_val, int n) { auto expanded_shape = GetExpandedShape(input_val, n); return RankedTensorType::get( SmallVector{expanded_shape.begin(), expanded_shape.end()}, mlir::cast(input_val.getType()).getElementType()); } -// Returns shape of a ranked tensor. -// Precondition: output_val's is ranked tensor. -// Returns a truncated shape when `truncate` is set to true. -inline DenseElementsAttr GetShape(Value output_val, bool truncate = false) { - auto output_shape = output_val.getType().dyn_cast().getShape(); +// Returns shape of a ranked tensor as a SmallVector. +// Precondition: input_value's is ranked tensor. +// Returns a squeezed shape when `squeeze_leading_ones` is set to true. +inline SmallVector GetShape(Value input_value, + bool squeeze_leading_ones = false) { + auto output_shape = input_value.getType().dyn_cast().getShape(); SmallVector shape; shape.reserve(output_shape.size()); - bool needs_truncation = true; + bool can_squeeze = true; for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) { int64_t dim = output_shape[dim_idx]; - if (truncate && needs_truncation && dim == 1) { + if (squeeze_leading_ones && can_squeeze && dim == 1) { continue; - } else if (needs_truncation && dim != 1) { - needs_truncation = false; + } else if (can_squeeze && dim != 1) { + can_squeeze = false; } shape.push_back(ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); } + return shape; +} + +// Returns shape of a ranked tensor as a DenseElementsAttr. +// Precondition: input_value's is ranked tensor. +// Returns a squeezed shape when `squeeze_leading_ones` is set to true. +inline DenseElementsAttr GetShapeAttr(Value input_value, + bool squeeze_leading_ones = false) { + SmallVector shape = GetShape(input_value, squeeze_leading_ones); return mlir::DenseElementsAttr::get( RankedTensorType::get( {static_cast(shape.size())}, - mlir::IntegerType::get(output_val.getContext(), 32)), + mlir::IntegerType::get(input_value.getContext(), 32)), llvm::ArrayRef(shape)); } +// Returns the value of a constant attribute as an int array, if the value is +// not a constant, returns an error status. +inline absl::StatusOr> GetValueAsIntArray(Value value) { + DenseElementsAttr values_const_attr; + if (!matchPattern(value, m_Constant(&values_const_attr))) { + return absl::InvalidArgumentError("Value is not a constant."); + } + + SmallVector values; + for (const auto& value : values_const_attr.getValues()) { + values.push_back(value.getSExtValue()); + } + return values; +} + //////////////////////////////////////////////////////////////////////////////// ///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index d7029fe5ca7939..b1dce77f71b001 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -140,7 +140,7 @@ def CreateNoneValue : NativeCodeCall< // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. -def GetShape: NativeCodeCall<"GetShape($0)">; +def GetShapeAttr: NativeCodeCall<"GetShapeAttr($0)">; // Return the resultant shape if the shape of the supplied attribute/value is // expanded by n leading 1s'. From 51bd0497d02b2b916ccf6294f0a1488a00075db2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 21:29:22 -0700 Subject: [PATCH 0065/1324] Automated Code Change PiperOrigin-RevId: 742523499 --- .../compiler/mlir/tensorflow/transforms/host_runtime/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index abeec429e2fe1d..fa91275c392432 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -6,7 +6,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", "//tensorflow/compiler/mlir:__pkg__", "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", From 0cd83a45e20fcd9c1a90a2e7bff3835f37d31af7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 22:11:26 -0700 Subject: [PATCH 0066/1324] Automated Code Change PiperOrigin-RevId: 742534687 --- tensorflow/python/framework/experimental/unified_api.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/framework/experimental/unified_api.cc b/tensorflow/python/framework/experimental/unified_api.cc index ea1047ff8d9032..49170e48d0f18d 100644 --- a/tensorflow/python/framework/experimental/unified_api.cc +++ b/tensorflow/python/framework/experimental/unified_api.cc @@ -78,7 +78,6 @@ using tensorflow::Safe_TF_StatusPtr; using tensorflow::Status; using tensorflow::string; using tensorflow::TFE_TensorHandleToNumpy; -using tensorflow::core::RefCountPtr; using tensorflow::errors::Internal; using tensorflow::errors::InvalidArgument; From 86091aca45ea41bdeda17a95efd6433bbef67e0b Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Mon, 31 Mar 2025 22:19:28 -0700 Subject: [PATCH 0067/1324] Lower jnp.unstack to tfl.unpack. Without the rewrite, the op is lowered to a number of slice ops. PiperOrigin-RevId: 742536706 --- .../lite/stablehlo/tests/composite-lowering.mlir | 12 +++++++++++- .../transforms/composite_lowering_patterns.td | 10 +++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir index a7095618ab0901..f9c8c4953fb931 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -466,4 +466,14 @@ func.func private @XlaCallModule_odml.random_standard_normal.impl_0(%arg0: tenso } // CHECK-LABEL func.func @random_standard_normal // CHECK: %0 = "tfl.random_standard_normal"(%arg0) <{seed = 0 : i64, seed2 = 1 : i64}> : (tensor<3xi32>) -> tensor<1x2x3xf32> -// CHECK: return %0 : tensor<1x2x3xf32> \ No newline at end of file +// CHECK: return %0 : tensor<1x2x3xf32> + + +func.func private @XlaCallModule_tfl.unpack.impl_0(%arg0: tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) +func.func @jax_unstack(%arg0: tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) { + %0:3 = mhlo.composite "tfl.unpack" %arg0 {composite_attributes = {num = 3 : i32, axis = 1 : i32}, decomposition = @XlaCallModule_tfl.unpack.impl_0} : (tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) + return %0#0, %0#1, %0#2 : tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32> +} + +// CHECK-LABEL: jax_unstack +// CHECK: %0:3 = "tfl.unpack"(%arg0) <{axis = 1 : i32, num = 3 : i32}> : (tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 47a3693ab3961b..1beff56b89b40c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -183,4 +183,12 @@ def LegalizeCompositeOdmlRandomStandardNormal : Pat< ConstantStrAttr, $attrs, $_, $_), (TFL_RandomStandardNormalOp $shape, (GetCompositeAttributeAs<"seed", "IntegerAttr"> $attrs), - (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; \ No newline at end of file + (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; + +def LegalizeCompositeUnpack : Pat< + (MHLO_CompositeOp:$composite + (variadic $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_UnpackOp $inputs, + (GetCompositeAttributeAs<"num", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; \ No newline at end of file From 084b10464b4b401c7e14f96dc737b968bd1ae9fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 Mar 2025 22:29:36 -0700 Subject: [PATCH 0068/1324] Automated Code Change PiperOrigin-RevId: 742538829 --- tensorflow/java/BUILD | 2 ++ tensorflow/java/src/gen/cc/op_gen_main.cc | 1 - tensorflow/java/src/gen/cc/op_generator.cc | 2 ++ tensorflow/java/src/gen/cc/op_generator.h | 1 + tensorflow/java/src/gen/cc/op_specs.cc | 3 +++ tensorflow/java/src/gen/cc/source_writer.cc | 1 + 6 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 741e089eddf082..2a6cb89cb485ad 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -137,7 +137,9 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 751289797875dc..814bf035d0e76b 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -13,7 +13,6 @@ limitations under the License. ==============================================================================*/ -#include #include #include "absl/log/check.h" diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 11798ad56641d8..61ccbd124eaaab 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include #include #include #include +#include "absl/status/status.h" #include "absl/strings/ascii.h" #include "xla/tsl/platform/status.h" #include "tensorflow/core/framework/api_def.pb.h" diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 048c05193d9b6a..e59c706e7355b6 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index d693a2b2a4ad08..bffb769004b56e 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/java/src/gen/cc/op_specs.h" +#include #include +#include #include #include #include +#include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/strings/strip.h" diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index a58746774996a5..b3878e85c6b132 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/java/src/gen/cc/source_writer.h" #include +#include #include #include From 2f446d41370bba5ef4d84caee49a3fbf649e6c0f Mon Sep 17 00:00:00 2001 From: pizzud Date: Mon, 31 Mar 2025 22:54:22 -0700 Subject: [PATCH 0069/1324] autotuning: Ensure entry versions only need to be updated in one place. These versions update frequently, so let's ensure that the process of doing so doesn't require touching several tests and textprotos. While here, add a missing license header. PiperOrigin-RevId: 742545052 --- .../service/gpu/autotuning/autotuner_util.cc | 11 ++++++++ .../service/gpu/autotuning/autotuner_util.h | 7 ++++++ .../gpu/autotuning/autotuner_util_test.cc | 2 +- .../autotuning/gemm_fusion_autotuner_test.cc | 5 ++-- .../gpu_compiler_test_autotune_db.textproto | 25 +++++++++++-------- ...aot_compile_test_autotune_results.prototxt | 2 -- 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc index 81ebd3c743bd51..4ecb2c9a0e1998 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc @@ -495,6 +495,7 @@ bool IsTextProtoPath(absl::string_view file_path) { kVersion, results.version())); } + AddVersionToAutotuneResults(results); TF_RETURN_IF_ERROR(LoadAutotuneResults(results, allow_override)); return absl::OkStatus(); } @@ -569,5 +570,15 @@ bool IsTextProtoPath(absl::string_view file_path) { autotune_cache_stats = CacheStats(); } +void AddVersionToAutotuneResults(AutotuneResults& results) { + for (auto& result : *results.mutable_results()) { + if (result.version() == 0) { + // Create a dummy key and pull its version if we don't have one specified. + AutotuneCacheKey key("foo", "canonical_foo"); + result.set_version(key.GetVersion()); + } + } +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index ad4d738ecfc6e1..1ac55aff05ed7b 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -362,6 +362,13 @@ absl::StatusOr GetBase64EncodedSha256Hash(absl::string_view s); std::string ToCanonicalString(const HloInstruction* instr); +// Adds version information to each entry in AutotuneResults. Useful for unit +// tests involving hard-coded AutotuneResults (including those read from files, +// which happens automatically), as the entry version changes much more often +// than the overall structure version of the AutotuneResults itself, so it's +// nice to only have to change one place to update it. +void AddVersionToAutotuneResults(AutotuneResults& results); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc index 9da36356e30830..97838f143b1ce0 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc @@ -101,7 +101,6 @@ results { num_ctas: 1 } } - version: 1 })"; void SetUp() override { @@ -185,6 +184,7 @@ TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto1) { EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString( std::string(kResultText), &results)); ASSERT_GT(results.results().size(), 0); + AddVersionToAutotuneResults(results); AutotuneCacheKey key(results.results(0).device(), results.results(0).hlo(), results.results(0).version()); auto options = DebugOptions(); diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index e8b5519c286522..ccdccc0c5f3fa0 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -762,7 +762,6 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, ParseTextProto(R"pb( - version: 3 results { device: "..." hlo: "..." @@ -770,8 +769,8 @@ ENTRY main { gemm { algorithm: -1 } run_time { nanos: 14 } } - version: 1 })pb")); + AddVersionToAutotuneResults(autotune_results_override); autotune_results_override.mutable_results(0)->set_device( std::string(cache_key.GetModelStr())); autotune_results_override.mutable_results(0)->set_hlo( @@ -1496,8 +1495,8 @@ absl::StatusOr GetDummyAutotuneResultsForCacheKey( custom_kernel_fusion { kernel_index: 1 } run_time { nanos: 14 } } - version: 1 })pb")); + AddVersionToAutotuneResults(autotune_results); autotune_results.mutable_results(0)->set_device( std::string(cache_key.GetModelStr())); autotune_results.mutable_results(0)->set_hlo(std::string(cache_key.GetHlo())); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index 8e9071858f0aeb..38b5e65a590572 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -1,3 +1,17 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + version: 3 results { device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" @@ -10,7 +24,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" @@ -23,7 +36,6 @@ results { algorithm: -1 } } - version: 1 } results { device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB" @@ -36,7 +48,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB" @@ -49,7 +60,6 @@ results { algorithm: -1 } } - version: 1 } results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" @@ -62,7 +72,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" @@ -75,7 +84,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" @@ -87,7 +95,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" @@ -100,7 +107,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" @@ -113,7 +119,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 10.0, Cores: 148, GPU clock: 1.65 GHz, Memory bandwidth: 8192 GB/s, L2 cache: 126.5 MB" @@ -126,7 +131,6 @@ results { nanos: 1 } } - version: 1 } results { device: "CUDA: 10.0, Cores: 148, GPU clock: 1.65 GHz, Memory bandwidth: 8192 GB/s, L2 cache: 126.5 MB", @@ -139,5 +143,4 @@ results { nanos: 1 } } - version: 1 } diff --git a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt index eea5608edad2b8..0cae8b48151a6a 100644 --- a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -21,7 +21,6 @@ results { algorithm: 13 } } - version: 1 } results { device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB" @@ -45,5 +44,4 @@ results { } } } - version: 1 } From 13a22fd5a3e90d1a795f45b1bc393c003cce0206 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 31 Mar 2025 23:00:33 -0700 Subject: [PATCH 0070/1324] Introduce `Client::MakeErrorArrays` for creating poisoned IFRT arrays This is useful primarily for tests that exercise error behavior. The API assumes that the sharding at least supports `Sharding::GetShardShape()`, which seems to be a reasonable assumption for the target use case and aligns with the general direction of IFRT sharding. PiperOrigin-RevId: 742546412 --- .../xla/backends/cpu/nanort/ifrt_client.cc | 10 +++ .../xla/xla/backends/cpu/nanort/ifrt_client.h | 8 +++ .../backends/cpu/nanort/ifrt_client_test.cc | 8 ++- .../xla/xla/python/compile_only_ifrt/BUILD | 1 + .../xla/xla/python/compile_only_ifrt/client.h | 9 +++ third_party/xla/xla/python/ifrt/BUILD | 3 +- .../xla/python/ifrt/array_impl_test_lib.cc | 31 +++++++++ third_party/xla/xla/python/ifrt/client.h | 7 ++ third_party/xla/xla/python/ifrt/mock.cc | 9 +++ third_party/xla/xla/python/ifrt/mock.h | 8 +++ .../xla/xla/python/ifrt_proxy/client/BUILD | 1 + .../xla/xla/python/ifrt_proxy/client/array.cc | 43 ++++++++++++ .../xla/xla/python/ifrt_proxy/client/array.h | 9 +++ .../xla/python/ifrt_proxy/client/client.cc | 9 +++ .../xla/xla/python/ifrt_proxy/client/client.h | 5 ++ .../python/ifrt_proxy/client/rpc_helper.cc | 1 + .../xla/python/ifrt_proxy/client/rpc_helper.h | 2 + .../ifrt_proxy/common/ifrt_service.proto | 13 ++++ .../python/ifrt_proxy/server/ifrt_backend.cc | 35 ++++++++++ .../python/ifrt_proxy/server/ifrt_backend.h | 2 + .../xla/xla/python/pjrt_ifrt/pjrt_client.cc | 66 +++++++++++++++++++ .../xla/xla/python/pjrt_ifrt/pjrt_client.h | 5 ++ 22 files changed, 280 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc index c19c5a02bc0759..7de59f236a215e 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc @@ -57,6 +57,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/utils.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" @@ -1272,6 +1273,15 @@ NanoIfrtClient::MakeArraysFromHostBufferShards( std::move(user_context)); } +absl::StatusOr>> +NanoIfrtClient::MakeErrorArrays( + const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) { + return absl::UnimplementedError( + "NanoIfrtClient does not support MakeErrorArrays."); +} + absl::StatusOr> NanoIfrtClient::AssembleArrayFromSingleDeviceArrays( ifrt::DType dtype, ifrt::Shape shape, diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h index 0a908f5fb0749e..d5c5fe7e226e08 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/base/nullability.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -30,6 +31,7 @@ limitations under the License. #include "xla/backends/cpu/nanort/nanort_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" @@ -103,6 +105,12 @@ class NanoIfrtClient : public llvm::RTTIExtends { HostBufferSemantics semantics, tsl::RCReference user_context) override; + absl::StatusOr>> + MakeErrorArrays( + const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) override; + // Assembles a sharded array from a list of single device arrays. If the // provided sharding is specific enough to assemble a dense array, this method // will actually return an assembled array that pretends it is sharded. diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc index b5b5ab55640bec..29f1ddd6938166 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc @@ -239,10 +239,12 @@ BENCHMARK(BM_IfRtAddManyScalars); } // namespace xla::cpu int main(int argc, char** argv) { - // This test expects copies to multiple devices to fail, but we only have one - // device and it doesn't seem worth pretending that we have more. static constexpr absl::string_view kFilter = - "-ArrayImplTest.CopyMixedSourceDevices"; + // This test expects copies to multiple devices to fail, but we only have + // one device and it doesn't seem worth pretending that we have more. + "-ArrayImplTest.CopyMixedSourceDevices:" + // `MakeErrorArrays` is not supported in NanoIfrtClient. + "ArrayImplTest.MakeErrorArrays"; xla::ifrt::test_util::SetTestFilterIfNotUserSpecified(kFilter); for (int i = 1; i < argc; i++) { diff --git a/third_party/xla/xla/python/compile_only_ifrt/BUILD b/third_party/xla/xla/python/compile_only_ifrt/BUILD index 0a57a9d008c3a2..ce855ffa938904 100644 --- a/third_party/xla/xla/python/compile_only_ifrt/BUILD +++ b/third_party/xla/xla/python/compile_only_ifrt/BUILD @@ -29,6 +29,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/python/compile_only_ifrt/client.h b/third_party/xla/xla/python/compile_only_ifrt/client.h index 92539fe1345bb4..8f4dc1cb01d731 100644 --- a/third_party/xla/xla/python/compile_only_ifrt/client.h +++ b/third_party/xla/xla/python/compile_only_ifrt/client.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/base/nullability.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -36,6 +37,7 @@ limitations under the License. #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" @@ -225,6 +227,13 @@ class CompileOnlyIfRtClient final "client."); } + absl::StatusOr>> MakeErrorArrays( + const absl::Status& error, absl::Span array_specs, + tsl::RCReference user_context) override { + return Unimplemented( + "MakeErrorArrays not available with compile-only client."); + } + absl::StatusOr> AssembleArrayFromSingleDeviceArrays( ifrt::DType dtype, ifrt::Shape shape, diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 211c8522b2c9d3..2ddee003c5a771 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -131,7 +131,6 @@ xla_cc_test( ":mock", "//xla/tsl/concurrency:ref_count", "@com_google_googletest//:gtest_main", - "@llvm-project//llvm:Support", ], ) @@ -512,6 +511,7 @@ cc_library( "//xla/tsl/framework:allocator", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -759,7 +759,6 @@ xla_cc_test( "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index 1745c23a8a0c0b..0c445013ea04d4 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" @@ -49,8 +50,10 @@ namespace xla { namespace ifrt { namespace { +using ::testing::_; using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; @@ -434,6 +437,34 @@ TEST(ArrayImplTest, MakeArraysFromHostBufferShardsAndCopyToHostBuffer) { } } +TEST(ArrayImplTest, MakeErrorArrays) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + xla::ifrt::DeviceListRef device_list = + client->MakeDeviceList(client->addressable_devices()); + + Shape shape({2, 2}); + ArraySpec array_spec = { + /*dtype=*/xla::ifrt::DType(xla::ifrt::DType::kS8), + /*shape=*/shape, + /*sharding=*/ + xla::ifrt::ConcreteEvenSharding::Create( + device_list, xla::ifrt::MemoryKind(), shape, /*shard_shape=*/shape, + /*is_fully_replicated=*/true), + }; + + const absl::Status error = absl::InternalError("injected error"); + TF_ASSERT_OK_AND_ASSIGN( + const std::vector> arrays, + client->MakeErrorArrays(error, {array_spec, array_spec}, + client->CreateUserContext())); + ASSERT_EQ(arrays.size(), 2); + + EXPECT_THAT(arrays[0]->GetReadyFuture().Await(), + StatusIs(_, HasSubstr("injected error"))); + EXPECT_THAT(arrays[1]->GetReadyFuture().Await(), + StatusIs(_, HasSubstr("injected error"))); +} + TEST(ArrayImplTest, AssembleArray) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); diff --git a/third_party/xla/xla/python/ifrt/client.h b/third_party/xla/xla/python/ifrt/client.h index ca365ba8103b01..d90adf46b272df 100644 --- a/third_party/xla/xla/python/ifrt/client.h +++ b/third_party/xla/xla/python/ifrt/client.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -208,6 +209,12 @@ class Client : public llvm::RTTIExtends { HostBufferSemantics semantics, tsl::RCReference user_context) = 0; + // Creates new arrays that will be fulfilled with the given error status. The + // status must not be OK. + virtual absl::StatusOr>> MakeErrorArrays( + const absl::Status& error, absl::Span array_specs, + tsl::RCReference user_context) = 0; + // Builds a larger array out of individual per-device shards. // TODO(hyeontaek): Replace this API with the version that takes // `SingleDeviceShardSemantics` and `dtype`. diff --git a/third_party/xla/xla/python/ifrt/mock.cc b/third_party/xla/xla/python/ifrt/mock.cc index 6bc6c0e3b39ce3..4d000b10340e11 100644 --- a/third_party/xla/xla/python/ifrt/mock.cc +++ b/third_party/xla/xla/python/ifrt/mock.cc @@ -23,10 +23,12 @@ limitations under the License. #include #include "absl/base/nullability.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" @@ -129,6 +131,13 @@ MockClient::MockClient(std::unique_ptr delegated) return delegated_->MakeArraysFromHostBufferShards( specs, semantics, std::move(user_context)); }); + ON_CALL(*this, MakeErrorArrays) + .WillByDefault([this](const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) { + return delegated_->MakeErrorArrays(error, array_specs, + std::move(user_context)); + }); ON_CALL(*this, AssembleArrayFromSingleDeviceArrays(_, _, _, _, _, _)) .WillByDefault( [this](DType dtype, Shape shape, diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index 55385537a794e1..c529549db13e84 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/hash/hash.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" @@ -128,6 +130,12 @@ class MockClient : public llvm::RTTIExtends { HostBufferSemantics semantics, tsl::RCReference user_context), (final)); + MOCK_METHOD(absl::StatusOr>>, + MakeErrorArrays, + (const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context), + (final)); MOCK_METHOD(absl::StatusOr>, AssembleArrayFromSingleDeviceArrays, (DType dtype, Shape shape, diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 23b73b96ed7a4b..4180533cff7215 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -250,6 +250,7 @@ cc_library( "//xla/python/ifrt_proxy/common:versions", "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_to_from_proto", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc index 2e8783b43677a3..da1422b393f857 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -40,6 +40,7 @@ #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/client_impl_util.h" #include "xla/python/ifrt/dtype.h" @@ -58,6 +59,7 @@ #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status_to_from_proto.h" #include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -374,6 +376,47 @@ Array::MakeArraysFromHostBufferShards( return arrays; } +absl::StatusOr>> +Array::MakeErrorArrays(xla::ifrt::Client* client, + std::shared_ptr rpc_helper, + const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) { + auto req = std::make_unique(); + *req->mutable_error() = tsl::StatusToProto(error); + + std::vector arr_handles; + arr_handles.reserve(array_specs.size()); + + for (const ArraySpec& array_spec : array_specs) { + const uint64_t array_handle = rpc_helper->NextHandle(); + req->add_array_handles(array_handle); + TF_ASSIGN_OR_RETURN(*req->add_array_specs(), array_spec.ToProto()); + arr_handles.push_back(ArrayHandle{array_handle}); + } + + if (rpc_helper->version().protocol_version() < 10) { + TF_ASSIGN_OR_RETURN(auto resp, + rpc_helper->MakeErrorArrays(std::move(req)).Await()); + for (const uint64_t array_handle : resp->array_handles()) { + arr_handles.push_back(ArrayHandle{array_handle}); + } + } else { + CheckResponseAfterAsyncCall(rpc_helper->MakeErrorArrays(std::move(req)), + arr_handles); + } + + std::vector> arrays; + arrays.reserve(array_specs.size()); + for (int i = 0; i < array_specs.size(); ++i) { + const xla::ifrt::ArraySpec& array_spec = array_specs[i]; + arrays.push_back(tsl::MakeRef(client, rpc_helper, array_spec.dtype, + array_spec.shape, array_spec.sharding, + arr_handles[i])); + } + return arrays; +} + void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { if (rpc_helper->version().protocol_version() >= 5) { rpc_helper->Batch(RpcHelper::kDestructArray, handle); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.h b/third_party/xla/xla/python/ifrt_proxy/client/array.h index 752420dca57037..9068fb6e4864b3 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.h @@ -34,6 +34,7 @@ #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/future.h" @@ -41,6 +42,7 @@ #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/user_context.h" #include "xla/python/ifrt/value.h" #include "xla/python/ifrt_proxy/client/rpc_helper.h" #include "xla/python/ifrt_proxy/common/types.h" @@ -75,6 +77,13 @@ class Array final : public llvm::RTTIExtends { xla::ifrt::Client::HostBufferSemantics semantics, tsl::RCReference user_context); + static absl::StatusOr>> + MakeErrorArrays(xla::ifrt::Client* client, + std::shared_ptr rpc_helper, + const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context); + // `Array::AssembleArrayFromSingleDeviceArrays()` implements // `Client::AssembleArrayFromSingleDeviceArrays()`. // TODO(b/261226026): Implement logic directly in client.cc. diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc index f3f59e64a27c77..67072212b9c5e1 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -34,6 +34,7 @@ #include "llvm/Support/Casting.h" #include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" @@ -243,6 +244,14 @@ Client::MakeArraysFromHostBufferShards( this, rpc_helper_, specs, semantics, std::move(user_context)); } +absl::StatusOr>> +Client::MakeErrorArrays(const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) { + return Array::MakeErrorArrays(this, rpc_helper_, error, array_specs, + std::move(user_context)); +} + absl::StatusOr> Client::AssembleArrayFromSingleDeviceArrays( DType dtype, Shape shape, std::shared_ptr sharding, diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.h b/third_party/xla/xla/python/ifrt_proxy/client/client.h index 4d7e752826bbf8..4c41d468fede13 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.h @@ -32,6 +32,7 @@ #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" @@ -78,6 +79,10 @@ class Client final : public llvm::RTTIExtends { absl::Span specs, HostBufferSemantics semantics, tsl::RCReference user_context) override; + absl::StatusOr>> + MakeErrorArrays(const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) override; absl::StatusOr> AssembleArrayFromSingleDeviceArrays( DType dtype, Shape shape, std::shared_ptr sharding, diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index 9ebceec10ce679..57f6cd4daec490 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -323,6 +323,7 @@ RPC(CheckFuture, check_future); RPC(CheckValueReady, check_value_ready); RPC(MakeArrayFromHostBuffer, make_array_from_host_buffer); RPC(MakeArraysFromHostBufferShards, make_arrays_from_host_buffer_shards); +RPC(MakeErrorArrays, make_error_arrays); RPC(AssembleArrayFromSingleDeviceArrays, assemble_array_from_single_device_arrays); RPC(RemapArrays, remap_arrays); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h index c80b3be8cf2b3c..5fc0440b44e37c 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -103,6 +103,8 @@ class RpcHelper { ResponseFuture MakeArraysFromHostBufferShards( std::unique_ptr req); + ResponseFuture MakeErrorArrays( + std::unique_ptr req); ResponseFuture AssembleArrayFromSingleDeviceArrays( std::unique_ptr req); diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 85ce1d647e8571..3d7ede383ca30e 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -51,6 +51,7 @@ message IfrtRequest { MakeArrayFromHostBufferRequest make_array_from_host_buffer_request = 4; MakeArraysFromHostBufferShardsRequest make_arrays_from_host_buffer_shards_request = 25; + MakeErrorArraysRequest make_error_arrays_request = 26; AssembleArrayFromSingleDeviceArraysRequest assemble_array_from_single_device_arrays_request = 5; RemapArraysRequest remap_arrays_request = 23; @@ -101,6 +102,7 @@ message IfrtResponse { MakeArrayFromHostBufferResponse make_array_from_host_buffer_response = 4; MakeArraysFromHostBufferShardsResponse make_arrays_from_host_buffer_shards_response = 25; + MakeErrorArraysResponse make_error_arrays_response = 26; AssembleArrayFromSingleDeviceArraysResponse assemble_array_from_single_device_arrays_response = 5; RemapArraysResponse remap_arrays_response = 23; @@ -304,6 +306,17 @@ message MakeArraysFromHostBufferShardsResponse { repeated fixed64 array_handles = 1; } +message MakeErrorArraysRequest { + tensorflow.StatusProto error = 1; + repeated xla.ifrt.ArraySpecProto array_specs = 2; + // If array_handles is provided, the server will either respond with the same + // handles in `MakeErrorArraysResponse` or return an error. + repeated fixed64 array_handles = 3; +} +message MakeErrorArraysResponse { + repeated fixed64 array_handles = 1; +} + // Makes an IFRT Array from a set of single-device Arrays. // Equivalent to ifrt::Client::AssembleArrayFromSingleDeviceArrays. message AssembleArrayFromSingleDeviceArraysRequest { diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index 2e30a2db2e56df..a58781905c37b4 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -528,6 +528,11 @@ Future IfrtBackend::ProcessInternal( &array_store_); return Future(HandleMakeArraysFromHostBufferShardsRequest( *asr, std::move(request))); + case IfrtRequest::RequestCase::kMakeErrorArraysRequest: + asr.emplace(request->make_error_arrays_request().array_handles(), + &array_store_); + return Future( + HandleMakeErrorArraysRequest(*asr, std::move(request))); case IfrtRequest::RequestCase::kAssembleArrayFromSingleDeviceArraysRequest: asr.emplace(request->assemble_array_from_single_device_arrays_request() .result_handle(), @@ -946,6 +951,36 @@ IfrtBackend::HandleMakeArraysFromHostBufferShardsRequest( return response; } +absl::StatusOr +IfrtBackend::HandleMakeErrorArraysRequest( + ArrayStore::Reservation& asr, std::unique_ptr request) { + CHECK(request->has_make_error_arrays_request()); + auto* make_array_request = request->mutable_make_error_arrays_request(); + + const absl::Status error = tsl::StatusFromProto(make_array_request->error()); + + std::vector array_specs; + array_specs.reserve(make_array_request->array_specs_size()); + for (const auto& array_spec_proto : make_array_request->array_specs()) { + TF_ASSIGN_OR_RETURN(auto array_spec, + ArraySpec::FromProto(client_.get(), array_spec_proto)); + array_specs.push_back(std::move(array_spec)); + } + + TF_ASSIGN_OR_RETURN(std::vector arrays, + client_->MakeErrorArrays(error, array_specs, + client_->CreateUserContext())); + + std::unique_ptr response = + NewIfrtResponse(request->request_metadata().op_id()); + auto* make_array_resp = response->mutable_make_error_arrays_response(); + for (uint64_t handle : asr.Fill(arrays)) { + make_array_resp->add_array_handles(handle); + } + + return response; +} + absl::StatusOr IfrtBackend::HandleAssembleArrayFromSingleDeviceArraysRequest( ArrayStore::Reservation& asr, std::unique_ptr request) { diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h index 9497d5f3c634af..166f6123f73741 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h @@ -179,6 +179,8 @@ class IfrtBackend final : public BackendInterface { ArrayStore::Reservation& asr, std::unique_ptr request); absl::StatusOr HandleMakeArraysFromHostBufferShardsRequest( ArrayStore::Reservation& asr, std::unique_ptr request); + absl::StatusOr HandleMakeErrorArraysRequest( + ArrayStore::Reservation& asr, std::unique_ptr request); absl::StatusOr HandleAssembleArrayFromSingleDeviceArraysRequest( ArrayStore::Reservation& asr, std::unique_ptr request); absl::StatusOr HandleRemapArraysRequest( diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index 51e5893f2896db..3df85dbc64ea25 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" @@ -82,6 +83,8 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/errors.h" @@ -983,6 +986,69 @@ PjRtClient::MakeArraysFromHostBufferShards( std::move(user_context)); } +absl::StatusOr>> +PjRtClient::MakeErrorArrays(const absl::Status& error, + absl::Span array_specs, + tsl::RCReference user_context) { + DCHECK(this); + std::vector> arrays; + arrays.reserve(array_specs.size()); + for (const auto& array_spec : array_specs) { + if (array_spec.dtype.kind() == DType::kString) { + TF_ASSIGN_OR_RETURN( + arrays.emplace_back(), + BasicStringArray::Create(this, array_spec.shape, array_spec.sharding, + Future(error), + /*on_done_with_buffer=*/[]() {})); + continue; + } + + TF_ASSIGN_OR_RETURN(auto primitive_type, ToPrimitiveType(array_spec.dtype)); + TF_ASSIGN_OR_RETURN(Shape shard_shape, + array_spec.sharding->GetShardShape(array_spec.shape)); + xla::Shape xla_shape = + xla::ShapeUtil::MakeShape(primitive_type, shard_shape.dims()); + + PjRtArray::PjRtBuffers buffers; + buffers.reserve(array_spec.sharding->devices()->size()); + for (xla::ifrt::Device* const device : + array_spec.sharding->devices()->devices()) { + std::unique_ptr buffer; + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + Memory* memory = nullptr; + for (Memory* ms : device->Memories()) { + if (ms->Kind() == array_spec.sharding->memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return absl::InvalidArgumentError(absl::StrFormat( + "Invalid memory kind: %s; available memory kinds: %s", + *array_spec.sharding->memory_kind().memory_kind(), + absl::StrJoin( + array_spec.sharding->devices()->devices().front()->Memories(), + ", ", [](std::string* out, Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + }))); + } + TF_ASSIGN_OR_RETURN( + buffers.emplace_back(), + pjrt_client_->CreateErrorBuffer( + error, xla_shape, + tensorflow::down_cast(memory)->pjrt_memory())); + } + auto layout = buffers.front()->layout(); + TF_ASSIGN_OR_RETURN( + arrays.emplace_back(), + PjRtArray::Create(this, array_spec.dtype, std::move(shard_shape), + array_spec.sharding, std::move(buffers), + std::move(layout))); + } + return arrays; +} + absl::StatusOr> PjRtClient::AssembleArrayFromSingleDeviceArrays( DType dtype, Shape shape, std::shared_ptr sharding, diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h index e30a8f9c330ff9..3a904b84287b2b 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h @@ -40,6 +40,7 @@ limitations under the License. #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" @@ -183,6 +184,10 @@ class PjRtClient final HostBufferSemantics semantics, tsl::RCReference user_context) override; + absl::StatusOr>> MakeErrorArrays( + const absl::Status& error, absl::Span array_specs, + tsl::RCReference user_context) override; + absl::StatusOr> AssembleArrayFromSingleDeviceArrays( DType dtype, Shape shape, std::shared_ptr sharding, absl::Span> arrays, From d898facd6c3e34318c1fdb3a5ee01f34dd16ff7b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 08:48:50 +0000 Subject: [PATCH 0071/1324] Bump the github-actions group with 5 updates Bumps the github-actions group with 5 updates: | Package | From | To | | --- | --- | --- | | [google/osv-scanner-action](https://github.com/google/osv-scanner-action) | `1.9.2` | `2.0.0` | | [actions/setup-python](https://github.com/actions/setup-python) | `5.4.0` | `5.5.0` | | [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) | `7.0.7` | `7.0.8` | | [actions/upload-artifact](https://github.com/actions/upload-artifact) | `4.6.1` | `4.6.2` | | [github/codeql-action](https://github.com/github/codeql-action) | `3.28.10` | `3.28.13` | Updates `google/osv-scanner-action` from 1.9.2 to 2.0.0 - [Release notes](https://github.com/google/osv-scanner-action/releases) - [Commits](https://github.com/google/osv-scanner-action/compare/v1.9.2...v2.0.0) Updates `actions/setup-python` from 5.4.0 to 5.5.0 - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/42375524e23c412d93fb67b49958b491fce71c38...8d9ed9ac5c53483de85588cdf95a591a75ab9f55) Updates `peter-evans/create-pull-request` from 7.0.7 to 7.0.8 - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/dd2324fc52d5d43c699a5636bcf19fceaa70c284...271a8d0340265f705b14b6d32b9829c1cb33d45e) Updates `actions/upload-artifact` from 4.6.1 to 4.6.2 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1...ea165f8d65b6e75b540449e92b4886f43607fa02) Updates `github/codeql-action` from 3.28.10 to 3.28.13 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/b56ba49b26e50535fa1e7f7db0f4f7b4bf65d80d...1b549b9259bda1cb5ddde3b41741a82a2d15a841) --- updated-dependencies: - dependency-name: google/osv-scanner-action dependency-version: 2.0.0 dependency-type: direct:production update-type: version-update:semver-major dependency-group: github-actions - dependency-name: actions/setup-python dependency-version: 5.5.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions - dependency-name: peter-evans/create-pull-request dependency-version: 7.0.8 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: actions/upload-artifact dependency-version: 4.6.2 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: github/codeql-action dependency-version: 3.28.13 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions ... Signed-off-by: dependabot[bot] --- .github/workflows/osv-scanner-scheduled.yml | 2 +- .github/workflows/pylint-presubmit.yml | 2 +- .github/workflows/release-branch-cherrypick.yml | 2 +- .github/workflows/scorecards-analysis.yml | 4 ++-- .github/workflows/update-rbe.yml | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index e612b642fb1959..c1d680fcdfcea6 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.9.2" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.0.0" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index 09801d29b69797..97a7b7a5f8285f 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -38,7 +38,7 @@ jobs: run: | echo Changed files: ${{ steps.get_file_changes.outputs.files }} - name: Set up Python 3.9 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: "3.9" - name: Install Python dependencies diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml index 6587769b85b868..4fa4f8d5b9435a 100644 --- a/.github/workflows/release-branch-cherrypick.yml +++ b/.github/workflows/release-branch-cherrypick.yml @@ -58,7 +58,7 @@ jobs: echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT" echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT" - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@dd2324fc52d5d43c699a5636bcf19fceaa70c284 # v7.0.7 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"' committer: TensorFlow Release Automation diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 6adc36c3749df4..c68351d3bd3a23 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -55,7 +55,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: SARIF file path: results.sarif @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b56ba49b26e50535fa1e7f7db0f4f7b4bf65d80d # v3.28.10 + uses: github/codeql-action/upload-sarif@1b549b9259bda1cb5ddde3b41741a82a2d15a841 # v3.28.13 with: sarif_file: results.sarif diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 11b83f43e70882..a06d2e0125f6b9 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -130,7 +130,7 @@ jobs: map sigbuild-r2.17-clang-python3.11 2.17-python3.11 map sigbuild-r2.17-clang-python3.12 2.17-python3.12 - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@dd2324fc52d5d43c699a5636bcf19fceaa70c284 # v7.0.7 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation From 745f19bf350b9892233984402486fc9680f7bdb1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 01:42:12 -0700 Subject: [PATCH 0072/1324] Adding executable location to KubernetesClusterResolver. PiperOrigin-RevId: 742591430 --- .../distribute/cluster_resolver/__init__.py | 1 + .../kubernetes_cluster_resolver.py | 47 ++++++++++++++++--- .../kubernetes_cluster_resolver_test.py | 33 +++++++++++++ ...esolver.-kubernetes-cluster-resolver.pbtxt | 2 +- ...lver.-kubernetes-executable-location.pbtxt | 12 +++++ ...nsorflow.distribute.cluster_resolver.pbtxt | 4 ++ ...esolver.-kubernetes-cluster-resolver.pbtxt | 2 +- ...lver.-kubernetes-executable-location.pbtxt | 12 +++++ ...nsorflow.distribute.cluster_resolver.pbtxt | 4 ++ 9 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt diff --git a/tensorflow/python/distribute/cluster_resolver/__init__.py b/tensorflow/python/distribute/cluster_resolver/__init__.py index 35903dc7b254f7..405cd28ed08230 100644 --- a/tensorflow/python/distribute/cluster_resolver/__init__.py +++ b/tensorflow/python/distribute/cluster_resolver/__init__.py @@ -25,6 +25,7 @@ from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import ExecutableLocation from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver diff --git a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py index f74089ed0415b6..2eb31344e34268 100644 --- a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py @@ -14,12 +14,30 @@ # ============================================================================== """Implementation of Cluster Resolvers for Kubernetes.""" +import enum + from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url from tensorflow.python.training import server_lib from tensorflow.python.util.tf_export import tf_export +@tf_export('distribute.cluster_resolver.KubernetesExecutableLocation') +class ExecutableLocation(enum.Enum): + """Defines where the executable runs on. + + This is used to determine how to resolve the configuration + to talk with the kube api server. + + `WITHIN_CLUSTER` means that the TensorFlow code you are running is running + in a pod within the cluster itself. + `OFF_CLUSTER` means any other enviroment outside the cluster. + """ + + WITHIN_CLUSTER = 0 + OFF_CLUSTER = 1 + + @tf_export('distribute.cluster_resolver.KubernetesClusterResolver') class KubernetesClusterResolver(ClusterResolver): """ClusterResolver for Kubernetes. @@ -55,11 +73,14 @@ class KubernetesClusterResolver(ClusterResolver): ``` """ - def __init__(self, - job_to_label_mapping=None, - tf_server_port=8470, - rpc_layer='grpc', - override_client=None): + def __init__( + self, + job_to_label_mapping=None, + tf_server_port=8470, + rpc_layer='grpc', + override_client=None, + executable_location=ExecutableLocation.WITHIN_CLUSTER, + ): """Initializes a new KubernetesClusterResolver. This initializes a new Kubernetes ClusterResolver. The ClusterResolver @@ -80,17 +101,29 @@ def __init__(self, between tasks in Kubernetes. Defaults to 'grpc'. override_client: The Kubernetes client (usually automatically retrieved using `from kubernetes import client as k8sclient`). If you pass this - in, you are responsible for setting Kubernetes credentials manually. + in, you are responsible for setting Kubernetes credentials manually and + calling `k8sconfig.load_kube_config()` or + `k8sconfig.load_incluster_config()` before using this ClusterResolver. + executable_location: Parameter that specifies whether or not this + TensorFlow code is running from within a K8S cluster or not. Raises: ImportError: If the Kubernetes Python client is not installed and no `override_client` is passed in. RuntimeError: If autoresolve_task is not a boolean or a callable. + ValueError: If `executable_location` is not a valid value. """ try: from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top - k8sconfig.load_kube_config() + if not override_client: + if executable_location == ExecutableLocation.OFF_CLUSTER: + k8sconfig.load_kube_config() + elif executable_location == ExecutableLocation.WITHIN_CLUSTER: + k8sconfig.load_incluster_config() + else: + raise ValueError('The executable location provided is invalid.') + except ImportError: if not override_client: raise ImportError('The Kubernetes Python client must be installed ' diff --git a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py index 9e663728c49d72..5e359f0eaa05cb 100644 --- a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver_test.py @@ -14,6 +14,9 @@ # ============================================================================== """Tests for K8sClusterResolver.""" +import sys + +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import ExecutableLocation from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -21,6 +24,10 @@ mock = test.mock +def _mock_kubernetes_module(): + sys.modules['kubernetes'] = mock.MagicMock() + + def _mock_kubernetes_client(ret): mock_client = mock.MagicMock() mock_client.list_pod_for_all_namespaces.side_effect = ( @@ -68,6 +75,32 @@ def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): server_lib.ClusterSpec( cluster_spec.as_dict()).as_cluster_def()) + def testSingleItemSuccessfulRetrievalInCluster(self): + ret = _create_pod_list( + ('tensorflow-abc123', 'Running', '10.1.2.3'), + ) + + cluster_resolver = KubernetesClusterResolver( + override_client=_mock_kubernetes_client({'job-name=tensorflow': ret}), + executable_location=ExecutableLocation.WITHIN_CLUSTER, + ) + + actual_cluster_spec = cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.1.2.3:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + + def testValueErrorRaisedOnInvalidExecutableLocation(self): + + _mock_kubernetes_module() + + with self.assertRaisesRegexp(ValueError, '.*'): + KubernetesClusterResolver(executable_location=None) + def testSingleItemSuccessfulRetrieval(self): ret = _create_pod_list(('tensorflow-abc123', 'Running', '10.1.2.3'),) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt index 2819ca85612ea9..5fde5f4d764790 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], " + argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\', \'executable_location\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\', \'ExecutableLocation.WITHIN_CLUSTER\'], " } member_method { name: "cluster_spec" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt new file mode 100644 index 00000000000000..ba6735c2f15bbf --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.distribute.cluster_resolver.KubernetesExecutableLocation" +tf_class { + is_instance: "" + member { + name: "OFF_CLUSTER" + mtype: "" + } + member { + name: "WITHIN_CLUSTER" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.pbtxt index 5906ffa850a360..797e01a3a57892 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "KubernetesClusterResolver" mtype: "" } + member { + name: "KubernetesExecutableLocation" + mtype: "" + } member { name: "SimpleClusterResolver" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt index 2819ca85612ea9..5fde5f4d764790 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-cluster-resolver.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], " + argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\', \'executable_location\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\', \'ExecutableLocation.WITHIN_CLUSTER\'], " } member_method { name: "cluster_spec" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt new file mode 100644 index 00000000000000..ba6735c2f15bbf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-kubernetes-executable-location.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.distribute.cluster_resolver.KubernetesExecutableLocation" +tf_class { + is_instance: "" + member { + name: "OFF_CLUSTER" + mtype: "" + } + member { + name: "WITHIN_CLUSTER" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.pbtxt index 5906ffa850a360..797e01a3a57892 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "KubernetesClusterResolver" mtype: "" } + member { + name: "KubernetesExecutableLocation" + mtype: "" + } member { name: "SimpleClusterResolver" mtype: "" From d22d13ee1cf390d8450b6b78507b77d30092302c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 01:42:14 -0700 Subject: [PATCH 0073/1324] Automated Code Change PiperOrigin-RevId: 742591443 --- third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 2 ++ third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc | 1 + third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h | 1 + .../xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc | 1 - 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index 644c2dbd815d36..d7a4ecc61d5e98 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -181,7 +181,9 @@ cc_library( "//xla/mlir/utils:type_util", "//xla/mlir_hlo", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index 809d8683aadba2..eb5744bdc6109f 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h index 23737ce07d07f0..8c85eec6d711c3 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc index 3751bd31a90ed2..821b2317997206 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" From a7ace96a10521859b80d9e3bbe041538b9fcc333 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 1 Apr 2025 01:45:33 -0700 Subject: [PATCH 0074/1324] Reverts 1f09b81cc4a47312472b1c59e38d42624aff1228 PiperOrigin-RevId: 742592480 --- .../model/gpu_indexing_performance_model.cc | 9 +- .../gpu/model/symbolic_tile_analysis.h | 3 - .../xla/xla/service/gpu/transforms/BUILD | 1 - .../service/gpu/transforms/priority_fusion.cc | 218 +++--------------- .../service/gpu/transforms/priority_fusion.h | 3 +- .../gpu/transforms/priority_fusion_test.cc | 103 --------- 6 files changed, 37 insertions(+), 300 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index ce41097da44677..8f22c3e347f61a 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -512,7 +512,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( } absl::Duration compute_time = - ComputeTime(*device_info_, flops, num_blocks, + ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), launch_dimensions.num_threads_per_block()); absl::Duration memory_access_time = read_time + write_time; @@ -543,12 +543,15 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( return absl::FailedPreconditionError(absl::StrCat( "SymbolicTileAnalysis failed. ", fusion_decision->Explain())); } + // TODO(b/390559452): Add support for more than one fusion root. + if (tile_sizes.size() != 1) { + return absl::UnimplementedError("Only 1 root is supported right now"); + } SymbolicTileAnalysis analysis = std::get(std::move(analysis_or_error)); TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, - analysis.ComputeTiledHloInstructions( - tile_sizes[analysis.real_root_index()])); + analysis.ComputeTiledHloInstructions(tile_sizes[0])); return EstimateRunTimeForTiledHloComputation( fusion_adaptor, tiled_hlo_computation, launch_dimensions); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 07243bb4904ef4..17bfd837ad38c1 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -142,9 +142,6 @@ class SymbolicTileAnalysis { return root_indexing_.roots[idx]; } - // Returns the output index of the real root. - int64_t real_root_index() const { return root_indexing_.real_root_index; } - // Returns the number of tile parameters in this symbolic analysis. // TODO(b/390569102): This assumes that there is only one root that matters // for computing the tiling, and that it is the last symbolic tiled hlo diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 68d8af6460438d..ba26a805777e15 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2438,7 +2438,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/gpu/codegen/triton:support", - "//xla/hlo/analysis:hlo_dfs_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_traversal", diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index 11f3435fc59d1f..1850ff499126da 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -42,7 +42,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/backends/gpu/codegen/triton/support.h" #include "xla/debug_options_flags.h" -#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -163,7 +162,6 @@ class PriorityFusionQueue { fusion_analysis_cache_(fusion_analysis_cache), fusion_deduplication_cache_(fusion_deduplication_cache), fusion_info_cache_(*device_info_), - reachability_(HloDfsReachability::Build(computation)), triton_heroless_fusion_enabled_(triton_heroless_fusion_enabled) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -255,10 +253,6 @@ class PriorityFusionQueue { reverse_map_.erase(current_producer_); current_consumers_ = current_producer_->users(); - auto preferred_consumer = GetPreferredConsumer(current_producer_); - if (preferred_consumer) { - current_consumers_ = {*preferred_consumer}; - } if (HloPredicateIsOp(current_producer_)) { // We don't check if bitcasts can be fused with all consumers, so we @@ -272,15 +266,6 @@ class PriorityFusionQueue { return !current_consumers_.empty(); } - std::optional GetPreferredConsumer( - HloInstruction* producer) { - auto it = preferred_consumer_.find(producer); - if (it == preferred_consumer_.end()) { - return std::nullopt; - } - return it->second; - } - absl::Status UpdatePerformanceModelCache(HloInstruction* producer) { if (!IsFusible(*producer)) { return absl::OkStatus(); @@ -405,19 +390,9 @@ class PriorityFusionQueue { } // Updates data for the new fusion instruction and its users and operands. - // Both `original_producer` and `original_consumer` could have been removed - // already from the computation, waiting for deletion. We can still - // dereference them though. void OnFusingInstruction(HloInstruction* fusion, HloInstruction* original_producer, - HloInstruction* original_consumer, - int64_t original_consumer_operand_index) { - bool creates_multi_output_fusion = - preferred_consumer_.contains(original_producer); - fusion_deduplication_cache_.UpdateFusedInstructionId( - fusion, original_producer, original_consumer, - original_consumer_operand_index, creates_multi_output_fusion); - + HloInstruction* original_consumer) { if (fusion_process_dump_) { auto* fusion_step = fusion_process_dump_->add_fusion_steps()->mutable_fusion(); @@ -436,24 +411,13 @@ class PriorityFusionQueue { *fusion); } - if (fusion == original_consumer) { - // We need to check again whether we can use `original_consumer` as a - // producer for a ProducerConsumer multi-output fusion. - preferred_consumer_.erase(original_consumer); - } else { - // The original consumer was replaced with the fusion, but it's pointer - // can still be referenced somewhere, for example, in to_update_priority_. - // Priority recomputation is called before DCE. Remove all references to - // the original consumer here. - reachability_->OnInstructionReplaced(/*previous=*/original_consumer, - /*now=*/fusion); + // The original consumer was replaced with the fusion, but it's pointer can + // still be referenced somewhere, for example, in to_update_priority_. + // Priority recomputation is called before DCE. Remove all references to + // the original consumer here. + if (fusion != original_consumer) { RemoveInstruction(original_consumer); } - if (creates_multi_output_fusion) { - // After a multi-output fusion was created, we need to rebuild the - // HloDfsReachability data structure. - reachability_ = HloDfsReachability::Build(computation_); - } // Collect the instructions whose priorities need to be updated. for (HloInstruction* operand : fusion->operands()) { @@ -469,21 +433,10 @@ class PriorityFusionQueue { } to_update_priority_.insert(operand); - // Update the consumers of this operand that we care about, - // so we can do incremental update of the operand. + // update the consumers of this operand that we care about, + // so we can do incremental update of the operand operands_to_new_consumers_[operand].push_back(fusion); - - // We may need to reset `preferred_consumer_`, as we don't know yet - // whether that fusion would still be valid. - auto it = preferred_consumer_.find(operand); - if (it != preferred_consumer_.end() && it->second == original_consumer) { - preferred_consumer_.erase(it); - } } - // TODO(b/390559452): For multi-output fusion, we would also need to update - // the priorities of the other consumers of `producer` with which we did not - // fuse. For now, as we only allow multi-output fusion if there is just a - // single fusible consumer, this is not needed. to_update_priority_.insert(fusion); } @@ -498,7 +451,6 @@ class PriorityFusionQueue { } producer_priority_queue_.erase(reverse_it->second); reverse_map_.erase(reverse_it); - preferred_consumer_.erase(instruction); } // Returns a map from consumer to BlockLevelParameters. This is used to @@ -533,24 +485,9 @@ class PriorityFusionQueue { return -absl::InfiniteDuration(); } + // Don't fuse if we can't fuse in all users. if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer); !fusion_decision) { - // If we cannot fuse `producer` into all non-bitcast consumers, try - // Triton multi-output fusion next. - std::vector possible_consumers = - FindPossibleConsumersForTritonMultiOutputFusion(producer); - if (CanFuseTritonMultiOutputWithSingleUser(producer, - possible_consumers)) { - GpuPerformanceModel::RunTimes run_times = - GpuPerformanceModel::EstimateRunTimes( - producer, *device_info_, &cost_analysis_, - GpuPerformanceModelOptions::Default( - &fusion_analysis_cache_, &gpu_performance_model_cache_), - /*fused_consumers=*/possible_consumers); - preferred_consumer_[producer] = possible_consumers[0]; - return run_times.time_unfused - run_times.time_fused; - } - // Don't fuse if we can't fuse in all users. if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = fusion_process_dump_->add_fusion_steps() @@ -627,12 +564,10 @@ class PriorityFusionQueue { } TiledRunTimeDataOrError GetTiledRunTimeDataCached( - const HloInstruction* producer, const HloInstruction* consumer, - bool use_multi_output_fusion = false) { + const HloInstruction* producer, const HloInstruction* consumer) { FusionDeduplicationCache::FusionId fusion_id = [&]() { absl::MutexLock lock(&fusion_deduplication_cache_mutex_); - return fusion_deduplication_cache_.GetFusionId(producer, consumer, - use_multi_output_fusion); + return fusion_deduplication_cache_.GetFusionId(producer, consumer); }(); { @@ -644,8 +579,7 @@ class PriorityFusionQueue { } } - auto fusion = HloFusionAdaptor::ForProducerConsumer( - producer, consumer, use_multi_output_fusion); + auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer); absl::StatusOr result_or_status = gpu_indexing_performance_model_.TryFindBestTilingForFusion(*fusion); @@ -677,8 +611,7 @@ class PriorityFusionQueue { } FusionDecision CanFuseTriton(HloInstruction* producer, - HloInstruction* consumer, - bool use_multi_output_fusion = false) { + HloInstruction* consumer) { if (!IsGenericTritonFusion(*producer) && !IsGenericTritonFusion(*consumer) && !triton_heroless_fusion_enabled_) { return FusionDecision::Forbid("triton heroless fusion is not enabled"); @@ -693,7 +626,7 @@ class PriorityFusionQueue { } TiledRunTimeDataOrError tiled_run_time_data_or_error = - GetTiledRunTimeDataCached(producer, consumer, use_multi_output_fusion); + GetTiledRunTimeDataCached(producer, consumer); if (const auto* fusion_decision = std::get_if(&tiled_run_time_data_or_error)) { @@ -705,14 +638,9 @@ class PriorityFusionQueue { // This is our way to pass the runtime estimate to the CalculatePriorities() // function. - // This is somewhat brittle as we currently don't distinguish between - // ProducerConsumer fusion where we allow multi-output fusions to be formed, - // and ProducerConsumer fusion where we don't allow it. Same for the - // `block_level_parameters_cache_` down below. Currently we only try out - // multi-output fusion if we cannot fuse into all consumers, and it is tried - // last, so the final cached value should be what we want. gpu_performance_model_cache_.Set( *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); + { absl::MutexLock lock(&block_level_parameters_cache_mutex_); block_level_parameters_cache_[producer][consumer] = @@ -852,63 +780,6 @@ class PriorityFusionQueue { return fusion_decision; } - // Checks whether any operand of `consumer` is reachable from `producer` - // following user edges in the HLO graph. If that is the case, we would - // introduce a cycle by fusing `producer` into `consumer`. - bool OperandReachableFromProducer(const HloInstruction* producer, - const HloInstruction* consumer) { - for (const auto* consumer_operand : consumer->operands()) { - CHECK(reachability_->IsPresent(consumer_operand) && - reachability_->IsPresent(producer)) - << "Reachability map is incomplete. This should never " - "happen."; - if (producer != consumer_operand && - reachability_->IsReachable(producer, consumer_operand)) { - return true; - } - } - return false; - } - - std::vector FindPossibleConsumersForTritonMultiOutputFusion( - HloInstruction* producer) { - bool triton_multi_output_fusion_enabled = - producer->GetModule() - ->config() - .debug_options() - .xla_gpu_unsupported_enable_triton_multi_output_fusion(); - if (!triton_multi_output_fusion_enabled) { - return {}; - } - std::vector possible_consumers; - for (const auto& user : producer->users()) { - if (HloPredicateIsOp(user)) { - continue; - } - if (CanFuseTriton(producer, user, /*use_multi_output_fusion=*/true) && - !OperandReachableFromProducer(producer, user)) { - possible_consumers.push_back(user); - } - } - return possible_consumers; - } - - FusionDecision CanFuseTritonMultiOutputWithSingleUser( - HloInstruction* producer, - const std::vector& possible_consumers) { - if (possible_consumers.empty()) { - return FusionDecision::Forbid("No users to fuse"); - } - - if (possible_consumers.size() != 1) { - // TODO(b/390559452): If there are several possible consumers to fuse - // with, decide which one is best. Also depends on what further fusions - // might be possible, needs checking the reachability graph. - return FusionDecision::Forbid("more than one consumer to fuse with"); - } - return FusionDecision::Allow(); - } - FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { if (producer->users().empty()) { return FusionDecision::Forbid("No users to fuse"); @@ -954,11 +825,6 @@ class PriorityFusionQueue { // A reverse map that helps find an instruction in the priority queue. absl::flat_hash_map reverse_map_; - // Stores a mapping from the producer to the preferred consumer to fuse into. - // This is only used in case that we want to use ProducerConsumer multi-output - // fusion. - absl::flat_hash_map preferred_consumer_; - // The current producer being visited. HloInstruction* current_producer_; @@ -1014,10 +880,6 @@ class PriorityFusionQueue { // like shared memory usage or number of unnested reductions of fusion nodes. FusionInfoCache fusion_info_cache_; - // Allows evaluation of whether an HloInstruction is an ancestor of another - // HloInstruction. - std::unique_ptr reachability_; - // If true, redirect all fusion decisions to Triton fusion. bool triton_heroless_fusion_enabled_; @@ -1122,12 +984,7 @@ absl::StatusOr PriorityFusion::Run( block_level_parameters_map = fusion_queue->GetBlockLevelParametersMap(producer); - auto preferred_consumer = fusion_queue->GetPreferredConsumer(producer); - std::vector consumers = - fusion_queue->current_consumers(); - bool use_multi_output_fusion = preferred_consumer.has_value(); - - for (auto* consumer : consumers) { + for (auto* consumer : fusion_queue->current_consumers()) { // Don't fuse into single bitcasts. We ignore them in the check // CanFuseWithAllNonBitcastUsers(), so we need to check it here. if (HloPredicateIsOp(consumer)) { @@ -1141,8 +998,12 @@ absl::StatusOr PriorityFusion::Run( int64_t consumer_operand_index = consumer->operand_index(producer); fusion_queue->PreFusion(producer, consumer); - auto fusion_instruction = - Fuse(producer, consumer, use_multi_output_fusion); + auto fusion_instruction = Fuse(producer, consumer); + fusion_deduplication_cache.UpdateFusedInstructionId( + fusion_instruction, producer, consumer, consumer_operand_index); + fusion_queue->OnFusingInstruction(fusion_instruction, producer, + consumer); + auto backend_config_it = block_level_parameters_map.find(consumer); if (backend_config_it != block_level_parameters_map.end()) { TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( @@ -1150,25 +1011,20 @@ absl::StatusOr PriorityFusion::Run( fusion_instruction->set_fusion_kind( HloInstruction::FusionKind::kCustom); } - fusion_queue->OnFusingInstruction(fusion_instruction, producer, - consumer, consumer_operand_index); changed = true; } fusion_queue->ComputeRuntimesOfRemovedConsumers(); - if (use_multi_output_fusion || producer->user_count() == 0) { + if (producer->user_count() == 0) { fusion_queue->InvalidateCaches(producer); + producer->DetachFromOperandsAndUsers(); fusion_queue->RemoveInstruction(producer); - // When we use ProducerConsumer multi-output fusion, `producer` will - // have been removed already. - if (!use_multi_output_fusion) { - producer->DetachFromOperandsAndUsers(); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); - } + // Remove from computation. + TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); } - for (auto* consumer : consumers) { + for (auto* consumer : fusion_queue->current_consumers()) { fusion_queue->InvalidateCaches(consumer); } TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); @@ -1238,8 +1094,7 @@ HloInstruction::FusionKind PriorityFusion::ChooseKind( } HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer, - bool use_multi_output_fusion) { + HloInstruction* consumer) { VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); @@ -1260,22 +1115,9 @@ HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, /*skip_async_execution_thread_overwrite=*/false); if (HloPredicateIsOp(producer)) { - if (use_multi_output_fusion) { - fusion_instruction->MergeFusionInstructionIntoMultiOutput(producer); - } else { - fusion_instruction->MergeFusionInstruction(producer); - } + fusion_instruction->MergeFusionInstruction(producer); } else { - if (use_multi_output_fusion) { - fusion_instruction->FuseInstructionIntoMultiOutput(producer); - // MergeFusionInstructionIntoMultiOutput already removes `producer` from - // the computation. Do the same here, so that we have the invariant that - // the producer has been cleaned up when multi-output fusion is used. - CHECK_EQ(0, producer->user_count()); - TF_CHECK_OK(producer->parent()->RemoveInstruction(producer)); - } else { - fusion_instruction->FuseInstruction(producer); - } + fusion_instruction->FuseInstruction(producer); } if (fusion_instruction != consumer) { diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h index c9a29a5c056f35..f1d19532a755f5 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h @@ -60,8 +60,7 @@ class PriorityFusion : public HloModulePass { HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); - HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer, - bool use_multi_output_fusion = false); + HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); private: // Consumes a unit of compiler fuel and returns true if we should diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index 463b4d45c475b5..b534f46150b382 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -1026,109 +1026,6 @@ ENTRY main { 2); } -TEST_F(PriorityFusionTest, - FuseTritonProducerWithTwoConsumersUsingMultiOutputFusion) { - const std::string kHloText = R"( -HloModule t - -producer_computation { - parameter_0 = f32[125]{0} parameter(0) - ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} -} - -consumer_computation { - parameter_0 = f32[125,127] parameter(0) - ROOT log = f32[125,127] log(parameter_0) -} - -ENTRY main { - param_0 = f32[125]{0} parameter(0) - producer_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=producer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","127"]}],"num_warps":"1"}}} - consumer_fusion = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation - ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion, producer_fusion) -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); - EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true); - EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - - HloInstruction* root = module->entry_computation()->root_instruction(); - HloInstruction *fusion1, *fusion2; - EXPECT_THAT(root, - GmockMatch(m::Tuple( - m::GetTupleElement(m::Fusion(&fusion1, m::Parameter()), 0), - m::GetTupleElement(m::Fusion(&fusion2, m::Parameter()), 1)))); - EXPECT_EQ(fusion1, fusion2); - EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); - TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, - fusion1->backend_config()); - EXPECT_TRUE( - backend_config1.fusion_backend_config().has_block_level_fusion_config()); - EXPECT_EQ(backend_config1.fusion_backend_config() - .block_level_fusion_config() - .output_tiles(0) - .sizes_size(), - 2); -} - -TEST_F(PriorityFusionTest, - FuseProducerWithTritonConsumerUsingMultiOutputFusion) { - const std::string kHloText = R"( -HloModule t - -consumer_computation { - parameter_0 = f32[125,127] parameter(0) - ROOT log = f32[125,127] log(parameter_0) -} - -ENTRY main { - param_0 = f32[125]{0} parameter(0) - producer = f32[125,127] broadcast(param_0), dimensions={0} - consumer_fusion = f32[125,127] fusion(producer), kind=kCustom, calls=consumer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","127"]}],"num_warps":"1"}}} - ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion, producer) -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); - EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true); - EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - - HloInstruction* root = module->entry_computation()->root_instruction(); - HloInstruction *fusion1, *fusion2; - EXPECT_THAT(root, - GmockMatch(m::Tuple( - m::GetTupleElement(m::Fusion(&fusion1, m::Parameter()), 0), - m::GetTupleElement(m::Fusion(&fusion2, m::Parameter()), 1)))); - EXPECT_EQ(fusion1, fusion2); - EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); - TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, - fusion1->backend_config()); - EXPECT_TRUE( - backend_config1.fusion_backend_config().has_block_level_fusion_config()); - EXPECT_EQ(backend_config1.fusion_backend_config() - .block_level_fusion_config() - .output_tiles(0) - .sizes_size(), - 2); -} - TEST_F(PriorityFusionTest, TritonProducerNotSupported_DoNotFuse) { const std::string kHloText = R"( HloModule t From 03c2386ef906c924d00e263163cca6e2e062d79a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 01:51:18 -0700 Subject: [PATCH 0075/1324] [XLA:GPU] Allow to extract collectives from the post-optimized module. PiperOrigin-RevId: 742594387 --- third_party/xla/xla/tools/BUILD | 1 + .../tools/extract_collective_operations.cc | 52 +++++++++++++++---- third_party/xla/xla/tools/hlo_decomposer.cc | 50 ++++++++++++------ third_party/xla/xla/tools/hlo_decomposer.h | 13 ++++- 4 files changed, 90 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index f8df429f683538..4a472ea029f6de 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -881,6 +881,7 @@ xla_cc_binary( "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/tools/extract_collective_operations.cc b/third_party/xla/xla/tools/extract_collective_operations.cc index 2f484059d3d3b1..f5c3ca7c1d2a21 100644 --- a/third_party/xla/xla/tools/extract_collective_operations.cc +++ b/third_party/xla/xla/tools/extract_collective_operations.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -25,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo.pb.h" #include "xla/tools/hlo_decomposer.h" #include "xla/tools/hlo_module_loader.h" @@ -37,27 +39,48 @@ limitations under the License. namespace { const char* const kUsage = R"( -This tool extracts collective operations from HLO module and saves them together +This tool extracts collective operations (all-reduce and all-gather) from HLO module and saves them together to the separate module. Usage: bazel run extract_collective_operations -- --input=path/to/hlo_module - --output=path/to/hlo_module + --output=path/to/hlo_module --operations=all-reduce,all-gather )"; } // namespace namespace xla { -absl::Status ExtractCollectiveOperations(const std::string& input, - const std::string& output) { + +absl::Status ExtractCollectiveOperations( + const std::string& input, const std::string& output, + const absl::flat_hash_set& operation_types) { TF_ASSIGN_OR_RETURN( std::unique_ptr test_module, LoadModuleFromFile(input, std::string(tsl::io::Extension(input)), hlo_module_loader_details::Config(), nullptr)); + absl::flat_hash_set done_ops; + absl::flat_hash_set non_optimized_ops; + if (operation_types.contains(HloOpcode::kAllReduce)) { + non_optimized_ops.insert(HloOpcode::kAllReduce); + done_ops.insert(HloOpcode::kAllReduceDone); + } + if (operation_types.contains(HloOpcode::kAllGather)) { + non_optimized_ops.insert(HloOpcode::kAllGather); + done_ops.insert(HloOpcode::kAllGatherDone); + } + std::vector collective_instructions; for (const auto& op : test_module->computations()) { for (const auto& instr : op->instructions()) { - if (absl::StartsWith(instr->name(), "all-")) { + if (operation_types.contains(HloOpcode::kAllReduce) && + HloPredicateIsOp(instr)) { + collective_instructions.push_back(instr); + } + + if (operation_types.contains(HloOpcode::kAllGather) && + HloPredicateIsOp(instr)) { collective_instructions.push_back(instr); } } @@ -66,8 +89,8 @@ absl::Status ExtractCollectiveOperations(const std::string& input, if (collective_instructions.empty()) { return absl::InternalError("No collective instructions found."); } - auto collectives_module = - ExtractInstructionIntoNewModule(collective_instructions); + auto collectives_module = ExtractCollectiveOperationsIntoNewModule( + collective_instructions, done_ops, non_optimized_ops); QCHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), output, collectives_module->ToString())) @@ -79,9 +102,12 @@ absl::Status ExtractCollectiveOperations(const std::string& input, int main(int argc, char** argv) { std::string input; std::string output; + std::string operations; std::vector flag_list = { tsl::Flag("input", &input, "input file"), - tsl::Flag("output", &output, "output file")}; + tsl::Flag("output", &output, "output file"), + tsl::Flag("operations", &operations, + "operations. possible values: all-reduce, all-gather")}; xla::AppendDebugOptionsFlags(&flag_list); const std::string kUsageString = absl::StrCat(kUsage, "\n\n", tsl::Flags::Usage(argv[0], flag_list)); @@ -90,6 +116,14 @@ int main(int argc, char** argv) { if (!parse_ok) { LOG(QFATAL) << kUsageString; } - TF_CHECK_OK(xla::ExtractCollectiveOperations(input, output)); + + absl::flat_hash_set operation_types; + if (absl::StrContains(operations, "all-reduce")) { + operation_types.insert(xla::HloOpcode::kAllReduce); + } + if (absl::StrContains(operations, "all-gather")) { + operation_types.insert(xla::HloOpcode::kAllGather); + } + TF_CHECK_OK(xla::ExtractCollectiveOperations(input, output, operation_types)); return 0; } diff --git a/third_party/xla/xla/tools/hlo_decomposer.cc b/third_party/xla/xla/tools/hlo_decomposer.cc index e083dc798a5c1e..f8ed1102a869ba 100644 --- a/third_party/xla/xla/tools/hlo_decomposer.cc +++ b/third_party/xla/xla/tools/hlo_decomposer.cc @@ -116,8 +116,10 @@ absl::StatusOr>> DecomposeHloModule( return modules; } -std::unique_ptr ExtractInstructionIntoNewModule( - const std::vector& instructions) { +std::unique_ptr ExtractCollectiveOperationsIntoNewModule( + const std::vector& instructions, + const absl::flat_hash_set& done_ops, + const absl::flat_hash_set& non_optimized_ops) { CHECK(!instructions.empty()); HloInstruction& first_instruction = *instructions[0]; auto new_hlo_module = std::make_unique( @@ -128,24 +130,42 @@ std::unique_ptr ExtractInstructionIntoNewModule( int parameter_number = 0; HloComputation::Builder builder("entry_computation"); HloCloneContext clone_context(new_hlo_module.get()); - std::vector new_instructions; + std::vector result_instructions; + absl::flat_hash_map start_op_map; for (auto* hlo : instructions) { - std::vector new_operands; - for (const HloInstruction* operand : hlo->operands()) { - std::unique_ptr new_parameter = - HloInstruction::CreateParameter(parameter_number, operand->shape(), - operand->name()); - ++parameter_number; - new_operands.push_back(builder.AddInstruction(std::move(new_parameter))); + if (done_ops.contains(hlo->opcode())) { + std::vector new_operands; + for (const HloInstruction* operand : hlo->operands()) { + if (start_op_map.contains(operand->name())) { + new_operands.push_back(start_op_map[operand->name()]); + } + } + result_instructions.push_back( + builder.AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, &clone_context))); + } else { + std::vector new_operands; + for (const HloInstruction* operand : hlo->operands()) { + std::unique_ptr new_parameter = + HloInstruction::CreateParameter(parameter_number, operand->shape(), + operand->name()); + ++parameter_number; + new_operands.push_back( + builder.AddInstruction(std::move(new_parameter))); + } + std::unique_ptr new_instruction = + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context); + HloInstruction* new_instr_ptr = + builder.AddInstruction(std::move(new_instruction)); + if (non_optimized_ops.contains(hlo->opcode())) { + result_instructions.push_back(new_instr_ptr); + } + start_op_map[hlo->name()] = new_instr_ptr; } - std::unique_ptr new_instruction = - hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context); - new_instructions.push_back( - builder.AddInstruction(std::move(new_instruction))); } std::unique_ptr tuple_instruction = - HloInstruction::CreateTuple(new_instructions); + HloInstruction::CreateTuple(result_instructions); builder.AddInstruction(std::move(tuple_instruction)); new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); return new_hlo_module; diff --git a/third_party/xla/xla/tools/hlo_decomposer.h b/third_party/xla/xla/tools/hlo_decomposer.h index d12b4d82216d1e..7fe6629bdab2c9 100644 --- a/third_party/xla/xla/tools/hlo_decomposer.h +++ b/third_party/xla/xla/tools/hlo_decomposer.h @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" namespace xla { @@ -42,8 +44,15 @@ std::unique_ptr ExtractInstructionIntoNewModule( // with parameter instructions even if the result of one instruction is used // as a parameter to another. Combines results of all operations into the // tuple and adds this tuple as a root instruction of the new module. -std::unique_ptr ExtractInstructionIntoNewModule( - const std::vector& instructions); +// Parameters: +// instructions: HLO instructions to be extracted. +// done_ops: Set of HLO opcodes that are done operations (e.g. AllReduceDone). +// non_optimized_ops: Set of HLO opcodes that are not optimized (e.g. +// AllReduce). +std::unique_ptr ExtractCollectiveOperationsIntoNewModule( + const std::vector& instructions, + const absl::flat_hash_set& done_ops, + const absl::flat_hash_set& non_optimized_ops); // Extracts producer and consumer HLO instruction into a new HLO module // replacing its operands with parameter instructions. From ddb35866f69e2f10f6207a75a98f75108cab83f7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 02:02:37 -0700 Subject: [PATCH 0076/1324] compat: Update forward compatibility horizon to 2025-04-01 PiperOrigin-RevId: 742597851 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 4fae6576e28ebf..fa2918c0912f24 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 3, 31) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 1) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 61c14f6b57d9d26d43c841f1c77a115975ea915c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 02:02:47 -0700 Subject: [PATCH 0077/1324] Update GraphDef version to 2184. PiperOrigin-RevId: 742597930 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 14905975e0e21b..970cc071347b35 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2183 // Updated: 2025/3/31 +#define TF_GRAPH_DEF_VERSION 2184 // Updated: 2025/4/1 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 8fec0544eeab74f2d0a575c471c4366f76eb112b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 02:09:59 -0700 Subject: [PATCH 0078/1324] Automated Code Change PiperOrigin-RevId: 742600442 --- .../xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc index 798d4ceeae086d..7c45961263df1e 100644 --- a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc +++ b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #if GOOGLE_CUDA #include "xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h" From e43dce09028b34f146a86685fad6a5e3fe35ae82 Mon Sep 17 00:00:00 2001 From: Michael Platings Date: Tue, 1 Apr 2025 10:56:36 +0100 Subject: [PATCH 0079/1324] [TOSA] Fix legalizing CONV bias (#90118) The logic to handle bias already existed in varying degrees of completeness for the various CONV operations. This change unifies that logic in a single function and reuses it for operations from which it was missing. The TOSA 1.0 specification permits that bias may be of shape [1]. This change takes advantage of that fact to simplify the logic. The change in behaviour is reflected in the tests. Change-Id: I97371fae425ca8e45b18705c71bd106d08f203d6 Signed-off-by: Michael Platings --- .../tests/tfl-to-tosa-pipeline-filtered.mlir | 3 +- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 52 +++- .../mlir/tosa/transforms/legalize_tfl.cc | 284 ++++++++---------- 3 files changed, 164 insertions(+), 175 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index a14fe7e43f4bdc..77919da76fd497 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -30,11 +30,10 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[28, 1, 1, 19]> : tensor<4xindex>} // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<28xf32>}> // CHECK: %[[VAR1:.*]] = tosa.transpose %arg1 {perms = array} // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR1]], %[[CONST1]] -// CHECK-DAG: %[[VAR4:.*]] = tosa.conv2d %[[VAR2]], %[[VAR3]], %[[VAR0]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK-DAG: %[[VAR4:.*]] = tosa.conv2d %[[VAR2]], %[[VAR3]], %[[CONST3]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[CONST2]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 2384330ea0b236..9bfc60e593f67f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -61,9 +61,8 @@ func.func @test_conv2d_slicing(%arg0: tensor<2x32x32x8xf32>, %arg1: tensor<16x3x // ----- // CHECK-LABEL: test_transpose_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR1]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none @@ -74,9 +73,8 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16 // ----- // CHECK-LABEL: test_transpose_conv2d_relu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR1]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} // CHECK: %[[VAR3:.*]] = tosa.clamp %[[VAR2]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> @@ -284,10 +282,9 @@ func.func @test_depthwise_conv2d_slicing(%arg0: tensor<1x32x32x8xf32>, %arg1: te // CHECK-LABEL: test_conv3d // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x7x7x2xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x2x4xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x2x7x7x4xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<2x2x7x7x2xf32>, tensor<2x3x3x2x4xf32>, none) -> tensor<2x2x7x7x4xf32> @@ -299,10 +296,9 @@ func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32 // CHECK-LABEL: test_conv3d_dynamic // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1x1x8x16xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x1x1x8x16xf32>) -> tensor<*xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x1x1x8x16xf32>, none) -> tensor<*xf32> @@ -345,13 +341,12 @@ func.func @test_conv3d_slicing(%arg0: tensor<1x32x32x32x8xf32>, %arg1: tensor<3x // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.11982894> : tensor<1x1x1x1x1xf32>} // CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-4> : tensor<1x1x1x1x1xi32>} -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<34xf32>} // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[BIAS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_0]] // CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]], %[[SHIFT]] // CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[VAL_6]], %[[ZP]], %[[ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[BIAS_ZP]], %[[BIAS_ZP]], %[[BIAS_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_4]], %[[SHIFT]] // CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] // CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_14]], %[[VAL_5]] @@ -367,6 +362,17 @@ func.func @test_conv3d_qi8(%arg0: tensor<1x4x8x21x17x!quant.uniform : tensor<16xi48>}> : () -> tensor<16xi48> +// CHECK: tosa.conv3d {{.+}}, %[[BIAS]], %{{.+}} {acc_type = i48, {{.+}}} : {{.+}} -> tensor<1x15x15x15x16xi48> +func.func @test_conv3d_qi16(%input: tensor<1x32x32x32x8x!quant.uniform>, %filter: tensor<3x3x3x8x16x!quant.uniform>) -> tensor<1x15x15x15x16x!quant.uniform> { + %bias = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<123> : tensor<16xi16>} : () -> tensor<16x!quant.uniform> + %0 = "tfl.conv_3d"(%input, %filter, %bias) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x32x8x!quant.uniform>, tensor<3x3x3x8x16x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x15x15x15x16x!quant.uniform> + func.return %0 : tensor<1x15x15x15x16x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_add // CHECK: %[[VAR0:.*]] = tosa.add %arg0, %arg1 func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -1758,7 +1764,6 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<28xf32>}> // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[14, 1, 1, 19]> : tensor<4xindex>} // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[28, 1, 1, 19]> : tensor<4xindex>} // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} @@ -1766,7 +1771,7 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK: %[[VAR2:.*]] = tosa.transpose %arg1 {perms = array} // CHECK: %[[VAR3:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK: %[[VAR4:.*]] = tosa.reshape %[[VAR2]], %[[CONST1]] -// CHECK: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[VAR1]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[CONST3]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[CONST2]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> @@ -2860,6 +2865,18 @@ func.func @test_resize_nearest_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!qua } // ----- + +// CHECK-LABEL: test_fullyconnected_qi16 +// CHECK: %[[BIAS:.+]] = "tosa.const"() <{values = dense<123> : tensor<3xi48>}> : () -> tensor<3xi48> +// CHECK: tosa.conv2d {{.+}}, %[[BIAS]], %{{.+}} {acc_type = i48, {{.+}}} : {{.+}} -> tensor<1x1x1x3xi48> +func.func @test_fullyconnected_qi16(%input: tensor<1x7x!quant.uniform>, %filter: tensor<3x7x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %bias = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<123> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> + %0 = "tfl.fully_connected"(%input, %filter, %bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x7x!quant.uniform>, tensor<3x7x!quant.uniform>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_gather // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] @@ -3113,6 +3130,15 @@ func.func @test_conv2d_infer(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x // ----- +// CHECK-LABEL: @test_conv2d_no_bias +func.func @test_conv2d_no_bias(%input: tensor<1x32x32x8x!quant.uniform>, %filter: tensor<3x3x8x16x!quant.uniform>) -> tensor<1x32x32x3x!quant.uniform> { + %bias = "tfl.no_value"() {value} : () -> none + %0 = "tfl.conv_2d"(%input, %filter, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>, tensor<3x3x8x16x!quant.uniform>, none) -> tensor<1x32x32x3x!quant.uniform> + return %0 : tensor<1x32x32x3x!quant.uniform> +} + +// ----- + // CHECK-LABEL: @test_squeeze func.func @test_squeeze(%arg0: tensor<2x1x3x1xf32>) -> tensor<2x3x1xf32> { // CHECK: tosa.reshape diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index a11ed5e33b465e..bea9dc46b5fd75 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -1511,6 +1511,102 @@ Value lowerGroupedConvolution(TFL::Conv2DOp op, PatternRewriter& rewriter) { convolutions, output_slice_dim); } +/* Ensure bias is of the correct type. +TOSA requires that bias must be of the same type as the output, and that +output type must be of a certain type depending on the input type. +*/ +static FailureOr> getTosaBias( + Operation* op, PatternRewriter& rewriter, ShapedType input_type, + ShapedType output_type, bool output_is_qtype, Value bias) { + Type bias_ety; + + int bias_bits; + if (output_is_qtype) { + auto input_qtype = + dyn_cast(input_type.getElementType()); + if (!input_qtype) { + return rewriter.notifyMatchFailure(op, + "output is qtype but input is not"); + } + int input_bits = input_qtype.getStorageTypeIntegralWidth(); + // For signed int8/int16 input tensor, int32/int48 bias and output + // tensor are generated. + bias_bits = input_bits == 16 ? 48 : 32; + bias_ety = rewriter.getIntegerType(bias_bits); + } else { + bias_ety = output_type.getElementType(); + bias_bits = bias_ety.getIntOrFloatBitWidth(); + } + + if (!bias || !dyn_cast(bias.getType())) { + // The bias may actually be typed "None" which has no value. TOSA requires + // bias to be an array of output_channel_count values, so create a constant + // of the appropriate number and type of zeros. + RankedTensorType bias_type = RankedTensorType::get({1}, bias_ety); + auto bias_attr = rewriter.getZeroAttr(bias_type); + bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, + bias_attr.cast()); + } + + auto prev_bias_type = dyn_cast(bias.getType()); + if (!prev_bias_type) { + return rewriter.notifyMatchFailure(op, "bias not a ranked tensor"); + } + + auto prev_bias_etype = prev_bias_type.getElementType(); + + int prev_bias_bits; + if (auto prev_bias_eqtype = + dyn_cast(prev_bias_etype)) { + prev_bias_bits = prev_bias_eqtype.getStorageTypeIntegralWidth(); + } else { + prev_bias_bits = prev_bias_etype.getIntOrFloatBitWidth(); + } + + if (prev_bias_bits == bias_bits) { + return std::pair(bias_ety, bias); + } + + auto const_op = bias.getDefiningOp(); + if (!const_op) { + return rewriter.notifyMatchFailure(op, "bias not a ConstOp"); + } + + DenseElementsAttr bias_attr; + { + auto prev_bias_attr = + dyn_cast(const_op.getValuesAttr()); + if (!prev_bias_attr) { + return rewriter.notifyMatchFailure( + op, "bias values not DenseIntElementsAttr"); + } + // Promote to int32/int48 if necessary. + bias_attr = prev_bias_attr.mapValues( + bias_ety, + [bias_bits = bias_ety.getIntOrFloatBitWidth()]( + const APInt& x) -> APInt { return x.sext(bias_bits); }); + } + + ShapedType bias_output_type; + if (auto bias_attr_type = dyn_cast(bias_attr.getType())) { + bias_output_type = bias_attr_type.clone(bias_ety); + } else { + bias_output_type = dyn_cast(const_op.getResult().getType()); + if (!bias_output_type) { + return rewriter.notifyMatchFailure( + op, "bias defining op result not ShapedType"); + } + bias_output_type = bias_output_type.clone(bias_ety); + } + + auto new_const_op = + rewriter.create(op->getLoc(), bias_output_type, bias_attr); + Value new_bias = new_const_op.getResult(); + rewriter.replaceOp(const_op, new_bias); + + return std::make_pair(bias_ety, new_bias); +} + LogicalResult ConvertTFLConv2DOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_conv2d_op = cast(op); @@ -1583,19 +1679,10 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( return failure(); } - Value unquantized_bias = tfl_conv2d_op.getBias(); - Type bias_ety = - output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); - if (unquantized_bias) { - Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = mlir::dyn_cast(new_bias_ety)) { - new_bias_ety = qtype.getStorageType(); - } - if (new_bias_ety.getIntOrFloatBitWidth() > - bias_ety.getIntOrFloatBitWidth()) { - bias_ety = new_bias_ety; - } - } + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv2d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); // TFLite only supports NHWC format Value conv2d_input = getInputSlicedToItsUsedSize( @@ -1609,8 +1696,7 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), conv2d_input, - tfl_conv2d_op.getFilter(), unquantized_bias, pad, stride, dilation, - acc_type); + tfl_conv2d_op.getFilter(), bias_val, pad, stride, dilation, acc_type); Value conv2d_output; if (input_is_qtype) { @@ -1710,37 +1796,26 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } } - Value unquantized_bias = tfl_conv3d_op.getBias(); - if (!dyn_cast(unquantized_bias.getType())) { - // The bias may actually be typed "None" which has no value. TOSA requires - // bias to be an array of output_channel_count values, so create a constant - // of the appropriate number and type of zeros. - auto bias_dim = filter_type.getShape().back(); - RankedTensorType bias_type = - RankedTensorType::get({bias_dim}, filter_type.getElementType()); - auto bias_attr = rewriter.getZeroAttr(bias_type); - unquantized_bias = CreateOpAndInfer( - rewriter, op->getLoc(), bias_type, bias_attr.cast()); - } - // TFLite only supports NDHWC format, tensorflow::FORMAT_NHWC is used for both // rank 4 and rank 5 tensors Value conv3d_input = getInputSlicedToItsUsedSize( rewriter, op, tensorflow::FORMAT_NHWC, input_type, tfl_conv3d_op.getInput(), kernel_size, pad, stride, dilation); - Type bias_ety = - unquantized_bias.getType().cast().getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv3d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, /* input_etype = */ input_type.getElementType(), /* output_etype = */ bias_ety); - std::optional a1_conv3d_op = convertConv3DCommon( - rewriter, op, output_type.clone(bias_ety), conv3d_input, - tfl_conv3d_op.getFilter(), unquantized_bias, pad, stride, dilation, - acc_type, StringRef("NDHWC")); + std::optional a1_conv3d_op = + convertConv3DCommon(rewriter, op, output_type.clone(bias_ety), + conv3d_input, tfl_conv3d_op.getFilter(), bias_val, + pad, stride, dilation, acc_type, StringRef("NDHWC")); if (!a1_conv3d_op) return failure(); @@ -1789,23 +1864,6 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( bool output_is_qtype = mlir::isa(output_type.getElementType()); - const bool has_bias = - tfl_conv_op.getBias() && !isa(tfl_conv_op.getBias().getType()); - - if (has_bias) { - RankedTensorType bias_type = - dyn_cast(tfl_conv_op.getBias().getType()); - bool bias_is_qtype = - isa(bias_type.getElementType()); - - if (input_is_qtype != bias_is_qtype) { - return rewriter.notifyMatchFailure( - op, - "input/bias tensor should " - "be all quantized or all floating-point"); - } - } - if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { return rewriter.notifyMatchFailure( @@ -1835,49 +1893,10 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( return failure(); } - int output_channel = 0; - // TODO(suderman): We need to figure out how to guarantee output channel - // propagation. - if (output_type.hasRank()) { - output_channel = output_type.getDimSize(3); - } else if (filter_type.hasRank()) { - output_channel = filter_type.getDimSize(0); - } else { - return failure(); - } - - Value bias_val; - if (has_bias) { - bias_val = tfl_conv_op.getBias(); - } else { - std::optional zero_bias; - if (input_is_qtype) { - uint32_t input_bits = - cast(input_type.getElementType()) - .getStorageTypeIntegralWidth(); - uint32_t weight_bits = - cast(filter_type.getElementType()) - .getStorageTypeIntegralWidth(); - - if (input_bits == 16 && weight_bits == 8) { - // For signed 16x8, the output is accumulated into int48 - SmallVector vec(output_channel, APInt(48, 0, true)); - zero_bias = getConstTensor(rewriter, op, vec, {output_channel}); - } else { - SmallVector vec(output_channel, 0); - zero_bias = - getConstTensor(rewriter, op, vec, {output_channel}); - } - } else { - SmallVector vec(output_channel, 0.0f); - zero_bias = getConstTensor(rewriter, op, vec, {output_channel}); - } - - if (!zero_bias) return failure(); - bias_val = zero_bias.value(); - } - - Type bias_ety = cast(bias_val.getType()).getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, @@ -1886,8 +1905,8 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), - tfl_conv_op.getInput(), tfl_conv_op.getWeights(), bias_val, - outpad, stride, acc_type); + tfl_conv_op.getInput(), tfl_conv_op.getWeights(), bias_val, outpad, + stride, acc_type); Value conv2d_output; if (input_is_qtype) { @@ -2020,20 +2039,10 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( filter_type.getElementType()), a1_filter_transpose_op.getResult(), a2_reshape_dims_value); - Type bias_ety = - output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); - - Value unquantized_bias = tfl_conv2d_op.getBias(); - if (unquantized_bias) { - Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { - new_bias_ety = qtype.getStorageType(); - } - if (new_bias_ety.getIntOrFloatBitWidth() > - bias_ety.getIntOrFloatBitWidth()) { - bias_ety = new_bias_ety; - } - } + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv2d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); // TFLite only supports NHWC format Value conv2d_input = getInputSlicedToItsUsedSize( @@ -2047,7 +2056,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( auto a3_depthwise_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), conv2d_input, - a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride, dilation, + a2_filter_reshape_op.getResult(), bias_val, pad, stride, dilation, acc_type); Value conv2d_output; @@ -2231,8 +2240,6 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( dyn_cast(tfl_fc_op.getInput().getType()); RankedTensorType filter_type = dyn_cast(tfl_fc_op.getFilter().getType()); - RankedTensorType bias_type = - dyn_cast(tfl_fc_op.getBias().getType()); if (!input_type || !filter_type) return failure(); bool input_is_qtype = @@ -2306,53 +2313,10 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( filter_val, new_filter_shape_value); filter_type = cast(filter_val.getType()); - Value bias_val; - if (!bias_type) { - // For some matmuls, the bias may actually be a "UnitType" which has no - // value. TOSA requires bias to be an array of output_channel_count values, - // so create a constant of the appropriate number and type of zeros. - SmallVector bias_shape({filter_type.getShape()[0]}); - RankedTensorType new_bias_type; - - DenseElementsAttr bias_attr; - if (mlir::isa(input_type.getElementType())) { - SmallVector bias_arr(bias_shape[0]); - - for (int i = 0; i < bias_shape[0]; i++) { - bias_arr[i] = 0.0; - } - new_bias_type = - RankedTensorType::get(bias_shape, input_type.getElementType()); - bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); - } else { - SmallVector bias_arr(bias_shape[0]); - - for (int i = 0; i < bias_shape[0]; i++) { - bias_arr[i] = 0; - } - if (!input_is_qtype) { - return rewriter.notifyMatchFailure( - op, "input must be quantized type if it's not float type"); - } - auto input_qtype = - mlir::cast(input_type.getElementType()); - Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16 - ? rewriter.getIntegerType(48) - : rewriter.getI32Type(); - new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety); - bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); - } - auto bias_op = CreateOpAndInfer(rewriter, op->getLoc(), - new_bias_type, bias_attr); - bias_val = bias_op.getResult(); - bias_type = new_bias_type; - } else { - bias_val = tfl_fc_op.getBias(); - } - - Type bias_ety = mlir::cast(bias_val.getType()).getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_fc_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, From d64691c0deba7f80cb02f307b9d3c60dad1813ee Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Tue, 1 Apr 2025 02:52:48 -0700 Subject: [PATCH 0080/1324] [XLA:GPU] Define HS optimization at O1. PiperOrigin-RevId: 742612839 --- third_party/xla/xla/service/gpu/BUILD | 2 +- third_party/xla/xla/service/gpu/flag_utils.h | 2 +- .../xla/xla/service/gpu/flag_utils_test.cc | 28 ++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2d01c910f845e1..44d8885928510d 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3173,6 +3173,7 @@ xla_cc_test( srcs = ["flag_utils_test.cc"], deps = [ ":flag_utils", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:collective_pipeliner", @@ -3180,7 +3181,6 @@ xla_cc_test( "//xla/service:latency_hiding_scheduler", "//xla/service/gpu/transforms:double_buffer_loop_unrolling", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/flag_utils.h b/third_party/xla/xla/service/gpu/flag_utils.h index c51c09fee530e7..b9d2ef0d5e2f1a 100644 --- a/third_party/xla/xla/service/gpu/flag_utils.h +++ b/third_party/xla/xla/service/gpu/flag_utils.h @@ -43,7 +43,7 @@ bool IsPassEnabledAtOptimizationEffort(const HloModule& module) { if (is_collective_optimization_pass) { ExecutionOptions::EffortLevel opt_level = module.config().optimization_level(); - return (opt_level == ExecutionOptions::EFFORT_O3) || + return (opt_level == ExecutionOptions::EFFORT_O1) || (opt_level == ExecutionOptions::EFFORT_UNKNOWN && exec_effort >= kExtraCollectiveOptimizations); } diff --git a/third_party/xla/xla/service/gpu/flag_utils_test.cc b/third_party/xla/xla/service/gpu/flag_utils_test.cc index fbad690eac3a10..29c21cdd2c42a7 100644 --- a/third_party/xla/xla/service/gpu/flag_utils_test.cc +++ b/third_party/xla/xla/service/gpu/flag_utils_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" -#include "tsl/platform/test.h" +#include "xla/xla.pb.h" namespace xla { namespace gpu { @@ -54,6 +54,32 @@ TEST(FlagUtilsTest, IsPassEnabledAtOptimizationEffort) { IsPassEnabledAtOptimizationEffort(module)); } +TEST(FlagUtilsTest, IsPassEnabledAtOptimizationLevel) { + HloModuleConfig config; + config.set_optimization_level(ExecutionOptions::EFFORT_O1); + HloModule module("test_module", config); + + // Collective optimization passes. + EXPECT_TRUE(IsPassEnabledAtOptimizationEffort(module)); + EXPECT_TRUE( + IsPassEnabledAtOptimizationEffort(module)); + EXPECT_TRUE( + IsPassEnabledAtOptimizationEffort(module)); + + // Other passes. + EXPECT_TRUE(IsPassEnabledAtOptimizationEffort(module)); + + config.set_optimization_level(ExecutionOptions::EFFORT_O0); + module.set_config(config); + + // Collective optimization passes. + EXPECT_FALSE(IsPassEnabledAtOptimizationEffort(module)); + EXPECT_FALSE( + IsPassEnabledAtOptimizationEffort(module)); + EXPECT_FALSE( + IsPassEnabledAtOptimizationEffort(module)); +} + } // namespace } // namespace gpu } // namespace xla From e71d14f4939a1aa9f9ba11529ddd25198cfc4304 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 1 Apr 2025 03:11:08 -0700 Subject: [PATCH 0081/1324] Remove usage of AsGpuStreamValue from HipBlasLt AsGpuStreamValue is deprecated, so this change is inlining the call. PiperOrigin-RevId: 742618006 --- third_party/xla/xla/stream_executor/rocm/BUILD | 2 +- third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index d90b09404bd2c2..97c0f2bcbe912a 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -690,11 +690,11 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_helpers_header", - "//xla/stream_executor/gpu:gpu_stream", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index e0c5a9288c9dfb..c33a463307e853 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "rocm/rocm_config.h" #if TF_HIPBLASLT @@ -45,7 +46,6 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/gpu/gpu_helpers.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/rocm/hip_blas_utils.h" #include "xla/stream_executor/rocm/hipblaslt_wrapper.h" #include "xla/stream_executor/rocm/rocm_blas.h" @@ -446,7 +446,9 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( blas_lt->blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), beta, args.c.opaque(), c_desc_.get(), args.d.opaque(), d_desc_.get(), palgo, workspace_addr, - workspace_size, gpu::AsGpuStreamValue(stream))); + workspace_size, + absl::bit_cast( + stream->platform_specific_handle().stream))); } else { return absl::InternalError("hipblaslt: Invalid algorithm type"); } From be3ecc2f30fd619aae614a04bafe5f99ed042434 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 1 Apr 2025 03:14:12 -0700 Subject: [PATCH 0082/1324] Remove AsGpuStreamValue usage from the CUDA blas_plugin `AsGpuStreamValue` is deprecated and needs to be inlined. PiperOrigin-RevId: 742618884 --- third_party/xla/xla/stream_executor/cuda/BUILD | 7 ++++++- .../xla/xla/stream_executor/cuda/cuda_blas.cc | 18 ++++++++++++------ .../xla/stream_executor/cuda/cuda_blas_lt.cc | 11 +++++------ .../xla/stream_executor/cuda/cuda_blas_lt.h | 2 -- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 7253ec4b295d6c..1344ae75dfd015 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -346,6 +346,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cuda_blas_utils", + ":cuda_compute_capability", ":cuda_executor", ":cuda_helpers", ":cuda_platform_id", @@ -367,13 +368,17 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_helpers_header", - "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/platform:initialize", "//xla/tsl/cuda:cublas", "//xla/tsl/cuda:cublas_lt", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc index ba5b805247bdf0..75ebc07d2bda01 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc @@ -22,7 +22,11 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -40,22 +44,21 @@ limitations under the License. #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_helpers.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_helpers.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/protobuf/dnn.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" namespace stream_executor { @@ -229,7 +232,10 @@ bool CUDABlas::SetStream(Stream *stream) { CHECK(blas_ != nullptr); std::unique_ptr activation = parent_->Activate(); - auto handle = (stream != nullptr) ? gpu::AsGpuStreamValue(stream) : nullptr; + auto handle = + (stream != nullptr) + ? absl::bit_cast(stream->platform_specific_handle().stream) + : nullptr; if (auto ret = cublasSetStream(blas_, handle); ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret); return false; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index 8e2aef4d7217c3..17e7d6dc9fca89 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -40,20 +41,17 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/cuda/cuda_blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/gpu/gpu_helpers.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/statusor.h" #define SET_ATTR(setter, handle, attr, value) \ ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter) @@ -462,7 +460,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( blas_lt->blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), beta, args.c.opaque(), c_desc_.get(), args.d.opaque(), d_desc_.get(), palgo, workspace_addr, - workspace_size, gpu::AsGpuStreamValue(stream))); + workspace_size, + absl::bit_cast(stream->platform_specific_handle().stream))); } else { return absl::InternalError("cublaslt: Invalid algorithm type"); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index f1b9ee30ca1ca5..8909c6ac7ed0e1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -31,7 +30,6 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/library_types.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" From 5e4dfa334faca89b4db2ecd2af1ff1254b1ee72c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 03:15:25 -0700 Subject: [PATCH 0083/1324] Enable `sin` and `cos` delegation for `float` and `float16`. The XNNPACK microkernels are 10-20x faster than the TFLite builtins. PiperOrigin-RevId: 742619195 --- .../lite/cmake/DownloadPThreadPool.cmake | 4 +- .../xnnpack/unary_elementwise_test.cc | 18 ++-- .../delegates/xnnpack/xnnpack_delegate.cc | 88 +++++++++++++++++++ .../lite/tools/cmake/modules/cpuinfo.cmake | 2 +- .../lite/tools/cmake/modules/xnnpack.cmake | 2 +- tensorflow/workspace2.bzl | 18 ++-- 6 files changed, 113 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/cmake/DownloadPThreadPool.cmake b/tensorflow/lite/cmake/DownloadPThreadPool.cmake index 08441656f177de..e12799e3231a31 100644 --- a/tensorflow/lite/cmake/DownloadPThreadPool.cmake +++ b/tensorflow/lite/cmake/DownloadPThreadPool.cmake @@ -19,8 +19,8 @@ PROJECT(pthreadpool-download NONE) INCLUDE(ExternalProject) ExternalProject_Add(pthreadpool - URL https://github.com/google/pthreadpool/archive/b1aee199d54003fb557076a201bcac3398af580b.zip - URL_HASH SHA256=215724985c4845cdcadcb5f26a2a8777943927bb5a172a00e7716fe16a6f3c1b + URL https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip + URL_HASH SHA256=745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421 SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" BINARY_DIR "${CMAKE_BINARY_DIR}/pthreadpool" CONFIGURE_COMMAND "" diff --git a/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc b/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc index 97cea8e5294ae1..4986ad9707b302 100644 --- a/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc +++ b/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc @@ -34,6 +34,9 @@ ToleranceInfo GetTolerance(BuiltinOperator op) { return ToleranceInfo{.relative = 1.0e+4f}; case BuiltinOperator_GELU: return ToleranceInfo{.relative = 5.0f, .absolute = 10.0f}; + case BuiltinOperator_COS: + case BuiltinOperator_SIN: + return ToleranceInfo{.relative = 5.0f, .absolute = 3.0f}; default: return ToleranceInfo{}; } @@ -139,12 +142,15 @@ TEST_P(UnaryTest, MultiThreading) { } BuiltinOperator all_unary_ops[] = { - BuiltinOperator_ABS, BuiltinOperator_CEIL, BuiltinOperator_ELU, - BuiltinOperator_FLOOR, BuiltinOperator_GELU, BuiltinOperator_NEG, - BuiltinOperator_HARD_SWISH, BuiltinOperator_RELU, BuiltinOperator_RELU6, - BuiltinOperator_RELU_N1_TO_1, BuiltinOperator_ROUND, BuiltinOperator_RSQRT, - BuiltinOperator_SQRT, BuiltinOperator_SQUARE, BuiltinOperator_TANH, - BuiltinOperator_LOGISTIC, + BuiltinOperator_ABS, BuiltinOperator_CEIL, + BuiltinOperator_COS, BuiltinOperator_ELU, + BuiltinOperator_FLOOR, BuiltinOperator_GELU, + BuiltinOperator_NEG, BuiltinOperator_HARD_SWISH, + BuiltinOperator_RELU, BuiltinOperator_RELU6, + BuiltinOperator_RELU_N1_TO_1, BuiltinOperator_ROUND, + BuiltinOperator_RSQRT, BuiltinOperator_SIN, + BuiltinOperator_SQRT, BuiltinOperator_SQUARE, + BuiltinOperator_TANH, BuiltinOperator_LOGISTIC, }; INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index ed9bcf7a47168b..7cb451e956256c 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -1901,6 +1901,22 @@ class Subgraph { node_index); } + static TfLiteStatus CheckTensorFloatType(TfLiteContext* context, + const TfLiteTensor& tensor, + int tensor_index, int node_index) { + switch (tensor.type) { + case kTfLiteFloat32: + case kTfLiteFloat16: + return kTfLiteOk; + default: + TF_LITE_MAYBE_KERNEL_LOG( + context, "%s: unsupported type %s in tensor #%d in node #%d", + __FUNCTION__, TfLiteTypeGetName(tensor.type), tensor_index, + node_index); + return kTfLiteError; + } + } + static TfLiteStatus CheckTensorFloat32OrQInt8Type(const Delegate& delegate, TfLiteContext* context, const TfLiteTensor& tensor, @@ -2744,6 +2760,9 @@ class Subgraph { node, context->tensors, conv_params, quasi_static_tensors, input_output_tensors); } + case kTfLiteBuiltinCos: + return VisitCosNode(subgraph, delegate, logging_context, node_index, + node, context->tensors, input_output_tensors); case kTfLiteBuiltinDepthwiseConv2d: { const TfLiteDepthwiseConvParams* dwconv_params = static_cast(node->builtin_data); @@ -2912,6 +2931,9 @@ class Subgraph { case kTfLiteBuiltinRsqrt: return VisitRsqrtNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); + case kTfLiteBuiltinSin: + return VisitSinNode(subgraph, delegate, logging_context, node_index, + node, context->tensors, input_output_tensors); case kTfLiteBuiltinSlice: return VisitSliceNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); @@ -3794,6 +3816,39 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus VisitCosNode( + xnn_subgraph_t subgraph, const Delegate& delegate, + TfLiteContext* logging_context, int node_index, TfLiteNode* node, + const TfLiteTensor* tensors, + const std::unordered_map& input_output_tensors) { + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( + logging_context, node, 1, 1, BuiltinOperator_COS, node_index)); + + const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, node->inputs->data[0], node_index)); + + const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, output_tensor, node->outputs->data[0], node_index)); + + if (subgraph != nullptr) { + const xnn_status status = xnn_define_unary( + subgraph, xnn_unary_cosine, /*params=*/nullptr, + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + /*flags=*/0); + if (status != xnn_status_success) { + TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", + EnumNameBuiltinOperator(BuiltinOperator_COS), + node_index); + return kTfLiteError; + } + } + + return kTfLiteOk; + } + static TfLiteStatus VisitDepthwiseConv2DNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -5609,6 +5664,39 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus VisitSinNode( + xnn_subgraph_t subgraph, const Delegate& delegate, + TfLiteContext* logging_context, int node_index, TfLiteNode* node, + const TfLiteTensor* tensors, + const std::unordered_map& input_output_tensors) { + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( + logging_context, node, 1, 1, BuiltinOperator_SIN, node_index)); + + const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, node->inputs->data[0], node_index)); + + const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, output_tensor, node->outputs->data[0], node_index)); + + if (subgraph != nullptr) { + const xnn_status status = xnn_define_unary( + subgraph, xnn_unary_sine, /*params=*/nullptr, + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + /*flags=*/0); + if (status != xnn_status_success) { + TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", + EnumNameBuiltinOperator(BuiltinOperator_SIN), + node_index); + return kTfLiteError; + } + } + + return kTfLiteOk; + } + static TfLiteStatus VisitSliceNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, diff --git a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake index e388c5f69e258e..852f272e6bba73 100644 --- a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake +++ b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( cpuinfo GIT_REPOSITORY https://github.com/pytorch/cpuinfo # Sync with tensorflow/workspace2.bzl - GIT_TAG 8a1772a0c5c447df2d18edf33ec4603a8c9c04a6 + GIT_TAG b73ae6ce38d5dd0b7fe46dbe0a4b5f4bab91c7ea GIT_PROGRESS TRUE SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo" ) diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index d4ab4b509305c6..18b9e115775c37 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG 5b4978cae19292232a27bdf0f495819bf5297167 + GIT_TAG e67c0fbc360903f921ff286a235c18d9e12c6df6 GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index e8570a3754381b..79d7ac7291db83 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -160,9 +160,9 @@ def _tf_repositories(): # LINT.IfChange(xnnpack) tf_http_archive( name = "XNNPACK", - sha256 = "04291b4c49693988f8c95d07968f6f3da3fd89d85bd9e4e26f73abbdfd7a8a45", - strip_prefix = "XNNPACK-24794834234a7926d2f553d34e84204c8ac99dfd", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/24794834234a7926d2f553d34e84204c8ac99dfd.zip"), + sha256 = "72e4368ff3e7bdefd8b43fc6e5708b8e9fada7a8302ba2362028832df6262c13", + strip_prefix = "XNNPACK-e67c0fbc360903f921ff286a235c18d9e12c6df6", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/e67c0fbc360903f921ff286a235c18d9e12c6df6.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -184,18 +184,18 @@ def _tf_repositories(): # LINT.IfChange(pthreadpool) tf_http_archive( name = "pthreadpool", - sha256 = "215724985c4845cdcadcb5f26a2a8777943927bb5a172a00e7716fe16a6f3c1b", - strip_prefix = "pthreadpool-b1aee199d54003fb557076a201bcac3398af580b", - urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/b1aee199d54003fb557076a201bcac3398af580b.zip"), + sha256 = "745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421", + strip_prefix = "pthreadpool-b92447772365661680f486e39a91dfe6675adafc", + urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip"), ) # LINT.ThenChange(//tensorflow/lite/cmake/DownloadPThreadPool.cmake) tf_http_archive( name = "cpuinfo", - sha256 = "4bf314b3f04db2fd984fef38a7e278e702b74297ef0af592b73296edba02b9d4", - strip_prefix = "cpuinfo-8a1772a0c5c447df2d18edf33ec4603a8c9c04a6", + sha256 = "593ac799e8c9382362e7b29a58917053299fa906e271185204bb571465bb2f79", + strip_prefix = "cpuinfo-b73ae6ce38d5dd0b7fe46dbe0a4b5f4bab91c7ea", patch_file = ["//third_party/cpuinfo:cpuinfo_ppc64le_support.patch"], - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6.zip"), + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/b73ae6ce38d5dd0b7fe46dbe0a4b5f4bab91c7ea.zip"), ) tf_http_archive( From d5bf4fdd48ce88bbede71f73fafb3c31f4a19835 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 1 Apr 2025 03:50:45 -0700 Subject: [PATCH 0084/1324] [XLA:GPU][NFC] Extract the logic to deduce the type of a `dot`'s accumulator type into `dot_algorithms.h`. Take the opportunity to split up the logic deriving the required operands type, the required accumulator type, and the algorithm emitter and hoist it out of `EmitSingleTileDot`. The purpose of this change is to make this logic available to the generic emitter as well. PiperOrigin-RevId: 742628334 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 3 +- .../gpu/codegen/triton/dot_algorithms.cc | 214 ++++++++++-------- .../gpu/codegen/triton/dot_algorithms.h | 6 + .../gpu/codegen/triton/emitter_helpers.h | 12 + .../triton/fusion_emitter_legacy_matmul.cc | 52 +---- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/algorithm_util.cc | 45 ++++ third_party/xla/xla/service/algorithm_util.h | 17 ++ 8 files changed, 198 insertions(+), 152 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 6c46dd86bcc042..c639c7e3e318dc 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -101,13 +101,13 @@ cc_library( "//xla:xla_proto_cc", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/mlir_hlo:transformation_helpers", "//xla/service/gpu:target_util", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -119,7 +119,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@triton//:TritonDialects", ], diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc index 03654935803618..6ffdbd24aae936 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -21,10 +21,12 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -42,6 +44,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/tensor_float_32_utils.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -277,11 +280,14 @@ ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) { return use_tf32 ? ttir::InputPrecision::TF32 : ttir::InputPrecision::IEEE; } -Type GetAlgUnsetAccumulatorType(EmitterLocOpBuilder& b, - const DotOperands& dot_operands) { - Type lhs_type = ElementType(dot_operands.lhs); - Type rhs_type = ElementType(dot_operands.rhs); - Type accumulator_type = ElementType(dot_operands.accumulator); +absl::StatusOr GetAlgUnsetAccumulatorType(EmitterLocOpBuilder& b, + const HloDotInstruction& dot) { + TF_ASSIGN_OR_RETURN(Type lhs_type, + TritonType(b, dot.operand(0)->shape().element_type())); + TF_ASSIGN_OR_RETURN(Type rhs_type, + TritonType(b, dot.operand(1)->shape().element_type())); + TF_ASSIGN_OR_RETURN(Type accumulator_type, + TritonType(b, dot.shape().element_type())); // The code below assumes that lhs and rhs have the same type. However // this may not always be the case with f8 matmuls, e.g. e4m3×e5m2 is @@ -313,12 +319,6 @@ absl::StatusOr EmitDotAlgUnset(EmitterLocOpBuilder& b, Value rhs = dot_operands.rhs; Value acc = dot_operands.accumulator; - Type expected_acc_type = GetAlgUnsetAccumulatorType(b, dot_operands); - if (ElementType(acc) != expected_acc_type) { - return absl::FailedPreconditionError( - "Given accumulator type for unset dot does not match expected type."); - } - int max_num_imprecise_acc = 0; if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) { // For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make @@ -365,114 +365,130 @@ absl::StatusOr EmitRegularDot(EmitterLocOpBuilder& b, /*maxNumImpreciseAcc=*/max_num_imprecise_acc); } -} // namespace - -absl::StatusOr EmitSingleTileDot(EmitterLocOpBuilder& b, - const HloDotInstruction& dot, - DotOperands dot_operands) { - AlgorithmEmitter algorithm_emitter = nullptr; - PrecisionSpec precision_spec{dot.precision_config().algorithm(), - dot.precision_config().operand_precision(0), - dot.precision_config().operand_precision(1), - InferDotPrecision(dot)}; - - // Algorithms mostly expect that their input and output types correspond to - // what the algorithm describes. This is not always the case though, e.g. - // for BF16_BF16_F32_X9, working from inputs casted to BF16 makes no sense; - // this algorithm instead expects F32 inputs, and performs splits into BF16 - // sub-values under the hood. - std::optional force_operands_type; - std::optional force_accumulator_type; - - PrecisionConfig::Algorithm algorithm = precision_spec.algorithm; - - Type bf16 = b.getBF16Type(); - Type f16 = b.getF16Type(); - Type f32 = b.getF32Type(); - Type f64 = b.getF64Type(); - +// Returns an emitter for the given dot algorithm. Raises an +// `UnimplementedError` if the algorithm is not supported. +absl::StatusOr GetAlgorithmEmitter( + const PrecisionConfig::Algorithm algorithm) { switch (algorithm) { case PrecisionConfig::ALG_UNSET: - algorithm_emitter = EmitDotAlgUnset; - break; + return EmitDotAlgUnset; case PrecisionConfig::ALG_DOT_F16_F16_F16: - force_operands_type = f16; - force_accumulator_type = f16; - algorithm_emitter = EmitRegularDot; - break; case PrecisionConfig::ALG_DOT_F32_F32_F32: - force_operands_type = f32; - force_accumulator_type = f32; - algorithm_emitter = EmitRegularDot; - break; case PrecisionConfig::ALG_DOT_F64_F64_F64: - force_operands_type = f64; - force_accumulator_type = f64; - algorithm_emitter = EmitRegularDot; - break; case PrecisionConfig::ALG_DOT_F16_F16_F32: - force_operands_type = f16; - force_accumulator_type = f32; - algorithm_emitter = EmitRegularDot; - break; case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: - force_operands_type = bf16; - force_accumulator_type = bf16; - algorithm_emitter = EmitRegularDot; - break; case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - force_operands_type = bf16; - force_accumulator_type = f32; - algorithm_emitter = EmitRegularDot; - break; + return EmitRegularDot; case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: - force_operands_type = f32; // This is not a typo. - force_accumulator_type = f32; - algorithm_emitter = EmitBF16x3Matmul; - break; + return EmitBF16x3Matmul; case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: - force_operands_type = f32; // This is not a typo. - force_accumulator_type = f32; - algorithm_emitter = EmitBF16x6Matmul; - break; + return EmitBF16x6Matmul; case PrecisionConfig::ALG_DOT_TF32_TF32_F32: - // TODO(bchetioui): pass around tf32 matmul config. - force_operands_type = f32; - force_accumulator_type = f32; // TODO(bchetioui): this should be factored out of EmitRegularDot. - algorithm_emitter = EmitRegularDot; - break; + return EmitRegularDot; case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: - // TODO(bchetioui): pass around tf32 matmul config. - force_operands_type = f32; - force_accumulator_type = f32; // TODO(bchetioui): this should be factored out of EmitRegularDot. - algorithm_emitter = EmitRegularDot; - break; + return EmitRegularDot; case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: - force_operands_type = f32; // This is not a typo. - force_accumulator_type = f32; - algorithm_emitter = EmitBF16x9Matmul; - break; + return EmitBF16x9Matmul; case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: - // TODO(bchetioui): How to enforce "any f8"? - force_accumulator_type = f32; - break; case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: - // TODO(bchetioui): How to enforce "any f8"? - force_accumulator_type = f32; - break; default: break; } // Couldn't find an algorithm emitter for this algorithm. Raise an error. - if (algorithm_emitter == nullptr) { - return absl::UnimplementedError( - absl::StrCat("This algorithm is not supported yet: ", - PrecisionConfig::Algorithm_Name(algorithm))); + return absl::UnimplementedError( + absl::StrCat("This algorithm is not supported yet: ", + PrecisionConfig::Algorithm_Name(algorithm))); +} + +// Returns the `Type` that the dot operands should be casted to if there is a +// clear candidate. Raises an error if there are multiple allowed choices but +// the operands do not already conform to any of them. Returns `std::nullopt` if +// no casting is a priori needed. +absl::StatusOr> GetForceOperandsType( + EmitterLocOpBuilder& b, const HloDotInstruction& dot, + const DotOperands& dot_operands) { + PrecisionConfig::Algorithm algorithm = dot.precision_config().algorithm(); + if (algorithm == PrecisionConfig::ALG_UNSET) { + return std::nullopt; + } + + TF_ASSIGN_OR_RETURN( + std::vector allowed_operands_primitive_types, + algorithm_util::GetAllowedOperandsTypeForAlgorithm(algorithm)); + CHECK(!allowed_operands_primitive_types.empty()); + + std::vector allowed_operands_types; + allowed_operands_types.reserve(allowed_operands_primitive_types.size()); + for (PrimitiveType primitive_type : allowed_operands_primitive_types) { + TF_ASSIGN_OR_RETURN(Type type, TritonType(b, primitive_type)); + allowed_operands_types.push_back(type); + } + + Type lhs_type = ElementType(dot_operands.lhs); + Type rhs_type = ElementType(dot_operands.rhs); + if (allowed_operands_types.size() == 1) { + // If there is a single allowed operand type, we force the operands to use + // this type. + return allowed_operands_types.front(); + + } else { + // If there are several allowed operand types, we just check that the + // operands have the same type, and that this type is one of the allowed + // ones. Raise an error otherwise. + if (lhs_type != rhs_type || + !absl::c_linear_search(allowed_operands_types, lhs_type)) { + std::string allowed_operands_types_str = absl::StrJoin( + allowed_operands_types, ", ", [&](std::string* out, Type type) { + absl::StrAppend(out, MlirToString(type)); + }); + return absl::FailedPreconditionError(absl::StrCat( + "Expected dot operands to both have the same type, and for this type " + "to be one of the following types: ", + allowed_operands_types_str, " but got ", MlirToString(lhs_type), + " and ", MlirToString(rhs_type))); + } + } + + return std::nullopt; +} + +} // namespace + +// TODO(b/266862493): Add support for more types as needed. +absl::StatusOr GetDotAccumulatorType(EmitterLocOpBuilder& b, + const HloDotInstruction& dot) { + const PrecisionConfig::Algorithm algorithm = + dot.precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + return GetAlgUnsetAccumulatorType(b, dot); } + TF_ASSIGN_OR_RETURN(PrimitiveType accumulator_type, + algorithm_util::GetDotAccumulatorType(algorithm)); + return TritonType(b, accumulator_type); +} + +absl::StatusOr EmitSingleTileDot(EmitterLocOpBuilder& b, + const HloDotInstruction& dot, + DotOperands dot_operands) { + PrecisionConfig::Algorithm algorithm = dot.precision_config().algorithm(); + PrecisionSpec precision_spec{ + algorithm, dot.precision_config().operand_precision(0), + dot.precision_config().operand_precision(1), InferDotPrecision(dot)}; + + TF_ASSIGN_OR_RETURN(AlgorithmEmitter algorithm_emitter, + GetAlgorithmEmitter(algorithm)); + + TF_ASSIGN_OR_RETURN(std::optional force_operands_type, + GetForceOperandsType(b, dot, dot_operands)); + + TF_ASSIGN_OR_RETURN(Type force_accumulator_type, + GetDotAccumulatorType(b, dot)); + if (force_operands_type.has_value()) { if (ElementType(dot_operands.lhs) != *force_operands_type) { dot_operands.lhs = Cast(b, dot_operands.lhs, *force_operands_type); @@ -483,11 +499,9 @@ absl::StatusOr EmitSingleTileDot(EmitterLocOpBuilder& b, } } - if (force_accumulator_type.has_value()) { - if (ElementType(dot_operands.accumulator) != *force_accumulator_type) { - dot_operands.accumulator = - Cast(b, dot_operands.accumulator, *force_accumulator_type); - } + if (ElementType(dot_operands.accumulator) != force_accumulator_type) { + dot_operands.accumulator = + Cast(b, dot_operands.accumulator, force_accumulator_type); } TF_ASSIGN_OR_RETURN(Value result, diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h index 2ee2bd2c01c5b3..f04eb1d2d7ed8c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_BACKENDS_GPU_CODEGEN_TRITON_DOT_ALGORITHMS_H_ #include "absl/status/statusor.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -33,6 +34,11 @@ struct DotOperands { ::mlir::Value accumulator; }; +// Returns the type to use for accumulation for the given `dot` instruction. +// This also handles the case where the algorithm is `ALG_UNSET`. +absl::StatusOr<::mlir::Type> GetDotAccumulatorType( + EmitterLocOpBuilder& b, const HloDotInstruction& dot); + // Emits a single-tile dot, considering the given `dot` instruction's algorithm // and operand precisions. Raises an `UnimplementedError` if the algorithm is // not supported. diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h index 24bce99a8781f9..0ab584d5f7bb6e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h @@ -17,12 +17,14 @@ limitations under the License. #define XLA_BACKENDS_GPU_CODEGEN_TRITON_EMITTER_HELPERS_H_ #include +#include #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -37,10 +39,20 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/status.h" #include "xla/xla.pb.h" namespace xla::gpu::triton { +// Returns a string representation of the given MLIR entity. +template +std::string MlirToString(T&& value) { + std::string result; + llvm::raw_string_ostream os(result); + value.print(os); + return result; +} + // This is a wrapper around mlir::Value that can hold either a scalar or a // non-0D tensor. An attempt to use this class with 0D tensors will CHECK-fail // because 0D tensors are not supported by Triton. diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index 2e590bc77f8d06..904f49ee7acc9c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -901,55 +901,6 @@ class MatMulEmitterHelper { dims_(dims), launch_config_(launch_config) {} - // TODO(b/266862493): Add support for more types as needed. - absl::StatusOr GetDotAccumulatorType(EmitterLocOpBuilder& b) { - const PrecisionConfig::Algorithm algorithm = - dot_instr_->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_UNSET) { - TF_ASSIGN_OR_RETURN(Type dot_output_ty, - TritonType(b, dot_instr_->shape().element_type())); - // The code below assumes that lhs and rhs have the same type. However - // it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported - // at the hardware level. NVidia GPU currently only supports f32 - // accumulator for such matmuls. - if (IsFp8Matmul(dot_instr_)) { - return b.getF32Type(); - } - - // Data type of dot() immediate inputs. - TF_ASSIGN_OR_RETURN( - const Type lhs_ty, - TritonType(b, dot_instr_->operand(0)->shape().element_type())); - TF_ASSIGN_OR_RETURN( - const Type rhs_ty, - TritonType(b, dot_instr_->operand(1)->shape().element_type())); - TF_RET_CHECK(lhs_ty == rhs_ty); - Type dot_input_ty = lhs_ty; - - // Currently allowing 8x8-bit ints -> i32. - if (dot_input_ty == b.getIntegerType(8) && dot_output_ty.isInteger(32)) { - return b.getI32Type(); - } - return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b.getF64Type() - : b.getF32Type(); - } - - absl::StatusOr accum_type = - algorithm_util::GetDotAccumulatorType(algorithm); - CHECK(accum_type.ok()) << "Unexpected algorithm: " - << PrecisionConfig::Algorithm_Name(algorithm); - TF_ASSIGN_OR_RETURN(Type mlir_accum_type, - TritonType(b, accum_type.value())); - if (auto float_accum_type = - mlir::dyn_cast(mlir_accum_type)) { - return float_accum_type; - } - LOG(FATAL) << "Only floating point accumulator types are supported for " - "now, but we got: " - << llvm_ir::DumpToString(mlir_accum_type); - } - std::vector EpiloguePostOrderTransitiveOperands( const HloInstruction* root) { // Collect all instructions of the dot's output scope. @@ -1987,7 +1938,8 @@ absl::StatusOr> EmitMatMul( MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, index_ty, dims, launch_config, analysis); - TF_ASSIGN_OR_RETURN(mlir::Type acc_ty, emitter.GetDotAccumulatorType(b)); + TF_ASSIGN_OR_RETURN(mlir::Type acc_ty, + triton::GetDotAccumulatorType(b, *dot_instr)); ma::ConstantOp accumulator_init = CreateConst(b, acc_ty, 0, {block_m, block_n}); diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index fa73abdc3a8c95..d3a6d7fe4d8d74 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6485,6 +6485,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/service/algorithm_util.cc b/third_party/xla/xla/service/algorithm_util.cc index 7ea65f61caa7be..dd7f97c5ca9197 100644 --- a/third_party/xla/xla/service/algorithm_util.cc +++ b/third_party/xla/xla/service/algorithm_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/algorithm_util.h" #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -24,6 +25,7 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/protobuf.h" namespace xla { namespace algorithm_util { @@ -65,6 +67,49 @@ absl::StatusOr GetBlasComputationType( } } +absl::StatusOr> GetAllowedOperandsTypeForAlgorithm( + PrecisionConfig::Algorithm algorithm) { + switch (algorithm) { + case PrecisionConfig::ALG_UNSET: + break; + case PrecisionConfig::ALG_DOT_F16_F16_F16: + case PrecisionConfig::ALG_DOT_F16_F16_F32: + return std::vector{F16}; + case PrecisionConfig::ALG_DOT_F32_F32_F32: + return std::vector{F32}; + case PrecisionConfig::ALG_DOT_F64_F64_F64: + return std::vector{F64}; + case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + return std::vector{BF16}; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: + return std::vector{F32}; // This is not a typo. + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: + return std::vector{F32}; + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: { + std::vector f8_types; + const tsl::protobuf::EnumDescriptor* desc = + tsl::protobuf::GetEnumDescriptor(); + for (int i = 0; i < desc->value_count(); ++i) { + PrimitiveType ty = static_cast(desc->value(i)->number()); + if (primitive_util::IsF8Type(ty)) { + f8_types.push_back(ty); + } + } + return f8_types; + } + default: + break; + } + return absl::InternalError( + absl::StrFormat("GetDotAccumulatorType: unsupported algorithm %s", + xla::PrecisionConfig::Algorithm_Name(algorithm))); +} + absl::StatusOr GetDotAccumulatorType( PrecisionConfig::Algorithm algorithm) { // All dot algorithms should be listed here. diff --git a/third_party/xla/xla/service/algorithm_util.h b/third_party/xla/xla/service/algorithm_util.h index 391b3fee5e937e..293fc1e47240b8 100644 --- a/third_party/xla/xla/service/algorithm_util.h +++ b/third_party/xla/xla/service/algorithm_util.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_SERVICE_ALGORITHM_UTIL_H_ #include +#include #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -37,6 +38,22 @@ namespace algorithm_util { absl::StatusOr GetBlasComputationType( PrecisionConfig::Algorithm algorithm); +// Returns the list of types that are allowed for the dot operands of the given +// algorithm. The expectation is always that both dot operands use the same +// type. +// +// Algorithms mostly expect that their input and output types correspond to +// what the algorithm describes. This is not always the case though, e.g. +// for BF16_BF16_F32_X9, working from inputs casted to BF16 makes no sense; +// this algorithm instead expects F32 inputs, and performs splits into BF16 +// sub-values under the hood. +// +// Another exception (and why we can't return a single type) are algorithms +// working on F8 types, where we sometimes allow any flavour of F8 type to be +// used. +absl::StatusOr> GetAllowedOperandsTypeForAlgorithm( + PrecisionConfig::Algorithm algorithm); + // Get the accumulator type of an algorithm. absl::StatusOr GetDotAccumulatorType( PrecisionConfig::Algorithm algorithm); From e196b070041877970de4fd4c746ec27dd7c22dcb Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Tue, 1 Apr 2025 04:01:17 -0700 Subject: [PATCH 0085/1324] [XLA:GPU] Add support for multiple output tiles in triton_support_test + removes dependency on `get-tuple-element` for Reduce, BatchNormGrad & BatchNormTraining tests PiperOrigin-RevId: 742631101 --- .../gpu/codegen/triton/support_test.cc | 67 +++++++++++++------ .../backends/gpu/codegen/triton/test_utils.h | 4 +- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 9d062876e10364..1cfc6d48440e77 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -230,14 +230,36 @@ class TritonSupportTest : public TritonSupportTestBase { std::vector output_tile_sizes, se::GpuComputeCapability cc, ExpectedFailMode failure_mode = ExpectedFailMode::kFail) { + // output_tile_sizes is embedded in a vector of 1 element to share the logic + // with the multiple output tiles case. + RunSupportTestMultipleOutputTiles( + std::move(ti), {std::move(output_tile_sizes)}, cc, failure_mode); + } + + void RunSupportTestMultipleOutputTiles( + TestedInstruction ti, std::vector> output_tile_sizes, + se::GpuComputeCapability cc, + ExpectedFailMode failure_mode = ExpectedFailMode::kFail) { // Ensure that the caller provided the right number of output tile sizes. // If that is not the case, codegen could fail for that reason---which - // wouldn't give any valuable signal here. We skip the check for non-array - // output shapes, since we have no meaningful way of providing tile sizes - // for them at the moment. + // wouldn't give any valuable signal here. The check is only done for array + // and tuple shapes (only one layer of nesting is supported for tuples). if (ti.Instruction().shape().IsArray()) { + ASSERT_EQ(output_tile_sizes.size(), 1); + ASSERT_EQ(output_tile_sizes[0].size(), + ti.Instruction().shape().dimensions().size()); + } else if (ti.Instruction().shape().IsTuple()) { ASSERT_EQ(output_tile_sizes.size(), - ti.Instruction().shape().dimensions_size()); + ti.Instruction().shape().tuple_shapes_size()); + for (int64_t i = 0; i < output_tile_sizes.size(); ++i) { + const auto& shape = ti.Instruction().shape().tuple_shapes(i); + if (shape.IsTuple()) { + continue; // No validation for nested tuples, as there is no way to + // specify output tile sizes for them. + } + ASSERT_TRUE(shape.IsArray()); + ASSERT_EQ(shape.dimensions().size(), output_tile_sizes[i].size()); + } } BlockLevelParameters block_level_parameters = FromOutputTileSizes(std::move(output_tile_sizes)); @@ -726,16 +748,16 @@ add { ENTRY triton_computation { parameter_0 = $$0[125,127] parameter(0) constant_0 = $$0[] constant($0) - tuple = ($$0[125], $$0[125]) reduce( + ROOT reduce = ($$0[125], $$0[125]) reduce( parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add - ROOT reduce = $$0[125] get-tuple-element(tuple), index=0 })", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); + RunSupportTestMultipleOutputTiles(std::move(ti), + /*output_tile_sizes=*/{{1}, {1}}, cc); } TEST_F(ReduceTest, ReduceWithNonConstReduceValueIsSupportedWithTriton) { @@ -1025,7 +1047,8 @@ ENTRY triton_computation { TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, HloOpcode::kAllGatherStart)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 2}, cc); + RunSupportTestMultipleOutputTiles(std::move(ti), + /*output_tile_sizes=*/{{2, 2}, {2, 2}}, cc); } TEST_P(CollectiveTest, UnsupportedAllGatherDoneFailsGracefullyWithTriton) { @@ -1142,7 +1165,8 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, HloOpcode::kCollectivePermuteDone)); - RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{2, 2}, cc); + RunSupportTestMultipleOutputTiles(std::move(ti_start), + /*output_tile_sizes=*/{{2, 2}, {2, 2}}, cc); RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{2, 2}, cc); } @@ -1197,8 +1221,10 @@ ENTRY triton_computation { TestedInstruction ti_done, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, HloOpcode::kAsyncDone)); - RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{1}, cc); - RunSupportTest(std::move(ti_update), /*output_tile_sizes=*/{1}, cc); + RunSupportTestMultipleOutputTiles(std::move(ti_start), + /*output_tile_sizes=*/{{1}, {1}}, cc); + RunSupportTestMultipleOutputTiles(std::move(ti_update), + /*output_tile_sizes=*/{{1}, {1}}, cc); RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{1}, cc); } @@ -1436,7 +1462,8 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc); + RunSupportTestMultipleOutputTiles(std::move(ti), + /*output_tile_sizes=*/{{1}, {16, 32}}, cc); } INSTANTIATE_TEST_SUITE_P( @@ -1608,8 +1635,6 @@ INSTANTIATE_TEST_SUITE_P( using BatchNormTrainingTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; -// TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output -// tikes support to RunSupportTest. TEST_P(BatchNormTrainingTest, BatchNormTraining) { auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( @@ -1617,14 +1642,14 @@ ENTRY triton_computation { operand = $0[4,8,16,32] parameter(0) scale = $0[32] parameter(1) offset = $0[32] parameter(2) - bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset), + ROOT bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset), epsilon=0.001, feature_index=3 - ROOT gte = $0[4,8,16,32] get-tuple-element(bn_train), index=0 })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc); + RunSupportTestMultipleOutputTiles( + std::move(ti), /*output_tile_sizes=*/{{1, 1, 4, 8}, {1}, {1}}, cc); } INSTANTIATE_TEST_SUITE_P( @@ -1634,8 +1659,6 @@ INSTANTIATE_TEST_SUITE_P( using BatchNormGradTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; -// TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output -// tikes support to RunSupportTest. TEST_P(BatchNormGradTest, BatchNormGrad) { auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( @@ -1645,14 +1668,14 @@ ENTRY triton_computation { mean = $0[32] parameter(2) variance = $0[32] parameter(3) grad_output = $0[4,8,16,32] parameter(4) - bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output), + ROOT bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output), epsilon=0.001, feature_index=3 - ROOT gte = $0[4,8,16,32] get-tuple-element(bn_grad), index=0 })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc); + RunSupportTestMultipleOutputTiles( + std::move(ti), /*output_tile_sizes=*/{{1, 1, 4, 8}, {1}, {1}}, cc); } INSTANTIATE_TEST_SUITE_P( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h index db2989801a1d6f..bbd640511ba3dc 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h @@ -65,9 +65,9 @@ absl::Status CreateTritonIrAndFileCheckForDot( const HloComputation& computation, absl::string_view filecheck_pattern); inline BlockLevelParameters FromOutputTileSizes( - std::vector output_tile_sizes) { + std::vector> output_tile_sizes) { BlockLevelParameters block_level_parameters; - block_level_parameters.output_tile_sizes.push_back(output_tile_sizes); + block_level_parameters.output_tile_sizes = std::move(output_tile_sizes); return block_level_parameters; } From ce134e0cd3c1ec21d3089125db77baccc3159c1b Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Tue, 1 Apr 2025 12:44:14 +0100 Subject: [PATCH 0086/1324] [TOSA] Add int8 and int16 legalization of the EXP op (#61989) * Align getTosaConst8bitTable and getTosaConst16bitTable behaviour to the TFL LUTPopulate function * Add int8 and int16 legalization of the EXP op --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 28 +++++- .../mlir/tosa/transforms/legalize_common.cc | 9 +- .../mlir/tosa/transforms/legalize_tfl.cc | 77 +++++++++++++--- .../mlir/tosa/transforms/legalize_utils.cc | 87 +++++++++++++------ .../mlir/tosa/transforms/legalize_utils.h | 14 +-- .../tosa/transforms/tfl_legalize_patterns.td | 1 - 6 files changed, 162 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 9bfc60e593f67f..080db037fc3215 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -431,9 +431,31 @@ func.func @test_mul_unranked(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x1x1xf32 // CHECK-LABEL: test_exp // CHECK: %[[VAR0:.*]] = tosa.exp %arg0 -func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_exp_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_exp_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_exp_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_exp_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> } // ----- diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 100a429af4a96a..5881990eb0d955 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -1839,8 +1839,8 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, auto exp_func = [](double x) -> double { return std::exp(x); }; // Follow TFLite reference: tensorflow/lite/kernels/activations.cc - Value exp_table_const = - getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0); + Value exp_table_const = getTosaConst16bitTable( + rewriter, op, 10.0 / 65535.0, 32767, 2.0 / 65535.0, 0, exp_func); double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0); @@ -1913,8 +1913,9 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, return 1.0 / (1.0 + x); }; - Value one_over_one_plus_x_table_const = getTosaConst16bitTable( - rewriter, op, one_over_one_plus_x_func, 0.0, 1.0); + Value one_over_one_plus_x_table_const = getTosaConst16bitTable( + rewriter, op, 1.0 / 65535.0, -32768, 2.0 / 65535.0, 0, + one_over_one_plus_x_func); // Get (1 / sum(exp(x))) result as 23 bits (including sign bit) auto op17_table_op16 = CreateOpAndInfer( diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index bea9dc46b5fd75..acb17d1e8e450e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -209,6 +209,7 @@ DECL_CONVERT_OP(LogicalAnd); DECL_CONVERT_OP(LogicalOr); DECL_CONVERT_OP(Pow); DECL_CONVERT_OP(BroadcastTo); +DECL_CONVERT_OP(Exp); #undef DECL_CONVERT_OP @@ -3600,7 +3601,8 @@ LogicalResult ConvertTFLAtan2Op::matchAndRewrite( // Note: the implementation of std::atan2 may be different on // different machines, so may result in varying numerical results. auto atan_func = [](double x) -> double { return std::atan(x); }; - Value table_const = getTosaConst16bitTable(rewriter, op, atan_func, 0.0, 1.0); + Value table_const = getTosaConst16bitTable( + rewriter, op, 1.0 / 65535.0, -32768, 2.0 / 65535.0, 0, atan_func); auto table_result = CreateOpAndInfer( rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted, table_const); @@ -3693,13 +3695,10 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "input/output zeropoint should be 0 in 16-bit mode"); } - double input_min = -32768 * input_qtype.getScale(); - double input_max = 32767 * input_qtype.getScale(); - // Generate table with gen_lut() in - // tensorflow/lite/kernels/internal/common.h - Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func, - input_min, input_max); + Value table_const = + getTosaConst16bitTable(rewriter, op, input_qtype.getScale(), + 0, 2.0 / 65535.0, 0, sigmoid_func); auto op1_table_in = CreateOpAndInfer(rewriter, op->getLoc(), int32_type, @@ -3765,13 +3764,9 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "input/output zeropoint should be 0 in 16-bit mode"); } - double input_min = -32768 * input_qtype.getScale(); - double input_max = 32767 * input_qtype.getScale(); - // Generate table with gen_lut() in - // tensorflow/lite/kernels/internal/common.h - Value table_const = - getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max); + Value table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), 0, 2.0 / 65535.0, 0, tanh_func); auto op1_table_in = CreateOpAndInfer(rewriter, op->getLoc(), int32_type, @@ -4821,6 +4816,62 @@ LogicalResult ConvertTFLBroadcastToOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLExpOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_exp_op = cast(op); + + RankedTensorType input_type = + dyn_cast(tfl_exp_op.getX().getType()); + RankedTensorType output_type = + dyn_cast(tfl_exp_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure( + op, "input/output are not all a ranked tensor"); + } + + mlir::quant::UniformQuantizedType input_qtype = + dyn_cast_or_null( + input_type.getElementType()); + mlir::quant::UniformQuantizedType output_qtype = + dyn_cast_or_null( + output_type.getElementType()); + + if ((input_qtype == nullptr) != (output_qtype == nullptr)) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + // Quantization case + if (input_qtype && output_qtype) { + auto exp_func = [](float x) -> float { return std::exp(x); }; + + Value table_const; + if (input_qtype.getStorageTypeIntegralWidth() == 8) { + table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), exp_func); + } else if (input_qtype.getStorageTypeIntegralWidth() == 16) { + table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), exp_func); + } else { + return rewriter.notifyMatchFailure( + op, "only quantized int8 and int16 are supported"); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_exp_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_exp_op.getType(), + tfl_exp_op.getX()); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index d3be58ce5a7d51..7b688ea3adf8d2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -637,24 +637,25 @@ Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, } // Create a 8-bit TOSA TABLE constant tensor with int8[256] array. -// Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc +// Follow LUTPopulateInt8() tensorflow/lite/kernels/internal/common.h Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, - double input_scale, int32_t input_zp, - double output_scale, int32_t output_zp, - std::function func) { + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp, + std::function func) { SmallVector table; + float inverse_scale = 1.0f / output_scale; for (int32_t i = -128; i < 128; i++) { - double dequantized = input_scale * (i - input_zp); - double transformed = func(dequantized); + float dequantized = input_scale * (i - input_zp); + float transformed = func(dequantized); - double max = (output_scale > 1.0) ? DBL_MAX : (DBL_MAX * output_scale); + float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale); if (transformed >= max) { table.push_back(INT8_MAX); continue; } - int32_t rescaled = std::llround(transformed / output_scale); + int32_t rescaled = std::round(transformed * inverse_scale); int32_t quantized = static_cast(rescaled + output_zp); table.push_back( static_cast(std::min(std::max(quantized, -128), 127))); @@ -673,34 +674,52 @@ Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } -// Create a 16-bit TOSA TABLE constant tensor with int16[513] array. -// Output is restricted to [-1.0, 1.0]. -// Follow gen_lut() tensorflow/lite/kernels/internal/common.h +// Create a 16-bit TOSA TABLE constant tensor. +// A float should be used by default for FloatT except if a double is required +// for backward compatibility. +// Follow LUTPopulateInt16() tensorflow/lite/kernels/internal/common.h +template Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, - std::function func, double min, - double max) { + FloatT input_scale, int32_t input_zp, + FloatT output_scale, int32_t output_zp, + std::function func) { + static_assert(std::is_floating_point::value, + "FloatT must be a floating-point type."); + SmallVector table; - double step = (max - min) / 512.0f; - double half_step = step / 2.0f; + FloatT input_min = + input_scale * (std::numeric_limits::min() - input_zp); + FloatT input_max = + input_scale * (std::numeric_limits::max() - input_zp); + FloatT output_min = + output_scale * (std::numeric_limits::min() - output_zp); + FloatT output_max = + output_scale * (std::numeric_limits::max() - output_zp); + + FloatT step = (input_max - input_min) / 512; + FloatT half_step = step / 2; + FloatT output_scaling_inv = 65536 / (output_max - output_min); + for (int32_t i = 0; i < 512; i++) { - int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0); - double midpoint_interp_val = - std::round(((func(min + (i + 1) * step) * 32768.0) + - std::round(func(min + (i * step)) * 32768.0)) / - 2.0); - double midpoint_val = - std::round(func(min + (i * step) + half_step) * 32768.0); - double midpoint_err = midpoint_interp_val - midpoint_val; - int32_t bias = std::llround(midpoint_err / 2.0); + FloatT sample_val = + std::round(func(input_min + (i * step)) * output_scaling_inv); + FloatT midpoint_interp_val = std::round( + ((func(input_min + (i + 1) * step) * output_scaling_inv) + + std::round(func(input_min + (i * step)) * output_scaling_inv)) / + 2); + FloatT midpoint_val = std::round(func(input_min + (i * step) + half_step) * + output_scaling_inv); + FloatT midpoint_err = midpoint_interp_val - midpoint_val; + FloatT bias = std::round(midpoint_err / 2); table.push_back(static_cast( - std::min(std::max(sample_val - bias, -32768), 32767))); + std::min(std::max(sample_val - bias, -32768), 32767))); } - int32_t max_val = std::llround(func(max) * 32768.0); - table.push_back( - static_cast(std::min(std::max(max_val, -32768), 32767))); + FloatT max_val = std::round(func(input_max) * output_scaling_inv); + table.push_back(static_cast( + std::min(std::max(max_val, -32768), 32767))); auto const_type = tensorflow::GetTypeFromTFTensorShape({513}, rewriter.getIntegerType(16)); @@ -711,6 +730,18 @@ Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } +template Value getTosaConst16bitTable(PatternRewriter& rewriter, + Operation* op, float input_scale, + int32_t input_zp, + float output_scale, + int32_t output_zp, + std::function func); + +template Value getTosaConst16bitTable( + PatternRewriter& rewriter, Operation* op, double input_scale, + int32_t input_zp, double output_scale, int32_t output_zp, + std::function func); + // Create a 32-bit TOSA TABLE for Softmax Exp void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, double beta, double input_scale, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index 8dc618d2bf2608..40d3e9f974e7a0 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -102,14 +102,18 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, // Create a 8-bit TOSA TABLE constant tensor Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, - double input_scale, int32_t input_zp, - double output_scale, int32_t output_zp, - std::function func); + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp, + std::function func); // Create a 16-bit TOSA TABLE constant tensor +// A float should be used by default for FloatT except if a double is required +// for backward compatibility +template Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, - std::function func, double min, - double max); + FloatT input_scale, int32_t input_zp, + FloatT output_scale, int32_t output_zp, + std::function func); // Create a 32-bit TOSA TABLE for Softmax Exp void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td index a7230ccf901399..1ed4d67a5f2bb7 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td +++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td @@ -29,7 +29,6 @@ include "mlir/Dialect/Tosa/IR/TosaOps.td" def ConvertTFLAbsOp : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>; def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>; def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>; -def ConvertTFLExpOp : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>; def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>; def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>; From 0ef1f026976c533aac524836f0d120a597e74081 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 1 Apr 2025 04:59:59 -0700 Subject: [PATCH 0087/1324] Reland Adjust PriorityFusion to allow forming simple multi-output Triton fusions. Reverts a7ace96a10521859b80d9e3bbe041538b9fcc333 PiperOrigin-RevId: 742645835 --- .../xla/hlo/analysis/hlo_dfs_reachability.cc | 5 +- .../model/gpu_indexing_performance_model.cc | 9 +- .../gpu/model/symbolic_tile_analysis.h | 3 + .../xla/xla/service/gpu/transforms/BUILD | 1 + .../service/gpu/transforms/priority_fusion.cc | 218 +++++++++++++++--- .../service/gpu/transforms/priority_fusion.h | 3 +- .../gpu/transforms/priority_fusion_test.cc | 103 +++++++++ 7 files changed, 303 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc index e26460630721f8..e9593d4ce57baf 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc @@ -115,9 +115,10 @@ void HloDfsReachability::OnInstructionReplaced(const HloInstruction* previous, const HloInstruction* now) { auto it = instruction_to_idx_.find(previous); CHECK(it != instruction_to_idx_.end()); - auto inserted = instruction_to_idx_.insert({now, it->second}).second; - CHECK(inserted); + auto idx = it->second; instruction_to_idx_.erase(it); + auto inserted = instruction_to_idx_.insert({now, idx}).second; + CHECK(inserted); } } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 8f22c3e347f61a..ce41097da44677 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -512,7 +512,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( } absl::Duration compute_time = - ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), + ComputeTime(*device_info_, flops, num_blocks, launch_dimensions.num_threads_per_block()); absl::Duration memory_access_time = read_time + write_time; @@ -543,15 +543,12 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( return absl::FailedPreconditionError(absl::StrCat( "SymbolicTileAnalysis failed. ", fusion_decision->Explain())); } - // TODO(b/390559452): Add support for more than one fusion root. - if (tile_sizes.size() != 1) { - return absl::UnimplementedError("Only 1 root is supported right now"); - } SymbolicTileAnalysis analysis = std::get(std::move(analysis_or_error)); TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, - analysis.ComputeTiledHloInstructions(tile_sizes[0])); + analysis.ComputeTiledHloInstructions( + tile_sizes[analysis.real_root_index()])); return EstimateRunTimeForTiledHloComputation( fusion_adaptor, tiled_hlo_computation, launch_dimensions); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 17bfd837ad38c1..07243bb4904ef4 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -142,6 +142,9 @@ class SymbolicTileAnalysis { return root_indexing_.roots[idx]; } + // Returns the output index of the real root. + int64_t real_root_index() const { return root_indexing_.real_root_index; } + // Returns the number of tile parameters in this symbolic analysis. // TODO(b/390569102): This assumes that there is only one root that matters // for computing the tiling, and that it is the last symbolic tiled hlo diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index ba26a805777e15..68d8af6460438d 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2438,6 +2438,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/gpu/codegen/triton:support", + "//xla/hlo/analysis:hlo_dfs_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_traversal", diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index 1850ff499126da..11f3435fc59d1f 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/backends/gpu/codegen/triton/support.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -162,6 +163,7 @@ class PriorityFusionQueue { fusion_analysis_cache_(fusion_analysis_cache), fusion_deduplication_cache_(fusion_deduplication_cache), fusion_info_cache_(*device_info_), + reachability_(HloDfsReachability::Build(computation)), triton_heroless_fusion_enabled_(triton_heroless_fusion_enabled) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -253,6 +255,10 @@ class PriorityFusionQueue { reverse_map_.erase(current_producer_); current_consumers_ = current_producer_->users(); + auto preferred_consumer = GetPreferredConsumer(current_producer_); + if (preferred_consumer) { + current_consumers_ = {*preferred_consumer}; + } if (HloPredicateIsOp(current_producer_)) { // We don't check if bitcasts can be fused with all consumers, so we @@ -266,6 +272,15 @@ class PriorityFusionQueue { return !current_consumers_.empty(); } + std::optional GetPreferredConsumer( + HloInstruction* producer) { + auto it = preferred_consumer_.find(producer); + if (it == preferred_consumer_.end()) { + return std::nullopt; + } + return it->second; + } + absl::Status UpdatePerformanceModelCache(HloInstruction* producer) { if (!IsFusible(*producer)) { return absl::OkStatus(); @@ -390,9 +405,19 @@ class PriorityFusionQueue { } // Updates data for the new fusion instruction and its users and operands. + // Both `original_producer` and `original_consumer` could have been removed + // already from the computation, waiting for deletion. We can still + // dereference them though. void OnFusingInstruction(HloInstruction* fusion, HloInstruction* original_producer, - HloInstruction* original_consumer) { + HloInstruction* original_consumer, + int64_t original_consumer_operand_index) { + bool creates_multi_output_fusion = + preferred_consumer_.contains(original_producer); + fusion_deduplication_cache_.UpdateFusedInstructionId( + fusion, original_producer, original_consumer, + original_consumer_operand_index, creates_multi_output_fusion); + if (fusion_process_dump_) { auto* fusion_step = fusion_process_dump_->add_fusion_steps()->mutable_fusion(); @@ -411,13 +436,24 @@ class PriorityFusionQueue { *fusion); } - // The original consumer was replaced with the fusion, but it's pointer can - // still be referenced somewhere, for example, in to_update_priority_. - // Priority recomputation is called before DCE. Remove all references to - // the original consumer here. - if (fusion != original_consumer) { + if (fusion == original_consumer) { + // We need to check again whether we can use `original_consumer` as a + // producer for a ProducerConsumer multi-output fusion. + preferred_consumer_.erase(original_consumer); + } else { + // The original consumer was replaced with the fusion, but it's pointer + // can still be referenced somewhere, for example, in to_update_priority_. + // Priority recomputation is called before DCE. Remove all references to + // the original consumer here. + reachability_->OnInstructionReplaced(/*previous=*/original_consumer, + /*now=*/fusion); RemoveInstruction(original_consumer); } + if (creates_multi_output_fusion) { + // After a multi-output fusion was created, we need to rebuild the + // HloDfsReachability data structure. + reachability_ = HloDfsReachability::Build(computation_); + } // Collect the instructions whose priorities need to be updated. for (HloInstruction* operand : fusion->operands()) { @@ -433,10 +469,21 @@ class PriorityFusionQueue { } to_update_priority_.insert(operand); - // update the consumers of this operand that we care about, - // so we can do incremental update of the operand + // Update the consumers of this operand that we care about, + // so we can do incremental update of the operand. operands_to_new_consumers_[operand].push_back(fusion); + + // We may need to reset `preferred_consumer_`, as we don't know yet + // whether that fusion would still be valid. + auto it = preferred_consumer_.find(operand); + if (it != preferred_consumer_.end() && it->second == original_consumer) { + preferred_consumer_.erase(it); + } } + // TODO(b/390559452): For multi-output fusion, we would also need to update + // the priorities of the other consumers of `producer` with which we did not + // fuse. For now, as we only allow multi-output fusion if there is just a + // single fusible consumer, this is not needed. to_update_priority_.insert(fusion); } @@ -451,6 +498,7 @@ class PriorityFusionQueue { } producer_priority_queue_.erase(reverse_it->second); reverse_map_.erase(reverse_it); + preferred_consumer_.erase(instruction); } // Returns a map from consumer to BlockLevelParameters. This is used to @@ -485,9 +533,24 @@ class PriorityFusionQueue { return -absl::InfiniteDuration(); } - // Don't fuse if we can't fuse in all users. if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer); !fusion_decision) { + // If we cannot fuse `producer` into all non-bitcast consumers, try + // Triton multi-output fusion next. + std::vector possible_consumers = + FindPossibleConsumersForTritonMultiOutputFusion(producer); + if (CanFuseTritonMultiOutputWithSingleUser(producer, + possible_consumers)) { + GpuPerformanceModel::RunTimes run_times = + GpuPerformanceModel::EstimateRunTimes( + producer, *device_info_, &cost_analysis_, + GpuPerformanceModelOptions::Default( + &fusion_analysis_cache_, &gpu_performance_model_cache_), + /*fused_consumers=*/possible_consumers); + preferred_consumer_[producer] = possible_consumers[0]; + return run_times.time_unfused - run_times.time_fused; + } + // Don't fuse if we can't fuse in all users. if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = fusion_process_dump_->add_fusion_steps() @@ -564,10 +627,12 @@ class PriorityFusionQueue { } TiledRunTimeDataOrError GetTiledRunTimeDataCached( - const HloInstruction* producer, const HloInstruction* consumer) { + const HloInstruction* producer, const HloInstruction* consumer, + bool use_multi_output_fusion = false) { FusionDeduplicationCache::FusionId fusion_id = [&]() { absl::MutexLock lock(&fusion_deduplication_cache_mutex_); - return fusion_deduplication_cache_.GetFusionId(producer, consumer); + return fusion_deduplication_cache_.GetFusionId(producer, consumer, + use_multi_output_fusion); }(); { @@ -579,7 +644,8 @@ class PriorityFusionQueue { } } - auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer); + auto fusion = HloFusionAdaptor::ForProducerConsumer( + producer, consumer, use_multi_output_fusion); absl::StatusOr result_or_status = gpu_indexing_performance_model_.TryFindBestTilingForFusion(*fusion); @@ -611,7 +677,8 @@ class PriorityFusionQueue { } FusionDecision CanFuseTriton(HloInstruction* producer, - HloInstruction* consumer) { + HloInstruction* consumer, + bool use_multi_output_fusion = false) { if (!IsGenericTritonFusion(*producer) && !IsGenericTritonFusion(*consumer) && !triton_heroless_fusion_enabled_) { return FusionDecision::Forbid("triton heroless fusion is not enabled"); @@ -626,7 +693,7 @@ class PriorityFusionQueue { } TiledRunTimeDataOrError tiled_run_time_data_or_error = - GetTiledRunTimeDataCached(producer, consumer); + GetTiledRunTimeDataCached(producer, consumer, use_multi_output_fusion); if (const auto* fusion_decision = std::get_if(&tiled_run_time_data_or_error)) { @@ -638,9 +705,14 @@ class PriorityFusionQueue { // This is our way to pass the runtime estimate to the CalculatePriorities() // function. + // This is somewhat brittle as we currently don't distinguish between + // ProducerConsumer fusion where we allow multi-output fusions to be formed, + // and ProducerConsumer fusion where we don't allow it. Same for the + // `block_level_parameters_cache_` down below. Currently we only try out + // multi-output fusion if we cannot fuse into all consumers, and it is tried + // last, so the final cached value should be what we want. gpu_performance_model_cache_.Set( *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); - { absl::MutexLock lock(&block_level_parameters_cache_mutex_); block_level_parameters_cache_[producer][consumer] = @@ -780,6 +852,63 @@ class PriorityFusionQueue { return fusion_decision; } + // Checks whether any operand of `consumer` is reachable from `producer` + // following user edges in the HLO graph. If that is the case, we would + // introduce a cycle by fusing `producer` into `consumer`. + bool OperandReachableFromProducer(const HloInstruction* producer, + const HloInstruction* consumer) { + for (const auto* consumer_operand : consumer->operands()) { + CHECK(reachability_->IsPresent(consumer_operand) && + reachability_->IsPresent(producer)) + << "Reachability map is incomplete. This should never " + "happen."; + if (producer != consumer_operand && + reachability_->IsReachable(producer, consumer_operand)) { + return true; + } + } + return false; + } + + std::vector FindPossibleConsumersForTritonMultiOutputFusion( + HloInstruction* producer) { + bool triton_multi_output_fusion_enabled = + producer->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_enable_triton_multi_output_fusion(); + if (!triton_multi_output_fusion_enabled) { + return {}; + } + std::vector possible_consumers; + for (const auto& user : producer->users()) { + if (HloPredicateIsOp(user)) { + continue; + } + if (CanFuseTriton(producer, user, /*use_multi_output_fusion=*/true) && + !OperandReachableFromProducer(producer, user)) { + possible_consumers.push_back(user); + } + } + return possible_consumers; + } + + FusionDecision CanFuseTritonMultiOutputWithSingleUser( + HloInstruction* producer, + const std::vector& possible_consumers) { + if (possible_consumers.empty()) { + return FusionDecision::Forbid("No users to fuse"); + } + + if (possible_consumers.size() != 1) { + // TODO(b/390559452): If there are several possible consumers to fuse + // with, decide which one is best. Also depends on what further fusions + // might be possible, needs checking the reachability graph. + return FusionDecision::Forbid("more than one consumer to fuse with"); + } + return FusionDecision::Allow(); + } + FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { if (producer->users().empty()) { return FusionDecision::Forbid("No users to fuse"); @@ -825,6 +954,11 @@ class PriorityFusionQueue { // A reverse map that helps find an instruction in the priority queue. absl::flat_hash_map reverse_map_; + // Stores a mapping from the producer to the preferred consumer to fuse into. + // This is only used in case that we want to use ProducerConsumer multi-output + // fusion. + absl::flat_hash_map preferred_consumer_; + // The current producer being visited. HloInstruction* current_producer_; @@ -880,6 +1014,10 @@ class PriorityFusionQueue { // like shared memory usage or number of unnested reductions of fusion nodes. FusionInfoCache fusion_info_cache_; + // Allows evaluation of whether an HloInstruction is an ancestor of another + // HloInstruction. + std::unique_ptr reachability_; + // If true, redirect all fusion decisions to Triton fusion. bool triton_heroless_fusion_enabled_; @@ -984,7 +1122,12 @@ absl::StatusOr PriorityFusion::Run( block_level_parameters_map = fusion_queue->GetBlockLevelParametersMap(producer); - for (auto* consumer : fusion_queue->current_consumers()) { + auto preferred_consumer = fusion_queue->GetPreferredConsumer(producer); + std::vector consumers = + fusion_queue->current_consumers(); + bool use_multi_output_fusion = preferred_consumer.has_value(); + + for (auto* consumer : consumers) { // Don't fuse into single bitcasts. We ignore them in the check // CanFuseWithAllNonBitcastUsers(), so we need to check it here. if (HloPredicateIsOp(consumer)) { @@ -998,12 +1141,8 @@ absl::StatusOr PriorityFusion::Run( int64_t consumer_operand_index = consumer->operand_index(producer); fusion_queue->PreFusion(producer, consumer); - auto fusion_instruction = Fuse(producer, consumer); - fusion_deduplication_cache.UpdateFusedInstructionId( - fusion_instruction, producer, consumer, consumer_operand_index); - fusion_queue->OnFusingInstruction(fusion_instruction, producer, - consumer); - + auto fusion_instruction = + Fuse(producer, consumer, use_multi_output_fusion); auto backend_config_it = block_level_parameters_map.find(consumer); if (backend_config_it != block_level_parameters_map.end()) { TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( @@ -1011,20 +1150,25 @@ absl::StatusOr PriorityFusion::Run( fusion_instruction->set_fusion_kind( HloInstruction::FusionKind::kCustom); } + fusion_queue->OnFusingInstruction(fusion_instruction, producer, + consumer, consumer_operand_index); changed = true; } fusion_queue->ComputeRuntimesOfRemovedConsumers(); - if (producer->user_count() == 0) { + if (use_multi_output_fusion || producer->user_count() == 0) { fusion_queue->InvalidateCaches(producer); - producer->DetachFromOperandsAndUsers(); fusion_queue->RemoveInstruction(producer); - // Remove from computation. - TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + // When we use ProducerConsumer multi-output fusion, `producer` will + // have been removed already. + if (!use_multi_output_fusion) { + producer->DetachFromOperandsAndUsers(); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + } } - for (auto* consumer : fusion_queue->current_consumers()) { + for (auto* consumer : consumers) { fusion_queue->InvalidateCaches(consumer); } TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); @@ -1094,7 +1238,8 @@ HloInstruction::FusionKind PriorityFusion::ChooseKind( } HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { + HloInstruction* consumer, + bool use_multi_output_fusion) { VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); @@ -1115,9 +1260,22 @@ HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, /*skip_async_execution_thread_overwrite=*/false); if (HloPredicateIsOp(producer)) { - fusion_instruction->MergeFusionInstruction(producer); + if (use_multi_output_fusion) { + fusion_instruction->MergeFusionInstructionIntoMultiOutput(producer); + } else { + fusion_instruction->MergeFusionInstruction(producer); + } } else { - fusion_instruction->FuseInstruction(producer); + if (use_multi_output_fusion) { + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + // MergeFusionInstructionIntoMultiOutput already removes `producer` from + // the computation. Do the same here, so that we have the invariant that + // the producer has been cleaned up when multi-output fusion is used. + CHECK_EQ(0, producer->user_count()); + TF_CHECK_OK(producer->parent()->RemoveInstruction(producer)); + } else { + fusion_instruction->FuseInstruction(producer); + } } if (fusion_instruction != consumer) { diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h index f1d19532a755f5..c9a29a5c056f35 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h @@ -60,7 +60,8 @@ class PriorityFusion : public HloModulePass { HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); - HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer, + bool use_multi_output_fusion = false); private: // Consumes a unit of compiler fuel and returns true if we should diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index b534f46150b382..463b4d45c475b5 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -1026,6 +1026,109 @@ ENTRY main { 2); } +TEST_F(PriorityFusionTest, + FuseTritonProducerWithTwoConsumersUsingMultiOutputFusion) { + const std::string kHloText = R"( +HloModule t + +producer_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} +} + +consumer_computation { + parameter_0 = f32[125,127] parameter(0) + ROOT log = f32[125,127] log(parameter_0) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + producer_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=producer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","127"]}],"num_warps":"1"}}} + consumer_fusion = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation + ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion, producer_fusion) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true); + EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction *fusion1, *fusion2; + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Fusion(&fusion1, m::Parameter()), 0), + m::GetTupleElement(m::Fusion(&fusion2, m::Parameter()), 1)))); + EXPECT_EQ(fusion1, fusion2); + EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, + fusion1->backend_config()); + EXPECT_TRUE( + backend_config1.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config1.fusion_backend_config() + .block_level_fusion_config() + .output_tiles(0) + .sizes_size(), + 2); +} + +TEST_F(PriorityFusionTest, + FuseProducerWithTritonConsumerUsingMultiOutputFusion) { + const std::string kHloText = R"( +HloModule t + +consumer_computation { + parameter_0 = f32[125,127] parameter(0) + ROOT log = f32[125,127] log(parameter_0) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + producer = f32[125,127] broadcast(param_0), dimensions={0} + consumer_fusion = f32[125,127] fusion(producer), kind=kCustom, calls=consumer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","127"]}],"num_warps":"1"}}} + ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion, producer) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); + + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true); + EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction *fusion1, *fusion2; + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Fusion(&fusion1, m::Parameter()), 0), + m::GetTupleElement(m::Fusion(&fusion2, m::Parameter()), 1)))); + EXPECT_EQ(fusion1, fusion2); + EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, + fusion1->backend_config()); + EXPECT_TRUE( + backend_config1.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config1.fusion_backend_config() + .block_level_fusion_config() + .output_tiles(0) + .sizes_size(), + 2); +} + TEST_F(PriorityFusionTest, TritonProducerNotSupported_DoNotFuse) { const std::string kHloText = R"( HloModule t From 00a6963b5ea10d9c336a72d02e525279f46c4c95 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 1 Apr 2025 07:31:47 -0700 Subject: [PATCH 0088/1324] [XLA:GPU][NFC] Remove dead `algorithm_util::IsAmpere` util. PiperOrigin-RevId: 742690841 --- third_party/xla/xla/service/algorithm_util.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/service/algorithm_util.cc b/third_party/xla/xla/service/algorithm_util.cc index dd7f97c5ca9197..e019effee6de16 100644 --- a/third_party/xla/xla/service/algorithm_util.cc +++ b/third_party/xla/xla/service/algorithm_util.cc @@ -148,13 +148,6 @@ bool HasFastAccum(PrecisionConfig::Algorithm algorithm) { return algorithm == PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM; } -bool IsAmpere(stream_executor::GpuComputeCapability gpu_compute_capability) { - return std::holds_alternative( - gpu_compute_capability) && - std::get(gpu_compute_capability).major == - stream_executor::CudaComputeCapability::kAmpere; -} - // It's clear that those libraries could support more, but we only list the ones // which we explicitly test for now. bool IsSupportedByCublasOrCublasLt( From a47955a149817ddd23af924ea700b3bb3f25f811 Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Tue, 1 Apr 2025 07:32:20 -0700 Subject: [PATCH 0089/1324] [XLA:GPU] add generic triton emitter support checks for fusions we now allow nested fusions with a specific backend config (at the moment they are operands of dots and concats but no special checks are performed for that). Note that '__triton' will not appear in practice right but we declare it supported anyway. PiperOrigin-RevId: 742691038 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 2 + .../gpu/codegen/triton/fusion_emitter.cc | 57 ++-- .../triton/fusion_emitter_device_test.cc | 62 ++++- .../backends/gpu/codegen/triton/support.cc | 51 ++-- .../gpu/codegen/triton/support_test.cc | 254 +++++++++--------- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/instruction_fusion.h | 7 + 7 files changed, 266 insertions(+), 168 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index c639c7e3e318dc..71accfea7ae094 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -859,6 +859,7 @@ cc_library( "//xla/service/gpu:variant_visitor", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -888,6 +889,7 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 3640e3c7e5b283..4b9271dfc49809 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1096,6 +1096,38 @@ absl::StatusOr EmitTiledHloInstruction( absl::StrCat("Unsupported operation ", hlo->ToString())); } +// Verifies that the nested fusion instruction conforms to the assumptions of +// the emitter. Currently, we expect nested fusions: +// - of kind `__triton_nested_gemm_fusion` +// - to have a single user that is either a `dot` or a `concatenate`. +absl::Status VerifyNestedFusion(const HloInstruction& hlo) { + // TODO(b/393299275): test cases when there are multiple dot users of the + // same fusion. + if (hlo.user_count() != 1) { + return absl::FailedPreconditionError( + absl::StrCat("Expected only one user for fusion ", hlo.ToString(), + " but got ", hlo.user_count())); + } + TF_ASSIGN_OR_RETURN(GpuBackendConfig backend_config, + hlo.backend_config()); + if (const std::string& kind = backend_config.fusion_backend_config().kind(); + kind != kTritonNestedGemmFusionKind) { + return absl::FailedPreconditionError(absl::StrCat( + "Expected ", hlo.ToString(), + " with fusion backend kind __triton_nested_gemm_fusion, got ", kind)); + } + const HloInstruction* user = hlo.users().front(); + switch (user->opcode()) { + case HloOpcode::kDot: + case HloOpcode::kConcatenate: + return absl::OkStatus(); + default: + return absl::FailedPreconditionError( + absl::StrCat("Unexpected user ", user->ToString(), + " of nested fusion ", hlo.ToString())); + } +} + // Emit a sequence of instructions using compatible tiling with producers // ordered before consumers in `tiled_computation`. Returns the results for the // roots of `tiled_computation`. @@ -1110,31 +1142,10 @@ absl::StatusOr> EmitTiledComputation( for (const TiledHloInstruction* tiled_hlo : tiled_computation.instructions()) { const HloInstruction* hlo = tiled_hlo->hlo(); - // We skip generating code for nested fusions, since they are always - // generated by their consumer. + // Skip generating nested fusions, they are emitted by their consumer. if (hlo->parent()->IsFusionComputation() && hlo->opcode() == HloOpcode::kFusion) { - // Currently, we expect nested fusions to have a single user that is - // either a `dot`, or a `concatenate`. Later, we will also support - // reductions here. - // - // TODO(b/393299275): test cases when there are multiple dot users of the - // same fusion. - if (hlo->user_count() != 1) { - return absl::FailedPreconditionError( - absl::StrCat("Expected only one user for fusion ", hlo->ToString(), - " but got ", hlo->user_count())); - } - const HloInstruction* user = hlo->users().front(); - switch (user->opcode()) { - case HloOpcode::kDot: - case HloOpcode::kConcatenate: - break; - default: - return absl::FailedPreconditionError(absl::StrCat( - "Expected only a single dot or concatenate user for fusion ", - hlo->ToString(), " but got ", user->ToString())); - } + TF_RETURN_IF_ERROR(VerifyNestedFusion(*hlo)); VLOG(1) << "Skipping nested fusion: " << hlo->ToString(); continue; } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index f25d7ed31903a8..8b59cc52b088a1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -2240,9 +2240,30 @@ concatenate_fusion { p1 = s32[128] parameter(1) p2 = s32[128] parameter(2) - fusion0 = s32[128] fusion(p0), kind=kCustom, calls=nest0 - fusion1 = s32[128] fusion(p1), kind=kCustom, calls=nest1 - fusion2 = s32[128] fusion(p2), kind=kCustom, calls=nest2 + fusion0 = s32[128] fusion(p0), kind=kCustom, calls=nest0, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} + fusion1 = s32[128] fusion(p1), kind=kCustom, calls=nest1, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} + fusion2 = s32[128] fusion(p2), kind=kCustom, calls=nest2, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} ROOT concatenate = s32[384] concatenate(fusion0, fusion1, fusion2), dimensions={0} } @@ -2251,7 +2272,8 @@ ENTRY main { p0 = s32[128] parameter(0) p1 = s32[128] parameter(1) p2 = s32[128] parameter(2) - ROOT fusion = s32[384] fusion(p0, p1, p2), kind=kCustom, calls=concatenate_fusion, backend_config={ + ROOT fusion = s32[384] fusion(p0, p1, p2), kind=kCustom, + calls=concatenate_fusion, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ @@ -2303,10 +2325,34 @@ rhs { p2 = f32[299,128] parameter(2) p3 = f32[299,128] parameter(3) - fusion0 = f32[299,128] fusion(p0), kind=kCustom, calls=nest0 - fusion1 = f32[299,128] fusion(p1), kind=kCustom, calls=nest1 - fusion2 = f32[299,128] fusion(p2), kind=kCustom, calls=nest2 - fusion3 = f32[299,128] fusion(p3), kind=kCustom, calls=nest3 + fusion0 = f32[299,128] fusion(p0), kind=kCustom, calls=nest0, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "64"]}] + } + } + } + fusion1 = f32[299,128] fusion(p1), kind=kCustom, calls=nest1, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "64"]}] + } + } + } + fusion2 = f32[299,128] fusion(p2), kind=kCustom, calls=nest2, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "64"]}] + } + } + } + fusion3 = f32[299,128] fusion(p3), kind=kCustom, calls=nest3, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "64"]}] + } + } + } concatenate = f32[299,512] concatenate(fusion0, fusion1, fusion2, fusion3), dimensions={1} ROOT cos = f32[299,512] cosine(concatenate) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 705d5c859cba58..e8771a2e767b65 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -367,6 +367,22 @@ CodegenDecision IsTritonSupportedDot( return CodegenDecision::Allow(); } +CodegenDecision IsSupportedFusion(const HloFusionInstruction& fusion) { + absl::StatusOr backend_config = + fusion.backend_config(); + if (!backend_config.ok()) { + return CodegenDecision(backend_config.status()); + } + absl::string_view fusion_kind = + backend_config.value().fusion_backend_config().kind(); + // Note: kTritonFusionKind is NOT expected to be set for nested fusions. + if (fusion_kind != kTritonNestedGemmFusionKind) { + return CodegenDecision::Forbid( + absl::StrCat("Unsupported fusion kind: ", fusion_kind)); + } + return CodegenDecision::Allow(); +} + CodegenDecision IsTritonSupportedConcatenate(const HloInstruction& hlo) { CHECK(hlo.opcode() == HloOpcode::kConcatenate); if (!IsInTritonNestedGemmFusion(hlo)) { @@ -439,19 +455,6 @@ CodegenDecision IsTritonSupportedInstructionImpl( "F8E4M3FN and F8E5M2 are not supported for iota."); } - if (instr.IsElementwise()) { - if (!IsTritonSupportedElementwise( - instr.opcode(), - // Use the last operand below in order to support both `compare` - // and `select` which have a fixed PRED type in the output and first - // operand. - instr.operand(instr.operand_count() - 1)->shape().element_type(), - gpu_version)) { - return CodegenDecision::Forbid("Unsupported elementwise operation."); - } - return CodegenDecision::Allow(); - } - switch (instr.opcode()) { case HloOpcode::kReduce: { return CanTritonHandleReduce(*Cast(&instr), @@ -467,11 +470,27 @@ CodegenDecision IsTritonSupportedInstructionImpl( case HloOpcode::kDot: return IsTritonSupportedDot(*Cast(&instr), gpu_version); + case HloOpcode::kFusion: + return IsSupportedFusion(*Cast(&instr)); default: - VLOG(2) << "Unsupported instruction: " << instr.ToString(); + // Not all instructions have a special handling. break; } - return CodegenDecision::Forbid("Unsupported instruction."); + + if (instr.IsElementwise()) { + if (!IsTritonSupportedElementwise( + instr.opcode(), + // Use the last operand below in order to support both `compare` + // and `select` which have a fixed PRED type in the output and first + // operand. + instr.operand(instr.operand_count() - 1)->shape().element_type(), + gpu_version)) { + return CodegenDecision::Forbid("Unsupported elementwise operation."); + } + return CodegenDecision::Allow(); + } + return CodegenDecision::Forbid(absl::StrCat("Unsupported instruction opcode ", + HloOpcodeString(instr.opcode()))); } } // namespace @@ -491,7 +510,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kFft: - case HloOpcode::kFusion: case HloOpcode::kGather: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: @@ -555,6 +573,7 @@ CodegenDecision IsTritonSupportedInstruction( CodegenDecision IsTritonSupportedComputation( const HloComputation& computation, const se::GpuComputeCapability& gpu_compute_capability) { + VLOG(3) << "IsTritonSupportedComputation: " << computation.ToString(); for (const auto* instruction : computation.instructions()) { if (CodegenDecision can_codegen = IsTritonSupportedInstruction(*instruction, gpu_compute_capability); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 1cfc6d48440e77..4a9b9d401a7d22 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" @@ -998,11 +999,16 @@ ENTRY triton_computation { p1 = $0[128] parameter(1) p2 = $0[128] parameter(2) - fusion0 = $0[128] fusion(p0), kind=kCustom, calls=nest0 - fusion1 = $0[128] fusion(p1), kind=kCustom, calls=nest1 - fusion2 = $0[128] fusion(p2), kind=kCustom, calls=nest2 - - ROOT concatenate = $0[384] concatenate(fusion0, fusion1, fusion2), dimensions={0} + fusion0 = $0[128] fusion(p0), kind=kCustom, calls=nest0, backend_config={ + "fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{"output_tiles":[{"sizes":["64"]}]}}} + fusion1 = $0[128] fusion(p1), kind=kCustom, calls=nest1, backend_config={ + "fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{"output_tiles":[{"sizes":["64"]}]}}} + fusion2 = $0[128] fusion(p2), kind=kCustom, calls=nest2, backend_config={ + "fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{"output_tiles":[{"sizes":["64"]}]}}} + ROOT result = $0[384] concatenate(fusion0, fusion1, fusion2), dimensions={0} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, @@ -1765,7 +1771,7 @@ frhs { ROOT result = $0[256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,256] parameter(0) p1 = $0[256,512] parameter(1) lhs = $0[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -1785,14 +1791,6 @@ triton_computation { ROOT result = $0[128,512]{1,0} dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } - -ENTRY e { - p0 = $0[128,256] parameter(0) - p1 = $0[256,512] parameter(1) - ROOT result = $0[128,512]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; ExpectedFailMode fail_mode = ExpectedFailMode::kFail; @@ -1802,7 +1800,8 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(hlo_text, type, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(hlo_text, type, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc, fail_mode); } @@ -1819,7 +1818,7 @@ flhs { ROOT result = $0[128,256] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,256] parameter(0) p1 = $0[256,512] parameter(1) lhs = $0[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -1832,18 +1831,11 @@ triton_computation { ROOT result = $0[128,512] dot(lhs, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } - -ENTRY e { - p0 = $0[128,256] parameter(0) - p1 = $0[256,512] parameter(1) - ROOT result = $0[128,512] fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, se::CudaComputeCapability::Ampere()); } @@ -1854,7 +1846,7 @@ flhs { ROOT result = $0[256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,256] parameter(0) p1 = $0[256,512] parameter(1) rhs = $0[256,512] fusion(p1), kind=kCustom, calls=flhs, backend_config={ @@ -1867,18 +1859,11 @@ triton_computation { ROOT result = $0[128,512] dot(p0, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } - -ENTRY e { - p0 = $0[128,256] parameter(0) - p1 = $0[256,512] parameter(1) - ROOT result = $0[128,512] fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, se::CudaComputeCapability::Ampere()); } @@ -1893,7 +1878,7 @@ frhs { ROOT result = $0[16,256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[16,128,256] parameter(0) p1 = $0[16,256,512] parameter(1) lhs = $0[16,128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -1914,18 +1899,11 @@ triton_computation { lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} } - -ENTRY e { - p0 = $0[16,128,256] parameter(0) - p1 = $0[16,256,512] parameter(1) - ROOT result = $0[16,128,512] fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 16, 32}, se::CudaComputeCapability::Ampere()); } @@ -1940,7 +1918,7 @@ frhs { ROOT result = $0[16,256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[16,128,256] parameter(0) p1 = $0[16,256,512] parameter(1) lhs = $0[16,128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -1960,19 +1938,11 @@ triton_computation { ROOT result = $0[16,128,16,512] dot(lhs, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={1} } - -ENTRY e { - p0 = $0[16,128,256] parameter(0) - p1 = $0[16,256,512] parameter(1) - ROOT result = $0[16,128,16,512] fusion(p0, p1), kind=kCustom, - calls=triton_computation, backend_config={"fusion_backend_config": - {"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config": - {"output_tiles":[{"sizes":["4", "16", "4", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{4, 16, 4, 32}, se::CudaComputeCapability::Ampere()); } @@ -1987,7 +1957,7 @@ frhs { ROOT result = $0[16,256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,16,256] parameter(0) lhs = $0[128,16,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ "fusion_backend_config":{ @@ -2008,19 +1978,11 @@ triton_computation { lhs_contracting_dims={1, 2}, rhs_contracting_dims={0, 1} } - -ENTRY e { - p0 = $0[128,16,256] parameter(0) - p1 = $0[16,256,512] parameter(1) - ROOT result = $0[128,512] fusion(p0, p1), kind=kCustom, - calls=triton_computation, backend_config={"fusion_backend_config": - {"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config": - {"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, se::CudaComputeCapability::Ampere()); } @@ -2036,7 +1998,7 @@ frhs { ROOT result = $0[256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[256,128] parameter(0) p1 = $0[256,512] parameter(1) lhs = $0[256,128] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -2057,18 +2019,11 @@ triton_computation { lhs_contracting_dims={0}, rhs_contracting_dims={0} } - -ENTRY e { - p0 = $0[256,128] parameter(0) - p1 = $0[256,512] parameter(1) - ROOT result = $0[128,512] fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, se::CudaComputeCapability::Ampere()); } @@ -2084,7 +2039,7 @@ frhs { ROOT result = $0[512,256] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,256] parameter(0) p1 = $0[512,256] parameter(1) lhs = $0[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -2105,18 +2060,11 @@ triton_computation { lhs_contracting_dims={1}, rhs_contracting_dims={1} } - -ENTRY e { - p0 = $0[128,256] parameter(0) - p1 = $0[512,256] parameter(1) - ROOT result = $0[128,512] fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, se::CudaComputeCapability::Ampere()); } @@ -2132,7 +2080,7 @@ frhs { ROOT result = $0[256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,128] parameter(0) p1 = $0[256,512] parameter(1) lhs = $0[128,128] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -2155,20 +2103,11 @@ triton_computation { rhs_contracting_dims={0}, sparsity=L.1@2:4 } - -ENTRY e { - p0 = $0[128,128] parameter(0) - p1 = $0[256,512] parameter(1) - p2 = u16[128,16] parameter(2) - ROOT result = $0[128,512]{1,0} fusion(p0, p1, p2), kind=kCustom, - calls=triton_computation, backend_config={"fusion_backend_config": - {"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config": - {"output_tiles":[{"sizes":["16", "32"]}]}}} -} )"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, se::CudaComputeCapability::Ampere()); } @@ -2203,7 +2142,7 @@ frhs { ROOT result = $0[256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,256] parameter(0) p1 = $0[256,512] parameter(1) lhs = $0[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -2225,14 +2164,6 @@ triton_computation { rhs_contracting_dims={0}, operand_precision={$1, $2} } - -ENTRY e { - p0 = $0[128,256] parameter(0) - p1 = $0[256,512] parameter(1) - ROOT result = $0[128,512]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )", primitive_util::LowercasePrimitiveTypeName(data_type), PrecisionToString(lhs_precision), PrecisionToString(rhs_precision)); @@ -2246,7 +2177,8 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction( - hlo_text, PrimitiveType::PRIMITIVE_TYPE_INVALID, HloOpcode::kDot)); + hlo_text, PrimitiveType::PRIMITIVE_TYPE_INVALID, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc, fail_mode); } @@ -2294,7 +2226,7 @@ frhs { ROOT result = $0[256,512] parameter(0) } -triton_computation { +ENTRY triton_computation { p0 = $0[128,256] parameter(0) p1 = $0[256,512] parameter(1) lhs = $0[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ @@ -2316,20 +2248,13 @@ triton_computation { rhs_contracting_dims={0}, algorithm=$1 } - -ENTRY e { - p0 = $0[128,256] parameter(0) - p1 = $0[256,512] parameter(1) - ROOT result = $0[128,512]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", - "block_level_fusion_config":{"output_tiles":[{"sizes":["16", "32"]}]}}} -} )", primitive_util::LowercasePrimitiveTypeName(data_type), AlgorithmToString(algorithm)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, - ParseTemplateAndGetInstruction(hlo_text, F32, HloOpcode::kDot)); + ParseTemplateAndGetInstruction(hlo_text, F32, HloOpcode::kDot, + /* use_nested_gemm_fusions=*/true)); ExpectedFailMode fail_mode = ExpectedFailMode::kFail; if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, data_type) && @@ -2347,6 +2272,94 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest())), DotPrecisionAlgorithmTestName); +class FusionKindsTest + : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(FusionKindsTest, OperandOfDot) { + auto [kind, cc] = GetParam(); + const std::string hlo_text = absl::Substitute(R"( +flhs { + ROOT result = f32[128,256] parameter(0) +} + +frhs { + ROOT result = f32[256,512] parameter(0) +} + +ENTRY triton_computation { + p0 = f32[128,256] parameter(0) + p1 = f32[256,512] parameter(1) + lhs = f32[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ + "fusion_backend_config":{"kind":"$0", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "64"]}]}}} + rhs = f32[256,512]{1,0} fusion(p1), kind=kCustom, calls=frhs, + backend_config={ "fusion_backend_config":{ "kind":"$0", + "block_level_fusion_config": {"output_tiles":[{"sizes":["64", "32"]}]}}} + ROOT result = f32[128,512]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)", + kind); + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, F32, HloOpcode::kFusion, + /* use_nested_gemm_fusions=*/true)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc); +} + +std::string FusionKindsTestName( + const ::testing::TestParamInfo< + std::tuple>& data) { + auto [kind, cc] = data.param; + return absl::StrCat(kind, "_", ComputeCapabilityToString(cc)); +} + +TEST_P(FusionKindsTest, OperandOfConcatenate) { + auto [kind, cc] = GetParam(); + const std::string hlo_text = absl::Substitute( + R"( +nest0 { + ROOT p0 = f32[128] parameter(0) +} + +nest1 { + ROOT p0 = f32[128] parameter(0) +} + +ENTRY triton_computation { + p0 = f32[128] parameter(0) + p1 = f32[128] parameter(1) + + fusion0 = f32[128] fusion(p0), kind=kCustom, calls=nest0, backend_config={ + "fusion_backend_config":{"kind":"$0", + "block_level_fusion_config":{"output_tiles":[{"sizes":["64"]}]}}} + fusion1 = f32[128] fusion(p1), kind=kCustom, calls=nest1, backend_config={ + "fusion_backend_config":{"kind":"$0", + "block_level_fusion_config":{"output_tiles":[{"sizes":["64"]}]}}} + ROOT result = f32[256] concatenate(fusion0, fusion1), dimensions={0} +} +)", + kind); + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, F32, HloOpcode::kFusion, + /* use_nested_gemm_fusions=*/true)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{64}, cc); +} + +std::vector FusionKindsForTest() { + return {kTritonFusionKind, kTritonNestedGemmFusionKind, "__invalid"}; +} + +INSTANTIATE_TEST_SUITE_P( + FusionTestSuite, FusionKindsTest, + ::testing::Combine(::testing::ValuesIn(FusionKindsForTest()), + ::testing::ValuesIn(AllDevicesToTest())), + FusionKindsTestName); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start @@ -2362,7 +2375,6 @@ constexpr std::array kUnsupportedOps = { HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, HloOpcode::kFft, - HloOpcode::kFusion, HloOpcode::kGather, HloOpcode::kGetTupleElement, HloOpcode::kInfeed, @@ -2422,7 +2434,7 @@ absl::flat_hash_set AllTestedOpcodes() { ret.emplace(HloOpcode::kRngBitGenerator); ret.emplace(HloOpcode::kRngGetAndUpdateState); ret.emplace(HloOpcode::kWhile); - + ret.emplace(HloOpcode::kFusion); ret.insert(kUnsupportedOps.begin(), kUnsupportedOps.end()); return ret; diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index d3a6d7fe4d8d74..4e3132117bacbf 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -2022,6 +2022,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/instruction_fusion.h b/third_party/xla/xla/service/instruction_fusion.h index b98517d0ee3299..b02878b2b842d5 100644 --- a/third_party/xla/xla/service/instruction_fusion.h +++ b/third_party/xla/xla/service/instruction_fusion.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -71,6 +72,12 @@ class FusionDecision { } } + explicit FusionDecision(absl::Status status) { + if (!status.ok()) { + explanation_ = status.message(); + } + } + #if defined(PLATFORM_GOOGLE) // We can fuse iff. the decision is `true`. The source location indicates // where an instance was created, making debugging easier without a need to From 0afeed9399bc37db6daeb0b6158ad74e19536a91 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 1 Apr 2025 07:36:36 -0700 Subject: [PATCH 0090/1324] [XLA:FFI] Add DeviceOrdinal context decoding to external FFI. This was already available in the internal version of the API, but it may be useful to be able to access it externally. For example, I found myself wanting this when constructing a DLPack wrapper for FFI buffers. PiperOrigin-RevId: 742692505 --- third_party/xla/xla/ffi/api/c_api.h | 21 ++++++++++++++- third_party/xla/xla/ffi/api/ffi.h | 34 +++++++++++++++++++++++++ third_party/xla/xla/ffi/api/ffi_test.cc | 18 +++++++++++++ third_party/xla/xla/ffi/ffi_api.cc | 10 ++++++++ 4 files changed, 82 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index b7278a754bb185..44a4e21281cceb 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -672,6 +672,24 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_RunId_Get_Args, run_id); // Returns a unique identifier for the current logical execution. typedef XLA_FFI_Error* XLA_FFI_RunId_Get(XLA_FFI_RunId_Get_Args* args); +//===----------------------------------------------------------------------===// +// DeviceOrdinal +//===----------------------------------------------------------------------===// + +struct XLA_FFI_DeviceOrdinal_Get_Args { + size_t struct_size; + XLA_FFI_Extension_Base* extension_start; + + XLA_FFI_ExecutionContext* ctx; + int32_t device_ordinal; // out +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_DeviceOrdinal_Get_Args, device_ordinal); + +// Returns a unique identifier for the current logical execution. +typedef XLA_FFI_Error* XLA_FFI_DeviceOrdinal_Get( + XLA_FFI_DeviceOrdinal_Get_Args* args); + //===----------------------------------------------------------------------===// // Metadata extension //===----------------------------------------------------------------------===// @@ -721,11 +739,12 @@ struct XLA_FFI_Api { _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_RunId_Get); + _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceOrdinal_Get); }; #undef _XLA_FFI_API_STRUCT_FIELD -XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Api, XLA_FFI_Stream_Get); +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Api, XLA_FFI_DeviceOrdinal_Get); const XLA_FFI_Api* XLA_FFI_GetApi(); diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 18070932fce069..99573059ad8003 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -1590,6 +1590,40 @@ struct CtxDecoding { } }; +//===----------------------------------------------------------------------===// +// DeviceOrdinal +//===----------------------------------------------------------------------===// + +struct DeviceOrdinal {}; + +// Context decoding for DeviceOrdinal. +// +// Example: Ffi::Bind().Ctx() +// .To([](int32_t device_ordinal) { ... }); +template <> +struct CtxDecoding { + using Type = int32_t; + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine& diagnostic) { + XLA_FFI_DeviceOrdinal_Get_Args args; + args.struct_size = XLA_FFI_ExecutionContext_Get_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.ctx = ctx; + args.device_ordinal = 0; + + if (XLA_FFI_Error* err = api->XLA_FFI_DeviceOrdinal_Get(&args); err) { + diagnostic.Emit("Failed to get device ordinal from execution context: ") + << internal::GetErrorMessage(api, err); + internal::DestroyError(api, err); + return std::nullopt; + } + + return args.device_ordinal; + } +}; + } // namespace xla::ffi #endif // XLA_FFI_API_FFI_H_ diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index e76d99cf18f332..e38d61dee89c8f 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -450,6 +450,24 @@ TEST(FfiTest, RunId) { TF_ASSERT_OK(status); } +TEST(FfiTest, DeviceOrdinal) { + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + auto call_frame = builder.Build(); + + auto handler = + Ffi::Bind().Ctx().To([&](int32_t device_ordinal) { + EXPECT_EQ(device_ordinal, 42); + return Error::Success(); + }); + + CallOptions options; + options.device_ordinal = 42; + + auto status = Call(*handler, call_frame, options); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, AnyBufferArgument) { std::vector storage(4, 0.0f); se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index bc6dbd5e3cd120..e88665c48ef71b 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -630,6 +630,15 @@ static XLA_FFI_Error* XLA_FFI_RunId_Get(XLA_FFI_RunId_Get_Args* args) { return nullptr; } +static XLA_FFI_Error* XLA_FFI_DeviceOrdinal_Get( + XLA_FFI_DeviceOrdinal_Get_Args* args) { + XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "XLA_FFI_DeviceOrdinal_Get", XLA_FFI_DeviceOrdinal_Get_Args_STRUCT_SIZE, + args->struct_size)); + args->device_ordinal = args->ctx->device_ordinal; + return nullptr; +} + static XLA_FFI_Error* XLA_FFI_TypeId_Register( XLA_FFI_TypeId_Register_Args* args) { XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( @@ -958,6 +967,7 @@ static XLA_FFI_Api api = { XLA_FFI_Future_SetAvailable, XLA_FFI_Future_SetError, XLA_FFI_RunId_Get, + XLA_FFI_DeviceOrdinal_Get, }; const XLA_FFI_Api* GetXlaFfiApi() { return &api; } From 502c645695129479608b16132d7eb42e8cfba7b2 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Tue, 1 Apr 2025 07:43:34 -0700 Subject: [PATCH 0091/1324] [XLA:GPU] NFC: Expose kDefaultVersion so that we don't need to create an AutotuneCacheKey instance to retrieve it. PiperOrigin-RevId: 742694721 --- .../xla/xla/service/gpu/autotuning/autotuner_util.cc | 5 ++--- .../xla/xla/service/gpu/autotuning/autotuner_util.h | 10 ++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc index 4ecb2c9a0e1998..f87bf7f6e82a46 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc @@ -573,9 +573,8 @@ bool IsTextProtoPath(absl::string_view file_path) { void AddVersionToAutotuneResults(AutotuneResults& results) { for (auto& result : *results.mutable_results()) { if (result.version() == 0) { - // Create a dummy key and pull its version if we don't have one specified. - AutotuneCacheKey key("foo", "canonical_foo"); - result.set_version(key.GetVersion()); + // Set to current version if we don't have one specified. + result.set_version(AutotuneCacheKey::kCurrentVersion); } } } diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index 1ac55aff05ed7b..76eecaf0c1125d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -58,6 +58,11 @@ struct DevicelessConfig { class AutotuneCacheKey { public: + // Tie a version to the cache key in order to invalidate the cache when + // necessary. This should be incremented on triton upgrades or any other + // changes that may affect the autotuning results. + static constexpr int kCurrentVersion = 1; + AutotuneCacheKey(const se::DeviceDescription& device_description, const HloInstruction& instruction) : AutotuneCacheKey(DeviceDescriptionToCacheKey(device_description), @@ -103,10 +108,7 @@ class AutotuneCacheKey { private: std::string model_str_; std::string hlo_canonical_; - // Tie a version to the cache key in order to invalidate the cache when - // necessary. This should be done on triton upgrades or any other changes - // that may affect the autotuning results. - int version_ = 1; + int version_ = kCurrentVersion; }; using AutotuneCacheKeySet = absl::flat_hash_set; From e2e7bbe3cd8b7d7b19741d7ca6fa8ebebf3e823f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 07:51:35 -0700 Subject: [PATCH 0092/1324] Simplify tests using EqualsProto(). PiperOrigin-RevId: 742697107 --- .../xla/xla/hlo/parser/hlo_parser_test.cc | 11 +++++----- .../large_hlo_snapshot_serialization/BUILD | 1 + .../serialization_test.cc | 10 +++++---- .../xla/xla/service/hlo_instruction_test.cc | 5 ++--- .../xla/xla/service/hlo_module_test.cc | 3 +-- .../coordination_service_agent_test.cc | 22 ------------------- .../xla/xla/tsl/distributed_runtime/rpc/BUILD | 1 + .../distributed_runtime/rpc/grpc_util_test.cc | 12 +++++----- .../xla/tsl/lib/histogram/histogram_test.cc | 10 ++------- 9 files changed, 25 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc index 4480becf0a0942..52afbbab8b92a1 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -5972,14 +5972,13 @@ TEST_F(HloParserTest, // Check the async-start and async-done instructions. HloInstruction* async_done = m->entry_computation()->root_instruction(); HloInstruction* async_start = async_done->async_chain_start(); - EXPECT_EQ(async_start->metadata().DebugString(), - wrapped_instr->metadata().DebugString()); + EXPECT_THAT(async_start->metadata(), EqualsProto(wrapped_instr->metadata())); EXPECT_EQ(async_start->raw_backend_config_string(), wrapped_instr->raw_backend_config_string()); - EXPECT_EQ(async_start->frontend_attributes().DebugString(), - wrapped_instr->frontend_attributes().DebugString()); - EXPECT_EQ(async_start->statistics_viz().DebugString(), - wrapped_instr->statistics_viz().DebugString()); + EXPECT_THAT(async_start->frontend_attributes(), + EqualsProto(wrapped_instr->frontend_attributes())); + EXPECT_THAT(async_start->statistics_viz(), + EqualsProto(wrapped_instr->statistics_viz())); EXPECT_EQ(OriginalValueToString(*async_done->original_value()), OriginalValueToString(*wrapped_instr->original_value())); } diff --git a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD index 79c5be30701ad6..97172f3aec2f81 100644 --- a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD +++ b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD @@ -66,6 +66,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", + "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc index 0c07d4e764b045..a407bc38f662bb 100644 --- a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc +++ b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/proto/proto_matchers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" @@ -37,6 +38,7 @@ namespace xla { namespace { using ::testing::HasSubstr; +using ::tsl::proto_testing::EqualsProto; HloUnoptimizedSnapshot CreateSnapshot() { HloUnoptimizedSnapshot snapshot; @@ -69,7 +71,7 @@ TEST(LargeHloSnapshotSerializationTest, SerializeAndDeserialize) { TF_ASSERT_OK_AND_ASSIGN(HloUnoptimizedSnapshot deserialized_snapshot, SerializeAndDeserialize(snapshot)); - EXPECT_EQ(deserialized_snapshot.DebugString(), snapshot.DebugString()); + EXPECT_THAT(deserialized_snapshot, EqualsProto(snapshot)); } TEST(LargeHloSnapshotSerializationTest, SerializeAndDeserializeEmptyModule) { @@ -79,7 +81,7 @@ TEST(LargeHloSnapshotSerializationTest, SerializeAndDeserializeEmptyModule) { TF_ASSERT_OK_AND_ASSIGN(HloUnoptimizedSnapshot deserialized_snapshot, SerializeAndDeserialize(snapshot)); - EXPECT_EQ(deserialized_snapshot.DebugString(), snapshot.DebugString()); + EXPECT_THAT(deserialized_snapshot, EqualsProto(snapshot)); } TEST(LargeHloSnapshotSerializationTest, SerializeAndDeserializeEmptyPartition) { @@ -89,7 +91,7 @@ TEST(LargeHloSnapshotSerializationTest, SerializeAndDeserializeEmptyPartition) { TF_ASSERT_OK_AND_ASSIGN(HloUnoptimizedSnapshot deserialized_snapshot, SerializeAndDeserialize(snapshot)); - EXPECT_EQ(deserialized_snapshot.DebugString(), snapshot.DebugString()); + EXPECT_THAT(deserialized_snapshot, EqualsProto(snapshot)); } TEST(LargeHloSnapshotSerializationTest, SerializeAndDeserializeBrokenSnapshot) { @@ -154,7 +156,7 @@ TEST(LargeHloSnapshotSerializationTest, TF_ASSERT_OK_AND_ASSIGN(HloUnoptimizedSnapshot deserialized_snapshot, SerializeAndDeserialize(snapshot)); - EXPECT_EQ(deserialized_snapshot.DebugString(), snapshot.DebugString()); + EXPECT_THAT(deserialized_snapshot, EqualsProto(snapshot)); } } // namespace diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index bff3fd1cf17ced..f526838ecd2d5f 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -2404,7 +2404,7 @@ TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { Window w = window_util::MakeWindow({1, 2, 3}); instr->set_window(w); auto clone = instr->Clone(); - EXPECT_THAT(clone->window(), EqualsProto(w)) << clone->window().DebugString(); + EXPECT_THAT(clone->window(), EqualsProto(w)); } TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { @@ -2415,8 +2415,7 @@ TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { dnums.set_output_batch_dimension(42); instr->set_convolution_dimension_numbers(dnums); auto clone = instr->Clone(); - EXPECT_THAT(clone->convolution_dimension_numbers(), EqualsProto(dnums)) - << clone->convolution_dimension_numbers().DebugString(); + EXPECT_THAT(clone->convolution_dimension_numbers(), EqualsProto(dnums)); } TEST_F(HloInstructionTest, CloneHasSideEffectOnCustomCall) { diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc index 28f8d8d3930490..6c673f56f914c7 100644 --- a/third_party/xla/xla/service/hlo_module_test.cc +++ b/third_party/xla/xla/service/hlo_module_test.cc @@ -765,8 +765,7 @@ ENTRY ReduceR3ToR2.v3 { xla::HloModuleProtoWithConfig proto = module->ToProtoWithConfig(); std::string serialized_module; ASSERT_TRUE(tsl::SerializeToStringDeterministic(proto, &serialized_module)); - std::string original_debug_str = proto.DebugString(); - RecordProperty("serialized_module", original_debug_str); + RecordProperty("serialized_module", proto.DebugString()); // Verify that we can create a module from our parsed proto copy TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr reconstructed_module, diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index 706dcc589142d6..2d9f72525e8cf6 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -52,28 +52,6 @@ using ::testing::UnorderedPointwise; using ::testing::WithArgs; using ::tsl::testing::StatusIs; -// TODO(b/229726259) Switch to OSS version after it's available. -// Simple implementation of a proto matcher comparing string representations. -class ProtoStringMatcher { - public: - explicit ProtoStringMatcher(const tsl::protobuf::Message& expected) - : expected_(expected.DebugString()) {} - - template - bool MatchAndExplain(const Message& p, - ::testing::MatchResultListener*) const { - return p.DebugString() == expected_; - } - - void DescribeTo(std::ostream* os) const { *os << expected_; } - void DescribeNegationTo(std::ostream* os) const { - *os << "not equal to expected message: " << expected_; - } - - private: - const std::string expected_; -}; - MATCHER(KvEq, "simple KeyValueEntry matcher") { const KeyValueEntry& kv0 = std::get<0>(arg); const KeyValueEntry& kv1 = std::get<1>(arg); diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index 0ba36f59f9350e..b522a338938297 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -62,6 +62,7 @@ tsl_cc_test( "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", + "//xla/tsl/util/proto:proto_matchers", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc index 182b6d02343bd9..af596db53a576d 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc @@ -24,10 +24,12 @@ limitations under the License. #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/util/proto/proto_matchers.h" namespace tsl { - namespace { + +using tsl::proto_testing::EqualsProto; using tsl::test::TestRequest; string ToString(const grpc::ByteBuffer& buf) { @@ -96,7 +98,7 @@ TEST(GrpcProto, Unparse) { ASSERT_TRUE(GrpcMaybeUnparseProto(proto, &buf).ok()); TestRequest parsed; ASSERT_TRUE(parsed.ParseFromString(ToString(buf))); - ASSERT_EQ(proto.DebugString(), parsed.DebugString()); + ASSERT_THAT(parsed, EqualsProto(proto)); } TEST(GrpcProto, UnparseToString) { @@ -109,7 +111,7 @@ TEST(GrpcProto, UnparseToString) { ASSERT_TRUE(GrpcMaybeUnparseProto(str, &buf).ok()); TestRequest parsed; ASSERT_TRUE(parsed.ParseFromString(ToString(buf))); - ASSERT_EQ(proto.DebugString(), parsed.DebugString()); + ASSERT_THAT(parsed, EqualsProto(proto)); } TEST(GrpcProto, Parse) { @@ -131,7 +133,7 @@ TEST(GrpcProto, Parse) { TestRequest parsed; ASSERT_TRUE(GrpcMaybeParseProto(&src, &parsed)) << c.length << " " << c.slices; - ASSERT_EQ(proto.DebugString(), parsed.DebugString()); + ASSERT_THAT(parsed, EqualsProto(proto)); } } @@ -156,7 +158,7 @@ TEST(GrpcProto, ParseFromString) { ASSERT_TRUE(GrpcMaybeParseProto(&src, &parsed_str)) << c.length << " " << c.slices; ASSERT_TRUE(parsed.ParseFromString(parsed_str)); - ASSERT_EQ(proto.DebugString(), parsed.DebugString()); + ASSERT_THAT(parsed, EqualsProto(proto)); } } diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc index 42268a44b0cce5..3ac2e3dd41213b 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc @@ -25,27 +25,21 @@ namespace tsl { namespace histogram { static void Validate(const Histogram& h) { - string s1 = h.ToString(); - LOG(ERROR) << s1; - HistogramProto proto_with_zeroes; h.EncodeToProto(&proto_with_zeroes, true); Histogram h2; EXPECT_TRUE(h2.DecodeFromProto(proto_with_zeroes)); - string s2 = h2.ToString(); - LOG(ERROR) << s2; - EXPECT_EQ(s1, s2); + EXPECT_EQ(h2.ToString(), h.ToString()); HistogramProto proto_no_zeroes; h.EncodeToProto(&proto_no_zeroes, false); - LOG(ERROR) << proto_no_zeroes.DebugString(); Histogram h3; EXPECT_TRUE(h3.DecodeFromProto(proto_no_zeroes)); string s3 = h3.ToString(); LOG(ERROR) << s3; - EXPECT_EQ(s1, s3); + EXPECT_EQ(h3.ToString(), h.ToString()); } TEST(Histogram, Empty) { From d8fd215ba9a96d3f9a72efb231044db14cb3c3ee Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 1 Apr 2025 08:04:09 -0700 Subject: [PATCH 0093/1324] [XLA:GPU] Add AllReduce decomposer. The pass rewrites small `all-reduce` as `all-gather` + `reduce`. The expectation is that once we have a memcpy-based `all-gather`, this combination will be faster on a single host compared to the default NCCL `all-reduce` that we use right now. PiperOrigin-RevId: 742701693 --- third_party/xla/xla/debug_options_flags.cc | 9 + third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/gpu_compiler.cc | 4 + .../xla/xla/service/gpu/transforms/BUILD | 42 +++++ .../gpu/transforms/all_reduce_decomposer.cc | 160 ++++++++++++++++++ .../gpu/transforms/all_reduce_decomposer.h | 41 +++++ .../transforms/all_reduce_decomposer_test.cc | 95 +++++++++++ third_party/xla/xla/xla.proto | 5 +- 8 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.cc create mode 100644 third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.h create mode 100644 third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index ce53378a85c5c7..0e7d84010421a4 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -334,6 +334,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_scatter_determinism_expander(true); opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(false); opts.set_xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel(true); + opts.set_xla_gpu_unsupported_enable_all_reduce_decomposer(false); opts.set_xla_gpu_experimental_pack_dot_operands_along_k_dimension(true); opts.set_xla_unsupported_crash_on_hlo_pass_fix_max_iterations(false); opts.set_xla_hlo_pass_fix_detect_cycles(false); @@ -2252,6 +2253,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "performance." "Note that even when this flag is disabled, scatter operations may still " "be deterministic, although with additional overhead.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_unsupported_enable_all_reduce_decomposer", + bool_setter_for( + &DebugOptions::set_xla_gpu_unsupported_enable_all_reduce_decomposer), + debug_options->xla_gpu_unsupported_enable_all_reduce_decomposer(), + "Internal: Enable the AllReduceDecomposer, an unsupported pass that " + "rewrites small all-reduce operations as a sequence of all-gather and " + "reduce operations.")); flag_list->push_back(tsl::Flag( "xla_gpu_unsupported_enable_ragged_all_to_all_decomposer", bool_setter_for( diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 44d8885928510d..2a2646ea8691d1 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1513,6 +1513,7 @@ cc_library( "//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier", "//xla/service/gpu/transforms:all_gather_optimizer", "//xla/service/gpu/transforms:all_reduce_blueconnect", + "//xla/service/gpu/transforms:all_reduce_decomposer", "//xla/service/gpu/transforms:all_reduce_splitter", "//xla/service/gpu/transforms:async_collective_annotator", "//xla/service/gpu/transforms:async_wrapper", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index e2e2a1f862cce3..f73d785c71b21c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -186,6 +186,7 @@ limitations under the License. #include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" #include "xla/service/gpu/transforms/all_gather_optimizer.h" #include "xla/service/gpu/transforms/all_reduce_blueconnect.h" +#include "xla/service/gpu/transforms/all_reduce_decomposer.h" #include "xla/service/gpu/transforms/all_reduce_splitter.h" #include "xla/service/gpu/transforms/async_wrapper.h" #include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h" @@ -880,6 +881,9 @@ absl::Status RunCollectiveOptimizationPasses( collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); + if (debug_options.xla_gpu_unsupported_enable_all_reduce_decomposer()) { + collectives_pipeline.AddPass(); + } collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); collectives_pipeline.AddPass( diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 68d8af6460438d..03f25d09842259 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -3558,3 +3558,45 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "all_reduce_decomposer", + srcs = ["all_reduce_decomposer.cc"], + hdrs = ["all_reduce_decomposer.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:collective_ops_utils", + "//xla/service:shape_inference", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "all_reduce_decomposer_test", + srcs = ["all_reduce_decomposer_test.cc"], + deps = [ + ":all_reduce_decomposer", + "//xla/hlo/testlib:filecheck", + "//xla/service:hlo_cse", + "//xla/service:hlo_runner", + "//xla/service:platform_util", + "//xla/tests:hlo_runner_agnostic_test_base", + "//xla/tests:test_utils", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.cc new file mode 100644 index 00000000000000..5e5338e7bf854f --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.cc @@ -0,0 +1,160 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/all_reduce_decomposer.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/shape_inference.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +// The threshold is the upper limit of the number of elements in the input to +// an all-reduce operation for it to be decomposed. The value is chosen +// empirically in Feb 2025 to be a reasonable trade-off between performance and +// memory usage. +constexpr int64_t kOneShotAllReduceThreshold = 256 * 1024; + +bool IsSmallAllReduce(const HloInstruction* hlo) { + return HloPredicateIsOp(hlo) && + ShapeUtil::ElementsInRecursive(hlo->shape()) <= + kOneShotAllReduceThreshold; +} + +std::optional CreateReductionInitLiteral( + HloAllReduceInstruction* all_reduce, HloComputation* computation) { + std::optional reduction_kind = + MatchReductionComputation(all_reduce->to_apply()); + if (!reduction_kind.has_value()) { + return std::nullopt; + } + + return GetReductionIdentity(reduction_kind.value(), + all_reduce->shape().element_type()); +} + +// Adds a size-1 major dimension to the given HLO instruction. +HloInstruction* PrependSize1MajorDimension(HloInstruction* hlo, + HloComputation* computation) { + absl::InlinedVector reshape_dimensions; + reshape_dimensions.reserve(hlo->shape().dimensions().size() + 1); + reshape_dimensions.push_back(1); + absl::c_copy(hlo->shape().dimensions(), + std::back_inserter(reshape_dimensions)); + + Shape reshape_shape = + ShapeUtil::MakeShape(hlo->shape().element_type(), reshape_dimensions); + return computation->AddInstruction( + HloInstruction::CreateReshape(reshape_shape, hlo)); +} + +// Decomposes the given all-reduce operation into an all-gather and a reduce +// operation. +absl::StatusOr DecomposeAllReduce(HloInstruction* hlo, + HloComputation* computation, + HloModule* module) { + HloAllReduceInstruction* all_reduce = Cast(hlo); + + HloInstruction* input = all_reduce->mutable_operand(0); + + std::optional reduction_init_literal = + CreateReductionInitLiteral(all_reduce, computation); + if (!reduction_init_literal.has_value()) { + // Unsupported reduction type. + return false; + } + + TF_ASSIGN_OR_RETURN(auto replica_group_count_and_size, + GetReplicaGroupCountAndSize(all_reduce)); + + if (!replica_group_count_and_size.has_value()) { + // Could not determine the number of participating devices at compilation. + return false; + } + + int64_t num_participating_devices = replica_group_count_and_size->second; + + // Add a size-1 major dimension to the input that will be used as the + // all-gather and reduction dimension. + HloInstruction* reshape = PrependSize1MajorDimension(input, computation); + + TF_ASSIGN_OR_RETURN(Shape all_gather_shape, + ShapeInference::InferAllGatherShape( + {&reshape->shape()}, /*all_gather_dimension=*/0, + num_participating_devices)); + + HloInstruction* all_gather = + computation->AddInstruction(HloInstruction::CreateAllGather( + all_gather_shape, {reshape}, /*all_gather_dimension=*/0, + all_reduce->device_list(), all_reduce->constrain_layout(), + all_reduce->channel_id(), all_reduce->use_global_device_ids())); + + HloInstruction* init = computation->AddInstruction( + HloInstruction::CreateConstant(*std::move(reduction_init_literal))); + + HloInstruction* reduce = + computation->AddInstruction(HloInstruction::CreateReduce( + input->shape(), all_gather, init, + /*dimensions_to_reduce=*/{0}, all_reduce->to_apply())); + + TF_RETURN_IF_ERROR(all_reduce->ReplaceAllUsesWith(reduce)); + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(all_reduce)); + + return true; +} + +absl::StatusOr AllReduceDecomposer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + for (auto computation : module->computations(execution_threads)) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (!IsSmallAllReduce(hlo)) { + continue; + } + TF_ASSIGN_OR_RETURN(bool decomposed, + DecomposeAllReduce(hlo, computation, module)); + changed |= decomposed; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.h new file mode 100644 index 00000000000000..5de352a1b4efbd --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer.h @@ -0,0 +1,41 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_DECOMPOSER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_DECOMPOSER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrites an `all-reduce` as `all-gather` and `reduce`. +class AllReduceDecomposer : public HloModulePass { + public: + absl::string_view name() const override { return "all-reduce-decomposer"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer_test.cc new file mode 100644 index 00000000000000..0ca38f5e8fe3e2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_decomposer_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/all_reduce_decomposer.h" + +#include + +#include +#include "absl/log/log.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/hlo_cse.h" +#include "xla/service/hlo_runner.h" +#include "xla/service/platform_util.h" +#include "xla/tests/hlo_runner_agnostic_test_base.h" +#include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class AllReduceDecomposerTest : public HloRunnerAgnosticTestBase { + public: + AllReduceDecomposerTest() + : HloRunnerAgnosticTestBase(std::make_unique( + PlatformUtil::GetDefaultPlatform().value())) {} +}; + +TEST_F(AllReduceDecomposerTest, SmallAllReduceIsDecomposed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY main { + input = f32[16] parameter(0) + ROOT all-reduce = f32[16] all-reduce(input), replica_groups={{0,1}}, to_apply=add +} +)")); + + AllReduceDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get(), {})); + EXPECT_TRUE(changed); + TF_EXPECT_OK(VerifyHloModule(module.get(), true, true)); + TF_EXPECT_OK(HloCSE(true).Run(module.get())); + + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + // CHECK: f32[1,16]{1,0} reshape + // CHECK: f32[2,16]{1,0} all-gather + // CHECK: f32[16]{0} reduce + )")); +} + +TEST_F(AllReduceDecomposerTest, LargeAllReduceIsNotDecomposed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY main { + input = f32[16777216] parameter(0) + ROOT all-reduce = f32[16777216] all-reduce(input), replica_groups={{0,1}}, to_apply=add +} +)")); + + AllReduceDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get(), {})); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index e1e951b3f8bded..5e0d41c5cb9a5b 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -765,6 +765,9 @@ message DebugOptions { // the code that emits a particular instruction. bool xla_gpu_unsupported_annotate_with_emitter_loc = 358; + // Internal testing flag to switch AllReduceDecomposer on or off. + bool xla_gpu_unsupported_enable_all_reduce_decomposer = 384; + // Enable experimental tiling for Triton GEMM fusions. // TODO(b/393299275): remove the flag once the feature is stable. bool xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms = 367; @@ -1203,7 +1206,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 384 + // Next id: 385 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 0e029807b782df9f21fd52cc7b7fd9cb22023c69 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 1 Apr 2025 08:51:49 -0700 Subject: [PATCH 0094/1324] [xla:gpu] CommandBuffer: switch CommandBuffer::While to explicit command update API Also add test for WhileCmd update PiperOrigin-RevId: 742717675 --- .../xla/xla/backends/gpu/runtime/BUILD | 8 +- .../gpu/runtime/command_buffer_cmd.cc | 12 +- .../gpu/runtime/command_buffer_thunk_test.cc | 107 +++++++++++++++++- .../xla/xla/stream_executor/command_buffer.h | 9 +- .../stream_executor/gpu/gpu_command_buffer.cc | 56 +++++---- .../stream_executor/gpu/gpu_command_buffer.h | 8 +- .../gpu/gpu_command_buffer_test.cc | 15 +-- .../gpu/gpu_test_kernels.cu.cc | 4 +- 8 files changed, 174 insertions(+), 45 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 71e11f25c3f353..eb11475a5f51f0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -354,6 +354,7 @@ xla_test( deps = [ ":command_buffer_cmd", ":command_buffer_thunk", + ":dynamic_slice_thunk", ":memset_thunk", ":sequential_thunk", ":thunk", @@ -367,6 +368,7 @@ xla_test( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/kernels:custom_kernel", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_description", @@ -374,22 +376,24 @@ xla_test( "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", "@local_tsl//tsl/profiler/lib:profiler_lock", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 07e3580e534ac8..010d66f6d7990d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -849,11 +849,13 @@ absl::Status WhileCmd::Record(const Thunk::ExecuteParams& execute_params, << " body_commands=" << body_commands_.size(); VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; - return command_buffer->While( - se::DeviceMemory(pred), - CreateExecutionScopeBuilder(&cond_commands_, &execute_params, - &record_params), - CreateBuilder(&body_commands_, &execute_params, &record_params)); + return command_buffer + ->While(se::DeviceMemory(pred), + CreateExecutionScopeBuilder(&cond_commands_, &execute_params, + &record_params), + CreateBuilder(&body_commands_, &execute_params, &record_params), + {}) + .status(); } bool WhileCmd::force_update() { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 4b8df69791c71e..5d5324815fba2a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,19 +31,23 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/types/span.h" #include "xla/backends/gpu/runtime/command_buffer_cmd.h" +#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" @@ -51,6 +56,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/semantic_version.h" @@ -58,10 +64,9 @@ limitations under the License. #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" #include "tsl/profiler/lib/profiler_lock.h" #ifdef GOOGLE_CUDA @@ -1224,9 +1229,101 @@ TEST(CommandBufferThunkTest, CaseCmd) { } TEST(CommandBufferThunkTest, WhileCmd) { - // TODO(ezhulenev): Find a way to test WhileCmd: add a test only TraceCmd that - // could allow us trace custom kernels to update while loop iterations. Or - // maybe add a CustomLaunchCmd and wrap loop update into custom kernel. + se::StreamExecutor* executor = GpuExecutor(); + + if (!IsAtLeastCuda12300(executor)) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: loop_cnt=0, num_iters=10, a=1, b=0 + se::DeviceMemory pred = executor->AllocateArray(1, 0); + se::DeviceMemory loop_cnt = executor->AllocateArray(1, 0); + se::DeviceMemory num_iters = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&loop_cnt, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&num_iters, 10, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_pred(/*index=*/0, sizeof(bool), /*color=*/0); + BufferAllocation alloc_loop_cnt(/*index=*/1, sizeof(int32_t), /*color=*/0); + BufferAllocation alloc_num_iters(/*index=*/2, sizeof(int32_t), /*color=*/0); + BufferAllocation alloc_a(/*index=*/3, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/4, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_pred(&alloc_pred, 0, sizeof(bool)); + BufferAllocation::Slice slice_loop_cnt(&alloc_loop_cnt, 0, sizeof(int32_t)); + BufferAllocation::Slice slice_num_iters(&alloc_num_iters, 0, sizeof(int32_t)); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto cond_args = {slice_loop_cnt, slice_pred, slice_num_iters}; + auto cond_args_access = {MemoryAccess::kWrite, MemoryAccess::kWrite, + MemoryAccess::kRead}; + + auto body_args = {slice_a, slice_b, slice_b}; // b = a + b + auto body_args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for loop `cond`. + CommandBufferCmdSequence cond_commands; + cond_commands.Emplace(s0, "IncAndCmp", cond_args, cond_args_access, + LaunchDimensions(1, 1), + /*shmem_bytes=*/0); + + // Prepare commands sequence for loop `body`. + CommandBufferCmdSequence body_commands; + body_commands.Emplace(s0, "AddI32", body_args, body_args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_pred, std::move(cond_commands), + std::move(body_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + + ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({pred, loop_cnt, num_iters, a, b}, 0, + &allocator); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); + + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); + TF_ASSERT_OK( + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value 10 times. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 10)); + + // Initialize `loop_cnt` to `5` and check that we run only 5 iterations. + TF_ASSERT_OK(stream->Memset32(&loop_cnt, 5, sizeof(int32_t))); + + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, std::vector(4, 15)); } class CmdBufferTest : public HloTestBase { diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 36ead8cc30de0a..d7f4660d05c653 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -202,8 +202,13 @@ class CommandBuffer { // body_builder() // cond_builder() // - virtual absl::Status While(DeviceMemory pred, Builder cond_builder, - Builder body_builder) = 0; + virtual absl::StatusOr While( + DeviceMemory pred, Builder cond_builder, Builder body_builder, + absl::Span dependencies) = 0; + + // Updates a While operation. + virtual absl::Status While(const Command* command, DeviceMemory pred, + Builder cond_builder, Builder body_builder) = 0; // Submits the command buffer for execution. virtual absl::Status Submit(Stream* stream) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 06c4e3643ff11f..76d19bef40cfcc 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -503,12 +503,19 @@ absl::Status GpuCommandBuffer::Case(const Command* command, /*index_is_bool=*/true, branches); } -absl::Status GpuCommandBuffer::While(DeviceMemory pred, - Builder cond_builder, - Builder body_builder) { +absl::StatusOr GpuCommandBuffer::While( + DeviceMemory pred, Builder cond_builder, Builder body_builder, + absl::Span dependencies) { if (state_ == State::kCreate) { GpuWhileCommand command = {}; + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + + // TODO(ezhulenev): cond_builder should be able to take dependencies. + (void)barrier; + TF_RETURN_IF_ERROR(cond_builder(this)); TF_ASSIGN_OR_RETURN(command.conditional, CreateConditionalHandle()); @@ -529,32 +536,41 @@ absl::Status GpuCommandBuffer::While(DeviceMemory pred, body->GetAutoDependencies())); TF_RETURN_IF_ERROR(command.conditional_node.command_buffer->Finalize()); - AppendCommand(std::move(command)); - return absl::OkStatus(); + return AppendCommand(std::move(command)); } if (state_ == State::kUpdate) { Command& command = *commands_[update_state_.command_idx++]; - auto* gpu_command = tsl::down_cast(&command); + TF_RETURN_IF_ERROR(While(&command, pred, cond_builder, body_builder)); + return &command; + } - TF_RETURN_IF_ERROR(cond_builder(this)); + return UnsupportedStateError(state_); +} - TF_RETURN_IF_ERROR(UpdateSetWhileConditionNode( - gpu_command->set_init_condition_node, gpu_command->conditional, pred)); +absl::Status GpuCommandBuffer::While(const Command* command, + DeviceMemory pred, + Builder cond_builder, + Builder body_builder) { + auto* gpu_command = tsl::down_cast(command); - GpuCommandBuffer* body = gpu_command->conditional_node.command_buffer.get(); - auto body_update_mode = ActivateUpdateMode(body); + TF_RETURN_IF_ERROR(cond_builder(this)); - // Update command buffer using user-provided builder callback. - TF_RETURN_IF_ERROR(body->Update()); - TF_RETURN_IF_ERROR(body_builder(body)); - TF_RETURN_IF_ERROR(cond_builder(body)); - TF_RETURN_IF_ERROR(body->UpdateSetWhileConditionNode( - gpu_command->set_body_condition_node, gpu_command->conditional, pred)); - TF_RETURN_IF_ERROR(body->Finalize()); - } + TF_RETURN_IF_ERROR(UpdateSetWhileConditionNode( + gpu_command->set_init_condition_node, gpu_command->conditional, pred)); - return UnsupportedStateError(state_); + GpuCommandBuffer* body = gpu_command->conditional_node.command_buffer.get(); + auto body_update_mode = ActivateUpdateMode(body); + + // Update command buffer using user-provided builder callback. + TF_RETURN_IF_ERROR(body->Update()); + TF_RETURN_IF_ERROR(body_builder(body)); + TF_RETURN_IF_ERROR(cond_builder(body)); + TF_RETURN_IF_ERROR(body->UpdateSetWhileConditionNode( + gpu_command->set_body_condition_node, gpu_command->conditional, pred)); + TF_RETURN_IF_ERROR(body->Finalize()); + + return absl::OkStatus(); } absl::Status GpuCommandBuffer::Finalize() { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 44c3538724ed85..3feef005f981da 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -171,8 +171,12 @@ class GpuCommandBuffer : public CommandBuffer { absl::Status Case(const Command* command, DeviceMemory index, std::vector branches) override; - absl::Status While(DeviceMemory pred, Builder cond_builder, - Builder body_builder) override; + absl::StatusOr While( + DeviceMemory pred, Builder cond_builder, Builder body_builder, + absl::Span dependencies) override; + + absl::Status While(const Command* command, DeviceMemory pred, + Builder cond_builder, Builder body_builder) override; absl::Status Finalize() override; absl::Status Update() override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 7945a424c6b7fb..352d08a5aabc12 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -60,7 +60,8 @@ using MulI32Kernel = TypedKernelFactory, DeviceMemory, DeviceMemory>; using IncAndCmpKernel = - TypedKernelFactory, DeviceMemory, int32_t>; + TypedKernelFactory, DeviceMemory, + DeviceMemory>; using AddI32Ptrs3 = TypedKernelFactory>; @@ -658,17 +659,17 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { // below. DeviceMemory pred = executor->AllocateArray(1, 0); DeviceMemory loop_counter = executor->AllocateArray(1, 0); + DeviceMemory num_iters = executor->AllocateArray(1, 0); DeviceMemory a = executor->AllocateArray(length, 0); DeviceMemory b = executor->AllocateArray(length, 0); static constexpr bool kFalse = false; TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); TF_ASSERT_OK(stream->Memset32(&loop_counter, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&num_iters, 10, sizeof(int32_t))); TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); - int32_t num_iters = 10; - // Loop cond: loop_counter++ < num_iters; CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { return cond_cmd @@ -685,7 +686,7 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder)); + TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -729,6 +730,7 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { DeviceMemory pred = executor->AllocateArray(1, 0); DeviceMemory pred_then = executor->AllocateArray(1, 0); DeviceMemory loop_counter = executor->AllocateArray(1, 0); + DeviceMemory num_iters = executor->AllocateArray(1, 0); DeviceMemory a = executor->AllocateArray(length, 0); DeviceMemory b = executor->AllocateArray(length, 0); @@ -737,11 +739,10 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); TF_ASSERT_OK(stream->Memcpy(&pred_then, &kTrue, 1)); TF_ASSERT_OK(stream->Memset32(&loop_counter, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&num_iters, 10, sizeof(int32_t))); TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); - int32_t num_iters = 10; - CommandBuffer::Builder then_builder = // Then body: b = a + b [&](CommandBuffer* then_cmd) { @@ -770,7 +771,7 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder)); + TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc index 659a8066a21a22..94a03f569feea6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc @@ -36,9 +36,9 @@ __global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { c[index] = a[index] * b[index]; } -__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t value) { +__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t* value) { int index = threadIdx.x + blockIdx.x * blockDim.x; - pred[index] = counter[index] < value; + pred[index] = counter[index] < *value; counter[index] += 1; } From 47f2afdfb063d3b0fd2898d4b195329c32270c0e Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 1 Apr 2025 09:34:43 -0700 Subject: [PATCH 0095/1324] Split up shape passes to remove the need for Shape->MHLO pass, use Shape->StableHLO PiperOrigin-RevId: 742732024 --- .../mlir/quantization/stablehlo/BUILD | 3 + .../stablehlo/cc/pass_pipeline.cc | 10 +- .../convert_shape_constraint_to_assert.cc | 218 ++++++ .../quantization/stablehlo/passes/passes.td | 9 + .../passes/shape_cstr_legalize_to_hlo.mlir | 110 +++ .../xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 1 + .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 9 +- third_party/xla/xla/mlir_hlo/BUILD | 1 - .../mlir_hlo/mhlo/transforms/CMakeLists.txt | 1 - .../mlir_hlo/mhlo/transforms/mhlo_passes.td | 17 - .../xla/xla/mlir_hlo/mhlo/transforms/passes.h | 4 - .../shape_legalize_to_hlo.cc | 705 ------------------ .../mhlo/shape_cstr_legalize_to_hlo.mlir | 110 --- .../Dialect/mhlo/shape_legalize_to_hlo.mlir | 372 --------- 14 files changed, 351 insertions(+), 1219 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index ec79c4f83f5d26..f40c1371c9df19 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -48,6 +48,7 @@ cc_library( name = "passes", srcs = [ "passes/convert_func_to_bfloat16.cc", + "passes/convert_shape_constraint_to_assert.cc", "passes/convert_xla_call_module_op_to_bfloat16.cc", "passes/defer_activation_transpose.cc", "passes/fold_constant_transpose.cc", @@ -138,6 +139,7 @@ cc_library( "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:path", @@ -150,6 +152,7 @@ cc_library( "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 1bbf67389366f5..c5fc8b5b3d8d8e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -116,15 +116,13 @@ void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { } void AddShapeLegalizationPasses(OpPassManager& pm) { - pm.addPass(mhlo::createStablehloLegalizeToHloPass()); + // TODO: We may need to make a parent pass here that does + // shape->StableHLO+cstr because the stablehlo pass requires that the ops made + // by cstr are legal. pm.addNestedPass( - mhlo::createShapeLegalizeToHloPass(/*legalizeConstraints=*/true)); - // The following 2 passes are used to clean up the spurious UnrealizedCast ops - // and shape.assuming regions leftover from the ShapeLegalizeToHlo pass. See - // pass definition for details. + createConvertShapeToStablehloWithConstraintsPass()); pm.addPass(createReconcileUnrealizedCastsPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass(mhlo::createHloLegalizeToStablehloPass()); } void AddStablehloQuantToIntPasses(OpPassManager& pm) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc new file mode 100644 index 00000000000000..d63dfdeaec7514 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc @@ -0,0 +1,218 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_CONVERTSHAPETOSTABLEHLOWITHCONSTRAINTSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { +using ::mlir::stablehlo::AndOp; +using ::mlir::stablehlo::CompareOp; +using ::mlir::stablehlo::ComparisonDirection; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConstantOp; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::OrOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::SliceOp; + +// Cast from index-based shape representation used in the Shape dialect to the +// i32-based representation used in HLO: +// * index => tensor. +// * tensor => tensor. +// * All i32-based types from above => themselves. +// There is no convenient op that can express this, so we're using +// unrealized_conversion_cast (with the idea that all these casts will +// annihilate at the end of the pass). +Value castToI32(PatternRewriter& rewriter, Location loc, Value value) { + Type resultType; + if (value.getType().isIndex()) + resultType = RankedTensorType::get({}, rewriter.getI32Type()); + if (auto valueType = mlir::dyn_cast(value.getType())) { + if (!valueType.hasStaticShape()) return {}; + if (valueType.getElementType().isInteger(32)) return value; + if (valueType.getElementType().isIndex()) + resultType = + RankedTensorType::get(valueType.getShape(), rewriter.getI32Type()); + } + if (!resultType) return {}; + auto cast = + rewriter.create(loc, resultType, value); + return cast.getResult(0); +} + +// Pads input tensor by X ones from the left. The number X is +// determined by input pad. Result is tensor<(X+N) x i32>, where the first X +// elements are ones. +Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, + int64_t pad) { + Value padI32 = rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); + return rewriter.create(loc, ValueRange{padI32, input}, + /*dimension=*/0); +} + +void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, + Value assert) { + auto customCall = + builder.create(loc, TypeRange{}, ValueRange{assert}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + builder.getStringAttr("Shape assertion failed")); +} + +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two 1D tensors. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = mlir::dyn_cast(shape1.getType()); + auto tensorType2 = mlir::dyn_cast(shape2.getType()); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + int32_t rank = + std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({rank}, rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create(op.getLoc(), shape1, allOne, + ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create(op.getLoc(), shape2, allOne, + ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create(op.getLoc(), shape1, shape2, + ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < rank; ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), + allBroadcastableScalar); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +bool hasIndexStyle(Value value) { + if (value.getType().isIndex()) return true; + auto type = mlir::dyn_cast(value.getType()); + return type && type.getElementType().isIndex(); +} + +struct ConvertShapeToStablehloWithConstraintsPass + : public impl::ConvertShapeToStablehloWithConstraintsPassBase< + ConvertShapeToStablehloWithConstraintsPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalDialect<::mlir::stablehlo::StablehloDialect>( + [](Operation* op) { + return !llvm::any_of(op->getOperands(), hasIndexStyle); + }); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + ::mlir::stablehlo::populateShapeToStablehloPatterns(&getContext(), + &patterns); + + patterns.add(&getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index da59c218a56926..e6108ca6d13e02 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -153,6 +153,15 @@ def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"stablehlo-convert-xla-call-modu ]; } +def ConvertShapeToStablehloWithConstraintsPass : Pass<"stablehlo-convert-shape-to-stablehlo-with-constraints", "mlir::func::FuncOp"> { + let summary = "Convert shape.cstr_broadcastable to stablehlo.custom_call @shape_assertion"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + def OptimizeGraphPass : Pass<"optimize-graph", "ModuleOp"> { let summary = "Optimize the sub-optimal patterns after quantization."; let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 00000000000000..ac7d6a51fb87b1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-shape-to-stablehlo-with-constraints --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_1 +func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_2 +func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS1_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index 0a471554a7acc5..ecd5ac12539432 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -181,6 +181,7 @@ cc_library( "@local_tsl//tsl/platform:ml_dtypes", "@stablehlo//:base", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", ], ) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 26ae589ac5e008..42b41b974c0e69 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "mhlo/transforms/passes.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -63,6 +64,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/RegionUtils.h" #include "stablehlo/dialect/Base.h" +#include "stablehlo/transforms/Passes.h" #include "xla/array.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" @@ -4102,16 +4104,17 @@ absl::Status PrepareForExport(mlir::ModuleOp module) { // Only bounded dynamism is planned to be supported; unbounded dynamism // is out of scope for now. // + // Shape -> MHLO // Currently takes overhead if input is MHLO for MHLO->StableHLO, can // be deleted once conversion can assume StableHLO input. mlir::mhlo::HloLegalizeToStablehloPassOptions options; options.allow_xla_features_ = true; - pm.addPass(mhlo::createHloLegalizeToStablehloPass(options)); pm.addNestedPass( stablehlo_ext::createSymbolicShapeOptimizationPass()); + pm.addPass(mhlo::createHloLegalizeToStablehloPass(options)); + pm.addNestedPass( + stablehlo::createShapeLegalizeToStablehloPass()); pm.addPass(mhlo::createStablehloLegalizeToHloPass()); - - pm.addNestedPass(mhlo::createShapeLegalizeToHloPass()); } mlir::BaseScopedDiagnosticHandler handler(module.getContext()); diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 6ced829ae19bd5..8f612af8ef5d75 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -383,7 +383,6 @@ cc_library( "mhlo/transforms/mhlo_passes.h.inc", "mhlo/transforms/optimize_mhlo/optimize_mhlo.cc", "mhlo/transforms/prepare_for_export/prepare_for_export.cc", - "mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc", "mhlo/transforms/shape_simplification/shape_simplification.cc", "mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc", "mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc", diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 249531d74f7e31..15a015db03cc72 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -43,7 +43,6 @@ add_mlir_library(MhloPasses mhlo_flatten_tuple/mhlo_flatten_tuple.cc prepare_for_export/prepare_for_export.cc optimize_mhlo/optimize_mhlo.cc - shape_legalize_to_hlo/shape_legalize_to_hlo.cc shape_simplification/shape_simplification.cc sink_constants_to_control_flow/sink_constants_to_control_flow.cc test_infer_shaped_type/test_infer_shaped_type_pass.cc diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index f3fc6cdec3a579..ab456fcaaea616 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -218,20 +218,3 @@ def PrepareForExportPass : Pass<"xla-prepare-for-export", "mlir::func::FuncOp"> canonicalization may undo transformations. }]; } - -def ShapeLegalizeToHloPass : Pass<"shape-legalize-to-hlo", "func::FuncOp"> { - let summary = "Legalize shape-related ops to HLO."; - let constructor = "createShapeLegalizeToHloPass()"; - let description = [{ - An experimental pass that legalizes shape-related ops to MHLO ops. - - Bringing shape and data computations together via an optional pass will - make it possible for the MHLO ecosystem to potentially leverage the - compilation pipelines that use HLO operations to model dynamism. - }]; - let dependentDialects = ["mhlo::MhloDialect"]; - let options = [ - Option<"legalize_constraints_", "legalize-constraints", "bool", - /*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call"> - ]; -} diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h index 271f7299e6f04b..4791d9a94550e1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h @@ -89,10 +89,6 @@ std::unique_ptr> createConvertToSignlessPass(); // Legalizes from the StableHLO dialect to the MHLO dialect. std::unique_ptr> createStablehloLegalizeToHloPass(); -// Legalizes from the Shape dialect to the MHLO dialect. -std::unique_ptr> createShapeLegalizeToHloPass( - bool legalizeConstraints = false); - // Test passes. std::unique_ptr createTestInferShapedTypeMethodsPass(); std::unique_ptr createTestMaterializeBroadcastsPass(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc deleted file mode 100644 index e00300f7a38f22..00000000000000 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc +++ /dev/null @@ -1,705 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectRegistry.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Support/TypeID.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace mhlo { - -#define GEN_PASS_DEF_SHAPELEGALIZETOHLOPASS -#include "mhlo/transforms/mhlo_passes.h.inc" - -namespace { - -bool hasI32Style(Value value) { - auto type = mlir::dyn_cast(value.getType()); - return type && type.getElementType().isInteger(32); -} - -// Cast from index-based shape representation used in the Shape dialect to the -// i32-based representation used in HLO: -// * index => tensor. -// * tensor => tensor. -// * All i32-based types from above => themselves. -// There is no convenient op that can express this, so we're using -// unrealized_conversion_cast (with the idea that all these casts will -// annihilate at the end of the pass). -Value castToI32(PatternRewriter& rewriter, Location loc, Value value) { - Type resultType; - if (value.getType().isIndex()) - resultType = RankedTensorType::get({}, rewriter.getI32Type()); - if (auto valueType = mlir::dyn_cast(value.getType())) { - if (!valueType.hasStaticShape()) return {}; - if (valueType.getElementType().isInteger(32)) return value; - if (valueType.getElementType().isIndex()) - resultType = - RankedTensorType::get(valueType.getShape(), rewriter.getI32Type()); - } - if (!resultType) return {}; - auto cast = - rewriter.create(loc, resultType, value); - return cast.getResult(0); -} - -bool hasIndexStyle(Value value) { - if (value.getType().isIndex()) return true; - auto type = mlir::dyn_cast(value.getType()); - return type && type.getElementType().isIndex(); -} - -// Cast from the i32-based shape representation used in HLO to the index-based -// representation used in the Shape dialect: -// * tensor => index. -// * tensor => tensor. -// * All index-based types from above => themselves. -// There is no convenient op that can express this, so we're using -// unrealized_conversion_cast (with the idea that all these casts will -// annihilate at the end of the pass). -Value castToIndex(PatternRewriter& rewriter, Location loc, Value value) { - Type resultType; - if (value.getType().isIndex()) return value; - if (auto valueType = mlir::dyn_cast(value.getType())) { - if (!valueType.hasStaticShape()) return {}; - if (valueType.getElementType().isInteger(32)) { - if (valueType.getRank() == 0) { - resultType = rewriter.getIndexType(); - } else { - resultType = RankedTensorType::get(valueType.getShape(), - rewriter.getIndexType()); - } - } - if (valueType.getElementType().isIndex()) return value; - } - if (!resultType) return {}; - auto cast = - rewriter.create(loc, resultType, value); - return cast.getResult(0); -} - -void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, - Value assert) { - auto customCall = - builder.create(loc, TypeRange{}, ValueRange{assert}); - customCall.setCallTargetName("shape_assertion"); - customCall.setHasSideEffect(true); - customCall->setAttr("error_message", - builder.getStringAttr("Shape assertion failed")); -} - -struct ConvertNumElementsOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::NumElementsOp op, - PatternRewriter& rewriter) const override { - // Cast shape from tensor to tensor. - // This will error out if shape is !shape.shape. - auto shapeI32 = castToI32(rewriter, op.getLoc(), op.getShape()); - if (!shapeI32) return rewriter.notifyMatchFailure(op, "cast to i32 failed"); - auto rank = mlir::cast(shapeI32.getType()).getNumElements(); - - // Compute the product of the individual dimension sizes. - // Using this representation instead of mhlo::ReduceOp because it is more - // amenable to optimizations. (Reduce can be folded only if the entire - // shape is static, but individual multiplications can be folded if - // individual dimensions are static). - auto resultI32Type = RankedTensorType::get({}, rewriter.getI32Type()); - Value resultI32 = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(resultI32Type, 1)); - for (auto i = 0; i < rank; ++i) { - auto sizeI32x1 = rewriter.create( - op.getLoc(), shapeI32, rewriter.getI64TensorAttr(i), - rewriter.getI64TensorAttr(i + 1), rewriter.getI64TensorAttr(1)); - auto sizeI32 = - rewriter.create(op.getLoc(), resultI32Type, sizeI32x1); - resultI32 = rewriter.create(op.getLoc(), resultI32, sizeI32); - } - - // Cast result from tensor to index. - // This will error out if the result is !shape.size. - auto resultIndex = castToIndex(rewriter, op.getLoc(), resultI32); - if (!resultIndex || resultIndex.getType() != op.getResult().getType()) - return rewriter.notifyMatchFailure(op, "cast to index failed"); - rewriter.replaceOp(op, resultIndex); - return success(); - } -}; - -struct ConvertShapeOfOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::ShapeOfOp op, - PatternRewriter& rewriter) const override { - auto operandType = dyn_cast(op.getArg().getType()); - if (!operandType) - return rewriter.notifyMatchFailure(op, "expected ranked operand"); - - // Produce an MHLO equivalent of this shape::ShapeOfOp. - // This is a very laborious representation because MHLO is currently lacking - // convenient tools to express this. - Value shapeI32; - if (operandType.getRank() > 0) { - SmallVector sizesI32x1; - for (auto i = 0; i < operandType.getRank(); ++i) { - auto sizeI32 = - rewriter.create(op.getLoc(), op.getArg(), i); - auto sizeI32x1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()), - sizeI32); - sizesI32x1.push_back(sizeI32x1); - } - shapeI32 = rewriter.create(op.getLoc(), sizesI32x1, - /*dimension=*/0); - } else { - shapeI32 = rewriter.create( - op.getLoc(), DenseElementsAttr::get( - RankedTensorType::get({0}, rewriter.getI32Type()), - ArrayRef())); - } - - // Cast result from tensor to tensor. - // This will error out if the result is !shape.shape. - auto shapeIndex = castToIndex(rewriter, op.getLoc(), shapeI32); - if (!shapeIndex || shapeIndex.getType() != op.getType()) - return rewriter.notifyMatchFailure(op, "cast to index failed"); - rewriter.replaceOp(op, shapeIndex); - return success(); - } -}; - -struct ConvertConstShapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::ConstShapeOp op, - PatternRewriter& rewriter) const override { - auto operandType = - mlir::dyn_cast(op.getResult().getType()); - if (!operandType) - return rewriter.notifyMatchFailure(op, "expected ranked operand"); - - llvm::SmallVector shape; - for (int i : op.getShape().getValues()) { - shape.push_back(i); - } - auto newConst = rewriter.create( - op.getLoc(), DenseElementsAttr::get( - RankedTensorType::get({operandType.getDimSize(0)}, - rewriter.getI32Type()), - ArrayRef(shape))); - auto newConstIndex = castToIndex(rewriter, op.getLoc(), newConst); - rewriter.replaceOp(op, newConstIndex); - return success(); - } -}; - -struct ConvertIndexCastOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::IndexCastOp op, - PatternRewriter& rewriter) const override { - Value result = op.getIn(); - if (hasIndexStyle(op.getIn()) && - !mlir::isa(op.getIn().getType())) { - // Handle a special case of index -> i64. - // This is converted to the following sequence: - // unrealized_conversion_cast index -> tensor - // mhlo.convert tensor -> tensor - // unrealized_conversion_cast tensor -> i64 - result = castToI32(rewriter, op.getLoc(), result); - if (!op.getOut().getType().isInteger(32)) { - result = rewriter.create(op.getLoc(), result, - op.getOut().getType()); - } - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), op.getOut().getType(), result)); - return success(); - } - if (!mlir::isa(op.getIn().getType()) && - hasIndexStyle(op.getOut())) { - // Handle a special case of i32 -> index. - // This is converted to the following sequence: - // unrealized_conversion_cast i32 -> tensor - // unrealized_conversion_cast tensor -> index - result = rewriter - .create( - op.getLoc(), RankedTensorType::get({}, result.getType()), - result) - .getResult(0); - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), op.getOut().getType(), result)); - return success(); - } - - if (hasIndexStyle(result)) { - result = castToI32(rewriter, op.getLoc(), result); - } else if (!hasI32Style(result)) { - return rewriter.notifyMatchFailure(op, - "expected input with index/i32 style"); - } - - if (hasIndexStyle(op.getOut())) { - result = castToIndex(rewriter, op.getLoc(), result); - } else if (!hasI32Style(op.getOut())) { - return rewriter.notifyMatchFailure( - op, "expected output with index/i32 style"); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ConvertMulIOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::MulIOp op, - PatternRewriter& rewriter) const override { - // We only handle index types. - if (!hasIndexStyle(op.getLhs()) || !hasIndexStyle(op.getRhs()) || - !hasIndexStyle(op.getResult())) { - return rewriter.notifyMatchFailure(op, "expected index type"); - } - Value lhs = op.getLhs(); - if (auto constIndex = - dyn_cast_or_null(lhs.getDefiningOp())) { - lhs = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getI32Type()), - static_cast(constIndex.value()))); - } else { - lhs = castToI32(rewriter, op.getLoc(), op.getLhs()); - } - Value rhs = op.getRhs(); - if (auto constIndex = - dyn_cast_or_null(rhs.getDefiningOp())) { - rhs = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getI32Type()), - static_cast(constIndex.value()))); - } else { - rhs = castToI32(rewriter, op.getLoc(), op.getRhs()); - } - Value result = rewriter.create(op.getLoc(), lhs, rhs); - rewriter.replaceOp(op, castToIndex(rewriter, op.getLoc(), result)); - return success(); - } -}; - -// Pads input tensor by X ones from the left. The number X is -// determined by input pad. Result is tensor<(X+N) x i32>, where the first X -// elements are ones. -Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, - int64_t pad) { - Value padI32 = rewriter.create( - loc, DenseIntElementsAttr::get( - RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); - return rewriter.create(loc, ValueRange{padI32, input}, - /*dimension=*/0); -} - -struct ConvertShapeBroadcastOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::BroadcastOp op, - PatternRewriter& rewriter) const override { - // As defined, op inputs must be 1D tensor or !shape.shape. - // We only support inputs of two input 1D tensors. - if (op.getShapes().size() != 2) return failure(); - auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); - auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); - if (!shape1 || !shape2) return failure(); - auto tensorType1 = mlir::dyn_cast(shape1.getType()); - auto tensorType2 = mlir::dyn_cast(shape2.getType()); - if (!tensorType1 || !tensorType2) return failure(); - - // If the two operand shapes are of different sizes, the smaller one is - // padded with 1's from the left. - if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { - shape1 = - padFromLeft(rewriter, op.getLoc(), shape1, - tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); - } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { - shape2 = - padFromLeft(rewriter, op.getLoc(), shape2, - tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); - } - - // By definition, broadcasted dims are: - // result[i] = lhs[i] if lhs[i] == rhs[i] - // = lhs[i] if rhs[i] == 1 - // = rhs[i] if lhs[i] == 1 - // - // We assume that there is shape.cstr_broadcastable check done elsewhere to - // make sure the shapes are broadcastable, then we can calculate broadcast - // result simply using MaxOp. In case the shapes are not broadcastable, the - // result extent tensor is undefined according to spec. So this - // implementation is technically correct. - auto broadcasted = - rewriter.create(op->getLoc(), shape1, shape2); - - auto broadcastedIndex = castToIndex(rewriter, op.getLoc(), broadcasted); - if (!broadcastedIndex || - broadcastedIndex.getType() != op.getResult().getType()) - return rewriter.notifyMatchFailure(op, "cast to index failed"); - rewriter.replaceOp(op, broadcastedIndex); - return success(); - } -}; - -struct ConvertTensorDimPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::DimOp op, - PatternRewriter& rewriter) const override { - // We only support getting static index. - auto constIndex = - dyn_cast_or_null(op.getIndex().getDefiningOp()); - if (!constIndex) { - return failure(); - } - - auto dim = rewriter.create( - op->getLoc(), op.getSource(), constIndex.value()); - auto dimIndex = castToIndex(rewriter, op.getLoc(), dim); - rewriter.replaceOp(op, dimIndex); - return success(); - } -}; - -struct ConvertTensorExtractPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::ExtractOp op, - PatternRewriter& rewriter) const override { - SmallVector indices; - auto tensorType = op.getTensor().getType(); - // We only support getting static indices. - for (auto index : op.getIndices()) { - auto constIndex = - dyn_cast_or_null(index.getDefiningOp()); - if (!constIndex) - return rewriter.notifyMatchFailure(op, "expected constant index op"); - - // Check if the index is out of range. - int idx = indices.size(); - if (tensorType.isDynamicDim(idx) || - constIndex.value() >= tensorType.getDimSize(idx)) - return rewriter.notifyMatchFailure(op, "index out of range"); - - indices.push_back(constIndex.value()); - } - auto input = castToI32(rewriter, op.getLoc(), op.getTensor()); - auto startIndices = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(indices.size())}, - rewriter.getI64Type()), - indices); - for (auto& index : indices) { - index += 1; - } - auto limitIndices = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(indices.size())}, - rewriter.getI64Type()), - indices); - - Value extractedTensor = rewriter.create( - op.getLoc(), input, startIndices, limitIndices, - /*strides=*/ - DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(indices.size())}, - rewriter.getI64Type()), - 1)); - Value extractedScalarTensor = rewriter.create( - op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()), - extractedTensor); - if (getElementTypeOrSelf(op.getResult().getType()).isIndex()) { - auto extractedIndex = - castToIndex(rewriter, op.getLoc(), extractedScalarTensor); - rewriter.replaceOp(op, extractedIndex); - } else { - // For the special case when the input is a i32 tensor and output is i32, - // convert the result back to i32 to be consistent: - // unrealized_conversion_cast tensor -> i32 - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), op.getResult().getType(), - extractedScalarTensor)); - } - return success(); - } -}; - -struct ConvertTensorFromElementsPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::FromElementsOp op, - PatternRewriter& rewriter) const override { - auto tensorType = - mlir::dyn_cast_or_null(op.getResult().getType()); - if (!tensorType) { - return failure(); - } - if (tensorType.getRank() == 0) { - // Handle the special cast of tensor.from_elements i64 -> tensor - // This is converted to unrealized_conversin_cast i64 -> tensor, - // which is later cancelled with previous unrealized_conversin_cast op. - rewriter.replaceOp( - op, rewriter.create( - op.getLoc(), op.getResult().getType(), op.getElements()[0])); - return success(); - } - - // We only handle 1D tensor with index types. tensor.from_elements spec - // allows the same element type only for all input/output. - if (tensorType.getRank() != 1) return failure(); - if (!hasIndexStyle(op.getResult())) return failure(); - - SmallVector elementI32x1; - for (size_t i = 0; i < op.getElements().size(); ++i) { - if (auto constIndex = dyn_cast_or_null( - op.getElements()[i].getDefiningOp())) { - elementI32x1.push_back(rewriter.create( - op.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get({1}, rewriter.getI32Type()), - static_cast(constIndex.value())))); - } else { - elementI32x1.push_back(rewriter.create( - op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()), - castToI32(rewriter, op->getLoc(), op.getElements()[i]))); - } - } - Value tensorI32 = - rewriter.create(op.getLoc(), elementI32x1, - /*dimension=*/0); - - tensorI32 = hasI32Style(op.getResult()) - ? tensorI32 - : castToIndex(rewriter, op.getLoc(), tensorI32); - if (!tensorI32 || tensorI32.getType() != op.getResult().getType()) - return rewriter.notifyMatchFailure(op, "cast to index failed"); - rewriter.replaceOp(op, tensorI32); - return success(); - } -}; - -struct ConvertCstrBroadcastableOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, - PatternRewriter& rewriter) const override { - // As defined, op inputs must be 1D tensor or !shape.shape. - // We only support inputs of two 1D tensors. - if (op.getShapes().size() != 2) return failure(); - auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); - auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); - if (!shape1 || !shape2) return failure(); - auto tensorType1 = mlir::dyn_cast(shape1.getType()); - auto tensorType2 = mlir::dyn_cast(shape2.getType()); - if (!tensorType1 || !tensorType2) return failure(); - - // If the two operand shapes are of different sizes, the smaller one is - // padded with 1's from the left. - int32_t rank = - std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); - if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { - shape1 = - padFromLeft(rewriter, op.getLoc(), shape1, - tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); - } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { - shape2 = - padFromLeft(rewriter, op.getLoc(), shape2, - tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); - } - - // Compute if each dim is broadcastable. A dim is broadcastable iff - // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 - auto allOne = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get({rank}, rewriter.getI32Type()), - static_cast(1))); - Value dimSize1Is1 = rewriter.create( - op.getLoc(), shape1, allOne, ComparisonDirection::EQ); - Value dimSize2Is1 = rewriter.create( - op.getLoc(), shape2, allOne, ComparisonDirection::EQ); - Value eitherDimSizeIs1 = - rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); - Value dimSizeEq = rewriter.create( - op.getLoc(), shape1, shape2, ComparisonDirection::EQ); - Value dimBroadcastable = - rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); - - // Iterate over each dim to check that all dims are broadcastable. - auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); - Value allBroadcastable = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(boolType, true)); - for (auto i = 0; i < rank; ++i) { - Value broadcastable = rewriter.create( - op.getLoc(), dimBroadcastable, rewriter.getI64TensorAttr(i), - rewriter.getI64TensorAttr(i + 1), rewriter.getI64TensorAttr(1)); - allBroadcastable = - rewriter.create(op.getLoc(), allBroadcastable, broadcastable); - } - Value allBroadcastableScalar = rewriter.create( - op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), - allBroadcastable); - - // Add CustomCallOp and replace Cstr op with const witness, which is useful - // for canonicalizer to remove the shape.assuming region. - insertShapeAssertionCustomCall(rewriter, op->getLoc(), - allBroadcastableScalar); - rewriter.replaceOpWithNewOp(op.getOperation(), true); - return success(); - } -}; - -template -struct CastOperandsPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpType op, - PatternRewriter& rewriter) const override { - if (!llvm::any_of(op->getOperands(), hasIndexStyle)) - return rewriter.notifyMatchFailure(op, "no operands need a cast to i32"); - - // If op has operands of type tensor, cast them to tensor. - // If producers of these operands have been transformed into casts from - // tensor to tensor, then these casts will annihilate with - // each other upon canonicalization. - SmallVector operandsI32; - for (auto operand : op->getOperands()) { - if (hasIndexStyle(operand)) { - operandsI32.push_back(castToI32(rewriter, op.getLoc(), operand)); - } else { - operandsI32.push_back(operand); - } - } - - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), operandsI32, - op->getAttrs()); - return success(); - } -}; - -// TODO(b/264240901): Comprehensively support shape computations to the extent -// needed to support bounded dynamism in MHLO export. -struct ShapeLegalizeToHloPass - : public impl::ShapeLegalizeToHloPassBase { - explicit ShapeLegalizeToHloPass(bool legalizeConstraints) - : impl::ShapeLegalizeToHloPassBase< - ShapeLegalizeToHloPass>::ShapeLegalizeToHloPassBase() { - this->legalize_constraints_ = legalizeConstraints; - } - - void runOnOperation() override { - // In order to make dynamic MHLO programs compatible with HLO, - // we need to get rid of all non-MHLO ops. - // - // As an example, a cursory inspection of the TF/XLA bridge, which provides - // one data point of an MHLO producer that can generate dynamic MHLO - // programs, reveals the following non-MHLO ops: - // * shape.broadcast - // * shape.concat - // * shape.cstr_broadcastable - // * shape.cstr_eq - // * shape.dim - // * shape.split_at - // * shape.to_extent_tensor - // * shape.assuming - // * shape.assuming_yield - // * tensor.dim - // * tensor.extract - // * tensor.from_elements - // - // Most of these ops are convertible to MHLO, although the representation is - // going to be pretty laborious for many of them. Luckily, canonicalization - // is able to remove unnecessary cruft. At the moment, this pass is a - // work in progress, so not all of these ops are supported. - // - // When legalize_constraints_ is set true, cstr* ops are also legalized. - // A shape_assertion custom_call is used to check the constraint. And the - // shape.assuming region will consume a shape.const_witness that evaluate to - // true, so that it can be removed later in a canonicalizer pass. - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addDynamicallyLegalDialect([](Operation* op) { - return !llvm::any_of(op->getOperands(), hasIndexStyle); - }); - target.addLegalOp(); - target.addLegalOp(); - if (this->legalize_constraints_) { - target.addLegalOp(); - } - - // The patterns do what one might expect, converting between MLIR-style - // and HLO-style shape computations. - // - // The only complication is that MLIR style uses index/tensor - // whereas HLO style uses tensor/vararg of tensor. We bridge - // this gap by producing unrealized_conversion_cast ops, which we expect - // to ultimately annihilate with each other upon canonicalization if - // everything went right. - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add>(&getContext()); - patterns.add>(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - if (this->legalize_constraints_) { - patterns.add(&getContext()); - } - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> createShapeLegalizeToHloPass( - bool legalizeConstraints) { - return std::make_unique(legalizeConstraints); -} - -} // namespace mhlo -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir deleted file mode 100644 index fe28bbe8f4977b..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir +++ /dev/null @@ -1,110 +0,0 @@ -// RUN: mlir-hlo-opt --shape-legalize-to-hlo=legalize-constraints=true --split-input-file --verify-diagnostics %s | FileCheck %s - -// ----- - -// CHECK-LABEL: func.func @shape_cstr_broadcastable -func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { - %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> - shape.assuming %0 { - } - func.return - // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> - // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> - // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> - // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> - // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[DIMS2]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> - // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> - // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> - // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> - // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> - // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> - // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor - // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () - // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true - // CHECK-NEXT: shape.assuming %[[WITNESS]] { - // CHECK-NEXT: } - // CHECK-NEXT: return -} - -// ----- - -func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { - // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} - %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape - shape.assuming %0 { - } - func.return -} - -// ----- - -func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { - %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> - shape.assuming %0 { - } - func.return - // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> - // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> - // CHECK-NEXT: %[[PAD:.*]] = mhlo.constant dense<1> : tensor<1xi32> - // CHECK-NEXT: %[[DIMS2_PAD:.*]] = "mhlo.concatenate"(%[[PAD]], %[[DIMS2]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> - // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> - // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> - // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> - // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> - // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> - // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> - // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> - // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor - // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () - // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true - // CHECK-NEXT: shape.assuming %[[WITNESS]] { - // CHECK-NEXT: } - // CHECK-NEXT: return -} - -// ----- - -func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { - %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> - shape.assuming %0 { - } - func.return - // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> - // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> - // CHECK-NEXT: %[[PAD:.*]] = mhlo.constant dense<1> : tensor<1xi32> - // CHECK-NEXT: %[[DIMS1_PAD:.*]] = "mhlo.concatenate"(%[[PAD]], %[[DIMS1]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> - // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> - // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> - // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> - // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> - // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> - // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> - // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> - // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor - // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () - // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true - // CHECK-NEXT: shape.assuming %[[WITNESS]] { - // CHECK-NEXT: } - // CHECK-NEXT: return -} - -// ----- - -func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { - // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} - %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> - shape.assuming %0 { - } - func.return -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir deleted file mode 100644 index f60b70b316a0c9..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir +++ /dev/null @@ -1,372 +0,0 @@ -// RUN: mlir-hlo-opt --shape-legalize-to-hlo --split-input-file --verify-diagnostics %s | FileCheck %s - -// CHECK-LABEL: func.func @num_elements_tensor_to_index -func.func @num_elements_tensor_to_index(%arg0: tensor<2xindex>) -> index { - %0 = shape.num_elements %arg0 : tensor<2xindex> -> index - func.return %0 : index - // CHECK: %[[ARG0_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> - // CHECK-NEXT: %[[TMP0:.*]] = mhlo.constant dense<1> : tensor - // CHECK-NEXT: %[[SIZE0x1:.*]] = "mhlo.slice"(%[[ARG0_I32]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[SIZE0:.*]] = mhlo.reshape %[[SIZE0x1]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[TMP1:.*]] = mhlo.multiply %[[TMP0]], %[[SIZE0]] : tensor - // CHECK-NEXT: %[[SIZE1x1:.*]] = "mhlo.slice"(%[[ARG0_I32]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[SIZE1:.*]] = mhlo.reshape %[[SIZE1x1]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[RESULT_I32:.*]] = mhlo.multiply %[[TMP1]], %[[SIZE1]] : tensor - // CHECK-NEXT: %[[RESULT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESULT_I32]] : tensor to index - // CHECK-NEXT: return %[[RESULT_INDEX]] : index -} - -// ----- - -func.func @num_elements_shape_to_xxx(%arg0: !shape.shape) -> !shape.size { - // expected-error@+1 {{failed to legalize operation 'shape.num_elements' that was explicitly marked illegal}} - %0 = shape.num_elements %arg0 : !shape.shape -> !shape.size - func.return %0 : !shape.size -} - -// ----- - -func.func @num_elements_xxx_to_size(%arg0: tensor<2xindex>) -> !shape.size { - // expected-error@+1 {{failed to legalize operation 'shape.num_elements' that was explicitly marked illegal}} - %0 = shape.num_elements %arg0 : tensor<2xindex> -> !shape.size - func.return %0 : !shape.size -} - -// ----- - -// CHECK-LABEL: func.func @shape_of_ranked -func.func @shape_of_ranked_to_index(%arg0: tensor) -> tensor<2xindex> { - %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> - func.return %0 : tensor<2xindex> - // CHECK: %[[SIZE0x1:.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor - // CHECK-NEXT: %[[SIZE0:.*]] = mhlo.reshape %[[SIZE0x1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[SIZE1x1:.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor - // CHECK-NEXT: %[[SIZE1:.*]] = mhlo.reshape %[[SIZE1x1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[RESULT_I32:.*]] = "mhlo.concatenate"(%[[SIZE0]], %[[SIZE1]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: %[[RESULT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESULT_I32]] : tensor<2xi32> to tensor<2xindex> - // CHECK-NEXT: return %[[RESULT_INDEX]] : tensor<2xindex> -} - -// ----- - -func.func @shape_of_unranked_to_xxx(%arg0: tensor<*xf32>) -> tensor { - // expected-error@+1 {{failed to legalize operation 'shape.shape_of' that was explicitly marked illegal}} - %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor - func.return %0 : tensor -} - -// ----- - -func.func @shape_of_ranked_to_shape(%arg0: tensor) -> !shape.shape { - // expected-error@+1 {{failed to legalize operation 'shape.shape_of' that was explicitly marked illegal}} - %0 = shape.shape_of %arg0 : tensor -> !shape.shape - func.return %0 : !shape.shape -} - -// ----- - -// CHECK-LABEL: func.func @tensor_dim -func.func @tensor_dim(%arg0: tensor) -> index { - %c0 = arith.constant 0 : index - %dim = tensor.dim %arg0, %c0 : tensor - func.return %dim : index - // CHECK: %[[DIM_SIZE:.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor - // CHECK-NEXT: %[[DIM_SIZE_INDEX:.*]] = builtin.unrealized_conversion_cast %[[DIM_SIZE]] : tensor to index - // CHECK-NEXT: return %[[DIM_SIZE_INDEX]] : index -} - -// ----- - -func.func @tensor_dim_dynamic(%arg0: tensor, %arg1: index) -> index { - // expected-error@+1 {{failed to legalize operation 'tensor.dim' that was explicitly marked illegal}} - %dim = tensor.dim %arg0, %arg1 : tensor - func.return %dim : index -} - -// ----- - -// CHECK-LABEL: func.func @tensor_from_elements -func.func @tensor_from_elements(%arg0: index) -> tensor<2xindex> { - %c0 = arith.constant 0 : index - %0 = tensor.from_elements %arg0, %c0 : tensor<2xindex> - func.return %0 : tensor<2xindex> - // CHECK: %[[ELEMENT1_SCALAR:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor - // CHECK-NEXT: %[[ELEMENT1:.*]] = mhlo.reshape %[[ELEMENT1_SCALAR]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[ELEMENT2:.*]] = mhlo.constant dense<0> : tensor<1xi32> - // CHECK-NEXT: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[ELEMENT1]], %[[ELEMENT2]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: %[[CONCAT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONCAT]] : tensor<2xi32> to tensor<2xindex> - // CHECK-NEXT: return %[[CONCAT_INDEX]] : tensor<2xindex> -} - -// ----- - -func.func @tensor_from_elements_i8(%arg0: i8) -> tensor<2xi8> { - %c0 = arith.constant 0 : i8 - // expected-error@+1 {{failed to legalize operation 'tensor.from_elements' that was explicitly marked illegal}} - %0 = tensor.from_elements %arg0, %c0 : tensor<2xi8> - func.return %0 : tensor<2xi8> -} - -// ----- - -// CHECK-LABEL: func.func @tensor_from_elements_scalar -func.func @tensor_from_elements_scalar(%arg0: i64) -> tensor { - %0 = tensor.from_elements %arg0 : tensor - func.return %0 : tensor - // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %arg0 : i64 to tensor - // CHECK-NEXT: return %[[RESULT]] : tensor -} - -// ----- - -func.func @tensor_from_elements_rank2(%arg0: index) -> tensor<2x1xindex> { - %c0 = arith.constant 0 : index - // expected-error@+1 {{failed to legalize operation 'tensor.from_elements' that was explicitly marked illegal}} - %0 = tensor.from_elements %arg0, %c0 : tensor<2x1xindex> - func.return %0 : tensor<2x1xindex> -} - -// ----- - -// CHECK-LABEL: func.func @shape_broadcast -func.func @shape_broadcast(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>) -> tensor<4xindex> { - %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> - func.return %0 : tensor<4xindex> - // CHECK: %[[LHS:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<4xindex> to tensor<4xi32> - // CHECK-NEXT: %[[RHS:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<4xindex> to tensor<4xi32> - // CHECK-NEXT: %[[BROADCAST:.*]] = mhlo.maximum %[[LHS]], %[[RHS]] : tensor<4xi32> - // CHECK-NEXT: %[[BROADCAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[BROADCAST]] : tensor<4xi32> to tensor<4xindex> - // CHECK-NEXT: return %[[BROADCAST_INDEX]] : tensor<4xindex> -} - -// ----- - -func.func @shape_broadcast_different_dims(%arg0: tensor<4xindex>, %arg1: tensor<6xindex>) -> tensor<6xindex> { - %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<6xindex> -> tensor<6xindex> - func.return %0 : tensor<6xindex> - // CHECK: %[[LHS:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<4xindex> to tensor<4xi32> - // CHECK-NEXT: %[[RHS:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<6xindex> to tensor<6xi32> - // CHECK-NEXT: %[[PAD:.*]] = mhlo.constant dense<1> : tensor<2xi32> - // CHECK-NEXT: %[[LHS_PAD:.*]] = "mhlo.concatenate"(%[[PAD]], %[[LHS]]) <{dimension = 0 : i64}> : (tensor<2xi32>, tensor<4xi32>) -> tensor<6xi32> - // CHECK-NEXT: %[[BROADCAST:.*]] = mhlo.maximum %[[LHS_PAD]], %[[RHS]] : tensor<6xi32> - // CHECK-NEXT: %[[BROADCAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[BROADCAST]] : tensor<6xi32> to tensor<6xindex> - // CHECK-NEXT: return %[[BROADCAST_INDEX]] : tensor<6xindex> -} - -// ----- - -func.func @shape_broadcast_result_shape(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>) -> !shape.shape { - // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} - %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<4xindex> -> !shape.shape - func.return %0 : !shape.shape -} - -// ----- - -func.func @shape_broadcast_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) -> !shape.shape { - // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} - %0 = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> !shape.shape - func.return %0 : !shape.shape -} - -// ----- - -func.func @shape_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) -> tensor<4xindex> { - // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} - %0 = shape.broadcast %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> - func.return %0 : tensor<4xindex> -} - -// ----- - -func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) -> !shape.witness { - // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} - %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> - func.return %0 : !shape.witness -} - -// ----- - -// CHECK-LABEL: func @const_shape -func.func @const_shape() -> tensor<2xindex> { - %0 = shape.const_shape [6, 4] : tensor<2xindex> - return %0 : tensor<2xindex> - // CHECK: %[[CST:.*]] = mhlo.constant dense<[6, 4]> : tensor<2xi32> - // CHECK-NEXT: %[[CST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CST]] : tensor<2xi32> to tensor<2xindex> - // CHECK-NEXT: return %[[CST_INDEX]] : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: func @index_cast_index_to_i32 -func.func @index_cast_index_to_i32(%arg0: tensor<2xindex>) -> tensor<2xi32> { - %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi32> - return %0 : tensor<2xi32> - // CHECK-NEXT: %[[CST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> - // CHECK-NEXT: return %[[CST_I32]] : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: func @index_cast_i32_to_index -func.func @index_cast_i32_to_index(%arg0: tensor<2xi32>) -> tensor<2xindex> { - %0 = arith.index_cast %arg0 : tensor<2xi32> to tensor<2xindex> - return %0 : tensor<2xindex> - // CHECK-NEXT: %[[CST_INDEX:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xi32> to tensor<2xindex> - // CHECK-NEXT: return %[[CST_INDEX]] : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: func @index_cast_scalar_index_to_i32 -func.func @index_cast_scalar_index_to_i32(%arg0: index) -> i32 { - // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor - // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CAST_I32]] : tensor to i32 - // CHECK-NEXT: return %[[CAST_INDEX]] : i32 - %0 = arith.index_cast %arg0 : index to i32 - return %0 : i32 -} - -// ----- - -// CHECK-LABEL: func @index_cast_scalar_index_to_i64 -func.func @index_cast_scalar_index_to_i64(%arg0: index) -> i64 { - // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor - // CHECK-NEXT: %[[CONVERT:.*]] = mhlo.convert %[[CAST_I32]] : (tensor) -> tensor - // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONVERT]] : tensor to i64 - // CHECK-NEXT: return %[[CAST_INDEX]] : i64 - %0 = arith.index_cast %arg0 : index to i64 - return %0 : i64 -} - -// ----- - -func.func @index_cast_scalar_i32_to_index(%arg0: i32) -> index { - // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : i32 to tensor - // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CAST_I32]] : tensor to index - // CHECK-NEXT: return %[[CAST_INDEX]] : index - %0 = arith.index_cast %arg0 : i32 to index - return %0 : index -} - -// ----- - -func.func @index_cast_index_to_i8(%arg0: tensor<2xindex>) -> tensor<2xi8> { - // expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}} - %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi8> - return %0 : tensor<2xi8> -} - -// ----- - -func.func @index_cast_i8_to_index(%arg0: tensor<2xi8>) -> tensor<2xindex> { - // expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}} - %0 = arith.index_cast %arg0 : tensor<2xi8> to tensor<2xindex> - return %0 : tensor<2xindex> -} - - -// ----- - -// CHECK-LABEL: func @muli -func.func @muli(%arg0: index, %arg1: index) -> index { - %0 = arith.muli %arg0, %arg1 : index - return %0 : index - // CHECK: %[[LHS:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor - // CHECK-NEXT: %[[RHS:.*]] = builtin.unrealized_conversion_cast %arg1 : index to tensor - // CHECK-NEXT: %[[RES:.*]] = mhlo.multiply %[[LHS]], %[[RHS]] : tensor - // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RES]] : tensor to index - // CHECK-NEXT: return %[[RES_INDEX]] : index -} - -// ----- - -// CHECK-LABEL: func @muli_const -func.func @muli_const() -> index { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %0 = arith.muli %c1, %c2 : index - return %0 : index - // CHECK: %[[LHS:.*]] = mhlo.constant dense<1> : tensor - // CHECK-NEXT: %[[RHS:.*]] = mhlo.constant dense<2> : tensor - // CHECK-NEXT: %[[RES:.*]] = mhlo.multiply %[[LHS]], %[[RHS]] : tensor - // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RES]] : tensor to index - // CHECK-NEXT: return %[[RES_INDEX]] : index -} - -// ----- - -func.func @muli_i32(%arg0: i32, %arg1: i32) -> i32 { - // expected-error@+1 {{failed to legalize operation 'arith.muli' that was explicitly marked illegal}} - %0 = arith.muli %arg0, %arg1 : i32 - return %0 : i32 -} - -// ----- - -// CHECK-LABEL: func @tensor_extract -func.func @tensor_extract(%arg0: tensor<3x3xindex>) -> index { - %c1 = arith.constant 0 : index - %c2 = arith.constant 1 : index - %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex> - return %0 : index - // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<3x3xindex> to tensor<3x3xi32> - // CHECK-NEXT: %[[SLICE:.*]] = "mhlo.slice"(%[[CAST]]) - // CHECK-SAME: limit_indices = dense<[1, 2]> : tensor<2xi64> - // CHECK-SAME: start_indices = dense<[0, 1]> : tensor<2xi64> - // CHECK-SAME: strides = dense<1> : tensor<2xi64> - // CHECK-SAME: (tensor<3x3xi32>) -> tensor<1x1xi32> - // CHECK-NEXT: %[[RESHAPE:.*]] = mhlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor - // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor to index - // CHECK-NEXT: return %[[RES_INDEX]] : index -} - -// ----- - -// CHECK-LABEL: func @tensor_extract_i32 -func.func @tensor_extract_i32(%arg0: tensor<3x3xi32>) -> i32 { - %c1 = arith.constant 0 : index - %c2 = arith.constant 1 : index - %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xi32> - return %0 : i32 - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%arg0) - // CHECK-SAME: limit_indices = dense<[1, 2]> : tensor<2xi64> - // CHECK-SAME: start_indices = dense<[0, 1]> : tensor<2xi64> - // CHECK-SAME: strides = dense<1> : tensor<2xi64> - // CHECK-SAME: (tensor<3x3xi32>) -> tensor<1x1xi32> - // CHECK-NEXT: %[[RESHAPE:.*]] = mhlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor - // CHECK-NEXT: %[[RES_I32:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor to i32 - // CHECK-NEXT: return %[[RES_I32]] : i32 -} - -// ----- - -func.func @tensor_extract_out_of_range(%arg0: tensor<3x3xindex>) -> index { - %c1 = arith.constant 4 : index - %c2 = arith.constant 4 : index - // expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}} - %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex> - return %0 : index -} - -// ----- - -func.func @tensor_extract_dynamic(%arg0: tensor) -> index { - %c1 = arith.constant 0 : index - %c2 = arith.constant 2 : index - // expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}} - %0 = tensor.extract %arg0[%c1, %c2] : tensor - return %0 : index -} - -// ----- - -// CHECK-LABEL: func @shape_of_zero_ranked_tensor -func.func @shape_of_zero_ranked_tensor(%arg0 : tensor) -> tensor<0xindex> { - // CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi32> - // CHECK-NEXT: %[[RES_DIM0_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONST]] : tensor<0xi32> to tensor<0xindex> - // CHECK-NEXT: return %[[RES_DIM0_INDEX]] : tensor<0xindex> - %0 = shape.shape_of %arg0 : tensor -> tensor<0xindex> - func.return %0 : tensor<0xindex> -} - From 9acbcb3b7bea605422d17a2c546eb9a793280699 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 1 Apr 2025 09:54:59 -0700 Subject: [PATCH 0096/1324] [XLA:GPU/TMA] Adjustments to TritonXLA ops in preparation for using them in the generic Triton Emitter. PiperOrigin-RevId: 742739235 --- .../gpu/codegen/triton/emitter_helpers.cc | 15 ++ .../gpu/codegen/triton/emitter_helpers.h | 4 + .../gpu/codegen/triton/ir/tests/invalid.mlir | 24 ++- .../gpu/codegen/triton/ir/tests/ops.mlir | 16 +- .../gpu/codegen/triton/ir/triton_xla_attrs.cc | 9 +- .../codegen/triton/ir/triton_xla_dialect.td | 8 - .../gpu/codegen/triton/ir/triton_xla_ops.cc | 36 ++-- .../gpu/codegen/triton/ir/triton_xla_ops.td | 32 +--- .../gpu/codegen/triton/transforms/BUILD | 12 +- .../gpu/codegen/triton/transforms/passes.td | 1 + .../triton_xla_extract_insert_to_triton.mlir | 60 +++--- ...riton_xla_extract_insert_to_triton_pass.cc | 176 ++++++++++++++---- 12 files changed, 235 insertions(+), 158 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index c2f0952c13e61e..e231a2c6a9e934 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -110,6 +110,21 @@ absl::StatusOr TritonType(EmitterLocOpBuilder& b, PrimitiveType t) { } } +absl::StatusOr GetPrimitiveType(Type t) { + if (t.isF64()) return F64; + if (t.isF32()) return F32; + if (t.isF16()) return F16; + if (t.isBF16()) return BF16; + if (t.isInteger(64)) return S64; + if (t.isInteger(32)) return S32; + if (t.isInteger(16)) return S16; + if (t.isInteger(8)) return S8; + if (t.isInteger(1)) return PRED; + if (mlir::isa(t)) return F8E5M2; + if (mlir::isa(t)) return F8E4M3FN; + return absl::UnimplementedError("Unsupported type in getPrimitiveType.\n"); +} + Type StorageType(EmitterLocOpBuilder& b, Type t) { if (t.isInteger(1)) { return b.getI8Type(); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h index 0ab584d5f7bb6e..20d2babaf5f3d1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -101,6 +102,9 @@ llvm::SmallVector GetPaddedTileSizes( // XLA -> Triton type conversions. absl::StatusOr TritonType(EmitterLocOpBuilder& b, PrimitiveType t); +// Triton type -> XLA type conversions. +absl::StatusOr GetPrimitiveType(mlir::Type t); + mlir::Type StorageType(EmitterLocOpBuilder& b, mlir::Type t); // Get the value of the scalar constant's literal in a C++ type. diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir index b89fb9726f9741..0ba8c41d32efc8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir @@ -2,12 +2,10 @@ // TODO(manany): Fix issue with this test. tt.func @tile_mismatch_rank(%arg0: tensor<256x256xbf16>) { - %cst_0_i32 = arith.constant 0 : i32 - %cst_0_i64 = arith.constant 0 : i64 - %cst_64_i64 = arith.constant 64 : i64 - // expected-error @+1 {{mismatch between tensor rank and one or more of offsets/sizes/strides}} - %tiled_tensor = triton_xla.tile %arg0 [%cst_0_i32][%cst_64_i64][%cst_0_i64] - : !triton_xla.tiled_tensor<16x64|256x256xbf16> + %cst_0 = arith.constant 0 : index + // expected-error @+1 {{mismatch between tensor rank and one or more of offsets and strides}} + %tiled_tensor = triton_xla.tile %arg0 [%cst_0][%cst_0] + {layout = array} : !triton_xla.tiled_tensor<16x64|256x256xbf16> tt.return } @@ -15,7 +13,7 @@ tt.func @tile_mismatch_rank(%arg0: tensor<256x256xbf16>) { tt.func @extract_mismatch_rank( %arg0: !triton_xla.tiled_tensor<16x64|256x256xbf16>) { - %cst = arith.constant 0 : i32 + %cst = arith.constant 0 : index // expected-error @+1 {{source tensor rank does not match number of offsets}} %extracted_tensor = triton_xla.extract %arg0 [%cst] : tensor<256x256xbf16> to tensor<16x64xbf16> @@ -27,7 +25,7 @@ tt.func @extract_mismatch_rank( tt.func @insert_mismatch_rank( %arg0: tensor<16x64xbf16>, %arg1: !triton_xla.tiled_tensor<16x64|256x256xbf16>) { - %cst = arith.constant 0 : i32 + %cst = arith.constant 0 : index // expected-error @+1 {{destination tensor rank does not match number of offsets}} %inserted_tensor = triton_xla.insert %arg0 into %arg1 [%cst,%cst,%cst] : tensor<16x64xbf16> into tensor<256x256xbf16> @@ -39,7 +37,7 @@ tt.func @insert_mismatch_rank( "tt.func"() <{function_type = (tensor) -> !triton_xla.tiled_tensor<|bf16>, sym_name = "xla_triton_tile"}> ({ ^bb0(%arg0: tensor): // expected-error @+1 {{cannot tile a 0-d tensor}} - %0 = "triton_xla.tile"(%arg0) : (tensor) -> !triton_xla.tiled_tensor<|bf16> + %0 = "triton_xla.tile"(%arg0) {layout = array} : (tensor) -> !triton_xla.tiled_tensor<|bf16> "tt.return"(%0) : (!triton_xla.tiled_tensor<|bf16>) -> () }) : () -> () @@ -47,9 +45,9 @@ tt.func @insert_mismatch_rank( "tt.func"() <{function_type = (!triton_xla.tiled_tensor<|bf16>) -> tensor, sym_name = "xla_triton_extract"}> ({ ^bb0(%arg0: !triton_xla.tiled_tensor<|bf16>): - %0 = "arith.constant"() <{value = 0 : i32}> : () -> i32 + %0 = "arith.constant"() <{value = 0 : index}> : () -> index // expected-error @+1 {{cannot extract a 0-d tensor}} - %1 = "triton_xla.extract"(%arg0, %0, %0) : (!triton_xla.tiled_tensor<|bf16>, i32, i32) -> tensor + %1 = "triton_xla.extract"(%arg0, %0, %0) : (!triton_xla.tiled_tensor<|bf16>, index, index) -> tensor "tt.return"(%1) : (tensor) -> () }) : () -> () @@ -57,8 +55,8 @@ tt.func @insert_mismatch_rank( "tt.func"() <{function_type = (tensor, !triton_xla.tiled_tensor<|bf16>) -> tensor, sym_name = "xla_triton_insert"}> ({ ^bb0(%arg0: tensor, %arg1: !triton_xla.tiled_tensor<|bf16>): - %0 = "arith.constant"() <{value = 0 : i32}> : () -> i32 + %0 = "arith.constant"() <{value = 0 : index}> : () -> index // expected-error @+1 {{cannot insert a 0-d tensor}} - %1 = "triton_xla.insert"(%arg0, %arg1, %0, %0) : (tensor, !triton_xla.tiled_tensor<|bf16>, i32, i32) -> tensor + %1 = "triton_xla.insert"(%arg0, %arg1, %0, %0) : (tensor, !triton_xla.tiled_tensor<|bf16>, index, index) -> tensor "tt.return"(%1) : (tensor) -> () }) : () -> () diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/ops.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/ops.mlir index 3417192c5301b2..fc9757387f32c8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/ops.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/ops.mlir @@ -7,13 +7,11 @@ tt.func @xla_triton_tile(%arg0: tensor<512x128xbf16>) -> !triton_xla.tiled_tensor<16x64|512x128xbf16> { - %cst_0 = arith.constant 0 : i32 - %cst_1 = arith.constant 1 : i64 - %cst_16 = arith.constant 16 : i64 - %cst_64 = arith.constant 64 : i64 - %cst_128 = arith.constant 128 : i64 - %tiled_tensor = triton_xla.tile %arg0 [%cst_0, %cst_0] [%cst_16, %cst_64] [%cst_128, %cst_1] - : !triton_xla.tiled_tensor<16x64|512x128xbf16> + %cst_0 = arith.constant 0 : index + %cst_1 = arith.constant 1 : index + %cst_128 = arith.constant 128 : index + %tiled_tensor = triton_xla.tile %arg0 [%cst_0, %cst_0] [%cst_128, %cst_1] + {layout = array} : !triton_xla.tiled_tensor<16x64|512x128xbf16> tt.return %tiled_tensor : !triton_xla.tiled_tensor<16x64|512x128xbf16> } // CHECK-LABEL: xla_triton_tile @@ -23,7 +21,7 @@ tt.func @xla_triton_tile(%arg0: tensor<512x128xbf16>) tt.func @xla_triton_extract(%arg0: !triton_xla.tiled_tensor<16x64|512x128xbf16>) -> tensor<16x64xbf16> { - %cst = arith.constant 0 : i32 + %cst = arith.constant 0 : index %extracted_tensor = triton_xla.extract %arg0 [%cst, %cst] : tensor<512x128xbf16> to tensor<16x64xbf16> tt.return %extracted_tensor : tensor<16x64xbf16> @@ -35,7 +33,7 @@ tt.func @xla_triton_extract(%arg0: !triton_xla.tiled_tensor<16x64|512x128xbf16>) tt.func @xla_triton_insert(%src: tensor<16x64xbf16>, %dst: !triton_xla.tiled_tensor<16x64|512x128xbf16>) -> tensor<512x128xbf16> { - %cst = arith.constant 0 : i32 + %cst = arith.constant 0 : index %updated_tensor = triton_xla.insert %src into %dst [%cst, %cst] : tensor<16x64xbf16> into tensor<512x128xbf16> tt.return %updated_tensor : tensor<512x128xbf16> diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc index 489096478c828b..b11c6665ce5e55 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc @@ -13,20 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/ErrorHandling.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" // IWYU pragma: keep +#include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" -#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" -#include "triton/Tools/LinearLayout.h" namespace mlir::triton::xla { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td index 08a513ed0ea26d..4afc0925af9b63 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td @@ -26,14 +26,6 @@ def XlaTritonDialect : Dialect { This dialect contains ops included in the xla extension point for Triton. }]; - // We need this to register interfaces for tensor and !ttg.memdesc types. - // TODO: b/382459490 - This is wrong layering, triton_xla should not depend on - // triton_gpu, remove this once we refactor the extension and catch up to - // upstream. - let dependentDialects = [ - "::mlir::triton::gpu::TritonGPUDialect", - ]; - let cppNamespace = "::mlir::triton::xla"; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc index 746a8216ca8e74..025bd7f7eba569 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc @@ -30,7 +30,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.cc.inc" -#include "triton/Dialect/TritonGPU/IR/Types.h" using mlir::LogicalResult; using mlir::RankedTensorType; @@ -46,34 +45,23 @@ void TileOp::getAsmResultNames(function_ref setNameFn) { setNameFn(getResult(), "tiled_tensor"); } -template -mlir::ParseResult parseDenseIntArrayAttr(mlir::AsmParser& parser, - DenseIntArrayAttrType& array) { - array = mlir::dyn_cast_or_null( - DenseIntArrayAttrType::parse(parser, mlir::Type{})); - if (!array) return mlir::failure(); - return mlir::success(); -} - ParseResult TileOp::parse(OpAsmParser& parser, OperationState& result) { OpAsmParser::UnresolvedOperand src; TiledTensorType tiled_tensor_type; SmallVector offsets, sizes, strides; if (parser.parseOperand(src) || parser.parseOperandList(offsets, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(sizes, OpAsmParser::Delimiter::Square) || parser.parseOperandList(strides, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(tiled_tensor_type)) { return failure(); } - auto offset_type = parser.getBuilder().getI32Type(); - auto size_and_stride_type = parser.getBuilder().getI64Type(); + + auto index_type = parser.getBuilder().getIndexType(); if (parser.resolveOperand(src, tiled_tensor_type.getOriginalType(), result.operands) || - parser.resolveOperands(offsets, offset_type, result.operands) || - parser.resolveOperands(sizes, size_and_stride_type, result.operands) || - parser.resolveOperands(strides, size_and_stride_type, result.operands)) { + parser.resolveOperands(offsets, index_type, result.operands) || + parser.resolveOperands(strides, index_type, result.operands)) { return failure(); } result.addTypes(tiled_tensor_type); @@ -85,10 +73,10 @@ void TileOp::print(OpAsmPrinter& p) { p << '['; llvm::interleaveComma(getOffsets(), p); p << "]["; - llvm::interleaveComma(getSizes(), p); - p << "]["; llvm::interleaveComma(getStrides(), p); - p << "] : " << getType(); + p << "] {layout = array} : " << getType(); } LogicalResult TileOp::verify() { @@ -96,11 +84,9 @@ LogicalResult TileOp::verify() { return emitError("cannot tile a 0-d tensor"); } auto tensor_rank = getTensor().getType().getRank(); - if (tensor_rank != getOffsets().size() || tensor_rank != getSizes().size() || - tensor_rank != getStrides().size()) + if (tensor_rank != getOffsets().size() || tensor_rank != getStrides().size()) return emitError( - "mismatch between tensor rank and one or more of " - "offsets/sizes/strides"); + "mismatch between tensor rank and one or more of offsets and strides"); return success(); } @@ -129,7 +115,7 @@ ParseResult ExtractOp::parse(OpAsmParser& parser, OperationState& result) { auto tiled_tensor_type = TiledTensorType::get( parser.getContext(), mlir::cast(tile_type), mlir::cast(original_type)); - auto offset_type = builder.getI32Type(); + auto offset_type = builder.getIndexType(); if (parser.resolveOperand(tiled_tensor, tiled_tensor_type, result.operands) || parser.resolveOperands(offsets, offset_type, result.operands)) { return failure(); @@ -185,7 +171,7 @@ ParseResult InsertOp::parse(OpAsmParser& parser, OperationState& result) { parser.getContext(), mlir::cast(tile_type), mlir::cast(original_type)); - auto offset_type = builder.getI32Type(); + auto offset_type = builder.getIndexType(); if (parser.resolveOperand(tiled_tensor, tiled_tensor_type, result.operands) || parser.resolveOperands(offsets, offset_type, result.operands)) { return failure(); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td index 66b0f9eb99979f..8b88977e348511 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td @@ -48,36 +48,24 @@ def TTXLA_TileOp : TTXLA_Op<"tile", [Pure, SameVariadicOperandSize, Example: ``` %tensor: tensor<120x320xbf16> - %of = arith.constant 0 : i32 - %sz1 = arith.constant 16 : i64 - %sz2 = arith.constant 64 : i64 - %st1 = arith.constant 120 : i64 - %st2 = arith.constant 1 : i64 + %of = arith.constant 0 : index + %st1 = arith.constant 120 : index + %st2 = arith.constant 1 : index ... - %tiled_tensor = triton_xla.tile %tensor [%of, %of][%sz1, %sz2][%st1, %st2] - : !triton_xla.tiled_tensor<16x64|120x320xbf16> + %tiled_tensor = triton_xla.tile %tensor [%of, %of][%st1, %st2] + {layout = array}: !triton_xla.tiled_tensor<16x64|120x320xbf16> ``` }]; let arguments = (ins AnyRankedTensor:$tensor, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides + Variadic:$offsets, + Variadic:$strides, + DenseI64ArrayAttr:$layout ); let results = (outs TTXLA_TiledTensorType:$tiled_tensor); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; - - let builders = [ - OpBuilder<(ins - "Value":$tensor, - "ValueRange":$offsets, - "ValueRange":$sizes, - "ValueRange":$strides, - "ArrayRef":$tensor_shape - )> - ]; } def TTXLA_ExtractOp : TTXLA_Op<"extract", [Pure, @@ -98,7 +86,7 @@ def TTXLA_ExtractOp : TTXLA_Op<"extract", [Pure, let arguments = (ins TTXLA_TiledTensorType:$src, - Variadic:$offsets + Variadic:$offsets ); let results = (outs AnyRankedTensor:$result); let hasCustomAssemblyFormat = 1; @@ -125,7 +113,7 @@ def TTXLA_InsertOp : TTXLA_Op<"insert", [Pure, let arguments = (ins AnyRankedTensor:$src, TTXLA_TiledTensorType:$dst, - Variadic:$offsets + Variadic:$offsets ); let results = (outs AnyRankedTensor:$result); let hasCustomAssemblyFormat = 1; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD index 7c7ee2da031866..e992f71db66711 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD @@ -43,10 +43,13 @@ cc_library( hdrs = ["passes.h"], deps = [ ":passes_inc_gen", + "//xla:shape_util", "//xla/backends/gpu/codegen/triton:emitter_helpers", "//xla/backends/gpu/codegen/triton:tma_utils", "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen:emitter_loc_op_builder", + "//xla/codegen/emitters/ir:xla", + "//xla/hlo/analysis:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", @@ -56,10 +59,7 @@ cc_library( "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", @@ -69,11 +69,5 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@triton//:TritonDialects", - "@triton//:TritonGPUToLLVM", - "@triton//:TritonGPUTransforms", - "@triton//:TritonToTritonGPU", - "@triton//third_party/nvidia:NVGPUDialect", - "@triton//third_party/nvidia:NVGPUToLLVM", - "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", ], ) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td index 4ce4096970756d..57c5e09b254b97 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td @@ -27,6 +27,7 @@ def TritonXLAExtractInsertToTritonPass : Pass<"triton-xla-extract-insert-to-trit }]; let dependentDialects = [ "triton::TritonDialect", + "::xla::XlaDialect" ]; let options = [ Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir index 9dbcac0a855742..a0495da17696b3 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir @@ -8,28 +8,27 @@ func.func @lower_tile_extract_insert(%arg0: tensor<512x128xbf16>, %arg1: tensor<256x256xbf16>) -> tensor<256x256xbf16> { - %c = arith.constant 1 : i32 - %c_0 = arith.constant 0 : i32 - %c_1 = arith.constant 1 : i64 - %c_16 = arith.constant 16 : i64 - %c_64 = arith.constant 64 : i64 - %c_128 = arith.constant 128 : i64 - %tiled_tensor_in = triton_xla.tile %arg0 [%c_0, %c_0] [%c_16, %c_64] [%c_128, %c_1] - : !triton_xla.tiled_tensor<16x64|512x128xbf16> - %tiled_tensor_out = triton_xla.tile %arg1 [%c_0, %c_0] [%c_16, %c_64] [%c_128, %c_1] - : !triton_xla.tiled_tensor<16x64|256x256xbf16> - %extracted_tensor = triton_xla.extract %tiled_tensor_in [%c, %c] + %c_0 = arith.constant 0 : index + %c_1 = arith.constant 1 : index + %c_128 = arith.constant 128 : index + %tiled_tensor_in = triton_xla.tile %arg0 [%c_0, %c_0] [%c_128, %c_1] + {layout = array} : !triton_xla.tiled_tensor<16x64|512x128xbf16> + %tiled_tensor_out = triton_xla.tile %arg1 [%c_0, %c_0] [%c_128, %c_1] + {layout = array} : !triton_xla.tiled_tensor<16x64|256x256xbf16> + %extracted_tensor = triton_xla.extract %tiled_tensor_in [%c_1, %c_1] : tensor<512x128xbf16> to tensor<16x64xbf16> %updated_tensor = triton_xla.insert %extracted_tensor into - %tiled_tensor_out [%c, %c] + %tiled_tensor_out [%c_1, %c_1] : tensor<16x64xbf16> into tensor<256x256xbf16> func.return %updated_tensor : tensor<256x256xbf16> } // CHECK-LABEL: tt.func @lower_tile_extract_insert // CHECK-SAME: %[[ARG_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[ARG_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32} -// CHECK: %[[PTR_0:.*]] = tt.make_tensor_ptr %[[ARG_0]] -// CHECK: %[[PTR_1:.*]] = tt.make_tensor_ptr %[[ARG_1]] +// CHECK: %[[ADDPTR_0:.*]] = tt.addptr %[[ARG_0]] +// CHECK: %[[PTR_0:.*]] = tt.make_tensor_ptr %[[ADDPTR_0]] +// CHECK: %[[ADDPTR_1:.*]] = tt.addptr %[[ARG_1]] +// CHECK: %[[PTR_1:.*]] = tt.make_tensor_ptr %[[ADDPTR_1]] // CHECK: %[[ADV_0:.*]] = tt.advance %[[PTR_0]] // CHECK: %[[LOAD:.*]] = tt.load %[[ADV_0]] // CHECK: %[[ADV_1:.*]] = tt.advance %[[PTR_1]] @@ -50,13 +49,12 @@ func.func @lower_tile_extract_insert(%arg0: tensor<512x128xbf16>, func.func @non_perfect_tile_shape( %arg0: tensor<300x300xbf16>, %arg1: tensor<300x300xbf16>) -> tensor<300x300xbf16> { - %c_0 = arith.constant 0 : i32 - %c_1 = arith.constant 1 : i64 - %c_8 = arith.constant 8 : i64 - %tiled_tensor_in = triton_xla.tile %arg0 [%c_0, %c_0] [%c_8, %c_8] [%c_1, %c_1] - : !triton_xla.tiled_tensor<8x8|300x300xbf16> - %tiled_tensor_out = triton_xla.tile %arg1 [%c_0, %c_0] [%c_8, %c_8] [%c_1, %c_1] - : !triton_xla.tiled_tensor<8x8|300x300xbf16> + %c_0 = arith.constant 0 : index + %c_1 = arith.constant 1 : index + %tiled_tensor_in = triton_xla.tile %arg0 [%c_0, %c_0] [%c_1, %c_1] + {layout = array} : !triton_xla.tiled_tensor<8x8|300x300xbf16> + %tiled_tensor_out = triton_xla.tile %arg1 [%c_0, %c_0] [%c_1, %c_1] + {layout = array} : !triton_xla.tiled_tensor<8x8|300x300xbf16> %extracted_tensor = triton_xla.extract %tiled_tensor_in [%c_0, %c_0] : tensor<300x300xbf16> to tensor<8x8xbf16> %updated_tensor = triton_xla.insert %extracted_tensor into @@ -74,19 +72,17 @@ func.func @non_perfect_tile_shape( func.func @incompatible_tma_shapes(%arg0: tensor<1000x1000xbf16>, %arg1: tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16> { - %c_0 = arith.constant 1 : i32 - %c_1 = arith.constant 1 : i64 - %c_512 = arith.constant 512 : i64 - %c_256 = arith.constant 256 : i64 - %c_128 = arith.constant 128 : i64 - %tiled_tensor_in = triton_xla.tile %arg0 [%c_0, %c_0] [%c_512, %c_256] [%c_128, %c_1] - : !triton_xla.tiled_tensor<512x256|1000x1000xbf16> - %tiled_tensor_out = triton_xla.tile %arg1 [%c_0, %c_0] [%c_512, %c_256] [%c_128, %c_1] - : !triton_xla.tiled_tensor<512x256|1024x1024xbf16> - %extracted_tensor = triton_xla.extract %tiled_tensor_in [%c_0, %c_0] + %c_0 = arith.constant 0 : index + %c_1 = arith.constant 1 : index + %c_128 = arith.constant 128 : index + %tiled_tensor_in = triton_xla.tile %arg0 [%c_0, %c_0] [%c_128, %c_1] + {layout = array} : !triton_xla.tiled_tensor<512x256|1000x1000xbf16> + %tiled_tensor_out = triton_xla.tile %arg1 [%c_0, %c_0] [%c_128, %c_1] + {layout = array} : !triton_xla.tiled_tensor<512x256|1024x1024xbf16> + %extracted_tensor = triton_xla.extract %tiled_tensor_in [%c_1, %c_1] : tensor<1000x1000xbf16> to tensor<512x256xbf16> %updated_tensor = triton_xla.insert %extracted_tensor into - %tiled_tensor_out [%c_0, %c_0] + %tiled_tensor_out [%c_1, %c_1] : tensor<512x256xbf16> into tensor<1024x1024xbf16> func.return %updated_tensor : tensor<1024x1024xbf16> } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index 30ac3dd5fa3645..e144f234a64213 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -31,9 +30,12 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -52,6 +54,10 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/tma_utils.h" #include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/codegen/emitter_loc_op_builder.h" +#include "xla/codegen/emitters/ir/xla_ops.h" +#include "xla/hlo/analysis/indexing_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -89,6 +95,16 @@ bool AreRankedTensors(ArrayRef types) { }); } +SmallVector IndexCastUI(::xla::EmitterLocOpBuilder& builder, Type type, + ValueRange values) { + SmallVector result; + result.reserve(values.size()); + for (auto value : values) { + result.push_back(builder.create(type, value)); + } + return result; +} + bool TmaIsEnabledForDevice( const stream_executor::DeviceDescription& device_info) { bool is_cuda = std::holds_alternative( @@ -163,6 +179,84 @@ void ComputeBoundaryChecks(std::vector& boundary_checks, } } +// TensorPtr is intended to wrap the base pointer of the TiledHloInstruction and +// the necessary offsets so that Triton can compute the pointer to the +// block specific to the given pid. This option would yield simpler code, +// but cannot handle all combinations of strides and offsets, because Triton +// always multiplies the offset by the stride. E.g., it's not possible to +// slice [10] with [1:5:2] because the offset is misaligned with regards to the +// stride. +// +// Instead, we output a TensorPtr that points directly to the tile specific +// to the pid. All offset computation is done in advance. MakeTensorPtrOp +// sees 0 offsets. This allows Triton to read any block regardless of +// strides size or offsets. To make sure that masking is correct, we compute +// a "residual shape" which is the original parent shape minus the offsets. +SmallVector ComputeResidualShape(::xla::EmitterLocOpBuilder& builder, + ArrayRef original_shape, + ValueRange tile_offsets) { + SmallVector residual_shape; + for (auto [dim_idx, shape_and_tile_offset] : + llvm::enumerate(llvm::zip(original_shape, tile_offsets))) { + auto [shape, tile_offset] = shape_and_tile_offset; + Value size = + ::xla::gpu::triton::CreateConst(builder, builder.getI64Type(), shape) + .UnwrapScalar(); + // Offsets are necessarily positive since they represent a distance + // between 0 and the size of the tensor on the given axis. Therefore, it + // is safe to use 'IndexCastUI' here. This allows index canonicalizations + // later on. + Value offset = + builder.create(builder.getI64Type(), tile_offset); + residual_shape.push_back(builder.create(size, offset)); + } + + return residual_shape; +} + +// Compute physical strides of the tile. `tile_strides` contains strides for +// individual dimensions. We need to convert them to strides in the buffer +// taking into account physical layout. Note that we should pass in the +// minor-to-major layout for this to work correctly. +SmallVector ComputeStrides(::xla::EmitterLocOpBuilder& builder, + ArrayRef original_shape, + ValueRange tile_strides, + ArrayRef minor_to_major_layout) { + SmallVector strides(tile_strides.size()); + int64_t current_stride = 1; + for (int64_t cur_dim : minor_to_major_layout) { + strides[cur_dim] = builder.create( + builder.create(builder.getI64Type(), + tile_strides[cur_dim]), + ::xla::gpu::triton::CreateConst(builder, builder.getI64Type(), + current_stride) + .UnwrapScalar()); + current_stride *= original_shape[cur_dim]; + } + return strides; +} + +// Based on the multi-dimensional offsets and layout of the shape, we compute +// a linear offset. We do this because we move the pointer to the correct +// position via tt.addptr prior to calling tt.make_tensor_ptr. +Value ComputeLinearOffset(::xla::EmitterLocOpBuilder& builder, + const TiledTensorType& tiled_tensor_type, + ValueRange offsets, llvm::ArrayRef layout) { + ::xla::Shape shape = ::xla::ShapeUtil::MakeShapeWithDenseLayout( + xgt::GetPrimitiveType(tiled_tensor_type.getElementType()).value(), + tiled_tensor_type.getOriginalShape(), layout); + + ::xla::Shape linear_shape = ::xla::ShapeUtil::MakeShape( + shape.element_type(), {::xla::ShapeUtil::ElementsIn(shape)}); + auto bitcast_map = + ::xla::GetBitcastMap(shape, linear_shape, builder.getContext()); + + return builder.create( + builder.getI64Type(), + builder.create<::xla::ApplyIndexingOp>(offsets, bitcast_map) + .getResult(0)); +} + struct RewriteFuncOp : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -250,9 +344,10 @@ struct RewriteTile : mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( TileOp op, mlir::PatternRewriter& rewriter) const override { ::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter); + auto tiled_tensor_type = op.getTiledTensor().getType(); - if (CanUseTMA(builder, tma_enabled, *device_description, - op.getTiledTensor().getType(), op.getTensor())) { + if (CanUseTMA(builder, tma_enabled, *device_description, tiled_tensor_type, + op.getTensor())) { // Add TMA attributes to the corresponding argument in the function. auto block_arg = mlir::dyn_cast(op.getTensor()); auto func_op = @@ -262,7 +357,6 @@ struct RewriteTile : mlir::OpRewritePattern { // Prefixing the attribute name with "tt", otherwise tt.func will // complain that it is not part of the dialect. Not the best way to // do this, but it works for now. - auto tiled_tensor_type = op.getTiledTensor().getType(); func_op.setArgAttr( block_arg.getArgNumber(), "tt.tma_descriptor", builder.getAttr( @@ -279,25 +373,19 @@ struct RewriteTile : mlir::OpRewritePattern { op.getTensor()) .getResult(0); - auto reinterpret_tensor_desc = - xg::EmitTmaDescriptor(builder, cast_to_tensor_ptr_type, - op.getTiledTensor().getType().getTileType()); + auto reinterpret_tensor_desc = xg::EmitTmaDescriptor( + builder, cast_to_tensor_ptr_type, tiled_tensor_type.getTileType()); // !tt.tensordesc -> tiled_tensor auto cast_desc_ptr_to_tiled_tensor_ptr_type = builder.create( - xgt::StorageType(builder, op.getTiledTensor().getType()), + xgt::StorageType(builder, tiled_tensor_type), reinterpret_tensor_desc); rewriter.replaceOp(op, cast_desc_ptr_to_tiled_tensor_ptr_type); return mlir::success(); } - // Order is rank - 1, ..., 1, 0. - std::vector dim_order(op.getSizes().size()); - std::iota(dim_order.begin(), dim_order.end(), 0); - std::reverse(dim_order.begin(), dim_order.end()); - // tensor -> !tt.ptr<> auto cast_to_tensor_ptr_type = builder @@ -307,23 +395,46 @@ struct RewriteTile : mlir::OpRewritePattern { op.getTensor()) .getResult(0); - auto tile_shape = op.getTiledTensor().getType().getTileShape(); - std::vector tile_shape_i32; - std::transform(tile_shape.begin(), tile_shape.end(), - std::back_inserter(tile_shape_i32), - [](int64_t value) { return static_cast(value); }); - auto tensor_ptr = - builder - .create(cast_to_tensor_ptr_type, op.getSizes(), - op.getStrides(), op.getOffsets(), - tile_shape_i32, dim_order) - .getResult(); + auto linear_offset = ComputeLinearOffset(builder, tiled_tensor_type, + op.getOffsets(), op.getLayout()); + + auto ptr = builder + .create(cast_to_tensor_ptr_type.getType(), + cast_to_tensor_ptr_type, linear_offset) + .getResult(); + + // Only emit make_tensor_ptr if the input is not a scalar. + auto tile_shape = tiled_tensor_type.getTileShape(); + if (!tile_shape.empty()) { + // TODO(b/342989850): Clarify and comment what `order` exactly is. It's + // not entirely clear from the Triton docs. Currently we are propagating + // the layout from the original tensor. + auto dim_order = llvm::to_vector_of(op.getLayout()); + + SmallVector residual_shape = ComputeResidualShape( + builder, tiled_tensor_type.getOriginalShape(), op.getOffsets()); + + // Offsets are always passed as 0 since we are using "residual shape". + SmallVector zero_offsets( + tile_shape.size(), + ::xla::gpu::triton::CreateConst(builder, builder.getI32Type(), 0) + .UnwrapScalar()); + + SmallVector strides = + ComputeStrides(builder, tiled_tensor_type.getOriginalShape(), + op.getStrides(), op.getLayout()); + + ptr = builder + .create( + ptr, residual_shape, strides, zero_offsets, + llvm::to_vector_of(tile_shape), dim_order) + .getResult(); + } // !tt.ptr -> tiled_tensor auto cast_to_tiled_tensor_type = builder.create( - xgt::StorageType(builder, op.getTiledTensor().getType()), - tensor_ptr); + xgt::StorageType(builder, tiled_tensor_type), ptr); rewriter.replaceOp(op, cast_to_tiled_tensor_type); return mlir::success(); @@ -357,7 +468,7 @@ struct RewriteExtract : mlir::OpRewritePattern { builder .create( op.getResult().getType(), cast_to_tensor_desc_ptr_type, - op.getOffsets()) + IndexCastUI(builder, builder.getI32Type(), op.getOffsets())) .getResult(); rewriter.replaceOp(op, descriptor_load); @@ -372,9 +483,9 @@ struct RewriteExtract : mlir::OpRewritePattern { op.getSrc()) .getResult(0); - auto advance = - builder.create(cast_to_tensor_ptr_type.getType(), - cast_to_tensor_ptr_type, op.getOffsets()); + auto advance = builder.create( + cast_to_tensor_ptr_type.getType(), cast_to_tensor_ptr_type, + IndexCastUI(builder, builder.getI32Type(), op.getOffsets())); std::vector boundary_checks; ComputeBoundaryChecks(boundary_checks, op.getSrc().getType()); std::optional padding; @@ -413,7 +524,8 @@ struct RewriteInsert : mlir::OpRewritePattern { .getResult(0); builder.create( - cast_to_tensor_desc_ptr_type, op.getSrc(), op.getOffsets()); + cast_to_tensor_desc_ptr_type, op.getSrc(), + IndexCastUI(builder, builder.getI32Type(), op.getOffsets())); } else { // tiled_tensor -> !tt.ptr auto cast_dst_to_tensor_ptr_type = @@ -427,7 +539,7 @@ struct RewriteInsert : mlir::OpRewritePattern { auto advance = builder.create( cast_dst_to_tensor_ptr_type.getType(), cast_dst_to_tensor_ptr_type, - op.getOffsets()); + IndexCastUI(builder, builder.getI32Type(), op.getOffsets())); std::vector boundary_checks; ComputeBoundaryChecks(boundary_checks, op.getDst().getType()); std::optional padding; From ea4cacead963a3217d02ef2668353beeef3f6ff6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 09:58:32 -0700 Subject: [PATCH 0097/1324] Remove mlir lowerings for operations that are covered by Xla Builder and the mlir tests that were coving that unused code. PiperOrigin-RevId: 742740372 --- .../compiler/mlir/lite/tests/prepare-tf.mlir | 14 - .../tests/legalize-tf-BatchMatMulV2.mlir | 89 -- .../tests/legalize-tf-binary-elementwise.mlir | 87 -- .../legalize-tf-with-tf2xla-hlo-importer.mlir | 9 - .../mlir/tf2xla/tests/legalize-tf.mlir | 1058 ----------------- .../mlir/tf2xla/transforms/legalize_tf.cc | 361 ------ .../tf2xla/transforms/legalize_tf_patterns.td | 37 - 7 files changed, 1655 deletions(-) delete mode 100644 tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index f6e8d6610aba74..974fbc2ab7c788 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -521,20 +521,6 @@ func.func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1 // CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) <{squeeze_dims = [-2]}> : (tensor<10x20x1x30xcomplex>) -> tensor<10x20x30xcomplex> } -// CHECK-LABEL: xla_gather_to_strided_slice -func.func @xla_gather_to_strided_slice(%arg0 : tensor<1x9x104x768xf32>) -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<[1, 9, 23, 768]> : tensor<4xi32>} : () -> tensor<4xi32> - %2 = "tf.XlaGather"(%arg0, %0, %1) {device = "", dimension_numbers = "\0A\04\00\01\02\03\1A\01\02", indices_are_sorted = false} : (tensor<1x9x104x768xf32>, tensor<1xi32>, tensor<4xi32>) -> tensor - func.return %2 : tensor - -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<4xi64> -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<[1, 9, 23, 768]> : tensor<4xi64> -// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : tensor<4xi64> -// CHECK: %[[V0:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]]) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x9x104x768xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor -// CHECK: return %[[V0]] : tensor -} - // CHECK-LABEL: DontMatchFusedBatchNormV3 func.func @DontMatchFusedBatchNormV3(%arg0 :tensor, %arg1 : tensor<576xf32>, %arg2 : tensor<576xf32>, %arg3 : tensor<576xf32>,%arg4 : tensor<576xf32>) -> (tensor) { %result:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>) -> (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<*xf32>) diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir deleted file mode 100644 index f62e9a140e83d9..00000000000000 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir +++ /dev/null @@ -1,89 +0,0 @@ -// RUN: tf-opt -xla-legalize-tf %s | FileCheck %s - -//===----------------------------------------------------------------------===// -// tf.BatchMatMulV2 op legalizations. -//===----------------------------------------------------------------------===// - -func.func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { -// CHECK-LABEL: func @batchmatmulv2_basic -// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> -// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32> -// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32> -// CHECK: [[CM2:%.*]] = arith.constant -2 : index -// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) -// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) -// CHECK: [[BCASTHEAD:%.*]] = shape.broadcast [[LHSHEAD]], [[RHSHEAD]] -// CHECK: [[LHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[LHSTAIL]] -// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] -// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> -// CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]] -// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) -// CHECK: return [[RESULT]] : tensor<3x4x4xf32> -// CHECK: } - - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - func.return %0 : tensor<3x4x4xf32> -} - -func.func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { -// CHECK-LABEL: func @batchmatmulv2_lhs_batch -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ -// CHECK-SAME: lhs_batching_dimensions = [0] -// CHECK-SAME: rhs_batching_dimensions = [0] -// CHECK-SAME: lhs_contracting_dimensions = [2] -// CHECK-SAME: rhs_contracting_dimensions = [1] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32> - func.return %0 : tensor<3x4x4xf32> -} - -func.func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { -// CHECK-LABEL: func @batchmatmulv2_rhs_batch -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ -// CHECK-SAME: lhs_batching_dimensions = [0] -// CHECK-SAME: rhs_batching_dimensions = [0] -// CHECK-SAME: lhs_contracting_dimensions = [2] -// CHECK-SAME: rhs_contracting_dimensions = [1] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - func.return %0 : tensor<3x4x4xf32> -} - -func.func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { -// CHECK-LABEL: func @batchmatmulv2_dynamic -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ -// CHECK-SAME: lhs_batching_dimensions = [0] -// CHECK-SAME: rhs_batching_dimensions = [0] -// CHECK-SAME: lhs_contracting_dimensions = [2] -// CHECK-SAME: rhs_contracting_dimensions = [1] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> { -// CHECK-LABEL: func @batchmatmulv2_adj_real -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [0] -// CHECK-SAME: rhs_contracting_dimensions = [1] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32> - func.return %0 : tensor<5x4xf32> -} - -func.func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex>, %arg1: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { -// CHECK-LABEL: func @batchmatmulv2_adj_complex( -// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex>, [[RHS:%.*]]: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { -// CHECK: [[LHSRE:%.*]] = mhlo.real [[LHS]] -// CHECK: [[LHSIM:%.*]] = mhlo.imag [[LHS]] -// CHECK: [[LHSIMNEG:%.*]] = mhlo.negate [[LHSIM]] -// CHECK: [[LHSCONJ:%.*]] = mhlo.complex [[LHSRE]], [[LHSIMNEG]] -// CHECK: [[RHSRE:%.*]] = mhlo.real [[RHS]] -// CHECK: [[RHSIM:%.*]] = mhlo.imag [[RHS]] -// CHECK: [[RHSIMNEG:%.*]] = mhlo.negate [[RHSIM]] -// CHECK: [[RHSCONJ:%.*]] = mhlo.complex [[RHSRE]], [[RHSIMNEG]] -// CHECK: shape.shape_of [[LHSCONJ]] -// CHECK: shape.shape_of [[RHSCONJ]] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex>, tensor<4x2xcomplex>) -> tensor<5x4xcomplex> - func.return %0 : tensor<5x4xcomplex> -} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir index da64452a3039f8..a1cfcb69f9c27e 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir @@ -73,13 +73,6 @@ func.func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { func.return %0: tensor<2xi32> } -// CHECK-LABEL: func @shift_left -func.func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_left %arg0, %arg1 : tensor<4xi32> - %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - // CHECK-LABEL: func @div_unranked func.func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { // CHECK-NEXT: tf.Div @@ -94,22 +87,6 @@ func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> } -// CHECK-LABEL: func @minimum -func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @mod -// CHLO-LABEL: func @mod -func.func @mod(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: mhlo.remainder %arg0, %arg1 : tensor<4xf32> - // CHLO: chlo.broadcast_remainder - %0 = "tf.Mod"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - // CHECK-LABEL: func @mul func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> @@ -118,13 +95,6 @@ func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { func.return %0: tensor<2xi32> } -// CHECK-LABEL: func @real_div -func.func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> - %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %0: tensor<2xi32> -} - // CHECK-LABEL: func @sub func.func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> @@ -133,28 +103,6 @@ func.func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { func.return %0: tensor<2xi32> } -// CHECK-LABEL: func @shift_right -func.func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @shift_right_unsigned -func.func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { - // CHECK: mhlo.shift_right_logical %arg0, %arg1 : tensor<4xui8> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> - func.return %0 : tensor<4xui8> -} - -// CHECK-LABEL: func @broadcast_shift_right_unsigned -func.func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { - // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xui8>) -> tensor<2x4xui8> - // CHECK: mhlo.shift_right_logical %[[BROADCAST]], %arg1 : tensor<2x4xui8> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> - func.return %0 : tensor<2x4xui8> -} - // CHECK-LABEL: func @and func.func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: mhlo.and @@ -176,20 +124,6 @@ func.func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { func.return %0: tensor<2xi1> } -// CHECK-LABEL: func @bitwise_or -func.func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: mhlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_or_unsigned -func.func @bitwise_or_unsigned(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> tensor<4xui32> { - // CHECK-NEXT: mhlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xui32>, tensor<4xui32>) -> tensor<4xui32> - func.return %0: tensor<4xui32> -} - // CHECK-LABEL: func @bitwise_xor func.func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: mhlo.xor @@ -204,27 +138,6 @@ func.func @bitwise_xor_unsigned(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> func.return %0: tensor<4xui32> } -// CHECK-LABEL: func @bitwise_and -func.func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: mhlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_and_unsigned -func.func @bitwise_and_unsigned(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> tensor<4xui32> { - // CHECK-NEXT: mhlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xui32>, tensor<4xui32>) -> tensor<4xui32> - func.return %0: tensor<4xui32> -} - -// CHECK-LABEL: func @pow -func.func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: mhlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %0: tensor<2xf32> -} - //===----------------------------------------------------------------------===// // Equality op legalizations. // tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index f1fb2fec85722c..9aa7a763b329bd 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -652,15 +652,6 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %1 : tensor<64x128xf32> } - // CHECK-LABEL: func @tf_mod - func.func @tf_mod(%arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %cst = "tf.Const"() {value = dense<7.000000e+00> : tensor} : () -> tensor - // CHECK: "mhlo.dynamic_broadcast_in_dim" - // CHECK: mhlo.remainder - %6 = "tf.Mod"(%arg1, %cst) {_global_shape = [#tf_type.shape<4x8>], device = ""} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - return %6 : tensor<2x2xf32> - } - // CHECK-LABEL: func @concat_v2 func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 92754a181e8551..4c6a0c52c36edb 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -1,11 +1,4 @@ // RUN: tf-opt "-xla-legalize-tf=legalize-chlo=false" -split-input-file %s | FILECHECK_OPTS="" FileCheck %s -// RUN: tf-opt "-xla-legalize-tf=legalize-chlo=true" -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix CHLO -// This test runs twice: -// 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying -// that the chlo ops emit produces more useful tests. -// 2. With chlo legalization enabled, verifying diagnostics to pick up any -// issues with the full lowering (can catch some broadcasting corner -// cases which emit with a warning). //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -13,29 +6,6 @@ // ----- -// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same -// code), so only do a couple of basic checks. - -// CHECK-LABEL: fusedBatchNormV2_noTraining -func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormV2_training -func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: mhlo.constant - // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - // CHECK-LABEL: fusedBatchNormV3_noTraining func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> @@ -160,139 +130,6 @@ func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> - - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> - - // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> - - // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> - - // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGrad_Training -func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_noTraining -func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> - - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> - - // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - - // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> - - // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> - - // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_Training -func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xf32> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision -func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - - // CHECK: %[[x_backprop:.*]] = mhlo.convert {{.*}} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - -// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision -func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> - - %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<8x8x8x8xbf16> -} - -// ----- - // CHECK-LABEL: fusedBatchNormGradV3_noTraining func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> @@ -731,32 +568,6 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten func.return %2: tensor<3x5x7x9x11x4x10xf32> } -//===----------------------------------------------------------------------===// -// Erf -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @erf -func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: mhlo.erf %arg0 : tensor<2x3xf32> - %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %0 : tensor<2x3xf32> -} - -//===----------------------------------------------------------------------===// -// Erfc -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @erfc -func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: chlo.erfc %arg0 : tensor<2x3xf32> - %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %0 : tensor<2x3xf32> -} - //===----------------------------------------------------------------------===// // Einsum. //===----------------------------------------------------------------------===// @@ -780,242 +591,6 @@ func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { func.return %0: tensor<2x2xf32> } -//===----------------------------------------------------------------------===// -// FloorDiv and FloorMod. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @floordiv_broadcast_i32 -func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} - // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} - // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] - // CHECK: return [[SELECT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 -func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} - // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] - // CHECK: return [[SELECT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floordiv_f32 -func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 - // CHECK-NEXT: %[[FLOOR:.*]] = mhlo.floor %[[DIV]] - // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> - %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %0: tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @floordiv_bf16 -func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { - // CHECK-NEXT: mhlo.convert - // CHECK-NEXT: mhlo.convert - // CHECK-NEXT: chlo.broadcast_divide - // CHECK-NEXT: mhlo.floor - // CHECK-NEXT: mhlo.convert - // CHECK-NEXT: return - %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> - func.return %0: tensor<2xbf16> -} - -// ----- - -// CHECK-LABEL: func @floordiv_f16_broadcast -func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: chlo.broadcast_divide - // CHECK-NEXT: mhlo.floor - // CHECK-NEXT: return - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - func.return %0: tensor<2x3xf16> -} - -// ----- - -// CHECK-LABEL: func @floordiv_dynamic -func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} - // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} - // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] - // CHECK: return [[SELECT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floordiv_unsigned -func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} - // CHECK: return [[DIV]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floordiv_int -func.func @floordiv_int(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 : (tensor, tensor) -> tensor - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 : (tensor, tensor) -> tensor - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> : tensor - // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] - // CHECK: return [[SELECT]] - %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floormod_broadcast_numerator -func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} - // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] - // CHECK-NEXT: return [[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floormod_broadcast_denominator -func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} - // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] - // CHECK-NEXT: return [[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - func.return %0: tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator -func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-NEXT: return [[REM]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> - func.return %0: tensor<2x3xui32> -} - -// ----- - -// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator -func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} - // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} - // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] - // CHECK-NEXT: return [[SELECT]] - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -// ----- - -// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator -func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NOT: tf.FloorMod - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} : (tensor, tensor) -> tensor - // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] : (tensor, tensor) -> tensor - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] : (tensor, tensor) -> tensor - // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] : tensor, tensor - // CHECK-NEXT: return [[SELECT]] : tensor - %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0: tensor -} - -//===----------------------------------------------------------------------===// -// OnesLike -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @ones_like -// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) -func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { - // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) <{value = 1.0{{.*}}}> - // CHECK: return %[[RES]] - %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> - func.return %0 : tensor<2x?xf32> -} - //===----------------------------------------------------------------------===// // ZerosLike //===----------------------------------------------------------------------===// @@ -1054,15 +629,6 @@ func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { // ----- -// CHECK-LABEL: func @complex -func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { - // CHECK: chlo.broadcast_complex - %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> - func.return %1 : tensor<3xcomplex> -} - -// ----- - // CHECK-LABEL: func @imag func.func @imag(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { // CHECK: mhlo.imag @@ -1123,63 +689,6 @@ func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf3 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> func.return %1 : tensor<3x6xf32> } - -//===----------------------------------------------------------------------===// -// Pad op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @padv2_1D -func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { - %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> - // CHECK: "mhlo.pad"(%arg0, %arg1) <{ - // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, - // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, - // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> - %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor) -> tensor<6xf32> - func.return %1 : tensor<6xf32> -} - -// ----- - -// CHECK-LABEL: func @padv2_2D -func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { - %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> - // CHECK: "mhlo.pad"(%arg0, %arg1) <{ - // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, - // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, - // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> - %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor) -> tensor<6x9xf32> - func.return %1 : tensor<6x9xf32> -} - -// ----- - -// CHECK-LABEL: func @padv2_i32_paddings -func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { - %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> - // CHECK: "mhlo.pad"(%arg0, %arg1) <{ - // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, - // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, - // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> - %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32> - func.return %1 : tensor<6x9xf32> -} - -// ----- - -// CHECK-LABEL: func @padv2_dynamic -func.func @padv2_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x2xi64>) -> tensor { - // CHECK: "mhlo.transpose"({{.*}}) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<1x2xi64>) -> tensor<2x1xi64> - // CHECK: mhlo.reshape {{.*}} : (tensor<2x1xi64>) -> tensor<2xi64> - // CHECK: "mhlo.slice"({{.*}}) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: "mhlo.slice"({{.*}}) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: mhlo.dynamic_pad {{.*}} : (tensor, tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor, tensor<1x2xi64>, tensor) -> tensor - func.return %1 : tensor -} - //===----------------------------------------------------------------------===// // Identity op legalizations. //===----------------------------------------------------------------------===// @@ -1195,15 +704,6 @@ func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { // ----- -// CHECK-LABEL: func @identityN -func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { - // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> - %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) - func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> -} - -// ----- - // CHECK-LABEL: func @stopgradient func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: return %arg0 : tensor<1xi32> @@ -1415,98 +915,6 @@ func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4 // ----- -// CHECK-LABEL: maxpool_valid_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor - // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK: <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}> - // CHECK: mhlo.maximum - // CHECK: mhlo.return - - %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> - func.return %0 : tensor<2x3x5x7xi32> -} - -// ----- - -// CHECK-LABEL: maxpool_same_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - - %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> - func.return %0 : tensor<2x4x7x7xi32> -} - -// ----- - -// CHECK-LABEL: maxpool_3d_valid_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor - // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK: <{window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>}> - // CHECK: mhlo.maximum - // CHECK: mhlo.return - - %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> - func.return %0 : tensor<2x8x3x5x7xf32> -} - -// ----- - -// CHECK-LABEL: maxpool_3d_same_padding -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> - - %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> - func.return %0 : tensor<2x8x4x7x7xf32> -} - -// ----- - -// CHECK-LABEL: maxpool_explicit_padding -func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { - // CHECK: tf.MaxPool - // TODO(b/165938852): need to support explicit padding in max_pool. - - %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> - func.return %0 : tensor<2x3x5x7xi32> -} - -//===----------------------------------------------------------------------===// -// MaxPoolGrad op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @max_pool_grad_valid -// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> -func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { - // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ - // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }, { - // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor - // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> - // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> - %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { - data_format = "NHWC", - ksize = [1, 2, 2, 1], - padding = "VALID", - strides = [1, 2, 2, 1] - } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> - func.return %result : tensor<10x24x24x64xf32> -} - -// ----- - // CHECK-LABEL: @max_pool_3d_grad_valid // CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { @@ -1527,20 +935,6 @@ func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_ // ----- -// CHECK-LABEL: @max_pool_grad_same -func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { - // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { - data_format = "NHWC", - ksize = [1, 2, 3, 1], - padding = "SAME", - strides = [1, 4, 4, 1] - } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> - func.return %result : tensor<2x13x25x7xf32> -} - -// ----- - // CHECK-LABEL: @max_pool_3d_grad_same func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> @@ -1680,49 +1074,6 @@ func.func @callee() { func.return } -//===----------------------------------------------------------------------===// -// ReverseV2 op legalization. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @reverse_func_32 -func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { - %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) - - // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> - %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> - - // CHECK: return [[VAL]] : tensor<5xi32> - func.return %reversed : tensor<5xi32> -} - -// ----- - -// CHECK-LABEL: @reverse_func_64 -func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { - %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) - - // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> - %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> - - // CHECK: return [[VAL]] : tensor<5xi32> - func.return %reversed : tensor<5xi32> -} - -// ----- - -// CHECK-LABEL: @reverse_func_neg -func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { - %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - - // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<1> : tensor<1xi64>}> - %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> - - // CHECK: return [[VAL]] : tensor<5x5xi32> - func.return %reversed : tensor<5x5xi32> -} - //===----------------------------------------------------------------------===// // StatefulPartitionedCall op legalization. //===----------------------------------------------------------------------===// @@ -1755,39 +1106,6 @@ func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) - func.return %arg1, %arg0 : tensor, tensor } -//===----------------------------------------------------------------------===// -// Elu op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @elu -func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1xf32>) -> tensor<1xf32> - // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] - // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 - // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %arg0, %[[EXP]] - // CHECK: return %[[RESULT]] - %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> - func.return %0: tensor<1xf32> -} - -// ----- - -// CHECK-LABEL: func @elu_grad -// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) -func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = array, comparison_direction = #chlo} - // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = array} - // CHECK-DAG: %[[MULGRAD:.*]] = mhlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] - // CHECK: return %[[RESULT]] - %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - func.return %2 : tensor<4x8xf32> -} - //===----------------------------------------------------------------------===// // Relu op legalizations. //===----------------------------------------------------------------------===// @@ -1834,85 +1152,6 @@ func.func @relu6_unsigned(%arg0: tensor) -> tensor { func.return %0: tensor } -// ----- - -// CHECK-LABEL: func @leaky_relu -func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { - // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> - // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> - // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32> - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> - // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[INP]], %[[LEAKY]] : tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32> - // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> - %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> - func.return %0 : tensor<1x4x4x3xf32> -} - -// ----- - -// CHECK-LABEL: func @leaky_relu_grad -func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { - // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> - // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> - // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32> - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> - // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] : tensor<1x4x4xi1>, tensor<1x4x4xf32> - // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> - %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> - func.return %0 : tensor<1x4x4xf32> -} - -// ----- - -// CHECK-LABEL: func @softsign -func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { - // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) <{value = 1.000000e+00 : f32}> : (tensor<4x10xf32>) -> tensor<4x10xf32> - // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> - // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> - // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> - %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> - func.return %0 : tensor<4x10xf32> -} - -// ----- - -// CHECK-LABEL: func @softsign_grad -func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { - - // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = array} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> - // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> - %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> - func.return %0 : tensor<4x10xf32> -} - -//===----------------------------------------------------------------------===// -// Roll op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @Roll_0D -func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi32> { - %axis = "tf.Const"() {value = dense<0> : tensor} : () -> (tensor) - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: %[[AXIS_SIZE:.*]] = mhlo.constant dense<512> : tensor - // CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor - // CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor - // CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor - // CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) <{dimension = 0 : i64}> - // CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor - // CHECK: "mhlo.dynamic_slice"(%[[CONCAT]], %[[OFFSET]]) - // CHECK-SAME: {slice_sizes = dense<512> : tensor<1xi64>} - // CHECK-SAME: (tensor<1024xi32>, tensor) -> tensor<512xi32> - %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor, tensor) -> tensor<512xi32> - func.return %0 : tensor<512xi32> -} - //===----------------------------------------------------------------------===// // Select op legalizations. //===----------------------------------------------------------------------===// @@ -2025,15 +1264,6 @@ func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32> // ----- -// CHECK-LABEL: func @fft_1D -func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { - // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xcomplex> - %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - -// ----- - // CHECK-LABEL: func @ifft_1D func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xcomplex> @@ -2043,38 +1273,6 @@ func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { // ----- -// CHECK-LABEL: func @rfft_1D -func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xf32> - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex> - func.return %0 : tensor<5xcomplex> -} - -// ----- - -// CHECK-LABEL: func @rfft_1D_padded -func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<7xf32>, tensor) -> tensor<8xf32> - // CHECK: "mhlo.fft"(%[[PADDED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xf32> - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex> - func.return %0 : tensor<5xcomplex> -} - -// ----- - -// CHECK-LABEL: func @rfft_1D_sliced -func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x9xf32>) -> tensor<2x8xf32> - // CHECK: "mhlo.fft"(%[[SLICED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<2x8xf32> - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex> - func.return %0 : tensor<2x5xcomplex> -} - -// ----- - // CHECK-LABEL: func @irfft_1D func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) @@ -2084,25 +1282,6 @@ func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { func.return %0 : tensor<8xf32> } -// ----- - -// CHECK-LABEL: fft_1D_dynamic -func.func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { - // CHECK: "tf.FFT" - %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - -// ----- - -// CHECK-LABEL: rfft_1D_dynamic -func.func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { - %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: "tf.RFFT" - %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> - func.return %0 : tensor<8xcomplex> -} - //===----------------------------------------------------------------------===// // Shape op legalization. //===----------------------------------------------------------------------===// @@ -2221,188 +1400,6 @@ func.func @abs_dynamic(%arg0: tensor) -> tensor { // ----- -// CHECK-LABEL: @acos -// CHLO-LABEL: @acos -func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: chlo.acos %arg0 : tensor<2xf32> - // CHLO: %[[TEMP_0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_1:.*]] = mhlo.subtract %[[TEMP_0]], %arg0 : tensor<2xf32> - // CHLO: %[[TEMP_2:.*]] = mhlo.add %[[TEMP_0]], %arg0 : tensor<2xf32> - // CHLO: %[[TEMP_3:.*]] = mhlo.multiply %[[TEMP_1]], %[[TEMP_2]] : tensor<2xf32> - // CHLO: %[[TEMP_4:.*]] = mhlo.sqrt %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_5:.*]] = mhlo.atan2 %[[TEMP_4]], %arg0 : tensor<2xf32> - // CHLO: return %[[TEMP_5]] : tensor<2xf32> - %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @acos_complex -// CHLO-LABEL: @acos_complex -func.func @acos_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK: chlo.acos - // CHLO: %[[TEMP_0:.*]] = mhlo.real %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> - // CHLO: %[[TEMP_1:.*]] = mhlo.abs %[[TEMP_0]] : tensor<2xf32> - // CHLO: %[[TEMP_2:.*]] = mhlo.imag %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> - // CHLO: %[[TEMP_3:.*]] = mhlo.abs %[[TEMP_2]] : tensor<2xf32> - // CHLO: %[[TEMP_4:.*]] = mhlo.maximum %[[TEMP_1]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_5:.*]] = mhlo.constant dense<3.40282347E+38> : tensor<2xf32> - // CHLO: %[[TEMP_6:.*]] = mhlo.sqrt %[[TEMP_5]] : tensor<2xf32> - // CHLO: %[[TEMP_7:.*]] = mhlo.constant dense<8.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_8:.*]] = mhlo.divide %[[TEMP_6]], %[[TEMP_7]] : tensor<2xf32> - // CHLO: %[[TEMP_9:.*]] = mhlo.compare GE, %[[TEMP_4]], %[[TEMP_8]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_10:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_11:.*]] = mhlo.compare LE, %[[TEMP_1]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_12:.*]] = mhlo.constant dense<5.000000e-01> : tensor<2xf32> - // CHLO: %[[TEMP_13:.*]] = mhlo.add %[[TEMP_1]], %[[TEMP_10]] : tensor<2xf32> - // CHLO: %[[TEMP_14:.*]] = mhlo.abs %[[TEMP_13]] : tensor<2xf32> - // CHLO: %[[TEMP_15:.*]] = mhlo.maximum %[[TEMP_14]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_16:.*]] = mhlo.minimum %[[TEMP_14]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_17:.*]] = mhlo.compare EQ, %[[TEMP_15]], %[[TEMP_16]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_18:.*]] = mhlo.constant dense<1.41421354> : tensor<2xf32> - // CHLO: %[[TEMP_19:.*]] = mhlo.multiply %[[TEMP_18]], %[[TEMP_15]] : tensor<2xf32> - // CHLO: %[[TEMP_20:.*]] = mhlo.divide %[[TEMP_16]], %[[TEMP_15]] : tensor<2xf32> - // CHLO: %[[TEMP_21:.*]] = mhlo.multiply %[[TEMP_20]], %[[TEMP_20]] : tensor<2xf32> - // CHLO: %[[TEMP_22:.*]] = mhlo.add %[[TEMP_10]], %[[TEMP_21]] : tensor<2xf32> - // CHLO: %[[TEMP_23:.*]] = mhlo.sqrt %[[TEMP_22]] : tensor<2xf32> - // CHLO: %[[TEMP_24:.*]] = mhlo.compare EQ, %[[TEMP_23]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_25:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_26:.*]] = mhlo.compare GT, %[[TEMP_21]], %[[TEMP_25]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_27:.*]] = mhlo.and %[[TEMP_24]], %[[TEMP_26]] : tensor<2xi1> - // CHLO: %[[TEMP_28:.*]] = mhlo.multiply %[[TEMP_15]], %[[TEMP_21]] : tensor<2xf32> - // CHLO: %[[TEMP_29:.*]] = mhlo.constant dense<2.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_30:.*]] = mhlo.divide %[[TEMP_28]], %[[TEMP_29]] : tensor<2xf32> - // CHLO: %[[TEMP_31:.*]] = mhlo.add %[[TEMP_15]], %[[TEMP_30]] : tensor<2xf32> - // CHLO: %[[TEMP_32:.*]] = mhlo.multiply %[[TEMP_15]], %[[TEMP_23]] : tensor<2xf32> - // CHLO: %[[TEMP_33:.*]] = mhlo.select %[[TEMP_27]], %[[TEMP_31]], %[[TEMP_32]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_34:.*]] = mhlo.select %[[TEMP_17]], %[[TEMP_19]], %[[TEMP_33]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_35:.*]] = mhlo.subtract %[[TEMP_1]], %[[TEMP_10]] : tensor<2xf32> - // CHLO: %[[TEMP_36:.*]] = mhlo.abs %[[TEMP_35]] : tensor<2xf32> - // CHLO: %[[TEMP_37:.*]] = mhlo.maximum %[[TEMP_36]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_38:.*]] = mhlo.minimum %[[TEMP_36]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_39:.*]] = mhlo.compare EQ, %[[TEMP_37]], %[[TEMP_38]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_40:.*]] = mhlo.multiply %[[TEMP_18]], %[[TEMP_37]] : tensor<2xf32> - // CHLO: %[[TEMP_41:.*]] = mhlo.divide %[[TEMP_38]], %[[TEMP_37]] : tensor<2xf32> - // CHLO: %[[TEMP_42:.*]] = mhlo.multiply %[[TEMP_41]], %[[TEMP_41]] : tensor<2xf32> - // CHLO: %[[TEMP_43:.*]] = mhlo.add %[[TEMP_10]], %[[TEMP_42]] : tensor<2xf32> - // CHLO: %[[TEMP_44:.*]] = mhlo.sqrt %[[TEMP_43]] : tensor<2xf32> - // CHLO: %[[TEMP_45:.*]] = mhlo.compare EQ, %[[TEMP_44]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_46:.*]] = mhlo.compare GT, %[[TEMP_42]], %[[TEMP_25]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_47:.*]] = mhlo.and %[[TEMP_45]], %[[TEMP_46]] : tensor<2xi1> - // CHLO: %[[TEMP_48:.*]] = mhlo.multiply %[[TEMP_37]], %[[TEMP_42]] : tensor<2xf32> - // CHLO: %[[TEMP_49:.*]] = mhlo.divide %[[TEMP_48]], %[[TEMP_29]] : tensor<2xf32> - // CHLO: %[[TEMP_50:.*]] = mhlo.add %[[TEMP_37]], %[[TEMP_49]] : tensor<2xf32> - // CHLO: %[[TEMP_51:.*]] = mhlo.multiply %[[TEMP_37]], %[[TEMP_44]] : tensor<2xf32> - // CHLO: %[[TEMP_52:.*]] = mhlo.select %[[TEMP_47]], %[[TEMP_50]], %[[TEMP_51]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_53:.*]] = mhlo.select %[[TEMP_39]], %[[TEMP_40]], %[[TEMP_52]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_54:.*]] = mhlo.add %[[TEMP_34]], %[[TEMP_53]] : tensor<2xf32> - // CHLO: %[[TEMP_55:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_54]] : tensor<2xf32> - // CHLO: %[[TEMP_56:.*]] = mhlo.add %[[TEMP_55]], %[[TEMP_1]] : tensor<2xf32> - // CHLO: %[[TEMP_57:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_56]] : tensor<2xf32> - // CHLO: %[[TEMP_58:.*]] = mhlo.multiply %[[TEMP_3]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_59:.*]] = mhlo.add %[[TEMP_34]], %[[TEMP_13]] : tensor<2xf32> - // CHLO: %[[TEMP_60:.*]] = mhlo.divide %[[TEMP_58]], %[[TEMP_59]] : tensor<2xf32> - // CHLO: %[[TEMP_61:.*]] = mhlo.subtract %[[TEMP_53]], %[[TEMP_35]] : tensor<2xf32> - // CHLO: %[[TEMP_62:.*]] = mhlo.add %[[TEMP_60]], %[[TEMP_61]] : tensor<2xf32> - // CHLO: %[[TEMP_63:.*]] = mhlo.multiply %[[TEMP_57]], %[[TEMP_62]] : tensor<2xf32> - // CHLO: %[[TEMP_64:.*]] = mhlo.sqrt %[[TEMP_63]] : tensor<2xf32> - // CHLO: %[[TEMP_65:.*]] = mhlo.divide %[[TEMP_57]], %[[TEMP_59]] : tensor<2xf32> - // CHLO: %[[TEMP_66:.*]] = mhlo.add %[[TEMP_53]], %[[TEMP_35]] : tensor<2xf32> - // CHLO: %[[TEMP_67:.*]] = mhlo.divide %[[TEMP_57]], %[[TEMP_66]] : tensor<2xf32> - // CHLO: %[[TEMP_68:.*]] = mhlo.add %[[TEMP_65]], %[[TEMP_67]] : tensor<2xf32> - // CHLO: %[[TEMP_69:.*]] = mhlo.sqrt %[[TEMP_68]] : tensor<2xf32> - // CHLO: %[[TEMP_70:.*]] = mhlo.multiply %[[TEMP_3]], %[[TEMP_69]] : tensor<2xf32> - // CHLO: %[[TEMP_71:.*]] = mhlo.select %[[TEMP_11]], %[[TEMP_64]], %[[TEMP_70]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_72:.*]] = mhlo.select %[[TEMP_9]], %[[TEMP_3]], %[[TEMP_71]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_73:.*]] = mhlo.constant dense<9.99999995E+11> : tensor<2xf32> - // CHLO: %[[TEMP_74:.*]] = mhlo.multiply %[[TEMP_8]], %[[TEMP_73]] : tensor<2xf32> - // CHLO: %[[TEMP_75:.*]] = mhlo.compare LT, %[[TEMP_1]], %[[TEMP_74]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_76:.*]] = mhlo.constant dense<9.99999997E-7> : tensor<2xf32> - // CHLO: %[[TEMP_77:.*]] = mhlo.multiply %[[TEMP_8]], %[[TEMP_76]] : tensor<2xf32> - // CHLO: %[[TEMP_78:.*]] = mhlo.constant dense<1.000000e+02> : tensor<2xf32> - // CHLO: %[[TEMP_79:.*]] = mhlo.multiply %[[TEMP_8]], %[[TEMP_78]] : tensor<2xf32> - // CHLO: %[[TEMP_80:.*]] = mhlo.select %[[TEMP_75]], %[[TEMP_77]], %[[TEMP_79]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_81:.*]] = mhlo.compare GE, %[[TEMP_3]], %[[TEMP_80]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_82:.*]] = mhlo.select %[[TEMP_81]], %[[TEMP_3]], %[[TEMP_1]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_83:.*]] = mhlo.select %[[TEMP_81]], %[[TEMP_80]], %[[TEMP_8]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_84:.*]] = mhlo.compare GE, %[[TEMP_82]], %[[TEMP_83]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_85:.*]] = mhlo.log %[[TEMP_29]] : tensor<2xf32> - // CHLO: %[[TEMP_86:.*]] = mhlo.log %[[TEMP_82]] : tensor<2xf32> - // CHLO: %[[TEMP_87:.*]] = mhlo.add %[[TEMP_85]], %[[TEMP_86]] : tensor<2xf32> - // CHLO: %[[TEMP_88:.*]] = mhlo.constant dense<0x7F800000> : tensor<2xf32> - // CHLO: %[[TEMP_89:.*]] = mhlo.compare EQ, %[[TEMP_3]], %[[TEMP_88]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_90:.*]] = mhlo.not %[[TEMP_89]] : tensor<2xi1> - // CHLO: %[[TEMP_91:.*]] = mhlo.and %[[TEMP_81]], %[[TEMP_90]] : tensor<2xi1> - // CHLO: %[[TEMP_92:.*]] = mhlo.divide %[[TEMP_1]], %[[TEMP_3]] : tensor<2xf32> - // CHLO: %[[TEMP_93:.*]] = mhlo.select %[[TEMP_91]], %[[TEMP_92]], %[[TEMP_25]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_94:.*]] = mhlo.multiply %[[TEMP_93]], %[[TEMP_93]] : tensor<2xf32> - // CHLO: %[[TEMP_95:.*]] = mhlo.log_plus_one %[[TEMP_94]] : tensor<2xf32> - // CHLO: %[[TEMP_96:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_95]] : tensor<2xf32> - // CHLO: %[[TEMP_97:.*]] = mhlo.add %[[TEMP_87]], %[[TEMP_96]] : tensor<2xf32> - // CHLO: %[[TEMP_98:.*]] = mhlo.constant dense<1.17549435E-38> : tensor<2xf32> - // CHLO: %[[TEMP_99:.*]] = mhlo.sqrt %[[TEMP_98]] : tensor<2xf32> - // CHLO: %[[TEMP_100:.*]] = mhlo.constant dense<4.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_101:.*]] = mhlo.multiply %[[TEMP_99]], %[[TEMP_100]] : tensor<2xf32> - // CHLO: %[[TEMP_102:.*]] = mhlo.compare LT, %[[TEMP_3]], %[[TEMP_101]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_103:.*]] = mhlo.compare LT, %[[TEMP_1]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_104:.*]] = mhlo.and %[[TEMP_102]], %[[TEMP_103]] : tensor<2xi1> - // CHLO: %[[TEMP_105:.*]] = mhlo.multiply %[[TEMP_13]], %[[TEMP_35]] : tensor<2xf32> - // CHLO: %[[TEMP_106:.*]] = mhlo.add %[[TEMP_55]], %[[TEMP_10]] : tensor<2xf32> - // CHLO: %[[TEMP_107:.*]] = mhlo.divide %[[TEMP_105]], %[[TEMP_106]] : tensor<2xf32> - // CHLO: %[[TEMP_108:.*]] = mhlo.negate %[[TEMP_107]] : tensor<2xf32> - // CHLO: %[[TEMP_109:.*]] = mhlo.compare GE, %[[TEMP_1]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_110:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_58]] : tensor<2xf32> - // CHLO: %[[TEMP_111:.*]] = mhlo.divide %[[TEMP_110]], %[[TEMP_59]] : tensor<2xf32> - // CHLO: %[[TEMP_112:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_66]] : tensor<2xf32> - // CHLO: %[[TEMP_113:.*]] = mhlo.add %[[TEMP_111]], %[[TEMP_112]] : tensor<2xf32> - // CHLO: %[[TEMP_114:.*]] = mhlo.constant dense<1.500000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_115:.*]] = mhlo.compare LE, %[[TEMP_55]], %[[TEMP_114]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_116:.*]] = mhlo.divide %[[TEMP_110]], %[[TEMP_61]] : tensor<2xf32> - // CHLO: %[[TEMP_117:.*]] = mhlo.add %[[TEMP_111]], %[[TEMP_116]] : tensor<2xf32> - // CHLO: %[[TEMP_118:.*]] = mhlo.subtract %[[TEMP_55]], %[[TEMP_10]] : tensor<2xf32> - // CHLO: %[[TEMP_119:.*]] = mhlo.select %[[TEMP_115]], %[[TEMP_117]], %[[TEMP_118]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_120:.*]] = mhlo.select %[[TEMP_109]], %[[TEMP_113]], %[[TEMP_119]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_121:.*]] = mhlo.select %[[TEMP_104]], %[[TEMP_108]], %[[TEMP_120]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_122:.*]] = mhlo.multiply %[[TEMP_121]], %[[TEMP_106]] : tensor<2xf32> - // CHLO: %[[TEMP_123:.*]] = mhlo.sqrt %[[TEMP_122]] : tensor<2xf32> - // CHLO: %[[TEMP_124:.*]] = mhlo.divide %[[TEMP_3]], %[[TEMP_123]] : tensor<2xf32> - // CHLO: %[[TEMP_125:.*]] = mhlo.add %[[TEMP_121]], %[[TEMP_123]] : tensor<2xf32> - // CHLO: %[[TEMP_126:.*]] = mhlo.log_plus_one %[[TEMP_125]] : tensor<2xf32> - // CHLO: %[[TEMP_127:.*]] = mhlo.select %[[TEMP_104]], %[[TEMP_124]], %[[TEMP_126]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_128:.*]] = mhlo.select %[[TEMP_84]], %[[TEMP_97]], %[[TEMP_127]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_129:.*]] = mhlo.complex %[[TEMP_72]], %[[TEMP_128]] : tensor<2xcomplex> - // CHLO: %[[TEMP_130:.*]] = mhlo.real %[[TEMP_129]] : (tensor<2xcomplex>) -> tensor<2xf32> - // CHLO: %[[TEMP_131:.*]] = mhlo.real %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> - // CHLO: %[[TEMP_132:.*]] = mhlo.atan2 %[[TEMP_130]], %[[TEMP_131]] : tensor<2xf32> - // CHLO: %[[TEMP_133:.*]] = mhlo.imag %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> - // CHLO: %[[TEMP_134:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2xf32> - // CHLO: %[[TEMP_135:.*]] = mhlo.compare LT, %[[TEMP_133]], %[[TEMP_134]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHLO: %[[TEMP_136:.*]] = mhlo.imag %[[TEMP_129]] : (tensor<2xcomplex>) -> tensor<2xf32> - // CHLO: %[[TEMP_137:.*]] = mhlo.negate %[[TEMP_136]] : tensor<2xf32> - // CHLO: %[[TEMP_138:.*]] = mhlo.select %[[TEMP_135]], %[[TEMP_136]], %[[TEMP_137]] : tensor<2xi1>, tensor<2xf32> - // CHLO: %[[TEMP_139:.*]] = mhlo.complex %[[TEMP_132]], %[[TEMP_138]] : tensor<2xcomplex> - // CHLO: return %[[TEMP_139:.*]] : tensor<2xcomplex> - %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> - func.return %0 : tensor<2xcomplex> -} - -// ----- - -// CHECK-LABEL: @acos_dynamic -// CHLO-LABEL: @acos_dynamic -func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: chlo.acos %arg0 : tensor<*xf32> - // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be - // lowered further on ranked tensors. Unranked CHLO must be transformed to - // ranked code before further lowering. - // CHLO: "tf.Acos" - %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: func @cast_dynamic_i2f func.func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: mhlo.convert %arg0 : (tensor) -> tensor @@ -2448,15 +1445,6 @@ func.func @ceil_dynamic(%arg0: tensor) -> tensor { // ----- -// CHECK-LABEL: @complex_abs -func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - // CHECK: mhlo.abs %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> - %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - // CHECK-LABEL: @cos func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: mhlo.cosine %arg0 : tensor<2xf32> @@ -2466,15 +1454,6 @@ func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // ----- -// CHECK-LABEL: @tan -func.func @tan(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.tan %arg0 : tensor<2xf32> - %0 = "tf.Tan"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - // CHECK-LABEL: func @cos_dynamic func.func @cos_dynamic(%arg0: tensor) -> tensor { // CHECK: mhlo.cosine %arg0 : tensor @@ -2617,40 +1596,3 @@ func.func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> } -// ----- - -// CHECK-LABEL: @sigmoid_grad -func.func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xf32> - // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> - // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xf32> - // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> - // CHECK: return [[MUL1]] - %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @sigmoid_grad_complex -func.func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xcomplex> - // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex> - // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex> - // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex> - // CHECK: return [[MUL1]] - %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> - func.return %0 : tensor<2xcomplex> -} - -// ----- - -// CHECK-LABEL: @sigmoid_grad_dynamic -func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor - // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = array} : (tensor, tensor) -> tensor - // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor - %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} - diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 047a5fb7b46bbc..ae649253605bb9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -1541,119 +1541,6 @@ class ConvertBroadcastToOp : public OpRewritePattern { } }; -/// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis -/// have to be a constant. -class ConvertRollOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::RollOp op, - PatternRewriter &rewriter) const override { - auto shift_ty = mlir::dyn_cast(op.getShift().getType()); - if (!shift_ty || shift_ty.getRank() != 0) { - return rewriter.notifyMatchFailure( - op, "require the type of shift to be 0D tensor"); - } - - APInt val; - if (!matchPattern(op.getAxis(), m_ConstantInt(&val))) { - return rewriter.notifyMatchFailure(op, "require axis to be constant"); - } - int axis = val.getSExtValue(); - - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty || !input_ty.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "require the type of input to have static shapes"); - } - ArrayRef input_shape = input_ty.getShape(); - int input_rank = input_ty.getRank(); - if (axis < 0) axis += input_rank; - - // Adjust large offsets into [0, axis_size). This also makes negative - // offsets positive. - // offset = ((offset % axis_size) + axis_size) % axis_size - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value offset = op.getShift(); - auto axis_size = b.create(b.getIntegerAttr( - getElementTypeOrSelf(offset.getType()), input_shape[axis])); - offset = b.create( - b.create(b.create(offset, axis_size), axis_size), - axis_size); - - // Stack two copies of the dimension, then slice from the calculated - // offset. This also works if shift is not constant. - // DynamicSliceOp requires the sizes being integer, and we can get the - // information from input shape. - auto concat = b.create( - ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); - Value zero = b.create( - b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); - SmallVector slice_begin_indices(input_rank, zero); - slice_begin_indices[axis] = b.create(axis_size, offset); - rewriter.replaceOpWithNewOp( - op, input_ty, concat, slice_begin_indices, - rewriter.getI64TensorAttr(input_shape)); - return success(); - } -}; - -/// Converts a TF::LeakyReluOp to HLO. -/// LeakyRelu(x) = alpha * x if x < 0 else x. -class ConvertLeakyReluOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::LeakyReluOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value features = op.getFeatures(); - - // Use ConstantLike for `alpha` to match the shape of feature. - auto alphaVal = chlo::getConstantLike( - rewriter, loc, op.getAlpha().convertToFloat(), features); - Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); - - Value leakyActivationVal = - rewriter.create(loc, features, alphaVal); - - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); - - rewriter.replaceOpWithNewOp(op, compareGtZero, features, - leakyActivationVal); - return success(); - } -}; - -/// Converts a TF::LeakyReluGradOp to HLO. -/// LeakyReluGrad(gradient, inputs) = gradient if input > 0 -/// else alpha * gradient. -class ConvertLeakyReluGradOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::LeakyReluGradOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value gradients = op.getGradients(); - Value features = op.getFeatures(); - auto featureType = features.getType(); - - // Use ConstantLike for `alpha` to match the shape of feature. - auto alphaVal = chlo::getConstantLike( - rewriter, loc, op.getAlpha().convertToFloat(), features); - Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); - - Value leakyGradientVal = - rewriter.create(loc, gradients, alphaVal); - - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); - - rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, - gradients, leakyGradientVal); - return success(); - } -}; - // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. // For a Rank-2 input, it creates the following ops: // %1 = "mhlo.iota"() {iota_dimension = 0 : i64} @@ -2028,17 +1915,6 @@ class ConvertEinsumOp : public OpRewritePattern { } }; -// Bypasses IdentityN op. -class ConvertIdentityNOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::IdentityNOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOp(op, op.getOperands()); - return success(); - } -}; - template class ConvertFFTOp : public OpRewritePattern { public: @@ -2117,7 +1993,6 @@ class ConvertFFTOp : public OpRewritePattern { } }; -using ConvertRFFTOp = ConvertFFTOp; using ConvertIRFFTOp = ConvertFFTOp; // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO @@ -2244,10 +2119,6 @@ class ConvertFusedBatchNormGradBase } }; -using ConvertFusedBatchNormGradOp = - ConvertFusedBatchNormGradBase; -using ConvertFusedBatchNormGradV2Op = - ConvertFusedBatchNormGradBase; using ConvertFusedBatchNormGradV3Op = ConvertFusedBatchNormGradBase; @@ -2446,8 +2317,6 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { } }; -using ConvertFusedBatchNormV2Op = - ConvertFusedBatchNormBase; using ConvertFusedBatchNormV3Op = ConvertFusedBatchNormBase; @@ -2820,54 +2689,6 @@ using ConvertAvgPool2DGradOp = using ConvertAvgPool3DGradOp = ConvertAvgPoolGradOp; -// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window -// dimensions with max as the reduction function. -// -// Sample result for VALID padding mode: -// -// %init = arith.constant dense<...> : tensor -// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] -// {window_dimensions = ..., window_strides = ... } -// -template -class ConvertMaxPoolOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - Type element_type = - mlir::cast(op.getInput().getType()).getElementType(); - if (!element_type.isSignlessIntOrFloat()) return failure(); - tensorflow::Padding padding; - if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) - return failure(); - if (padding == tensorflow::Padding::EXPLICIT) { - return failure(); - } - Location loc = op.getLoc(); - ConstantOp init = GetScalarLimitConstOfType( - element_type, loc, hlo::kInfinityLowest, &rewriter); - - auto input_ty = mlir::dyn_cast(op.getInput().getType()); - if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), - &rewriter); - auto reduce = rewriter.create( - loc, op.getType(), op.getInput(), init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(element_type, &reduce.getBody(), &rewriter); - - rewriter.replaceOp(op, reduce.getResult(0)); - return success(); - } -}; - -using ConvertMaxPool2DOp = ConvertMaxPoolOp; -using ConvertMaxPool3DOp = ConvertMaxPoolOp; // Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on // the condition only. @@ -3033,127 +2854,6 @@ class ConvertSliceOpDynamic : public OpRewritePattern { } }; -static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, - Value *out_lhs, Value *out_rhs, - PatternRewriter *rewriter) { - // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is: - // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS] - // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS] - // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS] - // To perform the matmul, we need to first broadcast lhs and rhs to a common - // set of leading dimensions before doing the actual matmul. - // That's what the code below does. - // In particular, we populate out_lhs and out_rhs to have dimension structure: - // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS] - // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS] - // To do this, we need to calculate those output shapes, which involves - // slicing off the leading batch dims of each operand, broadcasting them, - // then concatenating the broadcasted leading dims back to the row/col dims. - // Finally, we create a TF::BroadcastTo op that does the actual broadcast. - - // TODO(silvasean): Reduce duplication across reified shape calculations and - // the static computation of output types needed to create ops. - Value lhs_shape = rewriter->create(loc, lhs); - Value rhs_shape = rewriter->create(loc, rhs); - Value const_neg2 = - rewriter->create(loc, rewriter->getIndexAttr(-2)); - auto shape_type = shape::ShapeType::get(rewriter->getContext()); - auto lhs_splitted = rewriter->create( - loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); - auto rhs_splitted = rewriter->create( - loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); - auto lhs_type = mlir::cast(lhs.getType()); - auto rhs_type = mlir::cast(rhs.getType()); - // The last two dimensions are the matrix row/col dimensions. Don't broadcast - // them. - SmallVector result_batch_shape_compile_time_extents; - mlir::OpTrait::util::getBroadcastedShape( - lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2), - result_batch_shape_compile_time_extents); - auto result_batch_shape = rewriter->create( - loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(), - /*error=*/nullptr); - // Lambda which handles the broadcasting of one side to the common - // leading-batch dimensions. - auto broadcast_one_side = [&](Value side, RankedTensorType type, - Value tail_shape, Value *out_side) { - ArrayRef matrix_dims = type.getShape().take_back(2); - auto result_shape = result_batch_shape_compile_time_extents; - result_shape.append(matrix_dims.begin(), matrix_dims.end()); - auto result_type = tensorflow::GetTypeFromTFTensorShape( - result_shape, type.getElementType()); - auto shape = rewriter->create( - loc, shape_type, result_batch_shape, tail_shape); - auto shape_tensor = rewriter->create( - loc, - tensorflow::GetTypeFromTFTensorShape( - {static_cast(result_shape.size())}, - rewriter->getIndexType()), - shape); - *out_side = rewriter->create(loc, result_type, side, - shape_tensor); - }; - broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs); - broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs); -} - -class ConvertBatchMatMulV2Op : public OpRewritePattern { - public: - // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved - // to CHLO and it is missing legalization to MHLO. Once that is done, this - // pattern's benefit can be changed back to one as well as the fallback - // lowering pattern for the op can be removed. - // - // Set benefit of this pattern to zero to prefer the fallback pattern when - // available and applicable. That pattern avoids broadcast on operands and is - // therefore faster. - // - // Native legalization for BatchMatMulV3 needs to be added as well. - explicit ConvertBatchMatMulV2Op(MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/0) {} - - LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, - PatternRewriter &rewriter) const override { - Value lhs = op.getX(); - Value rhs = op.getY(); - auto lhs_type = mlir::dyn_cast(lhs.getType()); - auto rhs_type = mlir::dyn_cast(rhs.getType()); - if (!lhs_type || !rhs_type) return failure(); - if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { - lhs = rewriter.create(op.getLoc(), lhs_type, lhs); - } - if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { - rhs = rewriter.create(op.getLoc(), rhs_type, rhs); - } - - // Broadcast both operands. - BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, - &rewriter); - lhs_type = mlir::cast(lhs.getType()); - rhs_type = mlir::cast(rhs.getType()); - assert(lhs_type.getRank() == rhs_type.getRank()); - int64_t rank = lhs_type.getRank(); - auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); - auto lhs_contracting_dimensions = llvm::to_vector<4>( - llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); - auto rhs_contracting_dimensions = llvm::to_vector<4>( - llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); - auto dimension_numbers = DotDimensionNumbersAttr::get( - rewriter.getContext(), - /*lhs_batching_dimensions=*/batch_dimensions, - /*rhs_batching_dimensions=*/batch_dimensions, - /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, - /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); - // TODO(silvasean): Emit shape checks for contracting dimensions. - // (The batch dimensions are checked by the broadcasting logic) - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, rhs, dimension_numbers, - /*precision_config=*/GetPrecisionConfig(&rewriter), - /*algorithm=*/DotAlgorithmAttr{}); - return success(); - } -}; - // Converts the tf.Split op into a series of HLO slice ops when the tensor to be // split has fully static shape and the dimension to split is a constant. // @@ -4722,8 +4422,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { } }; -using ConvertMaxPool2DGradOp = - ConvertMaxPoolGradOp; using ConvertMaxPool3DGradOp = ConvertMaxPoolGradOp; @@ -5523,50 +5221,6 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { } }; -// Converts the tf.SigmoidGradOp -// TODO(disc): To recover static special case's performance with folding and -// canonicalization. -class ConvertSigmoidGradOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SigmoidGradOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value y = op.getY(); - Value dy = op.getDy(); - auto tp_y = mlir::dyn_cast(y.getType()); - auto tp_dy = mlir::dyn_cast(dy.getType()); - if (!tp_y || !tp_dy) return failure(); - - // TODO(disc): Remove this constraint once fold and canonicalization - // implemented. - if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure(); - - Attribute attr; - Type elem_tp = tp_y.getElementType(); - if (elem_tp.isSignlessInteger()) { - attr = rewriter.getIntegerAttr(elem_tp, 1); - } else { - assert(mlir::isa(elem_tp)); - attr = rewriter.getFloatAttr(elem_tp, 1); - } - Value one = rewriter.create( - loc, DenseElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); - - auto v0 = rewriter.create( - loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y)); - auto v1 = rewriter.create( - loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y)); - auto result = rewriter.create( - loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1)); - - rewriter.replaceOp(op, result.getOperation()->getResults()); - return success(); - } -}; - // Converts TF unsorted segment reduction ops to XLA HLO scatter op. // // TF unsorted segment reduction op peforms the following calculation: @@ -6787,7 +6441,6 @@ class LowerControlFlowOp : public OpConversionPattern { } // end namespace #include "tensorflow/compiler/mlir/tf2xla/transforms/generated_legalize_tf.inc" -// LINT.IfChange void PopulateLegalizeTfPatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); @@ -6797,7 +6450,6 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertAnyOp, ConvertArgMaxOp, ConvertArgMinOp, - ConvertBatchMatMulV2Op, ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, @@ -6816,15 +6468,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertDynamicExpandDimsOp, ConvertDynamicSqueezeOp, ConvertEinsumOp, - ConvertRFFTOp, ConvertIRFFTOp, - ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, - ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, @@ -6833,9 +6480,6 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, - ConvertMaxPool2DOp, - ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp, @@ -6875,14 +6519,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertXlaSortOp, ConvertXlaVariadicReduceV2Op, ConvertXlaVariadicSortOp, - ConvertRollOp, - ConvertLeakyReluOp, - ConvertLeakyReluGradOp, ConvertSplitOpDynamic, ConvertSliceOpDynamic, ConvertTileOpDynamic, ConvertUnpackOpDynamic, - ConvertSigmoidGradOpDynamic, ConvertConv2DDynamic, ConvertPadOpDynamic, ConvertGatherNdOpDynamic, @@ -6892,6 +6532,5 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, LowerYieldOp>(context); // clang-format on } -// LINT.ThenChange(:MlirAlwaysOps) } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 46f3ebfe19104d..8fd138d0f204d7 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -238,8 +238,6 @@ class DirectLogicalBinaryPat foreach fromToBinPair = [[TF_LogicalAndOp, CHLO_BroadcastAndOp], [TF_LogicalOrOp, CHLO_BroadcastOrOp], - [TF_BitwiseAndOp, CHLO_BroadcastAndOp], - [TF_BitwiseOrOp, CHLO_BroadcastOrOp], [TF_BitwiseXorOp, CHLO_BroadcastXorOp]] in def : DirectLogicalBinaryPat; @@ -359,25 +357,6 @@ def LegalizeGatherV2 : (GetHLOAxisFromTFAxis $axis, $params), (GetHLOAxisFromTFAxis $batch_dims, $indices))>; -//===----------------------------------------------------------------------===// -// Pad op patterns. -//===----------------------------------------------------------------------===// - -class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; - -class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; - -// Interior padding attribute based on the TF padding. -def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; - -def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), - (MHLO_PadOp $input, $c, - (SliceDenseIntElementsAttrColumn2D<"0"> $padding), - (SliceDenseIntElementsAttrColumn2D<"1"> $padding), - (GetInteriorPadding $padding))>; //===----------------------------------------------------------------------===// // Identity op patterns. @@ -744,22 +723,6 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), def : Pat<(TF_XlaReplicaIdOp), (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; -//===----------------------------------------------------------------------===// -// XlaGather op. -//===----------------------------------------------------------------------===// - -def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; - -def HasValidGatherDims : Constraint>; - -def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), - $dimension_numbers, $indices_are_sorted), - (MHLO_GatherOp $operand, $start_indices, - (ToGatherDimNumsAttr $dimension_numbers), - (CastElementsToI64Elements $slice_sizes), - $indices_are_sorted), - [(HasValidGatherDims $dimension_numbers)]>; - //===----------------------------------------------------------------------===// // XlaDotOp op. //===----------------------------------------------------------------------===// From 831174367e8c156a9df451d8535c47773b6e168f Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 1 Apr 2025 10:20:52 -0700 Subject: [PATCH 0098/1324] Rework `build.py` to properly force TensorFlow to read XLA Previously the version of XLA on the PR was not being used (!!) as XLA is known as `@local_xla` inside TensorFlow. However, using `--override_repository=local_xla` would not work either, as TensorFlow sometimes refers to things in XLA via `//third_party/xla` rather than `@xla`. I will attempt to fix this in a future change, but this gets things working properly for now. Also have to add some extra sed commands to have the XLA on the PR reference things as TensorFlow wants (`@local_{tsl,xla}`). This fixes my original mistake which allowed the breakage caused by https://github.com/openxla/xla/commit/1010bd13d27924ce1924fb6e7043d8858fda45d8 to not surface on presubmit tests. PiperOrigin-RevId: 742749886 --- third_party/xla/build_tools/ci/build.py | 66 ++++++++++++++++--- .../xla/build_tools/ci/golden_commands.txt | 22 ++++--- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 161f0a95ff1fcf..d71bf4b22d9773 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -137,6 +137,7 @@ class Build: _builds: ClassVar[Dict[BuildType, "Build"]] = {} type_: BuildType + subcommand: str = "test" repo: str target_patterns: Tuple[str, ...] configs: Tuple[str, ...] = () @@ -145,9 +146,9 @@ class Build: action_env: Dict[str, Any] = dataclasses.field(default_factory=dict) test_env: Dict[str, Any] = dataclasses.field(default_factory=dict) repo_env: Dict[str, Any] = dataclasses.field(default_factory=dict) + override_repository: Dict[str, str] = dataclasses.field(default_factory=dict) options: Dict[str, Any] = dataclasses.field(default_factory=dict) extra_setup_commands: Tuple[List[str], ...] = () - subcommand: str = "test" def __post_init__(self): # pylint: disable=protected-access @@ -178,6 +179,10 @@ def bazel_command( action_env = [f"--action_env={k}={v}" for k, v in self.action_env.items()] test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()] repo_env = [f"--repo_env={k}={v}" for k, v in self.repo_env.items()] + override_repository = [ + f"--override_repository={k}={v}" + for k, v in self.override_repository.items() + ] tag_filters = [build_tag_filters, test_tag_filters] all_options = ( @@ -186,6 +191,7 @@ def bazel_command( + action_env + test_env + repo_env + + override_repository + options + list(extra_options) ) @@ -463,10 +469,10 @@ def nvidia_gpu_build_with_compute_capability( JAX_NUM_GENERATED_CASES=25, JAX_SKIP_SLOW_TESTS=1, ), - options=dict( - **_DEFAULT_BAZEL_OPTIONS, - override_repository=f"xla={_GITHUB_WORKSPACE}/openxla/xla", + override_repository=dict( + xla=f"{_GITHUB_WORKSPACE}/openxla/xla", ), + options=_DEFAULT_BAZEL_OPTIONS, repo_env={"HERMETIC_PYTHON_VERSION": "3.12"}, ) @@ -482,10 +488,10 @@ def nvidia_gpu_build_with_compute_capability( TF_CPP_MIN_LOG_LEVEL=0, JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow", ), - options=dict( - **_DEFAULT_BAZEL_OPTIONS, - override_repository=f"xla={_GITHUB_WORKSPACE}/openxla/xla", + override_repository=dict( + xla=f"{_GITHUB_WORKSPACE}/openxla/xla", ), + options=_DEFAULT_BAZEL_OPTIONS, repo_env={"HERMETIC_PYTHON_VERSION": "3.10"}, ) @@ -528,12 +534,33 @@ def nvidia_gpu_build_with_compute_capability( options=dict( verbose_failures=True, test_output="errors", - override_repository=f"xla={_GITHUB_WORKSPACE}/openxla/xla", profile="profile.json.gz", test_lang_filters="cc,py", color="yes", ), repo_env={"USE_PYWRAP_RULES": "True"}, + extra_setup_commands=( + # This is pretty devious - but we have to do some adhoc extra Copybara + # work here to get XLA into the shape TF expects. b/407638223 + # pyformat:disable + [ + "cp", "-r", + f"{_GITHUB_WORKSPACE}/openxla/xla", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", + ], + [ + "find", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + "-type", "f", + "-exec", "sed", "-i", "s/@local_xla/@local_xla/g", "{}", "+", + ], + [ + "find", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + "-type", "f", + "-exec", "sed", "-i", "s/@local_tsl/@local_tsl/g", "{}", "+", + ], + ), ) Build( @@ -557,12 +584,33 @@ def nvidia_gpu_build_with_compute_capability( options=dict( verbose_failures=True, test_output="errors", - override_repository=f"xla={_GITHUB_WORKSPACE}/openxla/xla", profile="profile.json.gz", test_lang_filters="cc,py", color="yes", ), repo_env={"USE_PYWRAP_RULES": "True"}, + extra_setup_commands=( + # This is pretty devious - but we have to do some adhoc extra Copybara + # work here to get XLA into the shape TF expects. b/407638223 + # pyformat:disable + [ + "cp", "-r", + f"{_GITHUB_WORKSPACE}/openxla/xla", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", + ], + [ + "find", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + "-type", "f", + "-exec", "sed", "-i", "s/@local_xla/@local_xla/g", "{}", "+", + ], + [ + "find", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + "-type", "f", + "-exec", "sed", "-i", "s/@local_tsl/@local_tsl/g", "{}", "+", + ], + ), ) diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index ebfd16ee64b2b7..9985958a2d2d41 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -1,21 +1,27 @@ # BEGIN BuildType.JAX_LINUX_X86_CPU_GITHUB_ACTIONS -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --repo_env=HERMETIC_PYTHON_VERSION=3.12 --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --nobuild -- //tests:cpu_tests //tests:backend_independent_tests -bazel test --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --repo_env=HERMETIC_PYTHON_VERSION=3.12 --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla -- //tests:cpu_tests //tests:backend_independent_tests +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --repo_env=HERMETIC_PYTHON_VERSION=3.12 --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //tests:cpu_tests //tests:backend_independent_tests +bazel test --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --repo_env=HERMETIC_PYTHON_VERSION=3.12 --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //tests:cpu_tests //tests:backend_independent_tests bazel analyze-profile profile.json.gz # END BuildType.JAX_LINUX_X86_CPU_GITHUB_ACTIONS # BEGIN BuildType.JAX_LINUX_X86_GPU_T4_GITHUB_ACTIONS -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=rbe_linux_x86_64_cuda --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --repo_env=HERMETIC_PYTHON_VERSION=3.10 --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --nobuild -- //tests:gpu_tests //tests:backend_independent_tests -bazel test --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=rbe_linux_x86_64_cuda --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --repo_env=HERMETIC_PYTHON_VERSION=3.10 --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla -- //tests:gpu_tests //tests:backend_independent_tests +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=rbe_linux_x86_64_cuda --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --repo_env=HERMETIC_PYTHON_VERSION=3.10 --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //tests:gpu_tests //tests:backend_independent_tests +bazel test --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=rbe_linux_x86_64_cuda --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --repo_env=HERMETIC_PYTHON_VERSION=3.10 --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //tests:gpu_tests //tests:backend_independent_tests bazel analyze-profile profile.json.gz # END BuildType.JAX_LINUX_X86_GPU_T4_GITHUB_ACTIONS # BEGIN BuildType.TENSORFLOW_LINUX_X86_CPU_GITHUB_ACTIONS -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --config=release_cpu_linux --config=rbe_linux_cpu --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --profile=profile.json.gz --test_lang_filters=cc,py --color=yes --nobuild -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... -bazel test --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --config=release_cpu_linux --config=rbe_linux_cpu --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --profile=profile.json.gz --test_lang_filters=cc,py --color=yes -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... +cp -r $GITHUB_WORKSPACE/openxla/xla $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party +find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_xla/@local_xla/g {} + +find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_tsl/@local_tsl/g {} + +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --config=release_cpu_linux --config=rbe_linux_cpu --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes --nobuild -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... +bazel test --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --config=release_cpu_linux --config=rbe_linux_cpu --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... bazel analyze-profile profile.json.gz # END BuildType.TENSORFLOW_LINUX_X86_CPU_GITHUB_ACTIONS # BEGIN BuildType.TENSORFLOW_LINUX_X86_GPU_T4_GITHUB_ACTIONS -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --profile=profile.json.gz --test_lang_filters=cc,py --color=yes --nobuild -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... -bazel test --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --profile=profile.json.gz --test_lang_filters=cc,py --color=yes -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... +cp -r $GITHUB_WORKSPACE/openxla/xla $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party +find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_xla/@local_xla/g {} + +find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_tsl/@local_tsl/g {} + +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes --nobuild -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... +bazel test --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... bazel analyze-profile profile.json.gz # END BuildType.TENSORFLOW_LINUX_X86_GPU_T4_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_ARM64_CPU_48_VCPU_PRESUBMIT_GITHUB_ACTIONS From 7b361fb56236022e2cac36db2b0992c80d601212 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Tue, 1 Apr 2025 10:32:41 -0700 Subject: [PATCH 0099/1324] Harden PjRt-IFRT's `MakeErrorArrays` implementation by requiring the error to be not OK PiperOrigin-RevId: 742754522 --- third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index 3df85dbc64ea25..283099f5dfd87e 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -990,6 +990,9 @@ absl::StatusOr>> PjRtClient::MakeErrorArrays(const absl::Status& error, absl::Span array_specs, tsl::RCReference user_context) { + if (error.ok()) { + return absl::InvalidArgumentError("Error status must not be OK"); + } DCHECK(this); std::vector> arrays; arrays.reserve(array_specs.size()); From bcf7217c882021680733da4817abb5bc9764566c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 10:50:40 -0700 Subject: [PATCH 0100/1324] [XLA:GPU][Emitters] Fix crash in allocate_shared when using int4 without attributes PiperOrigin-RevId: 742761377 --- .../xla/xla/codegen/emitters/transforms/lower_tensors.cc | 2 +- .../codegen/emitters/transforms/tests/lower_tensors.mlir | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc index d122f70533dc5c..f7f425fa4ea54c 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc @@ -619,7 +619,7 @@ ml::GlobalOp CreateGlobalOp(mlir::Attribute value, // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - if (element_type.isIntOrFloat() && + if (value && element_type.isIntOrFloat() && element_type.getIntOrFloatBitWidth() == 4) { num_elements = CeilOfRatio(num_elements, 2); llvm_element_type = b.getI8Type(); diff --git a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir index 83be709942d973..ffda69a09bb090 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir @@ -378,6 +378,15 @@ func.func @shared_complex() -> tensor<10xcomplex> { // ----- +func.func @shared_i4() -> tensor<10xi4> { + %shared = xla_gpu.allocate_shared : tensor<10xi4> + return %shared : tensor<10xi4> +} +// CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x i4> +// CHECK-LABEL: @shared_i4 + +// ----- + func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> { %v = tensor.extract %arg[%i] : tensor<10xi4> From c47e7e3df7a8ad3029f195b7620e52cc961fa750 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 11:03:27 -0700 Subject: [PATCH 0101/1324] Add XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp to allow passing a custom combiner for embedding lookup FWD pass. PiperOrigin-RevId: 742766663 --- .../compiler/mlir/tensorflow/ir/tf_ops.td | 30 +++ .../transforms/legalization_op_config.cc | 1 + .../transforms/legalization_op_config_test.cc | 2 +- ...MatmulCustomCombinerOnTcWithCsrInput.pbtxt | 4 + ...MatmulCustomCombinerOnTcWithCsrInput.pbtxt | 4 + tensorflow/core/tpu/kernels/BUILD | 2 + .../core/tpu/kernels/sparse_core_xla_ops.cc | 189 +++++++++++++++++- tensorflow/core/tpu/ops/sparse_core_ops.cc | 75 +++++++ tensorflow/python/tpu/ops/BUILD | 1 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + third_party/xla/xla/side_effect_util.cc | 2 + third_party/xla/xla/side_effect_util.h | 4 + 13 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 83dca69fc1a9d8..007cd3f652439e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -2210,6 +2210,36 @@ def TF_XlaSparseDenseMatmulWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulWithCsrIn ); } +def TF_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput", [Pure]> { + let summary = "This op looks up the embedding vectors on SparseCores and performs the given combiner computation on TensorCores."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$weights, + + ConfinedAttr]>:$input_size, + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + OptionalAttr:$quantization_config_low, + OptionalAttr:$quantization_config_high, + OptionalAttr:$quantization_config_num_buckets, + + SymbolRefAttr:$combiner_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$activations, + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors + ); +} + def TF_XlaSparseDenseMatmulGradWithSgdAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulGradWithSgdAndCsrInput", [Pure]> { let summary = ""; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index 7df70e4de558a2..25c894b5785c91 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -358,6 +358,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get< TF::XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSizeOp>(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 113b088b3db7d2..3b0f27330be9c9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -83,7 +83,7 @@ TEST(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 323); + EXPECT_EQ(tf2xla_fallback_count, 324); EXPECT_EQ(non_categorized_count, 431); } diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt new file mode 100644 index 00000000000000..72728218d6ead2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt new file mode 100644 index 00000000000000..72728218d6ead2 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 8f14a5abac0c29..4378e884a65010 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -192,6 +192,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -204,6 +205,7 @@ cc_library( "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", + "@local_xla//xla/tsl/platform:errors", ], alwayslink = 1, ) diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc index ecfb757e71d335..560bd16ca94977 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -37,6 +38,7 @@ limitations under the License. #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/macros.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -340,7 +342,7 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { ctx->SetOutput(0, result); } - private: + protected: int input_size_; int64_t num_sparsecores_per_chip_; std::optional quantization_config_low_; @@ -357,6 +359,191 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaSparseDenseMatmulWithCsrInput"), XlaSparseDenseMatmulWithCsrInputOp); +// Similar to XlaSparseDenseMatmulWithCsrInputOp, but with an additional field +// `sorted_pos_ids` in the input Csr, `weights` which is a tensor of shape +// [num_weights] to be used by the `combiner_computation`. It produces the same +// embedding look up result as `XlaSparseDenseMatmulWithCsrInputOp`. +class XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp + : public XlaSparseDenseMatmulWithCsrInputOp { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulWithCsrInputOp(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_valency", &max_valency_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_weights", &num_weights_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner_computation", &name_attr)); + combiner_computation_ = *name_attr; + } + + ~XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp() override = default; + + absl::StatusOr BuildTcCustomCombinerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) { + XlaCompiler::CompileOptions options; + options.use_tuple_arg = false; + options.always_return_tuple = false; + options.is_entry_computation = false; + + XlaCompiler* compiler = ctx->compiler(); + XlaCompiler::CompilationResult custom_combiner_computation_result; + + XlaCompiler::Argument valencies_arg; + XlaCompiler::Argument vectors_arg; + + valencies_arg.kind = XlaCompiler::Argument::kParameter; + valencies_arg.type = DT_INT32; + valencies_arg.shape = xla::ShapeUtil::MakeShape(xla::S32, {input_size_}); + valencies_arg.name = "valencies"; + vectors_arg.kind = XlaCompiler::Argument::kParameter; + vectors_arg.type = DT_FLOAT; + vectors_arg.shape = xla::ShapeUtil::MakeShape( + xla::F32, {input_size_, max_valency_, feature_width}); + vectors_arg.name = "vectors"; + + std::vector arguments = {valencies_arg, vectors_arg}; + + // Don't add the weights argument if it's not needed. This helps avoid + // issues of passing around zero-sized tensors and Xla values. + if (num_weights_ > 0) { + XlaCompiler::Argument weights_arg; + weights_arg.kind = XlaCompiler::Argument::kParameter; + weights_arg.type = DT_FLOAT; + weights_arg.shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size_, num_weights_}); + weights_arg.name = "weights"; + arguments.push_back(weights_arg); + } + + TF_RETURN_IF_ERROR( + compiler->CompileFunction(options, combiner_computation_, arguments, + &custom_combiner_computation_result)); + return std::move(*custom_combiner_computation_result.computation); + } + + void Compile(XlaOpKernelContext* ctx) override { + int64_t per_sparse_core_batch_size = + input_size_ / num_sparsecores_per_chip_; + int64_t max_ids_per_partition = 0; + int64_t max_unique_ids_per_partition = 0; + + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp row_pointers = ctx->Input("row_pointers"); + xla::XlaOp sorted_sample_ids = ctx->Input("sorted_sample_ids"); + xla::XlaOp sorted_token_ids = ctx->Input("sorted_token_ids"); + xla::XlaOp sorted_pos_ids = ctx->Input("sorted_pos_ids"); + xla::XlaOp sorted_gains = ctx->Input("sorted_gains"); + xla::XlaOp embedding_table = ctx->Input("embedding_table"); + + OP_REQUIRES_VALUE(xla::Shape embedding_table_shape, ctx, + ctx->InputXlaShape("embedding_table")); + const int32_t feature_width = embedding_table_shape.dimensions(1); + + OP_REQUIRES_OK( + ctx, GetMaxIdsAndUniques(per_sparse_core_batch_size, feature_width, + &max_ids_per_partition, + &max_unique_ids_per_partition)); + // Log max_ids and max_uniques for offline analysis. We do this here since + // these values are fixed at TPU compile time and remain fixed during + // training. + max_ids_per_partition_gauge_->GetCell(device_name_, table_name_) + ->Set(max_ids_per_partition); + max_unique_ids_per_partition_gauge_->GetCell(device_name_, table_name_) + ->Set(max_unique_ids_per_partition); + LOG(INFO) << "Lowering " + "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp to HLO: " + << "table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; + + xla::FrontendAttributes tc_frontend_attributes; + xla::FrontendAttributes sc_frontend_attributes; + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_compute_type", "sparse"}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_sharding_strategy", "mod"}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_pad_value", absl::StrCat(kXlaPadValue)}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_max_ids_per_partition", absl::StrCat(max_ids_per_partition)}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_max_unique_ids_per_partition", + absl::StrCat(max_unique_ids_per_partition)}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_max_valency", absl::StrCat(max_valency_)}); + + if (quantization_config_low_.has_value()) { + sc_frontend_attributes.mutable_map()->insert( + {"_xla_quantization_high_value", + absl::StrCat(quantization_config_high_.value())}); + sc_frontend_attributes.mutable_map()->insert( + {"_xla_quantization_low_value", + absl::StrCat(quantization_config_low_.value())}); + sc_frontend_attributes.mutable_map()->insert( + {"_xla_quantization_num_buckets_value", + absl::StrCat(quantization_config_num_buckets_.value())}); + } + + tc_frontend_attributes = + builder->SwapFrontendAttributes(sc_frontend_attributes); + + // Emit the custom call that performs the SC embedding lookup. + xla::Shape valencies_shape = + xla::ShapeUtil::MakeShape(xla::S32, {input_size_}); + xla::Shape vectors_shape = xla::ShapeUtil::MakeShape( + xla::F32, {input_size_, max_valency_, feature_width}); + xla::Shape gains_shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size_, max_valency_}); + xla::XlaOp sc_lookup_result_tuple = xla::CustomCall( + builder, "SparseDenseMatmulCustomCombinerTcCombinerMegachipOp", + {row_pointers, sorted_token_ids, sorted_sample_ids, sorted_pos_ids, + sorted_gains, embedding_table}, + xla::ShapeUtil::MakeTupleShape( + {valencies_shape, vectors_shape, gains_shape})); + + // Emit the custom combiner computation into an HLO computation. + OP_REQUIRES_VALUE(xla::XlaComputation custom_combiner_tc_computation, ctx, + BuildTcCustomCombinerComputation(ctx, feature_width)); + + builder->SetFrontendAttributes(tc_frontend_attributes); + + xla::XlaOp valencies = xla::GetTupleElement(sc_lookup_result_tuple, 0); + xla::XlaOp vectors = xla::GetTupleElement(sc_lookup_result_tuple, 1); + + std::vector tc_combiner_args = {valencies, vectors}; + if (num_weights_ > 0) { + xla::XlaOp weights = ctx->Input("weights"); + tc_combiner_args.push_back(xla::Broadcast(weights, {input_size_})); + } + + xla::XlaOp tc_activations = + xla::Call(builder, custom_combiner_tc_computation, tc_combiner_args); + + ctx->SetOutput(0, tc_activations); + ctx->SetOutput(1, valencies); + ctx->SetOutput(2, vectors); + } + + private: + int max_valency_; + int num_weights_; + NameAttrList combiner_computation_; + + XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp&) = delete; + void operator=(const XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp&) = + delete; +}; + +REGISTER_XLA_OP(Name("XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp); + // Base class for all the minibatch with CSR input optimizer kernel. class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { public: diff --git a/tensorflow/core/tpu/ops/sparse_core_ops.cc b/tensorflow/core/tpu/ops/sparse_core_ops.cc index 8d0186f145b36a..36c9248aad14ab 100644 --- a/tensorflow/core/tpu/ops/sparse_core_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_ops.cc @@ -95,6 +95,81 @@ REGISTER_OP("XlaSparseDenseMatmulWithCsrInput") return absl::OkStatus(); }); +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("embedding_table: float32") + .Input("weights: float32") + .Output("activations: float32") + .Output("preserved_valencies: int32") + .Output("preserved_vectors: float32") + .Attr("input_size: int >= 0") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_computation: func") + .Attr("quantization_config_low: float") + .Attr("quantization_config_high: float") + .Attr("quantization_config_num_buckets: int >= 0") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kRowPointersIndex = 0; + constexpr int kSortedSampleIdsIndex = 1; + constexpr int kEmbeddingTableIndex = 5; + constexpr int kEmbeddingTableRank = 2; + constexpr int kWeightsIndex = 6; + constexpr int kWeightsRank = 1; + constexpr int kOutputActivationsIndex = 0; + constexpr int kPreservedValenciesIndex = 1; + constexpr int kPreservedVectorsIndex = 2; + // This input_size is per-chip batch size. + int input_size; + TF_RETURN_IF_ERROR(c->GetAttr("input_size", &input_size)); + int max_valency; + TF_RETURN_IF_ERROR(c->GetAttr("max_valency", &max_valency)); + int num_weights; + TF_RETURN_IF_ERROR(c->GetAttr("num_weights", &num_weights)); + + shape_inference::ShapeHandle rank; + for (int i = kRowPointersIndex; i < kEmbeddingTableIndex; ++i) { + TF_RETURN_IF_ERROR( + c->WithRank(c->input(i), kSortedSampleIdsIndex, &rank)); + } + TF_RETURN_IF_ERROR(c->WithRank(c->input(kEmbeddingTableIndex), + kEmbeddingTableRank, &rank)); + for (int i = kSortedSampleIdsIndex + 1; i < kEmbeddingTableIndex; ++i) { + shape_inference::ShapeHandle merged; + TF_RETURN_IF_ERROR( + c->Merge(c->input(i), c->input(kSortedSampleIdsIndex), &merged)); + } + if (num_weights > 0) { + TF_RETURN_IF_ERROR( + c->WithRank(c->input(kWeightsIndex), kWeightsRank, &rank)); + shape_inference::DimensionHandle weights_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(kWeightsIndex), 0), + num_weights, &weights_dim)); + } + + shape_inference::DimensionHandle input_size_dim = c->MakeDim(input_size); + shape_inference::DimensionHandle max_valency_dim = + c->MakeDim(max_valency); + shape_inference::DimensionHandle feature_width_dim = + c->Dim(c->input(kEmbeddingTableIndex), 1); + shape_inference::ShapeHandle output_activations_shape; + TF_RETURN_IF_ERROR(c->ReplaceDim(c->input(kEmbeddingTableIndex), 0, + c->MakeDim(input_size), + &output_activations_shape)); + c->set_output(kOutputActivationsIndex, output_activations_shape); + c->set_output(kPreservedValenciesIndex, c->MakeShape({input_size_dim})); + c->set_output( + kPreservedVectorsIndex, + c->MakeShape({input_size_dim, max_valency_dim, feature_width_dim})); + + return absl::OkStatus(); + }); + REGISTER_OP("XlaSparseDenseMatmulGradWithSgdAndCsrInput") .Input("row_pointers: int32") .Input("sorted_sample_ids: int32") diff --git a/tensorflow/python/tpu/ops/BUILD b/tensorflow/python/tpu/ops/BUILD index c0a9a7590770b7..bf0bfc21f63a16 100644 --- a/tensorflow/python/tpu/ops/BUILD +++ b/tensorflow/python/tpu/ops/BUILD @@ -65,6 +65,7 @@ tf_gen_op_wrapper_py( "TPUAnnotateTensorsWithDynamicShape", "XlaSparseDenseMatmul", "XlaSparseDenseMatmulWithCsrInput", + "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput", "XlaSparseDenseMatmulGrad", "XlaSparseDenseMatmulGradWithCsrInput", "XlaSparseDenseMatmulGradWithSgdAndCsrInput", diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 6df7d49a48a9af..b3885a3b5b3d84 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -5740,6 +5740,10 @@ tf_module { name: "XlaSparseDenseMatmul" argspec: "args=[\'row_ids\', \'col_ids\', \'values\', \'offsets\', \'embedding_table\', \'max_ids_per_partition\', \'max_unique_ids_per_partition\', \'input_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_pos_ids\', \'sorted_gains\', \'embedding_table\', \'weights\', \'input_size\', \'max_valency\', \'num_weights\', \'combiner_computation\', \'quantization_config_low\', \'quantization_config_high\', \'quantization_config_num_buckets\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput" argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 6df7d49a48a9af..b3885a3b5b3d84 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -5740,6 +5740,10 @@ tf_module { name: "XlaSparseDenseMatmul" argspec: "args=[\'row_ids\', \'col_ids\', \'values\', \'offsets\', \'embedding_table\', \'max_ids_per_partition\', \'max_unique_ids_per_partition\', \'input_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_pos_ids\', \'sorted_gains\', \'embedding_table\', \'weights\', \'input_size\', \'max_valency\', \'num_weights\', \'combiner_computation\', \'quantization_config_low\', \'quantization_config_high\', \'quantization_config_num_buckets\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput" argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], " diff --git a/third_party/xla/xla/side_effect_util.cc b/third_party/xla/xla/side_effect_util.cc index b16896616ba930..9a883a29597153 100644 --- a/third_party/xla/xla/side_effect_util.cc +++ b/third_party/xla/xla/side_effect_util.cc @@ -38,6 +38,8 @@ const char kXlaMaxIdsPerPartitionAttr[] = "_xla_max_ids_per_partition"; const char kXlaMaxUniqueIdsPerPartitionAttr[] = "_xla_max_unique_ids_per_partition"; +const char kXlaMaxValencyAttr[] = "_xla_max_valency"; + const char kXlaShardingStrategyAttr[] = "_xla_sharding_strategy"; const char kXlaShardingStrategyMod[] = "mod"; diff --git a/third_party/xla/xla/side_effect_util.h b/third_party/xla/xla/side_effect_util.h index e6768072d9376e..c108898fe626a4 100644 --- a/third_party/xla/xla/side_effect_util.h +++ b/third_party/xla/xla/side_effect_util.h @@ -45,6 +45,10 @@ extern const char kXlaMaxIdsPerPartitionAttr[]; // partition *after* an input batch is partitioned. extern const char kXlaMaxUniqueIdsPerPartitionAttr[]; +// XLA frontend attribute name for the maximum valency of a sample. Currently +// only used for the custom combiner coarse-grain op. +extern const char kXlaMaxValencyAttr[]; + // XLA frontend attribute for how to assign ids to partitions. extern const char kXlaShardingStrategyAttr[]; From 41f0a9a1487162b7a366efe49bdf9d957a05b3f1 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 1 Apr 2025 11:06:53 -0700 Subject: [PATCH 0102/1324] Convert CpuClient to also lazily tuplize arguments. This avoids constructing and explicit tupletized buffer when it is not necessary. PiperOrigin-RevId: 742768145 --- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 83 ++++++++++++++-------- 1 file changed, 55 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index c72b2f4c8a9aed..2d1c77b22dc9aa 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/pjrt/cpu/cpu_client.h" +#include + #define EIGEN_USE_THREADS #include @@ -1308,20 +1310,43 @@ static absl::StatusOr MemoryForAllocation( const BufferAllocation& allocation, absl::Span constants, absl::Span const> arguments, - BufferAlloc& buffer_alloc, BufferAllocAndCopy& buffer_alloc_and_copy) { + BufferAlloc& buffer_alloc, BufferAllocAndCopy& buffer_alloc_and_copy, + const tsl::AsyncValueRef& tuple_index_table) { BufferInfo buffer_info; if (allocation.is_entry_computation_parameter()) { - auto [can_donate, arg] = arguments[allocation.parameter_number()]; - tsl::AsyncValuePtr out = - arg->Buffer(allocation.param_shape_index()); - CHECK_EQ(allocation.size(), arg->BufferSize(allocation.param_shape_index())) + bool can_donate = false; + TrackedCpuDeviceBuffer* arg = nullptr; + size_t buffer_size; + tsl::AsyncValuePtr out; + if (tuple_index_table) { + if (allocation.param_shape_index().empty()) { + out = tuple_index_table.AsPtr(); + buffer_size = arguments.size() * sizeof(void*); + } else if (allocation.param_shape_index().size() == 1) { + std::tie(can_donate, arg) = + arguments[allocation.param_shape_index()[0]]; + out = arg->Buffer({}); + buffer_size = arg->BufferSize({}); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Nested tuples are not supported for argument: ", + allocation.parameter_number(), + " at shape index:", allocation.param_shape_index().ToString())); + } + } else { + std::tie(can_donate, arg) = arguments[allocation.parameter_number()]; + out = arg->Buffer(allocation.param_shape_index()); + buffer_size = arg->BufferSize(allocation.param_shape_index()); + } + CHECK_EQ(allocation.size(), buffer_size) << "Size mismatch on param " << allocation.parameter_number() << " at shape index " << allocation.param_shape_index().ToString(); // If we don't own the buffer, we can't overwrite it or donate it. For // example we might be pointing to a buffer owned by the client whose // lifetime will not extend past the lifetime of the donated input buffer. - if ((!can_donate || !arg->owns_buffers()) && !allocation.is_readonly()) { + if ((!can_donate || (arg && !arg->owns_buffers())) && + !allocation.is_readonly()) { auto copy = tsl::MakeUnconstructedAsyncValueRef(); buffer_alloc_and_copy.src_buffers.push_back(out.CopyRef()); @@ -1335,8 +1360,8 @@ static absl::StatusOr MemoryForAllocation( } buffer_info.buffer = out.CopyRef(); - buffer_info.owns_buffer = arg->owns_buffers(); - buffer_info.buffer_size = arg->BufferSize(allocation.param_shape_index()); + buffer_info.owns_buffer = !arg || arg->owns_buffers(); + buffer_info.buffer_size = buffer_size; return buffer_info; } else if (allocation.is_constant() && @@ -1372,14 +1397,15 @@ static absl::StatusOr> CreateBufferTable( const BufferAssignment& assignment, absl::Span constants, absl::Span const> arguments, - BufferAlloc& buffer_alloc, BufferAllocAndCopy& buffer_alloc_and_copy) { + BufferAlloc& buffer_alloc, BufferAllocAndCopy& buffer_alloc_and_copy, + const tsl::AsyncValueRef& tuple_index_table) { std::vector buffer_table(assignment.Allocations().size()); for (BufferAllocation::Index i = 0; i < buffer_table.size(); ++i) { const BufferAllocation& allocation = assignment.GetAllocation(i); TF_ASSIGN_OR_RETURN( buffer_table[i], MemoryForAllocation(allocation, constants, arguments, buffer_alloc, - buffer_alloc_and_copy)); + buffer_alloc_and_copy, tuple_index_table)); } return std::move(buffer_table); } @@ -1546,29 +1572,29 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( // Tuplize the inputs if compiler expects a single tuple argument but runtime // gets many inputs that are not yet tupled. - std::unique_ptr tuplized_arg; + tsl::AsyncValueRef tuple_index_table; if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { - bool owns_buffers = true; absl::InlinedVector, 4> leaf_buffers; - absl::InlinedVector leaf_buffer_sizes; leaf_buffers.reserve(tracked_buffers.size()); - leaf_buffer_sizes.reserve(tracked_buffers.size()); for (const auto& tracked_buffer : tracked_buffers) { - owns_buffers = owns_buffers && tracked_buffer.second->owns_buffers(); auto span = tracked_buffer.second->Buffers(); leaf_buffers.insert(leaf_buffers.end(), span.begin(), span.end()); - auto size_span = tracked_buffer.second->BufferSizes(); - leaf_buffer_sizes.insert(leaf_buffer_sizes.end(), size_span.begin(), - size_span.end()); } - - // Tuplize into a single input. - tracked_buffers.clear(); - tuplized_arg = std::make_unique( - /*is_tuple=*/true, owns_buffers, std::move(leaf_buffers), - std::move(leaf_buffer_sizes), - /*definition_event=*/tsl::MakeConstructedAsyncValueRef()); - tracked_buffers.emplace_back(false, tuplized_arg.get()); + tuple_index_table = tsl::MakeUnconstructedAsyncValueRef(); + tsl::RunWhenReady( + absl::MakeConstSpan(leaf_buffers), + [buffers = leaf_buffers, tuple_index_table = tuple_index_table] { + size_t index_table_byte_size = buffers.size() * sizeof(void*); + // We assume tuple table allocations will not fail. + tuple_index_table.emplace( + CpuDeviceMemory::Allocate(index_table_byte_size).value()); + uintptr_t* index_table = + reinterpret_cast(tuple_index_table->untyped_data()); + for (int i = 0; i < buffers.size(); ++i) { + index_table[i] = + absl::bit_cast(buffers[i]->untyped_data()); + } + }); } auto* cpu_executable = @@ -1581,7 +1607,8 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( std::vector buffer_table, CreateBufferTable(cpu_executable->buffer_assignment(), cpu_executable->constants(), tracked_buffers, - buffer_alloc, buffer_alloc_and_copy)); + buffer_alloc, buffer_alloc_and_copy, + tuple_index_table)); auto result_buffers_info = CreateResultBufferInfo(result_buffer_indices_, buffer_table); @@ -1774,7 +1801,7 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( device_assignment = std::move(device_assignment), cpu_run_options = std::move(cpu_run_options), compute_reservation = std::move(compute_reservation), - tuplized_arg = std::move(tuplized_arg), + tuple_index_table = std::move(tuple_index_table), donation_transactions = std::move(donation_transactions), scoped_async_execution = std::move(scoped_async_execution), input_deps_avs = std::move(input_deps_avs_copy), From 88cbfdecee62ee0cc00a03e009bd7018db8d0a74 Mon Sep 17 00:00:00 2001 From: Tom Ward Date: Tue, 1 Apr 2025 11:06:58 -0700 Subject: [PATCH 0103/1324] Reverts 59f2d850d14539a80060544a81eedcd3ed3b0067 PiperOrigin-RevId: 742768178 --- third_party/xla/xla/client/local_client.cc | 25 +- third_party/xla/xla/client/local_client.h | 9 - third_party/xla/xla/pjrt/BUILD | 5 - .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 34 +++ .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 10 + .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 5 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc | 12 +- .../pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc | 9 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 232 ++++++------------ .../xla/pjrt/pjrt_stream_executor_client.h | 27 +- .../xla/pjrt/stream_executor_executable.cc | 57 ----- .../xla/xla/pjrt/stream_executor_executable.h | 75 ++---- .../xla/xla/service/buffer_assignment.cc | 5 +- 13 files changed, 153 insertions(+), 352 deletions(-) diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index fd1bbae81e4503..2a1a6d74448b10 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -473,31 +473,18 @@ LocalClient::CompileAheadOfTime( absl::StatusOr> LocalClient::Load( const std::string& serialized_aot_result, const ExecutableBuildOptions& options) { - TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(platform())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr aot_result, - compiler->LoadAotCompilationResult(serialized_aot_result)); - return LoadInternal(std::move(aot_result), compiler, options); -} - -absl::StatusOr> LocalClient::Load( - std::unique_ptr aot_result, - const ExecutableBuildOptions& options) { - TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(platform())); - return LoadInternal(std::move(aot_result), compiler, options); -} - -absl::StatusOr> LocalClient::LoadInternal( - std::unique_ptr aot_result, Compiler* compiler, - const ExecutableBuildOptions& options) { TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, UpdateBuildOptions(options, default_device_ordinal())); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend().stream_executor(updated_options.device_ordinal())); + TF_ASSIGN_OR_RETURN(Compiler * compiler, + Compiler::GetForPlatform(platform())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + TF_ASSIGN_OR_RETURN( std::unique_ptr executable, std::move(*aot_result).LoadExecutable(compiler, executor)); diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index c687766fcc37b8..c9ee317bc42e5a 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -174,11 +174,6 @@ class LocalClient : public Client { const std::string& serialized_aot_result, const ExecutableBuildOptions& options); - // Variant of `Load()` that accepts an AotCompilationResult. - absl::StatusOr> Load( - std::unique_ptr aot_result, - const ExecutableBuildOptions& options); - // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the @@ -249,10 +244,6 @@ class LocalClient : public Client { private: LocalService* local_service_; - - absl::StatusOr> LoadInternal( - std::unique_ptr aot_result, Compiler* compiler, - const ExecutableBuildOptions& options); }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index b03ab7cb0020d5..2ee156f7faefe0 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -483,13 +483,9 @@ cc_library( srcs = ["stream_executor_executable.cc"], hdrs = ["stream_executor_executable.h"], deps = [ - ":host_memory_spaces", ":pjrt_common", ":pjrt_executable", ":stream_executor_executable_proto_cc", - "//xla:shape_util", - "//xla:util", - "//xla/client:local_client", "//xla/hlo/ir:hlo", "//xla/service:compiler", "@com_google_absl//absl/container:flat_hash_map", @@ -521,7 +517,6 @@ cc_library( ":pjrt_future", ":pjrt_stream_executor_device_description", ":semaphore", - ":stream_executor_executable", ":tracked_device_buffer", ":transpose", ":utils", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index e6d1ff8645a312..0c859c6e49fa34 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -906,6 +906,40 @@ StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, load_options); } +absl::StatusOr> +StreamExecutorGpuClient::Load(std::unique_ptr executable) { + auto se_executable = absl::WrapUnique( + tensorflow::down_cast(executable.release())); + + CompileOptions compile_options = se_executable->compile_options(); + CompileOptions input_options = compile_options; + TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); + TF_ASSIGN_OR_RETURN(ExecutableExtras extras, + GetExecutableExtras(&compile_options)); + + // Load Executable from AOT compilation result. + std::vector> local_executables; + local_executables.reserve(se_executable->aot_executables().size()); + for (std::unique_ptr& aot_executable : + se_executable->aot_executables()) { + TF_ASSIGN_OR_RETURN(std::string serialized, + aot_executable->SerializeAsString()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr local_executable, + client()->Load(serialized, compile_options.executable_build_options)); + local_executables.push_back(std::move(local_executable)); + } + bool parameter_is_tupled_arguments = + compile_options.parameter_is_tupled_arguments; + auto ret = std::make_unique( + std::move(local_executables), parameter_is_tupled_arguments, + std::move(extras.device_assignment), std::move(input_options), + std::move(extras.addressable_device_logical_ids), + std::move(extras.addressable_devices), this); + TF_RETURN_IF_ERROR(ret->SetUpDonation(parameter_is_tupled_arguments)); + return std::unique_ptr(std::move(ret)); +} + namespace { #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index a78bf13e641fb2..7d40b04535a839 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -140,6 +140,16 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { return &topology_; } + absl::StatusOr> Load( + std::unique_ptr executable, + const LoadOptions& load_options) override { + return absl::WrapUnique( + tensorflow::down_cast(executable.release())); + } + + absl::StatusOr> Load( + std::unique_ptr executable); + absl::StatusOr> LoadSerialized( absl::string_view serialized, std::optional options, const LoadOptions& load_options); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 9737ae193e0db5..78bb3c249c085c 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1154,7 +1154,7 @@ TEST(StreamExecutorGpuClientTest, GetDeviceFabricInfo) { &executor->GetDeviceDescription())) == 9) { auto fabric_info = GetDeviceFabricInfo(executor->device_ordinal()); if (fabric_info.ok()) { - ADD_FAILURE(); + EXPECT_FALSE(true); } } } @@ -1924,7 +1924,8 @@ TEST(StreamExecutorGpuClientTest, MlirParameterLayoutFromOptionsIsSetInHlo) { xla::CompileOptions options; options.argument_layouts = { {ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 2, 2}, {0, 2, 1})}}; - TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, options)); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + client->CompileAndLoad(*module, options)); TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); auto first_param_layout = diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 375748f4fad3bc..ace0900148d075 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -131,7 +131,7 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, if (!options.target_config) { if (client != nullptr) { TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); - return client->Compile(computation, options); + return client->CompileAndLoad(computation, options); } const auto& gpu_topology = tensorflow::down_cast( @@ -177,17 +177,21 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, const int num_partitions = hlo_module->config().num_partitions(); const std::string name = hlo_module->name(); const std::string fingerprint = hlo_module->GetFingerprint128(); + const int num_outputs = hlo_module->result_shape().IsTuple() + ? hlo_module->result_shape().tuple_shapes_size() + : 1; auto unique_module_group = std::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> aot_results, gpu_compiler->CompileAheadOfTime(std::move(unique_module_group), aot_options)); + std::vector> output_memory_kinds(1); + output_memory_kinds[0].resize(num_outputs, + StreamExecutorGpuHbmMemorySpace::kKind); return std::make_unique( std::move(input_options), std::move(aot_results), num_replicas, - num_partitions, name, fingerprint, - /*default_memory_kind=*/StreamExecutorGpuHbmMemorySpace::kKind, - /*local_executables=*/std::nullopt); + num_partitions, name, fingerprint, std::move(output_memory_kinds)); } absl::StatusOr> diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 2ae9ee769fc260..8c0b9bc6d3a182 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -99,9 +99,8 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) { TF_ASSERT_OK_AND_ASSIGN(auto executable, compiler.Compile(opts, mlir_module.get(), *topology, /*client=*/nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto loaded_executable, - se_client->Load(std::move(executable), LoadOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable, + se_client->Load(std::move(executable))); TF_ASSERT_OK_AND_ASSIGN( std::vector>> result, @@ -130,7 +129,7 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { compiler.Compile(opts, computation, *topology, /*client=*/nullptr)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr loaded_executable, - se_client->Load(std::move(executable), LoadOptions())); + se_client->Load(std::move(executable))); TF_ASSERT_OK_AND_ASSIGN( std::vector>> result, loaded_executable->Execute(/*argument_handles=*/{{}}, {})); @@ -193,7 +192,7 @@ TEST(StreamExecutorGpuCompilerTest, SuccessSerializeDeserialize) { compiler.Compile(opts, computation, *topology, /*client=*/nullptr)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr loaded_executable, - se_client->Load(std::move(executable), LoadOptions())); + se_client->Load(std::move(executable))); // Serialize the executable and deserialize it without failure. TF_ASSERT_OK_AND_ASSIGN(std::string serialized_executable, diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 20eb79146e2c6d..097bc9071a85a0 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -117,7 +117,6 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/profiling/device_time_measurement.h" #include "xla/pjrt/semaphore.h" -#include "xla/pjrt/stream_executor_executable.h" #include "xla/pjrt/tracked_device_buffer.h" #include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" @@ -3520,14 +3519,14 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { return extras; } -absl::StatusOr> +absl::StatusOr> PjRtStreamExecutorClient::CompileInternal( const XlaComputation& computation, const std::vector& argument_layout_pointers, LayoutCanonicalizationCallback layout_canonicalization_callback, CompileOptions options) { - tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::CompileInternal"); - VLOG(1) << "PjRtStreamExecutorClient::CompileInternal"; + tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile"); + VLOG(1) << "PjRtStreamExecutorClient::Compile"; if (key_value_store().has_value() && !options.executable_build_options.key_value_store()) { options.executable_build_options.set_key_value_store(*key_value_store()); @@ -3536,6 +3535,13 @@ PjRtStreamExecutorClient::CompileInternal( TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); + TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options)); + std::shared_ptr& device_assignment = + extras.device_assignment; + std::vector& + addressable_device_logical_ids = extras.addressable_device_logical_ids; + std::vector& addressable_devices = extras.addressable_devices; + // It is important to set the canonicalization callback after creating // a copy of the options so that the executable's options remain without // the callback - the callback would break the executable's serializability. @@ -3549,39 +3555,26 @@ PjRtStreamExecutorClient::CompileInternal( client()->Compile(computation, argument_layout_pointers, options.executable_build_options)); - return BuildPjRtExecutable(std::move(local_executables), input_options); -} + auto executable = std::make_unique( + std::move(local_executables), options.parameter_is_tupled_arguments, + std::move(device_assignment), std::move(input_options), + std::move(addressable_device_logical_ids), std::move(addressable_devices), + this); -absl::StatusOr> -PjRtStreamExecutorClient::Compile(const XlaComputation& computation, - CompileOptions options) { - std::vector argument_layout_pointers; - const ExecutableBuildOptions& build_options = - options.executable_build_options; - const bool allow_auto_layout = - build_options.has_debug_options() && - build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); - TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( - computation, - [local_client = client(), - allow_auto_layout](Shape shape) -> absl::StatusOr { - if (allow_auto_layout && !shape.has_layout()) { - return shape; - } - return local_client->backend() - .transfer_manager() - ->ChooseCompactLayoutForShape(shape); - }, - options.argument_layouts, &options.executable_build_options, - &argument_layout_pointers)); - return CompileInternal(computation, argument_layout_pointers, - /* layout_canonicalization_callback = */ nullptr, - options); + TF_RETURN_IF_ERROR( + executable->SetUpDonation(options.parameter_is_tupled_arguments)); + const auto& ex_options = options.executable_build_options; + if (ex_options.has_debug_options() && + ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots()) { + executable->SetInputHloSnapshotBits( + computation.proto(), options.executable_build_options.debug_options()); + } + return std::unique_ptr(std::move(executable)); } -absl::StatusOr> -PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, - CompileOptions options) { +absl::StatusOr> +PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, + CompileOptions options) { XlaComputation xla_computation; const ExecutableBuildOptions& exec_build_options = options.executable_build_options; @@ -3593,7 +3586,7 @@ PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, // If the compile options specify argument layout, then let's // fall back to using the options to determine layouts. if (options.argument_layouts) { - return Compile(xla_computation, options); + return CompileAndLoad(xla_computation, options); } TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, @@ -3643,17 +3636,28 @@ PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, absl::StatusOr> PjRtStreamExecutorClient::CompileAndLoad(const XlaComputation& computation, CompileOptions options) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - Compile(computation, options)); - return Load(std::move(executable), LoadOptions()); -} - -absl::StatusOr> -PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, - CompileOptions options) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - Compile(module, options)); - return Load(std::move(executable), LoadOptions()); + std::vector argument_layout_pointers; + const ExecutableBuildOptions& build_options = + options.executable_build_options; + const bool allow_auto_layout = + build_options.has_debug_options() && + build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); + TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( + computation, + [local_client = client(), + allow_auto_layout](Shape shape) -> absl::StatusOr { + if (allow_auto_layout && !shape.has_layout()) { + return shape; + } + return local_client->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(shape); + }, + options.argument_layouts, &options.executable_build_options, + &argument_layout_pointers)); + return CompileInternal(computation, argument_layout_pointers, + /* layout_canonicalization_callback = */ nullptr, + options); } absl::StatusOr PjRtStreamExecutorClient::SerializeExecutable( @@ -3690,61 +3694,21 @@ absl::StatusOr PjRtStreamExecutorClient::SerializeExecutable( return proto.SerializeAsString(); } -absl::StatusOr> -PjRtStreamExecutorClient::BuildPjRtExecutable( - std::vector> local_executables, - CompileOptions compile_options) { - if (local_executables.empty()) { - return Internal("No local executable"); - } - if (local_executables.size() != 1) { - return Unimplemented("Multiple executables are not supported"); - } - Executable* built_executable = local_executables[0]->executable(); - Compiler* compiler = client_->backend().compiler(); - TF_ASSIGN_OR_RETURN(std::unique_ptr aot_result, - compiler->Export(built_executable)); - std::vector> aot_results; - aot_results.push_back(std::move(aot_result)); - - if (!built_executable->has_module()) { - return absl::InternalError("Executable does not have HLO modules."); - } - const auto& hlo_module = built_executable->module(); - - const int num_replicas = hlo_module.config().replica_count(); - const int num_partitions = hlo_module.config().num_partitions(); - const std::string name = hlo_module.name(); - const std::string fingerprint = hlo_module.GetFingerprint128(); - - return std::make_unique( - std::move(compile_options), std::move(aot_results), num_replicas, - num_partitions, name, fingerprint, memory_spaces()[0]->kind(), - std::move(local_executables)); -} - -absl::StatusOr> -PjRtStreamExecutorClient::DeserializeExecutable( - absl::string_view serialized, - std::optional compile_options) { - TF_ASSIGN_OR_RETURN( - auto local_executables_and_options, - DeserializeToLocalExecutable(serialized, compile_options)); - - return BuildPjRtExecutable(std::move(local_executables_and_options.first), - local_executables_and_options.second); -} - -absl::StatusOr< - std::pair>, CompileOptions>> -PjRtStreamExecutorClient::DeserializeToLocalExecutable( - absl::string_view serialized, std::optional options) { +absl::StatusOr> +PjRtStreamExecutorClient::LoadSerializedExecutable( + absl::string_view serialized, std::optional options, + const LoadOptions& load_options) { ExecutableAndOptionsProto proto; if (serialized.size() > std::numeric_limits::max()) { - return Internal("Proto is too large (>2GB)"); + return Internal( + "PjRtStreamExecutorClient::DeserializeExecutable proto too large " + "(>2GB)"); } if (!proto.ParseFromArray(serialized.data(), serialized.size())) { - return Internal("Proto deserialization failed"); + return Internal( + "PjRtStreamExecutorClient::DeserializeExecutable proto " + "deserialization " + "failed"); } CompileOptions compile_options; @@ -3754,39 +3718,11 @@ PjRtStreamExecutorClient::DeserializeToLocalExecutable( TF_ASSIGN_OR_RETURN(compile_options, CompileOptions::FromProto(proto.compile_options())); } - - tsl::profiler::TraceMe traceme( - "PjRtStreamExecutorClient::DeserializeToLocalExecutable"); - VLOG(1) << "PjRtStreamExecutorClient::DeserializeToLocalExecutable"; - - std::string str = std::move(*proto.mutable_serialized_executable()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded, - client()->Load(str, compile_options.executable_build_options)); - - std::vector> local_executables; - local_executables.push_back(std::move(loaded)); - - return std::make_pair(std::move(local_executables), compile_options); -} - -absl::StatusOr> -PjRtStreamExecutorClient::LoadSerializedExecutable( - absl::string_view serialized, std::optional options, - const LoadOptions& load_options) { - TF_ASSIGN_OR_RETURN(auto local_executables_and_options, - DeserializeToLocalExecutable(serialized, options)); - return LoadInternal(std::move(local_executables_and_options.first), - local_executables_and_options.second); -} - -absl::StatusOr> -PjRtStreamExecutorClient::LoadInternal( - std::vector> local_executables, - CompileOptions compile_options) { auto input_options = compile_options; - TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); + tsl::profiler::TraceMe traceme( + "PjRtStreamExecutorClient::DeserializeExecutable"); + VLOG(1) << "PjRtStreamExecutorClient::DeserializeExecutable"; TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&compile_options)); @@ -3796,8 +3732,13 @@ PjRtStreamExecutorClient::LoadInternal( addressable_device_logical_ids = extras.addressable_device_logical_ids; std::vector& addressable_devices = extras.addressable_devices; - HloModuleProto hlo_module_proto = - local_executables[0]->executable()->module().ToProto(); + std::string str = std::move(*proto.mutable_serialized_executable()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr loaded, + client()->Load(str, compile_options.executable_build_options)); + + std::vector> local_executables; + local_executables.push_back(std::move(loaded)); auto executable = std::make_unique( std::move(local_executables), @@ -3808,40 +3749,9 @@ PjRtStreamExecutorClient::LoadInternal( TF_RETURN_IF_ERROR( executable->SetUpDonation(compile_options.parameter_is_tupled_arguments)); - const auto& ex_options = compile_options.executable_build_options; - if (ex_options.has_debug_options() && - ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots()) { - executable->SetInputHloSnapshotBits( - std::move(hlo_module_proto), - compile_options.executable_build_options.debug_options()); - } return std::unique_ptr(std::move(executable)); } -absl::StatusOr> -PjRtStreamExecutorClient::Load(std::unique_ptr executable, - const LoadOptions& load_options) { - auto se_executable = absl::WrapUnique( - tensorflow::down_cast(executable.release())); - CompileOptions compile_options = se_executable->compile_options(); - - tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Load"); - VLOG(1) << "PjRtStreamExecutorClient::Load"; - - // Load Executables from AOT compilation results. - std::vector> local_executables; - local_executables.reserve(se_executable->aot_executables().size()); - for (int i = 0; i < se_executable->aot_executables().size(); ++i) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr local_executable, - client()->Load(std::move(se_executable->aot_executables()[i]), - compile_options.executable_build_options)); - local_executables.push_back(std::move(local_executable)); - } - - return LoadInternal(std::move(local_executables), compile_options); -} - bool PjRtStreamExecutorClient::IsDmaMapped(const void* data_start, int64_t transfer_size) { absl::MutexLock lock(&dma_maps_mutex_); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index b2e0f804f7d2bd..1e5c8c7cc3b001 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -276,22 +276,14 @@ class PjRtStreamExecutorClient : public PjRtClient { absl::StatusOr GetDefaultLayout( PrimitiveType element_type, absl::Span dims) override; - absl::StatusOr> Compile( - const XlaComputation& computation, CompileOptions options) override; absl::StatusOr> CompileAndLoad( const XlaComputation& computation, CompileOptions options) override; - absl::StatusOr> Compile( - mlir::ModuleOp mlir_module, CompileOptions options) override; absl::StatusOr> CompileAndLoad( mlir::ModuleOp mlir_module, CompileOptions options) override; virtual absl::StatusOr SerializeExecutable( const PjRtLoadedExecutable& executable) const; - absl::StatusOr> DeserializeExecutable( - absl::string_view serialized, - std::optional options) override; - // For PjRtStreamExecutorClient, `options` is mandatory. // This function returns an InvalidArgument error if `std::nullopt` is passed. // TODO(b/237720161): make it actually optional @@ -300,10 +292,6 @@ class PjRtStreamExecutorClient : public PjRtClient { std::optional options, const LoadOptions& load_options) override; - absl::StatusOr> Load( - std::unique_ptr executable, - const LoadOptions& load_options) override; - absl::StatusOr> GetHloCostAnalysis() const override; @@ -429,25 +417,12 @@ class PjRtStreamExecutorClient : public PjRtClient { }; absl::StatusOr GetExecutableExtras(CompileOptions* options); - absl::StatusOr> CompileInternal( + absl::StatusOr> CompileInternal( const XlaComputation& computation, const std::vector& argument_layout_pointers, LayoutCanonicalizationCallback layout_canonicalization_callback, CompileOptions options); - absl::StatusOr> BuildPjRtExecutable( - std::vector> local_executables, - CompileOptions compile_options); - - absl::StatusOr< - std::pair>, CompileOptions>> - DeserializeToLocalExecutable(absl::string_view serialized, - std::optional options); - - absl::StatusOr> LoadInternal( - std::vector> local_executables, - CompileOptions compile_options); - absl::StatusOr> BufferFromHostBufferInternal( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.cc b/third_party/xla/xla/pjrt/stream_executor_executable.cc index a500bcb07d33c3..ab82fdaf0c2ec1 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.cc +++ b/third_party/xla/xla/pjrt/stream_executor_executable.cc @@ -18,15 +18,11 @@ limitations under the License. #include #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/compiler.h" -#include "xla/shape.h" -#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -54,57 +50,4 @@ absl::StatusOr StreamExecutorExecutable::SerializeExecutable() compile_options_.ToProto()); return proto.SerializeAsString(); } - -namespace { - -absl::StatusOr MemoryKindFromSimpleShape( - const Shape& shape, absl::string_view default_memory_kind) { - if (!shape.has_layout()) { - return default_memory_kind; - } - switch (shape.layout().memory_space()) { - case Layout::kHostMemorySpace: - return PinnedHostMemorySpace::kKind; - case Layout::kGenericFastMemorySpace: - case Layout::kDefaultMemorySpace: - return default_memory_kind; - default: - return InvalidArgument("Unexpected memory space %d in output layout", - shape.layout().memory_space()); - } -} - -absl::StatusOr> MemoryKindsFromShape( - const Shape& shape, absl::string_view default_memory_kind) { - if (!shape.IsTuple()) { - TF_ASSIGN_OR_RETURN(absl::string_view memory_kind, - MemoryKindFromSimpleShape(shape, default_memory_kind)); - return {{memory_kind}}; - } - std::vector result; - result.reserve(shape.tuple_shapes_size()); - for (const auto& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN( - absl::string_view element_memory_kind, - MemoryKindFromSimpleShape(element_shape, default_memory_kind)); - result.push_back(element_memory_kind); - } - return result; -} - -} // namespace - -absl::StatusOr>> -StreamExecutorExecutable::GetOutputMemoryKinds() const { - TF_ASSIGN_OR_RETURN(auto shapes, GetOutputShapes()); - std::vector> out; - out.reserve(shapes.size()); - for (const auto& shape : shapes) { - TF_ASSIGN_OR_RETURN(std::vector memory_kind, - MemoryKindsFromShape(shape, default_memory_kind_)); - out.push_back(memory_kind); - } - return out; -} - } // namespace xla diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.h b/third_party/xla/xla/pjrt/stream_executor_executable.h index aab85c2552924b..826e4f2912f176 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.h +++ b/third_party/xla/xla/pjrt/stream_executor_executable.h @@ -27,7 +27,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/client/local_client.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" @@ -36,32 +35,19 @@ limitations under the License. namespace xla { class StreamExecutorExecutable : public PjRtExecutable { public: - // TODO(b/407470731): Make `xla::AotCompilationResult` provide APIs for - // getting code size, memory stats, etc, so that we do not need to rely on - // `LocalExecutable`s for such information. StreamExecutorExecutable( const CompileOptions& compile_options, - std::vector> aot_executables, + std::vector> executables, int num_replicas, int num_partitions, absl::string_view name, - absl::string_view fingerprint, absl::string_view default_memory_kind, - std::optional>> - local_executables) + absl::string_view fingerprint, + std::optional>> + output_memory_kinds) : compile_options_(compile_options), - aot_executables_(std::move(aot_executables)), + aot_executables_(std::move(executables)), num_replicas_(num_replicas), num_partitions_(num_partitions), name_(name), - fingerprint_(fingerprint), - default_memory_kind_(default_memory_kind), - local_executables_(std::move(local_executables)) { - if (local_executables_.has_value()) { - std::vector> hlo_modules; - for (const auto& local_executable : *local_executables_) { - hlo_modules.push_back(local_executable->executable()->shared_module()); - } - hlo_modules_ = std::move(hlo_modules); - } - } + fingerprint_(fingerprint) {} absl::StatusOr SerializeExecutable() const override; @@ -73,51 +59,22 @@ class StreamExecutorExecutable : public PjRtExecutable { } absl::StatusOr>> GetHloModules() const override { - if (!hlo_modules_.has_value()) { - return absl::UnimplementedError("GetHloModules is not supported."); - } - return *hlo_modules_; + return absl::UnimplementedError("GetHloModules is not supported."); } - absl::StatusOr GetCompiledMemoryStats() const override { - if (!local_executables_.has_value()) { - return absl::UnimplementedError( - "Retrieving CompiledMemoryStats is not supported."); - } - if (local_executables_->size() != 1) { - return absl::UnimplementedError( - "Retrieving CompiledMemoryStats is not supported for multiple " - "executables."); - } - CompiledMemoryStats memory_stats = CompiledMemoryStats(); - memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); - const HloProto* proto = (*local_executables_)[0]->executable()->hlo_proto(); - if (proto != nullptr) { - memory_stats.serialized_hlo_proto = proto->SerializeAsString(); + absl::StatusOr>> + GetOutputMemoryKinds() const override { + if (output_memory_kinds_.has_value()) { + return *output_memory_kinds_; } - memory_stats.PopulateBufferStatsFromAllocations( - (*local_executables_)[0]->executable()->GetAllocations()); - return memory_stats; + return absl::UnimplementedError("GetOutputMemoryKinds is not supported."); } - - absl::StatusOr>> - GetOutputMemoryKinds() const override; - absl::StatusOr> GetCostAnalysis() const override { return absl::UnimplementedError("GetCostAnalysis is not supported."); } - int64_t SizeOfGeneratedCodeInBytes() const override { - if (!local_executables_.has_value()) { - return 0; - } - int64_t size = 0; - for (auto& executable : *local_executables_) { - size += executable->executable()->SizeOfGeneratedCodeInBytes(); - } - return size; - } + int64_t SizeOfGeneratedCodeInBytes() const override { return 0; } const CompileOptions& compile_options() const { return compile_options_; } std::vector>& aot_executables() { @@ -131,14 +88,12 @@ class StreamExecutorExecutable : public PjRtExecutable { private: CompileOptions compile_options_; std::vector> aot_executables_; - std::optional>> hlo_modules_; int num_replicas_; int num_partitions_; std::string name_; std::string fingerprint_; - absl::string_view default_memory_kind_; - std::optional>> - local_executables_; + std::optional>> + output_memory_kinds_; }; } // namespace xla diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 1e6ea95df721fa..c809be5b0971f6 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -1225,11 +1225,8 @@ absl::StatusOr> BufferAssignment::FromProto( absl::c_copy(alloc_proto.parameter_shape_index(), std::back_inserter(shape_idx_vals)); ShapeIndex shape_index(shape_idx_vals); - const bool parameter_has_alias = - module->input_output_alias_config().ParameterHasAlias( - alloc_proto.parameter_number(), shape_index); allocation->set_entry_computation_parameter( - alloc_proto.parameter_number(), shape_index, parameter_has_alias); + alloc_proto.parameter_number(), shape_index, false); } // Process each logical buffer assigned to the current allocation and create From 661bc9909db82ef5d3f9ebbf1391dcf1913ed69e Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 1 Apr 2025 11:08:23 -0700 Subject: [PATCH 0104/1324] Fix repeated InterpreterClient execution of the same executable. Before it was not possible to run the same InterpreterClient-generated PjRtLoadedExecutable twice. PiperOrigin-RevId: 742768821 --- third_party/xla/xla/pjrt/interpreter/BUILD | 21 ++++ .../pjrt/interpreter/interpreter_client.cc | 1 + .../interpreter/interpreter_client_test.cc | 109 ++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 third_party/xla/xla/pjrt/interpreter/interpreter_client_test.cc diff --git a/third_party/xla/xla/pjrt/interpreter/BUILD b/third_party/xla/xla/pjrt/interpreter/BUILD index 23921961dee897..fd24ba80df4485 100644 --- a/third_party/xla/xla/pjrt/interpreter/BUILD +++ b/third_party/xla/xla/pjrt/interpreter/BUILD @@ -1,3 +1,4 @@ +load("//xla:xla.default.bzl", "xla_cc_test") load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") @@ -67,3 +68,23 @@ cc_library( "@local_tsl//tsl/platform:fingerprint", ], ) + +xla_cc_test( + name = "interpreter_client_test", + srcs = ["interpreter_client_test.cc"], + deps = [ + ":interpreter_client", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_executable", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc b/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc index a68960ed0fcdb4..05cc0f7976c2ea 100644 --- a/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc +++ b/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc @@ -314,6 +314,7 @@ absl::StatusOr InterpreterLoadedExecutable::Evaluate( const HloComputation& computation, absl::Span arg_literals) { absl::MutexLock lock(&hlo_evaluator_lock_); + hlo_evaluator_->ResetVisitStates(); return hlo_evaluator_->Evaluate(computation, arg_literals); } diff --git a/third_party/xla/xla/pjrt/interpreter/interpreter_client_test.cc b/third_party/xla/xla/pjrt/interpreter/interpreter_client_test.cc new file mode 100644 index 00000000000000..bfe677111272fd --- /dev/null +++ b/third_party/xla/xla/pjrt/interpreter/interpreter_client_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/interpreter/interpreter_client.h" + +#include +#include +#include + +#include +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace { + +TEST(InterpreterClientTest, EvaluateOnceShouldSucceed) { + InterpreterClient client; + const Shape shape = ShapeUtil::MakeShape(S32, {4}); + XlaBuilder builder("test"); + Add(Parameter(&builder, 0, shape, "parameter0"), + ConstantR1(&builder, absl::Span{1, 1, 1, 1})); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client.CompileAndLoad(computation, CompileOptions())); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr argument, + client.BufferFromHostLiteral( + LiteralUtil::CreateR1(absl::Span{1, 2, 3, 4}), + client.memory_spaces().front())); + TF_ASSERT_OK_AND_ASSIGN( + std::vector>> results, + executable->Execute({{argument.get()}}, ExecuteOptions())); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results.front().size(), 1); + Literal result_literal(shape); + TF_ASSERT_OK(results.front().front()->ToLiteralSync(&result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal( + result_literal, + LiteralUtil::CreateR1(absl::Span{2, 3, 4, 5}))); +} + +TEST(InterpreterClientTest, EvaluateTwiceShouldSucceed) { + InterpreterClient client; + const Shape shape = ShapeUtil::MakeShape(S32, {4}); + XlaBuilder builder("test"); + Add(Parameter(&builder, 0, shape, "parameter0"), + ConstantR1(&builder, absl::Span{1, 1, 1, 1})); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client.CompileAndLoad(computation, CompileOptions())); + + std::vector>> results; + for (const Literal& execution_argument : + {LiteralUtil::CreateR1(absl::Span{1, 2, 3, 4}), + LiteralUtil::CreateR1(absl::Span{4, 3, 2, 1})}) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr argument_buffer, + client.BufferFromHostLiteral(execution_argument, + client.memory_spaces().front())); + TF_ASSERT_OK_AND_ASSIGN( + results.emplace_back(), + executable->ExecuteSharded({argument_buffer.get()}, + client.addressable_devices().front(), + ExecuteOptions())); + } + + std::vector expected_literals; + expected_literals.push_back( + LiteralUtil::CreateR1(absl::Span{2, 3, 4, 5})); + expected_literals.push_back( + LiteralUtil::CreateR1(absl::Span{5, 4, 3, 2})); + + ASSERT_EQ(results.size(), 2); + Literal actual_literal(shape); + for (int i = 0; i < results.size(); ++i) { + const std::vector>& actual_buffers = results[i]; + EXPECT_EQ(actual_buffers.size(), 1); + TF_ASSERT_OK(actual_buffers.front()->ToLiteralSync(&actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(actual_literal, expected_literals[i])); + } +} + +} // namespace +} // namespace xla From 106318740f81528902f2b52e9e549fa2859abf27 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Tue, 1 Apr 2025 11:12:08 -0700 Subject: [PATCH 0105/1324] `mhlo.copy` op, remove the folder Motivation: `mhlo.copy` ODS spec has a `hasfolder` enabled. During MLIR dialect conversion, if the mhlo.copy op is not a legal op (example: during chlo -> stablehlo conversion), MLIR try to legalize the op by folding it, if the op `hasfolder` enabled (`mlir/lib/Transforms/Utils/DialectConversion.cpp;l=2058-2067`). This results in removing the mhlo.copy op. Ideally: 1. `mhlo.copy` should be used only when it is needed during chlo -> stablehlo. 2. If the `mhlo.copy` is expected to be present in input module during MLIR conversion (chlo -> stablehlo), it should not be folded and should be preserved. To fix this, removing the folder for the op. PiperOrigin-RevId: 742770255 --- third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 6 ------ third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td | 1 - .../tests/Dialect/mhlo/canonicalize/canonicalize.mlir | 11 ----------- .../tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir | 4 ++-- .../tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir | 3 +-- .../python/pjrt_ifrt/xla_executable_impl_test_lib.cc | 10 +++------- .../xla/service/spmd/shardy/shardy_xla_pass_test.cc | 6 +----- 7 files changed, 7 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 30c3153d63fc62..48c3c255fd3173 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -3763,12 +3763,6 @@ LogicalResult RecvOp::verify() { getIsHostTransfer(), getResults()); } -//===----------------------------------------------------------------------===// -// CopyOp -//===----------------------------------------------------------------------===// - -OpFoldResult CopyOp::fold(FoldAdaptor) { return getOperand(); } - //===----------------------------------------------------------------------===// // ReduceWindowOp //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index d735b5b46e6315..9d5ba5557e6c8f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2525,7 +2525,6 @@ def MHLO_CopyOp: MHLO_Op<"copy", ); let results = (outs MHLO_TensorOrTokenOrTuple:$result); let hasCustomHLOConverter = 1; - let hasFolder = 1; let assemblyFormat = [{ operands attr-dict diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir index 134f5c5f57a9a0..258ad44e8b36bd 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir @@ -211,17 +211,6 @@ func.func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { func.return %2 : tensor<2x2xi32> } -//////// -// CopyOp - -// CHECK-LABEL: func @fold_copy -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func.func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK: return [[ARG]] - %0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - //////// // DynamicBroadcastInDimOp diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index 6e91799784d0bd..a3ea23db213eeb 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -798,10 +798,10 @@ func.func @complex_sin(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex) -> tensor<2x4x8xf32> { %0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) + // CHECK-PRIMITIVE: linalg.map + // CHECK: return [[ARG]] : tensor<2x4x8xf32> func.return %0 : tensor<2x4x8xf32> } -// CHECK: return [[ARG]] : tensor<2x4x8xf32> -// CHECK-PRIMITIVE: return [[ARG]] : tensor<2x4x8xf32> // ----- diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 99fc79ea16ab85..4a2c82e90763a7 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -2124,8 +2124,7 @@ func.func @op_bitcast(%arg0: tensor) -> tensor { // ----- func.func @op_copy(%arg0: tensor) -> tensor { - // mhlo.copy is immediately folded away at the first opportunity, - // so it doesn't seem to be possible to capture it in FileCheck tests. + // expected-error@+1 {{failed to legalize operation 'mhlo.copy' that was explicitly marked illegal}} %0 = "mhlo.copy"(%arg0) : (tensor) -> tensor func.return %0 : tensor } diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 5f2e0239125156..92e897e1b9646d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -118,13 +118,9 @@ TEST(LoadedExecutableImplTest, GetDonatableInputIndices) { %arg2: tensor<2x3xf32> {jax.buffer_donor = true}, %arg3: tensor<2x3xf32> ) -> tensor<2x3xf32> { - %0 = "mhlo.copy"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - %1 = "mhlo.copy"(%arg1) : (tensor<2x3xf32>) -> tensor<2x3xf32> - %2 = "mhlo.copy"(%arg2) : (tensor<2x3xf32>) -> tensor<2x3xf32> - %3 = "mhlo.copy"(%arg3) : (tensor<2x3xf32>) -> tensor<2x3xf32> - %4 = mhlo.add %0, %1 : tensor<2x3xf32> - %5 = mhlo.add %2, %3 : tensor<2x3xf32> - %6 = mhlo.add %4, %5 : tensor<2x3xf32> + %4 = stablehlo.add %arg0, %arg1 : tensor<2x3xf32> + %5 = stablehlo.add %arg2, %arg3 : tensor<2x3xf32> + %6 = stablehlo.add %4, %5 : tensor<2x3xf32> return %6 : tensor<2x3xf32> }})"; diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index d2fe76322c5e1a..41ffe8c23ff6cf 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -71,7 +71,7 @@ TEST_F(ShardyXLATest, AllowSpmdShardingPropagationParametersOutputRespected) { %dot = f32[8,256,128] dot(%p0, %p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={2}, sharding={devices=[2,2,2]<=[8]} - ROOT %copy = f32[8,256,128] copy(%dot), sharding={replicated} + ROOT %tuple = (f32[8,256,128]) tuple(%dot) })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); @@ -118,10 +118,6 @@ TEST_F(ShardyXLATest, ElementWise) { EXPECT_THAT(module->entry_computation()->root_instruction(), op::Sharding("{devices=[2,1]<=[2]}")); - - // Conversions HLO -> StableHLO -> HLO removes the copy instructions. - auto* copy = FindInstruction(module.get(), xla::HloOpcode::kCopy); - EXPECT_EQ(copy, nullptr); } TEST_F(ShardyXLATest, CostantSplitter) { From 287f82198476183aeb9f6cc73a17ffd4b73fe66e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 11:25:28 -0700 Subject: [PATCH 0106/1324] Update ops-related pbtxt files. PiperOrigin-RevId: 742775606 --- ...MatmulCustomCombinerOnTcWithCsrInput.pbtxt | 79 +++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 79 +++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt new file mode 100644 index 00000000000000..89d2a61e5c53ae --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt @@ -0,0 +1,79 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + output_arg { + name: "activations" + type: DT_FLOAT + } + output_arg { + name: "preserved_valencies" + type: DT_INT32 + } + output_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_computation" + type: "func" + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index df2588439eda9b..3650e9d60bed3d 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -67609,6 +67609,85 @@ op { has_minimum: true } } +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + output_arg { + name: "activations" + type: DT_FLOAT + } + output_arg { + name: "preserved_valencies" + type: DT_INT32 + } + output_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_computation" + type: "func" + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "table_name" + type: "string" + } +} op { name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput" input_arg { From 2b4159bc4019292aabec27d15002c63f82c741cd Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 1 Apr 2025 11:31:21 -0700 Subject: [PATCH 0107/1324] Move kernel_gen-specific passes from MHLO to kernel_gen dir PiperOrigin-RevId: 742777821 --- .../mlir/tools/kernel_gen/kernel_creator.cc | 6 +- .../tests}/broadcast_propagation.mlir | 2 +- .../kernel_gen/tests}/merge_assuming_ops.mlir | 2 +- .../tests/shape_simplification.mlir | 2 +- .../mlir/tools/kernel_gen/transforms/BUILD | 10 +- .../transforms/broadcast_propagation_pass.cc | 60 +++++++----- .../transforms/merge_assuming_ops_pass.cc | 98 +++++++++---------- .../mlir/tools/kernel_gen/transforms/passes.h | 3 + .../tools/kernel_gen/transforms/passes.td | 18 ++++ .../transforms/shape_simplification_pass.cc | 36 +++---- third_party/xla/xla/mlir_hlo/BUILD | 3 - .../mlir_hlo/mhlo/transforms/CMakeLists.txt | 3 - .../mlir_hlo/mhlo/transforms/mhlo_passes.td | 20 ---- .../xla/xla/mlir_hlo/mhlo/transforms/passes.h | 14 --- .../xla/mlir_hlo/mhlo/transforms/rewriters.h | 6 -- 15 files changed, 131 insertions(+), 152 deletions(-) rename {third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo => tensorflow/compiler/mlir/tools/kernel_gen/tests}/broadcast_propagation.mlir (99%) rename {third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo => tensorflow/compiler/mlir/tools/kernel_gen/tests}/merge_assuming_ops.mlir (99%) rename {third_party/xla/xla/mlir_hlo => tensorflow/compiler/mlir/tools/kernel_gen}/tests/shape_simplification.mlir (98%) rename third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc => tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc (91%) rename third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc => tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc (92%) rename third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc => tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc (91%) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 6e16ce51fdc761..5b53a3eb0d7752 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -170,9 +170,9 @@ absl::Status LowerHlotoLoops(mlir::ModuleOp module, pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::mhlo::createShapeSimplification()); - pm.addNestedPass(mlir::mhlo::createMergeAssumingOpsPass()); - pm.addNestedPass(mlir::mhlo::createBroadcastPropagationPass()); + pm.addNestedPass(mlir::kernel_gen::createShapeSimplificationPass()); + pm.addNestedPass(mlir::kernel_gen::createMergeAssumingOpsPass()); + pm.addNestedPass(mlir::kernel_gen::createBroadcastPropagationPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir similarity index 99% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir index 4bf50644127e70..f366f1938e0a38 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-broadcast-propagation | \ +// RUN: kernel-gen-opt %s --split-input-file --mhlo-broadcast-propagation | \ // RUN: FileCheck %s // CHECK-LABEL: @single_bcast diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir similarity index 99% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir index f8ff1a33d1c97b..d463da199549e3 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect \ +// RUN: kernel-gen-opt --split-input-file --allow-unregistered-dialect \ // RUN: --mhlo-merge-assuming-ops --canonicalize --cse %s | \ // RUN: FileCheck %s diff --git a/third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir similarity index 98% rename from third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir index 998918bdfa0744..f7ff67753bc235 100644 --- a/third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -split-input-file -shape-simplification %s | FileCheck %s +// RUN: kernel-gen-opt -split-input-file -shape-simplification %s | FileCheck %s // Incompatible shapes. No folding. // CHECK-LABEL: func @f diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 13b61395bcda08..331d1aa9c28aa8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -181,15 +181,18 @@ cc_library( cc_library( name = "passes", srcs = [ + "broadcast_propagation_pass.cc", "buffer_reuse_pass.cc", "bufferize_pass.cc", "copy_cleanup_pass.cc", "embed_tf_framework_pass.cc", "func_to_jit_invocations.cc", "fuse_inner_parallel_loops_pass.cc", + "merge_assuming_ops_pass.cc", "parallel_loops_to_sequential.cc", "rewrite_tf_framework_assert.cc", "same_shape_propagation.cc", + "shape_simplification_pass.cc", "shape_to_descriptors_pass.cc", "tensorflow_abi_knowledge_propagation.cc", ], @@ -200,8 +203,6 @@ cc_library( ":embed_tf_framework", # buildcleaner: keep ":kernel_gen_passes_inc_gen", ":tf_framework_legalize_to_llvm", # buildcleaner: keep - ":utils", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -211,6 +212,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MathDialect", @@ -226,7 +228,9 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:transforms_passes", + "@stablehlo//:base", ], ) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc similarity index 91% rename from third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc index c8268e4335dca2..840de572368c83 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc @@ -15,32 +15,41 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { -namespace mhlo { +namespace kernel_gen { + +using mhlo::DynamicBroadcastInDimOp; #define GEN_PASS_DEF_BROADCASTPROPAGATIONPASS -#include "mhlo/transforms/mhlo_passes.h.inc" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" namespace { @@ -62,26 +71,28 @@ struct BroadcastIntent { }; } // namespace -} // namespace mhlo +} // namespace kernel_gen } // namespace mlir namespace llvm { +using mlir::kernel_gen::BroadcastIntent; + template <> -struct DenseMapInfo { - static mlir::mhlo::BroadcastIntent getEmptyKey() { +struct DenseMapInfo { + static BroadcastIntent getEmptyKey() { return {DenseMapInfo::getEmptyKey(), DenseMapInfo::getEmptyKey(), DenseMapInfo::getEmptyKey(), DenseMapInfo::getEmptyKey()}; } - static mlir::mhlo::BroadcastIntent getTombstoneKey() { + static BroadcastIntent getTombstoneKey() { return {DenseMapInfo::getTombstoneKey(), DenseMapInfo::getTombstoneKey(), DenseMapInfo::getTombstoneKey(), DenseMapInfo::getTombstoneKey()}; } - static unsigned getHashValue(const mlir::mhlo::BroadcastIntent &intent) { + static unsigned getHashValue(const BroadcastIntent &intent) { return hash_combine( DenseMapInfo::getHashValue(intent.resultType), DenseMapInfo::getHashValue(intent.targetValue), @@ -89,8 +100,7 @@ struct DenseMapInfo { DenseMapInfo::getHashValue( intent.broadcastDimensions)); } - static bool isEqual(const mlir::mhlo::BroadcastIntent &lhs, - const mlir::mhlo::BroadcastIntent &rhs) { + static bool isEqual(const BroadcastIntent &lhs, const BroadcastIntent &rhs) { return lhs == rhs; } }; @@ -98,7 +108,7 @@ struct DenseMapInfo { } // namespace llvm namespace mlir { -namespace mhlo { +namespace kernel_gen { namespace { bool allowsForElementwiseBroadcastPropagation(Operation *op) { @@ -448,9 +458,5 @@ struct BroadcastPropagationPass } // namespace -std::unique_ptr> createBroadcastPropagationPass() { - return std::make_unique(); -} - -} // namespace mhlo +} // namespace kernel_gen } // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc similarity index 92% rename from third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc index d6c4b4767297d6..47a0d36fe2b748 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -21,28 +23,31 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { -namespace mhlo { +namespace kernel_gen { #define GEN_PASS_DEF_MERGEASSUMINGOPSPASS -#include "mhlo/transforms/mhlo_passes.h.inc" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" namespace { @@ -340,7 +345,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { LogicalResult matchAndRewrite(shape::AssumingOp op, PatternRewriter &rewriter) const override { // Merge assuming op with directly preceding one if both witnesses are - // availiable. + // available. auto precedingOp = llvm::dyn_cast_or_null(op->getPrevNode()); if (!precedingOp) return failure(); @@ -422,44 +427,20 @@ struct EliminateDuplicateCstrBroadcastableOps } }; -struct MergeAssumingOpsPass - : public impl::MergeAssumingOpsPassBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns); - GreedyRewriteConfig config; - config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), - config))) { - return signalPassFailure(); - } - } -}; - -} // namespace - void populateMergeAssumingOpsPatterns(MLIRContext *context, RewritePatternSet *patterns) { - // clang-format off patterns->add< EliminateDuplicateCstrBroadcastableOps, InlineBroadcastedShapeOperandsPattern, - MergeAssumingOpsPattern, - MoveElementwiseOpsDownIntoAssumingOpPattern, + MergeAssumingOpsPattern, MoveElementwiseOpsDownIntoAssumingOpPattern, MoveElementwiseOpsUpIntoAssumingOpPattern, MoveUpIntoAssumingOpPattern, MoveUpIntoAssumingOpPattern, MoveUpIntoAssumingOpPattern, MoveUpOutOfAssumingOpPattern, MoveUpOutOfAssumingOpPattern, - MoveUpOutOfAssumingOpPattern, - ShapeReificationPattern>(context); - // clang-format on + MoveUpOutOfAssumingOpPattern, ShapeReificationPattern>( + context); mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, context); mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context); @@ -470,9 +451,26 @@ void populateMergeAssumingOpsPatterns(MLIRContext *context, tensor::CastOp::getCanonicalizationPatterns(*patterns, context); } -std::unique_ptr> createMergeAssumingOpsPass() { - return std::make_unique(); -} +struct MergeAssumingOpsPass + : public impl::MergeAssumingOpsPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateMergeAssumingOpsPatterns(ctx, &patterns); + GreedyRewriteConfig config; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace -} // namespace mhlo +} // namespace kernel_gen } // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 45e248ceb904ff..d9dca26c8ce3a3 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -38,6 +38,9 @@ limitations under the License. #define GEN_PASS_DECL_PROPAGATESHAPEKNOWLEDGETOKERNELS #define GEN_PASS_DECL_FUSEINNERPARALLELLOOPSPASS #define GEN_PASS_DECL_COPYCLEANUPPASS +#define GEN_PASS_DECL_SHAPESIMPLIFICATIONPASS +#define GEN_PASS_DECL_MERGEASSUMINGOPSPASS +#define GEN_PASS_DECL_BROADCASTPROPAGATIONPASS namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 4f92be70d25397..9bd6fb8b2e8bf8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -137,4 +137,22 @@ def CopyCleanupPass : Pass<"copy-cleanup", "mlir::func::FuncOp"> { }]; } +def ShapeSimplificationPass + : Pass<"shape-simplification", "mlir::func::FuncOp"> { + let summary = "Simplify shape ops"; +} + +def MergeAssumingOpsPass : Pass<"mhlo-merge-assuming-ops", "func::FuncOp"> { + let summary = "Prepare moving dynamic broadcasts up over element-wise " + "operations and broadcast the operands rather than the result. This will " + "eventually allow for larger fusions."; +} + +def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "func::FuncOp"> { + let summary = "Move dynamic broadcasts up over element-wise operations and " + "broadcast the operands rather than the result. This will eventually allow " + "for larger fusions."; +} + + #endif // TF_KERNEL_GEN_PASSES diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc similarity index 91% rename from third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc index 1747bd93b492ef..b5ceec7f48e8fc 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc @@ -19,22 +19,22 @@ limitations under the License. #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" namespace mlir { -namespace mhlo { +namespace kernel_gen { -#define GEN_PASS_DEF_SHAPESIMPLIFICATION -#include "mhlo/transforms/mhlo_passes.h.inc" +#define GEN_PASS_DEF_SHAPESIMPLIFICATIONPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" namespace { @@ -219,8 +219,8 @@ struct ExtractFromBroadcastedTensorCanonicalizationPattern } }; -struct ShapeSimplification - : public impl::ShapeSimplificationBase { +struct ShapeSimplificationPass + : public impl::ShapeSimplificationPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -249,9 +249,5 @@ struct ShapeSimplification } // namespace -std::unique_ptr> createShapeSimplification() { - return std::make_unique(); -} - -} // namespace mhlo +} // namespace kernel_gen } // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 8f612af8ef5d75..9caf96a8b72f2b 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -363,7 +363,6 @@ cc_library( cc_library( name = "mhlo_passes", srcs = [ - "mhlo/transforms/broadcast_propagation/broadcast_propagation.cc", "mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc", "mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc", "mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc", @@ -378,12 +377,10 @@ cc_library( "mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc", "mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc", "mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc", - "mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc", "mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc", "mhlo/transforms/mhlo_passes.h.inc", "mhlo/transforms/optimize_mhlo/optimize_mhlo.cc", "mhlo/transforms/prepare_for_export/prepare_for_export.cc", - "mhlo/transforms/shape_simplification/shape_simplification.cc", "mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc", "mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc", "mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc", diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 15a015db03cc72..d68486334c7316 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -29,7 +29,6 @@ add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) add_mlir_library(MhloPasses - broadcast_propagation/broadcast_propagation.cc collapse_elementwise_map/collapse_elementwise_map.cc convert_to_signless/convert_to_signless_pass.cc expand_hlo_tuples/expand_hlo_tuples.cc @@ -39,11 +38,9 @@ add_mlir_library(MhloPasses legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc materialize_broadcasts/materialize_broadcasts.cc materialize_broadcasts/materialize_broadcasts_pass.cc - merge_assuming_ops/merge_assuming_ops.cc mhlo_flatten_tuple/mhlo_flatten_tuple.cc prepare_for_export/prepare_for_export.cc optimize_mhlo/optimize_mhlo.cc - shape_simplification/shape_simplification.cc sink_constants_to_control_flow/sink_constants_to_control_flow.cc test_infer_shaped_type/test_infer_shaped_type_pass.cc unfuse_batch_norm/unfuse_batch_norm.cc diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index ab456fcaaea616..8503598e26b64d 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -129,26 +129,6 @@ def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", let constructor = "createTestInferShapedTypeMethodsPass()"; } -def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "func::FuncOp"> { - let summary = "Move dynamic broadcasts up over element-wise operations and " - "broadcast the operands rather than the result. This will eventually allow " - "for larger fusions."; - let constructor = "createBroadcastPropagationPass()"; -} - -def MergeAssumingOpsPass : Pass<"mhlo-merge-assuming-ops", "func::FuncOp"> { - let summary = "Prepare moving dynamic broadcasts up over element-wise " - "operations and broadcast the operands rather than the result. This will " - "eventually allow for larger fusions."; - let constructor = "createMergeAssumingOpsPass()"; -} - -def ShapeSimplification - : Pass<"shape-simplification", "mlir::func::FuncOp"> { - let summary = "Simplify shape ops"; - let constructor = "createShapeSimplification()"; -} - def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "func::FuncOp"> { let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; let constructor = "createTestUnfuseBatchNormPass()"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h index 4791d9a94550e1..632f0f955e1b34 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h @@ -53,20 +53,6 @@ createSinkConstantsToControlFlowPass(); /// that do not use intrinsics. std::unique_ptr> createLegalizeTrigonometricToApproximationPass(); - -// Move dynamic broadcasts up over element-wise operations and broadcast the -// operands rather than the result. This will eventually allow for larger -// fusions. -std::unique_ptr> createBroadcastPropagationPass(); - -// Prepare moving dynamic broadcasts up over element-wise operations and -// broadcast the operands rather than the result. This will eventually allow for -// larger fusions. -std::unique_ptr> createMergeAssumingOpsPass(); - -// Pass to simplify shape ops. -std::unique_ptr> createShapeSimplification(); - std::unique_ptr> createLegalizeDotToDotGeneralPass(); std::unique_ptr> diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h index ac6949551f0a09..194e92b14757e1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -101,12 +101,6 @@ inline void populateUnfuseBatchNormPatterns(MLIRContext *context, void populateTrigonometricToApproximationPatterns(MLIRContext *context, RewritePatternSet *patterns); -// Populate patterns to prepare moving dynamic broadcasts up over element-wise -// operations and broadcast the operands rather than the result. This will -// eventually allow for larger fusions. -void populateMergeAssumingOpsPatterns(MLIRContext *context, - RewritePatternSet *patterns); - // Populate patterns to group reduction and parallel dimensions of reduction // operations and realize them through equivalent 1D or 2D reductions. void populateGroupReductionDimensionsPatterns(MLIRContext *context, From c9476414294b07306007a048d3ee7a92957338f3 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Tue, 1 Apr 2025 11:57:56 -0700 Subject: [PATCH 0108/1324] Integrate LLVM at llvm/llvm-project@799e9053641a Updates LLVM usage to match [799e9053641a](https://github.com/llvm/llvm-project/commit/799e9053641a) PiperOrigin-RevId: 742787533 --- third_party/llvm/generated.patch | 545 -------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 1104 ++++++++--------- third_party/shardy/workspace.bzl | 4 +- .../triton/llvm_integration/cl742325920.patch | 58 + .../triton/llvm_integration/series.bzl | 1 + .../xla/third_party/shardy/temporary.patch | 1104 ++++++++--------- .../xla/third_party/shardy/workspace.bzl | 4 +- .../triton/llvm_integration/cl742325920.patch | 58 + .../triton/llvm_integration/series.bzl | 1 + 10 files changed, 1228 insertions(+), 1655 deletions(-) create mode 100644 third_party/triton/llvm_integration/cl742325920.patch create mode 100644 third_party/xla/third_party/triton/llvm_integration/cl742325920.patch diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 99ef3cb5cdd7c0..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,546 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ---- a/clang/lib/Driver/ToolChains/Clang.cpp -+++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -6397,7 +6397,9 @@ - Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, - options::OPT_fno_convergent_functions); - -- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -+ // NVPTX doesn't support PGO or coverage -+ if (!Triple.isNVPTX()) -+ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); - - Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); - -diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ---- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu -+++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu -@@ -0,0 +1,33 @@ -+// Check that profiling/coverage arguments doen't get passed down to device-side -+// compilation. -+// -+// -+// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -+// XRUN: -fprofile-generate %s 2>&1 | \ -+// XRUN: FileCheck --check-prefixes=CHECK,PROF %s -+// -+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -+// RUN: -fprofile-instr-generate %s 2>&1 | \ -+// RUN: FileCheck --check-prefixes=CHECK,PROF %s -+// -+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -+// RUN: -coverage %s 2>&1 | \ -+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s -+// -+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -+// RUN: -ftest-coverage %s 2>&1 | \ -+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s -+// -+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -+// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ -+// RUN: FileCheck --check-prefixes=CHECK,PROF %s -+// -+// -+// CHECK-NOT: error: unsupported option '-fprofile -+// CHECK-NOT: error: invalid argument -+// CHECK-DAG: "-fcuda-is-device" -+// CHECK-NOT: "-f{{[^"/]*coverage.*}}" -+// CHECK-NOT: "-fprofile{{[^"]*}}" -+// CHECK: "-triple" "x86_64-unknown-linux-gnu" -+// PROF: "-fprofile{{.*}}" -+// GCOV: "-coverage-notes-file= -diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp ---- a/lldb/tools/lldb-dap/DAP.cpp -+++ b/lldb/tools/lldb-dap/DAP.cpp -@@ -711,12 +711,12 @@ - [](const std::string &message) -> llvm::StringRef { - return message; - }, -- [](const protocol::Response::Message &message) -+ [](const protocol::ResponseMessage &message) - -> llvm::StringRef { - switch (message) { -- case protocol::Response::Message::cancelled: -+ case protocol::eResponseMessageCancelled: - return "cancelled"; -- case protocol::Response::Message::notStopped: -+ case protocol::eResponseMessageNotStopped: - return "notStopped"; - } - }), -diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ---- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -@@ -7,6 +7,7 @@ - //===----------------------------------------------------------------------===// - - #include "Protocol/ProtocolBase.h" -+#include "lldb/lldb-enumerations.h" - #include "llvm/ADT/StringRef.h" - #include "llvm/ADT/StringSwitch.h" - #include "llvm/Support/ErrorHandling.h" -@@ -31,11 +32,8 @@ - - namespace lldb_dap::protocol { - --enum MessageType { -- eMessageTypeRequest, -- eMessageTypeResponse, -- eMessageTypeEvent --}; -+FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, -+ eMessageTypeEvent}; - - bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { - auto rawType = Params.getAsString(); -@@ -107,12 +105,12 @@ - - if (R.message) { - assert(!R.success && "message can only be used if success is false"); -- if (const auto *messageEnum = std::get_if(&*R.message)) { -+ if (const auto *messageEnum = std::get_if(&*R.message)) { - switch (*messageEnum) { -- case Response::Message::cancelled: -+ case eResponseMessageCancelled: - Result.insert({"message", "cancelled"}); - break; -- case Response::Message::notStopped: -+ case eResponseMessageNotStopped: - Result.insert({"message", "notStopped"}); - break; - } -@@ -129,16 +127,16 @@ - } - - bool fromJSON(json::Value const &Params, -- std::variant &M, json::Path P) { -+ std::variant &M, json::Path P) { - auto rawMessage = Params.getAsString(); - if (!rawMessage) { - P.report("expected a string"); - return false; - } -- std::optional message = -- StringSwitch>(*rawMessage) -- .Case("cancelled", Response::Message::cancelled) -- .Case("notStopped", Response::Message::notStopped) -+ std::optional message = -+ StringSwitch>(*rawMessage) -+ .Case("cancelled", eResponseMessageCancelled) -+ .Case("notStopped", eResponseMessageNotStopped) - .Default(std::nullopt); - if (message) - M = *message; -diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ---- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -@@ -20,6 +20,7 @@ - #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H - #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H - -+#include "lldb/lldb-enumerations.h" - #include "llvm/Support/JSON.h" - #include - #include -@@ -64,15 +65,15 @@ - llvm::json::Value toJSON(const Event &); - bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); - --/// Response for a request. --struct Response { -- enum class Message { -+FLAGS_ENUM(ResponseMessage){ - /// The request was cancelled -- cancelled, -+ eResponseMessageCancelled, - /// The request may be retried once the adapter is in a 'stopped' state -- notStopped, -- }; -+ eResponseMessageNotStopped, -+}; - -+/// Response for a request. -+struct Response { - /// Sequence number of the corresponding request. - int64_t request_seq; - -@@ -90,7 +91,7 @@ - /// Contains the raw error in short form if `success` is false. This raw error - /// might be interpreted by the client and is not shown in the UI. Some - /// predefined values exist. -- std::optional> message; -+ std::optional> message; - - /// Contains request result if success is true and error details if success is - /// false. -diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ---- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -+++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -@@ -22,6 +22,8 @@ - - #include "Protocol/ProtocolBase.h" - #include "Protocol/ProtocolTypes.h" -+#include "lldb/lldb-enumerations.h" -+#include "llvm/ADT/DenseSet.h" - #include "llvm/Support/JSON.h" - #include - #include -@@ -55,26 +57,26 @@ - using DisconnectResponse = VoidResponse; - - /// Features supported by DAP clients. --enum ClientFeature { -- eClientFeatureVariableType, -- eClientFeatureVariablePaging, -- eClientFeatureRunInTerminalRequest, -- eClientFeatureMemoryReferences, -- eClientFeatureProgressReporting, -- eClientFeatureInvalidatedEvent, -- eClientFeatureMemoryEvent, -- /// Client supports the `argsCanBeInterpretedByShell` attribute on the -- /// `runInTerminal` request. -- eClientFeatureArgsCanBeInterpretedByShell, -- eClientFeatureStartDebuggingRequest, -- /// The client will interpret ANSI escape sequences in the display of -- /// `OutputEvent.output` and `Variable.value` fields when -- /// `Capabilities.supportsANSIStyling` is also enabled. -- eClientFeatureANSIStyling, -+FLAGS_ENUM(ClientFeature){ -+ eClientFeatureVariableType, -+ eClientFeatureVariablePaging, -+ eClientFeatureRunInTerminalRequest, -+ eClientFeatureMemoryReferences, -+ eClientFeatureProgressReporting, -+ eClientFeatureInvalidatedEvent, -+ eClientFeatureMemoryEvent, -+ /// Client supports the `argsCanBeInterpretedByShell` attribute on the -+ /// `runInTerminal` request. -+ eClientFeatureArgsCanBeInterpretedByShell, -+ eClientFeatureStartDebuggingRequest, -+ /// The client will interpret ANSI escape sequences in the display of -+ /// `OutputEvent.output` and `Variable.value` fields when -+ /// `Capabilities.supportsANSIStyling` is also enabled. -+ eClientFeatureANSIStyling, - }; - - /// Format of paths reported by the debug adapter. --enum PathFormat { ePatFormatPath, ePathFormatURI }; -+FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; - - /// Arguments for `initialize` request. - struct InitializeRequestArguments { -diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ---- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -+++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -@@ -20,6 +20,7 @@ - #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H - #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H - -+#include "lldb/lldb-enumerations.h" - #include "llvm/ADT/DenseSet.h" - #include "llvm/Support/JSON.h" - #include -@@ -56,12 +57,8 @@ - }; - llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); - --enum ColumnType { -- eColumnTypeString, -- eColumnTypeNumber, -- eColumnTypeBoolean, -- eColumnTypeTimestamp --}; -+FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, -+ eColumnTypeTimestamp}; - - /// A ColumnDescriptor specifies what module attribute to show in a column of - /// the modules view, how to format it, and what the column’s label should be. -@@ -90,27 +87,23 @@ - - /// Names of checksum algorithms that may be supported by a debug adapter. - /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. --enum ChecksumAlgorithm { -- eChecksumAlgorithmMD5, -- eChecksumAlgorithmSHA1, -- eChecksumAlgorithmSHA256, -- eChecksumAlgorithmTimestamp --}; -+FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, -+ eChecksumAlgorithmSHA256, -+ eChecksumAlgorithmTimestamp}; - llvm::json::Value toJSON(const ChecksumAlgorithm &); - - /// Describes one or more type of breakpoint a BreakpointMode applies to. This - /// is a non-exhaustive enumeration and may expand as future breakpoint types - /// are added. --enum BreakpointModeApplicability { -- /// In `SourceBreakpoint`'s. -- eBreakpointModeApplicabilitySource, -- /// In exception breakpoints applied in the `ExceptionFilterOptions`. -- eBreakpointModeApplicabilityException, -- /// In data breakpoints requested in the `DataBreakpointInfo` request. -- eBreakpointModeApplicabilityData, -- /// In `InstructionBreakpoint`'s. -- eBreakpointModeApplicabilityInstruction --}; -+FLAGS_ENUM(BreakpointModeApplicability){ -+ /// In `SourceBreakpoint`'s. -+ eBreakpointModeApplicabilitySource, -+ /// In exception breakpoints applied in the `ExceptionFilterOptions`. -+ eBreakpointModeApplicabilityException, -+ /// In data breakpoints requested in the `DataBreakpointInfo` request. -+ eBreakpointModeApplicabilityData, -+ /// In `InstructionBreakpoint`'s. -+ eBreakpointModeApplicabilityInstruction}; - llvm::json::Value toJSON(const BreakpointModeApplicability &); - - /// A `BreakpointMode` is provided as a option when setting breakpoints on -@@ -133,101 +126,101 @@ - llvm::json::Value toJSON(const BreakpointMode &); - - /// Debug Adapter Features flags supported by lldb-dap. --enum AdapterFeature { -- /// The debug adapter supports ANSI escape sequences in styling of -- /// `OutputEvent.output` and `Variable.value` fields. -- eAdapterFeatureANSIStyling, -- /// The debug adapter supports the `breakpointLocations` request. -- eAdapterFeatureBreakpointLocationsRequest, -- /// The debug adapter supports the `cancel` request. -- eAdapterFeatureCancelRequest, -- /// The debug adapter supports the `clipboard` context value in the -- /// `evaluate` request. -- eAdapterFeatureClipboardContext, -- /// The debug adapter supports the `completions` request. -- eAdapterFeatureCompletionsRequest, -- /// The debug adapter supports conditional breakpoints. -- eAdapterFeatureConditionalBreakpoints, -- /// The debug adapter supports the `configurationDone` request. -- eAdapterFeatureConfigurationDoneRequest, -- /// The debug adapter supports the `asAddress` and `bytes` fields in the -- /// `dataBreakpointInfo` request. -- eAdapterFeatureDataBreakpointBytes, -- /// The debug adapter supports data breakpoints. -- eAdapterFeatureDataBreakpoints, -- /// The debug adapter supports the delayed loading of parts of the stack, -- /// which requires that both the `startFrame` and `levels` arguments and the -- /// `totalFrames` result of the `stackTrace` request are supported. -- eAdapterFeatureDelayedStackTraceLoading, -- /// The debug adapter supports the `disassemble` request. -- eAdapterFeatureDisassembleRequest, -- /// The debug adapter supports a (side effect free) `evaluate` request for -- /// data hovers. -- eAdapterFeatureEvaluateForHovers, -- /// The debug adapter supports `filterOptions` as an argument on the -- /// `setExceptionBreakpoints` request. -- eAdapterFeatureExceptionFilterOptions, -- /// The debug adapter supports the `exceptionInfo` request. -- eAdapterFeatureExceptionInfoRequest, -- /// The debug adapter supports `exceptionOptions` on the -- /// `setExceptionBreakpoints` request. -- eAdapterFeatureExceptionOptions, -- /// The debug adapter supports function breakpoints. -- eAdapterFeatureFunctionBreakpoints, -- /// The debug adapter supports the `gotoTargets` request. -- eAdapterFeatureGotoTargetsRequest, -- /// The debug adapter supports breakpoints that break execution after a -- /// specified number of hits. -- eAdapterFeatureHitConditionalBreakpoints, -- /// The debug adapter supports adding breakpoints based on instruction -- /// references. -- eAdapterFeatureInstructionBreakpoints, -- /// The debug adapter supports the `loadedSources` request. -- eAdapterFeatureLoadedSourcesRequest, -- /// The debug adapter supports log points by interpreting the `logMessage` -- /// attribute of the `SourceBreakpoint`. -- eAdapterFeatureLogPoints, -- /// The debug adapter supports the `modules` request. -- eAdapterFeatureModulesRequest, -- /// The debug adapter supports the `readMemory` request. -- eAdapterFeatureReadMemoryRequest, -- /// The debug adapter supports restarting a frame. -- eAdapterFeatureRestartFrame, -- /// The debug adapter supports the `restart` request. In this case a client -- /// should not implement `restart` by terminating and relaunching the -- /// adapter but by calling the `restart` request. -- eAdapterFeatureRestartRequest, -- /// The debug adapter supports the `setExpression` request. -- eAdapterFeatureSetExpression, -- /// The debug adapter supports setting a variable to a value. -- eAdapterFeatureSetVariable, -- /// The debug adapter supports the `singleThread` property on the execution -- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, -- /// `stepBack`). -- eAdapterFeatureSingleThreadExecutionRequests, -- /// The debug adapter supports stepping back via the `stepBack` and -- /// `reverseContinue` requests. -- eAdapterFeatureStepBack, -- /// The debug adapter supports the `stepInTargets` request. -- eAdapterFeatureStepInTargetsRequest, -- /// The debug adapter supports stepping granularities (argument -- /// `granularity`) for the stepping requests. -- eAdapterFeatureSteppingGranularity, -- /// The debug adapter supports the `terminate` request. -- eAdapterFeatureTerminateRequest, -- /// The debug adapter supports the `terminateThreads` request. -- eAdapterFeatureTerminateThreadsRequest, -- /// The debug adapter supports the `suspendDebuggee` attribute on the -- /// `disconnect` request. -- eAdapterFeatureSuspendDebuggee, -- /// The debug adapter supports a `format` attribute on the `stackTrace`, -- /// `variables`, and `evaluate` requests. -- eAdapterFeatureValueFormattingOptions, -- /// The debug adapter supports the `writeMemory` request. -- eAdapterFeatureWriteMemoryRequest, -- /// The debug adapter supports the `terminateDebuggee` attribute on the -- /// `disconnect` request. -- eAdapterFeatureTerminateDebuggee, -+FLAGS_ENUM(AdapterFeature){ -+ /// The debug adapter supports ANSI escape sequences in styling of -+ /// `OutputEvent.output` and `Variable.value` fields. -+ eAdapterFeatureANSIStyling, -+ /// The debug adapter supports the `breakpointLocations` request. -+ eAdapterFeatureBreakpointLocationsRequest, -+ /// The debug adapter supports the `cancel` request. -+ eAdapterFeatureCancelRequest, -+ /// The debug adapter supports the `clipboard` context value in the -+ /// `evaluate` request. -+ eAdapterFeatureClipboardContext, -+ /// The debug adapter supports the `completions` request. -+ eAdapterFeatureCompletionsRequest, -+ /// The debug adapter supports conditional breakpoints. -+ eAdapterFeatureConditionalBreakpoints, -+ /// The debug adapter supports the `configurationDone` request. -+ eAdapterFeatureConfigurationDoneRequest, -+ /// The debug adapter supports the `asAddress` and `bytes` fields in the -+ /// `dataBreakpointInfo` request. -+ eAdapterFeatureDataBreakpointBytes, -+ /// The debug adapter supports data breakpoints. -+ eAdapterFeatureDataBreakpoints, -+ /// The debug adapter supports the delayed loading of parts of the stack, -+ /// which requires that both the `startFrame` and `levels` arguments and the -+ /// `totalFrames` result of the `stackTrace` request are supported. -+ eAdapterFeatureDelayedStackTraceLoading, -+ /// The debug adapter supports the `disassemble` request. -+ eAdapterFeatureDisassembleRequest, -+ /// The debug adapter supports a (side effect free) `evaluate` request for -+ /// data hovers. -+ eAdapterFeatureEvaluateForHovers, -+ /// The debug adapter supports `filterOptions` as an argument on the -+ /// `setExceptionBreakpoints` request. -+ eAdapterFeatureExceptionFilterOptions, -+ /// The debug adapter supports the `exceptionInfo` request. -+ eAdapterFeatureExceptionInfoRequest, -+ /// The debug adapter supports `exceptionOptions` on the -+ /// `setExceptionBreakpoints` request. -+ eAdapterFeatureExceptionOptions, -+ /// The debug adapter supports function breakpoints. -+ eAdapterFeatureFunctionBreakpoints, -+ /// The debug adapter supports the `gotoTargets` request. -+ eAdapterFeatureGotoTargetsRequest, -+ /// The debug adapter supports breakpoints that break execution after a -+ /// specified number of hits. -+ eAdapterFeatureHitConditionalBreakpoints, -+ /// The debug adapter supports adding breakpoints based on instruction -+ /// references. -+ eAdapterFeatureInstructionBreakpoints, -+ /// The debug adapter supports the `loadedSources` request. -+ eAdapterFeatureLoadedSourcesRequest, -+ /// The debug adapter supports log points by interpreting the `logMessage` -+ /// attribute of the `SourceBreakpoint`. -+ eAdapterFeatureLogPoints, -+ /// The debug adapter supports the `modules` request. -+ eAdapterFeatureModulesRequest, -+ /// The debug adapter supports the `readMemory` request. -+ eAdapterFeatureReadMemoryRequest, -+ /// The debug adapter supports restarting a frame. -+ eAdapterFeatureRestartFrame, -+ /// The debug adapter supports the `restart` request. In this case a client -+ /// should not implement `restart` by terminating and relaunching the -+ /// adapter but by calling the `restart` request. -+ eAdapterFeatureRestartRequest, -+ /// The debug adapter supports the `setExpression` request. -+ eAdapterFeatureSetExpression, -+ /// The debug adapter supports setting a variable to a value. -+ eAdapterFeatureSetVariable, -+ /// The debug adapter supports the `singleThread` property on the execution -+ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, -+ /// `stepBack`). -+ eAdapterFeatureSingleThreadExecutionRequests, -+ /// The debug adapter supports stepping back via the `stepBack` and -+ /// `reverseContinue` requests. -+ eAdapterFeatureStepBack, -+ /// The debug adapter supports the `stepInTargets` request. -+ eAdapterFeatureStepInTargetsRequest, -+ /// The debug adapter supports stepping granularities (argument -+ /// `granularity`) for the stepping requests. -+ eAdapterFeatureSteppingGranularity, -+ /// The debug adapter supports the `terminate` request. -+ eAdapterFeatureTerminateRequest, -+ /// The debug adapter supports the `terminateThreads` request. -+ eAdapterFeatureTerminateThreadsRequest, -+ /// The debug adapter supports the `suspendDebuggee` attribute on the -+ /// `disconnect` request. -+ eAdapterFeatureSuspendDebuggee, -+ /// The debug adapter supports a `format` attribute on the `stackTrace`, -+ /// `variables`, and `evaluate` requests. -+ eAdapterFeatureValueFormattingOptions, -+ /// The debug adapter supports the `writeMemory` request. -+ eAdapterFeatureWriteMemoryRequest, -+ /// The debug adapter supports the `terminateDebuggee` attribute on the -+ /// `disconnect` request. -+ eAdapterFeatureTerminateDebuggee, - }; - - /// Information about the capabilities of a debug adapter. -@@ -268,10 +261,10 @@ - }; - llvm::json::Value toJSON(const Capabilities &); - --enum PresentationHint { -- ePresentationHintNormal, -- ePresentationHintEmphasize, -- ePresentationHintDeemphasize, -+FLAGS_ENUM(PresentationHint){ -+ ePresentationHintNormal, -+ ePresentationHintEmphasize, -+ ePresentationHintDeemphasize, - }; - - /// A `Source` is a descriptor for source code. It is returned from the debug -diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ---- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -+++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -@@ -1,7 +1,7 @@ - // Header - // - // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) --// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) -+// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) - // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) - // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) - // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) -diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c ---- a/offload/test/offloading/gpupgo/pgo1.c -+++ b/offload/test/offloading/gpupgo/pgo1.c -@@ -14,7 +14,7 @@ - // RUN: %target_triple.%basename_t.clang.profraw | \ - // RUN: %fcheck-generic --check-prefix="CLANG-PGO" - --// REQUIRES: gpu -+// REQUIRES: amdgpu - // REQUIRES: pgo - - int test1(int a) { return a / 2; } -diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c ---- a/offload/test/offloading/gpupgo/pgo2.c -+++ b/offload/test/offloading/gpupgo/pgo2.c -@@ -48,7 +48,7 @@ - // RUN: %target_triple.%basename_t.hfdi.profraw \ - // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" - --// REQUIRES: gpu -+// REQUIRES: amdgpu - // REQUIRES: pgo - - int main() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 005737af0dd2ac..fd9baec89202f9 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" - LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" + LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" + LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 5a732df12a541f..1cc7fe6b95c33c 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,566 +1,566 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..99ef3cb 100644 +index 99ef3cb..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1 +1,546 @@ +@@ -1,546 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp -+--- a/clang/lib/Driver/ToolChains/Clang.cpp -++++ b/clang/lib/Driver/ToolChains/Clang.cpp -+@@ -6397,7 +6397,9 @@ -+ Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, -+ options::OPT_fno_convergent_functions); -+ -+- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -++ // NVPTX doesn't support PGO or coverage -++ if (!Triple.isNVPTX()) -++ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -+ -+ Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); -+ -+diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu -+--- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu -++++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu -+@@ -0,0 +1,33 @@ -++// Check that profiling/coverage arguments doen't get passed down to device-side -++// compilation. -++// -++// -++// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// XRUN: -fprofile-generate %s 2>&1 | \ -++// XRUN: FileCheck --check-prefixes=CHECK,PROF %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -fprofile-instr-generate %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,PROF %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -coverage %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -ftest-coverage %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,PROF %s -++// -++// -++// CHECK-NOT: error: unsupported option '-fprofile -++// CHECK-NOT: error: invalid argument -++// CHECK-DAG: "-fcuda-is-device" -++// CHECK-NOT: "-f{{[^"/]*coverage.*}}" -++// CHECK-NOT: "-fprofile{{[^"]*}}" -++// CHECK: "-triple" "x86_64-unknown-linux-gnu" -++// PROF: "-fprofile{{.*}}" -++// GCOV: "-coverage-notes-file= -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp -+--- a/lldb/tools/lldb-dap/DAP.cpp -++++ b/lldb/tools/lldb-dap/DAP.cpp -+@@ -711,12 +711,12 @@ -+ [](const std::string &message) -> llvm::StringRef { -+ return message; -+ }, -+- [](const protocol::Response::Message &message) -++ [](const protocol::ResponseMessage &message) -+ -> llvm::StringRef { -+ switch (message) { -+- case protocol::Response::Message::cancelled: -++ case protocol::eResponseMessageCancelled: -+ return "cancelled"; -+- case protocol::Response::Message::notStopped: -++ case protocol::eResponseMessageNotStopped: -+ return "notStopped"; -+ } -+ }), -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -+@@ -7,6 +7,7 @@ -+ //===----------------------------------------------------------------------===// -+ -+ #include "Protocol/ProtocolBase.h" -++#include "lldb/lldb-enumerations.h" -+ #include "llvm/ADT/StringRef.h" -+ #include "llvm/ADT/StringSwitch.h" -+ #include "llvm/Support/ErrorHandling.h" -+@@ -31,11 +32,8 @@ -+ -+ namespace lldb_dap::protocol { -+ -+-enum MessageType { -+- eMessageTypeRequest, -+- eMessageTypeResponse, -+- eMessageTypeEvent -+-}; -++FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, -++ eMessageTypeEvent}; -+ -+ bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { -+ auto rawType = Params.getAsString(); -+@@ -107,12 +105,12 @@ -+ -+ if (R.message) { -+ assert(!R.success && "message can only be used if success is false"); -+- if (const auto *messageEnum = std::get_if(&*R.message)) { -++ if (const auto *messageEnum = std::get_if(&*R.message)) { -+ switch (*messageEnum) { -+- case Response::Message::cancelled: -++ case eResponseMessageCancelled: -+ Result.insert({"message", "cancelled"}); -+ break; -+- case Response::Message::notStopped: -++ case eResponseMessageNotStopped: -+ Result.insert({"message", "notStopped"}); -+ break; -+ } -+@@ -129,16 +127,16 @@ -+ } -+ -+ bool fromJSON(json::Value const &Params, -+- std::variant &M, json::Path P) { -++ std::variant &M, json::Path P) { -+ auto rawMessage = Params.getAsString(); -+ if (!rawMessage) { -+ P.report("expected a string"); -+ return false; -+ } -+- std::optional message = -+- StringSwitch>(*rawMessage) -+- .Case("cancelled", Response::Message::cancelled) -+- .Case("notStopped", Response::Message::notStopped) -++ std::optional message = -++ StringSwitch>(*rawMessage) -++ .Case("cancelled", eResponseMessageCancelled) -++ .Case("notStopped", eResponseMessageNotStopped) -+ .Default(std::nullopt); -+ if (message) -+ M = *message; -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -+@@ -20,6 +20,7 @@ -+ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -+ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -+ -++#include "lldb/lldb-enumerations.h" -+ #include "llvm/Support/JSON.h" -+ #include -+ #include -+@@ -64,15 +65,15 @@ -+ llvm::json::Value toJSON(const Event &); -+ bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); -+ -+-/// Response for a request. -+-struct Response { -+- enum class Message { -++FLAGS_ENUM(ResponseMessage){ -+ /// The request was cancelled -+- cancelled, -++ eResponseMessageCancelled, -+ /// The request may be retried once the adapter is in a 'stopped' state -+- notStopped, -+- }; -++ eResponseMessageNotStopped, -++}; -+ -++/// Response for a request. -++struct Response { -+ /// Sequence number of the corresponding request. -+ int64_t request_seq; -+ -+@@ -90,7 +91,7 @@ -+ /// Contains the raw error in short form if `success` is false. This raw error -+ /// might be interpreted by the client and is not shown in the UI. Some -+ /// predefined values exist. -+- std::optional> message; -++ std::optional> message; -+ -+ /// Contains request result if success is true and error details if success is -+ /// false. -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -+@@ -22,6 +22,8 @@ -+ -+ #include "Protocol/ProtocolBase.h" -+ #include "Protocol/ProtocolTypes.h" -++#include "lldb/lldb-enumerations.h" -++#include "llvm/ADT/DenseSet.h" -+ #include "llvm/Support/JSON.h" -+ #include -+ #include -+@@ -55,26 +57,26 @@ -+ using DisconnectResponse = VoidResponse; -+ -+ /// Features supported by DAP clients. -+-enum ClientFeature { -+- eClientFeatureVariableType, -+- eClientFeatureVariablePaging, -+- eClientFeatureRunInTerminalRequest, -+- eClientFeatureMemoryReferences, -+- eClientFeatureProgressReporting, -+- eClientFeatureInvalidatedEvent, -+- eClientFeatureMemoryEvent, -+- /// Client supports the `argsCanBeInterpretedByShell` attribute on the -+- /// `runInTerminal` request. -+- eClientFeatureArgsCanBeInterpretedByShell, -+- eClientFeatureStartDebuggingRequest, -+- /// The client will interpret ANSI escape sequences in the display of -+- /// `OutputEvent.output` and `Variable.value` fields when -+- /// `Capabilities.supportsANSIStyling` is also enabled. -+- eClientFeatureANSIStyling, -++FLAGS_ENUM(ClientFeature){ -++ eClientFeatureVariableType, -++ eClientFeatureVariablePaging, -++ eClientFeatureRunInTerminalRequest, -++ eClientFeatureMemoryReferences, -++ eClientFeatureProgressReporting, -++ eClientFeatureInvalidatedEvent, -++ eClientFeatureMemoryEvent, -++ /// Client supports the `argsCanBeInterpretedByShell` attribute on the -++ /// `runInTerminal` request. -++ eClientFeatureArgsCanBeInterpretedByShell, -++ eClientFeatureStartDebuggingRequest, -++ /// The client will interpret ANSI escape sequences in the display of -++ /// `OutputEvent.output` and `Variable.value` fields when -++ /// `Capabilities.supportsANSIStyling` is also enabled. -++ eClientFeatureANSIStyling, -+ }; -+ -+ /// Format of paths reported by the debug adapter. -+-enum PathFormat { ePatFormatPath, ePathFormatURI }; -++FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; -+ -+ /// Arguments for `initialize` request. -+ struct InitializeRequestArguments { -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -+@@ -20,6 +20,7 @@ -+ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -+ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -+ -++#include "lldb/lldb-enumerations.h" -+ #include "llvm/ADT/DenseSet.h" -+ #include "llvm/Support/JSON.h" -+ #include -+@@ -56,12 +57,8 @@ -+ }; -+ llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); -+ -+-enum ColumnType { -+- eColumnTypeString, -+- eColumnTypeNumber, -+- eColumnTypeBoolean, -+- eColumnTypeTimestamp -+-}; -++FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, -++ eColumnTypeTimestamp}; -+ -+ /// A ColumnDescriptor specifies what module attribute to show in a column of -+ /// the modules view, how to format it, and what the column’s label should be. -+@@ -90,27 +87,23 @@ -+ -+ /// Names of checksum algorithms that may be supported by a debug adapter. -+ /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. -+-enum ChecksumAlgorithm { -+- eChecksumAlgorithmMD5, -+- eChecksumAlgorithmSHA1, -+- eChecksumAlgorithmSHA256, -+- eChecksumAlgorithmTimestamp -+-}; -++FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, -++ eChecksumAlgorithmSHA256, -++ eChecksumAlgorithmTimestamp}; -+ llvm::json::Value toJSON(const ChecksumAlgorithm &); -+ -+ /// Describes one or more type of breakpoint a BreakpointMode applies to. This -+ /// is a non-exhaustive enumeration and may expand as future breakpoint types -+ /// are added. -+-enum BreakpointModeApplicability { -+- /// In `SourceBreakpoint`'s. -+- eBreakpointModeApplicabilitySource, -+- /// In exception breakpoints applied in the `ExceptionFilterOptions`. -+- eBreakpointModeApplicabilityException, -+- /// In data breakpoints requested in the `DataBreakpointInfo` request. -+- eBreakpointModeApplicabilityData, -+- /// In `InstructionBreakpoint`'s. -+- eBreakpointModeApplicabilityInstruction -+-}; -++FLAGS_ENUM(BreakpointModeApplicability){ -++ /// In `SourceBreakpoint`'s. -++ eBreakpointModeApplicabilitySource, -++ /// In exception breakpoints applied in the `ExceptionFilterOptions`. -++ eBreakpointModeApplicabilityException, -++ /// In data breakpoints requested in the `DataBreakpointInfo` request. -++ eBreakpointModeApplicabilityData, -++ /// In `InstructionBreakpoint`'s. -++ eBreakpointModeApplicabilityInstruction}; -+ llvm::json::Value toJSON(const BreakpointModeApplicability &); -+ -+ /// A `BreakpointMode` is provided as a option when setting breakpoints on -+@@ -133,101 +126,101 @@ -+ llvm::json::Value toJSON(const BreakpointMode &); -+ -+ /// Debug Adapter Features flags supported by lldb-dap. -+-enum AdapterFeature { -+- /// The debug adapter supports ANSI escape sequences in styling of -+- /// `OutputEvent.output` and `Variable.value` fields. -+- eAdapterFeatureANSIStyling, -+- /// The debug adapter supports the `breakpointLocations` request. -+- eAdapterFeatureBreakpointLocationsRequest, -+- /// The debug adapter supports the `cancel` request. -+- eAdapterFeatureCancelRequest, -+- /// The debug adapter supports the `clipboard` context value in the -+- /// `evaluate` request. -+- eAdapterFeatureClipboardContext, -+- /// The debug adapter supports the `completions` request. -+- eAdapterFeatureCompletionsRequest, -+- /// The debug adapter supports conditional breakpoints. -+- eAdapterFeatureConditionalBreakpoints, -+- /// The debug adapter supports the `configurationDone` request. -+- eAdapterFeatureConfigurationDoneRequest, -+- /// The debug adapter supports the `asAddress` and `bytes` fields in the -+- /// `dataBreakpointInfo` request. -+- eAdapterFeatureDataBreakpointBytes, -+- /// The debug adapter supports data breakpoints. -+- eAdapterFeatureDataBreakpoints, -+- /// The debug adapter supports the delayed loading of parts of the stack, -+- /// which requires that both the `startFrame` and `levels` arguments and the -+- /// `totalFrames` result of the `stackTrace` request are supported. -+- eAdapterFeatureDelayedStackTraceLoading, -+- /// The debug adapter supports the `disassemble` request. -+- eAdapterFeatureDisassembleRequest, -+- /// The debug adapter supports a (side effect free) `evaluate` request for -+- /// data hovers. -+- eAdapterFeatureEvaluateForHovers, -+- /// The debug adapter supports `filterOptions` as an argument on the -+- /// `setExceptionBreakpoints` request. -+- eAdapterFeatureExceptionFilterOptions, -+- /// The debug adapter supports the `exceptionInfo` request. -+- eAdapterFeatureExceptionInfoRequest, -+- /// The debug adapter supports `exceptionOptions` on the -+- /// `setExceptionBreakpoints` request. -+- eAdapterFeatureExceptionOptions, -+- /// The debug adapter supports function breakpoints. -+- eAdapterFeatureFunctionBreakpoints, -+- /// The debug adapter supports the `gotoTargets` request. -+- eAdapterFeatureGotoTargetsRequest, -+- /// The debug adapter supports breakpoints that break execution after a -+- /// specified number of hits. -+- eAdapterFeatureHitConditionalBreakpoints, -+- /// The debug adapter supports adding breakpoints based on instruction -+- /// references. -+- eAdapterFeatureInstructionBreakpoints, -+- /// The debug adapter supports the `loadedSources` request. -+- eAdapterFeatureLoadedSourcesRequest, -+- /// The debug adapter supports log points by interpreting the `logMessage` -+- /// attribute of the `SourceBreakpoint`. -+- eAdapterFeatureLogPoints, -+- /// The debug adapter supports the `modules` request. -+- eAdapterFeatureModulesRequest, -+- /// The debug adapter supports the `readMemory` request. -+- eAdapterFeatureReadMemoryRequest, -+- /// The debug adapter supports restarting a frame. -+- eAdapterFeatureRestartFrame, -+- /// The debug adapter supports the `restart` request. In this case a client -+- /// should not implement `restart` by terminating and relaunching the -+- /// adapter but by calling the `restart` request. -+- eAdapterFeatureRestartRequest, -+- /// The debug adapter supports the `setExpression` request. -+- eAdapterFeatureSetExpression, -+- /// The debug adapter supports setting a variable to a value. -+- eAdapterFeatureSetVariable, -+- /// The debug adapter supports the `singleThread` property on the execution -+- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, -+- /// `stepBack`). -+- eAdapterFeatureSingleThreadExecutionRequests, -+- /// The debug adapter supports stepping back via the `stepBack` and -+- /// `reverseContinue` requests. -+- eAdapterFeatureStepBack, -+- /// The debug adapter supports the `stepInTargets` request. -+- eAdapterFeatureStepInTargetsRequest, -+- /// The debug adapter supports stepping granularities (argument -+- /// `granularity`) for the stepping requests. -+- eAdapterFeatureSteppingGranularity, -+- /// The debug adapter supports the `terminate` request. -+- eAdapterFeatureTerminateRequest, -+- /// The debug adapter supports the `terminateThreads` request. -+- eAdapterFeatureTerminateThreadsRequest, -+- /// The debug adapter supports the `suspendDebuggee` attribute on the -+- /// `disconnect` request. -+- eAdapterFeatureSuspendDebuggee, -+- /// The debug adapter supports a `format` attribute on the `stackTrace`, -+- /// `variables`, and `evaluate` requests. -+- eAdapterFeatureValueFormattingOptions, -+- /// The debug adapter supports the `writeMemory` request. -+- eAdapterFeatureWriteMemoryRequest, -+- /// The debug adapter supports the `terminateDebuggee` attribute on the -+- /// `disconnect` request. -+- eAdapterFeatureTerminateDebuggee, -++FLAGS_ENUM(AdapterFeature){ -++ /// The debug adapter supports ANSI escape sequences in styling of -++ /// `OutputEvent.output` and `Variable.value` fields. -++ eAdapterFeatureANSIStyling, -++ /// The debug adapter supports the `breakpointLocations` request. -++ eAdapterFeatureBreakpointLocationsRequest, -++ /// The debug adapter supports the `cancel` request. -++ eAdapterFeatureCancelRequest, -++ /// The debug adapter supports the `clipboard` context value in the -++ /// `evaluate` request. -++ eAdapterFeatureClipboardContext, -++ /// The debug adapter supports the `completions` request. -++ eAdapterFeatureCompletionsRequest, -++ /// The debug adapter supports conditional breakpoints. -++ eAdapterFeatureConditionalBreakpoints, -++ /// The debug adapter supports the `configurationDone` request. -++ eAdapterFeatureConfigurationDoneRequest, -++ /// The debug adapter supports the `asAddress` and `bytes` fields in the -++ /// `dataBreakpointInfo` request. -++ eAdapterFeatureDataBreakpointBytes, -++ /// The debug adapter supports data breakpoints. -++ eAdapterFeatureDataBreakpoints, -++ /// The debug adapter supports the delayed loading of parts of the stack, -++ /// which requires that both the `startFrame` and `levels` arguments and the -++ /// `totalFrames` result of the `stackTrace` request are supported. -++ eAdapterFeatureDelayedStackTraceLoading, -++ /// The debug adapter supports the `disassemble` request. -++ eAdapterFeatureDisassembleRequest, -++ /// The debug adapter supports a (side effect free) `evaluate` request for -++ /// data hovers. -++ eAdapterFeatureEvaluateForHovers, -++ /// The debug adapter supports `filterOptions` as an argument on the -++ /// `setExceptionBreakpoints` request. -++ eAdapterFeatureExceptionFilterOptions, -++ /// The debug adapter supports the `exceptionInfo` request. -++ eAdapterFeatureExceptionInfoRequest, -++ /// The debug adapter supports `exceptionOptions` on the -++ /// `setExceptionBreakpoints` request. -++ eAdapterFeatureExceptionOptions, -++ /// The debug adapter supports function breakpoints. -++ eAdapterFeatureFunctionBreakpoints, -++ /// The debug adapter supports the `gotoTargets` request. -++ eAdapterFeatureGotoTargetsRequest, -++ /// The debug adapter supports breakpoints that break execution after a -++ /// specified number of hits. -++ eAdapterFeatureHitConditionalBreakpoints, -++ /// The debug adapter supports adding breakpoints based on instruction -++ /// references. -++ eAdapterFeatureInstructionBreakpoints, -++ /// The debug adapter supports the `loadedSources` request. -++ eAdapterFeatureLoadedSourcesRequest, -++ /// The debug adapter supports log points by interpreting the `logMessage` -++ /// attribute of the `SourceBreakpoint`. -++ eAdapterFeatureLogPoints, -++ /// The debug adapter supports the `modules` request. -++ eAdapterFeatureModulesRequest, -++ /// The debug adapter supports the `readMemory` request. -++ eAdapterFeatureReadMemoryRequest, -++ /// The debug adapter supports restarting a frame. -++ eAdapterFeatureRestartFrame, -++ /// The debug adapter supports the `restart` request. In this case a client -++ /// should not implement `restart` by terminating and relaunching the -++ /// adapter but by calling the `restart` request. -++ eAdapterFeatureRestartRequest, -++ /// The debug adapter supports the `setExpression` request. -++ eAdapterFeatureSetExpression, -++ /// The debug adapter supports setting a variable to a value. -++ eAdapterFeatureSetVariable, -++ /// The debug adapter supports the `singleThread` property on the execution -++ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, -++ /// `stepBack`). -++ eAdapterFeatureSingleThreadExecutionRequests, -++ /// The debug adapter supports stepping back via the `stepBack` and -++ /// `reverseContinue` requests. -++ eAdapterFeatureStepBack, -++ /// The debug adapter supports the `stepInTargets` request. -++ eAdapterFeatureStepInTargetsRequest, -++ /// The debug adapter supports stepping granularities (argument -++ /// `granularity`) for the stepping requests. -++ eAdapterFeatureSteppingGranularity, -++ /// The debug adapter supports the `terminate` request. -++ eAdapterFeatureTerminateRequest, -++ /// The debug adapter supports the `terminateThreads` request. -++ eAdapterFeatureTerminateThreadsRequest, -++ /// The debug adapter supports the `suspendDebuggee` attribute on the -++ /// `disconnect` request. -++ eAdapterFeatureSuspendDebuggee, -++ /// The debug adapter supports a `format` attribute on the `stackTrace`, -++ /// `variables`, and `evaluate` requests. -++ eAdapterFeatureValueFormattingOptions, -++ /// The debug adapter supports the `writeMemory` request. -++ eAdapterFeatureWriteMemoryRequest, -++ /// The debug adapter supports the `terminateDebuggee` attribute on the -++ /// `disconnect` request. -++ eAdapterFeatureTerminateDebuggee, -+ }; -+ -+ /// Information about the capabilities of a debug adapter. -+@@ -268,10 +261,10 @@ -+ }; -+ llvm::json::Value toJSON(const Capabilities &); -+ -+-enum PresentationHint { -+- ePresentationHintNormal, -+- ePresentationHintEmphasize, -+- ePresentationHintDeemphasize, -++FLAGS_ENUM(PresentationHint){ -++ ePresentationHintNormal, -++ ePresentationHintEmphasize, -++ ePresentationHintDeemphasize, -+ }; -+ -+ /// A `Source` is a descriptor for source code. It is returned from the debug -+diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -+--- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -++++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -+@@ -1,7 +1,7 @@ -+ // Header -+ // -+ // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) -+-// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) -++// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) -+ // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) -+ // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) -+ // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) -+diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c -+--- a/offload/test/offloading/gpupgo/pgo1.c -++++ b/offload/test/offloading/gpupgo/pgo1.c -+@@ -14,7 +14,7 @@ -+ // RUN: %target_triple.%basename_t.clang.profraw | \ -+ // RUN: %fcheck-generic --check-prefix="CLANG-PGO" -+ -+-// REQUIRES: gpu -++// REQUIRES: amdgpu -+ // REQUIRES: pgo -+ -+ int test1(int a) { return a / 2; } -+diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c -+--- a/offload/test/offloading/gpupgo/pgo2.c -++++ b/offload/test/offloading/gpupgo/pgo2.c -+@@ -48,7 +48,7 @@ -+ // RUN: %target_triple.%basename_t.hfdi.profraw \ -+ // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" -+ -+-// REQUIRES: gpu -++// REQUIRES: amdgpu -+ // REQUIRES: pgo -+ -+ int main() { +-diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp +---- a/clang/lib/Driver/ToolChains/Clang.cpp +-+++ b/clang/lib/Driver/ToolChains/Clang.cpp +-@@ -6397,7 +6397,9 @@ +- Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, +- options::OPT_fno_convergent_functions); +- +-- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); +-+ // NVPTX doesn't support PGO or coverage +-+ if (!Triple.isNVPTX()) +-+ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); +- +- Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); +- +-diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu +---- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu +-+++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu +-@@ -0,0 +1,33 @@ +-+// Check that profiling/coverage arguments doen't get passed down to device-side +-+// compilation. +-+// +-+// +-+// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// XRUN: -fprofile-generate %s 2>&1 | \ +-+// XRUN: FileCheck --check-prefixes=CHECK,PROF %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -fprofile-instr-generate %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,PROF %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -coverage %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -ftest-coverage %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,PROF %s +-+// +-+// +-+// CHECK-NOT: error: unsupported option '-fprofile +-+// CHECK-NOT: error: invalid argument +-+// CHECK-DAG: "-fcuda-is-device" +-+// CHECK-NOT: "-f{{[^"/]*coverage.*}}" +-+// CHECK-NOT: "-fprofile{{[^"]*}}" +-+// CHECK: "-triple" "x86_64-unknown-linux-gnu" +-+// PROF: "-fprofile{{.*}}" +-+// GCOV: "-coverage-notes-file= +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp +---- a/lldb/tools/lldb-dap/DAP.cpp +-+++ b/lldb/tools/lldb-dap/DAP.cpp +-@@ -711,12 +711,12 @@ +- [](const std::string &message) -> llvm::StringRef { +- return message; +- }, +-- [](const protocol::Response::Message &message) +-+ [](const protocol::ResponseMessage &message) +- -> llvm::StringRef { +- switch (message) { +-- case protocol::Response::Message::cancelled: +-+ case protocol::eResponseMessageCancelled: +- return "cancelled"; +-- case protocol::Response::Message::notStopped: +-+ case protocol::eResponseMessageNotStopped: +- return "notStopped"; +- } +- }), +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +---- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +-@@ -7,6 +7,7 @@ +- //===----------------------------------------------------------------------===// +- +- #include "Protocol/ProtocolBase.h" +-+#include "lldb/lldb-enumerations.h" +- #include "llvm/ADT/StringRef.h" +- #include "llvm/ADT/StringSwitch.h" +- #include "llvm/Support/ErrorHandling.h" +-@@ -31,11 +32,8 @@ +- +- namespace lldb_dap::protocol { +- +--enum MessageType { +-- eMessageTypeRequest, +-- eMessageTypeResponse, +-- eMessageTypeEvent +--}; +-+FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, +-+ eMessageTypeEvent}; +- +- bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { +- auto rawType = Params.getAsString(); +-@@ -107,12 +105,12 @@ +- +- if (R.message) { +- assert(!R.success && "message can only be used if success is false"); +-- if (const auto *messageEnum = std::get_if(&*R.message)) { +-+ if (const auto *messageEnum = std::get_if(&*R.message)) { +- switch (*messageEnum) { +-- case Response::Message::cancelled: +-+ case eResponseMessageCancelled: +- Result.insert({"message", "cancelled"}); +- break; +-- case Response::Message::notStopped: +-+ case eResponseMessageNotStopped: +- Result.insert({"message", "notStopped"}); +- break; +- } +-@@ -129,16 +127,16 @@ +- } +- +- bool fromJSON(json::Value const &Params, +-- std::variant &M, json::Path P) { +-+ std::variant &M, json::Path P) { +- auto rawMessage = Params.getAsString(); +- if (!rawMessage) { +- P.report("expected a string"); +- return false; +- } +-- std::optional message = +-- StringSwitch>(*rawMessage) +-- .Case("cancelled", Response::Message::cancelled) +-- .Case("notStopped", Response::Message::notStopped) +-+ std::optional message = +-+ StringSwitch>(*rawMessage) +-+ .Case("cancelled", eResponseMessageCancelled) +-+ .Case("notStopped", eResponseMessageNotStopped) +- .Default(std::nullopt); +- if (message) +- M = *message; +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +---- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +-@@ -20,6 +20,7 @@ +- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H +- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H +- +-+#include "lldb/lldb-enumerations.h" +- #include "llvm/Support/JSON.h" +- #include +- #include +-@@ -64,15 +65,15 @@ +- llvm::json::Value toJSON(const Event &); +- bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); +- +--/// Response for a request. +--struct Response { +-- enum class Message { +-+FLAGS_ENUM(ResponseMessage){ +- /// The request was cancelled +-- cancelled, +-+ eResponseMessageCancelled, +- /// The request may be retried once the adapter is in a 'stopped' state +-- notStopped, +-- }; +-+ eResponseMessageNotStopped, +-+}; +- +-+/// Response for a request. +-+struct Response { +- /// Sequence number of the corresponding request. +- int64_t request_seq; +- +-@@ -90,7 +91,7 @@ +- /// Contains the raw error in short form if `success` is false. This raw error +- /// might be interpreted by the client and is not shown in the UI. Some +- /// predefined values exist. +-- std::optional> message; +-+ std::optional> message; +- +- /// Contains request result if success is true and error details if success is +- /// false. +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +---- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +-@@ -22,6 +22,8 @@ +- +- #include "Protocol/ProtocolBase.h" +- #include "Protocol/ProtocolTypes.h" +-+#include "lldb/lldb-enumerations.h" +-+#include "llvm/ADT/DenseSet.h" +- #include "llvm/Support/JSON.h" +- #include +- #include +-@@ -55,26 +57,26 @@ +- using DisconnectResponse = VoidResponse; +- +- /// Features supported by DAP clients. +--enum ClientFeature { +-- eClientFeatureVariableType, +-- eClientFeatureVariablePaging, +-- eClientFeatureRunInTerminalRequest, +-- eClientFeatureMemoryReferences, +-- eClientFeatureProgressReporting, +-- eClientFeatureInvalidatedEvent, +-- eClientFeatureMemoryEvent, +-- /// Client supports the `argsCanBeInterpretedByShell` attribute on the +-- /// `runInTerminal` request. +-- eClientFeatureArgsCanBeInterpretedByShell, +-- eClientFeatureStartDebuggingRequest, +-- /// The client will interpret ANSI escape sequences in the display of +-- /// `OutputEvent.output` and `Variable.value` fields when +-- /// `Capabilities.supportsANSIStyling` is also enabled. +-- eClientFeatureANSIStyling, +-+FLAGS_ENUM(ClientFeature){ +-+ eClientFeatureVariableType, +-+ eClientFeatureVariablePaging, +-+ eClientFeatureRunInTerminalRequest, +-+ eClientFeatureMemoryReferences, +-+ eClientFeatureProgressReporting, +-+ eClientFeatureInvalidatedEvent, +-+ eClientFeatureMemoryEvent, +-+ /// Client supports the `argsCanBeInterpretedByShell` attribute on the +-+ /// `runInTerminal` request. +-+ eClientFeatureArgsCanBeInterpretedByShell, +-+ eClientFeatureStartDebuggingRequest, +-+ /// The client will interpret ANSI escape sequences in the display of +-+ /// `OutputEvent.output` and `Variable.value` fields when +-+ /// `Capabilities.supportsANSIStyling` is also enabled. +-+ eClientFeatureANSIStyling, +- }; +- +- /// Format of paths reported by the debug adapter. +--enum PathFormat { ePatFormatPath, ePathFormatURI }; +-+FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; +- +- /// Arguments for `initialize` request. +- struct InitializeRequestArguments { +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +---- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +-@@ -20,6 +20,7 @@ +- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H +- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H +- +-+#include "lldb/lldb-enumerations.h" +- #include "llvm/ADT/DenseSet.h" +- #include "llvm/Support/JSON.h" +- #include +-@@ -56,12 +57,8 @@ +- }; +- llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); +- +--enum ColumnType { +-- eColumnTypeString, +-- eColumnTypeNumber, +-- eColumnTypeBoolean, +-- eColumnTypeTimestamp +--}; +-+FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, +-+ eColumnTypeTimestamp}; +- +- /// A ColumnDescriptor specifies what module attribute to show in a column of +- /// the modules view, how to format it, and what the column’s label should be. +-@@ -90,27 +87,23 @@ +- +- /// Names of checksum algorithms that may be supported by a debug adapter. +- /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. +--enum ChecksumAlgorithm { +-- eChecksumAlgorithmMD5, +-- eChecksumAlgorithmSHA1, +-- eChecksumAlgorithmSHA256, +-- eChecksumAlgorithmTimestamp +--}; +-+FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, +-+ eChecksumAlgorithmSHA256, +-+ eChecksumAlgorithmTimestamp}; +- llvm::json::Value toJSON(const ChecksumAlgorithm &); +- +- /// Describes one or more type of breakpoint a BreakpointMode applies to. This +- /// is a non-exhaustive enumeration and may expand as future breakpoint types +- /// are added. +--enum BreakpointModeApplicability { +-- /// In `SourceBreakpoint`'s. +-- eBreakpointModeApplicabilitySource, +-- /// In exception breakpoints applied in the `ExceptionFilterOptions`. +-- eBreakpointModeApplicabilityException, +-- /// In data breakpoints requested in the `DataBreakpointInfo` request. +-- eBreakpointModeApplicabilityData, +-- /// In `InstructionBreakpoint`'s. +-- eBreakpointModeApplicabilityInstruction +--}; +-+FLAGS_ENUM(BreakpointModeApplicability){ +-+ /// In `SourceBreakpoint`'s. +-+ eBreakpointModeApplicabilitySource, +-+ /// In exception breakpoints applied in the `ExceptionFilterOptions`. +-+ eBreakpointModeApplicabilityException, +-+ /// In data breakpoints requested in the `DataBreakpointInfo` request. +-+ eBreakpointModeApplicabilityData, +-+ /// In `InstructionBreakpoint`'s. +-+ eBreakpointModeApplicabilityInstruction}; +- llvm::json::Value toJSON(const BreakpointModeApplicability &); +- +- /// A `BreakpointMode` is provided as a option when setting breakpoints on +-@@ -133,101 +126,101 @@ +- llvm::json::Value toJSON(const BreakpointMode &); +- +- /// Debug Adapter Features flags supported by lldb-dap. +--enum AdapterFeature { +-- /// The debug adapter supports ANSI escape sequences in styling of +-- /// `OutputEvent.output` and `Variable.value` fields. +-- eAdapterFeatureANSIStyling, +-- /// The debug adapter supports the `breakpointLocations` request. +-- eAdapterFeatureBreakpointLocationsRequest, +-- /// The debug adapter supports the `cancel` request. +-- eAdapterFeatureCancelRequest, +-- /// The debug adapter supports the `clipboard` context value in the +-- /// `evaluate` request. +-- eAdapterFeatureClipboardContext, +-- /// The debug adapter supports the `completions` request. +-- eAdapterFeatureCompletionsRequest, +-- /// The debug adapter supports conditional breakpoints. +-- eAdapterFeatureConditionalBreakpoints, +-- /// The debug adapter supports the `configurationDone` request. +-- eAdapterFeatureConfigurationDoneRequest, +-- /// The debug adapter supports the `asAddress` and `bytes` fields in the +-- /// `dataBreakpointInfo` request. +-- eAdapterFeatureDataBreakpointBytes, +-- /// The debug adapter supports data breakpoints. +-- eAdapterFeatureDataBreakpoints, +-- /// The debug adapter supports the delayed loading of parts of the stack, +-- /// which requires that both the `startFrame` and `levels` arguments and the +-- /// `totalFrames` result of the `stackTrace` request are supported. +-- eAdapterFeatureDelayedStackTraceLoading, +-- /// The debug adapter supports the `disassemble` request. +-- eAdapterFeatureDisassembleRequest, +-- /// The debug adapter supports a (side effect free) `evaluate` request for +-- /// data hovers. +-- eAdapterFeatureEvaluateForHovers, +-- /// The debug adapter supports `filterOptions` as an argument on the +-- /// `setExceptionBreakpoints` request. +-- eAdapterFeatureExceptionFilterOptions, +-- /// The debug adapter supports the `exceptionInfo` request. +-- eAdapterFeatureExceptionInfoRequest, +-- /// The debug adapter supports `exceptionOptions` on the +-- /// `setExceptionBreakpoints` request. +-- eAdapterFeatureExceptionOptions, +-- /// The debug adapter supports function breakpoints. +-- eAdapterFeatureFunctionBreakpoints, +-- /// The debug adapter supports the `gotoTargets` request. +-- eAdapterFeatureGotoTargetsRequest, +-- /// The debug adapter supports breakpoints that break execution after a +-- /// specified number of hits. +-- eAdapterFeatureHitConditionalBreakpoints, +-- /// The debug adapter supports adding breakpoints based on instruction +-- /// references. +-- eAdapterFeatureInstructionBreakpoints, +-- /// The debug adapter supports the `loadedSources` request. +-- eAdapterFeatureLoadedSourcesRequest, +-- /// The debug adapter supports log points by interpreting the `logMessage` +-- /// attribute of the `SourceBreakpoint`. +-- eAdapterFeatureLogPoints, +-- /// The debug adapter supports the `modules` request. +-- eAdapterFeatureModulesRequest, +-- /// The debug adapter supports the `readMemory` request. +-- eAdapterFeatureReadMemoryRequest, +-- /// The debug adapter supports restarting a frame. +-- eAdapterFeatureRestartFrame, +-- /// The debug adapter supports the `restart` request. In this case a client +-- /// should not implement `restart` by terminating and relaunching the +-- /// adapter but by calling the `restart` request. +-- eAdapterFeatureRestartRequest, +-- /// The debug adapter supports the `setExpression` request. +-- eAdapterFeatureSetExpression, +-- /// The debug adapter supports setting a variable to a value. +-- eAdapterFeatureSetVariable, +-- /// The debug adapter supports the `singleThread` property on the execution +-- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +-- /// `stepBack`). +-- eAdapterFeatureSingleThreadExecutionRequests, +-- /// The debug adapter supports stepping back via the `stepBack` and +-- /// `reverseContinue` requests. +-- eAdapterFeatureStepBack, +-- /// The debug adapter supports the `stepInTargets` request. +-- eAdapterFeatureStepInTargetsRequest, +-- /// The debug adapter supports stepping granularities (argument +-- /// `granularity`) for the stepping requests. +-- eAdapterFeatureSteppingGranularity, +-- /// The debug adapter supports the `terminate` request. +-- eAdapterFeatureTerminateRequest, +-- /// The debug adapter supports the `terminateThreads` request. +-- eAdapterFeatureTerminateThreadsRequest, +-- /// The debug adapter supports the `suspendDebuggee` attribute on the +-- /// `disconnect` request. +-- eAdapterFeatureSuspendDebuggee, +-- /// The debug adapter supports a `format` attribute on the `stackTrace`, +-- /// `variables`, and `evaluate` requests. +-- eAdapterFeatureValueFormattingOptions, +-- /// The debug adapter supports the `writeMemory` request. +-- eAdapterFeatureWriteMemoryRequest, +-- /// The debug adapter supports the `terminateDebuggee` attribute on the +-- /// `disconnect` request. +-- eAdapterFeatureTerminateDebuggee, +-+FLAGS_ENUM(AdapterFeature){ +-+ /// The debug adapter supports ANSI escape sequences in styling of +-+ /// `OutputEvent.output` and `Variable.value` fields. +-+ eAdapterFeatureANSIStyling, +-+ /// The debug adapter supports the `breakpointLocations` request. +-+ eAdapterFeatureBreakpointLocationsRequest, +-+ /// The debug adapter supports the `cancel` request. +-+ eAdapterFeatureCancelRequest, +-+ /// The debug adapter supports the `clipboard` context value in the +-+ /// `evaluate` request. +-+ eAdapterFeatureClipboardContext, +-+ /// The debug adapter supports the `completions` request. +-+ eAdapterFeatureCompletionsRequest, +-+ /// The debug adapter supports conditional breakpoints. +-+ eAdapterFeatureConditionalBreakpoints, +-+ /// The debug adapter supports the `configurationDone` request. +-+ eAdapterFeatureConfigurationDoneRequest, +-+ /// The debug adapter supports the `asAddress` and `bytes` fields in the +-+ /// `dataBreakpointInfo` request. +-+ eAdapterFeatureDataBreakpointBytes, +-+ /// The debug adapter supports data breakpoints. +-+ eAdapterFeatureDataBreakpoints, +-+ /// The debug adapter supports the delayed loading of parts of the stack, +-+ /// which requires that both the `startFrame` and `levels` arguments and the +-+ /// `totalFrames` result of the `stackTrace` request are supported. +-+ eAdapterFeatureDelayedStackTraceLoading, +-+ /// The debug adapter supports the `disassemble` request. +-+ eAdapterFeatureDisassembleRequest, +-+ /// The debug adapter supports a (side effect free) `evaluate` request for +-+ /// data hovers. +-+ eAdapterFeatureEvaluateForHovers, +-+ /// The debug adapter supports `filterOptions` as an argument on the +-+ /// `setExceptionBreakpoints` request. +-+ eAdapterFeatureExceptionFilterOptions, +-+ /// The debug adapter supports the `exceptionInfo` request. +-+ eAdapterFeatureExceptionInfoRequest, +-+ /// The debug adapter supports `exceptionOptions` on the +-+ /// `setExceptionBreakpoints` request. +-+ eAdapterFeatureExceptionOptions, +-+ /// The debug adapter supports function breakpoints. +-+ eAdapterFeatureFunctionBreakpoints, +-+ /// The debug adapter supports the `gotoTargets` request. +-+ eAdapterFeatureGotoTargetsRequest, +-+ /// The debug adapter supports breakpoints that break execution after a +-+ /// specified number of hits. +-+ eAdapterFeatureHitConditionalBreakpoints, +-+ /// The debug adapter supports adding breakpoints based on instruction +-+ /// references. +-+ eAdapterFeatureInstructionBreakpoints, +-+ /// The debug adapter supports the `loadedSources` request. +-+ eAdapterFeatureLoadedSourcesRequest, +-+ /// The debug adapter supports log points by interpreting the `logMessage` +-+ /// attribute of the `SourceBreakpoint`. +-+ eAdapterFeatureLogPoints, +-+ /// The debug adapter supports the `modules` request. +-+ eAdapterFeatureModulesRequest, +-+ /// The debug adapter supports the `readMemory` request. +-+ eAdapterFeatureReadMemoryRequest, +-+ /// The debug adapter supports restarting a frame. +-+ eAdapterFeatureRestartFrame, +-+ /// The debug adapter supports the `restart` request. In this case a client +-+ /// should not implement `restart` by terminating and relaunching the +-+ /// adapter but by calling the `restart` request. +-+ eAdapterFeatureRestartRequest, +-+ /// The debug adapter supports the `setExpression` request. +-+ eAdapterFeatureSetExpression, +-+ /// The debug adapter supports setting a variable to a value. +-+ eAdapterFeatureSetVariable, +-+ /// The debug adapter supports the `singleThread` property on the execution +-+ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +-+ /// `stepBack`). +-+ eAdapterFeatureSingleThreadExecutionRequests, +-+ /// The debug adapter supports stepping back via the `stepBack` and +-+ /// `reverseContinue` requests. +-+ eAdapterFeatureStepBack, +-+ /// The debug adapter supports the `stepInTargets` request. +-+ eAdapterFeatureStepInTargetsRequest, +-+ /// The debug adapter supports stepping granularities (argument +-+ /// `granularity`) for the stepping requests. +-+ eAdapterFeatureSteppingGranularity, +-+ /// The debug adapter supports the `terminate` request. +-+ eAdapterFeatureTerminateRequest, +-+ /// The debug adapter supports the `terminateThreads` request. +-+ eAdapterFeatureTerminateThreadsRequest, +-+ /// The debug adapter supports the `suspendDebuggee` attribute on the +-+ /// `disconnect` request. +-+ eAdapterFeatureSuspendDebuggee, +-+ /// The debug adapter supports a `format` attribute on the `stackTrace`, +-+ /// `variables`, and `evaluate` requests. +-+ eAdapterFeatureValueFormattingOptions, +-+ /// The debug adapter supports the `writeMemory` request. +-+ eAdapterFeatureWriteMemoryRequest, +-+ /// The debug adapter supports the `terminateDebuggee` attribute on the +-+ /// `disconnect` request. +-+ eAdapterFeatureTerminateDebuggee, +- }; +- +- /// Information about the capabilities of a debug adapter. +-@@ -268,10 +261,10 @@ +- }; +- llvm::json::Value toJSON(const Capabilities &); +- +--enum PresentationHint { +-- ePresentationHintNormal, +-- ePresentationHintEmphasize, +-- ePresentationHintDeemphasize, +-+FLAGS_ENUM(PresentationHint){ +-+ ePresentationHintNormal, +-+ ePresentationHintEmphasize, +-+ ePresentationHintDeemphasize, +- }; +- +- /// A `Source` is a descriptor for source code. It is returned from the debug +-diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +---- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +-+++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +-@@ -1,7 +1,7 @@ +- // Header +- // +- // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) +--// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) +-+// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) +- // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) +- // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) +- // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) +-diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c +---- a/offload/test/offloading/gpupgo/pgo1.c +-+++ b/offload/test/offloading/gpupgo/pgo1.c +-@@ -14,7 +14,7 @@ +- // RUN: %target_triple.%basename_t.clang.profraw | \ +- // RUN: %fcheck-generic --check-prefix="CLANG-PGO" +- +--// REQUIRES: gpu +-+// REQUIRES: amdgpu +- // REQUIRES: pgo +- +- int test1(int a) { return a / 2; } +-diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c +---- a/offload/test/offloading/gpupgo/pgo2.c +-+++ b/offload/test/offloading/gpupgo/pgo2.c +-@@ -48,7 +48,7 @@ +- // RUN: %target_triple.%basename_t.hfdi.profraw \ +- // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" +- +--// REQUIRES: gpu +-+// REQUIRES: amdgpu +- // REQUIRES: pgo +- +- int main() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 725480b..005737a 100644 +index 005737a..fd9baec 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" -- LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" -+ LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" -+ LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" +- LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" +- LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" ++ LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" ++ LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 7b1a0496a0965c..f93f7a93cea2e8 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "9435b34df0279d473240f5bcc2a829d0589ae372" - SHARDY_SHA256 = "5f2a037d3301a1407e5c94778dd56d855f5abe26999cce448ccfa1923cf9559f" + SHARDY_COMMIT = "f25a97f80402e43de93f46931c6dddc485e8dad0" + SHARDY_SHA256 = "c7e55d3902175c064d3dd6bffc856c5e0198ff6c8f1410c3a97ed2c8e85ddb30" tf_http_archive( name = "shardy", diff --git a/third_party/triton/llvm_integration/cl742325920.patch b/third_party/triton/llvm_integration/cl742325920.patch new file mode 100644 index 00000000000000..3a391a40c30b1a --- /dev/null +++ b/third_party/triton/llvm_integration/cl742325920.patch @@ -0,0 +1,58 @@ + +--- a/third_party/amd/include/Analysis/RangeAnalysis.h 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/amd/include/Analysis/RangeAnalysis.h 2025-03-31 11:20:15.000000000 -0700 +@@ -118,8 +118,11 @@ + + std::optional> + collectRanges(const DataFlowSolver &solver, ValueRange values); ++ + bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp); + ++bool isEmptyInitializedRange(ConstantIntRanges rv); ++ + } // namespace mlir::triton::AMD + + #endif + +--- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp 2025-03-31 11:20:15.000000000 -0700 +@@ -186,6 +186,13 @@ + + namespace mlir::triton::AMD { + ++bool isEmptyInitializedRange(ConstantIntRanges rv) { ++ if (!rv.umin().getBitWidth() || !rv.umax().getBitWidth() || ++ !rv.smin().getBitWidth() || !rv.smax().getBitWidth()) ++ return true; ++ return false; ++} ++ + std::optional> + collectRanges(const DataFlowSolver &solver, ValueRange values) { + SmallVector ranges; +@@ -196,6 +203,8 @@ + return {}; + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); ++ if (isEmptyInitializedRange(inferredRange)) ++ return {}; + ranges.push_back(inferredRange); + } + return ranges; + +--- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-31 11:20:16.000000000 -0700 +@@ -34,6 +34,13 @@ + return signalPassFailure(); + + auto nonNegativePred = [&solver](Value v) -> bool { ++ if (const auto *r = ++ solver->lookupState(v)) { ++ if (r->getValue().isUninitialized()) ++ return false; ++ if (AMD::isEmptyInitializedRange(r->getValue().getValue())) ++ return false; ++ } + return succeeded(dataflow::staticallyNonNegative(*solver, v)); + }; + mod->walk([&solver, nonNegativePred](Operation *op) { diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index b82b9f6b87bb1c..64e38dfaa30e75 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -10,5 +10,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. llvm_patch_list = [ "//third_party/triton:llvm_integration/cl740926882.patch", "//third_party/triton:llvm_integration/cl741558316.patch", + "//third_party/triton:llvm_integration/cl742325920.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 5a732df12a541f..1cc7fe6b95c33c 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,566 +1,566 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..99ef3cb 100644 +index 99ef3cb..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1 +1,546 @@ +@@ -1,546 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp -+--- a/clang/lib/Driver/ToolChains/Clang.cpp -++++ b/clang/lib/Driver/ToolChains/Clang.cpp -+@@ -6397,7 +6397,9 @@ -+ Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, -+ options::OPT_fno_convergent_functions); -+ -+- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -++ // NVPTX doesn't support PGO or coverage -++ if (!Triple.isNVPTX()) -++ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -+ -+ Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); -+ -+diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu -+--- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu -++++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu -+@@ -0,0 +1,33 @@ -++// Check that profiling/coverage arguments doen't get passed down to device-side -++// compilation. -++// -++// -++// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// XRUN: -fprofile-generate %s 2>&1 | \ -++// XRUN: FileCheck --check-prefixes=CHECK,PROF %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -fprofile-instr-generate %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,PROF %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -coverage %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -ftest-coverage %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,GCOV %s -++// -++// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ -++// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ -++// RUN: FileCheck --check-prefixes=CHECK,PROF %s -++// -++// -++// CHECK-NOT: error: unsupported option '-fprofile -++// CHECK-NOT: error: invalid argument -++// CHECK-DAG: "-fcuda-is-device" -++// CHECK-NOT: "-f{{[^"/]*coverage.*}}" -++// CHECK-NOT: "-fprofile{{[^"]*}}" -++// CHECK: "-triple" "x86_64-unknown-linux-gnu" -++// PROF: "-fprofile{{.*}}" -++// GCOV: "-coverage-notes-file= -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp -+--- a/lldb/tools/lldb-dap/DAP.cpp -++++ b/lldb/tools/lldb-dap/DAP.cpp -+@@ -711,12 +711,12 @@ -+ [](const std::string &message) -> llvm::StringRef { -+ return message; -+ }, -+- [](const protocol::Response::Message &message) -++ [](const protocol::ResponseMessage &message) -+ -> llvm::StringRef { -+ switch (message) { -+- case protocol::Response::Message::cancelled: -++ case protocol::eResponseMessageCancelled: -+ return "cancelled"; -+- case protocol::Response::Message::notStopped: -++ case protocol::eResponseMessageNotStopped: -+ return "notStopped"; -+ } -+ }), -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp -+@@ -7,6 +7,7 @@ -+ //===----------------------------------------------------------------------===// -+ -+ #include "Protocol/ProtocolBase.h" -++#include "lldb/lldb-enumerations.h" -+ #include "llvm/ADT/StringRef.h" -+ #include "llvm/ADT/StringSwitch.h" -+ #include "llvm/Support/ErrorHandling.h" -+@@ -31,11 +32,8 @@ -+ -+ namespace lldb_dap::protocol { -+ -+-enum MessageType { -+- eMessageTypeRequest, -+- eMessageTypeResponse, -+- eMessageTypeEvent -+-}; -++FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, -++ eMessageTypeEvent}; -+ -+ bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { -+ auto rawType = Params.getAsString(); -+@@ -107,12 +105,12 @@ -+ -+ if (R.message) { -+ assert(!R.success && "message can only be used if success is false"); -+- if (const auto *messageEnum = std::get_if(&*R.message)) { -++ if (const auto *messageEnum = std::get_if(&*R.message)) { -+ switch (*messageEnum) { -+- case Response::Message::cancelled: -++ case eResponseMessageCancelled: -+ Result.insert({"message", "cancelled"}); -+ break; -+- case Response::Message::notStopped: -++ case eResponseMessageNotStopped: -+ Result.insert({"message", "notStopped"}); -+ break; -+ } -+@@ -129,16 +127,16 @@ -+ } -+ -+ bool fromJSON(json::Value const &Params, -+- std::variant &M, json::Path P) { -++ std::variant &M, json::Path P) { -+ auto rawMessage = Params.getAsString(); -+ if (!rawMessage) { -+ P.report("expected a string"); -+ return false; -+ } -+- std::optional message = -+- StringSwitch>(*rawMessage) -+- .Case("cancelled", Response::Message::cancelled) -+- .Case("notStopped", Response::Message::notStopped) -++ std::optional message = -++ StringSwitch>(*rawMessage) -++ .Case("cancelled", eResponseMessageCancelled) -++ .Case("notStopped", eResponseMessageNotStopped) -+ .Default(std::nullopt); -+ if (message) -+ M = *message; -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h -+@@ -20,6 +20,7 @@ -+ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -+ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -+ -++#include "lldb/lldb-enumerations.h" -+ #include "llvm/Support/JSON.h" -+ #include -+ #include -+@@ -64,15 +65,15 @@ -+ llvm::json::Value toJSON(const Event &); -+ bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); -+ -+-/// Response for a request. -+-struct Response { -+- enum class Message { -++FLAGS_ENUM(ResponseMessage){ -+ /// The request was cancelled -+- cancelled, -++ eResponseMessageCancelled, -+ /// The request may be retried once the adapter is in a 'stopped' state -+- notStopped, -+- }; -++ eResponseMessageNotStopped, -++}; -+ -++/// Response for a request. -++struct Response { -+ /// Sequence number of the corresponding request. -+ int64_t request_seq; -+ -+@@ -90,7 +91,7 @@ -+ /// Contains the raw error in short form if `success` is false. This raw error -+ /// might be interpreted by the client and is not shown in the UI. Some -+ /// predefined values exist. -+- std::optional> message; -++ std::optional> message; -+ -+ /// Contains request result if success is true and error details if success is -+ /// false. -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h -+@@ -22,6 +22,8 @@ -+ -+ #include "Protocol/ProtocolBase.h" -+ #include "Protocol/ProtocolTypes.h" -++#include "lldb/lldb-enumerations.h" -++#include "llvm/ADT/DenseSet.h" -+ #include "llvm/Support/JSON.h" -+ #include -+ #include -+@@ -55,26 +57,26 @@ -+ using DisconnectResponse = VoidResponse; -+ -+ /// Features supported by DAP clients. -+-enum ClientFeature { -+- eClientFeatureVariableType, -+- eClientFeatureVariablePaging, -+- eClientFeatureRunInTerminalRequest, -+- eClientFeatureMemoryReferences, -+- eClientFeatureProgressReporting, -+- eClientFeatureInvalidatedEvent, -+- eClientFeatureMemoryEvent, -+- /// Client supports the `argsCanBeInterpretedByShell` attribute on the -+- /// `runInTerminal` request. -+- eClientFeatureArgsCanBeInterpretedByShell, -+- eClientFeatureStartDebuggingRequest, -+- /// The client will interpret ANSI escape sequences in the display of -+- /// `OutputEvent.output` and `Variable.value` fields when -+- /// `Capabilities.supportsANSIStyling` is also enabled. -+- eClientFeatureANSIStyling, -++FLAGS_ENUM(ClientFeature){ -++ eClientFeatureVariableType, -++ eClientFeatureVariablePaging, -++ eClientFeatureRunInTerminalRequest, -++ eClientFeatureMemoryReferences, -++ eClientFeatureProgressReporting, -++ eClientFeatureInvalidatedEvent, -++ eClientFeatureMemoryEvent, -++ /// Client supports the `argsCanBeInterpretedByShell` attribute on the -++ /// `runInTerminal` request. -++ eClientFeatureArgsCanBeInterpretedByShell, -++ eClientFeatureStartDebuggingRequest, -++ /// The client will interpret ANSI escape sequences in the display of -++ /// `OutputEvent.output` and `Variable.value` fields when -++ /// `Capabilities.supportsANSIStyling` is also enabled. -++ eClientFeatureANSIStyling, -+ }; -+ -+ /// Format of paths reported by the debug adapter. -+-enum PathFormat { ePatFormatPath, ePathFormatURI }; -++FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; -+ -+ /// Arguments for `initialize` request. -+ struct InitializeRequestArguments { -+diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -+--- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -++++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h -+@@ -20,6 +20,7 @@ -+ #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -+ #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -+ -++#include "lldb/lldb-enumerations.h" -+ #include "llvm/ADT/DenseSet.h" -+ #include "llvm/Support/JSON.h" -+ #include -+@@ -56,12 +57,8 @@ -+ }; -+ llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); -+ -+-enum ColumnType { -+- eColumnTypeString, -+- eColumnTypeNumber, -+- eColumnTypeBoolean, -+- eColumnTypeTimestamp -+-}; -++FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, -++ eColumnTypeTimestamp}; -+ -+ /// A ColumnDescriptor specifies what module attribute to show in a column of -+ /// the modules view, how to format it, and what the column’s label should be. -+@@ -90,27 +87,23 @@ -+ -+ /// Names of checksum algorithms that may be supported by a debug adapter. -+ /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. -+-enum ChecksumAlgorithm { -+- eChecksumAlgorithmMD5, -+- eChecksumAlgorithmSHA1, -+- eChecksumAlgorithmSHA256, -+- eChecksumAlgorithmTimestamp -+-}; -++FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, -++ eChecksumAlgorithmSHA256, -++ eChecksumAlgorithmTimestamp}; -+ llvm::json::Value toJSON(const ChecksumAlgorithm &); -+ -+ /// Describes one or more type of breakpoint a BreakpointMode applies to. This -+ /// is a non-exhaustive enumeration and may expand as future breakpoint types -+ /// are added. -+-enum BreakpointModeApplicability { -+- /// In `SourceBreakpoint`'s. -+- eBreakpointModeApplicabilitySource, -+- /// In exception breakpoints applied in the `ExceptionFilterOptions`. -+- eBreakpointModeApplicabilityException, -+- /// In data breakpoints requested in the `DataBreakpointInfo` request. -+- eBreakpointModeApplicabilityData, -+- /// In `InstructionBreakpoint`'s. -+- eBreakpointModeApplicabilityInstruction -+-}; -++FLAGS_ENUM(BreakpointModeApplicability){ -++ /// In `SourceBreakpoint`'s. -++ eBreakpointModeApplicabilitySource, -++ /// In exception breakpoints applied in the `ExceptionFilterOptions`. -++ eBreakpointModeApplicabilityException, -++ /// In data breakpoints requested in the `DataBreakpointInfo` request. -++ eBreakpointModeApplicabilityData, -++ /// In `InstructionBreakpoint`'s. -++ eBreakpointModeApplicabilityInstruction}; -+ llvm::json::Value toJSON(const BreakpointModeApplicability &); -+ -+ /// A `BreakpointMode` is provided as a option when setting breakpoints on -+@@ -133,101 +126,101 @@ -+ llvm::json::Value toJSON(const BreakpointMode &); -+ -+ /// Debug Adapter Features flags supported by lldb-dap. -+-enum AdapterFeature { -+- /// The debug adapter supports ANSI escape sequences in styling of -+- /// `OutputEvent.output` and `Variable.value` fields. -+- eAdapterFeatureANSIStyling, -+- /// The debug adapter supports the `breakpointLocations` request. -+- eAdapterFeatureBreakpointLocationsRequest, -+- /// The debug adapter supports the `cancel` request. -+- eAdapterFeatureCancelRequest, -+- /// The debug adapter supports the `clipboard` context value in the -+- /// `evaluate` request. -+- eAdapterFeatureClipboardContext, -+- /// The debug adapter supports the `completions` request. -+- eAdapterFeatureCompletionsRequest, -+- /// The debug adapter supports conditional breakpoints. -+- eAdapterFeatureConditionalBreakpoints, -+- /// The debug adapter supports the `configurationDone` request. -+- eAdapterFeatureConfigurationDoneRequest, -+- /// The debug adapter supports the `asAddress` and `bytes` fields in the -+- /// `dataBreakpointInfo` request. -+- eAdapterFeatureDataBreakpointBytes, -+- /// The debug adapter supports data breakpoints. -+- eAdapterFeatureDataBreakpoints, -+- /// The debug adapter supports the delayed loading of parts of the stack, -+- /// which requires that both the `startFrame` and `levels` arguments and the -+- /// `totalFrames` result of the `stackTrace` request are supported. -+- eAdapterFeatureDelayedStackTraceLoading, -+- /// The debug adapter supports the `disassemble` request. -+- eAdapterFeatureDisassembleRequest, -+- /// The debug adapter supports a (side effect free) `evaluate` request for -+- /// data hovers. -+- eAdapterFeatureEvaluateForHovers, -+- /// The debug adapter supports `filterOptions` as an argument on the -+- /// `setExceptionBreakpoints` request. -+- eAdapterFeatureExceptionFilterOptions, -+- /// The debug adapter supports the `exceptionInfo` request. -+- eAdapterFeatureExceptionInfoRequest, -+- /// The debug adapter supports `exceptionOptions` on the -+- /// `setExceptionBreakpoints` request. -+- eAdapterFeatureExceptionOptions, -+- /// The debug adapter supports function breakpoints. -+- eAdapterFeatureFunctionBreakpoints, -+- /// The debug adapter supports the `gotoTargets` request. -+- eAdapterFeatureGotoTargetsRequest, -+- /// The debug adapter supports breakpoints that break execution after a -+- /// specified number of hits. -+- eAdapterFeatureHitConditionalBreakpoints, -+- /// The debug adapter supports adding breakpoints based on instruction -+- /// references. -+- eAdapterFeatureInstructionBreakpoints, -+- /// The debug adapter supports the `loadedSources` request. -+- eAdapterFeatureLoadedSourcesRequest, -+- /// The debug adapter supports log points by interpreting the `logMessage` -+- /// attribute of the `SourceBreakpoint`. -+- eAdapterFeatureLogPoints, -+- /// The debug adapter supports the `modules` request. -+- eAdapterFeatureModulesRequest, -+- /// The debug adapter supports the `readMemory` request. -+- eAdapterFeatureReadMemoryRequest, -+- /// The debug adapter supports restarting a frame. -+- eAdapterFeatureRestartFrame, -+- /// The debug adapter supports the `restart` request. In this case a client -+- /// should not implement `restart` by terminating and relaunching the -+- /// adapter but by calling the `restart` request. -+- eAdapterFeatureRestartRequest, -+- /// The debug adapter supports the `setExpression` request. -+- eAdapterFeatureSetExpression, -+- /// The debug adapter supports setting a variable to a value. -+- eAdapterFeatureSetVariable, -+- /// The debug adapter supports the `singleThread` property on the execution -+- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, -+- /// `stepBack`). -+- eAdapterFeatureSingleThreadExecutionRequests, -+- /// The debug adapter supports stepping back via the `stepBack` and -+- /// `reverseContinue` requests. -+- eAdapterFeatureStepBack, -+- /// The debug adapter supports the `stepInTargets` request. -+- eAdapterFeatureStepInTargetsRequest, -+- /// The debug adapter supports stepping granularities (argument -+- /// `granularity`) for the stepping requests. -+- eAdapterFeatureSteppingGranularity, -+- /// The debug adapter supports the `terminate` request. -+- eAdapterFeatureTerminateRequest, -+- /// The debug adapter supports the `terminateThreads` request. -+- eAdapterFeatureTerminateThreadsRequest, -+- /// The debug adapter supports the `suspendDebuggee` attribute on the -+- /// `disconnect` request. -+- eAdapterFeatureSuspendDebuggee, -+- /// The debug adapter supports a `format` attribute on the `stackTrace`, -+- /// `variables`, and `evaluate` requests. -+- eAdapterFeatureValueFormattingOptions, -+- /// The debug adapter supports the `writeMemory` request. -+- eAdapterFeatureWriteMemoryRequest, -+- /// The debug adapter supports the `terminateDebuggee` attribute on the -+- /// `disconnect` request. -+- eAdapterFeatureTerminateDebuggee, -++FLAGS_ENUM(AdapterFeature){ -++ /// The debug adapter supports ANSI escape sequences in styling of -++ /// `OutputEvent.output` and `Variable.value` fields. -++ eAdapterFeatureANSIStyling, -++ /// The debug adapter supports the `breakpointLocations` request. -++ eAdapterFeatureBreakpointLocationsRequest, -++ /// The debug adapter supports the `cancel` request. -++ eAdapterFeatureCancelRequest, -++ /// The debug adapter supports the `clipboard` context value in the -++ /// `evaluate` request. -++ eAdapterFeatureClipboardContext, -++ /// The debug adapter supports the `completions` request. -++ eAdapterFeatureCompletionsRequest, -++ /// The debug adapter supports conditional breakpoints. -++ eAdapterFeatureConditionalBreakpoints, -++ /// The debug adapter supports the `configurationDone` request. -++ eAdapterFeatureConfigurationDoneRequest, -++ /// The debug adapter supports the `asAddress` and `bytes` fields in the -++ /// `dataBreakpointInfo` request. -++ eAdapterFeatureDataBreakpointBytes, -++ /// The debug adapter supports data breakpoints. -++ eAdapterFeatureDataBreakpoints, -++ /// The debug adapter supports the delayed loading of parts of the stack, -++ /// which requires that both the `startFrame` and `levels` arguments and the -++ /// `totalFrames` result of the `stackTrace` request are supported. -++ eAdapterFeatureDelayedStackTraceLoading, -++ /// The debug adapter supports the `disassemble` request. -++ eAdapterFeatureDisassembleRequest, -++ /// The debug adapter supports a (side effect free) `evaluate` request for -++ /// data hovers. -++ eAdapterFeatureEvaluateForHovers, -++ /// The debug adapter supports `filterOptions` as an argument on the -++ /// `setExceptionBreakpoints` request. -++ eAdapterFeatureExceptionFilterOptions, -++ /// The debug adapter supports the `exceptionInfo` request. -++ eAdapterFeatureExceptionInfoRequest, -++ /// The debug adapter supports `exceptionOptions` on the -++ /// `setExceptionBreakpoints` request. -++ eAdapterFeatureExceptionOptions, -++ /// The debug adapter supports function breakpoints. -++ eAdapterFeatureFunctionBreakpoints, -++ /// The debug adapter supports the `gotoTargets` request. -++ eAdapterFeatureGotoTargetsRequest, -++ /// The debug adapter supports breakpoints that break execution after a -++ /// specified number of hits. -++ eAdapterFeatureHitConditionalBreakpoints, -++ /// The debug adapter supports adding breakpoints based on instruction -++ /// references. -++ eAdapterFeatureInstructionBreakpoints, -++ /// The debug adapter supports the `loadedSources` request. -++ eAdapterFeatureLoadedSourcesRequest, -++ /// The debug adapter supports log points by interpreting the `logMessage` -++ /// attribute of the `SourceBreakpoint`. -++ eAdapterFeatureLogPoints, -++ /// The debug adapter supports the `modules` request. -++ eAdapterFeatureModulesRequest, -++ /// The debug adapter supports the `readMemory` request. -++ eAdapterFeatureReadMemoryRequest, -++ /// The debug adapter supports restarting a frame. -++ eAdapterFeatureRestartFrame, -++ /// The debug adapter supports the `restart` request. In this case a client -++ /// should not implement `restart` by terminating and relaunching the -++ /// adapter but by calling the `restart` request. -++ eAdapterFeatureRestartRequest, -++ /// The debug adapter supports the `setExpression` request. -++ eAdapterFeatureSetExpression, -++ /// The debug adapter supports setting a variable to a value. -++ eAdapterFeatureSetVariable, -++ /// The debug adapter supports the `singleThread` property on the execution -++ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, -++ /// `stepBack`). -++ eAdapterFeatureSingleThreadExecutionRequests, -++ /// The debug adapter supports stepping back via the `stepBack` and -++ /// `reverseContinue` requests. -++ eAdapterFeatureStepBack, -++ /// The debug adapter supports the `stepInTargets` request. -++ eAdapterFeatureStepInTargetsRequest, -++ /// The debug adapter supports stepping granularities (argument -++ /// `granularity`) for the stepping requests. -++ eAdapterFeatureSteppingGranularity, -++ /// The debug adapter supports the `terminate` request. -++ eAdapterFeatureTerminateRequest, -++ /// The debug adapter supports the `terminateThreads` request. -++ eAdapterFeatureTerminateThreadsRequest, -++ /// The debug adapter supports the `suspendDebuggee` attribute on the -++ /// `disconnect` request. -++ eAdapterFeatureSuspendDebuggee, -++ /// The debug adapter supports a `format` attribute on the `stackTrace`, -++ /// `variables`, and `evaluate` requests. -++ eAdapterFeatureValueFormattingOptions, -++ /// The debug adapter supports the `writeMemory` request. -++ eAdapterFeatureWriteMemoryRequest, -++ /// The debug adapter supports the `terminateDebuggee` attribute on the -++ /// `disconnect` request. -++ eAdapterFeatureTerminateDebuggee, -+ }; -+ -+ /// Information about the capabilities of a debug adapter. -+@@ -268,10 +261,10 @@ -+ }; -+ llvm::json::Value toJSON(const Capabilities &); -+ -+-enum PresentationHint { -+- ePresentationHintNormal, -+- ePresentationHintEmphasize, -+- ePresentationHintDeemphasize, -++FLAGS_ENUM(PresentationHint){ -++ ePresentationHintNormal, -++ ePresentationHintEmphasize, -++ ePresentationHintDeemphasize, -+ }; -+ -+ /// A `Source` is a descriptor for source code. It is returned from the debug -+diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -+--- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -++++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test -+@@ -1,7 +1,7 @@ -+ // Header -+ // -+ // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) -+-// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) -++// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) -+ // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) -+ // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) -+ // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) -+diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c -+--- a/offload/test/offloading/gpupgo/pgo1.c -++++ b/offload/test/offloading/gpupgo/pgo1.c -+@@ -14,7 +14,7 @@ -+ // RUN: %target_triple.%basename_t.clang.profraw | \ -+ // RUN: %fcheck-generic --check-prefix="CLANG-PGO" -+ -+-// REQUIRES: gpu -++// REQUIRES: amdgpu -+ // REQUIRES: pgo -+ -+ int test1(int a) { return a / 2; } -+diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c -+--- a/offload/test/offloading/gpupgo/pgo2.c -++++ b/offload/test/offloading/gpupgo/pgo2.c -+@@ -48,7 +48,7 @@ -+ // RUN: %target_triple.%basename_t.hfdi.profraw \ -+ // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" -+ -+-// REQUIRES: gpu -++// REQUIRES: amdgpu -+ // REQUIRES: pgo -+ -+ int main() { +-diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp +---- a/clang/lib/Driver/ToolChains/Clang.cpp +-+++ b/clang/lib/Driver/ToolChains/Clang.cpp +-@@ -6397,7 +6397,9 @@ +- Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, +- options::OPT_fno_convergent_functions); +- +-- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); +-+ // NVPTX doesn't support PGO or coverage +-+ if (!Triple.isNVPTX()) +-+ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); +- +- Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); +- +-diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu +---- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu +-+++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu +-@@ -0,0 +1,33 @@ +-+// Check that profiling/coverage arguments doen't get passed down to device-side +-+// compilation. +-+// +-+// +-+// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// XRUN: -fprofile-generate %s 2>&1 | \ +-+// XRUN: FileCheck --check-prefixes=CHECK,PROF %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -fprofile-instr-generate %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,PROF %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -coverage %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -ftest-coverage %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s +-+// +-+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ +-+// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ +-+// RUN: FileCheck --check-prefixes=CHECK,PROF %s +-+// +-+// +-+// CHECK-NOT: error: unsupported option '-fprofile +-+// CHECK-NOT: error: invalid argument +-+// CHECK-DAG: "-fcuda-is-device" +-+// CHECK-NOT: "-f{{[^"/]*coverage.*}}" +-+// CHECK-NOT: "-fprofile{{[^"]*}}" +-+// CHECK: "-triple" "x86_64-unknown-linux-gnu" +-+// PROF: "-fprofile{{.*}}" +-+// GCOV: "-coverage-notes-file= +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp +---- a/lldb/tools/lldb-dap/DAP.cpp +-+++ b/lldb/tools/lldb-dap/DAP.cpp +-@@ -711,12 +711,12 @@ +- [](const std::string &message) -> llvm::StringRef { +- return message; +- }, +-- [](const protocol::Response::Message &message) +-+ [](const protocol::ResponseMessage &message) +- -> llvm::StringRef { +- switch (message) { +-- case protocol::Response::Message::cancelled: +-+ case protocol::eResponseMessageCancelled: +- return "cancelled"; +-- case protocol::Response::Message::notStopped: +-+ case protocol::eResponseMessageNotStopped: +- return "notStopped"; +- } +- }), +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +---- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +-@@ -7,6 +7,7 @@ +- //===----------------------------------------------------------------------===// +- +- #include "Protocol/ProtocolBase.h" +-+#include "lldb/lldb-enumerations.h" +- #include "llvm/ADT/StringRef.h" +- #include "llvm/ADT/StringSwitch.h" +- #include "llvm/Support/ErrorHandling.h" +-@@ -31,11 +32,8 @@ +- +- namespace lldb_dap::protocol { +- +--enum MessageType { +-- eMessageTypeRequest, +-- eMessageTypeResponse, +-- eMessageTypeEvent +--}; +-+FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, +-+ eMessageTypeEvent}; +- +- bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { +- auto rawType = Params.getAsString(); +-@@ -107,12 +105,12 @@ +- +- if (R.message) { +- assert(!R.success && "message can only be used if success is false"); +-- if (const auto *messageEnum = std::get_if(&*R.message)) { +-+ if (const auto *messageEnum = std::get_if(&*R.message)) { +- switch (*messageEnum) { +-- case Response::Message::cancelled: +-+ case eResponseMessageCancelled: +- Result.insert({"message", "cancelled"}); +- break; +-- case Response::Message::notStopped: +-+ case eResponseMessageNotStopped: +- Result.insert({"message", "notStopped"}); +- break; +- } +-@@ -129,16 +127,16 @@ +- } +- +- bool fromJSON(json::Value const &Params, +-- std::variant &M, json::Path P) { +-+ std::variant &M, json::Path P) { +- auto rawMessage = Params.getAsString(); +- if (!rawMessage) { +- P.report("expected a string"); +- return false; +- } +-- std::optional message = +-- StringSwitch>(*rawMessage) +-- .Case("cancelled", Response::Message::cancelled) +-- .Case("notStopped", Response::Message::notStopped) +-+ std::optional message = +-+ StringSwitch>(*rawMessage) +-+ .Case("cancelled", eResponseMessageCancelled) +-+ .Case("notStopped", eResponseMessageNotStopped) +- .Default(std::nullopt); +- if (message) +- M = *message; +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +---- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +-@@ -20,6 +20,7 @@ +- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H +- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H +- +-+#include "lldb/lldb-enumerations.h" +- #include "llvm/Support/JSON.h" +- #include +- #include +-@@ -64,15 +65,15 @@ +- llvm::json::Value toJSON(const Event &); +- bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); +- +--/// Response for a request. +--struct Response { +-- enum class Message { +-+FLAGS_ENUM(ResponseMessage){ +- /// The request was cancelled +-- cancelled, +-+ eResponseMessageCancelled, +- /// The request may be retried once the adapter is in a 'stopped' state +-- notStopped, +-- }; +-+ eResponseMessageNotStopped, +-+}; +- +-+/// Response for a request. +-+struct Response { +- /// Sequence number of the corresponding request. +- int64_t request_seq; +- +-@@ -90,7 +91,7 @@ +- /// Contains the raw error in short form if `success` is false. This raw error +- /// might be interpreted by the client and is not shown in the UI. Some +- /// predefined values exist. +-- std::optional> message; +-+ std::optional> message; +- +- /// Contains request result if success is true and error details if success is +- /// false. +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +---- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +-@@ -22,6 +22,8 @@ +- +- #include "Protocol/ProtocolBase.h" +- #include "Protocol/ProtocolTypes.h" +-+#include "lldb/lldb-enumerations.h" +-+#include "llvm/ADT/DenseSet.h" +- #include "llvm/Support/JSON.h" +- #include +- #include +-@@ -55,26 +57,26 @@ +- using DisconnectResponse = VoidResponse; +- +- /// Features supported by DAP clients. +--enum ClientFeature { +-- eClientFeatureVariableType, +-- eClientFeatureVariablePaging, +-- eClientFeatureRunInTerminalRequest, +-- eClientFeatureMemoryReferences, +-- eClientFeatureProgressReporting, +-- eClientFeatureInvalidatedEvent, +-- eClientFeatureMemoryEvent, +-- /// Client supports the `argsCanBeInterpretedByShell` attribute on the +-- /// `runInTerminal` request. +-- eClientFeatureArgsCanBeInterpretedByShell, +-- eClientFeatureStartDebuggingRequest, +-- /// The client will interpret ANSI escape sequences in the display of +-- /// `OutputEvent.output` and `Variable.value` fields when +-- /// `Capabilities.supportsANSIStyling` is also enabled. +-- eClientFeatureANSIStyling, +-+FLAGS_ENUM(ClientFeature){ +-+ eClientFeatureVariableType, +-+ eClientFeatureVariablePaging, +-+ eClientFeatureRunInTerminalRequest, +-+ eClientFeatureMemoryReferences, +-+ eClientFeatureProgressReporting, +-+ eClientFeatureInvalidatedEvent, +-+ eClientFeatureMemoryEvent, +-+ /// Client supports the `argsCanBeInterpretedByShell` attribute on the +-+ /// `runInTerminal` request. +-+ eClientFeatureArgsCanBeInterpretedByShell, +-+ eClientFeatureStartDebuggingRequest, +-+ /// The client will interpret ANSI escape sequences in the display of +-+ /// `OutputEvent.output` and `Variable.value` fields when +-+ /// `Capabilities.supportsANSIStyling` is also enabled. +-+ eClientFeatureANSIStyling, +- }; +- +- /// Format of paths reported by the debug adapter. +--enum PathFormat { ePatFormatPath, ePathFormatURI }; +-+FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; +- +- /// Arguments for `initialize` request. +- struct InitializeRequestArguments { +-diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +---- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +-+++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +-@@ -20,6 +20,7 @@ +- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H +- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H +- +-+#include "lldb/lldb-enumerations.h" +- #include "llvm/ADT/DenseSet.h" +- #include "llvm/Support/JSON.h" +- #include +-@@ -56,12 +57,8 @@ +- }; +- llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); +- +--enum ColumnType { +-- eColumnTypeString, +-- eColumnTypeNumber, +-- eColumnTypeBoolean, +-- eColumnTypeTimestamp +--}; +-+FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, +-+ eColumnTypeTimestamp}; +- +- /// A ColumnDescriptor specifies what module attribute to show in a column of +- /// the modules view, how to format it, and what the column’s label should be. +-@@ -90,27 +87,23 @@ +- +- /// Names of checksum algorithms that may be supported by a debug adapter. +- /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. +--enum ChecksumAlgorithm { +-- eChecksumAlgorithmMD5, +-- eChecksumAlgorithmSHA1, +-- eChecksumAlgorithmSHA256, +-- eChecksumAlgorithmTimestamp +--}; +-+FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, +-+ eChecksumAlgorithmSHA256, +-+ eChecksumAlgorithmTimestamp}; +- llvm::json::Value toJSON(const ChecksumAlgorithm &); +- +- /// Describes one or more type of breakpoint a BreakpointMode applies to. This +- /// is a non-exhaustive enumeration and may expand as future breakpoint types +- /// are added. +--enum BreakpointModeApplicability { +-- /// In `SourceBreakpoint`'s. +-- eBreakpointModeApplicabilitySource, +-- /// In exception breakpoints applied in the `ExceptionFilterOptions`. +-- eBreakpointModeApplicabilityException, +-- /// In data breakpoints requested in the `DataBreakpointInfo` request. +-- eBreakpointModeApplicabilityData, +-- /// In `InstructionBreakpoint`'s. +-- eBreakpointModeApplicabilityInstruction +--}; +-+FLAGS_ENUM(BreakpointModeApplicability){ +-+ /// In `SourceBreakpoint`'s. +-+ eBreakpointModeApplicabilitySource, +-+ /// In exception breakpoints applied in the `ExceptionFilterOptions`. +-+ eBreakpointModeApplicabilityException, +-+ /// In data breakpoints requested in the `DataBreakpointInfo` request. +-+ eBreakpointModeApplicabilityData, +-+ /// In `InstructionBreakpoint`'s. +-+ eBreakpointModeApplicabilityInstruction}; +- llvm::json::Value toJSON(const BreakpointModeApplicability &); +- +- /// A `BreakpointMode` is provided as a option when setting breakpoints on +-@@ -133,101 +126,101 @@ +- llvm::json::Value toJSON(const BreakpointMode &); +- +- /// Debug Adapter Features flags supported by lldb-dap. +--enum AdapterFeature { +-- /// The debug adapter supports ANSI escape sequences in styling of +-- /// `OutputEvent.output` and `Variable.value` fields. +-- eAdapterFeatureANSIStyling, +-- /// The debug adapter supports the `breakpointLocations` request. +-- eAdapterFeatureBreakpointLocationsRequest, +-- /// The debug adapter supports the `cancel` request. +-- eAdapterFeatureCancelRequest, +-- /// The debug adapter supports the `clipboard` context value in the +-- /// `evaluate` request. +-- eAdapterFeatureClipboardContext, +-- /// The debug adapter supports the `completions` request. +-- eAdapterFeatureCompletionsRequest, +-- /// The debug adapter supports conditional breakpoints. +-- eAdapterFeatureConditionalBreakpoints, +-- /// The debug adapter supports the `configurationDone` request. +-- eAdapterFeatureConfigurationDoneRequest, +-- /// The debug adapter supports the `asAddress` and `bytes` fields in the +-- /// `dataBreakpointInfo` request. +-- eAdapterFeatureDataBreakpointBytes, +-- /// The debug adapter supports data breakpoints. +-- eAdapterFeatureDataBreakpoints, +-- /// The debug adapter supports the delayed loading of parts of the stack, +-- /// which requires that both the `startFrame` and `levels` arguments and the +-- /// `totalFrames` result of the `stackTrace` request are supported. +-- eAdapterFeatureDelayedStackTraceLoading, +-- /// The debug adapter supports the `disassemble` request. +-- eAdapterFeatureDisassembleRequest, +-- /// The debug adapter supports a (side effect free) `evaluate` request for +-- /// data hovers. +-- eAdapterFeatureEvaluateForHovers, +-- /// The debug adapter supports `filterOptions` as an argument on the +-- /// `setExceptionBreakpoints` request. +-- eAdapterFeatureExceptionFilterOptions, +-- /// The debug adapter supports the `exceptionInfo` request. +-- eAdapterFeatureExceptionInfoRequest, +-- /// The debug adapter supports `exceptionOptions` on the +-- /// `setExceptionBreakpoints` request. +-- eAdapterFeatureExceptionOptions, +-- /// The debug adapter supports function breakpoints. +-- eAdapterFeatureFunctionBreakpoints, +-- /// The debug adapter supports the `gotoTargets` request. +-- eAdapterFeatureGotoTargetsRequest, +-- /// The debug adapter supports breakpoints that break execution after a +-- /// specified number of hits. +-- eAdapterFeatureHitConditionalBreakpoints, +-- /// The debug adapter supports adding breakpoints based on instruction +-- /// references. +-- eAdapterFeatureInstructionBreakpoints, +-- /// The debug adapter supports the `loadedSources` request. +-- eAdapterFeatureLoadedSourcesRequest, +-- /// The debug adapter supports log points by interpreting the `logMessage` +-- /// attribute of the `SourceBreakpoint`. +-- eAdapterFeatureLogPoints, +-- /// The debug adapter supports the `modules` request. +-- eAdapterFeatureModulesRequest, +-- /// The debug adapter supports the `readMemory` request. +-- eAdapterFeatureReadMemoryRequest, +-- /// The debug adapter supports restarting a frame. +-- eAdapterFeatureRestartFrame, +-- /// The debug adapter supports the `restart` request. In this case a client +-- /// should not implement `restart` by terminating and relaunching the +-- /// adapter but by calling the `restart` request. +-- eAdapterFeatureRestartRequest, +-- /// The debug adapter supports the `setExpression` request. +-- eAdapterFeatureSetExpression, +-- /// The debug adapter supports setting a variable to a value. +-- eAdapterFeatureSetVariable, +-- /// The debug adapter supports the `singleThread` property on the execution +-- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +-- /// `stepBack`). +-- eAdapterFeatureSingleThreadExecutionRequests, +-- /// The debug adapter supports stepping back via the `stepBack` and +-- /// `reverseContinue` requests. +-- eAdapterFeatureStepBack, +-- /// The debug adapter supports the `stepInTargets` request. +-- eAdapterFeatureStepInTargetsRequest, +-- /// The debug adapter supports stepping granularities (argument +-- /// `granularity`) for the stepping requests. +-- eAdapterFeatureSteppingGranularity, +-- /// The debug adapter supports the `terminate` request. +-- eAdapterFeatureTerminateRequest, +-- /// The debug adapter supports the `terminateThreads` request. +-- eAdapterFeatureTerminateThreadsRequest, +-- /// The debug adapter supports the `suspendDebuggee` attribute on the +-- /// `disconnect` request. +-- eAdapterFeatureSuspendDebuggee, +-- /// The debug adapter supports a `format` attribute on the `stackTrace`, +-- /// `variables`, and `evaluate` requests. +-- eAdapterFeatureValueFormattingOptions, +-- /// The debug adapter supports the `writeMemory` request. +-- eAdapterFeatureWriteMemoryRequest, +-- /// The debug adapter supports the `terminateDebuggee` attribute on the +-- /// `disconnect` request. +-- eAdapterFeatureTerminateDebuggee, +-+FLAGS_ENUM(AdapterFeature){ +-+ /// The debug adapter supports ANSI escape sequences in styling of +-+ /// `OutputEvent.output` and `Variable.value` fields. +-+ eAdapterFeatureANSIStyling, +-+ /// The debug adapter supports the `breakpointLocations` request. +-+ eAdapterFeatureBreakpointLocationsRequest, +-+ /// The debug adapter supports the `cancel` request. +-+ eAdapterFeatureCancelRequest, +-+ /// The debug adapter supports the `clipboard` context value in the +-+ /// `evaluate` request. +-+ eAdapterFeatureClipboardContext, +-+ /// The debug adapter supports the `completions` request. +-+ eAdapterFeatureCompletionsRequest, +-+ /// The debug adapter supports conditional breakpoints. +-+ eAdapterFeatureConditionalBreakpoints, +-+ /// The debug adapter supports the `configurationDone` request. +-+ eAdapterFeatureConfigurationDoneRequest, +-+ /// The debug adapter supports the `asAddress` and `bytes` fields in the +-+ /// `dataBreakpointInfo` request. +-+ eAdapterFeatureDataBreakpointBytes, +-+ /// The debug adapter supports data breakpoints. +-+ eAdapterFeatureDataBreakpoints, +-+ /// The debug adapter supports the delayed loading of parts of the stack, +-+ /// which requires that both the `startFrame` and `levels` arguments and the +-+ /// `totalFrames` result of the `stackTrace` request are supported. +-+ eAdapterFeatureDelayedStackTraceLoading, +-+ /// The debug adapter supports the `disassemble` request. +-+ eAdapterFeatureDisassembleRequest, +-+ /// The debug adapter supports a (side effect free) `evaluate` request for +-+ /// data hovers. +-+ eAdapterFeatureEvaluateForHovers, +-+ /// The debug adapter supports `filterOptions` as an argument on the +-+ /// `setExceptionBreakpoints` request. +-+ eAdapterFeatureExceptionFilterOptions, +-+ /// The debug adapter supports the `exceptionInfo` request. +-+ eAdapterFeatureExceptionInfoRequest, +-+ /// The debug adapter supports `exceptionOptions` on the +-+ /// `setExceptionBreakpoints` request. +-+ eAdapterFeatureExceptionOptions, +-+ /// The debug adapter supports function breakpoints. +-+ eAdapterFeatureFunctionBreakpoints, +-+ /// The debug adapter supports the `gotoTargets` request. +-+ eAdapterFeatureGotoTargetsRequest, +-+ /// The debug adapter supports breakpoints that break execution after a +-+ /// specified number of hits. +-+ eAdapterFeatureHitConditionalBreakpoints, +-+ /// The debug adapter supports adding breakpoints based on instruction +-+ /// references. +-+ eAdapterFeatureInstructionBreakpoints, +-+ /// The debug adapter supports the `loadedSources` request. +-+ eAdapterFeatureLoadedSourcesRequest, +-+ /// The debug adapter supports log points by interpreting the `logMessage` +-+ /// attribute of the `SourceBreakpoint`. +-+ eAdapterFeatureLogPoints, +-+ /// The debug adapter supports the `modules` request. +-+ eAdapterFeatureModulesRequest, +-+ /// The debug adapter supports the `readMemory` request. +-+ eAdapterFeatureReadMemoryRequest, +-+ /// The debug adapter supports restarting a frame. +-+ eAdapterFeatureRestartFrame, +-+ /// The debug adapter supports the `restart` request. In this case a client +-+ /// should not implement `restart` by terminating and relaunching the +-+ /// adapter but by calling the `restart` request. +-+ eAdapterFeatureRestartRequest, +-+ /// The debug adapter supports the `setExpression` request. +-+ eAdapterFeatureSetExpression, +-+ /// The debug adapter supports setting a variable to a value. +-+ eAdapterFeatureSetVariable, +-+ /// The debug adapter supports the `singleThread` property on the execution +-+ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, +-+ /// `stepBack`). +-+ eAdapterFeatureSingleThreadExecutionRequests, +-+ /// The debug adapter supports stepping back via the `stepBack` and +-+ /// `reverseContinue` requests. +-+ eAdapterFeatureStepBack, +-+ /// The debug adapter supports the `stepInTargets` request. +-+ eAdapterFeatureStepInTargetsRequest, +-+ /// The debug adapter supports stepping granularities (argument +-+ /// `granularity`) for the stepping requests. +-+ eAdapterFeatureSteppingGranularity, +-+ /// The debug adapter supports the `terminate` request. +-+ eAdapterFeatureTerminateRequest, +-+ /// The debug adapter supports the `terminateThreads` request. +-+ eAdapterFeatureTerminateThreadsRequest, +-+ /// The debug adapter supports the `suspendDebuggee` attribute on the +-+ /// `disconnect` request. +-+ eAdapterFeatureSuspendDebuggee, +-+ /// The debug adapter supports a `format` attribute on the `stackTrace`, +-+ /// `variables`, and `evaluate` requests. +-+ eAdapterFeatureValueFormattingOptions, +-+ /// The debug adapter supports the `writeMemory` request. +-+ eAdapterFeatureWriteMemoryRequest, +-+ /// The debug adapter supports the `terminateDebuggee` attribute on the +-+ /// `disconnect` request. +-+ eAdapterFeatureTerminateDebuggee, +- }; +- +- /// Information about the capabilities of a debug adapter. +-@@ -268,10 +261,10 @@ +- }; +- llvm::json::Value toJSON(const Capabilities &); +- +--enum PresentationHint { +-- ePresentationHintNormal, +-- ePresentationHintEmphasize, +-- ePresentationHintDeemphasize, +-+FLAGS_ENUM(PresentationHint){ +-+ ePresentationHintNormal, +-+ ePresentationHintEmphasize, +-+ ePresentationHintDeemphasize, +- }; +- +- /// A `Source` is a descriptor for source code. It is returned from the debug +-diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +---- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +-+++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test +-@@ -1,7 +1,7 @@ +- // Header +- // +- // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) +--// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) +-+// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) +- // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) +- // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) +- // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) +-diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c +---- a/offload/test/offloading/gpupgo/pgo1.c +-+++ b/offload/test/offloading/gpupgo/pgo1.c +-@@ -14,7 +14,7 @@ +- // RUN: %target_triple.%basename_t.clang.profraw | \ +- // RUN: %fcheck-generic --check-prefix="CLANG-PGO" +- +--// REQUIRES: gpu +-+// REQUIRES: amdgpu +- // REQUIRES: pgo +- +- int test1(int a) { return a / 2; } +-diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c +---- a/offload/test/offloading/gpupgo/pgo2.c +-+++ b/offload/test/offloading/gpupgo/pgo2.c +-@@ -48,7 +48,7 @@ +- // RUN: %target_triple.%basename_t.hfdi.profraw \ +- // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" +- +--// REQUIRES: gpu +-+// REQUIRES: amdgpu +- // REQUIRES: pgo +- +- int main() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 725480b..005737a 100644 +index 005737a..fd9baec 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "71a977d0d611f3e9f6137a6b8a26b730b2886ce9" -- LLVM_SHA256 = "9bdf3ddf45c069248af36080a78b56d839d3aad6f9b727ec1ee1be72682888cc" -+ LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" -+ LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" +- LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" +- LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" ++ LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" ++ LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 7b1a0496a0965c..f93f7a93cea2e8 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "9435b34df0279d473240f5bcc2a829d0589ae372" - SHARDY_SHA256 = "5f2a037d3301a1407e5c94778dd56d855f5abe26999cce448ccfa1923cf9559f" + SHARDY_COMMIT = "f25a97f80402e43de93f46931c6dddc485e8dad0" + SHARDY_SHA256 = "c7e55d3902175c064d3dd6bffc856c5e0198ff6c8f1410c3a97ed2c8e85ddb30" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch b/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch new file mode 100644 index 00000000000000..3a391a40c30b1a --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch @@ -0,0 +1,58 @@ + +--- a/third_party/amd/include/Analysis/RangeAnalysis.h 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/amd/include/Analysis/RangeAnalysis.h 2025-03-31 11:20:15.000000000 -0700 +@@ -118,8 +118,11 @@ + + std::optional> + collectRanges(const DataFlowSolver &solver, ValueRange values); ++ + bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp); + ++bool isEmptyInitializedRange(ConstantIntRanges rv); ++ + } // namespace mlir::triton::AMD + + #endif + +--- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp 2025-03-31 11:20:15.000000000 -0700 +@@ -186,6 +186,13 @@ + + namespace mlir::triton::AMD { + ++bool isEmptyInitializedRange(ConstantIntRanges rv) { ++ if (!rv.umin().getBitWidth() || !rv.umax().getBitWidth() || ++ !rv.smin().getBitWidth() || !rv.smax().getBitWidth()) ++ return true; ++ return false; ++} ++ + std::optional> + collectRanges(const DataFlowSolver &solver, ValueRange values) { + SmallVector ranges; +@@ -196,6 +203,8 @@ + return {}; + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); ++ if (isEmptyInitializedRange(inferredRange)) ++ return {}; + ranges.push_back(inferredRange); + } + return ranges; + +--- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-31 11:20:16.000000000 -0700 +@@ -34,6 +34,13 @@ + return signalPassFailure(); + + auto nonNegativePred = [&solver](Value v) -> bool { ++ if (const auto *r = ++ solver->lookupState(v)) { ++ if (r->getValue().isUninitialized()) ++ return false; ++ if (AMD::isEmptyInitializedRange(r->getValue().getValue())) ++ return false; ++ } + return succeeded(dataflow::staticallyNonNegative(*solver, v)); + }; + mod->walk([&solver, nonNegativePred](Operation *op) { diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index b82b9f6b87bb1c..64e38dfaa30e75 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -10,5 +10,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. llvm_patch_list = [ "//third_party/triton:llvm_integration/cl740926882.patch", "//third_party/triton:llvm_integration/cl741558316.patch", + "//third_party/triton:llvm_integration/cl742325920.patch", # Add new patches just above this line ] From c51a37f03a00157b23985c42c5fabe1d6bf3fd5f Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 1 Apr 2025 12:30:39 -0700 Subject: [PATCH 0109/1324] [MHLO] Allow partial conversion between StableHLO and MHLO for ops with direct HLO lowerings PiperOrigin-RevId: 742799109 --- .../xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td | 6 +++++- third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h | 3 --- .../stablehlo_legalize_to_hlo_pass.cc | 11 ++++++----- .../mhlo/stablehlo-legalize-to-hlo-partial.mlir | 10 ++++++++++ 4 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index 8503598e26b64d..853531c1c6e0d7 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -181,8 +181,12 @@ def HloLegalizeToStablehloPass : Pass<"hlo-legalize-to-stablehlo", "ModuleOp"> { def StablehloLegalizeToHloPass : Pass<"stablehlo-legalize-to-hlo", "ModuleOp"> { let summary = "Legalize StableHLO to HLO."; - let constructor = "createStablehloLegalizeToHloPass()"; let dependentDialects = ["mhlo::MhloDialect"]; + let options = [ + Option<"convert_xla_supported_stablehlo_", "convert-xla-supported-stablehlo", + "bool", /*default=*/"true", + "Don't convert ops that have direct HLO lowering support."> + ]; } def PrepareForExportPass : Pass<"xla-prepare-for-export", "mlir::func::FuncOp"> { diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h index 632f0f955e1b34..3d2aa3b3d31b93 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h @@ -72,9 +72,6 @@ std::unique_ptr> createCollapseElementwiseMapPass(); // Pass to replace unsigned types with signless integers. std::unique_ptr> createConvertToSignlessPass(); -// Legalizes from the StableHLO dialect to the MHLO dialect. -std::unique_ptr> createStablehloLegalizeToHloPass(); - // Test passes. std::unique_ptr createTestInferShapedTypeMethodsPass(); std::unique_ptr createTestMaterializeBroadcastsPass(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 9b4c4d4eb64f0f..9a06ce80bbccec 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -41,11 +41,17 @@ namespace { struct StablehloLegalizeToHloPass : public impl::StablehloLegalizeToHloPassBase { + using StablehloLegalizeToHloPassBase::StablehloLegalizeToHloPassBase; void runOnOperation() override { ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); + // Allow injecting legal ops to permit gradual migration. + if (!convert_xla_supported_stablehlo_) { + target.addLegalOp(); + } + stablehlo::StablehloToHloTypeConverter converter; RewritePatternSet patterns(&getContext()); stablehlo::populateStablehloToHloPatterns(&patterns, &converter, @@ -63,10 +69,5 @@ struct StablehloLegalizeToHloPass } // namespace -std::unique_ptr> -createStablehloLegalizeToHloPass() { - return std::make_unique(); -} - } // namespace mhlo } // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir new file mode 100644 index 00000000000000..5008e99d64deb9 --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo=convert-xla-supported-stablehlo=false --split-input-file --verify-diagnostics %s | FileCheck %s + + +// CHECK-LABEL: op_constant +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: stablehlo.constant + // CHECK-NOT: mhlo.constant + %cst = stablehlo.constant dense<0.000000e+00> : tensor + return %cst : tensor +} From 9bdeda24e58b39852225c69b0ef7218835fa5bbf Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 1 Apr 2025 12:37:11 -0700 Subject: [PATCH 0110/1324] [XLA:GPU] Delete no-op `xla_gpu_enable_nccl_clique_optimization` flag and associated code. The flag was used in the now deleted ([#9329](https://github.com/openxla/xla/pull/9329)) XLA GPU runtime. PiperOrigin-RevId: 742801435 --- third_party/xla/xla/debug_options_flags.cc | 7 - third_party/xla/xla/service/gpu/BUILD | 2 - .../xla/xla/service/gpu/backend_configs.proto | 9 +- .../xla/service/gpu/backend_configs_test.cc | 1 - .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 2 - .../xla/service/gpu/gpu_hlo_schedule_test.cc | 1 - .../xla/xla/service/gpu/transforms/BUILD | 37 ---- .../command_buffer_scheduling_test.cc | 16 +- .../gpu/transforms/schedule_postprocessing.cc | 158 ----------------- .../gpu/transforms/schedule_postprocessing.h | 50 ------ .../schedule_postprocessing_test.cc | 163 ------------------ third_party/xla/xla/xla.proto | 4 +- 12 files changed, 12 insertions(+), 438 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc delete mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h delete mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 0e7d84010421a4..8667c05bbefb19 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -254,7 +254,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_split_k_autotuning(true); opts.set_xla_gpu_enable_reduction_epilogue_fusion(true); - opts.set_xla_gpu_enable_nccl_clique_optimization(false); opts.set_xla_gpu_cublas_fallback(true); opts.set_xla_gpu_cudnn_gemm_fusion_level(0); opts.set_xla_gpu_enable_while_loop_double_buffering(false); @@ -1920,12 +1919,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_reduction_epilogue_fusion), debug_options->xla_gpu_enable_reduction_epilogue_fusion(), "Enable fusion for reduction epilogues")); - flag_list->push_back( - tsl::Flag("xla_gpu_enable_nccl_clique_optimization", - bool_setter_for( - &DebugOptions::set_xla_gpu_enable_nccl_clique_optimization), - debug_options->xla_gpu_enable_nccl_clique_optimization(), - "Allow early return when acquiring NCCL cliques")); flag_list->push_back( tsl::Flag("xla_gpu_cublas_fallback", bool_setter_for(&DebugOptions::set_xla_gpu_cublas_fallback), diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2a2646ea8691d1..d67ef9f5ee157b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2164,7 +2164,6 @@ cc_library( "//xla/service/gpu/model:sol_latency_estimator", "//xla/service/gpu/transforms:async_collective_annotator", "//xla/service/gpu/transforms:pgle_accuracy_checker", - "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_description", @@ -2207,7 +2206,6 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service:latency_hiding_scheduler", "//xla/service:legalize_scheduling_annotations", - "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 515f672acbf3a9..37ad0b6e615ab1 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -118,19 +118,16 @@ message BitcastBackendConfig { // Backend config for async collective operations. Note that for is_sync will // be false by default, so even if a backend config is not explicitly attached // to the HLOInstruction, getting the backend_config will yield a default valued -// proto which will have is_sync = false. Attribute no_parallel_custom_call -// asserts that an asynchronous collective operation does not execute in -// parallel with custom-calls, which can trigger device synchronization . This -// attribute will also be false by default and should lead to conversative -// runtime behavior. +// proto which will have is_sync = false. message CollectiveBackendConfig { bool is_sync = 1; - bool no_parallel_custom_call = 2; // Determines whether the collective op of interested has been pipelined // within a loop. bool is_pipelined = 3; // Cost model prediction. repeated ReificationCost reification_cost = 4; + + reserved 2; } // Backend config for cost model estimates. diff --git a/third_party/xla/xla/service/gpu/backend_configs_test.cc b/third_party/xla/xla/service/gpu/backend_configs_test.cc index 7883547f077dcb..16f05964536e71 100644 --- a/third_party/xla/xla/service/gpu/backend_configs_test.cc +++ b/third_party/xla/xla/service/gpu/backend_configs_test.cc @@ -59,7 +59,6 @@ TEST_F(BackendConfigsTest, DefaultCollectiveBackendConfig) { const auto& collective_backend_config = gpu_config.collective_backend_config(); EXPECT_THAT(collective_backend_config.is_sync(), IsFalse()); - EXPECT_THAT(collective_backend_config.no_parallel_custom_call(), IsFalse()); } TEST_F(BackendConfigsTest, DefaultGpuBackendConfigParseOpQueue) { diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index c050fad9b8d68a..aacd6d264475fa 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -59,7 +59,6 @@ limitations under the License. #include "xla/service/gpu/transforms/async_collective_annotator.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/gpu/transforms/pgle_accuracy_checker.h" -#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" @@ -596,7 +595,6 @@ absl::Status RunLatencyHidingSchedulerPasses( std::move(estimator), std::move(async_tracker), std::move(scheduler_core), shape_size_in_bytes); pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index f2d3c1ebbedfda..375210385f89e7 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -43,7 +43,6 @@ limitations under the License. #include "xla/service/backend.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 03f25d09842259..56050746f2a4bc 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2828,43 +2828,6 @@ xla_cc_test( ], ) -cc_library( - name = "schedule_postprocessing", - srcs = ["schedule_postprocessing.cc"], - hdrs = ["schedule_postprocessing.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu/transforms/collectives:collective_ops_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "schedule_postprocessing_test", - srcs = ["schedule_postprocessing_test.cc"], - deps = [ - ":schedule_postprocessing", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/service/gpu:backend_configs_cc", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "scheduling_instruction_annotator", srcs = ["scheduling_instruction_annotator.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 3c60b7f6ec2255..e74ff34ec0d625 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -209,7 +209,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) { %a = s32[4] parameter(0) %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) })"; @@ -242,7 +242,7 @@ TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) { %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a), channel_id=555, replica_groups={{0,1}}, dimensions={0}, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} ROOT %done = s32[4]{0} all-gather-done(%start) })"; @@ -282,7 +282,7 @@ TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) { %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a), channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} ROOT %done = s32[2]{0} reduce-scatter-done(%start) })"; @@ -321,7 +321,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) { %a = s32[4] parameter(0) %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %bitcast = s32[4] bitcast(s32[4]{0} %a) ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) })"; @@ -361,10 +361,10 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) { %a = s32[4] parameter(0) %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) })"; @@ -418,11 +418,11 @@ TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) { %b = s32[] parameter(1) %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %c = s32[] custom-call(), custom_call_target="target" %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc deleted file mode 100644 index 0fd39a27d0e3ba..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/schedule_postprocessing.h" - -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { -// Maps a computation to a boolean that indicates whether the computation may -// invoke custom-calls directly or indirectly, which can eventually trigger gpu -// synchronization. -using CustomCallInComputation = - absl::flat_hash_map; - -// Returns whether the hlo may invoke custom-calls which may trigger gpu -// synchronization. Currently, we only check for custom-calls, because they are -// the only operations that can be parallel with asynchronous collectives -// operations in an hlo-schedule and may trigger gpu synchronization. -bool MayInvokeCustomCall( - const HloInstruction* hlo, - const CustomCallInComputation& custom_call_in_computation) { - if (HloPredicateIsOp(hlo)) { - return true; - } - - return absl::c_any_of( - hlo->called_computations(), [&](const HloComputation* callee) { - return custom_call_in_computation.find(callee)->second; - }); -} - -// Returns true if this is an asynchronous collective start operation, excluding -// P2P operations. -bool IsRelevantAsynchronousStart(const HloInstruction* hlo) { - return hlo_query::IsAsyncCollectiveStartOp(hlo, - /*include_send_recv=*/false) && - !IsGPUSyncCollective(*hlo); -} - -// Returns true if this is a collective done operation, excluding P2P -// operations. -bool IsRelevantAsynchronousDone(const HloInstruction* hlo) { - return hlo_query::IsAsyncCollectiveDoneOp(hlo, - /*include_send_recv=*/false); -} - -// For a given computation, finds all the asynchronous collective operations -// that aren't parallel with custom-calls and sets its no_parallel_custom_call -// attribute to true. Also records whether the given computation may invoke -// custom-calls. -absl::StatusOr ProcessComputation( - const HloSchedule& schedule, HloComputation* computation, - CustomCallInComputation& custom_call_in_computation) { - bool changed = false; - bool has_custom_call = false; - absl::flat_hash_set async_starts; - const HloInstructionSequence& sequence = schedule.sequence(computation); - - // Visit instructions in the sequence. Collect relevant asynchronous - // collective start ops. When we see a relevant asynchronous collective done - // op, remove the corresponding start op from the collection and set its - // attribute no_parallel_custom_call to true. When we see a custom-call, clear - // the start ops from the collection and keep their attribute - // no_parallel_custom_call as false. - const std::vector& all_instructions = - sequence.instructions(); - for (HloInstruction* hlo : all_instructions) { - if (MayInvokeCustomCall(hlo, custom_call_in_computation)) { - async_starts.clear(); - has_custom_call = true; - continue; - } - if (IsRelevantAsynchronousStart(hlo)) { - async_starts.insert(hlo); - continue; - } - - if (IsRelevantAsynchronousDone(hlo)) { - HloInstruction* async_start = hlo->mutable_operand(0); - if (async_starts.contains(async_start)) { - changed = true; - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - async_start->backend_config()); - CollectiveBackendConfig& collective_backend_config = - *gpu_config.mutable_collective_backend_config(); - collective_backend_config.set_no_parallel_custom_call(true); - TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config)); - async_starts.erase(async_start); - } - } - } - - custom_call_in_computation[computation] = has_custom_call; - return changed; -} - -} // anonymous namespace - -absl::StatusOr SchedulePostprocessing::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - if (!module->has_schedule()) return false; - HloSchedule& schedule = module->schedule(); - bool changed = false; - CustomCallInComputation custom_call_in_computation; - - // We visit computations in the order of callees to callers, as information is - // propagated from calles to callers. - std::vector all_computations = - module->MakeComputationPostOrder(execution_threads); - for (auto iter = all_computations.begin(); iter != all_computations.end(); - ++iter) { - HloComputation* computation = *iter; - if (computation->IsFusionComputation()) { - custom_call_in_computation[computation] = false; - continue; - } - - TF_ASSIGN_OR_RETURN( - bool result, - ProcessComputation(schedule, computation, custom_call_in_computation)); - changed |= result; - } - - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h deleted file mode 100644 index d76faed7d260cc..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ -#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -// Amends a schedule result with the needed information to support a runtime -// implementation. Currently, this pass refines attribute -// no_parallel_custom_call for asynchronous collective operations to support -// runtime optimization, such as skipping rendezvous of all participating -// threads for NCCL collective operations. In particular, it sets the attribute -// value for Collective-start operations with is_sync=false; it also keeps the -// attribute value untouch for the operations with is_sync=true and for P2P -// operations, assumming the runtime won't use those values. -// -class SchedulePostprocessing : public HloModulePass { - public: - absl::string_view name() const override { return "schedule-postprocessing"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc deleted file mode 100644 index 01659a11f6e66d..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/schedule_postprocessing.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/parser/hlo_parser.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -using SchedulePostprocessingTest = HloTestBase; - -TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY entry { - pf32 = f32[1] parameter(0) - - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}} - ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY main { - f0 = f32[] constant(0.0) - init = f32[1, 1024, 1024] broadcast(f0), dimensions={} - - after-all = token[] after-all() - recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}" - } - recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2 - ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY entry { - pf32 = f32[1] parameter(0) - pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} - ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, - start->backend_config()); - const CollectiveBackendConfig& collective_backend_config = - gpu_config.collective_backend_config(); - EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); -} - -TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY entry { - pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} - pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" - all-gather-done = f32[2] all-gather-done(all-gather-start) - ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); - - HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, - start->backend_config()); - const CollectiveBackendConfig& collective_backend_config = - gpu_config.collective_backend_config(); - EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); -} - -TEST_F(SchedulePostprocessingTest, - AsynchronousOpsWithParallelNestedCustomcall) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - foo { - v = f32[1] parameter(0) - ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call" - } - - ENTRY entry { - pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} - pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo - all-gather-done = f32[2] all-gather-done(all-gather-start) - ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); - - HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, - start->backend_config()); - const CollectiveBackendConfig& collective_backend_config = - gpu_config.collective_backend_config(); - EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 5e0d41c5cb9a5b..4a89560b8c9228 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -443,9 +443,6 @@ message DebugOptions { // threads. Setting to 0 (the default value) means no enforcement. bool xla_gpu_enable_llvm_module_compilation_parallelism = 268; - // Allow early return when acquiring NCCL cliques. - bool xla_gpu_enable_nccl_clique_optimization = 244; - // Enable NCCL communicator splitting. bool xla_gpu_enable_nccl_comm_splitting = 272; @@ -806,6 +803,7 @@ message DebugOptions { // go/keep-sorted end + reserved 244; // xla_gpu_enable_nccl_clique_optimization reserved 276; // xla_gpu_enable_nccl_per_stream_comms //--------------------------------------------------------------------------// From f2d8d9b836b7e82c1ae77d659039486a3be9591d Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 1 Apr 2025 12:50:25 -0700 Subject: [PATCH 0111/1324] Add tridiagonal tests for XlaBuilder PiperOrigin-RevId: 742805837 --- .../tests/tridiagonal_tridiagonal_solve.hlo | 399 ++++++++++++++++++ third_party/xla/xla/hlo/tools/tests/BUILD | 2 + .../tools/tests/hlo_opt_test_only_passes.cc | 45 +- 3 files changed, 440 insertions(+), 6 deletions(-) create mode 100644 third_party/xla/xla/hlo/builder/tests/tridiagonal_tridiagonal_solve.hlo diff --git a/third_party/xla/xla/hlo/builder/tests/tridiagonal_tridiagonal_solve.hlo b/third_party/xla/xla/hlo/builder/tests/tridiagonal_tridiagonal_solve.hlo new file mode 100644 index 00000000000000..84f46311e4a69a --- /dev/null +++ b/third_party/xla/xla/hlo/builder/tests/tridiagonal_tridiagonal_solve.hlo @@ -0,0 +1,399 @@ +// NOTE: Assertions have been autogenerated by hlo/tools/generate_hlo_test_checks.py +// RUN: hlo-opt --passes=test-only-xla-builder --split-input-file %s | FileCheck %s + +// CHECK-LABEL: HloModule tridiagonal_tridiagonal_solve, entry_computation_layout={(f32[3,3]{1,0}, f32[3,3]{1,0})->f32[3,3]{1,0}} + +// CHECK: %[[$preparation_body_29:[^ ]+]] +// CHECK-NEXT: %[[parameter_30:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_31:[^ ]+]] = s32[] get-tuple-element(%[[parameter_30]]), index=0 +// CHECK-NEXT: %[[constant_34:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_35:[^ ]+]] = s32[] add(%[[get_tuple_element_31]], %[[constant_34]]) +// CHECK-NEXT: %[[get_tuple_element_32:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_30]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_33:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_30]]), index=2 +// CHECK-NEXT: %[[constant_36:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_37:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_33]], %[[constant_36]], %[[get_tuple_element_31]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_38:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_39:[^ ]+]] = f32[1,3]{1,0} dynamic-update-slice(%[[get_tuple_element_32]], %[[dynamic_slice_37]], %[[constant_38]], %[[get_tuple_element_31]]) +// CHECK-NEXT: ROOT %[[tuple_40:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[add_35]], %[[dynamic_update_slice_39]], %[[get_tuple_element_33]]) + +// CHECK: %[[$preparation_condition_41:[^ ]+]] +// CHECK-NEXT: %[[parameter_42:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_44:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_42]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_45:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_42]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_43:[^ ]+]] = s32[] get-tuple-element(%[[parameter_42]]), index=0 +// CHECK-NEXT: %[[constant_46:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: ROOT %[[compare_47:[^ ]+]] = pred[] compare(%[[get_tuple_element_43]], %[[constant_46]]), direction=LT + +// CHECK: %[[$forward_transformation_body_54:[^ ]+]] +// CHECK-NEXT: %[[parameter_55:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{0,1}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = s32[] get-tuple-element(%[[parameter_55]]), index=0 +// CHECK-NEXT: %[[constant_63:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_64:[^ ]+]] = s32[] add(%[[get_tuple_element_56]], %[[constant_63]]) +// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_55]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_58:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_55]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_59:[^ ]+]] = f32[3,3]{0,1} get-tuple-element(%[[parameter_55]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_60:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_55]]), index=4 +// CHECK-NEXT: %[[constant_69:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_65:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_66:[^ ]+]] = s32[] add(%[[get_tuple_element_56]], %[[constant_65]]) +// CHECK-NEXT: %[[dynamic_slice_70:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_58]], %[[constant_69]], %[[add_66]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_67:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_68:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_57]], %[[constant_67]], %[[add_66]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_74:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[subtract_73:[^ ]+]] = s32[] subtract(%[[add_66]], %[[constant_65]]) +// CHECK-NEXT: %[[dynamic_slice_75:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_60]], %[[constant_74]], %[[subtract_73]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[divide_76:[^ ]+]] = f32[1,1]{1,0} divide(%[[dynamic_slice_68]], %[[dynamic_slice_75]]) +// CHECK-NEXT: %[[get_tuple_element_61:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_55]]), index=5 +// CHECK-NEXT: %[[constant_78:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[subtract_77:[^ ]+]] = s32[] subtract(%[[add_66]], %[[constant_65]]) +// CHECK-NEXT: %[[dynamic_slice_79:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_61]], %[[constant_78]], %[[subtract_77]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[multiply_80:[^ ]+]] = f32[1,1]{1,0} multiply(%[[divide_76]], %[[dynamic_slice_79]]) +// CHECK-NEXT: %[[subtract_81:[^ ]+]] = f32[1,1]{1,0} subtract(%[[dynamic_slice_70]], %[[multiply_80]]) +// CHECK-NEXT: %[[constant_82:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_83:[^ ]+]] = f32[1,3]{1,0} dynamic-update-slice(%[[get_tuple_element_60]], %[[subtract_81]], %[[constant_82]], %[[add_66]]) +// CHECK-NEXT: %[[get_tuple_element_62:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_55]]), index=6 +// CHECK-NEXT: %[[constant_71:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_72:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_59]], %[[constant_71]], %[[add_66]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[reshape_87:[^ ]+]] = f32[1]{0} reshape(%[[divide_76]]) +// CHECK-NEXT: %[[broadcast_88:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_87]]), dimensions={1} +// CHECK-NEXT: %[[constant_85:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[subtract_84:[^ ]+]] = s32[] subtract(%[[add_66]], %[[constant_65]]) +// CHECK-NEXT: %[[dynamic_slice_86:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_62]], %[[constant_85]], %[[subtract_84]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[multiply_89:[^ ]+]] = f32[3,1]{1,0} multiply(%[[broadcast_88]], %[[dynamic_slice_86]]) +// CHECK-NEXT: %[[subtract_90:[^ ]+]] = f32[3,1]{1,0} subtract(%[[dynamic_slice_72]], %[[multiply_89]]) +// CHECK-NEXT: %[[constant_91:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_92:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[get_tuple_element_62]], %[[subtract_90]], %[[constant_91]], %[[add_66]]) +// CHECK-NEXT: ROOT %[[tuple_93:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{0,1}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) tuple(%[[add_64]], %[[get_tuple_element_57]], %[[get_tuple_element_58]], %[[get_tuple_element_59]], %[[dynamic_update_slice_83]], /*index=5*/%[[get_tuple_element_61]], %[[dynamic_update_slice_92]]) + +// CHECK: %[[$forward_transformation_condition_94:[^ ]+]] +// CHECK-NEXT: %[[parameter_95:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{0,1}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_97:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_95]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_98:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_95]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_99:[^ ]+]] = f32[3,3]{0,1} get-tuple-element(%[[parameter_95]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_100:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_95]]), index=4 +// CHECK-NEXT: %[[get_tuple_element_101:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_95]]), index=5 +// CHECK-NEXT: %[[get_tuple_element_102:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_95]]), index=6 +// CHECK-NEXT: %[[get_tuple_element_96:[^ ]+]] = s32[] get-tuple-element(%[[parameter_95]]), index=0 +// CHECK-NEXT: %[[constant_103:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: ROOT %[[compare_104:[^ ]+]] = pred[] compare(%[[get_tuple_element_96]], %[[constant_103]]), direction=LT + +// CHECK: %[[$backward_reduction_body_127:[^ ]+]] +// CHECK-NEXT: %[[parameter_128:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_129:[^ ]+]] = s32[] get-tuple-element(%[[parameter_128]]), index=0 +// CHECK-NEXT: %[[constant_134:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_135:[^ ]+]] = s32[] add(%[[get_tuple_element_129]], %[[constant_134]]) +// CHECK-NEXT: %[[get_tuple_element_130:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_128]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_131:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_128]]), index=2 +// CHECK-NEXT: %[[constant_139:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_136:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[subtract_138:[^ ]+]] = s32[] subtract(%[[constant_136]], %[[get_tuple_element_129]]) +// CHECK-NEXT: %[[dynamic_slice_140:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_131]], %[[constant_139]], %[[subtract_138]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[get_tuple_element_132:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_128]]), index=3 +// CHECK-NEXT: %[[constant_141:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_142:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_132]], %[[constant_141]], %[[subtract_138]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[reshape_146:[^ ]+]] = f32[1]{0} reshape(%[[dynamic_slice_142]]) +// CHECK-NEXT: %[[broadcast_147:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_146]]), dimensions={1} +// CHECK-NEXT: %[[constant_144:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_137:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_143:[^ ]+]] = s32[] add(%[[subtract_138]], %[[constant_137]]) +// CHECK-NEXT: %[[dynamic_slice_145:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_130]], %[[constant_144]], %[[add_143]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[multiply_148:[^ ]+]] = f32[3,1]{1,0} multiply(%[[broadcast_147]], %[[dynamic_slice_145]]) +// CHECK-NEXT: %[[subtract_149:[^ ]+]] = f32[3,1]{1,0} subtract(%[[dynamic_slice_140]], %[[multiply_148]]) +// CHECK-NEXT: %[[get_tuple_element_133:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_128]]), index=4 +// CHECK-NEXT: %[[constant_150:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_151:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_133]], %[[constant_150]], %[[subtract_138]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[reshape_152:[^ ]+]] = f32[1]{0} reshape(%[[dynamic_slice_151]]) +// CHECK-NEXT: %[[broadcast_153:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_152]]), dimensions={1} +// CHECK-NEXT: %[[divide_154:[^ ]+]] = f32[3,1]{1,0} divide(%[[subtract_149]], %[[broadcast_153]]) +// CHECK-NEXT: %[[constant_155:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_156:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[get_tuple_element_130]], %[[divide_154]], %[[constant_155]], %[[subtract_138]]) +// CHECK-NEXT: ROOT %[[tuple_157:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[add_135]], %[[dynamic_update_slice_156]], %[[get_tuple_element_131]], %[[get_tuple_element_132]], %[[get_tuple_element_133]]) + +// CHECK: %[[$backward_reduction_condition_158:[^ ]+]] +// CHECK-NEXT: %[[parameter_159:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_161:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_159]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_162:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_159]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_163:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_159]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_164:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_159]]), index=4 +// CHECK-NEXT: %[[get_tuple_element_160:[^ ]+]] = s32[] get-tuple-element(%[[parameter_159]]), index=0 +// CHECK-NEXT: %[[constant_165:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: ROOT %[[compare_166:[^ ]+]] = pred[] compare(%[[get_tuple_element_160]], %[[constant_165]]), direction=LT + +// CHECK: %[[$xla_builder_tridiagonal_TridiagonalSolver_174:[^ ]+]] +// CHECK-NEXT: %[[constant_27:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_11:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = f32[1,3]{1,0} broadcast(%[[constant_11]]), dimensions={} +// CHECK-NEXT: %[[arg0_1:[^ ]+]] = f32[3,3]{1,0} parameter(0) +// CHECK-NEXT: %[[slice_3:[^ ]+]] = f32[1,3]{1,0} slice(%[[arg0_1]]), slice={[0:1], [0:3]} +// CHECK-NEXT: %[[tuple_28:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[constant_27]], %[[broadcast_12]], %[[slice_3]]) +// CHECK-NEXT: %[[while_48:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) while(%[[tuple_28]]), condition=%[[$preparation_condition_41]], body=%[[$preparation_body_29]] +// CHECK-NEXT: %[[get_tuple_element_49:[^ ]+]] = s32[] get-tuple-element(%[[while_48]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_51:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_48]]), index=2 +// CHECK-NEXT: %[[constant_52:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[slice_5:[^ ]+]] = f32[1,3]{1,0} slice(%[[arg0_1]]), slice={[2:3], [0:3]} +// CHECK-NEXT: %[[slice_4:[^ ]+]] = f32[1,3]{1,0} slice(%[[arg0_1]]), slice={[1:2], [0:3]} +// CHECK-NEXT: %[[arg1_2:[^ ]+]] = f32[3,3]{1,0} parameter(1) +// CHECK-NEXT: %[[transpose_6:[^ ]+]] = f32[3,3]{0,1} transpose(%[[arg1_2]]), dimensions={1,0} +// CHECK-NEXT: %[[constant_7:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_8:[^ ]+]] = f32[1,3]{1,0} broadcast(%[[constant_7]]), dimensions={} +// CHECK-NEXT: %[[constant_16:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_15:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_17:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[slice_4]], %[[constant_16]], %[[constant_15]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_19:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_18:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_20:[^ ]+]] = f32[1,3]{1,0} dynamic-update-slice(%[[broadcast_8]], %[[dynamic_slice_17]], %[[constant_19]], %[[constant_18]]) +// CHECK-NEXT: %[[get_tuple_element_50:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_48]]), index=1 +// CHECK-NEXT: %[[constant_9:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = f32[3,3]{1,0} broadcast(%[[constant_9]]), dimensions={} +// CHECK-NEXT: %[[constant_22:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_21:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_23:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[transpose_6]], %[[constant_22]], %[[constant_21]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[constant_25:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_24:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_26:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[broadcast_10]], %[[dynamic_slice_23]], %[[constant_25]], %[[constant_24]]) +// CHECK-NEXT: %[[tuple_53:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{0,1}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) tuple(%[[constant_52]], %[[slice_5]], %[[slice_4]], %[[transpose_6]], %[[dynamic_update_slice_20]], /*index=5*/%[[get_tuple_element_50]], %[[dynamic_update_slice_26]]) +// CHECK-NEXT: %[[while_105:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{0,1}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) while(%[[tuple_53]]), condition=%[[$forward_transformation_condition_94]], body=%[[$forward_transformation_body_54]] +// CHECK-NEXT: %[[get_tuple_element_106:[^ ]+]] = s32[] get-tuple-element(%[[while_105]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_107:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_105]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_108:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_105]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_109:[^ ]+]] = f32[3,3]{0,1} get-tuple-element(%[[while_105]]), index=3 +// CHECK-NEXT: %[[constant_125:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_13:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_14:[^ ]+]] = f32[3,3]{1,0} broadcast(%[[constant_13]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_112:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_105]]), index=6 +// CHECK-NEXT: %[[constant_114:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_113:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: %[[dynamic_slice_115:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_112]], %[[constant_114]], %[[constant_113]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[get_tuple_element_110:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_105]]), index=4 +// CHECK-NEXT: %[[constant_117:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_116:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: %[[dynamic_slice_118:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_110]], %[[constant_117]], %[[constant_116]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[reshape_119:[^ ]+]] = f32[1]{0} reshape(%[[dynamic_slice_118]]) +// CHECK-NEXT: %[[broadcast_120:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_119]]), dimensions={1} +// CHECK-NEXT: %[[divide_121:[^ ]+]] = f32[3,1]{1,0} divide(%[[dynamic_slice_115]], %[[broadcast_120]]) +// CHECK-NEXT: %[[constant_123:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_122:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: %[[dynamic_update_slice_124:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[broadcast_14]], %[[divide_121]], %[[constant_123]], %[[constant_122]]) +// CHECK-NEXT: %[[get_tuple_element_111:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_105]]), index=5 +// CHECK-NEXT: %[[tuple_126:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[constant_125]], %[[dynamic_update_slice_124]], %[[get_tuple_element_112]], %[[get_tuple_element_111]], %[[get_tuple_element_110]]) +// CHECK-NEXT: %[[while_167:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) while(%[[tuple_126]]), condition=%[[$backward_reduction_condition_158]], body=%[[$backward_reduction_body_127]] +// CHECK-NEXT: %[[get_tuple_element_168:[^ ]+]] = s32[] get-tuple-element(%[[while_167]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_170:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_167]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_171:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_167]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_172:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_167]]), index=4 +// CHECK-NEXT: %[[get_tuple_element_169:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_167]]), index=1 +// CHECK-NEXT: ROOT %[[transpose_173:[^ ]+]] = f32[3,3]{0,1} transpose(%[[get_tuple_element_169]]), dimensions={1,0} + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[3,3]{1,0} parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = f32[3,3]{1,0} parameter(1) +// CHECK-NEXT: %[[custom_call_2:[^ ]+]] = f32[3,3]{1,0} custom-call(%[[Arg_0_1]], %[[Arg_1_2]]), custom_call_target="xla_builder.tridiagonal.TridiagonalSolver" +// CHECK-NEXT: ROOT %[[custom_call:[^ ]+]] = f32[3,3]{1,0} custom-call(%[[Arg_0_1]], %[[Arg_1_2]]), custom_call_target="xla_builder.tridiagonal.TridiagonalSolver", called_computations={%[[$xla_builder_tridiagonal_TridiagonalSolver_174]]} + +HloModule tridiagonal_tridiagonal_solve, entry_computation_layout={(f32[3,3],f32[3,3])->f32[3,3]} + +ENTRY %main.3 (Arg_0.1: f32[3,3], Arg_1.2: f32[3,3]) -> f32[3,3] { + %Arg_0.1 = f32[3,3] parameter(0) + %Arg_1.2 = f32[3,3] parameter(1) + ROOT %custom-call.2 = f32[3,3] custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="xla_builder.tridiagonal.TridiagonalSolver" +} + +// ----- + +// CHECK-LABEL: HloModule tridiagonal_tridiagonal_solve_all_args, entry_computation_layout={(f32[1,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{1,0})->f32[3,3]{1,0}} + +// CHECK: %[[$preparation_body_27:[^ ]+]] +// CHECK-NEXT: %[[parameter_28:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_29:[^ ]+]] = s32[] get-tuple-element(%[[parameter_28]]), index=0 +// CHECK-NEXT: %[[constant_32:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_33:[^ ]+]] = s32[] add(%[[get_tuple_element_29]], %[[constant_32]]) +// CHECK-NEXT: %[[get_tuple_element_30:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_28]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_31:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_28]]), index=2 +// CHECK-NEXT: %[[constant_34:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_35:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_31]], %[[constant_34]], %[[get_tuple_element_29]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_36:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_37:[^ ]+]] = f32[1,3]{1,0} dynamic-update-slice(%[[get_tuple_element_30]], %[[dynamic_slice_35]], %[[constant_36]], %[[get_tuple_element_29]]) +// CHECK-NEXT: ROOT %[[tuple_38:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[add_33]], %[[dynamic_update_slice_37]], %[[get_tuple_element_31]]) + +// CHECK: %[[$preparation_condition_39:[^ ]+]] +// CHECK-NEXT: %[[parameter_40:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_42:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_40]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_43:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_40]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_41:[^ ]+]] = s32[] get-tuple-element(%[[parameter_40]]), index=0 +// CHECK-NEXT: %[[constant_44:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: ROOT %[[compare_45:[^ ]+]] = pred[] compare(%[[get_tuple_element_41]], %[[constant_44]]), direction=LT + +// CHECK: %[[$forward_transformation_body_52:[^ ]+]] +// CHECK-NEXT: %[[parameter_53:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_54:[^ ]+]] = s32[] get-tuple-element(%[[parameter_53]]), index=0 +// CHECK-NEXT: %[[constant_61:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_62:[^ ]+]] = s32[] add(%[[get_tuple_element_54]], %[[constant_61]]) +// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_53]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_53]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_53]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_58:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_53]]), index=4 +// CHECK-NEXT: %[[constant_67:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_63:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_64:[^ ]+]] = s32[] add(%[[get_tuple_element_54]], %[[constant_63]]) +// CHECK-NEXT: %[[dynamic_slice_68:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_56]], %[[constant_67]], %[[add_64]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_65:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_66:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_55]], %[[constant_65]], %[[add_64]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_72:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[subtract_71:[^ ]+]] = s32[] subtract(%[[add_64]], %[[constant_63]]) +// CHECK-NEXT: %[[dynamic_slice_73:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_58]], %[[constant_72]], %[[subtract_71]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[divide_74:[^ ]+]] = f32[1,1]{1,0} divide(%[[dynamic_slice_66]], %[[dynamic_slice_73]]) +// CHECK-NEXT: %[[get_tuple_element_59:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_53]]), index=5 +// CHECK-NEXT: %[[constant_76:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[subtract_75:[^ ]+]] = s32[] subtract(%[[add_64]], %[[constant_63]]) +// CHECK-NEXT: %[[dynamic_slice_77:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_59]], %[[constant_76]], %[[subtract_75]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[multiply_78:[^ ]+]] = f32[1,1]{1,0} multiply(%[[divide_74]], %[[dynamic_slice_77]]) +// CHECK-NEXT: %[[subtract_79:[^ ]+]] = f32[1,1]{1,0} subtract(%[[dynamic_slice_68]], %[[multiply_78]]) +// CHECK-NEXT: %[[constant_80:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_81:[^ ]+]] = f32[1,3]{1,0} dynamic-update-slice(%[[get_tuple_element_58]], %[[subtract_79]], %[[constant_80]], %[[add_64]]) +// CHECK-NEXT: %[[get_tuple_element_60:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_53]]), index=6 +// CHECK-NEXT: %[[constant_69:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_70:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_57]], %[[constant_69]], %[[add_64]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[reshape_85:[^ ]+]] = f32[1]{0} reshape(%[[divide_74]]) +// CHECK-NEXT: %[[broadcast_86:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_85]]), dimensions={1} +// CHECK-NEXT: %[[constant_83:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[subtract_82:[^ ]+]] = s32[] subtract(%[[add_64]], %[[constant_63]]) +// CHECK-NEXT: %[[dynamic_slice_84:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_60]], %[[constant_83]], %[[subtract_82]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[multiply_87:[^ ]+]] = f32[3,1]{1,0} multiply(%[[broadcast_86]], %[[dynamic_slice_84]]) +// CHECK-NEXT: %[[subtract_88:[^ ]+]] = f32[3,1]{1,0} subtract(%[[dynamic_slice_70]], %[[multiply_87]]) +// CHECK-NEXT: %[[constant_89:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_90:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[get_tuple_element_60]], %[[subtract_88]], %[[constant_89]], %[[add_64]]) +// CHECK-NEXT: ROOT %[[tuple_91:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) tuple(%[[add_62]], %[[get_tuple_element_55]], %[[get_tuple_element_56]], %[[get_tuple_element_57]], %[[dynamic_update_slice_81]], /*index=5*/%[[get_tuple_element_59]], %[[dynamic_update_slice_90]]) + +// CHECK: %[[$forward_transformation_condition_92:[^ ]+]] +// CHECK-NEXT: %[[parameter_93:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_95:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_93]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_96:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_93]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_97:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_93]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_98:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_93]]), index=4 +// CHECK-NEXT: %[[get_tuple_element_99:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_93]]), index=5 +// CHECK-NEXT: %[[get_tuple_element_100:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_93]]), index=6 +// CHECK-NEXT: %[[get_tuple_element_94:[^ ]+]] = s32[] get-tuple-element(%[[parameter_93]]), index=0 +// CHECK-NEXT: %[[constant_101:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: ROOT %[[compare_102:[^ ]+]] = pred[] compare(%[[get_tuple_element_94]], %[[constant_101]]), direction=LT + +// CHECK: %[[$backward_reduction_body_125:[^ ]+]] +// CHECK-NEXT: %[[parameter_126:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_127:[^ ]+]] = s32[] get-tuple-element(%[[parameter_126]]), index=0 +// CHECK-NEXT: %[[constant_132:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_133:[^ ]+]] = s32[] add(%[[get_tuple_element_127]], %[[constant_132]]) +// CHECK-NEXT: %[[get_tuple_element_128:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_126]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_129:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_126]]), index=2 +// CHECK-NEXT: %[[constant_137:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_134:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[subtract_136:[^ ]+]] = s32[] subtract(%[[constant_134]], %[[get_tuple_element_127]]) +// CHECK-NEXT: %[[dynamic_slice_138:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_129]], %[[constant_137]], %[[subtract_136]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[get_tuple_element_130:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_126]]), index=3 +// CHECK-NEXT: %[[constant_139:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_140:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_130]], %[[constant_139]], %[[subtract_136]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[reshape_144:[^ ]+]] = f32[1]{0} reshape(%[[dynamic_slice_140]]) +// CHECK-NEXT: %[[broadcast_145:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_144]]), dimensions={1} +// CHECK-NEXT: %[[constant_142:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_135:[^ ]+]] = s32[] constant(1) +// CHECK-NEXT: %[[add_141:[^ ]+]] = s32[] add(%[[subtract_136]], %[[constant_135]]) +// CHECK-NEXT: %[[dynamic_slice_143:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_128]], %[[constant_142]], %[[add_141]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[multiply_146:[^ ]+]] = f32[3,1]{1,0} multiply(%[[broadcast_145]], %[[dynamic_slice_143]]) +// CHECK-NEXT: %[[subtract_147:[^ ]+]] = f32[3,1]{1,0} subtract(%[[dynamic_slice_138]], %[[multiply_146]]) +// CHECK-NEXT: %[[get_tuple_element_131:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_126]]), index=4 +// CHECK-NEXT: %[[constant_148:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_149:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_131]], %[[constant_148]], %[[subtract_136]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[reshape_150:[^ ]+]] = f32[1]{0} reshape(%[[dynamic_slice_149]]) +// CHECK-NEXT: %[[broadcast_151:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_150]]), dimensions={1} +// CHECK-NEXT: %[[divide_152:[^ ]+]] = f32[3,1]{1,0} divide(%[[subtract_147]], %[[broadcast_151]]) +// CHECK-NEXT: %[[constant_153:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_154:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[get_tuple_element_128]], %[[divide_152]], %[[constant_153]], %[[subtract_136]]) +// CHECK-NEXT: ROOT %[[tuple_155:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[add_133]], %[[dynamic_update_slice_154]], %[[get_tuple_element_129]], %[[get_tuple_element_130]], %[[get_tuple_element_131]]) + +// CHECK: %[[$backward_reduction_condition_156:[^ ]+]] +// CHECK-NEXT: %[[parameter_157:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) parameter(0) +// CHECK-NEXT: %[[get_tuple_element_159:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_157]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_160:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[parameter_157]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_161:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_157]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_162:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[parameter_157]]), index=4 +// CHECK-NEXT: %[[get_tuple_element_158:[^ ]+]] = s32[] get-tuple-element(%[[parameter_157]]), index=0 +// CHECK-NEXT: %[[constant_163:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: ROOT %[[compare_164:[^ ]+]] = pred[] compare(%[[get_tuple_element_158]], %[[constant_163]]), direction=LT + +// CHECK: %[[$xla_builder_tridiagonal_TridiagonalSolver_171:[^ ]+]] +// CHECK-NEXT: %[[constant_25:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_9:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = f32[1,3]{1,0} broadcast(%[[constant_9]]), dimensions={} +// CHECK-NEXT: %[[arg2_3:[^ ]+]] = f32[1,3]{1,0} parameter(2) +// CHECK-NEXT: %[[tuple_26:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[constant_25]], %[[broadcast_10]], %[[arg2_3]]) +// CHECK-NEXT: %[[while_46:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}) while(%[[tuple_26]]), condition=%[[$preparation_condition_39]], body=%[[$preparation_body_27]] +// CHECK-NEXT: %[[get_tuple_element_47:[^ ]+]] = s32[] get-tuple-element(%[[while_46]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_49:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_46]]), index=2 +// CHECK-NEXT: %[[constant_50:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[arg0_1:[^ ]+]] = f32[1,3]{1,0} parameter(0) +// CHECK-NEXT: %[[arg1_2:[^ ]+]] = f32[1,3]{1,0} parameter(1) +// CHECK-NEXT: %[[arg3_4:[^ ]+]] = f32[3,3]{1,0} parameter(3) +// CHECK-NEXT: %[[constant_5:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = f32[1,3]{1,0} broadcast(%[[constant_5]]), dimensions={} +// CHECK-NEXT: %[[constant_14:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_13:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_15:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[arg1_2]], %[[constant_14]], %[[constant_13]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[constant_17:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_16:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_18:[^ ]+]] = f32[1,3]{1,0} dynamic-update-slice(%[[broadcast_6]], %[[dynamic_slice_15]], %[[constant_17]], %[[constant_16]]) +// CHECK-NEXT: %[[get_tuple_element_48:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_46]]), index=1 +// CHECK-NEXT: %[[constant_7:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_8:[^ ]+]] = f32[3,3]{1,0} broadcast(%[[constant_7]]), dimensions={} +// CHECK-NEXT: %[[constant_20:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_19:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_slice_21:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[arg3_4]], %[[constant_20]], %[[constant_19]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[constant_23:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_22:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[dynamic_update_slice_24:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[broadcast_8]], %[[dynamic_slice_21]], %[[constant_23]], %[[constant_22]]) +// CHECK-NEXT: %[[tuple_51:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) tuple(%[[constant_50]], %[[arg0_1]], %[[arg1_2]], %[[arg3_4]], %[[dynamic_update_slice_18]], /*index=5*/%[[get_tuple_element_48]], %[[dynamic_update_slice_24]]) +// CHECK-NEXT: %[[while_103:[^ ]+]] = (s32[], f32[1,3]{1,0}, f32[1,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, /*index=5*/f32[1,3]{1,0}, f32[3,3]{1,0}) while(%[[tuple_51]]), condition=%[[$forward_transformation_condition_92]], body=%[[$forward_transformation_body_52]] +// CHECK-NEXT: %[[get_tuple_element_104:[^ ]+]] = s32[] get-tuple-element(%[[while_103]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_105:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_103]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_106:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_103]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_107:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_103]]), index=3 +// CHECK-NEXT: %[[constant_123:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_11:[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = f32[3,3]{1,0} broadcast(%[[constant_11]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_110:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_103]]), index=6 +// CHECK-NEXT: %[[constant_112:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_111:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: %[[dynamic_slice_113:[^ ]+]] = f32[3,1]{1,0} dynamic-slice(%[[get_tuple_element_110]], %[[constant_112]], %[[constant_111]]), dynamic_slice_sizes={3,1} +// CHECK-NEXT: %[[get_tuple_element_108:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_103]]), index=4 +// CHECK-NEXT: %[[constant_115:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_114:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: %[[dynamic_slice_116:[^ ]+]] = f32[1,1]{1,0} dynamic-slice(%[[get_tuple_element_108]], %[[constant_115]], %[[constant_114]]), dynamic_slice_sizes={1,1} +// CHECK-NEXT: %[[reshape_117:[^ ]+]] = f32[1]{0} reshape(%[[dynamic_slice_116]]) +// CHECK-NEXT: %[[broadcast_118:[^ ]+]] = f32[3,1]{1,0} broadcast(%[[reshape_117]]), dimensions={1} +// CHECK-NEXT: %[[divide_119:[^ ]+]] = f32[3,1]{1,0} divide(%[[dynamic_slice_113]], %[[broadcast_118]]) +// CHECK-NEXT: %[[constant_121:[^ ]+]] = s32[] constant(0) +// CHECK-NEXT: %[[constant_120:[^ ]+]] = s32[] constant(2) +// CHECK-NEXT: %[[dynamic_update_slice_122:[^ ]+]] = f32[3,3]{1,0} dynamic-update-slice(%[[broadcast_12]], %[[divide_119]], %[[constant_121]], %[[constant_120]]) +// CHECK-NEXT: %[[get_tuple_element_109:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_103]]), index=5 +// CHECK-NEXT: %[[tuple_124:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) tuple(%[[constant_123]], %[[dynamic_update_slice_122]], %[[get_tuple_element_110]], %[[get_tuple_element_109]], %[[get_tuple_element_108]]) +// CHECK-NEXT: %[[while_165:[^ ]+]] = (s32[], f32[3,3]{1,0}, f32[3,3]{1,0}, f32[1,3]{1,0}, f32[1,3]{1,0}) while(%[[tuple_124]]), condition=%[[$backward_reduction_condition_156]], body=%[[$backward_reduction_body_125]] +// CHECK-NEXT: %[[get_tuple_element_166:[^ ]+]] = s32[] get-tuple-element(%[[while_165]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_167:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_165]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_168:[^ ]+]] = f32[3,3]{1,0} get-tuple-element(%[[while_165]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_169:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_165]]), index=3 +// CHECK-NEXT: ROOT %[[get_tuple_element_170:[^ ]+]] = f32[1,3]{1,0} get-tuple-element(%[[while_165]]), index=4 + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[1,3]{1,0} parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = f32[1,3]{1,0} parameter(1) +// CHECK-NEXT: %[[Arg_2_3:[^ ]+]] = f32[1,3]{1,0} parameter(2) +// CHECK-NEXT: %[[Arg_3_4:[^ ]+]] = f32[3,3]{1,0} parameter(3) +// CHECK-NEXT: %[[custom_call_2:[^ ]+]] = f32[3,3]{1,0} custom-call(%[[Arg_0_1]], %[[Arg_1_2]], %[[Arg_2_3]], %[[Arg_3_4]]), custom_call_target="xla_builder.tridiagonal.TridiagonalSolver" +// CHECK-NEXT: ROOT %[[custom_call:[^ ]+]] = f32[3,3]{1,0} custom-call(%[[Arg_0_1]], %[[Arg_1_2]], %[[Arg_2_3]], %[[Arg_3_4]]), custom_call_target="xla_builder.tridiagonal.TridiagonalSolver", called_computations={%[[$xla_builder_tridiagonal_TridiagonalSolver_171]]} + +HloModule tridiagonal_tridiagonal_solve_all_args, entry_computation_layout={(f32[1,3],f32[1,3],f32[1,3],f32[3,3])->f32[3,3]} + +ENTRY %main.3 (Arg_0.1: f32[1,3], Arg_1.2: f32[1,3], Arg_2.3: f32[1,3], Arg_3.4: f32[3,3]) -> f32[3,3] { + %Arg_0.1 = f32[1,3] parameter(0) + %Arg_1.2 = f32[1,3] parameter(1) + %Arg_2.3 = f32[1,3] parameter(2) + %Arg_3.4 = f32[3,3] parameter(3) + ROOT %custom-call.2 = f32[3,3] custom-call(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4), custom_call_target="xla_builder.tridiagonal.TridiagonalSolver" +} \ No newline at end of file diff --git a/third_party/xla/xla/hlo/tools/tests/BUILD b/third_party/xla/xla/hlo/tools/tests/BUILD index 814872908f243d..7d101c7dbbda9d 100644 --- a/third_party/xla/xla/hlo/tools/tests/BUILD +++ b/third_party/xla/xla/hlo/tools/tests/BUILD @@ -89,6 +89,7 @@ cc_library( "//xla/hlo/builder/lib:math", "//xla/hlo/builder/lib:matrix", "//xla/hlo/builder/lib:prng", + "//xla/hlo/builder/lib:tridiagonal", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:algebraic_simplifier", @@ -101,6 +102,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc b/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc index cf93d4ff4cb117..1379a912e72edd 100644 --- a/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc +++ b/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc @@ -15,18 +15,22 @@ limitations under the License. #include "xla/hlo/tools/tests/hlo_opt_test_only_passes.h" +#include #include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/lib/tridiagonal.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -71,17 +75,28 @@ std::vector GetParameters(XlaBuilder& builder, return parameters; } -absl::Status VerifyOperandCount(HloInstruction* instruction, - int64_t expected_operand_count, - absl::string_view custom_call_target) { - if (instruction->operand_count() != expected_operand_count) { +absl::Status VerifyOperandCounts( + HloInstruction* instruction, + const std::vector& expected_operand_counts, + absl::string_view custom_call_target) { + if (std::find(expected_operand_counts.begin(), expected_operand_counts.end(), + instruction->operand_count()) == + expected_operand_counts.end()) { return absl::InvalidArgumentError(absl::StrCat( - custom_call_target, " expected ", expected_operand_count, - " operands, but got ", instruction->operand_count(), " operands.")); + custom_call_target, " expected ", + absl::StrJoin(expected_operand_counts, " or "), " operands, but got ", + instruction->operand_count(), " operands.")); } return absl::OkStatus(); } +absl::Status VerifyOperandCount(HloInstruction* instruction, + int64_t expected_operand_count, + absl::string_view custom_call_target) { + return VerifyOperandCounts(instruction, {expected_operand_count}, + custom_call_target); +} + absl::StatusOr BuildAndReplace(XlaBuilder& builder, HloInstruction* instruction) { HloComputation* computation = instruction->parent(); @@ -183,6 +198,24 @@ absl::StatusOr XlaBuilderTestPass::ReplaceWithExpandedClientHlo( return BuildAndReplace(builder, instruction); } + // xla_builder.tridiagonal + if (custom_call_target == "xla_builder.tridiagonal.TridiagonalSolver") { + TF_RETURN_IF_ERROR( + VerifyOperandCounts(instruction, {2, 4}, custom_call_target)); + if (parameters.size() == 2) { + TF_ASSIGN_OR_RETURN( + std::ignore, xla::tridiagonal::TridiagonalSolver( + tridiagonal::SolverAlgorithm::kThomas, parameters[0], + parameters[1])); + return BuildAndReplace(builder, instruction); + } + TF_ASSIGN_OR_RETURN( + std::ignore, xla::tridiagonal::TridiagonalSolver( + tridiagonal::SolverAlgorithm::kThomas, parameters[0], + parameters[1], parameters[2], parameters[3])); + return BuildAndReplace(builder, instruction); + } + return absl::InvalidArgumentError(absl::StrCat( "Unsupported xla_builder custom call target: ", custom_call_target)); } From dc5126ed55b6a02769c4b8f7eb016d88be742589 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Tue, 1 Apr 2025 12:52:48 -0700 Subject: [PATCH 0112/1324] PR #24456: Add support for CUDA 13 (only when available locally) Imported from GitHub PR https://github.com/openxla/xla/pull/24456 Copybara import of the project: -- d152d725f2cbbe3bdd1df17a7edcc7da620ad703 by Dimitris Vardoulakis : Add support for CUDA 13 (only when available locally) Merging this change closes #24456 PiperOrigin-RevId: 742806633 --- third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl | 8 ++++++++ .../gpus/cuda/hermetic/cuda_redist_versions.bzl | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index bf26061ee86139..e0543e8cc8e433 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -244,6 +244,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_cudart": { "repo_name": "cuda_cudart", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", }, @@ -284,12 +285,14 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "libnvjitlink": { "repo_name": "cuda_nvjitlink", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", }, }, "cuda_nvrtc": { "repo_name": "cuda_nvrtc", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", }, @@ -297,6 +300,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_cccl": { "repo_name": "cuda_cccl", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", }, @@ -304,6 +308,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvcc": { "repo_name": "cuda_nvcc", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", }, @@ -311,6 +316,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvml_dev": { "repo_name": "cuda_nvml", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", }, @@ -318,6 +324,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvprune": { "repo_name": "cuda_nvprune", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", }, @@ -325,6 +332,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvtx": { "repo_name": "cuda_nvtx", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", }, diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index bf26061ee86139..e0543e8cc8e433 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -244,6 +244,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_cudart": { "repo_name": "cuda_cudart", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", }, @@ -284,12 +285,14 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "libnvjitlink": { "repo_name": "cuda_nvjitlink", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", }, }, "cuda_nvrtc": { "repo_name": "cuda_nvrtc", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", }, @@ -297,6 +300,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_cccl": { "repo_name": "cuda_cccl", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", }, @@ -304,6 +308,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvcc": { "repo_name": "cuda_nvcc", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", }, @@ -311,6 +316,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvml_dev": { "repo_name": "cuda_nvml", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", }, @@ -318,6 +324,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvprune": { "repo_name": "cuda_nvprune", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", }, @@ -325,6 +332,7 @@ REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "cuda_nvtx": { "repo_name": "cuda_nvtx", "version_to_template": { + "13": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", }, From 62898db8527be05a32166445c82c75cb35969fb6 Mon Sep 17 00:00:00 2001 From: Julia Guo Date: Tue, 1 Apr 2025 13:02:15 -0700 Subject: [PATCH 0113/1324] [XLA] Add num_repeats to nightly workflow runs to reduce noise. PiperOrigin-RevId: 742809909 --- .../xla/.github/workflows/benchmark_postsubmit.yml | 8 ++++---- third_party/xla/.github/workflows/benchmark_presubmit.yml | 2 +- .../xla/.github/workflows/cpu_benchmarks_nightly.yml | 6 +++--- .../xla/.github/workflows/gpu_benchmarks_nightly.yml | 4 ++++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/third_party/xla/.github/workflows/benchmark_postsubmit.yml b/third_party/xla/.github/workflows/benchmark_postsubmit.yml index a42368626378fb..cadb7f26538c49 100644 --- a/third_party/xla/.github/workflows/benchmark_postsubmit.yml +++ b/third_party/xla/.github/workflows/benchmark_postsubmit.yml @@ -140,10 +140,10 @@ jobs: pwd #print working directory if [[ "$platform" == "CPU" ]]; then - $binary_dir/multihost_hlo_runner/$runner_binary --device_type=host --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" + $binary_dir/multihost_hlo_runner/$runner_binary --device_type=host --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" $binary_dir/compute_xspace_stats_main --input="$xspace_file" --device_type=CPU >> "$output_file" elif [[ "$platform" == "GPU" ]]; then - $binary_dir/multihost_hlo_runner/$runner_binary --device_type=gpu --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" + $binary_dir/multihost_hlo_runner/$runner_binary --device_type=gpu --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" $binary_dir/compute_xspace_stats_main_gpu --input="$xspace_file" --device_type=GPU >> "$output_file" else echo "Unsupported platform: $platform" @@ -196,10 +196,10 @@ jobs: xspace_file="$GITHUB_WORKSPACE/xspace.pb" if [[ "$platform" == "CPU" ]]; then - $binary_dir/multihost_hlo_runner/$runner_binary --device_type=host --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" + $binary_dir/multihost_hlo_runner/$runner_binary --device_type=host --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" $binary_dir/compute_xspace_stats_main --input="$xspace_file" --device_type=CPU >> "$output_file" elif [[ "$platform" == "GPU" ]]; then - $binary_dir/multihost_hlo_runner/$runner_binary --device_type=gpu --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" + $binary_dir/multihost_hlo_runner/$runner_binary --device_type=gpu --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$xspace_file" "$test_hlo_file" > "$output_file" $binary_dir/compute_xspace_stats_main_gpu --input="$xspace_file" --device_type=GPU >> "$output_file" else echo "Unsupported platform: $platform" diff --git a/third_party/xla/.github/workflows/benchmark_presubmit.yml b/third_party/xla/.github/workflows/benchmark_presubmit.yml index 0270695c824fab..516adbe9888fa9 100644 --- a/third_party/xla/.github/workflows/benchmark_presubmit.yml +++ b/third_party/xla/.github/workflows/benchmark_presubmit.yml @@ -129,6 +129,6 @@ jobs: echo "Running test with binary: $HLO_RUNNER_BINARY" pwd #print working directory - $HLO_RUNNER_BINARY --device_type=$DEVICE_TYPE --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$TEST_HLO_FILE" > "$OUTPUT_FILE_PATH" + $HLO_RUNNER_BINARY --device_type=$DEVICE_TYPE --use_spmd_partitioning --num_repeats=5 --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$TEST_HLO_FILE" > "$OUTPUT_FILE_PATH" $COMPUTE_XSPACE_STATS_BINARY --input="$XSPACE_FILE_PATH" --device_type="${{ matrix.job_info.platform }}" >> "$OUTPUT_FILE_PATH" cat "$OUTPUT_FILE_PATH" diff --git a/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml b/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml index 6a941fe8b37553..0e2a4aa4a61476 100644 --- a/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml +++ b/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml @@ -140,7 +140,7 @@ jobs: test_hlo_file="xla/tools/hlo_opt/tests/cpu_hlo.hlo" echo "Running CPU test with binary: $binary_dir" pwd #print working directory - $binary_dir/multihost_hlo_runner/hlo_runner_main --device_type=host --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$test_hlo_file" > "$OUTPUT_FILE_PATH" + $binary_dir/multihost_hlo_runner/hlo_runner_main --device_type=host --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$test_hlo_file" > "$OUTPUT_FILE_PATH" $binary_dir/compute_xspace_stats_main --input="$XSPACE_FILE_PATH" --device_type=CPU >> "$OUTPUT_FILE_PATH" cat "$OUTPUT_FILE_PATH" @@ -153,7 +153,7 @@ jobs: test_hlo_file="$OUTPUT_DIR/tmp_hlo/gemma2_2b_keras_jax.hlo" echo "Running CPU test with binary: $binary_dir" pwd #print working directory - $binary_dir/multihost_hlo_runner/hlo_runner_main --device_type=host --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$test_hlo_file" > "$OUTPUT_FILE_PATH" + $binary_dir/multihost_hlo_runner/hlo_runner_main --device_type=host --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$test_hlo_file" > "$OUTPUT_FILE_PATH" $binary_dir/compute_xspace_stats_main --input="$XSPACE_FILE_PATH" --device_type=CPU >> "$OUTPUT_FILE_PATH" cat "$OUTPUT_FILE_PATH" @@ -166,7 +166,7 @@ jobs: test_hlo_file="$OUTPUT_DIR/tmp_hlo/gemma3_1b_flax_call.hlo" echo "Running CPU test with binary: $binary_dir" pwd #print working directory - $binary_dir/multihost_hlo_runner/hlo_runner_main --device_type=host --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$test_hlo_file" > "$OUTPUT_FILE_PATH" + $binary_dir/multihost_hlo_runner/hlo_runner_main --device_type=host --num_repeats=5 --use_spmd_partitioning --profile_execution=True --xla_gpu_dump_xspace_to="$XSPACE_FILE_PATH" "$test_hlo_file" > "$OUTPUT_FILE_PATH" $binary_dir/compute_xspace_stats_main --input="$XSPACE_FILE_PATH" --device_type=CPU >> "$OUTPUT_FILE_PATH" cat "$OUTPUT_FILE_PATH" diff --git a/third_party/xla/.github/workflows/gpu_benchmarks_nightly.yml b/third_party/xla/.github/workflows/gpu_benchmarks_nightly.yml index 9b4db985e1bf64..95c547ee7e5749 100644 --- a/third_party/xla/.github/workflows/gpu_benchmarks_nightly.yml +++ b/third_party/xla/.github/workflows/gpu_benchmarks_nightly.yml @@ -86,6 +86,7 @@ jobs: echo "Running GPU test: $HLO_FILE_GB" $binary_dir/multihost_hlo_runner/hlo_runner_main_gpu \ --device_type=gpu \ + --num_repeats=5 \ --use_spmd_partitioning \ --profile_execution=True \ --xla_gpu_dump_xspace_to="${OUTPUT_PREFIX_GB}_xspace.pb" \ @@ -105,6 +106,7 @@ jobs: echo "Running GPU test: $HLO_FILE_GEMMA" $binary_dir/multihost_hlo_runner/hlo_runner_main_gpu \ --device_type=gpu \ + --num_repeats=5 \ --use_spmd_partitioning \ --profile_execution=True \ --xla_gpu_dump_xspace_to="${OUTPUT_PREFIX_GEMMA}_xspace.pb" \ @@ -126,6 +128,7 @@ jobs: echo "Running GPU test: $HLO_FILE_GEMMA3_CALL" $binary_dir/multihost_hlo_runner/hlo_runner_main_gpu \ --device_type=gpu \ + --num_repeats=5 \ --use_spmd_partitioning \ --profile_execution=True \ --xla_gpu_dump_xspace_to="${OUTPUT_PREFIX_GEMMA3_CALL}_xspace.pb" \ @@ -147,6 +150,7 @@ jobs: echo "Running GPU test: $HLO_FILE_GEMMA3_SAMPLE_LOOP" $binary_dir/multihost_hlo_runner/hlo_runner_main_gpu \ --device_type=gpu \ + --num_repeats=5 \ --use_spmd_partitioning \ --profile_execution=True \ --xla_gpu_dump_xspace_to="${OUTPUT_PREFIX_GEMMA3_SAMPLE_LOOP}_xspace.pb" \ From fdcb6dbbb2507346afda90ba82ec72c74263c7db Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 13:04:44 -0700 Subject: [PATCH 0114/1324] Reverts ea4cacead963a3217d02ef2668353beeef3f6ff6 PiperOrigin-RevId: 742810754 --- .../compiler/mlir/lite/tests/prepare-tf.mlir | 14 + .../tests/legalize-tf-BatchMatMulV2.mlir | 89 ++ .../tests/legalize-tf-binary-elementwise.mlir | 87 ++ .../legalize-tf-with-tf2xla-hlo-importer.mlir | 9 + .../mlir/tf2xla/tests/legalize-tf.mlir | 1058 +++++++++++++++++ .../mlir/tf2xla/transforms/legalize_tf.cc | 361 ++++++ .../tf2xla/transforms/legalize_tf_patterns.td | 37 + 7 files changed, 1655 insertions(+) create mode 100644 tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 974fbc2ab7c788..f6e8d6610aba74 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -521,6 +521,20 @@ func.func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1 // CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) <{squeeze_dims = [-2]}> : (tensor<10x20x1x30xcomplex>) -> tensor<10x20x30xcomplex> } +// CHECK-LABEL: xla_gather_to_strided_slice +func.func @xla_gather_to_strided_slice(%arg0 : tensor<1x9x104x768xf32>) -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<[1, 9, 23, 768]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.XlaGather"(%arg0, %0, %1) {device = "", dimension_numbers = "\0A\04\00\01\02\03\1A\01\02", indices_are_sorted = false} : (tensor<1x9x104x768xf32>, tensor<1xi32>, tensor<4xi32>) -> tensor + func.return %2 : tensor + +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<4xi64> +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<[1, 9, 23, 768]> : tensor<4xi64> +// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : tensor<4xi64> +// CHECK: %[[V0:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]]) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x9x104x768xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor +// CHECK: return %[[V0]] : tensor +} + // CHECK-LABEL: DontMatchFusedBatchNormV3 func.func @DontMatchFusedBatchNormV3(%arg0 :tensor, %arg1 : tensor<576xf32>, %arg2 : tensor<576xf32>, %arg3 : tensor<576xf32>,%arg4 : tensor<576xf32>) -> (tensor) { %result:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>) -> (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<*xf32>) diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir new file mode 100644 index 00000000000000..f62e9a140e83d9 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -0,0 +1,89 @@ +// RUN: tf-opt -xla-legalize-tf %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// tf.BatchMatMulV2 op legalizations. +//===----------------------------------------------------------------------===// + +func.func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_basic +// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> +// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32> +// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32> +// CHECK: [[CM2:%.*]] = arith.constant -2 : index +// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) +// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) +// CHECK: [[BCASTHEAD:%.*]] = shape.broadcast [[LHSHEAD]], [[RHSHEAD]] +// CHECK: [[LHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[LHSTAIL]] +// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] +// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> +// CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]] +// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) +// CHECK: return [[RESULT]] : tensor<3x4x4xf32> +// CHECK: } + + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + func.return %0 : tensor<3x4x4xf32> +} + +func.func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_lhs_batch +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ +// CHECK-SAME: lhs_batching_dimensions = [0] +// CHECK-SAME: rhs_batching_dimensions = [0] +// CHECK-SAME: lhs_contracting_dimensions = [2] +// CHECK-SAME: rhs_contracting_dimensions = [1] + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32> + func.return %0 : tensor<3x4x4xf32> +} + +func.func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_rhs_batch +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ +// CHECK-SAME: lhs_batching_dimensions = [0] +// CHECK-SAME: rhs_batching_dimensions = [0] +// CHECK-SAME: lhs_contracting_dimensions = [2] +// CHECK-SAME: rhs_contracting_dimensions = [1] + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + func.return %0 : tensor<3x4x4xf32> +} + +func.func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK-LABEL: func @batchmatmulv2_dynamic +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ +// CHECK-SAME: lhs_batching_dimensions = [0] +// CHECK-SAME: rhs_batching_dimensions = [0] +// CHECK-SAME: lhs_contracting_dimensions = [2] +// CHECK-SAME: rhs_contracting_dimensions = [1] + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +func.func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_adj_real +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ +// CHECK-NOT: lhs_batching_dimensions +// CHECK-NOT: rhs_batching_dimensions +// CHECK-SAME: lhs_contracting_dimensions = [0] +// CHECK-SAME: rhs_contracting_dimensions = [1] + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32> + func.return %0 : tensor<5x4xf32> +} + +func.func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex>, %arg1: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { +// CHECK-LABEL: func @batchmatmulv2_adj_complex( +// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex>, [[RHS:%.*]]: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { +// CHECK: [[LHSRE:%.*]] = mhlo.real [[LHS]] +// CHECK: [[LHSIM:%.*]] = mhlo.imag [[LHS]] +// CHECK: [[LHSIMNEG:%.*]] = mhlo.negate [[LHSIM]] +// CHECK: [[LHSCONJ:%.*]] = mhlo.complex [[LHSRE]], [[LHSIMNEG]] +// CHECK: [[RHSRE:%.*]] = mhlo.real [[RHS]] +// CHECK: [[RHSIM:%.*]] = mhlo.imag [[RHS]] +// CHECK: [[RHSIMNEG:%.*]] = mhlo.negate [[RHSIM]] +// CHECK: [[RHSCONJ:%.*]] = mhlo.complex [[RHSRE]], [[RHSIMNEG]] +// CHECK: shape.shape_of [[LHSCONJ]] +// CHECK: shape.shape_of [[RHSCONJ]] + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex>, tensor<4x2xcomplex>) -> tensor<5x4xcomplex> + func.return %0 : tensor<5x4xcomplex> +} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir index a1cfcb69f9c27e..da64452a3039f8 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir @@ -73,6 +73,13 @@ func.func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { func.return %0: tensor<2xi32> } +// CHECK-LABEL: func @shift_left +func.func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: mhlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> +} + // CHECK-LABEL: func @div_unranked func.func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { // CHECK-NEXT: tf.Div @@ -87,6 +94,22 @@ func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> } +// CHECK-LABEL: func @minimum +func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @mod +// CHLO-LABEL: func @mod +func.func @mod(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: mhlo.remainder %arg0, %arg1 : tensor<4xf32> + // CHLO: chlo.broadcast_remainder + %0 = "tf.Mod"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + // CHECK-LABEL: func @mul func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> @@ -95,6 +118,13 @@ func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { func.return %0: tensor<2xi32> } +// CHECK-LABEL: func @real_div +func.func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + func.return %0: tensor<2xi32> +} + // CHECK-LABEL: func @sub func.func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> @@ -103,6 +133,28 @@ func.func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { func.return %0: tensor<2xi32> } +// CHECK-LABEL: func @shift_right +func.func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @shift_right_unsigned +func.func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { + // CHECK: mhlo.shift_right_logical %arg0, %arg1 : tensor<4xui8> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> + func.return %0 : tensor<4xui8> +} + +// CHECK-LABEL: func @broadcast_shift_right_unsigned +func.func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xui8>) -> tensor<2x4xui8> + // CHECK: mhlo.shift_right_logical %[[BROADCAST]], %arg1 : tensor<2x4xui8> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> + func.return %0 : tensor<2x4xui8> +} + // CHECK-LABEL: func @and func.func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: mhlo.and @@ -124,6 +176,20 @@ func.func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { func.return %0: tensor<2xi1> } +// CHECK-LABEL: func @bitwise_or +func.func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: mhlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_or_unsigned +func.func @bitwise_or_unsigned(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> tensor<4xui32> { + // CHECK-NEXT: mhlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xui32>, tensor<4xui32>) -> tensor<4xui32> + func.return %0: tensor<4xui32> +} + // CHECK-LABEL: func @bitwise_xor func.func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: mhlo.xor @@ -138,6 +204,27 @@ func.func @bitwise_xor_unsigned(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> func.return %0: tensor<4xui32> } +// CHECK-LABEL: func @bitwise_and +func.func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: mhlo.and + %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_and_unsigned +func.func @bitwise_and_unsigned(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> tensor<4xui32> { + // CHECK-NEXT: mhlo.and + %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xui32>, tensor<4xui32>) -> tensor<4xui32> + func.return %0: tensor<4xui32> +} + +// CHECK-LABEL: func @pow +func.func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: mhlo.power + %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0: tensor<2xf32> +} + //===----------------------------------------------------------------------===// // Equality op legalizations. // tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index 9aa7a763b329bd..f1fb2fec85722c 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -652,6 +652,15 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %1 : tensor<64x128xf32> } + // CHECK-LABEL: func @tf_mod + func.func @tf_mod(%arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = "tf.Const"() {value = dense<7.000000e+00> : tensor} : () -> tensor + // CHECK: "mhlo.dynamic_broadcast_in_dim" + // CHECK: mhlo.remainder + %6 = "tf.Mod"(%arg1, %cst) {_global_shape = [#tf_type.shape<4x8>], device = ""} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %6 : tensor<2x2xf32> + } + // CHECK-LABEL: func @concat_v2 func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 4c6a0c52c36edb..92754a181e8551 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -1,4 +1,11 @@ // RUN: tf-opt "-xla-legalize-tf=legalize-chlo=false" -split-input-file %s | FILECHECK_OPTS="" FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=legalize-chlo=true" -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix CHLO +// This test runs twice: +// 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying +// that the chlo ops emit produces more useful tests. +// 2. With chlo legalization enabled, verifying diagnostics to pick up any +// issues with the full lowering (can catch some broadcasting corner +// cases which emit with a warning). //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -6,6 +13,29 @@ // ----- +// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same +// code), so only do a couple of basic checks. + +// CHECK-LABEL: fusedBatchNormV2_noTraining +func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV2_training +func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: mhlo.constant + // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + // CHECK-LABEL: fusedBatchNormV3_noTraining func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> @@ -130,6 +160,139 @@ func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> + + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> + + // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> + + // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGrad_Training +func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining +func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> + + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> + + // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + + // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> + + // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_Training +func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision +func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + + // CHECK: %[[x_backprop:.*]] = mhlo.convert {{.*}} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision +func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + // CHECK-LABEL: fusedBatchNormGradV3_noTraining func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> @@ -568,6 +731,32 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten func.return %2: tensor<3x5x7x9x11x4x10xf32> } +//===----------------------------------------------------------------------===// +// Erf +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @erf +func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: mhlo.erf %arg0 : tensor<2x3xf32> + %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +//===----------------------------------------------------------------------===// +// Erfc +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @erfc +func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: chlo.erfc %arg0 : tensor<2x3xf32> + %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + //===----------------------------------------------------------------------===// // Einsum. //===----------------------------------------------------------------------===// @@ -591,6 +780,242 @@ func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { func.return %0: tensor<2x2xf32> } +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @floordiv_broadcast_i32 +func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] + // CHECK: return [[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 +func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] + // CHECK: return [[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_f32 +func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 + // CHECK-NEXT: %[[FLOOR:.*]] = mhlo.floor %[[DIV]] + // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> + %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0: tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_bf16 +func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { + // CHECK-NEXT: mhlo.convert + // CHECK-NEXT: mhlo.convert + // CHECK-NEXT: chlo.broadcast_divide + // CHECK-NEXT: mhlo.floor + // CHECK-NEXT: mhlo.convert + // CHECK-NEXT: return + %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + func.return %0: tensor<2xbf16> +} + +// ----- + +// CHECK-LABEL: func @floordiv_f16_broadcast +func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: chlo.broadcast_divide + // CHECK-NEXT: mhlo.floor + // CHECK-NEXT: return + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + func.return %0: tensor<2x3xf16> +} + +// ----- + +// CHECK-LABEL: func @floordiv_dynamic +func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] + // CHECK: return [[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floordiv_unsigned +func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK: return [[DIV]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floordiv_int +func.func @floordiv_int(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 : (tensor, tensor) -> tensor + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[SUB]], [[DIV]] + // CHECK: return [[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floormod_broadcast_numerator +func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] + // CHECK-NEXT: return [[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floormod_broadcast_denominator +func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] + // CHECK-NEXT: return [[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator +func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-NEXT: return [[REM]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> + func.return %0: tensor<2x3xui32> +} + +// ----- + +// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator +func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] + // CHECK-NEXT: return [[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator +func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NOT: tf.FloorMod + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} : (tensor, tensor) -> tensor + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] : (tensor, tensor) -> tensor + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] : (tensor, tensor) -> tensor + // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] : tensor, tensor + // CHECK-NEXT: return [[SELECT]] : tensor + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +//===----------------------------------------------------------------------===// +// OnesLike +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @ones_like +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) +func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { + // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) <{value = 1.0{{.*}}}> + // CHECK: return %[[RES]] + %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> + func.return %0 : tensor<2x?xf32> +} + //===----------------------------------------------------------------------===// // ZerosLike //===----------------------------------------------------------------------===// @@ -629,6 +1054,15 @@ func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { // ----- +// CHECK-LABEL: func @complex +func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { + // CHECK: chlo.broadcast_complex + %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> + func.return %1 : tensor<3xcomplex> +} + +// ----- + // CHECK-LABEL: func @imag func.func @imag(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { // CHECK: mhlo.imag @@ -689,6 +1123,63 @@ func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf3 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> func.return %1 : tensor<3x6xf32> } + +//===----------------------------------------------------------------------===// +// Pad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @padv2_1D +func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { + %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ + // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, + // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, + // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor) -> tensor<6xf32> + func.return %1 : tensor<6xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_2D +func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ + // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, + // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, + // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor) -> tensor<6x9xf32> + func.return %1 : tensor<6x9xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_i32_paddings +func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ + // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, + // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, + // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32> + func.return %1 : tensor<6x9xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_dynamic +func.func @padv2_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x2xi64>) -> tensor { + // CHECK: "mhlo.transpose"({{.*}}) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<1x2xi64>) -> tensor<2x1xi64> + // CHECK: mhlo.reshape {{.*}} : (tensor<2x1xi64>) -> tensor<2xi64> + // CHECK: "mhlo.slice"({{.*}}) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: "mhlo.slice"({{.*}}) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: mhlo.dynamic_pad {{.*}} : (tensor, tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor, tensor<1x2xi64>, tensor) -> tensor + func.return %1 : tensor +} + //===----------------------------------------------------------------------===// // Identity op legalizations. //===----------------------------------------------------------------------===// @@ -704,6 +1195,15 @@ func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { // ----- +// CHECK-LABEL: func @identityN +func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { + // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> + %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) + func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> +} + +// ----- + // CHECK-LABEL: func @stopgradient func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: return %arg0 : tensor<1xi32> @@ -915,6 +1415,98 @@ func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4 // ----- +// CHECK-LABEL: maxpool_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor + // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}> + // CHECK: mhlo.maximum + // CHECK: mhlo.return + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + func.return %0 : tensor<2x3x5x7xi32> +} + +// ----- + +// CHECK-LABEL: maxpool_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> + func.return %0 : tensor<2x4x7x7xi32> +} + +// ----- + +// CHECK-LABEL: maxpool_3d_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor + // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: <{window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>}> + // CHECK: mhlo.maximum + // CHECK: mhlo.return + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> + func.return %0 : tensor<2x8x3x5x7xf32> +} + +// ----- + +// CHECK-LABEL: maxpool_3d_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> + func.return %0 : tensor<2x8x4x7x7xf32> +} + +// ----- + +// CHECK-LABEL: maxpool_explicit_padding +func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: tf.MaxPool + // TODO(b/165938852): need to support explicit padding in max_pool. + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + func.return %0 : tensor<2x3x5x7xi32> +} + +//===----------------------------------------------------------------------===// +// MaxPoolGrad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @max_pool_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> +func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> + func.return %result : tensor<10x24x24x64xf32> +} + +// ----- + // CHECK-LABEL: @max_pool_3d_grad_valid // CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { @@ -935,6 +1527,20 @@ func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_ // ----- +// CHECK-LABEL: @max_pool_grad_same +func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> + func.return %result : tensor<2x13x25x7xf32> +} + +// ----- + // CHECK-LABEL: @max_pool_3d_grad_same func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> @@ -1074,6 +1680,49 @@ func.func @callee() { func.return } +//===----------------------------------------------------------------------===// +// ReverseV2 op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @reverse_func_32 +func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + func.return %reversed : tensor<5xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_func_64 +func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) + + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + func.return %reversed : tensor<5xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_func_neg +func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { + %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<1> : tensor<1xi64>}> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> + + // CHECK: return [[VAL]] : tensor<5x5xi32> + func.return %reversed : tensor<5x5xi32> +} + //===----------------------------------------------------------------------===// // StatefulPartitionedCall op legalization. //===----------------------------------------------------------------------===// @@ -1106,6 +1755,39 @@ func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) - func.return %arg1, %arg0 : tensor, tensor } +//===----------------------------------------------------------------------===// +// Elu op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @elu +func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1xf32>) -> tensor<1xf32> + // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] + // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 + // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %arg0, %[[EXP]] + // CHECK: return %[[RESULT]] + %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + func.return %0: tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: func @elu_grad +// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) +func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = array} + // CHECK-DAG: %[[MULGRAD:.*]] = mhlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] + // CHECK: return %[[RESULT]] + %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + func.return %2 : tensor<4x8xf32> +} + //===----------------------------------------------------------------------===// // Relu op legalizations. //===----------------------------------------------------------------------===// @@ -1152,6 +1834,85 @@ func.func @relu6_unsigned(%arg0: tensor) -> tensor { func.return %0: tensor } +// ----- + +// CHECK-LABEL: func @leaky_relu +func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { + // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32> + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> + // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[INP]], %[[LEAKY]] : tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32> + // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> + %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + func.return %0 : tensor<1x4x4x3xf32> +} + +// ----- + +// CHECK-LABEL: func @leaky_relu_grad +func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { + // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32> + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> + // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] : tensor<1x4x4xi1>, tensor<1x4x4xf32> + // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> + %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + func.return %0 : tensor<1x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @softsign +func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { + // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) <{value = 1.000000e+00 : f32}> : (tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> + // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> + // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> + // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> + %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> + func.return %0 : tensor<4x10xf32> +} + +// ----- + +// CHECK-LABEL: func @softsign_grad +func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { + + // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = array} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> + %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + func.return %0 : tensor<4x10xf32> +} + +//===----------------------------------------------------------------------===// +// Roll op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @Roll_0D +func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor} : () -> (tensor) + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: %[[AXIS_SIZE:.*]] = mhlo.constant dense<512> : tensor + // CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor + // CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor + // CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor + // CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) <{dimension = 0 : i64}> + // CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor + // CHECK: "mhlo.dynamic_slice"(%[[CONCAT]], %[[OFFSET]]) + // CHECK-SAME: {slice_sizes = dense<512> : tensor<1xi64>} + // CHECK-SAME: (tensor<1024xi32>, tensor) -> tensor<512xi32> + %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor, tensor) -> tensor<512xi32> + func.return %0 : tensor<512xi32> +} + //===----------------------------------------------------------------------===// // Select op legalizations. //===----------------------------------------------------------------------===// @@ -1264,6 +2025,15 @@ func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32> // ----- +// CHECK-LABEL: func @fft_1D +func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xcomplex> + %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + // CHECK-LABEL: func @ifft_1D func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xcomplex> @@ -1273,6 +2043,38 @@ func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { // ----- +// CHECK-LABEL: func @rfft_1D +func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex> + func.return %0 : tensor<5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D_padded +func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<7xf32>, tensor) -> tensor<8xf32> + // CHECK: "mhlo.fft"(%[[PADDED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex> + func.return %0 : tensor<5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D_sliced +func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x9xf32>) -> tensor<2x8xf32> + // CHECK: "mhlo.fft"(%[[SLICED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<2x8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex> + func.return %0 : tensor<2x5xcomplex> +} + +// ----- + // CHECK-LABEL: func @irfft_1D func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) @@ -1282,6 +2084,25 @@ func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { func.return %0 : tensor<8xf32> } +// ----- + +// CHECK-LABEL: fft_1D_dynamic +func.func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { + // CHECK: "tf.FFT" + %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: rfft_1D_dynamic +func.func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: "tf.RFFT" + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + //===----------------------------------------------------------------------===// // Shape op legalization. //===----------------------------------------------------------------------===// @@ -1400,6 +2221,188 @@ func.func @abs_dynamic(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @acos +// CHLO-LABEL: @acos +func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: chlo.acos %arg0 : tensor<2xf32> + // CHLO: %[[TEMP_0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_1:.*]] = mhlo.subtract %[[TEMP_0]], %arg0 : tensor<2xf32> + // CHLO: %[[TEMP_2:.*]] = mhlo.add %[[TEMP_0]], %arg0 : tensor<2xf32> + // CHLO: %[[TEMP_3:.*]] = mhlo.multiply %[[TEMP_1]], %[[TEMP_2]] : tensor<2xf32> + // CHLO: %[[TEMP_4:.*]] = mhlo.sqrt %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_5:.*]] = mhlo.atan2 %[[TEMP_4]], %arg0 : tensor<2xf32> + // CHLO: return %[[TEMP_5]] : tensor<2xf32> + %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @acos_complex +// CHLO-LABEL: @acos_complex +func.func @acos_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK: chlo.acos + // CHLO: %[[TEMP_0:.*]] = mhlo.real %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> + // CHLO: %[[TEMP_1:.*]] = mhlo.abs %[[TEMP_0]] : tensor<2xf32> + // CHLO: %[[TEMP_2:.*]] = mhlo.imag %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> + // CHLO: %[[TEMP_3:.*]] = mhlo.abs %[[TEMP_2]] : tensor<2xf32> + // CHLO: %[[TEMP_4:.*]] = mhlo.maximum %[[TEMP_1]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_5:.*]] = mhlo.constant dense<3.40282347E+38> : tensor<2xf32> + // CHLO: %[[TEMP_6:.*]] = mhlo.sqrt %[[TEMP_5]] : tensor<2xf32> + // CHLO: %[[TEMP_7:.*]] = mhlo.constant dense<8.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_8:.*]] = mhlo.divide %[[TEMP_6]], %[[TEMP_7]] : tensor<2xf32> + // CHLO: %[[TEMP_9:.*]] = mhlo.compare GE, %[[TEMP_4]], %[[TEMP_8]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_10:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_11:.*]] = mhlo.compare LE, %[[TEMP_1]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_12:.*]] = mhlo.constant dense<5.000000e-01> : tensor<2xf32> + // CHLO: %[[TEMP_13:.*]] = mhlo.add %[[TEMP_1]], %[[TEMP_10]] : tensor<2xf32> + // CHLO: %[[TEMP_14:.*]] = mhlo.abs %[[TEMP_13]] : tensor<2xf32> + // CHLO: %[[TEMP_15:.*]] = mhlo.maximum %[[TEMP_14]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_16:.*]] = mhlo.minimum %[[TEMP_14]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_17:.*]] = mhlo.compare EQ, %[[TEMP_15]], %[[TEMP_16]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_18:.*]] = mhlo.constant dense<1.41421354> : tensor<2xf32> + // CHLO: %[[TEMP_19:.*]] = mhlo.multiply %[[TEMP_18]], %[[TEMP_15]] : tensor<2xf32> + // CHLO: %[[TEMP_20:.*]] = mhlo.divide %[[TEMP_16]], %[[TEMP_15]] : tensor<2xf32> + // CHLO: %[[TEMP_21:.*]] = mhlo.multiply %[[TEMP_20]], %[[TEMP_20]] : tensor<2xf32> + // CHLO: %[[TEMP_22:.*]] = mhlo.add %[[TEMP_10]], %[[TEMP_21]] : tensor<2xf32> + // CHLO: %[[TEMP_23:.*]] = mhlo.sqrt %[[TEMP_22]] : tensor<2xf32> + // CHLO: %[[TEMP_24:.*]] = mhlo.compare EQ, %[[TEMP_23]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_25:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_26:.*]] = mhlo.compare GT, %[[TEMP_21]], %[[TEMP_25]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_27:.*]] = mhlo.and %[[TEMP_24]], %[[TEMP_26]] : tensor<2xi1> + // CHLO: %[[TEMP_28:.*]] = mhlo.multiply %[[TEMP_15]], %[[TEMP_21]] : tensor<2xf32> + // CHLO: %[[TEMP_29:.*]] = mhlo.constant dense<2.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_30:.*]] = mhlo.divide %[[TEMP_28]], %[[TEMP_29]] : tensor<2xf32> + // CHLO: %[[TEMP_31:.*]] = mhlo.add %[[TEMP_15]], %[[TEMP_30]] : tensor<2xf32> + // CHLO: %[[TEMP_32:.*]] = mhlo.multiply %[[TEMP_15]], %[[TEMP_23]] : tensor<2xf32> + // CHLO: %[[TEMP_33:.*]] = mhlo.select %[[TEMP_27]], %[[TEMP_31]], %[[TEMP_32]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_34:.*]] = mhlo.select %[[TEMP_17]], %[[TEMP_19]], %[[TEMP_33]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_35:.*]] = mhlo.subtract %[[TEMP_1]], %[[TEMP_10]] : tensor<2xf32> + // CHLO: %[[TEMP_36:.*]] = mhlo.abs %[[TEMP_35]] : tensor<2xf32> + // CHLO: %[[TEMP_37:.*]] = mhlo.maximum %[[TEMP_36]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_38:.*]] = mhlo.minimum %[[TEMP_36]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_39:.*]] = mhlo.compare EQ, %[[TEMP_37]], %[[TEMP_38]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_40:.*]] = mhlo.multiply %[[TEMP_18]], %[[TEMP_37]] : tensor<2xf32> + // CHLO: %[[TEMP_41:.*]] = mhlo.divide %[[TEMP_38]], %[[TEMP_37]] : tensor<2xf32> + // CHLO: %[[TEMP_42:.*]] = mhlo.multiply %[[TEMP_41]], %[[TEMP_41]] : tensor<2xf32> + // CHLO: %[[TEMP_43:.*]] = mhlo.add %[[TEMP_10]], %[[TEMP_42]] : tensor<2xf32> + // CHLO: %[[TEMP_44:.*]] = mhlo.sqrt %[[TEMP_43]] : tensor<2xf32> + // CHLO: %[[TEMP_45:.*]] = mhlo.compare EQ, %[[TEMP_44]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_46:.*]] = mhlo.compare GT, %[[TEMP_42]], %[[TEMP_25]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_47:.*]] = mhlo.and %[[TEMP_45]], %[[TEMP_46]] : tensor<2xi1> + // CHLO: %[[TEMP_48:.*]] = mhlo.multiply %[[TEMP_37]], %[[TEMP_42]] : tensor<2xf32> + // CHLO: %[[TEMP_49:.*]] = mhlo.divide %[[TEMP_48]], %[[TEMP_29]] : tensor<2xf32> + // CHLO: %[[TEMP_50:.*]] = mhlo.add %[[TEMP_37]], %[[TEMP_49]] : tensor<2xf32> + // CHLO: %[[TEMP_51:.*]] = mhlo.multiply %[[TEMP_37]], %[[TEMP_44]] : tensor<2xf32> + // CHLO: %[[TEMP_52:.*]] = mhlo.select %[[TEMP_47]], %[[TEMP_50]], %[[TEMP_51]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_53:.*]] = mhlo.select %[[TEMP_39]], %[[TEMP_40]], %[[TEMP_52]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_54:.*]] = mhlo.add %[[TEMP_34]], %[[TEMP_53]] : tensor<2xf32> + // CHLO: %[[TEMP_55:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_54]] : tensor<2xf32> + // CHLO: %[[TEMP_56:.*]] = mhlo.add %[[TEMP_55]], %[[TEMP_1]] : tensor<2xf32> + // CHLO: %[[TEMP_57:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_56]] : tensor<2xf32> + // CHLO: %[[TEMP_58:.*]] = mhlo.multiply %[[TEMP_3]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_59:.*]] = mhlo.add %[[TEMP_34]], %[[TEMP_13]] : tensor<2xf32> + // CHLO: %[[TEMP_60:.*]] = mhlo.divide %[[TEMP_58]], %[[TEMP_59]] : tensor<2xf32> + // CHLO: %[[TEMP_61:.*]] = mhlo.subtract %[[TEMP_53]], %[[TEMP_35]] : tensor<2xf32> + // CHLO: %[[TEMP_62:.*]] = mhlo.add %[[TEMP_60]], %[[TEMP_61]] : tensor<2xf32> + // CHLO: %[[TEMP_63:.*]] = mhlo.multiply %[[TEMP_57]], %[[TEMP_62]] : tensor<2xf32> + // CHLO: %[[TEMP_64:.*]] = mhlo.sqrt %[[TEMP_63]] : tensor<2xf32> + // CHLO: %[[TEMP_65:.*]] = mhlo.divide %[[TEMP_57]], %[[TEMP_59]] : tensor<2xf32> + // CHLO: %[[TEMP_66:.*]] = mhlo.add %[[TEMP_53]], %[[TEMP_35]] : tensor<2xf32> + // CHLO: %[[TEMP_67:.*]] = mhlo.divide %[[TEMP_57]], %[[TEMP_66]] : tensor<2xf32> + // CHLO: %[[TEMP_68:.*]] = mhlo.add %[[TEMP_65]], %[[TEMP_67]] : tensor<2xf32> + // CHLO: %[[TEMP_69:.*]] = mhlo.sqrt %[[TEMP_68]] : tensor<2xf32> + // CHLO: %[[TEMP_70:.*]] = mhlo.multiply %[[TEMP_3]], %[[TEMP_69]] : tensor<2xf32> + // CHLO: %[[TEMP_71:.*]] = mhlo.select %[[TEMP_11]], %[[TEMP_64]], %[[TEMP_70]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_72:.*]] = mhlo.select %[[TEMP_9]], %[[TEMP_3]], %[[TEMP_71]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_73:.*]] = mhlo.constant dense<9.99999995E+11> : tensor<2xf32> + // CHLO: %[[TEMP_74:.*]] = mhlo.multiply %[[TEMP_8]], %[[TEMP_73]] : tensor<2xf32> + // CHLO: %[[TEMP_75:.*]] = mhlo.compare LT, %[[TEMP_1]], %[[TEMP_74]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_76:.*]] = mhlo.constant dense<9.99999997E-7> : tensor<2xf32> + // CHLO: %[[TEMP_77:.*]] = mhlo.multiply %[[TEMP_8]], %[[TEMP_76]] : tensor<2xf32> + // CHLO: %[[TEMP_78:.*]] = mhlo.constant dense<1.000000e+02> : tensor<2xf32> + // CHLO: %[[TEMP_79:.*]] = mhlo.multiply %[[TEMP_8]], %[[TEMP_78]] : tensor<2xf32> + // CHLO: %[[TEMP_80:.*]] = mhlo.select %[[TEMP_75]], %[[TEMP_77]], %[[TEMP_79]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_81:.*]] = mhlo.compare GE, %[[TEMP_3]], %[[TEMP_80]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_82:.*]] = mhlo.select %[[TEMP_81]], %[[TEMP_3]], %[[TEMP_1]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_83:.*]] = mhlo.select %[[TEMP_81]], %[[TEMP_80]], %[[TEMP_8]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_84:.*]] = mhlo.compare GE, %[[TEMP_82]], %[[TEMP_83]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_85:.*]] = mhlo.log %[[TEMP_29]] : tensor<2xf32> + // CHLO: %[[TEMP_86:.*]] = mhlo.log %[[TEMP_82]] : tensor<2xf32> + // CHLO: %[[TEMP_87:.*]] = mhlo.add %[[TEMP_85]], %[[TEMP_86]] : tensor<2xf32> + // CHLO: %[[TEMP_88:.*]] = mhlo.constant dense<0x7F800000> : tensor<2xf32> + // CHLO: %[[TEMP_89:.*]] = mhlo.compare EQ, %[[TEMP_3]], %[[TEMP_88]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_90:.*]] = mhlo.not %[[TEMP_89]] : tensor<2xi1> + // CHLO: %[[TEMP_91:.*]] = mhlo.and %[[TEMP_81]], %[[TEMP_90]] : tensor<2xi1> + // CHLO: %[[TEMP_92:.*]] = mhlo.divide %[[TEMP_1]], %[[TEMP_3]] : tensor<2xf32> + // CHLO: %[[TEMP_93:.*]] = mhlo.select %[[TEMP_91]], %[[TEMP_92]], %[[TEMP_25]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_94:.*]] = mhlo.multiply %[[TEMP_93]], %[[TEMP_93]] : tensor<2xf32> + // CHLO: %[[TEMP_95:.*]] = mhlo.log_plus_one %[[TEMP_94]] : tensor<2xf32> + // CHLO: %[[TEMP_96:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_95]] : tensor<2xf32> + // CHLO: %[[TEMP_97:.*]] = mhlo.add %[[TEMP_87]], %[[TEMP_96]] : tensor<2xf32> + // CHLO: %[[TEMP_98:.*]] = mhlo.constant dense<1.17549435E-38> : tensor<2xf32> + // CHLO: %[[TEMP_99:.*]] = mhlo.sqrt %[[TEMP_98]] : tensor<2xf32> + // CHLO: %[[TEMP_100:.*]] = mhlo.constant dense<4.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_101:.*]] = mhlo.multiply %[[TEMP_99]], %[[TEMP_100]] : tensor<2xf32> + // CHLO: %[[TEMP_102:.*]] = mhlo.compare LT, %[[TEMP_3]], %[[TEMP_101]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_103:.*]] = mhlo.compare LT, %[[TEMP_1]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_104:.*]] = mhlo.and %[[TEMP_102]], %[[TEMP_103]] : tensor<2xi1> + // CHLO: %[[TEMP_105:.*]] = mhlo.multiply %[[TEMP_13]], %[[TEMP_35]] : tensor<2xf32> + // CHLO: %[[TEMP_106:.*]] = mhlo.add %[[TEMP_55]], %[[TEMP_10]] : tensor<2xf32> + // CHLO: %[[TEMP_107:.*]] = mhlo.divide %[[TEMP_105]], %[[TEMP_106]] : tensor<2xf32> + // CHLO: %[[TEMP_108:.*]] = mhlo.negate %[[TEMP_107]] : tensor<2xf32> + // CHLO: %[[TEMP_109:.*]] = mhlo.compare GE, %[[TEMP_1]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_110:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_58]] : tensor<2xf32> + // CHLO: %[[TEMP_111:.*]] = mhlo.divide %[[TEMP_110]], %[[TEMP_59]] : tensor<2xf32> + // CHLO: %[[TEMP_112:.*]] = mhlo.multiply %[[TEMP_12]], %[[TEMP_66]] : tensor<2xf32> + // CHLO: %[[TEMP_113:.*]] = mhlo.add %[[TEMP_111]], %[[TEMP_112]] : tensor<2xf32> + // CHLO: %[[TEMP_114:.*]] = mhlo.constant dense<1.500000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_115:.*]] = mhlo.compare LE, %[[TEMP_55]], %[[TEMP_114]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_116:.*]] = mhlo.divide %[[TEMP_110]], %[[TEMP_61]] : tensor<2xf32> + // CHLO: %[[TEMP_117:.*]] = mhlo.add %[[TEMP_111]], %[[TEMP_116]] : tensor<2xf32> + // CHLO: %[[TEMP_118:.*]] = mhlo.subtract %[[TEMP_55]], %[[TEMP_10]] : tensor<2xf32> + // CHLO: %[[TEMP_119:.*]] = mhlo.select %[[TEMP_115]], %[[TEMP_117]], %[[TEMP_118]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_120:.*]] = mhlo.select %[[TEMP_109]], %[[TEMP_113]], %[[TEMP_119]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_121:.*]] = mhlo.select %[[TEMP_104]], %[[TEMP_108]], %[[TEMP_120]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_122:.*]] = mhlo.multiply %[[TEMP_121]], %[[TEMP_106]] : tensor<2xf32> + // CHLO: %[[TEMP_123:.*]] = mhlo.sqrt %[[TEMP_122]] : tensor<2xf32> + // CHLO: %[[TEMP_124:.*]] = mhlo.divide %[[TEMP_3]], %[[TEMP_123]] : tensor<2xf32> + // CHLO: %[[TEMP_125:.*]] = mhlo.add %[[TEMP_121]], %[[TEMP_123]] : tensor<2xf32> + // CHLO: %[[TEMP_126:.*]] = mhlo.log_plus_one %[[TEMP_125]] : tensor<2xf32> + // CHLO: %[[TEMP_127:.*]] = mhlo.select %[[TEMP_104]], %[[TEMP_124]], %[[TEMP_126]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_128:.*]] = mhlo.select %[[TEMP_84]], %[[TEMP_97]], %[[TEMP_127]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_129:.*]] = mhlo.complex %[[TEMP_72]], %[[TEMP_128]] : tensor<2xcomplex> + // CHLO: %[[TEMP_130:.*]] = mhlo.real %[[TEMP_129]] : (tensor<2xcomplex>) -> tensor<2xf32> + // CHLO: %[[TEMP_131:.*]] = mhlo.real %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> + // CHLO: %[[TEMP_132:.*]] = mhlo.atan2 %[[TEMP_130]], %[[TEMP_131]] : tensor<2xf32> + // CHLO: %[[TEMP_133:.*]] = mhlo.imag %[[TEMP_arg0:.*]] : (tensor<2xcomplex>) -> tensor<2xf32> + // CHLO: %[[TEMP_134:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2xf32> + // CHLO: %[[TEMP_135:.*]] = mhlo.compare LT, %[[TEMP_133]], %[[TEMP_134]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHLO: %[[TEMP_136:.*]] = mhlo.imag %[[TEMP_129]] : (tensor<2xcomplex>) -> tensor<2xf32> + // CHLO: %[[TEMP_137:.*]] = mhlo.negate %[[TEMP_136]] : tensor<2xf32> + // CHLO: %[[TEMP_138:.*]] = mhlo.select %[[TEMP_135]], %[[TEMP_136]], %[[TEMP_137]] : tensor<2xi1>, tensor<2xf32> + // CHLO: %[[TEMP_139:.*]] = mhlo.complex %[[TEMP_132]], %[[TEMP_138]] : tensor<2xcomplex> + // CHLO: return %[[TEMP_139:.*]] : tensor<2xcomplex> + %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + func.return %0 : tensor<2xcomplex> +} + +// ----- + +// CHECK-LABEL: @acos_dynamic +// CHLO-LABEL: @acos_dynamic +func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: chlo.acos %arg0 : tensor<*xf32> + // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be + // lowered further on ranked tensors. Unranked CHLO must be transformed to + // ranked code before further lowering. + // CHLO: "tf.Acos" + %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: func @cast_dynamic_i2f func.func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: mhlo.convert %arg0 : (tensor) -> tensor @@ -1445,6 +2448,15 @@ func.func @ceil_dynamic(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @complex_abs +func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + // CHECK: mhlo.abs %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> + %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + // CHECK-LABEL: @cos func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: mhlo.cosine %arg0 : tensor<2xf32> @@ -1454,6 +2466,15 @@ func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // ----- +// CHECK-LABEL: @tan +func.func @tan(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.tan %arg0 : tensor<2xf32> + %0 = "tf.Tan"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + // CHECK-LABEL: func @cos_dynamic func.func @cos_dynamic(%arg0: tensor) -> tensor { // CHECK: mhlo.cosine %arg0 : tensor @@ -1596,3 +2617,40 @@ func.func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> } +// ----- + +// CHECK-LABEL: @sigmoid_grad +func.func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xf32> + // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xf32> + // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> + // CHECK: return [[MUL1]] + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @sigmoid_grad_complex +func.func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xcomplex> + // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex> + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex> + // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex> + // CHECK: return [[MUL1]] + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> + func.return %0 : tensor<2xcomplex> +} + +// ----- + +// CHECK-LABEL: @sigmoid_grad_dynamic +func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = array} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index ae649253605bb9..047a5fb7b46bbc 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -1541,6 +1541,119 @@ class ConvertBroadcastToOp : public OpRewritePattern { } }; +/// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis +/// have to be a constant. +class ConvertRollOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::RollOp op, + PatternRewriter &rewriter) const override { + auto shift_ty = mlir::dyn_cast(op.getShift().getType()); + if (!shift_ty || shift_ty.getRank() != 0) { + return rewriter.notifyMatchFailure( + op, "require the type of shift to be 0D tensor"); + } + + APInt val; + if (!matchPattern(op.getAxis(), m_ConstantInt(&val))) { + return rewriter.notifyMatchFailure(op, "require axis to be constant"); + } + int axis = val.getSExtValue(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "require the type of input to have static shapes"); + } + ArrayRef input_shape = input_ty.getShape(); + int input_rank = input_ty.getRank(); + if (axis < 0) axis += input_rank; + + // Adjust large offsets into [0, axis_size). This also makes negative + // offsets positive. + // offset = ((offset % axis_size) + axis_size) % axis_size + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value offset = op.getShift(); + auto axis_size = b.create(b.getIntegerAttr( + getElementTypeOrSelf(offset.getType()), input_shape[axis])); + offset = b.create( + b.create(b.create(offset, axis_size), axis_size), + axis_size); + + // Stack two copies of the dimension, then slice from the calculated + // offset. This also works if shift is not constant. + // DynamicSliceOp requires the sizes being integer, and we can get the + // information from input shape. + auto concat = b.create( + ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); + Value zero = b.create( + b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); + SmallVector slice_begin_indices(input_rank, zero); + slice_begin_indices[axis] = b.create(axis_size, offset); + rewriter.replaceOpWithNewOp( + op, input_ty, concat, slice_begin_indices, + rewriter.getI64TensorAttr(input_shape)); + return success(); + } +}; + +/// Converts a TF::LeakyReluOp to HLO. +/// LeakyRelu(x) = alpha * x if x < 0 else x. +class ConvertLeakyReluOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::LeakyReluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value features = op.getFeatures(); + + // Use ConstantLike for `alpha` to match the shape of feature. + auto alphaVal = chlo::getConstantLike( + rewriter, loc, op.getAlpha().convertToFloat(), features); + Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); + + Value leakyActivationVal = + rewriter.create(loc, features, alphaVal); + + Value compareGtZero = rewriter.create( + loc, features, zeroVal, ComparisonDirection::GT); + + rewriter.replaceOpWithNewOp(op, compareGtZero, features, + leakyActivationVal); + return success(); + } +}; + +/// Converts a TF::LeakyReluGradOp to HLO. +/// LeakyReluGrad(gradient, inputs) = gradient if input > 0 +/// else alpha * gradient. +class ConvertLeakyReluGradOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::LeakyReluGradOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradients = op.getGradients(); + Value features = op.getFeatures(); + auto featureType = features.getType(); + + // Use ConstantLike for `alpha` to match the shape of feature. + auto alphaVal = chlo::getConstantLike( + rewriter, loc, op.getAlpha().convertToFloat(), features); + Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); + + Value leakyGradientVal = + rewriter.create(loc, gradients, alphaVal); + + Value compareGtZero = rewriter.create( + loc, features, zeroVal, ComparisonDirection::GT); + + rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, + gradients, leakyGradientVal); + return success(); + } +}; + // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. // For a Rank-2 input, it creates the following ops: // %1 = "mhlo.iota"() {iota_dimension = 0 : i64} @@ -1915,6 +2028,17 @@ class ConvertEinsumOp : public OpRewritePattern { } }; +// Bypasses IdentityN op. +class ConvertIdentityNOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::IdentityNOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + template class ConvertFFTOp : public OpRewritePattern { public: @@ -1993,6 +2117,7 @@ class ConvertFFTOp : public OpRewritePattern { } }; +using ConvertRFFTOp = ConvertFFTOp; using ConvertIRFFTOp = ConvertFFTOp; // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO @@ -2119,6 +2244,10 @@ class ConvertFusedBatchNormGradBase } }; +using ConvertFusedBatchNormGradOp = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV2Op = + ConvertFusedBatchNormGradBase; using ConvertFusedBatchNormGradV3Op = ConvertFusedBatchNormGradBase; @@ -2317,6 +2446,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { } }; +using ConvertFusedBatchNormV2Op = + ConvertFusedBatchNormBase; using ConvertFusedBatchNormV3Op = ConvertFusedBatchNormBase; @@ -2689,6 +2820,54 @@ using ConvertAvgPool2DGradOp = using ConvertAvgPool3DGradOp = ConvertAvgPoolGradOp; +// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window +// dimensions with max as the reduction function. +// +// Sample result for VALID padding mode: +// +// %init = arith.constant dense<...> : tensor +// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// {window_dimensions = ..., window_strides = ... } +// +template +class ConvertMaxPoolOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Type element_type = + mlir::cast(op.getInput().getType()).getElementType(); + if (!element_type.isSignlessIntOrFloat()) return failure(); + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + if (padding == tensorflow::Padding::EXPLICIT) { + return failure(); + } + Location loc = op.getLoc(); + ConstantOp init = GetScalarLimitConstOfType( + element_type, loc, hlo::kInfinityLowest, &rewriter); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty) return failure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + auto reduce = rewriter.create( + loc, op.getType(), op.getInput(), init, + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(element_type, &reduce.getBody(), &rewriter); + + rewriter.replaceOp(op, reduce.getResult(0)); + return success(); + } +}; + +using ConvertMaxPool2DOp = ConvertMaxPoolOp; +using ConvertMaxPool3DOp = ConvertMaxPoolOp; // Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on // the condition only. @@ -2854,6 +3033,127 @@ class ConvertSliceOpDynamic : public OpRewritePattern { } }; +static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, + Value *out_lhs, Value *out_rhs, + PatternRewriter *rewriter) { + // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is: + // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS] + // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS] + // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS] + // To perform the matmul, we need to first broadcast lhs and rhs to a common + // set of leading dimensions before doing the actual matmul. + // That's what the code below does. + // In particular, we populate out_lhs and out_rhs to have dimension structure: + // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS] + // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS] + // To do this, we need to calculate those output shapes, which involves + // slicing off the leading batch dims of each operand, broadcasting them, + // then concatenating the broadcasted leading dims back to the row/col dims. + // Finally, we create a TF::BroadcastTo op that does the actual broadcast. + + // TODO(silvasean): Reduce duplication across reified shape calculations and + // the static computation of output types needed to create ops. + Value lhs_shape = rewriter->create(loc, lhs); + Value rhs_shape = rewriter->create(loc, rhs); + Value const_neg2 = + rewriter->create(loc, rewriter->getIndexAttr(-2)); + auto shape_type = shape::ShapeType::get(rewriter->getContext()); + auto lhs_splitted = rewriter->create( + loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); + auto rhs_splitted = rewriter->create( + loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); + // The last two dimensions are the matrix row/col dimensions. Don't broadcast + // them. + SmallVector result_batch_shape_compile_time_extents; + mlir::OpTrait::util::getBroadcastedShape( + lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2), + result_batch_shape_compile_time_extents); + auto result_batch_shape = rewriter->create( + loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(), + /*error=*/nullptr); + // Lambda which handles the broadcasting of one side to the common + // leading-batch dimensions. + auto broadcast_one_side = [&](Value side, RankedTensorType type, + Value tail_shape, Value *out_side) { + ArrayRef matrix_dims = type.getShape().take_back(2); + auto result_shape = result_batch_shape_compile_time_extents; + result_shape.append(matrix_dims.begin(), matrix_dims.end()); + auto result_type = tensorflow::GetTypeFromTFTensorShape( + result_shape, type.getElementType()); + auto shape = rewriter->create( + loc, shape_type, result_batch_shape, tail_shape); + auto shape_tensor = rewriter->create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(result_shape.size())}, + rewriter->getIndexType()), + shape); + *out_side = rewriter->create(loc, result_type, side, + shape_tensor); + }; + broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs); + broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs); +} + +class ConvertBatchMatMulV2Op : public OpRewritePattern { + public: + // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved + // to CHLO and it is missing legalization to MHLO. Once that is done, this + // pattern's benefit can be changed back to one as well as the fallback + // lowering pattern for the op can be removed. + // + // Set benefit of this pattern to zero to prefer the fallback pattern when + // available and applicable. That pattern avoids broadcast on operands and is + // therefore faster. + // + // Native legalization for BatchMatMulV3 needs to be added as well. + explicit ConvertBatchMatMulV2Op(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, + PatternRewriter &rewriter) const override { + Value lhs = op.getX(); + Value rhs = op.getY(); + auto lhs_type = mlir::dyn_cast(lhs.getType()); + auto rhs_type = mlir::dyn_cast(rhs.getType()); + if (!lhs_type || !rhs_type) return failure(); + if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { + lhs = rewriter.create(op.getLoc(), lhs_type, lhs); + } + if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { + rhs = rewriter.create(op.getLoc(), rhs_type, rhs); + } + + // Broadcast both operands. + BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, + &rewriter); + lhs_type = mlir::cast(lhs.getType()); + rhs_type = mlir::cast(rhs.getType()); + assert(lhs_type.getRank() == rhs_type.getRank()); + int64_t rank = lhs_type.getRank(); + auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); + auto lhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); + auto rhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); + auto dimension_numbers = DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhs_batching_dimensions=*/batch_dimensions, + /*rhs_batching_dimensions=*/batch_dimensions, + /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, + /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); + // TODO(silvasean): Emit shape checks for contracting dimensions. + // (The batch dimensions are checked by the broadcasting logic) + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, rhs, dimension_numbers, + /*precision_config=*/GetPrecisionConfig(&rewriter), + /*algorithm=*/DotAlgorithmAttr{}); + return success(); + } +}; + // Converts the tf.Split op into a series of HLO slice ops when the tensor to be // split has fully static shape and the dimension to split is a constant. // @@ -4422,6 +4722,8 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { } }; +using ConvertMaxPool2DGradOp = + ConvertMaxPoolGradOp; using ConvertMaxPool3DGradOp = ConvertMaxPoolGradOp; @@ -5221,6 +5523,50 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { } }; +// Converts the tf.SigmoidGradOp +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertSigmoidGradOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SigmoidGradOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value y = op.getY(); + Value dy = op.getDy(); + auto tp_y = mlir::dyn_cast(y.getType()); + auto tp_dy = mlir::dyn_cast(dy.getType()); + if (!tp_y || !tp_dy) return failure(); + + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure(); + + Attribute attr; + Type elem_tp = tp_y.getElementType(); + if (elem_tp.isSignlessInteger()) { + attr = rewriter.getIntegerAttr(elem_tp, 1); + } else { + assert(mlir::isa(elem_tp)); + attr = rewriter.getFloatAttr(elem_tp, 1); + } + Value one = rewriter.create( + loc, DenseElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); + + auto v0 = rewriter.create( + loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y)); + auto v1 = rewriter.create( + loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y)); + auto result = rewriter.create( + loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1)); + + rewriter.replaceOp(op, result.getOperation()->getResults()); + return success(); + } +}; + // Converts TF unsorted segment reduction ops to XLA HLO scatter op. // // TF unsorted segment reduction op peforms the following calculation: @@ -6441,6 +6787,7 @@ class LowerControlFlowOp : public OpConversionPattern { } // end namespace #include "tensorflow/compiler/mlir/tf2xla/transforms/generated_legalize_tf.inc" +// LINT.IfChange void PopulateLegalizeTfPatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); @@ -6450,6 +6797,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertAnyOp, ConvertArgMaxOp, ConvertArgMinOp, + ConvertBatchMatMulV2Op, ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, @@ -6468,10 +6816,15 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertDynamicExpandDimsOp, ConvertDynamicSqueezeOp, ConvertEinsumOp, + ConvertRFFTOp, ConvertIRFFTOp, + ConvertFusedBatchNormGradOp, + ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, + ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, + ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, @@ -6480,6 +6833,9 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, + ConvertMaxPool2DOp, + ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp, @@ -6519,10 +6875,14 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertXlaSortOp, ConvertXlaVariadicReduceV2Op, ConvertXlaVariadicSortOp, + ConvertRollOp, + ConvertLeakyReluOp, + ConvertLeakyReluGradOp, ConvertSplitOpDynamic, ConvertSliceOpDynamic, ConvertTileOpDynamic, ConvertUnpackOpDynamic, + ConvertSigmoidGradOpDynamic, ConvertConv2DDynamic, ConvertPadOpDynamic, ConvertGatherNdOpDynamic, @@ -6532,5 +6892,6 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, LowerYieldOp>(context); // clang-format on } +// LINT.ThenChange(:MlirAlwaysOps) } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 8fd138d0f204d7..46f3ebfe19104d 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -238,6 +238,8 @@ class DirectLogicalBinaryPat foreach fromToBinPair = [[TF_LogicalAndOp, CHLO_BroadcastAndOp], [TF_LogicalOrOp, CHLO_BroadcastOrOp], + [TF_BitwiseAndOp, CHLO_BroadcastAndOp], + [TF_BitwiseOrOp, CHLO_BroadcastOrOp], [TF_BitwiseXorOp, CHLO_BroadcastXorOp]] in def : DirectLogicalBinaryPat; @@ -357,6 +359,25 @@ def LegalizeGatherV2 : (GetHLOAxisFromTFAxis $axis, $params), (GetHLOAxisFromTFAxis $batch_dims, $indices))>; +//===----------------------------------------------------------------------===// +// Pad op patterns. +//===----------------------------------------------------------------------===// + +class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< + "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + +class SliceDenseIntElementsAttr : NativeCodeCall< + "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + +// Interior padding attribute based on the TF padding. +def GetInteriorPadding : NativeCodeCall < + "GetInteriorPadding($0.cast())">; + +def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), + (MHLO_PadOp $input, $c, + (SliceDenseIntElementsAttrColumn2D<"0"> $padding), + (SliceDenseIntElementsAttrColumn2D<"1"> $padding), + (GetInteriorPadding $padding))>; //===----------------------------------------------------------------------===// // Identity op patterns. @@ -723,6 +744,22 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), def : Pat<(TF_XlaReplicaIdOp), (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; +//===----------------------------------------------------------------------===// +// XlaGather op. +//===----------------------------------------------------------------------===// + +def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; + +def HasValidGatherDims : Constraint>; + +def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), + $dimension_numbers, $indices_are_sorted), + (MHLO_GatherOp $operand, $start_indices, + (ToGatherDimNumsAttr $dimension_numbers), + (CastElementsToI64Elements $slice_sizes), + $indices_are_sorted), + [(HasValidGatherDims $dimension_numbers)]>; + //===----------------------------------------------------------------------===// // XlaDotOp op. //===----------------------------------------------------------------------===// From ba785bec47cca9d9f7b444ddae0d5aa621f94e87 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 13:47:48 -0700 Subject: [PATCH 0115/1324] Add a `LiteralBase::EachCellUntilFailure()` to support early return when iterating over cells. Also use this to simplify the code in several places. PiperOrigin-RevId: 742826403 --- third_party/xla/xla/literal.h | 24 ++++++++++++++++++++++-- third_party/xla/xla/literal_test.cc | 24 ++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/literal.h b/third_party/xla/xla/literal.h index e7e4d38056a4a0..1690b36b69d994 100644 --- a/third_party/xla/xla/literal.h +++ b/third_party/xla/xla/literal.h @@ -292,6 +292,14 @@ class LiteralBase { void EachCell( absl::FunctionRef indices, NativeT value)> per_cell) const; + template + // Like the above, but allows early return. At any time in the iteration, if + // the callback returns false, the iteration will be aborted and the function + // will return false. Otherwise it will iterate over all elements and return + // true. + bool EachCellUntilFailure( + absl::FunctionRef indices, NativeT value)> + per_cell) const; // Checks whether all of this literal's values are equal to the given scalar // literal. @@ -1960,10 +1968,21 @@ template TF_ATTRIBUTE_NOINLINE void LiteralBase::EachCell( absl::FunctionRef indices, NativeT value)> per_cell) const { + EachCellUntilFailure( + [=](absl::Span indices, NativeT value) { + per_cell(indices, value); + return true; + }); +} + +template +TF_ATTRIBUTE_NOINLINE bool LiteralBase::EachCellUntilFailure( + absl::FunctionRef indices, NativeT value)> + per_cell) const { CHECK(LayoutUtil::IsDenseArray(shape())) << __func__ << " is only supported for dense arrays: " << shape(); if (ShapeUtil::IsZeroElementArray(shape())) { - return; + return true; } std::vector indices(shape().dimensions().size(), 0); @@ -1972,8 +1991,9 @@ TF_ATTRIBUTE_NOINLINE void LiteralBase::EachCell( shape_dynamic.set_dimensions(i, GetDynamicSize(i)); } do { - per_cell(indices, Get(indices)); + if (!per_cell(indices, Get(indices))) return false; } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices))); + return true; } template diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 326e9975d57e1e..e7e49d1b7680ff 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -971,6 +971,30 @@ TEST_F(LiteralUtilTest, TransposeR0) { EXPECT_EQ(original, reshape); } +TEST_F(LiteralUtilTest, EachCellUntilFailureAbortsOnFailure) { + auto original = LiteralUtil::CreateR1({1, 2, -2, 4, 5}); + int count = 0; + auto check_positive = [&](absl::Span indices, + int32_t value) -> bool { + count++; + return value > 0; + }; + EXPECT_FALSE(original.EachCellUntilFailure(check_positive)); + EXPECT_EQ(count, 3); +} + +TEST_F(LiteralUtilTest, EachCellUntilFailureGoesThroughAllCells) { + auto original = LiteralUtil::CreateR1({1, 2, 3, 4, 5}); + int count = 0; + auto check_positive = [&](absl::Span indices, + int32_t value) -> bool { + count++; + return value > 0; + }; + EXPECT_TRUE(original.EachCellUntilFailure(check_positive)); + EXPECT_EQ(count, 5); +} + TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] From 1418f4bcf3a0c5f87134719559c7f2c54522d078 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 13:48:47 -0700 Subject: [PATCH 0116/1324] Enforce that `Shape::clear_layout()` is only called on array shapes. PiperOrigin-RevId: 742826699 --- third_party/xla/xla/shape.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 6031146c1146c3..6fe7dd906a6f6d 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -371,12 +371,7 @@ class Shape { // Removes the layout of the shape, if any. // Precondition: this is an array shape. - void clear_layout() { - // TODO(b/404276923): ensure that this is never called on non-array shapes. - if (auto* const state = if_array_state()) { - state->layout = std::nullopt; - } - } + void clear_layout() { array_state().layout = std::nullopt; } // Recursively clear all dynamic dimension of a shape, including bounded and // unbounded dynamic dimensions. Clearing a dynamic dimension means From 7327a29edb31fd61f69a94ae63e53e9560a71795 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Tue, 1 Apr 2025 13:55:25 -0700 Subject: [PATCH 0117/1324] XLA:Translate: Reorder chlo passes to match with the PJRT lowering path. PiperOrigin-RevId: 742829315 --- third_party/xla/xla/hlo/translate/BUILD | 1 + .../xla/xla/hlo/translate/stablehlo.cc | 12 +++++- third_party/xla/xla/hlo/translate/tests/BUILD | 3 ++ .../xla/xla/hlo/translate/tests/chlo.mlir | 38 +++++++++++++++++++ 4 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/tests/chlo.mlir diff --git a/third_party/xla/xla/hlo/translate/BUILD b/third_party/xla/xla/hlo/translate/BUILD index 7eb5a0c0b63c61..4a5bcf1e127909 100644 --- a/third_party/xla/xla/hlo/translate/BUILD +++ b/third_party/xla/xla/hlo/translate/BUILD @@ -114,5 +114,6 @@ cc_library( "@llvm-project//mlir:Transforms", "@llvm-project//mlir:UBDialect", "@stablehlo//:register", + "@stablehlo//:stablehlo_passes", ], ) diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index e2f758a65e48f7..b6d7490f2b6e7c 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "stablehlo/dialect/Register.h" +#include "stablehlo/transforms/Passes.h" #include "xla/debug_options_flags.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" @@ -67,9 +68,16 @@ absl::Status MhloToStablehlo(mlir::ModuleOp module) { absl::Status StablehloToMhlo(mlir::ModuleOp module, bool run_canonicalizer) { mlir::MLIRContext* context = module->getContext(); mlir::PassManager pm(context); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + // CHLO -> MHLO for high level ops (TopK, Erf, RaggedDot, etc.) + // CHLO -> StableHLO otherwise + pm.addNestedPass( + mlir::stablehlo_ext::createChloRecomposeOpsPass()); + pm.addPass(mlir::createSymbolDCEPass()); pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); + mlir::mhlo::createChloLegalizeToHighLevelMhloPass()); + pm.addNestedPass( + mlir::stablehlo::createChloLegalizeToStablehloPass()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); if (run_canonicalizer) { pm.addNestedPass(mlir::createCanonicalizerPass()); } diff --git a/third_party/xla/xla/hlo/translate/tests/BUILD b/third_party/xla/xla/hlo/translate/tests/BUILD index 4c8c5068c6ebb6..68bb7fc8317101 100644 --- a/third_party/xla/xla/hlo/translate/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/tests/BUILD @@ -10,6 +10,8 @@ lit_test_suite( name = "all_tests", srcs = enforce_glob( [ + # go/keep-sorted start + "chlo.mlir", "emit_mhlo.hlo", "emit_proto.mlir", "print_large_constants.mlir", @@ -17,6 +19,7 @@ lit_test_suite( "simple.hlo", "simple.mlir", "vhlo_input.mlir", + # go/keep-sorted end ], include = [ "*.mlir", diff --git a/third_party/xla/xla/hlo/translate/tests/chlo.mlir b/third_party/xla/xla/hlo/translate/tests/chlo.mlir new file mode 100644 index 00000000000000..a20c78a6d10296 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/chlo.mlir @@ -0,0 +1,38 @@ +// RUN: hlo-translate -mlir-to-hlo -split-input-file %s | FileCheck %s + +// Validating chlo.op -> mhlo.op -> hlo.op conversion. + +// CHECK-LABEL: main +func.func @main(%arg : tensor) -> tensor { + // CHECK: %[[ARG:.*]] = f16[] parameter(0) + // CHECK: erf(%[[ARG]]) + %1 = "chlo.erf"(%arg) : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +func.func @main(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + // CHECK: %[[ARG:.*]] = f32[16,16] parameter(0) + // CHECK: (f32[16,8], s32[16,8]) topk(%[[ARG]]), k=8, largest=true + %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) + func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> +} + +// ----- + +func.func @main(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + // CHECK: ragged-dot + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} From f5b11ca4491e43a34c30e31ece924e0de75c5f9c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 13:58:41 -0700 Subject: [PATCH 0118/1324] Enforce that `Shape::tuple_shapes_size()` is only called on tuple shapes. PiperOrigin-RevId: 742830501 --- third_party/xla/xla/hlo/builder/xla_builder.cc | 2 +- .../xla/service/dynamic_dimension_inference.cc | 17 ++++++++++------- third_party/xla/xla/service/gpu/gpu_fusible.cc | 6 ++++-- .../xla/xla/service/hlo_cost_analysis.cc | 3 +++ third_party/xla/xla/service/hlo_verifier.cc | 2 +- third_party/xla/xla/shape.h | 8 +------- .../stream_executor/tpu/c_api_conversions.cc | 3 ++- .../tpu/c_api_conversions_test.cc | 15 +++++++-------- 8 files changed, 29 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index f4f94586a89eda..180d00d702afe2 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -759,7 +759,7 @@ absl::StatusOr XlaBuilder::Build( // the backend. if (remove_dynamic_dimensions) { std::function remove_dynamic_dimension = [&](Shape* shape) { - if (shape->tuple_shapes_size() != 0) { + if (shape->IsTuple()) { for (int i = 0; i < shape->tuple_shapes_size(); ++i) { remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); } diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.cc b/third_party/xla/xla/service/dynamic_dimension_inference.cc index b9404d2874ccb9..f92a01894420b1 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference.cc @@ -1946,16 +1946,16 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( // Only look at branch_index + 1, the correct operand index for a // given branch. const int64_t operand_index = branch_index + 1; - + const Shape& operand_shape = hlo->operand(operand_index)->shape(); int operand_count = - hlo->operand(operand_index)->shape().tuple_shapes_size(); + operand_shape.IsTuple() ? operand_shape.tuple_shapes_size() : 0; // Prepare to pass dynamic dimension into the new computation and add // dynamic dimension sizes as parameters to the new tuple. TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( hlo, operand_index, [&](HloInstruction*, ShapeIndex, int64_t, int64_t, HloInstruction* dynamic_size) -> absl::Status { - TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple()) + TF_RET_CHECK(operand_shape.IsTuple()) << "Only tuple typed inputs can have dynamic dimension. Please " "file a bug against XLA team."; const HloInstruction* tuple_operand = hlo->operand(operand_index); @@ -2035,7 +2035,8 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( new_branch_computations.push_back(new_computation); new_operands.push_back(new_operand); } - int tuple_count = hlo->shape().tuple_shapes_size(); + int tuple_count = + hlo->shape().IsTuple() ? hlo->shape().tuple_shapes_size() : 0; // The dynamism of the output of branches can be different. // E.g., // true_branch (s32[<=4]) @@ -2113,7 +2114,8 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( hlo->mutable_operand(0), new_branch_computations, new_operands)); HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix( - new_conditional, hlo->shape().tuple_shapes_size()); + new_conditional, + hlo->shape().IsTuple() ? hlo->shape().tuple_shapes_size() : 0); // Now set the dynamic dimensions of the newly created conditional. dynamic_output_mapping.ForEachElement( [&](const ShapeIndex& index, @@ -2230,11 +2232,12 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( // (represented by a shape index as output index and an int64_t dimension // number) to output index (represented by an int64_t) is tracked for the // while instruction. - Shape original_shape = hlo->shape(); + const Shape& original_shape = hlo->shape(); ShapeTree> dynamic_output_mapping( original_shape); std::vector operands_to_add; - const int original_tuple_count = original_shape.tuple_shapes_size(); + const int original_tuple_count = + original_shape.IsTuple() ? original_shape.tuple_shapes_size() : 0; int operand_count = original_tuple_count; // Clean up the result shape DynamicParameterBinding binding_for_while; diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 3644f02667612f..6cb217a2e3337c 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -842,8 +842,10 @@ bool MayPreventVectorization(const HloFusionAdaptor& fusion) { case HloOpcode::kConcatenate: return node.instruction().operand_count() > kMaxConcatArgumentsForUnrolling; - case HloOpcode::kReduce: - return node.instruction().shape().tuple_shapes_size() > 1; + case HloOpcode::kReduce: { + const Shape& shape = node.instruction().shape(); + return shape.IsTuple() && shape.tuple_shapes_size() > 1; + } default: return false; } diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index 89ba32c6221cd9..7a70d649c4b30e 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -1170,6 +1170,9 @@ absl::Status HloCostAnalysis::FusionProcessOutputBytesAccessed( if (bytes_accessed != 0) { return bytes_accessed; } + if (!shape.IsTuple()) { + return bytes_accessed; + } for (int i = 0; i < shape.tuple_shapes_size(); ++i) { const Shape& subshape = shape.tuple_shapes(i); if (!subshape.IsTuple() && ShouldFilterFusionOutputIndex(fusion, {i})) { diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index f754f8e8926c50..164af80efccb19 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -835,7 +835,7 @@ absl::Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) { hlo->operands(), std::back_inserter(operand_shapes), [](const HloInstruction* operand) { return &(operand->shape()); }); std::vector context_shapes; - if (hlo->shape().tuple_shapes_size() > 2) { + if (hlo->shape().IsTuple() && hlo->shape().tuple_shapes_size() > 2) { context_shapes = std::vector(hlo->shape().tuple_shapes().begin() + 2, hlo->shape().tuple_shapes().end()); } diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 6fe7dd906a6f6d..387e335340e3db 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -310,13 +310,7 @@ class Shape { // Returns the number of top-level tuple components in this shape. // Precondition: this is a tuple shape. - int tuple_shapes_size() const { - if (const auto* const state = if_tuple_state()) { - return state->tuple_shapes.size(); - } - // TODO(b/404276923): ensure that this is never called on non-tuple shapes. - return 0; - } + int tuple_shapes_size() const { return tuple_state().tuple_shapes.size(); } // Returns the shape of the i-th tuple component. // Precondition: this is a tuple shape and `index` is a valid tuple component diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc index f6bf364f853da8..e0098ecf7e5d03 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc @@ -277,7 +277,8 @@ void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { CreateVector(xla_shape.dimensions(), &c_shape->dimensions); CreateVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); - c_shape->ntuple_shapes = xla_shape.tuple_shapes_size(); + c_shape->ntuple_shapes = + xla_shape.IsTuple() ? xla_shape.tuple_shapes_size() : 0; if (c_shape->ntuple_shapes > 0) { c_shape->tuple_shapes = new XLA_Shape[c_shape->ntuple_shapes]; for (int i = 0; i < c_shape->ntuple_shapes; ++i) { diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc index c228a15322385c..b9465c1a46d1e7 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -218,10 +218,8 @@ TEST(XlaShape, ToCScalar) { MakeSpan(c_shape.dynamic_dimensions); EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); - int cpp_ntuple_shapes = cpp_shape.tuple_shapes_size(); - int c_ntuple_shapes = c_shape.ntuple_shapes; - EXPECT_EQ(cpp_ntuple_shapes, c_ntuple_shapes); - EXPECT_EQ(cpp_ntuple_shapes, 0); + EXPECT_FALSE(cpp_shape.IsTuple()); + EXPECT_EQ(c_shape.ntuple_shapes, 0); bool cpp_has_layout = cpp_shape.has_layout(); bool c_has_layout = c_shape.has_layout; @@ -231,7 +229,8 @@ TEST(XlaShape, ToCScalar) { } TEST(XlaShape, ToCNested) { - xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + const xla::Shape cpp_shape = + xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); XLA_Shape c_shape; ToC(cpp_shape, &c_shape); @@ -247,10 +246,10 @@ TEST(XlaShape, ToCNested) { MakeSpan(c_shape.dynamic_dimensions); EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); - int cpp_ntuple_shapes = cpp_shape.tuple_shapes_size(); - int c_ntuple_shapes = c_shape.ntuple_shapes; - EXPECT_EQ(cpp_ntuple_shapes, c_ntuple_shapes); + EXPECT_FALSE(cpp_shape.IsTuple()); + EXPECT_EQ(c_shape.ntuple_shapes, 0); + const int c_ntuple_shapes = c_shape.ntuple_shapes; const std::vector& cpp_tuple_shapes = cpp_shape.tuple_shapes(); absl::Span c_tuple_shapes(c_shape.tuple_shapes, c_ntuple_shapes); From f7b45ffd95450609a3e96429ce9cb3bd4955c85f Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Tue, 1 Apr 2025 14:57:08 -0700 Subject: [PATCH 0119/1324] PinnedHostMemory Support in TfrtGpuClient transfer API PiperOrigin-RevId: 742851837 --- third_party/xla/xla/pjrt/gpu/tfrt/BUILD | 1 + .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 127 ++++++++++++----- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h | 4 +- .../xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 132 +++++++++++++++++- 4 files changed, 226 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD index 3f863f27d1dfc3..e2af3a2459ae1e 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD +++ b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD @@ -30,6 +30,7 @@ cc_library( "//xla:shape_layout", "//xla:shape_tree", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 65b8ffec8b34f7..0707ad64cc8b4a 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/gpu/gpu_topology.h" @@ -88,6 +89,7 @@ limitations under the License. #include "xla/shape_layout.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" @@ -117,27 +119,61 @@ limitations under the License. namespace xla { namespace { +absl::StatusOr GetDestinationDeviceShape(const Shape& on_host_shape, + TfrtGpuDevice* device, + TfrtGpuClient* client, + PjRtMemorySpace* memory_space) { + PjRtMemorySpace* default_memory_space = + device->default_memory_space().value_or(nullptr); + if (!memory_space) { + memory_space = default_memory_space; + } + bool is_pinned_host_memory = + memory_space && (memory_space->kind() == PinnedHostMemorySpace::kKind); + // Only allow pinned host memory or device memory. + if (memory_space != default_memory_space && !is_pinned_host_memory) { + return InvalidArgument("Buffer allocation: invalid memory space"); + } + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); + TransferManager* transfer_manager = + client->xla_client()->backend().transfer_manager(); + auto memory_space_shape_fn = [is_pinned_host_memory, + transfer_manager](const Shape& on_host_shape) { + Shape result = transfer_manager->HostShapeToDeviceShape(on_host_shape); + if (is_pinned_host_memory) { + result.mutable_layout()->set_memory_space(Layout::kHostMemorySpace); + } + return result; + }; + Shape on_device_shape = memory_space_shape_fn(on_host_shape); + TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); + return on_device_shape; +} + absl::StatusOr> AllocateTfrtGpuDestinationBuffer( - const Shape& on_device_shape, + const Shape& on_host_shape, absl::InlinedVector, 4> definition_events, TfrtGpuDevice* device, TfrtGpuClient* client, PjRtMemorySpace* memory_space) { - if (on_device_shape.IsTuple()) { + if (on_host_shape.IsTuple()) { return Unimplemented( "tuple case not implemented for AllocateTfrtGpuDestinationBuffer"); } - size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape); - TF_ASSIGN_OR_RETURN( - auto device_buffer, - MaybeOwningGpuMemory::AllocateShared(device->allocator(), byte_size)); - auto buffer_async_value_ref = - tsl::MakeAvailableAsyncValueRef( - std::move(device_buffer)); - return std::make_unique( - on_device_shape, - std::make_unique( - buffer_async_value_ref, std::move(definition_events)), - client, device, memory_space); + TF_ASSIGN_OR_RETURN( + Shape on_device_shape, + GetDestinationDeviceShape(on_host_shape, device, client, memory_space)); + size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape); + TF_ASSIGN_OR_RETURN(auto device_buffer, MaybeOwningGpuMemory::AllocateShared( + device->allocator(), byte_size)); + auto buffer_async_value_ref = + tsl::MakeAvailableAsyncValueRef( + std::move(device_buffer)); + return std::make_unique( + on_device_shape, + std::make_unique( + buffer_async_value_ref, std::move(definition_events)), + client, device, memory_space); } void EnqueueWork(tsl::thread::ThreadPool* pool, @@ -262,23 +298,18 @@ class TfrtGpuAsyncHostToDeviceTransferManager final .transfer_manager() ->ChooseCompactLayoutForShape(device_shape)); } - int64_t byte_size = ShapeUtil::ByteSizeOf(device_shape); - - buffer_ptrs.push_back( - tsl::MakeUnconstructedAsyncValueRef()); - absl::StatusOr buffer_allocated = - MaybeOwningGpuMemory::AllocateShared(device->allocator(), byte_size); - if (!buffer_allocated.ok()) { - copy_event.SetError(buffer_allocated.status()); + absl::StatusOr> buffer = + AllocateTfrtGpuDestinationBuffer(device_shape, + definition_events.back(), device, + client, memory_space); + if (!buffer.ok()) { + copy_event.SetError(buffer.status()); return absl::InternalError("Failed to allocate buffer."); } else { - buffer_ptrs.back().emplace(std::move(buffer_allocated.value())); + buffer_ptrs.push_back(buffer->get()->GetBufferPtr()); } - auto tracked_device_buffer = std::make_unique( - buffer_ptrs.back(), definition_events.back()); - buffers.push_back(std::make_unique( - device_shape, std::move(tracked_device_buffer), client, device, - memory_space)); + + buffers.push_back(std::move(*buffer)); } return std::make_unique( @@ -1057,6 +1088,12 @@ TfrtGpuClient::TfrtGpuClient( memory_spaces_.push_back(pinned.get()); owned_memory_spaces_.push_back(std::move(pinned)); } + // We don't promise anything about the order of memory spaces, but this + // sorting is done for consistency with the device list that's sorted above. + absl::c_sort(memory_spaces_, + [](const PjRtMemorySpace* a, const PjRtMemorySpace* b) { + return a->id() < b->id(); + }); LOG(INFO) << "TfrtGpuClient created."; } @@ -1240,8 +1277,12 @@ TfrtGpuClient::CreateUninitializedBuffer(const Shape& shape, << shape.DebugString() << " memory_space: " << memory_space->DebugString(); } + TransferManager* transfer_manager = + xla_client()->backend().transfer_manager(); + TF_ASSIGN_OR_RETURN(Shape compact_shape, + transfer_manager->ChooseCompactLayoutForShape(shape)); return AllocateTfrtGpuDestinationBuffer( - shape, /*definition_events=*/{}, + compact_shape, /*definition_events=*/{}, tsl::down_cast(memory_space->devices()[0]), this, memory_space); } @@ -1482,6 +1523,7 @@ absl::StatusOr> TfrtGpuClient::BufferFromHostBuffer( ShapeUtil::ByteStrides(device_shape, absl::MakeSpan(tmp_strides))); byte_strides = tmp_strides; } + int64_t byte_size = ShapeUtil::ByteSizeOf(device_shape); TransferManager* transfer_manager = xla_client_->backend().transfer_manager(); @@ -1517,6 +1559,11 @@ absl::StatusOr> TfrtGpuClient::BufferFromHostBuffer( auto* gpu_device = tsl::down_cast(device); + TF_ASSIGN_OR_RETURN( + Shape destination_device_shape, + GetDestinationDeviceShape(device_shape, gpu_device, this, memory_space)); + byte_size = ShapeUtil::ByteSizeOf(destination_device_shape); + auto gpu_buffer = tsl::MakeUnconstructedAsyncValueRef(); absl::InlinedVector, 4> definition_events; tsl::AsyncValueRef copy_event = @@ -1876,6 +1923,17 @@ PjRtFuture<> TfrtGpuBuffer::GetReadyFuture() { }); } +bool TfrtGpuBuffer::IsOnCpu() const { + return memory_space() != nullptr && + memory_space()->kind() == PinnedHostMemorySpace::kKind; +} + +const tsl::AsyncValueRef& TfrtGpuBuffer::GetBufferPtr() + const { + absl::MutexLock lock(&mu_); + return tracked_device_buffer_->buffer(); +} + absl::StatusOr> TfrtGpuBuffer::AcquireExternalReference() { class ScopedExternalReference : public PjRtBuffer::ExternalReference { @@ -2334,19 +2392,18 @@ TfrtGpuExecutable::TfrtGpuExecutable( addressable_device_logical_ids_( std::move(addressable_device_logical_ids)), addressable_devices_(std::move(addressable_devices)) { - executables_.reserve(executables.size()); + TransferManager* transfer_manager = + client_->xla_client()->backend().transfer_manager(); tsl::Fprint128 fingerprint = tsl::Fingerprint128(fingerprint_); + executables_.reserve(executables.size()); for (auto& executable : executables) { const auto& computation_layout = executable->executable()->module().entry_computation_layout(); std::vector parameter_shapes; parameter_shapes.reserve(computation_layout.parameter_count()); for (int i = 0; i < computation_layout.parameter_count(); ++i) { - // TODO: b/400541410 - Convert to device shape when we have transfer - // manager. - // parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape( - // computation_layout.parameter_shape(i))); - parameter_shapes.push_back(computation_layout.parameter_shape(i)); + parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape( + computation_layout.parameter_shape(i))); } on_device_executable_parameter_shapes_.push_back( std::make_shared>(std::move(parameter_shapes))); diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h index 2213df8b6603a5..32f8dd2766d0b8 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h @@ -473,7 +473,9 @@ class TfrtGpuBuffer final : public PjRtBuffer { PjRtFuture<> GetReadyFuture() override; - bool IsOnCpu() const override { return false; } + bool IsOnCpu() const override; + + const tsl::AsyncValueRef& GetBufferPtr() const; private: // Acquires the device buffer for shared read-only usages, and it also adds diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index 5af05acbd64c5f..24ba9815926254 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -80,7 +80,10 @@ class DonationTransactionPeer { namespace { using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Gt; using ::testing::HasSubstr; +using ::testing::SizeIs; using ::testing::status::IsOkAndHolds; using ::testing::status::StatusIs; @@ -347,6 +350,60 @@ TEST(TfrtGpuClientTest, ShouldStageHostToDeviceTransfersSetToFalse) { LiteralTestUtil::Equal(*literal, LiteralUtil::CreateR1(data))); } +TEST(TfrtGpuClientTest, BufferFromHostBufferPinnedMemory) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + std::vector data{1, 2, 3, 4}; + Shape shape = ShapeUtil::MakeShape(S32, {4}); + TF_ASSERT_OK_AND_ASSIGN( + auto* pinned_memory_space, + client->addressable_devices()[0]->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); + TF_ASSERT_OK_AND_ASSIGN( + auto buffer, + client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + pinned_memory_space, /*device_layout=*/nullptr)); + + EXPECT_EQ(buffer->memory_space()->kind(), "pinned_host"); + EXPECT_TRUE(buffer->IsOnCpu()); + + TF_ASSERT_OK_AND_ASSIGN(auto literal, buffer->ToLiteralSync()); + std::vector expected{1, 2, 3, 4}; + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1(expected), + *literal)); +} + +TEST(TfrtGpuClientTest, CopyToPinnedHostMemorySpace) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + std::vector data{1, 2, 3, 4}; + Shape shape = ShapeUtil::MakeShape(S32, {4}); + auto device = client->addressable_devices()[0]; + TF_ASSERT_OK_AND_ASSIGN( + auto buffer, + client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + *device->default_memory_space(), /*device_layout=*/nullptr)); + + EXPECT_EQ(buffer->memory_space()->kind(), "device"); + + auto* pinned_memory_space = device->memory_spaces()[1]; + EXPECT_EQ(pinned_memory_space->kind_id(), PinnedHostMemorySpace::kKindId); + TF_ASSERT_OK_AND_ASSIGN(auto result, + buffer->CopyToMemorySpace(pinned_memory_space)); + + EXPECT_EQ(result->memory_space()->kind(), "pinned_host"); + EXPECT_TRUE(result->IsOnCpu()); + + TF_ASSERT_OK_AND_ASSIGN(auto literal, result->ToLiteralSync()); + std::vector expected{1, 2, 3, 4}; + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1(expected), + *literal)); +} + TEST(TfrtGpuClientTest, ToLiteralAsync) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); @@ -547,6 +604,77 @@ TEST(TfrtGpuClientTest, FromHostAsync) { } } +TEST(TfrtGpuClientTest, FromHostAsyncPinnedHost) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + ASSERT_GE(client->addressable_devices().size(), 1); + TF_ASSERT_OK_AND_ASSIGN( + auto* pinned_memory_space, + client->addressable_devices()[0]->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); + + std::vector src_literals; + std::vector src_shapes; + for (int i = 0; i < 4; ++i) { + std::vector data(i + 1); + std::iota(data.begin(), data.end(), static_cast(i + 10)); + src_literals.emplace_back(LiteralUtil::CreateR1(data)); + src_shapes.push_back(src_literals.back().shape()); + } + TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + src_shapes, pinned_memory_space)); + std::vector> buffers; + for (int i = 0; i < src_shapes.size(); ++i) { + buffers.emplace_back(transfer_manager->RetrieveBuffer(i)); + } + + for (int i = 0; i < src_shapes.size(); ++i) { + TF_ASSERT_OK(transfer_manager->TransferRawDataToBuffer( + i, + absl::string_view(static_cast(src_literals[i].untyped_data()), + src_literals[i].size_bytes()), + [&]() {})); + } +} + +TEST(TfrtGpuClientTest, FromHostAsyncPinnedHostChunked) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetTfrtGpuClient(GpuClientOptions())); + ASSERT_THAT(client->addressable_devices(), SizeIs(Gt(0))); + TF_ASSERT_OK_AND_ASSIGN( + PjRtMemorySpace * memspace, + client->addressable_devices()[0]->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); + std::vector data{1, 3, 5, 7, 11, 13, 17, 19}; + Shape shape = ShapeUtil::MakeShape(F32, {static_cast(data.size())}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr txm, + client->CreateBuffersForAsyncHostToDevice({shape}, memspace)); + std::unique_ptr buf = txm->RetrieveBuffer(0); + ASSERT_THAT(buf->GetReadyFuture().IsReady(), Eq(false)); + + absl::string_view raw_view(reinterpret_cast(data.data()), + data.size() * sizeof(data[0])); + int offset = 0; + while (true) { + int end = offset + 3; // unaligned chunk size + if (end > raw_view.size()) { + end = raw_view.size(); + } + int sz = end - offset; + bool reaches_end = end == raw_view.size(); + TF_ASSERT_OK(txm->TransferRawDataToSubBuffer( + /*buffer_index=*/0, raw_view.data() + offset, offset, sz, reaches_end, + /*on_done=*/[]() {})); + if (reaches_end) { + break; + } + offset = end; + } + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr lit, buf->ToLiteralSync()); + EXPECT_THAT(lit->data(), ElementsAreArray(data)); +} + TEST(TfrtGpuClientTest, CreateMixOfErrorBuffers) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, GetTfrtGpuClient(GpuClientOptions())); @@ -711,8 +839,8 @@ TEST(TfrtGpuClientTest, CopyRawToHostFuture) { auto ready = buffer->GetReadyFuture(); auto result = buffer->CopyRawToHostFuture(dst_future, 0, size); - // Drop the buffer before fulfilling `dst`. The transfer should still keep the - // buffer alive. + // Drop the buffer before fulfilling `dst`. The transfer should still keep + // the buffer alive. buffer.reset(); ready.OnReady([dst_promise = std::move(dst_promise), size](absl::Status status) mutable { From 01db53f6a5cefcfa9617f2b5bc616ef32a2690a0 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Tue, 1 Apr 2025 16:09:52 -0700 Subject: [PATCH 0120/1324] Let `non_blocking_thread_pool` handle `on_done` in `AsyncHostToDeviceTransferManager` PiperOrigin-RevId: 742875944 --- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 20 +++++++++------- .../xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 24 +++++++++++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 0707ad64cc8b4a..b78f462e460cc0 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -177,13 +177,14 @@ absl::StatusOr> AllocateTfrtGpuDestinationBuffer( } void EnqueueWork(tsl::thread::ThreadPool* pool, - absl::AnyInvocable callee) { + absl::AnyInvocable callee) { // TSL TheadPool expects std::function that must be copyable, so we are // forced to do a little bit of manual memory management here. - pool->Schedule([ptr = new absl::AnyInvocable(std::move(callee))]() { - (*ptr)(); - delete ptr; - }); + pool->Schedule( + [ptr = new absl::AnyInvocable(std::move(callee))]() { + std::move (*ptr)(); + delete ptr; + }); } // Enqueue to a thread pool when all `values` are ready. @@ -504,7 +505,7 @@ class TfrtGpuAsyncHostToDeviceTransferManager final // called on this thread, to avoid deadlock. l.Release(); - absl::AnyInvocable copy_to_gpu = + auto copy_to_gpu = [transfer_size, staging_buffer = std::move(staging_buffer), data, sub_buffer = std::move(sub_buffer), buffer_index, is_last_transfer, on_done = std::move(on_done), this]() mutable { @@ -572,7 +573,7 @@ class TfrtGpuAsyncHostToDeviceTransferManager final } // Call on_done after finishing all housekeeping and releasing the lock. - std::move(on_done)(); + EnqueueWork(client_->non_blocking_thread_pool(), std::move(on_done)); } absl::Mutex mu_; @@ -580,7 +581,10 @@ class TfrtGpuAsyncHostToDeviceTransferManager final // Retrieve. absl::InlinedVector, 4> buffers_; - // Just a single thread, to ensure transfers are ordered. + // Just a single thread, to ensure transfers are ordered. Its lifetime is + // managed by H2DTransferManager. We assume `h2d_thread` is destructed before + // `client_`, so `on_done` callbacks on `h2d_thread` will be handled by + // threads managed by `client_`. std::unique_ptr h2d_thread_; absl::InlinedVector, 4> buffer_ptrs_; diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index 24ba9815926254..7e5fd24f32dcb2 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/hlo/builder/xla_computation.h" @@ -956,5 +957,28 @@ TEST(TfrtGpuClientTest, AsyncCopyToDevice) { literal->Relayout(src_literal.shape().layout()).data()); } +TEST(TfrtGpuClientTest, OnDoneSafelyDestructTransferManagerAsync) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + ASSERT_GE(client->addressable_devices().size(), 1); + PjRtDevice* const device = client->addressable_devices()[0]; + + auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr + transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {src_literal.shape()}, *device->default_memory_space())); + std::unique_ptr buffer = transfer_manager->RetrieveBuffer(0); + absl::Notification done; + EXPECT_OK(transfer_manager->TransferLiteralToBuffer( + 0, src_literal, + /*on_done=*/ + [&done, transfer_manager = std::move(transfer_manager)]() { + done.Notify(); + })); + done.WaitForNotification(); +} + } // namespace } // namespace xla From d09ca2dc7f578258b18f21c796c0131d50c39cb9 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 1 Apr 2025 16:13:03 -0700 Subject: [PATCH 0121/1324] Allow raising of all_reduce with multiple args Fixes: https://github.com/pytorch/xla/issues/8854 PiperOrigin-RevId: 742877085 --- .../hlo_legalize_to_stablehlo.cc | 5 ---- ...lo-legalize-to-stablehlo-experimental.mlir | 23 --------------- .../mhlo/hlo-legalize-to-stablehlo.mlir | 29 +++++++++++++++++++ 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 7b4f0458f95210..372970d2b47e25 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -87,11 +87,6 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { // for StableHLO, and they are usually accompanied by a StableHLO GitHub ticket. template bool hasExperimentalFeaturesNotInStablehlo(HloOpTy hloOp) { - if constexpr (std::is_same::value) { - // StableHLO AllReduce doesn't support the tuple form yet. - // Proposal: https://github.com/openxla/stablehlo/issues/1370. - if (hloOp.getNumOperands() != 1) return true; - } if constexpr (std::is_same::value) { // StableHLO AllToAll doesn't support the tuple form yet. // Proposal: https://github.com/openxla/stablehlo/issues/574. diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir index 38056bdb3d0f67..3ace2e8f8a79f8 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir @@ -4,29 +4,6 @@ // This test file runs both FileCheck and diagnostic check. These tests all // error when the experimental flag is disabled, and pass when it is enabled. -// CHECK-LABEL: "op_all_reduce_tuple" -func.func @op_all_reduce_tuple(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<8xf32>, tensor) { - // CHECK: "stablehlo.custom_call"(%[[ARG0:.*]], %[[ARG1:.*]]) <{ - // CHECK-SAME: call_target_name = "mhlo.all_reduce", called_computations = [@all_reduce] - // CHECK-SAME: }> { - // CHECK-SAME{LITERAL}: mhlo.attributes = {replica_groups = dense<> : tensor<0x0xi64>} - // CHECK-SAME: } : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) - // CHECK: func.func - // CHECK-SAME: sym_name = "all_reduce" - // CHECK: ^bb0(%[[REDUCE_ARG0:.*]]: tensor, %[[REDUCE_ARG1:.*]]: tensor): - // CHECK-NEXT: %[[ADD:.*]] = "stablehlo.add"(%[[REDUCE_ARG0]], %[[REDUCE_ARG1]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: "stablehlo.return"(%[[ADD]]) : (tensor) -> () - // expected-error@+1 {{failed to legalize operation 'mhlo.all_reduce' that was explicitly marked illegal}} - %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %2 = mhlo.add %arg2, %arg3 : tensor - mhlo.return %2 : tensor - }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) - return %0#0, %0#1 : tensor<8xf32>, tensor -} - -// ----- - // CHECK-LABEL: "op_all_to_all_tuple" func.func @op_all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) { // CHECK: "stablehlo.custom_call"(%arg0, %arg1) <{call_target_name = "mhlo.all_to_all"}> { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 4a2c82e90763a7..2196a2190d2d48 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -466,6 +466,19 @@ func.func @op_all_reduce(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "op_all_reduce_tuple" +func.func @op_all_reduce_tuple(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<8xf32>, tensor) { + // CHECK: %[[RES:.*]]:2 = "stablehlo.all_reduce"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) + // CHECK-SAME: <{replica_groups = dense<> : tensor<0x0xi64>}> + // CHECK: "func.return"(%[[RES]]#0, %[[RES]]#1) : (tensor<8xf32>, tensor) + %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = mhlo.add %arg2, %arg3 : tensor + mhlo.return %2 : tensor + }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) + return %0#0, %0#1 : tensor<8xf32>, tensor +} + // CHECK-LABEL: "op_all_to_all" func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { // CHECK: "stablehlo.all_to_all"([[ARG0:%arg[0-9]+]]) <{ @@ -485,6 +498,22 @@ func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { func.return %0 : tensor<16x4xf32> } +// CHECK-LABEL: "op_all_to_all_tuple" +func.func @op_all_to_all_tuple(%arg0: tensor<2x4xi64>, %arg1: tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) { + // CHECK: %[[RES:.*]]:2 = "stablehlo.all_to_all"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) + // CHECK-SAME: <{concat_dimension = 0 : i64, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + // CHECK-SAME: split_count = 2 : i64, split_dimension = 1 : i64}> + // CHECK-SAME: : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) + %0:2 = "stablehlo.all_to_all"(%arg0, %arg1) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 2 : i64, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) + return %0#0, %0#1 : tensor<4x2xi64>, tensor<4x2xi64> +} + // CHECK-LABEL: "op_and" func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.and"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor From a56dbb448958b7b8cc7d763005207f77dc1e0047 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Tue, 1 Apr 2025 16:30:28 -0700 Subject: [PATCH 0122/1324] Remove extra call to `StablehloToMhlo`. `ConvertStablehloToHloProtoInternal` is already calling it. Also, merged anonymous namespaces. PiperOrigin-RevId: 742882212 --- .../xla/xla/hlo/translate/stablehlo.cc | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index b6d7490f2b6e7c..d62fcee25dabd8 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -98,41 +98,6 @@ absl::Status StablehloToMhlo(mlir::ModuleOp module, bool run_canonicalizer) { return absl::OkStatus(); } -} // namespace - -void RegisterMlirToHloDependentDialects(mlir::DialectRegistry& registry) { - mlir::stablehlo::registerAllDialects(registry); - mlir::func::registerAllExtensions(registry); - mlir::mhlo::registerAllMhloDialects(registry); - registry.insert(); -} - -absl::StatusOr> ConvertHloToStablehlo( - mlir::MLIRContext& ctx, const xla::HloModule* hlo_module) { - mlir::OwningOpRef mlir_module = - llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), - /*import_all_computation=*/true, - /*flatten_computation_args_result=*/true) - .Import(*hlo_module)); - TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); - return std::move(mlir_module); -} - -absl::StatusOr> ConvertHloToStablehlo( - mlir::MLIRContext& ctx, const xla::HloModuleProto* hlo_module_proto) { - mlir::OwningOpRef mlir_module = - llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), - /*import_all_computation=*/true, - /*flatten_computation_args_result=*/true) - .Import(*hlo_module_proto)); - TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); - return std::move(mlir_module); -} - -namespace { absl::Status ConvertStablehloToHloProtoInternal(mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, @@ -171,6 +136,38 @@ absl::StatusOr> ConvertStablehloToHloInternal( } // namespace +void RegisterMlirToHloDependentDialects(mlir::DialectRegistry& registry) { + mlir::stablehlo::registerAllDialects(registry); + mlir::func::registerAllExtensions(registry); + mlir::mhlo::registerAllMhloDialects(registry); + registry.insert(); +} + +absl::StatusOr> ConvertHloToStablehlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module) { + mlir::OwningOpRef mlir_module = + llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); + TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), + /*import_all_computation=*/true, + /*flatten_computation_args_result=*/true) + .Import(*hlo_module)); + TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); + return std::move(mlir_module); +} + +absl::StatusOr> ConvertHloToStablehlo( + mlir::MLIRContext& ctx, const xla::HloModuleProto* hlo_module_proto) { + mlir::OwningOpRef mlir_module = + llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); + TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), + /*import_all_computation=*/true, + /*flatten_computation_args_result=*/true) + .Import(*hlo_module_proto)); + TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); + return std::move(mlir_module); +} + absl::StatusOr> ConvertStablehloToHlo( mlir::ModuleOp module) { return ConvertStablehloToHloInternal(module, @@ -188,7 +185,6 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, xla::HloProto* hlo_proto) { if (!module) return absl::InvalidArgumentError("Module is null"); - TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/true)); return ConvertStablehloToHloProtoInternal(module, hlo_proto, /*use_tuple_args=*/false, /*return_tuple=*/false, From 9063d1d53bca1e75beb5b863a618874c114dfdd2 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 1 Apr 2025 16:37:41 -0700 Subject: [PATCH 0123/1324] Add `CreateR{2,3}Parameter` and `CreatePatternedMatrix` functions. `ClientLibraryTestRunnerMixin` tries to (mostly on a 1:1 basis) replicate the interface provided by `ClientLibraryTestBase`. These functions were missing because they hadn't yet been used by any tests that were ported to use a `HloRunnerAgnosticTestBase`. This change adds these functions. We will add more functons from `ClientLibraryTestBase` as the need arises. PiperOrigin-RevId: 742884852 --- .../tests/client_library_test_runner_mixin.h | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index 0d8d8f769dbb30..7235f93870af50 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -255,6 +255,26 @@ class ClientLibraryTestRunnerMixin : public T { return literal; } + template + Literal CreateR2Parameter(const Array2D& array_2d, + int64_t parameter_number, const std::string& name, + XlaBuilder* builder, XlaOp* data_handle) { + Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); + literal = MaybeConvertLiteralToTestType(literal); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); + return literal; + } + + template + Literal CreateR3Parameter(const Array3D& array_3d, + int64_t parameter_number, const std::string& name, + XlaBuilder* builder, XlaOp* data_handle) { + Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); + literal = MaybeConvertLiteralToTestType(literal); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); + return literal; + } + Literal MaybeConvertLiteralToTestType(const Literal& literal) const { switch (test_type_) { case BF16: @@ -283,6 +303,26 @@ class ClientLibraryTestRunnerMixin : public T { return execution_options_.mutable_debug_options(); } + // Creates a (rows x cols) array filled in the following form: + // + // [ 0 1 ... cols-1] + // [ 1,000 1,001 ... 1000.0 + cols-1] + // [ ... ... ... ...] + // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1] + // + // If provided, offset is added uniformly to every element (e.g. an offset of + // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.) + static std::unique_ptr> CreatePatternedMatrix( + const int rows, const int cols, float offset = 0.0) { + auto array = std::make_unique>(rows, cols); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t col = 0; col < cols; ++col) { + (*array)(row, col) = col + (row * 1000.0f) + offset; + } + } + return array; + } + private: absl::StatusOr> BuildAndVerifyHloModule( const XlaComputation& computation, From 8776bd45b405bb48a60e666c965a16de5daeca1e Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Tue, 1 Apr 2025 16:45:55 -0700 Subject: [PATCH 0124/1324] litert: Add fp16 support in GPU Accelerator Used fp16 calculation if target GPU supports it. Added naive fp16<->fp32 conversion as an initial step. It should be vectorized. PiperOrigin-RevId: 742887265 --- tensorflow/lite/delegates/gpu/cl/buffer.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/buffer.h b/tensorflow/lite/delegates/gpu/cl/buffer.h index 088a66aa57af2b..01d4e631247737 100644 --- a/tensorflow/lite/delegates/gpu/cl/buffer.h +++ b/tensorflow/lite/delegates/gpu/cl/buffer.h @@ -97,8 +97,9 @@ template absl::Status Buffer::WriteData(CLCommandQueue* queue, const absl::Span data) { if (size_ != sizeof(T) * data.size()) { - return absl::InvalidArgumentError( - "absl::Span data size is different from buffer allocated size."); + return absl::InvalidArgumentError(absl::StrCat( + "absl::Span data size is different from buffer allocated size: ", + size_, " vs ", sizeof(T) * data.size())); } RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data())); return absl::OkStatus(); From 702875c4890ae832b2ed94129206aebdd54f84b2 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 1 Apr 2025 16:48:37 -0700 Subject: [PATCH 0125/1324] Port concat_test to `HloTestBase`. PiperOrigin-RevId: 742888086 --- third_party/xla/xla/tests/BUILD | 11 ++++---- third_party/xla/xla/tests/concat_test.cc | 33 ++++++++++-------------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 18376e60fe59d6..c9294a05ba6b51 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2547,24 +2547,25 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_xla_cpu_no_thunks", + ], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", "//xla:literal_util", "//xla:reference_util", - "//xla/client:local_client", + "//xla:shape_util", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/testlib:test", "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/tests/concat_test.cc b/third_party/xla/xla/tests/concat_test.cc index 7301d7bf8c9c05..6035876a30ece8 100644 --- a/third_party/xla/xla/tests/concat_test.cc +++ b/third_party/xla/xla/tests/concat_test.cc @@ -13,29 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" namespace xla { namespace { -using ConcatTest = ClientLibraryTestBase; +using ConcatTest = ClientLibraryTestRunnerMixin; using ConcatTestHlo = HloTestBase; using ::testing::HasSubstr; @@ -492,7 +494,7 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { ConcatInDim(&builder, {h0, h1}, 2); - ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); + ComputeAndCompareR3(&builder, expected, {&p0, &p1}); } XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { @@ -517,8 +519,7 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { auto q = ConcatInDim(&builder, {p, p}, 0); ConcatInDim(&builder, {q, q}, 0); std::vector expected(131072, 256.0); - auto a_data = client_->TransferToServer(a_literal).value(); - ComputeAndCompareR1(&builder, expected, {a_data.get()}); + ComputeAndCompareR1(&builder, expected, {&a_literal}); } XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) { @@ -774,7 +775,7 @@ struct R2BinarySpec { }; // TEST_P harness for binary rank-2 concatenation. -class ConcatR2BinaryTest : public ClientLibraryTestBase, +class ConcatR2BinaryTest : public ConcatTest, public ::testing::WithParamInterface { }; @@ -808,8 +809,6 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); auto x_literal = LiteralUtil::CreateR0(2.f); auto y_literal = LiteralUtil::CreateR0(3.f); - auto x_data = client_->TransferToServer(x_literal).value(); - auto y_data = client_->TransferToServer(y_literal).value(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, f32_scalar, "x"); @@ -821,7 +820,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0); ComputeAndCompareR1(&builder, {7., 8., 9., 10., 11., 12.}, - {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); + {&x_literal, &y_literal}, ErrorSpec(1e-4)); } // Test that the HLO optimization to replace a concat of a broadcasted scalar @@ -831,9 +830,6 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(x_literal).value(); - auto y_data = client_->TransferToServer(y_literal).value(); - auto z_data = client_->TransferToServer(z_literal).value(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, x_literal.shape(), "x"); @@ -847,7 +843,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { ComputeAndCompareR1( &builder, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f}, - {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4)); + {&x_literal, &y_literal, &z_literal}, ErrorSpec(1e-4)); } // Test that the HLO optimization to replace a concat of a broadcasted scalar @@ -859,9 +855,6 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(x_literal).value(); - auto y_data = client_->TransferToServer(y_literal).value(); - auto z_data = client_->TransferToServer(z_literal).value(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, x_literal.shape(), "x"); @@ -877,7 +870,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1); ComputeAndCompareR3(&builder, *concat1, - {x_data.get(), y_data.get(), z_data.get()}, + {&x_literal, &y_literal, &z_literal}, ErrorSpec(1e-4)); } From dde3a6fae85aa3b2c8265a263d324b27b6303c33 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 17:09:20 -0700 Subject: [PATCH 0126/1324] Copy insertion will elide unnecessary copies based on checking the live range of the current schedule. Passes after copy-insertion might duplicate/reschedule nodes in a way that causes live range overlap and result in runtime corruption, this change adds dependencies to explicitly prevent that. PiperOrigin-RevId: 742894126 --- third_party/xla/xla/service/copy_insertion.cc | 32 +++++++++++++++++++ .../xla/xla/service/copy_insertion_test.cc | 6 ++-- .../cpu/small_while_loop_hoisting_pass.cc | 1 + .../xla/xla/service/gpu/gpu_compiler_test.cc | 2 +- 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 8db7718be597ab..655d880e540048 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -1592,6 +1592,24 @@ class CopyRemover { VLOG(2) << "Region-based interference is false."; return false; }; + auto AddControlDependenciesBetween = [&](ValueNode* src, ValueNode* dst) { + if (src == nullptr || dst == nullptr) { + return; + } + for (auto use : src->uses) { + if (use->instruction->parent() != dst->value->instruction()->parent() || + use->instruction == dst->value->instruction()) { + // Don't add control dependencies if the use is in a different + // computation or if the use is the same as the destination. + continue; + } + VLOG(2) << "Adding control dependency:"; + VLOG(2) << " From: " << use->instruction->ToString(); + VLOG(2) << " To: " << dst->value->instruction()->ToString(); + CHECK_OK(use->instruction->AddControlDependencyTo( + dst->value->instruction())); + } + }; // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -1674,6 +1692,13 @@ class CopyRemover { kMergeFirstDestInSource)) { return false; } + // Ensure that the last uses of the copy source (e.g. s_x) are + // ordered before the next definition of the copy destination buffer + // (d_1). + AddControlDependenciesBetween(copy_node.src, Next(*copy_node.dest)); + // Also ensure that the last uses of the copy destination (e.g. d_m) are + // ordered before the next definition of the copy source buffer (s_{x+1}). + AddControlDependenciesBetween(copy_node.dest->prev, Next(*copy_node.src)); VLOG(2) << "Splice dest after source."; // Splice in destination buffer values list right after 'src'. SpliceAfter(copy_node.dest, copy_node.src); @@ -1707,6 +1732,13 @@ class CopyRemover { VLOG(2) << "Region-based analysis concludes interference."; return false; } + // Ensure that the last uses of the copy source (e.g. s_n) are + // ordered before the next definition of the copy destination buffer + // (d_{y+1}). + AddControlDependenciesBetween(Prev(*copy_node.dest), copy_node.src->next); + // Also ensure that the last uses of the copy source (e.g. s_n) are + // ordered before next definition of the copy destination (e.g. d_{y+1}). + AddControlDependenciesBetween(copy_node.src, Next(*copy_node.dest)); VLOG(2) << "Splice src after prev of dest."; // Splice source buffer values list right after 'prev_dest'. SpliceAfter(copy_node.src->next, Prev(*copy_node.dest)); diff --git a/third_party/xla/xla/service/copy_insertion_test.cc b/third_party/xla/xla/service/copy_insertion_test.cc index f26650863fc622..9071f4250386a6 100644 --- a/third_party/xla/xla/service/copy_insertion_test.cc +++ b/third_party/xla/xla/service/copy_insertion_test.cc @@ -953,7 +953,8 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { InsertCopies(module_.get()); EXPECT_EQ(CountCopies(*body), 1); - EXPECT_EQ(CountControlEdges(*body), 0); + // Control edges exist for elided copies. + EXPECT_EQ(CountControlEdges(*body), 1); EXPECT_THAT( body->root_instruction(), @@ -3527,7 +3528,8 @@ TEST_F(CopyInsertionTest, AddControlDependencyForInputOutputAlias) { /*use_region_based_live_range_analysis=*/-1); ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); EXPECT_EQ(CountCopies(*module), 1); - EXPECT_EQ(CountControlEdges(*module), 2); + // Include control edges from elided copies. + EXPECT_EQ(CountControlEdges(*module), 3); HloInstruction* add_instr = FindInstruction(module.get(), HloOpcode::kAdd); HloInstruction* mul_instr = diff --git a/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc b/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc index 767e914a075992..c599b5f3ff6425 100644 --- a/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc +++ b/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc @@ -127,6 +127,7 @@ absl::StatusOr SmallWhileLoopHoistingPass::Run( call_instruction->add_frontend_attribute("xla_cpu_small_call", "true"); TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(call_instruction)); + TF_RETURN_IF_ERROR(while_instr->SafelyDropAllControlDependencies()); TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); changed = true; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 2adc6ea7c3e0e9..9f890c4076f782 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -830,7 +830,7 @@ CHECK: %[[RESULT_RECV:.*]] = recv(%[[AFTER_ALL]]) CHECK-SAME: channel_id=[[CHANNEL_ID]] CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}, -CHECK-SAME: control-predecessors={%[[CUSTOM_CALL]]} +CHECK-SAME: control-predecessors={%[[CUSTOM_CALL:.*]]} CHECK: %[[RESULT_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[AFTER_ALL]]) CHECK-SAME: channel_id=1 CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", From a98cf6fc01173bcab50308d83a7eab53b19ff4b5 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Tue, 1 Apr 2025 17:21:32 -0700 Subject: [PATCH 0127/1324] Add dynamic gelu composite lowerings PiperOrigin-RevId: 742897623 --- .../transforms/composite_lowering_patterns.td | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 1beff56b89b40c..656b9ec692568d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -133,6 +133,20 @@ def LegalizeCompositeGELU : Pat< (TFL_GeluOp $inputs, (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; +def LegalizeCompositeGELUDynamicShaped : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + +def LegalizeCompositeGELUDynamicShaped2 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + def LegalizeCompositeOdmlEmbeddingLookup : Pat< (MHLO_CompositeOp:$composite (variadic $indices, $table), From d61e6f81ce73869ab818f65ddfe81d1a2137ea1f Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 1 Apr 2025 17:23:26 -0700 Subject: [PATCH 0128/1324] Migrate concat_test to PjRt runner. PiperOrigin-RevId: 742898141 --- third_party/xla/xla/tests/BUILD | 4 +++- third_party/xla/xla/tests/concat_test.cc | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index c9294a05ba6b51..08a7e476957a5c 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2548,11 +2548,13 @@ xla_test( name = "concat_test", srcs = ["concat_test.cc"], tags = [ + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", diff --git a/third_party/xla/xla/tests/concat_test.cc b/third_party/xla/xla/tests/concat_test.cc index 6035876a30ece8..7fa71f38934a7b 100644 --- a/third_party/xla/xla/tests/concat_test.cc +++ b/third_party/xla/xla/tests/concat_test.cc @@ -31,14 +31,16 @@ limitations under the License. #include "xla/reference_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" namespace xla { namespace { -using ConcatTest = ClientLibraryTestRunnerMixin; -using ConcatTestHlo = HloTestBase; +using ConcatTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; +using ConcatTestHlo = HloPjRtInterpreterReferenceMixin; using ::testing::HasSubstr; // Concatenate expects at least one argument. @@ -762,7 +764,7 @@ ENTRY jit_broken.874 { auto input_array = std::make_unique>(4, 2); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); - EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_)); + EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, kDefaultErrorSpec)); } // Describes a binary rank-2 concatenation test. From 9ebac723d4a907c7291e7097e60bec0a18e4aec7 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 1 Apr 2025 18:11:15 -0700 Subject: [PATCH 0129/1324] Migrate scatter_test to PjRt runner. PiperOrigin-RevId: 742910507 --- third_party/xla/xla/tests/BUILD | 17 ++++++++++----- third_party/xla/xla/tests/scatter_test.cc | 26 +++++++++++++++++------ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 08a7e476957a5c..e29b60c4c4070b 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1214,21 +1214,28 @@ xla_test( xla_test( name = "scatter_test", srcs = ["scatter_test.cc"], - tags = ["test_xla_cpu_no_thunks"], - # TODO(b/245550554): enable Pjrt runner for scatter test once it's fixed. + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":client_library_test_base", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", + ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:status_macros", "//xla:types", + "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", + "//xla/service:hlo_module_config", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/tests/scatter_test.cc b/third_party/xla/xla/tests/scatter_test.cc index 3a17e9fbc5641e..1d83a247c5168b 100644 --- a/third_party/xla/xla/tests/scatter_test.cc +++ b/third_party/xla/xla/tests/scatter_test.cc @@ -13,32 +13,44 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include +#include +#include #include +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "absl/types/span.h" #include "xla/array2d.h" #include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/test.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" namespace xla { namespace { -class ScatterTest : public HloTestBase { +class ScatterTest : public HloPjRtInterpreterReferenceMixin { protected: - void RunTest(const std::string& hlo_text, Literal* operand, - Literal* scatter_indices, Literal* updates) { + void RunTest(const absl::string_view hlo_text, Literal* const operand, + Literal* const scatter_indices, Literal* const updates) { RunTest(hlo_text, {operand, scatter_indices, updates}); } - void RunTest(const std::string& hlo_text, absl::Span args) { + void RunTest(const absl::string_view hlo_text, + const absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, From 29aebb5b1065db931b1744ac4e96666d6d0f543b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 21:02:33 -0700 Subject: [PATCH 0130/1324] Automated Code Change PiperOrigin-RevId: 742956280 --- tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc | 1 + tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc | 1 - .../lite/delegates/gpu/gl/compiler/object_accessor_test.cc | 1 - .../lite/delegates/gpu/gl/compiler/variable_accessor.cc | 4 +++- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc index 985da96ebff678..7db139c4ccfa33 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_set.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc index 00a95c816e9976..f6aee5dd889678 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc index fbca570d892f2f..0a057b14a80a2c 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc index d1a7fd78e1a87b..81b8e89f2252f0 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc @@ -15,15 +15,17 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include +#include #include #include #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" From 976ef5831ab9c84e90887a20e3cd95856711a8fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 21:54:23 -0700 Subject: [PATCH 0131/1324] Automated Code Change PiperOrigin-RevId: 742969440 --- tensorflow/core/kernels/rnn/BUILD | 2 ++ tensorflow/core/kernels/rnn/gru_ops.cc | 1 + tensorflow/core/kernels/rnn/lstm_ops.cc | 10 +++++++--- tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc | 2 ++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/rnn/BUILD b/tensorflow/core/kernels/rnn/BUILD index 3b9298c5bac42f..17b2545986a7c6 100644 --- a/tensorflow/core/kernels/rnn/BUILD +++ b/tensorflow/core/kernels/rnn/BUILD @@ -49,6 +49,8 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:eigen_helpers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@eigen_archive//:eigen3", ], ) diff --git a/tensorflow/core/kernels/rnn/gru_ops.cc b/tensorflow/core/kernels/rnn/gru_ops.cc index ed424e922a4fe3..f1722497cc81c3 100644 --- a/tensorflow/core/kernels/rnn/gru_ops.cc +++ b/tensorflow/core/kernels/rnn/gru_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/rnn/gru_ops.h" diff --git a/tensorflow/core/kernels/rnn/lstm_ops.cc b/tensorflow/core/kernels/rnn/lstm_ops.cc index 5bf12c3b56cd62..8fb0dcfd9ce645 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops.cc @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/kernels/rnn/lstm_ops.h" - -#include #include #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive @@ -31,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/rnn/lstm_ops.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc index 795791254a475d..4b3867edfbf4bd 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU From f8888244556214c2fb76e954c994a986c7a8995a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Apr 2025 22:57:17 -0700 Subject: [PATCH 0132/1324] Automated Code Change PiperOrigin-RevId: 742985848 --- .../core/tfrt/saved_model/python/saved_model_load_and_run.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc b/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc index 448e05d411d165..c2dcfbebe87734 100644 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.h" +#include + #include #include #include From d3a386a35ffba8ea1649471f4811c4da8e132d71 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 00:08:02 -0700 Subject: [PATCH 0133/1324] Automated Code Change PiperOrigin-RevId: 743004300 --- .../lite/delegates/gpu/common/selectors/simple_selectors.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc index 6d3dec487e4ea1..02b4e16aa1a78e 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" From e42f73e990055b4d2d1f1aae3329a909d4bfb453 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 2 Apr 2025 00:14:49 -0700 Subject: [PATCH 0134/1324] Remove runtime dependencies from gpu_executable `GpuExecutable` is a type that is used by both the compiler and the runtime. Therefore it shouldn't link runtime-only dependencies. So this change removes those dependencies which required to add some explicit dependencies to downstream users of GpuExecutable that was transitively relying on the dependency. This is a prerequisite for turning XLA into a proper AOT/deviceless compiler and also for the CUDA/ROCm modularization PiperOrigin-RevId: 743006733 --- third_party/xla/xla/service/BUILD | 4 ++-- third_party/xla/xla/service/gpu/BUILD | 12 ++---------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 4e3132117bacbf..6e21551653e39d 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6240,9 +6240,9 @@ xla_cc_binary( "//xla/stream_executor/gpu:gpu_init", "//xla/stream_executor/rocm:rocm_platform", ]) + if_cuda([ - "//xla/stream_executor/cuda:cublas_plugin", + "//xla/stream_executor/cuda:all_runtime", ]) + if_rocm([ - "//xla/stream_executor/rocm:rocblas_plugin", + "//xla/stream_executor/rocm:all_runtime", ]) + xla_internal(["tools:xsymbol_repository"]), ) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index d67ef9f5ee157b..be8ebb87957949 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -623,16 +623,7 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:scoped_annotation", "@local_tsl//tsl/profiler/lib:traceme", - ] + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cublas_plugin", - "//xla/stream_executor/cuda:cudnn_plugin", - "//xla/stream_executor/cuda:cufft_plugin", - "//xla/stream_executor/cuda:stream_executor_cuda", - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:stream_executor_rocm", - "@local_config_rocm//rocm:rocm_headers", - ]), + ], ) cc_library( @@ -2369,6 +2360,7 @@ gpu_kernel_library( ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:hip", ]), ) From b2d98b30772ec8ec9de1b0025d77ae0ce81c3088 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 2 Apr 2025 00:26:02 -0700 Subject: [PATCH 0135/1324] [XLA:GPU] Delete `--xla_gpu_unsupported_force_triton_gemm`. It used to be necessary to work around a lowering bug on A100, but our version of Triton now carries the necessary patch. PiperOrigin-RevId: 743009816 --- .../fusion_emitter_device_legacy_test.cc | 40 ++++++------------- .../fusion_emitter_parametrized_test.cc | 1 - third_party/xla/xla/debug_options_flags.cc | 1 - .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 1 - .../xla/xla/service/gpu/float_support_test.cc | 1 - .../tests/tensor_float_32_global_var_test.cc | 1 - third_party/xla/xla/xla.proto | 9 +---- 7 files changed, 15 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc index 37141f2f0a0998..22eabbc2935b2a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc @@ -1398,16 +1398,7 @@ e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -class TritonGemmTestAny : public TritonGemmTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_unsupported_force_triton_gemm(true); - return debug_options; - } -}; - -TEST_F(TritonGemmTestAny, DoF32F32) { +TEST_F(TritonGemmTest, DoF32F32) { const std::string hlo_text = R"( HloModule t @@ -1427,7 +1418,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, DoAddConstantToScalarAndBroadcastThat) { +TEST_F(TritonGemmTest, DoAddConstantToScalarAndBroadcastThat) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "Not using autotuner on ROCM yet."; } @@ -1737,8 +1728,7 @@ ENTRY e { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } -TEST_F(TritonGemmTestAny, - DoNotFuseConcatenationOfSplitNonContractingDimension) { +TEST_F(TritonGemmTest, DoNotFuseConcatenationOfSplitNonContractingDimension) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "Not using autotuner on ROCM yet."; } @@ -2132,7 +2122,7 @@ e { /*arel=*/1e-2})); } -TEST_F(TritonGemmTestAny, MinimumHandlesNaNsOnTheLeft) { +TEST_F(TritonGemmTest, MinimumHandlesNaNsOnTheLeft) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2155,7 +2145,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MinimumHandlesNaNsOnTheRight) { +TEST_F(TritonGemmTest, MinimumHandlesNaNsOnTheRight) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2178,7 +2168,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MaximumHandlesNaNsOnTheLeft) { +TEST_F(TritonGemmTest, MaximumHandlesNaNsOnTheLeft) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2201,7 +2191,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MaximumHandlesNaNsOnTheRight) { +TEST_F(TritonGemmTest, MaximumHandlesNaNsOnTheRight) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2224,7 +2214,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MinimumReturnsLHS) { +TEST_F(TritonGemmTest, MinimumReturnsLHS) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2249,7 +2239,7 @@ ENTRY e { /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MinimumReturnsRHS) { +TEST_F(TritonGemmTest, MinimumReturnsRHS) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2274,7 +2264,7 @@ ENTRY e { /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MaximumReturnsLHS) { +TEST_F(TritonGemmTest, MaximumReturnsLHS) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2299,7 +2289,7 @@ ENTRY e { /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, MaximumReturnsRHS) { +TEST_F(TritonGemmTest, MaximumReturnsRHS) { constexpr absl::string_view kHloText = R"( HloModule t @@ -2671,8 +2661,7 @@ ENTRY e { )"); } -TEST_F(TritonGemmTestAny, - LowerDotWithLhsWithoutNonContractingDimThroughTriton) { +TEST_F(TritonGemmTest, LowerDotWithLhsWithoutNonContractingDimThroughTriton) { const std::string hlo_text = R"( HloModule t @@ -2693,8 +2682,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTestAny, - LowerDotWithRhsWithoutNonContractingDimThroughTriton) { +TEST_F(TritonGemmTest, LowerDotWithRhsWithoutNonContractingDimThroughTriton) { const std::string hlo_text = R"( HloModule t @@ -3984,8 +3972,6 @@ class TritonGemmContractionDims : public TritonGemmTest { DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true); - debug_options.set_xla_gpu_unsupported_force_triton_gemm(true); - return debug_options; } }; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc index 25ad4ae4e5e91e..5cff5b32eee616 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc @@ -151,7 +151,6 @@ class TritonTest : public GpuCodegenTest { public: DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_unsupported_force_triton_gemm(true); debug_options.set_xla_gpu_cublas_fallback(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 8667c05bbefb19..217a8c94db076a 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -230,7 +230,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_unsupported_enable_triton_multi_output_fusion(false); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(true); - opts.set_xla_gpu_unsupported_force_triton_gemm(false); opts.set_xla_gpu_verify_triton_fusion_numerics(false); // Moving reduce-scatter out of while loops can increase memory footprint, so diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 78bb3c249c085c..71bdb41cab3564 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -2335,7 +2335,6 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id, DebugOptions& debug_options = *compile_options.executable_build_options.mutable_debug_options(); debug_options.set_xla_gpu_shard_autotuning(true); - debug_options.set_xla_gpu_unsupported_force_triton_gemm(true); debug_options.set_xla_gpu_cublas_fallback(false); if (node_id < num_nodes_using_cache) { diff --git a/third_party/xla/xla/service/gpu/float_support_test.cc b/third_party/xla/xla/service/gpu/float_support_test.cc index a0bddbe0957972..79f343f90c1a4c 100644 --- a/third_party/xla/xla/service/gpu/float_support_test.cc +++ b/third_party/xla/xla/service/gpu/float_support_test.cc @@ -51,7 +51,6 @@ class FloatSupportTestWithTriton : public FloatSupportTest { DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = FloatSupportTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(true); - debug_options.set_xla_gpu_unsupported_force_triton_gemm(true); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; } diff --git a/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc b/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc index b7c9c92772b3c4..39d20ea80c9f28 100644 --- a/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc +++ b/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc @@ -50,7 +50,6 @@ class TensorFloat32GlobalVarTest : public ::testing::WithParamInterface, const bool enable_triton_gemm = GetParam(); if (enable_triton_gemm) { debug_options.set_xla_gpu_enable_triton_gemm(true); - debug_options.set_xla_gpu_unsupported_force_triton_gemm(true); debug_options.set_xla_gpu_cublas_fallback(false); } else { debug_options.set_xla_gpu_enable_triton_gemm(false); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 4a89560b8c9228..3d0c695fac4fcf 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -779,12 +779,6 @@ message DebugOptions { // TODO(b/390559452): Remove the flag once the feature is stable. bool xla_gpu_unsupported_enable_triton_multi_output_fusion = 382; - // Internal debug/testing flag to force all GEMMs to use Triton, independently - // of known issues. - // TODO(b/395903738): use to make specific tests pass on A100 while working - // around this bug. The can be removed once the bug is fixed. - bool xla_gpu_unsupported_force_triton_gemm = 369; - // Internal testing flag to enable one-shot kernel for single-host // ragged-all-to-all operations. bool xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel = 375; @@ -1231,8 +1225,9 @@ message DebugOptions { // xla_gpu_enable_bf16_3way_gemm // xla_gpu_enable_bf16_6way_gemm // xla_gpu_enable_cudnn_fmha + // xla_gpu_unsupported_force_triton_gemm reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320, - 325, 326, 332, 361, 270, 229, 271, 279, 218; + 325, 326, 332, 361, 270, 229, 271, 279, 218, 369; } // Contains flags which affects the GPU compilation result. From 2b0f50dbdaf8603bd9d2277ea2997fc08ad53f55 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 02:02:51 -0700 Subject: [PATCH 0136/1324] compat: Update forward compatibility horizon to 2025-04-02 PiperOrigin-RevId: 743036112 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index fa2918c0912f24..a88e5ef33cfb53 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 1) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 2) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From dcf091be3a63fd17aa10886f003fbfabe8f36d5e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 02:02:59 -0700 Subject: [PATCH 0137/1324] Update GraphDef version to 2185. PiperOrigin-RevId: 743036193 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 970cc071347b35..10d9b03823b4dd 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2184 // Updated: 2025/4/1 +#define TF_GRAPH_DEF_VERSION 2185 // Updated: 2025/4/2 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From ad2d5710c7dbde1cb1902d8b0ded0b36321a4497 Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Wed, 2 Apr 2025 02:22:40 -0700 Subject: [PATCH 0138/1324] [XLA:GPU] Add triton support test for optim-barrier Can be treated as an unary op PiperOrigin-RevId: 743042393 --- third_party/xla/xla/backends/gpu/codegen/triton/support.cc | 1 - third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index e8771a2e767b65..e9f06dd0fcec2e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -514,7 +514,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: case HloOpcode::kMap: - case HloOpcode::kOptimizationBarrier: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kRaggedDot: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 4a9b9d401a7d22..1b5ca1adb8345f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -410,6 +410,7 @@ constexpr std::array kTestedOpsUnaryElementwise = { HloOpcode::kLogistic, HloOpcode::kNegate, HloOpcode::kNot, + HloOpcode::kOptimizationBarrier, HloOpcode::kPopulationCount, HloOpcode::kReal, HloOpcode::kReducePrecision, @@ -2379,7 +2380,6 @@ constexpr std::array kUnsupportedOps = { HloOpcode::kGetTupleElement, HloOpcode::kInfeed, HloOpcode::kMap, - HloOpcode::kOptimizationBarrier, HloOpcode::kOutfeed, HloOpcode::kPad, HloOpcode::kRaggedDot, From 7836ee811eb77cadc9d48032a9880cd732c630d5 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 2 Apr 2025 03:00:58 -0700 Subject: [PATCH 0139/1324] Set up scaffolding for dynamic autotuner search space For now we only connect a dummy class with an XLA flag, so we can later do A/B comparisons to current behavior. PiperOrigin-RevId: 743053221 --- third_party/xla/xla/debug_options_flags.cc | 10 +++ .../xla/xla/service/gpu/autotuning/BUILD | 31 ++++++++ .../gpu/autotuning/dot_search_space.cc | 49 +++++++++++++ .../service/gpu/autotuning/dot_search_space.h | 54 ++++++++++++++ .../gpu/autotuning/dot_search_space_test.cc | 71 +++++++++++++++++++ .../gpu/autotuning/gemm_fusion_autotuner.cc | 35 ++++++--- .../autotuning/gemm_fusion_autotuner_test.cc | 32 +++++++++ third_party/xla/xla/xla.proto | 7 +- 8 files changed, 277 insertions(+), 12 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc create mode 100644 third_party/xla/xla/service/gpu/autotuning/dot_search_space.h create mode 100644 third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 217a8c94db076a..2711a0b7098065 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -269,6 +269,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_operand_bytes_threshold_for_windowed_einsum(-1); opts.set_xla_gpu_enable_triton_hopper(false); + opts.set_xla_gpu_experimental_enable_dynamic_dot_search_space(false); opts.set_xla_gpu_experimental_enable_fusion_block_level_rewriter(false); opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); @@ -2009,6 +2010,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_hopper), debug_options->xla_gpu_enable_triton_hopper(), "Currently used to enable MMA_V3 for Hopper in Triton")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_dynamic_dot_search_space", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_experimental_enable_dynamic_dot_search_space), + debug_options->xla_gpu_experimental_enable_dynamic_dot_search_space(), + "Enable dynamically generating and pruning the autotuning search space " + "for Triton dot fusions, based on the properties of the problem and " + "hardware (shapes, instructions, GPU limits, etc.).")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_enable_fusion_block_level_rewriter", bool_setter_for( diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 41dfba0252c9ff..400df6ac61bf25 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -118,6 +118,7 @@ cc_library( ":autotuner_compile_util", ":autotuner_status_key", ":autotuner_util", + ":dot_search_space", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", "//xla:shape_util", @@ -256,6 +257,36 @@ xla_test( ], ) +cc_library( + name = "dot_search_space", + srcs = ["dot_search_space.cc"], + hdrs = ["dot_search_space.h"], + tags = ["gpu"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service/gpu:matmul_utils", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/strings:str_format", + ], +) + +xla_test( + name = "dot_search_space_test", + srcs = ["dot_search_space_test.cc"], + backends = ["gpu"], + deps = [ + ":dot_search_space", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service/gpu:matmul_utils", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "gemm_algorithm_picker", srcs = ["gemm_algorithm_picker.cc"], diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc new file mode 100644 index 00000000000000..9e1510818150a4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -0,0 +1,49 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/autotuning/dot_search_space.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( + const se::DeviceDescription& device_description, + const HloDotInstruction* dot) + : // Set up basic information about the hardware and the problem. + device_description_(device_description) { + // TODO: b/404470821 - Do something based on `dot`. +} + +std::vector TritonDotFusionSearchSpace::GenerateConfigs() { + // TODO: b/404470821 - Implement this properly rather than hardcoding the + // config. + return {TritonGemmConfig( + /*block_m=*/64, /*block_n=*/128, /*block_k=*/64, + /*split_k=*/1, /*num_stages=*/3, /*num_warps=*/4, + /*num_ctas=*/1)}; +} + +std::string TritonDotFusionSearchSpace::Serialize() { + return absl::StrFormat("TODO: b/404470821 - Implement this."); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h new file mode 100644 index 00000000000000..3fd976b3fd62aa --- /dev/null +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -0,0 +1,54 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_AUTOTUNING_DOT_SEARCH_SPACE_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_DOT_SEARCH_SPACE_H_ + +#include +#include + +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Generates the space of promising Triton configs for a given dot fusion +// and hardware. +// +// Takes into account the properties of the problem (e.g., operand and result +// shapes, fused instructions), and the hardware (e.g., number of cores, +// available registers and memory per core). +// +// Internal doc with rationale: go/xla-gpu-dot-search +class TritonDotFusionSearchSpace { + public: + TritonDotFusionSearchSpace(const se::DeviceDescription& device_description, + const HloDotInstruction* dot); + + // Generates the list of promising configs in the search space for the + // autotuner to try. + std::vector GenerateConfigs(); + + // Serializes the search space to a human-readable string. + std::string Serialize(); + + private: + se::DeviceDescription device_description_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_AUTOTUNING_DOT_SEARCH_SPACE_H_ diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc new file mode 100644 index 00000000000000..bd6feba7bef256 --- /dev/null +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/autotuning/dot_search_space.h" + +#include + +#include +#include +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_description.pb.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using ::testing::Field; +using ::testing::Ge; + +auto IsValidConfig() { + return AllOf(Field("block_m", &TritonGemmConfig::block_m, Ge(1)), + Field("block_n", &TritonGemmConfig::block_n, Ge(1)), + Field("block_k", &TritonGemmConfig::block_k, Ge(1)), + Field("split_k", &TritonGemmConfig::split_k, Ge(1)), + Field("num_stages", &TritonGemmConfig::num_stages, Ge(1)), + Field("num_warps", &TritonGemmConfig::num_warps, Ge(1)), + Field("num_ctas", &TritonGemmConfig::num_ctas, Ge(1))); +}; + +class DotSearchSpaceTest : public HloHardwareIndependentTestBase { + protected: + se::DeviceDescription device_description_{ + se::DeviceDescription(se::GpuDeviceInfoProto::default_instance())}; +}; + +TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + TritonDotFusionSearchSpace search_space( + device_description_, + Cast(module->entry_computation()->root_instruction())); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Not(::testing::IsEmpty()), Each(IsValidConfig()))); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 73b9285566818c..2b75d3fe4bb4a7 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -63,6 +63,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_status_key.h" #include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/autotuning/dot_search_space.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" @@ -129,9 +130,6 @@ namespace { // Minimum tile size. constexpr int kMinTileSize = 16; -// Default tiling when autotuning is disabled. -constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4}; - // Split-K is enabled when the estimate number of waves is lower than the limit. constexpr int kMaxWavesForSplitK = 5; @@ -872,6 +870,27 @@ void ModifyPotentiallyFailingConfig(TritonGemmConfig& config, int minBitWidth, absl::StatusOr> GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { + // Default tiling when autotuning is disabled. + constexpr TritonGemmConfig kDefaultConfig = { + /*block_m=*/32, /*block_n=*/32, /*block_k=*/32, + /*split_k=*/1, /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + constexpr int kMinGemmElements = 2 * 32 * 32; + bool small_dot = ShapeUtil::ElementsIn(dot.operand(0)->shape()) + + ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= + kMinGemmElements; + + if (debug_options_.xla_gpu_experimental_enable_dynamic_dot_search_space()) { + if (small_dot || !IsAutotuningEnabled()) { + return {{kDefaultConfig}}; + } + TritonDotFusionSearchSpace search_space(config_.GetDeviceDescription(), + &dot); + VLOG(1) << "Generating configs from search space: " + << search_space.Serialize(); + return search_space.GenerateConfigs(); + } + // Retrieve the minimum bit-width participating in the dot. This is needed // to avoid autotuning configurations that are not supported by Triton. This // is used to restrict the values for tile_k. @@ -897,20 +916,14 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { // Generate the list of configurations (once). if (triton_configs_.empty()) { - triton_configs_ = !IsAutotuningEnabled() - ? std::vector(1, kDefaultGemmTiling) + triton_configs_ = !IsAutotuningEnabled() ? std::vector(1, kDefaultConfig) : debug_options_.xla_gpu_exhaustive_tiling_search() ? GetExhaustiveTritonConfigs() : GetDefaultTritonConfigs(); } - // Avoid autotuning tiny fusions. - constexpr int kMinGemmElements = 32 * 32; - bool small_dot = - ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements && - ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements; std::vector triton_configs = - small_dot ? std::vector(1, kDefaultGemmTiling) : triton_configs_; + small_dot ? std::vector(1, kDefaultConfig) : triton_configs_; // Split-K optimization enables more even utilization of a GPU in cases // where tiling just the non-contracting dimensions of a GEMM does not create diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index ccdccc0c5f3fa0..e57cb1ce5fcc48 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -400,6 +400,18 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; +// TODO: b/404470821 - Remove once this is enabled by default. +class DynamicSearchSpaceAutotunerTest : public GemmFusionAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = + GemmFusionAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_experimental_enable_dynamic_dot_search_space( + true); + return debug_options; + } +}; + absl::StatusOr> GetPossibleMatmulAutotuneTritonConfigs( const HloDotInstruction& dot, @@ -1757,6 +1769,26 @@ ENTRY e { [](const TritonGemmConfig& config) { return config.num_ctas > 2; })); } +TEST_F(DynamicSearchSpaceAutotunerTest, AutotunesSimpleDotFusion) { + const std::string hlo = R"( +HloModule module +ENTRY e { + x = s8[128,64] parameter(0) + c = f16[128,64] convert(x) + y = f16[64,6144] parameter(1) + ROOT out = f16[128,6144] dot(c, y), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + CheckTritonAutotuning(hlo, R"( +// CHECK: ENTRY +// CHECK: ROOT +// CHECK-SAME: kCustom +// CHECK-SAME: block_m +)"); + EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 3d0c695fac4fcf..881cbf5138e803 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -536,6 +536,11 @@ message DebugOptions { // xla_gpu_multi_streamed_windowed_einsum is set to true. bool xla_gpu_experimental_enable_alltoall_windowed_einsum = 360; + // Enable dynamically generating and pruning the autotuning search space for + // Triton dot fusions, based on the properties of the problem and hardware + // (shapes, instructions, GPU limits, etc.). + bool xla_gpu_experimental_enable_dynamic_dot_search_space = 385; + // Enabling this flag will attempt to redirect every already-constructed // fusion possible to the Triton emitter. // @@ -1198,7 +1203,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 385 + // Next id: 386 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From f26000fa96d6d67ae3f365619a8d7bd782ede07f Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 2 Apr 2025 03:26:20 -0700 Subject: [PATCH 0140/1324] PR #24451: [GPU] Upgrade cuDNN frontend to 1.11.0. Imported from GitHub PR https://github.com/openxla/xla/pull/24451 Copybara import of the project: -- f8439ae68e14bf9a6c3d83fa73e576d27920b92f by Ilia Sergachev : [GPU] Upgrade cuDNN frontend to 1.11.0. Merging this change closes #24451 PiperOrigin-RevId: 743060380 --- tensorflow/workspace2.bzl | 6 +++--- third_party/xla/workspace2.bzl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 79d7ac7291db83..bd468e6428be15 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -202,9 +202,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "59fb63e273c845cb85996d536194a7e2b22012810983cbbf06c4a46b09d17a32", - strip_prefix = "cudnn-frontend-1.10.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.10.0.zip"), + sha256 = "34dfe01057e43e799af207522aa0c863ad3177f8c1568b6e7a7e4ccf1cbff769", + strip_prefix = "cudnn-frontend-1.11.0", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.11.0.zip"), ) tf_http_archive( diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 8a0336e338aafa..7b29f1170df463 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -91,9 +91,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "59fb63e273c845cb85996d536194a7e2b22012810983cbbf06c4a46b09d17a32", - strip_prefix = "cudnn-frontend-1.10.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.10.0.zip"), + sha256 = "34dfe01057e43e799af207522aa0c863ad3177f8c1568b6e7a7e4ccf1cbff769", + strip_prefix = "cudnn-frontend-1.11.0", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.11.0.zip"), ) tf_http_archive( From 2cd427669b06d722f542b5c8a22f329943178865 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Wed, 2 Apr 2025 04:50:17 -0700 Subject: [PATCH 0141/1324] Temporary replace type of compiler_for_platform with auto. PiperOrigin-RevId: 743083188 --- tensorflow/compiler/jit/xla_platform_info.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index f9af695e33c163..321d00ad728402 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -255,8 +255,9 @@ absl::Status BuildXlaDeviceCompiler(DeviceBase* device, return platform.status(); } - absl::StatusOr compiler_for_platform = - xla::Compiler::GetForPlatform(platform.value()); + // TODO(aliia): Replace auto with the actual type. This is a temporary change, + // needed to pass the OSS presubmits. + auto compiler_for_platform = xla::Compiler::GetForPlatform(platform.value()); if (!compiler_for_platform.ok()) { // In some rare cases (usually in unit tests with very small clusters) we // may end up transforming an XLA cluster with at least one GPU operation From 05e6cce54a157f37db93a97beea1aee661617688 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Wed, 2 Apr 2025 04:59:34 -0700 Subject: [PATCH 0142/1324] PR #23347: [gpu] Allow explicitly setting slice_index in se_gpu_pjrt_client Imported from GitHub PR https://github.com/openxla/xla/pull/23347 Allows overriding the slice index used by se_gpu_pjrt_client. More explicit control over which slice a device ends up in is desirable: - Various parts of the ecosystem equate slices with "devices communicating via fast interconnect". With the arrival of NVL72 we want devices managed by multiple hosts to form a single slice. - For debugging purposes it can be useful to allow devices on the same host (managed in separate processes) to be treated as different slices. For example, [Orbax](https://github.com/google/orbax)'s local checkpointing presumes the existence of at least two slices, so overriding the boot id will allow us to test local checkpointing on a single host. (Companion PR in JAX: https://github.com/jax-ml/jax/pull/26906) Copybara import of the project: -- 8d167908028f75c92e635abe65beff9206cf25ea by Georg Stefan Schmid : [gpu] Allow overriding XLA slice_index Merging this change closes #23347 PiperOrigin-RevId: 743085978 --- .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 7 ++ .../xla/xla/pjrt/distributed/protocol.proto | 3 + .../xla/xla/pjrt/distributed/topology_util.cc | 69 ++++++++++++------- .../xla/xla/pjrt/distributed/topology_util.h | 2 +- .../pjrt/distributed/topology_util_test.cc | 54 ++++++++++++--- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 19 +++-- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 1 + .../plugin/xla_gpu/xla_gpu_client_options.h | 2 + 8 files changed, 118 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index df283cb3f6938a..b00a40ce78bf51 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -89,6 +89,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, {"enable_mock_nccl", PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, {"mock_gpu_topology", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, + {"slice_index", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, }); PJRT_RETURN_IF_ERROR( ValidateCreateOptions(create_options, kExpectedOptionNameAndTypes)); @@ -158,6 +159,11 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { it != create_options.end()) { mock_gpu_topology = std::get(it->second); } + std::optional slice_index; + if (auto it = create_options.find("slice_index"); + it != create_options.end()) { + slice_index = std::get(it->second); + } xla::GpuClientOptions options; options.allocator_config = allocator_config; @@ -172,6 +178,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { should_stage_host_to_device_transfers; options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; + options.slice_index = slice_index; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetStreamExecutorGpuClient(options)); args->client = pjrt::CreateWrapperClient(std::move(client)); diff --git a/third_party/xla/xla/pjrt/distributed/protocol.proto b/third_party/xla/xla/pjrt/distributed/protocol.proto index 4c8b23ba85cd55..9d65bae39e4a24 100644 --- a/third_party/xla/xla/pjrt/distributed/protocol.proto +++ b/third_party/xla/xla/pjrt/distributed/protocol.proto @@ -87,6 +87,9 @@ message LocalTopologyProto { // See /proc/sys/kernel/random/boot_id. string boot_id = 2; repeated DeviceProto devices = 3; + + // Explicit slice index; derived from boot_id if absent + optional int32 slice_index = 4; } message GlobalTopologyProto { diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.cc b/third_party/xla/xla/pjrt/distributed/topology_util.cc index ad4645085f0b4b..62a80ac692a811 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util.cc @@ -155,37 +155,59 @@ static absl::StatusOr> GetAllLocalTopologies( } // Steals the contents of `local_topologies`. -GlobalTopologyProto BuildGlobalTopology( +absl::StatusOr BuildGlobalTopology( absl::Span local_topologies, bool assign_global_device_ids) { + CHECK(!local_topologies.empty()); + bool explicit_slice_indices = local_topologies[0].has_slice_index(); + if (explicit_slice_indices) { + // Every local topology explicitly declares its slice_index. + for (LocalTopologyProto& local : local_topologies) { + if (!local.has_slice_index()) { + return InvalidArgument( + "Either all of or none of the local topologies " + "should explicitly set slice_index"); + } + int slice_index = local.slice_index(); + for (DeviceProto& device : *local.mutable_devices()) { + device.set_slice_index(slice_index); + } + } + } else { + // Assign local devices of the same host to the same slice_index. + absl::flat_hash_map boot_id_to_slice_index; + for (LocalTopologyProto& local : local_topologies) { + if (local.has_slice_index()) { + return InvalidArgument( + "Either all of or none of the local topologies " + "should explicitly set slice_index"); + } + // Every new boot_id seen is treated as a new host/slice. + auto [it, _] = boot_id_to_slice_index.try_emplace( + local.boot_id(), boot_id_to_slice_index.size()); + for (DeviceProto& device : *local.mutable_devices()) { + device.set_slice_index(it->second); + } + } + if (VLOG_IS_ON(10)) { + for (auto it = boot_id_to_slice_index.begin(); + it != boot_id_to_slice_index.end(); ++it) { + LOG(INFO) << "BuildGlobalTopology boot_id_to_slice_index " << it->first + << "->" << it->second; + } + } + } + GlobalTopologyProto global_topology; int next_global_device_id = 0; - // Assign local devices of the same host to the same slice_index. - int next_slice_index = 0; - absl::flat_hash_map boot_id_to_slice_index; for (LocalTopologyProto& local : local_topologies) { - // Every new boot_id seen is treated as a new host/slice. - absl::string_view boot_id = local.boot_id(); - auto [it, inserted] = - boot_id_to_slice_index.try_emplace(boot_id, next_slice_index); - if (inserted) { - ++next_slice_index; - } - for (DeviceProto& device : *local.mutable_devices()) { - if (assign_global_device_ids) { + if (assign_global_device_ids) { + for (DeviceProto& device : *local.mutable_devices()) { device.set_global_device_id(next_global_device_id++); } - device.set_slice_index(it->second); } global_topology.add_nodes()->Swap(&local); } - if (VLOG_IS_ON(10)) { - for (auto it = boot_id_to_slice_index.begin(); - it != boot_id_to_slice_index.end(); ++it) { - LOG(INFO) << "BuildGlobalTopology boot_id_to_slice_index " << it->first - << "->" << it->second; - } - } return global_topology; } @@ -239,9 +261,10 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id, TF_ASSIGN_OR_RETURN(std::vector local_topologies, GetAllLocalTopologies(platform, num_nodes, kv_store, get_local_topology_timeout)); - *global_topology = + TF_ASSIGN_OR_RETURN( + *global_topology, BuildGlobalTopology(absl::Span(local_topologies), - assign_global_device_ids); + assign_global_device_ids)); TF_RETURN_IF_ERROR(kv_store->Set(global_topology_key, global_topology->SerializeAsString())); } else { diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.h b/third_party/xla/xla/pjrt/distributed/topology_util.h index 2e492d9c907398..9997298e55cab0 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.h +++ b/third_party/xla/xla/pjrt/distributed/topology_util.h @@ -53,7 +53,7 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id, // Given a LocalTopologyProto object from each node, builds a // GlobalTopologyProto that describes all nodes. Steals the contents of the // LocalTopologyProtos. -GlobalTopologyProto BuildGlobalTopology( +absl::StatusOr BuildGlobalTopology( absl::Span local_topologies, bool assign_global_device_ids); diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index 3926aa52c0a5df..ea9665936dc2f5 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -46,9 +46,10 @@ TEST(TopologyTest, BuildGlobalTopology) { DeviceProto* d3 = locals[1].add_devices(); d3->set_local_device_ordinal(1); - GlobalTopologyProto global = + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, BuildGlobalTopology(absl::Span(locals), - /*assign_global_device_ids=*/true); + /*assign_global_device_ids=*/true)); EXPECT_EQ(global.nodes_size(), 2); EXPECT_EQ(global.nodes()[0].devices_size(), 2); EXPECT_EQ(global.nodes()[1].devices_size(), 2); @@ -177,6 +178,40 @@ TEST(TopologyTest, ExchangeTopology_TwiceWithDifferentLocalTopology_Fails) { } } +TEST(TopologyTest, BuildGlobalTopologyWithExplicitSliceIndices) { + // Set slice_index explicitly, and expect boot id to be ignored. + std::string boot_id = "foo"; + std::vector locals(2); + locals[0].set_boot_id(boot_id); + locals[1].set_boot_id(boot_id); + locals[0].set_node_id(0); + locals[1].set_node_id(1); + locals[0].set_slice_index(1); + locals[1].set_slice_index(0); + // Adds 2 devices to each host. + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(1); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, + BuildGlobalTopology(absl::Span(locals), + /*assign_global_device_ids=*/true)); + + EXPECT_EQ(global.nodes_size(), 2); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[0].devices()[0].slice_index(), 1); + EXPECT_EQ(global.nodes()[0].devices()[1].slice_index(), 1); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices()[0].slice_index(), 0); + EXPECT_EQ(global.nodes()[1].devices()[1].slice_index(), 0); +} + TEST(TopologyTest, BuildGpuTopology) { std::string slice_0_boot_id = "foo"; std::string slice_1_boot_id = "bar"; @@ -200,9 +235,10 @@ TEST(TopologyTest, BuildGpuTopology) { d3->set_local_device_ordinal(1); d3->set_core_count(20); - GlobalTopologyProto global = + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, BuildGlobalTopology(absl::Span(locals), - /*assign_global_device_ids=*/true); + /*assign_global_device_ids=*/true)); TF_ASSERT_OK_AND_ASSIGN(auto gpu_topology, BuildGpuTopology(global)); EXPECT_EQ(gpu_topology.device_ids_size(), 4); @@ -229,9 +265,10 @@ TEST(TopologyTest, BuildGpuTopologyWithDifferentNumHostsPerSlice) { DeviceProto* d2 = locals[2].add_devices(); d2->set_local_device_ordinal(0); - GlobalTopologyProto global = + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, BuildGlobalTopology(absl::Span(locals), - /*assign_global_device_ids=*/true); + /*assign_global_device_ids=*/true)); TF_ASSERT_OK_AND_ASSIGN(auto gpu_topology, BuildGpuTopology(global)); EXPECT_EQ(gpu_topology.device_ids_size(), 3); @@ -256,9 +293,10 @@ TEST(TopologyTest, BuildGpuTopologyWithDifferentNumDevicesPerHost) { DeviceProto* d2 = locals[1].add_devices(); d2->set_local_device_ordinal(0); - GlobalTopologyProto global = + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, BuildGlobalTopology(absl::Span(locals), - /*assign_global_device_ids=*/true); + /*assign_global_device_ids=*/true)); TF_ASSERT_OK_AND_ASSIGN(auto gpu_topology, BuildGpuTopology(global)); EXPECT_EQ(gpu_topology.device_ids_size(), 3); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 0c859c6e49fa34..23bf04851e88ae 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1137,7 +1137,7 @@ absl::StatusOr BuildDistributedDevices( gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, std::optional mock_gpu_topology, - absl::Duration get_local_topology_timeout, + std::optional slice_index, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout) { std::vector> devices; LocalTopologyProto local_topology; @@ -1150,6 +1150,9 @@ absl::StatusOr BuildDistributedDevices( boot_id_str = boot_id_str_or_status.value(); } local_topology.set_boot_id(boot_id_str); + if (slice_index.has_value()) { + local_topology.set_slice_index(*slice_index); + } for (const auto& ordinal_and_device : local_device_states) { const se::Platform* platform = ordinal_and_device.second->executor()->GetPlatform(); @@ -1208,8 +1211,9 @@ absl::StatusOr BuildDistributedDevices( local_topologies[node_id].set_boot_id(absl::StrCat(i)); } } - global_topology = BuildGlobalTopology(absl::MakeSpan(local_topologies), - /*assign_global_device_ids=*/true); + TF_ASSIGN_OR_RETURN(global_topology, + BuildGlobalTopology(absl::MakeSpan(local_topologies), + /*assign_global_device_ids=*/true)); } else { TF_RETURN_IF_ERROR(ExchangeTopologies( platform_name, node_id, num_nodes, get_local_topology_timeout, @@ -1394,10 +1398,11 @@ absl::StatusOr> GetStreamExecutorGpuClient( TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr); TF_ASSIGN_OR_RETURN( DeviceTopologyPair device_topology_pair, - BuildDistributedDevices( - pjrt_platform_name, std::move(local_device_states), options.node_id, - options.num_nodes, gpu_run_options.get(), kv_store, - options.enable_mock_nccl, options.mock_gpu_topology)); + BuildDistributedDevices(pjrt_platform_name, + std::move(local_device_states), options.node_id, + options.num_nodes, gpu_run_options.get(), + kv_store, options.enable_mock_nccl, + options.mock_gpu_topology, options.slice_index)); auto gpu_topology = std::shared_ptr( GpuTopology::FromProto(device_topology_pair.second)); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 7d40b04535a839..e920edaf33c580 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -180,6 +180,7 @@ absl::StatusOr BuildDistributedDevices( gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, std::optional mock_gpu_topology = std::nullopt, + std::optional slice_index = std::nullopt, absl::Duration get_local_topology_timeout = absl::Minutes(2), absl::Duration get_global_topology_timeout = absl::Minutes(5)); diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h index e40be6b4c189ee..58c2c67ba5101d 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h @@ -46,6 +46,8 @@ struct GpuClientOptions { bool enable_mock_nccl = false; std::optional mock_gpu_topology; + + std::optional slice_index; }; } // namespace xla From 45f85586435474b3c4a7b6eae3b00c4001aef04b Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 2 Apr 2025 05:37:07 -0700 Subject: [PATCH 0143/1324] [XLA:GPU] Remove second run of SortRewriter from the GPU pipeline. This second run was possibly meant to rewrite the sorts created by DynamicPadder. But there are two issues here: We would need to run SortRewriter before StableSortExpander, and SortRewriter currently doesn't support sorting pairs where the key is not an unsigned integer. Add a pass order test to verify that SortRewriter runs before ComparisonExpander and StableSortExpander, as otherwise we would not match the expanded patterns. PiperOrigin-RevId: 743095864 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 9 +++++---- third_party/xla/xla/service/gpu/gpu_compiler_test.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index f73d785c71b21c..e51ac671036ff1 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -784,15 +784,16 @@ absl::Status RunOptimizationPasses( LOG(FATAL) << "Unreachable"; } + // DynamicPadder creates a stable KeyValue sort for dynamic reshapes. pipeline.AddPass(dynamic_padder_options); + // TODO(b/407909195): Add SortRewriter here once it supports S32 keys for + // KeyValueSort. It needs to run before StableSortExpander, otherwise we will + // not match the comparison computation. + // Expand the sort op to support stable sorting if required. pipeline.AddPass(); - if (hlo_module->config().debug_options().xla_gpu_enable_cub_radix_sort()) { - pipeline.AddPass(); - } - se::GpuComputeCapability gpu_version = gpu_target_config.device_description.gpu_compute_capability(); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 9f890c4076f782..7374c6d522fa63 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -1618,6 +1618,14 @@ TEST_F(PassOrderTest, FusionDispatchRunsAfterAllFusionPasses) { /*include_pipeline_name=*/true); } +TEST_F(PassOrderTest, + SortRewriterRunsBeforeStableSortExpanderAndComparisonExpander) { + VerifyPassOrder(/*first_pass_regex=*/"sort-rewriter", + /*last_pass_regex=*/"stable-sort-expander"); + VerifyPassOrder(/*first_pass_regex=*/"sort-rewriter", + /*last_pass_regex=*/"comparison-expander"); +} + TEST_F(PassOrderTest, CollectivePipelinerRunsAfterCollectiveQuantizer) { DebugOptions options = GetDebugOptionsForTest(); options.set_xla_gpu_enable_pipelined_collectives(true); From 9173cd54c01548e69612f51432294282f0c6dd5e Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Wed, 2 Apr 2025 05:45:19 -0700 Subject: [PATCH 0144/1324] [XLA:CPU] Add AOT compilation for microbenchmarks PiperOrigin-RevId: 743098055 --- .../xla/xla/backends/cpu/benchmarks/BUILD | 22 +++ .../cpu/benchmarks/aliasing_benchmark_test.cc | 23 ++-- .../benchmarks/concatenate_benchmark_test.cc | 11 +- .../benchmarks/convolution_benchmark_test.cc | 128 +++++++++--------- .../benchmarks/custom_call_benchmark_test.cc | 24 ++-- .../dag_execution_benchmark_test.cc | 9 +- .../cpu/benchmarks/dot_benchmark_test.cc | 9 +- .../dynamic_update_slice_benchmark_test.cc | 9 +- .../benchmarks/elementwise_benchmark_test.cc | 19 ++- .../cpu/benchmarks/exp_benchmark_test.cc | 17 ++- .../cpu/benchmarks/fusion_benchmark_test.cc | 40 ++++-- .../cpu/benchmarks/gather_benchmark_test.cc | 8 +- .../cpu/benchmarks/log_benchmark_test.cc | 18 ++- .../benchmarks/optimizer_benchmark_test.cc | 9 +- .../cpu/benchmarks/pad_benchmark_test.cc | 8 +- .../benchmarks/reduction_benchmark_test.cc | 15 +- .../cpu/benchmarks/scatter_benchmark_test.cc | 35 +++-- .../select_and_scatter_benchmark_test.cc | 12 +- .../cpu/benchmarks/tanh_benchmark_test.cc | 23 +++- .../cpu/benchmarks/topk_benchmark_test.cc | 14 +- .../transposed_copy_benchmark_test.cc | 9 +- .../transposed_dot_benchmark_test.cc | 15 +- .../benchmarks/xnn_fusion_benchmark_test.cc | 28 ++-- 23 files changed, 311 insertions(+), 194 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/BUILD b/third_party/xla/xla/backends/cpu/benchmarks/BUILD index 5f528ed91c7683..0251f718339e38 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/backends/cpu/benchmarks/BUILD @@ -86,6 +86,7 @@ xla_cc_test( srcs = ["aliasing_benchmark_test.cc"], deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:xla_data_proto_cc", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", @@ -102,6 +103,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -121,6 +123,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -140,6 +143,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -159,6 +163,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -178,6 +183,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -197,6 +203,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -216,6 +223,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -248,6 +256,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -267,6 +276,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -286,6 +296,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -303,6 +314,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -325,6 +337,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:array2d", "//xla:literal", "//xla:literal_util", @@ -345,6 +358,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -364,6 +378,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -383,6 +398,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -400,6 +416,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -419,6 +436,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -438,6 +456,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:array2d", "//xla:literal", "//xla:literal_util", @@ -461,6 +480,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -480,6 +500,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -499,6 +520,7 @@ xla_cc_test( fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. deps = [ ":hlo_benchmark_runner", + ":multi_benchmark_config", "//xla:literal", "//xla:literal_util", "//xla:shape_util", diff --git a/third_party/xla/xla/backends/cpu/benchmarks/aliasing_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/aliasing_benchmark_test.cc index a03408f94ac841..4bfdd6ea31b4aa 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/aliasing_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/aliasing_benchmark_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" @@ -28,7 +29,8 @@ limitations under the License. namespace xla::cpu { namespace { -static void BM_ModelWithAliasing(benchmark::State& state) { +static void BM_ModelWithAliasing(benchmark::State& state, + HloBenchmarkOptions options) { int64_t num_executions = state.range(0); absl::string_view hlo = R"( @@ -42,22 +44,15 @@ ENTRY main.5 { } )"; - HloBenchmarkOptions benchmark_options; - benchmark_options.num_executions = num_executions; + options.num_executions = num_executions; - CHECK_OK(RunHloBenchmark(state, hlo, {}, {}, benchmark_options)); + CHECK_OK(RunHloBenchmark(state, hlo, {}, {}, options)); } -void GenerateModelWithAliasingArgs(benchmark::internal::Benchmark* benchmark) { - benchmark->MeasureProcessCPUTime(); - const std::vector num_executions = {1, 8}; - benchmark->ArgNames({"num_executions"}); - for (int64_t num_execution : num_executions) { - benchmark->Args({num_execution}); - } -} - -BENCHMARK(BM_ModelWithAliasing)->Apply(GenerateModelWithAliasingArgs); +XLA_CPU_BENCHMARK(BM_ModelWithAliasing) + ->ArgName("num_executions") + ->Arg(1) + ->Arg(8); } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/benchmarks/concatenate_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/concatenate_benchmark_test.cc index 733649e511f34f..793c79d4252974 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/concatenate_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/concatenate_benchmark_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape.h" @@ -32,7 +33,8 @@ limitations under the License. namespace xla::cpu { -static void BM_ConcatenateTwoR3F32(benchmark::State& state) { +static void BM_ConcatenateTwoR3F32(benchmark::State& state, + HloBenchmarkOptions options) { bool disable_parallel_backend = !static_cast(state.range(0)); int64_t dims[3] = {state.range(1), state.range(2), state.range(3)}; Shape shape = ShapeUtil::MakeShape(F32, dims); @@ -57,8 +59,7 @@ static void BM_ConcatenateTwoR3F32(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); auto p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); - HloBenchmarkOptions benchmark_options; - benchmark_options.disable_parallel_task_assigner = disable_parallel_backend; + options.disable_parallel_task_assigner = disable_parallel_backend; std::vector args = {&p0, &p1}; CHECK_OK(RunHloBenchmark(state, hlo, args, @@ -66,10 +67,10 @@ static void BM_ConcatenateTwoR3F32(benchmark::State& state) { {"$shape", absl::StrJoin(dims, ",")}, {"$out_shape", absl::StrJoin(out_dims, ",")}, {"$axis", absl::StrCat(axis)}}, - benchmark_options)); + options)); } -BENCHMARK(BM_ConcatenateTwoR3F32) +XLA_CPU_BENCHMARK(BM_ConcatenateTwoR3F32) ->MeasureProcessCPUTime() ->ArgNames({"parallel", "batch", "width", "height", "axis"}) // Fast Concat (memcpy, no parallelism) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/convolution_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/convolution_benchmark_test.cc index 3d355905eccc1a..e3beb047c09074 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/convolution_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/convolution_benchmark_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -32,7 +33,7 @@ namespace { bool IsOdd(int n) { return n % 2 == 1; } template -static void BM_Conv2D(benchmark::State& state) { +static void BM_Conv2D(benchmark::State& state, HloBenchmarkOptions options) { int batch = state.range(0); int height = state.range(1); int width = state.range(2); @@ -78,10 +79,12 @@ static void BM_Conv2D(benchmark::State& state) { {"$kernel_shape", kernel_shape.ToString()}, {"$window_size", absl::StrCat(kernel_h, "x", kernel_w)}, {"$padding", absl::StrCat(padding_h, "_", padding_h, "x", - padding_w, "_", padding_w)}})); + padding_w, "_", padding_w)}}, + options)); } -static void BM_GroupedConv2D(benchmark::State& state) { +static void BM_GroupedConv2D(benchmark::State& state, + HloBenchmarkOptions options) { int batch = state.range(0); int height = state.range(1); int width = state.range(2); @@ -133,11 +136,13 @@ static void BM_GroupedConv2D(benchmark::State& state) { {"$window_size", absl::StrCat(kernel_h, "x", kernel_w)}, {"$padding", absl::StrCat(padding_h, "_", padding_h, "x", padding_w, "_", padding_w)}, - {"$feature_group_count", absl::StrCat(feature_group_count)}})); + {"$feature_group_count", absl::StrCat(feature_group_count)}}, + options)); } // Regular strided 1D convolution. Shapes come from an actual use case. -static void BM_Conv1DStrided(benchmark::State& state) { +static void BM_Conv1DStrided(benchmark::State& state, + HloBenchmarkOptions options) { int input_channels = state.range(0); int output_channels = state.range(1); @@ -170,7 +175,8 @@ static void BM_Conv1DStrided(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo_module, args, {{"$input_shape", input_shape.ToString()}, {"$kernel_shape", kernel_shape.ToString()}, - {"$output_shape", output_shape.ToString()}})); + {"$output_shape", output_shape.ToString()}}, + options)); } // Transposed version (i.e. gradient) of BM_Conv1DStrided. In terms of shapes, @@ -179,7 +185,8 @@ static void BM_Conv1DStrided(benchmark::State& state) { // performance of this function with BM_Conv1DStrided). // Currently, the performance is few times worse than regular conv when they // should be similar. -static void BM_Conv1DTransposedStrided(benchmark::State& state) { +static void BM_Conv1DTransposedStrided(benchmark::State& state, + HloBenchmarkOptions options) { int input_channels = state.range(0); int output_channels = state.range(1); @@ -212,12 +219,13 @@ static void BM_Conv1DTransposedStrided(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo_module, args, {{"$input_shape", input_shape.ToString()}, {"$kernel_shape", kernel_shape.ToString()}, - {"$output_shape", output_shape.ToString()}})); + {"$output_shape", output_shape.ToString()}}, + options)); } // The same shapes as BM_Conv1DTransposedStrided, but with a different layout. static void BM_Conv1DTransposedStridedNonDefaultLayout( - benchmark::State& state) { + benchmark::State& state, HloBenchmarkOptions options) { int input_channels = state.range(0); int output_channels = state.range(1); std::string hlo_module = R"( @@ -249,12 +257,14 @@ static void BM_Conv1DTransposedStridedNonDefaultLayout( CHECK_OK(RunHloBenchmark(state, hlo_module, args, {{"$input_shape", input_shape.ToString()}, {"$kernel_shape", kernel_shape.ToString()}, - {"$output_shape", output_shape.ToString()}})); + {"$output_shape", output_shape.ToString()}}, + options)); } // Regular strided 2D convolution. Buffer sizes and convolution parameters are // based on an actual 1D use case, but adapted to a 2D convolution. -static void BM_Conv2DStrided(benchmark::State& state) { +static void BM_Conv2DStrided(benchmark::State& state, + HloBenchmarkOptions options) { std::string hlo_module = R"( HloModule jit_jconvf @@ -279,7 +289,7 @@ static void BM_Conv2DStrided(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); std::vector args = {&input, &kernel}; - CHECK_OK(RunHloBenchmark(state, hlo_module, args)); + CHECK_OK(RunHloBenchmark(state, hlo_module, args, {}, options)); } // Transposed version (i.e. gradient) of BM_Conv2DStrided. In terms of shapes, @@ -288,7 +298,8 @@ static void BM_Conv2DStrided(benchmark::State& state) { // performance of this function with BM_Conv2DStrided). // Currently, the performance is orders of magnitude worse than regular conv // when they should be similar. -static void BM_Conv2DTransposedStrided(benchmark::State& state) { +static void BM_Conv2DTransposedStrided(benchmark::State& state, + HloBenchmarkOptions options) { std::string hlo_module = R"( HloModule jit_jconvt @@ -314,11 +325,12 @@ static void BM_Conv2DTransposedStrided(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); std::vector args = {&input, &kernel}; - CHECK_OK(RunHloBenchmark(state, hlo_module, args)); + CHECK_OK(RunHloBenchmark(state, hlo_module, args, {}, options)); } // Regular (i.e. non-transposed) grouped and strided 2D convolution. -static void BM_GroupedConv2DStrided(benchmark::State& state) { +static void BM_GroupedConv2DStrided(benchmark::State& state, + HloBenchmarkOptions options) { int input_channels = state.range(0); int output_channels = state.range(1); int feature_group_count = state.range(2); @@ -356,14 +368,16 @@ static void BM_GroupedConv2DStrided(benchmark::State& state) { state, hlo_module, args, {{"$input_shape", input_shape.ToString()}, {"$kernel_shape", kernel_shape.ToString()}, - {"$feature_group_count", std::to_string(feature_group_count)}})); + {"$feature_group_count", std::to_string(feature_group_count)}}, + options)); } // Transposed version (i.e. gradient) of BM_GroupedConv2DStrided. In terms of // shapes, this operation can be thought of as reverse of regular strided // convolution, that's why input and output shapes are swapped (so we can // directly compare performance of this function with BM_GroupedConv2DStrided). -static void BM_GroupedConv2DTransposedStrided(benchmark::State& state) { +static void BM_GroupedConv2DTransposedStrided(benchmark::State& state, + HloBenchmarkOptions options) { int input_channels = state.range(0); int output_channels = state.range(1); int feature_group_count = state.range(2); @@ -401,56 +415,40 @@ static void BM_GroupedConv2DTransposedStrided(benchmark::State& state) { state, hlo_module, args, {{"$input_shape", input_shape.ToString()}, {"$kernel_shape", kernel_shape.ToString()}, - {"$feature_group_count", std::to_string(feature_group_count)}})); + {"$feature_group_count", std::to_string(feature_group_count)}}, + options)); } -// -------------------------------------------------------------------------- // -// Pixel CNN convolutions. -// -------------------------------------------------------------------------- // - -// Shapes from XLA convolution tests -BENCHMARK(BM_Conv2D) +XLA_CPU_BENCHMARK(BM_Conv2D) ->MeasureProcessCPUTime() + // -------------------------------------------------------------------------- + // // Pixel CNN convolutions. + // -------------------------------------------------------------------------- + // // Shapes from XLA convolution tests ->Args({8, 5, 5, 1, 1, 1, 32}) ->Args({8, 5, 5, 4, 1, 1, 32}) - ->Args({8, 128, 128, 4, 1, 1, 8}); - -// Shapes from TF convolution benchmarks. -BENCHMARK(BM_Conv2D) - ->MeasureProcessCPUTime() + ->Args({8, 128, 128, 4, 1, 1, 8}) + // Shapes from TF convolution benchmarks. ->Args({8, 32, 32, 128, 1, 1, 1024}) ->Args({16, 32, 32, 128, 1, 1, 1024}) - ->Args({32, 32, 32, 128, 1, 1, 1024}); - -// Shapes similar to Eigen spatial convolution benchmarks. -BENCHMARK(BM_Conv2D) - ->MeasureProcessCPUTime() + ->Args({32, 32, 32, 128, 1, 1, 1024}) + // Shapes similar to Eigen spatial convolution benchmarks. ->Args({32, 64, 64, 32, 1, 1, 64}) ->Args({32, 256, 256, 4, 1, 1, 16}) ->Args({32, 64, 64, 4, 1, 1, 16}) - ->Args({32, 32, 32, 96, 1, 1, 96}); - -// -------------------------------------------------------------------------- // -// 3x3 Convolution: SpatialConvolution -// -------------------------------------------------------------------------- // - -// Shapes from XLA convolution tests -BENCHMARK(BM_Conv2D) - ->MeasureProcessCPUTime() + ->Args({32, 32, 32, 96, 1, 1, 96}) + // -------------------------------------------------------------------------- + // // 3x3 Convolution: SpatialConvolution + // -------------------------------------------------------------------------- + // // Shapes from XLA convolution tests ->Args({8, 5, 5, 1, 3, 3, 32}) ->Args({8, 5, 5, 4, 3, 3, 32}) - ->Args({8, 128, 128, 4, 3, 3, 8}); - -// Shapes from TF convolution benchmarks -BENCHMARK(BM_Conv2D) - ->MeasureProcessCPUTime() + ->Args({8, 128, 128, 4, 3, 3, 8}) + // Shapes from TF convolution benchmarks ->Args({8, 32, 32, 128, 3, 3, 1024}) ->Args({16, 32, 32, 128, 3, 3, 1024}) - ->Args({32, 32, 32, 128, 3, 3, 1024}); - -// Shapes similar to Eigen spatial convolution benchmarks. -BENCHMARK(BM_Conv2D) - ->MeasureProcessCPUTime() + ->Args({32, 32, 32, 128, 3, 3, 1024}) + // Shapes similar to Eigen spatial convolution benchmarks. ->Args({32, 64, 64, 32, 3, 3, 64}) ->Args({32, 256, 256, 4, 3, 3, 16}) ->Args({32, 64, 64, 4, 3, 3, 16}) @@ -460,7 +458,7 @@ BENCHMARK(BM_Conv2D) // Grouped convolution // -------------------------------------------------------------------------- // -BENCHMARK(BM_GroupedConv2D) +XLA_CPU_BENCHMARK(BM_GroupedConv2D) ->MeasureProcessCPUTime() ->Args({1, 45, 45, 1024, 5, 5, 1024, 1024}); @@ -468,38 +466,34 @@ BENCHMARK(BM_GroupedConv2D) // 1D and 2D strided convolutions // -------------------------------------------------------------------------- // -BENCHMARK(BM_Conv1DStrided) +XLA_CPU_BENCHMARK(BM_Conv1DStrided) ->MeasureProcessCPUTime() ->Args({1, 129}) ->Args({3, 129}); -BENCHMARK(BM_Conv1DTransposedStrided) +XLA_CPU_BENCHMARK(BM_Conv1DTransposedStrided) ->MeasureProcessCPUTime() ->MeasureProcessCPUTime() ->Args({129, 1}) ->Args({129, 3}); -BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout) +XLA_CPU_BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout) ->MeasureProcessCPUTime() ->Args({129, 1}) ->Args({129, 3}); -BENCHMARK(BM_Conv2DStrided)->MeasureProcessCPUTime(); -BENCHMARK(BM_Conv2DTransposedStrided)->MeasureProcessCPUTime(); +XLA_CPU_BENCHMARK(BM_Conv2DStrided)->MeasureProcessCPUTime(); +XLA_CPU_BENCHMARK(BM_Conv2DTransposedStrided)->MeasureProcessCPUTime(); // -------------------------------------------------------------------------- // // Grouped strided convolutions // -------------------------------------------------------------------------- // -BENCHMARK(BM_GroupedConv2DStrided) - ->MeasureProcessCPUTime() - ->Args({128, 128, 128}); -BENCHMARK(BM_GroupedConv2DTransposedStrided) - ->MeasureProcessCPUTime() - ->Args({128, 128, 128}); -BENCHMARK(BM_GroupedConv2DStrided) +XLA_CPU_BENCHMARK(BM_GroupedConv2DStrided) ->MeasureProcessCPUTime() + ->Args({128, 128, 128}) ->Args({128, 128, 16}); -BENCHMARK(BM_GroupedConv2DTransposedStrided) +XLA_CPU_BENCHMARK(BM_GroupedConv2DTransposedStrided) ->MeasureProcessCPUTime() + ->Args({128, 128, 128}) ->Args({128, 128, 16}); } // namespace diff --git a/third_party/xla/xla/backends/cpu/benchmarks/custom_call_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/custom_call_benchmark_test.cc index c17f162a46191e..c5a759c7ca0648 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/custom_call_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/custom_call_benchmark_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/literal.h" @@ -49,7 +50,8 @@ XLA_FFI_DEFINE_HANDLER( XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$minimal", "Host", kMinimal); -static void BM_CustomCall_Minimal(benchmark::State& state) { +static void BM_CustomCall_Minimal(benchmark::State& state, + HloBenchmarkOptions options) { const char* kModuleStr = R"( HloModule module @@ -60,7 +62,7 @@ static void BM_CustomCall_Minimal(benchmark::State& state) { } )"; CHECK_OK(RunHloBenchmark(state, kModuleStr, /*args=*/{}, - /*replacements=*/{})); + /*replacements=*/{}, options)); } static absl::Status ManyIntAttributes( @@ -94,7 +96,8 @@ XLA_FFI_DEFINE_HANDLER(kManyIntAttributes, ManyIntAttributes, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$many_int_attributes", "Host", kManyIntAttributes); -static void BM_CustomCall_16IntAttributes(benchmark::State& state) { +static void BM_CustomCall_16IntAttributes(benchmark::State& state, + HloBenchmarkOptions options) { absl::string_view hlo = R"( HloModule module @@ -111,7 +114,8 @@ static void BM_CustomCall_16IntAttributes(benchmark::State& state) { } config << "}"; CHECK_OK(RunHloBenchmark(state, hlo, /*args=*/{}, - /*replacements=*/{{"$config", config.str()}})); + /*replacements=*/{{"$config", config.str()}}, + options)); } static absl::Status ManyFloatBuffers( @@ -151,7 +155,8 @@ XLA_FFI_DEFINE_HANDLER(kManyFloatBuffers, ManyFloatBuffers, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$many_float_buffers", "Host", kManyFloatBuffers); -static void BM_CustomCall_16FloatBuffers(benchmark::State& state) { +static void BM_CustomCall_16FloatBuffers(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d = 128; absl::string_view hlo = R"( @@ -182,12 +187,13 @@ static void BM_CustomCall_16FloatBuffers(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args(10, &p0); - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d", absl::StrCat(d)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d", absl::StrCat(d)}}, options)); } -BENCHMARK(BM_CustomCall_Minimal)->MeasureProcessCPUTime(); -BENCHMARK(BM_CustomCall_16IntAttributes)->MeasureProcessCPUTime(); -BENCHMARK(BM_CustomCall_16FloatBuffers)->MeasureProcessCPUTime(); +XLA_CPU_BENCHMARK(BM_CustomCall_Minimal)->MeasureProcessCPUTime(); +XLA_CPU_BENCHMARK(BM_CustomCall_16IntAttributes)->MeasureProcessCPUTime(); +XLA_CPU_BENCHMARK(BM_CustomCall_16FloatBuffers)->MeasureProcessCPUTime(); } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/benchmarks/dag_execution_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/dag_execution_benchmark_test.cc index 3d1782f34db096..82423c92a4abe5 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/dag_execution_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/dag_execution_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_DagExecution(benchmark::State& state) { +static void BM_DagExecution(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); // We use this benchmark to test how well XLA does the scheduling of the HLO @@ -88,10 +90,11 @@ static void BM_DagExecution(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -BENCHMARK(BM_DagExecution) +XLA_CPU_BENCHMARK(BM_DagExecution) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc index 3e46d5ea1fe9ef..ced6f8f7af0a93 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" @@ -31,7 +32,8 @@ limitations under the License. namespace xla::cpu { -static void BM_BatchedDot(benchmark::State& state) { +static void BM_BatchedDot(benchmark::State& state, + HloBenchmarkOptions options) { PrimitiveType dtype = static_cast(state.range(0)); int64_t d0 = state.range(1); int64_t d1 = state.range(2); @@ -68,11 +70,12 @@ static void BM_BatchedDot(benchmark::State& state) { state, hlo, args, {{"$dtype", primitive_util::LowercasePrimitiveTypeName(dtype)}, {"$d0", absl::StrCat(d0)}, - {"$d1", absl::StrCat(d1)}})); + {"$d1", absl::StrCat(d1)}}, + options)); } #define BENCHMARK_BATCHED_DOT(dtype) \ - BENCHMARK(BM_BatchedDot) \ + XLA_CPU_BENCHMARK(BM_BatchedDot) \ ->MeasureProcessCPUTime() \ ->Args({dtype, 1, 2}) \ ->Args({dtype, 1, 32}) \ diff --git a/third_party/xla/xla/backends/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc index cd06619c55ebc7..1763738647c9b0 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_DynamicUpdateSliceF32(benchmark::State& state) { +static void BM_DynamicUpdateSliceF32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -55,10 +57,11 @@ static void BM_DynamicUpdateSliceF32(benchmark::State& state) { auto p3 = LiteralUtil::CreateR0(0); std::vector args = {&p0, &p1, &p2, &p3}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -BENCHMARK(BM_DynamicUpdateSliceF32) +XLA_CPU_BENCHMARK(BM_DynamicUpdateSliceF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/elementwise_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/elementwise_benchmark_test.cc index 61225745a41a77..7960d8558f9df8 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/elementwise_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/elementwise_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,7 @@ limitations under the License. namespace xla::cpu { -static void BM_AddF32(benchmark::State& state) { +static void BM_AddF32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -50,10 +51,11 @@ static void BM_AddF32(benchmark::State& state) { auto p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_AddBF16(benchmark::State& state) { +static void BM_AddBF16(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -73,10 +75,12 @@ static void BM_AddBF16(benchmark::State& state) { auto p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_ConvertF32ToBF16(benchmark::State& state) { +static void BM_ConvertF32ToBF16(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -94,11 +98,12 @@ static void BM_ConvertF32ToBF16(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } #define BENCHMARK_SIZES(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ diff --git a/third_party/xla/xla/backends/cpu/benchmarks/exp_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/exp_benchmark_test.cc index 9aad43a91473be..7369cc34f25eba 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/exp_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/exp_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,7 @@ limitations under the License. namespace xla::cpu { -static void BM_ExpF32(benchmark::State& state) { +static void BM_ExpF32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -48,7 +49,8 @@ static void BM_ExpF32(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } static void BM_ExpF16(benchmark::State& state) { @@ -73,7 +75,7 @@ static void BM_ExpF16(benchmark::State& state) { } #define REGISTER_EXP_BENCHMARK(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ @@ -82,6 +84,13 @@ static void BM_ExpF16(benchmark::State& state) { ->Arg(4096); REGISTER_EXP_BENCHMARK(BM_ExpF32); -REGISTER_EXP_BENCHMARK(BM_ExpF16); +// TODO(b/406431945): add AOT for f16 exp +BENCHMARK(BM_ExpF16) + ->MeasureProcessCPUTime() + ->Arg(128) + ->Arg(256) + ->Arg(512) + ->Arg(1024) + ->Arg(4096); } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/benchmarks/fusion_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/fusion_benchmark_test.cc index 6556be2f8905af..1a1d79d00963e7 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/fusion_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/fusion_benchmark_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla::cpu { -static void BM_FusionF32(benchmark::State& state) { +static void BM_FusionF32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -62,10 +63,12 @@ static void BM_FusionF32(benchmark::State& state) { auto p2 = *LiteralUtil::CreateRandomLiteral(scalar, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1, &p2}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_FusionF32_2(benchmark::State& state) { +static void BM_FusionF32_2(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -138,10 +141,12 @@ static void BM_FusionF32_2(benchmark::State& state) { auto p6 = *LiteralUtil::CreateRandomLiteral(shape3, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1, &p2, &p3, &p4, &p5, &p6}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_BcastFusionF32(benchmark::State& state) { +static void BM_BcastFusionF32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -163,10 +168,12 @@ static void BM_BcastFusionF32(benchmark::State& state) { auto p1 = *LiteralUtil::CreateRandomLiteral(scalar, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_DynamicUpdateSliceFusionF32(benchmark::State& state) { +static void BM_DynamicUpdateSliceFusionF32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -190,10 +197,12 @@ static void BM_DynamicUpdateSliceFusionF32(benchmark::State& state) { auto p2 = LiteralUtil::CreateR0(0); std::vector args = {&p0, &p1, &p2}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_ChainOfAddF32(benchmark::State& state) { +static void BM_ChainOfAddF32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t size = state.range(0); // In this benchmark we create a chain of additions starting from `p2` and @@ -242,10 +251,11 @@ static void BM_ChainOfAddF32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$size", absl::StrCat(size)}, {"$parameters", parameters}, - {"$additions", additions}})); + {"$additions", additions}}, + options)); } -BENCHMARK(BM_FusionF32) +XLA_CPU_BENCHMARK(BM_FusionF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) @@ -254,14 +264,14 @@ BENCHMARK(BM_FusionF32) ->Arg(8192) ->Arg(16384); -BENCHMARK(BM_FusionF32_2) +XLA_CPU_BENCHMARK(BM_FusionF32_2) ->MeasureProcessCPUTime() ->Arg(40) ->Arg(80) ->Arg(160) ->Arg(240); -BENCHMARK(BM_BcastFusionF32) +XLA_CPU_BENCHMARK(BM_BcastFusionF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) @@ -270,7 +280,7 @@ BENCHMARK(BM_BcastFusionF32) ->Arg(8192) ->Arg(16384); -BENCHMARK(BM_DynamicUpdateSliceFusionF32) +XLA_CPU_BENCHMARK(BM_DynamicUpdateSliceFusionF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) @@ -279,7 +289,7 @@ BENCHMARK(BM_DynamicUpdateSliceFusionF32) ->Arg(8192) ->Arg(16384); -BENCHMARK(BM_ChainOfAddF32) +XLA_CPU_BENCHMARK(BM_ChainOfAddF32) ->MeasureProcessCPUTime() ->Arg(64) ->Arg(128) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/gather_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/gather_benchmark_test.cc index 7cba1748fba918..fe5ca3c8eb2842 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/gather_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/gather_benchmark_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/array2d.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -31,7 +32,7 @@ limitations under the License. namespace xla::cpu { -static void BM_GatherS32(benchmark::State& state) { +static void BM_GatherS32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); int64_t d1 = state.range(1); int64_t slice_size = state.range(2); @@ -74,10 +75,11 @@ static void BM_GatherS32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}, - {"$slice_size", absl::StrCat(slice_size)}})); + {"$slice_size", absl::StrCat(slice_size)}}, + options)); } -BENCHMARK(BM_GatherS32) +XLA_CPU_BENCHMARK(BM_GatherS32) ->MeasureProcessCPUTime() ->Args({3, 3, 1}) ->Args({3, 3, 2}) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/log_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/log_benchmark_test.cc index 712c4305d6a163..ae6c518561405a 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/log_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/log_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,7 @@ limitations under the License. namespace xla::cpu { -static void BM_LogF32(benchmark::State& state) { +static void BM_LogF32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -48,7 +49,8 @@ static void BM_LogF32(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } static void BM_LogF16(benchmark::State& state) { @@ -73,7 +75,7 @@ static void BM_LogF16(benchmark::State& state) { } #define REGISTER_EXP_BENCHMARK(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ @@ -82,6 +84,14 @@ static void BM_LogF16(benchmark::State& state) { ->Arg(4096); REGISTER_EXP_BENCHMARK(BM_LogF32); -REGISTER_EXP_BENCHMARK(BM_LogF16); + +// TODO(b/406431945): add AOT for f16 log +BENCHMARK(BM_LogF16) + ->MeasureProcessCPUTime() + ->Arg(128) + ->Arg(256) + ->Arg(512) + ->Arg(1024) + ->Arg(4096); } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/benchmarks/optimizer_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/optimizer_benchmark_test.cc index 3a4a6d80b1c166..0daf44efc6a950 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/optimizer_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/optimizer_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_Optimizer0(benchmark::State& state) { +static void BM_Optimizer0(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -107,10 +109,11 @@ static void BM_Optimizer0(benchmark::State& state) { auto p1 = *LiteralUtil::CreateRandomLiteral(scalar, &engine, 1, 2); std::vector args = {&p0, &p0, &p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -BENCHMARK(BM_Optimizer0) +XLA_CPU_BENCHMARK(BM_Optimizer0) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/pad_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/pad_benchmark_test.cc index 195c7b5e6a6293..15fe6545d6c3b0 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/pad_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/pad_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,7 @@ limitations under the License. namespace xla::cpu { -static void BM_PadF32(benchmark::State& state) { +static void BM_PadF32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -53,10 +54,11 @@ static void BM_PadF32(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(value_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -BENCHMARK(BM_PadF32) +XLA_CPU_BENCHMARK(BM_PadF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/reduction_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/reduction_benchmark_test.cc index 3e9222f63a858e..c99081f5ce0eac 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/reduction_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/reduction_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_ReduceAddF32(benchmark::State& state) { +static void BM_ReduceAddF32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -55,10 +57,12 @@ static void BM_ReduceAddF32(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_ReduceAddBF16(benchmark::State& state) { +static void BM_ReduceAddBF16(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -83,11 +87,12 @@ static void BM_ReduceAddBF16(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } #define BENCHMARK_SIZES(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ diff --git a/third_party/xla/xla/backends/cpu/benchmarks/scatter_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/scatter_benchmark_test.cc index 44d05970d6e6d7..51cc9e1714dddc 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/scatter_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/scatter_benchmark_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/array2d.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/parser/hlo_parser.h" @@ -72,7 +73,7 @@ Literal CreateReduceIndices(int32_t num_elems, int32_t step) { return LiteralUtil::CreateR2FromArray2D(array); } -void BM_ScatterS32_R1(benchmark::State& state) { +void BM_ScatterS32_R1(benchmark::State& state, HloBenchmarkOptions options) { const int64_t d0 = state.range(0); const int64_t slice_size = state.range(1); @@ -113,12 +114,13 @@ void BM_ScatterS32_R1(benchmark::State& state) { std::vector args = {&operand, &scatter_indices, &update}; CHECK_OK(RunHloBenchmark( state, hlo, args, - {{"$d0", absl::StrCat(d0)}, {"$slice_size", absl::StrCat(slice_size)}})); + {{"$d0", absl::StrCat(d0)}, {"$slice_size", absl::StrCat(slice_size)}}, + options)); state.SetComplexityN(state.range(1)); } -void BM_ScatterS32_R2(benchmark::State& state) { +void BM_ScatterS32_R2(benchmark::State& state, HloBenchmarkOptions options) { const int64_t d0 = state.range(0); const int64_t d1 = d0; const int64_t slice_size = state.range(1); @@ -160,10 +162,11 @@ void BM_ScatterS32_R2(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}, - {"$slice_size", absl::StrCat(slice_size)}})); + {"$slice_size", absl::StrCat(slice_size)}}, + options)); } -void BM_ScatterS32_R3(benchmark::State& state) { +void BM_ScatterS32_R3(benchmark::State& state, HloBenchmarkOptions options) { const int64_t d0 = state.range(0); const int64_t d1 = d0; const int64_t d2 = d0; @@ -208,10 +211,12 @@ void BM_ScatterS32_R3(benchmark::State& state) { {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}, {"$d2", absl::StrCat(d2)}, - {"$slice_size", absl::StrCat(slice_size)}})); + {"$slice_size", absl::StrCat(slice_size)}}, + options)); } -void BM_SimpleScatterReduceF32_R3(benchmark::State& state) { +void BM_SimpleScatterReduceF32_R3(benchmark::State& state, + HloBenchmarkOptions options) { const int64_t d0 = state.range(0); const int64_t d1 = state.range(1); const int64_t d2 = state.range(2); @@ -263,16 +268,22 @@ void BM_SimpleScatterReduceF32_R3(benchmark::State& state) { update_shape, &engine, /*mean=*/50, /*stddev=*/10); std::vector args = {&operand, &indices, &update}; - CHECK_OK(RunHloBenchmark(state, hlo, args)); + CHECK_OK(RunHloBenchmark(state, hlo, args, {}, options)); } // these all have the same number of elements in the operand // (2^18) == (2^9)^2 == (2^6)^3 -BENCHMARK(BM_ScatterS32_R1)->MeasureProcessCPUTime()->Args({1 << 18, 1 << 18}); -BENCHMARK(BM_ScatterS32_R2)->MeasureProcessCPUTime()->Args({1 << 9, 1 << 9}); -BENCHMARK(BM_ScatterS32_R3)->MeasureProcessCPUTime()->Args({1 << 6, 1 << 6}); +XLA_CPU_BENCHMARK(BM_ScatterS32_R1) + ->MeasureProcessCPUTime() + ->Args({1 << 18, 1 << 18}); +XLA_CPU_BENCHMARK(BM_ScatterS32_R2) + ->MeasureProcessCPUTime() + ->Args({1 << 9, 1 << 9}); +XLA_CPU_BENCHMARK(BM_ScatterS32_R3) + ->MeasureProcessCPUTime() + ->Args({1 << 6, 1 << 6}); -BENCHMARK(BM_SimpleScatterReduceF32_R3) +XLA_CPU_BENCHMARK(BM_SimpleScatterReduceF32_R3) ->MeasureProcessCPUTime() ->ArgNames({"d0", "d1", "d2", "num_slices"}) ->Args({1, 64, 8, 1}) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/select_and_scatter_benchmark_test.cc index 9d29c2681ad3f4..83d7db3858d86a 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_SelectAndScatterF32(benchmark::State& state) { +static void BM_SelectAndScatterF32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); int64_t d1 = (d0 - 1) / 2; @@ -67,12 +69,12 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { auto p2 = LiteralUtil::CreateR0(1.0f); std::vector args = {&p0, &p1, &p2}; - CHECK_OK( - RunHloBenchmark(state, hlo, args, - {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}})); + CHECK_OK(RunHloBenchmark( + state, hlo, args, {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}}, + options)); } -BENCHMARK(BM_SelectAndScatterF32) +XLA_CPU_BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/tanh_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/tanh_benchmark_test.cc index f0df31c0e68b25..11ffce2f19d2a6 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/tanh_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/tanh_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,7 @@ limitations under the License. namespace xla::cpu { -static void BM_TanhF32(benchmark::State& state) { +static void BM_TanhF32(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -48,7 +49,8 @@ static void BM_TanhF32(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } static void BM_TanhF16(benchmark::State& state) { @@ -72,7 +74,7 @@ static void BM_TanhF16(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); } -static void BM_TanhF64(benchmark::State& state) { +static void BM_TanhF64(benchmark::State& state, HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -90,11 +92,12 @@ static void BM_TanhF64(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } #define REGISTER_TANH_BENCHMARK(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ @@ -103,7 +106,15 @@ static void BM_TanhF64(benchmark::State& state) { ->Arg(4096); REGISTER_TANH_BENCHMARK(BM_TanhF32); -REGISTER_TANH_BENCHMARK(BM_TanhF16); REGISTER_TANH_BENCHMARK(BM_TanhF64); +// TODO(b/406431945): add AOT for f16 tanh +BENCHMARK(BM_TanhF16) + ->MeasureProcessCPUTime() + ->Arg(128) + ->Arg(256) + ->Arg(512) + ->Arg(1024) + ->Arg(4096); + } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/benchmarks/topk_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/topk_benchmark_test.cc index 29557cc1304508..eac2774bcaab5e 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/topk_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/topk_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/tsl/platform/test_benchmark.h" @@ -28,7 +29,8 @@ limitations under the License. namespace xla::cpu { -static void BM_TopKCustomCall_F32(benchmark::State& state) { +static void BM_TopKCustomCall_F32(benchmark::State& state, + HloBenchmarkOptions options) { int64_t k = state.range(0); int64_t batch = state.range(1); int64_t length = state.range(2); @@ -53,10 +55,11 @@ static void BM_TopKCustomCall_F32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, {&x}, {{"$batch", absl::StrCat(batch)}, {"$length", absl::StrCat(length)}, - {"$k", absl::StrCat(k)}})); + {"$k", absl::StrCat(k)}}, + options)); } -static void BM_TopK_BF16(benchmark::State& state) { +static void BM_TopK_BF16(benchmark::State& state, HloBenchmarkOptions options) { int64_t k = state.range(0); int64_t batch = state.range(1); int64_t length = state.range(2); @@ -80,11 +83,12 @@ static void BM_TopK_BF16(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, {&x}, {{"$batch", absl::StrCat(batch)}, {"$length", absl::StrCat(length)}, - {"$k", absl::StrCat(k)}})); + {"$k", absl::StrCat(k)}}, + options)); } #define BENCHMARK_TOPK(name) \ - BENCHMARK(name) \ + XLA_CPU_BENCHMARK(name) \ ->MeasureProcessCPUTime() \ ->ArgNames({"k", "batch", "length"}) \ ->Args({4, 4, 64}) \ diff --git a/third_party/xla/xla/backends/cpu/benchmarks/transposed_copy_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/transposed_copy_benchmark_test.cc index f999ae9395289a..c11698371fce62 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/transposed_copy_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/transposed_copy_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_TransposeAndCopy(benchmark::State& state) { +static void BM_TransposeAndCopy(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -49,11 +51,12 @@ static void BM_TransposeAndCopy(benchmark::State& state) { auto p0 = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } #define REGISTER_BENCHMARK(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ diff --git a/third_party/xla/xla/backends/cpu/benchmarks/transposed_dot_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/transposed_dot_benchmark_test.cc index 02b875aa986744..c0ec6b61537668 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/transposed_dot_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/transposed_dot_benchmark_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla::cpu { -static void BM_TransposeAndDot(benchmark::State& state) { +static void BM_TransposeAndDot(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -56,10 +58,12 @@ static void BM_TransposeAndDot(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(p1_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } -static void BM_TransposedDot(benchmark::State& state) { +static void BM_TransposedDot(benchmark::State& state, + HloBenchmarkOptions options) { int64_t d0 = state.range(0); absl::string_view hlo = R"( @@ -84,11 +88,12 @@ static void BM_TransposedDot(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(p1_shape, &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); + CHECK_OK( + RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}, options)); } #define REGISTER_BENCHMARK(NAME) \ - BENCHMARK(NAME) \ + XLA_CPU_BENCHMARK(NAME) \ ->MeasureProcessCPUTime() \ ->Arg(128) \ ->Arg(256) \ diff --git a/third_party/xla/xla/backends/cpu/benchmarks/xnn_fusion_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/xnn_fusion_benchmark_test.cc index e1284b1fd27f93..50463098989302 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/xnn_fusion_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/xnn_fusion_benchmark_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/strings/str_cat.h" @@ -23,6 +24,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "absl/types/span.h" #include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/backends/cpu/benchmarks/multi_benchmark_config.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" @@ -33,6 +35,7 @@ limitations under the License. namespace xla::cpu { static absl::Status RunFusionBenchmark(benchmark::State& state, + HloBenchmarkOptions options, absl::string_view hlo, bool is_xnn_fusion = false) { int64_t d0 = state.range(0); // Tensor size. @@ -55,7 +58,6 @@ static absl::Status RunFusionBenchmark(benchmark::State& state, ShapeUtil::MakeShape(F32, {d0, d0}), &engine, 1.0f, 0.1f); std::vector args = {&p0, &p1}; - HloBenchmarkOptions options; if (is_xnn_fusion) options.disable_parallel_task_assigner = true; return RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}, @@ -64,7 +66,8 @@ static absl::Status RunFusionBenchmark(benchmark::State& state, options); } -static void BM_EltwiseF32(benchmark::State& state) { +static void BM_EltwiseF32(benchmark::State& state, + HloBenchmarkOptions options) { // Perform `n+1` iterations of `add` and `multiply`, then end with `subtract`. absl::string_view hlo = R"( HloModule eltwise_f32_$n @@ -78,10 +81,11 @@ static void BM_EltwiseF32(benchmark::State& state) { ROOT sub = f32[$d0,$d0] subtract(mul$n, p0) } )"; - CHECK_OK(RunFusionBenchmark(state, hlo)); + CHECK_OK(RunFusionBenchmark(state, std::move(options), hlo)); } -static void BM_XnnEltwiseF32(benchmark::State& state) { +static void BM_XnnEltwiseF32(benchmark::State& state, + HloBenchmarkOptions options) { // Perform `n+1` iterations of `add` and `multiply`, then end with `subtract`. absl::string_view hlo = R"( HloModule eltwise_f32_$n @@ -103,10 +107,12 @@ static void BM_XnnEltwiseF32(benchmark::State& state) { backend_config={"fusion_config": {kind: "__xnn_fusion"}} } )"; - CHECK_OK(RunFusionBenchmark(state, hlo, /*is_xnn_fusion=*/true)); + CHECK_OK(RunFusionBenchmark(state, std::move(options), hlo, + /*is_xnn_fusion=*/true)); } -static void BM_DotAndEltwiseF32(benchmark::State& state) { +static void BM_DotAndEltwiseF32(benchmark::State& state, + HloBenchmarkOptions options) { // Perform `dot` followed by `n+1` iterations of `add` and `multiply`, then // end with `subtract`. absl::string_view hlo = R"( @@ -123,10 +129,11 @@ static void BM_DotAndEltwiseF32(benchmark::State& state) { ROOT sub = f32[$d0,$d0] subtract(mul$n, p0) } )"; - CHECK_OK(RunFusionBenchmark(state, hlo)); + CHECK_OK(RunFusionBenchmark(state, std::move(options), hlo)); } -static void BM_XnnDotAndEltwiseF32(benchmark::State& state) { +static void BM_XnnDotAndEltwiseF32(benchmark::State& state, + HloBenchmarkOptions options) { // Perform `dot` followed by `n+1` iterations of `add` and `multiply`, then // end with `subtract`. absl::string_view hlo = R"( @@ -151,11 +158,12 @@ static void BM_XnnDotAndEltwiseF32(benchmark::State& state) { backend_config={"fusion_config": {kind: "__xnn_fusion"}} } )"; - CHECK_OK(RunFusionBenchmark(state, hlo, /*is_xnn_fusion=*/true)); + CHECK_OK(RunFusionBenchmark(state, std::move(options), hlo, + /*is_xnn_fusion=*/true)); } #define BENCHMARK_FUSION(name) \ - BENCHMARK(name) \ + XLA_CPU_BENCHMARK(name) \ ->MeasureProcessCPUTime() \ ->Args({1024, 4}) \ ->Args({1024, 8}) \ From eb5f107dca93442de87b531d83a66ba095a253d1 Mon Sep 17 00:00:00 2001 From: Ranko Sredojevic Date: Wed, 2 Apr 2025 05:46:34 -0700 Subject: [PATCH 0145/1324] Provides a simple task executor API on top of a fixed-size thread pool. - Tries to fail fast should any action from a batch fail, i.e. does not wait for all submitted actions to finish. - Allows user to enforce the number of worker threads per parallelization call, up to the maximum number of threads existing in the pool. PiperOrigin-RevId: 743098459 --- .../xla/xla/hlo/utils/concurrency/BUILD | 66 +++++++++ .../utils/concurrency/tsl_task_executor.cc | 137 ++++++++++++++++++ .../hlo/utils/concurrency/tsl_task_executor.h | 71 +++++++++ .../concurrency/tsl_task_executor_test.cc | 119 +++++++++++++++ .../xla/hlo/utils/concurrency/type_adapters.h | 66 +++++++++ .../utils/concurrency/type_adapters_test.cc | 73 ++++++++++ 6 files changed, 532 insertions(+) create mode 100644 third_party/xla/xla/hlo/utils/concurrency/BUILD create mode 100644 third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.cc create mode 100644 third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h create mode 100644 third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor_test.cc create mode 100644 third_party/xla/xla/hlo/utils/concurrency/type_adapters.h create mode 100644 third_party/xla/xla/hlo/utils/concurrency/type_adapters_test.cc diff --git a/third_party/xla/xla/hlo/utils/concurrency/BUILD b/third_party/xla/xla/hlo/utils/concurrency/BUILD new file mode 100644 index 00000000000000..19bf1df8dff806 --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/BUILD @@ -0,0 +1,66 @@ +# Infrastructure for parallelization of compilation tasks. +load("//xla:xla.default.bzl", "xla_cc_test") +# copybara:uncomment load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "type_adapters", + hdrs = ["type_adapters.h"], + # copybara:uncomment compatible_with = get_compatible_with_portable(), + deps = [ + "@com_google_absl//absl/functional:any_invocable", + ], +) + +cc_library( + name = "tsl_task_executor", + srcs = ["tsl_task_executor.cc"], + hdrs = ["tsl_task_executor.h"], + # copybara:uncomment compatible_with = get_compatible_with_portable(), + deps = [ + ":type_adapters", + "//xla/tsl/platform:env", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:platform_port", + ], +) + +## Tests below. + +xla_cc_test( + name = "type_adapters_test", + srcs = ["type_adapters_test.cc"], + deps = [ + ":type_adapters", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_googletest//:gtest_main", + ], +) + +xla_cc_test( + name = "tsl_task_executor_test", + size = "small", + srcs = ["tsl_task_executor_test.cc"], + deps = [ + ":tsl_task_executor", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.cc b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.cc new file mode 100644 index 00000000000000..c363b5420ca1a2 --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.cc @@ -0,0 +1,137 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/utils/concurrency/tsl_task_executor.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/utils/concurrency/type_adapters.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/threadpool.h" +#include "tsl/platform/cpu_info.h" + +namespace xla::concurrency { +namespace { + +int ResolveParallelism(std::optional parallelism) { + if (!parallelism.has_value() || *parallelism <= 0 || + *parallelism > tsl::port::MaxParallelism()) { + return tsl::port::MaxParallelism(); + } + return *parallelism; +} + +// Run all actions in a loop within a single schedulable unit. +// This way we guarantee sequential execution. +void DispatchSequentialRun(tsl::thread::ThreadPool* thread_pool, + absl::Status& final_status, + absl::BlockingCounter& finished_counter, + std::vector& original_actions) { + thread_pool->Schedule( + [&final_status, &finished_counter, + actions = TurnMoveOnlyToCopyableWithCaching::FromVector( + std::move(original_actions))]() mutable { + for (auto& action : actions) { + auto action_status = std::move(action)(); + if (!action_status.ok()) { + final_status = action_status; + finished_counter + .DecrementCount(); // this will unblock the caller; count == 1 + return; + } + } + final_status = absl::OkStatus(); + finished_counter.DecrementCount(); + }); +} + +// Run each action as a separately schedulable unit. +void DispatchParallelRun(tsl::thread::ThreadPool* thread_pool, + absl::Status& final_status, + absl::BlockingCounter& finished_counter, + absl::Mutex& mu_final_status, + std::vector& actions) { + // When using `tsl::thread::ThreadPool` directly we need to count successful + // tasks and signal finish once all are done. Without `finished_conuter` we + // do not know when to set `absl::OkStatus()` on the latch. + for (auto& action : actions) { + thread_pool->Schedule([&final_status, &finished_counter, &mu_final_status, + action = TurnMoveOnlyToCopyableWithCaching( + std::move(action))]() mutable { + // Pseudo-cancellation. + // The actions will not be invoked. However, the `ThreadPool` will + // iterate through all the scheduled tasks and check the status. + // Cancellation complexity is O(#tasks). + absl::Status current_status = absl::OkStatus(); + { + absl::ReaderMutexLock reader_lock{&mu_final_status}; + current_status = final_status; + } + if (current_status.ok()) { + auto action_status = std::move(action)(); + if (!action_status.ok()) { + absl::MutexLock write_lock{&mu_final_status}; + final_status = action_status; + } + } + // Must be the last thing we touch. + finished_counter.DecrementCount(); + }); + } +} + +} // namespace + +TslTaskExecutor::TslTaskExecutor(std::optional max_parallelism) { + auto parallelism = ResolveParallelism(max_parallelism); + + thread_pool_ = std::make_unique( + tsl::Env::Default(), kThreadPoolName, parallelism); +} + +absl::Status TslTaskExecutor::ExecuteIndependentTasks( + std::vector tasks, std::optional parallelism) { + auto actual_parallelism = ResolveParallelism(parallelism); + + if (actual_parallelism == 1) { // NOMUTANTS -- Functionally equivalent code + // paths; but the other is parallelized. + // Enforce sequential execution for debugging. + absl::BlockingCounter finished_counter(1); + absl::Status final_status = absl::OkStatus(); + DispatchSequentialRun(thread_pool_.get(), final_status, finished_counter, + tasks); + finished_counter.Wait(); + return final_status; + } + + absl::Status final_status = absl::OkStatus(); + { + absl::BlockingCounter finished_counter(tasks.size()); + absl::Mutex mu_final_status; + + DispatchParallelRun(thread_pool_.get(), final_status, finished_counter, + mu_final_status, tasks); + // Wait for all tasks to finish, so `latch` can be destroyed. + finished_counter.Wait(); + } + return final_status; +} +} // namespace xla::concurrency diff --git a/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h new file mode 100644 index 00000000000000..9f8b684ad6a50c --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h @@ -0,0 +1,71 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_UTILS_CONCURRENCY_TSL_TASK_EXECUTOR_H_ +#define XLA_HLO_UTILS_CONCURRENCY_TSL_TASK_EXECUTOR_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "xla/tsl/platform/threadpool.h" + +namespace xla::concurrency { + +// Tasks must signal a status. We promise to call tasks at most once. +using Task = absl::AnyInvocable; + +// A thread pool with a higher-level API for parallelization of compiler passes. +// Not thread safe. +// +// All calls are synchronous. Specifically, the call to parallelize work blocks +// until the work is done, or canceled due to failure of any of the submitted +// tasks. Once a parallelization call unblocks implementatinos must guarantee +// that no value caputerd by any of the submitted tasks would be accessed going +// forward. Specifically, any captured values can be destroyed after the +// parallelization call returns, even when the work is cancelled. +// +// This design is chosen for simplicity & expediency. It has obvious downside +// that blocking until all work is done will result in many threads idling +// towards the end of the execution. +// +// Features +// - Batch submitted for execution fails if any individual task fails. +// - Guarantees in-order processing of tasks when `parallelism` is 1. +class TslTaskExecutor { + public: + // Runs all the actions on `parallelism` theads. If fewer threads are + // available, runs on as many as it has. + // + // When `parallelism` == 1 sequential execution is guaranteed. + // + absl::Status ExecuteIndependentTasks( + std::vector tasks, std::optional parallelism = std::nullopt); + + explicit TslTaskExecutor(std::optional max_parallelism = std::nullopt); + + private: + std::unique_ptr thread_pool_; + + // std::string because `tsl::thread::ThreadPool` wants a string and not a + // view. + const std::string kThreadPoolName = "TslTaskExecutor"; +}; + +} // namespace xla::concurrency +#endif // XLA_HLO_UTILS_CONCURRENCY_TSL_TASK_EXECUTOR_H_ diff --git a/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor_test.cc b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor_test.cc new file mode 100644 index 00000000000000..1ccd5129463d89 --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/utils/concurrency/tsl_task_executor.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" + +namespace xla::concurrency { +namespace { + +using ::testing::ElementsAreArray; + +TEST(TslTaskExecutorTest, ParallelismOneExecutesInOrder) { + const int kSlowWrite = 42; + const int kMediumWrite = 79; + const int kFastWrite = 255; + + const unsigned int kSlowWait = 1000; + const unsigned int kMediumWait = 300; + const unsigned int kFastWait = 10; + + auto task_executor = TslTaskExecutor(3); + + std::vector results; + + std::vector actions; + actions.push_back([&results, kSlowWrite]() { + absl::SleepFor(absl::Milliseconds(kSlowWait)); + results.push_back(kSlowWrite); + return absl::OkStatus(); + }); + actions.push_back([&results, kMediumWrite]() { + absl::SleepFor(absl::Milliseconds(kMediumWait)); + results.push_back(kMediumWrite); + return absl::OkStatus(); + }); + actions.push_back([&results, kFastWrite]() { + absl::SleepFor(absl::Milliseconds(kFastWait)); + results.push_back(kFastWrite); + return absl::OkStatus(); + }); + + TF_ASSERT_OK(task_executor.ExecuteIndependentTasks(std::move(actions), 1)); + EXPECT_THAT(results, + ElementsAreArray({kSlowWrite, kMediumWrite, kFastWrite})); +} + +TEST(TslTaskExecutorTest, SuccessfulExecutionReturnsOkStatus) { + auto task_executor = TslTaskExecutor(3); + + std::vector results; + + std::vector actions; + for (int i = 0; i < 20; ++i) { + actions.push_back([]() { return absl::OkStatus(); }); + } + + TF_EXPECT_OK(task_executor.ExecuteIndependentTasks(std::move(actions))); +} + +TEST(TaskExecutor, OnFailureNotAllWorkFinishes) { + const int kBeforeCount = 20; + const int kAfterCount = 100; + const int kThreadCount = 5; + auto task_executor = TslTaskExecutor(kThreadCount); + + int finish_counter = 0; + absl::Mutex mu_finish_counter; + + std::vector actions; + for (int i = 0; i < kBeforeCount; ++i) { + actions.push_back([&]() { + absl::MutexLock lock{&mu_finish_counter}; + ++finish_counter; + absl::SleepFor(absl::Milliseconds(10)); + return absl::OkStatus(); + }); + } + + actions.push_back( + []() { return absl::UnimplementedError("force a failure"); }); + + for (int i = 0; i < kAfterCount; ++i) { + actions.push_back([&]() { + absl::MutexLock lock{&mu_finish_counter}; + ++finish_counter; + absl::SleepFor(absl::Milliseconds(10)); + return absl::OkStatus(); + }); + } + + // ::testing::StatusIs not available in oss. // copybara:strip + EXPECT_EQ(task_executor.ExecuteIndependentTasks(std::move(actions)).code(), + absl::StatusCode::kUnimplemented); +} + +} // namespace +} // namespace xla::concurrency diff --git a/third_party/xla/xla/hlo/utils/concurrency/type_adapters.h b/third_party/xla/xla/hlo/utils/concurrency/type_adapters.h new file mode 100644 index 00000000000000..e5af118e709ab5 --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/type_adapters.h @@ -0,0 +1,66 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_UTILS_CONCURRENCY_TYPE_ADAPTERS_H_ +#define XLA_HLO_UTILS_CONCURRENCY_TYPE_ADAPTERS_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" + +namespace xla::concurrency { + +// Turn a move-only & call-once function to copyable by caching. +// +// Basically a `absl::AnyInvocable` -> `std::function`. +template +class TurnMoveOnlyToCopyableWithCaching { + public: + using InnerFunT = absl::AnyInvocable; + explicit TurnMoveOnlyToCopyableWithCaching(InnerFunT inner_fun) + : fun_{std::make_shared(std::move(inner_fun))} {} + + // Wraps each element of a vector of move-only functions to make them + // copyable. + static std::vector> FromVector( + std::vector funs) { + std::vector> res; + res.reserve(funs.size()); + for (auto& f : funs) { + res.emplace_back(std::move(f)); + } + return res; + } + + // Make it callable. + R operator()() { + if (res_ == nullptr) { + res_ = std::make_shared(std::move(*fun_)()); + } + return *res_; + } + + private: + std::shared_ptr fun_ = nullptr; + std::shared_ptr res_ = nullptr; +}; + +// CADT +template +TurnMoveOnlyToCopyableWithCaching(R r) -> TurnMoveOnlyToCopyableWithCaching; + +} // namespace xla::concurrency +#endif // XLA_HLO_UTILS_CONCURRENCY_TYPE_ADAPTERS_H_ diff --git a/third_party/xla/xla/hlo/utils/concurrency/type_adapters_test.cc b/third_party/xla/xla/hlo/utils/concurrency/type_adapters_test.cc new file mode 100644 index 00000000000000..cb89a8c2771a3e --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/type_adapters_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/utils/concurrency/type_adapters.h" + +#include +#include +#include + +#include +#include +#include "absl/functional/any_invocable.h" + +namespace xla::concurrency { +namespace { +using ::testing::ElementsAreArray; + +int call_fun(std::function f) { return f(); } + +TEST(TurnMoveOnlyToCopyableWithCachingTest, CanCopyAssign) { + const int kVal = -42; + absl::AnyInvocable my_fun = []() { return kVal; }; + + auto copyable_my_fun = + TurnMoveOnlyToCopyableWithCaching(std::move(my_fun)); + EXPECT_EQ(copyable_my_fun(), kVal); + + auto my_fun_copy = copyable_my_fun; + EXPECT_EQ(copyable_my_fun(), kVal); +} + +TEST(TurnMoveOnlyToCopyableWithCachingTest, CanCaptureCopyable) { + const int kVal = -42; + absl::AnyInvocable my_fun = []() { return kVal; }; + + EXPECT_EQ(call_fun([f = TurnMoveOnlyToCopyableWithCaching( + std::move(my_fun))]() mutable { return f(); }), + kVal); +} + +TEST(TurnMoveOnlyToCopyableWithCachingTest, VectorWrappingWrapsEachElement) { + const int kVal0 = 42; + const int kVal1 = 77; + + std::vector> funs; + funs.push_back([]() { return kVal0; }); + funs.push_back([]() { return kVal1; }); + + std::vector call0; + std::vector call1; + for (auto& f : + TurnMoveOnlyToCopyableWithCaching::FromVector(std::move(funs))) { + call0.push_back(f()); + call1.push_back(f()); + } + + EXPECT_THAT(call0, ElementsAreArray({kVal0, kVal1})); + EXPECT_THAT(call1, ElementsAreArray({kVal0, kVal1})); +} +} // namespace +} // namespace xla::concurrency From 7f80914f175381d8997847045effc676149fdb80 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 05:53:50 -0700 Subject: [PATCH 0146/1324] Reverts dde3a6fae85aa3b2c8265a263d324b27b6303c33 PiperOrigin-RevId: 743100614 --- third_party/xla/xla/service/copy_insertion.cc | 32 ------------------- .../xla/xla/service/copy_insertion_test.cc | 6 ++-- .../cpu/small_while_loop_hoisting_pass.cc | 1 - .../xla/xla/service/gpu/gpu_compiler_test.cc | 2 +- 4 files changed, 3 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 655d880e540048..8db7718be597ab 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -1592,24 +1592,6 @@ class CopyRemover { VLOG(2) << "Region-based interference is false."; return false; }; - auto AddControlDependenciesBetween = [&](ValueNode* src, ValueNode* dst) { - if (src == nullptr || dst == nullptr) { - return; - } - for (auto use : src->uses) { - if (use->instruction->parent() != dst->value->instruction()->parent() || - use->instruction == dst->value->instruction()) { - // Don't add control dependencies if the use is in a different - // computation or if the use is the same as the destination. - continue; - } - VLOG(2) << "Adding control dependency:"; - VLOG(2) << " From: " << use->instruction->ToString(); - VLOG(2) << " To: " << dst->value->instruction()->ToString(); - CHECK_OK(use->instruction->AddControlDependencyTo( - dst->value->instruction())); - } - }; // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -1692,13 +1674,6 @@ class CopyRemover { kMergeFirstDestInSource)) { return false; } - // Ensure that the last uses of the copy source (e.g. s_x) are - // ordered before the next definition of the copy destination buffer - // (d_1). - AddControlDependenciesBetween(copy_node.src, Next(*copy_node.dest)); - // Also ensure that the last uses of the copy destination (e.g. d_m) are - // ordered before the next definition of the copy source buffer (s_{x+1}). - AddControlDependenciesBetween(copy_node.dest->prev, Next(*copy_node.src)); VLOG(2) << "Splice dest after source."; // Splice in destination buffer values list right after 'src'. SpliceAfter(copy_node.dest, copy_node.src); @@ -1732,13 +1707,6 @@ class CopyRemover { VLOG(2) << "Region-based analysis concludes interference."; return false; } - // Ensure that the last uses of the copy source (e.g. s_n) are - // ordered before the next definition of the copy destination buffer - // (d_{y+1}). - AddControlDependenciesBetween(Prev(*copy_node.dest), copy_node.src->next); - // Also ensure that the last uses of the copy source (e.g. s_n) are - // ordered before next definition of the copy destination (e.g. d_{y+1}). - AddControlDependenciesBetween(copy_node.src, Next(*copy_node.dest)); VLOG(2) << "Splice src after prev of dest."; // Splice source buffer values list right after 'prev_dest'. SpliceAfter(copy_node.src->next, Prev(*copy_node.dest)); diff --git a/third_party/xla/xla/service/copy_insertion_test.cc b/third_party/xla/xla/service/copy_insertion_test.cc index 9071f4250386a6..f26650863fc622 100644 --- a/third_party/xla/xla/service/copy_insertion_test.cc +++ b/third_party/xla/xla/service/copy_insertion_test.cc @@ -953,8 +953,7 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { InsertCopies(module_.get()); EXPECT_EQ(CountCopies(*body), 1); - // Control edges exist for elided copies. - EXPECT_EQ(CountControlEdges(*body), 1); + EXPECT_EQ(CountControlEdges(*body), 0); EXPECT_THAT( body->root_instruction(), @@ -3528,8 +3527,7 @@ TEST_F(CopyInsertionTest, AddControlDependencyForInputOutputAlias) { /*use_region_based_live_range_analysis=*/-1); ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); EXPECT_EQ(CountCopies(*module), 1); - // Include control edges from elided copies. - EXPECT_EQ(CountControlEdges(*module), 3); + EXPECT_EQ(CountControlEdges(*module), 2); HloInstruction* add_instr = FindInstruction(module.get(), HloOpcode::kAdd); HloInstruction* mul_instr = diff --git a/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc b/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc index c599b5f3ff6425..767e914a075992 100644 --- a/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc +++ b/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc @@ -127,7 +127,6 @@ absl::StatusOr SmallWhileLoopHoistingPass::Run( call_instruction->add_frontend_attribute("xla_cpu_small_call", "true"); TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(call_instruction)); - TF_RETURN_IF_ERROR(while_instr->SafelyDropAllControlDependencies()); TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); changed = true; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 7374c6d522fa63..06af1b5b077283 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -830,7 +830,7 @@ CHECK: %[[RESULT_RECV:.*]] = recv(%[[AFTER_ALL]]) CHECK-SAME: channel_id=[[CHANNEL_ID]] CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}, -CHECK-SAME: control-predecessors={%[[CUSTOM_CALL:.*]]} +CHECK-SAME: control-predecessors={%[[CUSTOM_CALL]]} CHECK: %[[RESULT_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[AFTER_ALL]]) CHECK-SAME: channel_id=1 CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", From 2a1c0b6609cc3c9438890f1947ae6984645277de Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Wed, 2 Apr 2025 06:30:34 -0700 Subject: [PATCH 0147/1324] [XLA:GPU/TMA] Replacing Triton specific instructions for loads/stores with TritonXLA ops. Currently this is only enabling standard loads/stores. TMA equivalents will be considered later. PiperOrigin-RevId: 743111603 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 42 +- .../gpu/codegen/triton/fusion_emitter.cc | 366 +++++++++--------- .../gpu/codegen/triton/fusion_emitter.h | 13 +- .../triton/fusion_emitter_device_test.cc | 4 +- .../triton/fusion_emitter_legacy_matmul.cc | 5 +- .../triton/fusion_emitter_legacy_matmul.h | 5 +- .../fusion_emitter_legacy_matmul_stub.cc | 4 +- .../triton/fusion_emitter_mem_utils_test.cc | 256 ------------ .../gpu/codegen/triton/fusion_emitter_stub.cc | 3 +- .../triton/fusion_emitter_stub_test.cc | 4 +- .../gpu/codegen/triton/ir/tests/invalid.mlir | 9 - .../gpu/codegen/triton/ir/triton_xla_ops.cc | 3 - 12 files changed, 215 insertions(+), 499 deletions(-) delete mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_mem_utils_test.cc diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 71accfea7ae094..ff982d7158f88a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -242,7 +242,9 @@ cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", @@ -255,6 +257,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:errors", @@ -331,6 +334,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:SCFDialect", @@ -387,10 +391,12 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -413,6 +419,7 @@ cc_library( ], deps = [ "//xla:autotuning_proto_cc", + "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", @@ -430,6 +437,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@triton//:TritonDialects", @@ -443,6 +451,7 @@ xla_cc_test( ":fusion_emitter_stub_for_testing", "//xla:literal", "//xla:literal_util", + "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", @@ -748,39 +757,6 @@ cc_library( ], ) -xla_cc_test( - name = "fusion_emitter_mem_utils_test", - srcs = if_cuda_is_configured(["fusion_emitter_mem_utils_test.cc"]), - fail_if_no_test_linked = False, # NOLINT=The test is empty on non-CUDA platforms. - deps = [ - ":fusion_emitter", - "//xla/codegen:emitter_loc_op_builder", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/utils:hlo_traversal", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu/model:symbolic_tile_analysis", - "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", - "//xla/service/gpu/model:triton_emitter_constraints", - "//xla/service/llvm_ir:llvm_util", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:logging", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:logging", - "@triton//:TritonDialects", - ], -) - xla_test( name = "fusion_emitter_large_test", size = "large", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 4b9271dfc49809..dbfa75fb4e2156 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include // NOLINT(build/c++11): required to interface with LLVM @@ -41,7 +40,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" -#include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -51,11 +49,13 @@ limitations under the License. #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" @@ -72,6 +72,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" @@ -103,6 +104,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/permutation_util.h" @@ -125,12 +127,10 @@ limitations under the License. #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tools/hlo_decomposer.h" -#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/path.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -143,6 +143,7 @@ namespace gpu { namespace arith = ::mlir::arith; namespace ttir = ::mlir::triton; +namespace mtx = ::mlir::triton::xla; using ::llvm::SmallVector; using ::mlir::ArrayRef; @@ -175,30 +176,32 @@ ScalarOrTensor Range(EmitterLocOpBuilder& b, int32_t limit) { return ScalarOrTensor(b.create(type, 0, limit)); } -Value AddPtr(EmitterLocOpBuilder& b, Value ptr, Value offset) { - return b.create(ptr.getType(), ptr, offset); -} - -ScalarOrTensor EmitParameterLoad(EmitterLocOpBuilder& b, Value pointer, - ArrayRef boundary_checks) { +ScalarOrTensor EmitParameterExtract(EmitterLocOpBuilder& b, + mlir::triton::xla::TileOp tile_op) { // For a pointer to a scalar or a zero-dimensional tensor, load the base // pointer directly. This shortcut is necessary because Triton does not // support 0-D tensors. Looking for the defining make_tensor_ptr op is // sufficient because pointers to 0-D tensors are never modified by e.g. // `tt.advance`. - if (auto make_tensor_ptr = pointer.getDefiningOp(); - make_tensor_ptr && make_tensor_ptr.getShape().empty()) { - pointer = make_tensor_ptr.getBase(); - } - std::optional padding; - if (!boundary_checks.empty()) { - padding = ttir::PaddingOption::PAD_ZERO; + auto tiled_tensor_type = + mlir::dyn_cast(tile_op.getResult().getType()); + CHECK(tiled_tensor_type) << "Expected a TiledTensorType\n"; + + if (tiled_tensor_type.getTileShape().empty()) { + return ScalarOrTensor( + b.create(tile_op.getTensor(), {})); } - bool is_volatile = false; - return ScalarOrTensor(b.create( - pointer, boundary_checks, padding, ttir::CacheModifier::NONE, - ttir::EvictionPolicy::NORMAL, is_volatile)); + + SmallVector offsets( + tile_op.getTiledTensor().getType().getTileShape().size(), + CreateConst(b, b.getIndexType(), 0).UnwrapScalar()); + + return ScalarOrTensor(b.create( + mlir::RankedTensorType::get( + tiled_tensor_type.getTileShape(), + StorageType(b, tiled_tensor_type.getElementType())), + tile_op.getResult(), offsets)); } absl::StatusOr EmitScope( @@ -353,10 +356,9 @@ absl::StatusOr EmitNestedFusion( ScalarOrTensor EmitTiledBroadcast( EmitterLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast, absl::flat_hash_map& values) { - const llvm::SmallVector& input_tile_shape = + const SmallVector& input_tile_shape = tiled_broadcast.operand(0)->tile_sizes(); - const llvm::SmallVector& output_tile_shape = - tiled_broadcast.tile_sizes(); + const SmallVector& output_tile_shape = tiled_broadcast.tile_sizes(); if (input_tile_shape.empty() && output_tile_shape.empty()) { return values[tiled_broadcast.operand(0)]; @@ -605,7 +607,7 @@ absl::StatusOr> EmitTiledComputation( EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, - const TiledHloComputation& tiled_computation, mlir::triton::FuncOp fn, + const TiledHloComputation& tiled_computation, mlir::FunctionOpInterface fn, ValueRange tile_multi_index); bool UseGenericTritonEmitterForGemms(const HloInstruction* hlo) { @@ -663,7 +665,7 @@ absl::StatusOr MaskDotOperand(EmitterLocOpBuilder& b, llvm::ArrayRef tile_shape = mlir::cast(dot_operand_value.getType()).getShape(); - int64_t rank = dot_operand.hlo()->shape().dimensions_size(); + int64_t rank = dot_operand.hlo()->shape().dimensions().size(); int64_t contracting_dimension_size = dot_operand.hlo()->shape().dimensions(contraction_dimension_index); int64_t tile_size = tile_shape[contraction_dimension_index]; @@ -714,12 +716,25 @@ absl::StatusOr MaskDotOperand(EmitterLocOpBuilder& b, return dot_operand_value; } +// Computes the base pointer offset for the given tile multi-index and hlo shape +// taking into account the physical layout of the hlo buffer. +absl::StatusOr> ComputeBasePtrOffset( + EmitterLocOpBuilder& b, ValueRange tile_multi_index, + const TiledHloInstruction& tiled_hlo) { + TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_hlo.tile_offsets_indexing()); + + return emitters::ApplyIndexing(tile_offsets_indexing, + /*dims=*/tile_multi_index, + /*symbols=*/{}, b); +} + absl::StatusOr EmitDot(EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, const TiledHloInstruction& tiled_hlo_dot, - mlir::triton::FuncOp fn, + mlir::FunctionOpInterface fn, ValueRange tile_multi_index) { // We expect to get a tiled HLO in form: // @@ -783,7 +798,7 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, // (tile multi-index.., loop index) -> .... SmallVector computation_index(tile_multi_index); Value ki = for_op.getInductionVar(); - const Value ki_index = b.create(b.getIndexType(), ki); + const Value ki_index = b.create(b.getIndexType(), ki); computation_index.push_back(ki_index); for (const TiledHloInstruction* operand : tiled_hlo_dot.operands()) { VLOG(3) << "Emitting dot operand: " << operand->ToString(); @@ -850,7 +865,7 @@ absl::StatusOr EmitConcatenate( EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, - const TiledHloInstruction& tiled_concatenate, mlir::triton::FuncOp fn, + const TiledHloInstruction& tiled_concatenate, mlir::FunctionOpInterface fn, ValueRange tile_multi_index) { const int64_t concatenate_dimension = tiled_concatenate.hlo()->concatenate_dimension(); @@ -984,7 +999,7 @@ absl::StatusOr EmitTiledHloInstruction( EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, const TiledHloInstruction& tiled_hlo, - mlir::triton::FuncOp fn, ValueRange tile_multi_index, + mlir::FunctionOpInterface fn, ValueRange tile_multi_index, absl::flat_hash_map& values) { const HloInstruction* hlo = tiled_hlo.hlo(); VLOG(4) << "EmitTiledHloInstruction: " << hlo->ToString(); @@ -993,13 +1008,11 @@ absl::StatusOr EmitTiledHloInstruction( // If the fusion instruction is a user of `hlo`, then `hlo` is an operand // to the fusion instruction. int64_t arg_index = GetOutermostFusionOperandParameterIndex(fusion, hlo); - TF_ASSIGN_OR_RETURN( - auto make_tensor, - ir_emitter_triton_internal::CreateMakeTensorPtrOp( - b, tile_multi_index, tiled_hlo, fn.getArgument(arg_index))); + TF_ASSIGN_OR_RETURN(auto tile_op, ir_emitter_triton_internal::CreateTileOp( + b, tile_multi_index, tiled_hlo, + fn.getArgument(arg_index))); - ScalarOrTensor parameter = - EmitParameterLoad(b, make_tensor.op, make_tensor.boundary_checks); + ScalarOrTensor parameter = EmitParameterExtract(b, tile_op); // Some types are stored using different types, e.g. i1 is stored in memory // as i8. It's important to type checking that we perform a conversion after @@ -1135,7 +1148,7 @@ absl::StatusOr> EmitTiledComputation( EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, - const TiledHloComputation& tiled_computation, mlir::triton::FuncOp fn, + const TiledHloComputation& tiled_computation, mlir::FunctionOpInterface fn, ValueRange tile_multi_index) { VLOG(2) << "EmitTiledComputation: " << tiled_computation.ToString(); absl::flat_hash_map values; @@ -1226,33 +1239,6 @@ absl::StatusOr EmitScope( } return values[instructions.back()]; } - -// Computes the base pointer offset for the given tile multi-index and hlo shape -// taking into account the physical layout of the hlo buffer. -absl::StatusOr ComputeBasePtrOffset( - EmitterLocOpBuilder& b, ValueRange tile_multi_index, - const TiledHloInstruction& tiled_hlo) { - const Shape& shape = tiled_hlo.hlo()->shape(); - Shape linear_shape = ShapeUtil::MakeShape(shape.element_type(), - {ShapeUtil::ElementsIn(shape)}); - - // Bitcast map gives an indexing map from the parameter shape (multi-index) to - // a linear index respecting physical layout of the memory. - auto bitcast_map = GetBitcastMap(shape, linear_shape, b.getContext()); - - TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, - tiled_hlo.tile_offsets_indexing()); - - auto compose_indexing_maps = - ComposeIndexingMaps(tile_offsets_indexing, bitcast_map); - compose_indexing_maps.Simplify(); - - return b.create( - b.getI64Type(), emitters::ApplyIndexing(compose_indexing_maps, - /*dims=*/tile_multi_index, - /*symbols=*/{}, b)[0]); -} - } // namespace namespace ir_emitter_triton_internal { @@ -1284,91 +1270,40 @@ SmallVector ComputeDelinearizedTileIndex( /*symbols=*/{}, b); } -absl::StatusOr CreateMakeTensorPtrOp( - EmitterLocOpBuilder& b, ValueRange tile_multi_index, - const TiledHloInstruction& tiled_hlo, Value parent_base_ptr) { - const Shape& shape = tiled_hlo.hlo()->shape(); - - // Compute physical strides of the tile. `tile_strides` contains strides for - // individual dimensions. We need to convert them to strides in the buffer - // taking into account physical layout. - // TODO(b/331332678): Compute indexing maps to physical layout indexing in - // SymbolicTileAnalysis. - const llvm::SmallVector& tile_strides = tiled_hlo.tile_strides(); - llvm::SmallVector strides(tile_strides.size()); - int64_t current_stride = 1; - for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) { - strides[cur_dim] = - CreateConst(b, b.getI64Type(), tile_strides[cur_dim] * current_stride) - .UnwrapScalar(); - current_stride *= shape.dimensions(cur_dim); +SmallVector CreateIndexValues(EmitterLocOpBuilder& builder, + const ArrayRef& values) { + SmallVector result; + result.reserve(values.size()); + for (int64_t value : values) { + result.push_back( + CreateConst(builder, builder.getIndexType(), value).UnwrapScalar()); } + return result; +} - TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, - tiled_hlo.tile_offsets_indexing()); - auto tile_offsets_as_indices = - emitters::ApplyIndexing(tile_offsets_indexing, - /*dims=*/tile_multi_index, - /*symbols=*/{}, b); +absl::StatusOr CreateTileOp( + EmitterLocOpBuilder& b, ValueRange tile_multi_index, + const TiledHloInstruction& tiled_hlo, Value parent_base_ptr) { + TF_ASSIGN_OR_RETURN(SmallVector ptr_offsets, + ComputeBasePtrOffset(b, tile_multi_index, tiled_hlo)); // Triton requires that all block dimensions are a power of 2. SmallVector padded_tile_sizes = GetPaddedTileSizes(tiled_hlo.tile_sizes()); - // TensorPtr is intended to be a base pointer of the TiledHloInstruction and - // plus the necessary offsets so that Triton can compute the pointer to the - // block specific to the given pid. This option would yield simpler code, but - // cannot handle all combinations of strides and offsets, because Triton - // always multiplies the offset by the stride. E.g., it's not possible to - // slice [10] with [1:5:2] because the first element will always be at an even - // offset. - // - // Instead, we output a TensorPtr that points directly to the tile specific - // to the pid. All offset computation is done in advance. MakeTensorPtrOp - // sees 0 offsets. This allows Triton to read any block regardless of strides - // size or offsets. To make sure that masking is correct, we compute a - // "residual shape" which is the original parent shape minus the offsets. - - llvm::SmallVector residual_shape; - llvm::SmallVector boundary_checks; - for (int dim_idx = 0; dim_idx < padded_tile_sizes.size(); ++dim_idx) { - Value parent_size = - CreateConst(b, b.getI64Type(), shape.dimensions(dim_idx)) - .UnwrapScalar(); - // Offsets are necessarily positive since they represent a distance between - // 0 and the size of the tensor on the given axis. Therefore, it is safe to - // use 'IndexCastUI' here. This allows index canonicalizations later on. - Value offset = b.create( - b.getI64Type(), tile_offsets_as_indices[dim_idx]); - residual_shape.push_back(b.create(parent_size, offset)); - - if (shape.dimensions(dim_idx) % padded_tile_sizes[dim_idx] != 0) { - boundary_checks.push_back(dim_idx); - } - } - - TF_ASSIGN_OR_RETURN(Value ptr_offset, - ComputeBasePtrOffset(b, tile_multi_index, tiled_hlo)); - auto tile_ptr = AddPtr(b, parent_base_ptr, ptr_offset); - - llvm::SmallVector offsets( - padded_tile_sizes.size(), - CreateConst(b, b.getI32Type(), 0).UnwrapScalar()); - - // TODO(b/342989850): Clarify and comment what `order` exactly is. It's not - // entirely clear from the Triton docs. - llvm::SmallVector order(padded_tile_sizes.size()); - std::iota(order.rbegin(), order.rend(), 0); - - auto make_tensor_ptr = b.create( - /*base*/ tile_ptr, - /*shape*/ residual_shape, - /*strides*/ strides, - /*offsets*/ offsets, - /*tensorShape*/ llvm::to_vector_of(padded_tile_sizes), - /*order*/ order); - - return MakeTensorPtrOpAndBoundaryChecks{make_tensor_ptr, boundary_checks}; + const Shape& shape = tiled_hlo.hlo()->shape(); + TF_ASSIGN_OR_RETURN(Type expected_element_type, + TritonType(b, shape.element_type())); + auto result_type = mtx::TiledTensorType::get( + b.getContext(), padded_tile_sizes, + llvm::ArrayRef(shape.dimensions().data(), + shape.dimensions().size()), + expected_element_type); + + return b.create( + result_type, parent_base_ptr, ptr_offsets, + CreateIndexValues(b, tiled_hlo.tile_strides()), + llvm::to_vector(LayoutUtil::MinorToMajor(shape))); } } // namespace ir_emitter_triton_internal @@ -1378,12 +1313,11 @@ namespace { using ::xla::gpu::ir_emitter_triton_internal::DumpTritonIR; // Generate Triton IR inside 'fn', using the given block_level_parameters. -absl::Status EmitGeneric(mlir::OpBuilder builder, - absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, - const BlockLevelParameters& block_level_parameters) { +absl::StatusOr> EmitGeneric( + mlir::OpBuilder builder, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, + const BlockLevelParameters& block_level_parameters) { const HloComputation* computation = fusion->fused_instructions_computation(); SymbolicTileAnalysisOrError symbolic_tile_analysis_or = SymbolicTileAnalysis::AnalyzeComputation( @@ -1429,6 +1363,7 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, EmitTiledComputation(b, libdevice_path, device_info, fusion, tiled_hlo_computation, fn, tile_multi_index)); + SmallVector insert_results; for (auto [root, result, parent_base_ptr] : llvm::zip(tiled_hlo_computation.GetRoots(), results, fn.getArguments().drop_front(computation->num_parameters()))) { @@ -1444,23 +1379,39 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, } if (result.IsScalar()) { - b.create(parent_base_ptr, result.UnwrapScalar(), - ttir::CacheModifier::NONE, - ttir::EvictionPolicy::NORMAL); + ValueRange indices = {}; + insert_results.push_back( + b.create(result.UnwrapScalar(), + parent_base_ptr, indices) + .getResult()); continue; } CHECK(root->hlo()->shape().IsArray() && - root->hlo()->shape().dimensions_size() > 0); - TF_ASSIGN_OR_RETURN(auto make_tensor, - ir_emitter_triton_internal::CreateMakeTensorPtrOp( + !root->hlo()->shape().dimensions().empty()); + TF_ASSIGN_OR_RETURN(mlir::triton::xla::TileOp tile_op, + ir_emitter_triton_internal::CreateTileOp( b, tile_multi_index, *root, parent_base_ptr)); - b.create( - make_tensor.op, result.UnwrapTensor(), make_tensor.boundary_checks, - ttir::CacheModifier::NONE, ttir::EvictionPolicy::NORMAL); + + // Should not be scalar at this point. + auto tiled_tensor_type = + mlir::dyn_cast(tile_op.getResult().getType()); + CHECK(tiled_tensor_type) << "Expected a tiled tensor type since scalars " + "should've been handled at this point."; + + SmallVector offsets( + tiled_tensor_type.getTileShape().size(), + CreateConst(b, b.getIndexType(), 0).UnwrapScalar()); + + insert_results.push_back( + b.create( + mlir::RankedTensorType::get(tiled_tensor_type.getOriginalShape(), + result_storage_type), + result.UnwrapTensor(), tile_op.getResult(), offsets) + .getResult()); } - return absl::OkStatus(); + return insert_results; } } // namespace @@ -1470,7 +1421,8 @@ void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context) { .loadDialect(); + xla::gpu::XlaGpuDialect, ttir::xla::XlaTritonDialect, + mlir::func::FuncDialect, mlir::tensor::TensorDialect>(); mlir::DialectRegistry registry; mlir::func::registerInlinerExtension(registry); mlir::LLVM::registerInlinerInterface(registry); @@ -1515,6 +1467,62 @@ absl::Status CreateInternalError(absl::string_view message, return absl::InternalError(err); } +// Legacy emitter works with tt.func. New emitter works with func.func. +void AppendFuncArgType(EmitterLocOpBuilder& b, absl::Span dims, + absl::string_view fusion_kind, Type ir_type, + SmallVector& fn_arg_types) { + if (fusion_kind == kTritonGemmFusionKind) { + fn_arg_types.push_back(ttir::PointerType::get( + StorageType(b, ir_type), mlir::NVVM::kGlobalMemorySpace)); + } else { + fn_arg_types.push_back(mlir::RankedTensorType::get( + llvm::ArrayRef(dims.data(), dims.size()), + StorageType(b, ir_type))); + } +} + +// Only needed for the new emitter since we are using func.func instead of +// tt.func. +void AppendFuncResultType(EmitterLocOpBuilder& b, absl::string_view fusion_kind, + absl::Span dims, Type ir_type, + SmallVector& fn_result_types) { + if (fusion_kind != kTritonGemmFusionKind) { + fn_result_types.push_back(mlir::RankedTensorType::get( + llvm::ArrayRef(dims.data(), dims.size()), + StorageType(b, ir_type))); + } +} + +// Legacy emitter works with tt.func. New emitter works with func.func. +mlir::FunctionOpInterface CreateFuncOp(EmitterLocOpBuilder& b, + absl::string_view fn_name, + absl::string_view fusion_kind, + SmallVector& fn_arg_types, + SmallVector& fn_result_types) { + mlir::FunctionOpInterface fn; + if (fusion_kind == kTritonGemmFusionKind) { + fn = b.create(fn_name, + b.getFunctionType(fn_arg_types, std::nullopt)); + for (int i = 0; i < fn.getNumArguments(); ++i) { + fn.setArgAttr(i, "tt.divisibility", b.getIntegerAttr(b.getI32Type(), 16)); + } + } else { + fn = b.create( + fn_name, b.getFunctionType(fn_arg_types, fn_result_types)); + } + return fn; +} + +// Legacy emitter works with tt.return. New emitter works with func.return. +void EmitReturnOp(EmitterLocOpBuilder& b, absl::string_view fusion_kind, + SmallVector insert_results) { + if (fusion_kind == kTritonGemmFusionKind) { + b.create(); + } else { + b.create(insert_results); + } +} + absl::StatusOr CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, @@ -1536,6 +1544,10 @@ absl::StatusOr CreateTritonModule( llvm_ir::CreateMlirModuleOp(loc); b.setInsertionPointToEnd(triton_module->getBody()); + auto backend_config = + fusion->backend_config()->fusion_backend_config(); + absl::string_view fusion_kind = backend_config.kind(); + // Build Triton kernel. SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { @@ -1548,22 +1560,24 @@ absl::StatusOr CreateTritonModule( } else { TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type)); } - fn_arg_types.push_back(ttir::PointerType::get( - StorageType(b, ir_type), mlir::NVVM::kGlobalMemorySpace)); + + AppendFuncArgType(b, p->shape().dimensions(), fusion_kind, ir_type, + fn_arg_types); } + SmallVector fn_result_types; + for (const ShapeUtil::IndexedShape& s : ShapeUtil::GetLeafShapes(fusion->shape())) { TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type())); - fn_arg_types.push_back(ttir::PointerType::get( - StorageType(b, triton_ty), mlir::NVVM::kGlobalMemorySpace)); + AppendFuncArgType(b, s.shape.dimensions(), fusion_kind, triton_ty, + fn_arg_types); + AppendFuncResultType(b, fusion_kind, s.shape.dimensions(), triton_ty, + fn_result_types); } - auto fn = b.create( - fn_name, b.getFunctionType(fn_arg_types, std::nullopt)); - for (int i = 0; i < fn.getNumArguments(); ++i) { - fn.setArgAttr(i, "tt.divisibility", b.getIntegerAttr(b.getI32Type(), 16)); - } + mlir::FunctionOpInterface fn = + CreateFuncOp(b, fn_name, fusion_kind, fn_arg_types, fn_result_types); fn.addEntryBlock(); b.setInsertionPointToStart(&fn.front()); @@ -1571,12 +1585,9 @@ absl::StatusOr CreateTritonModule( std::string libdevice_path = GetLibdevicePath(fusion->GetModule()->config(), device_info); - auto backend_config = - fusion->backend_config()->fusion_backend_config(); - absl::string_view fusion_kind = backend_config.kind(); - // It's okay for tma_metadata to be empty; it's only populated when used // explicitly. + SmallVector insert_results; std::optional tma_metadata = std::nullopt; if (fusion_kind == kTritonGemmFusionKind) { // If the generic Triton emitter is enabled, we should never go through the @@ -1587,13 +1598,14 @@ absl::StatusOr CreateTritonModule( block_level_parameters)); } else if (fusion_kind == kTritonFusionKind || fusion_kind == kTritonNestedGemmFusionKind) { - TF_RETURN_IF_ERROR(EmitGeneric(b, libdevice_path, device_info, fusion, fn, - block_level_parameters)); + TF_ASSIGN_OR_RETURN(insert_results, + EmitGeneric(b, libdevice_path, device_info, fusion, fn, + block_level_parameters)); } else { return Internal("Unsupported fusion kind: %s", fusion_kind); } - b.create(); + EmitReturnOp(b, fusion_kind, insert_results); if (DumpingEnabledForHloModule(*hlo_computation->parent())) { auto suffix = absl::StrCat(fusion->name(), ".before_validation.ttir"); @@ -1615,6 +1627,10 @@ absl::StatusOr CreateTritonModule( } mlir::PassManager pm(&mlir_context); + + // TODO(b/315957220): Pass device and tma_flag to the pass. + pm.addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass()); + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); if (mlir::failed(pm.run(triton_module.get()))) { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h index 6f98930a24adb4..3cbd9a5039c34d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Pass/PassManager.h" #include "xla/autotuning.pb.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/stream_executor/launch_dim.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir { namespace triton { @@ -112,16 +112,7 @@ namespace ir_emitter_triton_internal { llvm::SmallVector ComputeDelinearizedTileIndex( EmitterLocOpBuilder& b, absl::Span num_output_tiles_per_dim); -// Used for creating Triton Load and Store ops. -struct MakeTensorPtrOpAndBoundaryChecks { - ::mlir::triton::MakeTensorPtrOp op; - - // Indices of dimensions where the original tile size is not a power of 2 and - // requires a boundary check. - llvm::SmallVector boundary_checks; -}; - -absl::StatusOr CreateMakeTensorPtrOp( +absl::StatusOr CreateTileOp( EmitterLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value parent_base_ptr); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index 8b59cc52b088a1..7163dbf516b782 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -468,7 +468,7 @@ ENTRY main { "num_stages":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 127), domain: pid_0 in [0, 124]"> +CHECK: #indexing_map = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 @@ -536,7 +536,7 @@ ENTRY main { "num_stages":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 127), domain: pid_0 in [0, 124]"> +CHECK: #indexing_map = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index 904f49ee7acc9c..e473078c81737e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -51,6 +51,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" #include "xla/backends/gpu/codegen/triton/dot_algorithms.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" @@ -1473,7 +1474,7 @@ class MatMulEmitterHelper { MatMulLaunchConfig launch_config_; }; -absl::StatusOr> GetArguments(mlir::triton::FuncOp fn, +absl::StatusOr> GetArguments(mlir::FunctionOpInterface fn, const HloInstruction& input) { if (input.opcode() == HloOpcode::kParameter) { return {{fn.getArgument(input.parameter_number())}}; @@ -1902,7 +1903,7 @@ absl::Status EmitForLoopBody(EmitterLocOpBuilder& b, absl::StatusOr> EmitMatMul( EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, mlir::triton::FuncOp fn, + const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, const BlockLevelParameters&) { // TODO b/315957220: Populate tma_metadata. stream_executor::gpu::TmaMetadata tma_metadata; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h index 723cc2df2d73ff..99d7e10fdf9d9a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h @@ -18,9 +18,9 @@ limitations under the License. #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_traversal.h" @@ -30,7 +30,6 @@ limitations under the License. #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/tma_metadata.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace xla::gpu { @@ -45,7 +44,7 @@ absl::StatusOr GetMatMulLaunchDimensions( absl::StatusOr> EmitMatMul( EmitterLocOpBuilder& builder, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, mlir::triton::FuncOp fn, + const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, const BlockLevelParameters&); } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc index e8e84fb5abd44f..eb845dab9830da 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/tma_metadata.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace xla::gpu { @@ -42,7 +42,7 @@ absl::StatusOr GetMatMulLaunchDimensions( absl::StatusOr> EmitMatMul( EmitterLocOpBuilder& builder, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, mlir::triton::FuncOp fn, + const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, const BlockLevelParameters&) { return absl::UnimplementedError("not supported for this build configuration"); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_mem_utils_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_mem_utils_test.cc deleted file mode 100644 index 0b8bef542b48d6..00000000000000 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_mem_utils_test.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/log/check.h" -#include "absl/strings/str_join.h" -#include "absl/strings/substitute.h" -#include "absl/types/span.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/triton/fusion_emitter.h" -#include "xla/codegen/emitter_loc_op_builder.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/hlo/utils/hlo_traversal.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/model/symbolic_tile_analysis.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" -#include "xla/service/gpu/model/tiled_hlo_instruction.h" -#include "xla/service/gpu/model/triton_emitter_constraints.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/logging.h" -#include "tsl/platform/logging.h" // IWYU pragma: keep -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" - -namespace xla::gpu::ir_emitter_triton_internal { -namespace { - -using ::llvm::SmallVector; -using ::mlir::MLIRContext; -using ::mlir::OpBuilder; -using ::mlir::Type; -using ::mlir::Value; -using ::testing::ElementsAre; - -class TritonMakeTensorPtrTest : public HloTestBase { - public: - void SetUp() override { LoadMlirDialectsForTriton(mlir_context_); } - - std::pair, TiledHloComputation> - CreateAndTileHloComputation(std::vector shape_sizes, - const std::vector& tile_sizes, - const std::vector& tile_strides); - - std::pair, MakeTensorPtrOpAndBoundaryChecks> - CreateTestTensorPtr(const std::vector& parent_shape, - const std::vector& tile_sizes, - const std::vector& tile_strides); - - protected: - MLIRContext mlir_context_; -}; - -// Returns a HloModule and a corresponding TiledHloComputation using -// `shape_sizes` to replace the placeholders in the hardcoded hlo text. -// `tile_sizes` and `tile_strides` are used to tile the hlo computation. -std::pair, TiledHloComputation> -TritonMakeTensorPtrTest::CreateAndTileHloComputation( - std::vector shape_sizes, const std::vector& tile_sizes, - const std::vector& tile_strides) { - const std::string hlo_text = R"( - HloModule test_module - - fusion { - p0 = f32[$0] parameter(0) - ROOT log = f32[$0] log(p0) - } - - ENTRY %main{ - p0.1 = f32[$0] parameter(0) - ROOT fusion = f32[$0] fusion(p0.1), kind=kLoop, calls=%fusion - })"; - auto verified_hlo_module_or = ParseAndReturnVerifiedModule( - absl::Substitute(hlo_text, absl::StrJoin(shape_sizes, ","))); - CHECK_OK(verified_hlo_module_or); - - std::unique_ptr verified_hlo_module = - std::move(verified_hlo_module_or).value(); - - auto fusion_adaptor = HloFusionAdaptor::ForInstruction( - verified_hlo_module->entry_computation()->root_instruction()); - - SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeFusion( - *fusion_adaptor, &mlir_context_, - TritonEmitterConstraints::GetBuilder( - TestGpuDeviceInfo::RTXA6000DeviceInfo())); - CHECK( - std::holds_alternative(symbolic_tile_analysis_or)); - - SymbolicTileAnalysis symbolic_tile_analysis = - std::get(std::move(symbolic_tile_analysis_or)); - - auto tiled_hlo_computation_or = - symbolic_tile_analysis.ComputeTiledHloInstructions( - tile_sizes, /*constraints_are_known_satisfied=*/true, - /*compute_all_tile_offset_indexing_maps=*/true); - TF_EXPECT_OK(tiled_hlo_computation_or.status()); - return std::make_pair(std::move(verified_hlo_module), - *std::move(tiled_hlo_computation_or)); -} - -mlir::triton::FuncOp CreateTritonFunction( - EmitterLocOpBuilder& b, const std::vector shape_sizes) { - auto fn = b.create<::mlir::triton::FuncOp>( - "func", - b.getFunctionType({::mlir::triton::PointerType::get( - b.getF32Type(), mlir::NVVM::kGlobalMemorySpace)}, - std::nullopt)); - for (int i = 0; i < fn.getNumArguments(); ++i) { - fn.setArgAttr(i, "tt.divisibility", b.getIntegerAttr(b.getI32Type(), 16)); - } - b.setInsertionPointToStart(fn.addEntryBlock()); - return fn; -} - -std::pair, MakeTensorPtrOpAndBoundaryChecks> -TritonMakeTensorPtrTest::CreateTestTensorPtr( - const std::vector& parent_shape, - const std::vector& tile_sizes, - const std::vector& tile_strides) { - auto [hlo_module, tiled_hlo_computation] = - CreateAndTileHloComputation(parent_shape, tile_sizes, tile_strides); - - const TiledHloInstruction* tiled_parameter = - tiled_hlo_computation.GetRoots()[0]->operand(0); - const HloInstruction* parameter = tiled_parameter->hlo(); - - OpBuilder builder(&mlir_context_); - auto loc = mlir::NameLoc::get(builder.getStringAttr(parameter->name())); - mlir::OwningOpRef triton_module = - llvm_ir::CreateMlirModuleOp(loc); - builder.setInsertionPointToEnd(triton_module->getBody()); - - EmitterLocOpBuilder b(loc, builder); - auto fn = CreateTritonFunction(b, parent_shape); - - SmallVector tile_multi_index = ComputeDelinearizedTileIndex( - b, tiled_hlo_computation.num_output_tiles_per_dim()); - - return std::make_pair( - std::move(triton_module), - *ir_emitter_triton_internal::CreateMakeTensorPtrOp( - b, tile_multi_index, *tiled_parameter, fn.getArgument(0))); -} - -std::vector ConstOpValuesToInt(const mlir::ValueRange values) { - std::vector result; - for (Value v : values) { - auto const_op = v.getDefiningOp(); - CHECK_NOTNULL(const_op); - auto int_attr = mlir::cast(const_op.getValueAttr()); - result.push_back(int_attr.getInt()); - } - return result; -} - -mlir::ArrayRef TensorShape(const ::mlir::triton::MakeTensorPtrOp& op) { - auto ptr = - mlir::cast<::mlir::triton::PointerType>(op->getResult(0).getType()); - auto tensor = mlir::cast(ptr.getPointeeType()); - return tensor.getShape(); -} - -void CheckSizesAreSubtractions(const mlir::ValueRange size_values) { - for (Value v : size_values) { - EXPECT_NE(v.getDefiningOp(), nullptr); - } -} -TEST_F(TritonMakeTensorPtrTest, BlockProperties) { - { - auto [module, ptr] = CreateTestTensorPtr({15, 20}, {3, 4}, {1, 1}); - CheckSizesAreSubtractions(ptr.op.getShape()); - EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4)); - EXPECT_THAT(ptr.boundary_checks, ElementsAre(0)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0)); - EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0)); - } - { - auto [module, ptr] = CreateTestTensorPtr({20, 20}, {4, 4}, {1, 1}); - CheckSizesAreSubtractions(ptr.op.getShape()); - EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4)); - EXPECT_TRUE(ptr.boundary_checks.empty()); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0)); - EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0)); - } - { - auto [module, ptr] = CreateTestTensorPtr({5}, {1}, {1}); - CheckSizesAreSubtractions(ptr.op.getShape()); - EXPECT_THAT(TensorShape(ptr.op), ElementsAre(1)); - EXPECT_TRUE(ptr.boundary_checks.empty()); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0)); - EXPECT_THAT(ptr.op.getOrder(), ElementsAre(0)); - } - { - auto [module, ptr] = CreateTestTensorPtr({5, 5, 5}, {1, 1, 1}, {1, 1, 1}); - CheckSizesAreSubtractions(ptr.op.getShape()); - EXPECT_THAT(TensorShape(ptr.op), ElementsAre(1, 1, 1)); - EXPECT_TRUE(ptr.boundary_checks.empty()); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(25, 5, 1)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0, 0)); - EXPECT_THAT(ptr.op.getOrder(), ElementsAre(2, 1, 0)); - } - { - auto [module, ptr] = CreateTestTensorPtr({5, 15, 20}, {1, 3, 4}, {1, 1, 1}); - CheckSizesAreSubtractions(ptr.op.getShape()); - EXPECT_THAT(TensorShape(ptr.op), ElementsAre(1, 4, 4)); - EXPECT_THAT(ptr.boundary_checks, ElementsAre(1)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), - ElementsAre(300, 20, 1)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0, 0)); - EXPECT_THAT(ptr.op.getOrder(), ElementsAre(2, 1, 0)); - } -} - -} // namespace -} // namespace xla::gpu::ir_emitter_triton_internal diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc index 480081da4b1b8d..19ed1aab38c432 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "xla/autotuning.pb.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -91,7 +92,7 @@ llvm::SmallVector ComputeDelinearizedTileIndex( return {}; } -absl::StatusOr CreateMakeTensorPtrOp( +absl::StatusOr CreateTileOp( EmitterLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value parent_base_ptr) { return absl::UnimplementedError("not supported for this build configuration"); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc index ecc560ffd1f54e..204d7e9da57349 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc @@ -68,8 +68,8 @@ TEST(TritonStub, CallStubApi) { auto tiled_hlo = TiledHloInstruction::Create(&constant, {}, {1}, {1}, {}); EXPECT_TRUE(tiled_hlo.ok()); - EXPECT_FALSE(ir_emitter_triton_internal::CreateMakeTensorPtrOp( - builder, {}, *tiled_hlo.value(), {}) + EXPECT_FALSE(ir_emitter_triton_internal::CreateTileOp(builder, {}, + *tiled_hlo.value(), {}) .ok()); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir index 0ba8c41d32efc8..7094a48994a58a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/tests/invalid.mlir @@ -34,15 +34,6 @@ tt.func @insert_mismatch_rank( // ----- -"tt.func"() <{function_type = (tensor) -> !triton_xla.tiled_tensor<|bf16>, sym_name = "xla_triton_tile"}> ({ -^bb0(%arg0: tensor): - // expected-error @+1 {{cannot tile a 0-d tensor}} - %0 = "triton_xla.tile"(%arg0) {layout = array} : (tensor) -> !triton_xla.tiled_tensor<|bf16> - "tt.return"(%0) : (!triton_xla.tiled_tensor<|bf16>) -> () -}) : () -> () - -// ----- - "tt.func"() <{function_type = (!triton_xla.tiled_tensor<|bf16>) -> tensor, sym_name = "xla_triton_extract"}> ({ ^bb0(%arg0: !triton_xla.tiled_tensor<|bf16>): %0 = "arith.constant"() <{value = 0 : index}> : () -> index diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc index 025bd7f7eba569..73a16722234c16 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc @@ -80,9 +80,6 @@ void TileOp::print(OpAsmPrinter& p) { } LogicalResult TileOp::verify() { - if (getTensor().getType().getRank() == 0) { - return emitError("cannot tile a 0-d tensor"); - } auto tensor_rank = getTensor().getType().getRank(); if (tensor_rank != getOffsets().size() || tensor_rank != getStrides().size()) return emitError( From 26570206d9c3b5e430a35fe4dc94017b24695782 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Wed, 2 Apr 2025 06:43:31 -0700 Subject: [PATCH 0148/1324] PR #24181: [ds-fusion] Fix the issue with JAX tests with no indvars Imported from GitHub PR https://github.com/openxla/xla/pull/24181 JAX has some testcases with no induction variable but static slices inside the while loop. Such slices are fused with the hero operation, but while lowering they fail on an assert (which is not required). Added a test to demonstrate this. Copybara import of the project: -- 0adcb04d43b75c02e9e4f6af2e0484245e49e990 by Shraiysh Vaishay : [ds-fusion] Fix the issue with JAX tests with no indvars JAX has some testcases with no induction variable but static slices inside the while loop. Such slices are fused with the hero operation, but while lowering they fail on an assert (which is not required). Added a test to demonstrate this. Merging this change closes #24181 PiperOrigin-RevId: 743115370 --- .../xla/xla/backends/gpu/codegen/custom.cc | 13 ++-- .../gpu/codegen/dynamic_slice_fusion_test.cc | 61 +++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/custom.cc b/third_party/xla/xla/backends/gpu/codegen/custom.cc index 4fb423f734a0ca..afbc30798b4ee0 100644 --- a/third_party/xla/xla/backends/gpu/codegen/custom.cc +++ b/third_party/xla/xla/backends/gpu/codegen/custom.cc @@ -264,8 +264,9 @@ std::unique_ptr ExtractOffsetModule( std::unique_ptr ExtractWhileUpdateModule( const HloInstruction* while_op) { std::optional tuple_idx = GetLoopInductionVarTupleIdx(while_op); - CHECK(tuple_idx != std::nullopt) - << "Unable to get tuple idx for whileop " << while_op->ToString(); + if (tuple_idx == std::nullopt) { + return nullptr; + } const HloInstruction* update = while_op->while_body()->root_instruction()->operand(*tuple_idx); return ExtractOffsetModule(update, while_op); @@ -277,8 +278,9 @@ std::unique_ptr ExtractWhileUpdateModule( std::unique_ptr ExtractWhileInitModule( const HloInstruction* while_op) { std::optional tuple_idx = GetLoopInductionVarTupleIdx(while_op); - CHECK(tuple_idx != std::nullopt) - << "Unable to get tuple idx for while op: " << while_op->ToString(); + if (tuple_idx == std::nullopt) { + return nullptr; + } const HloInstruction* init = while_op->operand(0)->operand(*tuple_idx); std::unique_ptr init_module = ExtractModule( /*instruction=*/init, /*height=*/-1, /*extract_selector=*/nullptr, @@ -1051,11 +1053,14 @@ CollectSliceArgumentMetadataForCollectives( SliceDataForCollectives slice_data(num_args); std::optional while_op = GetParentWhileOp(fusion_instr, call_graph); + VLOG(0) << "Collecting while op data"; if (while_op != std::nullopt) { CHECK(while_op.value() != nullptr) << "GetParentWhileOp is not expected to return nullptr."; slice_data.init_module = ExtractWhileInitModule(*while_op); + VLOG(0) << "Extracted init module"; slice_data.update_module = ExtractWhileUpdateModule(*while_op); + VLOG(0) << "Extracted update module"; } slice_data.can_compute_indvar_on_host = (slice_data.init_module != nullptr && slice_data.update_module != nullptr); diff --git a/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc b/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc index 2ba89b7fa389aa..eb7ed0c1f28d0b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc @@ -3703,6 +3703,67 @@ TEST_F(DynamicSliceFusionTest, /*run_hlo_passes=*/false, /*use_threads=*/true, std::nullopt)); } +TEST_F(DynamicSliceFusionTest, WhileLoopSliceWithNoInductionVariable) { + const char* hlo = R"( + HloModule test, replica_count=2 + + add { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) + } + + body { + param = (s32[], s32[], f32[128,128], f32[1024,128]) parameter(0) + iter0 = s32[] get-tuple-element(param), index=0 + iter1 = s32[] get-tuple-element(param), index=1 + c0 = s32[] constant(0) + c1 = s32[] constant(1) + add0 = s32[] add(iter0, iter0) + add1 = s32[] add(iter1, c1) + a = f32[128,128] get-tuple-element(param), index=2 + b = f32[1024,128] get-tuple-element(param), index=3 + slice = f32[256,128] slice(b), slice={[0:256], [0:128]} + rs = f32[128,128] reduce-scatter(slice), replica_groups={{0,1}}, dimensions={0}, to_apply=add + ROOT tuple = tuple(add0, add1, rs, b) + } + + condition { + param = (s32[], s32[], f32[128,128], f32[1024,128]) parameter(0) + iter = s32[] get-tuple-element(param), index=0 + iter1 = s32[] get-tuple-element(param), index=1 + c8 = s32[] constant(8) + compare1 = pred[] compare(iter, c8), direction=LT + compare2 = pred[] compare(iter1, c8), direction=LT + ROOT compare = pred[] and(compare1, compare2) + } + + ENTRY main { + c1 = s32[] constant(1) + a = f32[128,128] parameter(0) + b = f32[1024,128] parameter(1) + tuple = tuple(c1, c1, a, b) + while = while(tuple), body=body, condition=condition + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + m->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(false); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m_ref, + GetOptimizedModule(m->Clone())); + m->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(true); + TF_ASSERT_OK_AND_ASSIGN(m, GetOptimizedModule(std::move(m))); + // VLOG(0) << "Fused module: " << m->ToString(); + ErrorSpec error_spec(1e-5, 1e-5); + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(m), std::move(m_ref), + /*run_hlo_passes=*/false, + /*use_threads=*/true, + error_spec)); +} } // namespace } // namespace gpu } // namespace xla From b4bdb0a139fcf17a173ebc4074594237fb107862 Mon Sep 17 00:00:00 2001 From: Ranko Sredojevic Date: Wed, 2 Apr 2025 07:03:17 -0700 Subject: [PATCH 0149/1324] Provide conveniences for parallelization of simple loops when actions return a value. PiperOrigin-RevId: 743121338 --- .../xla/xla/hlo/utils/concurrency/BUILD | 26 +++++ .../hlo/utils/concurrency/concurrency_utils.h | 96 +++++++++++++++++++ .../concurrency/concurrency_utils_test.cc | 86 +++++++++++++++++ .../hlo/utils/concurrency/tsl_task_executor.h | 1 - 4 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h create mode 100644 third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc diff --git a/third_party/xla/xla/hlo/utils/concurrency/BUILD b/third_party/xla/xla/hlo/utils/concurrency/BUILD index 19bf1df8dff806..93ff7b472d7084 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/BUILD +++ b/third_party/xla/xla/hlo/utils/concurrency/BUILD @@ -39,6 +39,18 @@ cc_library( ], ) +cc_library( + name = "concurrency_utils", + hdrs = ["concurrency_utils.h"], + # copybara:uncomment compatible_with = get_compatible_with_portable(), + deps = [ + ":tsl_task_executor", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + ## Tests below. xla_cc_test( @@ -64,3 +76,17 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +xla_cc_test( + name = "concurrency_utils_test", + size = "small", + srcs = ["concurrency_utils_test.cc"], + deps = [ + ":concurrency_utils", + ":tsl_task_executor", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h new file mode 100644 index 00000000000000..b962801c051191 --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h @@ -0,0 +1,96 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ +#define XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "xla/hlo/utils/concurrency/tsl_task_executor.h" + +namespace xla::concurrency { +// Runs an action on all elements from an iterator. A successful run collects +// all the return values from actions. The implementation guarantees that the +// order of returned values corresponds to the order of elements in the argument +// iterator [action(begin), ... action(end-1)]. Note that the action can mutate +// the objects it receives from the iterator according to their semantics. +// +// The overload below is for actions that return a value. `ActionReturnT` must +// be default constructible. +// +// Returns synchronously when all actions finish. Aborts the run on the first +// failure. If a run aborts the underlying data is likely to be corrupted or +// partially modified. +// +// For synchronization, clients should make sure that actions do not deadlock or +// corrupt any state they access. Specifically, if actions access any shared +// mutable state clients must make sure that such access is synchronized. The +// run can deadlock in all the standard ways. Specifically, if the action locks +// a set of shared resources make sure that all locks are acquired in the same +// order. +template +#if __cplusplus >= 202002L + requires(std::forward_iterator && !std::is_void_v) +#endif +absl::StatusOr> ForEach( + ForwardItT begin, ForwardItT end, + absl::AnyInvocable( + typename std::iterator_traits::value_type)> + action, + TaskExecutorT& task_executor, + std::optional parallelism = std::nullopt) { + static_assert(!std::is_same_v, + "Cannot collect vector concurrently. If you need bool " + "return wrap it in a struct."); + auto result_size = std::distance(begin, end); + std::vector result_storage(result_size); + std::vector tasks; + tasks.reserve(result_size); + + auto result_iterator = result_storage.begin(); + for (auto argument_iterator = begin; argument_iterator != end; + ++argument_iterator) { + // If modifying this function, keep an eye on iterator capture. + // Specifically, evaluate whether capturing the iterator is correct. + // For example, we can capture `result_iterator` because we are using + // `std::vector`. Should you want to change the result collection consider + // if the capture needs to change. + auto argument = *argument_iterator; + tasks.push_back([result_iterator, argument, &action]() { + auto result = action(argument); + if (result.ok()) { + *result_iterator = *result; + } + return result.status(); + }); + ++result_iterator; + } + auto status = + task_executor.ExecuteIndependentTasks(std::move(tasks), parallelism); + if (status.ok()) { + return result_storage; + } + return status; +} + +} // namespace xla::concurrency + +#endif // XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ diff --git a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc new file mode 100644 index 00000000000000..a6009361701023 --- /dev/null +++ b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/utils/concurrency/concurrency_utils.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/utils/concurrency/tsl_task_executor.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::concurrency { +namespace { + +using ::testing::ElementsAreArray; + +TEST(ForEachTest, ActionReturnedValuesCollected) { + TslTaskExecutor task_executor{3}; + + constexpr int kx0 = 0; + constexpr int kx1 = 1; + constexpr int kx2 = 2; + + int v0 = kx0; + int v1 = kx1; + int v2 = kx2; + + std::vector v = {&v0, &v1, &v2}; + + TF_ASSERT_OK_AND_ASSIGN( + auto result, + (ForEach( + v.begin(), v.end(), + [](int* element) -> absl::StatusOr { return ++(*element); }, + task_executor))); + + EXPECT_EQ(v0, kx0 + 1); + EXPECT_EQ(v1, kx1 + 1); + EXPECT_EQ(v2, kx2 + 1); + + EXPECT_THAT(result, ElementsAreArray({1, 2, 3})); +} + +TEST(ForEachTest, FailureOfTheFirstActionPropagates) { + TslTaskExecutor task_executor{3}; + + constexpr int kx0 = 0; + constexpr int kx1 = 1; + constexpr int kx2 = 2; + + int v0 = kx0; + int v1 = kx1; + int v2 = kx2; + + std::vector v = {&v0, &v1, &v2}; + + EXPECT_EQ(ForEach( + v.begin(), v.end(), + [](int* element) -> absl::StatusOr { + if (*element % 2 == 1) + return absl::CancelledError("Force a failure."); + return ++(*element); + }, + task_executor) + .status() + .code(), + absl::StatusCode::kCancelled); +} + +} // namespace +} // namespace xla::concurrency diff --git a/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h index 9f8b684ad6a50c..dbea62a3585bd7 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h +++ b/third_party/xla/xla/hlo/utils/concurrency/tsl_task_executor.h @@ -53,7 +53,6 @@ class TslTaskExecutor { // available, runs on as many as it has. // // When `parallelism` == 1 sequential execution is guaranteed. - // absl::Status ExecuteIndependentTasks( std::vector tasks, std::optional parallelism = std::nullopt); From 8985adef6fac9160c02e6152abace9d4e8828586 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 2 Apr 2025 07:16:17 -0700 Subject: [PATCH 0150/1324] [xla:gpu] CommandBuffer: return recorded commands from CommandBufferCmd::Record Every CommandBufferCmd has to track commands in the underlying stream_executor::CommandBuffer, so it can use explicit update APIs (coming next). For now update API, so that CommandBufferCmdSequence can track dependencies between recorded commands. PiperOrigin-RevId: 743125149 --- .../xla/xla/backends/gpu/runtime/BUILD | 1 + .../gpu/runtime/command_buffer_cmd.cc | 258 +++++++++--------- .../backends/gpu/runtime/command_buffer_cmd.h | 161 ++++++----- .../gpu/runtime/command_buffer_cmd_test.cc | 18 +- 4 files changed, 241 insertions(+), 197 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index eb11475a5f51f0..a5a22480624d46 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -183,6 +183,7 @@ xla_test( "//xla/tsl/platform:test_main", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 010d66f6d7990d..0a18b5725a0f10 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -151,6 +151,13 @@ static se::CommandBuffer::Builder CreateExecutionScopeBuilder( }; } +absl::StatusOr +CommandBufferCmd::RecordedCommands::Create( + absl::StatusOr command) { + if (!command.ok()) return command.status(); + return RecordedCommands{{*command}}; +} + //===----------------------------------------------------------------------===// // CommandBufferCmd //===----------------------------------------------------------------------===// @@ -243,8 +250,10 @@ absl::Status CommandBufferCmdSequence::Record( std::optional annotation = GetKernelAnnotation(command->profile_annotation()); - TF_RETURN_IF_ERROR( + TF_ASSIGN_OR_RETURN( + CommandBufferCmd::RecordedCommands recorded_commands, command->Record(execute_params, record_params, command_buffer)); + (void)recorded_commands; ++num_recorded_commands; } @@ -352,7 +361,8 @@ TracedCommandBufferCmd::TracedCommandBufferCmd( CommandBufferCmdType cmd_type, ExecutionStreamId execution_stream_id) : CommandBufferCmd(cmd_type, execution_stream_id) {} -absl::Status TracedCommandBufferCmd::AddTracedCommandBuffer( +absl::StatusOr +TracedCommandBufferCmd::AddTracedCommandBuffer( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::FunctionRef trace) { @@ -370,7 +380,7 @@ absl::Status TracedCommandBufferCmd::AddTracedCommandBuffer( execute_params.command_buffer_trace_stream, trace)); VLOG(5) << "Add nested command buffer"; - return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}).status(); + return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}); } //===----------------------------------------------------------------------===// @@ -464,7 +474,7 @@ absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::Status ComputationIdCmd::Record( +absl::StatusOr ComputationIdCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = @@ -502,11 +512,12 @@ absl::Status ComputationIdCmd::Record( } auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); - return command_buffer - ->Launch(se::ThreadDim(1), se::BlockDim(1), *memset_kernel, *args, {}) - .status(); + return RecordedCommands::Create(command_buffer->Launch( + se::ThreadDim(1), se::BlockDim(1), *memset_kernel, *args, {})); + } else { - return command_buffer->Memset(&dst, value, /*num_elements=*/1, {}).status(); + return RecordedCommands::Create( + command_buffer->Memset(&dst, value, /*num_elements=*/1, {})); } } @@ -543,9 +554,9 @@ absl::Status LaunchCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr LaunchCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { VLOG(5) << "LaunchCmd: kernel=" << kernel_name_ << "; shmem_bytes=" << shmem_bytes_; @@ -570,10 +581,9 @@ absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& execute_params, TF_ASSIGN_OR_RETURN(auto kernel_args, se::PackKernelArgs(buffers, shmem_bytes_)); - return command_buffer - ->Launch(dims_.thread_counts_per_block(), dims_.block_counts(), *kernel, - *kernel_args, {}) - .status(); + return RecordedCommands::Create( + command_buffer->Launch(dims_.thread_counts_per_block(), + dims_.block_counts(), *kernel, *kernel_args, {})); } CommandBufferCmd::BufferUseVector LaunchCmd::buffers() { @@ -614,9 +624,10 @@ absl::Status CustomKernelLaunchCmd::Initialize( return absl::OkStatus(); } -absl::Status CustomKernelLaunchCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { +absl::StatusOr +CustomKernelLaunchCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { VLOG(5) << "CustomKernelLaunchCmd: custom_kernel=" << custom_kernel_.name(); se::Kernel* kernel = [&] { @@ -641,10 +652,9 @@ absl::Status CustomKernelLaunchCmd::Record( se::KernelArgsDeviceMemoryArray kernel_args( buffers, custom_kernel_.shared_memory_bytes()); - return command_buffer - ->Launch(custom_kernel_.thread_dims(), custom_kernel_.block_dims(), - *kernel, kernel_args, {}) - .status(); + return RecordedCommands::Create(command_buffer->Launch( + custom_kernel_.thread_dims(), custom_kernel_.block_dims(), *kernel, + kernel_args, {})); } CommandBufferCmd::BufferUseVector CustomKernelLaunchCmd::buffers() { @@ -668,9 +678,10 @@ MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd( src_(src), num_bytes_(num_bytes) {} -absl::Status MemcpyDeviceToDeviceCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { +absl::StatusOr +MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dst_); se::DeviceMemoryBase src = @@ -682,11 +693,11 @@ absl::Status MemcpyDeviceToDeviceCmd::Record( if (num_bytes_ == 0) { VLOG(5) << "Skip recording MemcpyDeviceToDeviceCmd command of 0 bytes"; - return absl::OkStatus(); + return RecordedCommands{}; } - return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_, {}) - .status(); + return RecordedCommands::Create( + command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_, {})); } CommandBufferCmd::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() { @@ -702,9 +713,9 @@ MemzeroCmd::MemzeroCmd(ExecutionStreamId execution_stream_id, : CommandBufferCmd(CommandBufferCmdType::kMemzeroCmd, execution_stream_id), dst_(dst) {} -absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr MemzeroCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dst_); @@ -713,13 +724,12 @@ absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& execute_params, if (dst_.size() == 0) { VLOG(5) << "Skip recording MemzeroCmd command of 0 bytes"; - return absl::OkStatus(); + return RecordedCommands{}; } - return command_buffer - ->Memset(&dst, uint8_t{0}, - /*num_elements=*/dst_.size(), {}) - .status(); + return RecordedCommands::Create( + command_buffer->Memset(&dst, uint8_t{0}, + /*num_elements=*/dst_.size(), {})); } CommandBufferCmd::BufferUseVector MemzeroCmd::buffers() { @@ -736,9 +746,9 @@ Memset32Cmd::Memset32Cmd(ExecutionStreamId execution_stream_id, dst_(dst), bit_pattern_(bit_pattern) {} -absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr Memset32Cmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dst_); @@ -747,13 +757,12 @@ absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& execute_params, if (dst_.size() == 0) { VLOG(5) << "Skip recording Memset32Cmd command of 0 bytes"; - return absl::OkStatus(); + return RecordedCommands{}; } - return command_buffer - ->Memset(&dst, bit_pattern_, - /*num_elements=*/dst_.size() / sizeof(uint32_t), {}) - .status(); + return RecordedCommands::Create(command_buffer->Memset( + &dst, bit_pattern_, + /*num_elements=*/dst_.size() / sizeof(uint32_t), {})); } CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() { @@ -780,9 +789,9 @@ absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::Status CaseCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr CaseCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase index = execute_params.buffer_allocations->GetDeviceAddress(index_); @@ -790,19 +799,17 @@ absl::Status CaseCmd::Record(const Thunk::ExecuteParams& execute_params, VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")"; if (index_is_bool_) { - return command_buffer - ->Case(se::DeviceMemory(index), - CreateBuilders(absl::MakeSpan(branches_commands_), - &execute_params, &record_params), - {}) - .status(); + return RecordedCommands::Create( + command_buffer->Case(se::DeviceMemory(index), + CreateBuilders(absl::MakeSpan(branches_commands_), + &execute_params, &record_params), + {})); } else { - return command_buffer - ->Case(se::DeviceMemory(index), - CreateBuilders(absl::MakeSpan(branches_commands_), - &execute_params, &record_params), - {}) - .status(); + return RecordedCommands::Create( + command_buffer->Case(se::DeviceMemory(index), + CreateBuilders(absl::MakeSpan(branches_commands_), + &execute_params, &record_params), + {})); } } @@ -839,9 +846,9 @@ absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params, return body_commands_.Initialize(params, state); } -absl::Status WhileCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr WhileCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase pred = execute_params.buffer_allocations->GetDeviceAddress(pred_); @@ -849,13 +856,11 @@ absl::Status WhileCmd::Record(const Thunk::ExecuteParams& execute_params, << " body_commands=" << body_commands_.size(); VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; - return command_buffer - ->While(se::DeviceMemory(pred), - CreateExecutionScopeBuilder(&cond_commands_, &execute_params, - &record_params), - CreateBuilder(&body_commands_, &execute_params, &record_params), - {}) - .status(); + return RecordedCommands::Create(command_buffer->While( + se::DeviceMemory(pred), + CreateExecutionScopeBuilder(&cond_commands_, &execute_params, + &record_params), + CreateBuilder(&body_commands_, &execute_params, &record_params), {})); } bool WhileCmd::force_update() { @@ -898,9 +903,9 @@ absl::Status GemmCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr GemmCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase lhs = execute_params.buffer_allocations->GetDeviceAddress(lhs_buffer_); se::DeviceMemoryBase rhs = @@ -916,11 +921,11 @@ absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params, VLOG(5) << " Out: " << output_buffer_ << " (" << out.opaque() << ")"; VLOG(5) << " Workspace: " << workspace_ << " (" << workspace.opaque() << ")"; - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, stream); - }); + })); } CommandBufferCmd::BufferUseVector GemmCmd::buffers() { @@ -1004,9 +1009,9 @@ absl::Status CublasLtCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr CublasLtCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(execute_params.stream)); TF_ASSIGN_OR_RETURN(auto algorithm, GetMatmulAlgorithm(execute_params.stream, plan, @@ -1051,7 +1056,7 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, VLOG(5) << " d_amax_buffer: " << d_amax_buffer_.ToString(); VLOG(5) << " workspace_buffer: " << workspace_buffer_.ToString(); - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return plan->ExecuteOnStream( stream, allocs.GetDeviceAddress(a_buffer_), @@ -1060,7 +1065,7 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, algorithm, allocs.GetDeviceAddress(workspace_buffer_)); - }); + })); } CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() { @@ -1116,9 +1121,9 @@ absl::Status CuDnnCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr CuDnnCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { CHECK(graph_ != nullptr); std::vector operands; operands.reserve(args_.size()); @@ -1129,12 +1134,12 @@ absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params, operands.push_back(buf); } - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return graph_->get()->Execute( *stream, absl::Span(operands), execute_params.collective_params->local_device_ordinal); - }); + })); } CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { @@ -1151,9 +1156,9 @@ CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { // CustomCallCmd //===----------------------------------------------------------------------===// -absl::Status CustomCallCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr CustomCallCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { if (handler_ == nullptr) { return RecordLegacyCustomCall(execute_params, record_params, command_buffer); @@ -1189,7 +1194,8 @@ absl::Status GetBuffers( } } // namespace -absl::Status CustomCallCmd::RecordLegacyCustomCall( +absl::StatusOr +CustomCallCmd::RecordLegacyCustomCall( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer) { std::vector buffers; @@ -1217,12 +1223,14 @@ absl::Status CustomCallCmd::RecordLegacyCustomCall( return absl::OkStatus(); })); - return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}).status(); + return RecordedCommands::Create( + command_buffer->AddNestedCommandBuffer(*nested_cmd, {})); } -absl::Status CustomCallCmd::RecordXlaFfiCall( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { +absl::StatusOr +CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes // separate from arguments, as they do not change after thunk is @@ -1289,7 +1297,8 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( return ffi::Call(handler_, call_frame, options); })); - return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}).status(); + return RecordedCommands::Create( + command_buffer->AddNestedCommandBuffer(*nested_cmd, {})); } CommandBufferCmd::BufferUseVector CustomCallCmd::buffers() { @@ -1328,7 +1337,8 @@ absl::Status CollectiveCmd::Prepare( return resource_requests.AddClique(clique_key); } -absl::Status CollectiveCmd::AddTracedCommandBuffer( +absl::StatusOr +CollectiveCmd::AddTracedCommandBuffer( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::FunctionRef trace) { @@ -1337,7 +1347,7 @@ absl::Status CollectiveCmd::AddTracedCommandBuffer( execute_params.stream->parent(), execute_params.command_buffer_trace_stream, trace)); - return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}).status(); + return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}); } //===----------------------------------------------------------------------===// @@ -1354,9 +1364,9 @@ AllReduceCmd::AllReduceCmd(ExecutionStreamId execution_stream_id, reduction_kind_(reduction_kind), buffers_(buffers.begin(), buffers.end()) {} -absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr AllReduceCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1385,11 +1395,11 @@ absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params, *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return RunAllReduce(collectives, reduction_kind_, device_buffers, *stream, comm_handle.comm); - }); + })); } CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() { @@ -1415,7 +1425,7 @@ ReduceScatterCmd::ReduceScatterCmd( reduction_kind_(reduction_kind), buffers_(buffers.begin(), buffers.end()) {} -absl::Status ReduceScatterCmd::Record( +absl::StatusOr ReduceScatterCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( @@ -1447,11 +1457,11 @@ absl::Status ReduceScatterCmd::Record( *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return RunReduceScatter(collectives, reduction_kind_, device_buffers, *stream, comm_handle.comm); - }); + })); } CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() { @@ -1476,9 +1486,9 @@ AllToAllCmd::AllToAllCmd(ExecutionStreamId execution_stream_id, has_split_dimension_(has_split_dimension), buffers_(buffers.begin(), buffers.end()) {} -absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr AllToAllCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1506,11 +1516,11 @@ absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params, *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return RunAllToAll(collectives, has_split_dimension_, device_buffers, *stream, comm_handle.comm); - }); + })); } CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() { @@ -1534,9 +1544,9 @@ AllGatherCmd::AllGatherCmd(ExecutionStreamId execution_stream_id, async_from_stream_id, std::move(config)), buffers_(buffers.begin(), buffers.end()) {} -absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { +absl::StatusOr AllGatherCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1565,11 +1575,11 @@ absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params, *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return RunAllGather(collectives, device_buffers, *stream, comm_handle.comm); - }); + })); } CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() { @@ -1594,9 +1604,10 @@ CollectiveBroadcastCmd::CollectiveBroadcastCmd( std::move(config)), buffers_(buffers.begin(), buffers.end()) {} -absl::Status CollectiveBroadcastCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { +absl::StatusOr +CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1625,11 +1636,11 @@ absl::Status CollectiveBroadcastCmd::Record( *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return AddTracedCommandBuffer( + return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return RunCollectiveBroadcast(device_buffers, *stream, comm_handle.comm, collectives); - }); + })); } CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() { @@ -1738,9 +1749,10 @@ absl::Status DynamicSliceFusionCmd::Prepare( return absl::OkStatus(); } -absl::Status DynamicSliceFusionCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { +absl::StatusOr +DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { se::Stream& stream = *execute_params.stream; const BufferAllocations& orig_allocations = @@ -1871,8 +1883,8 @@ absl::Status DynamicSliceFusionCmd::Record( .value(); TF_RETURN_IF_ERROR(embedded_commands_->Record(new_params, record_params, nested_command_buffer.get())); - return command_buffer->AddNestedCommandBuffer(*nested_command_buffer, {}) - .status(); + return RecordedCommands::Create( + command_buffer->AddNestedCommandBuffer(*nested_command_buffer, {})); } CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 7454725e13431b..fa54c980ec092d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -177,6 +177,15 @@ class CommandBufferCmd { StateManager& state; }; + // A list of commands recorded into the command buffer (or updated). + struct RecordedCommands { + // Creates a recorded commands from a single se::CommandBuffer command. + static absl::StatusOr Create( + absl::StatusOr command); + + absl::InlinedVector commands; + }; + // See Thunk documentation for XLA execution stages (prepare, initialize, // execute). Commands mirror thunks as they are executed as CommandBufferThunk // that is plugged into the Thunk execution cycle. @@ -197,10 +206,12 @@ class CommandBufferCmd { return absl::OkStatus(); } - // Records command into the command buffer using given execution scope. - virtual absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) = 0; + // Records commands into the command buffer. Returned commands will be passed + // back on the next call to `Record` into the same command buffer, so that it + // can do efficient command buffer updates. + virtual absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) = 0; // For some commands need to force update on Record even the input device // pointers do not change, e.g. command that has state that can be changed by @@ -378,7 +389,7 @@ class TracedCommandBufferCmd : public CommandBufferCmd { // Creates a command buffer by calling a user-provided `trace` function and // adds it as a nested command to `command_buffer`. Traced command buffers // cached and reused in an instance of `TracedCommandBuffer` kept in `state`. - absl::Status AddTracedCommandBuffer( + absl::StatusOr AddTracedCommandBuffer( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::FunctionRef trace); @@ -398,9 +409,10 @@ class ComputationIdCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -435,9 +447,10 @@ class LaunchCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -469,9 +482,10 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -497,9 +511,10 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { BufferAllocation::Slice dst, BufferAllocation::Slice src, int64_t num_bytes); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -518,9 +533,10 @@ class MemzeroCmd : public CommandBufferCmd { MemzeroCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -537,9 +553,10 @@ class Memset32Cmd : public CommandBufferCmd { Memset32Cmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst, uint32_t bit_pattern); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -561,9 +578,10 @@ class CaseCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; bool force_update() override; @@ -588,9 +606,10 @@ class WhileCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; bool force_update() override; @@ -617,9 +636,10 @@ class GemmCmd : public TracedCommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -659,9 +679,10 @@ class CublasLtCmd : public TracedCommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -712,9 +733,10 @@ class CuDnnCmd : public TracedCommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -765,20 +787,21 @@ class CustomCallCmd : public CommandBufferCmd { operands_(std::move(operands)), results_(std::move(results)) {} - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } private: - absl::Status RecordLegacyCustomCall(const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, - se::CommandBuffer* command_buffer); - absl::Status RecordXlaFfiCall(const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, - se::CommandBuffer* command_buffer); + absl::StatusOr RecordLegacyCustomCall( + const Thunk::ExecuteParams& execute_param, + const RecordParams& record_params, se::CommandBuffer* command_buffer); + absl::StatusOr RecordXlaFfiCall( + const Thunk::ExecuteParams& execute_param, + const RecordParams& record_params, se::CommandBuffer* command_buffer); std::string target_name_; @@ -817,7 +840,7 @@ class CollectiveCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } - absl::Status AddTracedCommandBuffer( + absl::StatusOr AddTracedCommandBuffer( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::FunctionRef trace); @@ -855,9 +878,10 @@ class AllReduceCmd : public CollectiveCmd { ReductionKind reduction_kind, absl::Span buffers); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -881,9 +905,10 @@ class ReduceScatterCmd : public CollectiveCmd { CollectiveConfig config, ReductionKind reduction_kind, absl::Span buffers); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -907,9 +932,10 @@ class AllToAllCmd : public CollectiveCmd { bool has_split_dimension, absl::Span buffers); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -932,9 +958,10 @@ class AllGatherCmd : public CollectiveCmd { ExecutionStreamId async_from_stream_id, CollectiveConfig config, absl::Span buffers); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -957,9 +984,10 @@ class CollectiveBroadcastCmd : public CollectiveCmd { CollectiveConfig config, absl::Span buffers); - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -991,9 +1019,10 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { const Thunk::PrepareParams& params, Thunk::ResourceRequestsInterface& resource_requests) final; - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 078b2184d648b9..a3e70675c6c40b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/types/span.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -69,9 +70,10 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { execution_stream_id), buffer_usage(buffer_usage) {} - absl::Status Record(const Thunk::ExecuteParams&, const RecordParams&, - se::CommandBuffer*) override { - return absl::OkStatus(); + absl::StatusOr Record(const Thunk::ExecuteParams&, + const RecordParams&, + se::CommandBuffer*) override { + return RecordedCommands{}; } BufferUseVector buffers() override { return buffer_usage; } @@ -81,14 +83,14 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { class FakeCmd : public CommandBufferCmd { public: - FakeCmd(ExecutionStreamId execution_stream_id) + explicit FakeCmd(ExecutionStreamId execution_stream_id) : CommandBufferCmd(CommandBufferCmdType::kTracedCommandBufferCmd, execution_stream_id) {} - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override { - return absl::OkStatus(); + absl::StatusOr Record(const Thunk::ExecuteParams&, + const RecordParams&, + se::CommandBuffer*) override { + return RecordedCommands{}; } BufferUseVector buffers() override { return BufferUseVector{}; } }; From c83672b45f89f893fdb9b3bbea1cb843d10eca3c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Wed, 2 Apr 2025 07:21:12 -0700 Subject: [PATCH 0151/1324] PR #24449: Remove HloComputation::ConditionalCallInstruction. Imported from GitHub PR https://github.com/openxla/xla/pull/24449 This is deprecated and broken. Step 4/5 of removing instruction_type. Copybara import of the project: -- fd2eb8a794915c0c8dacee2fafae94b4cb89292c by Johannes Reifferscheid : Remove HloComputation::ConditionalCallInstruction. This is deprecated and broken. Step 4/5 of removing instruction_type. Merging this change closes #24449 PiperOrigin-RevId: 743126664 --- third_party/xla/xla/hlo/ir/hlo_computation.h | 29 +------------------ third_party/xla/xla/hlo/ir/hlo_instruction.cc | 5 ---- .../simplifiers/flatten_call_graph.cc | 4 --- third_party/xla/xla/service/copy_insertion.cc | 2 +- .../xla/xla/service/hlo_instruction_test.cc | 16 ++++++---- 5 files changed, 12 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 4ed6f2dab39c9d..cbe8fb6ccac4dd 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -215,10 +215,8 @@ class HloComputation { // unreachable, and its instruction is set to null. We still need to regard // such computations as fusion computations for HLO scheduling purposes. kFusion, - // This computation is a conditional branch computation. - kConditional, // Last Value for range checking. - kLast = kConditional, + kLast = kFusion, }; static constexpr uintptr_t kInstructionTypeMask = 0b111; static_assert(static_cast(InstructionType::kUnset) == 0, @@ -803,31 +801,6 @@ class HloComputation { SetInstruction(fusion_instruction, InstructionType::kFusion); } - // Returns if this computation is a branch computation of a conditional. - [[deprecated( - "This is broken. Use CallGraph::GetComputationCallers() instead")]] - bool IsConditionalBranchComputation() const { - return instruction_type() == InstructionType::kConditional; - } - - // Returns the owning conditional call instruction, or nullptr if this is not - // a conditional branch computation. - [[deprecated( - "This is broken. Use CallGraph::GetComputationCallers() instead")]] - HloInstruction* ConditionalCallInstruction() const { - return instruction_type() == InstructionType::kConditional ? instruction() - : nullptr; - } - - [[deprecated( - "This is broken. Use CallGraph::GetComputationCallers() instead")]] - void SetConditionalCallInstruction( - HloInstruction* conditional_call_instruction) { - CHECK(conditional_call_instruction != nullptr); - CHECK(conditional_call_instruction->opcode() == HloOpcode::kConditional); - SetInstruction(conditional_call_instruction, InstructionType::kConditional); - } - // Returns if this computation is an async computation. bool IsAsyncComputation() const { return async_start_ != nullptr; } diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index e485866ff62ddc..6fa7756dfe9eba 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2004,9 +2004,6 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, // kFalseComputationIndex. instruction->AppendComputation(true_computation); instruction->AppendComputation(false_computation); - // Set back pointer from computations to the conditional instruction. - true_computation->SetConditionalCallInstruction(instruction.get()); - false_computation->SetConditionalCallInstruction(instruction.get()); return instruction; } @@ -2021,8 +2018,6 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, for (int i = 0; i < branch_computations.size(); ++i) { instruction->AppendComputation(branch_computations[i]); instruction->AppendOperand(branch_computation_args[i]); - // Set back pointer from the computation to the conditional instruction. - branch_computations[i]->SetConditionalCallInstruction(instruction.get()); } return instruction; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc index 7615849220fa9f..264fea757ae652 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc @@ -137,10 +137,6 @@ absl::Status AnnotateNode(const CallGraphNode& node) { for (HloComputation* computation : instruction->called_computations()) { computation->SetFusionInstruction(instruction); } - } else if (instruction->opcode() == HloOpcode::kConditional) { - for (HloComputation* branch : instruction->branch_computations()) { - branch->SetConditionalCallInstruction(instruction); - } } } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 8db7718be597ab..6d8ec0ba5c3509 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -890,7 +890,7 @@ class ComputeRelativeLocation { // A proper solution would be to track output index in // LiveRangeRegions::InstructionInfo. if (use->parent() == def->parent() && - def->parent()->IsConditionalBranchComputation() && + !def->parent()->caller_instructions(HloOpcode::kConditional).empty() && def == entry2.first && def->shape().IsTuple()) { VLOG(3) << "Setting interception for multi-output instruction inside " "conditional branch: " diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index f526838ecd2d5f..72b46c7861d572 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -2766,10 +2766,12 @@ TEST_F(HloInstructionTest, // point to the conditional instruction. int num_conditional_branch_comp = 0; for (HloComputation* comp : module->MakeComputationPostOrder()) { - if (comp->IsConditionalBranchComputation()) { + auto conditional_callers = + comp->caller_instructions(HloOpcode::kConditional); + if (!conditional_callers.empty()) { num_conditional_branch_comp += 1; - EXPECT_EQ(comp->ConditionalCallInstruction(), - module->entry_computation()->root_instruction()); + EXPECT_THAT(conditional_callers, + ElementsAre(module->entry_computation()->root_instruction())); } } EXPECT_EQ(num_conditional_branch_comp, 2); @@ -2841,10 +2843,12 @@ TEST_F(HloInstructionTest, // point to the conditional instruction. int num_conditional_branch_comp = 0; for (HloComputation* comp : module->MakeComputationPostOrder()) { - if (comp->IsConditionalBranchComputation()) { + auto conditional_callers = + comp->caller_instructions(HloOpcode::kConditional); + if (!conditional_callers.empty()) { num_conditional_branch_comp += 1; - EXPECT_EQ(comp->ConditionalCallInstruction(), - module->entry_computation()->root_instruction()); + EXPECT_THAT(conditional_callers, + ElementsAre(module->entry_computation()->root_instruction())); } } EXPECT_EQ(num_conditional_branch_comp, branch_computations.size()); From b8905ae89f2d5da8d07d0c661c461c9d7055c6a5 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Wed, 2 Apr 2025 07:28:57 -0700 Subject: [PATCH 0152/1324] Create a compiler object on every call of Compiler::GetForPlatform(). This is a first step in moving from xla::Compiler instance being a statically constructed singleton. We want to move all registrations to factories instead of singletons, so that every user creates their own compiler instance and sharing of instances becomes explicit. This should prevent race conditions and make code thread-safe. PiperOrigin-RevId: 743128681 --- tensorflow/compiler/jit/xla_platform_info.cc | 5 ++--- third_party/xla/xla/client/local_client.cc | 4 ++-- .../xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc | 7 ++++--- third_party/xla/xla/service/backend.cc | 13 ++++++------- third_party/xla/xla/service/backend.h | 6 +++--- .../xla/xla/service/compile_only_service.cc | 9 ++++++--- .../xla/xla/service/compile_only_service.h | 4 ++-- third_party/xla/xla/service/compiler.cc | 19 ++----------------- third_party/xla/xla/service/compiler.h | 3 ++- .../gpu/autotuning/autotuner_compile_util.cc | 10 +++++----- .../gpu/autotuning/autotuner_compile_util.h | 5 +++-- .../xla/xla/tools/hlo_opt/compiled_opt_lib.cc | 9 +++++---- .../xla/xla/tools/hlo_opt/compiled_opt_lib.h | 2 +- third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 4 ++-- 14 files changed, 45 insertions(+), 55 deletions(-) diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index 321d00ad728402..2fa93816071225 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -255,9 +255,8 @@ absl::Status BuildXlaDeviceCompiler(DeviceBase* device, return platform.status(); } - // TODO(aliia): Replace auto with the actual type. This is a temporary change, - // needed to pass the OSS presubmits. - auto compiler_for_platform = xla::Compiler::GetForPlatform(platform.value()); + absl::StatusOr> compiler_for_platform = + xla::Compiler::GetForPlatform(platform.value()); if (!compiler_for_platform.ok()) { // In some rare cases (usually in unit tests with very small clusters) we // may end up transforming an XLA cluster with at least one GPU operation diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index 2a1a6d74448b10..9296eeb2671521 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -479,7 +479,7 @@ absl::StatusOr> LocalClient::Load( se::StreamExecutor * executor, backend().stream_executor(updated_options.device_ordinal())); - TF_ASSIGN_OR_RETURN(Compiler * compiler, + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, Compiler::GetForPlatform(platform())); TF_ASSIGN_OR_RETURN( std::unique_ptr aot_result, @@ -487,7 +487,7 @@ absl::StatusOr> LocalClient::Load( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - std::move(*aot_result).LoadExecutable(compiler, executor)); + std::move(*aot_result).LoadExecutable(compiler.get(), executor)); return std::make_unique(std::move(executable), local_service_->mutable_backend(), updated_options); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index ace0900148d075..2c08169881592f 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -90,13 +90,14 @@ absl::Status IsValidTopologyAndClientForCompile( return absl::OkStatus(); } -absl::StatusOr GetCompilerForDefaultGpuPlatform() { +absl::StatusOr> +GetCompilerForDefaultGpuPlatform() { TF_ASSIGN_OR_RETURN(stream_executor::Platform * platform, PlatformUtil::GetPlatform("gpu")); return Compiler::GetForPlatform(platform); } -absl::StatusOr GetCompilerForPlatform( +absl::StatusOr> GetCompilerForPlatform( std::optional platform_id) { if (!platform_id.has_value()) { return GetCompilerForDefaultGpuPlatform(); @@ -163,7 +164,7 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, HloModule::CreateFromProto(hlo_module_proto, *hlo_config)); UpdateEntryComputationLayout( hlo_module.get(), std::bind(&Compiler::DefaultDeviceShapeRepresentation, - gpu_compiler, std::placeholders::_1)); + gpu_compiler.get(), std::placeholders::_1)); DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName); Compiler::CompileOptions opts; opts.target_config = options.target_config; diff --git a/third_party/xla/xla/service/backend.cc b/third_party/xla/xla/service/backend.cc index ed301131a22d89..9ee413892de042 100644 --- a/third_party/xla/xla/service/backend.cc +++ b/third_party/xla/xla/service/backend.cc @@ -24,8 +24,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #define EIGEN_USE_THREADS -#include "xla/service/backend.h" - #include #include #include @@ -35,6 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/host/host_platform_id.h" @@ -99,9 +98,9 @@ struct Backend::IntraOpThreadPool { TransferManager::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto computation_placer, ComputationPlacer::GetForPlatform(platform)); - std::unique_ptr backend( - new Backend(platform, compiler, stream_executors, transfer_manager, - computation_placer, options.intra_op_parallelism_threads())); + std::unique_ptr backend(new Backend( + platform, std::move(compiler), stream_executors, transfer_manager, + computation_placer, options.intra_op_parallelism_threads())); return std::move(backend); } @@ -145,13 +144,13 @@ absl::StatusOr> Backend::BorrowStreams( return ptrs; } -Backend::Backend(se::Platform* platform, Compiler* compiler, +Backend::Backend(se::Platform* platform, std::unique_ptr compiler, absl::Span stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads) : platform_(platform), - compiler_(compiler), + compiler_(std::move(compiler)), transfer_manager_(transfer_manager), computation_placer_(computation_placer), stream_executors_(stream_executors.begin(), stream_executors.end()) { diff --git a/third_party/xla/xla/service/backend.h b/third_party/xla/xla/service/backend.h index 85dbfea69c7faa..4da3740b4ed798 100644 --- a/third_party/xla/xla/service/backend.h +++ b/third_party/xla/xla/service/backend.h @@ -92,7 +92,7 @@ class Backend { // Accessors for the various objects. se::Platform* platform() const { return platform_; } - Compiler* compiler() const { return compiler_; } + Compiler* compiler() const { return compiler_.get(); } se::DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_.get(); } @@ -178,7 +178,7 @@ class Backend { absl::Status ResetDevices(); private: - Backend(se::Platform* platform, Compiler* compiler, + Backend(se::Platform* platform, std::unique_ptr compiler, absl::Span stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, @@ -187,7 +187,7 @@ class Backend { Backend& operator=(const Backend&) = delete; se::Platform* platform_; - Compiler* compiler_; + std::unique_ptr compiler_; TransferManager* transfer_manager_; ComputationPlacer* computation_placer_; diff --git a/third_party/xla/xla/service/compile_only_service.cc b/third_party/xla/xla/service/compile_only_service.cc index a0e48bd8174d3c..873840298e44a3 100644 --- a/third_party/xla/xla/service/compile_only_service.cc +++ b/third_party/xla/xla/service/compile_only_service.cc @@ -22,9 +22,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/debug_options_flags.h" #include "xla/service/backend.h" +#include "xla/service/compiler.h" #include "xla/service/computation_layout.h" #include "xla/service/dump.h" #include "xla/service/platform_util.h" +#include "xla/service/service.h" #include "xla/status_macros.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" @@ -50,13 +52,14 @@ CompileOnlyService::NewService(const ServiceOptions& options) { TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); std::unique_ptr service( - new CompileOnlyService(options, compiler)); + new CompileOnlyService(options, std::move(compiler))); return std::move(service); } CompileOnlyService::CompileOnlyService(const ServiceOptions& options, - Compiler* compiler) - : Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {} + std::unique_ptr compiler) + : Service(options, /*execute_backend=*/nullptr), + compiler_(std::move(compiler)) {} absl::StatusOr>> CompileOnlyService::CompileAheadOfTime( diff --git a/third_party/xla/xla/service/compile_only_service.h b/third_party/xla/xla/service/compile_only_service.h index 0238a16f282946..dc83313b2fd01a 100644 --- a/third_party/xla/xla/service/compile_only_service.h +++ b/third_party/xla/xla/service/compile_only_service.h @@ -73,14 +73,14 @@ class CompileOnlyService : public Service { private: explicit CompileOnlyService(const ServiceOptions& options, - Compiler* compiler); + std::unique_ptr compiler); CompileOnlyService(const CompileOnlyService&) = delete; void operator=(const CompileOnlyService&) = delete; // The compiler for the target platform. This is included in place of // the Service::execute_backend_'s compiler, since execute_backend_ is a // nullptr in CompileOnlyService. - Compiler* compiler_; + std::unique_ptr compiler_; }; } // namespace xla diff --git a/third_party/xla/xla/service/compiler.cc b/third_party/xla/xla/service/compiler.cc index 43f33e30ea418c..a9f26dafb7d005 100644 --- a/third_party/xla/xla/service/compiler.cc +++ b/third_party/xla/xla/service/compiler.cc @@ -108,22 +108,10 @@ Compiler::GetPlatformCompilers() { (*factories)[platform_id] = std::move(compiler_factory); } -/* static */ absl::StatusOr Compiler::GetForPlatform( +/* static */ absl::StatusOr> Compiler::GetForPlatform( const se::Platform* platform) { absl::MutexLock lock(&platform_compiler_mutex_); - auto* compilers = GetPlatformCompilers(); - // See if we already instantiated a compiler for this platform. - { - auto it = compilers->find(platform->id()); - if (it != compilers->end()) { - return it->second.get(); - } - - // If not, we just fall through to try to create one with a registered - // factory. - } - auto* factories = GetPlatformCompilerFactories(); auto it = factories->find(platform->id()); if (it == factories->end()) { @@ -132,10 +120,7 @@ Compiler::GetPlatformCompilers() { "that platform linked in?", platform->Name()); } - - // And then we invoke the factory, placing the result into the mapping. - compilers->insert(std::make_pair(platform->id(), it->second())); - return compilers->at(platform->id()).get(); + return it->second(); } // Default implementation diff --git a/third_party/xla/xla/service/compiler.h b/third_party/xla/xla/service/compiler.h index f916d273ea4031..bbb7ccc35842c9 100644 --- a/third_party/xla/xla/service/compiler.h +++ b/third_party/xla/xla/service/compiler.h @@ -293,7 +293,8 @@ class Compiler { // Returns the compiler singleton pointer if it is available for the given // platform, or an error status if it is not. - static absl::StatusOr GetForPlatform(const se::Platform* platform); + static absl::StatusOr> GetForPlatform( + const se::Platform* platform); // Returns a function that computes the size in bytes of the logical // buffer that contains a shape. diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc index 50a70dc48a5b59..b19ccc92f98039 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc @@ -71,13 +71,13 @@ std::vector ExecutionInputsFromBuffers( } // namespace AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, - Compiler* compiler, + std::unique_ptr compiler, se::StreamExecutor& stream_executor, se::Stream& stream, se::DeviceMemoryAllocator& allocator, const DebugOptions& opts) : config_(config), - compiler_(compiler), + compiler_(std::move(compiler)), stream_executor_(stream_executor), stream_(stream), allocator_(allocator), @@ -165,10 +165,10 @@ absl::StatusOr> AutotunerCompileUtil::ExtractModule( se::StreamExecutor* stream_exec = config.GetExecutor(); se::DeviceMemoryAllocator* allocator = config.GetAllocator(); TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream()); - TF_ASSIGN_OR_RETURN(Compiler * compiler, + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, Compiler::GetForPlatform(stream_exec->GetPlatform())); - return AutotunerCompileUtil(config, compiler, *stream_exec, *stream, - *allocator, opts); + return AutotunerCompileUtil(config, std::move(compiler), *stream_exec, + *stream, *allocator, opts); } absl::StatusOr AutotunerCompileUtil::Execute( diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h index 0e0fcc712a6eb9..b2b11b86878cab 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h @@ -98,7 +98,8 @@ class AutotunerCompileUtil { GenerateModuleFn extractor); private: - AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler, + AutotunerCompileUtil(const AutotuneConfig& config, + std::unique_ptr compiler, se::StreamExecutor& stream_executor, se::Stream& stream, se::DeviceMemoryAllocator& allocator, const DebugOptions& opts); @@ -108,7 +109,7 @@ class AutotunerCompileUtil { ExecutionProfile* profile = nullptr); AutotuneConfig config_; - Compiler* compiler_; + std::unique_ptr compiler_; se::StreamExecutor& stream_executor_; se::Stream& stream_; se::DeviceMemoryAllocator& allocator_; diff --git a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc index 62304da074ccc9..9718c8c42c1ab8 100644 --- a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc +++ b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc @@ -96,11 +96,12 @@ absl::StatusOr> CompiledOptProvider::GenerateStage( return std::nullopt; } -absl::StatusOr CompiledOptProvider::GetCompiler() { +absl::StatusOr> CompiledOptProvider::GetCompiler() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform(GetPlatformName())); - TF_ASSIGN_OR_RETURN(Compiler * compiler, Compiler::GetForPlatform(platform)); + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, + Compiler::GetForPlatform(platform)); return compiler; } @@ -110,7 +111,7 @@ absl::StatusOr> CompiledOptProvider::GetOptimizedHlo( DebugOptions debug_opts = GetDebugOptionsFromFlags(); Compiler::CompileOptions opts; - TF_ASSIGN_OR_RETURN(Compiler * compiler, GetCompiler()); + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, GetCompiler()); DebugOptions d = input_module->config().debug_options(); d.set_xla_embed_ir_in_executable(true); input_module->mutable_config().set_debug_options(d); @@ -133,7 +134,7 @@ absl::StatusOr> CompiledOptProvider::GetExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, GetOptimizedHlo(std::move(input_module))); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); - TF_ASSIGN_OR_RETURN(Compiler * compiler, GetCompiler()); + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, GetCompiler()); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, compiler->RunBackend(std::move(optimized_module), executor, opts)); diff --git a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h index eaabe294b5533a..8cecb1f3abb689 100644 --- a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h +++ b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h @@ -62,7 +62,7 @@ class CompiledOptProvider : public OptProvider { std::unique_ptr input_module); // Gets a compiler associated with the provider. - virtual absl::StatusOr GetCompiler(); + virtual absl::StatusOr> GetCompiler(); // Registers hardware-specific passes which are shared by // multiple backends (CPU, GPU, xPU). diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index b2f842d2e1c0ff..df2e399acee3ed 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -177,10 +177,10 @@ class GpuOptProvider : public CompiledOptProvider { GetDeviceDescription(optimized_module)); TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform(GetPlatformName())); - TF_ASSIGN_OR_RETURN(Compiler * compiler, + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, Compiler::GetForPlatform(platform)); - auto* gpu_compiler = static_cast(compiler); + auto* gpu_compiler = static_cast(compiler.get()); if (!optimized_module->has_schedule()) { TF_ASSIGN_OR_RETURN(gpu::ScheduleMetadata schedule_metadata, gpu::ScheduleGpuModule(optimized_module, From 174510c03d9e16cdf866f72f783ed6876b00896d Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Wed, 2 Apr 2025 07:53:06 -0700 Subject: [PATCH 0153/1324] [XLA:GPU] check in nest_gemm_fusion if the resulting computation is supported Normally we run a support check on HLO before assigning it to the generic emitter. In nested gemm fusion we switch the backend forcefully so it is a good idea to make sure that the resulting HLO is supported by the backend. PiperOrigin-RevId: 743136039 --- .../gpu/codegen/triton/fusion_emitter.cc | 41 ++----- .../backends/gpu/codegen/triton/support.cc | 49 +++++++-- .../gpu/codegen/triton/support_test.cc | 48 ++++++++- .../xla/xla/service/gpu/gpu_compiler.cc | 3 +- .../xla/xla/service/gpu/transforms/BUILD | 6 ++ .../gpu/transforms/nest_gemm_fusion.cc | 44 ++++++-- .../service/gpu/transforms/nest_gemm_fusion.h | 5 + .../gpu/transforms/nest_gemm_fusion_test.cc | 100 ++++++++++++++---- 8 files changed, 219 insertions(+), 77 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index dbfa75fb4e2156..ddec79e8f3fb99 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -90,6 +90,7 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" +#include "xla/backends/gpu/codegen/triton/support.h" #include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" @@ -1109,38 +1110,6 @@ absl::StatusOr EmitTiledHloInstruction( absl::StrCat("Unsupported operation ", hlo->ToString())); } -// Verifies that the nested fusion instruction conforms to the assumptions of -// the emitter. Currently, we expect nested fusions: -// - of kind `__triton_nested_gemm_fusion` -// - to have a single user that is either a `dot` or a `concatenate`. -absl::Status VerifyNestedFusion(const HloInstruction& hlo) { - // TODO(b/393299275): test cases when there are multiple dot users of the - // same fusion. - if (hlo.user_count() != 1) { - return absl::FailedPreconditionError( - absl::StrCat("Expected only one user for fusion ", hlo.ToString(), - " but got ", hlo.user_count())); - } - TF_ASSIGN_OR_RETURN(GpuBackendConfig backend_config, - hlo.backend_config()); - if (const std::string& kind = backend_config.fusion_backend_config().kind(); - kind != kTritonNestedGemmFusionKind) { - return absl::FailedPreconditionError(absl::StrCat( - "Expected ", hlo.ToString(), - " with fusion backend kind __triton_nested_gemm_fusion, got ", kind)); - } - const HloInstruction* user = hlo.users().front(); - switch (user->opcode()) { - case HloOpcode::kDot: - case HloOpcode::kConcatenate: - return absl::OkStatus(); - default: - return absl::FailedPreconditionError( - absl::StrCat("Unexpected user ", user->ToString(), - " of nested fusion ", hlo.ToString())); - } -} - // Emit a sequence of instructions using compatible tiling with producers // ordered before consumers in `tiled_computation`. Returns the results for the // roots of `tiled_computation`. @@ -1158,7 +1127,13 @@ absl::StatusOr> EmitTiledComputation( // Skip generating nested fusions, they are emitted by their consumer. if (hlo->parent()->IsFusionComputation() && hlo->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(VerifyNestedFusion(*hlo)); + CodegenDecision decision = IsTritonSupportedInstruction( + *hlo, device_info.gpu_compute_capability()); + if (!decision.CanFuse()) { + return absl::FailedPreconditionError( + absl::StrCat("Fusion ", hlo->ToString(), + " is not supported: ", decision.Explain())); + } VLOG(1) << "Skipping nested fusion: " << hlo->ToString(); continue; } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index e9f06dd0fcec2e..f7d59d1e312847 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/support.h" #include +#include #include #include @@ -367,20 +368,49 @@ CodegenDecision IsTritonSupportedDot( return CodegenDecision::Allow(); } -CodegenDecision IsSupportedFusion(const HloFusionInstruction& fusion) { +// Verifies that the nested fusion instruction conforms to the assumptions of +// the emitter. Currently, we expect nested fusions: +// - of kind `__triton_nested_gemm_fusion`; +// - to have a single user that is either a `dot` or a `concatenate`; +// - calls a supported computation. +CodegenDecision IsSupportedFusion(const HloFusionInstruction& hlo, + const se::GpuComputeCapability& capability) { + // TODO(b/393299275): test cases when there are multiple dot users of the + // same fusion. + if (hlo.user_count() != 1) { + return CodegenDecision::Forbid( + absl::StrCat("Expected only one user for fusion ", hlo.ToString(), + " but got ", hlo.user_count())); + } absl::StatusOr backend_config = - fusion.backend_config(); + hlo.backend_config(); if (!backend_config.ok()) { return CodegenDecision(backend_config.status()); } - absl::string_view fusion_kind = - backend_config.value().fusion_backend_config().kind(); - // Note: kTritonFusionKind is NOT expected to be set for nested fusions. - if (fusion_kind != kTritonNestedGemmFusionKind) { + if (const std::string& kind = + backend_config.value().fusion_backend_config().kind(); + kind != kTritonNestedGemmFusionKind) { return CodegenDecision::Forbid( - absl::StrCat("Unsupported fusion kind: ", fusion_kind)); + absl::StrCat("Expected ", hlo.ToString(), " with fusion backend kind ", + kTritonNestedGemmFusionKind, ", got ", kind)); } - return CodegenDecision::Allow(); + const HloInstruction* user = hlo.users().front(); + switch (user->opcode()) { + case HloOpcode::kDot: + case HloOpcode::kConcatenate: + break; + default: + return CodegenDecision::Forbid(absl::StrCat( + "Unexpected user opcode ", user->opcode(), " of nested fusion")); + } + CodegenDecision decision = + IsTritonSupportedComputation(*hlo.called_computation(), capability); + if (decision.CanFuse()) { + return CodegenDecision::Allow(); + } + return CodegenDecision::Forbid( + absl::StrCat("Computation called by fusion ", hlo.ToString(), + " is not supported: ", decision.Explain())); } CodegenDecision IsTritonSupportedConcatenate(const HloInstruction& hlo) { @@ -471,7 +501,8 @@ CodegenDecision IsTritonSupportedInstructionImpl( return IsTritonSupportedDot(*Cast(&instr), gpu_version); case HloOpcode::kFusion: - return IsSupportedFusion(*Cast(&instr)); + return IsSupportedFusion(*Cast(&instr), + gpu_version); default: // Not all instructions have a special handling. break; diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 1b5ca1adb8345f..aa30303e37cce0 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -2280,7 +2280,8 @@ class FusionKindsTest TEST_P(FusionKindsTest, OperandOfDot) { auto [kind, cc] = GetParam(); - const std::string hlo_text = absl::Substitute(R"( + const std::string hlo_text = absl::Substitute( + R"( flhs { ROOT result = f32[128,256] parameter(0) } @@ -2302,7 +2303,7 @@ ENTRY triton_computation { lhs_contracting_dims={1}, rhs_contracting_dims={0} } )", - kind); + kind); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -2361,6 +2362,49 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest())), FusionKindsTestName); +using FusionTest = TritonSupportTest; + +TEST_F(FusionTest, FusionComputationIsCheckedRecursively) { + // We expect test for fail as `flhs` is not a supported computation as + // fusion there is not an operand of a dot or a concatenate. + absl::string_view hlo_text = R"( +identity { + ROOT result = f32[128,256] parameter(0) +} + +flhs { + p0 = f32[128,256] parameter(0) + ROOT result = f32[128,256] fusion(p0), kind=kCustom, calls=identity, backend_config={ + "fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "64"]}]}}} +} + +frhs { + ROOT result = f32[256,512] parameter(0) +} + +ENTRY triton_computation { + p0 = f32[128,256] parameter(0) + p1 = f32[256,512] parameter(1) + lhs = f32[128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ + "fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "64"]}]}}} + rhs = f32[256,512]{1,0} fusion(p1), kind=kCustom, calls=frhs, + backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config": {"output_tiles":[{"sizes":["64", "32"]}]}}} + ROOT result = f32[128,512]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, F32, HloOpcode::kFusion, + /*use_nested_gemm_fusions=*/true)); + se::GpuComputeCapability cc = se::CudaComputeCapability::Ampere(); + ASSERT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{64, 32}, cc); +} + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index e51ac671036ff1..0fb4047a3d1e3e 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1656,7 +1656,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( if (debug_options .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()) { - pipeline.AddPass(); + pipeline.AddPass( + gpu_target_config.device_description.gpu_compute_capability()); } // Inline back the calls which have better performance with cuBLAS. pipeline.AddPass( diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 56050746f2a4bc..768b37554a6632 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2334,6 +2334,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/backends/gpu/codegen/triton:support", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:hlo_dce", @@ -2347,6 +2348,7 @@ cc_library( "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", @@ -2377,12 +2379,16 @@ xla_cc_test( "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index a4fa7730c45838..3b8c0eb78b93fe 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/MLIRContext.h" +#include "xla/backends/gpu/codegen/triton/support.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -59,6 +60,7 @@ limitations under the License. #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -709,38 +711,58 @@ absl::Status TryHoistBitcastsInComputationToCallers(HloInstruction* dot, class NestGemmFusionVisitor : public DfsHloRewriteVisitor { public: - explicit NestGemmFusionVisitor(mlir::MLIRContext* ctx, CallGraph* call_graph) - : ctx_(ctx), call_graph_(call_graph) {} + explicit NestGemmFusionVisitor( + mlir::MLIRContext* ctx, CallGraph* call_graph, + const se::GpuComputeCapability compute_capability) + : ctx_(ctx), + call_graph_(call_graph), + compute_capability_(compute_capability) {} absl::Status HandleFusion(HloInstruction* instruction) override { HloFusionInstruction* fusion = Cast(instruction); absl::StatusOr config = GetTritonGemmConfig(*fusion); if (!config.ok()) { - return absl::OkStatus(); // Skip because it's not a Triton gemm fusion. + VLOG(2) << "Skipping fusion as it does not have a TritonGemmConfig"; + return absl::OkStatus(); } HloComputation* computation = fusion->called_computation(); - HloInstruction* dot = + HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); - if (dot == nullptr) { - return absl::OkStatus(); // Skip because fusion has no dot. + if (instr == nullptr) { + VLOG(2) << "Skipping fusion as it has no dot instruction"; + return absl::OkStatus(); } DCHECK_EQ(GetDotCount(computation), 1) << "Fusion has more than one dot."; - + HloDotInstruction* dot = Cast(instr); TF_RETURN_IF_ERROR( - TryHoistBitcastsInComputationToCallers(dot, call_graph_)); + TryHoistBitcastsInComputationToCallers(instr, call_graph_)); VLOG(2) << "After hoisting bitcasts: " << computation->ToString(); - TF_RETURN_IF_ERROR(MakeNestedFusionFromGemmFusion( - fusion, config.value(), Cast(dot), ctx_)); + + TF_RETURN_IF_ERROR( + MakeNestedFusionFromGemmFusion(fusion, config.value(), dot, ctx_)); this->MarkAsChanged(); + // TODO(b/393299275): support checks should be run *before* the fusion is + // constructed and this pass should only be applied to the known supported + // HLO. Currently though, we are at mercy of what GemmFusion pass thinks + // legacy emitter can handle. We change the kind of the fusion here and + // switch the track. Thus it is on us to make sure that the generic emitter + // will be able to handle the result. That is an early check to make sure + // that that nesting did not produce an unsupported HLO. + if (!IsTritonSupportedComputation(*computation, compute_capability_)) { + return absl::InternalError(absl::StrCat("Computation of fusion ", + fusion->ToString(), + " is not supported by Triton.")); + } return absl::OkStatus(); } private: mlir::MLIRContext* ctx_; CallGraph* call_graph_; + const se::GpuComputeCapability compute_capability_; }; } // namespace @@ -753,7 +775,7 @@ absl::StatusOr NestGemmFusion::Run( mlir::MLIRContext ctx; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - NestGemmFusionVisitor visitor(&ctx, call_graph.get()); + NestGemmFusionVisitor visitor(&ctx, call_graph.get(), compute_capability_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); changed |= visitor.changed(); } diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h index 134810c1600b66..9ae4f4bf81fee5 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -42,6 +43,9 @@ namespace xla::gpu { // nested fusions, each with their own BlockLevelFusionConfig. class NestGemmFusion : public HloModulePass { public: + explicit NestGemmFusion(const se::GpuComputeCapability& compute_capability) + : compute_capability_(compute_capability) {} + absl::string_view name() const override { return "nest_gemm_fusion"; } using HloPassInterface::Run; @@ -50,6 +54,7 @@ class NestGemmFusion : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: + const se::GpuComputeCapability compute_capability_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc index 01275f2e0217ed..8c712ab27814b2 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -20,20 +20,27 @@ limitations under the License. #include #include #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/pattern_matcher.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" using ::testing::ElementsAre; +using ::testing::Not; +using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; namespace xla { @@ -69,21 +76,26 @@ MATCHER_P(OutputTileSizesIs, matcher, "") { return ExplainMatchResult(matcher, output_tile_sizes, result_listener); } -class NestGemmFusionTest : public HloTestBase {}; +class NestGemmFusionTest : public HloTestBase { + protected: + const se::GpuComputeCapability compute_capability_{ + TestGpuDeviceInfo::RTXA6000DeviceInfo(se::CudaComputeCapability::Ampere()) + .gpu_compute_capability()}; +}; TEST_F(NestGemmFusionTest, BasicTest) { absl::string_view hlo = R"( dot { - lhs = bf16[8192,512] parameter(0) - rhs = bf16[512,512] parameter(1) - ROOT dot = bf16[8192,512] dot(lhs, rhs), + lhs = f32[8192,512] parameter(0) + rhs = f32[512,512] parameter(1) + ROOT dot = f32[8192,512] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY entry { - p0 = bf16[8192,512] parameter(0) - p1 = bf16[512,512] parameter(1) - ROOT fusion = bf16[8192,512] fusion(p0, p1), + p0 = f32[8192,512] parameter(0) + p1 = f32[512,512] parameter(1) + ROOT fusion = f32[8192,512] fusion(p0, p1), kind=kCustom, calls=dot, backend_config={ "fusion_backend_config": { "kind":"__triton_gemm", "triton_gemm_config": { @@ -95,7 +107,8 @@ ENTRY entry { })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + ASSERT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); const HloInstruction* fusion = nullptr; @@ -154,7 +167,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + ASSERT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); const HloInstruction* fusion = nullptr; @@ -196,7 +210,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + ASSERT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); const HloInstruction* fusion = nullptr; @@ -238,7 +253,8 @@ ENTRY entry { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } @@ -268,7 +284,8 @@ ENTRY entry { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } @@ -303,7 +320,8 @@ ENTRY entry { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } @@ -335,7 +353,8 @@ ENTRY entry { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } @@ -370,7 +389,8 @@ ENTRY entry_computation { )"; // Note: block sizes were 16,16,32, but that now fails to satisfy constraints. TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } @@ -402,10 +422,44 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + // TODO(b/393299275): rhs_contracting_dims={0} is not currently supported. + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + Not(IsOk())); TF_ASSERT_OK(verifier().Run(module.get()).status()); } +TEST_F(NestGemmFusionTest, UnsupportedComputationsAreRejected) { + // Fusions other than kTritonNestedGemmFusionKind are not supported so + // we expect that the pass will fail as the resulting computation is not + // supported. + absl::string_view hlo = R"( +identity { + ROOT result = f32[128,128]{1,0} parameter(0) +} + +triton_dot { + p0 = f32[128,128]{1,0} parameter(0) + cp0 = f32[128,128]{1,0} fusion(p0), kind=kCustom, calls=identity + p1 = f32[128,128]{1,0} parameter(1) + ROOT result = f32[128,128]{1,0} dot(cp0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[128,128]{1,0} parameter(0) + p1 = f32[128,128]{1,0} parameter(1) + ROOT result = f32[128,128] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + "triton_gemm_config": { + "block_m":32,"block_n":16,"block_k":128, + "split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}}} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + absl::StatusOr result = + NestGemmFusion(compute_capability_).Run(module.get()); + EXPECT_THAT(result, StatusIs(absl::StatusCode::kInternal)) << result.status(); +} + // TODO(b/393299275): correctly hoist bitcast through compare. // Fails with: "... [Unknown]: Expected comparison type UNSIGNED.". TEST_F(NestGemmFusionTest, DISABLED_BitcastsAreHoistedPastCompare) { @@ -434,7 +488,8 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}}} )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } @@ -461,7 +516,8 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4,"num_ctas":1}}}} )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -502,7 +558,7 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4,"num_ctas":1}}}} )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_TRUE(!NestGemmFusion().Run(module.get()).ok()); + EXPECT_TRUE(!NestGemmFusion(compute_capability_).Run(module.get()).ok()); TF_ASSERT_OK(verifier().Run(module.get()).status()); // Cos should not be rewritten as we cannot hoist bitcast. EXPECT_THAT( @@ -537,7 +593,8 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4,"num_ctas":1}}}} )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -572,7 +629,8 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4,"num_ctas":1}}}} )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( From 2c554f8120204e3985af94b42f44ecacffde9f60 Mon Sep 17 00:00:00 2001 From: Kuter Dinel Date: Wed, 2 Apr 2025 08:11:25 -0700 Subject: [PATCH 0154/1324] [XLA:LatencyHidingScheduler] Extend ScheduleProto for the whole module. Instead of dumping one schedule proto per computation, exetend ScheduleProto to contain the schedule info of all computations in a module. PiperOrigin-RevId: 743142984 --- .../xla/service/latency_hiding_scheduler.cc | 41 +++++++++++-------- .../xla/service/latency_hiding_scheduler.h | 33 +++++++++++++-- third_party/xla/xla/xla.proto | 16 +++++--- 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 39cc72c971644b..88ab9546cef438 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -2402,6 +2402,7 @@ void HloScheduleGraph::AnnotateGraph( absl::Status DefaultSchedulerCore::InitializeScheduler( const HloModule* module) { + module_ = module; TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); module_pressure_state_ = std::make_unique( module, alias_analysis_.get(), shape_size_bytes_); @@ -2411,6 +2412,7 @@ absl::Status DefaultSchedulerCore::InitializeScheduler( if (VLOG_IS_ON(2)) { annotation_tracker_->PrintAnnotationSets(2); } + if (!scheduling_instruction_crosses_overlap_limit_) { scheduling_instruction_crosses_overlap_limit_ = [](const SchedulingState& sched_state, const HloGraphNode* node) { @@ -2589,25 +2591,22 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { .GetNode(sched_state.new_sequence_reversed.front()) .GetReadyTime(); - const auto& debug_options = xla::GetDebugOptionsFromFlags(); - if (debug_options.xla_dump_latency_hiding_schedule()) { - int core_freq = latency_estimator_->CyclesPerMicrosecond(); - DumpLatencyHidingSchedule(computation, sched_state.sched_graph, - sched_state.new_sequence_reversed, core_freq, - debug_options); + if (schedule_proto_.has_value()) { + *schedule_proto_->add_computation_schedules() = ComputationScheduleToProto( + computation, sched_state.sched_graph, *latency_estimator_, + sched_state.new_sequence_reversed); } - return std::move(sched_state.new_sequence_reversed); } -void DefaultSchedulerCore::DumpLatencyHidingSchedule( +ScheduleProto::ComputationScheduleProto +DefaultSchedulerCore::ComputationScheduleToProto( const HloComputation* computation, const HloScheduleGraph& schedule_graph, - const std::vector& instructions, - const int cycles_per_microsecond, const DebugOptions& debug_options) { - ScheduleProto proto; + const LatencyEstimator& estimator, + const std::vector& instructions) { + ScheduleProto::ComputationScheduleProto proto; proto.set_computation_id(computation->unique_id()); - proto.set_cycles_per_microsecond(cycles_per_microsecond); - + proto.set_cycles_per_microsecond(estimator.CyclesPerMicrosecond()); *proto.mutable_scheduler_statistics() = LatencyHidingScheduler::LatencyHidingStatistics( computation, latency_estimator_, async_tracker_, shape_size_bytes_) @@ -2626,10 +2625,7 @@ void DefaultSchedulerCore::DumpLatencyHidingSchedule( instr_msg->set_start_timestamp_cycles(start_time); instr_msg->set_end_timestamp_cycles(end_time); } - *proto.mutable_hlo_module() = computation->parent()->ToProto(); - - const std::string fn = absl::StrFormat("%s.schedule", computation->name()); - DumpProtobufToFile(proto, debug_options, fn); + return proto; } LatencyHidingScheduler::SchedulerStatistics @@ -2874,6 +2870,11 @@ absl::StatusOr LatencyHidingScheduler::Run( absl::flat_hash_map> saved_schedules; TF_RETURN_IF_ERROR(scheduler_core_->InitializeScheduler(module)); + const auto& debug_options = xla::GetDebugOptionsFromFlags(); + if (debug_options.xla_dump_latency_hiding_schedule()) { + TF_RETURN_IF_ERROR(scheduler_core_->CaptureScheduleProto()); + } + for (HloComputation* computation : computations_to_schedule) { TF_ASSIGN_OR_RETURN(std::vector new_schedule, scheduler_core_->ScheduleComputation(computation)); @@ -2909,6 +2910,12 @@ absl::StatusOr LatencyHidingScheduler::Run( VLOG(1) << "Statistics after scheduling:"; LogScheduleStatistics(computation); } + if (debug_options.xla_dump_latency_hiding_schedule()) { + TF_ASSIGN_OR_RETURN(ScheduleProto proto, + scheduler_core_->GetCapturedScheduleProto()); + const std::string filename = absl::StrFormat("%s.schedule", module->name()); + DumpProtobufToFile(proto, debug_options, filename); + } return true; } diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 74c4dce6e82c3e..9f0cb987d78e46 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -333,8 +333,14 @@ class AsyncTracker { class SchedulerCore { public: virtual absl::Status InitializeScheduler(const HloModule* module) = 0; + + virtual absl::Status CaptureScheduleProto() = 0; + + virtual absl::StatusOr GetCapturedScheduleProto() = 0; + virtual absl::StatusOr> ScheduleComputation( const HloComputation* computation) = 0; + virtual ~SchedulerCore() = default; virtual int64_t GetMemoryPeak() = 0; virtual void SetMemoryLimit(uint64_t new_limit) = 0; @@ -1130,7 +1136,23 @@ class DefaultSchedulerCore : public SchedulerCore { post_processing_fn_(post_processing_fn), scheduling_instruction_crosses_overlap_limit_( scheduling_instruction_crosses_overlap_limit) {} + absl::Status InitializeScheduler(const HloModule* module) override; + + absl::Status CaptureScheduleProto() override { + schedule_proto_ = ScheduleProto(); + *schedule_proto_->mutable_hlo_module() = module_->ToProto(); + + return absl::OkStatus(); + } + + absl::StatusOr GetCapturedScheduleProto() override { + if (!schedule_proto_.has_value()) { + return absl::FailedPreconditionError("Schedule proto not captured."); + } + return schedule_proto_.value(); + } + absl::StatusOr> ScheduleComputation( const HloComputation* computation) override; static bool AddOccupierToResource( @@ -1152,6 +1174,11 @@ class DefaultSchedulerCore : public SchedulerCore { absl::flat_hash_map GetNumResourcesNeededForAnnotation( const SchedulingState& sched_state, int64_t annotation); + ScheduleProto::ComputationScheduleProto ComputationScheduleToProto( + const HloComputation* computation, const HloScheduleGraph& schedule_graph, + const LatencyEstimator& estimator, + const std::vector& instructions); + protected: virtual void LogInstruction(const HloInstruction* instr) const; // Schedules the given annotated node. @@ -1172,10 +1199,6 @@ class DefaultSchedulerCore : public SchedulerCore { virtual absl::StatusOr FindAndExtractBestNodeAvailable( SchedulingState& sched_state, DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node); - void DumpLatencyHidingSchedule( - const HloComputation* computation, const HloScheduleGraph& schedule_graph, - const std::vector& instructions, - int cycles_per_microsecond, const DebugOptions& debug_options); HloCostAnalysis::ShapeSizeFunction shape_size_bytes_; std::unique_ptr module_pressure_state_; @@ -1188,6 +1211,8 @@ class DefaultSchedulerCore : public SchedulerCore { PostProcessingFn post_processing_fn_ = nullptr; OverlapLimitRule scheduling_instruction_crosses_overlap_limit_ = nullptr; std::unique_ptr annotation_tracker_; + std::optional schedule_proto_; + const HloModule* module_ = nullptr; }; // A scheduler oriented to hiding latencies of operations that can run in diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 881cbf5138e803..c4e1837272c7b1 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -1499,7 +1499,6 @@ message ScheduleProto { double start_timestamp_cycles = 2; double end_timestamp_cycles = 3; } - repeated Instruction instructions = 1; message SchedulerStatisticsProto { double all_gather_wasted_cycles = 1; @@ -1516,11 +1515,16 @@ message ScheduleProto { int64 memory_pressure_peak = 12; } - // Computation id (matches the id in HloComputationProto). - int64 computation_id = 2; - HloModuleProto hlo_module = 3; - int64 cycles_per_microsecond = 4; - SchedulerStatisticsProto scheduler_statistics = 5; + message ComputationScheduleProto { + // Computation id (matches the id in HloComputationProto). + int64 computation_id = 1; + repeated Instruction instructions = 2; + SchedulerStatisticsProto scheduler_statistics = 3; + int64 cycles_per_microsecond = 4; + } + + HloModuleProto hlo_module = 1; + repeated ComputationScheduleProto computation_schedules = 2; } // Message that captures sharding configuration of an HLO op. From 78e5fe8d964a70c0f91781038dc3b4582a11d1cf Mon Sep 17 00:00:00 2001 From: Ranko Sredojevic Date: Wed, 2 Apr 2025 08:15:51 -0700 Subject: [PATCH 0155/1324] Provide conveniences for parallelization of simple loops when actions return void. PiperOrigin-RevId: 743144415 --- .../xla/xla/hlo/utils/concurrency/BUILD | 1 + .../hlo/utils/concurrency/concurrency_utils.h | 35 +++++++++++++ .../concurrency/concurrency_utils_test.cc | 49 +++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/third_party/xla/xla/hlo/utils/concurrency/BUILD b/third_party/xla/xla/hlo/utils/concurrency/BUILD index 93ff7b472d7084..3e12edf86fda32 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/BUILD +++ b/third_party/xla/xla/hlo/utils/concurrency/BUILD @@ -46,6 +46,7 @@ cc_library( deps = [ ":tsl_task_executor", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], diff --git a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h index b962801c051191..1adc1b62869074 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h +++ b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/utils/concurrency/tsl_task_executor.h" @@ -91,6 +92,40 @@ absl::StatusOr> ForEach( return status; } +// Runs an action on all elements from an iterator. Note that the action must be +// side-effecting to make any sense, and specifically it can be mutating. +// +// Returns synchronously when all actions finish. Aborts the run on the first +// failure. If a run aborts the underlying data is likely to be corrupted or +// partially modified. +// +// For synchronization, clients should make sure that actions do not deadlock or +// corrupt any state they access. Specifically, if actions access any shared +// mutable state clients must make sure that such access is synchronized. The +// run can deadlock in all the standard ways. Specifically, if the action locks +// a set of shared resources make sure that all locks are acquired in the same +// order. +template +#if __cplusplus >= 202002L + requires(std::forward_iterator) +#endif +absl::Status ForEach(ForwardItT begin, ForwardItT end, + absl::AnyInvocable::value_type)> + action, + TaskExecutorT& task_executor, + std::optional parallelism = std::nullopt) { + auto result_size = std::distance(begin, end); + std::vector tasks; + tasks.reserve(result_size); + + for (auto iterator = begin; iterator != end; ++iterator) { + auto argument = *iterator; + tasks.push_back([argument, &action]() { return action(argument); }); + } + return task_executor.ExecuteIndependentTasks(std::move(tasks), parallelism); +} + } // namespace xla::concurrency #endif // XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ diff --git a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc index a6009361701023..6a36b29b0a180e 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc +++ b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc @@ -29,6 +29,55 @@ namespace { using ::testing::ElementsAreArray; +TEST(ForEachTest, IterVariantConcurrentlyIncrementsIntegers) { + TslTaskExecutor task_executor(5); + + constexpr int kx0 = 0; + constexpr int kx1 = 1; + constexpr int kx2 = 2; + + int v0 = kx0; + int v1 = kx1; + int v2 = kx2; + + std::vector v = {&v0, &v1, &v2}; + + ASSERT_EQ(ForEach( + v.begin(), v.end(), + [](int* element) { + ++(*element); + return absl::OkStatus(); + }, + task_executor), + absl::OkStatus()); + + EXPECT_EQ(v0, kx0 + 1); + EXPECT_EQ(v1, kx1 + 1); + EXPECT_EQ(v2, kx2 + 1); +} + +TEST(ForEachTest, NonOkStatusPropagatesAsTheFinalResult) { + const absl::Status status = absl::CancelledError("Test Error"); + + TslTaskExecutor task_executor{3}; + + constexpr int kx0 = 0; + constexpr int kx1 = 1; + constexpr int kx2 = 2; + + int v0 = kx0; + int v1 = kx1; + int v2 = kx2; + + std::vector v = {&v0, &v1, &v2}; + + EXPECT_THAT(ForEach( + v.begin(), v.end(), + [&status](int* element) { return status; }, task_executor) + .code(), + absl::StatusCode::kCancelled); +} + TEST(ForEachTest, ActionReturnedValuesCollected) { TslTaskExecutor task_executor{3}; From 78f750ffab7eaeb715cd9b38cf1e08baca67edbf Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Wed, 2 Apr 2025 08:44:29 -0700 Subject: [PATCH 0156/1324] [XLA:GPU] Add triton support test for bitcast-convert PiperOrigin-RevId: 743153060 --- .../backends/gpu/codegen/triton/support.cc | 1 - .../gpu/codegen/triton/support_test.cc | 69 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index f7d59d1e312847..c1a9a32e39a760 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -531,7 +531,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { switch (opcode) { case HloOpcode::kAddDependency: case HloOpcode::kAfterAll: - case HloOpcode::kBitcastConvert: case HloOpcode::kCholesky: case HloOpcode::kConvolution: case HloOpcode::kCopyDone: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index aa30303e37cce0..39a874ac409057 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -2405,12 +2405,78 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{64, 32}, cc); } +class BitcastConvertTest + : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> { +}; + +TEST_P(BitcastConvertTest, BitcastConvert) { + auto [data_type_in, data_type_out, cc] = GetParam(); + + if (primitive_util::IsComplexType(data_type_in) != + primitive_util::IsComplexType(data_type_out)) { + GTEST_SKIP() + << "BitcastConvert does not support complex <-> real conversion."; + } + + std::string hlo_text; + std::vector output_tile_sizes = {1, 32}; + + const int bit_width_in = primitive_util::BitWidth(data_type_in); + const int bit_width_out = primitive_util::BitWidth(data_type_out); + const std::string data_type_in_str = + primitive_util::LowercasePrimitiveTypeName(data_type_in); + const std::string data_type_out_str = + primitive_util::LowercasePrimitiveTypeName(data_type_out); + + if (bit_width_in == bit_width_out) { + hlo_text = absl::Substitute( + R"( +ENTRY triton_computation { + parameter = $0[33,68] parameter(0) + ROOT bc_convert = $1[33,68] bitcast-convert(parameter) +})", + data_type_in_str, data_type_out_str); + } else if (bit_width_in > bit_width_out) { + hlo_text = absl::Substitute( + R"( +ENTRY triton_computation { + parameter = $0[33] parameter(0) + ROOT bc_convert = $1[33, $2] bitcast-convert(parameter) +})", + data_type_in_str, data_type_out_str, bit_width_in / bit_width_out); + } else { // bit_width_in < bit_width_out + hlo_text = absl::Substitute( + R"( +ENTRY triton_computation { +parameter = $0[33, $1] parameter(0) +ROOT bc_convert = $2[33] bitcast-convert(parameter) +})", + data_type_in_str, bit_width_out / bit_width_in, data_type_out_str); + output_tile_sizes = {1}; + } + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, data_type_in, + HloOpcode::kBitcastConvert)); + + RunSupportTest(std::move(ti), output_tile_sizes, cc); +} + +INSTANTIATE_TEST_SUITE_P( + BitcastConvertSuite, BitcastConvertTest, + ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllDevicesToTest())), + TritonSupportTestTwoTypesAndDeviceToString); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start HloOpcode::kAddDependency, HloOpcode::kAfterAll, - HloOpcode::kBitcastConvert, HloOpcode::kCholesky, HloOpcode::kConvolution, HloOpcode::kCopyDone, @@ -2468,6 +2534,7 @@ absl::flat_hash_set AllTestedOpcodes() { ret.emplace(HloOpcode::kBatchNormGrad); ret.emplace(HloOpcode::kBatchNormInference); ret.emplace(HloOpcode::kBatchNormTraining); + ret.emplace(HloOpcode::kBitcastConvert); ret.emplace(HloOpcode::kCall); ret.emplace(HloOpcode::kComplex); ret.emplace(HloOpcode::kConditional); From 64414c61826f66cb13918f921f2a10f1f05233d8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 08:44:52 -0700 Subject: [PATCH 0157/1324] No public change. PiperOrigin-RevId: 743153161 --- third_party/xla/xla/service/metrics.proto | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/metrics.proto b/third_party/xla/xla/service/metrics.proto index ab359a56d7de9c..282642c9a1037f 100644 --- a/third_party/xla/xla/service/metrics.proto +++ b/third_party/xla/xla/service/metrics.proto @@ -6,6 +6,8 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/timestamp.proto"; +// internal imports + // Defines generic pass stats. message KeyValueMetric { string key = 1; From 82740c5a6b7c84d05e2adee2b35fa3f2a56ae132 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 2 Apr 2025 09:38:12 -0700 Subject: [PATCH 0158/1324] [XLA:GPU] Add support for dot algorithms to the generic Triton emitter. Support for these algorithms should now match what is implemented in the legacy dot emitter---since the logic is shared, and enforced by `support_test.cc`. PiperOrigin-RevId: 743171324 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 15 +- .../gpu/codegen/triton/fusion_emitter.cc | 52 +-- .../triton/fusion_emitter_device_test.cc | 344 +++++++++++++++++- .../backends/gpu/codegen/triton/support.cc | 138 ++++++- .../xla/backends/gpu/codegen/triton/support.h | 5 + .../gpu/codegen/triton/support_test.cc | 28 +- .../backends/gpu/codegen/triton/test_utils.cc | 24 +- .../backends/gpu/codegen/triton/test_utils.h | 2 + 8 files changed, 529 insertions(+), 79 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index ff982d7158f88a..84a4652b79cf1d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -180,6 +180,7 @@ cc_library( hdrs = ["fusion_emitter.h"], deps = [ ":compilation_pipeline", + ":dot_algorithms", ":emitter_helpers", ":fusion_emitter_legacy_matmul", ":support", @@ -650,11 +651,13 @@ xla_test( "gpu_b200", "gpu_amd_any", ], + shard_count = 5, tags = [ "no_mac", ], deps = [ ":fusion_emitter", + ":support", ":test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", @@ -664,6 +667,7 @@ xla_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:algorithm_util", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", @@ -675,6 +679,8 @@ xla_test( "//xla/tsl/platform:errors", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -727,7 +733,6 @@ cc_library( ":fusion_emitter", "//xla:shape_util", "//xla:status_macros", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/testlib:filecheck", @@ -741,7 +746,11 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:hlo_test_base", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -752,8 +761,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:protobuf", ], ) @@ -828,6 +836,7 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:algorithm_util", "//xla/service:instruction_fusion", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index ddec79e8f3fb99..503adba0bc4fd2 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -87,6 +87,7 @@ limitations under the License. #include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" #include "xla/backends/gpu/codegen/emitters/transforms/passes.h" #include "xla/backends/gpu/codegen/triton/compilation_pipeline.h" +#include "xla/backends/gpu/codegen/triton/dot_algorithms.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" @@ -778,10 +779,14 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, return absl::FailedPreconditionError("Expected dot operands to be fusions"); } - // Iteration arguments only contain the accumulator. - TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, dot.shape().element_type())); - SmallVector iter_args = { - CreateConst(b, ty, 0.0f, tiled_hlo_dot.tile_sizes()).UnwrapUnsafe()}; + // The specific accumulator type to use may not correspond to the output type + // of the dot. In particular, that is the case when an algorithm is specified + // and the dot's output type does not match its expectations. + TF_ASSIGN_OR_RETURN(Type accumulator_type, + triton::GetDotAccumulatorType(b, dot)); + Value accumulator = + CreateConst(b, accumulator_type, 0.0f, tiled_hlo_dot.tile_sizes()) + .UnwrapTensor(); auto ci64 = [&](int64_t value) -> Value { return b.create(b.getIntegerAttr(b.getI64Type(), value)); @@ -790,7 +795,7 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, GetDotLoopIterationCount(tiled_hlo_dot)); auto for_op = b.create( /*lowerBound=*/ci64(0), /*upperBound=*/ci64(loop_iteration_count), - /*step=*/ci64(1), iter_args); + /*step=*/ci64(1), SmallVector{accumulator}); { // Loop body. mlir::OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(for_op.getBody()); @@ -819,22 +824,7 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, } dot_args.push_back(result.front().UnwrapTensor()); } - QCHECK_EQ(dot_args.size(), 2); - QCHECK_EQ(iter_args.size(), 1); Value acc = for_op.getRegionIterArgs().front(); - const PrecisionConfig& precision_config = dot.precision_config(); - // TODO(b/393299275): Support precision config. Right now we bail out if - // user wants anything but the default. - if (precision_config.algorithm() != PrecisionConfig::ALG_UNSET || - absl::c_any_of(precision_config.operand_precision(), - [](const int precision) { - return precision != PrecisionConfig::DEFAULT; - })) { - return absl::UnimplementedError( - absl::StrCat("Unsupported precision config: ", - precision_config.ShortDebugString())); - } - int64_t lhs_contracting_dim_idx = dot.dot_dimension_numbers().lhs_contracting_dimensions(0); @@ -853,13 +843,23 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, Value rhs, MaskDotOperand(b, *tiled_hlo_dot.operand(1), dot_args[1], ki_i32, rhs_contracting_dim_idx)); - Value dot_result = - b.create(lhs, rhs, acc, - /*inputPrecision=*/ttir::InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); - b.create(dot_result); + TF_ASSIGN_OR_RETURN( + Value acc_next, + triton::EmitSingleTileDot(b, dot, triton::DotOperands{lhs, rhs, acc})); + b.create(acc_next); } - return ScalarOrTensor(for_op.getResult(0)); + + // The output of the loop may not match the expected output type of the dot. + // We make sure to issue a conversion if necessary. + TF_ASSIGN_OR_RETURN(Type dot_output_type, + TritonType(b, dot.shape().element_type())); + + Value result = for_op.getResult(0); + if (dot_output_type != accumulator_type) { + result = Cast(b, result, dot_output_type); + } + + return ScalarOrTensor(result); } absl::StatusOr EmitConcatenate( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index 7163dbf516b782..fc0a67905f85f7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include #include @@ -22,6 +24,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -32,14 +36,14 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "xla/autotuning.pb.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter.h" +#include "xla/backends/gpu/codegen/triton/support.h" #include "xla/backends/gpu/codegen/triton/test_utils.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -1703,10 +1707,16 @@ CHECK: tt.broadcast {{.*}} -> tensor<1x2x64x8x EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } +std::string TypeTestParamToString( + const ::testing::TestParamInfo& data) { + return primitive_util::LowercasePrimitiveTypeName(data.param); +} + INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite, IotaEmitterParametrizedTest, ::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32, - F64})); + F64}), + TypeTestParamToString); TEST_F(TritonEmitterTest, ReducePrecisionIsLoweredCorrectly) { const std::string kHloText = R"( @@ -2078,7 +2088,8 @@ fdot { } } ROOT fdot.root = f32[16,16]{1,0} dot(fdot.lhs, fdot.rhs), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 } ENTRY entry { @@ -2132,7 +2143,8 @@ fdot { } } ROOT fdot.root = f32[32,512]{1,0} dot(fdot.lhs, fdot.rhs), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 } ENTRY entry { @@ -2199,7 +2211,8 @@ fdot { } } ROOT fdot.root = f32[32,512]{1,0} dot(fdot.lhs, fdot.rhs), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 } ENTRY entry { @@ -2379,7 +2392,8 @@ dot { } } ROOT dot = f32[32,512]{1,0} dot(lhs, rhs), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 } ENTRY entry { @@ -2423,7 +2437,9 @@ triton_dot (p0: f32[264], p1: f32[128,8]) -> f32[264,8] { lhs = f32[264,128]{1,0} fusion(p0), kind=kCustom, calls=flhs, backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion","block_level_fusion_config":{"num_warps":"1","output_tiles":[{"sizes":["32","16"]}]}}} p1 = f32[128,8]{1,0} parameter(1) rhs = f32[128,8]{1,0} fusion(p1), kind=kCustom, calls=frhs, backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion","block_level_fusion_config":{"num_warps":"1","output_tiles":[{"sizes":["16","16"]}]}}} - ROOT result = f32[264,8]{1,0} dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT result = f32[264,8]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 } ENTRY e (p0.1: f32[11,1,24,1], p1.1: f32[128,8]) -> f32[264,8] { @@ -2445,6 +2461,318 @@ ENTRY e (p0.1: f32[11,1,24,1], p1.1: f32[128,8]) -> f32[264,8] { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } +// The template is parametrized by the type of the lhs/rhs, the type of the +// dot output, and the algorithm. +constexpr absl::string_view kHloForDotAlgorithmTestTemplate = R"( +lhs { + ROOT p0 = $0[512,512] parameter(0) +} + +rhs { + ROOT p0 = $0[512,512] parameter(0) +} + +dot { + p0 = $0[512,512] parameter(0) + p1 = $0[512,512] parameter(1) + lhs = $0[512,512] fusion(p0), kind=kCustom, calls=lhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "32"]}] + }}} + rhs = $0[512,512]{1,0} fusion(p1), kind=kCustom, calls=rhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "64"]}] + }}} + ROOT dot = $1[512,512] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=$2 +} + +ENTRY entry { + p0 = $0[512,512] parameter(0) + p1 = $0[512,512] parameter(1) + ROOT fusion = $1[512,512] fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16","64"]}], + "num_warps":"1", "num_ctas":"1", "num_stages":"1" + }}} +})"; + +std::string GetDotAlgorithmHlo(PrimitiveType in_ty, PrimitiveType out_ty, + PrecisionConfig::Algorithm algorithm) { + constexpr absl::string_view kAlgorithmPrefix = "ALG_"; + std::string in_ty_str = primitive_util::LowercasePrimitiveTypeName(in_ty); + std::string out_ty_str = primitive_util::LowercasePrimitiveTypeName(out_ty); + std::string algorithm_str = PrecisionConfig::Algorithm_Name(algorithm).substr( + kAlgorithmPrefix.size()); + return absl::Substitute(kHloForDotAlgorithmTestTemplate, in_ty_str, + out_ty_str, algorithm_str); +} + +// TODO(b/407744579): narrow down the error specs for the various dot +// algorithms. +// +// The non-default values are either taken from the pre-existing +// `dot_algorithms_test` as of 2025-04-01, or approximated. It's not clear +// whether even the pre-existing values were derived to adhere precisely to the +// numerical expectations of the corresponding algorithms. We should narrow this +// down in the future. +ErrorSpec ErrorSpecForDotAlgorithm(PrecisionConfig::Algorithm algorithm) { + // A default error spec, not particularly tuned to any algorithm. + ErrorSpec default_error_spec{/*aabs=*/1e-4, /*arel=*/1e-6}; + switch (algorithm) { + case PrecisionConfig::ALG_UNSET: + // Give a loose tolerance to ALG_UNSET, as the expected behaviour is + // not deducible from the algorithm name alone. + return ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}; + case PrecisionConfig::ALG_DOT_F16_F16_F16: + // Computed to make the tests pass (and it seems reasonable on the face of + // it), and not derived from first principles. + return ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}; + case PrecisionConfig::ALG_DOT_F32_F32_F32: + return default_error_spec; + case PrecisionConfig::ALG_DOT_F64_F64_F64: + // Computed to make the tests pass (and it seems reasonable on the face of + // it), and not derived from first principles. + return ErrorSpec{/*aabs=*/2e-6, /*arel=*/2e-6}; + case PrecisionConfig::ALG_DOT_F16_F16_F32: + return default_error_spec; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + // Taken from `dot_algorithms_test`. + return ErrorSpec{/*aabs=*/0, /*arel=*/6e-5}; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + // Taken from `dot_algorithms_test`. + return ErrorSpec{/*aabs=*/0, /*arel=*/7e-6}; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + // Computed to make the tests pass (and it seems reasonable on the face of + // it), and not derived from first principles. + return ErrorSpec{/*aabs=*/2e-6, /*arel=*/2e-6}; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + // Computed to make the tests pass (and it seems reasonable on the face of + // it), and not derived from first principles. + return ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: + // Computed to make the tests pass (and it seems reasonable on the face of + // it), and not derived from first principles. + return ErrorSpec{/*aabs=*/2e-6, /*arel=*/3e-6}; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: + // Computed to make the tests pass (and it seems reasonable on the face of + // it), and not derived from first principles. + return ErrorSpec{/*aabs=*/2e-6, /*arel=*/2e-6}; + case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + return kExactMatch; + // Keep in order to make the switch exhaustive. + case PrecisionConfig_Algorithm_PrecisionConfig_Algorithm_INT_MIN_SENTINEL_DO_NOT_USE_: // NOLINT(whitespace/line_length) + case PrecisionConfig_Algorithm_PrecisionConfig_Algorithm_INT_MAX_SENTINEL_DO_NOT_USE_: // NOLINT(whitespace/line_length) + LOG(FATAL) << "Unsupported algorithm: " << algorithm; + } +} + +class TritonEmitterTestWithAlgorithmParam + : public TritonEmitterTest, + public ::testing::WithParamInterface {}; + +// Regroups tests for dot algorithms that have no ambiguous type parameters as +// per `algorithm_util::GetAllowedOperandsTypeForAlgorithm` and +// `algorithm_util::GetDotAccumulatorType`, and do not decompose each tiled step +// into multiple `dot` operations. We call these algorithms "basic" algorithms +// here. +using BasicDotAlgorithmEmitterTest = TritonEmitterTestWithAlgorithmParam; + +constexpr std::array kBasicAlgorithms = { + PrecisionConfig::ALG_DOT_F16_F16_F16, + PrecisionConfig::ALG_DOT_F32_F32_F32, + PrecisionConfig::ALG_DOT_F64_F64_F64, + PrecisionConfig::ALG_DOT_F16_F16_F32, + PrecisionConfig::ALG_DOT_BF16_BF16_F32, + PrecisionConfig::ALG_DOT_TF32_TF32_F32, +}; + +TEST_P(BasicDotAlgorithmEmitterTest, BasicAlgorithmIsEmittedCorrectly) { + auto algorithm = GetParam(); + TF_ASSERT_OK_AND_ASSIGN( + std::vector allowed_types, + algorithm_util::GetAllowedOperandsTypeForAlgorithm(algorithm)); + ASSERT_EQ(allowed_types.size(), 1); + PrimitiveType in_ty = allowed_types.front(); + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType out_ty, + algorithm_util::GetDotAccumulatorType(algorithm)); + const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm); + + TF_EXPECT_OK(CreateTritonIrAndFileCheck( + this, kHloText, "dot", + absl::Substitute(R"( + CHECK: tt.dot{{.*}} : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1> + )", + primitive_util::LowercasePrimitiveTypeName(in_ty), + primitive_util::LowercasePrimitiveTypeName(out_ty)))); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpecForDotAlgorithm(algorithm))); +} + +std::string DotAlgorithmTestToString( + const ::testing::TestParamInfo& data) { + return PrecisionConfig::Algorithm_Name(data.param); +} + +INSTANTIATE_TEST_SUITE_P(BasicDotAlgorithmEmitterTestSuite, + BasicDotAlgorithmEmitterTest, + ::testing::ValuesIn(kBasicAlgorithms), + DotAlgorithmTestToString); + +// Regroups tests for dot algorithms that issue several dot instructions. +using MultiDotAlgorithmEmitterTest = TritonEmitterTestWithAlgorithmParam; + +constexpr std::array kMultiDotAlgorithms = { + PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3, + PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6, + PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3, + PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9, +}; + +TEST_P(MultiDotAlgorithmEmitterTest, MultiDotAlgorithmIsEmittedCorrectly) { + auto algorithm = GetParam(); + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType out_ty, + algorithm_util::GetDotAccumulatorType(algorithm)); + PrimitiveType in_ty = + algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ? F32 : BF16; + // Dummy value to ensure that the dot count is explicitly set. + int dot_count_for_algorithm = 0x1337; + std::string input_precision_string = ""; + switch (algorithm) { + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + dot_count_for_algorithm = 3; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + dot_count_for_algorithm = 6; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: + dot_count_for_algorithm = 9; + break; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: + // Triton implements TF32x3 as a specific precision mode. + input_precision_string = "tf32x3"; + dot_count_for_algorithm = 1; + break; + default: + // Unreachable. + ASSERT_TRUE(false); + } + + const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm); + + TF_EXPECT_OK(CreateTritonIrAndFileCheck( + this, kHloText, "dot", + absl::Substitute(R"( + CHECK-COUNT-$2: tt.dot{{.*}}$3{{.*}} : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1> + )", + primitive_util::LowercasePrimitiveTypeName(in_ty), + primitive_util::LowercasePrimitiveTypeName(out_ty), + dot_count_for_algorithm, input_precision_string))); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpecForDotAlgorithm(algorithm))); +} + +INSTANTIATE_TEST_SUITE_P(MultiDotAlgorithmEmitterTestSuite, + MultiDotAlgorithmEmitterTest, + ::testing::ValuesIn(kMultiDotAlgorithms), + DotAlgorithmTestToString); + +// Regroups tests that use TF32 precision by definition. +using TF32DotAlgorithmEmitterTest = TritonEmitterTestWithAlgorithmParam; + +constexpr std::array kTF32DotAlgorithms = { + PrecisionConfig::ALG_DOT_TF32_TF32_F32, + PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3}; + +TEST_P(TF32DotAlgorithmEmitterTest, TF32AlgorithmsUseTF32InputPrecision) { + auto algorithm = GetParam(); + TF_ASSERT_OK_AND_ASSIGN( + std::vector allowed_types, + algorithm_util::GetAllowedOperandsTypeForAlgorithm(algorithm)); + ASSERT_EQ(allowed_types.size(), 1); + PrimitiveType in_ty = allowed_types.front(); + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType out_ty, + algorithm_util::GetDotAccumulatorType(algorithm)); + const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm); + + std::string input_precision_string = + algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ? "tf32x3" + : "tf32"; + + TF_EXPECT_OK(CreateTritonIrAndFileCheck( + this, kHloText, "dot", + absl::Substitute(R"( + CHECK: tt.dot{{.*}} inputPrecision = $2 : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1> + )", + primitive_util::LowercasePrimitiveTypeName(in_ty), + primitive_util::LowercasePrimitiveTypeName(out_ty), + input_precision_string))); + // No need to `RunAndCompare` here, these algorithms are already covered by + // other tests. +} + +INSTANTIATE_TEST_SUITE_P(TF32DotAlgorithmEmitterTestSuite, + TF32DotAlgorithmEmitterTest, + ::testing::ValuesIn(kTF32DotAlgorithms), + DotAlgorithmTestToString); + +class DotUnsetAlgorithmEmitterTest + : public TritonEmitterTest, + public ::testing::WithParamInterface {}; + +TEST_P(DotUnsetAlgorithmEmitterTest, UnsetAlgorithmIsEmittedCorrectly) { + // This currently assumes that the dot output type is the same as the input + // type. This is not enforced by the verifier/HLO spec, but is currently true + // for our emitters, and is enforced by `support_test.cc`. This test may + // require upgrading if we ever consider emitting code for truly mixed type + // `dot`s. + PrimitiveType ty = GetParam(); + if (!internal::IsResultTypeSupportedByAlgUnsetDot(ty, + GpuComputeCapability())) { + GTEST_SKIP() << primitive_util::LowercasePrimitiveTypeName(ty) + << " is not supported on this platform."; + } + + ErrorSpec error_spec = ErrorSpecForDotAlgorithm(PrecisionConfig::ALG_UNSET); + // For 8-bit floating point types, we need to allow large errors. + if (primitive_util::IsFloatingPointType(ty) && + primitive_util::BitWidth(ty) == 8) { + error_spec = ErrorSpec{/*aabs=*/1e0, /*arel=*/1e-1}; + } + + const std::string kHloText = + GetDotAlgorithmHlo(ty, ty, PrecisionConfig::ALG_UNSET); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, error_spec)); +} + +std::vector AllXlaDataTypesSupportedByAlgUnsetDotLowering() { + // We don't have a pointer to stream executor available here so we can't + // detect the particular device we're running on with a canonical API call. + // Instead, we just return a superset of the supported types (i.e. those that + // are supported on the latest device), and filter out the unsupported types + // in the test body. + std::vector supported_types; + absl::c_copy_if(AllXlaDataTypes(), std::back_inserter(supported_types), + [](PrimitiveType type) { + return internal::IsResultTypeSupportedByAlgUnsetDot( + type, se::CudaComputeCapability::Blackwell()); + }); + return supported_types; +} + +INSTANTIATE_TEST_SUITE_P( + DotUnsetAlgorithmEmitterTestSuite, DotUnsetAlgorithmEmitterTest, + ::testing::ValuesIn(AllXlaDataTypesSupportedByAlgUnsetDotLowering()), + TypeTestParamToString); + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index c1a9a32e39a760..2126a3d8dcea64 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_indexing_utils.h" @@ -316,6 +317,91 @@ absl::Status CheckSupportedCheckDotDimensions(const HloDotInstruction& dot) { return absl::OkStatus(); } +bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm) { + switch (algorithm) { + case PrecisionConfig::ALG_UNSET: + case PrecisionConfig::ALG_DOT_F16_F16_F16: + case PrecisionConfig::ALG_DOT_F32_F32_F32: + case PrecisionConfig::ALG_DOT_F64_F64_F64: + case PrecisionConfig::ALG_DOT_F16_F16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: + return true; + case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + default: + break; + } + + return false; +} + +// Checks whether the conversions generated during the lowering of the relevant +// dot algorithm for the relevant input and output types are supported by +// Triton. +// +// When the algorithm is `ALG_UNSET`, nothing is checked. +CodegenDecision AreDotAlgorithmInputAndOutputConversionsSupported( + PrecisionConfig::Algorithm algorithm, PrimitiveType lhs_type, + PrimitiveType rhs_type, PrimitiveType result_type, + const se::GpuComputeCapability& gpu_version) { + if (algorithm == PrecisionConfig::ALG_UNSET) { + return CodegenDecision::Allow(); + } + + auto forbid = [&algorithm](absl::string_view message) { + return CodegenDecision::Forbid( + absl::StrCat(message, " for dot algorithm ", + PrecisionConfig::Algorithm_Name(algorithm))); + }; + + absl::StatusOr> allowed_operands_types_or = + algorithm_util::GetAllowedOperandsTypeForAlgorithm(algorithm); + absl::StatusOr expected_accumulator_type = + algorithm_util::GetDotAccumulatorType(algorithm); + if (!allowed_operands_types_or.ok() || !expected_accumulator_type.ok()) { + return forbid("Failed to recover operands types or accumulator type"); + } + CHECK(!allowed_operands_types_or->empty()); + + if (result_type != *expected_accumulator_type) { + if (!IsTritonSupportedConversion(*expected_accumulator_type, result_type, + gpu_version) || + !IsTritonSupportedConversion(result_type, *expected_accumulator_type, + gpu_version)) { + return forbid("Unsupported result conversion"); + } + } + + if (allowed_operands_types_or->size() != 1 && + (lhs_type != rhs_type || + !absl::c_linear_search(*allowed_operands_types_or, lhs_type))) { + return forbid("Unsupported operand types"); + } else if (allowed_operands_types_or->size() == 1) { + return CodegenDecision::Allow(); + } + + PrimitiveType expected_operands_type = allowed_operands_types_or->front(); + + if (lhs_type != expected_operands_type && + !IsTritonSupportedConversion(expected_operands_type, lhs_type, + gpu_version)) { + return forbid("Unsupported lhs conversion"); + } + if (rhs_type != expected_operands_type && + !IsTritonSupportedConversion(expected_operands_type, rhs_type, + gpu_version)) { + return forbid("Unsupported rhs conversion"); + } + + return CodegenDecision::Allow(); +} + CodegenDecision IsTritonSupportedDot( const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { if (!IsInTritonNestedGemmFusion(dot)) { @@ -334,19 +420,11 @@ CodegenDecision IsTritonSupportedDot( "Only operands that are fusions are supported."); } + // TODO(b/393299275): add support tests for mixed types. if (result_type != lhs_type || result_type != rhs_type) { return CodegenDecision::Forbid( "Dot operation only supports same types for the result, lhs and rhs."); } - if (absl::c_linear_search( - std::vector{PrimitiveType::F8E5M2, PrimitiveType::BF16, - PrimitiveType::F8E4M3FN, - PrimitiveType::S32, PrimitiveType::S64, - PrimitiveType::S16, PrimitiveType::S8}, - result_type)) { - return CodegenDecision::Forbid( - absl::StrCat(PrimitiveType_Name(result_type), " is not supported")); - } absl::Status status = CheckSupportedCheckDotDimensions(dot); if (!status.ok()) { @@ -354,15 +432,25 @@ CodegenDecision IsTritonSupportedDot( } const PrecisionConfig& precision_config = dot.precision_config(); - if (precision_config.algorithm() != PrecisionConfig::ALG_UNSET || - absl::c_any_of(precision_config.operand_precision(), - [](const int precision) { - return precision != PrecisionConfig::DEFAULT; - })) { - LOG(INFO) << "Unsupported precision config: " - << precision_config.ShortDebugString(); - return CodegenDecision::Forbid(absl::StrCat( - "Unsupported precision config: ", precision_config.ShortDebugString())); + const PrecisionConfig::Algorithm algorithm = precision_config.algorithm(); + + if (!IsSupportedDotAlgorithm(algorithm)) { + return CodegenDecision::Forbid( + absl::StrCat("Unsupported dot algorithm: ", + PrecisionConfig::Algorithm_Name(algorithm))); + } + + if (algorithm == PrecisionConfig::ALG_UNSET && + !internal::IsResultTypeSupportedByAlgUnsetDot(result_type, gpu_version)) { + return CodegenDecision::Forbid( + "Unsupported result type for dot algorithm ALG_UNSET."); + } + + if (CodegenDecision conversion_decision = + AreDotAlgorithmInputAndOutputConversionsSupported( + algorithm, lhs_type, rhs_type, result_type, gpu_version); + !conversion_decision) { + return conversion_decision; } return CodegenDecision::Allow(); @@ -565,6 +653,20 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { return false; } } + +bool IsResultTypeSupportedByAlgUnsetDot( + PrimitiveType result_type, const se::GpuComputeCapability& gpu_version) { + std::vector supported_types = {BF16, F16, F32, F64, F8E5M2}; + + if (auto* cuda_cc = std::get_if(&gpu_version)) { + if (cuda_cc->IsAtLeastHopper()) { + supported_types.push_back(F8E4M3FN); + } + } + + return absl::c_linear_search(supported_types, result_type); +} + } // namespace internal absl::Status EnsureTritonSupportsComputeCapability( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.h b/third_party/xla/xla/backends/gpu/codegen/triton/support.h index de2c15c6c47011..47f2f02c6f1ead 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.h @@ -71,6 +71,11 @@ namespace internal { // HLOs. This is exposed for testing purposes only and will be removed in the // near future. Do not use. This functions only returns a partial result. bool IsTritonUnsupportedOpcode(HloOpcode opcode); + +// This is exposed for testing purposes only. Do not use. +bool IsResultTypeSupportedByAlgUnsetDot( + PrimitiveType result_type, const se::GpuComputeCapability& gpu_version); + } // namespace internal } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 39a874ac409057..ca22498aecae07 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -56,23 +56,6 @@ namespace { using ::testing::Not; using ::tsl::testing::IsOk; -std::vector AllXlaDataTypes() { - std::vector xla_data_types; - std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, - TUPLE, OPAQUE_TYPE, TOKEN}; - const tsl::protobuf::EnumDescriptor* xla_type_descriptor = - tsl::protobuf::GetEnumDescriptor(); - for (int enum_ix = 0; enum_ix < xla_type_descriptor->value_count(); - ++enum_ix) { - xla::PrimitiveType xla_type = static_cast( - xla_type_descriptor->value(enum_ix)->number()); - if (!absl::c_linear_search(to_filter_out, xla_type)) { - xla_data_types.push_back(xla_type); - } - } - return xla_data_types; -} - // Returns true if the given `opcode` supports the given `type` with respect to // HLO semantics. This is completely independent of the what Triton supports or // what the hardware supports. @@ -203,7 +186,10 @@ auto AllTestCombinationsForOpcodes(absl::Span opcodes) { // Expected failure mode of the Triton lowering. enum class ExpectedFailMode { + // Denotes a graceful failure, e.g. a verifier failure, or an absl::Status. kFail, + // Denotes a crash. That is typically the case when encountering a bug in + // the Triton compiler itself. kCrash, // Use only in cases when the failure mode depends on the compilation mode // (i.e. when the failure is caused by a CHECK). @@ -2170,9 +2156,7 @@ ENTRY triton_computation { PrecisionToString(lhs_precision), PrecisionToString(rhs_precision)); ExpectedFailMode fail_mode = ExpectedFailMode::kFail; - if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, data_type) && - lhs_precision == PrecisionConfig::DEFAULT && - rhs_precision == PrecisionConfig::DEFAULT) { + if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, data_type)) { fail_mode = ExpectedFailMode::kFailOrCrash; } TF_ASSERT_OK_AND_ASSIGN( @@ -2258,8 +2242,8 @@ ENTRY triton_computation { /* use_nested_gemm_fusions=*/true)); ExpectedFailMode fail_mode = ExpectedFailMode::kFail; - if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, data_type) && - algorithm == PrecisionConfig::ALG_UNSET) { + if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, F8E4M3, S8}, + data_type)) { fail_mode = ExpectedFailMode::kFailOrCrash; } RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc, fail_mode); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc index c09f96a58a053d..615bb56f5adbf1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -51,13 +52,32 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/status_macros.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tsl/platform/protobuf.h" namespace xla::gpu { +std::vector AllXlaDataTypes() { + std::vector xla_data_types; + std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, + TUPLE, OPAQUE_TYPE, TOKEN}; + const tsl::protobuf::EnumDescriptor* xla_type_descriptor = + tsl::protobuf::GetEnumDescriptor(); + for (int enum_ix = 0; enum_ix < xla_type_descriptor->value_count(); + ++enum_ix) { + xla::PrimitiveType xla_type = static_cast( + xla_type_descriptor->value(enum_ix)->number()); + if (!absl::c_linear_search(to_filter_out, xla_type)) { + xla_data_types.push_back(xla_type); + } + } + return xla_data_types; +} + bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) { if (std::holds_alternative(cc)) { return std::get(cc).IsAtLeast( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h index bbd640511ba3dc..899483fd942f8e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h @@ -42,6 +42,8 @@ limitations under the License. namespace xla::gpu { +std::vector AllXlaDataTypes(); + bool SupportsBF16(const stream_executor::GpuComputeCapability& cc); std::string ComputeCapabilityToString( From 5192406dd8d8cf9776cc5d8e1de45fb03b5fe3d6 Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Wed, 2 Apr 2025 09:46:01 -0700 Subject: [PATCH 0159/1324] [XLA:GPU] run nest_gemm_fusion pass in gemm fusion autotuner gemm_fusion_autotuner runs it's own pipeline that should also include the new pass PiperOrigin-RevId: 743174019 --- third_party/xla/xla/service/gpu/autotuning/BUILD | 1 + .../xla/service/gpu/autotuning/gemm_fusion_autotuner.cc | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 400df6ac61bf25..c778f212a0a794 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -156,6 +156,7 @@ cc_library( "//xla/service/gpu/transforms:dot_algorithm_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/service/gpu/transforms:nest_gemm_fusion", "//xla/service/gpu/transforms:priority_fusion", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 2b75d3fe4bb4a7..6370fae5097a9a 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -79,6 +79,7 @@ limitations under the License. #include "xla/service/gpu/transforms/dot_algorithm_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" +#include "xla/service/gpu/transforms/nest_gemm_fusion.h" #include "xla/service/gpu/transforms/priority_fusion.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_graph_dumper.h" @@ -330,6 +331,12 @@ absl::StatusOr> TritonGemmAutotuneExtractor( *backend_config.mutable_triton_gemm_config() = config.ToProto(); TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config)); + if (debug_opts + .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()) { + NestGemmFusion nest_gemm_fusion(gpu_device_info.gpu_compute_capability()); + TF_RETURN_IF_ERROR(nest_gemm_fusion.Run(new_module.get()).status()); + } + if (config.split_k > 1) { TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config)); for (PrimitiveType type : From b8fce922561bde751cd9a39ffad356f6959ca771 Mon Sep 17 00:00:00 2001 From: Ranko Sredojevic Date: Wed, 2 Apr 2025 10:27:45 -0700 Subject: [PATCH 0160/1324] Provide sugared versions of ForEach for HloComputations. PiperOrigin-RevId: 743190903 --- .../xla/xla/hlo/utils/concurrency/BUILD | 9 ++ .../hlo/utils/concurrency/concurrency_utils.h | 81 ++++++++++++ .../concurrency/concurrency_utils_test.cc | 120 ++++++++++++++++++ 3 files changed, 210 insertions(+) diff --git a/third_party/xla/xla/hlo/utils/concurrency/BUILD b/third_party/xla/xla/hlo/utils/concurrency/BUILD index 3e12edf86fda32..ca54aee41d4197 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/BUILD +++ b/third_party/xla/xla/hlo/utils/concurrency/BUILD @@ -45,10 +45,13 @@ cc_library( # copybara:uncomment compatible_with = get_compatible_with_portable(), deps = [ ":tsl_task_executor", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -85,9 +88,15 @@ xla_cc_test( deps = [ ":concurrency_utils", ":tsl_task_executor", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_module_config", + "//xla/tests:hlo_test_base", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h index 1adc1b62869074..d7abb443c87a0b 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h +++ b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils.h @@ -16,15 +16,21 @@ limitations under the License. #ifndef XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ #define XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ +#include #include #include #include #include #include +#include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/concurrency/tsl_task_executor.h" namespace xla::concurrency { @@ -126,6 +132,81 @@ absl::Status ForEach(ForwardItT begin, ForwardItT end, return task_executor.ExecuteIndependentTasks(std::move(tasks), parallelism); } +// Specializes `ForEach` for an iterator of `xla::HloComputation` and provides a +// parameter to use when combining return values from individual actions. +template +absl::StatusOr ForEachHloComputation( + ForwardItT begin, ForwardItT end, + absl::AnyInvocable(HloComputation*)> action, + absl::AnyInvocable< + absl::StatusOr(std::vector&)> + combiner, + TaskExecutorT& task_executor, + std::optional parallelism = std::nullopt) { + auto result_for_each = + ForEach(begin, end, std::move(action), task_executor, parallelism); + if (!result_for_each.ok()) { + return result_for_each.status(); + } + + return combiner(*result_for_each); +} + +// Specializes `ForEach` for a span of `xla::HloComputation` and provides a +// parameter to use when combining return values from individual actions. +template +absl::StatusOr ForEachHloComputation( + absl::Span computations, + absl::AnyInvocable(HloComputation*)> action, + absl::AnyInvocable< + absl::StatusOr(std::vector&)> + combiner, + TaskExecutorT& task_executor, + std::optional parallelism = std::nullopt) { + return ForEachHloComputation(computations.begin(), computations.end(), + std::move(action), std::move(combiner), + task_executor, parallelism); +} + +// Specializes `ForEachHloComputation` to take an `xla::HloModule` and run on +// all computations in it. +template +absl::StatusOr ForEachHloComputation( + HloModule* module, + absl::AnyInvocable(HloComputation*)> action, + absl::AnyInvocable< + absl::StatusOr(std::vector&)> + combiner, + TaskExecutorT& task_executor, + std::optional parallelism = std::nullopt) { + // The returned type is not a `forward_iterator` so we create one. + auto it = module->computations(); + std::vector computations{it.begin(), it.end()}; + return ForEachHloComputation(computations, std::move(action), + std::move(combiner), task_executor, parallelism); +} + +// Specializes `ForEachHloComputation` to take an `xla::HloModule` and run on +// all non-fusion computations in it. +template +absl::StatusOr ForEachNonfusionHloComputation( + HloModule* module, + const absl::flat_hash_set& execution_threads, + absl::AnyInvocable(HloComputation*)> action, + absl::AnyInvocable< + absl::StatusOr(std::vector&)> + combiner, + TaskExecutorT& task_executor, + std::optional parallelism = std::nullopt) { + auto computations = module->MakeNonfusionComputations(execution_threads); + return ForEachHloComputation(computations, std::move(action), + std::move(combiner), task_executor, parallelism); +} + } // namespace xla::concurrency #endif // XLA_HLO_UTILS_CONCURRENCY_CONCURRENCY_UTILS_H_ diff --git a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc index 6a36b29b0a180e..242e17c79521d2 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc +++ b/third_party/xla/xla/hlo/utils/concurrency/concurrency_utils_test.cc @@ -15,13 +15,24 @@ limitations under the License. #include "xla/hlo/utils/concurrency/concurrency_utils.h" +#include +#include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/concurrency/tsl_task_executor.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" namespace xla::concurrency { @@ -29,6 +40,11 @@ namespace { using ::testing::ElementsAreArray; +template +struct WrappedT { + T val; +}; + TEST(ForEachTest, IterVariantConcurrentlyIncrementsIntegers) { TslTaskExecutor task_executor(5); @@ -131,5 +147,109 @@ TEST(ForEachTest, FailureOfTheFirstActionPropagates) { absl::StatusCode::kCancelled); } +class HloComputationTest : public HloHardwareIndependentTestBase { + protected: + HloComputationTest() = default; + + // Create a computation which takes a scalar and returns its negation. + std::unique_ptr CreateNegateComputation( + absl::string_view name = "Negate") { + auto builder = HloComputation::Builder(name); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + return builder.Build(); + } + + std::unique_ptr CreateNegateModule() { + auto module = + std::make_unique("NegateModule", HloModuleConfig{}); + module->AddComputation(CreateNegateComputation("Negate0"), true); + module->AddComputation(CreateNegateComputation("Negate1"), false); + module->AddComputation(CreateNegateComputation("Negate2"), false); + + return module; + }; + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + + TslTaskExecutor task_executor_{5}; +}; + +TEST_F(HloComputationTest, ForEachHloComputationBasicCall) { + auto comp0 = CreateNegateComputation(); + auto comp1 = CreateNegateComputation(); + auto comp2 = CreateNegateComputation(); + + std::vector v = {comp0.get(), comp1.get(), comp2.get()}; + + auto result = ForEachHloComputation>( + v.begin(), v.end(), + [](HloComputation* comp) -> absl::StatusOr> { + return WrappedT{true}; + }, + [](std::vector>& results) { + return std::any_of(results.begin(), results.end(), + [](WrappedT b) { return b.val; }); + }, + task_executor_); + // For compatibility with OpenXLA. + ASSERT_EQ(result.status(), absl::OkStatus()); + EXPECT_EQ(*result, true); +} + +TEST_F(HloComputationTest, ForEachHloComputationSpanBasicCall) { + auto comp0 = CreateNegateComputation(); + auto comp1 = CreateNegateComputation(); + auto comp2 = CreateNegateComputation(); + + std::vector v = {comp0.get(), comp1.get(), comp2.get()}; + + auto result = ForEachHloComputation>( + v, + [](HloComputation* comp) -> absl::StatusOr> { + return WrappedT{true}; + }, + [](std::vector>& results) { + return std::any_of(results.begin(), results.end(), + [](WrappedT b) { return b.val; }); + }, + task_executor_); + // For compatibility with OpenXLA. + ASSERT_EQ(result.status(), absl::OkStatus()); + EXPECT_EQ(*result, true); +} + +TEST_F(HloComputationTest, ForEachHloComputationModuleBasicCall) { + auto module = CreateNegateModule(); + + auto result = ForEachHloComputation>( + module.get(), + [](HloComputation* comp) -> absl::StatusOr> { + return WrappedT{true}; + }, + [](std::vector>& results) { return results.size(); }, + task_executor_); + // For compatibility with OpenXLA. + ASSERT_EQ(result.status(), absl::OkStatus()); + EXPECT_EQ(*result, 3); +} + +TEST_F(HloComputationTest, ForEachNonfusionHloComputationModuleBasicCall) { + auto module = CreateNegateModule(); + + auto result = ForEachNonfusionHloComputation>( + module.get(), {}, + [](HloComputation* comp) -> absl::StatusOr> { + return WrappedT{true}; + }, + [](std::vector>& results) { return results.size(); }, + task_executor_); + // For compatibility with OpenXLA. + ASSERT_EQ(result.status(), absl::OkStatus()); + EXPECT_EQ(*result, 3); +} + } // namespace } // namespace xla::concurrency From aa72ff64492969ce050543cf6fcb4fb18453585c Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Wed, 2 Apr 2025 10:44:21 -0700 Subject: [PATCH 0161/1324] TFL_DynamicUpdateSliceOp to support int16 operands. PiperOrigin-RevId: 743197416 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 6 +++--- .../compiler/mlir/lite/tests/legalize-tf.mlir | 8 ++++++++ tensorflow/lite/core/kernels/register.cc | 2 +- tensorflow/lite/kernels/dynamic_update_slice.cc | 4 ++++ .../lite/kernels/dynamic_update_slice_test.cc | 15 +++++++++++++++ tensorflow/lite/tools/versioning/op_version.cc | 4 +++- .../lite/tools/versioning/op_version_test.cc | 4 ++++ .../lite/tools/versioning/runtime_version.cc | 1 + 8 files changed, 39 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 986e02fe8e335a..8ffd327e94583b 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4153,13 +4153,13 @@ def TFL_DynamicUpdateSliceOp: TFL_Op<"dynamic_update_slice", [ }]; let arguments = (ins - TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$operand, - TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$update, + TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$operand, + TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$update, TFL_I32OrI64Tensor:$start_indices ); let results = ( - outs TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$output); + outs TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$output); let hasFolder = 1; } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index c0978d484ee11e..c3dc00ca74f1ae 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -2589,6 +2589,14 @@ func.func @dynamic_update_slice_f16_arg(%arg0: tensor<4x5xf16>, %arg1: tensor<1x // CHECK: "tfl.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4x5xf16>, tensor<1x5xf16>, tensor<2xi32>) -> tensor<4x5xf16> } +func.func @dynamic_update_slice_i16(%arg0: tensor<4x5xi16>, %arg1: tensor<1x5xi16>, %arg2: tensor<2xi32>) -> tensor<4x5xi16> { + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x5xi16>, tensor<1x5xi16>, tensor<2xi32>) -> tensor<4x5xi16> + func.return %0 : tensor<4x5xi16> + +// CHECK-LABEL:dynamic_update_slice_i16 +// CHECK: "tfl.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4x5xi16>, tensor<1x5xi16>, tensor<2xi32>) -> tensor<4x5xi16> +} + func.func @testReluI32(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index 216f1dece7ef8e..0c331b98ead1f9 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -354,7 +354,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE, Register_DYNAMIC_UPDATE_SLICE(), /* min_version = */ 1, - /* max_version = */ 3); + /* max_version = */ 4); AddBuiltin(BuiltinOperator_UNSORTED_SEGMENT_PROD, Register_UNSORTED_SEGMENT_PROD()); AddBuiltin(BuiltinOperator_UNSORTED_SEGMENT_MAX, diff --git a/tensorflow/lite/kernels/dynamic_update_slice.cc b/tensorflow/lite/kernels/dynamic_update_slice.cc index 776379058cc1ff..5c5cbcd8f963b4 100644 --- a/tensorflow/lite/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/kernels/dynamic_update_slice.cc @@ -219,6 +219,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { DynamicUpdateSlice(operand, update, indices_data_i64.data(), output); break; + case kTfLiteInt16: + DynamicUpdateSlice(operand, update, indices_data_i64.data(), + output); + break; case kTfLiteInt32: DynamicUpdateSlice(operand, update, indices_data_i64.data(), output); diff --git a/tensorflow/lite/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/kernels/dynamic_update_slice_test.cc index 867f2b9b8cc029..373a719d5ac412 100644 --- a/tensorflow/lite/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/kernels/dynamic_update_slice_test.cc @@ -177,6 +177,21 @@ TEST(DynamicUpdateSliceOpTest, SimpleTestI8) { 7, -2, 9})); } +TEST(DynamicUpdateSliceOpTest, SimpleTestI16) { + DynamicUpdateSliceOpModel m({TensorType_INT16, {3, 3}}, + {TensorType_INT16, {2, 1}}, + {TensorType_INT32, {2}}); + m.SetInput({1, 2, 3, // + 4, 5, 6, // + 7, 8, 9}); + m.SetUpdate({-1, -2}); + m.SetStartIndices({1, 1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, // + 4, -1, 6, // + 7, -2, 9})); +} + TEST(DynamicUpdateSliceOpTest, SimpleTestI32) { DynamicUpdateSliceOpModel m({TensorType_INT32, {3, 3}}, {TensorType_INT32, {2, 1}}, diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 7d4c3c6ca25ec7..82fb3974639c81 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -1048,7 +1048,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; case BuiltinOperator_DYNAMIC_UPDATE_SLICE: - if (op_sig.inputs.at(0).type == kTfLiteFloat16) { + if (op_sig.inputs.at(0).type == kTfLiteInt16) { + return 4; + } else if (op_sig.inputs.at(0).type == kTfLiteFloat16) { return 3; } else if (op_sig.inputs.at(2).type == kTfLiteInt64) { return 2; diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index a0f76a32780220..7afd313e0e339c 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -1433,5 +1433,9 @@ TEST(OpVersionTest, VersioningDynamicUpdateSliceTest) { fake_op_sig.inputs = CreateOpSignatureTensorSpecs( std::vector{kTfLiteFloat16, kTfLiteFloat16, kTfLiteInt32}); EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt16, kTfLiteInt16, kTfLiteInt32}); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); } } // namespace tflite diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index acfcd2754dccb6..3c121a92d47a97 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -432,6 +432,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 1}, "2.9.0"}, {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 2}, "2.17.0"}, {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 3}, "2.19.0"}, + {{BuiltinOperator_DYNAMIC_UPDATE_SLICE, 4}, "2.20.0"}, {{BuiltinOperator_UNSORTED_SEGMENT_PROD, 1}, "2.10.0"}, {{BuiltinOperator_UNSORTED_SEGMENT_MAX, 1}, "2.10.0"}, {{BuiltinOperator_UNSORTED_SEGMENT_MIN, 1}, "2.11.0"}, From 60f3ace93ab3ef1cc2474a2e87cd799a6937befc Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 2 Apr 2025 10:50:56 -0700 Subject: [PATCH 0162/1324] [XLA:GPU] Handle transposed `dot`s and `dot`s with more than two dimensions in the generic Triton emitter. We still have a restriction that exactly two dimensions need to be tiled with a non-unit size. This restriction will be enforced by `SymbolicTileAnalysis` in a future change. PiperOrigin-RevId: 743199898 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 1 + .../gpu/codegen/triton/fusion_emitter.cc | 99 ++++++++++- .../triton/fusion_emitter_device_test.cc | 155 ++++++++++++++++++ .../backends/gpu/codegen/triton/support.cc | 31 ---- .../gpu/codegen/triton/support_test.cc | 12 +- .../gpu/transforms/nest_gemm_fusion_test.cc | 12 +- 6 files changed, 258 insertions(+), 52 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 84a4652b79cf1d..cdd8ec6b54dfe4 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -200,6 +200,7 @@ cc_library( "//xla/codegen/emitters/ir:xla", "//xla/codegen/emitters/transforms:passes", "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", "//xla/hlo/utils:hlo_traversal", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 503adba0bc4fd2..4b539922d183cc 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -99,6 +100,7 @@ limitations under the License. #include "xla/codegen/emitters/transforms/passes.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -474,7 +476,7 @@ ScalarOrTensor ReshapeTensorToScalar(EmitterLocOpBuilder& b, Value input) { if (mlir::cast(input.getType()).getRank() > 1) { Type output_tensor_type = mlir::RankedTensorType::get({1}, element_type); single_dim_tensor = b.create(output_tensor_type, input, - /*allow_reorder*/ true); + /*allow_reorder=*/true); } // Second, reduce to a scalar. @@ -515,7 +517,6 @@ absl::StatusOr EmitTiledReshape(EmitterLocOpBuilder& b, } // At this point we know that the input is a non-0D tensor. - auto input_shaped_type = mlir::cast(input.getType()); // Handle the case of reshaping [1,1,1...] to a scalar. @@ -524,7 +525,6 @@ absl::StatusOr EmitTiledReshape(EmitterLocOpBuilder& b, } // At this point we know that neither the input nor the output are 0D tensors. - Type output_tensor_type = mlir::RankedTensorType::get( padded_tile_sizes, input_shaped_type.getElementType()); @@ -731,6 +731,65 @@ absl::StatusOr> ComputeBasePtrOffset( /*symbols=*/{}, b); } +// Returns `shape` without all its unit dimensions, as well as the index of the +// remaining dimensions in the original `shape`. +std::pair, SmallVector> CollapseUnitDims( + llvm::ArrayRef shape) { + SmallVector shape_without_unit_dims; + SmallVector non_unit_dims_indices; + for (auto [i, size] : llvm::enumerate(shape)) { + if (size != 1) { + shape_without_unit_dims.push_back(size); + non_unit_dims_indices.push_back(i); + } + } + return {std::move(shape_without_unit_dims), std::move(non_unit_dims_indices)}; +} + +enum class DotOperandSide { kLhs, kRhs }; + +// Canonicalizes the given operand of a dot operation, i.e. make it a 2D tensor, +// and make sure that the contracting dimension is where we expect it to be for +// the given side (the second dimension for LHS, the first dimension for the +// RHS). +// +// Returns an error if canonicalization is not possible. +absl::StatusOr CanonicalizeDotOperand(EmitterLocOpBuilder& b, + Value operand, + int64_t contracting_dim_idx, + DotOperandSide side) { + llvm::ArrayRef shape = + mlir::cast(operand.getType()).getShape(); + auto [shape_without_unit_dims, non_unit_dims_indices] = + CollapseUnitDims(shape); + + if (shape_without_unit_dims.size() != 2) { + return absl::FailedPreconditionError( + "Expected dot operand tile to have exactly two non-unit tile sizes"); + } + + if (shape.size() != shape_without_unit_dims.size()) { + TF_ASSIGN_OR_RETURN( + ScalarOrTensor wrapped_operand, + EmitTiledReshape(b, shape_without_unit_dims, ScalarOrTensor(operand))); + operand = wrapped_operand.UnwrapTensor(); + } + + int expected_contracting_dim_position = side == DotOperandSide::kLhs ? 1 : 0; + bool is_transposed = + non_unit_dims_indices[expected_contracting_dim_position] != + contracting_dim_idx; + + if (is_transposed) { + SmallVector transposed_shape{shape_without_unit_dims[1], + shape_without_unit_dims[0]}; + operand = + EmitTiledTranspose(b, transposed_shape, /*dimensions=*/{1, 0}, operand); + } + + return operand; +} + absl::StatusOr EmitDot(EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, @@ -779,13 +838,29 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, return absl::FailedPreconditionError("Expected dot operands to be fusions"); } + SmallVector padded_tile_sizes = + GetPaddedTileSizes(tiled_hlo_dot.tile_sizes()); + + SmallVector padded_tile_sizes_no_unit_dims = + CollapseUnitDims(padded_tile_sizes).first; + + // Sanity check: Triton historically did not support non-2D dots (and still + // doesn't support arbitrary nD dots), so we require that the dot is tiled + // with exactly two non-unit tile sizes. This anyway matches the hardware's + // expectations, so seems like a reasonable requirement. + // TODO(b/393299275): this needs to be enforced in tiling. + if (padded_tile_sizes_no_unit_dims.size() != 2) { + return absl::FailedPreconditionError( + "Expected dot to be tiled with exactly two non-unit tile sizes"); + } + // The specific accumulator type to use may not correspond to the output type // of the dot. In particular, that is the case when an algorithm is specified // and the dot's output type does not match its expectations. TF_ASSIGN_OR_RETURN(Type accumulator_type, triton::GetDotAccumulatorType(b, dot)); Value accumulator = - CreateConst(b, accumulator_type, 0.0f, tiled_hlo_dot.tile_sizes()) + CreateConst(b, accumulator_type, 0.0f, padded_tile_sizes_no_unit_dims) .UnwrapTensor(); auto ci64 = [&](int64_t value) -> Value { @@ -843,6 +918,15 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, Value rhs, MaskDotOperand(b, *tiled_hlo_dot.operand(1), dot_args[1], ki_i32, rhs_contracting_dim_idx)); + // Canonicalize the dot operands to match Triton's/the hardware's + // expectations. + TF_ASSIGN_OR_RETURN(lhs, + CanonicalizeDotOperand(b, lhs, lhs_contracting_dim_idx, + DotOperandSide::kLhs)); + TF_ASSIGN_OR_RETURN(rhs, + CanonicalizeDotOperand(b, rhs, rhs_contracting_dim_idx, + DotOperandSide::kRhs)); + TF_ASSIGN_OR_RETURN( Value acc_next, triton::EmitSingleTileDot(b, dot, triton::DotOperands{lhs, rhs, acc})); @@ -859,6 +943,13 @@ absl::StatusOr EmitDot(EmitterLocOpBuilder& b, result = Cast(b, result, dot_output_type); } + if (padded_tile_sizes.size() != padded_tile_sizes_no_unit_dims.size()) { + TF_ASSIGN_OR_RETURN( + ScalarOrTensor wrapped_result, + EmitTiledReshape(b, padded_tile_sizes, ScalarOrTensor(result))); + result = wrapped_result.UnwrapTensor(); + } + return ScalarOrTensor(result); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index fc0a67905f85f7..b8c353bbeeb689 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -2232,6 +2232,161 @@ ENTRY entry { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } +TEST_F(TritonEmitterTest, DotWithMajorLhsContractingDimIsEmittedCorrectly) { + const std::string kHloText = R"( +lhs { + ROOT p0 = f32[299,32] parameter(0) +} + +rhs { + ROOT p0 = f32[299,512] parameter(0) +} + +fdot { + p0 = f32[299,32] parameter(0) + p1 = f32[299,512] parameter(1) + lhs = f32[299,32] fusion(p0), kind=kCustom, calls=lhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "16"]}] + } + } + } + rhs = f32[299,512]{1,0} fusion(p1), kind=kCustom, calls=rhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "64"]}] + } + } + } + ROOT dot = f32[32,512]{1,0} dot(lhs, rhs), + lhs_contracting_dims={0}, rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 +} + +ENTRY entry { + // Take in boolean inputs for the test, in order to allow exact accumulation. + p0 = pred[299,32] parameter(0) + p1 = pred[299,512] parameter(1) + p0_f32 = f32[299,32] convert(p0) + p1_f32 = f32[299,512] convert(p1) + ROOT fusion = f32[32,512] fusion(p0_f32, p1_f32), + kind=kCustom, calls=fdot, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "64"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); +} + +TEST_F(TritonEmitterTest, DotWithMinorRhsContractingDimIsEmittedCorrectly) { + const std::string kHloText = R"( +lhs { + ROOT p0 = f32[32,299] parameter(0) +} + +rhs { + ROOT p0 = f32[512,299] parameter(0) +} + +fdot { + p0 = f32[32,299] parameter(0) + p1 = f32[512,299] parameter(1) + lhs = f32[32,299] fusion(p0), kind=kCustom, calls=lhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "32"]}] + } + } + } + rhs = f32[512,299]{1,0} fusion(p1), kind=kCustom, calls=rhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["64", "32"]}] + } + } + } + ROOT dot = f32[32,512]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + algorithm=dot_f32_f32_f32 +} + +ENTRY entry { + // Take in boolean inputs for the test, in order to allow exact accumulation. + p0 = pred[32,299] parameter(0) + p1 = pred[512,299] parameter(1) + p0_f32 = f32[32,299] convert(p0) + p1_f32 = f32[512,299] convert(p1) + ROOT fusion = f32[32,512] fusion(p0_f32, p1_f32), + kind=kCustom, calls=fdot, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16", "64"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); +} + +TEST_F(TritonEmitterTest, + DotWithAdditionalDimensionsWithUnitTileSizesIsEmittedCorrectly) { + const std::string kHloText = R"( +lhs { + ROOT p0 = f32[2,3,32,125] parameter(0) +} + +rhs { + ROOT p0 = f32[2,125,3,256] parameter(0) +} + +fdot { + p0 = f32[2,3,32,125] parameter(0) + p1 = f32[2,125,3,256] parameter(1) + lhs = f32[2,3,32,125] fusion(p0), kind=kCustom, calls=lhs, + backend_config={"fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["1", "1", "16", "32"]}] + } + } + } + rhs = f32[2,125,3,256] fusion(p1), kind=kCustom, calls=rhs, + backend_config={"fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["1", "32", "1", "64"]}] + } + } + } + ROOT dot = f32[2,3,32,256] dot(lhs, rhs), + lhs_batch_dims={0,1}, rhs_batch_dims={0,2}, + lhs_contracting_dims={3}, rhs_contracting_dims={1}, + algorithm=dot_f32_f32_f32 +} + +ENTRY entry { + // Take in boolean inputs for the test, in order to allow exact accumulation. + p0 = pred[2,3,32,125] parameter(0) + p1 = pred[2,125,3,256] parameter(1) + p0_f32 = f32[2,3,32,125] convert(p0) + p1_f32 = f32[2,125,3,256] convert(p1) + ROOT fusion = f32[2,3,32,256] fusion(p0_f32, p1_f32), + kind=kCustom, calls=fdot, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["1", "1", "16", "64"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); +} + TEST_F(TritonEmitterTest, ConcatenateOfNestsIsEmittedCorrectly) { const std::string kHloText = R"( nest0 { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 2126a3d8dcea64..9a3d6160535b73 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -275,45 +275,14 @@ bool IsInTritonNestedGemmFusion(const HloInstruction& hlo) { } absl::Status CheckSupportedCheckDotDimensions(const HloDotInstruction& dot) { - const Shape& lhs_shape = dot.operand(0)->shape(); - const Shape& rhs_shape = dot.operand(1)->shape(); const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - TF_ASSIGN_OR_RETURN( - std::vector lhs_non_contracting_dims, - GetNonContractingDims(lhs_shape, dim_numbers.lhs_batch_dimensions(), - dim_numbers.lhs_contracting_dimensions())); - TF_ASSIGN_OR_RETURN( - std::vector rhs_non_contracting_dims, - GetNonContractingDims(rhs_shape, dim_numbers.rhs_batch_dimensions(), - dim_numbers.rhs_contracting_dimensions())); - if (lhs_non_contracting_dims.size() > 1 || - rhs_non_contracting_dims.size() > 1) { - return absl::UnimplementedError(absl::StrCat( - "Multiple non-contracting dimensions are not supported, got LHS: [", - absl::StrJoin(lhs_non_contracting_dims, ","), "], RHS: [", - absl::StrJoin(rhs_non_contracting_dims, ","), "]")); - } // Only checking one side of bach and contracting dimensions, since they must // be the same for left and right. - if (dim_numbers.lhs_batch_dimensions_size() > 0) { - return absl::UnimplementedError( - absl::StrCat("Batch dimensions are not supported yet, got ", - absl::StrJoin(dim_numbers.lhs_batch_dimensions(), ","))); - } if (dim_numbers.lhs_contracting_dimensions_size() != 1) { return absl::UnimplementedError(absl::StrCat( "Exactly one contracting dimension is supported, got ", absl::StrJoin(dim_numbers.lhs_contracting_dimensions(), ","))); } - if (dim_numbers.lhs_contracting_dimensions(0) != 1 || - dim_numbers.rhs_contracting_dimensions(0) != 0) { - return absl::UnimplementedError(absl::StrCat( - "Only lhs_contracting_dimensions=1 (got ", - absl::StrJoin(dim_numbers.lhs_contracting_dimensions(), ","), - ") and rhs_contracting_dimensions=0 (got ", - absl::StrJoin(dim_numbers.rhs_contracting_dimensions(), ","), - ") are supported.")); - } return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index ca22498aecae07..819a875214847b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -1871,14 +1871,14 @@ ENTRY triton_computation { lhs = $0[16,128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["16", "16", "64"]}] + "output_tiles":[{"sizes":["1", "16", "64"]}] } } } rhs = $0[16,256,512] fusion(p1), kind=kCustom, calls=frhs, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["16", "64", "32"]}] + "output_tiles":[{"sizes":["1", "64", "32"]}] } } } @@ -1891,7 +1891,7 @@ ENTRY triton_computation { TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, /* use_nested_gemm_fusions=*/true)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 16, 32}, + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 16, 32}, se::CudaComputeCapability::Ampere()); } @@ -1911,14 +1911,14 @@ ENTRY triton_computation { lhs = $0[16,128,256] fusion(p0), kind=kCustom, calls=flhs, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["4", "16", "64"]}] + "output_tiles":[{"sizes":["1", "16", "64"]}] } } } rhs = $0[16,256,512] fusion(p1), kind=kCustom, calls=frhs, backend_config={ "fusion_backend_config":{ "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["4", "64", "32"]}] + "output_tiles":[{"sizes":["1", "64", "32"]}] } } } @@ -1930,7 +1930,7 @@ ENTRY triton_computation { TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, /* use_nested_gemm_fusions=*/true)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{4, 16, 4, 32}, + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 16, 1, 32}, se::CudaComputeCapability::Ampere()); } diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc index 8c712ab27814b2..bf42af6962ef93 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/gpu/transforms/nest_gemm_fusion.h" -#include - #include #include #include "absl/log/log.h" @@ -37,18 +35,11 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" using ::testing::ElementsAre; -using ::testing::Not; -using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; using ::tsl::testing::StatusIs; namespace xla { -// Gtest hook to pretty-print an HloInstruction. -static void PrintTo(const HloInstruction& hlo, std::ostream* os) { - *os << hlo.ToString(); -} - namespace gpu { namespace { @@ -422,9 +413,8 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - // TODO(b/393299275): rhs_contracting_dims={0} is not currently supported. EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), - Not(IsOk())); + IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); } From f191703e3bf1dcaa6c358a4d29ba5b062e0d1c63 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 2 Apr 2025 11:09:45 -0700 Subject: [PATCH 0163/1324] [XLA:GPU] Annotate combinable sync collectives. Instead of computing the set of synchronous collectives once per combiner pass, we compute it once for all combiner passes. We add a new pass prior to combiner passes where we annotate synchronous collectives. PiperOrigin-RevId: 743207737 --- third_party/xla/xla/service/gpu/BUILD | 4 +- .../xla/xla/service/gpu/gpu_compiler.cc | 8 +- .../service/gpu/transforms/collectives/BUILD | 57 +++++-- .../collectives/all_gather_combiner.cc | 19 +-- .../collectives/all_gather_combiner_test.cc | 12 +- .../collectives/all_reduce_combiner.cc | 19 +-- .../collectives/all_reduce_combiner_test.cc | 26 +-- .../collective_combiner_annotator.cc | 140 +++++++++++++++ .../collective_combiner_annotator.h | 61 +++++++ .../collective_combiner_annotator_test.cc | 159 ++++++++++++++++++ .../gpu_collective_combiner_utils.cc | 74 -------- .../gpu_collective_combiner_utils.h | 8 - .../gpu_collective_combiner_utils_test.cc | 78 --------- .../collectives/reduce_scatter_combiner.cc | 19 +-- .../reduce_scatter_combiner_test.cc | 12 +- 15 files changed, 443 insertions(+), 253 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.cc create mode 100644 third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.h create mode 100644 third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index be8ebb87957949..fd9d73f1bdec3b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1414,6 +1414,7 @@ cc_library( "@llvm-project//llvm:BitWriter", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:TransformUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -1493,6 +1494,7 @@ cc_library( "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:matmul_ptable_stats_collection", "//xla/service/gpu/model:sol_gpu_cost_model_stats_collection", + "//xla/service/gpu/transforms/collectives:collective_combiner_annotator", "//xla/service/gpu/transforms/collectives:convert_async_collectives_to_sync", "//xla/service/gpu/transforms/collectives:gpu_all_gather_combiner", "//xla/service/gpu/transforms/collectives:gpu_all_reduce_combiner", @@ -1641,7 +1643,7 @@ cc_library( ]) + xla_internal(["service:export_hlo"]) + if_google([ "//xla/hlo/experimental/auto_sharding", "//xla/hlo/experimental/auto_sharding:auto_sharding_option", - ]) + ["@llvm-project//llvm:TargetParser"], + ]), ) xla_test( diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0fb4047a3d1e3e..76ae8d1de9237c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -37,7 +37,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" -#include "absl/types/variant.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringRef.h" @@ -194,6 +193,7 @@ limitations under the License. #include "xla/service/gpu/transforms/collective_select_folder.h" #include "xla/service/gpu/transforms/collectives/all_gather_combiner.h" #include "xla/service/gpu/transforms/collectives/all_reduce_combiner.h" +#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h" #include "xla/service/gpu/transforms/collectives/convert_async_collectives_to_sync.h" #include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h" #include "xla/service/gpu/transforms/collectives/reduce_scatter_combiner.h" @@ -1150,6 +1150,12 @@ absl::Status RunPostFusionPasses( HloPassPipeline pipeline("post-fusion optimization"); pipeline.AddPass(); + if (hlo_module->config() + .debug_options() + .xla_gpu_experimental_enable_sync_collective_combining()) { + pipeline.AddPass(device_description, + pointer_size); + } pipeline.AddPass( device_description, /*default_combine_threshold_in_bytes=*/kDefaultAllGatherCombineThreshold, diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD index 5c95563786108c..d6f76342cdeb6d 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD @@ -67,23 +67,15 @@ cc_library( srcs = ["gpu_collective_combiner_utils.cc"], hdrs = ["gpu_collective_combiner_utils.h"], deps = [ - ":collective_ops_utils", - ":convert_async_collectives_to_sync", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", "//xla/service:collective_utils", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_hlo_schedule", "//xla/stream_executor:device_description", - "//xla/tsl/platform:errors", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", ], ) @@ -103,9 +95,7 @@ xla_cc_test( "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -117,6 +107,7 @@ cc_library( srcs = ["all_gather_combiner.cc"], hdrs = ["all_gather_combiner.h"], deps = [ + ":collective_combiner_annotator", ":gpu_collective_combiner_utils", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -126,7 +117,6 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -157,6 +147,7 @@ cc_library( srcs = ["reduce_scatter_combiner.cc"], hdrs = ["reduce_scatter_combiner.h"], deps = [ + ":collective_combiner_annotator", ":gpu_collective_combiner_utils", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -166,7 +157,6 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -196,6 +186,7 @@ cc_library( srcs = ["all_reduce_combiner.cc"], hdrs = ["all_reduce_combiner.h"], deps = [ + ":collective_combiner_annotator", ":gpu_collective_combiner_utils", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -205,7 +196,6 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -229,3 +219,44 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "collective_combiner_annotator", + srcs = ["collective_combiner_annotator.cc"], + hdrs = ["collective_combiner_annotator.h"], + deps = [ + ":collective_ops_utils", + ":convert_async_collectives_to_sync", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:gpu_hlo_schedule", + "//xla/stream_executor:device_description", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "collective_combiner_annotator_test", + srcs = ["collective_combiner_annotator_test.cc"], + deps = [ + ":collective_combiner_annotator", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner.cc b/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner.cc index 92933d75b34874..337b1a03d58009 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" -#include "absl/functional/bind_front.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -26,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/collectives/all_gather_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h" #include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h" #include "xla/service/hlo_domain_map.h" #include "xla/tsl/platform/statusor.h" @@ -49,10 +49,9 @@ std::optional PipelinedCombinerKey( } std::optional SynchronousCombinerKey( - const absl::flat_hash_set& sync_collectives, const HloInstruction* instruction, const HloDomainMap& domain_map, bool combine_by_dim, bool combine_different_dtypes) { - if (!sync_collectives.contains(instruction)) { + if (!IsCombinableSyncCollective(*instruction)) { return std::nullopt; } return AllGatherCombiner::CombineKey(instruction, domain_map, combine_by_dim, @@ -77,21 +76,11 @@ absl::StatusOr GpuAllGatherCombiner::Run( bool changed = false; // Combine as much as possible for synchronous collectives. - absl::flat_hash_set sync_collectives; - if (module->config() - .debug_options() - .xla_gpu_experimental_enable_sync_collective_combining()) { - TF_ASSIGN_OR_RETURN( - sync_collectives, - SynchronousCollectives(*module, pointer_size_, device_info_)); - } - if (!sync_collectives.empty()) { + if (ContainsCombinableSyncCollective(*module)) { combine_threshold_in_bytes_ = MaxAvailableMemory(*module, device_info_); TF_ASSIGN_OR_RETURN( bool combined, - RunWithKeyCombiner( - module, execution_threads, - absl::bind_front(SynchronousCombinerKey, sync_collectives))); + RunWithKeyCombiner(module, execution_threads, SynchronousCombinerKey)); changed |= combined; } diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner_test.cc index 98e1154716b746..fb1543ca76c79c 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner_test.cc @@ -356,8 +356,10 @@ TEST_F(GpuAllGatherCombinerTest, CombinesSynchronousCollectivesMaximally) { p1 = f16[5000000]{0} parameter(1) // 20MB combinable all-gather collectives. Default combiner threshold is 30MB. - ag0 = f16[10000000]{0} all-gather(p0), replica_groups={}, dimensions={0} - ag1 = f16[10000000]{0} all-gather(p1), replica_groups={}, dimensions={0} + ag0 = f16[10000000]{0} all-gather(p0), replica_groups={}, dimensions={0}, + frontend_attributes={sync_collective="true"} + ag1 = f16[10000000]{0} all-gather(p1), replica_groups={}, dimensions={0}, + frontend_attributes={sync_collective="true"} ROOT result = tuple(ag0, ag1) } )"; @@ -373,13 +375,7 @@ TEST_F(GpuAllGatherCombinerTest, CombinesSynchronousCollectivesMaximally) { /*combine_by_dim=*/false, /*combine_different_dtypes=*/true, /*pointer_size=*/4); - EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false)); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_experimental_enable_sync_collective_combining(true); EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(true)); - Matcher combined_all_gather = op::AllGather(op::Parameter(0), op::Parameter(1)); EXPECT_THAT(module->entry_computation()->root_instruction(), diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner.cc b/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner.cc index d1d40692e217c4..e97364f8924235 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" -#include "absl/functional/bind_front.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -26,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/collectives/all_reduce_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h" #include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h" #include "xla/service/hlo_domain_map.h" #include "xla/tsl/platform/statusor.h" @@ -47,9 +47,8 @@ std::optional PipelinedCombinerKey( } std::optional SynchronousCombinerKey( - const absl::flat_hash_set& sync_collectives, const HloInstruction* instruction, const HloDomainMap& domain_map) { - if (!sync_collectives.contains(instruction)) { + if (!IsCombinableSyncCollective(*instruction)) { return std::nullopt; } return AllReduceCombiner::CombineKey(instruction, domain_map); @@ -73,21 +72,11 @@ absl::StatusOr GpuAllReduceCombiner::Run( bool changed = false; // Combine as much as possible for synchronous collectives. - absl::flat_hash_set sync_collectives; - if (module->config() - .debug_options() - .xla_gpu_experimental_enable_sync_collective_combining()) { - TF_ASSIGN_OR_RETURN( - sync_collectives, - SynchronousCollectives(*module, pointer_size_, device_info_)); - } - if (!sync_collectives.empty()) { + if (ContainsCombinableSyncCollective(*module)) { combine_threshold_in_bytes_ = MaxAvailableMemory(*module, device_info_); TF_ASSIGN_OR_RETURN( bool combined, - RunWithKeyCombiner( - module, execution_threads, - absl::bind_front(SynchronousCombinerKey, sync_collectives))); + RunWithKeyCombiner(module, execution_threads, SynchronousCombinerKey)); changed |= combined; } diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner_test.cc index 75028d75d80c80..b6ecba042217aa 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner_test.cc @@ -359,8 +359,10 @@ TEST_F(GpuAllReduceCombinerTest, CombinesSynchronousCollectivesMaximally) { p1 = f16[10000000]{0} parameter(1) // 20MB combinable all-reduce collectives. Default combiner threshold is 30MB. - ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add - ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add + ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add, + frontend_attributes={sync_collective="true"} + ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add, + frontend_attributes={sync_collective="true"} ROOT result = tuple(ar0, ar1) } )"; @@ -374,13 +376,7 @@ TEST_F(GpuAllReduceCombinerTest, CombinesSynchronousCollectivesMaximally) { /*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold, /*combine_threshold_count=*/256, /*pointer_size=*/4); - EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false)); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_experimental_enable_sync_collective_combining(true); EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(true)); - Matcher combined_all_reduce = op::AllReduce(op::Parameter(0), op::Parameter(1)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -400,17 +396,17 @@ TEST_F(GpuAllReduceCombinerTest, } ENTRY main { - p0 = f16[10000000]{0} parameter(0) - p1 = f16[10000000]{0} parameter(1) + p0 = f16[10000]{0} parameter(0) + p1 = f16[10000]{0} parameter(1) // This all-reduce must happen first, which is enforced by the control // dependency and must be respected. - lead_ar = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add + lead_ar = f16[10000]{0} all-reduce(p0), replica_groups={}, to_apply=add // These all-reduce have control dependencies and must not be combined. - ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add, + ar0 = f16[10000]{0} all-reduce(p0), replica_groups={}, to_apply=add, control-predecessors={lead_ar} - ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add, + ar1 = f16[10000]{0} all-reduce(p1), replica_groups={}, to_apply=add, control-predecessors={lead_ar} ROOT result = tuple(ar0, ar1) } @@ -424,10 +420,6 @@ TEST_F(GpuAllReduceCombinerTest, kDefaultAllReduceCombineThreshold, /*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold, /*combine_threshold_count=*/256, /*pointer_size=*/4); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_experimental_enable_sync_collective_combining(true); EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false)); } diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.cc new file mode 100644 index 00000000000000..6c03d576c7454e --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.cc @@ -0,0 +1,140 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" +#include "xla/service/gpu/transforms/collectives/convert_async_collectives_to_sync.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::gpu { +namespace { + +static constexpr const char kCollectiveIdAttr[] = "collective_id"; +static constexpr const char kCollectiveSyncAttr[] = "sync_collective"; + +std::string CollectiveId(const HloInstruction* instr) { + return absl::StrCat(instr->unique_id()); +} + +// Annotate all collective instructions with a unique identifier that will be +// preserved after async collective conversion. +void AnnotateCollectives(HloModule* module) { + HloPredicate is_collective = [](const HloInstruction* instr) { + return hlo_query::IsCollectiveCommunicationOp(instr->opcode()); + }; + hlo_query::ForEachInstructionWithPred( + *module, is_collective, [](HloInstruction* instr) { + instr->add_frontend_attribute(kCollectiveIdAttr, CollectiveId(instr)); + }); +} + +absl::Status AnnotateSyncCollectives(HloModule* module) { + HloPassPipeline pipeline("annotate-sync-collectives"); + pipeline.AddPass(); + return pipeline.Run(module).status(); +} + +absl::flat_hash_set SyncCollectiveIds(const HloModule& module) { + absl::flat_hash_set sync_collective_ids; + HloPredicate is_sync_collective = [](const HloInstruction* instr) { + return IsGPUSyncCollective(*instr); + }; + hlo_query::ForEachInstructionWithPred( + module, is_sync_collective, + [&sync_collective_ids](const HloInstruction* instr) { + sync_collective_ids.insert( + *instr->get_frontend_attribute(kCollectiveIdAttr)); + }); + return sync_collective_ids; +} + +// Return the set of collective instructions that are synchronous post +// scheduling. +absl::StatusOr> SynchronousCollectives( + const HloModule& module, int64_t pointer_size, + const se::DeviceDescription& device_info) { + std::unique_ptr cloned_module = module.Clone(); + AnnotateCollectives(cloned_module.get()); + TF_RETURN_IF_ERROR(RunAsyncCollectivesConversionPasses(cloned_module.get())); + TF_RETURN_IF_ERROR( + ScheduleGpuModule(cloned_module.get(), pointer_size, device_info) + .status()); + TF_RETURN_IF_ERROR(AnnotateSyncCollectives(cloned_module.get())); + return SyncCollectiveIds(*cloned_module); +} + +} // namespace + +absl::StatusOr CollectiveCombinerAnnotator::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN( + absl::flat_hash_set sync_collectives, + SynchronousCollectives(*module, pointer_size_, device_info_)); + if (sync_collectives.empty()) { + return false; + } + + bool changed = false; + for (HloComputation* comp : module->computations(execution_threads)) { + for (HloInstruction* instr : comp->instructions()) { + if (!sync_collectives.contains(CollectiveId(instr))) { + continue; + } + instr->add_frontend_attribute(kCollectiveSyncAttr, "true"); + changed = true; + } + } + + return changed; +} + +bool IsCombinableSyncCollective(const HloInstruction& instr) { + return instr.get_frontend_attribute(kCollectiveSyncAttr).value_or("false") == + "true"; +} + +bool ContainsCombinableSyncCollective(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instr : computation->instructions()) { + if (IsCombinableSyncCollective(*instr)) { + return true; + } + } + } + return false; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.h b/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.h new file mode 100644 index 00000000000000..cfac165316bf79 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVES_COLLECTIVE_COMBINER_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVES_COLLECTIVE_COMBINER_ANNOTATOR_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Annotates collective operations with metadata used by collective combiners. +class CollectiveCombinerAnnotator : public HloModulePass { + public: + CollectiveCombinerAnnotator(se::DeviceDescription device_info, + int64_t pointer_size) + : device_info_(std::move(device_info)), pointer_size_(pointer_size) {} + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + absl::string_view name() const override { + return "collective-combiner-annotator"; + } + + private: + se::DeviceDescription device_info_; + int64_t pointer_size_; +}; + +// Returns true if `instr` is a combinable sync collective. False otherwise. +bool IsCombinableSyncCollective(const HloInstruction& instr); + +// Returns true if module contains any combinable sync collective. False +// otherwise. +bool ContainsCombinableSyncCollective(const HloModule& module); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVES_COLLECTIVE_COMBINER_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator_test.cc new file mode 100644 index 00000000000000..641640d7b5be7d --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_combiner_annotator_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h" + +#include + +#include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +class CollectiveCombinerAnnotatorTest : public HloHardwareIndependentTestBase { + protected: + absl::StatusOr RunCollectiveCombinerAnnotator(HloModule* module) { + int pointer_size = 4; + stream_executor::DeviceDescription device_info; + device_info.set_device_memory_size(20000); + return RunHloPass( + CollectiveCombinerAnnotator(std::move(device_info), pointer_size), + module); + } +}; + +TEST_F(CollectiveCombinerAnnotatorTest, SynchronousCollectivesNoOverlap) { + absl::string_view kHloText = R"( + HloModule m + + add { + p0 = f16[] parameter(0) + p1 = f16[] parameter(1) + ROOT add = f16[] add(p0, p1) + } + + ENTRY main { + p0 = f16[10000000]{0} parameter(0) + p1 = f16[10000000]{0} parameter(1) + ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add + ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add + ROOT result = tuple(ar0, ar1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + EXPECT_THAT(RunCollectiveCombinerAnnotator(module.get()), IsOkAndHolds(true)); + const HloInstruction* ar0 = + module->entry_computation()->root_instruction()->operand(0); + EXPECT_TRUE(IsCombinableSyncCollective(*ar0)); + const HloInstruction* ar1 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(IsCombinableSyncCollective(*ar1)); +} + +TEST_F(CollectiveCombinerAnnotatorTest, SynchronousCollectivesWithOverlap) { + // Expected schedule: + // ------------------ + // c0 –> ar0 + // c1 –> ar1 + // ------------------ + absl::string_view kHloText = R"( + HloModule m + + add { + p0 = f16[] parameter(0) + p1 = f16[] parameter(1) + ROOT add = f16[] add(p0, p1) + } + + ENTRY main { + p0 = f16[10000000]{0} parameter(0) + p1 = f16[10000000]{0} parameter(1) + + c0 = f16[10000000]{0} copy(p0) + c1 = f16[10000000]{0} copy(p1) + + ar0 = f16[10000000]{0} all-reduce(c0), replica_groups={}, to_apply=add + ar1 = f16[10000000]{0} all-reduce(c1), replica_groups={}, to_apply=add + + ROOT result = tuple(ar0, ar1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + EXPECT_THAT(RunCollectiveCombinerAnnotator(module.get()), IsOkAndHolds(true)); + const HloInstruction* ar0 = + module->entry_computation()->root_instruction()->operand(0); + EXPECT_FALSE(IsCombinableSyncCollective(*ar0)); + const HloInstruction* ar1 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(IsCombinableSyncCollective(*ar1)); +} + +TEST_F(CollectiveCombinerAnnotatorTest, + ContainsCombinableSyncCollectiveReturnFalseForNonAnnotatedCollectives) { + absl::string_view kHloText = R"( + HloModule m + + add { + p0 = f16[] parameter(0) + p1 = f16[] parameter(1) + ROOT add = f16[] add(p0, p1) + } + + ENTRY main { + p0 = f16[10000000]{0} parameter(0) + ROOT result = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + EXPECT_FALSE(ContainsCombinableSyncCollective(*module)); +} + +TEST_F(CollectiveCombinerAnnotatorTest, + ContainsCombinableSyncCollectiveReturnTRUEForAnnotatedCollectives) { + absl::string_view kHloText = R"( + HloModule m + + add { + p0 = f16[] parameter(0) + p1 = f16[] parameter(1) + ROOT add = f16[] add(p0, p1) + } + + ENTRY main { + p0 = f16[10000000]{0} parameter(0) + ROOT result = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add, + frontend_attributes={sync_collective="true"} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + EXPECT_TRUE(ContainsCombinableSyncCollective(*module)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc index 5a263db9b8e7da..748802ae8e5e64 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc @@ -14,27 +14,18 @@ limitations under the License. ==============================================================================*/ #include -#include -#include -#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_hlo_schedule.h" -#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" -#include "xla/service/gpu/transforms/collectives/convert_async_collectives_to_sync.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/errors.h" #include "xla/util.h" namespace xla::gpu { @@ -53,73 +44,8 @@ int64_t GetDefaultValue(HloOpcode opcode) { return -1; } -static constexpr const char* kCollectiveIdAttr = "collective_id"; - -std::string CollectiveId(const HloInstruction* instr) { - return absl::StrCat(instr->unique_id()); -} - -// Annotate all collective instructions with a unique identifier that will be -// preserved after async collective conversion. -void AnnotateCollectives(HloModule* module) { - HloPredicate is_collective = [](const HloInstruction* instr) { - return hlo_query::IsCollectiveCommunicationOp(instr->opcode()); - }; - hlo_query::ForEachInstructionWithPred( - *module, is_collective, [](HloInstruction* instr) { - instr->add_frontend_attribute(kCollectiveIdAttr, CollectiveId(instr)); - }); -} - -absl::Status AnnotateSyncCollectives(HloModule* module) { - HloPassPipeline pipeline("annotate-sync-collectives"); - pipeline.AddPass(); - return pipeline.Run(module).status(); -} - -absl::flat_hash_set SyncCollectiveIds(const HloModule& module) { - absl::flat_hash_set sync_collective_ids; - HloPredicate is_sync_collective = [](const HloInstruction* instr) { - return IsGPUSyncCollective(*instr); - }; - hlo_query::ForEachInstructionWithPred( - module, is_sync_collective, - [&sync_collective_ids](const HloInstruction* instr) { - sync_collective_ids.insert( - *instr->get_frontend_attribute(kCollectiveIdAttr)); - }); - return sync_collective_ids; -} - } // namespace -absl::StatusOr> SynchronousCollectives( - const HloModule& module, int64_t pointer_size, - const se::DeviceDescription& device_info) { - std::unique_ptr cloned_module = module.Clone(); - AnnotateCollectives(cloned_module.get()); - TF_RETURN_IF_ERROR(RunAsyncCollectivesConversionPasses(cloned_module.get())); - TF_RETURN_IF_ERROR( - ScheduleGpuModule(cloned_module.get(), pointer_size, device_info) - .status()); - TF_RETURN_IF_ERROR(AnnotateSyncCollectives(cloned_module.get())); - - absl::flat_hash_set sync_collective_ids = - SyncCollectiveIds(*cloned_module); - - // Find the corresponding sync collective instructions in the original module. - absl::flat_hash_set sync_collectives; - HloPredicate is_sync_collective = - [&sync_collective_ids](const HloInstruction* instr) { - return sync_collective_ids.contains(CollectiveId(instr)); - }; - hlo_query::ForEachInstructionWithPred( - module, is_sync_collective, [&sync_collectives](HloInstruction* instr) { - sync_collectives.insert(instr); - }); - return sync_collectives; -} - int64_t MaxAvailableMemory(const HloModule& module, const se::DeviceDescription& device_info) { int64_t base_limit = module.config().device_memory_size() != 0 diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h index 225182961ed3cf..f23792c0744deb 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h +++ b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h @@ -18,9 +18,7 @@ limitations under the License. #include -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -29,12 +27,6 @@ limitations under the License. namespace xla::gpu { -// Return the set of collective instructions that are synchronous post -// scheduling. -absl::StatusOr> SynchronousCollectives( - const HloModule& module, int64_t pointer_size, - const se::DeviceDescription& device_info); - // Returns the maximum available memory on a device. int64_t MaxAvailableMemory(const HloModule& module, const se::DeviceDescription& device_info); diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc index d2d740faf74b80..371a464afefedf 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc @@ -18,9 +18,7 @@ limitations under the License. #include #include -#include #include -#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -34,16 +32,12 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" namespace xla::gpu { namespace { -using ::testing::UnorderedElementsAre; -using ::tsl::testing::IsOkAndHolds; - using CollectiveCombinerUtilsTest = HloTestBase; TEST_F(CollectiveCombinerUtilsTest, @@ -519,77 +513,5 @@ TEST_F(CollectiveCombinerUtilsTest, EXPECT_FALSE(ContainsPipelinedInstruction(*module)); } -absl::StatusOr> SynchronousCollectives( - const HloModule& module) { - int pointer_size = 4; - stream_executor::DeviceDescription device_info; - device_info.set_device_memory_size(20000); - return xla::gpu::SynchronousCollectives(module, pointer_size, device_info); -} - -TEST_F(CollectiveCombinerUtilsTest, SynchronousCollectivesNoOverlap) { - absl::string_view kHloText = R"( - HloModule m - - add { - p0 = f16[] parameter(0) - p1 = f16[] parameter(1) - ROOT add = f16[] add(p0, p1) - } - - ENTRY main { - p0 = f16[10000000]{0} parameter(0) - p1 = f16[10000000]{0} parameter(1) - ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add - ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add - ROOT result = tuple(ar0, ar1) - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - const HloInstruction* ar0 = - module->entry_computation()->root_instruction()->operand(0); - const HloInstruction* ar1 = - module->entry_computation()->root_instruction()->operand(1); - EXPECT_THAT(SynchronousCollectives(*module), - IsOkAndHolds(UnorderedElementsAre(ar0, ar1))); -} - -TEST_F(CollectiveCombinerUtilsTest, SynchronousCollectivesWithOverlap) { - // Expected schedule: - // ------------------ - // c0 –> ar0 - // c1 –> ar1 - // ------------------ - absl::string_view kHloText = R"( - HloModule m - - add { - p0 = f16[] parameter(0) - p1 = f16[] parameter(1) - ROOT add = f16[] add(p0, p1) - } - - ENTRY main { - p0 = f16[10000000]{0} parameter(0) - p1 = f16[10000000]{0} parameter(1) - - c0 = f16[10000000]{0} copy(p0) - c1 = f16[10000000]{0} copy(p1) - - ar0 = f16[10000000]{0} all-reduce(c0), replica_groups={}, to_apply=add - ar1 = f16[10000000]{0} all-reduce(c1), replica_groups={}, to_apply=add - - ROOT result = tuple(ar0, ar1) - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - const HloInstruction* ar1 = - module->entry_computation()->root_instruction()->operand(1); - EXPECT_THAT(SynchronousCollectives(*module), - IsOkAndHolds(UnorderedElementsAre(ar1))); -} - } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner.cc b/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner.cc index ab2c7f9fd0b75f..cf6677f9d6b2ef 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" -#include "absl/functional/bind_front.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h" #include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h" #include "xla/service/hlo_domain_map.h" #include "xla/service/reduce_scatter_combiner.h" @@ -47,10 +47,9 @@ std::optional PipelinedCombinerKey( } std::optional SynchronousCombinerKey( - const absl::flat_hash_set& sync_collectives, const HloInstruction* instruction, const HloDomainMap& domain_map, bool combine_by_dim) { - if (!sync_collectives.contains(instruction)) { + if (!IsCombinableSyncCollective(*instruction)) { return std::nullopt; } return ReduceScatterCombiner::CombineKey(instruction, domain_map, @@ -75,21 +74,11 @@ absl::StatusOr GpuReduceScatterCombiner::Run( bool changed = false; // Combine as much as possible for synchronous collectives. - absl::flat_hash_set sync_collectives; - if (module->config() - .debug_options() - .xla_gpu_experimental_enable_sync_collective_combining()) { - TF_ASSIGN_OR_RETURN( - sync_collectives, - SynchronousCollectives(*module, pointer_size_, device_info_)); - } - if (!sync_collectives.empty()) { + if (ContainsCombinableSyncCollective(*module)) { combine_threshold_in_bytes_ = MaxAvailableMemory(*module, device_info_); TF_ASSIGN_OR_RETURN( bool combined, - RunWithKeyCombiner( - module, execution_threads, - absl::bind_front(SynchronousCombinerKey, sync_collectives))); + RunWithKeyCombiner(module, execution_threads, SynchronousCombinerKey)); changed |= combined; } diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner_test.cc index 1157ac1899de08..8999aba09c25d1 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/reduce_scatter_combiner_test.cc @@ -363,8 +363,10 @@ TEST_F(GpuReduceScatterCombinerTest, CombinesSynchronousCollectivesMaximally) { p1 = f16[20000000]{0} parameter(1) // 20MB combinable reduce-scatter collectives. Default combiner threshold is 30MB. - rs0 = f16[10000000]{0} reduce-scatter(p0), replica_groups={{0,1}}, dimensions={0}, to_apply=add - rs1 = f16[10000000]{0} reduce-scatter(p1), replica_groups={{0,1}}, dimensions={0}, to_apply=add + rs0 = f16[10000000]{0} reduce-scatter(p0), replica_groups={{0,1}}, dimensions={0}, to_apply=add, + frontend_attributes={sync_collective="true"} + rs1 = f16[10000000]{0} reduce-scatter(p1), replica_groups={{0,1}}, dimensions={0}, to_apply=add, + frontend_attributes={sync_collective="true"} ROOT result = tuple(rs0, rs1) } )"; @@ -379,13 +381,7 @@ TEST_F(GpuReduceScatterCombinerTest, CombinesSynchronousCollectivesMaximally) { /*combine_threshold_count=*/256, /*combine_by_dim=*/false, /*pointer_size=*/4); - EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false)); - - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_experimental_enable_sync_collective_combining(true); EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(true)); - Matcher combined_reduce_scatter = op::ReduceScatter(op::Parameter(0), op::Parameter(1)); EXPECT_THAT(module->entry_computation()->root_instruction(), From fbfb80b0b0ebd8a892f931f482b2221a916612fb Mon Sep 17 00:00:00 2001 From: Joshua Lang Date: Wed, 2 Apr 2025 11:49:18 -0700 Subject: [PATCH 0164/1324] [XLA:GPU] update nvjitlink and compilation_provider tests to support cuda 12.8 Update test expectations due to cuda 12.8 changes nvJitLinkCreate behavior to fail when an invalid sm architecture is provided Update nvJitLinkDestroy usage as asan detects a leakage if it is not run after nvJitLinkCreate failure. PiperOrigin-RevId: 743222408 --- .../xla/xla/stream_executor/cuda/BUILD | 1 + .../stream_executor/cuda/nvjitlink_impl.cc | 20 +++++++++++++++---- .../stream_executor/cuda/nvjitlink_test.cc | 4 +++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 1344ae75dfd015..4dd29c210fc903 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -930,6 +930,7 @@ xla_cc_test( ":nvjitlink_support", "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc index 6d8f298fe25622..445e96c2d9c68b 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc @@ -163,13 +163,25 @@ absl::StatusOr> CompileAndLinkUsingLibNvJitLink( absl::c_transform(cli_args, std::back_inserter(cli_args_ptrs), [](const std::string& s) { return s.c_str(); }); - nvJitLinkHandle link_handle{}; - RETURN_IF_NVJITLINK_ERROR(nvJitLinkCreate(&link_handle, cli_args_ptrs.size(), - cli_args_ptrs.data())); + nvJitLinkHandle link_handle = nullptr; + nvJitLinkResult create_result = + nvJitLinkCreate(&link_handle, cli_args_ptrs.size(), cli_args_ptrs.data()); + absl::Cleanup link_handle_cleaner = [&link_handle] { - CHECK_EQ(nvJitLinkDestroy(&link_handle), NVJITLINK_SUCCESS); + if (link_handle != nullptr) { + CHECK_EQ(nvJitLinkDestroy(&link_handle), NVJITLINK_SUCCESS); + } }; + if (create_result != NVJITLINK_SUCCESS) { + TF_ASSIGN_OR_RETURN(std::string error_log, + nvJitLinkGetErrorLog(link_handle)); + + VLOG(3) << "libnvjitlink error log output: " << error_log; + + return ToStatus(create_result, error_log); + } + for (auto& image : inputs) { nvJitLinkInputType input_type = image.type == NvJitLinkInput::Type::kPtx ? NVJITLINK_INPUT_PTX diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_test.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_test.cc index f3d30ae522f720..84cea34aa59f23 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/stream_executor/cuda/nvjitlink_support.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/tsl/platform/status_matchers.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" @@ -165,7 +166,8 @@ TEST_F(NvJitLinkTest, IdentifiesUnsupportedArchitecture) { EXPECT_THAT( CompileAndLinkHelper(stream_executor::CudaComputeCapability{100, 0}, {kStandalonePtx}), - tsl::testing::StatusIs(absl::StatusCode::kUnimplemented)); + tsl::testing::StatusIs(testing::AnyOf(absl::StatusCode::kUnknown, + absl::StatusCode::kUnimplemented))); } TEST_F(NvJitLinkTest, LinkingTwoCompilationUnitsSucceeds) { From 9e71b964a5b318784cee208e8ab798ab61c352fa Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Wed, 2 Apr 2025 12:21:21 -0700 Subject: [PATCH 0165/1324] Fix: make mlrt CaseOp kernel behave as the same as legacy tfrt PiperOrigin-RevId: 743234704 --- .../tfrt/mlrt/interpreter/builtin_kernels.cc | 6 +-- .../tfrt/mlrt/interpreter/interpreter_test.cc | 37 +++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc b/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc index 8ca71ba8e25b88..fdac986b64e4c6 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc @@ -142,10 +142,8 @@ void CaseOp::Invoke() { mlrt::bc::Vector attribute_function_indices = function_indices(); if (argument_branch_idx >= attribute_function_indices.size()) { - execution_context().Fail(absl::InvalidArgumentError( - absl::StrCat("Case branch number ", argument_branch_idx, - " exceeds limit ", attribute_function_indices.size()))); - return; + // Consistent with the behavior of the legacy TFRT case kernel behavior. + argument_branch_idx = attribute_function_indices.size() - 1; } auto function = diff --git a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc index 34b53ac1fb1bd6..568949ac3792c0 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc @@ -2825,6 +2825,43 @@ TEST(KernelTest, Case) { } } +TEST(KernelTest, CaseInvalidBranchIndexShallChooseLastBranch) { + auto buffer = CreateCaseExecutable(); + + bc::Executable executable(buffer.data()); + + KernelRegistry registry; + RegisterBuiltinKernels(registry); + LoadedExecutable loaded_executable(executable, registry); + + ExecutionContext execution_context(&loaded_executable); + + auto function = loaded_executable.GetFunction("main"); + ASSERT_TRUE(function); + + Value inputs[3]; + + constexpr int32_t kBranch0In = 123; + constexpr int32_t kBranch1In = 456; + + // Test Invalid Branch 10 + { + inputs[0].Set(10); + inputs[1].Set(kBranch0In); + inputs[2].Set(kBranch1In); + Value output; + + std::vector last_uses = {true, true, true}; + execution_context.Call(function, last_uses, absl::MakeSpan(inputs), + absl::Span(&output, 1)); + + Execute(execution_context); + + ASSERT_TRUE(output.HasValue()); + EXPECT_EQ(kBranch1In, output.Get()); + } +} + struct TestPromiseReturnOp : PromiseReturnOpBase { using PromiseReturnOpBase::PromiseReturnOpBase; From 28c32ab79c6815f37b41f24202ed21c92e18d48b Mon Sep 17 00:00:00 2001 From: Aravindh Balaji <77819568+aravindhbalaji1985@users.noreply.github.com> Date: Wed, 2 Apr 2025 12:41:10 -0700 Subject: [PATCH 0166/1324] Update gpu_prim.h Applying https://github.com/tensorflow/tensorflow/issues/16095 to avoid compilation error when compiling sparse_grad_op_gpu.cu.cc using clang compiler. --- tensorflow/core/kernels/gpu_prim.h | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index bef22b50ada12c..52d324c1e6dac6 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -44,10 +44,9 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr( Eigen::numext::bit_cast(val); } -template <> -__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( +__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( Eigen::half *ptr, Int2Type /*is_primitive*/) { - uint16_t result = *reinterpret_cast(ptr); + const uint16_t result = *reinterpret_cast(ptr); return Eigen::numext::bit_cast(result); } @@ -59,10 +58,8 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr( Eigen::numext::bit_cast(val); } -template <> -__device__ __forceinline__ Eigen::bfloat16 -ThreadLoadVolatilePointer(Eigen::bfloat16 *ptr, - Int2Type /*is_primitive*/) { +__device__ __forceinline__ Eigen::bfloat16 ThreadLoadVolatilePointer( + Eigen::bfloat16 *ptr, Int2Type /*is_primitive*/) { uint16_t result = *reinterpret_cast(ptr); return Eigen::numext::bit_cast(result); } From 8b9abc68b459658820c0a7019d38dc589738fde9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 12:38:39 -0700 Subject: [PATCH 0167/1324] Style improvements: - Remove uses of `const_cast` in inputbuffer. In general, `const_cast` is unsafe and can easily lead to undefined behavior. - Use `std::unique_ptr` instead of raw pointer to manage the buffer. PiperOrigin-RevId: 743240541 --- third_party/xla/xla/tsl/lib/io/BUILD | 3 + third_party/xla/xla/tsl/lib/io/inputbuffer.cc | 58 +++++++++++-------- third_party/xla/xla/tsl/lib/io/inputbuffer.h | 45 +++++++++----- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/tsl/lib/io/BUILD b/third_party/xla/xla/tsl/lib/io/BUILD index d1a85004627caa..3b90095c60a1ce 100644 --- a/third_party/xla/xla/tsl/lib/io/BUILD +++ b/third_party/xla/xla/tsl/lib/io/BUILD @@ -89,6 +89,9 @@ cc_library( "//xla/tsl/platform:macros", "//xla/tsl/platform:status", "//xla/tsl/platform:types", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:coding", ], alwayslink = True, diff --git a/third_party/xla/xla/tsl/lib/io/inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc index e7823794df8f76..a3fff3b3127671 100644 --- a/third_party/xla/xla/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc @@ -16,7 +16,13 @@ limitations under the License. #include "xla/tsl/lib/io/inputbuffer.h" #include +#include +#include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" @@ -28,18 +34,18 @@ InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) file_pos_(0), size_(buffer_bytes), buf_(new char[size_]), - pos_(buf_), - limit_(buf_) {} + pos_(buf_.get()), + limit_(buf_.get()) {} -InputBuffer::~InputBuffer() { delete[] buf_; } +InputBuffer::~InputBuffer() = default; absl::Status InputBuffer::FillBuffer() { absl::string_view data; - absl::Status s = file_->Read(file_pos_, size_, &data, buf_); - if (data.data() != buf_) { - memmove(buf_, data.data(), data.size()); + absl::Status s = file_->Read(file_pos_, size_, &data, buf()); + if (data.data() != buf()) { + memmove(buf(), data.data(), data.size()); } - pos_ = buf_; + pos_ = buf(); limit_ = pos_ + data.size(); file_pos_ += data.size(); return s; @@ -50,12 +56,13 @@ absl::Status InputBuffer::ReadLine(T* result) { result->clear(); absl::Status s; do { - size_t buf_remain = limit_ - pos_; - char* newline = static_cast(memchr(pos_, '\n', buf_remain)); + const int buf_remain = num_remaining_bytes(); + const char* const newline = + static_cast(memchr(pos_, '\n', buf_remain)); if (newline != nullptr) { - size_t result_len = newline - pos_; + const int result_len = newline - pos_; result->append(pos_, result_len); - pos_ = newline + 1; + pos_ += result_len + 1; if (!result->empty() && result->back() == '\r') { result->resize(result->size() - 1); } @@ -64,8 +71,8 @@ absl::Status InputBuffer::ReadLine(T* result) { if (buf_remain > 0) result->append(pos_, buf_remain); // Get more data into buffer s = FillBuffer(); - DCHECK_EQ(pos_, buf_); - } while (limit_ != buf_); + DCHECK_EQ(pos_, buf()); + } while (limit_ != buf()); if (!result->empty() && result->back() == '\r') { result->resize(result->size() - 1); } @@ -104,13 +111,13 @@ absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, if (pos_ == limit_) { // Get more data into buffer. status = FillBuffer(); - if (limit_ == buf_) { + if (limit_ == buf()) { break; } } // Do not go over the buffer boundary. const int64_t bytes_to_copy = - std::min(limit_ - pos_, bytes_to_read - *bytes_read); + std::min(num_remaining_bytes(), bytes_to_read - *bytes_read); // Copies buffered data into the destination. memcpy(result + *bytes_read, pos_, bytes_to_copy); pos_ += bytes_to_copy; @@ -166,12 +173,12 @@ absl::Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { if (pos_ == limit_) { // Get more data into buffer s = FillBuffer(); - if (limit_ == buf_) { + if (limit_ == buf()) { break; } } const int64_t bytes_to_advance = - std::min(limit_ - pos_, bytes_to_skip - bytes_skipped); + std::min(num_remaining_bytes(), bytes_to_skip - bytes_skipped); bytes_skipped += bytes_to_advance; pos_ += bytes_to_advance; } @@ -187,14 +194,15 @@ absl::Status InputBuffer::Seek(int64_t position) { position); } // Position of the buffer within file. - const int64_t bufpos = file_pos_ - static_cast(limit_ - buf_); + const int64_t bufpos = file_pos_ - static_cast(limit_ - buf()); if (position >= bufpos && position < file_pos_) { // Seeks to somewhere inside the buffer. - pos_ = buf_ + (position - bufpos); - DCHECK(pos_ >= buf_ && pos_ < limit_); + pos_ = buf() + position - bufpos; + DCHECK_GE(pos_, buf()); + DCHECK_LT(pos_, limit_); } else { // Seeks to somewhere outside. Discards the buffered data. - pos_ = limit_ = buf_; + pos_ = limit_ = buf(); file_pos_ = position; } return absl::OkStatus(); @@ -211,7 +219,7 @@ absl::Status InputBuffer::Hint(int64_t bytes_to_read) { return absl::OkStatus(); } - const int64_t bytes_remain_in_buf = static_cast(limit_ - pos_); + const int64_t bytes_remain_in_buf = num_remaining_bytes(); // There are enough data in the buffer. Do nothing. if (bytes_to_read <= bytes_remain_in_buf) { @@ -219,9 +227,9 @@ absl::Status InputBuffer::Hint(int64_t bytes_to_read) { } // Additional read from file is necessary. Make some room. - memmove(buf_, pos_, bytes_remain_in_buf); - pos_ = buf_; - limit_ = buf_ + bytes_remain_in_buf; + memmove(buf(), pos_, bytes_remain_in_buf); + pos_ = buf(); + limit_ = buf() + bytes_remain_in_buf; bytes_to_read -= bytes_remain_in_buf; // Read the remaining bytes from file. diff --git a/third_party/xla/xla/tsl/lib/io/inputbuffer.h b/third_party/xla/xla/tsl/lib/io/inputbuffer.h index 1d9db6bf19c5ad..5dd9923d248fb3 100644 --- a/third_party/xla/xla/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.h @@ -16,9 +16,15 @@ limitations under the License. #ifndef XLA_TSL_LIB_IO_INPUTBUFFER_H_ #define XLA_TSL_LIB_IO_INPUTBUFFER_H_ +#include +#include +#include #include +#include "absl/status/status.h" #include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" #include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/types.h" @@ -80,7 +86,7 @@ class InputBuffer { absl::Status Hint(int64_t bytes_to_read); // Returns the position in the file. - int64_t Tell() const { return file_pos_ - (limit_ - pos_); } + int64_t Tell() const { return file_pos_ - num_remaining_bytes(); } // Returns the underlying RandomAccessFile. RandomAccessFile* file() const { return file_; } @@ -100,13 +106,22 @@ class InputBuffer { template absl::Status ReadVarintFallback(T* result, int max_bytes); - RandomAccessFile* file_; // Not owned - int64_t file_pos_; // Next position to read from in "file_" - size_t size_; // Size of "buf_" - char* buf_; // The buffer itself - // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_" - char* pos_; // Current position in "buf" - char* limit_; // Just past end of valid data in "buf" + // The buffer itself. + char* buf() { return buf_.get(); } + const char* buf() const { return buf_.get(); } + + // Number of bytes remaining in the buffer. + int num_remaining_bytes() const { return limit_ - pos_; } + + RandomAccessFile* const file_ = nullptr; // Not owned + int64_t file_pos_ = 0; // Next position to read from in "file_" + const int size_ = 0; // Size of "buf_" + const std::unique_ptr buf_; // The buffer itself. Must not be null. + // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_". + char* pos_ = + nullptr; // Current position in "buf". Must be in [buf(), buf() + size_]. + char* limit_ = nullptr; // Just past end of valid data in "buf". Must be in + // [buf(), buf() + size_]. InputBuffer(const InputBuffer&) = delete; void operator=(const InputBuffer&) = delete; @@ -123,9 +138,10 @@ inline absl::Status InputBuffer::ReadVarint32(uint32* result) { if (pos_ + core::kMaxVarint32Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). - const char* offset = core::GetVarint32Ptr(pos_, limit_, result); - if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); - pos_ = const_cast(offset); + const char* const new_pos = core::GetVarint32Ptr(pos_, limit_, result); + if (new_pos == nullptr) return errors::OutOfRange("Parsed past limit."); + const int offset = new_pos - buf(); + pos_ = buf() + offset; return absl::OkStatus(); } else { return ReadVarint32Fallback(result); @@ -137,9 +153,10 @@ inline absl::Status InputBuffer::ReadVarint64(uint64* result) { if (pos_ + core::kMaxVarint64Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). - const char* offset = core::GetVarint64Ptr(pos_, limit_, result); - if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); - pos_ = const_cast(offset); + const char* const new_pos = core::GetVarint64Ptr(pos_, limit_, result); + if (new_pos == nullptr) return errors::OutOfRange("Parsed past limit."); + const int offset = new_pos - buf(); + pos_ = buf() + offset; return absl::OkStatus(); } else { return ReadVarint64Fallback(result); From 1d3cca1a6175003790d1c58f75b24130e8bcfe16 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Wed, 2 Apr 2025 12:41:49 -0700 Subject: [PATCH 0168/1324] Direct StableHLO to HLO conversion : prototype with AddOp and ConstantOp. The cl demonstrates 1. codegen for stablehlo ops. 2. changes needed in conversion pipeline to allow stablehlo ops to convert directly to hlo without mhlo step. Note: The new direct path is still disabled for the production until we add all stablehlo ops to the codegen. Example: ``` hlo-translate --mlir-to-hlo $PWD/1.mlir ``` Input ``` func.func @main(%arg0: tensor) -> tensor { %c = stablehlo.constant dense<2> : tensor %0 = stablehlo.add %arg0, %c : tensor %1 = mhlo.multiply %c, %0 : tensor return %1 : tensor } ``` after StableHLO -> MHLO conversion, stablehlo.add and stablehlo.constant are preserved, not converted to mhlo. ``` mlir-hlo-opt --stablehlo-legalize-to-hlo=legalize-partially=true --chlo-legalize-to-hlo $PWD/1.mlir module { func.func @main(%arg0: tensor) -> tensor { %c = stablehlo.constant dense<2> : tensor %0 = stablehlo.add %arg0, %c : tensor %1 = mhlo.multiply %c, %0 : tensor return %1 : tensor } } ``` Final Output ``` HloModule main, entry_computation_layout={(s32[])->s32[]} ENTRY %main.5 (Arg_0.1: s32[]) -> s32[] { %Arg_0.1 = s32[] parameter(0) %constant.2 = s32[] constant(2) %add.3 = s32[] add(%Arg_0.1, %constant.2), metadata= ROOT %multiply.4 = s32[] multiply(%add.3, %constant.2), metadata= } ``` PiperOrigin-RevId: 743241700 --- .../xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 2 + .../mhlo_to_hlo/gen_hlo_op_writer.cc | 65 +++++-- .../mhlo_to_hlo/gen_hlo_op_writer.td | 177 ++++++++++++++++++ .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 16 +- third_party/xla/xla/hlo/translate/tests/BUILD | 2 + .../xla/hlo/translate/tests/stablehlo.mlir | 15 ++ .../stablehlo_legalize_to_hlo_pass.cc | 6 +- 7 files changed, 262 insertions(+), 21 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/tests/stablehlo.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index ecd5ac12539432..7ad5ef164428bc 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -196,6 +196,7 @@ cc_binary( name = "gen_hlo_op_writer", srcs = ["gen_hlo_op_writer.cc"], deps = [ + "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", "@llvm-project//mlir:TableGen", @@ -213,6 +214,7 @@ gentbl_cc_library( "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@stablehlo//:stablehlo_ops_td_filegroup", ], ) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.cc index 7ca8574eb7ab12..6d99f7d0b26993 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include "absl/container/flat_hash_set.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" @@ -152,21 +154,42 @@ static bool OperatorWritersMain(raw_ostream& os, const RecordKeeper& records) { llvm::errs() << "Failed to get CustomHloConverterOps list\n"; return false; } - // Convert the list to a set for faster lookups. std::unordered_set custom_convert_op_names; for (const auto* op_def : custom_convert_op_defs->getValues()) custom_convert_op_names.insert(op_def->getAsString()); + // Get the list of StableHLO operations that are allowed to be directly + // converted to HLO without intermediate MHLO step. + const auto* hlo_conversion_allowed_op_defs = + llvm::dyn_cast_or_null( + records.getGlobal("HloConversionAllowedOps")); + if (!hlo_conversion_allowed_op_defs) { + llvm::errs() << "Failed to get HloConversionAllowedOps list\n"; + return false; + } + // Convert the list to a set for faster lookups. + absl::flat_hash_set hlo_conversion_allowed_op_names; + for (const auto* op_def : hlo_conversion_allowed_op_defs->getValues()) + hlo_conversion_allowed_op_names.insert(op_def->getAsString()); + emitSourceFileHeader("MLIR XLA Builders", os); - // Emit all the helper functions. - for (const auto* def : records.getAllDerivedDefinitions("MHLO_Op")) { - Operator op(def); + // Emit all the HLO writers. + std::array dialect_defs = {"MHLO_Op", "StableHLO_Op"}; + for (auto dialect_def : dialect_defs) { + for (const auto* def : records.getAllDerivedDefinitions(dialect_def)) { + Operator op(def); + + if (dialect_def == "StableHLO_Op" && + !(hlo_conversion_allowed_op_names.contains(def->getName().str()))) { + continue; + } + + if (custom_convert_op_names.count(def->getName().str()) > 0) continue; - // Skip operations that have a custom exporter. - if (!(custom_convert_op_names.count(def->getName().str()) > 0)) BuildOperator(op, os); + } } // Emit a function to generate an XLA operation for the operations with @@ -193,22 +216,28 @@ static bool OperatorWritersMain(raw_ostream& os, const RecordKeeper& records) { "op, lowering_context.frame_index_builder));\n\n"; // Retrieve all the definitions derived from MHLO_Op and sort by record name. - for (const auto* def : records.getAllDerivedDefinitions("MHLO_Op")) { - // Skip operations that have a custom exporter. - Operator op(def); + for (auto dialect_def : dialect_defs) { + for (const auto* def : records.getAllDerivedDefinitions(dialect_def)) { + // Skip operations that have a custom exporter. + Operator op(def); + + if (dialect_def == "StableHLO_Op" && + !(hlo_conversion_allowed_op_names.contains(def->getName().str()))) { + continue; + } - // Cast to the current operation and build the exporter. - os << " if (auto xla_op = llvm::dyn_cast<" << op.getCppNamespace() - << "::" << op.getCppClassName() << ">(op)) {\n"; - os << " return "; + // Cast to the current operation and build the exporter. + os << " if (auto xla_op = llvm::dyn_cast<" << op.getCppNamespace() + << "::" << op.getCppClassName() << ">(op)) {\n"; + os << " return "; - if (custom_convert_op_names.count(def->getName().str()) > 0) - os << op.getCppNamespace() << "::"; + if (custom_convert_op_names.count(def->getName().str()) > 0) + os << op.getCppNamespace() << "::"; - os << "ExportXlaOp(xla_op, lowering_context);\n"; - os << " }\n"; + os << "ExportXlaOp(xla_op, lowering_context);\n"; + os << " }\n"; + } } - os << " return mlir::failure();\n" "}\n"; return false; diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index 8e989cef058f6e..f0e61f04070e5f 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -14,9 +14,186 @@ limitations under the License. ==============================================================================*/ include "xla/mlir_hlo/mhlo/IR/hlo_ops.td" +include "stablehlo/dialect/StablehloOps.td" + +// List of StableHLO ops that are allowed to be directly converted to HLO +// without intermediate MHLO step. +defvar HloConversionAllowedOps = [ + // StableHLO_AbsOp, + StableHLO_AddOp, + // StableHLO_AfterAllOp, + // StableHLO_AllGatherOp, + // StableHLO_AllReduceOp, + // StableHLO_AllToAllOp, + // StableHLO_AndOp, + // StableHLO_Atan2Op, + // StableHLO_BatchNormGradOp, + // StableHLO_BatchNormInferenceOp, + // StableHLO_BatchNormTrainingOp, + // StableHLO_BitcastConvertOp, + // StableHLO_BroadcastInDimOp, + // StableHLO_BroadcastOp, + // StableHLO_CaseOp, + // StableHLO_CbrtOp, + // StableHLO_CeilOp, + // StableHLO_CholeskyOp, + // StableHLO_ClampOp, + // StableHLO_ClzOp, + // StableHLO_CollectiveBroadcastOp, + // StableHLO_CollectivePermuteOp, + // StableHLO_CompareOp, + // StableHLO_ComplexOp, + // StableHLO_CompositeOp, + // StableHLO_ConcatenateOp, + StableHLO_ConstantOp, + // StableHLO_ConvertOp, + // StableHLO_ConvolutionOp, + // StableHLO_CosineOp, + // StableHLO_CreateTokenOp, + // StableHLO_CrossReplicaSumOp, + // StableHLO_CustomCallOp, + // StableHLO_DivOp, + // StableHLO_DotGeneralOp, + // StableHLO_DotOp, + // StableHLO_DynamicBroadcastInDimOp, + // StableHLO_DynamicConvOp, + // StableHLO_DynamicGatherOp, + // StableHLO_DynamicIotaOp, + // StableHLO_DynamicPadOp, + // StableHLO_DynamicReshapeOp, + // StableHLO_DynamicSliceOp, + // StableHLO_DynamicUpdateSliceOp, + // StableHLO_EinsumOp, + // StableHLO_Expm1Op, + // StableHLO_ExpOp, + // StableHLO_FftOp, + // StableHLO_FloorOp, + // StableHLO_GatherOp, + // StableHLO_GetDimensionSizeOp, + // StableHLO_GetTupleElementOp, + // StableHLO_IfOp, + // StableHLO_ImagOp, + // StableHLO_InfeedOp, + // StableHLO_IotaOp, + // StableHLO_IsFiniteOp, + // StableHLO_Log1pOp, + // StableHLO_LogisticOp, + // StableHLO_LogOp, + // StableHLO_MapOp, + // StableHLO_MaxOp, + // StableHLO_MinOp, + // StableHLO_MulOp, + // StableHLO_NegOp, + // StableHLO_NotOp, + // StableHLO_OptimizationBarrierOp, + // StableHLO_OrOp, + // StableHLO_OutfeedOp, + // StableHLO_PadOp, + // StableHLO_PartitionIdOp, + // StableHLO_PopulationCountOp, + // StableHLO_PowOp, + // StableHLO_RealDynamicSliceOp, + // StableHLO_RealOp, + // StableHLO_RecvOp, + // StableHLO_ReduceOp, + // StableHLO_ReducePrecisionOp, + // StableHLO_ReduceScatterOp, + // StableHLO_ReduceWindowOp, + // StableHLO_RemOp, + // StableHLO_ReplicaIdOp, + // StableHLO_ReshapeOp, + // StableHLO_ReturnOp, + // StableHLO_ReverseOp, + // StableHLO_RngBitGeneratorOp, + // StableHLO_RngOp, + // StableHLO_RoundNearestEvenOp, + // StableHLO_RoundOp, + // StableHLO_RsqrtOp, + // StableHLO_ScatterOp, + // StableHLO_SelectAndScatterOp, + // StableHLO_SelectOp, + // StableHLO_SendOp, + // StableHLO_SetDimensionSizeOp, + // StableHLO_ShiftLeftOp, + // StableHLO_ShiftRightArithmeticOp, + // StableHLO_ShiftRightLogicalOp, + // StableHLO_SignOp, + // StableHLO_SineOp, + // StableHLO_SliceOp, + // StableHLO_SortOp, + // StableHLO_SqrtOp, + // StableHLO_SubtractOp, + // StableHLO_TanhOp, + // StableHLO_TanOp, + // StableHLO_TorchIndexSelectOp, + // StableHLO_TransposeOp, + // StableHLO_TriangularSolveOp, + // StableHLO_TupleOp, + // StableHLO_UnaryEinsumOp, + // StableHLO_UniformDequantizeOp, + // StableHLO_UniformQuantizeOp, + // StableHLO_WhileOp, + // StableHLO_XorOp, +]; // List of StableHLO and MHLO ops that need a custom HLO converter. defvar CustomHloConverterOps = [ + // StableHLO ops + // go/keep-sorted start + // StableHLO_AllGatherOp, + // StableHLO_AllReduceOp, + // StableHLO_AllToAllOp, + // StableHLO_BatchNormGradOp, + // StableHLO_BatchNormTrainingOp, + // StableHLO_BitcastConvertOp, + // StableHLO_BroadcastInDimOp, + // StableHLO_CaseOp, + // StableHLO_CollectiveBroadcastOp, + // StableHLO_CompareOp, + // StableHLO_CompositeOp, + StableHLO_ConstantOp, + // StableHLO_ConvertOp, + // StableHLO_ConvolutionOp, + // StableHLO_CosineOp, + // StableHLO_CustomCallOp, + // StableHLO_DotGeneralOp, + // StableHLO_DotOp, + // StableHLO_DynamicBroadcastInDimOp, + // StableHLO_DynamicConvOp, + // StableHLO_DynamicGatherOp, + // StableHLO_DynamicIotaOp, + // StableHLO_DynamicPadOp, + // StableHLO_DynamicReshapeOp, + // StableHLO_IfOp, + // StableHLO_InfeedOp, + // StableHLO_IotaOp, + // StableHLO_MapOp, + // StableHLO_OptimizationBarrierOp, + // StableHLO_OutfeedOp, + // StableHLO_PadOp, + // StableHLO_PartitionIdOp, + // StableHLO_RealDynamicSliceOp, + // StableHLO_RecvOp, + // StableHLO_ReduceOp, + // StableHLO_ReduceScatterOp, + // StableHLO_ReduceWindowOp, + // StableHLO_ReshapeOp, + // StableHLO_ReturnOp, + // StableHLO_RngBitGeneratorOp, + // StableHLO_RngOp, + // StableHLO_ScatterOp, + // StableHLO_SelectAndScatterOp, + // StableHLO_SendOp, + // StableHLO_SetDimensionSizeOp, + // StableHLO_SineOp, + // StableHLO_SortOp, + // StableHLO_SubtractOp, + // StableHLO_TanOp, + // StableHLO_UniformDequantizeOp, + // StableHLO_UniformQuantizeOp, + // StableHLO_WhileOp, + // go/keep-sorted end + // MHLO ops. // go/keep-sorted start MHLO_AddDependencyOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 42b41b974c0e69..7ec6c4d8989b79 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -64,6 +64,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/RegionUtils.h" #include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" #include "xla/array.h" #include "xla/comparison_util.h" @@ -1086,6 +1087,17 @@ void BuildGetTupleElementsForTupleResults(mlir::Operation* op, xla::XlaOp tuple, } // namespace namespace mlir { + +namespace stablehlo { +namespace { + +LogicalResult ExportXlaOp(ConstantOp op, OpLoweringContext ctx) { + return failure(); +} + +} // namespace +} // namespace stablehlo + namespace mhlo { namespace { LogicalResult ExportXlaOp(CollectiveBroadcastOp op, OpLoweringContext ctx) { @@ -3639,8 +3651,8 @@ LogicalResult ConvertToHloModule::Lower( ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaOp* return_value) { // Explicitly fail for ops that are not supported for export. - if (inst->getDialect() != - inst->getContext()->getLoadedDialect() && + if (!mlir::isa( + inst->getDialect()) && !mlir::isa(inst)) { diff --git a/third_party/xla/xla/hlo/translate/tests/BUILD b/third_party/xla/xla/hlo/translate/tests/BUILD index 68bb7fc8317101..d003281e4baeb8 100644 --- a/third_party/xla/xla/hlo/translate/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/tests/BUILD @@ -18,6 +18,7 @@ lit_test_suite( "print_layouts.mlir", "simple.hlo", "simple.mlir", + "stablehlo.mlir", "vhlo_input.mlir", # go/keep-sorted end ], @@ -33,6 +34,7 @@ lit_test_suite( ], tools = [ "//xla/hlo/tools:hlo-translate", + "//xla/mlir_hlo:mlir-hlo-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir new file mode 100644 index 00000000000000..7e782eb10f3c97 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir @@ -0,0 +1,15 @@ +// RUN: hlo-translate -mlir-to-hlo %s | FileCheck %s +// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo=convert-xla-supported-stablehlo=false %s | FileCheck %s --check-prefix CHECK-DIRECT + +// Tests for all stablehlo ops to validate stablehlo -> hlo conversion. + + +// CHECK-LABEL: HloModule + +// CHECK: %[[ARG0:.*]] = f32[4] parameter(0) +// CHECK: %[[ARG1:.*]] = f32[4] parameter(1) +// CHECK: ROOT %add.3 = f32[4] add(%[[ARG0]], %[[ARG1]]) +func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<4xf32> func.return %0 : tensor<4xf32> +} +// CHECK-DIRECT: stablehlo.add diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 9a06ce80bbccec..9257d88984b097 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -39,6 +39,10 @@ namespace mhlo { namespace { +void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { + target.addLegalOp(); +} + struct StablehloLegalizeToHloPass : public impl::StablehloLegalizeToHloPassBase { using StablehloLegalizeToHloPassBase::StablehloLegalizeToHloPassBase; @@ -49,7 +53,7 @@ struct StablehloLegalizeToHloPass // Allow injecting legal ops to permit gradual migration. if (!convert_xla_supported_stablehlo_) { - target.addLegalOp(); + legalDirectStablehloToHloConversionOps(target); } stablehlo::StablehloToHloTypeConverter converter; From d7b2ae622162785d18116314d888532b7cec363f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 2 Apr 2025 13:04:12 -0700 Subject: [PATCH 0169/1324] [XLA:GPU][NFC] Turn a crash into a graceful failure to avoid interrupting test runs when we hit it while developing. PiperOrigin-RevId: 743249260 --- .../xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 4b539922d183cc..d2ca7eadac1dfa 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1658,7 +1658,11 @@ absl::StatusOr CreateTritonModule( if (fusion_kind == kTritonGemmFusionKind) { // If the generic Triton emitter is enabled, we should never go through the // legacy MatMul emitter. - QCHECK(!UseGenericTritonEmitterForGemms(fusion)); + if (UseGenericTritonEmitterForGemms(fusion)) { + return absl::FailedPreconditionError( + "The generic Triton emitter is enabled, but the legacy MatMul " + "emitter is being used."); + } TF_ASSIGN_OR_RETURN(tma_metadata, EmitMatMul(b, libdevice_path, device_info, fusion, fn, block_level_parameters)); From dd29c8f5dc753d05b4587f72ff5af44c0c9538f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 13:06:32 -0700 Subject: [PATCH 0170/1324] Allow for viewing of memory_viewer trace properties by memory type. PiperOrigin-RevId: 743250035 --- .../profiler/convert/hlo_to_tools_data.cc | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc index 0978f0211d4d8f..ba4a13fa6c52ba 100644 --- a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/statusor.h" @@ -38,12 +39,11 @@ namespace profiler { namespace { absl::StatusOr GetMemoryViewerPreprocessResult( - const xla::HloProto& hlo_proto) { + const xla::HloProto& hlo_proto, int memory_space_color) { static constexpr int kSmallBufferSize = 16 * 1024; // 16KB - static constexpr int kMemorySpaceColor = 0; // HBM auto result_or = ConvertHloProtoToPreprocessResult( - hlo_proto, kSmallBufferSize, kMemorySpaceColor); + hlo_proto, kSmallBufferSize, memory_space_color); if (!result_or.ok()) { return errors::Internal( "Failed to convert HLO proto to memory viewer result: ", @@ -53,8 +53,9 @@ absl::StatusOr GetMemoryViewerPreprocessResult( } absl::StatusOr ConvertHloProtoToMemoryViewer( - const xla::HloProto& hlo_proto) { - auto result_or = GetMemoryViewerPreprocessResult(hlo_proto); + const xla::HloProto& hlo_proto, int memory_space_color) { + auto result_or = + GetMemoryViewerPreprocessResult(hlo_proto, memory_space_color); if (!result_or.ok()) { return result_or.status(); } @@ -75,8 +76,9 @@ absl::StatusOr ConvertHloProtoToMemoryViewer( } absl::StatusOr ConvertHloProtoToAllocationTimeline( - const xla::HloProto& hlo_proto) { - auto result_or = GetMemoryViewerPreprocessResult(hlo_proto); + const xla::HloProto& hlo_proto, int memory_space_color) { + auto result_or = + GetMemoryViewerPreprocessResult(hlo_proto, memory_space_color); if (!result_or.ok()) { return result_or.status(); } @@ -117,11 +119,18 @@ absl::StatusOr ConvertHloProtoToToolData( GetHloProtoByModuleName(session_snapshot, *hlo_module_name)); // Convert from HLO proto to tools data. + int memory_space_color = 0; + if (!absl::SimpleAtoi( + GetParamWithDefault(options, "memory_space", std::string("0")), + &memory_space_color)) { + memory_space_color = 0; + } + if (tool_name == "memory_viewer") { if (GetParamWithDefault(options, "view_memory_allocation_timeline", 0)) { - return ConvertHloProtoToAllocationTimeline(hlo_proto); + return ConvertHloProtoToAllocationTimeline(hlo_proto, memory_space_color); } - return ConvertHloProtoToMemoryViewer(hlo_proto); + return ConvertHloProtoToMemoryViewer(hlo_proto, memory_space_color); } else if (tool_name == "graph_viewer") { return ConvertHloProtoToGraphViewer(hlo_proto, options); } else { From 4ee57e675d8684410d3d83cebad82adf5ff2bf39 Mon Sep 17 00:00:00 2001 From: Praveen Batra Date: Wed, 2 Apr 2025 13:15:00 -0700 Subject: [PATCH 0171/1324] Remove allow_get_default_platform from XLA_FLAGS and set it directly as an env variable. Setting it as an XLA_FLAG in the build_defs would prevent a user from adding more flags to a test from the command line. Also, use DeepCopy() so that the build_def isn't using the same env variables across multiple targets. PiperOrigin-RevId: 743253142 --- third_party/xla/xla/debug_options_flags.cc | 6 ---- third_party/xla/xla/service/BUILD | 9 +++--- third_party/xla/xla/service/platform_util.cc | 29 +++++++++++++------- third_party/xla/xla/xla.proto | 8 ++---- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 2711a0b7098065..e8547150564a7b 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -338,7 +338,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_unsupported_crash_on_hlo_pass_fix_max_iterations(false); opts.set_xla_hlo_pass_fix_detect_cycles(false); opts.set_xla_gpu_experimental_enable_sync_collective_combining(false); - opts.set_xla_allow_get_default_platform(true); opts.set_xla_unsupported_crash_on_hlo_pass_silent_hlo_change(false); opts.set_xla_unsupported_crash_on_hlo_pass_noop_change(false); return opts; @@ -2316,11 +2315,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, set_xla_gpu_experimental_enable_sync_collective_combining), debug_options->xla_gpu_experimental_enable_sync_collective_combining(), "Enable sync collective combining.")); - flag_list->push_back(tsl::Flag( - "xla_allow_get_default_platform", - bool_setter_for(&DebugOptions::set_xla_allow_get_default_platform), - debug_options->xla_allow_get_default_platform(), - "If false, GetDefaultPlatform will cause an error if called.")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_collective_cse_distance_threshold", int64_setter_for( diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 6e21551653e39d..cbff4fa8204e07 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1094,15 +1094,16 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/platform_util.cc b/third_party/xla/xla/service/platform_util.cc index a7e6fb88966373..49d6db3af2cc24 100644 --- a/third_party/xla/xla/service/platform_util.cc +++ b/third_party/xla/xla/service/platform_util.cc @@ -15,17 +15,21 @@ limitations under the License. #include "xla/service/platform_util.h" +#include #include #include #include #include +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/service/compiler.h" #include "xla/status_macros.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/host/host_platform_id.h" @@ -33,12 +37,11 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" namespace xla { @@ -100,11 +103,17 @@ PlatformUtil::GetSupportedPlatforms() { } absl::StatusOr PlatformUtil::GetDefaultPlatform() { - TF_RET_CHECK(GetDebugOptionsFromFlags().xla_allow_get_default_platform()) - << "--xla_allow_get_default_platform=false means GetDefaultPlatform is " - "not allowed and the platform must be specified. If this is a test " - "that has been migrated to PJRT, double-check that you are using a " - "PJRT-compatible test class."; + const char* maybe_allow_get_default_platform = + getenv("XLA_ALLOW_GET_DEFAULT_PLATFORM"); + if (maybe_allow_get_default_platform != nullptr) { + std::string allow_get_default_platform(maybe_allow_get_default_platform); + TF_RET_CHECK(allow_get_default_platform == "true") + << "GetDefaultPlatform is not allowed (XLA_ALLOW_GET_DEFAULT_PLATFORM=" + << allow_get_default_platform + << ") and the platform must be specified. If this is a test that has " + "been migrated to PJRT, double-check that you are using a " + "PJRT-compatible test class."; + } TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); se::Platform* platform = nullptr; diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index c4e1837272c7b1..924701a09deda4 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -1188,11 +1188,6 @@ message DebugOptions { bool xla_pjrt_allow_auto_layout_in_hlo = 344; - // If false, platform_util::GetDefaultPlatform() will return an error. Used as - // a safeguard for tests where GetDefaultPlatform() won't work properly and - // its use indicates an error in test setup. - bool xla_allow_get_default_platform = 371; - // If true, the test launched with ClientLibraryTestBase will also try to // use command buffer mode to run. bool xla_test_add_command_buffer_mode = 373; @@ -1231,8 +1226,9 @@ message DebugOptions { // xla_gpu_enable_bf16_6way_gemm // xla_gpu_enable_cudnn_fmha // xla_gpu_unsupported_force_triton_gemm + // xla_allow_get_default_platform reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320, - 325, 326, 332, 361, 270, 229, 271, 279, 218, 369; + 325, 326, 332, 361, 270, 229, 271, 279, 218, 369, 371; } // Contains flags which affects the GPU compilation result. From d820b0160f63bd8ad844a5d78b65ef7315c72af3 Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Wed, 2 Apr 2025 13:34:13 -0700 Subject: [PATCH 0172/1324] [XLA:GPU] Add triton support test for after-all, add-dependency, custom-call PiperOrigin-RevId: 743260233 --- .../backends/gpu/codegen/triton/support.cc | 3 - .../gpu/codegen/triton/support_test.cc | 68 ++++++++++++++++++- 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 9a3d6160535b73..aefc67e95d01c3 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -586,13 +586,10 @@ CodegenDecision IsTritonSupportedInstructionImpl( namespace internal { bool IsTritonUnsupportedOpcode(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kAddDependency: - case HloOpcode::kAfterAll: case HloOpcode::kCholesky: case HloOpcode::kConvolution: case HloOpcode::kCopyDone: case HloOpcode::kCopyStart: - case HloOpcode::kCustomCall: case HloOpcode::kDynamicReshape: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 819a875214847b..10b964f733bd3a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -2456,16 +2456,75 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest())), TritonSupportTestTwoTypesAndDeviceToString); +using AddDependencyTest = TritonSupportTestWithDeviceParam; + +TEST_P(AddDependencyTest, AddDependency) { + auto cc = GetParam(); + const std::string kHloTestTemplate = R"( + ENTRY triton_computation { + param = f32[10] parameter(0) + token0 = token[] after-all() + ROOT add_dep = f32[10] add-dependency(param, token0) + })"; + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( + kHloTestTemplate, + F32, // Type is irrelevant. + HloOpcode::kAddDependency)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); +} + +INSTANTIATE_TEST_SUITE_P(AddDependencySuite, AddDependencyTest, + ::testing::ValuesIn(AllDevicesToTest()), + TritonSupportTestDeviceToString); + +using AfterAllTest = TritonSupportTestWithDeviceParam; + +TEST_P(AfterAllTest, AfterAll) { + auto cc = GetParam(); + const std::string kHloTestTemplate = R"( + ENTRY triton_computation { + token0 = token[] after-all() + token1 = token[] after-all() + ROOT token2 = token[] after-all(token0, token1) + })"; + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( + kHloTestTemplate, + F32, // Type is irrelevant. + HloOpcode::kAfterAll)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc); +} + +INSTANTIATE_TEST_SUITE_P(AfterAllSuite, AfterAllTest, + ::testing::ValuesIn(AllDevicesToTest()), + TritonSupportTestDeviceToString); + +using CustomCallTest = TritonSupportTestWithDeviceParam; + +TEST_P(CustomCallTest, CustomCall) { + auto cc = GetParam(); + const std::string kHloTestTemplate = R"( + ENTRY triton_computation { + parameter = f32[10] parameter(0) + ROOT custom_call_op = f32[10] custom-call(parameter), custom_call_target="SomeTarget" + })"; + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( + kHloTestTemplate, + F32, // Type is irrelevant. + HloOpcode::kCustomCall)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); +} + +INSTANTIATE_TEST_SUITE_P(CustomCallSuite, CustomCallTest, + ::testing::ValuesIn(AllDevicesToTest()), + TritonSupportTestDeviceToString); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start - HloOpcode::kAddDependency, - HloOpcode::kAfterAll, HloOpcode::kCholesky, HloOpcode::kConvolution, HloOpcode::kCopyDone, HloOpcode::kCopyStart, - HloOpcode::kCustomCall, HloOpcode::kDynamicReshape, HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, @@ -2515,6 +2574,8 @@ absl::flat_hash_set AllTestedOpcodes() { ret.insert(kTestedOpsIota.begin(), kTestedOpsIota.end()); ret.insert(kTestedOpsRng.begin(), kTestedOpsRng.end()); + ret.emplace(HloOpcode::kAfterAll); + ret.emplace(HloOpcode::kAddDependency); ret.emplace(HloOpcode::kBatchNormGrad); ret.emplace(HloOpcode::kBatchNormInference); ret.emplace(HloOpcode::kBatchNormTraining); @@ -2522,6 +2583,7 @@ absl::flat_hash_set AllTestedOpcodes() { ret.emplace(HloOpcode::kCall); ret.emplace(HloOpcode::kComplex); ret.emplace(HloOpcode::kConditional); + ret.emplace(HloOpcode::kCustomCall); ret.emplace(HloOpcode::kDomain); ret.emplace(HloOpcode::kDot); ret.emplace(HloOpcode::kGetDimensionSize); From a73efe252847f72e3b74a5c403fd1aadc9be0ceb Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 2 Apr 2025 13:52:28 -0700 Subject: [PATCH 0173/1324] Remove tuples from the TfrtCpuClient by forcing the use of untuple_result. PiperOrigin-RevId: 743267124 --- .../cpu/benchmarks/hlo_benchmark_runner.cc | 63 +++--- .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 186 ++++++------------ third_party/xla/xla/pjrt/cpu/cpu_client.cc | 67 ++++--- .../xla/pjrt/cpu/tracked_cpu_device_buffer.cc | 91 ++------- .../xla/pjrt/cpu/tracked_cpu_device_buffer.h | 38 ++-- .../cpu/tracked_cpu_device_buffer_test.cc | 146 +------------- 6 files changed, 156 insertions(+), 435 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc b/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc index 772c22f944ce0d..0cd50af444f5e9 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc @@ -84,40 +84,28 @@ class AliasHelper { } absl::Status SwapOutputAliasedBuffersToArgumentBuffers( - PjRtBuffer* result, + std::vector>& results, std::vector>& args_buffers, std::vector& args_ptrs) { if (!ComputationHasAliasing()) { return absl::OkStatus(); } - TfrtCpuBuffer* result_tfrt_cpu_buffer = - tsl::down_cast(result); - - TF_ASSIGN_OR_RETURN( - AbstractTfrtCpuBuffer::DonationTransaction buffer_donation, - result_tfrt_cpu_buffer->AcquireDonation()); - TrackedCpuDeviceBuffer* tracked_tfrt_cpu_device_buffer = - buffer_donation.device_buffer(); - - for (const auto& [output_index, arg_index] : + for (const auto& [output_sindex, arg_index] : aliased_output_index_to_argument_index_) { - // we don't need the entire buffer just the one at the output index - tsl::AsyncValuePtr output_cpu_memory = - tracked_tfrt_cpu_device_buffer->Buffer(output_index); - - auto tracked_device_buffer = std::make_unique( - /*is_tuple=*/false, /*owns_buffers=*/true, - absl::InlinedVector, 4>{ - output_cpu_memory.CopyRef()}, - tsl::MakeAvailableAsyncValueRef()); - - args_buffers[arg_index] = std::make_unique( - tsl::down_cast(args_buffers[arg_index].get()) - ->on_device_shape(), - std::move(tracked_device_buffer), - tsl::down_cast(client_), - tsl::down_cast(device_), memory_space_); - + if (output_sindex.size() > 1) { + return absl::InvalidArgumentError("Nested tuples not supported"); + } + size_t output_index = 0; + if (output_sindex.size() == 1) { + output_index = output_sindex[0]; + } + if (output_index >= results.size()) { + return absl::InvalidArgumentError("index out of bounds."); + } + if (!results[output_index]) { + return absl::InvalidArgumentError("Result already donated."); + } + args_buffers[arg_index] = std::move(results[output_index]); args_ptrs[arg_index] = args_buffers[arg_index].get(); } return absl::OkStatus(); @@ -215,6 +203,7 @@ absl::Status RunHloBenchmark(benchmark::State& state, // thread pool if we need to run multiple executions in parallel. ExecuteOptions execute_options; execute_options.execution_mode = ExecuteOptions::ExecutionMode::kSynchronous; + execute_options.untuple_result = true; std::vector> execution_args_ptrs( benchmark_options.num_executions); @@ -266,20 +255,12 @@ absl::Status RunHloBenchmark(benchmark::State& state, for (size_t i = 0; i < benchmark_options.num_executions; ++i) { for (const auto& result : execution_results[i]) { CHECK_OK(result->GetReadyFuture().Await()); - CHECK(!alias_helper.ComputationHasAliasing() || - result->IsTuple() && execution_results[i].size() == 1) - << "Only single output tuple is supported in benchmarking aliased " - "models. " - "result->IsTuple(): " - << result->IsTuple() - << " execution_results size: " << execution_results[i].size(); - std::vector>& args_buffers = - execution_args_buffers[i]; - std::vector& args_ptrs = execution_args_ptrs[i]; - TF_RETURN_IF_ERROR( - alias_helper.SwapOutputAliasedBuffersToArgumentBuffers( - result.get(), args_buffers, args_ptrs)); } + std::vector>& args_buffers = + execution_args_buffers[i]; + std::vector& args_ptrs = execution_args_ptrs[i]; + TF_RETURN_IF_ERROR(alias_helper.SwapOutputAliasedBuffersToArgumentBuffers( + execution_results[i], args_buffers, args_ptrs)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index e859b637a5133b..85f60b0a0b6481 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -89,48 +89,29 @@ void UnpackIntNToLiteral(PrimitiveType input_element_type, void CopyCpuBufferToLiteral(const Shape& device_shape, TrackedCpuDeviceBuffer* device_buffer, MutableLiteralBase* literal) { - if (!device_shape.IsTuple()) { - const tsl::AsyncValueRef& b = device_buffer->Buffers()[0]; - CHECK(b.IsConcrete()); - if (primitive_util::IsSubByteNonPredType(device_shape.element_type())) { - UnpackIntNToLiteral(device_shape.element_type(), *b, literal, - /*shape_index=*/{}); - } else { - std::memcpy(literal->untyped_data(), b->untyped_data(), - ShapeUtil::ByteSizeOf(device_shape)); - } + CHECK(!device_shape.IsTuple()); + const tsl::AsyncValueRef& b = device_buffer->buffer(); + CHECK(b.IsConcrete()); + if (primitive_util::IsSubByteNonPredType(device_shape.element_type())) { + UnpackIntNToLiteral(device_shape.element_type(), *b, literal, + /*shape_index=*/{}); } else { - // Tuple case. - int num_leaves = literal->shape().tuple_shapes().size(); - for (int i = 0; i < num_leaves; ++i) { - const tsl::AsyncValueRef& b = - device_buffer->Buffers()[i]; - CHECK(b.IsConcrete()); - if (primitive_util::IsSubByteNonPredType(device_shape.element_type())) { - UnpackIntNToLiteral(device_shape.element_type(), *b, literal, {i}); - } else { - std::memcpy( - literal->untyped_data({i}), b->untyped_data(), - ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(device_shape, {i}))); - } - } + std::memcpy(literal->untyped_data(), b->untyped_data(), + ShapeUtil::ByteSizeOf(device_shape)); } } // `buffers` must be available. -ShapedBuffer AsShapedBuffer( - int device_ordinal, const Shape& on_device_shape, - absl::Span> buffers) { +ShapedBuffer AsShapedBuffer(int device_ordinal, const Shape& on_device_shape, + tsl::AsyncValueRef buf) { ShapedBuffer shaped_buffer(on_device_shape, device_ordinal); ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); - for (const auto& buf : buffers) { - CHECK(buf.IsConcrete()); - CHECK(iterator != shaped_buffer.buffers().end()); - iterator->second = - se::DeviceMemoryBase(buf->untyped_data(), buf->size_bytes()); - ++iterator; - } + CHECK(buf.IsConcrete()); + CHECK(iterator != shaped_buffer.buffers().end()); + iterator->second = + se::DeviceMemoryBase(buf->untyped_data(), buf->size_bytes()); + ++iterator; CHECK(iterator == shaped_buffer.buffers().end()); return shaped_buffer; } @@ -170,7 +151,7 @@ absl::StatusOr AbstractTfrtCpuBuffer::logical_on_device_shape() { // Safe to call `AsShapedBuffer` because the definition event is ready. ShapedBuffer shaped_buffer = AsShapedBuffer(device()->local_hardware_id().value(), on_device_shape_, - device_buffer->Buffers()); + device_buffer->buffer()); Shape ret_shape = on_device_shape_; TF_RETURN_IF_ERROR(ReadDynamicShapesOnCpu( &shaped_buffer, &ret_shape, cpu::CpuExecutable::ShapeSizeBytes)); @@ -213,7 +194,7 @@ AbstractTfrtCpuBuffer::AcquireExternalReference() { ++external_reference_counter_; return {std::make_unique( - this, tracked_device_buffer_->Buffers()[0])}; + this, tracked_device_buffer_->buffer())}; } void AbstractTfrtCpuBuffer::DropExternalReference() { @@ -233,7 +214,7 @@ class TrackedCpuDeviceBufferExternalReference : tracked_device_buffer_(std::move(tracked_device_buffer)) { // We need to wait for the memory to be allocated before sharing it with // external frameworks like NumPy. - const auto& buffer = tracked_device_buffer_->Buffers()[0]; + const auto& buffer = tracked_device_buffer_->buffer(); tsl::BlockUntilReady(buffer); CHECK(buffer.IsConcrete()); data_ptr_ = buffer->untyped_data(); @@ -497,7 +478,7 @@ PjRtFuture<> AbstractTfrtCpuBuffer::CopyRawToHostHelper( offset + transfer_size > ShapeUtil::ByteSizeOf(device_shape)) { return InvalidArgument("CopyRawToHost out of bounds."); } - const tsl::AsyncValueRef& b = device_buffer->Buffers()[0]; + const tsl::AsyncValueRef& b = device_buffer->buffer(); CHECK(b.IsConcrete()); std::memcpy(dst, reinterpret_cast(b->untyped_data()) + offset, transfer_size); @@ -536,58 +517,36 @@ AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) { } MarkEventReadyOnExit ready_on_exit(std::move(usage_event)); - int num_leaf_buffers = src_device_buffer->Buffers().size(); - absl::InlinedVector, 4> src_buffers; - absl::InlinedVector, 4> dst_buffers; - absl::InlinedVector dst_buffers_sizes; - absl::InlinedVector, 4> dst_definition_events; - src_buffers.reserve(num_leaf_buffers); - dst_buffers.reserve(num_leaf_buffers); - dst_buffers_sizes.reserve(num_leaf_buffers); - dst_definition_events.reserve(num_leaf_buffers); - - for (int i = 0; i < num_leaf_buffers; ++i) { - src_buffers.push_back(std::move(src_device_buffer->Buffers()[i])); - dst_buffers.push_back( - tsl::MakeUnconstructedAsyncValueRef()); - dst_buffers_sizes.push_back(src_device_buffer->BufferSizes()[i]); - dst_definition_events.push_back( - tsl::MakeConstructedAsyncValueRef()); - } + auto dst_buffer = tsl::MakeUnconstructedAsyncValueRef(); + auto dst_definition_event = tsl::MakeConstructedAsyncValueRef(); // Wait for src buffer definition events to finish before d2d dispatch. // Errors are propagated asynchronously in dst buffer's definition events. const auto& src_definition_event = src_device_buffer->definition_event(); - auto copy_task = [num_leaf_buffers, src_buffers = std::move(src_buffers), - dst_buffers_copies = dst_buffers, dst_definition_events, + auto copy_task = [src_buffer = src_device_buffer->buffer(), + dst_buffer_copy = dst_buffer, dst_definition_event, src_definition_event, ready_on_exit = std::move(ready_on_exit)]() mutable { tsl::profiler::TraceMe traceme("D2D Dispatch"); if (auto* error = src_definition_event.GetErrorIfPresent()) { - for (int i = 0; i < num_leaf_buffers; ++i) { - // Any error discovered in src buffer are propagated to dst buffer - // definition events, which will surface to users in - // dst_buffer->ToLiteral(). - dst_definition_events[i].SetError(*error); - } + // Any error discovered in src buffer are propagated to dst buffer + // definition events, which will surface to users in + // dst_buffer->ToLiteral(). + dst_definition_event.SetError(*error); return; } - for (int i = 0; i < num_leaf_buffers; ++i) { - // `src_buffers` are available because `src_definition_event` should have - // been ready. - CHECK(src_buffers[i].IsConcrete()); - auto dst_memory = CpuDeviceMemory::Allocate(src_buffers[i]->size_bytes()); - if (!dst_memory.ok()) { - dst_definition_events[i].SetError(dst_memory.status()); - continue; - } - dst_buffers_copies[i].emplace(std::move(*dst_memory)); - std::memcpy(dst_buffers_copies[i]->untyped_data(), - src_buffers[i]->untyped_data(), src_buffers[i]->size_bytes()); - dst_definition_events[i].SetStateConcrete(); + CHECK(src_buffer.IsConcrete()); + auto dst_memory = CpuDeviceMemory::Allocate(src_buffer->size_bytes()); + if (!dst_memory.ok()) { + dst_definition_event.SetError(dst_memory.status()); + return; } + dst_buffer_copy.emplace(std::move(*dst_memory)); + std::memcpy(dst_buffer_copy->untyped_data(), src_buffer->untyped_data(), + src_buffer->size_bytes()); + dst_definition_event.SetStateConcrete(); }; src_definition_event.AndThen( @@ -596,8 +555,9 @@ AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) { }); return std::make_unique( - on_device_shape_.IsTuple(), /*owns_buffers=*/true, std::move(dst_buffers), - std::move(dst_buffers_sizes), std::move(dst_definition_events)); + /*owns_buffers=*/true, dst_buffer, src_device_buffer->BufferSize(), + absl::InlinedVector, 4>{ + std::move(dst_definition_event)}); } PjRtFuture<> AbstractTfrtCpuBuffer::GetReadyFuture() { @@ -677,65 +637,35 @@ void AbstractTfrtCpuBuffer::CopyFromLiteral( auto usage_event = tsl::MakeAvailableAsyncValueRef(); auto* device_buffer = AcquireUsage(std::move(usage_event)); CHECK(device_buffer); - if (!shape.IsTuple()) { - // It is OK to capture `buffer` pointer because the `output_buffer` can't be - // deleted until all the usage holds have gone away. - async_work_runner->Schedule( - [literal, av = (*avs)[0].CopyRef(), device_buffer, shape]() mutable { - tsl::profiler::TraceMe traceme("H2D Dispatch"); - const tsl::AsyncValueRef& b = - device_buffer->Buffers()[0]; - CHECK(b.IsConcrete()); - PackOrCopy(shape.element_type(), literal, b->untyped_data(), - b->size_bytes()); - // Signal copy is complete. - av->SetStateConcrete(); - }); - } else { - // For tuple, transfer leaf literal individually in parallel. - for (int i = 0; i < shape.tuple_shapes_size(); ++i) { - // It is OK to capture `buffer` pointer because the `output_buffer` can't - // be deleted until all the usage holds have gone away. - async_work_runner->Schedule([i, literal, av = (*avs)[i].CopyRef(), shape, - device_buffer]() mutable { + CHECK(!shape.IsTuple()); + // It is OK to capture `buffer` pointer because the `output_buffer` can't be + // deleted until all the usage holds have gone away. + async_work_runner->Schedule( + [literal, av = (*avs)[0].CopyRef(), device_buffer, shape]() mutable { tsl::profiler::TraceMe traceme("H2D Dispatch"); - auto slice = LiteralSlice(literal, {i}); - const tsl::AsyncValueRef& b = - device_buffer->Buffers()[i]; + const tsl::AsyncValueRef& b = device_buffer->buffer(); CHECK(b.IsConcrete()); - PackOrCopy(slice.shape().element_type(), slice, b->untyped_data(), + PackOrCopy(shape.element_type(), literal, b->untyped_data(), b->size_bytes()); // Signal copy is complete. av->SetStateConcrete(); }); - } - } } /*static*/ absl::StatusOr> AbstractTfrtCpuBuffer::AllocateTrackedDeviceBuffer( const Shape& on_device_shape, absl::InlinedVector, 4> definition_events) { - absl::InlinedVector, 4> buffers; - if (!on_device_shape.IsTuple()) { - size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape); - TF_ASSIGN_OR_RETURN(tsl::AsyncValueRef device_buffer, - CpuDeviceMemory::AllocateAvailable(byte_size)); - buffers.push_back(std::move(device_buffer)); - return std::make_unique( - /*is_tuple=*/false, /*owns_buffers=*/true, std::move(buffers), - std::move(definition_events)); - } - // Tuple case. - buffers.reserve(on_device_shape.tuple_shapes().size()); - for (const auto& leaf_shape : on_device_shape.tuple_shapes()) { - size_t byte_size = ShapeUtil::ByteSizeOf(leaf_shape); - TF_ASSIGN_OR_RETURN(tsl::AsyncValueRef device_buffer, - CpuDeviceMemory::AllocateAvailable(byte_size)); - buffers.push_back(std::move(device_buffer)); - } + if (on_device_shape.IsTuple()) { + return absl::InvalidArgumentError( + absl::StrCat("Tuples are not supported for cpu-buffers: ", + on_device_shape.ToString())); + } + size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape); + TF_ASSIGN_OR_RETURN(tsl::AsyncValueRef device_buffer, + CpuDeviceMemory::AllocateAvailable(byte_size)); return std::make_unique( - /*is_tuple=*/true, /*owns_buffers=*/true, std::move(buffers), + /*owns_buffers=*/true, std::move(device_buffer), std::move(definition_events)); } @@ -886,8 +816,8 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( } } return std::make_unique( - /*is_tuple=*/false, owns_buffers, std::move(buffers), - std::move(definition_events), std::move(on_delete_callback)); + owns_buffers, std::move(buffers[0]), std::move(definition_events), + std::move(on_delete_callback)); } AbstractAsyncHostToHostMemoryTransferManager:: @@ -1008,7 +938,7 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( tsl::RCReference event; { absl::MutexLock l(&mu_); - const auto& b = device_buffers_[buffer_index]->Buffers()[0]; + const auto& b = device_buffers_[buffer_index]->buffer(); CHECK(b.IsConcrete()); fill_fn(reinterpret_cast(b->untyped_data()), b->size_bytes()); if (is_last_transfer) { diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 2d1c77b22dc9aa..46f460a7d34b87 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -987,13 +987,11 @@ TfrtCpuClient::CreateViewOfDeviceBuffer( reinterpret_cast(device_ptr), cpu_function_runtime::MinAlign()); } - absl::InlinedVector, 4> buffers; size_t byte_size = ShapeUtil::ByteSizeOf(shape); auto non_owning_buffer = tsl::MakeAvailableAsyncValueRef(device_ptr, byte_size); - buffers.push_back(std::move(non_owning_buffer)); auto tracked_device_buffer = std::make_unique( - /*is_tuple=*/false, /*owns_buffers=*/false, std::move(buffers), + /*owns_buffers=*/false, std::move(non_owning_buffer), /*definition_event=*/tsl::MakeAvailableAsyncValueRef(), std::move(on_delete_callback)); CHECK_EQ(memory_space->devices().size(), 1); @@ -1017,9 +1015,7 @@ absl::StatusOr> TfrtCpuClient::CreateErrorBuffer( return std::make_unique( shape, std::make_unique( - /*is_tuple=*/false, /*owns_buffers=*/true, - absl::InlinedVector, 4>{ - std::move(buffer)}, + /*owns_buffers=*/true, std::move(buffer), absl::InlinedVector, 4>{ tsl::AsyncValueRef( tsl::MakeErrorAsyncValueRef(std::move(error)))}), @@ -1325,18 +1321,23 @@ static absl::StatusOr MemoryForAllocation( } else if (allocation.param_shape_index().size() == 1) { std::tie(can_donate, arg) = arguments[allocation.param_shape_index()[0]]; - out = arg->Buffer({}); - buffer_size = arg->BufferSize({}); + out = arg->buffer().AsPtr(); + buffer_size = arg->BufferSize(); } else { return absl::InvalidArgumentError(absl::StrCat( "Nested tuples are not supported for argument: ", allocation.parameter_number(), " at shape index:", allocation.param_shape_index().ToString())); } + } else if (!allocation.param_shape_index().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Nested tuples are not supported for argument: ", + allocation.parameter_number(), + " at shape index:", allocation.param_shape_index().ToString())); } else { std::tie(can_donate, arg) = arguments[allocation.parameter_number()]; - out = arg->Buffer(allocation.param_shape_index()); - buffer_size = arg->BufferSize(allocation.param_shape_index()); + out = arg->buffer().AsPtr(); + buffer_size = arg->BufferSize(); } CHECK_EQ(allocation.size(), buffer_size) << "Size mismatch on param " << allocation.parameter_number() @@ -1577,8 +1578,7 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( absl::InlinedVector, 4> leaf_buffers; leaf_buffers.reserve(tracked_buffers.size()); for (const auto& tracked_buffer : tracked_buffers) { - auto span = tracked_buffer.second->Buffers(); - leaf_buffers.insert(leaf_buffers.end(), span.begin(), span.end()); + leaf_buffers.push_back(tracked_buffer.second->buffer()); } tuple_index_table = tsl::MakeUnconstructedAsyncValueRef(); tsl::RunWhenReady( @@ -1924,7 +1924,7 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( // Create output TFRT buffers. const Shape& result_shape = cpu_executable_->result_shape(); std::vector> res; - if (options.untuple_result && result_shape.IsTuple()) { + if (result_shape.IsTuple()) { res.reserve(result_buffers_info.size()); for (int i = 0; i < result_buffers_info.size(); ++i) { // Program execution writes to output buffers so it's a definition event. @@ -1932,32 +1932,21 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( definition_events.push_back(execute_event.CopyRef()); auto leaf_tracked_device_buffer = std::make_unique( - /*is_tuple=*/false, result_buffers_info[i].owns_buffer, - absl::InlinedVector, 4>{ - std::move(result_buffers_info[i].buffer)}, - absl::InlinedVector{ - result_buffers_info[i].buffer_size}, - std::move(definition_events)); + result_buffers_info[i].owns_buffer, + std::move(result_buffers_info[i].buffer), + result_buffers_info[i].buffer_size, std::move(definition_events)); auto leaf_buffer = std::make_unique( result_shape.tuple_shapes(i), std::move(leaf_tracked_device_buffer), client_, device, *device->default_memory_space()); res.push_back(std::move(leaf_buffer)); } } else { - bool owns_buffers = true; - absl::InlinedVector, 4> sub_buffers; - absl::InlinedVector sub_buffer_sizes; - sub_buffers.reserve(result_buffers_info.size()); - sub_buffer_sizes.reserve(result_buffers_info.size()); - for (int i = 0; i < result_buffers_info.size(); ++i) { - owns_buffers = owns_buffers && result_buffers_info[i].owns_buffer; - sub_buffers.push_back(std::move(result_buffers_info[i].buffer)); - sub_buffer_sizes.push_back(result_buffers_info[i].buffer_size); - } + CHECK_EQ(result_buffers_info.size(), 1); // Program execution writes to output buffers so it's a definition event. auto tracked_device_buffer = std::make_unique( - /*is_tuple=*/result_shape.IsTuple(), owns_buffers, - std::move(sub_buffers), std::move(sub_buffer_sizes), + result_buffers_info[0].owns_buffer, + std::move(result_buffers_info[0].buffer), + result_buffers_info[0].buffer_size, /*definition_event=*/execute_event); auto tfrt_output_buffer = std::make_unique( result_shape, std::move(tracked_device_buffer), client_, device, @@ -2023,6 +2012,14 @@ TfrtCpuExecutable::Execute( tsl::profiler::TraceMeProducer activity("TfrtCpuExecutable::Execute", tsl::profiler::ContextType::kPjRt, run_id.ToInt()); + if (!options.untuple_result && cpu_executable_->module() + .config() + .entry_computation_layout() + .result_shape() + .IsTuple()) { + return InvalidArgument( + "Tuple results must be untupled using ExecuteOptions::untuple_result."); + } if (device_assignment_ == nullptr) { return InvalidArgument("Execute expects a non-null device_assignment"); } @@ -2146,6 +2143,14 @@ TfrtCpuExecutable::ExecuteSharded( if (device_assignment_ == nullptr) { return InvalidArgument("ExecuteShard expects a non-null device_assignment"); } + if (!options.untuple_result && cpu_executable_->module() + .config() + .entry_computation_layout() + .result_shape() + .IsTuple()) { + return InvalidArgument( + "Tuple results must be untupled using ExecuteOptions::untuple_result."); + } for (int i = 0; i < addressable_devices_.size(); ++i) { if (addressable_devices_[i] == device) { VLOG(1) << "ExecuteShard executes computation " << name() diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc index be7ee83a4a1979..6a3856e60a21b5 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc @@ -109,32 +109,28 @@ absl::StatusOr CpuDeviceMemory::Allocate(size_t size_bytes) { } TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, + bool owns_buffers, tsl::AsyncValueRef buffer, absl::InlinedVector, 4> definition_events, absl::AnyInvocable on_delete_callback) - : TrackedCpuDeviceBuffer(is_tuple, owns_buffers, std::move(buffers), + : TrackedCpuDeviceBuffer(owns_buffers, std::move(buffer), AfterAll(definition_events), std::move(on_delete_callback)) {} TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, - absl::InlinedVector buffer_sizes, + bool owns_buffers, tsl::AsyncValueRef buffer, + size_t buffer_size, absl::InlinedVector, 4> definition_events, absl::AnyInvocable on_delete_callback) - : TrackedCpuDeviceBuffer( - is_tuple, owns_buffers, std::move(buffers), std::move(buffer_sizes), - AfterAll(definition_events), std::move(on_delete_callback)) {} + : TrackedCpuDeviceBuffer(owns_buffers, std::move(buffer), buffer_size, + AfterAll(definition_events), + std::move(on_delete_callback)) {} TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, + bool owns_buffers, tsl::AsyncValueRef buffer, tsl::AsyncValueRef definition_event, absl::AnyInvocable on_delete_callback) - : is_tuple_(is_tuple), - owns_buffers_(owns_buffers), - buffers_(std::move(buffers)), + : owns_buffers_(owns_buffers), + buffers_({std::move(buffer)}), definition_event_(std::move(definition_event)), on_delete_callback_(std::move(on_delete_callback)) { DCHECK(definition_event_); @@ -142,49 +138,20 @@ TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( CHECK(buffer.IsConcrete()); buffer_sizes_.push_back(buffer->size_bytes()); } - if (is_tuple) { - size_t index_table_byte_size = buffers_.size() * sizeof(void*); - // We assume tuple table allocations will not fail. - tuple_index_table_ = - CpuDeviceMemory::AllocateAvailable(index_table_byte_size).value(); - uintptr_t* index_table = - reinterpret_cast(tuple_index_table_->untyped_data()); - for (int i = 0; i < buffers_.size(); ++i) { - index_table[i] = absl::bit_cast(buffers_[i]->untyped_data()); - } - } + CHECK_EQ(buffers_.size(), 1); } TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, - absl::InlinedVector buffer_sizes, - tsl::AsyncValueRef definition_event, + bool owns_buffers, tsl::AsyncValueRef buffer, + size_t buffer_size, tsl::AsyncValueRef definition_event, absl::AnyInvocable on_delete_callback) - : is_tuple_(is_tuple), - owns_buffers_(owns_buffers), - buffers_(std::move(buffers)), - buffer_sizes_(std::move(buffer_sizes)), + : owns_buffers_(owns_buffers), + buffers_({std::move(buffer)}), + buffer_sizes_({buffer_size}), definition_event_(std::move(definition_event)), on_delete_callback_(std::move(on_delete_callback)) { DCHECK(definition_event_); - if (is_tuple) { - tuple_index_table_ = tsl::MakeUnconstructedAsyncValueRef(); - tsl::RunWhenReady( - absl::MakeConstSpan(buffers_), - [buffers = buffers_, tuple_index_table = tuple_index_table_] { - size_t index_table_byte_size = buffers.size() * sizeof(void*); - // We assume tuple table allocations will not fail. - tuple_index_table.emplace( - CpuDeviceMemory::Allocate(index_table_byte_size).value()); - uintptr_t* index_table = - reinterpret_cast(tuple_index_table->untyped_data()); - for (int i = 0; i < buffers.size(); ++i) { - index_table[i] = - absl::bit_cast(buffers[i]->untyped_data()); - } - }); - } + CHECK_EQ(buffers_.size(), 1); } TrackedCpuDeviceBuffer::~TrackedCpuDeviceBuffer() { @@ -196,28 +163,11 @@ TrackedCpuDeviceBuffer::~TrackedCpuDeviceBuffer() { tsl::AsyncValuePtr TrackedCpuDeviceBuffer::Buffer( const ShapeIndex& shape_index) { - if (shape_index.empty()) { - // shape_index={} - if (is_tuple_) return tuple_index_table_.AsPtr(); - return buffers_[0].AsPtr(); - } - // shape_index={i} - CHECK(is_tuple_); - CHECK_EQ(shape_index.size(), 1) << "nested tuple not supported"; - return buffers_[shape_index[0]].AsPtr(); + CHECK(shape_index.empty()); + return buffers_[0].AsPtr(); } -size_t TrackedCpuDeviceBuffer::BufferSize(const ShapeIndex& shape_index) { - if (shape_index.empty()) { - // shape_index={} - if (is_tuple_) return buffers_.size() * sizeof(void*); - return buffer_sizes_[0]; - } - // shape_index={i} - CHECK(is_tuple_); - CHECK_EQ(shape_index.size(), 1) << "nested tuple not supported"; - return buffer_sizes_[shape_index[0]]; -} +size_t TrackedCpuDeviceBuffer::BufferSize() { return buffer_sizes_[0]; } void TrackedCpuDeviceBuffer::AddUsageEvents( absl::Span> events) { @@ -246,7 +196,6 @@ TrackedCpuDeviceBuffer::LockUseAndTransferUsageEvents() { } void TrackedCpuDeviceBuffer::ReleaseDeviceMemory() { - tuple_index_table_.reset(); buffers_.clear(); definition_event_.reset(); usage_events_.clear(); diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h index 739dd5a245d74d..c33526af99c947 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h @@ -73,39 +73,34 @@ class TrackedCpuDeviceBuffer { // For tuple, takes the leaf buffers. Tuple index table created internally. // Nested tuple is not supported. - // Constructor for allocated cpu memory, i.e., `buffers` should have concrete + // Constructor for allocated cpu memory, i.e., `buffer` should have concrete // states. Definition event is after the list of `definition_events`. TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, + bool owns_buffers, tsl::AsyncValueRef buffer, absl::InlinedVector, 4> definition_events, absl::AnyInvocable on_delete_callback = nullptr); // Variant with single definition event. TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, + bool owns_buffers, tsl::AsyncValueRef buffer, tsl::AsyncValueRef definition_event, absl::AnyInvocable on_delete_callback = nullptr); - // Constructor for unallocated cpu memory, i.e., `buffers` have unconstructed - // states, also needs to provide `buffer_sizes` which will be the sizes of - // the `buffers` after allocation. Definition event is after the list of - // `definition_events`. Callers need to ensure cpu memory is allocated before - // the definition event is ready. + // Constructor for unallocated cpu memory, i.e., `buffer` will have + // unconstructed states, and we also need to provide `buffer_size` which will + // be the size of the `buffer` after allocation. Definition event is after the + // list of `definition_events`. Callers need to ensure cpu memory is allocated + // before the definition event is ready. TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, - absl::InlinedVector buffer_sizes, + bool owns_buffers, tsl::AsyncValueRef buffer, + size_t buffer_size, absl::InlinedVector, 4> definition_events, absl::AnyInvocable on_delete_callback = nullptr); // Variant with single definition event. TrackedCpuDeviceBuffer( - bool is_tuple, bool owns_buffers, - absl::InlinedVector, 4> buffers, - absl::InlinedVector buffer_sizes, - tsl::AsyncValueRef definition_event, + bool owns_buffers, tsl::AsyncValueRef buffer, + size_t buffer_size, tsl::AsyncValueRef definition_event, absl::AnyInvocable on_delete_callback = nullptr); TrackedCpuDeviceBuffer(TrackedCpuDeviceBuffer&&) noexcept = default; @@ -114,15 +109,13 @@ class TrackedCpuDeviceBuffer { ~TrackedCpuDeviceBuffer(); - absl::Span> Buffers() { - return buffers_; - } + const tsl::AsyncValueRef& buffer() { return buffers_[0]; } absl::Span BufferSizes() { return buffer_sizes_; } tsl::AsyncValuePtr Buffer(const ShapeIndex& shape_index); - size_t BufferSize(const ShapeIndex& shape_index); + size_t BufferSize(); const tsl::AsyncValueRef& definition_event() const { return definition_event_; @@ -146,11 +139,8 @@ class TrackedCpuDeviceBuffer { // buffer is passed to a computation that aliases its inputs to outputs. void ReleaseDeviceMemory(); - bool is_tuple_; bool owns_buffers_; - // If tuple, tuple index table is created and stored. - tsl::AsyncValueRef tuple_index_table_; // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers. absl::InlinedVector, 4> buffers_; // Should correspond to size of each buffer in `buffers_` when `buffers_` is diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc index 913d0281014299..c01015c8472a95 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc @@ -51,60 +51,18 @@ TEST(TrackedCpuDeviceBufferTest, Basic) { }); TrackedCpuDeviceBuffer tracked_buffer( - /*is_tuple=*/false, /*owns_buffers=*/true, {buffer}, definition_event, + /*owns_buffers=*/true, buffer, definition_event, /*on_delete_callback_=*/nullptr); BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); - auto result = tracked_buffer.Buffers()[0]; + auto result = tracked_buffer.buffer(); ASSERT_TRUE(result.IsAvailable()); EXPECT_EQ(std::string(static_cast(result->untyped_data()), result->size_bytes()), expected); } -TEST(TrackedCpuDeviceBufferTest, Tuple) { - std::string expected_0 = "tracked_cpu_device_buffer_test"; - std::string expected_1 = "tuple"; - TF_ASSERT_OK_AND_ASSIGN( - auto buffer_0, CpuDeviceMemory::AllocateAvailable(expected_0.size())); - TF_ASSERT_OK_AND_ASSIGN( - auto buffer_1, CpuDeviceMemory::AllocateAvailable(expected_1.size())); - - auto definition_event_0 = MakeConstructedAsyncValueRef(); - auto definition_event_1 = MakeConstructedAsyncValueRef(); - - ThreadPool thread_pool(tsl::Env::Default(), "tracked_buffer_test", - /*num_threads=*/4); - - thread_pool.Schedule([&]() { - std::memcpy(buffer_0->untyped_data(), expected_0.data(), expected_0.size()); - definition_event_0.SetStateConcrete(); - }); - thread_pool.Schedule([&]() { - std::memcpy(buffer_1->untyped_data(), expected_1.data(), expected_1.size()); - definition_event_1.SetStateConcrete(); - }); - - TrackedCpuDeviceBuffer tracked_buffer( - /*is_tuple=*/true, /*owns_buffers=*/true, {buffer_0, buffer_1}, - {definition_event_0, definition_event_1}, - /*on_delete_callback_=*/nullptr); - - BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); - - auto result_0 = tracked_buffer.Buffers()[0]; - auto result_1 = tracked_buffer.Buffers()[1]; - ASSERT_TRUE(result_0.IsAvailable()); - ASSERT_TRUE(result_1.IsAvailable()); - EXPECT_EQ(std::string(static_cast(result_0->untyped_data()), - result_0->size_bytes()), - expected_0); - EXPECT_EQ(std::string(static_cast(result_1->untyped_data()), - result_1->size_bytes()), - expected_1); -} - TEST(TrackedCpuDeviceBufferTest, BasicError) { TF_ASSERT_OK_AND_ASSIGN(auto buffer, CpuDeviceMemory::AllocateAvailable(64)); @@ -119,7 +77,7 @@ TEST(TrackedCpuDeviceBufferTest, BasicError) { }); TrackedCpuDeviceBuffer tracked_buffer( - /*is_tuple=*/false, /*owns_buffers=*/true, {buffer}, definition_event, + /*owns_buffers=*/true, buffer, definition_event, /*on_delete_callback_=*/nullptr); BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); @@ -129,40 +87,6 @@ TEST(TrackedCpuDeviceBufferTest, BasicError) { "tracked_cpu_device_buffer_test error."); } -TEST(TrackedCpuDeviceBufferTest, TupleError) { - std::string expected = "tracked_cpu_device_buffer_test"; - TF_ASSERT_OK_AND_ASSIGN(auto buffer_0, - CpuDeviceMemory::AllocateAvailable(expected.size())); - TF_ASSERT_OK_AND_ASSIGN(auto buffer_1, - CpuDeviceMemory::AllocateAvailable(expected.size())); - - auto definition_event_0 = MakeConstructedAsyncValueRef(); - auto definition_event_1 = MakeConstructedAsyncValueRef(); - - ThreadPool thread_pool(tsl::Env::Default(), "tracked_buffer_test", - /*num_threads=*/4); - - thread_pool.Schedule([&]() { - std::memcpy(buffer_0->untyped_data(), expected.data(), expected.size()); - definition_event_0.SetStateConcrete(); - }); - thread_pool.Schedule([&]() { - definition_event_1.SetError( - Internal("tracked_cpu_device_buffer_test tuple error.")); - }); - - TrackedCpuDeviceBuffer tracked_buffer( - /*is_tuple=*/true, /*owns_buffers=*/true, {buffer_0, buffer_1}, - {definition_event_0, definition_event_1}, - /*on_delete_callback_=*/nullptr); - - BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); - - ASSERT_TRUE(tracked_buffer.definition_event().IsError()); - EXPECT_EQ(tracked_buffer.definition_event().GetError().message(), - "tracked_cpu_device_buffer_test tuple error."); -} - TEST(TrackedCpuDeviceBufferTest, DelayedAllocation) { std::string expected = "tracked_cpu_device_buffer_test"; @@ -174,11 +98,10 @@ TEST(TrackedCpuDeviceBufferTest, DelayedAllocation) { }); auto definition_event = MakeConstructedAsyncValueRef(); - TrackedCpuDeviceBuffer tracked_buffer(/*is_tuple=*/false, - /*owns_buffers=*/true, {buffer}, - {expected.size()}, definition_event, + TrackedCpuDeviceBuffer tracked_buffer(/*owns_buffers=*/true, buffer, + expected.size(), definition_event, /*on_delete_callback_=*/nullptr); - auto result = tracked_buffer.Buffers()[0]; + auto result = tracked_buffer.buffer(); ASSERT_FALSE(result.IsAvailable()); ASSERT_EQ(tracked_buffer.BufferSizes()[0], expected.size()); @@ -198,62 +121,5 @@ TEST(TrackedCpuDeviceBufferTest, DelayedAllocation) { expected); } -TEST(TrackedCpuDeviceBufferTest, DelayedAllocationTuple) { - std::string expected_0 = "tracked_cpu_device_buffer_test"; - std::string expected_1 = "tuple"; - - auto buffer_0 = MakeUnconstructedAsyncValueRef(); - auto malloc_event_0 = MakeConstructedAsyncValueRef(); - malloc_event_0.AndThen( - [buffer_0_copy = buffer_0.CopyRef(), buffer_0_size = expected_0.size()] { - buffer_0_copy.emplace(CpuDeviceMemory::Allocate(buffer_0_size).value()); - }); - auto buffer_1 = MakeUnconstructedAsyncValueRef(); - auto malloc_event_1 = MakeConstructedAsyncValueRef(); - malloc_event_1.AndThen( - [buffer_1_copy = buffer_1.CopyRef(), buffer_1_size = expected_1.size()] { - buffer_1_copy.emplace(CpuDeviceMemory::Allocate(buffer_1_size).value()); - }); - - auto definition_event_0 = MakeConstructedAsyncValueRef(); - auto definition_event_1 = MakeConstructedAsyncValueRef(); - TrackedCpuDeviceBuffer tracked_buffer( - /*is_tuple=*/true, - /*owns_buffers=*/true, {buffer_0, buffer_1}, - {expected_0.size(), expected_1.size()}, - {definition_event_0, definition_event_1}, - /*on_delete_callback_=*/nullptr); - - auto result_0 = tracked_buffer.Buffers()[0]; - auto result_1 = tracked_buffer.Buffers()[1]; - ASSERT_FALSE(result_0.IsAvailable()); - ASSERT_FALSE(result_1.IsAvailable()); - ASSERT_EQ(tracked_buffer.BufferSizes()[0], expected_0.size()); - ASSERT_EQ(tracked_buffer.BufferSizes()[1], expected_1.size()); - - ThreadPool thread_pool(tsl::Env::Default(), "tracked_buffer_test", - /*num_threads=*/4); - - thread_pool.Schedule([&]() { - malloc_event_0.SetStateConcrete(); - std::memcpy(buffer_0->untyped_data(), expected_0.data(), expected_0.size()); - definition_event_0.SetStateConcrete(); - }); - thread_pool.Schedule([&]() { - malloc_event_1.SetStateConcrete(); - std::memcpy(buffer_1->untyped_data(), expected_1.data(), expected_1.size()); - definition_event_1.SetStateConcrete(); - }); - - BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); - - EXPECT_EQ(std::string(static_cast(result_0->untyped_data()), - result_0->size_bytes()), - expected_0); - EXPECT_EQ(std::string(static_cast(result_1->untyped_data()), - result_1->size_bytes()), - expected_1); -} - } // namespace } // namespace xla From 4f9ec868fcedcb32e601edec031b69b079a262a2 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 2 Apr 2025 14:00:44 -0700 Subject: [PATCH 0174/1324] [XLA:GPU] Fork `fusion_emitter_device_legacy_test.cc` into `fusion_emitter_device_legacy_port_test.cc`. The goal of this new file is to faithfully replicate the tests done in its parent using the generic Triton emitter infrastructure, in order to track our progress in replacing it. For now, most tests are disabled, and we'll be looking into enabling them one by one. PiperOrigin-RevId: 743269941 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 50 + .../fusion_emitter_device_legacy_port_test.cc | 4193 +++++++++++++++++ 2 files changed, 4243 insertions(+) create mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index cdd8ec6b54dfe4..72c937aa373ac1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -546,6 +546,55 @@ xla_test( ], ) +xla_test( + name = "fusion_emitter_device_legacy_port_test", + srcs = if_gpu_is_configured(["fusion_emitter_device_legacy_port_test.cc"]), + backends = [ + "gpu_a100", + "gpu_h100", + "gpu_b200", + "gpu_amd_any", + ], + tags = [ + "no_mac", + ], + deps = [ + ":fusion_emitter", + ":test_utils", + "//xla:autotuning_proto_cc", + "//xla:error_spec", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/utils:hlo_query", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/service/gpu/transforms:nest_gemm_fusion", + "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:path", + ], +) + xla_test( name = "fusion_emitter_int4_device_test", size = "large", @@ -746,6 +795,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/service/gpu/transforms:nest_gemm_fusion", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc new file mode 100644 index 00000000000000..0bc9bded9a705c --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -0,0 +1,4193 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "llvm/IR/LLVMContext.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/PassManager.h" +#include "xla/autotuning.pb.h" +#include "xla/backends/gpu/codegen/triton/fusion_emitter.h" +#include "xla/backends/gpu/codegen/triton/test_utils.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/nest_gemm_fusion.h" +#include "xla/service/pattern_matcher.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/xla.pb.h" +#include "tsl/platform/path.h" + +namespace xla { +namespace gpu { +namespace { + +namespace m = ::xla::match; +using tsl::testing::StatusIs; + +struct ModuleAndNestedFusionMetadata { + std::unique_ptr module; + HloComputation* computation; + BlockLevelParameters block_level_parameters; +}; + +class TritonTest : public GpuCodegenTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options + .set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(true); + return debug_options; + } + + stream_executor::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + const stream_executor::GpuComputeCapability& GpuComputeCapability() { + return device_desc().gpu_compute_capability(); + } + + stream_executor::GpuComputeCapability CudaAmpereOrRocm() { + if (std::holds_alternative( + GpuComputeCapability())) { + return stream_executor::GpuComputeCapability{ + device_desc().rocm_compute_capability()}; + } else { + return stream_executor::GpuComputeCapability{ + stream_executor::CudaComputeCapability{ + stream_executor::CudaComputeCapability::kAmpere, 0}}; + } + } + + // Returns the computation and block level parameters from an HLO module text + // whose entry computation contains a single GEMM fusion. + absl::StatusOr + GetModuleAndNestedFusionMetadata(absl::string_view hlo_text) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSIGN_OR_RETURN( + bool fusion_was_nested, + NestGemmFusion(GpuComputeCapability()).Run(module.get())); + if (!fusion_was_nested) { + return absl::InternalError("Failed to nest the GEMM fusion."); + } + HloFusionInstruction* fusion = + Cast(hlo_query::GetFirstInstructionWithOpcode( + *module->entry_computation(), HloOpcode::kFusion)); + HloComputation* computation = fusion->fused_instructions_computation(); + BlockLevelParameters block_level_parameters = + BlockLevelParameters::FromBlockLevelFusionConfig( + fusion->backend_config() + ->fusion_backend_config() + .block_level_fusion_config()); + return ModuleAndNestedFusionMetadata{std::move(module), computation, + std::move(block_level_parameters)}; + } + + protected: + const stream_executor::DeviceDescription& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } +}; + +class TritonGemmTest : public TritonTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS and disable cuDNN; we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } + + void MatchHloModule(HloModule& module, absl::string_view pattern) { + TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result, + RunFileCheck(module.ToString(), pattern)); + EXPECT_TRUE(filecheck_result); + } +}; + +class TritonGemmTestWithSplitK : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_split_k_autotuning(true); + return debug_options; + } +}; + +TEST_F(TritonGemmTest, DISABLED_FP8DotSmallTileDoesNotCrash) { + GTEST_SKIP() << "TODO(b/337839570): Re-enable once the bug is fixed. " + "Currently the test is not representative of the issue. " + "While the test passes, the end-to-end model fails."; + + if (!GetCudaComputeCapability().IsAtLeastHopper()) { + GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; + } + + constexpr absl::string_view kHloText = R"( +HloModule m + +triton_dot { + %parameter_0 = f8e4m3fn[32,32]{1,0} parameter(0) + %parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + ROOT %dot.1643 = bf16[32,32]{1,0} dot(f8e4m3fn[32,32]{1,0} %parameter_0, f8e4m3fn[32,32]{0,1} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f8e4m3fn[32,32]{1,0} parameter(0) + p1 = f8e4m3fn[32,32]{1,0} parameter(1) + ROOT _ = bf16[32,32] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":16,"block_k":16, + "split_k":1,"num_stages":2,"num_warps":2, + "num_ctas":1}}} +})"; + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); +} + +TEST_F(TritonTest, DISABLED_TestGemmWithTrivialNonContractingDimension) { + constexpr absl::string_view kHloText = R"( +HloModule t, is_scheduled=true + +triton_dot { + p0 = f32[137,115]{1,0} parameter(0) + p1 = f32[1,115]{1,0} parameter(1) + ROOT dot = f32[137,1]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[137,115]{1,0} parameter(0) + p1 = f32[1,115]{1,0} parameter(1) + ROOT custom-call = f32[137,1]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":16,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(*module_and_metadata.computation, + module_and_metadata.block_level_parameters, + R"( +CHECK: tt.dot {{.*}} : tensor<16x32xf32> * tensor<32x16xf32> -> tensor<16x16xf32> +)")); +} + +TEST_F(TritonTest, PredParametersAreTruncatedToI1) { + constexpr absl::string_view kHloText = R"( +HloModule m + +triton_gemm_computation { + p = pred[2,2]{1,0} parameter(0) + a = f32[2,2]{1,0} parameter(1) + b = f32[2,2]{1,0} parameter(2) + c = f32[2,2]{1,0} parameter(3) + compare = pred[2,2]{1,0} compare(a, b), direction=LT + and = pred[2,2]{1,0} and(p, compare) + convert = f32[2,2]{1,0} convert(and) + ROOT r = f32[2,2]{1,0} dot(convert, c), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p = pred[2,2]{1,0} parameter(0) + a = f32[2,2]{1,0} parameter(1) + b = f32[2,2]{1,0} parameter(2) + c = f32[2,2]{1,0} parameter(3) + ROOT triton_gemm = f32[2,2]{1,0} fusion(p, a, b, c), kind=kCustom, + calls=triton_gemm_computation, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: { + "block_m":16,"block_n":16,"block_k":16, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1 + } + } + } +} +)"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(*module_and_metadata.computation, + module_and_metadata.block_level_parameters, + R"( +CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr> +CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1> +CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> +)")); +} + +TEST_F(TritonTest, + DISABLED_CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { + constexpr absl::string_view kHloText = R"( +HloModule t, is_scheduled=true + +triton_gemm { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + concatenate = f32[2,10,384]{2,1,0} concatenate(parameter_1, parameter_2), dimensions={2} + ROOT dot = f32[2,3,384]{2,1,0} dot(parameter_0, concatenate), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + ROOT dot = f32[2,3,384]{2,1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + + TF_EXPECT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr +CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr +CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr +CHECK-DAG: %[[ARG_PTR:.*]] = arith.select %[[CONCAT_COND:.*]], %[[P1]], %[[P2]] +CHECK-DAG: %[[BATCH_STRIDE_P1:.*]] = arith.constant 1280 +CHECK-DAG: %[[BATCH_STRIDE_P2:.*]] = arith.constant 2560 +CHECK-DAG: %[[BATCH_STRIDE:.*]] = arith.select %[[CONCAT_COND_2:.*]], %[[BATCH_STRIDE_P1]], %[[BATCH_STRIDE_P2]] +CHECK-DAG: %[[PID_BATCH:.*]] = tt.get_program_id y +CHECK-DAG: %[[OFFSET:.*]] = arith.muli %[[PID_BATCH]], %[[BATCH_STRIDE]] +CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] +)")); +} + +TEST_F(TritonTest, DISABLED_CodegenDynamicSliceWithCorrectOffsets) { + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + constexpr absl::string_view kHloText = R"( +HloModule t + +triton_gemm { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[4,5,2] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] parameter(3) + start_index2 = s32[] parameter(4) + dynamic_slice = f32[1,5,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1, start_index2), dynamic_slice_sizes={1,5,2} + bitcast = f32[5,2] bitcast(dynamic_slice) + ROOT dot = f32[4,5] dot(dot_lhs, bitcast), lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[4,5,2] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] constant(0) + start_index2 = s32[] constant(0) + ROOT fusion = f32[4,5] fusion(dot_lhs, dynamic_slice_input, start_index0, start_index1, start_index2), + kind=kCustom, calls=triton_gemm, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm","triton_gemm_config":{ + "block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + ASSERT_THAT( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"( +CHECK: tt.func @triton_fn({{[^,]*}}, %[[DYNAMIC_SLICE_INPUT:[^:]*]]: !tt.ptr {{[^,]*}}, %[[START_INDEX0_PTR:[^:]*]]: !tt.ptr +CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32 +CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 +CHECK-DAG: %[[C2_i64:.*]] = arith.constant 2 : i64 +CHECK-DAG: %[[C3_i32:.*]] = arith.constant 3 : i32 +CHECK-DAG: %[[C5_i32:.*]] = arith.constant 5 : i32 +CHECK-DAG: %[[C5_i64:.*]] = arith.constant 5 : i64 +CHECK-DAG: %[[START_INDEX0:.*]] = tt.load %[[START_INDEX0_PTR]] : !tt.ptr +CHECK-DAG: %[[SEMI_CLAMPED_START_INDEX0:.*]] = arith.maxsi %[[START_INDEX0]], %[[C0_i32]] : i32 +CHECK-DAG: %[[CLAMPED_START_INDEX0:.*]] = arith.minsi %[[SEMI_CLAMPED_START_INDEX0]], %[[C3_i32]] : i32 +CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[CLAMPED_START_INDEX0]], %[[C5_i32]] : i32 +CHECK-DAG: %[[ROW_OFFSET_i64:.*]] = arith.extsi %[[ROW_OFFSET]] : i32 to i64 +CHECK-DAG: %[[ROW_LIMIT:.*]] = arith.addi %[[ROW_OFFSET_i64]], %[[C5_i64]] : i64 +CHECK-DAG: tt.make_tensor_ptr %[[DYNAMIC_SLICE_INPUT]], [%[[C2_i64]], %[[ROW_LIMIT]]], [%[[C1_i64]], %[[C2_i64]]], [%[[C0_i32]], %[[ROW_OFFSET]]] +)"), + tsl::testing::IsOk()); +} + +TEST_F(TritonGemmTest, DISABLED_DoNotUseTensorCoresWithNonDefaultPrecision) { + constexpr absl::string_view kHloText = R"( +triton_gemm_r { + parameter_0 = s8[80,15]{1,0} parameter(0) + convert.3 = f32[80,15]{1,0} convert(parameter_0) + parameter_1 = f32[16,15]{1,0} parameter(1) + ROOT r.1 = f32[80,16]{1,0} dot(convert.3, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p1 = f32[16,15]{1,0} parameter(1) + p0 = s8[80,15]{1,0} parameter(0) + ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_r, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK-NOT: mma +)"); +} + +TEST_F(TritonGemmTest, DISABLED_DebugOptionsArePropagated) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f16[30,30] parameter(0) + p1 = s8[30,30] parameter(1) + cp1 = f16[30,30] convert(p1) + ROOT _ = f16[30,30] dot(p0, cp1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + std::string output_directory; + if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { + output_directory = tsl::testing::TmpDir(); + } + DebugOptions debug_options = verified_module->config().debug_options(); + debug_options.set_xla_dump_to(output_directory); + debug_options.set_xla_dump_hlo_pass_re("triton-fusion-emitter"); + verified_module->mutable_config().set_debug_options(debug_options); + + EXPECT_TRUE(RunAndCompare(std::move(verified_module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + + std::vector paths; + TF_EXPECT_OK(tsl::Env::Default()->GetMatchingPaths( + tsl::io::JoinPath(output_directory, "*.triton-passes.log"), &paths)); + EXPECT_EQ(paths.size(), 1); +} + +TEST_F(TritonGemmTest, DISABLED_DotWithPredFromCompareProducesCorrectResult) { + const std::string hlo_text = R"( +triton_dot { + parameter_0 = s32[4,128]{1,0} parameter(0) + broadcast.255 = s32[4,128,64]{2,1,0} broadcast(parameter_0), dimensions={0,1} + parameter_1 = s32[4,128,64]{2,1,0} parameter(1) + compare.39 = pred[4,128,64]{2,1,0} compare(broadcast.255, parameter_1), direction=EQ + bitcast.1097 = pred[512,64]{1,0} reshape(compare.39) + convert.229 = bf16[512,64]{1,0} convert(bitcast.1097) + parameter_2 = bf16[64,256]{0,1} parameter(2) + ROOT dot.21 = bf16[512,256]{1,0} dot(convert.229, parameter_2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +ENTRY main { + p0 = s32[4,128]{1,0} parameter(0) + p1 = s32[4,128,64]{2,1,0} parameter(1) + p2 = bf16[64,256]{0,1} parameter(2) + ROOT gemm_fusion_dot.0 = bf16[512,256]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_dot, backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"128","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_UseTensorCoresForF32OnAmpere) { + constexpr absl::string_view kHloText = R"( +triton_gemm_r { + parameter_0 = f16[80,15]{1,0} parameter(0) + convert.3 = f32[80,15]{1,0} convert(parameter_0) + parameter_1 = f32[16,15]{1,0} parameter(1) + ROOT r.1 = f32[80,16]{1,0} dot(convert.3, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p1 = f32[16,15]{1,0} parameter(1) + p0 = f16[80,15]{1,0} parameter(0) + ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_r, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: mma +)"); +} + +TEST_F(TritonGemmTest, DISABLED_FailIfTooMuchShmem) { + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } + constexpr absl::string_view kHloText = R"( +HloModule module, is_scheduled=true + +triton_gemm_dot { + p0 = s8[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + c0 = f32[1024,1024] convert(p0) + ROOT dot.0 = f32[1024,1024] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = s8[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloText)); + HloFusionInstruction* triton_dot_fusion = Cast( + hlo_module->entry_computation()->root_instruction()); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + auto backend_config_or = + triton_dot_fusion->backend_config(); + TF_ASSERT_OK(backend_config_or); + GpuBackendConfig& backend_config = *backend_config_or; + + FusionBackendConfig& fusion_backend_config = + *backend_config.mutable_fusion_backend_config(); + auto& config = *fusion_backend_config.mutable_triton_gemm_config(); + config.set_block_m(16); + config.set_block_n(32); + config.set_block_k(512); + config.set_split_k(1); + config.set_num_ctas(1); + config.set_num_warps(8); + config.set_num_stages(4); + + TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); + + BlockLevelParameters block_level_parameters; + block_level_parameters.num_ctas = 1; + block_level_parameters.num_stages = 4; + block_level_parameters.num_warps = 8; + + EXPECT_THAT( + TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, + block_level_parameters, &llvm_module, mlir_context), + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + ::testing::HasSubstr("Shared memory size limit exceeded"))); + + config.set_block_m(64); + config.set_block_n(128); + config.set_block_k(128); + block_level_parameters.num_stages = 1; + TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); + + TF_ASSERT_OK_AND_ASSIGN( + const auto result, + TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, + block_level_parameters, &llvm_module, mlir_context)); + // Use optin shared memory which is > shared_memory_per_block. + EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); +} + +TEST_F(TritonGemmTestWithSplitK, + DISABLED_WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK) { + // The condition mentioned in the test name is fulfilled by + // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for + // Ampere at the time of the addition of this test case. + constexpr absl::string_view kHloText = R"( +HloModule extracted + +ENTRY e { + a = f16[16,5120]{1,0} parameter(0) + b = s8[5120,10240]{1,0} parameter(1) + converted_b = f16[5120,10240]{1,0} convert(b) + ROOT r = f16[16,10240]{1,0} dot(a, converted_b), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + // This check tests if Triton is used at all plus it runs GemmFusionAutotuner, + // which verifies if the generated kernels can run without errors such as + // CUDA_ERROR_ILLEGAL_ADDRESS. + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": + )"); + + // Not doing a comparison here, because the input matrices are quite big. + // If I reduce their size then they can no longer trigger the error, that I + // want to avoid with this test case. +} + +TEST_F(TritonGemmTest, DISABLED_MultipleDims) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f16[1,16,17,3] parameter(0) + p1 = s8[16,17,3] parameter(1) + cp1 = f16[16,17,3] convert(p1) + ROOT _ = f16[1,16,16] dot(p0, cp1), + lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": + )"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_PredWithBF16DotProducesCorrectResult) { + const std::string hlo_text = R"( +triton_dot { + p0 = pred[8,640]{1,0} parameter(0) + cvt = bf16[8,640]{1,0} convert(pred[8,640]{1,0} p0) + p1 = bf16[4096,640]{1,0} parameter(1) + ROOT dot.10277 = bf16[8,4096]{1,0} dot(cvt, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = pred[8,640]{1,0} parameter(0) + p1 = bf16[4096,640]{1,0} parameter(1) + ROOT dot = bf16[8,4096]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":16,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":8, + "num_ctas":1}}} +})"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_NoPadding) { + const char* hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f16[15,19] parameter(0) + p1 = s8[19,17] parameter(1) + cp1 = f16[19,17] convert(p1) + ROOT _ = f16[15,17] dot(p0, cp1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: ROOT +; CHECK-SAME: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +; CHECK-NOT: pad +; CHECK-NOT: slice +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_S8xS8) { + const std::string hlo_text = R"( +HloModule t + +ENTRY f { + x = s8[1024,1024]{1,0} parameter(0) + y = s8[1024,1024]{1,0} parameter(1) + ROOT z = s32[1024,1024]{1,0} dot(x, y), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SplitLhsNoncontractingTransposeRhs) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = pred[3,122,96,12]{3,2,1,0} parameter(0) + cp0 = f16[3,122,96,12]{3,2,1,0} convert(p0) + p1 = pred[1,5,122]{2,1,0} parameter(1) + cp1 = f16[1,5,122]{2,1,0} convert(p1) + ROOT _ = f16[3,96,12,1,5]{4,3,2,1,0} dot(cp0, cp1), + lhs_contracting_dims={1}, rhs_contracting_dims={2} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(TritonGemmTest, DISABLED_SplitLhsNoncontracting) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f32[72,72] parameter(0) + bc1 = f32[4,3,3,2,4,3,3,2] reshape(p0) + tr = f32[4,3,3,2,2,4,3,3] transpose(bc1), dimensions={0,1,2,3,7,4,5,6} + bc2 = f32[144,36] reshape(tr) + p1 = f16[36,3] parameter(1) + c7 = f32[36,3] convert(p1) + ROOT _ = f32[144,3] dot(bc2, c7), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SplitAndTransposeLhsExecutesCorrectly) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + tmp_0 = s8[5,50,2,128] parameter(1) + tmp_2 = s8[50,5,2,128] transpose(tmp_0), dimensions={1,0,2,3} + tmp_3 = s8[50,1280] reshape(tmp_2) + tmp_4 = f16[50,1280] convert(tmp_3) + tmp_5 = f16[50,79] parameter(0) + ROOT tmp_6 = f16[1280,79] dot(tmp_4, tmp_5), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: ROOT +; CHECK-SAME: fusion +; CHECK-SAME: kind=kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_NondefaultOperandLayoutIsSupported) { + // TODO(bchetioui): reenable when b/285866137 is fixed. +#ifndef NDEBUG + GTEST_SKIP() << "This test times out when -UNDEBUG is set."; +#endif + constexpr absl::string_view kHloText = R"( +ENTRY r { + p1 = f16[9,140,128]{2,1,0} parameter(1) + cp = f16[9,140,128]{2,0,1} copy(p1) + cv = f32[9,140,128]{2,0,1} convert(cp) + p0 = f32[9,140,123]{2,1,0} parameter(0) + ROOT d = f32[9,128,123]{2,1,0} dot(cv, p0), + lhs_batch_dims={0}, lhs_contracting_dims={1}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_DoNotFuseSplitRhsContractingTranspose) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f16[5,8] parameter(0) + p1 = s8[2,3,4] parameter(1) + c0 = f16[2,3,4] convert(p1) + t1 = f16[3,2,4] transpose(c0), dimensions={1,0,2} + r1 = f16[3,8] reshape(t1) + ROOT _ = f16[5,3] dot(p0, r1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: transpose +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_DoNotFuseSplitLhsContractingTranspose) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f16[3,16,25]{2,1,0} parameter(0) + p0t = f16[16,3,25]{2,1,0} transpose(p0), dimensions={1,0,2} + p0tr = f16[16,75]{1,0} reshape(p0t) + p1 = s8[128,75]{1,0} parameter(1) + cp1 = f16[128,75]{1,0} convert(p1) + ROOT dot.126 = f16[16,128]{1,0} dot(p0tr, cp1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: transpose +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_BatchF32F16) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + x = f32[5,2,3] parameter(0) + y = f16[5,3,4] parameter(1) + cy = f32[5,3,4] convert(y) + ROOT _ = f32[5,2,4] dot(x, cy), + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_batch_dims={0}, rhs_batch_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_NonMajorMostInputBatchWorksCorrectly) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + x = f32[20,50,30] parameter(0) + y = f16[30,50,40] parameter(1) + cy = f32[30,50,40] convert(y) + ROOT _ = f32[50,20,40] dot(x, cy), + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + lhs_batch_dims={1}, rhs_batch_dims={1} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_BatchTransposeF32F16) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + x = f32[5,3,2] parameter(0) + y = f16[5,3,4] parameter(1) + cy = f32[5,3,4] convert(y) + x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1} + ROOT _ = f32[5,2,4] dot(x_transposed, cy), + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_batch_dims={0}, rhs_batch_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_DoNotFuseArbitraryReshape) { + const std::string hlo_text = R"( +HloModule m + +ENTRY e { + p0 = f16[5,2,3] parameter(0) + p0c = f32[5,2,3] convert(p0) + p1 = f32[20,3] parameter(1) + p1r = f32[5,3,4] reshape(p1) + ROOT dot.5 = f32[5,2,4] dot(p0c, p1r), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: f32[5,3,4]{2,1,0} bitcast +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +TEST_F(TritonGemmTest, DISABLED_MultipleBatchRequireSeparateTranspose) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + Arg_0 = f16[3,4,2,5,4] parameter(0) + c = f32[3,4,2,5,4] convert(Arg_0) + Arg_1 = f32[5,3,4,3,2] parameter(1) + ROOT dot.3 = f32[5,3,4,4,3] dot(c, Arg_1), + lhs_batch_dims={3,0,1}, lhs_contracting_dims={2}, + rhs_batch_dims={0,1,2}, rhs_contracting_dims={4} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ROOT +; CHECK: transpose( +; CHECK: bitcast( +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +TEST_F(TritonGemmTest, + DISABLED_CanCodegenNonBatchedDotWithConcatenationCorrectly) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + parameter_0 = f32[3,10]{1,0} parameter(0) + parameter_1 = f32[10,128]{1,0} parameter(1) + parameter_2 = f32[10,256]{1,0} parameter(2) + concatenate = f32[10,384]{1,0} concatenate(parameter_1, parameter_2), dimensions={1} + ROOT dot = f32[3,384]{1,0} dot(parameter_0, concatenate), + lhs_batch_dims={}, lhs_contracting_dims={1}, + rhs_batch_dims={}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NOT: concatenate +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, + DISABLED_CanCodegenBatchedDotWithConcatenationCorrectly) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + concatenate = f32[2,10,384]{2,1,0} concatenate(parameter_1, parameter_2), dimensions={2} + ROOT dot = f32[2,3,384]{2,1,0} dot(parameter_0, concatenate), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NOT: concatenate +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, DISABLED_FloatToSignedIntConversion) { + constexpr absl::string_view kHloText = R"( +HloModule t, is_scheduled=true + +triton_gemm_r { + p_0 = s8[32,32]{1,0} parameter(0) + p_1 = f16[32,32]{1,0} parameter(1) + cvt_1 = s8[32,32]{1,0} convert(p_1) + ROOT r.1 = f32[32,32]{1,0} dot(p_0, cvt_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p_0 = s8[32,32]{1,0} parameter(0) + p_1 = f16[32,32]{1,0} parameter(1) + ROOT triton_gemm_r = f32[32,32]{1,0} fusion(p_0, p_1), kind=kCustom, + calls=triton_gemm_r, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm_r", R"( +CHECK: tt.func @triton_fn +CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> +CHECK-DAG: %[[FMIN:.*]] = arith.constant dense<-1.280000e+02> +CHECK-DAG: %[[IMIN:.*]] = arith.constant dense<-128> +CHECK-DAG: %[[FMAX:.*]] = arith.constant dense<1.270000e+02> +CHECK-DAG: %[[IMAX:.*]] = arith.constant dense<127> +CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[IN:.*]] : +CHECK: %[[CMP1:.*]] = arith.cmpf ole, %[[IN]], %[[FMIN]] +CHECK: %[[RES1:.*]] = arith.select %[[CMP1]], %[[IMIN]], %[[FPTOSI]] +CHECK: %[[CMP2:.*]] = arith.cmpf oge, %[[IN]], %[[FMAX]] +CHECK: %[[RES2:.*]] = arith.select %[[CMP2]], %[[IMAX]], %[[RES1]] +CHECK: %[[CMP3:.*]] = arith.cmpf uno, %[[IN]], %[[IN]] +CHECK: %[[RES3:.*]] = arith.select %[[CMP3]], %[[ZERO]], %[[RES2]] +})")); +} + +// This tests the complexity heuristics in TritonWrapper. +TEST_F(TritonGemmTest, DISABLED_FailForTooComplexTiling) { + constexpr absl::string_view kHloText = R"( +HloModule module, is_scheduled=true + +triton_gemm_dot { + p0 = s8[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + c0 = f32[1024,1024] convert(p0) + ROOT dot.0 = f32[1024,1024] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = s8[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloText)); + HloFusionInstruction* triton_dot_fusion = Cast( + hlo_module->entry_computation()->root_instruction()); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + auto backend_config_or = + triton_dot_fusion->backend_config(); + TF_ASSERT_OK(backend_config_or); + GpuBackendConfig& backend_config = *backend_config_or; + + FusionBackendConfig& fusion_backend_config = + *backend_config.mutable_fusion_backend_config(); + auto& config = *fusion_backend_config.mutable_triton_gemm_config(); + // Fails if the tiling is too complex. + config.set_block_m(512); + config.set_block_n(512); + config.set_block_k(32); + config.set_split_k(1); + config.set_num_ctas(1); + config.set_num_stages(1); + config.set_num_warps(2); + TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); + + BlockLevelParameters block_level_parameters; + block_level_parameters.num_ctas = 1; + block_level_parameters.num_stages = 1; + block_level_parameters.num_warps = 2; + EXPECT_THAT( + TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, + block_level_parameters, &llvm_module, mlir_context), + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + "Tiling complexity heuristic exceeded: 147456 > 9000")); + + // Succeeds if the tiling is not too complex. + config.set_block_m(32); + config.set_block_n(32); + config.set_block_k(32); + TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); + + TF_ASSERT_OK(TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), + dev_info, block_level_parameters, &llvm_module, + mlir_context) + .status()); +} + +// Triton compiler used to have an issue with reordering constants: +// https://github.com/openai/triton/issues/1864 +TEST_F(TritonGemmTest, DISABLED_TritonCompilerDoesNotFailOnConstants) { + TF_ASSERT_OK(GetOptimizedModule(R"( +HloModule m + +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + c = f32[] constant(0) + b = f32[11,63] broadcast(c) + ROOT _.1 = f32[92,63]{1,0} dot(parameter_0, b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[92,11]{1,0} parameter(0) + ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"16","block_n":"64", + "block_k":"16","split_k":"1", + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} +})") + .status()); +} + +// Normally optimized HLO should contain `copy` instead of `transpose` but +// it's also possible to get transposes by modifying the compiler's pipeline. +// The emitter just has to skip through the transpose - it's handled by the +// tiled fusion analysis. +TEST_F(TritonGemmTest, DISABLED_TritonEmitterCanHandleTransposes) { + MatchOptimizedHlo(R"( +t { + p0 = f16[55,77,111]{2,1,0} parameter(0) + p1 = f16[111,77,99]{2,1,0} parameter(1) + t = f16[77,99,111]{2,1,0} transpose(p1), dimensions={1,2,0} + ROOT d = f16[77,55,99]{2,1,0} dot(p0, t), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY e { + p0 = f16[55,77,111]{2,1,0} parameter(0) + p1 = f16[111,77,99]{2,1,0} parameter(1) + ROOT r = f16[77,55,99]{2,1,0} fusion(p0, p1), kind=kCustom, + calls=t, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} +})", + // This partially optimized HLO will go through the + // autotuner which will run the fusion through the emitter + // multiple times and assign block sizes on success. + R"( +; CHECK: f16[77,99,111]{2,1,0} transpose +; CHECK-PTX: block_m +)"); +} + +TEST_F(TritonGemmTest, DISABLED_SingleElementTileIsHandled) { + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } + MatchOptimizedHlo(R"( +t { + p0 = f32[2,7,3]{2,1,0} parameter(0) + p1 = s32[2,1]{1,0} parameter(1) + c = s32[] constant(1) + br0 = s32[2,1]{1,0} broadcast(c), dimensions={} + cmp = pred[2,1]{1,0} compare(p1, br0), direction=LT + bc0 = pred[2]{0} bitcast(cmp) + br1 = pred[2,1,3,3]{3,2,0,1} broadcast(bc0), dimensions={0} + cvt = f32[2,1,3,3]{3,2,0,1} convert(br1) + bc1 = f32[2,3,3]{2,1,0} bitcast(cvt) + ROOT d = f32[2,7,3]{2,1,0} dot(p0, bc1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[2,7,3]{2,1,0} parameter(0) + p1 = s32[2,1]{1,0} parameter(1) + ROOT r = f32[2,7,3]{2,1,0} fusion(p0, p1), kind=kCustom, + calls=t, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} +})", + // This partially optimized HLO will go through the + // autotuner which will run the fusion through the emitter + // multiple times and assign block sizes on success. + R"( +; CHECK: block_m +)"); +} + +TEST_F( + TritonGemmTest, + DISABLED_BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported) { + EXPECT_TRUE(RunAndCompare(R"( +f { + p0 = f32[64,6464] parameter(0) + p1 = f32[16,6464] parameter(1) + dot = f32[16,64] dot(p1, p0), + lhs_contracting_dims={1}, rhs_contracting_dims={1} + bc0 = f32[1,16,64] bitcast(dot) + p2 = f32[64] parameter(2) + bc1 = f32[1,64] bitcast(p2) + br = f32[1,16,64] broadcast(bc1), dimensions={0,2} + m = f32[1,16,64] multiply(bc0, br) +} + +e { + p0 = f32[64,6464] parameter(0) + p1 = f32[16,6464] parameter(1) + p2 = f32[64] parameter(2) + f = f32[1,16,64] fusion(p0, p1, p2), + kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, + DISABLED_BroadcastsOfTriviallySizedContractingDimensionsAreSupported) { + EXPECT_TRUE(RunAndCompare(R"( +f { + a = f16[2] parameter(0) + bc0 = f16[1,2] bitcast(a) + br = f16[1,4000,2] broadcast(bc0), dimensions={0,2} + bc1 = f16[4000,2] bitcast(br) + b = f16[3,4000] parameter(1) + d = f16[2,3] dot(bc1, b), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +e { + a = f16[2] parameter(0) + b = f16[3,4000] parameter(1) + f = f16[2,3] fusion(a, b), + kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_DoF32F32) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f32[3,5] parameter(0) + p1 = f32[5,7] parameter(1) + ROOT _ = f32[3,7] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_DoAddConstantToScalarAndBroadcastThat) { + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f32[] parameter(0) + p1 = f32[5,5] parameter(1) + %constant = f32[] constant(8) + add = add(p0, constant) + broadcast = f32[5,5] broadcast(add), dimensions={} + ROOT _ = f32[5,5] dot(broadcast, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: fusion({{.*}} kind=kCustom, {{.*}}block_m +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SameInput) { + const std::string hlo_text = R"( +HloModule m + +ENTRY e { + p0 = pred[5,5]{1,0} parameter(0) + c = f32[5,5]{1,0} convert(p0) + ROOT r = f32[5,5]{1,0} dot(c, c), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})"; + + // The fusion has separate parameters for each scope. + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: %[[p0:.*]] = pred[5,5]{1,0} parameter(0) +; CHECK: fusion(%[[p0]], %[[p0]]), kind=kCustom +; CHECK-PTX-SAME: "block_m": +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_DynamicSliceIsSupportedInLhsEndToEnd) { + // The select is used to restrict the start index to values that make sense. + // If it was constant, then the dynamic-slice would be optimized to slice. It + // is not strictly needed, because we also support clamping the indices. + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[7,2] parameter(1) + pred0 = pred[] parameter(2) + c1 = s32[] constant(1) + c2 = s32[] constant(2) + start_index0 = s32[] select(pred0, c1, c2) + start_index1 = s32[] constant(0) + dynamic_slice = f32[5,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1), + dynamic_slice_sizes={5,2} + ROOT dot = f32[4,5] dot(dot_lhs, dynamic_slice), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), + m::Fusion(m::Parameter()), m::Constant()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + // Check that it's not optimized away. + MatchHloModule(*module, "; CHECK: dynamic-slice("); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_DynamicSliceIsSupportedInRhs) { + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + constexpr absl::string_view kHloText = R"( +HloModule m + +triton_gemm { + dynamic_slice_input = f32[7,2] parameter(0) + dot_rhs = f32[2,4] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] parameter(3) + dynamic_slice = f32[5,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1), + dynamic_slice_sizes={5,2} + ROOT dot = f32[5, 4] dot(dynamic_slice, dot_rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + dynamic_slice_input = f32[7,2] parameter(0) + dot_rhs = f32[2,4] parameter(1) + start_index0 = s32[] constant(1) + start_index1 = s32[] constant(0) + ROOT fusion = f32[5,4] fusion(dynamic_slice_input, dot_rhs, start_index0, start_index1), + kind=kCustom, calls=triton_gemm, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm","triton_gemm_config":{ + "block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_MultiplePathsToSameOperandWorks) { + constexpr absl::string_view kHloText = R"( +triton_computation { + p0 = bf16[8192,512]{1,0} parameter(0) + p1 = bf16[512,512]{1,0} parameter(1) + dot = bf16[8192,512]{1,0} dot(bf16[8192,512]{1,0} p0, bf16[512,512]{1,0} p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = bf16[8192,512]{1,0} parameter(2) + multiply.1 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} dot, bf16[8192,512]{1,0} p2) + ROOT multiply.2 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} multiply.1, bf16[8192,512]{1,0} p2) +} + +ENTRY e { + p0 = bf16[8192,512]{1,0} parameter(0) + p1 = bf16[512,512]{1,0} parameter(1) + p2 = bf16[8192,512]{1,0} parameter(2) + ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"256","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}} +})"; + + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_computation", R"( + CHECK: tt.dot + CHECK-SAME: tensor<64x32xbf16> * tensor<32x256xbf16> -> tensor<64x256xf32> + CHECK: arith.mulf + CHECK: arith.mulf + )")); +} + +class TritonGemmDynamicSliceClampingTest + : public TritonTest, + public ::testing::WithParamInterface {}; + +TEST_P(TritonGemmDynamicSliceClampingTest, + DISABLED_DynamicSliceIsSupportedWhenTheStartIndexNeedsClamping) { + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + + const std::string hlo_text = absl::Substitute(R"( +HloModule m + +triton_gemm { + dynamic_slice_input = f32[7,2] parameter(0) + dot_rhs = f32[2,4] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] parameter(3) + dynamic_slice = f32[5,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1), + dynamic_slice_sizes={5,2} + ROOT dot = f32[5, 4] dot(dynamic_slice, dot_rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + dynamic_slice_input = f32[7,2] parameter(0) + dot_rhs = f32[2,4] parameter(1) + start_index0 = s32[] constant($0) + start_index1 = s32[] constant(0) + ROOT fusion = f32[5,4] fusion(dynamic_slice_input, dot_rhs, start_index0, start_index1), + kind=kCustom, calls=triton_gemm, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm","triton_gemm_config":{ + "block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})", + GetParam()); + + EXPECT_TRUE(RunAndCompareNoHloPasses( + hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); +} + +std::string OffsetParamToString(const ::testing::TestParamInfo& data) { + return absl::StrCat("WithOffsetEq", data.param < 0 ? "Negative" : "", + std::abs(data.param)); +} + +INSTANTIATE_TEST_SUITE_P(All, TritonGemmDynamicSliceClampingTest, + ::testing::Values(-100, 3, 999), OffsetParamToString); + +TEST_F(TritonGemmTest, + DISABLED_DynamicSliceOfMajormostContractingDimIsSupported) { + // Tests that dynamic-slice works on the majormost dimension even if that + // dimension is contracted. + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + constexpr absl::string_view kHloText = R"( +HloModule m + +triton_gemm { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[5,4] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] parameter(3) + dynamic_slice = f32[2,4] dynamic-slice(dynamic_slice_input, start_index0, start_index1), + dynamic_slice_sizes={2,4} + ROOT dot = f32[4,4] dot(dot_lhs, dynamic_slice), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[5,4] parameter(1) + start_index0 = s32[] constant(2) + start_index1 = s32[] constant(0) + ROOT fusion = f32[4,4] fusion(dot_lhs, dynamic_slice_input, start_index0, start_index1), + kind=kCustom, calls=triton_gemm, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm","triton_gemm_config":{ + "block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_DynamicSliceOfMajormostBatchDimIsSupported) { + // Tests that dynamic-slice works on the majormost dimension even if that + // dimension is a batch. + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + constexpr absl::string_view kHloText = R"( +HloModule m + +triton_gemm { + dot_lhs = f32[2,2,4] parameter(0) + dynamic_slice_input = f32[7,2,4] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] parameter(3) + start_index2 = s32[] parameter(4) + dynamic_slice = f32[2,2,4] dynamic-slice(dynamic_slice_input, start_index0, start_index1, start_index2), + dynamic_slice_sizes={2,2,4} + ROOT dot = f32[2,4,4] dot(dot_lhs, dynamic_slice), + lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + dot_lhs = f32[2,2,4] parameter(0) + dynamic_slice_input = f32[7,2,4] parameter(1) + start_index0 = s32[] constant(2) + start_index1 = s32[] constant(0) + start_index2 = s32[] constant(0) + ROOT fusion = f32[2,4,4] fusion(dot_lhs, dynamic_slice_input, start_index0, start_index1, start_index2), + kind=kCustom, calls=triton_gemm, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm","triton_gemm_config":{ + "block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, + DISABLED_DynamicSliceSingleDimensionIntoReshapeIsSupported) { + // This directly tests the targeted use case (b/307922364) of iterating over + // layer weights and extracting them with dynamic slice. + // The start index(es) for the non-majormost dimension(s) are constant zero(s) + // because we don't support dynamic slice on those dimensions. + constexpr absl::string_view kHloText = R"( +HloModule m + +triton_gemm { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[4,5,2] parameter(1) + start_index0 = s32[] parameter(2) + start_index1 = s32[] parameter(3) + start_index2 = s32[] parameter(4) + dynamic_slice = f32[1,5,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1, start_index2), + dynamic_slice_sizes={1,5,2} + reshape = f32[5,2] reshape(dynamic_slice) + ROOT d = f32[4,5] dot(dot_lhs, reshape), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + dot_lhs = f32[2,4] parameter(0) + dynamic_slice_input = f32[4,5,2] parameter(1) + start_index0 = s32[] constant(3) + start_index1 = s32[] constant(0) + start_index2 = s32[] constant(0) + ROOT fusion = f32[4,5] fusion(dot_lhs, dynamic_slice_input, start_index0, start_index1, start_index2), + kind=kCustom, calls=triton_gemm, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm","triton_gemm_config":{ + "block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, + DISABLED_DoNotFuseConcatenationOfSplitNonContractingDimension) { + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const std::string hlo_text = R"( +HloModule m + +ENTRY e { + x = bf16[2,128,10] parameter(0) + y = bf16[2,256,10] parameter(1) + concat = bf16[2,384,10] concatenate(x, y), dimensions={1} + z = bf16[10,20] parameter(2) + ROOT d = bf16[2,384,20] dot(concat, z), lhs_contracting_dims={2}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: concatenate +; CHECK: ROOT +; CHECK-SAME: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: "block_m" +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarWorksCorrectly) { + constexpr absl::string_view kHloText = R"( +fusion { + p0 = f16[2,18] parameter(0) + p1 = f16[256,2] parameter(1) + d = f16[18,256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={1} + p2 = f16[] parameter(2) + p3 = f16[] parameter(3) + multiply = f16[] multiply(p2, p3) + broadcast = f16[18,256] broadcast(multiply), dimensions={} + ROOT multiply.3 = f16[18,256] multiply(d, broadcast) +} +ENTRY e { + p0 = f16[2,18] parameter(0) + p1 = f16[256,2] parameter(1) + p2 = f16[] parameter(2) + p3 = f16[] parameter(3) + ROOT gemm_fusion = f16[18,256]{1,0} fusion(p0, p1, p2, p3), kind=kCustom, calls=fusion, backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText, "fusion", R"( + CHECK: tt.dot + CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor + CHECK: tt.broadcast %{{.*}} : tensor<1x1xf16> -> tensor<32x32xf16> + CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor<32x32xf16> + )")); + const se::DeviceDescription dev_info = + backend().default_stream_executor()->GetDeviceDescription(); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloText)); + const HloFusionInstruction* triton_dot_fusion = Cast( + hlo_module->entry_computation()->root_instruction()); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + TF_ASSERT_OK_AND_ASSIGN( + auto gpu_config, triton_dot_fusion->backend_config()); + const FusionBackendConfig& config = gpu_config.fusion_backend_config(); + auto gemm_config = config.triton_gemm_config(); + BlockLevelParameters block_level_parameters; + block_level_parameters.num_ctas = gemm_config.num_ctas(); + block_level_parameters.num_warps = gemm_config.num_warps(); + block_level_parameters.num_stages = gemm_config.num_stages(); + + TF_ASSERT_OK(TritonWrapper("test_fn", triton_dot_fusion, + GpuComputeCapability(), dev_info, + block_level_parameters, &llvm_module, mlir_context) + .status()); +} + +TEST_F(TritonGemmTest, DISABLED_BinaryOperationWithSmallInputsIsFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + p1 = f32[3,16] parameter(1) + p2 = f32[3,16] parameter(2) + e = f32[3,16] exponential(p1) + a = f32[3,16] add(e, p2) + c = f32[7,3] convert(p0) + ROOT d = f32[7,16] dot(c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_BinaryOperationWithLargeInputsIsNotFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[333,1000] parameter(0) + p1 = f32[1000,333] parameter(1) + p1n = f32[1000,333] negate(p1) + p2 = f32[1000,333] parameter(2) + p2n = f32[1000,333] negate(p2) + s = f32[1000,333] subtract(p1n, p2n) + c = f32[333,1000] convert(p0) + ROOT d = f32[1000,1000] dot(s, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fused_subtract +; CHECK: negate +; CHECK: negate +; CHECK: ROOT +; CHECK-SAME: subtract +; CHECK: ENTRY +; CHECK: kLoop +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, + DISABLED_ParametersWithDifferentLayoutsAreSupportedInOneScope) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = s8[5,3] parameter(0) + p0c = f16[5,3] convert(p0) + p1 = f16[5,7] parameter(1) + p2 = f16[7,5] parameter(2) + t = f16[5,7] transpose(p2), dimensions={1,0} + a = f16[5,7] add(t, p1) + ROOT d = f16[3,7] dot(p0c, a), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_BinaryOperationOnLargeParametersIsFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[1000,111] parameter(0) + p1 = f32[111,10000] parameter(1) + p2 = f32[111,10000] parameter(2) + s = f32[111,10000] subtract(p1, p2) + c = f32[1000,111] convert(p0) + ROOT d = f32[10000,1000] dot(s, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_LinkingLibdeviceTwiceWorks) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = s8[7,3] parameter(0) + c0 = f32[7,3] convert(p0) + p1 = f32[3,16] parameter(1) + e1 = f32[3,16] exponential(p1) + d0 = f32[7,16] dot(c0, e1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = s8[7,3] parameter(2) + c2 = f32[7,3] convert(p2) + e2 = f32[7,3] exponential(c2) + p3 = f32[3,16] parameter(3) + d1 = f32[7,16] dot(e2, p3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT a = f32[7,16] add(d0, d1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Add( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom), + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarParameterIsFused) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f16[64,256] parameter(0) + p0c = f32[64,256] convert(p0) + p1 = f32[] parameter(1) + b = f32[256,128] broadcast(p1), dimensions={} + ROOT d = f32[64,128] dot(p0c, b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarConstantIsFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[70,30] parameter(0) + p0c = f32[70,30] convert(p0) + constant_3663 = f32[] constant(4321) + bc0 = f32[30,5] broadcast(constant_3663) + ROOT d = f32[70,5] dot(p0c, bc0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_DoubleBroadcastOfScalarConstantIsHandled) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +ENTRY e { + c = s32[] constant(1) + bc1 = s32[21]{0} broadcast(c), dimensions={} + p0 = s32[21]{0} parameter(0) + cmp = pred[21]{0} compare(bc1, p0), direction=EQ + convert.6 = bf16[21]{0} convert(cmp) + bc2 = bf16[3,21]{1,0} broadcast(convert.6), dimensions={1} + p1 = bf16[21,71]{1,0} parameter(1) + ROOT d = bf16[3,71]{1,0} dot(bc2, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_BroadcastOfVectorConstantIsFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[60,5] parameter(0) + c0 = f16[60,5] convert(p0) + cst1 = f16[120] constant({...}) + r1 = f16[5,120] broadcast(cst1), dimensions={1} + ROOT d = f16[60,120] dot(c0, r1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Constant()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(TritonGemmTest, DISABLED_AlwaysFuseScalarConstantAtBroadcastInput) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = bf16[2,3,3]{2,1,0} parameter(0) + p1 = bf16[3,2,3]{2,1,0} parameter(1) + d = bf16[2,3,3]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={1}, rhs_contracting_dims={0} + t = bf16[3,2,3]{2,0,1} transpose(d), dimensions={1,0,2} + c = bf16[] constant(0.123) + b = bf16[3,2,3]{2,1,0} broadcast(c), dimensions={} + m = bf16[3,2,3]{2,0,1} multiply(t, b) + ROOT tu = (bf16[3,2,3]{2,0,1}, bf16[3,2,3]{2,1,0}) tuple(m, b) +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: gemm_fusion_dot +; CHECK: dot( +; CHECK: bf16[] constant(0.123) +; CHECK: ROOT +; CHECK: ENTRY +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_BroadcastOfVectorParameterIsFused) { + constexpr absl::string_view kHloText = R"( +triton_dot { + p0 = f16[75] parameter(0) + bc0 = f16[75,67] broadcast(p0), dimensions={0} + p1 = f16[92,75] parameter(1) + ROOT d = f16[92,67] dot(p1, bc0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[75] parameter(0) + p1 = f16[92,75] parameter(1) + ROOT _ = f16[92,67] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":32,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_FuseConcatenation) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +e { + p0 = s8[153,1536] parameter(0) + p1 = s8[153,128] parameter(1) + p2 = s8[153,128] parameter(2) + cat = s8[153,1792] concatenate(p0, p1, p2), dimensions={1} + cvt = bf16[153,1792] convert(cat) + p3 = bf16[16,153] parameter(3) + ROOT d = bf16[16,1792] dot(p3, cvt), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, + /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_MinimumHandlesNaNsOnTheLeft) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + min = f32[5,5] minimum(nans, neg1s) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MinimumHandlesNaNsOnTheRight) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + min = f32[5,5] minimum(neg1s, nans) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MaximumHandlesNaNsOnTheLeft) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + max = f32[5,5] maximum(nans, neg1s) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MaximumHandlesNaNsOnTheRight) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + max = f32[5,5] maximum(neg1s, nans) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MinimumReturnsLHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + min = f32[5,5] minimum(zeros, ones) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MinimumReturnsRHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + min = f32[5,5] minimum(ones, zeros) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MaximumReturnsLHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + max = f32[5,5] maximum(ones, zeros) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_MaximumReturnsRHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + max = f32[5,5] maximum(zeros, ones) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-PTX-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SineOutputIsNotFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,101] parameter(0) + p1 = f32[101,16] parameter(1) + c = f32[7,101] convert(p0) + d = f32[7,16] dot(c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT r = f32[7,16] sine(d) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Sin( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_SliceInputIsFused) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f16[97,121] parameter(0) + s0 = f16[7,101] slice(p0), slice={[3:10], [10:111]} + p1 = f32[101,16] parameter(1) + c = f32[7,101] convert(s0) + ROOT d = f32[16,7] dot(p1, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SliceInputWithReshapeIsFused) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f32[363,1536] parameter(0) + p1 = f32[4,1536,611] parameter(1) + s = f32[1,1536,611] slice(p1), + slice={[1:2], [0:1536], [0:611]} + r = f32[1536,611] reshape(s) + ROOT d = f32[363,611] dot(p0, r), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_NestedSlicingWorks) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p1 = f32[6,24] parameter(1) + slice1 = f32[5,20] slice(p1), slice={[1:6], [3:23]} + n1 = f32[5,20] negate(slice1) + slice2 = f32[3,7] slice(n1), slice={[1:4], [13:20]} + p0 = f32[7,37] parameter(0) + ROOT d = f32[3,37] dot(slice2, p0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SlicedBatchDimensionIsSupported) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f16[3,3,256] parameter(0) + s0 = f16[3,3,128] slice(p0), slice={[0:3], [0:3], [123:251]} + r0 = f16[3,3,128] reshape(s0) + p1 = f16[3,3,256] parameter(1) + svar1 = f16[3,3,128] slice(p1), slice={[0:3], [0:3], [30:158]} + r1 = f16[3,3,128] reshape(svar1) + ROOT d = f16[128,3,3]{2,1,0} dot(r0, r1), + lhs_batch_dims={2}, lhs_contracting_dims={1}, + rhs_batch_dims={2}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTestWithSplitK, + DISABLED_SplitKDoesNotBreakSlicedFragmentedContractingDimension) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f16[16,8,128]{2,1,0} parameter(0) + s0 = f16[16,4,128]{2,1,0} slice(p0), + slice={[0:16], [0:4], [0:128]} + r0 = f16[16,512]{1,0} reshape(s0) + p1 = s8[4096,4,128]{2,1,0} parameter(1) + r1 = s8[512,4096]{0,1} reshape(p1) + c1 = f16[512,4096]{0,1} convert(r1) + ROOT d = f16[16,4096]{1,0} dot(r0, c1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTestWithSplitK, DISABLED_SplitKWithTrivialDimension) { + constexpr absl::string_view kHloText = R"( +ENTRY entry_computation { + p0 = f16[1001,1]{1,0} parameter(0) + convert = f32[1001,1]{1,0} convert(p0) + p1 = f32[1001,2048]{1,0} parameter(1) + ROOT dot = f32[1,2048]{1,0} dot(convert, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_NarrowingConvertOutputIsFused) { + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[22,80] parameter(0) + p1 = f32[80,54] parameter(1) + c = f32[22,80] convert(p0) + d = f32[54,22] dot(p1, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + ROOT r = f16[54,22] convert(d) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/3e-2, /*arel=*/3e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_ParameterAfterDotIsFused) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = bf16[350,1280]{1,0} parameter(0) + p1 = s16[1280,690]{0,1} parameter(1) + p1c = bf16[1280,690]{0,1} convert(p1) + dot.21 = bf16[350,690]{1,0} dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = bf16[350,690]{1,0} parameter(2) + ROOT r = bf16[350,690]{1,0} multiply(p2, dot.21) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + const HloInstruction* instr = module->entry_computation()->root_instruction(); + if (!instr->IsCustomFusion()) { + instr = instr->operand(0); + ASSERT_TRUE(instr->IsCustomFusion()); + } + EXPECT_THAT( + instr, + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_OutputFusionExecutesCorrectly) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[350,1280]{1,0} parameter(0) + p0c = bf16[350,1280]{1,0} convert(p0) + p1 = bf16[1280,690]{0,1} parameter(1) + d = bf16[350,690]{1,0} dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p3 = bf16[350,690]{1,0} parameter(3) + multiply.8811 = bf16[350,690]{1,0} multiply(d, p3) + neg.484 = bf16[350,690]{1,0} negate(multiply.8811) + p2 = bf16[350,690]{1,0} parameter(2) + ROOT multiply.8808 = bf16[350,690]{1,0} multiply(neg.484, p2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + const HloInstruction* instr = module->entry_computation()->root_instruction(); + if (!instr->IsCustomFusion()) { + instr = instr->operand(0); + ASSERT_TRUE(instr->IsCustomFusion()); + } + EXPECT_THAT( + instr, + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); +} + +TEST_F(TritonGemmTest, DISABLED_SplitLHSOutputTransposeAloneIsNotFused) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[18,15000] parameter(0) + p0c = bf16[18,15000] convert(p0) + p1 = bf16[42,18] parameter(1) + d = bf16[15000,42] dot(p0c, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + r1 = bf16[5,200,15,42] reshape(d) + ROOT t1 = bf16[5,42,200,15] transpose(r1), dimensions={0,3,1,2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast( + m::Fusion(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)) + .WithFusionKind(HloInstruction::FusionKind::kInput)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SplitLHSInputOutputIsFused) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "Skipped until corresponding issue on ROCm is fixed."; + } + + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0t = (s8[5,18,20,150]) parameter(0) + p0 = s8[5,18,20,150] get-tuple-element(p0t), index=0 + p0c = bf16[5,18,20,150] convert(p0) + t0 = bf16[18,5,20,150] transpose(p0c), dimensions={1,0,2,3} + r0 = bf16[18,15000] reshape(t0) + p1 = bf16[42,18] parameter(1) + d = bf16[15000,42] dot(r0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + r1 = bf16[5,20,150,42] reshape(d) + ROOT t1 = bf16[5,42,20,150] transpose(r1), dimensions={0,3,1,2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::GetTupleElement(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_SupportPredParametersUsedInExpressions) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + p = pred[2,2]{1,0} parameter(0) + a = f32[2,2]{1,0} parameter(1) + b = f32[2,2]{1,0} parameter(2) + c = f32[2,2]{1,0} parameter(3) + compare = pred[2,2]{1,0} compare(a, b), direction=LT + and = pred[2,2]{1,0} and(p, compare) + convert = f32[2,2]{1,0} convert(and) + ROOT r = f32[2,2]{1,0} dot(convert, c), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, DISABLED_Naming) { + const char* hlo_text = R"( +HloModule t + +ENTRY e { + p0 = f16[15,19] parameter(0) + p1 = s8[19,17] parameter(1) + cp1 = f16[19,17] convert(p1) + ROOT r = f16[15,17] dot(p0, cp1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: %gemm_fusion_r_computation ( +; CHECK: ROOT %gemm_fusion_r +; CHECK-SAME: kCustom +)"); +} + +TEST_F(TritonGemmTest, + DISABLED_LowerDotWithLhsWithoutNonContractingDimThroughTriton) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + parameter_0 = f32[1,40] parameter(0) + parameter_1 = f32[1,40,250000] parameter(1) + ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(hlo_text)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, + DISABLED_LowerDotWithRhsWithoutNonContractingDimThroughTriton) { + const std::string hlo_text = R"( +HloModule t + +ENTRY e { + parameter_0 = f32[1,40,250000] parameter(0) + parameter_1 = f32[1,40] parameter(1) + ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(hlo_text)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +// This group of tests compares GPU results of dots already rewritten +// into Triton fusions. +using CompareTest = TritonGemmTest; + +TEST_F(CompareTest, DISABLED_DifferentTilingsProduceSameResult) { + const char* hlo_text_ref = R"( +HloModule t + +triton_dot { + p0 = s8[101,202] parameter(0) + p0c = f32[101,202] convert(p0) + p1 = f32[202,303] parameter(1) + ROOT dot = f32[101,303] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[101,202]{1,0} parameter(0) + p1 = f32[202,303]{1,0} parameter(1) + ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":3,"num_warps":8, + "num_ctas":1}}} +})"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + p0 = s8[101,202] parameter(0) + p0c = f32[101,202] convert(p0) + p1 = f32[202,303] parameter(1) + ROOT dot = f32[101,303] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[101,202]{1,0} parameter(0) + p1 = f32[202,303]{1,0} parameter(1) + ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":128,"block_k":32, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_F16) { + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = f16[5,7] parameter(0) + arg1 = f16[7,33] parameter(1) + gemm = (f16[5,33], s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f16[5,33]{1,0} get-tuple-element((f16[5,33]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + p0 = f16[5,7] parameter(0) + p1 = f16[7,33] parameter(1) + ROOT dot = f16[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[5,7]{1,0} parameter(0) + p1 = f16[7,33]{1,0} parameter(1) + ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_F32) { + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = f32[5,7] parameter(0) + arg1 = f32[7,33] parameter(1) + gemm = (f32[5,33], s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[5,33]{1,0} get-tuple-element((f32[5,33]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_F32WithTrivialNonContractingDimension) { + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = f32[5,7] parameter(0) + arg1 = f32[1,7] parameter(1) + gemm = (f32[5,1], s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[5,1]{1,0} get-tuple-element((f32[5,1]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[1,7] parameter(1) + ROOT dot = f32[5,1] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[1,7]{1,0} parameter(1) + ROOT _ = f32[5,1] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_BF16TransposedLHS) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = bf16[512,16]{1,0} parameter(0) + arg1 = bf16[512,256]{1,0} parameter(1) + gemm = (bf16[16,256]{1,0}, s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = bf16[16,256]{1,0} get-tuple-element((bf16[16,256]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + arg0 = bf16[512,16]{1,0} parameter(0) + arg1 = bf16[512,256]{1,0} parameter(1) + ROOT dot = bf16[16,256]{1,0} dot(arg0, arg1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + arg0 = bf16[512,16]{1,0} parameter(0) + arg1 = bf16[512,256]{1,0} parameter(1) + ROOT _ = bf16[16,256]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":32,"block_k":16, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_UsingOptinSharedMemoryOnAmpereProducesSameResult) { + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "No Optin Shared Memory on AMD."; + } + const se::DeviceDescription dev_info = + backend().default_stream_executor()->GetDeviceDescription(); + constexpr int kBytesOfSharedMemoryTested = 64 * 1024; + EXPECT_GE(dev_info.shared_memory_per_block_optin(), + kBytesOfSharedMemoryTested); + + const std::string kHloTextOptinShmem = R"( +HloModule t + +triton_dot { + param_0.1 = s8[332,441]{1,0} parameter(0) + p0c = f16[332,441]{1,0} convert(param_0.1) + param_1.1 = f16[441,39]{1,0} parameter(1) + ROOT dot = f16[332,39]{1,0} dot(p0c, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[332,441]{1,0} parameter(0) + p1 = f16[441,39]{1,0} parameter(1) + ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":128,"block_k":128, + "split_k":1,"num_stages":2,"num_warps":32, + "num_ctas":1}}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTextOptinShmem)); + const HloFusionInstruction* triton_dot_fusion = Cast( + hlo_module->entry_computation()->root_instruction()); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + TF_ASSERT_OK_AND_ASSIGN( + auto gpu_config, triton_dot_fusion->backend_config()); + const FusionBackendConfig& config = gpu_config.fusion_backend_config(); + auto gemm_config = config.triton_gemm_config(); + BlockLevelParameters block_level_parameters; + block_level_parameters.num_ctas = gemm_config.num_ctas(); + block_level_parameters.num_warps = gemm_config.num_warps(); + block_level_parameters.num_stages = gemm_config.num_stages(); + TF_ASSERT_OK_AND_ASSIGN( + const auto result, + TritonWrapper("test_fn", triton_dot_fusion, GpuComputeCapability(), + dev_info, block_level_parameters, &llvm_module, + mlir_context)); + // The config is chosen so that the used memory size is slightly above the + // 48 kB boundary of standard / optin shared memory so that any GPU that + // has the optin one should be able to execute the test. + EXPECT_EQ(result.shmem_bytes, kBytesOfSharedMemoryTested); + // Make sure the written config indeed has to use optin shared memory. + EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); + + const std::string kHloTextLowShmem = R"( +HloModule t + +triton_dot { + param_0.1 = s8[332,441]{1,0} parameter(0) + p0c = f16[332,441]{1,0} convert(param_0.1) + param_1.1 = f16[441,39]{1,0} parameter(1) + ROOT dot = f16[332,39]{1,0} dot(p0c, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[332,441]{1,0} parameter(0) + p1 = f16[441,39]{1,0} parameter(1) + ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextLowShmem, kHloTextOptinShmem, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_F16TransposedRHS) { + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = f16[128,32]{1,0} parameter(0) + arg1 = f16[64,32]{1,0} parameter(1) + gemm = (f16[128,64]{1,0}, s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f16[128,64]{1,0} get-tuple-element((f16[128,64]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + arg0 = f16[128,32]{1,0} parameter(0) + arg1 = f16[64,32]{1,0} parameter(1) + ROOT dot = f16[128,64]{1,0} dot(arg0, arg1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + arg0 = f16[128,32]{1,0} parameter(0) + arg1 = f16[64,32]{1,0} parameter(1) + ROOT _ = f16[128,64]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_F32TransposedBoth) { + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = f32[64,128]{1,0} parameter(0) + arg1 = f32[1024,64]{1,0} parameter(1) + gemm = (f32[128,1024]{1,0}, s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[128,1024]{1,0} get-tuple-element((f32[128,1024]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + arg0 = f32[64,128]{1,0} parameter(0) + arg1 = f32[1024,64]{1,0} parameter(1) + ROOT dot = f32[128,1024]{1,0} dot(arg0, arg1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + arg0 = f32[64,128]{1,0} parameter(0) + arg1 = f32[1024,64]{1,0} parameter(1) + ROOT _ = f32[128,1024]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_S8BF16) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const char* hlo_text_ref = R"( +HloModule r + +fused_computation { + param_0.1 = s8[144,256]{1,0} parameter(0) + ROOT convert.4 = bf16[144,256]{1,0} convert(param_0.1) +} + +ENTRY e { + p0 = s8[144,256]{1,0} parameter(0) + fusion = bf16[144,256]{1,0} fusion(p0), kind=kInput, calls=fused_computation + p1 = bf16[256,122]{1,0} parameter(1) + gemm = (bf16[144,122]{1,0}, s8[0]{0}) custom-call(fusion, p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = bf16[144,122]{1,0} get-tuple-element((bf16[144,122]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + param_0.1 = s8[144,256]{1,0} parameter(0) + p0c = bf16[144,256]{1,0} convert(param_0.1) + param_1.1 = bf16[256,122]{1,0} parameter(1) + ROOT dot = bf16[144,122]{1,0} dot(p0c, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[144,256]{1,0} parameter(0) + p1 = bf16[256,122]{1,0} parameter(1) + ROOT _ = bf16[144,122]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":64,"block_k":64, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_SplitK) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const std::string hlo_text_ref = R"( +HloModule t, is_scheduled=true + +triton_gemm_r { + parameter_0 = s8[480,120]{1,0} parameter(0) + convert.3 = bf16[480,120]{1,0} convert(parameter_0) + parameter_1 = bf16[16,120]{1,0} parameter(1) + ROOT r.1 = bf16[480,16]{1,0} dot(convert.3, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p1 = bf16[16,120]{1,0} parameter(1) + p0 = s8[3,120,5,32]{3,2,1,0} parameter(0) + bitcast.4 = s8[480,120]{1,0} bitcast(p0) + ROOT triton_gemm_r = bf16[480,16]{1,0} fusion(bitcast.4, p1), kind=kCustom, + calls=triton_gemm_r, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":1,"num_stages":4,"num_warps":4, + "num_ctas":1}}} +})"; + + const std::string hlo_text_splitk = R"( +HloModule t, is_scheduled=true + +triton_gemm_r { + parameter_0 = s8[480,120]{1,0} parameter(0) + convert.3 = bf16[480,120]{1,0} convert(parameter_0) + bitcast.11 = bf16[480,4,30]{2,1,0} bitcast(convert.3) + parameter_1 = bf16[16,120]{1,0} parameter(1) + bitcast.12 = bf16[16,4,30]{2,1,0} bitcast(parameter_1) + ROOT dot.1 = bf16[4,480,16]{2,1,0} dot(bitcast.11, bitcast.12), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={1}, rhs_contracting_dims={2} +} + +add { + rhs.1 = f32[] parameter(1) + lhs.1 = f32[] parameter(0) + ROOT add.1 = f32[] add(lhs.1, rhs.1) +} + +fused_computation { + param_0.2 = bf16[4,480,16]{2,1,0} parameter(0) + convert.18 = f32[4,480,16]{2,1,0} convert(param_0.2) + constant_1 = bf16[] constant(0) + convert.17 = f32[] convert(constant_1) + reduce.1 = f32[480,16]{1,0} reduce(convert.18, convert.17), dimensions={0}, + to_apply=add + ROOT convert.16 = bf16[480,16]{1,0} convert(reduce.1) +} + +ENTRY e { + p1 = bf16[16,120]{1,0} parameter(1) + p0 = s8[3,120,5,32]{3,2,1,0} parameter(0) + bitcast.4 = s8[480,120]{1,0} bitcast(p0) + triton_gemm_r = bf16[4,480,16]{2,1,0} fusion(bitcast.4, p1), kind=kCustom, + calls=triton_gemm_r, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":128, + "split_k":4,"num_stages":1,"num_warps":4, + "num_ctas":1}}} + ROOT fusion.1 = bf16[480,16]{1,0} fusion(triton_gemm_r), kind=kLoop, + calls=fused_computation +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_splitk, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_SplitKBatch) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const std::string kHloTextRef = R"( +HloModule m, is_scheduled=true + +triton_gemm_dot.24 { + parameter_1 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) + bitcast.3 = bf16[800,5,128]{2,1,0} bitcast(parameter_1) + convert.3 = f32[800,5,128]{2,1,0} convert(bitcast.3) + parameter_0 = f32[1,5,700,800]{3,2,1,0} parameter(0) + bitcast.2 = f32[5,700,800]{2,1,0} bitcast(parameter_0) + ROOT dot.26 = f32[5,128,700]{2,1,0} dot(convert.3, bitcast.2), lhs_batch_dims={1}, lhs_contracting_dims={0}, rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY e { + tmp_3 = f32[1,5,700,800]{3,2,1,0} parameter(0) + tmp_0 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) + ROOT triton_gemm_dot.24 = f32[5,128,700]{2,1,0} fusion(tmp_3, tmp_0), + kind=kCustom, calls=triton_gemm_dot.24, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":8, + "num_ctas":1}}} +})"; + + const std::string kHloTextSplitK = R"( +HloModule m, is_scheduled=true + +triton_gemm_dot { + parameter_1 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) + bitcast.3 = bf16[800,5,128]{2,1,0} bitcast(parameter_1) + convert.3 = f32[800,5,128]{2,1,0} convert(bitcast.3) + bitcast = f32[8,100,5,128]{3,2,1,0} bitcast(convert.3) + parameter_0 = f32[1,5,700,800]{3,2,1,0} parameter(0) + bitcast.2 = f32[5,700,800]{2,1,0} bitcast(parameter_0) + bitcast.1 = f32[5,700,8,100]{3,2,1,0} bitcast(bitcast.2) + ROOT dot = f32[8,5,128,700]{3,2,1,0} dot(bitcast, bitcast.1), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={2,0}, rhs_contracting_dims={3} +} + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY e { + tmp_3 = f32[1,5,700,800]{3,2,1,0} parameter(0) + tmp_0 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) + triton_gemm_dot.24 = f32[8,5,128,700]{3,2,1,0} fusion(tmp_3, tmp_0), + kind=kCustom, calls=triton_gemm_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":8,"num_stages":1,"num_warps":4, + "num_ctas":1}}} + constant = f32[] constant(0) + ROOT reduce = f32[5,128,700]{2,1,0} reduce(triton_gemm_dot.24, constant), dimensions={0}, to_apply=add +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_SplitKNontrivialBitcast) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const std::string kHloTextRef = R"( +HloModule module, is_scheduled=true + +triton_gemm_dot.5316 { + parameter_1 = bf16[16,4,128]{2,1,0} parameter(1) + bitcast.2 = bf16[16,512]{1,0} bitcast(parameter_1) + parameter_0 = s8[512,96]{1,0} parameter(0) + convert.4 = bf16[512,96]{1,0} convert(parameter_0) + ROOT dot.0 = bf16[16,96]{1,0} dot(bitcast.2, convert.4), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + parameter_0.1 = s8[96,4,128]{2,1,0} parameter(0) + bitcast.6 = s8[512,96]{1,0} bitcast(parameter_0.1) + parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) + ROOT triton_gemm_dot.5316 = bf16[16,96]{1,0} fusion(bitcast.6, parameter_1.1), + kind=kCustom, calls=triton_gemm_dot.5316, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":256, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + const std::string kHloTextSplitK = R"( +HloModule module, is_scheduled=true + +triton_gemm_dot.5316 { + parameter_1 = bf16[16,4,128]{2,1,0} parameter(1) + bitcast.2 = bf16[16,512]{1,0} bitcast(parameter_1) + bitcast.17 = bf16[16,16,32]{2,1,0} bitcast(bitcast.2) + parameter_0 = s8[512,96]{1,0} parameter(0) + convert.4 = bf16[512,96]{1,0} convert(parameter_0) + bitcast.18 = bf16[16,32,96]{2,1,0} bitcast(convert.4) + ROOT dot.4 = bf16[16,16,96]{2,1,0} dot(bitcast.17, bitcast.18), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +triton_gemm_dot.5316.reduce_sub_computation.clone { + rhs.1 = f32[] parameter(1) + lhs.1 = f32[] parameter(0) + ROOT add.1 = f32[] add(lhs.1, rhs.1) +} + +fused_computation { + param_0.2 = bf16[16,16,96]{2,1,0} parameter(0) + convert.19 = f32[16,16,96]{2,1,0} convert(param_0.2) + constant_1 = bf16[] constant(0) + convert.18 = f32[] convert(constant_1) + reduce.1 = f32[16,96]{1,0} reduce(convert.19, convert.18), + dimensions={0}, to_apply=triton_gemm_dot.5316.reduce_sub_computation.clone + ROOT convert.17 = bf16[16,96]{1,0} convert(reduce.1) +} + +ENTRY entry { + parameter_0.1 = s8[96,4,128]{2,1,0} parameter(0) + bitcast.6 = s8[512,96]{1,0} bitcast(parameter_0.1) + parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) + triton_gemm_dot.5316 = bf16[16,16,96]{2,1,0} fusion(bitcast.6, parameter_1.1), + kind=kCustom, calls=triton_gemm_dot.5316, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":32, + "split_k":16,"num_stages":1,"num_warps":4, + "num_ctas":1}}} + ROOT fusion.1 = bf16[16,96]{1,0} fusion(triton_gemm_dot.5316), + kind=kLoop, calls=fused_computation +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, + ErrorSpec{/*aabs=*/2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); +} + +// This is based on gemm_fusion_test.cc/SplitKTest.SupportsIndivisible. +// +// There were relatively large numeric errors with an f16 temporary buffer, so I +// ended up using --xla_gpu_triton_gemm_disable_reduced_precision_reduction=true +// when generating this test case. +TEST_F(CompareTest, DISABLED_SupportsSplitKWithIndivisibleKComplexExample) { + constexpr absl::string_view kHloTextRef = R"( +HloModule extracted, entry_computation_layout={(s8[3,129,5,32]{3,2,1,0}, f16[16,129]{1,0})->f16[480,16]{1,0}} + +triton_gemm_dot.clone { + parameter_0 = s8[3,129,5,32]{3,2,1,0} parameter(0) + bitcast.1 = s8[3,5,32,129]{2,1,3,0} bitcast(parameter_0) + copy.1 = s8[3,5,32,129]{3,2,1,0} copy(bitcast.1) + reshape.5 = s8[480,129]{1,0} reshape(copy.1) + convert.8 = f16[480,129]{1,0} convert(reshape.5) + parameter_1 = f16[16,129]{1,0} parameter(1) + ROOT dot.0 = f16[480,16]{1,0} dot(convert.8, parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY entry_computation { + p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) + p1 = f16[16,129]{1,0} parameter(1) + ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} +} +)"; + + constexpr absl::string_view kHloTextSplitK = R"( +HloModule extracted, entry_computation_layout={(s8[3,129,5,32]{3,2,1,0}, f16[16,129]{1,0})->f16[480,16]{1,0}} + +triton_gemm_dot.clone { + parameter_0 = s8[3,129,5,32]{3,2,1,0} parameter(0) + bitcast.1 = s8[3,5,32,129]{2,1,3,0} bitcast(parameter_0) + copy.1 = s8[3,5,32,129]{3,2,1,0} copy(bitcast.1) + reshape.5 = s8[480,129]{1,0} reshape(copy.1) + convert.8 = f16[480,129]{1,0} convert(reshape.5) + constant = f16[] constant(0) + pad = f16[480,130]{1,0} pad(convert.8, constant), padding=0_0x0_1 + bitcast = f16[480,2,65]{2,1,0} bitcast(pad) + convert.1 = f32[480,2,65]{2,1,0} convert(bitcast) + parameter_1 = f16[16,129]{1,0} parameter(1) + constant.1 = f16[] constant(0) + pad.1 = f16[16,130]{1,0} pad(parameter_1, constant.1), padding=0_0x0_1 + bitcast.2 = f16[16,2,65]{2,1,0} bitcast(pad.1) + convert.2 = f32[16,2,65]{2,1,0} convert(bitcast.2) + ROOT dot.2 = f32[2,480,16]{2,1,0} dot(convert.1, convert.2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2} +} + +fusion.reduce_sub_computation { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0.1 = f32[2,480,16]{2,1,0} parameter(0) + constant.3 = f32[] constant(0) + reduce.1 = f32[480,16]{1,0} reduce(param_0.1, constant.3), dimensions={0}, to_apply=fusion.reduce_sub_computation + ROOT convert.3 = f16[480,16]{1,0} convert(reduce.1) +} + +ENTRY entry_computation { + p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) + p1 = f16[16,129]{1,0} parameter(1) + fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"64", + "split_k":"2","num_stages":"1","num_warps":"8", + "num_ctas":"1"}}} + ROOT fusion.1 = f16[480,16]{1,0} fusion(fusion), kind=kLoop, calls=fused_computation +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_SupportsSplitKWithIndivisibleKUsingPaddingEqual1) { + constexpr absl::string_view kHloTextRef = R"( +HloModule extracted, entry_computation_layout={(f16[1,8,4,1023]{3,2,1,0}, f16[1,1023,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} + +triton_gemm_dot.7103_computation.clone { + parameter_0.499 = f16[1,8,4,1023]{3,2,1,0} parameter(0) + bitcast.7923 = f16[32,1023]{1,0} bitcast(parameter_0.499) + parameter_1.499 = f16[1,1023,128]{2,1,0} parameter(1) + bitcast.7924 = f16[1023,128]{1,0} bitcast(parameter_1.499) + dot.9350 = f16[32,128]{1,0} dot(bitcast.7923, bitcast.7924), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT bitcast.7925 = f16[1,8,4,128]{3,2,1,0} bitcast(dot.9350) +} + +ENTRY entry_computation { + p0 = f16[1,8,4,1023]{3,2,1,0} parameter(0) + p1 = f16[1,1023,128]{2,1,0} parameter(1) + ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"32", + "split_k":"1","num_stages":"4","num_warps":"4", + "num_ctas":"1"}}} +} +)"; + + constexpr absl::string_view kHloTextSplitK = R"( +HloModule extracted, entry_computation_layout={(f16[1,8,4,1023]{3,2,1,0}, f16[1,1023,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} + +triton_gemm_dot.7103_computation.clone { + parameter_0.499 = f16[1,8,4,1023]{3,2,1,0} parameter(0) + bitcast.7923 = f16[32,1023]{1,0} bitcast(parameter_0.499) + constant = f16[] constant(0) + pad = f16[32,1024]{1,0} pad(bitcast.7923, constant), padding=0_0x0_1 + bitcast = f16[32,8,128]{2,1,0} bitcast(pad) + parameter_1.499 = f16[1,1023,128]{2,1,0} parameter(1) + bitcast.7924 = f16[1023,128]{1,0} bitcast(parameter_1.499) + constant.1 = f16[] constant(0) + pad.1 = f16[1024,128]{1,0} pad(bitcast.7924, constant.1), padding=0_1x0_0 + bitcast.1 = f16[8,128,128]{2,1,0} bitcast(pad.1) + dot.1 = f16[8,32,128]{2,1,0} dot(bitcast, bitcast.1), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT bitcast.7925.clone = f16[8,1,8,4,128]{4,3,2,1,0} bitcast(dot.1) +} + +triton_gemm_dot.7103.reduce_sub_computation.clone { + lhs.1 = f32[] parameter(0) + rhs.1 = f32[] parameter(1) + add.2 = f32[] add(lhs.1, rhs.1) + convert.13 = f16[] convert(add.2) + ROOT convert.12 = f32[] convert(convert.13) +} + +fused_computation.1 { + param_0.5 = f16[8,1,8,4,128]{4,3,2,1,0} parameter(0) + convert.16 = f32[8,1,8,4,128]{4,3,2,1,0} convert(param_0.5) + constant.3 = f16[] constant(0) + convert.15 = f32[] convert(constant.3) + reduce.1 = f32[1,8,4,128]{3,2,1,0} reduce(convert.16, convert.15), dimensions={0}, to_apply=triton_gemm_dot.7103.reduce_sub_computation.clone + ROOT convert.14 = f16[1,8,4,128]{3,2,1,0} convert(reduce.1) +} + +ENTRY entry_computation { + p0 = f16[1,8,4,1023]{3,2,1,0} parameter(0) + p1 = f16[1,1023,128]{2,1,0} parameter(1) + triton_gemm_dot.7103 = f16[8,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"16","block_n":"128","block_k":"32", + "split_k":"8","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} + ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, + ErrorSpec{/*aabs=*/4e-2, /*arel=*/2e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_SupportsSplitKWithIndivisibleKUsingPaddingEqual5) { + constexpr absl::string_view kHloTextRef = R"( +HloModule extracted, entry_computation_layout={(f16[1,8,4,1019]{3,2,1,0}, f16[1,1019,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} + +triton_gemm_dot.7103_computation.clone { + parameter_0.499 = f16[1,8,4,1019]{3,2,1,0} parameter(0) + bitcast.7923 = f16[32,1019]{1,0} bitcast(parameter_0.499) + parameter_1.499 = f16[1,1019,128]{2,1,0} parameter(1) + bitcast.7924 = f16[1019,128]{1,0} bitcast(parameter_1.499) + dot.9350 = f16[32,128]{1,0} dot(bitcast.7923, bitcast.7924), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT bitcast.7925 = f16[1,8,4,128]{3,2,1,0} bitcast(dot.9350) +} + +ENTRY entry_computation { + p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) + p1 = f16[1,1019,128]{2,1,0} parameter(1) + ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} +} +)"; + + constexpr absl::string_view kHloTextSplitK = R"( +HloModule extracted, entry_computation_layout={(f16[1,8,4,1019]{3,2,1,0}, f16[1,1019,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} + +triton_gemm_dot.7103_computation.clone { + parameter_0.499 = f16[1,8,4,1019]{3,2,1,0} parameter(0) + bitcast.7923 = f16[32,1019]{1,0} bitcast(parameter_0.499) + constant = f16[] constant(0) + pad = f16[32,1024]{1,0} pad(bitcast.7923, constant), padding=0_0x0_5 + bitcast = f16[32,16,64]{2,1,0} bitcast(pad) + parameter_1.499 = f16[1,1019,128]{2,1,0} parameter(1) + bitcast.7924 = f16[1019,128]{1,0} bitcast(parameter_1.499) + constant.1 = f16[] constant(0) + pad.1 = f16[1024,128]{1,0} pad(bitcast.7924, constant.1), padding=0_5x0_0 + bitcast.1 = f16[16,64,128]{2,1,0} bitcast(pad.1) + dot.1 = f16[16,32,128]{2,1,0} dot(bitcast, bitcast.1), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT bitcast.7925.clone = f16[16,1,8,4,128]{4,3,2,1,0} bitcast(dot.1) +} + +triton_gemm_dot.7103.reduce_sub_computation.clone { + lhs.1 = f32[] parameter(0) + rhs.1 = f32[] parameter(1) + add.2 = f32[] add(lhs.1, rhs.1) + convert.13 = f16[] convert(add.2) + ROOT convert.12 = f32[] convert(convert.13) +} + +fused_computation.1 { + param_0.5 = f16[16,1,8,4,128]{4,3,2,1,0} parameter(0) + convert.16 = f32[16,1,8,4,128]{4,3,2,1,0} convert(param_0.5) + constant.3 = f16[] constant(0) + convert.15 = f32[] convert(constant.3) + reduce.1 = f32[1,8,4,128]{3,2,1,0} reduce(convert.16, convert.15), dimensions={0}, to_apply=triton_gemm_dot.7103.reduce_sub_computation.clone + ROOT convert.14 = f16[1,8,4,128]{3,2,1,0} convert(reduce.1) +} + +ENTRY entry_computation { + p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) + p1 = f16[1,1019,128]{2,1,0} parameter(1) + triton_gemm_dot.7103 = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"64","block_n":"32","block_k":"32", + "split_k":"16","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} + ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, + ErrorSpec{/*aabs=*/4e-2, /*arel=*/2e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_NonMajorMostOutputBatchWorksCorrectly) { + const std::string kHloTextTest = R"( +HloModule m + +triton_gemm_dot.6 { + parameter_1 = f32[32,50,104]{2,1,0} parameter(1) + parameter_0 = s8[32,26,104]{2,1,0} parameter(0) + convert.22 = f32[32,26,104]{2,1,0} convert(parameter_0) + ROOT dot.127 = f32[32,50,26]{2,0,1} dot(parameter_1, convert.22), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY e { + p0 = s8[32,26,104]{2,1,0} parameter(0) + p1 = f32[32,50,104]{2,1,0} parameter(1) + ROOT triton_gemm_dot.6 = f32[32,50,26]{2,0,1} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot.6, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":16,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + const std::string kHloTextRef = R"( +HloModule m + +%triton_gemm_dot.127 { + %parameter_1.1 = f32[32,50,104]{2,1,0} parameter(1) + %parameter_0.1 = s8[32,26,104]{2,1,0} parameter(0) + %convert.0 = f32[32,26,104]{2,1,0} convert(%parameter_0.1) + ROOT %dot.0 = f32[32,50,26]{2,1,0} dot(%parameter_1.1, %convert.0), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +%fused_computation { + %param_0.1 = f32[32,50,26]{2,1,0} parameter(0) + %transpose.1 = f32[50,32,26]{2,1,0} transpose(%param_0.1), dimensions={1,0,2} + ROOT %bitcast.7 = f32[32,50,26]{2,0,1} bitcast(%transpose.1) +} + +ENTRY e { + %parameter_0 = s8[32,26,104]{2,1,0} parameter(0) + %parameter_1 = f32[32,50,104]{2,1,0} parameter(1) + %triton_gemm_dot.127 = f32[32,50,26]{2,1,0} fusion(%parameter_0, %parameter_1), + kind=kCustom, calls=%triton_gemm_dot.127, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":128,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} + ROOT %fusion.1 = f32[32,50,26]{2,0,1} fusion(%triton_gemm_dot.127), kind=kLoop, calls=%fused_computation +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_TritonDotFusionCanHaveOnlyRHSParameter) { + const std::string kHloTextTest = R"( +HloModule m, is_scheduled=true + +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + c = f16[] constant(321) + b = f16[11,63] broadcast(c) + cc = f32[11,63] convert(b) + ROOT _.1 = f32[63,92]{1,0} dot(cc, parameter_0), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[92,11]{1,0} parameter(0) + ROOT triton_gemm__ = f32[63,92]{1,0} fusion(p0), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"16","block_n":"64", + "block_k":"16","split_k":"1", + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} +})"; + + const std::string kHloTextRef = R"( +HloModule m, is_scheduled=true + +ENTRY e { + constant_2 = f32[] constant(321) + parameter_0 = f32[92,11]{1,0} parameter(0) + broadcast.2 = f32[11,63]{1,0} broadcast(constant_2), dimensions={} + gemm = (f32[63,92]{1,0}, s8[0]{0}) custom-call(broadcast.2, parameter_0), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[63,92]{1,0} get-tuple-element((f32[63,92]{1,0}, s8[0]{0}) gemm), index=0 +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_TritonDotFusionCanHaveNoParametersAtAll) { + const std::string kHloTextTest = R"( +HloModule m, is_scheduled=true + +triton_gemm___computation { + c = f32[] constant(7) + b = f32[11,61] broadcast(c) + c2 = f32[] constant(5) + b2 = f32[61,45] broadcast(c2) + ROOT _.1 = f32[11,45]{1,0} dot(b, b2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + ROOT triton_gemm__ = f32[11,45]{1,0} fusion(), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"16","block_n":"64", + "block_k":"16","split_k":"1", + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} +})"; + + const std::string kHloTextRef = R"( +HloModule m, is_scheduled=true + +ENTRY triton_gemm___computation { + constant_1 = f32[] constant(7) + constant = f32[] constant(5) + broadcast = f32[11,61]{1,0} broadcast(constant), dimensions={} + broadcast.1 = f32[61,45]{1,0} broadcast(constant_1), dimensions={} + gemm = (f32[11,45]{1,0}, s8[0]{0}) custom-call(broadcast, broadcast.1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[11,45]{1,0} get-tuple-element((f32[11,45]{1,0}, s8[0]{0}) gemm), index=0 +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_TritonDotFusionCanHaveManyParameters) { + const std::string kHloTextTest = R"( +HloModule m + +triton_gemm_dot_computation { + tmp_1 = pred[3,32]{1,0} parameter(0) + tmp_2 = f32[3,32]{1,0} parameter(1) + tmp_3 = f32[3,32]{1,0} parameter(2) + tmp_4 = f32[3,32]{1,0} select(tmp_1, tmp_2, tmp_3) + tmp_5 = f32[3,32]{1,0} parameter(3) + tmp_6 = f32[3,32]{1,0} multiply(tmp_4, tmp_5) + tmp_7 = f32[3,32]{1,0} parameter(4) + tmp_8 = f32[3,32]{1,0} maximum(tmp_6, tmp_7) + tmp_9 = f32[3,57]{1,0} parameter(9) + tmp_10 = f32[3,57]{1,0} parameter(10) + tmp_11 = f32[3,57]{1,0} multiply(tmp_9, tmp_10) + tmp_12 = f32[3,57]{1,0} parameter(11) + tmp_13 = f32[3,57]{1,0} add(tmp_11, tmp_12) + tmp_14 = pred[3,57]{1,0} parameter(5) + tmp_15 = f32[3,57]{1,0} parameter(6) + tmp_16 = f32[3,57]{1,0} parameter(7) + tmp_17 = f32[3,57]{1,0} select(tmp_14, tmp_15, tmp_16) + tmp_18 = f32[3,57]{1,0} parameter(8) + tmp_19 = f32[3,57]{1,0} multiply(tmp_17, tmp_18) + tmp_20 = f32[3,57]{1,0} negate(tmp_19) + tmp_21 = f32[3,57]{1,0} add(tmp_13, tmp_20) + const_1 = f32[] constant(-3e-3) + const_2 = f32[] constant(3e-2) + broadcast_1 = f32[3,57]{1,0} broadcast(const_1), dimensions={} + broadcast_2 = f32[3,57]{1,0} broadcast(const_2), dimensions={} + tmp_22 = f32[3,57]{1,0} clamp(broadcast_1, tmp_21, broadcast_2) + ROOT tmp_23 = f32[32,57]{0,1} dot(tmp_8, tmp_22), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + tmp_1 = pred[3,32]{1,0} parameter(0) + tmp_2 = f32[3,32]{1,0} parameter(1) + tmp_3 = f32[3,32]{1,0} parameter(2) + tmp_5 = f32[3,32]{1,0} parameter(3) + tmp_7 = f32[3,32]{1,0} parameter(4) + tmp_14 = pred[3,57]{1,0} parameter(5) + tmp_15 = f32[3,57]{1,0} parameter(6) + tmp_16 = f32[3,57]{1,0} parameter(7) + tmp_18 = f32[3,57]{1,0} parameter(8) + tmp_9 = f32[3,57]{1,0} parameter(9) + tmp_10 = f32[3,57]{1,0} parameter(10) + tmp_12 = f32[3,57]{1,0} parameter(11) + ROOT r = f32[32,57]{0,1} fusion(tmp_1, tmp_2, tmp_3, tmp_5, tmp_7, tmp_14, tmp_15, tmp_16, tmp_18, tmp_9, tmp_10, tmp_12), kind=kCustom, + calls=triton_gemm_dot_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"64","block_n":"64", + "block_k":"64","split_k":"1", + "num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} +})"; + + const std::string kHloTextRef = R"( +HloModule m + +fused_computation { + param_5.1 = f32[3,57]{1,0} parameter(5) + param_6 = f32[3,57]{1,0} parameter(6) + multiply.4 = f32[3,57]{1,0} multiply(param_5.1, param_6) + param_4.2 = f32[3,57]{1,0} parameter(4) + add.3 = f32[3,57]{1,0} add(multiply.4, param_4.2) + param_1.4 = pred[3,57]{1,0} parameter(1) + param_2.2 = f32[3,57]{1,0} parameter(2) + param_3.1 = f32[3,57]{1,0} parameter(3) + select.2 = f32[3,57]{1,0} select(param_1.4, param_2.2, param_3.1) + param_0.1 = f32[3,57]{1,0} parameter(0) + multiply.3 = f32[3,57]{1,0} multiply(select.2, param_0.1) + negate.1 = f32[3,57]{1,0} negate(multiply.3) + add.2 = f32[3,57]{1,0} add(add.3, negate.1) + const.1 = f32[] constant(-3e-3) + const.2 = f32[] constant(3e-2) + broadcast.1 = f32[3,57]{1,0} broadcast(const.1), dimensions={} + broadcast.2 = f32[3,57]{1,0} broadcast(const.2), dimensions={} + ROOT clamp = f32[3,57]{1,0} clamp(broadcast.1, add.2, broadcast.2) +} + +fused_computation.1 { + param_2.4 = pred[3,32]{1,0} parameter(2) + param_3.2 = f32[3,32]{1,0} parameter(3) + param_4.3 = f32[3,32]{1,0} parameter(4) + select.3 = f32[3,32]{1,0} select(param_2.4, param_3.2, param_4.3) + param_1.7 = f32[3,32]{1,0} parameter(1) + multiply.5 = f32[3,32]{1,0} multiply(select.3, param_1.7) + param_0.3 = f32[3,32]{1,0} parameter(0) + ROOT maximum.1 = f32[3,32]{1,0} maximum(multiply.5, param_0.3) +} + +ENTRY e { + tmp_18 = f32[3,57]{1,0} parameter(8) + tmp_16 = f32[3,57]{1,0} parameter(7) + tmp_15 = f32[3,57]{1,0} parameter(6) + tmp_14 = pred[3,57]{1,0} parameter(5) + tmp_12 = f32[3,57]{1,0} parameter(11) + tmp_10 = f32[3,57]{1,0} parameter(10) + tmp_9 = f32[3,57]{1,0} parameter(9) + tmp_7 = f32[3,32]{1,0} parameter(4) + tmp_5 = f32[3,32]{1,0} parameter(3) + tmp_3 = f32[3,32]{1,0} parameter(2) + tmp_2 = f32[3,32]{1,0} parameter(1) + tmp_1 = pred[3,32]{1,0} parameter(0) + fusion.1 = f32[3,32]{1,0} fusion(tmp_7, tmp_5, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.1 + fusion = f32[3,57]{1,0} fusion(tmp_18, tmp_14, tmp_15, tmp_16, tmp_12, /*index=5*/tmp_9, tmp_10), kind=kLoop, calls=fused_computation + gemm = (f32[32,57]{0,1}, s8[0]{0}) custom-call(fusion.1, fusion), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[32,57]{0,1} get-tuple-element((f32[32,57]{0,1}, s8[0]{0}) gemm), index=0 +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_PredToBF16ConversionWorks) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + const std::string kHloTextTest = R"( +HloModule m, is_scheduled=true + +triton_gemm_computation { + parameter_0 = bf16[92,11]{1,0} parameter(0) + parameter_1 = s32[11,63]{1,0} parameter(1) + parameter_2 = s32[11,63]{1,0} parameter(2) + f1.1 = pred[11,63]{1,0} compare(parameter_1, parameter_2), direction=GE + c.1 = bf16[11,63]{1,0} convert(f1.1) + ROOT _.1 = bf16[92,63]{1,0} dot(parameter_0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[92,11]{1,0} parameter(0) + p1 = s32[11,63]{1,0} parameter(1) + p2 = s32[11,63]{1,0} parameter(2) + ROOT triton_gemm__ = bf16[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, + calls=triton_gemm_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"16", + "block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} +})"; + + const std::string kHloTextRef = R"( +HloModule m, is_scheduled=true + +fused_computation { + p0 = s32[11,63]{1,0} parameter(0) + p1 = s32[11,63]{1,0} parameter(1) + f.1 = pred[11,63]{1,0} compare(p0, p1), direction=GE + ROOT convert.1 = bf16[11,63]{1,0} convert(f.1) +} + +ENTRY e { + p2 = s32[11,63]{1,0} parameter(2) + p1 = s32[11,63]{1,0} parameter(1) + p0 = bf16[92,11]{1,0} parameter(0) + fusion = bf16[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation + gemm = (bf16[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers": + {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, + "alpha_imag":0,"precision_config": + {"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = bf16[92,63]{1,0} get-tuple-element((bf16[92,63]{1,0}, s8[0]{0}) gemm), index=0 +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +TEST_F(CompareTest, DISABLED_DifferentLayoutsAreSupportedInOneScope) { + const std::string kHloTextTest = R"( +triton_dot { + p1 = f16[3,3,2,16]{1,3,2,0} parameter(1) + cvt1 = f32[3,3,2,16]{1,3,2,0} convert(p1) + p0 = f16[9,32]{0,1} parameter(0) + b0 = f16[3,3,2,16]{1,0,3,2} bitcast(p0) + cp0b0 = f16[2,16,3,3]{3,2,1,0} bitcast(b0) + cp0t0 = f16[3,2,16,3]{3,2,1,0} transpose(cp0b0), dimensions={2,0,1,3} + cp0b1 = f16[3,3,2,16]{1,3,2,0} bitcast(cp0t0) + cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0b1) + m = f32[3,3,2,16]{1,3,2,0} multiply(cvt1, cvt0) + cvt2 = f16[3,3,2,16]{1,3,2,0} convert(m) + cp1b0 = f16[3,2,16,3]{3,2,1,0} bitcast(cvt2) + cp1t0 = f16[3,3,2,16]{3,2,1,0} transpose(cp1b0), dimensions={0,3,1,2} + b1 = f16[9,32]{1,0} bitcast(cp1t0) + p2 = f16[32,32]{1,0} parameter(2) + ROOT r = f16[9,32]{1,0} dot(b1, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[9,32]{0,1} parameter(0) + p1 = f16[3,3,2,16]{1,3,2,0} parameter(1) + p2 = f16[32,32]{1,0} parameter(2) + ROOT r = f16[9,32]{1,0} fusion(p0, p1, p2), + kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":"1"}}} +})"; + + const std::string kHloTextRef = R"( +ENTRY e { + p1 = f16[3,3,2,16]{1,3,2,0} parameter(1) + cvt1 = f32[3,3,2,16]{1,3,2,0} convert(p1) + p0 = f16[9,32]{0,1} parameter(0) + b0 = f16[3,3,2,16]{1,0,3,2} bitcast(p0) + cp0b0 = f16[2,16,3,3]{3,2,1,0} bitcast(b0) + cp0t0 = f16[3,2,16,3]{3,2,1,0} transpose(cp0b0), dimensions={2,0,1,3} + cp0b1 = f16[3,3,2,16]{1,3,2,0} bitcast(cp0t0) + cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0b1) + m = f32[3,3,2,16]{1,3,2,0} multiply(cvt1, cvt0) + cvt2 = f16[3,3,2,16]{1,3,2,0} convert(m) + cp1b0 = f16[3,2,16,3]{3,2,1,0} bitcast(cvt2) + cp1t0 = f16[3,3,2,16]{3,2,1,0} transpose(cp1b0), dimensions={0,3,1,2} + b1 = f16[9,32]{1,0} bitcast(cp1t0) + p2 = f16[32,32]{1,0} parameter(2) + ROOT r = f16[9,32]{1,0} dot(b1, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4}, + /*run_hlo_passes=*/false)); +} + +class TritonGemmContractionDims : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true); + return debug_options; + } +}; + +TEST_F(TritonGemmContractionDims, DISABLED_TritonDotForceContractionDims_1_0) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + parameter.0 = bf16[16,40]{1,0} parameter(0) + parameter.1 = bf16[40,32]{1,0} parameter(1) + ROOT dot.31472 = bf16[16,32]{1,0} dot(parameter.0, parameter.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT(module->entry_computation() + ->root_instruction() + ->fused_instructions_computation() + ->root_instruction(), + GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 40}, {1, 0}), + m::Op().WithShape(BF16, {40, 32}, {0, 1})) + .WithShape(BF16, {16, 32}, {1, 0}))); +} + +TEST_F(TritonGemmContractionDims, + DISABLED_TritonDotForceContractionDims_1_2_1_2) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + parameter_0 = bf16[32,4,36]{2,1,0} parameter(0) + parameter_1 = bf16[40,4,36]{2,1,0} parameter(1) + ROOT dot.16450 = bf16[4,32,40]{2,1,0} dot(parameter_0, parameter_1), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={1}, rhs_contracting_dims={2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + // The contracting dims were already minor, so the layout is unchanged + // (non-major batch dims are fine). + EXPECT_THAT(module->entry_computation() + ->root_instruction() + ->fused_instructions_computation() + ->root_instruction(), + GmockMatch(m::Dot(m::Op().WithShape(BF16, {32, 4, 36}, {2, 1, 0}), + m::Op().WithShape(BF16, {40, 4, 36}, {2, 1, 0})) + .WithShape(BF16, {4, 32, 40}, {2, 1, 0}))); +} + +TEST_F(TritonGemmContractionDims, + DISABLED_TritonDotForceContractionDims_1_2_0_1) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + parameter_1 = bf16[16,16,48]{2,1,0} parameter(1) + parameter_2 = bf16[16,48,32]{2,1,0} parameter(0) + ROOT dot.16125 = bf16[16,16,32]{2,1,0} dot(parameter_1, parameter_2), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + // lhs has minor contracting dims, so the layout is changed. + // rhs changes layout to have minor contracting dims. + EXPECT_THAT( + module->entry_computation() + ->root_instruction() + ->fused_instructions_computation() + ->root_instruction(), + GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 16, 48}, {2, 1, 0}), + m::Op().WithShape(BF16, {16, 48, 32}, {1, 2, 0})) + .WithShape(BF16, {16, 16, 32}, {2, 1, 0}))); +} + +TEST_F(TritonGemmContractionDims, DISABLED_TritonDotForceContractionDims_1_1) { + if (!SupportsBF16(GpuComputeCapability())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr absl::string_view kHloText = R"( +HloModule m + +ENTRY e { + parameter_0 = bf16[16,32]{1,0} parameter(0) + parameter_1 = bf16[40,32]{0,1} parameter(1) + ROOT dot.15148 = bf16[16,40]{1,0} dot(parameter_0, parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT(module->entry_computation() + ->root_instruction() + ->fused_instructions_computation() + ->root_instruction(), + GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 32}, {1, 0}), + m::Op().WithShape(BF16, {32, 40}, {1, 0})) + .WithShape(BF16, {16, 40}, {1, 0}))); +} + +// This test could be modified to allow TF32 once this bug is fixed. +// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. +TEST_F(TritonTest, DISABLED_NoTF32For8BitOrLessWithF32) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + parameter_0 = s32[11,24]{1,0} parameter(0) + broadcast.1747 = s32[11,24,128]{2,1,0} broadcast(parameter_0), + dimensions={0,1} parameter_1 = s32[11,24,128]{2,1,0} parameter(1) + compare.49 = pred[11,24,128]{2,1,0} compare(broadcast.1747, parameter_1), + direction=EQ + bitcast.4717 = pred[264,128]{1,0} bitcast(compare.49) + convert.142 = f32[264,128]{1,0} convert(bitcast.4717) + parameter_2 = f32[128,8]{1,0} parameter(2) + ROOT dot.381 = f32[264,8]{1,0} dot(convert.142, parameter_2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s32[11,24]{1,0} parameter(0) + p1 = s32[11,24,128]{2,1,0} parameter(1) + p2 = f32[128,8]{1,0} parameter(2) + ROOT _ = f32[264,8] fusion(p0, p1, p2), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":16,"block_k":128, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( +CHECK: tt.dot +CHECK-NOT: inputPrecision = tf32 + )")); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, DISABLED_Fp8LoweringIsSupportedPostHopper) { + if (!GetCudaComputeCapability().IsAtLeastHopper()) { + GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; + } + const std::string hlo_text = R"( +HloModule t + +triton_dot { + parameter_0 = f8e4m3fn[1600,1600]{1,0} parameter(0) + parameter_1 = f8e4m3fn[1600,1600]{1,0} parameter(1) + transpose = f8e4m3fn[1600,1600]{0,1} transpose(parameter_1), dimensions={1,0} + ROOT dot = f16[1600,1600]{1,0} dot(parameter_0, transpose), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY main { + parameter_1 = f8e4m3fn[1600,1600]{1,0} parameter(1) + parameter_0 = f8e4m3fn[1600,1600]{1,0} parameter(0) + ROOT gemm_fusion_dot = f16[1600,1600]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_dot, + backend_config={ + "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": + {"block_m":"128","block_n":"32","block_k":"64","split_k":"1", + "num_stages":"4","num_warps":"4","num_ctas":"1"}}} +})"; + + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( +CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4M3FN> * tensor<64x32xf8E4M3FN> -> tensor<128x32xf32> + )")); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, DISABLED_BF16ToFP8EndToEnd) { + if (!GetCudaComputeCapability().IsAtLeastHopper()) { + GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; + } + + const std::string hlo_text = R"( +HloModule t + +triton_dot { + parameter_0 = bf16[32,32]{1,0} parameter(0) + parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + convert = f8e4m3fn[32,32]{1,0} convert(parameter_0) + ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY main { + parameter_0 = bf16[32,32]{1,0} parameter(0) + parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_dot, + backend_config={ + "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": + {"block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, DISABLED_FP8ToFP8EndToEnd) { + if (!GetCudaComputeCapability().IsAtLeastHopper()) { + GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; + } + + const std::string hlo_text = R"( +HloModule t + +triton_dot { + parameter_0 = f8e5m2[32,32]{1,0} parameter(0) + parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + convert = f8e4m3fn[32,32]{1,0} convert(parameter_0) + ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY main { + parameter_0 = f8e5m2[32,32]{1,0} parameter(0) + parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_dot, + backend_config={ + "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": + {"block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); +} + +// Test PreventMmaV3LoopUnrolling pass in order to keep compile time low. +// See b/344841434. +TEST_F(TritonGemmTest, DISABLED_TestPreventMMAV3LoopUnrolling) { + if (GetCudaComputeCapability().major != se::CudaComputeCapability::kHopper) { + GTEST_SKIP() << "wgmma instruction is only available on Hopper"; + } + const std::string hlo_text = R"( +gemm_fusion_dot { + %p0 = f16[64,1024]{1,0} parameter(0) + %p1 = f16[1024,32,32]{2,1,0} parameter(1) + %bitcast.74246 = f16[1024,1024]{0,1} bitcast(f16[1024,32,32]{2,1,0} %p1) + ROOT %dot.1302 = f16[64,1024]{1,0} dot(f16[64,1024]{1,0} %p0, f16[1024,1024]{0,1} %bitcast.74246), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} +} + +ENTRY e { + p0 = f16[64,1024]{1,0} parameter(0) + p1 = f16[1024,32,32]{2,1,0} parameter(1) + ROOT triton_gemm_fusion_dot = f16[64,1024]{1,0} fusion(p0, p1), kind=kCustom, + calls=gemm_fusion_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: $L__BB0_1: +CHECK-NEXT: // begin inline asm +CHECK-NEXT: .pragma "nounroll"; +CHECK: wgmma +)"); +} + +TEST_F(TritonGemmTest, DISABLED_WgmmaIsUsedForMemBoundShape) { + if (GetCudaComputeCapability().major != se::CudaComputeCapability::kHopper) { + GTEST_SKIP() << "wgmma instruction is only available on Hopper"; + } + const std::string hlo_text = R"( +gemm_fusion_dot { + p0 = s8[128,128]{1,0} parameter(0) + p1 = bf16[128,16]{1,0} parameter(1) + convert = bf16[128,128]{1,0} convert(p0) + ROOT %dot = bf16[128,16]{1,0} dot(convert, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[128,128]{1,0} parameter(0) + p1 = bf16[128,16]{1,0} parameter(1) + ROOT triton_gemm_fusion_dot = bf16[128,16]{1,0} fusion(p0, p1), kind=kCustom, + calls=gemm_fusion_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":128,"block_n":16,"block_k":16, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), R"( +CHECK: wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 +)"); +} + +// Test presence of default matmul config information +// when gemm autotuner is not present in pipeline, +// (which is currently the case on rocm). +TEST_F(TritonGemmTest, DISABLED_TestNoAutotuner) { + if (std::holds_alternative( + GpuComputeCapability())) { + GTEST_SKIP() << "Autotuner is always in pipeline on Cuda."; + } + constexpr absl::string_view kHloText = R"( +ENTRY e { + p0 = f16[30,30] parameter(0) + p1 = s8[30,30] parameter(1) + cp1 = f16[30,30] convert(p1) + ROOT _ = f16[30,30] dot(p0, cp1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + DebugOptions debug_options = verified_module->config().debug_options(); + debug_options.set_xla_gpu_autotune_level(0); + verified_module->mutable_config().set_debug_options(debug_options); + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_gemm + )"); + + EXPECT_TRUE(RunAndCompare(std::move(verified_module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla From 01a71e3969a0fe0170bb03048874cf87b61e12a4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 2 Apr 2025 14:23:30 -0700 Subject: [PATCH 0175/1324] [xla:gpu] CommandBuffer: add support for attaching multiple state objects to a command PiperOrigin-RevId: 743278958 --- .../xla/xla/backends/gpu/runtime/BUILD | 1 + .../gpu/runtime/command_buffer_cmd.cc | 19 ++++-- .../backends/gpu/runtime/command_buffer_cmd.h | 32 +++++++--- .../gpu/runtime/command_buffer_cmd_test.cc | 61 ++++++++++++------- 4 files changed, 77 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index a5a22480624d46..471c54cecde82b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -95,6 +95,7 @@ cc_library( "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/tsl/lib/gtl:int_type", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 0a18b5725a0f10..6c14705755d4d5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/command_buffer_cmd.h" #include +#include #include #include #include @@ -162,21 +163,29 @@ CommandBufferCmd::RecordedCommands::Create( // CommandBufferCmd //===----------------------------------------------------------------------===// +CommandBufferCmd::StateManager::TypeId +CommandBufferCmd::StateManager::GetNextTypeId() { + static auto* counter = new std::atomic(1); + return TypeId(counter->fetch_add(1)); +} + CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrNull( - const CommandBufferCmd* cmd) { - if (auto it = state_.find(cmd); it != state_.end()) { + const CommandBufferCmd* cmd, TypeId type_id) { + StateKey key = {cmd, type_id}; + if (auto it = state_.find(key); it != state_.end()) { return it->second.get(); } return nullptr; } CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( - const CommandBufferCmd* cmd, + const CommandBufferCmd* cmd, TypeId type_id, absl::FunctionRef()> create) { - if (auto it = state_.find(cmd); it != state_.end()) { + StateKey key = {cmd, type_id}; + if (auto it = state_.find(key); it != state_.end()) { return it->second.get(); } - return state_.try_emplace(cmd, create()).first->second.get(); + return state_.try_emplace(key, create()).first->second.get(); } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index fa54c980ec092d..433d7f9a945c79 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -58,6 +58,7 @@ limitations under the License. #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla::gpu { @@ -142,7 +143,8 @@ class CommandBufferCmd { template ConcreteState* GetOrNull(const CommandBufferCmd* cmd) { static_assert(std::is_base_of_v); - return static_cast(GetOrNull(cmd)); + return static_cast( + GetOrNull(cmd, GetTypeId())); } template @@ -150,24 +152,36 @@ class CommandBufferCmd { const CommandBufferCmd* cmd, absl::FunctionRef()> create) { static_assert(std::is_base_of_v); - return static_cast(GetOrCreate( - cmd, [&]() -> std::unique_ptr { return create(); })); + return static_cast( + GetOrCreate(cmd, GetTypeId(), + [&]() -> std::unique_ptr { return create(); })); } template ConcreteState* GetOrCreate(const CommandBufferCmd* cmd) { - static_assert(std::is_base_of_v); - return static_cast( - GetOrCreate(cmd, [] { return std::make_unique(); })); + return GetOrCreate( + cmd, [] { return std::make_unique(); }); } private: - State* GetOrNull(const CommandBufferCmd* cmd); + // We use TypeId to distinguish between different state types. + TSL_LIB_GTL_DEFINE_INT_TYPE(TypeId, int64_t); + + template + static TypeId GetTypeId() { + static const TypeId id = GetNextTypeId(); + return id; + } + + static TypeId GetNextTypeId(); + + State* GetOrNull(const CommandBufferCmd* cmd, TypeId type_id); - State* GetOrCreate(const CommandBufferCmd* cmd, + State* GetOrCreate(const CommandBufferCmd* cmd, TypeId type_id, absl::FunctionRef()> create); - absl::flat_hash_map> state_; + using StateKey = std::pair; + absl::flat_hash_map> state_; }; // Parameters for recording commands into the command buffer. diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index a3e70675c6c40b..bb2f58cb6d584f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -95,6 +95,45 @@ class FakeCmd : public CommandBufferCmd { BufferUseVector buffers() override { return BufferUseVector{}; } }; +TEST(CommandBufferCmdStateManageTest, GetOrCreateState) { + struct StateA : public CommandBufferCmd::State { + int32_t value = 0; + }; + + struct StateB : public CommandBufferCmd::State { + float value = 0; + }; + + // We need a fake command buffer pointer to use as a key. + CommandBufferCmd* cmd = reinterpret_cast(0x1234567); + + CommandBufferCmd::StateManager state_manager; + + // Create a state of type StateA. + auto* stateA0 = state_manager.GetOrNull(cmd); + ASSERT_EQ(stateA0, nullptr); + + auto* stateA1 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(stateA1->value, 0); + stateA1->value += 42; + + auto* stateA2 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(stateA2->value, 42); + ASSERT_EQ(stateA1, stateA2); + + // StateB has a different type, and has no connection to StateA created above. + auto* stateB0 = state_manager.GetOrNull(cmd); + ASSERT_EQ(stateB0, nullptr); + + auto* stateB1 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(stateB1->value, 0); + stateB1->value += 42.0; + + auto* stateB2 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(stateB2->value, 42.0); + ASSERT_EQ(stateB1, stateB2); +} + TEST(CommandBufferCmdTest, SerializeExecution) { BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); @@ -279,28 +318,6 @@ TEST(CommandBufferCmdTest, LaunchCmd) { ASSERT_EQ(dst, std::vector(4, 42 + 42)); } -TEST(CommandBufferCmdStateManageTest, GetOrCreateState) { - struct TestState : public CommandBufferCmd::State { - int32_t value = 0; - }; - - // We need a fake command buffer pointer to use as a key. - CommandBufferCmd* cmd = reinterpret_cast(0x1234567); - - CommandBufferCmd::StateManager state_manager; - - auto* state0 = state_manager.GetOrNull(cmd); - ASSERT_EQ(state0, nullptr); - - auto* state1 = state_manager.GetOrCreate(cmd); - ASSERT_EQ(state1->value, 0); - state1->value += 42; - - auto* state2 = state_manager.GetOrCreate(cmd); - ASSERT_EQ(state2->value, 42); - ASSERT_EQ(state1, state2); -} - TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { auto run_traced_test = [](int trace_cache_size) { se::StreamExecutor* executor = GpuExecutor(); From 16fc7aafcae0a40faf3d5dc89868fe7320029b32 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 2 Apr 2025 14:33:33 -0700 Subject: [PATCH 0176/1324] Remove Tuple support from stream executor client. PiperOrigin-RevId: 743282550 --- third_party/xla/xla/pjrt/gpu/raw_buffer.cc | 7 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 20 ++- .../xla/pjrt/pjrt_stream_executor_client.cc | 114 +++++------------- .../pjrt/pjrt_stream_executor_client_test.cc | 5 +- .../xla/xla/pjrt/tracked_device_buffer.cc | 38 ++---- .../xla/xla/pjrt/tracked_device_buffer.h | 27 ++--- .../xla/pjrt/tracked_device_buffer_test.cc | 29 +---- 7 files changed, 66 insertions(+), 174 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/raw_buffer.cc b/third_party/xla/xla/pjrt/gpu/raw_buffer.cc index f8a594a87f0598..7e6ba82da584fc 100644 --- a/third_party/xla/xla/pjrt/gpu/raw_buffer.cc +++ b/third_party/xla/xla/pjrt/gpu/raw_buffer.cc @@ -53,13 +53,14 @@ CreateGPURawBuffer(PjRtBuffer* buffer) { return hold.status(); } const auto& device_buffer = hold.buffer(); - if (device_buffer->device_memory().size() != 1) { - return absl::InvalidArgumentError("Copy raw buffer called on tuple"); + if (!device_buffer->device_memory()) { + return absl::InvalidArgumentError( + "Create raw buffer called on an invalid buffer"); } return tsl::MakeRef( se_client, se_buffer->memory_space(), se_buffer->device()->local_device_state(), - device_buffer->device_memory()[0]); + device_buffer->device_memory()); } return std::nullopt; } diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 23bf04851e88ae..3ddd2c2caeba0a 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -215,8 +215,8 @@ class AsyncHostToDeviceTransferManager device_(device) { buffer_sizes_.reserve(buffer_ptrs_.size()); for (const auto& ptr : buffer_ptrs_) { - DCHECK_EQ(ptr->device_memory().size(), 1); - buffer_sizes_.push_back(ptr->device_memory()[0]->mem().size()); + DCHECK(ptr->device_memory()); + buffer_sizes_.push_back(ptr->device_memory()->mem().size()); } last_transfer_started_.resize(buffer_ptrs_.size(), false); } @@ -275,15 +275,13 @@ class AsyncHostToDeviceTransferManager last_transfer_started_[buffer_index] = true; buffer = buffer_ptrs_[buffer_index]; DCHECK(buffer); - if (buffer->device_memory().empty()) { + if (!buffer->device_memory()) { return InvalidArgument( "TransferLiteralToBuffer requested for buffer index %d which has " "been donated. Async transfer of donated buffers is not supported " "in SE:GPU", buffer_index); } - DCHECK_EQ(buffer->device_memory().size(), 1); - ++transfers_in_flight_; } @@ -377,15 +375,14 @@ class AsyncHostToDeviceTransferManager last_transfer_started_[buffer_index] = true; } DCHECK(buffer_ptrs_[buffer_index]); - if (buffer_ptrs_[buffer_index]->device_memory().empty()) { + if (!buffer_ptrs_[buffer_index]->device_memory()) { return InvalidArgument( "TransferRawDataToSubBuffer requested for buffer index %d which has " "been donated. Async transfer of donated buffers is not supported " "in SE:GPU", buffer_index); } - DCHECK_EQ(buffer_ptrs_[buffer_index]->device_memory().size(), 1); - auto& buffer_memory = buffer_ptrs_[buffer_index]->device_memory()[0]->mem(); + auto& buffer_memory = buffer_ptrs_[buffer_index]->device_memory()->mem(); se::DeviceMemoryBase sub_buffer; CHECK_LE(offset, buffer_memory.size()); CHECK_LE(transfer_size, buffer_memory.size() - offset); @@ -652,8 +649,9 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( } auto device_buffer = hold.buffer(); - if (device_buffer->device_memory().size() != 1) { - return PjRtFuture<>(InvalidArgument("Copy raw buffer called on tuple")); + if (!device_buffer->device_memory()) { + return PjRtFuture<>( + InvalidArgument("Copy raw buffer called on an invalid buffer")); } auto promise = PjRtFuture<>::CreatePromise(); @@ -686,7 +684,7 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( return; } - auto& device_memory = device_buffer->device_memory()[0]->mem(); + auto& device_memory = device_buffer->device_memory()->mem(); if (offset < 0 || offset > device_memory.size() || device_memory.size() - offset < transfer_size) { promise.Set( diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 097bc9071a85a0..338312b2d208d0 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -427,8 +427,9 @@ AllocateDestinationBuffer( bool is_uninitialized_create, PjRtStreamExecutorClient* client, std::shared_ptr definition_event, PjRtMemorySpace* memory_space) { - if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) { - return InvalidArgument("Can't make a buffer from an empty tuple"); + if (on_host_shape.IsTuple()) { + return InvalidArgument( + "Cannot allocate a PjRtStreamExecutorBuffer for a tuple."); } PjRtMemorySpace* default_memory_space = @@ -510,54 +511,18 @@ AllocateDestinationBuffer( std::make_shared(client->thread_pool())); } } - se::Stream* tuple_table_stream = local_device->host_to_device_stream(); - if (on_device_shape.IsTuple()) { - // We also need to copy the tuple tables, so we'll have an additional - // definition event for that copy to complete. - if (tuple_table_stream != copy_stream) { - if (local_device->allocation_model() == - LocalDeviceState::kComputeSynchronized) { - DCHECK( - tuple_table_stream->WaitFor(local_device->compute_stream()).ok()); - } else { - DCHECK(transfer_manager->CanShapedBufferBeAccessedNow( - local_device->compute_stream()->parent(), dst_buffer)); - } - } - TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( - tuple_table_stream, dst_buffer)); - // CAUTION: From this point onwards we need to be careful about returning - // from error cases because we have started a transfer and must not allow - // dst_buffer to be freed too soon in the non-async allocation models. + auto mem = RawSEDeviceMemory::Create(dst_buffer.buffer({}), + device->local_device_id(), + dst_buffer.memory_allocator()); + dst_buffer.clear(); - definition_events.emplace_back( - std::make_shared(client->thread_pool())); - absl::StatusOr event_or = - local_device->event_pool().ThenAllocateAndRecordEvent( - tuple_table_stream); - if (!event_or.ok()) { - StallStreamOnError(local_device, tuple_table_stream); - return event_or.status(); - } - definition_events.back()->SetSequencingEvent(std::move(event_or).value(), - tuple_table_stream); - } - std::shared_ptr dst_device_buffer = - TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, - definition_events, device); + auto dst_device_buffer = std::make_shared( + device, std::move(mem), definition_events); auto py_buffer = std::make_unique( on_device_shape, std::move(dst_device_buffer), client, device, memory_space); - - if (on_device_shape.IsTuple()) { - // Add a usage hold for the tuple table write and immediately convert it to - // the appropriate form of synchronization. - RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, - definition_events.back(), tuple_table_stream); - } - return py_buffer; } @@ -667,7 +632,7 @@ class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference { : external_reference_(std::move(hold)) { CHECK(external_reference_.type() == PjRtStreamExecutorBuffer::ScopedHold::kExternalReference); - data_ptr_ = external_reference_->device_memory().front()->opaque(); + data_ptr_ = external_reference_->device_memory()->opaque(); } ~ScopedHoldAsExternalReference() override = default; @@ -701,7 +666,7 @@ class TrackedDeviceBufferExternalReference explicit TrackedDeviceBufferExternalReference( std::shared_ptr tracked_device_buffer) : tracked_device_buffer_(std::move(tracked_device_buffer)) { - data_ptr_ = tracked_device_buffer_->device_memory()[0]->opaque(); + data_ptr_ = tracked_device_buffer_->device_memory()->opaque(); } ~TrackedDeviceBufferExternalReference() override = default; @@ -744,9 +709,6 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { } // Copy all the data in the existing tracked_buffer. - absl::InlinedVector, 4> buffers( - tracked_buffer->device_memory().begin(), - tracked_buffer->device_memory().end()); auto original_definition_events = tracked_buffer->definition_events(); absl::InlinedVector, 4> definition_events; @@ -761,7 +723,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { original_definition_events.end()); auto new_device_buffer = std::make_shared( - device(), std::move(buffers), std::move(definition_events)); + device(), tracked_buffer->device_memory(), std::move(definition_events)); // Make the new buffer which is identical to the old, except for the new // definition event. @@ -946,7 +908,7 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( // allocation. se::DeviceMemoryBase device_memory = - device_buffer->device_memory()[0]->mem(); + device_buffer->device_memory()->mem(); // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned @@ -1071,7 +1033,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error, // Create an empty buffer. auto dummy_device_buffer = std::make_shared( - device, absl::Span>(), + device, tsl::RCReference(), absl::MakeSpan(&definition_event, 1)); auto py_buffer = std::make_unique( @@ -1212,9 +1174,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( definition_stream); auto device_buffer = std::make_shared( - device, - std::initializer_list>{buffer}, - definition_events); + device, std::move(buffer), definition_events); return std::unique_ptr(std::make_unique( shape, std::move(device_buffer), this, device, device->default_memory_space().value_or(nullptr))); @@ -1728,15 +1688,11 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { absl::StatusOr PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const { absl::MutexLock lock(&mu_); - if (device_buffer_ == nullptr) { + if (device_buffer_ == nullptr || !device_buffer_->device_memory()) { return InvalidArgument( "GetOnDeviceSizeInBytes called on deleted or donated buffer"); } - if (device_buffer_->device_memory().size() != 1) { - return InvalidArgument( - "GetOnDeviceSizeInBytes called on tuple-shaped buffer"); - } - return device_buffer_->device_memory()[0]->mem().size(); + return device_buffer_->device_memory()->mem().size(); } PjRtFuture<> PjRtStreamExecutorBuffer::CopyRawToHost(void* dst, int64_t offset, @@ -2166,15 +2122,10 @@ MakeTupleHelper( // Then set each sub-tuple in turn from the parameters. for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer : device_buffers) { - for (const tsl::RCReference& buf : - device_buffer->device_memory()) { - CHECK(input_iterator != iterator_end); - input_iterator->second = { - device_buffer.type() == - PjRtStreamExecutorBuffer::ScopedHold::kDonation, - buf}; - ++input_iterator; - } + input_iterator->second = { + device_buffer.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation, + device_buffer->device_memory()}; + ++input_iterator; } CHECK(input_iterator == iterator_end); @@ -2207,12 +2158,15 @@ absl::StatusOr> OutputBufferHelper( std::shared_ptr definition_event, PjRtClient* client, PjRtDevice* device, LocalDeviceState* local_device, std::vector>& buffers_to_release) { + if (result_buffer.shape().IsTuple()) { + return absl::InternalError("OutputBufferHelper called on tuple."); + } absl::InlinedVector, 1> buffers; for (auto& item : result_buffer) { buffers.push_back(std::move(item.second)); } auto out_buffer = std::make_shared( - device, absl::Span const>(buffers), + device, std::move(buffers[0]), absl::Span>{ definition_event}); const Shape& shape = result_buffer.shape(); @@ -2400,15 +2354,13 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents( execution_inputs.back(); auto input_iterator = execution_input.begin(); auto iterator_end = execution_input.end(); - for (const tsl::RCReference& buf : - device_buffers[i]->device_memory()) { - CHECK(input_iterator != iterator_end); - input_iterator->second = { - device_buffers[i].type() == - PjRtStreamExecutorBuffer::ScopedHold::kDonation, - buf}; - ++input_iterator; - } + const auto& buf = device_buffers[i]->device_memory(); + CHECK(input_iterator != iterator_end); + input_iterator->second = { + device_buffers[i].type() == + PjRtStreamExecutorBuffer::ScopedHold::kDonation, + buf}; + ++input_iterator; CHECK(input_iterator == iterator_end); } } @@ -2959,7 +2911,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( tsl::profiler::TraceMe traceme("MakeOutputBuffers"); std::vector> outputs; LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); - if (options.untuple_result && result_buffer.shape().IsTuple()) { + if (result_buffer.shape().IsTuple()) { int tuple_count = result_buffer.shape().tuple_shapes_size(); outputs.reserve(tuple_count); // Take ownership of each of the output values, leaving only the root table diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc index 5bd6dd8ff09510..cf6510a0b6da38 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -105,8 +105,9 @@ absl::Status ExecuteWithSameInputBuffer( shape, *device0->default_memory_space())); TF_ASSIGN_OR_RETURN(auto executable, ToyExecutable(*client, shape, std::move(set_up_aliases))); - return executable->Execute({{buffer.get(), buffer.get()}}, /*options=*/{}) - .status(); + xla::ExecuteOptions options; + options.untuple_result = true; + return executable->Execute({{buffer.get(), buffer.get()}}, options).status(); } TEST(PjRtStreamExecutorClientTest, DonateSameBufferTwice) { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index 0cfe41b9c2326e..9241d470f4da0b 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -242,31 +242,6 @@ tsl::RCReference RawSEDeviceMemory::CreateForeign( std::move(on_delete_callback)); } -/* static */ std::shared_ptr -TrackedDeviceBuffer::FromScopedShapedBuffer( - ScopedShapedBuffer* shaped_buffer, - absl::Span> definition_events, - PjRtDevice* device) { - ShapeTree::iterator iterator = - shaped_buffer->buffers().begin(); - std::vector> buffers; - buffers.reserve(1); - - ShapeUtil::ForEachSubshape( - shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) { - CHECK(iterator != shaped_buffer->buffers().end()); - buffers.push_back(RawSEDeviceMemory::Create( - iterator->second, device->local_device_id(), - shaped_buffer->memory_allocator())); - iterator->second = se::DeviceMemoryBase(); - ++iterator; - }); - CHECK(iterator == shaped_buffer->buffers().end()); - return std::make_shared( - device, absl::Span>(buffers), - definition_events); -} - ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( const Shape& on_device_shape) const { ShapedBuffer shaped_buffer(on_device_shape, @@ -274,9 +249,9 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( device_->local_hardware_id().value()); ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); - for (const tsl::RCReference& buf : device_memory_) { + if (device_memory_) { CHECK(iterator != shaped_buffer.buffers().end()); - iterator->second = buf->mem(); + iterator->second = device_memory_->mem(); ++iterator; } CHECK(iterator == shaped_buffer.buffers().end()); @@ -284,18 +259,19 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( } TrackedDeviceBuffer::TrackedDeviceBuffer( - PjRtDevice* device, - absl::Span const> device_memory, + PjRtDevice* device, tsl::RCReference device_memory, absl::Span> definition_events) : device_(device), - device_memory_(device_memory.begin(), device_memory.end()), + device_memory_(std::move(device_memory)), definition_events_(std::make_move_iterator(definition_events.begin()), std::make_move_iterator(definition_events.end())), in_use_(true) {} TrackedDeviceBuffer::~TrackedDeviceBuffer() = default; -void TrackedDeviceBuffer::ReleaseDeviceMemory() { device_memory_.clear(); } +void TrackedDeviceBuffer::ReleaseDeviceMemory() { + device_memory_ = tsl::RCReference(); +} void TrackedDeviceBuffer::AddUsageEvent( se::Stream* usage_stream, std::shared_ptr event, diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index 79cd45650b4e6c..c3a3be91aaf773 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -239,14 +239,6 @@ class TrackedDeviceBuffer { bool reference_held; }; - // Converts a ScopedShapedBuffer into a TrackedDeviceBuffer. Takes ownership - // of the buffers of the shaped_buffer. - static std::shared_ptr FromScopedShapedBuffer( - ScopedShapedBuffer* shaped_buffer, - absl::Span> - definition_events, - PjRtDevice* device); - // Builds a ShapedBuffer view onto the buffers of 'tree'. ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const; @@ -273,13 +265,10 @@ class TrackedDeviceBuffer { ExecutionInput* execution_input, se::DeviceMemoryAllocator* allocator) const; - absl::InlinedVector, 1>& device_memory() { - return device_memory_; - } - const absl::InlinedVector, 1>& - device_memory() const { + const tsl::RCReference& device_memory() const { return device_memory_; } + absl::Span> definition_events() const { return definition_events_; @@ -312,19 +301,17 @@ class TrackedDeviceBuffer { // any stream and, e.g. AddUsageHold will CHECK fail. StreamAndEventContainer LockUseAndTransferUsageEvents(); - TrackedDeviceBuffer() : in_use_(true) {} - TrackedDeviceBuffer( - PjRtDevice* device, - absl::Span const> device_memory, - absl::Span> - definition_events); + TrackedDeviceBuffer(PjRtDevice* device, + tsl::RCReference device_memory, + absl::Span> + definition_events); ~TrackedDeviceBuffer(); private: PjRtDevice* device_; // Each host-side buffer may have several buffers on-device. - absl::InlinedVector, 1> device_memory_; + tsl::RCReference device_memory_; // Events that are triggered when the content of one or more buffers is ready // during multistream execution. May be nullptr, which is used in the diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index 5518edf6880b9d..86699456f18283 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -98,7 +98,7 @@ absl::StatusOr> MakeArray( return absl::OkStatus(); })); return std::make_shared( - device, device_buffers, + device, device_buffers[0], absl::Span>()); } @@ -113,12 +113,9 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, MakeArray(b_shape, client, &device)); TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, MakeArray(c_shape, client, &device)); - ASSERT_EQ(a_buffer->device_memory().size(), 1); - ASSERT_EQ(b_buffer->device_memory().size(), 1); - ASSERT_EQ(c_buffer->device_memory().size(), 1); std::vector expected_buffer_sequence = { - a_buffer->device_memory()[0]->mem(), b_buffer->device_memory()[0]->mem(), - c_buffer->device_memory()[0]->mem()}; + a_buffer->device_memory()->mem(), b_buffer->device_memory()->mem(), + c_buffer->device_memory()->mem()}; ShapedBuffer shaped_a = a_buffer->AsShapedBuffer( client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape)); ShapedBuffer shaped_b = b_buffer->AsShapedBuffer( @@ -147,25 +144,5 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) { EXPECT_TRUE(expected_it == expected_buffer_sequence.end()); } -TEST(TrackedDeviceBufferTest, FromScopedShapedBuffer) { - TestDevice device; - LocalClient* client = ClientLibrary::LocalClientOrDie(); - - Literal literal = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateFullWithDescendingLayout({10, 3, 7}, 33.4f), - LiteralUtil::One(S64)); - - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer shaped_buffer, - client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); - std::shared_ptr device_buffer = - TrackedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {}, &device); - - EXPECT_EQ(device_buffer->device_memory().size(), - ShapeUtil::SubshapeCount( - client->backend().transfer_manager()->HostShapeToDeviceShape( - literal.shape()))); -} - } // namespace } // namespace xla From 8bc557de5b59bce4ea06413d5f2d9fb3c6e6b751 Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Wed, 2 Apr 2025 15:15:22 -0700 Subject: [PATCH 0177/1324] Move the test-only `sharding_format_picker.cc` next to the test that uses it. While working to improve `xla/hlo/` component coverage, we found that this file was test-only and used exclusively by a test in the `xla/service/` component. We're therefore moving it into `xla/service/`, next to the test that uses it. PiperOrigin-RevId: 743297363 --- third_party/xla/xla/hlo/tools/hlo_opt/BUILD | 1 - .../xla/xla/hlo/tools/hlo_opt/opt_lib.cc | 9 ++++---- third_party/xla/xla/hlo/transforms/BUILD | 17 -------------- third_party/xla/xla/service/BUILD | 8 ------- .../xla/xla/service/sharding_format_picker.h | 22 ------------------- third_party/xla/xla/service/spmd/BUILD | 19 +++++++++++++++- .../spmd}/sharding_format_picker.cc | 4 +++- .../spmd}/sharding_format_picker.h | 10 ++++----- .../xla/service/spmd/spmd_partitioner_test.cc | 3 ++- third_party/xla/xla/tools/hlo_opt/BUILD | 1 + .../xla/xla/tools/hlo_opt/compiled_opt_lib.cc | 8 +++++++ 11 files changed, 42 insertions(+), 60 deletions(-) delete mode 100644 third_party/xla/xla/service/sharding_format_picker.h rename third_party/xla/xla/{hlo/transforms => service/spmd}/sharding_format_picker.cc (98%) rename third_party/xla/xla/{hlo/transforms => service/spmd}/sharding_format_picker.h (88%) diff --git a/third_party/xla/xla/hlo/tools/hlo_opt/BUILD b/third_party/xla/xla/hlo/tools/hlo_opt/BUILD index 23f27e3fe27246..d724068ad2cd7e 100644 --- a/third_party/xla/xla/hlo/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_opt/BUILD @@ -53,7 +53,6 @@ cc_library( "//xla/hlo/transforms:literal_canonicalizer", "//xla/hlo/transforms:memory_space_propagation", "//xla/hlo/transforms:operand_upcaster", - "//xla/hlo/transforms:sharding_format_picker", "//xla/hlo/transforms:while_loop_trip_count_annotator", "//xla/hlo/transforms/collectives:all_gather_broadcast_reorder", "//xla/hlo/transforms/collectives:all_gather_combiner", diff --git a/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc index f51823f03fb4bc..d8c1100898afdc 100644 --- a/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc +++ b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc @@ -81,7 +81,6 @@ limitations under the License. #include "xla/hlo/transforms/literal_canonicalizer.h" #include "xla/hlo/transforms/memory_space_propagation.h" #include "xla/hlo/transforms/operand_upcaster.h" -#include "xla/hlo/transforms/sharding_format_picker.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/transforms/simplifiers/all_reduce_folder.h" #include "xla/hlo/transforms/simplifiers/ar_crs_combiner.h" @@ -322,8 +321,6 @@ void OptProvider::RegisterAllHardwareIndependentPasses() { RegisterPass(RandomAlgorithm::RNG_THREE_FRY); RegisterPass(); RegisterPass(); - RegisterPass( - /*sharding_type=*/ShardingFormatPicker::ShardingType::kBestEffortV2); RegisterPass(); RegisterPass(); RegisterPass(); @@ -343,13 +340,17 @@ void OptProvider::RegisterAllHardwareIndependentPasses() { // pass specific customization to the `RegisterPass`. // Dummy passes for unit-testing the `hlo-opt` tool itself. - RegisterPass(); + // go/keep-sorted start RegisterPass(); + RegisterPass(); + // go/keep-sorted end // Test-only passes exposing behavior that isn't easily testable through // standard passes, e.g. internal or config-dependent behavior. + // go/keep-sorted start RegisterPass(); RegisterPass(); + // go/keep-sorted end } } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index c5f8db622e4d7e..c26539cdf2e81d 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -479,23 +479,6 @@ xla_cc_test( ], ) -cc_library( - name = "sharding_format_picker", - testonly = True, - srcs = ["sharding_format_picker.cc"], - hdrs = ["sharding_format_picker.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:tile_assignment", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "add_original_value", srcs = ["add_original_value.cc"], diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index cbff4fa8204e07..ba631fa150c014 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6172,14 +6172,6 @@ cc_library( deps = ["//xla/hlo/transforms/simplifiers:sub_byte_normalization"], ) -cc_library( - name = "sharding_format_picker", - testonly = True, - hdrs = ["sharding_format_picker.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:sharding_format_picker instead.", - deps = ["//xla/hlo/transforms:sharding_format_picker"], -) - xla_cc_test( name = "batched_gather_scatter_normalizer_test", srcs = ["batched_gather_scatter_normalizer_test.cc"], diff --git a/third_party/xla/xla/service/sharding_format_picker.h b/third_party/xla/xla/service/sharding_format_picker.h deleted file mode 100644 index 9a369faedf284b..00000000000000 --- a/third_party/xla/xla/service/sharding_format_picker.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SHARDING_FORMAT_PICKER_H_ -#define XLA_SERVICE_SHARDING_FORMAT_PICKER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/sharding_format_picker.h" - -#endif // XLA_SERVICE_SHARDING_FORMAT_PICKER_H_ diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index b78792deb4bbcb..b2e3ff2783b7a1 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -97,6 +97,7 @@ xla_cc_test( srcs = ["spmd_partitioner_test.cc"], shard_count = 10, deps = [ + ":sharding_format_picker", ":spmd_partitioner", ":spmd_prepare", "//xla:shape_util", @@ -104,7 +105,6 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:sharding_format_picker", "//xla/hlo/utils:hlo_matchers", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:hlo_module_config", @@ -125,6 +125,23 @@ xla_cc_test( ], ) +cc_library( + name = "sharding_format_picker", + testonly = True, + srcs = ["sharding_format_picker.cc"], + hdrs = ["sharding_format_picker.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + xla_cc_test( name = "canonicalize_all_gather_for_cse_test", srcs = ["canonicalize_all_gather_for_cse_test.cc"], diff --git a/third_party/xla/xla/hlo/transforms/sharding_format_picker.cc b/third_party/xla/xla/service/spmd/sharding_format_picker.cc similarity index 98% rename from third_party/xla/xla/hlo/transforms/sharding_format_picker.cc rename to third_party/xla/xla/service/spmd/sharding_format_picker.cc index 90b192400f1f1e..c6a1a6a8008194 100644 --- a/third_party/xla/xla/hlo/transforms/sharding_format_picker.cc +++ b/third_party/xla/xla/service/spmd/sharding_format_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/transforms/sharding_format_picker.h" +#include "xla/service/spmd/sharding_format_picker.h" #include #include @@ -46,6 +46,7 @@ class HloShardingTestHelper { } }; +namespace test_only { namespace { bool PermuteDimsHelper(absl::Span dims, absl::Span perm, @@ -195,4 +196,5 @@ absl::StatusOr ShardingFormatPicker::Run( return changed; } +} // namespace test_only } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/sharding_format_picker.h b/third_party/xla/xla/service/spmd/sharding_format_picker.h similarity index 88% rename from third_party/xla/xla/hlo/transforms/sharding_format_picker.h rename to third_party/xla/xla/service/spmd/sharding_format_picker.h index a6cbeb9420a4c8..583444157b631e 100644 --- a/third_party/xla/xla/hlo/transforms/sharding_format_picker.h +++ b/third_party/xla/xla/service/spmd/sharding_format_picker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_HLO_TRANSFORMS_SHARDING_FORMAT_PICKER_H_ -#define XLA_HLO_TRANSFORMS_SHARDING_FORMAT_PICKER_H_ +#ifndef XLA_SERVICE_SPMD_SHARDING_FORMAT_PICKER_H_ +#define XLA_SERVICE_SPMD_SHARDING_FORMAT_PICKER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -22,7 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" -namespace xla { +namespace xla::test_only { // Test-only pass to transform the HloSharding format of all the instructions in // a module to the selected format. @@ -44,6 +44,6 @@ class ShardingFormatPicker : public HloModulePass { const ShardingType sharding_type_; }; -} // namespace xla +} // namespace xla::test_only -#endif // XLA_HLO_TRANSFORMS_SHARDING_FORMAT_PICKER_H_ +#endif // XLA_SERVICE_SPMD_SHARDING_FORMAT_PICKER_H_ diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index f8fd9b6150b8b9..fd2799ead72a5d 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -41,12 +41,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/hlo/transforms/sharding_format_picker.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" +#include "xla/service/spmd/sharding_format_picker.h" #include "xla/service/spmd/spmd_prepare.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" @@ -62,6 +62,7 @@ namespace { using ::testing::_; using ::testing::AllOf; +using ::xla::test_only::ShardingFormatPicker; namespace op = xla::testing::opcode_matchers; class SpmdPartitioningTest diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index d54dae0d31c083..a51ce425f9584e 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -64,6 +64,7 @@ cc_library( "//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator", "//xla/service/gpu/transforms:scatter_expander", "//xla/service/gpu/transforms:scatter_slice_simplifier", + "//xla/service/spmd:sharding_format_picker", "//xla/service/spmd/shardy:shardy_xla_pass", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", diff --git a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc index 9718c8c42c1ab8..fa241dd4d003ff 100644 --- a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc +++ b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/service/scatter_simplifier.h" #include "xla/service/select_and_scatter_expander.h" #include "xla/service/sharding_remover.h" +#include "xla/service/spmd/sharding_format_picker.h" #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include "xla/service/topk_rewriter.h" #include "xla/service/triangular_solve_expander.h" @@ -178,6 +179,13 @@ void CompiledOptProvider::RegisterSharedHardwareSpecificPasses() { RegisterPass(); RegisterPass(); // go/keep-sorted end + + // Test-only passes exposing behavior that isn't easily testable through + // standard passes, e.g. internal or config-dependent behavior. + // go/keep-sorted start + RegisterPass( + test_only::ShardingFormatPicker::ShardingType::kBestEffortV2); + // go/keep-sorted end } } // namespace xla From 9adefd9c7f62aa051f7f2e7219f560dc50216206 Mon Sep 17 00:00:00 2001 From: Julia Guo Date: Wed, 2 Apr 2025 15:21:53 -0700 Subject: [PATCH 0178/1324] [XLA] Skip gemma HLOs for `linux-arm64-t2a-48` PiperOrigin-RevId: 743299644 --- third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml b/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml index 0e2a4aa4a61476..f2e4ee1b979a12 100644 --- a/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml +++ b/third_party/xla/.github/workflows/cpu_benchmarks_nightly.yml @@ -145,6 +145,7 @@ jobs: cat "$OUTPUT_FILE_PATH" - name: Run hlo_runner_main on gemma2_2b_keras_jax.hlo and collect profile stats + if: ${{ matrix.job_info.pool != 'linux-arm64-t2a-48' }} run: | bazel_arch_dir="${{ matrix.job_info.bazel_arch_dir }}" binary_dir="./bazel-out/${bazel_arch_dir}/bin/xla/tools" @@ -158,6 +159,7 @@ jobs: cat "$OUTPUT_FILE_PATH" - name: Run hlo_runner_main on gemma3_1b_flax_call.hlo and collect profile stats + if: ${{ matrix.job_info.pool != 'linux-arm64-t2a-48' }} run: | bazel_arch_dir="${{ matrix.job_info.bazel_arch_dir }}" binary_dir="./bazel-out/${bazel_arch_dir}/bin/xla/tools" @@ -171,6 +173,7 @@ jobs: cat "$OUTPUT_FILE_PATH" - name: Upload HLO test output to a GCS bucket + if: ${{ matrix.job_info.pool != 'linux-arm64-t2a-48' }} run: | GCS_BUCKET="gs://openxla-nightly-transient" TIMESTAMP=$(date +%Y%m%d_%H%M%S) From 96f8e72a90a50f2f93acffadd5a775a7e8e442f9 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Wed, 2 Apr 2025 15:23:52 -0700 Subject: [PATCH 0179/1324] Add more tests for TfrtGpuClient PiperOrigin-RevId: 743300214 --- third_party/xla/xla/pjrt/gpu/tfrt/BUILD | 18 + .../xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 405 ++++++++++++++++++ 2 files changed, 423 insertions(+) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD index e2af3a2459ae1e..3380c24718838b 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD +++ b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD @@ -122,8 +122,10 @@ xla_cc_test( ":gpu_event", ":tfrt_gpu_client", ":tracked_tfrt_gpu_device_buffer", + "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -132,20 +134,35 @@ xla_cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/ffi", + "//xla/ffi:ffi_api", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/pjrt:host_memory_spaces", + "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/pjrt:raw_buffer", + "//xla/pjrt/gpu:gpu_topology", + "//xla/pjrt/gpu:gpu_topology_proto_cc", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:gpu_plugin", + "//xla/service:platform_util", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/tests:literal_test_util", "//xla/tsl/concurrency:async_value", # copybara:uncomment "//xla/tsl/framework:allocator", @@ -155,6 +172,7 @@ xla_cc_test( "//xla/tsl/platform:statusor", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index 7e5fd24f32dcb2..71c2615514244b 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -29,6 +30,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -38,22 +40,35 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/pjrt/gpu/gpu_topology.h" +#include "xla/pjrt/gpu/gpu_topology.pb.h" #include "xla/pjrt/gpu/tfrt/gpu_event.h" #include "xla/pjrt/gpu/tfrt/tracked_tfrt_gpu_device_buffer.h" #include "xla/pjrt/host_memory_spaces.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" +#include "xla/pjrt/raw_buffer.h" +#include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/framework/allocator.h" @@ -62,8 +77,11 @@ limitations under the License. #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/mem.h" +#include "tsl/platform/protobuf.h" namespace xla { @@ -80,6 +98,7 @@ class DonationTransactionPeer { namespace { +using ::testing::ElementsAre; using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::Gt; @@ -268,6 +287,176 @@ TEST(TfrtGpuClientTest, SendRecvChunked) { *result_literal)); } +TEST(TfrtGpuClientTest, SendErrorNoDeadLock) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kProgram, *client)); + + // Always-failing Send handler. + SendCallback send_callback = { + /*channel_id=*/1, + [&](const PjRtTransferMetadata&, PjRtChunk, int64_t, bool) { + return Internal("Uh-oh, can send chunk to host"); + }}; + + // No-op Recv handler. + RecvCallback recv_callback = { + /*channel_id=*/2, [&](const PjRtTransferMetadata& m, + std::unique_ptr stream) { + return absl::OkStatus(); + }}; + + // Callbacks for point-to-point communication ops. + std::vector> send_callbacks = {{send_callback}}; + std::vector> recv_callbacks = {{recv_callback}}; + + ExecuteOptions opts; + opts.send_callbacks = send_callbacks; + opts.recv_callbacks = recv_callbacks; + + // Check that send error safely rejected and we do not dead lock. + auto result = executable->Execute(/*argument_handles=*/{{}}, opts); + EXPECT_THAT(ExtractSingleResult(result).status(), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Uh-oh, can send chunk to host"))); +} + +TEST(TfrtGpuClientTest, RecvErrorNoDeadLock) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kProgram, *client)); + + // No-op Send handler. + SendCallback send_callback = { + /*channel_id=*/1, [&](const PjRtTransferMetadata&, PjRtChunk, int64_t, + bool) { return absl::OkStatus(); }}; + + // Invalid Recv handler that tries to add invalid chunk. + RecvCallback recv_callback = { + /*channel_id=*/2, [&](const PjRtTransferMetadata& m, + std::unique_ptr stream) { + auto chunk = PjRtChunk::AllocateDefault(10 * sizeof(float)); + stream->AddChunk(std::move(chunk)).Await().IgnoreError(); + // Return ok status to proceed to corresponding recv-done call. + return absl::OkStatus(); + }}; + + // Callbacks for point-to-point communication ops. + std::vector> send_callbacks = {{send_callback}}; + std::vector> recv_callbacks = {{recv_callback}}; + + ExecuteOptions opts; + opts.send_callbacks = send_callbacks; + opts.recv_callbacks = recv_callbacks; + + // Check that invalid chunk safely rejected and we do not dead lock. + auto result = executable->Execute(/*argument_handles=*/{{}}, opts); + EXPECT_THAT( + ExtractSingleResult(result).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Adding chunk of size 40 would overflow buffer " + "of size 8 (0 already transferred)"))); +} + +// User-defined data type to be passed to FFI handler via the execute context +// side channel. +struct MemsetValue { + explicit MemsetValue(float value) : value(value) {} + float value; +}; + +static absl::Status MemsetFromValue( + se::Stream* stream, ffi::Result> result, + MemsetValue* memset_value) { + uint32_t pattern; + std::memcpy(&pattern, &memset_value->value, sizeof(pattern)); + + se::DeviceMemoryBase base = result->device_memory(); + return stream->Memset32(&base, pattern, base.size()); +} + +XLA_FFI_DEFINE_HANDLER(kMemsetFromValue, MemsetFromValue, + ffi::Ffi::Bind() + .Ctx() + .Ret>() + .Ctx>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "MemsetFromValue", + PlatformUtil::CanonicalPlatformName("GPU").value(), + kMemsetFromValue); + +TEST(TfrtGpuClientTest, ForwardUserDataToFfiHandler) { + static constexpr char const* kProgram = R"( + HloModule ffi_handler + ENTRY main { + ROOT %custom-call = f32[4] custom-call(), + custom_call_target="MemsetFromValue", + api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kProgram, *client)); + + ExecuteContext context; + TF_ASSERT_OK(context.ffi_context().Emplace(42.0f)); + + ExecuteOptions opts; + opts.context = &context; + + auto result = executable->Execute(/*argument_handles=*/{{}}, opts); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr result_literal, + ExtractSingleResult(result)); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR1({42.0f, 42.0f, 42.0f, 42.0f}), + *result_literal)); +} + +static absl::Status MemsetFromAttr( + se::Stream* stream, float attr, + ffi::Result> result) { + uint32_t pattern; + std::memcpy(&pattern, &attr, sizeof(pattern)); + + se::DeviceMemoryBase base = result->device_memory(); + return stream->Memset32(&base, pattern, base.size()); +} + +XLA_FFI_DEFINE_HANDLER(kMemsetFromAttr, MemsetFromAttr, + ffi::Ffi::Bind() + .Ctx() + .Attr("attr") + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "MemsetFromAttr", + PlatformUtil::CanonicalPlatformName("GPU").value(), + kMemsetFromAttr); + +TEST(TfrtGpuClientTest, PassAttrToFfiHandler) { + static constexpr char const* kProgram = R"( + HloModule ffi_handler + ENTRY main { + ROOT %custom-call = f32[4] custom-call(), + custom_call_target="MemsetFromAttr", + api_version=API_VERSION_TYPED_FFI, + backend_config={"custom_call_backend_config": {"attributes": "{attr = 3.0 : f32}"}} + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kProgram, *client)); + + ExecuteOptions opts; + auto result = executable->Execute(/*argument_handles=*/{{}}, opts); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr result_literal, + ExtractSingleResult(result)); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f, 3.0f}), *result_literal)); +} + TEST(TfrtGpuClientTest, AcquireDonation) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); ASSERT_GE(client->devices().size(), 1); @@ -676,6 +865,53 @@ TEST(TfrtGpuClientTest, FromHostAsyncPinnedHostChunked) { EXPECT_THAT(lit->data(), ElementsAreArray(data)); } +TEST(TfrtGpuClientTest, DeleteBufferThenFulfillBufferNoDeadLock) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetTfrtGpuClient(GpuClientOptions())); + ASSERT_THAT(client->addressable_devices(), SizeIs(Gt(0))); + TF_ASSERT_OK_AND_ASSIGN( + PjRtMemorySpace * memspace, + client->addressable_devices()[0]->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); + std::vector data{1, 3, 5, 7, 11, 13, 17, 19}; + Shape shape = ShapeUtil::MakeShape(F32, {static_cast(data.size())}); + std::vector> + txms; + for (int i = 0; i < 10000; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr txm, + client->CreateBuffersForAsyncHostToDevice({shape}, memspace)); + std::unique_ptr buf = txm->RetrieveBuffer(0); + ASSERT_THAT(buf->GetReadyFuture().IsReady(), Eq(false)); + txms.push_back(std::move(txm)); + // Delete the buffer + } + + // At this point, we have 10000 buffers pending deallocation. + + absl::string_view raw_view( + reinterpret_cast(data.data()), // REINTERPRET_CAST_OK=test + data.size() * sizeof(data[0])); + for (auto& txm : txms) { + int offset = 0; + while (true) { + int end = offset + 3; // unaligned chunk size + if (end > raw_view.size()) { + end = raw_view.size(); + } + int sz = end - offset; + bool reaches_end = end == raw_view.size(); + TF_ASSERT_OK(txm->TransferRawDataToSubBuffer( + /*buffer_index=*/0, raw_view.data() + offset, offset, sz, reaches_end, + /*on_done=*/[]() {})); + if (reaches_end) { + break; + } + offset = end; + } + } +} + TEST(TfrtGpuClientTest, CreateMixOfErrorBuffers) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, GetTfrtGpuClient(GpuClientOptions())); @@ -858,6 +1094,175 @@ TEST(TfrtGpuClientTest, CopyRawToHostFuture) { tsl::port::AlignedSizedFree(dst, tsl::Allocator::kAllocatorAlignment, size); } +TEST(GpuTopology, FromProto) { + GpuTopologyProto msg; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + device_ids: [ 3, 2, 1 ] + platform_version: "platform_version" + num_slices: 2 + num_hosts_per_slice: 1 + num_devices_per_host: 3 + )pb", + &msg)); + + std::unique_ptr gpu_topology = GpuTopology::FromProto(msg); + EXPECT_THAT(gpu_topology->device_ids(), ElementsAre(3, 2, 1)); + EXPECT_THAT(gpu_topology->platform_version(), "platform_version"); + EXPECT_THAT(gpu_topology->num_slices(), 2); + EXPECT_THAT(gpu_topology->num_hosts_per_slice(), 1); + EXPECT_THAT(gpu_topology->num_devices_per_host(), 3); +} + +TEST(GpuTopology, ToProto) { + GpuTopology gpu_topology(/*gpu_device_ids=*/{3, 2, 1}, + /*platform_version=*/"platform_version", + /*num_slices=*/2, + /*num_hosts_per_slice=*/1, + /*num_devices_per_host=*/3); + GpuTopologyProto msg = gpu_topology.ToProto(); + EXPECT_THAT(msg.device_ids(), ElementsAre(3, 2, 1)); + EXPECT_THAT(msg.platform_version(), "platform_version"); + EXPECT_THAT(msg.num_slices(), 2); + EXPECT_THAT(msg.num_hosts_per_slice(), 1); + EXPECT_THAT(msg.num_devices_per_host(), 3); +} + +namespace { + +constexpr char const* kD2HProgram = R"( + HloModule f + + ENTRY main.5 { + p = s32[4]{0} parameter(0) + ROOT cc = s32[4] custom-call(p), + custom_call_target="annotate_device_placement", + frontend_attributes={_xla_buffer_placement="pinned_host"} + } +)"; + +constexpr char const* kD2HProgramTupleOutput = R"( + HloModule f + + ENTRY main.5 { + p = s32[4]{0} parameter(0) + cc = s32[4] custom-call(p), + custom_call_target="annotate_device_placement", + frontend_attributes={_xla_buffer_placement="pinned_host"} + ROOT tuple = (s32[4]{0}, s32[4]{0}) tuple(s32[4]{0} p, s32[4]{0} cc) + } +)"; + +} // namespace + +TEST(TfrtGpuClientTest, ExecutablePinnedHostOutputMemoryKindTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kD2HProgram, *client)); + + TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds, + executable->GetOutputMemoryKinds()); + EXPECT_EQ(memory_kinds.size(), 1); + EXPECT_EQ(memory_kinds[0].size(), 1); + EXPECT_EQ(memory_kinds[0][0], "pinned_host"); +} + +TEST(TfrtGpuClientTest, ExecutablePinnedHostTupleOutputMemoryKindTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + + // Build the output shape with the correct memory space set. + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {4}, {0}); + Shape host_shape = shape; + host_shape.mutable_layout()->set_memory_space(Layout::kHostMemorySpace); + Shape out_shape = ShapeUtil::MakeTupleShape({shape, host_shape}); + + // Set the result layout so that the compiler assertions on memory + // spaces pass. + xla::CompileOptions options; + options.executable_build_options.set_result_layout(out_shape); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CompileExecutable(kD2HProgramTupleOutput, *client, options)); + + TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds, + executable->GetOutputMemoryKinds()); + EXPECT_EQ(memory_kinds.size(), 1); + EXPECT_EQ(memory_kinds[0].size(), 2); + EXPECT_EQ(memory_kinds[0][0], "device"); + EXPECT_EQ(memory_kinds[0][1], "pinned_host"); +} + +TEST(TfrtGpuClientTest, MlirParameterLayoutFromOptionsIsSetInHlo) { + constexpr char kMlirCopy[] = + R"( + func.func public @main(%arg0: tensor<2x2x2xi32> { + mhlo.layout_mode = "default" + }) -> (tensor<2x2x2xi32> { + jax.result_info = "", + mhlo.layout_mode = "default"}) { + return %arg0 : tensor<2x2x2xi32> + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseMlirModuleString(kMlirCopy, context)); + + xla::CompileOptions options; + options.argument_layouts = { + {ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 2, 2}, {0, 2, 1})}}; + TF_ASSERT_OK_AND_ASSIGN(auto executable, + client->CompileAndLoad(*module, options)); + TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); + + auto first_param_layout = + modules[0]->entry_computation_layout().parameter_layout(0).layout(); + EXPECT_EQ(first_param_layout, Layout({0, 2, 1})); +} + +TEST(TfrtGpuClientTest, GetDefaultLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + auto shape = ShapeUtil::MakeShape(S4, {2, 2}); + TF_ASSERT_OK_AND_ASSIGN( + auto layout, + client->GetDefaultLayout(shape.element_type(), shape.dimensions())); + EXPECT_EQ(layout.element_size_in_bits(), 4); +} + +TEST(TfrtGpuClientTest, AutoLayoutIsSupported) { + const char* hlo_text = R"( + HloModule DotLayout, + entry_computation_layout={(f32[2,3,5],f32[3,4,5])->f32[5,2,4]{2,1,0}} + + ENTRY dot { + p0 = f32[2,3,5]{2,1,0} parameter(0) + p1 = f32[3,4,5]{2,1,0} parameter(1) + ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1), + lhs_batch_dims={2}, lhs_contracting_dims={1}, + rhs_batch_dims={2}, rhs_contracting_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnUnverifiedModule( + hlo_text, {}, HloParserOptions().set_fill_missing_layouts(false))); + + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + CompileOptions compile_options; + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_pjrt_allow_auto_layout_in_hlo(true); + XlaComputation computation = m->ToProto(); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + client->CompileAndLoad(computation, compile_options)); + TF_ASSERT_OK_AND_ASSIGN(auto layouts, executable->GetParameterLayouts()); + // Check that the assigned layouts are not default. + EXPECT_NE(layouts[0]->ToString(), "{2,1,0}"); + EXPECT_NE(layouts[1]->ToString(), "{2,1,0}"); +} + TEST(TfrtGpuClientTest, CreateUninitializedBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); From c2032e84aec1063e48007e528b53793c5e368b1c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 16:15:55 -0700 Subject: [PATCH 0180/1324] Avoid recomputing post-order when possible The non-determinism in post order traversal causes incorrect instruction order when post-order is called again. Other callers can also avoid this by passing a pre-computed post-order (when possible). In this patch I have not modified other callers. Added a testcase to check def-use order is preserved. Bug: 202886652 PiperOrigin-RevId: 743317623 --- .../xla/xla/hlo/analysis/hlo_reachability.cc | 7 ++- .../xla/xla/hlo/analysis/hlo_reachability.h | 3 +- .../collectives_schedule_linearizer.cc | 5 +- .../collectives_schedule_linearizer_test.cc | 47 +++++++++++++++++++ 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_reachability.cc b/third_party/xla/xla/hlo/analysis/hlo_reachability.cc index a843013b9eb31d..d059ece78afa7f 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_reachability.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability.cc @@ -110,11 +110,14 @@ std::unique_ptr HloReachabilityMap::BuildWithRestrictions( } std::unique_ptr HloReachabilityMap::Build( - const HloComputation* computation) { + const HloComputation* computation, + const std::vector& po_instructions) { HloComputation::ChannelDependencies channel_dependencies = computation->ComputeChannelDependencies(); std::vector instructions = - computation->MakeInstructionPostOrder(channel_dependencies); + po_instructions.empty() + ? computation->MakeInstructionPostOrder(channel_dependencies) + : po_instructions; auto result = std::make_unique(instructions); auto get_bit_set = [&](const HloInstruction* instruction) -> BitSet& { diff --git a/third_party/xla/xla/hlo/analysis/hlo_reachability.h b/third_party/xla/xla/hlo/analysis/hlo_reachability.h index 68faafe43cab15..74418975b22d7a 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_reachability.h +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability.h @@ -55,7 +55,8 @@ class HloReachabilityMap { // dependencies (operands) and control dependencies are considered for // reachability. Trivially an instruction is reachable from itself. static std::unique_ptr Build( - const HloComputation* computation); + const HloComputation* computation, + const std::vector& po_instructions = {}); // Similar to the above Build operation except that it tries to identify // paths between instructions that do not contain control instructions diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc index b5e755e8df41a0..772992b4e1c5d5 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc @@ -43,14 +43,15 @@ absl::StatusOr CollectivesScheduleLinearizer::Run( module->MakeNonfusionComputations(execution_threads)) { std::unique_ptr reachability; HloInstruction* prev_done = nullptr; - for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + auto post_order = computation->MakeInstructionPostOrder(); + for (HloInstruction* inst : post_order) { auto* next = DynCast(inst); if (!next) { continue; } // Build reachability map on demand if we actually see collectives. if (!reachability) { - reachability = HloReachabilityMap::Build(computation); + reachability = HloReachabilityMap::Build(computation, post_order); } // Derive the 'start' and 'done' peers of this instruction. For non-async // variants of collectives, they are the same as this instruction. For diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc index f6bd23d2ee4ead..2ef88bbea8cc4b 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc @@ -207,5 +207,52 @@ ENTRY entry { EXPECT_TRUE(absl::c_linear_search(ars1->control_predecessors(), ard0)); } +TEST_F(CollectivesScheduleLinearizerTest, DefUseOrder) { + absl::string_view hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = f32[] add(a, b) +} + +ENTRY entry { + p0 = f32[100] parameter(0), parameter_replication={false} + p1 = f32[100] parameter(1), parameter_replication={false} + i0 = f32[100] add(p0, p1) + i1 = f32[100] multiply(p0, p1) + i2 = f32[100] divide(p0, p1) + c1 = f32[100] all-reduce(i0), replica_groups={}, to_apply=sum, channel_id=1 + c2 = f32[100] all-reduce(i1), replica_groups={}, to_apply=sum, channel_id=1 + c3 = f32[100] all-reduce(i2), replica_groups={}, to_apply=sum, channel_id=1 + t = f32[100] add(c1, c2) + ROOT out = f32[100] add(t, c3) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCollectivesSchedule(module.get()); + EXPECT_EQ(CountControlEdges(*module->entry_computation()), 2); + + const HloInstruction *root = module->entry_computation()->root_instruction(); + const HloInstruction *t = root->operand(0); // t = add(c1, c2) + const HloInstruction *c3 = root->operand(1); // c3 = all-reduce(i2)... + EXPECT_EQ(t->opcode(), HloOpcode::kAdd); + EXPECT_EQ(c3->opcode(), HloOpcode::kAllReduce); + + const HloInstruction *c1 = t->operand(0); + const HloInstruction *c2 = t->operand(1); + EXPECT_EQ(c1->opcode(), HloOpcode::kAllReduce); + EXPECT_EQ(c2->opcode(), HloOpcode::kAllReduce); + + bool found_i0 = false; + // Verify that i0 is before c1. + for (const auto &instruction : module->entry_computation()->instructions()) { + if (instruction->name() == "c1") EXPECT_TRUE(found_i0); + if (instruction->name() == "i0") found_i0 = true; + } +} + } // namespace } // namespace xla From a58f7378a75ade110522fda816aba695bc6af05c Mon Sep 17 00:00:00 2001 From: Tongfei Guo Date: Wed, 2 Apr 2025 16:20:46 -0700 Subject: [PATCH 0181/1324] [XLA:SPMD] Fix non-determinism in SHARD_AS/SHARD_LIKE. PiperOrigin-RevId: 743319126 --- .../xla/xla/service/sharding_propagation.cc | 24 +++++++++---------- .../xla/xla/service/sharding_propagation.h | 11 +++++---- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index e37abb77d8978e..bd20f42360076e 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -1003,7 +1003,7 @@ bool InferUnspecifiedDimsFromUsers(HloInstruction* annotate_op, bool InferUnspecifiedDimsFromShardGroup( HloInstruction* annotate_op, absl::Span unspecified_dims, - const absl::flat_hash_set& shard_group) { + const std::vector& shard_group) { // ProcessShardingInstruction will either keep the "Sharding" custom call as // is or replace it with a copy. CHECK(annotate_op->IsCustomCall("Sharding") || @@ -1382,9 +1382,9 @@ absl::StatusOr ProcessShardingInstruction( absl::flat_hash_map* saved_parameter_shardings, absl::flat_hash_map* instruction_to_shard_group_id, - absl::flat_hash_map>* + absl::flat_hash_map>* shard_group_id_to_shard_as_group, - absl::flat_hash_map>* + absl::flat_hash_map>* shard_group_id_to_shard_like_group, const std::vector* allow_spmd_sharding_propagation_to_parameters_vector, @@ -1440,7 +1440,7 @@ absl::StatusOr ProcessShardingInstruction( "instructions within the same shard_as group: " << (*shard_as_group.begin())->shape().ToString(); } - shard_as_group.insert(instruction); + shard_as_group.push_back(instruction); } else { auto& shard_like_group = (*shard_group_id_to_shard_like_group)[shard_group_id]; @@ -1452,7 +1452,7 @@ absl::StatusOr ProcessShardingInstruction( "instructions within the same shard_like group: " << (*shard_like_group.begin())->shape().ToString(); } - shard_like_group.insert(instruction); + shard_like_group.push_back(instruction); } HloSharding sharding = instruction->sharding(); sharding.ClearShardGroup(); @@ -2038,7 +2038,7 @@ bool InferDynamicUpdateSliceShardingFromOperand0( bool ShardingPropagation::InferShardingFromShardGroup( HloInstruction* instruction, int64_t aggressiveness, - const absl::flat_hash_set& shard_group) { + const std::vector& shard_group) { if (!CanPropagateThroughAtAggressiveLevel(*instruction, aggressiveness)) { return false; } @@ -2836,9 +2836,9 @@ absl::StatusOr ShardingPropagation::RunToFixPoint( unspecified_dims, absl::flat_hash_map& instruction_to_shard_group_id, - absl::flat_hash_map>& + absl::flat_hash_map>& shard_group_id_to_shard_as_group, - absl::flat_hash_map>& + absl::flat_hash_map>& shard_group_id_to_shard_like_group, int64_t& iterations) { bool changed = false; @@ -2885,7 +2885,7 @@ absl::StatusOr ShardingPropagation::RunToFixPoint( } if (instruction_to_shard_group_id.contains(hlo)) { const int64_t shard_group_id = instruction_to_shard_group_id.at(hlo); - const absl::flat_hash_set& shard_group = + const std::vector& shard_group = shard_group_id_to_shard_as_group.contains(shard_group_id) ? shard_group_id_to_shard_as_group.at(shard_group_id) : shard_group_id_to_shard_like_group.at(shard_group_id); @@ -2908,7 +2908,7 @@ absl::StatusOr ShardingPropagation::RunToFixPoint( } const int64_t shard_group_id = instruction_to_shard_group_id.at(instruction); - const absl::flat_hash_set& shard_group = + const std::vector& shard_group = shard_group_id_to_shard_as_group.contains(shard_group_id) ? shard_group_id_to_shard_as_group.at(shard_group_id) : shard_group_id_to_shard_like_group.at(shard_group_id); @@ -3158,9 +3158,9 @@ absl::StatusOr ShardingPropagation::Run( std::vector saved_root_shardings; absl::flat_hash_map saved_parameter_shardings; absl::flat_hash_map instruction_to_shard_group_id; - absl::flat_hash_map> + absl::flat_hash_map> shard_group_id_to_shard_as_group; - absl::flat_hash_map> + absl::flat_hash_map> shard_group_id_to_shard_like_group; TF_ASSIGN_OR_RETURN( bool changed, diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 903d5d7730822d..f1b05bcba32aee 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_ #define XLA_SERVICE_SHARDING_PROPAGATION_H_ +#include #include #include #include @@ -74,9 +75,9 @@ absl::StatusOr ProcessShardingInstruction( absl::flat_hash_map* saved_parameter_shardings, absl::flat_hash_map* instruction_to_shard_group_id = nullptr, - absl::flat_hash_map>* + absl::flat_hash_map>* shard_group_id_to_shard_as_group = nullptr, - absl::flat_hash_map>* + absl::flat_hash_map>* shard_group_id_to_shard_like_group = nullptr, const std::vector* allow_spmd_sharding_propagation_to_parameters_vector = nullptr, @@ -149,7 +150,7 @@ class ShardingPropagation : public HloModulePass { private: bool InferShardingFromShardGroup( HloInstruction* instruction, int64_t aggressiveness, - const absl::flat_hash_set& shard_group); + const std::vector& shard_group); bool InferShardingFromOperands( HloInstruction* instruction, const ComputationMap& computation_map, int64_t aggressiveness, const CallGraph& call_graph, @@ -171,9 +172,9 @@ class ShardingPropagation : public HloModulePass { unspecified_dims, absl::flat_hash_map& instruction_to_shard_group_id, - absl::flat_hash_map>& + absl::flat_hash_map>& shard_group_id_to_shard_as_group, - absl::flat_hash_map>& + absl::flat_hash_map>& shard_group_id_to_shard_like_group, int64_t& iterations); From 481e5e4f0f8ebc9fd35ebd284af37ee14e6bf133 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 17:14:34 -0700 Subject: [PATCH 0182/1324] clang-tidy: No header providing "xla::HloPrintOptions" is directly included PiperOrigin-RevId: 743335315 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 37d43ee73e9b1f..fc011d551b4f53 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/ptrvec.h" #include "xla/literal.h" #include "xla/map_util.h" From d1f375444420d04fd6bbab6d7284935e6f26e1d7 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 2 Apr 2025 19:01:03 -0700 Subject: [PATCH 0183/1324] [PjRt-IFRT] Support a mix of addressable and non-addressable devices in Array creation methods `xla::ifrt::PjRtClient::MakeArrayFromHostBuffer()` and `...::MakeErrorArrays()` now support creating multi-shard Array(s) that contains non-addressable devices, as long as the sharding is fully replicated (pre-existent condition). This does not change the current requirement that `xla::ifrt::PjRtClient::MakeArrayFromHostBuffer()` should have at least one addressable device (which may be relaxed in the future, but not in this change). Bumping `JAX_IFRT_VERSION_NUMBER` because when JAX relies on this feature, the old implementation will fail array creation with non-addressable shards, and thus JAX has to take a fallback path based on the version. PiperOrigin-RevId: 743360913 --- .../xla/python/ifrt/array_impl_test_lib.cc | 101 ++++++++++++++++-- .../xla/xla/python/pjrt_ifrt/pjrt_client.cc | 37 ++++--- third_party/xla/xla/python/version.h | 2 +- 3 files changed, 111 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index 0c445013ea04d4..ffaf2aee644ad4 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -57,6 +57,17 @@ using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; +// Returns a list of non-addressable devices in the client. +std::vector GetNonAddressableDevices(Client* client) { + std::vector devices; + for (auto* device : client->devices()) { + if (!device->IsAddressable()) { + devices.push_back(device); + } + } + return devices; +} + TEST(ArrayImplTest, MakeArrayFromHostBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); @@ -80,6 +91,43 @@ TEST(ArrayImplTest, MakeArrayFromHostBuffer) { EXPECT_EQ(array->shared_ptr_sharding().get(), sharding.get()); } +TEST(ArrayImplTest, + MakeArrayFromHostBufferWithAddressableAndNonAddressableDevice) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + std::vector non_addressable_devices = + GetNonAddressableDevices(client.get()); + if (non_addressable_devices.empty()) { + GTEST_SKIP() << "Skipping test; needs at least 1 non-addressable device."; + } + + DType dtype(DType::kF32); + Shape shape({2, 3}); + auto data = std::make_unique>(6); + std::iota(data->begin(), data->end(), 0); + + std::vector devices; + devices.reserve(2); + devices.push_back(non_addressable_devices.at(0)); + devices.push_back(client->addressable_devices().at(0)); + std::shared_ptr sharding = + xla::ifrt::ConcreteEvenSharding::Create(client->MakeDeviceList(devices), + xla::ifrt::MemoryKind(), shape, + /*shard_shape=*/shape, + /*is_fully_replicated=*/true); + + TF_ASSERT_OK_AND_ASSIGN( + auto array, client->MakeArrayFromHostBuffer( + data->data(), dtype, shape, + /*byte_strides=*/std::nullopt, sharding, + Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr)); + + EXPECT_EQ(array->dtype(), dtype); + EXPECT_EQ(array->shape(), shape); + EXPECT_EQ(array->shared_ptr_sharding().get(), sharding.get()); +} + class ArrayImplWithHostBufferSemanticsTest : public testing::TestWithParam {}; @@ -465,6 +513,43 @@ TEST(ArrayImplTest, MakeErrorArrays) { StatusIs(_, HasSubstr("injected error"))); } +TEST(ArrayImplTest, MakeErrorArraysWithAddressableAndNonAddressableDevice) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + std::vector non_addressable_devices = + GetNonAddressableDevices(client.get()); + if (non_addressable_devices.empty()) { + GTEST_SKIP() << "Skipping test; needs at least 1 non-addressable device."; + } + + Shape shape({2, 2}); + + std::vector devices; + devices.reserve(2); + devices.push_back(client->addressable_devices().at(0)); + devices.push_back(non_addressable_devices.at(0)); + std::shared_ptr sharding = + ConcreteEvenSharding::Create(client->MakeDeviceList(devices), + MemoryKind(), shape, /*shard_shape=*/shape, + /*is_fully_replicated=*/true); + + ArraySpec array_spec = {/*dtype=*/xla::ifrt::DType(xla::ifrt::DType::kS8), + /*shape=*/shape, + /*sharding=*/sharding}; + + const absl::Status error = absl::InternalError("injected error"); + TF_ASSERT_OK_AND_ASSIGN( + const std::vector> arrays, + client->MakeErrorArrays(error, {array_spec, array_spec}, + client->CreateUserContext())); + ASSERT_EQ(arrays.size(), 2); + + EXPECT_THAT(arrays[0]->GetReadyFuture().Await(), + StatusIs(_, HasSubstr("injected error"))); + EXPECT_THAT(arrays[1]->GetReadyFuture().Await(), + StatusIs(_, HasSubstr("injected error"))); +} + TEST(ArrayImplTest, AssembleArray) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); @@ -656,7 +741,9 @@ TEST(ArrayImplTest, CopyToSameDevices) { TEST(ArrayImplTest, AssembleAndDisassembleNonAddressableArray) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); - if (client->device_count() - client->addressable_device_count() < 2) { + std::vector non_addressable_devices = + GetNonAddressableDevices(client.get()); + if (non_addressable_devices.size() < 2) { GTEST_SKIP() << "Skipping test; needs at least 2 non-addressable devices."; } @@ -680,16 +767,8 @@ TEST(ArrayImplTest, AssembleAndDisassembleNonAddressableArray) { for (auto* device : client->addressable_devices()) { addressable_device_ids.insert(device->Id()); } - std::vector non_addressable_devices; - for (auto* device : client->devices()) { - if (!addressable_device_ids.contains(device->Id())) { - non_addressable_devices.push_back(device); - } - if (non_addressable_devices.size() >= 2) { - break; - } - } - auto ifrt_device_list = client->MakeDeviceList(non_addressable_devices); + auto ifrt_device_list = client->MakeDeviceList( + absl::MakeConstSpan(non_addressable_devices).subspan(0, 2)); TF_ASSERT_OK_AND_ASSIGN( std::shared_ptr sharding_param_sharding, ShardingParamSharding::Create(std::move(sharding_param), ifrt_device_list, diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index 283099f5dfd87e..da3c58e84d3e8a 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -909,7 +909,14 @@ absl::StatusOr> PjRtClient::MakeArrayFromHostBuffer( } TF_ASSIGN_OR_RETURN(auto primitive_type, ToPrimitiveType(dtype)); - auto count = std::make_shared>(sharding->devices()->size()); + absl::Span ifrt_addressable_devices = + sharding->devices()->AddressableDeviceList()->devices(); + auto count = + std::make_shared>(ifrt_addressable_devices.size()); + if (ifrt_addressable_devices.empty()) { + return InvalidArgument("Cannot copy array to non-addressable device: %s", + sharding->devices()->DebugString()); + } std::function on_done_with_host_buffer_per_device; if (on_done_with_host_buffer) { on_done_with_host_buffer_per_device = @@ -924,8 +931,8 @@ absl::StatusOr> PjRtClient::MakeArrayFromHostBuffer( } PjRtArray::PjRtBuffers buffers; - buffers.reserve(sharding->devices()->size()); - for (xla::ifrt::Device* const device : sharding->devices()->devices()) { + buffers.reserve(ifrt_addressable_devices.size()); + for (xla::ifrt::Device* const device : ifrt_addressable_devices) { std::unique_ptr buffer; // If the sharding has memory_kind specified, use a version of // `PjRtClient::BufferFromHostBuffer` that accepts `PjRtMemorySpace`. @@ -945,8 +952,8 @@ absl::StatusOr> PjRtClient::MakeArrayFromHostBuffer( return InvalidArgument( "Invalid memory kind: %s; available memory kinds: %s", *sharding->memory_kind().memory_kind(), - absl::StrJoin(sharding->devices()->devices().front()->Memories(), - ", ", [](std::string* out, Memory* ms) { + absl::StrJoin(ifrt_addressable_devices.front()->Memories(), ", ", + [](std::string* out, Memory* ms) { absl::StrAppend(out, *ms->Kind().memory_kind()); })); } @@ -957,10 +964,6 @@ absl::StatusOr> PjRtClient::MakeArrayFromHostBuffer( tensorflow::down_cast(memory)->pjrt_memory(), /*device_layout=*/nullptr)); } else { - if (!device->IsAddressable()) { - return InvalidArgument("Cannot copy array to non-addressable device %s", - device->DebugString()); - } TF_ASSIGN_OR_RETURN(xla::PjRtMemorySpace * memory_space, tensorflow::down_cast(device) ->pjrt_device() @@ -1007,15 +1010,16 @@ PjRtClient::MakeErrorArrays(const absl::Status& error, } TF_ASSIGN_OR_RETURN(auto primitive_type, ToPrimitiveType(array_spec.dtype)); + absl::Span ifrt_addressable_devices = + array_spec.sharding->devices()->AddressableDeviceList()->devices(); TF_ASSIGN_OR_RETURN(Shape shard_shape, array_spec.sharding->GetShardShape(array_spec.shape)); xla::Shape xla_shape = xla::ShapeUtil::MakeShape(primitive_type, shard_shape.dims()); PjRtArray::PjRtBuffers buffers; - buffers.reserve(array_spec.sharding->devices()->size()); - for (xla::ifrt::Device* const device : - array_spec.sharding->devices()->devices()) { + buffers.reserve(ifrt_addressable_devices.size()); + for (xla::ifrt::Device* const device : ifrt_addressable_devices) { std::unique_ptr buffer; // Find `PjRtMemorySpace` that is associated with the sharding's device // and matches the sharding's memory_kind. @@ -1030,11 +1034,10 @@ PjRtClient::MakeErrorArrays(const absl::Status& error, return absl::InvalidArgumentError(absl::StrFormat( "Invalid memory kind: %s; available memory kinds: %s", *array_spec.sharding->memory_kind().memory_kind(), - absl::StrJoin( - array_spec.sharding->devices()->devices().front()->Memories(), - ", ", [](std::string* out, Memory* ms) { - absl::StrAppend(out, *ms->Kind().memory_kind()); - }))); + absl::StrJoin(ifrt_addressable_devices.front()->Memories(), ", ", + [](std::string* out, Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + }))); } TF_ASSIGN_OR_RETURN( buffers.emplace_back(), diff --git a/third_party/xla/xla/python/version.h b/third_party/xla/xla/python/version.h index 811db0ec9dbdb8..a8891ef4b22cc0 100644 --- a/third_party/xla/xla/python/version.h +++ b/third_party/xla/xla/python/version.h @@ -18,6 +18,6 @@ limitations under the License. // An increasing version number to protect jax code against breaking changes. // In JAX, reference this via jax._src.lib.ifrt_version. -#define JAX_IFRT_VERSION_NUMBER 1 +#define JAX_IFRT_VERSION_NUMBER 2 #endif // XLA_PYTHON_VERSION_H_ From d0013061d09f793a39325fdedc079d7a238ee356 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 20:55:51 -0700 Subject: [PATCH 0184/1324] Allow LogicalBufferAnalysis and TuplePointsToAnalysis to visit instructions in a fusion that are not reachable from the fusion root. PiperOrigin-RevId: 743385414 --- .../hlo/analysis/logical_buffer_analysis.cc | 3 ++- .../hlo/analysis/tuple_points_to_analysis.cc | 3 ++- .../analysis/tuple_points_to_analysis_test.cc | 25 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc index 0b0f789ea3fb8f..b6d92b51046174 100644 --- a/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc @@ -82,7 +82,8 @@ absl::Status LogicalBufferAnalysis::Analyze() { } } for (auto* instruction : fusion_instructions) { - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + instruction->fused_instructions_computation()->Accept(this)); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc index 9b8b2dc66c3865..d012c13e67ce7d 100644 --- a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc @@ -174,7 +174,8 @@ absl::Status TuplePointsToAnalysis::Analyze() { } // Run points-to analysis on fusion instructions in 'computation'. for (auto* instruction : fusion_instructions) { - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + instruction->fused_instructions_computation()->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } diff --git a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc index 8cd47c1d9d2d7b..3c41f7c1745d14 100644 --- a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc @@ -703,6 +703,31 @@ ENTRY %FusionParam0OneUser (param0: (f32[8], f32[3])) -> f32[8] { Run(hlo_str, /*expected_num_users=*/1); } +TEST_F(FusionPointsToAnalysisTest, + FusionParam0OneUserWithUnreachableInstructionInFusion) { + std::string hlo_str = R"( +HloModule FusionParam0OneUser + +%fused_computation (param_1.2: (f32[8], f32[3], f32[7])) -> f32[8] { + %param_1.2 = (f32[8]{0}, f32[3]{0}, f32[7]{0}) parameter(0) + %get-tuple-element.1 = f32[8]{0} get-tuple-element(%param_1.2), index=0 + %get-tuple-element.2 = f32[3]{0} get-tuple-element(%param_1.2), index=1 + %get-tuple-element.3 = f32[7]{0} get-tuple-element(%param_1.2), index=2 + %placeholder = f32[7]{0} custom-call(%get-tuple-element.3), custom_call_target="IntermediateBufferDummyConsumer", custom_call_has_side_effect=true + %constant.3 = f32[3]{0} constant({1, 1, 1}) + %add.1 = f32[3]{0} add(f32[3]{0} %get-tuple-element.2, f32[3]{0} %constant.3) + %constant.2 = s32[] constant(0) + ROOT %dynamic-update-slice.1 = f32[8]{0} dynamic-update-slice(f32[8]{0} %get-tuple-element.1, f32[3]{0} %add.1, s32[] %constant.2) +} + +ENTRY %FusionParam0OneUser (param0: (f32[8], f32[3], f32[7])) -> f32[8] { + %param0 = (f32[8]{0}, f32[3]{0}, f32[7]{0}) parameter(0) + ROOT %fusion = f32[8]{0} fusion(%param0), kind=kLoop, calls=%fused_computation +} +)"; + Run(hlo_str, /*expected_num_users=*/1); +} + // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users. // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices. // Tests that there are two users of the aliases of tuple-shaped fusion From 0a865fc582dc73521426d85762e56c257103b46f Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 2 Apr 2025 23:28:53 -0700 Subject: [PATCH 0185/1324] Make TF's runtime target explicitly depend on XLA's runtime target So far TF only implicitly depended on the XLA runtime through the XLA compiler target (which pulls in the runtime). We are trying to fully separate runtime and compiler, therefore things will start breaking if TF doesn't explicitly depend on the runtime which this change is doing PiperOrigin-RevId: 743425827 --- tensorflow/core/common_runtime/gpu/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index 15968c019e9ee8..56edf351c30cd3 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -1,4 +1,6 @@ load("@bazel_skylib//lib:selects.bzl", "selects") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_xla//xla/tsl:tsl.bzl", "if_cuda_libs") load( "//tensorflow:tensorflow.bzl", @@ -230,6 +232,10 @@ tf_cuda_library( ]), ) + if_cuda_or_rocm([ "@local_tsl//tsl/platform:dso_loader", + ]) + if_cuda([ + "@local_xla//xla/stream_executor/cuda:all_runtime", + ]) + if_rocm([ + "@local_xla//xla/stream_executor/rocm:all_runtime", ]), alwayslink = 1, ) From 84b62cd523594439b6f47fb226d562c053b8b980 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Apr 2025 23:29:37 -0700 Subject: [PATCH 0186/1324] Automated Code Change PiperOrigin-RevId: 743425987 --- .../compiler/mlir/lite/integrations/model_utils_core_pybind.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc index 53095f8436ccd5..2a913b5989217e 100644 --- a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc +++ b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc @@ -71,7 +71,7 @@ class MlirPythonPass pyfunc.inc_ref(); } - ~MlirPythonPass() = default; + ~MlirPythonPass() override = default; mlir::StringRef getName() const override { return name_; } mlir::StringRef getArgument() const override { return name_; } From 6a094043885f3c879cceba036dc011343278dc4b Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 3 Apr 2025 00:07:07 -0700 Subject: [PATCH 0187/1324] Make TF's XLA_GPU_JIT target depend on XLA's runtime explicitly I previously removed the dependency of the XLA runtime from the GpuExecutable target which broke users of TF's xla_gpu_jit that was transitively relying on the removed dependency. This changes fixes the issue for users of xla_gpu_jit PiperOrigin-RevId: 743436084 --- tensorflow/compiler/jit/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 23a6fa0d240440..39f93d17aa2932 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,3 +1,5 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_xla//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load( "@local_xla//xla/tsl:tsl.bzl", @@ -106,6 +108,10 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "@local_xla//xla/service:gpu_plugin", "//tensorflow/core/tfrt/common:pjrt_gpu_client_registration", + ]) + if_cuda([ + "@local_xla//xla/stream_executor/cuda:all_runtime", # buildcleaner: keep + ]) + if_rocm([ + "@local_xla//xla/stream_executor/rocm:all_runtime", # buildcleaner: keep ]), alwayslink = 1, ) From ed1888ee579faee018d8c16bca4b030a95928d1d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 00:59:04 -0700 Subject: [PATCH 0188/1324] Automated Code Change PiperOrigin-RevId: 743449890 --- tensorflow/lite/core/c/BUILD | 1 + tensorflow/lite/core/c/c_api.cc | 1 + tensorflow/lite/core/c/c_api_experimental.cc | 2 -- tensorflow/lite/core/c/c_api_experimental_test.cc | 1 + tensorflow/lite/core/c/c_api_test.cc | 2 ++ tensorflow/lite/core/c/common_test.cc | 2 ++ 6 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index b5e2bf493757e5..0f09f323530320 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -618,6 +618,7 @@ cc_test( "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/delegates:delegate_test_util", "//tensorflow/lite/testing:util", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/core/c/c_api.cc b/tensorflow/lite/core/c/c_api.cc index beb924415e298e..fcbda4e4fb0c81 100644 --- a/tensorflow/lite/core/c/c_api.cc +++ b/tensorflow/lite/core/c/c_api.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include // NOLINT diff --git a/tensorflow/lite/core/c/c_api_experimental.cc b/tensorflow/lite/core/c/c_api_experimental.cc index d2128efe608fbc..a07d8246b62923 100644 --- a/tensorflow/lite/core/c/c_api_experimental.cc +++ b/tensorflow/lite/core/c/c_api_experimental.cc @@ -17,9 +17,7 @@ limitations under the License. #include -#include #include -#include #include #include "tensorflow/lite/builtin_ops.h" diff --git a/tensorflow/lite/core/c/c_api_experimental_test.cc b/tensorflow/lite/core/c/c_api_experimental_test.cc index f98ddb0b2c00db..7ee05979e427db 100644 --- a/tensorflow/lite/core/c/c_api_experimental_test.cc +++ b/tensorflow/lite/core/c/c_api_experimental_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/c_api.h" diff --git a/tensorflow/lite/core/c/c_api_test.cc b/tensorflow/lite/core/c/c_api_test.cc index 8aeb116b692260..b9a0af08807292 100644 --- a/tensorflow/lite/core/c/c_api_test.cc +++ b/tensorflow/lite/core/c/c_api_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include +#include #include #include #include diff --git a/tensorflow/lite/core/c/common_test.cc b/tensorflow/lite/core/c/common_test.cc index fadc3f2bc68f08..e449b4821a4404 100644 --- a/tensorflow/lite/core/c/common_test.cc +++ b/tensorflow/lite/core/c/common_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include #include +#include #include #include #include +#include #include #include From 4d22cc45b38147f5b126a8f2857625a63c8cbd56 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 01:05:19 -0700 Subject: [PATCH 0189/1324] Automated Code Change PiperOrigin-RevId: 743451781 --- third_party/xla/xla/service/gpu/BUILD | 1 + third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index fd9d73f1bdec3b..ed3775413cbbce 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -839,6 +839,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index fb020163991f53..d46dd517e0dcae 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" From e712c69394102596d5b2ca4e2e72494699f24690 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Apr 2025 01:11:35 -0700 Subject: [PATCH 0190/1324] [XLA:GPU] Add an argument to `CompileAndOptionallyVerifyPtx` to decide whether the optimization pipeline should be run. Seems like many tests use this but actually don't need/expect optimizations to be run. PiperOrigin-RevId: 743453445 --- third_party/xla/xla/service/gpu/tests/BUILD | 2 ++ .../xla/xla/service/gpu/tests/gpu_codegen_test.cc | 10 +++++++--- .../xla/xla/service/gpu/tests/gpu_codegen_test.h | 3 ++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 48fedb3b6f996c..501f0a379203f7 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -62,8 +62,10 @@ cc_library( "//xla/service/gpu:gpu_executable", "//xla/stream_executor:platform_manager", "//xla/tests:llvm_irgen_test_base", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_for_library", ], ) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc index 6525d792775df8..4f31f91069fe01 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" @@ -29,6 +30,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -49,9 +51,11 @@ GpuCodegenTest::CreateNewVerifiedModuleWithFTZ(bool ftz) { } void GpuCodegenTest::CompileAndOptionallyVerifyPtx( - std::unique_ptr hlo_module, absl::string_view pattern) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - CompileToExecutable(std::move(hlo_module))); + std::unique_ptr hlo_module, absl::string_view pattern, + bool run_optimization_passes) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + CompileToExecutable(std::move(hlo_module), run_optimization_passes)); std::string ptx_str(static_cast(executable.get())->text()); // On the ROCM platform the "ptx" string is not populated for the compiled diff --git a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h index d77a4463055fa5..6fe9ac6ddca2ae 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h +++ b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h @@ -48,7 +48,8 @@ class GpuCodegenTest : public LlvmIrGenTestBase { // and hence the "Optionally" in function name. // For ROCm platform this routine will only do the "Compile" part. void CompileAndOptionallyVerifyPtx( - std::unique_ptr hlo_module, absl::string_view pattern); + std::unique_ptr hlo_module, absl::string_view pattern, + bool run_optimization_passes = true); bool is_built_with_rocm_; }; From 61cb6ff0444d619789ee1ff3b04fd6773d7f5f85 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 3 Apr 2025 01:53:54 -0700 Subject: [PATCH 0191/1324] [xla:cpu] Update XLA's XNNPACK and pthreadpool commits This is to use new features such as bf16 batch matrix multiplication and weak pthreadpool symbols. PiperOrigin-RevId: 743465001 --- tensorflow/lite/cmake/DownloadPThreadPool.cmake | 4 ++-- tensorflow/workspace2.bzl | 6 +++--- third_party/xla/workspace2.bzl | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/cmake/DownloadPThreadPool.cmake b/tensorflow/lite/cmake/DownloadPThreadPool.cmake index e12799e3231a31..a38d1b319d8bc7 100644 --- a/tensorflow/lite/cmake/DownloadPThreadPool.cmake +++ b/tensorflow/lite/cmake/DownloadPThreadPool.cmake @@ -19,8 +19,8 @@ PROJECT(pthreadpool-download NONE) INCLUDE(ExternalProject) ExternalProject_Add(pthreadpool - URL https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip - URL_HASH SHA256=745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421 + URL https://github.com/google/pthreadpool/archive/706a8ea9e4b8c2129718af195ddce7fc2573e719.zip + URL_HASH SHA256=2d56c31ebf6509d171d12ace2b543f6182ff0083ba674541515fc573738a3238 SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" BINARY_DIR "${CMAKE_BINARY_DIR}/pthreadpool" CONFIGURE_COMMAND "" diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index bd468e6428be15..82491ab8b087df 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -184,9 +184,9 @@ def _tf_repositories(): # LINT.IfChange(pthreadpool) tf_http_archive( name = "pthreadpool", - sha256 = "745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421", - strip_prefix = "pthreadpool-b92447772365661680f486e39a91dfe6675adafc", - urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip"), + sha256 = "2d56c31ebf6509d171d12ace2b543f6182ff0083ba674541515fc573738a3238", + strip_prefix = "pthreadpool-706a8ea9e4b8c2129718af195ddce7fc2573e719", + urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/706a8ea9e4b8c2129718af195ddce7fc2573e719.zip"), ) # LINT.ThenChange(//tensorflow/lite/cmake/DownloadPThreadPool.cmake) diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 7b29f1170df463..7fde4ef61d29c8 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -45,9 +45,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "9e290e7b094134bdda0cad4ef4b89625fbde3c4b8e8f5dc84044c0f2e55b875a", - strip_prefix = "XNNPACK-5b4978cae19292232a27bdf0f495819bf5297167", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/5b4978cae19292232a27bdf0f495819bf5297167.zip"), + sha256 = "72e4368ff3e7bdefd8b43fc6e5708b8e9fada7a8302ba2362028832df6262c13", + strip_prefix = "XNNPACK-e67c0fbc360903f921ff286a235c18d9e12c6df6", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/e67c0fbc360903f921ff286a235c18d9e12c6df6.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -74,9 +74,9 @@ def _tf_repositories(): tf_http_archive( name = "pthreadpool", - sha256 = "215724985c4845cdcadcb5f26a2a8777943927bb5a172a00e7716fe16a6f3c1b", - strip_prefix = "pthreadpool-b1aee199d54003fb557076a201bcac3398af580b", - urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/b1aee199d54003fb557076a201bcac3398af580b.zip"), + sha256 = "2d56c31ebf6509d171d12ace2b543f6182ff0083ba674541515fc573738a3238", + strip_prefix = "pthreadpool-706a8ea9e4b8c2129718af195ddce7fc2573e719", + urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/706a8ea9e4b8c2129718af195ddce7fc2573e719.zip"), ) tf_http_archive( From 8ed060f0c234f4c069cde7fc7080f166c34c410a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 02:03:16 -0700 Subject: [PATCH 0192/1324] compat: Update forward compatibility horizon to 2025-04-03 PiperOrigin-RevId: 743467814 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index a88e5ef33cfb53..6f4f7243d7f4d1 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 2) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 3) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 35507883d5ed1fb10073a0557a35b52eccf989af Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 02:03:22 -0700 Subject: [PATCH 0193/1324] Update GraphDef version to 2186. PiperOrigin-RevId: 743467850 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 10d9b03823b4dd..73057f3f941700 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2185 // Updated: 2025/4/2 +#define TF_GRAPH_DEF_VERSION 2186 // Updated: 2025/4/3 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From ecf6833a81a185331b8e3cfcc4487ef98a46aa49 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Apr 2025 04:21:42 -0700 Subject: [PATCH 0194/1324] [XLA:GPU] Make `NestGemmFusion` hoist `concatenate`s. Allows enabling one more test in `fusion_emitter_device_legacy_port_test.cc`. PiperOrigin-RevId: 743505373 --- .../fusion_emitter_device_legacy_port_test.cc | 25 ++++----- .../gpu/transforms/nest_gemm_fusion.cc | 28 ++++++++++ .../gpu/transforms/nest_gemm_fusion_test.cc | 54 +++++++++++++++---- 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 0bc9bded9a705c..13398805cf3e43 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -260,8 +260,7 @@ ENTRY e { } } } -} -)"; +})"; TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, GetModuleAndNestedFusionMetadata(kHloText)); TF_EXPECT_OK( @@ -274,10 +273,9 @@ CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> )")); } -TEST_F(TritonTest, - DISABLED_CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { +TEST_F(TritonTest, CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { constexpr absl::string_view kHloText = R"( -HloModule t, is_scheduled=true +HloModule t triton_gemm { parameter_0 = f32[2,3,10]{2,1,0} parameter(0) @@ -301,18 +299,13 @@ ENTRY e { "num_ctas":1}}} })"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); TF_EXPECT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"( -CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr -CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr -CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr -CHECK-DAG: %[[ARG_PTR:.*]] = arith.select %[[CONCAT_COND:.*]], %[[P1]], %[[P2]] -CHECK-DAG: %[[BATCH_STRIDE_P1:.*]] = arith.constant 1280 -CHECK-DAG: %[[BATCH_STRIDE_P2:.*]] = arith.constant 2560 -CHECK-DAG: %[[BATCH_STRIDE:.*]] = arith.select %[[CONCAT_COND_2:.*]], %[[BATCH_STRIDE_P1]], %[[BATCH_STRIDE_P2]] -CHECK-DAG: %[[PID_BATCH:.*]] = tt.get_program_id y -CHECK-DAG: %[[OFFSET:.*]] = arith.muli %[[PID_BATCH]], %[[BATCH_STRIDE]] -CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] + CreateTritonIrAndFileCheck(*module_and_metadata.computation, + module_and_metadata.block_level_parameters, R"( +CHECK: scf.if {{.*}} -> (tensor<1x32x64xf32>) +CHECK: tt.dot {{.*}} : tensor<16x32xf32> * tensor<32x64xf32> -> tensor<16x64xf32> )")); } diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index 3b8c0eb78b93fe..792d38c8a62384 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -294,6 +294,30 @@ absl::StatusOr GetTritonGemmConfig( return TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); } +// Constructs nested fusion nodes for the operands of `concatenate` instructions +// and annotates them with `kTritonNestedGemmFusionKind`. +absl::Status FuseAndAnnotateConcatOperands(HloComputation* computation) { + for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConcatenate) { + for (int i = 0; i < instr->operand_count(); ++i) { + TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( + computation->MakeInstructionPostOrderFrom( + *instr->mutable_operand(i)), + *instr)); + HloInstruction* new_operand = instr->mutable_operand(i); + TF_ASSIGN_OR_RETURN(auto gpu_config, + new_operand->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.clear_triton_gemm_config(); + backend_config.set_kind(std::string(kTritonNestedGemmFusionKind)); + TF_RETURN_IF_ERROR(new_operand->set_backend_config(gpu_config)); + } + } + } + return absl::OkStatus(); +} + // Transforms a fusion into an equivalent nested fusion if it has a single dot. // Returns ok if the transformation was successful. absl::Status MakeNestedFusionFromGemmFusion(HloFusionInstruction* fusion, @@ -304,6 +328,10 @@ absl::Status MakeNestedFusionFromGemmFusion(HloFusionInstruction* fusion, HloComputation* computation = fusion->called_computation(); + // First, create nested fusions for the operands of `concatenate` instructions + // if they exist. + TF_RETURN_IF_ERROR(FuseAndAnnotateConcatOperands(computation)); + // Left-hand side of the dot. TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(0)), diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc index bf42af6962ef93..baffca3da9b6cd 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -272,8 +272,7 @@ ENTRY entry { } } } -} -)"; +})"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), IsOkAndHolds(true)); @@ -308,8 +307,7 @@ ENTRY entry { } } } -} -)"; +})"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), IsOkAndHolds(true)); @@ -341,8 +339,7 @@ ENTRY entry { } } } -} -)"; +})"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), IsOkAndHolds(true)); @@ -376,8 +373,7 @@ ENTRY entry_computation { } } } -} -)"; +})"; // Note: block sizes were 16,16,32, but that now fails to satisfy constraints. TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), @@ -410,12 +406,52 @@ ENTRY entry_computation { } } } +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), + IsOkAndHolds(true)); + TF_ASSERT_OK(verifier().Run(module.get()).status()); } -)"; + +TEST_F(NestGemmFusionTest, ConcatenationsAreHoistedWithinNestedGemmFusions) { + absl::string_view hlo = R"( +HloModule t + +triton_gemm { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + concatenate = f32[2,10,384]{2,1,0} concatenate(parameter_1, parameter_2), dimensions={2} + ROOT dot = f32[2,3,384]{2,1,0} dot(parameter_0, concatenate), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + ROOT dot = f32[2,3,384]{2,1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); EXPECT_THAT(NestGemmFusion(compute_capability_).Run(module.get()), IsOkAndHolds(true)); TF_ASSERT_OK(verifier().Run(module.get()).status()); + HloComputation* fusion_computation = module->entry_computation() + ->root_instruction() + ->fused_instructions_computation(); + HloInstruction* dot_lhs; + HloInstruction* dot_rhs; + EXPECT_THAT( + fusion_computation->root_instruction(), + GmockMatch(match::Dot(match::Fusion(&dot_lhs), match::Fusion(&dot_rhs)))); + EXPECT_THAT(dot_rhs->fused_instructions_computation()->root_instruction(), + GmockMatch(match::Concatenate(match::Fusion(), match::Fusion()))); } TEST_F(NestGemmFusionTest, UnsupportedComputationsAreRejected) { From 06270d1cb6ec2550273f25336003011edeffe880 Mon Sep 17 00:00:00 2001 From: Venkat6871 Date: Thu, 3 Apr 2025 17:54:10 +0530 Subject: [PATCH 0195/1324] Fix typos in documentation strings --- tensorflow/python/keras/engine/base_layer_utils.py | 2 +- tensorflow/python/keras/engine/compile_utils.py | 2 +- tensorflow/python/keras/engine/data_adapter.py | 2 +- tensorflow/python/keras/engine/sequential.py | 2 +- tensorflow/python/keras/engine/training_utils.py | 2 +- tensorflow/python/keras/layers/core.py | 4 ++-- tensorflow/python/keras/models.py | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index dfa0f207a01fde..bf27b391815439 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -877,7 +877,7 @@ def no_ragged_support(inputs, layer_name): def is_split_variable(v): - """Returns True if `v` is either a PartionedVariable or a ShardedVariable.""" + """Returns True if `v` is either a PartitionedVariable or a ShardedVariable.""" return hasattr(v, '_variable_list') or hasattr(v, '_variables') diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py index 81f202d4d8b043..05ef59d5317652 100644 --- a/tensorflow/python/keras/engine/compile_utils.py +++ b/tensorflow/python/keras/engine/compile_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilites for `Model.compile`.""" +"""Utilities for `Model.compile`.""" import copy diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index f9db6cfe4a0cea..aec3e81cb0a887 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1619,7 +1619,7 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None): # For single x-input, we do no tuple wrapping since in this case # there is no ambiguity. This also makes NumPy and Dataset # consistent in that the user does not have to wrap their Dataset - # data in an unecessary tuple + # data in an unnecessary tuple if not nest.is_nested(x): return x else: diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 0f46e17d37837d..5ea6306f31bbb4 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -358,7 +358,7 @@ def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-o if not self._has_explicit_input_shape: if not tensor_util.is_tf_type(inputs) and not isinstance( inputs, np_arrays.ndarray): - # This is a Sequential with mutiple inputs. This is technically an + # This is a Sequential with multiple inputs. This is technically an # invalid use case of Sequential, but we tolerate it for backwards # compatibility. self._use_legacy_deferred_behavior = True diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index ca70099e066fd2..be9c360264fc4a 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -125,7 +125,7 @@ class RespectCompiledTrainableState(object): at `Model.compile` time will be used when training that model. In order to respect this requirement, it may be necessary to set the trainable value of layers to their compile time values before beginning a training endpoint and - restore the values before returing from said endpoint. This scope checks if + restore the values before returning from said endpoint. This scope checks if any layer's trainable state has changed since Model compile, and performs this set and un-set bookkeeping. diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index fe1a5022c5296f..fb5ec2792fde13 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -922,7 +922,7 @@ def _check_variables(self, created_variables, accessed_variables): but are not tracked by said layer: {variable_str} The layer cannot safely ensure proper Variable reuse across multiple - calls, and consquently this behavior is disallowed for safety. Lambda + calls, and consequently this behavior is disallowed for safety. Lambda layers are not well suited to stateful computation; instead, writing a subclassed Layer is the recommend way to define layers with Variables.''' @@ -1399,7 +1399,7 @@ def _check_variables(self, created_variables, accessed_variables): but are not tracked by said layer: {variable_str} The layer cannot safely ensure proper Variable reuse across multiple - calls, and consquently this behavior is disallowed for safety. Lambda + calls, and consequently this behavior is disallowed for safety. Lambda layers are not well suited to stateful computation; instead, writing a subclassed Layer is the recommend way to define layers with Variables.''' diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index b5eaccc3c579bd..4f14e1a870026a 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -451,7 +451,7 @@ def clone_model(model, input_tensors=None, clone_function=None): model, input_tensors=input_tensors, layer_fn=clone_function) -# "Clone" a subclassed model by reseting all of the attributes. +# "Clone" a subclassed model by resetting all of the attributes. def _in_place_subclassed_model_reset(model): """Substitute for model cloning that works for subclassed models. From ddd30fd34ef109088491030be909bec224d4243d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 06:24:40 -0700 Subject: [PATCH 0196/1324] Replace outdated select() on --cpu in tensorflow/lite/kernels/internal/BUILD with platform API equivalent. PiperOrigin-RevId: 743537069 --- tensorflow/lite/kernels/internal/BUILD | 154 ++----------------------- 1 file changed, 11 insertions(+), 143 deletions(-) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 4be963cc8e3607..f21180e89ee6b8 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -20,18 +20,11 @@ HARD_FP_FLAGS_IF_APPLICABLE = select({ }) NEON_FLAGS_IF_APPLICABLE = select({ - ":arm": [ - "-O3", - ], - ":armeabi-v7a": [ - "-O3", - "-mfpu=neon", - ], ":armhf": [ "-O3", "-mfpu=neon", ], - ":armv7a": [ + ":armv7": [ "-O3", "-mfpu=neon", ], @@ -107,147 +100,25 @@ config_setting( ], ) -config_setting( - name = "arm", - values = { - "cpu": "arm", - }, -) - -config_setting( - name = "arm64-v8a", - values = { - "cpu": "arm64-v8a", - }, -) - config_setting( name = "armhf", - values = { - "cpu": "armhf", - }, -) - -config_setting( - name = "armv7a", - values = { - "cpu": "armv7a", - }, -) - -config_setting( - name = "armeabi-v7a", - values = { - "cpu": "armeabi-v7a", - }, -) - -config_setting( - name = "haswell", - values = { - "cpu": "haswell", - }, -) - -config_setting( - name = "ios_armv7", - values = { - "cpu": "ios_armv7", - }, -) - -config_setting( - name = "ios_arm64", - values = { - "cpu": "ios_arm64", - }, -) - -config_setting( - name = "ios_arm64e", - values = { - "cpu": "ios_arm64e", - }, -) - -config_setting( - name = "ios_sim_arm64", - values = { - "cpu": "ios_sim_arm64", - }, -) - -config_setting( - name = "visionos_arm64", - values = { - "cpu": "visionos_arm64", - }, -) - -config_setting( - name = "visionos_sim_arm64", - values = { - "cpu": "visionos_sim_arm64", - }, -) - -config_setting( - name = "k8", - values = { - "cpu": "k8", - }, + constraint_values = [ + "@platforms//cpu:armv7e-mf", + ], ) config_setting( name = "x86", - values = { - "cpu": "x86", - }, + constraint_values = [ + "@platforms//cpu:x86_32", + ], ) config_setting( name = "x86_64", - values = { - "cpu": "x86_64", - }, -) - -config_setting( - name = "darwin", - values = { - "cpu": "darwin", - }, -) - -config_setting( - name = "darwin_arm64", - values = { - "cpu": "darwin_arm64", - }, -) - -config_setting( - name = "freebsd", - values = { - "cpu": "freebsd", - }, -) - -config_setting( - name = "windows", - values = { - "cpu": "x64_windows", - }, -) - -config_setting( - name = "raspberry_pi_with_neon", - define_values = { - "raspberry_pi_with_neon": "true", - }, - values = { - "cpu": "armeabi", - }, + constraint_values = [ + "@platforms//cpu:x86_64", + ], ) selects.config_setting_group( @@ -1450,10 +1321,7 @@ cc_test( srcs = ["optimized/avx2_quantization_utils_test.cc"], copts = select( { - ":haswell": [ - "-mavx2", - ], - ":k8": [ + ":x86_64": [ "-mavx2", ], "//conditions:default": [ From 3519332585b4f9724a244a368dd821d11c484764 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 3 Apr 2025 06:25:36 -0700 Subject: [PATCH 0197/1324] Integrate Triton up to [4e364a7](https://github.com/openai/triton/commits/4e364a7871231b5df903bc85ce4ed4256118605e) PiperOrigin-RevId: 743537344 --- .../triton/llvm_integration/cl734808760.patch | 33 -------------- .../triton/llvm_integration/cl737995800.patch | 33 -------------- .../triton/llvm_integration/cl740926882.patch | 45 ------------------- .../triton/llvm_integration/cl742325920.patch | 3 +- .../triton/llvm_integration/series.bzl | 1 - .../no_accelerate_through_broadcast.patch | 17 ------- third_party/triton/temporary/ptxas_12_4.patch | 15 +++++++ third_party/triton/temporary/series.bzl | 2 +- third_party/triton/workspace.bzl | 8 ++-- .../triton/llvm_integration/cl734808760.patch | 33 -------------- .../triton/llvm_integration/cl737995800.patch | 33 -------------- .../triton/llvm_integration/cl740926882.patch | 45 ------------------- .../triton/llvm_integration/cl742325920.patch | 3 +- .../triton/llvm_integration/series.bzl | 1 - .../no_accelerate_through_broadcast.patch | 17 ------- .../triton/temporary/ptxas_12_4.patch | 15 +++++++ .../third_party/triton/temporary/series.bzl | 2 +- .../xla/third_party/triton/workspace.bzl | 8 ++-- .../triton/compilation_pipeline_cuda.cc | 4 +- .../triton/compilation_pipeline_rocm.cc | 18 ++++++-- .../backends/gpu/codegen/triton/tma_utils.h | 2 +- .../triton_xla_extract_insert_to_triton.mlir | 4 +- ...riton_xla_extract_insert_to_triton_pass.cc | 4 +- .../service/gpu/autotuning/autotuner_util.h | 2 +- 24 files changed, 68 insertions(+), 280 deletions(-) delete mode 100644 third_party/triton/llvm_integration/cl734808760.patch delete mode 100644 third_party/triton/llvm_integration/cl737995800.patch delete mode 100644 third_party/triton/llvm_integration/cl740926882.patch delete mode 100644 third_party/triton/temporary/no_accelerate_through_broadcast.patch create mode 100644 third_party/triton/temporary/ptxas_12_4.patch delete mode 100644 third_party/xla/third_party/triton/llvm_integration/cl734808760.patch delete mode 100644 third_party/xla/third_party/triton/llvm_integration/cl737995800.patch delete mode 100644 third_party/xla/third_party/triton/llvm_integration/cl740926882.patch delete mode 100644 third_party/xla/third_party/triton/temporary/no_accelerate_through_broadcast.patch create mode 100644 third_party/xla/third_party/triton/temporary/ptxas_12_4.patch diff --git a/third_party/triton/llvm_integration/cl734808760.patch b/third_party/triton/llvm_integration/cl734808760.patch deleted file mode 100644 index cd9275bedeed20..00000000000000 --- a/third_party/triton/llvm_integration/cl734808760.patch +++ /dev/null @@ -1,33 +0,0 @@ - ---- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp 2025-02-03 07:46:30.000000000 -0800 -+++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp 2025-03-07 23:43:21.000000000 -0800 -@@ -38,10 +38,11 @@ - - // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) - struct MoveSplatAfterElementwisePattern -- : public OpTraitRewritePattern { -+ : public OpTraitRewritePattern::SplitMatchAndRewrite { - - MoveSplatAfterElementwisePattern(MLIRContext *context) -- : OpTraitRewritePattern(context) {} -+ : SplitMatchAndRewrite(Pattern::MatchTraitOpTypeTag(), -+ TypeID::get(), 1, context) {} - - LogicalResult match(Operation *op) const override { - if (!isMemoryEffectFree(op)) { -@@ -103,10 +104,13 @@ - // This also generalizes to multiple arguments when the rest are splat-like - // Not handled: multiple broadcasted arguments - struct MoveBroadcastAfterElementwisePattern -- : public OpTraitRewritePattern { -+ : public OpTraitRewritePattern::SplitMatchAndRewrite { -+ -+ using SplitMatchAndRewrite::SplitMatchAndRewrite; - - MoveBroadcastAfterElementwisePattern(MLIRContext *context) -- : OpTraitRewritePattern(context) {} -+ : SplitMatchAndRewrite(Pattern::MatchTraitOpTypeTag(), -+ TypeID::get(), 1, context) {} - - LogicalResult match(Operation *op) const override { - if (!isMemoryEffectFree(op)) { diff --git a/third_party/triton/llvm_integration/cl737995800.patch b/third_party/triton/llvm_integration/cl737995800.patch deleted file mode 100644 index 5add139f8c0f41..00000000000000 --- a/third_party/triton/llvm_integration/cl737995800.patch +++ /dev/null @@ -1,33 +0,0 @@ - ---- a/bin/triton-llvm-opt.cpp 2024-05-02 01:41:56.000000000 -0700 -+++ b/bin/triton-llvm-opt.cpp 2025-03-18 07:35:39.000000000 -0700 -@@ -91,7 +91,7 @@ - } - // If we are supposed to override the target triple or data layout, do so now. - if (!TargetTriple.empty()) -- M->setTargetTriple(Triple::normalize(TargetTriple)); -+ M->setTargetTriple(llvm::Triple(Triple::normalize(TargetTriple))); - auto optPipeline = makeOptimizingPipeline(); - if (auto err = optPipeline(M.get())) { - llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; - ---- a/python/src/llvm.cc 2025-01-21 05:40:49.000000000 -0800 -+++ b/python/src/llvm.cc 2025-03-18 07:35:39.000000000 -0700 -@@ -59,7 +59,7 @@ - opt.MCOptions.AsmVerbose = true; - opt.MCOptions.PreserveAsmComments = true; - std::unique_ptr machine{target->createTargetMachine( -- module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, -+ module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, - disableLLVMOpt ? llvm::CodeGenOptLevel::None - : llvm::CodeGenOptLevel::Aggressive)}; -@@ -132,7 +132,7 @@ - // module->print(llvm::outs(), nullptr); - - // create machine -- module.setTargetTriple(triple); -+ module.setTargetTriple(llvm::Triple(triple)); - auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); - // set data layout - module.setDataLayout(machine->createDataLayout()); diff --git a/third_party/triton/llvm_integration/cl740926882.patch b/third_party/triton/llvm_integration/cl740926882.patch deleted file mode 100644 index 24796e97184c5d..00000000000000 --- a/third_party/triton/llvm_integration/cl740926882.patch +++ /dev/null @@ -1,45 +0,0 @@ - ---- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp 2025-03-25 07:48:50.000000000 -0700 -+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp 2025-03-26 15:43:35.000000000 -0700 -@@ -298,16 +298,16 @@ - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp8ToFp32(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp8ToFp32(loc, rewriter, v[0], -+ v[1]); - } - - static SmallVector Fp8E5M2_to_Fp32(Location loc, - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp8ToFp32(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp8ToFp32(loc, rewriter, v[0], -+ v[1]); - } - - template -@@ -336,16 +336,16 @@ - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp32ToFp8(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp32ToFp8(loc, rewriter, v[0], -+ v[1]); - } - - static SmallVector Fp32_to_Fp8E5M2(Location loc, - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp32ToFp8(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp32ToFp8(loc, rewriter, v[0], -+ v[1]); - } - - static SmallVector diff --git a/third_party/triton/llvm_integration/cl742325920.patch b/third_party/triton/llvm_integration/cl742325920.patch index 3a391a40c30b1a..8a09ff47870272 100644 --- a/third_party/triton/llvm_integration/cl742325920.patch +++ b/third_party/triton/llvm_integration/cl742325920.patch @@ -42,7 +42,7 @@ --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-25 07:48:50.000000000 -0700 +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-31 11:20:16.000000000 -0700 -@@ -34,6 +34,13 @@ +@@ -34,7 +34,14 @@ return signalPassFailure(); auto nonNegativePred = [&solver](Value v) -> bool { @@ -55,4 +55,5 @@ + } return succeeded(dataflow::staticallyNonNegative(*solver, v)); }; + mod->walk([&solver, nonNegativePred](Operation *op) { diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 64e38dfaa30e75..f95ab09fd5f2cb 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,7 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl740926882.patch", "//third_party/triton:llvm_integration/cl741558316.patch", "//third_party/triton:llvm_integration/cl742325920.patch", # Add new patches just above this line diff --git a/third_party/triton/temporary/no_accelerate_through_broadcast.patch b/third_party/triton/temporary/no_accelerate_through_broadcast.patch deleted file mode 100644 index 719eec4bc76555..00000000000000 --- a/third_party/triton/temporary/no_accelerate_through_broadcast.patch +++ /dev/null @@ -1,17 +0,0 @@ - ---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2025-03-20 06:02:06.000000000 -0700 -+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2025-03-24 11:10:52.000000000 -0700 -@@ -246,9 +246,12 @@ - return false; - } - -+ // b/405045790: We don't want to propagate through the BroadcastOp because we -+ // probably don't care about the load before a broadcast as it would likely be -+ // small. This is just a heuristic to avoid a regression. - return (op->hasTrait() && isMemoryEffectFree(op)) || - isView(op) || -- isa(op); - } - diff --git a/third_party/triton/temporary/ptxas_12_4.patch b/third_party/triton/temporary/ptxas_12_4.patch new file mode 100644 index 00000000000000..fcc71519083af1 --- /dev/null +++ b/third_party/triton/temporary/ptxas_12_4.patch @@ -0,0 +1,15 @@ +This can be removed as soon as we updated ptxas to 12.8 (b/385480934). + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-03-26 00:22:57.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-04-02 02:43:36.000000000 -0700 +@@ -180,6 +180,10 @@ + // warpgroup. + Value warp = + b.and_(rewriter.create(loc), b.i32_val(0xFFFFFFFC)); ++ // Workaround for a bug in ptxas 12.3 that cause a failure in ++ // test_core.py::test_dot. The shuffle will force the compiler to treat the ++ // value as uniform and prevent wrong optimizations. ++ warp = mlir::LLVM::NVIDIA::shuffleIdx(loc, rewriter, warp, 0); + Value warpM = b.urem(warp, b.i32_val(wpt[0])); + Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); + diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 6c4e77dff5e8a9..f38405c475ed6c 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,6 +14,6 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/no_accelerate_through_broadcast.patch", + "//third_party/triton:temporary/ptxas_12_4.patch", # Add new patches just above this line ] diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index ec40f3de7a37ca..656679a89df22a 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,12 +8,12 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "triton_integrate_branch/1.1" - TRITON_SHA256 = "66666f46227b4ab10b6c5ff26bfa57446b0621ef13ebae407861328dd2dfe550" + TRITON_COMMIT = "triton_integrate_branch-1.2" + TRITON_SHA256 = "ba715575f8e8ead49df545a40c9557a4e40174400892fcf28fefdd15ff3f2c6a" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, - strip_prefix = "triton-" + TRITON_COMMIT.replace("/", "-"), - urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), + strip_prefix = "triton-" + TRITON_COMMIT, + urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{}.tar.gz".format(TRITON_COMMIT)), patch_file = extensions_files_patch_list + llvm_patch_list + temporary_patch_list, ) diff --git a/third_party/xla/third_party/triton/llvm_integration/cl734808760.patch b/third_party/xla/third_party/triton/llvm_integration/cl734808760.patch deleted file mode 100644 index cd9275bedeed20..00000000000000 --- a/third_party/xla/third_party/triton/llvm_integration/cl734808760.patch +++ /dev/null @@ -1,33 +0,0 @@ - ---- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp 2025-02-03 07:46:30.000000000 -0800 -+++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp 2025-03-07 23:43:21.000000000 -0800 -@@ -38,10 +38,11 @@ - - // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) - struct MoveSplatAfterElementwisePattern -- : public OpTraitRewritePattern { -+ : public OpTraitRewritePattern::SplitMatchAndRewrite { - - MoveSplatAfterElementwisePattern(MLIRContext *context) -- : OpTraitRewritePattern(context) {} -+ : SplitMatchAndRewrite(Pattern::MatchTraitOpTypeTag(), -+ TypeID::get(), 1, context) {} - - LogicalResult match(Operation *op) const override { - if (!isMemoryEffectFree(op)) { -@@ -103,10 +104,13 @@ - // This also generalizes to multiple arguments when the rest are splat-like - // Not handled: multiple broadcasted arguments - struct MoveBroadcastAfterElementwisePattern -- : public OpTraitRewritePattern { -+ : public OpTraitRewritePattern::SplitMatchAndRewrite { -+ -+ using SplitMatchAndRewrite::SplitMatchAndRewrite; - - MoveBroadcastAfterElementwisePattern(MLIRContext *context) -- : OpTraitRewritePattern(context) {} -+ : SplitMatchAndRewrite(Pattern::MatchTraitOpTypeTag(), -+ TypeID::get(), 1, context) {} - - LogicalResult match(Operation *op) const override { - if (!isMemoryEffectFree(op)) { diff --git a/third_party/xla/third_party/triton/llvm_integration/cl737995800.patch b/third_party/xla/third_party/triton/llvm_integration/cl737995800.patch deleted file mode 100644 index 5add139f8c0f41..00000000000000 --- a/third_party/xla/third_party/triton/llvm_integration/cl737995800.patch +++ /dev/null @@ -1,33 +0,0 @@ - ---- a/bin/triton-llvm-opt.cpp 2024-05-02 01:41:56.000000000 -0700 -+++ b/bin/triton-llvm-opt.cpp 2025-03-18 07:35:39.000000000 -0700 -@@ -91,7 +91,7 @@ - } - // If we are supposed to override the target triple or data layout, do so now. - if (!TargetTriple.empty()) -- M->setTargetTriple(Triple::normalize(TargetTriple)); -+ M->setTargetTriple(llvm::Triple(Triple::normalize(TargetTriple))); - auto optPipeline = makeOptimizingPipeline(); - if (auto err = optPipeline(M.get())) { - llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; - ---- a/python/src/llvm.cc 2025-01-21 05:40:49.000000000 -0800 -+++ b/python/src/llvm.cc 2025-03-18 07:35:39.000000000 -0700 -@@ -59,7 +59,7 @@ - opt.MCOptions.AsmVerbose = true; - opt.MCOptions.PreserveAsmComments = true; - std::unique_ptr machine{target->createTargetMachine( -- module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, -+ module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, - disableLLVMOpt ? llvm::CodeGenOptLevel::None - : llvm::CodeGenOptLevel::Aggressive)}; -@@ -132,7 +132,7 @@ - // module->print(llvm::outs(), nullptr); - - // create machine -- module.setTargetTriple(triple); -+ module.setTargetTriple(llvm::Triple(triple)); - auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); - // set data layout - module.setDataLayout(machine->createDataLayout()); diff --git a/third_party/xla/third_party/triton/llvm_integration/cl740926882.patch b/third_party/xla/third_party/triton/llvm_integration/cl740926882.patch deleted file mode 100644 index 24796e97184c5d..00000000000000 --- a/third_party/xla/third_party/triton/llvm_integration/cl740926882.patch +++ /dev/null @@ -1,45 +0,0 @@ - ---- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp 2025-03-25 07:48:50.000000000 -0700 -+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp 2025-03-26 15:43:35.000000000 -0700 -@@ -298,16 +298,16 @@ - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp8ToFp32(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp8ToFp32(loc, rewriter, v[0], -+ v[1]); - } - - static SmallVector Fp8E5M2_to_Fp32(Location loc, - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp8ToFp32(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp8ToFp32(loc, rewriter, v[0], -+ v[1]); - } - - template -@@ -336,16 +336,16 @@ - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp32ToFp8(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp32ToFp8(loc, rewriter, v[0], -+ v[1]); - } - - static SmallVector Fp32_to_Fp8E5M2(Location loc, - ConversionPatternRewriter &rewriter, - const SmallVector &v) { - assert(v.size() == 2); -- return cvtScaleFp32ToFp8(loc, rewriter, v[0], -- v[1]); -+ return cvtScaleFp32ToFp8(loc, rewriter, v[0], -+ v[1]); - } - - static SmallVector diff --git a/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch b/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch index 3a391a40c30b1a..8a09ff47870272 100644 --- a/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch +++ b/third_party/xla/third_party/triton/llvm_integration/cl742325920.patch @@ -42,7 +42,7 @@ --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-25 07:48:50.000000000 -0700 +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp 2025-03-31 11:20:16.000000000 -0700 -@@ -34,6 +34,13 @@ +@@ -34,7 +34,14 @@ return signalPassFailure(); auto nonNegativePred = [&solver](Value v) -> bool { @@ -55,4 +55,5 @@ + } return succeeded(dataflow::staticallyNonNegative(*solver, v)); }; + mod->walk([&solver, nonNegativePred](Operation *op) { diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index 64e38dfaa30e75..f95ab09fd5f2cb 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -8,7 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl740926882.patch", "//third_party/triton:llvm_integration/cl741558316.patch", "//third_party/triton:llvm_integration/cl742325920.patch", # Add new patches just above this line diff --git a/third_party/xla/third_party/triton/temporary/no_accelerate_through_broadcast.patch b/third_party/xla/third_party/triton/temporary/no_accelerate_through_broadcast.patch deleted file mode 100644 index 719eec4bc76555..00000000000000 --- a/third_party/xla/third_party/triton/temporary/no_accelerate_through_broadcast.patch +++ /dev/null @@ -1,17 +0,0 @@ - ---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2025-03-20 06:02:06.000000000 -0700 -+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2025-03-24 11:10:52.000000000 -0700 -@@ -246,9 +246,12 @@ - return false; - } - -+ // b/405045790: We don't want to propagate through the BroadcastOp because we -+ // probably don't care about the load before a broadcast as it would likely be -+ // small. This is just a heuristic to avoid a regression. - return (op->hasTrait() && isMemoryEffectFree(op)) || - isView(op) || -- isa(op); - } - diff --git a/third_party/xla/third_party/triton/temporary/ptxas_12_4.patch b/third_party/xla/third_party/triton/temporary/ptxas_12_4.patch new file mode 100644 index 00000000000000..fcc71519083af1 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/ptxas_12_4.patch @@ -0,0 +1,15 @@ +This can be removed as soon as we updated ptxas to 12.8 (b/385480934). + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-03-26 00:22:57.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-04-02 02:43:36.000000000 -0700 +@@ -180,6 +180,10 @@ + // warpgroup. + Value warp = + b.and_(rewriter.create(loc), b.i32_val(0xFFFFFFFC)); ++ // Workaround for a bug in ptxas 12.3 that cause a failure in ++ // test_core.py::test_dot. The shuffle will force the compiler to treat the ++ // value as uniform and prevent wrong optimizations. ++ warp = mlir::LLVM::NVIDIA::shuffleIdx(loc, rewriter, warp, 0); + Value warpM = b.urem(warp, b.i32_val(wpt[0])); + Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); + diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index 6c4e77dff5e8a9..f38405c475ed6c 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -14,6 +14,6 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/no_accelerate_through_broadcast.patch", + "//third_party/triton:temporary/ptxas_12_4.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index ec40f3de7a37ca..656679a89df22a 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,12 +8,12 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "triton_integrate_branch/1.1" - TRITON_SHA256 = "66666f46227b4ab10b6c5ff26bfa57446b0621ef13ebae407861328dd2dfe550" + TRITON_COMMIT = "triton_integrate_branch-1.2" + TRITON_SHA256 = "ba715575f8e8ead49df545a40c9557a4e40174400892fcf28fefdd15ff3f2c6a" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, - strip_prefix = "triton-" + TRITON_COMMIT.replace("/", "-"), - urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), + strip_prefix = "triton-" + TRITON_COMMIT, + urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{}.tar.gz".format(TRITON_COMMIT)), patch_file = extensions_files_patch_list + llvm_patch_list + temporary_patch_list, ) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc index 226cfd8d44783a..835b25a5fef1de 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc @@ -85,6 +85,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm->addPass( mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()})); + pm->addPass(mlir::createTritonNvidiaGPUOptimizeDescriptorEncodingPass()); pm->addPass(mlir::createCSEPass()); if (cc.IsAtLeastBlackwell()) { @@ -92,6 +93,8 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createLoopInvariantCodeMotionPass()); pm->addPass(mt::gpu::createTritonGPUOptimizeAccumulatorInit()); + pm->addPass( + mt::gpu::createTritonGPUAutomaticWarpSpecialization({num_stages})); pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages})); pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); pm->addPass(mlir::createTritonNvidiaGPUPromoteLHSToTMemPass()); @@ -103,7 +106,6 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mt::gpu::createTritonGPUFuseNestedLoops()); pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createLoopInvariantCodeMotionPass()); - pm->addPass(mt::gpu::createTritonGPUOptimizeAccumulatorInit()); pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages})); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index 6012e001a24661..9b75ab1f1eb9b7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -96,18 +96,32 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm->addPass(mlir::createTritonAMDGPUHoistLayoutConversionsPass()); + + pm->addPass(mt::gpu::createTritonGPUFuseNestedLoops()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createLoopInvariantCodeMotionPass()); + pm->addPass(mlir::createCanonicalizerPass()); + if (num_stages == kAmdDoubleBuffering && cc.has_amd_matrix_core()) { pm->addPass(mlir::createTritonAMDGPUStreamPipelinePass( - num_stages, /*stream_prefetch=*/true)); + num_stages, /*stream_prefetch=*/true, /*use_async_copy=*/true)); + pm->addPass(mlir::createTritonAMDGPUCoalesceAsyncCopyPass()); pm->addPass(mlir::createCanonicalizerPass()); } pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass("default")); pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication()); + if (false) { // Not enabled by default. + pm->addPass(mlir::createTritonAMDGPUInThreadTransposePass()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + } if (num_stages != kAmdDoubleBuffering) { pm->addPass(mt::gpu::createTritonGPUReorderInstructions()); } + if (false) { // For upstream, this is enabled iff arch == gfx942. + pm->addPass(mlir::createTritonAMDGPUBlockPingpongPass(num_stages)); + } pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createTritonAMDGPUConvertToBufferOpsPass(arch_name)); @@ -117,8 +131,6 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, // Based on make_llir() in // @triton//:third_party/amd/backend/compiler.py - pm->addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( - cc.gfx_version())); const int custom_lds_size = 0; pm->addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(cc.gfx_version(), custom_lds_size)); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h index 4bc39621ff742c..3a6382f5ea75ab 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h @@ -37,7 +37,7 @@ absl::StatusOr Create2DTmaDescriptor( mlir::Type element_type); // Emit a TmaDescriptor for the given argument & tensor type. It can then be -// used to load a tensor using the ExperimentalDescriptorLoadOp. +// used to load a tensor using the DescriptorLoadOp. mlir::Value EmitTmaDescriptor(EmitterLocOpBuilder& b, mlir::Value arg, mlir::RankedTensorType tensor_type); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir index a0495da17696b3..c0367fcf859649 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir @@ -40,8 +40,8 @@ func.func @lower_tile_extract_insert(%arg0: tensor<512x128xbf16>, // CHECK-TMA-SAME: %[[ARG_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor} // CHECK-TMA: %[[DESC_0:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_0]] // CHECK-TMA: %[[DESC_1:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_1]] -// CHECK-TMA: %[[LOAD:.*]] = tt.experimental_descriptor_load %[[DESC_0]] -// CHECK-TMA: tt.experimental_descriptor_store %[[DESC_1]][{{.*}}], %[[LOAD]] +// CHECK-TMA: %[[LOAD:.*]] = tt.descriptor_load %[[DESC_0]] +// CHECK-TMA: tt.descriptor_store %[[DESC_1]][{{.*}}], %[[LOAD]] // CHECK-TMA: tt.return // ----- diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index e144f234a64213..e7f064d64ba535 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -466,7 +466,7 @@ struct RewriteExtract : mlir::OpRewritePattern { auto descriptor_load = builder - .create( + .create( op.getResult().getType(), cast_to_tensor_desc_ptr_type, IndexCastUI(builder, builder.getI32Type(), op.getOffsets())) .getResult(); @@ -523,7 +523,7 @@ struct RewriteInsert : mlir::OpRewritePattern { op.getDst()) .getResult(0); - builder.create( + builder.create( cast_to_tensor_desc_ptr_type, op.getSrc(), IndexCastUI(builder, builder.getI32Type(), op.getOffsets())); } else { diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index 76eecaf0c1125d..d49545966372cf 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -61,7 +61,7 @@ class AutotuneCacheKey { // Tie a version to the cache key in order to invalidate the cache when // necessary. This should be incremented on triton upgrades or any other // changes that may affect the autotuning results. - static constexpr int kCurrentVersion = 1; + static constexpr int kCurrentVersion = 2; AutotuneCacheKey(const se::DeviceDescription& device_description, const HloInstruction& instruction) From 291e1b630eb24bf86f0d9a960a38827e5015ea92 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Thu, 3 Apr 2025 07:03:58 -0700 Subject: [PATCH 0198/1324] [XLA:GPU] Clean-up. Fix precision issues in Triton GEMM device tests for int4 int4 tests make sense for the cases when we have bf16 weights and activations. f32 types look like as an overkill. With f32 type the precision changes too much when the actual matmul happens in tf32. PiperOrigin-RevId: 743547757 --- .../triton/fusion_emitter_int4_device_test.cc | 100 +++++++++--------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc index 56c7b4903464f1..391895ea67b268 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc @@ -232,16 +232,16 @@ TEST_F(TritonTest, FuseSubchannelDequantizationFused) { w.s4 = s4[16,2048,4096]{2,1,0:E(4)} parameter(0) w.s8 = s8[16,2048,4096]{2,1,0} convert(w.s4) w.s8.bitcast = s8[16,8,256,4096]{3,2,1,0} bitcast(w.s8) - w.f32 = f32[16,8,256,4096]{3,2,1,0} convert(w.s8.bitcast) + w.bf16 = bf16[16,8,256,4096]{3,2,1,0} convert(w.s8.bitcast) - s.f32 = f32[16,8,1,4096]{3,2,1,0} parameter(1) - s.f32.bitcast = f32[16,8,4096]{2,1,0} bitcast(s.f32) - s.f32.broadcast = f32[16,8,256,4096]{3,2,1,0} broadcast(s.f32.bitcast), dimensions={0,1,3} - w = f32[16,8,256,4096]{3,2,1,0} multiply(w.f32, s.f32.broadcast) - w.bitcast = f32[16,2048,4096]{2,1,0} bitcast(w) + s.bf16 = bf16[16,8,1,4096]{3,2,1,0} parameter(1) + s.bf16.bitcast = bf16[16,8,4096]{2,1,0} bitcast(s.bf16) + s.bf16.broadcast = bf16[16,8,256,4096]{3,2,1,0} broadcast(s.bf16.bitcast), dimensions={0,1,3} + w = bf16[16,8,256,4096]{3,2,1,0} multiply(w.bf16, s.bf16.broadcast) + w.bitcast = bf16[16,2048,4096]{2,1,0} bitcast(w) - a = f32[2,16,1,2048]{3,2,1,0} parameter(2) - a.bitcast = f32[2,16,2048]{2,1,0} bitcast(a) + a = bf16[2,16,1,2048]{3,2,1,0} parameter(2) + a.bitcast = bf16[2,16,2048]{2,1,0} bitcast(a) ROOT dot = f32[16,4096,2]{2,1,0} dot(w.bitcast, a.bitcast), lhs_batch_dims={0}, lhs_contracting_dims={1}, @@ -251,9 +251,9 @@ TEST_F(TritonTest, FuseSubchannelDequantizationFused) { ENTRY main { w.s4 = s4[16,2048,4096]{2,1,0:E(4)} parameter(0) - s.f32 = f32[16,8,1,4096]{3,2,1,0} parameter(1) - a.f32 = f32[2,16,1,2048]{3,2,1,0} parameter(2) - ROOT fusion = f32[16,4096,2]{2,1,0} fusion(w.s4, s.f32, a.f32), + s.bf16 = bf16[16,8,1,4096]{3,2,1,0} parameter(1) + a.bf16 = bf16[2,16,1,2048]{3,2,1,0} parameter(2) + ROOT fusion = f32[16,4096,2]{2,1,0} fusion(w.s4, s.bf16, a.bf16), kind=kCustom, calls=fusion, backend_config={ @@ -285,22 +285,22 @@ TEST_F(TritonTest, FuseSubchannelDequantizationFusedWithSmallBlockKSize) { // The case where we do: // param -> bitcast -> broadcast -> multiply -> bitcast -> dot. constexpr absl::string_view kHloText = R"( - HloModule FuseSubchannelDequantizationFused + HloModule FuseSubchannelDequantizationFusedWithSmallBlockKSize fusion { w.s4 = s4[16,2048,4096]{2,1,0:E(4)} parameter(0) w.s8 = s8[16,2048,4096]{2,1,0} convert(w.s4) w.s8.bitcast = s8[16,8,256,4096]{3,2,1,0} bitcast(w.s8) - w.f32 = f32[16,8,256,4096]{3,2,1,0} convert(w.s8.bitcast) + w.bf16 = bf16[16,8,256,4096]{3,2,1,0} convert(w.s8.bitcast) - s.f32 = f32[16,8,1,4096]{3,2,1,0} parameter(1) - s.f32.bitcast = f32[16,8,4096]{2,1,0} bitcast(s.f32) - s.f32.broadcast = f32[16,8,256,4096]{3,2,1,0} broadcast(s.f32.bitcast), dimensions={0,1,3} - w = f32[16,8,256,4096]{3,2,1,0} multiply(w.f32, s.f32.broadcast) - w.bitcast = f32[16,2048,4096]{2,1,0} bitcast(w) + s.bf16 = bf16[16,8,1,4096]{3,2,1,0} parameter(1) + s.bf16.bitcast = bf16[16,8,4096]{2,1,0} bitcast(s.bf16) + s.bf16.broadcast = bf16[16,8,256,4096]{3,2,1,0} broadcast(s.bf16.bitcast), dimensions={0,1,3} + w = bf16[16,8,256,4096]{3,2,1,0} multiply(w.bf16, s.bf16.broadcast) + w.bitcast = bf16[16,2048,4096]{2,1,0} bitcast(w) - a = f32[2,16,1,2048]{3,2,1,0} parameter(2) - a.bitcast = f32[2,16,2048]{2,1,0} bitcast(a) + a = bf16[2,16,1,2048]{3,2,1,0} parameter(2) + a.bitcast = bf16[2,16,2048]{2,1,0} bitcast(a) ROOT dot = f32[16,4096,2]{2,1,0} dot(w.bitcast, a.bitcast), lhs_batch_dims={0}, lhs_contracting_dims={1}, @@ -310,9 +310,9 @@ TEST_F(TritonTest, FuseSubchannelDequantizationFusedWithSmallBlockKSize) { ENTRY main { w.s4 = s4[16,2048,4096]{2,1,0:E(4)} parameter(0) - s.f32 = f32[16,8,1,4096]{3,2,1,0} parameter(1) - a.f32 = f32[2,16,1,2048]{3,2,1,0} parameter(2) - ROOT fusion = f32[16,4096,2]{2,1,0} fusion(w.s4, s.f32, a.f32), + s.bf16 = bf16[16,8,1,4096]{3,2,1,0} parameter(1) + a.bf16 = bf16[2,16,1,2048]{3,2,1,0} parameter(2) + ROOT fusion = f32[16,4096,2]{2,1,0} fusion(w.s4, s.bf16, a.bf16), kind=kCustom, calls=fusion, backend_config={ @@ -323,7 +323,7 @@ TEST_F(TritonTest, FuseSubchannelDequantizationFusedWithSmallBlockKSize) { "triton_gemm_config":{ "block_m":16, "block_n":16, - "block_k":16, + "block_k":128, "split_k":1, "num_stages":1, "num_warps":2, @@ -508,11 +508,11 @@ TEST_F(TritonTest, DotWithInt4WeightsOnLhsFusedWithMultiplyByChannelScales) { DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales { w = s4[32,64,128]{2,1,0} parameter(0) w.i8 = s8[32,64,128]{2,1,0} convert(w) - w.f32 = f32[32,64,128]{2,1,0} convert(w.i8) - scales = f32[32,128]{1,0} parameter(1) - scales.broadcast = f32[32,64,128]{2,1,0} broadcast(scales), dimensions={0,2} - weights.scaled = f32[32,64,128]{2,1,0} multiply(w.f32, scales.broadcast) - activations = f32[32,64,256]{2,1,0} parameter(2) + w.bf16 = bf16[32,64,128]{2,1,0} convert(w.i8) + scales = bf16[32,128]{1,0} parameter(1) + scales.broadcast = bf16[32,64,128]{2,1,0} broadcast(scales), dimensions={0,2} + weights.scaled = bf16[32,64,128]{2,1,0} multiply(w.bf16, scales.broadcast) + activations = bf16[32,64,256]{2,1,0} parameter(2) ROOT dot = f32[32,128,256]{2,1,0} dot(weights.scaled, activations), lhs_batch_dims={0}, lhs_contracting_dims={1}, @@ -522,8 +522,8 @@ TEST_F(TritonTest, DotWithInt4WeightsOnLhsFusedWithMultiplyByChannelScales) { ENTRY main { w = s4[32,64,128]{2,1,0} parameter(0) - scales = f32[32,128]{1,0} parameter(1) - p2 = f32[32,64,256]{2,1,0} parameter(2) + scales = bf16[32,128]{1,0} parameter(1) + p2 = bf16[32,64,256]{2,1,0} parameter(2) ROOT dot = f32[32,128,256]{2,1,0} fusion(w, scales, p2), kind=kCustom, calls=DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales, @@ -543,14 +543,14 @@ TEST_F(TritonTest, FuseMultiplyInPrologue) { HloModule FuseMultiplyInPrologue ENTRY main { - t = (s4[32,64,128]{2,1,0}, f32[32,128]{0,1}, f32[32,64,256]{2,1,0}) parameter(0) + t = (s4[32,64,128]{2,1,0}, bf16[32,128]{0,1}, bf16[32,64,256]{2,1,0}) parameter(0) w = s4[32,64,128]{2,1,0} get-tuple-element(t), index=0 w.i8 = s8[32,64,128]{2,1,0} convert(w) - w.f32 = f32[32,64,128]{2,1,0} convert(w.i8) - scales = f32[32,128]{0,1} get-tuple-element(t), index=1 - scales.broadcast = f32[32,64,128]{2,1,0} broadcast(scales), dimensions={0,2} - weights.scaled = f32[32,64,128]{2,1,0} multiply(w.f32, scales.broadcast) - activations = f32[32,64,256]{2,1,0} get-tuple-element(t), index=2 + w.bf16 = bf16[32,64,128]{2,1,0} convert(w.i8) + scales = bf16[32,128]{0,1} get-tuple-element(t), index=1 + scales.broadcast = bf16[32,64,128]{2,1,0} broadcast(scales), dimensions={0,2} + weights.scaled = bf16[32,64,128]{2,1,0} multiply(w.bf16, scales.broadcast) + activations = bf16[32,64,256]{2,1,0} get-tuple-element(t), index=2 ROOT dot = f32[32,128,256]{2,1,0} dot(weights.scaled, activations), lhs_batch_dims={0}, lhs_contracting_dims={1}, @@ -560,7 +560,7 @@ TEST_F(TritonTest, FuseMultiplyInPrologue) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( - CHECK: %[[multiply:.*]] = f32[32,64,128]{2,1,0} multiply + CHECK: %[[multiply:.*]] = bf16[32,64,128]{2,1,0} multiply CHECK: %[[dot:.*]] = f32[32,128,256]{2,1,0} dot CHECK: ENTRY %main )")); @@ -651,16 +651,16 @@ TEST_P(ParametrizedTritonTest, Int4WeightsOnTheLhs) { lhs_${name} { w.s4 = s4[${lhs}]{1,0} parameter(0) w.s8 = s8[${lhs}]{1,0} convert(w.s4) - w.f32 = f32[${lhs}]{1,0} convert(w.s8) - a = f32[${rhs}]{1,0} parameter(1) - ROOT lhs_${name} = f32[${out}]{1,0} dot(w.f32, a), + w.bf16 = bf16[${lhs}]{1,0} convert(w.s8) + a = bf16[${rhs}]{1,0} parameter(1) + ROOT lhs_${name} = f32[${out}]{1,0} dot(w.bf16, a), lhs_contracting_dims={${lhs_contracting_dim}}, rhs_contracting_dims={${rhs_contracting_dim}} } ENTRY main { w = s4[${lhs}]{1,0} parameter(0) - a = f32[${rhs}]{1,0} parameter(1) + a = bf16[${rhs}]{1,0} parameter(1) ROOT gemm_fusion_dot.2 = f32[${out}]{1,0} fusion(w, a), kind=kCustom, calls=lhs_${name}, @@ -687,9 +687,9 @@ TEST_P(ParametrizedTritonTest, Int4WeightsOnTheLhsWithBatchDim) { fusion { w.s4 = s4[${lhs}]{2,1,0} parameter(0) w.s8 = s8[${lhs}]{2,1,0} convert(w.s4) - w.f32 = f32[${lhs}]{2,1,0} convert(w.s8) - a = f32[${rhs}]{2,1,0} parameter(1) - ROOT dot.0 = f32[${out}]{2,1,0} dot(w.f32, a), + w.bf16 = bf16[${lhs}]{2,1,0} convert(w.s8) + a = bf16[${rhs}]{2,1,0} parameter(1) + ROOT dot.0 = f32[${out}]{2,1,0} dot(w.bf16, a), lhs_contracting_dims={${lhs_contracting_dim}}, rhs_contracting_dims={${rhs_contracting_dim}}, lhs_batch_dims={0}, @@ -698,7 +698,7 @@ TEST_P(ParametrizedTritonTest, Int4WeightsOnTheLhsWithBatchDim) { ENTRY gemm_fusion_dot_computation { w = s4[${lhs}]{2,1,0} parameter(0) - a = f32[${rhs}]{2,1,0} parameter(1) + a = bf16[${rhs}]{2,1,0} parameter(1) ROOT gemm_fusion_dot.2 = f32[${out}]{2,1,0} fusion(w, a), kind=kCustom, calls=fusion, @@ -724,17 +724,17 @@ TEST_P(ParametrizedTritonTest, Int4WeightsOnTheRhs) { HloModule rhs_${name} rhs_${name} { - a = f32[${lhs}]{1,0} parameter(0) + a = bf16[${lhs}]{1,0} parameter(0) w.s4 = s4[${rhs}]{1,0} parameter(1) w.s8 = s8[${rhs}]{1,0} convert(w.s4) - w.f32 = f32[${rhs}]{1,0} convert(w.s8) - ROOT rhs_${name} = f32[${out}]{1,0} dot(a, w.f32), + w.bf16 = bf16[${rhs}]{1,0} convert(w.s8) + ROOT rhs_${name} = f32[${out}]{1,0} dot(a, w.bf16), lhs_contracting_dims={${lhs_contracting_dim}}, rhs_contracting_dims={${rhs_contracting_dim}} } ENTRY main { - a = f32[${lhs}]{1,0} parameter(0) + a = bf16[${lhs}]{1,0} parameter(0) w = s4[${rhs}]{1,0} parameter(1) ROOT rhs_${name} = f32[${out}]{1,0} fusion(a, w), kind=kCustom, From 997f64cb6021b63e02955174d28e805e73753e56 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Thu, 3 Apr 2025 07:40:58 -0700 Subject: [PATCH 0199/1324] [XLA:GPU] Re-enable Tensor-Cores for bitwidth <= 8 x F32. These used to crash in Triton. PiperOrigin-RevId: 743558633 --- .../gpu/codegen/triton/dot_algorithms.cc | 20 ++----------------- .../fusion_emitter_device_legacy_test.cc | 12 +++++------ .../fusion_emitter_parametrized_test.cc | 2 +- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc index 6ffdbd24aae936..a45b6854372297 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -254,30 +254,14 @@ bool IsTf32Allowed(const HloDotInstruction& dot) { return algorithm_util::HasTf32InputType(precision_config.algorithm()); } -bool DotDependsOnConvertFromByteWideOrSmallerTypeToF32( - const HloDotInstruction* dot) { - return HloBfsAnyOf({dot}, [&](const HloInstruction* node) { - if (node->opcode() != HloOpcode::kConvert) { - return false; - } - int in_width = - primitive_util::BitWidth(node->operand(0)->shape().element_type()); - return in_width <= 8 && node->shape().element_type() == F32; - }); -} - ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) { if (dot.precision_config().algorithm() == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { return ttir::InputPrecision::TF32x3; } - bool use_tf32 = IsTf32Allowed(dot); - // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. - use_tf32 = - use_tf32 && !DotDependsOnConvertFromByteWideOrSmallerTypeToF32(&dot); - - return use_tf32 ? ttir::InputPrecision::TF32 : ttir::InputPrecision::IEEE; + return IsTf32Allowed(dot) ? ttir::InputPrecision::TF32 + : ttir::InputPrecision::IEEE; } absl::StatusOr GetAlgUnsetAccumulatorType(EmitterLocOpBuilder& b, diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc index 22eabbc2935b2a..e33846f91983fa 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc @@ -563,7 +563,7 @@ ENTRY main { TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { constexpr absl::string_view kHloText = R"( triton_gemm_r { - parameter_0 = f16[80,15]{1,0} parameter(0) + parameter_0 = s8[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) parameter_1 = f32[16,15]{1,0} parameter(1) ROOT r.1 = f32[80,16]{1,0} dot(convert.3, parameter_1), @@ -572,7 +572,7 @@ triton_gemm_r { ENTRY e { p1 = f32[16,15]{1,0} parameter(1) - p0 = f16[80,15]{1,0} parameter(0) + p0 = s8[80,15]{1,0} parameter(0) ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: @@ -2484,7 +2484,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/3e-2, /*arel=*/3e-2})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/4e-2, /*arel=*/6e-2})); } TEST_F(TritonGemmTest, ParameterAfterDotIsFused) { @@ -4084,9 +4084,7 @@ ENTRY e { .WithShape(BF16, {16, 40}, {1, 0}))); } -// This test could be modified to allow TF32 once this bug is fixed. -// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. -TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) { +TEST_F(TritonTest, UseTF32For8BitOrLessWithF32) { const std::string hlo_text = R"( HloModule t @@ -4118,7 +4116,7 @@ ENTRY e { TF_ASSERT_OK( CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( CHECK: tt.dot -CHECK-NOT: inputPrecision = tf32 +CHECK: inputPrecision = tf32 )")); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc index 5cff5b32eee616..b250de906e2daa 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc @@ -117,7 +117,7 @@ INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest, ::testing::ValuesIn({ MixTypeParams{PRED, F16, 16, 32, 8}, MixTypeParams{PRED, BF16, 16, 32, 8}, - MixTypeParams{PRED, F32, 16, 32, 8, 1e-4, 1e-3}, + MixTypeParams{PRED, F32, 16, 32, 8, 2e-4, 2e-3}, MixTypeParams{S8, F16, 16, 32, 8}, MixTypeParams{S8, BF16, 16, 32, 8}, MixTypeParams{S8, F32, 16, 32, 8, 5e-2, 1e-2}, From 12a4097509260c6097f8f6f1bdf1c9e5720c8924 Mon Sep 17 00:00:00 2001 From: "Patrick J. LoPresti" Date: Thu, 3 Apr 2025 10:24:37 -0700 Subject: [PATCH 0200/1324] Remove ambiguous inherited constructor in default_quant_params.cc. GCC complains about this (https://stackoverflow.com/q/79553477/). Fix is trivial and harmless. Fixes #84977. --- .../compiler/mlir/lite/transforms/default_quant_params.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index f1b602a6763aca..7acbb7d17240b8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -54,7 +54,9 @@ namespace { class DefaultQuantParamsPass : public impl::DefaultQuantParamsPassBase { public: - using DefaultQuantParamsPassBase::DefaultQuantParamsPassBase; + DefaultQuantParamsPass() + { + } explicit DefaultQuantParamsPass(double default_min, double default_max, bool is_signed) { From 8afd8f56b53f08201b381885eda3de4d055833a8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 11:55:19 -0700 Subject: [PATCH 0201/1324] Introduce a `safe_reinterpret_cast` library to prevent unsafe `reinterpret_cast`s. Also replace `reinterepret_cast` with `safe_reinterpret_cast` in some places as examples. PiperOrigin-RevId: 743646228 --- tensorflow/core/BUILD | 1 + third_party/xla/xla/tsl/framework/BUILD | 1 + .../xla/xla/tsl/framework/bfc_allocator.h | 5 +- third_party/xla/xla/tsl/lib/gtl/BUILD | 5 +- .../xla/xla/tsl/lib/gtl/compactptrset.h | 10 +- third_party/xla/xla/tsl/lib/io/BUILD | 1 + third_party/xla/xla/tsl/lib/io/cache_test.cc | 6 +- third_party/xla/xla/tsl/profiler/utils/BUILD | 1 + .../xla/xla/tsl/profiler/utils/buffer_pool.cc | 10 +- third_party/xla/xla/tsl/util/BUILD | 29 ++++++ .../xla/xla/tsl/util/byte_swap_array.cc | 9 +- .../xla/xla/tsl/util/safe_reinterpret_cast.h | 98 +++++++++++++++++++ .../tsl/util/safe_reinterpret_cast_test.cc | 82 ++++++++++++++++ 13 files changed, 243 insertions(+), 15 deletions(-) create mode 100644 third_party/xla/xla/tsl/util/safe_reinterpret_cast.h create mode 100644 third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b67eca54f7a275..35218525f70f26 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1074,6 +1074,7 @@ cc_library( # "@com_google_absl//absl/strings", # "@com_google_absl//absl/types:optional", # "@local_xla//xla/tsl/framework/fixedpoint", +# "@local_xla//xla/tsl/util:safe_reinterpret_cast", # "//tensorflow/core/platform:resource", # "//tensorflow/core/util:managed_stack_trace", # "//tensorflow/core/util:stats_calculator_portable", diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 1b6165dd1d6792..0a2fd2de53be23 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -200,6 +200,7 @@ cc_library( "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:trace_filter_utils", "//xla/tsl/protobuf:bfc_memory_map_proto_cc", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/tsl/framework/bfc_allocator.h b/third_party/xla/xla/tsl/framework/bfc_allocator.h index a0d6568efab2fc..599e5e026b238a 100644 --- a/third_party/xla/xla/tsl/framework/bfc_allocator.h +++ b/third_party/xla/xla/tsl/framework/bfc_allocator.h @@ -37,6 +37,7 @@ limitations under the License. #include "xla/tsl/lib/core/bits.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/types.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "tsl/platform/numbers.h" namespace tensorflow { @@ -339,8 +340,8 @@ class BFCAllocator : public Allocator { } size_t IndexFor(const void* p) const { - std::uintptr_t p_int = reinterpret_cast(p); - std::uintptr_t base_int = reinterpret_cast(ptr_); + std::uintptr_t p_int = safe_reinterpret_cast(p); + std::uintptr_t base_int = safe_reinterpret_cast(ptr_); DCHECK_GE(p_int, base_int); DCHECK_LT(p_int, base_int + memory_size_); return static_cast(((p_int - base_int) >> kMinAllocationBits)); diff --git a/third_party/xla/xla/tsl/lib/gtl/BUILD b/third_party/xla/xla/tsl/lib/gtl/BUILD index de5643bee7b310..530eca7be2d140 100644 --- a/third_party/xla/xla/tsl/lib/gtl/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/BUILD @@ -44,7 +44,10 @@ package( cc_library( name = "compactptrset", hdrs = ["compactptrset.h"], - deps = [":flatset"], + deps = [ + ":flatset", + "//xla/tsl/util:safe_reinterpret_cast", + ], ) cc_library( diff --git a/third_party/xla/xla/tsl/lib/gtl/compactptrset.h b/third_party/xla/xla/tsl/lib/gtl/compactptrset.h index 3848430e76fb92..6d009230c55a3a 100644 --- a/third_party/xla/xla/tsl/lib/gtl/compactptrset.h +++ b/third_party/xla/xla/tsl/lib/gtl/compactptrset.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ #define XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ +#include #include #include "xla/tsl/lib/gtl/flatset.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" namespace tsl { namespace gtl { @@ -126,7 +128,7 @@ class CompactPointerSet { std::pair insert(T elem) { if (!isbig()) { if (rep_ == 0) { - uintptr_t v = reinterpret_cast(elem); + uintptr_t v = safe_reinterpret_cast(elem); if (v == 0 || ((v & 0x3) != 0)) { // Cannot use small representation for nullptr. Fall through. } else { @@ -155,7 +157,7 @@ class CompactPointerSet { } iterator find(T elem) const { - if (rep_ == reinterpret_cast(elem)) { + if (rep_ == safe_reinterpret_cast(elem)) { return iterator(rep_); } else if (!isbig()) { return iterator(0); @@ -168,7 +170,7 @@ class CompactPointerSet { size_t erase(T elem) { if (!isbig()) { - if (rep_ == reinterpret_cast(elem)) { + if (rep_ == safe_reinterpret_cast(elem)) { rep_ = 0; return 1; } else { @@ -199,7 +201,7 @@ class CompactPointerSet { if (rep_ != 0) { big->insert(reinterpret_cast(rep_)); } - rep_ = reinterpret_cast(big) + 0x1; + rep_ = safe_reinterpret_cast(big) + 0x1; } }; diff --git a/third_party/xla/xla/tsl/lib/io/BUILD b/third_party/xla/xla/tsl/lib/io/BUILD index 3b90095c60a1ce..8287664f9345e1 100644 --- a/third_party/xla/xla/tsl/lib/io/BUILD +++ b/third_party/xla/xla/tsl/lib/io/BUILD @@ -468,6 +468,7 @@ tsl_cc_test( deps = [ ":cache", "//xla/tsl/platform:test", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:coding", "@local_tsl//tsl/platform:raw_coding", diff --git a/third_party/xla/xla/tsl/lib/io/cache_test.cc b/third_party/xla/xla/tsl/lib/io/cache_test.cc index 14f28f22208930..a6806450973e1b 100644 --- a/third_party/xla/xla/tsl/lib/io/cache_test.cc +++ b/third_party/xla/xla/tsl/lib/io/cache_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/tsl/lib/io/cache.h" +#include #include #include #include "xla/tsl/platform/test.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "tsl/platform/coding.h" #include "tsl/platform/raw_coding.h" @@ -36,7 +38,9 @@ static int DecodeKey(Slice k) { return core::DecodeFixed32(k.data()); } static void* EncodeValue(uintptr_t v) { return reinterpret_cast(v); } -static int DecodeValue(void* v) { return reinterpret_cast(v); } +static int DecodeValue(void* v) { + return safe_reinterpret_cast(v); +} class CacheTest : public ::testing::Test { public: diff --git a/third_party/xla/xla/tsl/profiler/utils/BUILD b/third_party/xla/xla/tsl/profiler/utils/BUILD index 1a4d7f7af23ee5..3fbffc13ce1dde 100644 --- a/third_party/xla/xla/tsl/profiler/utils/BUILD +++ b/third_party/xla/xla/tsl/profiler/utils/BUILD @@ -408,6 +408,7 @@ cc_library( ]), deps = [ "//xla/tsl/platform:logging", + "//xla/tsl/util:safe_reinterpret_cast", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:thread_annotations", diff --git a/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc index 17bcb573b01cbf..4bab6056c6eeff 100644 --- a/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc @@ -15,9 +15,11 @@ limitations under the License. #include "xla/tsl/profiler/utils/buffer_pool.h" +#include #include #include "xla/tsl/platform/logging.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" @@ -41,7 +43,7 @@ uint8_t* BufferPool::GetOrCreateBuffer() { return nullptr; } VLOG(3) << "Reused Buffer, buffer=" << std::hex - << reinterpret_cast(buffer) << std::dec; + << safe_reinterpret_cast(buffer) << std::dec; return buffer; } } @@ -55,7 +57,7 @@ uint8_t* BufferPool::GetOrCreateBuffer() { return nullptr; } VLOG(3) << "Allocated Buffer, buffer=" << std::hex - << reinterpret_cast(buffer) << std::dec + << safe_reinterpret_cast(buffer) << std::dec << " size=" << buffer_size_in_bytes_; return buffer; } @@ -65,14 +67,14 @@ void BufferPool::ReclaimBuffer(uint8_t* buffer) { buffers_.push_back(buffer); VLOG(3) << "Reclaimed Buffer, buffer=" << std::hex - << reinterpret_cast(buffer) << std::dec; + << safe_reinterpret_cast(buffer) << std::dec; } void BufferPool::DestroyAllBuffers() { mutex_lock lock(buffers_mutex_); for (uint8_t* buffer : buffers_) { VLOG(3) << "Freeing Buffer, buffer:" << std::hex - << reinterpret_cast(buffer) << std::dec; + << safe_reinterpret_cast(buffer) << std::dec; port::AlignedFree(buffer); } buffers_.clear(); diff --git a/third_party/xla/xla/tsl/util/BUILD b/third_party/xla/xla/tsl/util/BUILD index a261a68b40ece5..52637b68e8dd25 100644 --- a/third_party/xla/xla/tsl/util/BUILD +++ b/third_party/xla/xla/tsl/util/BUILD @@ -122,11 +122,26 @@ filegroup( ], ) +filegroup( + name = "xla_cpu_runtime_hdrs", + srcs = [ + "safe_reinterpret_cast.h", + ], + compatible_with = get_compatible_with_portable(), +) + +filegroup( + name = "xla_cpu_runtime_srcs", + srcs = [], + compatible_with = get_compatible_with_portable(), +) + cc_library( name = "byte_swap_array", srcs = ["byte_swap_array.cc"], hdrs = ["byte_swap_array.h"], deps = [ + ":safe_reinterpret_cast", "//xla/tsl/platform:byte_order", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", @@ -366,3 +381,17 @@ tsl_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "safe_reinterpret_cast", + hdrs = ["safe_reinterpret_cast.h"], +) + +tsl_cc_test( + name = "safe_reinterpret_cast_test", + srcs = ["safe_reinterpret_cast_test.cc"], + deps = [ + ":safe_reinterpret_cast", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/tsl/util/byte_swap_array.cc b/third_party/xla/xla/tsl/util/byte_swap_array.cc index 53bc7d9124f6be..1e30dd200301e9 100644 --- a/third_party/xla/xla/tsl/util/byte_swap_array.cc +++ b/third_party/xla/xla/tsl/util/byte_swap_array.cc @@ -15,7 +15,10 @@ limitations under the License. #include "xla/tsl/util/byte_swap_array.h" +#include + #include "xla/tsl/platform/errors.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" namespace tsl { @@ -24,19 +27,19 @@ absl::Status ByteSwapArray(char* array, size_t bytes_per_elem, int array_len) { // No-op return absl::OkStatus(); } else if (bytes_per_elem == 2) { - auto array_16 = reinterpret_cast(array); + auto array_16 = safe_reinterpret_cast(array); for (int i = 0; i < array_len; i++) { array_16[i] = BYTE_SWAP_16(array_16[i]); } return absl::OkStatus(); } else if (bytes_per_elem == 4) { - auto array_32 = reinterpret_cast(array); + auto array_32 = safe_reinterpret_cast(array); for (int i = 0; i < array_len; i++) { array_32[i] = BYTE_SWAP_32(array_32[i]); } return absl::OkStatus(); } else if (bytes_per_elem == 8) { - auto array_64 = reinterpret_cast(array); + auto array_64 = safe_reinterpret_cast(array); for (int i = 0; i < array_len; i++) { array_64[i] = BYTE_SWAP_64(array_64[i]); } diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h new file mode 100644 index 00000000000000..636bcf0a42ea9d --- /dev/null +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -0,0 +1,98 @@ +// Copyright 2025 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +// This file provides a safe_reinterpret_cast function template that is like +// reinterpret_cast, but compiles only if the cast is safe. +// +// In general, reinterpret_cast is unsafe because it can easily cause undefined +// behavior. For example, +// +// Foo* foo = ...; +// Bar* bar = reinterpret_cast(foo); +// *bar = ...; +// +// is undefined behavior unless Foo or Bar is a character type. See +// https://en.cppreference.com/w/cpp/language/reinterpret_cast for more details. +// +// safe_reinterpret_cast is a subset of the casts that are always safe. We can +// add more as needed. + +#ifndef XLA_TSL_UTIL_SAFE_REINTERPRET_CAST_H_ +#define XLA_TSL_UTIL_SAFE_REINTERPRET_CAST_H_ + +#include +#include +#include + +namespace tsl { + +namespace internal { + +// IsSafeCast::value is true if it is safe to reinterpret_cast a +// value of type From to a value of type To. +// +// This is a subset of the types that are safe to cast, but it's the only +// subset that we need for now. We can add more as needed. +template +struct IsSafeCast : std::false_type {}; + +// It's safe to cast a type to itself. +template +struct IsSafeCast : std::true_type {}; + +// It's safe to cast a pointer to any character pointer. +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; + +// It's safe to cast a character pointer to a pointer to any type. +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; + +// It's safe to cast a pointer to/from std::uintptr_t. +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; + +} // namespace internal + +// Like reinterpret_cast, but compiles only if it's safe. +template ::value>> +To safe_reinterpret_cast(From from) { + return reinterpret_cast(from); // REINTERPRET_CAST_OK=for implementing + // safe_reinterpret_cast. +} + +} // namespace tsl + +#endif // XLA_TSL_UTIL_SAFE_REINTERPRET_CAST_H_ diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc new file mode 100644 index 00000000000000..648a3cea2f59c3 --- /dev/null +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -0,0 +1,82 @@ +// Copyright 2025 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#include "xla/tsl/util/safe_reinterpret_cast.h" + +#include +#include + +#include + +namespace tsl { +namespace { + +TEST(SafeReinterpretCast, CanCastPointerToFromConstCharPointer) { + const int x = 42; + const char* const char_p = safe_reinterpret_cast(&x); + EXPECT_EQ( + char_p, // + reinterpret_cast(&x)); // REINTERPRET_CAST_OK=for testing. + + const int* const int_p = safe_reinterpret_cast(char_p); + EXPECT_EQ(int_p, &x); +} + +TEST(SafeReinterpretCast, CanCastPointerToFromConstBytePointer) { + const int x = 42; + const ::std::byte* const char_p = + safe_reinterpret_cast(&x); + EXPECT_EQ( + char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + &x)); + + const int* const int_p = safe_reinterpret_cast(char_p); + EXPECT_EQ(int_p, &x); +} + +TEST(SafeReinterpretCast, CanCastPointerToFromConstUnsignedCharPointer) { + const int x = 42; + const unsigned char* const char_p = + safe_reinterpret_cast(&x); + EXPECT_EQ(char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for + // testing. + &x)); + + const int* const int_p = safe_reinterpret_cast(char_p); + EXPECT_EQ(int_p, &x); +} + +TEST(SafeReinterpretCast, CanCastPointerToFromMutableCharPointer) { + int x = 42; + char* const char_p = safe_reinterpret_cast(&x); + EXPECT_EQ(char_p, // + reinterpret_cast(&x)); // REINTERPRET_CAST_OK=for testing. + + int* const int_p = safe_reinterpret_cast(char_p); + EXPECT_EQ(int_p, &x); +} + +TEST(SafeReinterpretCast, CanCastPointerToFromStdUintptrT) { + const int x = 42; + const ::std::uintptr_t uintptr_t_p = + safe_reinterpret_cast<::std::uintptr_t>(&x); + EXPECT_EQ( + uintptr_t_p, // + reinterpret_cast<::std::uintptr_t>( // REINTERPRET_CAST_OK=for testing. + &x)); + EXPECT_EQ(safe_reinterpret_cast(uintptr_t_p), &x); +} + +} // namespace +} // namespace tsl From 55ee6b32cb7a0a64fd4df5e018dc4dbe9ee75f50 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 3 Apr 2025 11:57:33 -0700 Subject: [PATCH 0202/1324] Move on_delete_callback into CpuDeviceMemory and create a hierarchy of allocated memory instead of handling everything through a single type. PiperOrigin-RevId: 743646968 --- third_party/xla/xla/pjrt/cpu/BUILD | 1 + .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 43 ++----- .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h | 6 - third_party/xla/xla/pjrt/cpu/cpu_client.cc | 51 ++++---- .../xla/pjrt/cpu/tracked_cpu_device_buffer.cc | 111 ++++++++++-------- .../xla/pjrt/cpu/tracked_cpu_device_buffer.h | 87 ++++++++------ .../cpu/tracked_cpu_device_buffer_test.cc | 22 ++-- 7 files changed, 158 insertions(+), 163 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index ff3233570d7c34..9511e6bb9cc98a 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -58,6 +58,7 @@ xla_cc_test( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:env", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 85f60b0a0b6481..db483e43bc2229 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -201,9 +201,6 @@ void AbstractTfrtCpuBuffer::DropExternalReference() { absl::MutexLock lock(&mu_); CHECK_GT(external_reference_counter_, 0); --external_reference_counter_; - if (external_reference_counter_ == 0 && external_references_dropped_event_) { - external_references_dropped_event_->SetStateConcrete(); - } } class TrackedCpuDeviceBufferExternalReference @@ -263,16 +260,10 @@ void AbstractTfrtCpuBuffer::AbortDonation( void AbstractTfrtCpuBuffer::Delete() { std::unique_ptr device_buffer; - std::optional> external_references_dropped_event; { absl::MutexLock lock(&mu_); device_buffer = ReleaseBufferLocked(); if (device_buffer == nullptr) return; - - if (external_reference_counter_ > 0) { - external_references_dropped_event = external_references_dropped_event_ = - tsl::MakeConstructedAsyncValueRef(); - } } // Now that all holds have completed and no more can be added, we can get @@ -288,9 +279,6 @@ void AbstractTfrtCpuBuffer::Delete() { // We should also wait for the definition event. event_avs.push_back(device_buffer->definition_event().GetAsyncValue()); - if (external_references_dropped_event) { - event_avs.push_back(external_references_dropped_event->GetAsyncValue()); - } RunWhenReady(event_avs, [device_buffer = std::move(device_buffer)]() mutable { device_buffer.reset(); @@ -517,7 +505,7 @@ AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) { } MarkEventReadyOnExit ready_on_exit(std::move(usage_event)); - auto dst_buffer = tsl::MakeUnconstructedAsyncValueRef(); + auto dst_buffer = tsl::MakeUnconstructedAsyncValueRef(); auto dst_definition_event = tsl::MakeConstructedAsyncValueRef(); // Wait for src buffer definition events to finish before d2d dispatch. @@ -538,12 +526,12 @@ AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) { } CHECK(src_buffer.IsConcrete()); - auto dst_memory = CpuDeviceMemory::Allocate(src_buffer->size_bytes()); - if (!dst_memory.ok()) { - dst_definition_event.SetError(dst_memory.status()); + auto status = CpuDeviceMemoryOwned::AllocateInto(src_buffer->size_bytes(), + dst_buffer_copy); + if (!status.ok()) { + dst_definition_event.SetError(status); return; } - dst_buffer_copy.emplace(std::move(*dst_memory)); std::memcpy(dst_buffer_copy->untyped_data(), src_buffer->untyped_data(), src_buffer->size_bytes()); dst_definition_event.SetStateConcrete(); @@ -716,7 +704,6 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( absl::InlinedVector, 4> buffers; absl::InlinedVector, 4> definition_events; - absl::AnyInvocable on_delete_callback; size_t byte_size = ShapeUtil::ByteSizeOf(shape); bool owns_buffers = true; @@ -724,20 +711,15 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( // For a mutable zero copy semantics we pass a no-op deleter because // underlying buffer is owned by the caller and it will free it when // PjRt will call `on_done_with_host_buffer` callback. - CpuDeviceMemory::OwnedData::deleter_type no_op = +[](void*) {}; - buffers.push_back(tsl::MakeAvailableAsyncValueRef( - CpuDeviceMemory::OwnedData( - reinterpret_cast(const_cast(data)), no_op), - byte_size)); - on_delete_callback = std::move(on_done_with_host_buffer); - + buffers.push_back(CpuDeviceMemory::CreateForeignMemory( + const_cast(data), byte_size, // CONST_CAST_OK=flag controlled. + std::move(on_done_with_host_buffer))); } else if (can_use_zero_copy && immutable_zero_copy_semantics) { // For immutable zero-copy semantics we pass non-owning cpu memory. owns_buffers = false; - buffers.push_back(tsl::MakeAvailableAsyncValueRef( - const_cast(data), byte_size)); - on_delete_callback = std::move(on_done_with_host_buffer); - + buffers.push_back(CpuDeviceMemory::CreateForeignMemory( + const_cast(data), byte_size, // CONST_CAST_OK=flag controlled. + std::move(on_done_with_host_buffer))); } else { size_t dst_byte_size = is_packed ? CeilOfRatio(byte_size, 8 / bit_width) : byte_size; @@ -816,8 +798,7 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( } } return std::make_unique( - owns_buffers, std::move(buffers[0]), std::move(definition_events), - std::move(on_delete_callback)); + owns_buffers, std::move(buffers[0]), std::move(definition_events)); } AbstractAsyncHostToHostMemoryTransferManager:: diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 3c3ebcdf7004e9..a6f2f07bb51d26 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -297,12 +297,6 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { // Count of external references on the buffer. int external_reference_counter_ ABSL_GUARDED_BY(mu_) = 0; - // If this buffer has external references when Delete() is called, this event - // is populated by Delete(). When the last external reference is released, - // the event is triggered, which is a precondition for the buffer being - std::optional> external_references_dropped_event_ - ABSL_GUARDED_BY(mu_); - // `pending_donation_` indicates whether a donation is pending. The destructor // of the AbstractTfrtCpuBuffer will wait for a pending donation, as the // donation might fail. Note that concurrent calls to AcquireUsage() and diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 46f460a7d34b87..cf6a24d9195bd9 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -988,12 +988,11 @@ TfrtCpuClient::CreateViewOfDeviceBuffer( cpu_function_runtime::MinAlign()); } size_t byte_size = ShapeUtil::ByteSizeOf(shape); - auto non_owning_buffer = - tsl::MakeAvailableAsyncValueRef(device_ptr, byte_size); + auto non_owning_buffer = CpuDeviceMemory::CreateForeignMemory( + device_ptr, byte_size, std::move(on_delete_callback)); auto tracked_device_buffer = std::make_unique( /*owns_buffers=*/false, std::move(non_owning_buffer), - /*definition_event=*/tsl::MakeAvailableAsyncValueRef(), - std::move(on_delete_callback)); + /*definition_event=*/tsl::MakeAvailableAsyncValueRef()); CHECK_EQ(memory_space->devices().size(), 1); auto* device = memory_space->devices().front(); return std::unique_ptr(std::make_unique( @@ -1260,17 +1259,17 @@ struct BufferInfo { struct BufferAlloc { // All data members should have the same size. - absl::InlinedVector, 4> buffers; + absl::InlinedVector, 4> buffers; absl::InlinedVector allocation_sizes; void Allocate() { for (int i = 0; i < buffers.size(); ++i) { - auto memory = CpuDeviceMemory::Allocate(allocation_sizes[i]); - if (!memory.ok()) { - buffers[i].SetError(memory.status()); + auto status = + CpuDeviceMemoryOwned::AllocateInto(allocation_sizes[i], buffers[i]); + if (!status.ok()) { + buffers[i].SetError(status); return; } - buffers[i].emplace(std::move(*memory)); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(buffers[i]->untyped_data(), allocation_sizes[i]); } @@ -1280,17 +1279,17 @@ struct BufferAlloc { struct BufferAllocAndCopy { // All data members should have the same size. absl::InlinedVector, 4> src_buffers; - absl::InlinedVector, 4> dst_buffers; + absl::InlinedVector, 4> dst_buffers; absl::InlinedVector allocation_sizes; void AllocateAndCopy() { for (int i = 0; i < src_buffers.size(); ++i) { - auto memory = CpuDeviceMemory::Allocate(allocation_sizes[i]); - if (!memory.ok()) { - dst_buffers[i].SetError(memory.status()); + auto status = CpuDeviceMemoryOwned::AllocateInto(allocation_sizes[i], + dst_buffers[i]); + if (!status.ok()) { + dst_buffers[i].SetError(status); return; } - dst_buffers[i].emplace(std::move(*memory)); CHECK(src_buffers[i].IsConcrete()); std::memcpy(dst_buffers[i]->untyped_data(), src_buffers[i]->untyped_data(), allocation_sizes[i]); @@ -1348,7 +1347,7 @@ static absl::StatusOr MemoryForAllocation( // lifetime will not extend past the lifetime of the donated input buffer. if ((!can_donate || (arg && !arg->owns_buffers())) && !allocation.is_readonly()) { - auto copy = tsl::MakeUnconstructedAsyncValueRef(); + auto copy = tsl::MakeUnconstructedAsyncValueRef(); buffer_alloc_and_copy.src_buffers.push_back(out.CopyRef()); buffer_alloc_and_copy.dst_buffers.push_back(copy); @@ -1369,21 +1368,21 @@ static absl::StatusOr MemoryForAllocation( allocation.index() < constants.size()) { se::DeviceMemoryBase constant = constants[allocation.index()].AsDeviceMemoryBase(); - buffer_info.buffer = tsl::MakeAvailableAsyncValueRef( + buffer_info.buffer = CpuDeviceMemory::CreateUnownedConstant( constant.opaque(), constant.size()); buffer_info.owns_buffer = false; buffer_info.buffer_size = constant.size(); return buffer_info; } else if (allocation.is_constant() || allocation.is_thread_local()) { - buffer_info.buffer = tsl::MakeAvailableAsyncValueRef(); + buffer_info.buffer = CpuDeviceMemory::CreateUnownedConstant(nullptr, 0); buffer_info.owns_buffer = true; buffer_info.buffer_size = 0; return buffer_info; } // Output and temporary buffer. - auto out = tsl::MakeUnconstructedAsyncValueRef(); + auto out = tsl::MakeUnconstructedAsyncValueRef(); buffer_alloc.buffers.push_back(out); buffer_alloc.allocation_sizes.push_back(allocation.size()); @@ -1433,11 +1432,11 @@ absl::Status TfrtCpuExecutable::CheckBufferCompatibilities( } for (int i = 0; i < input_buffers.size(); ++i) { const auto& buffer = input_buffers[i].second; - if (input_buffer_sizes_in_bytes_[i] != buffer->BufferSizes()[0]) { + if (input_buffer_sizes_in_bytes_[i] != buffer->BufferSize()) { return InvalidArgument( "Executable expected parameter %d of size %lld but got buffer with " "incompatible size %lld", - i, input_buffer_sizes_in_bytes_[i], buffer->BufferSizes()[0]); + i, input_buffer_sizes_in_bytes_[i], buffer->BufferSize()); } } return absl::OkStatus(); @@ -1573,21 +1572,23 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( // Tuplize the inputs if compiler expects a single tuple argument but runtime // gets many inputs that are not yet tupled. - tsl::AsyncValueRef tuple_index_table; + tsl::AsyncValueRef tuple_index_table; if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { absl::InlinedVector, 4> leaf_buffers; leaf_buffers.reserve(tracked_buffers.size()); for (const auto& tracked_buffer : tracked_buffers) { leaf_buffers.push_back(tracked_buffer.second->buffer()); } - tuple_index_table = tsl::MakeUnconstructedAsyncValueRef(); + tuple_index_table = + tsl::MakeUnconstructedAsyncValueRef(); tsl::RunWhenReady( absl::MakeConstSpan(leaf_buffers), - [buffers = leaf_buffers, tuple_index_table = tuple_index_table] { + [buffers = leaf_buffers, + tuple_index_table = tuple_index_table]() mutable { size_t index_table_byte_size = buffers.size() * sizeof(void*); // We assume tuple table allocations will not fail. - tuple_index_table.emplace( - CpuDeviceMemory::Allocate(index_table_byte_size).value()); + CHECK_OK(CpuDeviceMemoryOwned::AllocateInto(index_table_byte_size, + tuple_index_table)); uintptr_t* index_table = reinterpret_cast(tuple_index_table->untyped_data()); for (int i = 0; i < buffers.size(); ++i) { diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc index 6a3856e60a21b5..0bc6851c0808ac 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.cc @@ -81,93 +81,104 @@ tsl::AsyncValueRef AfterAll( } // namespace -// Creates non-owning CPU device memory from a raw data pointer. -CpuDeviceMemory::CpuDeviceMemory(void* data, size_t size_bytes) - : data_(data), size_bytes_(size_bytes) {} +class CpuDeviceMemoryForeign : public CpuDeviceMemory { + public: + CpuDeviceMemoryForeign(void* base, size_t size, + absl::AnyInvocable on_delete_callback) + : CpuDeviceMemory(base, size), + on_delete_callback_(std::move(on_delete_callback)) {} + ~CpuDeviceMemoryForeign() override { + if (on_delete_callback_) { + std::move(on_delete_callback_)(); + } + } + + private: + absl::AnyInvocable on_delete_callback_; +}; + +tsl::AsyncValueRef CpuDeviceMemory::CreateForeignMemory( + void* base, size_t size, absl::AnyInvocable on_delete_callback) { + return tsl::MakeAvailableAsyncValueRef( + base, size, std::move(on_delete_callback)); +} + +class CpuDeviceMemoryConstant : public CpuDeviceMemory { + public: + CpuDeviceMemoryConstant(void* base, size_t size) + : CpuDeviceMemory(base, size) {} + using CpuDeviceMemory::CpuDeviceMemory; +}; + +tsl::AsyncValueRef CpuDeviceMemory::CreateUnownedConstant( + void* base, size_t size) { + return tsl::MakeAvailableAsyncValueRef(base, size); +} -// Creates owning CPU device memory from an owned data pointer. -CpuDeviceMemory::CpuDeviceMemory(OwnedData data, size_t size_bytes) - : data_(data.get()), - owned_data_(std::move(data)), - size_bytes_(size_bytes) {} +CpuDeviceMemoryOwned::~CpuDeviceMemoryOwned() { + CHECK_NE(untyped_data(), nullptr); + tsl::port::AlignedSizedFree(untyped_data(), cpu::MinAlign(), size_bytes()); +} // Allocates owning memory wrapped in an available `AsyncValueRef`. absl::StatusOr> CpuDeviceMemory::AllocateAvailable(size_t size_bytes) { - TF_ASSIGN_OR_RETURN(CpuDeviceMemory memory, Allocate(size_bytes)); - return tsl::MakeAvailableAsyncValueRef(std::move(memory)); + if (void* data = tsl::port::AlignedMalloc(size_bytes, cpu::MinAlign())) { + return tsl::MakeAvailableAsyncValueRef(data, + size_bytes); + } + return ResourceExhausted("Out of memory allocating %d bytes.", size_bytes); } // Allocates raw owning memory. The typical usage is for delayed allocation. -absl::StatusOr CpuDeviceMemory::Allocate(size_t size_bytes) { +absl::Status CpuDeviceMemoryOwned::AllocateInto( + size_t size_bytes, tsl::AsyncValueRef& out) { if (void* data = tsl::port::AlignedMalloc(size_bytes, cpu::MinAlign())) { - return CpuDeviceMemory( - OwnedData{static_cast(data), tsl::port::AlignedFree}, - size_bytes); + out.emplace(data, size_bytes); + return absl::OkStatus(); } return ResourceExhausted("Out of memory allocating %d bytes.", size_bytes); } TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( bool owns_buffers, tsl::AsyncValueRef buffer, - absl::InlinedVector, 4> definition_events, - absl::AnyInvocable on_delete_callback) + absl::InlinedVector, 4> definition_events) : TrackedCpuDeviceBuffer(owns_buffers, std::move(buffer), - AfterAll(definition_events), - std::move(on_delete_callback)) {} + AfterAll(definition_events)) {} TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( bool owns_buffers, tsl::AsyncValueRef buffer, size_t buffer_size, - absl::InlinedVector, 4> definition_events, - absl::AnyInvocable on_delete_callback) + absl::InlinedVector, 4> definition_events) : TrackedCpuDeviceBuffer(owns_buffers, std::move(buffer), buffer_size, - AfterAll(definition_events), - std::move(on_delete_callback)) {} + AfterAll(definition_events)) {} TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( bool owns_buffers, tsl::AsyncValueRef buffer, - tsl::AsyncValueRef definition_event, - absl::AnyInvocable on_delete_callback) + tsl::AsyncValueRef definition_event) : owns_buffers_(owns_buffers), - buffers_({std::move(buffer)}), - definition_event_(std::move(definition_event)), - on_delete_callback_(std::move(on_delete_callback)) { + buffer_(std::move(buffer)), + definition_event_(std::move(definition_event)) { DCHECK(definition_event_); - for (const auto& buffer : buffers_) { - CHECK(buffer.IsConcrete()); - buffer_sizes_.push_back(buffer->size_bytes()); - } - CHECK_EQ(buffers_.size(), 1); + CHECK(buffer_.IsConcrete()); + buffer_size_ = buffer_->size_bytes(); } TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer( bool owns_buffers, tsl::AsyncValueRef buffer, - size_t buffer_size, tsl::AsyncValueRef definition_event, - absl::AnyInvocable on_delete_callback) + size_t buffer_size, tsl::AsyncValueRef definition_event) : owns_buffers_(owns_buffers), - buffers_({std::move(buffer)}), - buffer_sizes_({buffer_size}), - definition_event_(std::move(definition_event)), - on_delete_callback_(std::move(on_delete_callback)) { + buffer_(std::move(buffer)), + buffer_size_(buffer_size), + definition_event_(std::move(definition_event)) { DCHECK(definition_event_); - CHECK_EQ(buffers_.size(), 1); } TrackedCpuDeviceBuffer::~TrackedCpuDeviceBuffer() { ReleaseDeviceMemory(); - if (on_delete_callback_) { - std::move(on_delete_callback_)(); - } -} - -tsl::AsyncValuePtr TrackedCpuDeviceBuffer::Buffer( - const ShapeIndex& shape_index) { - CHECK(shape_index.empty()); - return buffers_[0].AsPtr(); } -size_t TrackedCpuDeviceBuffer::BufferSize() { return buffer_sizes_[0]; } +size_t TrackedCpuDeviceBuffer::BufferSize() { return buffer_size_; } void TrackedCpuDeviceBuffer::AddUsageEvents( absl::Span> events) { @@ -196,7 +207,7 @@ TrackedCpuDeviceBuffer::LockUseAndTransferUsageEvents() { } void TrackedCpuDeviceBuffer::ReleaseDeviceMemory() { - buffers_.clear(); + buffer_ = tsl::AsyncValueRef(); definition_event_.reset(); usage_events_.clear(); } diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h index c33526af99c947..7bbad71a8172dd 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h @@ -35,32 +35,51 @@ namespace xla { // memory can be either owned or non-owned. class CpuDeviceMemory { public: - using OwnedData = std::unique_ptr; + virtual ~CpuDeviceMemory() = default; - CpuDeviceMemory() = default; - CpuDeviceMemory(CpuDeviceMemory&&) = default; - CpuDeviceMemory& operator=(CpuDeviceMemory&&) = default; + CpuDeviceMemory(const CpuDeviceMemory& other) = delete; + CpuDeviceMemory(CpuDeviceMemory&& other) = delete; + CpuDeviceMemory& operator=(const CpuDeviceMemory&) = delete; + CpuDeviceMemory& operator=(CpuDeviceMemory&&) = delete; - // Creates non-owning CPU device memory from a raw data pointer. - CpuDeviceMemory(void* data, size_t size_bytes); - - // Creates owning CPU device memory from an owned data pointer. - CpuDeviceMemory(OwnedData data, size_t size_bytes); + void* untyped_data() const { return base_; } + size_t size_bytes() const { return size_bytes_; } // Allocates owning memory wrapped in an available `AsyncValueRef`. static absl::StatusOr> AllocateAvailable( size_t size_bytes); - // Allocates raw owning memory. The typical usage is for delayed allocation. - static absl::StatusOr Allocate(size_t size_bytes); + // Creates an available asyncref to a CpuDeviceMemory that wraps foreign + // memory. Will call on_delete_callback on the last-ref. + static tsl::AsyncValueRef CreateForeignMemory( + void* base, size_t size, + absl::AnyInvocable on_delete_callback); - void* untyped_data() const { return data_; } - size_t size_bytes() const { return size_bytes_; } + // Creates an available asyncref to a CpuDeviceMemory that wraps an unowned + // constant. No action will be taken on decref. + static tsl::AsyncValueRef CreateUnownedConstant(void* base, + size_t size); - private: - void* data_ = nullptr; // non-owning data pointer - OwnedData owned_data_ = {nullptr, free}; // optional owning data pointer - size_t size_bytes_ = 0; + protected: + explicit CpuDeviceMemory(void* base, size_t size) + : base_(base), size_bytes_(size) {} + + void* base_; + size_t size_bytes_; +}; + +// CpuDeviceMemory that has been allocated explicitly by PjRt. +class CpuDeviceMemoryOwned : public CpuDeviceMemory { + public: + ~CpuDeviceMemoryOwned() final; + + // Allocates raw owning memory. The typical usage is for delayed allocation. + static absl::Status AllocateInto( + size_t size_bytes, tsl::AsyncValueRef& out); + + protected: + friend class tsl::internal::ConcreteAsyncValue; + using CpuDeviceMemory::CpuDeviceMemory; }; // A class that represents a CPU device buffer: it can be a single memory region @@ -77,14 +96,12 @@ class TrackedCpuDeviceBuffer { // states. Definition event is after the list of `definition_events`. TrackedCpuDeviceBuffer( bool owns_buffers, tsl::AsyncValueRef buffer, - absl::InlinedVector, 4> definition_events, - absl::AnyInvocable on_delete_callback = nullptr); + absl::InlinedVector, 4> definition_events); // Variant with single definition event. - TrackedCpuDeviceBuffer( - bool owns_buffers, tsl::AsyncValueRef buffer, - tsl::AsyncValueRef definition_event, - absl::AnyInvocable on_delete_callback = nullptr); + TrackedCpuDeviceBuffer(bool owns_buffers, + tsl::AsyncValueRef buffer, + tsl::AsyncValueRef definition_event); // Constructor for unallocated cpu memory, i.e., `buffer` will have // unconstructed states, and we also need to provide `buffer_size` which will @@ -94,14 +111,13 @@ class TrackedCpuDeviceBuffer { TrackedCpuDeviceBuffer( bool owns_buffers, tsl::AsyncValueRef buffer, size_t buffer_size, - absl::InlinedVector, 4> definition_events, - absl::AnyInvocable on_delete_callback = nullptr); + absl::InlinedVector, 4> definition_events); // Variant with single definition event. - TrackedCpuDeviceBuffer( - bool owns_buffers, tsl::AsyncValueRef buffer, - size_t buffer_size, tsl::AsyncValueRef definition_event, - absl::AnyInvocable on_delete_callback = nullptr); + TrackedCpuDeviceBuffer(bool owns_buffers, + tsl::AsyncValueRef buffer, + size_t buffer_size, + tsl::AsyncValueRef definition_event); TrackedCpuDeviceBuffer(TrackedCpuDeviceBuffer&&) noexcept = default; TrackedCpuDeviceBuffer& operator=(TrackedCpuDeviceBuffer&&) noexcept = @@ -109,11 +125,7 @@ class TrackedCpuDeviceBuffer { ~TrackedCpuDeviceBuffer(); - const tsl::AsyncValueRef& buffer() { return buffers_[0]; } - - absl::Span BufferSizes() { return buffer_sizes_; } - - tsl::AsyncValuePtr Buffer(const ShapeIndex& shape_index); + const tsl::AsyncValueRef& buffer() { return buffer_; } size_t BufferSize(); @@ -142,18 +154,15 @@ class TrackedCpuDeviceBuffer { bool owns_buffers_; // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers. - absl::InlinedVector, 4> buffers_; + tsl::AsyncValueRef buffer_; // Should correspond to size of each buffer in `buffers_` when `buffers_` is // available. - absl::InlinedVector buffer_sizes_; + size_t buffer_size_; // The definition event are associated with CPU operations that write to the // buffers. tsl::AsyncValueRef definition_event_; // Usage events are associated with CPU operations that read from the buffers. absl::InlinedVector, 4> usage_events_; - // A callback to call when the TrackedCpuDeviceBuffer is about to be - // destroyed. - absl::AnyInvocable on_delete_callback_; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc index c01015c8472a95..57572a9d361543 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "xla/pjrt/cpu/cpu_event.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -51,8 +52,7 @@ TEST(TrackedCpuDeviceBufferTest, Basic) { }); TrackedCpuDeviceBuffer tracked_buffer( - /*owns_buffers=*/true, buffer, definition_event, - /*on_delete_callback_=*/nullptr); + /*owns_buffers=*/true, buffer, definition_event); BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); @@ -77,8 +77,7 @@ TEST(TrackedCpuDeviceBufferTest, BasicError) { }); TrackedCpuDeviceBuffer tracked_buffer( - /*owns_buffers=*/true, buffer, definition_event, - /*on_delete_callback_=*/nullptr); + /*owns_buffers=*/true, buffer, definition_event); BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue()); @@ -90,20 +89,19 @@ TEST(TrackedCpuDeviceBufferTest, BasicError) { TEST(TrackedCpuDeviceBufferTest, DelayedAllocation) { std::string expected = "tracked_cpu_device_buffer_test"; - auto buffer = MakeUnconstructedAsyncValueRef(); + auto buffer = MakeUnconstructedAsyncValueRef(); auto malloc_event = MakeConstructedAsyncValueRef(); - malloc_event.AndThen( - [buffer_copy = buffer.CopyRef(), buffer_size = expected.size()] { - buffer_copy.emplace(CpuDeviceMemory::Allocate(buffer_size).value()); - }); + malloc_event.AndThen([buffer_copy = buffer.CopyRef(), + buffer_size = expected.size()]() mutable { + CHECK_OK(CpuDeviceMemoryOwned::AllocateInto(buffer_size, buffer_copy)); + }); auto definition_event = MakeConstructedAsyncValueRef(); TrackedCpuDeviceBuffer tracked_buffer(/*owns_buffers=*/true, buffer, - expected.size(), definition_event, - /*on_delete_callback_=*/nullptr); + expected.size(), definition_event); auto result = tracked_buffer.buffer(); ASSERT_FALSE(result.IsAvailable()); - ASSERT_EQ(tracked_buffer.BufferSizes()[0], expected.size()); + ASSERT_EQ(tracked_buffer.BufferSize(), expected.size()); ThreadPool thread_pool(tsl::Env::Default(), "tracked_buffer_test", /*num_threads=*/4); From 460fff279fa866db19fd9b4c8f842c4b3f1f19e8 Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Thu, 3 Apr 2025 12:06:45 -0700 Subject: [PATCH 0203/1324] fix failing to build //tensorflow/core:portable_tensorflow_lib_lite PiperOrigin-RevId: 743650381 --- tensorflow/core/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 35218525f70f26..cda9e305cc4c60 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1037,6 +1037,7 @@ cc_library( "//tensorflow/core:mobile_additional_lib_deps", "//tensorflow/core/platform:resource", "//tensorflow/core/public:release_version", + "//tensorflow/core/util:onednn_env_vars", "//tensorflow/core/util:stats_calculator_portable", ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, From af3fe1d51aafcda852051a751b2f47093ce92457 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 3 Apr 2025 12:19:21 -0700 Subject: [PATCH 0204/1324] [PjRt-IFRT] Treat IFRT HloSharding with a single tile as a fully replicated sharding `xla::ifrt::HloSharding` that has a single tile would have the shard buffer to be the same as the global array both in the content and the shape. Thus, we treat it as a fully replicated sharding in the context of IFRT APIs even though `HloSharding` does not explicitly say full replication in the context of the XLA `HloSharding`. This allows taking a runtime path for array creation specialized for fully replicated sharding for a single-tile `HloSharding`. PiperOrigin-RevId: 743654908 --- .../xla/xla/python/pjrt_ifrt/xla_sharding.cc | 4 +- .../xla/python/pjrt_ifrt/xla_sharding_test.cc | 49 +++++++++++++++++-- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc index 4aaf505ce6819e..0dc54be5e1e429 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc @@ -127,7 +127,9 @@ std::unique_ptr HloSharding::Create( HloSharding::HloSharding(DeviceListRef devices, MemoryKind memory_kind, xla::HloSharding xla_hlo_sharding) : llvm::RTTIExtends( - std::move(devices), memory_kind, xla_hlo_sharding.IsReplicated()), + std::move(devices), memory_kind, + (xla_hlo_sharding.IsReplicated() || + (xla_hlo_sharding.IsTiled() && xla_hlo_sharding.NumTiles() == 1))), xla_hlo_sharding_(std::move(xla_hlo_sharding)) {} absl::StatusOr HloSharding::GetShardShape(const Shape& shape) const { diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc index 040e4826ffe3ed..8b6a522c6a7321 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -64,22 +64,63 @@ TEST_P(HloShardingTest, CreateWithBadDeviceList) { TEST_P(HloShardingTest, IsFullyReplicated) { auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); { - // Fully replicated. + // Fully replicated HloSharding is fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); EXPECT_TRUE(sharding->IsFullyReplicated()); } { - // Not fully replicated. + // Single-tile HloSharding is fully replicated. + auto device_list = GetDevices({0}); // This sharding uses 1 device. + auto xla_hlo_sharding = xla::HloSharding::IotaTile({1, 1}); + std::shared_ptr sharding = + HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); + EXPECT_TRUE(sharding->IsFullyReplicated()); + } + { + // Multi-tile HloSharding with last_dim_replicate where all replices are on + // the last tile dimension is fully replicated. + auto xla_hlo_sharding = xla::HloSharding::PartialTile( + xla::TileAssignment(xla::IotaTileAssignment::Create({1, 6}))); + std::shared_ptr sharding = + HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); + EXPECT_TRUE(sharding->IsFullyReplicated()); + } + { + // Multi-tile HloSharding with last_dim_replicate where not all replices are + // on the last tile dimension is not fully replicated. + auto xla_hlo_sharding = xla::HloSharding::PartialTile( + xla::TileAssignment(xla::IotaTileAssignment::Create({2, 3}))); + std::shared_ptr sharding = + HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); + EXPECT_FALSE(sharding->IsFullyReplicated()); + } + { + // Multi-tile HloSharding with no last_dim_replicate is not fully + // replicated. auto xla_hlo_sharding = xla::HloSharding::IotaTile({1, 6}); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); EXPECT_FALSE(sharding->IsFullyReplicated()); } { - // Not fully replicated. - auto xla_hlo_sharding = xla::HloSharding::IotaTile({2, 3}); + // Maximal HloSharding is not fully replicated. + auto xla_hlo_sharding = xla::HloSharding::AssignDevice(/*device_id=*/0); + std::shared_ptr sharding = + HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); + EXPECT_FALSE(sharding->IsFullyReplicated()); + } + { + // Manual HloSharding is not fully replicated. + auto xla_hlo_sharding = xla::HloSharding::Manual(); + std::shared_ptr sharding = + HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); + EXPECT_FALSE(sharding->IsFullyReplicated()); + } + { + // Unknown HloSharding is not fully replicated. + auto xla_hlo_sharding = xla::HloSharding::Unknown(); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); EXPECT_FALSE(sharding->IsFullyReplicated()); From 37a0b48d692fb2756e87afa187fa93f918180f72 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 3 Apr 2025 12:20:54 -0700 Subject: [PATCH 0205/1324] [XLA:GPU] Add fast lookup interpolator. Supports O(1) complements of power of 2 and next-power-of-two lookup. PiperOrigin-RevId: 743655380 --- third_party/xla/xla/service/gpu/model/BUILD | 8 +- .../xla/xla/service/gpu/model/interpolator.h | 101 +++++++++++- .../service/gpu/model/interpolator_test.cc | 153 +++++++++++++----- 3 files changed, 214 insertions(+), 48 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 6d4514555f800b..acbabc297c174a 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -1044,7 +1044,11 @@ xla_cc_test( cc_library( name = "interpolator", hdrs = ["interpolator.h"], - deps = ["@com_google_absl//absl/log:check"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], ) xla_cc_test( @@ -1052,6 +1056,8 @@ xla_cc_test( srcs = ["interpolator_test.cc"], deps = [ ":interpolator", + "//xla/service/gpu:variant_visitor", + "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/service/gpu/model/interpolator.h b/third_party/xla/xla/service/gpu/model/interpolator.h index 491407164e7418..76d3260fe819d1 100644 --- a/third_party/xla/xla/service/gpu/model/interpolator.h +++ b/third_party/xla/xla/service/gpu/model/interpolator.h @@ -20,10 +20,14 @@ limitations under the License. #include #include #include +#include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/strings/str_join.h" namespace xla::gpu { @@ -40,7 +44,7 @@ class InterpolatorBase { }; // Returns interpolated value. - virtual R Eval(std::array& point) = 0; + virtual R Eval(std::array& point) const = 0; protected: std::vector, R>> plane_; @@ -52,12 +56,10 @@ class InterpolatorBase { // is to make it aware of the n-dimensional grid properties (like a constant // distance per dimension between neighbouring points) which in turn can make // shave off a bunch of time complexity. -// TODO: Speed up NN retrieval if it happens to be a compilation bottleneck (by -// rounding, k-d trees etc). template class EuclideanNNInterpolator : public InterpolatorBase { public: - R Eval(std::array& point) override { + R Eval(std::array& point) const override { CHECK_GT(this->plane_.size(), 0); R result; @@ -75,7 +77,7 @@ class EuclideanNNInterpolator : public InterpolatorBase { private: int64_t Norm2(const std::array& lhs, - const std::array& rhs) { + const std::array& rhs) const { int64_t dist = 0; for (int i = 0; i < lhs.size(); ++i) { int coord = lhs[i]; @@ -86,6 +88,95 @@ class EuclideanNNInterpolator : public InterpolatorBase { } }; +template +class EuclideanComplementInterpolator : public EuclideanNNInterpolator { + public: + explicit EuclideanComplementInterpolator( + std::array next_context, + std::array next_power_context, + std::array max_context, std::array min_context) + : retrieval_ctx_(next_context), + retrieval_pow_ctx_(next_power_context), + max_ctx_(max_context), + min_ctx_(min_context) {} + + void Add(std::array& point, R val) { + EuclideanNNInterpolator::Add(point, val); + retrieval_[point] = val; + } + + R Eval(std::array& point) const override { + CHECK_GT(this->plane_.size(), 0); + std::array interpolation_point; + for (int i = 0; i < point.size(); ++i) { + std::optional next_potential_dim; + if (retrieval_ctx_[i] != -1) { + int64_t next = retrieval_ctx_[i]; + next_potential_dim = Closest(point[i], PrevComplement(point[i], next), + NextComplement(point[i], next)); + } + if (retrieval_pow_ctx_[i] != -1) { + next_potential_dim = Closest(point[i], PrevPowerOfTwo(point[i]), + NextPowerOfTwo(point[i])); + } + CHECK(next_potential_dim.has_value()); + interpolation_point[i] = + std::max(std::min(*next_potential_dim, max_ctx_[i]), min_ctx_[i]); + } + return retrieval_.at(interpolation_point); + } + + private: + int64_t Closest(int64_t n, int64_t prev, int64_t next) const { + if (n - prev < next - n) { + return prev; + } + return next; + } + + int64_t NextComplement(int64_t n, int64_t complement) const { + return (n + complement) & ~(complement - 1); + } + + int64_t PrevComplement(int64_t n, int64_t complement) const { + return n & ~(complement - 1); + } + + bool IsPowerOfTwo(int n) { + if (n <= 0) { + return false; + } + return (n & (n - 1)) == 0; + } + + int64_t PrevPowerOfTwo(int64_t n) const { return NextPowerOfTwo(n << 1); } + + int64_t NextPowerOfTwo(int64_t n) const { + if (n == 0) { + return 1; + } + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + return n + 1; + } + + std::string PointStr(std::array point) const { + return absl::StrJoin(point, ", "); + } + + std::array retrieval_ctx_; + std::array retrieval_pow_ctx_; + std::array max_ctx_; + std::array min_ctx_; + + absl::flat_hash_map, R> retrieval_; +}; + } // namespace xla::gpu #endif // XLA_SERVICE_GPU_MODEL_INTERPOLATOR_H_ diff --git a/third_party/xla/xla/service/gpu/model/interpolator_test.cc b/third_party/xla/xla/service/gpu/model/interpolator_test.cc index 7223b3751c92c7..b5289b2412aa25 100644 --- a/third_party/xla/xla/service/gpu/model/interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/interpolator_test.cc @@ -18,22 +18,40 @@ limitations under the License. #include #include #include +#include #include +#include #include +#include #include #include +#include "absl/log/log.h" +#include "xla/service/gpu/variant_visitor.h" namespace xla::gpu { namespace { +using ::testing::Combine; using ::testing::TestWithParam; -using ::testing::Values; +using ::testing::ValuesIn; + +enum class InterpolatorType { + NN = 0, + Complement = 1, +}; + +template +using Interpolator = + std::variant>, + std::unique_ptr>>; class InterpolatorFake : public InterpolatorBase { public: + ~InterpolatorFake() override = default; + // Fake eval function which just returns the size of the consumed set. - int Eval(std::array& x) override { return plane_.size(); } + int Eval(std::array& x) const override { return plane_.size(); } }; TEST(Interpolator, PersistsEuclideanPoints) { @@ -56,57 +74,108 @@ struct EuclideanNNInterpolatorTestCase { }; class EuclideanNN2DInterpolatorTest - : public TestWithParam> { + : public TestWithParam>> { void SetUp() override { - std::array p1 = {3, 4}; - std::array p2 = {5, 7}; - interpolator_.Add(p1, /*val=*/1); - interpolator_.Add(p2, /*val=*/2); + std::array p1 = {8, 16}; + std::array p2 = {8, 8}; + std::array p3 = {16, 8}; + std::array p4 = {16, 16}; + plane_.push_back({p1, 1}); + plane_.push_back({p2, 2}); + plane_.push_back({p3, 3}); + plane_.push_back({p4, 4}); } protected: - EuclideanNNInterpolator interpolator_; std::vector, int>> plane_; + + Interpolator DispatchInterpolator(InterpolatorType type) { + if (type == InterpolatorType::NN) { + auto interpolator = + std::make_unique>(); + return std::move(interpolator); + } + if (type == InterpolatorType::Complement) { + auto interpolator = + std::make_unique>( + /*next_context=*/std::array{8, 8}, + /*next_power_context=*/std::array{-1, -1}, + /*max_context=*/std::array{16, 16}, + /*min_context=*/std::array{8, 8}); + return std::move(interpolator); + } + LOG(FATAL) << "Unreachable."; + } }; TEST_P(EuclideanNN2DInterpolatorTest, ReturnsNearestNeighbour) { - auto param = GetParam(); - for (auto& [plane_point, val] : plane_) { - interpolator_.Add(plane_point, val); + InterpolatorType interpolator_type = std::get<0>(GetParam()); + auto param = std::get<1>(GetParam()); + + Interpolator interpolator = + DispatchInterpolator(interpolator_type); + for (const auto& point : plane_) { + std::array plane_point = point.first; + int val = point.second; + std::visit( + VariantVisitor{ + [&](const std::unique_ptr>& + nn) { return nn->Add(plane_point, val); }, + [&](const std::unique_ptr< + EuclideanComplementInterpolator>& comp) { + return comp->Add(plane_point, val); + }}, + interpolator); } - EXPECT_EQ(interpolator_.Eval(param.eval_point), param.expected_value); + std::visit( + VariantVisitor{ + [&](const std::unique_ptr>& nn) { + EXPECT_EQ(nn->Eval(param.eval_point), param.expected_value); + }, + [&](const std::unique_ptr< + EuclideanComplementInterpolator>& comp) { + EXPECT_EQ(comp->Eval(param.eval_point), param.expected_value); + }}, + interpolator); } -// We have 2 points on a 2D plane. -// X = {(3,4), (5,7)} -INSTANTIATE_TEST_SUITE_P(EuclideanNNInterpolator2DIntegerTest, - EuclideanNN2DInterpolatorTest, - Values(EuclideanNNInterpolatorTestCase{ - /*test_name=*/"near_first_point", - /*eval_point=*/{4, 3}, - /*expected_value=*/1, - }, - EuclideanNNInterpolatorTestCase{ - /*test_name=*/"near_second_point", - /*eval_point=*/{7, 5}, - /*expected_value=*/2, - }, - EuclideanNNInterpolatorTestCase{ - /*test_name=*/"nearer_only_by_one", - /*eval_point=*/{4, 6}, - /*expected_value=*/2, - }, - EuclideanNNInterpolatorTestCase{ - /*test_name=*/"extrapolate_first_point", - /*eval_point=*/{2, 3}, - /*expected_value=*/1, - }, - EuclideanNNInterpolatorTestCase{ - /*test_name=*/"extrapolate_second_point", - /*eval_point=*/{6, 8}, - /*expected_value=*/2, - }), - [](const auto& info) { return info.param.test_name; }); +// We have 4 points on a 2D plane. +// X = {(8,8), (8,16), (16,8), (16,16)} +INSTANTIATE_TEST_SUITE_P( + EuclideanNNInterpolator2DIntegerTest, EuclideanNN2DInterpolatorTest, + Combine(ValuesIn({InterpolatorType::NN, InterpolatorType::Complement}), + ValuesIn({ + EuclideanNNInterpolatorTestCase{ + /*test_name=*/"near_first_point", + /*eval_point=*/{7, 9}, + /*expected_value=*/2, + }, + EuclideanNNInterpolatorTestCase{ + /*test_name=*/"near_second_point", + /*eval_point=*/{15, 17}, + /*expected_value=*/4, + }, + EuclideanNNInterpolatorTestCase{ + /*test_name=*/"nearer_only_by_one", + /*eval_point=*/{13, 8}, + /*expected_value=*/3, + }, + EuclideanNNInterpolatorTestCase{ + /*test_name=*/"extrapolate_first_point", + /*eval_point=*/{7, 7}, + /*expected_value=*/2, + }, + EuclideanNNInterpolatorTestCase{ + /*test_name=*/"extrapolate_second_point", + /*eval_point=*/{17, 9}, + /*expected_value=*/3, + }, + })), + [](const auto& info) { + return absl::StrCat(std::get<1>(info.param).test_name, "x", + std::get<0>(info.param)); + }); } // namespace } // namespace xla::gpu From adea2a3bf98a9b383eadb387c404ada1b72f0051 Mon Sep 17 00:00:00 2001 From: pizzud Date: Thu, 3 Apr 2025 12:24:08 -0700 Subject: [PATCH 0206/1324] [XLA:GPU] sort_rewriter: Don't run SortRewriter if there's no GPU present. Even when the compiler is built with CUDA, the compilation may not occur on a machine with a GPU. Currently SortRewriter will always attempt to contact the driver, which will fail in such scenarios. Let's avoid that at the admitted runtime cost, especially since the error message ("invalid device ordinal") is not particularly clear. PiperOrigin-RevId: 743656391 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 76ae8d1de9237c..65681113f4d6e2 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -657,7 +657,8 @@ absl::Status RunSPMDPasses( } absl::Status RunOptimizationPasses( - HloModule* hlo_module, const Compiler::TargetConfig& gpu_target_config, + HloModule* hlo_module, stream_executor::StreamExecutor* stream_exec, + const Compiler::TargetConfig& gpu_target_config, const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts) { const DebugOptions& debug_options = hlo_module->config().debug_options(); @@ -697,8 +698,16 @@ absl::Status RunOptimizationPasses( pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); + // SortRewriter needs to ask the device how much scratch space is needed, + // which isn't feasible if we don't have a device. if (hlo_module->config().debug_options().xla_gpu_enable_cub_radix_sort()) { - pipeline.AddPass(); + if (stream_exec != nullptr) { + pipeline.AddPass(); + } else { + LOG(WARNING) << "Using fallback sort algorithm rather than SortRewriter, " + "which will be slower at runtime. To avoid this, " + "compile with a GPU present."; + } } // Comparison total order expander @@ -1371,7 +1380,8 @@ absl::Status GpuCompiler::OptimizeHloModule( TF_RETURN_IF_ERROR(RunPreSPMDPartitionerPasses(hlo_module)); TF_RETURN_IF_ERROR(RunSPMDPasses(hlo_module, gpu_target_config, layout_insensitive_algsimp_opts)); - TF_RETURN_IF_ERROR(RunOptimizationPasses(hlo_module, gpu_target_config, + TF_RETURN_IF_ERROR(RunOptimizationPasses(hlo_module, stream_exec, + gpu_target_config, layout_insensitive_algsimp_opts)); se::GpuComputeCapability gpu_version = device_description.gpu_compute_capability(); From 091dca9543d9b33e73d145a576ce496a91256a0b Mon Sep 17 00:00:00 2001 From: "Patrick J. LoPresti" Date: Thu, 3 Apr 2025 12:59:24 -0700 Subject: [PATCH 0207/1324] Fix NVCC+Clang build failure. "nvcc --compiler-bindir /path/to/clang" sets __clang__ while compiling CUDA code. This causes gpu_device_functions.h to think it is being compiled with Clang and try to use a Clang-specific function. Fixes #90578. --- tensorflow/core/util/gpu_device_functions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/util/gpu_device_functions.h b/tensorflow/core/util/gpu_device_functions.h index bb9ff8c750b7d3..532b7a7209446d 100644 --- a/tensorflow/core/util/gpu_device_functions.h +++ b/tensorflow/core/util/gpu_device_functions.h @@ -194,7 +194,7 @@ __device__ const unsigned kGpuWarpAll = 0xffffffff; __device__ inline unsigned GpuLaneId() { unsigned int lane_id; #if GOOGLE_CUDA -#if __clang__ +#if __clang__ && !__NVCC__ return __nvvm_read_ptx_sreg_laneid(); #else // __clang__ asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); From f1560921d61e6b15657a2d58a487113fe39552ed Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Thu, 3 Apr 2025 12:25:53 -0700 Subject: [PATCH 0208/1324] PR #22541: [ROCm] Cleanup atomics support Imported from GitHub PR https://github.com/openxla/xla/pull/22541 Weaken the ordering barriers to match what atomicAdd does on rocm. Emulate fp16 atomic on top of packed fp16 atomic where possible. Also for bfloat16 atomics, albeit those don't get matched right now due to FloatNormalization. Left in support for fp16 and bfloat16 vector atomics. We might enable the vectorization for them in the future if we can prove the access satisfies 4-byte aligment. Copybara import of the project: -- b53668020f946207aa879eecd8e0b70173f75570 by Dragan Mladjenovic : [ROCm] Cleanup atomics support Merging this change closes #22541 PiperOrigin-RevId: 743656949 --- .../emitters/transforms/atomic_rmw_utils.cc | 1 + .../emitters/transforms/lower_tensors.cc | 100 +++++++++++++----- .../xla/codegen/emitters/transforms/passes.td | 1 + .../transforms/tests/lower_tensors.mlir | 72 +++++++++---- .../tests/lower_xla_loops_to_scf.mlir | 8 +- .../xla/service/gpu/tests/gpu_atomic_test.cc | 2 +- .../xla/stream_executor/device_description.h | 7 +- 7 files changed, 138 insertions(+), 53 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/transforms/atomic_rmw_utils.cc b/third_party/xla/xla/codegen/emitters/transforms/atomic_rmw_utils.cc index b1126d3c3402e1..23388b80472898 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/atomic_rmw_utils.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/atomic_rmw_utils.cc @@ -104,6 +104,7 @@ std::optional> GetAtomicModifierParameters( return std::nullopt; } // Match the kind of the atomic op. + // TODO(rocm): Match bf16 ops mlir::Operation* modifier_op = &operations.front(); auto kind = GetAtomicBinOp(modifier_op, element_type); if (!kind.has_value()) { diff --git a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc index f7f425fa4ea54c..f4cddfe8c0b9fb 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -808,6 +809,13 @@ class RewriteAtomicRMW : public OpRewritePattern { } private: + llvm::StringRef determinateScope() const { + if (device_spec_.IsAmdGpu()) { + return llvm::StringRef("agent-one-as"); + } + return llvm::StringRef(); + } + // Certain computations, such as floating-point addition and integer // maximization, can be simply implemented using an LLVM atomic instruction. // If "computation" is one of this kind, emits code to do that and returns @@ -826,8 +834,7 @@ class RewriteAtomicRMW : public OpRewritePattern { ml::AtomicBinOp atomic_bin_op = modifier_parameters->second; Location loc = op.getLoc(); - bool is_amd = device_spec_.IsAmdGpu(); - llvm::StringRef sync_scope = is_amd ? "agent" : ""; + auto sync_scope = determinateScope(); mlir::ImplicitLocOpBuilder b(loc, rewriter); Value addr = CreateGep(op.getInput(), op.getIndices(), b); @@ -852,12 +859,12 @@ class RewriteAtomicRMW : public OpRewritePattern { } case ml::AtomicBinOp::fadd: { // TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr. - return is_amd + return device_spec_.IsAmdGpu() ? emitAMDAtomicFAdd( - loc, modifier_arg, addr, sync_scope, + loc, modifier_arg, addr, device_spec_.gpu().rocm_compute_capability(), rewriter) : emitNVidiaAtomicFAdd( - loc, modifier_arg, addr, sync_scope, + loc, modifier_arg, addr, device_spec_.gpu().cuda_compute_capability(), rewriter, op); } @@ -872,7 +879,7 @@ class RewriteAtomicRMW : public OpRewritePattern { } LogicalResult emitNVidiaAtomicFAdd( - Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope, + Location loc, Value modifier_arg, Value addr, const se::CudaComputeCapability& cuda_compute_capability, OpBuilder& b, AtomicRMWOp& op) const { Type element_type = modifier_arg.getType(); @@ -897,7 +904,7 @@ class RewriteAtomicRMW : public OpRewritePattern { } b.create(loc, ml::AtomicBinOp::fadd, addr, modifier_arg, - ml::AtomicOrdering::monotonic, sync_scope); + ml::AtomicOrdering::monotonic); return success(); } @@ -950,28 +957,71 @@ class RewriteAtomicRMW : public OpRewritePattern { } LogicalResult emitAMDAtomicFAdd( - Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope, + Location loc, Value modifier_arg, Value addr, const se::RocmComputeCapability& rocm_compute_capability, OpBuilder& b) const { Type element_type = modifier_arg.getType(); - bool is_supported_f16_atomic = - element_type.isF16() && - rocm_compute_capability.has_fp16_atomics_support(); - if (!element_type.isF32() && !is_supported_f16_atomic) { + if (auto vector_type = dyn_cast_or_null(element_type)) { + // TODO(rocm) Don't vectorize atomics if we cannot satisfy 4-byte + // alignment + if (!(vector_type.getNumElements() == 2 && + (vector_type.getElementType().isF16() || + vector_type.getElementType().isBF16()))) { + return failure(); + } + } else if (!element_type.isF32() && !element_type.isF16() && + !element_type.isBF16() && !element_type.isF64()) { return failure(); } - constexpr int kGlobalMemory = 1; - constexpr int kSharedMemory = 3; - auto addr_type = mlir::cast(addr.getType()); - // adds to shared memory are always atomic. - if (addr_type.getAddressSpace() != kSharedMemory) { - // The compiler will only generate a global_atomic_fadd if the pointer - // is in global addrspace (1) - addr = b.create( - loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), addr); + + if ((element_type.isF16() && + rocm_compute_capability.has_packed_fp16_atomics_support()) || + (element_type.isBF16() && + rocm_compute_capability.has_packed_bf16_atomics_support())) { + auto packed_type = mlir::VectorType::get({2}, element_type); + auto i64_type = b.getI64Type(); + auto i32_type = b.getI32Type(); + auto i16_type = b.getI16Type(); + Value addr_int = b.create(loc, i64_type, addr); + Value addr_masked = b.create( + loc, addr_int, b.create(loc, i64_type, -4)); + + Value offset = b.create( + loc, b.create(loc, i32_type, addr_int), + b.create(loc, i32_type, 2)); + + Value shift = b.create( + loc, offset, b.create(loc, i32_type, 8)); + + Value modifier_int = b.create(loc, i16_type, modifier_arg); + + Value modifier_masked = b.create( + loc, b.create(loc, i32_type, modifier_int), shift); + + constexpr int kGlobalMemory = 1; + addr = b.create( + loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), + addr_masked); + + modifier_arg = b.create(loc, packed_type, modifier_masked); + element_type = packed_type; } - b.create(loc, ml::AtomicBinOp::fadd, addr, modifier_arg, - ml::AtomicOrdering::monotonic, sync_scope); + + auto op = b.create( + loc, ml::AtomicBinOp::fadd, addr, modifier_arg, + ml::AtomicOrdering::monotonic, "agent-one-as"); + + auto unitAttr = b.getUnitAttr(); + auto* rocdl = + op->getContext()->getOrLoadDialect(); + auto noRemoteMemHelper = rocdl->getNoRemoteMemoryAttrHelper(); + auto noFineMemHelper = rocdl->getNoFineGrainedMemoryAttrHelper(); + auto ignoreDenormalModeHelper = rocdl->getIgnoreDenormalModeAttrHelper(); + + noRemoteMemHelper.setAttr(op, unitAttr); + noFineMemHelper.setAttr(op, unitAttr); + ignoreDenormalModeHelper.setAttr(op, unitAttr); + return success(); } @@ -1182,11 +1232,13 @@ class RewriteAtomicRMW : public OpRewritePattern { new_value = CreateBitcast(b, op, result, atomic_ty); } + auto sync_scope = determinateScope(); + // Try saving the result atomically, retry if failed. Value cmpxchg = b.create( loc, addr, old_value, new_value, /*success_ordering=*/ml::AtomicOrdering::monotonic, - /*failure_ordering=*/ml::AtomicOrdering::monotonic); + /*failure_ordering=*/ml::AtomicOrdering::monotonic, sync_scope); Value next = b.create(cmpxchg, 0); Value ok = b.create(cmpxchg, 1); Value low_bit = b.create(b.getOneAttr(b.getI1Type())); diff --git a/third_party/xla/xla/codegen/emitters/transforms/passes.td b/third_party/xla/xla/codegen/emitters/transforms/passes.td index 24e7b71573d6d9..68fe4afe892f08 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/passes.td +++ b/third_party/xla/xla/codegen/emitters/transforms/passes.td @@ -99,6 +99,7 @@ def LowerTensorsPass : Pass<"xla-lower-tensors", "mlir::ModuleOp"> { "xla::gpu::XlaGpuDialect", "xla::XlaDialect", "mlir::vector::VectorDialect", + "mlir::ROCDL::ROCDLDialect", ]; let options = [ Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", diff --git a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir index ffda69a09bb090..2a927f8439857c 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_tensors.mlir @@ -280,8 +280,8 @@ func.func @atomic_rmw_f16(%in: tensor<8xf16>, %i: index) // CHECK-NEXT: llvm.bitcast %[[VAR_TRUNC]] : i16 to f16 // CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16 // CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]] -// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}} -// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} +// CHECK-DAG: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} +// CHECK-DAG: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}} // CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]] // CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]] @@ -362,8 +362,8 @@ func.func @atomic_rmw_overwrite(%in: tensor<8xf16>, %i: index) // CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]]) // CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16 // CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]] -// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}} -// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} +// CHECK-DAG: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} +// CHECK-DAG: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}} // CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]] // CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]] @@ -407,16 +407,16 @@ func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) // CHECK: scf.while (%[[INIT:.*]] = %[[CURRENT_I32]]) // CHECK: %[[SHIFTED:.*]] = llvm.lshr %[[INIT]] // CHECK: %[[CURRENT:.*]] = llvm.trunc %[[SHIFTED]] -// CHECK: %[[MASKED_CURRENT_LO:.*]] = arith.andi %[[CURRENT]], %[[C_NEG16]] : i8 -// CHECK: %[[MASKED_VALUE_I8:.*]] = arith.andi %[[VALUE_I8]], %[[C15]] : i8 +// CHECK-DAG: %[[MASKED_VALUE_I8:.*]] = arith.andi %[[VALUE_I8]], %[[C15]] : i8 +// CHECK-DAG: %[[MASKED_CURRENT_LO:.*]] = arith.andi %[[CURRENT]], %[[C_NEG16]] : i8 // CHECK: %[[NEW_LO:.*]] = arith.ori %[[MASKED_CURRENT_LO]], %[[MASKED_VALUE_I8]] : i8 -// CHECK: %[[MASKED_CURRENT_HI:.*]] = arith.andi %[[CURRENT]], %[[C15]] : i8 -// CHECK: %[[VALUE_HI:.*]] = arith.shli %[[VALUE_I8]], %[[C4]] : i8 +// CHECK-DAG: %[[VALUE_HI:.*]] = arith.shli %[[VALUE_I8]], %[[C4]] : i8 +// CHECK-DAG: %[[MASKED_CURRENT_HI:.*]] = arith.andi %[[CURRENT]], %[[C15]] : i8 // CHECK: %[[NEW_HI:.*]] = arith.ori %[[MASKED_CURRENT_HI]], %[[VALUE_HI]] : i8 // CHECK: %[[NEW_VALUE:.*]] = arith.select %{{.*}}, %[[NEW_LO]], %[[NEW_HI]] : i8 // CHECK: %[[NEW_VALUE_I32:.*]] = llvm.zext %[[NEW_VALUE]] -// CHECK: %[[MASKED_INIT:.*]] = llvm.and %[[INIT]] -// CHECK: %[[NEW_VALUE_SHIFTED:.*]] = llvm.shl %[[NEW_VALUE_I32]] +// CHECK-DAG: %[[NEW_VALUE_SHIFTED:.*]] = llvm.shl %[[NEW_VALUE_I32]] +// CHECK-DAG: %[[MASKED_INIT:.*]] = llvm.and %[[INIT]] // CHECK: %[[NEW_INIT:.*]] = llvm.or %[[MASKED_INIT]], %[[NEW_VALUE_SHIFTED]] // CHECK: llvm.cmpxchg %{{.*}}, %[[INIT]], %[[NEW_INIT]] monotonic monotonic // CHECK: scf.condition @@ -552,14 +552,12 @@ func.func @direct_atomic_rmw_fadd_f32(%in: tensor<8xf32>, // CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2 // CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-GFX908-MI100: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1> -// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") monotonic +// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] syncscope("agent-one-as") monotonic // CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2 // CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1> -// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") monotonic +// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] syncscope("agent-one-as") monotonic // ----- @@ -587,13 +585,41 @@ func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, // CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f16 -// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd +// CHECK-GFX908-MI100: %[[CST:.*]] = arith.constant 2 +// CHECK-GFX908-MI100: %[[C_NEG4:.*]] = llvm.mlir.constant(-4 : i64) : i64 +// CHECK-GFX908-MI100: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK-GFX908-MI100: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr +// CHECK-GFX908-MI100: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] +// CHECK-GFX908-MI100: %[[ADDR_MASKED:.*]] = llvm.and %[[ADDR_INT]], %[[C_NEG4]] +// CHECK-GFX908-MI100: %[[ADDR_TRUNC:.*]] = llvm.trunc %[[ADDR_INT]] +// CHECK-GFX908-MI100: %[[OFFSET:.*]] = llvm.and %[[ADDR_TRUNC]], %[[C2]] +// CHECK-GFX908-MI100: %[[SHIFT:.*]] = llvm.mul %[[OFFSET]], %[[C8]] +// CHECK-GFX908-MI100: %[[VAL_INT:.*]] = llvm.bitcast %[[CST]] : f16 to i16 +// CHECK-GFX908-MI100: %[[VAL_WIDE:.*]] = llvm.zext %[[VAL_INT]] : i16 to i32 +// CHECK-GFX908-MI100: %[[VAL_SHIFT:.*]] = llvm.shl %[[VAL_WIDE]], %[[SHIFT]] +// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.inttoptr %[[ADDR_MASKED]] +// CHECK-GFX908-MI100: %[[VAL:.*]] = llvm.bitcast %[[VAL_SHIFT]] : i32 to vector<2xf16> +// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR]], %[[VAL]] syncscope("agent-one-as") monotonic + // CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f16 -// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2 +// CHECK-GFX90A-MI200: %[[CST:.*]] = arith.constant 2 +// CHECK-GFX90A-MI200: %[[C_NEG4:.*]] = llvm.mlir.constant(-4 : i64) : i64 +// CHECK-GFX90A-MI200: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK-GFX90A-MI200: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr -// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1> -// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") monotonic +// CHECK-GFX90A-MI200: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] +// CHECK-GFX90A-MI200: %[[ADDR_MASKED:.*]] = llvm.and %[[ADDR_INT]], %[[C_NEG4]] +// CHECK-GFX90A-MI200: %[[ADDR_TRUNC:.*]] = llvm.trunc %[[ADDR_INT]] +// CHECK-GFX90A-MI200: %[[OFFSET:.*]] = llvm.and %[[ADDR_TRUNC]], %[[C2]] +// CHECK-GFX90A-MI200: %[[SHIFT:.*]] = llvm.mul %[[OFFSET]], %[[C8]] +// CHECK-GFX90A-MI200: %[[VAL_INT:.*]] = llvm.bitcast %[[CST]] : f16 to i16 +// CHECK-GFX90A-MI200: %[[VAL_WIDE:.*]] = llvm.zext %[[VAL_INT]] : i16 to i32 +// CHECK-GFX90A-MI200: %[[VAL_SHIFT:.*]] = llvm.shl %[[VAL_WIDE]], %[[SHIFT]] +// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.inttoptr %[[ADDR_MASKED]] +// CHECK-GFX90A-MI200: %[[VAL:.*]] = llvm.bitcast %[[VAL_SHIFT]] : i32 to vector<2xf16> +// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR]], %[[VAL]] syncscope("agent-one-as") monotonic // ----- @@ -643,10 +669,16 @@ func.func @direct_atomic_rmw_fadd_f64(%in: tensor<8xf64>, // CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] monotonic // CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f64 -// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd +// CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2 +// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr +// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] syncscope("agent-one-as") monotonic + // CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f64 -// CHECK-GFX90A-MI200-NOT: llvm.atomicrmw fadd +// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2 +// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr +// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] syncscope("agent-one-as") monotonic + // ----- diff --git a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir index 798e72fc1e3016..bbe7f9a6f11060 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir +++ b/third_party/xla/xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir @@ -37,11 +37,11 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // CHECK: %[[INDEX:.*]] = xla.apply_indexing // CHECK-SAME: #[[$MAP]](%[[DIM]])[%[[I]], %[[J]]] -// CHECK: %[[VAL1:.*]] = arith.cmpi sge, %[[INDEX]], %[[C0]] : index -// CHECK: %[[VAL2:.*]] = arith.cmpi sle, %[[INDEX]], %[[C90]] : index +// CHECK-DAG: %[[VAL2:.*]] = arith.cmpi sle, %[[INDEX]], %[[C90]] : index +// CHECK-DAG: %[[VAL1:.*]] = arith.cmpi sge, %[[INDEX]], %[[C0]] : index // CHECK: %[[VAL3:.*]] = arith.andi %[[VAL1]], %[[VAL2]] : i1 -// CHECK: %[[VAL4:.*]] = arith.cmpi sge, %[[DIM]], %[[C0]] : index -// CHECK: %[[VAL5:.*]] = arith.cmpi sle, %[[DIM]], %[[C3]] : index +// CHECK-DAG: %[[VAL5:.*]] = arith.cmpi sle, %[[DIM]], %[[C3]] : index +// CHECK-DAG: %[[VAL4:.*]] = arith.cmpi sge, %[[DIM]], %[[C0]] : index // CHECK: %[[VAL6:.*]] = arith.andi %[[VAL4]], %[[VAL5]] : i1 // CHECK: %[[INBOUNDS:.*]] = arith.andi %[[VAL3]], %[[VAL6]] : i1 // CHECK: %[[IF_RESULT:.*]] = scf.if %[[INBOUNDS]] -> (f32) { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc index 53cdf60e56c99a..c3eac7e9df6a3a 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_atomic_test.cc @@ -101,7 +101,7 @@ TEST_F(GpuAtomicTest, TestAddAtomicF32) { )"; CompileAndVerifyIr(hlo_string, is_built_with_rocm_ ? R"( -CHECK: atomicrmw fadd ptr addrspace(1) %[[ADDR:.*]], float %[[VALUE:.*]] syncscope("agent") monotonic +CHECK: atomicrmw fadd ptr %[[ADDR:.*]], float %[[VALUE:.*]] syncscope("agent-one-as") monotonic )" : R"( CHECK: atomicrmw fadd ptr %[[ADDR:.*]], float %[[VALUE:.*]] monotonic diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 476e3a0257135e..8263373452617f 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -111,10 +111,9 @@ class RocmComputeCapability { gfx_version().find("gfx12")); } - bool has_fp16_atomics_support() const { - // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). - return gfx9_mi200_or_later(); - } + bool has_packed_fp16_atomics_support() const { return gfx9_mi100_or_later(); } + + bool has_packed_bf16_atomics_support() const { return gfx9_mi300_series(); } bool fence_before_barrier() const { return gfx_version() != "gfx900" && gfx_version() != "gfx906"; From 945522c8d3c51742e70d9bafef8cd3c45887f509 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Thu, 3 Apr 2025 12:28:57 -0700 Subject: [PATCH 0209/1324] Clean up TfrtGpuAsyncHostToDeviceTransferManager non-const field should all be protected by the lock. PiperOrigin-RevId: 743657986 --- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 81 +++++++++++-------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index b78f462e460cc0..155a1cf07de8b2 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -273,6 +273,7 @@ class TfrtGpuAsyncHostToDeviceTransferManager final definition_events; absl::InlinedVector device_shapes; buffers.reserve(shape_specs.size()); + buffer_ptrs.reserve(shape_specs.size()); definition_events.reserve(shape_specs.size()); device_shapes.reserve(shape_specs.size()); for (int i = 0; i < shape_specs.size(); ++i) { @@ -331,6 +332,7 @@ class TfrtGpuAsyncHostToDeviceTransferManager final tsl::Env::Default(), "TfrtGpuAsyncHostToDeviceTransferManager_h2d_thread")), buffer_ptrs_(std::move(buffer_ptrs)), + buffer_sizes_(GetBufferSizes(buffers_)), definition_events_(std::move(definition_events)), device_shapes_(std::move(device_shapes)), remaining_buffer_count_(buffers_.size()), @@ -340,10 +342,6 @@ class TfrtGpuAsyncHostToDeviceTransferManager final "TfrtGpuAsyncHostToDeviceTransferManager: this=" << this << " buffers_.size()=" << buffers_.size(); - buffer_sizes_.reserve(buffers_.size()); - for (const auto& buffer : buffers_) { - buffer_sizes_.push_back(buffer->GetOnDeviceSizeInBytes().value()); - } last_transfer_started_.resize(buffer_ptrs_.size(), false); } @@ -360,7 +358,7 @@ class TfrtGpuAsyncHostToDeviceTransferManager final } } - size_t buffer_count() const override { return buffers_.size(); }; + size_t buffer_count() const override { return buffer_sizes_.size(); }; size_t buffer_size(int buffer_index) const override { DCHECK_LT(buffer_index, buffer_sizes_.size()); @@ -370,6 +368,7 @@ class TfrtGpuAsyncHostToDeviceTransferManager final PjRtDevice* device() const override { return device_; } std::unique_ptr RetrieveBuffer(int buffer_index) override { + absl::MutexLock l(&mu_); DCHECK_LT(buffer_index, buffers_.size()); return std::move(buffers_[buffer_index]); }; @@ -410,7 +409,9 @@ class TfrtGpuAsyncHostToDeviceTransferManager final // because it includes linearization that may be slow. // TODO(misard) assess if it would be preferable to introduce a heuristic // to put the transfer into the calling thread for small literals. - auto transfer_h2d = [this, buffer_index, transfer_manager, literal, buffer, + auto transfer_h2d = [this, buffer_index, transfer_manager, + literal = std::move(literal), + buffer = std::move(buffer), on_done = std::move(on_done)]() mutable { tsl::profiler::TraceMe traceme( "TfrtGpuAsyncHostToDeviceTransferManager::TransferLiteralToBuffer::" @@ -477,33 +478,33 @@ class TfrtGpuAsyncHostToDeviceTransferManager final std::unique_ptr> staging_buffer = host_memory_allocator->Allocate(transfer_size); - absl::ReleasableMutexLock l(&mu_); - DCHECK_LT(buffer_index, buffer_ptrs_.size()); - if (last_transfer_started_[buffer_index]) { - return InvalidArgument( - "TransferRawData requested for buffer index %d which has " - "already been fully transferred", - buffer_index); - } - if (is_last_transfer) { - last_transfer_started_[buffer_index] = true; - } - DCHECK(buffer_ptrs_[buffer_index]); - tsl::AsyncValueRef& buffer_memory = - buffer_ptrs_[buffer_index]; se::DeviceMemoryBase sub_buffer; - CHECK_LE(offset, buffer_memory->size()); - CHECK_LE(transfer_size, buffer_memory->size() - offset); - if (transfer_size < buffer_memory->size()) { - sub_buffer = buffer_memory->buffer().GetByteSlice(offset, transfer_size); - } else { - sub_buffer = buffer_memory->buffer(); - } + { + absl::MutexLock l(&mu_); + DCHECK_LT(buffer_index, buffer_ptrs_.size()); + if (last_transfer_started_[buffer_index]) { + return InvalidArgument( + "TransferRawData requested for buffer index %d which has " + "already been fully transferred", + buffer_index); + } + if (is_last_transfer) { + last_transfer_started_[buffer_index] = true; + } + DCHECK(buffer_ptrs_[buffer_index]); + tsl::AsyncValueRef& buffer_memory = + buffer_ptrs_[buffer_index]; + CHECK_LE(offset, buffer_memory->size()); + CHECK_LE(transfer_size, buffer_memory->size() - offset); + if (transfer_size < buffer_memory->size()) { + sub_buffer = + buffer_memory->buffer().GetByteSlice(offset, transfer_size); + } else { + sub_buffer = buffer_memory->buffer(); + } - ++transfers_in_flight_; - // Release the lock before transfer in case transfer or cleanup could be - // called on this thread, to avoid deadlock. - l.Release(); + ++transfers_in_flight_; + } auto copy_to_gpu = [transfer_size, staging_buffer = std::move(staging_buffer), data, @@ -546,6 +547,16 @@ class TfrtGpuAsyncHostToDeviceTransferManager final void AddTransferMetadata(const TransferMetadata& meta) override {} private: + static absl::InlinedVector GetBufferSizes( + absl::InlinedVector, 4>& buffers) { + absl::InlinedVector buffer_sizes; + buffer_sizes.reserve(buffers.size()); + for (const auto& buffer : buffers) { + buffer_sizes.push_back(buffer->GetOnDeviceSizeInBytes().value()); + } + return buffer_sizes; + } + void CleanUp(int buffer_index, bool is_last_transfer, absl::AnyInvocable on_done) { { @@ -579,7 +590,8 @@ class TfrtGpuAsyncHostToDeviceTransferManager final absl::Mutex mu_; // The newly created buffers, which will be returned to the caller via // Retrieve. - absl::InlinedVector, 4> buffers_; + absl::InlinedVector, 4> buffers_ + ABSL_GUARDED_BY(mu_); // Just a single thread, to ensure transfers are ordered. Its lifetime is // managed by H2DTransferManager. We assume `h2d_thread` is destructed before @@ -587,10 +599,11 @@ class TfrtGpuAsyncHostToDeviceTransferManager final // threads managed by `client_`. std::unique_ptr h2d_thread_; - absl::InlinedVector, 4> buffer_ptrs_; + absl::InlinedVector, 4> buffer_ptrs_ + ABSL_GUARDED_BY(mu_); // Cached versions of the sizes of all the buffers, so we can return them // without acquiring mu_. - absl::InlinedVector buffer_sizes_; + const absl::InlinedVector buffer_sizes_; // True if the last transfer for a buffer has been initiated. Used to // prevent a client initiating another transfer after the last transfer has // already been initiated. From 735e3e631e111556c271af72d74ee4f9f609aca5 Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Thu, 3 Apr 2025 12:37:41 -0700 Subject: [PATCH 0210/1324] Integrate StableHLO at openxla/stablehlo@4bf77d23 PiperOrigin-RevId: 743660800 --- third_party/stablehlo/temporary.patch | 576 ------------------ third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 576 ------------------ .../xla/third_party/stablehlo/workspace.bzl | 4 +- 4 files changed, 4 insertions(+), 1156 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 949ebc772ae60e..8b137891791fe9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,577 +1 @@ -diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp ---- stablehlo/stablehlo/dialect/Serialization.cpp -+++ stablehlo/stablehlo/dialect/Serialization.cpp -@@ -32,20 +32,25 @@ - #include "stablehlo/dialect/VhloOps.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { - - LogicalResult serializePortableArtifact(ModuleOp module, - StringRef targetVersion, -- raw_ostream& os) { -+ raw_ostream& os, -+ bool allowOtherDialects) { - MLIRContext* context = module.getContext(); - -- // Convert StableHLO --> VHLO. Will fail if entire program is not StableHLO. -+ // Convert StableHLO --> VHLO. -+ // If allowOtherDialects is true, we will allow other dialects to be present -+ // in the module, otherwise will fail if there are any other dialects present. - { - PassManager pm(context); -- pm.addPass(stablehlo::createStablehloLegalizeToVhloPass()); -+ StablehloLegalizeToVhloPassOptions options; -+ options.allowOtherDialects = allowOtherDialects; -+ pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options)); - if (!succeeded(pm.run(module))) { - return failure(); - } -diff --ruN a/stablehlo/stablehlo/dialect/Serialization.h b/stablehlo/stablehlo/dialect/Serialization.h ---- stablehlo/stablehlo/dialect/Serialization.h -+++ stablehlo/stablehlo/dialect/Serialization.h -@@ -34,7 +34,8 @@ - // unsupported dialects. - LogicalResult serializePortableArtifact(ModuleOp module, - StringRef targetVersion, -- raw_ostream& os); -+ raw_ostream& os, -+ bool allowOtherDialects = false); - - // Read StableHLO portable artifact - // -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.cpp b/stablehlo/stablehlo/dialect/VhloTypes.cpp ---- stablehlo/stablehlo/dialect/VhloTypes.cpp -+++ stablehlo/stablehlo/dialect/VhloTypes.cpp -@@ -323,6 +323,20 @@ - } - - namespace { -+Value materializeIllegalCast(OpBuilder& builder, Type type, ValueRange inputs, -+ Location loc) { -+ return builder.create(loc, type, inputs) -+ ->getResult(0); -+} -+} // namespace -+ -+void VhloTypeConverter::addUnrealizedMaterializations() { -+ addTargetMaterialization(materializeIllegalCast); -+ addSourceMaterialization(materializeIllegalCast); -+ addArgumentMaterialization(materializeIllegalCast); -+} -+ -+namespace { - // Helper functions for VHLO verifiers - template - bool isFromVhlo(TypeOrAttr t) { -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/dialect/VhloTypes.h ---- stablehlo/stablehlo/dialect/VhloTypes.h -+++ stablehlo/stablehlo/dialect/VhloTypes.h -@@ -55,6 +55,9 @@ - // it is likely that a subclass should call this last, especially if a default - // `Type -> Type` fallback conversion is registered. - void addBuiltinToVhloConversions(); -+ -+ // Mark unrealized casts as legal. Useful for dialect mixing. -+ void addUnrealizedMaterializations(); - }; - - // Autogenerated VHLO type printers and parsers. -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -@@ -1338,7 +1338,7 @@ - - // ----- - --// expected-error@+1 {{scale out of expressed type range}} -+// expected-error@+1 {{scale 1.055040e+05 out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return -@@ -1346,7 +1346,7 @@ - - // ----- - --// expected-error@+1 {{scale out of expressed type range}} -+// expected-error@+1 {{scale 4.960464e-08 out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return -diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir ---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir -+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir -@@ -0,0 +1,189 @@ -+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -+ -+// The script is designed to make adding checks to -+// a test case fast, it is *not* designed to be authoritative -+// about what constitutes a good test! The CHECK should be -+// minimized and named to reflect the test intent. -+ -+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -+ -+// The script is designed to make adding checks to -+// a test case fast, it is *not* designed to be authoritative -+// minimized and named to reflect the test intent. -+ -+// RUN: stablehlo-opt %s --stablehlo-legalize-to-vhlo=allow-other-dialects | FileCheck %s -+// RUN: stablehlo-opt %s > %t.0 -+// RUN: stablehlo-opt %s --stablehlo-legalize-to-vhlo=allow-other-dialects | stablehlo-opt --vhlo-legalize-to-stablehlo > %t.1 -+// RUN: diff %t.0 %t.1 -+ -+// CHECK-LABEL: vhlo.func_v1 @op_other( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor to !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_3]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @op_other(%arg0: tensor) -> tensor { -+ %0 = arith.addf %arg0, %arg0 : tensor -+ return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_shlo( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_1]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @op_shlo(%arg0: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg0 : tensor -+ return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @mixed_shlo_other_shlo( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_2:.*]] = "vhlo.abs_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_3]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : tensor to !vhlo.tensor_v1 -+// CHECK: %[[VAL_6:.*]] = "vhlo.abs_v1"(%[[VAL_5]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_6]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @mixed_shlo_other_shlo(%arg0: tensor) -> tensor { -+ %0 = stablehlo.abs %arg0 : tensor -+ %1 = arith.addf %0, %arg0 : tensor -+ %2 = stablehlo.abs %1 : tensor -+ return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @mixed_other_shlo_other( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor to !vhlo.tensor_v1 -+// CHECK: %[[VAL_4:.*]] = "vhlo.add_v1"(%[[VAL_3]], %[[VAL_0]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : tensor to !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_7]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @mixed_other_shlo_other(%arg0: tensor) -> tensor { -+ %0 = arith.addf %arg0, %arg0 : tensor -+ %1 = stablehlo.add %0, %arg0 : tensor -+ %2 = arith.addf %1, %arg0 : tensor -+ return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_with_region( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<1x16x16x320x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1<1x320x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}} : tensor<2xi64>>}> ({ -+// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1, %[[VAL_4:.*]]: !vhlo.tensor_v1): -+// CHECK: %[[VAL_5:.*]] = "vhlo.add_v1"(%[[VAL_3]], %[[VAL_4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_5]]) : (!vhlo.tensor_v1) -> () -+// CHECK: }) : (!vhlo.tensor_v1<1x16x16x320x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<1x320x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<1x320x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @op_with_region(%arg0: tensor<1x16x16x320xf32>, %arg1: tensor) -> tensor<1x320xf32> { -+ %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.add across dimensions = [1, 2] : (tensor<1x16x16x320xf32>, tensor) -> tensor<1x320xf32> -+ return %0 : tensor<1x320xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_with_region_mixed_other_shlo_other( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<7x5x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<5x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}} : tensor<1xi64>>}> ({ -+// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>): -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_5]] : tensor<5xf32> -+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_9:.*]] = "vhlo.add_v1"(%[[VAL_8]], %[[VAL_3]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_5]] : tensor<5xf32> -+// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_12]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: }) : (!vhlo.tensor_v1<7x5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @op_with_region_mixed_other_shlo_other(%arg0: tensor<7x5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { -+ %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> -+ reducer(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>) { -+ %1 = arith.addf %arg2, %arg3 : tensor<5xf32> -+ %2 = stablehlo.add %1, %arg2 : tensor<5xf32> -+ %3 = arith.addf %2, %arg3 : tensor<5xf32> -+ stablehlo.return %3 : tensor<5xf32> -+ } -+ return %0 : tensor<5xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_with_region_mixed_shlo_other_shlo( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<7x5x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<5x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}} : tensor<1xi64>>}> ({ -+// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>): -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_6:.*]] = "vhlo.abs_v1"(%[[VAL_3]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_7]], %[[VAL_5]] : tensor<5xf32> -+// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_10:.*]] = "vhlo.abs_v1"(%[[VAL_9]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_10]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: }) : (!vhlo.tensor_v1<7x5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @op_with_region_mixed_shlo_other_shlo(%arg0: tensor<7x5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { -+ %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> -+ reducer(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>) { -+ %1 = stablehlo.abs %arg2 : tensor<5xf32> -+ %2 = arith.addf %1, %arg3 : tensor<5xf32> -+ %3 = stablehlo.abs %2 : tensor<5xf32> -+ stablehlo.return %3 : tensor<5xf32> -+ } -+ return %0 : tensor<5xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @stablehlo_in_other_op_region( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<2x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.index_v1) -> (!vhlo.tensor_v1<2x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1<2x!vhlo.f32_v1> to tensor<2xf32> -+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -+// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index -+// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -+// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_9:.*]] = %[[VAL_2]]) -> (tensor<2xf32>) { -+// CHECK: %[[VAL_10:.*]] = tensor.insert %[[VAL_6]] into %[[VAL_9]]{{\[}}%[[VAL_8]]] : tensor<2xf32> -+// CHECK: %[[VAL_11:.*]] = builtin.unrealized_conversion_cast %[[VAL_10]] : tensor<2xf32> to !vhlo.tensor_v1<2x!vhlo.f32_v1> -+// CHECK: %[[VAL_12:.*]] = "vhlo.add_v1"(%[[VAL_11]], %[[VAL_11]]) : (!vhlo.tensor_v1<2x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.f32_v1>) -> !vhlo.tensor_v1<2x!vhlo.f32_v1> -+// CHECK: %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !vhlo.tensor_v1<2x!vhlo.f32_v1> to tensor<2xf32> -+// CHECK: scf.yield %[[VAL_13]] : tensor<2xf32> -+// CHECK: } -+// CHECK: %[[VAL_14:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : tensor<2xf32> to !vhlo.tensor_v1<2x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_14]]) : (!vhlo.tensor_v1<2x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @stablehlo_in_other_op_region(%arg0: tensor<2xf32>, %arg1: index) -> tensor<2xf32> { -+ %c0 = arith.constant 0 : index -+ %c1 = arith.constant 1 : index -+ %c2 = arith.constant 2 : index -+ %cst = arith.constant 0.0 : f32 -+ -+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { -+ %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> -+ %new_out_add = stablehlo.add %new_out, %new_out : tensor<2xf32> -+ scf.yield %new_out_add : tensor<2xf32> -+ } -+ return %for : tensor<2xf32> -+} -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -308,7 +308,26 @@ - - def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { - let summary = "Legalize StableHLO to VHLO."; -+ let description = [{ -+ Legalize StableHLO to the latest version of ops in VHLO. These ops can then -+ be downgraded to older versions of VHLO for forward compatibility using -+ `VhloToVersionPass`. -+ -+ ```mlir -+ stablehlo.exponential %[[ARG0]] <{result_accuracy = DEFAULT}> : tensor -+ # ====> -+ "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.DEFAULT_v1}> : !vhlo.tensor_v1 -+ ``` -+ -+ See [vhlo.md > The VHLO dialect](https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md) -+ for full details on how VHLO is used to preserve forward and backward -+ compatibility. -+ }]; - let dependentDialects = ["mlir::vhlo::VhloDialect"]; -+ let options = [ -+ Option<"allowOtherDialects", "allow-other-dialects", "bool", /*default=*/"false", -+ "Allow serialization to use other (potentially unstable) dialects, inserts unrealized casts between dialects.">, -+ ]; - } - - def StablehloRefineArgumentsPass : Pass<"stablehlo-refine-arguments", "ModuleOp"> { -@@ -330,6 +349,7 @@ - %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...} - : (tensor<16xf32>, tensor<1xi64>) -> tensor - ... -+ } - ``` - - The `refinedTypesOption` can be used to specify a list of refined types. -@@ -402,7 +422,23 @@ - } - - def VhloToVersionPass : Pass<"vhlo-to-version"> { -- let summary = "Convert between versions of VHLO."; -+ let summary = "Convert between versions of VHLO for compatibility."; -+ let description = [{ -+ Converts between versions of VHLO for IR upgrading and downgrading to -+ preserve forward and backward compatibility. -+ -+ ```mlir -+ "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = DEFAULT}> -+ # ==( -target=1.0.0 )==> -+ "vhlo.exponential_v1"(%[[ARG0]]) -+ # ==( -target=1.9.0 )==> -+ "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = DEFAULT}> -+ ``` -+ -+ See [vhlo.md > The VHLO dialect](https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md) -+ for full details on how VHLO is used to preserve forward and backward -+ compatibility. -+ }]; - let options = [ - Option<"targetVersionOption", "target", "std::string", "", - "The target version. Must be a version of the form #.#.# .">, -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp -@@ -47,7 +47,7 @@ - #include "stablehlo/transforms/PassUtils.h" // IWYU pragma: keep - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -@@ -12,6 +12,7 @@ - limitations under the License. - ==============================================================================*/ - -+#include - #include - #include - #include -@@ -24,6 +25,8 @@ - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinDialect.h" -+#include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/IR/ValueRange.h" -@@ -37,7 +40,7 @@ - #include "stablehlo/transforms/MapStablehloToVhlo.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { -@@ -53,17 +56,33 @@ - - class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { - public: -- StablehloToVhloTypeConverter() : vhlo::VhloTypeConverter() { -- addConversion([](Type type) -> Type { -- if (llvm::isa(type.getDialect())) return type; -- -- LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); -+ StablehloToVhloTypeConverter(bool allowOtherDialects) -+ : vhlo::VhloTypeConverter() { -+ LLVM_DEBUG( -+ llvm::dbgs() -+ << "[StablehloToVhloTypeConverter] Creating with allowOtherDialects: " -+ << allowOtherDialects << '\n'); -+ -+ addConversion([&allowOtherDialects](Type type) -> Type { -+ if (isa(type.getDialect())) return type; -+ -+ if (allowOtherDialects && -+ !isa(type.getDialect())) { -+ LLVM_DEBUG(llvm::dbgs() -+ << "[StablehloToVhloTypeConverter] Valid non-VHLO type: " -+ << type << '\n'); -+ return type; -+ } -+ -+ LLVM_DEBUG(llvm::dbgs() << "[StablehloToVhloTypeConverter] Invalid type: " -+ << type << '\n'); - return {}; - }); - addConversion([](TokenType token) -> Type { - return vhlo::TokenV1Type::get(token.getContext()); - }); - addBuiltinToVhloConversions(); -+ if (allowOtherDialects) addUnrealizedMaterializations(); - } - - Attribute convertEncoding(Attribute attr) const final { -@@ -1021,14 +1040,27 @@ - struct StablehloLegalizeToVhloPass - : public impl::StablehloLegalizeToVhloPassBase< - StablehloLegalizeToVhloPass> { -+ using StablehloLegalizeToVhloPassBase::StablehloLegalizeToVhloPassBase; -+ - LogicalResult initialize(MLIRContext* context) override { - target = std::make_shared(*context); - target->addIllegalDialect(); - target->addIllegalDialect(); - target->addLegalDialect(); -+ LLVM_DEBUG(llvm::dbgs() -+ << "allowOtherDialects: " << allowOtherDialects << "\n"); -+ -+ converter = -+ std::make_shared(allowOtherDialects); -+ if (allowOtherDialects) { -+ target->addLegalOp(); -+ } else { -+ target->addIllegalOp(); -+ } - - RewritePatternSet patterns_(context); -- stablehlo::populateStablehloToVhloPatterns(&patterns_, &converter, context); -+ stablehlo::populateStablehloToVhloPatterns(&patterns_, converter.get(), -+ context); - patterns = std::move(patterns_); - - return success(); -@@ -1043,7 +1075,7 @@ - } - - private: -- StablehloToVhloTypeConverter converter; -+ std::shared_ptr converter; - FrozenRewritePatternSet patterns; - std::shared_ptr target; - }; -diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -@@ -27,6 +27,7 @@ - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/IR/ValueRange.h" -@@ -40,7 +41,7 @@ - #include "stablehlo/transforms/MapStablehloToVhlo.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { -@@ -63,6 +64,7 @@ - return stablehlo::TokenType::get(token.getContext()); - }); - addVhloToBuiltinConversions(); -+ addUnrealizedMaterializations(); - } - - Attribute convertEncoding(Attribute attr) const final { -@@ -1021,6 +1023,36 @@ - } - }; - -+// Fold unnecessary unrealized conversion casts. -+// unrealized_conversion(unrealized_conversion(X) : Y) : X -> X -+// Not as complicated at mlir::reconcileUnrealizedCasts because we know that the -+// types must be the same and shouldn't be in chains greater than 2. -+struct ReconcileUnrealizedConversionCasts -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector unrealizedCasts( -+ op->getNumOperands() + 1); -+ unrealizedCasts.push_back(op); -+ for (auto operand : op.getOperands()) { -+ auto unrealizedOperand = -+ operand.getDefiningOp(); -+ if (!unrealizedOperand) { -+ LLVM_DEBUG(llvm::dbgs() << "Failed to reconcile unrealized conversion " -+ "casts: " -+ << op.getOperationName() << "\n"); -+ return success(); -+ } -+ unrealizedCasts.push_back(unrealizedOperand); -+ } -+ LLVM_DEBUG(llvm::dbgs() << "Reconciling unrealized conversion casts: " -+ << op.getOperationName() << "\n"); -+ mlir::reconcileUnrealizedCasts(unrealizedCasts); -+ return success(); -+ } -+}; -+ - template - void populateVhloToStablehloPatterns(RewritePatternSet* patterns, - TypeConverter* converter, -@@ -1043,6 +1075,7 @@ - - RewritePatternSet patterns_(context); - stablehlo::populateVhloToStablehloPatterns(&patterns_, &converter, context); -+ patterns_.add(context); - patterns = std::move(patterns_); - - return success(); -@@ -1055,6 +1088,12 @@ - if (failed(applyPartialConversion(getOperation(), *target, patterns))) { - return signalPassFailure(); - } -+ -+ // Cleanup unrealized conversion casts (if any, in case of dialect mixing) -+ SmallVector ops; -+ getOperation().walk( -+ [&ops](UnrealizedConversionCastOp op) { ops.push_back(op); }); -+ reconcileUnrealizedCasts(ops); - } - - private: -diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp ---- stablehlo/stablehlo/transforms/VhloToVersion.cpp -+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp -@@ -40,7 +40,7 @@ - #include "stablehlo/dialect/VhloTypes.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index a6a1f38b9f4494..da51958d55b282 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "be8ce602efbd90fd677247075745bf16eb4b31ac" - STABLEHLO_SHA256 = "81f44a6f4c37599fc600c159a899602590b17f3d3858fe6a400bb5643b0c9ba1" + STABLEHLO_COMMIT = "4bf77d23bd9150782a70d85fda9c12a2dec5328c" + STABLEHLO_SHA256 = "0efae2563d87c642cf9ad5c576911c5f08f9b5ee023b626ddb2a51a87d93297f" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 949ebc772ae60e..8b137891791fe9 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,577 +1 @@ -diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp ---- stablehlo/stablehlo/dialect/Serialization.cpp -+++ stablehlo/stablehlo/dialect/Serialization.cpp -@@ -32,20 +32,25 @@ - #include "stablehlo/dialect/VhloOps.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { - - LogicalResult serializePortableArtifact(ModuleOp module, - StringRef targetVersion, -- raw_ostream& os) { -+ raw_ostream& os, -+ bool allowOtherDialects) { - MLIRContext* context = module.getContext(); - -- // Convert StableHLO --> VHLO. Will fail if entire program is not StableHLO. -+ // Convert StableHLO --> VHLO. -+ // If allowOtherDialects is true, we will allow other dialects to be present -+ // in the module, otherwise will fail if there are any other dialects present. - { - PassManager pm(context); -- pm.addPass(stablehlo::createStablehloLegalizeToVhloPass()); -+ StablehloLegalizeToVhloPassOptions options; -+ options.allowOtherDialects = allowOtherDialects; -+ pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options)); - if (!succeeded(pm.run(module))) { - return failure(); - } -diff --ruN a/stablehlo/stablehlo/dialect/Serialization.h b/stablehlo/stablehlo/dialect/Serialization.h ---- stablehlo/stablehlo/dialect/Serialization.h -+++ stablehlo/stablehlo/dialect/Serialization.h -@@ -34,7 +34,8 @@ - // unsupported dialects. - LogicalResult serializePortableArtifact(ModuleOp module, - StringRef targetVersion, -- raw_ostream& os); -+ raw_ostream& os, -+ bool allowOtherDialects = false); - - // Read StableHLO portable artifact - // -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.cpp b/stablehlo/stablehlo/dialect/VhloTypes.cpp ---- stablehlo/stablehlo/dialect/VhloTypes.cpp -+++ stablehlo/stablehlo/dialect/VhloTypes.cpp -@@ -323,6 +323,20 @@ - } - - namespace { -+Value materializeIllegalCast(OpBuilder& builder, Type type, ValueRange inputs, -+ Location loc) { -+ return builder.create(loc, type, inputs) -+ ->getResult(0); -+} -+} // namespace -+ -+void VhloTypeConverter::addUnrealizedMaterializations() { -+ addTargetMaterialization(materializeIllegalCast); -+ addSourceMaterialization(materializeIllegalCast); -+ addArgumentMaterialization(materializeIllegalCast); -+} -+ -+namespace { - // Helper functions for VHLO verifiers - template - bool isFromVhlo(TypeOrAttr t) { -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/dialect/VhloTypes.h ---- stablehlo/stablehlo/dialect/VhloTypes.h -+++ stablehlo/stablehlo/dialect/VhloTypes.h -@@ -55,6 +55,9 @@ - // it is likely that a subclass should call this last, especially if a default - // `Type -> Type` fallback conversion is registered. - void addBuiltinToVhloConversions(); -+ -+ // Mark unrealized casts as legal. Useful for dialect mixing. -+ void addUnrealizedMaterializations(); - }; - - // Autogenerated VHLO type printers and parsers. -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -@@ -1338,7 +1338,7 @@ - - // ----- - --// expected-error@+1 {{scale out of expressed type range}} -+// expected-error@+1 {{scale 1.055040e+05 out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return -@@ -1346,7 +1346,7 @@ - - // ----- - --// expected-error@+1 {{scale out of expressed type range}} -+// expected-error@+1 {{scale 4.960464e-08 out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return -diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir ---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir -+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir -@@ -0,0 +1,189 @@ -+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -+ -+// The script is designed to make adding checks to -+// a test case fast, it is *not* designed to be authoritative -+// about what constitutes a good test! The CHECK should be -+// minimized and named to reflect the test intent. -+ -+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -+ -+// The script is designed to make adding checks to -+// a test case fast, it is *not* designed to be authoritative -+// minimized and named to reflect the test intent. -+ -+// RUN: stablehlo-opt %s --stablehlo-legalize-to-vhlo=allow-other-dialects | FileCheck %s -+// RUN: stablehlo-opt %s > %t.0 -+// RUN: stablehlo-opt %s --stablehlo-legalize-to-vhlo=allow-other-dialects | stablehlo-opt --vhlo-legalize-to-stablehlo > %t.1 -+// RUN: diff %t.0 %t.1 -+ -+// CHECK-LABEL: vhlo.func_v1 @op_other( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor to !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_3]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @op_other(%arg0: tensor) -> tensor { -+ %0 = arith.addf %arg0, %arg0 : tensor -+ return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_shlo( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_1]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @op_shlo(%arg0: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg0 : tensor -+ return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @mixed_shlo_other_shlo( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_2:.*]] = "vhlo.abs_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_3]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : tensor to !vhlo.tensor_v1 -+// CHECK: %[[VAL_6:.*]] = "vhlo.abs_v1"(%[[VAL_5]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_6]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @mixed_shlo_other_shlo(%arg0: tensor) -> tensor { -+ %0 = stablehlo.abs %arg0 : tensor -+ %1 = arith.addf %0, %arg0 : tensor -+ %2 = stablehlo.abs %1 : tensor -+ return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @mixed_other_shlo_other( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { -+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor to !vhlo.tensor_v1 -+// CHECK: %[[VAL_4:.*]] = "vhlo.add_v1"(%[[VAL_3]], %[[VAL_0]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1 to tensor -+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_1]] : tensor -+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : tensor to !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_7]]) : (!vhlo.tensor_v1) -> () -+// CHECK: } -+func.func @mixed_other_shlo_other(%arg0: tensor) -> tensor { -+ %0 = arith.addf %arg0, %arg0 : tensor -+ %1 = stablehlo.add %0, %arg0 : tensor -+ %2 = arith.addf %1, %arg0 : tensor -+ return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_with_region( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<1x16x16x320x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1<1x320x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}} : tensor<2xi64>>}> ({ -+// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1, %[[VAL_4:.*]]: !vhlo.tensor_v1): -+// CHECK: %[[VAL_5:.*]] = "vhlo.add_v1"(%[[VAL_3]], %[[VAL_4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+// CHECK: "vhlo.return_v1"(%[[VAL_5]]) : (!vhlo.tensor_v1) -> () -+// CHECK: }) : (!vhlo.tensor_v1<1x16x16x320x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<1x320x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<1x320x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @op_with_region(%arg0: tensor<1x16x16x320xf32>, %arg1: tensor) -> tensor<1x320xf32> { -+ %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.add across dimensions = [1, 2] : (tensor<1x16x16x320xf32>, tensor) -> tensor<1x320xf32> -+ return %0 : tensor<1x320xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_with_region_mixed_other_shlo_other( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<7x5x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<5x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}} : tensor<1xi64>>}> ({ -+// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>): -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_5]] : tensor<5xf32> -+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_9:.*]] = "vhlo.add_v1"(%[[VAL_8]], %[[VAL_3]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_5]] : tensor<5xf32> -+// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_12]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: }) : (!vhlo.tensor_v1<7x5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @op_with_region_mixed_other_shlo_other(%arg0: tensor<7x5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { -+ %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> -+ reducer(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>) { -+ %1 = arith.addf %arg2, %arg3 : tensor<5xf32> -+ %2 = stablehlo.add %1, %arg2 : tensor<5xf32> -+ %3 = arith.addf %2, %arg3 : tensor<5xf32> -+ stablehlo.return %3 : tensor<5xf32> -+ } -+ return %0 : tensor<5xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @op_with_region_mixed_shlo_other_shlo( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<7x5x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<5x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}} : tensor<1xi64>>}> ({ -+// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>): -+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_6:.*]] = "vhlo.abs_v1"(%[[VAL_3]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32> -+// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_7]], %[[VAL_5]] : tensor<5xf32> -+// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: %[[VAL_10:.*]] = "vhlo.abs_v1"(%[[VAL_9]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_10]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: }) : (!vhlo.tensor_v1<7x5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @op_with_region_mixed_shlo_other_shlo(%arg0: tensor<7x5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { -+ %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> -+ reducer(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>) { -+ %1 = stablehlo.abs %arg2 : tensor<5xf32> -+ %2 = arith.addf %1, %arg3 : tensor<5xf32> -+ %3 = stablehlo.abs %2 : tensor<5xf32> -+ stablehlo.return %3 : tensor<5xf32> -+ } -+ return %0 : tensor<5xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: vhlo.func_v1 @stablehlo_in_other_op_region( -+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<2x!vhlo.f32_v1>, -+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.index_v1) -> (!vhlo.tensor_v1<2x!vhlo.f32_v1>) { -+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1<2x!vhlo.f32_v1> to tensor<2xf32> -+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -+// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index -+// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -+// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_9:.*]] = %[[VAL_2]]) -> (tensor<2xf32>) { -+// CHECK: %[[VAL_10:.*]] = tensor.insert %[[VAL_6]] into %[[VAL_9]]{{\[}}%[[VAL_8]]] : tensor<2xf32> -+// CHECK: %[[VAL_11:.*]] = builtin.unrealized_conversion_cast %[[VAL_10]] : tensor<2xf32> to !vhlo.tensor_v1<2x!vhlo.f32_v1> -+// CHECK: %[[VAL_12:.*]] = "vhlo.add_v1"(%[[VAL_11]], %[[VAL_11]]) : (!vhlo.tensor_v1<2x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.f32_v1>) -> !vhlo.tensor_v1<2x!vhlo.f32_v1> -+// CHECK: %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !vhlo.tensor_v1<2x!vhlo.f32_v1> to tensor<2xf32> -+// CHECK: scf.yield %[[VAL_13]] : tensor<2xf32> -+// CHECK: } -+// CHECK: %[[VAL_14:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : tensor<2xf32> to !vhlo.tensor_v1<2x!vhlo.f32_v1> -+// CHECK: "vhlo.return_v1"(%[[VAL_14]]) : (!vhlo.tensor_v1<2x!vhlo.f32_v1>) -> () -+// CHECK: } -+func.func @stablehlo_in_other_op_region(%arg0: tensor<2xf32>, %arg1: index) -> tensor<2xf32> { -+ %c0 = arith.constant 0 : index -+ %c1 = arith.constant 1 : index -+ %c2 = arith.constant 2 : index -+ %cst = arith.constant 0.0 : f32 -+ -+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { -+ %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> -+ %new_out_add = stablehlo.add %new_out, %new_out : tensor<2xf32> -+ scf.yield %new_out_add : tensor<2xf32> -+ } -+ return %for : tensor<2xf32> -+} -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -308,7 +308,26 @@ - - def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { - let summary = "Legalize StableHLO to VHLO."; -+ let description = [{ -+ Legalize StableHLO to the latest version of ops in VHLO. These ops can then -+ be downgraded to older versions of VHLO for forward compatibility using -+ `VhloToVersionPass`. -+ -+ ```mlir -+ stablehlo.exponential %[[ARG0]] <{result_accuracy = DEFAULT}> : tensor -+ # ====> -+ "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.DEFAULT_v1}> : !vhlo.tensor_v1 -+ ``` -+ -+ See [vhlo.md > The VHLO dialect](https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md) -+ for full details on how VHLO is used to preserve forward and backward -+ compatibility. -+ }]; - let dependentDialects = ["mlir::vhlo::VhloDialect"]; -+ let options = [ -+ Option<"allowOtherDialects", "allow-other-dialects", "bool", /*default=*/"false", -+ "Allow serialization to use other (potentially unstable) dialects, inserts unrealized casts between dialects.">, -+ ]; - } - - def StablehloRefineArgumentsPass : Pass<"stablehlo-refine-arguments", "ModuleOp"> { -@@ -330,6 +349,7 @@ - %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...} - : (tensor<16xf32>, tensor<1xi64>) -> tensor - ... -+ } - ``` - - The `refinedTypesOption` can be used to specify a list of refined types. -@@ -402,7 +422,23 @@ - } - - def VhloToVersionPass : Pass<"vhlo-to-version"> { -- let summary = "Convert between versions of VHLO."; -+ let summary = "Convert between versions of VHLO for compatibility."; -+ let description = [{ -+ Converts between versions of VHLO for IR upgrading and downgrading to -+ preserve forward and backward compatibility. -+ -+ ```mlir -+ "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = DEFAULT}> -+ # ==( -target=1.0.0 )==> -+ "vhlo.exponential_v1"(%[[ARG0]]) -+ # ==( -target=1.9.0 )==> -+ "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = DEFAULT}> -+ ``` -+ -+ See [vhlo.md > The VHLO dialect](https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md) -+ for full details on how VHLO is used to preserve forward and backward -+ compatibility. -+ }]; - let options = [ - Option<"targetVersionOption", "target", "std::string", "", - "The target version. Must be a version of the form #.#.# .">, -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp -@@ -47,7 +47,7 @@ - #include "stablehlo/transforms/PassUtils.h" // IWYU pragma: keep - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -@@ -12,6 +12,7 @@ - limitations under the License. - ==============================================================================*/ - -+#include - #include - #include - #include -@@ -24,6 +25,8 @@ - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinDialect.h" -+#include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/IR/ValueRange.h" -@@ -37,7 +40,7 @@ - #include "stablehlo/transforms/MapStablehloToVhlo.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { -@@ -53,17 +56,33 @@ - - class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { - public: -- StablehloToVhloTypeConverter() : vhlo::VhloTypeConverter() { -- addConversion([](Type type) -> Type { -- if (llvm::isa(type.getDialect())) return type; -- -- LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); -+ StablehloToVhloTypeConverter(bool allowOtherDialects) -+ : vhlo::VhloTypeConverter() { -+ LLVM_DEBUG( -+ llvm::dbgs() -+ << "[StablehloToVhloTypeConverter] Creating with allowOtherDialects: " -+ << allowOtherDialects << '\n'); -+ -+ addConversion([&allowOtherDialects](Type type) -> Type { -+ if (isa(type.getDialect())) return type; -+ -+ if (allowOtherDialects && -+ !isa(type.getDialect())) { -+ LLVM_DEBUG(llvm::dbgs() -+ << "[StablehloToVhloTypeConverter] Valid non-VHLO type: " -+ << type << '\n'); -+ return type; -+ } -+ -+ LLVM_DEBUG(llvm::dbgs() << "[StablehloToVhloTypeConverter] Invalid type: " -+ << type << '\n'); - return {}; - }); - addConversion([](TokenType token) -> Type { - return vhlo::TokenV1Type::get(token.getContext()); - }); - addBuiltinToVhloConversions(); -+ if (allowOtherDialects) addUnrealizedMaterializations(); - } - - Attribute convertEncoding(Attribute attr) const final { -@@ -1021,14 +1040,27 @@ - struct StablehloLegalizeToVhloPass - : public impl::StablehloLegalizeToVhloPassBase< - StablehloLegalizeToVhloPass> { -+ using StablehloLegalizeToVhloPassBase::StablehloLegalizeToVhloPassBase; -+ - LogicalResult initialize(MLIRContext* context) override { - target = std::make_shared(*context); - target->addIllegalDialect(); - target->addIllegalDialect(); - target->addLegalDialect(); -+ LLVM_DEBUG(llvm::dbgs() -+ << "allowOtherDialects: " << allowOtherDialects << "\n"); -+ -+ converter = -+ std::make_shared(allowOtherDialects); -+ if (allowOtherDialects) { -+ target->addLegalOp(); -+ } else { -+ target->addIllegalOp(); -+ } - - RewritePatternSet patterns_(context); -- stablehlo::populateStablehloToVhloPatterns(&patterns_, &converter, context); -+ stablehlo::populateStablehloToVhloPatterns(&patterns_, converter.get(), -+ context); - patterns = std::move(patterns_); - - return success(); -@@ -1043,7 +1075,7 @@ - } - - private: -- StablehloToVhloTypeConverter converter; -+ std::shared_ptr converter; - FrozenRewritePatternSet patterns; - std::shared_ptr target; - }; -diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -@@ -27,6 +27,7 @@ - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/IR/ValueRange.h" -@@ -40,7 +41,7 @@ - #include "stablehlo/transforms/MapStablehloToVhlo.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { -@@ -63,6 +64,7 @@ - return stablehlo::TokenType::get(token.getContext()); - }); - addVhloToBuiltinConversions(); -+ addUnrealizedMaterializations(); - } - - Attribute convertEncoding(Attribute attr) const final { -@@ -1021,6 +1023,36 @@ - } - }; - -+// Fold unnecessary unrealized conversion casts. -+// unrealized_conversion(unrealized_conversion(X) : Y) : X -> X -+// Not as complicated at mlir::reconcileUnrealizedCasts because we know that the -+// types must be the same and shouldn't be in chains greater than 2. -+struct ReconcileUnrealizedConversionCasts -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector unrealizedCasts( -+ op->getNumOperands() + 1); -+ unrealizedCasts.push_back(op); -+ for (auto operand : op.getOperands()) { -+ auto unrealizedOperand = -+ operand.getDefiningOp(); -+ if (!unrealizedOperand) { -+ LLVM_DEBUG(llvm::dbgs() << "Failed to reconcile unrealized conversion " -+ "casts: " -+ << op.getOperationName() << "\n"); -+ return success(); -+ } -+ unrealizedCasts.push_back(unrealizedOperand); -+ } -+ LLVM_DEBUG(llvm::dbgs() << "Reconciling unrealized conversion casts: " -+ << op.getOperationName() << "\n"); -+ mlir::reconcileUnrealizedCasts(unrealizedCasts); -+ return success(); -+ } -+}; -+ - template - void populateVhloToStablehloPatterns(RewritePatternSet* patterns, - TypeConverter* converter, -@@ -1043,6 +1075,7 @@ - - RewritePatternSet patterns_(context); - stablehlo::populateVhloToStablehloPatterns(&patterns_, &converter, context); -+ patterns_.add(context); - patterns = std::move(patterns_); - - return success(); -@@ -1055,6 +1088,12 @@ - if (failed(applyPartialConversion(getOperation(), *target, patterns))) { - return signalPassFailure(); - } -+ -+ // Cleanup unrealized conversion casts (if any, in case of dialect mixing) -+ SmallVector ops; -+ getOperation().walk( -+ [&ops](UnrealizedConversionCastOp op) { ops.push_back(op); }); -+ reconcileUnrealizedCasts(ops); - } - - private: -diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp ---- stablehlo/stablehlo/transforms/VhloToVersion.cpp -+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp -@@ -40,7 +40,7 @@ - #include "stablehlo/dialect/VhloTypes.h" - #include "stablehlo/transforms/Passes.h" - --#define DEBUG_TYPE "compat-passes" -+#define DEBUG_TYPE "stablehlo-compat" - - namespace mlir { - namespace stablehlo { diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index a6a1f38b9f4494..da51958d55b282 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "be8ce602efbd90fd677247075745bf16eb4b31ac" - STABLEHLO_SHA256 = "81f44a6f4c37599fc600c159a899602590b17f3d3858fe6a400bb5643b0c9ba1" + STABLEHLO_COMMIT = "4bf77d23bd9150782a70d85fda9c12a2dec5328c" + STABLEHLO_SHA256 = "0efae2563d87c642cf9ad5c576911c5f08f9b5ee023b626ddb2a51a87d93297f" # LINT.ThenChange(Google-internal path) tf_http_archive( From 7c0be5b78503c5c6c7ea55d44786b75570e34595 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 3 Apr 2025 13:47:23 -0700 Subject: [PATCH 0211/1324] Add basic support for RaggedDot in SPMD partitioner. This cl is the 4th step to fully support RaggedDot in Shardy. 1. Import and export the RaggedDot into Shardy. 2. Add a sharding rule for the new operation. 3. Handle the new operation in the explicit reshard in Shardy. 4. Handle it in SPMD partitioner without resolving any conflicts. Steps 1 and 2 were in cl/737011229. We will proceed with Step 3 afterwards. This operation is a great example to demonstrate it is easy to support new and customized operations in Shardy system. PiperOrigin-RevId: 743684094 --- third_party/xla/xla/hlo/utils/hlo_matchers.h | 1 + .../xla/xla/service/spmd/spmd_partitioner.cc | 54 +++++++++++ .../xla/xla/service/spmd/spmd_partitioner.h | 1 + .../xla/service/spmd/spmd_partitioner_test.cc | 96 +++++++++++++++++++ 4 files changed, 152 insertions(+) diff --git a/third_party/xla/xla/hlo/utils/hlo_matchers.h b/third_party/xla/xla/hlo/utils/hlo_matchers.h index 93c11551ce8b15..914e120e50318c 100644 --- a/third_party/xla/xla/hlo/utils/hlo_matchers.h +++ b/third_party/xla/xla/hlo/utils/hlo_matchers.h @@ -325,6 +325,7 @@ HLO_MATCHER(Pad); HLO_MATCHER(PartitionId); HLO_MATCHER(Power); HLO_MATCHER(RaggedAllToAll); +HLO_MATCHER(RaggedDot); HLO_MATCHER(Recv); HLO_MATCHER(RecvDone); HLO_MATCHER(Reduce); diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index bd0f12b5d8fddd..f6333c3f530b21 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -4976,6 +4976,60 @@ absl::Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) { "the data is replicated, and if the latter which data is replicated."); } +absl::Status SpmdPartitioningVisitor::HandleRaggedDot(HloInstruction* hlo) { + LOG(WARNING) << "You have to use Shardy for RaggedDot. If not, the behavior " + "is undefined."; + + const RaggedDotDimensionNumbers& ragged_dot_dnums = + hlo->ragged_dot_dimension_numbers(); + const DotDimensionNumbers& dot_dnums = + ragged_dot_dnums.dot_dimension_numbers(); + + CHECK_EQ(ragged_dot_dnums.lhs_ragged_dimensions_size(), 1); + int64_t lhs_ragged_dim = ragged_dot_dnums.lhs_ragged_dimensions(0); + + PartitionedHlo& lhs = GetPartitionedHlo(hlo->operand(0)); + PartitionedHlo& rhs = GetPartitionedHlo(hlo->operand(1)); + PartitionedHlo& group_sizes = GetPartitionedHlo(hlo->operand(2)); + if (lhs.hlo() == rhs.hlo()) { + rhs = MakeACopyAndReturnItsPartitionedHlo(rhs, builder()); + } + + std::vector sharded_lhs_contracting_dims; + if (lhs.sharding().IsTiled()) { + for (int64_t dim : dot_dnums.lhs_contracting_dimensions()) { + if (lhs.sharding().tile_assignment().dim(dim) > 1) { + sharded_lhs_contracting_dims.push_back(dim); + } + } + } + + if (!sharded_lhs_contracting_dims.empty()) { + lhs = lhs.PadWithZero(); + rhs = rhs.PadWithZero(); + } + + HloInstruction* phlo; + Shape pshape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + if (absl::c_linear_search(dot_dnums.lhs_batch_dimensions(), lhs_ragged_dim)) { + phlo = b_.AddInstruction(HloInstruction::CreateDot( + pshape, lhs.hlo(), rhs.hlo(), dot_dnums, hlo->precision_config())); + } else { + phlo = b_.AddInstruction(hlo->CloneWithNewOperands( + pshape, {lhs.hlo(), rhs.hlo(), group_sizes.hlo()})); + } + + if (!sharded_lhs_contracting_dims.empty()) { + phlo = lhs.state().partitioner->AllReduceAlongShardingDims( + lhs.state().b, phlo, lhs.sharding(), lhs.state().next_channel_id, + sharded_lhs_contracting_dims, lhs.state().collective_ops_creator, + MakeBinaryAdd(phlo->shape().element_type(), lhs.state().module)); + } + + SetPartitionedHlo(hlo, [&]() { return phlo; }); + return absl::OkStatus(); +} + SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, int64_t num_replicas) { return { diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.h b/third_party/xla/xla/service/spmd/spmd_partitioner.h index de4e76d42a5059..a0994bd6433085 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.h @@ -620,6 +620,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { absl::Status HandlePad(HloInstruction* hlo) override; absl::Status HandleParameter(HloInstruction* hlo) override; absl::Status HandlePartitionId(HloInstruction* hlo) override; + absl::Status HandleRaggedDot(HloInstruction* hlo) override; absl::Status HandleReduce(HloInstruction* hlo) override; absl::Status HandleReduceWindow(HloInstruction* hlo) override; absl::Status HandleReshape(HloInstruction* hlo) override; diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index fd2799ead72a5d..ad2863654444a2 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -15749,6 +15749,102 @@ ENTRY main.12 { EXPECT_THAT(cp, op::Shape("s32[1]{0}")); } +TEST_P(SpmdPartitioningTest, RaggedDotNonContractingMode) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + a = f32[16,32,64] parameter(0), sharding={devices=[2,1,2,2]<=[8] last_tile_dim_replicate} + b = f32[4,16,64,8] parameter(1), sharding={devices=[1,2,2,2]<=[8]} + c = u32[16,4] parameter(2), sharding={devices=[2,1,4]<=[8] last_tile_dim_replicate} + ROOT dot = f32[16,32,8] ragged-dot(a, b, c), + lhs_batch_dims={0}, rhs_batch_dims={1}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + lhs_ragged_dims={1}, rhs_group_dims={0}, + sharding={devices=[2,1,2,2]<=[2,2,2]T(0,2,1) last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[8,32,32]")); + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[4,8,32,4]")); + auto param2 = AllOf(op::Parameter(2), op::Shape("u32[8,4]")); + auto ragged_dot = + AllOf(op::RaggedDot(param0, param1, param2), op::Shape("f32[8,32,4]")); + auto all_reduce = AllOf(op::AllReduce(ragged_dot), op::Shape("f32[8,32,4]")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, all_reduce); + + auto replica_groups = Cast(root)->replica_groups(); + EXPECT_EQ(replica_groups.size(), 4); + EXPECT_THAT(replica_groups[0].replica_ids(), ::testing::ElementsAre(0, 2)); +} + +TEST_P(SpmdPartitioningTest, RaggedDotContractingMode) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + a = f32[15,33,64] parameter(0), sharding={devices=[2,2,1,2]<=[8] last_tile_dim_replicate} + b = f32[15,64,9] parameter(1), sharding={devices=[2,1,2,2]<=[2,2,2]T(0,2,1) last_tile_dim_replicate} + c = u32[15,4] parameter(2), sharding={devices=[2,1,4]<=[8] last_tile_dim_replicate} + ROOT dot = f32[4,15,33,9] ragged-dot(a, b, c), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_ragged_dims={2}, sharding={devices=[1,2,2,2]<=[8]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[8,17,64]")); + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[8,64,5]")); + auto param2 = AllOf(op::Parameter(2), op::Shape("u32[8,4]")); + auto ragged_dot = + AllOf(op::RaggedDot(param0, param1, param2), op::Shape("f32[4,8,17,5]")); + EXPECT_THAT(module->entry_computation()->root_instruction(), ragged_dot); +} + +TEST_P(SpmdPartitioningTest, RaggedDotBatchMode) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + a = f32[16,32,63] parameter(0), sharding={devices=[2,2,2,2]<=[2,2,2,2]T(0,1,3,2) last_tile_dim_replicate} + b = f32[16,63,8] parameter(1), sharding={devices=[2,2,2,2]<=[2,2,2,2]T(0,3,2,1) last_tile_dim_replicate} + c = u32[4] parameter(2), sharding={devices=[4,4]<=[16] last_tile_dim_replicate} + ROOT dot = f32[16,32,8] ragged-dot(a, b, c), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_ragged_dims={0}, + sharding={devices=[2,2,2,2]<=[16] last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/16)); + LOG(ERROR) << module->ToString(); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[8,16,32]")); + auto param0_pad = AllOf(op::Select(_, param0, op::Broadcast(op::Constant())), + op::Shape("f32[8,16,32]")); + + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[8,32,4]")); + auto param1_pad = AllOf(op::Select(_, param1, op::Broadcast(op::Constant())), + op::Shape("f32[8,32,4]")); + + auto dot = AllOf(op::Dot(param0_pad, param1_pad), op::Shape("f32[8,16,4]")); + auto all_reduce = AllOf(op::AllReduce(dot), op::Shape("f32[8,16,4]")); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, all_reduce); + + auto replica_groups = Cast(root)->replica_groups(); + EXPECT_EQ(replica_groups.size(), 8); + EXPECT_THAT(replica_groups[0].replica_ids(), ::testing::ElementsAre(0, 1)); +} + } // namespace } // namespace spmd } // namespace xla From 7278ba5ac094dd252e52254c391bf927ec08c223 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 3 Apr 2025 14:27:09 -0700 Subject: [PATCH 0212/1324] [IFRT Proxy] Add missing `ArrayStore::Reservation::ProcessResponse()` calls When using `ArrayStore::Reservation`, `ArrayStore::Reservation::ProcessResponse()` must be used to catch any errors raised during request handling. Otherwise, the reservation remains non-filled and would break the proxy server invariant. PiperOrigin-RevId: 743698553 --- .../xla/xla/python/ifrt_proxy/server/ifrt_backend.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index a58781905c37b4..9d73688e7f0914 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -526,13 +526,14 @@ Future IfrtBackend::ProcessInternal( asr.emplace(request->make_arrays_from_host_buffer_shards_request() .array_handles(), &array_store_); - return Future(HandleMakeArraysFromHostBufferShardsRequest( - *asr, std::move(request))); + return Future( + asr->ProcessResponse(HandleMakeArraysFromHostBufferShardsRequest( + *asr, std::move(request)))); case IfrtRequest::RequestCase::kMakeErrorArraysRequest: asr.emplace(request->make_error_arrays_request().array_handles(), &array_store_); - return Future( - HandleMakeErrorArraysRequest(*asr, std::move(request))); + return Future(asr->ProcessResponse( + HandleMakeErrorArraysRequest(*asr, std::move(request)))); case IfrtRequest::RequestCase::kAssembleArrayFromSingleDeviceArraysRequest: asr.emplace(request->assemble_array_from_single_device_arrays_request() .result_handle(), From fbe4b80b8f7ee731b713a9416b487249a9fe88e6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 14:34:01 -0700 Subject: [PATCH 0213/1324] Enable loading CUDA redistributions in CPU Linux RBE configurations. This change is made to prevent hermetic CUDA repositories cache invalidation between the builds with `--config=cuda` and without it. It should speed up Github presubmit jobs. Currently CPU and GPU jobs use the machines in the same pool, and they share the RBE cache. Previously the cache was invalidated every time when `TF_NEED_CUDA` value changed between CPU and GPU builds, hence loading CUDA redistributions for GPU jobs took several minutes (see [this job](https://github.com/openxla/xla/actions/runs/14114621736/job/39541688832) for example: all the test results are cached, but CUDA redistributions were still downloaded). With adding `--repo_env USE_CUDA_REDISTRIBUTIONS=1` to RBE CPU linux job configurations, we load some CUDA redistributions once in RBE cache, and then reuse it between the jobs. PiperOrigin-RevId: 743700713 --- .bazelrc | 16 +++-- .../gpus/cuda/hermetic/cuda_configure.bzl | 68 ++++++++++++++----- third_party/nccl/hermetic/nccl_configure.bzl | 13 +++- third_party/xla/tensorflow.bazelrc | 16 +++-- .../gpus/cuda/hermetic/cuda_configure.bzl | 68 ++++++++++++++----- .../nccl/hermetic/nccl_configure.bzl | 13 +++- 6 files changed, 146 insertions(+), 48 deletions(-) diff --git a/.bazelrc b/.bazelrc index daa6094966b8f3..c33396789f01c0 100644 --- a/.bazelrc +++ b/.bazelrc @@ -252,13 +252,15 @@ build:mkl_aarch64 -c opt build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true build:mkl_aarch64_threadpool -c opt +# Default CUDA and CUDNN versions. +build:cuda_version --repo_env=HERMETIC_CUDA_VERSION="12.5.1" +build:cuda_version --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" + # CUDA: This config refers to building CUDA op kernels with nvcc. build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -# Default CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.5.1" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda --config=cuda_version # This flag is needed to include CUDA libraries. build:cuda --@local_config_cuda//cuda:include_cuda_libs=true @@ -288,8 +290,7 @@ build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.5.1" -build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda_clang_official --config=cuda_version build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" @@ -592,6 +593,11 @@ build:rbe_linux_cpu --python_path="/usr/bin/python3" # These you may need to change for your own GCP project. common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance +# Download CUDA/CUDNN redistributions to preserve the repositories cache between +# CPU and GPU builds. +build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +build:rbe_linux_cpu --config=cuda_version + # TODO(kanglan): Remove it after toolchain update is complete. build:rbe_linux_cpu_old --config=rbe_linux build:rbe_linux_cpu_old --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" diff --git a/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/gpus/cuda/hermetic/cuda_configure.bzl index b20ee5c3921732..826b91b03d9f99 100644 --- a/third_party/gpus/cuda/hermetic/cuda_configure.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -2,7 +2,10 @@ `cuda_configure` depends on the following environment variables: - * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NEED_CUDA`: Whether to enable building with CUDA toolchain. + * `USE_CUDA_REDISTRIBUTIONS`: Whether to use CUDA redistributions, but not + the CUDA toolchain. This can be used to preserve the cache between GPU and + CPU builds. * `TF_NVCC_CLANG` (deprecated): Whether to use clang for C++ and NVCC for Cuda compilation. * `CUDA_NVCC`: Whether to use NVCC for Cuda compilation. @@ -120,6 +123,11 @@ def enable_cuda(repository_ctx): """Returns whether to build with CUDA support.""" return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) +def use_cuda_redistributions(repository_ctx): + """Returns whether to use CUDA redistributions.""" + return (int(get_host_environ(repository_ctx, USE_CUDA_REDISTRIBUTIONS, False)) and + not int(get_host_environ(repository_ctx, _TF_NEED_ROCM, False))) + def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" @@ -459,23 +467,43 @@ def _create_dummy_repository(repository_ctx): # Set up cuda_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. - repository_ctx.template( - "cuda/cuda/cuda_config.h", - repository_ctx.attr.cuda_config_tpl, - { - "%{cuda_version}": "", - "%{cudart_version}": "", - "%{cupti_version}": "", - "%{cublas_version}": "", - "%{cusolver_version}": "", - "%{curand_version}": "", - "%{cufft_version}": "", - "%{cusparse_version}": "", - "%{cudnn_version}": "", - "%{cuda_toolkit_path}": "", - "%{cuda_compute_capabilities}": "", - }, - ) + if use_cuda_redistributions(repository_ctx): + cuda_config = _get_cuda_config(repository_ctx) + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cudart_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + else: + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) # Set up cuda_config.py, which is used by gen_build_info to provide # static build environment info to the API @@ -586,6 +614,8 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" TF_CUDA_VERSION = "TF_CUDA_VERSION" TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NEED_ROCM = "TF_NEED_ROCM" +USE_CUDA_REDISTRIBUTIONS = "USE_CUDA_REDISTRIBUTIONS" _TF_NVCC_CLANG = "TF_NVCC_CLANG" _CUDA_NVCC = "CUDA_NVCC" _TF_SYSROOT = "TF_SYSROOT" @@ -595,6 +625,7 @@ _ENVIRONS = [ _CC, _CLANG_CUDA_COMPILER_PATH, TF_NEED_CUDA, + _TF_NEED_ROCM, _TF_NVCC_CLANG, _CUDA_NVCC, TF_CUDA_VERSION, @@ -606,6 +637,7 @@ _ENVIRONS = [ _TMPDIR, "LOCAL_CUDA_PATH", "LOCAL_CUDNN_PATH", + USE_CUDA_REDISTRIBUTIONS, ] cuda_configure = repository_rule( diff --git a/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/nccl/hermetic/nccl_configure.bzl index c1e49a6b9f1dd2..acbfd146e2392f 100644 --- a/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/nccl/hermetic/nccl_configure.bzl @@ -14,8 +14,10 @@ load( "HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "TF_NEED_CUDA", + "USE_CUDA_REDISTRIBUTIONS", "enable_cuda", "get_cuda_version", + "use_cuda_redistributions", ) load( "//third_party/remote_config:common.bzl", @@ -157,7 +159,14 @@ def _nccl_autoconf_impl(repository_ctx): get_cpu_value(repository_ctx) != "Linux"): # Add a dummy build file to make bazel query happy. repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) - repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + if use_cuda_redistributions(repository_ctx): + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + repository_ctx.file( + "nccl_config.h", + "#define TF_NCCL_VERSION \"%s\"" % nccl_version, + ) + else: + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") else: _create_local_nccl_repository(repository_ctx) @@ -167,6 +176,8 @@ _ENVIRONS = [ _TF_NCCL_USE_STUB, HERMETIC_CUDA_VERSION, "LOCAL_NCCL_PATH", + USE_CUDA_REDISTRIBUTIONS, + "TF_NEED_ROCM", ] nccl_configure = repository_rule( diff --git a/third_party/xla/tensorflow.bazelrc b/third_party/xla/tensorflow.bazelrc index 13417de95522f5..1d47e43735baf3 100644 --- a/third_party/xla/tensorflow.bazelrc +++ b/third_party/xla/tensorflow.bazelrc @@ -157,13 +157,15 @@ build:mkl_aarch64 -c opt build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true build:mkl_aarch64_threadpool -c opt +# Default CUDA and CUDNN versions. +build:cuda_version --repo_env=HERMETIC_CUDA_VERSION="12.6.3" +build:cuda_version --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" + # CUDA: This config refers to building CUDA op kernels with nvcc. build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -# Default CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.3" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda --config=cuda_version # This flag is needed to include CUDA libraries. build:cuda --@local_config_cuda//cuda:include_cuda_libs=true @@ -193,8 +195,7 @@ build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.6.3" -build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda_clang_official --config=cuda_version build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" @@ -451,6 +452,11 @@ build:rbe_linux_cpu --python_path="/usr/bin/python3" # These you may need to change for your own GCP project. common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance +# Download CUDA/CUDNN redistributions to preserve the repositories cache between +# CPU and GPU builds. +build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +build:rbe_linux_cpu --config=cuda_version + build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl index b20ee5c3921732..826b91b03d9f99 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -2,7 +2,10 @@ `cuda_configure` depends on the following environment variables: - * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NEED_CUDA`: Whether to enable building with CUDA toolchain. + * `USE_CUDA_REDISTRIBUTIONS`: Whether to use CUDA redistributions, but not + the CUDA toolchain. This can be used to preserve the cache between GPU and + CPU builds. * `TF_NVCC_CLANG` (deprecated): Whether to use clang for C++ and NVCC for Cuda compilation. * `CUDA_NVCC`: Whether to use NVCC for Cuda compilation. @@ -120,6 +123,11 @@ def enable_cuda(repository_ctx): """Returns whether to build with CUDA support.""" return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) +def use_cuda_redistributions(repository_ctx): + """Returns whether to use CUDA redistributions.""" + return (int(get_host_environ(repository_ctx, USE_CUDA_REDISTRIBUTIONS, False)) and + not int(get_host_environ(repository_ctx, _TF_NEED_ROCM, False))) + def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" @@ -459,23 +467,43 @@ def _create_dummy_repository(repository_ctx): # Set up cuda_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. - repository_ctx.template( - "cuda/cuda/cuda_config.h", - repository_ctx.attr.cuda_config_tpl, - { - "%{cuda_version}": "", - "%{cudart_version}": "", - "%{cupti_version}": "", - "%{cublas_version}": "", - "%{cusolver_version}": "", - "%{curand_version}": "", - "%{cufft_version}": "", - "%{cusparse_version}": "", - "%{cudnn_version}": "", - "%{cuda_toolkit_path}": "", - "%{cuda_compute_capabilities}": "", - }, - ) + if use_cuda_redistributions(repository_ctx): + cuda_config = _get_cuda_config(repository_ctx) + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cudart_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + else: + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) # Set up cuda_config.py, which is used by gen_build_info to provide # static build environment info to the API @@ -586,6 +614,8 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" TF_CUDA_VERSION = "TF_CUDA_VERSION" TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NEED_ROCM = "TF_NEED_ROCM" +USE_CUDA_REDISTRIBUTIONS = "USE_CUDA_REDISTRIBUTIONS" _TF_NVCC_CLANG = "TF_NVCC_CLANG" _CUDA_NVCC = "CUDA_NVCC" _TF_SYSROOT = "TF_SYSROOT" @@ -595,6 +625,7 @@ _ENVIRONS = [ _CC, _CLANG_CUDA_COMPILER_PATH, TF_NEED_CUDA, + _TF_NEED_ROCM, _TF_NVCC_CLANG, _CUDA_NVCC, TF_CUDA_VERSION, @@ -606,6 +637,7 @@ _ENVIRONS = [ _TMPDIR, "LOCAL_CUDA_PATH", "LOCAL_CUDNN_PATH", + USE_CUDA_REDISTRIBUTIONS, ] cuda_configure = repository_rule( diff --git a/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl index c1e49a6b9f1dd2..acbfd146e2392f 100644 --- a/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl @@ -14,8 +14,10 @@ load( "HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "TF_NEED_CUDA", + "USE_CUDA_REDISTRIBUTIONS", "enable_cuda", "get_cuda_version", + "use_cuda_redistributions", ) load( "//third_party/remote_config:common.bzl", @@ -157,7 +159,14 @@ def _nccl_autoconf_impl(repository_ctx): get_cpu_value(repository_ctx) != "Linux"): # Add a dummy build file to make bazel query happy. repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) - repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + if use_cuda_redistributions(repository_ctx): + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + repository_ctx.file( + "nccl_config.h", + "#define TF_NCCL_VERSION \"%s\"" % nccl_version, + ) + else: + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") else: _create_local_nccl_repository(repository_ctx) @@ -167,6 +176,8 @@ _ENVIRONS = [ _TF_NCCL_USE_STUB, HERMETIC_CUDA_VERSION, "LOCAL_NCCL_PATH", + USE_CUDA_REDISTRIBUTIONS, + "TF_NEED_ROCM", ] nccl_configure = repository_rule( From 40d2466c71f8a84268ffdaa31e2e60baa09d9d0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 14:41:52 -0700 Subject: [PATCH 0214/1324] Clean up `Read()` in `file_system`: - Add a default implementation of the deprecated `Read()` in the base class to enable subclasses to migrate to implementing the new, safe `Read()`. - Use `ABSL_DEPRECATE_AND_INLINE()` to enable bots to fix existing call sites. - Remove the redundant size parameter from the new `Read()` as it can be inferred from the `Span` parameter. PiperOrigin-RevId: 743703411 --- third_party/xla/xla/tsl/platform/BUILD | 3 +- .../xla/xla/tsl/platform/file_system.h | 29 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/tsl/platform/BUILD b/third_party/xla/xla/tsl/platform/BUILD index 44daddd98211f5..cedaa35ff49072 100644 --- a/third_party/xla/xla/tsl/platform/BUILD +++ b/third_party/xla/xla/tsl/platform/BUILD @@ -421,14 +421,15 @@ cc_library( ], hdrs = ["file_system_helper.h"], deps = [ + ":env", ":errors", ":file_statistics", ":macros", ":status", ":statusor", ":types", - "//xla/tsl/platform:env", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:cord", diff --git a/third_party/xla/xla/tsl/platform/file_system.h b/third_party/xla/xla/tsl/platform/file_system.h index 65e530152c7e82..866c482d1dd194 100644 --- a/third_party/xla/xla/tsl/platform/file_system.h +++ b/third_party/xla/xla/tsl/platform/file_system.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include #include #include #include @@ -26,6 +27,8 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/file_statistics.h" @@ -783,13 +786,29 @@ class RandomAccessFile { /// because of EOF. /// /// Safe for concurrent use by multiple threads. - ABSL_DEPRECATED("Use the version that takes absl::Span instead.") + ABSL_DEPRECATE_AND_INLINE() virtual absl::Status Read(uint64 offset, size_t n, absl::string_view* result, - char* scratch) const = 0; - - virtual absl::Status Read(uint64 offset, size_t n, absl::string_view& result, + char* scratch) const { + // Subclasses should implement the safe version of Read() below instead of + // this. This implementation is provided to enable the migration: without + // this, when a subclass switches from implementing this (deprecated) Read() + // to the safe version, the compiler will complain that the subclass + // doesn't implement the old one. + return Read(offset, *result, absl::MakeSpan(scratch, n)); + } + + // Like the above, but takes an absl::Span instead of a size_t and a + // char*. + // TODO(b/393630847): + // - Make subclasses implement this method instead of the above, + // - Remove the above. + // - Mark this method as `= 0` to force subclasses to implement it. + virtual absl::Status Read(uint64 offset, absl::string_view& result, absl::Span scratch) const { - return Read(offset, n, &result, scratch.data()); + // This implementation is provided only for backward compatibility. + // If a subclass implements the deprecated Read() above instead of this, it + // will still work. + return Read(offset, scratch.size(), &result, scratch.data()); } #if defined(TF_CORD_SUPPORT) From 1f50c0f61e3c1ad91001a45bb4ccceea347a2b2d Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Thu, 3 Apr 2025 14:47:08 -0700 Subject: [PATCH 0215/1324] Add cache for OpStats PiperOrigin-RevId: 743705265 --- tensorflow/core/profiler/convert/repository.h | 4 +- .../profiler/convert/xplane_to_tools_data.cc | 76 ++++++++++--------- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h index df90b16f5a8748..84ac5dd3b0188a 100644 --- a/tensorflow/core/profiler/convert/repository.h +++ b/tensorflow/core/profiler/convert/repository.h @@ -43,11 +43,13 @@ constexpr char kNoHostIdentifier[] = "NO_HOST"; enum StoredDataType { DCN_COLLECTIVE_STATS, + OP_STATS, }; static auto* kHostDataSuffixes = new std::vector>( - {{StoredDataType::DCN_COLLECTIVE_STATS, ".dcn_collective_stats.pb"}}); + {{StoredDataType::DCN_COLLECTIVE_STATS, ".dcn_collective_stats.pb"}, + {StoredDataType::OP_STATS, ".op_stats.pb"}}); // File system directory snapshot of a profile session. class SessionSnapshot { diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc index 816267f43362c0..885c729af15b50 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc @@ -155,15 +155,36 @@ absl::StatusOr ConvertXSpaceToTraceEvents( } } -absl::StatusOr ConvertMultiXSpacesToOverviewPage( - const SessionSnapshot& session_snapshot) { +absl::Status ConvertMultiXSpaceToCombinedOpStatsWithCache( + const SessionSnapshot& session_snapshot, OpStats* combined_op_stats) { OpStatsOptions options; - options.generate_kernel_stats_db = true; options.generate_op_metrics_db = true; options.generate_step_db = true; + options.generate_kernel_stats_db = true; + TF_ASSIGN_OR_RETURN(auto has_cache, + session_snapshot.HasCacheFile(StoredDataType::OP_STATS)); + if (has_cache.first) { + TF_RETURN_IF_ERROR(ReadBinaryProto(session_snapshot, + StoredDataType::OP_STATS, + kAllHostsIdentifier, combined_op_stats)); + + } else { + TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( + session_snapshot, options, combined_op_stats)); + if (!WriteBinaryProto(session_snapshot, StoredDataType::OP_STATS, + kAllHostsIdentifier, *combined_op_stats) + .ok()) { + LOG(WARNING) << "Failed to write op stats cache file."; + }; + } + return absl::OkStatus(); +} + +absl::StatusOr ConvertMultiXSpacesToOverviewPage( + const SessionSnapshot& session_snapshot) { OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); OverviewPage overview_page = ConvertOpStatsToOverviewPage(combined_op_stats); InferenceStats inference_stats; TF_RETURN_IF_ERROR(ConvertMultiXSpaceToInferenceStats(session_snapshot, "", @@ -175,34 +196,26 @@ absl::StatusOr ConvertMultiXSpacesToOverviewPage( absl::StatusOr ConvertMultiXSpacesToInputPipeline( const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); return ConvertOpStatsToInputPipelineAnalysis(combined_op_stats) .SerializeAsString(); } absl::StatusOr ConvertMultiXSpacesToTfStats( const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_kernel_stats_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); return ConvertOpStatsToTfStats(combined_op_stats).SerializeAsString(); } absl::StatusOr ConvertMultiXSpacesToKernelStats( const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_kernel_stats_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); return combined_op_stats.kernel_stats_db().SerializeAsString(); } @@ -225,12 +238,9 @@ absl::StatusOr ConvertXSpaceToMemoryProfile( absl::StatusOr ConvertMultiXSpacesToPodViewer( const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); std::string json_output; tsl::protobuf::util::JsonPrintOptions opts; @@ -275,11 +285,9 @@ absl::StatusOr ConvertMultiXSpacesToTfDataBottleneckAnalysis( absl::StatusOr ConvertMultiXSpacesToHloStats( const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); hlo_stats::HloStatsDatabase hlo_stats_db = ConvertOpStatsToHloStats(combined_op_stats); return HloStatsToDataTableJson(hlo_stats_db); @@ -287,11 +295,9 @@ absl::StatusOr ConvertMultiXSpacesToHloStats( absl::StatusOr ConvertMultiXSpacesToRooflineModel( const SessionSnapshot& session_snapshot) { - OpStatsOptions op_stats_options; - op_stats_options.generate_op_metrics_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, op_stats_options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); RooflineModelDatabase result = ConvertOpStatsToRooflineModel(combined_op_stats, true); RooflineModelDatabase result_without_infeed_outfeed = @@ -303,11 +309,9 @@ absl::StatusOr ConvertMultiXSpacesToRooflineModel( absl::StatusOr ConvertMultiXSpacesToOpProfileViewer( const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); + TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( + session_snapshot, &combined_op_stats)); tensorflow::profiler::op_profile::Profile profile; ConvertOpStatsToOpProfile( From ab21d2621d0d7e946de9f1f10ca3a8983d163915 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 14:53:22 -0700 Subject: [PATCH 0216/1324] Integrate LLVM at llvm/llvm-project@537b6541e806 Updates LLVM usage to match [537b6541e806](https://github.com/llvm/llvm-project/commit/537b6541e806) PiperOrigin-RevId: 743707094 --- .../convert_control_to_data_outputs.cc | 2 +- third_party/llvm/generated.patch | 22 + third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 635 +++--------------- third_party/shardy/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 48 ++ .../xla/third_party/shardy/temporary.patch | 635 +++--------------- .../xla/third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 48 ++ .../deallocation/transforms/buffer_reuse.cc | 2 +- 10 files changed, 292 insertions(+), 1112 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index 6262cad26ca6e3..dd2f4b7309bf48 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -445,7 +445,7 @@ void ChainResourceOps( // by `class_iter`). Keep track of ops that have already been processed. llvm::SmallDenseSet processed_ops; for (auto member_iter = - resource_equivalence_classes.member_begin(class_iter); + resource_equivalence_classes.member_begin(*class_iter); member_iter != resource_equivalence_classes.member_end(); ++member_iter) { ResourceAndDevice resource_and_device = *member_iter; diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..2e6ff5801f349f 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,23 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp +--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp ++++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp +@@ -873,6 +873,7 @@ + DependentTemplateSpecializationTypeLoc SpecTL + = Builder.push(T); + SpecTL.setElaboratedKeywordLoc(SourceLocation()); ++ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); + SpecTL.setTemplateKeywordLoc(TemplateKWLoc); + SpecTL.setTemplateNameLoc(TemplateNameLoc); + SpecTL.setLAngleLoc(LAngleLoc); +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +@@ -1902,7 +1902,6 @@ + name = "inv_trigf_utils", + srcs = ["src/math/generic/inv_trigf_utils.cpp"], + hdrs = [ +- "src/math/generic/atan_utils.h", + "src/math/generic/inv_trigf_utils.h", + ], + deps = [ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index fd9baec89202f9..91166c32db1a11 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" - LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" + LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" + LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 1cc7fe6b95c33c..a292270454defc 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,566 +1,97 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 99ef3cb..509398d 100644 +index 509398d..2e6ff58 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,546 +1 @@ +@@ -1 +1,23 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ----- a/clang/lib/Driver/ToolChains/Clang.cpp --+++ b/clang/lib/Driver/ToolChains/Clang.cpp --@@ -6397,7 +6397,9 @@ -- Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, -- options::OPT_fno_convergent_functions); -- --- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); --+ // NVPTX doesn't support PGO or coverage --+ if (!Triple.isNVPTX()) --+ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -- -- Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); -- --diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ----- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu --+++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu --@@ -0,0 +1,33 @@ --+// Check that profiling/coverage arguments doen't get passed down to device-side --+// compilation. --+// --+// --+// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// XRUN: -fprofile-generate %s 2>&1 | \ --+// XRUN: FileCheck --check-prefixes=CHECK,PROF %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -fprofile-instr-generate %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,PROF %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -coverage %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -ftest-coverage %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,PROF %s --+// --+// --+// CHECK-NOT: error: unsupported option '-fprofile --+// CHECK-NOT: error: invalid argument --+// CHECK-DAG: "-fcuda-is-device" --+// CHECK-NOT: "-f{{[^"/]*coverage.*}}" --+// CHECK-NOT: "-fprofile{{[^"]*}}" --+// CHECK: "-triple" "x86_64-unknown-linux-gnu" --+// PROF: "-fprofile{{.*}}" --+// GCOV: "-coverage-notes-file= --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp ----- a/lldb/tools/lldb-dap/DAP.cpp --+++ b/lldb/tools/lldb-dap/DAP.cpp --@@ -711,12 +711,12 @@ -- [](const std::string &message) -> llvm::StringRef { -- return message; -- }, --- [](const protocol::Response::Message &message) --+ [](const protocol::ResponseMessage &message) -- -> llvm::StringRef { -- switch (message) { --- case protocol::Response::Message::cancelled: --+ case protocol::eResponseMessageCancelled: -- return "cancelled"; --- case protocol::Response::Message::notStopped: --+ case protocol::eResponseMessageNotStopped: -- return "notStopped"; -- } -- }), --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ----- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp --@@ -7,6 +7,7 @@ -- //===----------------------------------------------------------------------===// -- -- #include "Protocol/ProtocolBase.h" --+#include "lldb/lldb-enumerations.h" -- #include "llvm/ADT/StringRef.h" -- #include "llvm/ADT/StringSwitch.h" -- #include "llvm/Support/ErrorHandling.h" --@@ -31,11 +32,8 @@ -- -- namespace lldb_dap::protocol { -- ---enum MessageType { --- eMessageTypeRequest, --- eMessageTypeResponse, --- eMessageTypeEvent ---}; --+FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, --+ eMessageTypeEvent}; -- -- bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { -- auto rawType = Params.getAsString(); --@@ -107,12 +105,12 @@ -- -- if (R.message) { -- assert(!R.success && "message can only be used if success is false"); --- if (const auto *messageEnum = std::get_if(&*R.message)) { --+ if (const auto *messageEnum = std::get_if(&*R.message)) { -- switch (*messageEnum) { --- case Response::Message::cancelled: --+ case eResponseMessageCancelled: -- Result.insert({"message", "cancelled"}); -- break; --- case Response::Message::notStopped: --+ case eResponseMessageNotStopped: -- Result.insert({"message", "notStopped"}); -- break; -- } --@@ -129,16 +127,16 @@ -- } -- -- bool fromJSON(json::Value const &Params, --- std::variant &M, json::Path P) { --+ std::variant &M, json::Path P) { -- auto rawMessage = Params.getAsString(); -- if (!rawMessage) { -- P.report("expected a string"); -- return false; -- } --- std::optional message = --- StringSwitch>(*rawMessage) --- .Case("cancelled", Response::Message::cancelled) --- .Case("notStopped", Response::Message::notStopped) --+ std::optional message = --+ StringSwitch>(*rawMessage) --+ .Case("cancelled", eResponseMessageCancelled) --+ .Case("notStopped", eResponseMessageNotStopped) -- .Default(std::nullopt); -- if (message) -- M = *message; --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ----- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h --@@ -20,6 +20,7 @@ -- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -- --+#include "lldb/lldb-enumerations.h" -- #include "llvm/Support/JSON.h" -- #include -- #include --@@ -64,15 +65,15 @@ -- llvm::json::Value toJSON(const Event &); -- bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); -- ---/// Response for a request. ---struct Response { --- enum class Message { --+FLAGS_ENUM(ResponseMessage){ -- /// The request was cancelled --- cancelled, --+ eResponseMessageCancelled, -- /// The request may be retried once the adapter is in a 'stopped' state --- notStopped, --- }; --+ eResponseMessageNotStopped, --+}; -- --+/// Response for a request. --+struct Response { -- /// Sequence number of the corresponding request. -- int64_t request_seq; -- --@@ -90,7 +91,7 @@ -- /// Contains the raw error in short form if `success` is false. This raw error -- /// might be interpreted by the client and is not shown in the UI. Some -- /// predefined values exist. --- std::optional> message; --+ std::optional> message; -- -- /// Contains request result if success is true and error details if success is -- /// false. --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ----- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h --@@ -22,6 +22,8 @@ -- -- #include "Protocol/ProtocolBase.h" -- #include "Protocol/ProtocolTypes.h" --+#include "lldb/lldb-enumerations.h" --+#include "llvm/ADT/DenseSet.h" -- #include "llvm/Support/JSON.h" -- #include -- #include --@@ -55,26 +57,26 @@ -- using DisconnectResponse = VoidResponse; -- -- /// Features supported by DAP clients. ---enum ClientFeature { --- eClientFeatureVariableType, --- eClientFeatureVariablePaging, --- eClientFeatureRunInTerminalRequest, --- eClientFeatureMemoryReferences, --- eClientFeatureProgressReporting, --- eClientFeatureInvalidatedEvent, --- eClientFeatureMemoryEvent, --- /// Client supports the `argsCanBeInterpretedByShell` attribute on the --- /// `runInTerminal` request. --- eClientFeatureArgsCanBeInterpretedByShell, --- eClientFeatureStartDebuggingRequest, --- /// The client will interpret ANSI escape sequences in the display of --- /// `OutputEvent.output` and `Variable.value` fields when --- /// `Capabilities.supportsANSIStyling` is also enabled. --- eClientFeatureANSIStyling, --+FLAGS_ENUM(ClientFeature){ --+ eClientFeatureVariableType, --+ eClientFeatureVariablePaging, --+ eClientFeatureRunInTerminalRequest, --+ eClientFeatureMemoryReferences, --+ eClientFeatureProgressReporting, --+ eClientFeatureInvalidatedEvent, --+ eClientFeatureMemoryEvent, --+ /// Client supports the `argsCanBeInterpretedByShell` attribute on the --+ /// `runInTerminal` request. --+ eClientFeatureArgsCanBeInterpretedByShell, --+ eClientFeatureStartDebuggingRequest, --+ /// The client will interpret ANSI escape sequences in the display of --+ /// `OutputEvent.output` and `Variable.value` fields when --+ /// `Capabilities.supportsANSIStyling` is also enabled. --+ eClientFeatureANSIStyling, -- }; -- -- /// Format of paths reported by the debug adapter. ---enum PathFormat { ePatFormatPath, ePathFormatURI }; --+FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; -- -- /// Arguments for `initialize` request. -- struct InitializeRequestArguments { --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ----- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h --@@ -20,6 +20,7 @@ -- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -- --+#include "lldb/lldb-enumerations.h" -- #include "llvm/ADT/DenseSet.h" -- #include "llvm/Support/JSON.h" -- #include --@@ -56,12 +57,8 @@ -- }; -- llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); -- ---enum ColumnType { --- eColumnTypeString, --- eColumnTypeNumber, --- eColumnTypeBoolean, --- eColumnTypeTimestamp ---}; --+FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, --+ eColumnTypeTimestamp}; -- -- /// A ColumnDescriptor specifies what module attribute to show in a column of -- /// the modules view, how to format it, and what the column’s label should be. --@@ -90,27 +87,23 @@ -- -- /// Names of checksum algorithms that may be supported by a debug adapter. -- /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. ---enum ChecksumAlgorithm { --- eChecksumAlgorithmMD5, --- eChecksumAlgorithmSHA1, --- eChecksumAlgorithmSHA256, --- eChecksumAlgorithmTimestamp ---}; --+FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, --+ eChecksumAlgorithmSHA256, --+ eChecksumAlgorithmTimestamp}; -- llvm::json::Value toJSON(const ChecksumAlgorithm &); -- -- /// Describes one or more type of breakpoint a BreakpointMode applies to. This -- /// is a non-exhaustive enumeration and may expand as future breakpoint types -- /// are added. ---enum BreakpointModeApplicability { --- /// In `SourceBreakpoint`'s. --- eBreakpointModeApplicabilitySource, --- /// In exception breakpoints applied in the `ExceptionFilterOptions`. --- eBreakpointModeApplicabilityException, --- /// In data breakpoints requested in the `DataBreakpointInfo` request. --- eBreakpointModeApplicabilityData, --- /// In `InstructionBreakpoint`'s. --- eBreakpointModeApplicabilityInstruction ---}; --+FLAGS_ENUM(BreakpointModeApplicability){ --+ /// In `SourceBreakpoint`'s. --+ eBreakpointModeApplicabilitySource, --+ /// In exception breakpoints applied in the `ExceptionFilterOptions`. --+ eBreakpointModeApplicabilityException, --+ /// In data breakpoints requested in the `DataBreakpointInfo` request. --+ eBreakpointModeApplicabilityData, --+ /// In `InstructionBreakpoint`'s. --+ eBreakpointModeApplicabilityInstruction}; -- llvm::json::Value toJSON(const BreakpointModeApplicability &); -- -- /// A `BreakpointMode` is provided as a option when setting breakpoints on --@@ -133,101 +126,101 @@ -- llvm::json::Value toJSON(const BreakpointMode &); -- -- /// Debug Adapter Features flags supported by lldb-dap. ---enum AdapterFeature { --- /// The debug adapter supports ANSI escape sequences in styling of --- /// `OutputEvent.output` and `Variable.value` fields. --- eAdapterFeatureANSIStyling, --- /// The debug adapter supports the `breakpointLocations` request. --- eAdapterFeatureBreakpointLocationsRequest, --- /// The debug adapter supports the `cancel` request. --- eAdapterFeatureCancelRequest, --- /// The debug adapter supports the `clipboard` context value in the --- /// `evaluate` request. --- eAdapterFeatureClipboardContext, --- /// The debug adapter supports the `completions` request. --- eAdapterFeatureCompletionsRequest, --- /// The debug adapter supports conditional breakpoints. --- eAdapterFeatureConditionalBreakpoints, --- /// The debug adapter supports the `configurationDone` request. --- eAdapterFeatureConfigurationDoneRequest, --- /// The debug adapter supports the `asAddress` and `bytes` fields in the --- /// `dataBreakpointInfo` request. --- eAdapterFeatureDataBreakpointBytes, --- /// The debug adapter supports data breakpoints. --- eAdapterFeatureDataBreakpoints, --- /// The debug adapter supports the delayed loading of parts of the stack, --- /// which requires that both the `startFrame` and `levels` arguments and the --- /// `totalFrames` result of the `stackTrace` request are supported. --- eAdapterFeatureDelayedStackTraceLoading, --- /// The debug adapter supports the `disassemble` request. --- eAdapterFeatureDisassembleRequest, --- /// The debug adapter supports a (side effect free) `evaluate` request for --- /// data hovers. --- eAdapterFeatureEvaluateForHovers, --- /// The debug adapter supports `filterOptions` as an argument on the --- /// `setExceptionBreakpoints` request. --- eAdapterFeatureExceptionFilterOptions, --- /// The debug adapter supports the `exceptionInfo` request. --- eAdapterFeatureExceptionInfoRequest, --- /// The debug adapter supports `exceptionOptions` on the --- /// `setExceptionBreakpoints` request. --- eAdapterFeatureExceptionOptions, --- /// The debug adapter supports function breakpoints. --- eAdapterFeatureFunctionBreakpoints, --- /// The debug adapter supports the `gotoTargets` request. --- eAdapterFeatureGotoTargetsRequest, --- /// The debug adapter supports breakpoints that break execution after a --- /// specified number of hits. --- eAdapterFeatureHitConditionalBreakpoints, --- /// The debug adapter supports adding breakpoints based on instruction --- /// references. --- eAdapterFeatureInstructionBreakpoints, --- /// The debug adapter supports the `loadedSources` request. --- eAdapterFeatureLoadedSourcesRequest, --- /// The debug adapter supports log points by interpreting the `logMessage` --- /// attribute of the `SourceBreakpoint`. --- eAdapterFeatureLogPoints, --- /// The debug adapter supports the `modules` request. --- eAdapterFeatureModulesRequest, --- /// The debug adapter supports the `readMemory` request. --- eAdapterFeatureReadMemoryRequest, --- /// The debug adapter supports restarting a frame. --- eAdapterFeatureRestartFrame, --- /// The debug adapter supports the `restart` request. In this case a client --- /// should not implement `restart` by terminating and relaunching the --- /// adapter but by calling the `restart` request. --- eAdapterFeatureRestartRequest, --- /// The debug adapter supports the `setExpression` request. --- eAdapterFeatureSetExpression, --- /// The debug adapter supports setting a variable to a value. --- eAdapterFeatureSetVariable, --- /// The debug adapter supports the `singleThread` property on the execution --- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, --- /// `stepBack`). --- eAdapterFeatureSingleThreadExecutionRequests, --- /// The debug adapter supports stepping back via the `stepBack` and --- /// `reverseContinue` requests. --- eAdapterFeatureStepBack, --- /// The debug adapter supports the `stepInTargets` request. --- eAdapterFeatureStepInTargetsRequest, --- /// The debug adapter supports stepping granularities (argument --- /// `granularity`) for the stepping requests. --- eAdapterFeatureSteppingGranularity, --- /// The debug adapter supports the `terminate` request. --- eAdapterFeatureTerminateRequest, --- /// The debug adapter supports the `terminateThreads` request. --- eAdapterFeatureTerminateThreadsRequest, --- /// The debug adapter supports the `suspendDebuggee` attribute on the --- /// `disconnect` request. --- eAdapterFeatureSuspendDebuggee, --- /// The debug adapter supports a `format` attribute on the `stackTrace`, --- /// `variables`, and `evaluate` requests. --- eAdapterFeatureValueFormattingOptions, --- /// The debug adapter supports the `writeMemory` request. --- eAdapterFeatureWriteMemoryRequest, --- /// The debug adapter supports the `terminateDebuggee` attribute on the --- /// `disconnect` request. --- eAdapterFeatureTerminateDebuggee, --+FLAGS_ENUM(AdapterFeature){ --+ /// The debug adapter supports ANSI escape sequences in styling of --+ /// `OutputEvent.output` and `Variable.value` fields. --+ eAdapterFeatureANSIStyling, --+ /// The debug adapter supports the `breakpointLocations` request. --+ eAdapterFeatureBreakpointLocationsRequest, --+ /// The debug adapter supports the `cancel` request. --+ eAdapterFeatureCancelRequest, --+ /// The debug adapter supports the `clipboard` context value in the --+ /// `evaluate` request. --+ eAdapterFeatureClipboardContext, --+ /// The debug adapter supports the `completions` request. --+ eAdapterFeatureCompletionsRequest, --+ /// The debug adapter supports conditional breakpoints. --+ eAdapterFeatureConditionalBreakpoints, --+ /// The debug adapter supports the `configurationDone` request. --+ eAdapterFeatureConfigurationDoneRequest, --+ /// The debug adapter supports the `asAddress` and `bytes` fields in the --+ /// `dataBreakpointInfo` request. --+ eAdapterFeatureDataBreakpointBytes, --+ /// The debug adapter supports data breakpoints. --+ eAdapterFeatureDataBreakpoints, --+ /// The debug adapter supports the delayed loading of parts of the stack, --+ /// which requires that both the `startFrame` and `levels` arguments and the --+ /// `totalFrames` result of the `stackTrace` request are supported. --+ eAdapterFeatureDelayedStackTraceLoading, --+ /// The debug adapter supports the `disassemble` request. --+ eAdapterFeatureDisassembleRequest, --+ /// The debug adapter supports a (side effect free) `evaluate` request for --+ /// data hovers. --+ eAdapterFeatureEvaluateForHovers, --+ /// The debug adapter supports `filterOptions` as an argument on the --+ /// `setExceptionBreakpoints` request. --+ eAdapterFeatureExceptionFilterOptions, --+ /// The debug adapter supports the `exceptionInfo` request. --+ eAdapterFeatureExceptionInfoRequest, --+ /// The debug adapter supports `exceptionOptions` on the --+ /// `setExceptionBreakpoints` request. --+ eAdapterFeatureExceptionOptions, --+ /// The debug adapter supports function breakpoints. --+ eAdapterFeatureFunctionBreakpoints, --+ /// The debug adapter supports the `gotoTargets` request. --+ eAdapterFeatureGotoTargetsRequest, --+ /// The debug adapter supports breakpoints that break execution after a --+ /// specified number of hits. --+ eAdapterFeatureHitConditionalBreakpoints, --+ /// The debug adapter supports adding breakpoints based on instruction --+ /// references. --+ eAdapterFeatureInstructionBreakpoints, --+ /// The debug adapter supports the `loadedSources` request. --+ eAdapterFeatureLoadedSourcesRequest, --+ /// The debug adapter supports log points by interpreting the `logMessage` --+ /// attribute of the `SourceBreakpoint`. --+ eAdapterFeatureLogPoints, --+ /// The debug adapter supports the `modules` request. --+ eAdapterFeatureModulesRequest, --+ /// The debug adapter supports the `readMemory` request. --+ eAdapterFeatureReadMemoryRequest, --+ /// The debug adapter supports restarting a frame. --+ eAdapterFeatureRestartFrame, --+ /// The debug adapter supports the `restart` request. In this case a client --+ /// should not implement `restart` by terminating and relaunching the --+ /// adapter but by calling the `restart` request. --+ eAdapterFeatureRestartRequest, --+ /// The debug adapter supports the `setExpression` request. --+ eAdapterFeatureSetExpression, --+ /// The debug adapter supports setting a variable to a value. --+ eAdapterFeatureSetVariable, --+ /// The debug adapter supports the `singleThread` property on the execution --+ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, --+ /// `stepBack`). --+ eAdapterFeatureSingleThreadExecutionRequests, --+ /// The debug adapter supports stepping back via the `stepBack` and --+ /// `reverseContinue` requests. --+ eAdapterFeatureStepBack, --+ /// The debug adapter supports the `stepInTargets` request. --+ eAdapterFeatureStepInTargetsRequest, --+ /// The debug adapter supports stepping granularities (argument --+ /// `granularity`) for the stepping requests. --+ eAdapterFeatureSteppingGranularity, --+ /// The debug adapter supports the `terminate` request. --+ eAdapterFeatureTerminateRequest, --+ /// The debug adapter supports the `terminateThreads` request. --+ eAdapterFeatureTerminateThreadsRequest, --+ /// The debug adapter supports the `suspendDebuggee` attribute on the --+ /// `disconnect` request. --+ eAdapterFeatureSuspendDebuggee, --+ /// The debug adapter supports a `format` attribute on the `stackTrace`, --+ /// `variables`, and `evaluate` requests. --+ eAdapterFeatureValueFormattingOptions, --+ /// The debug adapter supports the `writeMemory` request. --+ eAdapterFeatureWriteMemoryRequest, --+ /// The debug adapter supports the `terminateDebuggee` attribute on the --+ /// `disconnect` request. --+ eAdapterFeatureTerminateDebuggee, -- }; -- -- /// Information about the capabilities of a debug adapter. --@@ -268,10 +261,10 @@ -- }; -- llvm::json::Value toJSON(const Capabilities &); -- ---enum PresentationHint { --- ePresentationHintNormal, --- ePresentationHintEmphasize, --- ePresentationHintDeemphasize, --+FLAGS_ENUM(PresentationHint){ --+ ePresentationHintNormal, --+ ePresentationHintEmphasize, --+ ePresentationHintDeemphasize, -- }; -- -- /// A `Source` is a descriptor for source code. It is returned from the debug --diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ----- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test --+++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test --@@ -1,7 +1,7 @@ -- // Header -- // -- // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) ---// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) --+// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) -- // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) -- // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) -- // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) --diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c ----- a/offload/test/offloading/gpupgo/pgo1.c --+++ b/offload/test/offloading/gpupgo/pgo1.c --@@ -14,7 +14,7 @@ -- // RUN: %target_triple.%basename_t.clang.profraw | \ -- // RUN: %fcheck-generic --check-prefix="CLANG-PGO" -- ---// REQUIRES: gpu --+// REQUIRES: amdgpu -- // REQUIRES: pgo -- -- int test1(int a) { return a / 2; } --diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c ----- a/offload/test/offloading/gpupgo/pgo2.c --+++ b/offload/test/offloading/gpupgo/pgo2.c --@@ -48,7 +48,7 @@ -- // RUN: %target_triple.%basename_t.hfdi.profraw \ -- // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" -- ---// REQUIRES: gpu --+// REQUIRES: amdgpu -- // REQUIRES: pgo -- -- int main() { ++diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ++--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp +++++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp ++@@ -873,6 +873,7 @@ ++ DependentTemplateSpecializationTypeLoc SpecTL ++ = Builder.push(T); ++ SpecTL.setElaboratedKeywordLoc(SourceLocation()); +++ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); ++ SpecTL.setTemplateKeywordLoc(TemplateKWLoc); ++ SpecTL.setTemplateNameLoc(TemplateNameLoc); ++ SpecTL.setLAngleLoc(LAngleLoc); ++diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++@@ -1902,7 +1902,6 @@ ++ name = "inv_trigf_utils", ++ srcs = ["src/math/generic/inv_trigf_utils.cpp"], ++ hdrs = [ ++- "src/math/generic/atan_utils.h", ++ "src/math/generic/inv_trigf_utils.h", ++ ], ++ deps = [ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 005737a..fd9baec 100644 +index fd9baec..91166c3 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" -- LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" -+ LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" -+ LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" +- LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" +- LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" ++ LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" ++ LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" tf_http_archive( name = name, +diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch +index 8b13789..90ca4ec 100755 +--- a/third_party/stablehlo/temporary.patch ++++ b/third_party/stablehlo/temporary.patch +@@ -1 +1,49 @@ ++diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ++@@ -12,7 +12,7 @@ ++ return %2 : tensor<14x15x0x33xf64> ++ } ++ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> ++ return %cst : tensor<14x15x0x17xcomplex> ++ } ++ func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { ++diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ++@@ -12,7 +12,7 @@ ++ return %2 : tensor<14x15x0x33xf32> ++ } ++ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> ++ return %cst : tensor<14x15x0x17xcomplex> ++ } ++ func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { ++diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ++@@ -16,7 +16,7 @@ ++ return %cst : tensor<14x15x0x17xf32> ++ } ++ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> ++ return %cst : tensor<14x15x0x9xcomplex> ++ } ++ } ++diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ++@@ -16,7 +16,7 @@ ++ return %cst : tensor<14x15x0x17xf64> ++ } ++ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> ++ return %cst : tensor<14x15x0x9xcomplex> ++ } ++ } + diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index f93f7a93cea2e8..92b2397d3a6afc 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f25a97f80402e43de93f46931c6dddc485e8dad0" - SHARDY_SHA256 = "c7e55d3902175c064d3dd6bffc856c5e0198ff6c8f1410c3a97ed2c8e85ddb30" + SHARDY_COMMIT = "84e1e3a76e7b827a6d1621df0c649c3449da1540" + SHARDY_SHA256 = "ddb163c1a40466e320b882821e63ee7cb48e31e96bd05109daff38fca0086f21" tf_http_archive( name = "shardy", diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..90ca4ec1d0d819 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,49 @@ +diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +@@ -12,7 +12,7 @@ + return %2 : tensor<14x15x0x33xf64> + } + func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> + return %cst : tensor<14x15x0x17xcomplex> + } + func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { +diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir +@@ -12,7 +12,7 @@ + return %2 : tensor<14x15x0x33xf32> + } + func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> + return %cst : tensor<14x15x0x17xcomplex> + } + func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { +diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir +@@ -16,7 +16,7 @@ + return %cst : tensor<14x15x0x17xf32> + } + func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> + return %cst : tensor<14x15x0x9xcomplex> + } + } +diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir +@@ -16,7 +16,7 @@ + return %cst : tensor<14x15x0x17xf64> + } + func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> + return %cst : tensor<14x15x0x9xcomplex> + } + } diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 1cc7fe6b95c33c..a292270454defc 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,566 +1,97 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 99ef3cb..509398d 100644 +index 509398d..2e6ff58 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,546 +1 @@ +@@ -1 +1,23 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ----- a/clang/lib/Driver/ToolChains/Clang.cpp --+++ b/clang/lib/Driver/ToolChains/Clang.cpp --@@ -6397,7 +6397,9 @@ -- Args.AddLastArg(CmdArgs, options::OPT_fconvergent_functions, -- options::OPT_fno_convergent_functions); -- --- addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); --+ // NVPTX doesn't support PGO or coverage --+ if (!Triple.isNVPTX()) --+ addPGOAndCoverageFlags(TC, C, JA, Output, Args, SanitizeArgs, CmdArgs); -- -- Args.AddLastArg(CmdArgs, options::OPT_fclang_abi_compat_EQ); -- --diff -ruN --strip-trailing-cr a/clang/test/Driver/cuda-no-pgo-or-coverage.cu b/clang/test/Driver/cuda-no-pgo-or-coverage.cu ----- a/clang/test/Driver/cuda-no-pgo-or-coverage.cu --+++ b/clang/test/Driver/cuda-no-pgo-or-coverage.cu --@@ -0,0 +1,33 @@ --+// Check that profiling/coverage arguments doen't get passed down to device-side --+// compilation. --+// --+// --+// XRUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// XRUN: -fprofile-generate %s 2>&1 | \ --+// XRUN: FileCheck --check-prefixes=CHECK,PROF %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -fprofile-instr-generate %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,PROF %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -coverage %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -ftest-coverage %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,GCOV %s --+// --+// RUN: not %clang -### --target=x86_64-linux-gnu -c --cuda-gpu-arch=sm_20 \ --+// RUN: -fprofile-instr-generate -fcoverage-mapping %s 2>&1 | \ --+// RUN: FileCheck --check-prefixes=CHECK,PROF %s --+// --+// --+// CHECK-NOT: error: unsupported option '-fprofile --+// CHECK-NOT: error: invalid argument --+// CHECK-DAG: "-fcuda-is-device" --+// CHECK-NOT: "-f{{[^"/]*coverage.*}}" --+// CHECK-NOT: "-fprofile{{[^"]*}}" --+// CHECK: "-triple" "x86_64-unknown-linux-gnu" --+// PROF: "-fprofile{{.*}}" --+// GCOV: "-coverage-notes-file= --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp ----- a/lldb/tools/lldb-dap/DAP.cpp --+++ b/lldb/tools/lldb-dap/DAP.cpp --@@ -711,12 +711,12 @@ -- [](const std::string &message) -> llvm::StringRef { -- return message; -- }, --- [](const protocol::Response::Message &message) --+ [](const protocol::ResponseMessage &message) -- -> llvm::StringRef { -- switch (message) { --- case protocol::Response::Message::cancelled: --+ case protocol::eResponseMessageCancelled: -- return "cancelled"; --- case protocol::Response::Message::notStopped: --+ case protocol::eResponseMessageNotStopped: -- return "notStopped"; -- } -- }), --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp ----- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp --@@ -7,6 +7,7 @@ -- //===----------------------------------------------------------------------===// -- -- #include "Protocol/ProtocolBase.h" --+#include "lldb/lldb-enumerations.h" -- #include "llvm/ADT/StringRef.h" -- #include "llvm/ADT/StringSwitch.h" -- #include "llvm/Support/ErrorHandling.h" --@@ -31,11 +32,8 @@ -- -- namespace lldb_dap::protocol { -- ---enum MessageType { --- eMessageTypeRequest, --- eMessageTypeResponse, --- eMessageTypeEvent ---}; --+FLAGS_ENUM(MessageType){eMessageTypeRequest, eMessageTypeResponse, --+ eMessageTypeEvent}; -- -- bool fromJSON(const json::Value &Params, MessageType &M, json::Path P) { -- auto rawType = Params.getAsString(); --@@ -107,12 +105,12 @@ -- -- if (R.message) { -- assert(!R.success && "message can only be used if success is false"); --- if (const auto *messageEnum = std::get_if(&*R.message)) { --+ if (const auto *messageEnum = std::get_if(&*R.message)) { -- switch (*messageEnum) { --- case Response::Message::cancelled: --+ case eResponseMessageCancelled: -- Result.insert({"message", "cancelled"}); -- break; --- case Response::Message::notStopped: --+ case eResponseMessageNotStopped: -- Result.insert({"message", "notStopped"}); -- break; -- } --@@ -129,16 +127,16 @@ -- } -- -- bool fromJSON(json::Value const &Params, --- std::variant &M, json::Path P) { --+ std::variant &M, json::Path P) { -- auto rawMessage = Params.getAsString(); -- if (!rawMessage) { -- P.report("expected a string"); -- return false; -- } --- std::optional message = --- StringSwitch>(*rawMessage) --- .Case("cancelled", Response::Message::cancelled) --- .Case("notStopped", Response::Message::notStopped) --+ std::optional message = --+ StringSwitch>(*rawMessage) --+ .Case("cancelled", eResponseMessageCancelled) --+ .Case("notStopped", eResponseMessageNotStopped) -- .Default(std::nullopt); -- if (message) -- M = *message; --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h ----- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h --@@ -20,6 +20,7 @@ -- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_H -- --+#include "lldb/lldb-enumerations.h" -- #include "llvm/Support/JSON.h" -- #include -- #include --@@ -64,15 +65,15 @@ -- llvm::json::Value toJSON(const Event &); -- bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); -- ---/// Response for a request. ---struct Response { --- enum class Message { --+FLAGS_ENUM(ResponseMessage){ -- /// The request was cancelled --- cancelled, --+ eResponseMessageCancelled, -- /// The request may be retried once the adapter is in a 'stopped' state --- notStopped, --- }; --+ eResponseMessageNotStopped, --+}; -- --+/// Response for a request. --+struct Response { -- /// Sequence number of the corresponding request. -- int64_t request_seq; -- --@@ -90,7 +91,7 @@ -- /// Contains the raw error in short form if `success` is false. This raw error -- /// might be interpreted by the client and is not shown in the UI. Some -- /// predefined values exist. --- std::optional> message; --+ std::optional> message; -- -- /// Contains request result if success is true and error details if success is -- /// false. --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h ----- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h --@@ -22,6 +22,8 @@ -- -- #include "Protocol/ProtocolBase.h" -- #include "Protocol/ProtocolTypes.h" --+#include "lldb/lldb-enumerations.h" --+#include "llvm/ADT/DenseSet.h" -- #include "llvm/Support/JSON.h" -- #include -- #include --@@ -55,26 +57,26 @@ -- using DisconnectResponse = VoidResponse; -- -- /// Features supported by DAP clients. ---enum ClientFeature { --- eClientFeatureVariableType, --- eClientFeatureVariablePaging, --- eClientFeatureRunInTerminalRequest, --- eClientFeatureMemoryReferences, --- eClientFeatureProgressReporting, --- eClientFeatureInvalidatedEvent, --- eClientFeatureMemoryEvent, --- /// Client supports the `argsCanBeInterpretedByShell` attribute on the --- /// `runInTerminal` request. --- eClientFeatureArgsCanBeInterpretedByShell, --- eClientFeatureStartDebuggingRequest, --- /// The client will interpret ANSI escape sequences in the display of --- /// `OutputEvent.output` and `Variable.value` fields when --- /// `Capabilities.supportsANSIStyling` is also enabled. --- eClientFeatureANSIStyling, --+FLAGS_ENUM(ClientFeature){ --+ eClientFeatureVariableType, --+ eClientFeatureVariablePaging, --+ eClientFeatureRunInTerminalRequest, --+ eClientFeatureMemoryReferences, --+ eClientFeatureProgressReporting, --+ eClientFeatureInvalidatedEvent, --+ eClientFeatureMemoryEvent, --+ /// Client supports the `argsCanBeInterpretedByShell` attribute on the --+ /// `runInTerminal` request. --+ eClientFeatureArgsCanBeInterpretedByShell, --+ eClientFeatureStartDebuggingRequest, --+ /// The client will interpret ANSI escape sequences in the display of --+ /// `OutputEvent.output` and `Variable.value` fields when --+ /// `Capabilities.supportsANSIStyling` is also enabled. --+ eClientFeatureANSIStyling, -- }; -- -- /// Format of paths reported by the debug adapter. ---enum PathFormat { ePatFormatPath, ePathFormatURI }; --+FLAGS_ENUM(PathFormat){ePatFormatPath, ePathFormatURI}; -- -- /// Arguments for `initialize` request. -- struct InitializeRequestArguments { --diff -ruN --strip-trailing-cr a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h ----- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h --+++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h --@@ -20,6 +20,7 @@ -- #ifndef LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -- #define LLDB_TOOLS_LLDB_DAP_PROTOCOL_PROTOCOL_TYPES_H -- --+#include "lldb/lldb-enumerations.h" -- #include "llvm/ADT/DenseSet.h" -- #include "llvm/Support/JSON.h" -- #include --@@ -56,12 +57,8 @@ -- }; -- llvm::json::Value toJSON(const ExceptionBreakpointsFilter &); -- ---enum ColumnType { --- eColumnTypeString, --- eColumnTypeNumber, --- eColumnTypeBoolean, --- eColumnTypeTimestamp ---}; --+FLAGS_ENUM(ColumnType){eColumnTypeString, eColumnTypeNumber, eColumnTypeBoolean, --+ eColumnTypeTimestamp}; -- -- /// A ColumnDescriptor specifies what module attribute to show in a column of -- /// the modules view, how to format it, and what the column’s label should be. --@@ -90,27 +87,23 @@ -- -- /// Names of checksum algorithms that may be supported by a debug adapter. -- /// Values: ‘MD5’, ‘SHA1’, ‘SHA256’, ‘timestamp’. ---enum ChecksumAlgorithm { --- eChecksumAlgorithmMD5, --- eChecksumAlgorithmSHA1, --- eChecksumAlgorithmSHA256, --- eChecksumAlgorithmTimestamp ---}; --+FLAGS_ENUM(ChecksumAlgorithm){eChecksumAlgorithmMD5, eChecksumAlgorithmSHA1, --+ eChecksumAlgorithmSHA256, --+ eChecksumAlgorithmTimestamp}; -- llvm::json::Value toJSON(const ChecksumAlgorithm &); -- -- /// Describes one or more type of breakpoint a BreakpointMode applies to. This -- /// is a non-exhaustive enumeration and may expand as future breakpoint types -- /// are added. ---enum BreakpointModeApplicability { --- /// In `SourceBreakpoint`'s. --- eBreakpointModeApplicabilitySource, --- /// In exception breakpoints applied in the `ExceptionFilterOptions`. --- eBreakpointModeApplicabilityException, --- /// In data breakpoints requested in the `DataBreakpointInfo` request. --- eBreakpointModeApplicabilityData, --- /// In `InstructionBreakpoint`'s. --- eBreakpointModeApplicabilityInstruction ---}; --+FLAGS_ENUM(BreakpointModeApplicability){ --+ /// In `SourceBreakpoint`'s. --+ eBreakpointModeApplicabilitySource, --+ /// In exception breakpoints applied in the `ExceptionFilterOptions`. --+ eBreakpointModeApplicabilityException, --+ /// In data breakpoints requested in the `DataBreakpointInfo` request. --+ eBreakpointModeApplicabilityData, --+ /// In `InstructionBreakpoint`'s. --+ eBreakpointModeApplicabilityInstruction}; -- llvm::json::Value toJSON(const BreakpointModeApplicability &); -- -- /// A `BreakpointMode` is provided as a option when setting breakpoints on --@@ -133,101 +126,101 @@ -- llvm::json::Value toJSON(const BreakpointMode &); -- -- /// Debug Adapter Features flags supported by lldb-dap. ---enum AdapterFeature { --- /// The debug adapter supports ANSI escape sequences in styling of --- /// `OutputEvent.output` and `Variable.value` fields. --- eAdapterFeatureANSIStyling, --- /// The debug adapter supports the `breakpointLocations` request. --- eAdapterFeatureBreakpointLocationsRequest, --- /// The debug adapter supports the `cancel` request. --- eAdapterFeatureCancelRequest, --- /// The debug adapter supports the `clipboard` context value in the --- /// `evaluate` request. --- eAdapterFeatureClipboardContext, --- /// The debug adapter supports the `completions` request. --- eAdapterFeatureCompletionsRequest, --- /// The debug adapter supports conditional breakpoints. --- eAdapterFeatureConditionalBreakpoints, --- /// The debug adapter supports the `configurationDone` request. --- eAdapterFeatureConfigurationDoneRequest, --- /// The debug adapter supports the `asAddress` and `bytes` fields in the --- /// `dataBreakpointInfo` request. --- eAdapterFeatureDataBreakpointBytes, --- /// The debug adapter supports data breakpoints. --- eAdapterFeatureDataBreakpoints, --- /// The debug adapter supports the delayed loading of parts of the stack, --- /// which requires that both the `startFrame` and `levels` arguments and the --- /// `totalFrames` result of the `stackTrace` request are supported. --- eAdapterFeatureDelayedStackTraceLoading, --- /// The debug adapter supports the `disassemble` request. --- eAdapterFeatureDisassembleRequest, --- /// The debug adapter supports a (side effect free) `evaluate` request for --- /// data hovers. --- eAdapterFeatureEvaluateForHovers, --- /// The debug adapter supports `filterOptions` as an argument on the --- /// `setExceptionBreakpoints` request. --- eAdapterFeatureExceptionFilterOptions, --- /// The debug adapter supports the `exceptionInfo` request. --- eAdapterFeatureExceptionInfoRequest, --- /// The debug adapter supports `exceptionOptions` on the --- /// `setExceptionBreakpoints` request. --- eAdapterFeatureExceptionOptions, --- /// The debug adapter supports function breakpoints. --- eAdapterFeatureFunctionBreakpoints, --- /// The debug adapter supports the `gotoTargets` request. --- eAdapterFeatureGotoTargetsRequest, --- /// The debug adapter supports breakpoints that break execution after a --- /// specified number of hits. --- eAdapterFeatureHitConditionalBreakpoints, --- /// The debug adapter supports adding breakpoints based on instruction --- /// references. --- eAdapterFeatureInstructionBreakpoints, --- /// The debug adapter supports the `loadedSources` request. --- eAdapterFeatureLoadedSourcesRequest, --- /// The debug adapter supports log points by interpreting the `logMessage` --- /// attribute of the `SourceBreakpoint`. --- eAdapterFeatureLogPoints, --- /// The debug adapter supports the `modules` request. --- eAdapterFeatureModulesRequest, --- /// The debug adapter supports the `readMemory` request. --- eAdapterFeatureReadMemoryRequest, --- /// The debug adapter supports restarting a frame. --- eAdapterFeatureRestartFrame, --- /// The debug adapter supports the `restart` request. In this case a client --- /// should not implement `restart` by terminating and relaunching the --- /// adapter but by calling the `restart` request. --- eAdapterFeatureRestartRequest, --- /// The debug adapter supports the `setExpression` request. --- eAdapterFeatureSetExpression, --- /// The debug adapter supports setting a variable to a value. --- eAdapterFeatureSetVariable, --- /// The debug adapter supports the `singleThread` property on the execution --- /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, --- /// `stepBack`). --- eAdapterFeatureSingleThreadExecutionRequests, --- /// The debug adapter supports stepping back via the `stepBack` and --- /// `reverseContinue` requests. --- eAdapterFeatureStepBack, --- /// The debug adapter supports the `stepInTargets` request. --- eAdapterFeatureStepInTargetsRequest, --- /// The debug adapter supports stepping granularities (argument --- /// `granularity`) for the stepping requests. --- eAdapterFeatureSteppingGranularity, --- /// The debug adapter supports the `terminate` request. --- eAdapterFeatureTerminateRequest, --- /// The debug adapter supports the `terminateThreads` request. --- eAdapterFeatureTerminateThreadsRequest, --- /// The debug adapter supports the `suspendDebuggee` attribute on the --- /// `disconnect` request. --- eAdapterFeatureSuspendDebuggee, --- /// The debug adapter supports a `format` attribute on the `stackTrace`, --- /// `variables`, and `evaluate` requests. --- eAdapterFeatureValueFormattingOptions, --- /// The debug adapter supports the `writeMemory` request. --- eAdapterFeatureWriteMemoryRequest, --- /// The debug adapter supports the `terminateDebuggee` attribute on the --- /// `disconnect` request. --- eAdapterFeatureTerminateDebuggee, --+FLAGS_ENUM(AdapterFeature){ --+ /// The debug adapter supports ANSI escape sequences in styling of --+ /// `OutputEvent.output` and `Variable.value` fields. --+ eAdapterFeatureANSIStyling, --+ /// The debug adapter supports the `breakpointLocations` request. --+ eAdapterFeatureBreakpointLocationsRequest, --+ /// The debug adapter supports the `cancel` request. --+ eAdapterFeatureCancelRequest, --+ /// The debug adapter supports the `clipboard` context value in the --+ /// `evaluate` request. --+ eAdapterFeatureClipboardContext, --+ /// The debug adapter supports the `completions` request. --+ eAdapterFeatureCompletionsRequest, --+ /// The debug adapter supports conditional breakpoints. --+ eAdapterFeatureConditionalBreakpoints, --+ /// The debug adapter supports the `configurationDone` request. --+ eAdapterFeatureConfigurationDoneRequest, --+ /// The debug adapter supports the `asAddress` and `bytes` fields in the --+ /// `dataBreakpointInfo` request. --+ eAdapterFeatureDataBreakpointBytes, --+ /// The debug adapter supports data breakpoints. --+ eAdapterFeatureDataBreakpoints, --+ /// The debug adapter supports the delayed loading of parts of the stack, --+ /// which requires that both the `startFrame` and `levels` arguments and the --+ /// `totalFrames` result of the `stackTrace` request are supported. --+ eAdapterFeatureDelayedStackTraceLoading, --+ /// The debug adapter supports the `disassemble` request. --+ eAdapterFeatureDisassembleRequest, --+ /// The debug adapter supports a (side effect free) `evaluate` request for --+ /// data hovers. --+ eAdapterFeatureEvaluateForHovers, --+ /// The debug adapter supports `filterOptions` as an argument on the --+ /// `setExceptionBreakpoints` request. --+ eAdapterFeatureExceptionFilterOptions, --+ /// The debug adapter supports the `exceptionInfo` request. --+ eAdapterFeatureExceptionInfoRequest, --+ /// The debug adapter supports `exceptionOptions` on the --+ /// `setExceptionBreakpoints` request. --+ eAdapterFeatureExceptionOptions, --+ /// The debug adapter supports function breakpoints. --+ eAdapterFeatureFunctionBreakpoints, --+ /// The debug adapter supports the `gotoTargets` request. --+ eAdapterFeatureGotoTargetsRequest, --+ /// The debug adapter supports breakpoints that break execution after a --+ /// specified number of hits. --+ eAdapterFeatureHitConditionalBreakpoints, --+ /// The debug adapter supports adding breakpoints based on instruction --+ /// references. --+ eAdapterFeatureInstructionBreakpoints, --+ /// The debug adapter supports the `loadedSources` request. --+ eAdapterFeatureLoadedSourcesRequest, --+ /// The debug adapter supports log points by interpreting the `logMessage` --+ /// attribute of the `SourceBreakpoint`. --+ eAdapterFeatureLogPoints, --+ /// The debug adapter supports the `modules` request. --+ eAdapterFeatureModulesRequest, --+ /// The debug adapter supports the `readMemory` request. --+ eAdapterFeatureReadMemoryRequest, --+ /// The debug adapter supports restarting a frame. --+ eAdapterFeatureRestartFrame, --+ /// The debug adapter supports the `restart` request. In this case a client --+ /// should not implement `restart` by terminating and relaunching the --+ /// adapter but by calling the `restart` request. --+ eAdapterFeatureRestartRequest, --+ /// The debug adapter supports the `setExpression` request. --+ eAdapterFeatureSetExpression, --+ /// The debug adapter supports setting a variable to a value. --+ eAdapterFeatureSetVariable, --+ /// The debug adapter supports the `singleThread` property on the execution --+ /// requests (`continue`, `next`, `stepIn`, `stepOut`, `reverseContinue`, --+ /// `stepBack`). --+ eAdapterFeatureSingleThreadExecutionRequests, --+ /// The debug adapter supports stepping back via the `stepBack` and --+ /// `reverseContinue` requests. --+ eAdapterFeatureStepBack, --+ /// The debug adapter supports the `stepInTargets` request. --+ eAdapterFeatureStepInTargetsRequest, --+ /// The debug adapter supports stepping granularities (argument --+ /// `granularity`) for the stepping requests. --+ eAdapterFeatureSteppingGranularity, --+ /// The debug adapter supports the `terminate` request. --+ eAdapterFeatureTerminateRequest, --+ /// The debug adapter supports the `terminateThreads` request. --+ eAdapterFeatureTerminateThreadsRequest, --+ /// The debug adapter supports the `suspendDebuggee` attribute on the --+ /// `disconnect` request. --+ eAdapterFeatureSuspendDebuggee, --+ /// The debug adapter supports a `format` attribute on the `stackTrace`, --+ /// `variables`, and `evaluate` requests. --+ eAdapterFeatureValueFormattingOptions, --+ /// The debug adapter supports the `writeMemory` request. --+ eAdapterFeatureWriteMemoryRequest, --+ /// The debug adapter supports the `terminateDebuggee` attribute on the --+ /// `disconnect` request. --+ eAdapterFeatureTerminateDebuggee, -- }; -- -- /// Information about the capabilities of a debug adapter. --@@ -268,10 +261,10 @@ -- }; -- llvm::json::Value toJSON(const Capabilities &); -- ---enum PresentationHint { --- ePresentationHintNormal, --- ePresentationHintEmphasize, --- ePresentationHintDeemphasize, --+FLAGS_ENUM(PresentationHint){ --+ ePresentationHintNormal, --+ ePresentationHintEmphasize, --+ ePresentationHintDeemphasize, -- }; -- -- /// A `Source` is a descriptor for source code. It is returned from the debug --diff -ruN --strip-trailing-cr a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test ----- a/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test --+++ b/llvm/test/tools/llvm-profdata/malformed-ptr-to-counter-array.test --@@ -1,7 +1,7 @@ -- // Header -- // -- // INSTR_PROF_RAW_HEADER(uint64_t, Magic, __llvm_profile_get_magic()) ---// INSTR_PROF_RAW_HEADER(uint64_t, Version, Version) --+// INSTR_PROF_RAW_HEADER(uint64_t, Version, __llvm_profile_get_version()) -- // INSTR_PROF_RAW_HEADER(uint64_t, BinaryIdsSize, __llvm_write_binary_ids(NULL)) -- // INSTR_PROF_RAW_HEADER(uint64_t, DataSize, DataSize) -- // INSTR_PROF_RAW_HEADER(uint64_t, CountersSize, CountersSize) --diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo1.c b/offload/test/offloading/gpupgo/pgo1.c ----- a/offload/test/offloading/gpupgo/pgo1.c --+++ b/offload/test/offloading/gpupgo/pgo1.c --@@ -14,7 +14,7 @@ -- // RUN: %target_triple.%basename_t.clang.profraw | \ -- // RUN: %fcheck-generic --check-prefix="CLANG-PGO" -- ---// REQUIRES: gpu --+// REQUIRES: amdgpu -- // REQUIRES: pgo -- -- int test1(int a) { return a / 2; } --diff -ruN --strip-trailing-cr a/offload/test/offloading/gpupgo/pgo2.c b/offload/test/offloading/gpupgo/pgo2.c ----- a/offload/test/offloading/gpupgo/pgo2.c --+++ b/offload/test/offloading/gpupgo/pgo2.c --@@ -48,7 +48,7 @@ -- // RUN: %target_triple.%basename_t.hfdi.profraw \ -- // RUN: | %fcheck-generic --check-prefix="LLVM-DEVICE" -- ---// REQUIRES: gpu --+// REQUIRES: amdgpu -- // REQUIRES: pgo -- -- int main() { ++diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ++--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp +++++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp ++@@ -873,6 +873,7 @@ ++ DependentTemplateSpecializationTypeLoc SpecTL ++ = Builder.push(T); ++ SpecTL.setElaboratedKeywordLoc(SourceLocation()); +++ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); ++ SpecTL.setTemplateKeywordLoc(TemplateKWLoc); ++ SpecTL.setTemplateNameLoc(TemplateNameLoc); ++ SpecTL.setLAngleLoc(LAngleLoc); ++diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++@@ -1902,7 +1902,6 @@ ++ name = "inv_trigf_utils", ++ srcs = ["src/math/generic/inv_trigf_utils.cpp"], ++ hdrs = [ ++- "src/math/generic/atan_utils.h", ++ "src/math/generic/inv_trigf_utils.h", ++ ], ++ deps = [ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 005737a..fd9baec 100644 +index fd9baec..91166c3 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "c0952a931c7d556ca9f0073d86d591a37eb60477" -- LLVM_SHA256 = "0a24477c0e3d6f3418dad1fe6375a74381b7b174c32c750f97ea05d540dddb84" -+ LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" -+ LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" +- LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" +- LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" ++ LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" ++ LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" tf_http_archive( name = name, +diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch +index 8b13789..90ca4ec 100755 +--- a/third_party/stablehlo/temporary.patch ++++ b/third_party/stablehlo/temporary.patch +@@ -1 +1,49 @@ ++diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ++@@ -12,7 +12,7 @@ ++ return %2 : tensor<14x15x0x33xf64> ++ } ++ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> ++ return %cst : tensor<14x15x0x17xcomplex> ++ } ++ func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { ++diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ++@@ -12,7 +12,7 @@ ++ return %2 : tensor<14x15x0x33xf32> ++ } ++ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> ++ return %cst : tensor<14x15x0x17xcomplex> ++ } ++ func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { ++diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ++@@ -16,7 +16,7 @@ ++ return %cst : tensor<14x15x0x17xf32> ++ } ++ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> ++ return %cst : tensor<14x15x0x9xcomplex> ++ } ++ } ++diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ++--- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir +++++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ++@@ -16,7 +16,7 @@ ++ return %cst : tensor<14x15x0x17xf64> ++ } ++ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { ++- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> +++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> ++ return %cst : tensor<14x15x0x9xcomplex> ++ } ++ } + diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index f93f7a93cea2e8..92b2397d3a6afc 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f25a97f80402e43de93f46931c6dddc485e8dad0" - SHARDY_SHA256 = "c7e55d3902175c064d3dd6bffc856c5e0198ff6c8f1410c3a97ed2c8e85ddb30" + SHARDY_COMMIT = "84e1e3a76e7b827a6d1621df0c649c3449da1540" + SHARDY_SHA256 = "ddb163c1a40466e320b882821e63ee7cb48e31e96bd05109daff38fca0086f21" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8b137891791fe9..90ca4ec1d0d819 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1 +1,49 @@ +diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +@@ -12,7 +12,7 @@ + return %2 : tensor<14x15x0x33xf64> + } + func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> + return %cst : tensor<14x15x0x17xcomplex> + } + func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { +diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir +@@ -12,7 +12,7 @@ + return %2 : tensor<14x15x0x33xf32> + } + func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> + return %cst : tensor<14x15x0x17xcomplex> + } + func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { +diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir +@@ -16,7 +16,7 @@ + return %cst : tensor<14x15x0x17xf32> + } + func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> + return %cst : tensor<14x15x0x9xcomplex> + } + } +diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir +--- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ++++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir +@@ -16,7 +16,7 @@ + return %cst : tensor<14x15x0x17xf64> + } + func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { +- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> ++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> + return %cst : tensor<14x15x0x9xcomplex> + } + } diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 3a09b6e3b33814..464884a2d426f0 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -461,7 +461,7 @@ bool simplifyLoopDeallocs(Block& block) { breaks_if_you_move_ops::ValueSet equivalentOperands; llvm::SmallVector deallocs; bool failed = false; - for (auto member = eq.member_begin(it); + for (auto member = eq.member_begin(*it); !failed && member != eq.member_end(); ++member) { if (operands.contains(*member)) { equivalentOperands.insert(*member); From af3c61d6bf1e9bd08fcab6cd0056ce0b83c478c1 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Thu, 3 Apr 2025 15:18:20 -0700 Subject: [PATCH 0217/1324] Add HLO->MLIR conversion for result accuracy. PiperOrigin-RevId: 743715760 --- .../hlo_to_mhlo/attribute_importer.cc | 22 +++++++++++ .../hlo_to_mhlo/attribute_importer.h | 4 ++ .../hlo_to_mhlo/hlo_function_importer.cc | 39 +++++++++++++------ .../xla/hlo/translate/hlo_to_mhlo/tests/BUILD | 1 + .../hlo_to_mhlo/tests/result_accuracy.hlo | 33 ++++++++++++++++ 5 files changed, 87 insertions(+), 12 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/result_accuracy.hlo diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc index aca98fa965126b..0be4b75f5e11e8 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -406,4 +407,25 @@ absl::StatusOr ExtractLayoutsFromTuple( return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder); } +mlir::mhlo::ResultAccuracyAttr ConvertResultAccuracy( + const ResultAccuracy& result_accuracy, mlir::Builder* builder) { + if (result_accuracy.has_tolerance()) { + return mlir::mhlo::ResultAccuracyAttr::get( + builder->getContext(), + llvm::APFloat(result_accuracy.tolerance().atol()), + llvm::APFloat(result_accuracy.tolerance().rtol()), + result_accuracy.tolerance().ulps(), + // Explicitly set the mode to TOLERANCE since ResultAccuracy has no + // TOLERANCE enum. + mlir::mhlo::ResultAccuracyModeAttr::get( + builder->getContext(), mlir::mhlo::ResultAccuracyMode::TOLERANCE)); + } + return mlir::mhlo::ResultAccuracyAttr::get( + builder->getContext(), llvm::APFloat(0.0), llvm::APFloat(0.0), 0, + mlir::mhlo::ResultAccuracyModeAttr::get( + builder->getContext(), + mlir::mhlo::symbolizeResultAccuracyMode(result_accuracy.mode()) + .value())); +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h index 93e4514aa6091c..b64ed37483a264 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h @@ -103,6 +103,10 @@ absl::StatusOr ExtractLayoutsFromShapes( absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, mlir::Builder* builder); +// Converts the ResultAccuracy to ResultAccuracyAttr. +mlir::mhlo::ResultAccuracyAttr ConvertResultAccuracy( + const ResultAccuracy& result_accuracy, mlir::Builder* builder); + } // namespace xla #endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 72acb9e0fc58f3..44d298bd69f5a8 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -2047,20 +2047,14 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kAnd, AndOp); NO_ATTRIBUTE_CASE(kAtan2, Atan2Op); NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp); - NO_ATTRIBUTE_CASE(kCbrt, CbrtOp); NO_ATTRIBUTE_CASE(kClz, ClzOp); NO_ATTRIBUTE_CASE(kCeil, CeilOp); NO_ATTRIBUTE_CASE(kClamp, ClampOp); NO_ATTRIBUTE_CASE(kComplex, ComplexOp); - NO_ATTRIBUTE_CASE(kCos, CosineOp); NO_ATTRIBUTE_CASE(kDivide, DivOp); - NO_ATTRIBUTE_CASE(kExp, ExpOp); - NO_ATTRIBUTE_CASE(kExpm1, Expm1Op); NO_ATTRIBUTE_CASE(kFloor, FloorOp); NO_ATTRIBUTE_CASE(kIsFinite, IsFiniteOp); NO_ATTRIBUTE_CASE(kImag, ImagOp); - NO_ATTRIBUTE_CASE(kLog, LogOp); - NO_ATTRIBUTE_CASE(kLog1p, Log1pOp); NO_ATTRIBUTE_CASE(kMaximum, MaxOp); NO_ATTRIBUTE_CASE(kMinimum, MinOp); NO_ATTRIBUTE_CASE(kMultiply, MulOp); @@ -2074,7 +2068,6 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kRemainder, RemOp); NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp); NO_ATTRIBUTE_CASE(kStochasticConvert, StochasticConvertOp); - NO_ATTRIBUTE_CASE(kLogistic, LogisticOp); NO_ATTRIBUTE_CASE(kErf, ErfOp); // The dimensions attribute is not present on the HLO Reshape // instruction. If dimensions are non-default, the XLA builder @@ -2082,23 +2075,45 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kReshape, ReshapeOp); NO_ATTRIBUTE_CASE(kRoundNearestAfz, RoundOp); NO_ATTRIBUTE_CASE(kRoundNearestEven, RoundNearestEvenOp); - NO_ATTRIBUTE_CASE(kRsqrt, RsqrtOp); NO_ATTRIBUTE_CASE(kSelect, SelectOp); NO_ATTRIBUTE_CASE(kShiftLeft, ShiftLeftOp); NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp); NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp); NO_ATTRIBUTE_CASE(kSign, SignOp); - NO_ATTRIBUTE_CASE(kSin, SineOp); - NO_ATTRIBUTE_CASE(kSqrt, SqrtOp); NO_ATTRIBUTE_CASE(kSubtract, SubtractOp); - NO_ATTRIBUTE_CASE(kTan, TanOp); - NO_ATTRIBUTE_CASE(kTanh, TanhOp); NO_ATTRIBUTE_CASE(kTuple, TupleOp); NO_ATTRIBUTE_CASE(kXor, XorOp); NO_ATTRIBUTE_CASE(kCopy, CopyOp); #undef NO_ATTRIBUTE_CASE +#define RESULT_ACCURACY_CASE(hlo_op_code, mlir_op) \ + case HloOpcode::hlo_op_code: { \ + if (instruction->has_result_accuracy()) { \ + attributes.push_back(builder_->getNamedAttr( \ + "result_accuracy", \ + ConvertResultAccuracy(instruction->result_accuracy(), builder_))); \ + } \ + return func_builder \ + ->create(loc, result_type, operands, attributes) \ + .getOperation(); \ + } + + RESULT_ACCURACY_CASE(kCbrt, CbrtOp); + RESULT_ACCURACY_CASE(kCos, CosineOp); + RESULT_ACCURACY_CASE(kExp, ExpOp); + RESULT_ACCURACY_CASE(kExpm1, Expm1Op); + RESULT_ACCURACY_CASE(kLog, LogOp); + RESULT_ACCURACY_CASE(kLog1p, Log1pOp); + RESULT_ACCURACY_CASE(kLogistic, LogisticOp); + RESULT_ACCURACY_CASE(kRsqrt, RsqrtOp); + RESULT_ACCURACY_CASE(kSin, SineOp); + RESULT_ACCURACY_CASE(kSqrt, SqrtOp); + RESULT_ACCURACY_CASE(kTan, TanOp); + RESULT_ACCURACY_CASE(kTanh, TanhOp); + +#undef RESULT_ACCURACY_CASE + case HloOpcode::kFusion: { // Flatten the tuple-typed operands. llvm::SmallVector flattened_operands = diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index b4c7e22be7f205..a407328a2ada0b 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -38,6 +38,7 @@ lit_test_suite( "stacktrace_to_location.hlo", "types.hlo", "while.hlo", + "result_accuracy.hlo", ], include = [ "*.hlo", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/result_accuracy.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/result_accuracy.hlo new file mode 100644 index 00000000000000..5464e511c9dd8b --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/result_accuracy.hlo @@ -0,0 +1,33 @@ +// RUN: hlo-translate -hlo-to-mlir -emit-mhlo -split-input-file %s | FileCheck %s + +HloModule main, entry_computation_layout={(f32[])->f32[]} + +ENTRY %main (Arg_0.1: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + // CHECK: %0 = mhlo.exponential %arg0 {result_accuracy = #mhlo.result_accuracy>} : tensor + ROOT %exponential.2 = f32[] exponential(%Arg_0.1), result_accuracy={tolerance={atol=1.0,rtol=0,ulps=10}} +} + +// ----- + +ENTRY %main (Arg_0.1: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + // CHECK: %0 = mhlo.exponential %arg0 {result_accuracy = #mhlo.result_accuracy>} : tensor + ROOT %exponential.2 = f32[] exponential(%Arg_0.1), result_accuracy={mode=HIGHEST} +} + +// ----- + +ENTRY %main (Arg_0.1: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + // CHECK: %0 = mhlo.exponential %arg0 : tensor + ROOT %exponential.2 = f32[] exponential(%Arg_0.1), result_accuracy={mode=DEFAULT} +} + +// ----- + +ENTRY %main (Arg_0.1: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + // CHECK: %0 = mhlo.exponential %arg0 : tensor + ROOT %exponential.2 = f32[] exponential(%Arg_0.1) +} From 40dc788ec0377f20ebf09ee932d8fb9ca1bfa78c Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 3 Apr 2025 16:04:41 -0700 Subject: [PATCH 0218/1324] [xla:cpu:xnn] Make XnnDotThunk support BF16 x BF16 -> F32 matmul. PiperOrigin-RevId: 743730787 --- third_party/xla/xla/backends/cpu/BUILD | 2 + .../cpu/benchmarks/dot_benchmark_test.cc | 4 +- .../xla/xla/backends/cpu/runtime/BUILD | 1 - .../xla/xla/backends/cpu/runtime/dot_thunk.cc | 3 + .../xla/backends/cpu/runtime/xnnpack/BUILD | 6 +- .../cpu/runtime/xnnpack/xnn_dot_thunk.cc | 17 +- .../cpu/runtime/xnnpack/xnn_dot_thunk_test.cc | 54 +++++-- .../xla/xla/backends/cpu/xnn_fusion.cc | 18 ++- third_party/xla/xla/backends/cpu/xnn_fusion.h | 4 +- third_party/xla/xla/service/cpu/BUILD | 13 +- .../xla/xla/service/cpu/cpu_compiler.cc | 30 +++- .../xla/xla/service/cpu/cpu_float_support.h | 18 ++- .../xla/service/cpu/cpu_float_support_test.cc | 152 +++++++++++++----- .../xla/xla/service/cpu/thunk_emitter.cc | 9 +- 14 files changed, 257 insertions(+), 74 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/BUILD b/third_party/xla/xla/backends/cpu/BUILD index 726b57671f6e80..42d359f6f76cd3 100644 --- a/third_party/xla/xla/backends/cpu/BUILD +++ b/third_party/xla/xla/backends/cpu/BUILD @@ -97,11 +97,13 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/codegen:target_machine_features", "//xla/backends/cpu/runtime:dot_lib", "//xla/hlo/ir:hlo", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", ], ) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc index ced6f8f7af0a93..56211f3ef29a44 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/dot_benchmark_test.cc @@ -35,6 +35,7 @@ namespace xla::cpu { static void BM_BatchedDot(benchmark::State& state, HloBenchmarkOptions options) { PrimitiveType dtype = static_cast(state.range(0)); + PrimitiveType out_dtype = F32; int64_t d0 = state.range(1); int64_t d1 = state.range(2); @@ -44,7 +45,7 @@ static void BM_BatchedDot(benchmark::State& state, ENTRY e { p0 = $dtype[$d0,$d1,$d1] parameter(0) p1 = $dtype[$d0,$d1,$d1] parameter(1) - ROOT dot = $dtype[$d0,$d1,$d1] dot(p0, p1), + ROOT dot = $out_dtype[$d0,$d1,$d1] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} } @@ -69,6 +70,7 @@ static void BM_BatchedDot(benchmark::State& state, CHECK_OK(RunHloBenchmark( state, hlo, args, {{"$dtype", primitive_util::LowercasePrimitiveTypeName(dtype)}, + {"$out_dtype", primitive_util::LowercasePrimitiveTypeName(out_dtype)}, {"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}}, options)); diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 5759e3657150a4..eae9577bbbe3e1 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -818,7 +818,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/lib:traceme", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc index 7a68eaf0503fc8..cd68d42acf28b2 100644 --- a/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc @@ -172,6 +172,9 @@ tsl::AsyncValueRef DotThunk::Execute( }; switch (element_type) { + case BF16: + dispatch(bfloat16{}); // Enable Eigen BF16 kernel for fallback. + break; case F16: dispatch(half{}); break; diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD index 897feba2b377ea..44de812771bdde 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD @@ -141,6 +141,7 @@ cc_library( ":xnn_fusion_thunk", ":xnn_interop", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/runtime:dot_lib", "//xla/backends/cpu/runtime:thunk", @@ -162,7 +163,6 @@ xla_cc_test( srcs = ["xnn_dot_thunk_test.cc"], deps = [ ":xnn_dot_thunk", - "//xla:executable_run_options", "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -173,10 +173,10 @@ xla_cc_test( "//xla/tsl/platform:env", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:platform_port", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc index 7a667cd47439aa..8ff54f739a9e8e 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc @@ -33,11 +33,13 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" #include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" +#include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { @@ -60,16 +62,25 @@ absl::StatusOr XnnDotThunk::BuildDotSubgraph( std::vector rhs_dims = dims(dot_slices_.rhs_shape.dimensions()); std::vector out_dims = dims(dot_slices_.out_shape.dimensions()); + PrimitiveType dtype = dot_slices_.lhs_shape.element_type(); + if (dtype != F32 && dtype != BF16) { + return InvalidArgument("Unsupported input data type for XnnDotThunk: %s", + primitive_util::LowercasePrimitiveTypeName(dtype)); + } + xnn_datatype input_dtype = + (dtype == F32) ? xnn_datatype_fp32 : xnn_datatype_bf16; + xnn_datatype output_dtype = xnn_datatype_fp32; + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, lhs_dims.size(), lhs_dims.data(), nullptr, + subgraph, input_dtype, lhs_dims.size(), lhs_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id)); XNN_RETURN_IF_ERROR(xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, rhs_dims.size(), rhs_dims.data(), nullptr, + subgraph, input_dtype, rhs_dims.size(), rhs_dims.data(), nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id)); XNN_RETURN_IF_ERROR(xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, out_dims.size(), out_dims.data(), nullptr, + subgraph, output_dtype, out_dims.size(), out_dims.data(), nullptr, /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id)); XNN_RETURN_IF_ERROR(xnn_define_batch_matrix_multiply( diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc index 42c218cab33274..547f6c0c65ce64 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h" +#include +#include + +#include "absl/strings/str_cat.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk_testlib.h" @@ -27,6 +31,7 @@ limitations under the License. #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/threadpool.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/cpu_info.h" #define EIGEN_USE_THREADS #include "unsupported/Eigen/CXX11/Tensor" @@ -34,12 +39,27 @@ limitations under the License. namespace xla::cpu { namespace { -class XnnDotThunkTest : public testing::TestWithParam { - protected: - bool use_threadpool() const { return GetParam(); } +struct XnnDotThunkTestSpec { + PrimitiveType input_type; + bool use_threadpool; +}; + +class XnnDotThunkTest : public testing::TestWithParam { + public: + static std::string Name( + const ::testing::TestParamInfo& info) { + return absl::StrCat( + PrimitiveType_Name(info.param.input_type), "_", + info.param.use_threadpool ? "threadpool" : "single_threaded"); + } }; TEST_P(XnnDotThunkTest, SimpleDot) { + XnnDotThunkTestSpec spec = GetParam(); + if (spec.input_type == BF16 && + !tsl::port::TestCPUFeature(tsl::port::AVX512_BF16)) { + GTEST_SKIP() << "CPU needs AVX512_BF16 for this test."; + } tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), threads.NumThreads()); @@ -47,6 +67,10 @@ TEST_P(XnnDotThunkTest, SimpleDot) { auto lhs = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{4.0, 3.0}, {2.0, 1.0}}); auto out = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + if (spec.input_type == BF16) { + lhs = LiteralUtil::ConvertF32ToBF16(lhs); + rhs = LiteralUtil::ConvertF32ToBF16(rhs); + } BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out); @@ -55,20 +79,22 @@ TEST_P(XnnDotThunkTest, SimpleDot) { auto [lhs_slice, rhs_slice, out_slice] = CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc); - Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape input_shape = ShapeUtil::MakeShape(spec.input_type, {2, 2}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 2}); DotDimensionNumbers dot_dimensions; dot_dimensions.add_lhs_contracting_dimensions(1); dot_dimensions.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN( - auto thunk, XnnDotThunk::Create(XnnDotThunk::Options{use_threadpool()}, - {"dot"}, dot_dimensions, lhs_slice, shape, - rhs_slice, shape, out_slice, shape)); + auto thunk, + XnnDotThunk::Create(XnnDotThunk::Options{spec.use_threadpool}, {"dot"}, + dot_dimensions, lhs_slice, input_shape, rhs_slice, + input_shape, out_slice, output_shape)); Thunk::ExecuteParams params; params.buffer_allocations = &allocations; - params.intra_op_threadpool = use_threadpool() ? &device : nullptr; + params.intra_op_threadpool = spec.use_threadpool ? &device : nullptr; auto execute_event = thunk->Execute(params); tsl::BlockUntilReady(execute_event); @@ -77,7 +103,17 @@ TEST_P(XnnDotThunkTest, SimpleDot) { EXPECT_EQ(out, LiteralUtil::CreateR2({{8.0, 5.0}, {20.0, 13.0}})); } -INSTANTIATE_TEST_SUITE_P(XnnDot, XnnDotThunkTest, testing::Values(true, false)); +std::vector GetXnnDotThunkTestSpecs() { + return std::vector{ + XnnDotThunkTestSpec{F32, /*use_threadpool=*/true}, + XnnDotThunkTestSpec{F32, /*use_threadpool=*/false}, + XnnDotThunkTestSpec{BF16, /*use_threadpool=*/true}, + XnnDotThunkTestSpec{BF16, /*use_threadpool=*/false}}; +} + +INSTANTIATE_TEST_SUITE_P(XnnDot, XnnDotThunkTest, + ::testing::ValuesIn(GetXnnDotThunkTestSpecs()), + XnnDotThunkTest::Name); } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/xnn_fusion.cc b/third_party/xla/xla/backends/cpu/xnn_fusion.cc index 0f0887ba3978f2..70e45e3ea1df99 100644 --- a/third_party/xla/xla/backends/cpu/xnn_fusion.cc +++ b/third_party/xla/xla/backends/cpu/xnn_fusion.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/backends/cpu/runtime/dot_lib.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -78,10 +79,21 @@ bool XnnShouldUseThreadPool(const HloComputation* computation) { absl::StatusOr IsXnnDotSupported( const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, - const Shape& rhs_shape, const Shape& out_shape) { + const Shape& rhs_shape, const Shape& out_shape, + TargetMachineFeatures* cpu_features) { // TODO(ezhulenev): Support other element types. - if (lhs_shape.element_type() != F32 || rhs_shape.element_type() != F32 || - out_shape.element_type() != F32) { + auto check_dtype = [&](PrimitiveType in_dtype, PrimitiveType out_dtype) { + return lhs_shape.element_type() == in_dtype && + rhs_shape.element_type() == in_dtype && + out_shape.element_type() == out_dtype; + }; + + // We assume that the feature is available if `cpu_features` is not provided. + bool cpu_has_avx512bf16 = + cpu_features == nullptr || cpu_features->has_avx512bf16(); + bool dtype_is_supported = + check_dtype(F32, F32) || (check_dtype(BF16, F32) && cpu_has_avx512bf16); + if (!dtype_is_supported) { return false; } diff --git a/third_party/xla/xla/backends/cpu/xnn_fusion.h b/third_party/xla/xla/backends/cpu/xnn_fusion.h index 96ad289b853dcd..755594af7208a5 100644 --- a/third_party/xla/xla/backends/cpu/xnn_fusion.h +++ b/third_party/xla/xla/backends/cpu/xnn_fusion.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape.h" @@ -37,7 +38,8 @@ bool XnnShouldUseThreadPool(const HloComputation* computation); // if the dot operation shape is invalid. absl::StatusOr IsXnnDotSupported( const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, - const Shape& rhs_shape, const Shape& out_shape); + const Shape& rhs_shape, const Shape& out_shape, + TargetMachineFeatures* cpu_features = nullptr); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 7ab47f7bd9c6f1..fb7865d60a4b54 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -200,6 +200,7 @@ cc_library( ":conv_canonicalization", ":cpu_aot_compilation_result", ":cpu_executable", + ":cpu_float_support", ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", @@ -230,6 +231,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/cpu:constant_allocation", + "//xla/backends/cpu:xnn_fusion", "//xla/backends/cpu/codegen:compiled_function_library", "//xla/backends/cpu/codegen:cpu_features", "//xla/backends/cpu/codegen:execution_engine", @@ -2069,6 +2071,8 @@ cc_library( hdrs = ["cpu_float_support.h"], copts = tsl_copts(), deps = [ + "//xla/backends/cpu:xnn_fusion", + "//xla/backends/cpu/codegen:target_machine_features", "//xla/hlo/ir:hlo", "//xla/service:float_support", ], @@ -2080,14 +2084,17 @@ xla_cc_test( tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":cpu_float_support", + "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/codegen:target_machine_features", + "//xla/backends/cpu/codegen:target_machine_test_base", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/transforms/simplifiers:float_normalization", - "//xla/service:hlo_runner", - "//xla/service:platform_util", - "//xla/tests:hlo_runner_agnostic_test_base", + "//xla/service:hlo_module_config", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index eca850642fbe51..55c06f338ac6aa 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -95,6 +95,7 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/thunk_proto_serdes.h" #include "xla/backends/cpu/transforms/xnn_graph_fusion.h" +#include "xla/backends/cpu/xnn_fusion.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/analysis/indexed_array_analysis.h" @@ -166,6 +167,7 @@ limitations under the License. #include "xla/service/cpu/conv_canonicalization.h" #include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/service/cpu/cpu_executable.h" +#include "xla/service/cpu/cpu_float_support.h" #include "xla/service/cpu/cpu_instruction_fusion.h" #include "xla/service/cpu/cpu_layout_assignment.h" #include "xla/service/cpu/cpu_options.h" @@ -555,7 +557,30 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( AddHloVerifier(&pipeline); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + + // If XNNPACK is enabled, we only need to upcast dots that XnnDotThunk does + // not support. `upcaster_filter` returns false if the instruction shouldn't + // be processed. + // TODO(b/406806134): Stop calling XNNPACK from regular Dot thunks. All XNN + // Dots should be wrapped in an `__xnn_fusion` fusion region and processed in + // `XnnFusionThunk`. + bool xnnpack_enabled = module->config().debug_options().xla_cpu_use_xnnpack(); + auto call_library_for_dot = [&](const HloInstruction& instr) { + if (!xnnpack_enabled) return false; + DotImplementationStrategy strategy = GetDotImplementationStrategy( + module->config(), instr, *target_machine_features, + /*allow_runtime_calls=*/true); + return strategy == DotImplementationStrategy::kEigen; + }; + HloPredicate upcaster_filter = [&](const HloInstruction* instr) { + if (!call_library_for_dot(*instr)) return true; + return !IsXnnDotSupported(instr->dot_dimension_numbers(), + instr->operand(0)->shape(), + instr->operand(1)->shape(), instr->shape(), + target_machine_features) + .value_or(false); + }; + pipeline.AddPass(upcaster_filter); // Expand random number generation. pipeline.AddPass(); @@ -609,7 +634,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( // Convert BF16 and F8 operations to F32 and F16 respectively so that the CPU // backend can support BF16/F8 operations without directly implementing a // BF16/F8 lowering for most ops. - FloatSupport bf16_support(BF16); + CpuFloatSupport bf16_support(BF16, call_library_for_dot, + target_machine_features); #if defined(INTEL_MKL) OneDnnFloatSupport onednn_bf16_support(BF16); if (!is_aot_compile && !is_thunk_runtime) { diff --git a/third_party/xla/xla/service/cpu/cpu_float_support.h b/third_party/xla/xla/service/cpu/cpu_float_support.h index 5fbe04bd638fd3..4582d97d78bff0 100644 --- a/third_party/xla/xla/service/cpu/cpu_float_support.h +++ b/third_party/xla/xla/service/cpu/cpu_float_support.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/backends/cpu/xnn_fusion.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/float_support.h" @@ -30,17 +32,25 @@ class CpuFloatSupport : public FloatSupport { using DotStrategyChecker = std::function; explicit CpuFloatSupport(PrimitiveType low_precision_type, - DotStrategyChecker call_library_for_dot) + DotStrategyChecker call_library_for_dot, + TargetMachineFeatures* cpu_features) : FloatSupport(low_precision_type), - call_library_for_dot_(call_library_for_dot) {} + call_library_for_dot_(call_library_for_dot), + cpu_features_(cpu_features) {} - // A hatch to skip FloatNormalization for certain instructions. + // Skip trying to upcast the dot if XNNPACK is enabled and the dot is + // supported by XNNPACK. bool ShouldSkipInstruction(const HloInstruction& hlo) const override { - return hlo.opcode() == HloOpcode::kDot && call_library_for_dot_(hlo); + return hlo.opcode() == HloOpcode::kDot && call_library_for_dot_(hlo) && + IsXnnDotSupported(hlo.dot_dimension_numbers(), + hlo.operand(0)->shape(), hlo.operand(1)->shape(), + hlo.shape(), cpu_features_) + .value_or(false); } private: DotStrategyChecker call_library_for_dot_; + TargetMachineFeatures* cpu_features_; }; } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/cpu_float_support_test.cc b/third_party/xla/xla/service/cpu/cpu_float_support_test.cc index e4af26e0629fa9..79ac03cfdddf90 100644 --- a/third_party/xla/xla/service/cpu/cpu_float_support_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_float_support_test.cc @@ -16,13 +16,22 @@ limitations under the License. #include "xla/service/cpu/cpu_float_support.h" #include +#include +#include +#include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/backends/cpu/codegen/target_machine_test_base.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" -#include "xla/service/hlo_runner.h" -#include "xla/service/platform_util.h" -#include "xla/tests/hlo_runner_agnostic_test_base.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -30,52 +39,115 @@ limitations under the License. namespace xla::cpu { namespace { -class SkipInstructionTest : public HloRunnerAgnosticTestBase { +struct SkipInstructionTestSpec { + HloOpcode op; + bool call_library_for_dot; + std::string cpu_name; + std::string features; + bool upcast; +}; + +class SkipInstructionTest + : public TargetMachineTestBase, + public ::testing::WithParamInterface { public: - SkipInstructionTest() - : HloRunnerAgnosticTestBase(std::make_unique( - PlatformUtil::GetDefaultPlatform().value())) {} - void SetUp() override { HloRunnerAgnosticTestBase::SetUp(); } + static std::string Name( + const ::testing::TestParamInfo& info) { + absl::string_view op = HloOpcodeString(info.param.op); + absl::string_view dot_strategy = + info.param.call_library_for_dot ? "LibDot" : "NoLibDot"; + absl::string_view bf16_strategy = + absl::StrContains(info.param.features, "+avx512bf16") ? "Bf16" + : "NoBf16"; + return absl::StrCat(op, "_", dot_strategy, "_", bf16_strategy); + } + + void SetUp() override { TargetMachineTestBase::SetUp(); } + + void CheckDtype(HloModule* module, PrimitiveType lhs_type, + PrimitiveType rhs_type, PrimitiveType out_type) { + HloInstruction* op = module->entry_computation()->root_instruction(); + EXPECT_EQ(op->operand(0)->shape().element_type(), lhs_type); + EXPECT_EQ(op->operand(1)->shape().element_type(), rhs_type); + EXPECT_EQ(op->shape().element_type(), out_type); + } }; -TEST_F(SkipInstructionTest, SkipDot) { - constexpr absl::string_view kHlo = R"( - HloModule test - - ENTRY main { - p0 = bf16[100,100] parameter(0) - p1 = bf16[100,100] parameter(1) - ROOT dot = f32[100,100] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHlo)); +TEST_P(SkipInstructionTest, Bf16InF32Out) { + SkipInstructionTestSpec spec = GetParam(); + + // Create the HLO module: p0 p1. + HloComputation::Builder builder("SkipInstructionTest"); + Shape input_shape = ShapeUtil::MakeShape(BF16, {100, 100}); + Shape output_shape = ShapeUtil::MakeShape(F32, {100, 100}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, input_shape, "p1")); + if (spec.op == HloOpcode::kDot) { + DotDimensionNumbers dot_dimensions; + dot_dimensions.add_lhs_contracting_dimensions(1); + dot_dimensions.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot( + output_shape, p0, p1, dot_dimensions, PrecisionConfig())); + } else { + builder.AddInstruction( + HloInstruction::CreateBinary(output_shape, spec.op, p0, p1)); + } + std::unique_ptr computation = builder.Build(); + std::unique_ptr module = std::make_unique( + "test", HloModuleConfig(), + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true, + ShapeUtil::ByteSizeOfElements); + module->AddEntryComputation(std::move(computation)); + + // Create CpuFloatSupport. CpuFloatSupport::DotStrategyChecker call_library_for_dot = - [](const HloInstruction& hlo) { return true; }; - CpuFloatSupport cpu_float_support(BF16, call_library_for_dot); + [&spec](const HloInstruction& hlo) { return spec.call_library_for_dot; }; + std::unique_ptr features = CreateTargetMachineFeatures( + "x86_64-unknown-linux-gnu", spec.cpu_name, spec.features); + CpuFloatSupport cpu_float_support(BF16, call_library_for_dot, features.get()); + + // Run FloatNormalization and check the results. FloatNormalization float_normalization(&cpu_float_support); TF_ASSERT_OK_AND_ASSIGN(bool upcast, float_normalization.Run(module.get())); - EXPECT_EQ(upcast, false); + EXPECT_EQ(upcast, spec.upcast); + PrimitiveType expected_input_dtype = spec.upcast ? F32 : BF16; + CheckDtype(module.get(), expected_input_dtype, expected_input_dtype, F32); } -TEST_F(SkipInstructionTest, UpcastAdd) { - constexpr absl::string_view kHlo = R"( - HloModule test - - ENTRY main { - p0 = bf16[100,100] parameter(0) - p1 = bf16[100,100] parameter(1) - ROOT add = f32[100,100] add(p0, p1) - })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHlo)); - CpuFloatSupport::DotStrategyChecker call_library_for_dot = - [](const HloInstruction& hlo) { return true; }; - CpuFloatSupport cpu_float_support(BF16, call_library_for_dot); - FloatNormalization float_normalization(&cpu_float_support); - TF_ASSERT_OK_AND_ASSIGN(bool upcast, float_normalization.Run(module.get())); - EXPECT_EQ(upcast, true); +std::vector GetSkipInstructionTestSpecs() { + return std::vector{ + // Add op, always upcast. + SkipInstructionTestSpec{HloOpcode::kAdd, + /*call_library_for_dot=*/true, + /*cpu_name=*/"sapphirerapids", + /*features=*/"+avx512bf16", + /*upcast=*/true}, + // CPU has BF16, but library dot is disabled. + SkipInstructionTestSpec{HloOpcode::kDot, + /*call_library_for_dot=*/false, + /*cpu_name=*/"sapphirerapids", + /*features=*/"+avx512bf16", + /*upcast=*/true}, + // Library dot is enabled, but CPU does not have BF16. + SkipInstructionTestSpec{HloOpcode::kDot, + /*call_library_for_dot=*/true, + /*cpu_name=*/"znver3", + /*features=*/"+avx2", + /*upcast=*/true}, + // Library dot is enabled and CPU has BF16. Use mixed precision. + SkipInstructionTestSpec{HloOpcode::kDot, + /*call_library_for_dot=*/true, + /*cpu_name=*/"sapphirerapids", + /*features=*/"+avx512bf16", + /*upcast=*/false}}; } +INSTANTIATE_TEST_SUITE_P(SkipInstructionTestSuite, SkipInstructionTest, + ::testing::ValuesIn(GetSkipInstructionTestSpecs()), + SkipInstructionTest::Name); + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index b97d6e483e7640..7d88c6d3bb3a04 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -852,10 +852,11 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( const HloInstruction* lhs = instruction->operand(0); const HloInstruction* rhs = instruction->operand(1); - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - *instruction, /*operands=*/{lhs, rhs}, - /*supported_types=*/ - {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128})); + TF_RETURN_IF_ERROR( + ElementTypesSameAndSupported(*instruction, /*operands=*/{lhs, rhs}, + /*supported_types=*/ + {PRED, S8, U8, S16, U16, S32, U32, S64, U64, + BF16, F16, F32, F64, C64, C128})); const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); if (dnums.lhs_contracting_dimensions_size() != 1) { From 157db184db3444b701397bb45a92725ccca2e1c4 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Thu, 3 Apr 2025 16:14:00 -0700 Subject: [PATCH 0219/1324] [XLA] Make Bfloat16Propagation flow through execution threads This simply adds the needed control flow handling for kCall, kAsyncStart and kAsyncDone. PiperOrigin-RevId: 743733923 --- third_party/xla/xla/hlo/transforms/BUILD | 2 + .../hlo/transforms/bfloat16_propagation.cc | 166 +++++++++++++++--- .../xla/hlo/transforms/bfloat16_propagation.h | 28 +-- .../transforms/bfloat16_propagation_test.cc | 101 +++++++++++ .../xla/xla/hlo/transforms/simplifiers/BUILD | 1 + .../simplifiers/float_normalization.cc | 2 + .../simplifiers/float_normalization_test.cc | 68 +++++++ 7 files changed, 332 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index c26539cdf2e81d..84943b76aecb3b 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -66,8 +66,10 @@ xla_cc_test( "//xla/service:float_support", "//xla/service:hlo_verifier", "//xla/tests:literal_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", # fixdeps: keep "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc index ef892467482774..b4dac526785c0c 100644 --- a/third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc @@ -245,6 +245,57 @@ void BFloat16Propagation::DetermineConditionalComputationsPrecision( } } +void BFloat16Propagation::DetermineAsyncComputationsPrecision( + HloInstruction* async_start) { + CHECK_EQ(async_start->opcode(), HloOpcode::kAsyncStart); + + auto root = async_start->async_wrapped_instruction(); + ShapeUtil::ForEachSubshape(root->shape(), [&](const Shape& subshape, + const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + if (OutputTypeAfterChange(async_start->async_chain_done(), index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); + VLOG(2) << "Async wrapped computation root " << root->ToString() + << " at shape index " << index + << " changed to BF16 precision for async start " + << async_start->ToString(); + } + }); + auto insts = + async_start->async_wrapped_computation()->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_backward_pass_.insert( + async_start->async_wrapped_computation()); +} + +void BFloat16Propagation::DetermineCalledComputationsPrecision( + HloInstruction* call) { + CHECK_EQ(call->opcode(), HloOpcode::kCall); + + auto root = call->to_apply()->root_instruction(); + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + if (OutputTypeAfterChange(call, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index + << " changed to BF16 precision for call " << call->ToString(); + } + }); + auto insts = call->to_apply()->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_backward_pass_.insert(call->to_apply()); +} + bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { // If the subshape isn't floating point then none of the users will be BF16. @@ -315,6 +366,28 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return false; } continue; + } else if (use.instruction->opcode() == HloOpcode::kAsyncStart && + HloInstruction::IsThreadIncluded( + use.instruction->async_execution_thread(), + execution_threads_)) { + auto* async_parameter = + use.instruction->async_wrapped_computation()->parameter_instruction( + use.operand_number); + if (OutputTypeAfterChange(async_parameter, use.operand_index) != BF16) { + return false; + } + continue; + } else if (use.instruction->opcode() == HloOpcode::kCall) { + auto* call_parameter = + use.instruction->to_apply()->parameter_instruction( + use.operand_number); + if (OutputTypeAfterChange(call_parameter, use.operand_index) != BF16) { + return false; + } + continue; + } else if (use.instruction->opcode() == HloOpcode::kAsyncDone) { + // async-done consumes whatever async-start gives it. + continue; } if (bfloat16_support_->EffectiveOperandPrecisionIsLowPrecision( *use.instruction, use.operand_number)) { @@ -371,9 +444,11 @@ bool BFloat16Propagation::ShouldKeepPrecisionUnchanged( // since it is merely a buffer allocation and does not have any side effects. return (inst->opcode() == HloOpcode::kCustomCall && !inst->IsCustomCall("AllocateBuffer")) || - inst->opcode() == HloOpcode::kCall || inst->opcode() == HloOpcode::kBitcastConvert || - inst->HasSideEffectNoRecurse(); + inst->HasSideEffectNoRecurse() || + (inst->IsAsynchronous() && + !HloInstruction::IsThreadIncluded(inst->async_execution_thread(), + execution_threads_)); } void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, @@ -392,6 +467,12 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, DetermineWhileComputationsPrecision(hlo); } else if (hlo->opcode() == HloOpcode::kConditional) { DetermineConditionalComputationsPrecision(hlo); + } else if (hlo->opcode() == HloOpcode::kAsyncStart && + HloInstruction::IsThreadIncluded(hlo->async_execution_thread(), + execution_threads_)) { + DetermineAsyncComputationsPrecision(hlo); + } else if (hlo->opcode() == HloOpcode::kCall) { + DetermineCalledComputationsPrecision(hlo); } } instructions_visited_in_backward_pass_.insert(hlo); @@ -412,6 +493,20 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } + if (hlo->opcode() == HloOpcode::kAsyncStart && + HloInstruction::IsThreadIncluded(hlo->async_execution_thread(), + execution_threads_) && + caller_counts_[hlo->async_wrapped_computation()] > 1) { + postpone_processing_called_computations = true; + return; + } + + if (hlo->opcode() == HloOpcode::kCall && + caller_counts_[hlo->to_apply()] > 1) { + postpone_processing_called_computations = true; + return; + } + // Prevent root instructions from having their output modified by recording // all F32 output values as needing to stay as F32. CHECK(hlo->parent() != nullptr); @@ -521,6 +616,15 @@ void BFloat16Propagation::AdjustCalledComputationParameters( {hlo->mutable_operand(i + 1)}); } break; + case HloOpcode::kAsyncStart: + if (HloInstruction::IsThreadIncluded(hlo->async_execution_thread(), + execution_threads_)) { + adjust_computation(hlo->async_wrapped_computation(), hlo->operands()); + } + break; + case HloOpcode::kCall: + adjust_computation(hlo->to_apply(), hlo->operands()); + break; default: break; } @@ -576,6 +680,15 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { adjust_computation(branch, hlo); } break; + case HloOpcode::kAsyncStart: + if (HloInstruction::IsThreadIncluded(hlo->async_execution_thread(), + execution_threads_)) { + adjust_computation(hlo->async_wrapped_computation(), hlo); + } + break; + case HloOpcode::kCall: + adjust_computation(hlo->to_apply(), hlo); + break; default: break; } @@ -698,6 +811,14 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( ResolveInconsistencyOfAliasingBuffersHelper(branch, visited_computations); } + } else if (hlo->opcode() == HloOpcode::kAsyncStart && + HloInstruction::IsThreadIncluded(hlo->async_execution_thread(), + execution_threads_)) { + ResolveInconsistencyOfAliasingBuffersHelper( + hlo->async_wrapped_computation(), visited_computations); + } else if (hlo->opcode() == HloOpcode::kCall) { + ResolveInconsistencyOfAliasingBuffersHelper(hlo->to_apply(), + visited_computations); } } if (!any_change) { @@ -712,10 +833,9 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( } void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( - HloModule* module, - const absl::flat_hash_set& execution_threads) { + HloModule* module) { const auto& computations_topological_order = - module->MakeComputationPostOrder(execution_threads); + module->MakeComputationPostOrder(execution_threads_); absl::flat_hash_set resolved; for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { @@ -727,8 +847,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( } absl::Status BFloat16Propagation::ResolveInconsistentFusions( - HloModule* module, - const absl::flat_hash_set& execution_threads) { + HloModule* module) { // We could have changed a fusion computation's root shape to have a different // precision than the fusion node's output, if the fusion root does not // define a buffer (e.g., a tuple). Now we add conversions after such fusion @@ -754,7 +873,8 @@ absl::Status BFloat16Propagation::ResolveInconsistentFusions( // (1) a is F32 but tuple is BF16 // (2) after adding conversion // (3) after tuple simplifier and DCE. - for (auto computation : module->MakeComputationPostOrder(execution_threads)) { + for (auto computation : + module->MakeComputationPostOrder(execution_threads_)) { auto insts = computation->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; @@ -789,9 +909,7 @@ absl::Status BFloat16Propagation::ResolveInconsistentFusions( return absl::OkStatus(); } -absl::Status BFloat16Propagation::ResolveConvertedConstants( - HloModule* module, - const absl::flat_hash_set& execution_threads) { +absl::Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { // We may have converted some constants from F32 to BF16, so adjust the // constant literals in such cases. We do this here instead of when the // constant node's is changed because 1) the HloInstruction interface does not @@ -802,7 +920,8 @@ absl::Status BFloat16Propagation::ResolveConvertedConstants( // can avoid repeated conversions. // // TODO(b/73833576): Consider resetting literal in HloInstruction. - for (auto computation : module->MakeComputationPostOrder(execution_threads)) { + for (auto computation : + module->MakeComputationPostOrder(execution_threads_)) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConstant) { continue; @@ -821,10 +940,8 @@ absl::Status BFloat16Propagation::ResolveConvertedConstants( return absl::OkStatus(); } -absl::Status BFloat16Propagation::SkipNoopConversions( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - for (auto computation : module->computations(execution_threads)) { +absl::Status BFloat16Propagation::SkipNoopConversions(HloModule* module) { + for (auto computation : module->computations(execution_threads_)) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConvert) { continue; @@ -859,9 +976,10 @@ absl::StatusOr BFloat16Propagation::Run( caller_counts_.clear(); changes_to_bf16_.clear(); changed_ = false; + execution_threads_ = execution_threads; auto computations_topological_order = - module->MakeComputationPostOrder(execution_threads); + module->MakeComputationPostOrder(execution_threads_); // Before running the propagation pass, we insert copies (kConvert to the same // type) of F32 inputs to while loops. This prevents other uses of the same @@ -929,7 +1047,7 @@ absl::StatusOr BFloat16Propagation::Run( // It's possible that an instruction does not define a buffer, but the // defining instruction's shape has changed. So we need to adjust the output // shapes of instructions according to the HLO values they refer to. - ResolveInconsistencyOfAliasingBuffers(module, execution_threads); + ResolveInconsistencyOfAliasingBuffers(module); // Apply the changes in changes_to_bf16_. for (auto& change : changes_to_bf16_) { @@ -978,13 +1096,13 @@ absl::StatusOr BFloat16Propagation::Run( // Removes redundant HLOs added by this pass, either when inserting // de-aliasing copies to while loop inputs, or later when converting output // types. - auto clean_up = [this, module, &execution_threads]() { - TF_RETURN_IF_ERROR(SkipNoopConversions(module, execution_threads)); + auto clean_up = [this, module]() { + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); TupleSimplifier tuple_simplifier; TF_RETURN_IF_ERROR( - tuple_simplifier.Run(module, execution_threads).status()); + tuple_simplifier.Run(module, execution_threads_).status()); HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); + TF_RETURN_IF_ERROR(dce.Run(module, execution_threads_).status()); return absl::OkStatus(); }; @@ -993,8 +1111,8 @@ absl::StatusOr BFloat16Propagation::Run( return false; } - TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module, execution_threads)); - TF_RETURN_IF_ERROR(ResolveConvertedConstants(module, execution_threads)); + TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); + TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); TF_RETURN_IF_ERROR(clean_up()); return true; diff --git a/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h index 317d754cb60c05..6d412d4265e10e 100644 --- a/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h @@ -130,6 +130,16 @@ class BFloat16Propagation : public HloModulePass { // Precondition: hlo->opcode() == kConditional void DetermineConditionalComputationsPrecision(HloInstruction* cond); + // Special handling in the opportunity-finding pass for async computations. + // + // Precondition: hlo->opcode() == kAsyncStart + void DetermineAsyncComputationsPrecision(HloInstruction* async_start); + + // Special handling in the opportunity-finding pass for called computations. + // + // Precondition: hlo->opcode() == kCall + void DetermineCalledComputationsPrecision(HloInstruction* call); + // The set of HloInstructions that have been visited in the // opportunity-finding pass. absl::flat_hash_set @@ -146,9 +156,7 @@ class BFloat16Propagation : public HloModulePass { // Adjusts the output shapes of HloInstructions such that if two // HloInstructions have aliasing buffers in their outputs, they must have the // same precision. - void ResolveInconsistencyOfAliasingBuffers( - HloModule* module, - const absl::flat_hash_set& execution_threads); + void ResolveInconsistencyOfAliasingBuffers(HloModule* module); // Resolves inconsistency of aliasing buffers for the given computation, and // recursively runs on a while instruction's condition and body until a fixed @@ -170,21 +178,15 @@ class BFloat16Propagation : public HloModulePass { // Resolves inconsistencies introduced by this pass for fusions with // tuple-type output. - absl::Status ResolveInconsistentFusions( - HloModule* module, - const absl::flat_hash_set& execution_threads); + absl::Status ResolveInconsistentFusions(HloModule* module); // Converts the literals in kConstant HLOs which have their types changed to // BF16 by this pass. - absl::Status ResolveConvertedConstants( - HloModule* module, - const absl::flat_hash_set& execution_threads); + absl::Status ResolveConvertedConstants(HloModule* module); // Skips no-op conversions (same source and target shapes) that can be // produced this pass, i.e., replaces them in their uses with their operands. - absl::Status SkipNoopConversions( - HloModule* module, - const absl::flat_hash_set& execution_threads); + absl::Status SkipNoopConversions(HloModule* module); // *************************** // Functions called and state used by two or more passes. @@ -232,6 +234,8 @@ class BFloat16Propagation : public HloModulePass { bool changed_ = false; std::unique_ptr dataflow_; + + absl::flat_hash_set execution_threads_; }; } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc b/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc index cf14c05d6a7365..505b18a1ee9c3b 100644 --- a/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_computation.h" @@ -36,6 +38,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -437,6 +440,104 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { EXPECT_TRUE(OutputsBF16(b_f1)); } +// Tests that BF16 is propagated properly through called fused computations. +TEST_F(BFloat16PropagationTest, PropagateThroughCalledFusion) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = f32[4,4] parameter(0) + add.0 = f32[4,4] add(arg.0, arg.0) + call.0 = call(add.0, add.0), to_apply={ + arg.0 = f32[4,4] parameter(0) + arg.1 = f32[4,4] parameter(1) + ROOT fusion.0 = (f32[4,4], f32[4,4]) fusion(arg.0, arg.1), kind=kCustom, calls={ + arg.0 = f32[4,4] parameter(0) + arg.1 = f32[4,4] parameter(1) + ROOT tuple.0 = tuple(arg.0, arg.1) + } + } + ROOT fusion.1 = f32[4,4] fusion(call.0), kind=kCustom, calls={ + arg.0 = (f32[4,4], f32[4,4]) parameter(0) + gte.0 = get-tuple-element(arg.0), index=0 + gte.1 = get-tuple-element(arg.0), index=1 + ROOT dot.0 = dot(gte.0, gte.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + HloInstruction* add0 = FindInstruction(module.get(), "add.0"); + ASSERT_NE(add0, nullptr); + EXPECT_TRUE(OutputsBF16(add0)); + HloInstruction* call = FindInstruction(module.get(), "call.0"); + ASSERT_NE(call, nullptr); + HloInstruction* arg0 = call->to_apply()->parameter_instruction(0); + EXPECT_TRUE(OutputsBF16(arg0)); + HloInstruction* arg1 = call->to_apply()->parameter_instruction(1); + EXPECT_TRUE(OutputsBF16(arg1)); + HloInstruction* gte0 = FindInstruction(module.get(), "gte.0"); + ASSERT_NE(gte0, nullptr); + EXPECT_TRUE(OutputsBF16(gte0)); + HloInstruction* gte1 = FindInstruction(module.get(), "gte.1"); + ASSERT_NE(gte1, nullptr); + EXPECT_TRUE(OutputsBF16(gte1)); +} + +// Tests that BF16 is propagated properly through async fused computations. +TEST_F(BFloat16PropagationTest, PropagateThroughAsyncFusion) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = f32[4,4] parameter(0) + add.0 = f32[4,4] add(arg.0, arg.0) + fusion-start.0 = ((f32[4,4], f32[4,4]), (f32[4,4], f32[4,4]), s32[]) fusion-start(add.0, add.0), kind=kCustom, calls={ + arg.0 = f32[4,4] parameter(0) + arg.1 = f32[4,4] parameter(1) + ROOT tuple.0 = tuple(arg.0, arg.1) + }, async_execution_thread="main" + fusion-done.0 = (f32[4,4], f32[4,4]) fusion-done(fusion-start.0) + ROOT fusion.1 = f32[4,4] fusion(fusion-done.0), kind=kCustom, calls={ + arg.0 = (f32[4,4], f32[4,4]) parameter(0) + gte.0 = get-tuple-element(arg.0), index=0 + gte.1 = get-tuple-element(arg.0), index=1 + ROOT dot.0 = dot(gte.0, gte.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + HloInstruction* add0 = FindInstruction(module.get(), "add.0"); + ASSERT_NE(add0, nullptr); + EXPECT_TRUE(OutputsBF16(add0)); + HloInstruction* fusion0 = FindInstruction(module.get(), "fusion-start.0"); + HloInstruction* async_arg0 = + fusion0->async_wrapped_computation()->parameter_instruction(0); + EXPECT_TRUE(OutputsBF16(async_arg0)); + HloInstruction* async_arg1 = + fusion0->async_wrapped_computation()->parameter_instruction(1); + EXPECT_TRUE(OutputsBF16(async_arg1)); + HloInstruction* arg0 = fusion0->async_wrapped_instruction() + ->called_computations()[0] + ->parameter_instruction(0); + EXPECT_TRUE(OutputsBF16(arg0)); + HloInstruction* arg1 = fusion0->async_wrapped_instruction() + ->called_computations()[0] + ->parameter_instruction(1); + EXPECT_TRUE(OutputsBF16(arg1)); + HloInstruction* gte0 = FindInstruction(module.get(), "gte.0"); + ASSERT_NE(gte0, nullptr); + EXPECT_TRUE(OutputsBF16(gte0)); + HloInstruction* gte1 = FindInstruction(module.get(), "gte.1"); + ASSERT_NE(gte1, nullptr); + EXPECT_TRUE(OutputsBF16(gte1)); +} + // Tests that a fusion with a bitcast-convert as its root is changed via adding // extra convert, instead of changing the type in-place. TEST_F(BFloat16PropagationTest, FusionWithBitcastConvertRoot) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD index e7bd19914b5aed..41a2ce422b6d86 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -202,6 +202,7 @@ xla_cc_test( "//xla/service:hlo_creation_utils", "//xla/service:hlo_verifier", "//xla/service:pattern_matcher", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc index 0867254d42df2f..22d1c19704c9c8 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -593,6 +593,8 @@ absl::Status FloatNormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kWhile || // hlo->opcode() == HloOpcode::kConditional || // hlo->opcode() == HloOpcode::kBitcastConvert || // + hlo->opcode() == HloOpcode::kAsyncStart || // + hlo->opcode() == HloOpcode::kAsyncDone || // hlo->HasSideEffectNoRecurse()) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc index f61d7f0c1d3865..a78fe9296ae8e0 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -212,6 +213,73 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedBF16) { EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); } +TEST_F(FloatNormalizationTest, ResolveIfUnsupportedBF16CalledComputation) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = f32[2,4] parameter(0) + arg.1 = bf16[2,4] parameter(1) + arg.2 = f32[2,4] parameter(2) + ROOT call.0 = call(arg.0, arg.1, arg.2), to_apply={ + arg.0 = f32[2,4] parameter(0) + arg.1 = bf16[2,4] parameter(1) + arg.2 = f32[2,4] parameter(2) + multiply.0 = bf16[2,4] multiply(arg.0, arg.1) + ROOT multiply.1 = bf16[2,4] multiply(multiply.0, arg.2) + } +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + + EXPECT_TRUE(Normalize(module.get())); + + HloInstruction* call0 = FindInstruction(module.get(), "call.0"); + ASSERT_NE(call0, nullptr); + HloComputation* computation = call0->to_apply(); + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + HloInstruction* multiply1 = FindInstruction(module.get(), "multiply.1"); + ASSERT_NE(multiply1, nullptr); + EXPECT_EQ(computation->root_instruction()->operand(0), multiply1); + EXPECT_EQ(multiply1->shape().element_type(), F32); + EXPECT_EQ(multiply1->shape().element_type(), F32); + EXPECT_EQ(multiply1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(FloatNormalizationTest, ResolveIfUnsupportedBF16AsyncComputation) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = f32[2,4] parameter(0) + arg.1 = bf16[2,4] parameter(1) + arg.2 = f32[2,4] parameter(2) + call-start.0 = ((f32[2,4], bf16[2,4], f32[2,4]), bf16[2,4], s32[]) call-start(arg.0, arg.1, arg.2), to_apply={ + arg.0 = f32[2,4] parameter(0) + arg.1 = bf16[2,4] parameter(1) + arg.2 = f32[2,4] parameter(2) + multiply.0 = bf16[2,4] multiply(arg.0, arg.1) + ROOT multiply.1 = bf16[2,4] multiply(multiply.0, arg.2) + } + ROOT call-done.0 = bf16[2,4] call-done(call-start.0) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + + EXPECT_TRUE(Normalize(module.get())); + HloInstruction* call_start0 = FindInstruction(module.get(), "call-start.0"); + ASSERT_NE(call_start0, nullptr); + HloComputation* computation = + call_start0->async_wrapped_instruction()->to_apply(); + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + HloInstruction* multiply1 = FindInstruction(module.get(), "multiply.1"); + ASSERT_NE(multiply1, nullptr); + EXPECT_EQ(computation->root_instruction()->operand(0), multiply1); + EXPECT_EQ(multiply1->shape().element_type(), F32); + EXPECT_EQ(multiply1->shape().element_type(), F32); + EXPECT_EQ(multiply1->operand(0)->opcode(), HloOpcode::kConvert); +} + TEST_F(FloatNormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); From c2f34790292c650669ce6c3d73bbfb211cc6bdc0 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 3 Apr 2025 16:53:56 -0700 Subject: [PATCH 0220/1324] Make TensorFlow CI properly read changes to `third_party/` which are Copybara'd to XLA This is correct as files on GitHub in both `openxla/xla/third_party` and `tensorflow/tensorflow/third_party` are the same underlying file internally. PiperOrigin-RevId: 743745178 --- third_party/xla/build_tools/ci/build.py | 40 ++++++++++++------- .../xla/build_tools/ci/golden_commands.txt | 10 +++-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index d71bf4b22d9773..2b5faa2f8ef1cf 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -543,23 +543,29 @@ def nvidia_gpu_build_with_compute_capability( # This is pretty devious - but we have to do some adhoc extra Copybara # work here to get XLA into the shape TF expects. b/407638223 # pyformat:disable - [ - "cp", "-r", - f"{_GITHUB_WORKSPACE}/openxla/xla", - f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", - ], [ "find", - f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + f"{_GITHUB_WORKSPACE}/openxla/xla", "-type", "f", "-exec", "sed", "-i", "s/@local_xla/@local_xla/g", "{}", "+", ], [ "find", - f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + f"{_GITHUB_WORKSPACE}/openxla/xla", "-type", "f", "-exec", "sed", "-i", "s/@local_tsl/@local_tsl/g", "{}", "+", ], + [ + "cp", "-r", + f"{_GITHUB_WORKSPACE}/openxla/xla", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", + ], + [ + "find", + f"{_GITHUB_WORKSPACE}/openxla/xla/third_party/", + "-maxdepth", "1", "-exec", "cp", "-r", "{}", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", ";", + ], ), ) @@ -593,23 +599,29 @@ def nvidia_gpu_build_with_compute_capability( # This is pretty devious - but we have to do some adhoc extra Copybara # work here to get XLA into the shape TF expects. b/407638223 # pyformat:disable - [ - "cp", "-r", - f"{_GITHUB_WORKSPACE}/openxla/xla", - f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", - ], [ "find", - f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + f"{_GITHUB_WORKSPACE}/openxla/xla", "-type", "f", "-exec", "sed", "-i", "s/@local_xla/@local_xla/g", "{}", "+", ], [ "find", - f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party/xla", + f"{_GITHUB_WORKSPACE}/openxla/xla", "-type", "f", "-exec", "sed", "-i", "s/@local_tsl/@local_tsl/g", "{}", "+", ], + [ + "cp", "-r", + f"{_GITHUB_WORKSPACE}/openxla/xla", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", + ], + [ + "find", + f"{_GITHUB_WORKSPACE}/openxla/xla/third_party/", + "-maxdepth", "1", "-exec", "cp", "-r", "{}", + f"{_GITHUB_WORKSPACE}/tensorflow/tensorflow/third_party", ";", + ], ), ) diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index 9985958a2d2d41..7a1cec6d41c8e9 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -9,17 +9,19 @@ bazel test --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccele bazel analyze-profile profile.json.gz # END BuildType.JAX_LINUX_X86_GPU_T4_GITHUB_ACTIONS # BEGIN BuildType.TENSORFLOW_LINUX_X86_CPU_GITHUB_ACTIONS +find $GITHUB_WORKSPACE/openxla/xla -type f -exec sed -i s/@local_xla/@local_xla/g {} + +find $GITHUB_WORKSPACE/openxla/xla -type f -exec sed -i s/@local_tsl/@local_tsl/g {} + cp -r $GITHUB_WORKSPACE/openxla/xla $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party -find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_xla/@local_xla/g {} + -find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_tsl/@local_tsl/g {} + +find $GITHUB_WORKSPACE/openxla/xla/third_party/ -maxdepth 1 -exec cp -r {} $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party ; parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --config=release_cpu_linux --config=rbe_linux_cpu --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes --nobuild -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... bazel test --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-gpu --config=release_cpu_linux --config=rbe_linux_cpu --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... bazel analyze-profile profile.json.gz # END BuildType.TENSORFLOW_LINUX_X86_CPU_GITHUB_ACTIONS # BEGIN BuildType.TENSORFLOW_LINUX_X86_GPU_T4_GITHUB_ACTIONS +find $GITHUB_WORKSPACE/openxla/xla -type f -exec sed -i s/@local_xla/@local_xla/g {} + +find $GITHUB_WORKSPACE/openxla/xla -type f -exec sed -i s/@local_tsl/@local_tsl/g {} + cp -r $GITHUB_WORKSPACE/openxla/xla $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party -find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_xla/@local_xla/g {} + -find $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party/xla -type f -exec sed -i s/@local_tsl/@local_tsl/g {} + +find $GITHUB_WORKSPACE/openxla/xla/third_party/ -maxdepth 1 -exec cp -r {} $GITHUB_WORKSPACE/tensorflow/tensorflow/third_party ; parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes --nobuild -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... bazel test --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-tpu,-benchmark-test,-v1only,-no_gpu,-no_gpu_presubmit,-no_cuda11,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --repo_env=USE_PYWRAP_RULES=True --verbose_failures --test_output=errors --profile=profile.json.gz --test_lang_filters=cc,py --color=yes -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/kernel_tests/... -//tensorflow/python/data/... -//tensorflow/python/compiler/tensorrt/... bazel analyze-profile profile.json.gz From 8efa3d771e1a38992be5263975275cfa8f238663 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Thu, 3 Apr 2025 16:54:51 -0700 Subject: [PATCH 0221/1324] Add check to prevent HloTestBase and HloPjRtTestBase being used together. PiperOrigin-RevId: 743745454 --- .../xla/xla/tests/client_library_test_base.h | 13 +++++++++++++ third_party/xla/xla/tests/hlo_pjrt_test_base.h | 13 +++++++++++++ third_party/xla/xla/tests/hlo_test_base.h | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 8c092a906ea76a..a5d0a457525275 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -16,6 +16,19 @@ limitations under the License. #ifndef XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ #define XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ +// Inclusion of this header indicates that the test has NOT been migrated to use +// HloRunnerPjRt. Migration requires tagging the build target so that the +// correct dependencies are included. The whole target must be migrated at once. +// This macro helps to ensure that migration test base classes are not used in +// conjunction with ClientLibraryTestBase. +// TODO: b/408276009 - Remove these macros once all tests have been migrated. +#define XLA_TEST_NOT_MIGRATED_TO_HLO_RUNNER_PJRT +#ifdef XLA_TEST_MIGRATED_TO_HLO_RUNNER_PJRT +static_assert(false, + "ClientLibraryTestBase cannot be used in the same target as a " + "test that has been explicitly migrated to use HloRunnerPjRt."); +#endif // XLA_TEST_MIGRATED_TO_HLO_RUNNER_PJRT + #include #include #include diff --git a/third_party/xla/xla/tests/hlo_pjrt_test_base.h b/third_party/xla/xla/tests/hlo_pjrt_test_base.h index 989ae6e769d29a..426fb8073e1034 100644 --- a/third_party/xla/xla/tests/hlo_pjrt_test_base.h +++ b/third_party/xla/xla/tests/hlo_pjrt_test_base.h @@ -16,6 +16,19 @@ limitations under the License. #ifndef XLA_TESTS_HLO_PJRT_TEST_BASE_H_ #define XLA_TESTS_HLO_PJRT_TEST_BASE_H_ +// Inclusion of this header indicates that the test has been migrated to use +// HloRunnerPjRt. Since this requires tagging the build target, the whole target +// must be migrated at once. HloPjRtTestBase cannot be used in conjunction with +// HloTestBase and vice versa. This macro helps to ensure that migration test +// base classes are not used in conjunction with HloTestBase. +// TODO: b/408276009 - Remove these macros once all tests have been migrated. +#define XLA_TEST_MIGRATED_TO_HLO_RUNNER_PJRT +#ifdef XLA_TEST_NOT_MIGRATED_TO_HLO_RUNNER_PJRT +static_assert(false, + "HloPjRtTestBase cannot be used in the same target as a test " + "that uses HloTestBase."); +#endif // XLA_TEST_NOT_MIGRATED_TO_HLO_RUNNER_PJRT + #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index dd37ead6b409e5..46ed5520dc2749 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -16,6 +16,19 @@ limitations under the License. #ifndef XLA_TESTS_HLO_TEST_BASE_H_ #define XLA_TESTS_HLO_TEST_BASE_H_ +// Inclusion of this header indicates that the test has NOT been migrated to use +// HloRunnerPjRt. Migration requires tagging the build target so that the +// correct dependencies are included. The whole target must be migrated at once. +// This macro helps to ensure that migration test base classes are not used in +// conjunction with HloTestBase. +// TODO: b/408276009 - Remove these macros once all tests have been migrated. +#define XLA_TEST_NOT_MIGRATED_TO_HLO_RUNNER_PJRT +#ifdef XLA_TEST_MIGRATED_TO_HLO_RUNNER_PJRT +static_assert(false, + "HloTestBase cannot be used in the same target as a test that " + "has been explicitly migrated to use HloRunnerPjRt."); +#endif // XLA_TEST_MIGRATED_TO_HLO_RUNNER_PJRT + #include #include #include From d33e9de735558f265294fdc32f481cfa321cdd4c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 17:09:53 -0700 Subject: [PATCH 0222/1324] To better preserve metadata when rewriting an instruction, this change does the following: 1- Allows HloInstruction::AddInstruction to accept a name for the derived instruction. 2- When replacing a while instruction with another one, we create the new while as a derived instruction from the original. PiperOrigin-RevId: 743749496 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 6 +++++- third_party/xla/xla/hlo/ir/hlo_instruction.h | 6 ++++-- third_party/xla/xla/service/while_util.cc | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 6fa7756dfe9eba..3fafafa4ced082 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -264,7 +264,11 @@ void HloInstruction::ClearCalledComputations() { } HloInstruction* HloInstruction::AddInstruction( - std::unique_ptr derived_instruction) { + std::unique_ptr derived_instruction, + absl::string_view new_name) { + if (!new_name.empty()) { + derived_instruction->SetAndSanitizeName(new_name); + } HloInstruction* derived = parent()->AddInstruction(std::move(derived_instruction)); const bool has_prior_sharding = derived->has_sharding(); diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 4e446867513979..4f0df16edb34f3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -300,9 +300,11 @@ class HloInstruction { void DetachFromOperandsAndUsers(); // Adds a derived instruction to the parent computation of this instruction. - // Also update setup the new instruction as a derived instruction. + // Updates setup the new instruction as a derived instruction, and sets the + // name of the new instruction (if `new_name` is not empty). HloInstruction* AddInstruction( - std::unique_ptr derived_instruction); + std::unique_ptr derived_instruction, + absl::string_view new_name = ""); // Creates an instruction from the given proto. Arguments: // diff --git a/third_party/xla/xla/service/while_util.cc b/third_party/xla/xla/service/while_util.cc index edf9f4b0b84a90..98bc1d2d9c7214 100644 --- a/third_party/xla/xla/service/while_util.cc +++ b/third_party/xla/xla/service/while_util.cc @@ -156,7 +156,7 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while_init = TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); HloComputation* containing_computation = while_instr->parent(); - HloInstruction* new_while = containing_computation->AddInstruction( + HloInstruction* new_while = while_instr->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); From 2e56997b415e140c026f3955c9ff36fa4fb74ecd Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 3 Apr 2025 17:18:20 -0700 Subject: [PATCH 0223/1324] Delete `PjRtClient.Defragment`. The `Defragment` implementation for GPU is in `py_client.cc`, so this should be a no-op. PiperOrigin-RevId: 743751570 --- third_party/xla/xla/pjrt/cpu/cpu_client.h | 4 ---- third_party/xla/xla/pjrt/pjrt_c_api_client.h | 6 ------ third_party/xla/xla/pjrt/pjrt_client.h | 6 ------ third_party/xla/xla/pjrt/pjrt_stream_executor_client.h | 5 ----- third_party/xla/xla/pjrt/tf_pjrt_client.h | 1 - 5 files changed, 22 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index 438f212b277ce0..b66328baa4bf43 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -182,10 +182,6 @@ class TfrtCpuClient final : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::Status Defragment() override { - return Unimplemented("Defragment not implemented."); - } - tsl::thread::ThreadPool* pjrt_client_thread_pool() const { return pjrt_client_thread_pool_.get(); } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 59aecf6d440a91..247e57c57f4eff 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -371,12 +371,6 @@ class PjRtCApiClient : public PjRtClient { "this feature."); } - absl::Status Defragment() override { - return Unimplemented( - "PJRT C API does not support Defragment. Please report an issue at " - "https://github.com/google/jax/issues if you need this feature."); - } - absl::Status DmaMap(void* data, size_t size) override; absl::Status DmaUnmap(void* data) override; diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index fefa0dc9c543de..8f8d9ae6495f19 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -931,12 +931,6 @@ class PjRtClient { return Unimplemented("MakeCrossHostReceiveBuffers is not implemented."); } - // TODO(zhangqiaorjc): Experimental API to be removed. - // Defragment device memory. - virtual absl::Status Defragment() { - return Unimplemented("Defragment is not implemented."); - } - // Return the PjRtHostMemoryForDeviceManager for this client. It can be // nullptr if the implementation does not provide one. virtual PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager() diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 1e5c8c7cc3b001..76b43fb985a684 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -337,11 +337,6 @@ class PjRtStreamExecutorClient : public PjRtClient { bool IsDmaMapped(const void* data_start, int64_t transfer_size); - // TODO(zhangqiaorjc): Experimental. Will be removed. - absl::Status Defragment() override { - return Unimplemented("Defragment not implemented"); - } - LocalDeviceState& device_state(int device_ordinal) const { return *tensorflow::down_cast( LookupAddressableDevice(xla::PjRtLocalDeviceId(device_ordinal)) diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index 59e11af8bdedd5..9662d9b3d73b89 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -310,7 +310,6 @@ class TfPjRtClient : public PjRtClient { const override { return wrapped_->GetTopologyDescription(); } - absl::Status Defragment() override { return wrapped_->Defragment(); } PjRtClient* wrapped() const { return wrapped_.get(); } From c3477b7ece644f4456f402dbb4fc47e97acc06be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 17:32:26 -0700 Subject: [PATCH 0224/1324] Internal fixes to add missing dependencies. PiperOrigin-RevId: 743754601 --- third_party/xla/xla/tsl/lib/math/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/tsl/lib/math/BUILD b/third_party/xla/xla/tsl/lib/math/BUILD index 6d2e3fe7b1ac2c..0a67561aca4c29 100644 --- a/third_party/xla/xla/tsl/lib/math/BUILD +++ b/third_party/xla/xla/tsl/lib/math/BUILD @@ -18,6 +18,7 @@ filegroup( srcs = [ "math_util.h", ], + compatible_with = get_compatible_with_portable(), ) cc_library( From c7419140c8837cec9ec79ab47df8bc7984369e41 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 3 Apr 2025 18:03:16 -0700 Subject: [PATCH 0225/1324] [XlaCallModule] Add better error message to computation deserialization failures. PiperOrigin-RevId: 743761053 --- .../compiler/tests/xla_call_module_test.py | 24 +++++++++++++++++++ .../tf2xla/kernels/xla_call_module_loader.cc | 12 +++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index b8d59d77641ada..197df89e2c0042 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -1568,6 +1568,30 @@ def f(x): self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) + def test_op_backward_incompatibility(self): + """Test for ensuring XlaCallModuleOp with invalid bytecode.""" + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + def f(x): + # Use an invalid MLIR string that will fail to parse when loading the + # call module op, emulating a backward incompatibility. + corrupted_module = 'stablehlo.invalid_op' + return gen_xla_ops.xla_call_module( + [x], + version=xla.call_module_maximum_supported_version(), + module=corrupted_module, + Tout=[x.dtype], + Sout=[x.shape], + platforms=[self.testing_platform()], + ) + + # Expect any error message to be included after `:` + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Cannot deserialize computation: .+', + ): + f(x) + if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index ddd1f23cbb068e..eb946ab9085b93 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -399,9 +399,15 @@ absl::Status XlaCallModuleLoader::LoadModule( } // Parse the StableHLO/VHLO bytecode - module_ = mlir::stablehlo::deserializePortableArtifact(module_str, context_); - if (!module_) { - return absl::InvalidArgumentError("Cannot deserialize computation"); + { + mlir::StatusScopedDiagnosticHandler diag_handler(context_); + module_ = + mlir::stablehlo::deserializePortableArtifact(module_str, context_); + if (!module_) { + return absl::InvalidArgumentError( + absl::StrCat("Cannot deserialize computation: ", + diag_handler.ConsumeStatus().ToString())); + } } VLOG(3) << "Parsed serialized module (version = " << version << ", platforms = [" << absl::StrJoin(platforms, ", ") From 27c036934e73460fd0a7ccd2e4fcfaddecc31857 Mon Sep 17 00:00:00 2001 From: Ce Zheng Date: Thu, 3 Apr 2025 18:19:12 -0700 Subject: [PATCH 0226/1324] [PJRT] Split AsyncWorkRunner to its own header. PiperOrigin-RevId: 743764232 --- third_party/xla/xla/pjrt/BUILD | 11 +++++ third_party/xla/xla/pjrt/async_work_runner.h | 41 +++++++++++++++++++ third_party/xla/xla/pjrt/cpu/BUILD | 2 + .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 1 + .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h | 14 +------ third_party/xla/xla/pjrt/cpu/cpu_client.h | 1 + 6 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 third_party/xla/xla/pjrt/async_work_runner.h diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 2ee156f7faefe0..f498d64d1d2647 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -1158,3 +1158,14 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "async_work_runner", + hdrs = ["async_work_runner.h"], + deps = [ + "//xla/tsl/concurrency:async_value", + "//xla/tsl/concurrency:ref_count", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/types:span", + ], +) diff --git a/third_party/xla/xla/pjrt/async_work_runner.h b/third_party/xla/xla/pjrt/async_work_runner.h new file mode 100644 index 00000000000000..906d6d213fa850 --- /dev/null +++ b/third_party/xla/xla/pjrt/async_work_runner.h @@ -0,0 +1,41 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_ASYNC_WORK_RUNNER_H_ +#define XLA_PJRT_ASYNC_WORK_RUNNER_H_ + +#include "absl/functional/any_invocable.h" +#include "absl/types/span.h" +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace xla { + +// Async work runner abstracts away the implementation of the underlying thread +// pool (or concurrent work queue). +class AsyncWorkRunner { + public: + virtual ~AsyncWorkRunner() = default; + + // `work` euqueued by `Schedule` may run on the calling thread. + virtual void Schedule(absl::AnyInvocable work) = 0; + virtual void ScheduleWhenReady( + absl::Span> values, + absl::AnyInvocable work) = 0; +}; + +} // namespace xla + +#endif // XLA_PJRT_ASYNC_WORK_RUNNER_H_ diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 9511e6bb9cc98a..dd26df16f693a6 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -77,6 +77,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/pjrt:async_work_runner", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_future", "//xla/pjrt:transpose", @@ -173,6 +174,7 @@ cc_library( "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/pjrt:async_work_runner", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:host_callback", "//xla/pjrt:host_memory_spaces", diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index db483e43bc2229..45978171c6e855 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/cpu_function_runtime.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/pjrt/async_work_runner.h" #include "xla/pjrt/cpu/cpu_event.h" #include "xla/pjrt/cpu/tracked_cpu_device_buffer.h" #include "xla/pjrt/pjrt_client.h" diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index a6f2f07bb51d26..a107d70f5cc0c3 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -36,6 +36,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/literal.h" +#include "xla/pjrt/async_work_runner.h" #include "xla/pjrt/cpu/cpu_event.h" #include "xla/pjrt/cpu/tracked_cpu_device_buffer.h" #include "xla/pjrt/pjrt_client.h" @@ -75,19 +76,6 @@ class MarkEventReadyOnExit { tsl::AsyncValueRef event_; }; -// Async work runner abstracts away the implementation of the underlying thread -// pool (or concurrent work queue). -class AsyncWorkRunner { - public: - virtual ~AsyncWorkRunner() = default; - - // `work` euqueued by `Schedule` may run on the calling thread. - virtual void Schedule(absl::AnyInvocable work) = 0; - virtual void ScheduleWhenReady( - absl::Span> values, - absl::AnyInvocable work) = 0; -}; - class AbstractTfrtCpuBuffer : public PjRtBuffer { public: AbstractTfrtCpuBuffer( diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index b66328baa4bf43..8b1f708c444e30 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -1,3 +1,4 @@ +#include "xla/pjrt/async_work_runner.h" /* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); From bd9452dcc229a011626c91bb02dadd0198bf5e53 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 18:50:21 -0700 Subject: [PATCH 0227/1324] Temporary disable `USE_CUDA_REDISTRIBUTIONS` because it increased CPU job time. PiperOrigin-RevId: 743771227 --- .bazelrc | 5 +++-- third_party/xla/tensorflow.bazelrc | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.bazelrc b/.bazelrc index c33396789f01c0..98e967b5619716 100644 --- a/.bazelrc +++ b/.bazelrc @@ -595,8 +595,9 @@ common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instance # Download CUDA/CUDNN redistributions to preserve the repositories cache between # CPU and GPU builds. -build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 -build:rbe_linux_cpu --config=cuda_version +# TODO(ybaturina): Uncomment when RBE is ready to support this. +# build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +# build:rbe_linux_cpu --config=cuda_version # TODO(kanglan): Remove it after toolchain update is complete. build:rbe_linux_cpu_old --config=rbe_linux diff --git a/third_party/xla/tensorflow.bazelrc b/third_party/xla/tensorflow.bazelrc index 1d47e43735baf3..584a966f8cde27 100644 --- a/third_party/xla/tensorflow.bazelrc +++ b/third_party/xla/tensorflow.bazelrc @@ -454,8 +454,9 @@ common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instance # Download CUDA/CUDNN redistributions to preserve the repositories cache between # CPU and GPU builds. -build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 -build:rbe_linux_cpu --config=cuda_version +# TODO(ybaturina): Uncomment when RBE is ready to support this. +# build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +# build:rbe_linux_cpu --config=cuda_version build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu From 586ce56ec684b58a8bfb606bc343a3e4606bc113 Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Thu, 3 Apr 2025 19:46:56 -0700 Subject: [PATCH 0228/1324] Migrate HloCostAnalysis helper libraries to open source. PiperOrigin-RevId: 743781631 --- tensorflow/core/profiler/utils/BUILD | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 4122fdefbfe62f..e41694b21f67d2 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow:tensorflow.bzl", "tf_cuda_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") @@ -301,12 +300,9 @@ cc_library( ], ) -tf_cuda_library( +cc_library( name = "hlo_module_map", hdrs = ["hlo_module_map.h"], - cuda_deps = [ - "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", - ], visibility = [ "//perftools/accelerators/xprof/convert:__pkg__", "//tensorflow/core/profiler/convert:__pkg__", @@ -333,7 +329,10 @@ cc_library( cc_library( name = "xprof_gpu_cost_analysis", hdrs = ["xprof_gpu_cost_analysis.h"], - visibility = ["//perftools/accelerators/xprof/convert:__pkg__"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + ], deps = [ "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], From 67374f0f820b1a2b2949ed5203ba0412aa1b4514 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 19:59:28 -0700 Subject: [PATCH 0229/1324] Add GetDefaultLayout API to IFRT Proxy PiperOrigin-RevId: 743783889 --- .../xla/xla/python/ifrt_proxy/client/BUILD | 1 + .../xla/python/ifrt_proxy/client/client.cc | 23 ++++++++++++ .../xla/xla/python/ifrt_proxy/client/client.h | 5 +-- .../python/ifrt_proxy/client/rpc_helper.cc | 1 + .../xla/python/ifrt_proxy/client/rpc_helper.h | 3 ++ .../ifrt_proxy/common/ifrt_service.proto | 11 ++++++ .../python/ifrt_proxy/server/ifrt_backend.cc | 30 ++++++++++++++++ .../python/ifrt_proxy/server/ifrt_backend.h | 2 ++ .../ifrt_proxy/server/ifrt_backend_test.cc | 36 +++++++++++++++++++ 9 files changed, 108 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 4180533cff7215..c69b75f8728d4f 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -158,6 +158,7 @@ cc_library( ":rpc_helper", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", "//xla/python/ifrt:basic_device_list", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc index 67072212b9c5e1..726f9b50ecd185 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -33,6 +33,7 @@ #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/attribute_map.h" @@ -415,6 +416,28 @@ xla::ifrt::DeviceListRef Client::MakeDeviceList( return xla::ifrt::BasicDeviceList::Create(devices); } +absl::StatusOr> Client::GetDefaultLayout( + xla::ifrt::DType dtype, absl::Span dims, + xla::ifrt::Device* device, xla::ifrt::MemoryKind memory_kind) const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointGetDefaultLayout"); + auto req = std::make_unique(); + *req->mutable_dtype() = dtype.ToProto(); + req->mutable_dims()->Reserve(dims.size()); + for (int64_t dim : dims) { + req->add_dims(dim); + } + req->set_device_id(device->Id().value()); + req->set_memory_kind(std::string(memory_kind.memory_kind().value_or(""))); + + auto future = rpc_helper_->GetDefaultLayout(std::move(req)); + TF_ASSIGN_OR_RETURN(auto response, future.Await()); + + TF_ASSIGN_OR_RETURN(auto layout, xla::PjRtLayout::Deserialize( + response->serialized_pjrt_layout())); + return layout; +} + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.h b/third_party/xla/xla/python/ifrt_proxy/client/client.h index 4c41d468fede13..e859de91da7cc2 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.h @@ -151,10 +151,7 @@ class Client final : public llvm::RTTIExtends { absl::StatusOr> GetDefaultLayout( xla::ifrt::DType dtype, absl::Span dims, xla::ifrt::Device* device, - xla::ifrt::MemoryKind memory_kind) const override { - return absl::UnimplementedError( - "GetDefaultLayout is not supported for the IFRT proxy client."); - } + xla::ifrt::MemoryKind memory_kind) const override; tsl::RCReference CreateUserContext() override { return tsl::RCReference(); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index 57f6cd4daec490..bfeb351a782972 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -342,6 +342,7 @@ RPC(LoadedExecutableIsDeleted, loaded_executable_is_deleted); RPC(LoadedExecutableDestruct, loaded_executable_destruct); RPC(LoadedHostCallbackPoll, loaded_host_callback_poll); RPC(LoadedHostCallbackReturn, loaded_host_callback_return); +RPC(GetDefaultLayout, get_default_layout); Future<> RpcHelper::CheckFuture(uint64_t handle) { auto req = std::make_unique(); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h index 5fc0440b44e37c..01ef59e534a8b3 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -144,6 +144,9 @@ class RpcHelper { ResponseFuture LoadedHostCallbackReturn( std::unique_ptr req); + ResponseFuture GetDefaultLayout( + std::unique_ptr req); + // Utility functions. // Generates a handle for new arrays, array data stored in HostBufferStore, diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 3d7ede383ca30e..329d6eec932302 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -81,6 +81,7 @@ message IfrtRequest { // ===== Client ===== GetDefaultDeviceAssignmentRequest get_default_device_assignment_request = 19; + GetDefaultLayoutRequest get_default_layout_request = 27; } reserved 10; @@ -133,6 +134,7 @@ message IfrtResponse { // ===== Client ===== GetDefaultDeviceAssignmentResponse get_default_device_assignment_response = 19; + GetDefaultLayoutResponse get_default_layout_response = 27; } reserved 10; @@ -571,3 +573,12 @@ message GetDefaultDeviceAssignmentRequest { message GetDefaultDeviceAssignmentResponse { xla.DeviceAssignmentProto device_assignment = 1; } +message GetDefaultLayoutRequest { + DTypeProto dtype = 1; + repeated int64 dims = 2; + int32 device_id = 3; + string memory_kind = 4; +} +message GetDefaultLayoutResponse { + bytes serialized_pjrt_layout = 1; +} diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index 9d73688e7f0914..d94a53748591ba 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -617,6 +617,9 @@ Future IfrtBackend::ProcessInternal( case IfrtRequest::RequestCase::kGetDefaultDeviceAssignmentRequest: return Future( HandleGetDefaultDeviceAssignmentRequest(std::move(request))); + case IfrtRequest::RequestCase::kGetDefaultLayoutRequest: + return Future( + HandleGetDefaultLayoutRequest(std::move(request))); default: LOG(ERROR) << "Got unimplemented request type: " << request->DebugString(); @@ -1928,6 +1931,33 @@ IfrtBackend::HandleGetDefaultDeviceAssignmentRequest( return ifrt_resp; } +absl::StatusOr +IfrtBackend::HandleGetDefaultLayoutRequest( + std::unique_ptr request) { + const auto& get_default_layout_request = + request->get_default_layout_request(); + TF_ASSIGN_OR_RETURN(auto dtype, + DType::FromProto(get_default_layout_request.dtype())); + TF_ASSIGN_OR_RETURN( + Device* const device, + client_->LookupDevice(DeviceId(get_default_layout_request.device_id()))); + MemoryKind memory_kind = + get_default_layout_request.memory_kind().empty() + ? MemoryKind() + : MemoryKind(get_default_layout_request.memory_kind()); + TF_ASSIGN_OR_RETURN( + std::shared_ptr layout, + client_->GetDefaultLayout(dtype, get_default_layout_request.dims(), + device, memory_kind)); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + *ifrt_resp->mutable_get_default_layout_response() + ->mutable_serialized_pjrt_layout() = layout->Serialize(); + + return ifrt_resp; +} + absl::StatusOr> IfrtBackend::GetLoadedExecutable(uint64_t handle) { absl::MutexLock lock(&executables_mutex_); diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h index 166f6123f73741..48124a7dd7824b 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h @@ -220,6 +220,8 @@ class IfrtBackend final : public BackendInterface { absl::StatusOr HandleGetDefaultDeviceAssignmentRequest( std::unique_ptr request); + absl::StatusOr HandleGetDefaultLayoutRequest( + std::unique_ptr request); ////////////////////////////////////////////////////////////////////// // Auxiliary/Helper methods for the handler methods above diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index d55a2b5d171bc9..e665b475d9a63e 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -1789,6 +1789,42 @@ TEST_P(IfrtBackendHandlerTest, StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); } +TEST_P(IfrtBackendHandlerTest, GetDefaultLayoutSuccess) { + const auto kDefaultLayout = std::make_shared( + xla::LayoutUtil::MakeDescendingLayout(1)); + const xla::ifrt::DType kDType = xla::ifrt::DType(xla::ifrt::DType::kF32); + const std::vector kDims = {1, 2, 3}; + const int64_t kDeviceId = 42; + const auto mock_device = std::make_unique(); + const std::string kMemoryKindStr = "xla::ifrt::MemoryKind()"; + const xla::ifrt::MemoryKind kMemoryKind(kMemoryKindStr); + + ON_CALL(*mock_client_, LookupDevice(DeviceId(kDeviceId))) + .WillByDefault(Return(mock_device.get())); + + EXPECT_CALL(*mock_client_, + GetDefaultLayout(kDType, absl::MakeConstSpan(kDims), + mock_device.get(), kMemoryKind)) + .WillOnce(Return(std::shared_ptr(kDefaultLayout))); + + auto request = NewIfrtRequest(NewOpId()); + auto* default_layout_request = request->mutable_get_default_layout_request(); + *default_layout_request->mutable_dtype() = kDType.ToProto(); + default_layout_request->mutable_dims()->Reserve(kDims.size()); + for (int64_t dim : kDims) { + default_layout_request->add_dims(dim); + } + default_layout_request->set_device_id(kDeviceId); + default_layout_request->set_memory_kind(kMemoryKindStr); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(request))); + TF_ASSERT_OK_AND_ASSIGN( + auto layout_got, + xla::PjRtLayout::Deserialize( + response->get_default_layout_response().serialized_pjrt_layout())); + EXPECT_EQ(*layout_got, *kDefaultLayout); +} + INSTANTIATE_TEST_SUITE_P( IfrtBackendHandlerTestWithAllVersions, IfrtBackendHandlerTest, testing::Range(kServerMinVersion, kServerMaxVersion + 1), From 89aa270ba6bdefccc944ae924073c21028cd9789 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Thu, 3 Apr 2025 20:22:22 -0700 Subject: [PATCH 0230/1324] Reorder tpose(reshape(tpose(input))) to tpose(tpose(reshape(input))). This reordering enables the two adjacent transpose ops to be folded into one, reducing unnecessary computation. PiperOrigin-RevId: 743790281 --- .../compiler/mlir/lite/tests/optimize.mlir | 18 +++ .../mlir/lite/transforms/optimize_pass.cc | 143 +++++++++++++++++- 2 files changed, 159 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 0a52298b17e7b9..4aa7057bdb7024 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -4610,3 +4610,21 @@ func.func @EliminateBooleanCastCompare(%arg0: tensor<*xi1>) -> (tensor<*xi1>, te // CHECK: %9 = "tfl.zeros_like"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> // CHECK: return %0, %1, %3, %arg0, %arg0, %4, %5, %7, %8, %arg0, %9, %arg0 : tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1> } + +// CHECK-LABEL: @ReorderTransposeReshapeTranspose +func.func @ReorderTransposeReshapeTranspose(%arg0: tensor<282x2048xf32>) -> tensor<2x1x282x1024xf32> { + %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + %cst_1 = arith.constant dense<[2, 1024, 1, 282]> : tensor<4xi32> + %cst_2 = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<282x2048xf32>, tensor<2xi32>) -> tensor<2048x282xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<2048x282xf32>, tensor<4xi32>) -> tensor<2x1024x1x282xf32> + %2 = "tfl.transpose"(%1, %cst_2) : (tensor<2x1024x1x282xf32>, tensor<4xi32>) -> tensor<2x1x282x1024xf32> + return %2: tensor<2x1x282x1024xf32> + + // CHECK: %cst = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32> + // CHECK-NEXT: %cst_0 = arith.constant dense<[282, 2, 1024, 1]> : tensor<4xi32> + // CHECK-NEXT: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<282x2048xf32>, tensor<4xi32>) -> tensor<282x2x1024x1xf32> + // CHECK-NEXT: %1 = "tfl.transpose"(%0, %cst) : (tensor<282x2x1024x1xf32>, tensor<4xi32>) -> tensor<2x1x282x1024xf32> + // CHECK-NEXT: return %1 : tensor<2x1x282x1024xf32> +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index cb2702886dc719..d7b64f89ae6f2f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -2798,6 +2798,145 @@ struct PushTransposeThroughSqueeze : public RewritePattern { } }; +// Helper function to check if a constant tensor attribute has the expected +// integer values +bool matchConstantIntPermutation(Value permValue, + ArrayRef expectedPerm) { + DenseElementsAttr permAttr; + if (!matchPattern(permValue, m_Constant(&permAttr))) { + return false; // Not a constant + } + if (!permAttr.getElementType().isInteger(32) && + !permAttr.getElementType().isInteger(64)) { + // TFLite perms are often i32, but accept i64 too + return false; + } + + auto values = permAttr.getValues(); + if (values.size() != expectedPerm.size()) { + return false; + } + for (size_t i = 0; i < expectedPerm.size(); ++i) { + if (values[i].getSExtValue() != expectedPerm[i]) { + return false; + } + } + return true; +} + +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + llvm::SmallVector new_values; + for (auto el : values) { + new_values.push_back(static_cast(el)); + } + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, new_values); +} + +// Reorders a Transpose-Reshape-Transpose sequence to +// Reshape-Transpose-Transpose to allow for further optimization. +// +// The pattern matches: +// Transpose(Reshape(Transpose(input, perm: [1, 0]))) +// +// and rewrites it to: +// Transpose(Transpose(Reshape(input))) +// +// This reordering allows for further optimization by potentially fusing the +// reshapes and transposes. +struct ReorderTransposeReshapeTranspose + : public OpRewritePattern { + explicit ReorderTransposeReshapeTranspose(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TFL::TransposeOp outer_tpose, + PatternRewriter &rewriter) const override { + auto reshape = outer_tpose.getInput().getDefiningOp(); + if (!reshape) return failure(); + + auto inner_tpose = reshape.getInput().getDefiningOp(); + if (!inner_tpose) return failure(); + + auto inner_tpose_shape = + mlir::dyn_cast_or_null(inner_tpose.getType()); + if (!inner_tpose_shape) return failure(); + + auto input = inner_tpose.getInput(); + + auto inner_perm = inner_tpose.getPerm(); + if (!matchConstantIntPermutation(inner_perm, {1, 0})) return failure(); + + int64_t perm0 = inner_tpose_shape.getDimSize(0); + + llvm::SmallVector reshape_shape; + { + DenseIntElementsAttr reshape_shape_attr; + if (!matchPattern(reshape.getShape(), m_Constant(&reshape_shape_attr))) { + return failure(); + } + + for (auto dim : reshape_shape_attr) { + reshape_shape.push_back(static_cast(dim.getSExtValue())); + } + } + + // Consume dimensions until we've equaled the size of the first dim in the + // permuted result of the inner tpose and record the dim. + int32_t dim = -1; + for (auto i = 0, running_total = 1; i < reshape_shape.size(); i++) { + running_total *= reshape_shape[i]; + if (perm0 == running_total) { + dim = i; + } + } + + if (dim == -1) return failure(); + + llvm::SmallVector new_reshape_shape(reshape_shape.size()); + llvm::SmallVector new_inner_perm(reshape_shape.size()); + + int index = 0; + for (auto i = dim + 1; i < reshape_shape.size(); i++) { + new_inner_perm[i] = index; + new_reshape_shape[index++] = reshape_shape[i]; + } + for (auto i = 0; i <= dim; i++) { + new_inner_perm[i] = index; + new_reshape_shape[index++] = reshape_shape[i]; + } + + auto reshape_type = + mlir::dyn_cast_or_null(reshape.getType()); + if (!reshape_type) return failure(); + + auto new_reshape_shape_const = rewriter.create( + reshape.getLoc(), GetI32ElementsAttr(new_reshape_shape, &rewriter)); + + auto new_inner_reshape = rewriter.create( + reshape.getLoc(), + RankedTensorType::get(new_reshape_shape, reshape_type.getElementType()), + input, new_reshape_shape_const.getResult()); + auto new_inner_tpose = rewriter.create( + inner_tpose.getLoc(), reshape_type, new_inner_reshape, + rewriter.create( + inner_tpose.getLoc(), + GetI32ElementsAttr(new_inner_perm, &rewriter))); + + rewriter.replaceOp(reshape, new_inner_tpose); + + return success(); + } +}; + // Adds canonicalization patterns to the list of patterns. void AddCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns) { @@ -2858,8 +2997,8 @@ void OptimizePass::runOnOperation() { OptimizeTopK, FuseAddAndStridedSlice, FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul, MoveReshapeAfterFullyConnected, - EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp>( - ctx); + EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp, + ReorderTransposeReshapeTranspose>(ctx); if (!GetOptions().disable_fuse_mul_and_fc) { phase_2_patterns.add(ctx); } From 3445ee21caa26d1047cbeff897feb0e9ef76e442 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 20:27:50 -0700 Subject: [PATCH 0231/1324] Enforce that `Shape::dynamic_dimensions()` is only called on array shapes. PiperOrigin-RevId: 743791400 --- third_party/xla/xla/shape.h | 6 +----- .../xla/xla/stream_executor/tpu/c_api_conversions.cc | 9 +++++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 387e335340e3db..1637eee92100f1 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -213,11 +213,7 @@ class Shape { // Returns a span to indicate whether each dimension is dynamic. // Precondition: this is an array shape. absl::Span dynamic_dimensions() const { - if (auto* const state = if_array_state()) { - return state->dynamic_dimensions; - } - // TODO(b/404276923): ensure that this is never called on non-array shapes. - return {}; + return array_state().dynamic_dimensions; } absl::Span mutable_dynamic_dimensions() { return absl::MakeSpan(array_state().dynamic_dimensions); diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc index e0098ecf7e5d03..035660859564c2 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc @@ -274,8 +274,13 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) { void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { c_shape->element_type = xla_shape.element_type(); - CreateVector(xla_shape.dimensions(), &c_shape->dimensions); - CreateVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); + if (xla_shape.IsArray()) { + CreateVector(xla_shape.dimensions(), &c_shape->dimensions); + CreateVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); + } else { + c_shape->dimensions.size = 0; + c_shape->dynamic_dimensions.size = 0; + } c_shape->ntuple_shapes = xla_shape.IsTuple() ? xla_shape.tuple_shapes_size() : 0; From bec09f7151dc8d8cb1f37906d341b21a6bc2f674 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 20:49:27 -0700 Subject: [PATCH 0232/1324] Enforce that `TrueNumDimensions()` is only called on array shapes. PiperOrigin-RevId: 743795543 --- third_party/xla/xla/service/instruction_fusion.cc | 7 +++++-- third_party/xla/xla/shape_util.cc | 11 +++++------ third_party/xla/xla/shape_util.h | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index 2c1f02d37ba537..d9f66a552a5266 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -263,8 +263,11 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { ShapeUtil::IsEffectiveScalar(operand->shape())) { return false; } - return ShapeUtil::TrueNumDimensions(operand->shape()) >= - output_rank; + const int true_dims = + operand->shape().IsArray() + ? ShapeUtil::TrueNumDimensions(operand->shape()) + : 0; + return true_dims >= output_rank; }) <= 1; } diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 1c3fda998628dd..8e816014eb670d 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -225,14 +225,13 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { return equal; } -/* static */ int64_t ShapeUtil::TrueNumDimensions(const Shape& shape) { - if (!shape.IsArray()) { - // TODO(b/404276923): enforce that this is never called on non-array shapes. - return 0; - } +/* static */ int64_t ShapeUtil::TrueNumDimensions(const Shape& array_shape) { + CHECK(array_shape.IsArray()) + << "TrueNumDimensions called on non-array shape: " + << array_shape.ToString(); int64_t accum = 0; - for (const int64_t dimension : shape.dimensions()) { + for (const int64_t dimension : array_shape.dimensions()) { // We do not count unit dimensions. if (dimension != 1) { accum += 1; diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index b6ccf195b6047b..ea04f3b6bc95c7 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -294,8 +294,8 @@ class ShapeUtil { // 1. e.g., f32[2x1x1] has a true dimensionality of 1D, the other dimensions // are just fluff. Note that zero dimensions are included in the true // dimensionality, e.g., f32[3,0,1] has a true dimensionality of 2D. - // Precondition: shape.IsArray(). - static int64_t TrueNumDimensions(const Shape& shape); + // Precondition: array_shape.IsArray(). + static int64_t TrueNumDimensions(const Shape& array_shape); static ProgramShape MakeProgramShape(std::initializer_list parameters, Shape result); From 1ec925659ad7ca854b8d4bc1fb7e6a2cc04ce0a7 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Thu, 3 Apr 2025 21:27:24 -0700 Subject: [PATCH 0233/1324] Add optimization pass to rewrite FC(IsConst(x), y) as FC(y, x). Many downstream optimizations rely on proper ordering of the input operands into FullyConnected. Some JAX programs produce graphs with the inputs swapped and this rewrite pattern protects against that behavior. PiperOrigin-RevId: 743803775 --- .../compiler/mlir/lite/tests/optimize.mlir | 45 +++++++++ .../mlir/lite/transforms/optimize_pass.cc | 91 ++++++++++++++++++- 2 files changed, 135 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 4aa7057bdb7024..6ad231ed9f54bc 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -4628,3 +4628,48 @@ func.func @ReorderTransposeReshapeTranspose(%arg0: tensor<282x2048xf32>) -> tens // CHECK-NEXT: return %1 : tensor<2x1x282x1024xf32> } +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConst +func.func @FullyConnectedSwapOperandsWhenLHSIsConst(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + // CHECK-NEXT: %cst_0 = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32> + // CHECK-NEXT: %0 = "tfl.fully_connected"(%arg0, %cst_0, %arg1) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> + // CHECK-NEXT: %1 = "tfl.transpose"(%0, %cst) : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<2x4xf32> + // CHECK-NEXT: return %1 : tensor<2x4xf32> +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstBias +func.func @FullyConnectedSwapOperandsWhenLHSIsConstBias(%arg0: tensor<4x2xf32>) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst_1 = arith.constant dense<2.0> : tensor<2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %cst_1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, tensor<2xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NEXT: [[cst_1:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], [[cst_1]]) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstKeepNumDimsTrue +func.func @FullyConnectedSwapOperandsWhenLHSIsConstKeepNumDimsTrue(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction +func.func @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index d7b64f89ae6f2f..2816e8bb090f1a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -2937,6 +2937,94 @@ struct ReorderTransposeReshapeTranspose } }; +// Some models produce FullyConnected ops where the LHS is a const and the RHS +// is the activation. This breaks some downstream optimizations (notably input +// caching in XNNPack among other things). This rewrite pattern swaps the +// operands to match the expected order and recomputes a new output shape for +// the resuling op. +// +// This pattern only applies when: +// * input and filter operands are 2D +// * bias = none +// * keep_num_dims = false (implied if input and filter are 2D) +// Support for additional cases to broaden applicability can be added later. +// TODO(b/408313959): Add support for more cases. +// +// Note that transposes are added to maintain correctness: +// +// Original: Output[B, O] = FC(Input[B, I](Const), Filter[O, I](Var), Bias=None) +// ~= matmul(C, transpose(V)) +// +// Transformed: +// Intermediate[O, B] = FC(Filter[O, I](Var), Input[B, I](Const), None) +// ~= matmul(V, transpose(C)) +// FinalOutput[B, O] = Transpose(Intermediate[O, B], perm=[1, 0]) +struct FullyConnectedSwapOperandsWhenLHSIsConst + : public OpRewritePattern { + explicit FullyConnectedSwapOperandsWhenLHSIsConst(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc, + PatternRewriter &rewriter) const override { + if (!mlir::isa(fc.getBias().getType())) return failure(); + + auto input = fc.getInput(); + auto filter = fc.getFilter(); + + if (!matchPattern(input, m_Constant()) || + matchPattern(filter, m_Constant())) + return failure(); + + auto input_type = mlir::dyn_cast(input.getType()); + auto filter_type = mlir::dyn_cast(filter.getType()); + auto output_type = + mlir::dyn_cast(fc.getResult(0).getType()); + + if (!input_type || !filter_type || !output_type) return failure(); + + if (input_type.getRank() != 2 && filter_type.getRank() != 2) + return failure(); + + // Dimensions: B=Batch, I=InputDepth, O=OutputDepth + // Input: [B, I], Filter: [O, I] + // We extract B from the input operand and O from the filter operand + int64_t B = input_type.getDimSize(0); + int64_t O = filter_type.getDimSize(0); + + Type element_type = output_type.getElementType(); + Location loc = fc.getLoc(); + + RankedTensorType intermediate_type = + RankedTensorType::get({O, B}, element_type); + + auto new_fc = rewriter.create( + loc, + /*resultTypes=*/intermediate_type, + /*input=*/filter, // Original Filter V[O, I] + /*filter=*/input, // Original Input C[B, I] + /*bias=*/fc.getBias(), + /*fused_activation_function=*/ + rewriter.getStringAttr(fc.getFusedActivationFunction()), + /*weights_format=*/fc.getWeightsFormatAttr(), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + /*asymmetric_quantize_inputs=*/ + fc.getAsymmetricQuantizeInputsAttr() // Propagate quant attr + ); + + RankedTensorType final_shape_type = + RankedTensorType::get({B, O}, element_type); + + Value transposed_result = rewriter.create( + loc, final_shape_type, new_fc.getResult(0), + rewriter.create( + loc, GetI32ElementsAttr(ArrayRef({1, 0}), &rewriter))); + + rewriter.replaceOp(fc, transposed_result); + + return success(); + } +}; + // Adds canonicalization patterns to the list of patterns. void AddCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns) { @@ -2998,7 +3086,8 @@ void OptimizePass::runOnOperation() { FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul, MoveReshapeAfterFullyConnected, EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp, - ReorderTransposeReshapeTranspose>(ctx); + ReorderTransposeReshapeTranspose, + FullyConnectedSwapOperandsWhenLHSIsConst>(ctx); if (!GetOptions().disable_fuse_mul_and_fc) { phase_2_patterns.add(ctx); } From 6dde6c88cd9d5173bf794eb8afe099cd8ab315f3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 22:11:47 -0700 Subject: [PATCH 0234/1324] Automated Code Change PiperOrigin-RevId: 743812649 --- tensorflow/compiler/tf2xla/mlir_bridge_pass.h | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index eae5fb83c5d682..f41c202b01e447 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" From c33fd64c8590a03d4c54941f48495faedb477f89 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 22:22:46 -0700 Subject: [PATCH 0235/1324] Enforce that `Shape::tuple_shapes()` is only called on tuple shapes. PiperOrigin-RevId: 743814497 --- .../xla/xla/service/layout_assignment.cc | 9 +++++--- third_party/xla/xla/shape.cc | 7 +----- .../tpu/c_api_conversions_test.cc | 22 +++++-------------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 326d2a66d777af..1fb7e247e92661 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -2939,9 +2939,12 @@ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { if (shape.IsArray()) { return shape.dimensions_size() <= 1; } - return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { - return IsAtMostRank1(subshape); - }); + if (shape.IsTuple()) { + return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { + return IsAtMostRank1(subshape); + }); + } + return true; } absl::Status LayoutAssignment::Init(HloModule* module) { diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 0d9feb7ad2f556..670b4e7c67f0e8 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -253,12 +253,7 @@ void Shape::CheckStateIsEmpty() const { } const std::vector& Shape::tuple_shapes() const { - if (const auto* const state = if_tuple_state()) { - return state->tuple_shapes; - } - // TODO(b/404276923): ensure that this is never called on non-tuple shapes. - static const auto* const kEmpty = new std::vector(); - return *kEmpty; + return tuple_state().tuple_shapes; } void Shape::Clear() { diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc index b9465c1a46d1e7..79a357b99c4114 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -246,25 +246,13 @@ TEST(XlaShape, ToCNested) { MakeSpan(c_shape.dynamic_dimensions); EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); - EXPECT_FALSE(cpp_shape.IsTuple()); - EXPECT_EQ(c_shape.ntuple_shapes, 0); - - const int c_ntuple_shapes = c_shape.ntuple_shapes; - const std::vector& cpp_tuple_shapes = cpp_shape.tuple_shapes(); - absl::Span c_tuple_shapes(c_shape.tuple_shapes, - c_ntuple_shapes); - for (int i = 0; i < c_ntuple_shapes; ++i) { - xla::Shape converted_c_shape = FromC(&c_tuple_shapes[i]); - EXPECT_EQ(cpp_tuple_shapes[i], converted_c_shape); - } + ASSERT_FALSE(cpp_shape.IsTuple()); + ASSERT_EQ(c_shape.ntuple_shapes, 0); - bool cpp_has_layout = cpp_shape.has_layout(); - bool c_has_layout = c_shape.has_layout; - EXPECT_EQ(cpp_has_layout, c_has_layout); + EXPECT_EQ(cpp_shape.has_layout(), c_shape.has_layout); - if (c_has_layout) { - xla::Layout converted_c_layout = FromC(&c_shape.layout); - EXPECT_EQ(cpp_shape.layout(), converted_c_layout); + if (c_shape.has_layout) { + EXPECT_EQ(cpp_shape.layout(), FromC(&c_shape.layout)); } Destroy(&c_shape); From ef92f4467ca8287447d0248963b83a6ba15592b6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Apr 2025 23:14:10 -0700 Subject: [PATCH 0236/1324] Automated Code Change PiperOrigin-RevId: 743824658 --- tensorflow/compiler/tf2xla/layout_util.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index 69075d3c712523..e6486684e6cd85 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -72,8 +72,8 @@ absl::Status RewriteLayoutWithShardedShape( sharding->TileOffsetForDevice(*xla_shape, device); std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->dimensions_size()); - for (int64_t i = 0; i < xla_shape->dimensions_size(); ++i) { + std::vector dimensions(xla_shape->dimensions().size()); + for (int64_t i = 0; i < xla_shape->dimensions().size(); ++i) { dimensions[i] = limit[i] - offset[i]; } xla::Shape per_device_xla_shape = @@ -131,7 +131,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( hlo_sharding, fast_mem, shape_determination_fns, &to_shape)); } if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64_t i = 0; i < original_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < original_shape.dimensions().size(); ++i) { to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); } } From 1383f6e00f0e1ece8096cf44dfce763fbd9cc261 Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Thu, 3 Apr 2025 23:33:36 -0700 Subject: [PATCH 0237/1324] Fix gpu cost analysis with: - Add XLA Ops processing when creating device op metrics db from XPlane - Create cost_analysis instance for gpu PiperOrigin-RevId: 743828244 --- tensorflow/core/profiler/convert/BUILD | 8 ++ tensorflow/core/profiler/convert/repository.h | 8 +- .../convert/xplane_to_op_metrics_db.cc | 112 +++++++++++++----- .../convert/xplane_to_op_metrics_db.h | 5 +- .../convert/xplane_to_op_metrics_db_test.cc | 15 ++- .../profiler/convert/xplane_to_op_stats.cc | 14 ++- 6 files changed, 126 insertions(+), 36 deletions(-) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 6748b3a5e22e58..4ef2c038ee8511 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -34,6 +34,8 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", "@org_xprof//xprof/utils:cost_utils", + "@org_xprof//xprof/utils:gpu_event_stats", + "@org_xprof//xprof/utils:hlo_module_map", "@org_xprof//xprof/utils:op_metrics_db_utils", "@org_xprof//xprof/utils:op_utils", ], @@ -58,7 +60,10 @@ tf_cc_test( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//xprof/utils:hlo_cost_analysis_wrapper", + "@org_xprof//xprof/utils:hlo_module_map", "@org_xprof//xprof/utils:op_metrics_db_utils", + "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], ) @@ -432,9 +437,11 @@ cc_library( "@org_xprof//xprof/utils:event_span", "@org_xprof//xprof/utils:gpu_event_stats", "@org_xprof//xprof/utils:hardware_type_utils", + "@org_xprof//xprof/utils:hlo_cost_analysis_wrapper", "@org_xprof//xprof/utils:hlo_module_map", "@org_xprof//xprof/utils:kernel_stats_utils", "@org_xprof//xprof/utils:op_utils", + "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], ) @@ -1008,6 +1015,7 @@ cc_library( "@local_xla//xla/tsl/platform:statusor", "@local_xla//xla/tsl/profiler/utils:file_system_utils", "@org_xprof//xprof/utils:hlo_module_map", + "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], ) diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h index 84ac5dd3b0188a..9c5ee606552711 100644 --- a/tensorflow/core/profiler/convert/repository.h +++ b/tensorflow/core/profiler/convert/repository.h @@ -34,6 +34,7 @@ limitations under the License. #include "tsl/platform/path.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/utils/hlo_module_map.h" // from @org_xprof +#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof namespace tensorflow { namespace profiler { @@ -183,14 +184,19 @@ absl::Status ReadBinaryProto(const SessionSnapshot& session_snapshot, return session_snapshot.ReadBinaryProto(data_type, host, proto); } +// TODO(b/408280338) Remove this function as 0 reference is found. +// Add a dummy cost_analysis factory function as a no-op now. // Process HloModuleMap from all XSpaces in a session. inline absl::StatusOr ProcessHloModuleMap( const SessionSnapshot& session_snapshot) { HloModuleMap hlo_module_map; + tensorflow::profiler::HloCostAnalysisWrapper::Factory create_cost_analysis = + []() { return nullptr; }; for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, session_snapshot.GetXSpace(i)); - ProcessHloModuleMapFromXSpace(hlo_module_map, xspace.get()); + ProcessHloModuleMapFromXSpace(hlo_module_map, xspace.get(), + create_cost_analysis); } return hlo_module_map; } diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 505ef66fb81591..f48acaea10d9d3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -44,6 +45,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/utils/cost_utils.h" // from @org_xprof +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof +#include "xprof/utils/hlo_module_map.h" // from @org_xprof #include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof #include "xprof/utils/op_utils.h" // from @org_xprof @@ -51,8 +54,24 @@ namespace tensorflow { namespace profiler { namespace { +using ::tensorflow::profiler::GpuEventStats; using tsl::profiler::GetDeviceEventTimespan; +struct HLOTracker { + uint64_t duration = 0; + uint64_t program_id = 0; + uint64_t group_id = 0; + bool is_eager; + const HloInstructionWrapper* hlo_instruction = nullptr; + std::string hlo_op_name; + + void Reset() { + duration = program_id = group_id = 0; + hlo_op_name.clear(); + hlo_instruction = nullptr; + } +}; + // Type of a TensorFlow Op activity, which is either beginning or ending an Op. enum TfActivityType { kTfOpBegin, kTfOpEnd }; @@ -276,7 +295,31 @@ OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( return builder.Finalize(last_op_timestamp_ps - first_op_timestamp_ps); } -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) { +void AggregateHloFunc(HLOTracker& current, DeviceOpMetricsDbBuilder& metricDb) { + if (current.hlo_instruction == nullptr) return; + auto performance_info_wrapper = + current.hlo_instruction->GetPerformanceInfoWrapper(); + auto flops = 0; + auto bytes_accessed = 0; + if (performance_info_wrapper != nullptr) { + flops = performance_info_wrapper->flops(); + bytes_accessed = performance_info_wrapper->bytes_accessed(); + } + metricDb.EnterOp( + current.program_id, current.hlo_op_name, + current.hlo_instruction->Category(), current.hlo_instruction->TfOpName(), + current.hlo_instruction->DeduplicatedName(), current.is_eager, 1, + current.duration, 0, performance_info_wrapper->DeviceFlops(), + performance_info_wrapper->bytes_accessed(), + ConvertPerformanceInfo( + performance_info_wrapper->memory_accessed_breakdown(), 1), + performance_info_wrapper->ModelFlops(), + current.hlo_instruction->Expression()); + current.Reset(); +} + +OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( + const XPlane& device_trace, const HloModuleMap& hlo_module_map) { OpMetricsDb result; DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result); @@ -285,42 +328,55 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) { TfOpRoofLineCostEstimator op_level_cost_estimator; XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); + HLOTracker current; plane.ForEachLine([&](const XLineVisitor& line) { if (IsDerivedThreadId(line.Id())) return; line.ForEachEvent([&](const XEventVisitor& event) { first_op_offset_ps = std::min(first_op_offset_ps, event.OffsetPs()); last_op_offset_ps = std::max(last_op_offset_ps, event.EndOffsetPs()); - absl::string_view tf_op_full_name; - bool is_eager = false; - int64_t program_id = 0; - absl::string_view deduplicated_name = ""; - event.ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kTfOp) { - tf_op_full_name = stat.StrOrRefValue(); - } else if (stat.Type() == StatType::kIsEager) { - is_eager = stat.IntValue(); - } else if (stat.Type() == StatType::kProgramId) { - program_id = stat.IntOrUintValue(); - } else if (stat.Type() == StatType::kDeduplicatedName) { - deduplicated_name = stat.StrOrRefValue(); + GpuEventStats stats(&event); + if (stats.IsXlaOp()) { + const auto* hlo_instruction = GetHloInstruction( + hlo_module_map, stats.program_id, stats.hlo_op_names.back()); + if (hlo_instruction != nullptr) { + if (stats.hlo_op_names.back() != current.hlo_op_name || + stats.group_id != current.group_id) { + AggregateHloFunc(current, device_op_metrics_db_builder); + } + // Merge identical and contiguous HLOs. + current.hlo_instruction = hlo_instruction; + current.hlo_op_name = stats.hlo_op_names.back(); + current.duration += event.DurationPs(); + current.is_eager = stats.is_eager; + current.program_id = *stats.program_id; + if (stats.group_id.has_value()) { + current.group_id = *stats.group_id; + } } - }); - if (tf_op_full_name.empty()) return; - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(tf_op_full_name); - TfOpRoofLineCostEstimator::OpRoofLineStats costs; - if (tf_op.category != tsl::profiler::Category::kUnknown) { - costs = op_level_cost_estimator.Predict(event); + } else if (stats.IsTfOp()) { + AggregateHloFunc(current, device_op_metrics_db_builder); + tsl::profiler::TfOp tf_op = + tsl::profiler::ParseTfOpFullname(stats.tf_op_fullname); + PerformanceInfo perf_info; + if (tf_op.category != tsl::profiler::Category::kUnknown) { + auto costs = op_level_cost_estimator.Predict(event); + // NOTE: events are per kernel, but costs are per tf-ops. + perf_info.set_flops(costs.flops); + perf_info.set_bytes_accessed(costs.bytes_accessed); + } + std::string name = absl::StrCat(tf_op.name, "/", event.Name()); + device_op_metrics_db_builder.EnterOp( + /*program_id=*/0, + /**name=*/name, + /**category=*/tf_op.type, + /*provenance=*/stats.tf_op_fullname, "", stats.is_eager, + /*occurrences=*/1, event.DurationPs(), + /*children_time_ps=*/0, perf_info.flops(), + perf_info.bytes_accessed()); } - device_op_metrics_db_builder.EnterOp( - /*program_id=*/program_id, - /**name=*/absl::StrCat(tf_op.name, "/", event.Name()), - /**category=*/tf_op.type, - /*provenance=*/tf_op_full_name, deduplicated_name, is_eager, - /*occurrences=*/1, event.DurationPs(), - /*children_time_ps=*/0, costs.flops, costs.bytes_accessed); }); + AggregateHloFunc(current, device_op_metrics_db_builder); }); SetTotalTimePs( result, last_op_offset_ps ? last_op_offset_ps - first_op_offset_ps : 0); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h index 337b4bee2cf27a..39fea9a2ef786b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -47,7 +47,10 @@ void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace); -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace); +// Converts GPU device trace to OpMetricsDb. +// Will use HloModuleMap to source performance info for cost analysis. +OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( + const XPlane& device_trace, const HloModuleMap& hlo_module_map); // Convert TPU DeviceTrace XPlane to OpMetricDb OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index 30fd0b1a5f26a7..a93cb82e98ed7b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -33,7 +33,10 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/hlo_cost_analysis_wrapper.h" // from @org_xprof +#include "xprof/utils/hlo_module_map.h" // from @org_xprof #include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof +#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof namespace tensorflow { namespace profiler { @@ -185,16 +188,20 @@ TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDb) { AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kKernel3StartNs, kKernel3DurationNs, /*on_device=*/true, kKernel3, &device_plane, &stream2); - - OpMetricsDb op_metrics = ConvertDeviceTraceXPlaneToOpMetricsDb(*xplane); + HloModuleMap hlo_module_map; + tensorflow::profiler::HloCostAnalysisWrapper::Factory create_cost_analysis = + []() { return tensorflow::profiler::CreateXprofGpuCostAnalysis(); }; + ProcessHloModuleMapFromXSpace(hlo_module_map, &xspace, create_cost_analysis); + OpMetricsDb op_metrics = + ConvertDeviceTraceXPlaneToOpMetricsDb(*xplane, hlo_module_map); // kernel1, kernel2, kernel3, Idle. EXPECT_EQ(4, op_metrics.metrics_db_size()); uint64 total_op_duration = tsl::profiler::NanoToPico( kKernel1DurationNs * 2 + kKernel2DurationNs * 2 + kKernel3DurationNs); EXPECT_EQ(total_op_duration, op_metrics.total_op_time_ps()); - // For device, the total_duration for each device is the total duration merged - // from all GPU streams, which is from 100000 to 130000. + // For device, the total_duration for each device is the total duration + // merged from all GPU streams, which is from 100000 to 130000. uint64 total_duration = tsl::profiler::NanoToPico( kKernel3StartNs + kKernel3DurationNs - kKernel1StartNs); EXPECT_EQ(std::max(total_duration, total_op_duration), diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 3747047a2730a4..73f35d7ae0dc65 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -54,9 +54,11 @@ limitations under the License. #include "xprof/utils/event_span.h" // from @org_xprof #include "xprof/utils/gpu_event_stats.h" // from @org_xprof #include "xprof/utils/hardware_type_utils.h" // from @org_xprof +#include "xprof/utils/hlo_cost_analysis_wrapper.h" // from @org_xprof #include "xprof/utils/hlo_module_map.h" // from @org_xprof #include "xprof/utils/kernel_stats_utils.h" // from @org_xprof #include "xprof/utils/op_utils.h" // from @org_xprof +#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof namespace tensorflow { namespace profiler { @@ -313,7 +315,14 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, HloModuleMap hlo_module_map; if (options.generate_kernel_stats_db || (is_tpu && options.generate_op_metrics_db)) { - ProcessHloModuleMapFromXSpace(hlo_module_map, &space); + tensorflow::profiler::HloCostAnalysisWrapper::Factory create_cost_analysis = + []() { return nullptr; }; + if (is_gpu) { + create_cost_analysis = []() { + return tensorflow::profiler::CreateXprofGpuCostAnalysis(); + }; + } + ProcessHloModuleMapFromXSpace(hlo_module_map, &space, create_cost_analysis); } for (const XPlane* device_trace : device_planes) { if (options.generate_op_metrics_db) { @@ -322,7 +331,8 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, } if (!is_tpu) { OpMetricsDb device_op_metrics_db = - ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace); + ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace, + hlo_module_map); op_metrics_db_combiner.Combine(device_op_metrics_db); } else { // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is From 7ff1fec59b2492bbad0a1630828f0ef34b2e5002 Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Thu, 3 Apr 2025 23:59:16 -0700 Subject: [PATCH 0238/1324] [XLA:GPU] Add triton support test for cholesky PiperOrigin-RevId: 743833418 --- .../backends/gpu/codegen/triton/support.cc | 1 - .../gpu/codegen/triton/support_test.cc | 41 ++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index aefc67e95d01c3..6852852469608b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -586,7 +586,6 @@ CodegenDecision IsTritonSupportedInstructionImpl( namespace internal { bool IsTritonUnsupportedOpcode(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kCholesky: case HloOpcode::kConvolution: case HloOpcode::kCopyDone: case HloOpcode::kCopyStart: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 10b964f733bd3a..cae593137e5454 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -89,6 +89,7 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { case HloOpcode::kReal: case HloOpcode::kImag: case HloOpcode::kLogistic: + case HloOpcode::kCholesky: return pu::IsFloatingPointType(type) || pu::IsComplexType(type); case HloOpcode::kCbrt: case HloOpcode::kErf: @@ -2518,10 +2519,47 @@ INSTANTIATE_TEST_SUITE_P(CustomCallSuite, CustomCallTest, ::testing::ValuesIn(AllDevicesToTest()), TritonSupportTestDeviceToString); +class CholeskyTest + : public TritonSupportTest, + public ::testing::WithParamInterface< + // The bool parameter is used to parametrize the lower=?. + std::tuple> {}; + +TEST_P(CholeskyTest, Cholesky) { + auto [data_type, cc, lower] = GetParam(); + + const std::string kHloTestTemplate = absl::Substitute( + R"( + ENTRY triton_computation { + parameter = $0[4,4] parameter(0) + ROOT cholesky_op = $0[4,4] cholesky(parameter), lower=$1 + })", + primitive_util::LowercasePrimitiveTypeName(data_type), lower); + + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( + kHloTestTemplate, data_type, + HloOpcode::kCholesky)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 2}, cc); +} + +std::string CholeskyTestName( + const ::testing::TestParamInfo< + std::tuple>& data) { + const auto [data_type, cc, lower] = data.param; + return absl::StrCat(primitive_util::LowercasePrimitiveTypeName(data_type), + "_", ComputeCapabilityToString(cc), "_", lower); +} + +INSTANTIATE_TEST_SUITE_P( + CholeskySuite, CholeskyTest, + ::testing::Combine( + ::testing::ValuesIn(AllOpSupportedTypes(HloOpcode::kCholesky)), + ::testing::ValuesIn(AllDevicesToTest()), ::testing::Bool()), + CholeskyTestName); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start - HloOpcode::kCholesky, HloOpcode::kConvolution, HloOpcode::kCopyDone, HloOpcode::kCopyStart, @@ -2581,6 +2619,7 @@ absl::flat_hash_set AllTestedOpcodes() { ret.emplace(HloOpcode::kBatchNormTraining); ret.emplace(HloOpcode::kBitcastConvert); ret.emplace(HloOpcode::kCall); + ret.emplace(HloOpcode::kCholesky); ret.emplace(HloOpcode::kComplex); ret.emplace(HloOpcode::kConditional); ret.emplace(HloOpcode::kCustomCall); From df8faa29371e4280704ea7510616350a1840c8e4 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 4 Apr 2025 00:00:52 -0700 Subject: [PATCH 0239/1324] PR #24436: Fix variadic reduction shared memory estimation. Imported from GitHub PR https://github.com/openxla/xla/pull/24436 Currently, the logic is broken for variadic reductions with heterogeneous input types, since it always uses the first input's primitive type to estimate the shared memory buffer size. It should be summing up the primitive sizes instead. Also expand the comments a bit to explain better what's going on there. This should fix https://github.com/jax-ml/jax/issues/27190. Copybara import of the project: -- 8bdb3fb57f10bd758fb6f1f2d159d7c5283322d6 by Johannes Reifferscheid : Fix variadic reduction shared memory estimation. Currently, the logic is broken for variadic reductions with heterogeneous input types, since it always uses the first input's primitive type to estimate the shared memory buffer size. It should be summing up the primitive sizes instead. Also expand the comments a bit to explain better what's going on there. Merging this change closes #24436 PiperOrigin-RevId: 743833721 --- .../xla/xla/service/gpu/gpu_fusible.cc | 29 +++++++++++----- .../xla/xla/service/gpu/gpu_fusible_test.cc | 33 ++++++++++++++++++- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 6cb217a2e3337c..ff019c3e42ecfe 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -509,17 +509,28 @@ static int64_t SharedMemoryUsageNoCache( IsReductionFromOrToContiguousDimensions(instr, device_info)) { ReductionDimensions reduction_info = GetReductionKindAndContiguousComponents(instr); - int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType( - instr.operand(0)->shape().element_type()); - int num_variadic = - instr.shape().IsTuple() ? instr.shape().tuple_shapes_size() : 1; + int64_t primitive_size_sum = 0; + // Variadic reductions will allocate one shared memory buffer for each + // input. They all have the same shape, so we can just sum up the primitive + // sizes of the inputs. + for (int i = 0; i < instr.operand_count() / 2; ++i) { + primitive_size_sum += ShapeUtil::ByteSizeOfPrimitiveType( + instr.operand(i)->shape().element_type()); + } + if (reduction_info.is_row_reduction) { - // __shared__[32] is used for row reduction. - return 32 * primitive_size * num_variadic; + // In row reductions, we write at most one element per warp to shared + // memory, regardless of whether the reduction is vectorized or not. We + // have at most 32 warps for a single row. We could tighten this estimate, + // but it doesn't really matter. Row reductions are very unlikely to ever + // run out of shared memory budget. + return 32 * primitive_size_sum; } else { - // __shared__[4][32][33] cache is used for column reduction ("4" comes - // from potential x-tiling). - return 4 * 32 * 33 * primitive_size * num_variadic; + // The shape of the cache for column reductions is 32x(vector_size * 32 + + // 1). We don't know the actual vector size here, so we assume the + // maximum. + constexpr int kMaxVectorSize = 4; + return 32 * (kMaxVectorSize * 32 + 1) * primitive_size_sum; } } else if (auto tr = GetDescriptionForTiledTransposeEmitter(instr)) { // Tile size for transposition. diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index 43a6d4a4dbf990..35ca6e0c0c3063 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "xla/service/gpu/gpu_fusible.h" #include @@ -1551,6 +1550,38 @@ TEST_F(GpuFusibleTest, GetSharedMemoryUsage) { EXPECT_EQ(cache.GetSharedMemoryUsage(*fusion), 32 * 33 * 2 * 4); } +TEST_F(GpuFusibleTest, GetSharedMemoryUsageVariadicReduction) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + reducer { + p0 = pred[] parameter(0) + p1 = s32[] parameter(1) + p2 = pred[] parameter(2) + p3 = s32[] parameter(3) + ROOT %tuple.20.0 = (pred[], s32[]) tuple(p2, p3) + } + reduce { + p0 = pred[4,128,128] parameter(0) + p1 = s32[4,128,128] parameter(1) + cfalse = pred[] constant(false) + c0 = s32[] constant(0) + ROOT reduce = (pred[4,128], s32[4,128]) reduce(p0, p1, cfalse, c0), + dimensions={1}, to_apply=reducer + } + ENTRY main { + p0 = pred[4,128,128] parameter(0) + p1 = s32[4,128,128] parameter(1) + ROOT fusion = (pred[4,128], s32[4,128]) fusion(p0, p1), + kind=kInput, calls=reduce + })"))); + FusionInfoCache cache(device_description()); + auto fusion = module->entry_computation()->root_instruction(); + constexpr int kMaxVectorSize = 4; + EXPECT_EQ( + cache.GetSharedMemoryUsage(*fusion), + (sizeof(int8_t) + sizeof(int32_t)) * 32 * (32 * kMaxVectorSize + 1)); +} + TEST_F(GpuFusibleTest, IsConsumerTheOnlyNonRootUser) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( e { From 827c9453bc054a6ee237223ac95b27740e025461 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 00:03:07 -0700 Subject: [PATCH 0240/1324] Automated Code Change PiperOrigin-RevId: 743834600 --- third_party/xla/xla/shape.cc | 4 +- third_party/xla/xla/shape_util.cc | 78 +++++++++++++------------- third_party/xla/xla/shape_util_test.cc | 2 +- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 670b4e7c67f0e8..578c290eb0b503 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -353,7 +353,7 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { VLOG(3) << "CompareShapes: lhs rank != rhs rank"; return false; } - for (int i = 0; i < lhs.dimensions_size(); ++i) { + for (int i = 0; i < lhs.dimensions().size(); ++i) { if (ignore_dynamic_dimension_ && (lhs.is_unbounded_dynamic_dimension(i) || rhs.is_unbounded_dynamic_dimension(i))) { @@ -403,7 +403,7 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } if (!ignore_dynamic_dimension_) { - for (int i = 0; i < lhs.dimensions_size(); ++i) { + for (int i = 0; i < lhs.dimensions().size(); ++i) { if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { VLOG(3) << "CompareShapes: lhs and rhs have different dynamic dimensions."; diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 8e816014eb670d..42508bdb5f32ea 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -417,8 +417,8 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { - std::vector dims(shape.dimensions_size()); - for (int i = 0; i < shape.dimensions_size(); ++i) { + std::vector dims(shape.dimensions().size()); + for (int i = 0; i < shape.dimensions().size(); ++i) { int dim = i; if (shape.has_layout()) { dim = LayoutUtil::Major(shape.layout(), dim); @@ -436,7 +436,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( new_shape.mutable_layout()->set_tail_padding_alignment_in_elements( shape.layout().tail_padding_alignment_in_elements()); } - for (int i = 0; i < shape.dimensions_size(); ++i) { + for (int i = 0; i < shape.dimensions().size(); ++i) { int dim = i; if (shape.has_layout()) { dim = LayoutUtil::Major(shape.layout(), dim); @@ -523,7 +523,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); if (shape->has_layout()) { - shape->mutable_layout()->add_minor_to_major(shape->dimensions_size()); + shape->mutable_layout()->add_minor_to_major(shape->dimensions().size()); } shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -559,15 +559,15 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } // Insert the newly added dimension at the minor-most position. shape->mutable_layout()->set_minor_to_major(0, - shape->dimensions_size() - 1); + shape->dimensions().size() - 1); } TF_DCHECK_OK(ValidateShape(*shape)); } /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to, const Shape& from) { - CHECK_EQ(to->dimensions_size(), from.dimensions_size()); - for (int64_t i = 0; i < from.dimensions_size(); ++i) { + CHECK_EQ(to->dimensions().size(), from.dimensions().size()); + for (int64_t i = 0; i < from.dimensions().size(); ++i) { to->set_dynamic_dimension(i, from.is_dynamic_dimension(i)); } TF_DCHECK_OK(ValidateShape(*to)); @@ -579,7 +579,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { // If not, and the most major dimension's size is 1, then we can repeat the // same check for next most major dimension as returned by // LayoutUtil::Major(1) and so on. - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { int64_t major_dimension = LayoutUtil::Major(shape.layout(), i); if (major_dimension == dimension) { return true; @@ -729,7 +729,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } }; print_one(0); - for (int i = 1, n = shape.dimensions_size(); i < n; ++i) { + for (int i = 1, n = shape.dimensions().size(); i < n; ++i) { printer->Append(","); print_one(i); } @@ -802,7 +802,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { if (!SameRank(lhs, rhs)) return false; - for (int i = 0; i < lhs.dimensions_size(); ++i) { + for (int i = 0; i < lhs.dimensions().size(); ++i) { if (!lhs.is_unbounded_dynamic_dimension(i) && !rhs.is_unbounded_dynamic_dimension(i) && lhs.dimensions(i) != rhs.dimensions(i)) { @@ -814,7 +814,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) { - return lhs.dimensions_size() == rhs.dimensions_size(); + return lhs.dimensions().size() == rhs.dimensions().size(); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { @@ -849,8 +849,8 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { /* static */ DimensionVector ShapeUtil::CreateDimensionVectorFromShape( const Shape& shape) { DimensionVector dimensions; - dimensions.reserve(shape.dimensions_size()); - for (int i = 0; i < shape.dimensions_size(); ++i) { + dimensions.reserve(shape.dimensions().size()); + for (int i = 0; i < shape.dimensions().size(); ++i) { dimensions.push_back(shape.dimensions(i)); } return dimensions; @@ -864,7 +864,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { /* static */ int64_t ShapeUtil::GetDimensionNumber(const Shape& shape, int64_t dimension_number) { if (dimension_number < 0) { - dimension_number += shape.dimensions_size(); + dimension_number += shape.dimensions().size(); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -939,7 +939,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return ShapeError(shape, "Shape cannot be serialiized."); } if (subshape.is_dynamic()) { - size += sizeof(DynamicSizeType) * subshape.dimensions_size(); + size += sizeof(DynamicSizeType) * subshape.dimensions().size(); } if (subshape.element_type() == PRED) { // PRED is packed 8 elements per byte. @@ -983,7 +983,7 @@ absl::Status ValidateShapeSize(const Shape& shape) { absl::Status ValidateDimensions(const Shape& shape) { bool any_overflows = false; int64_t product = 1; - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { int64_t dimension = shape.dimensions(i); if (dimension == Shape::kUnboundedSize) { continue; @@ -1186,7 +1186,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { /* static */ absl::StatusOr ShapeUtil::PackedFactorFor1DInterleavedArray(const Shape& shape) { - if (shape.dimensions_size() == 1 && shape.layout().tiles_size() == 3 && + if (shape.dimensions().size() == 1 && shape.layout().tiles_size() == 3 && shape.layout().tiles()[2].dimensions().size() == 2) { return shape.layout().tiles()[2].dimension(0); } @@ -1207,7 +1207,7 @@ ShapeUtil::PackedFactorFor1DInterleavedArray(const Shape& shape) { new_shape.add_dimensions(dim); } auto inv_permutation = InversePermutation(permutation); - for (int64_t i = 0; i < shape.dimensions_size(); i++) { + for (int64_t i = 0; i < shape.dimensions().size(); i++) { new_shape.set_dynamic_dimension(inv_permutation[i], shape.is_dynamic_dimension(i)); } @@ -1308,8 +1308,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, i < unmodified_dims.size() ? unmodified_dims[i] : std::make_pair( - static_cast(shape_pre.dimensions_size()), - static_cast(shape_post.dimensions_size())); + static_cast(shape_pre.dimensions().size()), + static_cast(shape_post.dimensions().size())); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return std::nullopt; } @@ -1556,13 +1556,13 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), output_shape.dimensions()); - for (int64_t input_dim = 0; input_dim < input_shape.dimensions_size(); + for (int64_t input_dim = 0; input_dim < input_shape.dimensions().size(); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(input_shape.dimensions_size(), 0); + std::vector input_unit_index(input_shape.dimensions().size(), 0); input_unit_index[input_dim] = 1; int64_t logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1592,7 +1592,7 @@ static absl::Span LayoutPerm(const Shape& s) { /* static */ std::optional> ShapeUtil::DeduceTransposeDimensionsForBitcast(const Shape& input_shape, const Shape& output_shape) { - if (output_shape.dimensions_size() != input_shape.dimensions_size()) { + if (output_shape.dimensions().size() != input_shape.dimensions().size()) { return std::nullopt; } @@ -1649,7 +1649,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, // transpose1_dims * R = input_layout | * R, knowing R * R = I // transpose1_dims = input_layout * R decomposition.transpose1_dims = ComposePermutations( - LayoutPerm(input_shape), ReverseIota(input_shape.dimensions_size())); + LayoutPerm(input_shape), ReverseIota(input_shape.dimensions().size())); CHECK(TransposeIsBitcast(input_shape, decomposition.transpose1_shape, decomposition.transpose1_dims, /*ignore_element_type=*/false)); @@ -1658,7 +1658,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, // transpose2_dims * output_layout = R | * inv(output_layout) // transpose2_dims = R * inv(output_layout) decomposition.transpose2_dims = - ComposePermutations(ReverseIota(output_shape.dimensions_size()), + ComposePermutations(ReverseIota(output_shape.dimensions().size()), InversePermutation(LayoutPerm(output_shape))); CHECK(TransposeIsBitcast(decomposition.reshape_shape, output_shape, decomposition.transpose2_dims, @@ -1705,8 +1705,8 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, // For each one sized dimension in the output, increment the dimension // numbers in layout that are more minor than the one. absl::InlinedVector dim_map; - dim_map.reserve(simple_output_shape->dimensions_size()); - for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) { + dim_map.reserve(simple_output_shape->dimensions().size()); + for (int64_t i = 0; i < output_shape.dimensions().size(); ++i) { if (output_shape.dimensions(i) != 1) { dim_map.push_back(i); } @@ -1717,7 +1717,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, // Add the ones in descending order to the layout. Descending layouts tend // to reduce the number of copies inserted in layout assignment. - for (int64_t i = output_shape.dimensions_size() - 1; i >= 0; --i) { + for (int64_t i = output_shape.dimensions().size() - 1; i >= 0; --i) { if (output_shape.dimensions(i) == 1) { layout.push_back(i); } @@ -1729,7 +1729,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, auto common_factors = CommonFactors(input_shape.dimensions(), output_shape.dimensions()); - const int64_t input_rank = input_shape.dimensions_size(); + const int64_t input_rank = input_shape.dimensions().size(); DimensionVector input_to_factor(input_rank); for (int64_t pos = 0; pos < common_factors.size() - 1; ++pos) { const int64_t input_start = common_factors[pos].first; @@ -1747,7 +1747,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, } } - int64_t output_rank = output_shape.dimensions_size(); + int64_t output_rank = output_shape.dimensions().size(); DimensionVector output_layout; output_layout.reserve(output_rank); int64_t input_minor = 0; @@ -1801,10 +1801,10 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible( const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { - if (dynamic_shape.dimensions_size() != bounded_shape.dimensions_size()) { + if (dynamic_shape.dimensions().size() != bounded_shape.dimensions().size()) { return false; } - for (int64_t i = 0; i < dynamic_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < dynamic_shape.dimensions().size(); ++i) { if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { return false; } @@ -1891,8 +1891,8 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, /* static */ absl::Status ShapeUtil::ForEachIndexParallelWithStatus( const Shape& shape, const ForEachParallelVisitorFunction& visitor_function) { - std::vector base(shape.dimensions_size()); - std::vector incr(shape.dimensions_size(), 1); + std::vector base(shape.dimensions().size()); + std::vector incr(shape.dimensions().size(), 1); return ForEachIndexParallelWithStatus(shape, base, /*count=*/shape.dimensions(), incr, visitor_function); @@ -1991,8 +1991,8 @@ struct ParallelState { // Compute the dimensions of the "work" which are defined by the count of // elements in each dimension and the increment. - std::vector work_dims(shape.dimensions_size()); - for (size_t d = 0; d < shape.dimensions_size(); ++d) { + std::vector work_dims(shape.dimensions().size()); + for (size_t d = 0; d < shape.dimensions().size(); ++d) { work_dims[d] = tsl::MathUtil::CeilOfRatio(count[d], incr[d]); } @@ -2082,7 +2082,7 @@ struct ParallelState { absl::FunctionRef p, Shape shape) { CHECK(shape.IsArray()); std::vector dims_to_delete; - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { if (!p(i)) { dims_to_delete.push_back(i); } @@ -2115,7 +2115,7 @@ absl::Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span strides) { TF_RET_CHECK(shape.IsArray()); TF_RET_CHECK(shape.has_layout()); - TF_RET_CHECK(shape.dimensions_size() == strides.size()); + TF_RET_CHECK(shape.dimensions().size() == strides.size()); int64_t stride = ByteSizeOfPrimitiveType(shape.element_type()); for (int i : shape.layout().minor_to_major()) { @@ -2128,7 +2128,7 @@ absl::Status ShapeUtil::ByteStrides(const Shape& shape, /*static*/ std::optional> ShapeUtil::ByteStrides( const Shape& shape) { - absl::InlinedVector strides(shape.dimensions_size()); + absl::InlinedVector strides(shape.dimensions().size()); if (!ByteStrides(shape, absl::MakeSpan(strides)).ok()) { return std::nullopt; } diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index 4f844100d66500..ae6542489eeca8 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -1058,7 +1058,7 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) { SCOPED_TRACE(absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); auto permuted = ShapeUtil::PermuteDimensions(permutation, shape); - for (int i = 0; i < shape.dimensions_size(); i++) { + for (int i = 0; i < shape.dimensions().size(); i++) { EXPECT_EQ(permuted.dimensions(i), shape.dimensions(permutation[i])); EXPECT_EQ(permuted.is_dynamic_dimension(i), shape.is_dynamic_dimension(permutation[i])); From fba6fa0a0b2cfb8328047045f9a30974245e3bff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 00:04:36 -0700 Subject: [PATCH 0241/1324] Automated Code Change PiperOrigin-RevId: 743835044 --- tensorflow/lite/toco/tflite/export.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 18f4227259e35c..e33f7ac8997e47 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -14,7 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/tflite/export.h" +#include +#include +#include +#include #include +#include #include "flatbuffers/flexbuffers.h" #include "absl/log/log.h" From 36f4d571a983b4ab7dd85f5aa8630eb08cea33fa Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Fri, 4 Apr 2025 00:04:39 -0700 Subject: [PATCH 0242/1324] PR #24519: [nfc] Remove VLOG(s) added by mistake Imported from GitHub PR https://github.com/openxla/xla/pull/24519 These were added by mistake in https://github.com/openxla/xla/commit/cc047dd9d9207c217c6eecaa55700e82c2066995 Copybara import of the project: -- 7a2d859e700af5530b3b012e0c358fb87c97950c by Shraiysh Vaishay : [nfc] Remove VLOG(s) added by mistake Merging this change closes #24519 PiperOrigin-RevId: 743835059 --- third_party/xla/xla/backends/gpu/codegen/custom.cc | 3 --- .../xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc | 1 - 2 files changed, 4 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/custom.cc b/third_party/xla/xla/backends/gpu/codegen/custom.cc index afbc30798b4ee0..3ad6c564a0bea5 100644 --- a/third_party/xla/xla/backends/gpu/codegen/custom.cc +++ b/third_party/xla/xla/backends/gpu/codegen/custom.cc @@ -1053,14 +1053,11 @@ CollectSliceArgumentMetadataForCollectives( SliceDataForCollectives slice_data(num_args); std::optional while_op = GetParentWhileOp(fusion_instr, call_graph); - VLOG(0) << "Collecting while op data"; if (while_op != std::nullopt) { CHECK(while_op.value() != nullptr) << "GetParentWhileOp is not expected to return nullptr."; slice_data.init_module = ExtractWhileInitModule(*while_op); - VLOG(0) << "Extracted init module"; slice_data.update_module = ExtractWhileUpdateModule(*while_op); - VLOG(0) << "Extracted update module"; } slice_data.can_compute_indvar_on_host = (slice_data.init_module != nullptr && slice_data.update_module != nullptr); diff --git a/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc b/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc index eb7ed0c1f28d0b..d2ce9fd26fcbdd 100644 --- a/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc @@ -3757,7 +3757,6 @@ TEST_F(DynamicSliceFusionTest, WhileLoopSliceWithNoInductionVariable) { .mutable_debug_options() .set_xla_gpu_enable_dynamic_slice_fusion(true); TF_ASSERT_OK_AND_ASSIGN(m, GetOptimizedModule(std::move(m))); - // VLOG(0) << "Fused module: " << m->ToString(); ErrorSpec error_spec(1e-5, 1e-5); EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(m), std::move(m_ref), /*run_hlo_passes=*/false, From b496bec157a2bc0be33f04b2362e498a4836882b Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Fri, 4 Apr 2025 00:07:49 -0700 Subject: [PATCH 0243/1324] PR #24513: [ROCm] Use code object version 5 Imported from GitHub PR https://github.com/openxla/xla/pull/24513 Allows generating hsaco files on pre-6.3 rocm Copybara import of the project: -- 084c021adf22a6d531feeb11f8c0daa7fd444f22 by Dragan Mladjenovic : [ROCm] Use code object version 5 Allows generating hsaco files on pre-6.3 rocm Merging this change closes #24513 PiperOrigin-RevId: 743835869 --- .../xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc index ff15924754dc99..429393b123603f 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc @@ -320,6 +320,9 @@ absl::Status AMDGPUTargetModuleLinker( fn.addFnAttr("denormal-fp-math-f32", "preserve-sign"); } } + const int32_t kAbiVersion = 500; + module->addModuleFlag(llvm::Module::Error, "amdhsa_code_object_version", + kAbiVersion); return absl::OkStatus(); } From 52fe79a1ca63090223606d358004fc6fa94621e2 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 4 Apr 2025 00:11:30 -0700 Subject: [PATCH 0244/1324] [XLA:GPU] Remove gpu/codegen/transforms/flatten_tensors.cc. PiperOrigin-RevId: 743836704 --- .../emitters/transforms/expand_float_ops.cc | 714 ----------------- .../emitters/transforms/flatten_tensors.cc | 728 ------------------ 2 files changed, 1442 deletions(-) delete mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/transforms/expand_float_ops.cc delete mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/transforms/flatten_tensors.cc diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/expand_float_ops.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/expand_float_ops.cc deleted file mode 100644 index 00556c88cd3b82..00000000000000 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/expand_float_ops.cc +++ /dev/null @@ -1,714 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "llvm/ADT/APFloat.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -namespace ma = ::mlir::arith; - -using ma::SelectOp; -using mlir::Value; - -#define GEN_PASS_DEF_EXPANDFLOATOPSPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" - -namespace { - -// Wraps a Value to provide operator overloading for more readable expressions. -struct Val { - Value value; - mlir::ImplicitLocOpBuilder* b; - - operator Value() const { return value; } // NOLINT - - Val operator+(int64_t rhs) const { return Binop(rhs); } - Val operator+(Value rhs) const { return Binop(rhs); } - Val operator-(int64_t rhs) const { return Binop(rhs); } - Val operator-(Value rhs) const { return Binop(rhs); } - Val operator*(int64_t rhs) const { return Binop(rhs); } - Val operator*(Value rhs) const { return Binop(rhs); } - Val operator&(Value rhs) const { return Binop(rhs); } - Val operator&(int64_t rhs) const { return Binop(rhs); } - Val operator|(Value rhs) const { return Binop(rhs); } - Val operator|(int64_t rhs) const { return Binop(rhs); } - Val operator^(Value rhs) const { return Binop(rhs); } - Val shl(Value rhs) const { return Binop(rhs); } - Val shl(int64_t rhs) const { return Binop(rhs); } - Val shrui(Value rhs) const { return Binop(rhs); } - Val shrui(int64_t rhs) const { return Binop(rhs); } - - Val cmp(ma::CmpIPredicate pred, Value rhs) const { - return {b->create(pred, value, rhs), b}; - } - Val cmp(ma::CmpIPredicate pred, int64_t rhs) const { - return cmp(pred, MakeConstant(rhs)); - } - Val operator==(Value rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); } - Val operator==(int64_t rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); } - Val operator!=(int64_t rhs) const { return cmp(ma::CmpIPredicate::ne, rhs); } - - Val MakeConstant(int64_t c) const { - return {b->create(c, value.getType()), b}; - } - - private: - template - Val Binop(Value rhs) const { - return {b->create(value, rhs), b}; - } - - template - Val Binop(int64_t rhs) const { - return Binop(MakeConstant(rhs)); - } -}; - -struct RewriteErf32Pattern : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::math::ErfOp op, mlir::PatternRewriter& rewriter) const override { - if (!op.getType().isF32()) { - return rewriter.notifyMatchFailure(op, "not an f32 erf"); - } - - static const std::array kAlpha{ - 0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f, - 0.18520832239976145f, 1.128379143519084f}; - - static const std::array kBeta{-1.1791602954361697e-7, - 0.000023547966471313185f, - 0.0010179625278914885f, - 0.014070470171167667f, - 0.11098505178285362f, - 0.49746925110067538f, - 1.0f}; - - // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of - // which x should be +/-1. - constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f; - - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto c = [&](float v) -> Value { - return b.create(llvm::APFloat(v), - rewriter.getF32Type()); - }; - - auto poly = [&](auto x, auto coefficients) -> Value { - auto r = c(coefficients[0]); - for (int i = 1; i < coefficients.size(); ++i) { - r = b.create(r, x, c(coefficients[i])); - } - return r; - }; - - Value x = op.getOperand(); - x = b.create(x, c(-kErfInvOneMinusHalfULP)); - x = b.create(x, c(kErfInvOneMinusHalfULP)); - Value x2 = b.create(x, x); - - rewriter.replaceOpWithNewOp( - op, b.create(x, poly(x2, kAlpha)), poly(x2, kBeta)); - - return mlir::success(); - } -}; - -int GetSignificandBits(mlir::FloatType ty) { - return llvm::APFloat::semanticsPrecision(ty.getFloatSemantics()) - 1; -} - -int GetExponentBias(mlir::FloatType ty) { - return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) - - llvm::isa(ty); // No zero exponent for E8M0. -} - -bool IsFNUZ(mlir::FloatType ty) { - return llvm::isa(ty); -} - -Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { - auto ty = mlir::cast(value.getType()); - if (mlir::LLVM::isCompatibleOuterType(ty)) { - value = b.create(value); - Value inf = b.create( - llvm::APFloat::getInf(ty.getFloatSemantics()), ty); - return b.create(ma::CmpFPredicate::OEQ, value, inf); - } - - assert(ty.getIntOrFloatBitWidth() <= 8); - // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. - if (llvm::isa(ty)) { - Val bits{b.create(b.getI8Type(), value), &b}; - return (bits & 0x7F) == 0x7C; - } else if (llvm::isa(ty)) { - Val bits{b.create(b.getI8Type(), value), &b}; - return (bits & 0x7F) == 0x78; - } else if (llvm::isa(ty)) { - Val bits{b.create(b.getI8Type(), value), &b}; - return (bits & 0x7F) == 0x70; - } else { - return b.create(false, b.getI1Type()); - } -} - -Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { - auto ty = value.getType(); - if (mlir::LLVM::isCompatibleOuterType(ty)) { - return b.create(ma::CmpFPredicate::UNO, value, value); - } - if (llvm::isa(ty)) { - return b.create(false, b.getI1Type()); - } - - assert(ty.getIntOrFloatBitWidth() == 8); - Val bits{b.create(b.getI8Type(), value), &b}; - if (llvm::isa(ty)) { - return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1100); - } else if (llvm::isa(ty)) { - return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1000); - } else if (llvm::isa(ty)) { - return (bits & 0b0111'1111) == 0b0111'1111; - } else if (llvm::isa(ty)) { - return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); - } else if (llvm::isa(ty)) { - return bits == 0xFF; - } - return bits == 0x80; -} - -Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits, - mlir::ImplicitLocOpBuilder& b) { - mlir::mhlo::ReducePrecisionOp::Properties properties; - properties.exponent_bits = b.getI32IntegerAttr(exponent_bits); - properties.mantissa_bits = b.getI32IntegerAttr(mantissa_bits); - return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType< - mlir::mhlo::ReducePrecisionOp>( - b.getLoc(), value.getType(), {value.getType()}, - mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), - /*attributes=*/std::nullopt, &b); -} - -Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) { - Val in_bits{b.create(b.getI16Type(), in), &b}; - // Use this method of checking for NaN because it's the same as what's used - // in the reduce precision lowering. - Value is_nan = (in_bits & 32767).cmp(ma::CmpIPredicate::ugt, 31744); - - Value value = EmitReducePrecision(in, 5, 2, b); - value = b.create(b.getI16Type(), value); - value = b.create(value, - b.create(8, b.getI16Type())); - value = b.create(b.getI8Type(), value); - // When the input is NaN, just truncating can turn a NaN into an inf if the - // mantissa becomes 0. - value = b.create( - is_nan, b.create(0x7F, value.getType()), value); - return b.create(b.getType(), value); -} - -Value EmitFloatConversion(Value value, mlir::FloatType to_ty, - mlir::ImplicitLocOpBuilder& b) { - using ma::CmpIPredicate; - - auto from_ty = mlir::cast(value.getType()); - if (to_ty == b.getType() && from_ty == b.getF16Type()) { - return EmitF16ToF8e5m2(value, b); - } - - if (to_ty == b.getType() && - from_ty == b.getBF16Type()) { - // Going through f32 and f16 is significantly faster than the fallback code - // below. - return EmitF16ToF8e5m2( - b.create(b.getF16Type(), - b.create(b.getF32Type(), value)), - b); - } - - // Fallback code. The generated code here is not good. If you end up here, - // you might want to add a more specific conversion. - // This is a port of ConvertImpl in - // https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h - - int from_mantissa = GetSignificandBits(from_ty); - int from_bias = GetExponentBias(from_ty); - int from_min_exp = - llvm::APFloat::semanticsMinExponent(from_ty.getFloatSemantics()); - int from_max_exp = - llvm::APFloat::semanticsMaxExponent(from_ty.getFloatSemantics()); - auto from_int_ty = b.getIntegerType(from_ty.getIntOrFloatBitWidth()); - - int to_mantissa = GetSignificandBits(to_ty); - int to_bias = GetExponentBias(to_ty); - int to_min_exp = - llvm::APFloat::semanticsMinExponent(to_ty.getFloatSemantics()); - int to_max_exp = - llvm::APFloat::semanticsMaxExponent(to_ty.getFloatSemantics()); - auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth()); - - mlir::IntegerType wide_int_ty; - if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) { - wide_int_ty = b.getI16Type(); - } else { - wide_int_ty = b.getIntegerType( - std::max(from_int_ty.getWidth(), to_int_ty.getWidth())); - // Avoid overflow for bit shifts. - auto may_overflow = [&](mlir::Type a, mlir::Type b) { - return llvm::isa(a) && b.isF16(); - }; - if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) { - wide_int_ty = b.getI32Type(); - } - } - auto convert_int = [&](mlir::Type ty, Value v) -> Val { - if (v.getType() == ty) { - return {v, &b}; - } - if (ty.getIntOrFloatBitWidth() < v.getType().getIntOrFloatBitWidth()) { - return {b.create(ty, v), &b}; - } - return {b.create(ty, v), &b}; - }; - - int64_t exp_offset = to_bias - from_bias; - int digit_shift = to_mantissa - from_mantissa; - - int from_width = value.getType().getIntOrFloatBitWidth(); - Val from_bits{b.create(b.getIntegerType(from_width), value), - &b}; - if (from_width < 8) { - from_bits = convert_int(b.getIntegerType(8), from_bits); - } - - auto cst = [&](mlir::Type ty, int64_t n) -> Val { - return {b.create(n, ty), &b}; - }; - - // Shift bits to destination type, without sign bit. - Val from_sign_bit; - if (!llvm::isa(from_ty)) { - from_sign_bit = from_bits.shrui(from_width - 1) != 0; - from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); - } - - auto cst_bits = [&](llvm::APFloat f) { - return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())), - f.bitcastToAPInt().getZExtValue()); - }; - Value to_nan; - Value to_inf; - Val to_zero; - - // MX float types have neither infinities nor NaNs. - if (llvm::isa(to_ty)) { - to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); - to_nan = to_zero | 0x8; - to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); - } else if (llvm::isa(to_ty)) { - to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); - to_inf = to_nan; - to_zero = Val{to_nan, &b}; - } else { - to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); - to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); - to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); - } - - auto round_bits_to_nearest_even = [&](Val bits, Val roundoff, - bool use_implicit_bit = false) { - assert(bits.value.getType() == roundoff.value.getType()); - // Round to nearest even by adding a bias term. - // Consider a bit pattern - // FFF...FLRTT...T, - // where bits RTT...T need to be rounded-off. We add a bias term to the - // bit pattern s.t. a carry is introduced to round up only if - // - L is 1, R is 1, OR - // - L is 0, R is 1, any T is one. - // We do this by adding L to a bit pattern consisting of all T = 1. - Val bias = !use_implicit_bit - ? (bits.shrui(roundoff) & 1) + - (bits.MakeConstant(1).shl(roundoff - 1) - 1) - : bits.MakeConstant(1).shl(roundoff - 1); - return bits + bias; - }; - - // Happy path: no subnormals, infinities or NaNs. - Value result; - { - // Round the mantissa if it is shrinking. - Val rounded_from_bits = convert_int(wide_int_ty, from_bits); - if (digit_shift < 0) { - rounded_from_bits = - round_bits_to_nearest_even( - rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift), - /*use_implicit_bit=*/to_mantissa == 0) & - ~((1ll << (-digit_shift)) - 1); - } - - // Re-bias the exponent. - rounded_from_bits = rounded_from_bits + (exp_offset << from_mantissa); - - // Check for overflows by aligning the significands. We always align the - // narrower significand to the wider significand. - int64_t to_highest = llvm::APFloat::getLargest(to_ty.getFloatSemantics()) - .bitcastToAPInt() - .getZExtValue(); - int64_t aligned_highest = to_highest; - if (digit_shift < 0) { - aligned_highest <<= -digit_shift; - // Shift down, all dropped bits should already be zero. - result = rounded_from_bits.shrui(-digit_shift); - } else { - // Shift up, inserting zeros in the newly created digits. - rounded_from_bits = rounded_from_bits.shl(digit_shift); - result = rounded_from_bits; - } - result = convert_int(to_int_ty, result); - - // `From` supports larger values than `To`, we may overflow. - if (std::make_pair(to_max_exp, to_mantissa) < - std::make_pair(from_max_exp, from_mantissa)) { - result = b.create( - rounded_from_bits.cmp(CmpIPredicate::ugt, aligned_highest), to_inf, - result); - } - } - - auto i32_ty = b.getI32Type(); - Val biased_from_exp = convert_int(i32_ty, from_bits.shrui(from_mantissa)); - - if (to_min_exp < from_min_exp) { - // `To` supports more exponents near zero which means that some subnormal - // values in `From` may become normal. - - // Subnormals. - Val bits = convert_int(wide_int_ty, from_bits); - - // Determine exponent in target type. - Value clz = convert_int( - i32_ty, b.create(from_bits)); - Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz; - Value normalization_factor = cst(i32_ty, from_mantissa) - msb; - - Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor; - // If the result is subnormal, adjust the subnormal bits to account for - // the difference in exponent bias. - Value subnormal_bits = bits; - if (exp_offset < wide_int_ty.getWidth()) { - subnormal_bits = bits.shl(exp_offset); - } - - // Result is normal. Shift the mantissa to account for the number of - // leading zero digits, and clear the hidden bit. - // Insert the exponent bits. - Value normal_bits = - (bits.shl(convert_int(wide_int_ty, normalization_factor)) & - ~(1 << from_mantissa)) | - convert_int(wide_int_ty, biased_exponent).shl(from_mantissa); - - Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0); - bits.value = - b.create(biased_exp_sle_zero, subnormal_bits, normal_bits); - if (digit_shift >= 0) { - bits = bits.shl(digit_shift); - } else { - bits = round_bits_to_nearest_even( - bits, bits.MakeConstant(-digit_shift), - /*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0); - bits = bits.shrui(-digit_shift); - } - bits = convert_int(to_int_ty, bits); - - result = b.create(biased_from_exp == 0, bits, result); - } else if (to_min_exp > from_min_exp) { - // `To` supports fewer exponents near zero which means that some values in - // `From` may become subnormal. - Val biased_to_exp = biased_from_exp + (to_bias - from_bias); - // Subnormals and zero. - // Round and shift mantissa down. - Val from_has_leading_one = !llvm::isa(from_ty) - ? biased_from_exp != 0 - : cst(i32_ty, 1); - Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one); - from_has_leading_one = convert_int(from_int_ty, from_has_leading_one); - Val exponent_shift_i32 = - (from_has_leading_one_i32 - biased_to_exp) - digit_shift; - // Insert the implicit leading 1 bit on the mantissa for normalized - // inputs. - Val rounded_from_bits = (from_bits & ((1ll << from_mantissa) - 1)) | - from_has_leading_one.shl(from_mantissa); - - // NOTE: we need to round again from the original from_bits, - // otherwise the lower precision bits may already be lost. There is - // an edge-case where rounding to a normalized value would normally - // round down, but for a subnormal, we need to round up. - Val exponent_shift_from_ty = convert_int(from_int_ty, exponent_shift_i32); - Val exponent_shift_to_ty = convert_int(to_int_ty, exponent_shift_i32); - Val positive_bits = convert_int( - to_int_ty, - round_bits_to_nearest_even(rounded_from_bits, exponent_shift_from_ty) - .shrui(exponent_shift_from_ty)); - // To avoid UB, limit rounding and shifting to the full mantissa plus - // leading 1. - positive_bits.value = b.create( - exponent_shift_i32.cmp(CmpIPredicate::sle, from_mantissa + 1), - positive_bits, to_zero); - - Val negative_bits = convert_int(to_int_ty, rounded_from_bits) - .shl(to_zero - exponent_shift_to_ty); - Value bits = - b.create(exponent_shift_i32.cmp(CmpIPredicate::sgt, 0), - positive_bits, negative_bits); - result = b.create(biased_to_exp.cmp(CmpIPredicate::sle, 0), bits, - result); - } - - Value result_is_inf = IsInf(value, b); - Value input_is_nan = IsNaN(value, b); - - if (llvm::isa(to_ty)) { - // Converting a negative number to E8M0 results in NaN. - input_is_nan = from_sign_bit | input_is_nan; - } else if (IsFNUZ(to_ty)) { - // Clear the sign bit if the result is zero (the output has no negative - // zero). Handle the edge case when the input is zero and the result is not. - Val result_is_non_zero = - (digit_shift > 0 ? from_bits : Val{result, &b}) != 0; - from_sign_bit = from_sign_bit & result_is_non_zero; - } else if (IsFNUZ(from_ty)) { - // Clear the sign bit if the input is NaN (it's positive but encoded as - // negative 0). - from_sign_bit = from_sign_bit ^ input_is_nan; - } - - if (!llvm::isa(from_ty)) { - result = b.create(from_bits == 0, to_zero, result); - } - result = b.create(result_is_inf, to_inf, result); - result = b.create(input_is_nan, to_nan, result); - - // Insert sign bit. - if (!llvm::isa(from_ty)) { - Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); - result = b.create(from_sign_bit, neg_result, result); - } - result = b.create(to_ty, result); - return result; -} - -struct RewriteTruncFPattern : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - ma::TruncFOp op, mlir::PatternRewriter& rewriter) const override { - using FloatValue = mlir::TypedValue; - auto src = mlir::cast(op.getOperand()); - auto dst_ty = mlir::cast(op.getType()); - if (dst_ty.getWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf"); - } - - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b)); - return mlir::success(); - } -}; - -struct RewriteExtFPattern : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - ma::ExtFOp op, mlir::PatternRewriter& rewriter) const override { - using FloatValue = mlir::TypedValue; - auto src = mlir::cast(op.getOperand()); - auto dst_ty = mlir::cast(op.getType()); - if (src.getType().getWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf"); - } - - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b)); - return mlir::success(); - } -}; - -// Lowering for cmpf : f8 for float to pred conversions. -struct RewriteF8Cst : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - ma::CmpFOp op, mlir::PatternRewriter& rewriter) const override { - using FloatValue = mlir::TypedValue; - auto lhs = mlir::cast(op.getLhs()); - auto rhs = mlir::cast(op.getRhs()); - - if (lhs.getType().getWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf"); - } - - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - // Skip the f32 conversion if we're comparing UNE.cst. - llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics()); - if (op.getPredicate() == ma::CmpFPredicate::UNE && - mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) { - mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth()); - Val int_value{b.create(int_ty, lhs), &b}; - int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); - // If we're comparing to +-0, compare the absolute values. - if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) { - int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1; - int_value = int_value & mask; - constant &= mask; - } - auto cst = b.create(constant, int_ty); - rewriter.replaceOpWithNewOp(op, ma::CmpIPredicate::ne, - int_value, cst); - return mlir::success(); - } - - auto lhs_ext = b.create(b.getF32Type(), lhs); - auto rhs_ext = b.create(b.getF32Type(), rhs); - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - mlir::ValueRange{lhs_ext, rhs_ext}, - op->getAttrs()); - return mlir::success(); - } -}; - -struct RewriteAbsFPattern : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::math::AbsFOp op, mlir::PatternRewriter& rewriter) const override { - using FloatValue = mlir::TypedValue; - auto src = mlir::cast(op.getOperand()); - // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16. - // Once that's removed, remove the code for BF16 here. - if (src.getType().getWidth() > 8 && !src.getType().isBF16()) { - return rewriter.notifyMatchFailure(op, - "not an f8 (or less) or bf16 absf"); - } - - // If type is unsigned (E8M0), the operation is no-op. - if (!llvm::APFloat::semanticsHasSignedRepr( - src.getType().getFloatSemantics())) { - rewriter.replaceAllOpUsesWith(op, op.getOperand()); - return mlir::success(); - } - - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); - Val value{b.create(i_ty, src), &b}; - int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1; - value = value & mask; - rewriter.replaceOpWithNewOp(op, src.getType(), value); - return mlir::success(); - } -}; - -template -struct RewriteIToFpPattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp"); - } - Value to_float = - rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); - rewriter.replaceOpWithNewOp(op, op.getType(), to_float); - return mlir::success(); - } -}; - -template -struct RewriteFpToIPattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getIn().getType().getIntOrFloatBitWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi"); - } - Value to_f32 = rewriter.create( - op.getLoc(), rewriter.getF32Type(), op.getIn()); - rewriter.replaceOpWithNewOp(op, op.getType(), to_f32); - return mlir::success(); - } -}; - -class ExpandFloatOpsPass - : public impl::ExpandFloatOpsPassBase { - public: - using ExpandFloatOpsPassBase::ExpandFloatOpsPassBase; - void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add, - RewriteIToFpPattern, - RewriteFpToIPattern, - RewriteFpToIPattern>(&getContext()); - mlir::populatePolynomialApproximateTanhPattern(patterns); - patterns.add(&getContext()); - if (mlir::failed( - mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr CreateExpandFloatOpsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/flatten_tensors.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/flatten_tensors.cc deleted file mode 100644 index d5b730c45406f1..00000000000000 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/flatten_tensors.cc +++ /dev/null @@ -1,728 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include - -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" -#include "xla/hlo/analysis/indexing_analysis.h" -#include "xla/layout_util.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { -namespace { - -#define GEN_PASS_DEF_FLATTENTENSORSPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" - -using mlir::Attribute; -using mlir::Location; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpRewritePattern; -using mlir::PatternRewriter; -using mlir::RankedTensorType; -using mlir::ShapedType; -using mlir::SmallVector; -using mlir::Type; -using mlir::TypedValue; -using mlir::TypeRange; -using mlir::UnrealizedConversionCastOp; -using mlir::Value; -using mlir::ValueRange; -using mlir::VectorType; -using mlir::func::FuncOp; -using mlir::func::ReturnOp; -using mlir::scf::ForOp; -using mlir::scf::IfOp; -using mlir::scf::IndexSwitchOp; -using mlir::tensor::ExtractOp; -using mlir::tensor::InsertOp; -namespace mv = mlir::vector; - -RankedTensorType GetFlattenedType(RankedTensorType tensor_type) { - return RankedTensorType::get({tensor_type.getNumElements()}, - tensor_type.getElementType()); -} - -VectorType GetFlattenedType(VectorType vector_type) { - return VectorType::get({vector_type.getNumElements()}, - vector_type.getElementType()); -} - -ShapedType GetFlattenedType(Type type) { - if (auto vector_type = mlir::dyn_cast(type)) { - return GetFlattenedType(vector_type); - } - return GetFlattenedType(mlir::cast(type)); -} - -bool IsScalarOrFlat(Type type) { - if (auto shaped_type = mlir::dyn_cast(type)) { - return shaped_type.getRank() < 2; - } - return true; -} - -bool HasOnlyFlatTensorsFlatVectorsOrScalars(TypeRange types) { - return llvm::all_of(types, IsScalarOrFlat); -} - -Value Flatten(Value value, PatternRewriter& rewriter) { - if (IsScalarOrFlat(value.getType())) return value; - auto flat_type = GetFlattenedType(value.getType()); - return rewriter - .create(value.getLoc(), flat_type, value) - .getResult(0); -} - -struct RewriteFunctionSignatures : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FuncOp op, - PatternRewriter& rewriter) const override { - auto input_types = op.getFunctionType().getInputs(); - auto result_types = op.getFunctionType().getResults(); - if (HasOnlyFlatTensorsFlatVectorsOrScalars(input_types) && - HasOnlyFlatTensorsFlatVectorsOrScalars(result_types)) { - return rewriter.notifyMatchFailure(op, "nothing to flatten"); - } - - auto loc = op.getLoc(); - mlir::Block* entry_block = &op.getBody().front(); - SmallVector new_result_types; - SmallVector new_results; - - // If some results are tensors or vectors, we need to flatten them. - auto terminator = entry_block->getTerminator(); - rewriter.setInsertionPoint(terminator); - - for (Value result : terminator->getOperands()) { - Value flattened = Flatten(result, rewriter); - new_results.push_back(flattened); - new_result_types.push_back(flattened.getType()); - } - rewriter.replaceOpWithNewOp(terminator, new_results); - - // Cast all function arguments to the original type. - SmallVector new_operand_types(input_types); - rewriter.setInsertionPointToStart(entry_block); - for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) { - if (IsScalarOrFlat(operand_type)) continue; - mlir::BlockArgument func_argument = op.getArgument(index); - auto cast_to_orig_type = rewriter.create( - loc, operand_type, func_argument); - func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0), - cast_to_orig_type); - operand_type = GetFlattenedType(operand_type); - } - // Replace the function arguments with the new types. - for (auto [arg, arg_type] : - llvm::zip(entry_block->getArguments(), new_operand_types)) { - arg.setType(arg_type); - } - // Update function signature. - op.setType(rewriter.getFunctionType(new_operand_types, new_result_types)); - return mlir::success(); - } -}; - -struct RewritePureCall : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PureCallOp op, - PatternRewriter& rewriter) const override { - if (HasOnlyFlatTensorsFlatVectorsOrScalars(op.getOperandTypes()) && - HasOnlyFlatTensorsFlatVectorsOrScalars(op.getResultTypes())) { - return rewriter.notifyMatchFailure(op, "nothing to flatten"); - } - SmallVector flat_operands; - flat_operands.reserve(op.getNumOperands()); - for (Value operand : op.getOperands()) { - flat_operands.push_back(Flatten(operand, rewriter)); - } - SmallVector flat_result_types; - flat_result_types.reserve(op.getNumResults()); - llvm::SmallBitVector results_to_update(op.getNumResults(), false); - for (auto [index, result_type] : llvm::enumerate(op.getResultTypes())) { - if (IsScalarOrFlat(result_type)) { - flat_result_types.push_back(result_type); - continue; - } - results_to_update.set(index); - flat_result_types.push_back(GetFlattenedType(result_type)); - } - Location loc = op.getLoc(); - auto new_call_op = rewriter.create( - loc, flat_result_types, op.getCalleeAttr(), flat_operands); - SmallVector new_results; - new_results.reserve(op.getNumResults()); - for (auto [index, new_result] : llvm::enumerate(new_call_op.getResults())) { - if (results_to_update.test(index)) { - new_results.push_back(new_result); - continue; - } - auto cast_to_orig_type = rewriter.create( - loc, op.getResult(index).getType(), new_result); - new_results.push_back(cast_to_orig_type.getResult(0)); - } - rewriter.replaceOp(op, new_results); - return mlir::success(); - } -}; - -// Returns the linearized index, if the rank is greater than 1. Otherwise, -// returns nullptr. -Value LinearizeIndex(Value value, ShapedType type, ValueRange indices, - PatternRewriter& rewriter, Attribute encoding = nullptr) { - if (type.getRank() < 2) { - return nullptr; - } - auto byte_shape = ShapeUtil::MakeShape(U8, type.getShape()); - if (encoding) { - *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( - mlir::cast(encoding).getValues())); - } - auto linear_shape = - ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)}); - auto linearized_map = - GetBitcastMap(byte_shape, linear_shape, value.getContext()); - mlir::SmallVector result; - rewriter.createOrFold(result, value.getLoc(), indices, - ValueRange{}, linearized_map); - return result.front(); -} - -struct RewriteAllocateShared : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AllocateSharedOp op, - PatternRewriter& rewriter) const override { - auto tensor_type = op.getResult().getType(); - if (IsScalarOrFlat(tensor_type)) { - return rewriter.notifyMatchFailure(op, "the tensor is already flat"); - } - auto flat_type = GetFlattenedType(tensor_type); - Location loc = op.getLoc(); - Value new_op = rewriter.create(op.getLoc(), flat_type); - auto cast_to_orig_type = - rewriter.create(loc, tensor_type, new_op); - rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); - return mlir::success(); - } -}; - -struct RewriteConstant : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::arith::ConstantOp op, - PatternRewriter& rewriter) const override { - if (IsScalarOrFlat(op.getType())) { - return rewriter.notifyMatchFailure( - op, "the tensor or vector is already flat"); - } - auto dense_attr = mlir::dyn_cast(op.getValue()); - auto new_type = GetFlattenedType(op.getType()); - Value new_constant = rewriter.create( - op.getLoc(), dense_attr.reshape(new_type)); - rewriter.replaceOpWithNewOp(op, op.getType(), - new_constant); - return mlir::success(); - } -}; - -struct RewriteTensorExtract : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractOp op, - PatternRewriter& rewriter) const override { - auto tensor = op.getTensor(); - auto tensor_type = tensor.getType(); - auto linear_index = LinearizeIndex(tensor, tensor_type, op.getIndices(), - rewriter, tensor_type.getEncoding()); - if (linear_index == nullptr) { - return rewriter.notifyMatchFailure(op, "the tensor is already flat"); - } - auto tensor_1D = rewriter - .create( - op.getLoc(), GetFlattenedType(tensor_type), tensor) - .getResult(0); - rewriter.replaceOpWithNewOp(op, tensor_1D, linear_index); - return mlir::success(); - } -}; - -struct RewriteVectorExtract : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mv::ExtractOp op, - PatternRewriter& rewriter) const override { - auto vector = op.getVector(); - auto vector_type = vector.getType(); - auto indices = - mv::getAsValues(rewriter, op.getLoc(), op.getMixedPosition()); - auto linear_index = LinearizeIndex(vector, vector_type, indices, rewriter); - if (linear_index == nullptr) { - return rewriter.notifyMatchFailure(op, "the vector is already flat"); - } - auto vector_1D = rewriter - .create( - op.getLoc(), GetFlattenedType(vector_type), vector) - .getResult(0); - rewriter.replaceOpWithNewOp(op, vector_1D, linear_index); - return mlir::success(); - } -}; - -struct RewriteTensorInsert : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertOp op, - PatternRewriter& rewriter) const override { - auto tensor = op.getDest(); - auto tensor_type = tensor.getType(); - auto linear_index = LinearizeIndex(tensor, tensor_type, op.getIndices(), - rewriter, tensor_type.getEncoding()); - if (linear_index == nullptr) { - return rewriter.notifyMatchFailure(op, "the tensor is already flat"); - } - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto tensor_1D = b.create( - GetFlattenedType(tensor_type), tensor) - .getResult(0); - auto new_insert = - b.create(op.getScalar(), tensor_1D, linear_index); - auto cast_to_orig_type = b.create( - tensor_type, new_insert.getResult()); - rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); - return mlir::success(); - } -}; - -struct RewriteVectorInsert : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mv::InsertOp op, - PatternRewriter& rewriter) const override { - auto vector = op.getDest(); - auto vector_type = vector.getType(); - auto indices = - mv::getAsValues(rewriter, op.getLoc(), op.getMixedPosition()); - auto linear_index = LinearizeIndex(vector, vector_type, indices, rewriter); - if (linear_index == nullptr) { - return rewriter.notifyMatchFailure(op, "the vector is already flat"); - } - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto vector_1D = b.create( - GetFlattenedType(vector_type), vector) - .getResult(0); - auto new_insert = - b.create(op.getSource(), vector_1D, linear_index); - auto cast_to_orig_type = b.create( - vector_type, new_insert.getResult()); - rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); - return mlir::success(); - } -}; - -struct RewriteAtomicRMW : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AtomicRMWOp op, - PatternRewriter& rewriter) const override { - auto tensor = op.getInput(); - auto tensor_type = tensor.getType(); - auto linear_index = LinearizeIndex(tensor, tensor_type, op.getIndices(), - rewriter, tensor_type.getEncoding()); - if (linear_index == nullptr) { - return rewriter.notifyMatchFailure(op, "the tensor is already flat"); - } - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto tensor_1D = b.create( - GetFlattenedType(tensor_type), tensor) - .getResult(0); - auto new_atomic_rmw = b.create(tensor_1D, linear_index); - rewriter.inlineRegionBefore(op.getRegion(), - &new_atomic_rmw.getRegion().front()); - auto cast_to_orig_type = b.create( - tensor_type, new_atomic_rmw.getResult()); - rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); - return mlir::success(); - } -}; - -// Returns the rank of the tensor or vector or nullopt if it is of neither type. -std::optional GetRankOfTensorOrVector(Type type) { - if (auto shaped_type = mlir::dyn_cast(type)) { - return shaped_type.getRank(); - } - return std::nullopt; -} - -// Checks that the value is produced by an unrealized conversion cast from 1D -// tensor or vector to ND. Returns the 1D tensor or vector if so. -std::optional GetDelinearizedTensorOrVector(Value value) { - auto rank = GetRankOfTensorOrVector(value.getType()); - if (!rank.has_value() || *rank < 2) { - return std::nullopt; - } - auto cast = value.getDefiningOp(); - if (!cast || cast->getNumResults() != 1 || cast->getNumOperands() != 1) { - return std::nullopt; - } - auto rank_before_linearization = - GetRankOfTensorOrVector(cast->getOperand(0).getType()); - if (!rank_before_linearization.has_value() || - *rank_before_linearization != 1) { - return std::nullopt; - } - return cast->getOperand(0); -} - -struct RewriteFor : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ForOp op, - PatternRewriter& rewriter) const override { - llvm::SmallBitVector args_to_update(op.getNumResults(), false); - mlir::SmallVector new_init_args; - new_init_args.reserve(op.getNumResults()); - for (auto [index, arg] : llvm::enumerate(op.getInitArgs())) { - auto type_before_linearization = GetDelinearizedTensorOrVector(arg); - if (!type_before_linearization.has_value()) { - new_init_args.push_back(arg); - continue; - } - new_init_args.push_back(*type_before_linearization); - args_to_update.set(index); - } - if (args_to_update.none()) { - return rewriter.notifyMatchFailure(op, "no args to update"); - } - // Create new ForOp with updated init args. - Location loc = op.getLoc(); - auto new_for_op = - rewriter.create(loc, op.getLowerBound(), op.getUpperBound(), - op.getStep(), new_init_args); - new_for_op->setAttrs(op->getAttrs()); - - // Insert casts for the block arguments. - mlir::Block* new_body = new_for_op.getBody(); - mlir::Block* old_body = op.getBody(); - rewriter.setInsertionPoint(new_body, new_body->begin()); - SmallVector updated_block_args{new_body->getArguments().begin(), - new_body->getArguments().end()}; - for (auto [index, arg] : - llvm::enumerate(new_body->getArguments().drop_front())) { - if (!args_to_update.test(index)) continue; - updated_block_args[index + 1] = - rewriter - .create( - loc, old_body->getArgument(index + 1).getType(), arg) - .getResult(0); - } - - // Move the body of the old ForOp to the new one. - rewriter.mergeBlocks(old_body, new_body, updated_block_args); - - // Update the terminator. - auto new_terminator = - mlir::cast(new_body->getTerminator()); - rewriter.setInsertionPoint(new_terminator); - for (auto&& [index, yielded_value] : - llvm::enumerate(new_terminator.getResultsMutable())) { - if (!args_to_update.test(index)) continue; - yielded_value.assign( - rewriter - .create( - loc, new_init_args[index].getType(), yielded_value.get()) - .getResult(0)); - } - - // Cast back the results. - rewriter.setInsertionPointAfter(new_for_op); - SmallVector new_results(new_for_op.getResults()); - for (auto&& [index, result] : llvm::enumerate(new_results)) { - if (!args_to_update.test(index)) continue; - result = rewriter - .create( - loc, op->getResult(index).getType(), result) - .getResult(0); - } - rewriter.replaceOp(op, new_results); - return mlir::success(); - } -}; - -struct RewriteIf : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IfOp op, - PatternRewriter& rewriter) const override { - auto result_types = op.getResultTypes(); - if (HasOnlyFlatTensorsFlatVectorsOrScalars(result_types)) { - return rewriter.notifyMatchFailure(op, "nothing to flatten"); - } - mlir::scf::YieldOp then_yield = op.thenYield(); - SmallVector new_result_types; - new_result_types.reserve(then_yield.getNumOperands()); - bool found_cast = false; - for (auto& result : then_yield->getOpOperands()) { - auto delinearized_tensor = GetDelinearizedTensorOrVector(result.get()); - if (!delinearized_tensor.has_value()) { - new_result_types.push_back(result.get().getType()); - continue; - } - new_result_types.push_back(delinearized_tensor->getType()); - result.set(*delinearized_tensor); - found_cast = true; - } - if (!found_cast) { - return rewriter.notifyMatchFailure(op, "no cast found"); - } - Location loc = op.getLoc(); - // Update the else branch if present. - bool has_else_region = !op.getElseRegion().empty(); - if (has_else_region) { - mlir::scf::YieldOp else_yield = op.elseYield(); - mlir::OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(else_yield); - for (auto&& [result, type] : - llvm::zip(else_yield->getOpOperands(), new_result_types)) { - if (result.get().getType() == type) continue; - result.set( - rewriter.create(loc, type, result.get()) - .getResult(0)); - } - } - // Create new IfOp and move the old op's regions to the new one. - auto new_if_op = rewriter.create(loc, new_result_types, - op.getCondition(), has_else_region); - rewriter.inlineRegionBefore(op.getThenRegion(), - &new_if_op.getThenRegion().back()); - rewriter.eraseBlock(&new_if_op.getThenRegion().back()); - if (has_else_region) { - rewriter.inlineRegionBefore(op.getElseRegion(), - &new_if_op.getElseRegion().back()); - rewriter.eraseBlock(&new_if_op.getElseRegion().back()); - } - - // Update the results. - rewriter.setInsertionPointAfter(new_if_op); - SmallVector new_results(new_if_op.getResults()); - for (auto&& [index, result] : llvm::enumerate(new_results)) { - Type old_type = op->getResult(index).getType(); - if (result.getType() == old_type) continue; - result = - rewriter.create(loc, old_type, result) - .getResult(0); - } - rewriter.replaceOp(op, new_results); - return mlir::success(); - } -}; - -struct RewriteIndexSwitch : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IndexSwitchOp op, - PatternRewriter& rewriter) const override { - auto result_types = op.getResultTypes(); - if (HasOnlyFlatTensorsFlatVectorsOrScalars(result_types)) { - return rewriter.notifyMatchFailure(op, "nothing to flatten"); - } - auto default_yield = - mlir::cast(op.getDefaultBlock().getTerminator()); - SmallVector new_result_types; - new_result_types.reserve(default_yield.getNumOperands()); - bool found_cast = false; - for (auto& result : default_yield->getOpOperands()) { - auto delinearized_tensor = GetDelinearizedTensorOrVector(result.get()); - if (!delinearized_tensor.has_value()) { - new_result_types.push_back(result.get().getType()); - continue; - } - new_result_types.push_back(delinearized_tensor->getType()); - result.set(*delinearized_tensor); - found_cast = true; - } - if (!found_cast) { - return rewriter.notifyMatchFailure(op, "no cast found"); - } - Location loc = op.getLoc(); - // Update the "case" regions. - for (auto& case_region : op.getCaseRegions()) { - auto yield = mlir::cast( - case_region.getBlocks().front().getTerminator()); - mlir::OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yield); - for (auto&& [result, type] : - llvm::zip(yield->getOpOperands(), new_result_types)) { - if (result.get().getType() == type) continue; - result.set( - rewriter.create(loc, type, result.get()) - .getResult(0)); - } - } - // Create new IndexSwitchOp and move the old op's regions to the new one. - auto new_index_switch = rewriter.create( - loc, new_result_types, op.getArg(), op.getCases(), op.getNumCases()); - for (auto&& [old_region, new_region] : - llvm::zip(op.getRegions(), new_index_switch.getRegions())) { - rewriter.inlineRegionBefore(*old_region, *new_region, new_region->end()); - } - // Update the results. - rewriter.setInsertionPointAfter(new_index_switch); - SmallVector new_results(new_index_switch.getResults()); - for (auto&& [index, result] : llvm::enumerate(new_results)) { - Type old_type = op->getResult(index).getType(); - if (result.getType() == old_type) continue; - result = - rewriter.create(loc, old_type, result) - .getResult(0); - } - rewriter.replaceOp(op, new_results); - return mlir::success(); - } -}; - -struct RewriteSyncThreads : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SyncThreadsOp op, - PatternRewriter& rewriter) const override { - auto types = op.getResultTypes(); - if (HasOnlyFlatTensorsFlatVectorsOrScalars(types)) { - return rewriter.notifyMatchFailure(op, "nothing to flatten"); - } - - auto loc = op.getLoc(); - - SmallVector new_operands; - new_operands.reserve(op.getNumOperands()); - llvm::SmallBitVector results_to_update(op.getNumResults(), false); - for (auto& operand : op->getOpOperands()) { - auto tensor_type = mlir::cast(operand.get().getType()); - if (tensor_type.getRank() < 2) continue; - results_to_update.set(operand.getOperandNumber()); - new_operands.push_back( - rewriter - .create( - loc, GetFlattenedType(tensor_type), operand.get()) - .getResult(0)); - } - auto new_op = rewriter.create(loc, TypeRange(new_operands), - new_operands); - SmallVector new_results; - new_results.reserve(op.getNumResults()); - for (auto [index, result] : llvm::enumerate(new_op.getResults())) { - if (!results_to_update.test(index)) { - new_results.push_back(result); - continue; - } - auto cast_to_orig_type = rewriter.create( - loc, result.getType(), result); - new_results.push_back(cast_to_orig_type.getResult(0)); - } - rewriter.replaceOp(op, new_results); - return mlir::success(); - } -}; - -class FlattenTensorsPass - : public impl::FlattenTensorsPassBase { - public: - void runOnOperation() override { - mlir::ModuleOp module = getOperation(); - MLIRContext* mlir_context = &getContext(); - mlir::RewritePatternSet patterns(mlir_context); - // clang-format off - patterns.add< - RewriteAllocateShared, - RewriteAtomicRMW, - RewriteConstant, - RewriteFor, - RewriteFunctionSignatures, - RewriteIf, - RewriteIndexSwitch, - RewritePureCall, - RewriteSyncThreads, - RewriteTensorExtract, - RewriteTensorInsert, - RewriteVectorExtract, - RewriteVectorInsert - >(mlir_context); - // clang-format on - ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context); - if (mlir::failed( - mlir::applyPatternsGreedily(module, std::move(patterns)))) { - signalPassFailure(); - return; - } - // Check if there are no unrealized_conversion_casts. - bool module_has_casts = module - .walk([](UnrealizedConversionCastOp op) { - return mlir::WalkResult::interrupt(); - }) - .wasInterrupted(); - if (module_has_casts) { - llvm::outs() << "FlattenTensorsPass failed to converge"; - signalPassFailure(); - return; - } - } -}; - -} // namespace - -std::unique_ptr CreateFlattenTensorsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla From d321133420da96982847bddd3cf22fd1193673f6 Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Fri, 4 Apr 2025 00:16:59 -0700 Subject: [PATCH 0245/1324] PR #23508: [ROCm] Enable OCP FP8 for latest AMD archs Imported from GitHub PR https://github.com/openxla/xla/pull/23508 We are moving from NANOO FP8 (i.e., F8E4M3FNUZ, F8E5M2FNUZ) to OCP FP8 (i.e., F8E4M3FN, F8E5M2) for new archs. In this PR: - We add support for OCP FP8 on ROCm. (`F8ConvertD` fusion isn't enabled now; we will submit another PR to support it soon.) - We also find that, because a/b scales are always enabled in FP8 scenarios, it's necessary to pass these scaling factors to the autotuner (even though they are dummy pointers), ensuring it selects algorithms that handle a/b scales properly. We introduce another `is_fp8` field to `GemmBackendConfig` to support it. Copybara import of the project: -- 42914277bdcd7e7bb3d760c510fff212c220c8f7 by scxfjiang : enable ocp fp8 for latest amd archs in gemm rewriter -- c080d446deb92b100ba28d6838d11fe5f06543b2 by scxfjiang : fix typo -- ad4000fdc59061ab147d6d46d0549b289381fe1b by scxfjiang : rm is_fp8 field -- 8e43ae776b84b08dd2a2289c7b0d6f20d8048382 by scxfjiang : fix typo Merging this change closes #23508 PiperOrigin-RevId: 743837883 --- .../gpu/autotuning/gemm_algorithm_picker.cc | 17 +- .../service/gpu/transforms/gemm_rewriter.cc | 165 ++++++++++++++---- .../gpu/transforms/gemm_rewriter_fp8_test.cc | 157 +++++++++-------- .../gpu/transforms/gemm_rewriter_test_lib.cc | 6 +- .../gpu/transforms/gemm_rewriter_test_lib.h | 2 +- .../xla/stream_executor/device_description.h | 6 +- .../xla/stream_executor/rocm/hip_blas_lt.cc | 60 ++++++- .../stream_executor/rocm/hip_blas_utils.cc | 18 +- 8 files changed, 313 insertions(+), 118 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc index c6cd3e4b332086..fabeb049d66d47 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc @@ -174,9 +174,24 @@ class GemmAutotuner { se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer; + int64_t input_buffer_idx = 2; // lhs is at 0, rhs is at 1 if (has_vector_bias) { - bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2); + if (has_matrix_bias) { + input_buffer_idx++; + } + bias_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++); + } + // In the current GemmRewriter design for FP8, the a/b scales remain active + // even when they are not used. Consequently, we must inform the autotuner + // so it can choose algorithms that properly support a/b scales. + if (xla::primitive_util::IsF8Type( + gemm->operand(0)->shape().element_type()) && + xla::primitive_util::IsF8Type( + gemm->operand(1)->shape().element_type())) { + a_scale_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++); + b_scale_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++); } + if (has_aux_output) { aux_buffer = rz_buffers_.output_buffers().at(1); } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 6f506f78a5f9b9..d0a37584173235 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1108,22 +1108,52 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (IsRocm(gpu_version_)) { - if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) { - VLOG(1) - << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. The element type of one of the operands " - "must be F8E4M3FNUZ."; - return false; + TF_ASSIGN_OR_RETURN(auto rocm_compute_capability, + GetRocmComputeCapability(gpu_version_)); + if (rocm_compute_capability.has_ocp_fp8_support()) { + if (a_type == F8E5M2 && b_type == F8E5M2) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For " + << rocm_compute_capability.gfx_version() + << " arch, one of the input types must be F8E4M3FN, but got " + << PrimitiveType_Name(a_type) << " and " + << PrimitiveType_Name(b_type); + return false; + } + if ((a_type != F8E5M2 && a_type != F8E4M3FN) || + (b_type != F8E5M2 && b_type != F8E4M3FN)) { + VLOG(1) + << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For " + << rocm_compute_capability.gfx_version() + << " arch, the input types must be F8E5M2 or F8E4M3FN, but got " + << PrimitiveType_Name(a_type) << " and " + << PrimitiveType_Name(b_type); + return false; + } } - if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) || - (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) { - VLOG(1) - << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or " - "F8E4M3FNUZ, but got " - << PrimitiveType_Name(a_type) << " and " - << PrimitiveType_Name(b_type); - return false; + if (rocm_compute_capability.has_nanoo_fp8_support()) { + if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) { + VLOG(1) + << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For " + << rocm_compute_capability.gfx_version() + << " arch, one of the input types must be F8E4M3FNUZ, but got " + << PrimitiveType_Name(a_type) << " and " + << PrimitiveType_Name(b_type); + return false; + } + if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) || + (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For " + << rocm_compute_capability.gfx_version() + << " arch, the input types must be F8E5M2FNUZ or F8E4M3FNUZ, " + "but got " + << PrimitiveType_Name(a_type) << " and " + << PrimitiveType_Name(b_type); + return false; + } } } @@ -1170,25 +1200,56 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } PrimitiveType d_type = instr->shape().element_type(); - bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32); - if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) { - supported_d_type = true; - } - if (IsRocm(gpu_version_) && - toolkit_version_ >= stream_executor::SemanticVersion{6, 2, 0} && - (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) { - supported_d_type = true; + std::unordered_set supported_d_types = {BF16, F16, F32}; + if (IsCuda(gpu_version_)) { + supported_d_types.insert(F8E4M3FN); + supported_d_types.insert(F8E5M2); + if (supported_d_types.find(d_type) == supported_d_types.end()) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. Output type must be " + "F8E4M3FN, F8E5M2, BF16, F16 or F32, but got " + << PrimitiveType_Name(d_type); + return false; + } } - if (!supported_d_type) { - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. Output element type must be " - << (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. " - : toolkit_version_ >= - stream_executor::SemanticVersion{6, 2, 0} - ? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. " - : "BF16, F16 or F32. ") - << "Actual element type is " << PrimitiveType_Name(d_type); - return false; + if (IsRocm(gpu_version_)) { + if (toolkit_version_ < stream_executor::SemanticVersion{6, 2, 0}) { + if (supported_d_types.find(d_type) == supported_d_types.end()) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For ROCm version < 6.2, output " + "type must be BF16, F16 or F32, but got " + << PrimitiveType_Name(d_type); + return false; + } + } + TF_ASSIGN_OR_RETURN(auto rocm_compute_capability, + GetRocmComputeCapability(gpu_version_)); + if (rocm_compute_capability.has_ocp_fp8_support()) { + supported_d_types.insert(F8E4M3FN); + supported_d_types.insert(F8E5M2); + if (supported_d_types.find(d_type) == supported_d_types.end()) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For " + << rocm_compute_capability.gfx_version() + << " arch output type must be F8E4M3FN, F8E5M2, BF16, F16 or " + "F32, but got " + << PrimitiveType_Name(d_type); + return false; + } + } + if (rocm_compute_capability.has_nanoo_fp8_support()) { + supported_d_types.insert(F8E4M3FNUZ); + supported_d_types.insert(F8E5M2FNUZ); + if (supported_d_types.find(d_type) == supported_d_types.end()) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. For " + << rocm_compute_capability.gfx_version() + << " arch output type must be F8E4M3FNUZ, F8E5M2FNUZ, BF16, " + "F16 or F32, but got " + << PrimitiveType_Name(d_type); + return false; + } + } } // Each operand must have exactly one contracting and one non-contracting @@ -1383,6 +1444,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *d_scale, HloInstruction *clamp_lower, HloInstruction *clamp_upper, bool mult_scale = false) { + // TODO: add ROCm support to this fusion pattern + if (IsRocm(gpu_version_)) { + return absl::OkStatus(); + } // Verify the data types and the operands of clamp. if (instr->shape().element_type() == F8E4M3FN) { if (!clamp_lower->literal().IsAllFloat(static_cast( @@ -2129,7 +2194,39 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return true; } const TypeCombinations supported_hipblas_type_combinations = { - // FP8 types: + // OCP FP8 types: + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E4M3FN, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E4M3FN, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E4M3FN, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E5M2, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E5M2, DataType::kF8E4M3FN}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E5M2, DataType::kF8E5M2}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E5M2, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + PrimitiveType::F8E5M2, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + PrimitiveType::F8E4M3FN, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + PrimitiveType::F8E4M3FN, DataType::kF8E5M2}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + PrimitiveType::F8E4M3FN, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + PrimitiveType::F8E4M3FN, DataType::kFloat}, + + // NANOO FP8 types: {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, PrimitiveType::F8E4M3FNUZ, DataType::kBF16}, {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc index 0d6088973651f1..c766ba81bd3179 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc @@ -57,11 +57,26 @@ class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTestBase { public: ParameterizedFp8GemmRewriteTest() { - replacements_[kF8E4M3DatatypePlaceholder] = - IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; - replacements_[kF8E5M2DatatypePlaceholder] = - IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; - replacements_[kF8E4M3AmaxPlaceholder] = IsCuda() ? "448." : "240."; + if (IsCuda()) { + replacements_[kF8E4M3DatatypePlaceholder] = "f8e4m3fn"; + replacements_[kF8E5M2DatatypePlaceholder] = "f8e5m2"; + replacements_[kF8E4M3AmaxPlaceholder] = "448."; + return; + } + if (IsRocm() && std::get(Capability()) + .has_ocp_fp8_support()) { + replacements_[kF8E4M3DatatypePlaceholder] = "f8e4m3fn"; + replacements_[kF8E5M2DatatypePlaceholder] = "f8e5m2"; + replacements_[kF8E4M3AmaxPlaceholder] = "448."; + return; + } + if (IsRocm() && std::get(Capability()) + .has_nanoo_fp8_support()) { + replacements_[kF8E4M3DatatypePlaceholder] = "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = "f8e5m2fnuz"; + replacements_[kF8E4M3AmaxPlaceholder] = "240."; + return; + } } void SetUp() override { @@ -73,6 +88,12 @@ class ParameterizedFp8GemmRewriteTest GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; } + + if (IsRocm() && + !std::get(Capability()).has_fp8_support()) { + GTEST_SKIP() + << "F8 gemm rewrite is only supported on MI300 and newer archs."; + } } protected: @@ -295,7 +316,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), checks); } @@ -318,7 +339,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: <>[16,16]) -> <>[16,16] { @@ -367,7 +388,7 @@ HloModule test CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,64,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,64,16] { @@ -424,7 +445,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -476,7 +497,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[13,17], {{.*}}: <>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] { @@ -534,7 +555,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -564,7 +585,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> f32[16,16] { @@ -620,7 +641,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -679,7 +700,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -737,7 +758,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -745,7 +766,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[32,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -800,7 +821,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -808,7 +829,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] { @@ -867,7 +888,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); @@ -896,7 +917,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[10,16,32], {{.*}}: <>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] { @@ -951,7 +972,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1007,7 +1028,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1131,7 +1152,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), checks); } @@ -1227,7 +1248,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, )"; RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), checks); } @@ -1256,7 +1277,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -1291,7 +1312,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1350,7 +1371,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1413,7 +1434,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> <>[16,16] { @@ -1464,7 +1485,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { @@ -1513,7 +1534,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { @@ -1564,7 +1585,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { @@ -1629,7 +1650,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { @@ -1694,7 +1715,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1738,7 +1759,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { @@ -1817,7 +1838,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1886,7 +1907,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( @@ -1952,7 +1973,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -2012,7 +2033,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { @@ -2070,7 +2091,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -2084,7 +2105,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] { @@ -2147,7 +2168,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -2163,7 +2184,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,15,15], {{.*}}: <>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] { @@ -2230,7 +2251,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -2244,7 +2265,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] { @@ -2303,7 +2324,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -2319,7 +2340,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3,15,15], {{.*}}: <>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] { @@ -2386,14 +2407,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter pass(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[48,16], {{.*}}: <>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] { @@ -2455,7 +2476,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.}); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { @@ -2533,7 +2554,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { @@ -2611,7 +2632,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { @@ -2692,7 +2713,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { @@ -2816,7 +2837,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -2884,7 +2905,7 @@ ENTRY f { RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -2915,7 +2936,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -2941,22 +2962,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - if (IsCuda()) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - CudaHopperOrRocmMI300(), GetToolkitVersion(), - GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - this->RunHloPass(&pass, module.get())); - EXPECT_FALSE(changed); - return; - } - if (IsRocm()) { + if (IsRocm() && std::get(Capability()) + .has_nanoo_fp8_support()) { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriter(CudaHopperOrRocmCapability(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -2994,6 +3005,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } )"); + } else { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass( + CudaHopperOrRocmCapability(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + this->RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); + return; } } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc index b2650bf3e4e02f..ac8c56e79bccb3 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc @@ -67,11 +67,11 @@ bool GemmRewriteTestBase::IsBlackwell() const { } stream_executor::GpuComputeCapability -GemmRewriteTestBase::CudaHopperOrRocmMI300() { +GemmRewriteTestBase::CudaHopperOrRocmCapability() { if (IsCuda()) { - return stream_executor::CudaComputeCapability::Hopper(); + return se::CudaComputeCapability::Hopper(); } else { - return stream_executor::RocmComputeCapability{"gfx942"}; + return std::get(Capability()); } } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h index 8b628d8a8cea88..efefb8a25c0216 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h @@ -39,7 +39,7 @@ class GemmRewriteTestBase : public GpuCodegenTest { bool IsBlackwell() const; - stream_executor::GpuComputeCapability CudaHopperOrRocmMI300(); + stream_executor::GpuComputeCapability CudaHopperOrRocmCapability(); DebugOptions GetDebugOptionsForTest() const override; diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 8263373452617f..333cb0c4c4319d 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -124,9 +124,13 @@ class RocmComputeCapability { } bool has_fp8_support() const { - return gfx9_mi300_series() || gfx1200() || gfx1201(); + return has_ocp_fp8_support() || has_nanoo_fp8_support(); } + bool has_ocp_fp8_support() const { return gfx1200() || gfx1201(); } + + bool has_nanoo_fp8_support() const { return gfx_version() == "gfx942"; } + std::string ToString() const { return gcn_arch_name(); } RocmComputeCapabilityProto ToProto() const { diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index c33a463307e853..37084b0ce38971 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -238,9 +238,27 @@ auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, // no algorithms can be found for "bias epilogues". This is to be removed // later when this limitation is gone. if (op_desc_.has_bias_epilogue()) { - static int64_t dummyPointer = 0xACEBALL; + static int64_t dummy_pointer = 0xACEBALL; TF_RETURN_IF_ERROR(SetAttr( - op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &dummyPointer)); + op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &dummy_pointer)); + } + + // hipBlasLt requires setting the a/b scale pointer (even a dummy one), + // otherwise no algorithms can be found for "a/b scaling". This is to be + // removed later when this limitation is gone. + auto IsFP8 = [&](const MatrixLayout& layout) -> bool { + return layout.type() == HIP_R_8F_E5M2_FNUZ || + layout.type() == HIP_R_8F_E4M3_FNUZ || + layout.type() == HIP_R_8F_E5M2 || layout.type() == HIP_R_8F_E4M3; + }; + if (IsFP8(a_desc_) && IsFP8(b_desc_)) { + static int64_t dummy_pointer = 0xACEBALL; + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &dummy_pointer)); + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &dummy_pointer)); } int found_algorithm_count = 0; @@ -419,9 +437,15 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, args.b_scale.opaque())); } - if (args.c_scale != nullptr || args.d_scale != nullptr) { - return absl::InternalError( - "hipblaslt does not support c_scale or d_scale."); + if (args.c_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER, + args.c_scale.opaque())); + } + if (args.d_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, + args.d_scale.opaque())); } #else if (!(args.a_scale == nullptr && args.b_scale == nullptr && @@ -531,6 +555,32 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E5M2_FNUZ) #endif +#if TF_ROCM_VERSION >= 60300 + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_8F_E4M3) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_8F_E4M3) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E4M3, HIP_R_8F_E4M3, + HIP_R_8F_E4M3) + + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_8F_E4M3) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_8F_E5M2) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_8F_E4M3) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_8F_E5M2) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_32F) + + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_8F_E4M3) + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_8F_E5M2) + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_8F_E4M3) + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_8F_E5M2) + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_32F) +#endif + // Other data types: TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc index 8864476bf0d825..23cbfdae2d47b4 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc @@ -35,14 +35,12 @@ absl::Status ToStatus(hipblasStatus_t status, const char* prefix) { hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { - case blas::DataType::kF8E5M2: case blas::DataType::kF8E4M3: - case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: case blas::DataType::kF4E2M1FN: case blas::DataType::kF8E8M0FNU: - LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " - "F8E3M4, F4E2M1FN and F8E8M0FNU"; + LOG(FATAL) << "hipblaslt does not support, F8E4M3, F8E3M4, F4E2M1FN and " + "F8E8M0FNU"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; @@ -51,7 +49,17 @@ hipDataType AsHipblasDataType(blas::DataType type) { #else case blas::DataType::kF8E5M2FNUZ: case blas::DataType::kF8E4M3FNUZ: - LOG(FATAL) << "hipblaslt only supports F8 in ROCm 6.0 and above"; + LOG(FATAL) << "hipblaslt only supports nanoo F8 in ROCm 6.0 and above"; +#endif +#if TF_ROCM_VERSION >= 60300 + case blas::DataType::kF8E5M2: + return HIP_R_8F_E5M2; + case blas::DataType::kF8E4M3FN: + return HIP_R_8F_E4M3; +#else + case blas::DataType::kF8E5M2: + case blas::DataType::kF8E4M3FN: + LOG(FATAL) << "hipblaslt only supports OCP F8 in ROCm 6.3 and above"; #endif case blas::DataType::kHalf: return HIP_R_16F; From f4aa69f7d1845a60c0448e90f4d69b25f0cd16ad Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 4 Apr 2025 00:17:10 -0700 Subject: [PATCH 0246/1324] PR #24508: [GPU] Bump minimal supported cuDNN version to 8.9. Imported from GitHub PR https://github.com/openxla/xla/pull/24508 Copybara import of the project: -- 4de77bbedbbfe68a33ad59d948fa49979f78988a by Ilia Sergachev : [GPU] Bump minimal supported cuDNN version to 8.9. Merging this change closes #24508 PiperOrigin-RevId: 743837925 --- .../cuda/hermetic/cuda_redist_versions.bzl | 4 - .../cuda/hermetic/cuda_redist_versions.bzl | 4 - .../xla/xla/service/gpu/nvptx_compiler.cc | 7 +- .../transforms/cudnn_fused_conv_rewriter.cc | 3 - .../cudnn_fused_conv_rewriter_test.cc | 19 +- .../transforms/cudnn_simplify_padding_test.cc | 2 +- .../cudnn_vectorize_convolutions.cc | 6 +- .../cudnn_vectorize_convolutions_test.cc | 4 +- .../gpu/transforms/layout_assignment.cc | 18 - .../gpu/transforms/layout_assignment_test.cc | 4 +- .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 577 +----------------- .../xla/xla/stream_executor/cuda/cuda_dnn.h | 6 - .../hlo_opt/gpu_specs/a100_pcie_80.txtpb | 4 +- third_party/xla/xla/tsl/util/use_cudnn.cc | 10 +- 14 files changed, 50 insertions(+), 618 deletions(-) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index e0543e8cc8e433..ae4802ee04519d 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -81,10 +81,6 @@ CUDA_REDIST_JSON_DICT = { } CUDNN_REDIST_JSON_DICT = { - "8.6": [ - "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", - "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", - ], "8.9.4.25": [ "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.4.25.json", "02258dba8384860c9230fe3c78522e7bd8e350e461ccd37a8d932cb64127ba57", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index e0543e8cc8e433..ae4802ee04519d 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -81,10 +81,6 @@ CUDA_REDIST_JSON_DICT = { } CUDNN_REDIST_JSON_DICT = { - "8.6": [ - "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", - "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", - ], "8.9.4.25": [ "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.4.25.json", "02258dba8384860c9230fe3c78522e7bd8e350e461ccd37a8d932cb64127ba57", diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 5290bdc984c115..fd74b5b5b06470 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -124,11 +124,8 @@ class ConvBfloat16Support : public FloatSupport { se::dnn::VersionInfo cudnn_version, se::CudaComputeCapability cuda_compute_capability) : FloatSupport(BF16), - is_conv_bf16_supported_((cudnn_version.major_version() > 8 || - (cudnn_version.major_version() == 8 && - cudnn_version.minor_version() >= 2)) && - cuda_compute_capability.IsAtLeast( - se::CudaComputeCapability::kAmpere)) {} + is_conv_bf16_supported_(cuda_compute_capability.IsAtLeast( + se::CudaComputeCapability::kAmpere)) {} bool SupportsLowPrecisionOperand(const HloInstruction& hlo, int64_t operand_index) const override { diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc index 0d4fe39f6c8741..e5f033611f6745 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc @@ -848,9 +848,6 @@ absl::StatusOr F8GraphConv(HloComputation* comp, const se::SemanticVersion& toolkit_version) { bool changed = false; - if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) { - return false; - } if (toolkit_version < se::SemanticVersion{12, 0, 0}) { return false; } diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index f2b00e2c2b87a5..e52961faad06ce 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -286,16 +286,15 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { } }; -#define MAYBE_SKIP_TEST(CAUSE) \ - do { \ - if (absl::string_view(CAUSE) == "F8" && IsCuda() && \ - (GetToolkitVersion() < se::SemanticVersion{12, 0, 0} || \ - GetDnnVersion() < se::dnn::VersionInfo(8, 9, 0))) { \ - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; \ - } \ - if (!IsCuda()) { \ - GTEST_SKIP() << CAUSE " fusion is only supported on CUDA."; \ - } \ +#define MAYBE_SKIP_TEST(CAUSE) \ + do { \ + if (absl::string_view(CAUSE) == "F8" && IsCuda() && \ + GetToolkitVersion() < se::SemanticVersion{12, 0, 0}) { \ + GTEST_SKIP() << "FP8 convolutions require CUDA 12."; \ + } \ + if (!IsCuda()) { \ + GTEST_SKIP() << CAUSE " fusion is only supported on CUDA."; \ + } \ } while (0) TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc index deff8f7d9c8399..38cbe080d1d53b 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc @@ -64,7 +64,7 @@ class CudnnSimplifyPaddingTest : public HloTestBase { TF_RETURN_IF_ERROR( RunHloPass(CudnnVectorizeConvolutions( - cc, /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0}), + cc, /*cudnn_version=*/se::dnn::VersionInfo{8, 9, 0}), module) .status()); VLOG(1) << "after vectorizing convs:\n" << module->ToString(); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index 8325c0f2da2d20..0fdf28c82621a2 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -406,8 +406,7 @@ static absl::StatusOr TryRevectorizeConv( const auto& debug_options = conv->GetModule()->config().debug_options(); bool use_reordering = input_shape.element_type() == xla::S8 && vect_size == 32 && - debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() && - cudnn_version >= se::dnn::VersionInfo{8, 3, 0}; + debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering(); if (use_reordering) { // Reordering helper supports vector sizes of 4 and 32, so an additional // reshape-transpose-reshape is not necessary in these cases. @@ -551,8 +550,7 @@ static absl::StatusOr TryVectorizeConv( const auto& debug_options = conv->GetModule()->config().debug_options(); bool use_reordering = input_shape.element_type() == xla::S8 && vect_size == 32 && - debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() && - cudnn_version >= se::dnn::VersionInfo{8, 3, 0}; + debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering(); if (use_reordering) { new_operands[1] = filter; TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data())); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc index d3a0586598382b..5118d99d60248f 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc @@ -50,7 +50,7 @@ class CudnnVectorizeConvolutionsTest : public HloTestBase { CudnnVectorizeConvolutions pass( se::CudaComputeCapability{compute_capability.first, compute_capability.second}, - se::dnn::VersionInfo(8, 3, 0)); + se::dnn::VersionInfo(8, 9, 0)); TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module)); CallInliner inliner; @@ -229,7 +229,7 @@ TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) { .value(); CudnnVectorizeConvolutions pass( /*compute_capability=*/{7, 5}, - /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0}); + /*cudnn_version=*/se::dnn::VersionInfo{8, 9, 0}); TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get())); SCOPED_TRACE(module->ToString()); diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index b5fbe5c093f104..b3ff0ea817d0bf 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -177,24 +177,6 @@ HeuristicLayoutAssignment(const HloInstruction* instr, instr->shape().tuple_shapes(0).dimensions_size() != 4) { return kAllNCHW; } - - // Empirically we've found with Volta and cudnn <= 7.3 that backward-input - // convs with stride are significantly faster with NCHW layouts. - // - // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), - // which on paper gives good performance. However, there are two - // observations: - // * a mixed layout combination is more cuDNN-bug prone, based on empirical - // evidence. - // * we've also observed that for mixed layouts, cuDNN transposes data back - // and forth from a different layout combination. If we end up with - // transposes anyway, we prefer to have them in XLA, as they can be fused. - if (std::make_tuple(dnn_version.major_version(), - dnn_version.minor_version()) <= std::make_tuple(7, 3) && - instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && - window_util::HasStride(instr->window())) { - return kAllNCHW; - } } else if (std::holds_alternative(gpu_version)) { bool is_enabled = false; TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_ROCM_NHWC", diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 8650c3c4a52c22..896fb72c2da82d 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -70,10 +70,8 @@ class LayoutAssignmentTest : public HloTestBase { } se::dnn::VersionInfo GetDnnVersion() { - // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but - // none of the tests trigger this heuristic. return GetDnnVersionInfoOrDefault(backend().default_stream_executor(), - se::dnn::VersionInfo{8, 3, 0}); + se::dnn::VersionInfo{8, 9, 0}); } }; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 6f99c125a54040..53a18f564c8c71 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -81,7 +81,7 @@ limitations under the License. #include "third_party/gpus/cudnn/cudnn_adv.h" #include "third_party/gpus/cudnn/cudnn_cnn.h" #include "third_party/gpus/cudnn/cudnn_ops.h" -#elif CUDNN_VERSION >= 8100 +#else #include "third_party/gpus/cudnn/cudnn_adv_infer.h" #include "third_party/gpus/cudnn/cudnn_adv_train.h" #include "third_party/gpus/cudnn/cudnn_cnn_infer.h" @@ -92,7 +92,6 @@ limitations under the License. #include "third_party/gpus/cudnn/cudnn_backend.h" -#if CUDNN_VERSION >= 8100 #include "third_party/cudnn_frontend/include/cudnn_frontend.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_utils.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_EngineConfig.h" @@ -107,7 +106,6 @@ limitations under the License. #include "third_party/cudnn_frontend/include/cudnn_frontend_Rng.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_Tensor.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_VariantPack.h" -#endif // CUDNN_VERSION >= 8100 // clang-format on #ifdef __clang__ @@ -123,7 +121,7 @@ namespace gpu { namespace { -static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); +static_assert(CUDNN_VERSION >= 8900, "cuDNN needs to be version 8.9 or higher"); // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS. #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS) @@ -429,17 +427,17 @@ void PreloadCudnnSubLibs(PreloadCudnnType type) { switch (type) { case PreloadCudnnType::ConvBwdFilter: case PreloadCudnnType::ConvBwdData: { -#if CUDNN_VERSION >= 8004 && CUDNN_VERSION < 90000 +#if CUDNN_VERSION < 90000 cudnnOpsTrainVersionCheck(); cudnnCnnTrainVersionCheck(); -#endif // CUDNN_VERSION >= 8004 && CUDNN_VERSION < 90000 +#endif // CUDNN_VERSION < 90000 [[clang::fallthrough]]; } case PreloadCudnnType::ConvFwd: { #if CUDNN_VERSION >= 90000 cudnnGraphVersionCheck(); cudnnOpsVersionCheck(); -#elif CUDNN_VERSION >= 8004 +#else cudnnOpsInferVersionCheck(); cudnnCnnInferVersionCheck(); #endif // CUDNN_VERSION >= 90000 @@ -449,7 +447,7 @@ void PreloadCudnnSubLibs(PreloadCudnnType type) { #if CUDNN_VERSION >= 90000 cudnnOpsVersionCheck(); cudnnAdvVersionCheck(); -#elif CUDNN_VERSION >= 8004 +#else cudnnOpsInferVersionCheck(); cudnnAdvInferVersionCheck(); cudnnOpsTrainVersionCheck(); @@ -612,20 +610,11 @@ struct RnnDescriptorDeleter { CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor)); } }; -#if CUDNN_VERSION < 8100 -struct PersistentRnnPlanDeleter { - void operator()(cudnnPersistentRNNPlan_t plan) const { - CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan)); - } -}; -#endif // CUDNN_VERSION < 8100 -#if CUDNN_VERSION >= 7603 struct CtcLossDescriptorDeleter { void operator()(cudnnCTCLossDescriptor_t descriptor) const { CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor)); } }; -#endif // RAII wrappers for cuDNN types. using TensorDescriptor = @@ -644,17 +633,10 @@ using ActivationDescriptor = using DropoutDescriptor = std::unique_ptr; using RnnDescriptor = std::unique_ptr; -#if CUDNN_VERSION >= 8100 struct DummyType {}; using PersistentRnnPlan = std::unique_ptr; -#else -using PersistentRnnPlan = - std::unique_ptr; -#endif // CUDNN_VERSION >= 8100 -#if CUDNN_VERSION >= 7603 using CtcLossDescriptor = std::unique_ptr; -#endif // Factory methods for cuDNN types. TensorDescriptor CreateTensorDescriptor() { @@ -702,23 +684,11 @@ RnnDescriptor CreateRnnDescriptor() { CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result)); return RnnDescriptor(result); } -#if CUDNN_VERSION >= 7603 CtcLossDescriptor CreateCtcLossDescriptor() { cudnnCTCLossDescriptor_t result; CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result)); return CtcLossDescriptor(result); } -#endif - -#if CUDNN_VERSION < 8100 -absl::StatusOr CreatePersistentRnnPlan( - cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) { - cudnnPersistentRNNPlan_t result; - RETURN_IF_CUDNN_ERROR( - cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result)); - return absl::StatusOr(PersistentRnnPlan(result)); -} -#endif // CUDNN_VERSION < 8100 // Turns a BatchDescriptor structure into a cudnn tensor handle within a // scope. @@ -829,7 +799,6 @@ class CudnnFilterDescriptor { FilterDescriptor handle_; // Owned. }; -#if CUDNN_VERSION >= 8100 // The errata sheet (JSON format) for marking the cudnn engines that might be // buggy. For example, we don't want the engine 999 of forward convolution: // R"({ "version" : 1, @@ -925,8 +894,6 @@ const json* CudnnExecutionPlanEngineFilterRuntime() { return json_handle; } -#endif // CUDNN_VERSION >= 8100 - // A helper function to decide whether to use // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT @@ -994,24 +961,15 @@ class CudnnConvolutionDescriptor { : CUDNN_CROSS_CORRELATION, data_type)); -#if CUDNN_MAJOR >= 7 VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount( handle_.get(), convolution_descriptor.group_count())); -#else - CHECK_EQ(convolution_descriptor.group_count(), 1) - << "Requested grouped convolution for cuDNN version < 7"; -#endif } void set_use_tensor_op_math(bool use_tensor_op_math) { cudnnMathType_t math_type = -#if CUDNN_VERSION >= 8000 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH); -#else - (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); -#endif CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type)); } @@ -1026,11 +984,7 @@ class CudnnConvolutionDescriptor { static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) { cudnnMathType_t math_type; CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type)); -#if CUDNN_VERSION >= 8000 return math_type != CUDNN_FMA_MATH; -#else - return math_type == CUDNN_TENSOR_OP_MATH; -#endif } static bool TensorOpMathAvailable( @@ -1044,13 +998,9 @@ static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability, return false; } if (input_type == dnn::DataType::kFloat) { -#if CUDNN_VERSION < 8000 - return false; -#else if (!allow_tf32 || !tsl::tensor_float_32_execution_enabled()) { return false; } -#endif } return true; } @@ -1213,16 +1163,12 @@ cudnn_frontend::DataType_t ToCudnnFrontendDataType( return cudnn_frontend::DataType_t::INT32; case dnn::DataType::kInt64: return cudnn_frontend::DataType_t::INT64; -#if CUDNN_VERSION >= 8200 case dnn::DataType::kBF16: return cudnn_frontend::DataType_t::BFLOAT16; -#endif -#if CUDNN_VERSION >= 8900 case dnn::DataType::kF8E4M3FN: return cudnn_frontend::DataType_t::FP8_E4M3; case dnn::DataType::kF8E5M2: return cudnn_frontend::DataType_t::FP8_E5M2; -#endif #if CUDNN_VERSION >= 90700 case dnn::DataType::kF4E2M1FN: return cudnn_frontend::DataType_t::FP4_E2M1; @@ -1483,15 +1429,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { "Algo requests disallowed tensor op evaluation."); } -#if CUDNN_VERSION >= 8000 cudnnMathType_t math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; -#else - cudnnMathType_t math_type = - use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; -#endif -#if CUDNN_VERSION >= 8000 cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS; uint32_t aux_flags = 0; if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED; @@ -1506,48 +1446,13 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), /*auxFlags=*/aux_flags)); -#else - RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( - cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), - /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers, - /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, - /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo, - /*dataType=*/compute_type)); - CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); - - if (proj_size < hidden_size) { - RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers( - cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), - /*recProjSize=*/proj_size, /*outProjSize=*/0)); - } - - // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs. - // But in the future if these APIs are used to process full length arrays, - // we need to distinguish when to set it. - if (use_padded_io) { - RETURN_IF_CUDNN_ERROR( - cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED)); - } -#endif absl::StatusOr rnn_plan_wrapper; PersistentRnnPlan rnn_plan; if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { CHECK_GE(batch_size, 0); -#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR( cudnnBuildRNNDynamic(cudnn.handle(), rnn_desc.get(), batch_size)); -#else - rnn_plan_wrapper = - CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type); - if (!rnn_plan_wrapper.ok()) { - return absl::StatusOr(rnn_plan_wrapper.status()); - } else { - rnn_plan = std::move(rnn_plan_wrapper).value(); - RETURN_IF_CUDNN_ERROR( - cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get())); - } -#endif // CUDNN_VERSION >= 8100 } // Create the params handle. @@ -1618,7 +1523,6 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { void operator=(const CudnnRnnDescriptor&) = delete; }; -#if CUDNN_VERSION >= 7603 class CudnnCtcLossDescriptor { public: explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type) @@ -1638,13 +1542,6 @@ class CudnnCtcLossDescriptor { CudnnCtcLossDescriptor(const CudnnCtcLossDescriptor&) = delete; void operator=(const CudnnCtcLossDescriptor&) = delete; }; -#else -// dummy class -class CudnnCtcLossDescriptor { - public: - CudnnCtcLossDescriptor(cudnnDataType_t data_type) {} -}; -#endif namespace { @@ -1665,7 +1562,6 @@ absl::Status CheckAndFetchProjectionWeights( cudnnRNNAlgo_t algo; cudnnDataType_t data_type; int rec_proj_size_v; -#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v8( /*rnnDesc=*/rnn_desc, /*algo=*/&algo, @@ -1682,27 +1578,8 @@ absl::Status CheckAndFetchProjectionWeights( /*numLayers=*/&num_layers_v, /*dropoutDesc=*/&dropout_desc, /*auxFlags=*/nullptr)); -#else - RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*hiddenSize=*/&hidden_size_v, - /*numLayers=*/&num_layers_v, - /*dropoutDesc=*/&dropout_desc, - /*inputMode=*/&input_mode, - /*direction=*/&direction, - /*mode=*/&mode, - /*algo=*/&algo, - /*mathPrec=*/&data_type)); - int out_proj_size_v; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( - /*handle=*/cudnn.handle(), - /*rnnDesc=*/rnn_desc, - /*recProjSize*/ &rec_proj_size_v, - /*outProjSize*/ &out_proj_size_v)); -#endif // CUDNN_VERSION >= 8100 if (rec_proj_size_v != hidden_size_v) { int region_id = 8; -#if CUDNN_VERSION >= 8100 void* b_ptr = nullptr; void* m_ptr = nullptr; void* w_ptr = nullptr; @@ -1733,27 +1610,6 @@ absl::Status CheckAndFetchProjectionWeights( int64_t size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); int64_t offset = static_cast(m_ptr) - static_cast(w_ptr); -#else - void* offset = nullptr; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*layer=*/layer, /*xDesc=*/input_desc.get(), - /*wDesc=*/filter_desc.get(), - /*w=*/nullptr, /*linLayerID=*/region_id, - /*linLayerMatDesc=*/region_desc_handle.get(), - /*linLayerMat or linLayerBias=*/&offset)); - int dims[] = {1, 1, 1}; - cudnnDataType_t data_type; - cudnnTensorFormat_t tensor_format; - int n_dims; - RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( - /*filterDesc=*/region_desc_handle.get(), - /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), - /*dataType=*/&data_type, /*format=*/&tensor_format, - /*nbDims=*/&n_dims, /*filterDimA=*/dims)); - int64_t size = - dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); -#endif // CUDNN_VERSION >= 8100 dnn::RnnDescriptor::ParamsRegion region = {static_cast(offset), size}; weights->push_back(region); @@ -1776,16 +1632,9 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( /*strideA=*/strides)); size_t params_size = 0; -#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, /*weightSpaceSize=*/¶ms_size)); -#else - RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size, - /*dataType=*/data_type)); -#endif // CUDNN_VERSION >= 8100 int64_t params_size_in_bytes = static_cast(params_size); FilterDescriptor filter_desc = CreateFilterDescriptor(); @@ -1823,7 +1672,6 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( for (int layer = 0; layer < layer_count; layer++) { for (int region = 0; region < region_count_per_layer; region++) { -#if CUDNN_VERSION >= 8100 void* m_ptr = nullptr; void* b_ptr = nullptr; void* w_ptr = nullptr; @@ -1867,40 +1715,6 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( int64_t b_offset = static_cast(b_ptr) - static_cast(w_ptr); dnn::RnnDescriptor::ParamsRegion b_region = {b_offset, b_size}; biases.push_back(b_region); -#else - for (int type = 0; type < 2; type++) { - void* offset = nullptr; - RETURN_IF_CUDNN_ERROR( - type == 0 ? cudnnGetRNNLinLayerMatrixParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*layer=*/layer, /*xDesc=*/input_desc.get(), - /*wDesc=*/filter_desc.get(), - /*w=*/nullptr, /*linLayerID=*/region, - /*linLayerMatDesc=*/region_desc_handle.get(), - /*linLayerMat or linLayerBias=*/&offset) - : cudnnGetRNNLinLayerBiasParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*layer=*/layer, /*xDesc=*/input_desc.get(), - /*wDesc=*/filter_desc.get(), - /*w=*/nullptr, /*linLayerID=*/region, - /*linLayerMatDesc=*/region_desc_handle.get(), - /*linLayerMat or linLayerBias=*/&offset)); - int dims[] = {1, 1, 1}; - cudnnDataType_t data_type; - cudnnTensorFormat_t tensor_format; - int n_dims; - RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( - /*filterDesc=*/region_desc_handle.get(), - /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), - /*dataType=*/&data_type, /*format=*/&tensor_format, - /*nbDims=*/&n_dims, /*filterDimA=*/dims)); - int64_t size = - dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); - dnn::RnnDescriptor::ParamsRegion region = {static_cast(offset), - size}; - (type == 0 ? weights : biases).push_back(region); - } -#endif // CUDNN_VERSION >= 8100 } TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights( cudnn, rnn_desc, layer, input_desc, filter_desc, params_size_in_bytes, @@ -2113,16 +1927,9 @@ absl::Status CheckRNNParameterSize( const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; -#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*sizeInBytes=*/¶ms_size_in_bytes)); -#else - RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes, - /*dataType=*/rnn_desc.data_type())); -#endif if (static_cast(params_size_in_bytes) != rnn_desc.ParamsSizeInBytes()) { return absl::InvalidArgumentError("Mismatching RNN parameter size"); @@ -2140,7 +1947,6 @@ absl::Status CreateRnnTempSpace( size_t reserve_space_size_in_bytes = 0; size_t workspace_size_in_bytes = 0; if (input_desc.is_var_seq_lengths()) { -#if CUDNN_VERSION >= 8100 auto rnn_fwd_mode = is_fwd_training ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE; RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( @@ -2150,10 +1956,6 @@ absl::Status CreateRnnTempSpace( /*xDesc=*/input_desc.data_handle(), /*workSpaceSize=*/&workspace_size_in_bytes, /*reserveSpaceSize=*/&reserve_space_size_in_bytes)); -#else - return tsl::errors::Internal( - "Sequence lengths for RNN are supported from CUDNN 8.1+"); -#endif // CUDNN_VERSION >= 8100 } else { #if CUDNN_VERSION >= 90000 return tsl::errors::Internal( @@ -2188,7 +1990,6 @@ absl::Status CreateRnnTempSpace( return absl::OkStatus(); } -#if CUDNN_VERSION >= 7402 absl::StatusOr> CreateBatchNormForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode, const cudnnBatchNormOps_t& bn_ops, @@ -2239,8 +2040,6 @@ absl::StatusOr> CreateBatchNormBackwardWorkspace( return workspace_allocator->AllocateBytes(workspace_size_in_bytes); } -#endif - } // namespace // Populates the profile result if not empty. @@ -2304,10 +2103,6 @@ absl::Status CudnnSupport::DoRnnForwardImpl( } if (input_desc.is_var_seq_lengths()) { - // In CUDNN v8, the cudnnRNNForward*** and cudnnRNNForward***Ex have been - // deprecated. Instead, we use the cudnnRNNForward which requires the - // sequence_lengths parameter. -#if CUDNN_VERSION >= 8100 auto rnn_fwd_mode = is_training ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE; RETURN_IF_CUDNN_ERROR(cudnnRNNForward( @@ -2326,41 +2121,6 @@ absl::Status CudnnSupport::DoRnnForwardImpl( /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(), /*reserveSpaceSizeInBytes=*/reserve_space.size(), /*reserveSpace=*/reserve_space.opaque())); -#else - if (!is_training) { - RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), - /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), - /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), - /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(), - /*yDesc=*/output_desc.data_handle(), - /*y=*/output_data->opaque(), - /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(), - /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(), - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, - /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size())); - } else { - RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), - /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), - /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), - /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(), - /*yDesc=*/output_desc.data_handle(), - /*y=*/output_data->opaque(), - /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(), - /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(), - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, - /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size(), - /*reserveSpace=*/reserve_space.opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space.size())); - } -#endif // CUDNN_VERSION >= 8100 } else { #if CUDNN_VERSION >= 90000 return tsl::errors::Internal( @@ -2457,10 +2217,6 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( } if (input_desc.is_var_seq_lengths()) { - // In CUDNN v8, the cudnnRNNBackward*** and cudnnRNNBackward***Ex have - // been deprecated. Instead, we use the cudnnRNNBackward***_v8 which - // requires the sequence_lengths parameter. -#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*devSeqLengths=*/ @@ -2480,36 +2236,11 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(), /*reserveSpaceSize=*/reserve_space_data->size(), /*reserveSpace=*/reserve_space_data->opaque())); -#else - RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(), - /*dyDesc=*/output_desc.data_handle(), - /*dy=*/output_backprop_data.opaque(), nullptr, nullptr, - /*dhyDesc=*/output_h_desc.handle(), - /*dhy=*/output_h_backprop_data.opaque(), - /*dcyDesc=*/output_c_desc.handle(), - /*dcy=*/output_c_backprop_data.opaque(), - /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(), - /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), - /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), - /*dxDesc=*/input_desc.data_handle(), - /*dx=*/input_backprop_data->opaque(), - /*dhxDesc=*/input_h_desc.handle(), - /*dhx=*/input_h_backprop_data->opaque(), - /*dcxDesc=*/input_c_desc.handle(), - /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr, - /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size(), - /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); -#endif // CUDNN_VERSION >= 8100 if (params_backprop_data != nullptr) { // Clear the dw to zeros. TF_RETURN_IF_ERROR( stream->MemZero(params_backprop_data, params_backprop_data->size())); -#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), @@ -2528,20 +2259,6 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( /*workSpace=*/workspace.opaque(), /*reserveSpaceSize=*/reserve_space_data->size(), /*reserveSpace=*/reserve_space_data->opaque())); -#else - RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), - /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), - /*yDesc=*/output_desc.data_handle(), - /*y=*/output_data.opaque(), - /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size(), - /*dwDesc=*/rnn_desc.params_handle(), - /*dw=*/params_backprop_data->opaque(), - /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); -#endif // CUDNN_VERSION >= 8100 } } else { #if CUDNN_VERSION >= 90000 @@ -2617,7 +2334,6 @@ absl::Status CudnnSupport::DoCtcLossImpl( int total_size = kNumLabels * kNumTimestamps * kBatchSize; (void)total_size; -#if CUDNN_VERSION >= 7603 cudnnCTCLossAlgo_t ctc_loss_algo = static_cast(ctc_loss_algo_id); RETURN_IF_CUDNN_ERROR(cudnnCTCLoss( @@ -2631,11 +2347,6 @@ absl::Status CudnnSupport::DoCtcLossImpl( /*ctcLossDesc=*/ctc_loss_desc.handle(), /*workspace=*/scratch_memory.opaque(), /*workSpaceSizeInBytes=*/scratch_memory.size())); -#else - return absl::InvalidArgumentError( - "No supported cudnnCTCLoss when " - "CUDNN_VERSION < 7.6.3"); -#endif return absl::OkStatus(); } @@ -2997,7 +2708,6 @@ absl::StatusOr GetCudnnConvolutionForwardAlgo( const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit, size_t memory_limit_bytes) { -#if CUDNN_VERSION >= 8000 const int num_requested_algos = 5; int num_returned_algos = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos]; @@ -3018,16 +2728,6 @@ absl::StatusOr GetCudnnConvolutionForwardAlgo( return absl::InternalError( "cudnnGetConvolutionForwardAlgorithm_v7 returned " "no suitable algorithms. This could be a cudnn bug."); -#else - cudnnConvolutionFwdPreference_t preference = - specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - cudnnConvolutionFwdAlgo_t algo_to_use; - RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm( - cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), - output_nd.handle(), preference, memory_limit_bytes, &algo_to_use)); - return algo_to_use; -#endif } absl::StatusOr @@ -3038,7 +2738,6 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit, size_t memory_limit_bytes) { -#if CUDNN_VERSION >= 8000 const int num_requested_algos = 5; int num_returned_algos = 0; cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos]; @@ -3060,17 +2759,6 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, return absl::InternalError( "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned " "no suitable algorithms. This could be a cudnn bug."); -#else - cudnnConvolutionBwdDataPreference_t preference = - specify_workspace_limit - ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; - cudnnConvolutionBwdDataAlgo_t algo_to_use; - RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm( - cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(), - input_nd.handle(), preference, memory_limit_bytes, &algo_to_use)); - return algo_to_use; -#endif } absl::StatusOr @@ -3081,7 +2769,6 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit, size_t memory_limit_bytes) { -#if CUDNN_VERSION >= 8000 const int num_requested_algos = 5; int num_returned_algos = 0; cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos]; @@ -3102,17 +2789,6 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, return absl::InternalError( "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned " "no suitable algorithms. This could be a cudnn bug."); -#else - cudnnConvolutionBwdFilterPreference_t preference = - specify_workspace_limit - ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; - cudnnConvolutionBwdFilterAlgo_t algo_to_use; - RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm( - cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(), - filter.handle(), preference, memory_limit_bytes, &algo_to_use)); - return algo_to_use; -#endif } absl::StatusOr> AllocateCudnnConvolutionForwardWorkspace( @@ -3486,8 +3162,6 @@ struct FftTilingForward { // winograd-non-fused engines will be ruled out. struct WinogradNonfused { static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED"; - // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions, - // we have a workaround. static constexpr bool kDefaultFlag = true; }; @@ -3513,20 +3187,11 @@ struct ConvDoFP32ComputationFP16Input { // in precision. struct RnnDoFP32ComputationFP16Input { static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE"; - // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug. - // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math - // precision is set. - // Set it temporary to false s.t. no error is raised when using fp16 inputs, - // fp32 math precision. - // - // cuDNN == 7.5.0 is verified to have this fixed. - static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500; + static constexpr bool kDefaultFlag = true; }; namespace { -#if CUDNN_VERSION >= 8100 - bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config, bool disable_winograd, bool disable_nondeterminism, bool disable_tensor_core) { @@ -3552,8 +3217,6 @@ bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config, return ret; } -#endif // CUDNN_VERSION >= 8100 - } // namespace cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) { @@ -3587,13 +3250,11 @@ dnn::DataType GetConvActivationType(dnn::DataType data_type) { case dnn::DataType::kInt8: case dnn::DataType::kInt32: // TODO(awpr): does int32 do blending in float? return dnn::DataType::kFloat; -#if CUDNN_VERSION >= 8200 // TODO(awpr): as with kHalf, this is not clear. case dnn::DataType::kBF16: return CudnnEnvVar::IsEnabled() ? dnn::DataType::kFloat : dnn::DataType::kBF16; -#endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); } @@ -3611,24 +3272,18 @@ dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) { case dnn::DataType::kInt8: case dnn::DataType::kInt32: return dnn::DataType::kInt32; -#if CUDNN_VERSION >= 8200 case dnn::DataType::kBF16: return CudnnEnvVar::IsEnabled() ? dnn::DataType::kFloat : dnn::DataType::kBF16; -#endif -#if CUDNN_VERSION >= 8900 case dnn::DataType::kF8E4M3FN: case dnn::DataType::kF8E5M2: return dnn::DataType::kFloat; -#endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); } } -#if CUDNN_VERSION >= 8100 - namespace { static bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { (void)engine_config; @@ -3696,7 +3351,6 @@ std::tuple GetTensorVectorSizeAndDim( return std::make_tuple(vector_size, vector_dim); } -#if CUDNN_VERSION >= 8800 absl::StatusOr CreateCudnnTensor( absl::Span dims, absl::Span strides, int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim, @@ -3718,41 +3372,10 @@ absl::StatusOr CreateCudnnTensor( RETURN_MSG_IF_CUDNN_ERROR(tensor); return tensor; } -#else -absl::StatusOr CreateCudnnTensor( - absl::Span dims, absl::Span strides, - int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim, - bool is_virtual = false, bool is_reordered_nchw_vect = false) { - if (is_reordered_nchw_vect && (CUDNN_VERSION) < 8300) { - return tsl::errors::Internal( - "reordered nchw_vect requires cudnn 8.3+, but version was %d", - (CUDNN_VERSION)); - } - auto tensor = cudnn_frontend::TensorBuilder() - .setDim(dims.size(), dims.data()) - .setStride(strides.size(), strides.data()) - .setId(uid) - .setAlignment(32) - .setDataType(ToCudnnDataType(dtype)) - .setVectorCountAndDimension(vec_count, vec_dim) - .setVirtual(is_virtual) -// TODO(jlebar): remove guard after JAX no longer supports old cudnn -#if CUDNN_VERSION >= 8300 - .setReorderType(is_reordered_nchw_vect - - ? CUDNN_TENSOR_REORDERING_INT8x32 - : CUDNN_TENSOR_REORDERING_NONE) -#endif - .build(); - RETURN_MSG_IF_CUDNN_ERROR(tensor); - return tensor; -} -#endif absl::StatusOr CreateCudnnTensor( const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, bool is_virtual = false) { -#if CUDNN_VERSION >= 8900 auto tensor = cudnn_frontend::TensorBuilder() .cloneFrom(original, uid) .setAlignment(32) @@ -3761,13 +3384,8 @@ absl::StatusOr CreateCudnnTensor( .build(); RETURN_MSG_IF_CUDNN_ERROR(tensor); return tensor; -#else - return tsl::errors::Internal("Not implemented."); -#endif // CUDNN_VERSION >= 8900 } -#if CUDNN_VERSION >= 8800 - absl::StatusOr CreatePwDesc( dnn::DataType dtype, cudnnPointwiseMode_t mode) { auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder() @@ -3821,7 +3439,6 @@ absl::StatusOr CreateTernaryPwOp( RETURN_MSG_IF_CUDNN_ERROR(pw_op_created); return pw_op_created; } -#endif // CUDNN_VERSION >= 8800 absl::StatusOr> GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -3868,31 +3485,17 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, std::vector filter_strides = filter_descriptor.vectorized_strides( dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); -#if CUDNN_VERSION >= 8800 cudnnBackendTensorReordering_t tensor_ordering_type = filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX32_CudnnReordered ? CUDNN_TENSOR_REORDERING_INT8x32 : CUDNN_TENSOR_REORDERING_NONE; -#else - bool is_reordered_nchw_vect = - filter_descriptor.layout() == - dnn::FilterLayout::kOutputInputYX32_CudnnReordered; -#endif -#if CUDNN_VERSION >= 8800 TF_ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, vector_size, vector_dim, /*is_virtual=*/false, tensor_ordering_type)); -#else - TF_ASSIGN_OR_RETURN( - auto tensor_w, - CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, - vector_size, vector_dim, - /*is_virtual=*/false, is_reordered_nchw_vect)); -#endif // conv_desc. auto mode = convolution_descriptor.convolution_not_crosscorr() @@ -4003,15 +3606,15 @@ OpNameStringToOperandKindAndMode(std::string opstring) { // Struct describing the convolution, pointwise and reduction ops in the // graph. struct OpDescriptor { - int uid; // The UID of the op. - std::vector operand_uids; // The UIDs of the operands of the op that - // are part of the graph. - OpMode mode; // The mode describing the op. - TensorKind operand_kind; // The kind of a second operand. - TensorKind result_kind; // The kind of the output. - dnn::DataType result_type; // The type of the output. - bool is_virtual; // A virtual op has a user within the graph. - int sequence_index; // The index of the op in the sequence. + int uid; // The UID of the op. + std::vector operand_uids; // The UIDs of the operands of the op that + // are part of the graph. + OpMode mode; // The mode describing the op. + TensorKind operand_kind; // The kind of a second operand. + TensorKind result_kind; // The kind of the output. + dnn::DataType result_type; // The type of the output. + bool is_virtual; // A virtual op has a user within the graph. + int sequence_index; // The index of the op in the sequence. }; // Class describing the graph of ops to be fused into the cuDNN convolution @@ -4498,33 +4101,18 @@ GetCudnnFusedOperationGraph( std::vector filter_strides = filter_descriptor.vectorized_strides( dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); -#if CUDNN_VERSION >= 8800 cudnnBackendTensorReordering_t tensor_ordering_type = filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX32_CudnnReordered ? CUDNN_TENSOR_REORDERING_INT8x32 : CUDNN_TENSOR_REORDERING_NONE; -#else - bool is_reordered_nchw_vect = - filter_descriptor.layout() == - dnn::FilterLayout::kOutputInputYX32_CudnnReordered; -#endif -#if CUDNN_VERSION >= 8800 TF_ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, vector_size, vector_dim, /*is_virtual=*/false, tensor_ordering_type)); // cuDNN 8.3 fails here -#else - TF_ASSIGN_OR_RETURN( - auto tensor_w, - CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, - vector_size, vector_dim, - /*is_virtual=*/false, - is_reordered_nchw_vect)); // cuDNN 8.3 fails here -#endif // For the purposes of the cudnn graph, say that the bias tensor has the same // layout as the output tensor. It doesn't actually matter, because bias is a @@ -4558,15 +4146,9 @@ GetCudnnFusedOperationGraph( // kFloat). If it's not, then cuDNN silently does the reordering under the // hood, which yields incorrect results as we already do the reordering // ourselves. - auto maybe_tensor_b = CreateCudnnTensor(bias_dims, bias_strides, 'b', - bias_type, vector_size, vector_dim, - /*is_virtual=*/false, -#if CUDNN_VERSION >= 8800 - tensor_ordering_type -#else - is_reordered_nchw_vect -#endif - ); // cuDNN 8.3 fails here + auto maybe_tensor_b = CreateCudnnTensor( + bias_dims, bias_strides, 'b', bias_type, vector_size, vector_dim, + /*is_virtual=*/false, tensor_ordering_type); if (!maybe_tensor_b.ok()) { maybe_tensor_b = CreateCudnnTensor(bias_dims, bias_strides, 'b', bias_type, vector_size, vector_dim); @@ -4902,7 +4484,6 @@ GetCudnnFusedMatmulGraph(dnn::DataType input_type, dnn::DataType bias_type, static absl::StatusOr GetExecPlanFromHeuristics( cudnn_frontend::OperationGraph&& opGraph, const CudnnHandle& cudnn, bool include_fallback_heuristics = false) { -#if (CUDNN_VERSION >= 8800) cudnn_frontend::EngineConfigList engine_configs; if (!include_fallback_heuristics) { cudnn_frontend::get_heuristics_list<1>( @@ -4940,9 +4521,6 @@ static absl::StatusOr GetExecPlanFromHeuristics( LOG(FATAL) << "Failed to generate cuDNN execution plan for opGraph " << opGraph.getTag() << ". absl::Status of final plan: " << CudnnStatusToString(status); -#else - return absl::UnimplementedError("Supported only for cuDNN >= 8.8.0"); -#endif } static absl::StatusOr RebuildExecutionPlan( @@ -4997,8 +4575,6 @@ static absl::StatusOr RebuildExecutionPlan( return {std::move(plan)}; } -#endif // CUDNN_VERSION >= 8100 - } // namespace void FixDimsForRaggedOffset(std::vector& dims, int max_reg_per_batch) { @@ -6076,16 +5652,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { } const auto get_fwd_bugs = [&]() -> absl::Status { -#if CUDNN_VERSION < 8000 - if (algo_id_ == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM && - ToCudnnDataType(input_type_) == CUDNN_DATA_INT8 && - ToCudnnDataType(output_type_) == CUDNN_DATA_FLOAT) { - return absl::FailedPreconditionError( - "This configuration potentially produces incorrect results."); - } -#else (void)output_type_; // To stop clang-tidy saying it's unused. -#endif return absl::OkStatus(); }; @@ -6298,7 +5865,6 @@ class ScalingParam { dnn::DataType default_target_dtype_; }; -#if CUDNN_VERSION >= 8100 struct BackendDescriptorDeleter { void operator()(cudnnBackendDescriptor_t desc) { cudnnBackendDestroyDescriptor(desc); @@ -6703,7 +6269,6 @@ absl::Status CreateOpRunners( } } // namespace -#endif // CUDNN_VERSION >= 8100 absl::Status CudnnSupport::GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, @@ -6718,20 +6283,15 @@ absl::Status CudnnSupport::GetConvolveRunners( ScratchAllocator* /*scratch_allocator*/, const NumericOptions& numeric_options, std::vector>* out_exec_plans) { - // cuDNN frontend support became sufficiently stable to use in 8.1. - // TODO(awpr): remove this condition once support for cuDNN 8.0 is dropped. - const bool is_pre_frontend_cudnn = CUDNN_VERSION < 8100; - // cuDNN frontend support for Tx32 convolutions added in 8.3. // If the filter is not reordered, do not use frontend (it is slow). const bool is_disabled_x32 = input_descriptor.layout() == dnn::kBatchDepthYX32 && - (CUDNN_VERSION < 8300 || - filter_descriptor.layout() != - dnn::FilterLayout::kOutputInputYX32_CudnnReordered); + (filter_descriptor.layout() != + dnn::FilterLayout::kOutputInputYX32_CudnnReordered); const bool actually_use_cudnn_frontend = - use_cudnn_frontend && !is_pre_frontend_cudnn && !is_disabled_x32; + use_cudnn_frontend && !is_disabled_x32; if (use_cudnn_frontend && !actually_use_cudnn_frontend) { // This will happen once per unique conv configuration/shape that gets @@ -6796,7 +6356,6 @@ absl::Status CudnnSupport::GetConvolveRunners( return absl::OkStatus(); } -#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( auto op_graph, @@ -6808,10 +6367,6 @@ absl::Status CudnnSupport::GetConvolveRunners( stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind, input_type, {'x', 'w', 'y'}, use_fallback, out_exec_plans, /*need_side_input=*/false, numeric_options); -#else - return tsl::errors::Unimplemented( - "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif // CUDNN_VERSION >= 8100 } absl::Status CudnnSupport::GetGraphConvolveRunners( @@ -6878,7 +6433,6 @@ CudnnSupport::ConvolveRunnerFromDesc( return {std::make_unique(std::move(runner))}; } -#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( @@ -6897,10 +6451,6 @@ CudnnSupport::ConvolveRunnerFromDesc( /*need_side_input=*/false)); return {std::make_unique>( std::move(runner))}; -#else - return tsl::errors::Unimplemented( - "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif } absl::StatusOr> @@ -6917,7 +6467,6 @@ CudnnSupport::GraphConvolveRunnerFromDesc( "cuDNN graph execution requires the use of the cuDNN frontend."); } -#if CUDNN_VERSION >= 8900 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( @@ -6937,10 +6486,6 @@ CudnnSupport::GraphConvolveRunnerFromDesc( /*need_side_input=*/false)); return {std::make_unique>( std::move(runner))}; -#else - return tsl::errors::Unimplemented( - "cuDNN graph execution requires cuDNN version 8.9 or higher."); -#endif } class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { @@ -7174,7 +6719,6 @@ CudnnSupport::FusedConvolveRunnerFromDesc( return {std::make_unique(std::move(runner))}; } -#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN(auto op_graph, @@ -7195,10 +6739,6 @@ CudnnSupport::FusedConvolveRunnerFromDesc( {'x', 'w', 'z', 'b', 'y'}, need_side_input)); return {std::make_unique>( std::move(runner))}; -#else - return tsl::errors::Unimplemented( - "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif } absl::Status CudnnSupport::GetFusedConvolveRunners( @@ -7217,35 +6757,20 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( // Fused convolutions with identity activations are broken in that they // implicitly do ReLU on some engines, and we can't reliably detect which // ones. - const bool is_broken_identity_fused_conv = -#if CUDNN_VERSION < 8205 - activation_mode == dnn::ActivationMode::kNone; -#else - false; -#endif - // cuDNN frontend support became sufficiently stable to use in 8.1. - // TODO(awpr): remove this condition once support for cuDNN 8.0 is dropped. - const bool is_pre_frontend_cudnn = CUDNN_VERSION < 8100; - - // cuDNN frontend support for Tx32 convolutions added in 8.3. // If the filter is not reordered, do not use frontend (it is slow). const bool is_disabled_x32 = input_descriptor.layout() == dnn::kBatchDepthYX32 && - (CUDNN_VERSION < 8300 || - filter_descriptor.layout() != - dnn::FilterLayout::kOutputInputYX32_CudnnReordered); + (filter_descriptor.layout() != + dnn::FilterLayout::kOutputInputYX32_CudnnReordered); const bool actually_use_cudnn_frontend = - use_cudnn_frontend && !is_pre_frontend_cudnn && - !is_broken_identity_fused_conv && !is_disabled_x32; + use_cudnn_frontend && !is_disabled_x32; if (use_cudnn_frontend && !actually_use_cudnn_frontend) { const char* reason = "the current cuDNN version does not support it."; if (is_disabled_x32) { reason = "Tx32 convolutions are disabled."; - } else if (is_broken_identity_fused_conv) { - reason = "it uses an identity activation."; } // This will happen once per unique conv configuration/shape that gets @@ -7266,14 +6791,6 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( "on GPUs with compute capability 6.1 or later."); } - if (input_type == dnn::DataType::kInt8 && - output_type == dnn::DataType::kFloat && - (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) { - return tsl::errors::Unimplemented( - "int8 -> float fused conv is disabled for this cuDNN version. See " - "go/nvbugs/3326122"); - } - if (activation_mode != dnn::ActivationMode::kRelu && activation_mode != dnn::ActivationMode::kRelu6 && activation_mode != dnn::ActivationMode::kElu && @@ -7317,7 +6834,6 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( return absl::OkStatus(); } -#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); auto op_graph_status = GetCudnnFusedOperationGraph( kind, input_type, bias_type, output_type, conv_scale, side_input_scale, @@ -7335,10 +6851,6 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind, input_type, {'x', 'w', 'z', 'b', 'y'}, use_fallback, out_exec_plans, need_side_input, numeric_options); -#else - return tsl::errors::Unimplemented( - "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif // CUDNN_VERSION >= 8100 } absl::Status CudnnSupport::GetFusedMatmulRunners( @@ -7349,7 +6861,6 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( const NumericOptions& numeric_options, std::vector>* out_exec_plans) { -#if CUDNN_VERSION >= 8400 if (!use_cudnn_frontend) { return tsl::errors::Unimplemented( "Cudnn execution plans for matmul are only supported with cudnn " @@ -7374,11 +6885,6 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), dnn::ConvolutionKind::INVALID, input_type, {'a', 'b', 'z', 'c'}, use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options); -#else - return tsl::errors::Unimplemented( - "Cudnn execution plans for matmul are only supported with Cudnn >= " - "8.4."); -#endif // CUDNN_VERSION >= 8400 } bool CudnnSupport::GetConvolveAlgorithms( @@ -7765,7 +7271,6 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( DeviceMemory workspace; DeviceMemory reserve_space; -#if CUDNN_VERSION >= 7402 const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t { if (side_input.is_null()) { return activation_mode == dnn::ActivationMode::kNone @@ -7800,7 +7305,6 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( reserve_space_size_in_bytes)); } } -#endif auto check_no_side_input_or_activation = [&]() -> absl::Status { if (activation_mode != dnn::ActivationMode::kNone || @@ -7833,7 +7337,6 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( } bool called = false; -#if CUDNN_VERSION >= 7402 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) { called = true; RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx( @@ -7863,7 +7366,6 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( /*reserveSpace=*/reserve_space.opaque(), /*reserveSpaceSizeInBytes=*/reserve_space.size())); } -#endif if (!called) { TF_RETURN_IF_ERROR(check_no_side_input_or_activation()); RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining( @@ -7976,7 +7478,6 @@ absl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); bool called = false; -#if CUDNN_VERSION >= 7402 if (reserve_space_data != nullptr && workspace_allocator != nullptr) { called = true; const cudnnBatchNormOps_t bn_ops = [&]() { @@ -8031,7 +7532,6 @@ absl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( /*reserveSpace=*/reserve_space_data->opaque(), /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); } -#endif auto check_no_side_input_or_activation = [&]() -> absl::Status { if (activation_mode != dnn::ActivationMode::kNone || !side_input_backprop->is_null()) { @@ -8079,14 +7579,6 @@ absl::Status CudnnSupport::DoFusedConvolve( "on GPUs with compute capability 6.1 or later."); } - if (input_type == dnn::DataType::kInt8 && - output_type == dnn::DataType::kFloat && - (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) { - return tsl::errors::Unimplemented( - "int8 -> float fused conv is disabled for this cuDNN version. See " - "go/nvbugs/3326122"); - } - if (activation_mode != dnn::ActivationMode::kRelu && activation_mode != dnn::ActivationMode::kNone) { return absl::InvalidArgumentError( @@ -8177,7 +7669,6 @@ absl::Status CudnnSupport::DoPrepareForCtcLoss( auto cudnn = cudnn_->GetHandle(parent_, stream); // Query the workspace size. size_t workspace_size_in_bytes = 0; -#if CUDNN_VERSION >= 7603 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type)); const CudnnRnnStateTensorDescriptor& cudnn_probs_desc = static_cast(probs_desc); @@ -8216,11 +7707,6 @@ absl::Status CudnnSupport::DoPrepareForCtcLoss( /*sizeInBytes=*/&workspace_size_in_bytes)); } *ctc_loss_algo_id = algo; -#else - return absl::InvalidArgumentError( - "No supported cudnnGetCTCLossWorkspaceSize when " - "CUDNN_VERSION < 7.6.3"); -#endif // Allocate the workspace. if (workspace_size_in_bytes == 0) { *scratch_memory = DeviceMemory(); @@ -8246,10 +7732,10 @@ absl::Status CudnnSupport::DoCtcLoss( DeviceMemoryBase grads_data, DeviceMemory scratch_memory, int ctc_loss_algo_id) { // Current cuDNN CTC Loss only supports the float datatype - if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) { + if (element_type != dnn::DataType::kFloat) { return absl::InvalidArgumentError( "CudnnCtcLossDescriptor is supported only when the " - "CUDNN_VERSION >= 7.6.3 and DataType is float"); + "DataType is float"); } CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type)); const CudnnRnnStateTensorDescriptor& cudnn_probs_desc = @@ -8605,8 +8091,6 @@ bool CudnnSupport::DeriveOutputBatchDescriptor( return IsStatusOk(status, /*report_error=*/true); } -#if CUDNN_VERSION >= 8100 - absl::StatusOr> CudnnSupport::DeserializeGraph( Stream& stream, absl::string_view serialized_data) const { auto cudnn = cudnn_->GetHandle(stream.parent(), &stream); @@ -8663,16 +8147,10 @@ absl::Status CudnnGraph::Execute(Stream& stream, } if (dropout_rng_offset_increment_ > 0) { -#if CUDNN_VERSION >= 8800 UpdateDropoutState(local_device_ordinal); tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_; tensor_to_ptr_map[next_uid()] = (void*)¤t_dropout_rng_offset_[local_device_ordinal]; -#else - return absl::UnimplementedError( - "Cudnn dropout offset and seed are only supported with Cudnn >= " - "8.8.0"); -#endif // CUDNN_VERSION >= 8800 } const CudnnSupport& dnn_support = @@ -8682,7 +8160,6 @@ absl::Status CudnnGraph::Execute(Stream& stream, graph_.execute(cudnn.handle(), tensor_to_ptr_map, workspace.opaque())); } -#endif // CUDNN_VERSION >= 8100 } // namespace gpu void initialize_cudnn() { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index e3a45ef761e37a..4e214cf4e11ac3 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -38,9 +38,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/protobuf/dnn.pb.h" -#if CUDNN_VERSION >= 8100 #include "third_party/cudnn_frontend/include/cudnn_frontend.h" -#endif // CUDNN_VERSION >= 8100 namespace stream_executor { namespace gpu { @@ -55,7 +53,6 @@ using BatchDescriptorSlice = absl::Span; template using DeviceMemorySlice = absl::Span* const>; -#if CUDNN_VERSION >= 8100 class CudnnGraph : public dnn::DnnGraph { public: explicit CudnnGraph(cudnn_frontend::graph::Graph&& graph) @@ -85,7 +82,6 @@ class CudnnGraph : public dnn::DnnGraph { mutable std::vector current_dropout_rng_offset_; int64_t dropout_rng_offset_increment_ = 0; }; -#endif // CUDNN_VERSION >= 8100 // cudnn-library based DNN support. For details on overridden interface // functions, see dnn.h. @@ -555,11 +551,9 @@ class CudnnSupport : public dnn::DnnSupport { void NotifyStreamDestroyed(Stream* stream) override; -#if CUDNN_VERSION >= 8100 // Loads complete graph from its serialized representation. absl::StatusOr> DeserializeGraph( Stream& stream, absl::string_view serialized_data) const override; -#endif // CUDNN_VERSION >= 8100 private: // Uses cuDNN handle for execution. diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb b/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb index 6de868f0d19899..4f58cc288f6a49 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb +++ b/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb @@ -37,7 +37,7 @@ gpu_device_info { platform_name: "CUDA" dnn_version_info { major: 8 - minor: 3 - patch: 2 + minor: 9 + patch: 4 } device_description_str: "A100 80GB" diff --git a/third_party/xla/xla/tsl/util/use_cudnn.cc b/third_party/xla/xla/tsl/util/use_cudnn.cc index ad94bb5ac13faf..cc9bc0987cd585 100644 --- a/third_party/xla/xla/tsl/util/use_cudnn.cc +++ b/third_party/xla/xla/tsl/util/use_cudnn.cc @@ -49,12 +49,10 @@ bool CudnnUseRuntimeFusion() { static bool result = [] { bool value = false; #if GOOGLE_CUDA - if (CUDNN_VERSION >= 8400) { - absl::Status status = - ReadBoolFromEnvVar("TF_CUDNN_USE_RUNTIME_FUSION", false, &value); - if (!status.ok()) { - LOG(ERROR) << status; - } + absl::Status status = + ReadBoolFromEnvVar("TF_CUDNN_USE_RUNTIME_FUSION", false, &value); + if (!status.ok()) { + LOG(ERROR) << status; } #endif // GOOGLE_CUDA return value; From 8cd432fd5b4d950440ba65a6b84d31bb6b316290 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 4 Apr 2025 01:58:54 -0700 Subject: [PATCH 0247/1324] [XLA:GPU] Enable `TritonGemmTest.DoNotUseTensorCoresWithNonDefaultPrecision` in `fusion_emitter_device_legacy_port_test.cc`. PiperOrigin-RevId: 743861172 --- .../fusion_emitter_device_legacy_port_test.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 13398805cf3e43..a8603e0b9b89d3 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -69,7 +68,7 @@ namespace m = ::xla::match; using tsl::testing::StatusIs; struct ModuleAndNestedFusionMetadata { - std::unique_ptr module; + std::unique_ptr module; HloComputation* computation; BlockLevelParameters block_level_parameters; }; @@ -106,11 +105,12 @@ class TritonTest : public GpuCodegenTest { } } - // Returns the computation and block level parameters from an HLO module text - // whose entry computation contains a single GEMM fusion. + // Returns the module, its fusion computation and associated block level + // parameters from an HLO module text whose entry computation contains a + // single GEMM fusion. absl::StatusOr GetModuleAndNestedFusionMetadata(absl::string_view hlo_text) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, + TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); TF_ASSIGN_OR_RETURN( bool fusion_was_nested, @@ -361,7 +361,7 @@ CHECK-DAG: tt.make_tensor_ptr %[[DYNAMIC_SLICE_INPUT]], [%[[C2_i64]], %[[ROW_L tsl::testing::IsOk()); } -TEST_F(TritonGemmTest, DISABLED_DoNotUseTensorCoresWithNonDefaultPrecision) { +TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { constexpr absl::string_view kHloText = R"( triton_gemm_r { parameter_0 = s8[80,15]{1,0} parameter(0) @@ -382,10 +382,10 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":2, "num_ctas":1}}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); - CompileAndOptionallyVerifyPtx(std::move(verified_module), + CompileAndOptionallyVerifyPtx(std::move(module_and_metadata.module), R"( CHECK-NOT: mma )"); From fd28e20972ae5c7f0688e08f3ec58cf94633d4a5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 02:02:56 -0700 Subject: [PATCH 0248/1324] Update GraphDef version to 2187. PiperOrigin-RevId: 743862228 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 73057f3f941700..bd134534929aad 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2186 // Updated: 2025/4/3 +#define TF_GRAPH_DEF_VERSION 2187 // Updated: 2025/4/4 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 2765e59402df1d4e024d79045ba501947f51a0c6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 02:03:00 -0700 Subject: [PATCH 0249/1324] compat: Update forward compatibility horizon to 2025-04-04 PiperOrigin-RevId: 743862253 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 6f4f7243d7f4d1..d6f9f971055063 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 3) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 4) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 60fb772317a9e98d184c6d901289775960dccbb5 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 4 Apr 2025 14:35:37 -0700 Subject: [PATCH 0250/1324] Update profiler pin to (hopefully) fix builds PiperOrigin-RevId: 744073081 --- tensorflow/workspace2.bzl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 82491ab8b087df..f43770537af437 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -937,9 +937,9 @@ def _tf_repositories(): tf_http_archive( name = "org_xprof", - sha256 = "88bc65694f79f266e16269da73b5b9238db1552175d1cd75bc08c7337377ab0d", - strip_prefix = "profiler-5d90906294ecbd83639b583fc926cbedc06e60dc", - urls = tf_mirror_urls("https://github.com/tensorflow/profiler/archive/5d90906294ecbd83639b583fc926cbedc06e60dc.zip"), + sha256 = "dec4889a6a5123fca0a775ba20f22717b2d0c3af1491f41bb52e1b502595271e", + strip_prefix = "xprof-c3dbeb2c69b48163c6156d6f4a8c82ac34736f49", + urls = tf_mirror_urls("https://github.com/tensorflow/profiler/archive/c3dbeb2c69b48163c6156d6f4a8c82ac34736f49.zip"), ) # used for adding androidx.annotation dependencies in tflite android jni. From 76e3683c4e26a72854300189452dbef21a9ee4cb Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 4 Apr 2025 15:17:00 -0700 Subject: [PATCH 0251/1324] Fix element type mismatch for SPMD DUS indices PiperOrigin-RevId: 744085265 --- third_party/xla/xla/service/spmd/spmd_partitioner.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index f6333c3f530b21..b1c2d3ab974a22 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -3761,8 +3761,11 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( if (slice_size == 1) { partitioned_slice_offsets.push_back(-1); } else { + const PrimitiveType elemType = + hlo->operand(i + 2)->shape().element_type(); partitioned_slice_offsets.push_back( - hlo->operand(i + 2)->literal().Get({})); + elemType == S64 ? hlo->operand(i + 2)->literal().Get({}) + : hlo->operand(i + 2)->literal().Get({})); } } } else if (hlo->sharding().tile_assignment().dim(i) != 1) { From dd5064300419f85a890b2bc616c954fe0a3d8dc0 Mon Sep 17 00:00:00 2001 From: Chen Li Date: Fri, 4 Apr 2025 15:18:11 -0700 Subject: [PATCH 0252/1324] Reverts 9bdeda24e58b39852225c69b0ef7218835fa5bbf PiperOrigin-RevId: 744085605 --- third_party/xla/xla/debug_options_flags.cc | 7 + third_party/xla/xla/service/gpu/BUILD | 2 + .../xla/xla/service/gpu/backend_configs.proto | 9 +- .../xla/service/gpu/backend_configs_test.cc | 1 + .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 2 + .../xla/service/gpu/gpu_hlo_schedule_test.cc | 1 + .../xla/xla/service/gpu/transforms/BUILD | 37 ++++ .../command_buffer_scheduling_test.cc | 16 +- .../gpu/transforms/schedule_postprocessing.cc | 158 +++++++++++++++++ .../gpu/transforms/schedule_postprocessing.h | 50 ++++++ .../schedule_postprocessing_test.cc | 163 ++++++++++++++++++ third_party/xla/xla/xla.proto | 4 +- 12 files changed, 438 insertions(+), 12 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc create mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h create mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index e8547150564a7b..2c44ea0f19636d 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -253,6 +253,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_split_k_autotuning(true); opts.set_xla_gpu_enable_reduction_epilogue_fusion(true); + opts.set_xla_gpu_enable_nccl_clique_optimization(false); opts.set_xla_gpu_cublas_fallback(true); opts.set_xla_gpu_cudnn_gemm_fusion_level(0); opts.set_xla_gpu_enable_while_loop_double_buffering(false); @@ -1918,6 +1919,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_reduction_epilogue_fusion), debug_options->xla_gpu_enable_reduction_epilogue_fusion(), "Enable fusion for reduction epilogues")); + flag_list->push_back( + tsl::Flag("xla_gpu_enable_nccl_clique_optimization", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_nccl_clique_optimization), + debug_options->xla_gpu_enable_nccl_clique_optimization(), + "Allow early return when acquiring NCCL cliques")); flag_list->push_back( tsl::Flag("xla_gpu_cublas_fallback", bool_setter_for(&DebugOptions::set_xla_gpu_cublas_fallback), diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ed3775413cbbce..ab2b73adf48be0 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2158,6 +2158,7 @@ cc_library( "//xla/service/gpu/model:sol_latency_estimator", "//xla/service/gpu/transforms:async_collective_annotator", "//xla/service/gpu/transforms:pgle_accuracy_checker", + "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_description", @@ -2200,6 +2201,7 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service:latency_hiding_scheduler", "//xla/service:legalize_scheduling_annotations", + "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 37ad0b6e615ab1..515f672acbf3a9 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -118,16 +118,19 @@ message BitcastBackendConfig { // Backend config for async collective operations. Note that for is_sync will // be false by default, so even if a backend config is not explicitly attached // to the HLOInstruction, getting the backend_config will yield a default valued -// proto which will have is_sync = false. +// proto which will have is_sync = false. Attribute no_parallel_custom_call +// asserts that an asynchronous collective operation does not execute in +// parallel with custom-calls, which can trigger device synchronization . This +// attribute will also be false by default and should lead to conversative +// runtime behavior. message CollectiveBackendConfig { bool is_sync = 1; + bool no_parallel_custom_call = 2; // Determines whether the collective op of interested has been pipelined // within a loop. bool is_pipelined = 3; // Cost model prediction. repeated ReificationCost reification_cost = 4; - - reserved 2; } // Backend config for cost model estimates. diff --git a/third_party/xla/xla/service/gpu/backend_configs_test.cc b/third_party/xla/xla/service/gpu/backend_configs_test.cc index 16f05964536e71..7883547f077dcb 100644 --- a/third_party/xla/xla/service/gpu/backend_configs_test.cc +++ b/third_party/xla/xla/service/gpu/backend_configs_test.cc @@ -59,6 +59,7 @@ TEST_F(BackendConfigsTest, DefaultCollectiveBackendConfig) { const auto& collective_backend_config = gpu_config.collective_backend_config(); EXPECT_THAT(collective_backend_config.is_sync(), IsFalse()); + EXPECT_THAT(collective_backend_config.no_parallel_custom_call(), IsFalse()); } TEST_F(BackendConfigsTest, DefaultGpuBackendConfigParseOpQueue) { diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index aacd6d264475fa..c050fad9b8d68a 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -59,6 +59,7 @@ limitations under the License. #include "xla/service/gpu/transforms/async_collective_annotator.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/gpu/transforms/pgle_accuracy_checker.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" @@ -595,6 +596,7 @@ absl::Status RunLatencyHidingSchedulerPasses( std::move(estimator), std::move(async_tracker), std::move(scheduler_core), shape_size_in_bytes); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index 375210385f89e7..f2d3c1ebbedfda 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/backend.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 768b37554a6632..dfd5bb6144e019 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2834,6 +2834,43 @@ xla_cc_test( ], ) +cc_library( + name = "schedule_postprocessing", + srcs = ["schedule_postprocessing.cc"], + hdrs = ["schedule_postprocessing.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/transforms/collectives:collective_ops_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "schedule_postprocessing_test", + srcs = ["schedule_postprocessing_test.cc"], + deps = [ + ":schedule_postprocessing", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "scheduling_instruction_annotator", srcs = ["scheduling_instruction_annotator.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index e74ff34ec0d625..3c60b7f6ec2255 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -209,7 +209,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) { %a = s32[4] parameter(0) %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) })"; @@ -242,7 +242,7 @@ TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) { %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a), channel_id=555, replica_groups={{0,1}}, dimensions={0}, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} ROOT %done = s32[4]{0} all-gather-done(%start) })"; @@ -282,7 +282,7 @@ TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) { %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a), channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} ROOT %done = s32[2]{0} reduce-scatter-done(%start) })"; @@ -321,7 +321,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) { %a = s32[4] parameter(0) %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} %bitcast = s32[4] bitcast(s32[4]{0} %a) ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) })"; @@ -361,10 +361,10 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) { %a = s32[4] parameter(0) %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) })"; @@ -418,11 +418,11 @@ TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) { %b = s32[] parameter(1) %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} %c = s32[] custom-call(), custom_call_target="target" %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true}} + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc new file mode 100644 index 00000000000000..0fd39a27d0e3ba --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc @@ -0,0 +1,158 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/schedule_postprocessing.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { +// Maps a computation to a boolean that indicates whether the computation may +// invoke custom-calls directly or indirectly, which can eventually trigger gpu +// synchronization. +using CustomCallInComputation = + absl::flat_hash_map; + +// Returns whether the hlo may invoke custom-calls which may trigger gpu +// synchronization. Currently, we only check for custom-calls, because they are +// the only operations that can be parallel with asynchronous collectives +// operations in an hlo-schedule and may trigger gpu synchronization. +bool MayInvokeCustomCall( + const HloInstruction* hlo, + const CustomCallInComputation& custom_call_in_computation) { + if (HloPredicateIsOp(hlo)) { + return true; + } + + return absl::c_any_of( + hlo->called_computations(), [&](const HloComputation* callee) { + return custom_call_in_computation.find(callee)->second; + }); +} + +// Returns true if this is an asynchronous collective start operation, excluding +// P2P operations. +bool IsRelevantAsynchronousStart(const HloInstruction* hlo) { + return hlo_query::IsAsyncCollectiveStartOp(hlo, + /*include_send_recv=*/false) && + !IsGPUSyncCollective(*hlo); +} + +// Returns true if this is a collective done operation, excluding P2P +// operations. +bool IsRelevantAsynchronousDone(const HloInstruction* hlo) { + return hlo_query::IsAsyncCollectiveDoneOp(hlo, + /*include_send_recv=*/false); +} + +// For a given computation, finds all the asynchronous collective operations +// that aren't parallel with custom-calls and sets its no_parallel_custom_call +// attribute to true. Also records whether the given computation may invoke +// custom-calls. +absl::StatusOr ProcessComputation( + const HloSchedule& schedule, HloComputation* computation, + CustomCallInComputation& custom_call_in_computation) { + bool changed = false; + bool has_custom_call = false; + absl::flat_hash_set async_starts; + const HloInstructionSequence& sequence = schedule.sequence(computation); + + // Visit instructions in the sequence. Collect relevant asynchronous + // collective start ops. When we see a relevant asynchronous collective done + // op, remove the corresponding start op from the collection and set its + // attribute no_parallel_custom_call to true. When we see a custom-call, clear + // the start ops from the collection and keep their attribute + // no_parallel_custom_call as false. + const std::vector& all_instructions = + sequence.instructions(); + for (HloInstruction* hlo : all_instructions) { + if (MayInvokeCustomCall(hlo, custom_call_in_computation)) { + async_starts.clear(); + has_custom_call = true; + continue; + } + if (IsRelevantAsynchronousStart(hlo)) { + async_starts.insert(hlo); + continue; + } + + if (IsRelevantAsynchronousDone(hlo)) { + HloInstruction* async_start = hlo->mutable_operand(0); + if (async_starts.contains(async_start)) { + changed = true; + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + async_start->backend_config()); + CollectiveBackendConfig& collective_backend_config = + *gpu_config.mutable_collective_backend_config(); + collective_backend_config.set_no_parallel_custom_call(true); + TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config)); + async_starts.erase(async_start); + } + } + } + + custom_call_in_computation[computation] = has_custom_call; + return changed; +} + +} // anonymous namespace + +absl::StatusOr SchedulePostprocessing::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + if (!module->has_schedule()) return false; + HloSchedule& schedule = module->schedule(); + bool changed = false; + CustomCallInComputation custom_call_in_computation; + + // We visit computations in the order of callees to callers, as information is + // propagated from calles to callers. + std::vector all_computations = + module->MakeComputationPostOrder(execution_threads); + for (auto iter = all_computations.begin(); iter != all_computations.end(); + ++iter) { + HloComputation* computation = *iter; + if (computation->IsFusionComputation()) { + custom_call_in_computation[computation] = false; + continue; + } + + TF_ASSIGN_OR_RETURN( + bool result, + ProcessComputation(schedule, computation, custom_call_in_computation)); + changed |= result; + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h new file mode 100644 index 00000000000000..d76faed7d260cc --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Amends a schedule result with the needed information to support a runtime +// implementation. Currently, this pass refines attribute +// no_parallel_custom_call for asynchronous collective operations to support +// runtime optimization, such as skipping rendezvous of all participating +// threads for NCCL collective operations. In particular, it sets the attribute +// value for Collective-start operations with is_sync=false; it also keeps the +// attribute value untouch for the operations with is_sync=true and for P2P +// operations, assumming the runtime won't use those values. +// +class SchedulePostprocessing : public HloModulePass { + public: + absl::string_view name() const override { return "schedule-postprocessing"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc new file mode 100644 index 00000000000000..01659a11f6e66d --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/schedule_postprocessing.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using SchedulePostprocessingTest = HloTestBase; + +TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + pf32 = f32[1] parameter(0) + + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}} + ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + SchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY main { + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + + after-all = token[] after-all() + recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}" + } + recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2 + ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + SchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + pf32 = f32[1] parameter(0) + pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} + ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + SchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + start->backend_config()); + const CollectiveBackendConfig& collective_backend_config = + gpu_config.collective_backend_config(); + EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); +} + +TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + pf32 = f32[1] parameter(0) + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} + pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" + all-gather-done = f32[2] all-gather-done(all-gather-start) + ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + SchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); + + HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + start->backend_config()); + const CollectiveBackendConfig& collective_backend_config = + gpu_config.collective_backend_config(); + EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); +} + +TEST_F(SchedulePostprocessingTest, + AsynchronousOpsWithParallelNestedCustomcall) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + foo { + v = f32[1] parameter(0) + ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call" + } + + ENTRY entry { + pf32 = f32[1] parameter(0) + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} + pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo + all-gather-done = f32[2] all-gather-done(all-gather-start) + ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + SchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); + + HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + start->backend_config()); + const CollectiveBackendConfig& collective_backend_config = + gpu_config.collective_backend_config(); + EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 924701a09deda4..bd7c0ccbd4b661 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -443,6 +443,9 @@ message DebugOptions { // threads. Setting to 0 (the default value) means no enforcement. bool xla_gpu_enable_llvm_module_compilation_parallelism = 268; + // Allow early return when acquiring NCCL cliques. + bool xla_gpu_enable_nccl_clique_optimization = 244; + // Enable NCCL communicator splitting. bool xla_gpu_enable_nccl_comm_splitting = 272; @@ -802,7 +805,6 @@ message DebugOptions { // go/keep-sorted end - reserved 244; // xla_gpu_enable_nccl_clique_optimization reserved 276; // xla_gpu_enable_nccl_per_stream_comms //--------------------------------------------------------------------------// From e2dca7eef9941a04a398031fab3fa3b1a67f3b16 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 15:18:12 -0700 Subject: [PATCH 0253/1324] Fix uses of the "rank" concept in `layout_util`. PiperOrigin-RevId: 744085609 --- third_party/xla/xla/layout.h | 3 +- third_party/xla/xla/layout_util.cc | 43 ++++++++------- third_party/xla/xla/layout_util.h | 15 ++--- third_party/xla/xla/layout_util_test.cc | 73 +++++++++++++------------ 4 files changed, 69 insertions(+), 65 deletions(-) diff --git a/third_party/xla/xla/layout.h b/third_party/xla/xla/layout.h index 8134ab15c45347..6c5dd212b0a433 100644 --- a/third_party/xla/xla/layout.h +++ b/third_party/xla/xla/layout.h @@ -99,7 +99,8 @@ using TileVector = absl::InlinedVector; // where the splits occur. For example, if the dimension contains 1024 elements, // a split indices value of {512} indicates splitting this dimension into two // right through the middle. The dimension here refers to the physical dimension -// such that 0 is the majormost dimension and rank-1 is the minormost dimension. +// such that 0 is the majormost dimension and (number of dimensions - 1) is the +// minormost dimension. class SplitConfig { public: SplitConfig(int64_t dimension, absl::Span split_indices) diff --git a/third_party/xla/xla/layout_util.cc b/third_party/xla/xla/layout_util.cc index e75110a93f7506..787d17bb60d4a2 100644 --- a/third_party/xla/xla/layout_util.cc +++ b/third_party/xla/xla/layout_util.cc @@ -113,14 +113,14 @@ absl::string_view BoolToString(bool b) { return b ? "true" : "false"; } return layout; } -/* static */ Layout LayoutUtil::MakeDescendingLayout(int64_t rank) { - std::vector layout(rank); +/* static */ Layout LayoutUtil::MakeDescendingLayout(int64_t num_dims) { + std::vector layout(num_dims); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeLayout(layout); } -/* static */ Layout LayoutUtil::MakeAscendingLayout(int64_t rank) { - std::vector layout(rank); +/* static */ Layout LayoutUtil::MakeAscendingLayout(int64_t num_dims) { + std::vector layout(num_dims); std::iota(layout.begin(), layout.end(), static_cast(0)); return MakeLayout(layout); } @@ -136,11 +136,12 @@ absl::string_view BoolToString(bool b) { return b ? "true" : "false"; } namespace { -// Internal helper that creates a default layout for an array of the given rank. -Layout CreateDefaultLayoutForRank(int64_t rank) { +// Internal helper that creates a default layout for an array of the given +// number of dimensions. +Layout CreateDefaultLayoutForRank(int64_t num_dims) { Layout layout; auto* minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->resize(rank, 0); + minor_to_major->resize(num_dims, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -158,8 +159,8 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } -/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64_t rank) { - return CreateDefaultLayoutForRank(rank); +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64_t num_dims) { + return CreateDefaultLayoutForRank(num_dims); } /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { @@ -236,7 +237,7 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { if (layout.minor_to_major_size() != shape.dimensions_size()) { return InvalidArgument( "layout minor_to_major field contains %d elements, " - "but shape is rank %d: {%s}; shape: %s", + "but shape has %d dimensions: {%s}; shape: %s", layout.minor_to_major_size(), shape.dimensions_size(), absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } @@ -268,8 +269,8 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { dim_level_types[i] = layout.dim_level_type(i); } return InvalidArgument( - "layout dim_level_types field contains %d elements, but shape is " - "rank %d: {%s}; shape: %s", + "layout dim_level_types field contains %d elements, but shape has " + "%d dimensions: {%s}; shape: %s", layout.dim_level_types_size(), shape.dimensions_size(), absl::StrJoin(dim_level_types, ", ", [](std::string* out, DimLevelType dim_level_type) { @@ -287,8 +288,8 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { dim_unique[i] = layout.dim_unique(i); } return InvalidArgument( - "layout dim_unique field contains %d elements, but shape is " - "rank %d: {%s}; shape: %s", + "layout dim_unique field contains %d elements, but shape has " + "%d dimensions: {%s}; shape: %s", layout.dim_unique_size(), shape.dimensions_size(), absl::StrJoin(dim_unique, ", ", [](std::string* out, bool dim_unique) { @@ -305,8 +306,8 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { dim_ordered[i] = layout.dim_ordered(i); } return InvalidArgument( - "layout dim_ordered field contains %d elements, but shape is " - "rank %d: {%s}; shape: %s", + "layout dim_ordered field contains %d elements, but shape has " + "%d dimensions: {%s}; shape: %s", layout.dim_ordered_size(), shape.dimensions_size(), absl::StrJoin(dim_ordered, ", ", [](std::string* out, bool dim_ordered) { @@ -677,13 +678,13 @@ absl::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { absl::Span indices) { CHECK(shape.IsArray()); CHECK(shape.has_layout()); - const int rank = shape.dimensions_size(); - CHECK_EQ(rank, indices.size()); + const int num_dims = shape.dimensions().size(); + CHECK_EQ(num_dims, indices.size()); - if (rank == 0) { + if (num_dims == 0) { return 0; } - if (rank == 1) { + if (num_dims == 1) { return indices[0]; } @@ -701,7 +702,7 @@ absl::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { int64_t within_tile_multiplier = 1; // We only look at the top-level tile. - for (int64_t minor = 0; minor < rank; minor++) { + for (int64_t minor = 0; minor < num_dims; minor++) { int64_t logical_dim = Minor(shape.layout(), minor); int64_t shape_dim_size = shape.dimensions(logical_dim); int64_t index = indices[logical_dim]; diff --git a/third_party/xla/xla/layout_util.h b/third_party/xla/xla/layout_util.h index b70a61b23bd712..71507049bd71f3 100644 --- a/third_party/xla/xla/layout_util.h +++ b/third_party/xla/xla/layout_util.h @@ -61,17 +61,17 @@ class LayoutUtil { // Returns a layout with descending ((i.e. {n-1, n-2, ... 0}) minor-to-major // dimensions. - static Layout MakeDescendingLayout(int64_t rank); + static Layout MakeDescendingLayout(int64_t num_dims); // Returns a layout with ascending ((i.e. {0, 1, ... n-1}) minor-to-major // dimensions. - static Layout MakeAscendingLayout(int64_t rank); + static Layout MakeAscendingLayout(int64_t num_dims); // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. - static Layout GetDefaultLayoutForRank(int64_t rank); + static Layout GetDefaultLayoutForRank(int64_t num_dims); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); @@ -228,8 +228,8 @@ class LayoutUtil { // // In the returned vector, the first element represents the most major logical // dimension. The element whose contents are 0 represents the most major - // physical dimension, and the element with contents (rank - 1) represents - // the most minor physical dimension. + // physical dimension, and the element with contents (number of dimensions - + // 1) represents the most minor physical dimension. static std::vector MakeLogicalToPhysical(const Layout& layout); // Prints a human-readable string that represents the given layout. @@ -241,7 +241,8 @@ class LayoutUtil { // Copies the layout from 'src' to 'dst'. Recursively copies layouts of // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same - // rank. within the shapes must have the same number of dimensions. + // number of dimensions. within the shapes must have the same number of + // dimensions. static absl::Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false @@ -249,7 +250,7 @@ class LayoutUtil { // // lhs and rhs need not be compatible to have the same layout but the two // shapes must have the same tuple structure (if any) and arrays must have the - // same rank. Element type is ignored. + // same number of dimensions. Element type is ignored. static bool LayoutsInShapesEqual( const Shape& lhs, const Shape& rhs, std::optional equal = std::nullopt); diff --git a/third_party/xla/xla/layout_util_test.cc b/third_party/xla/xla/layout_util_test.cc index c1aa071864c427..383d29348a2b40 100644 --- a/third_party/xla/xla/layout_util_test.cc +++ b/third_party/xla/xla/layout_util_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include +#include #include "absl/types/span.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/test_helpers.h" @@ -32,6 +34,11 @@ limitations under the License. namespace xla { namespace { +using ::testing::ContainsRegex; +using ::testing::HasSubstr; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + class LayoutUtilTest : public ::testing::Test { protected: Shape MakeShapeWithLayout( @@ -176,8 +183,7 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - ::testing::ContainsRegex("cannot copy layout from shape")); + EXPECT_THAT(status.message(), ContainsRegex("cannot copy layout from shape")); } TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { @@ -195,8 +201,7 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - ::testing::ContainsRegex("cannot copy layout from shape")); + EXPECT_THAT(status.message(), ContainsRegex("cannot copy layout from shape")); } TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { @@ -207,9 +212,9 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.message(), ::testing::ContainsRegex( - "layout minor_to_major field contains .* " - "elements, but shape is rank")); + EXPECT_THAT(status.message(), + ContainsRegex("layout minor_to_major field contains .* " + "elements, but shape has")); } TEST_F(LayoutUtilTest, CopyTokenLayout) { @@ -422,14 +427,14 @@ TEST_F(LayoutUtilTest, ValidateLayout_InvalidArrayLayout) { LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("layout minor_to_major field " - "contains 3 elements, but shape is rank 2")); + HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape has 2 dimensions")); status = LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("layout minor_to_major field " - "contains 3 elements, but shape is rank 2")); + HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape has 2 dimensions")); } TEST_F(LayoutUtilTest, ValidateLayout_InvalidDimLevelTypes) { @@ -442,14 +447,14 @@ TEST_F(LayoutUtilTest, ValidateLayout_InvalidDimLevelTypes) { LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("layout dim_level_types field " - "contains 3 elements, but shape is rank 2")); + HasSubstr("layout dim_level_types field " + "contains 3 elements, but shape has 2 dimensions")); status = LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("layout dim_level_types field " - "contains 3 elements, but shape is rank 2")); + HasSubstr("layout dim_level_types field " + "contains 3 elements, but shape has 2 dimensions")); } TEST_F(LayoutUtilTest, ValidateLayout_MissingArrayLayout) { @@ -459,7 +464,7 @@ TEST_F(LayoutUtilTest, ValidateLayout_MissingArrayLayout) { LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("shape f32[2,3] does not have a layout")); + HasSubstr("shape f32[2,3] does not have a layout")); status = LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); EXPECT_TRUE(status.ok()); @@ -469,41 +474,37 @@ TEST_F(LayoutUtilTest, ValidateLayout_Sparse) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); *shape.mutable_layout() = LayoutUtil::MakeLayout( {1, 0}, {DIM_DENSE, DIM_COMPRESSED}, {}, {}, {Tile({10, 10})}); - EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), - tsl::testing::StatusIs( - tsl::error::INVALID_ARGUMENT, - ::testing::HasSubstr( - "layout has tiles, but the shape is a sparse array"))); + EXPECT_THAT( + LayoutUtil::ValidateLayoutInShape(shape), + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("layout has tiles, but the shape is a sparse array"))); shape.mutable_layout()->clear_tiles(); - EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), tsl::testing::IsOk()); + EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), IsOk()); *shape.mutable_layout()->mutable_physical_shape() = ShapeUtil::MakeShape(F32, {6}); - EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), tsl::testing::IsOk()); + EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), IsOk()); *shape.mutable_layout() ->mutable_physical_shape() ->mutable_layout() ->mutable_physical_shape() = ShapeUtil::MakeShape(S32, {10}); EXPECT_THAT( LayoutUtil::ValidateLayoutInShape(shape), - tsl::testing::StatusIs( + StatusIs( tsl::error::INVALID_ARGUMENT, - ::testing::HasSubstr( - "layout has a physical_shape, but is not a sparse array"))); + HasSubstr("layout has a physical_shape, but is not a sparse array"))); shape.mutable_layout()->mutable_physical_shape()->clear_layout(); shape.mutable_layout()->clear_dim_level_types(); EXPECT_THAT( LayoutUtil::ValidateLayoutInShape(shape), - tsl::testing::StatusIs( + StatusIs( tsl::error::INVALID_ARGUMENT, - ::testing::HasSubstr( - "layout has a physical_shape, but is not a sparse array"))); + HasSubstr("layout has a physical_shape, but is not a sparse array"))); *shape.mutable_layout() = LayoutUtil::MakeLayout({1, 0}, {DIM_DENSE, DIM_DENSE}, {true, false}); EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), - tsl::testing::StatusIs( - tsl::error::INVALID_ARGUMENT, - ::testing::HasSubstr("layout dimension 1 has invalid level " - "encoding DIM_DENSE, non-unique"))); + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("layout dimension 1 has invalid level " + "encoding DIM_DENSE, non-unique"))); } TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) { @@ -521,7 +522,7 @@ TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) { LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("shape f32[1,2] does not have a layout")); + HasSubstr("shape f32[1,2] does not have a layout")); status = LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); EXPECT_TRUE(status.ok()); @@ -534,8 +535,8 @@ TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) { LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.message(), - ::testing::HasSubstr("layout minor_to_major field " - "contains 3 elements, but shape is rank 1")); + HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape has 1 dimensions")); } TEST_F(LayoutUtilTest, MoveDimToMajor) { From 523c5a4679402bf31047465845df5b1e0b1a388b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 4 Apr 2025 15:27:39 -0700 Subject: [PATCH 0254/1324] [xla:gpu] CommandBuffer: add CommandBufferCmdSequence::Builder for building command buffers For now always use sequential synchronization mode for command buffers. DAG command buffers with automatic synchronization coming in followup CLs. PiperOrigin-RevId: 744088007 --- .../gpu/runtime/command_buffer_cmd.cc | 41 +++-- .../backends/gpu/runtime/command_buffer_cmd.h | 69 +++++---- .../gpu/runtime/command_buffer_cmd_emitter.cc | 87 ++++------- .../gpu/runtime/command_buffer_cmd_emitter.h | 3 +- .../gpu/runtime/command_buffer_cmd_test.cc | 46 +++--- .../gpu/runtime/command_buffer_thunk_test.cc | 144 ++++++++++-------- .../xla/service/gpu/ir_emitter_unnested.cc | 9 +- 7 files changed, 210 insertions(+), 189 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 6c14705755d4d5..8d75b609ed1b56 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -192,17 +192,28 @@ CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( // CommandBufferCmdSequence //===----------------------------------------------------------------------===// -CommandBufferCmdSequence::CommandBufferCmdSequence( - SynchronizationMode synchronization_mode) - : synchronization_mode_(synchronization_mode) {} +void CommandBufferCmdSequence::Builder::Append( + std::unique_ptr cmd) { + commands_.push_back({std::move(cmd)}); +} -void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { - for (const BufferUse& buffer : cmd->buffers()) { - buffers_.insert(buffer); - allocs_indices_.insert(buffer.slice().index()); - } +CommandBufferCmdSequence CommandBufferCmdSequence::Builder::Build( + SynchronizationMode synchronization_mode) && { + return CommandBufferCmdSequence(synchronization_mode, std::move(commands_)); +} - commands_.push_back({std::move(cmd)}); +CommandBufferCmdSequence::CommandBufferCmdSequence( + SynchronizationMode synchronization_mode, + std::vector> commands) + : synchronization_mode_(synchronization_mode), + commands_(std::move(commands)) { + // Record all buffers used by commands in the sequence. + for (const std::unique_ptr& cmd : commands_) { + for (const BufferUse& buffer : cmd->buffers()) { + buffers_.insert(buffer); + allocs_indices_.insert(buffer.slice().index()); + } + } } absl::Status CommandBufferCmdSequence::Prepare( @@ -1667,7 +1678,7 @@ CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() { DynamicSliceFusionCmd::DynamicSliceFusionCmd( ExecutionStreamId execution_stream_id, - std::unique_ptr embedded_commands, + CommandBufferCmdSequence embedded_commands, std::vector> arguments, std::vector> fake_allocations, std::vector>> offsets, @@ -1722,7 +1733,7 @@ bool DynamicSliceFusionCmd::force_update() { absl::Status DynamicSliceFusionCmd::Initialize( const Thunk::InitializeParams& params, StateManager& state) { - TF_RETURN_IF_ERROR(embedded_commands_->Initialize(params, state)); + TF_RETURN_IF_ERROR(embedded_commands_.Initialize(params, state)); absl::MutexLock lock(&mutex_); if (offsets_allocs_.contains(params.executor)) return absl::OkStatus(); @@ -1754,7 +1765,7 @@ absl::Status DynamicSliceFusionCmd::Prepare( slice.orig_shape->dimensions_size()); } } - TF_RETURN_IF_ERROR(embedded_commands_->Prepare(params, resource_requests)); + TF_RETURN_IF_ERROR(embedded_commands_.Prepare(params, resource_requests)); return absl::OkStatus(); } @@ -1890,15 +1901,15 @@ DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, execute_params.stream->parent() ->CreateCommandBuffer(se::CommandBuffer::Mode::kNested) .value(); - TF_RETURN_IF_ERROR(embedded_commands_->Record(new_params, record_params, - nested_command_buffer.get())); + TF_RETURN_IF_ERROR(embedded_commands_.Record(new_params, record_params, + nested_command_buffer.get())); return RecordedCommands::Create( command_buffer->AddNestedCommandBuffer(*nested_command_buffer, {})); } CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() { CommandBufferCmd::BufferUseVector buffers; - auto embed_buffers = embedded_commands_->buffers(); + auto embed_buffers = embedded_commands_.buffers(); for (auto buffer_usage : embed_buffers) { CHECK( embeded_to_origin_slice_map_[buffer_usage.slice().index()].has_value()); diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 433d7f9a945c79..510c5acfe02557 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -267,33 +267,42 @@ class CommandBufferCmd { // purpose is to manipulate command buffers at run time. class CommandBufferCmdSequence { public: - // Synchronization mode defines how execution streams gets converted to - // command buffer execution scopes and barriers. - // - // Each individual Thunk assigned an execution stream id, and we have explicit - // inter-stream synchronization (`Thunk::Kind::kWaitForStreams`) between - // streams. Thunks assigned to the same stream are implicitly synchronized. - // - // Command buffers on the other hand by default can execute commands - // concurrently and require barriers to enforce execution order. - // - // WARNING: We do not have implicit synchronization between execution scopes - // corresponding to different execution streams and rely on explicit barriers - // emitted from thunks. Synchronization mode controls only barriers within - // a single exection scope (corresponds to execution stream). + CommandBufferCmdSequence() = default; + CommandBufferCmdSequence(CommandBufferCmdSequence&&) = default; + CommandBufferCmdSequence& operator=(CommandBufferCmdSequence&&) = default; + + // Synchronization mode defines how much concurrency is allowed between + // commands in the sequence. enum class SynchronizationMode { - // Adds barriers between all commands recorded into the same execution scope - // (thunks sharing execution stream) and enforces completely serialized - // execution order that matches what would happen in a ThunkSequence. + // Serializes execution of all commands recorded into the command buffer + // by adding a dependency between them. kSerialize, - // Relies on buffer use analysis to insert barriers only between commands - // that have read-write conflicts into the same buffers. Conflicts are - // detected only between commands using the same stream id, and inter-stream - // synchronization is a user responsibility. + // Relies on execution graph to insert dependencies between commands + // that have buffer of resource conflicts, and building a DAG of commands. kAutomatic }; + // A command buffer cmd sequence builder for lazy command sequence + // construction. + class Builder { + public: + void Append(std::unique_ptr cmd); + + template + void Emplace(Args... args) { + Append(std::make_unique(std::forward(args)...)); + } + + // TODO(b/406370928): Remove default argument and make sure we correctly + // propagate synchronization mode through the codebase. + CommandBufferCmdSequence Build(SynchronizationMode synchronization_mode = + SynchronizationMode::kSerialize) &&; + + private: + std::vector> commands_; + }; + enum class RecordMode { // In exclusive mode no one else is recording commands into the command // buffer argument, and cmd sequence is responsible for updating command @@ -309,16 +318,6 @@ class CommandBufferCmdSequence { kConditional }; - explicit CommandBufferCmdSequence(SynchronizationMode synchronization_mode = - SynchronizationMode::kAutomatic); - - void Append(std::unique_ptr cmd); - - template - void Emplace(Args... args) { - Append(std::make_unique(std::forward(args)...)); - } - // Prepares all commands added to a sequence. absl::Status Prepare(const Thunk::PrepareParams& params, Thunk::ResourceRequestsInterface& resource_requests); @@ -348,6 +347,10 @@ class CommandBufferCmdSequence { } private: + CommandBufferCmdSequence( + SynchronizationMode synchronization_mode, + std::vector> commands); + SynchronizationMode synchronization_mode_; std::vector> commands_; @@ -1017,7 +1020,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { public: DynamicSliceFusionCmd( ExecutionStreamId execution_stream_id, - std::unique_ptr embedded_commands, + CommandBufferCmdSequence embedded_commands, std::vector> arguments, std::vector> fake_allocations_, std::vector>> @@ -1045,7 +1048,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } private: - std::unique_ptr embedded_commands_; + CommandBufferCmdSequence embedded_commands_; std::vector slices_; std::vector> fake_allocations_; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index 653649f186ae76..de10af9c3ae6c2 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -50,22 +50,21 @@ limitations under the License. namespace xla::gpu { -// Appends command(s) converted from `thunk` to `cmd_sequence`. +// Appends command(s) converted from `thunk` to `cmd_sequence_builder`. static absl::Status AppendCommands( - CommandBufferCmdSequence& cmd_sequence, const Thunk& thunk, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode); + CommandBufferCmdSequence::Builder& cmd_sequence_builder, + const Thunk& thunk); -// Appends command(s) converted from `sequence` to `cmd_sequence`. +// Appends command(s) converted from `sequence` to `cmd_sequence_builder`. static absl::Status AppendCommands( - CommandBufferCmdSequence& cmd_sequence, const ThunkSequence& sequence, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode); + CommandBufferCmdSequence::Builder& cmd_sequence_builder, + const ThunkSequence& sequence); //===----------------------------------------------------------------------===// // Conversions from Thunk to Command //===----------------------------------------------------------------------===// using Command = std::unique_ptr; -using xla::BufferUse; static auto ArgsAccess(const std::vector& written) { absl::InlinedVector args_access; @@ -106,16 +105,12 @@ static absl::StatusOr Convert(const Memset32BitValueThunk& thunk) { thunk.destination(), thunk.value()); } -static absl::StatusOr Convert( - const WhileThunk& thunk, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { +static absl::StatusOr Convert(const WhileThunk& thunk) { TF_ASSIGN_OR_RETURN( CommandBufferCmdSequence cond_cmds, - ConvertToCommands(thunk.condition_thunk_sequence()->thunks(), - synchronization_mode)); + ConvertToCommands(thunk.condition_thunk_sequence()->thunks())); TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence body_cmds, - ConvertToCommands(thunk.body_thunk_sequence()->thunks(), - synchronization_mode)); + ConvertToCommands(thunk.body_thunk_sequence()->thunks())); return std::make_unique(thunk.execution_stream_id(), thunk.condition_result_buffer(), @@ -147,9 +142,7 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { thunk.workspace().value()); } -static absl::StatusOr Convert( - const ConditionalThunk& thunk, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { +static absl::StatusOr Convert(const ConditionalThunk& thunk) { std::vector branch_cmds; branch_cmds.reserve(thunk.branch_thunks().size()); if (thunk.branch_index_is_bool()) { @@ -157,16 +150,13 @@ static absl::StatusOr Convert( // because the first branch is the "false" branch and the second is "true" CHECK_EQ(thunk.branch_thunks().size(), 2); TF_ASSIGN_OR_RETURN(branch_cmds.emplace_back(), - ConvertToCommands(thunk.branch_thunks()[1]->thunks(), - synchronization_mode)); + ConvertToCommands(thunk.branch_thunks()[1]->thunks())); TF_ASSIGN_OR_RETURN(branch_cmds.emplace_back(), - ConvertToCommands(thunk.branch_thunks()[0]->thunks(), - synchronization_mode)); + ConvertToCommands(thunk.branch_thunks()[0]->thunks())); } else { for (auto& branch_thunk : thunk.branch_thunks()) { - TF_ASSIGN_OR_RETURN( - CommandBufferCmdSequence cmds, - ConvertToCommands(branch_thunk->thunks(), synchronization_mode)); + TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmds, + ConvertToCommands(branch_thunk->thunks())); branch_cmds.emplace_back(std::move(cmds)); } } @@ -200,11 +190,9 @@ static absl::StatusOr Convert(const AllGatherStartThunk& thunk) { } static absl::StatusOr Convert(const DynamicSliceThunk& thunk) { - auto cmd_sequence = std::make_unique(); + CommandBufferCmdSequence::Builder builder; auto embed_thunk = thunk.get_embedded_thunk(); - TF_RETURN_IF_ERROR(AppendCommands( - *cmd_sequence, embed_thunk->thunks(), - CommandBufferCmdSequence::SynchronizationMode::kAutomatic)); + TF_RETURN_IF_ERROR(AppendCommands(builder, embed_thunk->thunks())); auto& thunk_fake_allocations = thunk.get_fake_allocations(); std::vector> fake_allocations; @@ -213,7 +201,7 @@ static absl::StatusOr Convert(const DynamicSliceThunk& thunk) { fake_allocations.push_back(std::make_unique(**it)); } return std::make_unique( - thunk.execution_stream_id(), std::move(cmd_sequence), + thunk.execution_stream_id(), std::move(builder).Build(), thunk.get_arguments(), std::move(fake_allocations), thunk.get_offsets(), thunk.get_orig_shapes(), thunk.get_sliced_shapes(), thunk.get_offset_byte_sizes()); @@ -264,19 +252,12 @@ static absl::StatusOr Convert(const Thunk& thunk) { return CopyMetadata(Convert(static_cast(thunk)), thunk); } -template -static absl::StatusOr Convert( - const Thunk& thunk, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { - return Convert(static_cast(thunk), synchronization_mode); -} - static absl::Status AppendCommands( - CommandBufferCmdSequence& cmd_sequence, const Thunk& thunk, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + CommandBufferCmdSequence::Builder& cmd_sequence_builder, + const Thunk& thunk) { auto append = [&](absl::StatusOr command) -> absl::Status { if (command.ok()) { - cmd_sequence.Append(std::move(*command)); + cmd_sequence_builder.Append(std::move(*command)); return absl::OkStatus(); } return command.status(); @@ -284,7 +265,7 @@ static absl::Status AppendCommands( switch (thunk.kind()) { case Thunk::Kind::kConditional: - return append(Convert(thunk, synchronization_mode)); + return append(Convert(thunk)); case Thunk::Kind::kCopy: return append(Convert(thunk)); case Thunk::Kind::kCustomCall: @@ -314,7 +295,7 @@ static absl::Status AppendCommands( case Thunk::Kind::kReplicaId: return append(Convert(thunk)); case Thunk::Kind::kWhile: - return append(Convert(thunk, synchronization_mode)); + return append(Convert(thunk)); case Thunk::Kind::kCuDnn: return append(Convert(thunk)); case Thunk::Kind::kDynamicSlice: @@ -323,9 +304,9 @@ static absl::Status AppendCommands( // Sequential thunk does not have any special semantics and we simply inline // all nested thunks into command buffer. case Thunk::Kind::kSequential: - return AppendCommands(cmd_sequence, - static_cast(thunk).thunks(), - synchronization_mode); + return AppendCommands( + cmd_sequence_builder, + static_cast(thunk).thunks()); // Thunks that simply wait for stream events are no-op in the command buffer // context, as we convert async thunks to command dependency graph. @@ -350,22 +331,18 @@ static absl::Status AppendCommands( } static absl::Status AppendCommands( - CommandBufferCmdSequence& cmd_sequence, const ThunkSequence& sequence, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + CommandBufferCmdSequence::Builder& cmd_sequence_builder, + const ThunkSequence& sequence) { for (const std::unique_ptr& thunk : sequence) - TF_RETURN_IF_ERROR( - AppendCommands(cmd_sequence, *thunk, synchronization_mode)); + TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, *thunk)); return absl::OkStatus(); } -// TODO(vuson): Add unit tests. absl::StatusOr ConvertToCommands( - const ThunkSequence& sequence, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { - CommandBufferCmdSequence cmd_sequence(synchronization_mode); - TF_RETURN_IF_ERROR( - AppendCommands(cmd_sequence, sequence, synchronization_mode)); - return cmd_sequence; + const ThunkSequence& sequence) { + CommandBufferCmdSequence::Builder cmd_sequence_builder; + TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, sequence)); + return std::move(cmd_sequence_builder).Build(); } } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h index 9622723803e752..fbe7ed9a016f5c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h @@ -28,8 +28,7 @@ namespace xla::gpu { // allow commands to run concurrently and insert barriers only when needed for // correctness. absl::StatusOr ConvertToCommands( - const ThunkSequence& sequence, - CommandBufferCmdSequence::SynchronizationMode synchronization_mode); + const ThunkSequence& sequence); } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index bb2f58cb6d584f..528a593db6d0a5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/functional/function_ref.h" @@ -144,10 +145,10 @@ TEST(CommandBufferCmdTest, SerializeExecution) { auto use0 = BufferUse(slice0, BufferUse::kRead); auto use1 = BufferUse(slice1, BufferUse::kRead); - CommandBufferCmdSequence commands( - CommandBufferCmdSequence::SynchronizationMode::kSerialize); - commands.Emplace(s0, BufferUseVector{use0}); - commands.Emplace(s0, BufferUseVector{use1}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, BufferUseVector{use0}); + builder.Emplace(s0, BufferUseVector{use1}); + CommandBufferCmdSequence commands = std::move(builder).Build(); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -162,9 +163,10 @@ TEST(CommandBufferCmdTest, NoReadBarrier) { auto use0 = BufferUse(slice0, BufferUse::kRead); auto use1 = BufferUse(slice1, BufferUse::kRead); - CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUseVector{use0}); - commands.Emplace(s0, BufferUseVector{use1}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, BufferUseVector{use0}); + builder.Emplace(s0, BufferUseVector{use1}); + CommandBufferCmdSequence commands = std::move(builder).Build(); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -179,9 +181,10 @@ TEST(CommandBufferCmdTest, NoWriteBarrier) { auto use0 = BufferUse(slice0, BufferUse::kWrite); auto use1 = BufferUse(slice1, BufferUse::kWrite); - CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUseVector{use0}); - commands.Emplace(s0, BufferUseVector{use1}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, BufferUseVector{use0}); + builder.Emplace(s0, BufferUseVector{use1}); + CommandBufferCmdSequence commands = std::move(builder).Build(); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -198,10 +201,11 @@ TEST(CommandBufferCmdTest, WriteConflictBarrier) { auto use1 = BufferUse(slice0, BufferUse::kRead); auto use2 = BufferUse(slice1, BufferUse::kWrite); - CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUseVector{use0}); - commands.Emplace(s0, BufferUseVector{use1}); - commands.Emplace(s0, BufferUseVector{use2}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, BufferUseVector{use0}); + builder.Emplace(s0, BufferUseVector{use1}); + builder.Emplace(s0, BufferUseVector{use2}); + CommandBufferCmdSequence commands = std::move(builder).Build(); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -228,8 +232,9 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_b, slice_a, byte_length); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_b, slice_a, byte_length); + CommandBufferCmdSequence commands = std::move(builder).Build(); ServiceExecutableRunOptions run_options; se::StreamExecutorMemoryAllocator allocator(executor); @@ -281,10 +286,11 @@ TEST(CommandBufferCmdTest, LaunchCmd) { auto args_access = {BufferUse::kRead, MemoryAccess::kRead, BufferUse::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Initialize command sequence and load device kernels. TF_ASSERT_OK_AND_ASSIGN(std::vector fatbin, diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 5d5324815fba2a..2dbf17e1d7c32b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -154,8 +154,9 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_b, slice_a, byte_length); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_b, slice_a, byte_length); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -208,8 +209,9 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_a); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_a); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -250,8 +252,9 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_a, int32_t{84}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_a, int32_t{84}); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -299,8 +302,9 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); // Prepare commands sequence for constructing command buffer that should not // be used. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_a, int32_t{12}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_a, int32_t{12}); + CommandBufferCmdSequence commands = std::move(builder).Build(); constexpr bool kProfileCommandBuffersEnabled = false; // Construct a thunk with command sequence. @@ -353,8 +357,9 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); // Prepare commands sequence for constructing command buffer that should not // be used. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_a, int32_t{12}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_a, int32_t{12}); + CommandBufferCmdSequence commands = std::move(builder).Build(); constexpr bool kProfileCommandBuffersEnabled = true; // Construct a thunk with command sequence. @@ -396,9 +401,10 @@ TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { BufferAllocation::Slice slice1(&alloc, 1 * sizeof(int32_t), sizeof(int32_t)); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice0, int32_t{12}); - commands.Emplace(s1, slice1, int32_t{34}); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice0, int32_t{12}); + builder.Emplace(s1, slice1, int32_t{34}); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -448,10 +454,11 @@ TEST(CommandBufferThunkTest, LaunchCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -545,10 +552,11 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -662,10 +670,11 @@ TEST(CommandBufferThunkTest, GemmCmd) { ASSERT_TRUE(config.ok()); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, - slice_workspace, - /*deterministic=*/true); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, + slice_workspace, + /*deterministic=*/true); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -788,11 +797,11 @@ TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { ASSERT_TRUE(config.ok()); // Prepare commands sequence for constructing command buffer. - std::unique_ptr embed_commands = - std::make_unique(); - embed_commands->Emplace(s0, config.value(), fake_slice_lhs, - slice_rhs, slice_out, slice_workspace, - /*deterministic=*/true); + CommandBufferCmdSequence::Builder embed_builder; + embed_builder.Emplace(s0, config.value(), fake_slice_lhs, slice_rhs, + slice_out, slice_workspace, + /*deterministic=*/true); + CommandBufferCmdSequence embed_commands = std::move(embed_builder).Build(); BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); @@ -818,10 +827,11 @@ TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { std::vector> offset_byte_sizes = { sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}; - CommandBufferCmdSequence commands; - commands.Emplace( + CommandBufferCmdSequence::Builder builder; + builder.Emplace( s0, std::move(embed_commands), arguments, std::move(fake_allocations), offsets, orig_shapes, sliced_shapes, offset_byte_sizes); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -927,13 +937,14 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { ASSERT_TRUE(config.ok()); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace( + CommandBufferCmdSequence::Builder builder; + builder.Emplace( s0, config.value(), se::gpu::BlasLt::Epilogue::kDefault, 0, slice_a, slice_b, slice_c, slice_d, BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), slice_workspace); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1060,13 +1071,14 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - commands.Emplace(s0, "AddI32", args_1, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + builder.Emplace(s0, "AddI32", args_1, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1170,28 +1182,33 @@ TEST(CommandBufferThunkTest, CaseCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); // Prepare commands sequence for branches. - std::vector branches(2); + std::vector branches_builder(2); auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, MemoryAccess::kWrite}; { // Case 0: b = a + a auto args = {slice_a, slice_a, slice_b}; - branches[0].Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + branches_builder[0].Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); } { // Case 1: b = b + b auto args = {slice_b, slice_b, slice_b}; - branches[1].Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + branches_builder[1].Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); } + std::vector branches(2); + branches[0] = std::move(branches_builder[0]).Build(); + branches[1] = std::move(branches_builder[1]).Build(); + // Prepare commands sequence for thunk. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_i, false, std::move(branches)); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_i, false, std::move(branches)); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1274,21 +1291,26 @@ TEST(CommandBufferThunkTest, WhileCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for loop `cond`. - CommandBufferCmdSequence cond_commands; - cond_commands.Emplace(s0, "IncAndCmp", cond_args, cond_args_access, - LaunchDimensions(1, 1), - /*shmem_bytes=*/0); + CommandBufferCmdSequence::Builder cond_commands_builder; + cond_commands_builder.Emplace( + s0, "IncAndCmp", cond_args, cond_args_access, LaunchDimensions(1, 1), + /*shmem_bytes=*/0); + CommandBufferCmdSequence cond_commands = + std::move(cond_commands_builder).Build(); // Prepare commands sequence for loop `body`. - CommandBufferCmdSequence body_commands; - body_commands.Emplace(s0, "AddI32", body_args, body_args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + CommandBufferCmdSequence::Builder body_commands_builder; + body_commands_builder.Emplace( + s0, "AddI32", body_args, body_args_access, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + CommandBufferCmdSequence body_commands = + std::move(body_commands_builder).Build(); // Prepare commands sequence for thunk. - CommandBufferCmdSequence commands; - commands.Emplace(s0, slice_pred, std::move(cond_commands), - std::move(body_commands)); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_pred, std::move(cond_commands), + std::move(body_commands)); + CommandBufferCmdSequence commands = std::move(builder).Build(); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index b51f5f4928568c..c3bed53294e36d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -573,9 +573,12 @@ absl::Status IrEmitterUnnested::EmitCommandBufferThunk( ? CommandBufferCmdSequence::SynchronizationMode::kAutomatic : CommandBufferCmdSequence::SynchronizationMode::kSerialize; - TF_ASSIGN_OR_RETURN( - CommandBufferCmdSequence cmd_sequence, - ConvertToCommands(thunk_sequence->thunks(), synchronization_mode)); + // TODO(b/406370928): Use `synchronization_mode` to construct a command buffer + // cmd sequence with specified synchronization mode. + (void)synchronization_mode; + + TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmd_sequence, + ConvertToCommands(thunk_sequence->thunks())); AddThunkToThunkSequence(std::make_unique( std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr), From 278759df28d1bc5297b1b980584a6abbad4ad88f Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Fri, 4 Apr 2025 15:31:13 -0700 Subject: [PATCH 0255/1324] [hlo op profile] Remove memory type size check and peak bandwidth removal to enable HBM analysis for GPU PiperOrigin-RevId: 744088860 --- tensorflow/core/profiler/convert/op_profile_builder.cc | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc index fcb02c9227a938..8741b92df8bfff 100644 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ b/tensorflow/core/profiler/convert/op_profile_builder.cc @@ -164,16 +164,6 @@ void PopulateOpMetricsNode( const OpMetrics& op_metrics, double peak_gigaflops_per_second_per_core, std::vector peak_mem_gibibytes_per_second_per_core, uint64_t total_time_ps, Node* node) { - // TODO(dfinchel): remove this temporary change to avoid crash. - // This is only needed while we make an update to proto version that is not - // backwards compatible. - if (peak_mem_gibibytes_per_second_per_core.size() != - (MemBwType_MAX - MemBwType_MIN + 1)) { - peak_mem_gibibytes_per_second_per_core.clear(); - for (int i = MemBwType_MIN; i <= MemBwType_MAX; ++i) { - peak_mem_gibibytes_per_second_per_core.push_back(0); - } - } Metrics* metrics = node->mutable_metrics(); // The UI computes flops_rate = raw_flops / raw_time From 51aa01247b6aaab85a73e858bcd0562aa22b2584 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 4 Apr 2025 15:40:10 -0700 Subject: [PATCH 0256/1324] [IFRT] Check the uniqueness of IFRT devices IFRT devices must be unique and have unique device Ids within a single IFRT client. This is now tested explicitly to ensure that all implementations satisfy the requirement. PiperOrigin-RevId: 744091324 --- .../xla/backends/cpu/nanort/ifrt_client.cc | 30 ++++++++++++------- .../xla/xla/backends/cpu/nanort/ifrt_client.h | 9 ++---- third_party/xla/xla/python/ifrt/BUILD | 1 + .../xla/python/ifrt/client_impl_test_lib.cc | 4 +++ 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc index 7de59f236a215e..ed9fc1aa00cd63 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc @@ -1182,12 +1182,12 @@ ABSL_ATTRIBUTE_UNUSED char NanoMemory::ID = 'M'; // NOLINT // Device implementation. There is only one device so this doesn't do much. class NanoDevice final : public llvm::RTTIExtends { public: - NanoDevice(NanoIfrtClient* client, ifrt::Memory* memory) - : client_(client), memory_(memory) {} + NanoDevice(NanoIfrtClient* client, ifrt::DeviceId id, ifrt::Memory* memory) + : client_(client), id_(id), memory_(memory) {} ifrt::Client* client() const override { return client_; } - ifrt::DeviceId Id() const override { return ifrt::DeviceId(0); } + ifrt::DeviceId Id() const override { return id_; } const ifrt::AttributeMap& Attributes() const override { static auto attributes = new ifrt::AttributeMap({}); @@ -1216,6 +1216,7 @@ class NanoDevice final : public llvm::RTTIExtends { private: NanoIfrtClient* client_; + ifrt::DeviceId id_; ifrt::Memory* memory_; }; @@ -1235,7 +1236,8 @@ std::shared_ptr NanoIfrtClient::CreateWithDevices( } std::shared_ptr NanoIfrtClient::default_sharding() const { - return ifrt::SingleDeviceSharding::Create(device_.get(), ifrt::MemoryKind{}); + return ifrt::SingleDeviceSharding::Create(devices_.front(), + ifrt::MemoryKind{}); } absl::StatusOr> @@ -1409,7 +1411,9 @@ absl::StatusOr NanoIfrtClient::LookupDevice( absl::StatusOr NanoIfrtClient::LookupAddressableDevice( int local_hardware_id) const { - return device_.get(); + TF_RET_CHECK(local_hardware_id >= 0); + TF_RET_CHECK(local_hardware_id < devices_.size()); + return devices_[local_hardware_id]; } ifrt::DeviceListRef NanoIfrtClient::MakeDeviceList( @@ -1435,11 +1439,17 @@ NanoIfrtClient::GetDefaultLayout(ifrt::DType dtype, NanoIfrtClient::NanoIfrtClient(int32_t num_devices) : compiler_(std::make_unique(this)), - memory_(std::make_unique(this)), - device_(std::make_unique(this, memory_.get())), - default_sharding_( - ifrt::SingleDeviceSharding::Create(device_.get(), memory_->Kind())), - devices_(num_devices, device_.get()) {} + memory_(std::make_unique(this)) { + owned_devices_.reserve(num_devices); + devices_.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + owned_devices_.push_back( + std::make_unique(this, ifrt::DeviceId(i), memory_.get())); + devices_.push_back(owned_devices_.back().get()); + } + default_sharding_ = + ifrt::SingleDeviceSharding::Create(devices_.front(), memory_->Kind()); +} char NanoIfrtClient::ID = 'N'; // NOLINT diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h index d5c5fe7e226e08..14e133e2cc16cc 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h @@ -193,19 +193,14 @@ class NanoIfrtClient : public llvm::RTTIExtends { // details. std::unique_ptr compiler_; std::unique_ptr memory_; - std::unique_ptr device_; + std::vector> owned_devices_; // The default sharding for this client. When this sharding is used it // typically means that we can use an array's contents directly. std::shared_ptr default_sharding_; // Some of the ifrt::Client methods return a span of devices, so we need to - // keep storage for them here. Note that this may repeat the device_ pointer - // multiple times if this client is configured with multiple devices. This is - // mostly to make IFRT callers that expect sharded programs to run on multiple - // devices happy. This has the unusual property that we have multiple devices - // but a single device_id, but this seems to work fine and most documentation - // warns that devices may be repeated within a device list or sharding. + // keep storage for them here. std::vector devices_; }; diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 2ddee003c5a771..082adf338c0500 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -443,6 +443,7 @@ cc_library( ":test_util", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", + "@com_google_absl//absl/container:flat_hash_set", ], alwayslink = True, ) diff --git a/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc index 7bff86a9ffdcd7..7c76d5caae04ca 100644 --- a/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/flat_hash_set.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/platform/statusor.h" @@ -60,10 +61,13 @@ TEST(ClientImplTest, GetAllDevices) { EXPECT_GE(client->GetAllDevices().size(), client->device_count()); + absl::flat_hash_set seen_device_ids; for (Device* device : client->GetAllDevices()) { TF_ASSERT_OK_AND_ASSIGN(auto* looked_up_device, client->LookupDevice(device->Id())); EXPECT_EQ(device, looked_up_device); + EXPECT_TRUE(seen_device_ids.insert(device->Id()).second) + << "Duplicate device ID: " << device->Id(); } } From 17d22b5d94a11cb53d340534e0df502c4288c6d0 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 4 Apr 2025 15:43:28 -0700 Subject: [PATCH 0257/1324] Change shared_ptr -> unique_ptr. This also removes some of the quirks of the GPU implementation (like the ForClosure code). PiperOrigin-RevId: 744092320 --- third_party/xla/xla/pjrt/gpu/raw_buffer.cc | 6 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 75 ++-- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 1 + .../xla/pjrt/pjrt_stream_executor_client.cc | 373 ++++++++++-------- .../xla/pjrt/pjrt_stream_executor_client.h | 73 ++-- .../xla/xla/pjrt/tracked_device_buffer.cc | 33 +- .../xla/xla/pjrt/tracked_device_buffer.h | 13 +- 7 files changed, 319 insertions(+), 255 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/raw_buffer.cc b/third_party/xla/xla/pjrt/gpu/raw_buffer.cc index 7e6ba82da584fc..dbb30e0537fa19 100644 --- a/third_party/xla/xla/pjrt/gpu/raw_buffer.cc +++ b/third_party/xla/xla/pjrt/gpu/raw_buffer.cc @@ -52,15 +52,13 @@ CreateGPURawBuffer(PjRtBuffer* buffer) { if (!hold.ok()) { return hold.status(); } - const auto& device_buffer = hold.buffer(); - if (!device_buffer->device_memory()) { + if (!hold->device_memory()) { return absl::InvalidArgumentError( "Create raw buffer called on an invalid buffer"); } return tsl::MakeRef( se_client, se_buffer->memory_space(), - se_buffer->device()->local_device_state(), - device_buffer->device_memory()); + se_buffer->device()->local_device_state(), hold->device_memory()); } return std::nullopt; } diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 3ddd2c2caeba0a..df901c978c9e45 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -130,7 +130,7 @@ limitations under the License. #include "xla/util.h" namespace xla { -class AsyncHostToDeviceTransferManager +class GpuAsyncHostToDeviceTransferManager : public xla::PjRtClient::AsyncHostToDeviceTransferManager { public: static absl::StatusOr> @@ -145,7 +145,7 @@ class AsyncHostToDeviceTransferManager device_layouts->size(), shape_specs.size()); } absl::InlinedVector, 4> buffers; - absl::InlinedVector, 4> buffer_ptrs; + absl::InlinedVector, 4> buffer_ptrs; absl::InlinedVector, 4> definition_events; absl::InlinedVector device_shapes; @@ -190,18 +190,18 @@ class AsyncHostToDeviceTransferManager tensorflow::down_cast(buffer.get()); DCHECK(se_buffer); auto hold = se_buffer->GetBufferWithUsageHold(); - buffer_ptrs.push_back(hold.buffer()); + buffer_ptrs.push_back(hold->device_memory()); buffers.push_back(std::move(buffer)); } - return std::make_unique( + return std::make_unique( std::move(buffers), std::move(buffer_ptrs), std::move(definition_events), std::move(device_shapes), device); } - AsyncHostToDeviceTransferManager( + GpuAsyncHostToDeviceTransferManager( absl::InlinedVector, 4> buffers, - absl::InlinedVector, 4> buffer_ptrs, + absl::InlinedVector, 4> buffer_ptrs, absl::InlinedVector, 4> definition_events, absl::InlinedVector device_shapes, @@ -215,13 +215,13 @@ class AsyncHostToDeviceTransferManager device_(device) { buffer_sizes_.reserve(buffer_ptrs_.size()); for (const auto& ptr : buffer_ptrs_) { - DCHECK(ptr->device_memory()); - buffer_sizes_.push_back(ptr->device_memory()->mem().size()); + DCHECK(ptr); + buffer_sizes_.push_back(ptr->mem().size()); } last_transfer_started_.resize(buffer_ptrs_.size(), false); } - ~AsyncHostToDeviceTransferManager() override { + ~GpuAsyncHostToDeviceTransferManager() override { auto transfers_finished = [this]() { mu_.AssertHeld(); return transfers_in_flight_ == 0; @@ -252,7 +252,7 @@ class AsyncHostToDeviceTransferManager int buffer_index, const LiteralSlice& literal, absl::AnyInvocable on_done) override { tsl::profiler::TraceMe traceme( - "AsyncHostToDeviceTransferManager::TransferLiteralToBuffer"); + "GpuAsyncHostToDeviceTransferManager::TransferLiteralToBuffer"); auto* stream = device_->local_device_state()->host_to_device_stream(); auto* se_client = tensorflow::down_cast(device_->client()); @@ -261,7 +261,7 @@ class AsyncHostToDeviceTransferManager TransferManager* transfer_manager = se_client->client()->backend().transfer_manager(); - std::shared_ptr buffer; + tsl::RCReference buffer; { absl::MutexLock l(&mu_); @@ -275,13 +275,6 @@ class AsyncHostToDeviceTransferManager last_transfer_started_[buffer_index] = true; buffer = buffer_ptrs_[buffer_index]; DCHECK(buffer); - if (!buffer->device_memory()) { - return InvalidArgument( - "TransferLiteralToBuffer requested for buffer index %d which has " - "been donated. Async transfer of donated buffers is not supported " - "in SE:GPU", - buffer_index); - } ++transfers_in_flight_; } @@ -290,19 +283,20 @@ class AsyncHostToDeviceTransferManager // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. auto transfer_h2d = [this, buffer_index, stream, transfer_manager, literal, - device_buffer = buffer.get(), + device = device_, device_buffer = buffer, local_device = std::move(device_->local_device_state()), on_done = std::move(on_done)]() mutable { tsl::profiler::TraceMe traceme( - "AsyncHostToDeviceTransferManager::TransferLiteralToBuffer::transfer_" + "GpuAsyncHostToDeviceTransferManager::TransferLiteralToBuffer::" + "transfer_" "h2d"); auto event = local_device->event_pool().AllocateEvent(stream->parent()); // Initiate linearization and transfer of the buffer on the stream. ShapedBuffer buffer = - device_buffer->AsShapedBuffer(device_shapes_[buffer_index]); + device_buffer->AsShapedBuffer(device, device_shapes_[buffer_index]); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( stream, literal, buffer)); local_device->event_pool().ThenRecordEvent(stream, event.value()); @@ -375,14 +369,7 @@ class AsyncHostToDeviceTransferManager last_transfer_started_[buffer_index] = true; } DCHECK(buffer_ptrs_[buffer_index]); - if (!buffer_ptrs_[buffer_index]->device_memory()) { - return InvalidArgument( - "TransferRawDataToSubBuffer requested for buffer index %d which has " - "been donated. Async transfer of donated buffers is not supported " - "in SE:GPU", - buffer_index); - } - auto& buffer_memory = buffer_ptrs_[buffer_index]->device_memory()->mem(); + auto& buffer_memory = buffer_ptrs_[buffer_index]->mem(); se::DeviceMemoryBase sub_buffer; CHECK_LE(offset, buffer_memory.size()); CHECK_LE(transfer_size, buffer_memory.size() - offset); @@ -457,7 +444,7 @@ class AsyncHostToDeviceTransferManager absl::InlinedVector buffer_sizes_; // References to the underlying storage for all the buffers, which ensures // that the buffers can't be freed before all transfers complete. - absl::InlinedVector, 4> buffer_ptrs_ + absl::InlinedVector, 4> buffer_ptrs_ ABSL_GUARDED_BY(mu_); // True if the last transfer for a buffer has been initiated. Used to prevent // a client initiating another transfer after the last transfer has already @@ -487,7 +474,7 @@ class AsyncHostToDeviceTransferManager if (is_last_transfer) { // Drop our reference to the TrackedDeviceBuffer for this buffer. CHECK(buffer_ptrs_[buffer_index]); - buffer_ptrs_[buffer_index] = nullptr; + buffer_ptrs_[buffer_index] = tsl::RCReference(); CHECK_GT(remaining_buffer_count_, 0); --remaining_buffer_count_; definition_events_[buffer_index]->SetSequencingEvent(std::move(event), @@ -612,7 +599,7 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( PjRtDevice* device = memory_space->devices()[0]; auto* stream_executor_device = tensorflow::down_cast(device); - return xla::AsyncHostToDeviceTransferManager::Create( + return xla::GpuAsyncHostToDeviceTransferManager::Create( shape_specs, std::move(device_layouts), stream_executor_device, this, memory_space); } @@ -648,8 +635,8 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( return PjRtFuture<>(hold.status()); } - auto device_buffer = hold.buffer(); - if (!device_buffer->device_memory()) { + auto device_memory = hold->device_memory(); + if (!device_memory) { return PjRtFuture<>( InvalidArgument("Copy raw buffer called on an invalid buffer")); } @@ -658,6 +645,9 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( auto usage_event = std::make_shared(this->thread_pool()); + auto definition_events = hold->definition_events(); + auto first_definition_event = definition_events[0]; + // When using the ComputeSynchronized allocation model, retain a reference to // the device_buffer until the copy completes, to ensure that the buffer isn't // deleted or donated while it is still in use. The choice of retaining a @@ -668,7 +658,9 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( hold.ConvertUsageHold(stream, usage_event, /*reference_held=*/true); auto async_copy = [this, promise, offset, transfer_size, stream, local_device, - device_buffer, usage_event = std::move(usage_event)]( + owning_device_memory = std::move(device_memory), + definition_events = std::move(definition_events), + usage_event = std::move(usage_event)]( absl::StatusOr dst) mutable { absl::StatusOr event = local_device->event_pool().AllocateEvent(stream->parent()); @@ -677,14 +669,13 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( return; } - absl::Status defined_status = - device_buffer->definition_events()[0]->GetDefinedStatus(); + absl::Status defined_status = definition_events[0]->GetDefinedStatus(); if (!defined_status.ok()) { promise.Set(defined_status); return; } - auto& device_memory = device_buffer->device_memory()->mem(); + auto& device_memory = owning_device_memory->mem(); if (offset < 0 || offset > device_memory.size() || device_memory.size() - offset < transfer_size) { promise.Set( @@ -702,7 +693,8 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( sub_buffer = std::make_unique(device_memory); } - WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); + WaitForBufferDefinitionEventsOnStream(absl::MakeSpan(definition_events), + stream); if (transfer_size != 0) { if (should_stage_host_to_device_transfers() && @@ -751,7 +743,8 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( usage_event->SetSequencingEvent(std::move(event).value(), stream); auto callback_status = local_device->ThenExecuteCallback( - stream, [promise, device_buffer = std::move(device_buffer)]() mutable { + stream, [promise, owning_device_memory = + std::move(owning_device_memory)]() mutable { promise.Set(); }); if (!callback_status.ok()) { @@ -760,7 +753,7 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( } }; - device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( + first_definition_event->ExecuteOrAddToFutureTasks( absl::StrFormat("async_copy_raw_sub_buffer_to_host_%p", &async_copy), [this, dst, async_copy = std::move(async_copy)]() mutable { dst.OnReady([this, async_copy = std::move(async_copy)]( diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 71bdb41cab3564..b86a100d53c59f 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1174,6 +1174,7 @@ TEST(StreamExecutorGpuClientTest, GetAllocatorStatsTest) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr buffer, client->BufferFromHostLiteral(literal, memory_space)); + TF_ASSERT_OK(buffer->GetReadyFuture().Await()); auto stats = device->GetAllocatorStats(); TF_ASSERT_OK(stats.status()); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 338312b2d208d0..4d687f66ccecc3 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -150,6 +150,11 @@ limitations under the License. namespace xla { +template +static std::function WrapClosureAsCopyable(T cb) { + return [state = std::make_shared(std::move(cb))]() { return (*state)(); }; +} + PjRtStreamExecutorMemorySpace::PjRtStreamExecutorMemorySpace( int id, PjRtDevice* device, absl::string_view kind, int kind_id) : id_(id), device_(device), kind_(kind), kind_id_(kind_id) { @@ -354,7 +359,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, LocalDeviceState* stream_local_device, std::shared_ptr event, se::Stream* usage_stream, - std::vector>* + std::vector>* buffers_to_release = nullptr) { tsl::profiler::TraceMe traceme("RecordUsage"); bool retain_buffer_until_completion = @@ -366,9 +371,10 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, LocalDeviceState::kSynchronous); if (retain_buffer_until_completion) { if (buffers_to_release) { - buffers_to_release->push_back(device_buffer.buffer()); + buffers_to_release->push_back(device_buffer->device_memory()); } else { - buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer()) + buffer_local_device + ->ThenRelease(usage_stream, device_buffer->device_memory()) .IgnoreError(); } } @@ -381,7 +387,6 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, // had an event recorded. absl::Status AddDestinationBufferSynchronization( LocalDeviceState* local_device, - PjRtStreamExecutorBuffer::ScopedHold device_buffer, std::shared_ptr definition_event, se::Stream* copy_stream) { absl::StatusOr event_or = @@ -392,8 +397,6 @@ absl::Status AddDestinationBufferSynchronization( } definition_event->SetSequencingEvent(std::move(event_or).value(), copy_stream); - RecordUsage(std::move(device_buffer), local_device, local_device, - definition_event, copy_stream); return absl::OkStatus(); } @@ -517,7 +520,7 @@ AllocateDestinationBuffer( dst_buffer.memory_allocator()); dst_buffer.clear(); - auto dst_device_buffer = std::make_shared( + auto dst_device_buffer = std::make_unique( device, std::move(mem), definition_events); auto py_buffer = std::make_unique( @@ -528,7 +531,11 @@ AllocateDestinationBuffer( PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() { if (ok()) { - parent_->DropHold(type_, buffer().get()); + if (type_ == kDonation) { + parent_->DropDonationHold(std::move(buffer_)); + } else { + parent_->DropUsageOrExternalHold(type_, buffer_ptr_); + } } } @@ -537,33 +544,44 @@ PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other) type_(other.type_), state_(other.state_), status_(std::move(other.status_)), + buffer_ptr_(other.buffer_ptr_), buffer_(std::move(other.buffer_)) { // Preserve the invariant that status is invalid if buffer == nullptr. other.SetState(kMoved); } -void PjRtStreamExecutorBuffer::ScopedHold::Acquire( - absl::StatusOr>&& buffer_or) { +void PjRtStreamExecutorBuffer::ScopedHold::AcquireDonation( + absl::StatusOr> buffer_or) { CHECK(!ok()); if (buffer_or.ok()) { - buffer_ = buffer_or.value(); + buffer_ = std::move(buffer_or).value(); + buffer_ptr_ = buffer_.get(); SetState(kValid); } else { - status_ = buffer_or.status(); + status_ = std::move(buffer_or).status(); buffer_ = nullptr; + buffer_ptr_ = nullptr; SetState(kError); } // Check the invariant holds. - CHECK(!ok() || buffer_ != nullptr); + CHECK(!ok() || buffer_ptr_ != nullptr); } -PjRtStreamExecutorBuffer::ScopedHold::ForClosure -PjRtStreamExecutorBuffer::ScopedHold::ToClosure() { - CHECK(ok()); - ForClosure for_closure(parent_, type_, state_, std::move(status_), - std::move(buffer_)); - SetState(kReleased); - return for_closure; +void PjRtStreamExecutorBuffer::ScopedHold::AcquireUsageOrExternalReference( + absl::StatusOr buffer_or) { + CHECK(!ok()); + if (buffer_or.ok()) { + buffer_.reset(); + buffer_ptr_ = buffer_or.value(); + SetState(kValid); + } else { + status_ = std::move(buffer_or).status(); + buffer_.reset(); + buffer_ = nullptr; + SetState(kError); + } + // Check the invariant holds. + CHECK(!ok() || buffer_ptr_ != nullptr); } void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold( @@ -571,7 +589,7 @@ void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold( bool reference_held) { CHECK(ok()); CHECK_EQ(type_, kUsage); - parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event), + parent_->ConvertUsageHold(buffer(), usage_stream, std::move(event), reference_held); SetState(kConverted); } @@ -579,7 +597,7 @@ void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold( void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() { CHECK(ok()); CHECK_EQ(type_, kDonation); - parent_->ConfirmDonation(buffer().get()); + parent_->ConfirmDonation(buffer()); SetState(kDonated); } @@ -606,7 +624,8 @@ absl::StatusOr PjRtStreamExecutorBuffer::logical_on_device_shape() { AcquireHoldLocked(&device_buffer); } - WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); + WaitForBufferDefinitionEventsOnStream(device_buffer->definition_events(), + stream); ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_); absl::StatusOr event_or = local_device->event_pool().AllocateEvent(stream->parent()); @@ -664,15 +683,15 @@ class TrackedDeviceBufferExternalReference : public PjRtBuffer::ExternalReference { public: explicit TrackedDeviceBufferExternalReference( - std::shared_ptr tracked_device_buffer) - : tracked_device_buffer_(std::move(tracked_device_buffer)) { - data_ptr_ = tracked_device_buffer_->device_memory()->opaque(); + tsl::RCReference memory) + : memory_(std::move(memory)) { + data_ptr_ = memory_->opaque(); } ~TrackedDeviceBufferExternalReference() override = default; private: - std::shared_ptr tracked_device_buffer_; + tsl::RCReference memory_; }; absl::StatusOr> @@ -682,9 +701,8 @@ PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership( return InvalidArgument( "ReleaseDeviceMemoryOwnership allowed only for non-tuple"); } - TF_ASSIGN_OR_RETURN( - std::shared_ptr tracked_device_buffer, - Release(wait_for_operations_to_complete)); + TF_ASSIGN_OR_RETURN(tsl::RCReference tracked_device_buffer, + Release(wait_for_operations_to_complete)); std::unique_ptr ref; if (tracked_device_buffer) { @@ -709,7 +727,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { } // Copy all the data in the existing tracked_buffer. - auto original_definition_events = tracked_buffer->definition_events(); + const auto& original_definition_events = tracked_buffer->definition_events(); absl::InlinedVector, 4> definition_events; @@ -722,7 +740,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { original_definition_events.begin(), original_definition_events.end()); - auto new_device_buffer = std::make_shared( + auto new_device_buffer = std::make_unique( device(), tracked_buffer->device_memory(), std::move(definition_events)); // Make the new buffer which is identical to the old, except for the new @@ -881,6 +899,9 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( } } + std::shared_ptr event = + device_buffer->definition_events()[0]; + // The host to device transfer is performed on a thread pool, mostly because // it includes linearization that may be slow. It is OK to capture the // py_buffer pointer because the py_buffer can't be deleted until all the @@ -889,8 +910,9 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( // put the transfer into the calling thread for small literals. auto transfer_h2d = [local_client = client(), transfer_manager, local_device, data, size, - type, packed_size, movable_device_buffer{device_buffer.ToClosure()}, - device_shape, should_pack, py_buffer{py_buffer.get()}, + type, packed_size, event, + device_memory_owned = device_buffer->device_memory(), device_shape, + should_pack, py_buffer{py_buffer.get()}, on_device_shape{py_buffer->on_device_shape()}, staging_buffer{std::move(staging_buffer)}, on_done_with_host_buffer = @@ -898,17 +920,14 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( ? std::make_shared>( std::move(on_done_with_host_buffer)) : nullptr, - host_buffer_semantics, transpose{std::move(transpose)}]() { - PjRtStreamExecutorBuffer::ScopedHold device_buffer( - movable_device_buffer); + host_buffer_semantics, transpose{std::move(transpose)}]() mutable { // This function uses TF_CHECK_OK and value() since we have no way // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to // memory that has already been allocated, and a possible Event // allocation. - se::DeviceMemoryBase device_memory = - device_buffer->device_memory()->mem(); + se::DeviceMemoryBase device_memory = device_memory_owned->mem(); // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned @@ -947,11 +966,13 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( &device_memory, data, packed_size)); } - std::shared_ptr event = - device_buffer->definition_events()[0]; TF_CHECK_OK(AddDestinationBufferSynchronization( - local_device, std::move(device_buffer), event, - local_device->host_to_device_stream())); + local_device, event, local_device->host_to_device_stream())); + + local_device + ->ThenRelease(local_device->host_to_device_stream(), + device_memory_owned) + .IgnoreError(); TF_CHECK_OK(local_device->ThenExecuteCallback( local_device->host_to_device_stream(), @@ -963,7 +984,9 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( } })); }; - thread_pool()->Schedule(transfer_h2d); + thread_pool()->Schedule(WrapClosureAsCopyable(std::move(transfer_h2d))); + RecordUsage(std::move(device_buffer), local_device, local_device, event, + local_device->host_to_device_stream()); return std::unique_ptr(std::move(py_buffer)); } @@ -1032,7 +1055,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error, definition_event->SetDefinedStatus(error); // Create an empty buffer. - auto dummy_device_buffer = std::make_shared( + auto dummy_device_buffer = std::make_unique( device, tsl::RCReference(), absl::MakeSpan(&definition_event, 1)); @@ -1070,6 +1093,9 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, py_buffer->GetBufferWithUsageHold()); CHECK(device_buffer.ok()); + std::shared_ptr event = + device_buffer->definition_events()[0]; + // The host to device transfer is performed on a thread pool, mostly because // it includes linearization that may be slow. It is OK to capture the // py_buffer pointer because the py_buffer can't be deleted until all the @@ -1077,10 +1103,10 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. auto transfer_h2d = [local_client = client(), transfer_manager, local_device, - movable_device_buffer{device_buffer.ToClosure()}, - literal, py_buffer{py_buffer.get()}, - on_device_shape{py_buffer->on_device_shape()}]() { - PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer); + device_memory = device_buffer->device_memory(), device, + event, literal, py_buffer{py_buffer.get()}, + on_device_shape{ + py_buffer->on_device_shape()}]() mutable { // This function uses TF_CHECK_OK and value() since we have no way // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to @@ -1089,14 +1115,15 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, se::Stream* h2d_stream = local_device->host_to_device_stream(); - ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape); + ShapedBuffer buffer = + device_memory->AsShapedBuffer(device, on_device_shape); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( h2d_stream, literal, buffer)); - std::shared_ptr event = - device_buffer->definition_events()[0]; - TF_CHECK_OK(AddDestinationBufferSynchronization( - local_device, std::move(device_buffer), event, h2d_stream)); + TF_CHECK_OK( + AddDestinationBufferSynchronization(local_device, event, h2d_stream)); + + local_device->ThenRelease(h2d_stream, device_memory).IgnoreError(); // This can sometimes catch the case where the literal memory has been // freed before the H2D transfer was issued. @@ -1104,7 +1131,9 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, .IgnoreError(); // Can return error::Unimplemented QCHECK(h2d_stream->ok()); }; - thread_pool()->Schedule(transfer_h2d); + thread_pool()->Schedule(WrapClosureAsCopyable(std::move(transfer_h2d))); + RecordUsage(std::move(device_buffer), local_device, local_device, event, + local_device->host_to_device_stream()); return std::unique_ptr(std::move(py_buffer)); } @@ -1173,7 +1202,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( definition_events.back()->SetSequencingEvent(std::move(event), definition_stream); - auto device_buffer = std::make_shared( + auto device_buffer = std::make_unique( device, std::move(buffer), definition_events); return std::unique_ptr(std::make_unique( shape, std::move(device_buffer), this, device, @@ -1317,7 +1346,7 @@ absl::Span PjRtStreamExecutorClient::memory_spaces() } PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer( - Shape on_device_shape, std::shared_ptr device_buffer, + Shape on_device_shape, std::unique_ptr device_buffer, PjRtClient* client, PjRtDevice* device, PjRtMemorySpace* memory_space) : client_(tensorflow::down_cast(client)), on_device_shape_(std::move(on_device_shape)), @@ -1350,10 +1379,10 @@ void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() { mu_.Await(absl::Condition(¬_in_donation_hold)); } -absl::StatusOr> +absl::StatusOr> PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { tsl::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release"); - std::shared_ptr device_buffer; + std::unique_ptr device_buffer; TrackedDeviceBuffer::StreamAndEventContainer events; { absl::MutexLock lock(&mu_); @@ -1362,7 +1391,7 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { // set device_buffer_ to nullptr before returning to this thread. WaitForOutstandingDonationHold(); if (device_buffer_ == nullptr) { - return std::shared_ptr(); + return tsl::RCReference(); } // Set device_buffer_ to null now so that no other // thread can add a hold while we are in WaitForOutstandingUsageHolds() @@ -1373,6 +1402,7 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { // the final set of usage events. events = device_buffer->LockUseAndTransferUsageEvents(); } + auto device_memory = device_buffer->device_memory(); LocalDeviceState* local_device_state = device_->local_device_state(); if (wait_for_operations_to_complete) { // Block the host until all usage events have completed. Usage events @@ -1448,7 +1478,7 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { local_device_state->cleanup_thread()->Schedule( [events_to_wait_for_in_a_different_thread = std::move(events_to_wait_for_in_a_different_thread), - local_device_state, device_buffer, block_stream]() mutable { + local_device_state, device_memory, block_stream]() mutable { for (const auto& event : events_to_wait_for_in_a_different_thread) { MaybeWaitForEventOnStream(event.get(), local_device_state, @@ -1456,20 +1486,20 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { } if (block_stream != nullptr) { TF_CHECK_OK(local_device_state->ThenExecuteCallback( - block_stream, [device_buffer]() { - // Drops device_buffer shared pointer. + block_stream, [device_memory]() { + // Drops device_memory shared pointer. })); } }); } else if (block_stream != nullptr) { TF_RETURN_IF_ERROR(local_device_state->ThenExecuteCallback( - block_stream, [device_buffer]() { - // Drops device_buffer shared pointer. + block_stream, [device_memory]() { + // Drops device_memory shared pointer. })); } } } - return device_buffer; + return device_memory; } void PjRtStreamExecutorBuffer::Delete() { @@ -1489,39 +1519,50 @@ bool PjRtStreamExecutorBuffer::IsDeleted() { return device_buffer_ == nullptr; } -absl::StatusOr> -PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) { +absl::StatusOr +PjRtStreamExecutorBuffer::GetBufferForUsageOrExternalHoldLocked( + ScopedHold::Type type) { // All callers should have called WaitForOutstandingDonationHold(). CHECK_EQ(holds_[ScopedHold::kDonation], 0); - if (type == ScopedHold::kDonation) { - if (device_buffer_ == nullptr) { - return InvalidArgument("Donation requested for invalid buffer"); - } - if (holds_[ScopedHold::kExternalReference] > 0) { - return InvalidArgument( - "Donation requested for buffer with external reference"); - } - // First add the donation hold. - ++holds_[type]; - // Then wait for any usage holds to be dropped or converted. No new usage - // holds can be added until we drop the donation hold so this wait will - // complete eventually. - WaitForOutstandingUsageHolds(); - // Because we added a donation hold, nobody could release the buffer while - // we were waiting. - CHECK(device_buffer_ != nullptr); + if (device_buffer_ == nullptr) { + return InvalidArgument("Buffer has been deleted or donated."); } else { - if (device_buffer_ == nullptr) { - return InvalidArgument("Buffer has been deleted or donated."); - } else { - ++holds_[type]; - } + ++holds_[type]; } - return device_buffer_; + return device_buffer_.get(); +} + +absl::StatusOr> +PjRtStreamExecutorBuffer::GetBufferForDonationHoldLocked() { + // All callers should have called WaitForOutstandingDonationHold(). + CHECK_EQ(holds_[ScopedHold::kDonation], 0); + if (device_buffer_ == nullptr) { + return InvalidArgument("Donation requested for invalid buffer"); + } + if (holds_[ScopedHold::kExternalReference] > 0) { + return InvalidArgument( + "Donation requested for buffer with external reference"); + } + // First add the donation hold. + ++holds_[ScopedHold::kDonation]; + // Then wait for any usage holds to be dropped or converted. No new usage + // holds can be added until we drop the donation hold so this wait will + // complete eventually. + WaitForOutstandingUsageHolds(); + // Because we added a donation hold, nobody could release the buffer while + // we were waiting. + CHECK(device_buffer_ != nullptr); + return std::move(device_buffer_); } void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) { - hold->Acquire(GetBufferForHoldLocked(hold->type())); + if (hold->type() == ScopedHold::kDonation) { + hold->AcquireDonation(GetBufferForDonationHoldLocked()); + return; + } + + hold->AcquireUsageOrExternalReference( + GetBufferForUsageOrExternalHoldLocked(hold->type())); } void PjRtStreamExecutorBuffer::ConvertUsageHold( @@ -1542,7 +1583,6 @@ void PjRtStreamExecutorBuffer::ConfirmDonation( CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); CHECK_EQ(holds_[ScopedHold::kDonation], 1); holds_[ScopedHold::kDonation] = 0; - CHECK(device_buffer_.get() == device_buffer); // As a sanity check ensure no more usage events can be added to the buffer. device_buffer->LockUseAndTransferUsageEvents(); // Give up ownership of the device memory so we don't free it when the last @@ -1554,17 +1594,24 @@ void PjRtStreamExecutorBuffer::ConfirmDonation( } } -void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type, - TrackedDeviceBuffer* buffer) { +void PjRtStreamExecutorBuffer::DropUsageOrExternalHold( + ScopedHold::Type type, TrackedDeviceBuffer* buffer) { absl::MutexLock lock(&mu_); CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); CHECK_GT(holds_[type], 0); --holds_[type]; - if (type == ScopedHold::kDonation) { - CHECK_EQ(holds_[ScopedHold::kDonation], 0); - CHECK_EQ(holds_[ScopedHold::kUsage], 0); - CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); - } +} + +void PjRtStreamExecutorBuffer::DropDonationHold( + std::unique_ptr buffer) { + absl::MutexLock lock(&mu_); + CHECK_EQ(device_buffer_.get(), nullptr); + device_buffer_ = std::move(buffer); + CHECK_GT(holds_[ScopedHold::kDonation], 0); + --holds_[ScopedHold::kDonation]; + CHECK_EQ(holds_[ScopedHold::kDonation], 0); + CHECK_EQ(holds_[ScopedHold::kUsage], 0); + CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); } PjRtFuture<> PjRtStreamExecutorBuffer::LazyToLiteral( @@ -1602,7 +1649,9 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { TransferManager* transfer_manager = client_->client()->backend().transfer_manager(); - auto tracked_device_buffer = device_buffer.buffer(); + auto device_memory = device_buffer->device_memory(); + auto definition_events = device_buffer->definition_events(); + auto first_definition_event = definition_events[0]; // When using the ComputeSynchronized allocation model, retain a // reference to the device_buffer until the copy completes, to @@ -1617,7 +1666,10 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { // to ToLiteral. device_buffer.ConvertUsageHold(stream, usage_event, /*reference_held=*/true); - auto async_to_literal = [usage_event, tracked_device_buffer, stream, + auto async_to_literal = [usage_event, + device_memory = std::move(device_memory), + definition_events = std::move(definition_events), + stream, device = device_, transfer_manager = std::move(transfer_manager), on_device_shape{on_device_shape_}, literal, promise, local_device]() mutable { @@ -1628,16 +1680,16 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { return; } - absl::Status defined_status = - tracked_device_buffer->definition_events()[0]->GetDefinedStatus(); + absl::Status defined_status = definition_events[0]->GetDefinedStatus(); if (!defined_status.ok()) { promise.Set(defined_status); return; } - WaitForBufferDefinitionEventsOnStream(*tracked_device_buffer, stream); + WaitForBufferDefinitionEventsOnStream(absl::MakeSpan(definition_events), + stream); ShapedBuffer shaped_buffer = - tracked_device_buffer->AsShapedBuffer(on_device_shape); + device_memory->AsShapedBuffer(device, on_device_shape); GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata; // We never call device functions from the `done` callback. @@ -1658,13 +1710,13 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { local_device->event_pool().ThenRecordEvent(stream, event_or.value()); usage_event->SetSequencingEvent(std::move(event_or).value(), stream); - defined_status = local_device->ThenRelease(stream, tracked_device_buffer); + defined_status = local_device->ThenRelease(stream, device_memory); if (!defined_status.ok()) { promise.Set(defined_status); } }; - tracked_device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( + first_definition_event->ExecuteOrAddToFutureTasks( absl::StrFormat("async_to_literal_%p", literal), std::move(async_to_literal)); @@ -1731,7 +1783,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( PjRtDevice* dst_device, LocalDeviceState* dst_local_device, PjRtMemorySpace* dst_memory_space, LocalDeviceState* transfer_local_device, LocalDeviceState* src_local_device, se::Stream* transfer_stream, - std::shared_ptr src_device_buffer) { + const TrackedDeviceBuffer& src_device_buffer) { TF_ASSIGN_OR_RETURN(std::unique_ptr py_buffer, AllocateDestinationBuffer( ShapeUtil::DeviceShapeToHostShape(on_device_shape_), @@ -1746,9 +1798,10 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( dst_device_buffer->definition_events()[0]; // Copy the leaf buffers. - auto async_copy_to_device = [src_device_buffer, - dst_device_buffer = - std::move(dst_device_buffer.buffer()), + auto async_copy_to_device = [src_memory = src_device_buffer.device_memory(), + src_definition_events = + src_device_buffer.definition_events(), + dst_memory = dst_device_buffer->device_memory(), transfer_stream = std::move(transfer_stream), copy_event, on_device_shape{py_buffer->on_device_shape()}, @@ -1763,43 +1816,34 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( VLOG(1) << "PjRtStreamExecutorBuffer::CopyToDeviceHelper::async_copy_to_device"; - absl::Status defined_status = - src_device_buffer->definition_events()[0]->GetDefinedStatus(); + absl::Status defined_status = src_definition_events[0]->GetDefinedStatus(); // Only proceeds to transfer when the buffer doesn't hold an error. if (defined_status.ok()) { - WaitForBufferDefinitionEventsOnStream(*src_device_buffer, + WaitForBufferDefinitionEventsOnStream(src_definition_events, transfer_stream); - ShapedBuffer src_buffer = - src_device_buffer->AsShapedBuffer(on_device_shape); - - ShapedBuffer dst_buffer = - dst_device_buffer->AsShapedBuffer(on_device_shape); - for (const auto& leaf : src_buffer.buffers().leaves()) { - const ShapeIndex& index = leaf.first; - const se::DeviceMemoryBase& input_buffer = leaf.second; - const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index); - CHECK_EQ(input_buffer.size(), output_buffer.size()); - if (input_buffer.size() != 0) { - auto status = transfer_local_device->ThenMemcpyDeviceToDevice( - transfer_stream, dst_local_device->compute_stream(), input_buffer, - output_buffer); - if (!status.ok()) { - LOG(ERROR) << "D2D memory copy failed due to: " << status; - StallStreamOnError(transfer_local_device, transfer_stream); - if (transfer_local_device == dst_local_device) { - // Some copies may have been enqueued before the error was - // returned, and StallStreamOnError only makes sure the - // destination device is ok, so make sure that the src buffer - // remains valid until after any transfers have completed. - auto status = src_local_device->ThenRelease( - transfer_stream, std::move(src_device_buffer)); - if (!status.ok()) { - LOG(ERROR) << "ThenRelease failed due to: " << status; - } + const se::DeviceMemoryBase& input_buffer = src_memory->mem(); + const se::DeviceMemoryBase& output_buffer = dst_memory->mem(); + CHECK_EQ(input_buffer.size(), output_buffer.size()); + if (input_buffer.size() != 0) { + auto status = transfer_local_device->ThenMemcpyDeviceToDevice( + transfer_stream, dst_local_device->compute_stream(), input_buffer, + output_buffer); + if (!status.ok()) { + LOG(ERROR) << "D2D memory copy failed due to: " << status; + StallStreamOnError(transfer_local_device, transfer_stream); + if (transfer_local_device == dst_local_device) { + // Some copies may have been enqueued before the error was + // returned, and StallStreamOnError only makes sure the + // destination device is ok, so make sure that the src buffer + // remains valid until after any transfers have completed. + auto status = + src_local_device->ThenRelease(transfer_stream, src_memory); + if (!status.ok()) { + LOG(ERROR) << "ThenRelease failed due to: " << status; } - return; } + return; } } @@ -1817,16 +1861,15 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( copy_event->SetDefinedStatus(defined_status); } - auto status = src_local_device->ThenRelease(transfer_stream, - std::move(src_device_buffer)); + auto status = + src_local_device->ThenRelease(transfer_stream, std::move(src_memory)); if (!status.ok()) { LOG(ERROR) << "ThenRelease failed due to: " << status; } }; - src_device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( - absl::StrFormat("async_copy_to_device_%p", - dst_device_buffer.buffer().get()), + src_device_buffer.definition_events()[0]->ExecuteOrAddToFutureTasks( + absl::StrFormat("async_copy_to_device_%p", dst_device_buffer.buffer()), std::move(async_copy_to_device)); RecordUsage(std::move(dst_device_buffer), transfer_local_device, @@ -1888,8 +1931,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace( std::shared_ptr>> buffer_and_event_or = CopyToDeviceHelper( dst_device, dst_local_device, dst_memory_space, transfer_local_device, - device_->local_device_state(), transfer_stream, - src_device_buffer.buffer()); + device_->local_device_state(), transfer_stream, *src_device_buffer); if (!buffer_and_event_or.ok()) { return buffer_and_event_or.status(); } @@ -1925,7 +1967,8 @@ void PjRtStreamExecutorBuffer::CopyToRemoteDevice( } PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() { - std::shared_ptr device_buffer; + absl::InlinedVector, 2> + definition_events; PjRtFuture<>::Promise definition_promise; { absl::MutexLock lock(&mu_); @@ -1934,25 +1977,27 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() { "GetReadyFuture() called on deleted or donated buffer")); } if (!definition_promise_) { - device_buffer = device_buffer_; + definition_events = device_buffer_->definition_events(); definition_promise_ = PjRtFuture<>::CreatePromise(); } definition_promise = definition_promise_; } - if (device_buffer) { + if (!definition_events.empty()) { LocalDeviceState* local_device_state = device_->local_device_state(); + auto first_definition_event = definition_events[0]; auto async_wait_for_events = - [device_buffer, local_device_state = std::move(local_device_state), + [definition_events = std::move(definition_events), + local_device_state = std::move(local_device_state), definition_promise]() mutable { std::unique_ptr stream; absl::Status defined_status = - device_buffer->definition_events()[0]->GetDefinedStatus(); + definition_events[0]->GetDefinedStatus(); if (!defined_status.ok()) { definition_promise.Set(defined_status); return; } - for (auto& event : device_buffer->definition_events()) { + for (auto& event : definition_events) { if (!event->IsComplete()) { if (stream == nullptr) { stream = local_device_state->BorrowStreamFromPool(); @@ -1969,8 +2014,7 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() { // saves significant time. auto status = stream_ptr->DoHostCallback( [definition_promise, stream_ptr, local_device_state, - event_with_status = - device_buffer->definition_events()[0]]() mutable { + event_with_status = definition_events[0]]() mutable { local_device_state->ReturnStreamToPool( std::unique_ptr(stream_ptr)); definition_promise.Set(event_with_status->GetDefinedStatus()); @@ -1983,11 +2027,10 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() { // All events are already complete; set the `definition_promise` // with the status of the buffer's first definition event which may // have error status to propagate. - definition_promise.Set( - device_buffer->definition_events()[0]->GetDefinedStatus()); + definition_promise.Set(definition_events[0]->GetDefinedStatus()); } }; - device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( + first_definition_event->ExecuteOrAddToFutureTasks( absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events), std::move(async_wait_for_events)); } @@ -2157,7 +2200,7 @@ absl::StatusOr> OutputBufferHelper( ShapeTree> result_buffer, std::shared_ptr definition_event, PjRtClient* client, PjRtDevice* device, LocalDeviceState* local_device, - std::vector>& buffers_to_release) { + std::vector>& buffers_to_release) { if (result_buffer.shape().IsTuple()) { return absl::InternalError("OutputBufferHelper called on tuple."); } @@ -2165,7 +2208,7 @@ absl::StatusOr> OutputBufferHelper( for (auto& item : result_buffer) { buffers.push_back(std::move(item.second)); } - auto out_buffer = std::make_shared( + auto out_buffer = std::make_unique( device, std::move(buffers[0]), absl::Span>{ definition_event}); @@ -2906,7 +2949,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( ShapeTree> result_buffer, std::shared_ptr definition_event, PjRtDevice* device, std::vector>& compute_callbacks, - std::vector>& buffers_to_release) + std::vector>& buffers_to_release) const { tsl::profiler::TraceMe traceme("MakeOutputBuffers"); std::vector> outputs; @@ -3047,7 +3090,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( auto definition_event = std::make_shared(client_->thread_pool()); definition_event->SetSequencingEvent(std::move(event_or).value(), stream); - std::vector> buffers_to_release; + std::vector> buffers_to_release; TF_ASSIGN_OR_RETURN( std::vector> outputs, MakeOutputBuffers(device_ordinal, options, std::move(result_buffer), diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 76b43fb985a684..f8fee79850bdef 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -571,12 +571,12 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { bool ok() const { return state_ == kValid; } // Access to the underlying device buffer storage. Requires this->ok(). - const std::shared_ptr& buffer() const { + TrackedDeviceBuffer* buffer() const { CHECK_EQ(state_, kValid); - CHECK_NE(buffer_, nullptr); - return buffer_; + CHECK_NE(buffer_ptr_, nullptr); + return buffer_ptr_; } - TrackedDeviceBuffer* operator->() const { return buffer().get(); } + TrackedDeviceBuffer* operator->() const { return buffer(); } const TrackedDeviceBuffer& operator*() const { return *buffer(); } // Converts the hold into a usage event. Only valid for holds of type @@ -614,35 +614,21 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { friend class PjRtStreamExecutorBuffer; friend class PjRtStreamExecutorClient; - // Helper struct that makes it possible to move a ScopedHold through a - // closure. - using ForClosure = - std::tuple>; - ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) : parent_(parent), type_(type), state_(kUninitialized) {} - explicit ScopedHold(const ForClosure& closure_helper) - : parent_(std::get<0>(closure_helper)), - type_(std::get<1>(closure_helper)), - state_(std::get<2>(closure_helper)), - status_(std::get<3>(closure_helper)), - buffer_(std::get<4>(closure_helper)) { - // Check the buffer is not in an error state. - CHECK(status_.ok() && buffer_ != nullptr); - } // Sets buffer state. void SetState(State state) { state_ = state; } - // Sets buffer_ and status_. Called by parent_ to initialize the hold. - void Acquire( - absl::StatusOr>&& buffer_or); - // Releases the contents of *this, so *this can subsequently be - // deleted without releasing the parent's hold. Should be passed to the - // appropriate constructor of another ScopedHold, e.g., when a hold must be - // passed through a closure that is incompatible with std::move. - ForClosure ToClosure(); + // Acquires the unique ownership of the buffer. Called by parent_ to + // initialize the donation hold. + void AcquireDonation( + absl::StatusOr> buffer_or); + + // Acquires a non-owning reference of the buffer. Called by parent_ to + // initialize the usage or external reference hold. + void AcquireUsageOrExternalReference( + absl::StatusOr buffer_or); PjRtStreamExecutorBuffer* const parent_; const Type type_; @@ -651,11 +637,16 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // buffer_.value() != nullptr. State state_; absl::Status status_; - std::shared_ptr buffer_; + // The non-owning pointer to the underlying buffer. It is not nullptr for + // all types of holds. + TrackedDeviceBuffer* buffer_ptr_ = nullptr; + // If it is a donation hold, `buffer_` will not be nullptr. Otherwise, it is + // a nullptr. + std::unique_ptr buffer_; }; PjRtStreamExecutorBuffer(Shape on_device_shape, - std::shared_ptr device_buffer, + std::unique_ptr device_buffer, PjRtClient* client, PjRtDevice* device, PjRtMemorySpace* memory_space); ~PjRtStreamExecutorBuffer() override; @@ -749,7 +740,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // If the buffer was shared via an external reference it is the client's // responsibility that accesses via that reference do not interfere with // accesses via the buffer returned from Release. - absl::StatusOr> Release( + absl::StatusOr> Release( bool wait_for_operations_to_complete); absl::StatusOr> DonateWithControlDependency( @@ -769,7 +760,14 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // an outstanding external hold. // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() // must be called first.) - absl::StatusOr> GetBufferForHoldLocked( + absl::StatusOr> + GetBufferForDonationHoldLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Adds a hold of usage or external reference and returns non-owning + // device_buffer_. Returns an error if device_buffer_ is null. + // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() + // must be called first.) + absl::StatusOr GetBufferForUsageOrExternalHoldLocked( ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Adds a hold of hold->type() and initializes `hold` with device_buffer_. @@ -793,7 +791,12 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // Drops a hold without taking any other action. Does a sanity check that // buffer==device_buffer_ or device_buffer_==nullptr. - void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); + void DropUsageOrExternalHold(ScopedHold::Type type, + TrackedDeviceBuffer* buffer); + + // Drops a hold without taking any other action. Does a sanity check that + // buffer==device_buffer_ or device_buffer_==nullptr. + void DropDonationHold(std::unique_ptr buffer); absl::StatusOr, std::shared_ptr>> @@ -802,7 +805,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { LocalDeviceState* transfer_local_device, LocalDeviceState* src_local_device, se::Stream* transfer_stream, - std::shared_ptr src_device_buffer); + const TrackedDeviceBuffer& src_device_buffer); absl::StatusOr> CopyToDeviceMemorySpace( PjRtDevice* dst_device, PjRtMemorySpace* dst_memory_space = nullptr); @@ -812,7 +815,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { PjRtMemorySpace* const memory_space_; mutable absl::Mutex mu_; - std::shared_ptr device_buffer_ ABSL_GUARDED_BY(mu_); + std::unique_ptr device_buffer_ ABSL_GUARDED_BY(mu_); // Count of holds on the buffer. std::array holds_ ABSL_GUARDED_BY(mu_); PjRtFuture<>::Promise definition_promise_ ABSL_GUARDED_BY(mu_); @@ -999,7 +1002,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { std::shared_ptr definition_event, PjRtDevice* device, std::vector>& compute_callbacks, - std::vector>& buffers_to_release) + std::vector>& buffers_to_release) const; absl::StatusOr ExecuteHelper( diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index 9241d470f4da0b..83d2231dd0f12c 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -187,6 +187,19 @@ void BufferSequencingEvent::ExecuteFutureTasks() { thread_pool_->Schedule(std::move(call_all_task_callbacks)); } +ShapedBuffer RawSEDeviceMemory::AsShapedBuffer( + PjRtDevice* device, const Shape& on_device_shape) const { + ShapedBuffer shaped_buffer(on_device_shape, device->local_device_id().value(), + device->local_hardware_id().value()); + ShapeTree::iterator iterator = + shaped_buffer.buffers().begin(); + CHECK(iterator != shaped_buffer.buffers().end()); + iterator->second = mem(); + ++iterator; + CHECK(iterator == shaped_buffer.buffers().end()); + return shaped_buffer; +} + class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory { public: AllocatedRawSEDeviceMemory(se::DeviceMemoryBase value, int device_ordinal, @@ -321,12 +334,20 @@ void GetDeviceBufferEvents( } } -void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, - se::Stream* stream) { - absl::flat_hash_set events; - GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events); - for (BufferSequencingEvent* event : events) { - event->WaitForEventOnStream(stream); +void WaitForBufferDefinitionEventsOnStream( + absl::Span> definition_events, + se::Stream* stream) { + if (definition_events.size() <= 1) { + for (const auto& event : definition_events) { + event->WaitForEventOnStream(stream); + } + } else { + absl::flat_hash_set events; + for (const auto& event : definition_events) { + if (events.emplace(event.get()).second) { + event->WaitForEventOnStream(stream); + } + } } } diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index c3a3be91aaf773..6af9e6fd5cdbc0 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -210,6 +210,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted { // buffer. virtual void UnsafeReleaseMemory() = 0; + // Builds a ShapedBuffer which points to mem() of shape on_device_shape. + ShapedBuffer AsShapedBuffer(PjRtDevice* device, + const Shape& on_device_shape) const; + static tsl::RCReference Create( se::DeviceMemoryBase value, PjRtLocalDeviceId device_id, se::DeviceMemoryAllocator* allocator); @@ -269,8 +273,8 @@ class TrackedDeviceBuffer { return device_memory_; } - absl::Span> definition_events() - const { + const absl::InlinedVector, 2>& + definition_events() const { return definition_events_; } absl::Span usage_events() const { @@ -341,8 +345,9 @@ void GetDeviceBufferEvents(const TrackedDeviceBuffer& buffer, absl::flat_hash_set* events); // Waits for all of the definition events in a buffer on 'stream'. -void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, - se::Stream* stream); +void WaitForBufferDefinitionEventsOnStream( + absl::Span> definition_events, + se::Stream* stream); } // namespace xla From d5272b6ab7206c7ae8acbcbd226812861472f148 Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Fri, 4 Apr 2025 16:18:16 -0700 Subject: [PATCH 0258/1324] Remove SplitMatchAndRewrite usages in TensorFlow. This was deleted in the upstream: https://github.com/llvm/llvm-project/commit/69f59d59cb02c06f1fac93ea5b19c2df9a684109 PiperOrigin-RevId: 744101541 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 35 +- .../compose_uniform_quantized_type_pass.cc | 58 +++- ...uniform_quantized_stablehlo_to_tfl_pass.cc | 327 +++++++++--------- .../mlir/lite/transforms/optimize_pass.cc | 16 +- .../passes/defer_activation_transpose.cc | 57 +-- .../passes/fold_constant_transpose.cc | 22 +- .../stablehlo/passes/insert_weight_param.cc | 50 ++- .../passes/nchw_convolution_to_nhwc.cc | 22 +- .../stablehlo/passes/quantization_patterns.cc | 45 +-- .../tensorflow/passes/add_dump_tensor_op.cc | 18 +- .../tensorflow/passes/cast_bf16_ops_to_f32.cc | 17 +- .../passes/remove_var_init_by_const.cc | 17 +- .../prepare_tpu_computation_for_tf_export.cc | 13 +- .../mlir/dtensor_layout_to_xla_sharding_op.cc | 42 +-- .../stablehlo_add_quant_dequant_conv.cpp | 13 +- 15 files changed, 395 insertions(+), 357 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index dfa9f5b094b949..ddb9a9d1017a63 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -2230,21 +2230,18 @@ namespace { // * The input's defining op is another tfl.reshape. // TODO(antiagainst): This pattern probably should be moved to the peephole // category, after we have the infra for peephole passes. -struct RemoveAdjacentReshape : public RewritePattern::SplitMatchAndRewrite { +struct RemoveAdjacentReshape : public RewritePattern { explicit RemoveAdjacentReshape(MLIRContext* context) - : RewritePattern::SplitMatchAndRewrite(ReshapeOp::getOperationName(), 1, - context) {} + : RewritePattern(ReshapeOp::getOperationName(), 1, context) {} - LogicalResult match(Operation* op) const override { - auto thisOp = cast(op); - auto prevOp = thisOp.getOperand(0).getDefiningOp(); - return isa_and_nonnull(prevOp) ? success() : failure(); - } - - void rewrite(Operation* op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { auto thisOp = cast(op); - auto prevOp = cast(thisOp.getOperand(0).getDefiningOp()); - + auto prevOp = + dyn_cast_or_null(thisOp.getOperand(0).getDefiningOp()); + if (!prevOp) { + return failure(); + } // Replace // %1 = "tfl.reshape"(%0, %shape0) // %2 = "tfl.reshape"(%1, %shape1) @@ -2252,6 +2249,7 @@ struct RemoveAdjacentReshape : public RewritePattern::SplitMatchAndRewrite { // %2 = "tfl.reshape"(%0, %shape1) rewriter.replaceOpWithNewOp( op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1)); + return success(); } }; @@ -2963,12 +2961,12 @@ namespace { /// This pattern matches and remove a tfl.fake_quant if all the users of this op /// and itself have "minmax" attribute set. -struct DropFakeQuant : public RewritePattern::SplitMatchAndRewrite { +struct DropFakeQuant : public RewritePattern { explicit DropFakeQuant(MLIRContext* context) - : RewritePattern::SplitMatchAndRewrite(FakeQuantOp::getOperationName(), 1, - context) {} + : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {} - LogicalResult match(Operation* op) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { // We only match the op with valid "minmax" attribute. if (!HasValidMinMaxAttribute(op)) return failure(); @@ -2978,12 +2976,9 @@ struct DropFakeQuant : public RewritePattern::SplitMatchAndRewrite { for (auto* operand : fakeQuantOp.getResult().getUsers()) if (!HasValidMinMaxAttribute(operand)) return failure(); - return success(); - } - - void rewrite(Operation* op, PatternRewriter& rewriter) const override { // Replace the matched FakeQuantOp by its primary operand. rewriter.replaceOp(op, op->getOperand(0)); + return success(); } }; } // end anonymous namespace diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index c94fd5cd5fede4..4107859b7412af 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -427,11 +427,21 @@ class UniformDequantizeFunctionCallPattern { // %4 = stablehlo.uniform_dequantize %3 // Dequantize the output. // ``` class ComposeUniformQuantizedConvolutionOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConvolutionOp op) const final { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::ConvolutionOp op) const { // Verify operands' types. for (Type operand_type : op.getOperandTypes()) { if (Type element_type = @@ -643,8 +653,7 @@ class ComposeUniformQuantizedConvolutionOp return success(); } - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const { // Rewrite `call @uniform_quantize` -> `stablehlo.uniform_quantize`. auto input_i8_to_f32_convert_op = cast(op.getOperand(0).getDefiningOp()); @@ -881,10 +890,21 @@ class ComposeUniformQuantizedConvolutionOp // cast isn't present, the filter constant (%3) should be i8 quantized values // disguised in f32. class ComposeUniformQuantizedDotGeneralOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - LogicalResult match(stablehlo::DotGeneralOp op) const final { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { auto input_i8_to_f32_convert_op = TryCast(op.getOperand(0).getDefiningOp(), /*name=*/"input_i8_to_f32_convert_op"); @@ -988,8 +1008,7 @@ class ComposeUniformQuantizedDotGeneralOp return success(); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { // Build uniform quantized type for input. auto input_i8_to_f32_convert_op = cast(op.getOperand(0).getDefiningOp()); @@ -1304,11 +1323,21 @@ class ComposeUniformQuantizedDotGeneralOp // %5 = stablehlo.uniform_dequantize %4 // Dequantize the output. // ``` class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const final { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { // q1 - z1 if (failed(MatchQuantizedOperand(op.getOperand(0)))) { LLVM_DEBUG(llvm::dbgs() @@ -1365,8 +1394,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations return success(); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { // Build uniform quantized type for input 1 (lhs). auto input1_zero_point_subtract_op = cast(op.getOperand(0).getDefiningOp()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index 13bc60b1648e3a..e2571e2ad84af5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -444,15 +444,15 @@ int64_t GetConvolutionKernelInputFeatureDimension(bool is_depthwise) { // stablehlo.uniform_quantize -> tfl.quantize // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteUniformQuantizeOp - : public OpRewritePattern< - stablehlo::UniformQuantizeOp>::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; // Determines whether the input and output types are compatible with // `tfl.quantize`. See the definition for the `QUANTIZE` kernel for the // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/quantize.cc#L105). - LogicalResult match(stablehlo::UniformQuantizeOp op) const override { + LogicalResult matchAndRewrite(stablehlo::UniformQuantizeOp op, + PatternRewriter& rewriter) const override { const Type input_element_type = GetElementType(op.getOperand()); if (!(input_element_type.isa() || IsI32F32UniformQuantizedType(input_element_type) || @@ -475,29 +475,25 @@ class RewriteUniformQuantizeOp return failure(); } - return success(); - } - - void rewrite(stablehlo::UniformQuantizeOp op, - PatternRewriter& rewriter) const override { Type output_type = *op->getResultTypes().begin(); rewriter.replaceOpWithNewOp( op, output_type, /*input=*/op.getOperand(), /*qtype=*/TypeAttr::get(output_type)); + return success(); } }; // stablehlo.uniform_dequantize -> tfl.dequantize class RewriteUniformDequantizeOp - : public OpRewritePattern< - stablehlo::UniformDequantizeOp>::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; // Determines whether the input and output types are compatible with // `tfl.dequantize`. See the definition for the `DEQUANTIZE` kernel for the // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/dequantize.cc#L52). - LogicalResult match(stablehlo::UniformDequantizeOp op) const override { + LogicalResult matchAndRewrite(stablehlo::UniformDequantizeOp op, + PatternRewriter& rewriter) const override { const auto input_storage_type = GetElementType(op.getOperand()) .cast() .getStorageType() @@ -518,13 +514,9 @@ class RewriteUniformDequantizeOp return failure(); } - return success(); - } - - void rewrite(stablehlo::UniformDequantizeOp op, - PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( op, /*resultTypes=*/op->getResultTypes(), /*input=*/op.getOperand()); + return success(); } }; @@ -563,17 +555,26 @@ class RewriteUniformDequantizeOp // * The filter tensor's rank is 2. The contracting dimension should be the // first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: // Sets benefit to 10 to make this pattern more preferred than smaller local // transformations like `stablehlo.transpose`->`tfl.transpose`, as this // pattern involves `stablehlo.transpose` in some cases. explicit RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp( MLIRContext* ctx) - : OpRewritePattern::SplitMatchAndRewrite( - ctx, /*benefit=*/10) {} + : OpRewritePattern(ctx, /*benefit=*/10) {} - LogicalResult match(stablehlo::DotGeneralOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = op.getDotDimensionNumbers(); const bool is_batch_matmul = !IsDotGeneralFullyConnected(op).value(); @@ -605,8 +606,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp has_i32_output); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { const Type output_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(output_type) || @@ -624,7 +624,6 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } } - private: static LogicalResult MatchDotGeneralToTflBatchMatmulOp( stablehlo::DotGeneralOp op, const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, @@ -1007,10 +1006,21 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // * The filter tensor's format is `[0, 1, i, o]`. // * Not a depthwise convolution. class RewriteQuantizedConvolutionOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - LogicalResult match(stablehlo::ConvolutionOp op) const override { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::ConvolutionOp op) const { const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); const bool fuse_bias_constant = @@ -1056,8 +1066,7 @@ class RewriteQuantizedConvolutionOp return success(); } - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const { const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); stablehlo::ConvDimensionNumbersAttr dimension_numbers = @@ -1148,7 +1157,6 @@ class RewriteQuantizedConvolutionOp } } - private: static LogicalResult MatchInput(Value input) { auto input_type = input.getType().cast(); if (const auto input_element_type = input_type.getElementType(); @@ -1481,16 +1489,15 @@ class RewriteQuantizedConvolutionOp // Rewrites quantized `stablehlo.transpose` to `tfl.transpose`. class RewriteQuantizedTransposeOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::TransposeOp op) const override { - return success(IsOpFullyQuantized(op)); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::TransposeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } auto operand_type = op.getOperand().getType().cast(); const int64_t rank = operand_type.getRank(); ArrayRef shape(rank); @@ -1506,54 +1513,54 @@ class RewriteQuantizedTransposeOp rewriter.create(op.getLoc(), permutation_attr); rewriter.replaceOpWithNewOp(op, op.getOperand(), permutation); + return success(); } }; // Rewrites quantized stablehlo.reshape to tfl.reshape. class RewriteQuantizedReshapeOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ReshapeOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::ReshapeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ReshapeOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } rewriter.replaceOpWithNewOp( op, op.getOperand(), CreateI32ShapeConstantOp(op.getResult().getType(), op->getLoc(), rewriter)); + return success(); } }; class RewriteQuantizedDynamicReshapeOp - : public OpRewritePattern< - stablehlo::DynamicReshapeOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::DynamicReshapeOp op) const override { - return success(IsQuantizedTensorType(op.getOperand().getType()) && - IsQuantizedTensorType(op.getResult().getType())); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::DynamicReshapeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + if (!IsQuantizedTensorType(op.getOperand().getType()) || + !IsQuantizedTensorType(op.getResult().getType())) { + return failure(); + } rewriter.replaceOpWithNewOp(op, op.getOperand(), op.getOutputShape()); + return success(); } }; // Rewrites quantized stablehlo.select to tfl.select_v2. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteQuantizedSelectOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedSelectOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::SelectOp op) const override { + LogicalResult matchAndRewrite(stablehlo::SelectOp op, + PatternRewriter& rewriter) const override { if (!IsQuantizedTensorType(op.getOperand(1).getType())) { return failure(); } @@ -1563,52 +1570,47 @@ class RewriteQuantizedSelectOp if (!IsQuantizedTensorType(op.getResult().getType())) { return failure(); } - return success(); - } - - void rewrite(stablehlo::SelectOp op, - PatternRewriter& rewriter) const override { Value pred = op.getOperand(0); Value on_true = op.getOperand(1); Value on_false = op.getOperand(2); rewriter.replaceOpWithNewOp(op, pred, on_true, on_false); + return success(); } }; // Rewrites quantized stablehlo.concatenate to tfl.concatenation. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedConcatenateOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConcatenateOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::ConcatenateOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ConcatenateOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } Type output_type = op.getResult().getType(); uint32_t axis = CastI64ToI32(op.getDimension()).value(); rewriter.replaceOpWithNewOp( op, output_type, op.getOperands(), axis, /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + return success(); } }; // Rewrites quantized stablehlo.pad to tfl.padv2. // tfl.dilate is introduced in between when interior padding exists. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteQuantizedPadOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedPadOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::PadOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::PadOp op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::PadOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } Value input = op.getOperand(); // If any of the interior padding is non-zero, operand should be dilated // first, and then padded. @@ -1639,6 +1641,7 @@ class RewriteQuantizedPadOp rewriter.create(op.getLoc(), padding_attr); rewriter.replaceOpWithNewOp(op, output_type, input, padding, constant_values); + return success(); } Value InsertDilateOp(stablehlo::PadOp op, PatternRewriter& rewriter) const { @@ -1673,17 +1676,15 @@ class RewriteQuantizedPadOp }; // Rewrites quantized stablehlo.slice to tfl.slice or tfl.strided_slice. -class RewriteQuantizedSliceOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedSliceOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::SliceOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::SliceOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::SliceOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } auto operand_type = op.getOperand().getType().cast(); Type output_type = op.getResult().getType(); const int64_t rank = operand_type.getRank(); @@ -1716,7 +1717,7 @@ class RewriteQuantizedSliceOp if (llvm::all_of(strides, [](int64_t stride) { return stride == 1; })) { rewriter.replaceOpWithNewOp( op, output_type, op.getOperand(), start_idx, slice_size); - return; + return success(); } SmallVector stride_i32 = CastI64ArrayToI32(strides).value(); @@ -1727,6 +1728,7 @@ class RewriteQuantizedSliceOp /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/0, /*new_axis_mask=*/0, /*shrink_axis_mask=*/0, /*offset=*/false); + return success(); } }; @@ -1736,17 +1738,15 @@ class RewriteQuantizedSliceOp // output rank. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedBroadcastInDimOp - : public OpRewritePattern< - stablehlo::BroadcastInDimOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::BroadcastInDimOp op) const override { - return success(IsOpFullyQuantized(op)); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::BroadcastInDimOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } auto operand_type = op.getOperand().getType().cast(); auto output_type = op.getResult().getType().cast(); Value input = op.getOperand(); @@ -1773,6 +1773,7 @@ class RewriteQuantizedBroadcastInDimOp rewriter.replaceOpWithNewOp(op, output_type, input, shape); + return success(); } Value InsertTransposeOp(stablehlo::BroadcastInDimOp op, @@ -1834,10 +1835,20 @@ class RewriteQuantizedBroadcastInDimOp // Rewrites quantized stablehlo.reduce_window with max to tfl.max_pool_2d. class RewriteQuantizedReduceWindowOpWithMax - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: LogicalResult MatchBinaryReduceFunction(Region& function) const { Block& body = function.front(); if (body.getNumArguments() != 2) return failure(); @@ -1853,7 +1864,7 @@ class RewriteQuantizedReduceWindowOpWithMax reduce_op.getRhs() == body.getArgument(1)); } - LogicalResult match(stablehlo::ReduceWindowOp op) const override { + LogicalResult match(stablehlo::ReduceWindowOp op) const { // Check that the reduce-window is a max-reduce-window. if (failed(MatchBinaryReduceFunction(op.getBody()))) { return failure(); @@ -1887,8 +1898,7 @@ class RewriteQuantizedReduceWindowOpWithMax return success(IsOpFullyQuantized(op)); } - void rewrite(stablehlo::ReduceWindowOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::ReduceWindowOp op, PatternRewriter& rewriter) const { Type result_type = op.getResult(0).getType(); Value input = op.getOperand(0); // Ops with padding is rejected in matching function, so we can use the @@ -1929,12 +1939,12 @@ class RewriteQuantizedReduceWindowOpWithMax // Condition 3 - `offset_dims` should be the last dimensions of `output`. // Condition 4 - shape of slice should be same with shape of input on the // offset dimensions. -class RewriteQuantizedGatherOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedGatherOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::GatherOp op) const override { + LogicalResult matchAndRewrite(stablehlo::GatherOp op, + PatternRewriter& rewriter) const override { const Type input_type = op.getOperand().getType(); const Type output_type = op.getResult().getType(); if (!IsQuantizedTensorType(input_type) || @@ -2014,35 +2024,28 @@ class RewriteQuantizedGatherOp } } - return success(); - } - - void rewrite(stablehlo::GatherOp op, - PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( op, /*output=*/op.getResult().getType(), /*params=*/op.getOperand(), /*indices=*/op.getStartIndices()); + return success(); } }; // Rewrites quantized stablehlo.dynamic_slice to tfl.slice. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedDynamicSliceOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DynamicSliceOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DynamicSliceOp op, + PatternRewriter& rewriter) const override { if (!IsQuantizedTensorType(op.getOperand().getType()) || - !IsQuantizedTensorType(op.getResult().getType())) { + !IsQuantizedTensorType(op.getResult().getType()) || + !quant::HasStaticShape(op.getOperand())) { return failure(); } - return success(quant::HasStaticShape(op.getOperand())); - } - - void rewrite(stablehlo::DynamicSliceOp op, - PatternRewriter& rewriter) const override { Type output = op.getResult().getType(); Value input = op.getOperand(); TensorType operand_type = input.getType().cast(); @@ -2098,20 +2101,20 @@ class RewriteQuantizedDynamicSliceOp auto size = rewriter.create(op.getLoc(), size_attr); rewriter.replaceOpWithNewOp(op, output, input, begin, size); + return success(); } }; -class RewriteQuantizedAddOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedAddOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::AddOp op) const override { - return success(IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) && - IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))); - } - - void rewrite(stablehlo::AddOp op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::AddOp op, + PatternRewriter& rewriter) const override { + if (!IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) || + !IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))) { + return failure(); + } TFL::QConstOp lhs_qconst_op; TFL::QConstOp rhs_qconst_op; @@ -2137,24 +2140,25 @@ class RewriteQuantizedAddOp lhs_qconst_op ? lhs_qconst_op : op.getOperand(0), rhs_qconst_op ? rhs_qconst_op : op.getOperand(1), /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + return success(); } }; // Rewrites quantized `stablehlo.constant` to `tfl.pseudo_qconst`. class RewriteQuantizedConstantOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::ConstantOp op) const override { - return success(IsQuantizedTensorType(op.getOutput().getType())); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::ConstantOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + if (!IsQuantizedTensorType(op.getOutput().getType())) { + return failure(); + } rewriter.replaceOpWithNewOp( op, /*qtype=*/TypeAttr::get(op.getOutput().getType()), /*value=*/op.getValue()); + return success(); } }; @@ -2163,19 +2167,18 @@ class RewriteQuantizedConstantOp // `stablehlo.dot_general` op relies on existing passes for conversion of // StableHLO -> MHLO -> TF -> TFL. class RewriteHybridQuantizedDotGeneralOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { // Lhs and result should not be quantized and rhs should be quantized. - return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && - IsQuantizedTensorType(op->getOperand(1).getType()) && - !IsQuantizedTensorType(op->getResult(0).getType())); - } - - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { + if (IsQuantizedTensorType(op->getOperand(0).getType()) || + !IsQuantizedTensorType(op->getOperand(1).getType()) || + IsQuantizedTensorType(op->getResult(0).getType())) { + return failure(); + } Value rhs = op.getRhs(); Type lhs_element_type = op.getLhs().getType().template cast().getElementType(); @@ -2185,6 +2188,7 @@ class RewriteHybridQuantizedDotGeneralOp op->getLoc(), /*output=*/dequantized_rhs_type, /*input=*/rhs); rewriter.replaceAllUsesExcept(rhs, dq.getOutput(), dq); + return success(); } }; @@ -2194,26 +2198,24 @@ class RewriteHybridQuantizedDotGeneralOp // Legalization of float `stablehlo.convolution` op relies on existing passes // for conversion of StableHLO -> MHLO -> TF -> TFL. class RewriteHybridQuantizedConvolutionOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: explicit RewriteHybridQuantizedConvolutionOp(MLIRContext* ctx) - : OpRewritePattern::SplitMatchAndRewrite( - ctx, /*benefit=*/5) {} + : OpRewritePattern(ctx, /*benefit=*/5) {} - LogicalResult match(stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { if (failed(MatchConvolutionFormat(op))) { LLVM_DEBUG(llvm::dbgs() << "Failed to match dimension format for convolution_op.\n"); return failure(); } // Lhs and result should not be quantized and rhs should be quantized. - return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && - IsQuantizedTensorType(op->getOperand(1).getType()) && - !IsQuantizedTensorType(op->getResult(0).getType())); - } - - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { + if (IsQuantizedTensorType(op->getOperand(0).getType()) || + !IsQuantizedTensorType(op->getOperand(1).getType()) || + IsQuantizedTensorType(op->getResult(0).getType())) { + return failure(); + } const bool is_depthwise = IsDepthwiseConvolution(op); Operation* filter_op = op.getRhs().getDefiningOp(); @@ -2243,6 +2245,7 @@ class RewriteHybridQuantizedConvolutionOp op->getLoc(), /*output=*/dequantized_rhs_type, /*input=*/new_filter); rewriter.replaceAllUsesExcept(filter_op->getResult(0), dq.getOutput(), dq); + return success(); } private: diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index 2816e8bb090f1a..1de8398dbd8bd9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -2516,9 +2516,11 @@ struct EliminateQDQPairs : public OpRewritePattern { // (HasRankAtLeast<2> $bias), // (IsDefinedByFullyConnectedOp $lhs)]>; struct UndoBroadcastFullyConnectedBiasAddWithQDQs - : public OpRewritePattern::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; - LogicalResult match(TFL::AddOp add_op) const override { + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::AddOp add_op, + PatternRewriter &rewriter) const override { if (!add_op->hasOneUse()) { return failure(); } @@ -2557,13 +2559,6 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs return failure(); } - return success(); - } - - void rewrite(TFL::AddOp add_op, PatternRewriter &rewriter) const override { - auto dq_op = cast(add_op.getRhs().getDefiningOp()); - auto q_op = cast(dq_op.getInput().getDefiningOp()); - auto bias_op = cast(q_op.getInput().getDefiningOp()); auto new_bias = FlattenTo1D(bias_op.getValueAttr()); auto new_bias_type = new_bias.getType(); auto new_bias_op = rewriter.create( @@ -2588,6 +2583,7 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs // Remove old bias rewriter.eraseOp(bias_op); + return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc index b16f8787e9ea35..e40c58e9f6f535 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/base/nullability.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -98,12 +99,12 @@ absl::Nullable SkipUpwardsOptionalBroadcastInDimOp( return op; } -class DeferActivationTransposeForAddOp - : public OpRewritePattern::SplitMatchAndRewrite { +class DeferActivationTransposeForAddOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(AddOp op) const override { + LogicalResult matchAndRewrite(AddOp op, + PatternRewriter& rewriter) const override { // Only supports the case for 2D convolution. const Value lhs = op.getOperand(0); if (!HasRankOf(lhs, /*rank=*/4)) return failure(); @@ -120,12 +121,13 @@ class DeferActivationTransposeForAddOp } // Match LHS permutation that converts: NHWC -> NCHW. - return IsTransposeOpWithPermuation(lhs.getDefiningOp(), - kNhwcToNchwPermutation); - } + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } - void rewrite(AddOp op, PatternRewriter& rewriter) const override { DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); } }; @@ -134,12 +136,12 @@ class DeferActivationTransposeForAddOp // to the result. The reduce function should be equivalent to // `stablehlo.maximum`, representing max pooling. class DeferActivationTransposeForMaxPoolReduceWindowOp - : public OpRewritePattern< - mlir::stablehlo::ReduceWindowOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::ReduceWindowOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { if (failed(MatchMaxPoolReduceWindowOp(op))) return failure(); // Match only when the lhs is connected to a transpose. @@ -148,13 +150,12 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp if (!HasRankOf(lhs, /*rank=*/4)) return failure(); // Match input permutation that converts: NHWC -> NCHW. - return IsTransposeOpWithPermuation(lhs.getDefiningOp(), - kNhwcToNchwPermutation); - } + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } - // Pushes the transpose op at the input to the result. - void rewrite(mlir::stablehlo::ReduceWindowOp op, - PatternRewriter& rewriter) const override { + // Pushes the transpose op at the input to the result. auto transpose_op = cast(op.getOperand(0).getDefiningOp()); const auto result_type = mlir::cast(op.getResult(0).getType()); @@ -194,6 +195,7 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp rewriter); rewriter.replaceAllUsesWith(op.getResult(0), result_transpose_op); + return success(); } private: @@ -242,12 +244,12 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp // Rewrites `maximum(transpose(%rhs), %lhs)` patterns to // `transpose(maximum(%rhs, transpose(%lhs)))`. -class DeferActivationTransposeForMaxOp - : public OpRewritePattern::SplitMatchAndRewrite { +class DeferActivationTransposeForMaxOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(MaxOp op) const override { + LogicalResult matchAndRewrite(MaxOp op, + PatternRewriter& rewriter) const override { Value input = op.getOperand(0); if (!HasRankOf(input, /*rank=*/4)) return failure(); @@ -258,12 +260,13 @@ class DeferActivationTransposeForMaxOp return failure(); } - return IsTransposeOpWithPermuation(input.getDefiningOp(), - kNhwcToNchwPermutation); - } - - void rewrite(MaxOp op, PatternRewriter& rewriter) const override { + if (IsTransposeOpWithPermuation(input.getDefiningOp(), + kNhwcToNchwPermutation) + .failed()) { + return failure(); + } DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 79872b57e1574e..197fb1c868afb3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -116,12 +116,12 @@ class DenseElementsTransposer { }; class FoldTransposedConstantOp - : public OpRewritePattern< - mlir::stablehlo::TransposeOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::TransposeOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { Value operand = op.getOperand(); auto const_op = dyn_cast_or_null(operand.getDefiningOp()); @@ -133,14 +133,9 @@ class FoldTransposedConstantOp return failure(); } - return success( - mlir::isa_and_nonnull(const_op.getValue())); - } - - void rewrite(mlir::stablehlo::TransposeOp op, - PatternRewriter& rewriter) const override { - auto const_op = - cast(op.getOperand().getDefiningOp()); + if (!mlir::isa_and_nonnull(const_op.getValue())) { + return failure(); + } const auto value_attr = mlir::cast(const_op.getValue()); @@ -169,7 +164,8 @@ class FoldTransposedConstantOp combined_loc, new_value_attr); rewriter.replaceAllUsesWith(op, new_const_op); - }; + return success(); + } }; } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index b3c309c76adb79..ac8649835f78ff 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -80,15 +80,13 @@ class InsertWeightParamPass // Inserts quantization parameters for weights for hybrid quantization of // `stablehlo.convolution` and `stablehlo.dot_general`. class InsertWeightParamPattern - : public OpTraitRewritePattern< - OpTrait::ConstantLike>::SplitMatchAndRewrite { + : public OpTraitRewritePattern { public: explicit InsertWeightParamPattern(MLIRContext* context) - : SplitMatchAndRewrite(Pattern::MatchTraitOpTypeTag(), - TypeID::get(), 1, context) { - } + : OpTraitRewritePattern(context) {} - LogicalResult match(Operation* op) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { if (op->getNumResults() != 1) { return failure(); } @@ -96,27 +94,11 @@ class InsertWeightParamPattern if (!type || !type.getElementType().isF32()) { return failure(); } - return success( - op->hasOneUse() && - IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())); - } - - // Checks if the operand is second operand of `tf.XlaCallModule` op for - // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable - // trait. - static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { - if (operand.getOperandNumber() != 1) { - return false; - } - Operation* user = operand.getOwner(); - if (!IsWeightOnlyQuantizableOp(*user)) { - return false; + if (!op->hasOneUse() || + !IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())) { + return failure(); } - Method method = GetQuantizationMethodOrDefault(user); - return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); - } - void rewrite(Operation* op, PatternRewriter& rewriter) const override { Operation* quantizable_op = *op->getUsers().begin(); DenseFPElementsAttr attr; matchPattern(op->getResult(0), m_Constant(&attr)); @@ -144,7 +126,7 @@ class InsertWeightParamPattern op->emitError( "Failed to get weight quantization parameters for weight-only " "quantization."); - return; + return failure(); } const Type expressed_type = op->getResult(0).getType(); @@ -157,6 +139,22 @@ class InsertWeightParamPattern auto dq = rewriter.create(op->getLoc(), expressed_type, q); quantizable_op->setOperand(1, dq.getResult()); + return success(); + } + + // Checks if the operand is second operand of `tf.XlaCallModule` op for + // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable + // trait. + static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { + if (operand.getOperandNumber() != 1) { + return false; + } + Operation* user = operand.getOwner(); + if (!IsWeightOnlyQuantizableOp(*user)) { + return false; + } + Method method = GetQuantizationMethodOrDefault(user); + return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); } private: diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 3b2b20bc2e4c52..4bb871a56886b3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -48,12 +48,12 @@ class NchwConvolutionToNhwcPass // * Src dimension numbers: [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] // * Dst dimension numbers: [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] class RewriteNchwConvolutionToNhwc - : public OpRewritePattern< - mlir::stablehlo::ConvolutionOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { // Handles 2D convolutions only. if (!HasRankOf(op.getOperand(0), /*rank=*/4) || !HasRankOf(op.getOperand(1), /*rank=*/4)) { @@ -63,13 +63,14 @@ class RewriteNchwConvolutionToNhwc if (!IsOpNotQuantized(op)) return failure(); const ConvDimensionNumbersAttr dimension_nums = op.getDimensionNumbers(); - return success(MatchInputDimensionNumbers(dimension_nums) && - MatchKernelDimensionNumbers(dimension_nums) && - MatchOutputDimensionNumbers(dimension_nums)); - } + const bool dimension_nums_matched = + MatchInputDimensionNumbers(dimension_nums) && + MatchKernelDimensionNumbers(dimension_nums) && + MatchOutputDimensionNumbers(dimension_nums); + if (!dimension_nums_matched) { + return failure(); + } - void rewrite(mlir::stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( @@ -130,6 +131,7 @@ class RewriteNchwConvolutionToNhwc rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); rewriter.replaceAllUsesWith(op, output_transpose_op); + return success(); } private: diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 232115e53d3219..d6a88055c8c855 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -668,16 +668,16 @@ void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( template >> -class XlaCallModuleOpToCallOp - : public OpRewritePattern::SplitMatchAndRewrite { +class XlaCallModuleOpToCallOp : public OpRewritePattern { public: explicit XlaCallModuleOpToCallOp( MLIRContext& ctx, const bool enable_per_channel_quantized_weight) - : OpRewritePattern::SplitMatchAndRewrite(&ctx), + : OpRewritePattern::OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( enable_per_channel_quantized_weight) {} - LogicalResult match(TF::XlaCallModuleOp op) const override { + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { ModuleOp module_op = op->getParentOfType(); // Ignore ops without quantization method. @@ -698,22 +698,20 @@ class XlaCallModuleOpToCallOp return failure(); } Method quantization_method = GetQuantizationMethodOrDefault(op); - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) - .match(entry_func_op, quantization_method); - } + if (FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method) + .failed()) { + return failure(); + } - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { // TODO: b/331145946 - Each quantization method should be valid // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check // the validity in `match()`. Use accessors to achieve this. - const Method quantization_method = - GetQuantizationMethodOrDefault(xla_call_module_op); - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, + *rewriter.getContext(), rewriter, op, FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), quantization_method); + return success(); } private: @@ -726,14 +724,22 @@ class XlaCallModuleOpToCallOp // Quantizes only when the nested region consists of ops whose quantization // parameters can be propagated from outside. class QuantizeOpWithRegionPattern - : public OpRewritePattern< - quantfork::DequantizeCastOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: explicit QuantizeOpWithRegionPattern(MLIRContext& ctx) - : OpRewritePattern::SplitMatchAndRewrite( - &ctx) {}; + : OpRewritePattern(&ctx) {}; - LogicalResult match(quantfork::DequantizeCastOp op) const final { + LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(quantfork::DequantizeCastOp op) const { // Match only when there is one user of the dequantize op. if (!op.getResult().hasOneUse()) { return failure(); @@ -762,7 +768,7 @@ class QuantizeOpWithRegionPattern } void rewrite(quantfork::DequantizeCastOp op, - PatternRewriter& rewriter) const final { + PatternRewriter& rewriter) const { // Rewrite the floating-point ops to the quantized version, by fusing // preceding dequantize ops and succeding quantize ops. for (Operation* op_with_region : op.getResult().getUsers()) { @@ -849,7 +855,6 @@ class QuantizeOpWithRegionPattern } } - private: // Checks if an op is quantizable in a nested region. bool IsOpQuantizableInNestedRegion(Operation& op) const { return isa(op); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 57b8b23de72cc4..0b73b9c550b62a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -164,17 +164,25 @@ class AddDumpTensorOpPass }; template -class AddDumpTensorOp - : public OpRewritePattern::SplitMatchAndRewrite { +class AddDumpTensorOp : public OpRewritePattern { public: // Does not take ownership of context, which must refer to a valid value that // outlives this object. explicit AddDumpTensorOp(MLIRContext *context, DebuggerType debugger_type, std::string log_dir_path) - : OpRewritePattern::SplitMatchAndRewrite(context), + : OpRewritePattern(context), debugger_type_(debugger_type), log_dir_path_(std::move(log_dir_path)) {} + LogicalResult matchAndRewrite(LiftedOpT op, + PatternRewriter &rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + private: SmallVector CreateDumpAttributes( PatternRewriter &rewriter, const StringRef folder_name, @@ -204,7 +212,7 @@ class AddDumpTensorOp return symbol_table.insert(new_ref_func); } - LogicalResult match(LiftedOpT op) const override { + LogicalResult match(LiftedOpT op) const { if (!op->hasAttr(kQuantTraitAttrName) || op->getNumResults() != 1) { return failure(); } @@ -219,7 +227,7 @@ class AddDumpTensorOp return success(); } - void rewrite(LiftedOpT op, PatternRewriter &rewriter) const override { + void rewrite(LiftedOpT op, PatternRewriter &rewriter) const { // Only support ops with 1 results Value result = op->getResult(0); rewriter.setInsertionPointAfterValue(result); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc index 7fea73725af761..50d4030083d99b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc @@ -47,13 +47,22 @@ class CastBf16OpsToF32Pass void runOnOperation() override; }; -class CastBf16OpsToF32 : public RewritePattern::SplitMatchAndRewrite { +class CastBf16OpsToF32 : public RewritePattern { public: explicit CastBf16OpsToF32(MLIRContext* context) - : SplitMatchAndRewrite(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } private: - LogicalResult match(Operation* op) const override { + LogicalResult match(Operation* op) const { if (isa(op) || op->getName().hasTrait()) { return failure(); @@ -71,7 +80,7 @@ class CastBf16OpsToF32 : public RewritePattern::SplitMatchAndRewrite { return failure(); } - void rewrite(Operation* op, PatternRewriter& rewriter) const override { + void rewrite(Operation* op, PatternRewriter& rewriter) const { // Casts inputs of the operation. for (int i = 0; i < op->getNumOperands(); i++) { Value input = op->getOperand(i); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc index 33fbe5406040f7..ae3a25b32199e7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc @@ -62,28 +62,25 @@ class RemoveVariableInitializationByConstPass // pattern. `tf.VarHandleOp` and `tf.Const` are removed unless they are used by // other ops. struct RemoveVariableAssignmentByConst - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { // Inherit the constructors. - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(TF::AssignVariableOp assign_op) const override { + LogicalResult matchAndRewrite(TF::AssignVariableOp assign_op, + PatternRewriter& rewriter) const override { Value resource_operand = assign_op.getOperand(0); Value assigned_value_operand = assign_op.getOperand(1); - if (isa(resource_operand.getDefiningOp()) && - isa(assigned_value_operand.getDefiningOp())) { - return success(); - } else { + if (!isa(resource_operand.getDefiningOp()) || + !isa(assigned_value_operand.getDefiningOp())) { return failure(); } - } - void rewrite(TF::AssignVariableOp assign_op, - PatternRewriter& rewriter) const override { // `TF::ConstOp` and `TF::VarHandleOp` are not manually erased. // `applyPatternsGreedily` performs dead code elimination and unsed // ops will be erased during the optimization. rewriter.eraseOp(assign_op); + return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 03e3072732b92f..8b5d2e0de1e26a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -67,21 +67,19 @@ class PrepareTpuComputationForTfExportPass }; class RewriteXlaHostComputeMlir - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(TF::_XlaHostComputeMlirOp op) const override { + LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op, + PatternRewriter& rewriter) const override { if (op.getManualSharding()) { // This rewrite does not support manual_sharding. It is expected that the // _XlaHostComputeMlirOp registered as an MlirXlaOpKernel will handle this // case later once the XlaBuilder graph reaches it. return failure(); } - return success(); - } - void rewrite(TF::_XlaHostComputeMlirOp op, - PatternRewriter& rewriter) const override { + llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { @@ -141,6 +139,7 @@ class RewriteXlaHostComputeMlir op.getRecvKeyAttr(), /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate), /*tpu_core=*/rewriter.getI64IntegerAttr(0)); + return success(); } }; diff --git a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc index 15b065cf0ec2f5..3235badda66bd0 100644 --- a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc +++ b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc @@ -43,16 +43,32 @@ namespace { using mlir::TF::DTensorLayout; class RemoveDTensorLayoutAfterConstOrBlockArgPattern - : public mlir::OpRewritePattern::SplitMatchAndRewrite { + : public mlir::OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult match(DTensorLayout layout_op) const override; - - void rewrite(DTensorLayout layout_op, - mlir::PatternRewriter& rewriter) const override { + mlir::LogicalResult matchAndRewrite( + DTensorLayout layout_op, mlir::PatternRewriter& rewriter) const override { + if (match(layout_op).failed()) { + return mlir::failure(); + } rewriter.replaceAllUsesWith(layout_op, layout_op.getInput()); rewriter.eraseOp(layout_op); + return mlir::success(); + } + + private: + mlir::LogicalResult match(DTensorLayout layout_op) const { + auto input = layout_op.getInput(); + if (mlir::isa(input)) { + return mlir::success(); + } + mlir::Operation* input_op = input.getDefiningOp(); + if (input_op != nullptr) { + return mlir::success(input_op->hasTrait()); + } else { + return layout_op->emitOpError() << "Can't find defining op for " << input; + } } }; @@ -63,20 +79,6 @@ class DTensorLayoutToXlaShardingOpPass void runOnOperation() override; }; -mlir::LogicalResult RemoveDTensorLayoutAfterConstOrBlockArgPattern::match( - DTensorLayout layout_op) const { - auto input = layout_op.getInput(); - if (mlir::isa(input)) { - return mlir::success(); - } - mlir::Operation* input_op = input.getDefiningOp(); - if (input_op != nullptr) { - return mlir::success(input_op->hasTrait()); - } else { - return layout_op->emitOpError() << "Can't find defining op for " << input; - } -} - void DTensorLayoutToXlaShardingOpPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); // Some patterns in tf2xla requires operands to be ConstantLike. diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp index 5b7dda9ff5075e..8fae3ddee4f585 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp @@ -56,10 +56,11 @@ Type getQuantizedType(Location loc, PatternRewriter& rewriter, } struct AddQuantDeQuantAfterConvolutionOp final - : OpRewritePattern::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { // Match a stablehlo.convolution op if // 1. Its operands are defined by stablehlo.uniform_dequantize op, // 2. It has a single user. @@ -77,11 +78,6 @@ struct AddQuantDeQuantAfterConvolutionOp final if (isa(*op->getUsers().begin())) return failure(); - return success(); - } - - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { auto* clonedConvOp = rewriter.clone(*op); auto convResultType = cast(clonedConvOp->getResult(0).getType()); @@ -94,6 +90,7 @@ struct AddQuantDeQuantAfterConvolutionOp final op.getLoc(), op.getType(), /*input=*/stablehloQuantizeOp.getResult()); rewriter.replaceAllUsesWith(op, stablehloDeQuantizeOp.getResult()); + return success(); } }; From b9da8973215ad68ecee0ac1fa0d63b05c7391ba5 Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Fri, 4 Apr 2025 16:19:41 -0700 Subject: [PATCH 0259/1324] Update xprof repo from tensorflow/profiler to openxla/xprof The xprof repo points now to openxla/xprof: https://github.com/openxla/xprof/commit/7bee47367747c PiperOrigin-RevId: 744101924 --- tensorflow/workspace2.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index f43770537af437..41d25e5099886a 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -939,7 +939,7 @@ def _tf_repositories(): name = "org_xprof", sha256 = "dec4889a6a5123fca0a775ba20f22717b2d0c3af1491f41bb52e1b502595271e", strip_prefix = "xprof-c3dbeb2c69b48163c6156d6f4a8c82ac34736f49", - urls = tf_mirror_urls("https://github.com/tensorflow/profiler/archive/c3dbeb2c69b48163c6156d6f4a8c82ac34736f49.zip"), + urls = tf_mirror_urls("https://github.com/openxla/xprof/archive/c3dbeb2c69b48163c6156d6f4a8c82ac34736f49.zip"), ) # used for adding androidx.annotation dependencies in tflite android jni. From 69b47b579bccc95d554ecbf2ac7824be4abec35d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 17:01:53 -0700 Subject: [PATCH 0260/1324] Hlo-Diff is a semantic diff tool for HLO modules. It compares the graph structure of two HLO Modules focusing on the computational differences ignoring irrelevant changes such as instruction names, parameter ordering, layouts (in some instances) etc. The tool supports: 1. Diffing of large HLO Modules (>100k) nodes in <1 minute. 2. Summarized output of diffs highlighting what has changed for updated HLO instructions. Note only XLA's HLO format is supported at the moment. PiperOrigin-RevId: 744111765 --- third_party/xla/tsl_workspace1.bzl | 3 + third_party/xla/tsl_workspace2.bzl | 7 + third_party/xla/workspace1.bzl | 3 + third_party/xla/workspace2.bzl | 7 + third_party/xla/xla/hlo/tools/hlo_diff/BUILD | 198 +++ .../xla/xla/hlo/tools/hlo_diff/graph/BUILD | 68 + .../hlo/tools/hlo_diff/graph/analysis/BUILD | 41 + .../graph/analysis/hlo_value_tracing.cc | 1234 +++++++++++++++++ .../graph/analysis/hlo_value_tracing.h | 194 +++ .../hlo/tools/hlo_diff/graph/hlo_gumgraph.cc | 331 +++++ .../hlo/tools/hlo_diff/graph/hlo_gumgraph.h | 161 +++ .../tools/hlo_diff/graph/hlo_gumgraph_node.h | 73 + .../tools/hlo_diff/graph/hlo_gumgraph_test.cc | 523 +++++++ .../xla/hlo/tools/hlo_diff/graph/utils/BUILD | 62 + .../hlo_diff/graph/utils/hlo_gumgraph_bfs.cc | 108 ++ .../hlo_diff/graph/utils/hlo_gumgraph_bfs.h | 74 + .../graph/utils/hlo_gumgraph_bfs_test.cc | 246 ++++ .../hlo_diff/graph/utils/hlo_gumgraph_dfs.cc | 81 ++ .../hlo_diff/graph/utils/hlo_gumgraph_dfs.h | 59 + .../graph/utils/hlo_gumgraph_dfs_test.cc | 233 ++++ .../xla/hlo/tools/hlo_diff/hlo_diff_eval.cc | 137 ++ .../xla/hlo/tools/hlo_diff/hlo_diff_eval.h | 70 + .../hlo/tools/hlo_diff/hlo_diff_eval_test.cc | 206 +++ .../xla/hlo/tools/hlo_diff/hlo_diff_main.cc | 259 ++++ .../xla/hlo/tools/hlo_diff/hlo_diff_result.cc | 128 ++ .../xla/hlo/tools/hlo_diff/hlo_diff_result.h | 65 + .../tools/hlo_diff/hlo_diff_result_test.cc | 301 ++++ .../hlo/tools/hlo_diff/hlo_diff_summary.cc | 308 ++++ .../xla/hlo/tools/hlo_diff/hlo_diff_summary.h | 97 ++ .../tools/hlo_diff/hlo_diff_summary_test.cc | 369 +++++ .../hlo/tools/hlo_diff/hlo_gumgraph_diff.cc | 117 ++ .../hlo/tools/hlo_diff/hlo_gumgraph_diff.h | 51 + .../tools/hlo_diff/hlo_gumgraph_diff_test.cc | 85 ++ .../tools/hlo_diff/hlo_gumgraph_mappings.h | 121 ++ .../xla/xla/hlo/tools/hlo_diff/matchers/BUILD | 115 ++ .../matchers/hlo_call_graph_matcher.cc | 389 ++++++ .../matchers/hlo_call_graph_matcher.h | 38 + .../matchers/hlo_call_graph_matcher_test.cc | 400 ++++++ .../matchers/hlo_computation_graph_matcher.cc | 401 ++++++ .../matchers/hlo_computation_graph_matcher.h | 35 + .../hlo_computation_graph_matcher_test.cc | 250 ++++ .../hlo_diff/matchers/hlo_gumgraph_matcher.cc | 511 +++++++ .../hlo_diff/matchers/hlo_gumgraph_matcher.h | 129 ++ .../matchers/hlo_gumgraph_matcher_test.cc | 551 ++++++++ .../xla/xla/hlo/tools/hlo_diff/render/BUILD | 76 + .../render/hlo_gumgraph_html_renderer.cc | 494 +++++++ .../render/hlo_gumgraph_html_renderer.h | 54 + .../render/hlo_gumgraph_renderer_util.cc | 140 ++ .../render/hlo_gumgraph_renderer_util.h | 78 ++ .../render/hlo_gumgraph_renderer_util_test.cc | 92 ++ .../render/hlo_gumgraph_text_renderer.cc | 271 ++++ .../render/hlo_gumgraph_text_renderer.h | 55 + .../xla/xla/hlo/tools/hlo_diff/utils/BUILD | 66 + .../hlo_diff/utils/connected_components.cc | 71 + .../hlo_diff/utils/connected_components.h | 52 + .../utils/connected_components_test.cc | 152 ++ .../hlo/tools/hlo_diff/utils/hlo_diff_util.h | 42 + .../xla/hlo/tools/hlo_diff/utils/test_util.cc | 138 ++ .../xla/hlo/tools/hlo_diff/utils/test_util.h | 68 + 59 files changed, 10688 insertions(+) create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/hlo_diff_util.h create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.h diff --git a/third_party/xla/tsl_workspace1.bzl b/third_party/xla/tsl_workspace1.bzl index 2495080d804c42..aead12298027fb 100644 --- a/third_party/xla/tsl_workspace1.bzl +++ b/third_party/xla/tsl_workspace1.bzl @@ -2,6 +2,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") @@ -17,6 +18,8 @@ def workspace(with_rules_cc = True): closure_repositories() + boost_deps() + http_archive( name = "bazel_toolchains", sha256 = "294cdd859e57fcaf101d4301978c408c88683fbc46fbc1a3829da92afbea55fb", diff --git a/third_party/xla/tsl_workspace2.bzl b/third_party/xla/tsl_workspace2.bzl index 641efe4756eb0f..deaf8b0ebaef0c 100644 --- a/third_party/xla/tsl_workspace2.bzl +++ b/third_party/xla/tsl_workspace2.bzl @@ -608,6 +608,13 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-LLVM-Translator/archive/dad1f0eaab8047a4f73c50ed5f3d1694b78aae97.tar.gz"), ) + tf_http_archive( + name = "com_github_nelhage_rules_boost", + urls = tf_mirror_urls("https://github.com/nelhage/rules_boost/archive/5160325dbdc8c9e499f9d9917d913f35f1785d52.zip"), + strip_prefix = "rules_boost-5160325dbdc8c9e499f9d9917d913f35f1785d52", + sha256 = "feb4b1294684c79df7c1e08f1aec5da0da52021e33db59c88edbe86b4d1a017a", + ) + # buildifier: disable=unnamed-macro def workspace(): # Check the bazel version before executing any repository rules, in case diff --git a/third_party/xla/workspace1.bzl b/third_party/xla/workspace1.bzl index f04f6305d9b89c..b4c446195874b5 100644 --- a/third_party/xla/workspace1.bzl +++ b/third_party/xla/workspace1.bzl @@ -2,6 +2,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") load("//:tsl_workspace1.bzl", "tsl_workspace1") @@ -16,6 +17,8 @@ def workspace(): closure_repositories() + boost_deps() + http_archive( name = "bazel_toolchains", sha256 = "294cdd859e57fcaf101d4301978c408c88683fbc46fbc1a3829da92afbea55fb", diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 7fde4ef61d29c8..a76093e926f022 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -158,6 +158,13 @@ def _tf_repositories(): strip_prefix = "pybind11_protobuf-80f3440cd8fee124e077e2e47a8a17b78b451363", ) + tf_http_archive( + name = "com_github_nelhage_rules_boost", + urls = tf_mirror_urls("https://github.com/nelhage/rules_boost/archive/5160325dbdc8c9e499f9d9917d913f35f1785d52.zip"), + strip_prefix = "rules_boost-5160325dbdc8c9e499f9d9917d913f35f1785d52", + sha256 = "feb4b1294684c79df7c1e08f1aec5da0da52021e33db59c88edbe86b4d1a017a", + ) + # buildifier: disable=function-docstring # buildifier: disable=unnamed-macro def workspace(): diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD new file mode 100644 index 00000000000000..9feb12fbcadbd7 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD @@ -0,0 +1,198 @@ +load("//xla:xla.default.bzl", "xla_cc_binary", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_gumgraph_mappings", + hdrs = ["hlo_gumgraph_mappings.h"], + deps = [ + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/service:call_graph", + "@boost//:bimap", + ], +) + +cc_library( + name = "hlo_diff_result", + srcs = ["hlo_diff_result.cc"], + hdrs = ["hlo_diff_result.h"], + deps = [ + ":hlo_gumgraph_mappings", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/hlo/tools/hlo_diff/graph/utils:hlo_gumgraph_bfs", + "//xla/hlo/tools/hlo_diff/utils:hlo_diff_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + ], +) + +xla_cc_test( + name = "hlo_diff_result_test", + srcs = ["hlo_diff_result_test.cc"], + deps = [ + ":hlo_diff_result", + ":hlo_gumgraph_mappings", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/utils:test_util", + "//xla/service:hlo_module_config", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_diff_summary", + srcs = ["hlo_diff_summary.cc"], + hdrs = ["hlo_diff_summary.h"], + deps = [ + ":hlo_diff_result", + ":hlo_gumgraph_mappings", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/hlo/tools/hlo_diff/utils:connected_components", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:fingerprint", + ], +) + +xla_cc_test( + name = "hlo_diff_summary_test", + srcs = ["hlo_diff_summary_test.cc"], + deps = [ + ":hlo_diff_result", + ":hlo_diff_summary", + ":hlo_gumgraph_mappings", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/utils:test_util", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_gumgraph_diff", + srcs = ["hlo_gumgraph_diff.cc"], + hdrs = ["hlo_gumgraph_diff.h"], + deps = [ + ":hlo_diff_eval", + ":hlo_diff_result", + ":hlo_diff_summary", + ":hlo_gumgraph_mappings", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/hlo/tools/hlo_diff/matchers:hlo_call_graph_matcher", + "//xla/hlo/tools/hlo_diff/matchers:hlo_computation_graph_matcher", + "//xla/hlo/tools/hlo_diff/matchers:hlo_gumgraph_matcher", + "//xla/service:call_graph", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_cc_test( + name = "hlo_gumgraph_diff_test", + srcs = ["hlo_gumgraph_diff_test.cc"], + deps = [ + ":hlo_gumgraph_diff", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_diff_eval", + srcs = ["hlo_diff_eval.cc"], + hdrs = ["hlo_diff_eval.h"], + deps = [ + ":hlo_diff_result", + ":hlo_diff_summary", + ":hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + ], +) + +xla_cc_test( + name = "hlo_diff_eval_test", + srcs = ["hlo_diff_eval_test.cc"], + deps = [ + ":hlo_diff_eval", + ":hlo_diff_result", + ":hlo_diff_summary", + ":hlo_gumgraph_mappings", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/utils:test_util", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +xla_cc_binary( + name = "hlo_diff", + srcs = ["hlo_diff_main.cc"], + deps = [ + ":hlo_diff_eval", + ":hlo_diff_result", + ":hlo_diff_summary", + ":hlo_gumgraph_diff", + "//xla:debug_options_flags", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/tools/hlo_diff/render:hlo_gumgraph_html_renderer", + "//xla/hlo/tools/hlo_diff/render:hlo_gumgraph_text_renderer", + "//xla/service:hlo_module_config", + "//xla/service:hlo_module_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:platform_port", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD new file mode 100644 index 00000000000000..ee08b9f3e40ac7 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD @@ -0,0 +1,68 @@ +load("//xla:xla.default.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_gumgraph_node", + hdrs = ["hlo_gumgraph_node.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:call_graph", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "hlo_gumgraph", + srcs = ["hlo_gumgraph.cc"], + hdrs = ["hlo_gumgraph.h"], + deps = [ + ":hlo_gumgraph_node", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff/graph/analysis:hlo_value_tracing", + "//xla/hlo/tools/hlo_diff/graph/utils:hlo_gumgraph_dfs", + "//xla/hlo/tools/hlo_diff/utils:hlo_diff_util", + "//xla/service:call_graph", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:fingerprint", + ], +) + +xla_cc_test( + name = "hlo_gumgraph_test", + srcs = ["hlo_gumgraph_test.cc"], + deps = [ + ":hlo_gumgraph", + ":hlo_gumgraph_node", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:hlo_module_config", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD new file mode 100644 index 00000000000000..233bc847801026 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD @@ -0,0 +1,41 @@ +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//xla/hlo/tools/hlo_diff:__subpackages__", + ], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_value_tracing", + srcs = ["hlo_value_tracing.cc"], + hdrs = ["hlo_value_tracing.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:call_graph", + "//xla/service:hlo_value", + "//xla/tsl/platform:errors", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc new file mode 100644 index 00000000000000..22db01d6f42657 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc @@ -0,0 +1,1234 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/util.h" + +namespace xla { +namespace { +// CalculatePostOrderSchedule traverses a module and assign a ordinal to each +// instruction based the postorder dependency. +int64_t CalculatePostOrderScheduleHelper( + const HloComputation* comp, int64_t start_ordinal, + absl::flat_hash_map* ordinal_map) { + int64_t ordinal = start_ordinal; + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kAsyncStart || + instruction->opcode() == HloOpcode::kConditional) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal, + ordinal_map); + } + } + if (instruction->opcode() == HloOpcode::kWhile) { + ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(), + ordinal, ordinal_map); + ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(), + ordinal, ordinal_map); + } + // It's possible that in some unit tests the computation graph is not + // flatten (meaning we could have multiple callers for one computation). In + // that case the oridinal_map will see the instruction multiple times. We + // consider that case to be ok as it only shows up in unit tests. + ordinal_map->insert({instruction, ordinal++}); + } + return ordinal; +} + +absl::flat_hash_map CalculatePostOrderSchedule( + const HloModule& module) { + absl::flat_hash_map map; + CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map); + return map; +} + +} // namespace + +bool HloValueTracing::ValueIsDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index) const { + const HloValueSet& value_set = GetValueSet(instruction, index); + if (value_set.values().size() != 1) { + return false; + } + return value_set.GetUniqueValue().defining_instruction() == instruction; +} + +HloValueTracing::HloValueTracing( + const HloModule& module, + absl::flat_hash_set execution_threads) + : module_(module), + execution_threads_(std::move(execution_threads)), + call_graph_(CallGraph::Build(&module)) {} + +HloValue* HloValueTracing::NewHloValue(HloInstruction* instruction, + const ShapeIndex& index, bool is_phi) { + const int64_t value_id = next_value_id_++; + auto result = + values_.insert({value_id, std::make_unique( + value_id, instruction, index, is_phi)}); + CHECK(result.second); + + + return result.first->second.get(); +} + +void HloValueTracing::DeleteMarkedValues() { + // Use a set to prevent deleting an id twice. + absl::flat_hash_set id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); + + for (HloValue::Id value_id : id_set) { + values_.erase(value_id); + } + value_ids_to_delete_.clear(); +} + +const HloValue& HloValueTracing::GetValue(HloValue::Id value_id) const { + DCHECK(values_.contains(value_id)) << "Value not found: " << value_id; + return *values_.find(value_id)->second; +} + +HloValue& HloValueTracing::GetValue(HloValue::Id value_id) { + DCHECK(values_.contains(value_id)) << "Value not found: " << value_id; + return *values_.find(value_id)->second; +} + +HloValueSet HloValueTracing::GetFlattenedValueSet( + const HloInstruction* instruction) const { + HloValueSet value_set; + + const InstructionValueSet& value_set_tree = + GetInstructionValueSet(instruction); + + std::vector all_sets; + for (auto& pair : value_set_tree) { + all_sets.push_back(&pair.second); + } + value_set.AssignUnionOf(all_sets); + + return value_set; +} + +const HloValueSet& HloValueTracing::GetValueSet( + const HloInstruction* instruction, const ShapeIndex& index) const { + return GetInstructionValueSet(instruction).element(index); +} + +HloValueSet& HloValueTracing::GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index) { + return *GetInstructionValueSet(instruction).mutable_element(index); +} + +const HloValueSet& HloValueTracing::GetValueSet( + const HloPosition& position) const { + return GetValueSet(position.instruction, position.index); +} + +HloValueSet& HloValueTracing::GetValueSet(const HloPosition& position) { + return GetValueSet(position.instruction, position.index); +} + +bool HloValueTracing::UpdateSendValueSet(HloInstruction* send) { + CHECK_EQ(send->opcode(), HloOpcode::kSend); + bool changed = false; + // Send forwards the operand value to the output tuple at {0}. + for (auto& pair : GetInstructionValueSet(send->operand(0))) { + const ShapeIndex& operand_index = pair.first; + const HloValueSet& operand_value_set = pair.second; + + ShapeIndex index = {0}; + for (int64_t i : operand_index) { + index.push_back(i); + } + + HloValueSet& value_set = GetValueSet(send, index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateAsyncStartValueSet(HloInstruction* async_start) { + CHECK_EQ(async_start->opcode(), HloOpcode::kAsyncStart); + bool changed = false; + // AsyncStart forwards the operand values to element {0} of its output. + for (int64_t i = 0; i < async_start->operand_count(); ++i) { + const HloInstruction* operand = async_start->operand(i); + ShapeUtil::ForEachSubshape( + operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + const HloValueSet& operand_value_set = GetValueSet(operand, index); + + ShapeIndex output_index = {0, i}; + output_index.insert(output_index.end(), index.begin(), index.end()); + + HloValueSet& value_set = GetValueSet(async_start, output_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + }); + } + if (!HloInstruction::IsThreadIncluded(async_start->async_execution_thread(), + execution_threads_)) { + return changed; + } + // AsyncStart forwards the async wrapped computation root values to element + // {1} of its output. + HloInstruction* root = + async_start->async_wrapped_computation()->root_instruction(); + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + const HloValueSet& root_value_set = GetValueSet(root, index); + + ShapeIndex output_index = {1}; + output_index.insert(output_index.end(), index.begin(), index.end()); + + HloValueSet& value_set = GetValueSet(async_start, output_index); + if (value_set != root_value_set) { + value_set = root_value_set; + changed = true; + } + }); + return changed; +} + +bool HloValueTracing::UpdateAsyncUpdateValueSet(HloInstruction* async_update) { + CHECK_EQ(async_update->opcode(), HloOpcode::kAsyncUpdate); + CHECK_EQ(async_update->shape(), async_update->operand(0)->shape()); + bool changed = false; + HloInstruction* root = + HloInstruction::IsThreadIncluded(async_update->async_execution_thread(), + execution_threads_) + ? async_update->async_wrapped_computation()->root_instruction() + : nullptr; + // AsyncUpdate forwards all of the operand values to corresponding elements of + // its output. + ShapeUtil::ForEachSubshape( + async_update->operand(0)->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + const HloValueSet& operand_value_set = + GetValueSet(async_update->operand(0), index); + + HloValueSet& value_set = GetValueSet(async_update, index); + CHECK_GE(index.size(), 0); + if (index[0] != 1) { + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } else if (root != nullptr) { + // If this subshape is an output (index {1}), we need to create the + // union with the async wrapped computation root. + ShapeIndex root_index(index.begin() + 1, index.end()); + const HloValueSet& root_value_set = GetValueSet(root, root_index); + changed |= + value_set.AssignUnionOf({&operand_value_set, &root_value_set}); + } else if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + }); + return changed; +} + +bool HloValueTracing::UpdateAsyncDoneValueSet(HloInstruction* async_done) { + CHECK_EQ(async_done->opcode(), HloOpcode::kAsyncDone); + bool changed = false; + HloInstruction* root = + HloInstruction::IsThreadIncluded(async_done->async_execution_thread(), + execution_threads_) + ? async_done->async_wrapped_computation()->root_instruction() + : nullptr; + // AsyncDone creates a union of the operand values at {1} and the async + // wrapped computation root to element {} of its output. + ShapeUtil::ForEachSubshape( + async_done->operand(0)->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray() || index.front() != 1) { + return; + } + const HloValueSet& operand_value_set = + GetValueSet(async_done->operand(0), index); + + ShapeIndex output_index(index.begin() + 1, index.end()); + HloValueSet& value_set = GetValueSet(async_done, output_index); + if (root != nullptr) { + const HloValueSet& root_value_set = GetValueSet(root, output_index); + changed |= + value_set.AssignUnionOf({&operand_value_set, &root_value_set}); + } else if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + }); + return changed; +} + +bool HloValueTracing::UpdateCopyStartValueSet(HloInstruction* copy_start) { + CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart); + bool changed = false; + // CopyStart forwards the operand value to elements {0, 1} of its output. + const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0)); + HloValueSet& first_value_set = GetValueSet(copy_start, {0}); + if (first_value_set != operand_value_set) { + first_value_set = operand_value_set; + changed = true; + } + + HloValueSet& second_value_set = GetValueSet(copy_start, {1}); + if (second_value_set != operand_value_set) { + second_value_set = operand_value_set; + changed = true; + } + + return changed; +} + +bool HloValueTracing::UpdateCopyDoneValueSet(HloInstruction* copy_done) { + CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone); + bool changed = false; + // CopyDone forwards the operand value at {0} to element {} of its output. + const HloValueSet& operand_value_set = + GetValueSet(copy_done->operand(0), {0}); + HloValueSet& value_set = GetValueSet(copy_done); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + return changed; +} + +bool HloValueTracing::UpdateRecvDoneValueSet(HloInstruction* recv_done) { + CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); + bool changed = false; + // RecvDone forwards the operand value at {0} to element {0} of its output. + for (auto& pair : GetInstructionValueSet(recv_done)) { + ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + + if (index.empty() || index[0] != 0) { + continue; + } + + const HloValueSet& operand_value_set = + GetValueSet(recv_done->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateCallValueSet(HloInstruction* call) { + CHECK_EQ(call->opcode(), HloOpcode::kCall); + InstructionValueSet& value_set = GetInstructionValueSet(call); + InstructionValueSet& root_value_set = + GetInstructionValueSet(call->to_apply()->root_instruction()); + if (value_set != root_value_set) { + value_set = root_value_set; + return true; + } + return false; +} + +bool HloValueTracing::UpdateConditionalValueSet(HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + std::vector inputs(conditional->branch_count()); + for (int j = 0; j < conditional->branch_count(); ++j) { + inputs[j] = &GetInstructionValueSet( + conditional->branch_computation(j)->root_instruction()); + } + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); +} + +bool HloValueTracing::UpdateCopyValueSet(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + bool changed = false; + for (auto& pair : GetInstructionValueSet(copy)) { + const ShapeIndex& index = pair.first; + + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateOptimizationBarrierValueSet( + HloInstruction* barrier) { + // Optimization Barriers just forward their operand. Given that barriers can + // have a tuple operand, we iterate through its indexes, like for copies. + // Unlike copies though we also propagate the top-level value. + CHECK_EQ(barrier->opcode(), HloOpcode::kOptimizationBarrier); + bool changed = false; + for (auto& pair : GetInstructionValueSet(barrier)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(barrier->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateDomainValueSet(HloInstruction* domain) { + // Domain instructions just forward their operand. Given that domains can have + // a tuple operand, we iterate through its indexes, like for copies. + // Unlike copies though we also propagate the top-level value. + CHECK_EQ(domain->opcode(), HloOpcode::kDomain); + bool changed = false; + for (auto& pair : GetInstructionValueSet(domain)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateAddDependencyValueSet( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency); + const InstructionValueSet& operand_set = + GetInstructionValueSet(add_dependency->operand(0)); + InstructionValueSet& add_dependency_set = + GetInstructionValueSet(add_dependency); + if (operand_set != add_dependency_set) { + add_dependency_set = operand_set; + return true; + } + return false; +} + +bool HloValueTracing::UpdateGetTupleElementValueSet(HloInstruction* gte) { + CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); + bool changed = false; + // The GetTupleElement instruction forwards the values from the specified + // tuple element. + for (auto& pair : GetInstructionValueSet(gte)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + + // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex + // with the tuple element number prefixed. + ShapeIndex operand_index = {gte->tuple_index()}; + for (int64_t i : index) { + operand_index.push_back(i); + } + + HloValueSet& operand_value_set = + GetValueSet(gte->operand(0), operand_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateParameterValueSet(HloInstruction* parameter) { + CHECK_EQ(parameter->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph_->GetNode(parameter->parent()); + + // Subcomputations called in a parallel context (eg, map) do not have dataflow + // from the caller operands. + if (call_graph_node.caller_callsites().empty()) { + return false; + } + + std::vector inputs; + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + const HloOpcode opcode = callsite.instruction()->opcode(); + if (opcode == HloOpcode::kCall || opcode == HloOpcode::kFusion) { + // The operand values of a call instruction are forwarded to the + // respective parameter instruction of the subcomputation. + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->operand(parameter->parameter_number()))); + } else if (opcode == HloOpcode::kWhile) { + // In a while instruction, the while operand (ie, the init value) and the + // backedge are dataflow inputs to the parameter instruction. This is the + // case for parameters of both the body and condition computations. + CHECK_EQ(parameter->parameter_number(), 0); + inputs.push_back( + &GetInstructionValueSet(callsite.instruction()->operand(0))); + // If the parameter *is not* the root, parameter state would be + // updated by the root, otherwise don't consider it's current state + // (InstructionValueSet) as we are recomputing its current state. + if (parameter != + callsite.instruction()->while_body()->root_instruction()) { + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->while_body()->root_instruction())); + } + } else if (opcode == HloOpcode::kConditional) { + CHECK_EQ(parameter->parameter_number(), 0); + auto* conditional = callsite.instruction(); + // Conditional has branch_count+1 operands. Operand 0 is the branch_index, + // operands 1 and onward are the arguments to the branch computations. + // + // If the parameter belongs to conditional's branch 0 computation, then + // operand 1 is forwarded to this parameter instruction. If the parameter + // belongs to conditional's branch 5 computation, then operand 6 is + // forwarded to this parameter instruction. + bool found_parent = false; + for (int j = 0; j < conditional->branch_count(); ++j) { + if (parameter->parent() == conditional->branch_computation(j)) { + inputs.push_back( + &GetInstructionValueSet(conditional->operand(j + 1))); + found_parent = true; + break; + } + } + CHECK(found_parent); + } else if (opcode == HloOpcode::kAsyncStart) { + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->operand(parameter->parameter_number()))); + } else if (opcode == HloOpcode::kAsyncUpdate || + opcode == HloOpcode::kAsyncDone) { + return GetInstructionValueSet(parameter).AssignUnionOf( + GetInstructionValueSet(callsite.instruction()->operand(0)), + {0, parameter->parameter_number()}); + } else { + return false; + } + } + + return GetInstructionValueSet(parameter).AssignUnionOf(inputs); +} + +bool HloValueTracing::UpdateTupleValueSet(HloInstruction* tuple) { + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); + bool changed = false; + for (int64_t i = 0; i < tuple->operands().size(); ++i) { + // Copy the value set(s) of each operand into the respective position in the + // kTuple instruction's value sets. + for (auto& pair : GetInstructionValueSet(tuple->operand(i))) { + const ShapeIndex& operand_index = pair.first; + HloValueSet& operand_value_set = pair.second; + + ShapeIndex index = {i}; + for (int64_t op_index : operand_index) { + index.push_back(op_index); + } + HloValueSet& value_set = GetValueSet(tuple, index); + + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + } + return changed; +} + +bool HloValueTracing::UpdateWhileValueSet(HloInstruction* xla_while) { + CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); + const InstructionValueSet* const inputs[] = { + &GetInstructionValueSet(xla_while->while_body()->root_instruction()), + &GetInstructionValueSet(xla_while->operand(0))}; + return GetInstructionValueSet(xla_while).AssignUnionOf(inputs); +} + +bool HloValueTracing::UpdateFusionValueSet(HloInstruction* fusion) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + + InstructionValueSet& value_set = GetInstructionValueSet(fusion); + InstructionValueSet& root_value_set = GetInstructionValueSet( + fusion->called_computations().front()->root_instruction()); + if (value_set != root_value_set) { + value_set = root_value_set; + return true; + } + + return false; +} + +bool HloValueTracing::UpdateAllGatherStartValueSet( + HloInstruction* all_gather_start) { + CHECK_EQ(all_gather_start->opcode(), HloOpcode::kAllGatherStart); + bool changed = false; + // AllGatherStart forwards the operand values to element {0} of its output. + for (int64_t i = 0; i < all_gather_start->operand_count(); ++i) { + const HloValueSet& operand_value_set = + GetValueSet(all_gather_start->operand(i)); + + ShapeIndex output_index = {0}; + if (all_gather_start->operand_count() > 1) { + output_index.push_back(i); + } + + HloValueSet& value_set = GetValueSet(all_gather_start, output_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateAllGatherDoneValueSet( + HloInstruction* all_gather_done) { + CHECK_EQ(all_gather_done->opcode(), HloOpcode::kAllGatherDone); + bool changed = false; + // AllGatherDone forwards the operand value at {1} to its output. If the + // output is a tuple, then that tuple is defined by all-gather-done, so + // only update the value set for tuple leaf elements (arrays). + for (auto& pair : GetInstructionValueSet(all_gather_done)) { + const ShapeIndex& output_index = pair.first; + HloValueSet& value_set = pair.second; + + if (!ShapeUtil::GetSubshape(all_gather_done->shape(), output_index) + .IsArray()) { + continue; + } + ShapeIndex operand_index = {1}; + for (int64_t i : output_index) { + operand_index.push_back(i); + } + + const HloValueSet& operand_value_set = + GetValueSet(all_gather_done->operand(0), operand_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateAllReduceDoneValueSet( + HloInstruction* all_reduce_done) { + CHECK_EQ(all_reduce_done->opcode(), HloOpcode::kAllReduceDone); + bool changed = false; + // AllReduceDone forwards its only operand. + for (auto& pair : GetInstructionValueSet(all_reduce_done)) { + const ShapeIndex& output_index = pair.first; + HloValueSet& value_set = pair.second; + + ShapeIndex operand_index = {}; + for (int64_t i : output_index) { + operand_index.push_back(i); + } + + const HloValueSet& operand_value_set = + GetValueSet(all_reduce_done->operand(0), operand_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start) { + CHECK_EQ(collective_permute_start->opcode(), + HloOpcode::kCollectivePermuteStart); + bool changed = false; + // CollectivePermuteStart forwards the operand value to element {0} of its + // output. + if (collective_permute_start->operand(0)->shape().IsTuple()) { + for (int i = 0; i < ShapeUtil::TupleElementCount( + collective_permute_start->operand(0)->shape()); + ++i) { + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_start->operand(0), {i}); + HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i}); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + } else { + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_start->operand(0)); + HloValueSet& value_set = GetValueSet(collective_permute_start, {0}); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done) { + CHECK_EQ(collective_permute_done->opcode(), + HloOpcode::kCollectivePermuteDone); + bool changed = false; + // CollectivePermuteDone forwards the operand value at {1} to its output. + if (collective_permute_done->shape().IsTuple()) { + for (int i = 0; + i < ShapeUtil::TupleElementCount(collective_permute_done->shape()); + ++i) { + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_done->operand(0), {1, i}); + HloValueSet& value_set = GetValueSet(collective_permute_done, {i}); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + } else { + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_done->operand(0), {1}); + HloValueSet& value_set = GetValueSet(collective_permute_done); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloValueTracing::UpdateInstructionValueSet(HloInstruction* instruction) { + // Recompute from operands. + bool changed = false; + switch (instruction->opcode()) { + case HloOpcode::kAddDependency: { + changed = UpdateAddDependencyValueSet(instruction); + break; + } + case HloOpcode::kAllGatherStart: { + changed = UpdateAllGatherStartValueSet(instruction); + break; + } + case HloOpcode::kAllGatherDone: { + changed = UpdateAllGatherDoneValueSet(instruction); + break; + } + case HloOpcode::kAsyncStart: { + changed = UpdateAsyncStartValueSet(instruction); + break; + } + case HloOpcode::kAsyncUpdate: { + changed = UpdateAsyncUpdateValueSet(instruction); + break; + } + case HloOpcode::kAsyncDone: { + changed = UpdateAsyncDoneValueSet(instruction); + break; + } + case HloOpcode::kDomain: { + changed = UpdateDomainValueSet(instruction); + break; + } + case HloOpcode::kCopy: { + changed = UpdateCopyValueSet(instruction); + break; + } + case HloOpcode::kGetTupleElement: { + changed = UpdateGetTupleElementValueSet(instruction); + break; + } + case HloOpcode::kTuple: { + changed = UpdateTupleValueSet(instruction); + break; + } + case HloOpcode::kParameter: { + changed = UpdateParameterValueSet(instruction); + break; + } + case HloOpcode::kCall: { + changed = UpdateCallValueSet(instruction); + break; + } + case HloOpcode::kWhile: { + changed = UpdateWhileValueSet(instruction); + break; + } + case HloOpcode::kSend: { + changed = UpdateSendValueSet(instruction); + break; + } + case HloOpcode::kRecvDone: { + changed = UpdateRecvDoneValueSet(instruction); + break; + } + case HloOpcode::kCopyStart: { + changed = UpdateCopyStartValueSet(instruction); + break; + } + case HloOpcode::kCopyDone: { + changed = UpdateCopyDoneValueSet(instruction); + break; + } + case HloOpcode::kConditional: { + changed = UpdateConditionalValueSet(instruction); + break; + } + case HloOpcode::kAllReduceDone: { + changed = UpdateAllReduceDoneValueSet(instruction); + break; + } + case HloOpcode::kCollectivePermuteStart: { + changed = UpdateCollectivePermuteStartValueSet(instruction); + break; + } + case HloOpcode::kCollectivePermuteDone: { + changed = UpdateCollectivePermuteDoneValueSet(instruction); + break; + } + case HloOpcode::kOptimizationBarrier: { + changed = UpdateOptimizationBarrierValueSet(instruction); + break; + } + case HloOpcode::kFusion: { + changed = UpdateFusionValueSet(instruction); + break; + } + default: + break; + } + + return changed; +} + +void HloValueTracing::Propagate() { + using Work = std::pair; + // Avoid duplicating work by preferring work items early in the post order + // schedule. Intuitively, we start from entry parameters and propagate buffers + // updates throughout the module only once. + std::priority_queue, std::greater> worklist; + absl::flat_hash_set workset; + auto priority_map = CalculatePostOrderSchedule(module_); + auto add_to_worklist = [&priority_map, &worklist, + &workset](HloInstruction* instruction) { + if (workset.insert(instruction).second) { + worklist.emplace(priority_map[instruction], instruction); + } + }; + + auto comps = module_.MakeComputationPostOrder(); + for (HloComputation* computation : comps) { + if (!HloInstruction::IsThreadIncluded(computation->execution_thread(), + execution_threads_)) { + continue; + } + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + add_to_worklist(instruction); + } + } + + while (!worklist.empty()) { + HloInstruction* instruction = worklist.top().second; + worklist.pop(); + + workset.erase(workset.find(instruction)); + + if (!UpdateInstructionValueSet(instruction)) { + // No change to the instruction's value set. + continue; + } + + // Instruction value was updated. Add users to work list if we haven't + // already. + for (HloInstruction* user : instruction->users()) { + add_to_worklist(user); + + // If user sequentially calls a computation, then the respective + // parameter(s) of the computation need to be updated. + if (user->opcode() == HloOpcode::kConditional) { + // If operand 0 is the use of instruction, then no parameters need to be + // updated, since that is the branch_index of the conditional. + // If operand n+1 is the use of instruction, then the branch_computation + // n's parameter need to be updated. + // + // Note that the same instruction can be used in multiple branches' + // operands. + for (int j = 0; j < user->branch_count(); ++j) { + if (user->operand(j + 1) == instruction) { + add_to_worklist( + user->branch_computation(j)->parameter_instruction(0)); + } + } + } else if (user->opcode() == HloOpcode::kAsyncUpdate || + user->opcode() == HloOpcode::kAsyncDone) { + if (HloInstruction::IsThreadIncluded(user->async_execution_thread(), + execution_threads_)) { + // For async update and async done, we cannot distinguish which + // parameter needs to be updated so add all to the worklist. + for (int64_t parameter_number = 0; + parameter_number < + user->async_wrapped_computation()->num_parameters(); + ++parameter_number) { + add_to_worklist( + user->async_wrapped_computation()->parameter_instruction( + parameter_number)); + } + } + } else { + for (HloComputation* called_computation : user->called_computations()) { + if (!HloInstruction::IsThreadIncluded( + called_computation->execution_thread(), execution_threads_)) { + continue; + } + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kControlFlow || + user->opcode() == HloOpcode::kFusion) { + for (int64_t operand_number : user->OperandIndices(instruction)) { + add_to_worklist( + called_computation->parameter_instruction(operand_number)); + } + } + } + } + } + + // If instruction is a root instruction, then propagate out to any calling + // instruction and across any while backedge. + if (instruction == instruction->parent()->root_instruction()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(instruction->parent()); + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // Add the while itself, and the body and condition parameters. + add_to_worklist(callsite.instruction()); + add_to_worklist( + callsite.instruction()->while_body()->parameter_instruction(0)); + add_to_worklist( + callsite.instruction()->while_condition()->parameter_instruction( + 0)); + } else if (call_graph_node.context() == CallContext::kControlFlow || + callsite.instruction()->opcode() == + HloOpcode::kConditional || + callsite.instruction()->opcode() == HloOpcode::kFusion) { + add_to_worklist(callsite.instruction()); + } + } + } + } +} + +const InstructionValueSet& HloValueTracing::GetInstructionValueSet( + const HloInstruction* instruction) const { + DCHECK(value_sets_.contains(instruction)) + << "Instruction " << instruction->ToString() << " not found."; + return *value_sets_.find(instruction)->second; +} + +InstructionValueSet& HloValueTracing::GetInstructionValueSet( + const HloInstruction* instruction) { + DCHECK(value_sets_.contains(instruction)) + << "Instruction " << instruction->ToString() << " not found."; + return *value_sets_.find(instruction)->second; +} + +absl::Status HloValueTracing::InitializeInstructionValueSets() { + for (const HloComputation* computation : module_.MakeComputationSorted()) { + if (!HloInstruction::IsThreadIncluded(computation->execution_thread(), + execution_threads_)) { + continue; + } + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + // Create an empty shape tree. + value_sets_.insert({instruction, std::make_unique( + &instruction->shape())}); + + // For each sub-shape of the instruction shape, add a new HloValue to its + // HloValueSet. should_define may be provided to define a subset of + // values. + auto define_all_values = + [this, &instruction]( + absl::FunctionRef should_define = + [](const ShapeIndex&) { return true; }) { + for (auto& pair : GetInstructionValueSet(instruction)) { + const ShapeIndex& index = pair.first; + + if (should_define(index)) { + HloValue* value = + NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); + } + } + }; + + // Add a new HloValue to the HloValueSet corresponding to the given index + // of the instruction shape. + auto define_value_at = [this, &instruction](const ShapeIndex& index) { + HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); + }; + + switch (instruction->opcode()) { + case HloOpcode::kAddDependency: + case HloOpcode::kWhile: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kGetTupleElement: + case HloOpcode::kDomain: + case HloOpcode::kOptimizationBarrier: + case HloOpcode::kCopy: + case HloOpcode::kFusion: + // These instructions define no values. The values in their output + // flow from their operands or from cross computation dataflow. + break; + case HloOpcode::kParameter: { + if (call_graph_node.context() == CallContext::kBoth) { + // We do not support a subcomputation that is called from both a + // parallel and sequential context. In this case, the parameter + // would both define a value and propagate a value from its + // caller. This limitation is not really a problem because the call + // graph is typically flattened. + return Unimplemented( + "Computation %s is called in both a parallel (eg, kMap) and " + "sequential (eg, kCall) context", + computation->name()); + } + if (call_graph_node.caller_callsites().empty()) { + // Parameters of computations called in a parallel context (eg, map + // and reduce) as well as parameters of dead computations define all + // values in their output. Otherwise the values of the parameter + // come from the caller (eg, operands to the kCall instruction). + define_all_values(); + } else { + HloOpcode caller_callsite_opcode = + call_graph_node.caller_callsites() + .front() + .instruction() + ->opcode(); + if (caller_callsite_opcode != HloOpcode::kFusion) { + define_all_values(); + } + } + break; + } + case HloOpcode::kTuple: + // These instructions only define their top-level values. Any other + // values flow from their operands. + define_value_at(/*index=*/{}); + break; + case HloOpcode::kAsyncStart: { + // AsyncStart produces a tuple of {{aliased operands}, {destination}, + // contexts}. It defines all of the tuple-shaped values and the + // contexts. + // If the thread is excluded, then we don't track the contained + // dataflow, and define the destination values too. + bool thread_included = HloInstruction::IsThreadIncluded( + instruction->async_execution_thread(), execution_threads_); + define_all_values([&](const ShapeIndex& index) { + return ShapeUtil::GetSubshape(instruction->shape(), index) + .IsTuple() || + (!thread_included && index.front() == 1) || + (index.front() > 1); + }); + break; + } + case HloOpcode::kAsyncUpdate: + // AsyncUpdate produces a tuple of {{aliased operands}, {destination}, + // contexts} where all of the array-typed values alias with the + // operand. So, only tuple-shaped values are defined by AsyncUpdate. + define_all_values([&](const ShapeIndex& index) { + return ShapeUtil::GetSubshape(instruction->shape(), index) + .IsTuple(); + }); + break; + case HloOpcode::kAsyncDone: + // AsyncDone's output aliases its output. It defines all remaining + // tuple-shaped values. + define_all_values([&](const ShapeIndex& index) { + return ShapeUtil::GetSubshape(instruction->shape(), index) + .IsTuple(); + }); + break; + case HloOpcode::kCopyStart: + // CopyStart produces a tuple of {destination buffer, aliased operand, + // U32 context}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{2}); + break; + case HloOpcode::kCopyDone: + // CopyDone consumes a tuple produced by CopyStart and produces an + // element. Its output aliases its input tuple element {0}. + break; + case HloOpcode::kAllGatherStart: + // AllGatherStart produces a tuple of + // {aliased operands, destination buffers}. If there is more than + // one operand, then both aliased operands and destination buffers + // will be tuples themselves. all-gather-start will define all tuples + // and all tuple leaves (arrays) in tuple sub-index 1 (destination + // buffers). + define_all_values([&](const ShapeIndex& index) { + return ShapeUtil::GetSubshape(instruction->shape(), index) + .IsTuple() || + index.front() == 1; + }); + break; + case HloOpcode::kAllGatherDone: + // AllGatherDone's output aliases its input tuple element {1}. + if (instruction->shape().IsTuple()) { + define_value_at(/*index=*/{}); + } + break; + case HloOpcode::kAllReduceDone: + // AllReduceDone's output aliases its input. + break; + case HloOpcode::kCollectivePermuteStart: + // CollectivePermuteStart produces a tuple of + // {aliased operand, destination buffer, contexts}, where the context + // data are optional. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + for (int i = 2; i < instruction->shape().tuple_shapes_size(); ++i) { + define_value_at(/*index=*/{i}); + } + + if (instruction->operand_count() > 1) { + CHECK_EQ(instruction->operand_count(), 4); + if (instruction->operand(1)->shape().IsTuple()) { + for (int i = 0; i < ShapeUtil::TupleElementCount( + instruction->operand(1)->shape()); + ++i) { + define_value_at(/*index=*/{1, i}); + } + } + } + break; + case HloOpcode::kCollectivePermuteDone: + // CollectivePermuteDone's output aliases its input tuple element {1}. + if (instruction->shape().IsTuple()) { + define_value_at(/*index=*/{}); + } + break; + case HloOpcode::kRecvDone: + // RecvDone produces a two-element tuple. Element zero aliases its + // input tuple element {0}; element one is a token. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + break; + case HloOpcode::kSend: + // Send produces a tuple of {aliased operand, U32 context, token}, + // therefore only defines the top-level tuple and the tuple elements + // at {1} and {2}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + define_value_at(/*index=*/{2}); + break; + default: + define_all_values(); + break; + } + } + } + + return absl::OkStatus(); +} + +/* static */ +absl::StatusOr> HloValueTracing::Run( + const HloModule& module, + const absl::flat_hash_set& execution_threads) { + auto hlo_value_tracing = + absl::WrapUnique(new HloValueTracing(module, execution_threads)); + + TF_RETURN_IF_ERROR(hlo_value_tracing->InitializeInstructionValueSets()); + hlo_value_tracing->Propagate(); + + // Delete all values marked for deletion. + hlo_value_tracing->DeleteMarkedValues(); + + // Gather and set all non-definition positions of all values. Value deletion + // is rare, so just use a vector indexed by Value::Id rather than a map from + // Value::Id to positions. There should be very few holes in the vector, and + // lookup is faster. + std::vector> value_positions( + hlo_value_tracing->next_value_id_); + for (const HloComputation* computation : module.computations()) { + if (!HloInstruction::IsThreadIncluded(computation->execution_thread(), + execution_threads)) { + continue; + } + for (HloInstruction* instruction : computation->instructions()) { + for (const auto& pair : + hlo_value_tracing->GetInstructionValueSet(instruction)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction() != instruction) { + value_positions[value->id()].push_back( + HloPosition{instruction, index}); + } + } + } + } + } + for (auto& pair : hlo_value_tracing->values_) { + HloValue::Id value_id = pair.first; + HloValue& value = *pair.second; + value.SetPositions(value_positions[value_id]); + } + + // Construct vector of values. + hlo_value_tracing->values_vector_.reserve(hlo_value_tracing->values_.size()); + for (const auto& pair : hlo_value_tracing->values_) { + hlo_value_tracing->values_vector_.push_back(pair.second.get()); + } + absl::c_sort(hlo_value_tracing->values_vector_, HloValue::IdLessThan); + + return std::move(hlo_value_tracing); +} +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h new file mode 100644 index 00000000000000..3e5d744fbdbe05 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h @@ -0,0 +1,194 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_GRAPH_ANALYSIS_HLO_VALUE_TRACING_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_GRAPH_ANALYSIS_HLO_VALUE_TRACING_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_value.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Analysis that traces the defining HLO instructions of HLO values used by +// any instruction. This is largely based on HloDataflowAnalysis with +// primary difference that the HLO values are traced back through copy and +// fusion instructions. +class HloValueTracing { + public: + // Runs dataflow analysis on the given module. + static absl::StatusOr> Run( + const HloModule& module, + const absl::flat_hash_set& execution_threads = {}); + + // Returns true if 'instruction' defines an HLO value at the given shape index + // of its output. + bool ValueIsDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + + // Returns the InstructionValueSet for the given instruction. + const InstructionValueSet& GetInstructionValueSet( + const HloInstruction* instruction) const; + InstructionValueSet& GetInstructionValueSet( + const HloInstruction* instruction); + + // Returns all values that are contained in the output of this instruction in + // a flattened set. + HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; + + // Returns the HloValueSet for the given instruction at the given index or the + // given position. + const HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + const HloValueSet& GetValueSet(const HloPosition& position) const; + HloValueSet& GetValueSet(const HloPosition& position); + HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Returns the unique value in the HloValueSet at the given instruction and + // shape index. CHECKs if the value set does not contain a exactly one value. + const HloValue& GetUniqueValueAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + return GetValueSet(instruction, index).GetUniqueValue(); + } + HloValue& GetUniqueValueAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) { + return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); + } + + // Returns the HloValue with the given Id. + const HloValue& GetValue(HloValue::Id value_id) const; + HloValue& GetValue(HloValue::Id value_id); + + // Returns the total number of HloValues. + int64_t value_count() const { return values_.size(); } + + // Returns a vector of all HloValues stabily sorted by HloValue::Id. + const std::vector& values() const { return values_vector_; } + + // Returns the call graph used for computing the dataflow. + const CallGraph& call_graph() const { return *call_graph_; } + + + const HloModule& module() const { return module_; } + + private: + HloValueTracing(const HloModule& module, + absl::flat_hash_set execution_threads); + + // Returns a new HloValue defined at the given instruction and shape index. + HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, + bool is_phi); + + // Deletes all HloValues marked for deletion. Should be called after + // propagation is complete. + void DeleteMarkedValues(); + + // Constructs and initializes the InstructionValueSets of all instructions to + // contain exactly the HloValues defined by each instruction. These values can + // then propagated throughout the HLO graph by calling Propagate. + absl::Status InitializeInstructionValueSets(); + + // Updates the value set of the given instruction based on the values flowing + // into the instruction (operands and cross-computation dataflow). + bool UpdateInstructionValueSet(HloInstruction* instruction); + + // Updates the value set for a particular instruction type. Returns whether + // the instruction value set changed. + bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); + bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateCustomCallValueSet(HloInstruction* custom_call); + bool UpdateDomainValueSet(HloInstruction* domain); + bool UpdateGetTupleElementValueSet(HloInstruction* gte); + bool UpdateParameterValueSet(HloInstruction* parameter); + // Async op propagation rules: + // - Operand of async-start to parameter of async wrapped computation and at + // index {0, operand_number} of async-start and async-update outputs. + // - Root of async wrapped computation to index {1} of async-start and + // async-update and index {} of async-done. + // - The contexts in indices {2+} of async-start to the same indices of + // async-update. + // + // As a result of this, the operands/outputs of async-start and async-done + // instructions share the same values as the parameters/roots of the async + // wrapped computation. + bool UpdateAsyncStartValueSet(HloInstruction* async_start); + bool UpdateAsyncUpdateValueSet(HloInstruction* async_update); + bool UpdateAsyncDoneValueSet(HloInstruction* async_done); + bool UpdateCopyStartValueSet(HloInstruction* copy_start); + bool UpdateCopyDoneValueSet(HloInstruction* copy_done); + bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier); + bool UpdateRecvDoneValueSet(HloInstruction* recv_done); + bool UpdateSendValueSet(HloInstruction* send); + bool UpdateTupleValueSet(HloInstruction* tuple); + bool UpdateFusionValueSet(HloInstruction* fusion); + bool UpdateWhileValueSet(HloInstruction* xla_while); + bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); + bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start); + bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done); + bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done); + bool UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start); + bool UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done); + + // Propagates the dataflow through the module. In particular, it propagates + // the HloValueSet from its defining instruction to the users of the + // instructions. + void Propagate(); + + const HloModule& module_; + const absl::flat_hash_set execution_threads_; + + std::unique_ptr call_graph_; + + // The map of all HloValues in the module. We pass around pointers to the + // mapped HloValues, so the underlying container must keep them valid despite + // mutations touching other map entries. + absl::flat_hash_map> values_; + + // A map from instruction to InstructionValueSet. + absl::flat_hash_map> + value_sets_; + + // Values marked for deletion during construction. We don't delete them + // immediately because references to them may remain in ValueSets temporarily + // during propagation. After construction, these values are deleted. + std::vector value_ids_to_delete_; + + // A vector containing all HloValues sorted by HloValue::Id. + std::vector values_vector_; + + // The Id to use for the next HloValue. + HloValue::Id next_value_id_ = 0; +}; +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_GRAPH_ANALYSIS_HLO_VALUE_TRACING_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc new file mode 100644 index 00000000000000..6293539ad0119c --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc @@ -0,0 +1,331 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h" +#include "xla/hlo/tools/hlo_diff/utils/hlo_diff_util.h" +#include "xla/service/call_graph.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tsl/platform/fingerprint.h" + +namespace xla { +namespace hlo_diff { +namespace { + +// Adds an edge between the given parent and child nodes. +void AddEdge(HloInstructionNode* parent, HloInstructionNode* child) { + parent->children.push_back(child); + child->parents.push_back(parent); +} + +// Creates HloPrintOptions from the given fingerprint options. +HloPrintOptions CreateHloPrintOptions( + const HloGumgraphFingerprintOptions& fingerprint_options) { + HloPrintOptions hlo_print_options = + HloPrintOptions::Fingerprint() + .set_include_layout_in_shapes(false) + .set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff) + .set_print_parameter_number(false); + if (fingerprint_options.ignore_shape) { + hlo_print_options.set_print_operand_shape(false); + hlo_print_options.set_print_result_shape(false); + } + return hlo_print_options; +} + +} // namespace + +std::pair HloGumgraph::AddNode( + const HloInstruction& instruction, int unique_node_index) { + auto node = std::make_unique(HloInstructionNode{ + .instruction = &instruction, .unique_node_index = unique_node_index}); + auto [new_node_it, inserted] = + instruction_to_node_.try_emplace(&instruction, std::move(node)); + return {new_node_it->second.get(), inserted}; +} + +absl::Status HloGumgraph::ConstructGraph(const HloModule& hlo_module) { + LOG(INFO) << "Constructing HloGumgraph"; + int unique_instruction_index = 0; + for (auto* computation : hlo_module.MakeComputationPostOrder()) { + for (auto* instruction : computation->MakeInstructionPostOrder()) { + std::pair node_and_inserted = + AddNode(*instruction, ++unique_instruction_index); + if (!node_and_inserted.second) { + return absl::InternalError(absl::StrCat( + "Instruction: ", instruction->name(), " already in the graph")); + } + + HloInstructionNode* node = node_and_inserted.first; + node->props.fingerprint = GetHloInstructionFingerprint( + instruction, CreateHloPrintOptions(fingerprint_options_)); + + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kFusion: + case HloOpcode::kWhile: { + // Connect Call, Fusion and While instruction's called computations + // parameters with the operands of the caller instructions to inline + // the called computation as they should match 1:1. + for (auto* called_computation : instruction->called_computations()) { + for (int i = 0; i < instruction->operands().size(); ++i) { + HloInstructionNode* parent = + GetNode(called_computation->parameter_instruction(i)); + HloInstructionNode* child = GetNode(instruction->operands()[i]); + if (parent == nullptr || child == nullptr) { + return absl::InternalError(absl::StrFormat( + "Called computation instruction (%s) operand not found " + "in the called computation: %s parameters", + child->GetName(), parent->GetName())); + } + AddEdge(parent, child); + } + } + break; + } + case HloOpcode::kConditional: { + // Connect conditional instruction node with the predicate operand. + HloInstructionNode* pred_node = GetNode(instruction->operands()[0]); + if (pred_node == nullptr) { + return absl::InternalError(absl::StrFormat( + "Instruction (%s) operand: %s not found in the graph", + instruction->name(), instruction->operands()[0]->name())); + } + AddEdge(node, pred_node); + + // Connect conditional instruction's branch computations parameters + // with the operands of the caller instructions to inline the branch + // computations. + for (int i = 0; i < instruction->branch_count(); ++i) { + HloComputation* branch_computation = + instruction->branch_computation(i); + HloInstructionNode* parent = + GetNode(branch_computation->parameter_instruction(0)); + HloInstructionNode* child = GetNode(instruction->operands()[i + 1]); + if (parent == nullptr || child == nullptr) { + return absl::InternalError(absl::StrFormat( + "Branch computation instruction (%s) operand not found " + "in the branch computation: %s parameters", + child->GetName(), parent->GetName())); + } + AddEdge(parent, child); + } + break; + } + default: { + for (auto* operand : instruction->operands()) { + HloInstructionNode* child = GetNode(operand); + if (child == nullptr) { + return absl::InternalError(absl::StrFormat( + "Instruction (%s) operand: %s not found in the graph", + instruction->name(), operand->name())); + } + AddEdge(node, child); + } + } + } + + // Connect the root instruction of the called computation with the + // caller instruction. + for (auto* called_computation : instruction->called_computations()) { + HloInstructionNode* called_computation_root_node = + GetNode(called_computation->root_instruction()); + if (called_computation_root_node == nullptr) { + return absl::InternalError(absl::StrFormat( + "Called computation (%s) root: %s not found in the graph", + called_computation->name(), + called_computation->root_instruction()->name())); + } + AddEdge(node, called_computation_root_node); + } + } + } + + return absl::OkStatus(); +} + +absl::StatusOr> +HloGumgraph::PrecomputeGenerations() { + LOG(INFO) << "Precomputing generations"; + std::vector zero_indegrees; + absl::flat_hash_map indegrees; + for (const auto& [_, node] : instruction_to_node_) { + if (node->parents.empty()) { + zero_indegrees.push_back(node.get()); + continue; + } + + auto [it, inserted] = indegrees.insert({node.get(), node->parents.size()}); + if (!inserted) { + return absl::InternalError( + absl::StrCat("Instruction: ", node->instruction->name(), + " already inserted in indegree map")); + } + indegrees[node.get()] = node->parents.size(); + } + std::vector init_zero_indegrees = zero_indegrees; + nodes_by_generation_.push_back({&root_}); + + int current_generation = 1; + while (!zero_indegrees.empty()) { + std::vector current_generation_nodes = + std::move(zero_indegrees); + zero_indegrees = {}; + + for (int i = 0; i < current_generation_nodes.size(); ++i) { + current_generation_nodes[i]->props.generation = current_generation; + current_generation_nodes[i]->props.sibling_position = { + i, static_cast(current_generation_nodes.size())}; + for (HloInstructionNode* child : current_generation_nodes[i]->children) { + auto it = indegrees.find(child); + if (it == indegrees.end()) { + return absl::InternalError( + absl::StrCat("Instruction: ", child->instruction->name(), + " not found in indegree map")); + } + --it->second; + if (it->second == 0) { + zero_indegrees.push_back(child); + indegrees.erase(it); + } + } + } + nodes_by_generation_.push_back(std::move(current_generation_nodes)); + ++current_generation; + } + + if (!indegrees.empty()) { + LOG(WARNING) << "Cycle detected in the graph."; + return absl::InternalError("Cycle detected in the graph"); + } + return init_zero_indegrees; +} + +void HloGumgraph::PrecomputeSizeAndHeight() { + LOG(INFO) << "Precomputing size and height"; + // TODO(camillesun): Refactor this to use DFS. + for (auto it = nodes_by_generation_.rbegin(); + it != nodes_by_generation_.rend(); ++it) { + for (HloInstructionNode* node : *it) { + int64_t height = 0; + uint64_t fingerprint = node->props.fingerprint; + + for (const HloInstructionNode* child : node->children) { + height = std::max(height, child->props.height); + fingerprint = tsl::FingerprintCat64(fingerprint, + child->props.subgraph_fingerprint); + } + + node->props.height = height + 1; + // TODO(b/365855856): graph with different structure can share a same + // subgraph fingerprint, see test case + // PreComputationsWorksSubgraphFingerprint. This is unexpected. + node->props.subgraph_fingerprint = fingerprint; + } + } +} + +absl::Status HloGumgraph::PrecomputeComputationFingerprint() { + LOG(INFO) << "Precomputing computation fingerprint"; + TF_RETURN_IF_ERROR(call_graph_->VisitNodes([&](const CallGraphNode& node) + -> absl::Status { + absl::flat_hash_map subgraph_fingerprint; + const HloComputation* computation = node.computation(); + for (auto* instruction : computation->MakeInstructionPostOrder()) { + uint64_t fp = GetNode(instruction)->props.fingerprint; + for (const HloInstruction* operand : instruction->operands()) { + fp = tsl::FingerprintCat64( + fp, subgraph_fingerprint.at(GetNode(operand)->instruction)); + } + subgraph_fingerprint[instruction] = fp; + } + + computation_to_props_[computation] = CallGraphNodeProps{ + .call_graph_node = &node, + .fingerprint = + subgraph_fingerprint.at(computation->root_instruction())}; + + return absl::OkStatus(); + })); + return absl::OkStatus(); +} + +void HloGumgraph::PrecomputeDfsPosition() { + LOG(INFO) << "Precomputing DFS position"; + std::vector pre_order_nodes = + GetAllNodesInDfsOrder(root_, DfsTraversalOrder::kPreOrder, + GetNodeCount()); + for (int i = 0; i < pre_order_nodes.size(); ++i) { + if (pre_order_nodes[i]->is_root) { + continue; + } + instruction_to_node_[pre_order_nodes[i]->instruction] + ->props.pre_order_graph_position = { + i, static_cast(pre_order_nodes.size())}; + } +} + +absl::StatusOr> HloGumgraph::Create( + absl::Nonnull hlo_module, + const HloGumgraphFingerprintOptions& fingerprint_options) { + CHECK(hlo_module != nullptr) << "Expected a non-null hlo module"; + CHECK(hlo_module->entry_computation() != nullptr) + << "Expected a non-null entry computation"; + + std::unique_ptr call_graph = CallGraph::Build(hlo_module); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_value_tracing, + HloValueTracing::Run(*hlo_module)); + auto graph = absl::WrapUnique( + new HloGumgraph(*hlo_module, fingerprint_options, std::move(call_graph), + std::move(hlo_value_tracing))); + + TF_RETURN_IF_ERROR(graph->ConstructGraph(*hlo_module)); + TF_ASSIGN_OR_RETURN(std::vector zero_indegree_nodes, + graph->PrecomputeGenerations()); + for (auto* zero_indegree_node : zero_indegree_nodes) { + AddEdge(&graph->root_, zero_indegree_node); + } + graph->PrecomputeSizeAndHeight(); + TF_RETURN_IF_ERROR(graph->PrecomputeComputationFingerprint()); + graph->PrecomputeDfsPosition(); + + return graph; +}; + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h new file mode 100644 index 00000000000000..5ff403774e0e9a --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h @@ -0,0 +1,161 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_GRAPH_HLO_GUMGRAPH_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_GRAPH_HLO_GUMGRAPH_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/service/call_graph.h" + +namespace xla { +namespace hlo_diff { + +// Options for computing the per instruction/node fingerprint of an HloGumgraph. +struct HloGumgraphFingerprintOptions { + // Ignore shape when computing the instruction fingerprint. + bool ignore_shape = false; +}; + +// A directed acyclic graph representation of an HloModule with all called +// computations inlined i.e. the calling instructions is connected to the +// called computation's root instruction. +class HloGumgraph { + public: + // Instantiates a HloGumgraph from a HloModule, pre-processing and caching + // various graph properties such as height, siblings per node etc. + static absl::StatusOr> Create( + absl::Nonnull hlo_module, + const HloGumgraphFingerprintOptions& fingerprint_options = {}); + + // HloGumgraph is neither copyable nor movable as it can be really large. + HloGumgraph(const HloGumgraph&) = delete; + HloGumgraph& operator=(const HloGumgraph&) = delete; + + // Returns the dummy root node which is connected to all zero-indegree nodes + // in the graph. The dummy root is always connected to the entry computation's + // root instruction but additionally might be connected to other unreachable + // roots in the entry computation. + inline const HloInstructionNode& GetRoot() const { return root_; } + + // Returns graph node corresponding to the given HloInstruction. Returns + // nullptr if the instruction is not in the graph. + inline HloInstructionNode* GetNode( + absl::Nonnull instruction) const { + if (auto it = instruction_to_node_.find(instruction); + it != instruction_to_node_.end()) { + return it->second.get(); + } + return nullptr; + } + + // Returns all nodes in the graph excluding the dummy root node. + inline std::vector AllNodes() const { + std::vector nodes; + for (const auto& [_, node] : instruction_to_node_) { + nodes.push_back(node.get()); + } + return nodes; + } + + // Returns the number of nodes in the graph including the dummy root node. + inline int GetNodeCount() const { return instruction_to_node_.size() + 1; } + + // Returns all properties of computations in the graph. + inline const absl::flat_hash_map& + AllComputationProps() const { + return computation_to_props_; + } + + // Returns the call graph of the HloModule. + const CallGraph& GetCallGraph() const { return *call_graph_; } + + // Returns the HloValueTracing used to trace the HloValues used by + // instructions. + const HloValueTracing& GetHloValueTracing() const { + return *hlo_value_tracing_; + } + + // Returns the backing HloModule of the HloGumgraph. + const HloModule& GetHloModule() const { return hlo_module_; } + + private: + explicit HloGumgraph(const HloModule& hlo_module, + const HloGumgraphFingerprintOptions& fingerprint_options, + std::unique_ptr call_graph, + std::unique_ptr hlo_value_tracing) + : hlo_module_(hlo_module), + fingerprint_options_(fingerprint_options), + root_( + {.instruction = nullptr, .unique_node_index = 0, .is_root = true}), + call_graph_(std::move(call_graph)), + hlo_value_tracing_(std::move(hlo_value_tracing)) {} + + // Adds a HloInstructionNode for the given HloInstruction to the graph. + // Returns a pair of the node and a boolean indicating whether the node was + // already in the graph. + std::pair AddNode( + const HloInstruction& instruction, int unique_node_index); + + // Constructs the HloGumgraph from the given HloModule connecting Instruction + // operands and called computations. + absl::Status ConstructGraph(const HloModule& hlo_module); + + // Precomputes the generation of each node in the graph. Generation of a node + // is simply the longest distance of a node from the root node. The generation + // of the root node is 0. Additionally it returns all zero-indegree nodes. + absl::StatusOr> PrecomputeGenerations(); + + // Precomputes the size and height of each node in the graph. + void PrecomputeSizeAndHeight(); + + // Precomputes the fingerprint of each computation in the graph, all + // instructions in the computation are hashed to compute the fingerprint. + absl::Status PrecomputeComputationFingerprint(); + + // Precomputes the index of each node in a pre-order DFS traversal of the + // graph. + void PrecomputeDfsPosition(); + + const HloModule& hlo_module_; + const HloGumgraphFingerprintOptions& fingerprint_options_; + HloInstructionNode root_; + absl::flat_hash_map> + instruction_to_node_; + absl::flat_hash_map + computation_to_props_; + std::vector> nodes_by_generation_; + const std::unique_ptr call_graph_; + const std::unique_ptr hlo_value_tracing_; +}; + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_GRAPH_HLO_GUMGRAPH_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h new file mode 100644 index 00000000000000..d5fa2d14a41496 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h @@ -0,0 +1,73 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_GRAPH_HLO_GUMGRAPH_NODE_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_GRAPH_HLO_GUMGRAPH_NODE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/call_graph.h" + +namespace xla { +namespace hlo_diff { + +// Position of a node in a container of siblings. +struct ListPosition { + int64_t index = 0; + int64_t size = 0; +}; + +// Properties of a instruction node in a HloGumgraph such as generation etc. +struct HloInstructionNodeProps { + int64_t generation = 0; + int64_t height = 0; + uint64_t subgraph_fingerprint = 0; + uint64_t fingerprint = 0; + ListPosition sibling_position; + ListPosition pre_order_graph_position; +}; + +// Properties of a computation node in a HloGumgraph. +struct CallGraphNodeProps { + const CallGraphNode* call_graph_node; + uint64_t fingerprint = 0; + absl::string_view GetName() const { + return call_graph_node->computation()->name(); + } +}; + +// A node in a HloGumgraph representing a HLO instruction. +// Only root nodes can have no instruction. +struct HloInstructionNode { + const HloInstruction* instruction; + int unique_node_index = 0; + std::vector children; + std::vector parents; + HloInstructionNodeProps props; + bool is_root = false; + absl::string_view GetName() const { + return is_root ? "root" : instruction->name(); + } +}; + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_GRAPH_HLO_GUMGRAPH_NODE_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_test.cc new file mode 100644 index 00000000000000..b022b51ea3fee1 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_test.cc @@ -0,0 +1,523 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/service/hlo_module_config.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::Field; +using ::testing::FieldsAre; +using ::testing::Pair; +using ::testing::Pointee; +using ::testing::Property; +using ::testing::UnorderedElementsAre; + +class HloGumgraphTest : public HloTestBase {}; + +const HloInstructionNode* SelectNodeByName(const HloGumgraph& graph, + absl::string_view name) { + const HloInstructionNode* result = nullptr; + for (const auto* node : graph.AllNodes()) { + if (!node->is_root && node->instruction->name() == name) { + result = node; + break; + } + } + return result; +} + +// Returns true if the subgraph fingerprint of the roots are the same. +bool FingerprintEqualTo(const HloGumgraph& first, const HloGumgraph& second) { + return first.GetRoot().props.subgraph_fingerprint == + second.GetRoot().props.subgraph_fingerprint; +} + +void AssertNode(const HloInstructionNode* actual_node, + absl::string_view expected_node_name, int expected_num_children, + int expected_num_parents) { + EXPECT_EQ(actual_node->instruction->name(), expected_node_name); + ASSERT_EQ(actual_node->children.size(), expected_num_children); + ASSERT_EQ(actual_node->parents.size(), expected_num_parents); +} + +TEST_F(HloGumgraphTest, CreateSimpleHloModuleWithoutFusionInstructionWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | Add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + const auto* entry = graph->GetRoot().children[0]; + ASSERT_NO_FATAL_FAILURE(AssertNode(entry, "add_0", 2, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode(entry->children[0], "add_1", 2, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode(entry->children[1], "baz", 0, 1)); + ASSERT_NO_FATAL_FAILURE( + AssertNode(entry->children[0]->children[0], "foo", 0, 1)); + ASSERT_NO_FATAL_FAILURE( + AssertNode(entry->children[0]->children[1], "bar", 0, 1)); + + EXPECT_THAT( + graph->AllComputationProps(), + UnorderedElementsAre(Pair( + Pointee(Property(&HloComputation::name, "entry")), + Field(&CallGraphNodeProps::fingerprint, 10150663182810228731U)))); +} + +TEST_F(HloGumgraphTest, CreateHloModuleWithFusionInstructionWorks) { + // Create a module with entry computation containing the following structure: + // [Param p0] ---> [Param p2] ---> ┌-------┐ ┌----------┐ ┌------┐ + // | add.1 | ---> | fusion.1 | ---> | ROOT | + // [Param p1] ---> [Param p3] ---> └-------┘ └----------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) +} + +ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + const auto* entry = graph->GetRoot().children[0]; + ASSERT_NO_FATAL_FAILURE(AssertNode(entry, "fusion.1", 1, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode(entry->children[0], "add.1", 2, 1)); + ASSERT_NO_FATAL_FAILURE( + AssertNode(entry->children[0]->children[0], "p2", 1, 1)); + ASSERT_NO_FATAL_FAILURE( + AssertNode(entry->children[0]->children[0]->children[0], "p0", 0, 1)); + + EXPECT_THAT( + graph->AllComputationProps(), + UnorderedElementsAre( + Pair(Pointee(Property(&HloComputation::name, "entry")), + Field(&CallGraphNodeProps::fingerprint, 17918193494741257405U)), + Pair( + Pointee(Property(&HloComputation::name, "fused_computation.1")), + Field(&CallGraphNodeProps::fingerprint, 18256571801256786953U)))); +} + +TEST_F(HloGumgraphTest, CreateHloModuleWithConditionalInstructionWorks) { + // Create a module with entry computation containing the following structure: + // [constant.2] ---> [y] ---> [identity] ---> ┌-------------┐ + // | | ┌------┐ + // [constant.1] ---> [x] ---> [negate] -----> | conditional | ---> | ROOT | + // [constant] ------------------------------> └-------------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +Negate { + x = f32[] parameter(0) + ROOT negate = f32[] negate(x) +} + +Identity { + y = f32[] parameter(0) + ROOT identity = f32[] copy(y) +} + +ENTRY entry { + constant = pred[] constant(true) + constant.1 = f32[] constant(56) + constant.2 = f32[] constant(12) + ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + const auto* entry = graph->GetRoot().children[0]; + ASSERT_NO_FATAL_FAILURE(AssertNode(entry, "conditional", 3, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode(entry->children[2], "identity", 1, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode(entry->children[1], "negate", 1, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode(entry->children[0], "constant", 0, 1)); + ASSERT_NO_FATAL_FAILURE( + AssertNode(entry->children[2]->children[0], "y", 1, 1)); + ASSERT_NO_FATAL_FAILURE( + AssertNode(entry->children[1]->children[0], "x", 1, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode( + entry->children[2]->children[0]->children[0], "constant.2", 0, 1)); + ASSERT_NO_FATAL_FAILURE(AssertNode( + entry->children[1]->children[0]->children[0], "constant.1", 0, 1)); + + EXPECT_THAT( + graph->AllComputationProps(), + UnorderedElementsAre( + Pair(Pointee(Property(&HloComputation::name, "entry")), + Field(&CallGraphNodeProps::fingerprint, 9646443073508437215U)), + Pair(Pointee(Property(&HloComputation::name, "Identity")), + Field(&CallGraphNodeProps::fingerprint, 7593821242743477274U)), + Pair( + Pointee(Property(&HloComputation::name, "Negate")), + Field(&CallGraphNodeProps::fingerprint, 11882609566947793238U)))); +} + +TEST_F(HloGumgraphTest, PreComputationsWorksWithoutShapeInFingerprint) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | Add_1 | + // ┌------------┐ ---> └-------┘ ---> ┌-------┐ ┌------┐ + // |Constant bar| | add_0 | ---> | ROOT | + // └------------┘ ------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, bar) +} +)")); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr graph, + HloGumgraph::Create(module.get(), {.ignore_shape = true})); + + const auto* entry = graph->GetRoot().children[0]; + EXPECT_THAT( + entry->props, + FieldsAre( + /*generation=*/1, + /*height=*/3, /*subgraph_fingerprint=*/8543065396480500811U, + /*fingerprint=*/7968662072287666665U, + /*sibling_position=*/FieldsAre(/*index=*/0, /*size=*/1), + /*pre_order_graph_position=*/FieldsAre(/*index=*/1, /*size=*/5))); + EXPECT_THAT( + entry->children[0]->props, + FieldsAre( + /*generation=*/2, + /*height=*/2, /*subgraph_fingerprint=*/12467718903949982030U, + /*fingerprint=*/7968662072287666665U, + /*sibling_position=*/FieldsAre(/*index=*/0, /*size=*/1), + /*pre_order_graph_position=*/FieldsAre(/*index=*/3, /*size=*/5))); + EXPECT_THAT( + entry->children[1]->props, + FieldsAre( + /*generation=*/3, + /*height=*/1, /*subgraph_fingerprint=*/3183718271480206887U, + /*fingerprint=*/3183718271480206887U, + /*sibling_position=*/FieldsAre(/*index=*/1, /*size=*/2), + /*pre_order_graph_position=*/FieldsAre(/*index=*/2, /*size=*/5))); + EXPECT_THAT( + entry->children[0]->children[0]->props, + FieldsAre( + /*generation=*/3, + /*height=*/1, /*subgraph_fingerprint=*/856105463456541506U, + /*fingerprint=*/856105463456541506U, + /*sibling_position=*/FieldsAre(/*index=*/0, /*size=*/2), + /*pre_order_graph_position=*/FieldsAre(/*index=*/4, /*size=*/5))); + + EXPECT_THAT( + graph->AllComputationProps(), + UnorderedElementsAre( + Pair(Pointee(Property(&HloComputation::name, "entry")), + Field(&CallGraphNodeProps::fingerprint, 8543065396480500811U)))); +} + +TEST_F(HloGumgraphTest, PreComputationsWorksWithShapeInFingerprint) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | Add_1 | + // ┌------------┐ ---> └-------┘ ---> ┌-------┐ ┌------┐ + // |Constant bar| | add_0 | ---> | ROOT | + // └------------┘ ------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, bar) +} +)")); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr graph, + HloGumgraph::Create(module.get(), {.ignore_shape = false})); + + const auto* entry = graph->GetRoot().children[0]; + EXPECT_THAT( + entry->props, + FieldsAre( + /*generation=*/1, + /*height=*/3, /*subgraph_fingerprint=*/11491866794545709423U, + /*fingerprint=*/13023796333337170182U, + /*sibling_position=*/FieldsAre(/*index=*/0, /*size=*/1), + /*pre_order_graph_position=*/FieldsAre(/*index=*/1, /*size=*/5))); + + EXPECT_THAT( + entry->children[0]->props, + FieldsAre( + /*generation=*/2, + /*height=*/2, /*subgraph_fingerprint=*/11413025457497517292U, + /*fingerprint=*/13023796333337170182U, + /*sibling_position=*/FieldsAre(/*index=*/0, /*size=*/1), + /*pre_order_graph_position=*/FieldsAre(/*index=*/3, /*size=*/5))); + EXPECT_THAT( + entry->children[1]->props, + FieldsAre( + /*generation=*/3, + /*height=*/1, /*subgraph_fingerprint=*/18045659843081992748U, + /*fingerprint=*/18045659843081992748U, + /*sibling_position=*/FieldsAre(/*index=*/1, /*size=*/2), + /*pre_order_graph_position=*/FieldsAre(/*index=*/2, /*size=*/5))); + EXPECT_THAT( + entry->children[0]->children[0]->props, + FieldsAre( + /*generation=*/3, + /*height=*/1, /*subgraph_fingerprint=*/7851455295828926644U, + /*fingerprint=*/7851455295828926644U, + /*sibling_position=*/FieldsAre(/*index=*/0, /*size=*/2), + /*pre_order_graph_position=*/FieldsAre(/*index=*/4, /*size=*/5))); + + EXPECT_THAT( + graph->AllComputationProps(), + UnorderedElementsAre(Pair( + Pointee(Property(&HloComputation::name, "entry")), + Field(&CallGraphNodeProps::fingerprint, 11491866794545709423U)))); +} + +TEST_F(HloGumgraphTest, PreComputationsWorksMultiRoot) { + // Create a module with entry computation containing the following structure: + // ┌--------┐ ┌-----------┐ + // ┌-----------┐ -----> | recv | --------> | recv-done | ---> ┌------┐ + // | after-all | └--------┘ └-----------┘ | ROOT | + // └-----------┘ -----> ┌--------┐ ┌-----------┐ ---> └------┘ + // ┌----------┐ | send | --------> | send-done | + // | constant | ------> └--------┘ └-----------┘ + // └----------┘ + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule( + R"(HloModule TwoSendRecvBothWayRecvFist_module, entry_computation_layout={()->(f32[], token[])} + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 + %constant = f32[] constant(2.1) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%recv} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 +} + +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + EXPECT_EQ(SelectNodeByName(*graph, "recv")->props.generation, 2); + EXPECT_EQ(SelectNodeByName(*graph, "recv-done")->props.generation, 1); + EXPECT_EQ(SelectNodeByName(*graph, "send")->props.generation, 2); + EXPECT_EQ(SelectNodeByName(*graph, "send-done")->props.generation, 1); + EXPECT_EQ(SelectNodeByName(*graph, "token0")->props.generation, 3); + EXPECT_EQ(SelectNodeByName(*graph, "constant")->props.generation, 3); +} + +TEST_F(HloGumgraphTest, PreComputationsWorksSubgraphFingerprint) { + // Create left module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | add_0 | + // [Const 1] ---> └-------┘ ---> ┌-------┐ ┌------┐ + // | add_3 | ---> | ROOT | + // [Const 2] ---> ┌-------┐ ---> └-------┘ └------┘ + // | add_1 | + // [Const 3] ---> └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.2, constant.3) + add.3 = f32[] add(add.0, add.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | add_0 | + // ┌-------┐ ---> └-------┘ ---> ┌-------┐ ┌------┐ + // |Const 1| | add_3 | ---> | ROOT | + // └-------┘ ---> ┌-------┐ ---> └-------┘ └------┘ + // | add_1 | + // [Const 3] ---> └-------┘ + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.3 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.1, constant.3) + add.3 = f32[] add(add.0, add.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + + // TODO(b/365855856): The subgraph fingerprint should not be the same. + // EXPECT_NE(graph_l->GetRoot().props.subgraph_fingerprint, + // graph_r->GetRoot().props.subgraph_fingerprint); + EXPECT_EQ(graph_l->GetRoot().props.subgraph_fingerprint, + graph_r->GetRoot().props.subgraph_fingerprint); +} + +using HloGumgraphDeathTest = HloGumgraphTest; + +TEST_F(HloGumgraphDeathTest, CreateWithNullHloModuleFails) { + // The `hlo_module` parameter is annotated nonnull, but we want to test the + // defensive null check. Use a variable instead of passing nullptr directly + // to avoid a `-Wnonnull` warning. + HloModule* null_hlo_module = nullptr; + ASSERT_DEATH(auto unused = HloGumgraph::Create(null_hlo_module), ""); +} + +TEST_F(HloGumgraphDeathTest, CreateWithNullEntryComputationFails) { + HloModule hlo_module("module", HloModuleConfig()); + + ASSERT_DEATH(auto unused = HloGumgraph::Create(&hlo_module), ""); +} + +TEST_F(HloGumgraphTest, CheckEqualityForIdenticalGraphs) { + // Create two identical modules with entry computation containing the + // following structure: + // [Param foo] ------> ┌-------┐ + // | Add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + const auto* hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto first_module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto second_module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr first_graph, + HloGumgraph::Create(first_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr second_graph, + HloGumgraph::Create(second_module.get())); + + EXPECT_TRUE(FingerprintEqualTo(*first_graph, *second_graph)); +} + +TEST_F(HloGumgraphTest, CheckEqualityForDifferentGraphs) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | Add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(auto first_module, ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | Add_1 | ---> ┌------------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | subtract_0 | ---> | ROOT | + // [Param baz] ---------------------> └------------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(auto second_module, ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + subtract_0 = f32[8,2048]{1,0:T(8,128)} subtract(add_1, baz) +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr first_graph, + HloGumgraph::Create(first_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr second_graph, + HloGumgraph::Create(second_module.get())); + + EXPECT_FALSE(FingerprintEqualTo(*first_graph, *second_graph)); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/BUILD new file mode 100644 index 00000000000000..b9bd8229ac7fa8 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/BUILD @@ -0,0 +1,62 @@ +load("//xla:xla.default.bzl", "xla_cc_test") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//xla/hlo/tools/hlo_diff:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "hlo_gumgraph_bfs", + srcs = ["hlo_gumgraph_bfs.cc"], + hdrs = ["hlo_gumgraph_bfs.h"], + deps = [ + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "hlo_gumgraph_dfs", + srcs = ["hlo_gumgraph_dfs.cc"], + hdrs = ["hlo_gumgraph_dfs.h"], + deps = [ + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "@com_google_absl//absl/functional:function_ref", + ], +) + +xla_cc_test( + name = "hlo_gumgraph_dfs_test", + srcs = ["hlo_gumgraph_dfs_test.cc"], + deps = [ + ":hlo_gumgraph_dfs", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +xla_cc_test( + name = "hlo_gumgraph_bfs_test", + srcs = ["hlo_gumgraph_bfs_test.cc"], + deps = [ + ":hlo_gumgraph_bfs", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.cc new file mode 100644 index 00000000000000..83375acc105084 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.cc @@ -0,0 +1,108 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h" + +#include +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" + +namespace xla { +namespace hlo_diff { +namespace { + +bool GetVisited(std::vector& visited, int node_index) { + int index = node_index / 64; + CHECK_LT(index, visited.size()); + return visited[index] & (1ull << (node_index % 64)); +} + +void SetVisited(std::vector& visited, int node_index) { + int index = node_index / 64; + CHECK_LT(index, visited.size()); + visited[index] |= (1ull << (node_index % 64)); +} +} // namespace + +void HloGumgraphBfs( + absl::Span start_nodes, + absl::FunctionRef per_node_fn, + BfsTraversalDirection direction, int node_limit, + absl::FunctionRef expand_node_fn) { + std::queue nodes_to_expand; + std::vector visited((node_limit + 63) / 64, 0); + + for (const HloInstructionNode* start_node : start_nodes) { + CHECK(start_node != nullptr) << "Expected a non-null root node"; + if (!per_node_fn(*start_node)) { + return; + } + if (expand_node_fn(*start_node)) { + nodes_to_expand.push(start_node); + } + SetVisited(visited, start_node->unique_node_index); + } + + while (!nodes_to_expand.empty()) { + const HloInstructionNode* current_node = nodes_to_expand.front(); + nodes_to_expand.pop(); + + std::vector adjacent_nodes = + direction == BfsTraversalDirection::kForward ? current_node->children + : current_node->parents; + + for (auto* adjacent_node : adjacent_nodes) { + if (!GetVisited(visited, adjacent_node->unique_node_index)) { + if (!per_node_fn(*adjacent_node)) { + return; + } + if (expand_node_fn(*adjacent_node)) { + nodes_to_expand.push(adjacent_node); + } + SetVisited(visited, adjacent_node->unique_node_index); + } + } + } +} + +void HloGumgraphBfs( + const HloInstructionNode& start_node, + absl::FunctionRef per_node_fn, + BfsTraversalDirection direction, int node_limit, + absl::FunctionRef expand_node_fn) { + HloGumgraphBfs(std::vector({&start_node}), + per_node_fn, direction, node_limit, expand_node_fn); +} + +std::vector GetAllNodesInBfsOrder( + const HloInstructionNode& root, BfsTraversalDirection direction, + int node_limit) { + std::vector subgraph; + HloGumgraphBfs( + root, + [&](const HloInstructionNode& node) { + subgraph.push_back(&node); + return true; + }, + BfsTraversalDirection::kForward, node_limit); + return subgraph; +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h new file mode 100644 index 00000000000000..aab6da1c30529e --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h @@ -0,0 +1,74 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_GRAPH_UTILS_HLO_GUMGRAPH_BFS_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_GRAPH_UTILS_HLO_GUMGRAPH_BFS_H_ + +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" + +namespace xla::hlo_diff { + +// Direction of the BFS traversal. +enum class BfsTraversalDirection : std::int8_t { kForward, kReverse }; + +// Performs a breadth first search of the HLO Module starting with specified +// instruction node as start and calls supplied per node execution function. +// +// If the per_node_fn returns false for a node, the BFS traversal will be +// terminate immediately. +// +// The BFS traversal is performed in the specified direction. +// kForward: Start from the start node and traverse forward to the nodes +// children. +// kReverse: Start from the start node and traverse backwards to the +// nodes parents. +// +// The node_limit parameter should be set to the number of nodes in the +// HLOGumgraph, as its used to track the visit state of each node during +// traversal. +// +// If the expand_node_fn returns false for a node, the children of the node +// will not be visited. +void HloGumgraphBfs( + const HloInstructionNode& start_node, + absl::FunctionRef per_node_fn, + BfsTraversalDirection direction, int node_limit, + absl::FunctionRef expand_node_fn = + [](const HloInstructionNode&) { return true; }); + +// Breadth first search from multiple start nodes. Check comment of +// HloGumgraphBfs for more details. +void HloGumgraphBfs( + absl::Span start_nodes, + absl::FunctionRef per_node_fn, + BfsTraversalDirection direction, int node_limit, + absl::FunctionRef expand_node_fn = + [](const HloInstructionNode&) { return true; }); + +// Returns all nodes start from the given node in BFS order. Check comment of +// HloGumgraphBfs for more details. +std::vector GetAllNodesInBfsOrder( + const HloInstructionNode& root, BfsTraversalDirection direction, + int node_limit = 100000); + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_GRAPH_UTILS_HLO_GUMGRAPH_BFS_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs_test.cc new file mode 100644 index 00000000000000..cee91741ca8db6 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs_test.cc @@ -0,0 +1,246 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::ElementsAre; + +class HloGumgraphBfsTest : public HloTestBase {}; + +TEST_F(HloGumgraphBfsTest, BfsForwardWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + const auto root = graph->GetRoot(); + HloGumgraphBfs( + root, + [&visited_nodes](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + return true; + }, + BfsTraversalDirection::kForward, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, + ElementsAre("root", "add_0", "add_1", "baz", "foo", "bar")); +} + +TEST_F(HloGumgraphBfsTest, BfsReverseWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ ┌-------┐ ┌------┐ + // | add_0 | ---> | abs_0 | ---> | ROOT | + // [Constant bar] ---> └-------┘ └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + add_0 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + abs_0 = f32[8,2048]{1,0:T(8,128)} abs(f32[8,2048]{1,0:T(8,128)} %add_0) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + const auto root = graph->GetRoot(); + HloGumgraphBfs( + *root.children[0]->children[0]->children[0], + [&visited_nodes](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + return true; + }, + BfsTraversalDirection::kReverse, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, ElementsAre("foo", "add_0", "abs_0", "root")); +} + +TEST_F(HloGumgraphBfsTest, GetAllNodesWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + const auto root = graph->GetRoot(); + std::vector visited_nodes = + GetAllNodesInBfsOrder(root, BfsTraversalDirection::kForward); + std::vector string_views; + string_views.reserve(visited_nodes.size()); + for (const HloInstructionNode* node : visited_nodes) { + string_views.push_back(node->GetName()); + } + + EXPECT_THAT(string_views, + ElementsAre("root", "add_0", "add_1", "baz", "foo", "bar")); +} + +TEST_F(HloGumgraphBfsTest, BfsFromMultipleNodesWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + const auto root = graph->GetRoot(); + HloGumgraphBfs( + std::vector{root.children[0]->children[0], + root.children[0]}, + [&visited_nodes](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + return true; + }, + BfsTraversalDirection::kForward, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, + ElementsAre("add_1", "add_0", "foo", "bar", "baz")); +} + +TEST_F(HloGumgraphBfsTest, BfsStopExpandingWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + HloGumgraphBfs( + graph->GetRoot(), + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + return true; + }, + BfsTraversalDirection::kForward, 6, + [&](const HloInstructionNode& node) { + return node.GetName() != "add_1"; + }); + + EXPECT_THAT(visited_nodes, ElementsAre("root", "add_0", "add_1", "baz")); +} + +TEST_F(HloGumgraphBfsTest, BfsEarlyTerminationWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + // This is an example of how to use per_node_fn return value to limit a BFS + // traversal to stop after already visiting 5 nodes. + HloGumgraphBfs( + graph->GetRoot(), + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + return visited_nodes.size() < 5; + }, + BfsTraversalDirection::kForward, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, + ElementsAre("root", "add_0", "add_1", "baz", "foo")); +} +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.cc new file mode 100644 index 00000000000000..b02afcb74db78f --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.cc @@ -0,0 +1,81 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h" + +#include +#include + +#include "absl/functional/function_ref.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" + +namespace xla::hlo_diff { +namespace { + +enum class VisitState : uint8_t { kNew = 0, kVisiting = 1, kVisited = 2 }; + +} // namespace + +void HloGumgraphDfs( + const HloInstructionNode& start_node, + absl::FunctionRef per_node_fn, + DfsTraversalOrder order, int node_limit, + absl::FunctionRef expand_node_fn) { + std::vector visited(node_limit); + + std::vector stack = {&start_node}; + + while (!stack.empty()) { + const HloInstructionNode* node = stack.back(); + VisitState& visit_state = visited[node->unique_node_index]; + + if (visit_state == VisitState::kNew) { + visit_state = VisitState::kVisiting; + if (order == DfsTraversalOrder::kPreOrder) { + per_node_fn(*node); + } + } else { + stack.pop_back(); + if (visit_state == VisitState::kVisiting) { + visit_state = VisitState::kVisited; + if (order == DfsTraversalOrder::kPostOrder) { + per_node_fn(*node); + } + } + continue; + } + + if (!expand_node_fn(*node)) { + continue; + } + for (auto* child : node->children) { + if (visited[child->unique_node_index] == VisitState::kNew) { + stack.push_back(child); + } else { + // Already fully visited, no need to visit. + } + } + } +} + +std::vector GetAllNodesInDfsOrder( + const HloInstructionNode& root, DfsTraversalOrder order, int node_limit) { + std::vector subgraph; + HloGumgraphDfs( + root, [&](const HloInstructionNode& node) { subgraph.push_back(&node); }, + order, node_limit); + return subgraph; +} + +} // namespace xla::hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h new file mode 100644 index 00000000000000..c49a648f51d6c1 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_GRAPH_UTILS_HLO_GUMGRAPH_DFS_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_GRAPH_UTILS_HLO_GUMGRAPH_DFS_H_ + +#include +#include + +#include "absl/functional/function_ref.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" + +namespace xla::hlo_diff { + +// DFS traversal order: pre-order or post-order. +enum class DfsTraversalOrder : std::int8_t { kPreOrder, kPostOrder }; + +// Performs a depth first search of the HLO Module starting with specified +// instruction node as start and calls supplied per node execution function for +// each visited node. +// +// The traversal order determines whether the per node function is invoked +// before or after the children of the node are visited, i.e. pre-order or +// post-order traversal. +// +// The node_limit parameter should be set to the number of nodes in the +// HLOGumgraph, as its used to track the visit state of each node during +// traversal. +// +// If the expand_node_fn returns false for a node, the children of the node +// will not be visited. +void HloGumgraphDfs( + const HloInstructionNode& start_node, + absl::FunctionRef per_node_fn, + DfsTraversalOrder order, int node_limit, + absl::FunctionRef expand_node_fn = + [](const HloInstructionNode&) { return true; }); + +// Returns all nodes in the HLO Module in DFS order starting from the provided +// root node. Check comment of HloGumgraphDfs for more details. +std::vector GetAllNodesInDfsOrder( + const HloInstructionNode& root, DfsTraversalOrder order, int node_limit); + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_GRAPH_UTILS_HLO_GUMGRAPH_DFS_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs_test.cc new file mode 100644 index 00000000000000..5888d0d82ef55a --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs_test.cc @@ -0,0 +1,233 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::ElementsAre; + +class HloGumgraphDfsTest : public HloTestBase {}; + +TEST_F(HloGumgraphDfsTest, DfsPreOrderWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + const auto root = graph->GetRoot(); + std::vector visited_nodes; + HloGumgraphDfs( + root, + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + }, + DfsTraversalOrder::kPreOrder, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, + ElementsAre("root", "add_0", "baz", "add_1", "bar", "foo")); +} + +TEST_F(HloGumgraphDfsTest, DfsPostOrderWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + const auto root = graph->GetRoot(); + std::vector visited_nodes; + HloGumgraphDfs( + root, + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + }, + DfsTraversalOrder::kPostOrder, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, + ElementsAre("baz", "bar", "foo", "add_1", "add_0", "root")); +} + +TEST_F(HloGumgraphDfsTest, DfsPostOrderWorksForMultiplePathsFromRoot) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ ┌------┐ + // | | add_1 | ---> | ROOT | + // [copy_foo] -------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + copy_foo = f32[8,2048]{1,0:T(8,128)} copy(foo) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, copy_foo) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + const auto root = graph->GetRoot(); + std::vector visited_nodes; + HloGumgraphDfs( + root, + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + }, + DfsTraversalOrder::kPostOrder, graph->GetNodeCount()); + + EXPECT_THAT(visited_nodes, ElementsAre("foo", "copy_foo", "add_1", "root")); +} + +TEST_F(HloGumgraphDfsTest, GetAllNodesWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + const auto root = graph->GetRoot(); + std::vector visited_nodes = GetAllNodesInDfsOrder( + root, DfsTraversalOrder::kPreOrder, graph->GetNodeCount()); + std::vector string_views; + string_views.reserve(visited_nodes.size()); + for (const HloInstructionNode* node : visited_nodes) { + string_views.push_back(node->GetName()); + } + + EXPECT_THAT(string_views, + ElementsAre("root", "add_0", "baz", "add_1", "bar", "foo")); +} + +TEST_F(HloGumgraphDfsTest, DfsPreOrderStopExpandingWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + HloGumgraphDfs( + graph->GetRoot(), + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + }, + DfsTraversalOrder::kPreOrder, 6, + [](const HloInstructionNode& node) { return node.GetName() != "add_1"; }); + + EXPECT_THAT(visited_nodes, ElementsAre("root", "add_0", "baz", "add_1")); +} + +TEST_F(HloGumgraphDfsTest, DfsPostOrderStopExpandingWorks) { + // Create a module with entry computation containing the following structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, + HloGumgraph::Create(module.get())); + + std::vector visited_nodes; + HloGumgraphDfs( + graph->GetRoot(), + [&](const HloInstructionNode& node) { + visited_nodes.push_back(node.GetName()); + }, + DfsTraversalOrder::kPostOrder, 6, + [](const HloInstructionNode& node) { return node.GetName() != "add_1"; }); + + EXPECT_THAT(visited_nodes, ElementsAre("baz", "add_1", "add_0", "root")); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.cc new file mode 100644 index 00000000000000..38b0553a37624e --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.cc @@ -0,0 +1,137 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_diff_eval.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla::hlo_diff { +namespace { + +// Counts the number of split allegiance in the diff result. +// Split allegiance is defined as: +// Two computation nodes share the same fingerprint and some instructions are +// matched, but other instructions in the left computation are matched to +// instructions in a different computation node in the right graph. +// Returns a pair of the number of computations that are split allegiance, and +// the accumulated number of minimum instructions that are mismatched inside. +std::pair CountSplitAllegiance( + const HloGumgraph& left, const HloGumgraph& right, + const DiffSummary& diff_summary) { + int64_t split_allegiance_computation_count = 0; + int64_t split_allegiance_instruction_count = 0; + for (auto const& [computation, computation_props] : + left.AllComputationProps()) { + if (auto it = diff_summary.computation_summary.find(computation); + it != diff_summary.computation_summary.end()) { + const ComputationSummary& cmi = it->second; + if (cmi.split_allegiance_instruction_count > 0 && + right.AllComputationProps() + .at(cmi.main_matched_computation) + .fingerprint == computation_props.fingerprint) { + ++split_allegiance_computation_count; + split_allegiance_instruction_count += + cmi.split_allegiance_instruction_count; + } + } + } + return std::make_pair(split_allegiance_computation_count, + split_allegiance_instruction_count); +} + +// Counts the number of split allegiance parental in the diff result. +// Split allegiance parental is defined as: +// Two nodes are matched, they share the same number of children and children +// opcodes, but some of their children are not matched. +int64_t CountSplitAllegianceParental(const HloGumgraph& left, + const HloGumgraph& right, + const HloGumgraphMappings& mappings) { + int64_t count = 0; + for (const auto it : mappings.left_to_right_instruction_map.left) { + if (it.first->children.size() != it.second->children.size()) { + continue; + } + bool children_opcode_mismatch = false; + for (int i = 0; i < it.first->children.size(); ++i) { + if (it.first->children[i]->instruction->opcode() != + it.second->children[i]->instruction->opcode()) { + children_opcode_mismatch = true; + break; + } + } + if (children_opcode_mismatch) { + continue; + } + for (int i = 0; i < it.first->children.size(); ++i) { + if (auto cit = mappings.left_to_right_instruction_map.left.find( + it.first->children[i]); + cit == mappings.left_to_right_instruction_map.left.end() || + cit->second != it.second->children[i]) { + count++; + // LOG(INFO) << it.first->instruction->name() << " has split child: " + // << it.first->children[i]->instruction->name(); + } + } + } + return count; +} + +} // namespace + +std::unique_ptr ComputeDiffEval( + const HloGumgraph& left, const HloGumgraph& right, + const HloGumgraphMappings& mappings, const DiffResult& diff_result, + const DiffSummary& diff_summary) { + LOG(INFO) << "Evaluating diff result"; + auto eval = std::make_unique(); + auto [split_allegiance_computation_count, + split_allegiance_instruction_count] = + CountSplitAllegiance(left, right, diff_summary); + eval->num_split_allegiance_computation = split_allegiance_computation_count; + eval->num_split_allegiance_instruction = split_allegiance_instruction_count; + eval->num_split_allegiance_parental = + CountSplitAllegianceParental(left, right, mappings); + + eval->len_left_unmatched = + diff_result.left_module_unmatched_instructions.size(); + eval->len_right_unmatched = + diff_result.right_module_unmatched_instructions.size(); + eval->len_changed = diff_result.changed_instructions.size(); + eval->len_unchanged = diff_result.unchanged_instructions.size(); + + eval->left_node_count = left.GetNodeCount(); + eval->right_node_count = right.GetNodeCount(); + + return eval; +} + +void LogDiffEval(const DiffEval& diff_eval) { + LOG(INFO) << "Split Allegiance Computation: " + << diff_eval.num_split_allegiance_computation; + LOG(INFO) << "Split Allegiance Instruction: " + << diff_eval.num_split_allegiance_instruction; + LOG(INFO) << "Split Allegiance Parental: " + << diff_eval.num_split_allegiance_parental; +} + +} // namespace xla::hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.h new file mode 100644 index 00000000000000..620a1a5e1219ed --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval.h @@ -0,0 +1,70 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_EVAL_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_EVAL_H_ + +#include +#include + +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla::hlo_diff { + +// Evaluation metrics for the diff result. +struct DiffEval { + // Split allegiance is defined as: + // Two computation nodes share the same fingerprint and some instructions are + // matched, but other instructions in the left computation are matched to + // instructions in a different computation node in the right graph. + int64_t num_split_allegiance_computation = 0; + int64_t num_split_allegiance_instruction = 0; + // Split allegiance parental is defined as: + // Two nodes are matched, they share the same number of children and children + // opcodes, but some of their children are not matched. + int64_t num_split_allegiance_parental = 0; + + // Size of the diff result. + int64_t len_left_unmatched = 0; + int64_t len_right_unmatched = 0; + int64_t len_changed = 0; + int64_t len_unchanged = 0; + + // Graph node counts. + int64_t left_node_count = 0; + int64_t right_node_count = 0; +}; + +// Computes the diff evaluation metrics. +// left and right are the original graphs. +// mappings are the node mappings between the two graphs. +// diff_result contains the edit script(insert/delete/change/move) created from +// the node mappings. diff_summary summarizes the computation-based repeated +// diff patterns. +std::unique_ptr ComputeDiffEval( + const HloGumgraph& left, const HloGumgraph& right, + const HloGumgraphMappings& mappings, const DiffResult& diff_result, + const DiffSummary& diff_summary); + +// Logs the diff evaluation metrics. +void LogDiffEval(const DiffEval& diff_eval); + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_EVAL_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc new file mode 100644 index 00000000000000..c5aed19f50814c --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc @@ -0,0 +1,206 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_diff_eval.h" + +#include + +#include +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +class HloDiffTest : public HloTestBase {}; + +TEST_F(HloDiffTest, SplitAllegianceWorks) { + // Create two similar modules with entry computation containing the following + // structure: + // [Param p0]->[Param p2]->┌-------┐ ┌----------┐ + // | add.1 |->| fusion.1 |->┌-------┐ + // [Param p1]->[Param p3]->└-------┘ └----------┘ | | ┌------┐ + // | add.3 |->| ROOT | + // [Param p4]->[Param p6]->┌-------┐ ┌----------┐ | | └------┘ + // | add.2 |->| fusion.2 |->└-------┘ + // [Param p5]->[Param p7]->└-------┘ └----------┘ + const char* hlo_string = R"( + HloModule module, is_scheduled=true + + fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) + } + + fused_computation.2 { + p6 = s32[32,16]{0,1:T(1,128)} parameter(0) + p7 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.2 = s32[32,16]{0,1:T(1,128)} add(p6, p7) + } + + ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + p4 = s32[32,16]{0, 1:T(1,128)} parameter(2) + p5 = s32[32,16]{0,1:T(1,128)} parameter(3) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[32,16]{0,1:T(1,128)} fusion(p4,p5), kind=kLoop, calls=fused_computation.2 + ROOT add.3 = s32[32,16]{0,1:T(1,128)} add(fusion.1, fusion.2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + HloGumgraphMappings mappings; + // Map all nodes with the same name and then switch the mappings for add.1 and + // add.2. + mappings.MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + MatchAllNodesByName(*graph_l, *graph_r, mappings); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.1"), + GetNodeByName(*graph_r, "add.2"), mappings)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.2"), + GetNodeByName(*graph_r, "add.1"), mappings)); + // Construct the diff eval from the manually mapped node mappings. + std::unique_ptr diff_result = + ConstructDiffResult(*graph_l, *graph_r, mappings); + std::unique_ptr diff_summary = + ConstructDiffSummary(*graph_l, *graph_r, mappings, *diff_result); + std::unique_ptr diff_eval = ComputeDiffEval( + *graph_l, *graph_r, mappings, *diff_result, *diff_summary); + + EXPECT_EQ(diff_eval->num_split_allegiance_computation, 2); + EXPECT_EQ(diff_eval->num_split_allegiance_instruction, 2); + // The following pairs are split allegiance parental: parent are + // matched but children are not. (add.1 is matched to add.2) + // fusion.1 + // add.1 -> add.1 + // fusion.2 + // add.2 -> add.2 + // add.1 + // param2 -> param6 + // param3 -> param7 + // add.2 + // param6 -> param2 + // param7 -> param3 + EXPECT_EQ(diff_eval->num_split_allegiance_parental, 6); +} + +TEST_F(HloDiffTest, GraphNodeCountsWork) { + // Create a module with entry computation containing the following structure: + // [Param p0]->[Param p2]->┌-------┐ ┌----------┐ + // | add.1 |->| fusion.1 |->┌-------┐ + // [Param p1]->[Param p3]->└-------┘ └----------┘ | | ┌------┐ + // | add.3 |->| ROOT | + // [Param p4]->[Param p6]->┌-------┐ ┌----------┐ | | └------┘ + // | add.2 |->| fusion.2 |->└-------┘ + // [Param p5]->[Param p7]->└-------┘ └----------┘ + const char* hlo_string = R"( + HloModule module, is_scheduled=true + + fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) + } + + fused_computation.2 { + p6 = s32[32,16]{0,1:T(1,128)} parameter(0) + p7 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.2 = s32[32,16]{0,1:T(1,128)} add(p6, p7) + } + + ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + p4 = s32[32,16]{0, 1:T(1,128)} parameter(2) + p5 = s32[32,16]{0,1:T(1,128)} parameter(3) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[32,16]{0,1:T(1,128)} fusion(p4,p5), kind=kLoop, calls=fused_computation.2 + ROOT add.3 = s32[32,16]{0,1:T(1,128)} add(fusion.1, fusion.2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + std::unique_ptr diff_eval = + ComputeDiffEval(*graph_l, *graph_r, {}, {}, {}); + + EXPECT_EQ(diff_eval->left_node_count, 14); + EXPECT_EQ(diff_eval->right_node_count, 14); +} + +TEST_F(HloDiffTest, DiffSizeWorks) { + // Create a module with entry computation containing the following structure: + // [Param p0]->┌-------┐ + // | add.1 | + // [Param p1]->└-------┘ + const char* hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT add.1 = s32[32,16]{0,1:T(1,128)} add(p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + DiffResult diff_result; + diff_result.left_module_unmatched_instructions.push_back( + graph_l->GetRoot().instruction); + diff_result.right_module_unmatched_instructions.push_back( + graph_r->GetRoot().instruction); + diff_result.changed_instructions.insert( + {graph_l->GetRoot().instruction, graph_r->GetRoot().instruction}); + diff_result.unchanged_instructions.insert( + {graph_l->GetRoot().instruction, graph_r->GetRoot().instruction}); + std::unique_ptr diff_eval = + ComputeDiffEval(*graph_l, *graph_r, {}, diff_result, {}); + + EXPECT_EQ(diff_eval->len_left_unmatched, 1); + EXPECT_EQ(diff_eval->len_right_unmatched, 1); + EXPECT_EQ(diff_eval->len_changed, 1); + EXPECT_EQ(diff_eval->len_unchanged, 1); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc new file mode 100644 index 00000000000000..b38cd4ae5aed01 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc @@ -0,0 +1,259 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_eval.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h" +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h" +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_module_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tsl/platform/init_main.h" + +namespace xla { +namespace hlo_diff { +namespace { + +const char* const kUsage = R"( +Given two HLO Modules, compares the graph structure of two HLO Modules and +summarizes the differences in a human readable format. The tool focuses on +computational differences ignoring irrelevant changes such as instruction +names, parameter ordering etc, layouts (in some instances). + + Usage: + bazel run hlo_diff -- \ + --{first_hlo_snapshot,first_hlo_proto,first_hlo_module_proto,first_hlo_text}=path/to/first/binary_proto + --{second_hlo_snapshot,second_hlo_proto,second_hlo_module_proto,second_hlo_text}=path/to/second/binary_proto + [--ignore_shape_during_instruction_matching] + [--text_output=path/to/file/to/save/text] + [--html_output=path/to/file/to/save/html] + +first and second hlo file paths are required flags. Optionally the following +flags can be used: + +If --ignore_shape_during_instruction_matching is specified, the tool ignores +array/tensor shapes when matching instructions allowing for more permissive +matches. +If --text_output is specified, the full diff result will be printed in text +format and saved to the specified file. +if --html_output is specified, the diff result will be rendered in HTML +format and saved to the specified path. +)"; + +// Command line opts to this tool. See the main() for descriptions of these +// fields. +struct Options { + struct HloPath { + std::string hlo_snapshot; + std::string hlo_proto; + std::string hlo_module_proto; + std::string hlo_text; + }; + + struct RenderOptions { + std::string text_output; + std::string html_output; + }; + + HloPath first; + HloPath second; + DiffOptions diff_options; + RenderOptions render_options; +}; + +absl::Status CheckGroupFlags(const Options::HloPath& hlo_path) { + int nonempty_options_amount = 0; + for (const auto& path : {hlo_path.hlo_snapshot, hlo_path.hlo_proto, + hlo_path.hlo_module_proto, hlo_path.hlo_text}) { + if (!path.empty()) { + ++nonempty_options_amount; + } + } + return nonempty_options_amount == 1 + ? absl::OkStatus() + : absl::FailedPreconditionError( + "Can only specify one and only one of path flags."); +} + +// Builds a HloModule from the HloModuleProto. +absl::StatusOr> BuildHloModule( + const HloModuleProto& hlo_module_proto) { + TF_ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto( + hlo_module_proto, xla::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(hlo_module_proto, config); +} + +absl::StatusOr> LoadHLOModule( + const Options::HloPath& hlo_path) { + if (!hlo_path.hlo_snapshot.empty()) { + HloSnapshot snapshot; + TF_CHECK_OK(tsl::ReadBinaryProto(tsl::Env::Default(), hlo_path.hlo_snapshot, + &snapshot)) + << "Can't open, read, or parse HloSnapshot proto at " + << hlo_path.hlo_snapshot; + return BuildHloModule(snapshot.hlo().hlo_module()); + } + if (!hlo_path.hlo_proto.empty()) { + return ReadModuleFromBinaryProtoFile(hlo_path.hlo_proto, + xla::GetDebugOptionsFromFlags()); + } + if (!hlo_path.hlo_module_proto.empty()) { + return ReadModuleFromModuleBinaryProtofile(hlo_path.hlo_module_proto, + xla::GetDebugOptionsFromFlags()); + } + if (!hlo_path.hlo_text.empty()) { + return ReadModuleFromHloTextFile( + hlo_path.hlo_text, xla::GetDebugOptionsFromFlags(), + xla::HloParserOptions().set_fill_shortform_constants_with_random_values( + false)); + } + + return absl::InvalidArgumentError("No hlo_path specified."); +} + +// Runs Gumgraph algorithm based diff and renders the diff results. +absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, + const Options& opts) { + TF_RETURN_IF_ERROR(first_module.RemoveUnusedComputations()); + TF_RETURN_IF_ERROR(second_module.RemoveUnusedComputations()); + + TF_ASSIGN_OR_RETURN( + auto hlo_gumgraph_diff, + ComputeDiff(first_module, second_module, opts.diff_options)); + std::cout << "Diffing finished" << '\n'; + + const DiffResult& diff = *hlo_gumgraph_diff.diff_result; + const DiffSummary& diff_summary = *hlo_gumgraph_diff.diff_summary; + LogDiffResult(diff); + LogDiffEval(*hlo_gumgraph_diff.diff_eval); + std::ostringstream text; + RenderTextSummary(diff, text); + std::cout << text.str() << '\n'; + + const std::string& text_output = opts.render_options.text_output; + if (!text_output.empty()) { + std::ostringstream text; + RenderText(diff, text); + TF_RETURN_IF_ERROR( + tsl::WriteStringToFile(tsl::Env::Default(), text_output, text.str())); + } + + std::string html_output = opts.render_options.html_output; + if (!html_output.empty()) { + std::ostringstream html; + RenderHtml( + diff, diff_summary, + [](const HloInstruction* left_inst, const HloInstruction* right_inst) { + return ""; + }, + [](absl::string_view op_name) { return std::nullopt; }, + [](absl::string_view op_name) { return std::nullopt; }, html); + TF_RETURN_IF_ERROR( + tsl::WriteStringToFile(tsl::Env::Default(), html_output, html.str())); + + std::cout << "The diff summary is saved to: " << html_output << '\n'; + } + + return absl::OkStatus(); +} + +void RealMain(const Options& opts) { + TF_CHECK_OK(CheckGroupFlags(opts.first)) + << "Can only specify one and ony one of --first_hlo_snapshot, " + "--first_hlo_proto, --first_hlo_module_proto, --first_hlo_text"; + TF_CHECK_OK(CheckGroupFlags(opts.second)) + << "Can only specify one and ony one of --second_hlo_snapshot, " + "--second_hlo_proto, --second_hlo_module_proto, --second_hlo_text"; + + LOG(INFO) << "Loading first module"; + absl::StatusOr> first_module = + LoadHLOModule(opts.first); + TF_CHECK_OK(first_module.status()) << "Failed to build first HLO module"; + LOG(INFO) << "Loaded first module"; + + LOG(INFO) << "Loading second module"; + absl::StatusOr> second_module = + LoadHLOModule(opts.second); + TF_CHECK_OK(second_module.status()) << "Failed to build second HLO module"; + LOG(INFO) << "Loaded second module"; + + CHECK_OK( + RunGumgraphDiff(*first_module.value(), *second_module.value(), opts)); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla + +int main(int argc, char** argv) { + xla::hlo_diff::Options opts; + bool need_help = false; + const std::vector flag_list = { + tsl::Flag("first_hlo_snapshot", &opts.first.hlo_snapshot, + "first HloSnapshot proto to compare"), + tsl::Flag("first_hlo_proto", &opts.first.hlo_proto, + "first XLA hlo proto to compare"), + tsl::Flag("first_hlo_module_proto", &opts.first.hlo_module_proto, + "first XLA hlo module proto to compare"), + tsl::Flag("first_hlo_text", &opts.first.hlo_text, + "first XLA hlo text to compare"), + tsl::Flag("second_hlo_snapshot", &opts.second.hlo_snapshot, + "second HloSnapshot proto to compare"), + tsl::Flag("second_hlo_proto", &opts.second.hlo_proto, + "second XLA hlo proto to compare"), + tsl::Flag("second_hlo_module_proto", &opts.second.hlo_module_proto, + "second XLA hlo module proto to compare"), + tsl::Flag("second_hlo_text", &opts.second.hlo_text, + "second XLA hlo text to compare"), + tsl::Flag("ignore_shape_during_instruction_matching", + &opts.diff_options.fingerprint_options.ignore_shape, + "Ignore array/tensor shapes when matching instructions"), + tsl::Flag("text_output", &opts.render_options.text_output, + "file to save diff blocks as text"), + tsl::Flag("html_output", &opts.render_options.html_output, + "file to save an overview of the diff result as html"), + tsl::Flag("help", &need_help, "Prints this help message"), + }; + + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list); + tsl::port::InitMain(xla::hlo_diff::kUsage, &argc, &argv); + LOG_IF(QFATAL, argc != 1 || !parse_ok || need_help) << usage; + xla::hlo_diff::RealMain(opts); + return 0; +} diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc new file mode 100644 index 00000000000000..a6657831ea4289 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc @@ -0,0 +1,128 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/hlo_diff_util.h" + +namespace xla { +namespace hlo_diff { +namespace { + +bool IsChangedInstruction(const HloInstructionNode* left_node, + const HloInstructionNode* right_node) { + uint64_t left_fingerprint = GetHloInstructionFingerprint( + left_node->instruction, HloPrintOptions::Fingerprint()); + uint64_t right_fingerprint = GetHloInstructionFingerprint( + right_node->instruction, HloPrintOptions::Fingerprint()); + return left_fingerprint != right_fingerprint; +} + +} // namespace + +std::unique_ptr ConstructDiffResult( + const HloGumgraph& left_graph, const HloGumgraph& right_graph, + const HloGumgraphMappings& mappings) { + LOG(INFO) << "Constructing diff result"; + const std::vector left_all_nodes = + GetAllNodesInBfsOrder(left_graph.GetRoot(), + BfsTraversalDirection::kForward, + left_graph.GetNodeCount()); + const std::vector right_all_nodes = + GetAllNodesInBfsOrder(right_graph.GetRoot(), + BfsTraversalDirection::kForward, + right_graph.GetNodeCount()); + auto diff_result = std::make_unique(); + for (const HloInstructionNode* left_node : left_all_nodes) { + if (left_node->is_root) { + continue; + } + diff_result->node_props.insert({left_node->instruction, left_node->props}); + if (!mappings.InstructionMapContainsLeft(left_node)) { + diff_result->left_module_unmatched_instructions.push_back( + left_node->instruction); + continue; + } + const HloInstructionNode* right_node = + mappings.left_to_right_instruction_map.left.find(left_node)->second; + const HloInstructionNodeMappingProps& mapping_props = + mappings.left_to_right_instruction_map.left.find(left_node)->info; + + if (IsChangedInstruction(left_node, right_node)) { + diff_result->changed_instructions[left_node->instruction] = + right_node->instruction; + diff_result->map_by[std::make_pair(left_node->instruction, + right_node->instruction)] = + mapping_props.matcher_type; + continue; + } + // If node position is unchanged, add to unchanged instructions. + if (mapping_props.unchanged) { + diff_result->unchanged_instructions[left_node->instruction] = + right_node->instruction; + diff_result->map_by[std::make_pair(left_node->instruction, + right_node->instruction)] = + mapping_props.matcher_type; + continue; + } + // TODO(b/369851244): Add moved instructions to diff result. + diff_result->unchanged_instructions[left_node->instruction] = + right_node->instruction; + diff_result->map_by[std::make_pair(left_node->instruction, + right_node->instruction)] = + mapping_props.matcher_type; + } + + for (const HloInstructionNode* right_node : right_all_nodes) { + if (right_node->is_root) { + continue; + } + diff_result->node_props.insert( + {right_node->instruction, right_node->props}); + if (!mappings.InstructionMapContainsRight(right_node)) { + diff_result->right_module_unmatched_instructions.push_back( + right_node->instruction); + } + } + + return diff_result; +} + +void LogDiffResult(const DiffResult& diff_result) { + LOG(INFO) << "Unmatched instructions in the left module: " + << diff_result.left_module_unmatched_instructions.size(); + LOG(INFO) << "Unmatched instructions in the right module: " + << diff_result.right_module_unmatched_instructions.size(); + LOG(INFO) << "Changed instructions: " + << diff_result.changed_instructions.size(); + LOG(INFO) << "Moved instructions: " << diff_result.moved_instructions.size(); + LOG(INFO) << "Unchanged instructions: " + << diff_result.unchanged_instructions.size(); +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h new file mode 100644 index 00000000000000..08a8e0d5ea2dcb --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h @@ -0,0 +1,65 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_RESULT_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_RESULT_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla { +namespace hlo_diff { +// Result of diff'ng the left and right HLO modules. Contains the matched and +// unmatched instructions in the two modules. +struct DiffResult { + // Matched instructions. + absl::flat_hash_map + unchanged_instructions; + absl::flat_hash_map + changed_instructions; + absl::flat_hash_map + moved_instructions; + + // Unmatched instructions. + std::vector left_module_unmatched_instructions; + std::vector right_module_unmatched_instructions; + + // Debug info. + absl::flat_hash_map, + MatcherType> + map_by; + absl::flat_hash_map + node_props; +}; + +// Constructs the diff result from the node mappings. +std::unique_ptr ConstructDiffResult( + const HloGumgraph& left_graph, const HloGumgraph& right_graph, + const HloGumgraphMappings& mappings); + +// Logs the diff result. +void LogDiffResult(const DiffResult& diff_result); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_RESULT_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result_test.cc new file mode 100644 index 00000000000000..12dff12ae38801 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result_test.cc @@ -0,0 +1,301 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" + +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" +#include "xla/service/hlo_module_config.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::Pair; +using ::testing::Pointee; +using ::testing::Property; +using ::testing::UnorderedElementsAre; + +class HloDiffTest : public HloTestBase {}; + +TEST_F(HloDiffTest, MatchedDifferentShapeMarkAsChanged) { + // Create left module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Param 1] ---> └-------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.0, parameter.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Param 1] ---> └-------┘ + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = f64[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.0, parameter.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "add.0"), GetNodeByName(*graph_r, "add.0"), + *mappings, /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "parameter.0"), + GetNodeByName(*graph_r, "parameter.0"), + *mappings, /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "parameter.1"), + GetNodeByName(*graph_r, "parameter.1"), + *mappings, /*position_unchanged=*/true)); + auto diff_result = ConstructDiffResult(*graph_l, *graph_r, *mappings); + + EXPECT_THAT(diff_result->changed_instructions, + UnorderedElementsAre( + Pair(Pointee(Property(&HloInstruction::name, "parameter.0")), + Pointee(Property(&HloInstruction::name, "parameter.0"))), + Pair(Pointee(Property(&HloInstruction::name, "add.0")), + Pointee(Property(&HloInstruction::name, "add.0"))))); + EXPECT_THAT(diff_result->unchanged_instructions, + UnorderedElementsAre(Pair( + Pointee(Property(&HloInstruction::name, "parameter.1")), + Pointee(Property(&HloInstruction::name, "parameter.1"))))); +} + +TEST_F(HloDiffTest, MatchedDifferentFingerprintMarkAsChanged) { + // Create left module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ ┌------┐ + // | add_0 | ---> | ROOT | + // [Param 1] ---> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.0, parameter.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param 1] ---> ┌-------┐ ┌------┐ + // | add_0 | ---> | ROOT | + // [Param 0] ---> └-------┘ └------┘ + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.1, parameter.0) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "add.0"), GetNodeByName(*graph_r, "add.0"), + *mappings, /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "parameter.0"), + GetNodeByName(*graph_r, "parameter.1"), + *mappings, /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "parameter.1"), + GetNodeByName(*graph_r, "parameter.0"), + *mappings, /*position_unchanged=*/true)); + auto diff_result = ConstructDiffResult(*graph_l, *graph_r, *mappings); + + EXPECT_THAT( + diff_result->changed_instructions, + UnorderedElementsAre( + Pair(Pointee(Property(&HloInstruction::name, "parameter.0")), + Pointee(Property(&HloInstruction::name, "parameter.1"))), + Pair(Pointee(Property(&HloInstruction::name, "parameter.1")), + Pointee(Property(&HloInstruction::name, "parameter.0"))))); + EXPECT_THAT(diff_result->unchanged_instructions, + UnorderedElementsAre( + Pair(Pointee(Property(&HloInstruction::name, "add.0")), + Pointee(Property(&HloInstruction::name, "add.0"))))); +} + +TEST_F(HloDiffTest, UnmatchedInstructionsMarkAsUnmatched) { + // Create left module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Param 1] ---> └-------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.0, parameter.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param 1] ---> ┌-------┐ + // | add_0 | + // [Param 0] ---> └-------┘ + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.1, parameter.0) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "add.0"), GetNodeByName(*graph_r, "add.0"), + *mappings, /*position_unchanged=*/true)); + auto diff_result = ConstructDiffResult(*graph_l, *graph_r, *mappings); + + EXPECT_THAT(diff_result->unchanged_instructions, + UnorderedElementsAre( + Pair(Pointee(Property(&HloInstruction::name, "add.0")), + Pointee(Property(&HloInstruction::name, "add.0"))))); + EXPECT_THAT(diff_result->left_module_unmatched_instructions, + UnorderedElementsAre( + Pointee(Property(&HloInstruction::name, "parameter.0")), + Pointee(Property(&HloInstruction::name, "parameter.1")))); + EXPECT_THAT(diff_result->right_module_unmatched_instructions, + UnorderedElementsAre( + Pointee(Property(&HloInstruction::name, "parameter.0")), + Pointee(Property(&HloInstruction::name, "parameter.1")))); +} + +TEST_F(HloDiffTest, ShortFormConstantsMatched) { + // Create left module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Const 2958] ---> └-------┘ + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_l, + ParseAndReturnUnverifiedModule( + R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = s32[301]{0:T(512)} parameter(0) + constant.2958 = s32[301]{0:T(512)} constant({...}) + add.0 = s32[301]{0:T(512)} add(parameter.0, constant.2958) +} +)", + HloModuleConfig(), + HloParserOptions().set_fill_shortform_constants_with_random_values( + false))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Const 2958] ---> └-------┘ + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_r, + ParseAndReturnUnverifiedModule( + R"( +HloModule module, is_scheduled=true + +ENTRY entry { + parameter.0 = s32[301]{0:T(512)} parameter(0) + constant.2958 = s32[301]{0:T(512)} constant({...}) + add.0 = s32[301]{0:T(512)} add(parameter.0, constant.2958) +} +)", + HloModuleConfig(), + HloParserOptions().set_fill_shortform_constants_with_random_values( + false))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "add.0"), GetNodeByName(*graph_r, "add.0"), + *mappings, /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "parameter.0"), + GetNodeByName(*graph_r, "parameter.0"), + *mappings, /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "constant.2958"), + GetNodeByName(*graph_r, "constant.2958"), + *mappings, /*position_unchanged=*/true)); + auto diff_result = ConstructDiffResult(*graph_l, *graph_r, *mappings); + + EXPECT_THAT( + diff_result->unchanged_instructions, + UnorderedElementsAre( + Pair(Pointee(Property(&HloInstruction::name, "constant.2958")), + Pointee(Property(&HloInstruction::name, "constant.2958"))), + Pair(Pointee(Property(&HloInstruction::name, "parameter.0")), + Pointee(Property(&HloInstruction::name, "parameter.0"))), + Pair(Pointee(Property(&HloInstruction::name, "add.0")), + Pointee(Property(&HloInstruction::name, "add.0"))))); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc new file mode 100644 index 00000000000000..b69fad4be806f5 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc @@ -0,0 +1,308 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/connected_components.h" +#include "tsl/platform/fingerprint.h" + +namespace xla { +namespace hlo_diff { +namespace { + +// Returns the mapped instruction node of the given instruction in the given +// direction. Returns nullptr if the instruction is not mapped. +const HloInstructionNode* FindMappedInstructionNode( + const HloGumgraphMappings& mappings, const HloInstructionNode* instruction, + ComputationMappingDirection direction) { + switch (direction) { + case ComputationMappingDirection::kLeftToRight: { + auto it = mappings.left_to_right_instruction_map.left.find(instruction); + if (it != mappings.left_to_right_instruction_map.left.end()) { + return it->second; + } + break; + } + case ComputationMappingDirection::kRightToLeft: { + auto it = mappings.left_to_right_instruction_map.right.find(instruction); + if (it != mappings.left_to_right_instruction_map.right.end()) { + return it->second; + } + break; + } + } + return nullptr; +} + +// Result of finding the main matched computation. +struct MainMatchedComputationResult { + const HloComputation* main_matched_computation = nullptr; + const int max_matched_instruction_count = 0; + const int split_allegiance_instruction_count = 0; +}; + +// Returns the main matched computation of the given computation in the given +// direction. A computation is considered as the main matched computation if it +// has the most matched instructions. +MainMatchedComputationResult FindMainMatchedComputation( + const HloComputation* computation, const HloGumgraph& graph, + const HloGumgraphMappings& mappings, + ComputationMappingDirection direction) { + ComputationSummary result; + absl::flat_hash_map matched_instruction_count; + int max_count = 0; + int mapped_instruction_count = 0; + const HloComputation* main_matched_computation = nullptr; + for (const HloInstruction* instruction : computation->instructions()) { + if (const HloInstructionNode* const mapped_instruction_node = + FindMappedInstructionNode(mappings, graph.GetNode(instruction), + direction); + mapped_instruction_node != nullptr) { + ++mapped_instruction_count; + const HloComputation* right_computation = + mapped_instruction_node->instruction->parent(); + const int count = ++matched_instruction_count[right_computation]; + if (count > max_count) { + max_count = count; + main_matched_computation = right_computation; + } + } + } + return {.main_matched_computation = main_matched_computation, + .max_matched_instruction_count = max_count, + .split_allegiance_instruction_count = + mapped_instruction_count - max_count}; +} + +uint64_t GetDiffTypeFingerprint( + const HloInstruction* instruction, + const absl::flat_hash_set& changed_instructions, + const absl::flat_hash_set& unmatched_instructions) { + if (changed_instructions.contains(instruction)) { + return DiffCode::kChanged; + } + if (unmatched_instructions.contains(instruction)) { + return DiffCode::kUnmatched; + } + return DiffCode::kUnchanged; +} + +struct DiffFingerprint { + bool all_unchanged; + uint64_t diff_fingerprint; +}; + +DiffFingerprint ComputationDiffFingerprint( + const xla::HloComputation* computation, + const absl::flat_hash_set& changed_instructions, + const absl::flat_hash_set& unmatched_instructions) { + absl::flat_hash_map subgraph_fingerprint; + DiffFingerprint result; + bool all_unchanged = true; + for (auto* instruction : computation->MakeInstructionPostOrder()) { + uint64_t fp = static_cast(instruction->opcode()); + uint64_t diff_type_fp = GetDiffTypeFingerprint( + instruction, changed_instructions, unmatched_instructions); + all_unchanged = all_unchanged && (diff_type_fp == DiffCode::kUnchanged); + fp = tsl::FingerprintCat64(fp, diff_type_fp); + for (const HloInstruction* operand : instruction->operands()) { + fp = tsl::FingerprintCat64(fp, subgraph_fingerprint.at(operand)); + } + // TODO(b/394201811): Make sure no fingerprint collision. + subgraph_fingerprint[instruction] = fp; + } + result.all_unchanged = all_unchanged; + result.diff_fingerprint = + subgraph_fingerprint.at(computation->root_instruction()); + return result; +} + +// Split the computations into left and right computations. +ComputationGroup SplitComputations( + const std::vector& computations, + const absl::flat_hash_map& + computation_summaries) { + ComputationGroup result; + for (const HloComputation* computation : computations) { + if (auto it = computation_summaries.find(computation); + it != computation_summaries.end()) { + if (it->second.direction == ComputationMappingDirection::kLeftToRight) { + result.left_computations.push_back(computation); + } else { + result.right_computations.push_back(computation); + } + } + } + return result; +} + +// Returns the connected components of the given computation summary. +absl::flat_hash_map> +FindConnectedComponents( + absl::flat_hash_map + computation_summary) { + ConnectedComponentsFinder cc; + absl::flat_hash_map> result; + for (const auto& [computation, computation_match_info] : + computation_summary) { + if (computation_match_info.main_matched_computation != nullptr) { + cc.AddEdge(computation, computation_match_info.main_matched_computation); + } + } + std::vector> connected_component_groups = + cc.FindConnectedComponents(); + + for (const auto& component_group : connected_component_groups) { + bool all_unchanged = true; + for (const auto& computation : component_group) { + all_unchanged = + all_unchanged && computation_summary.at(computation).all_unchanged; + } + if (all_unchanged) { + continue; + } + std::vector sorted_component_group(component_group); + std::sort(sorted_component_group.begin(), sorted_component_group.end(), + [&](const HloComputation* a, const HloComputation* b) { + return computation_summary.at(a).diff_fingerprint < + computation_summary.at(b).diff_fingerprint; + }); + uint64_t fingerprint = 0; + for (const auto& computation : sorted_component_group) { + fingerprint = tsl::FingerprintCat64( + fingerprint, computation_summary.at(computation).diff_fingerprint); + } + result[fingerprint].push_back( + SplitComputations(sorted_component_group, computation_summary)); + } + return result; +} + +absl::flat_hash_map +SummarizeAllComputationsInGraph( + const HloGumgraph& graph, const HloGumgraphMappings& mappings, + const DiffResult& diff_result, + const absl::flat_hash_set& changed_instructions, + const absl::flat_hash_set& unmatched_instructions, + ComputationMappingDirection direction) { + absl::flat_hash_map result; + for (auto const& [computation, _] : graph.AllComputationProps()) { + const MainMatchedComputationResult mmc = + FindMainMatchedComputation(computation, graph, mappings, direction); + DiffFingerprint dfp = ComputationDiffFingerprint( + computation, changed_instructions, unmatched_instructions); + result.insert( + {computation, + { + .direction = direction, + .main_matched_computation = mmc.main_matched_computation, + .max_matched_instruction_count = mmc.max_matched_instruction_count, + .split_allegiance_instruction_count = + mmc.split_allegiance_instruction_count, + .diff_fingerprint = dfp.diff_fingerprint, + .all_unchanged = dfp.all_unchanged, + }}); + } + return result; +} +} // namespace + +std::unique_ptr ConstructDiffSummary( + const HloGumgraph& left_graph, const HloGumgraph& right_graph, + const HloGumgraphMappings& mappings, const DiffResult& diff_result) { + auto summary = std::make_unique(); + absl::flat_hash_set left_changed_instructions; + absl::flat_hash_set right_changed_instructions; + absl::flat_hash_set left_unmatched_instructions; + absl::flat_hash_set right_unmatched_instructions; + for (auto const& [left, right] : diff_result.changed_instructions) { + left_changed_instructions.insert(left); + right_changed_instructions.insert(right); + } + left_unmatched_instructions.insert( + diff_result.left_module_unmatched_instructions.begin(), + diff_result.left_module_unmatched_instructions.end()); + right_unmatched_instructions.insert( + diff_result.right_module_unmatched_instructions.begin(), + diff_result.right_module_unmatched_instructions.end()); + summary->computation_summary.merge(SummarizeAllComputationsInGraph( + left_graph, mappings, diff_result, left_changed_instructions, + left_unmatched_instructions, ComputationMappingDirection::kLeftToRight)); + summary->computation_summary.merge(SummarizeAllComputationsInGraph( + right_graph, mappings, diff_result, right_changed_instructions, + right_unmatched_instructions, ComputationMappingDirection::kRightToLeft)); + + // Group the computations by their diff fingerprint. + summary->grouped_computations = + FindConnectedComponents(summary->computation_summary); + + return summary; +} + +void LogDiffSummary(const DiffSummary& diff_summary) { + // Log the connected components repeated more than 3 times. + LOG(INFO) << "Find Repeated Connected Components: "; + for (const auto& [fingerprint, computation_groups] : + diff_summary.grouped_computations) { + if (computation_groups.size() < 3) { + continue; + } + LOG(INFO) << computation_groups.size() + << " Repeated Connected Components Fingerprint: " << fingerprint; + int i = 0; + for (const auto& computation_group : computation_groups) { + ++i; + std::string computations_str; + for (const HloComputation* computation : + computation_group.left_computations) { + absl::StrAppend(&computations_str, + absl::StrFormat("L: %s, ", computation->name())); + } + for (const HloComputation* computation : + computation_group.right_computations) { + absl::StrAppend(&computations_str, + absl::StrFormat("R: %s, ", computation->name())); + } + LOG(INFO) << computations_str; + if (i >= 5) { + LOG(INFO) << "..."; + break; + } + } + } +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h new file mode 100644 index 00000000000000..5b38e3c31940d5 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h @@ -0,0 +1,97 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_SUMMARY_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_SUMMARY_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla { +namespace hlo_diff { + +enum DiffCode : uint8_t { + kUnchanged, + kChanged, + kUnmatched, +}; + +enum class ComputationMappingDirection : std::uint8_t { + kLeftToRight, + kRightToLeft +}; + +struct ComputationSummary { + ComputationMappingDirection direction; + // Computation in the other graph that has most instructions matched. + // Can be nullptr if no instructions are matched. + const HloComputation* main_matched_computation = nullptr; + + // Number of instructions that are mapped to instructions in the main matched + // computation. + int64_t max_matched_instruction_count = 0; + + // Number of instructions that are mapped to instructions in a different + // computation. + int64_t split_allegiance_instruction_count = 0; + + // Fingerprint of the computation including diff. + uint64_t diff_fingerprint = 0; + + // Whether all instructions in the computation are unchanged. + bool all_unchanged = true; +}; + +// A group of computations that are connected in the graph. +struct ComputationGroup { + std::vector left_computations; + std::vector right_computations; +}; + +// Summary of the diff result of the left and right HLO modules. +struct DiffSummary { + // Connected computations grouped by fingerprint. + absl::flat_hash_map> + grouped_computations; + + // Summary of each computation. + absl::flat_hash_map + computation_summary; +}; + +// Constructs the diff summary from the node mappings and diff result. +// `left_graph` and `right_graph` are the original graphs. +// `mappings` are the node mappings between the two graphs.. +// `diff_result` contains the edit script(insert/delete/change/move) created +// from the node mappings. +std::unique_ptr ConstructDiffSummary( + const HloGumgraph& left_graph, const HloGumgraph& right_graph, + const HloGumgraphMappings& mappings, const DiffResult& diff_result); + +// Logs the diff summary. +void LogDiffSummary(const DiffSummary& diff_summary); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_HLO_DIFF_SUMMARY_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc new file mode 100644 index 00000000000000..d1026c80b4ea5a --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc @@ -0,0 +1,369 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" + +#include + +#include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::FieldsAre; +using ::testing::Pair; +using ::testing::Pointee; +using ::testing::Property; +using ::testing::UnorderedElementsAre; + +class HloDiffTest : public HloTestBase {}; + +TEST_F(HloDiffTest, FindMainMatchedComputationWorks) { + // Create a module with entry computation containing the following structure: + // [Param p0]->[Param p2]->┌-------┐ ┌----------┐ + // | add.1 |->| fusion.1 |->┌-------┐ + // [Param p1]->[Param p3]->└-------┘ └----------┘ | | ┌------┐ + // | add.3 |->| ROOT | + // [Param p4]->[Param p6]->┌-------┐ ┌----------┐ | | └------┘ + // | add.2 |->| fusion.2 |->└-------┘ + // [Param p5]->[Param p7]->└-------┘ └----------┘ + const char* hlo_string = R"( + HloModule module, is_scheduled=true + + fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) + } + + fused_computation.2 { + p6 = s32[32,16]{0,1:T(1,128)} parameter(0) + p7 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.2 = s32[32,16]{0,1:T(1,128)} add(p6, p7) + } + + ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + p4 = s32[32,16]{0, 1:T(1,128)} parameter(2) + p5 = s32[32,16]{0,1:T(1,128)} parameter(3) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[32,16]{0,1:T(1,128)} fusion(p4,p5), kind=kLoop, calls=fused_computation.2 + ROOT add.3 = s32[32,16]{0,1:T(1,128)} add(fusion.1, fusion.2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + HloGumgraphMappings mappings; + // Root nodes are matched by default before the matcher is called. + mappings.MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + MatchAllNodesByName(*graph_l, *graph_r, mappings); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.1"), + GetNodeByName(*graph_r, "add.2"), mappings, + /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.2"), + GetNodeByName(*graph_r, "add.1"), mappings, + /*position_unchanged=*/true)); + std::unique_ptr diff_result = + ConstructDiffResult(*graph_l, *graph_r, mappings); + std::unique_ptr diff_summary = + ConstructDiffSummary(*graph_l, *graph_r, mappings, *diff_result); + absl::flat_hash_map + left_computation_summary; + for (const auto& [computation, _] : graph_l->AllComputationProps()) { + if (auto it = diff_summary->computation_summary.find(computation); + it != diff_summary->computation_summary.end()) { + left_computation_summary[computation] = it->second; + } + } + absl::flat_hash_map + right_computation_summary; + for (const auto& [computation, _] : graph_r->AllComputationProps()) { + if (auto it = diff_summary->computation_summary.find(computation); + it != diff_summary->computation_summary.end()) { + right_computation_summary[computation] = it->second; + } + } + + EXPECT_THAT( + left_computation_summary, + UnorderedElementsAre( + Pair( + Pointee(Property(&HloComputation::name, "entry")), + FieldsAre(/*direction=*/ComputationMappingDirection::kLeftToRight, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, "entry")), + /*max_matched_instruction_count=*/7, + /*split_allegiance_instruction=*/0, + /*diff_fingerprint=*/3570884195340145402U, + /*all_unchanged=*/true)), + Pair( + Pointee(Property(&HloComputation::name, "fused_computation.1")), + FieldsAre(/*direction=*/ComputationMappingDirection::kLeftToRight, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, + "fused_computation.1")), + /*max_matched_instruction_count=*/2, + /*split_allegiance_instruction=*/1, + /*diff_fingerprint=*/2604941079081458563U, + /*all_unchanged=*/true)), + Pair( + Pointee(Property(&HloComputation::name, "fused_computation.2")), + FieldsAre(/*direction=*/ComputationMappingDirection::kLeftToRight, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, + "fused_computation.2")), + /*max_matched_instruction_count=*/2, + /*split_allegiance_instruction=*/1, + /*diff_fingerprint=*/2604941079081458563U, + /*all_unchanged=*/true)))); + EXPECT_THAT( + right_computation_summary, + UnorderedElementsAre( + Pair( + Pointee(Property(&HloComputation::name, "entry")), + FieldsAre(/*direction=*/ComputationMappingDirection::kRightToLeft, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, "entry")), + /*max_matched_instruction_count=*/7, + /*split_allegiance_instruction=*/0, + /*diff_fingerprint=*/3570884195340145402U, + /*all_unchanged=*/true)), + Pair( + Pointee(Property(&HloComputation::name, "fused_computation.1")), + FieldsAre(/*direction=*/ComputationMappingDirection::kRightToLeft, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, + "fused_computation.1")), + /*max_matched_instruction_count=*/2, + /*split_allegiance_instruction=*/1, + /*diff_fingerprint=*/2604941079081458563U, + /*all_unchanged=*/true)), + Pair( + Pointee(Property(&HloComputation::name, "fused_computation.2")), + FieldsAre(/*direction=*/ComputationMappingDirection::kRightToLeft, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, + "fused_computation.2")), + /*max_matched_instruction_count=*/2, + /*split_allegiance_instruction=*/1, + /*diff_fingerprint=*/2604941079081458563U, + /*all_unchanged=*/true)))); +} + +TEST_F(HloDiffTest, ComputationDiffFingerprintWorks) { + // Create left module with entry computation containing the following + // structure: + // [Param 0] ---> ┌-------┐ + // | add_0 | + // [Param 1] ---> └-------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.0, parameter.1) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param 1] ---> ┌-------┐ + // | add_0 | + // [Param 0] ---> └-------┘ + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( + HloModule module, is_scheduled=true + + ENTRY entry { + parameter.0 = f32[] parameter(0) + parameter.1 = f32[] parameter(1) + add.0 = f32[] add(parameter.1, parameter.0) + } + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + HloGumgraphMappings mappings; + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "add.0"), GetNodeByName(*graph_r, "add.0"), + mappings, true)); + std::unique_ptr diff_result = + ConstructDiffResult(*graph_l, *graph_r, mappings); + std::unique_ptr diff_summary = + ConstructDiffSummary(*graph_l, *graph_r, mappings, *diff_result); + absl::flat_hash_map + left_computation_summary; + for (const auto& [computation, _] : graph_l->AllComputationProps()) { + if (auto it = diff_summary->computation_summary.find(computation); + it != diff_summary->computation_summary.end()) { + left_computation_summary[computation] = it->second; + } + } + absl::flat_hash_map + right_computation_summary; + for (const auto& [computation, _] : graph_r->AllComputationProps()) { + if (auto it = diff_summary->computation_summary.find(computation); + it != diff_summary->computation_summary.end()) { + right_computation_summary[computation] = it->second; + } + } + EXPECT_THAT( + left_computation_summary, + UnorderedElementsAre(Pair( + Pointee(Property(&HloComputation::name, "entry")), + FieldsAre(/*direction=*/ComputationMappingDirection::kLeftToRight, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, "entry")), + /*max_matched_instruction_count=*/1, + /*split_allegiance_instruction=*/0, + /*diff_fingerprint=*/13464792036913846758U, + /*all_unchanged=*/false)))); + EXPECT_THAT( + right_computation_summary, + UnorderedElementsAre(Pair( + Pointee(Property(&HloComputation::name, "entry")), + FieldsAre(/*direction=*/ComputationMappingDirection::kRightToLeft, + /*main_matched_computation=*/ + Pointee(Property(&HloComputation::name, "entry")), + /*max_matched_instruction_count=*/1, + /*split_allegiance_instruction=*/0, + /*diff_fingerprint=*/13464792036913846758U, + /*all_unchanged=*/false)))); + EXPECT_THAT(diff_summary->grouped_computations, + UnorderedElementsAre(Pair( + 2864899211444957078U, + UnorderedElementsAre(FieldsAre( + /*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, "entry"))), + /*right_computations=*/UnorderedElementsAre(Pointee( + Property(&HloComputation::name, "entry")))))))); +} + +TEST_F(HloDiffTest, FindConnectedComponentsWorks) { + // Create a module with entry computation containing the following structure: + // [Param p0]->[Param p2]->┌-------┐ ┌----------┐ + // | add.1 |->| fusion.1 |->┌-------┐ + // [Param p1]->[Param p3]->└-------┘ └----------┘ | | ┌------┐ + // | add.3 |->| ROOT | + // [Param p4]->[Param p6]->┌-------┐ ┌----------┐ | | └------┘ + // | add.2 |->| fusion.2 |->└-------┘ + // [Param p5]->[Param p7]->└-------┘ └----------┘ + const char* hlo_string = R"( + HloModule module, is_scheduled=true + + fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) + } + + fused_computation.2 { + p6 = s32[32,16]{0,1:T(1,128)} parameter(0) + p7 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.2 = s32[32,16]{0,1:T(1,128)} add(p6, p7) + } + + ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + p4 = s32[32,16]{0, 1:T(1,128)} parameter(2) + p5 = s32[32,16]{0,1:T(1,128)} parameter(3) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[32,16]{0,1:T(1,128)} fusion(p4,p5), kind=kLoop, calls=fused_computation.2 + ROOT add.3 = s32[32,16]{0,1:T(1,128)} add(fusion.1, fusion.2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.1"), + GetNodeByName(*graph_r, "add.2"), *mappings, + /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.2"), + GetNodeByName(*graph_r, "add.1"), *mappings, + /*position_unchanged=*/true)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.3"), + GetNodeByName(*graph_r, "add.3"), *mappings, + /*position_unchanged=*/true)); + std::unique_ptr diff_result = + ConstructDiffResult(*graph_l, *graph_r, *mappings); + std::unique_ptr diff_summary = + ConstructDiffSummary(*graph_l, *graph_r, *mappings, *diff_result); + EXPECT_THAT( + diff_summary->grouped_computations, + UnorderedElementsAre( + Pair(2864899211444957078U, + UnorderedElementsAre( + FieldsAre(/*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.1"))), + /*right_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.2")))), + FieldsAre(/*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.2"))), + /*right_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.1")))))), + Pair(15473561031564762362U, + UnorderedElementsAre(FieldsAre( + /*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, "entry"))), + /*right_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, "entry")))))))); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc new file mode 100644 index 00000000000000..5fa80ded11c392 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc @@ -0,0 +1,117 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_eval.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h" +#include "xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h" +#include "xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h" +#include "xla/service/call_graph.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +absl::StatusOr> FindMappings( + const HloGumgraph& left, const HloGumgraph& right, + const MatchOptions& options = {}) { + LOG(INFO) << "Running Matchers"; + auto mappings = std::make_unique(); + mappings->MapInstructionsIfAbsent(&left.GetRoot(), &right.GetRoot(), + MatcherType::kManual); + + MatchCallGraphs(left, right, *mappings); + + TF_RETURN_IF_ERROR(left.GetCallGraph().VisitNodes( + [&](const CallGraphNode& node) { + if (auto it = mappings->left_to_right_computation_map.left.find(&node); + it != mappings->left_to_right_computation_map.left.end()) { + MatchComputationGraphs(left, right, node, *it->second, *mappings); + } + return absl::OkStatus(); + }, + /*visit_unreachable_nodes=*/true)); + + std::vector> matchers; + matchers.push_back( + std::make_unique(&left, &right)); + matchers.push_back( + std::make_unique(&left, &right)); + if (options.use_top_down_matcher) { + matchers.push_back(std::make_unique(&left, &right)); + } + + for (auto& matcher : matchers) { + matcher->Match(*mappings); + } + + return mappings; +} +} // namespace + +absl::StatusOr ComputeDiff(const HloModule& left, + const HloModule& right, + const DiffOptions& options, + bool run_eval) { + LOG(INFO) << "Initializing left module graph"; + TF_ASSIGN_OR_RETURN(std::unique_ptr left_graph, + HloGumgraph::Create(&left, options.fingerprint_options)); + LOG(INFO) << "Initialized left module graph of size: " + << left_graph->GetNodeCount() + << " and height: " << left_graph->GetRoot().props.height; + + LOG(INFO) << "Initializing right module graph"; + TF_ASSIGN_OR_RETURN(std::unique_ptr right_graph, + HloGumgraph::Create(&right, options.fingerprint_options)); + LOG(INFO) << "Initialized right module graph of size: " + << right_graph->GetNodeCount() + << " and height: " << right_graph->GetRoot().props.height; + + TF_ASSIGN_OR_RETURN(std::unique_ptr mappings, + FindMappings(*left_graph, *right_graph)); + + std::unique_ptr diff_result = + ConstructDiffResult(*left_graph, *right_graph, *mappings); + std::unique_ptr diff_summary = + ConstructDiffSummary(*left_graph, *right_graph, *mappings, *diff_result); + std::unique_ptr diff_eval = nullptr; + if (run_eval) { + diff_eval = ComputeDiffEval(*left_graph, *right_graph, *mappings, + *diff_result, *diff_summary); + } + + return HloGumgraphDiffResults( + {std::move(diff_result), std::move(diff_summary), std::move(diff_eval)}); +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h new file mode 100644 index 00000000000000..9d8db2f2554626 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h @@ -0,0 +1,51 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_HLO_GUMGRAPH_DIFF_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_HLO_GUMGRAPH_DIFF_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_eval.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" + +namespace xla { +namespace hlo_diff { + +// Options for computing the diff between two HLO modules. +struct DiffOptions { + HloGumgraphFingerprintOptions fingerprint_options; +}; + +struct HloGumgraphDiffResults { + std::unique_ptr diff_result; + std::unique_ptr diff_summary; + std::unique_ptr diff_eval; +}; + +// Compares two HLO modules, computes and returns differences. +absl::StatusOr ComputeDiff( + const HloModule& left, const HloModule& right, + const DiffOptions& options = {}, bool run_eval = false); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_HLO_GUMGRAPH_DIFF_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff_test.cc new file mode 100644 index 00000000000000..c24fe5b1629ab5 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff_test.cc @@ -0,0 +1,85 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h" + +#include + +#include +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +class HloDiffTest : public HloTestBase {}; + +TEST_F(HloDiffTest, ComputeDiffWorksWithoutEval) { + // Create a module with entry computation containing the following structure: + // [Param p0]->┌-------┐ + // | add.1 | + // [Param p1]->└-------┘ + const char* hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT add.1 = s32[32,16]{0,1:T(1,128)} add(p0, p1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + auto diff_result, + ComputeDiff(*module_l, *module_r, {}, /*run_eval=*/false)); + + EXPECT_NE(diff_result.diff_result, nullptr); + EXPECT_NE(diff_result.diff_summary, nullptr); + EXPECT_EQ(diff_result.diff_eval, nullptr); +} + +TEST_F(HloDiffTest, ComputeDiffWorksWithEval) { + // Create a module with entry computation containing the following structure: + // [Param p0]->┌-------┐ + // | add.1 | + // [Param p1]->└-------┘ + const char* hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT add.1 = s32[32,16]{0,1:T(1,128)} add(p0, p1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto diff_result, ComputeDiff(*module_l, *module_r, + {}, /*run_eval=*/true)); + + EXPECT_NE(diff_result.diff_result, nullptr); + EXPECT_NE(diff_result.diff_summary, nullptr); + EXPECT_NE(diff_result.diff_eval, nullptr); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h new file mode 100644 index 00000000000000..4a97e924f9e0c8 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h @@ -0,0 +1,121 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_HLO_GUMGRAPH_MAPPINGS_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_HLO_GUMGRAPH_MAPPINGS_H_ + +#include + +#include "boost/bimap.hpp" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/service/call_graph.h" + +namespace xla::hlo_diff { + +// Type of matcher that matched two HloInstructionNodes. +enum class MatcherType : std::uint8_t { + kNotSet, + kManual, + kComputationGraphExactFingerprintMatcher, + kComputationGraphExactSignatureMatcher, + kGreedySubGraphExactMatcher, + kGreedyDoubleCountedBottomUpMatcher, + kGreedyLimitedCandidatesBottomUpMatcher, + kGreedyLimitedCandidatesStaticSeedsBottomUpMatcher, + kGreedyTopDownMatcher, +}; + +// Computations with matching input parameters and output result are classified +// as kSignature matches. kExact matches on the other hand are kSignature +// matches that additionally have identical instructions in the computation +// graph, i.e. same computation fingerprint. +enum class ComputationMatchType : std::uint8_t { kExact, kSignature }; + +// Aggregated match characteristics of a mapped HloInstructionNode. +struct HloInstructionNodeMappingProps { + bool unchanged = false; + MatcherType matcher_type = MatcherType::kNotSet; +}; + +// Aggregated match characteristics of a mapped CallGraphNode. +struct HloCallGraphNodeMappingProps { + ComputationMatchType computation_match_type = ComputationMatchType::kExact; +}; + +using InstructionPair = boost::bimap< + const HloInstructionNode*, const HloInstructionNode*, + boost::bimaps::with_info>::value_type; + +using CallGraphNodePair = boost::bimap< + const CallGraphNode*, const CallGraphNode*, + boost::bimaps::with_info>::value_type; + +// Mapped nodes between two HloGumgraphs. +struct HloGumgraphMappings { + // Map between the left and right CallGraphNodes. + boost::bimap> + left_to_right_computation_map; + + // A bi-directional map between the left and right HloInstructionNodes along + // with additional information about the mapping. Check out + // https://www.boost.org/doc/libs/1_79_0/libs/bimap/doc/html/boost_bimap/the_tutorial/additional_information.html + // for more details on the bimap API. + boost::bimap> + left_to_right_instruction_map; + + // Maps two nodes if they are not already mapped. Returns true if mapping + // was performed. + inline bool MapInstructionsIfAbsent(const HloInstructionNode* left, + const HloInstructionNode* right, + const MatcherType matcher_type) { + auto [it, inserted] = left_to_right_instruction_map.insert( + InstructionPair(left, right, {.matcher_type = matcher_type})); + + return inserted; + } + + // Maps two CallGraphNodes if they are not already mapped. Returns true if + // mapping was performed. + inline bool MapComputationsIfAbsent( + const CallGraphNode& left, const CallGraphNode& right, + const ComputationMatchType computation_match_type) { + auto [it, inserted] = + left_to_right_computation_map.insert(CallGraphNodePair( + &left, &right, {.computation_match_type = computation_match_type})); + + return inserted; + } + + // Returns true if the left node is mapped to a right node. + inline bool InstructionMapContainsLeft( + const HloInstructionNode* left_node) const { + return left_to_right_instruction_map.left.find(left_node) != + left_to_right_instruction_map.left.end(); + } + + // Returns true if the right node is mapped to a left node. + inline bool InstructionMapContainsRight( + const HloInstructionNode* right_node) const { + return left_to_right_instruction_map.right.find(right_node) != + left_to_right_instruction_map.right.end(); + } +}; + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_HLO_GUMGRAPH_MAPPINGS_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/BUILD new file mode 100644 index 00000000000000..2e553b10e431bf --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/BUILD @@ -0,0 +1,115 @@ +load("//xla:xla.default.bzl", "xla_cc_test") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//xla/hlo/tools/hlo_diff:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "hlo_call_graph_matcher", + srcs = ["hlo_call_graph_matcher.cc"], + hdrs = ["hlo_call_graph_matcher.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/service:call_graph", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "hlo_call_graph_matcher_test", + srcs = ["hlo_call_graph_matcher_test.cc"], + deps = [ + ":hlo_call_graph_matcher", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/utils:test_util", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_computation_graph_matcher", + srcs = ["hlo_computation_graph_matcher.cc"], + hdrs = ["hlo_computation_graph_matcher.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/service:call_graph", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + ], +) + +xla_cc_test( + name = "hlo_computation_graph_matcher_test", + srcs = ["hlo_computation_graph_matcher_test.cc"], + deps = [ + ":hlo_computation_graph_matcher", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/utils:test_util", + "//xla/service:call_graph", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_gumgraph_matcher", + srcs = ["hlo_gumgraph_matcher.cc"], + hdrs = ["hlo_gumgraph_matcher.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "//xla/hlo/tools/hlo_diff/graph/utils:hlo_gumgraph_bfs", + "//xla/hlo/tools/hlo_diff/graph/utils:hlo_gumgraph_dfs", + "//xla/service:hlo_value", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "hlo_gumgraph_matcher_test", + srcs = ["hlo_gumgraph_matcher_test.cc"], + deps = [ + ":hlo_gumgraph_matcher", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/utils:test_util", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.cc new file mode 100644 index 00000000000000..e1355f02b822d2 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.cc @@ -0,0 +1,389 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/service/call_graph.h" +#include "xla/shape.h" + +namespace xla::hlo_diff { +namespace { + +using VisitorFunction = absl::FunctionRef; + +// Sort the callees of the caller computation in order of the caller +// computation's post order instructions. +std::vector> +SortCalleesByCallerComputationPostOrder( + const absl::flat_hash_set& callees, + HloComputation::CachingPostOrder& cpo, const HloGumgraph& gumgraph) { + std::vector> callsites; + for (auto* instruction : cpo.PostOrder()) { + for (const auto* computation : instruction->called_computations()) { + const CallGraphNode* callee = + &gumgraph.GetCallGraph().GetNode(computation); + if (callees.contains(callee)) { + callsites.push_back(std::make_pair(instruction, callee)); + } + } + } + + return callsites; +} + +// Match left and right callee computations based on call_site instruction +// attributes: ex. op_code and instruction position in the caller computation +// post-order. +void MapComputationCalleesWithSameFingerprintOrProgramShape( + const absl::flat_hash_set& left_callees, + const absl::flat_hash_set& right_callees, + HloComputation::CachingPostOrder& left_cpo, + HloComputation::CachingPostOrder& right_cpo, const HloGumgraph& left, + const HloGumgraph& right, HloGumgraphMappings& mappings, + const ComputationMatchType computation_match_type) { + // Don't attempt to match if there are different number of callee + // computations as its difficult to disambiguate between the callees. + if (left_callees.size() != right_callees.size()) { + return; + } + + // Match computations if there is exactly one callee on both sides. + if (left_callees.size() == 1 && right_callees.size() == 1) { + mappings.MapComputationsIfAbsent(*(*left_callees.begin()), + *(*right_callees.begin()), + computation_match_type); + return; + } + + // For multiple callees, match them in order of the caller computation's post + // order instructions. + std::vector> + left_callsites = + SortCalleesByCallerComputationPostOrder(left_callees, left_cpo, left); + std::vector> + right_callsites = SortCalleesByCallerComputationPostOrder( + right_callees, right_cpo, right); + + // Don't attempt to match if there are different number of call sites as its + // difficult to disambiguate between the call sites. + if (left_callsites.size() != right_callsites.size()) { + return; + } + + // Verify that all call sites instruction op codes and metadata op_name match. + for (int i = 0; i < left_callsites.size(); ++i) { + if (left_callsites[i].first->opcode() != + right_callsites[i].first->opcode()) { + return; + } + + if (left_callsites[i].first->metadata().op_name() != + right_callsites[i].first->metadata().op_name()) { + return; + } + } + + for (int i = 0; i < left_callsites.size(); ++i) { + mappings.MapComputationsIfAbsent(*left_callsites[i].second, + *right_callsites[i].second, + computation_match_type); + } +} + +void MapCalledComputations(const HloInstruction* left_instruction, + const HloInstruction* right_instruction, + const HloGumgraph& left, const HloGumgraph& right, + HloGumgraphMappings& mappings) { + for (int i = 0; i < left_instruction->called_computations().size(); ++i) { + mappings.MapComputationsIfAbsent( + left.GetCallGraph().GetNode( + left_instruction->called_computations().at(i)), + right.GetCallGraph().GetNode( + right_instruction->called_computations().at(i)), + ComputationMatchType::kSignature); + } +} + +// Process a single computation (CallGraphNode) in the left call graph. For each +// called computation in this computation, attempt to find a matching +// computation on the right call graph. +void ProcessCallGraphNode(const CallGraphNode& left_computation, + const HloGumgraph& left, const HloGumgraph& right, + HloGumgraphMappings& mappings) { + // Only match called computations if current computation is already matched. + auto it = mappings.left_to_right_computation_map.left.find(&left_computation); + if (it == mappings.left_to_right_computation_map.left.end() || + left_computation.callees().empty()) { + return; + } + + const CallGraphNode* right_computation = it->second; + HloComputation::CachingPostOrder left_cpo(left_computation.computation()); + HloComputation::CachingPostOrder right_cpo(right_computation->computation()); + + // Phase 1: Match called computations to computations with matching + // computation fingerprints, i.e. exact matches. + absl::flat_hash_map> + left_callees_by_fingerprint, right_callees_by_fingerprint; + for (const HloComputation* callee : left_computation.callees()) { + CallGraphNodeProps left_props = left.AllComputationProps().at(callee); + left_callees_by_fingerprint[left_props.fingerprint].insert( + left_props.call_graph_node); + } + for (const HloComputation* callee : right_computation->callees()) { + CallGraphNodeProps right_props = right.AllComputationProps().at(callee); + right_callees_by_fingerprint[right_props.fingerprint].insert( + right_props.call_graph_node); + } + + for (const auto& [fingerprint, left_callees] : left_callees_by_fingerprint) { + if (auto right_it = right_callees_by_fingerprint.find(fingerprint); + right_it != right_callees_by_fingerprint.end()) { + const absl::flat_hash_set& right_callees = + right_it->second; + MapComputationCalleesWithSameFingerprintOrProgramShape( + left_callees, right_callees, left_cpo, right_cpo, left, right, + mappings, ComputationMatchType::kExact); + } + } + + // Phase2: Match left called computations to right computations if their + // callsite instructions have matching opcodes and metadata op_name. + absl::flat_hash_map, + std::vector> + left_instructions_by_op, right_instructions_by_op; + // First we filter out instructions whose called computations are already + // matched. + for (const HloInstruction* instruction : left_cpo.PostOrder()) { + bool all_called_computations_matched = true; + for (const HloComputation* callee : instruction->called_computations()) { + if (auto left_it = mappings.left_to_right_computation_map.left.find( + &left.GetCallGraph().GetNode(callee)); + left_it == mappings.left_to_right_computation_map.left.end()) { + all_called_computations_matched = false; + break; + } + } + + if (!all_called_computations_matched) { + std::pair op_code_and_name = std::make_pair( + instruction->opcode(), instruction->metadata().op_name()); + left_instructions_by_op[op_code_and_name].push_back(instruction); + } + } + + for (const HloInstruction* instruction : right_cpo.PostOrder()) { + bool all_called_computations_matched = true; + for (const HloComputation* callee : instruction->called_computations()) { + if (auto right_it = mappings.left_to_right_computation_map.right.find( + &right.GetCallGraph().GetNode(callee)); + right_it == mappings.left_to_right_computation_map.right.end()) { + all_called_computations_matched = false; + break; + } + } + + if (!all_called_computations_matched) { + std::pair op_code_and_name = std::make_pair( + instruction->opcode(), instruction->metadata().op_name()); + right_instructions_by_op[op_code_and_name].push_back(instruction); + } + } + + // Match called computations if their callsite instructions have matching + // opcodes and metadata op_name and there is exactly one called computation + // on both sides. + for (const auto& [op, left_instructions] : left_instructions_by_op) { + auto right_it = right_instructions_by_op.find(op); + if (right_it == right_instructions_by_op.end()) { + continue; + } + + std::vector right_instructions = right_it->second; + if (left_instructions.size() == 1 && right_instructions.size() == 1) { + MapCalledComputations(left_instructions[0], right_instructions[0], left, + right, mappings); + } else { + // Even if there are multiple call sites with matching opcodes and + // metadata op_name, we still attempt to match the called computations if + // they are of the same size, but only for While opcodes. + switch (op.first) { + case HloOpcode::kWhile: { + if (left_instructions.size() != right_instructions.size()) { + break; + } + + for (int i = 0; i < left_instructions.size(); ++i) { + MapCalledComputations(left_instructions[i], right_instructions[i], + left, right, mappings); + } + break; + } + default: + break; + } + } + } + + // Phase 3: Match children computations with matching opcode, metadata op-name + // and program shapes as signature matches. + absl::flat_hash_map> + unmatched_left_callees, unmatched_right_callees; + for (const HloComputation* callee : left_computation.callees()) { + if (auto left_it = mappings.left_to_right_computation_map.left.find( + &left.GetCallGraph().GetNode(callee)); + left_it == mappings.left_to_right_computation_map.left.end()) { + const CallGraphNode& callee_node = left.GetCallGraph().GetNode(callee); + std::string opcode_and_name; + if (!callee_node.caller_callsites().empty()) { + const HloInstruction* caller_instruction = + callee_node.caller_callsites()[0].instruction(); + opcode_and_name = + absl::StrCat(caller_instruction->opcode(), + "::", caller_instruction->metadata().op_name()); + } else { + LOG(WARNING) << "Callee node " << callee_node.computation()->name() + << " has no caller callsites"; + } + std::string opcode_name_shape = absl::StrCat( + opcode_and_name, + "::", callee->ComputeProgramShape(/*include ids=*/false).ToString()); + unmatched_left_callees[opcode_name_shape].insert(&callee_node); + } + } + for (const HloComputation* callee : right_computation->callees()) { + if (auto right_it = mappings.left_to_right_computation_map.right.find( + &right.GetCallGraph().GetNode(callee)); + right_it == mappings.left_to_right_computation_map.right.end()) { + const CallGraphNode& callee_node = right.GetCallGraph().GetNode(callee); + std::string opcode_and_name; + if (callee_node.caller_callsites().size() == 1) { + const HloInstruction* caller_instruction = + callee_node.caller_callsites()[0].instruction(); + opcode_and_name = + absl::StrCat(caller_instruction->opcode(), + "::", caller_instruction->metadata().op_name()); + } + std::string key = absl::StrCat( + opcode_and_name, + "::", callee->ComputeProgramShape(/*include ids=*/false).ToString()); + unmatched_right_callees[key].insert(&callee_node); + } + } + + for (const auto& [shape, left_calleees] : unmatched_left_callees) { + if (auto right_it = unmatched_right_callees.find(shape); + right_it != unmatched_right_callees.end()) { + const absl::flat_hash_set& + program_shape_matched_right_calleees = right_it->second; + MapComputationCalleesWithSameFingerprintOrProgramShape( + left_calleees, program_shape_matched_right_calleees, left_cpo, + right_cpo, left, right, mappings, ComputationMatchType::kSignature); + } + } +} + +// Visits all CallGraphNodes in the call graph in BFS order. +void VisitCallGraphNodesBfs(const CallGraph& call_graph, + const CallGraphNode& root, + VisitorFunction visit_fn) { + absl::flat_hash_set visited; + std::queue queue; + queue.push(&root); + + while (!queue.empty()) { + const CallGraphNode* current_node = queue.front(); + queue.pop(); + + if (!visited.insert(current_node).second) { + continue; + } + + visit_fn(*current_node); + + for (const HloComputation* callee : current_node->callees()) { + queue.push(&call_graph.GetNode(callee)); + } + } +} + +} // namespace + +void MatchCallGraphs(const HloGumgraph& left, const HloGumgraph& right, + HloGumgraphMappings& mappings) { + // Match the entry computations as signature matches. This optimizes for the + // common case, i.e. users comparing similar programs whose input/output + // parameters are often identical or very similar. + ComputationMatchType entry_computation_match_type = + ComputationMatchType::kSignature; + if (left.AllComputationProps() + .at(left.GetHloModule().entry_computation()) + .fingerprint == right.AllComputationProps() + .at(right.GetHloModule().entry_computation()) + .fingerprint) { + entry_computation_match_type = ComputationMatchType::kExact; + } + mappings.MapComputationsIfAbsent( + left.GetCallGraph().GetNode(left.GetHloModule().entry_computation()), + right.GetCallGraph().GetNode(right.GetHloModule().entry_computation()), + entry_computation_match_type); + + // Traverse the call graph of the left HloGumgraph in BFS order. For each + // visited computation node in the left call graph we attempt to find a + // matching computation node on the right call graph. Two computation nodes + // are only matched if their parent computations are already matched. + VisitCallGraphNodesBfs( + left.GetCallGraph(), + left.GetCallGraph().GetNode(left.GetHloModule().entry_computation()), + [&](const CallGraphNode& left_node) { + return ProcessCallGraphNode(left_node, left, right, mappings); + }); + + int signature_match_count = 0, exact_match_count = 0; + for (auto it = mappings.left_to_right_computation_map.left.begin(); + it != mappings.left_to_right_computation_map.left.end(); ++it) { + if (it->info.computation_match_type == ComputationMatchType::kSignature) { + ++signature_match_count; + } else { + ++exact_match_count; + } + } + LOG(INFO) << "Finished matching call graphs for " + << left.GetHloModule().name() << ": " + << left.GetCallGraph().nodes().size() << " and " + << right.GetHloModule().name() << ": " + << right.GetCallGraph().nodes().size() + << ". Total signature matched computations: " + << signature_match_count + << ". Total exact matched computations: " << exact_match_count; +} + +} // namespace xla::hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h new file mode 100644 index 00000000000000..23759bc49b78a4 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h @@ -0,0 +1,38 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_CALL_GRAPH_MATCHER_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_CALL_GRAPH_MATCHER_H_ + +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla::hlo_diff { + +// Matches similar computations between two HloGumgraphs (HloModules). +// Computations with the semantically same input parameters and output +// parameters are matched as `kSignature` matches. If a kSignature match +// additionally has semantically identical instructions, then its classified as +// a `kExact` matches. + +// The matcher does not match the instructions within the matched computations +// which is the responsibility of the subequent matchers. +void MatchCallGraphs(const HloGumgraph& left, const HloGumgraph& right, + HloGumgraphMappings& mappings); + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_CALL_GRAPH_MATCHER_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher_test.cc new file mode 100644 index 00000000000000..9bb1a36cfe73a5 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher_test.cc @@ -0,0 +1,400 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/matchers/hlo_call_graph_matcher.h" + +#include + +#include +#include +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::hlo_diff { +namespace { + +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class HloCallGraphMatcherTest : public HloTestBase {}; + +TEST_F(HloCallGraphMatcherTest, ExactFingerprintMatches) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) +} + +fused_computation.2 { + p4 = s32[32,16]{0,1:T(1,128)} parameter(0) + p5 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.2 = s32[32,16]{0,1:T(1,128)} add(p4, p5) +} + +ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.2 + ROOT add = s32[32,16]{0,1:T(1,128)} add(fusion.1, fusion.2) +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.11 { + p21 = s32[32,16]{0,1:T(1,128)} parameter(0) + p31 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.11 = s32[32,16]{0,1:T(1,128)} add(p21, p31) +} + +fused_computation.21 { + p41 = s32[32,16]{0,1:T(1,128)} parameter(0) + p51 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.21 = s32[32,16]{0,1:T(1,128)} add(p41, p51) +} + +ENTRY entry { + p01 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p11 = s32[32,16]{0,1:T(1,128)} parameter(1) + fusion.11 = s32[32,16]{0,1:T(1,128)} fusion(p01,p11), kind=kLoop, calls=fused_computation.11 + fusion.21 = s32[32,16]{0,1:T(1,128)} fusion(p01,p11), kind=kLoop, calls=fused_computation.21 + ROOT add11 = s32[32,16]{0,1:T(1,128)} add(fusion.11, fusion.21) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + + MatchCallGraphs(*left_gumgraph, *right_gumgraph, *mappings); + + auto matched_computations = ExtractMappedComputationNames(*mappings); + auto match_type = ExtractComputationMatchType(*mappings); + EXPECT_THAT( + matched_computations, + UnorderedElementsAre(Pair("fused_computation.1", "fused_computation.11"), + Pair("fused_computation.2", "fused_computation.21"), + Pair("entry", "entry"))); + EXPECT_THAT(match_type, + UnorderedElementsAre( + Pair("fused_computation.1", ComputationMatchType::kExact), + Pair("fused_computation.2", ComputationMatchType::kExact), + Pair("entry", ComputationMatchType::kExact))); +} + +TEST_F(HloCallGraphMatcherTest, UnequalFingerprintMatchesNotMatched) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) +} + +fused_computation.2 { + p4 = s32[32,16]{0,1:T(1,128)} parameter(0) + p5 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.2 = s32[32,16]{0,1:T(1,128)} add(p4, p5) +} + +ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.2 + ROOT add = s32[32,16]{0,1:T(1,128)} add(fusion.1, fusion.2) +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.11 { + p21 = s32[32,16]{0,1:T(1,128)} parameter(0) + p31 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.11 = s32[32,16]{0,1:T(1,128)} add(p21, p31) +} + +ENTRY entry { + p01 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p11 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT fusion.11 = s32[32,16]{0,1:T(1,128)} fusion(p01,p11), kind=kLoop, calls=fused_computation.11 +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + + MatchCallGraphs(*left_gumgraph, *right_gumgraph, *mappings); + + auto matched_computations = ExtractMappedComputationNames(*mappings); + auto match_type = ExtractComputationMatchType(*mappings); + EXPECT_THAT(matched_computations, + UnorderedElementsAre(Pair("entry", "entry"))); + EXPECT_THAT(match_type, UnorderedElementsAre( + Pair("entry", ComputationMatchType::kSignature))); +} + +TEST_F(HloCallGraphMatcherTest, MultipleWhileInstructionsMatched) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +body_1 { + prev_1 = s32[2] parameter(0) + all-gather_1 = s32[4] all-gather(s32[2] prev_1), replica_groups={}, dimensions={0}, backend_config="{}" + ROOT slice_1 = s32[2] slice(all-gather_1), slice={[0:2]} +} + +condition_1 { + prev_1 = s32[2] parameter(0) + constant_1 = pred[] constant(true) + ROOT copy_1 = pred[] copy(constant_1) +} + +body_2 { + prev_2 = s32[2] parameter(0) + all-gather_2 = s32[4] all-gather(s32[2] prev_2), replica_groups={}, dimensions={0}, backend_config="{}" + ROOT slice_2 = s32[2] slice(all-gather_2), slice={[0:2]} +} + +condition_2 { + prev_2 = s32[2] parameter(0) + constant_2 = pred[] constant(true) + ROOT copy_2 = pred[] copy(constant_2) +} + +body_3 { + prev_3 = s32[2] parameter(0) + all-gather_3 = s32[4] all-gather(s32[2] prev_3), replica_groups={}, dimensions={0}, backend_config="{}" + ROOT slice_3 = s32[2] slice(all-gather_3), slice={[0:2]} +} + +condition_3 { + prev_3 = s32[2] parameter(0) + constant_3 = pred[] constant(true) + ROOT copy_3 = pred[] copy(constant_3) +} + +ENTRY entry { + constant = s32[2] constant({0,0}) + while.1 = s32[2] while(s32[2] constant), condition=condition_1, body=body_1 + while.2 = s32[2] while(s32[2] constant), condition=condition_2, body=body_2, metadata={op_name="while-activations"} + while.3 = s32[2] while(s32[2] constant), condition=condition_3, body=body_3 + add.1 = s32[2] add(while.1, while.2) + ROOT add.2 = s32[2] add(add.1, while.3) +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true +body_1 { + prev_1 = s32[2] parameter(0) + all-gather_1 = s32[4] all-gather(s32[2] prev_1), replica_groups={}, dimensions={0}, backend_config="{}" + ROOT slice_1 = s32[2] slice(all-gather_1), slice={[0:2]} +} + +condition_1 { + prev_1 = s32[2] parameter(0) + constant_1 = pred[] constant(true) + ROOT copy_1 = pred[] copy(constant_1) +} + +body_2 { + prev_2 = s32[2] parameter(0) + all-gather_2 = s32[4] all-gather(s32[2] prev_2), replica_groups={}, dimensions={0}, backend_config="{}" + ROOT slice_2 = s32[2] slice(all-gather_2), slice={[0:2]} +} + +condition_2 { + prev_2 = s32[2] parameter(0) + constant_2 = pred[] constant(true) + ROOT copy_2 = pred[] copy(constant_2) +} + +body_3 { + prev_3 = s32[2] parameter(0) + all-gather_3 = s32[4] all-gather(s32[2] prev_3), replica_groups={}, dimensions={0}, backend_config="{}" + ROOT slice_3 = s32[2] slice(all-gather_3), slice={[0:2]} +} + +condition_3 { + prev_3 = s32[2] parameter(0) + constant_3 = pred[] constant(true) + ROOT copy_3 = pred[] copy(constant_3) +} + +ENTRY entry { + constant = s32[2] constant({0,0}) + while.1 = s32[2] while(s32[2] constant), condition=condition_1, body=body_1, metadata={op_name="while-activations"} + while.2 = s32[2] while(s32[2] constant), condition=condition_2, body=body_2 + while.3 = s32[2] while(s32[2] constant), condition=condition_3, body=body_3 + add.1 = s32[2] add(while.1, while.2) + ROOT add.2 = s32[2] add(add.1, while.3) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + + MatchCallGraphs(*left_gumgraph, *right_gumgraph, *mappings); + + auto matched_computations = ExtractMappedComputationNames(*mappings); + auto match_type = ExtractComputationMatchType(*mappings); + EXPECT_THAT(matched_computations, + UnorderedElementsAre( + Pair("body_1", "body_2"), Pair("body_2", "body_1"), + Pair("body_3", "body_3"), Pair("condition_1", "condition_2"), + Pair("condition_2", "condition_1"), + Pair("condition_3", "condition_3"), Pair("entry", "entry"))); + EXPECT_THAT(match_type, + UnorderedElementsAre( + Pair("body_1", ComputationMatchType::kSignature), + Pair("body_2", ComputationMatchType::kSignature), + Pair("body_3", ComputationMatchType::kSignature), + Pair("condition_1", ComputationMatchType::kSignature), + Pair("condition_2", ComputationMatchType::kSignature), + Pair("condition_3", ComputationMatchType::kSignature), + Pair("entry", ComputationMatchType::kExact))); +} + +TEST_F(HloCallGraphMatcherTest, ExactSignatureMatches) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.1 { + p.2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p.3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p.2, p.3) +} + +fused_computation.2 { + p.4 = s32[] parameter(0) + p.5 = s32[] parameter(1) + add.2 = s32[] add(p.4, p.5) +} + +fused_computation.3 { + p.6 = s32[32,16]{0,1:T(1,128)} parameter(0) + p.7 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.3 = s32[32,16]{0,1:T(1,128)} add(p.6, p.7) +} + +fused_computation.4 { + p.8 = s32[] parameter(0) + p.9 = s32[] parameter(1) + add.4 = s32[] add(p.8, p.9) +} + +ENTRY entry { + p.0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p.1 = s32[32,16]{0,1:T(1,128)} parameter(1) + p.2 = s32[] parameter(2) + p.3 = s32[] parameter(3) + fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p.0, p.1), kind=kLoop, calls=fused_computation.1 + fusion.2 = s32[] fusion(p.2, p.3), kind=kLoop, calls=fused_computation.2 + fusion.3 = s32[32,16]{0,1:T(1,128)} fusion(p.0, p.1), kind=kLoop, calls=fused_computation.3, metadata={op_name="add_fusion"} + fusion.4 = s32[] fusion(p.2, p.3), kind=kLoop, calls=fused_computation.4 +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.11 { + p.21 = s32[32,16]{0,1:T(1,128)} parameter(0) + p.31 = s32[32,16]{0,1:T(1,128)} parameter(1) + subtract.11 = s32[32,16]{0,1:T(1,128)} subtract(p.21, p.31) +} + +fused_computation.21 { + p.41 = s32[] parameter(0) + p.51 = s32[] parameter(1) + subtract.21 = s32[] subtract(p.41, p.51) +} + +fused_computation.31 { + p.21 = s32[32,16]{0,1:T(1,128)} parameter(0) + p.31 = s32[32,16]{0,1:T(1,128)} parameter(1) + subtract.11 = s32[32,16]{0,1:T(1,128)} subtract(p.21, p.31) +} + +fused_computation.41 { + p.41 = s32[] parameter(0) + p.51 = s32[] parameter(1) + subtract.21 = s32[] subtract(p.41, p.51) +} + +ENTRY entry { + p.01 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p.11 = s32[32,16]{0,1:T(1,128)} parameter(1) + p.21 = s32[] parameter(2) + p.31 = s32[] parameter(3) + fusion.11 = s32[32,16]{0,1:T(1,128)} fusion(p.01,p.11), kind=kLoop, calls=fused_computation.11 + fusion.21 = s32[] fusion(p.21, p.31), kind=kLoop, calls=fused_computation.21 + fusion.31 = s32[32,16]{0,1:T(1,128)} fusion(p.01,p.11), kind=kLoop, calls=fused_computation.31, metadata={op_name="add_fusion"} + fusion.41 = s32[] fusion(p.21, p.31), kind=kLoop, calls=fused_computation.41 + +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + + MatchCallGraphs(*left_gumgraph, *right_gumgraph, *mappings); + + auto matched_computations = ExtractMappedComputationNames(*mappings); + auto match_type = ExtractComputationMatchType(*mappings); + EXPECT_THAT( + matched_computations, + UnorderedElementsAre(Pair("fused_computation.1", "fused_computation.11"), + Pair("fused_computation.2", "fused_computation.21"), + Pair("fused_computation.3", "fused_computation.31"), + Pair("fused_computation.4", "fused_computation.41"), + Pair("entry", "entry"))); + EXPECT_THAT(match_type, + UnorderedElementsAre( + Pair("fused_computation.1", ComputationMatchType::kSignature), + Pair("fused_computation.2", ComputationMatchType::kSignature), + Pair("fused_computation.3", ComputationMatchType::kSignature), + Pair("fused_computation.4", ComputationMatchType::kSignature), + Pair("entry", ComputationMatchType::kExact))); +} + +} // namespace +} // namespace xla::hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.cc new file mode 100644 index 00000000000000..559ebccfd4ad05 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.cc @@ -0,0 +1,401 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/service/call_graph.h" + +namespace xla { +namespace hlo_diff { +namespace { + +// Function to compute property match score between two instructions. +// Compares various properties of the instructions and returns a int score. +// Higher the score, more similar the instructions are. +using PropertyMatchesFn = absl::FunctionRef; + +// Returns true if all the users of the left instruction are matched to the +// right instruction users by fingerprint. +bool AllInstructionUsersAreMatched(const HloInstructionNode* left, + const HloInstructionNode* right) { + absl::flat_hash_set left_users, right_users; + for (const HloInstructionNode* user : left->parents) { + left_users.insert(user); + } + for (const HloInstructionNode* user : right->parents) { + right_users.insert(user); + } + + for (const HloInstructionNode* user : left_users) { + if (!right_users.contains(user)) { + return false; + } + } + return true; +} + +// Returns count of matched properties between two leaf instructions - i.e. +// parameter or constant. +int LeafPropertyMatches(const HloInstructionNode* left, + const HloInstructionNode* right) { + int match_score = 0; + if (left->instruction->shape().has_layout() && + right->instruction->shape().has_layout() && + (left->instruction->shape().layout() == + right->instruction->shape().layout())) { + ++match_score; + } + + if (left->instruction->has_sharding() && right->instruction->has_sharding() && + (left->instruction->sharding() == right->instruction->sharding())) { + ++match_score; + } + + if (!left->instruction->metadata().op_name().empty() && + !right->instruction->metadata().op_name().empty() && + (left->instruction->metadata().op_name() == + right->instruction->metadata().op_name())) { + ++match_score; + } + + if (!left->instruction->metadata().source_file().empty() && + !right->instruction->metadata().source_file().empty() && + (left->instruction->metadata().source_file() == + right->instruction->metadata().source_file())) { + ++match_score; + } + + if ((left->instruction->metadata().source_line() != 0) && + (right->instruction->metadata().source_line() != 0) && + (left->instruction->metadata().source_line() == + right->instruction->metadata().source_line())) { + ++match_score; + } + + if (AllInstructionUsersAreMatched(left, right)) { + ++match_score; + } + + return match_score; +} + +// Returns a similarity score between 0-1.0 of two parameters based on their +// sharding, layout, name and users. +int ParamPropertyMatches(const HloInstructionNode* left, + const HloInstructionNode* right) { + int match_score = LeafPropertyMatches(left, right); + + // A Parameter's name and parameter number are typically consistently + // generated by the frameworks. But in some cases, the name and parameter + // numbers might be differently generated for the same instruction, but if + // they are both same, we can be pretty confident that they are the same + // instruction. + if ((left->instruction->name() == right->instruction->name()) && + (left->instruction->parameter_number() == + right->instruction->parameter_number())) { + ++match_score; + } + + return match_score; +} + +// Returns count of matched properties between two constant instructions. +int ConstantPropertyMatches(const HloInstructionNode* left, + const HloInstructionNode* right) { + int match_score = LeafPropertyMatches(left, right); + + // Use the canonical options as fingerprint ignore float values. + if (left->instruction->ToString(HloPrintOptions::Canonical()) == + right->instruction->ToString(HloPrintOptions::Canonical())) { + ++match_score; + } + + if (left->parents.size() == right->parents.size()) { + ++match_score; + } + return match_score; +} + +// Match instructions with multiple match candidates using similarity measures. +void MatchInstructionsWithMultipleCandidates( + const absl::flat_hash_set& left_instructions, + const absl::flat_hash_set& right_instructions, + HloGumgraphMappings& mappings, PropertyMatchesFn property_matches_fn, + const MatcherType& matcher_type) { + for (const HloInstructionNode* left : left_instructions) { + double max_match_score = 0.0; + std::vector right_candidates; + for (const HloInstructionNode* right : right_instructions) { + double similarity = property_matches_fn(left, right); + if (similarity > max_match_score) { + max_match_score = similarity; + right_candidates.clear(); + right_candidates.push_back(right); + } else if (similarity == max_match_score) { + right_candidates.push_back(right); + } + } + + // Avoid matching instructions with multiple candidates. + if (right_candidates.size() == 1) { + mappings.MapInstructionsIfAbsent(left, right_candidates[0], matcher_type); + } + } +} + +// Find optimal matches between the left and right leaf instructions - i.e. +// parameter or constant. +// This function is called when attempting to map two computations. The goal is +// to establish a mapping between corresponding leaf instructions from the +// 'left_instructions' and 'right_instructions' sets. These sets are derived +// from the two computations being mapped. +void MatchLeafInstructions( + const absl::flat_hash_set& left_instructions, + const absl::flat_hash_set& right_instructions, + HloGumgraphMappings& mappings, PropertyMatchesFn property_matches_fn, + const MatcherType& matcher_type) { + absl::flat_hash_set matched_instructions; + + // Phase 0: Direct mapping if only one instruction in each set. + if (left_instructions.size() == 1 && right_instructions.size() == 1) { + mappings.MapInstructionsIfAbsent(*left_instructions.begin(), + *right_instructions.begin(), matcher_type); + return; // Early return after direct mapping. + } + + // Phase 1: Map instructions with the same shape and metadata op name if its + // specified. This name is often unique within a computation and specified by + // the frameworks. Note that for XLA generated computations, the metadata is + // not consistently specified. + for (const HloInstructionNode* left_instruction : left_instructions) { + if (left_instruction->instruction->metadata().op_name().empty()) { + continue; + } + int candidates_found = 0; + const HloInstructionNode* candidate = nullptr; + + for (const HloInstructionNode* right_instruction : right_instructions) { + bool same_shape = left_instruction->instruction->shape().ToString( + /*print_layout=*/false) == + right_instruction->instruction->shape().ToString( + /*print_layout=*/false); + bool same_op_name = left_instruction->instruction->metadata().op_name() == + right_instruction->instruction->metadata().op_name(); + if (same_shape && same_op_name) { + ++candidates_found; + candidate = right_instruction; + } + } + + // Avoid matching instructions with multiple candidates. + if (candidates_found == 1) { + mappings.MapInstructionsIfAbsent(left_instruction, candidate, + matcher_type); + matched_instructions.insert(left_instruction); + matched_instructions.insert(candidate); + } + } + + // Phase 2: Group instructions by shape. + // 2.1: Match unique instructions with the same shape + // 2.2: Match instructions with multiple candidates using similarity measures. + absl::flat_hash_map> + left_instructions_by_shape; + for (const HloInstructionNode* instruction : left_instructions) { + if (!matched_instructions.contains(instruction)) { + left_instructions_by_shape[instruction->instruction->shape().ToString( + /*print_layout=*/false)] + .insert(instruction); + } + } + + absl::flat_hash_map> + right_instructions_by_shape; + for (const HloInstructionNode* instruction : right_instructions) { + if (!matched_instructions.contains(instruction)) { + right_instructions_by_shape[instruction->instruction->shape().ToString( + /*print_layout=*/false)] + .insert(instruction); + } + } + + for (const auto& [shape, shape_left_instructions] : + left_instructions_by_shape) { + if (auto it = right_instructions_by_shape.find(shape); + it != right_instructions_by_shape.end()) { + absl::flat_hash_set shape_right_instructions = + it->second; + // Phase 2.1: Match unique instructions with the same shape. + if (shape_left_instructions.size() == 1 && + shape_right_instructions.size() == 1) { + mappings.MapInstructionsIfAbsent(*shape_left_instructions.begin(), + *shape_right_instructions.begin(), + matcher_type); + } else { + // Phase 2.2: Match instructions with multiple candidates using + // similarity measures. + MatchInstructionsWithMultipleCandidates( + shape_left_instructions, shape_right_instructions, mappings, + property_matches_fn, matcher_type); + } + } + } +} + +// Match parameter instructions between the left and right computations. +void MatchComputationParams(const HloGumgraph& left, const HloGumgraph& right, + const CallGraphNode& left_computation, + const CallGraphNode& right_computation, + HloGumgraphMappings& mappings, + const MatcherType& matcher_type) { + absl::flat_hash_set left_params, right_params; + for (const HloInstruction* param : + left_computation.computation()->parameter_instructions()) { + left_params.insert(left.GetNode(param)); + } + for (const HloInstruction* param : + right_computation.computation()->parameter_instructions()) { + right_params.insert(right.GetNode(param)); + } + + MatchLeafInstructions(left_params, right_params, mappings, + std::ref(ParamPropertyMatches), matcher_type); +} + +// Match constant instructions between the left and right computations. +void MatchComputationConstants(const HloGumgraph& left, + const HloGumgraph& right, + const CallGraphNode& left_computation, + const CallGraphNode& right_computation, + HloGumgraphMappings& mappings, + const MatcherType& matcher_type) { + absl::flat_hash_set left_constants, + right_constants; + for (const HloInstruction* instruction : + left_computation.computation()->instructions()) { + if (instruction->IsConstant()) { + left_constants.insert(left.GetNode(instruction)); + } + } + for (const HloInstruction* instruction : + right_computation.computation()->instructions()) { + if (instruction->IsConstant()) { + right_constants.insert(right.GetNode(instruction)); + } + } + + MatchLeafInstructions(left_constants, right_constants, mappings, + std::ref(ConstantPropertyMatches), matcher_type); +} + +// Match the call site instruction and it's operands for a matched left and +// right computation. +void MatchCallSites(const HloGumgraph& left, const HloGumgraph& right, + const CallGraphNode& left_computation, + const CallGraphNode& right_computation, + HloGumgraphMappings& mappings) { + // Only match call sites if both computations are called from exactly one call + // site. In case a computation is called from multiple call sites, we cannot + // disambiguate between the call sites. The subsequent matchers should be able + // to find the matches between the call sites in such cases. + if (left_computation.caller_callsites().size() != 1 || + right_computation.caller_callsites().size() != 1) { + return; + } + + const CallSite& left_call_site = *left_computation.caller_callsites().begin(); + const CallSite& right_call_site = + *right_computation.caller_callsites().begin(); + + // Match the call site instruction. + mappings.MapInstructionsIfAbsent( + left.GetNode(left_call_site.instruction()), + right.GetNode(right_call_site.instruction()), + MatcherType::kComputationGraphExactSignatureMatcher); +} + +} // namespace + +void MatchComputationGraphs(const HloGumgraph& left, const HloGumgraph& right, + const CallGraphNode& left_computation, + const CallGraphNode& right_computation, + HloGumgraphMappings& mappings) { + auto it = mappings.left_to_right_computation_map.left.find(&left_computation); + if (it == mappings.left_to_right_computation_map.left.end()) { + return; + } + + MatchCallSites(left, right, left_computation, right_computation, mappings); + + // If the two computations are exact matches, we can match all + // instructions in the two computations. + if (it->info.computation_match_type == ComputationMatchType::kExact) { + auto left_instructions = + left_computation.computation()->MakeInstructionPostOrder(); + auto right_instructions = + right_computation.computation()->MakeInstructionPostOrder(); + if (left_instructions.size() != right_instructions.size()) { + LOG(WARNING) << "Computation size mismatch: Left computation: " + << left_computation.computation()->name() << " has " + << left_instructions.size() + << " instructions and right computation: " + << right_computation.computation()->name() << " has " + << right_instructions.size() << " instructions"; + return; + } + + for (int i = 0; i < left_instructions.size(); ++i) { + mappings.MapInstructionsIfAbsent( + left.GetNode(left_instructions[i]), + right.GetNode(right_instructions[i]), + MatcherType::kComputationGraphExactFingerprintMatcher); + } + } else { + // If the two computations are signature matches, we can match the + // inputs (parameters, constants) and root instruction of the two + // computation graph. + MatchComputationParams(left, right, left_computation, right_computation, + mappings, + MatcherType::kComputationGraphExactSignatureMatcher); + MatchComputationConstants( + left, right, left_computation, right_computation, mappings, + MatcherType::kComputationGraphExactSignatureMatcher); + + if (left_computation.computation()->root_instruction()->opcode() == + right_computation.computation()->root_instruction()->opcode()) { + mappings.MapInstructionsIfAbsent( + left.GetNode(left_computation.computation()->root_instruction()), + right.GetNode(right_computation.computation()->root_instruction()), + MatcherType::kComputationGraphExactSignatureMatcher); + } + } +} +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h new file mode 100644 index 00000000000000..6989894760e3e2 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h @@ -0,0 +1,35 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_COMPUTATION_GRAPH_MATCHER_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_COMPUTATION_GRAPH_MATCHER_H_ + +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/service/call_graph.h" + +namespace xla::hlo_diff { + +// Matches instructions of matched left and right computations in the left and +// right HloGumgraphs. +void MatchComputationGraphs(const HloGumgraph& left, const HloGumgraph& right, + const CallGraphNode& left_computation, + const CallGraphNode& right_computation, + HloGumgraphMappings& mappings); + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_COMPUTATION_GRAPH_MATCHER_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher_test.cc new file mode 100644 index 00000000000000..e97ed299e70f81 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher_test.cc @@ -0,0 +1,250 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/matchers/hlo_computation_graph_matcher.h" + +#include + +#include +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" +#include "xla/service/call_graph.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::hlo_diff { +namespace { + +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class HloComputationGraphMatcherTest : public HloTestBase {}; + +TEST_F(HloComputationGraphMatcherTest, MatchSingleParameterOrConstant) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p1 = bf16[2]{0} parameter(0), metadata={op_name="first-phase"} + c1 = bf16[2]{0} constant({1.1, 2.2}) + + ROOT add1 = bf16[2]{0} add(p1, c1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p2 = bf16[3]{0} parameter(0), metadata={op_name="first-phase.modify"} + c2 = bf16[3]{0} constant({1.1, 2.2, 3.3}) + + ROOT add2 = bf16[3]{0} add(p2, c2) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + const CallGraphNode& left_entry_computation = + left_gumgraph->GetCallGraph().GetNode(left_module->entry_computation()); + const CallGraphNode& right_entry_computation = + right_gumgraph->GetCallGraph().GetNode(right_module->entry_computation()); + + mappings->MapComputationsIfAbsent(left_entry_computation, + right_entry_computation, + ComputationMatchType::kSignature); + MatchComputationGraphs(*left_gumgraph, *right_gumgraph, + left_entry_computation, right_entry_computation, + *mappings); + + auto matched_params = ExtractMappedInstructionNames(*mappings); + EXPECT_THAT(matched_params, + UnorderedElementsAre(Pair("p1", "p2"), Pair("c1", "c2"), + Pair("add1", "add2"))); +} + +TEST_F(HloComputationGraphMatcherTest, MatchComputationParams) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p21 = f32[10]{0} parameter(0), metadata={op_name="first-phase"} + p22 = f32[10]{0:T(128)} parameter(1), metadata={op_name="first-phase.multiple-matches", source_file="test.cc", source_line=43} + p23 = f32[20]{0} parameter(2) + p24 = f32[10]{0} parameter(3), metadata={source_file="test.cc", source_line=42} + p25 = f32[10]{0} parameter(4), sharding={maximal device=1} + p26 = f32[30]{0} parameter(5), sharding={maximal device=1} + + add21 = f32[10]{0} add(p21, p22) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p11 = f32[10]{0} parameter(0), metadata={op_name="first-phase"} + p12 = f32[10]{0} parameter(1), metadata={op_name="first-phase.multiple-matches"} + p13 = f32[10]{0} parameter(2), metadata={op_name="first-phase.multiple-matches"} + p14 = f32[20]{0} parameter(3) + p15 = f32[10]{0} parameter(4), metadata={source_file="test.cc", source_line=42} + p16 = f32[10]{0} parameter(5), sharding={maximal device=1} + p17 = f32[30]{0} parameter(6) + p18 = f32[10]{0:T(128)} parameter(7), metadata={source_file="test.cc", source_line=43} + + ROOT add22 = f32[10]{0} add(p11, p18) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + const CallGraphNode& left_entry_computation = + left_gumgraph->GetCallGraph().GetNode(left_module->entry_computation()); + const CallGraphNode& right_entry_computation = + right_gumgraph->GetCallGraph().GetNode(right_module->entry_computation()); + + mappings->MapComputationsIfAbsent(left_entry_computation, + right_entry_computation, + ComputationMatchType::kSignature); + MatchComputationGraphs( + *left_gumgraph, *right_gumgraph, + left_gumgraph->GetCallGraph().GetNode(left_module->entry_computation()), + right_gumgraph->GetCallGraph().GetNode(right_module->entry_computation()), + *mappings); + + auto matched_params = ExtractMappedInstructionNames(*mappings); + EXPECT_THAT(matched_params, + UnorderedElementsAre(Pair("p21", "p11"), Pair("p22", "p18"), + Pair("p23", "p14"), Pair("p24", "p15"), + Pair("p25", "p16"), Pair("p26", "p17"), + Pair("add21", "add22"))); +} + +TEST_F(HloComputationGraphMatcherTest, MatchComputationConstants) { + const char* hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + c20 = bf16[2]{0} constant({1.1, 2.2}) + c21 = bf16[2]{0} constant({1.1, 2.2}) + c22 = bf16[2]{0} constant({1.1, 2.2}) + c23 = bf16[2]{0} constant({5.5, 6.6}) + c24 = u32[2]{0} constant({1, 2}), metadata={op_name="first-phase"} + c25 = bf16[1] constant(0.0), metadata={source_file="test.cc", source_line=42} + c26 = s32[4]{0} constant({1, 2, 3, 4}) + + add21 = bf16[2]{0} add(c22, c23) + add22 = bf16[2]{0} add(c22, c23) + add23 = bf16[2]{0} add(add21, add22) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + const CallGraphNode& left_entry_computation = + left_gumgraph->GetCallGraph().GetNode(left_module->entry_computation()); + const CallGraphNode& right_entry_computation = + right_gumgraph->GetCallGraph().GetNode(right_module->entry_computation()); + + mappings->MapComputationsIfAbsent(left_entry_computation, + right_entry_computation, + ComputationMatchType::kSignature); + MatchComputationGraphs(*left_gumgraph, *right_gumgraph, + left_entry_computation, right_entry_computation, + *mappings); + + auto matched_params = ExtractMappedInstructionNames(*mappings); + EXPECT_THAT(matched_params, + UnorderedElementsAre(Pair("c22", "c22"), Pair("c23", "c23"), + Pair("c24", "c24"), Pair("c25", "c25"), + Pair("c26", "c26"), Pair("add23", "add23"))); +} + +TEST_F(HloComputationGraphMatcherTest, + ExactMatchComputationsInstructionsExactlyMatched) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.1 { + p2 = s32[32,16]{0,1:T(1,128)} parameter(0) + p3 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.1 = s32[32,16]{0,1:T(1,128)} add(p2, p3) +} + +ENTRY entry { + p0 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p1 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT fusion.1 = s32[32,16]{0,1:T(1,128)} fusion(p0,p1), kind=kLoop, calls=fused_computation.1 +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_module, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation.11 { + p21 = s32[32,16]{0,1:T(1,128)} parameter(0) + p31 = s32[32,16]{0,1:T(1,128)} parameter(1) + add.11 = s32[32,16]{0,1:T(1,128)} add(p21, p31) +} + +ENTRY entry { + p01 = s32[32,16]{0, 1:T(1,128)} parameter(0) + p11 = s32[32,16]{0,1:T(1,128)} parameter(1) + ROOT fusion.11 = s32[32,16]{0,1:T(1,128)} fusion(p01,p11), kind=kLoop, calls=fused_computation.11 +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr left_gumgraph, + HloGumgraph::Create(left_module.get())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr right_gumgraph, + HloGumgraph::Create(right_module.get())); + auto mappings = std::make_unique(); + const CallGraphNode& left_fused_computation = + left_gumgraph->GetCallGraph().GetNode( + left_module->GetComputationWithName("fused_computation.1")); + const CallGraphNode& right_fused_computation = + right_gumgraph->GetCallGraph().GetNode( + right_module->GetComputationWithName("fused_computation.11")); + + mappings->MapComputationsIfAbsent(left_fused_computation, + right_fused_computation, + ComputationMatchType::kExact); + MatchComputationGraphs(*left_gumgraph, *right_gumgraph, + left_fused_computation, right_fused_computation, + *mappings); + + auto matched_params = ExtractMappedInstructionNames(*mappings); + EXPECT_THAT(matched_params, + UnorderedElementsAre(Pair("p2", "p21"), Pair("p3", "p31"), + Pair("add.1", "add.11"), + Pair("fusion.1", "fusion.11"))); +} + +} // namespace +} // namespace xla::hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc new file mode 100644 index 00000000000000..dd5f10256b1d9d --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc @@ -0,0 +1,511 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_bfs.h" +#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/service/hlo_value.h" + +namespace xla { +namespace hlo_diff { +namespace { + +constexpr double kOperandsMatchScore = 0.75; +constexpr double kFingerprintMatchScore = 0.5; +constexpr double kOperandsFingerprintsMatchScore = 0.5; +constexpr double kMetadataOpNameMatchScore = 0.1; +constexpr double kMetadataSourceFileMatchScore = 0.1; +constexpr double kMetadataSourceLineMatchScore = 0.1; + +struct NodePairSimilarity { + const HloInstructionNode* left; + const HloInstructionNode* right; + double similarity; +}; + +// Returns true if the two subgraphs have a diff. +bool HasDiff(absl::Nonnull left, int left_graph_size, + absl::Nonnull right, + int right_graph_size) { + if (left->props.subgraph_fingerprint != right->props.subgraph_fingerprint) { + return true; + } + // TODO(b/365855856): Make sure there's no hash collision before removing the + // following extra comparison code. + std::vector left_subgraph = GetAllNodesInBfsOrder( + *left, BfsTraversalDirection::kForward, left_graph_size); + std::vector right_subgraph = GetAllNodesInBfsOrder( + *right, BfsTraversalDirection::kForward, right_graph_size); + if (left_subgraph.size() != right_subgraph.size()) { + LOG(WARNING) << "Subgraph (" << left->instruction->name() << " vs " + << right->instruction->name() << ") with same fingerprint " + << left->props.subgraph_fingerprint + << " but different size: " << left_subgraph.size() << " vs " + << right_subgraph.size(); + return true; + } + for (int i = 0; i < left_subgraph.size(); ++i) { + if (left_subgraph[i]->instruction->opcode() != + right_subgraph[i]->instruction->opcode()) { + LOG(WARNING) << "Subgraph (" << left->instruction->name() << " vs " + << right->instruction->name() << ") with same fingerprint " + << left->props.subgraph_fingerprint << " and size " + << left_subgraph.size() << " but has diff type at node " << i + << ":" << left_subgraph[i]->instruction->name() << " vs " + << right_subgraph[i]->instruction->name(); + return true; + } + } + return false; +}; + +// Maps the two subgraphs starting from the given nodes. +void MapSubgraph(absl::Nonnull left, + int left_graph_size, + absl::Nonnull right, + int right_graph_size, const MatcherType matcher_type, + HloGumgraphMappings& mappings) { + std::vector left_subgraph = GetAllNodesInBfsOrder( + *left, BfsTraversalDirection::kForward, left_graph_size); + std::vector right_subgraph = GetAllNodesInBfsOrder( + *right, BfsTraversalDirection::kForward, right_graph_size); + if (left_subgraph.size() != right_subgraph.size()) { + LOG(WARNING) << "Unable to map subgraphs due to size mismatch: " + << left_subgraph.size() << " vs " << right_subgraph.size(); + return; + } + for (int i = 0; i < left_subgraph.size(); ++i) { + mappings.MapInstructionsIfAbsent(left_subgraph[i], right_subgraph[i], + matcher_type); + // Mark all nodes except the root as unchanged. + if (i != 0) { + mappings.left_to_right_instruction_map.left.find(left_subgraph[i]) + ->info.unchanged = true; + } + } +} + +// Recursively matches the two nodes top down when the opcodes and the +// position of the nodes in their parents' children list match. +void RecursiveTopDownMatcher(const HloInstructionNode* left, + const HloInstructionNode* right, + const MatcherType matcher_type, + HloGumgraphMappings& mappings) { + for (auto i = 0; i < left->children.size() && i < right->children.size(); + ++i) { + const HloInstructionNode* left_child = left->children[i]; + const HloInstructionNode* right_child = right->children[i]; + // TODO(b/360878130) - Use fingerprint to compare nodes. + if (left_child->instruction->opcode() != + right_child->instruction->opcode() || + !(mappings.MapInstructionsIfAbsent(left_child, right_child, + matcher_type))) { + // Stop recursive matching if the nodes are not matched, or + // non-overwriting mapping failed. + continue; + } + RecursiveTopDownMatcher(left_child, right_child, matcher_type, mappings); + } +} + +// DiceSim similarity score between two subgraphs. Subgraphs are limited to +// first max_subgraph_size nodes of BFS starting from the given nodes. +double DiceSimLimitedSubgraph(absl::Nonnull left, + absl::Nonnull right, + HloGumgraphMappings& mappings, + int max_subgraph_size, int left_graph_size, + int right_graph_size) { + absl::flat_hash_set left_nodes; + absl::flat_hash_set right_nodes; + HloGumgraphBfs( + *left, + [&](const HloInstructionNode& node) { + left_nodes.insert(&node); + return left_nodes.size() < max_subgraph_size; + }, + BfsTraversalDirection::kForward, left_graph_size); + HloGumgraphBfs( + *right, + [&](const HloInstructionNode& node) { + right_nodes.insert(&node); + return right_nodes.size() < max_subgraph_size; + }, + BfsTraversalDirection::kForward, right_graph_size); + int common = 0; + for (const HloInstructionNode* left_node : left_nodes) { + if (auto it = mappings.left_to_right_instruction_map.left.find(left_node); + it != mappings.left_to_right_instruction_map.left.end() && + right_nodes.contains(it->second)) { + ++common; + } + } + + return 2 * static_cast(common) / + static_cast((left_nodes.size() + right_nodes.size())); +} + +// A heuristic score based on the node attributes. Calculated by comparing the +// fingerprint, name and generation of the nodes. This set of parameters +// together with min_similarity threshold = 0.75 works the best so far, and +// might need to be tuned later. +double NodeAttributesSimilarity( + absl::Nonnull left, + absl::Nonnull right) { + double sim_score = 0.0; + + if (right->props.fingerprint == left->props.fingerprint) { + sim_score += kFingerprintMatchScore; + } + + if (!left->instruction->metadata().op_name().empty() && + left->instruction->metadata().op_name() == + right->instruction->metadata().op_name()) { + sim_score += kMetadataOpNameMatchScore; + if (!left->instruction->metadata().source_file().empty() && + left->instruction->metadata().source_file() == + right->instruction->metadata().source_file()) { + sim_score += kMetadataSourceFileMatchScore; + if (left->instruction->metadata().source_line() != 0 && + left->instruction->metadata().source_line() == + right->instruction->metadata().source_line()) { + sim_score += kMetadataSourceLineMatchScore; + } + } + } + + return sim_score; +} + +// A heuristic score based on the ancestor subgraphs of the given nodes. +// Calculated by comparing the fingerprints of the ancestors of the nodes. +double AncestorSubGraphSimilarity(const HloInstructionNode* left, + const HloInstructionNode* right, + const int candidate_traversal_limit, + int left_graph_size, int right_graph_size) { + absl::flat_hash_map left_ancestor_fingerprints, + right_ancestor_fingerprints; + int left_traversal_count = 0; + HloGumgraphBfs( + *left, + [&](const HloInstructionNode& node) { + ++left_ancestor_fingerprints[node.props.fingerprint]; + return ++left_traversal_count < candidate_traversal_limit; + }, + BfsTraversalDirection::kReverse, left_graph_size); + int right_traversal_count = 0; + HloGumgraphBfs( + *right, + [&](const HloInstructionNode& node) { + ++right_ancestor_fingerprints[node.props.fingerprint]; + return ++right_traversal_count < candidate_traversal_limit; + }, + BfsTraversalDirection::kReverse, right_graph_size); + + int matching_ancestors = 0; + for (const auto& [fingerprint, count] : left_ancestor_fingerprints) { + if (right_ancestor_fingerprints.contains(fingerprint)) { + matching_ancestors += + std::min(count, right_ancestor_fingerprints[fingerprint]); + } + } + + return 2.0 * static_cast(matching_ancestors) / + static_cast(left_traversal_count + right_traversal_count); +} + +// Returns all HloValues used by the given instruction. +std::vector GetAllValuesUsedByInstruction( + const HloInstruction* instruction, const HloGumgraph& gumgraph) { + if (instruction->opcode() == HloOpcode::kParameter) { + if (instruction->parent()->IsEntryComputation() || + gumgraph.GetHloValueTracing().ValueIsDefinedAt(instruction)) { + return std::vector(); + } + + return gumgraph.GetHloValueTracing() + .GetFlattenedValueSet(instruction) + .values(); + } + + std::vector values_used_by_instruction; + for (const HloInstruction* operand : instruction->operands()) { + const HloValueSet operand_value_set = + gumgraph.GetHloValueTracing().GetFlattenedValueSet(operand); + for (const HloValue* value : operand_value_set.values()) { + absl::Span uses = value->GetUses(); + for (const HloUse& use : uses) { + if (use.instruction == instruction) { + values_used_by_instruction.push_back(value); + break; + } + } + } + } + + return values_used_by_instruction; +} + +// Returns true if all HloValues used by the left and right nodes have their +// defining instructions matched. +double AllOperandHloValuesMatchedScore(const HloInstructionNode* left_node, + const HloInstructionNode* right_node, + const HloGumgraph& left, + const HloGumgraph& right, + HloGumgraphMappings& mappings) { + std::vector left_hlo_values = + GetAllValuesUsedByInstruction(left_node->instruction, left); + std::vector right_hlo_values = + GetAllValuesUsedByInstruction(right_node->instruction, right); + + if (left_hlo_values.empty() || right_hlo_values.empty() || + left_hlo_values.size() != right_hlo_values.size()) { + return 0.0; + } + + bool fingerprints_matched = true; + bool mappings_matched = true; + for (int i = 0; i < left_hlo_values.size(); ++i) { + if (!fingerprints_matched && !mappings_matched) { + // stop if both fingerprints and mappings are not matched. + break; + } + + HloInstructionNode* left_hlo_value_node = + left.GetNode(left_hlo_values[i]->defining_instruction()); + HloInstructionNode* right_hlo_value_node = + right.GetNode(right_hlo_values[i]->defining_instruction()); + if (auto it = mappings.left_to_right_instruction_map.left.find( + left_hlo_value_node); + it == mappings.left_to_right_instruction_map.left.end() || + it->second != right_hlo_value_node) { + mappings_matched = false; + } + if (left_hlo_value_node->props.fingerprint != + right_hlo_value_node->props.fingerprint) { + fingerprints_matched = false; + } + } + + if (mappings_matched) { + return kOperandsMatchScore; + } + if (fingerprints_matched) { + return kOperandsFingerprintsMatchScore; + } + return 0.0; +} + +} // namespace + +void GreedySubGraphExactMatcher::Match(HloGumgraphMappings& mappings) const { + // Find candidate subgraphs that match exactly. + LOG(INFO) << "Running GreedySubgraphExactMatcher: matching subgraphs that " + "match exactly"; + int current_mapping_count = mappings.left_to_right_instruction_map.size(); + absl::flat_hash_map> + candidates, candidates_reverse; + int max_height = + std::max(left_.GetRoot().props.height, right_.GetRoot().props.height); + // Cache all subgraphs at each height. + absl::flat_hash_map> + source_subgraphs; + HloGumgraphBfs( + left_.GetRoot(), + [&](const HloInstructionNode& node) { + if (!node.is_root) { + source_subgraphs[node.props.height].push_back(&node); + } + return true; + }, + BfsTraversalDirection::kForward, left_.GetNodeCount()); + absl::flat_hash_map> + target_subgraphs; + HloGumgraphBfs( + right_.GetRoot(), + [&](const HloInstructionNode& node) { + if (!node.is_root) { + target_subgraphs[node.props.height].push_back(&node); + } + return true; + }, + BfsTraversalDirection::kForward, right_.GetNodeCount()); + + absl::flat_hash_set ignored; + // Find exact match left-right subgraphs candidates greedly from high to low + // height. + for (int height = max_height; height >= 0; --height) { + if (!source_subgraphs.contains(height) || + !target_subgraphs.contains(height)) { + continue; + } + absl::flat_hash_set found; + // Find exact match left-right subgraph candidates at the current height. + for (const HloInstructionNode* source_node : source_subgraphs[height]) { + if (ignored.contains(source_node)) { + continue; + } + for (const HloInstructionNode* target_node : target_subgraphs[height]) { + if (ignored.contains(target_node)) { + continue; + } + if (HasDiff(source_node, left_.GetNodeCount(), target_node, + right_.GetNodeCount())) { + continue; + } + candidates[source_node].push_back(target_node); + candidates_reverse[target_node].push_back(source_node); + found.insert(source_node); + found.insert(target_node); + } + } + // Ignore all nodes in the subgraphs that matched in later traversals. + for (const HloInstructionNode* found_node : found) { + HloGumgraphBfs( + *found_node, + [&](const HloInstructionNode& node) { + ignored.insert(&node); + return true; + }, + BfsTraversalDirection::kForward, + std::max(left_.GetNodeCount(), right_.GetNodeCount())); + } + } + // Map 1:1 candidates. + for (auto& [left, right] : candidates) { + if (right.size() == 1 && candidates_reverse[right[0]].size() == 1) { + MapSubgraph(left, left_.GetNodeCount(), right[0], right_.GetNodeCount(), + type_, mappings); + } + } + + LOG(INFO) + << "Finished GreedySubGraphExactMatcher. Found left to right mappings: " + << mappings.left_to_right_instruction_map.size() - current_mapping_count; +} + +void GreedyLimitedCandidatesBottomUpMatcher::Match( + HloGumgraphMappings& mappings) const { + LOG(INFO) << "Running GreedyLimitedCandidatesBottomUpMatcher: matching " + "subgraphs that match based on Dice similarity"; + int current_mapping_count = mappings.left_to_right_instruction_map.size(); + std::vector left_postorder = GetAllNodesInDfsOrder( + left_.GetRoot(), DfsTraversalOrder::kPostOrder, left_.GetNodeCount()); + for (const HloInstructionNode* left_node : left_postorder) { + // Skip matched nodes or ones without children. + if (mappings.InstructionMapContainsLeft(left_node) || + left_node->children.empty()) { + continue; + } + + std::vector right_seeds; + int count = 0; + HloGumgraphBfs( + *left_node, + [&](const HloInstructionNode& node) { + if (auto it = mappings.left_to_right_instruction_map.left.find(&node); + it != mappings.left_to_right_instruction_map.left.end()) { + right_seeds.push_back(it->second); + } + // Don't pursue subgraphs with too many childrens. Allows us to visit + // deeper subgraphs without getting stuck on a single node with a + // large number of children. + if (node.children.size() > right_seeds_traversal_limit_ / 2) { + return false; + } + return ++count < right_seeds_traversal_limit_; + }, + BfsTraversalDirection::kForward, left_.GetNodeCount()); + + // Find right candidates and maxSimilarity on the fly. + double max_similarity = 0; + const HloInstructionNode* right_candidate = nullptr; + count = 0; + HloGumgraphBfs( + right_seeds, + [&](const HloInstructionNode& node) { + if (!mappings.InstructionMapContainsRight(&node) && + node.instruction->opcode() == left_node->instruction->opcode()) { + // Found candidate. Calculate similarity. + double operands_match_similarity = AllOperandHloValuesMatchedScore( + left_node, &node, left_, right_, mappings); + double dice_sim = DiceSimLimitedSubgraph( + left_node, &node, mappings, max_dice_subgraph_size_, + left_.GetNodeCount(), right_.GetNodeCount()); + double node_attributes_similarity = + NodeAttributesSimilarity(left_node, &node); + double ancestor_similarity = AncestorSubGraphSimilarity( + left_node, &node, max_ancestors_to_consider_, + left_.GetNodeCount(), right_.GetNodeCount()); + // We give ancestor similarity a lower weight as its lower signal + // in comparison to dice similarity and node attributes similarity. + double similarity = operands_match_similarity + + node_attributes_similarity + dice_sim + + ancestor_similarity / 2; + if (similarity > max_similarity) { + max_similarity = similarity; + right_candidate = &node; + } + } + return ++count < candidate_traversal_limit_; + }, + BfsTraversalDirection::kReverse, right_.GetNodeCount()); + if (max_similarity > min_similarity_) { + mappings.MapInstructionsIfAbsent(left_node, right_candidate, type_); + } + } + LOG(INFO) << "Finished GreedyLimitedCandidatesBottomUpMatcher. Total left to " + "right mappings: " + << mappings.left_to_right_instruction_map.size() - + current_mapping_count; +} + +void GreedyTopDownMatcher::Match(HloGumgraphMappings& mappings) const { + LOG(INFO) << "Running GreedyTopDownMatcher: matching umatched nodes"; + int current_mapping_count = mappings.left_to_right_instruction_map.size(); + HloGumgraphDfs( + left_.GetRoot(), + [&](const HloInstructionNode& left_node) { + auto it = mappings.left_to_right_instruction_map.left.find(&left_node); + if (it == mappings.left_to_right_instruction_map.left.end()) { + return; + } + + RecursiveTopDownMatcher(&left_node, it->second, type_, mappings); + }, + DfsTraversalOrder::kPostOrder, left_.GetNodeCount()); + LOG(INFO) << "Finished GreedyTopDownMatcher. Total left to right mappings: " + << mappings.left_to_right_instruction_map.size() - + current_mapping_count; +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h new file mode 100644 index 00000000000000..8d8f8f242ed638 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h @@ -0,0 +1,129 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_GUMGRAPH_MATCHER_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_GUMGRAPH_MATCHER_H_ + +#include "absl/log/die_if_null.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla { +namespace hlo_diff { + +// Options allowing configuration of the instruction matching algorithm. +struct MatchOptions { + bool use_top_down_matcher = true; +}; + +// Base class for all node matchers. Each matcher implements a unique algorithm +// to match nodes between two HLO graphs. The base class standardizes input and +// output types, ensuring seamless integration and compatibility within any +// matcher sequence. +class HloGumgraphMatcher { + public: + virtual ~HloGumgraphMatcher() = default; + virtual void Match(HloGumgraphMappings& mappings) const = 0; + + protected: + explicit HloGumgraphMatcher(MatcherType type) : type_(type) {} + const MatcherType type_; +}; + +// Matcher that matches identical subgraphs starting with the tallest. +class GreedySubGraphExactMatcher : public HloGumgraphMatcher { + public: + GreedySubGraphExactMatcher(const HloGumgraph* left, const HloGumgraph* right) + : HloGumgraphMatcher(MatcherType::kGreedySubGraphExactMatcher), + left_(*ABSL_DIE_IF_NULL(left)), + right_(*ABSL_DIE_IF_NULL(right)) {} + void Match(HloGumgraphMappings& mappings) const override; + + private: + const HloGumgraph& left_; + const HloGumgraph& right_; +}; + +// Matcher that matches nodes bottom up by dice similarity. For each left node, +// mappings of the already matched descendants are considered as seeds, from +// which we traverse back the graph to find nodes with same opcode as +// candidates. The candidate with the highest similarity is chosen as the match. +// Nodes mapped by this matcher in earlier iterations are also considered as +// seeds for later iterations. +// +// Seeds: Number of nodes to traverse to find seeds are limited. +// Candidates: Number of nodes to traverse to find candidates are limited. +// Dice similarity: Number of nodes to traverse in subgraph are limited. +class GreedyLimitedCandidatesBottomUpMatcher : public HloGumgraphMatcher { + public: + GreedyLimitedCandidatesBottomUpMatcher(const HloGumgraph* left, + const HloGumgraph* right, + double min_similarity = 1.2, + int max_dice_subgraph_size = 200, + int max_ancestors_to_consider = 100, + int right_seeds_traversal_limit = 40, + int candidate_traversal_limit = 200) + : HloGumgraphMatcher( + MatcherType::kGreedyLimitedCandidatesBottomUpMatcher), + left_(*ABSL_DIE_IF_NULL(left)), + right_(*ABSL_DIE_IF_NULL(right)), + min_similarity_(min_similarity), + max_dice_subgraph_size_(max_dice_subgraph_size), + max_ancestors_to_consider_(max_ancestors_to_consider), + right_seeds_traversal_limit_(right_seeds_traversal_limit), + candidate_traversal_limit_(candidate_traversal_limit) {} + void Match(HloGumgraphMappings& mappings) const override; + + private: + const HloGumgraph& left_; + const HloGumgraph& right_; + + // Minimum similarity to consider a match. + const double min_similarity_; + + // Maximum size of the subgraph to consider when calculating dice similarity. + const int max_dice_subgraph_size_; + + // Maximum number of ancestors to consider when calculating ancestor + // similarity. + const int max_ancestors_to_consider_; + + // Maximum number of nodes to traverse to find right seeds. + const int right_seeds_traversal_limit_; + + // Maximum number of nodes to traverse from seeds. Nodes with the same + // opcode are considered as candidates. + const int candidate_traversal_limit_; +}; + +// Matcher that matches nodes top down by same type sequence along the path. +class GreedyTopDownMatcher : public HloGumgraphMatcher { + public: + GreedyTopDownMatcher(const HloGumgraph* left, const HloGumgraph* right) + : HloGumgraphMatcher(MatcherType::kGreedyTopDownMatcher), + left_(*ABSL_DIE_IF_NULL(left)), + right_(*ABSL_DIE_IF_NULL(right)) {} + void Match(HloGumgraphMappings& mappings) const override; + + private: + const HloGumgraph& left_; + const HloGumgraph& right_; +}; + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_MATCHERS_HLO_GUMGRAPH_MATCHER_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher_test.cc new file mode 100644 index 00000000000000..fddb6ba87c4b8e --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher_test.cc @@ -0,0 +1,551 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h" + +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class HloMatcherTest : public HloTestBase {}; + +TEST_F(HloMatcherTest, SubGraphExactMatcherEntryChange) { + // Create left module with entry computation containing the following + // structure: + // [Param foo_L] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar_L] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz_L] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo_L = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar_L = f32[8,2048]{1,0:T(8,128)} constant(0) + baz_L = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo_L, bar_L) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz_L) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param foo_R] ------> ┌-------┐ + // | add_1 | ---> ┌------------┐ ┌------┐ + // [Constant bar_R] ---> └-------┘ | subtract_0 | ---> | ROOT | + // [Param baz_R] ---------------------> └------------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo_R = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar_R = f32[8,2048]{1,0:T(8,128)} constant(0) + baz_R = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo_R, bar_R) + subtract_0 = f32[8,2048]{1,0:T(8,128)} subtract(add_1, baz_R) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + auto matcher = std::make_unique(graph_l.get(), + graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT(mapped_nodes, UnorderedElementsAre( + Pair("add_1", "add_1"), Pair("foo_L", "foo_R"), + Pair("bar_L", "bar_R"), Pair("baz_L", "baz_R"), + Pair("root_L", "root_R"))); +} + +TEST_F(HloMatcherTest, SubGraphExactMatcherLeafChange) { + // Create left module with entry computation containing the following + // structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Param baz] ---------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} parameter(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Param foo] ------> ┌-------┐ + // | add_1 | ---> ┌-------┐ ┌------┐ + // [Constant bar] ---> └-------┘ | add_0 | ---> | ROOT | + // [Constant baz] ------------------> └-------┘ └------┘ + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + foo = f32[8,2048]{1,0:T(8,128)} parameter(0) + bar = f32[8,2048]{1,0:T(8,128)} constant(0) + baz = f32[8,2048]{1,0:T(8,128)} constant(1) + add_1 = f32[8,2048]{1,0:T(8,128)} add(foo, bar) + add_0 = f32[8,2048]{1,0:T(8,128)} add(add_1, baz) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + auto matcher = std::make_unique(graph_l.get(), + graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT(mapped_nodes, UnorderedElementsAre( + Pair("add_1", "add_1"), Pair("foo", "foo"), + Pair("bar", "bar"), Pair("root_L", "root_R"))); +} + +TEST_F(HloMatcherTest, GreedyLimitedCandidatesBottomUpMatcher) { + // Create left module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | add_0 | --------> ┌-------┐ + // [Const 1] ---> └-------┘ | | ┌-------┐ + // | add_3 | ---> | | + // [Const 2] ---> ┌------------┐ | | | | ┌------┐ + // | subtract_1 | ---> └-------┘ | add_4 | ---> | ROOT | + // [Const 3] ---> └------------┘ | | └------┘ + // | | + // [Const 4] --------------------------------------> └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + constant.4 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + subtract.1 = f32[] subtract(constant.2, constant.3) + add.3 = f32[] add(add.0, subtract.1) + add.4 = f32[] add(add.3, constant.4) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | add_0 | ---> ┌-------┐ + // [Const 1] ---> └-------┘ | | ┌-------┐ + // | add_3 | ---> | | + // [Const 2] ---> ┌-------┐ | | | | ┌------┐ + // | add_1 | ---> └-------┘ | add_4 | ---> | ROOT | + // [Const 3] ---> └-------┘ | | └------┘ + // | | + // [Const 4] ---------------------------------> └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + constant.4 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.2, constant.3) + add.3 = f32[] add(add.0, add.1) + add.4 = f32[] add(add.3, constant.4) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.0"), + GetNodeByName(*graph_r, "constant.0"), *mappings)); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.1"), + GetNodeByName(*graph_r, "constant.1"), *mappings)); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.2"), + GetNodeByName(*graph_r, "constant.2"), *mappings)); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.3"), + GetNodeByName(*graph_r, "constant.3"), *mappings)); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.4"), + GetNodeByName(*graph_r, "constant.4"), *mappings)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.0"), + GetNodeByName(*graph_r, "add.0"), *mappings)); + auto matcher = std::make_unique( + graph_l.get(), graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT( + mapped_nodes, + UnorderedElementsAre( + Pair("constant.0", "constant.0"), Pair("constant.1", "constant.1"), + Pair("constant.2", "constant.2"), Pair("constant.3", "constant.3"), + Pair("add.0", "add.0"), Pair("add.3", "add.3"), + Pair("constant.4", "constant.4"), Pair("add.4", "add.4"), + Pair("root_L", "root_R"))); +} + +TEST_F(HloMatcherTest, GreedyLimitedCandidatesBottomUpMatcherAmbiguousMatch) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.0, constant.1) + add.2 = f32[] add(add.0, constant.0) + subtract.1 = f32[] subtract(add.1, add.2) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + add.10 = f32[] add(constant.0, constant.1) + add.11 = f32[] add(constant.0, constant.1) + add.12 = f32[] add(add.10, constant.0) + subtract.1 = f32[] subtract(add.11, add.12) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.0"), + GetNodeByName(*graph_r, "constant.0"), *mappings)); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.1"), + GetNodeByName(*graph_r, "constant.1"), *mappings)); + auto matcher = std::make_unique( + graph_l.get(), graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT(mapped_nodes, + UnorderedElementsAre( + Pair("constant.0", "constant.0"), + Pair("constant.1", "constant.1"), Pair("add.0", "add.10"), + Pair("add.1", "add.11"), Pair("add.2", "add.12"), + Pair("subtract.1", "subtract.1"), Pair("root_L", "root_R"))); +} + +TEST_F(HloMatcherTest, GreedyLimitedCandidatesBottomUpMatcherHloValueTraced) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation0 { + param_0 = f32[] parameter(0) + ROOT negate.0 = f32[] negate(param_0) +} + +fused_computation1 { + param_1 = f32[] parameter(0) + ROOT abs.0 = f32[] abs(param_1) +} + +ENTRY entry { + constant.0 = f32[] constant(0) + bitcast.0 = f32[] bitcast(constant.0) + copy.0 = f32[] copy(bitcast.0) + fusion.0 = f32[] fusion(bitcast.0), kind=kLoop, calls=fused_computation0 + fusion.1 = f32[] fusion(copy.0), kind=kLoop, calls=fused_computation1 + ROOT add.0 = f32[] add(fusion.0, fusion.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +fused_computation0 { + param_0 = f32[] parameter(0) + ROOT negate.0 = f32[] negate(param_0) +} + +fused_computation1 { + param_1 = f32[] parameter(0) + ROOT abs.0 = f32[] abs(param_1) +} + +ENTRY entry { + constant.0 = f32[] constant(0) + bitcast.0 = f32[] bitcast(constant.0) + copy.0 = f32[] copy(bitcast.0) + fusion.0 = f32[] fusion(bitcast.0), kind=kLoop, calls=fused_computation1 + fusion.1 = f32[] fusion(copy.0), kind=kLoop, calls=fused_computation0 + ROOT add.0 = f32[] add(fusion.0, fusion.1) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE(OverwriteMapInstructions( + GetNodeByName(*graph_l, "constant.0"), + GetNodeByName(*graph_r, "constant.0"), *mappings)); + auto matcher = std::make_unique( + graph_l.get(), graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT(mapped_nodes, + UnorderedElementsAre( + Pair("constant.0", "constant.0"), + Pair("bitcast.0", "bitcast.0"), Pair("copy.0", "copy.0"), + Pair("fusion.0", "fusion.0"), Pair("fusion.1", "fusion.1"), + Pair("add.0", "add.0"), Pair("negate.0", "negate.0"), + Pair("abs.0", "abs.0"), Pair("param_0", "param_1"), + Pair("param_1", "param_0"), Pair("root_L", "root_R"))); +} + +TEST_F(HloMatcherTest, GreedyTopDownMatcherStopAtUnmatchedType) { + // Create left module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | add_0 | --------> ┌-------┐ + // [Const 1] ---> └-------┘ | | ┌-------┐ + // | add_3 | ---> | | + // [Const 2] ---> ┌------------┐ | | | | ┌------┐ + // | subtract_1 | ---> └-------┘ | add_4 | ---> | ROOT | + // [Const 3] ---> └------------┘ | | └------┘ + // | | + // [Const 4] --------------------------------------> └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + constant.4 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + subtract.1 = f32[] subtract(constant.2, constant.3) + add.3 = f32[] add(add.0, subtract.1) + add.4 = f32[] add(add.3, constant.4) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [Const 0] ---> ┌-------┐ + // | add_0 | ---> ┌-------┐ + // [Const 1] ---> └-------┘ | | ┌-------┐ + // | add_3 | ---> | | + // [Const 2] ---> ┌-------┐ | | | | ┌------┐ + // | add_1 | ---> └-------┘ | add_4 | ---> | ROOT | + // [Const 3] ---> └-------┘ | | └------┘ + // | | + // [Const 4] ---------------------------------> └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + constant.4 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.2, constant.3) + add.3 = f32[] add(add.0, add.1) + add.4 = f32[] add(add.3, constant.4) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + auto matcher = + std::make_unique(graph_l.get(), graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT(mapped_nodes, + UnorderedElementsAre( + Pair("constant.0", "constant.0"), + Pair("constant.1", "constant.1"), Pair("add.0", "add.0"), + Pair("add.3", "add.3"), Pair("constant.4", "constant.4"), + Pair("add.4", "add.4"), Pair("root_L", "root_R"))); +} + +TEST_F(HloMatcherTest, GreedyTopDownMatcherStopAtMappedNode) { + // Create left module with entry computation containing the following + // structure: + // [const.0] ---> ┌-------┐ + // | add.0 | ---> ┌-------┐ + // [const.1] ---> └-------┘ | | ┌-------┐ + // | add.3 | ---> | | + // [const.2] ---> ┌-------┐ | | | | ┌------┐ + // | add.1 | ---> └-------┘ | add.4 | ---> | ROOT | + // [const.3] ---> └-------┘ | | └------┘ + // | | + // [const.4] ---------------------------------> └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_l, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + constant.4 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.2, constant.3) + add.3 = f32[] add(add.0, add.1) + add.4 = f32[] add(add.3, constant.4) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_l, + HloGumgraph::Create(module_l.get())); + + // Create right module with entry computation containing the following + // structure: + // [const.0] ---> ┌-------┐ + // | add.0 | ---> ┌-------┐ + // [const.1] ---> └-------┘ | | ┌-------┐ + // | add.3 | ---> | | + // [const.2] ---> ┌-------┐ | | | | + // | add.1 | ---> └-------┘ | | ┌------┐ + // [const.3] ---> └-------┘ | add.4 | ---> | ROOT | + // | | └------┘ + // [const.4] ---> ┌-------┐ | | + // | add.2 | ------------------> | | + // [const.5] ---> └-------┘ └-------┘ + // + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_r, + ParseAndReturnVerifiedModule(R"( +HloModule module, is_scheduled=true + +ENTRY entry { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(0) + constant.2 = f32[] constant(0) + constant.3 = f32[] constant(0) + constant.4 = f32[] constant(0) + constant.5 = f32[] constant(0) + add.0 = f32[] add(constant.0, constant.1) + add.1 = f32[] add(constant.2, constant.3) + add.2 = f32[] add(constant.4, constant.5) + add.3 = f32[] add(add.0, add.1) + add.4 = f32[] add(add.3, add.2) +} +)")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, + HloGumgraph::Create(module_r.get())); + auto mappings = std::make_unique(); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.4"), + GetNodeByName(*graph_r, "add.4"), *mappings)); + ASSERT_NO_FATAL_FAILURE( + OverwriteMapInstructions(GetNodeByName(*graph_l, "add.1"), + GetNodeByName(*graph_r, "add.2"), *mappings)); + auto matcher = + std::make_unique(graph_l.get(), graph_r.get()); + // Root nodes are matched by default before the matcher is called. + mappings->MapInstructionsIfAbsent(&graph_l->GetRoot(), &graph_r->GetRoot(), + MatcherType::kManual); + matcher->Match(*mappings); + auto mapped_nodes = ExtractMappedInstructionNames(*mappings); + + EXPECT_THAT( + mapped_nodes, + UnorderedElementsAre( + Pair("constant.0", "constant.0"), Pair("constant.1", "constant.1"), + Pair("add.0", "add.0"), Pair("constant.2", "constant.4"), + Pair("constant.3", "constant.5"), Pair("add.1", "add.2"), + Pair("add.3", "add.3"), Pair("add.4", "add.4"), + Pair("root_L", "root_R"))); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD new file mode 100644 index 00000000000000..f868bf87f3d83b --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD @@ -0,0 +1,76 @@ +load("//xla:xla.default.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_gumgraph_renderer_util", + srcs = ["hlo_gumgraph_renderer_util.cc"], + hdrs = ["hlo_gumgraph_renderer_util.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "hlo_gumgraph_renderer_util_test", + srcs = ["hlo_gumgraph_renderer_util_test.cc"], + deps = [ + ":hlo_gumgraph_renderer_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_gumgraph_text_renderer", + srcs = ["hlo_gumgraph_text_renderer.cc"], + hdrs = ["hlo_gumgraph_text_renderer.h"], + deps = [ + ":hlo_gumgraph_renderer_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff:hlo_diff_result", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "hlo_gumgraph_html_renderer", + srcs = ["hlo_gumgraph_html_renderer.cc"], + hdrs = ["hlo_gumgraph_html_renderer.h"], + deps = [ + ":hlo_gumgraph_renderer_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff:hlo_diff_result", + "//xla/hlo/tools/hlo_diff:hlo_diff_summary", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc new file mode 100644 index 00000000000000..fb788ca7745a0f --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc @@ -0,0 +1,494 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h" + +namespace xla { +namespace hlo_diff { +namespace { + +/*** HTML printing functions ***/ + +// Prints the CSS styles for the HTML output. +std::string PrintCss() { + return R"( + + )"; +} + +// Prints the div html block. +std::string PrintDiv(absl::string_view content, absl::string_view class_name) { + return absl::StrFormat("
%s
", class_name, content); +} + +// Prints the detail html block. +std::string PrintDetails(absl::string_view summary, absl::string_view content) { + return absl::StrFormat(R"(
%s%s
)", + summary, PrintDiv(content, "content")); +} + +// Prints a html block with a header. +std::string PrintSectionWithHeader(absl::string_view header, + absl::string_view content) { + return PrintDiv( + absl::StrCat(PrintDiv(header, "header"), PrintDiv(content, "content")), + "section"); +} + +// Prints a list of items. +std::string PrintList(absl::Span items) { + return PrintDiv(absl::StrJoin(items, "", + [](std::string* out, const auto& item) { + absl::StrAppend(out, PrintDiv(item, "item")); + }), + "list"); +} + +// Prints a link to the instruction in model explorer if url_generator is not +// null, otherwise returns the text directly. +std::string PrintInstructionLink(const HloInstruction* left_inst, + const HloInstruction* right_inst, + absl::string_view text, + UrlGenerator url_generator) { + std::string url = url_generator(left_inst, right_inst); + + if (url.empty()) { + return std::string(text); + } + return absl::StrFormat("%s", url, text); +} + +std::string PrintTooltip(absl::string_view text, + absl::string_view tooltip_text) { + return absl::StrFormat( + R"(%s%s)", + text, tooltip_text); +} + +/*** Summary logic ***/ + +// The location of the instruction in the diff result. +enum class InstructionLocation : std::uint8_t { kLeft, kRight }; + +// Prints a list of instructions. +std::string PrintInstructionsAsList( + absl::Span instructions, + InstructionLocation location, bool name_only, UrlGenerator url_generator) { + std::vector instructions_list; + for (const HloInstruction* inst : instructions) { + std::string link; + if (location == InstructionLocation::kLeft) { + link = PrintInstructionLink(inst, /*right_inst=*/nullptr, + InstructionToString(inst, name_only), + url_generator); + } else { + link = PrintInstructionLink(/*left_inst=*/nullptr, inst, + InstructionToString(inst, name_only), + url_generator); + } + instructions_list.push_back(link); + } + return PrintList(instructions_list); +} + +// Prints a list of instruction pairs. +std::string PrintInstructionPairsAsList( + absl::Span> + instruction_pairs, + const std::function& + instruction_pair_printer) { + std::vector pair_list; + for (const auto& pair : instruction_pairs) { + pair_list.push_back(instruction_pair_printer(pair.first, pair.second)); + } + return PrintList(pair_list); +} + +// Prints unmatched instructions grouped by opcode and print in a descending +// order of the number of instructions for each opcode. +std::string PrintUnmatchedInstructions( + absl::Span instructions, + InstructionLocation location, + const absl::flat_hash_set& opcodes_to_ignore, bool name_only, + UrlGenerator url_generator) { + absl::flat_hash_map> + instructions_by_opcode = GroupInstructionsByOpcode(instructions); + std::vector> opcode_counts; + for (const auto& [opcode, insts] : instructions_by_opcode) { + opcode_counts.push_back({opcode, insts.size()}); + } + std::sort(opcode_counts.begin(), opcode_counts.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + std::stringstream ss; + for (auto cit = opcode_counts.begin(); cit != opcode_counts.end(); ++cit) { + if (opcodes_to_ignore.contains(cit->first)) { + continue; + } + ss << PrintDetails( + absl::StrFormat("%s (%d)", HloOpcodeString(cit->first), cit->second), + PrintInstructionsAsList(instructions_by_opcode[cit->first], location, + name_only, url_generator)); + } + return ss.str(); +} + +// Prints instruction pairs grouped by opcode and print in a descending order +// of the number of instruction pairs for each opcode. +std::string PrintInstructionPairsByOpcode( + const absl::flat_hash_map& + instructions, + const absl::flat_hash_set& opcodes_to_ignore, + const std::function& + instruction_pair_printer) { + absl::flat_hash_map< + HloOpcode, + std::vector>> + instructions_by_opcode = GroupInstructionPairsByOpcode(instructions); + std::vector> opcode_counts; + for (const auto& [opcode, insts] : instructions_by_opcode) { + opcode_counts.push_back({opcode, insts.size()}); + } + std::sort(opcode_counts.begin(), opcode_counts.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + std::stringstream ss; + for (auto cit = opcode_counts.begin(); cit != opcode_counts.end(); ++cit) { + if (opcodes_to_ignore.contains(cit->first)) { + continue; + } + absl::string_view op_name = HloOpcodeString(cit->first); + ss << PrintDetails( + absl::StrFormat("%s (%d)", op_name, cit->second), + PrintInstructionPairsAsList(instructions_by_opcode.at(cit->first), + instruction_pair_printer)); + } + return ss.str(); +} + +// Prints the summary of the changed instruction diff type. +std::string PrintChangedInstructionDiffTypeSummary( + const HloInstruction* left_inst, const HloInstruction* right_inst, + ChangedInstructionDiffType diff_type) { + switch (diff_type) { + case ChangedInstructionDiffType::kShapeChange: + return absl::StrFormat( + "left: %s\nright: %s", + left_inst->shape().ToString(/*print_layout=*/true), + right_inst->shape().ToString(/*print_layout=*/true)); + case ChangedInstructionDiffType::kLayoutChange: + return absl::StrFormat("left: %s\nright: %s", + left_inst->shape().layout().ToString(), + right_inst->shape().layout().ToString()); + case ChangedInstructionDiffType::kMemorySpaceChange: + return absl::StrFormat("left: %d\nright: %d", + left_inst->shape().layout().memory_space(), + right_inst->shape().layout().memory_space()); + case ChangedInstructionDiffType::kChangedOperandsNumber: + return absl::StrFormat("left: %d\nright: %d", left_inst->operand_count(), + right_inst->operand_count()); + case ChangedInstructionDiffType::kChangedOperandsShape: { + std::vector operand_shape_diffs; + for (int64_t i = 0; i < left_inst->operand_count(); ++i) { + if (left_inst->operand(i)->shape() != right_inst->operand(i)->shape()) { + operand_shape_diffs.push_back(absl::StrFormat( + "operand %d (%s):\n left: %s\n right: %s", i, + HloOpcodeString(left_inst->operand(i)->opcode()), + left_inst->operand(i)->shape().ToString(/*print_layout=*/true), + right_inst->operand(i)->shape().ToString(/*print_layout=*/true))); + } + } + return absl::StrJoin(operand_shape_diffs, "\n"); + } + case ChangedInstructionDiffType::kOpCodeChanged: + return absl::StrFormat("left: %s\nright: %s", + HloOpcodeString(left_inst->opcode()), + HloOpcodeString(right_inst->opcode())); + case ChangedInstructionDiffType::kConstantLiteralChanged: + return absl::StrFormat("left: %s\nright: %s", + left_inst->literal().ToString(), + right_inst->literal().ToString()); + default: + return "Other changes"; + } +} + +// Prints changed instructions grouped by opcode and print in a +// descending order of the number of instructions for each opcode. +std::string PrintChangedInstructions( + const absl::flat_hash_map& + instructions, + const absl::flat_hash_set& opcodes_to_ignore, + UrlGenerator url_generator) { + auto decorated_printer = [&url_generator](const HloInstruction* left_inst, + const HloInstruction* right_inst) { + std::vector diff_types = + GetChangedInstructionDiffTypes(*left_inst, *right_inst); + return absl::StrFormat( + "%s have changed: %s", + PrintInstructionLink( + left_inst, right_inst, + absl::StrFormat( + "%s and %s", InstructionToString(left_inst, /*name_only=*/true), + InstructionToString(right_inst, /*name_only=*/true)), + url_generator), + absl::StrJoin( + diff_types, ", ", + [&left_inst, &right_inst](std::string* out, const auto& diff_type) { + std::string diff_type_string = + GetChangedInstructionDiffTypeString(diff_type); + return absl::StrAppend( + out, + diff_type == ChangedInstructionDiffType::kOtherChange + ? diff_type_string + : PrintTooltip(diff_type_string, + PrintChangedInstructionDiffTypeSummary( + left_inst, right_inst, diff_type))); + })); + }; + return PrintInstructionPairsByOpcode(instructions, opcodes_to_ignore, + decorated_printer); +} + +// Prints unchanged instructions grouped by opcode and print in a +// descending order of the number of instructions for each opcode. +std::string PrintUnchangedInstructions( + const absl::flat_hash_map& + instructions, + const absl::flat_hash_set& opcodes_to_ignore, + UrlGenerator url_generator) { + auto simple_printer = [&url_generator](const HloInstruction* left_inst, + const HloInstruction* right_inst) { + return PrintInstructionLink( + left_inst, right_inst, + absl::StrFormat("%s and %s", + InstructionToString(left_inst, /*name_only=*/true), + InstructionToString(right_inst, /*name_only=*/true)), + url_generator); + }; + return PrintInstructionPairsByOpcode(instructions, opcodes_to_ignore, + simple_printer); +} + +std::string PrintUnmatchedMetricsDiff( + absl::Span instructions, + GetOpMetricFn get_op_metrics, UrlGenerator url_generator) { + std::vector> sorted_metrics_diff; + for (const HloInstruction* inst : instructions) { + if (auto metric = get_op_metrics(inst->name()); metric.has_value()) { + sorted_metrics_diff.push_back({inst, static_cast(*metric)}); + } + } + + std::sort(sorted_metrics_diff.begin(), sorted_metrics_diff.end()); + std::vector metrics_diff_list(sorted_metrics_diff.size()); + for (const auto& [inst, metrics_diff] : sorted_metrics_diff) { + metrics_diff_list.push_back( + absl::StrFormat("%s: %.2f (us)", + PrintInstructionLink(inst, /*right_inst=*/nullptr, + inst->name(), url_generator), + metrics_diff / 1e6)); + } + return PrintList(metrics_diff_list); +} + +std::string PrintMatchedMetricsDiff( + const absl::flat_hash_map& + instructions, + GetOpMetricFn left_op_metrics, GetOpMetricFn right_op_metrics, + UrlGenerator url_generator) { + std::vector, + double>> + sorted_metrics_diff; + for (const auto& [left_inst, right_inst] : instructions) { + auto left_metric = left_op_metrics(left_inst->name()); + auto right_metric = right_op_metrics(right_inst->name()); + if (left_metric.has_value() && right_metric.has_value()) { + sorted_metrics_diff.push_back( + {{left_inst, right_inst}, + static_cast(*left_metric - *right_metric)}); + } + } + std::sort(sorted_metrics_diff.begin(), sorted_metrics_diff.end()); + std::vector metrics_diff_list(sorted_metrics_diff.size()); + for (const auto& [inst_pair, metrics_diff] : sorted_metrics_diff) { + const auto& [left_inst, right_inst] = inst_pair; + metrics_diff_list.push_back(absl::StrFormat( + "%s: %.2f (us)", + PrintInstructionLink( + left_inst, right_inst, + absl::StrFormat("%s and %s", left_inst->name(), right_inst->name()), + url_generator), + metrics_diff / 1e6)); + } + return PrintList(metrics_diff_list); +} + +} // namespace + +void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, + UrlGenerator url_generator, GetOpMetricFn left_op_metrics, + GetOpMetricFn right_op_metrics, std::ostringstream& out) { + const absl::flat_hash_set ignored_opcodes(kIgnoredOpcodes.begin(), + kIgnoredOpcodes.end()); + out << PrintCss(); + + // Print full diff results + out << PrintSectionWithHeader( + "Full Diff Results", + absl::StrCat( + PrintDetails( + absl::StrFormat( + "Unmatched Instructions (left) (%d)", + diff_result.left_module_unmatched_instructions.size()), + PrintUnmatchedInstructions( + diff_result.left_module_unmatched_instructions, + InstructionLocation::kLeft, ignored_opcodes, + /*name_only=*/false, url_generator)), + PrintDetails( + absl::StrFormat( + "Unmatched Instructions (right) (%d)", + diff_result.right_module_unmatched_instructions.size()), + PrintUnmatchedInstructions( + diff_result.right_module_unmatched_instructions, + InstructionLocation::kRight, ignored_opcodes, + /*name_only=*/false, url_generator)), + PrintDetails( + absl::StrFormat("Changed Instructions (%d)", + diff_result.changed_instructions.size()), + PrintChangedInstructions(diff_result.changed_instructions, + ignored_opcodes, url_generator)))); + + // Print profile metrics diff + out << PrintSectionWithHeader( + "Profile Metrics Diff", + absl::StrCat( + PrintDetails("Left Module Unmatched Instructions", + PrintUnmatchedMetricsDiff( + diff_result.left_module_unmatched_instructions, + left_op_metrics, url_generator)), + PrintDetails("Right Module Unmatched Instructions", + PrintUnmatchedMetricsDiff( + diff_result.right_module_unmatched_instructions, + right_op_metrics, url_generator)), + PrintDetails("Changed Instructions", + PrintMatchedMetricsDiff( + diff_result.changed_instructions, left_op_metrics, + right_op_metrics, url_generator)), + PrintDetails("Unchanged Instructions", + PrintMatchedMetricsDiff( + diff_result.unchanged_instructions, left_op_metrics, + right_op_metrics, url_generator)))); +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h new file mode 100644 index 00000000000000..c2244629392fed --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h @@ -0,0 +1,54 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_HTML_RENDERER_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_HTML_RENDERER_H_ + +#include +#include +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" + +namespace xla { +namespace hlo_diff { + +// A function that returns a visualization url for the given instruction pair. +using UrlGenerator = absl::FunctionRef; + +// A function that returns the op metric for the given op name. +using GetOpMetricFn = + absl::FunctionRef(absl::string_view)>; + +// Renders the diff result in HTML format, and writes the result to the given +// output stream. + +// url_generator can be specified which is used to link an url to each generated +// diff result. +void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, + UrlGenerator url_generator, GetOpMetricFn left_op_metrics, + GetOpMetricFn right_op_metrics, std::ostringstream& out); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_HTML_RENDERER_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc new file mode 100644 index 00000000000000..d33d2099d49ee8 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc @@ -0,0 +1,140 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_print_options.h" + +namespace xla { +namespace hlo_diff { + +std::string InstructionToString(const HloInstruction* instr, bool name_only) { + if (name_only) { + return std::string(instr->name()); + } + return instr->ToString(HloPrintOptions::ShortParsable()); +} + +std::vector GetChangedInstructionDiffTypes( + const HloInstruction& left, const HloInstruction& right) { + // Compare shapes, layouts and memory spaces + std::vector diff_types; + if (left.shape() != right.shape()) { + diff_types.push_back(ChangedInstructionDiffType::kShapeChange); + + if (left.shape().IsArray() && right.shape().IsArray() && + left.shape().has_layout() && right.shape().has_layout() && + (left.shape().layout() != right.shape().layout())) { + diff_types.push_back(ChangedInstructionDiffType::kLayoutChange); + if (left.shape().layout().memory_space() != + right.shape().layout().memory_space()) { + diff_types.push_back(ChangedInstructionDiffType::kMemorySpaceChange); + } + } + } + + // Compare operand numbers and shapes + if (left.operand_count() != right.operand_count()) { + diff_types.push_back(ChangedInstructionDiffType::kChangedOperandsNumber); + } else { // If operand numbers are the same, compare shapes + for (int64_t i = 0; i < left.operand_count(); ++i) { + if (left.operand(i)->shape() != right.operand(i)->shape()) { + diff_types.push_back(ChangedInstructionDiffType::kChangedOperandsShape); + break; + } + } + } + + // Compare opcodes + if (left.opcode() != right.opcode()) { + diff_types.push_back(ChangedInstructionDiffType::kOpCodeChanged); + } + + // Compare constants + if (left.IsConstant() && right.IsConstant()) { + if (left.literal() != right.literal()) { + diff_types.push_back(ChangedInstructionDiffType::kConstantLiteralChanged); + } + } + + // If no diff type is found, return kOtherChange. + if (diff_types.empty()) { + diff_types.push_back(ChangedInstructionDiffType::kOtherChange); + } + + return diff_types; +}; + +std::string GetChangedInstructionDiffTypeString( + ChangedInstructionDiffType diff_type) { + switch (diff_type) { + case ChangedInstructionDiffType::kOtherChange: + return "kOtherChange"; + case ChangedInstructionDiffType::kShapeChange: + return "kShapeChange"; + case ChangedInstructionDiffType::kLayoutChange: + return "kLayoutChange"; + case ChangedInstructionDiffType::kMemorySpaceChange: + return "kMemorySpaceChange"; + case ChangedInstructionDiffType::kChangedOperandsNumber: + return "kChangedOperandsNumber"; + case ChangedInstructionDiffType::kChangedOperandsShape: + return "kChangedOperandsShape"; + case ChangedInstructionDiffType::kOpCodeChanged: + return "kOpCodeChanged"; + case ChangedInstructionDiffType::kConstantLiteralChanged: + return "kConstantLiteralChanged"; + default: + return ""; + } +} + +absl::flat_hash_map> +GroupInstructionsByOpcode( + absl::Span instructions) { + absl::flat_hash_map> + instructions_by_opcode; + for (const HloInstruction* inst : instructions) { + instructions_by_opcode[inst->opcode()].push_back(inst); + } + return instructions_by_opcode; +} + +absl::flat_hash_map< + HloOpcode, + std::vector>> +GroupInstructionPairsByOpcode( + const absl::flat_hash_map& + instructions) { + absl::flat_hash_map< + HloOpcode, + std::vector>> + instructions_by_opcode; + for (const auto& pair : instructions) { + instructions_by_opcode[pair.first->opcode()].push_back(pair); + } + return instructions_by_opcode; +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h new file mode 100644 index 00000000000000..01609c53f5b302 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h @@ -0,0 +1,78 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_RENDERER_UTIL_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_RENDERER_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" + +namespace xla { +namespace hlo_diff { + +// Print the instruction to string. +std::string InstructionToString(const HloInstruction* instr, bool name_only); + +// Enum representing the type of changes for a pair of changed instructions. +enum class ChangedInstructionDiffType : uint8_t { + kOtherChange, + kShapeChange, + kLayoutChange, + kMemorySpaceChange, + kChangedOperandsNumber, + kChangedOperandsShape, + kOpCodeChanged, + kConstantLiteralChanged, +}; + +// Returns details on what exactly has changed for a pair of changed +// instruction. +std::vector GetChangedInstructionDiffTypes( + const HloInstruction& left, const HloInstruction& right); + +// Converts the changed instruction diff type enum value to a string. +std::string GetChangedInstructionDiffTypeString( + ChangedInstructionDiffType diff_type); + +// Opcodes to be ignored when printing summaries. +inline constexpr auto kIgnoredOpcodes = std::array( + {HloOpcode::kReshape, HloOpcode::kBitcast, HloOpcode::kPad, + HloOpcode::kCopyDone, HloOpcode::kCopyStart, HloOpcode::kGetTupleElement}); + +// Groups the instructions by opcode. +absl::flat_hash_map> +GroupInstructionsByOpcode(absl::Span instructions); + +// Groups the instruction pairs by opcode. +absl::flat_hash_map< + HloOpcode, + std::vector>> +GroupInstructionPairsByOpcode( + const absl::flat_hash_map& + instructions); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_RENDERER_UTIL_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc new file mode 100644 index 00000000000000..f5d7e60061dcec --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc @@ -0,0 +1,92 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h" + +#include +#include + +#include +#include +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::Pair; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +class HloDiffRendererUtilTest : public HloTestBase {}; + +TEST_F(HloDiffRendererUtilTest, GroupInstructionsByOpcode) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule test_module + +ENTRY test_computation { + param1 = s32[10] parameter(0) + param2 = s32[10] parameter(1) + add = s32[10] add(param1, param2) + ROOT sub = s32[10] subtract(add, param2) +} + )")); + std::vector instructions; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + instructions.push_back(instruction); + } + } + + EXPECT_THAT(GroupInstructionsByOpcode(instructions), + UnorderedElementsAre(Pair(HloOpcode::kParameter, SizeIs(2)), + Pair(HloOpcode::kAdd, SizeIs(1)), + Pair(HloOpcode::kSubtract, SizeIs(1)))); +} + +TEST_F(HloDiffRendererUtilTest, GroupInstructionPairsByOpcode) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule test_module + +ENTRY test_computation { + param1 = s32[10] parameter(0) + param2 = s32[10] parameter(1) + add = s32[10] add(param1, param2) + ROOT sub = s32[10] subtract(add, param2) +} + )")); + absl::flat_hash_map + instruction_map; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + instruction_map[instruction] = instruction; + } + } + + EXPECT_THAT(GroupInstructionPairsByOpcode(instruction_map), + UnorderedElementsAre(Pair(HloOpcode::kParameter, SizeIs(2)), + Pair(HloOpcode::kAdd, SizeIs(1)), + Pair(HloOpcode::kSubtract, SizeIs(1)))); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc new file mode 100644 index 00000000000000..ee27ea2bd36aab --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc @@ -0,0 +1,271 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h" + +namespace xla { +namespace hlo_diff { +namespace { + +// Prints unmatched instructions grouped by opcode and print in a descending +// order of the number of instructions for each opcode. If top_n_opcodes or +// max_instructions_per_opcode is a negative number, all the instructions will +// be printed. +void PrintUnmatchedInstructions( + const absl::string_view header, + absl::Span instructions, + std::ostringstream& out, const RenderTextOptions& options) { + out << header; + if (options.top_n_opcodes >= 0) { + out << " (top " << options.top_n_opcodes << " frequent opcode)"; + } + if (!options.opcodes_to_ignore.empty()) { + out << " (ignoring " + << absl::StrJoin(options.opcodes_to_ignore, ", ", + [](std::string* out, const HloOpcode& opcode) { + absl::StrAppend(out, HloOpcodeString(opcode)); + }) + << ")"; + } + out << ":\n"; + + absl::flat_hash_map> + instructions_by_opcode = GroupInstructionsByOpcode(instructions); + std::vector> opcode_counts; + for (const auto& [opcode, insts] : instructions_by_opcode) { + opcode_counts.push_back({opcode, insts.size()}); + } + std::sort(opcode_counts.begin(), opcode_counts.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + // Print the top N most frequent opcodes + int i = 0; + for (auto cit = opcode_counts.begin(); + (options.top_n_opcodes < 0 || i < options.top_n_opcodes) && + cit != opcode_counts.end(); + ++cit) { + if (options.opcodes_to_ignore.contains(cit->first)) { + continue; + } + absl::string_view op_name = HloOpcodeString(cit->first); + out << " " << op_name << " (" << cit->second << "):\n"; + std::vector insts = + instructions_by_opcode[cit->first]; + // Print the M instructions for each opcode + int j = 0; + for (auto iit = insts.begin(); (options.max_instructions_per_opcode < 0 || + j < options.max_instructions_per_opcode) && + iit != insts.end(); + ++j, ++iit) { + out << " " << InstructionToString(*iit, options.name_only) << "\n"; + } + if (j < insts.size()) { + out << " ... and " << insts.size() - j << " more " << op_name + << " instructions\n"; + } + out << "\n"; + ++i; + } + if (i < opcode_counts.size()) { + out << " ... and " << opcode_counts.size() - i << " more opcodes\n"; + } + out << "\n"; +} + +// Prints changed or unchanged instructions grouped by opcode and print in a +// descending order of the number of instructions for each opcode. If +// top_n_opcodes or max_instructions_per_opcode is a negative number, all the +// instructions will be printed. +void PrintChangedAndUnchangedInstructions( + absl::string_view header, + const absl::flat_hash_map& + instructions, + std::ostringstream& out, bool is_changed_pair, + const RenderTextOptions& options) { + out << header; + if (options.top_n_opcodes >= 0) { + out << " (top " << options.top_n_opcodes << " frequent opcode)"; + } + if (!options.opcodes_to_ignore.empty()) { + out << " (ignoring " + << absl::StrJoin(options.opcodes_to_ignore, ", ", + [](std::string* out, const HloOpcode& opcode) { + absl::StrAppend(out, HloOpcodeString(opcode)); + }) + << ")"; + } + out << ":\n"; + absl::flat_hash_map< + HloOpcode, + std::vector>> + instructions_by_opcode = GroupInstructionPairsByOpcode(instructions); + std::vector> opcode_counts; + for (const auto& [opcode, insts] : instructions_by_opcode) { + opcode_counts.push_back({opcode, insts.size()}); + } + std::sort(opcode_counts.begin(), opcode_counts.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + // Print the top N most frequent opcodes + int i = 0; + for (auto cit = opcode_counts.begin(); + (options.top_n_opcodes < 0 || i < options.top_n_opcodes) && + cit != opcode_counts.end(); + ++cit) { + if (options.opcodes_to_ignore.contains(cit->first)) { + continue; + } + absl::string_view op_name = HloOpcodeString(cit->first); + out << " " << op_name << " (" << cit->second << ")"; + if (is_changed_pair) { + // Count and sort the number of diff types for each opcode + absl::flat_hash_map diff_type_counts; + for (const auto& inst_pair : instructions_by_opcode[cit->first]) { + std::vector diff_types = + GetChangedInstructionDiffTypes(*inst_pair.first, *inst_pair.second); + for (const auto& diff_type : diff_types) { + diff_type_counts[diff_type]++; + } + } + std::vector> + diff_type_counts_vec(diff_type_counts.begin(), + diff_type_counts.end()); + std::sort( + diff_type_counts_vec.begin(), diff_type_counts_vec.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + out << ", top diff types: " + << absl::StrJoin( + diff_type_counts_vec, ", ", + [](std::string* out, const auto& pair) { + absl::StrAppend( + out, GetChangedInstructionDiffTypeString(pair.first), + " (", pair.second, ")"); + }); + } + out << "\n"; + std::vector> insts = + instructions_by_opcode[cit->first]; + + // Print the M instructions for each opcode + int j = 0; + for (auto iit = insts.begin(); (options.max_instructions_per_opcode < 0 || + j < options.max_instructions_per_opcode) && + iit != insts.end(); + ++j, ++iit) { + if (is_changed_pair) { + std::vector diff_types = + GetChangedInstructionDiffTypes(*iit->first, *iit->second); + out << " " << InstructionToString(iit->first, /*name_only=*/true) + << " and " << InstructionToString(iit->second, /*name_only=*/true) + << " have changed: " + << absl::StrJoin( + diff_types, ", ", + [](std::string* out, const auto& diff_type) { + return absl::StrAppend( + out, GetChangedInstructionDiffTypeString(diff_type)); + }) + << "\n"; + if (!options.name_only) { + out << " Left: " + << InstructionToString(iit->first, /*name_only=*/false) << "\n"; + out << " Right: " + << InstructionToString(iit->second, /*name_only=*/false) << "\n"; + } + } else { + out << " " << InstructionToString(iit->first, options.name_only) + << "\n"; + } + } + if (j < insts.size()) { + out << " ... and " << insts.size() - j << " more " << op_name + << " instructions\n"; + } + out << "\n"; + ++i; + } + if (i < opcode_counts.size()) { + out << " ... and " << opcode_counts.size() - i << " more opcodes\n"; + } + out << "\n"; +} + +} // namespace + +void RenderText(const DiffResult& diff_result, std::ostringstream& out, + const RenderTextOptions& options) { + // Print unmatched instructions + PrintUnmatchedInstructions("Unmatched Instructions (left)", + diff_result.left_module_unmatched_instructions, + out, options); + + PrintUnmatchedInstructions("Unmatched Instructions (right)", + diff_result.right_module_unmatched_instructions, + out, options); + + // Print changed instructions (print both left and right) + PrintChangedAndUnchangedInstructions("Changed Instructions", + diff_result.changed_instructions, out, + true, options); + + if (options.print_unchanged_instructions) { + // Print unchanged instructions (print only the first instruction) + PrintChangedAndUnchangedInstructions("Unchanged Instructions", + diff_result.unchanged_instructions, + out, false, options); + } +} + +void RenderTextSummary(const DiffResult& diff_result, std::ostringstream& out) { + // Print a summary of the diff results + out << "Diff Summary:\n"; + out << " Unmatched instructions (left): " + << diff_result.left_module_unmatched_instructions.size() << "\n"; + out << " Unmatched instructions (right): " + << diff_result.right_module_unmatched_instructions.size() << "\n"; + out << " Changed instructions: " << diff_result.changed_instructions.size() + << "\n"; + out << " Unchanged instructions: " + << diff_result.unchanged_instructions.size() << "\n"; + out << "\n"; + + RenderTextOptions options = { + .top_n_opcodes = 5, + .max_instructions_per_opcode = 5, + .name_only = true, + .opcodes_to_ignore = absl::flat_hash_set( + kIgnoredOpcodes.begin(), kIgnoredOpcodes.end()), + .print_unchanged_instructions = false}; + RenderText(diff_result, out, options); +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.h new file mode 100644 index 00000000000000..bd56b34e55e3c5 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.h @@ -0,0 +1,55 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_TEXT_RENDERER_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_TEXT_RENDERER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" + +namespace xla { +namespace hlo_diff { + +// Options for rendering the diff result to text. +struct RenderTextOptions { + // Print the top n opcodes. If negative, print all opcodes. + int top_n_opcodes = -1; + // Print the top n instructions per opcode. If negative, print all + // instructions. + int max_instructions_per_opcode = -1; + // If true, only print the instruction name. Otherwise, print the full details + // of the instruction. + bool name_only = false; + // Opcodes to be ignored when printing summaries. + absl::flat_hash_set opcodes_to_ignore; + // If true, print the unchanged instructions. + bool print_unchanged_instructions = true; +}; + +// Renders the diff result to a text output stream. +void RenderText(const DiffResult& diff_result, std::ostringstream& out, + const RenderTextOptions& options = {}); + +// Renders the diff summary to a text output stream. +void RenderTextSummary(const DiffResult& diff_result, std::ostringstream& out); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_TEXT_RENDERER_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD new file mode 100644 index 00000000000000..c820faaaad25b6 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD @@ -0,0 +1,66 @@ +load("//xla:xla.default.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "connected_components", + srcs = ["connected_components.cc"], + hdrs = ["connected_components.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +xla_cc_test( + name = "connected_components_test", + srcs = ["connected_components_test.cc"], + deps = [ + ":connected_components", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "hlo_diff_util", + hdrs = ["hlo_diff_util.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@local_tsl//tsl/platform:fingerprint", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/tools/hlo_diff:hlo_gumgraph_mappings", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph", + "//xla/hlo/tools/hlo_diff/graph:hlo_gumgraph_node", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_for_library", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.cc b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.cc new file mode 100644 index 00000000000000..064cca6827223e --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.cc @@ -0,0 +1,71 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/utils/connected_components.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_computation.h" + +namespace xla { +namespace hlo_diff { + +// Find the representative of the set (with path compression) +const HloComputation* ConnectedComponentsFinder::Find(const HloComputation* i) { + if (parent_.find(i) == parent_.end() || parent_[i] == i) { + parent_[i] = i; + return i; + } + return parent_[i] = Find(parent_[i]); // Path compression +} + +// Union the sets containing a and b (by making one parent the other) +void ConnectedComponentsFinder::Union(const HloComputation* a, + const HloComputation* b) { + const HloComputation* root_a = Find(a); + const HloComputation* root_b = Find(b); + if (root_a != root_b) { + parent_[root_a] = root_b; + } +} + +// Add an edge between two computations +void ConnectedComponentsFinder::AddEdge(const HloComputation* u, + const HloComputation* v) { + nodes_.insert(u); + nodes_.insert(v); + Union(u, v); +} + +// Find and return the connected components +std::vector> +ConnectedComponentsFinder::FindConnectedComponents() { + absl::flat_hash_map> + components; + for (const auto& node : nodes_) { + components[Find(node)].push_back(node); + } + + std::vector> result; + result.reserve(components.size()); + for (auto& [root, component_nodes] : components) { + result.push_back(component_nodes); + } + return result; +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.h b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.h new file mode 100644 index 00000000000000..61bef8aa2fc6da --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components.h @@ -0,0 +1,52 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_UTILS_CONNECTED_COMPONENTS_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_UTILS_CONNECTED_COMPONENTS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_computation.h" + +namespace xla { +namespace hlo_diff { + +// Finds the connected components in an undirected graph of HloComputations. +class ConnectedComponentsFinder { + public: + // Add an edge between two computations + void AddEdge(const HloComputation* u, const HloComputation* v); + + // Find and return the connected components + std::vector> FindConnectedComponents(); + + private: + // Find the representative of the set (with path compression) + const HloComputation* Find(const HloComputation* i); + + // Union the sets containing a and b (by making one parent the other) + void Union(const HloComputation* a, const HloComputation* b); + + absl::flat_hash_map parent_; + absl::flat_hash_set nodes_; +}; + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_UTILS_CONNECTED_COMPONENTS_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc new file mode 100644 index 00000000000000..2b6c0237506556 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc @@ -0,0 +1,152 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/utils/connected_components.h" + +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/shape_util.h" + +namespace xla { +namespace hlo_diff { + +namespace { +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +// Helper function to create a simple HloComputation +std::unique_ptr MakeComputation(absl::string_view name) { + auto builder = HloComputation::Builder(name); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); + builder.AddInstruction(HloInstruction::CreateTuple({x})); + return builder.Build(); +} + +class ConnectedComponentsFinderTest : public ::testing::Test {}; + +TEST_F(ConnectedComponentsFinderTest, EmptyGraph) { + ConnectedComponentsFinder cc_finder; + EXPECT_THAT(cc_finder.FindConnectedComponents(), IsEmpty()); +} + +TEST_F(ConnectedComponentsFinderTest, SingleNodeNoEdges) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + cc_finder.AddEdge(c1.get(), c1.get()); // Adding a self-loop, node exists + auto components = cc_finder.FindConnectedComponents(); + ASSERT_THAT(components.size(), 1); + EXPECT_THAT(components[0], UnorderedElementsAre(c1.get())); +} + +TEST_F(ConnectedComponentsFinderTest, TwoSeparateNodes) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + auto c2 = MakeComputation("c2"); + // Don't add an edge between c1 and c2 + cc_finder.AddEdge(c1.get(), c1.get()); + cc_finder.AddEdge(c2.get(), c2.get()); + auto components = cc_finder.FindConnectedComponents(); + ASSERT_THAT(components.size(), 2); + EXPECT_THAT(components, UnorderedElementsAre(UnorderedElementsAre(c1.get()), + UnorderedElementsAre(c2.get()))); +} + +TEST_F(ConnectedComponentsFinderTest, TwoConnectedNodes) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + auto c2 = MakeComputation("c2"); + cc_finder.AddEdge(c1.get(), c2.get()); + auto components = cc_finder.FindConnectedComponents(); + ASSERT_THAT(components.size(), 1); + EXPECT_THAT(components[0], UnorderedElementsAre(c1.get(), c2.get())); +} + +TEST_F(ConnectedComponentsFinderTest, ThreeNodesLinearConnection) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + auto c2 = MakeComputation("c2"); + auto c3 = MakeComputation("c3"); + cc_finder.AddEdge(c1.get(), c2.get()); + cc_finder.AddEdge(c2.get(), c3.get()); + auto components = cc_finder.FindConnectedComponents(); + ASSERT_THAT(components.size(), 1); + EXPECT_THAT(components[0], + UnorderedElementsAre(c1.get(), c2.get(), c3.get())); +} + +TEST_F(ConnectedComponentsFinderTest, ThreeNodesTriangleConnection) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + auto c2 = MakeComputation("c2"); + auto c3 = MakeComputation("c3"); + cc_finder.AddEdge(c1.get(), c2.get()); + cc_finder.AddEdge(c2.get(), c3.get()); + cc_finder.AddEdge(c3.get(), c1.get()); + auto components = cc_finder.FindConnectedComponents(); + ASSERT_THAT(components.size(), 1); + EXPECT_THAT(components[0], + UnorderedElementsAre(c1.get(), c2.get(), c3.get())); +} + +TEST_F(ConnectedComponentsFinderTest, MixedConnectedAndSeparate) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + auto c2 = MakeComputation("c2"); + auto c3 = MakeComputation("c3"); + auto c4 = MakeComputation("c4"); + auto c5 = MakeComputation("c5"); + cc_finder.AddEdge(c1.get(), c2.get()); + cc_finder.AddEdge(c2.get(), c3.get()); + cc_finder.AddEdge(c4.get(), c4.get()); // c4 is separate + auto components = cc_finder.FindConnectedComponents(); + ASSERT_THAT(components.size(), 2); + EXPECT_THAT( + components, + UnorderedElementsAre(UnorderedElementsAre(c1.get(), c2.get(), c3.get()), + UnorderedElementsAre(c4.get()))); +} + +TEST_F(ConnectedComponentsFinderTest, LargerComponentOk) { + ConnectedComponentsFinder cc_finder; + auto c1 = MakeComputation("c1"); + auto c2 = MakeComputation("c2"); + auto c3 = MakeComputation("c3"); + auto c4 = MakeComputation("c4"); + auto c5 = MakeComputation("c5"); + auto c6 = MakeComputation("c6"); + cc_finder.AddEdge(c1.get(), c2.get()); + cc_finder.AddEdge(c2.get(), c3.get()); + cc_finder.AddEdge(c3.get(), c4.get()); + cc_finder.AddEdge(c4.get(), c1.get()); + cc_finder.AddEdge(c3.get(), c5.get()); + cc_finder.AddEdge(c6.get(), c6.get()); + + auto components = cc_finder.FindConnectedComponents(); + + ASSERT_THAT(components.size(), 2); + EXPECT_THAT(components, UnorderedElementsAre( + UnorderedElementsAre(c1.get(), c2.get(), c3.get(), + c4.get(), c5.get()), + UnorderedElementsAre(c6.get()))); +} + +} // namespace +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/hlo_diff_util.h b/third_party/xla/xla/hlo/tools/hlo_diff/utils/hlo_diff_util.h new file mode 100644 index 00000000000000..b1d24e3c313b59 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/hlo_diff_util.h @@ -0,0 +1,42 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_UTILS_HLO_DIFF_UTIL_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_UTILS_HLO_DIFF_UTIL_H_ + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "tsl/platform/fingerprint.h" + +namespace xla::hlo_diff { + +inline uint64_t GetHloInstructionFingerprint( + const HloInstruction* instruction, + const HloPrintOptions& hlo_print_options) { + return tsl::Fingerprint64(instruction->ToString(hlo_print_options)); +} + +inline uint64_t GetHloInstructionFingerprint( + const HloInstruction* instruction) { + return GetHloInstructionFingerprint(instruction, + HloPrintOptions::Fingerprint()); +} + +} // namespace xla::hlo_diff + +#endif // XLA_HLO_TOOLS_HLO_DIFF_UTILS_HLO_DIFF_UTIL_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.cc b/third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.cc new file mode 100644 index 00000000000000..ad771dc2f39a20 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.cc @@ -0,0 +1,138 @@ +// Copyright 2025 The OpenXLA Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/hlo/tools/hlo_diff/utils/test_util.h" + +#include + +#include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla { +namespace hlo_diff { + +const HloInstructionNode* GetNodeByName(const HloGumgraph& graph, + absl::string_view name) { + for (const auto* node : graph.AllNodes()) { + if (!node->is_root && node->instruction->name() == name) { + return node; + } + } + return nullptr; +} + +void OverwriteMapInstructions(const HloInstructionNode* left, + const HloInstructionNode* right, + HloGumgraphMappings& mappings, + bool position_unchanged) { + ASSERT_NE(left, nullptr); + ASSERT_NE(right, nullptr); + if (auto it = mappings.left_to_right_instruction_map.left.find(left); + it != mappings.left_to_right_instruction_map.left.end()) { + mappings.left_to_right_instruction_map.left.erase(it); + } + + if (auto it = mappings.left_to_right_instruction_map.right.find(right); + it != mappings.left_to_right_instruction_map.right.end()) { + mappings.left_to_right_instruction_map.right.erase(it); + } + + mappings.left_to_right_instruction_map.insert( + InstructionPair(left, right, {.matcher_type = MatcherType::kManual})); + if (position_unchanged) { + mappings.left_to_right_instruction_map.left.find(left)->info.unchanged = + true; + } +} + +void MatchAllNodesByName(const HloGumgraph& left, const HloGumgraph& right, + HloGumgraphMappings& mappings) { + for (const auto* left_node : left.AllNodes()) { + if (left_node->is_root) { + continue; + } + const HloInstructionNode* right_node = nullptr; + for (const auto* node : right.AllNodes()) { + if (!node->is_root && + node->instruction->name() == left_node->instruction->name()) { + right_node = node; + break; + } + } + if (right_node != nullptr) { + mappings.MapInstructionsIfAbsent(left_node, right_node, + MatcherType::kManual); + } + } +} + +absl::flat_hash_map ExtractMappedInstructionNames( + const HloGumgraphMappings& mappings) { + absl::flat_hash_map mapped_nodes; + for (auto it = mappings.left_to_right_instruction_map.begin(); + it != mappings.left_to_right_instruction_map.end(); ++it) { + absl::string_view left_name = + it->left->is_root ? "root_L" : it->left->instruction->name(); + absl::string_view right_name = + it->right->is_root ? "root_R" : it->right->instruction->name(); + mapped_nodes[left_name] = right_name; + } + + return mapped_nodes; +} + +absl::flat_hash_map ExtractMappedComputationNames( + const HloGumgraphMappings& mappings) { + absl::flat_hash_map mapped_computations; + for (auto it = mappings.left_to_right_computation_map.left.begin(); + it != mappings.left_to_right_computation_map.left.end(); ++it) { + mapped_computations[it->first->computation()->name()] = + it->second->computation()->name(); + } + return mapped_computations; +} + +absl::flat_hash_map +ExtractComputationMatchType(const HloGumgraphMappings& mappings) { + absl::flat_hash_map computation_match_type; + for (auto it = mappings.left_to_right_computation_map.left.begin(); + it != mappings.left_to_right_computation_map.left.end(); ++it) { + computation_match_type[it->first->computation()->name()] = + it->info.computation_match_type; + } + return computation_match_type; +} + +absl::StatusOr GetInstructionByName(HloModule& module, + absl::string_view name) { + for (HloComputation* computation : module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == name) { + return instruction; + } + } + } + return absl::InvalidArgumentError("instruction not found"); +} + +} // namespace hlo_diff +} // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.h b/third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.h new file mode 100644 index 00000000000000..411bf7433c0737 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/test_util.h @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The OpenXLA Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_UTILS_TEST_UTIL_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_UTILS_TEST_UTIL_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" +#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h" +#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" + +namespace xla { +namespace hlo_diff { + +// Returns the node with the given name. +// Returns nullptr if the node is not found. +const HloInstructionNode* GetNodeByName(const HloGumgraph& graph, + absl::string_view name); + +// Map Nodes, overwriting existing mappings if they are different. +void OverwriteMapInstructions(const HloInstructionNode* left, + const HloInstructionNode* right, + HloGumgraphMappings& mappings, + bool position_unchanged = false); + +// Matches all node pairs with the same name. +void MatchAllNodesByName(const HloGumgraph& left, const HloGumgraph& right, + HloGumgraphMappings& mappings); + +// Extracts the mapped instruction names from the HloGumgraphMappings. +absl::flat_hash_map ExtractMappedInstructionNames( + const HloGumgraphMappings& mappings); + +// Extracts the mapped computation names from the HloGumgraphMappings. +absl::flat_hash_map ExtractMappedComputationNames( + const HloGumgraphMappings& mappings); + +// Extracts the computation match type from the HloGumgraphMappings. +absl::flat_hash_map +ExtractComputationMatchType(const HloGumgraphMappings& mappings); + +// Returns the instruction with the given name. +absl::StatusOr GetInstructionByName(HloModule& module, + absl::string_view name); + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_UTILS_TEST_UTIL_H_ From fa0b9d7cc8c773466f58859cdfb34da5280fb2ce Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 4 Apr 2025 17:11:54 -0700 Subject: [PATCH 0261/1324] [XLA:LAYOUT_ASSIGNMENT] Add a way to reset the entry computation to the inital saved value. PiperOrigin-RevId: 744113929 --- .../xla/xla/service/layout_assignment.cc | 120 +++++++++--------- .../xla/xla/service/layout_assignment.h | 7 + 2 files changed, 68 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 1fb7e247e92661..b7bea1e55704c5 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -710,7 +710,7 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( ShapeLayout parameter_layout = constraints->computation_layout().parameter_layout( instruction->parameter_number()); - // Allow some paramter/result layouts to be unset in the entry + // Allow some parameter/result layouts to be unset in the entry // computation. if (parameter_layout.AnyLayoutIsSet()) { // Clear out memory space in layout. Host offloader will do the @@ -1386,7 +1386,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( } if (instruction->opcode() == HloOpcode::kReshape) { - // Prefer the operand layout that makes the reshape an bitcast. If any + // Prefer the operand layout that makes the reshape a bitcast. If any // dimension bound is 1 in the operand shape, there may be several such // layouts. So if 'output_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy @@ -2707,63 +2707,9 @@ absl::StatusOr LayoutAssignment::Run( entry_computation_layout_->AnyLayoutSet() ? LayoutConstraint::kGivenPriority : LayoutConstraint::kDefaultPriority)); - for (int64_t i = 0; i < kNumberOfPropagationRounds; ++i) { - if (i > 0) { - LayoutConstraints* constraints = - mutable_computation_constraints(module->entry_computation()); - - bool changed = false; - module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias& alias) { - const auto param = alias.parameter_number; - const auto& index = alias.parameter_index; - bool param_is_forced = - ShapeUtil::GetSubshape( - saved_entry_computation_layout_.parameter_shape(param), - index) - .has_layout(); - bool result_is_forced = - ShapeUtil::GetSubshape( - saved_entry_computation_layout_.result_shape(), - output_index) - .has_layout(); - Shape* param_shape = - ShapeUtil::GetMutableSubshape(module->entry_computation() - ->parameter_instruction(param) - ->mutable_shape(), - index); - Shape* result_shape = - ShapeUtil::GetMutableSubshape(module->entry_computation() - ->root_instruction() - ->mutable_shape(), - output_index); - if (param_is_forced && result_is_forced) { - return; - } - - if (param_shape->layout().minor_to_major() == - result_shape->layout().minor_to_major()) { - return; - } - changed = true; - if (!param_is_forced) { - *param_shape = *result_shape; - return; - } - *result_shape = *param_shape; - }); - if (changed) { - auto computed_program_shape = - module->entry_computation()->ComputeProgramShape(); - constraints->mutable_computation_constraint()->ResetComputationLayout( - ComputationLayout{ - module->entry_computation()->ComputeProgramShape(), false}, - LayoutConstraint::kGivenPriority, true, true); - *entry_computation_layout_ = - constraints->computation_constraint().computation_layout(); - } - } + bool changed = true; + for (int64_t i = 0; changed || i < kNumberOfPropagationRounds; ++i) { + changed = false; VLOG(1) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module, execution_threads)); for (auto* computation : computations_to_work) { @@ -2773,6 +2719,62 @@ absl::StatusOr LayoutAssignment::Run( RunOnComputation(constraints, channel_layout_constraints_)); } current_priority_ += 1; + auto* entry_constraint = + mutable_computation_constraints(module->entry_computation()) + ->mutable_computation_constraint() + ->mutable_computation_layout(); + TF_RETURN_IF_ERROR( + module->input_output_alias_config().ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + const auto param = alias.parameter_number; + const auto& index = alias.parameter_index; + bool param_is_forced = + ShapeUtil::GetSubshape( + saved_entry_computation_layout_.parameter_shape(param), + index) + .has_layout(); + bool result_is_forced = + ShapeUtil::GetSubshape( + saved_entry_computation_layout_.result_shape(), + output_index) + .has_layout(); + if (param_is_forced && result_is_forced) { + return absl::OkStatus(); + } + auto* entry = module->entry_computation(); + TF_ASSIGN_OR_RETURN( + auto param_layout, + InferArrayLayout(entry->parameter_instruction(param), index)); + TF_ASSIGN_OR_RETURN( + auto result_layout, + InferArrayLayout(entry->root_instruction(), output_index)); + if (param_layout.minor_to_major() == + result_layout.minor_to_major()) { + return absl::OkStatus(); + } + changed = true; + if (!param_is_forced) { + entry_computation_layout_->mutable_parameter_layout(param) + ->ResetLayout(result_layout, index); + entry_computation_layout_->mutable_result_layout()->ResetLayout( + result_layout, output_index); + entry_constraint->mutable_parameter_layout(param)->ResetLayout( + result_layout, index); + entry_constraint->mutable_result_layout()->ResetLayout( + result_layout, output_index); + return absl::OkStatus(); + } + entry_computation_layout_->mutable_parameter_layout(param) + ->ResetLayout(param_layout, index); + entry_computation_layout_->mutable_result_layout()->ResetLayout( + param_layout, output_index); + entry_constraint->mutable_parameter_layout(param)->ResetLayout( + param_layout, index); + entry_constraint->mutable_result_layout()->ResetLayout( + param_layout, output_index); + return absl::OkStatus(); + })); } for (auto* computation : computations_to_work) { diff --git a/third_party/xla/xla/service/layout_assignment.h b/third_party/xla/xla/service/layout_assignment.h index b925821311e1d0..5c9a5a576906ec 100644 --- a/third_party/xla/xla/service/layout_assignment.h +++ b/third_party/xla/xla/service/layout_assignment.h @@ -171,6 +171,9 @@ class ComputationLayoutConstraint : public LayoutConstraint { const ComputationLayout& computation_layout() const { return computation_layout_; } + ComputationLayout* mutable_computation_layout() { + return &computation_layout_; + } void ResetComputationLayout(const ComputationLayout& layout, int64_t priority, bool prop_result_layout, bool prop_parameter_layout) { @@ -704,6 +707,10 @@ class LayoutAssignment : public HloModulePass { } } + void ResetEntryComputationLayout() { + *entry_computation_layout_ = saved_entry_computation_layout_; + } + // Adds constraints related to host Send/Recv instructions. absl::Status BuildHostChannelConstraints(HloComputation* computation); From ba268154571bffb32028cc9f1185241210455bd3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 17:24:33 -0700 Subject: [PATCH 0262/1324] Temporarily remove input pipeline analyzer. PiperOrigin-RevId: 744116422 --- tensorflow/core/profiler/convert/xplane_to_tool_names.cc | 4 +++- tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc index 9fb564899e5641..df3ccc129e5922 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc @@ -39,7 +39,9 @@ absl::StatusOr GetAvailableToolNames( tools.reserve(11); tools.push_back(is_cloud_vertex_ai ? "trace_viewer" : "trace_viewer@"); tools.push_back("overview_page"); - tools.push_back("input_pipeline_analyzer"); + // TODO(jonahweaver): Re-enable input_pipeline_analyzer when it is ready. + // b/407096031 + // tools.push_back("input_pipeline_analyzer"); tools.push_back("framework_op_stats"); tools.push_back("memory_profile"); tools.push_back("pod_viewer"); diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc index 83fa3111374622..d10e81519e563a 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc @@ -119,7 +119,9 @@ TEST_P(XPlaneToToolsTest, ToolsList) { std::vector expected_tools = { "trace_viewer", "overview_page", - "input_pipeline_analyzer", + // TODO(jonahweaver): Re-enable input_pipeline_analyzer when it is ready. + // b/407096031 + // "input_pipeline_analyzer", "framework_op_stats", "memory_profile", "pod_viewer", From f781546fbcf885fc4a6abf9b7b76b3371eece050 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 20:08:26 -0700 Subject: [PATCH 0263/1324] Improve `safe_reinterpret_cast` to cover more cases, and fix some uses of `reinterpret_cast`. PiperOrigin-RevId: 744145231 --- third_party/xla/xla/BUILD | 4 + third_party/xla/xla/literal.cc | 118 ++++++++++-------- third_party/xla/xla/literal.h | 9 +- third_party/xla/xla/literal_test.cc | 48 ++++--- third_party/xla/xla/maybe_owning.h | 15 ++- .../xla/xla/tsl/util/safe_reinterpret_cast.h | 60 +++++---- .../tsl/util/safe_reinterpret_cast_test.cc | 39 ++++-- third_party/xla/xla/util.h | 6 +- third_party/xla/xla/util_test.cc | 11 ++ 9 files changed, 198 insertions(+), 112 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 3bfda5c5b5a1e5..61194f0a629e72 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -326,6 +326,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -628,6 +629,7 @@ cc_library( "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", "//xla/tsl/util:byte_swap_array", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -666,6 +668,7 @@ xla_cc_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", "@com_google_absl//absl/random", @@ -1232,6 +1235,7 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ "//xla/backends/cpu:alignment", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/base:dynamic_annotations", ], ) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 57dfd6c94e0d51..20803b2dec197a 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/byte_swap_array.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -2269,24 +2270,27 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { break; case U1: *proto->mutable_u1s() = std::string( - reinterpret_cast(data().data()), size_bytes_dense()); + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); break; case U2: *proto->mutable_u2s() = std::string( - reinterpret_cast(data().data()), size_bytes_dense()); + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); break; case U4: *proto->mutable_u4s() = std::string( - reinterpret_cast(data().data()), size_bytes_dense()); + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); break; case U8: proto->set_u8s(static_cast(data().data()), element_count()); break; case U16: - *proto->mutable_u16s() = - std::string(reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_u16s() = std::string( + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); if (!kLittleEndian) { ConvertEndianShort(proto->mutable_u16s()); } @@ -2299,24 +2303,27 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { break; case S1: *proto->mutable_s1s() = std::string( - reinterpret_cast(data().data()), size_bytes_dense()); + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); break; case S2: *proto->mutable_s2s() = std::string( - reinterpret_cast(data().data()), size_bytes_dense()); + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); break; case S4: *proto->mutable_s4s() = std::string( - reinterpret_cast(data().data()), size_bytes_dense()); + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); break; case S8: proto->set_s8s(static_cast(data().data()), element_count()); break; case S16: - *proto->mutable_s16s() = - std::string(reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_s16s() = std::string( + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); if (!kLittleEndian) { ConvertEndianShort(proto->mutable_s16s()); } @@ -2328,62 +2335,71 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { CopyToRepeatedField(proto->mutable_s64s(), data()); break; case F4E2M1FN: - *proto->mutable_f4e2m1fns() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f4e2m1fns() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E5M2: - *proto->mutable_f8e5m2s() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e5m2s() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E4M3: - *proto->mutable_f8e4m3s() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e4m3s() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E4M3FN: - *proto->mutable_f8e4m3fns() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e4m3fns() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E4M3B11FNUZ: - *proto->mutable_f8e4m3b11fnuzs() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e4m3b11fnuzs() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E5M2FNUZ: - *proto->mutable_f8e5m2fnuzs() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e5m2fnuzs() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E4M3FNUZ: - *proto->mutable_f8e4m3fnuzs() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e4m3fnuzs() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E3M4: - *proto->mutable_f8e3m4s() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e3m4s() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F8E8M0FNU: - *proto->mutable_f8e8m0fnus() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f8e8m0fnus() = + std::string(tsl::safe_reinterpret_cast( + data().data()), + size_bytes_dense()); break; case F16: - *proto->mutable_f16s() = - std::string(reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_f16s() = std::string( + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); if (!kLittleEndian) { ConvertEndianShort(proto->mutable_f16s()); } break; case BF16: - *proto->mutable_bf16s() = - std::string(reinterpret_cast(data().data()), - size_bytes_dense()); + *proto->mutable_bf16s() = std::string( + tsl::safe_reinterpret_cast(data().data()), + size_bytes_dense()); if (!kLittleEndian) { ConvertEndianShort(proto->mutable_bf16s()); } @@ -2480,7 +2496,8 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(data().size() * sizeof(int16_t) == s.size()); memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + ConvertEndianShort(tsl::safe_reinterpret_cast(untyped_data()), + s.size()); } break; } @@ -2513,7 +2530,8 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(data().size() * sizeof(uint16_t) == s.size()); memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + ConvertEndianShort(tsl::safe_reinterpret_cast(untyped_data()), + s.size()); } break; } @@ -2597,7 +2615,8 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(data().size() * sizeof(half) == s.size()); memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + ConvertEndianShort(tsl::safe_reinterpret_cast(untyped_data()), + s.size()); } break; } @@ -2606,7 +2625,8 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + ConvertEndianShort(tsl::safe_reinterpret_cast(untyped_data()), + s.size()); } break; } diff --git a/third_party/xla/xla/literal.h b/third_party/xla/xla/literal.h index 1690b36b69d994..0b7110b3331c06 100644 --- a/third_party/xla/xla/literal.h +++ b/third_party/xla/xla/literal.h @@ -58,6 +58,7 @@ limitations under the License. #include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -882,7 +883,7 @@ class LiteralBase { // Gets/sets the buffer holding dynamic sizes. const DynamicSizeType* dynamic_size_buffer() const { DCHECK(LayoutUtil::IsDenseArray(*subshape_)); - return reinterpret_cast( + return tsl::safe_reinterpret_cast( buffer() + dynamic_size_buffer_offset()); } DynamicSizeType* dynamic_size_buffer() { @@ -1823,8 +1824,8 @@ absl::Span LiteralBase::Piece::data() const { << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return absl::Span(reinterpret_cast(buffer()), - element_count()); + return absl::Span( + tsl::safe_reinterpret_cast(buffer()), element_count()); } template @@ -1841,7 +1842,7 @@ absl::Span LiteralBase::Piece::data() { << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return absl::Span(reinterpret_cast(buffer()), + return absl::Span(tsl::safe_reinterpret_cast(buffer()), element_count()); } diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index e7e49d1b7680ff..19b0eeb0b1949b 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -1485,7 +1486,8 @@ TEST_F(LiteralUtilTest, F16) { // are in little endian format // TODO - modify if we make the data format machine endianness dependent Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); - const char* d1 = reinterpret_cast(m1.data().data()); + const char* const d1 = + tsl::safe_reinterpret_cast(m1.data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -1498,12 +1500,16 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - const uint16_t* d2 = - reinterpret_cast(m2.data().data()); - EXPECT_EQ(d2[0], 0x3C00); - EXPECT_EQ(d2[1], 0x4000); - EXPECT_EQ(d2[2], 0x4000); - EXPECT_EQ(d2[3], 0x3C00); + const char* const d2 = + tsl::safe_reinterpret_cast(m2.data().data()); + EXPECT_EQ(d2[0], 0x00); + EXPECT_EQ(d2[1], 0x3C); + EXPECT_EQ(d2[2], 0x00); + EXPECT_EQ(d2[3], 0x40); + EXPECT_EQ(d2[4], 0x00); + EXPECT_EQ(d2[5], 0x40); + EXPECT_EQ(d2[6], 0x00); + EXPECT_EQ(d2[7], 0x3C); } TEST_F(LiteralUtilTest, Populate) { @@ -2073,8 +2079,9 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { std::vector int64_values = {1, 2, 3}; const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); - BorrowingLiteral literal(reinterpret_cast(int64_values.data()), - literal_shape); + BorrowingLiteral literal( + tsl::safe_reinterpret_cast(int64_values.data()), + literal_shape); EXPECT_EQ(literal.Get({0}), 1); EXPECT_EQ(literal.Get({1}), 2); @@ -2090,8 +2097,9 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { std::vector src_buf_ptrs; src_buf_ptrs.emplace_back( - reinterpret_cast(one_two_three.data())); - src_buf_ptrs.emplace_back(reinterpret_cast(hundred.data())); + tsl::safe_reinterpret_cast(one_two_three.data())); + src_buf_ptrs.emplace_back( + tsl::safe_reinterpret_cast(hundred.data())); auto literal_tuple = BorrowingLiteral( src_buf_ptrs, ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape})); @@ -2117,9 +2125,12 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromShapeTree) { Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, shape}); ShapeTree ptr_tree(nested_tuple); - *ptr_tree.mutable_element({0, 0}) = reinterpret_cast(data.data()); - *ptr_tree.mutable_element({0, 1}) = reinterpret_cast(data.data()); - *ptr_tree.mutable_element({1}) = reinterpret_cast(data.data()); + *ptr_tree.mutable_element({0, 0}) = + tsl::safe_reinterpret_cast(data.data()); + *ptr_tree.mutable_element({0, 1}) = + tsl::safe_reinterpret_cast(data.data()); + *ptr_tree.mutable_element({1}) = + tsl::safe_reinterpret_cast(data.data()); BorrowingLiteral literal(ptr_tree); @@ -2136,9 +2147,12 @@ TEST_F(LiteralUtilTest, MutableBorrowingLiteralFromShapeTree) { Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, shape}); ShapeTree ptr_tree(nested_tuple); - *ptr_tree.mutable_element({0, 0}) = reinterpret_cast(data.data()); - *ptr_tree.mutable_element({0, 1}) = reinterpret_cast(data.data()); - *ptr_tree.mutable_element({1}) = reinterpret_cast(data.data()); + *ptr_tree.mutable_element({0, 0}) = + tsl::safe_reinterpret_cast(data.data()); + *ptr_tree.mutable_element({0, 1}) = + tsl::safe_reinterpret_cast(data.data()); + *ptr_tree.mutable_element({1}) = + tsl::safe_reinterpret_cast(data.data()); MutableBorrowingLiteral literal(ptr_tree); diff --git a/third_party/xla/xla/maybe_owning.h b/third_party/xla/xla/maybe_owning.h index 4f32472ecb2f95..2b63a45543375d 100644 --- a/third_party/xla/xla/maybe_owning.h +++ b/third_party/xla/xla/maybe_owning.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "xla/tsl/util/safe_reinterpret_cast.h" + // A unique_ptr like class which may or may not have ownership of its pointer. // Uses least significant bit of the pointer to indicate ownership. template @@ -81,18 +83,21 @@ class MaybeOwning final { }; T* RemoveMask() const { - return reinterpret_cast(ptr_and_owning_bit_ & kPointerMask); + return tsl::safe_reinterpret_cast( + static_cast(ptr_and_owning_bit_ & kPointerMask)); } static intptr_t TakeUnique(std::unique_ptr unique) { T* released = unique.release(); - DCHECK_EQ(reinterpret_cast(released) & kOwningBitMask, 0); - return reinterpret_cast(released) | kOwningBitMask; + DCHECK_EQ(tsl::safe_reinterpret_cast(released) & kOwningBitMask, + 0); + return tsl::safe_reinterpret_cast(released) | kOwningBitMask; } static intptr_t Borrow(const T* borrowed) { - DCHECK_EQ(reinterpret_cast(borrowed) & kOwningBitMask, 0); - return reinterpret_cast(borrowed); + DCHECK_EQ(tsl::safe_reinterpret_cast(borrowed) & kOwningBitMask, + 0); + return tsl::safe_reinterpret_cast(borrowed); } void MaybeDeleteOwned() { diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index 636bcf0a42ea9d..82aebaf8ecdc46 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -37,6 +37,28 @@ namespace tsl { namespace internal { +// IsByteLike::value is true if T is a byte-like type (char, unsigned char, +// or std::byte). +template +struct IsByteLike : std::false_type {}; +template <> +struct IsByteLike : std::true_type {}; +template <> +struct IsByteLike : std::true_type {}; +template <> +struct IsByteLike : std::true_type {}; + +// IsCvByteLike::value is true if T is a possibly CV-qualified byte-like type +// (char, unsigned char, or std::byte). +template +struct IsCvByteLike : IsByteLike {}; +template +struct IsCvByteLike : IsByteLike {}; +template +struct IsCvByteLike : IsByteLike {}; +template +struct IsCvByteLike : IsByteLike {}; + // IsSafeCast::value is true if it is safe to reinterpret_cast a // value of type From to a value of type To. // @@ -49,33 +71,11 @@ struct IsSafeCast : std::false_type {}; template struct IsSafeCast : std::true_type {}; -// It's safe to cast a pointer to any character pointer. -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; - -// It's safe to cast a character pointer to a pointer to any type. -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; -template -struct IsSafeCast : std::true_type {}; +// It's safe to cast a pointer to/from a byte-like type. +template +struct IsSafeCast + : std::integral_constant::value || + IsCvByteLike::value> {}; // It's safe to cast a pointer to/from std::uintptr_t. template @@ -83,6 +83,12 @@ struct IsSafeCast : std::true_type {}; template struct IsSafeCast : std::true_type {}; +// It's safe to cast a pointer to/from std::intptr_t. +template +struct IsSafeCast : std::true_type {}; +template +struct IsSafeCast : std::true_type {}; + } // namespace internal // Like reinterpret_cast, but compiles only if it's safe. diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index 648a3cea2f59c3..e804e64906393e 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -33,11 +33,10 @@ TEST(SafeReinterpretCast, CanCastPointerToFromConstCharPointer) { TEST(SafeReinterpretCast, CanCastPointerToFromConstBytePointer) { const int x = 42; - const ::std::byte* const char_p = - safe_reinterpret_cast(&x); + const std::byte* const char_p = safe_reinterpret_cast(&x); EXPECT_EQ( - char_p, // - reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. &x)); const int* const int_p = safe_reinterpret_cast(char_p); @@ -67,16 +66,40 @@ TEST(SafeReinterpretCast, CanCastPointerToFromMutableCharPointer) { EXPECT_EQ(int_p, &x); } +TEST(SafeReinterpretCast, CanCastBetweenByteLikePointers) { + char x = 'A'; + std::byte* const byte_p = safe_reinterpret_cast(&x); + EXPECT_EQ(byte_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + &x)); + + unsigned char* const unsigned_char_p = + safe_reinterpret_cast(&x); + EXPECT_EQ(unsigned_char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for + // testing. + &x)); +} + TEST(SafeReinterpretCast, CanCastPointerToFromStdUintptrT) { const int x = 42; - const ::std::uintptr_t uintptr_t_p = - safe_reinterpret_cast<::std::uintptr_t>(&x); + const std::uintptr_t uintptr_t_p = safe_reinterpret_cast(&x); EXPECT_EQ( - uintptr_t_p, // - reinterpret_cast<::std::uintptr_t>( // REINTERPRET_CAST_OK=for testing. + uintptr_t_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. &x)); EXPECT_EQ(safe_reinterpret_cast(uintptr_t_p), &x); } +TEST(SafeReinterpretCast, CanCastPointerToFromStdIntptrT) { + const int x = 42; + const std::intptr_t intptr_t_p = safe_reinterpret_cast(&x); + EXPECT_EQ( + intptr_t_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + &x)); + EXPECT_EQ(safe_reinterpret_cast(intptr_t_p), &x); +} + } // namespace } // namespace tsl diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h index 2a5f11b497b007..51b1e8f1c0bc78 100644 --- a/third_party/xla/xla/util.h +++ b/third_party/xla/xla/util.h @@ -50,6 +50,7 @@ limitations under the License. #include "xla/tsl/lib/math/math_util.h" #include "xla/tsl/platform/errors.h" // IWYU pragma: keep #include "xla/tsl/platform/logging.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/bfloat16.h" @@ -160,7 +161,8 @@ class ScopedLoggingTimer { template absl::Span CastToByteSlice(absl::Span slice) { return absl::Span( - reinterpret_cast(slice.data()), slice.size() * sizeof(T)); + tsl::safe_reinterpret_cast(slice.data()), + slice.size() * sizeof(T)); } // Casts a byte slice to a non-byte type T, checking that the original slice @@ -168,7 +170,7 @@ absl::Span CastToByteSlice(absl::Span slice) { template absl::Span CastByteSlice(absl::Span slice) { CHECK_EQ(0, slice.size() % sizeof(T)); - return absl::Span(reinterpret_cast(slice.data()), + return absl::Span(tsl::safe_reinterpret_cast(slice.data()), slice.size() / sizeof(T)); } diff --git a/third_party/xla/xla/util_test.cc b/third_party/xla/xla/util_test.cc index a192407eac118b..10d3ee6f2e98dc 100644 --- a/third_party/xla/xla/util_test.cc +++ b/third_party/xla/xla/util_test.cc @@ -411,6 +411,17 @@ TEST(UtilTest, MaybeOwningTestShared) { EXPECT_EQ(c1.get(), c2.get()); } +TEST(UtilTest, MaybeOwningTestSharedNoCharType) { + auto owner = std::make_unique(); + *owner = 42; + MaybeOwning i1(owner.get()); + MaybeOwning i2(owner.get()); + + EXPECT_EQ(*i1, 42); + EXPECT_EQ(*i2, 42); + EXPECT_EQ(i1.get(), i2.get()); +} + TEST(UtilTest, PrintAllFields) { // Here we are using one of the bool fields that has the default value to // false and ensuring that it is always printed. From 0359ec6c0b614c06176530893b7b03c2285b4b99 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Apr 2025 20:57:38 -0700 Subject: [PATCH 0264/1324] Automated Code Change PiperOrigin-RevId: 744157085 --- .../xla/hlo/ir/dynamic_parameter_binding.cc | 3 +- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 38 ++++++++++--------- .../xla/xla/hlo/ir/hlo_instructions.cc | 8 ++-- third_party/xla/xla/hlo/ir/hlo_sharding.cc | 10 ++--- 4 files changed, 31 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc b/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc index 2b794b204a13c0..9803f33a37b116 100644 --- a/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc +++ b/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc @@ -98,7 +98,8 @@ absl::Status DynamicParameterBinding::Verify( computation.parameter_instruction(dynamic_dimension.parameter_num) ->shape(), dynamic_dimension.parameter_index) - .dimensions_size()); + .dimensions() + .size()); return absl::OkStatus(); }); } diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 3fafafa4ced082..550a474059b265 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -865,9 +865,10 @@ absl::StatusOr> HloInstruction::CreateFromProto( slice_sizes.resize(input->shape().tuple_shapes_size()); for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) { slice_sizes[i].resize( - input->shape().tuple_shapes(i).dimensions_size()); + input->shape().tuple_shapes(i).dimensions().size()); for (int j = 0; - j < input->shape().tuple_shapes(i).dimensions_size(); ++j) { + j < input->shape().tuple_shapes(i).dimensions().size(); + ++j) { CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index); slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index); proto_index += 1; @@ -885,9 +886,9 @@ absl::StatusOr> HloInstruction::CreateFromProto( input_start_indices->shape().tuple_shapes(i)); ++j) { slice_sizes[slice_sizes_count].resize( - input->shape().tuple_shapes(i).dimensions_size()); + input->shape().tuple_shapes(i).dimensions().size()); for (int k = 0; - k < input->shape().tuple_shapes(i).dimensions_size(); + k < input->shape().tuple_shapes(i).dimensions().size(); ++k) { CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index); slice_sizes[slice_sizes_count][k] = @@ -905,16 +906,16 @@ absl::StatusOr> HloInstruction::CreateFromProto( for (int i = 0; i < ShapeUtil::TupleElementCount(input_start_indices->shape()); ++i) { - slice_sizes[i].resize(input->shape().dimensions_size()); - for (int j = 0; j < input->shape().dimensions_size(); ++j) { + slice_sizes[i].resize(input->shape().dimensions().size()); + for (int j = 0; j < input->shape().dimensions().size(); ++j) { slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index); proto_index += 1; } } } else { slice_sizes.resize(1); - slice_sizes[0].resize(input->shape().dimensions_size()); - for (int j = 0; j < input->shape().dimensions_size(); ++j) { + slice_sizes[0].resize(input->shape().dimensions().size()); + for (int j = 0; j < input->shape().dimensions().size(); ++j) { slice_sizes[0][j] = proto.dynamic_slice_sizes(proto_index); proto_index += 1; } @@ -1065,8 +1066,8 @@ absl::StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); // TODO(b/118437727): Old form, make the check unconditional. if (proto.operand_ids_size() != 2 || - operands(1)->shape().dimensions_size() != 1) { - auto expected_operands = 1 + operands(0)->shape().dimensions_size(); + operands(1)->shape().dimensions().size() != 1) { + auto expected_operands = 1 + operands(0)->shape().dimensions().size(); TF_RET_CHECK(proto.operand_ids_size() == expected_operands) << "DynamicSlice instruction should have " << expected_operands << " operands, but has " << proto.operand_ids_size(); @@ -1084,8 +1085,8 @@ absl::StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); // TODO(b/118437727): Old form, make the check unconditional. if (proto.operand_ids_size() != 3 || - operands(2)->shape().dimensions_size() != 1) { - auto expected_operands = 2 + operands(0)->shape().dimensions_size(); + operands(2)->shape().dimensions().size() != 1) { + auto expected_operands = 2 + operands(0)->shape().dimensions().size(); TF_RET_CHECK(proto.operand_ids_size() == expected_operands) << "DynamicUpdateSlice instruction should have " << expected_operands << " operands, but has " @@ -2220,7 +2221,8 @@ HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, absl::FunctionRef)> adder) { CHECK(ShapeUtil::IsScalar(operand->shape()) || - operand->shape().dimensions_size() == output_shape.dimensions_size()); + operand->shape().dimensions().size() == + output_shape.dimensions().size()); Shape broadcast_shape = ShapeUtil::ChangeElementType( output_shape, operand->shape().element_type()); // Do explicit broadcast for scalar. @@ -2238,7 +2240,7 @@ HloInstruction::CreateBroadcastSequence( // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < operand->shape().dimensions_size(); i++) { + for (int i = 0; i < operand->shape().dimensions().size(); i++) { if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); @@ -2298,7 +2300,7 @@ HloInstruction::CreateDynamicReshape( ShapeUtil::StaticExtentProduct(data_operand[0].shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(data_operand[0].shape()); - CHECK_EQ(shape.dimensions_size(), dim_sizes.size()); + CHECK_EQ(shape.dimensions().size(), dim_sizes.size()); return std::make_unique(shape, data_operand, dim_sizes); } @@ -4937,11 +4939,11 @@ static UseKind OperandElementUse(const HloInstruction& instr, *instr.fused_expression_root()); case HloOpcode::kDot: // Matrix-vector dots do not reuse the matrix operand. - if (instr.shape().dimensions_size() <= 1) { + if (instr.shape().dimensions().size() <= 1) { if ((operand_num == 0 && - instr.operand(1)->shape().dimensions_size() <= 1) || + instr.operand(1)->shape().dimensions().size() <= 1) || (operand_num == 1 && - instr.operand(0)->shape().dimensions_size() <= 1)) { + instr.operand(0)->shape().dimensions().size() <= 1)) { return UseKind::kUse; } } diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 806db537674589..f6d1a73b47d3a6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -1516,7 +1516,7 @@ HloTransposeInstruction::HloTransposeInstruction( bool HloTransposeInstruction::IsRank2Transpose() const { return dimensions() == std::vector({1, 0}) && - shape().dimensions_size() == 2 && + shape().dimensions().size() == 2 && std::equal(shape().dimensions().begin(), shape().dimensions().end(), operand(0)->shape().dimensions().rbegin()); } @@ -1618,7 +1618,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape, AppendComputation(map_computation); // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. - dimensions_.resize(shape.dimensions_size()); + dimensions_.resize(shape.dimensions().size()); std::iota(dimensions_.begin(), dimensions_.end(), 0); } @@ -1634,7 +1634,7 @@ bool HloMapInstruction::IsElementwiseImpl( const std::optional& operand_idx) const { if (!dimensions().empty()) { // Check that the map is executed in elementwise compatible dimensions. - if (dimensions().size() != shape().dimensions_size()) { + if (dimensions().size() != shape().dimensions().size()) { return false; } for (int i = 0; i < dimensions().size(); ++i) { @@ -3567,7 +3567,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { if (new_operands.size() == 2 && - new_operands[1]->shape().dimensions_size() == 1) { + new_operands[1]->shape().dimensions().size() == 1) { // TODO(b/118437727): Old form, remove this path. return std::make_unique( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index c772f31573d123..883f36eb93d61a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -128,7 +128,7 @@ HloSharding HloSharding::AssignDevice(int64_t device_id, HloSharding HloSharding::Tile1D(const Shape& input_shape, int64_t num_tiles, absl::Span metadata) { - CHECK_EQ(1, input_shape.dimensions_size()); + CHECK_EQ(1, input_shape.dimensions().size()); CHECK_GT(num_tiles, 1); absl::Span dimensions(&num_tiles, 1); return HloSharding(TileAssignment(dimensions, dimensions, {0}), @@ -577,9 +577,9 @@ std::vector HloSharding::TileOffsetForDevice(const Shape& shape, CHECK(!IsUnknown()); if (maximal_) { - return std::vector(shape.dimensions_size(), 0); + return std::vector(shape.dimensions().size(), 0); } - CHECK_EQ(shape.dimensions_size(), TiledDataRank()); + CHECK_EQ(shape.dimensions().size(), TiledDataRank()); std::vector index = TileIndexForDevice(device); for (int64_t i = 0; i < index.size(); ++i) { const int64_t shape_dim = shape.dimensions(i); @@ -600,7 +600,7 @@ std::vector HloSharding::TileLimitForDevice(const Shape& shape, shape.dimensions().end()); } - CHECK_EQ(shape.dimensions_size(), TiledDataRank()); + CHECK_EQ(shape.dimensions().size(), TiledDataRank()); std::vector index = TileIndexForDevice(device); for (int64_t i = 0; i < index.size(); ++i) { const int64_t shape_dim = shape.dimensions(i); @@ -776,7 +776,7 @@ absl::Status HloSharding::ValidateNonTuple( } // The tile assignment tensor must have the same rank as the tiled data rank. - if (shape.dimensions_size() != TiledDataRank()) { + if (shape.dimensions().size() != TiledDataRank()) { return tsl::errors::InvalidArgument( "Number of tile assignment dimensions (excluding subgroups) is " "different than the input rank. " From 0219a2fd47b15290f31c91d31f5e08074312d59c Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Fri, 4 Apr 2025 21:11:09 -0700 Subject: [PATCH 0265/1324] [NFC] Refactor `GetExecutableExtras()` to separate the compilation option update to another function `UpdateCompileOptions()`. The `UpdateCompileOptions()` will be used by the 'Compile()` implementation which returns an unloaded executable. PiperOrigin-RevId: 744160124 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 5 ++- .../xla/pjrt/pjrt_stream_executor_client.cc | 41 +++++++++++++------ .../xla/pjrt/pjrt_stream_executor_client.h | 8 +++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index df901c978c9e45..12393ed9910c09 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -905,8 +905,9 @@ StreamExecutorGpuClient::Load(std::unique_ptr executable) { CompileOptions compile_options = se_executable->compile_options(); CompileOptions input_options = compile_options; TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); - TF_ASSIGN_OR_RETURN(ExecutableExtras extras, - GetExecutableExtras(&compile_options)); + TF_ASSIGN_OR_RETURN( + ExecutableExtras extras, + UpdateCompileOptionsAndGetExecutableExtras(&compile_options)); // Load Executable from AOT compilation result. std::vector> local_executables; diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 4d687f66ccecc3..ad60e2ba4143a5 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -3414,15 +3414,7 @@ PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const { return out; } -absl::StatusOr -PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { - ExecutableExtras extras; - std::shared_ptr& device_assignment = - extras.device_assignment; - std::vector& - addressable_device_logical_ids = extras.addressable_device_logical_ids; - std::vector& addressable_devices = extras.addressable_devices; - +void PjRtStreamExecutorClient::UpdateCompileOptions(CompileOptions* options) { ExecutableBuildOptions& build_options = options->executable_build_options; if (!build_options.compile_thread_pool()) { build_options.set_compile_thread_pool(thread_pool()); @@ -3460,6 +3452,26 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { build_options.set_layout_canonicalization_callback(layout_callback); + if (build_options.device_ordinal() < 0) { + build_options.set_device_ordinal(0); + } +} + +absl::StatusOr +PjRtStreamExecutorClient::UpdateCompileOptionsAndGetExecutableExtras( + CompileOptions* options) { + const int original_device_ordinal = + options->executable_build_options.device_ordinal(); + + UpdateCompileOptions(options); + + ExecutableExtras extras; + std::shared_ptr& device_assignment = + extras.device_assignment; + std::vector& + addressable_device_logical_ids = extras.addressable_device_logical_ids; + std::vector& addressable_devices = extras.addressable_devices; + int num_replicas; int num_partitions; TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions( @@ -3503,7 +3515,8 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { device_assignment->ToString()); } - if (build_options.device_ordinal() < 0) { + ExecutableBuildOptions& build_options = options->executable_build_options; + if (original_device_ordinal < 0) { build_options.set_device_ordinal( addressable_devices.front()->local_hardware_id().value()); } @@ -3530,7 +3543,8 @@ PjRtStreamExecutorClient::CompileInternal( TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); - TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options)); + TF_ASSIGN_OR_RETURN(ExecutableExtras extras, + UpdateCompileOptionsAndGetExecutableExtras(&options)); std::shared_ptr& device_assignment = extras.device_assignment; std::vector& @@ -3719,8 +3733,9 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( "PjRtStreamExecutorClient::DeserializeExecutable"); VLOG(1) << "PjRtStreamExecutorClient::DeserializeExecutable"; - TF_ASSIGN_OR_RETURN(ExecutableExtras extras, - GetExecutableExtras(&compile_options)); + TF_ASSIGN_OR_RETURN( + ExecutableExtras extras, + UpdateCompileOptionsAndGetExecutableExtras(&compile_options)); std::shared_ptr& device_assignment = extras.device_assignment; std::vector& diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index f8fee79850bdef..80db814c8846fe 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -410,7 +410,13 @@ class PjRtStreamExecutorClient : public PjRtClient { addressable_device_logical_ids; std::vector addressable_devices; }; - absl::StatusOr GetExecutableExtras(CompileOptions* options); + + // Updates `options` for compilation. + void UpdateCompileOptions(CompileOptions* options); + + // Same as above, but also returns the executable extras. + absl::StatusOr UpdateCompileOptionsAndGetExecutableExtras( + CompileOptions* options); absl::StatusOr> CompileInternal( const XlaComputation& computation, From f76c73b671b72d098bb55e0f6da8bae00ac44c84 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Fri, 4 Apr 2025 21:20:20 -0700 Subject: [PATCH 0266/1324] [XLA:MSA] Fix a bug in ConsumeResource in MSA algorithm. The resource type is float and is initialized by calling GetInstructionElapsed at each schedule time. However, since the GetInstructionElapsed is measured in seconds, the resource values are very small floats (e.g., 1e-10). When consuming resource, we perform floating point operation on these small values resulting in error (comparing to zero and etc...). This cl scales the float resource values to int64 type by a constant power of 2 scaling factor. PiperOrigin-RevId: 744161807 --- .../memory_space_assignment/algorithm.cc | 66 ++++++++++--------- .../memory_space_assignment/algorithm.h | 34 ++++++++-- .../memory_space_assignment_test.cc | 22 +++++++ 3 files changed, 86 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 914dea385720e4..6c8bff8b26f21e 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -3229,16 +3229,16 @@ bool AsynchronousCopyOrdering::ViolatesOrdering(int64_t exclusive_start_time, } bool AsynchronousCopyResource::ConsumeResource( - int64_t exclusive_start_time, int64_t end_time, float resource, - std::vector>* delay_changes, - float resource_to_free) { + int64_t exclusive_start_time, int64_t end_time, int64_t resource, + std::vector>* delay_changes, + int64_t resource_to_free) { // Cache the pointers to the arrays to avoid the overhead of `operator[]` // size checks in hardened libc++. // // NOTE: Do not modify the vectors `initial_resources_` or `delay_` in this // function, otherwise the pointers will become dangling. - float* initial_resources_ptr = initial_resources_.data(); - float* delay_ptr = delay_.data(); + int64_t* initial_resources_scaled_ptr = initial_resources_scaled_.data(); + int64_t* delay_ptr = delay_.data(); std::list::iterator current_copy = async_copies_.end(); // In order to propagate the resource to the next scheduled copy, we iterate @@ -3247,7 +3247,7 @@ bool AsynchronousCopyResource::ConsumeResource( // resource (and return false). while (true) { // resource is modified below. We save its initial value for logging below. - const float amount_requested = resource; + const int64_t amount_requested = resource; VLOG(3) << "Consume resource: start time_exclusive = " << exclusive_start_time << ", end time = " << end_time @@ -3261,7 +3261,7 @@ bool AsynchronousCopyResource::ConsumeResource( end_time); // Nothing to do if we're not adding or removing any resources. - if (resource == 0.0 && resource_to_free == 0.0) { + if (resource == 0 && resource_to_free == 0) { return true; } @@ -3290,13 +3290,14 @@ bool AsynchronousCopyResource::ConsumeResource( // Check if this copy will push the next copy later in time (or if removing // the resource, check if the removal of this copy move the next copy // earlier in time). - std::optional delay_for_next_copy = std::nullopt; - float resource_freed = 0.0; + std::optional delay_for_next_copy = std::nullopt; + int64_t resource_freed = 0; for (int64_t time = ExclusiveToInclusiveStartTime(exclusive_start_time); time < end_time && resource != 0; ++time) { + int64_t initial_resource_scaled = initial_resources_scaled_ptr[time]; // Iterate over the logical times that this copy spans. Note that the // start and end time ranges are exclusive. - float used_resource = std::min(resource, initial_resources_ptr[time]); + int64_t used_resource = std::min(resource, initial_resource_scaled); if (next_copy != async_copies_.end() && next_copy->exclusive_start_time == InclusiveToExclusiveStartTime(time)) { @@ -3309,13 +3310,13 @@ bool AsynchronousCopyResource::ConsumeResource( if (!delay_for_next_copy.has_value()) { // Update the delay_ vector and resource_freed variable with the amount // that was freed when removing the copy. - float old_delay = delay_ptr[time]; - float old_resource = - std::max(0.0f, initial_resources_ptr[time] - old_delay); - float new_delay = std::max(0.0f, resource - resource_to_free); - float new_resource = - std::max(0.0f, initial_resources_ptr[time] - new_delay); - resource_freed += std::max(0.0f, new_resource - old_resource); + int64_t old_delay = delay_ptr[time]; + int64_t old_resource = + std::max(0, initial_resource_scaled - old_delay); + int64_t new_delay = std::max(0, resource - resource_to_free); + int64_t new_resource = + std::max(0, initial_resource_scaled - new_delay); + resource_freed += std::max(0, new_resource - old_resource); delay_ptr[time] = new_delay; if (delay_changes) { delay_changes->emplace_back(time, old_delay); @@ -3325,7 +3326,8 @@ bool AsynchronousCopyResource::ConsumeResource( resource -= used_resource; } - // If resource isn't satisfied by the end, we didn't have enough resources. + // If resource isn't satisfied by the end, we didn't have enough + // resources. if (resource > 0) { VLOG(3) << "Doesn't have enough resource; requested resource = " << amount_requested << "; leftover resources = " << resource; @@ -3340,14 +3342,15 @@ bool AsynchronousCopyResource::ConsumeResource( // removed. exclusive_start_time = next_copy->exclusive_start_time; end_time = next_copy->end_time; - resource = *delay_for_next_copy + next_copy->resource; + resource = + *delay_for_next_copy + GetScaledIntegerResource(next_copy->resource); current_copy = next_copy; } } void AsynchronousCopyResource::AddCopy(const AsynchronousCopy& copy) { - CHECK( - ConsumeResource(copy.exclusive_start_time, copy.end_time, copy.resource)); + CHECK(ConsumeResource(copy.exclusive_start_time, copy.end_time, + GetScaledIntegerResource(copy.resource))); // Find the iterator for the copy that would be right after this copy and put // this copy right before it in async_copies_. @@ -3413,10 +3416,11 @@ void AsynchronousCopyResource::RemoveCopy( CHECK(std::next(copy_it) == async_copies_.end() || std::next(copy_it)->exclusive_start_time > copy_it->exclusive_start_time); - CHECK(ConsumeResource(copy_it->exclusive_start_time, copy_it->end_time, - /*resource=*/0, - /*delay_changes=*/nullptr, - /*resource_to_free=*/copy_it->resource)); + CHECK(ConsumeResource( + copy_it->exclusive_start_time, copy_it->end_time, + /*resource=*/0, + /*delay_changes=*/nullptr, + /*resource_to_free=*/GetScaledIntegerResource(copy_it->resource))); // If the copy to be removed is the value pointed by async_copy_time_map_, we // make the next copy with the same start time to be pointed by // async_copy_time_map_. If there are no such copies, we remove the key for @@ -3437,10 +3441,11 @@ void AsynchronousCopyResource::RemoveCopy( bool AsynchronousCopyResource::HasEnoughResource(int64_t exclusive_start_time, int64_t end_time, float resource) { - std::vector> delay_changes; + std::vector> delay_changes; delay_changes.reserve(delay_.size()); bool result = - ConsumeResource(exclusive_start_time, end_time, resource, &delay_changes); + ConsumeResource(exclusive_start_time, end_time, + GetScaledIntegerResource(resource), &delay_changes); // Apply the delay changes in reverse order. This ensures that the original // value of each delay is restored. if (!delay_changes.empty()) { @@ -3454,11 +3459,12 @@ bool AsynchronousCopyResource::HasEnoughResource(int64_t exclusive_start_time, bool AsynchronousCopyResource::HasEnoughResourceMultiCheck( const std::vector& specs) { - std::vector> delay_changes; + std::vector> delay_changes; delay_changes.reserve(delay_.size()); bool result = absl::c_all_of(specs, [&](const ResourceSpec& spec) { return ConsumeResource(spec.exclusive_start_time, spec.end_time, - spec.resource, &delay_changes); + GetScaledIntegerResource(spec.resource), + &delay_changes); }); // Apply the delay changes in reverse order. This ensures that the original // value of each delay is restored. @@ -3492,7 +3498,7 @@ std::string AsynchronousCopyResource::Dump( for (int i = start_time; i < end_time; ++i) { time_dump_data.push_back({ initial_resources_[i], - delay_[i], + GetDescaledFloatResource(delay_[i]), available[i], /*overlapping_copies=*/{}, }); diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 78ebbca77c0112..b32340cb8927ae 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -185,7 +185,12 @@ class AsynchronousCopyResource { // The constructor needs the initial resources. explicit AsynchronousCopyResource(absl::Span initial_resources) : initial_resources_(initial_resources.begin(), initial_resources.end()), - delay_(initial_resources.size(), 0) {} + delay_(initial_resources.size(), 0) { + for (int i = 0; i < initial_resources.size(); ++i) { + initial_resources_scaled_.push_back( + GetScaledIntegerResource(initial_resources[i])); + } + } // Adds the given asynchronous copy and updates the current resources. CHECK // fails if there aren't enough resources to satisfy this copy (the caller @@ -204,13 +209,24 @@ class AsynchronousCopyResource { // order specified. bool HasEnoughResourceMultiCheck(const std::vector& specs); + int64_t GetScaledIntegerResource(float resource) const { + float scaled_value = resource * kCopyResourceIntScale; + int64_t scaled_value_int = static_cast(scaled_value); + return scaled_value_int; + } + + float GetDescaledFloatResource(int64_t scaled_resource) const { + return scaled_resource / kCopyResourceIntScale; + } + // This is only used for debugging and testing purposes, it returns the // currently available resource at each logical time. std::vector GetCurrentResources() const { std::vector current_resources(initial_resources_.begin(), initial_resources_.end()); for (int i = 0; i < current_resources.size(); ++i) { - current_resources[i] -= std::min(current_resources[i], delay_[i]); + current_resources[i] -= + std::min(current_resources[i], GetDescaledFloatResource(delay_[i])); } return current_resources; } @@ -220,6 +236,11 @@ class AsynchronousCopyResource { std::string Dump(int64_t start_time, int64_t end_time, MemorySpace memory_space_filter) const; + // The scale factor to convert a float resource to an integer resource. Note + // that is a power of 2 to avoid introducing noise when casting the scaled + // value to an int64_t. + static constexpr int64_t kCopyResourceIntScale = 1ULL << 50; + private: // Internal helper method to implement adding/removing/checking resources. // ConsumeResource() may modify delay_. If delay_changes is not null, @@ -227,9 +248,9 @@ class AsynchronousCopyResource { // delay_changes, allowing callers to undo any modifications by iterating over // the vector in reverse order. bool ConsumeResource( - int64_t exclusive_start_time, int64_t end_time, float resource, - std::vector>* delay_changes = nullptr, - float resource_to_free = 0.0); + int64_t exclusive_start_time, int64_t end_time, int64_t resource, + std::vector>* delay_changes = nullptr, + int64_t resource_to_free = 0.0); // Same as the public RemoveCopy except it works on the async_copies_ // iterator. Assumes copy_it points to the last copy for its start time; @@ -253,7 +274,8 @@ class AsynchronousCopyResource { std::map::iterator> async_copy_time_map_; #endif std::vector initial_resources_; - std::vector delay_; + std::vector initial_resources_scaled_; + std::vector delay_; }; // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 6e87950496f0b9..982e87481ac26d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -9946,6 +9946,28 @@ TEST_F(AsynchronousCopyResourceTest, StartAtZeroAndRemove) { std::vector({0.0, 0.0, 0.0, 0.0, 2.0})); } +// Below test only works when the resource values are scaled to int64 to avoid +// floating point precision issues. +TEST_F(AsynchronousCopyResourceTest, ConsumeResourceScaledIntegerResource) { + auto alternate_mem_space = MemorySpace::kAlternate; + AsynchronousCopyResource resource( + {5.71429e-10, 8.71333e-09, 8.71333e-09, 1.74267e-08, 1.74267e-08}); + AsynchronousCopy copy1{0, 2, 8.71333e-09, alternate_mem_space, 0}; + EXPECT_TRUE(resource.HasEnoughResource(0, 2, 8.71333e-09)); + resource.AddCopy(copy1); + + AsynchronousCopy copy2{0, 3, 4.35667e-09, alternate_mem_space, 1}; + EXPECT_TRUE(resource.HasEnoughResource(0, 3, 4.35667e-09)); + resource.AddCopy(copy2); + + AsynchronousCopy copy3{2, 4, 4.35667e-09, alternate_mem_space, 2}; + EXPECT_TRUE(resource.HasEnoughResource(2, 4, 4.35667e-09)); + resource.AddCopy(copy3); + + // This call to RemoveCopy should not cause a crash. + resource.RemoveCopy(copy1); +} + TEST_F(AsynchronousCopyResourceTest, OutOfOrderRemovalSameStartTime) { // time: 0 1 2 3 4 // resource: 2 2 2 2 2 From 51ef47d476b982b278a79d0d1885699678cffbfa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 5 Apr 2025 02:02:40 -0700 Subject: [PATCH 0267/1324] Update GraphDef version to 2188. PiperOrigin-RevId: 744208694 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index bd134534929aad..2497c8e30074fa 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2187 // Updated: 2025/4/4 +#define TF_GRAPH_DEF_VERSION 2188 // Updated: 2025/4/5 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From bb5f2318015a9c3cc1c29231b6aab92c3232863d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 5 Apr 2025 02:02:42 -0700 Subject: [PATCH 0268/1324] compat: Update forward compatibility horizon to 2025-04-05 PiperOrigin-RevId: 744208701 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index d6f9f971055063..ad80a4ac3acad9 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 4) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 5) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 40e061d00915f30909c2bf91a467aacac0b24e2b Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Sat, 5 Apr 2025 02:40:13 -0700 Subject: [PATCH 0269/1324] Support FP8 convolutions on Ada with cuDNN >= 9.8. FP8 convolutions fail on cuDNN 9.0, but work on 9.8. I haven't tried any cuDNN versions in between, so I enable them for cuDNN 9.8 and above. PiperOrigin-RevId: 744214591 --- .../xla/xla/service/gpu/nvptx_compiler.cc | 2 +- .../xla/xla/service/gpu/transforms/BUILD | 1 + .../service/gpu/transforms/conv_rewriter.cc | 38 +++- .../service/gpu/transforms/conv_rewriter.h | 10 +- .../transforms/cudnn_fused_conv_rewriter.cc | 11 +- .../cudnn_fused_conv_rewriter_test.cc | 200 ++++++++---------- .../cuda/cuda_compute_capability.h | 2 + 7 files changed, 135 insertions(+), 129 deletions(-) diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index fd74b5b5b06470..51b0a8ba3ce235 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -198,7 +198,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( if (!hlo_module->config() .debug_options() .xla_gpu_experimental_disable_binary_libraries()) { - pipeline.AddPass(cuda_compute_capability); + pipeline.AddPass(cuda_compute_capability, dnn_version); pipeline.AddPass(cuda_compute_capability, dnn_version, toolkit_version); pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index dfd5bb6144e019..13cc6cb5719f21 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -804,6 +804,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc index 567d66ac7a0b0a..714c7561868192 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" @@ -54,9 +55,10 @@ namespace gpu { namespace { -absl::Status CheckTypes(HloInstruction* conv, - const se::GpuComputeCapability cc) { - auto valid_shape = [conv, &cc](const Shape& shape) -> absl::Status { +absl::Status CheckTypes(HloInstruction* conv, const se::GpuComputeCapability cc, + const se::dnn::VersionInfo dnn_version) { + auto valid_shape = [conv, &cc, + &dnn_version](const Shape& shape) -> absl::Status { PrimitiveType type = shape.element_type(); if (!primitive_util::IsFloatingPointType(type) && !primitive_util::IsIntegralType(type)) { @@ -81,6 +83,16 @@ absl::Status CheckTypes(HloInstruction* conv, "FP8 convolutions are only supported on CUDA GPUs, but got " "FP8 convolution on ROCm GPU: %s", conv->ToString()); + } + if (dnn_version >= se::dnn::VersionInfo{9, 8, 0}) { + if (!std::get(cc).IsAtLeastAda()) { + return Unimplemented( + "FP8 convolutions are only supported on CUDA GPUs with compute " + "capability at least 8.9, but got " + "FP8 convolution on GPU with compute capability %s: %s", + std::get(cc).ToString(), + conv->ToString()); + } } else if (!std::get(cc).IsAtLeastHopper()) { return Unimplemented( "FP8 convolutions are only supported on CUDA GPUs with compute " @@ -762,8 +774,9 @@ CudnnConvBackendConfig GetDefaultBackendConfig() { // Helper function to create a custom_call instruction to replace the given // conv instruction static absl::StatusOr CreateCustomCallHelper( - HloInstruction* conv, const se::GpuComputeCapability& cc) { - TF_RETURN_IF_ERROR(CheckTypes(conv, cc)); + HloInstruction* conv, const se::GpuComputeCapability& cc, + const se::dnn::VersionInfo& dnn_version) { + TF_RETURN_IF_ERROR(CheckTypes(conv, cc, dnn_version)); if (ConvolutionMatch m = MatchBackwardInput(conv)) { auto& [window, dnums, rhs] = *m; return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(), @@ -798,11 +811,12 @@ static absl::StatusOr CreateCustomCallHelper( // Tries to rewrite a single convolution into a call to cudnn/miopen. absl::StatusOr RunOnInstruction(HloInstruction* conv, - const se::GpuComputeCapability& cc) { + const se::GpuComputeCapability& cc, + const se::dnn::VersionInfo& dnn_version) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); TF_ASSIGN_OR_RETURN(HloInstruction * custom_call, - CreateCustomCallHelper(conv, cc)); + CreateCustomCallHelper(conv, cc, dnn_version)); if (custom_call == nullptr) { return false; } @@ -827,7 +841,8 @@ absl::StatusOr RunOnInstruction(HloInstruction* conv, // cudnn/miopen. // Returns true if it made any changes. absl::StatusOr RunOnComputation(HloComputation* computation, - const se::GpuComputeCapability& cc) { + const se::GpuComputeCapability& cc, + const se::dnn::VersionInfo dnn_version) { std::vector convs; for (auto* hlo : computation->instructions()) { if (HloPredicateIsOp(hlo)) { @@ -837,7 +852,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation, bool changed = false; for (HloInstruction* conv : convs) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc)); + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc, dnn_version)); changed |= result; } return changed; @@ -851,8 +866,9 @@ absl::StatusOr ConvRewriter::Run( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, - RunOnComputation(computation, compute_capability_)); + TF_ASSIGN_OR_RETURN( + bool result, + RunOnComputation(computation, compute_capability_, dnn_version_)); changed |= result; } XLA_VLOG_LINES(2, "ConvRewriter::Run(), after:\n" + module->ToString()); diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h index 5ad7e7111807fa..d7074a459012cb 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/dnn.h" namespace xla { namespace gpu { @@ -36,8 +37,10 @@ namespace gpu { class ConvRewriter : public HloModulePass { public: - explicit ConvRewriter(const se::GpuComputeCapability& compute_capability) - : compute_capability_(compute_capability) {}; + explicit ConvRewriter( + const se::GpuComputeCapability& compute_capability, + se::dnn::VersionInfo dnn_version = se::dnn::VersionInfo{}) + : compute_capability_(compute_capability), dnn_version_(dnn_version) {}; absl::string_view name() const override { return "conv-rewriter"; } @@ -49,7 +52,8 @@ class ConvRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - se::GpuComputeCapability compute_capability_; + const se::GpuComputeCapability compute_capability_; + const se::dnn::VersionInfo dnn_version_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc index e5f033611f6745..71f30a90f6046d 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc @@ -832,8 +832,8 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution, } // Matches convolutions operating on FP8 inputs and filters and rewrites into a -// ForwardGraph Custom Call. For scaled FP8 convolutions on Hopper systems, the -// following steps are elided and rewritten into a ForwardGraph Custom Call: +// ForwardGraph Custom Call. For scaled FP8 convolutions, the following steps +// are elided and rewritten into a ForwardGraph Custom Call: // // 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32. // 2. Optionally unscale the filter and input by multiplying or dividing by @@ -851,7 +851,12 @@ absl::StatusOr F8GraphConv(HloComputation* comp, if (toolkit_version < se::SemanticVersion{12, 0, 0}) { return false; } - if (!cc.IsAtLeast(se::CudaComputeCapability::kHopper)) { + if (!cc.IsAtLeastAda()) { + return false; + } + if (dnn_version < se::dnn::VersionInfo{9, 8, 0} && !cc.IsAtLeastHopper()) { + // Ada is not supported on older cuDNN versions, and instead Hopper or later + // is required. return false; } for (auto instr : comp->MakeInstructionPostOrder()) { diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index e52961faad06ce..b0e0ac3a78c823 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -97,6 +97,15 @@ class CudnnFusedConvRewriterHloTest : public HloTestBase { .runtime_version(); } + ConvRewriter GetConvRewriter() const { + return ConvRewriter(GetCudaComputeCapability(), GetDnnVersion()); + } + + CudnnFusedConvRewriter GetCudnnFusedConvRewriter() const { + return CudnnFusedConvRewriter(GetCudaComputeCapability(), GetDnnVersion(), + GetToolkitVersion()); + } + CudnnFusedConvRewriterHloTest() : HloTestBase(/*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/false, @@ -208,9 +217,13 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { void TestF8(std::string pre_hlo_string, std::string custom_call_string, std::string serialized_graph_string) { if (!IsCuda()) return; - if (GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::kHopper)) { - // On Hopper and newer architectures, test numerical correctness and + + bool fp8_supported = GetDnnVersion() >= se::dnn::VersionInfo{9, 8, 0} + ? GetCudaComputeCapability().IsAtLeastAda() + : GetCudaComputeCapability().IsAtLeastHopper(); + LOG(INFO) << "RRR fp8_supported: " << fp8_supported; + if (fp8_supported) { + // On Ada/Hopper and newer architectures, test numerical correctness and // verify the HLO of the Custom Call with operand and return layouts and // the serialized graph based on the full compiler pipeline. std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string); @@ -1605,10 +1618,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1639,10 +1651,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1680,10 +1691,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1721,10 +1731,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1775,10 +1784,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1821,10 +1829,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1876,10 +1883,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1916,10 +1922,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1957,10 +1962,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2007,7 +2011,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // elu fusion is only active on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -2056,10 +2060,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2109,7 +2112,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // relu6 fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -2153,10 +2156,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2201,7 +2203,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // Leaky-relu fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -2248,10 +2250,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2296,10 +2297,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2336,10 +2336,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2375,10 +2374,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2414,10 +2412,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2450,10 +2447,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2487,10 +2483,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2521,10 +2516,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2554,10 +2548,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2581,10 +2574,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2608,10 +2600,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2636,10 +2627,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2666,10 +2656,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2698,10 +2687,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2741,10 +2729,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2784,10 +2771,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2822,10 +2808,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2865,10 +2850,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2912,10 +2896,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2965,10 +2948,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -3021,10 +3003,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph, and fold @@ -3075,10 +3056,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph, and fold @@ -3139,10 +3119,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -3182,10 +3161,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter = GetConvRewriter(); TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - GetToolkitVersion()}; + CudnnFusedConvRewriter fuser = GetCudnnFusedConvRewriter(); TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h b/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h index 7c05ee995b6583..aa4865f12ef100 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h @@ -88,6 +88,8 @@ struct CudaComputeCapability { return major >= CudaComputeCapabilities::kAmpere; } + bool IsAtLeastAda() const { return IsAtLeast(8, 9); } + bool IsAtLeastHopper() const { return major >= CudaComputeCapabilities::kHopper; } From 6f8e2fcb6517db60e4197132897b5dd6470a40d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 5 Apr 2025 03:27:35 -0700 Subject: [PATCH 0270/1324] Automated Code Change PiperOrigin-RevId: 744221338 --- tensorflow/core/tpu/tpu_execute.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index d077a0bec849ff..12009787a1094c 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -134,10 +134,10 @@ absl::Status FixTupleTableAsync(se::Stream* stream, // "bounded_shape". bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { - if (dynamic_shape.dimensions_size() != bounded_shape.dimensions_size()) { + if (dynamic_shape.dimensions().size() != bounded_shape.dimensions().size()) { return false; } - for (int64_t i = 0; i < dynamic_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < dynamic_shape.dimensions().size(); ++i) { if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { return false; } From f2ed54a1b09655f9221753fd87b930edef794c2c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 5 Apr 2025 09:11:53 -0700 Subject: [PATCH 0271/1324] Update reachability when control dep is updated PiperOrigin-RevId: 744268804 --- .../collectives/collectives_schedule_linearizer.cc | 3 +++ .../collectives/collectives_schedule_linearizer_test.cc | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc index 772992b4e1c5d5..0d49977df1bc69 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc @@ -75,6 +75,9 @@ absl::StatusOr CollectivesScheduleLinearizer::Run( if (prev_done && !reachability->IsConnected(start, prev_done)) { // If prev_done and start are independent, enforce ordering. TF_RETURN_IF_ERROR(prev_done->AddControlDependencyTo(next)); + // Adding control dependency does not update the reachability map. + reachability->UpdateReachabilityThroughInstruction(start); + VLOG(1) << "Adding control dependency from " << prev_done->ToString() << " to " << start->ToString(); changed = true; diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc index 2ef88bbea8cc4b..d93acafcffa692 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc @@ -252,6 +252,13 @@ ENTRY entry { if (instruction->name() == "c1") EXPECT_TRUE(found_i0); if (instruction->name() == "i0") found_i0 = true; } + // Calling MakeInstructionPostOrder() again to verify idempotence. + auto post_order = module->entry_computation()->MakeInstructionPostOrder(); + found_i0 = false; + for (HloInstruction *instruction : post_order) { + if (instruction->name() == "c1") EXPECT_TRUE(found_i0); + if (instruction->name() == "i0") found_i0 = true; + } } } // namespace From 9385bc36c06e8a2cbe07d0a0730fcc32108d77f0 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Sat, 5 Apr 2025 19:05:04 -0700 Subject: [PATCH 0272/1324] Update tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc --- .../compiler/mlir/lite/transforms/default_quant_params.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 7acbb7d17240b8..995a878cfc47cf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -54,9 +54,7 @@ namespace { class DefaultQuantParamsPass : public impl::DefaultQuantParamsPassBase { public: - DefaultQuantParamsPass() - { - } + DefaultQuantParamsPass() {} explicit DefaultQuantParamsPass(double default_min, double default_max, bool is_signed) { From 3765146b92dbae03488caef5286d13230b32f992 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Sat, 5 Apr 2025 19:21:11 -0700 Subject: [PATCH 0273/1324] Integrate LLVM at llvm/llvm-project@69f59d59cb02 Updates LLVM usage to match [69f59d59cb02](https://github.com/llvm/llvm-project/commit/69f59d59cb02) PiperOrigin-RevId: 744348330 --- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 92 +------------------ third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 92 +------------------ .../xla/third_party/shardy/workspace.bzl | 4 +- 5 files changed, 16 insertions(+), 180 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 91166c32db1a11..4a58099b072de7 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" - LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" + LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" + LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index a292270454defc..4adb475a33423c 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,97 +1,15 @@ -diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..2e6ff58 100644 ---- a/third_party/llvm/generated.patch -+++ b/third_party/llvm/generated.patch -@@ -1 +1,23 @@ - Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp -+--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp -++++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp -+@@ -873,6 +873,7 @@ -+ DependentTemplateSpecializationTypeLoc SpecTL -+ = Builder.push(T); -+ SpecTL.setElaboratedKeywordLoc(SourceLocation()); -++ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); -+ SpecTL.setTemplateKeywordLoc(TemplateKWLoc); -+ SpecTL.setTemplateNameLoc(TemplateNameLoc); -+ SpecTL.setLAngleLoc(LAngleLoc); -+diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+@@ -1902,7 +1902,6 @@ -+ name = "inv_trigf_utils", -+ srcs = ["src/math/generic/inv_trigf_utils.cpp"], -+ hdrs = [ -+- "src/math/generic/atan_utils.h", -+ "src/math/generic/inv_trigf_utils.h", -+ ], -+ deps = [ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index fd9baec..91166c3 100644 +index 91166c3..4a58099 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" -- LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" -+ LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" -+ LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" +- LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" +- LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" ++ LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" ++ LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" tf_http_archive( name = name, -diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch -index 8b13789..90ca4ec 100755 ---- a/third_party/stablehlo/temporary.patch -+++ b/third_party/stablehlo/temporary.patch -@@ -1 +1,49 @@ -+diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -+@@ -12,7 +12,7 @@ -+ return %2 : tensor<14x15x0x33xf64> -+ } -+ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> -+ return %cst : tensor<14x15x0x17xcomplex> -+ } -+ func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { -+diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -+@@ -12,7 +12,7 @@ -+ return %2 : tensor<14x15x0x33xf32> -+ } -+ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> -+ return %cst : tensor<14x15x0x17xcomplex> -+ } -+ func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { -+diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -+@@ -16,7 +16,7 @@ -+ return %cst : tensor<14x15x0x17xf32> -+ } -+ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> -+ return %cst : tensor<14x15x0x9xcomplex> -+ } -+ } -+diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -+@@ -16,7 +16,7 @@ -+ return %cst : tensor<14x15x0x17xf64> -+ } -+ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> -+ return %cst : tensor<14x15x0x9xcomplex> -+ } -+ } - diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 92b2397d3a6afc..4b7db9e5a69eb3 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "84e1e3a76e7b827a6d1621df0c649c3449da1540" - SHARDY_SHA256 = "ddb163c1a40466e320b882821e63ee7cb48e31e96bd05109daff38fca0086f21" + SHARDY_COMMIT = "98555add83dfaa334cd538a401b49130ecacb0d8" + SHARDY_SHA256 = "cafe90437597fedee14f57b3cccea63b689c254748df42c1be0105ed1d64f21f" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index a292270454defc..4adb475a33423c 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,97 +1,15 @@ -diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..2e6ff58 100644 ---- a/third_party/llvm/generated.patch -+++ b/third_party/llvm/generated.patch -@@ -1 +1,23 @@ - Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp -+--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp -++++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp -+@@ -873,6 +873,7 @@ -+ DependentTemplateSpecializationTypeLoc SpecTL -+ = Builder.push(T); -+ SpecTL.setElaboratedKeywordLoc(SourceLocation()); -++ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); -+ SpecTL.setTemplateKeywordLoc(TemplateKWLoc); -+ SpecTL.setTemplateNameLoc(TemplateNameLoc); -+ SpecTL.setLAngleLoc(LAngleLoc); -+diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+@@ -1902,7 +1902,6 @@ -+ name = "inv_trigf_utils", -+ srcs = ["src/math/generic/inv_trigf_utils.cpp"], -+ hdrs = [ -+- "src/math/generic/atan_utils.h", -+ "src/math/generic/inv_trigf_utils.h", -+ ], -+ deps = [ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index fd9baec..91166c3 100644 +index 91166c3..4a58099 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260" -- LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3" -+ LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" -+ LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" +- LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" +- LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" ++ LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" ++ LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" tf_http_archive( name = name, -diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch -index 8b13789..90ca4ec 100755 ---- a/third_party/stablehlo/temporary.patch -+++ b/third_party/stablehlo/temporary.patch -@@ -1 +1,49 @@ -+diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -+@@ -12,7 +12,7 @@ -+ return %2 : tensor<14x15x0x33xf64> -+ } -+ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> -+ return %cst : tensor<14x15x0x17xcomplex> -+ } -+ func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { -+diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -+@@ -12,7 +12,7 @@ -+ return %2 : tensor<14x15x0x33xf32> -+ } -+ func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> -+ return %cst : tensor<14x15x0x17xcomplex> -+ } -+ func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { -+diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -+@@ -16,7 +16,7 @@ -+ return %cst : tensor<14x15x0x17xf32> -+ } -+ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> -+ return %cst : tensor<14x15x0x9xcomplex> -+ } -+ } -+diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -+--- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -++++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -+@@ -16,7 +16,7 @@ -+ return %cst : tensor<14x15x0x17xf64> -+ } -+ func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -+- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -++ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> -+ return %cst : tensor<14x15x0x9xcomplex> -+ } -+ } - diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 92b2397d3a6afc..4b7db9e5a69eb3 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "84e1e3a76e7b827a6d1621df0c649c3449da1540" - SHARDY_SHA256 = "ddb163c1a40466e320b882821e63ee7cb48e31e96bd05109daff38fca0086f21" + SHARDY_COMMIT = "98555add83dfaa334cd538a401b49130ecacb0d8" + SHARDY_SHA256 = "cafe90437597fedee14f57b3cccea63b689c254748df42c1be0105ed1d64f21f" tf_http_archive( name = "shardy", From 9003f14f63943cbdd111c8376f1c986e23b87995 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 6 Apr 2025 02:03:30 -0700 Subject: [PATCH 0274/1324] Update GraphDef version to 2189. PiperOrigin-RevId: 744411986 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 2497c8e30074fa..366f8c03b00947 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2188 // Updated: 2025/4/5 +#define TF_GRAPH_DEF_VERSION 2189 // Updated: 2025/4/6 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From f152b89e8b2ea945ce8543f697285eb10990de91 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 6 Apr 2025 02:03:32 -0700 Subject: [PATCH 0275/1324] compat: Update forward compatibility horizon to 2025-04-06 PiperOrigin-RevId: 744411995 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index ad80a4ac3acad9..35cb75e8cad846 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 5) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 6) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 8a25327043782efaaa6bf40dd3d874e5dcd02512 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Sun, 6 Apr 2025 06:02:47 -0700 Subject: [PATCH 0276/1324] Remove the topology check for compilation, allowing compilation on a client different from the client where the executable will run. PiperOrigin-RevId: 744447533 --- third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc | 6 ------ third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc | 4 ++-- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 2c08169881592f..50e910f58e02ef 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -81,12 +81,6 @@ absl::Status IsValidTopologyAndClientForCompile( return absl::InvalidArgumentError( "SE:GPU compiler requires a GPU PjRtClient."); } - TF_ASSIGN_OR_RETURN(auto client_topology, client->GetTopologyDescription()); - - if (!IsSameTopology(topology, *client_topology)) { - return absl::UnimplementedError( - "SE:GPU compiler requires the topology same as the one in the client."); - } return absl::OkStatus(); } diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index e653f15ec26985..a7a56b618be668 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -101,7 +101,7 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram)); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology, client.get()), - StatusIs(absl::StatusCode::kUnimplemented)); + StatusIs(absl::StatusCode::kOk)); } TEST(StreamExecutorGpuCompilerTest, SuccessXla) { @@ -164,7 +164,7 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) { GetStreamExecutorGpuClient(GpuClientOptions())); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology, client.get()), - StatusIs(absl::StatusCode::kUnimplemented)); + StatusIs(absl::StatusCode::kOk)); } TEST(StreamExecutorGpuCompilerTest, SuccessMlir) { From c7b462bd120f190fb83bad5172f30a74de1405bf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 6 Apr 2025 11:17:50 -0700 Subject: [PATCH 0277/1324] Automated Code Change PiperOrigin-RevId: 744489182 --- .../compiler/mlir/lite/quantization/lite/toco_legacy/BUILD | 1 + .../mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index c3945abc74f740..4f36cb7e7b3d4a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -102,6 +102,7 @@ cc_library( "//tensorflow/core/platform:logging", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@flatbuffers//:runtime_cc", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc index b2d6fe97280174..655c1e4deadf91 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers From 7350e4fa621739a3673fe132be4cb522cf642b5e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 6 Apr 2025 18:12:29 -0700 Subject: [PATCH 0278/1324] Automated Code Change PiperOrigin-RevId: 744545449 --- .../toco/graph_transformations/tests/identify_l2_pool_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc index ab487b4cf3bb28..136d8e26575e41 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include From b43757eb62315ae9d95ff82b1046777dd0296fee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 6 Apr 2025 23:06:47 -0700 Subject: [PATCH 0279/1324] Automated Code Change PiperOrigin-RevId: 744596327 --- .../lite/delegates/hexagon/builders/activation_builder.cc | 2 -- .../lite/delegates/hexagon/builders/arg_min_max_builder.cc | 2 +- .../lite/delegates/hexagon/builders/arithmetic_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/cast_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc | 2 +- tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc | 1 - .../lite/delegates/hexagon/builders/hardswish_builder.cc | 2 -- .../delegates/hexagon/builders/l2_normalization_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc | 2 +- tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc | 2 ++ .../lite/delegates/hexagon/builders/mirror_pad_builder.cc | 2 +- tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc | 2 +- tensorflow/lite/delegates/hexagon/builders/op_builder.cc | 3 +++ tensorflow/lite/delegates/hexagon/builders/pack_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/pad_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc | 2 +- tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc | 2 +- .../lite/delegates/hexagon/builders/resize_bilinear_builder.cc | 3 +++ .../hexagon/builders/resize_nearest_neighbor_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc | 2 ++ tensorflow/lite/delegates/hexagon/builders/slice_builder.cc | 1 + tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc | 2 -- .../lite/delegates/hexagon/builders/space_to_depth_builder.cc | 2 -- tensorflow/lite/delegates/hexagon/builders/split_builder.cc | 2 -- .../lite/delegates/hexagon/builders/squared_difference.cc | 2 ++ .../lite/delegates/hexagon/builders/strided_slice_builder.cc | 1 + 28 files changed, 21 insertions(+), 34 deletions(-) diff --git a/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc b/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc index 97202338826f2f..ee17b849706819 100644 --- a/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc b/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc index 0b3921f1b34622..f4b39a02357651 100644 --- a/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.h" -#include +#include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc b/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc index 4e888de5fc5eb3..c054b86735bee5 100644 --- a/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc b/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc index 7f624203dae9d0..d4b8adb105e92b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc index 17c1ce63718662..0c17d2f0baae5f 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc b/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc index 58c7bd76fb0239..744d048b699799 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc b/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc index 5e6ff2699fd1e2..dc42c8f51ed612 100644 --- a/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc @@ -15,8 +15,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc b/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc index 9e87d4109dba51..e4bb336b6e369f 100644 --- a/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc b/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc index fa91b50808560e..c242ff8e7d11c8 100644 --- a/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc b/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc index 9b6103fcc93536..772c52a7f6b4c9 100644 --- a/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/min_max_builder.h" +#include + #include "tensorflow/lite/core/c/common.h" namespace tflite { diff --git a/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc b/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc index 353b8a007d65fb..bcce11acd02edd 100644 --- a/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc b/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc index 93511dc491dad0..715aa3955793c8 100644 --- a/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/neg_op_builder.h" -#include +#include namespace tflite { namespace delegates { diff --git a/tensorflow/lite/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/delegates/hexagon/builders/op_builder.cc index 91258a418fd326..a3cb4157a5b3eb 100644 --- a/tensorflow/lite/delegates/hexagon/builders/op_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/op_builder.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" +#include +#include + #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc b/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc index 7ccdb299d5d835..9d7cc75f7a9908 100644 --- a/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc b/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc index d49a3de4ab9b42..4047d438f309ca 100644 --- a/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc b/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc index 45529b68858c30..729d988c24935b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc b/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc index 6e653fd70e48fc..078f27161f34e1 100644 --- a/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc b/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc index a41a9fb23ee72e..38e3a2e6633de2 100644 --- a/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc b/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc index 5946abff4d1fd8..58e2cc80f00605 100644 --- a/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc b/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc index 5cdd5398de1b29..8c846b41595946 100644 --- a/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.h" +#include +#include + #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { diff --git a/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc b/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc index 7276e9ad4500d9..b21665f30e568d 100644 --- a/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc b/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc index ad52495f54eaa3..f31800edb01e6c 100644 --- a/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc b/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc index 05dfd3ffeb070e..149106d4350983 100644 --- a/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/slice_builder.h" +#include #include #include "tensorflow/lite/kernels/internal/tensor.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc b/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc index 9915512856a2d1..28165875621516 100644 --- a/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc b/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc index 65e3899b79fe8c..6426fc36a0770b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/split_builder.cc b/tensorflow/lite/delegates/hexagon/builders/split_builder.cc index 6ea35f60114e18..a3a0254df5cd82 100644 --- a/tensorflow/lite/delegates/hexagon/builders/split_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/split_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc b/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc index b040aa0a12b993..51231f07fd79bc 100644 --- a/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc +++ b/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc b/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc index 9eabf5334199eb..257e1910455e1d 100644 --- a/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.h" +#include #include #include "tensorflow/lite/core/c/builtin_op_data.h" From 7d34fcf84dd1d6aa40165cf9cccb81233e339c3a Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 7 Apr 2025 00:06:42 -0700 Subject: [PATCH 0280/1324] Make the xla_compile target work without --config=cuda Half of the XLA CUDA runtime components were guarded by `if_cuda`, the other half by `if_cuda_is_configured`. This change unifies it and puts all CUDA/ROCm dependency behind if_{cuda|rocm}_is_configured. PiperOrigin-RevId: 744608636 --- third_party/xla/xla/service/BUILD | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index ba631fa150c014..14220b62ef0f64 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5,7 +5,6 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load( "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm", "if_rocm_is_configured", ) load( @@ -6219,23 +6218,23 @@ xla_cc_binary( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:types", ] + if_cuda_is_configured([ + # keep sorted "//xla/service/gpu:executable_proto_cc", "//xla/service/gpu:gpu_compiler", "//xla/service/gpu:nvptx_compiler", "//xla/service/gpu:nvptx_compiler_impl", - "//xla/stream_executor/gpu:gpu_init", + "//xla/stream_executor/cuda:all_runtime", "//xla/stream_executor/cuda:cuda_platform", + "//xla/stream_executor/gpu:gpu_init", ]) + if_rocm_is_configured([ - "//xla/service/gpu:executable_proto_cc", - "//xla/service/gpu:gpu_compiler", + # keep sorted "//xla/service/gpu:amdgpu_compiler", "//xla/service/gpu:amdgpu_compiler_impl", + "//xla/service/gpu:executable_proto_cc", + "//xla/service/gpu:gpu_compiler", "//xla/stream_executor/gpu:gpu_init", - "//xla/stream_executor/rocm:rocm_platform", - ]) + if_cuda([ - "//xla/stream_executor/cuda:all_runtime", - ]) + if_rocm([ "//xla/stream_executor/rocm:all_runtime", + "//xla/stream_executor/rocm:rocm_platform", ]) + xla_internal(["tools:xsymbol_repository"]), ) From f9079b5d0d9b05f9050aae6a8f145cfc4d8bffe3 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Mon, 7 Apr 2025 00:31:13 -0700 Subject: [PATCH 0281/1324] [XLA:GPU] Do not rewrite reshape+transpose+reshape as slices+concatenate It does more harm that good as later passes can reason about transpose much better than slices, and emitters can handle transpose efficiently just fine. PiperOrigin-RevId: 744613962 --- .../simplifiers/algebraic_simplifier.cc | 1 + .../simplifiers/algebraic_simplifier.h | 9 +++++++ .../simplifiers/algebraic_simplifier_test.cc | 25 +++++++++++++++++++ .../xla/xla/service/gpu/gpu_compiler.cc | 11 +++++--- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 59a0d5bd69cf0f..398d8dead09791 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -9034,6 +9034,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleTranspose( // the reshape/transpose combination can be interpreted as a space-to-depth // transformation. if (!options_.is_layout_sensitive() && + options_.rewrite_reshape_transpose_as_slice_concatenate() && operand->opcode() == HloOpcode::kReshape && transpose->user_count() == 1 && HloOpcode::kReshape == transpose->users()[0]->opcode()) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h index fc1489daffacc9..d39680ed40dbc8 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h @@ -338,6 +338,14 @@ class AlgebraicSimplifierOptions { enable_onednn_support_ = enable_onednn_support; } + bool rewrite_reshape_transpose_as_slice_concatenate() const { + return rewrite_reshape_transpose_as_slice_concatenate_; + } + + void set_rewrite_reshape_transpose_as_slice_concatenate(bool value) { + rewrite_reshape_transpose_as_slice_concatenate_ = value; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplifierOptions that can be later used in an @@ -393,6 +401,7 @@ class AlgebraicSimplifierOptions { false #endif // INTEL_MKL }; + bool rewrite_reshape_transpose_as_slice_concatenate_{true}; Metadata metadata_; }; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 2e86cd3d90694d..5d0d1705129942 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -6312,6 +6312,31 @@ ENTRY entry { .WithShapeEqualTo(&result_shape))); } +// Test that the transformation above doesn't happen when disabled. +TEST_F(AlgebraicSimplifierTest, DisabledTransposeReshapeToConcatSlice) { + const std::string& hlo_string = R"( +HloModule TransposeReshapeDepthToSpace + +ENTRY entry { + %param = f32[8,14,14,128] parameter(0) + %reshape.1 = f32[8,14,14,2,64] reshape(%param) + %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4} + ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options = default_options_; + options.set_rewrite_reshape_transpose_as_slice_concatenate(false); + AlgebraicSimplifier simplifier(options); + ASSERT_FALSE(simplifier.Run(module.get()).value()); + + Shape result_shape = ShapeUtil::MakeShape(F32, {8, 28, 14, 64}); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Reshape(m::Transpose(m::Reshape(m::Parameter(0)))))); +} + // Test that a depth-to-space transformation expressed as // reshape(transpose(reshape(op))) with a large number of chunks // is not rewritten. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 65681113f4d6e2..3a9dd3a257b5cc 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -582,6 +582,10 @@ AlgebraicSimplifierOptions LayoutInsensitiveAlgebraicSimplifierOptions( } layout_insensitive_algsimp_opts .set_enable_unconditional_reduce_of_concat_replacement(false); + // GPU pipeline handles transposes better than slice+concatenate, so keep + // the transpose. + layout_insensitive_algsimp_opts + .set_rewrite_reshape_transpose_as_slice_concatenate(false); return layout_insensitive_algsimp_opts; } @@ -1742,11 +1746,10 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( HloPassPipeline& remove_no_op_reduce_precision_pipeline = pipeline.AddPass( "remove-no-op-reduce-precision-algebraic-simplifier"); - AlgebraicSimplifierOptions simplifier_options_{simplifier_options}; - simplifier_options_.set_enable_remove_no_op_reduce_precision(true); + AlgebraicSimplifierOptions options{simplifier_options}; + options.set_enable_remove_no_op_reduce_precision(true); remove_no_op_reduce_precision_pipeline - .AddPass>(simplifier_options_, - gpu_version); + .AddPass>(options, gpu_version); } pipeline.AddPass(/*is_layout_sensitive=*/true); From 3886ddd80cef6a6c03db5fe900d17dc684bed994 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Mon, 7 Apr 2025 01:05:07 -0700 Subject: [PATCH 0282/1324] [XLA:GPU] Fix int4 test for Ampere. On ampere multiply has F32 return type. PiperOrigin-RevId: 744621796 --- .../gpu/codegen/triton/fusion_emitter_int4_device_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc index 391895ea67b268..69b37552c9a816 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_int4_device_test.cc @@ -559,8 +559,9 @@ TEST_F(TritonTest, FuseMultiplyInPrologue) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + // On Ampere the multiply result type is f32, on Hopper it is bf16. EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( - CHECK: %[[multiply:.*]] = bf16[32,64,128]{2,1,0} multiply + CHECK: %[[multiply:.*]] = [[type:.*]][32,64,128]{2,1,0} multiply CHECK: %[[dot:.*]] = f32[32,128,256]{2,1,0} dot CHECK: ENTRY %main )")); From 2d7a9461402c6cbb7e45f7f9a315f610fc952dd5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 01:12:19 -0700 Subject: [PATCH 0283/1324] Automated Code Change PiperOrigin-RevId: 744623706 --- .../gpu/codegen/emitters/in_place_dynamic_update_slice.cc | 4 ++-- .../xla/xla/backends/gpu/codegen/emitters/transpose.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc index 83770fbd85c3cf..524313a47ea62a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc @@ -139,9 +139,9 @@ absl::Status InPlaceDynamicUpdateSliceFusion::EmitEntryFunction( auto start_indices = ProvideParameterRange(root_computation, dus_instr, dus_instr->first_index_operand_number(), - update_shape.dimensions_size(), {}, + update_shape.dimensions().size(), {}, call_targets, entry_function, nested_b); - for (int i = 0; i < update_shape.dimensions_size(); ++i) { + for (int i = 0; i < update_shape.dimensions().size(); ++i) { int64_t update_size = update_shape.dimensions(i); auto start_index = ClampIndex( start_indices[i], diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc index 5bb9d41b4f9e3d..1fcea8b9571855 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc @@ -91,7 +91,7 @@ TransposeSpec GetTransposeSpec(const HloTransposeInstruction* transpose) { // If the last dimension is transposed, add a size-1 B dimension. if (canonical_permutation.back() != canonical_output_shape.size() - 1) { - canonical_permutation.push_back(output_shape.dimensions_size()); + canonical_permutation.push_back(output_shape.dimensions().size()); canonical_output_shape.push_back(1); } int64_t dim_t1 = -1; From 88e53ae74c8da5116e0d8f021de57b937cb3808f Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Mon, 7 Apr 2025 01:31:34 -0700 Subject: [PATCH 0284/1324] PR #24269: Fix for fusion wrapper on async computations Imported from GitHub PR https://github.com/openxla/xla/pull/24269 Previously, in JAX if you do a simple stream annotated computation with just a single instruction, i.e. ```python @compute_on('gpu_stream:1') @jax.jit def h(x, y): return x * y ``` Then the XLA compiler would fail with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unsupported instruction opcode: multiply ``` This is because the fusion wrapper would not correctly look inside of async computations, leaving the instructions unfused. To fix this, we simply add `kCall` and `kAsyncStart` to the set of instructions that are treated recursively in the fusion wrapper. Copybara import of the project: -- bef9c2829a585727094c270b6c7e1e05ce304819 by chaserileyroberts : Added fix for some fusion issues using compute_on -- 1d909f1fb4a5bd11053be35f15e42796f7adb643 by chaserileyroberts : Update fusion wrapper test -- 0e39915f444607c5518bb7a5ce4a168801bfee8d by chaserileyroberts : Added {0} to dtype descriptions -- 5e51b3897c45354e929b061d131f8b2a6efde36a by chaser : xla_cc_test -> xla_test -- 37df8115138ba07ed3f735787f62830d4abafffc by chaser : RunAndCompare -> RunAndCompareNoHloPasses Merging this change closes #24269 PiperOrigin-RevId: 744627904 --- .../codegen/emitters/fusion_wrapper_base.cc | 3 +- third_party/xla/xla/service/gpu/tests/BUILD | 6 +-- .../gpu/tests/async_kernel_launch_test.cc | 53 +++++++++++++++++++ .../gpu/transforms/fusion_wrapper_test.cc | 29 ++++++++++ 4 files changed, 87 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/fusion_wrapper_base.cc b/third_party/xla/xla/codegen/emitters/fusion_wrapper_base.cc index 759b038f40d0a0..9d0ef0b98da911 100644 --- a/third_party/xla/xla/codegen/emitters/fusion_wrapper_base.cc +++ b/third_party/xla/xla/codegen/emitters/fusion_wrapper_base.cc @@ -37,7 +37,8 @@ absl::StatusOr FusionWrapperBase::Run( std::function handle_instruction; handle_instruction = [&](HloInstruction* instruction) -> absl::Status { const HloOpcode opcode = instruction->opcode(); - if (opcode == HloOpcode::kConditional || opcode == HloOpcode::kWhile) { + if (opcode == HloOpcode::kConditional || opcode == HloOpcode::kWhile || + opcode == HloOpcode::kCall || opcode == HloOpcode::kAsyncStart) { for (auto* computation : instruction->called_computations()) { for (auto* inner_instruction : computation->MakeInstructionPostOrder()) { diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 501f0a379203f7..d107b979fe02e0 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -117,17 +117,17 @@ xla_test( ], ) -xla_cc_test( +xla_test( name = "async_kernel_launch_test", srcs = ["async_kernel_launch_test.cc"], + backends = ["gpu"], # "requires-net:external" tag allows uploading `xprof` results. - tags = if_google(["requires-net:external"]) + tf_cuda_tests_tags(), + tags = if_google(["requires-net:external"]), deps = [ "//xla:debug_options_flags", "//xla:literal", "//xla:literal_util", "//xla:xla_proto_cc", - "//xla/service:gpu_plugin", "//xla/service:hlo_module_config", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", diff --git a/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc b/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc index 7aac65a38121ce..4acb8ec495a783 100644 --- a/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc +++ b/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc @@ -77,5 +77,58 @@ TEST_F(AsyncKernelLaunchTest, BasicFusion) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(AsyncKernelLaunchTest, BasicAsyncComputation) { + const char* hlo_text = R"( + HloModule Test1 + + add_F32 { + lhs = f32[2]{0} parameter(0) + rhs = f32[2]{0} parameter(1) + ROOT add = f32[2]{0} add(lhs, rhs) + } + + ENTRY Test1 { + a = f32[2]{0} parameter(0) + b = f32[2]{0} parameter(1) + start = ((f32[2]{0}, f32[2]{0}), f32[2]{0}) call-start(a, b), to_apply=add_F32 + ROOT done = f32[2]{0} call-done(start) + } + )"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(AsyncKernelLaunchTest, ScheduledOverlappingAsyncComputations) { + const char* hlo_text = R"( + HloModule Test1 + + add { + lhs = f32[2]{0} parameter(0) + rhs = f32[2]{0} parameter(1) + ROOT add = f32[2]{0} add(lhs, rhs) + } + + mul { + lhs = f32[2]{0} parameter(0) + rhs = f32[2]{0} parameter(1) + ROOT mul = f32[2] multiply(lhs, rhs) + } + + ENTRY Test1 { + a = f32[2]{0} parameter(0) + b = f32[2]{0} parameter(1) + start = ((f32[2]{0}, f32[2]{0}), f32[2]{0}) call-start(a, b), to_apply=add, + frontend_attributes={_xla_stream_annotation="1", _scheduling_group_id="0"} + start.1 = ((f32[2]{0}, f32[2]{0}), f32[2]{0}) call-start(a, b), to_apply=mul, + frontend_attributes={_xla_stream_annotation="2", _scheduling_group_id="0"} + done = f32[2]{0} call-done(start) + done.1 = f32[2]{0} call-done(start.1) + ROOT result = f32[2]{0} add(done, done.1) + } + )"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc index 84421328cabd66..e1326bb27cfc47 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -236,6 +236,35 @@ TEST_F(FusionWrapperTest, WhileInFusion) { std::nullopt); } +TEST_F(FusionWrapperTest, AsyncComputationFusion) { + RunAndFilecheckHloRewrite(R"( + HloModule AsyncComputation + + mul { + a = f32[5] parameter(0) + ROOT b = f32[5] multiply(a, a) + } + + ENTRY %main { + parameter = f32[5] parameter(0) + start = ((f32[5]), f32[5]) call-start(parameter), to_apply=mul + ROOT done = f32[5] call-done(start) + })", + FusionWrapper(device_description()), R"( +//CHECK: %wrapped_multiply_computation {{.*}} { +//CHECK-NEXT: %[[P0:.*]] = {{.*}} parameter(0) +//CHECK-NEXT: ROOT {{.*}} multiply(%[[P0]], %[[P0]]) +//CHECK-NEXT: } +//CHECK: %mul {{.*}} { +//CHECK-NEXT: %[[P0:.*]] = {{.*}} parameter(0) +//CHECK-NEXT: ROOT {{.*}} fusion(%[[P0]]), kind=kLoop, calls=%wrapped_multiply_computation +//CHECK-NEXT: } +//CHECK: ENTRY %main {{.*}} { +//CHECK-NEXT: %[[P0:.*]] = {{.*}} parameter(0) +//CHECK-NEXT: {{.*}} call-start(%[[P0]]), to_apply=%mul +)"); +} + } // namespace } // namespace gpu } // namespace xla From d624ad359526bf12a897a0e84eba870df0c1552e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 7 Apr 2025 01:32:39 -0700 Subject: [PATCH 0285/1324] [XLA:GPU] Rewrite and enable some tests in the port of the legacy matmul tests to the generic Triton emitter. Uncover a couple of things we need to fix: 1. We need to support some actual mixed type `dot`s (at least `f8xf8->f32`); 2. There is likely a bug in `bitcast` hoisting through `broadcast`s, since the HLO no longer verifies after hoisting in `DISABLED_NoTF32For8BitOrLessWithF32`. I suspect the element type might end up not being set correctly. PiperOrigin-RevId: 744628196 --- .../fusion_emitter_device_legacy_port_test.cc | 105 +++++++++++------- .../gpu/autotuning/gemm_fusion_autotuner.cc | 13 ++- .../xla/xla/service/gpu/backend_configs.proto | 2 +- .../xla/xla/service/gpu/gpu_compiler.cc | 13 ++- .../gpu/transforms/nest_gemm_fusion.cc | 10 +- 5 files changed, 89 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index a8603e0b9b89d3..9ab252183f04e3 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -198,6 +198,8 @@ ENTRY e { EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); } +// TODO(bchetioui): there is already a change out to fix this, enable once it +// lands. TEST_F(TritonTest, DISABLED_TestGemmWithTrivialNonContractingDimension) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -3946,20 +3948,24 @@ ENTRY e { // This test could be modified to allow TF32 once this bug is fixed. // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. +// +// TODO(b/393299275): this test uncovers a bug in hoisting bitcasts through +// broadcasts (seems to generate a type mismatch). TEST_F(TritonTest, DISABLED_NoTF32For8BitOrLessWithF32) { const std::string hlo_text = R"( HloModule t triton_dot { parameter_0 = s32[11,24]{1,0} parameter(0) - broadcast.1747 = s32[11,24,128]{2,1,0} broadcast(parameter_0), - dimensions={0,1} parameter_1 = s32[11,24,128]{2,1,0} parameter(1) - compare.49 = pred[11,24,128]{2,1,0} compare(broadcast.1747, parameter_1), + broadcast = s32[11,24,128]{2,1,0} broadcast(parameter_0), + dimensions={0,1} + parameter_1 = s32[11,24,128]{2,1,0} parameter(1) + compare = pred[11,24,128]{2,1,0} compare(broadcast, parameter_1), direction=EQ - bitcast.4717 = pred[264,128]{1,0} bitcast(compare.49) - convert.142 = f32[264,128]{1,0} convert(bitcast.4717) + bitcast = pred[264,128]{1,0} bitcast(compare) + convert = f32[264,128]{1,0} convert(bitcast) parameter_2 = f32[128,8]{1,0} parameter(2) - ROOT dot.381 = f32[264,8]{1,0} dot(convert.142, parameter_2), + ROOT dot = f32[264,8]{1,0} dot(convert, parameter_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -3974,16 +3980,23 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} })"; - + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(hlo_text)); TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( + CreateTritonIrAndFileCheck(*module_and_metadata.computation, + module_and_metadata.block_level_parameters, + R"( CHECK: tt.dot CHECK-NOT: inputPrecision = tf32 )")); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompareNoHloPasses( + hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): this test requires us to allow actual mixed type GEMMs +// in the lowering. We need to expand support tests and the lowering to model +// mixed types as needed. (f8e4m3fn x f8e4m3fn -> f32) TEST_F(TritonTest, DISABLED_Fp8LoweringIsSupportedPostHopper) { if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; @@ -4010,14 +4023,22 @@ ENTRY main { "num_stages":"4","num_warps":"4","num_ctas":"1"}}} })"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(hlo_text)); TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( + CreateTritonIrAndFileCheck(*module_and_metadata.computation, + module_and_metadata.block_level_parameters, + R"( CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4M3FN> * tensor<64x32xf8E4M3FN> -> tensor<128x32xf32> )")); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, + ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } +// TODO(b/393299275): this test requires us to allow actual mixed type GEMMs +// in the lowering. We need to expand support tests and the lowering to model +// mixed types as needed. (f8e4m3fn x f8e4m3fn -> f32) TEST_F(TritonTest, DISABLED_BF16ToFP8EndToEnd) { if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; @@ -4048,6 +4069,9 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } +// TODO(b/393299275): this test requires us to allow actual mixed type GEMMs +// in the lowering. We need to expand support tests and the lowering to model +// mixed types as needed. TEST_F(TritonTest, DISABLED_FP8ToFP8EndToEnd) { if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; @@ -4074,22 +4098,27 @@ ENTRY main { {"block_m":"32","block_n":"32","block_k":"32","split_k":"1", "num_stages":"1","num_warps":"4","num_ctas":"1"}}} })"; - + ASSERT_TRUE( + GetDebugOptionsForTest() + .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } // Test PreventMmaV3LoopUnrolling pass in order to keep compile time low. // See b/344841434. -TEST_F(TritonGemmTest, DISABLED_TestPreventMMAV3LoopUnrolling) { +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonGemmTest, TestPreventMMAV3LoopUnrolling) { if (GetCudaComputeCapability().major != se::CudaComputeCapability::kHopper) { GTEST_SKIP() << "wgmma instruction is only available on Hopper"; } const std::string hlo_text = R"( gemm_fusion_dot { - %p0 = f16[64,1024]{1,0} parameter(0) - %p1 = f16[1024,32,32]{2,1,0} parameter(1) - %bitcast.74246 = f16[1024,1024]{0,1} bitcast(f16[1024,32,32]{2,1,0} %p1) - ROOT %dot.1302 = f16[64,1024]{1,0} dot(f16[64,1024]{1,0} %p0, f16[1024,1024]{0,1} %bitcast.74246), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} + p0 = f16[64,1024]{1,0} parameter(0) + p1 = f16[1024,32,32]{2,1,0} parameter(1) + bitcast = f16[1024,1024]{0,1} bitcast(p1) + ROOT dot = f16[64,1024]{1,0} dot(p0, bitcast), + lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { @@ -4103,19 +4132,22 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} })"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(verified_module), + CompileAndOptionallyVerifyPtx(std::move(module_and_metadata.module), R"( R"( CHECK: $L__BB0_1: CHECK-NEXT: // begin inline asm CHECK-NEXT: .pragma "nounroll"; CHECK: wgmma -)"); +)", + /*run_optimization_passes=*/false); } -TEST_F(TritonGemmTest, DISABLED_WgmmaIsUsedForMemBoundShape) { +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonGemmTest, WgmmaIsUsedForMemBoundShape) { if (GetCudaComputeCapability().major != se::CudaComputeCapability::kHopper) { GTEST_SKIP() << "wgmma instruction is only available on Hopper"; } @@ -4139,21 +4171,21 @@ ENTRY e { "num_ctas":1}}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(verified_module), R"( + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(hlo_text)); + + CompileAndOptionallyVerifyPtx(std::move(module_and_metadata.module), R"( CHECK: wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 -)"); +)", + /*run_optimization_passes=*/false); } -// Test presence of default matmul config information -// when gemm autotuner is not present in pipeline, -// (which is currently the case on rocm). -TEST_F(TritonGemmTest, DISABLED_TestNoAutotuner) { - if (std::holds_alternative( - GpuComputeCapability())) { - GTEST_SKIP() << "Autotuner is always in pipeline on Cuda."; - } +// Test presence of default matmul config information when the GEMM autotuner is +// not present in the compilation pipeline (which is always the case on ROCM). +// +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonGemmTest, TestNoAutotuner) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) @@ -4174,11 +4206,8 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_gemm +; CHECK-SAME: __triton_nested_gemm_fusion )"); - - EXPECT_TRUE(RunAndCompare(std::move(verified_module), - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } } // namespace diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 6370fae5097a9a..bcc2a05fedd748 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -331,12 +331,6 @@ absl::StatusOr> TritonGemmAutotuneExtractor( *backend_config.mutable_triton_gemm_config() = config.ToProto(); TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config)); - if (debug_opts - .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()) { - NestGemmFusion nest_gemm_fusion(gpu_device_info.gpu_compute_capability()); - TF_RETURN_IF_ERROR(nest_gemm_fusion.Run(new_module.get()).status()); - } - if (config.split_k > 1) { TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config)); for (PrimitiveType type : @@ -356,6 +350,13 @@ absl::StatusOr> TritonGemmAutotuneExtractor( FusionWrapper fusion_wrapper(gpu_device_info); TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); } + + if (debug_opts + .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()) { + NestGemmFusion nest_gemm_fusion(gpu_device_info.gpu_compute_capability()); + TF_RETURN_IF_ERROR(nest_gemm_fusion.Run(new_module.get()).status()); + } + return new_module; } diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 515f672acbf3a9..329cb7f44c0cc4 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -209,7 +209,7 @@ message FusionBackendConfig { // present, we use the default Triton config. AutotuneResult.TritonGemmKey triton_gemm_config = 2; - // Only valid when kind is "__triton" or "__triton_nested_fusion_gemm". Code + // Only valid when kind is "__triton" or "__triton_nested_gemm_fusion". Code // generation of such fusions will fail if this field is not set. BlockLevelFusionConfig block_level_fusion_config = 6; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 3a9dd3a257b5cc..be36336e74f25b 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1674,11 +1674,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( options.key_value_store, gpu_target_config.device_description.runtime_version())); - if (debug_options - .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()) { - pipeline.AddPass( - gpu_target_config.device_description.gpu_compute_capability()); - } // Inline back the calls which have better performance with cuBLAS. pipeline.AddPass( /*single_call_site=*/false, /*update_domain=*/false, @@ -1721,6 +1716,14 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // normalized again. add_float_normalization(pipeline); + // Match the location of this pass in `gemm_fusion_autotuner.cc` to make sure + // that there is no discrepancy. + if (debug_options + .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()) { + pipeline.AddPass( + gpu_target_config.device_description.gpu_compute_capability()); + } + // Clean up new_tuple described above. pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index 792d38c8a62384..ec2530a6c098d1 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -779,10 +779,12 @@ class NestGemmFusionVisitor : public DfsHloRewriteVisitor { // switch the track. Thus it is on us to make sure that the generic emitter // will be able to handle the result. That is an early check to make sure // that that nesting did not produce an unsupported HLO. - if (!IsTritonSupportedComputation(*computation, compute_capability_)) { - return absl::InternalError(absl::StrCat("Computation of fusion ", - fusion->ToString(), - " is not supported by Triton.")); + CodegenDecision can_codegen_computation = + IsTritonSupportedComputation(*computation, compute_capability_); + if (!can_codegen_computation) { + return absl::InternalError(absl::StrCat( + "Computation of fusion ", fusion->ToString(), + " is not supported by Triton: ", can_codegen_computation.Explain())); } return absl::OkStatus(); } From d4ffd17fa69c46fa64337152b410b7aa37aeb09c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 01:35:15 -0700 Subject: [PATCH 0286/1324] Automated Code Change PiperOrigin-RevId: 744628857 --- .../backends/gpu/codegen/emitters/ir/xla_gpu_ops.cc | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.cc index c7d5c5f61880c5..c916b8e6dfbfcc 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.cc @@ -54,33 +54,22 @@ namespace { using llvm::ArrayRef; using mlir::AffineExpr; -using mlir::AffineMap; -using mlir::Block; using mlir::DenseI64ArrayAttr; using mlir::failure; -using mlir::getAffineConstantExpr; -using mlir::getAffineDimExpr; -using mlir::getAffineSymbolExpr; using mlir::Location; using mlir::LogicalResult; using mlir::MLIRContext; using mlir::OpAsmParser; using mlir::OpAsmPrinter; -using mlir::OpBuilder; using mlir::OperationState; using mlir::ParseResult; -using mlir::PatternRewriter; using mlir::RankedTensorType; -using mlir::Region; using mlir::SmallVector; using mlir::success; using mlir::Type; using mlir::TypeRange; -using mlir::Value; using mlir::ValueRange; -namespace arith = mlir::arith; - } // namespace //===----------------------------------------------------------------------===// From 1fe54384d883b6b5334b6a699cfbc139a8541bdd Mon Sep 17 00:00:00 2001 From: zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com> Date: Mon, 7 Apr 2025 01:39:56 -0700 Subject: [PATCH 0287/1324] PR #24573: [ROCm] Fixed issue with builtin.module error. Imported from GitHub PR https://github.com/openxla/xla/pull/24573 Copybara import of the project: -- 82ab4adb53e16f93a439c34bb18034d0a7201922 by Zoran Jovanovic : [ROCm] Fixed issue with builtin.module error. Merging this change closes #24573 PiperOrigin-RevId: 744630070 --- .../triton/compilation_pipeline_rocm.cc | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index 9b75ab1f1eb9b7..21919310c67dfd 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -40,10 +40,6 @@ limitations under the License. namespace xla { namespace gpu { -// Value 0 for num_stages is used to represent AMD specific register -// file double buffering. -constexpr int kAmdDoubleBuffering = 0; - namespace ma = ::mlir::arith; namespace mm = ::mlir::math; namespace ml = ::mlir::LLVM; @@ -90,41 +86,49 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mt::gpu::createTritonGPUCoalesce()); pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm->addPass(mlir::createTritonAMDGPUAccelerateMatmulPass(cc.gfx_version())); + // TODO ROCm Pass cc.gfx_version() after fixing issue with fmfa + pm->addPass(mlir::createTritonAMDGPUAccelerateMatmulPass()); pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); - pm->addPass(mlir::createTritonAMDGPUHoistLayoutConversionsPass()); + pm->addNestedPass( + mlir::createTritonAMDGPUHoistLayoutConversionsPass()); pm->addPass(mt::gpu::createTritonGPUFuseNestedLoops()); pm->addPass(mlir::createCSEPass()); pm->addPass(mlir::createLoopInvariantCodeMotionPass()); pm->addPass(mlir::createCanonicalizerPass()); - if (num_stages == kAmdDoubleBuffering && cc.has_amd_matrix_core()) { - pm->addPass(mlir::createTritonAMDGPUStreamPipelinePass( - num_stages, /*stream_prefetch=*/true, /*use_async_copy=*/true)); - pm->addPass(mlir::createTritonAMDGPUCoalesceAsyncCopyPass()); + if (cc.has_amd_matrix_core()) { + pm->addPass(mlir::createTritonAMDGPUStreamPipelinePass(num_stages)); + // TODO(ROCm) Modify when corresponding run time flags are introduced. + if (/*use_async_copy=*/false) { // Not enabled by default. + pm->addPass(mlir::createTritonAMDGPUCoalesceAsyncCopyPass()); + } pm->addPass(mlir::createCanonicalizerPass()); } - pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass("default")); + if (/*(instruction_sched_variant=="none") == */ false) { + pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass("none")); + } pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication()); - if (false) { // Not enabled by default. + if (/*(instruction_sched_variant=="none") == */ false) { pm->addPass(mlir::createTritonAMDGPUInThreadTransposePass()); pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); } - if (num_stages != kAmdDoubleBuffering) { + if (cc.has_amd_matrix_core()) { pm->addPass(mt::gpu::createTritonGPUReorderInstructions()); } - if (false) { // For upstream, this is enabled iff arch == gfx942. + if (/*(use_block_pingpong == "none") ==*/false) { pm->addPass(mlir::createTritonAMDGPUBlockPingpongPass(num_stages)); } - pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); - pm->addPass(mlir::createCanonicalizerPass()); - pm->addPass(mlir::createTritonAMDGPUConvertToBufferOpsPass(arch_name)); + if (/*use_buffer_ops=*/false) { // Not enabled by default. + pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createTritonAMDGPUConvertToBufferOpsPass(arch_name)); + } pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createCSEPass()); pm->addPass(mlir::createSymbolDCEPass()); @@ -146,8 +150,10 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createCSEPass()); pm->addPass(mlir::createSymbolDCEPass()); - pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( - cc.gfx_version(), num_stages)); + if (/*(instruction_sched_variant=="none") == */ false) { + pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( + cc.gfx_version(), num_stages)); + } pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)); // There is no clusters in ROCm for now. out_cluster_info.clusterDimX = 1; From c20123a9cfcc4ae6160137d0336de7c7cfbee73f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 01:42:35 -0700 Subject: [PATCH 0288/1324] Replace outdated select() on --cpu in tensorflow/BUILD and related files with platform API equivalent. PiperOrigin-RevId: 744630742 --- tensorflow/BUILD | 322 ++++++++++--------------- tensorflow/lite/delegates/gpu/BUILD | 10 +- tensorflow/lite/kernels/internal/BUILD | 4 +- tensorflow/tensorflow.bzl | 13 - tensorflow/workspace3.bzl | 10 + 5 files changed, 150 insertions(+), 209 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index bf7adc2e796ec4..12c43a0aaf8db0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -253,7 +253,7 @@ config_setting( config_setting( name = "android", constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], + ["@platforms//os:android"], [], ), values = if_oss( @@ -265,45 +265,45 @@ config_setting( config_setting( name = "android_x86", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_32", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "x86", ), visibility = ["//visibility:public"], ) config_setting( name = "android_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "x86_64", ), visibility = ["//visibility:public"], ) config_setting( name = "android_armeabi", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:armv6-m", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "armeabi", ), visibility = ["//visibility:public"], ) @@ -311,22 +311,28 @@ config_setting( # copybara:uncomment_begin(google-only) # config_setting( # name = "chromiumos_x86_64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "k8"}, +# constraint_values = [ +# "@platforms//cpu:x86_64", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_arm64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "arm"}, +# constraint_values = [ +# "@platforms//cpu:aarch64", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_armv7", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "armeabi-v7a"}, +# constraint_values = [ +# "@platforms//cpu:armv7", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # copybara:uncomment_end @@ -334,7 +340,7 @@ config_setting( config_setting( name = "emscripten", constraint_values = if_google( - ["//third_party/bazel_platforms/os:emscripten"], + ["@platforms//os:emscripten"], [], ), values = if_oss( @@ -346,57 +352,56 @@ config_setting( config_setting( name = "raspberry_pi_armeabi", + constraint_values = + [ + "@platforms//cpu:armv6-m", + "@platforms//os:linux", + ], values = { "crosstool_top": "@local_config_arm_compiler//:toolchain", - "cpu": "armeabi", }, visibility = ["//visibility:public"], ) config_setting( name = "android_arm", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:armv7", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "armeabi-v7a", ), visibility = ["//visibility:public"], ) config_setting( name = "android_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "arm64-v8a", ), visibility = ["//visibility:public"], ) -config_setting( - name = "android_mips", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "mips", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "android_mips64", + constraint_values = + [ + "@platforms//cpu:mips64", + "@platforms//os:android", + ], values = { "crosstool_top": "//external:android/crosstool", - "cpu": "mips64", }, visibility = ["//visibility:public"], ) @@ -404,16 +409,10 @@ config_setting( # TODO(jakeharmon8): Remove in favor of TSL version config_setting( name = "windows", - # Internal builds query the target OS. - constraint_values = if_google( - ["//third_party/bazel_platforms/os:windows"], - [], - ), - # OSS builds query the CPU type. - values = if_oss( - {"cpu": "x64_windows"}, - {}, - ), + constraint_values = + [ + "@platforms//os:windows", + ], visibility = ["//visibility:public"], ) @@ -423,52 +422,28 @@ config_setting( visibility = ["//visibility:public"], ) -# Sometimes Bazel reports darwin_x86_64 as "darwin" and sometimes as -# "darwin_x86_64". The former shows up when building on a Mac x86_64 host for a Mac x86_64 target. -# The latter shows up when cross-compiling for Mac x86_64 from a Mac ARM machine and in internal -# Google builds. -config_setting( - name = "macos_x86_64_default", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), - values = { - "apple_platform_type": "macos", - "cpu": "darwin", - }, -) - config_setting( - name = "macos_x86_64_crosscompile", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), + name = "macos_x86_64", + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:macos", + ], values = { "apple_platform_type": "macos", - "cpu": "darwin_x86_64", }, -) - -selects.config_setting_group( - name = "macos_x86_64", - match_any = [ - ":macos_x86_64_default", - ":macos_x86_64_crosscompile", - ], visibility = ["//visibility:public"], ) config_setting( name = "macos_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:macos", + ], values = { "apple_platform_type": "macos", - "cpu": "darwin_arm64", }, visibility = ["//visibility:public"], ) @@ -486,7 +461,7 @@ selects.config_setting_group( config_setting( name = "ios", constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], + ["@platforms//os:ios"], [], ), values = if_oss( @@ -499,41 +474,32 @@ config_setting( # TODO(jakeharmon8): Remove in favor of TSL version config_setting( name = "fuchsia", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = if_oss( - # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. - {"cpu": "fuchsia"}, - {}, - ), + constraint_values = + ["@platforms//os:fuchsia"], visibility = ["//visibility:public"], ) config_setting( name = "fuchsia_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = { - "cpu": "x86_64", - }, + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:fuchsia", + ], visibility = ["//visibility:public"], ) config_setting( name = "ios_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:ios", + ], values = dict( if_oss( {"crosstool_top": "//tools/osx/crosstool:crosstool"}, ), - cpu = "ios_x86_64", ), visibility = ["//visibility:public"], ) @@ -541,7 +507,7 @@ config_setting( config_setting( name = "chromiumos", constraint_values = if_google( - ["//third_party/bazel_platforms/os:chromiumos"], + ["@platforms//os:chromiumos"], [], ), values = if_oss( @@ -553,31 +519,31 @@ config_setting( config_setting( name = "linux_aarch64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "aarch64"}, + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_armhf", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "armhf"}, + constraint_values = + [ + "@platforms//cpu:armv7e-mf", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "k8"}, + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) @@ -590,12 +556,12 @@ config_setting( # This condition takes precedence over :linux_x86_64 config_setting( name = "linux_x86_64_no_sse", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], values = { - "cpu": "k8", "copt": "-mno-sse4.2", }, visibility = ["//visibility:public"], @@ -605,52 +571,52 @@ config_setting( # TODO(b/290533709): Remove this with PJRT build rule cleanup. config_setting( name = "linux_x86_64_with_weightwatcher", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], define_values = {"tensorflow_weightwatcher": "true"}, - values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_ppc64le", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "ppc"}, + constraint_values = + [ + "@platforms//cpu:ppc64le", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_s390x", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "s390x"}, + constraint_values = + [ + "@platforms//cpu:s390x", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_mips64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "mips64"}, + constraint_values = + [ + "@platforms//cpu:mips64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_riscv64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "riscv64"}, + constraint_values = + [ + "@platforms//cpu:riscv64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) @@ -670,45 +636,25 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "arm", - values = {"cpu": "arm"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi", - values = {"cpu": "armeabi"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "arm64-v8a", - values = {"cpu": "arm64-v8a"}, - visibility = ["//visibility:public"], -) - selects.config_setting_group( name = "arm_any", match_any = [ - ":arm", - ":armeabi", - ":armeabi-v7a", - ":arm64-v8a", - ":linux_aarch64", - ":linux_armhf", + "@platforms//cpu:aarch32", + "@platforms//cpu:aarch64", + "@platforms//cpu:armv6-m", + "@platforms//cpu:armv7", + "@platforms//cpu:armv7-m", + "@platforms//cpu:armv7e-m", + "@platforms//cpu:armv7e-mf", ], ) config_setting( name = "freebsd", - values = {"cpu": "freebsd"}, + constraint_values = [ + "@platforms//os:freebsd", + "@platforms//cpu:x86_64", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 421a6faebfbd1d..37c4df92b4dad8 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -70,14 +70,12 @@ config_setting( config_setting( name = "tflite_gpu_extra_gles_deps", - # copybara:uncomment_begin(google-only) - # constraint_values = [ - # "//third_party/bazel_platforms/os:linux", - # ], - # copybara:uncomment_end + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:linux", + ], values = { "copt": "-DTFLITE_GPU_EXTRA_GLES_DEPS", - "cpu": "k8", }, ) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index f21180e89ee6b8..353cdcdf23e417 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("//tensorflow:tensorflow.bzl", "transitive_hdrs") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts") -load("//tensorflow/lite:special_rules.bzl", "tflite_extra_arm_config_settings", "tflite_portable_test_suite_combined") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined") load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") package( @@ -134,7 +134,7 @@ selects.config_setting_group( match_any = [ ":arm32_any", ":aarch64_any", - ] + tflite_extra_arm_config_settings(), + ], ) selects.config_setting_group( diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 5e9adecd209dce..8ea4765974b40c 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -216,25 +216,12 @@ def if_android_arm64(a): "//conditions:default": [], }) -def if_android_mips(a): - return select({ - clean_dep("//tensorflow:android_mips"): a, - "//conditions:default": [], - }) - def if_not_android(a): return select({ clean_dep("//tensorflow:android"): [], "//conditions:default": a, }) -def if_not_android_mips_and_mips64(a): - return select({ - clean_dep("//tensorflow:android_mips"): [], - clean_dep("//tensorflow:android_mips64"): [], - "//conditions:default": a, - }) - def if_android(a): return select({ clean_dep("//tensorflow:android"): a, diff --git a/tensorflow/workspace3.bzl b/tensorflow/workspace3.bzl index 0818a86d302fc1..15fffd35ab7531 100644 --- a/tensorflow/workspace3.bzl +++ b/tensorflow/workspace3.bzl @@ -78,6 +78,16 @@ def workspace(): url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG, ) + # Platforms + http_archive( + name = "platforms", + sha256 = "29742e87275809b5e598dc2f04d86960cc7a55b3067d97221c9abbc9926bff0f", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/platforms/releases/download/0.0.11/platforms-0.0.11.tar.gz", + "https://github.com/bazelbuild/platforms/releases/download/0.0.11/platforms-0.0.11.tar.gz", + ], + ) + # Load the raw llvm-project. llvm does not have build rules set up by default, # but provides a script for setting up build rules via overlays. llvm("llvm-raw") From 3df355522ffcec6705e619d26c8cf4918e845e4b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 01:49:36 -0700 Subject: [PATCH 0289/1324] Automated Code Change PiperOrigin-RevId: 744632397 --- .../backends/cpu/codegen/emitters/cpu_scatter_emitter.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc index 3daee405cdf5e3..fa2a2b70d86132 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc @@ -108,8 +108,8 @@ std::optional CpuScatterFusion::ComputeThreadIdToInputIndexing( Shape scatter_update_shape = scatter->scatter_updates().front()->shape(); auto root_shape = scatter->scatter_operands().front()->shape(); - SmallVector outer_dimension_partitions(root_shape.dimensions_size(), - 1); + SmallVector outer_dimension_partitions( + root_shape.dimensions().size(), 1); auto backend_config = fusion_->backend_config(); if (backend_config.ok() && !backend_config->outer_dimension_partitions().empty()) { @@ -294,7 +294,7 @@ absl::Status CpuScatterFusion::EmitEntryFunction( Value in_bounds = nested_b.create(1, b.getI1Type()); SmallVector update_offsets( - scatter_operands.front()->shape().dimensions_size(), c0); + scatter_operands.front()->shape().dimensions().size(), c0); for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) { SmallVector indices_tensor_indices = { update_id, b.create(i)}; From 53c6aed2429335652407d9de8c28d9c2ca29e4b3 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 7 Apr 2025 01:53:39 -0700 Subject: [PATCH 0290/1324] [XLA:GPU] Make tests match the latest state of `fusion_emitter_device_legacy_test.cc`. PiperOrigin-RevId: 744633726 --- .../triton/fusion_emitter_device_legacy_port_test.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 9ab252183f04e3..c58c767ccbf8da 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -3946,12 +3946,9 @@ ENTRY e { .WithShape(BF16, {16, 40}, {1, 0}))); } -// This test could be modified to allow TF32 once this bug is fixed. -// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. -// // TODO(b/393299275): this test uncovers a bug in hoisting bitcasts through // broadcasts (seems to generate a type mismatch). -TEST_F(TritonTest, DISABLED_NoTF32For8BitOrLessWithF32) { +TEST_F(TritonTest, DISABLED_UseTF32For8BitOrLessWithF32) { const std::string hlo_text = R"( HloModule t @@ -3987,7 +3984,7 @@ ENTRY e { module_and_metadata.block_level_parameters, R"( CHECK: tt.dot -CHECK-NOT: inputPrecision = tf32 +CHECK: inputPrecision = tf32 )")); EXPECT_TRUE(RunAndCompareNoHloPasses( From 6f49643fd9c62e5ccb597bcbfa691edfc7617a82 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 02:02:42 -0700 Subject: [PATCH 0291/1324] Update GraphDef version to 2190. PiperOrigin-RevId: 744636038 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 366f8c03b00947..37debf7e5a16f6 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2189 // Updated: 2025/4/6 +#define TF_GRAPH_DEF_VERSION 2190 // Updated: 2025/4/7 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 31e8cf1abec9dd1f7adbc7d448ef299f32b63d49 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 02:02:42 -0700 Subject: [PATCH 0292/1324] compat: Update forward compatibility horizon to 2025-04-07 PiperOrigin-RevId: 744636040 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 35cb75e8cad846..efbd58c3b04602 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 6) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 7) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 71dea564c90d33e659f07643a2cebf4ec527a3ed Mon Sep 17 00:00:00 2001 From: Terry Sun Date: Mon, 7 Apr 2025 03:31:06 -0700 Subject: [PATCH 0293/1324] PR #24491: [NVIDIA GPU] Collective-permute combiner ignores channel id when enabled Imported from GitHub PR https://github.com/openxla/xla/pull/24491 This PR pipes `xla_ignore_channel_id` flag to collective-permute combiner, allowing the combiner to ignore channel id difference when the flag is on. Copybara import of the project: -- 01a961d67f3a2ada8856c2da920da52db87071b9 by Terry Sun : ignore channel id when enabled -- f03e91d4611999999e6919374889028c54c899b0 by Terry Sun : always ignore channel id Merging this change closes #24491 PiperOrigin-RevId: 744657932 --- .../collectives/collective_permute_combiner_test.cc | 13 ++++++++----- .../xla/xla/service/collective_permute_key.cc | 2 +- .../xla/xla/service/collective_permute_key.h | 3 +-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_permute_combiner_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_permute_combiner_test.cc index 1879654d99f84b..c860f6cb766c9c 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collective_permute_combiner_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_permute_combiner_test.cc @@ -301,7 +301,7 @@ ENTRY %CombineCollectivePermutes () -> (f32[256], f32[512], f32[2560], f32[1792] EXPECT_TRUE(changed); } -TEST_F(CollectivePermuteCombinerTest, ChannelIdPreventsCombining) { +TEST_F(CollectivePermuteCombinerTest, IgnoreChannelId) { const char* const hlo_string = R"( HloModule CombineCollectivePermutes, entry_computation_layout={()->(f32[256]{0}, f32[512]{0}, f32[2560]{0}, f32[1792]{0}, f32[1536]{0})} @@ -328,16 +328,19 @@ ENTRY %CombineCollectivePermutes () -> (f32[256], f32[512], f32[2560], f32[1792] ROOT %tuple = (f32[256]{0}, f32[512]{0}, f32[2560]{0}, f32[1792]{0}, f32[1536]{0}) tuple(f32[256]{0} %collective-permute, f32[512]{0} %collective-permute.1, f32[2560]{0} %collective-permute.2, f32[1792]{0} %collective-permute.3, f32[1536]{0} %collective-permute.4) })"; + HloModuleConfig config = GetModuleConfigForTest(); + auto opts = GetDebugOptionsForTest(); + opts.set_xla_ignore_channel_id(true); + config.set_debug_options(opts); TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string, config)); const int64_t total_count = 5; CollectivePermuteCombiner combine(1024 * 1024, kMaxCombineCount); ASSERT_EQ(CollectivePermuteCount(*module), total_count); TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); - // Expect two combined collective permute ops since there are two types of - // channel_id in HLO - EXPECT_EQ(CollectivePermuteCount(*module), 2); + // Expect one combined collective permute op since channel_id is ignored + EXPECT_EQ(CollectivePermuteCount(*module), 1); EXPECT_TRUE(changed); } diff --git a/third_party/xla/xla/service/collective_permute_key.cc b/third_party/xla/xla/service/collective_permute_key.cc index f6ba9f68514d73..320a7cb1903902 100644 --- a/third_party/xla/xla/service/collective_permute_key.cc +++ b/third_party/xla/xla/service/collective_permute_key.cc @@ -34,7 +34,7 @@ std::optional GetCollectivePermuteKey( } const auto* cp = Cast(instruction); - return CollectivePermuteKey{cp->source_target_pairs(), cp->channel_id()}; + return CollectivePermuteKey{cp->source_target_pairs()}; } } // namespace xla diff --git a/third_party/xla/xla/service/collective_permute_key.h b/third_party/xla/xla/service/collective_permute_key.h index ea0018633d88a1..3876aeaddc169b 100644 --- a/third_party/xla/xla/service/collective_permute_key.h +++ b/third_party/xla/xla/service/collective_permute_key.h @@ -32,8 +32,7 @@ namespace xla { // collective-permute instructions to be compatible with each other (and hence // be possible to combine the instructions). using CollectivePermuteKey = std::tuple< - /*source_target_pairs*/ std::vector>, - /*channel_id*/ std::optional>; + /*source_target_pairs*/ std::vector>>; std::optional GetCollectivePermuteKey( const HloInstruction* instruction); From c0f30fcb7cfa1d2055d1c7dd40b04260f820b8fe Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 7 Apr 2025 03:39:46 -0700 Subject: [PATCH 0294/1324] Reverts dd5064300419f85a890b2bc616c954fe0a3d8dc0 PiperOrigin-RevId: 744660090 --- third_party/xla/xla/debug_options_flags.cc | 10 +- third_party/xla/xla/service/gpu/BUILD | 2 - .../xla/xla/service/gpu/backend_configs.proto | 9 +- .../xla/service/gpu/backend_configs_test.cc | 1 - .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 2 - .../xla/service/gpu/gpu_hlo_schedule_test.cc | 1 - .../xla/xla/service/gpu/transforms/BUILD | 37 ---- .../command_buffer_scheduling_test.cc | 16 +- .../gpu/transforms/schedule_postprocessing.cc | 158 ----------------- .../gpu/transforms/schedule_postprocessing.h | 50 ------ .../schedule_postprocessing_test.cc | 163 ------------------ third_party/xla/xla/xla.proto | 4 +- 12 files changed, 15 insertions(+), 438 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc delete mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h delete mode 100644 third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 2c44ea0f19636d..73fa48e678914a 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -253,7 +253,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_split_k_autotuning(true); opts.set_xla_gpu_enable_reduction_epilogue_fusion(true); - opts.set_xla_gpu_enable_nccl_clique_optimization(false); opts.set_xla_gpu_cublas_fallback(true); opts.set_xla_gpu_cudnn_gemm_fusion_level(0); opts.set_xla_gpu_enable_while_loop_double_buffering(false); @@ -1919,12 +1918,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_reduction_epilogue_fusion), debug_options->xla_gpu_enable_reduction_epilogue_fusion(), "Enable fusion for reduction epilogues")); - flag_list->push_back( - tsl::Flag("xla_gpu_enable_nccl_clique_optimization", - bool_setter_for( - &DebugOptions::set_xla_gpu_enable_nccl_clique_optimization), - debug_options->xla_gpu_enable_nccl_clique_optimization(), - "Allow early return when acquiring NCCL cliques")); + flag_list->push_back(tsl::Flag("xla_gpu_enable_nccl_clique_optimization", + noop_flag_setter, false, + "[Deprecated, do not use].")); flag_list->push_back( tsl::Flag("xla_gpu_cublas_fallback", bool_setter_for(&DebugOptions::set_xla_gpu_cublas_fallback), diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ab2b73adf48be0..ed3775413cbbce 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2158,7 +2158,6 @@ cc_library( "//xla/service/gpu/model:sol_latency_estimator", "//xla/service/gpu/transforms:async_collective_annotator", "//xla/service/gpu/transforms:pgle_accuracy_checker", - "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_description", @@ -2201,7 +2200,6 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service:latency_hiding_scheduler", "//xla/service:legalize_scheduling_annotations", - "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 329cb7f44c0cc4..2486e0b0f83d86 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -118,19 +118,16 @@ message BitcastBackendConfig { // Backend config for async collective operations. Note that for is_sync will // be false by default, so even if a backend config is not explicitly attached // to the HLOInstruction, getting the backend_config will yield a default valued -// proto which will have is_sync = false. Attribute no_parallel_custom_call -// asserts that an asynchronous collective operation does not execute in -// parallel with custom-calls, which can trigger device synchronization . This -// attribute will also be false by default and should lead to conversative -// runtime behavior. +// proto which will have is_sync = false. message CollectiveBackendConfig { bool is_sync = 1; - bool no_parallel_custom_call = 2; // Determines whether the collective op of interested has been pipelined // within a loop. bool is_pipelined = 3; // Cost model prediction. repeated ReificationCost reification_cost = 4; + + reserved 2; } // Backend config for cost model estimates. diff --git a/third_party/xla/xla/service/gpu/backend_configs_test.cc b/third_party/xla/xla/service/gpu/backend_configs_test.cc index 7883547f077dcb..16f05964536e71 100644 --- a/third_party/xla/xla/service/gpu/backend_configs_test.cc +++ b/third_party/xla/xla/service/gpu/backend_configs_test.cc @@ -59,7 +59,6 @@ TEST_F(BackendConfigsTest, DefaultCollectiveBackendConfig) { const auto& collective_backend_config = gpu_config.collective_backend_config(); EXPECT_THAT(collective_backend_config.is_sync(), IsFalse()); - EXPECT_THAT(collective_backend_config.no_parallel_custom_call(), IsFalse()); } TEST_F(BackendConfigsTest, DefaultGpuBackendConfigParseOpQueue) { diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index c050fad9b8d68a..aacd6d264475fa 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -59,7 +59,6 @@ limitations under the License. #include "xla/service/gpu/transforms/async_collective_annotator.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/gpu/transforms/pgle_accuracy_checker.h" -#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" @@ -596,7 +595,6 @@ absl::Status RunLatencyHidingSchedulerPasses( std::move(estimator), std::move(async_tracker), std::move(scheduler_core), shape_size_in_bytes); pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index f2d3c1ebbedfda..375210385f89e7 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -43,7 +43,6 @@ limitations under the License. #include "xla/service/backend.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/latency_hiding_scheduler.h" diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 13cc6cb5719f21..08971084ef1ad7 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2835,43 +2835,6 @@ xla_cc_test( ], ) -cc_library( - name = "schedule_postprocessing", - srcs = ["schedule_postprocessing.cc"], - hdrs = ["schedule_postprocessing.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu/transforms/collectives:collective_ops_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "schedule_postprocessing_test", - srcs = ["schedule_postprocessing_test.cc"], - deps = [ - ":schedule_postprocessing", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/service/gpu:backend_configs_cc", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "scheduling_instruction_annotator", srcs = ["scheduling_instruction_annotator.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 3c60b7f6ec2255..e74ff34ec0d625 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -209,7 +209,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) { %a = s32[4] parameter(0) %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) })"; @@ -242,7 +242,7 @@ TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) { %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a), channel_id=555, replica_groups={{0,1}}, dimensions={0}, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} ROOT %done = s32[4]{0} all-gather-done(%start) })"; @@ -282,7 +282,7 @@ TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) { %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a), channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} ROOT %done = s32[2]{0} reduce-scatter-done(%start) })"; @@ -321,7 +321,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) { %a = s32[4] parameter(0) %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %bitcast = s32[4] bitcast(s32[4]{0} %a) ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) })"; @@ -361,10 +361,10 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) { %a = s32[4] parameter(0) %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) })"; @@ -418,11 +418,11 @@ TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) { %b = s32[] parameter(1) %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %c = s32[] custom-call(), custom_call_target="target" %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), replica_groups={{0,1}}, to_apply=%add, - backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + backend_config={"collective_backend_config": {"is_sync":true}} %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc deleted file mode 100644 index 0fd39a27d0e3ba..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/schedule_postprocessing.h" - -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { -// Maps a computation to a boolean that indicates whether the computation may -// invoke custom-calls directly or indirectly, which can eventually trigger gpu -// synchronization. -using CustomCallInComputation = - absl::flat_hash_map; - -// Returns whether the hlo may invoke custom-calls which may trigger gpu -// synchronization. Currently, we only check for custom-calls, because they are -// the only operations that can be parallel with asynchronous collectives -// operations in an hlo-schedule and may trigger gpu synchronization. -bool MayInvokeCustomCall( - const HloInstruction* hlo, - const CustomCallInComputation& custom_call_in_computation) { - if (HloPredicateIsOp(hlo)) { - return true; - } - - return absl::c_any_of( - hlo->called_computations(), [&](const HloComputation* callee) { - return custom_call_in_computation.find(callee)->second; - }); -} - -// Returns true if this is an asynchronous collective start operation, excluding -// P2P operations. -bool IsRelevantAsynchronousStart(const HloInstruction* hlo) { - return hlo_query::IsAsyncCollectiveStartOp(hlo, - /*include_send_recv=*/false) && - !IsGPUSyncCollective(*hlo); -} - -// Returns true if this is a collective done operation, excluding P2P -// operations. -bool IsRelevantAsynchronousDone(const HloInstruction* hlo) { - return hlo_query::IsAsyncCollectiveDoneOp(hlo, - /*include_send_recv=*/false); -} - -// For a given computation, finds all the asynchronous collective operations -// that aren't parallel with custom-calls and sets its no_parallel_custom_call -// attribute to true. Also records whether the given computation may invoke -// custom-calls. -absl::StatusOr ProcessComputation( - const HloSchedule& schedule, HloComputation* computation, - CustomCallInComputation& custom_call_in_computation) { - bool changed = false; - bool has_custom_call = false; - absl::flat_hash_set async_starts; - const HloInstructionSequence& sequence = schedule.sequence(computation); - - // Visit instructions in the sequence. Collect relevant asynchronous - // collective start ops. When we see a relevant asynchronous collective done - // op, remove the corresponding start op from the collection and set its - // attribute no_parallel_custom_call to true. When we see a custom-call, clear - // the start ops from the collection and keep their attribute - // no_parallel_custom_call as false. - const std::vector& all_instructions = - sequence.instructions(); - for (HloInstruction* hlo : all_instructions) { - if (MayInvokeCustomCall(hlo, custom_call_in_computation)) { - async_starts.clear(); - has_custom_call = true; - continue; - } - if (IsRelevantAsynchronousStart(hlo)) { - async_starts.insert(hlo); - continue; - } - - if (IsRelevantAsynchronousDone(hlo)) { - HloInstruction* async_start = hlo->mutable_operand(0); - if (async_starts.contains(async_start)) { - changed = true; - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - async_start->backend_config()); - CollectiveBackendConfig& collective_backend_config = - *gpu_config.mutable_collective_backend_config(); - collective_backend_config.set_no_parallel_custom_call(true); - TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config)); - async_starts.erase(async_start); - } - } - } - - custom_call_in_computation[computation] = has_custom_call; - return changed; -} - -} // anonymous namespace - -absl::StatusOr SchedulePostprocessing::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - if (!module->has_schedule()) return false; - HloSchedule& schedule = module->schedule(); - bool changed = false; - CustomCallInComputation custom_call_in_computation; - - // We visit computations in the order of callees to callers, as information is - // propagated from calles to callers. - std::vector all_computations = - module->MakeComputationPostOrder(execution_threads); - for (auto iter = all_computations.begin(); iter != all_computations.end(); - ++iter) { - HloComputation* computation = *iter; - if (computation->IsFusionComputation()) { - custom_call_in_computation[computation] = false; - continue; - } - - TF_ASSIGN_OR_RETURN( - bool result, - ProcessComputation(schedule, computation, custom_call_in_computation)); - changed |= result; - } - - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h deleted file mode 100644 index d76faed7d260cc..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ -#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -// Amends a schedule result with the needed information to support a runtime -// implementation. Currently, this pass refines attribute -// no_parallel_custom_call for asynchronous collective operations to support -// runtime optimization, such as skipping rendezvous of all participating -// threads for NCCL collective operations. In particular, it sets the attribute -// value for Collective-start operations with is_sync=false; it also keeps the -// attribute value untouch for the operations with is_sync=true and for P2P -// operations, assumming the runtime won't use those values. -// -class SchedulePostprocessing : public HloModulePass { - public: - absl::string_view name() const override { return "schedule-postprocessing"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc deleted file mode 100644 index 01659a11f6e66d..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/schedule_postprocessing.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/parser/hlo_parser.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -using SchedulePostprocessingTest = HloTestBase; - -TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY entry { - pf32 = f32[1] parameter(0) - - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}} - ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY main { - f0 = f32[] constant(0.0) - init = f32[1, 1024, 1024] broadcast(f0), dimensions={} - - after-all = token[] after-all() - recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}" - } - recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2 - ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY entry { - pf32 = f32[1] parameter(0) - pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} - ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, - start->backend_config()); - const CollectiveBackendConfig& collective_backend_config = - gpu_config.collective_backend_config(); - EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); -} - -TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - - ENTRY entry { - pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} - pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" - all-gather-done = f32[2] all-gather-done(all-gather-start) - ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); - - HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, - start->backend_config()); - const CollectiveBackendConfig& collective_backend_config = - gpu_config.collective_backend_config(); - EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); -} - -TEST_F(SchedulePostprocessingTest, - AsynchronousOpsWithParallelNestedCustomcall) { - constexpr absl::string_view kHloString = R"( - HloModule module, is_scheduled=true - foo { - v = f32[1] parameter(0) - ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call" - } - - ENTRY entry { - pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} - pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo - all-gather-done = f32[2] all-gather-done(all-gather-start) - ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) - } -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kHloString))); - SchedulePostprocessing pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); - EXPECT_FALSE(changed); - - HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, - start->backend_config()); - const CollectiveBackendConfig& collective_backend_config = - gpu_config.collective_backend_config(); - EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index bd7c0ccbd4b661..924701a09deda4 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -443,9 +443,6 @@ message DebugOptions { // threads. Setting to 0 (the default value) means no enforcement. bool xla_gpu_enable_llvm_module_compilation_parallelism = 268; - // Allow early return when acquiring NCCL cliques. - bool xla_gpu_enable_nccl_clique_optimization = 244; - // Enable NCCL communicator splitting. bool xla_gpu_enable_nccl_comm_splitting = 272; @@ -805,6 +802,7 @@ message DebugOptions { // go/keep-sorted end + reserved 244; // xla_gpu_enable_nccl_clique_optimization reserved 276; // xla_gpu_enable_nccl_per_stream_comms //--------------------------------------------------------------------------// From c44b307c5ee6a275d106ae36f971867573c7168c Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Mon, 7 Apr 2025 03:52:59 -0700 Subject: [PATCH 0295/1324] PR #24578: [ROCm] Check for shared memory resource during triton fusion emitter Imported from GitHub PR https://github.com/openxla/xla/pull/24578 Copybara import of the project: -- 02e11e347ae7b026d09c83ccbf6f683ef70c52bf by Harsha HS : [ROCm] Check for shared memory resource during triton fusion emitter Merging this change closes #24578 PiperOrigin-RevId: 744663087 --- .../xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc | 3 +-- third_party/xla/xla/stream_executor/rocm/rocm_executor.cc | 2 ++ third_party/xla/xla/tools/hlo_opt/gpu_specs/mi200.txtpb | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index d2ca7eadac1dfa..f183568ae2a36d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1860,8 +1860,7 @@ absl::StatusOr CompileTritonToLLVM( const int shared_mem_bytes = triton_module->getAttrOfType("ttg.shared").getInt(); VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B"; - if (std::holds_alternative(cc) && - shared_mem_bytes > device_info.shared_memory_per_block_optin()) { + if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) { return absl::ResourceExhaustedError(absl::StrFormat( "Shared memory size limit exceeded: requested %d, available: %d", shared_mem_bytes, device_info.shared_memory_per_block_optin())); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index ebc99ae75f5282..d4dde90953fa03 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -1118,6 +1118,8 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_shared_memory_per_core(GetMaxSharedMemoryPerCore(device).value()); desc.set_shared_memory_per_block(GetMaxSharedMemoryPerBlock(device).value()); + desc.set_shared_memory_per_block_optin( + GetMaxSharedMemoryPerBlock(device).value()); int core_count = GetMultiprocessorCount(device).value(); desc.set_core_count(core_count); desc.set_fpus_per_core(fpus_per_core(gcn_arch_name)); diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_specs/mi200.txtpb b/third_party/xla/xla/tools/hlo_opt/gpu_specs/mi200.txtpb index 3c882ac2dcbe56..e72fa777e2ac34 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_specs/mi200.txtpb +++ b/third_party/xla/xla/tools/hlo_opt/gpu_specs/mi200.txtpb @@ -16,6 +16,7 @@ gpu_device_info { threads_per_block_limit: 1024 threads_per_warp: 64 shared_memory_per_block: 65536 + shared_memory_per_block_optin: 65536 shared_memory_per_core: 65536 threads_per_core_limit: 2048 core_count: 110 From f9d3c26d8b4b7bb504247573536e04cbb8dfe269 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Mon, 7 Apr 2025 04:38:09 -0700 Subject: [PATCH 0296/1324] Add support for blockwise quantization to FC operator. PiperOrigin-RevId: 744674517 --- .../lite/cmake/DownloadPThreadPool.cmake | 4 +- tensorflow/lite/core/c/common.cc | 17 +- tensorflow/lite/core/c/common.h | 16 + tensorflow/lite/core/interpreter_builder.cc | 14 + .../delegates/xnnpack/xnnpack_delegate.cc | 413 ++++++++++++------ tensorflow/lite/kernels/fully_connected.cc | 109 ++++- .../lite/tools/cmake/modules/xnnpack.cmake | 2 +- tensorflow/workspace2.bzl | 12 +- 8 files changed, 433 insertions(+), 154 deletions(-) diff --git a/tensorflow/lite/cmake/DownloadPThreadPool.cmake b/tensorflow/lite/cmake/DownloadPThreadPool.cmake index a38d1b319d8bc7..e12799e3231a31 100644 --- a/tensorflow/lite/cmake/DownloadPThreadPool.cmake +++ b/tensorflow/lite/cmake/DownloadPThreadPool.cmake @@ -19,8 +19,8 @@ PROJECT(pthreadpool-download NONE) INCLUDE(ExternalProject) ExternalProject_Add(pthreadpool - URL https://github.com/google/pthreadpool/archive/706a8ea9e4b8c2129718af195ddce7fc2573e719.zip - URL_HASH SHA256=2d56c31ebf6509d171d12ace2b543f6182ff0083ba674541515fc573738a3238 + URL https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip + URL_HASH SHA256=745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421 SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" BINARY_DIR "${CMAKE_BINARY_DIR}/pthreadpool" CONFIGURE_COMMAND "" diff --git a/tensorflow/lite/core/c/common.cc b/tensorflow/lite/core/c/common.cc index d458b1eb29b5ab..baa6282fd5b12e 100644 --- a/tensorflow/lite/core/c/common.cc +++ b/tensorflow/lite/core/c/common.cc @@ -113,14 +113,25 @@ TfLiteQuantization TfLiteQuantizationClone(const TfLiteQuantization& src) { case kTfLiteAffineQuantization: { dst.params = calloc(1, sizeof(TfLiteAffineQuantization)); const TfLiteAffineQuantization* const src_params = - (TfLiteAffineQuantization*)(src.params); + reinterpret_cast(src.params); TfLiteAffineQuantization* const dst_params = - (TfLiteAffineQuantization*)(dst.params); + reinterpret_cast(dst.params); dst_params->quantized_dimension = src_params->quantized_dimension; dst_params->scale = TfLiteFloatArrayCopy(src_params->scale); dst_params->zero_point = TfLiteIntArrayCopy(src_params->zero_point); break; } + case kTfLiteBlockwiseQuantization: { + dst.params = calloc(1, sizeof(TfLiteBlockwiseQuantization)); + const TfLiteBlockwiseQuantization* const src_params = + (TfLiteBlockwiseQuantization*)(src.params); + TfLiteBlockwiseQuantization* const dst_params = + (TfLiteBlockwiseQuantization*)(dst.params); + dst_params->blocksize = src_params->blocksize; + dst_params->scale = src_params->scale; + dst_params->zero_point = src_params->zero_point; + break; + } } return dst; } @@ -225,7 +236,7 @@ void TfLiteTensorDataFree(TfLiteTensor* t) { void TfLiteQuantizationFree(TfLiteQuantization* quantization) { if (quantization->type == kTfLiteAffineQuantization) { TfLiteAffineQuantization* q_params = - (TfLiteAffineQuantization*)(quantization->params); + reinterpret_cast(quantization->params); if (q_params->scale) { TfLiteFloatArrayFree(q_params->scale); q_params->scale = nullptr; diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 87a9b1a5075051..57caa3b759a35d 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -328,6 +328,8 @@ typedef enum TfLiteQuantizationType : int { /// Affine quantization (with support for per-channel quantization). /// Corresponds to TfLiteAffineQuantization. kTfLiteAffineQuantization = 1, + /// Blockwise quantization. + kTfLiteBlockwiseQuantization = 2, } TfLiteQuantizationType; /// Structure specifying the quantization used by the tensor, if-any. @@ -353,6 +355,20 @@ typedef struct TfLiteAffineQuantization { int32_t quantized_dimension; } TfLiteAffineQuantization; +/// Parameters for blockwise quantization across the output channels dimension. +/// For a particular value in quantized_dimension, quantized values can be +/// converted back to float using: +/// `real_value = scale * (quantized_value - zero_point)` +typedef struct TfLiteBlockwiseQuantization { + // Index of the tensor containing the scales. + int32_t scale; + // Index of the tensor containing the zero points. + int32_t zero_point; + // Quantization blocksize. + int32_t blocksize; + int32_t quantized_dimension; +} TfLiteBlockwiseQuantization; + /// A union of pointers that points to memory for a given tensor. /// /// Do not access these members directly, if possible, use diff --git a/tensorflow/lite/core/interpreter_builder.cc b/tensorflow/lite/core/interpreter_builder.cc index 1c6cc8c2ac9dd9..8741022e3c2a70 100644 --- a/tensorflow/lite/core/interpreter_builder.cc +++ b/tensorflow/lite/core/interpreter_builder.cc @@ -407,6 +407,20 @@ TfLiteStatus InterpreterBuilder::ParseNodes( TfLiteStatus InterpreterBuilder::ParseQuantization( const QuantizationParameters* src_quantization, TfLiteQuantization* quantization, const std::vector& dims) { + // Blockwise quantization. + if (src_quantization && src_quantization->details_type() == + QuantizationDetails_BlockwiseQuantization) { + auto* src_quant = src_quantization->details_as_BlockwiseQuantization(); + quantization->type = kTfLiteBlockwiseQuantization; + auto* blockwise_quantization = + reinterpret_cast( + malloc(sizeof(TfLiteBlockwiseQuantization))); + blockwise_quantization->scale = src_quant->scales(); + blockwise_quantization->quantized_dimension = 0; + blockwise_quantization->blocksize = src_quant->block_size(); + quantization->params = reinterpret_cast(blockwise_quantization); + return kTfLiteOk; + } quantization->type = kTfLiteNoQuantization; quantization->params = nullptr; if (!src_quantization || !src_quantization->scale() || diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 7cb451e956256c..f5d3eaeff3b1ca 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -91,6 +91,96 @@ void CopyTensorDataInt32OrInt64(int64_t* dst, const TfLiteTensor& tensor, } } +bool CheckZeroPoint(TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteIntArray* quantization_zero_point) { + if (quantization_zero_point == nullptr) { + TF_LITE_KERNEL_LOG(context, + "missing zero point quantization parameters for " + "%s tensor %d in XNNPACK delegate", + TfLiteTypeGetName(tensor.type), t); + return false; + } + return true; +} + +bool CheckFp16Scale(TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteBlockwiseQuantization* quantization_params) { + const TfLiteTensor& scale = context->tensors[quantization_params->scale]; + int num_scales = NumElements(&scale); + std::vector dequantized_scale(num_scales); + DequantizeFloat16(reinterpret_cast(scale.data.data), + dequantized_scale.data(), num_scales); + for (int i = 0; i < num_scales; i++) { + if (!std::isnormal(dequantized_scale[i]) || dequantized_scale[i] <= 0.0f) { + TF_LITE_KERNEL_LOG(context, + "unsupported scale value (%f) in channel %d for " + "%s tensor %d in XNNPACK delegate", + dequantized_scale[i], i, + TfLiteTypeGetName(tensor.type), t); + return false; + } + } + return true; +} + +bool CheckFp32Scale(TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteFloatArray* quantization_scale, + const TfLiteIntArray* quantization_zero_point) { + if (quantization_scale == nullptr) { + TF_LITE_KERNEL_LOG(context, + "missing scale quantization parameters for %s " + "tensor %d in XNNPACK delegate", + TfLiteTypeGetName(tensor.type), t); + return false; + } + if (quantization_zero_point != nullptr && + quantization_scale->size != quantization_zero_point->size) { + TF_LITE_KERNEL_LOG(context, + "mismatching number of scale (%d) and zero " + "point (%d) quantization parameters for %s " + "tensor %d in XNNPACK delegate", + quantization_scale->size, quantization_zero_point->size, + TfLiteTypeGetName(tensor.type), t); + return false; + } + for (int i = 0; i < quantization_scale->size; i++) { + const float scale = quantization_scale->data[i]; + if (!std::isnormal(scale) || scale <= 0.0f) { + TF_LITE_KERNEL_LOG(context, + "unsupported scale value (%f) in channel %d for " + "%s tensor %d in XNNPACK delegate", + scale, i, TfLiteTypeGetName(tensor.type), t); + return false; + } + } + return true; +} + +xnn_datatype CheckPerTensorQuantization( + TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteFloatArray* quantization_scale, + const TfLiteIntArray* quantization_zero_point) { + // Per-tensor quantization parameters + if (kTfLiteInt8 != tensor.type) { + TF_LITE_KERNEL_LOG(context, + "unsupported per-tensor quantization scale " + "parameter for %s tensor %d in XNNPACK delegate", + TfLiteTypeGetName(tensor.type), t); + return xnn_datatype_invalid; + } + + const int zero_point = quantization_zero_point->data[0]; + if (zero_point < std::numeric_limits::min() || + zero_point > std::numeric_limits::max()) { + TF_LITE_KERNEL_LOG(context, + "unsupported zero-point value (%d) for INT8 " + "tensor %d in XNNPACK delegate", + zero_point, t); + return xnn_datatype_invalid; + } + return xnn_datatype_qint8; +} + xnn_datatype GetXNNPackDatatype(TfLiteContext* context, const TfLiteTensor& tensor, int t) { switch (tensor.type) { @@ -163,111 +253,106 @@ xnn_datatype GetXNNPackDatatype(TfLiteContext* context, } case kTfLiteInt8: case kTfLiteInt4: { - if (tensor.quantization.type != kTfLiteAffineQuantization) { - TF_LITE_KERNEL_LOG(context, - "unsupported quantization type %d for %s " - "tensor %d in XNNPACK delegate", - tensor.quantization.type, - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - const auto quantization_params = - static_cast( - tensor.quantization.params); - if (quantization_params->scale == nullptr) { - TF_LITE_KERNEL_LOG(context, - "missing scale quantization parameters for %s " - "tensor %d in XNNPACK delegate", - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - if (quantization_params->zero_point == nullptr) { - TF_LITE_KERNEL_LOG(context, - "missing zero point quantization parameters for " - "%s tensor %d in XNNPACK delegate", - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - if (quantization_params->scale->size != - quantization_params->zero_point->size) { - TF_LITE_KERNEL_LOG(context, - "mismatching number of scale (%d) and zero " - "point (%d) quantization parameters for %s " - "tensor %d in XNNPACK delegate", - quantization_params->scale->size, - quantization_params->zero_point->size, - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - - for (int i = 0; i < quantization_params->scale->size; i++) { - const float scale = quantization_params->scale->data[i]; - if (!std::isnormal(scale) || scale <= 0.0f) { - TF_LITE_KERNEL_LOG(context, - "unsupported scale value (%f) in channel %d for " - "%s tensor %d in XNNPACK delegate", - scale, i, TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; + switch (tensor.quantization.type) { + case kTfLiteAffineQuantization: { + const auto quantization_params = + static_cast( + tensor.quantization.params); + const auto quantization_scale = quantization_params->scale; + const auto quantization_zero_point = quantization_params->zero_point; + if (!CheckFp32Scale(context, tensor, t, quantization_scale, + quantization_zero_point)) { + return xnn_datatype_invalid; + } + if (quantization_scale->size == 1) { + return CheckPerTensorQuantization(context, tensor, t, + quantization_scale, + quantization_zero_point); + } + if (!CheckZeroPoint(context, tensor, t, quantization_zero_point)) { + return xnn_datatype_invalid; + } + if (NumDimensions(&tensor) >= 1 && + quantization_scale->size == + SizeOfDimension(&tensor, + quantization_params->quantized_dimension)) { + // Per-channel quantization parameters + for (int c = 0; + c < SizeOfDimension(&tensor, + quantization_params->quantized_dimension); + c++) { + if (quantization_params->zero_point->data[c] != 0 && + (tensor.type != kTfLiteInt4 && + quantization_params->zero_point->data[c] != 8)) { + TF_LITE_KERNEL_LOG(context, + "unsupported zero-point value %d in channel " + "%d of %s tensor %d in XNNPACK delegate", + quantization_params->zero_point->data[c], c, + TfLiteTypeGetName(tensor.type), t); + return xnn_datatype_invalid; + } + } + } else { + TF_LITE_KERNEL_LOG( + context, + "mismatching number of quantization parameters %d and outer " + "dimension %d for INT8 tensor %d in XNNPACK delegate", + quantization_params->scale->size, + SizeOfDimension(&tensor, + quantization_params->quantized_dimension), + t); + return xnn_datatype_invalid; + } + break; } - } - - if (quantization_params->scale->size == 1) { - // Per-tensor quantization parameters - if (kTfLiteInt8 != tensor.type) { - TF_LITE_KERNEL_LOG(context, - "unsupported per-tensor quantization scale " - "parameter for %s tensor %d in XNNPACK delegate", - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; + case kTfLiteBlockwiseQuantization: { + const auto quantization_params = + reinterpret_cast( + tensor.quantization.params); + if (!CheckFp16Scale(context, tensor, t, quantization_params)) { + return xnn_datatype_invalid; + } + int num_scales = + NumElements(&context->tensors[quantization_params->scale]); + int num_filter_elements = NumElements(&tensor); + if (num_filter_elements / num_scales != + quantization_params->blocksize) { + TF_LITE_KERNEL_LOG(context, + "Unsupported combination of filter elements %d " + "number of scales %d and blocksize %d " + "%s tensor %d in XNNPACK delegate", + num_filter_elements, num_scales, + quantization_params->blocksize, t); + return xnn_datatype_invalid; + } + break; } - - const int zero_point = quantization_params->zero_point->data[0]; - if (zero_point < std::numeric_limits::min() || - zero_point > std::numeric_limits::max()) { + default: TF_LITE_KERNEL_LOG(context, - "unsupported zero-point value (%d) for INT8 " + "unsupported quantization type %d for %s " "tensor %d in XNNPACK delegate", - zero_point, t); + tensor.quantization.type, + TfLiteTypeGetName(tensor.type), t); return xnn_datatype_invalid; - } - return xnn_datatype_qint8; - } else if (NumDimensions(&tensor) >= 1 && - quantization_params->scale->size == - SizeOfDimension( - &tensor, quantization_params->quantized_dimension)) { - // Per-channel quantization parameters - for (int c = 0; - c < - SizeOfDimension(&tensor, quantization_params->quantized_dimension); - c++) { - if (quantization_params->zero_point->data[c] != 0 && - (tensor.type != kTfLiteInt4 && - quantization_params->zero_point->data[c] != 8)) { - TF_LITE_KERNEL_LOG(context, - "unsupported zero-point value %d in channel " - "%d of %s tensor %d in XNNPACK delegate", - quantization_params->zero_point->data[c], c, - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; + } + + switch (tensor.type) { + case kTfLiteInt4: + switch (tensor.quantization.type) { + case kTfLiteAffineQuantization: + return xnn_datatype_qcint4; + case kTfLiteBlockwiseQuantization: + return xnn_datatype_qbint4; + default: + TF_LITE_KERNEL_LOG( + context, + "Unsupported quantization type %zu for INT4 tensor #%d", t); + return xnn_datatype_invalid; } - } - switch (tensor.type) { - case kTfLiteInt4: - return xnn_datatype_qcint4; - case kTfLiteInt8: - return xnn_datatype_qcint8; - default: - return xnn_datatype_invalid; - } - } else { - TF_LITE_KERNEL_LOG( - context, - "mismatching number of quantization parameters %d and outer " - "dimension %d for INT8 tensor %d in XNNPACK delegate", - quantization_params->scale->size, - SizeOfDimension(&tensor, quantization_params->quantized_dimension), - t); - return xnn_datatype_invalid; + case kTfLiteInt8: + return xnn_datatype_qcint8; + default: + return xnn_datatype_invalid; } break; } @@ -1118,6 +1203,20 @@ class Subgraph { ->quantized_dimension, dims.data(), data, XNN_INVALID_VALUE_ID, flags, &xnnpack_id); break; + case xnn_datatype_qbint4: { + const auto* quantization_params = + reinterpret_cast( + context->tensors[t].quantization.params); + const TfLiteTensor& scale_tensor = + context->tensors[quantization_params->scale]; + status = xnn_define_blockwise_quantized_tensor_value_v2( + subgraph.get(), datatype, 0, + reinterpret_cast(scale_tensor.data.data), + dims.size(), quantization_params->quantized_dimension, + quantization_params->blocksize, dims.data(), data, + XNN_INVALID_VALUE_ID, flags, xnn_datatype_fp16, &xnnpack_id); + break; + } default: status = xnn_define_tensor_value( subgraph.get(), datatype, dims.size(), dims.data(), data, @@ -2184,32 +2283,59 @@ class Subgraph { case kTfLiteInt8: if (delegate.support_signed_8bit_quantization() && (kTfLiteInt8 == tensor.type || kTfLiteInt4 == tensor.type)) { - if (tensor.quantization.type != kTfLiteAffineQuantization) { - TF_LITE_MAYBE_KERNEL_LOG( - context, - "unsupported quantization type %d in tensor #%d in node #%d", - tensor.quantization.type, tensor_index, node_index); - return kTfLiteError; - } - const TfLiteAffineQuantization* quantization_params = - static_cast( - tensor.quantization.params); - if (quantization_params->scale == nullptr) { - TF_LITE_MAYBE_KERNEL_LOG(context, - "missing scale quantization parameters in " - "tensor #%d in node #%d", - tensor_index, node_index); - return kTfLiteError; - } - if (quantization_params->scale->size > 1 && - quantization_params->quantized_dimension != - expected_quantized_dimension) { - TF_LITE_MAYBE_KERNEL_LOG( - context, - "unsupported quantized dimension %d in tensor #%d in node #%d", - quantization_params->quantized_dimension, tensor_index, - node_index); - return kTfLiteError; + switch (tensor.quantization.type) { + case kTfLiteAffineQuantization: { + const TfLiteAffineQuantization* quantization_params = + static_cast( + tensor.quantization.params); + if (quantization_params->scale == nullptr) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "missing scale quantization parameters in " + "tensor #%d in node #%d", + tensor_index, node_index); + return kTfLiteError; + } + if (quantization_params->scale->size > 1 && + quantization_params->quantized_dimension != + expected_quantized_dimension) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported quantized dimension %d in tensor #%d in node " + "#%d", + quantization_params->quantized_dimension, tensor_index, + node_index); + return kTfLiteError; + } + break; + } + case kTfLiteBlockwiseQuantization: { + const TfLiteBlockwiseQuantization* quantization_params = + reinterpret_cast( + tensor.quantization.params); + if (quantization_params->scale == kTfLiteOptionalTensor) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "missing scale quantization parameters in " + "tensor #%d in node #%d", + tensor_index, node_index); + return kTfLiteError; + } + if (quantization_params->blocksize % 32 != 0) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "Blocksize %zu must be multiple of 32 in " + "tensor #%d in node #%d", + quantization_params->blocksize, tensor_index, node_index); + return kTfLiteError; + } + break; + } + default: + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported quantization type %d in tensor #%d in node #%d", + tensor.quantization.type, tensor_index, node_index); } return kTfLiteOk; } @@ -4353,14 +4479,37 @@ class Subgraph { std::vector filter_dims( &filter_tensor.dims->data[0], &filter_tensor.dims->data[NumDimensions(&filter_tensor)]); - int32_t zero_point_value = filter_params->zero_point->data[0]; uint32_t kernel_id = XNN_INVALID_VALUE_ID; - status = xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, filter_datatype, zero_point_value, - filter_params->scale->data, filter_dims.size(), /*channel_dim=*/0, - filter_dims.data(), GetTensorData(&filter_tensor), - XNN_INVALID_VALUE_ID, - /*flags=*/0, &kernel_id); + switch (filter_datatype) { + case xnn_datatype_qcint4: + case xnn_datatype_qcint8: { + int32_t zero_point_value = filter_params->zero_point->data[0]; + status = xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, filter_datatype, zero_point_value, + filter_params->scale->data, filter_dims.size(), + /*channel_dim=*/0, filter_dims.data(), + GetTensorData(&filter_tensor), XNN_INVALID_VALUE_ID, + /*flags=*/0, &kernel_id); + break; + } + case xnn_datatype_qbint4: { + const auto* quantization_params = + reinterpret_cast( + tensors[node->inputs->data[1]].quantization.params); + const TfLiteTensor& scale_tensor = + tensors[quantization_params->scale]; + status = xnn_define_blockwise_quantized_tensor_value_v2( + subgraph, filter_datatype, 0, + reinterpret_cast(scale_tensor.data.data), + filter_dims.size(), quantization_params->quantized_dimension, + quantization_params->blocksize, filter_dims.data(), + GetTensorData(&filter_tensor), XNN_INVALID_VALUE_ID, + /*flags=*/0, xnn_datatype_fp16, &kernel_id); + break; + } + default: + return kTfLiteError; + } if (status != xnn_status_success) { TF_LITE_KERNEL_LOG( logging_context, "failed to update filter tensor %s node #%d", diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 8bfb045bc1b477..b4620c202cd674 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -120,13 +122,15 @@ TfLiteStatus VerifyPerChannelQuantization(TfLiteContext* context, TfLiteStatus VerifyQuantizationZeroPoint(const TfLiteTensor* tensor, int expected_value) { - const auto* params = - reinterpret_cast(tensor->quantization.params); - if (params && params->zero_point && - std::any_of(params->zero_point->data, - params->zero_point->data + params->zero_point->size, - [expected_value](int v) { return v != expected_value; })) { - return kTfLiteError; + if (tensor->quantization.type == kTfLiteAffineQuantization) { + const auto* params = reinterpret_cast( + tensor->quantization.params); + if (params && params->zero_point && + std::any_of(params->zero_point->data, + params->zero_point->data + params->zero_point->size, + [expected_value](int v) { return v != expected_value; })) { + return kTfLiteError; + } } return kTfLiteOk; } @@ -947,6 +951,82 @@ struct SparseHybridFullyConnectedTask : cpu_backend_threadpool::Task { TfLiteTensor* output; }; +inline int8_t SignExtendInt4(int8_t value) { return (value ^ 0x8) - 8; } + +TfLiteStatus EvalBlockwise4Bit( + TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* input_offsets, + TfLiteTensor* output) { + const auto quantization_params = + static_cast( + filter->quantization.params); + + const size_t blocksize = quantization_params->blocksize; + const size_t input_channels = filter->dims->data[1]; + const size_t output_channels = filter->dims->data[0]; + const size_t batch_size = NumElements(input) / input_channels; + const size_t num_blocks = input_channels / blocksize; + const TfLiteTensor& scale = context->tensors[quantization_params->scale]; + int num_scales = NumElements(&scale); + std::vector dequantized_scale(num_scales, 0); + const Eigen::half* half_data = reinterpret_cast( + GetTensorData(&scale)); + reference_ops::Dequantize(GetTensorShape(&scale), half_data, + GetTensorShape(&scale), dequantized_scale.data()); + float* output_ptr = GetTensorData(output); + memset(output_ptr, 0, NumElements(output) * sizeof(float)); + std::vector quant_data(NumElements(input)); + std::vector input_scales(batch_size); + std::vector input_zero_points(batch_size); + + const float* input_ptr = GetTensorData(input); + tensor_utils::BatchQuantizeFloats(input_ptr, batch_size, input_channels, + quant_data.data(), input_scales.data(), + input_zero_points.data(), + /*do_asymmetric=*/true); + + const float* bias_data = nullptr; + if (bias) { + bias_data = GetTensorData(bias); + } + const size_t k2 = (input_channels + 1) & 0xFFFFFFFFFFFFFFFE; + const uint8_t* kernel = GetTensorData(filter); + for (size_t mi = 0; mi < batch_size; mi++) { + for (size_t ni = 0; ni < output_channels; ni++) { + float kfsum = 0.0; + for (size_t bi = 0; bi < num_blocks; bi++) { + int32_t ksum = 0; + int32_t c_ref_acc = 0; + for (size_t ki = 0; ki < blocksize; ki++) { + const size_t k_index = bi * blocksize + ki; + const size_t nb_index = (ni * k2 + k_index) / 2; + const int8_t k_value = int8_t( + (k_index % 2 == 0) ? (kernel[nb_index] & static_cast(0xF)) + : (kernel[nb_index] >> 4)); + const int32_t kernel_value = SignExtendInt4(k_value); + ksum += kernel_value; + c_ref_acc += + static_cast(quant_data[mi * input_channels + k_index]) * + static_cast(kernel_value); + } + size_t scale_index = ni * num_blocks + bi; + float scale = dequantized_scale[scale_index]; + output_ptr[mi * output_channels + ni] += c_ref_acc * scale; + kfsum += scale * ksum; + } + output_ptr[mi * output_channels + ni] -= (input_zero_points[mi] * kfsum); + output_ptr[mi * output_channels + ni] *= input_scales[mi]; + if (bias_data != nullptr) { + output_ptr[mi * output_channels + ni] += bias_data[ni]; + } + } + } + return kTfLiteOk; +} + TfLiteStatus EvalHybridDense4Bit( TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, @@ -1295,9 +1375,18 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_OK( context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets)); if (data->op_data_4bit) { - return EvalHybridDense4Bit(context, node, params, data, input, filter, - bias, input_quantized, scaling_factors, - accum_scratch, input_offsets, output); + switch (filter->quantization.type) { + case kTfLiteAffineQuantization: + return EvalHybridDense4Bit(context, node, params, data, input, filter, + bias, input_quantized, scaling_factors, + accum_scratch, input_offsets, output); + case kTfLiteBlockwiseQuantization: + return EvalBlockwise4Bit(context, node, params, data, input, filter, + bias, input_quantized, scaling_factors, + accum_scratch, input_offsets, output); + default: + return kTfLiteError; + } } TfLiteTensor* row_sums; TF_LITE_ENSURE_OK(context, diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 18b9e115775c37..8737699e3eaa5a 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG e67c0fbc360903f921ff286a235c18d9e12c6df6 + GIT_TAG 42ed90ba36f14321df08712e7a36713de5b2f29b GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 41d25e5099886a..7465f38d4df828 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -160,9 +160,9 @@ def _tf_repositories(): # LINT.IfChange(xnnpack) tf_http_archive( name = "XNNPACK", - sha256 = "72e4368ff3e7bdefd8b43fc6e5708b8e9fada7a8302ba2362028832df6262c13", - strip_prefix = "XNNPACK-e67c0fbc360903f921ff286a235c18d9e12c6df6", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/e67c0fbc360903f921ff286a235c18d9e12c6df6.zip"), + sha256 = "a7e47b12fb8beb0177fbd49c8dfcb842709b5a50cdf2f5bf5ec5e33d8244fcfa", + strip_prefix = "XNNPACK-42ed90ba36f14321df08712e7a36713de5b2f29b", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/42ed90ba36f14321df08712e7a36713de5b2f29b.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -184,9 +184,9 @@ def _tf_repositories(): # LINT.IfChange(pthreadpool) tf_http_archive( name = "pthreadpool", - sha256 = "2d56c31ebf6509d171d12ace2b543f6182ff0083ba674541515fc573738a3238", - strip_prefix = "pthreadpool-706a8ea9e4b8c2129718af195ddce7fc2573e719", - urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/706a8ea9e4b8c2129718af195ddce7fc2573e719.zip"), + sha256 = "745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421", + strip_prefix = "pthreadpool-b92447772365661680f486e39a91dfe6675adafc", + urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip"), ) # LINT.ThenChange(//tensorflow/lite/cmake/DownloadPThreadPool.cmake) From 2a9268035df53e9ebea33e6aef3fba5bc0747cd9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 7 Apr 2025 05:41:08 -0700 Subject: [PATCH 0297/1324] [xla:gpu] CommandBuffer: remove default SynchronizationMode argument PiperOrigin-RevId: 744689270 --- .../backends/gpu/runtime/command_buffer_cmd.h | 5 +- .../gpu/runtime/command_buffer_cmd_emitter.cc | 76 ++++++++++--------- .../gpu/runtime/command_buffer_cmd_emitter.h | 14 ++-- .../gpu/runtime/command_buffer_cmd_test.cc | 18 +++-- .../gpu/runtime/command_buffer_thunk_test.cc | 45 ++++++----- .../xla/service/gpu/ir_emitter_unnested.cc | 10 +-- 6 files changed, 92 insertions(+), 76 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 510c5acfe02557..d414bf74288221 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -294,10 +294,7 @@ class CommandBufferCmdSequence { Append(std::make_unique(std::forward(args)...)); } - // TODO(b/406370928): Remove default argument and make sure we correctly - // propagate synchronization mode through the codebase. - CommandBufferCmdSequence Build(SynchronizationMode synchronization_mode = - SynchronizationMode::kSerialize) &&; + CommandBufferCmdSequence Build(SynchronizationMode synchronization_mode) &&; private: std::vector> commands_; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index de10af9c3ae6c2..3aa4c001822d22 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -52,13 +52,13 @@ namespace xla::gpu { // Appends command(s) converted from `thunk` to `cmd_sequence_builder`. static absl::Status AppendCommands( - CommandBufferCmdSequence::Builder& cmd_sequence_builder, - const Thunk& thunk); + CommandBufferCmdSequence::Builder& cmd_sequence_builder, const Thunk& thunk, + const ConvertToCommandsOptions& options); // Appends command(s) converted from `sequence` to `cmd_sequence_builder`. static absl::Status AppendCommands( CommandBufferCmdSequence::Builder& cmd_sequence_builder, - const ThunkSequence& sequence); + const ThunkSequence& sequence, const ConvertToCommandsOptions& options); //===----------------------------------------------------------------------===// // Conversions from Thunk to Command @@ -105,12 +105,14 @@ static absl::StatusOr Convert(const Memset32BitValueThunk& thunk) { thunk.destination(), thunk.value()); } -static absl::StatusOr Convert(const WhileThunk& thunk) { +static absl::StatusOr Convert( + const WhileThunk& thunk, const ConvertToCommandsOptions& options) { TF_ASSIGN_OR_RETURN( CommandBufferCmdSequence cond_cmds, - ConvertToCommands(thunk.condition_thunk_sequence()->thunks())); - TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence body_cmds, - ConvertToCommands(thunk.body_thunk_sequence()->thunks())); + ConvertToCommands(thunk.condition_thunk_sequence()->thunks(), options)); + TF_ASSIGN_OR_RETURN( + CommandBufferCmdSequence body_cmds, + ConvertToCommands(thunk.body_thunk_sequence()->thunks(), options)); return std::make_unique(thunk.execution_stream_id(), thunk.condition_result_buffer(), @@ -142,21 +144,24 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { thunk.workspace().value()); } -static absl::StatusOr Convert(const ConditionalThunk& thunk) { +static absl::StatusOr Convert( + const ConditionalThunk& thunk, const ConvertToCommandsOptions& options) { std::vector branch_cmds; branch_cmds.reserve(thunk.branch_thunks().size()); if (thunk.branch_index_is_bool()) { // For boolean predicates, we need to convert the branches in reverse order // because the first branch is the "false" branch and the second is "true" CHECK_EQ(thunk.branch_thunks().size(), 2); - TF_ASSIGN_OR_RETURN(branch_cmds.emplace_back(), - ConvertToCommands(thunk.branch_thunks()[1]->thunks())); - TF_ASSIGN_OR_RETURN(branch_cmds.emplace_back(), - ConvertToCommands(thunk.branch_thunks()[0]->thunks())); + TF_ASSIGN_OR_RETURN( + branch_cmds.emplace_back(), + ConvertToCommands(thunk.branch_thunks()[1]->thunks(), options)); + TF_ASSIGN_OR_RETURN( + branch_cmds.emplace_back(), + ConvertToCommands(thunk.branch_thunks()[0]->thunks(), options)); } else { for (auto& branch_thunk : thunk.branch_thunks()) { TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmds, - ConvertToCommands(branch_thunk->thunks())); + ConvertToCommands(branch_thunk->thunks(), options)); branch_cmds.emplace_back(std::move(cmds)); } } @@ -189,10 +194,11 @@ static absl::StatusOr Convert(const AllGatherStartThunk& thunk) { thunk.config(), thunk.buffers()); } -static absl::StatusOr Convert(const DynamicSliceThunk& thunk) { - CommandBufferCmdSequence::Builder builder; - auto embed_thunk = thunk.get_embedded_thunk(); - TF_RETURN_IF_ERROR(AppendCommands(builder, embed_thunk->thunks())); +static absl::StatusOr Convert( + const DynamicSliceThunk& thunk, const ConvertToCommandsOptions& options) { + TF_ASSIGN_OR_RETURN( + CommandBufferCmdSequence embedded_cmds, + ConvertToCommands(thunk.get_embedded_thunk()->thunks(), options)); auto& thunk_fake_allocations = thunk.get_fake_allocations(); std::vector> fake_allocations; @@ -201,7 +207,7 @@ static absl::StatusOr Convert(const DynamicSliceThunk& thunk) { fake_allocations.push_back(std::make_unique(**it)); } return std::make_unique( - thunk.execution_stream_id(), std::move(builder).Build(), + thunk.execution_stream_id(), std::move(embedded_cmds), thunk.get_arguments(), std::move(fake_allocations), thunk.get_offsets(), thunk.get_orig_shapes(), thunk.get_sliced_shapes(), thunk.get_offset_byte_sizes()); @@ -247,14 +253,16 @@ static absl::StatusOr CopyMetadata(absl::StatusOr cmd, return cmd; } -template -static absl::StatusOr Convert(const Thunk& thunk) { - return CopyMetadata(Convert(static_cast(thunk)), thunk); +template +static absl::StatusOr Convert(const Thunk& thunk, Args&&... args) { + return CopyMetadata(Convert(static_cast(thunk), + std::forward(args)...), + thunk); } static absl::Status AppendCommands( - CommandBufferCmdSequence::Builder& cmd_sequence_builder, - const Thunk& thunk) { + CommandBufferCmdSequence::Builder& cmd_sequence_builder, const Thunk& thunk, + const ConvertToCommandsOptions& options) { auto append = [&](absl::StatusOr command) -> absl::Status { if (command.ok()) { cmd_sequence_builder.Append(std::move(*command)); @@ -265,7 +273,7 @@ static absl::Status AppendCommands( switch (thunk.kind()) { case Thunk::Kind::kConditional: - return append(Convert(thunk)); + return append(Convert(thunk, options)); case Thunk::Kind::kCopy: return append(Convert(thunk)); case Thunk::Kind::kCustomCall: @@ -295,18 +303,18 @@ static absl::Status AppendCommands( case Thunk::Kind::kReplicaId: return append(Convert(thunk)); case Thunk::Kind::kWhile: - return append(Convert(thunk)); + return append(Convert(thunk, options)); case Thunk::Kind::kCuDnn: return append(Convert(thunk)); case Thunk::Kind::kDynamicSlice: - return append(Convert(thunk)); + return append(Convert(thunk, options)); // Sequential thunk does not have any special semantics and we simply inline // all nested thunks into command buffer. case Thunk::Kind::kSequential: - return AppendCommands( - cmd_sequence_builder, - static_cast(thunk).thunks()); + return AppendCommands(cmd_sequence_builder, + static_cast(thunk).thunks(), + options); // Thunks that simply wait for stream events are no-op in the command buffer // context, as we convert async thunks to command dependency graph. @@ -332,17 +340,17 @@ static absl::Status AppendCommands( static absl::Status AppendCommands( CommandBufferCmdSequence::Builder& cmd_sequence_builder, - const ThunkSequence& sequence) { + const ThunkSequence& sequence, const ConvertToCommandsOptions& options) { for (const std::unique_ptr& thunk : sequence) - TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, *thunk)); + TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, *thunk, options)); return absl::OkStatus(); } absl::StatusOr ConvertToCommands( - const ThunkSequence& sequence) { + const ThunkSequence& sequence, const ConvertToCommandsOptions& options) { CommandBufferCmdSequence::Builder cmd_sequence_builder; - TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, sequence)); - return std::move(cmd_sequence_builder).Build(); + TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, sequence, options)); + return std::move(cmd_sequence_builder).Build(options.synchronization_mode); } } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h index fbe7ed9a016f5c..f3a3781b39cc6e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h @@ -22,13 +22,15 @@ limitations under the License. namespace xla::gpu { -// Converts thunk sequence to a command buffer cmd sequence. If -// `synchronization_mode` is kSerialize, we automatically insert barriers -// between all commands in a sequence. Otherwise we use buffer usage aliasing to -// allow commands to run concurrently and insert barriers only when needed for -// correctness. +// Options for converting from thunks to command buffer commands. +struct ConvertToCommandsOptions { + CommandBufferCmdSequence::SynchronizationMode synchronization_mode = + CommandBufferCmdSequence::SynchronizationMode::kSerialize; +}; + +// Converts thunk sequence to a command buffer cmd sequence. absl::StatusOr ConvertToCommands( - const ThunkSequence& sequence); + const ThunkSequence& sequence, const ConvertToCommandsOptions& options); } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 528a593db6d0a5..892e61c62c5136 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -58,9 +58,13 @@ static se::StreamExecutor* GpuExecutor() { return platform->ExecutorForDevice(0).value(); } -// Give a short aliases to execution threads. +// Give a short alias to execution thread. static constexpr auto s0 = ExecutionStreamId(0); +// Give a short alias to synchronization mode. +static constexpr auto serialize = + CommandBufferCmdSequence::SynchronizationMode::kSerialize; + // A command buffer cmd for testing automatic barriers insertion by the command // buffer cmd sequence. We never execute this command, we need it only to pass // buffer usage vector to the command buffer cmd sequence. @@ -148,7 +152,7 @@ TEST(CommandBufferCmdTest, SerializeExecution) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -166,7 +170,7 @@ TEST(CommandBufferCmdTest, NoReadBarrier) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -184,7 +188,7 @@ TEST(CommandBufferCmdTest, NoWriteBarrier) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -205,7 +209,7 @@ TEST(CommandBufferCmdTest, WriteConflictBarrier) { builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); builder.Emplace(s0, BufferUseVector{use2}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -234,7 +238,7 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_b, slice_a, byte_length); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); ServiceExecutableRunOptions run_options; se::StreamExecutorMemoryAllocator allocator(executor); @@ -290,7 +294,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { builder.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Initialize command sequence and load device kernels. TF_ASSERT_OK_AND_ASSIGN(std::vector fatbin, diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 2dbf17e1d7c32b..5de0bd6ee0b910 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -79,6 +79,7 @@ using MemoryAccess = BufferUse::MemoryAccess; using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; namespace { + se::StreamExecutor* GpuExecutor() { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); @@ -129,6 +130,11 @@ bool IsAtLeastCuda12300(const se::StreamExecutor* executor) { // Give a short aliases to execution threads. constexpr auto s0 = ExecutionStreamId(0); constexpr auto s1 = ExecutionStreamId(1); + +// Give a short alias to synchronization mode. +static constexpr auto serialize = + CommandBufferCmdSequence::SynchronizationMode::kSerialize; + } // namespace TEST(CommandBufferThunkTest, MemcpyCmd) { @@ -156,7 +162,7 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_b, slice_a, byte_length); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -211,7 +217,7 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -254,7 +260,7 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{84}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -304,7 +310,7 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { // be used. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{12}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); constexpr bool kProfileCommandBuffersEnabled = false; // Construct a thunk with command sequence. @@ -359,7 +365,7 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { // be used. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{12}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); constexpr bool kProfileCommandBuffersEnabled = true; // Construct a thunk with command sequence. @@ -404,7 +410,7 @@ TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice0, int32_t{12}); builder.Emplace(s1, slice1, int32_t{34}); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -458,7 +464,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { builder.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -556,7 +562,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { builder.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -674,7 +680,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { builder.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, slice_workspace, /*deterministic=*/true); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -801,7 +807,8 @@ TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { embed_builder.Emplace(s0, config.value(), fake_slice_lhs, slice_rhs, slice_out, slice_workspace, /*deterministic=*/true); - CommandBufferCmdSequence embed_commands = std::move(embed_builder).Build(); + CommandBufferCmdSequence embed_commands = + std::move(embed_builder).Build(serialize); BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); @@ -831,7 +838,7 @@ TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { builder.Emplace( s0, std::move(embed_commands), arguments, std::move(fake_allocations), offsets, orig_shapes, sliced_shapes, offset_byte_sizes); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -944,7 +951,7 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), slice_workspace); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1078,7 +1085,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { builder.Emplace(s0, "AddI32", args_1, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1202,13 +1209,13 @@ TEST(CommandBufferThunkTest, CaseCmd) { } std::vector branches(2); - branches[0] = std::move(branches_builder[0]).Build(); - branches[1] = std::move(branches_builder[1]).Build(); + branches[0] = std::move(branches_builder[0]).Build(serialize); + branches[1] = std::move(branches_builder[1]).Build(serialize); // Prepare commands sequence for thunk. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_i, false, std::move(branches)); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1296,7 +1303,7 @@ TEST(CommandBufferThunkTest, WhileCmd) { s0, "IncAndCmp", cond_args, cond_args_access, LaunchDimensions(1, 1), /*shmem_bytes=*/0); CommandBufferCmdSequence cond_commands = - std::move(cond_commands_builder).Build(); + std::move(cond_commands_builder).Build(serialize); // Prepare commands sequence for loop `body`. CommandBufferCmdSequence::Builder body_commands_builder; @@ -1304,13 +1311,13 @@ TEST(CommandBufferThunkTest, WhileCmd) { s0, "AddI32", body_args, body_args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); CommandBufferCmdSequence body_commands = - std::move(body_commands_builder).Build(); + std::move(body_commands_builder).Build(serialize); // Prepare commands sequence for thunk. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_pred, std::move(cond_commands), std::move(body_commands)); - CommandBufferCmdSequence commands = std::move(builder).Build(); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index c3bed53294e36d..d9e09ee69ea464 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -573,12 +573,10 @@ absl::Status IrEmitterUnnested::EmitCommandBufferThunk( ? CommandBufferCmdSequence::SynchronizationMode::kAutomatic : CommandBufferCmdSequence::SynchronizationMode::kSerialize; - // TODO(b/406370928): Use `synchronization_mode` to construct a command buffer - // cmd sequence with specified synchronization mode. - (void)synchronization_mode; - - TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmd_sequence, - ConvertToCommands(thunk_sequence->thunks())); + TF_ASSIGN_OR_RETURN( + CommandBufferCmdSequence cmd_sequence, + ConvertToCommands(thunk_sequence->thunks(), + ConvertToCommandsOptions{synchronization_mode})); AddThunkToThunkSequence(std::make_unique( std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr), From b3d680f5fc4abd188d0476f38cfee053ff4967a2 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Mon, 7 Apr 2025 08:03:09 -0700 Subject: [PATCH 0298/1324] [XLA:GPU/TMA] Move TritonXLAExtractInsertToTriton pass to happen during compiling Triton to LLVM. Consequently, tests also needed some updates. PiperOrigin-RevId: 744723981 --- .../gpu/codegen/triton/fusion_emitter.cc | 10 +- .../fusion_emitter_device_legacy_port_test.cc | 2 +- .../triton/fusion_emitter_device_test.cc | 220 +++++++----------- .../triton_xla_extract_insert_to_triton.mlir | 78 +++++++ third_party/xla/xla/service/gpu/tests/BUILD | 1 + .../xla/xla/service/gpu/tests/xla-opt.cc | 3 +- 6 files changed, 178 insertions(+), 136 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index f183568ae2a36d..8467a8d7174943 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1360,11 +1360,12 @@ absl::StatusOr CreateTileOp( const Shape& shape = tiled_hlo.hlo()->shape(); TF_ASSIGN_OR_RETURN(Type expected_element_type, TritonType(b, shape.element_type())); + Type storage_type = StorageType(b, expected_element_type); auto result_type = mtx::TiledTensorType::get( b.getContext(), padded_tile_sizes, llvm::ArrayRef(shape.dimensions().data(), shape.dimensions().size()), - expected_element_type); + storage_type); return b.create( result_type, parent_base_ptr, ptr_offsets, @@ -1698,9 +1699,6 @@ absl::StatusOr CreateTritonModule( mlir::PassManager pm(&mlir_context); - // TODO(b/315957220): Pass device and tma_flag to the pass. - pm.addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass()); - pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); if (mlir::failed(pm.run(triton_module.get()))) { @@ -1800,6 +1798,10 @@ absl::StatusOr CompileTritonToLLVM( mlir::PassManager pm(&mlir_context); pm.enableVerifier(should_verify); + // TODO(b/315957220): Propagate TMA flag once it's supported. + pm.addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass( + device_info, /*tma_enabled=*/false)); + // Lower affine expressions into arithmetic ops. pm.addPass(mlir::createLowerAffinePass()); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index c58c767ccbf8da..c07585d8b6e052 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -269,7 +269,7 @@ ENTRY e { CreateTritonIrAndFileCheck(*module_and_metadata.computation, module_and_metadata.block_level_parameters, R"( -CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr> +CHECK: %[[LOAD:.*]] = triton_xla.extract {{.*}} : tensor<2x2xi8> to tensor<16x16xi8> CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1> CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> )")); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index b8c353bbeeb689..ffe7033ef1ede8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -147,11 +147,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "fused_computation", R"( -CHECK-COUNT-1: tt.load +CHECK-COUNT-1: triton_xla.extract CHECK: %[[ABS:.*]] = math.absf CHECK: %[[REDUCE:.*]] = "tt.reduce"(%[[ABS:.*]]) <{axis = 1 : i32}> -CHECK: tt.store %{{.*}}, %[[REDUCE]] : !tt.ptr> -CHECK: tt.store %{{.*}}, %[[ABS]] : !tt.ptr> +CHECK: triton_xla.insert %[[REDUCE]] {{.*}} : tensor<64xf32> into tensor<128xf32> +CHECK: triton_xla.insert %[[ABS]] {{.*}} : tensor<64x512xf32> into tensor<128x512xf32> )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } @@ -189,11 +189,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "fused_computation", R"( -CHECK-COUNT-1: tt.load +CHECK-COUNT-1: triton_xla.extract CHECK: %[[ABS:.*]] = math.absf CHECK: %[[REDUCE:.*]] = "tt.reduce"(%[[ABS:.*]]) <{axis = 0 : i32}> -CHECK: tt.store %{{.*}}, %[[REDUCE]] : !tt.ptr -CHECK: tt.store %{{.*}}, %[[ABS]] : !tt.ptr> +CHECK: tensor.insert %[[REDUCE]] {{.*}} : tensor +CHECK: triton_xla.insert %[[ABS]] {{.*}} : tensor<512xf32> into tensor<512xf32> )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } @@ -224,14 +224,6 @@ ENTRY entry_computation { "num_ctas":"1", "num_stages":"1"}}} })"; - TF_EXPECT_OK( - CreateTritonIrAndFileCheck(this, kHloText, "fused_computation", R"( -CHECK-COUNT-1: tt.load -CHECK: tt.store -CHECK-SAME: {boundaryCheck = array} -CHECK: tt.store -CHECK-SAME: {boundaryCheck = array} -)")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } @@ -260,14 +252,6 @@ ENTRY entry_computation { "num_ctas":"1", "num_stages":"1"}}} })"; - TF_EXPECT_OK( - CreateTritonIrAndFileCheck(this, kHloText, "fused_computation", R"( -CHECK-COUNT-1: tt.load -CHECK: tt.store -CHECK-SAME: {boundaryCheck = array} -CHECK: tt.store -CHECK-NOT: {boundaryCheck = array} -)")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } @@ -305,9 +289,9 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "fused_computation", R"( -CHECK-COUNT-1: tt.load +CHECK-COUNT-1: triton_xla.extract CHECK: tt.reduce -CHECK-COUNT-2: tt.store +CHECK-COUNT-2: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } @@ -423,7 +407,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_reduction_computation", R"( ; Make sure input reduction tile is padded with a neutral value. -CHECK: %[[LOAD:.*]] = tt.load +CHECK: %[[LOAD:.*]] = triton_xla.extract CHECK: %[[RANGE:.*]] = tt.make_range CHECK: %[[EXPAND:.*]] = tt.expand_dims %[[RANGE]] CHECK: %[[BROADCAST:.*]] = tt.broadcast %[[EXPAND]] @@ -472,21 +456,14 @@ ENTRY main { "num_stages":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> -CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { -CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 -CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 -CHECK-DAG: %[[C127:.*]] = arith.constant 127 : i64 +CHECK: func.func @triton_fn(%[[P0:.*]]: {{.*}}, %[[P1:.*]]: {{.*}}) +CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID_I64]] : i64 to index -CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C125]], %[[PID_I64]] : i64 -CHECK-DAG: %[[OFFSET_IDX:.*]] = xla.apply_indexing #indexing_map(%[[PID_INDEX]]) -CHECK-DAG: %[[OFFSET_I64:.*]] = arith.index_castui %[[OFFSET_IDX]] : index to i64 -CHECK-DAG: %[[BASE_PTR_LOAD:.*]] = tt.addptr %[[P0]], %[[OFFSET_I64]] : !tt.ptr, i64 -CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR_LOAD]], [%[[SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.load -CHECK-SAME: {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> +CHECK-DAG: triton_xla.tile %[[P0]][%[[PID_INDEX]], %[[C0]]][%[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x128|125x127xf32> +CHECK-NEXT: triton_xla.extract CHECK: tt.reduce CHECK-NEXT: ^bb0(%[[ARG2:[^:]*]]: f32, %[[ARG3:[^:]*]]: f32): CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 @@ -494,11 +471,9 @@ CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 CHECK-NEXT: }) : (tensor<1x128xf32>) -> tensor<1xf32> CHECK: arith.mulf CHECK-SAME: tensor<1x128xf32> -CHECK-DAG: %[[BASE_PTR_STORE:.*]] = tt.addptr %[[P1]], %[[OFFSET_I64]] : !tt.ptr, i64 -CHECK: tt.make_tensor_ptr %[[BASE_PTR_STORE]], [%[[SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.store -CHECK-SAME: {boundaryCheck = array} : !tt.ptr> -CHECK: tt.return +CHECK: triton_xla.tile %[[P1]][%[[PID_INDEX]], %[[C0]]][%[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x128|125x127xf32> +CHECK-NEXT: triton_xla.insert +CHECK: return CHECK: } )")); } @@ -540,34 +515,28 @@ ENTRY main { "num_stages":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> -CHECK: tt.func @triton_fn( -CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-SAME: %[[P2:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 -CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 -CHECK-DAG: %[[C127:.*]] = arith.constant 127 : i64 +CHECK: func.func @triton_fn( +CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: tensor<125x127xf32> +CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: tensor<127xf32> +CHECK-SAME: %[[P2:[A-Za-z0-9_]*]]: tensor<125x127xf32> +CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID_I64]] : i64 to index -CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C125]], %[[PID_I64]] : i64 -CHECK-DAG: %[[OFFSET_IDX:.*]] = xla.apply_indexing #indexing_map(%[[PID_INDEX]]) -CHECK-DAG: %[[OFFSET_I64:.*]] = arith.index_castui %[[OFFSET_IDX]] : index to i64 -CHECK-DAG: %[[BASE_PTR0_LOAD:.*]] = tt.addptr %[[P0]], %[[OFFSET_I64]] : !tt.ptr, i64 -CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR0_LOAD]], [%[[SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.load {{.*}} : !tt.ptr> -CHECK-DAG: tt.make_tensor_ptr %[[P1]], [%[[C127]]], {{.*}} [%[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.load {{.*}} : !tt.ptr> +CHECK-DAG: %[[TILE_0:.*]] = triton_xla.tile %[[P0]][%[[PID_INDEX]], %[[C0]]][%[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x128|125x127xf32> +CHECK-DAG: triton_xla.extract %[[TILE_0]][%[[C0]], %[[C0]]] : tensor<125x127xf32> to tensor<1x128xf32> +CHECK-DAG: %[[TILE_1:.*]] = triton_xla.tile %[[P1]][%[[C0]]][%[[C1]]] {layout = array} : !triton_xla.tiled_tensor<128|127xf32> +CHECK-DAG: triton_xla.extract %[[TILE_1]][%[[C0]]] : tensor<127xf32> to tensor<128xf32> CHECK: tt.reduce CHECK-NEXT: ^bb0(%[[ARG3:[^:]*]]: f32, %[[ARG4:[^:]*]]: f32): CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 CHECK-NEXT: }) : (tensor<1x128xf32>) -> tensor<1xf32> CHECK: arith.mulf -CHECK-DAG: %[[BASE_PTR2_LOAD:.*]] = tt.addptr %[[P2]], %[[OFFSET_I64]] : !tt.ptr, i64 -CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR2_LOAD]], [%[[SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-DAG: tt.store {{.*}} : !tt.ptr> + +CHECK-DAG: %[[TILE_2:.*]] = triton_xla.tile %[[P2]][%[[PID_INDEX]], %[[C0]]][%[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x128|125x127xf32> +CHECK-DAG: triton_xla.insert {{.*}} into %[[TILE_2]][%[[C0]], %[[C0]]] : tensor<1x128xf32> into tensor<125x127xf32> )")); } @@ -615,39 +584,30 @@ ENTRY main { "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249]"> CHECK: #[[MAP1:.*]] = #xla.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249]"> -CHECK: #[[MAP2:.*]] = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249]"> -CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { -CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 -CHECK-DAG: %[[C10:.*]] = arith.constant 10 : i64 -CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 -CHECK-DAG: %[[C127:.*]] = arith.constant 127 : i64 +CHECK: func.func @triton_fn(%[[P0:.*]]: {{.*}}, %[[P1:.*]]: {{.*}}, %[[P2:.*]]: {{.*}}, %[[P3:.*]]: {{.*}}) +CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID_I64]] : i64 to index CHECK-DAG: %[[ROW_INDEX:.*]] = xla.apply_indexing #[[MAP]](%[[PID_INDEX]] CHECK-DAG: %[[COL_INDEX:.*]] = xla.apply_indexing #[[MAP1]](%[[PID_INDEX]] -CHECK-DAG: %[[ROW_64:.*]] = arith.index_castui %[[ROW_INDEX]] : index to i64 -CHECK-DAG: %[[COL_64:.*]] = arith.index_castui %[[COL_INDEX]] : index to i64 -CHECK-DAG: %[[ROW_SUB:.*]] = arith.subi %[[C10]], %[[ROW_64]] : i64 -CHECK-DAG: %[[COL_SUB:.*]] = arith.subi %[[C125]], %[[COL_64]] : i64 -CHECK-DAG: %[[OFFSET_IDX:.*]] = xla.apply_indexing #[[MAP2]](%[[PID_INDEX]]) -CHECK-DAG: %[[OFFSET_I64:.*]] = arith.index_castui %[[OFFSET_IDX]] : index to i64 -CHECK-DAG: %[[BASE_PTR0_LOAD:.*]] = tt.addptr %[[P0]], %[[OFFSET_I64]] : !tt.ptr, i64 -CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR0_LOAD]], [%[[ROW_SUB]], %[[COL_SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.load {{.*}} : !tt.ptr> -CHECK-DAG: tt.make_tensor_ptr %[[P1]], [%[[C127]]], {{.*}} [%[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.load {{.*}} : !tt.ptr> -CHECK-DAG: %[[BASE_PTR2_LOAD:.*]] = tt.addptr %[[P2]], %[[PID_I64]] : !tt.ptr, i64 -CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR2_LOAD]], [%[[ROW_SUB]], %[[COL_SUB]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.load {{.*}} : !tt.ptr> + +CHECK-DAG: triton_xla.tile %[[P0]][%[[ROW_INDEX]], %[[COL_INDEX]], %[[C0]]][%[[C1]], %[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x1x128|10x125x127xf32> +CHECK-NEXT: triton_xla.extract {{.*}} : tensor<10x125x127xf32> to tensor<1x1x128xf32> +CHECK-DAG: triton_xla.tile %[[P1]][%[[C0]]][%[[C1]]] {layout = array} : !triton_xla.tiled_tensor<128|127xf32> +CHECK-NEXT: triton_xla.extract {{.*}} : tensor<127xf32> to tensor<128xf32> + +CHECK-DAG: triton_xla.tile %[[P2]][%[[ROW_INDEX]], %[[COL_INDEX]]][%[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x1|10x125xf32> +CHECK-NEXT: triton_xla.extract {{.*}} : tensor<10x125xf32> to tensor<1x1xf32> CHECK: tt.reduce CHECK-NEXT: ^bb0(%[[ARG4:[^:]*]]: f32, %[[ARG5:[^:]*]]: f32): CHECK-NEXT: %[[MAX:.*]] = arith.maximumf %[[ARG4]], %[[ARG5]] : f32 CHECK-NEXT: tt.reduce.return %[[MAX]] : f32 CHECK-NEXT: }) : (tensor<1x1x128xf32>) -> tensor<1x1xf32> -CHECK-DAG: %[[BASE_PTR3_STORE:.*]] = tt.addptr %[[P3]], %[[OFFSET_I64]] : !tt.ptr, i64 -CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR3_STORE]], [%[[ROW_SUB]], %[[COL_SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]], %[[ZERO]]] {order = array} : > -CHECK-NEXT: tt.store {{.*}} : !tt.ptr> + +CHECK-DAG: triton_xla.tile %[[P3]][%[[ROW_INDEX]], %[[COL_INDEX]], %[[C0]]][%[[C1]], %[[C1]], %[[C1]]] {layout = array} : !triton_xla.tiled_tensor<1x1x128|10x125x127xf32> +CHECK-NEXT: triton_xla.insert {{.*}} : tensor<1x1x128xf32> into tensor<10x125x127xf32> )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -834,13 +794,13 @@ ENTRY main { "triton_softmax_computation", R"( // CHECK: #xla.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047]"> // CHECK: #xla.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047]"> -// CHECK-LABEL: tt.func @triton_fn( -// CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr -// CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr -// CHECK-SAME: %[[P2:[A-Za-z0-9_]*]]: !tt.ptr -// CHECK-DAG: tt.load {{.*}} : !tt.ptr -// CHECK-DAG: tt.load {{.*}} : !tt.ptr> -// CHECK: tt.store {{.*}} : !tt.ptr> +// CHECK-LABEL: func.func @triton_fn( +// CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: tensor<64x32x16xf32> +// CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: tensor +// CHECK-SAME: %[[P2:[A-Za-z0-9_]*]]: tensor<64x32x16xf32> +// CHECK-DAG: tensor.extract {{.*}} : tensor +// CHECK-DAG: triton_xla.extract {{.*}} : tensor<64x32x16xf32> to tensor<1x1x16xf32> +// CHECK: triton_xla.insert {{.*}} : tensor<1x1x16xf32> into tensor<64x32x16xf32> )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1033,15 +993,15 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_reduction_computation", R"( -CHECK: tt.func @triton_fn(%[[P0:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-SAME: %[[P2:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-DAG: tt.load {{.*}} : !tt.ptr> -CHECK-DAG: tt.load {{.*}} : !tt.ptr> +CHECK: func.func @triton_fn(%[[P0:[A-Za-z0-9_]*]]: tensor<125x127xf32> +CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: tensor<125xf32> +CHECK-SAME: %[[P2:[A-Za-z0-9_]*]]: tensor<125xf32> +CHECK-DAG: triton_xla.extract {{.*}} : tensor<125xf32> to tensor<1xf32> +CHECK-DAG: triton_xla.extract {{.*}} : tensor<125x127xf32> to tensor<1x128xf32> CHECK: tt.reduce CHECK: (tensor<1x128xf32>) -> tensor<1xf32> CHECK: arith.mulf {{.*}} tensor<1xf32> -CHECK: tt.store {{.*}} : !tt.ptr> +CHECK: triton_xla.insert {{.*}} : tensor<1xf32> into tensor<125xf32> )")); } @@ -1270,11 +1230,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK-NOT: tt.trans CHECK: tt.reshape CHECK-NOT: tt.trans -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1301,11 +1261,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK: tt.trans CHECK: tt.reshape CHECK-NOT: tt.trans -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1332,11 +1292,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK-NOT: tt.trans CHECK: tt.reshape CHECK: tt.trans -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1364,11 +1324,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK: tt.trans CHECK: tt.reshape CHECK: tt.trans -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1395,11 +1355,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK: tt.trans CHECK-NOT: tt.reshape CHECK-NOT: tt.trans -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1441,10 +1401,10 @@ ENTRY main { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK: tt.reduce CHECK: tt.broadcast -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1504,7 +1464,7 @@ ENTRY main { CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( CHECK: %[[CASTED_OUT:.*]] = arith.extui CHECK-SAME: tensor<4xi1> to tensor<4xi8> -CHECK: tt.store {{.*}} %[[CASTED_OUT]] +CHECK: triton_xla.insert %[[CASTED_OUT]] )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1543,7 +1503,7 @@ ENTRY main { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: %[[I8_PARAM:.*]] = tt.load {{.*}} : !tt.ptr> +CHECK: %[[I8_PARAM:.*]] = triton_xla.extract {{.*}} : tensor<15xi8> to tensor<4xi8> CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1> )")); @@ -1573,7 +1533,7 @@ ENTRY main { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: %[[TILE:.*]] = tt.load {{.*}} : !tt.ptr> +CHECK: %[[TILE:.*]] = triton_xla.extract {{.*}} : tensor<15x7x3xf32> to tensor<8x4x1xf32> CHECK: tt.trans %[[TILE]] {order = array} : tensor<8x4x1xf32> -> tensor<1x8x4xf32> )")); @@ -1606,11 +1566,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "fused_computation", R"( -CHECK: %[[TILE:.*]] = tt.load {{.*}} : !tt.ptr> -CHECK-NOT: tt.load +CHECK: %[[TILE:.*]] = triton_xla.extract {{.*}} : tensor<15x7x3xf32> to tensor<8x4x1xf32> +CHECK-NOT: triton_xla.extract CHECK: %[[ABS:.*]] = math.absf %[[TILE]] CHECK: tt.trans %[[ABS]] {order = array} : tensor<8x4x1xf32> -> tensor<1x8x4xf32> -CHECK-COUNT-2: tt.store +CHECK-COUNT-2: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1739,7 +1699,7 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -1777,10 +1737,10 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load {{.*}} !tt.ptr +CHECK: tensor.extract {{.*}} tensor CHECK: tt.extern_elementwise {{.*}} (f32) -> f32 CHECK: arith.subf {{.*}} f32 -CHECK: tt.load {{.*}} !tt.ptr +CHECK: tensor.extract {{.*}} tensor CHECK: tt.extern_elementwise {{.*}} (f32) -> f32 CHECK: arith.subf {{.*}} f32 CHECK: arith.addf {{.*}} f32 @@ -1788,7 +1748,7 @@ CHECK: arith.mulf {{.*}} f32 CHECK: arith.divf {{.*}} f32 CHECK: arith.truncf {{.*}} f32 to bf16 CHECK: arith.subf {{.*}} bf16 -CHECK: tt.store {{.*}} !tt.ptr +CHECK: tensor.insert {{.*}} tensor )")); EXPECT_TRUE(RunAndCompareNoHloPasses( @@ -1829,11 +1789,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: tensor.extract CHECK: tt.splat CHECK: arith.addf CHECK: tt.reduce -CHECK: tt.store {{.*}} !tt.ptr +CHECK: tensor.insert {{.*}} tensor )")); EXPECT_TRUE(RunAndCompareNoHloPasses( @@ -1869,7 +1829,7 @@ CHECK: tt.reshape CHECK: tt.reduce{{.*}}axis = 0 CHECK-NOT: tt.reshape CHECK: tt.reduce{{.*}}axis = 0 -CHECK: tt.store {{.*}} !tt.ptr +CHECK: tensor.insert {{.*}} tensor )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{0, 0})); @@ -1909,11 +1869,11 @@ ENTRY entry_computation { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.load +CHECK: triton_xla.extract CHECK: tt.reshape CHECK: tt.reduce CHECK: tt.reduce -CHECK: tt.store +CHECK: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); @@ -2161,21 +2121,21 @@ ENTRY entry { "num_stages":"1"}}} })"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fdot", R"( -CHECK: tt.func @triton_fn(%[[ARG0:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: !tt.ptr -CHECK-SAME: %[[ARG2:[A-Za-z0-9_]*]]: !tt.ptr +CHECK: func.func @triton_fn(%[[ARG0:[A-Za-z0-9_]*]]: tensor<32x256xf32> +CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<256x512xf32> +CHECK-SAME: %[[ARG2:[A-Za-z0-9_]*]]: tensor<32x512xf32> CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i64 CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i64 CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i64 CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C8]] step %[[C1]] CHECK-SAME: iter_args({{.*}}) -> (tensor<16x64xf32>) : i64 { -CHECK-DAG: tt.addptr %[[ARG0]] -CHECK-DAG: tt.addptr %[[ARG1]] +CHECK-DAG: triton_xla.tile %[[ARG0]] +CHECK-DAG: triton_xla.tile %[[ARG1]] CHECK-DAG: arith.subf {{.*}} : tensor<16x32xf32> CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32> CHECK-DAG: tt.dot {{.*}} tensor<16x32xf32> * tensor<32x64xf32> -> tensor<16x64xf32> CHECK: scf.yield {{.*}} : tensor<16x64xf32> -CHECK-COUNT-1: tt.store +CHECK-COUNT-1: triton_xla.insert )")); EXPECT_TRUE(RunAndCompareNoHloPasses( kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir index c0367fcf859649..a4d1ce9e9f59b5 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir @@ -93,3 +93,81 @@ func.func @incompatible_tma_shapes(%arg0: tensor<1000x1000xbf16>, // CHECK-TMA: tt.load // CHECK-TMA: tt.advance // CHECK-TMA: tt.store + +// ----- + +#indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]"> +module { + func.func @slice_with_tiling_that_needs_padding_has_boundary_checks( + %arg0: tensor<64xf32>, %arg1: tensor<63xf32>, %arg2: tensor<63xf32>) + -> (tensor<63xf32>, tensor<63xf32>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.extsi %0 : i32 to i64 + %2 = arith.index_castui %1 : i64 to index + %3 = xla.apply_indexing #indexing_map(%2) + %tiled_tensor = triton_xla.tile %arg0[%3][%c1] {layout = array} + : !triton_xla.tiled_tensor<32|64xf32> + %extracted_tile = triton_xla.extract %tiled_tensor[%c0] + : tensor<64xf32> to tensor<32xf32> + %4 = math.absf %extracted_tile : tensor<32xf32> + %5 = arith.subf %cst, %4 : tensor<32xf32> + %tiled_tensor_0 = triton_xla.tile %arg1[%3][%c1] {layout = array} + : !triton_xla.tiled_tensor<32|63xf32> + %inserted_tile = triton_xla.insert %5 into %tiled_tensor_0[%c0] + : tensor<32xf32> into tensor<63xf32> + %tiled_tensor_1 = triton_xla.tile %arg2[%3][%c1] {layout = array} + : !triton_xla.tiled_tensor<32|63xf32> + %inserted_tile_2 = triton_xla.insert %4 into %tiled_tensor_1[%c0] + : tensor<32xf32> into tensor<63xf32> + return %inserted_tile, %inserted_tile_2 : tensor<63xf32>, tensor<63xf32> + } +} + +// CHECK-LABEL: func @slice_with_tiling_that_needs_padding_has_boundary_checks +// CHECK-COUNT-1: tt.load +// CHECK: tt.store +// CHECK-SAME: boundaryCheck = array +// CHECK: tt.store +// CHECK-SAME: boundaryCheck = array + +// ----- + +#indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]"> +module { + func.func @slice_with_extra_output_that_can_reuse_tile_due_to_padding( + %arg0: tensor<64xf32>, %arg1: tensor<63xf32>, %arg2: tensor<64xf32>) + -> (tensor<63xf32>, tensor<64xf32>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.extsi %0 : i32 to i64 + %2 = arith.index_castui %1 : i64 to index + %3 = xla.apply_indexing #indexing_map(%2) + %tiled_tensor = triton_xla.tile %arg0[%3][%c1] {layout = array} + : !triton_xla.tiled_tensor<32|64xf32> + %extracted_tile = triton_xla.extract %tiled_tensor[%c0] + : tensor<64xf32> to tensor<32xf32> + %4 = math.absf %extracted_tile : tensor<32xf32> + %5 = arith.subf %cst, %4 : tensor<32xf32> + %tiled_tensor_0 = triton_xla.tile %arg1[%3][%c1] {layout = array} + : !triton_xla.tiled_tensor<32|63xf32> + %inserted_tile = triton_xla.insert %5 into %tiled_tensor_0[%c0] + : tensor<32xf32> into tensor<63xf32> + %tiled_tensor_1 = triton_xla.tile %arg2[%3][%c1] {layout = array} + : !triton_xla.tiled_tensor<32|64xf32> + %inserted_tile_2 = triton_xla.insert %4 into %tiled_tensor_1[%c0] + : tensor<32xf32> into tensor<64xf32> + return %inserted_tile, %inserted_tile_2 : tensor<63xf32>, tensor<64xf32> + } +} + +// CHECK-LABEL: func @slice_with_extra_output_that_can_reuse_tile_due_to_padding +// CHECK-COUNT-1: tt.load +// CHECK: tt.store +// CHECK-SAME: boundaryCheck = array +// CHECK: tt.store +// CHECK-NOT: boundaryCheck = array diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index d107b979fe02e0..ff6bd494ba3d95 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -642,6 +642,7 @@ lit_test_suite( # "//xla/backends/gpu/codegen/emitters/transforms:passes", # "//xla/backends/gpu/codegen/triton/ir:triton_xla", # "//xla/backends/gpu/codegen/triton/transforms:passes", +# "//xla/codegen/emitters/ir:xla", # "//xla/codegen/emitters/transforms:passes", # "@triton//:AllPassesAndDialects", # "@triton//third_party/amd:TestAMDRangeAnalysis", diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc index 358490f6798d44..f4b3fcc1c71839 100644 --- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc +++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/backends/gpu/codegen/emitters/transforms/passes.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "xla/backends/gpu/codegen/triton/transforms/passes.h" +#include "xla/codegen/emitters/ir/xla_ops.h" #include "xla/codegen/emitters/transforms/passes.h" #include "third_party/triton/bin/RegisterTritonDialects.h" @@ -28,7 +29,7 @@ int main(int argc, char **argv) { mlir::registerAllExtensions(registry); registerTritonDialects(registry); // This registers all passes as well. registry.insert(); + mlir::triton::xla::XlaTritonDialect, xla::XlaDialect>(); mlir::triton::xla::registerTritonXlaTransformsPasses(); xla::emitters::registerTransformsPasses(); xla::gpu::registerGpuFusionTransformsPasses(); From bc3e70bcb9d1cbe7e3231cf99c8235fc9a9e5b3a Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Mon, 7 Apr 2025 08:26:17 -0700 Subject: [PATCH 0299/1324] PR #24550: Support custom call stream assignment Imported from GitHub PR https://github.com/openxla/xla/pull/24550 Previously, in JAX if a user tried to use `compute_on("gpu_stream:#")` on a computation that included a custom call, that custom call would not run on the specified stream, leading to errors. To fix this, we simply use `GetStreamForExecution` instead of just relying on `params.stream` to decided which stream to pass to the FFI. This is the same logic that other thunks already have. We need to make this edit twice since there are two codepaths a custom call can follow, once via `ExecuteFfiHandler` and another in `ExecuteCustomCall`. Copybara import of the project: -- 7ee230f7abf744fa40ad86e0521e9b8b9f7cc80e by chaser : Support custom call stream assignment Merging this change closes #24550 PiperOrigin-RevId: 744730610 --- .../backends/gpu/runtime/custom_call_thunk.cc | 10 ++++-- .../gpu/runtime/custom_call_thunk_test.cc | 32 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk.cc index d45931c4b758ec..662f61c0710fa1 100644 --- a/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk.cc @@ -136,8 +136,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { } } + TF_ASSIGN_OR_RETURN( + se::Stream * stream, + GetStreamForExecution(Thunk::execution_stream_id(), params)); XlaCustomCallStatus custom_call_status; - call_target_(params.stream, buffers.data(), opaque_.data(), opaque_.size(), + call_target_(stream, buffers.data(), opaque_.data(), opaque_.size(), &custom_call_status); auto message = CustomCallStatusGetMessage(&custom_call_status); if (message) { @@ -246,10 +249,13 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { } absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { + TF_ASSIGN_OR_RETURN( + se::Stream * stream, + GetStreamForExecution(Thunk::execution_stream_id(), params)); if (bundle_.has_value()) { return ExecuteFfiHandler( params.collective_params ? params.collective_params->run_id : RunId{-1}, - bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, params.stream, + bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, stream, params.ffi_execution_context, params.buffer_allocations); } return ExecuteCustomCall(params); diff --git a/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk_test.cc index 4770e7e68dc80d..18c1fdb242833e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/custom_call_thunk_test.cc @@ -71,5 +71,37 @@ TEST(CustomCallThunkTest, SimpleCustomCall) { EXPECT_TRUE(was_called); } +TEST(CustomCallThunkTest, CustomCallOnCustomStream) { + // Whitebox test to ensure that custom calls respect execution_stream_id + // assignments. + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr extra_stream, + executor->CreateStream()); + // Setup the additional streams. + Thunk::ExecutionStreamIdMap additional_compute_streams = {}; + additional_compute_streams[ExecutionStreamId(1)] = extra_stream.get(); + se::StreamExecutorMemoryAllocator allocator(executor); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator), + stream.get(), stream.get(), nullptr, nullptr, additional_compute_streams); + + CustomCallThunk::CustomCallTarget target = + [&](se::Stream* stream_in_callback, void** args, const char* target_name, + size_t num_args, XlaCustomCallStatus* status) { + // We should be launching on the extra stream and not the default one. + EXPECT_THAT(stream_in_callback, ::testing::Eq(extra_stream.get())); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, CustomCallThunk::Create(Thunk::ThunkInfo(), "target_name", + target, {}, {}, "")); + // Setting this tells the thunk to dispatch on one of the additional streams. + thunk->set_execution_stream_id(ExecutionStreamId(1)); + EXPECT_THAT(thunk->ExecuteOnStream(Thunk::ExecuteParams(params)), + ::tsl::testing::IsOk()); +} + } // namespace } // namespace xla::gpu From bdd49be7ee756c05f6f6c4db05d8c707a32da449 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Mon, 7 Apr 2025 08:35:02 -0700 Subject: [PATCH 0300/1324] Move `BufferComparatorKernel` behind `GpuKernelRegistry`. * Moves `BufferComparator` logic into `backends/gpu/runtime` since it's a runtime component. * Defines trait for the `BufferComparator` kernel in `stream_executor/gpu/` * Moves the implementations of this kernel into `stream_executor/{cuda|rocm}` and registers them with the registry for each `xla::primitive_util::{Integral|FloatingPoint}Type`. * Makes `BufferComparator` retrieve the kernel by using the kernel registry. * Add the kernel implementations as dependencies to the `all_runtime` targets for CUDA and ROCm. PiperOrigin-RevId: 744732993 --- .../xla/xla/backends/gpu/runtime/BUILD | 43 +++++++++++ .../gpu/runtime}/buffer_comparator.cc | 49 ++++-------- .../gpu/runtime}/buffer_comparator.h | 12 +-- .../gpu/runtime}/buffer_comparator_test.cc | 3 +- third_party/xla/xla/primitive_util.h | 35 +++++++++ third_party/xla/xla/service/gpu/BUILD | 75 ------------------- .../xla/xla/service/gpu/autotuning/BUILD | 6 +- .../gpu/autotuning/conv_algorithm_picker.cc | 2 +- .../gpu/autotuning/gemm_algorithm_picker.cc | 2 +- .../gpu/autotuning/gemm_fusion_autotuner.cc | 2 +- .../xla/xla/service/gpu/transforms/BUILD | 2 +- .../triton_fusion_numerics_verifier.cc | 2 +- .../xla/xla/stream_executor/cuda/BUILD | 25 +++++++ .../cuda/buffer_comparator_kernel_cuda.cu.cc | 46 ++++++++++++ third_party/xla/xla/stream_executor/gpu/BUILD | 11 +++ .../gpu/buffer_comparator_kernel.h | 41 ++++++++++ .../gpu/buffer_comparator_kernel_lib.cu.h} | 67 ++++++++++------- .../xla/xla/stream_executor/rocm/BUILD | 25 +++++++ .../rocm/buffer_comparator_kernel_rocm.cu.cc | 45 +++++++++++ 19 files changed, 336 insertions(+), 157 deletions(-) rename third_party/xla/xla/{service/gpu => backends/gpu/runtime}/buffer_comparator.cc (82%) rename third_party/xla/xla/{service/gpu => backends/gpu/runtime}/buffer_comparator.h (85%) rename third_party/xla/xla/{service/gpu => backends/gpu/runtime}/buffer_comparator_test.cc (99%) create mode 100644 third_party/xla/xla/stream_executor/cuda/buffer_comparator_kernel_cuda.cu.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel.h rename third_party/xla/xla/{service/gpu/buffer_comparator.cu.cc => stream_executor/gpu/buffer_comparator_kernel_lib.cu.h} (64%) create mode 100644 third_party/xla/xla/stream_executor/rocm/buffer_comparator_kernel_rocm.cu.cc diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 471c54cecde82b..5fb985c794f176 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -1348,3 +1348,46 @@ cc_library( "@com_google_absl//absl/synchronization", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/service/gpu:launch_dimensions", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:buffer_comparator_kernel", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "@com_google_absl//absl/status:statusor", + "@eigen_archive//:eigen3", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = ["gpu"], + disabled_backends = [], + deps = [ + ":buffer_comparator", + "//xla:shape_util", + "//xla:types", + "//xla/service:platform_util", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cc b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc similarity index 82% rename from third_party/xla/xla/service/gpu/buffer_comparator.cc rename to third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc index 5c4ea65739938f..7a782fc0be6b95 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/buffer_comparator.h" +#include "xla/backends/gpu/runtime/buffer_comparator.h" #include #include #include -#include #include #include +#include "absl/status/statusor.h" #include "Eigen/Core" #include "xla/primitive_util.h" #include "xla/service/gpu/launch_dimensions.h" @@ -29,23 +29,15 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_handle.h" -#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/gpu/buffer_comparator_kernel.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/typed_kernel_factory.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -template -using ComparisonKernelT = - se::TypedKernel, se::DeviceMemory, - float, uint64_t, se::DeviceMemory>; - struct ComparisonParams { double relative_tol = 0.1; bool verbose = true; @@ -59,9 +51,7 @@ struct ComparisonParams { // // Returns `true` if two buffers are equal, `false` otherwise. template -static absl::StatusOr DeviceCompare(absl::string_view kernel_name, - void* kernel_symbol, - const ComparisonParams& params) { +static absl::StatusOr DeviceCompare(const ComparisonParams& params) { se::StreamExecutor* executor = params.stream->parent(); se::DeviceMemoryHandle out(executor, executor->AllocateScalar()); @@ -78,11 +68,10 @@ static absl::StatusOr DeviceCompare(absl::string_view kernel_name, uint64_t buffer_size = current_typed.ElementCount(); TF_ASSIGN_OR_RETURN( - ComparisonKernelT comparison_kernel, - (se::TypedKernelFactory< - se::DeviceMemory, se::DeviceMemory, float, - uint64_t, se::DeviceMemory>::Create(executor, kernel_name, - kernel_symbol))); + auto comparison_kernel, + stream_executor::gpu::GpuKernelRegistry::GetGlobalRegistry() + .LoadKernel>( + executor)); const se::DeviceDescription& gpu_device_info = executor->GetDeviceDescription(); @@ -162,12 +151,9 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { template static absl::StatusOr CompareEqualParameterized( - absl::string_view kernel_name, void* kernel_symbol, const ComparisonParams& params) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); - TF_ASSIGN_OR_RETURN( - bool result, DeviceCompare(kernel_name, kernel_symbol, params)); - + TF_ASSIGN_OR_RETURN(bool result, DeviceCompare(params)); if (result) { return true; } @@ -185,21 +171,12 @@ absl::StatusOr BufferComparator::CompareEqual( ComparisonParams params{relative_tol_, verbose_, &shape_, stream, current, expected}; - void* kernel_symbol = buffer_comparator::comparison_fn(shape_.element_type()); - if (kernel_symbol == nullptr) { - return Unimplemented("Unimplemented element type for device kernel"); - } - - std::string kernel_name = absl::StrCat( - primitive_util::LowercasePrimitiveTypeName(shape_.element_type()), - "_comparison"); - auto do_compare = [&](auto cst_type) { using ElementT = primitive_util::NativeTypeOf; using ComparisonT = - std::conditional_t, double, float>; - return CompareEqualParameterized( - kernel_name, kernel_symbol, params); + std::conditional_t, + double, float>; + return CompareEqualParameterized(params); }; if (primitive_util::IsFloatingPointType(shape_.element_type())) { diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.h b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.h similarity index 85% rename from third_party/xla/xla/service/gpu/buffer_comparator.h rename to third_party/xla/xla/backends/gpu/runtime/buffer_comparator.h index 76e605bb24adb5..d3b016b7ca1112 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.h +++ b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ -#define XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#ifndef XLA_BACKENDS_GPU_RUNTIME_BUFFER_COMPARATOR_H_ +#define XLA_BACKENDS_GPU_RUNTIME_BUFFER_COMPARATOR_H_ #include "absl/status/statusor.h" #include "xla/shape.h" @@ -50,12 +50,6 @@ class BufferComparator { bool verbose_; // whether to print out error message on mismatch }; -namespace buffer_comparator { - -// Returns a pointer to CUDA C++ device function implementing comparison. -void* comparison_fn(xla::PrimitiveType type); - -} // namespace buffer_comparator } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#endif // XLA_BACKENDS_GPU_RUNTIME_BUFFER_COMPARATOR_H_ diff --git a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/buffer_comparator_test.cc rename to third_party/xla/xla/backends/gpu/runtime/buffer_comparator_test.cc index acd7f60ac94fc7..942a951b50f533 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/buffer_comparator.h" +#include "xla/backends/gpu/runtime/buffer_comparator.h" #include #include @@ -33,7 +33,6 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/types.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/status.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/primitive_util.h b/third_party/xla/xla/primitive_util.h index bb6d542afd5058..4aced00bdf2cd6 100644 --- a/third_party/xla/xla/primitive_util.h +++ b/third_party/xla/xla/primitive_util.h @@ -646,6 +646,41 @@ constexpr R PrimitiveTypeSwitch(F&& f, PrimitiveType type) { LOG(FATAL) << "unhandled type " << type; } +template +constexpr void IntegralTypeForEach(F&& f) { + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); +} + +template +constexpr void FloatingPointTypeForEach(F&& f) { + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); + std::forward(f)(PrimitiveTypeConstant()); +} + namespace internal { // Returns the number of bits in the native type for a given primitive type if diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ed3775413cbbce..fee0cc5d0f6fda 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -6,7 +6,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", - "rocm_copts", ) load("//xla:xla.default.bzl", "xla_cc_test", "xla_internal") load( @@ -2495,80 +2494,6 @@ xla_cc_test( ], ) -cc_library( - name = "buffer_comparator", - srcs = if_gpu_is_configured(["buffer_comparator.cc"]), - hdrs = if_gpu_is_configured(["buffer_comparator.h"]), - deps = if_gpu_is_configured([ - # keep sorted - ":buffer_comparator_kernel", - ":gpu_asm_opts_util", - ":launch_dimensions", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/service:hlo_module_config", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:kernel", - "//xla/stream_executor:stream", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:asm_compiler", - "@com_google_absl//absl/base", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", - ]) + if_rocm_is_configured([ - # keep sorted - "@local_config_rocm//rocm:rocm_config", - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -gpu_kernel_library( - name = "buffer_comparator_kernel", - srcs = if_gpu_is_configured(["buffer_comparator.cu.cc"]), - copts = rocm_copts(), - deps = [ - "//xla:shape_util", - "//xla:types", - ] + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -xla_test( - name = "buffer_comparator_test", - srcs = if_gpu_is_configured(["buffer_comparator_test.cc"]), - backends = ["gpu"], - deps = [ - ":stream_executor_util", - "//xla:shape_util", - "//xla:types", - "//xla/service:hlo_module_config", - "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/stream_executor:stream", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - ] + if_gpu_is_configured([ - ":buffer_comparator", - "//xla/stream_executor:device_memory", - ]), -) - cc_library( name = "buffer_sharing", srcs = ["buffer_sharing.cc"], diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index c778f212a0a794..47a9c7c02d8ffd 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -126,6 +126,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/gpu/runtime:buffer_comparator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", @@ -142,7 +143,6 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service:shaped_buffer", "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_indexing_utils", @@ -301,11 +301,11 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_proto_cc", + "//xla/backends/gpu/runtime:buffer_comparator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", @@ -529,7 +529,7 @@ cc_library( "@local_tsl//tsl/platform:numbers", ] + if_cuda_is_configured([ # keep sorted - "//xla/service/gpu:buffer_comparator", + "//xla/backends/gpu/runtime:buffer_comparator", "//xla/stream_executor/gpu:redzone_allocator", "@local_config_cuda//cuda:cudnn_header", ]), diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc index abb04413acfd2f..c616588c92bddc 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc @@ -83,7 +83,7 @@ limitations under the License. #else #include "third_party/gpus/cudnn/cudnn_ops_infer.h" #endif // CUDNN_VERSION >= 90000 -#include "xla/service/gpu/buffer_comparator.h" +#include "xla/backends/gpu/runtime/buffer_comparator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #endif diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc index fabeb049d66d47..90087127341c81 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc @@ -34,12 +34,12 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" +#include "xla/backends/gpu/runtime/buffer_comparator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index bcc2a05fedd748..0cb9682053a86b 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -42,6 +42,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" +#include "xla/backends/gpu/runtime/buffer_comparator.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -65,7 +66,6 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/autotuning/dot_search_space.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/kernels/custom_kernel.h" diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 08971084ef1ad7..bd31c186229aa6 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -3269,6 +3269,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_proto_cc", + "//xla/backends/gpu/runtime:buffer_comparator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:dump", @@ -3277,7 +3278,6 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service:shaped_buffer", "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/autotuning:autotuner_compile_util", "//xla/service/gpu/autotuning:autotuner_util", diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index 04e4da53f0fc08..4e58bee68d251d 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/backends/gpu/runtime/buffer_comparator.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/priority_fusion.h" diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 4dd29c210fc903..80be22c8899a39 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1117,6 +1117,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":buffer_comparator_kernel_cuda", ":cublas_plugin", ":cuda_platform", ":cudnn_plugin", @@ -1971,3 +1972,27 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cuda_library( + name = "buffer_comparator_kernel_cuda", + srcs = [ + "buffer_comparator_kernel_cuda.cu.cc", + "//xla/stream_executor/gpu:buffer_comparator_kernel_lib.cu.h", + ], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_platform_id", + "//xla:shape_util", + "//xla:types", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:buffer_comparator_kernel", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/platform:initialize", + "@local_config_cuda//cuda:cuda_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/cuda/buffer_comparator_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/buffer_comparator_kernel_cuda.cu.cc new file mode 100644 index 00000000000000..59e2fd8e7b9304 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/buffer_comparator_kernel_cuda.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/gpus/cuda/include/device_launch_parameters.h" +#include "xla/primitive_util.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/gpu/buffer_comparator_kernel_lib.cu.h" +#include "xla/stream_executor/platform/initialize.h" + +namespace stream_executor::cuda { + +// Comparison kernel code: compare two buffers of +// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the +// relative error does not exceed the passed rel_error_threshold. Write the +// number of mismatches into out parameter mismatch_count. + +namespace { + +static void RegisterBufferComparatorKernelCudaImpl() { + auto register_kernel = [&](auto primitive_type_constant) { + gpu::RegisterBufferComparatorKernelParametrized< + xla::primitive_util::NativeTypeOf>( + stream_executor::cuda::kCudaPlatformId); + }; + xla::primitive_util::IntegralTypeForEach(register_kernel); + xla::primitive_util::FloatingPointTypeForEach(register_kernel); +} + +} // namespace +} // namespace stream_executor::cuda + +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + RegisterBufferComparatorKernelCuda, + stream_executor::cuda::RegisterBufferComparatorKernelCudaImpl()); diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 780acdabb537c5..75378d3960246b 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -853,3 +853,14 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "buffer_comparator_kernel", + hdrs = ["buffer_comparator_kernel.h"], + deps = [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + ], +) + +exports_files(["buffer_comparator_kernel_lib.cu.h"]) diff --git a/third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel.h b/third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel.h new file mode 100644 index 00000000000000..84eb98ff6e2a69 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel.h @@ -0,0 +1,41 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_BUFFER_COMPARATOR_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_GPU_BUFFER_COMPARATOR_KERNEL_H_ + +#include + +#include + +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" + +namespace stream_executor::gpu { + +// Defines a trait for the BufferComparator kernel that can be used to register +// and look up the kernel in the GPU kernel registry. +template +struct BufferComparatorKernel { + using KernelType = + stream_executor::TypedKernel, + stream_executor::DeviceMemory, + float, uint64_t, + stream_executor::DeviceMemory>; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_BUFFER_COMPARATOR_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc b/third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel_lib.cu.h similarity index 64% rename from third_party/xla/xla/service/gpu/buffer_comparator.cu.cc rename to third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel_lib.cu.h index 08563ce2732070..bdba30bb238d5b 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/buffer_comparator_kernel_lib.cu.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The OpenXLA Authors. +/* Copyright 2025 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef XLA_STREAM_EXECUTOR_GPU_BUFFER_COMPARATOR_KERNEL_LIB_CU_H_ +#define XLA_STREAM_EXECUTOR_GPU_BUFFER_COMPARATOR_KERNEL_LIB_CU_H_ + +#include + #include #include #include #include "xla/primitive_util.h" +#include "xla/stream_executor/gpu/buffer_comparator_kernel.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/types.h" -namespace xla::gpu::buffer_comparator { - -// Comparison kernel code: compare two buffers of -// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the -// relative error does not exceed the passed rel_error_threshold. Write the -// number of mismatches into out parameter mismatch_count. - -namespace { +namespace stream_executor::gpu { // NaN's are considered equal, and for half's we clamp all numbers to largest // and smallest numbers representable to avoid miscomparisons due to overflows. @@ -106,26 +107,38 @@ __global__ void xla_int_comparison(T* buffer_a, T* buffer_b, atomicAdd(mismatch_count, 1); } -} // namespace +template +void RegisterBufferComparatorKernelParametrized(Platform::Id platform_id) { + void* kernel_symbol = nullptr; + constexpr xla::PrimitiveType p_type = + xla::primitive_util::NativeToPrimitiveType(); -void* comparison_fn(const xla::PrimitiveType type) { - if (xla::primitive_util::IsFloatingPointType(type)) { - return primitive_util::FloatingPointTypeSwitch( - [](auto cst_type) { - using native_type = primitive_util::NativeTypeOf; - return reinterpret_cast(&xla_fp_comparison); - }, - type); + if constexpr (xla::primitive_util::IsIntegralType(p_type)) { + kernel_symbol = absl::bit_cast(&xla_int_comparison); + } else if constexpr (xla::primitive_util::IsFloatingPointType(p_type)) { + kernel_symbol = absl::bit_cast(&xla_fp_comparison); + } else { + LOG(FATAL) << "Failed to register buffer comparator kernel for type " + << xla::primitive_util::LowercasePrimitiveTypeName(p_type); + return; } - if (xla::primitive_util::IsIntegralType(type)) { - return primitive_util::IntegralTypeSwitch( - [](auto cst_type) { - using native_type = primitive_util::NativeTypeOf; - return reinterpret_cast(&xla_int_comparison); - }, - type); + std::string kernel_name = absl::StrCat( + xla::primitive_util::LowercasePrimitiveTypeName(p_type), "_comparison"); + + stream_executor::MultiKernelLoaderSpec spec(5); + spec.AddInProcessSymbol(kernel_symbol, kernel_name); + + absl::Status result = + stream_executor::gpu::GpuKernelRegistry::GetGlobalRegistry() + .RegisterKernel>(platform_id, spec); + + if (!result.ok()) { + LOG(FATAL) << "Failed to register buffer comparator kernel for type " + << xla::primitive_util::LowercasePrimitiveTypeName(p_type) + << ": " << result; } - return nullptr; } -} // namespace xla::gpu::buffer_comparator +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_BUFFER_COMPARATOR_KERNEL_LIB_CU_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 97c0f2bcbe912a..dbac48b0c136d6 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -824,6 +824,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":amdhipblaslt_plugin", + ":buffer_comparator_kernel_rocm", ":hipfft_plugin", ":miopen_plugin", ":rocblas_plugin", @@ -1073,3 +1074,27 @@ cc_library( "@local_tsl//tsl/platform:casts", ], ) + +rocm_library( + name = "buffer_comparator_kernel_rocm", + srcs = [ + "buffer_comparator_kernel_rocm.cu.cc", + "//xla/stream_executor/gpu:buffer_comparator_kernel_lib.cu.h", + ], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + ":rocm_platform_id", + "//xla:shape_util", + "//xla:types", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:buffer_comparator_kernel", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/platform:initialize", + "@local_config_rocm//rocm:rocm_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/rocm/buffer_comparator_kernel_rocm.cu.cc b/third_party/xla/xla/stream_executor/rocm/buffer_comparator_kernel_rocm.cu.cc new file mode 100644 index 00000000000000..ac6b084541be6a --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/buffer_comparator_kernel_rocm.cu.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/primitive_util.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/gpu/buffer_comparator_kernel_lib.cu.h" +#include "xla/stream_executor/platform/initialize.h" + +namespace stream_executor::rocm { + +// Comparison kernel code: compare two buffers of +// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the +// relative error does not exceed the passed rel_error_threshold. Write the +// number of mismatches into out parameter mismatch_count. + +namespace { + +static void RegisterBufferComparatorKernelRocmImpl() { + auto register_kernel = [&](auto primitive_type_constant) { + gpu::RegisterBufferComparatorKernelParametrized< + xla::primitive_util::NativeTypeOf>( + stream_executor::rocm::kROCmPlatformId); + }; + xla::primitive_util::IntegralTypeForEach(register_kernel); + xla::primitive_util::FloatingPointTypeForEach(register_kernel); +} + +} // namespace +} // namespace stream_executor::rocm + +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + RegisterBufferComparatorKernelRocm, + stream_executor::rocm::RegisterBufferComparatorKernelRocmImpl()); From d7dab308b82754ecb7146986b061c0c3d137c438 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 08:36:17 -0700 Subject: [PATCH 0301/1324] Refactor delegation of unary and binary ops. PiperOrigin-RevId: 744733352 --- .../delegates/xnnpack/xnnpack_delegate.cc | 1759 +++++------------ 1 file changed, 485 insertions(+), 1274 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index f5d3eaeff3b1ca..3372324fa4f560 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -2838,16 +2838,43 @@ class Subgraph { #endif switch (registration->builtin_code) { case kTfLiteBuiltinAbs: - return VisitAbsNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinAdd: { - const TfLiteAddParams* add_params = - static_cast(node->builtin_data); + case kTfLiteBuiltinCeil: + case kTfLiteBuiltinCos: + case kTfLiteBuiltinDequantize: + case kTfLiteBuiltinElu: + case kTfLiteBuiltinFloor: + case kTfLiteBuiltinGelu: + case kTfLiteBuiltinHardSwish: + case kTfLiteBuiltinLeakyRelu: + case kTfLiteBuiltinLogistic: + case kTfLiteBuiltinNeg: + case kTfLiteBuiltinQuantize: + case kTfLiteBuiltinRelu: + case kTfLiteBuiltinRelu6: + case kTfLiteBuiltinReluN1To1: + case kTfLiteBuiltinRound: + case kTfLiteBuiltinRsqrt: + case kTfLiteBuiltinSin: + case kTfLiteBuiltinSqrt: + case kTfLiteBuiltinSquare: + case kTfLiteBuiltinTanh: + return VisitUnaryNode(subgraph, delegate, logging_context, node_index, + node, (BuiltinOperator)registration->builtin_code, + context->tensors, input_output_tensors); + + case kTfLiteBuiltinAdd: + case kTfLiteBuiltinDiv: + case kTfLiteBuiltinMaximum: + case kTfLiteBuiltinMinimum: + case kTfLiteBuiltinMul: + case kTfLiteBuiltinPrelu: + case kTfLiteBuiltinSquaredDifference: + case kTfLiteBuiltinSub: + return VisitBinaryNode( + subgraph, delegate, logging_context, node_index, node, + (BuiltinOperator)registration->builtin_code, context->tensors, + quasi_static_tensors, input_output_tensors); - return VisitAddNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, add_params, - input_output_tensors); - } case kTfLiteBuiltinAssignVariable: return VisitAssignVariableNode(subgraph, delegate, logging_context, node_index, node, context->tensors, @@ -2868,9 +2895,6 @@ class Subgraph { node_index, node, context->tensors, batchmatmul_params, input_output_tensors); } - case kTfLiteBuiltinCeil: - return VisitCeilNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinConcatenation: { const TfLiteConcatenationParams* concat_params = static_cast(node->builtin_data); @@ -2886,9 +2910,6 @@ class Subgraph { node, context->tensors, conv_params, quasi_static_tensors, input_output_tensors); } - case kTfLiteBuiltinCos: - return VisitCosNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinDepthwiseConv2d: { const TfLiteDepthwiseConvParams* dwconv_params = static_cast(node->builtin_data); @@ -2906,21 +2927,6 @@ class Subgraph { subgraph, delegate, logging_context, node_index, node, context->tensors, depth_to_space_params, input_output_tensors); } - case kTfLiteBuiltinDequantize: - return VisitDequantizeNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); - case kTfLiteBuiltinDiv: { - const TfLiteDivParams* div_params = - static_cast(node->builtin_data); - - return VisitDivNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, div_params, - input_output_tensors); - } - case kTfLiteBuiltinElu: - return VisitEluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinExpandDims: return VisitExpandDimsNode(subgraph, delegate, logging_context, node_index, node, context->tensors, @@ -2943,28 +2949,6 @@ class Subgraph { fc_params, quasi_static_tensors, input_output_tensors); } - case kTfLiteBuiltinFloor: - return VisitFloorNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinGelu: - return VisitGeluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinHardSwish: - return VisitHardSwishNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); - case kTfLiteBuiltinLeakyRelu: { - const TfLiteLeakyReluParams* leaky_relu_params = - static_cast(node->builtin_data); - - return VisitLeakyReluNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - leaky_relu_params, input_output_tensors); - } - case kTfLiteBuiltinLogistic: - return VisitLogisticNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); case kTfLiteBuiltinMaxPool2d: { const TfLitePoolParams* pool_params = static_cast(node->builtin_data); @@ -2982,9 +2966,6 @@ class Subgraph { context->tensors, reducer_params, input_output_tensors); } - case kTfLiteBuiltinMaximum: - return VisitMaximumNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinMean: { const TfLiteReducerParams* reducer_params = static_cast(node->builtin_data); @@ -2993,48 +2974,13 @@ class Subgraph { context->tensors, reducer_params, input_output_tensors); } - case kTfLiteBuiltinMinimum: - return VisitMinimumNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinMul: { - const TfLiteMulParams* mul_params = - static_cast(node->builtin_data); - - return VisitMulNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, mul_params, - input_output_tensors); - } - case kTfLiteBuiltinNeg: - return VisitNegNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinPad: return VisitPadNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); - case kTfLiteBuiltinPrelu: - return VisitPreluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, quasi_static_tensors, - input_output_tensors); - case kTfLiteBuiltinQuantize: - return VisitQuantizeNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); case kTfLiteBuiltinReadVariable: return VisitReadVariableNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); - case kTfLiteBuiltinRelu: - return VisitReluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, 0.0f, - std::numeric_limits::infinity(), - input_output_tensors); - case kTfLiteBuiltinReluN1To1: - return VisitReluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, -1.0f, 1.0f, - input_output_tensors); - case kTfLiteBuiltinRelu6: - return VisitReluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, 0.0f, 6.0f, - input_output_tensors); case kTfLiteBuiltinReshape: { const TfLiteReshapeParams* reshape_params = static_cast(node->builtin_data); @@ -3051,15 +2997,6 @@ class Subgraph { node_index, node, context->tensors, resize_params, input_output_tensors); } - case kTfLiteBuiltinRound: - return VisitRoundNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinRsqrt: - return VisitRsqrtNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinSin: - return VisitSinNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinSlice: return VisitSliceNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); @@ -3086,16 +3023,6 @@ class Subgraph { node, context->tensors, split_params, input_output_tensors); } - case kTfLiteBuiltinSqrt: - return VisitSqrtNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinSquare: - return VisitSquareNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinSquaredDifference: - return VisitSquaredDifferenceNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); case kTfLiteBuiltinStridedSlice: { const auto* params = static_cast(node->builtin_data); @@ -3103,17 +3030,6 @@ class Subgraph { node_index, node, context->tensors, params, input_output_tensors); } - case kTfLiteBuiltinSub: { - const TfLiteSubParams* sub_params = - static_cast(node->builtin_data); - - return VisitSubNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, sub_params, - input_output_tensors); - } - case kTfLiteBuiltinTanh: - return VisitTanhNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinTranspose: { return VisitTransposeNode(subgraph, delegate, logging_context, node_index, node, context->tensors, @@ -3199,112 +3115,6 @@ class Subgraph { } } - static TfLiteStatus VisitAbsNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_ABS, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_abs( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ABS), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitAddNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteAddParams* add_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_ADD, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_ADD, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, - node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_ADD, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (input1_tensor.type != input2_tensor.type || - input1_tensor.type != output_tensor.type) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported mixed types in ADD operator #%d", - node_index); - return kTfLiteError; - } - const float scale_min = 1.0f / 1024.0f; - const float scale_max = 256.0f; - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input1_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_ADD, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input2_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_ADD, node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (add_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, add_params->activation, &output_min, - &output_max)); - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_add2( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ADD), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitAssignVariableNode( xnn_subgraph_t subgraph, Delegate& delegate, TfLiteContext* logging_context, int node_index, const TfLiteNode* node, @@ -3600,39 +3410,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitCeilNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_CEIL, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_ceiling( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_CEIL), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitConcatenationNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -3942,39 +3719,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitCosNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_COS, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloatType( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloatType( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_unary( - subgraph, xnn_unary_cosine, /*params=*/nullptr, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_COS), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitDepthwiseConv2DNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -4141,91 +3885,198 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitDequantizeNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_DEQUANTIZE, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorQInt8OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_DEQUANTIZE, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_convert( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_DEQUANTIZE), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitDivNode( + static TfLiteStatus VisitBinaryNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteDivParams* div_params, + tflite::BuiltinOperator op_type, const TfLiteTensor* tensors, + const std::unordered_set& quasi_static_tensors, const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_DIV, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); + // Get the input and output tensors. + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs(logging_context, node, 2, 1, + op_type, node_index)); + const int input1_id = node->inputs->data[0]; + const int input2_id = node->inputs->data[1]; + const int output_id = node->outputs->data[0]; + const TfLiteTensor& input1_tensor = tensors[input1_id]; + const TfLiteTensor& input2_tensor = tensors[input2_id]; + const TfLiteTensor& output_tensor = tensors[output_id]; + + // Check the input shapes. TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_DIV, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, input1_id, op_type, node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_DIV, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (div_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, div_params->activation, &output_min, - &output_max)); + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, input2_id, op_type, node_index)); + + // Check the input/output tensor types. + switch (op_type) { + case BuiltinOperator_ADD: + case BuiltinOperator_MUL: + case BuiltinOperator_SUB: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input1_tensor, input1_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input2_tensor, input2_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + if (input1_tensor.type != input2_tensor.type || + input1_tensor.type != output_tensor.type) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, "unsupported mixed types in %s operator #%d", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + break; + case BuiltinOperator_DIV: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_PRELU: + case BuiltinOperator_SQUARED_DIFFERENCE: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, input1_tensor, input1_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, input2_tensor, input2_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, output_tensor, output_id, node_index)); + break; + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + + // Extract any op-specific params. + float output_min = -std::numeric_limits::infinity(); + float output_max = +std::numeric_limits::infinity(); + switch (op_type) { + case BuiltinOperator_ADD: { + const float scale_min = 1.0f / 1024.0f; + const float scale_max = 256.0f; + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input1_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input2_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + const TfLiteAddParams* add_params = + static_cast(node->builtin_data); + if (add_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, add_params->activation, &output_min, + &output_max)); + } + break; + } + case BuiltinOperator_DIV: { + const TfLiteDivParams* div_params = + static_cast(node->builtin_data); + if (div_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, div_params->activation, &output_min, + &output_max)); + } + break; + } + case BuiltinOperator_MUL: { + const float scale_min = 1.0f / 65536.0f; + const float scale_max = 256.0f; + TF_LITE_ENSURE_STATUS(CheckTensorsInputProductOutputScale( + logging_context, input1_tensor, input2_tensor, output_tensor, + scale_min, scale_max, op_type, node_index)); + const TfLiteMulParams* mul_params = + static_cast(node->builtin_data); + if (mul_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, mul_params->activation, &output_min, + &output_max)); + } + break; + } + case BuiltinOperator_PRELU: + if (quasi_static_tensors.count(input2_id) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, input2_tensor, input2_id, op_type, node_index)); + } + break; + case BuiltinOperator_SUB: { + const float scale_min = 1.0f / 1024.0f; + const float scale_max = 256.0f; + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input1_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input2_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + const TfLiteSubParams* sub_params = + static_cast(node->builtin_data); + if (sub_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, sub_params->activation, &output_min, + &output_max)); + } + break; + } + default: + break; } if (subgraph != nullptr) { - const xnn_status status = xnn_define_divide( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + // Setup the binary op params. + struct xnn_binary_params params; + params.output_min = output_min; + params.output_max = output_max; + + // Set the binary op type and any special params associated with it. + enum xnn_binary_operator binary_op_type = xnn_binary_invalid; + switch (op_type) { + case BuiltinOperator_ADD: + binary_op_type = xnn_binary_add; + break; + case BuiltinOperator_DIV: + binary_op_type = xnn_binary_divide; + break; + case BuiltinOperator_MAXIMUM: + binary_op_type = xnn_binary_maximum; + break; + case BuiltinOperator_MINIMUM: + binary_op_type = xnn_binary_minimum; + break; + case BuiltinOperator_MUL: + binary_op_type = xnn_binary_multiply; + break; + case BuiltinOperator_PRELU: + binary_op_type = xnn_binary_prelu; + break; + case BuiltinOperator_SQUARED_DIFFERENCE: + binary_op_type = xnn_binary_squared_difference; + break; + case BuiltinOperator_SUB: + binary_op_type = xnn_binary_subtract; + break; + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + + // Create the subgraph node. + const xnn_status status = + xnn_define_binary(subgraph, binary_op_type, ¶ms, + /*input1_id=*/input_output_tensors.at(input1_id), + /*input2_id=*/input_output_tensors.at(input2_id), + /*output_id=*/input_output_tensors.at(output_id), + /*flags=*/0); if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_DIV), - node_index); + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d (binary_op_type=%i, status=%i)", + EnumNameBuiltinOperator(BuiltinOperator_DIV), node_index, + binary_op_type, status); return kTfLiteError; } } @@ -4233,33 +4084,257 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitEluNode( + static TfLiteStatus VisitUnaryNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, + tflite::BuiltinOperator op_type, const TfLiteTensor* tensors, const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_ELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); + // Get the input and output tensors. + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs(logging_context, node, 1, 1, + op_type, node_index)); + const int input_id = node->inputs->data[0]; + const int output_id = node->outputs->data[0]; + const TfLiteTensor& input_tensor = tensors[input_id]; + const TfLiteTensor& output_tensor = tensors[output_id]; + + // Check the input tensor shape. + TF_LITE_ENSURE_STATUS(CheckTensorShape( + logging_context, input_tensor, /*min_num_dims=*/0, + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, input_id, op_type, node_index)); + + // Check the input/output tensor types. + switch (op_type) { + case BuiltinOperator_ABS: + case BuiltinOperator_CEIL: + case BuiltinOperator_COS: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GELU: + case BuiltinOperator_HARD_SWISH: + case BuiltinOperator_NEG: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_ROUND: + case BuiltinOperator_RSQRT: + case BuiltinOperator_SIN: + case BuiltinOperator_SQRT: + case BuiltinOperator_SQUARE: + case BuiltinOperator_TANH: + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, input_id, node_index)); + break; + case BuiltinOperator_DEQUANTIZE: + TF_LITE_ENSURE_STATUS(CheckTensorQInt8OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, output_tensor, output_id, node_index)); + break; + case BuiltinOperator_ELU: + case BuiltinOperator_LOGISTIC: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + break; + case BuiltinOperator_LEAKY_RELU: { + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + const TfLiteLeakyReluParams* leaky_relu_params = + static_cast(node->builtin_data); + if (!std::isnormal(leaky_relu_params->alpha) || + leaky_relu_params->alpha == 0.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, "unsupported alpha %g in LEAKY_RELU node #%d", + leaky_relu_params->alpha, node_index); + return kTfLiteError; + } + const float input_scale = + GetTensorScaleOrDefault(input_tensor, std::nanf("")); + const float output_scale = + GetTensorScaleOrDefault(output_tensor, std::nanf("")); + if (std::isnormal(input_scale) && std::isnormal(output_scale)) { + const float positive_scale = input_scale / output_scale; + if (positive_scale < 1.0f / 256.0f || positive_scale > 128.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported positive input-to-output scale " + "%g in LEAKY_RELU node #%d", + positive_scale, node_index); + return kTfLiteError; + } + const float negative_scale = + positive_scale * leaky_relu_params->alpha; + if (negative_scale < -127.99609375f || negative_scale > 128.0f || + std::fabs(negative_scale) < 1.0f / 256.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported negative input-to-output scale " + "%g in LEAKY_RELU node #%d", + negative_scale, node_index); + return kTfLiteError; + } + } + break; + } + case BuiltinOperator_QUANTIZE: { + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorQInt8OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + const xnn_datatype input_datatype = + GetXNNPackDatatype(logging_context, input_tensor, input_id); + const xnn_datatype output_datatype = + GetXNNPackDatatype(logging_context, output_tensor, output_id); + bool supported_combination = false; + switch (input_datatype) { + case xnn_datatype_fp32: + supported_combination = true; + break; + case xnn_datatype_qint8: + case xnn_datatype_quint8: + if (input_datatype == output_datatype) { + const float input_scale = + GetTensorScaleOrDefault(input_tensor, std::nanf("")); + const float output_scale = + GetTensorScaleOrDefault(output_tensor, std::nanf("")); + const float input_output_scale = input_scale / output_scale; + if (input_output_scale < 1.0f / 256.0f || + input_output_scale > 128.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported input-to-output scale in QUANTIZE node #%d", + node_index); + return kTfLiteError; + } + supported_combination = true; + } + break; + default: + break; + } + if (!supported_combination) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported combination of input type (%s) and " + "output type (%s) in QUANTIZE node #%d", + TfLiteTypeGetName(input_tensor.type), + TfLiteTypeGetName(output_tensor.type), node_index); + return kTfLiteError; + } + break; + } + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } if (subgraph != nullptr) { - const xnn_status status = xnn_define_elu( - subgraph, /*alpha=*/1.0f, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + // Setup the unary op params. + union xnn_unary_params params; + + // Set the binary op type and any special params associated with it. + enum xnn_unary_operator unary_op_type = xnn_unary_invalid; + switch (op_type) { + case BuiltinOperator_ABS: + unary_op_type = xnn_unary_abs; + break; + case BuiltinOperator_CEIL: + unary_op_type = xnn_unary_ceiling; + break; + case BuiltinOperator_COS: + unary_op_type = xnn_unary_cosine; + break; + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_QUANTIZE: + unary_op_type = xnn_unary_convert; + break; + case BuiltinOperator_ELU: + unary_op_type = xnn_unary_elu; + params.elu.alpha = 1.0f; + break; + case BuiltinOperator_FLOOR: + unary_op_type = xnn_unary_floor; + break; + case BuiltinOperator_GELU: { + const TfLiteGeluParams* gelu_params = + static_cast(node->builtin_data); + unary_op_type = + gelu_params->approximate ? xnn_unary_approxgelu : xnn_unary_gelu; + break; + } + case BuiltinOperator_HARD_SWISH: + unary_op_type = xnn_unary_hardswish; + break; + case BuiltinOperator_LEAKY_RELU: { + const TfLiteLeakyReluParams* leaky_relu_params = + static_cast(node->builtin_data); + params.leaky_relu.negative_slope = leaky_relu_params->alpha; + unary_op_type = xnn_unary_leaky_relu; + break; + } + case BuiltinOperator_LOGISTIC: + unary_op_type = xnn_unary_sigmoid; + break; + case BuiltinOperator_NEG: + unary_op_type = xnn_unary_negate; + break; + case BuiltinOperator_RELU: + params.clamp.min = 0.0f; + params.clamp.max = std::numeric_limits::infinity(); + unary_op_type = xnn_unary_clamp; + break; + case BuiltinOperator_RELU_N1_TO_1: + params.clamp.min = -1.0f; + params.clamp.max = 1.0f; + unary_op_type = xnn_unary_clamp; + break; + case BuiltinOperator_RELU6: + params.clamp.min = 0.0f; + params.clamp.max = 6.0f; + unary_op_type = xnn_unary_clamp; + break; + case BuiltinOperator_ROUND: + unary_op_type = xnn_unary_bankers_rounding; + break; + case BuiltinOperator_RSQRT: + unary_op_type = xnn_unary_reciprocal_square_root; + break; + case BuiltinOperator_SIN: + unary_op_type = xnn_unary_sine; + break; + case BuiltinOperator_SQRT: + unary_op_type = xnn_unary_square_root; + break; + case BuiltinOperator_SQUARE: + unary_op_type = xnn_unary_square; + break; + case BuiltinOperator_TANH: + unary_op_type = xnn_unary_tanh; + break; + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + + // Create the subgraph node. + const xnn_status status = + xnn_define_unary(subgraph, unary_op_type, ¶ms, + /*input_id=*/input_output_tensors.at(input_id), + /*output_id=*/input_output_tensors.at(output_id), + /*flags=*/0); if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ELU), + EnumNameBuiltinOperator(BuiltinOperator_DIV), node_index); return kTfLiteError; } @@ -4534,233 +4609,24 @@ class Subgraph { return kTfLiteError; } } else { - const xnn_status status = xnn_define_fully_connected( - subgraph, output_min, output_max, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*filter_id=*/input_output_tensors.at(node->inputs->data[1]), - /*bias_id=*/bias_tensor_id >= 0 - ? input_output_tensors.at(bias_tensor_id) - : XNN_INVALID_VALUE_ID, - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/fc_params->keep_num_dims - ? 0 - : XNN_FLAG_TENSORFLOW_RESHAPE_2D); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG( - logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_FULLY_CONNECTED), - node_index); - return kTfLiteError; - } - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitFloorNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_FLOOR, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_floor( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_FLOOR), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitGeluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_GELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - const TfLiteGeluParams* gelu_params = - static_cast(node->builtin_data); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_unary( - subgraph, - /*type=*/gelu_params->approximate ? xnn_unary_approxgelu - : xnn_unary_gelu, - /*params=*/nullptr, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_GELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitHardSwishNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_HARD_SWISH, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_hardswish( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_HARD_SWISH), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitLeakyReluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const TfLiteLeakyReluParams* leaky_relu_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_LEAKY_RELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (!std::isnormal(leaky_relu_params->alpha) || - leaky_relu_params->alpha == 0.0f) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported alpha %g in LEAKY_RELU node #%d", - leaky_relu_params->alpha, node_index); - return kTfLiteError; - } - - const float input_scale = - GetTensorScaleOrDefault(input_tensor, std::nanf("")); - const float output_scale = - GetTensorScaleOrDefault(output_tensor, std::nanf("")); - if (std::isnormal(input_scale) && std::isnormal(output_scale)) { - const float positive_scale = input_scale / output_scale; - if (positive_scale < 1.0f / 256.0f || positive_scale > 128.0f) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported positive input-to-output scale " - "%g in LEAKY_RELU node #%d", - positive_scale, node_index); - return kTfLiteError; - } - - const float negative_scale = positive_scale * leaky_relu_params->alpha; - if (negative_scale < -127.99609375f || negative_scale > 128.0f || - std::fabs(negative_scale) < 1.0f / 256.0f) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported negative input-to-output scale " - "%g in LEAKY_RELU node #%d", - negative_scale, node_index); - return kTfLiteError; - } - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_leaky_relu( - subgraph, leaky_relu_params->alpha, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_LEAKY_RELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitLogisticNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_LOGISTIC, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_sigmoid( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_LOGISTIC), - node_index); - return kTfLiteError; + const xnn_status status = xnn_define_fully_connected( + subgraph, output_min, output_max, + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*filter_id=*/input_output_tensors.at(node->inputs->data[1]), + /*bias_id=*/bias_tensor_id >= 0 + ? input_output_tensors.at(bias_tensor_id) + : XNN_INVALID_VALUE_ID, + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + /*flags=*/fc_params->keep_num_dims + ? 0 + : XNN_FLAG_TENSORFLOW_RESHAPE_2D); + if (status != xnn_status_success) { + TF_LITE_KERNEL_LOG( + logging_context, "failed to delegate %s node #%d", + EnumNameBuiltinOperator(BuiltinOperator_FULLY_CONNECTED), + node_index); + return kTfLiteError; + } } } @@ -4892,44 +4758,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitMaximumNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MAXIMUM, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_maximum2( - subgraph, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MAXIMUM), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitMediaPipeDeconvolutionNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -5166,140 +4994,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitMinimumNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MINIMUM, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_minimum2( - subgraph, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MINIMUM), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitMulNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteMulParams* mul_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MUL, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_MUL, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, - node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_MUL, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - const float scale_min = 1.0f / 65536.0f; - const float scale_max = 256.0f; - TF_LITE_ENSURE_STATUS(CheckTensorsInputProductOutputScale( - logging_context, input1_tensor, input2_tensor, output_tensor, scale_min, - scale_max, BuiltinOperator_MUL, node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (mul_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, mul_params->activation, &output_min, - &output_max)); - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_multiply2( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MUL), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitNegNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_NEG, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_negate( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_NEG), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitPadNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -5382,139 +5076,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitPreluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_set& quasi_static_tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_PRELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, 1, XNN_MAX_TENSOR_DIMS, - node->inputs->data[0], BuiltinOperator_PRELU, node_index)); - - const TfLiteTensor& slope_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, slope_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape( - logging_context, slope_tensor, node->inputs->data[1], - BuiltinOperator_PRELU, node_index)); - if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, slope_tensor, node->inputs->data[1], - BuiltinOperator_PRELU, node_index)); - } - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, output_tensor, 1, XNN_MAX_TENSOR_DIMS, - node->outputs->data[0], BuiltinOperator_PRELU, node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_prelu( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*slope_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_PRELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitQuantizeNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_QUANTIZE, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorQInt8OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_QUANTIZE, node_index)); - - const xnn_datatype input_datatype = GetXNNPackDatatype( - logging_context, input_tensor, node->inputs->data[0]); - const xnn_datatype output_datatype = GetXNNPackDatatype( - logging_context, output_tensor, node->outputs->data[0]); - bool supported_combination = false; - switch (input_datatype) { - case xnn_datatype_fp32: - supported_combination = true; - break; - case xnn_datatype_qint8: - case xnn_datatype_quint8: - if (input_datatype == output_datatype) { - const float input_scale = - GetTensorScaleOrDefault(input_tensor, std::nanf("")); - const float output_scale = - GetTensorScaleOrDefault(output_tensor, std::nanf("")); - const float input_output_scale = input_scale / output_scale; - if (input_output_scale < 1.0f / 256.0f || - input_output_scale > 128.0f) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported input-to-output scale in QUANTIZE node #%d", - node_index); - return kTfLiteError; - } - supported_combination = true; - } - break; - default: - break; - } - if (!supported_combination) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported combination of input type (%s) and " - "output type (%s) in QUANTIZE node #%d", - TfLiteTypeGetName(input_tensor.type), - TfLiteTypeGetName(output_tensor.type), - node_index); - return kTfLiteError; - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_convert( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_QUANTIZE), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitReadVariableNode( xnn_subgraph_t subgraph, Delegate& delegate, TfLiteContext* logging_context, int node_index, const TfLiteNode* node, @@ -5541,49 +5102,16 @@ class Subgraph { resource_tensor_id, &tensors[node->outputs->data[0]], logging_context); } else { - const xnn_status status = xnn_define_copy( - subgraph, input_output_tensors.at(resource_tensor_id), - input_output_tensors.at(output_tensor_id), 0 /* flags */); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG( - logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_READ_VARIABLE), node_index); - return kTfLiteError; - } - } - return kTfLiteOk; - } - - static TfLiteStatus VisitReluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, float output_min, float output_max, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_RELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_clamp( - subgraph, output_min, output_max, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + const xnn_status status = xnn_define_copy( + subgraph, input_output_tensors.at(resource_tensor_id), + input_output_tensors.at(output_tensor_id), 0 /* flags */); if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_RELU), - node_index); + TF_LITE_KERNEL_LOG( + logging_context, "failed to delegate %s node #%d", + EnumNameBuiltinOperator(BuiltinOperator_READ_VARIABLE), node_index); return kTfLiteError; } } - return kTfLiteOk; } @@ -5780,72 +5308,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitRoundNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_ROUND, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_bankers_rounding( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ROUND), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitSinNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_SIN, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloatType( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloatType( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_unary( - subgraph, xnn_unary_sine, /*params=*/nullptr, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SIN), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitSliceNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -6121,78 +5583,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSquareNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_SQUARE, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_SQUARE, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_square( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SQUARE), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitTanhNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_TANH, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_tanh( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_TANH), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitTransposeNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -6240,119 +5630,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSqrtNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_SQRT, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_square_root( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SQRT), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitRsqrtNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_RSQRT, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_reciprocal_square_root( - subgraph, /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_RSQRT), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitSquaredDifferenceNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_SQUARED_DIFFERENCE, - node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_SQUARED_DIFFERENCE, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_SQUARED_DIFFERENCE, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_squared_difference( - subgraph, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG( - logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SQUARED_DIFFERENCE), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitStridedSliceNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -6956,72 +6233,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSubNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteSubParams* sub_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_SUB, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_SUB, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, - node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_SUB, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - const float scale_min = 1.0f / 1024.0f; - const float scale_max = 256.0f; - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input1_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_SUB, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input2_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_SUB, node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (sub_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, sub_params->activation, &output_min, - &output_max)); - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_subtract( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SUB), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitTransposeConvNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, From 646deb57e9b42af4dfd6f335afa3adf692c87762 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 09:06:32 -0700 Subject: [PATCH 0302/1324] Remove Diff Eval logging as eval runs are not calculated. PiperOrigin-RevId: 744742094 --- third_party/xla/xla/hlo/tools/hlo_diff/BUILD | 1 - third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc | 2 -- 2 files changed, 3 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD index 9feb12fbcadbd7..3c7639a89de630 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD @@ -172,7 +172,6 @@ xla_cc_binary( name = "hlo_diff", srcs = ["hlo_diff_main.cc"], deps = [ - ":hlo_diff_eval", ":hlo_diff_result", ":hlo_diff_summary", ":hlo_gumgraph_diff", diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc index b38cd4ae5aed01..4d072bbc46177c 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" -#include "xla/hlo/tools/hlo_diff/hlo_diff_eval.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" #include "xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.h" @@ -160,7 +159,6 @@ absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, const DiffResult& diff = *hlo_gumgraph_diff.diff_result; const DiffSummary& diff_summary = *hlo_gumgraph_diff.diff_summary; LogDiffResult(diff); - LogDiffEval(*hlo_gumgraph_diff.diff_eval); std::ostringstream text; RenderTextSummary(diff, text); std::cout << text.str() << '\n'; From 276ebb155ff94adf8ece553dcf0a3d2e455d6e4e Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 7 Apr 2025 11:47:32 -0500 Subject: [PATCH 0303/1324] [mlir][tosa] Fix up lit tests (#90573) This fixes a couple of failing lit tests due to tosa llvm updates Change-Id: I00136b6454f742673703138cabf798c64382bd8d Signed-off-by: Tai Ly --- .../mlir/tosa/tests/tf-to-tosa-pipeline.mlir | 10 +++++----- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 18 +++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index da75738e80d1cf..502ec7e74ade06 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -117,12 +117,12 @@ func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) // CHECK-LABEL: func.func @test_floor_div( // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> // CHECK: %[[VAL_5:.*]] = tosa.int_div %[[VAL_0]], %[[VAL_1]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor) -> tensor<13x21x3xi32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (tensor<13x1x3xi32>, tensor<13x21x3xi32>, tensor) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (tensor<13x1x3xi32>, tensor<13x21x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_0]], %[[VAL_7]] : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_3]], %[[VAL_6]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 080db037fc3215..8623a886b353eb 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -102,7 +102,7 @@ func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<16x2x2x8xi8>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<16xi8>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<16xi32>}> // CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAL_6:.*]] = tosa.conv2d %arg0, %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_5]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_6]] @@ -440,7 +440,7 @@ func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_exp_qi8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> // CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] func.func @test_exp_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> @@ -451,7 +451,7 @@ func.func @test_exp_qi8(%arg0: tensor<13x21x3x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> // CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] func.func @test_exp_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> @@ -483,16 +483,16 @@ func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*x // CHECK-LABEL: func.func @test_floor_div( // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<13x21x3xi32> { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> // CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_0]], %[[VAL_6]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_0]], %[[VAL_8]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>, tensor) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_0]], %[[VAL_8]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_7]], %[[VAL_2]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>, tensor) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_7]], %[[VAL_2]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_0]], %[[VAL_11]] : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_13:.*]] = tosa.logical_not %[[VAL_12]] : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_3]], %[[VAL_9]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> From f4d5c9db907b348ef51310d492f74705c03911a8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 09:39:08 -0700 Subject: [PATCH 0304/1324] Add delegation of `BuiltinOperator_EXP` to XNNPACK. PiperOrigin-RevId: 744751517 --- .../xnnpack/unary_elementwise_test.cc | 19 ++++++++++--------- .../delegates/xnnpack/xnnpack_delegate.cc | 5 +++++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc b/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc index 4986ad9707b302..afd77fad0607ad 100644 --- a/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc +++ b/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc @@ -142,15 +142,16 @@ TEST_P(UnaryTest, MultiThreading) { } BuiltinOperator all_unary_ops[] = { - BuiltinOperator_ABS, BuiltinOperator_CEIL, - BuiltinOperator_COS, BuiltinOperator_ELU, - BuiltinOperator_FLOOR, BuiltinOperator_GELU, - BuiltinOperator_NEG, BuiltinOperator_HARD_SWISH, - BuiltinOperator_RELU, BuiltinOperator_RELU6, - BuiltinOperator_RELU_N1_TO_1, BuiltinOperator_ROUND, - BuiltinOperator_RSQRT, BuiltinOperator_SIN, - BuiltinOperator_SQRT, BuiltinOperator_SQUARE, - BuiltinOperator_TANH, BuiltinOperator_LOGISTIC, + BuiltinOperator_ABS, BuiltinOperator_CEIL, + BuiltinOperator_COS, BuiltinOperator_ELU, + BuiltinOperator_EXP, BuiltinOperator_FLOOR, + BuiltinOperator_GELU, BuiltinOperator_NEG, + BuiltinOperator_HARD_SWISH, BuiltinOperator_RELU, + BuiltinOperator_RELU6, BuiltinOperator_RELU_N1_TO_1, + BuiltinOperator_ROUND, BuiltinOperator_RSQRT, + BuiltinOperator_SIN, BuiltinOperator_SQRT, + BuiltinOperator_SQUARE, BuiltinOperator_TANH, + BuiltinOperator_LOGISTIC, }; INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 3372324fa4f560..7d32e5c9a450c7 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -2842,6 +2842,7 @@ class Subgraph { case kTfLiteBuiltinCos: case kTfLiteBuiltinDequantize: case kTfLiteBuiltinElu: + case kTfLiteBuiltinExp: case kTfLiteBuiltinFloor: case kTfLiteBuiltinGelu: case kTfLiteBuiltinHardSwish: @@ -4107,6 +4108,7 @@ class Subgraph { case BuiltinOperator_ABS: case BuiltinOperator_CEIL: case BuiltinOperator_COS: + case BuiltinOperator_EXP: case BuiltinOperator_FLOOR: case BuiltinOperator_GELU: case BuiltinOperator_HARD_SWISH: @@ -4259,6 +4261,9 @@ class Subgraph { unary_op_type = xnn_unary_elu; params.elu.alpha = 1.0f; break; + case BuiltinOperator_EXP: + unary_op_type = xnn_unary_exp; + break; case BuiltinOperator_FLOOR: unary_op_type = xnn_unary_floor; break; From 3c42f578e34f1a429d6524970bd52ae20e874723 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 7 Apr 2025 10:14:27 -0700 Subject: [PATCH 0305/1324] PR #24599: [GPU] Print contents of command buffer thunks. Imported from GitHub PR https://github.com/openxla/xla/pull/24599 This makes thunk_sequence.txt dumps more informative. Copybara import of the project: -- e369633fb1e9a32dcacd3c922f8d2de3a15b69a5 by Ilia Sergachev : [GPU] Print contents of command buffer thunks. This makes thunk_sequence.txt dumps more informative. fix test Merging this change closes #24599 PiperOrigin-RevId: 744764038 --- .../backends/gpu/runtime/command_buffer_thunk.cc | 7 +++++++ .../backends/gpu/runtime/command_buffer_thunk.h | 2 ++ .../gpu/runtime/command_buffer_thunk_test.cc | 16 ++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc index 9e2db69aa8106a..a13f590696506f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc @@ -348,4 +348,11 @@ void CommandBufferThunk::ForAllThunks( thunks_->ForAllThunks(fn); } } + +std::string CommandBufferThunk::ToString(int indent) const { + std::string result = "\n"; + absl::StrAppend(&result, thunks_->ToString(indent + 1)); + return result; +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h index 63fe04a682ef26..c8f248b67be198 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h @@ -57,6 +57,8 @@ class CommandBufferThunk : public Thunk { void ForAllThunks(absl::FunctionRef fn) const override; + std::string ToString(int indent) const override; + private: // Command buffer instantiated on a `se::StreamExecutor` instance, and // auxiliary state required for efficient command buffer updates. diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 5de0bd6ee0b910..c759af5746bf23 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -1463,4 +1463,20 @@ ENTRY main.49 { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-3, 2e-3})); } +TEST(CommandBufferThunkTest, ToStringPrintsNestedThunks) { + BufferAllocation alloc_a(/*index=*/0, /*size=*/4, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, /*offset=*/0, /*size=*/4); + CommandBufferCmdSequence::Builder builder; + builder.Emplace(s0, slice_a, int32_t{42}); + CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + std::vector> thunks; + thunks.emplace_back( + std::make_unique(Thunk::ThunkInfo(), 42, slice_a)); + CommandBufferThunk thunk( + std::move(commands), Thunk::ThunkInfo(), + std::make_unique(Thunk::ThunkInfo(), std::move(thunks))); + EXPECT_TRUE( + absl::StrContains(thunk.ToString(/*indent=*/1), " kMemset32BitValue")); +} + } // namespace xla::gpu From 5ebed7532d62de661bef04fe01b666bb69b098de Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 7 Apr 2025 10:31:13 -0700 Subject: [PATCH 0306/1324] Fix a missing default initialization for the cached hash of `BasicDeviceList` The constructor of `std::atomic` didn't perform value initialization until C++20. PiperOrigin-RevId: 744770279 --- third_party/xla/xla/python/ifrt/basic_device_list.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/ifrt/basic_device_list.h b/third_party/xla/xla/python/ifrt/basic_device_list.h index 89a9db8b010f22..5231487547cbc2 100644 --- a/third_party/xla/xla/python/ifrt/basic_device_list.h +++ b/third_party/xla/xla/python/ifrt/basic_device_list.h @@ -101,7 +101,7 @@ class BasicDeviceList : public llvm::RTTIExtends { // Cached hash. 0 indicates the hash needs to be computed and cached. // May be written multiple times with the same non-zero value. static constexpr uint64_t kUnsetHash = 0; - mutable std::atomic hash_; + mutable std::atomic hash_ = kUnsetHash; }; } // namespace ifrt From 17a7c51c6a2d23a8945f0c13deeadcb31c42cb05 Mon Sep 17 00:00:00 2001 From: Nicolas Perez Date: Mon, 7 Apr 2025 10:45:29 -0700 Subject: [PATCH 0307/1324] Return error if tpu topology is not available when getting number of cores per chip. This is function is added to prevent returning cores_per_chip=4 when targeting hardware with 2 SCs. PiperOrigin-RevId: 744775665 --- third_party/xla/xla/stream_executor/tpu/BUILD | 1 + .../xla/xla/stream_executor/tpu/tpu_library_init_fns.inc | 1 + third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h | 7 +++++++ 3 files changed, 9 insertions(+) diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index 3623c2b86a6423..09f7da009fcbdf 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -176,6 +176,7 @@ cc_library( deps = [ ":c_api_decl", ":libtftpu_header", + "@com_google_absl//absl/status:statusor", ], alwayslink = True, ) diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc b/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc index 1ab49e97110a46..07fa04a011494a 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc @@ -75,6 +75,7 @@ absl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount); TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoresPerChip); + TFTPU_SET_FN(ops_api_fn, TpuTopology_MaybeAvailableCoresPerChip); TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort); TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled); TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h index 48eff9a389b0c6..fbc8fcec6962d9 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/libtftpu.h" @@ -436,6 +437,11 @@ TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoresPerChip( TpuCoreTypeEnum tpu_core_type); +// Returns the number of cores per Chip or -1 if the TPU system is not +// available. +TFTPU_CAPI_EXPORT absl::StatusOr TpuTopology_MaybeAvailableCoresPerChip( + TpuCoreTypeEnum tpu_core_type); + // Recycle unused service port. TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); @@ -802,6 +808,7 @@ struct TfTpu_OpsApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoresPerChip); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_MaybeAvailableCoresPerChip); TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); From 1a0bcb7be2d255fa46f996d2c2911ba94590db8d Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Mon, 7 Apr 2025 11:08:53 -0700 Subject: [PATCH 0308/1324] [MLIR] Add additional optimize patterns for GeluOp approximation. The existing patterns look for pow(x, 3). Adding additional patterns that check mul(mul(x, x), x) and mul(x, mul(x, x)) PiperOrigin-RevId: 744784056 --- .../compiler/mlir/lite/tests/optimize.mlir | 83 ++++++++++++ .../mlir/lite/transforms/optimize_patterns.td | 122 ++++++++++++++++++ 2 files changed, 205 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 6ad231ed9f54bc..83b82d50fc064f 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -3718,6 +3718,46 @@ func.func @gelu_approximate(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_approximate_with_mul(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%99, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%6, %5) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + +func.func @gelu_approximate_with_mul2(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%arg0, %99) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%6, %5) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.797884583> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor @@ -3738,6 +3778,49 @@ func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_approximate1_with_mul(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_2 = arith.constant dense<3.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%99, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%arg0, %6) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + + +func.func @gelu_approximate1_with_mul1(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_2 = arith.constant dense<3.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%arg0, %99) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%arg0, %6) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_approximate_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.797884583> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 2e08d0f89959a9..1e97e8f42584b6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1397,6 +1397,67 @@ def MatchGeluApproximate : Pat< (HasOneUse $pow_out), ]>; +// Alternate pattern for GeluApproximate to match mul(x, mul(x, x)). +// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(x, mul(x, x)) ) ) ) +def MatchGeluApproximate_Mul1 : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out $arg0, + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + +// Alternate pattern for GeluApproximate to match mul(mul(x, x), x). +// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(mul(x, x), x) ) ) ) +def MatchGeluApproximate_Mul2 : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), + $arg0, TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + // Alternate pattern for GeluApproximate (see different order for mul), replaces // x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * pow( x, 3 ) ) ) ) ) def MatchGeluApproximate1 : Pat< @@ -1426,6 +1487,67 @@ def MatchGeluApproximate1 : Pat< (HasOneUse $pow_out), ]>; +// Alternate pattern for GeluApproximate1 to match mul(x, mul(x, x)). +// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(x, mul(x, x)) ) ) ) ) +def MatchGeluApproximate1_Mul1 : Pat< + (TFL_MulOp $arg0, + (TFL_MulOp:$mul_out + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out $arg0, + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + +// Alternate pattern for GeluApproximate1 to match mul(mul(x, x), x). +// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(mul(x, x), x) ) ) ) ) +def MatchGeluApproximate1_Mul2 : Pat< + (TFL_MulOp $arg0, + (TFL_MulOp:$mul_out + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), + $arg0, TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + // For Gelu, replaces // 0.5 * x * ( 1 + erf( x * sqrt_1_2 ) ) def MatchGelu : Pat< From b84c9cc9c9c6c98420e986322f9ed9a89d4c536e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 11:23:15 -0700 Subject: [PATCH 0309/1324] Bumping up libtpu version to pick correct versioned nightlies PiperOrigin-RevId: 744788989 --- tensorflow/tools/pip_package/setup.py.tpl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/pip_package/setup.py.tpl b/tensorflow/tools/pip_package/setup.py.tpl index c6191b05cec9d1..e2495eed665c43 100644 --- a/tensorflow/tools/pip_package/setup.py.tpl +++ b/tensorflow/tools/pip_package/setup.py.tpl @@ -55,8 +55,8 @@ from setuptools.dist import Distribution _VERSION = '0.0.0' # Update this version when a new libtpu stable version is released. -LATEST_RELEASE_LIBTPU_VERSION = '0.0.11' -NEXT_LIBTPU_VERSION = '0.0.12' +LATEST_RELEASE_LIBTPU_VERSION = '0.0.12' +NEXT_LIBTPU_VERSION = '0.0.13' # We use the same setup.py for all tensorflow_* packages and for the nightly # equivalents (tf_nightly_*). The package is controlled from the argument line From 99a29e5d0c31012b281df1fd4f95d9d11400a825 Mon Sep 17 00:00:00 2001 From: Daniel Sosa Date: Mon, 7 Apr 2025 11:23:21 -0700 Subject: [PATCH 0310/1324] remove obsolete code PiperOrigin-RevId: 744789021 --- tensorflow/lite/tools/pip_package/BUILD | 40 ---- tensorflow/lite/tools/pip_package/utils/BUILD | 23 -- .../utils/manylinux_compliance_test.py | 128 ----------- .../utils/py_manylinux_compliance_test.bzl | 27 --- .../lite/tools/pip_package/utils/py_wheel.bzl | 88 -------- .../tools/pip_package/utils/wheel_builder.py | 199 ------------------ tensorflow/opensource_only.files | 2 - 7 files changed, 507 deletions(-) delete mode 100644 tensorflow/lite/tools/pip_package/BUILD delete mode 100644 tensorflow/lite/tools/pip_package/utils/BUILD delete mode 100644 tensorflow/lite/tools/pip_package/utils/manylinux_compliance_test.py delete mode 100644 tensorflow/lite/tools/pip_package/utils/py_manylinux_compliance_test.bzl delete mode 100644 tensorflow/lite/tools/pip_package/utils/py_wheel.bzl delete mode 100644 tensorflow/lite/tools/pip_package/utils/wheel_builder.py diff --git a/tensorflow/lite/tools/pip_package/BUILD b/tensorflow/lite/tools/pip_package/BUILD deleted file mode 100644 index a7c4c254e8613c..00000000000000 --- a/tensorflow/lite/tools/pip_package/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tensorflow/lite/tools/pip_package/utils:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test") -load("//tensorflow/lite/tools/pip_package/utils:py_wheel.bzl", "py_wheel") - -package( - default_visibility = ["//visibility:private"], -) - -MANYLINUX_X86_64_TAG = "manylinux_2_17_x86_64" - -genrule( - name = "setup_py", - srcs = ["//tensorflow/lite/tools/pip_package:setup_with_binary.py"], - outs = ["setup.py"], - cmd = "cat $< > $@", -) - -py_wheel( - name = "litert_wheel", - srcs = [ - "//tensorflow/lite/experimental/genai:pywrap_genai_ops.so", - "//tensorflow/lite/profiling/proto:model_runtime_info_py", - "//tensorflow/lite/profiling/proto:profiling_info_py", - "//tensorflow/lite/python:interpreter", - "//tensorflow/lite/python:schema_py", - "//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper.so", - "//tensorflow/lite/python/metrics:metrics_interface", - "//tensorflow/lite/python/metrics:metrics_portable.py", - ], - platform_name = MANYLINUX_X86_64_TAG, - setup_py = ":setup_py", - version = "1.1.2", -) - -verify_manylinux_compliance_test( - name = "manylinux_compliance_test", - aarch64_compliance_tag = "manylinux_2_17_aarch64", - ppc64le_compliance_tag = "manylinux_2_17_ppc64le", - wheel = ":litert_wheel", - x86_64_compliance_tag = MANYLINUX_X86_64_TAG, -) diff --git a/tensorflow/lite/tools/pip_package/utils/BUILD b/tensorflow/lite/tools/pip_package/utils/BUILD deleted file mode 100644 index 32762c6253e962..00000000000000 --- a/tensorflow/lite/tools/pip_package/utils/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tensorflow:pytype.default.bzl", "pytype_strict_binary") - -package( - default_visibility = ["//visibility:public"], -) - -filegroup( - name = "manylinux_compliance_test", - srcs = ["manylinux_compliance_test.py"], - visibility = ["//visibility:public"], -) - -pytype_strict_binary( - name = "wheel_builder", - srcs = [ - "wheel_builder.py", - ], - main = "wheel_builder.py", - deps = [ - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", - ], -) diff --git a/tensorflow/lite/tools/pip_package/utils/manylinux_compliance_test.py b/tensorflow/lite/tools/pip_package/utils/manylinux_compliance_test.py deleted file mode 100644 index 3ff09f3cb691cb..00000000000000 --- a/tensorflow/lite/tools/pip_package/utils/manylinux_compliance_test.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2025 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import argparse -import io -import platform -import re -import sys -from auditwheel import main_show - - -def parse_args(): - """Arguments parser.""" - parser = argparse.ArgumentParser( - description="Helper for manylinux compliance verification", - fromfile_prefix_chars="@", - ) - parser.add_argument( - "--wheel-path", required=True, help="Path of the wheel, mandatory" - ) - parser.add_argument( - "--aarch64-compliance-tag", - required=True, - help="ManyLinux compliance tag for aarch64", - ) - parser.add_argument( - "--x86_64-compliance-tag", - required=True, - help="ManyLinux compliance tag for x86_64", - ) - parser.add_argument( - "--ppc64le-compliance-tag", - required=True, - help="ManyLinux compliance tag for ppc64le", - ) - return parser.parse_args() - - -def get_auditwheel_output(wheel_path: str) -> None: - """Run "auditwheel show" on the wheel and return the output. - - Args: - wheel_path: path of the wheel file - - Returns: - "auditwheel show" output - """ - stringio = io.StringIO() - previous_stdout = sys.stdout - sys.stdout = stringio - - auditwheel_parser = argparse.ArgumentParser( - description="Cross-distro Python wheels." - ) - sub_parsers = auditwheel_parser.add_subparsers(metavar="command", dest="cmd") - main_show.configure_parser(sub_parsers) - auditwheel_args = argparse.Namespace( - WHEEL_FILE=wheel_path, - verbose=1, - ) - main_show.execute(args=auditwheel_args, p=auditwheel_parser) - - sys.stdout = previous_stdout - return stringio.getvalue() - - -def verify_manylinux_compliance( - auditwheel_log: str, - compliance_tag: str, -) -> None: - """Verify manylinux compliance. - - Args: - auditwheel_log: "auditwheel show" execution results - compliance_tag: manyLinux compliance tag - - Raises: - RuntimeError: if the wheel is not manyLinux compliant. - """ - regex = r'platform tag to\s+"{}"'.format(compliance_tag) - alt_regex = regex.replace("2014", "_2_17") - if not ( - re.search(regex, auditwheel_log) or re.search(alt_regex, auditwheel_log) - ): - raise RuntimeError( - ("The wheel is not compliant with the tag {tag}.\n{result}").format( - tag=compliance_tag, result=auditwheel_log - ) - ) - - -def test_manylinux_compliance(args): - machine_type = platform.uname().machine - supported_machine_types = ["x86_64", "aarch64", "ppc64le"] - if machine_type not in supported_machine_types: - raise RuntimeError( - "Unsupported machine type {machine_type}. The supported are:" - " {supported_types}".format( - machine_type=machine_type, supported_types=supported_machine_types - ) - ) - if machine_type == "x86_64": - compliance_tag = args.x86_64_compliance_tag - elif machine_type == "aarch64": - compliance_tag = args.aarch64_compliance_tag - else: - compliance_tag = args.ppc64le_compliance_tag - auditwheel_output = get_auditwheel_output(args.wheel_path) - verify_manylinux_compliance( - auditwheel_output, - compliance_tag, - ) - - -if __name__ == "__main__": - test_manylinux_compliance(parse_args()) diff --git a/tensorflow/lite/tools/pip_package/utils/py_manylinux_compliance_test.bzl b/tensorflow/lite/tools/pip_package/utils/py_manylinux_compliance_test.bzl deleted file mode 100644 index f7718ea36e8909..00000000000000 --- a/tensorflow/lite/tools/pip_package/utils/py_manylinux_compliance_test.bzl +++ /dev/null @@ -1,27 +0,0 @@ -""" Macros for manylinux compliance verification test. """ - -load("//tensorflow:strict.default.bzl", "py_strict_test") - -def verify_manylinux_compliance_test( - name, - wheel, - aarch64_compliance_tag, - x86_64_compliance_tag, - ppc64le_compliance_tag, - test_tags = []): - py_strict_test( - name = name, - srcs = [Label("//tensorflow/lite/tools/pip_package/utils:manylinux_compliance_test")], - data = [ - wheel, - ], - deps = ["@pypi_auditwheel//:pkg"], - args = [ - "--wheel-path=$(location {})".format(wheel), - "--aarch64-compliance-tag={}".format(aarch64_compliance_tag), - "--x86_64-compliance-tag={}".format(x86_64_compliance_tag), - "--ppc64le-compliance-tag={}".format(ppc64le_compliance_tag), - ], - main = "manylinux_compliance_test.py", - tags = ["manual"] + test_tags, - ) diff --git a/tensorflow/lite/tools/pip_package/utils/py_wheel.bzl b/tensorflow/lite/tools/pip_package/utils/py_wheel.bzl deleted file mode 100644 index 824e6f464c5c33..00000000000000 --- a/tensorflow/lite/tools/pip_package/utils/py_wheel.bzl +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2025 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Rule to build a python wheel generically. - -This rule is used to build a python wheel from a list of source files. It takes a list of source -files, a setup.py file, and a version string as input. It then uses a python script, -wheel_builder.py, to generate the wheel file. The wheel builder binary is responsible for preparing -the build environment and calling the setuptools command to generate the wheel file. -""" - -load( - "@python_version_repo//:py_version.bzl", - "HERMETIC_PYTHON_VERSION", -) - -def _get_full_wheel_name(wheel_name, version, platform_name): - python_version = HERMETIC_PYTHON_VERSION.replace(".", "") - wheel_version = version.replace("-dev", ".dev").replace("-", "") - return "{wheel_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl".format( - wheel_name = wheel_name, - wheel_version = wheel_version, - python_version = python_version, - wheel_platform_tag = platform_name, - ) - -def _py_wheel_impl(ctx): - executable = ctx.executable.wheel_binary - filelist_lists = [src.files.to_list() for src in ctx.attr.srcs] - filelist = [f for filelist in filelist_lists for f in filelist] - wheel_name = _get_full_wheel_name("ai_edge_litert", ctx.attr.version, ctx.attr.platform_name) - output_file = ctx.actions.declare_file("dist/{wheel_name}".format(wheel_name = wheel_name)) - - args = ctx.actions.args() - args.add("--setup_py", ctx.file.setup_py.path) - args.add("--output", output_file.dirname) - args.add("--version", ctx.attr.version) - - for f in filelist: - args.add("--src", f.path) - - if ctx.attr.platform_name: - args.add("--platform", ctx.attr.platform_name) - - args.set_param_file_format("flag_per_line") - args.use_param_file("@%s", use_always = False) - - ctx.actions.run( - mnemonic = "WheelBuilder", - arguments = [args], - inputs = filelist + [ctx.file.setup_py], - outputs = [output_file], - executable = executable, - ) - return [DefaultInfo(files = depset(direct = [output_file]))] - -py_wheel = rule( - implementation = _py_wheel_impl, - attrs = { - "srcs": attr.label_list( - allow_files = True, - ), - "pyproject": attr.label( - allow_single_file = [".toml"], - ), - "setup_py": attr.label( - allow_single_file = [".py"], - mandatory = True, - ), - "platform_name": attr.string(), - "version": attr.string(mandatory = True), - "wheel_binary": attr.label( - default = Label("//tensorflow/lite/tools/pip_package/utils:wheel_builder"), - executable = True, - cfg = "exec", - ), - }, -) diff --git a/tensorflow/lite/tools/pip_package/utils/wheel_builder.py b/tensorflow/lite/tools/pip_package/utils/wheel_builder.py deleted file mode 100644 index 7d749fce0d7180..00000000000000 --- a/tensorflow/lite/tools/pip_package/utils/wheel_builder.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2025 The Tensorflow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This script is used to build a python wheel from a list of source files. - -It takes a list of source files, a setup.py file, and a version string as input. -It then uses a python script, wheel_builder.py, to generate the wheel file. The -wheel builder binary is responsible for preparing the build environment and -calling the setuptools command to generate the wheel file. - -args: - --setup_py: Path to the setup.py file. - --output: Output directory for the wheel. - --version: Version of the wheel. - --src: List of source files for the wheel. - --platform: Platform name to be passed to build module. - -output: - A python wheel file is created in the output directory. The name of the wheel - file is based on various factors, including the version and platform. -""" - -import argparse -import glob -import os -import shutil -import subprocess -import sys -import tomllib -from typing import Optional - - -def parse_args() -> argparse.Namespace: - """Arguments parser.""" - parser = argparse.ArgumentParser( - description="Helper for building python wheel from pyproject.toml", - fromfile_prefix_chars="@", - ) - parser.add_argument("--pyproject", help="location of pyproject.toml file") - parser.add_argument("--setup_py", help="location of setup.py file") - parser.add_argument("--output", help="output directory") - parser.add_argument("--version", help="version of the wheel") - parser.add_argument( - "--src", help="single source file for the wheel", action="append" - ) - parser.add_argument( - "--platform", - required=True, - help="Platform name to be passed to build module", - ) - return parser.parse_args() - - -def get_project_name(pyproject_path: str) -> str: - with open(pyproject_path, "rb") as f: - pyproject = tomllib.load(f) - try: - return pyproject["project"]["name"] - except KeyError as e: - raise ValueError( - "Invalid pyproject.toml file. Please check the project name." - " Dynamically generated project names are not supported." - ) from e - - -def create_empty_init_files(dst_dir: str) -> None: - """Create __init__.py files.""" - dir_list = [f for f in os.scandir(dst_dir) if f.is_dir()] - for dir_name in dir_list: - with open(os.path.join(dir_name, "__init__.py"), "w"): - pass - - -def create_init_files(dst_dir: str, meta_dict: Optional[dict[str, str]] = None): - create_empty_init_files(dst_dir) - - if meta_dict: - with open(os.path.join(dst_dir, "__init__.py"), "w") as f: - for key, value in meta_dict.items(): - f.write(f"{key} = \"{value}\"\n") - - -def construct_meta_dict(args) -> dict[str, str]: - return { - "__version__": args.version, - } - - -def prepare_build_tree(tree_path, args, project_name: str): - """Prepares the build tree for the wheel build. - - Args: - tree_path: Path to the build tree. - args: Command line arguments. - project_name: Name of the project. - """ - src_dir = os.path.join(tree_path, project_name.replace("-", "_")) - os.makedirs(src_dir) - - shutil.copyfile(args.setup_py, os.path.join(tree_path, "setup.py")) - - for src in args.src: - shutil.copyfile(src, os.path.join(src_dir, os.path.basename(src))) - - meta_dict = construct_meta_dict(args) - - create_init_files(src_dir, meta_dict) - - -def build_pyproject_wheel( - buildtree_path: str, platform_name: Optional[str] = None -): - """Builds a python wheel from a pyproject.toml file. - - Args: - buildtree_path: Path to the build tree. - platform_name: Platform name to be passed to build module. - """ - env = os.environ.copy() - - command = [ - sys.executable, - "-m", - "build", - "-w", - "-o", - os.getcwd(), - ] - - if platform_name: - command.append( - # This is due to setuptools not making it possible to pass the - # platform name as a dynamic pyproject.toml property. - f"--config-setting=--build-option=--plat-name={platform_name}" - ) - - subprocess.run( - command, - check=True, - cwd=buildtree_path, - env=env, - ) - - -def build_setup_py_wheel( - buildtree_path: str, - output_dir: str, - version: str, - platform_name: Optional[str] = None, -): - """Builds a python wheel from a setup.py file. - - Args: - buildtree_path: Path to the build tree. - output_dir: Output directory for the wheel. - version: Version of the wheel. - platform_name: Platform name to be passed to build module. - """ - env = os.environ.copy() - - env["PROJECT_NAME"] = "ai_edge_litert" - env["PACKAGE_VERSION"] = version - - command = [ - sys.executable, - f"{buildtree_path}/setup.py", - "bdist_wheel", - f"--plat-name={platform_name}", - ] - - subprocess.run( - command, - check=True, - cwd=buildtree_path, - env=env, - ) - - for filename in glob.glob(os.path.join(buildtree_path, "dist/*.whl")): - shutil.copy(filename, output_dir) - - -if __name__ == "__main__": - build_dir = os.path.join(os.getcwd(), "wheel_build") - arg_data = parse_args() - - prepare_build_tree(build_dir, arg_data, "ai_edge_litert") - build_setup_py_wheel( - build_dir, arg_data.output, arg_data.version, arg_data.platform - ) diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 1857c0f9019ebc..da69497a97f2b8 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -127,8 +127,6 @@ tf_staging/tensorflow/lite/signature_runner.h: tf_staging/tensorflow/lite/special_rules.bzl: tf_staging/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/AndroidManifest.xml: tf_staging/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/build_defs.bzl: -tf_staging/tensorflow/lite/tools/pip_package/BUILD: -tf_staging/tensorflow/lite/tools/pip_package/utils/BUILD: tf_staging/tensorflow/lite/tools/verifier.h: tf_staging/tensorflow/lite/tools/verifier_internal.h: tf_staging/tensorflow/python/autograph/core/config:.py From 21c67be2959c9c0f1cc66617a1179f3642a803b2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 11:27:25 -0700 Subject: [PATCH 0311/1324] Cleanup forwarding headers in tensorflow/core/profiler/protobuf folder PiperOrigin-RevId: 744790303 --- tensorflow/core/BUILD | 2 -- tensorflow/core/profiler/convert/BUILD | 2 -- .../convert/xplane_to_op_metrics_db_test.cc | 1 - .../convert/xplane_to_step_events_test.cc | 1 - .../profiler/convert/xplane_to_step_stats.cc | 1 - tensorflow/core/profiler/protobuf/BUILD | 18 ------------------ tensorflow/core/profiler/protobuf/xplane.proto | 5 ----- 7 files changed, 30 deletions(-) delete mode 100644 tensorflow/core/profiler/protobuf/xplane.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index cda9e305cc4c60..acab945bf8c5ca 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -179,7 +179,6 @@ tf_proto_library( "//tensorflow/core/example:protos_all", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", - "//tensorflow/core/profiler/protobuf:xplane_proto", "//tensorflow/core/profiler:profiler_options_proto", "//tensorflow/core/protobuf:error_codes_proto_impl", "//tensorflow/core/protobuf:for_core_protos", @@ -1482,7 +1481,6 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc_impl", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl", "//tensorflow/core/protobuf:autotuning_proto_cc_impl", "//tensorflow/core/protobuf:conv_autotuning_proto_cc_impl", ":protos_all_cc_impl", diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 4ef2c038ee8511..978b075dea7733 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -51,7 +51,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", @@ -533,7 +532,6 @@ tf_cc_test( ":xplane_to_step_events", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index a93cb82e98ed7b..ee05ed31341025 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc index d02c231659e353..5c97b045622a56 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "xla/tsl/profiler/utils/group_events.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" diff --git a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc index dbe625dde2091f..8645876955bb1f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc @@ -30,7 +30,6 @@ limitations under the License. #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/framework/step_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/gpu_event_stats.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 6d949f5b67d504..fcdfed4c4711f9 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -26,19 +26,6 @@ package_group( ], ) -tf_proto_library( - name = "xplane_proto", - srcs = ["xplane.proto"], - make_default_target_header_only = True, - protodeps = [ - "@local_tsl//tsl/profiler/protobuf:xplane_proto", - ], - visibility = [":friends"], - exports = [ - "@local_tsl//tsl/profiler/protobuf:xplane_proto", - ], -) - # This is needed because of how tf_android_core_proto_sources parses proto paths. exports_files( srcs = ["xplane.proto"], @@ -280,11 +267,6 @@ tf_proto_library( ) # copybara:uncomment_begin(google-only) -# py_proto_library( -# name = "xplane_py_pb2", -# visibility = [":friends"], -# deps = [":xplane_proto"], -# ) # # py_proto_library( # name = "memory_viewer_preprocess_py_pb2", diff --git a/tensorflow/core/profiler/protobuf/xplane.proto b/tensorflow/core/profiler/protobuf/xplane.proto deleted file mode 100644 index 69655b76d3e189..00000000000000 --- a/tensorflow/core/profiler/protobuf/xplane.proto +++ /dev/null @@ -1,5 +0,0 @@ -syntax = "proto3"; - -package tensorflow.profiler.empty; - -import public "tsl/profiler/protobuf/xplane.proto"; From ac8c58bd168fab8ac8b4070af52e5e378d744d14 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Mon, 7 Apr 2025 11:38:48 -0700 Subject: [PATCH 0312/1324] Integrate latest pywrap rules (mac fixes and support for pybind_extension submodules, including recursive ones). Also enable pywrap rules on MacOS PiperOrigin-RevId: 744794481 --- ci/official/envs/macos_arm64 | 2 +- tensorflow/python/BUILD | 58 ++++--- tensorflow/python/_pywrap_tensorflow.def | 7 +- tensorflow/tf_exported_symbols.lds | 7 + .../py/rules_pywrap/pybind_extension.py.tpl | 23 +-- .../py/rules_pywrap/pywrap.impl.bzl | 149 +++++++++++++----- 6 files changed, 167 insertions(+), 79 deletions(-) diff --git a/ci/official/envs/macos_arm64 b/ci/official/envs/macos_arm64 index c789a2dc2d0990..b92f49590f033a 100644 --- a/ci/official/envs/macos_arm64 +++ b/ci/official/envs/macos_arm64 @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config release_macos_arm64" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config release_macos_arm64" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_BUILD_PIP_PACKAGE_WHEEL_NAME_ARG="--repo_env=WHEEL_NAME=tensorflow" TFCI_INDEX_HTML_ENABLE=1 diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 87b00da9dc7503..fa305cd653fc41 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1490,29 +1490,6 @@ cc_library( alwayslink = True, ) -cc_library( - name = "_pywrap_lib_filter", - deps = if_pywrap( - if_true = [ - "@pybind11_abseil//pybind11_abseil:absl_casters", - "@pybind11_abseil//pybind11_abseil:import_status_module", - "@pybind11_abseil//pybind11_abseil:status_casters", - "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], - ), -) - -cc_library( - name = "_pywrap_lib_exclusion_filter", - deps = if_pywrap( - if_true = [ - "@com_google_protobuf//:protobuf", - "@com_google_protobuf//:protobuf_lite", - "@zlib//:zlib", - ], - ), -) - cc_library( name = "_cuda_deps_filter", srcs = [], @@ -1531,7 +1508,7 @@ pywrap_library( # buildifier: disable=unsorted-dict-items # @unsorted-dict-items common_lib_filters = { - "tensorflow/libtensorflow_framework.so.2": select({ + "tensorflow/tensorflow_framework": select({ "//tensorflow:windows": [], "//conditions:default": [ "//tensorflow:tensorflow_framework_pywrap_filter", @@ -1542,7 +1519,7 @@ pywrap_library( "@com_google_absl//absl/random", ], }), - "tensorflow/libtensorflow_cc.so.2": select({ + "tensorflow/tensorflow_cc": select({ "//tensorflow:windows": [], "//conditions:default": ["//tensorflow:tensorflow_cc_pywrap_filter"], }), @@ -1550,8 +1527,13 @@ pywrap_library( # buildifier: disable=unsorted-dict-items # @unsorted-dict-items common_lib_linkopts = { - "tensorflow/libtensorflow_framework.so.2": select({ + "tensorflow/tensorflow_framework": select({ "//tensorflow:windows": [], + "//tensorflow:macos": [ + "-lpthread", + "-ldl", + "-lm", + ], "//conditions:default": [ "-z defs", "-lpthread", @@ -1559,7 +1541,7 @@ pywrap_library( "-lm", ], }), - "tensorflow/libtensorflow_cc.so.2": select({ + "tensorflow/tensorflow_cc": select({ "//tensorflow:windows": [ "-DEFAULTLIB:ws2_32.lib", "-DEFAULTLIB:advapi32.lib", @@ -1567,6 +1549,11 @@ pywrap_library( "-DEFAULTLIB:Normaliz.lib", "-DEFAULTLIB:ntdll.lib", ], + "//tensorflow:macos": [ + "-lpthread", + "-ldl", + "-lm", + ], "//conditions:default": [ "-z defs", "-lpthread", @@ -1578,14 +1565,23 @@ pywrap_library( # buildifier: disable=unsorted-dict-items # @unsorted-dict-items common_lib_version_scripts = { - "tensorflow/libtensorflow_cc.so.2": "//tensorflow:tf_version_script.lds", + "tensorflow/tensorflow_cc": select({ + "//tensorflow:windows": None, + "//tensorflow:macos": "//tensorflow:tf_exported_symbols.lds", + "//conditions:default": "//tensorflow:tf_version_script.lds", + }), + }, + # buildifier: disable=unsorted-dict-items + # @unsorted-dict-items + common_lib_versions = { + "tensorflow/tensorflow_framework": "2", + "tensorflow/tensorflow_cc": "2", }, enable_common_lib_starlark_only_filter = select({ "//tensorflow:windows": False, + "//tensorflow:macos": False, "//conditions:default": True, }), - pywrap_lib_exclusion_filter = ":_pywrap_lib_exclusion_filter", - pywrap_lib_filter = ":_pywrap_lib_filter", starlark_only_deps = [ "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", "//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization", @@ -1676,7 +1672,7 @@ pywrap_library( pywrap_common_library( name = "pywrap_tensorflow_framework", dep = ":_pywrap_tensorflow", - filter_name = "libtensorflow_framework.so.2", + filter_name = "tensorflow_framework", ) pywrap_common_library( diff --git a/tensorflow/python/_pywrap_tensorflow.def b/tensorflow/python/_pywrap_tensorflow.def index 26e29918b1f848..70a2818ba4835e 100644 --- a/tensorflow/python/_pywrap_tensorflow.def +++ b/tensorflow/python/_pywrap_tensorflow.def @@ -1239,4 +1239,9 @@ EXPORTS ?QuantizeModel@CalibrationWrapper@calibration_wrapper@tflite@@QEAAPEAU_object@@HH_NHH00@Z ?CreateTypeInfoAndReturnTypeIdImpl@AsyncValue@tsl@@CAGAEBUTypeInfo@12@@Z ?pywrap_library_dependency_symbol@python@tensorflow@@YAXXZ - ?NotifyAvailable@AsyncValue@tsl@@IEAAXVState@12@@Z \ No newline at end of file + ?NotifyAvailable@AsyncValue@tsl@@IEAAXVState@12@@Z + ??1CondVar@lts_20230802@absl@@QEAA@XZ + ?Lock@Mutex@lts_20230802@absl@@QEAAXXZ + ?Signal@CondVar@lts_20230802@absl@@QEAAXXZ + ?Unlock@Mutex@lts_20230802@absl@@QEAAXXZ + ?Wait@CondVar@lts_20230802@absl@@QEAAXPEAVMutex@23@@Z diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds index 37c9e6445e4f29..d27dd48a697971 100644 --- a/tensorflow/tf_exported_symbols.lds +++ b/tensorflow/tf_exported_symbols.lds @@ -14,4 +14,11 @@ *tsl* *lite* *TFL* +*TfLite* *quantization* +*mlir*detail* +*mlir*func* +*mlir*TF* +*mlir*shape* +*mlir*scf* +*mlir*quant* diff --git a/third_party/xla/third_party/py/rules_pywrap/pybind_extension.py.tpl b/third_party/xla/third_party/py/rules_pywrap/pybind_extension.py.tpl index fb225d16aa0d8e..6554f6ce915854 100644 --- a/third_party/xla/third_party/py/rules_pywrap/pybind_extension.py.tpl +++ b/third_party/xla/third_party/py/rules_pywrap/pybind_extension.py.tpl @@ -1,12 +1,16 @@ -def __update_globals(pywrap_m): - if hasattr(pywrap_m, '__all__'): - all_names = pywrap_m.__all__ - else: - all_names = [name for name in dir(pywrap_m) if not name.startswith('_')] +from sys import modules +from types import ModuleType - extra_names = [] # template_val - all_names.extend(extra_names) - globals().update({name: getattr(pywrap_m, name) for name in all_names}) + +def __update_globals(new_import_path, pywrap_m): + all_names = pywrap_m.__all__ if hasattr(pywrap_m, '__all__') else dir( + pywrap_m) + modules[new_import_path] = pywrap_m + for name in all_names: + sub_pywrap = getattr(pywrap_m, name) + if isinstance(sub_pywrap, ModuleType): + sub_name = sub_pywrap.__name__[len(pywrap_m.__name__):] + __update_globals(new_import_path + sub_name, sub_pywrap) def __try_import(): @@ -16,7 +20,7 @@ def __try_import(): for import_path in imports_paths: try: pywrap_m = __import__(import_path, fromlist=["*"]) - __update_globals(pywrap_m) + __update_globals(__name__, pywrap_m) return except ImportError as e: exceptions.append(str(e)) @@ -27,4 +31,5 @@ def __try_import(): Could not import original test/binary location, import paths tried: {imports_paths}. Previous exceptions: {exceptions}""", last_exception) + __try_import() diff --git a/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl b/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl index d1efec9d2df8d2..222b7d47c5283d 100644 --- a/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl +++ b/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl @@ -39,6 +39,7 @@ def pywrap_library( pywrap_lib_filter = None, pywrap_lib_exclusion_filter = None, common_lib_filters = {}, + common_lib_versions = {}, common_lib_version_scripts = {}, common_lib_def_files_or_filters = {}, common_lib_linkopts = {}, @@ -138,6 +139,7 @@ def pywrap_library( linkopts = common_lib_linkopts.get(common_lib_full_name, []) ver_script = common_lib_version_scripts.get(common_lib_full_name, None) common_cc_binary_name = "%s" % common_lib_name + common_import_name, win_import_library_name = _construct_common_binary( common_cc_binary_name, common_deps + [":%s" % common_split_name], @@ -149,7 +151,8 @@ def pywrap_library( binaries_data.values(), common_lib_pkg, ver_script, - data = [":%s" % common_split_name], + [":%s" % common_split_name], + common_lib_versions.get(common_lib_full_name, ""), ) actual_binaries_data = binaries_data actual_common_deps = common_deps @@ -274,20 +277,55 @@ def _construct_common_binary( dependency_common_lib_packages, dependent_common_lib_package, version_script, - data): - actual_linkopts = _construct_linkopt_soname(name) + _construct_linkopt_rpaths( + data, + version = ""): + version_str = ".{}".format(version) if version else version + linux_binary_name = "lib{}.so{}".format(name, version_str) + win_binary_name = "{}{}.dll".format(name, version_str) + darwin_binary_name = "lib{}{}.dylib".format(name, version_str) + + actual_version_script = None + if version_script: + actual_version_script = "{}_version_script".format(name) + native.alias( + name = actual_version_script, + actual = version_script, + ) + actual_version_script = ":{}".format(actual_version_script) + + linux_linkopts = _construct_linkopt_soname( + linux_binary_name, + False, + ) + _construct_linkopt_rpaths( dependency_common_lib_packages, dependent_common_lib_package, - ) + _construct_linkopt_version_script(version_script) + False, + ) + _construct_linkopt_version_script(actual_version_script, False) native.cc_binary( - name = name, - deps = deps + ([version_script] if version_script else []), + name = linux_binary_name, + deps = deps + ([actual_version_script] if actual_version_script else []), linkstatic = True, linkshared = True, linkopts = linkopts + select({ "@bazel_tools//src/conditions:windows": [], - "//conditions:default": actual_linkopts, + "@bazel_tools//src/conditions:darwin": [], + "//conditions:default": linux_linkopts, + }), + testonly = testonly, + compatible_with = compatible_with, + local_defines = local_defines, + ) + + native.cc_binary( + name = win_binary_name, + deps = deps, + linkstatic = True, + linkshared = True, + linkopts = linkopts + select({ + "@bazel_tools//src/conditions:windows": [], + "@bazel_tools//src/conditions:darwin": [], + "//conditions:default": [], }), testonly = testonly, compatible_with = compatible_with, @@ -295,19 +333,53 @@ def _construct_common_binary( local_defines = local_defines, ) - if_lib_name = "%s_if_lib" % name + darwin_linkopts = _construct_linkopt_soname( + darwin_binary_name, + True, + ) + _construct_linkopt_rpaths( + dependency_common_lib_packages, + dependent_common_lib_package, + True, + ) + _construct_linkopt_version_script(actual_version_script, True) + + native.cc_binary( + name = darwin_binary_name, + deps = deps + ([actual_version_script] if actual_version_script else []), + linkstatic = True, + linkshared = True, + linkopts = linkopts + select({ + "@bazel_tools//src/conditions:windows": [], + "@bazel_tools//src/conditions:darwin": darwin_linkopts, + "//conditions:default": [], + }), + testonly = testonly, + compatible_with = compatible_with, + local_defines = local_defines, + ) + + if_lib_name = "{}{}_if_lib".format(name, version_str) native.filegroup( name = if_lib_name, - srcs = [":%s" % name], + srcs = [":%s" % win_binary_name], output_group = "interface_library", testonly = testonly, compatible_with = compatible_with, ) + native.alias( + name = name, + actual = select({ + "@bazel_tools//src/conditions:windows": ":%s" % win_binary_name, + "@bazel_tools//src/conditions:darwin": ":%s" % darwin_binary_name, + "//conditions:default": ":%s" % linux_binary_name, + }), + ) + import_name = "%s_import" % name + native.cc_import( name = import_name, - shared_library = ":%s" % name, + shared_library = "%s" % name, interface_library = select({ "@bazel_tools//src/conditions:windows": ":%s" % if_lib_name, "//conditions:default": None, @@ -399,7 +471,7 @@ def _pywrap_common_split_library_impl(ctx): else: libs_to_include = filters.common_lib_filters[ctx.attr.common_lib_full_name] - user_link_flags = {} + user_link_flags = [] dynamic_lib_filter = filters.dynamic_lib_filter default_runfiles = ctx.runfiles() for pw in pywrap_infos: @@ -410,8 +482,7 @@ def _pywrap_common_split_library_impl(ctx): continue if include_all_not_excluded or (li in libs_to_include) or li in dynamic_lib_filter: split_linker_inputs.append(li) - for user_link_flag in li.user_link_flags: - user_link_flags[user_link_flag] = True + user_link_flags.extend(li.user_link_flags) if not pw_runfiles_merged: default_runfiles = default_runfiles.merge(pw.default_runfiles) pw_runfiles_merged = True @@ -419,7 +490,7 @@ def _pywrap_common_split_library_impl(ctx): return _construct_split_library_cc_info( ctx, split_linker_inputs, - list(user_link_flags.keys()), + user_link_flags, [], default_runfiles, ctx.attr.collect_objects, @@ -463,7 +534,7 @@ def _construct_split_library_cc_info( linker_input = cc_common.create_linker_input( owner = ctx.label, libraries = depset(direct = dependency_libraries), - user_link_flags = depset(direct = user_link_flags), + user_link_flags = user_link_flags, ) linking_context = cc_common.create_linking_context( @@ -680,6 +751,8 @@ def pybind_extension( linkopts = [], starlark_only = False, **kwargs): + # For backward compatibility that I don't want to mess with + _ignore = [additional_exported_symbols] cc_library_name = "_%s_cc_library" % name native.cc_library( name = cc_library_name, @@ -693,9 +766,15 @@ def pybind_extension( local_defines = ["PROTOBUF_USE_DLLS", "ABSL_CONSUME_DLL"], linkopts = linkopts + select({ "@bazel_tools//src/conditions:windows": [], + "@bazel_tools//src/conditions:darwin": _construct_linkopt_rpaths( + common_lib_packages + [native.package_name()], + native.package_name(), + True, + ), "//conditions:default": _construct_linkopt_rpaths( common_lib_packages + [native.package_name()], native.package_name(), + False, ), }), **kwargs @@ -715,7 +794,6 @@ def pybind_extension( name = name, deps = ["%s" % cc_library_name], common_lib_packages = common_lib_packages, - additional_exported_symbols = additional_exported_symbols, starlark_only = starlark_only, testonly = testonly, compatible_with = compatible_with, @@ -729,9 +807,6 @@ def _pywrap_info_wrapper_impl(ctx): py_stub = ctx.actions.declare_file("%s.py" % ctx.attr.name) substitutions = {} - - additional_exported_symbols = ctx.attr.additional_exported_symbols - py_pkgs = [] for pkg in ctx.attr.common_lib_packages: if pkg: @@ -741,10 +816,6 @@ def _pywrap_info_wrapper_impl(ctx): val = "imports_paths = %s # template_val" % py_pkgs substitutions["imports_paths = [] # template_val"] = val - if additional_exported_symbols: - val = "extra_names = %s # template_val" % additional_exported_symbols - substitutions["extra_names = [] # template_val"] = val - ctx.actions.expand_template( template = ctx.file.py_stub_src, output = py_stub, @@ -776,10 +847,6 @@ _pywrap_info_wrapper = rule( allow_single_file = True, default = Label("//third_party/py/rules_pywrap:pybind_extension.py.tpl"), ), - "additional_exported_symbols": attr.string_list( - mandatory = False, - default = [], - ), "starlark_only": attr.bool(mandatory = False, default = False), }, implementation = _pywrap_info_wrapper_impl, @@ -1081,18 +1148,25 @@ def _construct_inverse_common_lib_filters(common_lib_filters): inverse_common_lib_filters[new_common_lib_k] = common_lib_k return inverse_common_lib_filters -def _construct_linkopt_soname(name): +def _construct_linkopt_soname(name, darwin): soname = name.rsplit("/", 1)[1] if "/" in name else name - soname = soname if name.startswith("lib") else ("lib%s" % soname) - if ".so" not in name: - soname += ".so" - return ["-Wl,-soname,%s" % soname] - -def _construct_linkopt_rpaths(dependency_lib_packages, dependent_lib_package): + soname = soname if name.startswith("lib") else ("lib{}".format(soname)) + extension = ".so" + arg_name = "-soname" + if darwin: + extension = ".dylib" + arg_name = "-install_name" + soname = "@rpath/" + soname + if extension not in name: + soname += extension + return ["-Wl,{},{}".format(arg_name, soname)] + +def _construct_linkopt_rpaths(dependency_lib_packages, dependent_lib_package, darwin): linkopts = {} + origin = "@loader_path" if darwin else "$$ORIGIN" for dependency_lib_package in dependency_lib_packages: origin_pkg = _construct_rpath(dependency_lib_package, dependent_lib_package) - linkopts["-rpath,'$$ORIGIN/%s'" % origin_pkg] = True + linkopts["-rpath,'{}/{}'".format(origin, origin_pkg)] = True return ["-Wl," + ",".join(linkopts.keys())] if linkopts else [] def _construct_rpath(dependency_lib_package, dependent_lib_package): @@ -1111,10 +1185,11 @@ def _construct_rpath(dependency_lib_package, dependent_lib_package): return levels_up + remaining_pkg -def _construct_linkopt_version_script(version_script): +def _construct_linkopt_version_script(version_script, darwin): if not version_script: return [] - return ["-Wl,--version-script,$(location {})".format(version_script)] + arg_name = "-exported_symbols_list" if darwin else "--version-script" + return ["-Wl,{},$(location {})".format(arg_name, version_script)] def _generated_common_win_def_file_impl(ctx): win_raw_def_file_name = "%s.gen.def" % ctx.attr.name From 922c204928f0eaf3bc009141909cbb75ce78fe1b Mon Sep 17 00:00:00 2001 From: pizzud Date: Mon, 7 Apr 2025 11:58:49 -0700 Subject: [PATCH 0313/1324] hlo_hardware_independent_test_base: Mutex-guard default_device_assignment_. //xla/service/gpu:gpu_compiler_test contains a test case that parallelism doesn't hit race conditions, and it occasionally fails due to a read-write race when multiple threads call ParseAndReturnVerifiedModule at once and wind up resetting the default DeviceAssignment pointer. TSAN of course catches this immediately. PiperOrigin-RevId: 744801747 --- third_party/xla/xla/hlo/testlib/BUILD | 1 + .../testlib/hlo_hardware_independent_test_base.cc | 2 ++ .../testlib/hlo_hardware_independent_test_base.h | 13 +++++++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/testlib/BUILD b/third_party/xla/xla/hlo/testlib/BUILD index e3556d7fe1cafd..549a30afa218f0 100644 --- a/third_party/xla/xla/hlo/testlib/BUILD +++ b/third_party/xla/xla/hlo/testlib/BUILD @@ -86,6 +86,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index aed11655a5e01b..134a41678351c9 100644 --- a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -145,6 +146,7 @@ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( std::function shape_size_fn) const { HloModuleConfig config_with_device_assignment = config; if (!config.has_static_device_assignment()) { + absl::MutexLock ml(&device_assignment_mu_); default_device_assignment_ = std::make_unique(GetDefaultDeviceAssignment( config.replica_count(), config.num_partitions())); diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h index 4f101d43b30887..d92eb0ab2da75b 100644 --- a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -26,10 +26,12 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -208,7 +210,10 @@ class HloHardwareIndependentTestBase : public ::testing::Test { // options (e.g. disabling additional passes). virtual DebugOptions GetDebugOptionsForTest() const; - void TearDown() override { default_device_assignment_.reset(); } + void TearDown() override { + absl::MutexLock ml(&device_assignment_mu_); + default_device_assignment_.reset(); + } // Gets an HloModuleConfig with options appropriate for tests. HloModuleConfig GetModuleConfigForTest( int64_t replica_count = 1, int64_t num_partitions = 1, @@ -220,6 +225,7 @@ class HloHardwareIndependentTestBase : public ::testing::Test { if (device_assignment.has_value()) { config.set_static_device_assignment(*device_assignment); } else { + absl::MutexLock ml(&device_assignment_mu_); default_device_assignment_ = std::make_unique( GetDefaultDeviceAssignment(replica_count, num_partitions)); config.set_static_device_assignment(*default_device_assignment_); @@ -305,7 +311,10 @@ class HloHardwareIndependentTestBase : public ::testing::Test { bool allow_mixed_precision_in_hlo_verifier_; HloPredicate instruction_can_change_layout_func_; std::unique_ptr hlo_verifier_; - mutable std::unique_ptr default_device_assignment_; + mutable absl::Mutex device_assignment_mu_; + mutable std::unique_ptr default_device_assignment_ + ABSL_GUARDED_BY(device_assignment_mu_) + ABSL_PT_GUARDED_BY(device_assignment_mu_); }; } // namespace xla From 8887605a04cf792591941cc2e00d48f9018b2ac8 Mon Sep 17 00:00:00 2001 From: pizzud Date: Mon, 7 Apr 2025 12:30:14 -0700 Subject: [PATCH 0314/1324] gpu_compiler_test: Ensure HLOs that would use SortRewriter compile deviceless. Reaching SortRewriter deviceless could be worth triggering a compilation failure in debug mode, but adding the necessary #ifdef NDEBUG chains feels icky and against XLA's general policy of avoiding #ifdefs, especially when the test also needs them. PiperOrigin-RevId: 744812796 --- third_party/xla/xla/service/gpu/BUILD | 4 ++ .../xla/xla/service/gpu/gpu_compiler_test.cc | 52 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index fee0cc5d0f6fda..1884521528cce5 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1685,9 +1685,13 @@ xla_test( "//xla/tsl/lib/monitoring:collected_metrics", "//xla/tsl/lib/monitoring:collection_registry", "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_sink", + "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 06af1b5b077283..feaba89cb931b6 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -21,15 +21,19 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include +#include "absl/base/log_severity.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/log/log_sink.h" +#include "absl/log/scoped_mock_log.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -66,6 +70,7 @@ limitations under the License. #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/threadpool.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -83,10 +88,12 @@ namespace { namespace m = ::xla::match; +using ::testing::EndsWith; using ::testing::IsEmpty; using ::testing::IsSupersetOf; using ::testing::Matches; using ::testing::Not; +using ::testing::StartsWith; using ::testing::TempDir; class GpuCompilerTest : public HloTestBase { @@ -1925,6 +1932,51 @@ TEST_F(GpuCompilerTest, DynamicSliceFusionReduceScatterMultipleBuffers) { ::tsl::testing::IsOkAndHolds(true)); } +TEST_F(GpuCompilerTest, CompilingSortsWorksWithoutDevice) { + constexpr absl::string_view kHlo = R"( +HloModule TestModule + +%compare { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT +} + +ENTRY %main { + %input = f32[1000] parameter(0) + ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare +})"; + + HloModuleConfig config; + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_cub_radix_sort(true); + + std::string target_file; + ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&target_file)); + TF_ASSERT_OK(tsl::WriteTextProto( + tsl::Env::Default(), target_file, + Compiler::TargetConfig(backend().default_stream_executor()).ToProto())); + debug_options.set_xla_gpu_target_config_filename(target_file); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo, config)); + // absl::ScopedMockLog only works if we're actually using ABSL logging, and + // TSL supports a homegrown logging implementation, so we should only check + // the log is emitted when ABSL logging is used. + absl::ScopedMockLog mock_log(absl::MockLogDefault::kIgnoreUnexpected); + if constexpr (std::is_same_v) { + EXPECT_CALL(mock_log, + Log(absl::LogSeverity::kWarning, EndsWith("/gpu_compiler.cc"), + StartsWith("Using fallback sort algorithm"))); + } + // StartCapturingLogs has to be called even if we expect not to capture any + // logs. + mock_log.StartCapturingLogs(); + TF_ASSERT_OK(backend().compiler()->RunHloPasses(std::move(module), nullptr, + GetAllocator())); +} + } // namespace } // namespace gpu } // namespace xla From e2adf40e07ec58d5c5e9aea90a11a3fc9d54e346 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 7 Apr 2025 12:45:39 -0700 Subject: [PATCH 0315/1324] Cache the hash of `xla::ifrt::HloSharding` `xla::HloSharding`'s hash function isn't cheap because its current implementation unrolls iota tile assignment into a regular tile assignment. Since sharding objects are immutable, it is safe to cache its hash value. The cached value check uses the same pattern as `BasicDeviceList`. PiperOrigin-RevId: 744818035 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 4 ++-- .../xla/xla/python/pjrt_ifrt/xla_sharding.cc | 14 +++++++++++--- .../xla/xla/python/pjrt_ifrt/xla_sharding.h | 8 +++++++- .../xla/python/pjrt_ifrt/xla_sharding_test.cc | 18 +++++++++++++++++- 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 18f9bef9b329f5..ad87618f331ddf 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -41,8 +41,8 @@ cc_library( "//xla/pjrt:pjrt_executable", "//xla/python/ifrt", "//xla/python/ifrt:serdes", - "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -168,10 +168,10 @@ xla_cc_test( "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:device_test_util", "//xla/python/ifrt:tuple_impl_test_lib", - "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:errors", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc index 0dc54be5e1e429..565349f745fe30 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_sharding.h" #include +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -40,7 +42,6 @@ limitations under the License. #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/shape_util.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -426,8 +427,15 @@ std::string HloSharding::DebugString() const { } void HloSharding::Hash(absl::HashState state) const { - absl::HashState::combine(std::move(state), devices_, memory_kind_, - xla_hlo_sharding_); + uint64_t hash = hash_.load(std::memory_order_relaxed); + if (hash == kUnsetHash) { + hash = absl::HashOf(devices_, memory_kind_, xla_hlo_sharding_); + if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { + ++hash; + } + hash_.store(hash, std::memory_order_relaxed); + } + absl::HashState::combine(std::move(state), hash); } std::vector TEST_HloShardingIndexDomainsSlowPath( diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h index 46d799bfaa9a70..cef55eb90148f6 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_PYTHON_PJRT_IFRT_XLA_SHARDING_H_ #define XLA_PYTHON_PJRT_IFRT_XLA_SHARDING_H_ +#include +#include #include #include #include @@ -31,7 +33,6 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/tsl/concurrency/ref_count.h" namespace xla { namespace ifrt { @@ -106,6 +107,11 @@ class HloSharding final void Hash(absl::HashState state) const override; xla::HloSharding xla_hlo_sharding_; + + // Cached hash. 0 indicates the hash needs to be computed and cached. + // May be written multiple times with the same non-zero value. + static constexpr uint64_t kUnsetHash = 0; + mutable std::atomic hash_ = kUnsetHash; }; // Test only: returns `HloSharding::IndexDomains()`, using `xla::HloSharding` diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc index 8b6a522c6a7321..a7285df4afdca2 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/hash/hash_testing.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" @@ -923,6 +923,22 @@ TEST_P(HloShardingTest, DisassembleFailsWithDynamicShape) { HasSubstr("can only disassemble static shape"))); } +TEST_P(HloShardingTest, Hash) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + HloSharding::Create(GetDevices({0, 1, 2, 3, 4, 5}), MemoryKind(), + xla::HloSharding::Replicate()), + HloSharding::Create(GetDevices({0}), MemoryKind(), + xla::HloSharding::Replicate()), + HloSharding::Create(GetDevices({0}), MemoryKind("pinned_host"), + xla::HloSharding::Replicate()), + HloSharding::Create(GetDevices({0, 1, 2, 3, 4, 5}), MemoryKind(), + xla::HloSharding::AssignDevice(/*device_id=*/0)), + HloSharding::Create(GetDevices({0, 1, 2, 3, 4, 5}), MemoryKind(), + xla::HloSharding::PartialTile(xla::TileAssignment( + xla::IotaTileAssignment::Create({2, 3})))), + })); +} + INSTANTIATE_TEST_SUITE_P(NumDevices, HloShardingTest, testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, From c0902f1fde3e05004c0a35c931d23aec4a4101d6 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 7 Apr 2025 13:22:17 -0700 Subject: [PATCH 0316/1324] [XLA:GPU] Enable `TritonGemmContractionDims` tests using the generic Triton emitter. Make sure to disable the autotuner in the test, since it isn't necessary. Also add a `TODO` to move these tests to a deviceless test file. PiperOrigin-RevId: 744829684 --- .../fusion_emitter_device_legacy_port_test.cc | 62 ++++++++++--------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index c07585d8b6e052..52f5a728a69892 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -3827,16 +3827,20 @@ ENTRY e { /*run_hlo_passes=*/false)); } +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. We should move all the +// `TritonGemmContractionDims` tests. class TritonGemmContractionDims : public TritonGemmTest { public: DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true); + debug_options.set_xla_gpu_autotune_level(0); return debug_options; } }; -TEST_F(TritonGemmContractionDims, DISABLED_TritonDotForceContractionDims_1_0) { +TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3844,9 +3848,9 @@ TEST_F(TritonGemmContractionDims, DISABLED_TritonDotForceContractionDims_1_0) { HloModule m ENTRY e { - parameter.0 = bf16[16,40]{1,0} parameter(0) - parameter.1 = bf16[40,32]{1,0} parameter(1) - ROOT dot.31472 = bf16[16,32]{1,0} dot(parameter.0, parameter.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + p0 = bf16[16,40]{1,0} parameter(0) + p1 = bf16[40,32]{1,0} parameter(1) + ROOT dot = bf16[16,32]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -3856,13 +3860,12 @@ ENTRY e { ->root_instruction() ->fused_instructions_computation() ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 40}, {1, 0}), - m::Op().WithShape(BF16, {40, 32}, {0, 1})) + GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {16, 40}, {1, 0}), + m::Fusion().WithShape(BF16, {40, 32}, {0, 1})) .WithShape(BF16, {16, 32}, {1, 0}))); } -TEST_F(TritonGemmContractionDims, - DISABLED_TritonDotForceContractionDims_1_2_1_2) { +TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3870,9 +3873,9 @@ TEST_F(TritonGemmContractionDims, HloModule m ENTRY e { - parameter_0 = bf16[32,4,36]{2,1,0} parameter(0) - parameter_1 = bf16[40,4,36]{2,1,0} parameter(1) - ROOT dot.16450 = bf16[4,32,40]{2,1,0} dot(parameter_0, parameter_1), + p0 = bf16[32,4,36]{2,1,0} parameter(0) + p1 = bf16[40,4,36]{2,1,0} parameter(1) + ROOT dot = bf16[4,32,40]{2,1,0} dot(p0, p1), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2} })"; @@ -3882,17 +3885,17 @@ ENTRY e { // The contracting dims were already minor, so the layout is unchanged // (non-major batch dims are fine). - EXPECT_THAT(module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {32, 4, 36}, {2, 1, 0}), - m::Op().WithShape(BF16, {40, 4, 36}, {2, 1, 0})) - .WithShape(BF16, {4, 32, 40}, {2, 1, 0}))); + EXPECT_THAT( + module->entry_computation() + ->root_instruction() + ->fused_instructions_computation() + ->root_instruction(), + GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {32, 4, 36}, {2, 1, 0}), + m::Fusion().WithShape(BF16, {40, 4, 36}, {2, 1, 0})) + .WithShape(BF16, {4, 32, 40}, {2, 1, 0}))); } -TEST_F(TritonGemmContractionDims, - DISABLED_TritonDotForceContractionDims_1_2_0_1) { +TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3906,7 +3909,6 @@ ENTRY e { lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); @@ -3917,12 +3919,12 @@ ENTRY e { ->root_instruction() ->fused_instructions_computation() ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 16, 48}, {2, 1, 0}), - m::Op().WithShape(BF16, {16, 48, 32}, {1, 2, 0})) + GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {16, 16, 48}, {2, 1, 0}), + m::Fusion().WithShape(BF16, {16, 48, 32}, {1, 2, 0})) .WithShape(BF16, {16, 16, 32}, {2, 1, 0}))); } -TEST_F(TritonGemmContractionDims, DISABLED_TritonDotForceContractionDims_1_1) { +TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3930,19 +3932,19 @@ TEST_F(TritonGemmContractionDims, DISABLED_TritonDotForceContractionDims_1_1) { HloModule m ENTRY e { - parameter_0 = bf16[16,32]{1,0} parameter(0) - parameter_1 = bf16[40,32]{0,1} parameter(1) - ROOT dot.15148 = bf16[16,40]{1,0} dot(parameter_0, parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + p0 = bf16[16,32]{1,0} parameter(0) + p1 = bf16[40,32]{0,1} parameter(1) + ROOT dot = bf16[16,40]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT(module->entry_computation() ->root_instruction() ->fused_instructions_computation() ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 32}, {1, 0}), - m::Op().WithShape(BF16, {32, 40}, {1, 0})) + GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {16, 32}, {1, 0}), + m::Fusion().WithShape(BF16, {32, 40}, {1, 0})) .WithShape(BF16, {16, 40}, {1, 0}))); } From 5b53ff373bdae00b52dabd9558641bbcba7aad8b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 14:14:04 -0700 Subject: [PATCH 0317/1324] Remove outdates haswell references in tensorflow/BUILD PiperOrigin-RevId: 744845971 --- tensorflow/BUILD | 6 ------ tensorflow/core/common_runtime/eager/BUILD | 1 - tensorflow/core/common_runtime/gpu/BUILD | 1 - tensorflow/tensorflow.bzl | 1 - 4 files changed, 9 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 12c43a0aaf8db0..21103e3c9dd7e9 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -547,12 +547,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "haswell", - values = {"cpu": "haswell"}, - visibility = ["//visibility:public"], -) - # This condition takes precedence over :linux_x86_64 config_setting( name = "linux_x86_64_no_sse", diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index f2ada4a8ab854e..7d2d3148818736 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -267,7 +267,6 @@ tf_cuda_library( clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [], ( clean_dep("//tensorflow:linux_x86_64"), - clean_dep("//tensorflow:haswell"), ): [ "//tensorflow/core", "//tensorflow/core/framework:resource_base", diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index 56edf351c30cd3..7e2db9b67c3615 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -216,7 +216,6 @@ tf_cuda_library( clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [], ( clean_dep("//tensorflow:linux_x86_64"), - clean_dep("//tensorflow:haswell"), ): [ "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/jit:flags", diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 8ea4765974b40c..f64e111f78b4a3 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -319,7 +319,6 @@ def if_not_fuchsia(a): def if_linux_x86_64(a): return select({ clean_dep("//tensorflow:linux_x86_64"): a, - clean_dep("//tensorflow:haswell"): a, "//conditions:default": [], }) From c071e2c71b6f050517b1a4840d345218e23406f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 14:39:21 -0700 Subject: [PATCH 0318/1324] [HLO Diff] Remove dead code from HloValueTracing. PiperOrigin-RevId: 744854654 --- .../graph/analysis/hlo_value_tracing.cc | 14 ---------- .../graph/analysis/hlo_value_tracing.h | 27 ------------------- 2 files changed, 41 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc index 22db01d6f42657..7e1e25695267e0 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc @@ -122,11 +122,6 @@ void HloValueTracing::DeleteMarkedValues() { value_ids_to_delete_.clear(); } -const HloValue& HloValueTracing::GetValue(HloValue::Id value_id) const { - DCHECK(values_.contains(value_id)) << "Value not found: " << value_id; - return *values_.find(value_id)->second; -} - HloValue& HloValueTracing::GetValue(HloValue::Id value_id) { DCHECK(values_.contains(value_id)) << "Value not found: " << value_id; return *values_.find(value_id)->second; @@ -158,15 +153,6 @@ HloValueSet& HloValueTracing::GetValueSet(const HloInstruction* instruction, return *GetInstructionValueSet(instruction).mutable_element(index); } -const HloValueSet& HloValueTracing::GetValueSet( - const HloPosition& position) const { - return GetValueSet(position.instruction, position.index); -} - -HloValueSet& HloValueTracing::GetValueSet(const HloPosition& position) { - return GetValueSet(position.instruction, position.index); -} - bool HloValueTracing::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h index 3e5d744fbdbe05..6cabdf4f838787 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.h @@ -65,38 +65,12 @@ class HloValueTracing { // given position. const HloValueSet& GetValueSet(const HloInstruction* instruction, const ShapeIndex& index = {}) const; - const HloValueSet& GetValueSet(const HloPosition& position) const; - HloValueSet& GetValueSet(const HloPosition& position); HloValueSet& GetValueSet(const HloInstruction* instruction, const ShapeIndex& index = {}); - // Returns the unique value in the HloValueSet at the given instruction and - // shape index. CHECKs if the value set does not contain a exactly one value. - const HloValue& GetUniqueValueAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) const { - return GetValueSet(instruction, index).GetUniqueValue(); - } - HloValue& GetUniqueValueAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) { - return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); - } - // Returns the HloValue with the given Id. - const HloValue& GetValue(HloValue::Id value_id) const; HloValue& GetValue(HloValue::Id value_id); - // Returns the total number of HloValues. - int64_t value_count() const { return values_.size(); } - - // Returns a vector of all HloValues stabily sorted by HloValue::Id. - const std::vector& values() const { return values_vector_; } - - // Returns the call graph used for computing the dataflow. - const CallGraph& call_graph() const { return *call_graph_; } - - - const HloModule& module() const { return module_; } - private: HloValueTracing(const HloModule& module, absl::flat_hash_set execution_threads); @@ -123,7 +97,6 @@ class HloValueTracing { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); - bool UpdateCustomCallValueSet(HloInstruction* custom_call); bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); From 94afdde16d9ee7ddaf32dd873ba69f42c5b38379 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 14:59:33 -0700 Subject: [PATCH 0319/1324] Remove UpdateCustomCallValueSet declaration as the definition was removed earlier. PiperOrigin-RevId: 744861348 --- third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h index 5509e137043333..c89df918c9aabb 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h @@ -289,7 +289,6 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); - bool UpdateCustomCallValueSet(HloInstruction* custom_call); bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); From 106d75d07857c2930d872ddf1f8a1772a500957c Mon Sep 17 00:00:00 2001 From: Daniel Chen Date: Mon, 7 Apr 2025 15:03:36 -0700 Subject: [PATCH 0320/1324] Print repetitive pattern groups. PiperOrigin-RevId: 744862678 --- .../xla/hlo/tools/hlo_diff/hlo_diff_main.cc | 5 +- .../xla/hlo/tools/hlo_diff/hlo_diff_summary.h | 2 +- .../xla/xla/hlo/tools/hlo_diff/render/BUILD | 10 ++ .../hlo_diff/render/graph_url_generator.h | 48 ++++++ .../render/hlo_gumgraph_html_renderer.cc | 159 +++++++++++++++--- .../render/hlo_gumgraph_html_renderer.h | 13 +- 6 files changed, 201 insertions(+), 36 deletions(-) create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/graph_url_generator.h diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc index 4d072bbc46177c..89c4c734ce8172 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc @@ -175,10 +175,7 @@ absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, if (!html_output.empty()) { std::ostringstream html; RenderHtml( - diff, diff_summary, - [](const HloInstruction* left_inst, const HloInstruction* right_inst) { - return ""; - }, + diff, diff_summary, nullptr, [](absl::string_view op_name) { return std::nullopt; }, [](absl::string_view op_name) { return std::nullopt; }, html); TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h index 5b38e3c31940d5..a5dfed6c855ace 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h @@ -62,7 +62,7 @@ struct ComputationSummary { bool all_unchanged = true; }; -// A group of computations that are connected in the graph. +// A group of left and right computations that form a diff pattern. struct ComputationGroup { std::vector left_computations; std::vector right_computations; diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD index f868bf87f3d83b..9168981c0a8cc6 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD @@ -61,6 +61,7 @@ cc_library( srcs = ["hlo_gumgraph_html_renderer.cc"], hdrs = ["hlo_gumgraph_html_renderer.h"], deps = [ + ":graph_url_generator", ":hlo_gumgraph_renderer_util", "//xla/hlo/ir:hlo", "//xla/hlo/tools/hlo_diff:hlo_diff_result", @@ -74,3 +75,12 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "graph_url_generator", + hdrs = ["graph_url_generator.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/graph_url_generator.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/graph_url_generator.h new file mode 100644 index 00000000000000..f318836f0c5531 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/graph_url_generator.h @@ -0,0 +1,48 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TOOLS_HLO_DIFF_RENDER_GRAPH_URL_GENERATOR_H_ +#define XLA_HLO_TOOLS_HLO_DIFF_RENDER_GRAPH_URL_GENERATOR_H_ + +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { +namespace hlo_diff { + +// A helper class to generate a url to the graph visualization. +class GraphUrlGenerator { + public: + virtual ~GraphUrlGenerator() = default; + + // Generates a url to the graph visualization for the given selected nodes. + virtual std::string Generate(absl::string_view left_selected_node_id, + absl::string_view right_selected_node_id) = 0; + + // Generates a url to the graph visualization for the given instruction pair. + virtual std::string Generate(const HloInstruction* left_inst, + const HloInstruction* right_inst) = 0; + + // Generates a url to the graph visualization for the given computation pair. + virtual std::string Generate(const HloComputation* left_comp, + const HloComputation* right_comp) = 0; +}; + +} // namespace hlo_diff +} // namespace xla + +#endif // XLA_HLO_TOOLS_HLO_DIFF_RENDER_GRAPH_URL_GENERATOR_H_ diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc index fb788ca7745a0f..95500a2c3718de 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc @@ -34,6 +34,7 @@ #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/render/graph_url_generator.h" #include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h" namespace xla { @@ -60,7 +61,7 @@ std::string PrintCss() { .section > .content { font-size: 14px; } - + details { margin: 0; padding: 0; @@ -73,9 +74,9 @@ std::string PrintCss() { background-color: #eee; } details > .content { - padding-left: 10px; + padding-left: 20px; } - + .list { margin: 0; padding: 0; @@ -83,7 +84,12 @@ std::string PrintCss() { .list > .item:hover { background-color: #eee; } - + + .attributes-list { + margin: 0; + padding: 0; + } + .tooltip { position: relative; display: inline-block; @@ -122,7 +128,7 @@ std::string PrintCss() { opacity: 1; } - )"; + )"; } // Prints the div html block. @@ -153,20 +159,21 @@ std::string PrintList(absl::Span items) { "list"); } -// Prints a link to the instruction in model explorer if url_generator is not -// null, otherwise returns the text directly. -std::string PrintInstructionLink(const HloInstruction* left_inst, - const HloInstruction* right_inst, - absl::string_view text, - UrlGenerator url_generator) { - std::string url = url_generator(left_inst, right_inst); +// Prints a list of attribute items. +std::string PrintAttributesList(absl::Span items) { + return PrintDiv(absl::StrJoin(items, "", + [](std::string* out, const auto& item) { + absl::StrAppend(out, PrintDiv(item, "item")); + }), + "attributes-list"); +} - if (url.empty()) { - return std::string(text); - } +// Prints a link to the given url. +std::string PrintLink(absl::string_view text, absl::string_view url) { return absl::StrFormat("%s", url, text); } +// Prints a span with a tooltip. std::string PrintTooltip(absl::string_view text, absl::string_view tooltip_text) { return absl::StrFormat( @@ -176,13 +183,46 @@ std::string PrintTooltip(absl::string_view text, /*** Summary logic ***/ +// Prints a link to the instruction in model explorer if url_generator is not +// null, otherwise returns the text directly. +std::string PrintInstructionLink(const HloInstruction* left_inst, + const HloInstruction* right_inst, + absl::string_view text, + GraphUrlGenerator* url_generator) { + if (url_generator == nullptr) { + return std::string(text); + } + std::string url = url_generator->Generate(left_inst, right_inst); + if (url.empty()) { + return std::string(text); + } + return PrintLink(text, url); +} + +// Prints a link to the computation in model explorer if url_generator is not +// null, otherwise returns the text directly. +std::string PrintComputationLink(const HloComputation* left_comp, + const HloComputation* right_comp, + absl::string_view text, + GraphUrlGenerator* url_generator) { + if (url_generator == nullptr) { + return std::string(text); + } + std::string maybe_url = url_generator->Generate(left_comp, right_comp); + if (maybe_url.empty()) { + return std::string(text); + } + return PrintLink(text, maybe_url); +} + // The location of the instruction in the diff result. enum class InstructionLocation : std::uint8_t { kLeft, kRight }; // Prints a list of instructions. std::string PrintInstructionsAsList( absl::Span instructions, - InstructionLocation location, bool name_only, UrlGenerator url_generator) { + InstructionLocation location, bool name_only, + GraphUrlGenerator* url_generator) { std::vector instructions_list; for (const HloInstruction* inst : instructions) { std::string link; @@ -220,7 +260,7 @@ std::string PrintUnmatchedInstructions( absl::Span instructions, InstructionLocation location, const absl::flat_hash_set& opcodes_to_ignore, bool name_only, - UrlGenerator url_generator) { + GraphUrlGenerator* url_generator) { absl::flat_hash_map> instructions_by_opcode = GroupInstructionsByOpcode(instructions); std::vector> opcode_counts; @@ -328,7 +368,7 @@ std::string PrintChangedInstructions( const absl::flat_hash_map& instructions, const absl::flat_hash_set& opcodes_to_ignore, - UrlGenerator url_generator) { + GraphUrlGenerator* url_generator) { auto decorated_printer = [&url_generator](const HloInstruction* left_inst, const HloInstruction* right_inst) { std::vector diff_types = @@ -365,7 +405,7 @@ std::string PrintUnchangedInstructions( const absl::flat_hash_map& instructions, const absl::flat_hash_set& opcodes_to_ignore, - UrlGenerator url_generator) { + GraphUrlGenerator* url_generator) { auto simple_printer = [&url_generator](const HloInstruction* left_inst, const HloInstruction* right_inst) { return PrintInstructionLink( @@ -381,7 +421,7 @@ std::string PrintUnchangedInstructions( std::string PrintUnmatchedMetricsDiff( absl::Span instructions, - GetOpMetricFn get_op_metrics, UrlGenerator url_generator) { + GetOpMetricFn get_op_metrics, GraphUrlGenerator* url_generator) { std::vector> sorted_metrics_diff; for (const HloInstruction* inst : instructions) { if (auto metric = get_op_metrics(inst->name()); metric.has_value()) { @@ -405,7 +445,7 @@ std::string PrintMatchedMetricsDiff( const absl::flat_hash_map& instructions, GetOpMetricFn left_op_metrics, GetOpMetricFn right_op_metrics, - UrlGenerator url_generator) { + GraphUrlGenerator* url_generator) { std::vector, double>> sorted_metrics_diff; @@ -433,10 +473,80 @@ std::string PrintMatchedMetricsDiff( return PrintList(metrics_diff_list); } +// Summarize a diff group. +std::string SummarizeDiffGroup( + absl::Span computation_groups) { + if (computation_groups.size() > 1) { + return absl::StrFormat("Summarized %d computations with the same diff", + computation_groups.size()); + } + return "A single computation has unique diff"; +} + +// Prints the summary of the repetitive computation groups. +std::string PrintRepetitiveComputationGroups(const DiffSummary& diff_summary, + GraphUrlGenerator* url_generator) { + // Sort the computation groups by the number of computations in each group in + // descending order. + std::vector> sorted_computation_groups; + for (const auto& [_, computation_groups] : + diff_summary.grouped_computations) { + sorted_computation_groups.push_back(computation_groups); + } + std::sort( + sorted_computation_groups.begin(), sorted_computation_groups.end(), + [](absl::Span a, + absl::Span b) { return a.size() > b.size(); }); + std::string computation_group_list; + int i = 0; + for (const auto& computation_groups : sorted_computation_groups) { + if (computation_groups.empty()) { + continue; + } + const ComputationGroup& sample = computation_groups[0]; + // We only print the one-to-one mapping for now. + if (sample.left_computations.size() != 1 || + sample.right_computations.size() != 1) { + continue; + } + std::vector computation_pair_list; + for (const ComputationGroup& computation_group : computation_groups) { + if (computation_group.left_computations.size() != 1 || + computation_group.right_computations.size() != 1) { + continue; + } + const HloComputation* left_computation = + computation_group.left_computations[0]; + const HloComputation* right_computation = + computation_group.right_computations[0]; + computation_pair_list.push_back(PrintComputationLink( + left_computation, right_computation, + absl::StrFormat("%s and %s", left_computation->name(), + right_computation->name()), + url_generator)); + } + absl::StrAppend( + &computation_group_list, + PrintDetails( + absl::StrFormat("Group %d: %s (Sample: %s → %s)", ++i, + SummarizeDiffGroup(computation_groups), + sample.left_computations[0]->name(), + sample.right_computations[0]->name()), + PrintAttributesList( + {absl::StrFormat( + "Instruction count: %d → %d", + sample.left_computations[0]->instruction_count(), + sample.right_computations[0]->instruction_count()), + PrintDetails("Instances", + PrintList(computation_pair_list))}))); + } + return computation_group_list; +} + } // namespace void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, - UrlGenerator url_generator, GetOpMetricFn left_op_metrics, + GraphUrlGenerator* url_generator, GetOpMetricFn left_op_metrics, GetOpMetricFn right_op_metrics, std::ostringstream& out) { const absl::flat_hash_set ignored_opcodes(kIgnoredOpcodes.begin(), kIgnoredOpcodes.end()); @@ -488,6 +598,11 @@ void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, PrintMatchedMetricsDiff( diff_result.unchanged_instructions, left_op_metrics, right_op_metrics, url_generator)))); + + // Print repetitive computation groups + out << PrintSectionWithHeader( + "Group of computations with the same diff", + PrintRepetitiveComputationGroups(diff_summary, url_generator)); } } // namespace hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h index c2244629392fed..9076a5b09270e1 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h @@ -20,21 +20,16 @@ #include #include #include -#include #include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/render/graph_url_generator.h" namespace xla { namespace hlo_diff { -// A function that returns a visualization url for the given instruction pair. -using UrlGenerator = absl::FunctionRef; - // A function that returns the op metric for the given op name. using GetOpMetricFn = absl::FunctionRef(absl::string_view)>; @@ -42,10 +37,10 @@ using GetOpMetricFn = // Renders the diff result in HTML format, and writes the result to the given // output stream. -// url_generator can be specified which is used to link an url to each generated -// diff result. +// url_generator can be specified which is used to link an url to each +// generated diff result. void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, - UrlGenerator url_generator, GetOpMetricFn left_op_metrics, + GraphUrlGenerator* url_generator, GetOpMetricFn left_op_metrics, GetOpMetricFn right_op_metrics, std::ostringstream& out); } // namespace hlo_diff From 82e545d3a156c3cb6853a26bf238560d53e85d5a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 15:12:35 -0700 Subject: [PATCH 0321/1324] Allow `safe_reinterpret_cast` to cast between same pointer types. Currently, `safe_reinterpret_cast(p)`, where `p` has type `T*`, results in a compiler error that the template partial specialization is ambiguous. Rewrite the implementation to avoid this error. PiperOrigin-RevId: 744865869 --- third_party/xla/xla/tsl/util/safe_reinterpret_cast.h | 10 ++++------ .../xla/xla/tsl/util/safe_reinterpret_cast_test.cc | 10 ++++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index 82aebaf8ecdc46..67089a8100deae 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -67,15 +67,13 @@ struct IsCvByteLike : IsByteLike {}; template struct IsSafeCast : std::false_type {}; -// It's safe to cast a type to itself. -template -struct IsSafeCast : std::true_type {}; - -// It's safe to cast a pointer to/from a byte-like type. +// It's safe to cast a pointer to/from a byte-like type, or to/from the same +// type. template struct IsSafeCast : std::integral_constant::value || - IsCvByteLike::value> {}; + IsCvByteLike::value || + std::is_same_v> {}; // It's safe to cast a pointer to/from std::uintptr_t. template diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index e804e64906393e..6deecb32681800 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -101,5 +101,15 @@ TEST(SafeReinterpretCast, CanCastPointerToFromStdIntptrT) { EXPECT_EQ(safe_reinterpret_cast(intptr_t_p), &x); } +TEST(SafeReinterpretCast, CanCastPointerToFromSameType) { + const int x = 42; + const int* const int_p = safe_reinterpret_cast(&x); + EXPECT_EQ(int_p, &x); + + char y = 'A'; + char* const char_p = safe_reinterpret_cast(&y); + EXPECT_EQ(char_p, &y); +} + } // namespace } // namespace tsl From 8381f09c77795712be3f9ee85c226d47bf756c99 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 15:57:12 -0700 Subject: [PATCH 0322/1324] Fix an unnecessary `const_cast`. The `TfAllocatorAdapter` ctor doesn't actually need a mutable `Platform` object. Changing the parameter type to `const Platform*` prevents it from accidentally mutating the object and thus makes the code safer and more readable. PiperOrigin-RevId: 744880532 --- third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 2 +- .../xla/stream_executor/integrations/tf_allocator_adapter.cc | 2 +- .../xla/xla/stream_executor/integrations/tf_allocator_adapter.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 155a1cf07de8b2..0a0577bea80455 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -931,7 +931,7 @@ TfrtGpuDevice::TfrtGpuDevice(Options&& options) CHECK_OK(executor_->CreateStream().status()) << "Failed to create stream"; se_allocator_ = std::make_unique( - allocator_.get(), const_cast(executor_->GetPlatform())); + allocator_.get(), executor_->GetPlatform()); } void TfrtGpuDevice::SetClient(PjRtClient* client) { diff --git a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc index 1bc6d86ecec386..9cfedfff5549b9 100644 --- a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc +++ b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc @@ -39,7 +39,7 @@ TfAllocatorAdapter::TfAllocatorAdapter(tsl::Allocator *wrapped, Stream *stream) stream_(stream) {} TfAllocatorAdapter::TfAllocatorAdapter(tsl::Allocator *wrapped, - Platform *platform) + const Platform *platform) : DeviceMemoryAllocator(platform), wrapped_(wrapped), stream_(nullptr) {} TfAllocatorAdapter::~TfAllocatorAdapter() {} diff --git a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h index 7b45db652e8132..0396cd5625839d 100644 --- a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h +++ b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h @@ -53,7 +53,7 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator { TfAllocatorAdapter(tsl::Allocator *wrapped, Stream *stream); // Constructor for the cases where `stream` can not be provided. - TfAllocatorAdapter(tsl::Allocator *wrapped, Platform *platform); + TfAllocatorAdapter(tsl::Allocator *wrapped, const Platform *platform); ~TfAllocatorAdapter() override; From 3caa9b28e7aeecdb7209bea921f11761570bdfbf Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 7 Apr 2025 16:21:30 -0700 Subject: [PATCH 0323/1324] Add a common AbstractTrackedDeviceBuffer type which can be used by a AbstractLocalPjRtBuffer to allow a unified implementation of donation logic. PiperOrigin-RevId: 744888717 --- third_party/xla/xla/pjrt/BUILD | 18 ++ .../pjrt/abstract_tracked_device_buffer.cc | 247 ++++++++++++++++ .../xla/pjrt/abstract_tracked_device_buffer.h | 242 +++++++++++++++ .../xla/pjrt/pjrt_stream_executor_client.cc | 278 +++--------------- .../xla/pjrt/pjrt_stream_executor_client.h | 222 ++------------ .../xla/xla/pjrt/tracked_device_buffer.cc | 7 + .../xla/xla/pjrt/tracked_device_buffer.h | 6 +- 7 files changed, 573 insertions(+), 447 deletions(-) create mode 100644 third_party/xla/xla/pjrt/abstract_tracked_device_buffer.cc create mode 100644 third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index f498d64d1d2647..dfd717ac84a7ca 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -84,11 +84,28 @@ xla_cc_test( ], ) +cc_library( + name = "abstract_tracked_device_buffer", + srcs = ["abstract_tracked_device_buffer.cc"], + hdrs = ["abstract_tracked_device_buffer.h"], + deps = [ + ":pjrt_client", + ":pjrt_future", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "tracked_device_buffer", srcs = ["tracked_device_buffer.cc"], hdrs = ["tracked_device_buffer.h"], deps = [ + ":abstract_tracked_device_buffer", ":event_pool", ":pjrt_client", ":pjrt_common", @@ -504,6 +521,7 @@ cc_library( hdrs = ["pjrt_stream_executor_client.h"], visibility = internal_visibility(["//xla:friends"]), deps = [ + ":abstract_tracked_device_buffer", ":event_pool", ":host_callback", ":host_memory_spaces", diff --git a/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.cc b/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.cc new file mode 100644 index 00000000000000..a0e99346a94e05 --- /dev/null +++ b/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.cc @@ -0,0 +1,247 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/abstract_tracked_device_buffer.h" + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { + +CommonPjRtBuffer::CommonPjRtBuffer( + std::unique_ptr device_buffer) + : device_buffer_(std::move(device_buffer)) { + for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { + holds_[i] = 0; + } +} + +CommonPjRtBuffer::~CommonPjRtBuffer() { + for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { + CHECK_EQ(holds_[i], 0) << "Non-zero type " << i << " hold on destruction."; + } +} + +void CommonPjRtBuffer::WaitForOutstandingUsageHolds() { + auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return holds_[ScopedHold::kUsage] == 0; + }; + mu_.Await(absl::Condition(¬_in_usage_hold)); +} + +void CommonPjRtBuffer::WaitForOutstandingDonationHold() { + auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return holds_[ScopedHold::kDonation] == 0; + }; + mu_.Await(absl::Condition(¬_in_donation_hold)); +} + +absl::StatusOr +CommonPjRtBuffer::GetBufferForUsageOrExternalHoldLocked(ScopedHold::Type type) { + // All callers should have called WaitForOutstandingDonationHold(). + CHECK_EQ(holds_[ScopedHold::kDonation], 0); + if (device_buffer_ == nullptr) { + return absl::InvalidArgumentError("Buffer has been deleted or donated."); + } else { + ++holds_[type]; + } + return device_buffer_.get(); +} + +absl::StatusOr> +CommonPjRtBuffer::GetBufferForDonationHoldLocked() { + // All callers should have called WaitForOutstandingDonationHold(). + CHECK_EQ(holds_[ScopedHold::kDonation], 0); + if (device_buffer_ == nullptr) { + return absl::InvalidArgumentError("Donation requested for invalid buffer"); + } + if (holds_[ScopedHold::kExternalReference] > 0) { + return absl::InvalidArgumentError( + "Donation requested for buffer with external reference"); + } + // First add the donation hold. + ++holds_[ScopedHold::kDonation]; + // Then wait for any usage holds to be dropped or converted. No new usage + // holds can be added until we drop the donation hold so this wait will + // complete eventually. + WaitForOutstandingUsageHolds(); + // Because we added a donation hold, nobody could release the buffer while + // we were waiting. + CHECK(device_buffer_ != nullptr); + return std::move(device_buffer_); +} + +void CommonPjRtBuffer::AcquireHoldLocked(ScopedHold* hold) { + if (hold->type() == ScopedHold::kDonation) { + hold->AcquireDonation(GetBufferForDonationHoldLocked()); + return; + } + + hold->AcquireUsageOrExternalReference( + GetBufferForUsageOrExternalHoldLocked(hold->type())); +} + +void CommonPjRtBuffer::DropUsageOrExternalHold( + ScopedHold::Type type, AbstractTrackedDeviceBuffer* buffer) { + absl::MutexLock lock(&mu_); + CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); + CHECK_GT(holds_[type], 0); + --holds_[type]; +} + +void CommonPjRtBuffer::DropDonationHold( + std::unique_ptr buffer) { + absl::MutexLock lock(&mu_); + CHECK_EQ(device_buffer_.get(), nullptr); + device_buffer_ = std::move(buffer); + CHECK_GT(holds_[ScopedHold::kDonation], 0); + --holds_[ScopedHold::kDonation]; + CHECK_EQ(holds_[ScopedHold::kDonation], 0); + CHECK_EQ(holds_[ScopedHold::kUsage], 0); + CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); +} + +absl::Status CommonPjRtBuffer::ScopedHold::status() const { + // Lazily create absl::Status values only when they are requested. + switch (state_) { + case kUninitialized: + return absl::InvalidArgumentError("Buffer has not been initialized"); + case kValid: + return absl::OkStatus(); + case kMoved: + return absl::InvalidArgumentError("Buffer has been moved."); + case kConverted: + return absl::InvalidArgumentError("Buffer has been converted"); + case kReleased: + return absl::InvalidArgumentError("Buffer has been released"); + case kDonated: + return absl::InvalidArgumentError("Buffer has been donated"); + case kError: + return status_; + default: + CHECK(false) << "Unexpected state value " << state_; + } +} + +void CommonPjRtBuffer::ScopedHold::DropHold() { + if (ok()) { + if (type_ == kDonation) { + parent_->DropDonationHold(std::move(buffer_)); + } else { + parent_->DropUsageOrExternalHold(type_, buffer_ptr_); + } + } +} + +CommonPjRtBuffer::ScopedHold::~ScopedHold() { DropHold(); } + +CommonPjRtBuffer::ScopedHold::ScopedHold(ScopedHold&& other) + : parent_(other.parent_), + type_(other.type_), + state_(other.state_), + status_(std::move(other.status_)), + buffer_ptr_(other.buffer_ptr_), + buffer_(std::move(other.buffer_)) { + // Preserve the invariant that status is invalid if buffer == nullptr. + other.SetState(kMoved); +} + +void CommonPjRtBuffer::ScopedHold::AcquireDonation( + absl::StatusOr> buffer_or) { + CHECK(!ok()); + if (buffer_or.ok()) { + buffer_ = std::move(buffer_or).value(); + buffer_ptr_ = buffer_.get(); + SetState(kValid); + } else { + status_ = std::move(buffer_or).status(); + buffer_ = nullptr; + buffer_ptr_ = nullptr; + SetState(kError); + } + // Check the invariant holds. + CHECK(!ok() || buffer_ptr_ != nullptr); +} + +void CommonPjRtBuffer::ScopedHold::AcquireUsageOrExternalReference( + absl::StatusOr buffer_or) { + CHECK(!ok()); + if (buffer_or.ok()) { + buffer_.reset(); + buffer_ptr_ = buffer_or.value(); + SetState(kValid); + } else { + status_ = std::move(buffer_or).status(); + buffer_.reset(); + buffer_ = nullptr; + SetState(kError); + } + // Check the invariant holds. + CHECK(!ok() || buffer_ptr_ != nullptr); +} + +void CommonPjRtBuffer::ScopedHold::ConfirmDonation() { + CHECK(ok()); + CHECK_EQ(type(), kDonation); + parent()->ConfirmDonation(buffer()); + SetState(kDonated); +} + +void CommonPjRtBuffer::ConfirmDonation( + AbstractTrackedDeviceBuffer* device_buffer) { + absl::MutexLock lock(&mu_); + CHECK_EQ(holds_[ScopedHold::kUsage], 0); + CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); + CHECK_EQ(holds_[ScopedHold::kDonation], 1); + holds_[ScopedHold::kDonation] = 0; + device_buffer->ConfirmDonation(); +} + +std::unique_ptr CommonPjRtBuffer::ReleaseBuffer() { + absl::MutexLock lock(&mu_); + { + tsl::profiler::TraceMe t1("Wait for donation holds"); + // We first wait for a donation hold to complete if there is one in + // progress. If the donation succeeds via ConfirmDonation() then it will + // set device_buffer_ to nullptr before returning to this thread. + WaitForOutstandingDonationHold(); + } + if (device_buffer_ == nullptr) { + // Buffer has been deleted. + return nullptr; + } + // Return device_buffer_ by move which also sets it to nullptr, so + // that no other thread can add a hold while we are in + // WaitForOutstandingUsageHolds() below. + auto buffer = std::move(device_buffer_); + + tsl::profiler::TraceMe t2("Wait for usage holds"); + WaitForOutstandingUsageHolds(); + return buffer; +} + +bool CommonPjRtBuffer::IsDeleted() { + absl::MutexLock lock(&mu_); + return device_buffer_ == nullptr; +} + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h b/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h new file mode 100644 index 00000000000000..5fd453db673f66 --- /dev/null +++ b/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h @@ -0,0 +1,242 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_ABSTRACT_TRACKED_DEVICE_BUFFER_H_ +#define XLA_PJRT_ABSTRACT_TRACKED_DEVICE_BUFFER_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" + +namespace xla { + +class AbstractTrackedDeviceBuffer { + public: + virtual ~AbstractTrackedDeviceBuffer() = default; + + // Only to be called by ScopedHold to mark a successful donation. + virtual void ConfirmDonation() = 0; +}; + +class CommonPjRtBuffer : public PjRtBuffer { + public: + // Helper class to retain a "hold" on a CommonPjRtBuffer. A ScopedHold + // may not outlive its parent CommonPjRtBuffer. + // + // There are three types of hold, as follows: + // + // 1) Usage hold: a transient hold while an operation using the buffer is + // being enqueued to the runtime. + // A client acquires a usage hold by calling + // CommonPjRtBuffer::GetBufferWithHold(kUsage) or the convenience + // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the + // hold should be released using a call to ConvertUsageHold. If the ScopedHold + // is deleted without ConvertUsageHold being called, e.g., on error, the hold + // is dropped. It is legal to drop a usage hold instead of calling + // ConvertUsageHold, even if the buffer was successfully enqueued, as long as + // the client ensures that all necessary synchronization has been done. + // + // 2) External hold: a potentially long-lived hold while the buffer is being + // shared by an external framework, e.g., NumPy. + // A client acquires an external hold by calling + // CommonPjRtBuffer::GetBufferWithHold(kExternal) or the convenience + // wrapper GetBufferWithExternalReference and releases it by deleting the + // ScopedHold. The external framework should not modify the underlying buffer + // unless it is confident via its own synchronization that modifications do + // not race with reads from the CommonPjRtBuffer. + // + // 3) Donation hold: a transient hold while an execution that donates the + // buffer is being enqueued to the runtime. + // A client acquires a donation hold by calling + // CommonPjRtBuffer::GetBufferWithHold(kDonation). If the enqueue + // completes successfully the hold should be released using a call to + // ConfirmDonation after which the buffer is invalid. If the ScopedHold is + // deleted without ConfirmDonation being called, e.g., on error, the hold is + // dropped and the buffer remains valid. If the buffer is successfully + // enqueued the client *must* call ConfirmDonation. + // + // Donation holds behave like exclusive write locks: when a donation hold + // has been acquired, any attempt to acquire another hold of any type will + // block until the donation hold is dropped or confirmed. Acquiring a donation + // hold will fail with an error if there is any outstanding external hold, and + // will block if there are any outstanding usage holds until those holds are + // dropped or converted. + // + // Calls to CommonPjRtBuffer::Release (and transitively to + // CommonPjRtBuffer::Delete() and ~CommonPjRtBuffer()) will + // block until all usage and donation holds are either deleted or + // converted/confirmed. + class ScopedHold { + public: + enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; + // Use a State enum instead of encoding the state in an error absl::Status + // to avoid creating absl::Status values in non-error cases. Creating a + // absl::Status entails several allocations and can add O(us) to every use + // of a hold. + enum State { + kUninitialized = 0, + kValid, + kMoved, + kConverted, + kReleased, + kDonated, + kError + }; + + ~ScopedHold(); + ScopedHold(ScopedHold&& other); + ScopedHold(const ScopedHold&) = delete; + ScopedHold& operator=(const ScopedHold&) = delete; + + Type type() const { return type_; } + + absl::Status status() const; + bool ok() const { return state_ == kValid; } + + // Access to the underlying device buffer storage. Requires this->ok(). + AbstractTrackedDeviceBuffer* buffer() const { + CHECK_EQ(state_, kValid); + CHECK_NE(buffer_ptr_, nullptr); + return buffer_ptr_; + } + CommonPjRtBuffer* parent() const { return parent_; } + + // Confirms that the buffer was successfully donated to an execution. + // Only valid for holds of type kDonation. Causes the buffer to become + // invalid. + void ConfirmDonation(); + + protected: + ScopedHold(CommonPjRtBuffer* parent, Type type) + : parent_(parent), type_(type), state_(kUninitialized) {} + + // Sets buffer state. + void SetState(State state) { state_ = state; } + + private: + friend class CommonPjRtBuffer; + + // Acquires the unique ownership of the buffer. Called by parent_ to + // initialize the donation hold. + void AcquireDonation( + absl::StatusOr> buffer_or); + + // Acquires a non-owning reference of the buffer. Called by parent_ to + // initialize the usage or external reference hold. + void AcquireUsageOrExternalReference( + absl::StatusOr buffer_or); + + // Drops this hold. It resets `holds_` counters. If it is a donation hold + // and an error occurs, it returns the device buffer to the + // CommonPjRtBuffer. + void DropHold(); + + CommonPjRtBuffer* const parent_; + const Type type_; + + // There is an invariant that if ok() then buffer_.value() != nullptr. + State state_; + absl::Status status_; + // The non-owning pointer to the underlying buffer. It is not nullptr for + // all types of holds. + AbstractTrackedDeviceBuffer* buffer_ptr_ = nullptr; + // If it is a donation hold, `buffer_` will not be nullptr. Otherwise, it is + // a nullptr. + std::unique_ptr buffer_; + }; + + bool IsDeleted() override; + + protected: + explicit CommonPjRtBuffer( + std::unique_ptr device_buffer); + ~CommonPjRtBuffer() override; + + // Blocks in mu_.Await until there are no more usage holds. + void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Blocks in mu_.Await until there is no donation hold. + void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Adds a donation hold and returns device_buffer_. Returns an error if + // device_buffer_ is null, or if a donation hold was requested when there is + // an outstanding external hold. + // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() + // must be called first.) + absl::StatusOr> + GetBufferForDonationHoldLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Adds a hold of usage or external reference and returns non-owning + // device_buffer_. Returns an error if device_buffer_ is null. + // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() + // must be called first.) + absl::StatusOr + GetBufferForUsageOrExternalHoldLocked(ScopedHold::Type type) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Adds a hold of hold->type() and initializes `hold` with device_buffer_. + // Initializes hold with an error if device_buffer_ is null, or if a donation + // hold was requested when there is an outstanding external hold. + // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() + // must be called first.) + void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Drops a hold without taking any other action. Does a sanity check that + // buffer==device_buffer_ or device_buffer_==nullptr. + void DropUsageOrExternalHold(ScopedHold::Type type, + AbstractTrackedDeviceBuffer* buffer); + + // Drops a hold without taking any other action. Does a sanity check that + // buffer==device_buffer_ or device_buffer_==nullptr. + void DropDonationHold(std::unique_ptr buffer); + + // Drops a donation hold and makes *this invalid for further use. Does a + // sanity check that buffer==device_buffer_. Called after device_buffer_ was + // successfully donated to an execution. + void ConfirmDonation(AbstractTrackedDeviceBuffer* device_buffer); + + void DecrementUsage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + CHECK_GT(holds_[ScopedHold::kUsage], 0); + --holds_[ScopedHold::kUsage]; + } + + std::unique_ptr ReleaseBuffer() + ABSL_LOCKS_EXCLUDED(mu_); + + AbstractTrackedDeviceBuffer* device_buffer() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return device_buffer_.get(); + } + + mutable absl::Mutex mu_; + PjRtFuture<>::Promise definition_promise_ ABSL_GUARDED_BY(mu_); + + private: + std::unique_ptr device_buffer_ + ABSL_GUARDED_BY(mu_); + // Count of holds on the buffer. + std::array holds_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // XLA_PJRT_ABSTRACT_TRACKED_DEVICE_BUFFER_H_ diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index ad60e2ba4143a5..e31ed5dc48cc68 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -103,6 +103,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_print_options.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/event_pool.h" #include "xla/pjrt/host_callback.h" @@ -529,78 +530,16 @@ AllocateDestinationBuffer( return py_buffer; } -PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() { - if (ok()) { - if (type_ == kDonation) { - parent_->DropDonationHold(std::move(buffer_)); - } else { - parent_->DropUsageOrExternalHold(type_, buffer_ptr_); - } - } -} - -PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other) - : parent_(other.parent_), - type_(other.type_), - state_(other.state_), - status_(std::move(other.status_)), - buffer_ptr_(other.buffer_ptr_), - buffer_(std::move(other.buffer_)) { - // Preserve the invariant that status is invalid if buffer == nullptr. - other.SetState(kMoved); -} - -void PjRtStreamExecutorBuffer::ScopedHold::AcquireDonation( - absl::StatusOr> buffer_or) { - CHECK(!ok()); - if (buffer_or.ok()) { - buffer_ = std::move(buffer_or).value(); - buffer_ptr_ = buffer_.get(); - SetState(kValid); - } else { - status_ = std::move(buffer_or).status(); - buffer_ = nullptr; - buffer_ptr_ = nullptr; - SetState(kError); - } - // Check the invariant holds. - CHECK(!ok() || buffer_ptr_ != nullptr); -} - -void PjRtStreamExecutorBuffer::ScopedHold::AcquireUsageOrExternalReference( - absl::StatusOr buffer_or) { - CHECK(!ok()); - if (buffer_or.ok()) { - buffer_.reset(); - buffer_ptr_ = buffer_or.value(); - SetState(kValid); - } else { - status_ = std::move(buffer_or).status(); - buffer_.reset(); - buffer_ = nullptr; - SetState(kError); - } - // Check the invariant holds. - CHECK(!ok() || buffer_ptr_ != nullptr); -} - void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold( se::Stream* usage_stream, std::shared_ptr event, bool reference_held) { CHECK(ok()); - CHECK_EQ(type_, kUsage); - parent_->ConvertUsageHold(buffer(), usage_stream, std::move(event), - reference_held); + CHECK_EQ(type(), kUsage); + parent()->ConvertUsageHold(buffer(), usage_stream, std::move(event), + reference_held); SetState(kConverted); } -void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() { - CHECK(ok()); - CHECK_EQ(type_, kDonation); - parent_->ConfirmDonation(buffer()); - SetState(kDonated); -} - bool PjRtStreamExecutorBuffer::IsOnCpu() const { return memory_space() != nullptr && memory_space()->kind() == PinnedHostMemorySpace::kKind; @@ -612,16 +551,11 @@ absl::StatusOr PjRtStreamExecutorBuffer::logical_on_device_shape() { } auto* local_device = device_->local_device_state(); auto* stream = local_device->GetDeviceToHostStream(); - ScopedHold device_buffer(this, ScopedHold::kUsage); - { - absl::MutexLock lock(&mu_); - // We can't perform any other action while a donation hold is in progress. - WaitForOutstandingDonationHold(); - if (device_buffer_ == nullptr) { - return InvalidArgument( - "logical_on_device_shape() called on deleted or donated buffer"); - } - AcquireHoldLocked(&device_buffer); + auto device_buffer = GetBufferWithUsageHold(); + if (!device_buffer.ok()) { + return InvalidArgument( + "logical_on_device_shape() called on deleted or donated buffer: %s", + device_buffer.status().ToString()); } WaitForBufferDefinitionEventsOnStream(device_buffer->definition_events(), @@ -1348,60 +1282,26 @@ absl::Span PjRtStreamExecutorClient::memory_spaces() PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer( Shape on_device_shape, std::unique_ptr device_buffer, PjRtClient* client, PjRtDevice* device, PjRtMemorySpace* memory_space) - : client_(tensorflow::down_cast(client)), + : CommonPjRtBuffer(std::move(device_buffer)), + client_(tensorflow::down_cast(client)), on_device_shape_(std::move(on_device_shape)), device_(tensorflow::down_cast(device)), - memory_space_(memory_space), - device_buffer_(std::move(device_buffer)) { - for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { - holds_[i] = 0; - } -} + memory_space_(memory_space) {} PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() { Delete(); - for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { - CHECK_EQ(holds_[i], 0); - } -} - -void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() { - auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return holds_[ScopedHold::kUsage] == 0; - }; - mu_.Await(absl::Condition(¬_in_usage_hold)); -} - -void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() { - auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return holds_[ScopedHold::kDonation] == 0; - }; - mu_.Await(absl::Condition(¬_in_donation_hold)); } absl::StatusOr> PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { tsl::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release"); - std::unique_ptr device_buffer; - TrackedDeviceBuffer::StreamAndEventContainer events; - { - absl::MutexLock lock(&mu_); - // We first wait for a donation hold to complete if there is one in - // progress. If the donation succeeds via ConfirmDonation() then it will - // set device_buffer_ to nullptr before returning to this thread. - WaitForOutstandingDonationHold(); - if (device_buffer_ == nullptr) { - return tsl::RCReference(); - } - // Set device_buffer_ to null now so that no other - // thread can add a hold while we are in WaitForOutstandingUsageHolds() - // below. - std::swap(device_buffer_, device_buffer); - WaitForOutstandingUsageHolds(); - // Now that all holds have completed and no more can be added, we can get - // the final set of usage events. - events = device_buffer->LockUseAndTransferUsageEvents(); + std::unique_ptr device_buffer( + static_cast(ReleaseBuffer().release())); + if (device_buffer == nullptr) { + return tsl::RCReference(); } + TrackedDeviceBuffer::StreamAndEventContainer events = + device_buffer->LockUseAndTransferUsageEvents(); auto device_memory = device_buffer->device_memory(); LocalDeviceState* local_device_state = device_->local_device_state(); if (wait_for_operations_to_complete) { @@ -1514,104 +1414,13 @@ void PjRtStreamExecutorBuffer::Delete() { TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status()); } -bool PjRtStreamExecutorBuffer::IsDeleted() { - absl::MutexLock lock(&mu_); - return device_buffer_ == nullptr; -} - -absl::StatusOr -PjRtStreamExecutorBuffer::GetBufferForUsageOrExternalHoldLocked( - ScopedHold::Type type) { - // All callers should have called WaitForOutstandingDonationHold(). - CHECK_EQ(holds_[ScopedHold::kDonation], 0); - if (device_buffer_ == nullptr) { - return InvalidArgument("Buffer has been deleted or donated."); - } else { - ++holds_[type]; - } - return device_buffer_.get(); -} - -absl::StatusOr> -PjRtStreamExecutorBuffer::GetBufferForDonationHoldLocked() { - // All callers should have called WaitForOutstandingDonationHold(). - CHECK_EQ(holds_[ScopedHold::kDonation], 0); - if (device_buffer_ == nullptr) { - return InvalidArgument("Donation requested for invalid buffer"); - } - if (holds_[ScopedHold::kExternalReference] > 0) { - return InvalidArgument( - "Donation requested for buffer with external reference"); - } - // First add the donation hold. - ++holds_[ScopedHold::kDonation]; - // Then wait for any usage holds to be dropped or converted. No new usage - // holds can be added until we drop the donation hold so this wait will - // complete eventually. - WaitForOutstandingUsageHolds(); - // Because we added a donation hold, nobody could release the buffer while - // we were waiting. - CHECK(device_buffer_ != nullptr); - return std::move(device_buffer_); -} - -void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) { - if (hold->type() == ScopedHold::kDonation) { - hold->AcquireDonation(GetBufferForDonationHoldLocked()); - return; - } - - hold->AcquireUsageOrExternalReference( - GetBufferForUsageOrExternalHoldLocked(hold->type())); -} - void PjRtStreamExecutorBuffer::ConvertUsageHold( TrackedDeviceBuffer* buffer, se::Stream* usage_stream, std::shared_ptr event, bool reference_held) { absl::MutexLock lock(&mu_); - CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); + CHECK(device_buffer() == buffer || device_buffer() == nullptr); buffer->AddUsageEvent(usage_stream, std::move(event), reference_held); - CHECK_GT(holds_[ScopedHold::kUsage], 0); - --holds_[ScopedHold::kUsage]; -} - -void PjRtStreamExecutorBuffer::ConfirmDonation( - TrackedDeviceBuffer* device_buffer) { - { - absl::MutexLock lock(&mu_); - CHECK_EQ(holds_[ScopedHold::kUsage], 0); - CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); - CHECK_EQ(holds_[ScopedHold::kDonation], 1); - holds_[ScopedHold::kDonation] = 0; - // As a sanity check ensure no more usage events can be added to the buffer. - device_buffer->LockUseAndTransferUsageEvents(); - // Give up ownership of the device memory so we don't free it when the last - // reference to device_buffer_ goes away. - device_buffer->ReleaseDeviceMemory(); - // Make *this invalid so it can't be used again. Any threads blocking in - // Release or GetBufferWithHold will see an invalid buffer and return. - device_buffer_.reset(); - } -} - -void PjRtStreamExecutorBuffer::DropUsageOrExternalHold( - ScopedHold::Type type, TrackedDeviceBuffer* buffer) { - absl::MutexLock lock(&mu_); - CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); - CHECK_GT(holds_[type], 0); - --holds_[type]; -} - -void PjRtStreamExecutorBuffer::DropDonationHold( - std::unique_ptr buffer) { - absl::MutexLock lock(&mu_); - CHECK_EQ(device_buffer_.get(), nullptr); - device_buffer_ = std::move(buffer); - CHECK_GT(holds_[ScopedHold::kDonation], 0); - --holds_[ScopedHold::kDonation]; - CHECK_EQ(holds_[ScopedHold::kDonation], 0); - CHECK_EQ(holds_[ScopedHold::kUsage], 0); - CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); + DecrementUsage(); } PjRtFuture<> PjRtStreamExecutorBuffer::LazyToLiteral( @@ -1630,16 +1439,11 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { } LocalDeviceState* local_device = device_->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); - ScopedHold device_buffer(this, ScopedHold::kUsage); - { - absl::MutexLock lock(&mu_); - // We can't perform any other action while a donation hold is in progress. - WaitForOutstandingDonationHold(); - if (device_buffer_ == nullptr) { - return PjRtFuture<>(InvalidArgument( - "CopyToHostAsync() called on deleted or donated buffer")); - } - AcquireHoldLocked(&device_buffer); + auto device_buffer = GetBufferWithUsageHold(); + if (!device_buffer.ok()) { + return PjRtFuture<>( + InvalidArgument("ToLiteral() called on deleted or donated buffer: %s", + device_buffer.status().ToString())); } auto promise = PjRtFuture<>::CreatePromise(); @@ -1740,11 +1544,11 @@ PjRtFuture<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) { absl::StatusOr PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const { absl::MutexLock lock(&mu_); - if (device_buffer_ == nullptr || !device_buffer_->device_memory()) { + if (device_buffer() == nullptr || !device_buffer()->device_memory()) { return InvalidArgument( "GetOnDeviceSizeInBytes called on deleted or donated buffer"); } - return device_buffer_->device_memory()->mem().size(); + return device_buffer()->device_memory()->mem().size(); } PjRtFuture<> PjRtStreamExecutorBuffer::CopyRawToHost(void* dst, int64_t offset, @@ -1758,15 +1562,6 @@ PjRtFuture<> PjRtStreamExecutorBuffer::CopyRawToHostFuture( return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size); } -absl::StatusOr PjRtStreamExecutorBuffer::AsShapedBuffer() const { - absl::MutexLock lock(&mu_); - if (device_buffer_ == nullptr) { - return InvalidArgument( - "Attempted to fetch value of invalid/deleted buffer."); - } - return device_buffer_->AsShapedBuffer(on_device_shape_); -} - PjRtStreamExecutorBuffer::ScopedHold PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) { absl::MutexLock lock(&mu_); @@ -1915,16 +1710,11 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace( se::Stream* transfer_stream = transfer_local_device->GetDeviceToDeviceStream(); - ScopedHold src_device_buffer(this, ScopedHold::kUsage); - { - absl::MutexLock lock(&mu_); - // We can't perform any other action while a donation hold is in progress. - WaitForOutstandingDonationHold(); - if (device_buffer_ == nullptr) { - return InvalidArgument( - "CopyToDevice called on deleted or donated buffer"); - } - AcquireHoldLocked(&src_device_buffer); + auto src_device_buffer = GetBufferWithUsageHold(); + if (!src_device_buffer.ok()) { + return InvalidArgument( + "CopyToDevice() called on deleted or donated buffer: %s", + src_device_buffer.status().ToString()); } absl::StatusOr, @@ -1972,12 +1762,12 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() { PjRtFuture<>::Promise definition_promise; { absl::MutexLock lock(&mu_); - if (device_buffer_ == nullptr) { + if (device_buffer() == nullptr) { return PjRtFuture<>(InvalidArgument( "GetReadyFuture() called on deleted or donated buffer")); } if (!definition_promise_) { - definition_events = device_buffer_->definition_events(); + definition_events = device_buffer()->definition_events(); definition_promise_ = PjRtFuture<>::CreatePromise(); } definition_promise = definition_promise_; diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 80db814c8846fe..5f1cd62924ed4d 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -48,6 +48,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" @@ -481,110 +482,10 @@ class PjRtStreamExecutorClient : public PjRtClient { absl::StatusOr DevicesToDeviceAssignment( absl::Span> devices); -class PjRtStreamExecutorBuffer : public PjRtBuffer { +class PjRtStreamExecutorBuffer : public CommonPjRtBuffer { public: - // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold - // may not outlive its parent PjRtStreamExecutorBuffer. - // - // There are three types of hold, as follows: - // - // 1) Usage hold: a transient hold while an operation using the buffer is - // being enqueued onto a stream. - // A client acquires a usage hold by calling - // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience - // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the - // hold should be released using a call to ConvertUsageHold. If the ScopedHold - // is deleted without ConvertUsageHold being called, e.g., on error, the hold - // is dropped. It is legal to drop a usage hold instead of calling - // ConvertUsageHold, even if the buffer was successfully enqueued, as long as - // the client ensures that all necessary synchronization has been done. - // - // 2) External hold: a potentially long-lived hold while the buffer is being - // shared by an external framework, e.g., NumPy. - // A client acquires an external hold by calling - // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience - // wrapper GetBufferWithExternalReference and releases it by deleting the - // ScopedHold. The external framework should not modify the underlying buffer - // unless it is confident via its own synchronization that modifications do - // not race with reads from the PjRtStreamExecutorBuffer. - // - // 3) Donation hold: a transient hold while an execution that donates the - // buffer is being enqueued onto the compute stream. - // A client acquires a donation hold by calling - // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue - // completes successfully the hold should be released using a call to - // ConfirmDonation after which the buffer is invalid. If the ScopedHold is - // deleted without ConfirmDonation being called, e.g., on error, the hold is - // dropped and the buffer remains valid. If the buffer is successfully - // enqueued the client *must* call ConfirmDonation. - // - // Donation holds behave like exclusive write locks: when a donation hold - // has been acquired, any attempt to acquire another hold of any type will - // block until the donation hold is dropped or confirmed. Acquiring a donation - // hold will fail with an error if there is any outstanding external hold, and - // will block if there are any outstanding usage holds until those holds are - // dropped or converted. - // - // Calls to PjRtStreamExecutorBuffer::Release (and transitively to - // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will - // block until all usage and donation holds are either deleted or - // converted/confirmed. - class ScopedHold { + class ScopedHold : public CommonPjRtBuffer::ScopedHold { public: - enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; - // Use a State enum instead of encoding the state in an error absl::Status - // to avoid creating absl::Status values in non-error cases. Creating a - // absl::Status entails several allocations and can add O(us) to every use - // of a hold. - enum State { - kUninitialized = 0, - kValid, - kMoved, - kConverted, - kReleased, - kDonated, - kError - }; - - ~ScopedHold(); - ScopedHold(ScopedHold&& other); - ScopedHold(const ScopedHold&) = delete; - ScopedHold& operator=(const ScopedHold&) = delete; - - Type type() const { return type_; } - - absl::Status status() const { - // Lazily create absl::Status values only when they are requested. - switch (state_) { - case kUninitialized: - return InvalidArgument("Buffer has not been initialized"); - case kValid: - return absl::OkStatus(); - case kMoved: - return InvalidArgument("Buffer has been moved."); - case kConverted: - return InvalidArgument("Buffer has been converted"); - case kReleased: - return InvalidArgument("Buffer has been released"); - case kDonated: - return InvalidArgument("Buffer has been donated"); - case kError: - return status_; - default: - CHECK(false) << "Unexpected state value " << state_; - } - } - bool ok() const { return state_ == kValid; } - - // Access to the underlying device buffer storage. Requires this->ok(). - TrackedDeviceBuffer* buffer() const { - CHECK_EQ(state_, kValid); - CHECK_NE(buffer_ptr_, nullptr); - return buffer_ptr_; - } - TrackedDeviceBuffer* operator->() const { return buffer(); } - const TrackedDeviceBuffer& operator*() const { return *buffer(); } - // Converts the hold into a usage event. Only valid for holds of type // kUsage. // @@ -599,58 +500,23 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { std::shared_ptr event, bool reference_held); - // Confirms that the buffer was successfully donated to an execution. - // Only valid for holds of type kDonation. Causes the buffer to become - // invalid. - void ConfirmDonation(); - - // Adds the held device buffers in order to 'iterator'. Used to add the - // buffers to an ExecutionInput. We require but do not verify that - // 'iterator' when passed in is pointing to a sub-tuple of the - // ExecutionInput whose on_device_shape matches that of the - // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run - // out of bounds. Donates the device buffers if the hold type is kDonation, - // otherwise retains ownership of the device buffers. - void AddToInput(ShapeTree::iterator* iterator, - const ShapeTree::iterator& end, - ExecutionInput* execution_input, - se::DeviceMemoryAllocator* allocator) const; + TrackedDeviceBuffer* buffer() const { + return static_cast( + CommonPjRtBuffer::ScopedHold::buffer()); + } + TrackedDeviceBuffer* operator->() const { return buffer(); } + const TrackedDeviceBuffer& operator*() const { return *buffer(); } + + PjRtStreamExecutorBuffer* parent() const { + return static_cast( + CommonPjRtBuffer::ScopedHold::parent()); + } private: + using CommonPjRtBuffer::ScopedHold::ScopedHold; friend class PjRtStreamExecutorBuffer; friend class PjRtStreamExecutorClient; - - ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) - : parent_(parent), type_(type), state_(kUninitialized) {} - - // Sets buffer state. - void SetState(State state) { state_ = state; } - - // Acquires the unique ownership of the buffer. Called by parent_ to - // initialize the donation hold. - void AcquireDonation( - absl::StatusOr> buffer_or); - - // Acquires a non-owning reference of the buffer. Called by parent_ to - // initialize the usage or external reference hold. - void AcquireUsageOrExternalReference( - absl::StatusOr buffer_or); - - PjRtStreamExecutorBuffer* const parent_; - const Type type_; - - // There is an invariant that if ok() then - // buffer_.value() != nullptr. - State state_; - absl::Status status_; - // The non-owning pointer to the underlying buffer. It is not nullptr for - // all types of holds. - TrackedDeviceBuffer* buffer_ptr_ = nullptr; - // If it is a donation hold, `buffer_` will not be nullptr. Otherwise, it is - // a nullptr. - std::unique_ptr buffer_; }; - PjRtStreamExecutorBuffer(Shape on_device_shape, std::unique_ptr device_buffer, PjRtClient* client, PjRtDevice* device, @@ -663,6 +529,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; const Shape& on_device_shape() const override { return on_device_shape_; } + absl::StatusOr logical_on_device_shape() override; PjRtMemorySpace* memory_space() const override { return memory_space_; } PjRtStreamExecutorDevice* device() const override { return device_; } @@ -705,12 +572,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // external framework drops the reference. void Delete() override; - bool IsDeleted() override; - - // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The - // PjRtBuffer retains ownership of the device buffers. - absl::StatusOr AsShapedBuffer() const; - // Returns a hold on the TrackedDeviceBuffer holding the device // buffers. See comment on ScopedHold. ScopedHold GetBufferWithHold(ScopedHold::Type type); @@ -755,33 +616,10 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { private: friend class PjRtClient; - // Blocks in mu_.Await until there are no more usage holds. - void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Blocks in mu_.Await until there is no donation hold. - void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Adds a hold of 'type' and returns device_buffer_. Returns an error if - // device_buffer_ is null, or if a donation hold was requested when there is - // an outstanding external hold. - // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() - // must be called first.) - absl::StatusOr> - GetBufferForDonationHoldLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Adds a hold of usage or external reference and returns non-owning - // device_buffer_. Returns an error if device_buffer_ is null. - // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() - // must be called first.) - absl::StatusOr GetBufferForUsageOrExternalHoldLocked( - ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Adds a hold of hold->type() and initializes `hold` with device_buffer_. - // Initializes hold with an error if device_buffer_ is null, or if a donation - // hold was requested when there is an outstanding external hold. - // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() - // must be called first.) - void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + TrackedDeviceBuffer* device_buffer() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return static_cast(CommonPjRtBuffer::device_buffer()); + } // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after @@ -790,20 +628,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { std::shared_ptr event, bool reference_held); - // Drops a donation hold and makes *this invalid for further use. Does a - // sanity check that buffer==device_buffer_. Called after device_buffer_ was - // successfully donated to an execution. - void ConfirmDonation(TrackedDeviceBuffer* device_buffer); - - // Drops a hold without taking any other action. Does a sanity check that - // buffer==device_buffer_ or device_buffer_==nullptr. - void DropUsageOrExternalHold(ScopedHold::Type type, - TrackedDeviceBuffer* buffer); - - // Drops a hold without taking any other action. Does a sanity check that - // buffer==device_buffer_ or device_buffer_==nullptr. - void DropDonationHold(std::unique_ptr buffer); - absl::StatusOr, std::shared_ptr>> CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, @@ -819,12 +643,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { const Shape on_device_shape_; PjRtStreamExecutorDevice* const device_; PjRtMemorySpace* const memory_space_; - - mutable absl::Mutex mu_; - std::unique_ptr device_buffer_ ABSL_GUARDED_BY(mu_); - // Count of holds on the buffer. - std::array holds_ ABSL_GUARDED_BY(mu_); - PjRtFuture<>::Promise definition_promise_ ABSL_GUARDED_BY(mu_); }; // Allocates the device buffers for a buffer that will be used as the diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index 83d2231dd0f12c..8be84a770ef4d8 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -286,6 +286,13 @@ void TrackedDeviceBuffer::ReleaseDeviceMemory() { device_memory_ = tsl::RCReference(); } +void TrackedDeviceBuffer::ConfirmDonation() { + // As a sanity check ensure no more usage events can be added to the buffer. + LockUseAndTransferUsageEvents(); + // Release the memory so that no new usage is possible. + ReleaseDeviceMemory(); +} + void TrackedDeviceBuffer::AddUsageEvent( se::Stream* usage_stream, std::shared_ptr event, bool reference_held) { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index 6af9e6fd5cdbc0..5e47adaacad987 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/event_pool.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/executable.h" @@ -229,7 +230,7 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted { // owns all of the device memory in the tuple. It also tracks the definition and // usage of the memory on streams, to allow for synchronized usage and deletion // of memory under all of the allocation model semantics. -class TrackedDeviceBuffer { +class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { public: // Helper object to keep track of usage of the buffer on streams. struct StreamAndEvent { @@ -285,6 +286,9 @@ class TrackedDeviceBuffer { // buffer is passed to a computation that aliases its inputs to outputs. void ReleaseDeviceMemory(); + // Only to be called by ScopedHold to mark a successful donation. + void ConfirmDonation() override; + // Indicates that the buffer has been used on a stream. // // usage_stream: a stream that the buffer was used on. From e12c49d7764b5443c31f888f7bf9b87695fc876b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 16:42:39 -0700 Subject: [PATCH 0324/1324] This CL adds missing headers (IWYU) to prepare for upcoming change that will remove gtl/stl-util.h from a protobuf header. PiperOrigin-RevId: 744895422 --- tensorflow/core/platform/tensor_coding.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc index b5aa5ffe150c8e..53328afe0bcf79 100644 --- a/tensorflow/core/platform/tensor_coding.cc +++ b/tensorflow/core/platform/tensor_coding.cc @@ -27,6 +27,7 @@ limitations under the License. #if defined(TENSORFLOW_PROTOBUF_USES_CORD) #include "strings/cord_varint.h" +#include "util/gtl/stl_util.h" #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) namespace tensorflow { From 08b70b43ca9d89a6166659294fa53268072926e4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 16:46:16 -0700 Subject: [PATCH 0325/1324] Fix some unnecessary uses of `const_cast`. The `topology` fields in `PJRT_TopologyDescription_GetDeviceDescriptions_Args` and `PJRT_TopologyDescription_PlatformName_Args` point to an object we don't intend to change. Therefore we can change their type to `const PJRT_TopologyDescription*` to make the code safer and easier to understand - no more worry that someone might be mutating the description object. PiperOrigin-RevId: 744896385 --- third_party/xla/xla/pjrt/c/CHANGELOG.md | 7 +++++++ third_party/xla/xla/pjrt/c/pjrt_c_api.h | 6 +++--- third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc | 4 ++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index b6f85b496f6a10..9b8a09a7c392c2 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,12 @@ # PJRT C API changelog +## 0.68 + +* Changed the type of ``topology`` in + ``PJRT_TopologyDescription_PlatformName_Args`` and + ``PJRT_TopologyDescription_GetDeviceDescriptions_Args`` from + ``PJRT_TopologyDescription*`` to ``const PJRT_TopologyDescription*``. + ## 0.67 * Added ``PJRT_Client_DmaMap`` and ``PJRT_Client_DmaUnmap``. diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index 85fcd6511dd174..18e86b8b401b95 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -82,7 +82,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 67 +#define PJRT_API_MINOR 68 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -2237,7 +2237,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_PlatformVersion( struct PJRT_TopologyDescription_PlatformName_Args { size_t struct_size; PJRT_Extension_Base* extension_start; - PJRT_TopologyDescription* topology; + const PJRT_TopologyDescription* topology; // `platform_name` has the same lifetime as `topology`. It is owned by // `topology`. const char* platform_name; // out @@ -2253,7 +2253,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_PlatformName( struct PJRT_TopologyDescription_GetDeviceDescriptions_Args { size_t struct_size; PJRT_Extension_Base* extension_start; - PJRT_TopologyDescription* topology; + const PJRT_TopologyDescription* topology; // Has the same lifetime as topology. PJRT_DeviceDescription* const* descriptions; // out size_t num_descriptions; // out diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index ee2c6ad60242fe..121a50f5699d92 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -1074,7 +1074,7 @@ absl::string_view PlatformName(const PJRT_Api* api, PJRT_TopologyDescription_PlatformName_Args args; args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE; args.extension_start = nullptr; - args.topology = const_cast(topo_desc); + args.topology = topo_desc; LogFatalIfPjrtError(api->PJRT_TopologyDescription_PlatformName(&args), api); return {args.platform_name, args.platform_name_size}; } @@ -1085,7 +1085,7 @@ absl::Span DeviceDescriptions( args.struct_size = PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE; args.extension_start = nullptr; - args.topology = const_cast(topo_desc); + args.topology = topo_desc; LogFatalIfPjrtError( api->PJRT_TopologyDescription_GetDeviceDescriptions(&args), api); return {args.descriptions, args.num_descriptions}; From fcdd7fb8b4fe18be64ed0bc7e88ce9415630832e Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 7 Apr 2025 17:59:01 -0700 Subject: [PATCH 0326/1324] Migrate deprecated target `pattern_matcher_gmock` for internal users. There may still be users of the old target externally, so I'll hold off from deleting the deprecated target for the broader migration. PiperOrigin-RevId: 744915367 --- third_party/xla/xla/service/BUILD | 11 +++++++++++ .../xla/xla/service/all_reduce_promotion_test.cc | 2 +- .../xla/xla/service/all_reduce_reassociate_test.cc | 2 +- .../xla/xla/service/all_reduce_simplifier_test.cc | 2 +- .../xla/xla/service/change_op_data_type_test.cc | 2 +- third_party/xla/xla/service/dynamic_padder_test.cc | 2 +- third_party/xla/xla/service/gpu/BUILD | 2 +- .../xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc | 2 +- third_party/xla/xla/service/hlo_computation_test.cc | 2 +- .../xla/xla/service/hlo_creation_utils_test.cc | 2 +- third_party/xla/xla/service/hlo_cse_test.cc | 2 +- third_party/xla/xla/service/hlo_instruction_test.cc | 2 +- third_party/xla/xla/service/layout_assignment_test.cc | 2 +- .../xla/xla/service/pattern_matcher_gmock_test.cc | 2 +- 14 files changed, 24 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 14220b62ef0f64..68b110d741389c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -155,6 +155,7 @@ xla_cc_test( ":pattern_matcher_gmock", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -196,6 +197,7 @@ xla_cc_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -891,6 +893,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:xla_internal_test_main", "@local_tsl//tsl/platform:test", ], @@ -936,6 +939,7 @@ xla_cc_test( "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2145,6 +2149,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -2626,6 +2631,7 @@ xla_cc_test( "//xla:window_util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -3359,6 +3365,7 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", "//xla/hlo/transforms/simplifiers:hlo_dce", @@ -3638,6 +3645,7 @@ xla_cc_test( "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -4214,6 +4222,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -4292,6 +4301,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -6190,6 +6200,7 @@ xla_cc_test( ":change_op_data_type", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/all_reduce_promotion_test.cc b/third_party/xla/xla/service/all_reduce_promotion_test.cc index 86d5fde6eb71c5..ad3578755c9d75 100644 --- a/third_party/xla/xla/service/all_reduce_promotion_test.cc +++ b/third_party/xla/xla/service/all_reduce_promotion_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_reduce_reassociate_test.cc b/third_party/xla/xla/service/all_reduce_reassociate_test.cc index c0a91a93be215c..d14ec3acdf3d74 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate_test.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate_test.cc @@ -27,9 +27,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 7048bf20a61639..6850ef1d11b315 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/change_op_data_type_test.cc b/third_party/xla/xla/service/change_op_data_type_test.cc index 2bd746b4bc6bdb..24b59e1c460811 100644 --- a/third_party/xla/xla/service/change_op_data_type_test.cc +++ b/third_party/xla/xla/service/change_op_data_type_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc index 13c754482e7995..65ac2512efdcee 100644 --- a/third_party/xla/xla/service/dynamic_padder_test.cc +++ b/third_party/xla/xla/service/dynamic_padder_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/dynamic_dimension_inference.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1884521528cce5..a9ad63be9dcca9 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2259,10 +2259,10 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc index 7f990879906365..40e4092f4f2405 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -33,10 +33,10 @@ limitations under the License. #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index 4103dcac62d8d8..ccefc268cb46b3 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -30,10 +30,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/hlo_creation_utils_test.cc b/third_party/xla/xla/service/hlo_creation_utils_test.cc index debabe09c3c51e..4e8f19f031d157 100644 --- a/third_party/xla/xla/service/hlo_creation_utils_test.cc +++ b/third_party/xla/xla/service/hlo_creation_utils_test.cc @@ -26,11 +26,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/hlo_cse_test.cc b/third_party/xla/xla/service/hlo_cse_test.cc index a0de5d978a6dda..f33d68bb9ef285 100644 --- a/third_party/xla/xla/service/hlo_cse_test.cc +++ b/third_party/xla/xla/service/hlo_cse_test.cc @@ -27,11 +27,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index 72b46c7861d572..c97003ced8b129 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -39,12 +39,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo.pb.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 263efab73ec405..7e9d55e75bb0fb 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -38,7 +39,6 @@ limitations under the License. #include "xla/service/computation_layout.h" #include "xla/service/logical_buffer.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/pattern_matcher_gmock_test.cc b/third_party/xla/xla/service/pattern_matcher_gmock_test.cc index c0a279537f686d..899f909c983df0 100644 --- a/third_party/xla/xla/service/pattern_matcher_gmock_test.cc +++ b/third_party/xla/xla/service/pattern_matcher_gmock_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include #include From 8f7cb951ca296d817524714395e0d85e0fcf765f Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 7 Apr 2025 19:31:56 -0700 Subject: [PATCH 0327/1324] [IFRT] Test a `DeviceAssignment` returned from `ifrt::Client::GetDefaultDeviceAssignment()` Adds a test to check if `ifrt::Client::GetDefaultDeviceAssignment()` returns a valid `DeviceAssignment`. Fixes `GetDefaultDeviceAssignment()` implementation in NanoRt. PiperOrigin-RevId: 744936168 --- third_party/xla/xla/backends/cpu/nanort/BUILD | 1 + .../xla/backends/cpu/nanort/ifrt_client.cc | 45 ++++++++++++++----- third_party/xla/xla/python/ifrt/BUILD | 1 + .../xla/python/ifrt/client_impl_test_lib.cc | 32 +++++++++++++ 4 files changed, 67 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/nanort/BUILD b/third_party/xla/xla/backends/cpu/nanort/BUILD index 78b111fcc9588f..e7965bcfac6fcf 100644 --- a/third_party/xla/xla/backends/cpu/nanort/BUILD +++ b/third_party/xla/xla/backends/cpu/nanort/BUILD @@ -162,6 +162,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc index ed9fc1aa00cd63..ce69d79f708aa7 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" @@ -791,10 +792,10 @@ class NanoExecutable final GetInputShardings(program_shape, computation)); TF_ASSIGN_OR_RETURN(auto proto_output_shardings, GetOutputShardings(program_shape, computation)); - auto input_shardings = - IfrtShardingsFromProto(client, proto_input_shardings); - auto output_shardings = - IfrtShardingsFromProto(client, proto_output_shardings); + TF_ASSIGN_OR_RETURN(auto input_shardings, + IfrtShardingsFromProto(client, proto_input_shardings)); + TF_ASSIGN_OR_RETURN(auto output_shardings, + IfrtShardingsFromProto(client, proto_output_shardings)); return absl::WrapUnique(new NanoExecutable( client, std::move(computation), std::move(program_shape), @@ -977,8 +978,9 @@ class NanoExecutable final // Converts an OpSharding proto (from an HLO Instruction) to an ifrt // sharding. - static std::vector> IfrtShardingsFromProto( - NanoIfrtClient* client, absl::Span shardings) { + static absl::StatusOr>> + IfrtShardingsFromProto(NanoIfrtClient* client, + absl::Span shardings) { std::vector> result; result.reserve(shardings.size()); for (const auto& sharding : shardings) { @@ -991,10 +993,13 @@ class NanoExecutable final for (const auto dim : sharding.tile_assignment_dimensions()) { num_tiles *= dim; } - // Repeat the device for each tile. We only have one device anyway so - // just used the first. + if (num_tiles > client->devices().size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Sharding has %d tiles, but only %d devices are available.", + num_tiles, client->devices().size())); + } auto device_list = ifrt::BasicDeviceList::Create( - ifrt::BasicDeviceList::Devices(num_tiles, client->devices()[0])); + client->devices().subspan(0, num_tiles)); auto xla_sharding = *HloSharding::FromProto(sharding); result.push_back(ifrt::HloSharding::Create( std::move(device_list), client->devices()[0]->Memories()[0]->Kind(), @@ -1401,7 +1406,20 @@ absl::Span NanoIfrtClient::GetAllDevices() const { absl::StatusOr NanoIfrtClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) const { - return ifrt::DeviceAssignment(num_replicas, num_partitions); + if (num_replicas < 1 || num_partitions < 1) { + return absl::InvalidArgumentError( + absl::StrFormat("Requested device assignment is invalid: %d replicas, " + "%d partitions", + num_replicas, num_partitions)); + } else if (num_replicas * num_partitions > devices_.size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Requested device assignment is too large for the number of devices " + "available: %d vs. %d", + num_replicas * num_partitions, devices_.size())); + } + ifrt::DeviceAssignment device_assignment(num_replicas, num_partitions); + device_assignment.FillIota(0); + return device_assignment; } absl::StatusOr NanoIfrtClient::LookupDevice( @@ -1411,8 +1429,11 @@ absl::StatusOr NanoIfrtClient::LookupDevice( absl::StatusOr NanoIfrtClient::LookupAddressableDevice( int local_hardware_id) const { - TF_RET_CHECK(local_hardware_id >= 0); - TF_RET_CHECK(local_hardware_id < devices_.size()); + if (local_hardware_id < 0 || local_hardware_id >= devices_.size()) { + return absl::InvalidArgumentError( + absl::StrFormat("Device id %d is out of range [0, %d)", + local_hardware_id, devices_.size())); + } return devices_[local_hardware_id]; } diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 082adf338c0500..fb4ca796153459 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -441,6 +441,7 @@ cc_library( deps = [ ":ifrt", ":test_util", + "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc index 7c76d5caae04ca..719079a77fdc7b 100644 --- a/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc @@ -16,6 +16,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/test_util.h" +#include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -27,6 +28,7 @@ using ::testing::IsEmpty; using ::testing::Not; using ::testing::NotNull; using ::testing::SizeIs; +using ::tsl::testing::IsOk; TEST(ClientImplTest, RuntimeTypeAndPlatform) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); @@ -76,6 +78,36 @@ TEST(ClientImplTest, DefaultCompiler) { EXPECT_THAT(client->GetDefaultCompiler(), NotNull()); } +TEST(ClientImplTest, DefaultDeviceAssignment) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + { + TF_ASSERT_OK_AND_ASSIGN( + auto device_assignment, + client->GetDefaultDeviceAssignment(client->device_count(), 1)); + EXPECT_EQ(device_assignment.replica_count(), client->device_count()); + EXPECT_EQ(device_assignment.computation_count(), 1); + for (int i = 0; i < device_assignment.replica_count(); ++i) { + for (int j = 0; j < device_assignment.computation_count(); ++j) { + EXPECT_THAT(client->LookupDevice(DeviceId(device_assignment(i, j))), + IsOk()); + } + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto device_assignment, + client->GetDefaultDeviceAssignment(1, client->device_count())); + EXPECT_EQ(device_assignment.replica_count(), 1); + EXPECT_EQ(device_assignment.computation_count(), client->device_count()); + for (int i = 0; i < device_assignment.replica_count(); ++i) { + for (int j = 0; j < device_assignment.computation_count(); ++j) { + EXPECT_THAT(client->LookupDevice(DeviceId(device_assignment(i, j))), + IsOk()); + } + } + } +} + } // namespace } // namespace ifrt } // namespace xla From ced99a2ec109a12ec22fbbea87c46ef378658287 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Mon, 7 Apr 2025 20:00:46 -0700 Subject: [PATCH 0328/1324] Address previous FP8-related TODOs in jaxlib/XLA. The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4 This update allows us to address previous FP8-related TODOs in jaxlib/XLA. PiperOrigin-RevId: 744943824 --- third_party/xla/xla/python/ifrt/dtype.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/dtype.cc b/third_party/xla/xla/python/ifrt/dtype.cc index 5694e1886d1cda..fdc45de6bcfdb3 100644 --- a/third_party/xla/xla/python/ifrt/dtype.cc +++ b/third_party/xla/xla/python/ifrt/dtype.cc @@ -146,9 +146,8 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(C64); CASE(C128); CASE(F4E2M1FN); - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // CASE(F8E3M4); - // CASE(F8E4M3); + CASE(F8E3M4); + CASE(F8E4M3); CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); @@ -196,9 +195,8 @@ DTypeProto DType::ToProto() const { CASE(C64); CASE(C128); CASE(F4E2M1FN); - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // CASE(F8E3M4); - // CASE(F8E4M3); + CASE(F8E3M4); + CASE(F8E4M3); CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); From 50e871f1605e5257c3bbcd333dfeaa07846daa44 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 7 Apr 2025 21:05:16 -0700 Subject: [PATCH 0329/1324] PR #24464: Roll forward "PR #22292: [GPU] Support cuDNN explicit CUDA graph construction." Imported from GitHub PR https://github.com/openxla/xla/pull/24464 The problem that made the original PR get reverted was fixed in cuDNN frontend v1.11.0. Copybara import of the project: -- 806011b93a550fd0baed9a78c4af48dbae3ba2b3 by Ilia Sergachev : Roll forward "PR #22292: [GPU] Support cuDNN explicit CUDA graph construction." Merging this change closes #24464 PiperOrigin-RevId: 744961713 --- .../gpu/runtime/command_buffer_cmd.cc | 9 +- third_party/xla/xla/stream_executor/BUILD | 1 + .../xla/xla/stream_executor/command_buffer.h | 10 ++ .../xla/xla/stream_executor/cuda/BUILD | 1 + .../cuda/cuda_command_buffer.cc | 18 +++ .../cuda/cuda_command_buffer.h | 7 ++ .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 62 ++++++++-- .../xla/xla/stream_executor/cuda/cuda_dnn.h | 9 ++ third_party/xla/xla/stream_executor/dnn.h | 6 + third_party/xla/xla/stream_executor/gpu/BUILD | 4 + .../stream_executor/gpu/gpu_command_buffer.cc | 47 ++++++++ .../stream_executor/gpu/gpu_command_buffer.h | 17 ++- .../gpu/gpu_command_buffer_test.cc | 112 ++++++++++++++++++ .../rocm/rocm_command_buffer.h | 11 ++ 14 files changed, 304 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 8d75b609ed1b56..4e8a67d0324e82 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -1153,7 +1153,14 @@ absl::StatusOr CuDnnCmd::Record( VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); operands.push_back(buf); } - + TF_ASSIGN_OR_RETURN( + const bool supports_explicit, + graph_->get()->SupportsExplicitCommandBufferConstruction()); + if (supports_explicit) { + return RecordedCommands::Create(command_buffer->DnnGraph( + *graph_->get(), *execute_params.stream, + absl::Span(operands), {})); + } return RecordedCommands::Create(AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { return graph_->get()->Execute( diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 2a2abc9926e7a7..c346625a666ea3 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -607,6 +607,7 @@ cc_library( deps = [ ":bit_pattern", ":device_memory", + ":dnn", ":kernel", ":launch_dim", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index d7f4660d05c653..15dbdcb52c1447 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/bit_pattern.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" @@ -165,6 +166,15 @@ class CommandBuffer { const BitPattern& bit_pattern, size_t num_elements) = 0; + // Adds a DNN graph launch command. + virtual absl::StatusOr DnnGraph( + dnn::DnnGraph&, Stream&, absl::Span operands, + absl::Span dependencies) = 0; + + // Updates a DNN graph command. + virtual absl::Status DnnGraph(const Command*, dnn::DnnGraph&, Stream&, + absl::Span operands) = 0; + //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 80be22c8899a39..472f8e2bb04e6b 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1348,6 +1348,7 @@ cc_library( "//xla/stream_executor:bit_pattern", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", + "//xla/stream_executor:dnn", "//xla/stream_executor:kernel", "//xla/stream_executor:launch_dim", "//xla/stream_executor:semantic_version", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc index 550e8732859010..8e94a896e25188 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_kernel.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/scoped_update_mode.h" @@ -391,6 +392,23 @@ absl::Status CudaCommandBuffer::UpdateMemcpyD2DNode( "Failed to set memcpy d2d node params"); } +absl::Status CudaCommandBuffer::PopulateDnnGraphNode( + dnn::DnnGraph& dnn_graph, Stream& stream, + absl::Span operands) { + return dnn_graph.PopulateOrUpdateRawCommandBuffer(stream, operands, graph_, + false); +} + +absl::Status CudaCommandBuffer::UpdateDnnGraphNode( + dnn::DnnGraph& dnn_graph, Stream& stream, + absl::Span operands, GraphNodeHandle node_handle) { + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuGraphChildGraphNodeGetGraph(ToCudaGraphHandle(node_handle), &graph_))); + is_owned_graph_ = false; + return dnn_graph.PopulateOrUpdateRawCommandBuffer(stream, operands, graph_, + true); +} + absl::StatusOr CudaCommandBuffer::CreateChildNode( absl::Span dependencies, const CommandBuffer& nested) { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h index 5dc8e494e8c6e8..06b95380a91978 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h @@ -123,6 +123,13 @@ class CudaCommandBuffer final : public GpuCommandBuffer { DeviceMemoryBase source, uint64_t size) override; + absl::Status PopulateDnnGraphNode( + dnn::DnnGraph&, Stream&, absl::Span operands) override; + + absl::Status UpdateDnnGraphNode(dnn::DnnGraph&, Stream&, + absl::Span operands, + GraphNodeHandle) override; + absl::StatusOr CreateChildNode( absl::Span dependencies, const CommandBuffer& nested) override; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 53a18f564c8c71..d7135689aa9448 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -8128,12 +8128,11 @@ absl::Status CudnnGraph::Build(dnn::DnnSupport& dnn_support, RETURN_CUDNN_FRONTEND_STATUS(graph_.build_plans(cudnn->handle())); } -absl::Status CudnnGraph::Execute(Stream& stream, - absl::Span operands, - int64_t local_device_ordinal) const { - std::unordered_map tensor_to_ptr_map; +CudnnGraph::VariantPack CudnnGraph::PackOperands( + absl::Span operands, DeviceMemoryBase& workspace, + std::optional local_device_ordinal) const { + CudnnGraph::VariantPack tensor_to_ptr_map; absl::Span operands_without_workspace = operands; - DeviceMemoryBase workspace; if (graph_.get_workspace_size() > 0) { workspace = operands.back(); CHECK_EQ(graph_.get_workspace_size(), workspace.size()); @@ -8147,19 +8146,66 @@ absl::Status CudnnGraph::Execute(Stream& stream, } if (dropout_rng_offset_increment_ > 0) { - UpdateDropoutState(local_device_ordinal); + CHECK(local_device_ordinal.has_value()); + UpdateDropoutState(*local_device_ordinal); tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_; tensor_to_ptr_map[next_uid()] = - (void*)¤t_dropout_rng_offset_[local_device_ordinal]; + (void*)¤t_dropout_rng_offset_[*local_device_ordinal]; } + return tensor_to_ptr_map; +} + +absl::Status CudnnGraph::Execute(Stream& stream, + absl::Span operands, + int64_t local_device_ordinal) const { + DeviceMemoryBase workspace; + VariantPack tensor_to_ptr_map = + PackOperands(operands, workspace, local_device_ordinal); + const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); - auto cudnn = dnn_support.cudnn_->GetHandle(stream.parent(), &stream); + CudnnHandle cudnn = dnn_support.cudnn_->GetHandle(stream.parent(), &stream); + RETURN_CUDNN_FRONTEND_STATUS( graph_.execute(cudnn.handle(), tensor_to_ptr_map, workspace.opaque())); } +absl::StatusOr CudnnGraph::SupportsExplicitCommandBufferConstruction() + const { + std::vector notes; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.get_behavior_notes(notes)); + bool result = absl::c_any_of(notes, [](cudnn_frontend::BehaviorNote_t n) { + return n == cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API; + }); + if (!result) { + VLOG(5) << "Graph does not support CUDA graph native API:\n" + << graph_.print(); + } + return result; +} + +absl::Status CudnnGraph::PopulateOrUpdateRawCommandBuffer( + Stream& stream, absl::Span operands, + RawCommandBufferHandle cuda_graph, bool do_update) { + DeviceMemoryBase workspace; + VariantPack tensor_to_ptr_map = PackOperands(operands, workspace); + + const CudnnSupport& dnn_support = + static_cast(*stream.parent()->AsDnn()); + CudnnHandle cudnn = dnn_support.cudnn_->GetHandle(stream.parent(), &stream); + + if (do_update) { + RETURN_CUDNN_FRONTEND_STATUS( + graph_.update_cuda_graph(cudnn.handle(), tensor_to_ptr_map, + workspace.opaque(), (cudaGraph_t)cuda_graph)); + } else { + RETURN_CUDNN_FRONTEND_STATUS(graph_.populate_cuda_graph( + cudnn.handle(), tensor_to_ptr_map, workspace.opaque(), + (cudaGraph_t)cuda_graph)); + } +} + } // namespace gpu void initialize_cudnn() { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 4e214cf4e11ac3..f2338142ea45ec 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -75,12 +75,21 @@ class CudnnGraph : public dnn::DnnGraph { current_dropout_rng_offset_[local_device_ordinal] += dropout_rng_offset_increment_; } + absl::StatusOr SupportsExplicitCommandBufferConstruction() + const override; + absl::Status PopulateOrUpdateRawCommandBuffer( + Stream&, absl::Span operands, RawCommandBufferHandle, + bool do_update) override; private: cudnn_frontend::graph::Graph graph_; int64_t dropout_rng_seed_; mutable std::vector current_dropout_rng_offset_; int64_t dropout_rng_offset_increment_ = 0; + using VariantPack = std::unordered_map; + VariantPack PackOperands( + absl::Span operands, DeviceMemoryBase& workspace, + std::optional local_device_ordinal = std::nullopt) const; }; // cudnn-library based DNN support. For details on overridden interface diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index e644eb81e34775..6e3fa58c757632 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -1105,6 +1105,12 @@ class DnnGraph { int64_t local_device_ordinal) const = 0; virtual void InitDropoutState(int64_t local_device_count, int64_t seed, int64_t increment) = 0; + virtual absl::StatusOr SupportsExplicitCommandBufferConstruction() + const = 0; + using RawCommandBufferHandle = void*; + virtual absl::Status PopulateOrUpdateRawCommandBuffer( + Stream&, absl::Span operands, RawCommandBufferHandle, + bool do_update) = 0; }; using LazyDnnGraph = std::unique_ptr; diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 75378d3960246b..6d03c89b6c081e 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -171,6 +171,7 @@ gpu_only_cc_library( "//xla/stream_executor:bit_pattern", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", + "//xla/stream_executor:dnn", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", @@ -645,6 +646,7 @@ xla_test( "//xla/service:platform_util", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", + "//xla/stream_executor:dnn", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", @@ -671,8 +673,10 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@cudnn_frontend_archive//:cudnn_frontend", ] + if_cuda([ "//xla/stream_executor/cuda:cuda_platform", + "//xla/stream_executor/cuda:cudnn_plugin", ]) + if_rocm([ "//xla/stream_executor/rocm:rocm_platform", ]), diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 76d19bef40cfcc..e1289b3abdfc93 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/bit_pattern.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -332,6 +333,52 @@ absl::Status GpuCommandBuffer::Memset(const Command* command, return UpdateMemsetNode(gpu_command->handle, *dst, bit_pattern, num_elements); } +absl::StatusOr GpuCommandBuffer::DnnGraph( + dnn::DnnGraph& dnn_graph, Stream& stream, + absl::Span operands, + absl::Span dependencies) { + TF_RETURN_IF_ERROR(CheckNotFinalized()); + + if (state_ == State::kCreate) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr nested, + stream.parent()->CreateCommandBuffer(CommandBuffer::Mode::kNested)); + GpuCommandBuffer& nested_gpu = + tensorflow::down_cast(*nested); + TF_RETURN_IF_ERROR( + nested_gpu.PopulateDnnGraphNode(dnn_graph, stream, operands)); + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, + CreateChildNode(barrier, *nested)); + return AppendCommand(handle); + } + + if (state_ == State::kUpdate) { + Command& command = *commands_[update_state_.command_idx++]; + TF_RETURN_IF_ERROR(DnnGraph(&command, dnn_graph, stream, operands)); + return &command; + } + + return UnsupportedStateError(state_); +} + +absl::Status GpuCommandBuffer::DnnGraph(const Command* command, + dnn::DnnGraph& dnn_graph, + Stream& stream, + absl::Span operands) { + auto* gpu_command = tsl::down_cast(command); + TF_ASSIGN_OR_RETURN( + std::unique_ptr nested, + stream.parent()->CreateCommandBuffer(CommandBuffer::Mode::kNested)); + GpuCommandBuffer& nested_gpu = + tensorflow::down_cast(*nested); + TF_RETURN_IF_ERROR(nested_gpu.UpdateDnnGraphNode(dnn_graph, stream, operands, + gpu_command->handle)); + return UpdateChildNode(gpu_command->handle, *nested); +} + //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 3feef005f981da..14f9da3e57b556 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/gpu/scoped_update_mode.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/casts.h" @@ -161,6 +162,13 @@ class GpuCommandBuffer : public CommandBuffer { DeviceMemory index, std::vector branches, absl::Span dependencies) override; + absl::StatusOr DnnGraph( + dnn::DnnGraph&, Stream&, absl::Span operands, + absl::Span dependencies) override; + + absl::Status DnnGraph(const Command*, dnn::DnnGraph&, Stream&, + absl::Span operands) override; + absl::StatusOr Case( DeviceMemory index, std::vector branches, absl::Span dependencies) override; @@ -320,7 +328,7 @@ class GpuCommandBuffer : public CommandBuffer { DeviceMemoryBase destination, BitPattern bit_pattern, size_t num_elements) = 0; - // Updates an existing memset node. Note that `node_handle` needs to be refer + // Updates an existing memset node. Note that `node_handle` needs to refer // to a node created by `CreateMemsetNode`. virtual absl::Status UpdateMemsetNode(GraphNodeHandle node_handle, DeviceMemoryBase destination, @@ -337,6 +345,13 @@ class GpuCommandBuffer : public CommandBuffer { DeviceMemoryBase source, uint64_t size) = 0; + virtual absl::Status PopulateDnnGraphNode( + dnn::DnnGraph&, Stream&, absl::Span operands) = 0; + + virtual absl::Status UpdateDnnGraphNode(dnn::DnnGraph&, Stream&, + absl::Span operands, + GraphNodeHandle) = 0; + // Adds a new nested command buffer node to the graph. virtual absl::StatusOr CreateChildNode( absl::Span dependencies, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 352d08a5aabc12..ffa9d26a31e740 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -24,10 +25,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/ascii.h" #include "absl/types/span.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -47,6 +51,9 @@ limitations under the License. namespace stream_executor::gpu { +using ::testing::Each; +using ::tsl::testing::IsOkAndHolds; + static Platform* GpuPlatform() { auto name = absl::AsciiStrToUpper( xla::PlatformUtil::CanonicalPlatformName("gpu").value()); @@ -784,6 +791,111 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { ASSERT_EQ(dst, expected); } +TEST(GpuCommandBufferTest, CuDnnExplicitConstructionAndUpdateWork) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + dnn::DnnSupport& dnn_support = *executor->AsDnn(); + + if (dnn_support.GetVersion().value_or(dnn::VersionInfo{0, 0, 0}) < + dnn::VersionInfo(9, 7, 0)) { + GTEST_SKIP() << "Requires cuDNN 9.7.0 or later."; + } + + constexpr int dim_size = 32; + constexpr int total_elements = dim_size * dim_size; + + CudnnGraph graph([]() { + cudnn_frontend::graph::Graph graph; + graph.set_compute_data_type(cudnn_frontend::DataType_t::INT32); + std::shared_ptr lhs = + graph.tensor(cudnn_frontend::graph::Tensor_attributes() + .set_dim({1, dim_size, dim_size}) + .set_stride({dim_size * dim_size, dim_size, 1}) + .set_data_type(cudnn_frontend::DataType_t::INT8) + .set_uid(1)); + std::shared_ptr rhs = + graph.tensor_like(lhs); + rhs->set_uid(2); + graph.matmul(lhs, rhs, cudnn_frontend::graph::Matmul_attributes()) + ->set_output(true) + .set_data_type(cudnn_frontend::DataType_t::INT32) + .set_uid(3); + return graph; + }()); + TF_ASSERT_OK(graph.Prepare(dnn_support, NumericOptions{})); + TF_ASSERT_OK(graph.Build(dnn_support, /*plan_id=*/std::nullopt)); + EXPECT_THAT(graph.SupportsExplicitCommandBufferConstruction(), + IsOkAndHolds(true)); + + DeviceMemory input = executor->AllocateArray(total_elements); + TF_ASSERT_OK(stream->MemZero(&input, input.size())); + DeviceMemory output0 = + executor->AllocateArray(total_elements); + DeviceMemoryBase workspace; + std::vector operands; + operands.reserve(4); + operands.push_back(input); // multiplying the input by itself + operands.push_back(input); + operands.push_back(output0); + if (graph.Graph().get_workspace_size() > 0) { + workspace = executor->Allocate(graph.Graph().get_workspace_size()); + operands.push_back(workspace); + } + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr cmd_buffer, + executor->CreateCommandBuffer(primary)); + TF_ASSERT_OK( + cmd_buffer + ->DnnGraph(graph, *stream, absl::Span(operands), {}) + .status()); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + std::vector host_buffer(output0.ElementCount()); + + // Initialize and check the output before execution. + TF_ASSERT_OK(stream->Memset32(&output0, 123, output0.size())); + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output0, output0.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(123)); + + // Run the computation. + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); + + // Check the output after execution. + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output0, output0.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(0)); + + // Swap the output buffer. + DeviceMemory output1 = + executor->AllocateArray(total_elements); + operands[2] = output1; + executor->Deallocate(&output0); + + // Initialize and check the output before execution. + TF_ASSERT_OK(stream->Memset32(&output1, 456, output1.size())); + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output1, output1.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(456)); + + // Update the command buffer to write into the new output buffer. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK( + cmd_buffer + ->DnnGraph(graph, *stream, absl::Span(operands), {}) + .status()); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + // Run the computation. + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); + + // Check the output after execution. + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output1, output1.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(0)); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h index 299fcf93d66fdc..2d58554b1f8413 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h @@ -106,6 +106,17 @@ class RocmCommandBuffer : public GpuCommandBuffer { DeviceMemoryBase source, uint64_t size) override; + absl::Status PopulateDnnGraphNode( + dnn::DnnGraph&, Stream&, absl::Span operands) override { + return absl::UnimplementedError("Not implemented."); + } + + absl::Status UpdateDnnGraphNode(dnn::DnnGraph&, Stream&, + absl::Span operands, + GraphNodeHandle) override { + return absl::UnimplementedError("Not implemented."); + } + absl::StatusOr CreateChildNode( absl::Span dependencies, const CommandBuffer& nested) override; From 78f46ec25f561c07032002e041f821a1d8b43e69 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 21:24:19 -0700 Subject: [PATCH 0330/1324] Add global shape for easier debugging PiperOrigin-RevId: 744966730 --- third_party/xla/xla/python/ifrt/ir/sharding_param.cc | 4 ++-- third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/sharding_param.cc b/third_party/xla/xla/python/ifrt/ir/sharding_param.cc index e40ac373e8d3bd..04df160a9e9022 100644 --- a/third_party/xla/xla/python/ifrt/ir/sharding_param.cc +++ b/third_party/xla/xla/python/ifrt/ir/sharding_param.cc @@ -274,8 +274,8 @@ ShardingParam::LocalShapeFromGlobalShape( if (global_shape[i] % num_shards[i] != 0) { return absl::InvalidArgumentError(absl::StrCat( "Global shape is not divisible by the number of shards in dimension ", - i, ". Global size: ", global_shape[i], - ", number of shards: ", num_shards[i], ".")); + i, ". Global shape: [", absl::StrJoin(global_shape, ","), + "], number of shards: ", num_shards[i], ".")); } local_shape.push_back(global_shape[i] / num_shards[i]); } diff --git a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir index 4fef0876dc8bb8..29977603221a7d 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir @@ -166,7 +166,7 @@ module @missing_entry_function #device = #ifrt #sharding = #ifrt.sharding_param<2x1 to [0] on 2> module @non_divisible_global_shape attributes {ifrt.num_devices = 2} { - // expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global size: 3, number of shards: 2}} + // expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global shape: [3,2], number of shards: 2.}} func.func @main( %arg0: tensor<3x2xi32> {ifrt.sharding = #sharding, ifrt.devices = #device}) From d3328901ea11a35cbc5436f54ff25f9c4e529a26 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 21:47:58 -0700 Subject: [PATCH 0331/1324] Automated Code Change PiperOrigin-RevId: 744972752 --- third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc | 1 - third_party/xla/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc index a3e36d1a5d1d63..2309c12cb35f90 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index 15c56fe3e330a5..8dff15655a6a2a 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include "absl/functional/bind_front.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/client.h" From a680792a496d6664d4bff29dc5b58ca6c29f8e3a Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Mon, 7 Apr 2025 22:03:29 -0700 Subject: [PATCH 0332/1324] Move TFL::StridedSliceOp and TFL::SliceOp rank constraints to runtime checks This relaxes these constraints for other non-runtime backends. PiperOrigin-RevId: 744976915 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 32 +++++++++++++------- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 6 ++-- tensorflow/compiler/mlir/lite/tests/ops.mlir | 8 +++++ 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index ddb9a9d1017a63..0b8cda7329f11d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -4442,6 +4442,27 @@ int64_t TransposeConvOp::GetArithmeticCount(Operation* op) { // StridedSliceOp //===----------------------------------------------------------------------===// +bool VerifyStridedSliceOpInputRankConstraints(StridedSliceOp op) { + auto ranked_input_type = + mlir::dyn_cast(op.getInput().getType()); + + // If input is unranked, there is nothing else to be verified. + if (!ranked_input_type) return true; + const int num_input_dims = ranked_input_type.getRank(); + + // The kernel will reshape the input tensor with new axis, it only supports + // this reshaped tensor up to 5D. + const uint32_t ellipsis_mask = op.getEllipsisMask(); + const uint32_t new_axis_mask = op.getNewAxisMask(); + int num_added_axis = 0; + for (int i = 0; i < 8; ++i) { + if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) { + num_added_axis++; + } + } + return (num_input_dims + num_added_axis <= 5); +} + LogicalResult StridedSliceOp::verify() { StridedSliceOp op = *this; auto ranked_input_type = @@ -4468,17 +4489,6 @@ LogicalResult StridedSliceOp::verify() { if (strides_type.getDimSize(0) > num_input_dims) return failure(); } - // The kernel will reshape the input tensor with new axis, it only supports - // this reshaped tensor up to 5D. - uint32_t ellipsis_mask = op.getEllipsisMask(); - uint32_t new_axis_mask = op.getNewAxisMask(); - int num_added_axis = 0; - for (int i = 0; i < 8; ++i) { - if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) { - num_added_axis++; - } - } - if (num_input_dims + num_added_axis > 5) return failure(); return success(); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 8ffd327e94583b..87f6754b4cc9da 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -2430,7 +2430,8 @@ def TFL_SliceOp : TFL_Op<"slice", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, Pure, SameOperandsAndResultsScale, - TFL_OperandHasRankAtMost<0, 5>, + TFL_RuntimePredOpTrait<"input must have rank at most 5", + TFL_OperandHasRankAtMostPred<0, 5>>, TFL_OperandHasRankAtMost<1, 1>, TFL_OperandHasRankAtMost<2, 1>]> { let summary = "Return a slice from 'input'."; @@ -3979,7 +3980,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultsScale, - TFL_OperandHasRankAtMost<0, 5>, + TFL_RuntimePredOpTrait<"input (with new_axis) must have rank at most 5", + CPred<"TFL::VerifyStridedSliceOpInputRankConstraints(llvm::cast($_op))">>, TFL_OperandHasRank<1, 1>, TFL_OperandHasRank<2, 1>, TFL_OperandHasRank<3, 1> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 772a462747d01a..c2096033859fd4 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1759,6 +1759,14 @@ func.func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %a // ----- +func.func @testStridedSliceWithInvalidInputRank(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x1x1x2x2x5xf32> { + // expected-error @+1 {{op failed to verify that input (with new_axis) must have rank at most 5}} + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 6 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1x1x2x2x5xf32> + func.return %0 : tensor<1x1x1x2x2x5xf32> +} + +// ----- + // CHECK-LABEL: testOneHot func.func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xf32> { // CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) <{axis = -1 : i32}> : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> From ca9a6ab219aa2ae84dc664e68f10b15c2c7ebc76 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Mon, 7 Apr 2025 22:36:01 -0700 Subject: [PATCH 0333/1324] Refactor spmd dot handler on the `PadWithZero`. PiperOrigin-RevId: 744984992 --- .../xla/xla/service/spmd/dot_handler.cc | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 7bfe5d4c19a515..b9105c4157d9b9 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -633,7 +633,7 @@ std::optional GetWindowedEinsumConfiguration( computation_time_in_ms = visitor->GetComputationTimeInMilliSec(dot); std::vector lhs_contracting_dims; - lhs_contracting_dims.reserve(new_lhs.base_shape().dimensions_size()); + lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size()); for (const auto& cd : dims_mapping.contracting_dims) { lhs_contracting_dims.push_back(cd.lhs); } @@ -1989,16 +1989,17 @@ absl::StatusOr PartitionBaseCase( // on the other side. if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) { - lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero(); - rhs = rhs.PadWithZero(); + lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); } else { - lhs = lhs.PadWithZero(); - rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero(); + rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); } + + lhs = lhs.PadWithZero(); + rhs = rhs.PadWithZero(); TF_ASSIGN_OR_RETURN( auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); std::vector lhs_contracting_dims; - lhs_contracting_dims.reserve(lhs.base_shape().dimensions_size()); + lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size()); for (const auto& cd : dims_mapping.contracting_dims) { lhs_contracting_dims.push_back(cd.lhs); } @@ -2064,7 +2065,7 @@ absl::StatusOr PartitionBaseCase( // If LHS and output are replicated, we compare the cost of all-gather on RHS // vs all-reduce on the output. - const bool rhs_constracting_fully_partitioned = + const bool rhs_contracting_fully_partitioned = (rhs_contracting_partitions == num_partitions) && lhs.sharding().IsReplicated() && ShapeUtil::ElementsIn(rhs.base_shape()) > @@ -2075,21 +2076,22 @@ absl::StatusOr PartitionBaseCase( ShapeUtil::ElementsIn(lhs.base_shape()) > ShapeUtil::ElementsIn(output_base_shape); - if (rhs_constracting_fully_partitioned) { - lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero(); - rhs = rhs.PadWithZero(); + if (rhs_contracting_fully_partitioned) { + lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); } else if (lhs_contracting_fully_partitioned) { - lhs = lhs.PadWithZero(); - rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero(); + rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); } else { return nullptr; } + lhs = lhs.PadWithZero(); + rhs = rhs.PadWithZero(); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); std::vector lhs_contracting_dims; - lhs_contracting_dims.reserve(lhs.base_shape().dimensions_size()); + lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size()); for (const auto& cd : dims_mapping.contracting_dims) { lhs_contracting_dims.push_back(cd.lhs); } @@ -2863,8 +2865,7 @@ absl::StatusOr PartitionDotGroupOnContractingImpl( } lhs_skipped_dims.push_back(i); } - lhs = lhs.PadWithZero( - /*left_padded_dims=*/{}, lhs_skipped_dims); + lhs = lhs.PadWithZero(/*left_padded_dims=*/{}, lhs_skipped_dims); std::vector rhs_skipped_dims; for (int64_t i = 0; i < rhs.base_shape().dimensions_size(); ++i) { if (absl::c_linear_search(rhs_dims, i)) { @@ -2872,8 +2873,7 @@ absl::StatusOr PartitionDotGroupOnContractingImpl( } rhs_skipped_dims.push_back(i); } - rhs = rhs.PadWithZero( - /*left_padded_dims=*/{}, rhs_skipped_dims); + rhs = rhs.PadWithZero(/*left_padded_dims=*/{}, rhs_skipped_dims); top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding); lhs.hlo()->set_sharding(lhs_grouped.sharding); top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding); From d95400cccca9a4acf76fd9e9ec547fc8e3525cf3 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 7 Apr 2025 22:46:18 -0700 Subject: [PATCH 0334/1324] [xla:gpu] NFC: improve helper function for nested gemm/concat fusions. Follow-up from cl/743505373. PiperOrigin-RevId: 744987309 --- .../gpu/transforms/nest_gemm_fusion.cc | 82 +++++++++---------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index ec2530a6c098d1..776d0c83f3c724 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -68,13 +68,13 @@ limitations under the License. namespace xla::gpu { namespace { -// Fuses the given instructions together. The instructions are expected to be -// passed in def-before-use order. The resulting fusion has a single root -// instruction, which is the last instructions in the input span. We only -// replace the uses of the root in 'consumer', and leave other users alone. -absl::Status FuseInstructionsForConsumer( - absl::Span instructions, HloInstruction& consumer) { - HloComputation::Builder builder(instructions.back()->name()); + +// Creates a fusion for instructions starting from 'root' and returns it. +absl::StatusOr FuseInstructionsFromRoot(HloInstruction& root) { + std::vector instructions = + root.parent()->MakeInstructionPostOrderFrom(root); + + HloComputation::Builder builder(root.name()); absl::flat_hash_map old_to_new_mapping; @@ -108,27 +108,37 @@ absl::Status FuseInstructionsForConsumer( old_to_new_mapping[instruction] = builder.AddInstruction( instruction->CloneWithNewOperands(instruction->shape(), new_operands)); } - - HloInstruction* old_root = instructions.back(); - old_to_new_mapping[old_root]->MarkAsRoot(); + old_to_new_mapping[&root]->MarkAsRoot(); HloComputation* computation = - old_root->GetModule()->AddComputationAndUnifyNamesAndIds( - builder.Build(), /*is_entry=*/false); + root.GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), + /*is_entry=*/false); HloInstruction* fusion = - old_root->parent()->AddInstruction(HloInstruction::CreateFusion( - old_root->shape(), HloInstruction::FusionKind::kCustom, parameters, + root.parent()->AddInstruction(HloInstruction::CreateFusion( + root.shape(), HloInstruction::FusionKind::kCustom, parameters, computation)); fusion->GetModule()->SetAndUniquifyInstrName(fusion, "block_fusion"); + return fusion; +} + +// Fuses the instructions starting from 'root' for 'consumer'. Other users of +// 'root' are not affected. Annotates fusion with `kTritonNestedGemmFusionKind`. +absl::Status FuseInstructionsForConsumer(HloInstruction& root, + HloInstruction& consumer) { + CHECK(absl::c_count(consumer.operands(), &root) != 0) + << "Consumer " << consumer.ToString() << " does not use root " + << root.ToString(); + + TF_ASSIGN_OR_RETURN(HloInstruction * fusion, FuseInstructionsFromRoot(root)); + TF_ASSIGN_OR_RETURN(auto gpu_config, fusion->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - backend_config.set_kind(std::string(kTritonNestedGemmFusionKind)); + gpu_config.mutable_fusion_backend_config()->set_kind( + std::string(kTritonNestedGemmFusionKind)); TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); - for (int64_t operand_index : consumer.OperandIndices(old_root)) { + for (int64_t operand_index : consumer.OperandIndices(&root)) { TF_RETURN_IF_ERROR(consumer.ReplaceOperandWith(operand_index, fusion)); } @@ -173,12 +183,12 @@ absl::Status AnnotateDotOperandNestedFusionImpl( block_level_parameters.num_ctas = config.num_ctas; block_level_parameters.num_stages = config.num_stages; - TF_ASSIGN_OR_RETURN(auto backend_config, + TF_ASSIGN_OR_RETURN(auto gpu_config, nested_fusion.backend_config()); - *backend_config.mutable_fusion_backend_config() + *gpu_config.mutable_fusion_backend_config() ->mutable_block_level_fusion_config() = block_level_parameters.ToBlockLevelFusionConfig(); - TF_RETURN_IF_ERROR(nested_fusion.set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(nested_fusion.set_backend_config(gpu_config)); return absl::OkStatus(); } @@ -298,21 +308,11 @@ absl::StatusOr GetTritonGemmConfig( // and annotates them with `kTritonNestedGemmFusionKind`. absl::Status FuseAndAnnotateConcatOperands(HloComputation* computation) { for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConcatenate) { - for (int i = 0; i < instr->operand_count(); ++i) { - TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( - computation->MakeInstructionPostOrderFrom( - *instr->mutable_operand(i)), - *instr)); - HloInstruction* new_operand = instr->mutable_operand(i); - TF_ASSIGN_OR_RETURN(auto gpu_config, - new_operand->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - backend_config.clear_triton_gemm_config(); - backend_config.set_kind(std::string(kTritonNestedGemmFusionKind)); - TF_RETURN_IF_ERROR(new_operand->set_backend_config(gpu_config)); - } + if (instr->opcode() != HloOpcode::kConcatenate) { + continue; + } + for (HloInstruction* operand : instr->mutable_operands()) { + TF_RETURN_IF_ERROR(FuseInstructionsForConsumer(*operand, *instr)); } } return absl::OkStatus(); @@ -333,17 +333,15 @@ absl::Status MakeNestedFusionFromGemmFusion(HloFusionInstruction* fusion, TF_RETURN_IF_ERROR(FuseAndAnnotateConcatOperands(computation)); // Left-hand side of the dot. - TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( - computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(0)), - *dot)); + TF_RETURN_IF_ERROR( + FuseInstructionsForConsumer(*dot->mutable_operand(0), *dot)); TF_RETURN_IF_ERROR(AnnotateDotLhsNestedFusion( *::xla::Cast(dot->mutable_operand(0)), *dot, config)); // Right-hand side of the dot. - TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( - computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(1)), - *dot)); + TF_RETURN_IF_ERROR( + FuseInstructionsForConsumer(*dot->mutable_operand(1), *dot)); TF_RETURN_IF_ERROR(AnnotateDotRhsNestedFusion( *::xla::Cast(dot->mutable_operand(1)), *dot, config)); From 531d1fd0bc24fa2678e2fc332036c70009d7e9a7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 22:52:52 -0700 Subject: [PATCH 0335/1324] Automated Code Change PiperOrigin-RevId: 744988839 --- .../mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc index f2ad3562a9f3fd..d07f3178552f76 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/SmallVector.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" From afa14754739afcf5c3319ebf012992f85115d22f Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 7 Apr 2025 23:22:35 -0700 Subject: [PATCH 0336/1324] Reverts c2032e84aec1063e48007e528b53793c5e368b1c PiperOrigin-RevId: 744995845 --- third_party/xla/xla/hlo/analysis/hlo_reachability.cc | 7 ++----- third_party/xla/xla/hlo/analysis/hlo_reachability.h | 3 +-- .../collectives/collectives_schedule_linearizer.cc | 5 ++--- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_reachability.cc b/third_party/xla/xla/hlo/analysis/hlo_reachability.cc index d059ece78afa7f..a843013b9eb31d 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_reachability.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability.cc @@ -110,14 +110,11 @@ std::unique_ptr HloReachabilityMap::BuildWithRestrictions( } std::unique_ptr HloReachabilityMap::Build( - const HloComputation* computation, - const std::vector& po_instructions) { + const HloComputation* computation) { HloComputation::ChannelDependencies channel_dependencies = computation->ComputeChannelDependencies(); std::vector instructions = - po_instructions.empty() - ? computation->MakeInstructionPostOrder(channel_dependencies) - : po_instructions; + computation->MakeInstructionPostOrder(channel_dependencies); auto result = std::make_unique(instructions); auto get_bit_set = [&](const HloInstruction* instruction) -> BitSet& { diff --git a/third_party/xla/xla/hlo/analysis/hlo_reachability.h b/third_party/xla/xla/hlo/analysis/hlo_reachability.h index 74418975b22d7a..68faafe43cab15 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_reachability.h +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability.h @@ -55,8 +55,7 @@ class HloReachabilityMap { // dependencies (operands) and control dependencies are considered for // reachability. Trivially an instruction is reachable from itself. static std::unique_ptr Build( - const HloComputation* computation, - const std::vector& po_instructions = {}); + const HloComputation* computation); // Similar to the above Build operation except that it tries to identify // paths between instructions that do not contain control instructions diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc index 0d49977df1bc69..f5acfa5a8672c3 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc @@ -43,15 +43,14 @@ absl::StatusOr CollectivesScheduleLinearizer::Run( module->MakeNonfusionComputations(execution_threads)) { std::unique_ptr reachability; HloInstruction* prev_done = nullptr; - auto post_order = computation->MakeInstructionPostOrder(); - for (HloInstruction* inst : post_order) { + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { auto* next = DynCast(inst); if (!next) { continue; } // Build reachability map on demand if we actually see collectives. if (!reachability) { - reachability = HloReachabilityMap::Build(computation, post_order); + reachability = HloReachabilityMap::Build(computation); } // Derive the 'start' and 'done' peers of this instruction. For non-async // variants of collectives, they are the same as this instruction. For From 9b0bf248c816911a6f7d4c04a5be19c576cef4ff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 23:29:41 -0700 Subject: [PATCH 0337/1324] Automated Code Change PiperOrigin-RevId: 744997474 --- tensorflow/lite/core/async/task_internal_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/lite/core/async/task_internal_test.cc b/tensorflow/lite/core/async/task_internal_test.cc index b0dc1ae385917f..68e8004fa5d434 100644 --- a/tensorflow/lite/core/async/task_internal_test.cc +++ b/tensorflow/lite/core/async/task_internal_test.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/async/task_internal.h" -#include - #include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" From 7a15faa4828de40c1409845f8f031431092c97d2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Apr 2025 23:48:20 -0700 Subject: [PATCH 0338/1324] Automated Code Change PiperOrigin-RevId: 745002346 --- third_party/xla/xla/backends/gpu/codegen/triton/BUILD | 10 ++++++++++ .../xla/backends/gpu/codegen/triton/emitter_helpers.cc | 1 + .../xla/backends/gpu/codegen/triton/emitter_helpers.h | 1 + .../xla/backends/gpu/codegen/triton/fusion_emitter.cc | 2 -- .../triton/fusion_emitter_device_legacy_port_test.cc | 1 + .../triton/fusion_emitter_device_legacy_test.cc | 1 + .../gpu/codegen/triton/fusion_emitter_large_test.cc | 2 ++ .../gpu/codegen/triton/fusion_emitter_legacy_matmul.cc | 2 ++ .../codegen/triton/fusion_emitter_parametrized_test.cc | 1 + .../gpu/codegen/triton/kernel_name_tracer_cuda.cc | 1 + .../xla/xla/backends/gpu/codegen/triton/support.cc | 1 - .../xla/xla/backends/gpu/codegen/triton/test_utils.cc | 2 ++ .../xla/xla/backends/gpu/codegen/triton/test_utils.h | 2 ++ .../xla/backends/gpu/codegen/triton/tma_utils_test.cc | 1 + 14 files changed, 25 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 72c937aa373ac1..f5789d3f3c3677 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -98,6 +98,7 @@ cc_library( deps = [ "//xla:literal", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", @@ -296,6 +297,7 @@ cc_library( deps = [ ":dot_algorithms", ":emitter_helpers", + "//xla:autotuning_proto_cc", "//xla:comparison_util", "//xla:literal", "//xla:shape_util", @@ -310,6 +312,7 @@ cc_library( "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/mlir_hlo:transformation_helpers", "//xla/service:algorithm_util", + "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_indexing_utils", @@ -519,6 +522,7 @@ xla_test( ":test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", @@ -563,6 +567,7 @@ xla_test( ":test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", @@ -753,6 +758,7 @@ cc_library( "//xla/backends/profiler/gpu:cupti_tracer", "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/algorithm:container", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -783,6 +789,8 @@ cc_library( ":fusion_emitter", "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/testlib:filecheck", @@ -841,6 +849,7 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", ], ) @@ -1002,6 +1011,7 @@ xla_cc_test( deps = [ ":tma_utils", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/codegen:emitter_loc_op_builder", "//xla/service:hlo_module_config", "//xla/service/llvm_ir:llvm_util", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index e231a2c6a9e934..f7124577c5a061 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" #include "triton/Dialect/Triton/IR/Dialect.h" diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h index 20d2babaf5f3d1..093e6cfdaea680 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h @@ -42,6 +42,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tsl/platform/status.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" namespace xla::gpu::triton { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 8467a8d7174943..575dee4b8f0503 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -17,11 +17,9 @@ limitations under the License. #include #include -#include #include #include #include -#include // NOLINT(build/c++11): required to interface with LLVM #include #include #include diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 52f5a728a69892..11e9fb9569b73b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -58,6 +58,7 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/path.h" namespace xla { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc index e33846f91983fa..9feda522639a66 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/path.h" namespace xla { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_large_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_large_test.cc index 4c3979fdf34887..f0eea382bc50b4 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_large_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_large_test.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/log/check.h" +#include "absl/strings/string_view.h" #include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index e473078c81737e..8a9635baead4fa 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -53,6 +53,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" +#include "xla/autotuning.pb.h" #include "xla/backends/gpu/codegen/triton/dot_algorithms.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" #include "xla/codegen/emitter_loc_op_builder.h" @@ -72,6 +73,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_indexing_utils.h" diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc index b250de906e2daa..7dc1de19dc85d1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/kernel_name_tracer_cuda.cc b/third_party/xla/xla/backends/gpu/codegen/triton/kernel_name_tracer_cuda.cc index 8fad2d4644453a..e6941772374a2b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/kernel_name_tracer_cuda.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/kernel_name_tracer_cuda.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/tsl/profiler/utils/time_utils.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace xla::gpu { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 6852852469608b..50b723d6d21267 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/support.h" -#include #include #include #include diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc index 615bb56f5adbf1..997a5fb4c450bc 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc @@ -57,6 +57,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" namespace xla::gpu { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h index 899483fd942f8e..62c6d09c1af620 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.h @@ -39,6 +39,8 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" namespace xla::gpu { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc index e339e63c102932..be4e0b6fb9b480 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" From 81d74c0cbabd0be35ae34438725a0b66a9183c5e Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Tue, 8 Apr 2025 00:45:35 -0700 Subject: [PATCH 0339/1324] [XLA:GPU] Add triton support test for fft PiperOrigin-RevId: 745018851 --- .../backends/gpu/codegen/triton/support.cc | 1 - .../gpu/codegen/triton/support_test.cc | 95 ++++++++++++++++++- 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 50b723d6d21267..fff08e4d0da7b7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -591,7 +591,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kDynamicReshape: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kFft: case HloOpcode::kGather: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index cae593137e5454..63f742904f604d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -2557,6 +2557,99 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest()), ::testing::Bool()), CholeskyTestName); +class FftTest : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(FftTest, FFT) { + auto [data_type, cc] = GetParam(); + + const std::string hlo_text = R"( + ENTRY triton_computation { + parameter = $0[16,16] parameter(0) + ROOT fft_op = $0[16,16] fft(parameter), fft_type=FFT, fft_length={16} + })"; + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, data_type, HloOpcode::kFft)); + + RunSupportTest(std::move(ti), {4, 4}, cc); +} + +TEST_P(FftTest, IFFT) { + auto [data_type, cc] = GetParam(); + + const std::string hlo_text = R"( + ENTRY triton_computation { + parameter = $0[16,16] parameter(0) + ROOT fft_op = $0[16,16] fft(parameter), fft_type=IFFT, fft_length={16} + })"; + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, data_type, HloOpcode::kFft)); + + RunSupportTest(std::move(ti), {4, 4}, cc); +} + +TEST_P(FftTest, RFFT) { + auto [data_type, cc] = GetParam(); + const std::string complex_data_type_str = + primitive_util::LowercasePrimitiveTypeName(data_type); + // Real type matching the complex type for real -> complex conversion. + const std::string real_data_type_str = + primitive_util::LowercasePrimitiveTypeName( + primitive_util::ComplexComponentType(data_type)); + + const std::string hlo_text = absl::Substitute( + R"( + ENTRY triton_computation { + parameter = $0[16,16,32] parameter(0) + ROOT fft_op = $1[16,16,17] fft(parameter), fft_type=RFFT, fft_length={16,32} + })", + real_data_type_str, complex_data_type_str); + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, data_type, HloOpcode::kFft)); + + RunSupportTest(std::move(ti), {4, 4, 4}, cc); +} + +TEST_P(FftTest, IRFFT) { + auto [data_type, cc] = GetParam(); + const std::string complex_data_type_str = + primitive_util::LowercasePrimitiveTypeName(data_type); + // Real type matching the complex type for complex -> real conversion. + const std::string real_data_type_str = + primitive_util::LowercasePrimitiveTypeName( + primitive_util::ComplexComponentType(data_type)); + + const std::string hlo_text = absl::Substitute( + R"( + ENTRY triton_computation { + parameter = $0[16,16,32,33] parameter(0) + ROOT fft_op = $1[16,16,32,64] fft(parameter), fft_type=IRFFT, fft_length={16,32,64} + })", + complex_data_type_str, real_data_type_str); + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(hlo_text, data_type, HloOpcode::kFft)); + + RunSupportTest(std::move(ti), {4, 4, 4, 4}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + FftTestSuite, FftTest, + // FFT takes a complex type either as input, output or both. When there is a + // complex <-> real conversion, the real type can be directly inferred from + // the complex type (C64 <-> F32, C128 <-> F64). + ::testing::Combine(::testing::ValuesIn({C64, C128}), + ::testing::ValuesIn(AllDevicesToTest())), + TritonSupportTestTypeAndDeviceToString); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start @@ -2566,7 +2659,6 @@ constexpr std::array kUnsupportedOps = { HloOpcode::kDynamicReshape, HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, - HloOpcode::kFft, HloOpcode::kGather, HloOpcode::kGetTupleElement, HloOpcode::kInfeed, @@ -2625,6 +2717,7 @@ absl::flat_hash_set AllTestedOpcodes() { ret.emplace(HloOpcode::kCustomCall); ret.emplace(HloOpcode::kDomain); ret.emplace(HloOpcode::kDot); + ret.emplace(HloOpcode::kFft); ret.emplace(HloOpcode::kGetDimensionSize); ret.emplace(HloOpcode::kReverse); ret.emplace(HloOpcode::kRngBitGenerator); From c209a806a225151c836773ebfe75fbe16c1b4fa7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 00:59:48 -0700 Subject: [PATCH 0340/1324] Replace outdated select() on "cpu": "fuchsia" with platform API equivalent, and fix x86 constraint value in android_x86 config_setting in compiler/xla/tsl/BUILD PiperOrigin-RevId: 745022649 --- third_party/xla/xla/tsl/BUILD | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/tsl/BUILD b/third_party/xla/xla/tsl/BUILD index 3cdbbb17edf2b5..a230f93e9e5fc8 100644 --- a/third_party/xla/xla/tsl/BUILD +++ b/third_party/xla/xla/tsl/BUILD @@ -399,15 +399,7 @@ config_setting( config_setting( name = "fuchsia", - constraint_values = if_google( - ["@platforms//os:fuchsia"], - [], - ), - values = if_oss( - # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. - {"cpu": "fuchsia"}, - {}, - ), + constraint_values = ["@platforms//os:fuchsia"], visibility = ["//visibility:public"], ) @@ -416,7 +408,7 @@ config_setting( name = "android_x86", constraint_values = [ - ":x86_any", + "@platforms//cpu:x86_32", "@platforms//os:android", ], values = dict( @@ -442,15 +434,6 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) -selects.config_setting_group( - name = "x86_any", - match_any = [ - "@platforms//cpu:x86_32", - "@platforms//cpu:x86_64", - ], - visibility = ["//visibility:public"], -) - selects.config_setting_group( name = "linux_any", match_any = [ From ffebb34ec0947940e6060d150f057211167df653 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 8 Apr 2025 01:09:34 -0700 Subject: [PATCH 0341/1324] `PadWithZero` only considers contracting dimensions when we partition the standard/ragged dot along contracting dimensions. When we partition dot operation along contracting dimensions, we need to 1. assign zero to the padding before dot 2. insert all-reduce after dot We only need to `PadWithZero` on the contracting dimensions. For other dimensions (batch, non-contracting dimensions), the padding can be any value. PiperOrigin-RevId: 745025806 --- .../xla/xla/service/spmd/dot_handler.cc | 19 ++--------- .../xla/xla/service/spmd/spmd_partitioner.cc | 19 +++++++++-- .../xla/xla/service/spmd/spmd_partitioner.h | 7 ++++ .../xla/service/spmd/spmd_partitioner_test.cc | 34 +++++++++++++++++-- 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index b9105c4157d9b9..9ffd37a1040fa1 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -2858,22 +2858,9 @@ absl::StatusOr PartitionDotGroupOnContractingImpl( lhs = lhs.Reshard(lhs_sharding); } // Mask out invalid data. - std::vector lhs_skipped_dims; - for (int64_t i = 0; i < lhs.base_shape().dimensions_size(); ++i) { - if (absl::c_linear_search(lhs_dims, i)) { - continue; - } - lhs_skipped_dims.push_back(i); - } - lhs = lhs.PadWithZero(/*left_padded_dims=*/{}, lhs_skipped_dims); - std::vector rhs_skipped_dims; - for (int64_t i = 0; i < rhs.base_shape().dimensions_size(); ++i) { - if (absl::c_linear_search(rhs_dims, i)) { - continue; - } - rhs_skipped_dims.push_back(i); - } - rhs = rhs.PadWithZero(/*left_padded_dims=*/{}, rhs_skipped_dims); + lhs = lhs.PadWithZeroOnSpecifiedDims(lhs_dims); + rhs = rhs.PadWithZeroOnSpecifiedDims(rhs_dims); + top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding); lhs.hlo()->set_sharding(lhs_grouped.sharding); top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding); diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index b1c2d3ab974a22..40593280510c31 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -750,6 +750,19 @@ PartitionedHlo PartitionedHlo::PadWithZero( return PadWithValue(zero, left_padded_dims, skipped_dims); } +PartitionedHlo PartitionedHlo::PadWithZeroOnSpecifiedDims( + absl::Span dims, + absl::Span left_padded_dims) const { + std::vector skipped_dims; + skipped_dims.reserve(base_shape_.dimensions_size() - dims.size()); + for (int64_t i = 0; i < base_shape_.dimensions_size(); ++i) { + if (!absl::c_linear_search(dims, i)) { + skipped_dims.push_back(i); + } + } + return PadWithZero(left_padded_dims, skipped_dims); +} + std::optional PartitionedHlo::ReshardAsWindowedInput(const Window& window, const HloSharding& target, @@ -5008,8 +5021,10 @@ absl::Status SpmdPartitioningVisitor::HandleRaggedDot(HloInstruction* hlo) { } if (!sharded_lhs_contracting_dims.empty()) { - lhs = lhs.PadWithZero(); - rhs = rhs.PadWithZero(); + lhs = + lhs.PadWithZeroOnSpecifiedDims(dot_dnums.lhs_contracting_dimensions()); + rhs = + rhs.PadWithZeroOnSpecifiedDims(dot_dnums.rhs_contracting_dimensions()); } HloInstruction* phlo; diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.h b/third_party/xla/xla/service/spmd/spmd_partitioner.h index a0994bd6433085..4e9532bb7c9c94 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.h @@ -493,9 +493,16 @@ class PartitionedHlo { absl::Span left_padded_dims = {}, absl::Span skipped_dims = {}) const; + // Same as PadWithValue with zero as the pad value. PartitionedHlo PadWithZero(absl::Span left_padded_dims = {}, absl::Span skipped_dims = {}) const; + // PadWithZero consider all dimensions except the skipped dimensions. + // PadWithZeroOnSpecifiedDims considers only the specified dimensions. + PartitionedHlo PadWithZeroOnSpecifiedDims( + absl::Span dims, + absl::Span left_padded_dims = {}) const; + // Returns the SPMD instruction. HloInstruction* hlo() const { return hlo_; } diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index ad2863654444a2..4d409615c973e6 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -15807,7 +15807,7 @@ ENTRY entry { EXPECT_THAT(module->entry_computation()->root_instruction(), ragged_dot); } -TEST_P(SpmdPartitioningTest, RaggedDotBatchMode) { +TEST_P(SpmdPartitioningTest, RaggedDotBatchModeWithPadding) { absl::string_view hlo_string = R"( HloModule module @@ -15824,7 +15824,6 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/16)); - LOG(ERROR) << module->ToString(); auto param0 = AllOf(op::Parameter(0), op::Shape("f32[8,16,32]")); auto param0_pad = AllOf(op::Select(_, param0, op::Broadcast(op::Constant())), @@ -15845,6 +15844,37 @@ ENTRY entry { EXPECT_THAT(replica_groups[0].replica_ids(), ::testing::ElementsAre(0, 1)); } +TEST_P(SpmdPartitioningTest, RaggedDotBatchModeWithoutPadding) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + a = f32[15,31,64] parameter(0), sharding={devices=[2,2,2,2]<=[2,2,2,2]T(0,1,3,2) last_tile_dim_replicate} + b = f32[15,64,7] parameter(1), sharding={devices=[2,2,2,2]<=[2,2,2,2]T(0,3,2,1) last_tile_dim_replicate} + c = u32[4] parameter(2), sharding={devices=[4,4]<=[16] last_tile_dim_replicate} + ROOT dot = f32[15,31,7] ragged-dot(a, b, c), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_ragged_dims={0}, + sharding={devices=[2,2,2,2]<=[16] last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/16)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[8,16,32]")); + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[8,32,4]")); + auto dot = AllOf(op::Dot(param0, param1), op::Shape("f32[8,16,4]")); + auto all_reduce = AllOf(op::AllReduce(dot), op::Shape("f32[8,16,4]")); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, all_reduce); + + auto replica_groups = Cast(root)->replica_groups(); + EXPECT_EQ(replica_groups.size(), 8); + EXPECT_THAT(replica_groups[0].replica_ids(), ::testing::ElementsAre(0, 1)); +} + } // namespace } // namespace spmd } // namespace xla From 17dfefba077c2b5d9705860a28147deba4aeb813 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 8 Apr 2025 09:34:53 +0100 Subject: [PATCH 0342/1324] [Tosa] Add support for legalizing scatter_nd (#72370) This commit legalizes tf(l).scatter_nd to tosa.scatter in a similar approach to the existing gather_nd support. Specifically, inputs are rewritten to the expected formats by tosa.scatter and finally the output of tosa.scatter is reshaped to the expected output of scatter_nd. tosa.scatter does not support duplicated indices while TF does. Therefore, we restrict legalization of scatter_nd to a constant indices tensor only and check that the provided indices are unique. Change-Id: Id0c6c4df1aa09807bc8f47c430924846f4670ba8 Signed-off-by: Luke Hutton --- .../mlir/tosa/tests/tf-to-tosa-pipeline.mlir | 14 ++ .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 65 +++++ .../mlir/tosa/transforms/legalize_common.cc | 224 ++++++++++++++++++ .../mlir/tosa/transforms/legalize_common.h | 6 + .../mlir/tosa/transforms/legalize_tf.cc | 18 ++ .../mlir/tosa/transforms/legalize_tfl.cc | 18 ++ .../mlir/tosa/transforms/legalize_utils.cc | 31 +++ .../mlir/tosa/transforms/legalize_utils.h | 8 + 8 files changed, 384 insertions(+) diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 502ec7e74ade06..99d5bd23688a8c 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -1053,6 +1053,20 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { func.return %1 : tensor<6x7x21x3xf32> } +// ----- + +// CHECK-LABEL: test_scatter_nd +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x224x512xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR4:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512xf32>, tensor<1x1xi32>, tensor<1x1x512xf32>) +// CHECK: return %[[VAR4]] +func.func @test_scatter_nd(%arg0: tensor<1x1x512xf32>) -> tensor<1x224x512xf32> { + %shape = "tf.Const"() {device = "", value = dense<[1, 224, 512]> : tensor<3xi32>} : () -> tensor<3xi32> + %indices = "tf.Const"() {device = "", value = dense<[[[0, 0]]]>: tensor<1x1x2xi32>} : () -> tensor<1x1x2xi32> + %1 = "tf.ScatterNd"(%indices, %arg0, %shape) {device = ""} : (tensor<1x1x2xi32>, tensor<1x1x512xf32>, tensor<3xi32>) -> tensor<1x224x512xf32> + func.return %1 : tensor<1x224x512xf32> +} // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 8623a886b353eb..1bc67dd3088e71 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -3050,6 +3050,71 @@ func.func @test_sparse_to_dense(%arg0 : tensor, %arg1 : tensor) // ----- +// CHECK-LABEL: test_scatter_nd +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x224x512xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR5:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512xf32>, tensor<1x1xi32>, tensor<1x1x512xf32>) +// CHECK: return %[[VAR5]] +func.func @test_scatter_nd(%arg0: tensor<1x1x512xf32>) -> tensor<1x224x512xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[1, 224, 512]> : tensor<3xi32>}> : () -> tensor<3xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[0, 0]]]> : tensor<1x1x2xi32>}> : () -> tensor<1x1x2xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<1x1x2xi32>, tensor<1x1x512xf32>, tensor<3xi32>) -> tensor<1x224x512xf32> + func.return %0 : tensor<1x224x512xf32> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_reshape +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}8, 4, 1]]> : tensor<1x3xi32>}> : () +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x16x4xf32>}> : () +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<{{\[\[}}0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]]> : tensor<8x3xi32>}> : () +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[NEW_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 8, 4]> : tensor<3xindex>} +// CHECK-DAG: %[[NEW_SHAPE1:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} +// CHECK-DAG: %[[NEW_SHAPE2:.*]] = tosa.const_shape {values = dense<[2, 2, 4, 4]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[NEW_SHAPE]] : (tensor<2x2x2x4xf32>, !tosa.shape<3>) +// CHECK-DAG: %[[VAR5:.*]] = tosa.mul %[[VAR3]], %[[VAR1]], %[[SHIFT]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) +// CHECK-DAG: %[[VAR6:.*]] = tosa.reduce_sum %[[VAR5]] {axis = 1 : i32} : (tensor<8x3xi32>) +// CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[NEW_SHAPE1]] : (tensor<8x1xi32>, !tosa.shape<2>) +// CHECK-DAG: %[[VAR8:.*]] = tosa.scatter %[[VAR2]], %[[VAR7]], %[[VAR4]] : (tensor<1x16x4xf32>, tensor<1x8xi32>, tensor<1x8x4xf32>) +// CHECK-DAG: %[[VAR9:.*]] = tosa.reshape %[[VAR8]], %[[NEW_SHAPE2]] : (tensor<1x16x4xf32>, !tosa.shape<4>) +// CHECK-DAG: return %[[VAR9]] +func.func @test_scatter_nd_reshape(%arg0: tensor<2x2x2x4xf32>) -> tensor<2x2x4x4xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[2, 2, 4, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 0, 2], [0, 0, 3]]], [[[1, 0, 0], [1, 0, 1]], [[1, 0, 2], [1, 0, 3]]]]> : tensor<2x2x2x3xi32>}> : () -> tensor<2x2x2x3xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<2x2x2x3xi32>, tensor<2x2x2x4xf32>, tensor<4xi32>) -> tensor<2x2x4x4xf32> + func.return %0 : tensor<2x2x4x4xf32> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_qi8 +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x224x512xi8>}> : () -> tensor<1x224x512x!quant.uniform> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR4:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512x!quant.uniform>, tensor<1x1xi32>, tensor<1x1x512x!quant.uniform>) +// CHECK: return %[[VAR4]] +func.func @test_scatter_nd_qi8(%arg0: tensor<1x1x512x!quant.uniform>) -> tensor<1x224x512x!quant.uniform> { + %shape = "tfl.pseudo_const"() <{value = dense<[1, 224, 512]> : tensor<3xi32>}> : () -> tensor<3xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[0, 0]]]> : tensor<1x1x2xi32>}> : () -> tensor<1x1x2xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<1x1x2xi32>, tensor<1x1x512x!quant.uniform>, tensor<3xi32>) -> tensor<1x224x512x!quant.uniform> + func.return %0 : tensor<1x224x512x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_duplicate_indices +// CHECK: tfl.scatter_nd +func.func @test_scatter_nd_duplicate_indices(%arg0: tensor<2x2x2x4xf32>) -> tensor<2x2x4x4xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[2, 2, 4, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 0, 2], [0, 0, 3]]], [[[1, 0, 0], [1, 0, 0]], [[1, 0, 2], [1, 0, 3]]]]> : tensor<2x2x2x3xi32>}> : () -> tensor<2x2x2x3xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<2x2x2x3xi32>, tensor<2x2x2x4xf32>, tensor<4xi32>) -> tensor<2x2x4x4xf32> + func.return %0 : tensor<2x2x4x4xf32> +} + +// ----- + // CHECK-LABEL: @test_arg_max func.func @test_arg_max(%arg0: tensor<13x21x3xf32>) -> tensor<*xi32> { // CHECK: %[[ARGMAX:.+]] = tosa.argmax %arg0 {axis = 1 : i32} diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 5881990eb0d955..2cdace08b6d193 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -46,6 +46,7 @@ limitations under the License. #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -4407,6 +4408,229 @@ std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, .getResult(); } +std::optional convertScatterNdOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value indices_value, + Value updates_value, + Value shape_value) { + auto const result_type = dyn_cast(result_value.getType()); + auto const indices_type = dyn_cast(indices_value.getType()); + auto const updates_type = dyn_cast(updates_value.getType()); + auto const shape_type = dyn_cast(shape_value.getType()); + + if (!result_type || !indices_type || !updates_type || !shape_type) { + (void)rewriter.notifyMatchFailure( + op, "input/output types must be ranked tensor type"); + return std::nullopt; + } + + // Don't support variable indices yet since we cannot check uniqueness + // of indices in this case + Operation* indices_op = indices_value.getDefiningOp(); + if (!indices_op || !llvm::isa(indices_op)) { + (void)rewriter.notifyMatchFailure(op, "indices must be a constant tensor"); + return std::nullopt; + } + + Type indices_elmt_type = indices_type.getElementType(); + if (!indices_elmt_type.isInteger(32)) { + (void)rewriter.notifyMatchFailure(op, "indices expected to be int32"); + return std::nullopt; + } + + // The tosa scatter operation only supports unique indices, so if there + // are duplicates, we cannot legalize + tosa::ConstOp const_indices = cast(indices_op); + ElementsAttr const_data = const_indices.getValues(); + if (!checkUniqueConstantScatterIndices(indices_type, result_type, + const_data)) { + (void)rewriter.notifyMatchFailure(op, "index values must be unique"); + return std::nullopt; + } + + // N: number of batches + // Always 1 for ScatterND + // + // Because TOSA's SCATTER operator already uses the symbol 'N' for + // the number of batches, we will use the symbol 'ND' to specify the + // number of dimensions that are sliced from input instead of'N' in + // the TF MLIR documentation. + // + // ND: indices.shape[-1] + // + // W: number of indices in each batch + // Computed as: + // product(indices.shape[0:-1]) (all but the last dimension) + // + // K: range of each index + // Computed as: + // product(result.shape[0:ND-1]) + // + // C: number of channels for each index + // Computed as: + // product(result.shape[ND:]) + // + // The updates tensor needs to be reshaped, but not transposed, to move + // the dimensions into [N, W, C] order. + // + // Indices needs to be put in the form of [N, W], but a simple flattening + // will not suffice, because the indices need to index into the [W]-shape + // updates vector instead. + // + // To flatten the coordinates, first reshape indices to a [W, ND] matrix, + // where the matrix now represents W ND-dimensional coordinates into the + // updates tensor. + // + // From here, we take each of the ND dimensions and multiply it with + // the size of the next updates dimension (or 1 for the last + // dimension), then sum all these together with a reduce_sum + // operator. This is exactly the same mathematics as one would use + // flatten the indices of an N-dimensional row-major array into a + // 1-D array in C. + // + // More precisely, do an element-wise multiply with [updates.shape[1 + // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a + // [W]-shaped tensor, then trivially reshape to [N=1, W] to be + // compatible with the SCATTER operator's shape. + // + // Then perform the tosa.SCATTER() operation. + // + // Now we have result = [N, K, C]. + // + // Reshape with a single, simple reshape to the final output shape + // provided by shape_value. + + const unsigned int input_output_rank = result_type.getShape().size(); + const unsigned int indices_rank = indices_type.getShape().size(); + + const unsigned int ND = indices_type.getShape()[indices_rank - 1]; + + if (ND > input_output_rank) { + (void)rewriter.notifyMatchFailure( + op, "size of last dimension of indices must be <= input/output rank"); + return std::nullopt; + } + + // Calculate N, K, W, C. (N is always 1) + auto const indices_shape_begin{indices_type.getShape().begin()}; + auto const result_shape_begin{result_type.getShape().begin()}; + auto const accumulate_func = [](auto const& a_, auto const& b_) { + return a_ * b_; + }; + + const unsigned int N = 1; + const unsigned int W = std::accumulate(indices_shape_begin, + indices_shape_begin + indices_rank - 1, + 1, accumulate_func); + const unsigned int K = std::accumulate( + result_shape_begin, result_shape_begin + ND, 1, accumulate_func); + const unsigned int C = std::accumulate(result_shape_begin + ND, + result_shape_begin + input_output_rank, + 1, accumulate_func); + + SmallVector tosa_indices_shape({N, W}); + SmallVector indices_matrix_shape({W, ND}); + SmallVector tosa_input_shape({N, W, C}); + SmallVector tosa_values_in_out_shape({N, K, C}); + + // Flatten the updates tensor to an [N, W] matrix. + auto input_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(tosa_input_shape)); + auto tosa_input_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_input_shape, + result_type.getElementType()), + updates_value, input_shape_value); + + // Flatten the indices tensor to an [W, ND] matrix. + auto indices_matrix_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(indices_matrix_shape)); + auto indices_matrix_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(indices_matrix_shape, + indices_elmt_type), + indices_value, indices_matrix_shape_value); + + SmallVector flattened_coeff_vec; + for (int i = 1; i < ND; i++) { + flattened_coeff_vec.push_back(result_type.getShape()[i]); + } + flattened_coeff_vec.push_back(1); + for (int i = ND - 1; i > 0; i--) { + flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i]; + } + std::optional flattened_coeff_value = getConstTensor( + rewriter, op, flattened_coeff_vec, + {static_cast(flattened_coeff_vec.size())}); + + if (!flattened_coeff_value) { + (void)rewriter.notifyMatchFailure( + op, "failed to calculate flattened coeff value"); + return std::nullopt; + } + + // Multiply the coefficients by the coordinates + Value mul_x = indices_matrix_reshape_op.getResult(); + Value mul_y = flattened_coeff_value.value(); + RankedTensorType mul_type = tensorflow::GetTypeFromTFTensorShape( + indices_matrix_shape, indices_type.getElementType()); + if (EqualizeRanks(rewriter, op->getLoc(), mul_x, mul_y).failed()) { + (void)rewriter.notifyMatchFailure( + op, "failed to broadcast coefficients over the coordinates"); + return std::nullopt; + } + auto flattened_indices_mul_op = CreateMulOpAndInfer( + rewriter, op, mul_type, mul_x, mul_y); + + // Sum up the products of the coefficients and coordinates + auto flattened_indices_reduce_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, + indices_type.getElementType()), + flattened_indices_mul_op.getResult(), rewriter.getI32IntegerAttr(1)); + + // And reshape to [N, W] + auto tosa_indices_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(tosa_indices_shape)); + auto tosa_indices_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, + indices_type.getElementType()), + flattened_indices_reduce_op.getResult(), tosa_indices_shape_value); + + // Scatter_nd has no input tensor, use a zero tensor + Type const_element_type = updates_type.getElementType(); + auto const_type = + RankedTensorType::get(tosa_values_in_out_shape, const_element_type); + if (mlir::isa(const_element_type)) { + auto quant_type = dyn_cast(const_element_type); + const_element_type = quant_type.getStorageType(); + } + auto const_storage_type = + RankedTensorType::get(tosa_values_in_out_shape, const_element_type); + auto const_attr = DenseElementsAttr::get( + const_storage_type, rewriter.getZeroAttr(const_element_type)); + Value tosa_values_in = + rewriter.create(op->getLoc(), const_type, const_attr); + + // Now the scatter op itself + auto tosa_scatter_op = CreateOpAndInfer( + rewriter, op->getLoc(), result_type, tosa_values_in, + tosa_indices_reshape_op.getResult(), tosa_input_reshape_op.getResult()); + + // Finally, reshape back to the expected output shape. + auto reshape_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(result_type.getShape())); + return CreateOpAndInfer(rewriter, op->getLoc(), result_type, + tosa_scatter_op.getResult(), + reshape_shape_value) + .getResult(); +} + // Lowers OneHot operator to a sequence of TOSA ops. std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value indices_value, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 5ddcec25e821f9..9b118ad6e73335 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -298,6 +298,12 @@ std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value params_value, Value indices_value); +// Lowers ScatterNd operator to a sequence of TOSA ops. +std::optional convertScatterNdOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value indices_value, + Value updates_value, Value shape_value); + // Lowers OneHot operator to a sequence of TOSA ops. std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value indices_value, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index b355829547f0c3..9578d0ecf8c0aa 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -131,6 +131,7 @@ DECL_CONVERT_OP(ResizeNearestNeighbor); DECL_CONVERT_OP(Gather); DECL_CONVERT_OP(GatherV2); DECL_CONVERT_OP(GatherNd); +DECL_CONVERT_OP(ScatterNd); DECL_CONVERT_OP(SelectV2); DECL_CONVERT_OP(SpaceToDepth); DECL_CONVERT_OP(DepthToSpace); @@ -2001,6 +2002,22 @@ LogicalResult ConvertTFGatherNdOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFScatterNdOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_scatternd_op = cast(op); + + const std::optional result = convertScatterNdOp( + rewriter, op, tfl_scatternd_op.getResult(), tfl_scatternd_op.getIndices(), + tfl_scatternd_op.getUpdates(), tfl_scatternd_op.getShape()); + + if (!result) { + return failure(); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult ConvertTFSelectV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_sel_op = cast(op); @@ -2620,6 +2637,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index acb17d1e8e450e..f4877c5c44aac7 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -196,6 +196,7 @@ DECL_CONVERT_OP(Const); DECL_CONVERT_OP(QConst); DECL_CONVERT_OP(Gather); DECL_CONVERT_OP(GatherNd); +DECL_CONVERT_OP(ScatterNd); DECL_CONVERT_OP(SparseToDense); DECL_CONVERT_OP(OneHot); DECL_CONVERT_OP(ArgMax); @@ -4341,6 +4342,22 @@ LogicalResult ConvertTFLGatherNdOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLScatterNdOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_scatternd_op = cast(op); + + const std::optional result = convertScatterNdOp( + rewriter, op, tfl_scatternd_op.getResult(), tfl_scatternd_op.getIndices(), + tfl_scatternd_op.getUpdates(), tfl_scatternd_op.getShape()); + + if (!result) { + return failure(); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_sparse_to_dense_op = cast(op); @@ -4998,6 +5015,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLConst); DEF_PATTERN_INSERT(TFLQConst); DEF_PATTERN_INSERT(TFLGatherNd); + DEF_PATTERN_INSERT(TFLScatterNd); DEF_PATTERN_INSERT(TFLSparseToDense); DEF_PATTERN_INSERT(Constant); DEF_PATTERN_INSERT(TFLOneHot); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 7b688ea3adf8d2..29cc8208d3fa2b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -1413,5 +1413,36 @@ LogicalResult broadcastLowRankTensor(PatternRewriter& rewriter, Operation* op, return success(); } +bool checkUniqueConstantScatterIndices(ShapedType indices_type, + ShapedType result_type, + ElementsAttr const_data) { + llvm::ArrayRef const indices_shape = indices_type.getShape(); + const unsigned int indices_rank = indices_shape.size(); + const unsigned int result_rank = result_type.getRank(); + const unsigned int last_dim_size = indices_shape[indices_rank - 1]; + + // Reconstruct each index from the unshaped constant data array and + // calculate the corresponding flattened index + auto const const_data_range = const_data.getValues(); + assert((const_data_range.size() % last_dim_size == 0) && + "Constant data length should be a multiple of indices_shape[-1]"); + + std::vector flattened_indices; + flattened_indices.reserve(const_data_range.size() / last_dim_size); + for (auto beg = const_data_range.begin(); beg < const_data_range.end(); + beg += last_dim_size) { + std::vector current_single_index(result_rank); + std::copy(beg, beg + last_dim_size, current_single_index.begin()); + const uint64_t f_index{ + ElementsAttr::getFlattenedIndex(result_type, current_single_index)}; + flattened_indices.push_back(f_index); + } + + // If adjacent flattened values are found, there are non-unique indices + std::sort(flattened_indices.begin(), flattened_indices.end()); + return std::adjacent_find(flattened_indices.begin(), + flattened_indices.end()) == flattened_indices.end(); +} + } // namespace tosa } // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index 40d3e9f974e7a0..443054e5bf9f01 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -207,6 +207,14 @@ Value getInputSlicedToItsUsedSize(PatternRewriter& rewriter, Operation* op, // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type); +// Checks if the multi-dimensional indices supplied by a constant tensor +// are unique. This is a useful check for legalizations to tosa.scatter +// which requires indices are unique, while in TF/TFLite they may be +// non-unique. +bool checkUniqueConstantScatterIndices(ShapedType indices_type, + ShapedType result_type, + ElementsAttr const_data); + // Applies a set of patterns greedily to the specified function, then applies // a cleanup to guarantee the function contract and constants are valid. This // means patterns can performed shape inference while not altering immutable From 2bdd806bb7eb7b8810cfe4c5b4b3f02b03bb9989 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 8 Apr 2025 01:29:45 -0700 Subject: [PATCH 0343/1324] Make `MakeBatchPointers` use the kernel registry - Moves `MakeBatchPointers` logic into backends/gpu/runtime since it's a runtime component. - Defines trait for the MakeBatchPointers kernel in stream_executor/gpu/ - Moves the implementations of this kernel into stream_executor/{cuda|rocm} and registers them with the registry. - Makes `MakeBatchPointers` retrieve the kernel by using the kernel registry. - Add the kernel implementations as dependencies to the `all_runtime` targets for CUDA and ROCm. PiperOrigin-RevId: 745031209 --- .../xla/xla/backends/gpu/runtime/BUILD | 53 ++++++++++++- .../backends/gpu/runtime/cholesky_thunk.cc | 6 +- .../gpu/runtime/make_batch_pointers.cc | 57 ++++++++++++++ .../gpu/runtime}/make_batch_pointers.h | 6 +- .../gpu/runtime/make_batch_pointers_test.cc | 78 +++++++++++++++++++ .../gpu/runtime/triangular_solve_thunk.cc | 4 +- third_party/xla/xla/service/gpu/BUILD | 36 --------- .../xla/service/gpu/make_batch_pointers.cc | 75 ------------------ .../xla/xla/stream_executor/cuda/BUILD | 19 +++++ .../make_batch_pointers_kernel_cuda.cu.cc | 45 +++++++++++ third_party/xla/xla/stream_executor/gpu/BUILD | 9 +++ .../gpu/make_batch_pointers_kernel.h | 36 +++++++++ .../xla/xla/stream_executor/rocm/BUILD | 23 ++++++ .../make_batch_pointers_kernel_rocm.cu.cc} | 23 ++++-- .../stream_executor/rocm/rocm_helpers.cu.cc | 25 +----- 15 files changed, 344 insertions(+), 151 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc rename third_party/xla/xla/{service/gpu => backends/gpu/runtime}/make_batch_pointers.h (92%) create mode 100644 third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc delete mode 100644 third_party/xla/xla/service/gpu/make_batch_pointers.cc create mode 100644 third_party/xla/xla/stream_executor/cuda/make_batch_pointers_kernel_cuda.cu.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/make_batch_pointers_kernel.h rename third_party/xla/xla/{service/gpu/make_batch_pointers.cu.cc => stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc} (57%) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 5fb985c794f176..c766e7a7bb45d3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -288,6 +288,7 @@ cc_library( hdrs = if_gpu_is_configured(["cholesky_thunk.h"]), deps = if_gpu_is_configured([ # keep sorted + ":make_batch_pointers", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -295,7 +296,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:make_batch_pointers", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory", "//xla/stream_executor:gpu_solver_context", @@ -307,7 +307,10 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", - ]), + ]) + [ + "//xla/tsl/platform:errors", + "@com_google_absl//absl/log:check", + ], ) cc_library( @@ -1218,6 +1221,7 @@ cc_library( hdrs = if_gpu_is_configured(["triangular_solve_thunk.h"]), deps = if_gpu_is_configured([ # keep sorted + ":make_batch_pointers", "//xla:status_macros", "//xla:types", "//xla:util", @@ -1226,7 +1230,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:make_batch_pointers", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", @@ -1391,3 +1394,47 @@ xla_test( "@local_tsl//tsl/platform:test", ], ) + +cc_library( + name = "make_batch_pointers", + srcs = ["make_batch_pointers.cc"], + hdrs = ["make_batch_pointers.h"], + deps = [ + "//xla:types", + "//xla:util", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:make_batch_pointers_kernel", + "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "make_batch_pointers_test", + srcs = ["make_batch_pointers_test.cc"], + backends = ["gpu"], + deps = [ + ":make_batch_pointers", + "//xla/service:platform_util", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/backends/gpu/runtime/cholesky_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/cholesky_thunk.cc index 8ff71f6b01ba0f..7562386bca4399 100644 --- a/third_party/xla/xla/backends/gpu/runtime/cholesky_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/cholesky_thunk.cc @@ -21,19 +21,19 @@ limitations under the License. #include #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "xla/backends/gpu/runtime/make_batch_pointers.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/make_batch_pointers.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu_solver_context.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc new file mode 100644 index 00000000000000..bff8fe2bb77c48 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc @@ -0,0 +1,57 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/make_batch_pointers.h" + +#include + +#include "absl/status/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/gpu/make_batch_pointers_kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::gpu { + +absl::Status MakeBatchPointers(se::Stream* stream, + se::DeviceMemoryBase base_ptr, + size_t stride_bytes, size_t n, + se::DeviceMemoryBase ptrs_out) { + se::StreamExecutor* executor = stream->parent(); + size_t threads_per_block = [&] { + if (executor->GetPlatform()->id() == + stream_executor::rocm::kROCmPlatformId) { + return 256; + } else { + return 128; + } + }(); + + TF_ASSIGN_OR_RETURN( + auto kernel, + stream_executor::gpu::GpuKernelRegistry::GetGlobalRegistry() + .LoadKernel(executor)); + + return kernel.Launch(se::ThreadDim(threads_per_block, 1, 1), + se::BlockDim(CeilOfRatio(n, threads_per_block), 1, 1), + stream, base_ptr, stride_bytes, n, ptrs_out); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/make_batch_pointers.h b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.h similarity index 92% rename from third_party/xla/xla/service/gpu/make_batch_pointers.h rename to third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.h index 6e437fafdcb6aa..0ebc10d2e68198 100644 --- a/third_party/xla/xla/service/gpu/make_batch_pointers.h +++ b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_MAKE_BATCH_POINTERS_H_ -#define XLA_SERVICE_GPU_MAKE_BATCH_POINTERS_H_ +#ifndef XLA_BACKENDS_GPU_RUNTIME_MAKE_BATCH_POINTERS_H_ +#define XLA_BACKENDS_GPU_RUNTIME_MAKE_BATCH_POINTERS_H_ #include @@ -56,4 +56,4 @@ absl::Status MakeBatchPointers(se::Stream* stream, } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_MAKE_BATCH_POINTERS_H_ +#endif // XLA_BACKENDS_GPU_RUNTIME_MAKE_BATCH_POINTERS_H_ diff --git a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc new file mode 100644 index 00000000000000..96d454ec20096a --- /dev/null +++ b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/make_batch_pointers.h" + +#include +#include + +#include +#include +#include "absl/status/statusor.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { +using ::testing::ElementsAreArray; +using tsl::testing::IsOk; + +static absl::StatusOr GpuExecutor() { + TF_ASSIGN_OR_RETURN(stream_executor::Platform * platform, + PlatformUtil::GetDefaultPlatform()); + return platform->ExecutorForDevice(0); +} + +TEST(MakeBatchPointersTest, Basic) { + TF_ASSERT_OK_AND_ASSIGN(stream_executor::StreamExecutor * executor, + GpuExecutor()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + // We don't care what `base` points to, we only need a pointer to a buffer + // that we can use as a base. + stream_executor::DeviceMemory base = executor->AllocateScalar(); + stream_executor::DeviceMemory ptrs_out = + executor->AllocateArray(8); + + constexpr int kStride = 13; + constexpr int kN = 8; + + EXPECT_THAT(MakeBatchPointers(stream.get(), base, kStride, kN, ptrs_out), + IsOk()); + + std::array result = {}; + + EXPECT_THAT( + executor->SynchronousMemcpy(result.data(), ptrs_out, kN * sizeof(void*)), + IsOk()); + + std::array expected = { + base.base() + 0 * kStride, base.base() + 1 * kStride, + base.base() + 2 * kStride, base.base() + 3 * kStride, + base.base() + 4 * kStride, base.base() + 5 * kStride, + base.base() + 6 * kStride, base.base() + 7 * kStride, + }; + + EXPECT_THAT(result, ElementsAreArray(expected)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc index e7f9768fd7fe1d..d4642b58026e67 100644 --- a/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc @@ -20,17 +20,15 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "xla/backends/gpu/runtime/make_batch_pointers.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/make_batch_pointers.h" #include "xla/status_macros.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index a9ad63be9dcca9..8823527f8faf1f 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2,7 +2,6 @@ # GPU-specific components in XLA service implementation. load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", @@ -2696,41 +2695,6 @@ xla_cc_test( ], ) -cc_library( - name = "make_batch_pointers", - srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), - hdrs = if_gpu_is_configured(["make_batch_pointers.h"]), - deps = [ - "//xla:types", - "//xla:util", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:stream", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - ":make_batch_pointers_kernel", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_helpers", - ]), -) - -cuda_library( - name = "make_batch_pointers_kernel", - srcs = if_cuda_is_configured(["make_batch_pointers.cu.cc"]), - deps = [ - "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep - ], -) - tsl_gpu_library( name = "runtime_intrinsics", srcs = ["runtime_intrinsics.cc"], diff --git a/third_party/xla/xla/service/gpu/make_batch_pointers.cc b/third_party/xla/xla/service/gpu/make_batch_pointers.cc deleted file mode 100644 index ad569593a84924..00000000000000 --- a/third_party/xla/xla/service/gpu/make_batch_pointers.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/make_batch_pointers.h" - -#include - -#include "absl/status/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/typed_kernel_factory.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/util.h" - -#if TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_stream.h" -namespace stream_executor::gpu { - -extern void rocm_MakeBatchPointers(void* stream, char* base, int stride, int n, - void** ptrs_out); - -} // namespace stream_executor::gpu -#endif - -namespace xla::gpu { - -namespace make_batch_pointers { -void* kernel(); // returns a pointer to a CUDA C++ device function -} // namespace make_batch_pointers - -absl::Status MakeBatchPointers(se::Stream* stream, - se::DeviceMemoryBase base_ptr, - size_t stride_bytes, size_t n, - se::DeviceMemoryBase ptrs_out) { -#if TENSORFLOW_USE_ROCM - stream_executor::gpu::rocm_MakeBatchPointers( - se::gpu::AsGpuStreamValue(stream), - reinterpret_cast(base_ptr.opaque()), stride_bytes, n, - reinterpret_cast(ptrs_out.opaque())); -#else - se::StreamExecutor* executor = stream->parent(); - static constexpr size_t kThreads = 128; - - TF_ASSIGN_OR_RETURN( - auto kernel, - (se::TypedKernelFactory< - se::DeviceMemoryBase, size_t, size_t, - se::DeviceMemoryBase>::Create(executor, "make_batch_pointers", - make_batch_pointers::kernel()))); - - TF_RETURN_IF_ERROR(kernel.Launch(se::ThreadDim(kThreads, 1, 1), - se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), - stream, base_ptr, stride_bytes, n, - ptrs_out)); -#endif - return absl::OkStatus(); -} - -} // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 472f8e2bb04e6b..83377e18a4bf6a 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1122,6 +1122,7 @@ cc_library( ":cuda_platform", ":cudnn_plugin", ":cufft_plugin", + ":make_batch_pointers_kernel_cuda", "//xla/tsl/cuda:cusolver", "//xla/tsl/cuda:cusparse", "//xla/tsl/cuda:tensorrt_rpath", @@ -1997,3 +1998,21 @@ cuda_library( ], alwayslink = 1, ) + +cuda_library( + name = "make_batch_pointers_kernel_cuda", + srcs = ["make_batch_pointers_kernel_cuda.cu.cc"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_platform_id", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/gpu:make_batch_pointers_kernel", + "@com_google_absl//absl/base", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/cuda/make_batch_pointers_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/make_batch_pointers_kernel_cuda.cu.cc new file mode 100644 index 00000000000000..fbb41248b8f3ab --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/make_batch_pointers_kernel_cuda.cu.cc @@ -0,0 +1,45 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/base/casts.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/gpu/make_batch_pointers_kernel.h" +#include "xla/stream_executor/kernel_spec.h" + +namespace stream_executor::cuda { +namespace { +__global__ void MakeBatchPointers(char* base, size_t stride, size_t n, + void** ptrs_out) { + size_t idx = size_t(threadIdx.x) + size_t(blockIdx.x) * size_t(blockDim.x); + if (idx >= n) return; + ptrs_out[idx] = base + idx * stride; +} +} // namespace + +} // namespace stream_executor::cuda + +GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( + MakeBatchPointersKernelCuda, stream_executor::gpu::MakeBatchPointersKernel, + stream_executor::cuda::kCudaPlatformId, ([] { + stream_executor::MultiKernelLoaderSpec spec(4); + spec.AddInProcessSymbol( + absl::bit_cast(&stream_executor::cuda::MakeBatchPointers), + + "make_batch_pointers"); + return spec; + })); diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 6d03c89b6c081e..22747215ae3b7a 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -868,3 +868,12 @@ cc_library( ) exports_files(["buffer_comparator_kernel_lib.cu.h"]) + +cc_library( + name = "make_batch_pointers_kernel", + hdrs = ["make_batch_pointers_kernel.h"], + deps = [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + ], +) diff --git a/third_party/xla/xla/stream_executor/gpu/make_batch_pointers_kernel.h b/third_party/xla/xla/stream_executor/gpu/make_batch_pointers_kernel.h new file mode 100644 index 00000000000000..1a8ed7322cfd79 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/make_batch_pointers_kernel.h @@ -0,0 +1,36 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_MAKE_BATCH_POINTERS_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_GPU_MAKE_BATCH_POINTERS_KERNEL_H_ + +#include + +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" + +namespace stream_executor::gpu { + +// Defines a trait for the MakeBatchPointers kernel that can be used to register +// and look up the kernel in the GPU kernel registry. +struct MakeBatchPointersKernel { + using KernelType = + stream_executor::TypedKernel; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_MAKE_BATCH_POINTERS_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index dbac48b0c136d6..d42871f9cb2b95 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -826,6 +826,7 @@ cc_library( ":amdhipblaslt_plugin", ":buffer_comparator_kernel_rocm", ":hipfft_plugin", + ":make_batch_pointers_kernel_rocm", ":miopen_plugin", ":rocblas_plugin", ":rocm_helpers", @@ -1098,3 +1099,25 @@ rocm_library( ], alwayslink = 1, ) + +rocm_library( + name = "make_batch_pointers_kernel_rocm", + srcs = ["make_batch_pointers_kernel_rocm.cu.cc"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + ":rocm_platform_id", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/gpu:make_batch_pointers_kernel", + "//xla/stream_executor/platform:initialize", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/service/gpu/make_batch_pointers.cu.cc b/third_party/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc similarity index 57% rename from third_party/xla/xla/service/gpu/make_batch_pointers.cu.cc rename to third_party/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc index 344f8ecc214e12..110d2e46e10016 100644 --- a/third_party/xla/xla/service/gpu/make_batch_pointers.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc @@ -15,7 +15,13 @@ limitations under the License. #include -namespace xla::gpu { +#include "absl/base/casts.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/gpu/make_batch_pointers_kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" + +namespace stream_executor::rocm { namespace { __global__ void MakeBatchPointers(char* base, size_t stride, size_t n, void** ptrs_out) { @@ -24,9 +30,14 @@ __global__ void MakeBatchPointers(char* base, size_t stride, size_t n, ptrs_out[idx] = base + idx * stride; } } // namespace +} // namespace stream_executor::rocm -namespace make_batch_pointers { -void* kernel() { return reinterpret_cast(MakeBatchPointers); } -} // namespace make_batch_pointers - -} // namespace xla::gpu +GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( + MakeBatchPointersKernelRocm, stream_executor::gpu::MakeBatchPointersKernel, + stream_executor::rocm::kROCmPlatformId, ([] { + stream_executor::MultiKernelLoaderSpec spec(4); + spec.AddInProcessSymbol( + absl::bit_cast(&stream_executor::rocm::MakeBatchPointers), + "make_batch_pointers"); + return spec; + })); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_helpers.cu.cc b/third_party/xla/xla/stream_executor/rocm/rocm_helpers.cu.cc index ae287aa4025fb2..d836d10a20aff8 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_helpers.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_helpers.cu.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include -#include +#include +#include + namespace stream_executor { namespace gpu { @@ -55,27 +57,6 @@ void rocm_Broadcast_fp32(void* stream, float* dst, int dst_stride, int batches, dst_stride, batches, src, size); } -// GPU kernel to populate an array of pointers: -// -// [base + stride * i for i in range(n)]. -// -__global__ void __xla_MakeBatchPointers(char* base, int stride, int n, - void** ptrs_out) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx >= n) return; - ptrs_out[idx] = base + idx * stride; -} - -void rocm_MakeBatchPointers(void* stream, char* base, int stride, int n, - void** ptrs_out) { - const int threads_per_block = 256; - hipLaunchKernelGGL( - __xla_MakeBatchPointers, - dim3((n + threads_per_block - 1) / threads_per_block, 1, 1), - dim3(threads_per_block, 1, 1), 0, (hipStream_t)stream, base, stride, n, - ptrs_out); -} - __device__ float sigmoid(float x) { if (x > 0) return 1. / (1. + __expf(-x)); From fd5597be8c09f7e06b5d0e4780a87bab53c2f03e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 8 Apr 2025 01:31:30 -0700 Subject: [PATCH 0344/1324] [XLA:GPU][NFC] Extract the whole fusion when `SymbolicTileAnalysis` fails in `NestGemmFusion`. PiperOrigin-RevId: 745031647 --- third_party/xla/xla/service/gpu/transforms/BUILD | 1 + .../xla/xla/service/gpu/transforms/nest_gemm_fusion.cc | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index bd31c186229aa6..b66f9871093b5c 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2350,6 +2350,7 @@ cc_library( "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/stream_executor:device_description", + "//xla/tools:hlo_extractor", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index 776d0c83f3c724..9580087f9490d6 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -61,6 +62,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tools/hlo_extractor.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -225,10 +227,12 @@ absl::StatusOr> FindOutputTileSizesForEpilogue( SymbolicTileAnalysisOrError analysis_or = SymbolicTileAnalysis::AnalyzeComputation(*computation, ctx); if (std::holds_alternative(analysis_or)) { + std::unique_ptr extracted_computation_module = + ExtractModule(computation->FusionInstruction()); return absl::InternalError( absl::StrCat("Failed to analyze the computation (", std::get(analysis_or).Explain(), - "): ", computation->ToString())); + "): ", extracted_computation_module->ToString())); } auto& analysis = std::get(analysis_or); From ece6c59755a498b3d463c141336c453377727d2d Mon Sep 17 00:00:00 2001 From: Ruturaj Vaidya Date: Tue, 8 Apr 2025 01:53:09 -0700 Subject: [PATCH 0345/1324] PR #24533: Fix fp8 tests Imported from GitHub PR https://github.com/openxla/xla/pull/24533 We observed failures in internal CI for FP8 types `F8E4M3FNUZ` and `F8E5M2FNUZ`. These types should not be implicitly upcast during dot operation operand conversion. This patch extends the existing check that prevents conversions between FP8 types by including these two additional variants. @xla-rotation @draganmladjenovic can you please take a look? Copybara import of the project: -- fa874600a85f1baed83c01cc245ceb581f75b078 by Ruturaj4 : Fix fp8 tests Merging this change closes #24533 PiperOrigin-RevId: 745038314 --- .../gpu/transforms/dot_operand_converter.cc | 3 ++- .../transforms/dot_operand_converter_test.cc | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc index d15e6779503045..4f2606d9ca6574 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc @@ -39,7 +39,8 @@ bool DotOperandConverter::InstructionMatchesPattern( } // Exclude conversions between FP8 types. - absl::flat_hash_set non_converting = {F8E4M3FN, F8E5M2}; + absl::flat_hash_set non_converting = {F8E4M3FN, F8E5M2, + F8E4M3FNUZ, F8E5M2FNUZ}; if (non_converting.contains(lhs_type) && non_converting.contains(rhs_type)) { return false; } diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc index be05b6767abbfd..03992378cc45a5 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc @@ -124,6 +124,23 @@ TEST_F(DotOperandConverterTest, NoConvertFromF8toF8) { EXPECT_FALSE(upcasted); } +TEST_F(DotOperandConverterTest, NoConvertFromF8FNUZtoF8FNUZ) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f8e4m3fnuz[2,3]{1,0} parameter(0) + p1 = f8e5m2fnuz[3,2]{1,0} parameter(1) + ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, + DotOperandConverter().Run(module.get())); + EXPECT_FALSE(upcasted); +} + TEST_F(DotOperandConverterTest, CompilerOptimizesUsingDotOperandConverter) { absl::string_view module_string = R"( HloModule module From 34b772277c6cce743c98f401eb066f0a33ac36ef Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 8 Apr 2025 01:54:39 -0700 Subject: [PATCH 0346/1324] PR #23493: [NFC] Fix "SyntaxWarning: invalid escape sequence" Imported from GitHub PR https://github.com/openxla/xla/pull/23493 The warning frequently appears during the compilation: .../external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:227: SyntaxWarning: invalid escape sequence '\.' Copybara import of the project: -- aaded2a363907eb207d4165b94c4455223115402 by Sergey Kozub : [NFC] Fix "SyntaxWarning: invalid escape sequence" Merging this change closes #23493 PiperOrigin-RevId: 745038828 --- .../crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl | 2 +- .../gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl | 2 +- third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl | 2 +- .../crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl | 2 +- .../gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl | 2 +- .../gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 35bb5ec2b4228a..50945644019e9f 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -224,7 +224,7 @@ def InvokeNvcc(argv, log=False): # Unfortunately, there are other options that have -c prefix too. # So allowing only those look like C/C++ files. src_files = [f for f in src_files if - re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] srcs = ' '.join(src_files) out = ' -o ' + out_file[0] diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index e3c99b8298494a..559a360767017e 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -163,7 +163,7 @@ def InvokeHipcc(argv, log=False): # Unfortunately, there are other options that have -c prefix too. # So allowing only those look like C/C++ files. src_files = [f for f in src_files if - re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] srcs = ' '.join(src_files) out = ' -o ' + out_file[0] diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index eb3a1d8c8ddf02..59d150b82fa34e 100644 --- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -102,7 +102,7 @@ def InvokeNvcc(argv, log=False): """ src_files = [f for f in argv if - re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] if len(src_files) == 0: raise Error('No source files found for cuda compilation.') diff --git a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 35bb5ec2b4228a..50945644019e9f 100755 --- a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -224,7 +224,7 @@ def InvokeNvcc(argv, log=False): # Unfortunately, there are other options that have -c prefix too. # So allowing only those look like C/C++ files. src_files = [f for f in src_files if - re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] srcs = ' '.join(src_files) out = ' -o ' + out_file[0] diff --git a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index e3c99b8298494a..559a360767017e 100755 --- a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -163,7 +163,7 @@ def InvokeHipcc(argv, log=False): # Unfortunately, there are other options that have -c prefix too. # So allowing only those look like C/C++ files. src_files = [f for f in src_files if - re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] srcs = ' '.join(src_files) out = ' -o ' + out_file[0] diff --git a/third_party/xla/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/xla/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index eb3a1d8c8ddf02..59d150b82fa34e 100644 --- a/third_party/xla/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/xla/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -102,7 +102,7 @@ def InvokeNvcc(argv, log=False): """ src_files = [f for f in argv if - re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] if len(src_files) == 0: raise Error('No source files found for cuda compilation.') From 7cd06ecbe5964aa2192671fc6b1645afe46ca3b8 Mon Sep 17 00:00:00 2001 From: mmakevic-amd Date: Tue, 8 Apr 2025 01:59:35 -0700 Subject: [PATCH 0347/1324] PR #23988: [ROCm] Fix rocm_stream_test Imported from GitHub PR https://github.com/openxla/xla/pull/23988 The test was failing with: ``` [2024-11-18T10:57:59.388Z] [ RUN ] RocmStreamTest.WaitForEvent [2024-11-18T10:57:59.388Z] external/local_xla/xla/stream_executor/rocm/rocm_stream_test.cc:258: Failure [2024-11-18T10:57:59.388Z] Value of: callback_called [2024-11-18T10:57:59.388Z] Actual: true [2024-11-18T10:57:59.388Z] Expected: false [2024-11-18T10:57:59.388Z] [ FAILED ] RocmStreamTest.WaitForEvent (0 ms) ``` Removing this expect_false check to avoid timing issues with callback Copybara import of the project: -- 1130979d98b182022fffb961dcfa28aa8aada3d9 by Milica Makevic : Remove expect_false check from WaitForEvent test Merging this change closes #23988 PiperOrigin-RevId: 745040464 --- third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc index ce272494b64895..59fe94efa9094f 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc @@ -254,7 +254,6 @@ TEST_F(RocmStreamTest, WaitForEvent) { stream->DoHostCallback([&callback_called]() { callback_called = true; }), IsOk()); - EXPECT_FALSE(callback_called); EXPECT_THAT(stream->RecordEvent(&event), IsOk()); EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); EXPECT_TRUE(callback_called); From 0e2240d6fcf23a3e71bf9489bb1b557cd9bd5548 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 02:02:42 -0700 Subject: [PATCH 0348/1324] Update GraphDef version to 2191. PiperOrigin-RevId: 745041468 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 37debf7e5a16f6..c8bc4044c228d2 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2190 // Updated: 2025/4/7 +#define TF_GRAPH_DEF_VERSION 2191 // Updated: 2025/4/8 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 65075e19e319e09b100e78594e58a74cd1f20298 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 02:02:46 -0700 Subject: [PATCH 0349/1324] compat: Update forward compatibility horizon to 2025-04-08 PiperOrigin-RevId: 745041496 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index efbd58c3b04602..3b77d06cffd2fe 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 7) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 8) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From f6786556f8861710e1abbbaccc8e344c9e512f3b Mon Sep 17 00:00:00 2001 From: Chunyu Jin Date: Tue, 8 Apr 2025 02:46:43 -0700 Subject: [PATCH 0350/1324] PR #24741: [ROCm] a patch for C64 and C128 from HLO to MLIR Imported from GitHub PR https://github.com/openxla/xla/pull/24741 Hi @xla-rotation, would you please kindly help to review this PR? This PR addresses the following issue related to HLO to MLIR for C64 and C128 as we got ``` [ RUN ] DotTests/ParametricDotTest.TestC64/1x23x1_MajorToMinorTF LLVM ERROR: Failed to infer result type(s): "arith.constant"(...) {} : () -> ( ??? ) [ RUN ] DotTests/ParametricDotTest.TestC128/1x23x1_MajorToMinorTF LLVM ERROR: Failed to infer result type(s): "arith.constant"(...) {} : () -> ( ??? ) ``` Copybara import of the project: -- d35eaa73de32d4c7480644807155a1551d1bf6fc by cj401-amd : a patch for C64 and C128 from HLO to MLIR -- 888e6a88ad3c994a001c1d6060de182c456d6430 by cj401-amd : add unit test for the patch of c64 and c128 Merging this change closes #24741 PiperOrigin-RevId: 745054817 --- .../codegen/emitters/elemental_hlo_to_mlir.cc | 24 +++++++- .../emitters/elemental_hlo_to_mlir_test.cc | 58 +++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc index 567736cfa8f1da..6d848e2912f34c 100644 --- a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc @@ -502,6 +502,11 @@ absl::StatusOr EmitMulAdd(Value lhs, Value rhs, Value accumulator, return b.create(accumulator, b.create(lhs, rhs)); } + if (primitive_util::IsComplexType(result_element_type)) { + // Handle complex types (e.g., C64, C128) + Value mul = b.create(accumulator_type, lhs, rhs); + return b.create(accumulator_type, accumulator, mul); + } return b.create(accumulator, b.create(lhs, rhs)); } @@ -518,8 +523,23 @@ absl::StatusOr> EmitDotLoop( const mlir::Type accumulator_type = result_element_type.isBF16() ? b.getF32Type() : result_element_type; - Value accum_init_value = - b.create(b.getZeroAttr(accumulator_type)).getResult(); + Value accum_init_value; + if (auto complex_ty = accumulator_type.dyn_cast()) { + // For complex, build real-zero and imag-zero separately: + mlir::Type element_ty = complex_ty.getElementType(); + + // E.g. float zero + auto real_zero = b.create(b.getZeroAttr(element_ty)); + auto imag_zero = b.create(b.getZeroAttr(element_ty)); + + // Create a complex from these two scalars + accum_init_value = + b.create(complex_ty, real_zero, imag_zero); + } else { + // For non-complex, just build a float or integer zero directly + accum_init_value = + b.create(b.getZeroAttr(accumulator_type)); + } // For convolutions with `batch_group_count` > 1, there is an additional // symbol for LHS (group id) - ignore it for RHS. diff --git a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc index 701814f0440b5f..6ed498fa7dd350 100644 --- a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc @@ -1765,6 +1765,64 @@ TEST_F(ElementalHloToMlirTest, BroadcastSelect) { )")); } +TEST_F(ElementalHloToMlirTest, DotC64) { + TF_EXPECT_OK(Run( + R"( +HloModule c64_dot_test + +ENTRY main { + p0 = c64[4] parameter(0) + p1 = c64[4] parameter(1) + dot = c64[] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT out = c64[] add(dot, dot) +} + )", + R"( + // CHECK: func.func private @main_out( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4xcomplex>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<4xcomplex> + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[INIT:.*]] = complex.create %[[CST0]], %[[CST0]] : complex + // CHECK: %[[DOTRESULT:.*]] = scf.for {{.*}} = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}} = %[[INIT]]) -> (complex) { + // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG0]][{{.*}}] + // CHECK: %[[EXTRACTED0:.*]] = tensor.extract %[[ARG1]][{{.*}}] + // CHECK: %[[MUL:.*]] = complex.mul %[[EXTRACTED]], %[[EXTRACTED0]] + // CHECK: %[[NEXTACC:.*]] = complex.add {{.*}}, %[[MUL]] + // CHECK: scf.yield %[[NEXTACC]] + // CHECK: %[[OUT:.*]] = complex.add %[[DOTRESULT]], %[[DOTRESULT]] + // CHECK: return %[[OUT]] + )")); +} + +TEST_F(ElementalHloToMlirTest, DotC128) { + TF_EXPECT_OK(Run( + R"( +HloModule c128_dot_test + +ENTRY main { + p0 = c128[3] parameter(0) + p1 = c128[3] parameter(1) + dot = c128[] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT out = c128[] add(dot, dot) +} + )", + R"( + // CHECK: func.func private @main_out( + // CHECK-SAME: %[[ARG0:.*]]: tensor<3xcomplex>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<3xcomplex> + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[INIT:.*]] = complex.create %[[CST0]], %[[CST0]] : complex + // CHECK: %[[DOTRESULT:.*]] = scf.for {{.*}} = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}} = %[[INIT]]) -> (complex) { + // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG0]][{{.*}}] + // CHECK: %[[EXTRACTED0:.*]] = tensor.extract %[[ARG1]][{{.*}}] + // CHECK: %[[MUL:.*]] = complex.mul %[[EXTRACTED]], %[[EXTRACTED0]] + // CHECK: %[[NEXTACC:.*]] = complex.add {{.*}}, %[[MUL]] + // CHECK: scf.yield %[[NEXTACC]] + // CHECK: %[[OUT:.*]] = complex.add %[[DOTRESULT]], %[[DOTRESULT]] + // CHECK: return %[[OUT]] + )")); +} + } // namespace } // namespace emitters } // namespace xla From efba2e44df81619a475d5f0037d21294d1af5e78 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Tue, 8 Apr 2025 03:02:58 -0700 Subject: [PATCH 0351/1324] PR #24731: [NVIDIA GPU] Assign collective memory to both IO of collective instructions when user buffer is enabled Imported from GitHub PR https://github.com/openxla/xla/pull/24731 Currently when user buffer is enabled, buffer assigner doesn't assign collective memory space to the inputs of collectives that are not in-place ops, ie all-gather or reduce-scatter, this prevents nccl from enabling NVLS + user buffer for these collectives. Copybara import of the project: -- efe843b4e90f8850887b3ea97d8bfa32092ef59c by TJ Xu : Assign collective memory to both IO of collective instructions when user buffer is enabled Merging this change closes #24731 PiperOrigin-RevId: 745059570 --- .../service/gpu/gpu_memory_space_assignment.h | 25 ++++++++--- .../xla/xla/tests/collective_ops_e2e_test.cc | 45 +++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h index 9dc80eb45a7c64..32d1906c30dc07 100644 --- a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h +++ b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h @@ -49,15 +49,28 @@ inline BufferAssigner::Colorer CollectiveColorer() { HloOpcode::kCollectivePermuteDone, HloOpcode::kAllToAll, }; + auto is_collective_memory_instr = [&](const HloInstruction* instr) { + return kSupportedOpcodes->contains(instr->opcode()) || + // opcode or async wrapped opcode is in kSupportedOpcodes. + ((instr->opcode() == HloOpcode::kAsyncStart || + instr->opcode() == HloOpcode::kAsyncDone) && + kSupportedOpcodes->contains(instr->async_wrapped_opcode())); + }; + auto has_collective_memory_in_uses = [&](const HloValue* input_alias) { + // If any use is a collective instruction, we must color the value to use + // collective memory space. + for (auto& use : input_alias->GetUses()) { + if (is_collective_memory_instr(use.instruction)) { + return true; + } + } + return false; + }; for (HloValue* value : alias_analysis->dataflow_analysis().values()) { auto& buffer = alias_analysis->GetBufferContainingValue(*value); for (const auto& alias : buffer.values()) { - // opcode or async wrapped opcode is in kSupportedOpcodes. - if (kSupportedOpcodes->contains(alias->instruction()->opcode()) || - ((alias->instruction()->opcode() == HloOpcode::kAsyncStart || - alias->instruction()->opcode() == HloOpcode::kAsyncDone) && - kSupportedOpcodes->contains( - alias->instruction()->async_wrapped_opcode()))) { + if (is_collective_memory_instr(alias->instruction()) || + has_collective_memory_in_uses(alias)) { value->set_color(kCollectiveMemorySpaceColor); } } diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 3fcf93f83c3517..870c383c7961dd 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -3058,5 +3058,50 @@ ENTRY main { } } +TEST_F(CollectiveOpsTestE2E, AllgatherMemspaceWithNcclUserBuffer) { + absl::string_view hlo_string = R"( +HloModule AllgatherMemspaceWithNcclUserBuffer, entry_computation_layout={(bf16[1024,1024]{1,0},bf16[1024,1024]{1,0})->bf16[4096,1024]{1,0}}, num_partitions=4 + +ENTRY main { + Arg_1 = bf16[1024,1024]{1,0} parameter(0) + Arg_2 = bf16[1024,1024]{1,0} parameter(1) + + add = bf16[1024,1024]{1,0} add(Arg_1, Arg_2) + all-gather-start = (bf16[1024,1024]{1,0},bf16[4096,1024]{1,0}) all-gather-start(add), dimensions={0} + all-gather-done = bf16[4096,1024]{1,0} all-gather-done(all-gather-start) + + ROOT add2 = bf16[4096,1024]{1,0} add(all-gather-done, all-gather-done) +} // main +)"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } + + HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions); + auto opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_enable_nccl_user_buffers(true); + opts.add_xla_disable_hlo_passes("gpu-convert-async-collectives-to-sync"); + + config.set_debug_options(opts); + config.set_use_spmd_partitioning(false); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CreateExecutable(std::move(module), /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(const HloModule* const executable_module, + test_runner().HloModuleFromWrapped(executable.get())); + HloInstruction* ag_start = + FindInstructions(executable_module, HloOpcode::kAllGatherStart)[0]; + // Both ag and its producer should have collective memory space 1 + EXPECT_EQ(ag_start->shape().tuple_shapes()[1].layout().memory_space(), 1); + EXPECT_EQ(ag_start->operand(0)->shape().layout().memory_space(), 1); +} + } // namespace } // namespace xla From c31c7bbdd82053eacfc95dbb35966202b58fa8cf Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 8 Apr 2025 03:14:25 -0700 Subject: [PATCH 0352/1324] PR #24709: Disable dynamic memcpy fusion annotation. Imported from GitHub PR https://github.com/openxla/xla/pull/24709 This was reported to cause segfaults in the dynamic copy thunk. Disabling the annotation disables the creation of any such thunks. Note: there are no test cases because the tests I wrote before are incomplete (which I hadn't noticed). In `GpuCopyTest`, I intended to test this functionality end-to-end, but I accidentally left out triggering. I'll fix this when re-enabling this. Copybara import of the project: -- c1095c3261d1f6164f23aaeefdcc70abe0dd9a40 by Johannes Reifferscheid : Disable dynamic memcpy fusion annotation. This was reported to cause segfaults in the dynamic copy thunk. Disabling the annotation disables the creation of any such thunks. Merging this change closes #24709 PiperOrigin-RevId: 745063095 --- third_party/xla/xla/service/gpu/BUILD | 1 - third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc | 2 -- 2 files changed, 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8823527f8faf1f..867c9a53949d84 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1302,7 +1302,6 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:pattern_matcher", "//xla/service/gpu/transforms:fusion_block_level_rewriter", - "//xla/service/gpu/transforms:fusion_dynamic_memcpy_rewriter", "//xla/stream_executor:device_description", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc index 71afe0365794fc..28302ad6ffe6aa 100644 --- a/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/layout_util.h" #include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" -#include "xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/device_description.h" @@ -132,7 +131,6 @@ HloPassPipeline FusionDispatchPipeline( HloPassPipeline pipeline("fusion-dispatch-pipeline"); pipeline.AddPass(device_description, shape_size_fn, std::move(try_rewrite_fusion_if)); - pipeline.AddPass(); return pipeline; } From 434d527ff4a182a5f9cd3c4e9b0131730465d7ac Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 8 Apr 2025 03:20:16 -0700 Subject: [PATCH 0353/1324] [XLA:GPU][Emitters] Improve compile time by limiting the size of subgraphs. At the moment, we can create subgraphs with O(10^4) ops, which leads to long compile time. Also, this CL restricts the inliner, so that we don't inline the functions back to create huge functions again. PiperOrigin-RevId: 745064490 --- .../loop/large_loop_slow_compile_time.hlo | 16838 ++++++++++++++++ .../emitters/computation_partitioner.cc | 22 +- .../codegen/emitters/ir/tests/inlining.mlir | 23 +- .../xla/codegen/emitters/ir/xla_dialect.cc | 47 +- 4 files changed, 16896 insertions(+), 34 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/tests/loop/large_loop_slow_compile_time.hlo diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/tests/loop/large_loop_slow_compile_time.hlo b/third_party/xla/xla/backends/gpu/codegen/emitters/tests/loop/large_loop_slow_compile_time.hlo new file mode 100644 index 00000000000000..3f29d4a320890b --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/tests/loop/large_loop_slow_compile_time.hlo @@ -0,0 +1,16838 @@ +// RUN: fusion_to_mlir %s | FileCheck %s --check-prefix=CHECK-PARTITIONED-HLO +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s + +HloModule m + +%fused_multiply { + %constant_84353_39 = f32[] constant(-1.99999988) + %broadcast.244407.384 = f32[1280,1280]{1,0} broadcast(%constant_84353_39), dimensions={} + %constant_84358_39 = f32[] constant(-0.954499722) + %broadcast.244408.1024 = f32[1280,1280]{1,0} broadcast(%constant_84358_39), dimensions={} + %param_0.210024 = u64[1280]{0} parameter(0) + %broadcast.244410.17671 = u64[1280,1280]{1,0} broadcast(%param_0.210024), dimensions={0} + %iota.3149.17671 = u64[1280,1280]{1,0} iota(), iota_dimension=1 + %add.244064.17671 = u64[1280,1280]{1,0} add(%broadcast.244410.17671, %iota.3149.17671) + %constant_39483_353 = u64[] constant(32) + %broadcast.244411.4867 = u64[1280,1280]{1,0} broadcast(%constant_39483_353), dimensions={} + %shift-right-logical.113832.4867 = u64[1280,1280]{1,0} shift-right-logical(%add.244064.17671, %broadcast.244411.4867) + %convert.3610.4865 = u32[1280,1280]{1,0} convert(%shift-right-logical.113832.4867) + %constant_158097_1 = u32[] constant(2326687384) + %broadcast.244412.44 = u32[1280,1280]{1,0} broadcast(%constant_158097_1), dimensions={} + %add.244065.37 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.244412.44) + %convert.3611.12801 = u32[1280,1280]{1,0} convert(%add.244064.17671) + %constant_158104_1 = u32[] constant(3182756316) + %broadcast.244413.113 = u32[1280,1280]{1,0} broadcast(%constant_158104_1), dimensions={} + %add.244067.99 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.244413.113) + %add.244068.35 = u32[1280,1280]{1,0} add(%add.244065.37, %add.244067.99) + %constant_39484_252 = u32[] constant(13) + %broadcast.244414.6272 = u32[1280,1280]{1,0} broadcast(%constant_39484_252), dimensions={} + %shift-left.107747.31 = u32[1280,1280]{1,0} shift-left(%add.244067.99, %broadcast.244414.6272) + %constant_39485_252 = u32[] constant(19) + %broadcast.244415.6016 = u32[1280,1280]{1,0} broadcast(%constant_39485_252), dimensions={} + %shift-right-logical.113833.29 = u32[1280,1280]{1,0} shift-right-logical(%add.244067.99, %broadcast.244415.6016) + %or.113347.29 = u32[1280,1280]{1,0} or(%shift-left.107747.31, %shift-right-logical.113833.29) + %xor.119908.27 = u32[1280,1280]{1,0} xor(%add.244068.35, %or.113347.29) + %add.244069.5 = u32[1280,1280]{1,0} add(%add.244068.35, %xor.119908.27) + %constant_39486_382 = u32[] constant(15) + %broadcast.244416.5760 = u32[1280,1280]{1,0} broadcast(%constant_39486_382), dimensions={} + %shift-left.107748.9 = u32[1280,1280]{1,0} shift-left(%xor.119908.27, %broadcast.244416.5760) + %constant_39487_382 = u32[] constant(17) + %broadcast.244417.5760 = u32[1280,1280]{1,0} broadcast(%constant_39487_382), dimensions={} + %shift-right-logical.113834.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119908.27, %broadcast.244417.5760) + %or.113348.7 = u32[1280,1280]{1,0} or(%shift-left.107748.9, %shift-right-logical.113834.9) + %xor.119909.5 = u32[1280,1280]{1,0} xor(%add.244069.5, %or.113348.7) + %add.244070.3 = u32[1280,1280]{1,0} add(%add.244069.5, %xor.119909.5) + %constant_39488_447 = u32[] constant(26) + %broadcast.244418.4352 = u32[1280,1280]{1,0} broadcast(%constant_39488_447), dimensions={} + %shift-left.107749.5 = u32[1280,1280]{1,0} shift-left(%xor.119909.5, %broadcast.244418.4352) + %constant_39489_447 = u32[] constant(6) + %broadcast.244419.4352 = u32[1280,1280]{1,0} broadcast(%constant_39489_447), dimensions={} + %shift-right-logical.113835.5 = u32[1280,1280]{1,0} shift-right-logical(%xor.119909.5, %broadcast.244419.4352) + %or.113349.3 = u32[1280,1280]{1,0} or(%shift-left.107749.5, %shift-right-logical.113835.5) + %xor.119911.3 = u32[1280,1280]{1,0} xor(%add.244070.3, %or.113349.3) + %add.244072.3 = u32[1280,1280]{1,0} add(%add.244070.3, %xor.119911.3) + %add.244073.7 = u32[1280,1280]{1,0} add(%add.244072.3, %broadcast.244413.113) + %shift-left.107751.5 = u32[1280,1280]{1,0} shift-left(%xor.119911.3, %broadcast.244419.4352) + %shift-right-logical.113836.5 = u32[1280,1280]{1,0} shift-right-logical(%xor.119911.3, %broadcast.244418.4352) + %or.113351.3 = u32[1280,1280]{1,0} or(%shift-left.107751.5, %shift-right-logical.113836.5) + %xor.119912.3 = u32[1280,1280]{1,0} xor(%add.244072.3, %or.113351.3) + %constant_217763_1 = u32[] constant(751465631) + %broadcast.244425.5 = u32[1280,1280]{1,0} broadcast(%constant_217763_1), dimensions={} + %add.244074.5 = u32[1280,1280]{1,0} add(%xor.119912.3, %broadcast.244425.5) + %add.244075.5 = u32[1280,1280]{1,0} add(%add.244073.7, %add.244074.5) + %shift-left.107752.9 = u32[1280,1280]{1,0} shift-left(%add.244074.5, %broadcast.244417.5760) + %shift-right-logical.113838.9 = u32[1280,1280]{1,0} shift-right-logical(%add.244074.5, %broadcast.244416.5760) + %or.113352.7 = u32[1280,1280]{1,0} or(%shift-left.107752.9, %shift-right-logical.113838.9) + %xor.119913.5 = u32[1280,1280]{1,0} xor(%add.244075.5, %or.113352.7) + %add.244077.3 = u32[1280,1280]{1,0} add(%add.244075.5, %xor.119913.5) + %constant_39492_187 = u32[] constant(29) + %broadcast.244428.2304 = u32[1280,1280]{1,0} broadcast(%constant_39492_187), dimensions={} + %shift-left.107753.9 = u32[1280,1280]{1,0} shift-left(%xor.119913.5, %broadcast.244428.2304) + %constant_39493_187 = u32[] constant(3) + %broadcast.244429.2304 = u32[1280,1280]{1,0} broadcast(%constant_39493_187), dimensions={} + %shift-right-logical.113839.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119913.5, %broadcast.244429.2304) + %or.113353.7 = u32[1280,1280]{1,0} or(%shift-left.107753.9, %shift-right-logical.113839.9) + %xor.119914.5 = u32[1280,1280]{1,0} xor(%add.244077.3, %or.113353.7) + %add.244078.3 = u32[1280,1280]{1,0} add(%add.244077.3, %xor.119914.5) + %constant_39494_317 = u32[] constant(16) + %broadcast.244430.4608 = u32[1280,1280]{1,0} broadcast(%constant_39494_317), dimensions={} + %shift-left.107754.9 = u32[1280,1280]{1,0} shift-left(%xor.119914.5, %broadcast.244430.4608) + %shift-right-logical.113840.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119914.5, %broadcast.244430.4608) + %or.113354.7 = u32[1280,1280]{1,0} or(%shift-left.107754.9, %shift-right-logical.113840.9) + %xor.119916.5 = u32[1280,1280]{1,0} xor(%add.244078.3, %or.113354.7) + %add.244079.3 = u32[1280,1280]{1,0} add(%add.244078.3, %xor.119916.5) + %constant_158106_1 = u32[] constant(751465630) + %broadcast.244432.24 = u32[1280,1280]{1,0} broadcast(%constant_158106_1), dimensions={} + %add.244080.7 = u32[1280,1280]{1,0} add(%add.244079.3, %broadcast.244432.24) + %constant_39495_187 = u32[] constant(24) + %broadcast.244433.2816 = u32[1280,1280]{1,0} broadcast(%constant_39495_187), dimensions={} + %shift-left.107755.11 = u32[1280,1280]{1,0} shift-left(%xor.119916.5, %broadcast.244433.2816) + %constant_39496_187 = u32[] constant(8) + %broadcast.244434.2816 = u32[1280,1280]{1,0} broadcast(%constant_39496_187), dimensions={} + %shift-right-logical.113841.11 = u32[1280,1280]{1,0} shift-right-logical(%xor.119916.5, %broadcast.244434.2816) + %or.113356.9 = u32[1280,1280]{1,0} or(%shift-left.107755.11, %shift-right-logical.113841.11) + %xor.119917.7 = u32[1280,1280]{1,0} xor(%add.244079.3, %or.113356.9) + %constant_217764_1 = u32[] constant(2326687386) + %broadcast.244435.5 = u32[1280,1280]{1,0} broadcast(%constant_217764_1), dimensions={} + %add.244081.5 = u32[1280,1280]{1,0} add(%xor.119917.7, %broadcast.244435.5) + %add.244083.5 = u32[1280,1280]{1,0} add(%add.244080.7, %add.244081.5) + %shift-left.107756.9 = u32[1280,1280]{1,0} shift-left(%add.244081.5, %broadcast.244414.6272) + %shift-right-logical.113842.9 = u32[1280,1280]{1,0} shift-right-logical(%add.244081.5, %broadcast.244415.6016) + %or.113357.7 = u32[1280,1280]{1,0} or(%shift-left.107756.9, %shift-right-logical.113842.9) + %xor.119918.5 = u32[1280,1280]{1,0} xor(%add.244083.5, %or.113357.7) + %add.244087.3 = u32[1280,1280]{1,0} add(%add.244083.5, %xor.119918.5) + %shift-left.107757.9 = u32[1280,1280]{1,0} shift-left(%xor.119918.5, %broadcast.244416.5760) + %shift-right-logical.113843.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119918.5, %broadcast.244417.5760) + %or.113358.7 = u32[1280,1280]{1,0} or(%shift-left.107757.9, %shift-right-logical.113843.9) + %xor.119919.5 = u32[1280,1280]{1,0} xor(%add.244087.3, %or.113358.7) + %add.244088.3 = u32[1280,1280]{1,0} add(%add.244087.3, %xor.119919.5) + %shift-left.107758.7 = u32[1280,1280]{1,0} shift-left(%xor.119919.5, %broadcast.244418.4352) + %shift-right-logical.113844.7 = u32[1280,1280]{1,0} shift-right-logical(%xor.119919.5, %broadcast.244419.4352) + %or.113359.5 = u32[1280,1280]{1,0} or(%shift-left.107758.7, %shift-right-logical.113844.7) + %xor.119921.3 = u32[1280,1280]{1,0} xor(%add.244088.3, %or.113359.5) + %add.244089.3 = u32[1280,1280]{1,0} add(%add.244088.3, %xor.119921.3) + %add.244090.7 = u32[1280,1280]{1,0} add(%add.244089.3, %broadcast.244412.44) + %shift-left.107759.7 = u32[1280,1280]{1,0} shift-left(%xor.119921.3, %broadcast.244419.4352) + %shift-right-logical.113845.7 = u32[1280,1280]{1,0} shift-right-logical(%xor.119921.3, %broadcast.244418.4352) + %or.113360.5 = u32[1280,1280]{1,0} or(%shift-left.107759.7, %shift-right-logical.113845.7) + %xor.119922.3 = u32[1280,1280]{1,0} xor(%add.244089.3, %or.113360.5) + %constant_217765_1 = u32[] constant(3182756319) + %broadcast.244445.5 = u32[1280,1280]{1,0} broadcast(%constant_217765_1), dimensions={} + %add.244092.5 = u32[1280,1280]{1,0} add(%xor.119922.3, %broadcast.244445.5) + %add.244093.5 = u32[1280,1280]{1,0} add(%add.244090.7, %add.244092.5) + %shift-left.107761.9 = u32[1280,1280]{1,0} shift-left(%add.244092.5, %broadcast.244417.5760) + %shift-right-logical.113846.9 = u32[1280,1280]{1,0} shift-right-logical(%add.244092.5, %broadcast.244416.5760) + %or.113361.7 = u32[1280,1280]{1,0} or(%shift-left.107761.9, %shift-right-logical.113846.9) + %xor.119923.5 = u32[1280,1280]{1,0} xor(%add.244093.5, %or.113361.7) + %add.244094.3 = u32[1280,1280]{1,0} add(%add.244093.5, %xor.119923.5) + %shift-left.107762.9 = u32[1280,1280]{1,0} shift-left(%xor.119923.5, %broadcast.244428.2304) + %shift-right-logical.113847.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119923.5, %broadcast.244429.2304) + %or.113362.7 = u32[1280,1280]{1,0} or(%shift-left.107762.9, %shift-right-logical.113847.9) + %xor.119924.5 = u32[1280,1280]{1,0} xor(%add.244094.3, %or.113362.7) + %add.244095.3 = u32[1280,1280]{1,0} add(%add.244094.3, %xor.119924.5) + %shift-left.107763.9 = u32[1280,1280]{1,0} shift-left(%xor.119924.5, %broadcast.244430.4608) + %shift-right-logical.113848.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119924.5, %broadcast.244430.4608) + %or.113363.7 = u32[1280,1280]{1,0} or(%shift-left.107763.9, %shift-right-logical.113848.9) + %xor.119925.5 = u32[1280,1280]{1,0} xor(%add.244095.3, %or.113363.7) + %add.244097.3 = u32[1280,1280]{1,0} add(%add.244095.3, %xor.119925.5) + %add.244098.7 = u32[1280,1280]{1,0} add(%add.244097.3, %broadcast.244413.113) + %shift-left.107764.11 = u32[1280,1280]{1,0} shift-left(%xor.119925.5, %broadcast.244433.2816) + %shift-right-logical.113849.11 = u32[1280,1280]{1,0} shift-right-logical(%xor.119925.5, %broadcast.244434.2816) + %or.113364.9 = u32[1280,1280]{1,0} or(%shift-left.107764.11, %shift-right-logical.113849.11) + %xor.119926.7 = u32[1280,1280]{1,0} xor(%add.244097.3, %or.113364.9) + %constant_217766_1 = u32[] constant(751465634) + %broadcast.244457.5 = u32[1280,1280]{1,0} broadcast(%constant_217766_1), dimensions={} + %add.244099.5 = u32[1280,1280]{1,0} add(%xor.119926.7, %broadcast.244457.5) + %add.244100.5 = u32[1280,1280]{1,0} add(%add.244098.7, %add.244099.5) + %shift-left.107766.9 = u32[1280,1280]{1,0} shift-left(%add.244099.5, %broadcast.244414.6272) + %shift-right-logical.113850.9 = u32[1280,1280]{1,0} shift-right-logical(%add.244099.5, %broadcast.244415.6016) + %or.113366.7 = u32[1280,1280]{1,0} or(%shift-left.107766.9, %shift-right-logical.113850.9) + %xor.119927.5 = u32[1280,1280]{1,0} xor(%add.244100.5, %or.113366.7) + %add.244102.3 = u32[1280,1280]{1,0} add(%add.244100.5, %xor.119927.5) + %shift-left.107767.9 = u32[1280,1280]{1,0} shift-left(%xor.119927.5, %broadcast.244416.5760) + %shift-right-logical.113851.9 = u32[1280,1280]{1,0} shift-right-logical(%xor.119927.5, %broadcast.244417.5760) + %or.113367.7 = u32[1280,1280]{1,0} or(%shift-left.107767.9, %shift-right-logical.113851.9) + %xor.119928.5 = u32[1280,1280]{1,0} xor(%add.244102.3, %or.113367.7) + %add.244103.3 = u32[1280,1280]{1,0} add(%add.244102.3, %xor.119928.5) + %shift-left.107768.5 = u32[1280,1280]{1,0} shift-left(%xor.119928.5, %broadcast.244418.4352) + %shift-right-logical.113852.5 = u32[1280,1280]{1,0} shift-right-logical(%xor.119928.5, %broadcast.244419.4352) + %or.113368.3 = u32[1280,1280]{1,0} or(%shift-left.107768.5, %shift-right-logical.113852.5) + %xor.119929.3 = u32[1280,1280]{1,0} xor(%add.244103.3, %or.113368.3) + %add.244104.3 = u32[1280,1280]{1,0} add(%add.244103.3, %xor.119929.3) + %add.244105.17 = u32[1280,1280]{1,0} add(%add.244104.3, %broadcast.244432.24) + %shift-left.107769.5 = u32[1280,1280]{1,0} shift-left(%xor.119929.3, %broadcast.244419.4352) + %shift-right-logical.113853.5 = u32[1280,1280]{1,0} shift-right-logical(%xor.119929.3, %broadcast.244418.4352) + %or.113369.3 = u32[1280,1280]{1,0} or(%shift-left.107769.5, %shift-right-logical.113853.5) + %xor.119931.15 = u32[1280,1280]{1,0} xor(%add.244104.3, %or.113369.3) + %constant_217767_1 = u32[] constant(2326687389) + %broadcast.244467.19 = u32[1280,1280]{1,0} broadcast(%constant_217767_1), dimensions={} + %add.244106.19 = u32[1280,1280]{1,0} add(%xor.119931.15, %broadcast.244467.19) + %xor.119932.17 = u32[1280,1280]{1,0} xor(%add.244105.17, %add.244106.19) + %constant_39500_122 = u32[] constant(9) + %broadcast.244468.1920 = u32[1280,1280]{1,0} broadcast(%constant_39500_122), dimensions={} + %shift-right-logical.113854.15 = u32[1280,1280]{1,0} shift-right-logical(%xor.119932.17, %broadcast.244468.1920) + %constant_39501_122 = u32[] constant(1065353216) + %broadcast.244469.1664 = u32[1280,1280]{1,0} broadcast(%constant_39501_122), dimensions={} + %or.113371.13 = u32[1280,1280]{1,0} or(%shift-right-logical.113854.15, %broadcast.244469.1664) + %bitcast-convert.5669.11 = f32[1280,1280]{1,0} bitcast-convert(%or.113371.13) + %constant_84831_250 = f32[] constant(-1) + %broadcast.244470.1152 = f32[1280,1280]{1,0} broadcast(%constant_84831_250), dimensions={} + %add.244108.9 = f32[1280,1280]{1,0} add(%bitcast-convert.5669.11, %broadcast.244470.1152) + %constant_84836_39 = f32[] constant(1.90899944) + %broadcast.244471.896 = f32[1280,1280]{1,0} broadcast(%constant_84836_39), dimensions={} + %multiply.25583.7 = f32[1280,1280]{1,0} multiply(%add.244108.9, %broadcast.244471.896) + %add.244112.5 = f32[1280,1280]{1,0} add(%multiply.25583.7, %broadcast.244408.1024) + %maximum.3601.3 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.244112.5) + %abs.1483.3 = f32[1280,1280]{1,0} abs(%maximum.3601.3) + %constant_23_60 = f32[] constant(1) + %broadcast.244475.384 = f32[1280,1280]{1,0} broadcast(%constant_23_60), dimensions={} + %compare.7114.3 = pred[1280,1280]{1,0} compare(%abs.1483.3, %broadcast.244475.384), direction=EQ + %constant_39504_52 = f32[] constant(inf) + %broadcast.244476.1152 = f32[1280,1280]{1,0} broadcast(%constant_39504_52), dimensions={} + %multiply.25584.9 = f32[1280,1280]{1,0} multiply(%maximum.3601.3, %broadcast.244476.1152) + %negate.4471.5 = f32[1280,1280]{1,0} negate(%maximum.3601.3) + %multiply.25585.5 = f32[1280,1280]{1,0} multiply(%maximum.3601.3, %negate.4471.5) + %log-plus-one.1483.3 = f32[1280,1280]{1,0} log-plus-one(%multiply.25585.5) + %negate.4472.4 = f32[1280,1280]{1,0} negate(%log-plus-one.1483.3) + %constant_39505_52 = f32[] constant(5) + %broadcast.244477.384 = f32[1280,1280]{1,0} broadcast(%constant_39505_52), dimensions={} + %compare.7115.3 = pred[1280,1280]{1,0} compare(%negate.4472.4, %broadcast.244477.384), direction=LT + %constant_39506_52 = f32[] constant(1.50140941) + %broadcast.244478.896 = f32[1280,1280]{1,0} broadcast(%constant_39506_52), dimensions={} + %constant_39507_52 = f32[] constant(2.83297682) + %broadcast.244479.896 = f32[1280,1280]{1,0} broadcast(%constant_39507_52), dimensions={} + %select.20357.7 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244478.896, %broadcast.244479.896) + %constant_39508_52 = f32[] constant(0.246640727) + %broadcast.244480.1408 = f32[1280,1280]{1,0} broadcast(%constant_39508_52), dimensions={} + %constant_39509_52 = f32[] constant(1.00167406) + %broadcast.244481.1408 = f32[1280,1280]{1,0} broadcast(%constant_39509_52), dimensions={} + %select.20358.11 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244480.1408, %broadcast.244481.1408) + %constant_39510_52 = f32[] constant(-0.00417768164) + %broadcast.244482.640 = f32[1280,1280]{1,0} broadcast(%constant_39510_52), dimensions={} + %constant_39511_52 = f32[] constant(0.00943887047) + %broadcast.244483.640 = f32[1280,1280]{1,0} broadcast(%constant_39511_52), dimensions={} + %select.20359.5 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244482.640, %broadcast.244483.640) + %constant_39512_52 = f32[] constant(-0.00125372503) + %broadcast.244484.640 = f32[1280,1280]{1,0} broadcast(%constant_39512_52), dimensions={} + %constant_39513_52 = f32[] constant(-0.0076224613) + %broadcast.244485.640 = f32[1280,1280]{1,0} broadcast(%constant_39513_52), dimensions={} + %select.20360.5 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244484.640, %broadcast.244485.640) + %constant_39514_52 = f32[] constant(0.00021858087) + %broadcast.244486.384 = f32[1280,1280]{1,0} broadcast(%constant_39514_52), dimensions={} + %constant_39515_52 = f32[] constant(0.00573950773) + %broadcast.244487.384 = f32[1280,1280]{1,0} broadcast(%constant_39515_52), dimensions={} + %select.20361.3 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244486.384, %broadcast.244487.384) + %constant_39516_52 = f32[] constant(-4.39150654e-06) + %broadcast.244488.384 = f32[1280,1280]{1,0} broadcast(%constant_39516_52), dimensions={} + %constant_39517_52 = f32[] constant(-0.00367342844) + %broadcast.244489.384 = f32[1280,1280]{1,0} broadcast(%constant_39517_52), dimensions={} + %select.20362.3 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244488.384, %broadcast.244489.384) + %constant_39518_52 = f32[] constant(-3.5233877e-06) + %broadcast.244490.384 = f32[1280,1280]{1,0} broadcast(%constant_39518_52), dimensions={} + %constant_39519_52 = f32[] constant(0.00134934322) + %broadcast.244491.384 = f32[1280,1280]{1,0} broadcast(%constant_39519_52), dimensions={} + %select.20363.3 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244490.384, %broadcast.244491.384) + %constant_39520_52 = f32[] constant(3.43273939e-07) + %broadcast.244492.384 = f32[1280,1280]{1,0} broadcast(%constant_39520_52), dimensions={} + %constant_39521_52 = f32[] constant(0.000100950558) + %broadcast.244493.384 = f32[1280,1280]{1,0} broadcast(%constant_39521_52), dimensions={} + %select.20364.3 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244492.384, %broadcast.244493.384) + %constant_39522_52 = f32[] constant(2.81022636e-08) + %broadcast.244494.384 = f32[1280,1280]{1,0} broadcast(%constant_39522_52), dimensions={} + %constant_39523_52 = f32[] constant(-0.000200214257) + %broadcast.244495.384 = f32[1280,1280]{1,0} broadcast(%constant_39523_52), dimensions={} + %select.20365.3 = f32[1280,1280]{1,0} select(%compare.7115.3, %broadcast.244494.384, %broadcast.244495.384) + %constant_84838_52 = f32[] constant(-2.5) + %broadcast.244496.640 = f32[1280,1280]{1,0} broadcast(%constant_84838_52), dimensions={} + %add.244113.5 = f32[1280,1280]{1,0} add(%negate.4472.4, %broadcast.244496.640) + %sqrt.1483.5 = f32[1280,1280]{1,0} sqrt(%negate.4472.4) + %constant_84839_52 = f32[] constant(-3) + %broadcast.244498.640 = f32[1280,1280]{1,0} broadcast(%constant_84839_52), dimensions={} + %add.244114.5 = f32[1280,1280]{1,0} add(%sqrt.1483.5, %broadcast.244498.640) + %select.20366.3 = f32[1280,1280]{1,0} select(%compare.7115.3, %add.244113.5, %add.244114.5) + %multiply.25586.1 = f32[1280,1280]{1,0} multiply(%select.20365.3, %select.20366.3) + %add.244115.1 = f32[1280,1280]{1,0} add(%select.20364.3, %multiply.25586.1) + %multiply.25587.1 = f32[1280,1280]{1,0} multiply(%add.244115.1, %select.20366.3) + %add.244117.1 = f32[1280,1280]{1,0} add(%select.20363.3, %multiply.25587.1) + %multiply.25588.1 = f32[1280,1280]{1,0} multiply(%add.244117.1, %select.20366.3) + %add.244118.1 = f32[1280,1280]{1,0} add(%select.20362.3, %multiply.25588.1) + %multiply.25589.1 = f32[1280,1280]{1,0} multiply(%add.244118.1, %select.20366.3) + %add.244119.1 = f32[1280,1280]{1,0} add(%select.20361.3, %multiply.25589.1) + %multiply.25590.1 = f32[1280,1280]{1,0} multiply(%add.244119.1, %select.20366.3) + %add.244120.3 = f32[1280,1280]{1,0} add(%select.20360.5, %multiply.25590.1) + %multiply.25591.1 = f32[1280,1280]{1,0} multiply(%add.244120.3, %select.20366.3) + %add.244122.3 = f32[1280,1280]{1,0} add(%select.20359.5, %multiply.25591.1) + %multiply.25592.7 = f32[1280,1280]{1,0} multiply(%add.244122.3, %select.20366.3) + %add.244123.9 = f32[1280,1280]{1,0} add(%select.20358.11, %multiply.25592.7) + %multiply.25593.7 = f32[1280,1280]{1,0} multiply(%add.244123.9, %select.20366.3) + %add.244124.7 = f32[1280,1280]{1,0} add(%select.20357.7, %multiply.25593.7) + %multiply.25594.7 = f32[1280,1280]{1,0} multiply(%add.244124.7, %maximum.3601.3) + %select.20367.7 = f32[1280,1280]{1,0} select(%compare.7114.3, %multiply.25584.9, %multiply.25594.7) + %constant_39526_43 = f32[] constant(1.41421354) + %broadcast.244500.640 = f32[1280,1280]{1,0} broadcast(%constant_39526_43), dimensions={} + %multiply.25595.5 = f32[1280,1280]{1,0} multiply(%select.20367.7, %broadcast.244500.640) + %constant_84330_39 = f32[] constant(1.99999988) + %broadcast.244501.384 = f32[1280,1280]{1,0} broadcast(%constant_84330_39), dimensions={} + %clamp.1127.3 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25595.5, %broadcast.244501.384) + %constant_106_2 = f32[] constant(0.0317758434) + %broadcast.244502.1 = f32[1280,1280]{1,0} broadcast(%constant_106_2), dimensions={} + %multiply.25596.1 = f32[1280,1280]{1,0} multiply(%clamp.1127.3, %broadcast.244502.1) + %constant_158523_1_clone_1 = u32[] constant(2951773949) + %broadcast.244607.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_158523_1_clone_1), dimensions={} + %add.244174.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.244607.44.clone.1) + %constant_158531_1_clone_1 = u32[] constant(1662691728) + %broadcast.244608.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_158531_1_clone_1), dimensions={} + %add.244175.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.244608.113.clone.1) + %add.244176.35.clone.1 = u32[1280,1280]{1,0} add(%add.244174.37.clone.1, %add.244175.99.clone.1) + %shift-left.107794.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.244175.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.113876.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.244175.99.clone.1, %broadcast.244415.6016) + %or.113397.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107794.31.clone.1, %shift-right-logical.113876.29.clone.1) + %xor.119958.27.clone.1 = u32[1280,1280]{1,0} xor(%add.244176.35.clone.1, %or.113397.29.clone.1) + %add.244177.5.clone.1 = u32[1280,1280]{1,0} add(%add.244176.35.clone.1, %xor.119958.27.clone.1) + %shift-left.107796.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119958.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.113877.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119958.27.clone.1, %broadcast.244417.5760) + %or.113398.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107796.9.clone.1, %shift-right-logical.113877.9.clone.1) + %xor.119959.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244177.5.clone.1, %or.113398.7.clone.1) + %add.244178.3.clone.1 = u32[1280,1280]{1,0} add(%add.244177.5.clone.1, %xor.119959.5.clone.1) + %shift-left.107797.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119959.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.113878.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119959.5.clone.1, %broadcast.244419.4352) + %or.113399.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107797.5.clone.1, %shift-right-logical.113878.5.clone.1) + %xor.119961.3.clone.1 = u32[1280,1280]{1,0} xor(%add.244178.3.clone.1, %or.113399.3.clone.1) + %add.244179.3.clone.1 = u32[1280,1280]{1,0} add(%add.244178.3.clone.1, %xor.119961.3.clone.1) + %add.244180.7.clone.1 = u32[1280,1280]{1,0} add(%add.244179.3.clone.1, %broadcast.244608.113.clone.1) + %shift-left.107798.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119961.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.113879.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119961.3.clone.1, %broadcast.244418.4352) + %or.113401.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107798.5.clone.1, %shift-right-logical.113879.5.clone.1) + %xor.119962.3.clone.1 = u32[1280,1280]{1,0} xor(%add.244179.3.clone.1, %or.113401.3.clone.1) + %constant_217773_1_clone_1 = u32[] constant(3611020472) + %broadcast.244618.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217773_1_clone_1), dimensions={} + %add.244181.5.clone.1 = u32[1280,1280]{1,0} add(%xor.119962.3.clone.1, %broadcast.244618.5.clone.1) + %add.244182.5.clone.1 = u32[1280,1280]{1,0} add(%add.244180.7.clone.1, %add.244181.5.clone.1) + %shift-left.107799.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.244181.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.113880.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.244181.5.clone.1, %broadcast.244416.5760) + %or.113402.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107799.9.clone.1, %shift-right-logical.113880.9.clone.1) + %xor.119963.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244182.5.clone.1, %or.113402.7.clone.1) + %add.244184.3.clone.1 = u32[1280,1280]{1,0} add(%add.244182.5.clone.1, %xor.119963.5.clone.1) + %shift-left.107801.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119963.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.113881.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119963.5.clone.1, %broadcast.244429.2304) + %or.113403.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107801.9.clone.1, %shift-right-logical.113881.9.clone.1) + %xor.119964.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244184.3.clone.1, %or.113403.7.clone.1) + %add.244185.3.clone.1 = u32[1280,1280]{1,0} add(%add.244184.3.clone.1, %xor.119964.5.clone.1) + %shift-left.107802.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119964.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.113882.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119964.5.clone.1, %broadcast.244430.4608) + %or.113404.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107802.9.clone.1, %shift-right-logical.113882.9.clone.1) + %xor.119966.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244185.3.clone.1, %or.113404.7.clone.1) + %add.244186.3.clone.1 = u32[1280,1280]{1,0} add(%add.244185.3.clone.1, %xor.119966.5.clone.1) + %constant_158541_1_clone_1 = u32[] constant(3611020471) + %broadcast.244625.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_158541_1_clone_1), dimensions={} + %add.244187.7.clone.1 = u32[1280,1280]{1,0} add(%add.244186.3.clone.1, %broadcast.244625.24.clone.1) + %shift-left.107803.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119966.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.113883.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119966.5.clone.1, %broadcast.244434.2816) + %or.113406.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107803.11.clone.1, %shift-right-logical.113883.11.clone.1) + %xor.119967.7.clone.1 = u32[1280,1280]{1,0} xor(%add.244186.3.clone.1, %or.113406.9.clone.1) + %constant_217774_1_clone_1 = u32[] constant(2951773951) + %broadcast.244628.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217774_1_clone_1), dimensions={} + %add.244188.5.clone.1 = u32[1280,1280]{1,0} add(%xor.119967.7.clone.1, %broadcast.244628.5.clone.1) + %add.244189.5.clone.1 = u32[1280,1280]{1,0} add(%add.244187.7.clone.1, %add.244188.5.clone.1) + %shift-left.107804.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.244188.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.113884.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.244188.5.clone.1, %broadcast.244415.6016) + %or.113407.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107804.9.clone.1, %shift-right-logical.113884.9.clone.1) + %xor.119968.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244189.5.clone.1, %or.113407.7.clone.1) + %add.244190.3.clone.1 = u32[1280,1280]{1,0} add(%add.244189.5.clone.1, %xor.119968.5.clone.1) + %shift-left.107805.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119968.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.113886.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119968.5.clone.1, %broadcast.244417.5760) + %or.113408.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107805.9.clone.1, %shift-right-logical.113886.9.clone.1) + %xor.119969.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244190.3.clone.1, %or.113408.7.clone.1) + %add.244191.3.clone.1 = u32[1280,1280]{1,0} add(%add.244190.3.clone.1, %xor.119969.5.clone.1) + %shift-left.107806.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119969.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.113887.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119969.5.clone.1, %broadcast.244419.4352) + %or.113409.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107806.7.clone.1, %shift-right-logical.113887.7.clone.1) + %xor.119971.3.clone.1 = u32[1280,1280]{1,0} xor(%add.244191.3.clone.1, %or.113409.5.clone.1) + %add.244192.3.clone.1 = u32[1280,1280]{1,0} add(%add.244191.3.clone.1, %xor.119971.3.clone.1) + %add.244193.7.clone.1 = u32[1280,1280]{1,0} add(%add.244192.3.clone.1, %broadcast.244607.44.clone.1) + %shift-left.107807.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119971.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.113888.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119971.3.clone.1, %broadcast.244418.4352) + %or.113410.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107807.7.clone.1, %shift-right-logical.113888.7.clone.1) + %xor.119972.3.clone.1 = u32[1280,1280]{1,0} xor(%add.244192.3.clone.1, %or.113410.5.clone.1) + %constant_217775_1_clone_1 = u32[] constant(1662691731) + %broadcast.244638.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217775_1_clone_1), dimensions={} + %add.244194.5.clone.1 = u32[1280,1280]{1,0} add(%xor.119972.3.clone.1, %broadcast.244638.5.clone.1) + %add.244195.5.clone.1 = u32[1280,1280]{1,0} add(%add.244193.7.clone.1, %add.244194.5.clone.1) + %shift-left.107808.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.244194.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.113889.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.244194.5.clone.1, %broadcast.244416.5760) + %or.113411.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107808.9.clone.1, %shift-right-logical.113889.9.clone.1) + %xor.119973.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244195.5.clone.1, %or.113411.7.clone.1) + %add.244196.3.clone.1 = u32[1280,1280]{1,0} add(%add.244195.5.clone.1, %xor.119973.5.clone.1) + %shift-left.107809.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119973.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.113891.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119973.5.clone.1, %broadcast.244429.2304) + %or.113412.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107809.9.clone.1, %shift-right-logical.113891.9.clone.1) + %xor.119974.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244196.3.clone.1, %or.113412.7.clone.1) + %add.244197.3.clone.1 = u32[1280,1280]{1,0} add(%add.244196.3.clone.1, %xor.119974.5.clone.1) + %shift-left.107811.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119974.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.113892.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119974.5.clone.1, %broadcast.244430.4608) + %or.113413.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107811.9.clone.1, %shift-right-logical.113892.9.clone.1) + %xor.119975.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244197.3.clone.1, %or.113413.7.clone.1) + %add.244198.3.clone.1 = u32[1280,1280]{1,0} add(%add.244197.3.clone.1, %xor.119975.5.clone.1) + %add.244199.7.clone.1 = u32[1280,1280]{1,0} add(%add.244198.3.clone.1, %broadcast.244608.113.clone.1) + %shift-left.107812.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119975.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.113893.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119975.5.clone.1, %broadcast.244434.2816) + %or.113414.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107812.11.clone.1, %shift-right-logical.113893.11.clone.1) + %xor.119976.7.clone.1 = u32[1280,1280]{1,0} xor(%add.244198.3.clone.1, %or.113414.9.clone.1) + %constant_217776_1_clone_1 = u32[] constant(3611020475) + %broadcast.244648.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217776_1_clone_1), dimensions={} + %add.244200.5.clone.1 = u32[1280,1280]{1,0} add(%xor.119976.7.clone.1, %broadcast.244648.5.clone.1) + %add.244201.5.clone.1 = u32[1280,1280]{1,0} add(%add.244199.7.clone.1, %add.244200.5.clone.1) + %shift-left.107813.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.244200.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.113894.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.244200.5.clone.1, %broadcast.244415.6016) + %or.113415.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107813.9.clone.1, %shift-right-logical.113894.9.clone.1) + %xor.119977.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244201.5.clone.1, %or.113415.7.clone.1) + %add.244202.3.clone.1 = u32[1280,1280]{1,0} add(%add.244201.5.clone.1, %xor.119977.5.clone.1) + %shift-left.107814.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119977.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.113896.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119977.5.clone.1, %broadcast.244417.5760) + %or.113416.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107814.9.clone.1, %shift-right-logical.113896.9.clone.1) + %xor.119978.5.clone.1 = u32[1280,1280]{1,0} xor(%add.244202.3.clone.1, %or.113416.7.clone.1) + %add.244203.3.clone.1 = u32[1280,1280]{1,0} add(%add.244202.3.clone.1, %xor.119978.5.clone.1) + %shift-left.107816.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119978.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.113897.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119978.5.clone.1, %broadcast.244419.4352) + %or.113417.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107816.5.clone.1, %shift-right-logical.113897.5.clone.1) + %xor.119979.3.clone.1 = u32[1280,1280]{1,0} xor(%add.244203.3.clone.1, %or.113417.3.clone.1) + %add.244204.3.clone.1 = u32[1280,1280]{1,0} add(%add.244203.3.clone.1, %xor.119979.3.clone.1) + %add.244205.17.clone.1 = u32[1280,1280]{1,0} add(%add.244204.3.clone.1, %broadcast.244625.24.clone.1) + %shift-left.107817.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.119979.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.113898.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119979.3.clone.1, %broadcast.244418.4352) + %or.113418.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.107817.5.clone.1, %shift-right-logical.113898.5.clone.1) + %xor.119980.15.clone.1 = u32[1280,1280]{1,0} xor(%add.244204.3.clone.1, %or.113418.3.clone.1) + %constant_217777_1_clone_1 = u32[] constant(2951773954) + %broadcast.244658.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217777_1_clone_1), dimensions={} + %add.244206.19.clone.1 = u32[1280,1280]{1,0} add(%xor.119980.15.clone.1, %broadcast.244658.19.clone.1) + %xor.119981.17.clone.1 = u32[1280,1280]{1,0} xor(%add.244205.17.clone.1, %add.244206.19.clone.1) + %shift-right-logical.113899.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.119981.17.clone.1, %broadcast.244468.1920) + %or.113419.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.113899.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5671.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.113419.13.clone.1) + %add.244207.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5671.11.clone.1, %broadcast.244470.1152) + %multiply.25611.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.244207.9.clone.1, %broadcast.244471.896) + %add.244208.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25611.7.clone.1, %broadcast.244408.1024) + %maximum.3603.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.244208.5.clone.1) + %abs.1485.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3603.3.clone.1) + %compare.7118.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1485.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25612.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3603.3.clone.1, %broadcast.244476.1152) + %negate.4475.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3603.3.clone.1) + %multiply.25613.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3603.3.clone.1, %negate.4475.5.clone.1) + %log-plus-one.1485.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25613.5.clone.1) + %negate.4476.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1485.3.clone.1) + %compare.7119.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4476.4.clone.1, %broadcast.244477.384), direction=LT + %select.20379.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20380.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20381.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20382.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20383.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20384.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20385.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20386.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20387.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.244209.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4476.4.clone.1, %broadcast.244496.640) + %sqrt.1485.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4476.4.clone.1) + %add.244210.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1485.5.clone.1, %broadcast.244498.640) + %select.20388.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7119.3.clone.1, %add.244209.5.clone.1, %add.244210.5.clone.1) + %multiply.25614.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20387.3.clone.1, %select.20388.3.clone.1) + %add.244211.1.clone.1 = f32[1280,1280]{1,0} add(%select.20386.3.clone.1, %multiply.25614.1.clone.1) + %multiply.25615.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.244211.1.clone.1, %select.20388.3.clone.1) + %add.244212.1.clone.1 = f32[1280,1280]{1,0} add(%select.20385.3.clone.1, %multiply.25615.1.clone.1) + %multiply.25616.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.244212.1.clone.1, %select.20388.3.clone.1) + %add.244213.1.clone.1 = f32[1280,1280]{1,0} add(%select.20384.3.clone.1, %multiply.25616.1.clone.1) + %multiply.25617.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.244213.1.clone.1, %select.20388.3.clone.1) + %add.244214.1.clone.1 = f32[1280,1280]{1,0} add(%select.20383.3.clone.1, %multiply.25617.1.clone.1) + %multiply.25618.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.244214.1.clone.1, %select.20388.3.clone.1) + %add.244216.3.clone.1 = f32[1280,1280]{1,0} add(%select.20382.5.clone.1, %multiply.25618.1.clone.1) + %multiply.25619.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.244216.3.clone.1, %select.20388.3.clone.1) + %add.244219.3.clone.1 = f32[1280,1280]{1,0} add(%select.20381.5.clone.1, %multiply.25619.1.clone.1) + %multiply.25620.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.244219.3.clone.1, %select.20388.3.clone.1) + %add.244220.9.clone.1 = f32[1280,1280]{1,0} add(%select.20380.11.clone.1, %multiply.25620.7.clone.1) + %multiply.25621.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.244220.9.clone.1, %select.20388.3.clone.1) + %add.244221.7.clone.1 = f32[1280,1280]{1,0} add(%select.20379.7.clone.1, %multiply.25621.7.clone.1) + %multiply.25622.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.244221.7.clone.1, %maximum.3603.3.clone.1) + %select.20389.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7118.3.clone.1, %multiply.25612.9.clone.1, %multiply.25622.7.clone.1) + %multiply.25623.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20389.7.clone.1, %broadcast.244500.640) + %clamp.1129.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25623.5.clone.1, %broadcast.244501.384) + %multiply.25624.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1129.3.clone.1, %broadcast.244502.1) + %constant_177603_1_clone_1 = u32[] constant(3731615297) + %broadcast.252810.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177603_1_clone_1), dimensions={} + %add.248884.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252810.44.clone.1) + %constant_177614_1_clone_1 = u32[] constant(961749960) + %broadcast.252811.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177614_1_clone_1), dimensions={} + %add.248886.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252811.113.clone.1) + %add.248887.35.clone.1 = u32[1280,1280]{1,0} add(%add.248884.37.clone.1, %add.248886.99.clone.1) + %shift-left.109834.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248886.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116024.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248886.99.clone.1, %broadcast.244415.6016) + %or.115553.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109834.31.clone.1, %shift-right-logical.116024.29.clone.1) + %xor.122103.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248887.35.clone.1, %or.115553.29.clone.1) + %add.248888.5.clone.1 = u32[1280,1280]{1,0} add(%add.248887.35.clone.1, %xor.122103.27.clone.1) + %shift-left.109835.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122103.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116025.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122103.27.clone.1, %broadcast.244417.5760) + %or.115555.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109835.9.clone.1, %shift-right-logical.116025.9.clone.1) + %xor.122104.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248888.5.clone.1, %or.115555.7.clone.1) + %add.248889.3.clone.1 = u32[1280,1280]{1,0} add(%add.248888.5.clone.1, %xor.122104.5.clone.1) + %shift-left.109836.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122104.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116026.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122104.5.clone.1, %broadcast.244419.4352) + %or.115556.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109836.5.clone.1, %shift-right-logical.116026.5.clone.1) + %xor.122105.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248889.3.clone.1, %or.115556.3.clone.1) + %add.248891.3.clone.1 = u32[1280,1280]{1,0} add(%add.248889.3.clone.1, %xor.122105.3.clone.1) + %add.248892.7.clone.1 = u32[1280,1280]{1,0} add(%add.248891.3.clone.1, %broadcast.252811.113.clone.1) + %shift-left.109837.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122105.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116027.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122105.3.clone.1, %broadcast.244418.4352) + %or.115557.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109837.5.clone.1, %shift-right-logical.116027.5.clone.1) + %xor.122106.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248891.3.clone.1, %or.115557.3.clone.1) + %constant_218297_1_clone_1 = u32[] constant(4243183188) + %broadcast.252821.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218297_1_clone_1), dimensions={} + %add.248893.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122106.3.clone.1, %broadcast.252821.5.clone.1) + %add.248894.5.clone.1 = u32[1280,1280]{1,0} add(%add.248892.7.clone.1, %add.248893.5.clone.1) + %shift-left.109839.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248893.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116028.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248893.5.clone.1, %broadcast.244416.5760) + %or.115558.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109839.9.clone.1, %shift-right-logical.116028.9.clone.1) + %xor.122107.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248894.5.clone.1, %or.115558.7.clone.1) + %add.248895.3.clone.1 = u32[1280,1280]{1,0} add(%add.248894.5.clone.1, %xor.122107.5.clone.1) + %shift-left.109840.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122107.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116029.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122107.5.clone.1, %broadcast.244429.2304) + %or.115559.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109840.9.clone.1, %shift-right-logical.116029.9.clone.1) + %xor.122108.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248895.3.clone.1, %or.115559.7.clone.1) + %add.248897.3.clone.1 = u32[1280,1280]{1,0} add(%add.248895.3.clone.1, %xor.122108.5.clone.1) + %shift-left.109841.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122108.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116030.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122108.5.clone.1, %broadcast.244430.4608) + %or.115560.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109841.9.clone.1, %shift-right-logical.116030.9.clone.1) + %xor.122110.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248897.3.clone.1, %or.115560.7.clone.1) + %add.248901.3.clone.1 = u32[1280,1280]{1,0} add(%add.248897.3.clone.1, %xor.122110.5.clone.1) + %constant_177616_1_clone_1 = u32[] constant(4243183187) + %broadcast.252828.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177616_1_clone_1), dimensions={} + %add.248902.7.clone.1 = u32[1280,1280]{1,0} add(%add.248901.3.clone.1, %broadcast.252828.24.clone.1) + %shift-left.109842.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122110.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116031.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122110.5.clone.1, %broadcast.244434.2816) + %or.115561.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109842.11.clone.1, %shift-right-logical.116031.11.clone.1) + %xor.122111.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248901.3.clone.1, %or.115561.9.clone.1) + %constant_218298_1_clone_1 = u32[] constant(3731615299) + %broadcast.252831.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218298_1_clone_1), dimensions={} + %add.248903.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122111.7.clone.1, %broadcast.252831.5.clone.1) + %add.248904.5.clone.1 = u32[1280,1280]{1,0} add(%add.248902.7.clone.1, %add.248903.5.clone.1) + %shift-left.109844.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248903.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116032.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248903.5.clone.1, %broadcast.244415.6016) + %or.115562.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109844.9.clone.1, %shift-right-logical.116032.9.clone.1) + %xor.122112.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248904.5.clone.1, %or.115562.7.clone.1) + %add.248906.3.clone.1 = u32[1280,1280]{1,0} add(%add.248904.5.clone.1, %xor.122112.5.clone.1) + %shift-left.109845.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122112.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116033.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122112.5.clone.1, %broadcast.244417.5760) + %or.115563.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109845.9.clone.1, %shift-right-logical.116033.9.clone.1) + %xor.122113.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248906.3.clone.1, %or.115563.7.clone.1) + %add.248907.3.clone.1 = u32[1280,1280]{1,0} add(%add.248906.3.clone.1, %xor.122113.5.clone.1) + %shift-left.109846.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122113.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116034.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122113.5.clone.1, %broadcast.244419.4352) + %or.115565.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109846.7.clone.1, %shift-right-logical.116034.7.clone.1) + %xor.122115.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248907.3.clone.1, %or.115565.5.clone.1) + %add.248908.3.clone.1 = u32[1280,1280]{1,0} add(%add.248907.3.clone.1, %xor.122115.3.clone.1) + %add.248909.7.clone.1 = u32[1280,1280]{1,0} add(%add.248908.3.clone.1, %broadcast.252810.44.clone.1) + %shift-left.109847.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122115.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116035.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122115.3.clone.1, %broadcast.244418.4352) + %or.115566.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109847.7.clone.1, %shift-right-logical.116035.7.clone.1) + %xor.122116.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248908.3.clone.1, %or.115566.5.clone.1) + %constant_218299_1_clone_1 = u32[] constant(961749963) + %broadcast.252841.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218299_1_clone_1), dimensions={} + %add.248911.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122116.3.clone.1, %broadcast.252841.5.clone.1) + %add.248912.5.clone.1 = u32[1280,1280]{1,0} add(%add.248909.7.clone.1, %add.248911.5.clone.1) + %shift-left.109848.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248911.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116036.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248911.5.clone.1, %broadcast.244416.5760) + %or.115567.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109848.9.clone.1, %shift-right-logical.116036.9.clone.1) + %xor.122117.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248912.5.clone.1, %or.115567.7.clone.1) + %add.248913.3.clone.1 = u32[1280,1280]{1,0} add(%add.248912.5.clone.1, %xor.122117.5.clone.1) + %shift-left.109849.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122117.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116037.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122117.5.clone.1, %broadcast.244429.2304) + %or.115568.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109849.9.clone.1, %shift-right-logical.116037.9.clone.1) + %xor.122118.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248913.3.clone.1, %or.115568.7.clone.1) + %add.248914.3.clone.1 = u32[1280,1280]{1,0} add(%add.248913.3.clone.1, %xor.122118.5.clone.1) + %shift-left.109850.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122118.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116038.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122118.5.clone.1, %broadcast.244430.4608) + %or.115570.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109850.9.clone.1, %shift-right-logical.116038.9.clone.1) + %xor.122120.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248914.3.clone.1, %or.115570.7.clone.1) + %add.248916.3.clone.1 = u32[1280,1280]{1,0} add(%add.248914.3.clone.1, %xor.122120.5.clone.1) + %add.248917.7.clone.1 = u32[1280,1280]{1,0} add(%add.248916.3.clone.1, %broadcast.252811.113.clone.1) + %shift-left.109851.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122120.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116039.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122120.5.clone.1, %broadcast.244434.2816) + %or.115571.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109851.11.clone.1, %shift-right-logical.116039.11.clone.1) + %xor.122121.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248916.3.clone.1, %or.115571.9.clone.1) + %constant_218300_1_clone_1 = u32[] constant(4243183191) + %broadcast.252851.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218300_1_clone_1), dimensions={} + %add.248918.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122121.7.clone.1, %broadcast.252851.5.clone.1) + %add.248919.5.clone.1 = u32[1280,1280]{1,0} add(%add.248917.7.clone.1, %add.248918.5.clone.1) + %shift-left.109852.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248918.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116040.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248918.5.clone.1, %broadcast.244415.6016) + %or.115572.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109852.9.clone.1, %shift-right-logical.116040.9.clone.1) + %xor.122122.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248919.5.clone.1, %or.115572.7.clone.1) + %add.248920.3.clone.1 = u32[1280,1280]{1,0} add(%add.248919.5.clone.1, %xor.122122.5.clone.1) + %shift-left.109854.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122122.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116041.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122122.5.clone.1, %broadcast.244417.5760) + %or.115573.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109854.9.clone.1, %shift-right-logical.116041.9.clone.1) + %xor.122123.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248920.3.clone.1, %or.115573.7.clone.1) + %add.248922.3.clone.1 = u32[1280,1280]{1,0} add(%add.248920.3.clone.1, %xor.122123.5.clone.1) + %shift-left.109855.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122123.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116042.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122123.5.clone.1, %broadcast.244419.4352) + %or.115575.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109855.5.clone.1, %shift-right-logical.116042.5.clone.1) + %xor.122125.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248922.3.clone.1, %or.115575.3.clone.1) + %add.248926.3.clone.1 = u32[1280,1280]{1,0} add(%add.248922.3.clone.1, %xor.122125.3.clone.1) + %add.248927.17.clone.1 = u32[1280,1280]{1,0} add(%add.248926.3.clone.1, %broadcast.252828.24.clone.1) + %shift-left.109856.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122125.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116043.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122125.3.clone.1, %broadcast.244418.4352) + %or.115576.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109856.5.clone.1, %shift-right-logical.116043.5.clone.1) + %xor.122126.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248926.3.clone.1, %or.115576.3.clone.1) + %constant_218301_1_clone_1 = u32[] constant(3731615302) + %broadcast.252861.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218301_1_clone_1), dimensions={} + %add.248928.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122126.15.clone.1, %broadcast.252861.19.clone.1) + %xor.122127.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248927.17.clone.1, %add.248928.19.clone.1) + %shift-right-logical.116044.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122127.17.clone.1, %broadcast.244468.1920) + %or.115577.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116044.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5765.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115577.13.clone.1) + %add.248929.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5765.11.clone.1, %broadcast.244470.1152) + %multiply.26575.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248929.9.clone.1, %broadcast.244471.896) + %add.248931.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26575.7.clone.1, %broadcast.244408.1024) + %maximum.3697.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248931.5.clone.1) + %abs.1547.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3697.3.clone.1) + %compare.7242.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1547.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26576.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3697.3.clone.1, %broadcast.244476.1152) + %negate.4599.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3697.3.clone.1) + %multiply.26577.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3697.3.clone.1, %negate.4599.5.clone.1) + %log-plus-one.1547.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26577.5.clone.1) + %negate.4600.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1547.3.clone.1) + %compare.7243.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4600.4.clone.1, %broadcast.244477.384), direction=LT + %select.21103.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21104.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21105.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21106.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21107.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21108.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21109.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21110.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21111.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248932.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4600.4.clone.1, %broadcast.244496.640) + %sqrt.1547.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4600.4.clone.1) + %add.248933.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1547.5.clone.1, %broadcast.244498.640) + %select.21112.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7243.3.clone.1, %add.248932.5.clone.1, %add.248933.5.clone.1) + %multiply.26578.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21111.3.clone.1, %select.21112.3.clone.1) + %add.248934.1.clone.1 = f32[1280,1280]{1,0} add(%select.21110.3.clone.1, %multiply.26578.1.clone.1) + %multiply.26579.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248934.1.clone.1, %select.21112.3.clone.1) + %add.248936.1.clone.1 = f32[1280,1280]{1,0} add(%select.21109.3.clone.1, %multiply.26579.1.clone.1) + %multiply.26580.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248936.1.clone.1, %select.21112.3.clone.1) + %add.248937.1.clone.1 = f32[1280,1280]{1,0} add(%select.21108.3.clone.1, %multiply.26580.1.clone.1) + %multiply.26581.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248937.1.clone.1, %select.21112.3.clone.1) + %add.248938.1.clone.1 = f32[1280,1280]{1,0} add(%select.21107.3.clone.1, %multiply.26581.1.clone.1) + %multiply.26582.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248938.1.clone.1, %select.21112.3.clone.1) + %add.248939.3.clone.1 = f32[1280,1280]{1,0} add(%select.21106.5.clone.1, %multiply.26582.1.clone.1) + %multiply.26583.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248939.3.clone.1, %select.21112.3.clone.1) + %add.248941.3.clone.1 = f32[1280,1280]{1,0} add(%select.21105.5.clone.1, %multiply.26583.1.clone.1) + %multiply.26584.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248941.3.clone.1, %select.21112.3.clone.1) + %add.248942.9.clone.1 = f32[1280,1280]{1,0} add(%select.21104.11.clone.1, %multiply.26584.7.clone.1) + %multiply.26585.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248942.9.clone.1, %select.21112.3.clone.1) + %add.248943.7.clone.1 = f32[1280,1280]{1,0} add(%select.21103.7.clone.1, %multiply.26585.7.clone.1) + %multiply.26586.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248943.7.clone.1, %maximum.3697.3.clone.1) + %select.21113.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7242.3.clone.1, %multiply.26576.9.clone.1, %multiply.26586.7.clone.1) + %multiply.26587.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21113.7.clone.1, %broadcast.244500.640) + %clamp.1191.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26587.5.clone.1, %broadcast.244501.384) + %multiply.26588.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1191.3.clone.1, %broadcast.244502.1) + %constant_177814_1_clone_1 = u32[] constant(3930035116) + %broadcast.252896.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177814_1_clone_1), dimensions={} + %add.248944.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252896.44.clone.1) + %constant_177828_1_clone_1 = u32[] constant(595127978) + %broadcast.252897.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177828_1_clone_1), dimensions={} + %add.248945.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252897.113.clone.1) + %add.248947.35.clone.1 = u32[1280,1280]{1,0} add(%add.248944.37.clone.1, %add.248945.99.clone.1) + %shift-left.109857.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248945.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116045.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248945.99.clone.1, %broadcast.244415.6016) + %or.115578.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109857.31.clone.1, %shift-right-logical.116045.29.clone.1) + %xor.122128.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248947.35.clone.1, %or.115578.29.clone.1) + %add.248951.5.clone.1 = u32[1280,1280]{1,0} add(%add.248947.35.clone.1, %xor.122128.27.clone.1) + %shift-left.109859.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122128.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116046.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122128.27.clone.1, %broadcast.244417.5760) + %or.115580.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109859.9.clone.1, %shift-right-logical.116046.9.clone.1) + %xor.122129.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248951.5.clone.1, %or.115580.7.clone.1) + %add.248952.3.clone.1 = u32[1280,1280]{1,0} add(%add.248951.5.clone.1, %xor.122129.5.clone.1) + %shift-left.109860.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122129.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116047.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122129.5.clone.1, %broadcast.244419.4352) + %or.115581.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109860.5.clone.1, %shift-right-logical.116047.5.clone.1) + %xor.122130.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248952.3.clone.1, %or.115581.3.clone.1) + %add.248953.3.clone.1 = u32[1280,1280]{1,0} add(%add.248952.3.clone.1, %xor.122130.3.clone.1) + %add.248954.7.clone.1 = u32[1280,1280]{1,0} add(%add.248953.3.clone.1, %broadcast.252897.113.clone.1) + %shift-left.109861.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122130.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116048.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122130.3.clone.1, %broadcast.244418.4352) + %or.115582.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109861.5.clone.1, %shift-right-logical.116048.5.clone.1) + %xor.122131.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248953.3.clone.1, %or.115582.3.clone.1) + %constant_218302_1_clone_1 = u32[] constant(3533072093) + %broadcast.252907.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218302_1_clone_1), dimensions={} + %add.248956.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122131.3.clone.1, %broadcast.252907.5.clone.1) + %add.248957.5.clone.1 = u32[1280,1280]{1,0} add(%add.248954.7.clone.1, %add.248956.5.clone.1) + %shift-left.109862.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248956.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116049.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248956.5.clone.1, %broadcast.244416.5760) + %or.115583.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109862.9.clone.1, %shift-right-logical.116049.9.clone.1) + %xor.122132.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248957.5.clone.1, %or.115583.7.clone.1) + %add.248958.3.clone.1 = u32[1280,1280]{1,0} add(%add.248957.5.clone.1, %xor.122132.5.clone.1) + %shift-left.109864.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122132.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116050.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122132.5.clone.1, %broadcast.244429.2304) + %or.115584.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109864.9.clone.1, %shift-right-logical.116050.9.clone.1) + %xor.122133.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248958.3.clone.1, %or.115584.7.clone.1) + %add.248959.3.clone.1 = u32[1280,1280]{1,0} add(%add.248958.3.clone.1, %xor.122133.5.clone.1) + %shift-left.109865.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122133.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116051.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122133.5.clone.1, %broadcast.244430.4608) + %or.115585.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109865.9.clone.1, %shift-right-logical.116051.9.clone.1) + %xor.122135.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248959.3.clone.1, %or.115585.7.clone.1) + %add.248961.3.clone.1 = u32[1280,1280]{1,0} add(%add.248959.3.clone.1, %xor.122135.5.clone.1) + %constant_177832_1_clone_1 = u32[] constant(3533072092) + %broadcast.252916.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177832_1_clone_1), dimensions={} + %add.248962.7.clone.1 = u32[1280,1280]{1,0} add(%add.248961.3.clone.1, %broadcast.252916.24.clone.1) + %shift-left.109866.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122135.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116052.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122135.5.clone.1, %broadcast.244434.2816) + %or.115586.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109866.11.clone.1, %shift-right-logical.116052.11.clone.1) + %xor.122136.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248961.3.clone.1, %or.115586.9.clone.1) + %constant_218303_1_clone_1 = u32[] constant(3930035118) + %broadcast.252922.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218303_1_clone_1), dimensions={} + %add.248963.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122136.7.clone.1, %broadcast.252922.5.clone.1) + %add.248964.5.clone.1 = u32[1280,1280]{1,0} add(%add.248962.7.clone.1, %add.248963.5.clone.1) + %shift-left.109867.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248963.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116053.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248963.5.clone.1, %broadcast.244415.6016) + %or.115587.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109867.9.clone.1, %shift-right-logical.116053.9.clone.1) + %xor.122137.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248964.5.clone.1, %or.115587.7.clone.1) + %add.248966.3.clone.1 = u32[1280,1280]{1,0} add(%add.248964.5.clone.1, %xor.122137.5.clone.1) + %shift-left.109869.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122137.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116054.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122137.5.clone.1, %broadcast.244417.5760) + %or.115588.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109869.9.clone.1, %shift-right-logical.116054.9.clone.1) + %xor.122138.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248966.3.clone.1, %or.115588.7.clone.1) + %add.248967.3.clone.1 = u32[1280,1280]{1,0} add(%add.248966.3.clone.1, %xor.122138.5.clone.1) + %shift-left.109870.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122138.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116055.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122138.5.clone.1, %broadcast.244419.4352) + %or.115590.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109870.7.clone.1, %shift-right-logical.116055.7.clone.1) + %xor.122140.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248967.3.clone.1, %or.115590.5.clone.1) + %add.248968.3.clone.1 = u32[1280,1280]{1,0} add(%add.248967.3.clone.1, %xor.122140.3.clone.1) + %add.248969.7.clone.1 = u32[1280,1280]{1,0} add(%add.248968.3.clone.1, %broadcast.252896.44.clone.1) + %shift-left.109871.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122140.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116056.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122140.3.clone.1, %broadcast.244418.4352) + %or.115591.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109871.7.clone.1, %shift-right-logical.116056.7.clone.1) + %xor.122141.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248968.3.clone.1, %or.115591.5.clone.1) + %constant_218304_1_clone_1 = u32[] constant(595127981) + %broadcast.252942.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218304_1_clone_1), dimensions={} + %add.248970.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122141.3.clone.1, %broadcast.252942.5.clone.1) + %add.248972.5.clone.1 = u32[1280,1280]{1,0} add(%add.248969.7.clone.1, %add.248970.5.clone.1) + %shift-left.109872.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248970.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116057.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248970.5.clone.1, %broadcast.244416.5760) + %or.115592.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109872.9.clone.1, %shift-right-logical.116057.9.clone.1) + %xor.122142.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248972.5.clone.1, %or.115592.7.clone.1) + %add.248976.3.clone.1 = u32[1280,1280]{1,0} add(%add.248972.5.clone.1, %xor.122142.5.clone.1) + %shift-left.109873.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122142.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116058.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122142.5.clone.1, %broadcast.244429.2304) + %or.115593.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109873.9.clone.1, %shift-right-logical.116058.9.clone.1) + %xor.122143.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248976.3.clone.1, %or.115593.7.clone.1) + %add.248977.3.clone.1 = u32[1280,1280]{1,0} add(%add.248976.3.clone.1, %xor.122143.5.clone.1) + %shift-left.109874.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122143.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116059.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122143.5.clone.1, %broadcast.244430.4608) + %or.115595.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109874.9.clone.1, %shift-right-logical.116059.9.clone.1) + %xor.122145.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248977.3.clone.1, %or.115595.7.clone.1) + %add.248978.3.clone.1 = u32[1280,1280]{1,0} add(%add.248977.3.clone.1, %xor.122145.5.clone.1) + %add.248979.7.clone.1 = u32[1280,1280]{1,0} add(%add.248978.3.clone.1, %broadcast.252897.113.clone.1) + %shift-left.109875.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122145.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116060.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122145.5.clone.1, %broadcast.244434.2816) + %or.115596.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109875.11.clone.1, %shift-right-logical.116060.11.clone.1) + %xor.122146.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248978.3.clone.1, %or.115596.9.clone.1) + %constant_218305_1_clone_1 = u32[] constant(3533072096) + %broadcast.252954.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218305_1_clone_1), dimensions={} + %add.248981.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122146.7.clone.1, %broadcast.252954.5.clone.1) + %add.248982.5.clone.1 = u32[1280,1280]{1,0} add(%add.248979.7.clone.1, %add.248981.5.clone.1) + %shift-left.109876.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248981.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116061.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248981.5.clone.1, %broadcast.244415.6016) + %or.115597.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109876.9.clone.1, %shift-right-logical.116061.9.clone.1) + %xor.122147.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248982.5.clone.1, %or.115597.7.clone.1) + %add.248983.3.clone.1 = u32[1280,1280]{1,0} add(%add.248982.5.clone.1, %xor.122147.5.clone.1) + %shift-left.109877.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122147.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116062.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122147.5.clone.1, %broadcast.244417.5760) + %or.115598.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109877.9.clone.1, %shift-right-logical.116062.9.clone.1) + %xor.122148.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248983.3.clone.1, %or.115598.7.clone.1) + %add.248984.3.clone.1 = u32[1280,1280]{1,0} add(%add.248983.3.clone.1, %xor.122148.5.clone.1) + %shift-left.109878.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122148.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116063.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122148.5.clone.1, %broadcast.244419.4352) + %or.115600.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109878.5.clone.1, %shift-right-logical.116063.5.clone.1) + %xor.122150.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248984.3.clone.1, %or.115600.3.clone.1) + %add.248986.3.clone.1 = u32[1280,1280]{1,0} add(%add.248984.3.clone.1, %xor.122150.3.clone.1) + %add.248987.17.clone.1 = u32[1280,1280]{1,0} add(%add.248986.3.clone.1, %broadcast.252916.24.clone.1) + %shift-left.109879.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122150.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116064.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122150.3.clone.1, %broadcast.244418.4352) + %or.115601.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109879.5.clone.1, %shift-right-logical.116064.5.clone.1) + %xor.122151.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248986.3.clone.1, %or.115601.3.clone.1) + %constant_218306_1_clone_1 = u32[] constant(3930035121) + %broadcast.252964.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218306_1_clone_1), dimensions={} + %add.248988.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122151.15.clone.1, %broadcast.252964.19.clone.1) + %xor.122152.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248987.17.clone.1, %add.248988.19.clone.1) + %shift-right-logical.116065.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122152.17.clone.1, %broadcast.244468.1920) + %or.115602.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116065.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5766.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115602.13.clone.1) + %add.248989.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5766.11.clone.1, %broadcast.244470.1152) + %multiply.26589.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248989.9.clone.1, %broadcast.244471.896) + %add.248991.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26589.7.clone.1, %broadcast.244408.1024) + %maximum.3698.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248991.5.clone.1) + %abs.1548.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3698.3.clone.1) + %compare.7244.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1548.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26590.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3698.3.clone.1, %broadcast.244476.1152) + %negate.4601.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3698.3.clone.1) + %multiply.26591.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3698.3.clone.1, %negate.4601.5.clone.1) + %log-plus-one.1548.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26591.5.clone.1) + %negate.4602.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1548.3.clone.1) + %compare.7245.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4602.4.clone.1, %broadcast.244477.384), direction=LT + %select.21114.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21115.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21116.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21117.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21118.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21119.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21120.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21121.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21122.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248992.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4602.4.clone.1, %broadcast.244496.640) + %sqrt.1548.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4602.4.clone.1) + %add.248993.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1548.5.clone.1, %broadcast.244498.640) + %select.21123.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7245.3.clone.1, %add.248992.5.clone.1, %add.248993.5.clone.1) + %multiply.26592.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21122.3.clone.1, %select.21123.3.clone.1) + %add.248994.1.clone.1 = f32[1280,1280]{1,0} add(%select.21121.3.clone.1, %multiply.26592.1.clone.1) + %multiply.26593.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248994.1.clone.1, %select.21123.3.clone.1) + %add.248995.1.clone.1 = f32[1280,1280]{1,0} add(%select.21120.3.clone.1, %multiply.26593.1.clone.1) + %multiply.26594.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248995.1.clone.1, %select.21123.3.clone.1) + %add.248997.1.clone.1 = f32[1280,1280]{1,0} add(%select.21119.3.clone.1, %multiply.26594.1.clone.1) + %multiply.26595.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248997.1.clone.1, %select.21123.3.clone.1) + %add.249000.1.clone.1 = f32[1280,1280]{1,0} add(%select.21118.3.clone.1, %multiply.26595.1.clone.1) + %multiply.26596.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249000.1.clone.1, %select.21123.3.clone.1) + %add.249001.3.clone.1 = f32[1280,1280]{1,0} add(%select.21117.5.clone.1, %multiply.26596.1.clone.1) + %multiply.26597.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249001.3.clone.1, %select.21123.3.clone.1) + %add.249002.3.clone.1 = f32[1280,1280]{1,0} add(%select.21116.5.clone.1, %multiply.26597.1.clone.1) + %multiply.26598.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249002.3.clone.1, %select.21123.3.clone.1) + %add.249003.9.clone.1 = f32[1280,1280]{1,0} add(%select.21115.11.clone.1, %multiply.26598.7.clone.1) + %multiply.26599.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249003.9.clone.1, %select.21123.3.clone.1) + %add.249004.7.clone.1 = f32[1280,1280]{1,0} add(%select.21114.7.clone.1, %multiply.26599.7.clone.1) + %multiply.26600.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249004.7.clone.1, %maximum.3698.3.clone.1) + %select.21124.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7244.3.clone.1, %multiply.26590.9.clone.1, %multiply.26600.7.clone.1) + %multiply.26601.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21124.7.clone.1, %broadcast.244500.640) + %clamp.1192.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26601.5.clone.1, %broadcast.244501.384) + %multiply.26602.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1192.3.clone.1, %broadcast.244502.1) + %constant_187357_1_clone_1 = u32[] constant(3486608628) + %broadcast.257058.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_187357_1_clone_1), dimensions={} + %add.251290.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.257058.44.clone.1) + %constant_187364_1_clone_1 = u32[] constant(2943763298) + %broadcast.257059.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_187364_1_clone_1), dimensions={} + %add.251291.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.257059.113.clone.1) + %add.251293.35.clone.1 = u32[1280,1280]{1,0} add(%add.251290.37.clone.1, %add.251291.99.clone.1) + %shift-left.110880.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251291.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117156.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251291.99.clone.1, %broadcast.244415.6016) + %or.116665.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110880.31.clone.1, %shift-right-logical.117156.29.clone.1) + %xor.123236.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251293.35.clone.1, %or.116665.29.clone.1) + %add.251294.5.clone.1 = u32[1280,1280]{1,0} add(%add.251293.35.clone.1, %xor.123236.27.clone.1) + %shift-left.110881.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123236.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117157.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123236.27.clone.1, %broadcast.244417.5760) + %or.116666.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110881.9.clone.1, %shift-right-logical.117157.9.clone.1) + %xor.123237.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251294.5.clone.1, %or.116666.7.clone.1) + %add.251295.3.clone.1 = u32[1280,1280]{1,0} add(%add.251294.5.clone.1, %xor.123237.5.clone.1) + %shift-left.110882.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123237.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117159.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123237.5.clone.1, %broadcast.244419.4352) + %or.116667.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110882.5.clone.1, %shift-right-logical.117159.5.clone.1) + %xor.123238.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251295.3.clone.1, %or.116667.3.clone.1) + %add.251296.3.clone.1 = u32[1280,1280]{1,0} add(%add.251295.3.clone.1, %xor.123238.3.clone.1) + %add.251297.7.clone.1 = u32[1280,1280]{1,0} add(%add.251296.3.clone.1, %broadcast.257059.113.clone.1) + %shift-left.110883.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123238.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117160.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123238.3.clone.1, %broadcast.244418.4352) + %or.116668.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110883.5.clone.1, %shift-right-logical.117160.5.clone.1) + %xor.123239.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251296.3.clone.1, %or.116668.3.clone.1) + %constant_218559_1_clone_1 = u32[] constant(2071344205) + %broadcast.257069.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218559_1_clone_1), dimensions={} + %add.251299.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123239.3.clone.1, %broadcast.257069.5.clone.1) + %add.251303.5.clone.1 = u32[1280,1280]{1,0} add(%add.251297.7.clone.1, %add.251299.5.clone.1) + %shift-left.110884.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251299.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117161.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251299.5.clone.1, %broadcast.244416.5760) + %or.116670.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110884.9.clone.1, %shift-right-logical.117161.9.clone.1) + %xor.123240.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251303.5.clone.1, %or.116670.7.clone.1) + %add.251304.3.clone.1 = u32[1280,1280]{1,0} add(%add.251303.5.clone.1, %xor.123240.5.clone.1) + %shift-left.110885.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123240.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117162.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123240.5.clone.1, %broadcast.244429.2304) + %or.116671.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110885.9.clone.1, %shift-right-logical.117162.9.clone.1) + %xor.123241.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251304.3.clone.1, %or.116671.7.clone.1) + %add.251305.3.clone.1 = u32[1280,1280]{1,0} add(%add.251304.3.clone.1, %xor.123241.5.clone.1) + %shift-left.110886.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123241.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117164.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123241.5.clone.1, %broadcast.244430.4608) + %or.116672.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110886.9.clone.1, %shift-right-logical.117164.9.clone.1) + %xor.123242.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251305.3.clone.1, %or.116672.7.clone.1) + %add.251306.3.clone.1 = u32[1280,1280]{1,0} add(%add.251305.3.clone.1, %xor.123242.5.clone.1) + %constant_187366_1_clone_1 = u32[] constant(2071344204) + %broadcast.257076.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_187366_1_clone_1), dimensions={} + %add.251308.7.clone.1 = u32[1280,1280]{1,0} add(%add.251306.3.clone.1, %broadcast.257076.24.clone.1) + %shift-left.110887.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123242.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117165.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123242.5.clone.1, %broadcast.244434.2816) + %or.116673.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110887.11.clone.1, %shift-right-logical.117165.11.clone.1) + %xor.123243.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251306.3.clone.1, %or.116673.9.clone.1) + %constant_218560_1_clone_1 = u32[] constant(3486608630) + %broadcast.257079.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218560_1_clone_1), dimensions={} + %add.251309.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123243.7.clone.1, %broadcast.257079.5.clone.1) + %add.251310.5.clone.1 = u32[1280,1280]{1,0} add(%add.251308.7.clone.1, %add.251309.5.clone.1) + %shift-left.110888.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251309.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117166.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251309.5.clone.1, %broadcast.244415.6016) + %or.116675.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110888.9.clone.1, %shift-right-logical.117166.9.clone.1) + %xor.123244.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251310.5.clone.1, %or.116675.7.clone.1) + %add.251311.3.clone.1 = u32[1280,1280]{1,0} add(%add.251310.5.clone.1, %xor.123244.5.clone.1) + %shift-left.110889.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123244.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117167.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123244.5.clone.1, %broadcast.244417.5760) + %or.116676.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110889.9.clone.1, %shift-right-logical.117167.9.clone.1) + %xor.123245.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251311.3.clone.1, %or.116676.7.clone.1) + %add.251313.3.clone.1 = u32[1280,1280]{1,0} add(%add.251311.3.clone.1, %xor.123245.5.clone.1) + %shift-left.110890.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123245.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117169.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123245.5.clone.1, %broadcast.244419.4352) + %or.116677.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110890.7.clone.1, %shift-right-logical.117169.7.clone.1) + %xor.123246.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251313.3.clone.1, %or.116677.5.clone.1) + %add.251314.3.clone.1 = u32[1280,1280]{1,0} add(%add.251313.3.clone.1, %xor.123246.3.clone.1) + %add.251315.7.clone.1 = u32[1280,1280]{1,0} add(%add.251314.3.clone.1, %broadcast.257058.44.clone.1) + %shift-left.110891.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123246.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117170.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123246.3.clone.1, %broadcast.244418.4352) + %or.116678.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110891.7.clone.1, %shift-right-logical.117170.7.clone.1) + %xor.123247.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251314.3.clone.1, %or.116678.5.clone.1) + %constant_218561_1_clone_1 = u32[] constant(2943763301) + %broadcast.257089.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218561_1_clone_1), dimensions={} + %add.251316.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123247.3.clone.1, %broadcast.257089.5.clone.1) + %add.251318.5.clone.1 = u32[1280,1280]{1,0} add(%add.251315.7.clone.1, %add.251316.5.clone.1) + %shift-left.110892.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251316.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117171.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251316.5.clone.1, %broadcast.244416.5760) + %or.116680.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110892.9.clone.1, %shift-right-logical.117171.9.clone.1) + %xor.123248.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251318.5.clone.1, %or.116680.7.clone.1) + %add.251319.3.clone.1 = u32[1280,1280]{1,0} add(%add.251318.5.clone.1, %xor.123248.5.clone.1) + %shift-left.110893.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123248.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117172.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123248.5.clone.1, %broadcast.244429.2304) + %or.116681.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110893.9.clone.1, %shift-right-logical.117172.9.clone.1) + %xor.123249.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251319.3.clone.1, %or.116681.7.clone.1) + %add.251320.3.clone.1 = u32[1280,1280]{1,0} add(%add.251319.3.clone.1, %xor.123249.5.clone.1) + %shift-left.110894.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123249.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117173.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123249.5.clone.1, %broadcast.244430.4608) + %or.116682.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110894.9.clone.1, %shift-right-logical.117173.9.clone.1) + %xor.123250.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251320.3.clone.1, %or.116682.7.clone.1) + %add.251321.3.clone.1 = u32[1280,1280]{1,0} add(%add.251320.3.clone.1, %xor.123250.5.clone.1) + %add.251322.7.clone.1 = u32[1280,1280]{1,0} add(%add.251321.3.clone.1, %broadcast.257059.113.clone.1) + %shift-left.110895.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123250.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117174.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123250.5.clone.1, %broadcast.244434.2816) + %or.116683.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110895.11.clone.1, %shift-right-logical.117174.11.clone.1) + %xor.123251.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251321.3.clone.1, %or.116683.9.clone.1) + %constant_218562_1_clone_1 = u32[] constant(2071344208) + %broadcast.257099.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218562_1_clone_1), dimensions={} + %add.251324.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123251.7.clone.1, %broadcast.257099.5.clone.1) + %add.251327.5.clone.1 = u32[1280,1280]{1,0} add(%add.251322.7.clone.1, %add.251324.5.clone.1) + %shift-left.110896.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251324.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117175.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251324.5.clone.1, %broadcast.244415.6016) + %or.116684.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110896.9.clone.1, %shift-right-logical.117175.9.clone.1) + %xor.123252.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251327.5.clone.1, %or.116684.7.clone.1) + %add.251328.3.clone.1 = u32[1280,1280]{1,0} add(%add.251327.5.clone.1, %xor.123252.5.clone.1) + %shift-left.110897.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123252.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117176.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123252.5.clone.1, %broadcast.244417.5760) + %or.116685.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110897.9.clone.1, %shift-right-logical.117176.9.clone.1) + %xor.123253.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251328.3.clone.1, %or.116685.7.clone.1) + %add.251329.3.clone.1 = u32[1280,1280]{1,0} add(%add.251328.3.clone.1, %xor.123253.5.clone.1) + %shift-left.110898.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123253.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117177.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123253.5.clone.1, %broadcast.244419.4352) + %or.116686.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110898.5.clone.1, %shift-right-logical.117177.5.clone.1) + %xor.123254.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251329.3.clone.1, %or.116686.3.clone.1) + %add.251330.3.clone.1 = u32[1280,1280]{1,0} add(%add.251329.3.clone.1, %xor.123254.3.clone.1) + %add.251331.17.clone.1 = u32[1280,1280]{1,0} add(%add.251330.3.clone.1, %broadcast.257076.24.clone.1) + %shift-left.110899.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123254.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117178.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123254.3.clone.1, %broadcast.244418.4352) + %or.116687.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110899.5.clone.1, %shift-right-logical.117178.5.clone.1) + %xor.123255.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251330.3.clone.1, %or.116687.3.clone.1) + %constant_218563_1_clone_1 = u32[] constant(3486608633) + %broadcast.257109.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218563_1_clone_1), dimensions={} + %add.251332.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123255.15.clone.1, %broadcast.257109.19.clone.1) + %xor.123256.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251331.17.clone.1, %add.251332.19.clone.1) + %shift-right-logical.117179.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123256.17.clone.1, %broadcast.244468.1920) + %or.116688.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117179.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5813.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116688.13.clone.1) + %add.251333.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5813.11.clone.1, %broadcast.244470.1152) + %multiply.27069.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251333.9.clone.1, %broadcast.244471.896) + %add.251334.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27069.7.clone.1, %broadcast.244408.1024) + %maximum.3745.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251334.5.clone.1) + %abs.1579.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3745.3.clone.1) + %compare.7320.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1579.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27070.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3745.3.clone.1, %broadcast.244476.1152) + %negate.4663.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3745.3.clone.1) + %multiply.27071.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3745.3.clone.1, %negate.4663.5.clone.1) + %log-plus-one.1579.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27071.5.clone.1) + %negate.4664.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1579.3.clone.1) + %compare.7321.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4664.4.clone.1, %broadcast.244477.384), direction=LT + %select.21455.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21456.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21457.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21458.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21459.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21460.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21461.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21462.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21463.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251335.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4664.4.clone.1, %broadcast.244496.640) + %sqrt.1579.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4664.4.clone.1) + %add.251336.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1579.5.clone.1, %broadcast.244498.640) + %select.21464.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7321.3.clone.1, %add.251335.5.clone.1, %add.251336.5.clone.1) + %multiply.27072.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21463.3.clone.1, %select.21464.3.clone.1) + %add.251337.1.clone.1 = f32[1280,1280]{1,0} add(%select.21462.3.clone.1, %multiply.27072.1.clone.1) + %multiply.27073.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251337.1.clone.1, %select.21464.3.clone.1) + %add.251338.1.clone.1 = f32[1280,1280]{1,0} add(%select.21461.3.clone.1, %multiply.27073.1.clone.1) + %multiply.27074.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251338.1.clone.1, %select.21464.3.clone.1) + %add.251339.1.clone.1 = f32[1280,1280]{1,0} add(%select.21460.3.clone.1, %multiply.27074.1.clone.1) + %multiply.27075.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251339.1.clone.1, %select.21464.3.clone.1) + %add.251340.1.clone.1 = f32[1280,1280]{1,0} add(%select.21459.3.clone.1, %multiply.27075.1.clone.1) + %multiply.27076.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251340.1.clone.1, %select.21464.3.clone.1) + %add.251341.3.clone.1 = f32[1280,1280]{1,0} add(%select.21458.5.clone.1, %multiply.27076.1.clone.1) + %multiply.27077.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251341.3.clone.1, %select.21464.3.clone.1) + %add.251342.3.clone.1 = f32[1280,1280]{1,0} add(%select.21457.5.clone.1, %multiply.27077.1.clone.1) + %multiply.27078.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251342.3.clone.1, %select.21464.3.clone.1) + %add.251343.9.clone.1 = f32[1280,1280]{1,0} add(%select.21456.11.clone.1, %multiply.27078.7.clone.1) + %multiply.27079.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251343.9.clone.1, %select.21464.3.clone.1) + %add.251344.7.clone.1 = f32[1280,1280]{1,0} add(%select.21455.7.clone.1, %multiply.27079.7.clone.1) + %multiply.27080.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251344.7.clone.1, %maximum.3745.3.clone.1) + %select.21465.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7320.3.clone.1, %multiply.27070.9.clone.1, %multiply.27080.7.clone.1) + %multiply.27081.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21465.7.clone.1, %broadcast.244500.640) + %clamp.1223.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27081.5.clone.1, %broadcast.244501.384) + %multiply.27082.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1223.3.clone.1, %broadcast.244502.1) + %constant_177052_1_clone_1 = u32[] constant(196262093) + %broadcast.252577.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177052_1_clone_1), dimensions={} + %add.248749.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252577.44.clone.1) + %constant_177063_1_clone_1 = u32[] constant(157763338) + %broadcast.252578.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177063_1_clone_1), dimensions={} + %add.248750.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252578.113.clone.1) + %add.248752.35.clone.1 = u32[1280,1280]{1,0} add(%add.248749.37.clone.1, %add.248750.99.clone.1) + %shift-left.109762.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248750.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115961.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248750.99.clone.1, %broadcast.244415.6016) + %or.115485.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109762.31.clone.1, %shift-right-logical.115961.29.clone.1) + %xor.122030.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248752.35.clone.1, %or.115485.29.clone.1) + %add.248753.5.clone.1 = u32[1280,1280]{1,0} add(%add.248752.35.clone.1, %xor.122030.27.clone.1) + %shift-left.109764.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122030.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115962.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122030.27.clone.1, %broadcast.244417.5760) + %or.115486.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109764.9.clone.1, %shift-right-logical.115962.9.clone.1) + %xor.122031.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248753.5.clone.1, %or.115486.7.clone.1) + %add.248754.3.clone.1 = u32[1280,1280]{1,0} add(%add.248753.5.clone.1, %xor.122031.5.clone.1) + %shift-left.109765.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122031.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115963.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122031.5.clone.1, %broadcast.244419.4352) + %or.115487.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109765.5.clone.1, %shift-right-logical.115963.5.clone.1) + %xor.122032.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248754.3.clone.1, %or.115487.3.clone.1) + %add.248755.3.clone.1 = u32[1280,1280]{1,0} add(%add.248754.3.clone.1, %xor.122032.3.clone.1) + %add.248757.7.clone.1 = u32[1280,1280]{1,0} add(%add.248755.3.clone.1, %broadcast.252578.113.clone.1) + %shift-left.109766.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122032.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115964.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122032.3.clone.1, %broadcast.244418.4352) + %or.115488.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109766.5.clone.1, %shift-right-logical.115964.5.clone.1) + %xor.122033.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248755.3.clone.1, %or.115488.3.clone.1) + %constant_218282_1_clone_1 = u32[] constant(419750942) + %broadcast.252590.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218282_1_clone_1), dimensions={} + %add.248758.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122033.3.clone.1, %broadcast.252590.5.clone.1) + %add.248759.5.clone.1 = u32[1280,1280]{1,0} add(%add.248757.7.clone.1, %add.248758.5.clone.1) + %shift-left.109767.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248758.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115965.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248758.5.clone.1, %broadcast.244416.5760) + %or.115489.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109767.9.clone.1, %shift-right-logical.115965.9.clone.1) + %xor.122034.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248759.5.clone.1, %or.115489.7.clone.1) + %add.248760.3.clone.1 = u32[1280,1280]{1,0} add(%add.248759.5.clone.1, %xor.122034.5.clone.1) + %shift-left.109769.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122034.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115966.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122034.5.clone.1, %broadcast.244429.2304) + %or.115490.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109769.9.clone.1, %shift-right-logical.115966.9.clone.1) + %xor.122035.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248760.3.clone.1, %or.115490.7.clone.1) + %add.248761.3.clone.1 = u32[1280,1280]{1,0} add(%add.248760.3.clone.1, %xor.122035.5.clone.1) + %shift-left.109770.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122035.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115967.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122035.5.clone.1, %broadcast.244430.4608) + %or.115491.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109770.9.clone.1, %shift-right-logical.115967.9.clone.1) + %xor.122036.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248761.3.clone.1, %or.115491.7.clone.1) + %add.248763.3.clone.1 = u32[1280,1280]{1,0} add(%add.248761.3.clone.1, %xor.122036.5.clone.1) + %constant_177065_1_clone_1 = u32[] constant(419750941) + %broadcast.252597.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_177065_1_clone_1), dimensions={} + %add.248767.7.clone.1 = u32[1280,1280]{1,0} add(%add.248763.3.clone.1, %broadcast.252597.24.clone.1) + %shift-left.109771.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122036.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115968.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122036.5.clone.1, %broadcast.244434.2816) + %or.115492.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109771.11.clone.1, %shift-right-logical.115968.11.clone.1) + %xor.122037.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248763.3.clone.1, %or.115492.9.clone.1) + %constant_218283_1_clone_1 = u32[] constant(196262095) + %broadcast.252600.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218283_1_clone_1), dimensions={} + %add.248768.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122037.7.clone.1, %broadcast.252600.5.clone.1) + %add.248769.5.clone.1 = u32[1280,1280]{1,0} add(%add.248767.7.clone.1, %add.248768.5.clone.1) + %shift-left.109772.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248768.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115969.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248768.5.clone.1, %broadcast.244415.6016) + %or.115493.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109772.9.clone.1, %shift-right-logical.115969.9.clone.1) + %xor.122038.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248769.5.clone.1, %or.115493.7.clone.1) + %add.248770.3.clone.1 = u32[1280,1280]{1,0} add(%add.248769.5.clone.1, %xor.122038.5.clone.1) + %shift-left.109773.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122038.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115970.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122038.5.clone.1, %broadcast.244417.5760) + %or.115495.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109773.9.clone.1, %shift-right-logical.115970.9.clone.1) + %xor.122039.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248770.3.clone.1, %or.115495.7.clone.1) + %add.248772.3.clone.1 = u32[1280,1280]{1,0} add(%add.248770.3.clone.1, %xor.122039.5.clone.1) + %shift-left.109774.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122039.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115971.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122039.5.clone.1, %broadcast.244419.4352) + %or.115496.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109774.7.clone.1, %shift-right-logical.115971.7.clone.1) + %xor.122040.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248772.3.clone.1, %or.115496.5.clone.1) + %add.248773.3.clone.1 = u32[1280,1280]{1,0} add(%add.248772.3.clone.1, %xor.122040.3.clone.1) + %add.248774.7.clone.1 = u32[1280,1280]{1,0} add(%add.248773.3.clone.1, %broadcast.252577.44.clone.1) + %shift-left.109775.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122040.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115972.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122040.3.clone.1, %broadcast.244418.4352) + %or.115498.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109775.7.clone.1, %shift-right-logical.115972.7.clone.1) + %xor.122041.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248773.3.clone.1, %or.115498.5.clone.1) + %constant_218284_1_clone_1 = u32[] constant(157763341) + %broadcast.252612.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218284_1_clone_1), dimensions={} + %add.248775.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122041.3.clone.1, %broadcast.252612.5.clone.1) + %add.248777.5.clone.1 = u32[1280,1280]{1,0} add(%add.248774.7.clone.1, %add.248775.5.clone.1) + %shift-left.109776.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248775.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115973.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248775.5.clone.1, %broadcast.244416.5760) + %or.115499.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109776.9.clone.1, %shift-right-logical.115973.9.clone.1) + %xor.122042.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248777.5.clone.1, %or.115499.7.clone.1) + %add.248778.3.clone.1 = u32[1280,1280]{1,0} add(%add.248777.5.clone.1, %xor.122042.5.clone.1) + %shift-left.109777.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122042.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115974.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122042.5.clone.1, %broadcast.244429.2304) + %or.115500.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109777.9.clone.1, %shift-right-logical.115974.9.clone.1) + %xor.122043.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248778.3.clone.1, %or.115500.7.clone.1) + %add.248779.3.clone.1 = u32[1280,1280]{1,0} add(%add.248778.3.clone.1, %xor.122043.5.clone.1) + %shift-left.109779.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122043.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115975.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122043.5.clone.1, %broadcast.244430.4608) + %or.115501.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109779.9.clone.1, %shift-right-logical.115975.9.clone.1) + %xor.122044.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248779.3.clone.1, %or.115501.7.clone.1) + %add.248780.3.clone.1 = u32[1280,1280]{1,0} add(%add.248779.3.clone.1, %xor.122044.5.clone.1) + %add.248782.7.clone.1 = u32[1280,1280]{1,0} add(%add.248780.3.clone.1, %broadcast.252578.113.clone.1) + %shift-left.109780.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122044.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115976.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122044.5.clone.1, %broadcast.244434.2816) + %or.115502.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109780.11.clone.1, %shift-right-logical.115976.11.clone.1) + %xor.122045.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248780.3.clone.1, %or.115502.9.clone.1) + %constant_218285_1_clone_1 = u32[] constant(419750945) + %broadcast.252622.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218285_1_clone_1), dimensions={} + %add.248783.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122045.7.clone.1, %broadcast.252622.5.clone.1) + %add.248784.5.clone.1 = u32[1280,1280]{1,0} add(%add.248782.7.clone.1, %add.248783.5.clone.1) + %shift-left.109781.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248783.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115977.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248783.5.clone.1, %broadcast.244415.6016) + %or.115503.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109781.9.clone.1, %shift-right-logical.115977.9.clone.1) + %xor.122046.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248784.5.clone.1, %or.115503.7.clone.1) + %add.248785.3.clone.1 = u32[1280,1280]{1,0} add(%add.248784.5.clone.1, %xor.122046.5.clone.1) + %shift-left.109782.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122046.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115978.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122046.5.clone.1, %broadcast.244417.5760) + %or.115504.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109782.9.clone.1, %shift-right-logical.115978.9.clone.1) + %xor.122047.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248785.3.clone.1, %or.115504.7.clone.1) + %add.248786.3.clone.1 = u32[1280,1280]{1,0} add(%add.248785.3.clone.1, %xor.122047.5.clone.1) + %shift-left.109784.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122047.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115979.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122047.5.clone.1, %broadcast.244419.4352) + %or.115505.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109784.5.clone.1, %shift-right-logical.115979.5.clone.1) + %xor.122048.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248786.3.clone.1, %or.115505.3.clone.1) + %add.248788.3.clone.1 = u32[1280,1280]{1,0} add(%add.248786.3.clone.1, %xor.122048.3.clone.1) + %add.248792.17.clone.1 = u32[1280,1280]{1,0} add(%add.248788.3.clone.1, %broadcast.252597.24.clone.1) + %shift-left.109785.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122048.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115980.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122048.3.clone.1, %broadcast.244418.4352) + %or.115506.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109785.5.clone.1, %shift-right-logical.115980.5.clone.1) + %xor.122051.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248788.3.clone.1, %or.115506.3.clone.1) + %constant_218286_1_clone_1 = u32[] constant(196262098) + %broadcast.252634.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218286_1_clone_1), dimensions={} + %add.248793.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122051.15.clone.1, %broadcast.252634.19.clone.1) + %xor.122052.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248792.17.clone.1, %add.248793.19.clone.1) + %shift-right-logical.115981.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122052.17.clone.1, %broadcast.244468.1920) + %or.115507.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115981.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5762.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115507.13.clone.1) + %add.248794.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5762.11.clone.1, %broadcast.244470.1152) + %multiply.26557.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248794.9.clone.1, %broadcast.244471.896) + %add.248795.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26557.7.clone.1, %broadcast.244408.1024) + %maximum.3694.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248795.5.clone.1) + %abs.1546.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3694.3.clone.1) + %compare.7240.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1546.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26558.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3694.3.clone.1, %broadcast.244476.1152) + %negate.4597.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3694.3.clone.1) + %multiply.26559.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3694.3.clone.1, %negate.4597.5.clone.1) + %log-plus-one.1546.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26559.5.clone.1) + %negate.4598.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1546.3.clone.1) + %compare.7241.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4598.4.clone.1, %broadcast.244477.384), direction=LT + %select.21092.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21093.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21094.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21095.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21096.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21097.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21098.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21099.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21100.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248797.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4598.4.clone.1, %broadcast.244496.640) + %sqrt.1546.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4598.4.clone.1) + %add.248798.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1546.5.clone.1, %broadcast.244498.640) + %select.21101.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7241.3.clone.1, %add.248797.5.clone.1, %add.248798.5.clone.1) + %multiply.26560.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21100.3.clone.1, %select.21101.3.clone.1) + %add.248799.1.clone.1 = f32[1280,1280]{1,0} add(%select.21099.3.clone.1, %multiply.26560.1.clone.1) + %multiply.26561.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248799.1.clone.1, %select.21101.3.clone.1) + %add.248800.1.clone.1 = f32[1280,1280]{1,0} add(%select.21098.3.clone.1, %multiply.26561.1.clone.1) + %multiply.26562.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248800.1.clone.1, %select.21101.3.clone.1) + %add.248802.1.clone.1 = f32[1280,1280]{1,0} add(%select.21097.3.clone.1, %multiply.26562.1.clone.1) + %multiply.26563.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248802.1.clone.1, %select.21101.3.clone.1) + %add.248803.1.clone.1 = f32[1280,1280]{1,0} add(%select.21096.3.clone.1, %multiply.26563.1.clone.1) + %multiply.26564.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248803.1.clone.1, %select.21101.3.clone.1) + %add.248804.3.clone.1 = f32[1280,1280]{1,0} add(%select.21095.5.clone.1, %multiply.26564.1.clone.1) + %multiply.26565.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248804.3.clone.1, %select.21101.3.clone.1) + %add.248805.3.clone.1 = f32[1280,1280]{1,0} add(%select.21094.5.clone.1, %multiply.26565.1.clone.1) + %multiply.26566.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248805.3.clone.1, %select.21101.3.clone.1) + %add.248807.9.clone.1 = f32[1280,1280]{1,0} add(%select.21093.11.clone.1, %multiply.26566.7.clone.1) + %multiply.26567.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248807.9.clone.1, %select.21101.3.clone.1) + %add.248808.7.clone.1 = f32[1280,1280]{1,0} add(%select.21092.7.clone.1, %multiply.26567.7.clone.1) + %multiply.26568.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248808.7.clone.1, %maximum.3694.3.clone.1) + %select.21102.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7240.3.clone.1, %multiply.26558.9.clone.1, %multiply.26568.7.clone.1) + %multiply.26569.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21102.7.clone.1, %broadcast.244500.640) + %clamp.1190.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26569.5.clone.1, %broadcast.244501.384) + %multiply.26570.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1190.3.clone.1, %broadcast.244502.1) + %constant_192244_1_clone_1 = u32[] constant(2225837459) + %broadcast.259151.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_192244_1_clone_1), dimensions={} + %add.252505.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.259151.44.clone.1) + %constant_192251_1_clone_1 = u32[] constant(1598632266) + %broadcast.259152.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_192251_1_clone_1), dimensions={} + %add.252506.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.259152.113.clone.1) + %add.252507.35.clone.1 = u32[1280,1280]{1,0} add(%add.252505.37.clone.1, %add.252506.99.clone.1) + %shift-left.111400.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252506.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117702.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252506.99.clone.1, %broadcast.244415.6016) + %or.117230.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111400.31.clone.1, %shift-right-logical.117702.29.clone.1) + %xor.123785.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252507.35.clone.1, %or.117230.29.clone.1) + %add.252508.5.clone.1 = u32[1280,1280]{1,0} add(%add.252507.35.clone.1, %xor.123785.27.clone.1) + %shift-left.111401.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123785.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117703.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123785.27.clone.1, %broadcast.244417.5760) + %or.117231.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111401.9.clone.1, %shift-right-logical.117703.9.clone.1) + %xor.123786.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252508.5.clone.1, %or.117231.7.clone.1) + %add.252509.3.clone.1 = u32[1280,1280]{1,0} add(%add.252508.5.clone.1, %xor.123786.5.clone.1) + %shift-left.111402.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123786.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117705.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123786.5.clone.1, %broadcast.244419.4352) + %or.117232.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111402.5.clone.1, %shift-right-logical.117705.5.clone.1) + %xor.123787.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252509.3.clone.1, %or.117232.3.clone.1) + %add.252510.3.clone.1 = u32[1280,1280]{1,0} add(%add.252509.3.clone.1, %xor.123787.3.clone.1) + %add.252511.7.clone.1 = u32[1280,1280]{1,0} add(%add.252510.3.clone.1, %broadcast.259152.113.clone.1) + %shift-left.111403.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123787.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117706.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123787.3.clone.1, %broadcast.244418.4352) + %or.117233.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111403.5.clone.1, %shift-right-logical.117706.5.clone.1) + %xor.123788.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252510.3.clone.1, %or.117233.3.clone.1) + %constant_218695_1_clone_1 = u32[] constant(3224616708) + %broadcast.259162.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218695_1_clone_1), dimensions={} + %add.252512.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123788.3.clone.1, %broadcast.259162.5.clone.1) + %add.252513.5.clone.1 = u32[1280,1280]{1,0} add(%add.252511.7.clone.1, %add.252512.5.clone.1) + %shift-left.111404.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252512.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117707.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252512.5.clone.1, %broadcast.244416.5760) + %or.117234.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111404.9.clone.1, %shift-right-logical.117707.9.clone.1) + %xor.123789.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252513.5.clone.1, %or.117234.7.clone.1) + %add.252514.3.clone.1 = u32[1280,1280]{1,0} add(%add.252513.5.clone.1, %xor.123789.5.clone.1) + %shift-left.111405.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123789.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117708.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123789.5.clone.1, %broadcast.244429.2304) + %or.117235.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111405.9.clone.1, %shift-right-logical.117708.9.clone.1) + %xor.123790.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252514.3.clone.1, %or.117235.7.clone.1) + %add.252515.3.clone.1 = u32[1280,1280]{1,0} add(%add.252514.3.clone.1, %xor.123790.5.clone.1) + %shift-left.111406.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123790.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117710.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123790.5.clone.1, %broadcast.244430.4608) + %or.117236.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111406.9.clone.1, %shift-right-logical.117710.9.clone.1) + %xor.123791.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252515.3.clone.1, %or.117236.7.clone.1) + %add.252516.3.clone.1 = u32[1280,1280]{1,0} add(%add.252515.3.clone.1, %xor.123791.5.clone.1) + %constant_192253_1_clone_1 = u32[] constant(3224616707) + %broadcast.259169.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_192253_1_clone_1), dimensions={} + %add.252517.7.clone.1 = u32[1280,1280]{1,0} add(%add.252516.3.clone.1, %broadcast.259169.24.clone.1) + %shift-left.111407.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123791.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117711.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123791.5.clone.1, %broadcast.244434.2816) + %or.117237.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111407.11.clone.1, %shift-right-logical.117711.11.clone.1) + %xor.123792.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252516.3.clone.1, %or.117237.9.clone.1) + %constant_218696_1_clone_1 = u32[] constant(2225837461) + %broadcast.259172.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218696_1_clone_1), dimensions={} + %add.252518.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123792.7.clone.1, %broadcast.259172.5.clone.1) + %add.252519.5.clone.1 = u32[1280,1280]{1,0} add(%add.252517.7.clone.1, %add.252518.5.clone.1) + %shift-left.111408.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252518.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117712.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252518.5.clone.1, %broadcast.244415.6016) + %or.117238.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111408.9.clone.1, %shift-right-logical.117712.9.clone.1) + %xor.123793.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252519.5.clone.1, %or.117238.7.clone.1) + %add.252520.3.clone.1 = u32[1280,1280]{1,0} add(%add.252519.5.clone.1, %xor.123793.5.clone.1) + %shift-left.111409.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123793.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117713.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123793.5.clone.1, %broadcast.244417.5760) + %or.117239.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111409.9.clone.1, %shift-right-logical.117713.9.clone.1) + %xor.123794.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252520.3.clone.1, %or.117239.7.clone.1) + %add.252521.3.clone.1 = u32[1280,1280]{1,0} add(%add.252520.3.clone.1, %xor.123794.5.clone.1) + %shift-left.111410.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123794.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117715.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123794.5.clone.1, %broadcast.244419.4352) + %or.117240.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111410.7.clone.1, %shift-right-logical.117715.7.clone.1) + %xor.123795.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252521.3.clone.1, %or.117240.5.clone.1) + %add.252522.3.clone.1 = u32[1280,1280]{1,0} add(%add.252521.3.clone.1, %xor.123795.3.clone.1) + %add.252523.7.clone.1 = u32[1280,1280]{1,0} add(%add.252522.3.clone.1, %broadcast.259151.44.clone.1) + %shift-left.111411.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123795.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117716.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123795.3.clone.1, %broadcast.244418.4352) + %or.117241.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111411.7.clone.1, %shift-right-logical.117716.7.clone.1) + %xor.123796.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252522.3.clone.1, %or.117241.5.clone.1) + %constant_218697_1_clone_1 = u32[] constant(1598632269) + %broadcast.259182.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218697_1_clone_1), dimensions={} + %add.252524.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123796.3.clone.1, %broadcast.259182.5.clone.1) + %add.252525.5.clone.1 = u32[1280,1280]{1,0} add(%add.252523.7.clone.1, %add.252524.5.clone.1) + %shift-left.111412.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252524.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117717.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252524.5.clone.1, %broadcast.244416.5760) + %or.117242.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111412.9.clone.1, %shift-right-logical.117717.9.clone.1) + %xor.123797.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252525.5.clone.1, %or.117242.7.clone.1) + %add.252526.3.clone.1 = u32[1280,1280]{1,0} add(%add.252525.5.clone.1, %xor.123797.5.clone.1) + %shift-left.111413.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123797.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117718.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123797.5.clone.1, %broadcast.244429.2304) + %or.117243.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111413.9.clone.1, %shift-right-logical.117718.9.clone.1) + %xor.123798.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252526.3.clone.1, %or.117243.7.clone.1) + %add.252527.3.clone.1 = u32[1280,1280]{1,0} add(%add.252526.3.clone.1, %xor.123798.5.clone.1) + %shift-left.111414.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123798.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117720.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123798.5.clone.1, %broadcast.244430.4608) + %or.117244.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111414.9.clone.1, %shift-right-logical.117720.9.clone.1) + %xor.123799.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252527.3.clone.1, %or.117244.7.clone.1) + %add.252528.3.clone.1 = u32[1280,1280]{1,0} add(%add.252527.3.clone.1, %xor.123799.5.clone.1) + %add.252529.7.clone.1 = u32[1280,1280]{1,0} add(%add.252528.3.clone.1, %broadcast.259152.113.clone.1) + %shift-left.111415.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123799.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117721.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123799.5.clone.1, %broadcast.244434.2816) + %or.117245.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111415.11.clone.1, %shift-right-logical.117721.11.clone.1) + %xor.123800.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252528.3.clone.1, %or.117245.9.clone.1) + %constant_218698_1_clone_1 = u32[] constant(3224616711) + %broadcast.259192.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218698_1_clone_1), dimensions={} + %add.252530.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123800.7.clone.1, %broadcast.259192.5.clone.1) + %add.252531.5.clone.1 = u32[1280,1280]{1,0} add(%add.252529.7.clone.1, %add.252530.5.clone.1) + %shift-left.111416.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252530.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117722.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252530.5.clone.1, %broadcast.244415.6016) + %or.117246.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111416.9.clone.1, %shift-right-logical.117722.9.clone.1) + %xor.123801.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252531.5.clone.1, %or.117246.7.clone.1) + %add.252532.3.clone.1 = u32[1280,1280]{1,0} add(%add.252531.5.clone.1, %xor.123801.5.clone.1) + %shift-left.111417.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123801.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117723.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123801.5.clone.1, %broadcast.244417.5760) + %or.117247.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111417.9.clone.1, %shift-right-logical.117723.9.clone.1) + %xor.123802.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252532.3.clone.1, %or.117247.7.clone.1) + %add.252533.3.clone.1 = u32[1280,1280]{1,0} add(%add.252532.3.clone.1, %xor.123802.5.clone.1) + %shift-left.111418.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123802.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117724.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123802.5.clone.1, %broadcast.244419.4352) + %or.117248.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111418.5.clone.1, %shift-right-logical.117724.5.clone.1) + %xor.123803.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252533.3.clone.1, %or.117248.3.clone.1) + %add.252534.3.clone.1 = u32[1280,1280]{1,0} add(%add.252533.3.clone.1, %xor.123803.3.clone.1) + %add.252535.17.clone.1 = u32[1280,1280]{1,0} add(%add.252534.3.clone.1, %broadcast.259169.24.clone.1) + %shift-left.111419.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123803.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117725.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123803.3.clone.1, %broadcast.244418.4352) + %or.117249.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111419.5.clone.1, %shift-right-logical.117725.5.clone.1) + %xor.123804.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252534.3.clone.1, %or.117249.3.clone.1) + %constant_218699_1_clone_1 = u32[] constant(2225837464) + %broadcast.259202.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218699_1_clone_1), dimensions={} + %add.252536.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123804.15.clone.1, %broadcast.259202.19.clone.1) + %xor.123805.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252535.17.clone.1, %add.252536.19.clone.1) + %shift-right-logical.117726.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123805.17.clone.1, %broadcast.244468.1920) + %or.117250.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117726.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5837.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117250.13.clone.1) + %add.252537.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5837.11.clone.1, %broadcast.244470.1152) + %multiply.27310.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252537.9.clone.1, %broadcast.244471.896) + %add.252538.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27310.7.clone.1, %broadcast.244408.1024) + %maximum.3769.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252538.5.clone.1) + %abs.1595.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3769.3.clone.1) + %compare.7352.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1595.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27311.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3769.3.clone.1, %broadcast.244476.1152) + %negate.4695.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3769.3.clone.1) + %multiply.27312.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3769.3.clone.1, %negate.4695.5.clone.1) + %log-plus-one.1595.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27312.5.clone.1) + %negate.4696.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1595.3.clone.1) + %compare.7353.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4696.4.clone.1, %broadcast.244477.384), direction=LT + %select.21652.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21653.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21654.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21655.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21656.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21657.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21658.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21659.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21660.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252539.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4696.4.clone.1, %broadcast.244496.640) + %sqrt.1595.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4696.4.clone.1) + %add.252540.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1595.5.clone.1, %broadcast.244498.640) + %select.21661.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7353.3.clone.1, %add.252539.5.clone.1, %add.252540.5.clone.1) + %multiply.27313.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21660.3.clone.1, %select.21661.3.clone.1) + %add.252541.1.clone.1 = f32[1280,1280]{1,0} add(%select.21659.3.clone.1, %multiply.27313.1.clone.1) + %multiply.27314.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252541.1.clone.1, %select.21661.3.clone.1) + %add.252542.1.clone.1 = f32[1280,1280]{1,0} add(%select.21658.3.clone.1, %multiply.27314.1.clone.1) + %multiply.27315.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252542.1.clone.1, %select.21661.3.clone.1) + %add.252543.1.clone.1 = f32[1280,1280]{1,0} add(%select.21657.3.clone.1, %multiply.27315.1.clone.1) + %multiply.27316.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252543.1.clone.1, %select.21661.3.clone.1) + %add.252544.1.clone.1 = f32[1280,1280]{1,0} add(%select.21656.3.clone.1, %multiply.27316.1.clone.1) + %multiply.27317.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252544.1.clone.1, %select.21661.3.clone.1) + %add.252545.3.clone.1 = f32[1280,1280]{1,0} add(%select.21655.5.clone.1, %multiply.27317.1.clone.1) + %multiply.27318.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252545.3.clone.1, %select.21661.3.clone.1) + %add.252546.3.clone.1 = f32[1280,1280]{1,0} add(%select.21654.5.clone.1, %multiply.27318.1.clone.1) + %multiply.27319.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252546.3.clone.1, %select.21661.3.clone.1) + %add.252547.9.clone.1 = f32[1280,1280]{1,0} add(%select.21653.11.clone.1, %multiply.27319.7.clone.1) + %multiply.27320.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252547.9.clone.1, %select.21661.3.clone.1) + %add.252548.7.clone.1 = f32[1280,1280]{1,0} add(%select.21652.7.clone.1, %multiply.27320.7.clone.1) + %multiply.27321.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252548.7.clone.1, %maximum.3769.3.clone.1) + %select.21662.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7352.3.clone.1, %multiply.27311.9.clone.1, %multiply.27321.7.clone.1) + %multiply.27322.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21662.7.clone.1, %broadcast.244500.640) + %clamp.1239.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27322.5.clone.1, %broadcast.244501.384) + %multiply.27323.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1239.3.clone.1, %broadcast.244502.1) + %constant_176829_1_clone_1 = u32[] constant(109760836) + %broadcast.252474.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176829_1_clone_1), dimensions={} + %add.248687.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252474.44.clone.1) + %constant_176836_1_clone_1 = u32[] constant(546401704) + %broadcast.252475.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176836_1_clone_1), dimensions={} + %add.248689.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252475.113.clone.1) + %add.248692.35.clone.1 = u32[1280,1280]{1,0} add(%add.248687.37.clone.1, %add.248689.99.clone.1) + %shift-left.109740.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248689.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115940.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248689.99.clone.1, %broadcast.244415.6016) + %or.115462.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109740.31.clone.1, %shift-right-logical.115940.29.clone.1) + %xor.122009.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248692.35.clone.1, %or.115462.29.clone.1) + %add.248693.5.clone.1 = u32[1280,1280]{1,0} add(%add.248692.35.clone.1, %xor.122009.27.clone.1) + %shift-left.109741.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122009.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115941.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122009.27.clone.1, %broadcast.244417.5760) + %or.115463.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109741.9.clone.1, %shift-right-logical.115941.9.clone.1) + %xor.122010.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248693.5.clone.1, %or.115463.7.clone.1) + %add.248694.3.clone.1 = u32[1280,1280]{1,0} add(%add.248693.5.clone.1, %xor.122010.5.clone.1) + %shift-left.109742.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122010.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115942.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122010.5.clone.1, %broadcast.244419.4352) + %or.115464.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109742.5.clone.1, %shift-right-logical.115942.5.clone.1) + %xor.122011.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248694.3.clone.1, %or.115464.3.clone.1) + %add.248695.3.clone.1 = u32[1280,1280]{1,0} add(%add.248694.3.clone.1, %xor.122011.3.clone.1) + %add.248697.7.clone.1 = u32[1280,1280]{1,0} add(%add.248695.3.clone.1, %broadcast.252475.113.clone.1) + %shift-left.109743.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122011.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115943.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122011.3.clone.1, %broadcast.244418.4352) + %or.115465.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109743.5.clone.1, %shift-right-logical.115943.5.clone.1) + %xor.122012.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248695.3.clone.1, %or.115465.3.clone.1) + %constant_218277_1_clone_1 = u32[] constant(1036691255) + %broadcast.252485.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218277_1_clone_1), dimensions={} + %add.248698.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122012.3.clone.1, %broadcast.252485.5.clone.1) + %add.248699.5.clone.1 = u32[1280,1280]{1,0} add(%add.248697.7.clone.1, %add.248698.5.clone.1) + %shift-left.109744.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248698.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115944.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248698.5.clone.1, %broadcast.244416.5760) + %or.115466.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109744.9.clone.1, %shift-right-logical.115944.9.clone.1) + %xor.122013.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248699.5.clone.1, %or.115466.7.clone.1) + %add.248700.3.clone.1 = u32[1280,1280]{1,0} add(%add.248699.5.clone.1, %xor.122013.5.clone.1) + %shift-left.109745.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122013.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115945.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122013.5.clone.1, %broadcast.244429.2304) + %or.115467.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109745.9.clone.1, %shift-right-logical.115945.9.clone.1) + %xor.122014.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248700.3.clone.1, %or.115467.7.clone.1) + %add.248702.3.clone.1 = u32[1280,1280]{1,0} add(%add.248700.3.clone.1, %xor.122014.5.clone.1) + %shift-left.109746.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122014.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115946.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122014.5.clone.1, %broadcast.244430.4608) + %or.115468.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109746.9.clone.1, %shift-right-logical.115946.9.clone.1) + %xor.122015.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248702.3.clone.1, %or.115468.7.clone.1) + %add.248703.3.clone.1 = u32[1280,1280]{1,0} add(%add.248702.3.clone.1, %xor.122015.5.clone.1) + %constant_176838_1_clone_1 = u32[] constant(1036691254) + %broadcast.252492.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176838_1_clone_1), dimensions={} + %add.248704.7.clone.1 = u32[1280,1280]{1,0} add(%add.248703.3.clone.1, %broadcast.252492.24.clone.1) + %shift-left.109747.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122015.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115947.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122015.5.clone.1, %broadcast.244434.2816) + %or.115469.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109747.11.clone.1, %shift-right-logical.115947.11.clone.1) + %xor.122016.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248703.3.clone.1, %or.115469.9.clone.1) + %constant_218278_1_clone_1 = u32[] constant(109760838) + %broadcast.252495.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218278_1_clone_1), dimensions={} + %add.248705.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122016.7.clone.1, %broadcast.252495.5.clone.1) + %add.248707.5.clone.1 = u32[1280,1280]{1,0} add(%add.248704.7.clone.1, %add.248705.5.clone.1) + %shift-left.109748.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248705.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115948.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248705.5.clone.1, %broadcast.244415.6016) + %or.115470.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109748.9.clone.1, %shift-right-logical.115948.9.clone.1) + %xor.122017.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248707.5.clone.1, %or.115470.7.clone.1) + %add.248708.3.clone.1 = u32[1280,1280]{1,0} add(%add.248707.5.clone.1, %xor.122017.5.clone.1) + %shift-left.109749.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122017.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115949.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122017.5.clone.1, %broadcast.244417.5760) + %or.115471.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109749.9.clone.1, %shift-right-logical.115949.9.clone.1) + %xor.122018.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248708.3.clone.1, %or.115471.7.clone.1) + %add.248709.3.clone.1 = u32[1280,1280]{1,0} add(%add.248708.3.clone.1, %xor.122018.5.clone.1) + %shift-left.109750.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122018.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115950.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122018.5.clone.1, %broadcast.244419.4352) + %or.115473.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109750.7.clone.1, %shift-right-logical.115950.7.clone.1) + %xor.122019.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248709.3.clone.1, %or.115473.5.clone.1) + %add.248710.3.clone.1 = u32[1280,1280]{1,0} add(%add.248709.3.clone.1, %xor.122019.3.clone.1) + %add.248711.7.clone.1 = u32[1280,1280]{1,0} add(%add.248710.3.clone.1, %broadcast.252474.44.clone.1) + %shift-left.109751.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122019.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115951.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122019.3.clone.1, %broadcast.244418.4352) + %or.115474.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109751.7.clone.1, %shift-right-logical.115951.7.clone.1) + %xor.122020.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248710.3.clone.1, %or.115474.5.clone.1) + %constant_218279_1_clone_1 = u32[] constant(546401707) + %broadcast.252505.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218279_1_clone_1), dimensions={} + %add.248713.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122020.3.clone.1, %broadcast.252505.5.clone.1) + %add.248717.5.clone.1 = u32[1280,1280]{1,0} add(%add.248711.7.clone.1, %add.248713.5.clone.1) + %shift-left.109752.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248713.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115952.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248713.5.clone.1, %broadcast.244416.5760) + %or.115476.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109752.9.clone.1, %shift-right-logical.115952.9.clone.1) + %xor.122021.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248717.5.clone.1, %or.115476.7.clone.1) + %add.248718.3.clone.1 = u32[1280,1280]{1,0} add(%add.248717.5.clone.1, %xor.122021.5.clone.1) + %shift-left.109754.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122021.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115953.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122021.5.clone.1, %broadcast.244429.2304) + %or.115477.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109754.9.clone.1, %shift-right-logical.115953.9.clone.1) + %xor.122022.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248718.3.clone.1, %or.115477.7.clone.1) + %add.248719.3.clone.1 = u32[1280,1280]{1,0} add(%add.248718.3.clone.1, %xor.122022.5.clone.1) + %shift-left.109755.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122022.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115954.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122022.5.clone.1, %broadcast.244430.4608) + %or.115478.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109755.9.clone.1, %shift-right-logical.115954.9.clone.1) + %xor.122023.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248719.3.clone.1, %or.115478.7.clone.1) + %add.248720.3.clone.1 = u32[1280,1280]{1,0} add(%add.248719.3.clone.1, %xor.122023.5.clone.1) + %add.248722.7.clone.1 = u32[1280,1280]{1,0} add(%add.248720.3.clone.1, %broadcast.252475.113.clone.1) + %shift-left.109756.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122023.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115955.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122023.5.clone.1, %broadcast.244434.2816) + %or.115479.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109756.11.clone.1, %shift-right-logical.115955.11.clone.1) + %xor.122024.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248720.3.clone.1, %or.115479.9.clone.1) + %constant_218280_1_clone_1 = u32[] constant(1036691258) + %broadcast.252515.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218280_1_clone_1), dimensions={} + %add.248723.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122024.7.clone.1, %broadcast.252515.5.clone.1) + %add.248724.5.clone.1 = u32[1280,1280]{1,0} add(%add.248722.7.clone.1, %add.248723.5.clone.1) + %shift-left.109757.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248723.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115956.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248723.5.clone.1, %broadcast.244415.6016) + %or.115480.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109757.9.clone.1, %shift-right-logical.115956.9.clone.1) + %xor.122025.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248724.5.clone.1, %or.115480.7.clone.1) + %add.248725.3.clone.1 = u32[1280,1280]{1,0} add(%add.248724.5.clone.1, %xor.122025.5.clone.1) + %shift-left.109759.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122025.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115957.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122025.5.clone.1, %broadcast.244417.5760) + %or.115481.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109759.9.clone.1, %shift-right-logical.115957.9.clone.1) + %xor.122026.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248725.3.clone.1, %or.115481.7.clone.1) + %add.248727.3.clone.1 = u32[1280,1280]{1,0} add(%add.248725.3.clone.1, %xor.122026.5.clone.1) + %shift-left.109760.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122026.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115958.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122026.5.clone.1, %broadcast.244419.4352) + %or.115482.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109760.5.clone.1, %shift-right-logical.115958.5.clone.1) + %xor.122027.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248727.3.clone.1, %or.115482.3.clone.1) + %add.248728.3.clone.1 = u32[1280,1280]{1,0} add(%add.248727.3.clone.1, %xor.122027.3.clone.1) + %add.248729.17.clone.1 = u32[1280,1280]{1,0} add(%add.248728.3.clone.1, %broadcast.252492.24.clone.1) + %shift-left.109761.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122027.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115959.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122027.3.clone.1, %broadcast.244418.4352) + %or.115483.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109761.5.clone.1, %shift-right-logical.115959.5.clone.1) + %xor.122028.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248728.3.clone.1, %or.115483.3.clone.1) + %constant_218281_1_clone_1 = u32[] constant(109760841) + %broadcast.252528.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218281_1_clone_1), dimensions={} + %add.248730.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122028.15.clone.1, %broadcast.252528.19.clone.1) + %xor.122029.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248729.17.clone.1, %add.248730.19.clone.1) + %shift-right-logical.115960.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122029.17.clone.1, %broadcast.244468.1920) + %or.115484.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115960.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5761.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115484.13.clone.1) + %add.248732.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5761.11.clone.1, %broadcast.244470.1152) + %multiply.26542.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248732.9.clone.1, %broadcast.244471.896) + %add.248733.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26542.7.clone.1, %broadcast.244408.1024) + %maximum.3693.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248733.5.clone.1) + %abs.1545.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3693.3.clone.1) + %compare.7238.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1545.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26543.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3693.3.clone.1, %broadcast.244476.1152) + %negate.4595.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3693.3.clone.1) + %multiply.26544.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3693.3.clone.1, %negate.4595.5.clone.1) + %log-plus-one.1545.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26544.5.clone.1) + %negate.4596.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1545.3.clone.1) + %compare.7239.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4596.4.clone.1, %broadcast.244477.384), direction=LT + %select.21081.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21082.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21083.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21084.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21085.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21086.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21087.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21088.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21089.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248734.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4596.4.clone.1, %broadcast.244496.640) + %sqrt.1545.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4596.4.clone.1) + %add.248735.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1545.5.clone.1, %broadcast.244498.640) + %select.21090.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7239.3.clone.1, %add.248734.5.clone.1, %add.248735.5.clone.1) + %multiply.26545.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21089.3.clone.1, %select.21090.3.clone.1) + %add.248736.1.clone.1 = f32[1280,1280]{1,0} add(%select.21088.3.clone.1, %multiply.26545.1.clone.1) + %multiply.26546.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248736.1.clone.1, %select.21090.3.clone.1) + %add.248738.1.clone.1 = f32[1280,1280]{1,0} add(%select.21087.3.clone.1, %multiply.26546.1.clone.1) + %multiply.26547.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248738.1.clone.1, %select.21090.3.clone.1) + %add.248742.1.clone.1 = f32[1280,1280]{1,0} add(%select.21086.3.clone.1, %multiply.26547.1.clone.1) + %multiply.26548.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248742.1.clone.1, %select.21090.3.clone.1) + %add.248743.1.clone.1 = f32[1280,1280]{1,0} add(%select.21085.3.clone.1, %multiply.26548.1.clone.1) + %multiply.26549.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248743.1.clone.1, %select.21090.3.clone.1) + %add.248744.3.clone.1 = f32[1280,1280]{1,0} add(%select.21084.5.clone.1, %multiply.26549.1.clone.1) + %multiply.26551.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248744.3.clone.1, %select.21090.3.clone.1) + %add.248745.3.clone.1 = f32[1280,1280]{1,0} add(%select.21083.5.clone.1, %multiply.26551.1.clone.1) + %multiply.26552.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248745.3.clone.1, %select.21090.3.clone.1) + %add.248747.9.clone.1 = f32[1280,1280]{1,0} add(%select.21082.11.clone.1, %multiply.26552.7.clone.1) + %multiply.26553.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248747.9.clone.1, %select.21090.3.clone.1) + %add.248748.7.clone.1 = f32[1280,1280]{1,0} add(%select.21081.7.clone.1, %multiply.26553.7.clone.1) + %multiply.26554.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248748.7.clone.1, %maximum.3693.3.clone.1) + %select.21091.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7238.3.clone.1, %multiply.26543.9.clone.1, %multiply.26554.7.clone.1) + %multiply.26555.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21091.7.clone.1, %broadcast.244500.640) + %clamp.1189.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26555.5.clone.1, %broadcast.244501.384) + %multiply.26556.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1189.3.clone.1, %broadcast.244502.1) + %constant_186813_1_clone_1 = u32[] constant(3048909039) + %broadcast.256825.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186813_1_clone_1), dimensions={} + %add.251151.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256825.44.clone.1) + %constant_186820_1_clone_1 = u32[] constant(4279400643) + %broadcast.256826.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186820_1_clone_1), dimensions={} + %add.251152.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256826.113.clone.1) + %add.251153.35.clone.1 = u32[1280,1280]{1,0} add(%add.251151.37.clone.1, %add.251152.99.clone.1) + %shift-left.110820.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251152.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117081.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251152.99.clone.1, %broadcast.244415.6016) + %or.116601.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110820.31.clone.1, %shift-right-logical.117081.29.clone.1) + %xor.123172.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251153.35.clone.1, %or.116601.29.clone.1) + %add.251154.5.clone.1 = u32[1280,1280]{1,0} add(%add.251153.35.clone.1, %xor.123172.27.clone.1) + %shift-left.110821.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123172.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117082.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123172.27.clone.1, %broadcast.244417.5760) + %or.116602.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110821.9.clone.1, %shift-right-logical.117082.9.clone.1) + %xor.123173.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251154.5.clone.1, %or.116602.7.clone.1) + %add.251155.3.clone.1 = u32[1280,1280]{1,0} add(%add.251154.5.clone.1, %xor.123173.5.clone.1) + %shift-left.110822.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123173.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117084.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123173.5.clone.1, %broadcast.244419.4352) + %or.116603.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110822.5.clone.1, %shift-right-logical.117084.5.clone.1) + %xor.123174.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251155.3.clone.1, %or.116603.3.clone.1) + %add.251156.3.clone.1 = u32[1280,1280]{1,0} add(%add.251155.3.clone.1, %xor.123174.3.clone.1) + %add.251157.7.clone.1 = u32[1280,1280]{1,0} add(%add.251156.3.clone.1, %broadcast.256826.113.clone.1) + %shift-left.110823.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123174.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117085.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123174.3.clone.1, %broadcast.244418.4352) + %or.116604.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110823.5.clone.1, %shift-right-logical.117085.5.clone.1) + %xor.123175.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251156.3.clone.1, %or.116604.3.clone.1) + %constant_218544_1_clone_1 = u32[] constant(1366936567) + %broadcast.256836.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218544_1_clone_1), dimensions={} + %add.251158.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123175.3.clone.1, %broadcast.256836.5.clone.1) + %add.251159.5.clone.1 = u32[1280,1280]{1,0} add(%add.251157.7.clone.1, %add.251158.5.clone.1) + %shift-left.110824.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251158.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117086.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251158.5.clone.1, %broadcast.244416.5760) + %or.116605.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110824.9.clone.1, %shift-right-logical.117086.9.clone.1) + %xor.123176.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251159.5.clone.1, %or.116605.7.clone.1) + %add.251160.3.clone.1 = u32[1280,1280]{1,0} add(%add.251159.5.clone.1, %xor.123176.5.clone.1) + %shift-left.110825.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123176.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117087.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123176.5.clone.1, %broadcast.244429.2304) + %or.116606.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110825.9.clone.1, %shift-right-logical.117087.9.clone.1) + %xor.123178.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251160.3.clone.1, %or.116606.7.clone.1) + %add.251161.3.clone.1 = u32[1280,1280]{1,0} add(%add.251160.3.clone.1, %xor.123178.5.clone.1) + %shift-left.110826.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123178.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117089.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123178.5.clone.1, %broadcast.244430.4608) + %or.116607.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110826.9.clone.1, %shift-right-logical.117089.9.clone.1) + %xor.123179.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251161.3.clone.1, %or.116607.7.clone.1) + %add.251162.3.clone.1 = u32[1280,1280]{1,0} add(%add.251161.3.clone.1, %xor.123179.5.clone.1) + %constant_186822_1_clone_1 = u32[] constant(1366936566) + %broadcast.256845.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186822_1_clone_1), dimensions={} + %add.251163.7.clone.1 = u32[1280,1280]{1,0} add(%add.251162.3.clone.1, %broadcast.256845.24.clone.1) + %shift-left.110827.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123179.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117090.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123179.5.clone.1, %broadcast.244434.2816) + %or.116608.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110827.11.clone.1, %shift-right-logical.117090.11.clone.1) + %xor.123180.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251162.3.clone.1, %or.116608.9.clone.1) + %constant_218545_1_clone_1 = u32[] constant(3048909041) + %broadcast.256848.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218545_1_clone_1), dimensions={} + %add.251164.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123180.7.clone.1, %broadcast.256848.5.clone.1) + %add.251165.5.clone.1 = u32[1280,1280]{1,0} add(%add.251163.7.clone.1, %add.251164.5.clone.1) + %shift-left.110828.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251164.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117091.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251164.5.clone.1, %broadcast.244415.6016) + %or.116609.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110828.9.clone.1, %shift-right-logical.117091.9.clone.1) + %xor.123181.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251165.5.clone.1, %or.116609.7.clone.1) + %add.251166.3.clone.1 = u32[1280,1280]{1,0} add(%add.251165.5.clone.1, %xor.123181.5.clone.1) + %shift-left.110829.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123181.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117092.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123181.5.clone.1, %broadcast.244417.5760) + %or.116610.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110829.9.clone.1, %shift-right-logical.117092.9.clone.1) + %xor.123182.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251166.3.clone.1, %or.116610.7.clone.1) + %add.251168.3.clone.1 = u32[1280,1280]{1,0} add(%add.251166.3.clone.1, %xor.123182.5.clone.1) + %shift-left.110830.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123182.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117094.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123182.5.clone.1, %broadcast.244419.4352) + %or.116611.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110830.7.clone.1, %shift-right-logical.117094.7.clone.1) + %xor.123183.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251168.3.clone.1, %or.116611.5.clone.1) + %add.251169.3.clone.1 = u32[1280,1280]{1,0} add(%add.251168.3.clone.1, %xor.123183.3.clone.1) + %add.251170.7.clone.1 = u32[1280,1280]{1,0} add(%add.251169.3.clone.1, %broadcast.256825.44.clone.1) + %shift-left.110831.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123183.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117095.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123183.3.clone.1, %broadcast.244418.4352) + %or.116612.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110831.7.clone.1, %shift-right-logical.117095.7.clone.1) + %xor.123184.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251169.3.clone.1, %or.116612.5.clone.1) + %constant_218546_1_clone_1 = u32[] constant(4279400646) + %broadcast.256858.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218546_1_clone_1), dimensions={} + %add.251171.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123184.3.clone.1, %broadcast.256858.5.clone.1) + %add.251172.5.clone.1 = u32[1280,1280]{1,0} add(%add.251170.7.clone.1, %add.251171.5.clone.1) + %shift-left.110832.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251171.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117096.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251171.5.clone.1, %broadcast.244416.5760) + %or.116613.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110832.9.clone.1, %shift-right-logical.117096.9.clone.1) + %xor.123185.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251172.5.clone.1, %or.116613.7.clone.1) + %add.251173.3.clone.1 = u32[1280,1280]{1,0} add(%add.251172.5.clone.1, %xor.123185.5.clone.1) + %shift-left.110833.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123185.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117097.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123185.5.clone.1, %broadcast.244429.2304) + %or.116614.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110833.9.clone.1, %shift-right-logical.117097.9.clone.1) + %xor.123186.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251173.3.clone.1, %or.116614.7.clone.1) + %add.251174.3.clone.1 = u32[1280,1280]{1,0} add(%add.251173.3.clone.1, %xor.123186.5.clone.1) + %shift-left.110834.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123186.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117098.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123186.5.clone.1, %broadcast.244430.4608) + %or.116615.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110834.9.clone.1, %shift-right-logical.117098.9.clone.1) + %xor.123187.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251174.3.clone.1, %or.116615.7.clone.1) + %add.251175.3.clone.1 = u32[1280,1280]{1,0} add(%add.251174.3.clone.1, %xor.123187.5.clone.1) + %add.251176.7.clone.1 = u32[1280,1280]{1,0} add(%add.251175.3.clone.1, %broadcast.256826.113.clone.1) + %shift-left.110835.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123187.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117099.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123187.5.clone.1, %broadcast.244434.2816) + %or.116616.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110835.11.clone.1, %shift-right-logical.117099.11.clone.1) + %xor.123188.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251175.3.clone.1, %or.116616.9.clone.1) + %constant_218547_1_clone_1 = u32[] constant(1366936570) + %broadcast.256870.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218547_1_clone_1), dimensions={} + %add.251177.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123188.7.clone.1, %broadcast.256870.5.clone.1) + %add.251178.5.clone.1 = u32[1280,1280]{1,0} add(%add.251176.7.clone.1, %add.251177.5.clone.1) + %shift-left.110836.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251177.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117100.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251177.5.clone.1, %broadcast.244415.6016) + %or.116617.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110836.9.clone.1, %shift-right-logical.117100.9.clone.1) + %xor.123189.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251178.5.clone.1, %or.116617.7.clone.1) + %add.251179.3.clone.1 = u32[1280,1280]{1,0} add(%add.251178.5.clone.1, %xor.123189.5.clone.1) + %shift-left.110837.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123189.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117101.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123189.5.clone.1, %broadcast.244417.5760) + %or.116618.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110837.9.clone.1, %shift-right-logical.117101.9.clone.1) + %xor.123190.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251179.3.clone.1, %or.116618.7.clone.1) + %add.251180.3.clone.1 = u32[1280,1280]{1,0} add(%add.251179.3.clone.1, %xor.123190.5.clone.1) + %shift-left.110838.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123190.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117102.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123190.5.clone.1, %broadcast.244419.4352) + %or.116619.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110838.5.clone.1, %shift-right-logical.117102.5.clone.1) + %xor.123191.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251180.3.clone.1, %or.116619.3.clone.1) + %add.251181.3.clone.1 = u32[1280,1280]{1,0} add(%add.251180.3.clone.1, %xor.123191.3.clone.1) + %add.251182.17.clone.1 = u32[1280,1280]{1,0} add(%add.251181.3.clone.1, %broadcast.256845.24.clone.1) + %shift-left.110839.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123191.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117104.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123191.3.clone.1, %broadcast.244418.4352) + %or.116620.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110839.5.clone.1, %shift-right-logical.117104.5.clone.1) + %xor.123192.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251181.3.clone.1, %or.116620.3.clone.1) + %constant_218548_1_clone_1 = u32[] constant(3048909044) + %broadcast.256880.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218548_1_clone_1), dimensions={} + %add.251183.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123192.15.clone.1, %broadcast.256880.19.clone.1) + %xor.123193.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251182.17.clone.1, %add.251183.19.clone.1) + %shift-right-logical.117105.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123193.17.clone.1, %broadcast.244468.1920) + %or.116621.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117105.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5810.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116621.13.clone.1) + %add.251184.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5810.11.clone.1, %broadcast.244470.1152) + %multiply.27049.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251184.9.clone.1, %broadcast.244471.896) + %add.251185.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27049.7.clone.1, %broadcast.244408.1024) + %maximum.3742.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251185.5.clone.1) + %abs.1578.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3742.3.clone.1) + %compare.7318.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1578.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27050.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3742.3.clone.1, %broadcast.244476.1152) + %negate.4661.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3742.3.clone.1) + %multiply.27052.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3742.3.clone.1, %negate.4661.5.clone.1) + %log-plus-one.1578.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27052.5.clone.1) + %negate.4662.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1578.3.clone.1) + %compare.7319.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4662.4.clone.1, %broadcast.244477.384), direction=LT + %select.21444.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21445.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21446.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21447.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21448.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21449.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21450.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21451.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21452.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251186.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4662.4.clone.1, %broadcast.244496.640) + %sqrt.1578.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4662.4.clone.1) + %add.251187.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1578.5.clone.1, %broadcast.244498.640) + %select.21453.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7319.3.clone.1, %add.251186.5.clone.1, %add.251187.5.clone.1) + %multiply.27054.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21452.3.clone.1, %select.21453.3.clone.1) + %add.251188.1.clone.1 = f32[1280,1280]{1,0} add(%select.21451.3.clone.1, %multiply.27054.1.clone.1) + %multiply.27055.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251188.1.clone.1, %select.21453.3.clone.1) + %add.251189.1.clone.1 = f32[1280,1280]{1,0} add(%select.21450.3.clone.1, %multiply.27055.1.clone.1) + %multiply.27056.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251189.1.clone.1, %select.21453.3.clone.1) + %add.251190.1.clone.1 = f32[1280,1280]{1,0} add(%select.21449.3.clone.1, %multiply.27056.1.clone.1) + %multiply.27057.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251190.1.clone.1, %select.21453.3.clone.1) + %add.251191.1.clone.1 = f32[1280,1280]{1,0} add(%select.21448.3.clone.1, %multiply.27057.1.clone.1) + %multiply.27058.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251191.1.clone.1, %select.21453.3.clone.1) + %add.251192.3.clone.1 = f32[1280,1280]{1,0} add(%select.21447.5.clone.1, %multiply.27058.1.clone.1) + %multiply.27059.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251192.3.clone.1, %select.21453.3.clone.1) + %add.251193.3.clone.1 = f32[1280,1280]{1,0} add(%select.21446.5.clone.1, %multiply.27059.1.clone.1) + %multiply.27060.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251193.3.clone.1, %select.21453.3.clone.1) + %add.251194.9.clone.1 = f32[1280,1280]{1,0} add(%select.21445.11.clone.1, %multiply.27060.7.clone.1) + %multiply.27061.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251194.9.clone.1, %select.21453.3.clone.1) + %add.251195.7.clone.1 = f32[1280,1280]{1,0} add(%select.21444.7.clone.1, %multiply.27061.7.clone.1) + %multiply.27062.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251195.7.clone.1, %maximum.3742.3.clone.1) + %select.21454.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7318.3.clone.1, %multiply.27050.9.clone.1, %multiply.27062.7.clone.1) + %multiply.27063.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21454.7.clone.1, %broadcast.244500.640) + %clamp.1222.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27063.5.clone.1, %broadcast.244501.384) + %multiply.27064.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1222.3.clone.1, %broadcast.244502.1) + %constant_176597_1_clone_1 = u32[] constant(3351390897) + %broadcast.252388.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176597_1_clone_1), dimensions={} + %add.248642.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252388.44.clone.1) + %constant_176604_1_clone_1 = u32[] constant(3630472852) + %broadcast.252389.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176604_1_clone_1), dimensions={} + %add.248643.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252389.113.clone.1) + %add.248644.35.clone.1 = u32[1280,1280]{1,0} add(%add.248642.37.clone.1, %add.248643.99.clone.1) + %shift-left.109720.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248643.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115919.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248643.99.clone.1, %broadcast.244415.6016) + %or.115441.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109720.31.clone.1, %shift-right-logical.115919.29.clone.1) + %xor.121988.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248644.35.clone.1, %or.115441.29.clone.1) + %add.248645.5.clone.1 = u32[1280,1280]{1,0} add(%add.248644.35.clone.1, %xor.121988.27.clone.1) + %shift-left.109721.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121988.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115920.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121988.27.clone.1, %broadcast.244417.5760) + %or.115442.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109721.9.clone.1, %shift-right-logical.115920.9.clone.1) + %xor.121989.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248645.5.clone.1, %or.115442.7.clone.1) + %add.248646.3.clone.1 = u32[1280,1280]{1,0} add(%add.248645.5.clone.1, %xor.121989.5.clone.1) + %shift-left.109722.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121989.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115921.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121989.5.clone.1, %broadcast.244419.4352) + %or.115443.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109722.5.clone.1, %shift-right-logical.115921.5.clone.1) + %xor.121990.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248646.3.clone.1, %or.115443.3.clone.1) + %add.248647.3.clone.1 = u32[1280,1280]{1,0} add(%add.248646.3.clone.1, %xor.121990.3.clone.1) + %add.248648.7.clone.1 = u32[1280,1280]{1,0} add(%add.248647.3.clone.1, %broadcast.252389.113.clone.1) + %shift-left.109723.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121990.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115922.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121990.3.clone.1, %broadcast.244418.4352) + %or.115444.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109723.5.clone.1, %shift-right-logical.115922.5.clone.1) + %xor.121991.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248647.3.clone.1, %or.115444.3.clone.1) + %constant_218272_1_clone_1 = u32[] constant(74952704) + %broadcast.252399.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218272_1_clone_1), dimensions={} + %add.248649.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121991.3.clone.1, %broadcast.252399.5.clone.1) + %add.248650.5.clone.1 = u32[1280,1280]{1,0} add(%add.248648.7.clone.1, %add.248649.5.clone.1) + %shift-left.109724.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248649.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115923.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248649.5.clone.1, %broadcast.244416.5760) + %or.115445.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109724.9.clone.1, %shift-right-logical.115923.9.clone.1) + %xor.121992.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248650.5.clone.1, %or.115445.7.clone.1) + %add.248651.3.clone.1 = u32[1280,1280]{1,0} add(%add.248650.5.clone.1, %xor.121992.5.clone.1) + %shift-left.109725.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121992.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115924.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121992.5.clone.1, %broadcast.244429.2304) + %or.115446.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109725.9.clone.1, %shift-right-logical.115924.9.clone.1) + %xor.121993.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248651.3.clone.1, %or.115446.7.clone.1) + %add.248652.3.clone.1 = u32[1280,1280]{1,0} add(%add.248651.3.clone.1, %xor.121993.5.clone.1) + %shift-left.109726.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121993.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115925.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121993.5.clone.1, %broadcast.244430.4608) + %or.115447.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109726.9.clone.1, %shift-right-logical.115925.9.clone.1) + %xor.121994.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248652.3.clone.1, %or.115447.7.clone.1) + %add.248653.3.clone.1 = u32[1280,1280]{1,0} add(%add.248652.3.clone.1, %xor.121994.5.clone.1) + %constant_176606_1_clone_1 = u32[] constant(74952703) + %broadcast.252406.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176606_1_clone_1), dimensions={} + %add.248654.7.clone.1 = u32[1280,1280]{1,0} add(%add.248653.3.clone.1, %broadcast.252406.24.clone.1) + %shift-left.109727.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121994.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115926.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121994.5.clone.1, %broadcast.244434.2816) + %or.115448.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109727.11.clone.1, %shift-right-logical.115926.11.clone.1) + %xor.121995.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248653.3.clone.1, %or.115448.9.clone.1) + %constant_218273_1_clone_1 = u32[] constant(3351390899) + %broadcast.252409.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218273_1_clone_1), dimensions={} + %add.248655.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121995.7.clone.1, %broadcast.252409.5.clone.1) + %add.248657.5.clone.1 = u32[1280,1280]{1,0} add(%add.248654.7.clone.1, %add.248655.5.clone.1) + %shift-left.109728.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248655.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115927.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248655.5.clone.1, %broadcast.244415.6016) + %or.115449.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109728.9.clone.1, %shift-right-logical.115927.9.clone.1) + %xor.121996.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248657.5.clone.1, %or.115449.7.clone.1) + %add.248658.3.clone.1 = u32[1280,1280]{1,0} add(%add.248657.5.clone.1, %xor.121996.5.clone.1) + %shift-left.109729.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121996.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115928.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121996.5.clone.1, %broadcast.244417.5760) + %or.115450.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109729.9.clone.1, %shift-right-logical.115928.9.clone.1) + %xor.121997.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248658.3.clone.1, %or.115450.7.clone.1) + %add.248659.3.clone.1 = u32[1280,1280]{1,0} add(%add.248658.3.clone.1, %xor.121997.5.clone.1) + %shift-left.109730.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121997.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115929.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121997.5.clone.1, %broadcast.244419.4352) + %or.115451.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109730.7.clone.1, %shift-right-logical.115929.7.clone.1) + %xor.121998.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248659.3.clone.1, %or.115451.5.clone.1) + %add.248660.3.clone.1 = u32[1280,1280]{1,0} add(%add.248659.3.clone.1, %xor.121998.3.clone.1) + %add.248661.7.clone.1 = u32[1280,1280]{1,0} add(%add.248660.3.clone.1, %broadcast.252388.44.clone.1) + %shift-left.109731.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121998.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115930.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121998.3.clone.1, %broadcast.244418.4352) + %or.115452.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109731.7.clone.1, %shift-right-logical.115930.7.clone.1) + %xor.121999.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248660.3.clone.1, %or.115452.5.clone.1) + %constant_218274_1_clone_1 = u32[] constant(3630472855) + %broadcast.252419.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218274_1_clone_1), dimensions={} + %add.248662.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121999.3.clone.1, %broadcast.252419.5.clone.1) + %add.248663.5.clone.1 = u32[1280,1280]{1,0} add(%add.248661.7.clone.1, %add.248662.5.clone.1) + %shift-left.109732.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248662.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115931.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248662.5.clone.1, %broadcast.244416.5760) + %or.115453.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109732.9.clone.1, %shift-right-logical.115931.9.clone.1) + %xor.122000.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248663.5.clone.1, %or.115453.7.clone.1) + %add.248664.3.clone.1 = u32[1280,1280]{1,0} add(%add.248663.5.clone.1, %xor.122000.5.clone.1) + %shift-left.109733.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122000.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115932.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122000.5.clone.1, %broadcast.244429.2304) + %or.115454.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109733.9.clone.1, %shift-right-logical.115932.9.clone.1) + %xor.122001.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248664.3.clone.1, %or.115454.7.clone.1) + %add.248665.3.clone.1 = u32[1280,1280]{1,0} add(%add.248664.3.clone.1, %xor.122001.5.clone.1) + %shift-left.109734.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122001.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115933.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122001.5.clone.1, %broadcast.244430.4608) + %or.115455.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109734.9.clone.1, %shift-right-logical.115933.9.clone.1) + %xor.122002.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248665.3.clone.1, %or.115455.7.clone.1) + %add.248666.3.clone.1 = u32[1280,1280]{1,0} add(%add.248665.3.clone.1, %xor.122002.5.clone.1) + %add.248667.7.clone.1 = u32[1280,1280]{1,0} add(%add.248666.3.clone.1, %broadcast.252389.113.clone.1) + %shift-left.109735.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122002.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115934.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122002.5.clone.1, %broadcast.244434.2816) + %or.115456.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109735.11.clone.1, %shift-right-logical.115934.11.clone.1) + %xor.122003.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248666.3.clone.1, %or.115456.9.clone.1) + %constant_218275_1_clone_1 = u32[] constant(74952707) + %broadcast.252429.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218275_1_clone_1), dimensions={} + %add.248668.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122003.7.clone.1, %broadcast.252429.5.clone.1) + %add.248669.5.clone.1 = u32[1280,1280]{1,0} add(%add.248667.7.clone.1, %add.248668.5.clone.1) + %shift-left.109736.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248668.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115935.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248668.5.clone.1, %broadcast.244415.6016) + %or.115457.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109736.9.clone.1, %shift-right-logical.115935.9.clone.1) + %xor.122004.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248669.5.clone.1, %or.115457.7.clone.1) + %add.248670.3.clone.1 = u32[1280,1280]{1,0} add(%add.248669.5.clone.1, %xor.122004.5.clone.1) + %shift-left.109737.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122004.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115936.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122004.5.clone.1, %broadcast.244417.5760) + %or.115458.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109737.9.clone.1, %shift-right-logical.115936.9.clone.1) + %xor.122005.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248670.3.clone.1, %or.115458.7.clone.1) + %add.248671.3.clone.1 = u32[1280,1280]{1,0} add(%add.248670.3.clone.1, %xor.122005.5.clone.1) + %shift-left.109738.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122005.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115937.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122005.5.clone.1, %broadcast.244419.4352) + %or.115459.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109738.5.clone.1, %shift-right-logical.115937.5.clone.1) + %xor.122006.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248671.3.clone.1, %or.115459.3.clone.1) + %add.248672.3.clone.1 = u32[1280,1280]{1,0} add(%add.248671.3.clone.1, %xor.122006.3.clone.1) + %add.248673.17.clone.1 = u32[1280,1280]{1,0} add(%add.248672.3.clone.1, %broadcast.252406.24.clone.1) + %shift-left.109739.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122006.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115938.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122006.3.clone.1, %broadcast.244418.4352) + %or.115460.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109739.5.clone.1, %shift-right-logical.115938.5.clone.1) + %xor.122007.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248672.3.clone.1, %or.115460.3.clone.1) + %constant_218276_1_clone_1 = u32[] constant(3351390902) + %broadcast.252439.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218276_1_clone_1), dimensions={} + %add.248674.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122007.15.clone.1, %broadcast.252439.19.clone.1) + %xor.122008.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248673.17.clone.1, %add.248674.19.clone.1) + %shift-right-logical.115939.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122008.17.clone.1, %broadcast.244468.1920) + %or.115461.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115939.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5760.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115461.13.clone.1) + %add.248675.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5760.11.clone.1, %broadcast.244470.1152) + %multiply.26528.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248675.9.clone.1, %broadcast.244471.896) + %add.248676.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26528.7.clone.1, %broadcast.244408.1024) + %maximum.3692.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248676.5.clone.1) + %abs.1544.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3692.3.clone.1) + %compare.7236.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1544.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26529.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3692.3.clone.1, %broadcast.244476.1152) + %negate.4593.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3692.3.clone.1) + %multiply.26530.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3692.3.clone.1, %negate.4593.5.clone.1) + %log-plus-one.1544.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26530.5.clone.1) + %negate.4594.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1544.3.clone.1) + %compare.7237.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4594.4.clone.1, %broadcast.244477.384), direction=LT + %select.21070.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21071.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21072.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21073.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21074.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21075.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21076.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21077.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21078.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248677.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4594.4.clone.1, %broadcast.244496.640) + %sqrt.1544.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4594.4.clone.1) + %add.248678.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1544.5.clone.1, %broadcast.244498.640) + %select.21079.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7237.3.clone.1, %add.248677.5.clone.1, %add.248678.5.clone.1) + %multiply.26531.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21078.3.clone.1, %select.21079.3.clone.1) + %add.248679.1.clone.1 = f32[1280,1280]{1,0} add(%select.21077.3.clone.1, %multiply.26531.1.clone.1) + %multiply.26532.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248679.1.clone.1, %select.21079.3.clone.1) + %add.248680.1.clone.1 = f32[1280,1280]{1,0} add(%select.21076.3.clone.1, %multiply.26532.1.clone.1) + %multiply.26533.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248680.1.clone.1, %select.21079.3.clone.1) + %add.248681.1.clone.1 = f32[1280,1280]{1,0} add(%select.21075.3.clone.1, %multiply.26533.1.clone.1) + %multiply.26534.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248681.1.clone.1, %select.21079.3.clone.1) + %add.248682.1.clone.1 = f32[1280,1280]{1,0} add(%select.21074.3.clone.1, %multiply.26534.1.clone.1) + %multiply.26535.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248682.1.clone.1, %select.21079.3.clone.1) + %add.248683.3.clone.1 = f32[1280,1280]{1,0} add(%select.21073.5.clone.1, %multiply.26535.1.clone.1) + %multiply.26536.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248683.3.clone.1, %select.21079.3.clone.1) + %add.248684.3.clone.1 = f32[1280,1280]{1,0} add(%select.21072.5.clone.1, %multiply.26536.1.clone.1) + %multiply.26537.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248684.3.clone.1, %select.21079.3.clone.1) + %add.248685.9.clone.1 = f32[1280,1280]{1,0} add(%select.21071.11.clone.1, %multiply.26537.7.clone.1) + %multiply.26538.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248685.9.clone.1, %select.21079.3.clone.1) + %add.248686.7.clone.1 = f32[1280,1280]{1,0} add(%select.21070.7.clone.1, %multiply.26538.7.clone.1) + %multiply.26539.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248686.7.clone.1, %maximum.3692.3.clone.1) + %select.21080.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7236.3.clone.1, %multiply.26529.9.clone.1, %multiply.26539.7.clone.1) + %multiply.26540.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21080.7.clone.1, %broadcast.244500.640) + %clamp.1188.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26540.5.clone.1, %broadcast.244501.384) + %multiply.26541.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1188.3.clone.1, %broadcast.244502.1) + %constant_194669_1_clone_1 = u32[] constant(4088501695) + %broadcast.260211.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_194669_1_clone_1), dimensions={} + %add.253106.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.260211.44.clone.1) + %constant_194676_1_clone_1 = u32[] constant(3766148499) + %broadcast.260212.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_194676_1_clone_1), dimensions={} + %add.253107.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.260212.113.clone.1) + %add.253108.35.clone.1 = u32[1280,1280]{1,0} add(%add.253106.37.clone.1, %add.253107.99.clone.1) + %shift-left.111666.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253107.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117960.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253107.99.clone.1, %broadcast.244415.6016) + %or.117499.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111666.31.clone.1, %shift-right-logical.117960.29.clone.1) + %xor.124059.27.clone.1 = u32[1280,1280]{1,0} xor(%add.253108.35.clone.1, %or.117499.29.clone.1) + %add.253109.5.clone.1 = u32[1280,1280]{1,0} add(%add.253108.35.clone.1, %xor.124059.27.clone.1) + %shift-left.111667.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124059.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117961.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124059.27.clone.1, %broadcast.244417.5760) + %or.117500.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111667.9.clone.1, %shift-right-logical.117961.9.clone.1) + %xor.124060.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253109.5.clone.1, %or.117500.7.clone.1) + %add.253110.3.clone.1 = u32[1280,1280]{1,0} add(%add.253109.5.clone.1, %xor.124060.5.clone.1) + %shift-left.111668.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124060.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117962.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124060.5.clone.1, %broadcast.244419.4352) + %or.117502.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111668.5.clone.1, %shift-right-logical.117962.5.clone.1) + %xor.124061.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253110.3.clone.1, %or.117502.3.clone.1) + %add.253111.3.clone.1 = u32[1280,1280]{1,0} add(%add.253110.3.clone.1, %xor.124061.3.clone.1) + %add.253112.7.clone.1 = u32[1280,1280]{1,0} add(%add.253111.3.clone.1, %broadcast.260212.113.clone.1) + %shift-left.111669.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124061.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117963.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124061.3.clone.1, %broadcast.244418.4352) + %or.117503.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111669.5.clone.1, %shift-right-logical.117963.5.clone.1) + %xor.124062.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253111.3.clone.1, %or.117503.3.clone.1) + %constant_218766_1_clone_1 = u32[] constant(135943159) + %broadcast.260224.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218766_1_clone_1), dimensions={} + %add.253114.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124062.3.clone.1, %broadcast.260224.5.clone.1) + %add.253115.5.clone.1 = u32[1280,1280]{1,0} add(%add.253112.7.clone.1, %add.253114.5.clone.1) + %shift-left.111671.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253114.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117964.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253114.5.clone.1, %broadcast.244416.5760) + %or.117504.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111671.9.clone.1, %shift-right-logical.117964.9.clone.1) + %xor.124063.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253115.5.clone.1, %or.117504.7.clone.1) + %add.253116.3.clone.1 = u32[1280,1280]{1,0} add(%add.253115.5.clone.1, %xor.124063.5.clone.1) + %shift-left.111672.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124063.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117965.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124063.5.clone.1, %broadcast.244429.2304) + %or.117505.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111672.9.clone.1, %shift-right-logical.117965.9.clone.1) + %xor.124064.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253116.3.clone.1, %or.117505.7.clone.1) + %add.253117.3.clone.1 = u32[1280,1280]{1,0} add(%add.253116.3.clone.1, %xor.124064.5.clone.1) + %shift-left.111673.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124064.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117966.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124064.5.clone.1, %broadcast.244430.4608) + %or.117507.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111673.9.clone.1, %shift-right-logical.117966.9.clone.1) + %xor.124065.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253117.3.clone.1, %or.117507.7.clone.1) + %add.253118.3.clone.1 = u32[1280,1280]{1,0} add(%add.253117.3.clone.1, %xor.124065.5.clone.1) + %constant_194678_1_clone_1 = u32[] constant(135943158) + %broadcast.260231.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_194678_1_clone_1), dimensions={} + %add.253119.7.clone.1 = u32[1280,1280]{1,0} add(%add.253118.3.clone.1, %broadcast.260231.24.clone.1) + %shift-left.111674.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124065.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117967.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124065.5.clone.1, %broadcast.244434.2816) + %or.117508.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111674.11.clone.1, %shift-right-logical.117967.11.clone.1) + %xor.124066.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253118.3.clone.1, %or.117508.9.clone.1) + %constant_218768_1_clone_1 = u32[] constant(4088501697) + %broadcast.260234.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218768_1_clone_1), dimensions={} + %add.253120.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124066.7.clone.1, %broadcast.260234.5.clone.1) + %add.253121.5.clone.1 = u32[1280,1280]{1,0} add(%add.253119.7.clone.1, %add.253120.5.clone.1) + %shift-left.111676.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253120.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117968.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253120.5.clone.1, %broadcast.244415.6016) + %or.117509.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111676.9.clone.1, %shift-right-logical.117968.9.clone.1) + %xor.124067.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253121.5.clone.1, %or.117509.7.clone.1) + %add.253122.3.clone.1 = u32[1280,1280]{1,0} add(%add.253121.5.clone.1, %xor.124067.5.clone.1) + %shift-left.111677.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124067.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117969.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124067.5.clone.1, %broadcast.244417.5760) + %or.117510.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111677.9.clone.1, %shift-right-logical.117969.9.clone.1) + %xor.124068.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253122.3.clone.1, %or.117510.7.clone.1) + %add.253123.3.clone.1 = u32[1280,1280]{1,0} add(%add.253122.3.clone.1, %xor.124068.5.clone.1) + %shift-left.111678.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124068.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117970.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124068.5.clone.1, %broadcast.244419.4352) + %or.117512.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111678.7.clone.1, %shift-right-logical.117970.7.clone.1) + %xor.124069.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253123.3.clone.1, %or.117512.5.clone.1) + %add.253124.3.clone.1 = u32[1280,1280]{1,0} add(%add.253123.3.clone.1, %xor.124069.3.clone.1) + %add.253125.7.clone.1 = u32[1280,1280]{1,0} add(%add.253124.3.clone.1, %broadcast.260211.44.clone.1) + %shift-left.111679.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124069.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117972.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124069.3.clone.1, %broadcast.244418.4352) + %or.117513.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111679.7.clone.1, %shift-right-logical.117972.7.clone.1) + %xor.124070.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253124.3.clone.1, %or.117513.5.clone.1) + %constant_218770_1_clone_1 = u32[] constant(3766148502) + %broadcast.260246.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218770_1_clone_1), dimensions={} + %add.253126.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124070.3.clone.1, %broadcast.260246.5.clone.1) + %add.253127.5.clone.1 = u32[1280,1280]{1,0} add(%add.253125.7.clone.1, %add.253126.5.clone.1) + %shift-left.111680.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253126.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117973.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253126.5.clone.1, %broadcast.244416.5760) + %or.117514.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111680.9.clone.1, %shift-right-logical.117973.9.clone.1) + %xor.124071.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253127.5.clone.1, %or.117514.7.clone.1) + %add.253128.3.clone.1 = u32[1280,1280]{1,0} add(%add.253127.5.clone.1, %xor.124071.5.clone.1) + %shift-left.111681.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124071.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117974.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124071.5.clone.1, %broadcast.244429.2304) + %or.117515.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111681.9.clone.1, %shift-right-logical.117974.9.clone.1) + %xor.124072.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253128.3.clone.1, %or.117515.7.clone.1) + %add.253129.3.clone.1 = u32[1280,1280]{1,0} add(%add.253128.3.clone.1, %xor.124072.5.clone.1) + %shift-left.111682.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124072.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117975.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124072.5.clone.1, %broadcast.244430.4608) + %or.117516.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111682.9.clone.1, %shift-right-logical.117975.9.clone.1) + %xor.124073.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253129.3.clone.1, %or.117516.7.clone.1) + %add.253130.3.clone.1 = u32[1280,1280]{1,0} add(%add.253129.3.clone.1, %xor.124073.5.clone.1) + %add.253131.7.clone.1 = u32[1280,1280]{1,0} add(%add.253130.3.clone.1, %broadcast.260212.113.clone.1) + %shift-left.111683.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124073.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117977.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124073.5.clone.1, %broadcast.244434.2816) + %or.117517.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111683.11.clone.1, %shift-right-logical.117977.11.clone.1) + %xor.124074.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253130.3.clone.1, %or.117517.9.clone.1) + %constant_218772_1_clone_1 = u32[] constant(135943162) + %broadcast.260261.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218772_1_clone_1), dimensions={} + %add.253132.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124074.7.clone.1, %broadcast.260261.5.clone.1) + %add.253133.5.clone.1 = u32[1280,1280]{1,0} add(%add.253131.7.clone.1, %add.253132.5.clone.1) + %shift-left.111684.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253132.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117978.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253132.5.clone.1, %broadcast.244415.6016) + %or.117518.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111684.9.clone.1, %shift-right-logical.117978.9.clone.1) + %xor.124075.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253133.5.clone.1, %or.117518.7.clone.1) + %add.253134.3.clone.1 = u32[1280,1280]{1,0} add(%add.253133.5.clone.1, %xor.124075.5.clone.1) + %shift-left.111686.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124075.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117979.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124075.5.clone.1, %broadcast.244417.5760) + %or.117519.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111686.9.clone.1, %shift-right-logical.117979.9.clone.1) + %xor.124076.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253134.3.clone.1, %or.117519.7.clone.1) + %add.253135.3.clone.1 = u32[1280,1280]{1,0} add(%add.253134.3.clone.1, %xor.124076.5.clone.1) + %shift-left.111687.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124076.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117980.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124076.5.clone.1, %broadcast.244419.4352) + %or.117520.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111687.5.clone.1, %shift-right-logical.117980.5.clone.1) + %xor.124077.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253135.3.clone.1, %or.117520.3.clone.1) + %add.253136.3.clone.1 = u32[1280,1280]{1,0} add(%add.253135.3.clone.1, %xor.124077.3.clone.1) + %add.253137.17.clone.1 = u32[1280,1280]{1,0} add(%add.253136.3.clone.1, %broadcast.260231.24.clone.1) + %shift-left.111688.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124077.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117982.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124077.3.clone.1, %broadcast.244418.4352) + %or.117522.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111688.5.clone.1, %shift-right-logical.117982.5.clone.1) + %xor.124078.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253136.3.clone.1, %or.117522.3.clone.1) + %constant_218774_1_clone_1 = u32[] constant(4088501700) + %broadcast.260273.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218774_1_clone_1), dimensions={} + %add.253138.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124078.15.clone.1, %broadcast.260273.19.clone.1) + %xor.124079.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253137.17.clone.1, %add.253138.19.clone.1) + %shift-right-logical.117983.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124079.17.clone.1, %broadcast.244468.1920) + %or.117523.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117983.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5849.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117523.13.clone.1) + %add.253139.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5849.11.clone.1, %broadcast.244470.1152) + %multiply.27430.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253139.9.clone.1, %broadcast.244471.896) + %add.253140.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27430.7.clone.1, %broadcast.244408.1024) + %maximum.3781.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253140.5.clone.1) + %abs.1603.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3781.3.clone.1) + %compare.7368.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1603.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27431.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3781.3.clone.1, %broadcast.244476.1152) + %negate.4711.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3781.3.clone.1) + %multiply.27432.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3781.3.clone.1, %negate.4711.5.clone.1) + %log-plus-one.1603.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27432.5.clone.1) + %negate.4712.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1603.3.clone.1) + %compare.7369.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4712.4.clone.1, %broadcast.244477.384), direction=LT + %select.21740.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21741.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21742.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21743.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21744.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21745.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21746.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21747.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21748.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253141.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4712.4.clone.1, %broadcast.244496.640) + %sqrt.1603.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4712.4.clone.1) + %add.253142.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1603.5.clone.1, %broadcast.244498.640) + %select.21749.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7369.3.clone.1, %add.253141.5.clone.1, %add.253142.5.clone.1) + %multiply.27433.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21748.3.clone.1, %select.21749.3.clone.1) + %add.253143.1.clone.1 = f32[1280,1280]{1,0} add(%select.21747.3.clone.1, %multiply.27433.1.clone.1) + %multiply.27434.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253143.1.clone.1, %select.21749.3.clone.1) + %add.253144.1.clone.1 = f32[1280,1280]{1,0} add(%select.21746.3.clone.1, %multiply.27434.1.clone.1) + %multiply.27435.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253144.1.clone.1, %select.21749.3.clone.1) + %add.253145.1.clone.1 = f32[1280,1280]{1,0} add(%select.21745.3.clone.1, %multiply.27435.1.clone.1) + %multiply.27436.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253145.1.clone.1, %select.21749.3.clone.1) + %add.253146.1.clone.1 = f32[1280,1280]{1,0} add(%select.21744.3.clone.1, %multiply.27436.1.clone.1) + %multiply.27437.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253146.1.clone.1, %select.21749.3.clone.1) + %add.253147.3.clone.1 = f32[1280,1280]{1,0} add(%select.21743.5.clone.1, %multiply.27437.1.clone.1) + %multiply.27438.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253147.3.clone.1, %select.21749.3.clone.1) + %add.253148.3.clone.1 = f32[1280,1280]{1,0} add(%select.21742.5.clone.1, %multiply.27438.1.clone.1) + %multiply.27439.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253148.3.clone.1, %select.21749.3.clone.1) + %add.253149.9.clone.1 = f32[1280,1280]{1,0} add(%select.21741.11.clone.1, %multiply.27439.7.clone.1) + %multiply.27440.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253149.9.clone.1, %select.21749.3.clone.1) + %add.253150.7.clone.1 = f32[1280,1280]{1,0} add(%select.21740.7.clone.1, %multiply.27440.7.clone.1) + %multiply.27441.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253150.7.clone.1, %maximum.3781.3.clone.1) + %select.21750.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7368.3.clone.1, %multiply.27431.9.clone.1, %multiply.27441.7.clone.1) + %multiply.27442.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21750.7.clone.1, %broadcast.244500.640) + %clamp.1247.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27442.5.clone.1, %broadcast.244501.384) + %multiply.27443.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1247.3.clone.1, %broadcast.244502.1) + %constant_176386_1_clone_1 = u32[] constant(866984013) + %broadcast.252287.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176386_1_clone_1), dimensions={} + %add.248592.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252287.44.clone.1) + %constant_176393_1_clone_1 = u32[] constant(280297754) + %broadcast.252289.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176393_1_clone_1), dimensions={} + %add.248593.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252289.113.clone.1) + %add.248595.35.clone.1 = u32[1280,1280]{1,0} add(%add.248592.37.clone.1, %add.248593.99.clone.1) + %shift-left.109700.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248593.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115898.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248593.99.clone.1, %broadcast.244415.6016) + %or.115420.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109700.31.clone.1, %shift-right-logical.115898.29.clone.1) + %xor.121967.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248595.35.clone.1, %or.115420.29.clone.1) + %add.248596.5.clone.1 = u32[1280,1280]{1,0} add(%add.248595.35.clone.1, %xor.121967.27.clone.1) + %shift-left.109701.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121967.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115899.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121967.27.clone.1, %broadcast.244417.5760) + %or.115421.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109701.9.clone.1, %shift-right-logical.115899.9.clone.1) + %xor.121968.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248596.5.clone.1, %or.115421.7.clone.1) + %add.248597.3.clone.1 = u32[1280,1280]{1,0} add(%add.248596.5.clone.1, %xor.121968.5.clone.1) + %shift-left.109702.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121968.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115900.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121968.5.clone.1, %broadcast.244419.4352) + %or.115422.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109702.5.clone.1, %shift-right-logical.115900.5.clone.1) + %xor.121969.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248597.3.clone.1, %or.115422.3.clone.1) + %add.248598.3.clone.1 = u32[1280,1280]{1,0} add(%add.248597.3.clone.1, %xor.121969.3.clone.1) + %add.248600.7.clone.1 = u32[1280,1280]{1,0} add(%add.248598.3.clone.1, %broadcast.252289.113.clone.1) + %shift-left.109703.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121969.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115901.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121969.3.clone.1, %broadcast.244418.4352) + %or.115423.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109703.5.clone.1, %shift-right-logical.115901.5.clone.1) + %xor.121970.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248598.3.clone.1, %or.115423.3.clone.1) + %constant_218267_1_clone_1 = u32[] constant(952711822) + %broadcast.252302.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218267_1_clone_1), dimensions={} + %add.248601.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121970.3.clone.1, %broadcast.252302.5.clone.1) + %add.248602.5.clone.1 = u32[1280,1280]{1,0} add(%add.248600.7.clone.1, %add.248601.5.clone.1) + %shift-left.109704.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248601.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115902.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248601.5.clone.1, %broadcast.244416.5760) + %or.115424.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109704.9.clone.1, %shift-right-logical.115902.9.clone.1) + %xor.121971.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248602.5.clone.1, %or.115424.7.clone.1) + %add.248603.3.clone.1 = u32[1280,1280]{1,0} add(%add.248602.5.clone.1, %xor.121971.5.clone.1) + %shift-left.109705.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121971.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115903.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121971.5.clone.1, %broadcast.244429.2304) + %or.115425.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109705.9.clone.1, %shift-right-logical.115903.9.clone.1) + %xor.121972.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248603.3.clone.1, %or.115425.7.clone.1) + %add.248604.3.clone.1 = u32[1280,1280]{1,0} add(%add.248603.3.clone.1, %xor.121972.5.clone.1) + %shift-left.109706.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121972.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115904.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121972.5.clone.1, %broadcast.244430.4608) + %or.115426.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109706.9.clone.1, %shift-right-logical.115904.9.clone.1) + %xor.121973.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248604.3.clone.1, %or.115426.7.clone.1) + %add.248606.3.clone.1 = u32[1280,1280]{1,0} add(%add.248604.3.clone.1, %xor.121973.5.clone.1) + %constant_176395_1_clone_1 = u32[] constant(952711821) + %broadcast.252312.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_176395_1_clone_1), dimensions={} + %add.248609.7.clone.1 = u32[1280,1280]{1,0} add(%add.248606.3.clone.1, %broadcast.252312.24.clone.1) + %shift-left.109707.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121973.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115905.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121973.5.clone.1, %broadcast.244434.2816) + %or.115427.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109707.11.clone.1, %shift-right-logical.115905.11.clone.1) + %xor.121974.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248606.3.clone.1, %or.115427.9.clone.1) + %constant_218268_1_clone_1 = u32[] constant(866984015) + %broadcast.252315.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218268_1_clone_1), dimensions={} + %add.248610.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121974.7.clone.1, %broadcast.252315.5.clone.1) + %add.248611.5.clone.1 = u32[1280,1280]{1,0} add(%add.248609.7.clone.1, %add.248610.5.clone.1) + %shift-left.109708.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248610.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115906.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248610.5.clone.1, %broadcast.244415.6016) + %or.115428.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109708.9.clone.1, %shift-right-logical.115906.9.clone.1) + %xor.121975.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248611.5.clone.1, %or.115428.7.clone.1) + %add.248612.3.clone.1 = u32[1280,1280]{1,0} add(%add.248611.5.clone.1, %xor.121975.5.clone.1) + %shift-left.109709.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121975.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115907.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121975.5.clone.1, %broadcast.244417.5760) + %or.115429.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109709.9.clone.1, %shift-right-logical.115907.9.clone.1) + %xor.121976.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248612.3.clone.1, %or.115429.7.clone.1) + %add.248613.3.clone.1 = u32[1280,1280]{1,0} add(%add.248612.3.clone.1, %xor.121976.5.clone.1) + %shift-left.109710.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121976.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115908.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121976.5.clone.1, %broadcast.244419.4352) + %or.115430.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109710.7.clone.1, %shift-right-logical.115908.7.clone.1) + %xor.121977.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248613.3.clone.1, %or.115430.5.clone.1) + %add.248614.3.clone.1 = u32[1280,1280]{1,0} add(%add.248613.3.clone.1, %xor.121977.3.clone.1) + %add.248615.7.clone.1 = u32[1280,1280]{1,0} add(%add.248614.3.clone.1, %broadcast.252287.44.clone.1) + %shift-left.109711.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121977.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115909.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121977.3.clone.1, %broadcast.244418.4352) + %or.115431.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109711.7.clone.1, %shift-right-logical.115909.7.clone.1) + %xor.121978.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248614.3.clone.1, %or.115431.5.clone.1) + %constant_218269_1_clone_1 = u32[] constant(280297757) + %broadcast.252327.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218269_1_clone_1), dimensions={} + %add.248616.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121978.3.clone.1, %broadcast.252327.5.clone.1) + %add.248617.5.clone.1 = u32[1280,1280]{1,0} add(%add.248615.7.clone.1, %add.248616.5.clone.1) + %shift-left.109712.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248616.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115910.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248616.5.clone.1, %broadcast.244416.5760) + %or.115432.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109712.9.clone.1, %shift-right-logical.115910.9.clone.1) + %xor.121979.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248617.5.clone.1, %or.115432.7.clone.1) + %add.248618.3.clone.1 = u32[1280,1280]{1,0} add(%add.248617.5.clone.1, %xor.121979.5.clone.1) + %shift-left.109713.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121979.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115911.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121979.5.clone.1, %broadcast.244429.2304) + %or.115433.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109713.9.clone.1, %shift-right-logical.115911.9.clone.1) + %xor.121980.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248618.3.clone.1, %or.115433.7.clone.1) + %add.248619.3.clone.1 = u32[1280,1280]{1,0} add(%add.248618.3.clone.1, %xor.121980.5.clone.1) + %shift-left.109714.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121980.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115912.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121980.5.clone.1, %broadcast.244430.4608) + %or.115434.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109714.9.clone.1, %shift-right-logical.115912.9.clone.1) + %xor.121981.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248619.3.clone.1, %or.115434.7.clone.1) + %add.248620.3.clone.1 = u32[1280,1280]{1,0} add(%add.248619.3.clone.1, %xor.121981.5.clone.1) + %add.248621.7.clone.1 = u32[1280,1280]{1,0} add(%add.248620.3.clone.1, %broadcast.252289.113.clone.1) + %shift-left.109715.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121981.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115913.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121981.5.clone.1, %broadcast.244434.2816) + %or.115435.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109715.11.clone.1, %shift-right-logical.115913.11.clone.1) + %xor.121982.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248620.3.clone.1, %or.115435.9.clone.1) + %constant_218270_1_clone_1 = u32[] constant(952711825) + %broadcast.252342.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218270_1_clone_1), dimensions={} + %add.248622.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121982.7.clone.1, %broadcast.252342.5.clone.1) + %add.248623.5.clone.1 = u32[1280,1280]{1,0} add(%add.248621.7.clone.1, %add.248622.5.clone.1) + %shift-left.109716.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248622.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115914.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248622.5.clone.1, %broadcast.244415.6016) + %or.115436.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109716.9.clone.1, %shift-right-logical.115914.9.clone.1) + %xor.121983.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248623.5.clone.1, %or.115436.7.clone.1) + %add.248624.3.clone.1 = u32[1280,1280]{1,0} add(%add.248623.5.clone.1, %xor.121983.5.clone.1) + %shift-left.109717.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121983.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115915.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121983.5.clone.1, %broadcast.244417.5760) + %or.115437.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109717.9.clone.1, %shift-right-logical.115915.9.clone.1) + %xor.121984.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248624.3.clone.1, %or.115437.7.clone.1) + %add.248625.3.clone.1 = u32[1280,1280]{1,0} add(%add.248624.3.clone.1, %xor.121984.5.clone.1) + %shift-left.109718.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121984.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115916.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121984.5.clone.1, %broadcast.244419.4352) + %or.115438.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109718.5.clone.1, %shift-right-logical.115916.5.clone.1) + %xor.121985.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248625.3.clone.1, %or.115438.3.clone.1) + %add.248626.3.clone.1 = u32[1280,1280]{1,0} add(%add.248625.3.clone.1, %xor.121985.3.clone.1) + %add.248627.17.clone.1 = u32[1280,1280]{1,0} add(%add.248626.3.clone.1, %broadcast.252312.24.clone.1) + %shift-left.109719.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121985.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115917.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121985.3.clone.1, %broadcast.244418.4352) + %or.115439.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109719.5.clone.1, %shift-right-logical.115917.5.clone.1) + %xor.121986.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248626.3.clone.1, %or.115439.3.clone.1) + %constant_218271_1_clone_1 = u32[] constant(866984018) + %broadcast.252353.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218271_1_clone_1), dimensions={} + %add.248628.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121986.15.clone.1, %broadcast.252353.19.clone.1) + %xor.121987.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248627.17.clone.1, %add.248628.19.clone.1) + %shift-right-logical.115918.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121987.17.clone.1, %broadcast.244468.1920) + %or.115440.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115918.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5759.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115440.13.clone.1) + %add.248629.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5759.11.clone.1, %broadcast.244470.1152) + %multiply.26514.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248629.9.clone.1, %broadcast.244471.896) + %add.248630.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26514.7.clone.1, %broadcast.244408.1024) + %maximum.3691.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248630.5.clone.1) + %abs.1543.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3691.3.clone.1) + %compare.7234.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1543.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26515.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3691.3.clone.1, %broadcast.244476.1152) + %negate.4591.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3691.3.clone.1) + %multiply.26516.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3691.3.clone.1, %negate.4591.5.clone.1) + %log-plus-one.1543.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26516.5.clone.1) + %negate.4592.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1543.3.clone.1) + %compare.7235.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4592.4.clone.1, %broadcast.244477.384), direction=LT + %select.21059.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21060.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21061.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21062.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21063.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21064.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21065.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21066.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21067.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248631.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4592.4.clone.1, %broadcast.244496.640) + %sqrt.1543.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4592.4.clone.1) + %add.248632.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1543.5.clone.1, %broadcast.244498.640) + %select.21068.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7235.3.clone.1, %add.248631.5.clone.1, %add.248632.5.clone.1) + %multiply.26517.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21067.3.clone.1, %select.21068.3.clone.1) + %add.248633.1.clone.1 = f32[1280,1280]{1,0} add(%select.21066.3.clone.1, %multiply.26517.1.clone.1) + %multiply.26518.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248633.1.clone.1, %select.21068.3.clone.1) + %add.248634.1.clone.1 = f32[1280,1280]{1,0} add(%select.21065.3.clone.1, %multiply.26518.1.clone.1) + %multiply.26519.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248634.1.clone.1, %select.21068.3.clone.1) + %add.248635.1.clone.1 = f32[1280,1280]{1,0} add(%select.21064.3.clone.1, %multiply.26519.1.clone.1) + %multiply.26520.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248635.1.clone.1, %select.21068.3.clone.1) + %add.248637.1.clone.1 = f32[1280,1280]{1,0} add(%select.21063.3.clone.1, %multiply.26520.1.clone.1) + %multiply.26521.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248637.1.clone.1, %select.21068.3.clone.1) + %add.248638.3.clone.1 = f32[1280,1280]{1,0} add(%select.21062.5.clone.1, %multiply.26521.1.clone.1) + %multiply.26522.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248638.3.clone.1, %select.21068.3.clone.1) + %add.248639.3.clone.1 = f32[1280,1280]{1,0} add(%select.21061.5.clone.1, %multiply.26522.1.clone.1) + %multiply.26523.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248639.3.clone.1, %select.21068.3.clone.1) + %add.248640.9.clone.1 = f32[1280,1280]{1,0} add(%select.21060.11.clone.1, %multiply.26523.7.clone.1) + %multiply.26524.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248640.9.clone.1, %select.21068.3.clone.1) + %add.248641.7.clone.1 = f32[1280,1280]{1,0} add(%select.21059.7.clone.1, %multiply.26524.7.clone.1) + %multiply.26525.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248641.7.clone.1, %maximum.3691.3.clone.1) + %select.21069.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7234.3.clone.1, %multiply.26515.9.clone.1, %multiply.26525.7.clone.1) + %multiply.26526.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21069.7.clone.1, %broadcast.244500.640) + %clamp.1187.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26526.5.clone.1, %broadcast.244501.384) + %multiply.26527.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1187.3.clone.1, %broadcast.244502.1) + %constant_186570_1_clone_1 = u32[] constant(658820832) + %broadcast.256722.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186570_1_clone_1), dimensions={} + %add.251101.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256722.44.clone.1) + %constant_186577_1_clone_1 = u32[] constant(157675534) + %broadcast.256723.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186577_1_clone_1), dimensions={} + %add.251102.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256723.113.clone.1) + %add.251103.35.clone.1 = u32[1280,1280]{1,0} add(%add.251101.37.clone.1, %add.251102.99.clone.1) + %shift-left.110797.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251102.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117056.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251102.99.clone.1, %broadcast.244415.6016) + %or.116576.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110797.31.clone.1, %shift-right-logical.117056.29.clone.1) + %xor.123151.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251103.35.clone.1, %or.116576.29.clone.1) + %add.251104.5.clone.1 = u32[1280,1280]{1,0} add(%add.251103.35.clone.1, %xor.123151.27.clone.1) + %shift-left.110798.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123151.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117057.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123151.27.clone.1, %broadcast.244417.5760) + %or.116577.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110798.9.clone.1, %shift-right-logical.117057.9.clone.1) + %xor.123152.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251104.5.clone.1, %or.116577.7.clone.1) + %add.251106.3.clone.1 = u32[1280,1280]{1,0} add(%add.251104.5.clone.1, %xor.123152.5.clone.1) + %shift-left.110800.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123152.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117059.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123152.5.clone.1, %broadcast.244419.4352) + %or.116578.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110800.5.clone.1, %shift-right-logical.117059.5.clone.1) + %xor.123153.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251106.3.clone.1, %or.116578.3.clone.1) + %add.251107.3.clone.1 = u32[1280,1280]{1,0} add(%add.251106.3.clone.1, %xor.123153.3.clone.1) + %add.251108.7.clone.1 = u32[1280,1280]{1,0} add(%add.251107.3.clone.1, %broadcast.256723.113.clone.1) + %shift-left.110801.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123153.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117060.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123153.3.clone.1, %broadcast.244418.4352) + %or.116579.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110801.5.clone.1, %shift-right-logical.117060.5.clone.1) + %xor.123154.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251107.3.clone.1, %or.116579.3.clone.1) + %constant_218539_1_clone_1 = u32[] constant(904930613) + %broadcast.256733.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218539_1_clone_1), dimensions={} + %add.251109.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123154.3.clone.1, %broadcast.256733.5.clone.1) + %add.251111.5.clone.1 = u32[1280,1280]{1,0} add(%add.251108.7.clone.1, %add.251109.5.clone.1) + %shift-left.110802.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251109.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117061.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251109.5.clone.1, %broadcast.244416.5760) + %or.116581.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110802.9.clone.1, %shift-right-logical.117061.9.clone.1) + %xor.123155.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251111.5.clone.1, %or.116581.7.clone.1) + %add.251112.3.clone.1 = u32[1280,1280]{1,0} add(%add.251111.5.clone.1, %xor.123155.5.clone.1) + %shift-left.110803.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123155.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117062.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123155.5.clone.1, %broadcast.244429.2304) + %or.116582.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110803.9.clone.1, %shift-right-logical.117062.9.clone.1) + %xor.123156.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251112.3.clone.1, %or.116582.7.clone.1) + %add.251113.3.clone.1 = u32[1280,1280]{1,0} add(%add.251112.3.clone.1, %xor.123156.5.clone.1) + %shift-left.110805.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123156.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117064.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123156.5.clone.1, %broadcast.244430.4608) + %or.116583.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110805.9.clone.1, %shift-right-logical.117064.9.clone.1) + %xor.123157.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251113.3.clone.1, %or.116583.7.clone.1) + %add.251114.3.clone.1 = u32[1280,1280]{1,0} add(%add.251113.3.clone.1, %xor.123157.5.clone.1) + %constant_186579_1_clone_1 = u32[] constant(904930612) + %broadcast.256740.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186579_1_clone_1), dimensions={} + %add.251115.7.clone.1 = u32[1280,1280]{1,0} add(%add.251114.3.clone.1, %broadcast.256740.24.clone.1) + %shift-left.110806.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123157.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117065.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123157.5.clone.1, %broadcast.244434.2816) + %or.116584.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110806.11.clone.1, %shift-right-logical.117065.11.clone.1) + %xor.123158.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251114.3.clone.1, %or.116584.9.clone.1) + %constant_218540_1_clone_1 = u32[] constant(658820834) + %broadcast.256743.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218540_1_clone_1), dimensions={} + %add.251117.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123158.7.clone.1, %broadcast.256743.5.clone.1) + %add.251120.5.clone.1 = u32[1280,1280]{1,0} add(%add.251115.7.clone.1, %add.251117.5.clone.1) + %shift-left.110807.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251117.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117066.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251117.5.clone.1, %broadcast.244415.6016) + %or.116586.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110807.9.clone.1, %shift-right-logical.117066.9.clone.1) + %xor.123159.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251120.5.clone.1, %or.116586.7.clone.1) + %add.251121.3.clone.1 = u32[1280,1280]{1,0} add(%add.251120.5.clone.1, %xor.123159.5.clone.1) + %shift-left.110808.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123159.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117067.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123159.5.clone.1, %broadcast.244417.5760) + %or.116587.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110808.9.clone.1, %shift-right-logical.117067.9.clone.1) + %xor.123160.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251121.3.clone.1, %or.116587.7.clone.1) + %add.251122.3.clone.1 = u32[1280,1280]{1,0} add(%add.251121.3.clone.1, %xor.123160.5.clone.1) + %shift-left.110810.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123160.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117069.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123160.5.clone.1, %broadcast.244419.4352) + %or.116588.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110810.7.clone.1, %shift-right-logical.117069.7.clone.1) + %xor.123161.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251122.3.clone.1, %or.116588.5.clone.1) + %add.251123.3.clone.1 = u32[1280,1280]{1,0} add(%add.251122.3.clone.1, %xor.123161.3.clone.1) + %add.251124.7.clone.1 = u32[1280,1280]{1,0} add(%add.251123.3.clone.1, %broadcast.256722.44.clone.1) + %shift-left.110811.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123161.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117070.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123161.3.clone.1, %broadcast.244418.4352) + %or.116589.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110811.7.clone.1, %shift-right-logical.117070.7.clone.1) + %xor.123162.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251123.3.clone.1, %or.116589.5.clone.1) + %constant_218541_1_clone_1 = u32[] constant(157675537) + %broadcast.256753.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218541_1_clone_1), dimensions={} + %add.251125.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123162.3.clone.1, %broadcast.256753.5.clone.1) + %add.251126.5.clone.1 = u32[1280,1280]{1,0} add(%add.251124.7.clone.1, %add.251125.5.clone.1) + %shift-left.110812.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251125.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117071.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251125.5.clone.1, %broadcast.244416.5760) + %or.116591.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110812.9.clone.1, %shift-right-logical.117071.9.clone.1) + %xor.123163.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251126.5.clone.1, %or.116591.7.clone.1) + %add.251127.3.clone.1 = u32[1280,1280]{1,0} add(%add.251126.5.clone.1, %xor.123163.5.clone.1) + %shift-left.110813.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123163.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117072.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123163.5.clone.1, %broadcast.244429.2304) + %or.116592.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110813.9.clone.1, %shift-right-logical.117072.9.clone.1) + %xor.123164.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251127.3.clone.1, %or.116592.7.clone.1) + %add.251128.3.clone.1 = u32[1280,1280]{1,0} add(%add.251127.3.clone.1, %xor.123164.5.clone.1) + %shift-left.110814.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123164.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117073.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123164.5.clone.1, %broadcast.244430.4608) + %or.116593.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110814.9.clone.1, %shift-right-logical.117073.9.clone.1) + %xor.123165.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251128.3.clone.1, %or.116593.7.clone.1) + %add.251129.3.clone.1 = u32[1280,1280]{1,0} add(%add.251128.3.clone.1, %xor.123165.5.clone.1) + %add.251130.7.clone.1 = u32[1280,1280]{1,0} add(%add.251129.3.clone.1, %broadcast.256723.113.clone.1) + %shift-left.110815.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123165.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117074.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123165.5.clone.1, %broadcast.244434.2816) + %or.116594.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110815.11.clone.1, %shift-right-logical.117074.11.clone.1) + %xor.123166.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251129.3.clone.1, %or.116594.9.clone.1) + %constant_218542_1_clone_1 = u32[] constant(904930616) + %broadcast.256763.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218542_1_clone_1), dimensions={} + %add.251131.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123166.7.clone.1, %broadcast.256763.5.clone.1) + %add.251132.5.clone.1 = u32[1280,1280]{1,0} add(%add.251130.7.clone.1, %add.251131.5.clone.1) + %shift-left.110816.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251131.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117075.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251131.5.clone.1, %broadcast.244415.6016) + %or.116596.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110816.9.clone.1, %shift-right-logical.117075.9.clone.1) + %xor.123167.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251132.5.clone.1, %or.116596.7.clone.1) + %add.251133.3.clone.1 = u32[1280,1280]{1,0} add(%add.251132.5.clone.1, %xor.123167.5.clone.1) + %shift-left.110817.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123167.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117076.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123167.5.clone.1, %broadcast.244417.5760) + %or.116597.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110817.9.clone.1, %shift-right-logical.117076.9.clone.1) + %xor.123168.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251133.3.clone.1, %or.116597.7.clone.1) + %add.251134.3.clone.1 = u32[1280,1280]{1,0} add(%add.251133.3.clone.1, %xor.123168.5.clone.1) + %shift-left.110818.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123168.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117077.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123168.5.clone.1, %broadcast.244419.4352) + %or.116598.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110818.5.clone.1, %shift-right-logical.117077.5.clone.1) + %xor.123169.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251134.3.clone.1, %or.116598.3.clone.1) + %add.251135.3.clone.1 = u32[1280,1280]{1,0} add(%add.251134.3.clone.1, %xor.123169.3.clone.1) + %add.251136.17.clone.1 = u32[1280,1280]{1,0} add(%add.251135.3.clone.1, %broadcast.256740.24.clone.1) + %shift-left.110819.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123169.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117079.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123169.3.clone.1, %broadcast.244418.4352) + %or.116599.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110819.5.clone.1, %shift-right-logical.117079.5.clone.1) + %xor.123170.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251135.3.clone.1, %or.116599.3.clone.1) + %constant_218543_1_clone_1 = u32[] constant(658820837) + %broadcast.256773.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218543_1_clone_1), dimensions={} + %add.251137.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123170.15.clone.1, %broadcast.256773.19.clone.1) + %xor.123171.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251136.17.clone.1, %add.251137.19.clone.1) + %shift-right-logical.117080.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123171.17.clone.1, %broadcast.244468.1920) + %or.116600.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117080.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5809.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116600.13.clone.1) + %add.251138.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5809.11.clone.1, %broadcast.244470.1152) + %multiply.27027.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251138.9.clone.1, %broadcast.244471.896) + %add.251139.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27027.7.clone.1, %broadcast.244408.1024) + %maximum.3741.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251139.5.clone.1) + %abs.1577.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3741.3.clone.1) + %compare.7314.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1577.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27029.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3741.3.clone.1, %broadcast.244476.1152) + %negate.4659.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3741.3.clone.1) + %multiply.27030.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3741.3.clone.1, %negate.4659.5.clone.1) + %log-plus-one.1577.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27030.5.clone.1) + %negate.4660.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1577.3.clone.1) + %compare.7315.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4660.4.clone.1, %broadcast.244477.384), direction=LT + %select.21433.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21434.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21435.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21436.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21437.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21438.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21439.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21440.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21441.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251140.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4660.4.clone.1, %broadcast.244496.640) + %sqrt.1577.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4660.4.clone.1) + %add.251141.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1577.5.clone.1, %broadcast.244498.640) + %select.21442.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7315.3.clone.1, %add.251140.5.clone.1, %add.251141.5.clone.1) + %multiply.27032.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21441.3.clone.1, %select.21442.3.clone.1) + %add.251142.1.clone.1 = f32[1280,1280]{1,0} add(%select.21440.3.clone.1, %multiply.27032.1.clone.1) + %multiply.27033.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251142.1.clone.1, %select.21442.3.clone.1) + %add.251143.1.clone.1 = f32[1280,1280]{1,0} add(%select.21439.3.clone.1, %multiply.27033.1.clone.1) + %multiply.27035.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251143.1.clone.1, %select.21442.3.clone.1) + %add.251144.1.clone.1 = f32[1280,1280]{1,0} add(%select.21438.3.clone.1, %multiply.27035.1.clone.1) + %multiply.27036.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251144.1.clone.1, %select.21442.3.clone.1) + %add.251145.1.clone.1 = f32[1280,1280]{1,0} add(%select.21437.3.clone.1, %multiply.27036.1.clone.1) + %multiply.27038.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251145.1.clone.1, %select.21442.3.clone.1) + %add.251146.3.clone.1 = f32[1280,1280]{1,0} add(%select.21436.5.clone.1, %multiply.27038.1.clone.1) + %multiply.27039.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251146.3.clone.1, %select.21442.3.clone.1) + %add.251148.3.clone.1 = f32[1280,1280]{1,0} add(%select.21435.5.clone.1, %multiply.27039.1.clone.1) + %multiply.27041.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251148.3.clone.1, %select.21442.3.clone.1) + %add.251149.9.clone.1 = f32[1280,1280]{1,0} add(%select.21434.11.clone.1, %multiply.27041.7.clone.1) + %multiply.27042.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251149.9.clone.1, %select.21442.3.clone.1) + %add.251150.7.clone.1 = f32[1280,1280]{1,0} add(%select.21433.7.clone.1, %multiply.27042.7.clone.1) + %multiply.27044.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251150.7.clone.1, %maximum.3741.3.clone.1) + %select.21443.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7314.3.clone.1, %multiply.27029.9.clone.1, %multiply.27044.7.clone.1) + %multiply.27045.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21443.7.clone.1, %broadcast.244500.640) + %clamp.1221.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27045.5.clone.1, %broadcast.244501.384) + %multiply.27047.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1221.3.clone.1, %broadcast.244502.1) + %constant_175835_1_clone_1 = u32[] constant(1103031926) + %broadcast.252052.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175835_1_clone_1), dimensions={} + %add.248448.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.252052.44.clone.1) + %constant_175842_1_clone_1 = u32[] constant(3567742604) + %broadcast.252053.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175842_1_clone_1), dimensions={} + %add.248449.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.252053.113.clone.1) + %add.248450.35.clone.1 = u32[1280,1280]{1,0} add(%add.248448.37.clone.1, %add.248449.99.clone.1) + %shift-left.109631.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248449.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115831.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248449.99.clone.1, %broadcast.244415.6016) + %or.115357.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109631.31.clone.1, %shift-right-logical.115831.29.clone.1) + %xor.121904.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248450.35.clone.1, %or.115357.29.clone.1) + %add.248451.5.clone.1 = u32[1280,1280]{1,0} add(%add.248450.35.clone.1, %xor.121904.27.clone.1) + %shift-left.109632.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121904.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115833.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121904.27.clone.1, %broadcast.244417.5760) + %or.115358.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109632.9.clone.1, %shift-right-logical.115833.9.clone.1) + %xor.121905.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248451.5.clone.1, %or.115358.7.clone.1) + %add.248452.3.clone.1 = u32[1280,1280]{1,0} add(%add.248451.5.clone.1, %xor.121905.5.clone.1) + %shift-left.109633.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121905.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115834.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121905.5.clone.1, %broadcast.244419.4352) + %or.115359.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109633.5.clone.1, %shift-right-logical.115834.5.clone.1) + %xor.121906.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248452.3.clone.1, %or.115359.3.clone.1) + %add.248453.3.clone.1 = u32[1280,1280]{1,0} add(%add.248452.3.clone.1, %xor.121906.3.clone.1) + %add.248454.7.clone.1 = u32[1280,1280]{1,0} add(%add.248453.3.clone.1, %broadcast.252053.113.clone.1) + %shift-left.109635.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121906.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115835.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121906.3.clone.1, %broadcast.244418.4352) + %or.115360.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109635.5.clone.1, %shift-right-logical.115835.5.clone.1) + %xor.121907.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248453.3.clone.1, %or.115360.3.clone.1) + %constant_218252_1_clone_1 = u32[] constant(2395511585) + %broadcast.252063.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218252_1_clone_1), dimensions={} + %add.248455.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121907.3.clone.1, %broadcast.252063.5.clone.1) + %add.248456.5.clone.1 = u32[1280,1280]{1,0} add(%add.248454.7.clone.1, %add.248455.5.clone.1) + %shift-left.109636.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248455.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115836.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248455.5.clone.1, %broadcast.244416.5760) + %or.115361.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109636.9.clone.1, %shift-right-logical.115836.9.clone.1) + %xor.121908.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248456.5.clone.1, %or.115361.7.clone.1) + %add.248457.3.clone.1 = u32[1280,1280]{1,0} add(%add.248456.5.clone.1, %xor.121908.5.clone.1) + %shift-left.109637.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121908.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115838.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121908.5.clone.1, %broadcast.244429.2304) + %or.115362.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109637.9.clone.1, %shift-right-logical.115838.9.clone.1) + %xor.121909.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248457.3.clone.1, %or.115362.7.clone.1) + %add.248458.3.clone.1 = u32[1280,1280]{1,0} add(%add.248457.3.clone.1, %xor.121909.5.clone.1) + %shift-left.109638.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121909.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115839.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121909.5.clone.1, %broadcast.244430.4608) + %or.115363.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109638.9.clone.1, %shift-right-logical.115839.9.clone.1) + %xor.121910.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248458.3.clone.1, %or.115363.7.clone.1) + %add.248459.3.clone.1 = u32[1280,1280]{1,0} add(%add.248458.3.clone.1, %xor.121910.5.clone.1) + %constant_175844_1_clone_1 = u32[] constant(2395511584) + %broadcast.252070.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175844_1_clone_1), dimensions={} + %add.248460.7.clone.1 = u32[1280,1280]{1,0} add(%add.248459.3.clone.1, %broadcast.252070.24.clone.1) + %shift-left.109639.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121910.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115840.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121910.5.clone.1, %broadcast.244434.2816) + %or.115364.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109639.11.clone.1, %shift-right-logical.115840.11.clone.1) + %xor.121911.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248459.3.clone.1, %or.115364.9.clone.1) + %constant_218253_1_clone_1 = u32[] constant(1103031928) + %broadcast.252073.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218253_1_clone_1), dimensions={} + %add.248461.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121911.7.clone.1, %broadcast.252073.5.clone.1) + %add.248462.5.clone.1 = u32[1280,1280]{1,0} add(%add.248460.7.clone.1, %add.248461.5.clone.1) + %shift-left.109640.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248461.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115841.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248461.5.clone.1, %broadcast.244415.6016) + %or.115365.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109640.9.clone.1, %shift-right-logical.115841.9.clone.1) + %xor.121912.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248462.5.clone.1, %or.115365.7.clone.1) + %add.248463.3.clone.1 = u32[1280,1280]{1,0} add(%add.248462.5.clone.1, %xor.121912.5.clone.1) + %shift-left.109641.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121912.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115842.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121912.5.clone.1, %broadcast.244417.5760) + %or.115366.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109641.9.clone.1, %shift-right-logical.115842.9.clone.1) + %xor.121913.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248463.3.clone.1, %or.115366.7.clone.1) + %add.248464.3.clone.1 = u32[1280,1280]{1,0} add(%add.248463.3.clone.1, %xor.121913.5.clone.1) + %shift-left.109642.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121913.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115843.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121913.5.clone.1, %broadcast.244419.4352) + %or.115367.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109642.7.clone.1, %shift-right-logical.115843.7.clone.1) + %xor.121914.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248464.3.clone.1, %or.115367.5.clone.1) + %add.248465.3.clone.1 = u32[1280,1280]{1,0} add(%add.248464.3.clone.1, %xor.121914.3.clone.1) + %add.248466.7.clone.1 = u32[1280,1280]{1,0} add(%add.248465.3.clone.1, %broadcast.252052.44.clone.1) + %shift-left.109643.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121914.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115844.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121914.3.clone.1, %broadcast.244418.4352) + %or.115368.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109643.7.clone.1, %shift-right-logical.115844.7.clone.1) + %xor.121915.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248465.3.clone.1, %or.115368.5.clone.1) + %constant_218254_1_clone_1 = u32[] constant(3567742607) + %broadcast.252083.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218254_1_clone_1), dimensions={} + %add.248467.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121915.3.clone.1, %broadcast.252083.5.clone.1) + %add.248468.5.clone.1 = u32[1280,1280]{1,0} add(%add.248466.7.clone.1, %add.248467.5.clone.1) + %shift-left.109645.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248467.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115845.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248467.5.clone.1, %broadcast.244416.5760) + %or.115369.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109645.9.clone.1, %shift-right-logical.115845.9.clone.1) + %xor.121916.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248468.5.clone.1, %or.115369.7.clone.1) + %add.248469.3.clone.1 = u32[1280,1280]{1,0} add(%add.248468.5.clone.1, %xor.121916.5.clone.1) + %shift-left.109646.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121916.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115846.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121916.5.clone.1, %broadcast.244429.2304) + %or.115370.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109646.9.clone.1, %shift-right-logical.115846.9.clone.1) + %xor.121917.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248469.3.clone.1, %or.115370.7.clone.1) + %add.248470.3.clone.1 = u32[1280,1280]{1,0} add(%add.248469.3.clone.1, %xor.121917.5.clone.1) + %shift-left.109647.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121917.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115847.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121917.5.clone.1, %broadcast.244430.4608) + %or.115371.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109647.9.clone.1, %shift-right-logical.115847.9.clone.1) + %xor.121918.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248470.3.clone.1, %or.115371.7.clone.1) + %add.248471.3.clone.1 = u32[1280,1280]{1,0} add(%add.248470.3.clone.1, %xor.121918.5.clone.1) + %add.248472.7.clone.1 = u32[1280,1280]{1,0} add(%add.248471.3.clone.1, %broadcast.252053.113.clone.1) + %shift-left.109648.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121918.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115848.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121918.5.clone.1, %broadcast.244434.2816) + %or.115372.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109648.11.clone.1, %shift-right-logical.115848.11.clone.1) + %xor.121919.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248471.3.clone.1, %or.115372.9.clone.1) + %constant_218255_1_clone_1 = u32[] constant(2395511588) + %broadcast.252093.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218255_1_clone_1), dimensions={} + %add.248473.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121919.7.clone.1, %broadcast.252093.5.clone.1) + %add.248474.5.clone.1 = u32[1280,1280]{1,0} add(%add.248472.7.clone.1, %add.248473.5.clone.1) + %shift-left.109650.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248473.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115849.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248473.5.clone.1, %broadcast.244415.6016) + %or.115373.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109650.9.clone.1, %shift-right-logical.115849.9.clone.1) + %xor.121920.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248474.5.clone.1, %or.115373.7.clone.1) + %add.248475.3.clone.1 = u32[1280,1280]{1,0} add(%add.248474.5.clone.1, %xor.121920.5.clone.1) + %shift-left.109651.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121920.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115850.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121920.5.clone.1, %broadcast.244417.5760) + %or.115374.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109651.9.clone.1, %shift-right-logical.115850.9.clone.1) + %xor.121921.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248475.3.clone.1, %or.115374.7.clone.1) + %add.248476.3.clone.1 = u32[1280,1280]{1,0} add(%add.248475.3.clone.1, %xor.121921.5.clone.1) + %shift-left.109652.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121921.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115851.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121921.5.clone.1, %broadcast.244419.4352) + %or.115375.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109652.5.clone.1, %shift-right-logical.115851.5.clone.1) + %xor.121922.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248476.3.clone.1, %or.115375.3.clone.1) + %add.248477.3.clone.1 = u32[1280,1280]{1,0} add(%add.248476.3.clone.1, %xor.121922.3.clone.1) + %add.248478.17.clone.1 = u32[1280,1280]{1,0} add(%add.248477.3.clone.1, %broadcast.252070.24.clone.1) + %shift-left.109653.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121922.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115852.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121922.3.clone.1, %broadcast.244418.4352) + %or.115376.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109653.5.clone.1, %shift-right-logical.115852.5.clone.1) + %xor.121923.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248477.3.clone.1, %or.115376.3.clone.1) + %constant_218256_1_clone_1 = u32[] constant(1103031931) + %broadcast.252103.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218256_1_clone_1), dimensions={} + %add.248479.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121923.15.clone.1, %broadcast.252103.19.clone.1) + %xor.121924.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248478.17.clone.1, %add.248479.19.clone.1) + %shift-right-logical.115853.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121924.17.clone.1, %broadcast.244468.1920) + %or.115377.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115853.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5756.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115377.13.clone.1) + %add.248480.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5756.11.clone.1, %broadcast.244470.1152) + %multiply.26491.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248480.9.clone.1, %broadcast.244471.896) + %add.248482.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26491.7.clone.1, %broadcast.244408.1024) + %maximum.3688.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248482.5.clone.1) + %abs.1542.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3688.3.clone.1) + %compare.7232.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1542.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26492.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3688.3.clone.1, %broadcast.244476.1152) + %negate.4589.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3688.3.clone.1) + %multiply.26494.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3688.3.clone.1, %negate.4589.5.clone.1) + %log-plus-one.1542.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26494.5.clone.1) + %negate.4590.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1542.3.clone.1) + %compare.7233.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4590.4.clone.1, %broadcast.244477.384), direction=LT + %select.21048.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21049.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21050.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21051.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21052.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21053.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21054.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21055.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21056.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248485.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4590.4.clone.1, %broadcast.244496.640) + %sqrt.1542.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4590.4.clone.1) + %add.248486.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1542.5.clone.1, %broadcast.244498.640) + %select.21057.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7233.3.clone.1, %add.248485.5.clone.1, %add.248486.5.clone.1) + %multiply.26495.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21056.3.clone.1, %select.21057.3.clone.1) + %add.248487.1.clone.1 = f32[1280,1280]{1,0} add(%select.21055.3.clone.1, %multiply.26495.1.clone.1) + %multiply.26497.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248487.1.clone.1, %select.21057.3.clone.1) + %add.248488.1.clone.1 = f32[1280,1280]{1,0} add(%select.21054.3.clone.1, %multiply.26497.1.clone.1) + %multiply.26499.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248488.1.clone.1, %select.21057.3.clone.1) + %add.248490.1.clone.1 = f32[1280,1280]{1,0} add(%select.21053.3.clone.1, %multiply.26499.1.clone.1) + %multiply.26500.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248490.1.clone.1, %select.21057.3.clone.1) + %add.248491.1.clone.1 = f32[1280,1280]{1,0} add(%select.21052.3.clone.1, %multiply.26500.1.clone.1) + %multiply.26502.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248491.1.clone.1, %select.21057.3.clone.1) + %add.248492.3.clone.1 = f32[1280,1280]{1,0} add(%select.21051.5.clone.1, %multiply.26502.1.clone.1) + %multiply.26504.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248492.3.clone.1, %select.21057.3.clone.1) + %add.248493.3.clone.1 = f32[1280,1280]{1,0} add(%select.21050.5.clone.1, %multiply.26504.1.clone.1) + %multiply.26505.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248493.3.clone.1, %select.21057.3.clone.1) + %add.248495.9.clone.1 = f32[1280,1280]{1,0} add(%select.21049.11.clone.1, %multiply.26505.7.clone.1) + %multiply.26506.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248495.9.clone.1, %select.21057.3.clone.1) + %add.248496.7.clone.1 = f32[1280,1280]{1,0} add(%select.21048.7.clone.1, %multiply.26506.7.clone.1) + %multiply.26507.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248496.7.clone.1, %maximum.3688.3.clone.1) + %select.21058.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7232.3.clone.1, %multiply.26492.9.clone.1, %multiply.26507.7.clone.1) + %multiply.26508.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21058.7.clone.1, %broadcast.244500.640) + %clamp.1186.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26508.5.clone.1, %broadcast.244501.384) + %multiply.26509.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1186.3.clone.1, %broadcast.244502.1) + %constant_191693_1_clone_1 = u32[] constant(1438835833) + %broadcast.258922.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191693_1_clone_1), dimensions={} + %add.252356.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258922.44.clone.1) + %constant_191700_1_clone_1 = u32[] constant(3720491603) + %broadcast.258923.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191700_1_clone_1), dimensions={} + %add.252357.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258923.113.clone.1) + %add.252358.35.clone.1 = u32[1280,1280]{1,0} add(%add.252356.37.clone.1, %add.252357.99.clone.1) + %shift-left.111340.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252357.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117627.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252357.99.clone.1, %broadcast.244415.6016) + %or.117165.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111340.31.clone.1, %shift-right-logical.117627.29.clone.1) + %xor.123721.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252358.35.clone.1, %or.117165.29.clone.1) + %add.252359.5.clone.1 = u32[1280,1280]{1,0} add(%add.252358.35.clone.1, %xor.123721.27.clone.1) + %shift-left.111341.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123721.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117628.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123721.27.clone.1, %broadcast.244417.5760) + %or.117166.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111341.9.clone.1, %shift-right-logical.117628.9.clone.1) + %xor.123722.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252359.5.clone.1, %or.117166.7.clone.1) + %add.252360.3.clone.1 = u32[1280,1280]{1,0} add(%add.252359.5.clone.1, %xor.123722.5.clone.1) + %shift-left.111342.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123722.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117630.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123722.5.clone.1, %broadcast.244419.4352) + %or.117167.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111342.5.clone.1, %shift-right-logical.117630.5.clone.1) + %xor.123723.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252360.3.clone.1, %or.117167.3.clone.1) + %add.252361.3.clone.1 = u32[1280,1280]{1,0} add(%add.252360.3.clone.1, %xor.123723.3.clone.1) + %add.252362.7.clone.1 = u32[1280,1280]{1,0} add(%add.252361.3.clone.1, %broadcast.258923.113.clone.1) + %shift-left.111343.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123723.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117631.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123723.3.clone.1, %broadcast.244418.4352) + %or.117168.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111343.5.clone.1, %shift-right-logical.117631.5.clone.1) + %xor.123724.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252361.3.clone.1, %or.117168.3.clone.1) + %constant_218680_1_clone_1 = u32[] constant(2480000497) + %broadcast.258935.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218680_1_clone_1), dimensions={} + %add.252363.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123724.3.clone.1, %broadcast.258935.5.clone.1) + %add.252364.5.clone.1 = u32[1280,1280]{1,0} add(%add.252362.7.clone.1, %add.252363.5.clone.1) + %shift-left.111344.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252363.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117632.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252363.5.clone.1, %broadcast.244416.5760) + %or.117170.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111344.9.clone.1, %shift-right-logical.117632.9.clone.1) + %xor.123725.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252364.5.clone.1, %or.117170.7.clone.1) + %add.252365.3.clone.1 = u32[1280,1280]{1,0} add(%add.252364.5.clone.1, %xor.123725.5.clone.1) + %shift-left.111345.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123725.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117633.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123725.5.clone.1, %broadcast.244429.2304) + %or.117171.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111345.9.clone.1, %shift-right-logical.117633.9.clone.1) + %xor.123726.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252365.3.clone.1, %or.117171.7.clone.1) + %add.252366.3.clone.1 = u32[1280,1280]{1,0} add(%add.252365.3.clone.1, %xor.123726.5.clone.1) + %shift-left.111346.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123726.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117635.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123726.5.clone.1, %broadcast.244430.4608) + %or.117172.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111346.9.clone.1, %shift-right-logical.117635.9.clone.1) + %xor.123727.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252366.3.clone.1, %or.117172.7.clone.1) + %add.252367.3.clone.1 = u32[1280,1280]{1,0} add(%add.252366.3.clone.1, %xor.123727.5.clone.1) + %constant_191702_1_clone_1 = u32[] constant(2480000496) + %broadcast.258942.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191702_1_clone_1), dimensions={} + %add.252368.7.clone.1 = u32[1280,1280]{1,0} add(%add.252367.3.clone.1, %broadcast.258942.24.clone.1) + %shift-left.111347.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123727.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117636.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123727.5.clone.1, %broadcast.244434.2816) + %or.117173.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111347.11.clone.1, %shift-right-logical.117636.11.clone.1) + %xor.123729.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252367.3.clone.1, %or.117173.9.clone.1) + %constant_218681_1_clone_1 = u32[] constant(1438835835) + %broadcast.258945.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218681_1_clone_1), dimensions={} + %add.252369.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123729.7.clone.1, %broadcast.258945.5.clone.1) + %add.252370.5.clone.1 = u32[1280,1280]{1,0} add(%add.252368.7.clone.1, %add.252369.5.clone.1) + %shift-left.111348.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252369.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117637.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252369.5.clone.1, %broadcast.244415.6016) + %or.117174.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111348.9.clone.1, %shift-right-logical.117637.9.clone.1) + %xor.123730.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252370.5.clone.1, %or.117174.7.clone.1) + %add.252371.3.clone.1 = u32[1280,1280]{1,0} add(%add.252370.5.clone.1, %xor.123730.5.clone.1) + %shift-left.111349.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123730.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117638.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123730.5.clone.1, %broadcast.244417.5760) + %or.117175.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111349.9.clone.1, %shift-right-logical.117638.9.clone.1) + %xor.123731.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252371.3.clone.1, %or.117175.7.clone.1) + %add.252373.3.clone.1 = u32[1280,1280]{1,0} add(%add.252371.3.clone.1, %xor.123731.5.clone.1) + %shift-left.111350.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123731.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117640.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123731.5.clone.1, %broadcast.244419.4352) + %or.117176.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111350.7.clone.1, %shift-right-logical.117640.7.clone.1) + %xor.123732.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252373.3.clone.1, %or.117176.5.clone.1) + %add.252376.3.clone.1 = u32[1280,1280]{1,0} add(%add.252373.3.clone.1, %xor.123732.3.clone.1) + %add.252377.7.clone.1 = u32[1280,1280]{1,0} add(%add.252376.3.clone.1, %broadcast.258922.44.clone.1) + %shift-left.111351.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123732.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117641.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123732.3.clone.1, %broadcast.244418.4352) + %or.117177.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111351.7.clone.1, %shift-right-logical.117641.7.clone.1) + %xor.123733.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252376.3.clone.1, %or.117177.5.clone.1) + %constant_218682_1_clone_1 = u32[] constant(3720491606) + %broadcast.258955.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218682_1_clone_1), dimensions={} + %add.252378.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123733.3.clone.1, %broadcast.258955.5.clone.1) + %add.252379.5.clone.1 = u32[1280,1280]{1,0} add(%add.252377.7.clone.1, %add.252378.5.clone.1) + %shift-left.111352.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252378.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117642.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252378.5.clone.1, %broadcast.244416.5760) + %or.117178.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111352.9.clone.1, %shift-right-logical.117642.9.clone.1) + %xor.123734.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252379.5.clone.1, %or.117178.7.clone.1) + %add.252381.3.clone.1 = u32[1280,1280]{1,0} add(%add.252379.5.clone.1, %xor.123734.5.clone.1) + %shift-left.111353.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123734.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117643.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123734.5.clone.1, %broadcast.244429.2304) + %or.117179.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111353.9.clone.1, %shift-right-logical.117643.9.clone.1) + %xor.123735.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252381.3.clone.1, %or.117179.7.clone.1) + %add.252382.3.clone.1 = u32[1280,1280]{1,0} add(%add.252381.3.clone.1, %xor.123735.5.clone.1) + %shift-left.111354.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123735.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117645.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123735.5.clone.1, %broadcast.244430.4608) + %or.117180.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111354.9.clone.1, %shift-right-logical.117645.9.clone.1) + %xor.123736.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252382.3.clone.1, %or.117180.7.clone.1) + %add.252383.3.clone.1 = u32[1280,1280]{1,0} add(%add.252382.3.clone.1, %xor.123736.5.clone.1) + %add.252384.7.clone.1 = u32[1280,1280]{1,0} add(%add.252383.3.clone.1, %broadcast.258923.113.clone.1) + %shift-left.111355.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123736.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117646.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123736.5.clone.1, %broadcast.244434.2816) + %or.117181.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111355.11.clone.1, %shift-right-logical.117646.11.clone.1) + %xor.123737.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252383.3.clone.1, %or.117181.9.clone.1) + %constant_218683_1_clone_1 = u32[] constant(2480000500) + %broadcast.258967.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218683_1_clone_1), dimensions={} + %add.252386.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123737.7.clone.1, %broadcast.258967.5.clone.1) + %add.252387.5.clone.1 = u32[1280,1280]{1,0} add(%add.252384.7.clone.1, %add.252386.5.clone.1) + %shift-left.111356.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252386.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117647.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252386.5.clone.1, %broadcast.244415.6016) + %or.117182.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111356.9.clone.1, %shift-right-logical.117647.9.clone.1) + %xor.123738.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252387.5.clone.1, %or.117182.7.clone.1) + %add.252388.3.clone.1 = u32[1280,1280]{1,0} add(%add.252387.5.clone.1, %xor.123738.5.clone.1) + %shift-left.111357.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123738.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117648.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123738.5.clone.1, %broadcast.244417.5760) + %or.117183.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111357.9.clone.1, %shift-right-logical.117648.9.clone.1) + %xor.123739.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252388.3.clone.1, %or.117183.7.clone.1) + %add.252389.3.clone.1 = u32[1280,1280]{1,0} add(%add.252388.3.clone.1, %xor.123739.5.clone.1) + %shift-left.111358.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123739.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117649.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123739.5.clone.1, %broadcast.244419.4352) + %or.117184.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111358.5.clone.1, %shift-right-logical.117649.5.clone.1) + %xor.123740.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252389.3.clone.1, %or.117184.3.clone.1) + %add.252391.3.clone.1 = u32[1280,1280]{1,0} add(%add.252389.3.clone.1, %xor.123740.3.clone.1) + %add.252392.17.clone.1 = u32[1280,1280]{1,0} add(%add.252391.3.clone.1, %broadcast.258942.24.clone.1) + %shift-left.111359.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123740.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117650.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123740.3.clone.1, %broadcast.244418.4352) + %or.117185.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111359.5.clone.1, %shift-right-logical.117650.5.clone.1) + %xor.123741.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252391.3.clone.1, %or.117185.3.clone.1) + %constant_218684_1_clone_1 = u32[] constant(1438835838) + %broadcast.258977.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218684_1_clone_1), dimensions={} + %add.252393.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123741.15.clone.1, %broadcast.258977.19.clone.1) + %xor.123742.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252392.17.clone.1, %add.252393.19.clone.1) + %shift-right-logical.117651.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123742.17.clone.1, %broadcast.244468.1920) + %or.117186.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117651.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5834.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117186.13.clone.1) + %add.252394.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5834.11.clone.1, %broadcast.244470.1152) + %multiply.27292.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252394.9.clone.1, %broadcast.244471.896) + %add.252395.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27292.7.clone.1, %broadcast.244408.1024) + %maximum.3766.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252395.5.clone.1) + %abs.1594.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3766.3.clone.1) + %compare.7350.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1594.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27293.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3766.3.clone.1, %broadcast.244476.1152) + %negate.4693.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3766.3.clone.1) + %multiply.27294.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3766.3.clone.1, %negate.4693.5.clone.1) + %log-plus-one.1594.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27294.5.clone.1) + %negate.4694.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1594.3.clone.1) + %compare.7351.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4694.4.clone.1, %broadcast.244477.384), direction=LT + %select.21641.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21642.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21643.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21644.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21645.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21646.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21647.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21648.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21649.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252397.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4694.4.clone.1, %broadcast.244496.640) + %sqrt.1594.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4694.4.clone.1) + %add.252401.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1594.5.clone.1, %broadcast.244498.640) + %select.21650.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7351.3.clone.1, %add.252397.5.clone.1, %add.252401.5.clone.1) + %multiply.27295.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21649.3.clone.1, %select.21650.3.clone.1) + %add.252402.1.clone.1 = f32[1280,1280]{1,0} add(%select.21648.3.clone.1, %multiply.27295.1.clone.1) + %multiply.27296.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252402.1.clone.1, %select.21650.3.clone.1) + %add.252403.1.clone.1 = f32[1280,1280]{1,0} add(%select.21647.3.clone.1, %multiply.27296.1.clone.1) + %multiply.27297.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252403.1.clone.1, %select.21650.3.clone.1) + %add.252404.1.clone.1 = f32[1280,1280]{1,0} add(%select.21646.3.clone.1, %multiply.27297.1.clone.1) + %multiply.27298.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252404.1.clone.1, %select.21650.3.clone.1) + %add.252406.1.clone.1 = f32[1280,1280]{1,0} add(%select.21645.3.clone.1, %multiply.27298.1.clone.1) + %multiply.27299.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252406.1.clone.1, %select.21650.3.clone.1) + %add.252407.3.clone.1 = f32[1280,1280]{1,0} add(%select.21644.5.clone.1, %multiply.27299.1.clone.1) + %multiply.27300.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252407.3.clone.1, %select.21650.3.clone.1) + %add.252408.3.clone.1 = f32[1280,1280]{1,0} add(%select.21643.5.clone.1, %multiply.27300.1.clone.1) + %multiply.27301.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252408.3.clone.1, %select.21650.3.clone.1) + %add.252409.9.clone.1 = f32[1280,1280]{1,0} add(%select.21642.11.clone.1, %multiply.27301.7.clone.1) + %multiply.27302.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252409.9.clone.1, %select.21650.3.clone.1) + %add.252411.7.clone.1 = f32[1280,1280]{1,0} add(%select.21641.7.clone.1, %multiply.27302.7.clone.1) + %multiply.27303.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252411.7.clone.1, %maximum.3766.3.clone.1) + %select.21651.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7350.3.clone.1, %multiply.27293.9.clone.1, %multiply.27303.7.clone.1) + %multiply.27304.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21651.7.clone.1, %broadcast.244500.640) + %clamp.1238.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27304.5.clone.1, %broadcast.244501.384) + %multiply.27305.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1238.3.clone.1, %broadcast.244502.1) + %constant_175592_1_clone_1 = u32[] constant(2935165355) + %broadcast.251966.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175592_1_clone_1), dimensions={} + %add.248394.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251966.44.clone.1) + %constant_175599_1_clone_1 = u32[] constant(1815324084) + %broadcast.251967.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175599_1_clone_1), dimensions={} + %add.248395.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251967.113.clone.1) + %add.248397.35.clone.1 = u32[1280,1280]{1,0} add(%add.248394.37.clone.1, %add.248395.99.clone.1) + %shift-left.109607.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248395.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115806.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248395.99.clone.1, %broadcast.244415.6016) + %or.115336.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109607.31.clone.1, %shift-right-logical.115806.29.clone.1) + %xor.121883.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248397.35.clone.1, %or.115336.29.clone.1) + %add.248401.5.clone.1 = u32[1280,1280]{1,0} add(%add.248397.35.clone.1, %xor.121883.27.clone.1) + %shift-left.109608.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121883.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115808.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121883.27.clone.1, %broadcast.244417.5760) + %or.115337.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109608.9.clone.1, %shift-right-logical.115808.9.clone.1) + %xor.121884.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248401.5.clone.1, %or.115337.7.clone.1) + %add.248402.3.clone.1 = u32[1280,1280]{1,0} add(%add.248401.5.clone.1, %xor.121884.5.clone.1) + %shift-left.109610.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121884.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115809.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121884.5.clone.1, %broadcast.244419.4352) + %or.115338.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109610.5.clone.1, %shift-right-logical.115809.5.clone.1) + %xor.121885.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248402.3.clone.1, %or.115338.3.clone.1) + %add.248403.3.clone.1 = u32[1280,1280]{1,0} add(%add.248402.3.clone.1, %xor.121885.3.clone.1) + %add.248404.7.clone.1 = u32[1280,1280]{1,0} add(%add.248403.3.clone.1, %broadcast.251967.113.clone.1) + %shift-left.109611.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121885.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115810.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121885.3.clone.1, %broadcast.244418.4352) + %or.115339.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109611.5.clone.1, %shift-right-logical.115810.5.clone.1) + %xor.121886.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248403.3.clone.1, %or.115339.3.clone.1) + %constant_218247_1_clone_1 = u32[] constant(3641814982) + %broadcast.251977.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218247_1_clone_1), dimensions={} + %add.248406.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121886.3.clone.1, %broadcast.251977.5.clone.1) + %add.248407.5.clone.1 = u32[1280,1280]{1,0} add(%add.248404.7.clone.1, %add.248406.5.clone.1) + %shift-left.109612.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248406.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115811.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248406.5.clone.1, %broadcast.244416.5760) + %or.115340.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109612.9.clone.1, %shift-right-logical.115811.9.clone.1) + %xor.121887.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248407.5.clone.1, %or.115340.7.clone.1) + %add.248408.3.clone.1 = u32[1280,1280]{1,0} add(%add.248407.5.clone.1, %xor.121887.5.clone.1) + %shift-left.109613.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121887.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115813.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121887.5.clone.1, %broadcast.244429.2304) + %or.115341.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109613.9.clone.1, %shift-right-logical.115813.9.clone.1) + %xor.121888.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248408.3.clone.1, %or.115341.7.clone.1) + %add.248409.3.clone.1 = u32[1280,1280]{1,0} add(%add.248408.3.clone.1, %xor.121888.5.clone.1) + %shift-left.109614.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121888.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115814.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121888.5.clone.1, %broadcast.244430.4608) + %or.115342.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109614.9.clone.1, %shift-right-logical.115814.9.clone.1) + %xor.121889.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248409.3.clone.1, %or.115342.7.clone.1) + %add.248411.3.clone.1 = u32[1280,1280]{1,0} add(%add.248409.3.clone.1, %xor.121889.5.clone.1) + %constant_175601_1_clone_1 = u32[] constant(3641814981) + %broadcast.251984.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175601_1_clone_1), dimensions={} + %add.248412.7.clone.1 = u32[1280,1280]{1,0} add(%add.248411.3.clone.1, %broadcast.251984.24.clone.1) + %shift-left.109615.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121889.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115815.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121889.5.clone.1, %broadcast.244434.2816) + %or.115343.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109615.11.clone.1, %shift-right-logical.115815.11.clone.1) + %xor.121890.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248411.3.clone.1, %or.115343.9.clone.1) + %constant_218248_1_clone_1 = u32[] constant(2935165357) + %broadcast.251987.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218248_1_clone_1), dimensions={} + %add.248413.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121890.7.clone.1, %broadcast.251987.5.clone.1) + %add.248414.5.clone.1 = u32[1280,1280]{1,0} add(%add.248412.7.clone.1, %add.248413.5.clone.1) + %shift-left.109616.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248413.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115816.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248413.5.clone.1, %broadcast.244415.6016) + %or.115344.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109616.9.clone.1, %shift-right-logical.115816.9.clone.1) + %xor.121891.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248414.5.clone.1, %or.115344.7.clone.1) + %add.248416.3.clone.1 = u32[1280,1280]{1,0} add(%add.248414.5.clone.1, %xor.121891.5.clone.1) + %shift-left.109617.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121891.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115817.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121891.5.clone.1, %broadcast.244417.5760) + %or.115345.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109617.9.clone.1, %shift-right-logical.115817.9.clone.1) + %xor.121892.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248416.3.clone.1, %or.115345.7.clone.1) + %add.248417.3.clone.1 = u32[1280,1280]{1,0} add(%add.248416.3.clone.1, %xor.121892.5.clone.1) + %shift-left.109618.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121892.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115818.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121892.5.clone.1, %broadcast.244419.4352) + %or.115346.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109618.7.clone.1, %shift-right-logical.115818.7.clone.1) + %xor.121893.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248417.3.clone.1, %or.115346.5.clone.1) + %add.248418.3.clone.1 = u32[1280,1280]{1,0} add(%add.248417.3.clone.1, %xor.121893.3.clone.1) + %add.248419.7.clone.1 = u32[1280,1280]{1,0} add(%add.248418.3.clone.1, %broadcast.251966.44.clone.1) + %shift-left.109620.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121893.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115819.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121893.3.clone.1, %broadcast.244418.4352) + %or.115347.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109620.7.clone.1, %shift-right-logical.115819.7.clone.1) + %xor.121894.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248418.3.clone.1, %or.115347.5.clone.1) + %constant_218249_1_clone_1 = u32[] constant(1815324087) + %broadcast.251997.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218249_1_clone_1), dimensions={} + %add.248420.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121894.3.clone.1, %broadcast.251997.5.clone.1) + %add.248422.5.clone.1 = u32[1280,1280]{1,0} add(%add.248419.7.clone.1, %add.248420.5.clone.1) + %shift-left.109621.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248420.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115820.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248420.5.clone.1, %broadcast.244416.5760) + %or.115348.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109621.9.clone.1, %shift-right-logical.115820.9.clone.1) + %xor.121895.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248422.5.clone.1, %or.115348.7.clone.1) + %add.248425.3.clone.1 = u32[1280,1280]{1,0} add(%add.248422.5.clone.1, %xor.121895.5.clone.1) + %shift-left.109622.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121895.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115821.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121895.5.clone.1, %broadcast.244429.2304) + %or.115349.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109622.9.clone.1, %shift-right-logical.115821.9.clone.1) + %xor.121896.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248425.3.clone.1, %or.115349.7.clone.1) + %add.248426.3.clone.1 = u32[1280,1280]{1,0} add(%add.248425.3.clone.1, %xor.121896.5.clone.1) + %shift-left.109623.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121896.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115823.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121896.5.clone.1, %broadcast.244430.4608) + %or.115350.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109623.9.clone.1, %shift-right-logical.115823.9.clone.1) + %xor.121897.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248426.3.clone.1, %or.115350.7.clone.1) + %add.248427.3.clone.1 = u32[1280,1280]{1,0} add(%add.248426.3.clone.1, %xor.121897.5.clone.1) + %add.248428.7.clone.1 = u32[1280,1280]{1,0} add(%add.248427.3.clone.1, %broadcast.251967.113.clone.1) + %shift-left.109625.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121897.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115824.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121897.5.clone.1, %broadcast.244434.2816) + %or.115351.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109625.11.clone.1, %shift-right-logical.115824.11.clone.1) + %xor.121898.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248427.3.clone.1, %or.115351.9.clone.1) + %constant_218250_1_clone_1 = u32[] constant(3641814985) + %broadcast.252007.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218250_1_clone_1), dimensions={} + %add.248429.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121898.7.clone.1, %broadcast.252007.5.clone.1) + %add.248430.5.clone.1 = u32[1280,1280]{1,0} add(%add.248428.7.clone.1, %add.248429.5.clone.1) + %shift-left.109626.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248429.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115825.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248429.5.clone.1, %broadcast.244415.6016) + %or.115352.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109626.9.clone.1, %shift-right-logical.115825.9.clone.1) + %xor.121899.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248430.5.clone.1, %or.115352.7.clone.1) + %add.248431.3.clone.1 = u32[1280,1280]{1,0} add(%add.248430.5.clone.1, %xor.121899.5.clone.1) + %shift-left.109627.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121899.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115826.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121899.5.clone.1, %broadcast.244417.5760) + %or.115353.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109627.9.clone.1, %shift-right-logical.115826.9.clone.1) + %xor.121900.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248431.3.clone.1, %or.115353.7.clone.1) + %add.248432.3.clone.1 = u32[1280,1280]{1,0} add(%add.248431.3.clone.1, %xor.121900.5.clone.1) + %shift-left.109628.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121900.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115828.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121900.5.clone.1, %broadcast.244419.4352) + %or.115354.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109628.5.clone.1, %shift-right-logical.115828.5.clone.1) + %xor.121901.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248432.3.clone.1, %or.115354.3.clone.1) + %add.248433.3.clone.1 = u32[1280,1280]{1,0} add(%add.248432.3.clone.1, %xor.121901.3.clone.1) + %add.248434.17.clone.1 = u32[1280,1280]{1,0} add(%add.248433.3.clone.1, %broadcast.251984.24.clone.1) + %shift-left.109630.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121901.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115829.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121901.3.clone.1, %broadcast.244418.4352) + %or.115355.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109630.5.clone.1, %shift-right-logical.115829.5.clone.1) + %xor.121902.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248433.3.clone.1, %or.115355.3.clone.1) + %constant_218251_1_clone_1 = u32[] constant(2935165360) + %broadcast.252017.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218251_1_clone_1), dimensions={} + %add.248435.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121902.15.clone.1, %broadcast.252017.19.clone.1) + %xor.121903.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248434.17.clone.1, %add.248435.19.clone.1) + %shift-right-logical.115830.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121903.17.clone.1, %broadcast.244468.1920) + %or.115356.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115830.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5755.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115356.13.clone.1) + %add.248436.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5755.11.clone.1, %broadcast.244470.1152) + %multiply.26471.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248436.9.clone.1, %broadcast.244471.896) + %add.248437.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26471.7.clone.1, %broadcast.244408.1024) + %maximum.3687.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248437.5.clone.1) + %abs.1541.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3687.3.clone.1) + %compare.7230.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1541.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26472.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3687.3.clone.1, %broadcast.244476.1152) + %negate.4587.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3687.3.clone.1) + %multiply.26473.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3687.3.clone.1, %negate.4587.5.clone.1) + %log-plus-one.1541.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26473.5.clone.1) + %negate.4588.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1541.3.clone.1) + %compare.7231.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4588.4.clone.1, %broadcast.244477.384), direction=LT + %select.21032.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21033.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21035.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21040.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21041.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21042.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21043.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21044.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21045.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248438.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4588.4.clone.1, %broadcast.244496.640) + %sqrt.1541.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4588.4.clone.1) + %add.248439.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1541.5.clone.1, %broadcast.244498.640) + %select.21046.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7231.3.clone.1, %add.248438.5.clone.1, %add.248439.5.clone.1) + %multiply.26474.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21045.3.clone.1, %select.21046.3.clone.1) + %add.248440.1.clone.1 = f32[1280,1280]{1,0} add(%select.21044.3.clone.1, %multiply.26474.1.clone.1) + %multiply.26476.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248440.1.clone.1, %select.21046.3.clone.1) + %add.248441.1.clone.1 = f32[1280,1280]{1,0} add(%select.21043.3.clone.1, %multiply.26476.1.clone.1) + %multiply.26477.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248441.1.clone.1, %select.21046.3.clone.1) + %add.248442.1.clone.1 = f32[1280,1280]{1,0} add(%select.21042.3.clone.1, %multiply.26477.1.clone.1) + %multiply.26479.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248442.1.clone.1, %select.21046.3.clone.1) + %add.248443.1.clone.1 = f32[1280,1280]{1,0} add(%select.21041.3.clone.1, %multiply.26479.1.clone.1) + %multiply.26480.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248443.1.clone.1, %select.21046.3.clone.1) + %add.248444.3.clone.1 = f32[1280,1280]{1,0} add(%select.21040.5.clone.1, %multiply.26480.1.clone.1) + %multiply.26482.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248444.3.clone.1, %select.21046.3.clone.1) + %add.248445.3.clone.1 = f32[1280,1280]{1,0} add(%select.21035.5.clone.1, %multiply.26482.1.clone.1) + %multiply.26483.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248445.3.clone.1, %select.21046.3.clone.1) + %add.248446.9.clone.1 = f32[1280,1280]{1,0} add(%select.21033.11.clone.1, %multiply.26483.7.clone.1) + %multiply.26485.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248446.9.clone.1, %select.21046.3.clone.1) + %add.248447.7.clone.1 = f32[1280,1280]{1,0} add(%select.21032.7.clone.1, %multiply.26485.7.clone.1) + %multiply.26486.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248447.7.clone.1, %maximum.3687.3.clone.1) + %select.21047.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7230.3.clone.1, %multiply.26472.9.clone.1, %multiply.26486.7.clone.1) + %multiply.26488.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21047.7.clone.1, %broadcast.244500.640) + %clamp.1185.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26488.5.clone.1, %broadcast.244501.384) + %multiply.26489.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1185.3.clone.1, %broadcast.244502.1) + %constant_186354_1_clone_1 = u32[] constant(2496889964) + %broadcast.256636.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186354_1_clone_1), dimensions={} + %add.251038.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256636.44.clone.1) + %constant_186361_1_clone_1 = u32[] constant(4113454960) + %broadcast.256637.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186361_1_clone_1), dimensions={} + %add.251039.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256637.113.clone.1) + %add.251040.35.clone.1 = u32[1280,1280]{1,0} add(%add.251038.37.clone.1, %add.251039.99.clone.1) + %shift-left.110773.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251039.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117034.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251039.99.clone.1, %broadcast.244415.6016) + %or.116551.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110773.31.clone.1, %shift-right-logical.117034.29.clone.1) + %xor.123127.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251040.35.clone.1, %or.116551.29.clone.1) + %add.251042.5.clone.1 = u32[1280,1280]{1,0} add(%add.251040.35.clone.1, %xor.123127.27.clone.1) + %shift-left.110775.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123127.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117035.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123127.27.clone.1, %broadcast.244417.5760) + %or.116552.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110775.9.clone.1, %shift-right-logical.117035.9.clone.1) + %xor.123129.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251042.5.clone.1, %or.116552.7.clone.1) + %add.251046.3.clone.1 = u32[1280,1280]{1,0} add(%add.251042.5.clone.1, %xor.123129.5.clone.1) + %shift-left.110776.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123129.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117036.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123129.5.clone.1, %broadcast.244419.4352) + %or.116553.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110776.5.clone.1, %shift-right-logical.117036.5.clone.1) + %xor.123130.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251046.3.clone.1, %or.116553.3.clone.1) + %add.251047.3.clone.1 = u32[1280,1280]{1,0} add(%add.251046.3.clone.1, %xor.123130.3.clone.1) + %add.251048.7.clone.1 = u32[1280,1280]{1,0} add(%add.251047.3.clone.1, %broadcast.256637.113.clone.1) + %shift-left.110777.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123130.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117037.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123130.3.clone.1, %broadcast.244418.4352) + %or.116554.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110777.5.clone.1, %shift-right-logical.117037.5.clone.1) + %xor.123131.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251047.3.clone.1, %or.116554.3.clone.1) + %constant_218534_1_clone_1 = u32[] constant(2049755335) + %broadcast.256647.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218534_1_clone_1), dimensions={} + %add.251049.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123131.3.clone.1, %broadcast.256647.5.clone.1) + %add.251051.5.clone.1 = u32[1280,1280]{1,0} add(%add.251048.7.clone.1, %add.251049.5.clone.1) + %shift-left.110778.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251049.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117038.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251049.5.clone.1, %broadcast.244416.5760) + %or.116556.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110778.9.clone.1, %shift-right-logical.117038.9.clone.1) + %xor.123132.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251051.5.clone.1, %or.116556.7.clone.1) + %add.251052.3.clone.1 = u32[1280,1280]{1,0} add(%add.251051.5.clone.1, %xor.123132.5.clone.1) + %shift-left.110780.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123132.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117039.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123132.5.clone.1, %broadcast.244429.2304) + %or.116557.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110780.9.clone.1, %shift-right-logical.117039.9.clone.1) + %xor.123134.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251052.3.clone.1, %or.116557.7.clone.1) + %add.251053.3.clone.1 = u32[1280,1280]{1,0} add(%add.251052.3.clone.1, %xor.123134.5.clone.1) + %shift-left.110781.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123134.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117040.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123134.5.clone.1, %broadcast.244430.4608) + %or.116558.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110781.9.clone.1, %shift-right-logical.117040.9.clone.1) + %xor.123135.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251053.3.clone.1, %or.116558.7.clone.1) + %add.251054.3.clone.1 = u32[1280,1280]{1,0} add(%add.251053.3.clone.1, %xor.123135.5.clone.1) + %constant_186363_1_clone_1 = u32[] constant(2049755334) + %broadcast.256654.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186363_1_clone_1), dimensions={} + %add.251056.7.clone.1 = u32[1280,1280]{1,0} add(%add.251054.3.clone.1, %broadcast.256654.24.clone.1) + %shift-left.110782.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123135.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117041.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123135.5.clone.1, %broadcast.244434.2816) + %or.116559.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110782.11.clone.1, %shift-right-logical.117041.11.clone.1) + %xor.123136.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251054.3.clone.1, %or.116559.9.clone.1) + %constant_218535_1_clone_1 = u32[] constant(2496889966) + %broadcast.256657.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218535_1_clone_1), dimensions={} + %add.251057.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123136.7.clone.1, %broadcast.256657.5.clone.1) + %add.251058.5.clone.1 = u32[1280,1280]{1,0} add(%add.251056.7.clone.1, %add.251057.5.clone.1) + %shift-left.110783.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251057.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117042.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251057.5.clone.1, %broadcast.244415.6016) + %or.116561.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110783.9.clone.1, %shift-right-logical.117042.9.clone.1) + %xor.123137.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251058.5.clone.1, %or.116561.7.clone.1) + %add.251059.3.clone.1 = u32[1280,1280]{1,0} add(%add.251058.5.clone.1, %xor.123137.5.clone.1) + %shift-left.110785.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123137.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117043.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123137.5.clone.1, %broadcast.244417.5760) + %or.116562.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110785.9.clone.1, %shift-right-logical.117043.9.clone.1) + %xor.123139.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251059.3.clone.1, %or.116562.7.clone.1) + %add.251061.3.clone.1 = u32[1280,1280]{1,0} add(%add.251059.3.clone.1, %xor.123139.5.clone.1) + %shift-left.110786.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123139.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117044.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123139.5.clone.1, %broadcast.244419.4352) + %or.116563.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110786.7.clone.1, %shift-right-logical.117044.7.clone.1) + %xor.123140.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251061.3.clone.1, %or.116563.5.clone.1) + %add.251062.3.clone.1 = u32[1280,1280]{1,0} add(%add.251061.3.clone.1, %xor.123140.3.clone.1) + %add.251063.7.clone.1 = u32[1280,1280]{1,0} add(%add.251062.3.clone.1, %broadcast.256636.44.clone.1) + %shift-left.110787.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123140.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117045.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123140.3.clone.1, %broadcast.244418.4352) + %or.116564.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110787.7.clone.1, %shift-right-logical.117045.7.clone.1) + %xor.123141.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251062.3.clone.1, %or.116564.5.clone.1) + %constant_218536_1_clone_1 = u32[] constant(4113454963) + %broadcast.256667.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218536_1_clone_1), dimensions={} + %add.251064.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123141.3.clone.1, %broadcast.256667.5.clone.1) + %add.251065.5.clone.1 = u32[1280,1280]{1,0} add(%add.251063.7.clone.1, %add.251064.5.clone.1) + %shift-left.110788.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251064.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117046.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251064.5.clone.1, %broadcast.244416.5760) + %or.116566.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110788.9.clone.1, %shift-right-logical.117046.9.clone.1) + %xor.123142.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251065.5.clone.1, %or.116566.7.clone.1) + %add.251067.3.clone.1 = u32[1280,1280]{1,0} add(%add.251065.5.clone.1, %xor.123142.5.clone.1) + %shift-left.110789.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123142.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117047.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123142.5.clone.1, %broadcast.244429.2304) + %or.116567.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110789.9.clone.1, %shift-right-logical.117047.9.clone.1) + %xor.123143.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251067.3.clone.1, %or.116567.7.clone.1) + %add.251071.3.clone.1 = u32[1280,1280]{1,0} add(%add.251067.3.clone.1, %xor.123143.5.clone.1) + %shift-left.110790.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123143.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117048.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123143.5.clone.1, %broadcast.244430.4608) + %or.116568.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110790.9.clone.1, %shift-right-logical.117048.9.clone.1) + %xor.123144.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251071.3.clone.1, %or.116568.7.clone.1) + %add.251072.3.clone.1 = u32[1280,1280]{1,0} add(%add.251071.3.clone.1, %xor.123144.5.clone.1) + %add.251073.7.clone.1 = u32[1280,1280]{1,0} add(%add.251072.3.clone.1, %broadcast.256637.113.clone.1) + %shift-left.110791.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123144.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117049.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123144.5.clone.1, %broadcast.244434.2816) + %or.116569.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110791.11.clone.1, %shift-right-logical.117049.11.clone.1) + %xor.123145.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251072.3.clone.1, %or.116569.9.clone.1) + %constant_218537_1_clone_1 = u32[] constant(2049755338) + %broadcast.256677.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218537_1_clone_1), dimensions={} + %add.251074.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123145.7.clone.1, %broadcast.256677.5.clone.1) + %add.251076.5.clone.1 = u32[1280,1280]{1,0} add(%add.251073.7.clone.1, %add.251074.5.clone.1) + %shift-left.110792.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251074.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117050.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251074.5.clone.1, %broadcast.244415.6016) + %or.116571.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110792.9.clone.1, %shift-right-logical.117050.9.clone.1) + %xor.123146.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251076.5.clone.1, %or.116571.7.clone.1) + %add.251077.3.clone.1 = u32[1280,1280]{1,0} add(%add.251076.5.clone.1, %xor.123146.5.clone.1) + %shift-left.110793.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123146.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117051.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123146.5.clone.1, %broadcast.244417.5760) + %or.116572.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110793.9.clone.1, %shift-right-logical.117051.9.clone.1) + %xor.123147.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251077.3.clone.1, %or.116572.7.clone.1) + %add.251078.3.clone.1 = u32[1280,1280]{1,0} add(%add.251077.3.clone.1, %xor.123147.5.clone.1) + %shift-left.110795.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123147.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117052.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123147.5.clone.1, %broadcast.244419.4352) + %or.116573.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110795.5.clone.1, %shift-right-logical.117052.5.clone.1) + %xor.123148.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251078.3.clone.1, %or.116573.3.clone.1) + %add.251079.3.clone.1 = u32[1280,1280]{1,0} add(%add.251078.3.clone.1, %xor.123148.3.clone.1) + %add.251081.17.clone.1 = u32[1280,1280]{1,0} add(%add.251079.3.clone.1, %broadcast.256654.24.clone.1) + %shift-left.110796.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123148.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117054.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123148.3.clone.1, %broadcast.244418.4352) + %or.116574.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110796.5.clone.1, %shift-right-logical.117054.5.clone.1) + %xor.123149.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251079.3.clone.1, %or.116574.3.clone.1) + %constant_218538_1_clone_1 = u32[] constant(2496889969) + %broadcast.256687.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218538_1_clone_1), dimensions={} + %add.251082.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123149.15.clone.1, %broadcast.256687.19.clone.1) + %xor.123150.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251081.17.clone.1, %add.251082.19.clone.1) + %shift-right-logical.117055.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123150.17.clone.1, %broadcast.244468.1920) + %or.116575.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117055.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5808.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116575.13.clone.1) + %add.251083.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5808.11.clone.1, %broadcast.244470.1152) + %multiply.27011.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251083.9.clone.1, %broadcast.244471.896) + %add.251084.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27011.7.clone.1, %broadcast.244408.1024) + %maximum.3740.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251084.5.clone.1) + %abs.1576.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3740.3.clone.1) + %compare.7309.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1576.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27012.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3740.3.clone.1, %broadcast.244476.1152) + %negate.4657.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3740.3.clone.1) + %multiply.27013.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3740.3.clone.1, %negate.4657.5.clone.1) + %log-plus-one.1576.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27013.5.clone.1) + %negate.4658.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1576.3.clone.1) + %compare.7313.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4658.4.clone.1, %broadcast.244477.384), direction=LT + %select.21422.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21423.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21424.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21425.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21426.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21427.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21428.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21429.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21430.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251086.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4658.4.clone.1, %broadcast.244496.640) + %sqrt.1576.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4658.4.clone.1) + %add.251087.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1576.5.clone.1, %broadcast.244498.640) + %select.21431.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7313.3.clone.1, %add.251086.5.clone.1, %add.251087.5.clone.1) + %multiply.27014.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21430.3.clone.1, %select.21431.3.clone.1) + %add.251088.1.clone.1 = f32[1280,1280]{1,0} add(%select.21429.3.clone.1, %multiply.27014.1.clone.1) + %multiply.27016.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251088.1.clone.1, %select.21431.3.clone.1) + %add.251089.1.clone.1 = f32[1280,1280]{1,0} add(%select.21428.3.clone.1, %multiply.27016.1.clone.1) + %multiply.27017.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251089.1.clone.1, %select.21431.3.clone.1) + %add.251090.1.clone.1 = f32[1280,1280]{1,0} add(%select.21427.3.clone.1, %multiply.27017.1.clone.1) + %multiply.27018.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251090.1.clone.1, %select.21431.3.clone.1) + %add.251092.1.clone.1 = f32[1280,1280]{1,0} add(%select.21426.3.clone.1, %multiply.27018.1.clone.1) + %multiply.27019.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251092.1.clone.1, %select.21431.3.clone.1) + %add.251096.3.clone.1 = f32[1280,1280]{1,0} add(%select.21425.5.clone.1, %multiply.27019.1.clone.1) + %multiply.27020.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251096.3.clone.1, %select.21431.3.clone.1) + %add.251097.3.clone.1 = f32[1280,1280]{1,0} add(%select.21424.5.clone.1, %multiply.27020.1.clone.1) + %multiply.27021.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251097.3.clone.1, %select.21431.3.clone.1) + %add.251098.9.clone.1 = f32[1280,1280]{1,0} add(%select.21423.11.clone.1, %multiply.27021.7.clone.1) + %multiply.27022.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251098.9.clone.1, %select.21431.3.clone.1) + %add.251099.7.clone.1 = f32[1280,1280]{1,0} add(%select.21422.7.clone.1, %multiply.27022.7.clone.1) + %multiply.27023.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251099.7.clone.1, %maximum.3740.3.clone.1) + %select.21432.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7309.3.clone.1, %multiply.27012.9.clone.1, %multiply.27023.7.clone.1) + %multiply.27024.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21432.7.clone.1, %broadcast.244500.640) + %clamp.1220.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27024.5.clone.1, %broadcast.244501.384) + %multiply.27026.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1220.3.clone.1, %broadcast.244502.1) + %constant_175378_1_clone_1 = u32[] constant(1467111404) + %broadcast.251863.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175378_1_clone_1), dimensions={} + %add.248334.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251863.44.clone.1) + %constant_175392_1_clone_1 = u32[] constant(4217064098) + %broadcast.251864.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175392_1_clone_1), dimensions={} + %add.248336.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251864.113.clone.1) + %add.248337.35.clone.1 = u32[1280,1280]{1,0} add(%add.248334.37.clone.1, %add.248336.99.clone.1) + %shift-left.109583.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248336.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115781.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248336.99.clone.1, %broadcast.244415.6016) + %or.115314.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109583.31.clone.1, %shift-right-logical.115781.29.clone.1) + %xor.121862.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248337.35.clone.1, %or.115314.29.clone.1) + %add.248338.5.clone.1 = u32[1280,1280]{1,0} add(%add.248337.35.clone.1, %xor.121862.27.clone.1) + %shift-left.109585.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121862.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115783.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121862.27.clone.1, %broadcast.244417.5760) + %or.115315.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109585.9.clone.1, %shift-right-logical.115783.9.clone.1) + %xor.121863.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248338.5.clone.1, %or.115315.7.clone.1) + %add.248339.3.clone.1 = u32[1280,1280]{1,0} add(%add.248338.5.clone.1, %xor.121863.5.clone.1) + %shift-left.109586.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121863.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115784.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121863.5.clone.1, %broadcast.244419.4352) + %or.115316.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109586.5.clone.1, %shift-right-logical.115784.5.clone.1) + %xor.121864.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248339.3.clone.1, %or.115316.3.clone.1) + %add.248341.3.clone.1 = u32[1280,1280]{1,0} add(%add.248339.3.clone.1, %xor.121864.3.clone.1) + %add.248342.7.clone.1 = u32[1280,1280]{1,0} add(%add.248341.3.clone.1, %broadcast.251864.113.clone.1) + %shift-left.109587.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121864.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115785.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121864.3.clone.1, %broadcast.244418.4352) + %or.115317.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109587.5.clone.1, %shift-right-logical.115785.5.clone.1) + %xor.121865.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248341.3.clone.1, %or.115317.3.clone.1) + %constant_218242_1_clone_1 = u32[] constant(3086485141) + %broadcast.251876.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218242_1_clone_1), dimensions={} + %add.248343.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121865.3.clone.1, %broadcast.251876.5.clone.1) + %add.248344.5.clone.1 = u32[1280,1280]{1,0} add(%add.248342.7.clone.1, %add.248343.5.clone.1) + %shift-left.109588.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248343.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115786.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248343.5.clone.1, %broadcast.244416.5760) + %or.115318.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109588.9.clone.1, %shift-right-logical.115786.9.clone.1) + %xor.121866.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248344.5.clone.1, %or.115318.7.clone.1) + %add.248345.3.clone.1 = u32[1280,1280]{1,0} add(%add.248344.5.clone.1, %xor.121866.5.clone.1) + %shift-left.109589.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121866.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115788.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121866.5.clone.1, %broadcast.244429.2304) + %or.115319.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109589.9.clone.1, %shift-right-logical.115788.9.clone.1) + %xor.121867.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248345.3.clone.1, %or.115319.7.clone.1) + %add.248347.3.clone.1 = u32[1280,1280]{1,0} add(%add.248345.3.clone.1, %xor.121867.5.clone.1) + %shift-left.109590.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121867.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115789.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121867.5.clone.1, %broadcast.244430.4608) + %or.115320.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109590.9.clone.1, %shift-right-logical.115789.9.clone.1) + %xor.121868.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248347.3.clone.1, %or.115320.7.clone.1) + %add.248351.3.clone.1 = u32[1280,1280]{1,0} add(%add.248347.3.clone.1, %xor.121868.5.clone.1) + %constant_175396_1_clone_1 = u32[] constant(3086485140) + %broadcast.251883.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175396_1_clone_1), dimensions={} + %add.248352.7.clone.1 = u32[1280,1280]{1,0} add(%add.248351.3.clone.1, %broadcast.251883.24.clone.1) + %shift-left.109591.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121868.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115790.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121868.5.clone.1, %broadcast.244434.2816) + %or.115321.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109591.11.clone.1, %shift-right-logical.115790.11.clone.1) + %xor.121869.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248351.3.clone.1, %or.115321.9.clone.1) + %constant_218243_1_clone_1 = u32[] constant(1467111406) + %broadcast.251886.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218243_1_clone_1), dimensions={} + %add.248353.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121869.7.clone.1, %broadcast.251886.5.clone.1) + %add.248354.5.clone.1 = u32[1280,1280]{1,0} add(%add.248352.7.clone.1, %add.248353.5.clone.1) + %shift-left.109592.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248353.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115791.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248353.5.clone.1, %broadcast.244415.6016) + %or.115322.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109592.9.clone.1, %shift-right-logical.115791.9.clone.1) + %xor.121870.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248354.5.clone.1, %or.115322.7.clone.1) + %add.248356.3.clone.1 = u32[1280,1280]{1,0} add(%add.248354.5.clone.1, %xor.121870.5.clone.1) + %shift-left.109593.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121870.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115792.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121870.5.clone.1, %broadcast.244417.5760) + %or.115323.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109593.9.clone.1, %shift-right-logical.115792.9.clone.1) + %xor.121871.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248356.3.clone.1, %or.115323.7.clone.1) + %add.248357.3.clone.1 = u32[1280,1280]{1,0} add(%add.248356.3.clone.1, %xor.121871.5.clone.1) + %shift-left.109595.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121871.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115793.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121871.5.clone.1, %broadcast.244419.4352) + %or.115324.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109595.7.clone.1, %shift-right-logical.115793.7.clone.1) + %xor.121872.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248357.3.clone.1, %or.115324.5.clone.1) + %add.248358.3.clone.1 = u32[1280,1280]{1,0} add(%add.248357.3.clone.1, %xor.121872.3.clone.1) + %add.248359.7.clone.1 = u32[1280,1280]{1,0} add(%add.248358.3.clone.1, %broadcast.251863.44.clone.1) + %shift-left.109596.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121872.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115794.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121872.3.clone.1, %broadcast.244418.4352) + %or.115325.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109596.7.clone.1, %shift-right-logical.115794.7.clone.1) + %xor.121873.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248358.3.clone.1, %or.115325.5.clone.1) + %constant_218244_1_clone_1 = u32[] constant(4217064101) + %broadcast.251896.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218244_1_clone_1), dimensions={} + %add.248361.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121873.3.clone.1, %broadcast.251896.5.clone.1) + %add.248362.5.clone.1 = u32[1280,1280]{1,0} add(%add.248359.7.clone.1, %add.248361.5.clone.1) + %shift-left.109597.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248361.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115795.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248361.5.clone.1, %broadcast.244416.5760) + %or.115326.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109597.9.clone.1, %shift-right-logical.115795.9.clone.1) + %xor.121874.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248362.5.clone.1, %or.115326.7.clone.1) + %add.248363.3.clone.1 = u32[1280,1280]{1,0} add(%add.248362.5.clone.1, %xor.121874.5.clone.1) + %shift-left.109598.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121874.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115796.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121874.5.clone.1, %broadcast.244429.2304) + %or.115327.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109598.9.clone.1, %shift-right-logical.115796.9.clone.1) + %xor.121875.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248363.3.clone.1, %or.115327.7.clone.1) + %add.248364.3.clone.1 = u32[1280,1280]{1,0} add(%add.248363.3.clone.1, %xor.121875.5.clone.1) + %shift-left.109600.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121875.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115798.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121875.5.clone.1, %broadcast.244430.4608) + %or.115328.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109600.9.clone.1, %shift-right-logical.115798.9.clone.1) + %xor.121876.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248364.3.clone.1, %or.115328.7.clone.1) + %add.248366.3.clone.1 = u32[1280,1280]{1,0} add(%add.248364.3.clone.1, %xor.121876.5.clone.1) + %add.248367.7.clone.1 = u32[1280,1280]{1,0} add(%add.248366.3.clone.1, %broadcast.251864.113.clone.1) + %shift-left.109601.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121876.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115799.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121876.5.clone.1, %broadcast.244434.2816) + %or.115330.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109601.11.clone.1, %shift-right-logical.115799.11.clone.1) + %xor.121877.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248366.3.clone.1, %or.115330.9.clone.1) + %constant_218245_1_clone_1 = u32[] constant(3086485144) + %broadcast.251909.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218245_1_clone_1), dimensions={} + %add.248368.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121877.7.clone.1, %broadcast.251909.5.clone.1) + %add.248369.5.clone.1 = u32[1280,1280]{1,0} add(%add.248367.7.clone.1, %add.248368.5.clone.1) + %shift-left.109602.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248368.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115800.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248368.5.clone.1, %broadcast.244415.6016) + %or.115331.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109602.9.clone.1, %shift-right-logical.115800.9.clone.1) + %xor.121878.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248369.5.clone.1, %or.115331.7.clone.1) + %add.248370.3.clone.1 = u32[1280,1280]{1,0} add(%add.248369.5.clone.1, %xor.121878.5.clone.1) + %shift-left.109603.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121878.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115801.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121878.5.clone.1, %broadcast.244417.5760) + %or.115332.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109603.9.clone.1, %shift-right-logical.115801.9.clone.1) + %xor.121879.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248370.3.clone.1, %or.115332.7.clone.1) + %add.248372.3.clone.1 = u32[1280,1280]{1,0} add(%add.248370.3.clone.1, %xor.121879.5.clone.1) + %shift-left.109605.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121879.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115803.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121879.5.clone.1, %broadcast.244419.4352) + %or.115333.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109605.5.clone.1, %shift-right-logical.115803.5.clone.1) + %xor.121880.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248372.3.clone.1, %or.115333.3.clone.1) + %add.248376.3.clone.1 = u32[1280,1280]{1,0} add(%add.248372.3.clone.1, %xor.121880.3.clone.1) + %add.248377.17.clone.1 = u32[1280,1280]{1,0} add(%add.248376.3.clone.1, %broadcast.251883.24.clone.1) + %shift-left.109606.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121880.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115804.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121880.3.clone.1, %broadcast.244418.4352) + %or.115334.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109606.5.clone.1, %shift-right-logical.115804.5.clone.1) + %xor.121881.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248376.3.clone.1, %or.115334.3.clone.1) + %constant_218246_1_clone_1 = u32[] constant(1467111409) + %broadcast.251923.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218246_1_clone_1), dimensions={} + %add.248378.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121881.15.clone.1, %broadcast.251923.19.clone.1) + %xor.121882.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248377.17.clone.1, %add.248378.19.clone.1) + %shift-right-logical.115805.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121882.17.clone.1, %broadcast.244468.1920) + %or.115335.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115805.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5754.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115335.13.clone.1) + %add.248379.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5754.11.clone.1, %broadcast.244470.1152) + %multiply.26456.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248379.9.clone.1, %broadcast.244471.896) + %add.248381.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26456.7.clone.1, %broadcast.244408.1024) + %maximum.3686.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248381.5.clone.1) + %abs.1540.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3686.3.clone.1) + %compare.7228.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1540.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26457.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3686.3.clone.1, %broadcast.244476.1152) + %negate.4585.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3686.3.clone.1) + %multiply.26458.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3686.3.clone.1, %negate.4585.5.clone.1) + %log-plus-one.1540.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26458.5.clone.1) + %negate.4586.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1540.3.clone.1) + %compare.7229.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4586.4.clone.1, %broadcast.244477.384), direction=LT + %select.21021.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21022.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21023.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21024.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21025.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21026.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21027.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21028.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21029.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248382.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4586.4.clone.1, %broadcast.244496.640) + %sqrt.1540.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4586.4.clone.1) + %add.248383.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1540.5.clone.1, %broadcast.244498.640) + %select.21030.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7229.3.clone.1, %add.248382.5.clone.1, %add.248383.5.clone.1) + %multiply.26459.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21029.3.clone.1, %select.21030.3.clone.1) + %add.248384.1.clone.1 = f32[1280,1280]{1,0} add(%select.21028.3.clone.1, %multiply.26459.1.clone.1) + %multiply.26460.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248384.1.clone.1, %select.21030.3.clone.1) + %add.248386.1.clone.1 = f32[1280,1280]{1,0} add(%select.21027.3.clone.1, %multiply.26460.1.clone.1) + %multiply.26461.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248386.1.clone.1, %select.21030.3.clone.1) + %add.248387.1.clone.1 = f32[1280,1280]{1,0} add(%select.21026.3.clone.1, %multiply.26461.1.clone.1) + %multiply.26462.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248387.1.clone.1, %select.21030.3.clone.1) + %add.248388.1.clone.1 = f32[1280,1280]{1,0} add(%select.21025.3.clone.1, %multiply.26462.1.clone.1) + %multiply.26463.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248388.1.clone.1, %select.21030.3.clone.1) + %add.248389.3.clone.1 = f32[1280,1280]{1,0} add(%select.21024.5.clone.1, %multiply.26463.1.clone.1) + %multiply.26464.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248389.3.clone.1, %select.21030.3.clone.1) + %add.248391.3.clone.1 = f32[1280,1280]{1,0} add(%select.21023.5.clone.1, %multiply.26464.1.clone.1) + %multiply.26466.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248391.3.clone.1, %select.21030.3.clone.1) + %add.248392.9.clone.1 = f32[1280,1280]{1,0} add(%select.21022.11.clone.1, %multiply.26466.7.clone.1) + %multiply.26467.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248392.9.clone.1, %select.21030.3.clone.1) + %add.248393.7.clone.1 = f32[1280,1280]{1,0} add(%select.21021.7.clone.1, %multiply.26467.7.clone.1) + %multiply.26468.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248393.7.clone.1, %maximum.3686.3.clone.1) + %select.21031.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7228.3.clone.1, %multiply.26457.9.clone.1, %multiply.26468.7.clone.1) + %multiply.26469.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21031.7.clone.1, %broadcast.244500.640) + %clamp.1184.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26469.5.clone.1, %broadcast.244501.384) + %multiply.26470.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1184.3.clone.1, %broadcast.244502.1) + %constant_195878_1_clone_1 = u32[] constant(3825517637) + %broadcast.260736.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195878_1_clone_1), dimensions={} + %add.253396.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.260736.44.clone.1) + %constant_195885_1_clone_1 = u32[] constant(586838930) + %broadcast.260737.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195885_1_clone_1), dimensions={} + %add.253397.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.260737.113.clone.1) + %add.253398.35.clone.1 = u32[1280,1280]{1,0} add(%add.253396.37.clone.1, %add.253397.99.clone.1) + %shift-left.111800.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253397.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.118107.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253397.99.clone.1, %broadcast.244415.6016) + %or.117638.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111800.31.clone.1, %shift-right-logical.118107.29.clone.1) + %xor.124199.27.clone.1 = u32[1280,1280]{1,0} xor(%add.253398.35.clone.1, %or.117638.29.clone.1) + %add.253399.5.clone.1 = u32[1280,1280]{1,0} add(%add.253398.35.clone.1, %xor.124199.27.clone.1) + %shift-left.111801.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124199.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.118108.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124199.27.clone.1, %broadcast.244417.5760) + %or.117639.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111801.9.clone.1, %shift-right-logical.118108.9.clone.1) + %xor.124201.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253399.5.clone.1, %or.117639.7.clone.1) + %add.253400.3.clone.1 = u32[1280,1280]{1,0} add(%add.253399.5.clone.1, %xor.124201.5.clone.1) + %shift-left.111802.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124201.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118109.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124201.5.clone.1, %broadcast.244419.4352) + %or.117641.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111802.5.clone.1, %shift-right-logical.118109.5.clone.1) + %xor.124202.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253400.3.clone.1, %or.117641.3.clone.1) + %add.253401.3.clone.1 = u32[1280,1280]{1,0} add(%add.253400.3.clone.1, %xor.124202.3.clone.1) + %add.253402.7.clone.1 = u32[1280,1280]{1,0} add(%add.253401.3.clone.1, %broadcast.260737.113.clone.1) + %shift-left.111803.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124202.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118110.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124202.3.clone.1, %broadcast.244418.4352) + %or.117642.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111803.5.clone.1, %shift-right-logical.118110.5.clone.1) + %xor.124203.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253401.3.clone.1, %or.117642.3.clone.1) + %constant_218802_1_clone_1 = u32[] constant(3710888974) + %broadcast.260747.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218802_1_clone_1), dimensions={} + %add.253403.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124203.3.clone.1, %broadcast.260747.5.clone.1) + %add.253404.5.clone.1 = u32[1280,1280]{1,0} add(%add.253402.7.clone.1, %add.253403.5.clone.1) + %shift-left.111804.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253403.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118111.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253403.5.clone.1, %broadcast.244416.5760) + %or.117643.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111804.9.clone.1, %shift-right-logical.118111.9.clone.1) + %xor.124204.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253404.5.clone.1, %or.117643.7.clone.1) + %add.253405.3.clone.1 = u32[1280,1280]{1,0} add(%add.253404.5.clone.1, %xor.124204.5.clone.1) + %shift-left.111805.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124204.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118112.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124204.5.clone.1, %broadcast.244429.2304) + %or.117644.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111805.9.clone.1, %shift-right-logical.118112.9.clone.1) + %xor.124206.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253405.3.clone.1, %or.117644.7.clone.1) + %add.253406.3.clone.1 = u32[1280,1280]{1,0} add(%add.253405.3.clone.1, %xor.124206.5.clone.1) + %shift-left.111806.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124206.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118113.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124206.5.clone.1, %broadcast.244430.4608) + %or.117646.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111806.9.clone.1, %shift-right-logical.118113.9.clone.1) + %xor.124207.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253406.3.clone.1, %or.117646.7.clone.1) + %add.253407.3.clone.1 = u32[1280,1280]{1,0} add(%add.253406.3.clone.1, %xor.124207.5.clone.1) + %constant_195887_1_clone_1 = u32[] constant(3710888973) + %broadcast.260754.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195887_1_clone_1), dimensions={} + %add.253408.7.clone.1 = u32[1280,1280]{1,0} add(%add.253407.3.clone.1, %broadcast.260754.24.clone.1) + %shift-left.111807.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124207.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118114.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124207.5.clone.1, %broadcast.244434.2816) + %or.117647.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111807.11.clone.1, %shift-right-logical.118114.11.clone.1) + %xor.124208.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253407.3.clone.1, %or.117647.9.clone.1) + %constant_218803_1_clone_1 = u32[] constant(3825517639) + %broadcast.260757.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218803_1_clone_1), dimensions={} + %add.253409.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124208.7.clone.1, %broadcast.260757.5.clone.1) + %add.253410.5.clone.1 = u32[1280,1280]{1,0} add(%add.253408.7.clone.1, %add.253409.5.clone.1) + %shift-left.111808.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253409.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118115.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253409.5.clone.1, %broadcast.244415.6016) + %or.117648.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111808.9.clone.1, %shift-right-logical.118115.9.clone.1) + %xor.124209.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253410.5.clone.1, %or.117648.7.clone.1) + %add.253411.3.clone.1 = u32[1280,1280]{1,0} add(%add.253410.5.clone.1, %xor.124209.5.clone.1) + %shift-left.111809.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124209.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118116.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124209.5.clone.1, %broadcast.244417.5760) + %or.117649.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111809.9.clone.1, %shift-right-logical.118116.9.clone.1) + %xor.124211.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253411.3.clone.1, %or.117649.7.clone.1) + %add.253412.3.clone.1 = u32[1280,1280]{1,0} add(%add.253411.3.clone.1, %xor.124211.5.clone.1) + %shift-left.111810.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124211.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118117.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124211.5.clone.1, %broadcast.244419.4352) + %or.117650.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111810.7.clone.1, %shift-right-logical.118117.7.clone.1) + %xor.124212.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253412.3.clone.1, %or.117650.5.clone.1) + %add.253413.3.clone.1 = u32[1280,1280]{1,0} add(%add.253412.3.clone.1, %xor.124212.3.clone.1) + %add.253414.7.clone.1 = u32[1280,1280]{1,0} add(%add.253413.3.clone.1, %broadcast.260736.44.clone.1) + %shift-left.111811.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124212.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118118.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124212.3.clone.1, %broadcast.244418.4352) + %or.117651.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111811.7.clone.1, %shift-right-logical.118118.7.clone.1) + %xor.124213.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253413.3.clone.1, %or.117651.5.clone.1) + %constant_218804_1_clone_1 = u32[] constant(586838933) + %broadcast.260767.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218804_1_clone_1), dimensions={} + %add.253415.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124213.3.clone.1, %broadcast.260767.5.clone.1) + %add.253416.5.clone.1 = u32[1280,1280]{1,0} add(%add.253414.7.clone.1, %add.253415.5.clone.1) + %shift-left.111812.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253415.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118119.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253415.5.clone.1, %broadcast.244416.5760) + %or.117652.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111812.9.clone.1, %shift-right-logical.118119.9.clone.1) + %xor.124214.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253416.5.clone.1, %or.117652.7.clone.1) + %add.253417.3.clone.1 = u32[1280,1280]{1,0} add(%add.253416.5.clone.1, %xor.124214.5.clone.1) + %shift-left.111813.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124214.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118120.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124214.5.clone.1, %broadcast.244429.2304) + %or.117653.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111813.9.clone.1, %shift-right-logical.118120.9.clone.1) + %xor.124216.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253417.3.clone.1, %or.117653.7.clone.1) + %add.253418.3.clone.1 = u32[1280,1280]{1,0} add(%add.253417.3.clone.1, %xor.124216.5.clone.1) + %shift-left.111814.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124216.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118121.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124216.5.clone.1, %broadcast.244430.4608) + %or.117654.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111814.9.clone.1, %shift-right-logical.118121.9.clone.1) + %xor.124217.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253418.3.clone.1, %or.117654.7.clone.1) + %add.253419.3.clone.1 = u32[1280,1280]{1,0} add(%add.253418.3.clone.1, %xor.124217.5.clone.1) + %add.253420.7.clone.1 = u32[1280,1280]{1,0} add(%add.253419.3.clone.1, %broadcast.260737.113.clone.1) + %shift-left.111815.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124217.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118122.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124217.5.clone.1, %broadcast.244434.2816) + %or.117656.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111815.11.clone.1, %shift-right-logical.118122.11.clone.1) + %xor.124218.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253419.3.clone.1, %or.117656.9.clone.1) + %constant_218805_1_clone_1 = u32[] constant(3710888977) + %broadcast.260777.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218805_1_clone_1), dimensions={} + %add.253421.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124218.7.clone.1, %broadcast.260777.5.clone.1) + %add.253422.5.clone.1 = u32[1280,1280]{1,0} add(%add.253420.7.clone.1, %add.253421.5.clone.1) + %shift-left.111816.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253421.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118123.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253421.5.clone.1, %broadcast.244415.6016) + %or.117657.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111816.9.clone.1, %shift-right-logical.118123.9.clone.1) + %xor.124219.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253422.5.clone.1, %or.117657.7.clone.1) + %add.253423.3.clone.1 = u32[1280,1280]{1,0} add(%add.253422.5.clone.1, %xor.124219.5.clone.1) + %shift-left.111817.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124219.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118124.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124219.5.clone.1, %broadcast.244417.5760) + %or.117658.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111817.9.clone.1, %shift-right-logical.118124.9.clone.1) + %xor.124220.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253423.3.clone.1, %or.117658.7.clone.1) + %add.253425.3.clone.1 = u32[1280,1280]{1,0} add(%add.253423.3.clone.1, %xor.124220.5.clone.1) + %shift-left.111818.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124220.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118125.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124220.5.clone.1, %broadcast.244419.4352) + %or.117659.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111818.5.clone.1, %shift-right-logical.118125.5.clone.1) + %xor.124221.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253425.3.clone.1, %or.117659.3.clone.1) + %add.253428.3.clone.1 = u32[1280,1280]{1,0} add(%add.253425.3.clone.1, %xor.124221.3.clone.1) + %add.253429.17.clone.1 = u32[1280,1280]{1,0} add(%add.253428.3.clone.1, %broadcast.260754.24.clone.1) + %shift-left.111819.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124221.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118126.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124221.3.clone.1, %broadcast.244418.4352) + %or.117661.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111819.5.clone.1, %shift-right-logical.118126.5.clone.1) + %xor.124222.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253428.3.clone.1, %or.117661.3.clone.1) + %constant_218806_1_clone_1 = u32[] constant(3825517642) + %broadcast.260787.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218806_1_clone_1), dimensions={} + %add.253430.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124222.15.clone.1, %broadcast.260787.19.clone.1) + %xor.124223.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253429.17.clone.1, %add.253430.19.clone.1) + %shift-right-logical.118127.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124223.17.clone.1, %broadcast.244468.1920) + %or.117662.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.118127.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5855.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117662.13.clone.1) + %add.253431.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5855.11.clone.1, %broadcast.244470.1152) + %multiply.27492.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253431.9.clone.1, %broadcast.244471.896) + %add.253433.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27492.7.clone.1, %broadcast.244408.1024) + %maximum.3787.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253433.5.clone.1) + %abs.1607.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3787.3.clone.1) + %compare.7376.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1607.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27493.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3787.3.clone.1, %broadcast.244476.1152) + %negate.4719.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3787.3.clone.1) + %multiply.27494.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3787.3.clone.1, %negate.4719.5.clone.1) + %log-plus-one.1607.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27494.5.clone.1) + %negate.4720.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1607.3.clone.1) + %compare.7377.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4720.4.clone.1, %broadcast.244477.384), direction=LT + %select.21784.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21785.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21786.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21787.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21788.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21789.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21790.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21791.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21792.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253434.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4720.4.clone.1, %broadcast.244496.640) + %sqrt.1607.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4720.4.clone.1) + %add.253435.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1607.5.clone.1, %broadcast.244498.640) + %select.21793.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7377.3.clone.1, %add.253434.5.clone.1, %add.253435.5.clone.1) + %multiply.27495.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21792.3.clone.1, %select.21793.3.clone.1) + %add.253436.1.clone.1 = f32[1280,1280]{1,0} add(%select.21791.3.clone.1, %multiply.27495.1.clone.1) + %multiply.27496.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253436.1.clone.1, %select.21793.3.clone.1) + %add.253438.1.clone.1 = f32[1280,1280]{1,0} add(%select.21790.3.clone.1, %multiply.27496.1.clone.1) + %multiply.27497.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253438.1.clone.1, %select.21793.3.clone.1) + %add.253439.1.clone.1 = f32[1280,1280]{1,0} add(%select.21789.3.clone.1, %multiply.27497.1.clone.1) + %multiply.27498.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253439.1.clone.1, %select.21793.3.clone.1) + %add.253440.1.clone.1 = f32[1280,1280]{1,0} add(%select.21788.3.clone.1, %multiply.27498.1.clone.1) + %multiply.27499.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253440.1.clone.1, %select.21793.3.clone.1) + %add.253441.3.clone.1 = f32[1280,1280]{1,0} add(%select.21787.5.clone.1, %multiply.27499.1.clone.1) + %multiply.27500.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253441.3.clone.1, %select.21793.3.clone.1) + %add.253443.3.clone.1 = f32[1280,1280]{1,0} add(%select.21786.5.clone.1, %multiply.27500.1.clone.1) + %multiply.27501.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253443.3.clone.1, %select.21793.3.clone.1) + %add.253444.9.clone.1 = f32[1280,1280]{1,0} add(%select.21785.11.clone.1, %multiply.27501.7.clone.1) + %multiply.27502.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253444.9.clone.1, %select.21793.3.clone.1) + %add.253445.7.clone.1 = f32[1280,1280]{1,0} add(%select.21784.7.clone.1, %multiply.27502.7.clone.1) + %multiply.27503.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253445.7.clone.1, %maximum.3787.3.clone.1) + %select.21794.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7376.3.clone.1, %multiply.27493.9.clone.1, %multiply.27503.7.clone.1) + %multiply.27504.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21794.7.clone.1, %broadcast.244500.640) + %clamp.1251.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27504.5.clone.1, %broadcast.244501.384) + %multiply.27505.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1251.3.clone.1, %broadcast.244502.1) + %constant_175166_1_clone_1 = u32[] constant(4012746161) + %broadcast.251768.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175166_1_clone_1), dimensions={} + %add.248279.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251768.44.clone.1) + %constant_175173_1_clone_1 = u32[] constant(3152204928) + %broadcast.251770.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175173_1_clone_1), dimensions={} + %add.248280.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251770.113.clone.1) + %add.248281.35.clone.1 = u32[1280,1280]{1,0} add(%add.248279.37.clone.1, %add.248280.99.clone.1) + %shift-left.109560.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248280.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115756.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248280.99.clone.1, %broadcast.244415.6016) + %or.115293.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109560.31.clone.1, %shift-right-logical.115756.29.clone.1) + %xor.121840.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248281.35.clone.1, %or.115293.29.clone.1) + %add.248282.5.clone.1 = u32[1280,1280]{1,0} add(%add.248281.35.clone.1, %xor.121840.27.clone.1) + %shift-left.109561.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121840.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115758.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121840.27.clone.1, %broadcast.244417.5760) + %or.115294.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109561.9.clone.1, %shift-right-logical.115758.9.clone.1) + %xor.121841.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248282.5.clone.1, %or.115294.7.clone.1) + %add.248283.3.clone.1 = u32[1280,1280]{1,0} add(%add.248282.5.clone.1, %xor.121841.5.clone.1) + %shift-left.109562.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121841.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115759.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121841.5.clone.1, %broadcast.244419.4352) + %or.115295.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109562.5.clone.1, %shift-right-logical.115759.5.clone.1) + %xor.121842.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248283.3.clone.1, %or.115295.3.clone.1) + %add.248284.3.clone.1 = u32[1280,1280]{1,0} add(%add.248283.3.clone.1, %xor.121842.3.clone.1) + %add.248285.7.clone.1 = u32[1280,1280]{1,0} add(%add.248284.3.clone.1, %broadcast.251770.113.clone.1) + %shift-left.109563.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121842.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115760.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121842.3.clone.1, %broadcast.244418.4352) + %or.115296.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109563.5.clone.1, %shift-right-logical.115760.5.clone.1) + %xor.121843.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248284.3.clone.1, %or.115296.3.clone.1) + %constant_218237_1_clone_1 = u32[] constant(1327393516) + %broadcast.251782.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218237_1_clone_1), dimensions={} + %add.248286.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121843.3.clone.1, %broadcast.251782.5.clone.1) + %add.248287.5.clone.1 = u32[1280,1280]{1,0} add(%add.248285.7.clone.1, %add.248286.5.clone.1) + %shift-left.109564.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248286.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115761.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248286.5.clone.1, %broadcast.244416.5760) + %or.115297.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109564.9.clone.1, %shift-right-logical.115761.9.clone.1) + %xor.121844.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248287.5.clone.1, %or.115297.7.clone.1) + %add.248288.3.clone.1 = u32[1280,1280]{1,0} add(%add.248287.5.clone.1, %xor.121844.5.clone.1) + %shift-left.109565.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121844.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115763.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121844.5.clone.1, %broadcast.244429.2304) + %or.115298.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109565.9.clone.1, %shift-right-logical.115763.9.clone.1) + %xor.121845.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248288.3.clone.1, %or.115298.7.clone.1) + %add.248289.3.clone.1 = u32[1280,1280]{1,0} add(%add.248288.3.clone.1, %xor.121845.5.clone.1) + %shift-left.109566.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121845.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115764.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121845.5.clone.1, %broadcast.244430.4608) + %or.115299.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109566.9.clone.1, %shift-right-logical.115764.9.clone.1) + %xor.121847.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248289.3.clone.1, %or.115299.7.clone.1) + %add.248290.3.clone.1 = u32[1280,1280]{1,0} add(%add.248289.3.clone.1, %xor.121847.5.clone.1) + %constant_175175_1_clone_1 = u32[] constant(1327393515) + %broadcast.251789.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_175175_1_clone_1), dimensions={} + %add.248291.7.clone.1 = u32[1280,1280]{1,0} add(%add.248290.3.clone.1, %broadcast.251789.24.clone.1) + %shift-left.109567.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121847.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115765.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121847.5.clone.1, %broadcast.244434.2816) + %or.115300.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109567.11.clone.1, %shift-right-logical.115765.11.clone.1) + %xor.121848.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248290.3.clone.1, %or.115300.9.clone.1) + %constant_218238_1_clone_1 = u32[] constant(4012746163) + %broadcast.251792.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218238_1_clone_1), dimensions={} + %add.248292.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121848.7.clone.1, %broadcast.251792.5.clone.1) + %add.248293.5.clone.1 = u32[1280,1280]{1,0} add(%add.248291.7.clone.1, %add.248292.5.clone.1) + %shift-left.109568.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248292.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115766.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248292.5.clone.1, %broadcast.244415.6016) + %or.115301.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109568.9.clone.1, %shift-right-logical.115766.9.clone.1) + %xor.121849.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248293.5.clone.1, %or.115301.7.clone.1) + %add.248294.3.clone.1 = u32[1280,1280]{1,0} add(%add.248293.5.clone.1, %xor.121849.5.clone.1) + %shift-left.109570.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121849.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115767.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121849.5.clone.1, %broadcast.244417.5760) + %or.115302.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109570.9.clone.1, %shift-right-logical.115767.9.clone.1) + %xor.121850.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248294.3.clone.1, %or.115302.7.clone.1) + %add.248295.3.clone.1 = u32[1280,1280]{1,0} add(%add.248294.3.clone.1, %xor.121850.5.clone.1) + %shift-left.109571.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121850.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115768.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121850.5.clone.1, %broadcast.244419.4352) + %or.115303.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109571.7.clone.1, %shift-right-logical.115768.7.clone.1) + %xor.121851.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248295.3.clone.1, %or.115303.5.clone.1) + %add.248296.3.clone.1 = u32[1280,1280]{1,0} add(%add.248295.3.clone.1, %xor.121851.3.clone.1) + %add.248298.7.clone.1 = u32[1280,1280]{1,0} add(%add.248296.3.clone.1, %broadcast.251768.44.clone.1) + %shift-left.109572.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121851.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115769.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121851.3.clone.1, %broadcast.244418.4352) + %or.115304.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109572.7.clone.1, %shift-right-logical.115769.7.clone.1) + %xor.121852.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248296.3.clone.1, %or.115304.5.clone.1) + %constant_218239_1_clone_1 = u32[] constant(3152204931) + %broadcast.251804.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218239_1_clone_1), dimensions={} + %add.248301.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121852.3.clone.1, %broadcast.251804.5.clone.1) + %add.248302.5.clone.1 = u32[1280,1280]{1,0} add(%add.248298.7.clone.1, %add.248301.5.clone.1) + %shift-left.109573.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248301.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115770.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248301.5.clone.1, %broadcast.244416.5760) + %or.115305.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109573.9.clone.1, %shift-right-logical.115770.9.clone.1) + %xor.121853.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248302.5.clone.1, %or.115305.7.clone.1) + %add.248303.3.clone.1 = u32[1280,1280]{1,0} add(%add.248302.5.clone.1, %xor.121853.5.clone.1) + %shift-left.109575.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121853.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115771.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121853.5.clone.1, %broadcast.244429.2304) + %or.115306.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109575.9.clone.1, %shift-right-logical.115771.9.clone.1) + %xor.121854.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248303.3.clone.1, %or.115306.7.clone.1) + %add.248304.3.clone.1 = u32[1280,1280]{1,0} add(%add.248303.3.clone.1, %xor.121854.5.clone.1) + %shift-left.109576.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121854.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115773.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121854.5.clone.1, %broadcast.244430.4608) + %or.115307.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109576.9.clone.1, %shift-right-logical.115773.9.clone.1) + %xor.121855.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248304.3.clone.1, %or.115307.7.clone.1) + %add.248306.3.clone.1 = u32[1280,1280]{1,0} add(%add.248304.3.clone.1, %xor.121855.5.clone.1) + %add.248307.7.clone.1 = u32[1280,1280]{1,0} add(%add.248306.3.clone.1, %broadcast.251770.113.clone.1) + %shift-left.109577.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121855.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115774.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121855.5.clone.1, %broadcast.244434.2816) + %or.115308.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109577.11.clone.1, %shift-right-logical.115774.11.clone.1) + %xor.121856.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248306.3.clone.1, %or.115308.9.clone.1) + %constant_218240_1_clone_1 = u32[] constant(1327393519) + %broadcast.251814.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218240_1_clone_1), dimensions={} + %add.248308.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121856.7.clone.1, %broadcast.251814.5.clone.1) + %add.248309.5.clone.1 = u32[1280,1280]{1,0} add(%add.248307.7.clone.1, %add.248308.5.clone.1) + %shift-left.109578.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248308.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115775.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248308.5.clone.1, %broadcast.244415.6016) + %or.115309.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109578.9.clone.1, %shift-right-logical.115775.9.clone.1) + %xor.121857.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248309.5.clone.1, %or.115309.7.clone.1) + %add.248311.3.clone.1 = u32[1280,1280]{1,0} add(%add.248309.5.clone.1, %xor.121857.5.clone.1) + %shift-left.109580.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121857.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115776.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121857.5.clone.1, %broadcast.244417.5760) + %or.115310.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109580.9.clone.1, %shift-right-logical.115776.9.clone.1) + %xor.121858.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248311.3.clone.1, %or.115310.7.clone.1) + %add.248312.3.clone.1 = u32[1280,1280]{1,0} add(%add.248311.3.clone.1, %xor.121858.5.clone.1) + %shift-left.109581.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121858.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115778.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121858.5.clone.1, %broadcast.244419.4352) + %or.115311.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109581.5.clone.1, %shift-right-logical.115778.5.clone.1) + %xor.121859.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248312.3.clone.1, %or.115311.3.clone.1) + %add.248313.3.clone.1 = u32[1280,1280]{1,0} add(%add.248312.3.clone.1, %xor.121859.3.clone.1) + %add.248314.17.clone.1 = u32[1280,1280]{1,0} add(%add.248313.3.clone.1, %broadcast.251789.24.clone.1) + %shift-left.109582.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121859.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115779.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121859.3.clone.1, %broadcast.244418.4352) + %or.115312.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109582.5.clone.1, %shift-right-logical.115779.5.clone.1) + %xor.121860.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248313.3.clone.1, %or.115312.3.clone.1) + %constant_218241_1_clone_1 = u32[] constant(4012746166) + %broadcast.251826.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218241_1_clone_1), dimensions={} + %add.248316.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121860.15.clone.1, %broadcast.251826.19.clone.1) + %xor.121861.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248314.17.clone.1, %add.248316.19.clone.1) + %shift-right-logical.115780.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121861.17.clone.1, %broadcast.244468.1920) + %or.115313.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115780.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5753.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115313.13.clone.1) + %add.248317.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5753.11.clone.1, %broadcast.244470.1152) + %multiply.26442.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248317.9.clone.1, %broadcast.244471.896) + %add.248318.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26442.7.clone.1, %broadcast.244408.1024) + %maximum.3685.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248318.5.clone.1) + %abs.1539.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3685.3.clone.1) + %compare.7226.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1539.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26443.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3685.3.clone.1, %broadcast.244476.1152) + %negate.4583.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3685.3.clone.1) + %multiply.26444.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3685.3.clone.1, %negate.4583.5.clone.1) + %log-plus-one.1539.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26444.5.clone.1) + %negate.4584.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1539.3.clone.1) + %compare.7227.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4584.4.clone.1, %broadcast.244477.384), direction=LT + %select.21005.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21006.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21007.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21008.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21009.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21010.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21011.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21013.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21018.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248319.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4584.4.clone.1, %broadcast.244496.640) + %sqrt.1539.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4584.4.clone.1) + %add.248320.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1539.5.clone.1, %broadcast.244498.640) + %select.21019.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7227.3.clone.1, %add.248319.5.clone.1, %add.248320.5.clone.1) + %multiply.26445.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21018.3.clone.1, %select.21019.3.clone.1) + %add.248322.1.clone.1 = f32[1280,1280]{1,0} add(%select.21013.3.clone.1, %multiply.26445.1.clone.1) + %multiply.26446.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248322.1.clone.1, %select.21019.3.clone.1) + %add.248326.1.clone.1 = f32[1280,1280]{1,0} add(%select.21011.3.clone.1, %multiply.26446.1.clone.1) + %multiply.26447.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248326.1.clone.1, %select.21019.3.clone.1) + %add.248327.1.clone.1 = f32[1280,1280]{1,0} add(%select.21010.3.clone.1, %multiply.26447.1.clone.1) + %multiply.26448.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248327.1.clone.1, %select.21019.3.clone.1) + %add.248328.1.clone.1 = f32[1280,1280]{1,0} add(%select.21009.3.clone.1, %multiply.26448.1.clone.1) + %multiply.26449.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248328.1.clone.1, %select.21019.3.clone.1) + %add.248329.3.clone.1 = f32[1280,1280]{1,0} add(%select.21008.5.clone.1, %multiply.26449.1.clone.1) + %multiply.26450.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248329.3.clone.1, %select.21019.3.clone.1) + %add.248331.3.clone.1 = f32[1280,1280]{1,0} add(%select.21007.5.clone.1, %multiply.26450.1.clone.1) + %multiply.26451.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248331.3.clone.1, %select.21019.3.clone.1) + %add.248332.9.clone.1 = f32[1280,1280]{1,0} add(%select.21006.11.clone.1, %multiply.26451.7.clone.1) + %multiply.26452.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248332.9.clone.1, %select.21019.3.clone.1) + %add.248333.7.clone.1 = f32[1280,1280]{1,0} add(%select.21005.7.clone.1, %multiply.26452.7.clone.1) + %multiply.26453.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248333.7.clone.1, %maximum.3685.3.clone.1) + %select.21020.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7226.3.clone.1, %multiply.26443.9.clone.1, %multiply.26453.7.clone.1) + %multiply.26454.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21020.7.clone.1, %broadcast.244500.640) + %clamp.1183.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26454.5.clone.1, %broadcast.244501.384) + %multiply.26455.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1183.3.clone.1, %broadcast.244502.1) + %constant_186144_1_clone_1 = u32[] constant(2626188011) + %broadcast.256535.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186144_1_clone_1), dimensions={} + %add.250981.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256535.44.clone.1) + %constant_186151_1_clone_1 = u32[] constant(3739757680) + %broadcast.256536.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186151_1_clone_1), dimensions={} + %add.250982.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256536.113.clone.1) + %add.250983.35.clone.1 = u32[1280,1280]{1,0} add(%add.250981.37.clone.1, %add.250982.99.clone.1) + %shift-left.110750.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250982.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117013.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250982.99.clone.1, %broadcast.244415.6016) + %or.116526.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110750.31.clone.1, %shift-right-logical.117013.29.clone.1) + %xor.123102.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250983.35.clone.1, %or.116526.29.clone.1) + %add.250984.5.clone.1 = u32[1280,1280]{1,0} add(%add.250983.35.clone.1, %xor.123102.27.clone.1) + %shift-left.110751.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123102.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117014.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123102.27.clone.1, %broadcast.244417.5760) + %or.116527.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110751.9.clone.1, %shift-right-logical.117014.9.clone.1) + %xor.123104.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250984.5.clone.1, %or.116527.7.clone.1) + %add.250985.3.clone.1 = u32[1280,1280]{1,0} add(%add.250984.5.clone.1, %xor.123104.5.clone.1) + %shift-left.110752.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123104.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117015.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123104.5.clone.1, %broadcast.244419.4352) + %or.116528.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110752.5.clone.1, %shift-right-logical.117015.5.clone.1) + %xor.123105.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250985.3.clone.1, %or.116528.3.clone.1) + %add.250986.3.clone.1 = u32[1280,1280]{1,0} add(%add.250985.3.clone.1, %xor.123105.3.clone.1) + %add.250987.7.clone.1 = u32[1280,1280]{1,0} add(%add.250986.3.clone.1, %broadcast.256536.113.clone.1) + %shift-left.110753.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123105.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117016.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123105.3.clone.1, %broadcast.244418.4352) + %or.116529.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110753.5.clone.1, %shift-right-logical.117016.5.clone.1) + %xor.123106.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250986.3.clone.1, %or.116529.3.clone.1) + %constant_218526_1_clone_1 = u32[] constant(1504789826) + %broadcast.256546.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218526_1_clone_1), dimensions={} + %add.250988.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123106.3.clone.1, %broadcast.256546.5.clone.1) + %add.250989.5.clone.1 = u32[1280,1280]{1,0} add(%add.250987.7.clone.1, %add.250988.5.clone.1) + %shift-left.110755.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250988.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117017.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250988.5.clone.1, %broadcast.244416.5760) + %or.116531.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110755.9.clone.1, %shift-right-logical.117017.9.clone.1) + %xor.123107.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250989.5.clone.1, %or.116531.7.clone.1) + %add.250990.3.clone.1 = u32[1280,1280]{1,0} add(%add.250989.5.clone.1, %xor.123107.5.clone.1) + %shift-left.110756.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123107.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117018.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123107.5.clone.1, %broadcast.244429.2304) + %or.116532.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110756.9.clone.1, %shift-right-logical.117018.9.clone.1) + %xor.123109.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250990.3.clone.1, %or.116532.7.clone.1) + %add.250991.3.clone.1 = u32[1280,1280]{1,0} add(%add.250990.3.clone.1, %xor.123109.5.clone.1) + %shift-left.110757.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123109.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117019.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123109.5.clone.1, %broadcast.244430.4608) + %or.116533.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110757.9.clone.1, %shift-right-logical.117019.9.clone.1) + %xor.123110.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250991.3.clone.1, %or.116533.7.clone.1) + %add.250993.3.clone.1 = u32[1280,1280]{1,0} add(%add.250991.3.clone.1, %xor.123110.5.clone.1) + %constant_186153_1_clone_1 = u32[] constant(1504789825) + %broadcast.256555.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_186153_1_clone_1), dimensions={} + %add.250996.7.clone.1 = u32[1280,1280]{1,0} add(%add.250993.3.clone.1, %broadcast.256555.24.clone.1) + %shift-left.110758.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123110.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117020.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123110.5.clone.1, %broadcast.244434.2816) + %or.116534.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110758.11.clone.1, %shift-right-logical.117020.11.clone.1) + %xor.123111.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250993.3.clone.1, %or.116534.9.clone.1) + %constant_218528_1_clone_1 = u32[] constant(2626188013) + %broadcast.256558.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218528_1_clone_1), dimensions={} + %add.250997.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123111.7.clone.1, %broadcast.256558.5.clone.1) + %add.250998.5.clone.1 = u32[1280,1280]{1,0} add(%add.250996.7.clone.1, %add.250997.5.clone.1) + %shift-left.110760.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250997.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117021.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250997.5.clone.1, %broadcast.244415.6016) + %or.116536.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110760.9.clone.1, %shift-right-logical.117021.9.clone.1) + %xor.123112.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250998.5.clone.1, %or.116536.7.clone.1) + %add.250999.3.clone.1 = u32[1280,1280]{1,0} add(%add.250998.5.clone.1, %xor.123112.5.clone.1) + %shift-left.110761.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123112.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117022.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123112.5.clone.1, %broadcast.244417.5760) + %or.116537.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110761.9.clone.1, %shift-right-logical.117022.9.clone.1) + %xor.123114.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250999.3.clone.1, %or.116537.7.clone.1) + %add.251001.3.clone.1 = u32[1280,1280]{1,0} add(%add.250999.3.clone.1, %xor.123114.5.clone.1) + %shift-left.110762.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123114.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117023.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123114.5.clone.1, %broadcast.244419.4352) + %or.116538.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110762.7.clone.1, %shift-right-logical.117023.7.clone.1) + %xor.123115.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251001.3.clone.1, %or.116538.5.clone.1) + %add.251002.3.clone.1 = u32[1280,1280]{1,0} add(%add.251001.3.clone.1, %xor.123115.3.clone.1) + %add.251003.7.clone.1 = u32[1280,1280]{1,0} add(%add.251002.3.clone.1, %broadcast.256535.44.clone.1) + %shift-left.110763.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123115.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117024.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123115.3.clone.1, %broadcast.244418.4352) + %or.116539.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110763.7.clone.1, %shift-right-logical.117024.7.clone.1) + %xor.123116.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251002.3.clone.1, %or.116539.5.clone.1) + %constant_218530_1_clone_1 = u32[] constant(3739757683) + %broadcast.256573.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218530_1_clone_1), dimensions={} + %add.251004.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123116.3.clone.1, %broadcast.256573.5.clone.1) + %add.251006.5.clone.1 = u32[1280,1280]{1,0} add(%add.251003.7.clone.1, %add.251004.5.clone.1) + %shift-left.110764.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251004.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117025.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251004.5.clone.1, %broadcast.244416.5760) + %or.116541.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110764.9.clone.1, %shift-right-logical.117025.9.clone.1) + %xor.123117.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251006.5.clone.1, %or.116541.7.clone.1) + %add.251007.3.clone.1 = u32[1280,1280]{1,0} add(%add.251006.5.clone.1, %xor.123117.5.clone.1) + %shift-left.110765.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123117.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117026.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123117.5.clone.1, %broadcast.244429.2304) + %or.116542.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110765.9.clone.1, %shift-right-logical.117026.9.clone.1) + %xor.123118.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251007.3.clone.1, %or.116542.7.clone.1) + %add.251008.3.clone.1 = u32[1280,1280]{1,0} add(%add.251007.3.clone.1, %xor.123118.5.clone.1) + %shift-left.110766.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123118.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117027.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123118.5.clone.1, %broadcast.244430.4608) + %or.116543.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110766.9.clone.1, %shift-right-logical.117027.9.clone.1) + %xor.123119.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251008.3.clone.1, %or.116543.7.clone.1) + %add.251009.3.clone.1 = u32[1280,1280]{1,0} add(%add.251008.3.clone.1, %xor.123119.5.clone.1) + %add.251011.7.clone.1 = u32[1280,1280]{1,0} add(%add.251009.3.clone.1, %broadcast.256536.113.clone.1) + %shift-left.110767.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123119.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117028.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123119.5.clone.1, %broadcast.244434.2816) + %or.116544.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110767.11.clone.1, %shift-right-logical.117028.11.clone.1) + %xor.123120.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251009.3.clone.1, %or.116544.9.clone.1) + %constant_218532_1_clone_1 = u32[] constant(1504789829) + %broadcast.256585.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218532_1_clone_1), dimensions={} + %add.251012.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123120.7.clone.1, %broadcast.256585.5.clone.1) + %add.251013.5.clone.1 = u32[1280,1280]{1,0} add(%add.251011.7.clone.1, %add.251012.5.clone.1) + %shift-left.110768.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251012.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117029.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251012.5.clone.1, %broadcast.244415.6016) + %or.116546.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110768.9.clone.1, %shift-right-logical.117029.9.clone.1) + %xor.123121.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251013.5.clone.1, %or.116546.7.clone.1) + %add.251014.3.clone.1 = u32[1280,1280]{1,0} add(%add.251013.5.clone.1, %xor.123121.5.clone.1) + %shift-left.110770.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123121.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117030.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123121.5.clone.1, %broadcast.244417.5760) + %or.116547.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110770.9.clone.1, %shift-right-logical.117030.9.clone.1) + %xor.123122.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251014.3.clone.1, %or.116547.7.clone.1) + %add.251015.3.clone.1 = u32[1280,1280]{1,0} add(%add.251014.3.clone.1, %xor.123122.5.clone.1) + %shift-left.110771.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123122.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117031.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123122.5.clone.1, %broadcast.244419.4352) + %or.116548.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110771.5.clone.1, %shift-right-logical.117031.5.clone.1) + %xor.123124.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251015.3.clone.1, %or.116548.3.clone.1) + %add.251017.3.clone.1 = u32[1280,1280]{1,0} add(%add.251015.3.clone.1, %xor.123124.3.clone.1) + %add.251021.17.clone.1 = u32[1280,1280]{1,0} add(%add.251017.3.clone.1, %broadcast.256555.24.clone.1) + %shift-left.110772.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123124.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117032.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123124.3.clone.1, %broadcast.244418.4352) + %or.116549.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110772.5.clone.1, %shift-right-logical.117032.5.clone.1) + %xor.123125.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251017.3.clone.1, %or.116549.3.clone.1) + %constant_218533_1_clone_1 = u32[] constant(2626188016) + %broadcast.256596.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218533_1_clone_1), dimensions={} + %add.251022.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123125.15.clone.1, %broadcast.256596.19.clone.1) + %xor.123126.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251021.17.clone.1, %add.251022.19.clone.1) + %shift-right-logical.117033.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123126.17.clone.1, %broadcast.244468.1920) + %or.116550.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117033.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5807.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116550.13.clone.1) + %add.251023.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5807.11.clone.1, %broadcast.244470.1152) + %multiply.26997.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251023.9.clone.1, %broadcast.244471.896) + %add.251024.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26997.7.clone.1, %broadcast.244408.1024) + %maximum.3739.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251024.5.clone.1) + %abs.1575.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3739.3.clone.1) + %compare.7307.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1575.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26998.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3739.3.clone.1, %broadcast.244476.1152) + %negate.4655.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3739.3.clone.1) + %multiply.26999.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3739.3.clone.1, %negate.4655.5.clone.1) + %log-plus-one.1575.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26999.5.clone.1) + %negate.4656.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1575.3.clone.1) + %compare.7308.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4656.4.clone.1, %broadcast.244477.384), direction=LT + %select.21411.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21412.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21413.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21414.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21415.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21416.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21417.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21418.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21419.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251026.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4656.4.clone.1, %broadcast.244496.640) + %sqrt.1575.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4656.4.clone.1) + %add.251027.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1575.5.clone.1, %broadcast.244498.640) + %select.21420.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7308.3.clone.1, %add.251026.5.clone.1, %add.251027.5.clone.1) + %multiply.27000.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21419.3.clone.1, %select.21420.3.clone.1) + %add.251028.1.clone.1 = f32[1280,1280]{1,0} add(%select.21418.3.clone.1, %multiply.27000.1.clone.1) + %multiply.27001.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251028.1.clone.1, %select.21420.3.clone.1) + %add.251029.1.clone.1 = f32[1280,1280]{1,0} add(%select.21417.3.clone.1, %multiply.27001.1.clone.1) + %multiply.27002.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251029.1.clone.1, %select.21420.3.clone.1) + %add.251031.1.clone.1 = f32[1280,1280]{1,0} add(%select.21416.3.clone.1, %multiply.27002.1.clone.1) + %multiply.27003.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251031.1.clone.1, %select.21420.3.clone.1) + %add.251032.1.clone.1 = f32[1280,1280]{1,0} add(%select.21415.3.clone.1, %multiply.27003.1.clone.1) + %multiply.27004.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251032.1.clone.1, %select.21420.3.clone.1) + %add.251033.3.clone.1 = f32[1280,1280]{1,0} add(%select.21414.5.clone.1, %multiply.27004.1.clone.1) + %multiply.27005.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251033.3.clone.1, %select.21420.3.clone.1) + %add.251034.3.clone.1 = f32[1280,1280]{1,0} add(%select.21413.5.clone.1, %multiply.27005.1.clone.1) + %multiply.27006.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251034.3.clone.1, %select.21420.3.clone.1) + %add.251036.9.clone.1 = f32[1280,1280]{1,0} add(%select.21412.11.clone.1, %multiply.27006.7.clone.1) + %multiply.27007.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251036.9.clone.1, %select.21420.3.clone.1) + %add.251037.7.clone.1 = f32[1280,1280]{1,0} add(%select.21411.7.clone.1, %multiply.27007.7.clone.1) + %multiply.27008.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251037.7.clone.1, %maximum.3739.3.clone.1) + %select.21421.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7307.3.clone.1, %multiply.26998.9.clone.1, %multiply.27008.7.clone.1) + %multiply.27009.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21421.7.clone.1, %broadcast.244500.640) + %clamp.1219.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27009.5.clone.1, %broadcast.244501.384) + %multiply.27010.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1219.3.clone.1, %broadcast.244502.1) + %constant_174615_1_clone_1 = u32[] constant(385371153) + %broadcast.251538.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174615_1_clone_1), dimensions={} + %add.248145.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251538.44.clone.1) + %constant_174622_1_clone_1 = u32[] constant(4274959127) + %broadcast.251539.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174622_1_clone_1), dimensions={} + %add.248146.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251539.113.clone.1) + %add.248147.35.clone.1 = u32[1280,1280]{1,0} add(%add.248145.37.clone.1, %add.248146.99.clone.1) + %shift-left.109500.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248146.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115687.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248146.99.clone.1, %broadcast.244415.6016) + %or.115218.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109500.31.clone.1, %shift-right-logical.115687.29.clone.1) + %xor.121770.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248147.35.clone.1, %or.115218.29.clone.1) + %add.248149.5.clone.1 = u32[1280,1280]{1,0} add(%add.248147.35.clone.1, %xor.121770.27.clone.1) + %shift-left.109501.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121770.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115688.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121770.27.clone.1, %broadcast.244417.5760) + %or.115219.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109501.9.clone.1, %shift-right-logical.115688.9.clone.1) + %xor.121771.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248149.5.clone.1, %or.115219.7.clone.1) + %add.248150.3.clone.1 = u32[1280,1280]{1,0} add(%add.248149.5.clone.1, %xor.121771.5.clone.1) + %shift-left.109502.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121771.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115689.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121771.5.clone.1, %broadcast.244419.4352) + %or.115220.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109502.5.clone.1, %shift-right-logical.115689.5.clone.1) + %xor.121773.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248150.3.clone.1, %or.115220.3.clone.1) + %add.248151.3.clone.1 = u32[1280,1280]{1,0} add(%add.248150.3.clone.1, %xor.121773.3.clone.1) + %add.248152.7.clone.1 = u32[1280,1280]{1,0} add(%add.248151.3.clone.1, %broadcast.251539.113.clone.1) + %shift-left.109503.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121773.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115690.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121773.3.clone.1, %broadcast.244418.4352) + %or.115221.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109503.5.clone.1, %shift-right-logical.115690.5.clone.1) + %xor.121774.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248151.3.clone.1, %or.115221.3.clone.1) + %constant_218222_1_clone_1 = u32[] constant(4092060893) + %broadcast.251549.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218222_1_clone_1), dimensions={} + %add.248154.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121774.3.clone.1, %broadcast.251549.5.clone.1) + %add.248155.5.clone.1 = u32[1280,1280]{1,0} add(%add.248152.7.clone.1, %add.248154.5.clone.1) + %shift-left.109504.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248154.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115691.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248154.5.clone.1, %broadcast.244416.5760) + %or.115223.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109504.9.clone.1, %shift-right-logical.115691.9.clone.1) + %xor.121775.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248155.5.clone.1, %or.115223.7.clone.1) + %add.248156.3.clone.1 = u32[1280,1280]{1,0} add(%add.248155.5.clone.1, %xor.121775.5.clone.1) + %shift-left.109505.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121775.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115692.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121775.5.clone.1, %broadcast.244429.2304) + %or.115224.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109505.9.clone.1, %shift-right-logical.115692.9.clone.1) + %xor.121776.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248156.3.clone.1, %or.115224.7.clone.1) + %add.248157.3.clone.1 = u32[1280,1280]{1,0} add(%add.248156.3.clone.1, %xor.121776.5.clone.1) + %shift-left.109506.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121776.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115693.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121776.5.clone.1, %broadcast.244430.4608) + %or.115225.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109506.9.clone.1, %shift-right-logical.115693.9.clone.1) + %xor.121778.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248157.3.clone.1, %or.115225.7.clone.1) + %add.248159.3.clone.1 = u32[1280,1280]{1,0} add(%add.248157.3.clone.1, %xor.121778.5.clone.1) + %constant_174624_1_clone_1 = u32[] constant(4092060892) + %broadcast.251562.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174624_1_clone_1), dimensions={} + %add.248160.7.clone.1 = u32[1280,1280]{1,0} add(%add.248159.3.clone.1, %broadcast.251562.24.clone.1) + %shift-left.109507.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121778.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115694.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121778.5.clone.1, %broadcast.244434.2816) + %or.115226.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109507.11.clone.1, %shift-right-logical.115694.11.clone.1) + %xor.121779.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248159.3.clone.1, %or.115226.9.clone.1) + %constant_218223_1_clone_1 = u32[] constant(385371155) + %broadcast.251565.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218223_1_clone_1), dimensions={} + %add.248161.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121779.7.clone.1, %broadcast.251565.5.clone.1) + %add.248162.5.clone.1 = u32[1280,1280]{1,0} add(%add.248160.7.clone.1, %add.248161.5.clone.1) + %shift-left.109508.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248161.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115695.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248161.5.clone.1, %broadcast.244415.6016) + %or.115228.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109508.9.clone.1, %shift-right-logical.115695.9.clone.1) + %xor.121780.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248162.5.clone.1, %or.115228.7.clone.1) + %add.248163.3.clone.1 = u32[1280,1280]{1,0} add(%add.248162.5.clone.1, %xor.121780.5.clone.1) + %shift-left.109509.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121780.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115696.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121780.5.clone.1, %broadcast.244417.5760) + %or.115229.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109509.9.clone.1, %shift-right-logical.115696.9.clone.1) + %xor.121781.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248163.3.clone.1, %or.115229.7.clone.1) + %add.248165.3.clone.1 = u32[1280,1280]{1,0} add(%add.248163.3.clone.1, %xor.121781.5.clone.1) + %shift-left.109510.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121781.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115697.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121781.5.clone.1, %broadcast.244419.4352) + %or.115230.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109510.7.clone.1, %shift-right-logical.115697.7.clone.1) + %xor.121783.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248165.3.clone.1, %or.115230.5.clone.1) + %add.248169.3.clone.1 = u32[1280,1280]{1,0} add(%add.248165.3.clone.1, %xor.121783.3.clone.1) + %add.248170.7.clone.1 = u32[1280,1280]{1,0} add(%add.248169.3.clone.1, %broadcast.251538.44.clone.1) + %shift-left.109511.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121783.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115698.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121783.3.clone.1, %broadcast.244418.4352) + %or.115231.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109511.7.clone.1, %shift-right-logical.115698.7.clone.1) + %xor.121784.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248169.3.clone.1, %or.115231.5.clone.1) + %constant_218224_1_clone_1 = u32[] constant(4274959130) + %broadcast.251575.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218224_1_clone_1), dimensions={} + %add.248171.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121784.3.clone.1, %broadcast.251575.5.clone.1) + %add.248172.5.clone.1 = u32[1280,1280]{1,0} add(%add.248170.7.clone.1, %add.248171.5.clone.1) + %shift-left.109512.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248171.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115699.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248171.5.clone.1, %broadcast.244416.5760) + %or.115233.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109512.9.clone.1, %shift-right-logical.115699.9.clone.1) + %xor.121785.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248172.5.clone.1, %or.115233.7.clone.1) + %add.248174.3.clone.1 = u32[1280,1280]{1,0} add(%add.248172.5.clone.1, %xor.121785.5.clone.1) + %shift-left.109513.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121785.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115700.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121785.5.clone.1, %broadcast.244429.2304) + %or.115234.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109513.9.clone.1, %shift-right-logical.115700.9.clone.1) + %xor.121786.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248174.3.clone.1, %or.115234.7.clone.1) + %add.248175.3.clone.1 = u32[1280,1280]{1,0} add(%add.248174.3.clone.1, %xor.121786.5.clone.1) + %shift-left.109514.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121786.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115701.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121786.5.clone.1, %broadcast.244430.4608) + %or.115235.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109514.9.clone.1, %shift-right-logical.115701.9.clone.1) + %xor.121787.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248175.3.clone.1, %or.115235.7.clone.1) + %add.248176.3.clone.1 = u32[1280,1280]{1,0} add(%add.248175.3.clone.1, %xor.121787.5.clone.1) + %add.248177.7.clone.1 = u32[1280,1280]{1,0} add(%add.248176.3.clone.1, %broadcast.251539.113.clone.1) + %shift-left.109515.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121787.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115702.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121787.5.clone.1, %broadcast.244434.2816) + %or.115236.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109515.11.clone.1, %shift-right-logical.115702.11.clone.1) + %xor.121788.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248176.3.clone.1, %or.115236.9.clone.1) + %constant_218225_1_clone_1 = u32[] constant(4092060896) + %broadcast.251585.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218225_1_clone_1), dimensions={} + %add.248179.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121788.7.clone.1, %broadcast.251585.5.clone.1) + %add.248180.5.clone.1 = u32[1280,1280]{1,0} add(%add.248177.7.clone.1, %add.248179.5.clone.1) + %shift-left.109516.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248179.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115703.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248179.5.clone.1, %broadcast.244415.6016) + %or.115238.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109516.9.clone.1, %shift-right-logical.115703.9.clone.1) + %xor.121789.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248180.5.clone.1, %or.115238.7.clone.1) + %add.248181.3.clone.1 = u32[1280,1280]{1,0} add(%add.248180.5.clone.1, %xor.121789.5.clone.1) + %shift-left.109517.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121789.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115704.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121789.5.clone.1, %broadcast.244417.5760) + %or.115239.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109517.9.clone.1, %shift-right-logical.115704.9.clone.1) + %xor.121790.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248181.3.clone.1, %or.115239.7.clone.1) + %add.248182.3.clone.1 = u32[1280,1280]{1,0} add(%add.248181.3.clone.1, %xor.121790.5.clone.1) + %shift-left.109518.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121790.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115705.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121790.5.clone.1, %broadcast.244419.4352) + %or.115240.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109518.5.clone.1, %shift-right-logical.115705.5.clone.1) + %xor.121791.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248182.3.clone.1, %or.115240.3.clone.1) + %add.248184.3.clone.1 = u32[1280,1280]{1,0} add(%add.248182.3.clone.1, %xor.121791.3.clone.1) + %add.248185.17.clone.1 = u32[1280,1280]{1,0} add(%add.248184.3.clone.1, %broadcast.251562.24.clone.1) + %shift-left.109519.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121791.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115706.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121791.3.clone.1, %broadcast.244418.4352) + %or.115241.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109519.5.clone.1, %shift-right-logical.115706.5.clone.1) + %xor.121793.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248184.3.clone.1, %or.115241.3.clone.1) + %constant_218226_1_clone_1 = u32[] constant(385371158) + %broadcast.251595.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218226_1_clone_1), dimensions={} + %add.248186.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121793.15.clone.1, %broadcast.251595.19.clone.1) + %xor.121794.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248185.17.clone.1, %add.248186.19.clone.1) + %shift-right-logical.115707.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121794.17.clone.1, %broadcast.244468.1920) + %or.115242.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115707.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5750.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115242.13.clone.1) + %add.248187.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5750.11.clone.1, %broadcast.244470.1152) + %multiply.26424.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248187.9.clone.1, %broadcast.244471.896) + %add.248188.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26424.7.clone.1, %broadcast.244408.1024) + %maximum.3682.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248188.5.clone.1) + %abs.1538.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3682.3.clone.1) + %compare.7224.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1538.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26425.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3682.3.clone.1, %broadcast.244476.1152) + %negate.4581.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3682.3.clone.1) + %multiply.26426.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3682.3.clone.1, %negate.4581.5.clone.1) + %log-plus-one.1538.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26426.5.clone.1) + %negate.4582.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1538.3.clone.1) + %compare.7225.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4582.4.clone.1, %broadcast.244477.384), direction=LT + %select.20993.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20994.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20996.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20997.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20998.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20999.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21000.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21001.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21002.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248190.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4582.4.clone.1, %broadcast.244496.640) + %sqrt.1538.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4582.4.clone.1) + %add.248194.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1538.5.clone.1, %broadcast.244498.640) + %select.21003.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7225.3.clone.1, %add.248190.5.clone.1, %add.248194.5.clone.1) + %multiply.26427.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21002.3.clone.1, %select.21003.3.clone.1) + %add.248195.1.clone.1 = f32[1280,1280]{1,0} add(%select.21001.3.clone.1, %multiply.26427.1.clone.1) + %multiply.26428.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248195.1.clone.1, %select.21003.3.clone.1) + %add.248196.1.clone.1 = f32[1280,1280]{1,0} add(%select.21000.3.clone.1, %multiply.26428.1.clone.1) + %multiply.26429.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248196.1.clone.1, %select.21003.3.clone.1) + %add.248197.1.clone.1 = f32[1280,1280]{1,0} add(%select.20999.3.clone.1, %multiply.26429.1.clone.1) + %multiply.26430.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248197.1.clone.1, %select.21003.3.clone.1) + %add.248199.1.clone.1 = f32[1280,1280]{1,0} add(%select.20998.3.clone.1, %multiply.26430.1.clone.1) + %multiply.26431.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248199.1.clone.1, %select.21003.3.clone.1) + %add.248200.3.clone.1 = f32[1280,1280]{1,0} add(%select.20997.5.clone.1, %multiply.26431.1.clone.1) + %multiply.26432.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248200.3.clone.1, %select.21003.3.clone.1) + %add.248201.3.clone.1 = f32[1280,1280]{1,0} add(%select.20996.5.clone.1, %multiply.26432.1.clone.1) + %multiply.26433.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248201.3.clone.1, %select.21003.3.clone.1) + %add.248202.9.clone.1 = f32[1280,1280]{1,0} add(%select.20994.11.clone.1, %multiply.26433.7.clone.1) + %multiply.26434.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248202.9.clone.1, %select.21003.3.clone.1) + %add.248204.7.clone.1 = f32[1280,1280]{1,0} add(%select.20993.7.clone.1, %multiply.26434.7.clone.1) + %multiply.26435.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248204.7.clone.1, %maximum.3682.3.clone.1) + %select.21004.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7224.3.clone.1, %multiply.26425.9.clone.1, %multiply.26435.7.clone.1) + %multiply.26436.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21004.7.clone.1, %broadcast.244500.640) + %clamp.1182.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26436.5.clone.1, %broadcast.244501.384) + %multiply.26437.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1182.3.clone.1, %broadcast.244502.1) + %constant_191448_1_clone_1 = u32[] constant(2491264431) + %broadcast.258815.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191448_1_clone_1), dimensions={} + %add.252310.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258815.44.clone.1) + %constant_191457_1_clone_1 = u32[] constant(3255467134) + %broadcast.258816.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191457_1_clone_1), dimensions={} + %add.252311.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258816.113.clone.1) + %add.252312.35.clone.1 = u32[1280,1280]{1,0} add(%add.252310.37.clone.1, %add.252311.99.clone.1) + %shift-left.111320.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252311.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117602.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252311.99.clone.1, %broadcast.244415.6016) + %or.117140.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111320.31.clone.1, %shift-right-logical.117602.29.clone.1) + %xor.123700.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252312.35.clone.1, %or.117140.29.clone.1) + %add.252313.5.clone.1 = u32[1280,1280]{1,0} add(%add.252312.35.clone.1, %xor.123700.27.clone.1) + %shift-left.111321.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123700.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117603.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123700.27.clone.1, %broadcast.244417.5760) + %or.117141.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111321.9.clone.1, %shift-right-logical.117603.9.clone.1) + %xor.123701.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252313.5.clone.1, %or.117141.7.clone.1) + %add.252314.3.clone.1 = u32[1280,1280]{1,0} add(%add.252313.5.clone.1, %xor.123701.5.clone.1) + %shift-left.111322.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123701.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117605.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123701.5.clone.1, %broadcast.244419.4352) + %or.117142.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111322.5.clone.1, %shift-right-logical.117605.5.clone.1) + %xor.123702.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252314.3.clone.1, %or.117142.3.clone.1) + %add.252315.3.clone.1 = u32[1280,1280]{1,0} add(%add.252314.3.clone.1, %xor.123702.3.clone.1) + %add.252316.7.clone.1 = u32[1280,1280]{1,0} add(%add.252315.3.clone.1, %broadcast.258816.113.clone.1) + %shift-left.111323.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123702.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117606.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123702.3.clone.1, %broadcast.244418.4352) + %or.117143.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111323.5.clone.1, %shift-right-logical.117606.5.clone.1) + %xor.123703.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252315.3.clone.1, %or.117143.3.clone.1) + %constant_218675_1_clone_1 = u32[] constant(1302776332) + %broadcast.258827.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218675_1_clone_1), dimensions={} + %add.252317.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123703.3.clone.1, %broadcast.258827.5.clone.1) + %add.252318.5.clone.1 = u32[1280,1280]{1,0} add(%add.252316.7.clone.1, %add.252317.5.clone.1) + %shift-left.111324.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252317.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117607.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252317.5.clone.1, %broadcast.244416.5760) + %or.117145.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111324.9.clone.1, %shift-right-logical.117607.9.clone.1) + %xor.123704.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252318.5.clone.1, %or.117145.7.clone.1) + %add.252319.3.clone.1 = u32[1280,1280]{1,0} add(%add.252318.5.clone.1, %xor.123704.5.clone.1) + %shift-left.111325.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123704.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117608.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123704.5.clone.1, %broadcast.244429.2304) + %or.117146.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111325.9.clone.1, %shift-right-logical.117608.9.clone.1) + %xor.123705.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252319.3.clone.1, %or.117146.7.clone.1) + %add.252321.3.clone.1 = u32[1280,1280]{1,0} add(%add.252319.3.clone.1, %xor.123705.5.clone.1) + %shift-left.111326.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123705.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117610.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123705.5.clone.1, %broadcast.244430.4608) + %or.117147.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111326.9.clone.1, %shift-right-logical.117610.9.clone.1) + %xor.123706.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252321.3.clone.1, %or.117147.7.clone.1) + %add.252322.3.clone.1 = u32[1280,1280]{1,0} add(%add.252321.3.clone.1, %xor.123706.5.clone.1) + %constant_191459_1_clone_1 = u32[] constant(1302776331) + %broadcast.258841.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191459_1_clone_1), dimensions={} + %add.252323.7.clone.1 = u32[1280,1280]{1,0} add(%add.252322.3.clone.1, %broadcast.258841.24.clone.1) + %shift-left.111327.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123706.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117611.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123706.5.clone.1, %broadcast.244434.2816) + %or.117148.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111327.11.clone.1, %shift-right-logical.117611.11.clone.1) + %xor.123707.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252322.3.clone.1, %or.117148.9.clone.1) + %constant_218676_1_clone_1 = u32[] constant(2491264433) + %broadcast.258847.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218676_1_clone_1), dimensions={} + %add.252324.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123707.7.clone.1, %broadcast.258847.5.clone.1) + %add.252325.5.clone.1 = u32[1280,1280]{1,0} add(%add.252323.7.clone.1, %add.252324.5.clone.1) + %shift-left.111328.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252324.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117612.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252324.5.clone.1, %broadcast.244415.6016) + %or.117149.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111328.9.clone.1, %shift-right-logical.117612.9.clone.1) + %xor.123708.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252325.5.clone.1, %or.117149.7.clone.1) + %add.252326.3.clone.1 = u32[1280,1280]{1,0} add(%add.252325.5.clone.1, %xor.123708.5.clone.1) + %shift-left.111329.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123708.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117613.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123708.5.clone.1, %broadcast.244417.5760) + %or.117150.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111329.9.clone.1, %shift-right-logical.117613.9.clone.1) + %xor.123709.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252326.3.clone.1, %or.117150.7.clone.1) + %add.252327.3.clone.1 = u32[1280,1280]{1,0} add(%add.252326.3.clone.1, %xor.123709.5.clone.1) + %shift-left.111330.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123709.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117615.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123709.5.clone.1, %broadcast.244419.4352) + %or.117151.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111330.7.clone.1, %shift-right-logical.117615.7.clone.1) + %xor.123710.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252327.3.clone.1, %or.117151.5.clone.1) + %add.252328.3.clone.1 = u32[1280,1280]{1,0} add(%add.252327.3.clone.1, %xor.123710.3.clone.1) + %add.252329.7.clone.1 = u32[1280,1280]{1,0} add(%add.252328.3.clone.1, %broadcast.258815.44.clone.1) + %shift-left.111331.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123710.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117616.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123710.3.clone.1, %broadcast.244418.4352) + %or.117152.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111331.7.clone.1, %shift-right-logical.117616.7.clone.1) + %xor.123711.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252328.3.clone.1, %or.117152.5.clone.1) + %constant_218677_1_clone_1 = u32[] constant(3255467137) + %broadcast.258863.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218677_1_clone_1), dimensions={} + %add.252330.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123711.3.clone.1, %broadcast.258863.5.clone.1) + %add.252331.5.clone.1 = u32[1280,1280]{1,0} add(%add.252329.7.clone.1, %add.252330.5.clone.1) + %shift-left.111332.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252330.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117617.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252330.5.clone.1, %broadcast.244416.5760) + %or.117153.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111332.9.clone.1, %shift-right-logical.117617.9.clone.1) + %xor.123712.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252331.5.clone.1, %or.117153.7.clone.1) + %add.252332.3.clone.1 = u32[1280,1280]{1,0} add(%add.252331.5.clone.1, %xor.123712.5.clone.1) + %shift-left.111333.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123712.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117618.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123712.5.clone.1, %broadcast.244429.2304) + %or.117155.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111333.9.clone.1, %shift-right-logical.117618.9.clone.1) + %xor.123713.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252332.3.clone.1, %or.117155.7.clone.1) + %add.252333.3.clone.1 = u32[1280,1280]{1,0} add(%add.252332.3.clone.1, %xor.123713.5.clone.1) + %shift-left.111334.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123713.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117620.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123713.5.clone.1, %broadcast.244430.4608) + %or.117156.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111334.9.clone.1, %shift-right-logical.117620.9.clone.1) + %xor.123714.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252333.3.clone.1, %or.117156.7.clone.1) + %add.252334.3.clone.1 = u32[1280,1280]{1,0} add(%add.252333.3.clone.1, %xor.123714.5.clone.1) + %add.252335.7.clone.1 = u32[1280,1280]{1,0} add(%add.252334.3.clone.1, %broadcast.258816.113.clone.1) + %shift-left.111335.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123714.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117621.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123714.5.clone.1, %broadcast.244434.2816) + %or.117157.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111335.11.clone.1, %shift-right-logical.117621.11.clone.1) + %xor.123715.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252334.3.clone.1, %or.117157.9.clone.1) + %constant_218678_1_clone_1 = u32[] constant(1302776335) + %broadcast.258873.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218678_1_clone_1), dimensions={} + %add.252336.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123715.7.clone.1, %broadcast.258873.5.clone.1) + %add.252337.5.clone.1 = u32[1280,1280]{1,0} add(%add.252335.7.clone.1, %add.252336.5.clone.1) + %shift-left.111336.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252336.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117622.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252336.5.clone.1, %broadcast.244415.6016) + %or.117158.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111336.9.clone.1, %shift-right-logical.117622.9.clone.1) + %xor.123716.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252337.5.clone.1, %or.117158.7.clone.1) + %add.252338.3.clone.1 = u32[1280,1280]{1,0} add(%add.252337.5.clone.1, %xor.123716.5.clone.1) + %shift-left.111337.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123716.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117623.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123716.5.clone.1, %broadcast.244417.5760) + %or.117160.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111337.9.clone.1, %shift-right-logical.117623.9.clone.1) + %xor.123717.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252338.3.clone.1, %or.117160.7.clone.1) + %add.252339.3.clone.1 = u32[1280,1280]{1,0} add(%add.252338.3.clone.1, %xor.123717.5.clone.1) + %shift-left.111338.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123717.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117624.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123717.5.clone.1, %broadcast.244419.4352) + %or.117161.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111338.5.clone.1, %shift-right-logical.117624.5.clone.1) + %xor.123718.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252339.3.clone.1, %or.117161.3.clone.1) + %add.252341.3.clone.1 = u32[1280,1280]{1,0} add(%add.252339.3.clone.1, %xor.123718.3.clone.1) + %add.252342.17.clone.1 = u32[1280,1280]{1,0} add(%add.252341.3.clone.1, %broadcast.258841.24.clone.1) + %shift-left.111339.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123718.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117625.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123718.3.clone.1, %broadcast.244418.4352) + %or.117162.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111339.5.clone.1, %shift-right-logical.117625.5.clone.1) + %xor.123719.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252341.3.clone.1, %or.117162.3.clone.1) + %constant_218679_1_clone_1 = u32[] constant(2491264436) + %broadcast.258885.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218679_1_clone_1), dimensions={} + %add.252343.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123719.15.clone.1, %broadcast.258885.19.clone.1) + %xor.123720.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252342.17.clone.1, %add.252343.19.clone.1) + %shift-right-logical.117626.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123720.17.clone.1, %broadcast.244468.1920) + %or.117163.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117626.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5833.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117163.13.clone.1) + %add.252344.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5833.11.clone.1, %broadcast.244470.1152) + %multiply.27278.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252344.9.clone.1, %broadcast.244471.896) + %add.252345.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27278.7.clone.1, %broadcast.244408.1024) + %maximum.3765.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252345.5.clone.1) + %abs.1593.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3765.3.clone.1) + %compare.7348.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1593.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27279.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3765.3.clone.1, %broadcast.244476.1152) + %negate.4691.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3765.3.clone.1) + %multiply.27280.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3765.3.clone.1, %negate.4691.5.clone.1) + %log-plus-one.1593.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27280.5.clone.1) + %negate.4692.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1593.3.clone.1) + %compare.7349.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4692.4.clone.1, %broadcast.244477.384), direction=LT + %select.21630.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21631.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21632.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21633.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21634.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21635.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21636.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21637.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21638.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252346.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4692.4.clone.1, %broadcast.244496.640) + %sqrt.1593.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4692.4.clone.1) + %add.252347.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1593.5.clone.1, %broadcast.244498.640) + %select.21639.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7349.3.clone.1, %add.252346.5.clone.1, %add.252347.5.clone.1) + %multiply.27281.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21638.3.clone.1, %select.21639.3.clone.1) + %add.252348.1.clone.1 = f32[1280,1280]{1,0} add(%select.21637.3.clone.1, %multiply.27281.1.clone.1) + %multiply.27282.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252348.1.clone.1, %select.21639.3.clone.1) + %add.252349.1.clone.1 = f32[1280,1280]{1,0} add(%select.21636.3.clone.1, %multiply.27282.1.clone.1) + %multiply.27283.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252349.1.clone.1, %select.21639.3.clone.1) + %add.252350.1.clone.1 = f32[1280,1280]{1,0} add(%select.21635.3.clone.1, %multiply.27283.1.clone.1) + %multiply.27284.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252350.1.clone.1, %select.21639.3.clone.1) + %add.252351.1.clone.1 = f32[1280,1280]{1,0} add(%select.21634.3.clone.1, %multiply.27284.1.clone.1) + %multiply.27285.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252351.1.clone.1, %select.21639.3.clone.1) + %add.252352.3.clone.1 = f32[1280,1280]{1,0} add(%select.21633.5.clone.1, %multiply.27285.1.clone.1) + %multiply.27286.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252352.3.clone.1, %select.21639.3.clone.1) + %add.252353.3.clone.1 = f32[1280,1280]{1,0} add(%select.21632.5.clone.1, %multiply.27286.1.clone.1) + %multiply.27287.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252353.3.clone.1, %select.21639.3.clone.1) + %add.252354.9.clone.1 = f32[1280,1280]{1,0} add(%select.21631.11.clone.1, %multiply.27287.7.clone.1) + %multiply.27288.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252354.9.clone.1, %select.21639.3.clone.1) + %add.252355.7.clone.1 = f32[1280,1280]{1,0} add(%select.21630.7.clone.1, %multiply.27288.7.clone.1) + %multiply.27289.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252355.7.clone.1, %maximum.3765.3.clone.1) + %select.21640.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7348.3.clone.1, %multiply.27279.9.clone.1, %multiply.27289.7.clone.1) + %multiply.27290.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21640.7.clone.1, %broadcast.244500.640) + %clamp.1237.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27290.5.clone.1, %broadcast.244501.384) + %multiply.27291.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1237.3.clone.1, %broadcast.244502.1) + %constant_174383_1_clone_1 = u32[] constant(2599510339) + %broadcast.251439.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174383_1_clone_1), dimensions={} + %add.248084.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251439.44.clone.1) + %constant_174390_1_clone_1 = u32[] constant(191910730) + %broadcast.251440.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174390_1_clone_1), dimensions={} + %add.248085.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251440.113.clone.1) + %add.248086.35.clone.1 = u32[1280,1280]{1,0} add(%add.248084.37.clone.1, %add.248085.99.clone.1) + %shift-left.109480.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248085.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115665.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248085.99.clone.1, %broadcast.244415.6016) + %or.115193.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109480.31.clone.1, %shift-right-logical.115665.29.clone.1) + %xor.121745.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248086.35.clone.1, %or.115193.29.clone.1) + %add.248087.5.clone.1 = u32[1280,1280]{1,0} add(%add.248086.35.clone.1, %xor.121745.27.clone.1) + %shift-left.109481.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121745.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115667.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121745.27.clone.1, %broadcast.244417.5760) + %or.115194.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109481.9.clone.1, %shift-right-logical.115667.9.clone.1) + %xor.121746.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248087.5.clone.1, %or.115194.7.clone.1) + %add.248088.3.clone.1 = u32[1280,1280]{1,0} add(%add.248087.5.clone.1, %xor.121746.5.clone.1) + %shift-left.109482.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121746.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115668.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121746.5.clone.1, %broadcast.244419.4352) + %or.115195.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109482.5.clone.1, %shift-right-logical.115668.5.clone.1) + %xor.121748.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248088.3.clone.1, %or.115195.3.clone.1) + %add.248089.3.clone.1 = u32[1280,1280]{1,0} add(%add.248088.3.clone.1, %xor.121748.3.clone.1) + %add.248091.7.clone.1 = u32[1280,1280]{1,0} add(%add.248089.3.clone.1, %broadcast.251440.113.clone.1) + %shift-left.109483.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121748.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115669.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121748.3.clone.1, %broadcast.244418.4352) + %or.115196.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109483.5.clone.1, %shift-right-logical.115669.5.clone.1) + %xor.121749.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248089.3.clone.1, %or.115196.3.clone.1) + %constant_218217_1_clone_1 = u32[] constant(2320509396) + %broadcast.251450.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218217_1_clone_1), dimensions={} + %add.248094.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121749.3.clone.1, %broadcast.251450.5.clone.1) + %add.248095.5.clone.1 = u32[1280,1280]{1,0} add(%add.248091.7.clone.1, %add.248094.5.clone.1) + %shift-left.109484.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248094.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115670.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248094.5.clone.1, %broadcast.244416.5760) + %or.115198.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109484.9.clone.1, %shift-right-logical.115670.9.clone.1) + %xor.121750.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248095.5.clone.1, %or.115198.7.clone.1) + %add.248096.3.clone.1 = u32[1280,1280]{1,0} add(%add.248095.5.clone.1, %xor.121750.5.clone.1) + %shift-left.109485.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121750.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115671.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121750.5.clone.1, %broadcast.244429.2304) + %or.115199.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109485.9.clone.1, %shift-right-logical.115671.9.clone.1) + %xor.121751.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248096.3.clone.1, %or.115199.7.clone.1) + %add.248097.3.clone.1 = u32[1280,1280]{1,0} add(%add.248096.3.clone.1, %xor.121751.5.clone.1) + %shift-left.109486.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121751.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115672.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121751.5.clone.1, %broadcast.244430.4608) + %or.115200.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109486.9.clone.1, %shift-right-logical.115672.9.clone.1) + %xor.121753.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248097.3.clone.1, %or.115200.7.clone.1) + %add.248099.3.clone.1 = u32[1280,1280]{1,0} add(%add.248097.3.clone.1, %xor.121753.5.clone.1) + %constant_174392_1_clone_1 = u32[] constant(2320509395) + %broadcast.251459.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174392_1_clone_1), dimensions={} + %add.248100.7.clone.1 = u32[1280,1280]{1,0} add(%add.248099.3.clone.1, %broadcast.251459.24.clone.1) + %shift-left.109487.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121753.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115673.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121753.5.clone.1, %broadcast.244434.2816) + %or.115201.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109487.11.clone.1, %shift-right-logical.115673.11.clone.1) + %xor.121754.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248099.3.clone.1, %or.115201.9.clone.1) + %constant_218218_1_clone_1 = u32[] constant(2599510341) + %broadcast.251462.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218218_1_clone_1), dimensions={} + %add.248101.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121754.7.clone.1, %broadcast.251462.5.clone.1) + %add.248102.5.clone.1 = u32[1280,1280]{1,0} add(%add.248100.7.clone.1, %add.248101.5.clone.1) + %shift-left.109488.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248101.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115674.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248101.5.clone.1, %broadcast.244415.6016) + %or.115203.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109488.9.clone.1, %shift-right-logical.115674.9.clone.1) + %xor.121755.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248102.5.clone.1, %or.115203.7.clone.1) + %add.248104.3.clone.1 = u32[1280,1280]{1,0} add(%add.248102.5.clone.1, %xor.121755.5.clone.1) + %shift-left.109489.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121755.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115675.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121755.5.clone.1, %broadcast.244417.5760) + %or.115204.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109489.9.clone.1, %shift-right-logical.115675.9.clone.1) + %xor.121756.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248104.3.clone.1, %or.115204.7.clone.1) + %add.248105.3.clone.1 = u32[1280,1280]{1,0} add(%add.248104.3.clone.1, %xor.121756.5.clone.1) + %shift-left.109490.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121756.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115676.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121756.5.clone.1, %broadcast.244419.4352) + %or.115205.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109490.7.clone.1, %shift-right-logical.115676.7.clone.1) + %xor.121758.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248105.3.clone.1, %or.115205.5.clone.1) + %add.248106.3.clone.1 = u32[1280,1280]{1,0} add(%add.248105.3.clone.1, %xor.121758.3.clone.1) + %add.248107.7.clone.1 = u32[1280,1280]{1,0} add(%add.248106.3.clone.1, %broadcast.251439.44.clone.1) + %shift-left.109491.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121758.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115677.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121758.3.clone.1, %broadcast.244418.4352) + %or.115206.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109491.7.clone.1, %shift-right-logical.115677.7.clone.1) + %xor.121759.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248106.3.clone.1, %or.115206.5.clone.1) + %constant_218219_1_clone_1 = u32[] constant(191910733) + %broadcast.251472.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218219_1_clone_1), dimensions={} + %add.248109.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121759.3.clone.1, %broadcast.251472.5.clone.1) + %add.248110.5.clone.1 = u32[1280,1280]{1,0} add(%add.248107.7.clone.1, %add.248109.5.clone.1) + %shift-left.109492.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248109.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115678.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248109.5.clone.1, %broadcast.244416.5760) + %or.115208.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109492.9.clone.1, %shift-right-logical.115678.9.clone.1) + %xor.121760.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248110.5.clone.1, %or.115208.7.clone.1) + %add.248111.3.clone.1 = u32[1280,1280]{1,0} add(%add.248110.5.clone.1, %xor.121760.5.clone.1) + %shift-left.109493.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121760.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115679.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121760.5.clone.1, %broadcast.244429.2304) + %or.115209.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109493.9.clone.1, %shift-right-logical.115679.9.clone.1) + %xor.121761.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248111.3.clone.1, %or.115209.7.clone.1) + %add.248112.3.clone.1 = u32[1280,1280]{1,0} add(%add.248111.3.clone.1, %xor.121761.5.clone.1) + %shift-left.109494.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121761.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115680.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121761.5.clone.1, %broadcast.244430.4608) + %or.115210.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109494.9.clone.1, %shift-right-logical.115680.9.clone.1) + %xor.121762.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248112.3.clone.1, %or.115210.7.clone.1) + %add.248113.3.clone.1 = u32[1280,1280]{1,0} add(%add.248112.3.clone.1, %xor.121762.5.clone.1) + %add.248115.7.clone.1 = u32[1280,1280]{1,0} add(%add.248113.3.clone.1, %broadcast.251440.113.clone.1) + %shift-left.109495.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121762.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115681.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121762.5.clone.1, %broadcast.244434.2816) + %or.115211.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109495.11.clone.1, %shift-right-logical.115681.11.clone.1) + %xor.121763.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248113.3.clone.1, %or.115211.9.clone.1) + %constant_218220_1_clone_1 = u32[] constant(2320509399) + %broadcast.251484.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218220_1_clone_1), dimensions={} + %add.248119.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121763.7.clone.1, %broadcast.251484.5.clone.1) + %add.248120.5.clone.1 = u32[1280,1280]{1,0} add(%add.248115.7.clone.1, %add.248119.5.clone.1) + %shift-left.109496.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248119.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115682.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248119.5.clone.1, %broadcast.244415.6016) + %or.115213.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109496.9.clone.1, %shift-right-logical.115682.9.clone.1) + %xor.121764.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248120.5.clone.1, %or.115213.7.clone.1) + %add.248121.3.clone.1 = u32[1280,1280]{1,0} add(%add.248120.5.clone.1, %xor.121764.5.clone.1) + %shift-left.109497.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121764.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115683.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121764.5.clone.1, %broadcast.244417.5760) + %or.115214.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109497.9.clone.1, %shift-right-logical.115683.9.clone.1) + %xor.121765.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248121.3.clone.1, %or.115214.7.clone.1) + %add.248122.3.clone.1 = u32[1280,1280]{1,0} add(%add.248121.3.clone.1, %xor.121765.5.clone.1) + %shift-left.109498.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121765.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115684.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121765.5.clone.1, %broadcast.244419.4352) + %or.115215.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109498.5.clone.1, %shift-right-logical.115684.5.clone.1) + %xor.121766.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248122.3.clone.1, %or.115215.3.clone.1) + %add.248124.3.clone.1 = u32[1280,1280]{1,0} add(%add.248122.3.clone.1, %xor.121766.3.clone.1) + %add.248125.17.clone.1 = u32[1280,1280]{1,0} add(%add.248124.3.clone.1, %broadcast.251459.24.clone.1) + %shift-left.109499.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121766.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115685.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121766.3.clone.1, %broadcast.244418.4352) + %or.115216.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109499.5.clone.1, %shift-right-logical.115685.5.clone.1) + %xor.121768.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248124.3.clone.1, %or.115216.3.clone.1) + %constant_218221_1_clone_1 = u32[] constant(2599510344) + %broadcast.251494.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218221_1_clone_1), dimensions={} + %add.248126.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121768.15.clone.1, %broadcast.251494.19.clone.1) + %xor.121769.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248125.17.clone.1, %add.248126.19.clone.1) + %shift-right-logical.115686.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121769.17.clone.1, %broadcast.244468.1920) + %or.115217.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115686.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5749.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115217.13.clone.1) + %add.248127.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5749.11.clone.1, %broadcast.244470.1152) + %multiply.26410.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248127.9.clone.1, %broadcast.244471.896) + %add.248129.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26410.7.clone.1, %broadcast.244408.1024) + %maximum.3681.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248129.5.clone.1) + %abs.1537.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3681.3.clone.1) + %compare.7222.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1537.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26411.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3681.3.clone.1, %broadcast.244476.1152) + %negate.4579.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3681.3.clone.1) + %multiply.26412.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3681.3.clone.1, %negate.4579.5.clone.1) + %log-plus-one.1537.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26412.5.clone.1) + %negate.4580.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1537.3.clone.1) + %compare.7223.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4580.4.clone.1, %broadcast.244477.384), direction=LT + %select.20978.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20980.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20981.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20983.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20984.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20986.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20987.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20989.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20990.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248130.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4580.4.clone.1, %broadcast.244496.640) + %sqrt.1537.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4580.4.clone.1) + %add.248131.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1537.5.clone.1, %broadcast.244498.640) + %select.20991.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7223.3.clone.1, %add.248130.5.clone.1, %add.248131.5.clone.1) + %multiply.26413.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20990.3.clone.1, %select.20991.3.clone.1) + %add.248132.1.clone.1 = f32[1280,1280]{1,0} add(%select.20989.3.clone.1, %multiply.26413.1.clone.1) + %multiply.26414.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248132.1.clone.1, %select.20991.3.clone.1) + %add.248134.1.clone.1 = f32[1280,1280]{1,0} add(%select.20987.3.clone.1, %multiply.26414.1.clone.1) + %multiply.26415.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248134.1.clone.1, %select.20991.3.clone.1) + %add.248135.1.clone.1 = f32[1280,1280]{1,0} add(%select.20986.3.clone.1, %multiply.26415.1.clone.1) + %multiply.26416.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248135.1.clone.1, %select.20991.3.clone.1) + %add.248136.1.clone.1 = f32[1280,1280]{1,0} add(%select.20984.3.clone.1, %multiply.26416.1.clone.1) + %multiply.26417.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248136.1.clone.1, %select.20991.3.clone.1) + %add.248137.3.clone.1 = f32[1280,1280]{1,0} add(%select.20983.5.clone.1, %multiply.26417.1.clone.1) + %multiply.26418.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248137.3.clone.1, %select.20991.3.clone.1) + %add.248138.3.clone.1 = f32[1280,1280]{1,0} add(%select.20981.5.clone.1, %multiply.26418.1.clone.1) + %multiply.26419.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248138.3.clone.1, %select.20991.3.clone.1) + %add.248140.9.clone.1 = f32[1280,1280]{1,0} add(%select.20980.11.clone.1, %multiply.26419.7.clone.1) + %multiply.26420.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248140.9.clone.1, %select.20991.3.clone.1) + %add.248144.7.clone.1 = f32[1280,1280]{1,0} add(%select.20978.7.clone.1, %multiply.26420.7.clone.1) + %multiply.26421.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248144.7.clone.1, %maximum.3681.3.clone.1) + %select.20992.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7222.3.clone.1, %multiply.26411.9.clone.1, %multiply.26421.7.clone.1) + %multiply.26422.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20992.7.clone.1, %broadcast.244500.640) + %clamp.1181.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26422.5.clone.1, %broadcast.244501.384) + %multiply.26423.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1181.3.clone.1, %broadcast.244502.1) + %constant_185593_1_clone_1 = u32[] constant(2718060488) + %broadcast.256300.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185593_1_clone_1), dimensions={} + %add.250843.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256300.44.clone.1) + %constant_185600_1_clone_1 = u32[] constant(3294696415) + %broadcast.256301.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185600_1_clone_1), dimensions={} + %add.250844.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256301.113.clone.1) + %add.250845.35.clone.1 = u32[1280,1280]{1,0} add(%add.250843.37.clone.1, %add.250844.99.clone.1) + %shift-left.110680.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250844.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116941.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250844.99.clone.1, %broadcast.244415.6016) + %or.116455.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110680.31.clone.1, %shift-right-logical.116941.29.clone.1) + %xor.123027.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250845.35.clone.1, %or.116455.29.clone.1) + %add.250847.5.clone.1 = u32[1280,1280]{1,0} add(%add.250845.35.clone.1, %xor.123027.27.clone.1) + %shift-left.110681.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123027.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116942.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123027.27.clone.1, %broadcast.244417.5760) + %or.116456.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110681.9.clone.1, %shift-right-logical.116942.9.clone.1) + %xor.123029.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250847.5.clone.1, %or.116456.7.clone.1) + %add.250848.3.clone.1 = u32[1280,1280]{1,0} add(%add.250847.5.clone.1, %xor.123029.5.clone.1) + %shift-left.110682.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123029.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116943.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123029.5.clone.1, %broadcast.244419.4352) + %or.116457.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110682.5.clone.1, %shift-right-logical.116943.5.clone.1) + %xor.123030.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250848.3.clone.1, %or.116457.3.clone.1) + %add.250849.3.clone.1 = u32[1280,1280]{1,0} add(%add.250848.3.clone.1, %xor.123030.3.clone.1) + %add.250850.7.clone.1 = u32[1280,1280]{1,0} add(%add.250849.3.clone.1, %broadcast.256301.113.clone.1) + %shift-left.110683.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123030.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116945.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123030.3.clone.1, %broadcast.244418.4352) + %or.116458.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110683.5.clone.1, %shift-right-logical.116945.5.clone.1) + %xor.123031.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250849.3.clone.1, %or.116458.3.clone.1) + %constant_218509_1_clone_1 = u32[] constant(2108835790) + %broadcast.256311.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218509_1_clone_1), dimensions={} + %add.250852.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123031.3.clone.1, %broadcast.256311.5.clone.1) + %add.250853.5.clone.1 = u32[1280,1280]{1,0} add(%add.250850.7.clone.1, %add.250852.5.clone.1) + %shift-left.110684.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250852.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116946.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250852.5.clone.1, %broadcast.244416.5760) + %or.116459.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110684.9.clone.1, %shift-right-logical.116946.9.clone.1) + %xor.123032.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250853.5.clone.1, %or.116459.7.clone.1) + %add.250854.3.clone.1 = u32[1280,1280]{1,0} add(%add.250853.5.clone.1, %xor.123032.5.clone.1) + %shift-left.110685.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123032.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116947.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123032.5.clone.1, %broadcast.244429.2304) + %or.116460.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110685.9.clone.1, %shift-right-logical.116947.9.clone.1) + %xor.123034.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250854.3.clone.1, %or.116460.7.clone.1) + %add.250855.3.clone.1 = u32[1280,1280]{1,0} add(%add.250854.3.clone.1, %xor.123034.5.clone.1) + %shift-left.110686.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123034.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116948.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123034.5.clone.1, %broadcast.244430.4608) + %or.116461.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110686.9.clone.1, %shift-right-logical.116948.9.clone.1) + %xor.123035.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250855.3.clone.1, %or.116461.7.clone.1) + %add.250856.3.clone.1 = u32[1280,1280]{1,0} add(%add.250855.3.clone.1, %xor.123035.5.clone.1) + %constant_185602_1_clone_1 = u32[] constant(2108835789) + %broadcast.256318.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185602_1_clone_1), dimensions={} + %add.250858.7.clone.1 = u32[1280,1280]{1,0} add(%add.250856.3.clone.1, %broadcast.256318.24.clone.1) + %shift-left.110687.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123035.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116950.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123035.5.clone.1, %broadcast.244434.2816) + %or.116462.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110687.11.clone.1, %shift-right-logical.116950.11.clone.1) + %xor.123036.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250856.3.clone.1, %or.116462.9.clone.1) + %constant_218510_1_clone_1 = u32[] constant(2718060490) + %broadcast.256321.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218510_1_clone_1), dimensions={} + %add.250862.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123036.7.clone.1, %broadcast.256321.5.clone.1) + %add.250863.5.clone.1 = u32[1280,1280]{1,0} add(%add.250858.7.clone.1, %add.250862.5.clone.1) + %shift-left.110688.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250862.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116951.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250862.5.clone.1, %broadcast.244415.6016) + %or.116463.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110688.9.clone.1, %shift-right-logical.116951.9.clone.1) + %xor.123037.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250863.5.clone.1, %or.116463.7.clone.1) + %add.250864.3.clone.1 = u32[1280,1280]{1,0} add(%add.250863.5.clone.1, %xor.123037.5.clone.1) + %shift-left.110689.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123037.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116952.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123037.5.clone.1, %broadcast.244417.5760) + %or.116464.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110689.9.clone.1, %shift-right-logical.116952.9.clone.1) + %xor.123039.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250864.3.clone.1, %or.116464.7.clone.1) + %add.250865.3.clone.1 = u32[1280,1280]{1,0} add(%add.250864.3.clone.1, %xor.123039.5.clone.1) + %shift-left.110690.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123039.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116953.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123039.5.clone.1, %broadcast.244419.4352) + %or.116465.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110690.7.clone.1, %shift-right-logical.116953.7.clone.1) + %xor.123040.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250865.3.clone.1, %or.116465.5.clone.1) + %add.250867.3.clone.1 = u32[1280,1280]{1,0} add(%add.250865.3.clone.1, %xor.123040.3.clone.1) + %add.250868.7.clone.1 = u32[1280,1280]{1,0} add(%add.250867.3.clone.1, %broadcast.256300.44.clone.1) + %shift-left.110691.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123040.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116955.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123040.3.clone.1, %broadcast.244418.4352) + %or.116466.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110691.7.clone.1, %shift-right-logical.116955.7.clone.1) + %xor.123041.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250867.3.clone.1, %or.116466.5.clone.1) + %constant_218511_1_clone_1 = u32[] constant(3294696418) + %broadcast.256331.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218511_1_clone_1), dimensions={} + %add.250869.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123041.3.clone.1, %broadcast.256331.5.clone.1) + %add.250870.5.clone.1 = u32[1280,1280]{1,0} add(%add.250868.7.clone.1, %add.250869.5.clone.1) + %shift-left.110692.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250869.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116956.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250869.5.clone.1, %broadcast.244416.5760) + %or.116467.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110692.9.clone.1, %shift-right-logical.116956.9.clone.1) + %xor.123042.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250870.5.clone.1, %or.116467.7.clone.1) + %add.250872.3.clone.1 = u32[1280,1280]{1,0} add(%add.250870.5.clone.1, %xor.123042.5.clone.1) + %shift-left.110693.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123042.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116957.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123042.5.clone.1, %broadcast.244429.2304) + %or.116468.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110693.9.clone.1, %shift-right-logical.116957.9.clone.1) + %xor.123043.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250872.3.clone.1, %or.116468.7.clone.1) + %add.250873.3.clone.1 = u32[1280,1280]{1,0} add(%add.250872.3.clone.1, %xor.123043.5.clone.1) + %shift-left.110695.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123043.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116958.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123043.5.clone.1, %broadcast.244430.4608) + %or.116469.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110695.9.clone.1, %shift-right-logical.116958.9.clone.1) + %xor.123044.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250873.3.clone.1, %or.116469.7.clone.1) + %add.250874.3.clone.1 = u32[1280,1280]{1,0} add(%add.250873.3.clone.1, %xor.123044.5.clone.1) + %add.250875.7.clone.1 = u32[1280,1280]{1,0} add(%add.250874.3.clone.1, %broadcast.256301.113.clone.1) + %shift-left.110696.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123044.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116960.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123044.5.clone.1, %broadcast.244434.2816) + %or.116470.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110696.11.clone.1, %shift-right-logical.116960.11.clone.1) + %xor.123045.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250874.3.clone.1, %or.116470.9.clone.1) + %constant_218512_1_clone_1 = u32[] constant(2108835793) + %broadcast.256341.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218512_1_clone_1), dimensions={} + %add.250877.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123045.7.clone.1, %broadcast.256341.5.clone.1) + %add.250878.5.clone.1 = u32[1280,1280]{1,0} add(%add.250875.7.clone.1, %add.250877.5.clone.1) + %shift-left.110697.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250877.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116961.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250877.5.clone.1, %broadcast.244415.6016) + %or.116471.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110697.9.clone.1, %shift-right-logical.116961.9.clone.1) + %xor.123046.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250878.5.clone.1, %or.116471.7.clone.1) + %add.250879.3.clone.1 = u32[1280,1280]{1,0} add(%add.250878.5.clone.1, %xor.123046.5.clone.1) + %shift-left.110698.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123046.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116962.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123046.5.clone.1, %broadcast.244417.5760) + %or.116472.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110698.9.clone.1, %shift-right-logical.116962.9.clone.1) + %xor.123047.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250879.3.clone.1, %or.116472.7.clone.1) + %add.250880.3.clone.1 = u32[1280,1280]{1,0} add(%add.250879.3.clone.1, %xor.123047.5.clone.1) + %shift-left.110700.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123047.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116963.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123047.5.clone.1, %broadcast.244419.4352) + %or.116473.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110700.5.clone.1, %shift-right-logical.116963.5.clone.1) + %xor.123049.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250880.3.clone.1, %or.116473.3.clone.1) + %add.250881.3.clone.1 = u32[1280,1280]{1,0} add(%add.250880.3.clone.1, %xor.123049.3.clone.1) + %add.250883.17.clone.1 = u32[1280,1280]{1,0} add(%add.250881.3.clone.1, %broadcast.256318.24.clone.1) + %shift-left.110701.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123049.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116964.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123049.3.clone.1, %broadcast.244418.4352) + %or.116474.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110701.5.clone.1, %shift-right-logical.116964.5.clone.1) + %xor.123050.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250881.3.clone.1, %or.116474.3.clone.1) + %constant_218513_1_clone_1 = u32[] constant(2718060493) + %broadcast.256351.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218513_1_clone_1), dimensions={} + %add.250887.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123050.15.clone.1, %broadcast.256351.19.clone.1) + %xor.123051.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250883.17.clone.1, %add.250887.19.clone.1) + %shift-right-logical.116965.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123051.17.clone.1, %broadcast.244468.1920) + %or.116475.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116965.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5804.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116475.13.clone.1) + %add.250888.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5804.11.clone.1, %broadcast.244470.1152) + %multiply.26979.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250888.9.clone.1, %broadcast.244471.896) + %add.250889.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26979.7.clone.1, %broadcast.244408.1024) + %maximum.3736.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250889.5.clone.1) + %abs.1574.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3736.3.clone.1) + %compare.7304.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1574.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26980.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3736.3.clone.1, %broadcast.244476.1152) + %negate.4653.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3736.3.clone.1) + %multiply.26981.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3736.3.clone.1, %negate.4653.5.clone.1) + %log-plus-one.1574.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26981.5.clone.1) + %negate.4654.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1574.3.clone.1) + %compare.7305.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4654.4.clone.1, %broadcast.244477.384), direction=LT + %select.21400.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21401.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21402.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21403.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21404.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21405.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21406.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21407.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21408.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250890.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4654.4.clone.1, %broadcast.244496.640) + %sqrt.1574.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4654.4.clone.1) + %add.250892.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1574.5.clone.1, %broadcast.244498.640) + %select.21409.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7305.3.clone.1, %add.250890.5.clone.1, %add.250892.5.clone.1) + %multiply.26982.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21408.3.clone.1, %select.21409.3.clone.1) + %add.250893.1.clone.1 = f32[1280,1280]{1,0} add(%select.21407.3.clone.1, %multiply.26982.1.clone.1) + %multiply.26983.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250893.1.clone.1, %select.21409.3.clone.1) + %add.250894.1.clone.1 = f32[1280,1280]{1,0} add(%select.21406.3.clone.1, %multiply.26983.1.clone.1) + %multiply.26984.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250894.1.clone.1, %select.21409.3.clone.1) + %add.250895.1.clone.1 = f32[1280,1280]{1,0} add(%select.21405.3.clone.1, %multiply.26984.1.clone.1) + %multiply.26985.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250895.1.clone.1, %select.21409.3.clone.1) + %add.250897.1.clone.1 = f32[1280,1280]{1,0} add(%select.21404.3.clone.1, %multiply.26985.1.clone.1) + %multiply.26986.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250897.1.clone.1, %select.21409.3.clone.1) + %add.250898.3.clone.1 = f32[1280,1280]{1,0} add(%select.21403.5.clone.1, %multiply.26986.1.clone.1) + %multiply.26987.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250898.3.clone.1, %select.21409.3.clone.1) + %add.250899.3.clone.1 = f32[1280,1280]{1,0} add(%select.21402.5.clone.1, %multiply.26987.1.clone.1) + %multiply.26988.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250899.3.clone.1, %select.21409.3.clone.1) + %add.250900.9.clone.1 = f32[1280,1280]{1,0} add(%select.21401.11.clone.1, %multiply.26988.7.clone.1) + %multiply.26989.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250900.9.clone.1, %select.21409.3.clone.1) + %add.250902.7.clone.1 = f32[1280,1280]{1,0} add(%select.21400.7.clone.1, %multiply.26989.7.clone.1) + %multiply.26990.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250902.7.clone.1, %maximum.3736.3.clone.1) + %select.21410.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7304.3.clone.1, %multiply.26980.9.clone.1, %multiply.26990.7.clone.1) + %multiply.26991.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21410.7.clone.1, %broadcast.244500.640) + %clamp.1218.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26991.5.clone.1, %broadcast.244501.384) + %multiply.26992.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1218.3.clone.1, %broadcast.244502.1) + %constant_174172_1_clone_1 = u32[] constant(3829579796) + %broadcast.251332.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174172_1_clone_1), dimensions={} + %add.248040.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251332.44.clone.1) + %constant_174179_1_clone_1 = u32[] constant(645380412) + %broadcast.251333.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174179_1_clone_1), dimensions={} + %add.248041.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251333.113.clone.1) + %add.248042.35.clone.1 = u32[1280,1280]{1,0} add(%add.248040.37.clone.1, %add.248041.99.clone.1) + %shift-left.109460.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248041.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115641.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248041.99.clone.1, %broadcast.244415.6016) + %or.115168.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109460.31.clone.1, %shift-right-logical.115641.29.clone.1) + %xor.121720.27.clone.1 = u32[1280,1280]{1,0} xor(%add.248042.35.clone.1, %or.115168.29.clone.1) + %add.248043.5.clone.1 = u32[1280,1280]{1,0} add(%add.248042.35.clone.1, %xor.121720.27.clone.1) + %shift-left.109461.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121720.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115642.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121720.27.clone.1, %broadcast.244417.5760) + %or.115169.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109461.9.clone.1, %shift-right-logical.115642.9.clone.1) + %xor.121721.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248043.5.clone.1, %or.115169.7.clone.1) + %add.248044.3.clone.1 = u32[1280,1280]{1,0} add(%add.248043.5.clone.1, %xor.121721.5.clone.1) + %shift-left.109462.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121721.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115644.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121721.5.clone.1, %broadcast.244419.4352) + %or.115170.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109462.5.clone.1, %shift-right-logical.115644.5.clone.1) + %xor.121723.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248044.3.clone.1, %or.115170.3.clone.1) + %add.248045.3.clone.1 = u32[1280,1280]{1,0} add(%add.248044.3.clone.1, %xor.121723.3.clone.1) + %add.248046.7.clone.1 = u32[1280,1280]{1,0} add(%add.248045.3.clone.1, %broadcast.251333.113.clone.1) + %shift-left.109463.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121723.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115645.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121723.3.clone.1, %broadcast.244418.4352) + %or.115171.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109463.5.clone.1, %shift-right-logical.115645.5.clone.1) + %xor.121724.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248045.3.clone.1, %or.115171.3.clone.1) + %constant_218212_1_clone_1 = u32[] constant(3655623411) + %broadcast.251343.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218212_1_clone_1), dimensions={} + %add.248047.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121724.3.clone.1, %broadcast.251343.5.clone.1) + %add.248048.5.clone.1 = u32[1280,1280]{1,0} add(%add.248046.7.clone.1, %add.248047.5.clone.1) + %shift-left.109464.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248047.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115646.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248047.5.clone.1, %broadcast.244416.5760) + %or.115173.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109464.9.clone.1, %shift-right-logical.115646.9.clone.1) + %xor.121725.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248048.5.clone.1, %or.115173.7.clone.1) + %add.248049.3.clone.1 = u32[1280,1280]{1,0} add(%add.248048.5.clone.1, %xor.121725.5.clone.1) + %shift-left.109465.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121725.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115647.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121725.5.clone.1, %broadcast.244429.2304) + %or.115174.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109465.9.clone.1, %shift-right-logical.115647.9.clone.1) + %xor.121726.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248049.3.clone.1, %or.115174.7.clone.1) + %add.248050.3.clone.1 = u32[1280,1280]{1,0} add(%add.248049.3.clone.1, %xor.121726.5.clone.1) + %shift-left.109466.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121726.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115649.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121726.5.clone.1, %broadcast.244430.4608) + %or.115175.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109466.9.clone.1, %shift-right-logical.115649.9.clone.1) + %xor.121728.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248050.3.clone.1, %or.115175.7.clone.1) + %add.248051.3.clone.1 = u32[1280,1280]{1,0} add(%add.248050.3.clone.1, %xor.121728.5.clone.1) + %constant_174181_1_clone_1 = u32[] constant(3655623410) + %broadcast.251351.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_174181_1_clone_1), dimensions={} + %add.248052.7.clone.1 = u32[1280,1280]{1,0} add(%add.248051.3.clone.1, %broadcast.251351.24.clone.1) + %shift-left.109467.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121728.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115650.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121728.5.clone.1, %broadcast.244434.2816) + %or.115176.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109467.11.clone.1, %shift-right-logical.115650.11.clone.1) + %xor.121729.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248051.3.clone.1, %or.115176.9.clone.1) + %constant_218213_1_clone_1 = u32[] constant(3829579798) + %broadcast.251357.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218213_1_clone_1), dimensions={} + %add.248053.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121729.7.clone.1, %broadcast.251357.5.clone.1) + %add.248054.5.clone.1 = u32[1280,1280]{1,0} add(%add.248052.7.clone.1, %add.248053.5.clone.1) + %shift-left.109468.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248053.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115651.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248053.5.clone.1, %broadcast.244415.6016) + %or.115178.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109468.9.clone.1, %shift-right-logical.115651.9.clone.1) + %xor.121730.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248054.5.clone.1, %or.115178.7.clone.1) + %add.248055.3.clone.1 = u32[1280,1280]{1,0} add(%add.248054.5.clone.1, %xor.121730.5.clone.1) + %shift-left.109469.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121730.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115652.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121730.5.clone.1, %broadcast.244417.5760) + %or.115179.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109469.9.clone.1, %shift-right-logical.115652.9.clone.1) + %xor.121731.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248055.3.clone.1, %or.115179.7.clone.1) + %add.248056.3.clone.1 = u32[1280,1280]{1,0} add(%add.248055.3.clone.1, %xor.121731.5.clone.1) + %shift-left.109470.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121731.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115654.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121731.5.clone.1, %broadcast.244419.4352) + %or.115180.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109470.7.clone.1, %shift-right-logical.115654.7.clone.1) + %xor.121733.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248056.3.clone.1, %or.115180.5.clone.1) + %add.248057.3.clone.1 = u32[1280,1280]{1,0} add(%add.248056.3.clone.1, %xor.121733.3.clone.1) + %add.248058.7.clone.1 = u32[1280,1280]{1,0} add(%add.248057.3.clone.1, %broadcast.251332.44.clone.1) + %shift-left.109471.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121733.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115655.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121733.3.clone.1, %broadcast.244418.4352) + %or.115181.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109471.7.clone.1, %shift-right-logical.115655.7.clone.1) + %xor.121734.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248057.3.clone.1, %or.115181.5.clone.1) + %constant_218214_1_clone_1 = u32[] constant(645380415) + %broadcast.251377.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218214_1_clone_1), dimensions={} + %add.248059.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121734.3.clone.1, %broadcast.251377.5.clone.1) + %add.248060.5.clone.1 = u32[1280,1280]{1,0} add(%add.248058.7.clone.1, %add.248059.5.clone.1) + %shift-left.109472.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248059.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115656.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248059.5.clone.1, %broadcast.244416.5760) + %or.115183.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109472.9.clone.1, %shift-right-logical.115656.9.clone.1) + %xor.121735.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248060.5.clone.1, %or.115183.7.clone.1) + %add.248061.3.clone.1 = u32[1280,1280]{1,0} add(%add.248060.5.clone.1, %xor.121735.5.clone.1) + %shift-left.109473.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121735.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115657.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121735.5.clone.1, %broadcast.244429.2304) + %or.115184.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109473.9.clone.1, %shift-right-logical.115657.9.clone.1) + %xor.121736.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248061.3.clone.1, %or.115184.7.clone.1) + %add.248062.3.clone.1 = u32[1280,1280]{1,0} add(%add.248061.3.clone.1, %xor.121736.5.clone.1) + %shift-left.109474.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121736.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115658.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121736.5.clone.1, %broadcast.244430.4608) + %or.115185.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109474.9.clone.1, %shift-right-logical.115658.9.clone.1) + %xor.121737.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248062.3.clone.1, %or.115185.7.clone.1) + %add.248063.3.clone.1 = u32[1280,1280]{1,0} add(%add.248062.3.clone.1, %xor.121737.5.clone.1) + %add.248064.7.clone.1 = u32[1280,1280]{1,0} add(%add.248063.3.clone.1, %broadcast.251333.113.clone.1) + %shift-left.109475.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121737.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115659.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121737.5.clone.1, %broadcast.244434.2816) + %or.115186.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109475.11.clone.1, %shift-right-logical.115659.11.clone.1) + %xor.121738.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248063.3.clone.1, %or.115186.9.clone.1) + %constant_218215_1_clone_1 = u32[] constant(3655623414) + %broadcast.251390.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218215_1_clone_1), dimensions={} + %add.248065.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121738.7.clone.1, %broadcast.251390.5.clone.1) + %add.248066.5.clone.1 = u32[1280,1280]{1,0} add(%add.248064.7.clone.1, %add.248065.5.clone.1) + %shift-left.109476.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248065.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115660.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248065.5.clone.1, %broadcast.244415.6016) + %or.115188.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109476.9.clone.1, %shift-right-logical.115660.9.clone.1) + %xor.121739.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248066.5.clone.1, %or.115188.7.clone.1) + %add.248067.3.clone.1 = u32[1280,1280]{1,0} add(%add.248066.5.clone.1, %xor.121739.5.clone.1) + %shift-left.109477.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121739.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115661.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121739.5.clone.1, %broadcast.244417.5760) + %or.115189.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109477.9.clone.1, %shift-right-logical.115661.9.clone.1) + %xor.121740.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248067.3.clone.1, %or.115189.7.clone.1) + %add.248068.3.clone.1 = u32[1280,1280]{1,0} add(%add.248067.3.clone.1, %xor.121740.5.clone.1) + %shift-left.109478.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121740.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115662.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121740.5.clone.1, %broadcast.244419.4352) + %or.115190.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109478.5.clone.1, %shift-right-logical.115662.5.clone.1) + %xor.121741.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248068.3.clone.1, %or.115190.3.clone.1) + %add.248069.3.clone.1 = u32[1280,1280]{1,0} add(%add.248068.3.clone.1, %xor.121741.3.clone.1) + %add.248070.17.clone.1 = u32[1280,1280]{1,0} add(%add.248069.3.clone.1, %broadcast.251351.24.clone.1) + %shift-left.109479.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121741.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115663.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121741.3.clone.1, %broadcast.244418.4352) + %or.115191.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109479.5.clone.1, %shift-right-logical.115663.5.clone.1) + %xor.121743.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248069.3.clone.1, %or.115191.3.clone.1) + %constant_218216_1_clone_1 = u32[] constant(3829579801) + %broadcast.251400.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218216_1_clone_1), dimensions={} + %add.248071.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121743.15.clone.1, %broadcast.251400.19.clone.1) + %xor.121744.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248070.17.clone.1, %add.248071.19.clone.1) + %shift-right-logical.115664.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121744.17.clone.1, %broadcast.244468.1920) + %or.115192.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115664.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5748.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115192.13.clone.1) + %add.248072.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5748.11.clone.1, %broadcast.244470.1152) + %multiply.26396.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248072.9.clone.1, %broadcast.244471.896) + %add.248073.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26396.7.clone.1, %broadcast.244408.1024) + %maximum.3680.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248073.5.clone.1) + %abs.1536.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3680.3.clone.1) + %compare.7220.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1536.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26397.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3680.3.clone.1, %broadcast.244476.1152) + %negate.4577.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3680.3.clone.1) + %multiply.26398.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3680.3.clone.1, %negate.4577.5.clone.1) + %log-plus-one.1536.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26398.5.clone.1) + %negate.4578.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1536.3.clone.1) + %compare.7221.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4578.4.clone.1, %broadcast.244477.384), direction=LT + %select.20961.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20962.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20963.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20964.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20968.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20969.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20971.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20972.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20974.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248074.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4578.4.clone.1, %broadcast.244496.640) + %sqrt.1536.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4578.4.clone.1) + %add.248075.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1536.5.clone.1, %broadcast.244498.640) + %select.20975.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7221.3.clone.1, %add.248074.5.clone.1, %add.248075.5.clone.1) + %multiply.26399.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20974.3.clone.1, %select.20975.3.clone.1) + %add.248076.1.clone.1 = f32[1280,1280]{1,0} add(%select.20972.3.clone.1, %multiply.26399.1.clone.1) + %multiply.26400.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248076.1.clone.1, %select.20975.3.clone.1) + %add.248077.1.clone.1 = f32[1280,1280]{1,0} add(%select.20971.3.clone.1, %multiply.26400.1.clone.1) + %multiply.26401.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248077.1.clone.1, %select.20975.3.clone.1) + %add.248078.1.clone.1 = f32[1280,1280]{1,0} add(%select.20969.3.clone.1, %multiply.26401.1.clone.1) + %multiply.26402.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248078.1.clone.1, %select.20975.3.clone.1) + %add.248079.1.clone.1 = f32[1280,1280]{1,0} add(%select.20968.3.clone.1, %multiply.26402.1.clone.1) + %multiply.26403.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248079.1.clone.1, %select.20975.3.clone.1) + %add.248080.3.clone.1 = f32[1280,1280]{1,0} add(%select.20964.5.clone.1, %multiply.26403.1.clone.1) + %multiply.26404.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248080.3.clone.1, %select.20975.3.clone.1) + %add.248081.3.clone.1 = f32[1280,1280]{1,0} add(%select.20963.5.clone.1, %multiply.26404.1.clone.1) + %multiply.26405.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248081.3.clone.1, %select.20975.3.clone.1) + %add.248082.9.clone.1 = f32[1280,1280]{1,0} add(%select.20962.11.clone.1, %multiply.26405.7.clone.1) + %multiply.26406.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248082.9.clone.1, %select.20975.3.clone.1) + %add.248083.7.clone.1 = f32[1280,1280]{1,0} add(%select.20961.7.clone.1, %multiply.26406.7.clone.1) + %multiply.26407.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248083.7.clone.1, %maximum.3680.3.clone.1) + %select.20977.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7220.3.clone.1, %multiply.26397.9.clone.1, %multiply.26407.7.clone.1) + %multiply.26408.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20977.7.clone.1, %broadcast.244500.640) + %clamp.1180.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26408.5.clone.1, %broadcast.244501.384) + %multiply.26409.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1180.3.clone.1, %broadcast.244502.1) + %constant_194118_1_clone_1 = u32[] constant(708617427) + %broadcast.259978.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_194118_1_clone_1), dimensions={} + %add.252962.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.259978.44.clone.1) + %constant_194125_1_clone_1 = u32[] constant(1243070488) + %broadcast.259979.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_194125_1_clone_1), dimensions={} + %add.252963.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.259979.113.clone.1) + %add.252965.35.clone.1 = u32[1280,1280]{1,0} add(%add.252962.37.clone.1, %add.252963.99.clone.1) + %shift-left.111600.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252963.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117897.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252963.99.clone.1, %broadcast.244415.6016) + %or.117424.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111600.31.clone.1, %shift-right-logical.117897.29.clone.1) + %xor.123984.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252965.35.clone.1, %or.117424.29.clone.1) + %add.252966.5.clone.1 = u32[1280,1280]{1,0} add(%add.252965.35.clone.1, %xor.123984.27.clone.1) + %shift-left.111601.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123984.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117898.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123984.27.clone.1, %broadcast.244417.5760) + %or.117425.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111601.9.clone.1, %shift-right-logical.117898.9.clone.1) + %xor.123985.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252966.5.clone.1, %or.117425.7.clone.1) + %add.252967.3.clone.1 = u32[1280,1280]{1,0} add(%add.252966.5.clone.1, %xor.123985.5.clone.1) + %shift-left.111602.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123985.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117899.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123985.5.clone.1, %broadcast.244419.4352) + %or.117427.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111602.5.clone.1, %shift-right-logical.117899.5.clone.1) + %xor.123986.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252967.3.clone.1, %or.117427.3.clone.1) + %add.252968.3.clone.1 = u32[1280,1280]{1,0} add(%add.252967.3.clone.1, %xor.123986.3.clone.1) + %add.252969.7.clone.1 = u32[1280,1280]{1,0} add(%add.252968.3.clone.1, %broadcast.259979.113.clone.1) + %shift-left.111603.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123986.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117900.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123986.3.clone.1, %broadcast.244418.4352) + %or.117428.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111603.5.clone.1, %shift-right-logical.117900.5.clone.1) + %xor.123987.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252968.3.clone.1, %or.117428.3.clone.1) + %constant_218741_1_clone_1 = u32[] constant(2080014098) + %broadcast.259989.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218741_1_clone_1), dimensions={} + %add.252971.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123987.3.clone.1, %broadcast.259989.5.clone.1) + %add.252975.5.clone.1 = u32[1280,1280]{1,0} add(%add.252969.7.clone.1, %add.252971.5.clone.1) + %shift-left.111604.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252971.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117901.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252971.5.clone.1, %broadcast.244416.5760) + %or.117429.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111604.9.clone.1, %shift-right-logical.117901.9.clone.1) + %xor.123988.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252975.5.clone.1, %or.117429.7.clone.1) + %add.252976.3.clone.1 = u32[1280,1280]{1,0} add(%add.252975.5.clone.1, %xor.123988.5.clone.1) + %shift-left.111605.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123988.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117902.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123988.5.clone.1, %broadcast.244429.2304) + %or.117430.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111605.9.clone.1, %shift-right-logical.117902.9.clone.1) + %xor.123989.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252976.3.clone.1, %or.117430.7.clone.1) + %add.252977.3.clone.1 = u32[1280,1280]{1,0} add(%add.252976.3.clone.1, %xor.123989.5.clone.1) + %shift-left.111606.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123989.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117903.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123989.5.clone.1, %broadcast.244430.4608) + %or.117432.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111606.9.clone.1, %shift-right-logical.117903.9.clone.1) + %xor.123990.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252977.3.clone.1, %or.117432.7.clone.1) + %add.252978.3.clone.1 = u32[1280,1280]{1,0} add(%add.252977.3.clone.1, %xor.123990.5.clone.1) + %constant_194127_1_clone_1 = u32[] constant(2080014097) + %broadcast.259996.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_194127_1_clone_1), dimensions={} + %add.252980.7.clone.1 = u32[1280,1280]{1,0} add(%add.252978.3.clone.1, %broadcast.259996.24.clone.1) + %shift-left.111607.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123990.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117904.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123990.5.clone.1, %broadcast.244434.2816) + %or.117433.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111607.11.clone.1, %shift-right-logical.117904.11.clone.1) + %xor.123992.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252978.3.clone.1, %or.117433.9.clone.1) + %constant_218742_1_clone_1 = u32[] constant(708617429) + %broadcast.259999.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218742_1_clone_1), dimensions={} + %add.252981.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123992.7.clone.1, %broadcast.259999.5.clone.1) + %add.252982.5.clone.1 = u32[1280,1280]{1,0} add(%add.252980.7.clone.1, %add.252981.5.clone.1) + %shift-left.111608.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252981.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117905.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252981.5.clone.1, %broadcast.244415.6016) + %or.117434.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111608.9.clone.1, %shift-right-logical.117905.9.clone.1) + %xor.123993.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252982.5.clone.1, %or.117434.7.clone.1) + %add.252983.3.clone.1 = u32[1280,1280]{1,0} add(%add.252982.5.clone.1, %xor.123993.5.clone.1) + %shift-left.111609.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123993.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117906.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123993.5.clone.1, %broadcast.244417.5760) + %or.117435.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111609.9.clone.1, %shift-right-logical.117906.9.clone.1) + %xor.123994.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252983.3.clone.1, %or.117435.7.clone.1) + %add.252985.3.clone.1 = u32[1280,1280]{1,0} add(%add.252983.3.clone.1, %xor.123994.5.clone.1) + %shift-left.111610.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123994.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117907.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123994.5.clone.1, %broadcast.244419.4352) + %or.117437.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111610.7.clone.1, %shift-right-logical.117907.7.clone.1) + %xor.123995.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252985.3.clone.1, %or.117437.5.clone.1) + %add.252986.3.clone.1 = u32[1280,1280]{1,0} add(%add.252985.3.clone.1, %xor.123995.3.clone.1) + %add.252987.7.clone.1 = u32[1280,1280]{1,0} add(%add.252986.3.clone.1, %broadcast.259978.44.clone.1) + %shift-left.111611.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123995.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117908.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123995.3.clone.1, %broadcast.244418.4352) + %or.117438.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111611.7.clone.1, %shift-right-logical.117908.7.clone.1) + %xor.123997.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252986.3.clone.1, %or.117438.5.clone.1) + %constant_218743_1_clone_1 = u32[] constant(1243070491) + %broadcast.260009.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218743_1_clone_1), dimensions={} + %add.252988.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123997.3.clone.1, %broadcast.260009.5.clone.1) + %add.252990.5.clone.1 = u32[1280,1280]{1,0} add(%add.252987.7.clone.1, %add.252988.5.clone.1) + %shift-left.111612.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252988.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117909.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252988.5.clone.1, %broadcast.244416.5760) + %or.117439.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111612.9.clone.1, %shift-right-logical.117909.9.clone.1) + %xor.123998.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252990.5.clone.1, %or.117439.7.clone.1) + %add.252991.3.clone.1 = u32[1280,1280]{1,0} add(%add.252990.5.clone.1, %xor.123998.5.clone.1) + %shift-left.111613.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123998.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117910.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123998.5.clone.1, %broadcast.244429.2304) + %or.117440.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111613.9.clone.1, %shift-right-logical.117910.9.clone.1) + %xor.123999.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252991.3.clone.1, %or.117440.7.clone.1) + %add.252992.3.clone.1 = u32[1280,1280]{1,0} add(%add.252991.3.clone.1, %xor.123999.5.clone.1) + %shift-left.111614.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123999.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117911.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123999.5.clone.1, %broadcast.244430.4608) + %or.117441.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111614.9.clone.1, %shift-right-logical.117911.9.clone.1) + %xor.124000.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252992.3.clone.1, %or.117441.7.clone.1) + %add.252993.3.clone.1 = u32[1280,1280]{1,0} add(%add.252992.3.clone.1, %xor.124000.5.clone.1) + %add.252994.7.clone.1 = u32[1280,1280]{1,0} add(%add.252993.3.clone.1, %broadcast.259979.113.clone.1) + %shift-left.111615.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124000.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117912.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124000.5.clone.1, %broadcast.244434.2816) + %or.117442.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111615.11.clone.1, %shift-right-logical.117912.11.clone.1) + %xor.124002.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252993.3.clone.1, %or.117442.9.clone.1) + %constant_218744_1_clone_1 = u32[] constant(2080014101) + %broadcast.260019.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218744_1_clone_1), dimensions={} + %add.252996.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124002.7.clone.1, %broadcast.260019.5.clone.1) + %add.253000.5.clone.1 = u32[1280,1280]{1,0} add(%add.252994.7.clone.1, %add.252996.5.clone.1) + %shift-left.111616.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252996.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117913.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252996.5.clone.1, %broadcast.244415.6016) + %or.117443.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111616.9.clone.1, %shift-right-logical.117913.9.clone.1) + %xor.124003.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253000.5.clone.1, %or.117443.7.clone.1) + %add.253001.3.clone.1 = u32[1280,1280]{1,0} add(%add.253000.5.clone.1, %xor.124003.5.clone.1) + %shift-left.111617.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124003.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117914.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124003.5.clone.1, %broadcast.244417.5760) + %or.117444.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111617.9.clone.1, %shift-right-logical.117914.9.clone.1) + %xor.124004.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253001.3.clone.1, %or.117444.7.clone.1) + %add.253002.3.clone.1 = u32[1280,1280]{1,0} add(%add.253001.3.clone.1, %xor.124004.5.clone.1) + %shift-left.111618.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124004.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117915.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124004.5.clone.1, %broadcast.244419.4352) + %or.117445.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111618.5.clone.1, %shift-right-logical.117915.5.clone.1) + %xor.124005.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253002.3.clone.1, %or.117445.3.clone.1) + %add.253003.3.clone.1 = u32[1280,1280]{1,0} add(%add.253002.3.clone.1, %xor.124005.3.clone.1) + %add.253005.17.clone.1 = u32[1280,1280]{1,0} add(%add.253003.3.clone.1, %broadcast.259996.24.clone.1) + %shift-left.111619.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124005.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117916.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124005.3.clone.1, %broadcast.244418.4352) + %or.117447.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111619.5.clone.1, %shift-right-logical.117916.5.clone.1) + %xor.124007.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253003.3.clone.1, %or.117447.3.clone.1) + %constant_218745_1_clone_1 = u32[] constant(708617432) + %broadcast.260029.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218745_1_clone_1), dimensions={} + %add.253006.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124007.15.clone.1, %broadcast.260029.19.clone.1) + %xor.124008.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253005.17.clone.1, %add.253006.19.clone.1) + %shift-right-logical.117917.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124008.17.clone.1, %broadcast.244468.1920) + %or.117448.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117917.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5846.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117448.13.clone.1) + %add.253007.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5846.11.clone.1, %broadcast.244470.1152) + %multiply.27412.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253007.9.clone.1, %broadcast.244471.896) + %add.253008.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27412.7.clone.1, %broadcast.244408.1024) + %maximum.3778.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253008.5.clone.1) + %abs.1602.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3778.3.clone.1) + %compare.7366.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1602.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27413.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3778.3.clone.1, %broadcast.244476.1152) + %negate.4709.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3778.3.clone.1) + %multiply.27414.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3778.3.clone.1, %negate.4709.5.clone.1) + %log-plus-one.1602.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27414.5.clone.1) + %negate.4710.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1602.3.clone.1) + %compare.7367.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4710.4.clone.1, %broadcast.244477.384), direction=LT + %select.21729.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21730.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21731.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21732.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21733.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21734.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21735.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21736.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21737.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253010.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4710.4.clone.1, %broadcast.244496.640) + %sqrt.1602.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4710.4.clone.1) + %add.253011.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1602.5.clone.1, %broadcast.244498.640) + %select.21738.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7367.3.clone.1, %add.253010.5.clone.1, %add.253011.5.clone.1) + %multiply.27415.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21737.3.clone.1, %select.21738.3.clone.1) + %add.253012.1.clone.1 = f32[1280,1280]{1,0} add(%select.21736.3.clone.1, %multiply.27415.1.clone.1) + %multiply.27416.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253012.1.clone.1, %select.21738.3.clone.1) + %add.253013.1.clone.1 = f32[1280,1280]{1,0} add(%select.21735.3.clone.1, %multiply.27416.1.clone.1) + %multiply.27417.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253013.1.clone.1, %select.21738.3.clone.1) + %add.253015.1.clone.1 = f32[1280,1280]{1,0} add(%select.21734.3.clone.1, %multiply.27417.1.clone.1) + %multiply.27418.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253015.1.clone.1, %select.21738.3.clone.1) + %add.253016.1.clone.1 = f32[1280,1280]{1,0} add(%select.21733.3.clone.1, %multiply.27418.1.clone.1) + %multiply.27419.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253016.1.clone.1, %select.21738.3.clone.1) + %add.253017.3.clone.1 = f32[1280,1280]{1,0} add(%select.21732.5.clone.1, %multiply.27419.1.clone.1) + %multiply.27420.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253017.3.clone.1, %select.21738.3.clone.1) + %add.253018.3.clone.1 = f32[1280,1280]{1,0} add(%select.21731.5.clone.1, %multiply.27420.1.clone.1) + %multiply.27421.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253018.3.clone.1, %select.21738.3.clone.1) + %add.253019.9.clone.1 = f32[1280,1280]{1,0} add(%select.21730.11.clone.1, %multiply.27421.7.clone.1) + %multiply.27422.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253019.9.clone.1, %select.21738.3.clone.1) + %add.253021.7.clone.1 = f32[1280,1280]{1,0} add(%select.21729.7.clone.1, %multiply.27422.7.clone.1) + %multiply.27423.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253021.7.clone.1, %maximum.3778.3.clone.1) + %select.21739.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7366.3.clone.1, %multiply.27413.9.clone.1, %multiply.27423.7.clone.1) + %multiply.27424.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21739.7.clone.1, %broadcast.244500.640) + %clamp.1246.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27424.5.clone.1, %broadcast.244501.384) + %multiply.27425.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1246.3.clone.1, %broadcast.244502.1) + %constant_173929_1_clone_1 = u32[] constant(1260932913) + %broadcast.251246.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173929_1_clone_1), dimensions={} + %add.247979.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251246.44.clone.1) + %constant_173936_1_clone_1 = u32[] constant(1413264670) + %broadcast.251247.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173936_1_clone_1), dimensions={} + %add.247981.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251247.113.clone.1) + %add.247985.35.clone.1 = u32[1280,1280]{1,0} add(%add.247979.37.clone.1, %add.247981.99.clone.1) + %shift-left.109440.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247981.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115616.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247981.99.clone.1, %broadcast.244415.6016) + %or.115147.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109440.31.clone.1, %shift-right-logical.115616.29.clone.1) + %xor.121695.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247985.35.clone.1, %or.115147.29.clone.1) + %add.247986.5.clone.1 = u32[1280,1280]{1,0} add(%add.247985.35.clone.1, %xor.121695.27.clone.1) + %shift-left.109441.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121695.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115617.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121695.27.clone.1, %broadcast.244417.5760) + %or.115148.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109441.9.clone.1, %shift-right-logical.115617.9.clone.1) + %xor.121696.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247986.5.clone.1, %or.115148.7.clone.1) + %add.247987.3.clone.1 = u32[1280,1280]{1,0} add(%add.247986.5.clone.1, %xor.121696.5.clone.1) + %shift-left.109442.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121696.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115619.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121696.5.clone.1, %broadcast.244419.4352) + %or.115149.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109442.5.clone.1, %shift-right-logical.115619.5.clone.1) + %xor.121698.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247987.3.clone.1, %or.115149.3.clone.1) + %add.247988.3.clone.1 = u32[1280,1280]{1,0} add(%add.247987.3.clone.1, %xor.121698.3.clone.1) + %add.247990.7.clone.1 = u32[1280,1280]{1,0} add(%add.247988.3.clone.1, %broadcast.251247.113.clone.1) + %shift-left.109443.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121698.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115620.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121698.3.clone.1, %broadcast.244418.4352) + %or.115150.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109443.5.clone.1, %shift-right-logical.115620.5.clone.1) + %xor.121699.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247988.3.clone.1, %or.115150.3.clone.1) + %constant_218207_1_clone_1 = u32[] constant(80077302) + %broadcast.251257.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218207_1_clone_1), dimensions={} + %add.247991.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121699.3.clone.1, %broadcast.251257.5.clone.1) + %add.247992.5.clone.1 = u32[1280,1280]{1,0} add(%add.247990.7.clone.1, %add.247991.5.clone.1) + %shift-left.109444.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247991.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115621.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247991.5.clone.1, %broadcast.244416.5760) + %or.115151.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109444.9.clone.1, %shift-right-logical.115621.9.clone.1) + %xor.121700.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247992.5.clone.1, %or.115151.7.clone.1) + %add.247993.3.clone.1 = u32[1280,1280]{1,0} add(%add.247992.5.clone.1, %xor.121700.5.clone.1) + %shift-left.109445.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121700.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115622.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121700.5.clone.1, %broadcast.244429.2304) + %or.115152.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109445.9.clone.1, %shift-right-logical.115622.9.clone.1) + %xor.121701.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247993.3.clone.1, %or.115152.7.clone.1) + %add.247995.3.clone.1 = u32[1280,1280]{1,0} add(%add.247993.3.clone.1, %xor.121701.5.clone.1) + %shift-left.109446.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121701.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115624.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121701.5.clone.1, %broadcast.244430.4608) + %or.115153.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109446.9.clone.1, %shift-right-logical.115624.9.clone.1) + %xor.121703.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247995.3.clone.1, %or.115153.7.clone.1) + %add.247996.3.clone.1 = u32[1280,1280]{1,0} add(%add.247995.3.clone.1, %xor.121703.5.clone.1) + %constant_173938_1_clone_1 = u32[] constant(80077301) + %broadcast.251264.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173938_1_clone_1), dimensions={} + %add.247997.7.clone.1 = u32[1280,1280]{1,0} add(%add.247996.3.clone.1, %broadcast.251264.24.clone.1) + %shift-left.109447.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121703.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115625.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121703.5.clone.1, %broadcast.244434.2816) + %or.115154.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109447.11.clone.1, %shift-right-logical.115625.11.clone.1) + %xor.121704.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247996.3.clone.1, %or.115154.9.clone.1) + %constant_218208_1_clone_1 = u32[] constant(1260932915) + %broadcast.251267.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218208_1_clone_1), dimensions={} + %add.247998.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121704.7.clone.1, %broadcast.251267.5.clone.1) + %add.248000.5.clone.1 = u32[1280,1280]{1,0} add(%add.247997.7.clone.1, %add.247998.5.clone.1) + %shift-left.109448.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247998.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115626.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247998.5.clone.1, %broadcast.244415.6016) + %or.115155.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109448.9.clone.1, %shift-right-logical.115626.9.clone.1) + %xor.121705.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248000.5.clone.1, %or.115155.7.clone.1) + %add.248001.3.clone.1 = u32[1280,1280]{1,0} add(%add.248000.5.clone.1, %xor.121705.5.clone.1) + %shift-left.109449.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121705.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115627.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121705.5.clone.1, %broadcast.244417.5760) + %or.115156.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109449.9.clone.1, %shift-right-logical.115627.9.clone.1) + %xor.121706.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248001.3.clone.1, %or.115156.7.clone.1) + %add.248002.3.clone.1 = u32[1280,1280]{1,0} add(%add.248001.3.clone.1, %xor.121706.5.clone.1) + %shift-left.109450.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121706.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115629.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121706.5.clone.1, %broadcast.244419.4352) + %or.115157.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109450.7.clone.1, %shift-right-logical.115629.7.clone.1) + %xor.121708.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248002.3.clone.1, %or.115157.5.clone.1) + %add.248003.3.clone.1 = u32[1280,1280]{1,0} add(%add.248002.3.clone.1, %xor.121708.3.clone.1) + %add.248004.7.clone.1 = u32[1280,1280]{1,0} add(%add.248003.3.clone.1, %broadcast.251246.44.clone.1) + %shift-left.109451.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121708.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115630.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121708.3.clone.1, %broadcast.244418.4352) + %or.115158.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109451.7.clone.1, %shift-right-logical.115630.7.clone.1) + %xor.121709.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248003.3.clone.1, %or.115158.5.clone.1) + %constant_218209_1_clone_1 = u32[] constant(1413264673) + %broadcast.251277.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218209_1_clone_1), dimensions={} + %add.248006.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121709.3.clone.1, %broadcast.251277.5.clone.1) + %add.248010.5.clone.1 = u32[1280,1280]{1,0} add(%add.248004.7.clone.1, %add.248006.5.clone.1) + %shift-left.109452.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248006.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115631.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248006.5.clone.1, %broadcast.244416.5760) + %or.115159.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109452.9.clone.1, %shift-right-logical.115631.9.clone.1) + %xor.121710.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248010.5.clone.1, %or.115159.7.clone.1) + %add.248011.3.clone.1 = u32[1280,1280]{1,0} add(%add.248010.5.clone.1, %xor.121710.5.clone.1) + %shift-left.109453.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121710.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115632.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121710.5.clone.1, %broadcast.244429.2304) + %or.115160.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109453.9.clone.1, %shift-right-logical.115632.9.clone.1) + %xor.121711.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248011.3.clone.1, %or.115160.7.clone.1) + %add.248012.3.clone.1 = u32[1280,1280]{1,0} add(%add.248011.3.clone.1, %xor.121711.5.clone.1) + %shift-left.109454.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121711.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115633.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121711.5.clone.1, %broadcast.244430.4608) + %or.115161.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109454.9.clone.1, %shift-right-logical.115633.9.clone.1) + %xor.121712.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248012.3.clone.1, %or.115161.7.clone.1) + %add.248013.3.clone.1 = u32[1280,1280]{1,0} add(%add.248012.3.clone.1, %xor.121712.5.clone.1) + %add.248015.7.clone.1 = u32[1280,1280]{1,0} add(%add.248013.3.clone.1, %broadcast.251247.113.clone.1) + %shift-left.109455.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121712.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115634.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121712.5.clone.1, %broadcast.244434.2816) + %or.115162.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109455.11.clone.1, %shift-right-logical.115634.11.clone.1) + %xor.121713.7.clone.1 = u32[1280,1280]{1,0} xor(%add.248013.3.clone.1, %or.115162.9.clone.1) + %constant_218210_1_clone_1 = u32[] constant(80077305) + %broadcast.251287.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218210_1_clone_1), dimensions={} + %add.248016.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121713.7.clone.1, %broadcast.251287.5.clone.1) + %add.248017.5.clone.1 = u32[1280,1280]{1,0} add(%add.248015.7.clone.1, %add.248016.5.clone.1) + %shift-left.109456.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.248016.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115635.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.248016.5.clone.1, %broadcast.244415.6016) + %or.115163.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109456.9.clone.1, %shift-right-logical.115635.9.clone.1) + %xor.121714.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248017.5.clone.1, %or.115163.7.clone.1) + %add.248018.3.clone.1 = u32[1280,1280]{1,0} add(%add.248017.5.clone.1, %xor.121714.5.clone.1) + %shift-left.109457.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121714.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115636.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121714.5.clone.1, %broadcast.244417.5760) + %or.115164.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109457.9.clone.1, %shift-right-logical.115636.9.clone.1) + %xor.121715.5.clone.1 = u32[1280,1280]{1,0} xor(%add.248018.3.clone.1, %or.115164.7.clone.1) + %add.248020.3.clone.1 = u32[1280,1280]{1,0} add(%add.248018.3.clone.1, %xor.121715.5.clone.1) + %shift-left.109458.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121715.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115637.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121715.5.clone.1, %broadcast.244419.4352) + %or.115165.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109458.5.clone.1, %shift-right-logical.115637.5.clone.1) + %xor.121716.3.clone.1 = u32[1280,1280]{1,0} xor(%add.248020.3.clone.1, %or.115165.3.clone.1) + %add.248021.3.clone.1 = u32[1280,1280]{1,0} add(%add.248020.3.clone.1, %xor.121716.3.clone.1) + %add.248022.17.clone.1 = u32[1280,1280]{1,0} add(%add.248021.3.clone.1, %broadcast.251264.24.clone.1) + %shift-left.109459.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121716.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115639.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121716.3.clone.1, %broadcast.244418.4352) + %or.115166.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109459.5.clone.1, %shift-right-logical.115639.5.clone.1) + %xor.121718.15.clone.1 = u32[1280,1280]{1,0} xor(%add.248021.3.clone.1, %or.115166.3.clone.1) + %constant_218211_1_clone_1 = u32[] constant(1260932918) + %broadcast.251297.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218211_1_clone_1), dimensions={} + %add.248023.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121718.15.clone.1, %broadcast.251297.19.clone.1) + %xor.121719.17.clone.1 = u32[1280,1280]{1,0} xor(%add.248022.17.clone.1, %add.248023.19.clone.1) + %shift-right-logical.115640.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121719.17.clone.1, %broadcast.244468.1920) + %or.115167.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115640.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5747.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115167.13.clone.1) + %add.248025.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5747.11.clone.1, %broadcast.244470.1152) + %multiply.26382.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248025.9.clone.1, %broadcast.244471.896) + %add.248026.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26382.7.clone.1, %broadcast.244408.1024) + %maximum.3679.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.248026.5.clone.1) + %abs.1535.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3679.3.clone.1) + %compare.7218.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1535.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26383.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3679.3.clone.1, %broadcast.244476.1152) + %negate.4575.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3679.3.clone.1) + %multiply.26384.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3679.3.clone.1, %negate.4575.5.clone.1) + %log-plus-one.1535.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26384.5.clone.1) + %negate.4576.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1535.3.clone.1) + %compare.7219.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4576.4.clone.1, %broadcast.244477.384), direction=LT + %select.20950.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20951.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20952.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20953.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20954.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20955.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20956.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20957.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20958.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.248027.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4576.4.clone.1, %broadcast.244496.640) + %sqrt.1535.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4576.4.clone.1) + %add.248028.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1535.5.clone.1, %broadcast.244498.640) + %select.20959.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7219.3.clone.1, %add.248027.5.clone.1, %add.248028.5.clone.1) + %multiply.26385.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20958.3.clone.1, %select.20959.3.clone.1) + %add.248029.1.clone.1 = f32[1280,1280]{1,0} add(%select.20957.3.clone.1, %multiply.26385.1.clone.1) + %multiply.26386.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248029.1.clone.1, %select.20959.3.clone.1) + %add.248031.1.clone.1 = f32[1280,1280]{1,0} add(%select.20956.3.clone.1, %multiply.26386.1.clone.1) + %multiply.26387.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248031.1.clone.1, %select.20959.3.clone.1) + %add.248034.1.clone.1 = f32[1280,1280]{1,0} add(%select.20955.3.clone.1, %multiply.26387.1.clone.1) + %multiply.26388.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248034.1.clone.1, %select.20959.3.clone.1) + %add.248035.1.clone.1 = f32[1280,1280]{1,0} add(%select.20954.3.clone.1, %multiply.26388.1.clone.1) + %multiply.26389.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248035.1.clone.1, %select.20959.3.clone.1) + %add.248036.3.clone.1 = f32[1280,1280]{1,0} add(%select.20953.5.clone.1, %multiply.26389.1.clone.1) + %multiply.26390.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.248036.3.clone.1, %select.20959.3.clone.1) + %add.248037.3.clone.1 = f32[1280,1280]{1,0} add(%select.20952.5.clone.1, %multiply.26390.1.clone.1) + %multiply.26391.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248037.3.clone.1, %select.20959.3.clone.1) + %add.248038.9.clone.1 = f32[1280,1280]{1,0} add(%select.20951.11.clone.1, %multiply.26391.7.clone.1) + %multiply.26392.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248038.9.clone.1, %select.20959.3.clone.1) + %add.248039.7.clone.1 = f32[1280,1280]{1,0} add(%select.20950.7.clone.1, %multiply.26392.7.clone.1) + %multiply.26393.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.248039.7.clone.1, %maximum.3679.3.clone.1) + %select.20960.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7218.3.clone.1, %multiply.26383.9.clone.1, %multiply.26393.7.clone.1) + %multiply.26394.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20960.7.clone.1, %broadcast.244500.640) + %clamp.1179.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26394.5.clone.1, %broadcast.244501.384) + %multiply.26395.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1179.3.clone.1, %broadcast.244502.1) + %constant_185361_1_clone_1 = u32[] constant(1322156039) + %broadcast.256214.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185361_1_clone_1), dimensions={} + %add.250788.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256214.44.clone.1) + %constant_185368_1_clone_1 = u32[] constant(1657033467) + %broadcast.256215.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185368_1_clone_1), dimensions={} + %add.250789.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256215.113.clone.1) + %add.250790.35.clone.1 = u32[1280,1280]{1,0} add(%add.250788.37.clone.1, %add.250789.99.clone.1) + %shift-left.110660.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250789.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116916.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250789.99.clone.1, %broadcast.244415.6016) + %or.116434.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110660.31.clone.1, %shift-right-logical.116916.29.clone.1) + %xor.123003.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250790.35.clone.1, %or.116434.29.clone.1) + %add.250791.5.clone.1 = u32[1280,1280]{1,0} add(%add.250790.35.clone.1, %xor.123003.27.clone.1) + %shift-left.110661.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123003.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116917.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123003.27.clone.1, %broadcast.244417.5760) + %or.116435.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110661.9.clone.1, %shift-right-logical.116917.9.clone.1) + %xor.123004.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250791.5.clone.1, %or.116435.7.clone.1) + %add.250792.3.clone.1 = u32[1280,1280]{1,0} add(%add.250791.5.clone.1, %xor.123004.5.clone.1) + %shift-left.110662.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123004.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116918.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123004.5.clone.1, %broadcast.244419.4352) + %or.116436.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110662.5.clone.1, %shift-right-logical.116918.5.clone.1) + %xor.123005.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250792.3.clone.1, %or.116436.3.clone.1) + %add.250793.3.clone.1 = u32[1280,1280]{1,0} add(%add.250792.3.clone.1, %xor.123005.3.clone.1) + %add.250794.7.clone.1 = u32[1280,1280]{1,0} add(%add.250793.3.clone.1, %broadcast.256215.113.clone.1) + %shift-left.110663.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123005.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116920.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123005.3.clone.1, %broadcast.244418.4352) + %or.116437.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110663.5.clone.1, %shift-right-logical.116920.5.clone.1) + %xor.123006.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250793.3.clone.1, %or.116437.3.clone.1) + %constant_218504_1_clone_1 = u32[] constant(937150759) + %broadcast.256225.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218504_1_clone_1), dimensions={} + %add.250795.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123006.3.clone.1, %broadcast.256225.5.clone.1) + %add.250796.5.clone.1 = u32[1280,1280]{1,0} add(%add.250794.7.clone.1, %add.250795.5.clone.1) + %shift-left.110664.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250795.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116921.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250795.5.clone.1, %broadcast.244416.5760) + %or.116438.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110664.9.clone.1, %shift-right-logical.116921.9.clone.1) + %xor.123007.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250796.5.clone.1, %or.116438.7.clone.1) + %add.250797.3.clone.1 = u32[1280,1280]{1,0} add(%add.250796.5.clone.1, %xor.123007.5.clone.1) + %shift-left.110665.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123007.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116922.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123007.5.clone.1, %broadcast.244429.2304) + %or.116439.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110665.9.clone.1, %shift-right-logical.116922.9.clone.1) + %xor.123008.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250797.3.clone.1, %or.116439.7.clone.1) + %add.250798.3.clone.1 = u32[1280,1280]{1,0} add(%add.250797.3.clone.1, %xor.123008.5.clone.1) + %shift-left.110666.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123008.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116923.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123008.5.clone.1, %broadcast.244430.4608) + %or.116440.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110666.9.clone.1, %shift-right-logical.116923.9.clone.1) + %xor.123009.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250798.3.clone.1, %or.116440.7.clone.1) + %add.250799.3.clone.1 = u32[1280,1280]{1,0} add(%add.250798.3.clone.1, %xor.123009.5.clone.1) + %constant_185370_1_clone_1 = u32[] constant(937150758) + %broadcast.256232.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185370_1_clone_1), dimensions={} + %add.250800.7.clone.1 = u32[1280,1280]{1,0} add(%add.250799.3.clone.1, %broadcast.256232.24.clone.1) + %shift-left.110667.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123009.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116925.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123009.5.clone.1, %broadcast.244434.2816) + %or.116441.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110667.11.clone.1, %shift-right-logical.116925.11.clone.1) + %xor.123010.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250799.3.clone.1, %or.116441.9.clone.1) + %constant_218505_1_clone_1 = u32[] constant(1322156041) + %broadcast.256235.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218505_1_clone_1), dimensions={} + %add.250801.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123010.7.clone.1, %broadcast.256235.5.clone.1) + %add.250802.5.clone.1 = u32[1280,1280]{1,0} add(%add.250800.7.clone.1, %add.250801.5.clone.1) + %shift-left.110668.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250801.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116926.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250801.5.clone.1, %broadcast.244415.6016) + %or.116442.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110668.9.clone.1, %shift-right-logical.116926.9.clone.1) + %xor.123011.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250802.5.clone.1, %or.116442.7.clone.1) + %add.250803.3.clone.1 = u32[1280,1280]{1,0} add(%add.250802.5.clone.1, %xor.123011.5.clone.1) + %shift-left.110669.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123011.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116927.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123011.5.clone.1, %broadcast.244417.5760) + %or.116443.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110669.9.clone.1, %shift-right-logical.116927.9.clone.1) + %xor.123012.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250803.3.clone.1, %or.116443.7.clone.1) + %add.250804.3.clone.1 = u32[1280,1280]{1,0} add(%add.250803.3.clone.1, %xor.123012.5.clone.1) + %shift-left.110670.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123012.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116928.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123012.5.clone.1, %broadcast.244419.4352) + %or.116444.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110670.7.clone.1, %shift-right-logical.116928.7.clone.1) + %xor.123015.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250804.3.clone.1, %or.116444.5.clone.1) + %add.250805.3.clone.1 = u32[1280,1280]{1,0} add(%add.250804.3.clone.1, %xor.123015.3.clone.1) + %add.250806.7.clone.1 = u32[1280,1280]{1,0} add(%add.250805.3.clone.1, %broadcast.256214.44.clone.1) + %shift-left.110671.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123015.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116930.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123015.3.clone.1, %broadcast.244418.4352) + %or.116445.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110671.7.clone.1, %shift-right-logical.116930.7.clone.1) + %xor.123016.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250805.3.clone.1, %or.116445.5.clone.1) + %constant_218506_1_clone_1 = u32[] constant(1657033470) + %broadcast.256245.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218506_1_clone_1), dimensions={} + %add.250807.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123016.3.clone.1, %broadcast.256245.5.clone.1) + %add.250809.5.clone.1 = u32[1280,1280]{1,0} add(%add.250806.7.clone.1, %add.250807.5.clone.1) + %shift-left.110672.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250807.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116931.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250807.5.clone.1, %broadcast.244416.5760) + %or.116446.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110672.9.clone.1, %shift-right-logical.116931.9.clone.1) + %xor.123017.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250809.5.clone.1, %or.116446.7.clone.1) + %add.250812.3.clone.1 = u32[1280,1280]{1,0} add(%add.250809.5.clone.1, %xor.123017.5.clone.1) + %shift-left.110673.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123017.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116932.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123017.5.clone.1, %broadcast.244429.2304) + %or.116447.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110673.9.clone.1, %shift-right-logical.116932.9.clone.1) + %xor.123018.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250812.3.clone.1, %or.116447.7.clone.1) + %add.250813.3.clone.1 = u32[1280,1280]{1,0} add(%add.250812.3.clone.1, %xor.123018.5.clone.1) + %shift-left.110674.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123018.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116933.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123018.5.clone.1, %broadcast.244430.4608) + %or.116448.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110674.9.clone.1, %shift-right-logical.116933.9.clone.1) + %xor.123019.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250813.3.clone.1, %or.116448.7.clone.1) + %add.250814.3.clone.1 = u32[1280,1280]{1,0} add(%add.250813.3.clone.1, %xor.123019.5.clone.1) + %add.250815.7.clone.1 = u32[1280,1280]{1,0} add(%add.250814.3.clone.1, %broadcast.256215.113.clone.1) + %shift-left.110675.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123019.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116935.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123019.5.clone.1, %broadcast.244434.2816) + %or.116449.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110675.11.clone.1, %shift-right-logical.116935.11.clone.1) + %xor.123020.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250814.3.clone.1, %or.116449.9.clone.1) + %constant_218507_1_clone_1 = u32[] constant(937150762) + %broadcast.256255.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218507_1_clone_1), dimensions={} + %add.250817.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123020.7.clone.1, %broadcast.256255.5.clone.1) + %add.250818.5.clone.1 = u32[1280,1280]{1,0} add(%add.250815.7.clone.1, %add.250817.5.clone.1) + %shift-left.110676.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250817.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116936.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250817.5.clone.1, %broadcast.244415.6016) + %or.116450.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110676.9.clone.1, %shift-right-logical.116936.9.clone.1) + %xor.123021.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250818.5.clone.1, %or.116450.7.clone.1) + %add.250819.3.clone.1 = u32[1280,1280]{1,0} add(%add.250818.5.clone.1, %xor.123021.5.clone.1) + %shift-left.110677.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123021.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116937.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123021.5.clone.1, %broadcast.244417.5760) + %or.116451.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110677.9.clone.1, %shift-right-logical.116937.9.clone.1) + %xor.123022.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250819.3.clone.1, %or.116451.7.clone.1) + %add.250820.3.clone.1 = u32[1280,1280]{1,0} add(%add.250819.3.clone.1, %xor.123022.5.clone.1) + %shift-left.110678.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123022.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116938.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123022.5.clone.1, %broadcast.244419.4352) + %or.116452.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110678.5.clone.1, %shift-right-logical.116938.5.clone.1) + %xor.123024.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250820.3.clone.1, %or.116452.3.clone.1) + %add.250822.3.clone.1 = u32[1280,1280]{1,0} add(%add.250820.3.clone.1, %xor.123024.3.clone.1) + %add.250823.17.clone.1 = u32[1280,1280]{1,0} add(%add.250822.3.clone.1, %broadcast.256232.24.clone.1) + %shift-left.110679.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123024.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116939.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123024.3.clone.1, %broadcast.244418.4352) + %or.116453.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110679.5.clone.1, %shift-right-logical.116939.5.clone.1) + %xor.123025.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250822.3.clone.1, %or.116453.3.clone.1) + %constant_218508_1_clone_1 = u32[] constant(1322156044) + %broadcast.256265.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218508_1_clone_1), dimensions={} + %add.250824.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123025.15.clone.1, %broadcast.256265.19.clone.1) + %xor.123026.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250823.17.clone.1, %add.250824.19.clone.1) + %shift-right-logical.116940.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123026.17.clone.1, %broadcast.244468.1920) + %or.116454.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116940.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5803.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116454.13.clone.1) + %add.250825.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5803.11.clone.1, %broadcast.244470.1152) + %multiply.26965.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250825.9.clone.1, %broadcast.244471.896) + %add.250827.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26965.7.clone.1, %broadcast.244408.1024) + %maximum.3735.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250827.5.clone.1) + %abs.1573.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3735.3.clone.1) + %compare.7302.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1573.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26966.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3735.3.clone.1, %broadcast.244476.1152) + %negate.4651.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3735.3.clone.1) + %multiply.26967.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3735.3.clone.1, %negate.4651.5.clone.1) + %log-plus-one.1573.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26967.5.clone.1) + %negate.4652.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1573.3.clone.1) + %compare.7303.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4652.4.clone.1, %broadcast.244477.384), direction=LT + %select.21389.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21390.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21391.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21392.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21393.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21394.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21395.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21396.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21397.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250828.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4652.4.clone.1, %broadcast.244496.640) + %sqrt.1573.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4652.4.clone.1) + %add.250829.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1573.5.clone.1, %broadcast.244498.640) + %select.21398.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7303.3.clone.1, %add.250828.5.clone.1, %add.250829.5.clone.1) + %multiply.26968.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21397.3.clone.1, %select.21398.3.clone.1) + %add.250830.1.clone.1 = f32[1280,1280]{1,0} add(%select.21396.3.clone.1, %multiply.26968.1.clone.1) + %multiply.26969.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250830.1.clone.1, %select.21398.3.clone.1) + %add.250831.1.clone.1 = f32[1280,1280]{1,0} add(%select.21395.3.clone.1, %multiply.26969.1.clone.1) + %multiply.26970.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250831.1.clone.1, %select.21398.3.clone.1) + %add.250833.1.clone.1 = f32[1280,1280]{1,0} add(%select.21394.3.clone.1, %multiply.26970.1.clone.1) + %multiply.26971.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250833.1.clone.1, %select.21398.3.clone.1) + %add.250837.1.clone.1 = f32[1280,1280]{1,0} add(%select.21393.3.clone.1, %multiply.26971.1.clone.1) + %multiply.26972.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250837.1.clone.1, %select.21398.3.clone.1) + %add.250838.3.clone.1 = f32[1280,1280]{1,0} add(%select.21392.5.clone.1, %multiply.26972.1.clone.1) + %multiply.26973.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250838.3.clone.1, %select.21398.3.clone.1) + %add.250839.3.clone.1 = f32[1280,1280]{1,0} add(%select.21391.5.clone.1, %multiply.26973.1.clone.1) + %multiply.26974.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250839.3.clone.1, %select.21398.3.clone.1) + %add.250840.9.clone.1 = f32[1280,1280]{1,0} add(%select.21390.11.clone.1, %multiply.26974.7.clone.1) + %multiply.26975.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250840.9.clone.1, %select.21398.3.clone.1) + %add.250842.7.clone.1 = f32[1280,1280]{1,0} add(%select.21389.7.clone.1, %multiply.26975.7.clone.1) + %multiply.26976.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250842.7.clone.1, %maximum.3735.3.clone.1) + %select.21399.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7302.3.clone.1, %multiply.26966.9.clone.1, %multiply.26976.7.clone.1) + %multiply.26977.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21399.7.clone.1, %broadcast.244500.640) + %clamp.1217.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26977.5.clone.1, %broadcast.244501.384) + %multiply.26978.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1217.3.clone.1, %broadcast.244502.1) + %constant_173378_1_clone_1 = u32[] constant(2492626316) + %broadcast.251013.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173378_1_clone_1), dimensions={} + %add.247845.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.251013.44.clone.1) + %constant_173385_1_clone_1 = u32[] constant(377826067) + %broadcast.251014.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173385_1_clone_1), dimensions={} + %add.247846.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.251014.113.clone.1) + %add.247847.35.clone.1 = u32[1280,1280]{1,0} add(%add.247845.37.clone.1, %add.247846.99.clone.1) + %shift-left.109380.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247846.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115541.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247846.99.clone.1, %broadcast.244415.6016) + %or.115080.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109380.31.clone.1, %shift-right-logical.115541.29.clone.1) + %xor.121629.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247847.35.clone.1, %or.115080.29.clone.1) + %add.247848.5.clone.1 = u32[1280,1280]{1,0} add(%add.247847.35.clone.1, %xor.121629.27.clone.1) + %shift-left.109381.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121629.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115542.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121629.27.clone.1, %broadcast.244417.5760) + %or.115081.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109381.9.clone.1, %shift-right-logical.115542.9.clone.1) + %xor.121630.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247848.5.clone.1, %or.115081.7.clone.1) + %add.247849.3.clone.1 = u32[1280,1280]{1,0} add(%add.247848.5.clone.1, %xor.121630.5.clone.1) + %shift-left.109382.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121630.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115544.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121630.5.clone.1, %broadcast.244419.4352) + %or.115082.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109382.5.clone.1, %shift-right-logical.115544.5.clone.1) + %xor.121631.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247849.3.clone.1, %or.115082.3.clone.1) + %add.247850.3.clone.1 = u32[1280,1280]{1,0} add(%add.247849.3.clone.1, %xor.121631.3.clone.1) + %add.247851.7.clone.1 = u32[1280,1280]{1,0} add(%add.247850.3.clone.1, %broadcast.251014.113.clone.1) + %shift-left.109383.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121631.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115545.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121631.3.clone.1, %broadcast.244418.4352) + %or.115083.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109383.5.clone.1, %shift-right-logical.115545.5.clone.1) + %xor.121632.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247850.3.clone.1, %or.115083.3.clone.1) + %constant_218176_1_clone_1 = u32[] constant(2579907910) + %broadcast.251026.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218176_1_clone_1), dimensions={} + %add.247852.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121632.3.clone.1, %broadcast.251026.5.clone.1) + %add.247853.5.clone.1 = u32[1280,1280]{1,0} add(%add.247851.7.clone.1, %add.247852.5.clone.1) + %shift-left.109384.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247852.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115546.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247852.5.clone.1, %broadcast.244416.5760) + %or.115084.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109384.9.clone.1, %shift-right-logical.115546.9.clone.1) + %xor.121633.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247853.5.clone.1, %or.115084.7.clone.1) + %add.247855.3.clone.1 = u32[1280,1280]{1,0} add(%add.247853.5.clone.1, %xor.121633.5.clone.1) + %shift-left.109385.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121633.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115547.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121633.5.clone.1, %broadcast.244429.2304) + %or.115085.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109385.9.clone.1, %shift-right-logical.115547.9.clone.1) + %xor.121634.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247855.3.clone.1, %or.115085.7.clone.1) + %add.247856.3.clone.1 = u32[1280,1280]{1,0} add(%add.247855.3.clone.1, %xor.121634.5.clone.1) + %shift-left.109386.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121634.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115549.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121634.5.clone.1, %broadcast.244430.4608) + %or.115086.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109386.9.clone.1, %shift-right-logical.115549.9.clone.1) + %xor.121635.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247856.3.clone.1, %or.115086.7.clone.1) + %add.247857.3.clone.1 = u32[1280,1280]{1,0} add(%add.247856.3.clone.1, %xor.121635.5.clone.1) + %constant_173387_1_clone_1 = u32[] constant(2579907909) + %broadcast.251033.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173387_1_clone_1), dimensions={} + %add.247858.7.clone.1 = u32[1280,1280]{1,0} add(%add.247857.3.clone.1, %broadcast.251033.24.clone.1) + %shift-left.109387.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121635.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115550.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121635.5.clone.1, %broadcast.244434.2816) + %or.115087.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109387.11.clone.1, %shift-right-logical.115550.11.clone.1) + %xor.121636.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247857.3.clone.1, %or.115087.9.clone.1) + %constant_218177_1_clone_1 = u32[] constant(2492626318) + %broadcast.251036.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218177_1_clone_1), dimensions={} + %add.247859.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121636.7.clone.1, %broadcast.251036.5.clone.1) + %add.247860.5.clone.1 = u32[1280,1280]{1,0} add(%add.247858.7.clone.1, %add.247859.5.clone.1) + %shift-left.109388.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247859.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115551.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247859.5.clone.1, %broadcast.244415.6016) + %or.115089.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109388.9.clone.1, %shift-right-logical.115551.9.clone.1) + %xor.121637.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247860.5.clone.1, %or.115089.7.clone.1) + %add.247861.3.clone.1 = u32[1280,1280]{1,0} add(%add.247860.5.clone.1, %xor.121637.5.clone.1) + %shift-left.109389.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121637.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115552.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121637.5.clone.1, %broadcast.244417.5760) + %or.115090.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109389.9.clone.1, %shift-right-logical.115552.9.clone.1) + %xor.121638.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247861.3.clone.1, %or.115090.7.clone.1) + %add.247862.3.clone.1 = u32[1280,1280]{1,0} add(%add.247861.3.clone.1, %xor.121638.5.clone.1) + %shift-left.109390.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121638.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115554.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121638.5.clone.1, %broadcast.244419.4352) + %or.115091.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109390.7.clone.1, %shift-right-logical.115554.7.clone.1) + %xor.121639.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247862.3.clone.1, %or.115091.5.clone.1) + %add.247863.3.clone.1 = u32[1280,1280]{1,0} add(%add.247862.3.clone.1, %xor.121639.3.clone.1) + %add.247864.7.clone.1 = u32[1280,1280]{1,0} add(%add.247863.3.clone.1, %broadcast.251013.44.clone.1) + %shift-left.109391.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121639.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115555.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121639.3.clone.1, %broadcast.244418.4352) + %or.115092.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109391.7.clone.1, %shift-right-logical.115555.7.clone.1) + %xor.121640.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247863.3.clone.1, %or.115092.5.clone.1) + %constant_218178_1_clone_1 = u32[] constant(377826070) + %broadcast.251048.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218178_1_clone_1), dimensions={} + %add.247865.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121640.3.clone.1, %broadcast.251048.5.clone.1) + %add.247866.5.clone.1 = u32[1280,1280]{1,0} add(%add.247864.7.clone.1, %add.247865.5.clone.1) + %shift-left.109392.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247865.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115556.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247865.5.clone.1, %broadcast.244416.5760) + %or.115094.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109392.9.clone.1, %shift-right-logical.115556.9.clone.1) + %xor.121641.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247866.5.clone.1, %or.115094.7.clone.1) + %add.247867.3.clone.1 = u32[1280,1280]{1,0} add(%add.247866.5.clone.1, %xor.121641.5.clone.1) + %shift-left.109393.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121641.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115557.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121641.5.clone.1, %broadcast.244429.2304) + %or.115095.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109393.9.clone.1, %shift-right-logical.115557.9.clone.1) + %xor.121642.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247867.3.clone.1, %or.115095.7.clone.1) + %add.247868.3.clone.1 = u32[1280,1280]{1,0} add(%add.247867.3.clone.1, %xor.121642.5.clone.1) + %shift-left.109394.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121642.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115558.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121642.5.clone.1, %broadcast.244430.4608) + %or.115096.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109394.9.clone.1, %shift-right-logical.115558.9.clone.1) + %xor.121643.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247868.3.clone.1, %or.115096.7.clone.1) + %add.247869.3.clone.1 = u32[1280,1280]{1,0} add(%add.247868.3.clone.1, %xor.121643.5.clone.1) + %add.247870.7.clone.1 = u32[1280,1280]{1,0} add(%add.247869.3.clone.1, %broadcast.251014.113.clone.1) + %shift-left.109395.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121643.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115559.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121643.5.clone.1, %broadcast.244434.2816) + %or.115097.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109395.11.clone.1, %shift-right-logical.115559.11.clone.1) + %xor.121644.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247869.3.clone.1, %or.115097.9.clone.1) + %constant_218179_1_clone_1 = u32[] constant(2579907913) + %broadcast.251058.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218179_1_clone_1), dimensions={} + %add.247871.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121644.7.clone.1, %broadcast.251058.5.clone.1) + %add.247872.5.clone.1 = u32[1280,1280]{1,0} add(%add.247870.7.clone.1, %add.247871.5.clone.1) + %shift-left.109396.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247871.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115560.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247871.5.clone.1, %broadcast.244415.6016) + %or.115099.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109396.9.clone.1, %shift-right-logical.115560.9.clone.1) + %xor.121645.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247872.5.clone.1, %or.115099.7.clone.1) + %add.247873.3.clone.1 = u32[1280,1280]{1,0} add(%add.247872.5.clone.1, %xor.121645.5.clone.1) + %shift-left.109397.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121645.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115561.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121645.5.clone.1, %broadcast.244417.5760) + %or.115100.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109397.9.clone.1, %shift-right-logical.115561.9.clone.1) + %xor.121646.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247873.3.clone.1, %or.115100.7.clone.1) + %add.247875.3.clone.1 = u32[1280,1280]{1,0} add(%add.247873.3.clone.1, %xor.121646.5.clone.1) + %shift-left.109398.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121646.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115562.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121646.5.clone.1, %broadcast.244419.4352) + %or.115101.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109398.5.clone.1, %shift-right-logical.115562.5.clone.1) + %xor.121647.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247875.3.clone.1, %or.115101.3.clone.1) + %add.247876.3.clone.1 = u32[1280,1280]{1,0} add(%add.247875.3.clone.1, %xor.121647.3.clone.1) + %add.247877.17.clone.1 = u32[1280,1280]{1,0} add(%add.247876.3.clone.1, %broadcast.251033.24.clone.1) + %shift-left.109399.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121647.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115564.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121647.3.clone.1, %broadcast.244418.4352) + %or.115102.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109399.5.clone.1, %shift-right-logical.115564.5.clone.1) + %xor.121648.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247876.3.clone.1, %or.115102.3.clone.1) + %constant_218180_1_clone_1 = u32[] constant(2492626321) + %broadcast.251070.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218180_1_clone_1), dimensions={} + %add.247878.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121648.15.clone.1, %broadcast.251070.19.clone.1) + %xor.121649.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247877.17.clone.1, %add.247878.19.clone.1) + %shift-right-logical.115565.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121649.17.clone.1, %broadcast.244468.1920) + %or.115104.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115565.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5744.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115104.13.clone.1) + %add.247879.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5744.11.clone.1, %broadcast.244470.1152) + %multiply.26363.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247879.9.clone.1, %broadcast.244471.896) + %add.247880.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26363.7.clone.1, %broadcast.244408.1024) + %maximum.3676.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247880.5.clone.1) + %abs.1534.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3676.3.clone.1) + %compare.7216.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1534.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26364.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3676.3.clone.1, %broadcast.244476.1152) + %negate.4573.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3676.3.clone.1) + %multiply.26365.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3676.3.clone.1, %negate.4573.5.clone.1) + %log-plus-one.1534.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26365.5.clone.1) + %negate.4574.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1534.3.clone.1) + %compare.7217.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4574.4.clone.1, %broadcast.244477.384), direction=LT + %select.20939.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20940.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20941.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20942.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20943.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20944.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20945.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20946.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20947.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247881.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4574.4.clone.1, %broadcast.244496.640) + %sqrt.1534.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4574.4.clone.1) + %add.247882.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1534.5.clone.1, %broadcast.244498.640) + %select.20948.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7217.3.clone.1, %add.247881.5.clone.1, %add.247882.5.clone.1) + %multiply.26366.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20947.3.clone.1, %select.20948.3.clone.1) + %add.247883.1.clone.1 = f32[1280,1280]{1,0} add(%select.20946.3.clone.1, %multiply.26366.1.clone.1) + %multiply.26367.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247883.1.clone.1, %select.20948.3.clone.1) + %add.247884.1.clone.1 = f32[1280,1280]{1,0} add(%select.20945.3.clone.1, %multiply.26367.1.clone.1) + %multiply.26368.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247884.1.clone.1, %select.20948.3.clone.1) + %add.247885.1.clone.1 = f32[1280,1280]{1,0} add(%select.20944.3.clone.1, %multiply.26368.1.clone.1) + %multiply.26369.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247885.1.clone.1, %select.20948.3.clone.1) + %add.247886.1.clone.1 = f32[1280,1280]{1,0} add(%select.20943.3.clone.1, %multiply.26369.1.clone.1) + %multiply.26370.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247886.1.clone.1, %select.20948.3.clone.1) + %add.247887.3.clone.1 = f32[1280,1280]{1,0} add(%select.20942.5.clone.1, %multiply.26370.1.clone.1) + %multiply.26371.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247887.3.clone.1, %select.20948.3.clone.1) + %add.247888.3.clone.1 = f32[1280,1280]{1,0} add(%select.20941.5.clone.1, %multiply.26371.1.clone.1) + %multiply.26372.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247888.3.clone.1, %select.20948.3.clone.1) + %add.247889.9.clone.1 = f32[1280,1280]{1,0} add(%select.20940.11.clone.1, %multiply.26372.7.clone.1) + %multiply.26373.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247889.9.clone.1, %select.20948.3.clone.1) + %add.247890.7.clone.1 = f32[1280,1280]{1,0} add(%select.20939.7.clone.1, %multiply.26373.7.clone.1) + %multiply.26374.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247890.7.clone.1, %maximum.3676.3.clone.1) + %select.20949.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7216.3.clone.1, %multiply.26364.9.clone.1, %multiply.26374.7.clone.1) + %multiply.26375.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20949.7.clone.1, %broadcast.244500.640) + %clamp.1178.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26375.5.clone.1, %broadcast.244501.384) + %multiply.26376.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1178.3.clone.1, %broadcast.244502.1) + %constant_191234_1_clone_1 = u32[] constant(2005361984) + %broadcast.258729.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191234_1_clone_1), dimensions={} + %add.252255.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258729.44.clone.1) + %constant_191241_1_clone_1 = u32[] constant(2881587123) + %broadcast.258730.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191241_1_clone_1), dimensions={} + %add.252256.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258730.113.clone.1) + %add.252257.35.clone.1 = u32[1280,1280]{1,0} add(%add.252255.37.clone.1, %add.252256.99.clone.1) + %shift-left.111300.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252256.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117581.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252256.99.clone.1, %broadcast.244415.6016) + %or.117115.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111300.31.clone.1, %shift-right-logical.117581.29.clone.1) + %xor.123676.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252257.35.clone.1, %or.117115.29.clone.1) + %add.252259.5.clone.1 = u32[1280,1280]{1,0} add(%add.252257.35.clone.1, %xor.123676.27.clone.1) + %shift-left.111301.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123676.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117582.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123676.27.clone.1, %broadcast.244417.5760) + %or.117116.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111301.9.clone.1, %shift-right-logical.117582.9.clone.1) + %xor.123677.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252259.5.clone.1, %or.117116.7.clone.1) + %add.252260.3.clone.1 = u32[1280,1280]{1,0} add(%add.252259.5.clone.1, %xor.123677.5.clone.1) + %shift-left.111302.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123677.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117583.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123677.5.clone.1, %broadcast.244419.4352) + %or.117117.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111302.5.clone.1, %shift-right-logical.117583.5.clone.1) + %xor.123678.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252260.3.clone.1, %or.117117.3.clone.1) + %add.252261.3.clone.1 = u32[1280,1280]{1,0} add(%add.252260.3.clone.1, %xor.123678.3.clone.1) + %add.252262.7.clone.1 = u32[1280,1280]{1,0} add(%add.252261.3.clone.1, %broadcast.258730.113.clone.1) + %shift-left.111303.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123678.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117584.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123678.3.clone.1, %broadcast.244418.4352) + %or.117118.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111303.5.clone.1, %shift-right-logical.117584.5.clone.1) + %xor.123680.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252261.3.clone.1, %or.117118.3.clone.1) + %constant_218670_1_clone_1 = u32[] constant(3348625706) + %broadcast.258740.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218670_1_clone_1), dimensions={} + %add.252263.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123680.3.clone.1, %broadcast.258740.5.clone.1) + %add.252265.5.clone.1 = u32[1280,1280]{1,0} add(%add.252262.7.clone.1, %add.252263.5.clone.1) + %shift-left.111304.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252263.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117585.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252263.5.clone.1, %broadcast.244416.5760) + %or.117120.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111304.9.clone.1, %shift-right-logical.117585.9.clone.1) + %xor.123681.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252265.5.clone.1, %or.117120.7.clone.1) + %add.252269.3.clone.1 = u32[1280,1280]{1,0} add(%add.252265.5.clone.1, %xor.123681.5.clone.1) + %shift-left.111305.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123681.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117586.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123681.5.clone.1, %broadcast.244429.2304) + %or.117121.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111305.9.clone.1, %shift-right-logical.117586.9.clone.1) + %xor.123682.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252269.3.clone.1, %or.117121.7.clone.1) + %add.252270.3.clone.1 = u32[1280,1280]{1,0} add(%add.252269.3.clone.1, %xor.123682.5.clone.1) + %shift-left.111306.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123682.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117587.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123682.5.clone.1, %broadcast.244430.4608) + %or.117122.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111306.9.clone.1, %shift-right-logical.117587.9.clone.1) + %xor.123683.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252270.3.clone.1, %or.117122.7.clone.1) + %add.252271.3.clone.1 = u32[1280,1280]{1,0} add(%add.252270.3.clone.1, %xor.123683.5.clone.1) + %constant_191243_1_clone_1 = u32[] constant(3348625705) + %broadcast.258747.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191243_1_clone_1), dimensions={} + %add.252272.7.clone.1 = u32[1280,1280]{1,0} add(%add.252271.3.clone.1, %broadcast.258747.24.clone.1) + %shift-left.111307.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123683.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117588.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123683.5.clone.1, %broadcast.244434.2816) + %or.117123.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111307.11.clone.1, %shift-right-logical.117588.11.clone.1) + %xor.123685.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252271.3.clone.1, %or.117123.9.clone.1) + %constant_218671_1_clone_1 = u32[] constant(2005361986) + %broadcast.258750.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218671_1_clone_1), dimensions={} + %add.252274.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123685.7.clone.1, %broadcast.258750.5.clone.1) + %add.252275.5.clone.1 = u32[1280,1280]{1,0} add(%add.252272.7.clone.1, %add.252274.5.clone.1) + %shift-left.111308.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252274.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117589.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252274.5.clone.1, %broadcast.244415.6016) + %or.117124.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111308.9.clone.1, %shift-right-logical.117589.9.clone.1) + %xor.123686.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252275.5.clone.1, %or.117124.7.clone.1) + %add.252276.3.clone.1 = u32[1280,1280]{1,0} add(%add.252275.5.clone.1, %xor.123686.5.clone.1) + %shift-left.111309.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123686.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117590.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123686.5.clone.1, %broadcast.244417.5760) + %or.117125.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111309.9.clone.1, %shift-right-logical.117590.9.clone.1) + %xor.123687.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252276.3.clone.1, %or.117125.7.clone.1) + %add.252277.3.clone.1 = u32[1280,1280]{1,0} add(%add.252276.3.clone.1, %xor.123687.5.clone.1) + %shift-left.111310.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123687.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117591.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123687.5.clone.1, %broadcast.244419.4352) + %or.117126.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111310.7.clone.1, %shift-right-logical.117591.7.clone.1) + %xor.123688.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252277.3.clone.1, %or.117126.5.clone.1) + %add.252279.3.clone.1 = u32[1280,1280]{1,0} add(%add.252277.3.clone.1, %xor.123688.3.clone.1) + %add.252280.7.clone.1 = u32[1280,1280]{1,0} add(%add.252279.3.clone.1, %broadcast.258729.44.clone.1) + %shift-left.111311.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123688.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117592.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123688.3.clone.1, %broadcast.244418.4352) + %or.117127.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111311.7.clone.1, %shift-right-logical.117592.7.clone.1) + %xor.123690.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252279.3.clone.1, %or.117127.5.clone.1) + %constant_218672_1_clone_1 = u32[] constant(2881587126) + %broadcast.258760.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218672_1_clone_1), dimensions={} + %add.252281.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123690.3.clone.1, %broadcast.258760.5.clone.1) + %add.252282.5.clone.1 = u32[1280,1280]{1,0} add(%add.252280.7.clone.1, %add.252281.5.clone.1) + %shift-left.111312.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252281.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117593.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252281.5.clone.1, %broadcast.244416.5760) + %or.117128.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111312.9.clone.1, %shift-right-logical.117593.9.clone.1) + %xor.123691.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252282.5.clone.1, %or.117128.7.clone.1) + %add.252284.3.clone.1 = u32[1280,1280]{1,0} add(%add.252282.5.clone.1, %xor.123691.5.clone.1) + %shift-left.111313.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123691.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117594.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123691.5.clone.1, %broadcast.244429.2304) + %or.117130.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111313.9.clone.1, %shift-right-logical.117594.9.clone.1) + %xor.123692.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252284.3.clone.1, %or.117130.7.clone.1) + %add.252285.3.clone.1 = u32[1280,1280]{1,0} add(%add.252284.3.clone.1, %xor.123692.5.clone.1) + %shift-left.111314.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123692.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117595.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123692.5.clone.1, %broadcast.244430.4608) + %or.117131.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111314.9.clone.1, %shift-right-logical.117595.9.clone.1) + %xor.123693.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252285.3.clone.1, %or.117131.7.clone.1) + %add.252286.3.clone.1 = u32[1280,1280]{1,0} add(%add.252285.3.clone.1, %xor.123693.5.clone.1) + %add.252287.7.clone.1 = u32[1280,1280]{1,0} add(%add.252286.3.clone.1, %broadcast.258730.113.clone.1) + %shift-left.111315.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123693.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117596.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123693.5.clone.1, %broadcast.244434.2816) + %or.117132.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111315.11.clone.1, %shift-right-logical.117596.11.clone.1) + %xor.123694.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252286.3.clone.1, %or.117132.9.clone.1) + %constant_218673_1_clone_1 = u32[] constant(3348625709) + %broadcast.258770.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218673_1_clone_1), dimensions={} + %add.252288.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123694.7.clone.1, %broadcast.258770.5.clone.1) + %add.252290.5.clone.1 = u32[1280,1280]{1,0} add(%add.252287.7.clone.1, %add.252288.5.clone.1) + %shift-left.111316.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252288.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117597.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252288.5.clone.1, %broadcast.244415.6016) + %or.117133.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111316.9.clone.1, %shift-right-logical.117597.9.clone.1) + %xor.123695.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252290.5.clone.1, %or.117133.7.clone.1) + %add.252293.3.clone.1 = u32[1280,1280]{1,0} add(%add.252290.5.clone.1, %xor.123695.5.clone.1) + %shift-left.111317.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123695.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117598.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123695.5.clone.1, %broadcast.244417.5760) + %or.117135.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111317.9.clone.1, %shift-right-logical.117598.9.clone.1) + %xor.123696.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252293.3.clone.1, %or.117135.7.clone.1) + %add.252294.3.clone.1 = u32[1280,1280]{1,0} add(%add.252293.3.clone.1, %xor.123696.5.clone.1) + %shift-left.111318.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123696.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117599.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123696.5.clone.1, %broadcast.244419.4352) + %or.117136.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111318.5.clone.1, %shift-right-logical.117599.5.clone.1) + %xor.123697.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252294.3.clone.1, %or.117136.3.clone.1) + %add.252295.3.clone.1 = u32[1280,1280]{1,0} add(%add.252294.3.clone.1, %xor.123697.3.clone.1) + %add.252296.17.clone.1 = u32[1280,1280]{1,0} add(%add.252295.3.clone.1, %broadcast.258747.24.clone.1) + %shift-left.111319.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123697.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117600.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123697.3.clone.1, %broadcast.244418.4352) + %or.117137.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111319.5.clone.1, %shift-right-logical.117600.5.clone.1) + %xor.123698.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252295.3.clone.1, %or.117137.3.clone.1) + %constant_218674_1_clone_1 = u32[] constant(2005361989) + %broadcast.258780.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218674_1_clone_1), dimensions={} + %add.252297.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123698.15.clone.1, %broadcast.258780.19.clone.1) + %xor.123699.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252296.17.clone.1, %add.252297.19.clone.1) + %shift-right-logical.117601.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123699.17.clone.1, %broadcast.244468.1920) + %or.117138.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117601.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5832.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117138.13.clone.1) + %add.252298.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5832.11.clone.1, %broadcast.244470.1152) + %multiply.27264.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252298.9.clone.1, %broadcast.244471.896) + %add.252299.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27264.7.clone.1, %broadcast.244408.1024) + %maximum.3764.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252299.5.clone.1) + %abs.1592.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3764.3.clone.1) + %compare.7346.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1592.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27265.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3764.3.clone.1, %broadcast.244476.1152) + %negate.4689.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3764.3.clone.1) + %multiply.27266.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3764.3.clone.1, %negate.4689.5.clone.1) + %log-plus-one.1592.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27266.5.clone.1) + %negate.4690.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1592.3.clone.1) + %compare.7347.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4690.4.clone.1, %broadcast.244477.384), direction=LT + %select.21619.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21620.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21621.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21622.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21623.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21624.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21625.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21626.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21627.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252300.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4690.4.clone.1, %broadcast.244496.640) + %sqrt.1592.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4690.4.clone.1) + %add.252301.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1592.5.clone.1, %broadcast.244498.640) + %select.21628.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7347.3.clone.1, %add.252300.5.clone.1, %add.252301.5.clone.1) + %multiply.27267.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21627.3.clone.1, %select.21628.3.clone.1) + %add.252302.1.clone.1 = f32[1280,1280]{1,0} add(%select.21626.3.clone.1, %multiply.27267.1.clone.1) + %multiply.27268.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252302.1.clone.1, %select.21628.3.clone.1) + %add.252303.1.clone.1 = f32[1280,1280]{1,0} add(%select.21625.3.clone.1, %multiply.27268.1.clone.1) + %multiply.27269.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252303.1.clone.1, %select.21628.3.clone.1) + %add.252304.1.clone.1 = f32[1280,1280]{1,0} add(%select.21624.3.clone.1, %multiply.27269.1.clone.1) + %multiply.27270.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252304.1.clone.1, %select.21628.3.clone.1) + %add.252305.1.clone.1 = f32[1280,1280]{1,0} add(%select.21623.3.clone.1, %multiply.27270.1.clone.1) + %multiply.27271.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252305.1.clone.1, %select.21628.3.clone.1) + %add.252306.3.clone.1 = f32[1280,1280]{1,0} add(%select.21622.5.clone.1, %multiply.27271.1.clone.1) + %multiply.27272.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252306.3.clone.1, %select.21628.3.clone.1) + %add.252307.3.clone.1 = f32[1280,1280]{1,0} add(%select.21621.5.clone.1, %multiply.27272.1.clone.1) + %multiply.27273.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252307.3.clone.1, %select.21628.3.clone.1) + %add.252308.9.clone.1 = f32[1280,1280]{1,0} add(%select.21620.11.clone.1, %multiply.27273.7.clone.1) + %multiply.27274.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252308.9.clone.1, %select.21628.3.clone.1) + %add.252309.7.clone.1 = f32[1280,1280]{1,0} add(%select.21619.7.clone.1, %multiply.27274.7.clone.1) + %multiply.27275.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252309.7.clone.1, %maximum.3764.3.clone.1) + %select.21629.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7346.3.clone.1, %multiply.27265.9.clone.1, %multiply.27275.7.clone.1) + %multiply.27276.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21629.7.clone.1, %broadcast.244500.640) + %clamp.1236.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27276.5.clone.1, %broadcast.244501.384) + %multiply.27277.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1236.3.clone.1, %broadcast.244502.1) + %constant_173163_1_clone_1 = u32[] constant(3938287682) + %broadcast.250910.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173163_1_clone_1), dimensions={} + %add.247790.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250910.44.clone.1) + %constant_173170_1_clone_1 = u32[] constant(2876408776) + %broadcast.250911.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173170_1_clone_1), dimensions={} + %add.247791.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250911.113.clone.1) + %add.247793.35.clone.1 = u32[1280,1280]{1,0} add(%add.247790.37.clone.1, %add.247791.99.clone.1) + %shift-left.109360.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247791.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115519.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247791.99.clone.1, %broadcast.244415.6016) + %or.115055.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109360.31.clone.1, %shift-right-logical.115519.29.clone.1) + %xor.121604.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247793.35.clone.1, %or.115055.29.clone.1) + %add.247794.5.clone.1 = u32[1280,1280]{1,0} add(%add.247793.35.clone.1, %xor.121604.27.clone.1) + %shift-left.109361.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121604.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115520.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121604.27.clone.1, %broadcast.244417.5760) + %or.115056.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109361.9.clone.1, %shift-right-logical.115520.9.clone.1) + %xor.121605.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247794.5.clone.1, %or.115056.7.clone.1) + %add.247795.3.clone.1 = u32[1280,1280]{1,0} add(%add.247794.5.clone.1, %xor.121605.5.clone.1) + %shift-left.109362.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121605.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115521.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121605.5.clone.1, %broadcast.244419.4352) + %or.115057.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109362.5.clone.1, %shift-right-logical.115521.5.clone.1) + %xor.121606.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247795.3.clone.1, %or.115057.3.clone.1) + %add.247796.3.clone.1 = u32[1280,1280]{1,0} add(%add.247795.3.clone.1, %xor.121606.3.clone.1) + %add.247797.7.clone.1 = u32[1280,1280]{1,0} add(%add.247796.3.clone.1, %broadcast.250911.113.clone.1) + %shift-left.109363.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121606.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115522.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121606.3.clone.1, %broadcast.244418.4352) + %or.115058.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109363.5.clone.1, %shift-right-logical.115522.5.clone.1) + %xor.121607.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247796.3.clone.1, %or.115058.3.clone.1) + %constant_218171_1_clone_1 = u32[] constant(1511921745) + %broadcast.250921.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218171_1_clone_1), dimensions={} + %add.247799.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121607.3.clone.1, %broadcast.250921.5.clone.1) + %add.247803.5.clone.1 = u32[1280,1280]{1,0} add(%add.247797.7.clone.1, %add.247799.5.clone.1) + %shift-left.109364.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247799.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115523.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247799.5.clone.1, %broadcast.244416.5760) + %or.115059.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109364.9.clone.1, %shift-right-logical.115523.9.clone.1) + %xor.121609.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247803.5.clone.1, %or.115059.7.clone.1) + %add.247804.3.clone.1 = u32[1280,1280]{1,0} add(%add.247803.5.clone.1, %xor.121609.5.clone.1) + %shift-left.109365.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121609.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115524.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121609.5.clone.1, %broadcast.244429.2304) + %or.115060.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109365.9.clone.1, %shift-right-logical.115524.9.clone.1) + %xor.121610.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247804.3.clone.1, %or.115060.7.clone.1) + %add.247805.3.clone.1 = u32[1280,1280]{1,0} add(%add.247804.3.clone.1, %xor.121610.5.clone.1) + %shift-left.109366.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121610.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115525.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121610.5.clone.1, %broadcast.244430.4608) + %or.115061.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109366.9.clone.1, %shift-right-logical.115525.9.clone.1) + %xor.121611.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247805.3.clone.1, %or.115061.7.clone.1) + %add.247806.3.clone.1 = u32[1280,1280]{1,0} add(%add.247805.3.clone.1, %xor.121611.5.clone.1) + %constant_173172_1_clone_1 = u32[] constant(1511921744) + %broadcast.250928.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_173172_1_clone_1), dimensions={} + %add.247808.7.clone.1 = u32[1280,1280]{1,0} add(%add.247806.3.clone.1, %broadcast.250928.24.clone.1) + %shift-left.109367.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121611.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115526.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121611.5.clone.1, %broadcast.244434.2816) + %or.115062.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109367.11.clone.1, %shift-right-logical.115526.11.clone.1) + %xor.121612.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247806.3.clone.1, %or.115062.9.clone.1) + %constant_218172_1_clone_1 = u32[] constant(3938287684) + %broadcast.250931.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218172_1_clone_1), dimensions={} + %add.247809.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121612.7.clone.1, %broadcast.250931.5.clone.1) + %add.247810.5.clone.1 = u32[1280,1280]{1,0} add(%add.247808.7.clone.1, %add.247809.5.clone.1) + %shift-left.109368.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247809.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115527.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247809.5.clone.1, %broadcast.244415.6016) + %or.115064.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109368.9.clone.1, %shift-right-logical.115527.9.clone.1) + %xor.121614.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247810.5.clone.1, %or.115064.7.clone.1) + %add.247811.3.clone.1 = u32[1280,1280]{1,0} add(%add.247810.5.clone.1, %xor.121614.5.clone.1) + %shift-left.109369.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121614.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115528.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121614.5.clone.1, %broadcast.244417.5760) + %or.115065.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109369.9.clone.1, %shift-right-logical.115528.9.clone.1) + %xor.121615.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247811.3.clone.1, %or.115065.7.clone.1) + %add.247813.3.clone.1 = u32[1280,1280]{1,0} add(%add.247811.3.clone.1, %xor.121615.5.clone.1) + %shift-left.109370.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121615.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115529.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121615.5.clone.1, %broadcast.244419.4352) + %or.115066.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109370.7.clone.1, %shift-right-logical.115529.7.clone.1) + %xor.121616.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247813.3.clone.1, %or.115066.5.clone.1) + %add.247814.3.clone.1 = u32[1280,1280]{1,0} add(%add.247813.3.clone.1, %xor.121616.3.clone.1) + %add.247815.7.clone.1 = u32[1280,1280]{1,0} add(%add.247814.3.clone.1, %broadcast.250910.44.clone.1) + %shift-left.109371.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121616.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115530.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121616.3.clone.1, %broadcast.244418.4352) + %or.115067.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109371.7.clone.1, %shift-right-logical.115530.7.clone.1) + %xor.121617.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247814.3.clone.1, %or.115067.5.clone.1) + %constant_218173_1_clone_1 = u32[] constant(2876408779) + %broadcast.250941.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218173_1_clone_1), dimensions={} + %add.247816.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121617.3.clone.1, %broadcast.250941.5.clone.1) + %add.247818.5.clone.1 = u32[1280,1280]{1,0} add(%add.247815.7.clone.1, %add.247816.5.clone.1) + %shift-left.109372.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247816.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115531.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247816.5.clone.1, %broadcast.244416.5760) + %or.115069.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109372.9.clone.1, %shift-right-logical.115531.9.clone.1) + %xor.121619.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247818.5.clone.1, %or.115069.7.clone.1) + %add.247819.3.clone.1 = u32[1280,1280]{1,0} add(%add.247818.5.clone.1, %xor.121619.5.clone.1) + %shift-left.109373.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121619.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115532.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121619.5.clone.1, %broadcast.244429.2304) + %or.115070.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109373.9.clone.1, %shift-right-logical.115532.9.clone.1) + %xor.121620.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247819.3.clone.1, %or.115070.7.clone.1) + %add.247820.3.clone.1 = u32[1280,1280]{1,0} add(%add.247819.3.clone.1, %xor.121620.5.clone.1) + %shift-left.109374.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121620.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115533.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121620.5.clone.1, %broadcast.244430.4608) + %or.115071.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109374.9.clone.1, %shift-right-logical.115533.9.clone.1) + %xor.121621.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247820.3.clone.1, %or.115071.7.clone.1) + %add.247821.3.clone.1 = u32[1280,1280]{1,0} add(%add.247820.3.clone.1, %xor.121621.5.clone.1) + %add.247822.7.clone.1 = u32[1280,1280]{1,0} add(%add.247821.3.clone.1, %broadcast.250911.113.clone.1) + %shift-left.109375.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121621.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115534.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121621.5.clone.1, %broadcast.244434.2816) + %or.115072.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109375.11.clone.1, %shift-right-logical.115534.11.clone.1) + %xor.121622.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247821.3.clone.1, %or.115072.9.clone.1) + %constant_218174_1_clone_1 = u32[] constant(1511921748) + %broadcast.250951.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218174_1_clone_1), dimensions={} + %add.247824.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121622.7.clone.1, %broadcast.250951.5.clone.1) + %add.247827.5.clone.1 = u32[1280,1280]{1,0} add(%add.247822.7.clone.1, %add.247824.5.clone.1) + %shift-left.109376.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247824.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115535.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247824.5.clone.1, %broadcast.244415.6016) + %or.115074.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109376.9.clone.1, %shift-right-logical.115535.9.clone.1) + %xor.121624.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247827.5.clone.1, %or.115074.7.clone.1) + %add.247828.3.clone.1 = u32[1280,1280]{1,0} add(%add.247827.5.clone.1, %xor.121624.5.clone.1) + %shift-left.109377.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121624.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115536.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121624.5.clone.1, %broadcast.244417.5760) + %or.115075.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109377.9.clone.1, %shift-right-logical.115536.9.clone.1) + %xor.121625.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247828.3.clone.1, %or.115075.7.clone.1) + %add.247829.3.clone.1 = u32[1280,1280]{1,0} add(%add.247828.3.clone.1, %xor.121625.5.clone.1) + %shift-left.109378.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121625.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115537.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121625.5.clone.1, %broadcast.244419.4352) + %or.115076.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109378.5.clone.1, %shift-right-logical.115537.5.clone.1) + %xor.121626.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247829.3.clone.1, %or.115076.3.clone.1) + %add.247830.3.clone.1 = u32[1280,1280]{1,0} add(%add.247829.3.clone.1, %xor.121626.3.clone.1) + %add.247831.17.clone.1 = u32[1280,1280]{1,0} add(%add.247830.3.clone.1, %broadcast.250928.24.clone.1) + %shift-left.109379.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121626.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115539.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121626.3.clone.1, %broadcast.244418.4352) + %or.115077.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109379.5.clone.1, %shift-right-logical.115539.5.clone.1) + %xor.121627.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247830.3.clone.1, %or.115077.3.clone.1) + %constant_218175_1_clone_1 = u32[] constant(3938287687) + %broadcast.250964.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218175_1_clone_1), dimensions={} + %add.247832.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121627.15.clone.1, %broadcast.250964.19.clone.1) + %xor.121628.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247831.17.clone.1, %add.247832.19.clone.1) + %shift-right-logical.115540.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121628.17.clone.1, %broadcast.244468.1920) + %or.115079.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115540.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5743.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115079.13.clone.1) + %add.247833.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5743.11.clone.1, %broadcast.244470.1152) + %multiply.26348.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247833.9.clone.1, %broadcast.244471.896) + %add.247834.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26348.7.clone.1, %broadcast.244408.1024) + %maximum.3675.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247834.5.clone.1) + %abs.1533.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3675.3.clone.1) + %compare.7214.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1533.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26349.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3675.3.clone.1, %broadcast.244476.1152) + %negate.4571.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3675.3.clone.1) + %multiply.26350.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3675.3.clone.1, %negate.4571.5.clone.1) + %log-plus-one.1533.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26350.5.clone.1) + %negate.4572.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1533.3.clone.1) + %compare.7215.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4572.4.clone.1, %broadcast.244477.384), direction=LT + %select.20928.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20929.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20930.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20931.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20932.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20933.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20934.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20935.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20936.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247835.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4572.4.clone.1, %broadcast.244496.640) + %sqrt.1533.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4572.4.clone.1) + %add.247836.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1533.5.clone.1, %broadcast.244498.640) + %select.20937.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7215.3.clone.1, %add.247835.5.clone.1, %add.247836.5.clone.1) + %multiply.26351.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20936.3.clone.1, %select.20937.3.clone.1) + %add.247837.1.clone.1 = f32[1280,1280]{1,0} add(%select.20935.3.clone.1, %multiply.26351.1.clone.1) + %multiply.26352.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247837.1.clone.1, %select.20937.3.clone.1) + %add.247838.1.clone.1 = f32[1280,1280]{1,0} add(%select.20934.3.clone.1, %multiply.26352.1.clone.1) + %multiply.26353.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247838.1.clone.1, %select.20937.3.clone.1) + %add.247839.1.clone.1 = f32[1280,1280]{1,0} add(%select.20933.3.clone.1, %multiply.26353.1.clone.1) + %multiply.26354.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247839.1.clone.1, %select.20937.3.clone.1) + %add.247840.1.clone.1 = f32[1280,1280]{1,0} add(%select.20932.3.clone.1, %multiply.26354.1.clone.1) + %multiply.26355.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247840.1.clone.1, %select.20937.3.clone.1) + %add.247841.3.clone.1 = f32[1280,1280]{1,0} add(%select.20931.5.clone.1, %multiply.26355.1.clone.1) + %multiply.26356.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247841.3.clone.1, %select.20937.3.clone.1) + %add.247842.3.clone.1 = f32[1280,1280]{1,0} add(%select.20930.5.clone.1, %multiply.26356.1.clone.1) + %multiply.26357.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247842.3.clone.1, %select.20937.3.clone.1) + %add.247843.9.clone.1 = f32[1280,1280]{1,0} add(%select.20929.11.clone.1, %multiply.26357.7.clone.1) + %multiply.26358.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247843.9.clone.1, %select.20937.3.clone.1) + %add.247844.7.clone.1 = f32[1280,1280]{1,0} add(%select.20928.7.clone.1, %multiply.26358.7.clone.1) + %multiply.26359.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247844.7.clone.1, %maximum.3675.3.clone.1) + %select.20938.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7214.3.clone.1, %multiply.26349.9.clone.1, %multiply.26359.7.clone.1) + %multiply.26360.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20938.7.clone.1, %broadcast.244500.640) + %clamp.1177.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26360.5.clone.1, %broadcast.244501.384) + %multiply.26361.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1177.3.clone.1, %broadcast.244502.1) + %constant_185150_1_clone_1 = u32[] constant(328282474) + %broadcast.256098.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185150_1_clone_1), dimensions={} + %add.250743.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.256098.44.clone.1) + %constant_185157_1_clone_1 = u32[] constant(3490391876) + %broadcast.256099.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185157_1_clone_1), dimensions={} + %add.250744.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256099.113.clone.1) + %add.250745.35.clone.1 = u32[1280,1280]{1,0} add(%add.250743.37.clone.1, %add.250744.99.clone.1) + %shift-left.110640.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250744.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116891.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250744.99.clone.1, %broadcast.244415.6016) + %or.116412.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110640.31.clone.1, %shift-right-logical.116891.29.clone.1) + %xor.122982.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250745.35.clone.1, %or.116412.29.clone.1) + %add.250746.5.clone.1 = u32[1280,1280]{1,0} add(%add.250745.35.clone.1, %xor.122982.27.clone.1) + %shift-left.110641.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122982.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116892.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122982.27.clone.1, %broadcast.244417.5760) + %or.116413.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110641.9.clone.1, %shift-right-logical.116892.9.clone.1) + %xor.122983.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250746.5.clone.1, %or.116413.7.clone.1) + %add.250747.3.clone.1 = u32[1280,1280]{1,0} add(%add.250746.5.clone.1, %xor.122983.5.clone.1) + %shift-left.110642.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122983.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116893.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122983.5.clone.1, %broadcast.244419.4352) + %or.116414.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110642.5.clone.1, %shift-right-logical.116893.5.clone.1) + %xor.122984.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250747.3.clone.1, %or.116414.3.clone.1) + %add.250748.3.clone.1 = u32[1280,1280]{1,0} add(%add.250747.3.clone.1, %xor.122984.3.clone.1) + %add.250749.7.clone.1 = u32[1280,1280]{1,0} add(%add.250748.3.clone.1, %broadcast.256099.113.clone.1) + %shift-left.110643.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122984.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116895.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122984.3.clone.1, %broadcast.244418.4352) + %or.116415.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110643.5.clone.1, %shift-right-logical.116895.5.clone.1) + %xor.122985.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250748.3.clone.1, %or.116415.3.clone.1) + %constant_218499_1_clone_1 = u32[] constant(3628797429) + %broadcast.256110.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218499_1_clone_1), dimensions={} + %add.250750.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122985.3.clone.1, %broadcast.256110.5.clone.1) + %add.250751.5.clone.1 = u32[1280,1280]{1,0} add(%add.250749.7.clone.1, %add.250750.5.clone.1) + %shift-left.110644.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250750.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116896.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250750.5.clone.1, %broadcast.244416.5760) + %or.116416.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110644.9.clone.1, %shift-right-logical.116896.9.clone.1) + %xor.122986.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250751.5.clone.1, %or.116416.7.clone.1) + %add.250752.3.clone.1 = u32[1280,1280]{1,0} add(%add.250751.5.clone.1, %xor.122986.5.clone.1) + %shift-left.110645.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122986.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116897.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122986.5.clone.1, %broadcast.244429.2304) + %or.116417.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110645.9.clone.1, %shift-right-logical.116897.9.clone.1) + %xor.122987.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250752.3.clone.1, %or.116417.7.clone.1) + %add.250753.3.clone.1 = u32[1280,1280]{1,0} add(%add.250752.3.clone.1, %xor.122987.5.clone.1) + %shift-left.110646.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122987.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116898.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122987.5.clone.1, %broadcast.244430.4608) + %or.116418.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110646.9.clone.1, %shift-right-logical.116898.9.clone.1) + %xor.122988.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250753.3.clone.1, %or.116418.7.clone.1) + %add.250754.3.clone.1 = u32[1280,1280]{1,0} add(%add.250753.3.clone.1, %xor.122988.5.clone.1) + %constant_185159_1_clone_1 = u32[] constant(3628797428) + %broadcast.256122.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_185159_1_clone_1), dimensions={} + %add.250755.7.clone.1 = u32[1280,1280]{1,0} add(%add.250754.3.clone.1, %broadcast.256122.24.clone.1) + %shift-left.110647.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122988.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116900.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122988.5.clone.1, %broadcast.244434.2816) + %or.116419.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110647.11.clone.1, %shift-right-logical.116900.11.clone.1) + %xor.122989.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250754.3.clone.1, %or.116419.9.clone.1) + %constant_218500_1_clone_1 = u32[] constant(328282476) + %broadcast.256127.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218500_1_clone_1), dimensions={} + %add.250756.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122989.7.clone.1, %broadcast.256127.5.clone.1) + %add.250757.5.clone.1 = u32[1280,1280]{1,0} add(%add.250755.7.clone.1, %add.250756.5.clone.1) + %shift-left.110648.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250756.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116901.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250756.5.clone.1, %broadcast.244415.6016) + %or.116420.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110648.9.clone.1, %shift-right-logical.116901.9.clone.1) + %xor.122990.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250757.5.clone.1, %or.116420.7.clone.1) + %add.250758.3.clone.1 = u32[1280,1280]{1,0} add(%add.250757.5.clone.1, %xor.122990.5.clone.1) + %shift-left.110649.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122990.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116902.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122990.5.clone.1, %broadcast.244417.5760) + %or.116421.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110649.9.clone.1, %shift-right-logical.116902.9.clone.1) + %xor.122991.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250758.3.clone.1, %or.116421.7.clone.1) + %add.250759.3.clone.1 = u32[1280,1280]{1,0} add(%add.250758.3.clone.1, %xor.122991.5.clone.1) + %shift-left.110650.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122991.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116903.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122991.5.clone.1, %broadcast.244419.4352) + %or.116422.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110650.7.clone.1, %shift-right-logical.116903.7.clone.1) + %xor.122992.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250759.3.clone.1, %or.116422.5.clone.1) + %add.250760.3.clone.1 = u32[1280,1280]{1,0} add(%add.250759.3.clone.1, %xor.122992.3.clone.1) + %add.250761.7.clone.1 = u32[1280,1280]{1,0} add(%add.250760.3.clone.1, %broadcast.256098.44.clone.1) + %shift-left.110651.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122992.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116905.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122992.3.clone.1, %broadcast.244418.4352) + %or.116423.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110651.7.clone.1, %shift-right-logical.116905.7.clone.1) + %xor.122993.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250760.3.clone.1, %or.116423.5.clone.1) + %constant_218501_1_clone_1 = u32[] constant(3490391879) + %broadcast.256147.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218501_1_clone_1), dimensions={} + %add.250762.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122993.3.clone.1, %broadcast.256147.5.clone.1) + %add.250763.5.clone.1 = u32[1280,1280]{1,0} add(%add.250761.7.clone.1, %add.250762.5.clone.1) + %shift-left.110652.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250762.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116906.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250762.5.clone.1, %broadcast.244416.5760) + %or.116424.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110652.9.clone.1, %shift-right-logical.116906.9.clone.1) + %xor.122994.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250763.5.clone.1, %or.116424.7.clone.1) + %add.250764.3.clone.1 = u32[1280,1280]{1,0} add(%add.250763.5.clone.1, %xor.122994.5.clone.1) + %shift-left.110653.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122994.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116907.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122994.5.clone.1, %broadcast.244429.2304) + %or.116425.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110653.9.clone.1, %shift-right-logical.116907.9.clone.1) + %xor.122995.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250764.3.clone.1, %or.116425.7.clone.1) + %add.250765.3.clone.1 = u32[1280,1280]{1,0} add(%add.250764.3.clone.1, %xor.122995.5.clone.1) + %shift-left.110654.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122995.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116908.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122995.5.clone.1, %broadcast.244430.4608) + %or.116426.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110654.9.clone.1, %shift-right-logical.116908.9.clone.1) + %xor.122996.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250765.3.clone.1, %or.116426.7.clone.1) + %add.250766.3.clone.1 = u32[1280,1280]{1,0} add(%add.250765.3.clone.1, %xor.122996.5.clone.1) + %add.250767.7.clone.1 = u32[1280,1280]{1,0} add(%add.250766.3.clone.1, %broadcast.256099.113.clone.1) + %shift-left.110655.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122996.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116910.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122996.5.clone.1, %broadcast.244434.2816) + %or.116427.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110655.11.clone.1, %shift-right-logical.116910.11.clone.1) + %xor.122997.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250766.3.clone.1, %or.116427.9.clone.1) + %constant_218502_1_clone_1 = u32[] constant(3628797432) + %broadcast.256167.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218502_1_clone_1), dimensions={} + %add.250768.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122997.7.clone.1, %broadcast.256167.5.clone.1) + %add.250769.5.clone.1 = u32[1280,1280]{1,0} add(%add.250767.7.clone.1, %add.250768.5.clone.1) + %shift-left.110656.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250768.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116911.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250768.5.clone.1, %broadcast.244415.6016) + %or.116428.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110656.9.clone.1, %shift-right-logical.116911.9.clone.1) + %xor.122998.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250769.5.clone.1, %or.116428.7.clone.1) + %add.250770.3.clone.1 = u32[1280,1280]{1,0} add(%add.250769.5.clone.1, %xor.122998.5.clone.1) + %shift-left.110657.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122998.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116912.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122998.5.clone.1, %broadcast.244417.5760) + %or.116429.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110657.9.clone.1, %shift-right-logical.116912.9.clone.1) + %xor.122999.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250770.3.clone.1, %or.116429.7.clone.1) + %add.250771.3.clone.1 = u32[1280,1280]{1,0} add(%add.250770.3.clone.1, %xor.122999.5.clone.1) + %shift-left.110658.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122999.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116913.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122999.5.clone.1, %broadcast.244419.4352) + %or.116430.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110658.5.clone.1, %shift-right-logical.116913.5.clone.1) + %xor.123000.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250771.3.clone.1, %or.116430.3.clone.1) + %add.250772.3.clone.1 = u32[1280,1280]{1,0} add(%add.250771.3.clone.1, %xor.123000.3.clone.1) + %add.250773.17.clone.1 = u32[1280,1280]{1,0} add(%add.250772.3.clone.1, %broadcast.256122.24.clone.1) + %shift-left.110659.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123000.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116914.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123000.3.clone.1, %broadcast.244418.4352) + %or.116431.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110659.5.clone.1, %shift-right-logical.116914.5.clone.1) + %xor.123001.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250772.3.clone.1, %or.116431.3.clone.1) + %constant_218503_1_clone_1 = u32[] constant(328282479) + %broadcast.256179.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218503_1_clone_1), dimensions={} + %add.250774.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123001.15.clone.1, %broadcast.256179.19.clone.1) + %xor.123002.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250773.17.clone.1, %add.250774.19.clone.1) + %shift-right-logical.116915.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123002.17.clone.1, %broadcast.244468.1920) + %or.116433.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116915.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5802.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116433.13.clone.1) + %add.250775.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5802.11.clone.1, %broadcast.244470.1152) + %multiply.26951.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250775.9.clone.1, %broadcast.244471.896) + %add.250776.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26951.7.clone.1, %broadcast.244408.1024) + %maximum.3734.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250776.5.clone.1) + %abs.1572.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3734.3.clone.1) + %compare.7300.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1572.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26952.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3734.3.clone.1, %broadcast.244476.1152) + %negate.4649.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3734.3.clone.1) + %multiply.26953.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3734.3.clone.1, %negate.4649.5.clone.1) + %log-plus-one.1572.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26953.5.clone.1) + %negate.4650.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1572.3.clone.1) + %compare.7301.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4650.4.clone.1, %broadcast.244477.384), direction=LT + %select.21378.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21379.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21380.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21381.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21382.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21383.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21384.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21385.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21386.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250778.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4650.4.clone.1, %broadcast.244496.640) + %sqrt.1572.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4650.4.clone.1) + %add.250779.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1572.5.clone.1, %broadcast.244498.640) + %select.21387.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7301.3.clone.1, %add.250778.5.clone.1, %add.250779.5.clone.1) + %multiply.26954.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21386.3.clone.1, %select.21387.3.clone.1) + %add.250780.1.clone.1 = f32[1280,1280]{1,0} add(%select.21385.3.clone.1, %multiply.26954.1.clone.1) + %multiply.26955.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250780.1.clone.1, %select.21387.3.clone.1) + %add.250781.1.clone.1 = f32[1280,1280]{1,0} add(%select.21384.3.clone.1, %multiply.26955.1.clone.1) + %multiply.26956.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250781.1.clone.1, %select.21387.3.clone.1) + %add.250782.1.clone.1 = f32[1280,1280]{1,0} add(%select.21383.3.clone.1, %multiply.26956.1.clone.1) + %multiply.26957.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250782.1.clone.1, %select.21387.3.clone.1) + %add.250783.1.clone.1 = f32[1280,1280]{1,0} add(%select.21382.3.clone.1, %multiply.26957.1.clone.1) + %multiply.26958.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250783.1.clone.1, %select.21387.3.clone.1) + %add.250784.3.clone.1 = f32[1280,1280]{1,0} add(%select.21381.5.clone.1, %multiply.26958.1.clone.1) + %multiply.26959.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250784.3.clone.1, %select.21387.3.clone.1) + %add.250785.3.clone.1 = f32[1280,1280]{1,0} add(%select.21380.5.clone.1, %multiply.26959.1.clone.1) + %multiply.26960.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250785.3.clone.1, %select.21387.3.clone.1) + %add.250786.9.clone.1 = f32[1280,1280]{1,0} add(%select.21379.11.clone.1, %multiply.26960.7.clone.1) + %multiply.26961.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250786.9.clone.1, %select.21387.3.clone.1) + %add.250787.7.clone.1 = f32[1280,1280]{1,0} add(%select.21378.7.clone.1, %multiply.26961.7.clone.1) + %multiply.26962.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250787.7.clone.1, %maximum.3734.3.clone.1) + %select.21388.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7300.3.clone.1, %multiply.26952.9.clone.1, %multiply.26962.7.clone.1) + %multiply.26963.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21388.7.clone.1, %broadcast.244500.640) + %clamp.1216.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26963.5.clone.1, %broadcast.244501.384) + %multiply.26964.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1216.3.clone.1, %broadcast.244502.1) + %constant_172952_1_clone_1 = u32[] constant(575593931) + %broadcast.250824.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172952_1_clone_1), dimensions={} + %add.247730.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250824.44.clone.1) + %constant_172959_1_clone_1 = u32[] constant(3302378801) + %broadcast.250825.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172959_1_clone_1), dimensions={} + %add.247731.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250825.113.clone.1) + %add.247733.35.clone.1 = u32[1280,1280]{1,0} add(%add.247730.37.clone.1, %add.247731.99.clone.1) + %shift-left.109340.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247731.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115498.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247731.99.clone.1, %broadcast.244415.6016) + %or.115030.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109340.31.clone.1, %shift-right-logical.115498.29.clone.1) + %xor.121579.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247733.35.clone.1, %or.115030.29.clone.1) + %add.247734.5.clone.1 = u32[1280,1280]{1,0} add(%add.247733.35.clone.1, %xor.121579.27.clone.1) + %shift-left.109341.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121579.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115499.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121579.27.clone.1, %broadcast.244417.5760) + %or.115031.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109341.9.clone.1, %shift-right-logical.115499.9.clone.1) + %xor.121580.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247734.5.clone.1, %or.115031.7.clone.1) + %add.247735.3.clone.1 = u32[1280,1280]{1,0} add(%add.247734.5.clone.1, %xor.121580.5.clone.1) + %shift-left.109342.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121580.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115500.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121580.5.clone.1, %broadcast.244419.4352) + %or.115032.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109342.5.clone.1, %shift-right-logical.115500.5.clone.1) + %xor.121581.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247735.3.clone.1, %or.115032.3.clone.1) + %add.247736.3.clone.1 = u32[1280,1280]{1,0} add(%add.247735.3.clone.1, %xor.121581.3.clone.1) + %add.247738.7.clone.1 = u32[1280,1280]{1,0} add(%add.247736.3.clone.1, %broadcast.250825.113.clone.1) + %shift-left.109343.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121581.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115501.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121581.3.clone.1, %broadcast.244418.4352) + %or.115033.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109343.5.clone.1, %shift-right-logical.115501.5.clone.1) + %xor.121582.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247736.3.clone.1, %or.115033.3.clone.1) + %constant_218166_1_clone_1 = u32[] constant(4249455393) + %broadcast.250835.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218166_1_clone_1), dimensions={} + %add.247739.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121582.3.clone.1, %broadcast.250835.5.clone.1) + %add.247740.5.clone.1 = u32[1280,1280]{1,0} add(%add.247738.7.clone.1, %add.247739.5.clone.1) + %shift-left.109344.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247739.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115502.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247739.5.clone.1, %broadcast.244416.5760) + %or.115034.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109344.9.clone.1, %shift-right-logical.115502.9.clone.1) + %xor.121584.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247740.5.clone.1, %or.115034.7.clone.1) + %add.247741.3.clone.1 = u32[1280,1280]{1,0} add(%add.247740.5.clone.1, %xor.121584.5.clone.1) + %shift-left.109345.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121584.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115503.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121584.5.clone.1, %broadcast.244429.2304) + %or.115035.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109345.9.clone.1, %shift-right-logical.115503.9.clone.1) + %xor.121585.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247741.3.clone.1, %or.115035.7.clone.1) + %add.247743.3.clone.1 = u32[1280,1280]{1,0} add(%add.247741.3.clone.1, %xor.121585.5.clone.1) + %shift-left.109346.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121585.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115504.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121585.5.clone.1, %broadcast.244430.4608) + %or.115036.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109346.9.clone.1, %shift-right-logical.115504.9.clone.1) + %xor.121586.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247743.3.clone.1, %or.115036.7.clone.1) + %add.247744.3.clone.1 = u32[1280,1280]{1,0} add(%add.247743.3.clone.1, %xor.121586.5.clone.1) + %constant_172961_1_clone_1 = u32[] constant(4249455392) + %broadcast.250842.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172961_1_clone_1), dimensions={} + %add.247745.7.clone.1 = u32[1280,1280]{1,0} add(%add.247744.3.clone.1, %broadcast.250842.24.clone.1) + %shift-left.109347.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121586.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115505.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121586.5.clone.1, %broadcast.244434.2816) + %or.115037.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109347.11.clone.1, %shift-right-logical.115505.11.clone.1) + %xor.121587.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247744.3.clone.1, %or.115037.9.clone.1) + %constant_218167_1_clone_1 = u32[] constant(575593933) + %broadcast.250845.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218167_1_clone_1), dimensions={} + %add.247746.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121587.7.clone.1, %broadcast.250845.5.clone.1) + %add.247747.5.clone.1 = u32[1280,1280]{1,0} add(%add.247745.7.clone.1, %add.247746.5.clone.1) + %shift-left.109348.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247746.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115506.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247746.5.clone.1, %broadcast.244415.6016) + %or.115039.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109348.9.clone.1, %shift-right-logical.115506.9.clone.1) + %xor.121589.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247747.5.clone.1, %or.115039.7.clone.1) + %add.247749.3.clone.1 = u32[1280,1280]{1,0} add(%add.247747.5.clone.1, %xor.121589.5.clone.1) + %shift-left.109349.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121589.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115507.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121589.5.clone.1, %broadcast.244417.5760) + %or.115040.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109349.9.clone.1, %shift-right-logical.115507.9.clone.1) + %xor.121590.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247749.3.clone.1, %or.115040.7.clone.1) + %add.247753.3.clone.1 = u32[1280,1280]{1,0} add(%add.247749.3.clone.1, %xor.121590.5.clone.1) + %shift-left.109350.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121590.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115508.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121590.5.clone.1, %broadcast.244419.4352) + %or.115041.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109350.7.clone.1, %shift-right-logical.115508.7.clone.1) + %xor.121591.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247753.3.clone.1, %or.115041.5.clone.1) + %add.247754.3.clone.1 = u32[1280,1280]{1,0} add(%add.247753.3.clone.1, %xor.121591.3.clone.1) + %add.247755.7.clone.1 = u32[1280,1280]{1,0} add(%add.247754.3.clone.1, %broadcast.250824.44.clone.1) + %shift-left.109351.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121591.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115509.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121591.3.clone.1, %broadcast.244418.4352) + %or.115042.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109351.7.clone.1, %shift-right-logical.115509.7.clone.1) + %xor.121592.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247754.3.clone.1, %or.115042.5.clone.1) + %constant_218168_1_clone_1 = u32[] constant(3302378804) + %broadcast.250855.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218168_1_clone_1), dimensions={} + %add.247756.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121592.3.clone.1, %broadcast.250855.5.clone.1) + %add.247758.5.clone.1 = u32[1280,1280]{1,0} add(%add.247755.7.clone.1, %add.247756.5.clone.1) + %shift-left.109352.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247756.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115510.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247756.5.clone.1, %broadcast.244416.5760) + %or.115044.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109352.9.clone.1, %shift-right-logical.115510.9.clone.1) + %xor.121594.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247758.5.clone.1, %or.115044.7.clone.1) + %add.247759.3.clone.1 = u32[1280,1280]{1,0} add(%add.247758.5.clone.1, %xor.121594.5.clone.1) + %shift-left.109353.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121594.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115511.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121594.5.clone.1, %broadcast.244429.2304) + %or.115045.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109353.9.clone.1, %shift-right-logical.115511.9.clone.1) + %xor.121595.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247759.3.clone.1, %or.115045.7.clone.1) + %add.247760.3.clone.1 = u32[1280,1280]{1,0} add(%add.247759.3.clone.1, %xor.121595.5.clone.1) + %shift-left.109354.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121595.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115512.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121595.5.clone.1, %broadcast.244430.4608) + %or.115046.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109354.9.clone.1, %shift-right-logical.115512.9.clone.1) + %xor.121596.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247760.3.clone.1, %or.115046.7.clone.1) + %add.247761.3.clone.1 = u32[1280,1280]{1,0} add(%add.247760.3.clone.1, %xor.121596.5.clone.1) + %add.247763.7.clone.1 = u32[1280,1280]{1,0} add(%add.247761.3.clone.1, %broadcast.250825.113.clone.1) + %shift-left.109355.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121596.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115513.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121596.5.clone.1, %broadcast.244434.2816) + %or.115047.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109355.11.clone.1, %shift-right-logical.115513.11.clone.1) + %xor.121597.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247761.3.clone.1, %or.115047.9.clone.1) + %constant_218169_1_clone_1 = u32[] constant(4249455396) + %broadcast.250865.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218169_1_clone_1), dimensions={} + %add.247764.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121597.7.clone.1, %broadcast.250865.5.clone.1) + %add.247765.5.clone.1 = u32[1280,1280]{1,0} add(%add.247763.7.clone.1, %add.247764.5.clone.1) + %shift-left.109356.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247764.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115514.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247764.5.clone.1, %broadcast.244415.6016) + %or.115049.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109356.9.clone.1, %shift-right-logical.115514.9.clone.1) + %xor.121599.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247765.5.clone.1, %or.115049.7.clone.1) + %add.247766.3.clone.1 = u32[1280,1280]{1,0} add(%add.247765.5.clone.1, %xor.121599.5.clone.1) + %shift-left.109357.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121599.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115515.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121599.5.clone.1, %broadcast.244417.5760) + %or.115050.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109357.9.clone.1, %shift-right-logical.115515.9.clone.1) + %xor.121600.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247766.3.clone.1, %or.115050.7.clone.1) + %add.247768.3.clone.1 = u32[1280,1280]{1,0} add(%add.247766.3.clone.1, %xor.121600.5.clone.1) + %shift-left.109358.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121600.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115516.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121600.5.clone.1, %broadcast.244419.4352) + %or.115051.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109358.5.clone.1, %shift-right-logical.115516.5.clone.1) + %xor.121601.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247768.3.clone.1, %or.115051.3.clone.1) + %add.247769.3.clone.1 = u32[1280,1280]{1,0} add(%add.247768.3.clone.1, %xor.121601.3.clone.1) + %add.247770.17.clone.1 = u32[1280,1280]{1,0} add(%add.247769.3.clone.1, %broadcast.250842.24.clone.1) + %shift-left.109359.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121601.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115517.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121601.3.clone.1, %broadcast.244418.4352) + %or.115052.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109359.5.clone.1, %shift-right-logical.115517.5.clone.1) + %xor.121602.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247769.3.clone.1, %or.115052.3.clone.1) + %constant_218170_1_clone_1 = u32[] constant(575593936) + %broadcast.250875.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218170_1_clone_1), dimensions={} + %add.247771.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121602.15.clone.1, %broadcast.250875.19.clone.1) + %xor.121603.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247770.17.clone.1, %add.247771.19.clone.1) + %shift-right-logical.115518.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121603.17.clone.1, %broadcast.244468.1920) + %or.115054.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115518.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5742.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115054.13.clone.1) + %add.247772.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5742.11.clone.1, %broadcast.244470.1152) + %multiply.26334.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247772.9.clone.1, %broadcast.244471.896) + %add.247774.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26334.7.clone.1, %broadcast.244408.1024) + %maximum.3674.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247774.5.clone.1) + %abs.1532.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3674.3.clone.1) + %compare.7212.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1532.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26335.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3674.3.clone.1, %broadcast.244476.1152) + %negate.4569.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3674.3.clone.1) + %multiply.26336.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3674.3.clone.1, %negate.4569.5.clone.1) + %log-plus-one.1532.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26336.5.clone.1) + %negate.4570.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1532.3.clone.1) + %compare.7213.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4570.4.clone.1, %broadcast.244477.384), direction=LT + %select.20917.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20918.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20919.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20920.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20921.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20922.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20923.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20924.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20925.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247778.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4570.4.clone.1, %broadcast.244496.640) + %sqrt.1532.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4570.4.clone.1) + %add.247779.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1532.5.clone.1, %broadcast.244498.640) + %select.20926.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7213.3.clone.1, %add.247778.5.clone.1, %add.247779.5.clone.1) + %multiply.26337.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20925.3.clone.1, %select.20926.3.clone.1) + %add.247780.1.clone.1 = f32[1280,1280]{1,0} add(%select.20924.3.clone.1, %multiply.26337.1.clone.1) + %multiply.26338.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247780.1.clone.1, %select.20926.3.clone.1) + %add.247781.1.clone.1 = f32[1280,1280]{1,0} add(%select.20923.3.clone.1, %multiply.26338.1.clone.1) + %multiply.26339.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247781.1.clone.1, %select.20926.3.clone.1) + %add.247783.1.clone.1 = f32[1280,1280]{1,0} add(%select.20922.3.clone.1, %multiply.26339.1.clone.1) + %multiply.26340.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247783.1.clone.1, %select.20926.3.clone.1) + %add.247784.1.clone.1 = f32[1280,1280]{1,0} add(%select.20921.3.clone.1, %multiply.26340.1.clone.1) + %multiply.26341.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247784.1.clone.1, %select.20926.3.clone.1) + %add.247785.3.clone.1 = f32[1280,1280]{1,0} add(%select.20920.5.clone.1, %multiply.26341.1.clone.1) + %multiply.26342.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247785.3.clone.1, %select.20926.3.clone.1) + %add.247786.3.clone.1 = f32[1280,1280]{1,0} add(%select.20919.5.clone.1, %multiply.26342.1.clone.1) + %multiply.26343.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247786.3.clone.1, %select.20926.3.clone.1) + %add.247788.9.clone.1 = f32[1280,1280]{1,0} add(%select.20918.11.clone.1, %multiply.26343.7.clone.1) + %multiply.26344.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247788.9.clone.1, %select.20926.3.clone.1) + %add.247789.7.clone.1 = f32[1280,1280]{1,0} add(%select.20917.7.clone.1, %multiply.26344.7.clone.1) + %multiply.26345.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247789.7.clone.1, %maximum.3674.3.clone.1) + %select.20927.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7212.3.clone.1, %multiply.26335.9.clone.1, %multiply.26345.7.clone.1) + %multiply.26346.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20927.7.clone.1, %broadcast.244500.640) + %clamp.1176.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26346.5.clone.1, %broadcast.244501.384) + %multiply.26347.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1176.3.clone.1, %broadcast.244502.1) + %constant_196332_1_clone_1 = u32[] constant(391327767) + %broadcast.260927.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_196332_1_clone_1), dimensions={} + %add.253509.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.260927.44.clone.1) + %constant_196339_1_clone_1 = u32[] constant(2492804280) + %broadcast.260928.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_196339_1_clone_1), dimensions={} + %add.253510.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.260928.113.clone.1) + %add.253511.35.clone.1 = u32[1280,1280]{1,0} add(%add.253509.37.clone.1, %add.253510.99.clone.1) + %shift-left.111840.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253510.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.118149.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253510.99.clone.1, %broadcast.244415.6016) + %or.117688.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111840.31.clone.1, %shift-right-logical.118149.29.clone.1) + %xor.124249.27.clone.1 = u32[1280,1280]{1,0} xor(%add.253511.35.clone.1, %or.117688.29.clone.1) + %add.253513.5.clone.1 = u32[1280,1280]{1,0} add(%add.253511.35.clone.1, %xor.124249.27.clone.1) + %shift-left.111841.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124249.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.118150.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124249.27.clone.1, %broadcast.244417.5760) + %or.117689.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111841.9.clone.1, %shift-right-logical.118150.9.clone.1) + %xor.124250.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253513.5.clone.1, %or.117689.7.clone.1) + %add.253514.3.clone.1 = u32[1280,1280]{1,0} add(%add.253513.5.clone.1, %xor.124250.5.clone.1) + %shift-left.111842.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124250.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118151.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124250.5.clone.1, %broadcast.244419.4352) + %or.117691.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111842.5.clone.1, %shift-right-logical.118151.5.clone.1) + %xor.124251.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253514.3.clone.1, %or.117691.3.clone.1) + %add.253515.3.clone.1 = u32[1280,1280]{1,0} add(%add.253514.3.clone.1, %xor.124251.3.clone.1) + %add.253516.7.clone.1 = u32[1280,1280]{1,0} add(%add.253515.3.clone.1, %broadcast.260928.113.clone.1) + %shift-left.111843.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124251.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118152.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124251.3.clone.1, %broadcast.244418.4352) + %or.117692.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111843.5.clone.1, %shift-right-logical.118152.5.clone.1) + %xor.124252.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253515.3.clone.1, %or.117692.3.clone.1) + %constant_218812_1_clone_1 = u32[] constant(2551646070) + %broadcast.260938.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218812_1_clone_1), dimensions={} + %add.253518.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124252.3.clone.1, %broadcast.260938.5.clone.1) + %add.253519.5.clone.1 = u32[1280,1280]{1,0} add(%add.253516.7.clone.1, %add.253518.5.clone.1) + %shift-left.111844.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253518.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118153.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253518.5.clone.1, %broadcast.244416.5760) + %or.117693.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111844.9.clone.1, %shift-right-logical.118153.9.clone.1) + %xor.124253.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253519.5.clone.1, %or.117693.7.clone.1) + %add.253520.3.clone.1 = u32[1280,1280]{1,0} add(%add.253519.5.clone.1, %xor.124253.5.clone.1) + %shift-left.111845.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124253.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118154.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124253.5.clone.1, %broadcast.244429.2304) + %or.117694.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111845.9.clone.1, %shift-right-logical.118154.9.clone.1) + %xor.124254.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253520.3.clone.1, %or.117694.7.clone.1) + %add.253521.3.clone.1 = u32[1280,1280]{1,0} add(%add.253520.3.clone.1, %xor.124254.5.clone.1) + %shift-left.111846.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124254.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118156.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124254.5.clone.1, %broadcast.244430.4608) + %or.117696.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111846.9.clone.1, %shift-right-logical.118156.9.clone.1) + %xor.124255.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253521.3.clone.1, %or.117696.7.clone.1) + %add.253522.3.clone.1 = u32[1280,1280]{1,0} add(%add.253521.3.clone.1, %xor.124255.5.clone.1) + %constant_196341_1_clone_1 = u32[] constant(2551646069) + %broadcast.260945.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_196341_1_clone_1), dimensions={} + %add.253524.7.clone.1 = u32[1280,1280]{1,0} add(%add.253522.3.clone.1, %broadcast.260945.24.clone.1) + %shift-left.111847.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124255.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118157.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124255.5.clone.1, %broadcast.244434.2816) + %or.117697.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111847.11.clone.1, %shift-right-logical.118157.11.clone.1) + %xor.124256.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253522.3.clone.1, %or.117697.9.clone.1) + %constant_218813_1_clone_1 = u32[] constant(391327769) + %broadcast.260949.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218813_1_clone_1), dimensions={} + %add.253528.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124256.7.clone.1, %broadcast.260949.5.clone.1) + %add.253529.5.clone.1 = u32[1280,1280]{1,0} add(%add.253524.7.clone.1, %add.253528.5.clone.1) + %shift-left.111848.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253528.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118158.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253528.5.clone.1, %broadcast.244415.6016) + %or.117698.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111848.9.clone.1, %shift-right-logical.118158.9.clone.1) + %xor.124257.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253529.5.clone.1, %or.117698.7.clone.1) + %add.253530.3.clone.1 = u32[1280,1280]{1,0} add(%add.253529.5.clone.1, %xor.124257.5.clone.1) + %shift-left.111849.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124257.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118159.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124257.5.clone.1, %broadcast.244417.5760) + %or.117699.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111849.9.clone.1, %shift-right-logical.118159.9.clone.1) + %xor.124258.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253530.3.clone.1, %or.117699.7.clone.1) + %add.253531.3.clone.1 = u32[1280,1280]{1,0} add(%add.253530.3.clone.1, %xor.124258.5.clone.1) + %shift-left.111850.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124258.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118161.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124258.5.clone.1, %broadcast.244419.4352) + %or.117700.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111850.7.clone.1, %shift-right-logical.118161.7.clone.1) + %xor.124259.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253531.3.clone.1, %or.117700.5.clone.1) + %add.253533.3.clone.1 = u32[1280,1280]{1,0} add(%add.253531.3.clone.1, %xor.124259.3.clone.1) + %add.253534.7.clone.1 = u32[1280,1280]{1,0} add(%add.253533.3.clone.1, %broadcast.260927.44.clone.1) + %shift-left.111851.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124259.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118162.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124259.3.clone.1, %broadcast.244418.4352) + %or.117701.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111851.7.clone.1, %shift-right-logical.118162.7.clone.1) + %xor.124260.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253533.3.clone.1, %or.117701.5.clone.1) + %constant_218814_1_clone_1 = u32[] constant(2492804283) + %broadcast.260960.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218814_1_clone_1), dimensions={} + %add.253535.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124260.3.clone.1, %broadcast.260960.5.clone.1) + %add.253536.5.clone.1 = u32[1280,1280]{1,0} add(%add.253534.7.clone.1, %add.253535.5.clone.1) + %shift-left.111852.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253535.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118163.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253535.5.clone.1, %broadcast.244416.5760) + %or.117702.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111852.9.clone.1, %shift-right-logical.118163.9.clone.1) + %xor.124261.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253536.5.clone.1, %or.117702.7.clone.1) + %add.253538.3.clone.1 = u32[1280,1280]{1,0} add(%add.253536.5.clone.1, %xor.124261.5.clone.1) + %shift-left.111853.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124261.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118164.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124261.5.clone.1, %broadcast.244429.2304) + %or.117703.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111853.9.clone.1, %shift-right-logical.118164.9.clone.1) + %xor.124262.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253538.3.clone.1, %or.117703.7.clone.1) + %add.253539.3.clone.1 = u32[1280,1280]{1,0} add(%add.253538.3.clone.1, %xor.124262.5.clone.1) + %shift-left.111854.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124262.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118166.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124262.5.clone.1, %broadcast.244430.4608) + %or.117704.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111854.9.clone.1, %shift-right-logical.118166.9.clone.1) + %xor.124263.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253539.3.clone.1, %or.117704.7.clone.1) + %add.253540.3.clone.1 = u32[1280,1280]{1,0} add(%add.253539.3.clone.1, %xor.124263.5.clone.1) + %add.253541.7.clone.1 = u32[1280,1280]{1,0} add(%add.253540.3.clone.1, %broadcast.260928.113.clone.1) + %shift-left.111855.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124263.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118167.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124263.5.clone.1, %broadcast.244434.2816) + %or.117706.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111855.11.clone.1, %shift-right-logical.118167.11.clone.1) + %xor.124264.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253540.3.clone.1, %or.117706.9.clone.1) + %constant_218815_1_clone_1 = u32[] constant(2551646073) + %broadcast.260970.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218815_1_clone_1), dimensions={} + %add.253543.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124264.7.clone.1, %broadcast.260970.5.clone.1) + %add.253544.5.clone.1 = u32[1280,1280]{1,0} add(%add.253541.7.clone.1, %add.253543.5.clone.1) + %shift-left.111856.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253543.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118168.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253543.5.clone.1, %broadcast.244415.6016) + %or.117707.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111856.9.clone.1, %shift-right-logical.118168.9.clone.1) + %xor.124265.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253544.5.clone.1, %or.117707.7.clone.1) + %add.253545.3.clone.1 = u32[1280,1280]{1,0} add(%add.253544.5.clone.1, %xor.124265.5.clone.1) + %shift-left.111857.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124265.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118169.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124265.5.clone.1, %broadcast.244417.5760) + %or.117708.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111857.9.clone.1, %shift-right-logical.118169.9.clone.1) + %xor.124266.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253545.3.clone.1, %or.117708.7.clone.1) + %add.253546.3.clone.1 = u32[1280,1280]{1,0} add(%add.253545.3.clone.1, %xor.124266.5.clone.1) + %shift-left.111858.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124266.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118171.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124266.5.clone.1, %broadcast.244419.4352) + %or.117709.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111858.5.clone.1, %shift-right-logical.118171.5.clone.1) + %xor.124267.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253546.3.clone.1, %or.117709.3.clone.1) + %add.253547.3.clone.1 = u32[1280,1280]{1,0} add(%add.253546.3.clone.1, %xor.124267.3.clone.1) + %add.253549.17.clone.1 = u32[1280,1280]{1,0} add(%add.253547.3.clone.1, %broadcast.260945.24.clone.1) + %shift-left.111859.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124267.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118172.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124267.3.clone.1, %broadcast.244418.4352) + %or.117711.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111859.5.clone.1, %shift-right-logical.118172.5.clone.1) + %xor.124268.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253547.3.clone.1, %or.117711.3.clone.1) + %constant_218816_1_clone_1 = u32[] constant(391327772) + %broadcast.260982.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218816_1_clone_1), dimensions={} + %add.253552.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124268.15.clone.1, %broadcast.260982.19.clone.1) + %xor.124269.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253549.17.clone.1, %add.253552.19.clone.1) + %shift-right-logical.118173.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124269.17.clone.1, %broadcast.244468.1920) + %or.117712.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.118173.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5857.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117712.13.clone.1) + %add.253553.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5857.11.clone.1, %broadcast.244470.1152) + %multiply.27520.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253553.9.clone.1, %broadcast.244471.896) + %add.253554.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27520.7.clone.1, %broadcast.244408.1024) + %maximum.3789.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253554.5.clone.1) + %abs.1609.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3789.3.clone.1) + %compare.7380.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1609.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27521.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3789.3.clone.1, %broadcast.244476.1152) + %negate.4723.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3789.3.clone.1) + %multiply.27522.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3789.3.clone.1, %negate.4723.5.clone.1) + %log-plus-one.1609.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27522.5.clone.1) + %negate.4724.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1609.3.clone.1) + %compare.7381.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4724.4.clone.1, %broadcast.244477.384), direction=LT + %select.21806.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21807.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21808.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21809.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21810.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21811.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21812.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21813.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21814.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253555.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4724.4.clone.1, %broadcast.244496.640) + %sqrt.1609.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4724.4.clone.1) + %add.253556.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1609.5.clone.1, %broadcast.244498.640) + %select.21815.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7381.3.clone.1, %add.253555.5.clone.1, %add.253556.5.clone.1) + %multiply.27523.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21814.3.clone.1, %select.21815.3.clone.1) + %add.253557.1.clone.1 = f32[1280,1280]{1,0} add(%select.21813.3.clone.1, %multiply.27523.1.clone.1) + %multiply.27524.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253557.1.clone.1, %select.21815.3.clone.1) + %add.253558.1.clone.1 = f32[1280,1280]{1,0} add(%select.21812.3.clone.1, %multiply.27524.1.clone.1) + %multiply.27525.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253558.1.clone.1, %select.21815.3.clone.1) + %add.253559.1.clone.1 = f32[1280,1280]{1,0} add(%select.21811.3.clone.1, %multiply.27525.1.clone.1) + %multiply.27526.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253559.1.clone.1, %select.21815.3.clone.1) + %add.253560.1.clone.1 = f32[1280,1280]{1,0} add(%select.21810.3.clone.1, %multiply.27526.1.clone.1) + %multiply.27527.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253560.1.clone.1, %select.21815.3.clone.1) + %add.253561.3.clone.1 = f32[1280,1280]{1,0} add(%select.21809.5.clone.1, %multiply.27527.1.clone.1) + %multiply.27528.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253561.3.clone.1, %select.21815.3.clone.1) + %add.253562.3.clone.1 = f32[1280,1280]{1,0} add(%select.21808.5.clone.1, %multiply.27528.1.clone.1) + %multiply.27529.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253562.3.clone.1, %select.21815.3.clone.1) + %add.253563.9.clone.1 = f32[1280,1280]{1,0} add(%select.21807.11.clone.1, %multiply.27529.7.clone.1) + %multiply.27530.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253563.9.clone.1, %select.21815.3.clone.1) + %add.253564.7.clone.1 = f32[1280,1280]{1,0} add(%select.21806.7.clone.1, %multiply.27530.7.clone.1) + %multiply.27531.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253564.7.clone.1, %maximum.3789.3.clone.1) + %select.21816.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7380.3.clone.1, %multiply.27521.9.clone.1, %multiply.27531.7.clone.1) + %multiply.27532.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21816.7.clone.1, %broadcast.244500.640) + %clamp.1253.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27532.5.clone.1, %broadcast.244501.384) + %multiply.27533.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1253.3.clone.1, %broadcast.244502.1) + %constant_172720_1_clone_1 = u32[] constant(4241515503) + %broadcast.250726.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172720_1_clone_1), dimensions={} + %add.247676.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250726.44.clone.1) + %constant_172727_1_clone_1 = u32[] constant(1055489444) + %broadcast.250727.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172727_1_clone_1), dimensions={} + %add.247677.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250727.113.clone.1) + %add.247678.35.clone.1 = u32[1280,1280]{1,0} add(%add.247676.37.clone.1, %add.247677.99.clone.1) + %shift-left.109320.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247677.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115477.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247677.99.clone.1, %broadcast.244415.6016) + %or.115005.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109320.31.clone.1, %shift-right-logical.115477.29.clone.1) + %xor.121554.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247678.35.clone.1, %or.115005.29.clone.1) + %add.247679.5.clone.1 = u32[1280,1280]{1,0} add(%add.247678.35.clone.1, %xor.121554.27.clone.1) + %shift-left.109321.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121554.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115478.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121554.27.clone.1, %broadcast.244417.5760) + %or.115006.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109321.9.clone.1, %shift-right-logical.115478.9.clone.1) + %xor.121555.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247679.5.clone.1, %or.115006.7.clone.1) + %add.247680.3.clone.1 = u32[1280,1280]{1,0} add(%add.247679.5.clone.1, %xor.121555.5.clone.1) + %shift-left.109322.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121555.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115479.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121555.5.clone.1, %broadcast.244419.4352) + %or.115007.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109322.5.clone.1, %shift-right-logical.115479.5.clone.1) + %xor.121556.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247680.3.clone.1, %or.115007.3.clone.1) + %add.247681.3.clone.1 = u32[1280,1280]{1,0} add(%add.247680.3.clone.1, %xor.121556.3.clone.1) + %add.247682.7.clone.1 = u32[1280,1280]{1,0} add(%add.247681.3.clone.1, %broadcast.250727.113.clone.1) + %shift-left.109323.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121556.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115480.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121556.3.clone.1, %broadcast.244418.4352) + %or.115008.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109323.5.clone.1, %shift-right-logical.115480.5.clone.1) + %xor.121557.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247681.3.clone.1, %or.115008.3.clone.1) + %constant_218161_1_clone_1 = u32[] constant(3655861650) + %broadcast.250739.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218161_1_clone_1), dimensions={} + %add.247683.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121557.3.clone.1, %broadcast.250739.5.clone.1) + %add.247684.5.clone.1 = u32[1280,1280]{1,0} add(%add.247682.7.clone.1, %add.247683.5.clone.1) + %shift-left.109324.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247683.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115481.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247683.5.clone.1, %broadcast.244416.5760) + %or.115009.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109324.9.clone.1, %shift-right-logical.115481.9.clone.1) + %xor.121559.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247684.5.clone.1, %or.115009.7.clone.1) + %add.247685.3.clone.1 = u32[1280,1280]{1,0} add(%add.247684.5.clone.1, %xor.121559.5.clone.1) + %shift-left.109325.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121559.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115482.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121559.5.clone.1, %broadcast.244429.2304) + %or.115010.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109325.9.clone.1, %shift-right-logical.115482.9.clone.1) + %xor.121560.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247685.3.clone.1, %or.115010.7.clone.1) + %add.247686.3.clone.1 = u32[1280,1280]{1,0} add(%add.247685.3.clone.1, %xor.121560.5.clone.1) + %shift-left.109326.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121560.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115483.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121560.5.clone.1, %broadcast.244430.4608) + %or.115011.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109326.9.clone.1, %shift-right-logical.115483.9.clone.1) + %xor.121561.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247686.3.clone.1, %or.115011.7.clone.1) + %add.247687.3.clone.1 = u32[1280,1280]{1,0} add(%add.247686.3.clone.1, %xor.121561.5.clone.1) + %constant_172729_1_clone_1 = u32[] constant(3655861649) + %broadcast.250750.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172729_1_clone_1), dimensions={} + %add.247688.7.clone.1 = u32[1280,1280]{1,0} add(%add.247687.3.clone.1, %broadcast.250750.24.clone.1) + %shift-left.109327.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121561.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115484.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121561.5.clone.1, %broadcast.244434.2816) + %or.115012.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109327.11.clone.1, %shift-right-logical.115484.11.clone.1) + %xor.121562.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247687.3.clone.1, %or.115012.9.clone.1) + %constant_218162_1_clone_1 = u32[] constant(4241515505) + %broadcast.250753.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218162_1_clone_1), dimensions={} + %add.247689.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121562.7.clone.1, %broadcast.250753.5.clone.1) + %add.247690.5.clone.1 = u32[1280,1280]{1,0} add(%add.247688.7.clone.1, %add.247689.5.clone.1) + %shift-left.109328.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247689.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115485.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247689.5.clone.1, %broadcast.244415.6016) + %or.115014.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109328.9.clone.1, %shift-right-logical.115485.9.clone.1) + %xor.121564.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247690.5.clone.1, %or.115014.7.clone.1) + %add.247691.3.clone.1 = u32[1280,1280]{1,0} add(%add.247690.5.clone.1, %xor.121564.5.clone.1) + %shift-left.109329.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121564.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115486.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121564.5.clone.1, %broadcast.244417.5760) + %or.115015.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109329.9.clone.1, %shift-right-logical.115486.9.clone.1) + %xor.121565.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247691.3.clone.1, %or.115015.7.clone.1) + %add.247692.3.clone.1 = u32[1280,1280]{1,0} add(%add.247691.3.clone.1, %xor.121565.5.clone.1) + %shift-left.109330.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121565.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115487.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121565.5.clone.1, %broadcast.244419.4352) + %or.115016.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109330.7.clone.1, %shift-right-logical.115487.7.clone.1) + %xor.121566.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247692.3.clone.1, %or.115016.5.clone.1) + %add.247693.3.clone.1 = u32[1280,1280]{1,0} add(%add.247692.3.clone.1, %xor.121566.3.clone.1) + %add.247694.7.clone.1 = u32[1280,1280]{1,0} add(%add.247693.3.clone.1, %broadcast.250726.44.clone.1) + %shift-left.109331.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121566.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115488.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121566.3.clone.1, %broadcast.244418.4352) + %or.115017.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109331.7.clone.1, %shift-right-logical.115488.7.clone.1) + %xor.121567.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247693.3.clone.1, %or.115017.5.clone.1) + %constant_218163_1_clone_1 = u32[] constant(1055489447) + %broadcast.250763.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218163_1_clone_1), dimensions={} + %add.247695.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121567.3.clone.1, %broadcast.250763.5.clone.1) + %add.247696.5.clone.1 = u32[1280,1280]{1,0} add(%add.247694.7.clone.1, %add.247695.5.clone.1) + %shift-left.109332.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247695.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115489.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247695.5.clone.1, %broadcast.244416.5760) + %or.115019.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109332.9.clone.1, %shift-right-logical.115489.9.clone.1) + %xor.121569.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247696.5.clone.1, %or.115019.7.clone.1) + %add.247697.3.clone.1 = u32[1280,1280]{1,0} add(%add.247696.5.clone.1, %xor.121569.5.clone.1) + %shift-left.109333.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121569.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115490.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121569.5.clone.1, %broadcast.244429.2304) + %or.115020.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109333.9.clone.1, %shift-right-logical.115490.9.clone.1) + %xor.121570.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247697.3.clone.1, %or.115020.7.clone.1) + %add.247698.3.clone.1 = u32[1280,1280]{1,0} add(%add.247697.3.clone.1, %xor.121570.5.clone.1) + %shift-left.109334.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121570.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115491.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121570.5.clone.1, %broadcast.244430.4608) + %or.115021.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109334.9.clone.1, %shift-right-logical.115491.9.clone.1) + %xor.121571.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247698.3.clone.1, %or.115021.7.clone.1) + %add.247700.3.clone.1 = u32[1280,1280]{1,0} add(%add.247698.3.clone.1, %xor.121571.5.clone.1) + %add.247703.7.clone.1 = u32[1280,1280]{1,0} add(%add.247700.3.clone.1, %broadcast.250727.113.clone.1) + %shift-left.109335.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121571.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115492.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121571.5.clone.1, %broadcast.244434.2816) + %or.115022.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109335.11.clone.1, %shift-right-logical.115492.11.clone.1) + %xor.121572.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247700.3.clone.1, %or.115022.9.clone.1) + %constant_218164_1_clone_1 = u32[] constant(3655861653) + %broadcast.250777.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218164_1_clone_1), dimensions={} + %add.247704.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121572.7.clone.1, %broadcast.250777.5.clone.1) + %add.247705.5.clone.1 = u32[1280,1280]{1,0} add(%add.247703.7.clone.1, %add.247704.5.clone.1) + %shift-left.109336.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247704.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115493.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247704.5.clone.1, %broadcast.244415.6016) + %or.115024.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109336.9.clone.1, %shift-right-logical.115493.9.clone.1) + %xor.121574.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247705.5.clone.1, %or.115024.7.clone.1) + %add.247706.3.clone.1 = u32[1280,1280]{1,0} add(%add.247705.5.clone.1, %xor.121574.5.clone.1) + %shift-left.109337.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121574.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115494.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121574.5.clone.1, %broadcast.244417.5760) + %or.115025.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109337.9.clone.1, %shift-right-logical.115494.9.clone.1) + %xor.121575.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247706.3.clone.1, %or.115025.7.clone.1) + %add.247708.3.clone.1 = u32[1280,1280]{1,0} add(%add.247706.3.clone.1, %xor.121575.5.clone.1) + %shift-left.109338.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121575.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115495.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121575.5.clone.1, %broadcast.244419.4352) + %or.115026.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109338.5.clone.1, %shift-right-logical.115495.5.clone.1) + %xor.121576.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247708.3.clone.1, %or.115026.3.clone.1) + %add.247709.3.clone.1 = u32[1280,1280]{1,0} add(%add.247708.3.clone.1, %xor.121576.3.clone.1) + %add.247710.17.clone.1 = u32[1280,1280]{1,0} add(%add.247709.3.clone.1, %broadcast.250750.24.clone.1) + %shift-left.109339.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121576.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115496.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121576.3.clone.1, %broadcast.244418.4352) + %or.115027.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109339.5.clone.1, %shift-right-logical.115496.5.clone.1) + %xor.121577.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247709.3.clone.1, %or.115027.3.clone.1) + %constant_218165_1_clone_1 = u32[] constant(4241515508) + %broadcast.250789.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218165_1_clone_1), dimensions={} + %add.247711.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121577.15.clone.1, %broadcast.250789.19.clone.1) + %xor.121578.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247710.17.clone.1, %add.247711.19.clone.1) + %shift-right-logical.115497.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121578.17.clone.1, %broadcast.244468.1920) + %or.115029.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115497.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5741.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115029.13.clone.1) + %add.247713.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5741.11.clone.1, %broadcast.244470.1152) + %multiply.26320.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247713.9.clone.1, %broadcast.244471.896) + %add.247714.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26320.7.clone.1, %broadcast.244408.1024) + %maximum.3673.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247714.5.clone.1) + %abs.1531.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3673.3.clone.1) + %compare.7210.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1531.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26321.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3673.3.clone.1, %broadcast.244476.1152) + %negate.4567.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3673.3.clone.1) + %multiply.26322.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3673.3.clone.1, %negate.4567.5.clone.1) + %log-plus-one.1531.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26322.5.clone.1) + %negate.4568.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1531.3.clone.1) + %compare.7211.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4568.4.clone.1, %broadcast.244477.384), direction=LT + %select.20906.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20907.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20908.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20909.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20910.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20911.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20912.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20913.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20914.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247715.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4568.4.clone.1, %broadcast.244496.640) + %sqrt.1531.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4568.4.clone.1) + %add.247716.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1531.5.clone.1, %broadcast.244498.640) + %select.20915.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7211.3.clone.1, %add.247715.5.clone.1, %add.247716.5.clone.1) + %multiply.26323.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20914.3.clone.1, %select.20915.3.clone.1) + %add.247718.1.clone.1 = f32[1280,1280]{1,0} add(%select.20913.3.clone.1, %multiply.26323.1.clone.1) + %multiply.26324.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247718.1.clone.1, %select.20915.3.clone.1) + %add.247719.1.clone.1 = f32[1280,1280]{1,0} add(%select.20912.3.clone.1, %multiply.26324.1.clone.1) + %multiply.26325.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247719.1.clone.1, %select.20915.3.clone.1) + %add.247720.1.clone.1 = f32[1280,1280]{1,0} add(%select.20911.3.clone.1, %multiply.26325.1.clone.1) + %multiply.26326.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247720.1.clone.1, %select.20915.3.clone.1) + %add.247721.1.clone.1 = f32[1280,1280]{1,0} add(%select.20910.3.clone.1, %multiply.26326.1.clone.1) + %multiply.26327.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247721.1.clone.1, %select.20915.3.clone.1) + %add.247722.3.clone.1 = f32[1280,1280]{1,0} add(%select.20909.5.clone.1, %multiply.26327.1.clone.1) + %multiply.26328.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247722.3.clone.1, %select.20915.3.clone.1) + %add.247724.3.clone.1 = f32[1280,1280]{1,0} add(%select.20908.5.clone.1, %multiply.26328.1.clone.1) + %multiply.26329.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247724.3.clone.1, %select.20915.3.clone.1) + %add.247728.9.clone.1 = f32[1280,1280]{1,0} add(%select.20907.11.clone.1, %multiply.26329.7.clone.1) + %multiply.26330.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247728.9.clone.1, %select.20915.3.clone.1) + %add.247729.7.clone.1 = f32[1280,1280]{1,0} add(%select.20906.7.clone.1, %multiply.26330.7.clone.1) + %multiply.26331.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247729.7.clone.1, %maximum.3673.3.clone.1) + %select.20916.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7210.3.clone.1, %multiply.26321.9.clone.1, %multiply.26331.7.clone.1) + %multiply.26332.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20916.7.clone.1, %broadcast.244500.640) + %clamp.1175.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26332.5.clone.1, %broadcast.244501.384) + %multiply.26333.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1175.3.clone.1, %broadcast.244502.1) + %constant_184902_1_clone_1 = u32[] constant(1079347404) + %broadcast.255999.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184902_1_clone_1), dimensions={} + %add.250683.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255999.44.clone.1) + %constant_184912_1_clone_1 = u32[] constant(344521199) + %broadcast.256000.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184912_1_clone_1), dimensions={} + %add.250687.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.256000.113.clone.1) + %add.250688.35.clone.1 = u32[1280,1280]{1,0} add(%add.250683.37.clone.1, %add.250687.99.clone.1) + %shift-left.110618.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250687.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116866.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250687.99.clone.1, %broadcast.244415.6016) + %or.116391.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110618.31.clone.1, %shift-right-logical.116866.29.clone.1) + %xor.122961.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250688.35.clone.1, %or.116391.29.clone.1) + %add.250689.5.clone.1 = u32[1280,1280]{1,0} add(%add.250688.35.clone.1, %xor.122961.27.clone.1) + %shift-left.110619.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122961.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116867.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122961.27.clone.1, %broadcast.244417.5760) + %or.116392.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110619.9.clone.1, %shift-right-logical.116867.9.clone.1) + %xor.122962.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250689.5.clone.1, %or.116392.7.clone.1) + %add.250690.3.clone.1 = u32[1280,1280]{1,0} add(%add.250689.5.clone.1, %xor.122962.5.clone.1) + %shift-left.110621.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122962.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116868.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122962.5.clone.1, %broadcast.244419.4352) + %or.116393.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110621.5.clone.1, %shift-right-logical.116868.5.clone.1) + %xor.122963.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250690.3.clone.1, %or.116393.3.clone.1) + %add.250692.3.clone.1 = u32[1280,1280]{1,0} add(%add.250690.3.clone.1, %xor.122963.3.clone.1) + %add.250693.7.clone.1 = u32[1280,1280]{1,0} add(%add.250692.3.clone.1, %broadcast.256000.113.clone.1) + %shift-left.110622.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122963.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116870.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122963.3.clone.1, %broadcast.244418.4352) + %or.116394.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110622.5.clone.1, %shift-right-logical.116870.5.clone.1) + %xor.122964.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250692.3.clone.1, %or.116394.3.clone.1) + %constant_218494_1_clone_1 = u32[] constant(1326213882) + %broadcast.256010.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218494_1_clone_1), dimensions={} + %add.250694.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122964.3.clone.1, %broadcast.256010.5.clone.1) + %add.250695.5.clone.1 = u32[1280,1280]{1,0} add(%add.250693.7.clone.1, %add.250694.5.clone.1) + %shift-left.110623.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250694.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116871.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250694.5.clone.1, %broadcast.244416.5760) + %or.116395.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110623.9.clone.1, %shift-right-logical.116871.9.clone.1) + %xor.122965.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250695.5.clone.1, %or.116395.7.clone.1) + %add.250697.3.clone.1 = u32[1280,1280]{1,0} add(%add.250695.5.clone.1, %xor.122965.5.clone.1) + %shift-left.110624.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122965.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116872.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122965.5.clone.1, %broadcast.244429.2304) + %or.116396.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110624.9.clone.1, %shift-right-logical.116872.9.clone.1) + %xor.122966.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250697.3.clone.1, %or.116396.7.clone.1) + %add.250698.3.clone.1 = u32[1280,1280]{1,0} add(%add.250697.3.clone.1, %xor.122966.5.clone.1) + %shift-left.110626.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122966.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116873.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122966.5.clone.1, %broadcast.244430.4608) + %or.116397.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110626.9.clone.1, %shift-right-logical.116873.9.clone.1) + %xor.122967.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250698.3.clone.1, %or.116397.7.clone.1) + %add.250699.3.clone.1 = u32[1280,1280]{1,0} add(%add.250698.3.clone.1, %xor.122967.5.clone.1) + %constant_184916_1_clone_1 = u32[] constant(1326213881) + %broadcast.256019.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184916_1_clone_1), dimensions={} + %add.250700.7.clone.1 = u32[1280,1280]{1,0} add(%add.250699.3.clone.1, %broadcast.256019.24.clone.1) + %shift-left.110627.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122967.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116875.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122967.5.clone.1, %broadcast.244434.2816) + %or.116398.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110627.11.clone.1, %shift-right-logical.116875.11.clone.1) + %xor.122968.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250699.3.clone.1, %or.116398.9.clone.1) + %constant_218495_1_clone_1 = u32[] constant(1079347406) + %broadcast.256022.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218495_1_clone_1), dimensions={} + %add.250702.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122968.7.clone.1, %broadcast.256022.5.clone.1) + %add.250703.5.clone.1 = u32[1280,1280]{1,0} add(%add.250700.7.clone.1, %add.250702.5.clone.1) + %shift-left.110628.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250702.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116876.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250702.5.clone.1, %broadcast.244415.6016) + %or.116399.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110628.9.clone.1, %shift-right-logical.116876.9.clone.1) + %xor.122969.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250703.5.clone.1, %or.116399.7.clone.1) + %add.250704.3.clone.1 = u32[1280,1280]{1,0} add(%add.250703.5.clone.1, %xor.122969.5.clone.1) + %shift-left.110629.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122969.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116877.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122969.5.clone.1, %broadcast.244417.5760) + %or.116400.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110629.9.clone.1, %shift-right-logical.116877.9.clone.1) + %xor.122970.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250704.3.clone.1, %or.116400.7.clone.1) + %add.250705.3.clone.1 = u32[1280,1280]{1,0} add(%add.250704.3.clone.1, %xor.122970.5.clone.1) + %shift-left.110630.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122970.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116878.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122970.5.clone.1, %broadcast.244419.4352) + %or.116401.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110630.7.clone.1, %shift-right-logical.116878.7.clone.1) + %xor.122971.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250705.3.clone.1, %or.116401.5.clone.1) + %add.250706.3.clone.1 = u32[1280,1280]{1,0} add(%add.250705.3.clone.1, %xor.122971.3.clone.1) + %add.250708.7.clone.1 = u32[1280,1280]{1,0} add(%add.250706.3.clone.1, %broadcast.255999.44.clone.1) + %shift-left.110631.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122971.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116880.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122971.3.clone.1, %broadcast.244418.4352) + %or.116402.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110631.7.clone.1, %shift-right-logical.116880.7.clone.1) + %xor.122972.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250706.3.clone.1, %or.116402.5.clone.1) + %constant_218496_1_clone_1 = u32[] constant(344521202) + %broadcast.256032.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218496_1_clone_1), dimensions={} + %add.250712.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122972.3.clone.1, %broadcast.256032.5.clone.1) + %add.250713.5.clone.1 = u32[1280,1280]{1,0} add(%add.250708.7.clone.1, %add.250712.5.clone.1) + %shift-left.110632.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250712.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116881.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250712.5.clone.1, %broadcast.244416.5760) + %or.116403.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110632.9.clone.1, %shift-right-logical.116881.9.clone.1) + %xor.122973.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250713.5.clone.1, %or.116403.7.clone.1) + %add.250714.3.clone.1 = u32[1280,1280]{1,0} add(%add.250713.5.clone.1, %xor.122973.5.clone.1) + %shift-left.110633.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122973.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116882.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122973.5.clone.1, %broadcast.244429.2304) + %or.116404.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110633.9.clone.1, %shift-right-logical.116882.9.clone.1) + %xor.122974.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250714.3.clone.1, %or.116404.7.clone.1) + %add.250715.3.clone.1 = u32[1280,1280]{1,0} add(%add.250714.3.clone.1, %xor.122974.5.clone.1) + %shift-left.110634.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122974.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116883.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122974.5.clone.1, %broadcast.244430.4608) + %or.116405.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110634.9.clone.1, %shift-right-logical.116883.9.clone.1) + %xor.122975.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250715.3.clone.1, %or.116405.7.clone.1) + %add.250717.3.clone.1 = u32[1280,1280]{1,0} add(%add.250715.3.clone.1, %xor.122975.5.clone.1) + %add.250718.7.clone.1 = u32[1280,1280]{1,0} add(%add.250717.3.clone.1, %broadcast.256000.113.clone.1) + %shift-left.110635.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122975.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116885.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122975.5.clone.1, %broadcast.244434.2816) + %or.116406.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110635.11.clone.1, %shift-right-logical.116885.11.clone.1) + %xor.122976.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250717.3.clone.1, %or.116406.9.clone.1) + %constant_218497_1_clone_1 = u32[] constant(1326213885) + %broadcast.256044.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218497_1_clone_1), dimensions={} + %add.250719.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122976.7.clone.1, %broadcast.256044.5.clone.1) + %add.250720.5.clone.1 = u32[1280,1280]{1,0} add(%add.250718.7.clone.1, %add.250719.5.clone.1) + %shift-left.110636.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250719.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116886.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250719.5.clone.1, %broadcast.244415.6016) + %or.116407.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110636.9.clone.1, %shift-right-logical.116886.9.clone.1) + %xor.122977.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250720.5.clone.1, %or.116407.7.clone.1) + %add.250722.3.clone.1 = u32[1280,1280]{1,0} add(%add.250720.5.clone.1, %xor.122977.5.clone.1) + %shift-left.110637.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122977.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116887.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122977.5.clone.1, %broadcast.244417.5760) + %or.116408.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110637.9.clone.1, %shift-right-logical.116887.9.clone.1) + %xor.122978.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250722.3.clone.1, %or.116408.7.clone.1) + %add.250723.3.clone.1 = u32[1280,1280]{1,0} add(%add.250722.3.clone.1, %xor.122978.5.clone.1) + %shift-left.110638.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122978.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116888.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122978.5.clone.1, %broadcast.244419.4352) + %or.116409.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110638.5.clone.1, %shift-right-logical.116888.5.clone.1) + %xor.122979.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250723.3.clone.1, %or.116409.3.clone.1) + %add.250724.3.clone.1 = u32[1280,1280]{1,0} add(%add.250723.3.clone.1, %xor.122979.3.clone.1) + %add.250725.17.clone.1 = u32[1280,1280]{1,0} add(%add.250724.3.clone.1, %broadcast.256019.24.clone.1) + %shift-left.110639.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122979.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116889.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122979.3.clone.1, %broadcast.244418.4352) + %or.116410.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110639.5.clone.1, %shift-right-logical.116889.5.clone.1) + %xor.122980.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250724.3.clone.1, %or.116410.3.clone.1) + %constant_218498_1_clone_1 = u32[] constant(1079347409) + %broadcast.256054.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218498_1_clone_1), dimensions={} + %add.250727.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122980.15.clone.1, %broadcast.256054.19.clone.1) + %xor.122981.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250725.17.clone.1, %add.250727.19.clone.1) + %shift-right-logical.116890.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122981.17.clone.1, %broadcast.244468.1920) + %or.116411.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116890.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5801.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116411.13.clone.1) + %add.250728.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5801.11.clone.1, %broadcast.244470.1152) + %multiply.26937.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250728.9.clone.1, %broadcast.244471.896) + %add.250729.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26937.7.clone.1, %broadcast.244408.1024) + %maximum.3733.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250729.5.clone.1) + %abs.1571.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3733.3.clone.1) + %compare.7298.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1571.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26938.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3733.3.clone.1, %broadcast.244476.1152) + %negate.4647.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3733.3.clone.1) + %multiply.26939.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3733.3.clone.1, %negate.4647.5.clone.1) + %log-plus-one.1571.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26939.5.clone.1) + %negate.4648.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1571.3.clone.1) + %compare.7299.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4648.4.clone.1, %broadcast.244477.384), direction=LT + %select.21367.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21368.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21369.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21370.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21371.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21372.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21373.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21374.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21375.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250730.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4648.4.clone.1, %broadcast.244496.640) + %sqrt.1571.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4648.4.clone.1) + %add.250731.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1571.5.clone.1, %broadcast.244498.640) + %select.21376.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7299.3.clone.1, %add.250730.5.clone.1, %add.250731.5.clone.1) + %multiply.26940.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21375.3.clone.1, %select.21376.3.clone.1) + %add.250733.1.clone.1 = f32[1280,1280]{1,0} add(%select.21374.3.clone.1, %multiply.26940.1.clone.1) + %multiply.26941.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250733.1.clone.1, %select.21376.3.clone.1) + %add.250736.1.clone.1 = f32[1280,1280]{1,0} add(%select.21373.3.clone.1, %multiply.26941.1.clone.1) + %multiply.26942.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250736.1.clone.1, %select.21376.3.clone.1) + %add.250737.1.clone.1 = f32[1280,1280]{1,0} add(%select.21372.3.clone.1, %multiply.26942.1.clone.1) + %multiply.26943.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250737.1.clone.1, %select.21376.3.clone.1) + %add.250738.1.clone.1 = f32[1280,1280]{1,0} add(%select.21371.3.clone.1, %multiply.26943.1.clone.1) + %multiply.26944.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250738.1.clone.1, %select.21376.3.clone.1) + %add.250739.3.clone.1 = f32[1280,1280]{1,0} add(%select.21370.5.clone.1, %multiply.26944.1.clone.1) + %multiply.26945.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250739.3.clone.1, %select.21376.3.clone.1) + %add.250740.3.clone.1 = f32[1280,1280]{1,0} add(%select.21369.5.clone.1, %multiply.26945.1.clone.1) + %multiply.26946.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250740.3.clone.1, %select.21376.3.clone.1) + %add.250741.9.clone.1 = f32[1280,1280]{1,0} add(%select.21368.11.clone.1, %multiply.26946.7.clone.1) + %multiply.26947.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250741.9.clone.1, %select.21376.3.clone.1) + %add.250742.7.clone.1 = f32[1280,1280]{1,0} add(%select.21367.7.clone.1, %multiply.26947.7.clone.1) + %multiply.26948.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250742.7.clone.1, %maximum.3733.3.clone.1) + %select.21377.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7298.3.clone.1, %multiply.26938.9.clone.1, %multiply.26948.7.clone.1) + %multiply.26949.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21377.7.clone.1, %broadcast.244500.640) + %clamp.1215.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26949.5.clone.1, %broadcast.244501.384) + %multiply.26950.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1215.3.clone.1, %broadcast.244502.1) + %constant_172169_1_clone_1 = u32[] constant(1166513390) + %broadcast.250491.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172169_1_clone_1), dimensions={} + %add.247532.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250491.44.clone.1) + %constant_172176_1_clone_1 = u32[] constant(1392346075) + %broadcast.250492.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172176_1_clone_1), dimensions={} + %add.247534.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250492.113.clone.1) + %add.247535.35.clone.1 = u32[1280,1280]{1,0} add(%add.247532.37.clone.1, %add.247534.99.clone.1) + %shift-left.109248.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247534.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115414.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247534.99.clone.1, %broadcast.244415.6016) + %or.114936.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109248.31.clone.1, %shift-right-logical.115414.29.clone.1) + %xor.121481.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247535.35.clone.1, %or.114936.29.clone.1) + %add.247536.5.clone.1 = u32[1280,1280]{1,0} add(%add.247535.35.clone.1, %xor.121481.27.clone.1) + %shift-left.109249.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121481.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115415.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121481.27.clone.1, %broadcast.244417.5760) + %or.114937.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109249.9.clone.1, %shift-right-logical.115415.9.clone.1) + %xor.121482.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247536.5.clone.1, %or.114937.7.clone.1) + %add.247537.3.clone.1 = u32[1280,1280]{1,0} add(%add.247536.5.clone.1, %xor.121482.5.clone.1) + %shift-left.109250.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121482.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115416.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121482.5.clone.1, %broadcast.244419.4352) + %or.114938.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109250.5.clone.1, %shift-right-logical.115416.5.clone.1) + %xor.121483.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247537.3.clone.1, %or.114938.3.clone.1) + %add.247538.3.clone.1 = u32[1280,1280]{1,0} add(%add.247537.3.clone.1, %xor.121483.3.clone.1) + %add.247540.7.clone.1 = u32[1280,1280]{1,0} add(%add.247538.3.clone.1, %broadcast.250492.113.clone.1) + %shift-left.109251.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121483.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115417.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121483.3.clone.1, %broadcast.244418.4352) + %or.114939.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109251.5.clone.1, %shift-right-logical.115417.5.clone.1) + %xor.121484.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247538.3.clone.1, %or.114939.3.clone.1) + %constant_218146_1_clone_1 = u32[] constant(212536560) + %broadcast.250502.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218146_1_clone_1), dimensions={} + %add.247544.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121484.3.clone.1, %broadcast.250502.5.clone.1) + %add.247545.5.clone.1 = u32[1280,1280]{1,0} add(%add.247540.7.clone.1, %add.247544.5.clone.1) + %shift-left.109253.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247544.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115418.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247544.5.clone.1, %broadcast.244416.5760) + %or.114940.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109253.9.clone.1, %shift-right-logical.115418.9.clone.1) + %xor.121485.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247545.5.clone.1, %or.114940.7.clone.1) + %add.247546.3.clone.1 = u32[1280,1280]{1,0} add(%add.247545.5.clone.1, %xor.121485.5.clone.1) + %shift-left.109254.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121485.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115419.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121485.5.clone.1, %broadcast.244429.2304) + %or.114941.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109254.9.clone.1, %shift-right-logical.115419.9.clone.1) + %xor.121486.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247546.3.clone.1, %or.114941.7.clone.1) + %add.247547.3.clone.1 = u32[1280,1280]{1,0} add(%add.247546.3.clone.1, %xor.121486.5.clone.1) + %shift-left.109255.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121486.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115420.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121486.5.clone.1, %broadcast.244430.4608) + %or.114942.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109255.9.clone.1, %shift-right-logical.115420.9.clone.1) + %xor.121487.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247547.3.clone.1, %or.114942.7.clone.1) + %add.247549.3.clone.1 = u32[1280,1280]{1,0} add(%add.247547.3.clone.1, %xor.121487.5.clone.1) + %constant_172178_1_clone_1 = u32[] constant(212536559) + %broadcast.250509.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_172178_1_clone_1), dimensions={} + %add.247550.7.clone.1 = u32[1280,1280]{1,0} add(%add.247549.3.clone.1, %broadcast.250509.24.clone.1) + %shift-left.109256.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121487.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115421.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121487.5.clone.1, %broadcast.244434.2816) + %or.114944.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109256.11.clone.1, %shift-right-logical.115421.11.clone.1) + %xor.121488.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247549.3.clone.1, %or.114944.9.clone.1) + %constant_218147_1_clone_1 = u32[] constant(1166513392) + %broadcast.250512.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218147_1_clone_1), dimensions={} + %add.247551.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121488.7.clone.1, %broadcast.250512.5.clone.1) + %add.247552.5.clone.1 = u32[1280,1280]{1,0} add(%add.247550.7.clone.1, %add.247551.5.clone.1) + %shift-left.109258.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247551.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115422.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247551.5.clone.1, %broadcast.244415.6016) + %or.114945.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109258.9.clone.1, %shift-right-logical.115422.9.clone.1) + %xor.121489.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247552.5.clone.1, %or.114945.7.clone.1) + %add.247554.3.clone.1 = u32[1280,1280]{1,0} add(%add.247552.5.clone.1, %xor.121489.5.clone.1) + %shift-left.109259.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121489.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115423.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121489.5.clone.1, %broadcast.244417.5760) + %or.114947.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109259.9.clone.1, %shift-right-logical.115423.9.clone.1) + %xor.121490.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247554.3.clone.1, %or.114947.7.clone.1) + %add.247555.3.clone.1 = u32[1280,1280]{1,0} add(%add.247554.3.clone.1, %xor.121490.5.clone.1) + %shift-left.109260.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121490.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115424.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121490.5.clone.1, %broadcast.244419.4352) + %or.114948.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109260.7.clone.1, %shift-right-logical.115424.7.clone.1) + %xor.121491.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247555.3.clone.1, %or.114948.5.clone.1) + %add.247556.3.clone.1 = u32[1280,1280]{1,0} add(%add.247555.3.clone.1, %xor.121491.3.clone.1) + %add.247557.7.clone.1 = u32[1280,1280]{1,0} add(%add.247556.3.clone.1, %broadcast.250491.44.clone.1) + %shift-left.109261.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121491.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115425.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121491.3.clone.1, %broadcast.244418.4352) + %or.114949.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109261.7.clone.1, %shift-right-logical.115425.7.clone.1) + %xor.121492.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247556.3.clone.1, %or.114949.5.clone.1) + %constant_218148_1_clone_1 = u32[] constant(1392346078) + %broadcast.250522.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218148_1_clone_1), dimensions={} + %add.247559.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121492.3.clone.1, %broadcast.250522.5.clone.1) + %add.247560.5.clone.1 = u32[1280,1280]{1,0} add(%add.247557.7.clone.1, %add.247559.5.clone.1) + %shift-left.109263.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247559.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115426.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247559.5.clone.1, %broadcast.244416.5760) + %or.114950.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109263.9.clone.1, %shift-right-logical.115426.9.clone.1) + %xor.121493.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247560.5.clone.1, %or.114950.7.clone.1) + %add.247561.3.clone.1 = u32[1280,1280]{1,0} add(%add.247560.5.clone.1, %xor.121493.5.clone.1) + %shift-left.109264.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121493.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115427.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121493.5.clone.1, %broadcast.244429.2304) + %or.114951.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109264.9.clone.1, %shift-right-logical.115427.9.clone.1) + %xor.121494.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247561.3.clone.1, %or.114951.7.clone.1) + %add.247562.3.clone.1 = u32[1280,1280]{1,0} add(%add.247561.3.clone.1, %xor.121494.5.clone.1) + %shift-left.109265.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121494.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115428.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121494.5.clone.1, %broadcast.244430.4608) + %or.114952.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109265.9.clone.1, %shift-right-logical.115428.9.clone.1) + %xor.121495.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247562.3.clone.1, %or.114952.7.clone.1) + %add.247563.3.clone.1 = u32[1280,1280]{1,0} add(%add.247562.3.clone.1, %xor.121495.5.clone.1) + %add.247565.7.clone.1 = u32[1280,1280]{1,0} add(%add.247563.3.clone.1, %broadcast.250492.113.clone.1) + %shift-left.109266.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121495.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115429.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121495.5.clone.1, %broadcast.244434.2816) + %or.114953.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109266.11.clone.1, %shift-right-logical.115429.11.clone.1) + %xor.121496.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247563.3.clone.1, %or.114953.9.clone.1) + %constant_218149_1_clone_1 = u32[] constant(212536563) + %broadcast.250532.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218149_1_clone_1), dimensions={} + %add.247569.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121496.7.clone.1, %broadcast.250532.5.clone.1) + %add.247570.5.clone.1 = u32[1280,1280]{1,0} add(%add.247565.7.clone.1, %add.247569.5.clone.1) + %shift-left.109268.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247569.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115430.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247569.5.clone.1, %broadcast.244415.6016) + %or.114954.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109268.9.clone.1, %shift-right-logical.115430.9.clone.1) + %xor.121497.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247570.5.clone.1, %or.114954.7.clone.1) + %add.247571.3.clone.1 = u32[1280,1280]{1,0} add(%add.247570.5.clone.1, %xor.121497.5.clone.1) + %shift-left.109269.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121497.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115431.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121497.5.clone.1, %broadcast.244417.5760) + %or.114955.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109269.9.clone.1, %shift-right-logical.115431.9.clone.1) + %xor.121500.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247571.3.clone.1, %or.114955.7.clone.1) + %add.247572.3.clone.1 = u32[1280,1280]{1,0} add(%add.247571.3.clone.1, %xor.121500.5.clone.1) + %shift-left.109270.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121500.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115432.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121500.5.clone.1, %broadcast.244419.4352) + %or.114956.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109270.5.clone.1, %shift-right-logical.115432.5.clone.1) + %xor.121501.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247572.3.clone.1, %or.114956.3.clone.1) + %add.247574.3.clone.1 = u32[1280,1280]{1,0} add(%add.247572.3.clone.1, %xor.121501.3.clone.1) + %add.247575.17.clone.1 = u32[1280,1280]{1,0} add(%add.247574.3.clone.1, %broadcast.250509.24.clone.1) + %shift-left.109271.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121501.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115433.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121501.3.clone.1, %broadcast.244418.4352) + %or.114957.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109271.5.clone.1, %shift-right-logical.115433.5.clone.1) + %xor.121502.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247574.3.clone.1, %or.114957.3.clone.1) + %constant_218150_1_clone_1 = u32[] constant(1166513395) + %broadcast.250542.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218150_1_clone_1), dimensions={} + %add.247576.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121502.15.clone.1, %broadcast.250542.19.clone.1) + %xor.121503.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247575.17.clone.1, %add.247576.19.clone.1) + %shift-right-logical.115434.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121503.17.clone.1, %broadcast.244468.1920) + %or.114958.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115434.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5738.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114958.13.clone.1) + %add.247577.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5738.11.clone.1, %broadcast.244470.1152) + %multiply.26302.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247577.9.clone.1, %broadcast.244471.896) + %add.247579.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26302.7.clone.1, %broadcast.244408.1024) + %maximum.3670.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247579.5.clone.1) + %abs.1530.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3670.3.clone.1) + %compare.7208.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1530.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26303.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3670.3.clone.1, %broadcast.244476.1152) + %negate.4565.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3670.3.clone.1) + %multiply.26304.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3670.3.clone.1, %negate.4565.5.clone.1) + %log-plus-one.1530.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26304.5.clone.1) + %negate.4566.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1530.3.clone.1) + %compare.7209.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4566.4.clone.1, %broadcast.244477.384), direction=LT + %select.20895.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20896.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20897.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20898.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20899.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20900.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20901.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20902.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20903.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247580.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4566.4.clone.1, %broadcast.244496.640) + %sqrt.1530.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4566.4.clone.1) + %add.247581.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1530.5.clone.1, %broadcast.244498.640) + %select.20904.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7209.3.clone.1, %add.247580.5.clone.1, %add.247581.5.clone.1) + %multiply.26305.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20903.3.clone.1, %select.20904.3.clone.1) + %add.247582.1.clone.1 = f32[1280,1280]{1,0} add(%select.20902.3.clone.1, %multiply.26305.1.clone.1) + %multiply.26306.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247582.1.clone.1, %select.20904.3.clone.1) + %add.247584.1.clone.1 = f32[1280,1280]{1,0} add(%select.20901.3.clone.1, %multiply.26306.1.clone.1) + %multiply.26307.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247584.1.clone.1, %select.20904.3.clone.1) + %add.247585.1.clone.1 = f32[1280,1280]{1,0} add(%select.20900.3.clone.1, %multiply.26307.1.clone.1) + %multiply.26308.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247585.1.clone.1, %select.20904.3.clone.1) + %add.247586.1.clone.1 = f32[1280,1280]{1,0} add(%select.20899.3.clone.1, %multiply.26308.1.clone.1) + %multiply.26309.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247586.1.clone.1, %select.20904.3.clone.1) + %add.247587.3.clone.1 = f32[1280,1280]{1,0} add(%select.20898.5.clone.1, %multiply.26309.1.clone.1) + %multiply.26310.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247587.3.clone.1, %select.20904.3.clone.1) + %add.247588.3.clone.1 = f32[1280,1280]{1,0} add(%select.20897.5.clone.1, %multiply.26310.1.clone.1) + %multiply.26311.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247588.3.clone.1, %select.20904.3.clone.1) + %add.247590.9.clone.1 = f32[1280,1280]{1,0} add(%select.20896.11.clone.1, %multiply.26311.7.clone.1) + %multiply.26312.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247590.9.clone.1, %select.20904.3.clone.1) + %add.247594.7.clone.1 = f32[1280,1280]{1,0} add(%select.20895.7.clone.1, %multiply.26312.7.clone.1) + %multiply.26313.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247594.7.clone.1, %maximum.3670.3.clone.1) + %select.20905.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7208.3.clone.1, %multiply.26303.9.clone.1, %multiply.26313.7.clone.1) + %multiply.26314.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20905.7.clone.1, %broadcast.244500.640) + %clamp.1174.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26314.5.clone.1, %broadcast.244501.384) + %multiply.26315.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1174.3.clone.1, %broadcast.244502.1) + %constant_191024_1_clone_1 = u32[] constant(1759838114) + %broadcast.258643.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191024_1_clone_1), dimensions={} + %add.252195.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258643.44.clone.1) + %constant_191031_1_clone_1 = u32[] constant(525301722) + %broadcast.258644.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191031_1_clone_1), dimensions={} + %add.252196.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258644.113.clone.1) + %add.252197.35.clone.1 = u32[1280,1280]{1,0} add(%add.252195.37.clone.1, %add.252196.99.clone.1) + %shift-left.111280.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252196.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117560.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252196.99.clone.1, %broadcast.244415.6016) + %or.117090.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111280.31.clone.1, %shift-right-logical.117560.29.clone.1) + %xor.123651.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252197.35.clone.1, %or.117090.29.clone.1) + %add.252199.5.clone.1 = u32[1280,1280]{1,0} add(%add.252197.35.clone.1, %xor.123651.27.clone.1) + %shift-left.111281.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123651.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117561.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123651.27.clone.1, %broadcast.244417.5760) + %or.117091.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111281.9.clone.1, %shift-right-logical.117561.9.clone.1) + %xor.123652.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252199.5.clone.1, %or.117091.7.clone.1) + %add.252200.3.clone.1 = u32[1280,1280]{1,0} add(%add.252199.5.clone.1, %xor.123652.5.clone.1) + %shift-left.111282.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123652.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117562.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123652.5.clone.1, %broadcast.244419.4352) + %or.117092.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111282.5.clone.1, %shift-right-logical.117562.5.clone.1) + %xor.123653.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252200.3.clone.1, %or.117092.3.clone.1) + %add.252201.3.clone.1 = u32[1280,1280]{1,0} add(%add.252200.3.clone.1, %xor.123653.3.clone.1) + %add.252202.7.clone.1 = u32[1280,1280]{1,0} add(%add.252201.3.clone.1, %broadcast.258644.113.clone.1) + %shift-left.111283.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123653.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117563.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123653.3.clone.1, %broadcast.244418.4352) + %or.117093.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111283.5.clone.1, %shift-right-logical.117563.5.clone.1) + %xor.123655.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252201.3.clone.1, %or.117093.3.clone.1) + %constant_218665_1_clone_1 = u32[] constant(1819972515) + %broadcast.258654.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218665_1_clone_1), dimensions={} + %add.252204.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123655.3.clone.1, %broadcast.258654.5.clone.1) + %add.252205.5.clone.1 = u32[1280,1280]{1,0} add(%add.252202.7.clone.1, %add.252204.5.clone.1) + %shift-left.111284.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252204.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117564.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252204.5.clone.1, %broadcast.244416.5760) + %or.117095.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111284.9.clone.1, %shift-right-logical.117564.9.clone.1) + %xor.123656.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252205.5.clone.1, %or.117095.7.clone.1) + %add.252206.3.clone.1 = u32[1280,1280]{1,0} add(%add.252205.5.clone.1, %xor.123656.5.clone.1) + %shift-left.111285.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123656.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117565.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123656.5.clone.1, %broadcast.244429.2304) + %or.117096.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111285.9.clone.1, %shift-right-logical.117565.9.clone.1) + %xor.123657.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252206.3.clone.1, %or.117096.7.clone.1) + %add.252207.3.clone.1 = u32[1280,1280]{1,0} add(%add.252206.3.clone.1, %xor.123657.5.clone.1) + %shift-left.111286.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123657.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117566.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123657.5.clone.1, %broadcast.244430.4608) + %or.117097.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111286.9.clone.1, %shift-right-logical.117566.9.clone.1) + %xor.123658.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252207.3.clone.1, %or.117097.7.clone.1) + %add.252209.3.clone.1 = u32[1280,1280]{1,0} add(%add.252207.3.clone.1, %xor.123658.5.clone.1) + %constant_191033_1_clone_1 = u32[] constant(1819972514) + %broadcast.258661.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_191033_1_clone_1), dimensions={} + %add.252210.7.clone.1 = u32[1280,1280]{1,0} add(%add.252209.3.clone.1, %broadcast.258661.24.clone.1) + %shift-left.111287.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123658.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117567.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123658.5.clone.1, %broadcast.244434.2816) + %or.117098.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111287.11.clone.1, %shift-right-logical.117567.11.clone.1) + %xor.123660.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252209.3.clone.1, %or.117098.9.clone.1) + %constant_218666_1_clone_1 = u32[] constant(1759838116) + %broadcast.258664.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218666_1_clone_1), dimensions={} + %add.252211.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123660.7.clone.1, %broadcast.258664.5.clone.1) + %add.252212.5.clone.1 = u32[1280,1280]{1,0} add(%add.252210.7.clone.1, %add.252211.5.clone.1) + %shift-left.111288.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252211.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117568.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252211.5.clone.1, %broadcast.244415.6016) + %or.117099.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111288.9.clone.1, %shift-right-logical.117568.9.clone.1) + %xor.123661.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252212.5.clone.1, %or.117099.7.clone.1) + %add.252213.3.clone.1 = u32[1280,1280]{1,0} add(%add.252212.5.clone.1, %xor.123661.5.clone.1) + %shift-left.111289.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123661.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117569.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123661.5.clone.1, %broadcast.244417.5760) + %or.117100.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111289.9.clone.1, %shift-right-logical.117569.9.clone.1) + %xor.123662.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252213.3.clone.1, %or.117100.7.clone.1) + %add.252215.3.clone.1 = u32[1280,1280]{1,0} add(%add.252213.3.clone.1, %xor.123662.5.clone.1) + %shift-left.111290.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123662.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117570.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123662.5.clone.1, %broadcast.244419.4352) + %or.117101.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111290.7.clone.1, %shift-right-logical.117570.7.clone.1) + %xor.123663.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252215.3.clone.1, %or.117101.5.clone.1) + %add.252219.3.clone.1 = u32[1280,1280]{1,0} add(%add.252215.3.clone.1, %xor.123663.3.clone.1) + %add.252220.7.clone.1 = u32[1280,1280]{1,0} add(%add.252219.3.clone.1, %broadcast.258643.44.clone.1) + %shift-left.111291.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123663.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117571.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123663.3.clone.1, %broadcast.244418.4352) + %or.117102.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111291.7.clone.1, %shift-right-logical.117571.7.clone.1) + %xor.123665.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252219.3.clone.1, %or.117102.5.clone.1) + %constant_218667_1_clone_1 = u32[] constant(525301725) + %broadcast.258674.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218667_1_clone_1), dimensions={} + %add.252221.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123665.3.clone.1, %broadcast.258674.5.clone.1) + %add.252222.5.clone.1 = u32[1280,1280]{1,0} add(%add.252220.7.clone.1, %add.252221.5.clone.1) + %shift-left.111292.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252221.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117572.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252221.5.clone.1, %broadcast.244416.5760) + %or.117103.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111292.9.clone.1, %shift-right-logical.117572.9.clone.1) + %xor.123666.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252222.5.clone.1, %or.117103.7.clone.1) + %add.252224.3.clone.1 = u32[1280,1280]{1,0} add(%add.252222.5.clone.1, %xor.123666.5.clone.1) + %shift-left.111293.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123666.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117573.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123666.5.clone.1, %broadcast.244429.2304) + %or.117105.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111293.9.clone.1, %shift-right-logical.117573.9.clone.1) + %xor.123667.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252224.3.clone.1, %or.117105.7.clone.1) + %add.252225.3.clone.1 = u32[1280,1280]{1,0} add(%add.252224.3.clone.1, %xor.123667.5.clone.1) + %shift-left.111294.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123667.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117574.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123667.5.clone.1, %broadcast.244430.4608) + %or.117106.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111294.9.clone.1, %shift-right-logical.117574.9.clone.1) + %xor.123668.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252225.3.clone.1, %or.117106.7.clone.1) + %add.252226.3.clone.1 = u32[1280,1280]{1,0} add(%add.252225.3.clone.1, %xor.123668.5.clone.1) + %add.252227.7.clone.1 = u32[1280,1280]{1,0} add(%add.252226.3.clone.1, %broadcast.258644.113.clone.1) + %shift-left.111295.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123668.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117575.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123668.5.clone.1, %broadcast.244434.2816) + %or.117107.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111295.11.clone.1, %shift-right-logical.117575.11.clone.1) + %xor.123669.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252226.3.clone.1, %or.117107.9.clone.1) + %constant_218668_1_clone_1 = u32[] constant(1819972518) + %broadcast.258684.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218668_1_clone_1), dimensions={} + %add.252229.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123669.7.clone.1, %broadcast.258684.5.clone.1) + %add.252230.5.clone.1 = u32[1280,1280]{1,0} add(%add.252227.7.clone.1, %add.252229.5.clone.1) + %shift-left.111296.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252229.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117576.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252229.5.clone.1, %broadcast.244415.6016) + %or.117108.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111296.9.clone.1, %shift-right-logical.117576.9.clone.1) + %xor.123670.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252230.5.clone.1, %or.117108.7.clone.1) + %add.252231.3.clone.1 = u32[1280,1280]{1,0} add(%add.252230.5.clone.1, %xor.123670.5.clone.1) + %shift-left.111297.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123670.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117577.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123670.5.clone.1, %broadcast.244417.5760) + %or.117110.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111297.9.clone.1, %shift-right-logical.117577.9.clone.1) + %xor.123671.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252231.3.clone.1, %or.117110.7.clone.1) + %add.252232.3.clone.1 = u32[1280,1280]{1,0} add(%add.252231.3.clone.1, %xor.123671.5.clone.1) + %shift-left.111298.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123671.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117578.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123671.5.clone.1, %broadcast.244419.4352) + %or.117111.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111298.5.clone.1, %shift-right-logical.117578.5.clone.1) + %xor.123672.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252232.3.clone.1, %or.117111.3.clone.1) + %add.252234.3.clone.1 = u32[1280,1280]{1,0} add(%add.252232.3.clone.1, %xor.123672.3.clone.1) + %add.252235.17.clone.1 = u32[1280,1280]{1,0} add(%add.252234.3.clone.1, %broadcast.258661.24.clone.1) + %shift-left.111299.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123672.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117579.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123672.3.clone.1, %broadcast.244418.4352) + %or.117112.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111299.5.clone.1, %shift-right-logical.117579.5.clone.1) + %xor.123673.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252234.3.clone.1, %or.117112.3.clone.1) + %constant_218669_1_clone_1 = u32[] constant(1759838119) + %broadcast.258694.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218669_1_clone_1), dimensions={} + %add.252236.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123673.15.clone.1, %broadcast.258694.19.clone.1) + %xor.123675.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252235.17.clone.1, %add.252236.19.clone.1) + %shift-right-logical.117580.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123675.17.clone.1, %broadcast.244468.1920) + %or.117113.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117580.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5831.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117113.13.clone.1) + %add.252237.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5831.11.clone.1, %broadcast.244470.1152) + %multiply.27250.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252237.9.clone.1, %broadcast.244471.896) + %add.252238.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27250.7.clone.1, %broadcast.244408.1024) + %maximum.3763.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252238.5.clone.1) + %abs.1591.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3763.3.clone.1) + %compare.7344.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1591.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27251.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3763.3.clone.1, %broadcast.244476.1152) + %negate.4687.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3763.3.clone.1) + %multiply.27252.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3763.3.clone.1, %negate.4687.5.clone.1) + %log-plus-one.1591.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27252.5.clone.1) + %negate.4688.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1591.3.clone.1) + %compare.7345.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4688.4.clone.1, %broadcast.244477.384), direction=LT + %select.21608.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21609.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21610.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21611.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21612.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21613.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21614.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21615.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21616.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252240.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4688.4.clone.1, %broadcast.244496.640) + %sqrt.1591.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4688.4.clone.1) + %add.252244.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1591.5.clone.1, %broadcast.244498.640) + %select.21617.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7345.3.clone.1, %add.252240.5.clone.1, %add.252244.5.clone.1) + %multiply.27253.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21616.3.clone.1, %select.21617.3.clone.1) + %add.252245.1.clone.1 = f32[1280,1280]{1,0} add(%select.21615.3.clone.1, %multiply.27253.1.clone.1) + %multiply.27254.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252245.1.clone.1, %select.21617.3.clone.1) + %add.252246.1.clone.1 = f32[1280,1280]{1,0} add(%select.21614.3.clone.1, %multiply.27254.1.clone.1) + %multiply.27255.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252246.1.clone.1, %select.21617.3.clone.1) + %add.252247.1.clone.1 = f32[1280,1280]{1,0} add(%select.21613.3.clone.1, %multiply.27255.1.clone.1) + %multiply.27256.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252247.1.clone.1, %select.21617.3.clone.1) + %add.252249.1.clone.1 = f32[1280,1280]{1,0} add(%select.21612.3.clone.1, %multiply.27256.1.clone.1) + %multiply.27257.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252249.1.clone.1, %select.21617.3.clone.1) + %add.252250.3.clone.1 = f32[1280,1280]{1,0} add(%select.21611.5.clone.1, %multiply.27257.1.clone.1) + %multiply.27258.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252250.3.clone.1, %select.21617.3.clone.1) + %add.252251.3.clone.1 = f32[1280,1280]{1,0} add(%select.21610.5.clone.1, %multiply.27258.1.clone.1) + %multiply.27259.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252251.3.clone.1, %select.21617.3.clone.1) + %add.252252.9.clone.1 = f32[1280,1280]{1,0} add(%select.21609.11.clone.1, %multiply.27259.7.clone.1) + %multiply.27260.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252252.9.clone.1, %select.21617.3.clone.1) + %add.252254.7.clone.1 = f32[1280,1280]{1,0} add(%select.21608.7.clone.1, %multiply.27260.7.clone.1) + %multiply.27261.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252254.7.clone.1, %maximum.3763.3.clone.1) + %select.21618.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7344.3.clone.1, %multiply.27251.9.clone.1, %multiply.27261.7.clone.1) + %multiply.27262.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21618.7.clone.1, %broadcast.244500.640) + %clamp.1235.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27262.5.clone.1, %broadcast.244501.384) + %multiply.27263.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1235.3.clone.1, %broadcast.244502.1) + %constant_171958_1_clone_1 = u32[] constant(3820892366) + %broadcast.250405.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171958_1_clone_1), dimensions={} + %add.247482.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250405.44.clone.1) + %constant_171965_1_clone_1 = u32[] constant(3482744791) + %broadcast.250406.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171965_1_clone_1), dimensions={} + %add.247484.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250406.113.clone.1) + %add.247485.35.clone.1 = u32[1280,1280]{1,0} add(%add.247482.37.clone.1, %add.247484.99.clone.1) + %shift-left.109224.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247484.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115393.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247484.99.clone.1, %broadcast.244415.6016) + %or.114913.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109224.31.clone.1, %shift-right-logical.115393.29.clone.1) + %xor.121460.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247485.35.clone.1, %or.114913.29.clone.1) + %add.247486.5.clone.1 = u32[1280,1280]{1,0} add(%add.247485.35.clone.1, %xor.121460.27.clone.1) + %shift-left.109225.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121460.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115394.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121460.27.clone.1, %broadcast.244417.5760) + %or.114914.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109225.9.clone.1, %shift-right-logical.115394.9.clone.1) + %xor.121461.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247486.5.clone.1, %or.114914.7.clone.1) + %add.247487.3.clone.1 = u32[1280,1280]{1,0} add(%add.247486.5.clone.1, %xor.121461.5.clone.1) + %shift-left.109226.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121461.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115395.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121461.5.clone.1, %broadcast.244419.4352) + %or.114915.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109226.5.clone.1, %shift-right-logical.115395.5.clone.1) + %xor.121462.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247487.3.clone.1, %or.114915.3.clone.1) + %add.247488.3.clone.1 = u32[1280,1280]{1,0} add(%add.247487.3.clone.1, %xor.121462.3.clone.1) + %add.247489.7.clone.1 = u32[1280,1280]{1,0} add(%add.247488.3.clone.1, %broadcast.250406.113.clone.1) + %shift-left.109228.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121462.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115396.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121462.3.clone.1, %broadcast.244418.4352) + %or.114916.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109228.5.clone.1, %shift-right-logical.115396.5.clone.1) + %xor.121463.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247488.3.clone.1, %or.114916.3.clone.1) + %constant_218141_1_clone_1 = u32[] constant(939089092) + %broadcast.250416.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218141_1_clone_1), dimensions={} + %add.247490.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121463.3.clone.1, %broadcast.250416.5.clone.1) + %add.247491.5.clone.1 = u32[1280,1280]{1,0} add(%add.247489.7.clone.1, %add.247490.5.clone.1) + %shift-left.109229.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247490.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115397.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247490.5.clone.1, %broadcast.244416.5760) + %or.114917.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109229.9.clone.1, %shift-right-logical.115397.9.clone.1) + %xor.121464.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247491.5.clone.1, %or.114917.7.clone.1) + %add.247492.3.clone.1 = u32[1280,1280]{1,0} add(%add.247491.5.clone.1, %xor.121464.5.clone.1) + %shift-left.109230.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121464.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115398.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121464.5.clone.1, %broadcast.244429.2304) + %or.114918.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109230.9.clone.1, %shift-right-logical.115398.9.clone.1) + %xor.121465.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247492.3.clone.1, %or.114918.7.clone.1) + %add.247493.3.clone.1 = u32[1280,1280]{1,0} add(%add.247492.3.clone.1, %xor.121465.5.clone.1) + %shift-left.109231.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121465.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115399.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121465.5.clone.1, %broadcast.244430.4608) + %or.114919.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109231.9.clone.1, %shift-right-logical.115399.9.clone.1) + %xor.121466.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247493.3.clone.1, %or.114919.7.clone.1) + %add.247494.3.clone.1 = u32[1280,1280]{1,0} add(%add.247493.3.clone.1, %xor.121466.5.clone.1) + %constant_171967_1_clone_1 = u32[] constant(939089091) + %broadcast.250423.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171967_1_clone_1), dimensions={} + %add.247495.7.clone.1 = u32[1280,1280]{1,0} add(%add.247494.3.clone.1, %broadcast.250423.24.clone.1) + %shift-left.109233.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121466.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115400.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121466.5.clone.1, %broadcast.244434.2816) + %or.114920.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109233.11.clone.1, %shift-right-logical.115400.11.clone.1) + %xor.121467.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247494.3.clone.1, %or.114920.9.clone.1) + %constant_218142_1_clone_1 = u32[] constant(3820892368) + %broadcast.250426.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218142_1_clone_1), dimensions={} + %add.247496.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121467.7.clone.1, %broadcast.250426.5.clone.1) + %add.247497.5.clone.1 = u32[1280,1280]{1,0} add(%add.247495.7.clone.1, %add.247496.5.clone.1) + %shift-left.109234.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247496.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115401.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247496.5.clone.1, %broadcast.244415.6016) + %or.114922.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109234.9.clone.1, %shift-right-logical.115401.9.clone.1) + %xor.121468.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247497.5.clone.1, %or.114922.7.clone.1) + %add.247498.3.clone.1 = u32[1280,1280]{1,0} add(%add.247497.5.clone.1, %xor.121468.5.clone.1) + %shift-left.109235.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121468.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115402.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121468.5.clone.1, %broadcast.244417.5760) + %or.114923.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109235.9.clone.1, %shift-right-logical.115402.9.clone.1) + %xor.121469.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247498.3.clone.1, %or.114923.7.clone.1) + %add.247499.3.clone.1 = u32[1280,1280]{1,0} add(%add.247498.3.clone.1, %xor.121469.5.clone.1) + %shift-left.109236.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121469.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115403.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121469.5.clone.1, %broadcast.244419.4352) + %or.114925.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109236.7.clone.1, %shift-right-logical.115403.7.clone.1) + %xor.121470.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247499.3.clone.1, %or.114925.5.clone.1) + %add.247500.3.clone.1 = u32[1280,1280]{1,0} add(%add.247499.3.clone.1, %xor.121470.3.clone.1) + %add.247501.7.clone.1 = u32[1280,1280]{1,0} add(%add.247500.3.clone.1, %broadcast.250405.44.clone.1) + %shift-left.109238.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121470.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115404.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121470.3.clone.1, %broadcast.244418.4352) + %or.114926.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109238.7.clone.1, %shift-right-logical.115404.7.clone.1) + %xor.121471.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247500.3.clone.1, %or.114926.5.clone.1) + %constant_218143_1_clone_1 = u32[] constant(3482744794) + %broadcast.250436.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218143_1_clone_1), dimensions={} + %add.247502.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121471.3.clone.1, %broadcast.250436.5.clone.1) + %add.247503.5.clone.1 = u32[1280,1280]{1,0} add(%add.247501.7.clone.1, %add.247502.5.clone.1) + %shift-left.109239.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247502.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115405.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247502.5.clone.1, %broadcast.244416.5760) + %or.114927.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109239.9.clone.1, %shift-right-logical.115405.9.clone.1) + %xor.121472.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247503.5.clone.1, %or.114927.7.clone.1) + %add.247504.3.clone.1 = u32[1280,1280]{1,0} add(%add.247503.5.clone.1, %xor.121472.5.clone.1) + %shift-left.109240.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121472.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115406.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121472.5.clone.1, %broadcast.244429.2304) + %or.114928.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109240.9.clone.1, %shift-right-logical.115406.9.clone.1) + %xor.121473.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247504.3.clone.1, %or.114928.7.clone.1) + %add.247505.3.clone.1 = u32[1280,1280]{1,0} add(%add.247504.3.clone.1, %xor.121473.5.clone.1) + %shift-left.109241.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121473.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115407.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121473.5.clone.1, %broadcast.244430.4608) + %or.114929.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109241.9.clone.1, %shift-right-logical.115407.9.clone.1) + %xor.121474.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247505.3.clone.1, %or.114929.7.clone.1) + %add.247506.3.clone.1 = u32[1280,1280]{1,0} add(%add.247505.3.clone.1, %xor.121474.5.clone.1) + %add.247507.7.clone.1 = u32[1280,1280]{1,0} add(%add.247506.3.clone.1, %broadcast.250406.113.clone.1) + %shift-left.109243.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121474.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115408.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121474.5.clone.1, %broadcast.244434.2816) + %or.114930.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109243.11.clone.1, %shift-right-logical.115408.11.clone.1) + %xor.121475.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247506.3.clone.1, %or.114930.9.clone.1) + %constant_218144_1_clone_1 = u32[] constant(939089095) + %broadcast.250446.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218144_1_clone_1), dimensions={} + %add.247508.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121475.7.clone.1, %broadcast.250446.5.clone.1) + %add.247509.5.clone.1 = u32[1280,1280]{1,0} add(%add.247507.7.clone.1, %add.247508.5.clone.1) + %shift-left.109244.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247508.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115409.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247508.5.clone.1, %broadcast.244415.6016) + %or.114931.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109244.9.clone.1, %shift-right-logical.115409.9.clone.1) + %xor.121476.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247509.5.clone.1, %or.114931.7.clone.1) + %add.247510.3.clone.1 = u32[1280,1280]{1,0} add(%add.247509.5.clone.1, %xor.121476.5.clone.1) + %shift-left.109245.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121476.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115410.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121476.5.clone.1, %broadcast.244417.5760) + %or.114932.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109245.9.clone.1, %shift-right-logical.115410.9.clone.1) + %xor.121477.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247510.3.clone.1, %or.114932.7.clone.1) + %add.247511.3.clone.1 = u32[1280,1280]{1,0} add(%add.247510.3.clone.1, %xor.121477.5.clone.1) + %shift-left.109246.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121477.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115411.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121477.5.clone.1, %broadcast.244419.4352) + %or.114933.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109246.5.clone.1, %shift-right-logical.115411.5.clone.1) + %xor.121478.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247511.3.clone.1, %or.114933.3.clone.1) + %add.247512.3.clone.1 = u32[1280,1280]{1,0} add(%add.247511.3.clone.1, %xor.121478.3.clone.1) + %add.247513.17.clone.1 = u32[1280,1280]{1,0} add(%add.247512.3.clone.1, %broadcast.250423.24.clone.1) + %shift-left.109247.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121478.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115412.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121478.3.clone.1, %broadcast.244418.4352) + %or.114934.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109247.5.clone.1, %shift-right-logical.115412.5.clone.1) + %xor.121479.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247512.3.clone.1, %or.114934.3.clone.1) + %constant_218145_1_clone_1 = u32[] constant(3820892371) + %broadcast.250456.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218145_1_clone_1), dimensions={} + %add.247514.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121479.15.clone.1, %broadcast.250456.19.clone.1) + %xor.121480.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247513.17.clone.1, %add.247514.19.clone.1) + %shift-right-logical.115413.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121480.17.clone.1, %broadcast.244468.1920) + %or.114935.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115413.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5737.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114935.13.clone.1) + %add.247516.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5737.11.clone.1, %broadcast.244470.1152) + %multiply.26288.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247516.9.clone.1, %broadcast.244471.896) + %add.247519.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26288.7.clone.1, %broadcast.244408.1024) + %maximum.3669.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247519.5.clone.1) + %abs.1529.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3669.3.clone.1) + %compare.7206.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1529.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26289.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3669.3.clone.1, %broadcast.244476.1152) + %negate.4563.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3669.3.clone.1) + %multiply.26290.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3669.3.clone.1, %negate.4563.5.clone.1) + %log-plus-one.1529.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26290.5.clone.1) + %negate.4564.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1529.3.clone.1) + %compare.7207.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4564.4.clone.1, %broadcast.244477.384), direction=LT + %select.20884.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20885.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20886.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20887.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20888.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20889.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20890.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20891.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20892.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247520.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4564.4.clone.1, %broadcast.244496.640) + %sqrt.1529.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4564.4.clone.1) + %add.247521.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1529.5.clone.1, %broadcast.244498.640) + %select.20893.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7207.3.clone.1, %add.247520.5.clone.1, %add.247521.5.clone.1) + %multiply.26291.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20892.3.clone.1, %select.20893.3.clone.1) + %add.247522.1.clone.1 = f32[1280,1280]{1,0} add(%select.20891.3.clone.1, %multiply.26291.1.clone.1) + %multiply.26292.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247522.1.clone.1, %select.20893.3.clone.1) + %add.247524.1.clone.1 = f32[1280,1280]{1,0} add(%select.20890.3.clone.1, %multiply.26292.1.clone.1) + %multiply.26293.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247524.1.clone.1, %select.20893.3.clone.1) + %add.247525.1.clone.1 = f32[1280,1280]{1,0} add(%select.20889.3.clone.1, %multiply.26293.1.clone.1) + %multiply.26294.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247525.1.clone.1, %select.20893.3.clone.1) + %add.247526.1.clone.1 = f32[1280,1280]{1,0} add(%select.20888.3.clone.1, %multiply.26294.1.clone.1) + %multiply.26295.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247526.1.clone.1, %select.20893.3.clone.1) + %add.247527.3.clone.1 = f32[1280,1280]{1,0} add(%select.20887.5.clone.1, %multiply.26295.1.clone.1) + %multiply.26296.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247527.3.clone.1, %select.20893.3.clone.1) + %add.247529.3.clone.1 = f32[1280,1280]{1,0} add(%select.20886.5.clone.1, %multiply.26296.1.clone.1) + %multiply.26297.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247529.3.clone.1, %select.20893.3.clone.1) + %add.247530.9.clone.1 = f32[1280,1280]{1,0} add(%select.20885.11.clone.1, %multiply.26297.7.clone.1) + %multiply.26298.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247530.9.clone.1, %select.20893.3.clone.1) + %add.247531.7.clone.1 = f32[1280,1280]{1,0} add(%select.20884.7.clone.1, %multiply.26298.7.clone.1) + %multiply.26299.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247531.7.clone.1, %maximum.3669.3.clone.1) + %select.20894.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7206.3.clone.1, %multiply.26289.9.clone.1, %multiply.26299.7.clone.1) + %multiply.26300.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20894.7.clone.1, %broadcast.244500.640) + %clamp.1173.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26300.5.clone.1, %broadcast.244501.384) + %multiply.26301.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1173.3.clone.1, %broadcast.244502.1) + %constant_184351_1_clone_1 = u32[] constant(2451525708) + %broadcast.255768.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184351_1_clone_1), dimensions={} + %add.250547.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255768.44.clone.1) + %constant_184361_1_clone_1 = u32[] constant(2332708934) + %broadcast.255769.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184361_1_clone_1), dimensions={} + %add.250549.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255769.113.clone.1) + %add.250552.35.clone.1 = u32[1280,1280]{1,0} add(%add.250547.37.clone.1, %add.250549.99.clone.1) + %shift-left.110547.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250549.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116801.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250549.99.clone.1, %broadcast.244415.6016) + %or.116316.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110547.31.clone.1, %shift-right-logical.116801.29.clone.1) + %xor.122886.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250552.35.clone.1, %or.116316.29.clone.1) + %add.250553.5.clone.1 = u32[1280,1280]{1,0} add(%add.250552.35.clone.1, %xor.122886.27.clone.1) + %shift-left.110548.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122886.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116802.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122886.27.clone.1, %broadcast.244417.5760) + %or.116317.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110548.9.clone.1, %shift-right-logical.116802.9.clone.1) + %xor.122887.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250553.5.clone.1, %or.116317.7.clone.1) + %add.250554.3.clone.1 = u32[1280,1280]{1,0} add(%add.250553.5.clone.1, %xor.122887.5.clone.1) + %shift-left.110549.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122887.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116803.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122887.5.clone.1, %broadcast.244419.4352) + %or.116318.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110549.5.clone.1, %shift-right-logical.116803.5.clone.1) + %xor.122888.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250554.3.clone.1, %or.116318.3.clone.1) + %add.250555.3.clone.1 = u32[1280,1280]{1,0} add(%add.250554.3.clone.1, %xor.122888.3.clone.1) + %add.250556.7.clone.1 = u32[1280,1280]{1,0} add(%add.250555.3.clone.1, %broadcast.255769.113.clone.1) + %shift-left.110551.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122888.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116804.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122888.3.clone.1, %broadcast.244418.4352) + %or.116319.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110551.5.clone.1, %shift-right-logical.116804.5.clone.1) + %xor.122890.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250555.3.clone.1, %or.116319.3.clone.1) + %constant_218479_1_clone_1 = u32[] constant(46406609) + %broadcast.255779.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218479_1_clone_1), dimensions={} + %add.250557.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122890.3.clone.1, %broadcast.255779.5.clone.1) + %add.250558.5.clone.1 = u32[1280,1280]{1,0} add(%add.250556.7.clone.1, %add.250557.5.clone.1) + %shift-left.110552.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250557.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116805.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250557.5.clone.1, %broadcast.244416.5760) + %or.116320.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110552.9.clone.1, %shift-right-logical.116805.9.clone.1) + %xor.122891.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250558.5.clone.1, %or.116320.7.clone.1) + %add.250559.3.clone.1 = u32[1280,1280]{1,0} add(%add.250558.5.clone.1, %xor.122891.5.clone.1) + %shift-left.110553.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122891.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116806.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122891.5.clone.1, %broadcast.244429.2304) + %or.116321.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110553.9.clone.1, %shift-right-logical.116806.9.clone.1) + %xor.122892.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250559.3.clone.1, %or.116321.7.clone.1) + %add.250560.3.clone.1 = u32[1280,1280]{1,0} add(%add.250559.3.clone.1, %xor.122892.5.clone.1) + %shift-left.110554.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122892.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116807.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122892.5.clone.1, %broadcast.244430.4608) + %or.116322.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110554.9.clone.1, %shift-right-logical.116807.9.clone.1) + %xor.122893.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250560.3.clone.1, %or.116322.7.clone.1) + %add.250561.3.clone.1 = u32[1280,1280]{1,0} add(%add.250560.3.clone.1, %xor.122893.5.clone.1) + %constant_184365_1_clone_1 = u32[] constant(46406608) + %broadcast.255786.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184365_1_clone_1), dimensions={} + %add.250562.7.clone.1 = u32[1280,1280]{1,0} add(%add.250561.3.clone.1, %broadcast.255786.24.clone.1) + %shift-left.110555.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122893.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116808.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122893.5.clone.1, %broadcast.244434.2816) + %or.116323.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110555.11.clone.1, %shift-right-logical.116808.11.clone.1) + %xor.122895.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250561.3.clone.1, %or.116323.9.clone.1) + %constant_218480_1_clone_1 = u32[] constant(2451525710) + %broadcast.255789.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218480_1_clone_1), dimensions={} + %add.250563.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122895.7.clone.1, %broadcast.255789.5.clone.1) + %add.250564.5.clone.1 = u32[1280,1280]{1,0} add(%add.250562.7.clone.1, %add.250563.5.clone.1) + %shift-left.110556.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250563.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116809.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250563.5.clone.1, %broadcast.244415.6016) + %or.116325.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110556.9.clone.1, %shift-right-logical.116809.9.clone.1) + %xor.122896.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250564.5.clone.1, %or.116325.7.clone.1) + %add.250565.3.clone.1 = u32[1280,1280]{1,0} add(%add.250564.5.clone.1, %xor.122896.5.clone.1) + %shift-left.110557.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122896.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116810.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122896.5.clone.1, %broadcast.244417.5760) + %or.116326.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110557.9.clone.1, %shift-right-logical.116810.9.clone.1) + %xor.122897.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250565.3.clone.1, %or.116326.7.clone.1) + %add.250566.3.clone.1 = u32[1280,1280]{1,0} add(%add.250565.3.clone.1, %xor.122897.5.clone.1) + %shift-left.110558.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122897.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116811.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122897.5.clone.1, %broadcast.244419.4352) + %or.116327.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110558.7.clone.1, %shift-right-logical.116811.7.clone.1) + %xor.122898.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250566.3.clone.1, %or.116327.5.clone.1) + %add.250567.3.clone.1 = u32[1280,1280]{1,0} add(%add.250566.3.clone.1, %xor.122898.3.clone.1) + %add.250568.7.clone.1 = u32[1280,1280]{1,0} add(%add.250567.3.clone.1, %broadcast.255768.44.clone.1) + %shift-left.110559.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122898.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116812.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122898.3.clone.1, %broadcast.244418.4352) + %or.116328.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110559.7.clone.1, %shift-right-logical.116812.7.clone.1) + %xor.122900.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250567.3.clone.1, %or.116328.5.clone.1) + %constant_218481_1_clone_1 = u32[] constant(2332708937) + %broadcast.255799.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218481_1_clone_1), dimensions={} + %add.250569.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122900.3.clone.1, %broadcast.255799.5.clone.1) + %add.250570.5.clone.1 = u32[1280,1280]{1,0} add(%add.250568.7.clone.1, %add.250569.5.clone.1) + %shift-left.110561.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250569.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116813.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250569.5.clone.1, %broadcast.244416.5760) + %or.116330.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110561.9.clone.1, %shift-right-logical.116813.9.clone.1) + %xor.122901.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250570.5.clone.1, %or.116330.7.clone.1) + %add.250571.3.clone.1 = u32[1280,1280]{1,0} add(%add.250570.5.clone.1, %xor.122901.5.clone.1) + %shift-left.110562.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122901.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116815.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122901.5.clone.1, %broadcast.244429.2304) + %or.116331.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110562.9.clone.1, %shift-right-logical.116815.9.clone.1) + %xor.122902.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250571.3.clone.1, %or.116331.7.clone.1) + %add.250572.3.clone.1 = u32[1280,1280]{1,0} add(%add.250571.3.clone.1, %xor.122902.5.clone.1) + %shift-left.110563.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122902.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116816.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122902.5.clone.1, %broadcast.244430.4608) + %or.116332.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110563.9.clone.1, %shift-right-logical.116816.9.clone.1) + %xor.122903.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250572.3.clone.1, %or.116332.7.clone.1) + %add.250573.3.clone.1 = u32[1280,1280]{1,0} add(%add.250572.3.clone.1, %xor.122903.5.clone.1) + %add.250574.7.clone.1 = u32[1280,1280]{1,0} add(%add.250573.3.clone.1, %broadcast.255769.113.clone.1) + %shift-left.110564.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122903.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116817.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122903.5.clone.1, %broadcast.244434.2816) + %or.116333.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110564.11.clone.1, %shift-right-logical.116817.11.clone.1) + %xor.122905.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250573.3.clone.1, %or.116333.9.clone.1) + %constant_218482_1_clone_1 = u32[] constant(46406612) + %broadcast.255809.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218482_1_clone_1), dimensions={} + %add.250575.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122905.7.clone.1, %broadcast.255809.5.clone.1) + %add.250576.5.clone.1 = u32[1280,1280]{1,0} add(%add.250574.7.clone.1, %add.250575.5.clone.1) + %shift-left.110566.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250575.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116818.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250575.5.clone.1, %broadcast.244415.6016) + %or.116335.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110566.9.clone.1, %shift-right-logical.116818.9.clone.1) + %xor.122906.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250576.5.clone.1, %or.116335.7.clone.1) + %add.250577.3.clone.1 = u32[1280,1280]{1,0} add(%add.250576.5.clone.1, %xor.122906.5.clone.1) + %shift-left.110567.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122906.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116819.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122906.5.clone.1, %broadcast.244417.5760) + %or.116336.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110567.9.clone.1, %shift-right-logical.116819.9.clone.1) + %xor.122907.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250577.3.clone.1, %or.116336.7.clone.1) + %add.250578.3.clone.1 = u32[1280,1280]{1,0} add(%add.250577.3.clone.1, %xor.122907.5.clone.1) + %shift-left.110568.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122907.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116820.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122907.5.clone.1, %broadcast.244419.4352) + %or.116337.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110568.5.clone.1, %shift-right-logical.116820.5.clone.1) + %xor.122908.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250578.3.clone.1, %or.116337.3.clone.1) + %add.250579.3.clone.1 = u32[1280,1280]{1,0} add(%add.250578.3.clone.1, %xor.122908.3.clone.1) + %add.250580.17.clone.1 = u32[1280,1280]{1,0} add(%add.250579.3.clone.1, %broadcast.255786.24.clone.1) + %shift-left.110569.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122908.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116822.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122908.3.clone.1, %broadcast.244418.4352) + %or.116338.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110569.5.clone.1, %shift-right-logical.116822.5.clone.1) + %xor.122909.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250579.3.clone.1, %or.116338.3.clone.1) + %constant_218483_1_clone_1 = u32[] constant(2451525713) + %broadcast.255819.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218483_1_clone_1), dimensions={} + %add.250581.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122909.15.clone.1, %broadcast.255819.19.clone.1) + %xor.122910.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250580.17.clone.1, %add.250581.19.clone.1) + %shift-right-logical.116823.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122910.17.clone.1, %broadcast.244468.1920) + %or.116340.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116823.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5798.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116340.13.clone.1) + %add.250582.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5798.11.clone.1, %broadcast.244470.1152) + %multiply.26918.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250582.9.clone.1, %broadcast.244471.896) + %add.250583.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26918.7.clone.1, %broadcast.244408.1024) + %maximum.3730.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250583.5.clone.1) + %abs.1570.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3730.3.clone.1) + %compare.7296.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1570.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26919.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3730.3.clone.1, %broadcast.244476.1152) + %negate.4645.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3730.3.clone.1) + %multiply.26920.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3730.3.clone.1, %negate.4645.5.clone.1) + %log-plus-one.1570.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26920.5.clone.1) + %negate.4646.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1570.3.clone.1) + %compare.7297.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4646.4.clone.1, %broadcast.244477.384), direction=LT + %select.21356.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21357.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21358.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21359.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21360.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21361.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21362.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21363.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21364.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250584.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4646.4.clone.1, %broadcast.244496.640) + %sqrt.1570.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4646.4.clone.1) + %add.250585.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1570.5.clone.1, %broadcast.244498.640) + %select.21365.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7297.3.clone.1, %add.250584.5.clone.1, %add.250585.5.clone.1) + %multiply.26921.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21364.3.clone.1, %select.21365.3.clone.1) + %add.250586.1.clone.1 = f32[1280,1280]{1,0} add(%select.21363.3.clone.1, %multiply.26921.1.clone.1) + %multiply.26922.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250586.1.clone.1, %select.21365.3.clone.1) + %add.250587.1.clone.1 = f32[1280,1280]{1,0} add(%select.21362.3.clone.1, %multiply.26922.1.clone.1) + %multiply.26923.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250587.1.clone.1, %select.21365.3.clone.1) + %add.250588.1.clone.1 = f32[1280,1280]{1,0} add(%select.21361.3.clone.1, %multiply.26923.1.clone.1) + %multiply.26924.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250588.1.clone.1, %select.21365.3.clone.1) + %add.250589.1.clone.1 = f32[1280,1280]{1,0} add(%select.21360.3.clone.1, %multiply.26924.1.clone.1) + %multiply.26925.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250589.1.clone.1, %select.21365.3.clone.1) + %add.250590.3.clone.1 = f32[1280,1280]{1,0} add(%select.21359.5.clone.1, %multiply.26925.1.clone.1) + %multiply.26926.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250590.3.clone.1, %select.21365.3.clone.1) + %add.250591.3.clone.1 = f32[1280,1280]{1,0} add(%select.21358.5.clone.1, %multiply.26926.1.clone.1) + %multiply.26927.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250591.3.clone.1, %select.21365.3.clone.1) + %add.250592.9.clone.1 = f32[1280,1280]{1,0} add(%select.21357.11.clone.1, %multiply.26927.7.clone.1) + %multiply.26928.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250592.9.clone.1, %select.21365.3.clone.1) + %add.250593.7.clone.1 = f32[1280,1280]{1,0} add(%select.21356.7.clone.1, %multiply.26928.7.clone.1) + %multiply.26930.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250593.7.clone.1, %maximum.3730.3.clone.1) + %select.21366.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7296.3.clone.1, %multiply.26919.9.clone.1, %multiply.26930.7.clone.1) + %multiply.26931.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21366.7.clone.1, %broadcast.244500.640) + %clamp.1214.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26931.5.clone.1, %broadcast.244501.384) + %multiply.26932.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1214.3.clone.1, %broadcast.244502.1) + %constant_171715_1_clone_1 = u32[] constant(4125501039) + %broadcast.250289.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171715_1_clone_1), dimensions={} + %add.247437.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250289.44.clone.1) + %constant_171722_1_clone_1 = u32[] constant(1840320350) + %broadcast.250291.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171722_1_clone_1), dimensions={} + %add.247438.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250291.113.clone.1) + %add.247439.35.clone.1 = u32[1280,1280]{1,0} add(%add.247437.37.clone.1, %add.247438.99.clone.1) + %shift-left.109200.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247438.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115372.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247438.99.clone.1, %broadcast.244415.6016) + %or.114892.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109200.31.clone.1, %shift-right-logical.115372.29.clone.1) + %xor.121439.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247439.35.clone.1, %or.114892.29.clone.1) + %add.247440.5.clone.1 = u32[1280,1280]{1,0} add(%add.247439.35.clone.1, %xor.121439.27.clone.1) + %shift-left.109201.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121439.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115373.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121439.27.clone.1, %broadcast.244417.5760) + %or.114893.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109201.9.clone.1, %shift-right-logical.115373.9.clone.1) + %xor.121440.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247440.5.clone.1, %or.114893.7.clone.1) + %add.247441.3.clone.1 = u32[1280,1280]{1,0} add(%add.247440.5.clone.1, %xor.121440.5.clone.1) + %shift-left.109203.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121440.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115374.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121440.5.clone.1, %broadcast.244419.4352) + %or.114894.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109203.5.clone.1, %shift-right-logical.115374.5.clone.1) + %xor.121441.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247441.3.clone.1, %or.114894.3.clone.1) + %add.247442.3.clone.1 = u32[1280,1280]{1,0} add(%add.247441.3.clone.1, %xor.121441.3.clone.1) + %add.247443.7.clone.1 = u32[1280,1280]{1,0} add(%add.247442.3.clone.1, %broadcast.250291.113.clone.1) + %shift-left.109204.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121441.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115375.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121441.3.clone.1, %broadcast.244418.4352) + %or.114895.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109204.5.clone.1, %shift-right-logical.115375.5.clone.1) + %xor.121442.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247442.3.clone.1, %or.114895.3.clone.1) + %constant_218131_1_clone_1 = u32[] constant(2206610156) + %broadcast.250308.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218131_1_clone_1), dimensions={} + %add.247444.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121442.3.clone.1, %broadcast.250308.5.clone.1) + %add.247445.5.clone.1 = u32[1280,1280]{1,0} add(%add.247443.7.clone.1, %add.247444.5.clone.1) + %shift-left.109205.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247444.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115376.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247444.5.clone.1, %broadcast.244416.5760) + %or.114896.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109205.9.clone.1, %shift-right-logical.115376.9.clone.1) + %xor.121443.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247445.5.clone.1, %or.114896.7.clone.1) + %add.247446.3.clone.1 = u32[1280,1280]{1,0} add(%add.247445.5.clone.1, %xor.121443.5.clone.1) + %shift-left.109206.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121443.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115377.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121443.5.clone.1, %broadcast.244429.2304) + %or.114897.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109206.9.clone.1, %shift-right-logical.115377.9.clone.1) + %xor.121444.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247446.3.clone.1, %or.114897.7.clone.1) + %add.247447.3.clone.1 = u32[1280,1280]{1,0} add(%add.247446.3.clone.1, %xor.121444.5.clone.1) + %shift-left.109208.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121444.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115378.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121444.5.clone.1, %broadcast.244430.4608) + %or.114898.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109208.9.clone.1, %shift-right-logical.115378.9.clone.1) + %xor.121445.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247447.3.clone.1, %or.114898.7.clone.1) + %add.247448.3.clone.1 = u32[1280,1280]{1,0} add(%add.247447.3.clone.1, %xor.121445.5.clone.1) + %constant_171724_1_clone_1 = u32[] constant(2206610155) + %broadcast.250322.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171724_1_clone_1), dimensions={} + %add.247449.7.clone.1 = u32[1280,1280]{1,0} add(%add.247448.3.clone.1, %broadcast.250322.24.clone.1) + %shift-left.109209.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121445.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115379.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121445.5.clone.1, %broadcast.244434.2816) + %or.114899.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109209.11.clone.1, %shift-right-logical.115379.11.clone.1) + %xor.121446.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247448.3.clone.1, %or.114899.9.clone.1) + %constant_218133_1_clone_1 = u32[] constant(4125501041) + %broadcast.250328.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218133_1_clone_1), dimensions={} + %add.247450.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121446.7.clone.1, %broadcast.250328.5.clone.1) + %add.247451.5.clone.1 = u32[1280,1280]{1,0} add(%add.247449.7.clone.1, %add.247450.5.clone.1) + %shift-left.109210.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247450.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115380.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247450.5.clone.1, %broadcast.244415.6016) + %or.114900.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109210.9.clone.1, %shift-right-logical.115380.9.clone.1) + %xor.121447.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247451.5.clone.1, %or.114900.7.clone.1) + %add.247452.3.clone.1 = u32[1280,1280]{1,0} add(%add.247451.5.clone.1, %xor.121447.5.clone.1) + %shift-left.109211.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121447.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115381.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121447.5.clone.1, %broadcast.244417.5760) + %or.114901.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109211.9.clone.1, %shift-right-logical.115381.9.clone.1) + %xor.121448.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247452.3.clone.1, %or.114901.7.clone.1) + %add.247453.3.clone.1 = u32[1280,1280]{1,0} add(%add.247452.3.clone.1, %xor.121448.5.clone.1) + %shift-left.109213.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121448.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115382.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121448.5.clone.1, %broadcast.244419.4352) + %or.114902.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109213.7.clone.1, %shift-right-logical.115382.7.clone.1) + %xor.121449.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247453.3.clone.1, %or.114902.5.clone.1) + %add.247454.3.clone.1 = u32[1280,1280]{1,0} add(%add.247453.3.clone.1, %xor.121449.3.clone.1) + %add.247455.7.clone.1 = u32[1280,1280]{1,0} add(%add.247454.3.clone.1, %broadcast.250289.44.clone.1) + %shift-left.109214.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121449.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115383.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121449.3.clone.1, %broadcast.244418.4352) + %or.114903.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109214.7.clone.1, %shift-right-logical.115383.7.clone.1) + %xor.121450.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247454.3.clone.1, %or.114903.5.clone.1) + %constant_218135_1_clone_1 = u32[] constant(1840320353) + %broadcast.250348.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218135_1_clone_1), dimensions={} + %add.247456.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121450.3.clone.1, %broadcast.250348.5.clone.1) + %add.247457.5.clone.1 = u32[1280,1280]{1,0} add(%add.247455.7.clone.1, %add.247456.5.clone.1) + %shift-left.109215.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247456.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115384.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247456.5.clone.1, %broadcast.244416.5760) + %or.114904.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109215.9.clone.1, %shift-right-logical.115384.9.clone.1) + %xor.121451.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247457.5.clone.1, %or.114904.7.clone.1) + %add.247458.3.clone.1 = u32[1280,1280]{1,0} add(%add.247457.5.clone.1, %xor.121451.5.clone.1) + %shift-left.109216.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121451.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115385.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121451.5.clone.1, %broadcast.244429.2304) + %or.114905.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109216.9.clone.1, %shift-right-logical.115385.9.clone.1) + %xor.121452.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247458.3.clone.1, %or.114905.7.clone.1) + %add.247459.3.clone.1 = u32[1280,1280]{1,0} add(%add.247458.3.clone.1, %xor.121452.5.clone.1) + %shift-left.109218.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121452.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115386.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121452.5.clone.1, %broadcast.244430.4608) + %or.114906.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109218.9.clone.1, %shift-right-logical.115386.9.clone.1) + %xor.121453.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247459.3.clone.1, %or.114906.7.clone.1) + %add.247460.3.clone.1 = u32[1280,1280]{1,0} add(%add.247459.3.clone.1, %xor.121453.5.clone.1) + %add.247461.7.clone.1 = u32[1280,1280]{1,0} add(%add.247460.3.clone.1, %broadcast.250291.113.clone.1) + %shift-left.109219.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121453.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115387.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121453.5.clone.1, %broadcast.244434.2816) + %or.114907.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109219.11.clone.1, %shift-right-logical.115387.11.clone.1) + %xor.121454.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247460.3.clone.1, %or.114907.9.clone.1) + %constant_218137_1_clone_1 = u32[] constant(2206610159) + %broadcast.250360.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218137_1_clone_1), dimensions={} + %add.247462.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121454.7.clone.1, %broadcast.250360.5.clone.1) + %add.247464.5.clone.1 = u32[1280,1280]{1,0} add(%add.247461.7.clone.1, %add.247462.5.clone.1) + %shift-left.109220.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247462.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115388.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247462.5.clone.1, %broadcast.244415.6016) + %or.114908.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109220.9.clone.1, %shift-right-logical.115388.9.clone.1) + %xor.121455.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247464.5.clone.1, %or.114908.7.clone.1) + %add.247465.3.clone.1 = u32[1280,1280]{1,0} add(%add.247464.5.clone.1, %xor.121455.5.clone.1) + %shift-left.109221.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121455.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115389.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121455.5.clone.1, %broadcast.244417.5760) + %or.114909.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109221.9.clone.1, %shift-right-logical.115389.9.clone.1) + %xor.121456.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247465.3.clone.1, %or.114909.7.clone.1) + %add.247466.3.clone.1 = u32[1280,1280]{1,0} add(%add.247465.3.clone.1, %xor.121456.5.clone.1) + %shift-left.109222.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121456.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115390.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121456.5.clone.1, %broadcast.244419.4352) + %or.114910.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109222.5.clone.1, %shift-right-logical.115390.5.clone.1) + %xor.121457.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247466.3.clone.1, %or.114910.3.clone.1) + %add.247467.3.clone.1 = u32[1280,1280]{1,0} add(%add.247466.3.clone.1, %xor.121457.3.clone.1) + %add.247468.17.clone.1 = u32[1280,1280]{1,0} add(%add.247467.3.clone.1, %broadcast.250322.24.clone.1) + %shift-left.109223.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121457.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115391.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121457.3.clone.1, %broadcast.244418.4352) + %or.114911.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109223.5.clone.1, %shift-right-logical.115391.5.clone.1) + %xor.121458.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247467.3.clone.1, %or.114911.3.clone.1) + %constant_218139_1_clone_1 = u32[] constant(4125501044) + %broadcast.250370.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218139_1_clone_1), dimensions={} + %add.247469.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121458.15.clone.1, %broadcast.250370.19.clone.1) + %xor.121459.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247468.17.clone.1, %add.247469.19.clone.1) + %shift-right-logical.115392.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121459.17.clone.1, %broadcast.244468.1920) + %or.114912.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115392.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5736.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114912.13.clone.1) + %add.247470.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5736.11.clone.1, %broadcast.244470.1152) + %multiply.26274.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247470.9.clone.1, %broadcast.244471.896) + %add.247471.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26274.7.clone.1, %broadcast.244408.1024) + %maximum.3668.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247471.5.clone.1) + %abs.1528.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3668.3.clone.1) + %compare.7204.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1528.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26275.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3668.3.clone.1, %broadcast.244476.1152) + %negate.4561.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3668.3.clone.1) + %multiply.26276.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3668.3.clone.1, %negate.4561.5.clone.1) + %log-plus-one.1528.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26276.5.clone.1) + %negate.4562.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1528.3.clone.1) + %compare.7205.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4562.4.clone.1, %broadcast.244477.384), direction=LT + %select.20873.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20874.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20875.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20876.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20877.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20878.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20879.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20880.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20881.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247472.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4562.4.clone.1, %broadcast.244496.640) + %sqrt.1528.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4562.4.clone.1) + %add.247473.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1528.5.clone.1, %broadcast.244498.640) + %select.20882.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7205.3.clone.1, %add.247472.5.clone.1, %add.247473.5.clone.1) + %multiply.26277.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20881.3.clone.1, %select.20882.3.clone.1) + %add.247474.1.clone.1 = f32[1280,1280]{1,0} add(%select.20880.3.clone.1, %multiply.26277.1.clone.1) + %multiply.26278.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247474.1.clone.1, %select.20882.3.clone.1) + %add.247475.1.clone.1 = f32[1280,1280]{1,0} add(%select.20879.3.clone.1, %multiply.26278.1.clone.1) + %multiply.26279.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247475.1.clone.1, %select.20882.3.clone.1) + %add.247476.1.clone.1 = f32[1280,1280]{1,0} add(%select.20878.3.clone.1, %multiply.26279.1.clone.1) + %multiply.26280.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247476.1.clone.1, %select.20882.3.clone.1) + %add.247477.1.clone.1 = f32[1280,1280]{1,0} add(%select.20877.3.clone.1, %multiply.26280.1.clone.1) + %multiply.26281.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247477.1.clone.1, %select.20882.3.clone.1) + %add.247478.3.clone.1 = f32[1280,1280]{1,0} add(%select.20876.5.clone.1, %multiply.26281.1.clone.1) + %multiply.26282.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247478.3.clone.1, %select.20882.3.clone.1) + %add.247479.3.clone.1 = f32[1280,1280]{1,0} add(%select.20875.5.clone.1, %multiply.26282.1.clone.1) + %multiply.26283.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247479.3.clone.1, %select.20882.3.clone.1) + %add.247480.9.clone.1 = f32[1280,1280]{1,0} add(%select.20874.11.clone.1, %multiply.26283.7.clone.1) + %multiply.26284.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247480.9.clone.1, %select.20882.3.clone.1) + %add.247481.7.clone.1 = f32[1280,1280]{1,0} add(%select.20873.7.clone.1, %multiply.26284.7.clone.1) + %multiply.26285.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247481.7.clone.1, %maximum.3668.3.clone.1) + %select.20883.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7204.3.clone.1, %multiply.26275.9.clone.1, %multiply.26285.7.clone.1) + %multiply.26286.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20883.7.clone.1, %broadcast.244500.640) + %clamp.1172.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26286.5.clone.1, %broadcast.244501.384) + %multiply.26287.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1172.3.clone.1, %broadcast.244502.1) + %constant_193907_1_clone_1 = u32[] constant(210688267) + %broadcast.259886.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193907_1_clone_1), dimensions={} + %add.252913.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.259886.44.clone.1) + %constant_193914_1_clone_1 = u32[] constant(3327797064) + %broadcast.259887.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193914_1_clone_1), dimensions={} + %add.252914.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.259887.113.clone.1) + %add.252915.35.clone.1 = u32[1280,1280]{1,0} add(%add.252913.37.clone.1, %add.252914.99.clone.1) + %shift-left.111580.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252914.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117876.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252914.99.clone.1, %broadcast.244415.6016) + %or.117402.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111580.31.clone.1, %shift-right-logical.117876.29.clone.1) + %xor.123959.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252915.35.clone.1, %or.117402.29.clone.1) + %add.252916.5.clone.1 = u32[1280,1280]{1,0} add(%add.252915.35.clone.1, %xor.123959.27.clone.1) + %shift-left.111581.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123959.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117877.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123959.27.clone.1, %broadcast.244417.5760) + %or.117403.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111581.9.clone.1, %shift-right-logical.117877.9.clone.1) + %xor.123960.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252916.5.clone.1, %or.117403.7.clone.1) + %add.252917.3.clone.1 = u32[1280,1280]{1,0} add(%add.252916.5.clone.1, %xor.123960.5.clone.1) + %shift-left.111582.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123960.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117878.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123960.5.clone.1, %broadcast.244419.4352) + %or.117404.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111582.5.clone.1, %shift-right-logical.117878.5.clone.1) + %xor.123961.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252917.3.clone.1, %or.117404.3.clone.1) + %add.252918.3.clone.1 = u32[1280,1280]{1,0} add(%add.252917.3.clone.1, %xor.123961.3.clone.1) + %add.252919.7.clone.1 = u32[1280,1280]{1,0} add(%add.252918.3.clone.1, %broadcast.259887.113.clone.1) + %shift-left.111583.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123961.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117879.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123961.3.clone.1, %broadcast.244418.4352) + %or.117405.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111583.5.clone.1, %shift-right-logical.117879.5.clone.1) + %xor.123962.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252918.3.clone.1, %or.117405.3.clone.1) + %constant_218736_1_clone_1 = u32[] constant(3506824602) + %broadcast.259900.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218736_1_clone_1), dimensions={} + %add.252920.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123962.3.clone.1, %broadcast.259900.5.clone.1) + %add.252921.5.clone.1 = u32[1280,1280]{1,0} add(%add.252919.7.clone.1, %add.252920.5.clone.1) + %shift-left.111584.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252920.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117880.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252920.5.clone.1, %broadcast.244416.5760) + %or.117406.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111584.9.clone.1, %shift-right-logical.117880.9.clone.1) + %xor.123963.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252921.5.clone.1, %or.117406.7.clone.1) + %add.252922.3.clone.1 = u32[1280,1280]{1,0} add(%add.252921.5.clone.1, %xor.123963.5.clone.1) + %shift-left.111585.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123963.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117881.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123963.5.clone.1, %broadcast.244429.2304) + %or.117407.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111585.9.clone.1, %shift-right-logical.117881.9.clone.1) + %xor.123964.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252922.3.clone.1, %or.117407.7.clone.1) + %add.252923.3.clone.1 = u32[1280,1280]{1,0} add(%add.252922.3.clone.1, %xor.123964.5.clone.1) + %shift-left.111586.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123964.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117882.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123964.5.clone.1, %broadcast.244430.4608) + %or.117408.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111586.9.clone.1, %shift-right-logical.117882.9.clone.1) + %xor.123965.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252923.3.clone.1, %or.117408.7.clone.1) + %add.252924.3.clone.1 = u32[1280,1280]{1,0} add(%add.252923.3.clone.1, %xor.123965.5.clone.1) + %constant_193916_1_clone_1 = u32[] constant(3506824601) + %broadcast.259910.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193916_1_clone_1), dimensions={} + %add.252925.7.clone.1 = u32[1280,1280]{1,0} add(%add.252924.3.clone.1, %broadcast.259910.24.clone.1) + %shift-left.111587.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123965.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117883.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123965.5.clone.1, %broadcast.244434.2816) + %or.117409.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111587.11.clone.1, %shift-right-logical.117883.11.clone.1) + %xor.123967.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252924.3.clone.1, %or.117409.9.clone.1) + %constant_218737_1_clone_1 = u32[] constant(210688269) + %broadcast.259913.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218737_1_clone_1), dimensions={} + %add.252926.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123967.7.clone.1, %broadcast.259913.5.clone.1) + %add.252927.5.clone.1 = u32[1280,1280]{1,0} add(%add.252925.7.clone.1, %add.252926.5.clone.1) + %shift-left.111588.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252926.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117884.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252926.5.clone.1, %broadcast.244415.6016) + %or.117410.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111588.9.clone.1, %shift-right-logical.117884.9.clone.1) + %xor.123968.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252927.5.clone.1, %or.117410.7.clone.1) + %add.252928.3.clone.1 = u32[1280,1280]{1,0} add(%add.252927.5.clone.1, %xor.123968.5.clone.1) + %shift-left.111589.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123968.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117885.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123968.5.clone.1, %broadcast.244417.5760) + %or.117411.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111589.9.clone.1, %shift-right-logical.117885.9.clone.1) + %xor.123969.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252928.3.clone.1, %or.117411.7.clone.1) + %add.252929.3.clone.1 = u32[1280,1280]{1,0} add(%add.252928.3.clone.1, %xor.123969.5.clone.1) + %shift-left.111590.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123969.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117886.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123969.5.clone.1, %broadcast.244419.4352) + %or.117412.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111590.7.clone.1, %shift-right-logical.117886.7.clone.1) + %xor.123970.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252929.3.clone.1, %or.117412.5.clone.1) + %add.252930.3.clone.1 = u32[1280,1280]{1,0} add(%add.252929.3.clone.1, %xor.123970.3.clone.1) + %add.252931.7.clone.1 = u32[1280,1280]{1,0} add(%add.252930.3.clone.1, %broadcast.259886.44.clone.1) + %shift-left.111591.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123970.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117887.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123970.3.clone.1, %broadcast.244418.4352) + %or.117413.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111591.7.clone.1, %shift-right-logical.117887.7.clone.1) + %xor.123972.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252930.3.clone.1, %or.117413.5.clone.1) + %constant_218738_1_clone_1 = u32[] constant(3327797067) + %broadcast.259923.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218738_1_clone_1), dimensions={} + %add.252932.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123972.3.clone.1, %broadcast.259923.5.clone.1) + %add.252933.5.clone.1 = u32[1280,1280]{1,0} add(%add.252931.7.clone.1, %add.252932.5.clone.1) + %shift-left.111592.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252932.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117888.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252932.5.clone.1, %broadcast.244416.5760) + %or.117414.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111592.9.clone.1, %shift-right-logical.117888.9.clone.1) + %xor.123973.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252933.5.clone.1, %or.117414.7.clone.1) + %add.252934.3.clone.1 = u32[1280,1280]{1,0} add(%add.252933.5.clone.1, %xor.123973.5.clone.1) + %shift-left.111593.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123973.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117889.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123973.5.clone.1, %broadcast.244429.2304) + %or.117415.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111593.9.clone.1, %shift-right-logical.117889.9.clone.1) + %xor.123974.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252934.3.clone.1, %or.117415.7.clone.1) + %add.252935.3.clone.1 = u32[1280,1280]{1,0} add(%add.252934.3.clone.1, %xor.123974.5.clone.1) + %shift-left.111594.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123974.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117890.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123974.5.clone.1, %broadcast.244430.4608) + %or.117416.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111594.9.clone.1, %shift-right-logical.117890.9.clone.1) + %xor.123975.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252935.3.clone.1, %or.117416.7.clone.1) + %add.252936.3.clone.1 = u32[1280,1280]{1,0} add(%add.252935.3.clone.1, %xor.123975.5.clone.1) + %add.252937.7.clone.1 = u32[1280,1280]{1,0} add(%add.252936.3.clone.1, %broadcast.259887.113.clone.1) + %shift-left.111595.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123975.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117891.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123975.5.clone.1, %broadcast.244434.2816) + %or.117417.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111595.11.clone.1, %shift-right-logical.117891.11.clone.1) + %xor.123977.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252936.3.clone.1, %or.117417.9.clone.1) + %constant_218739_1_clone_1 = u32[] constant(3506824605) + %broadcast.259933.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218739_1_clone_1), dimensions={} + %add.252938.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123977.7.clone.1, %broadcast.259933.5.clone.1) + %add.252939.5.clone.1 = u32[1280,1280]{1,0} add(%add.252937.7.clone.1, %add.252938.5.clone.1) + %shift-left.111596.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252938.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117892.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252938.5.clone.1, %broadcast.244415.6016) + %or.117418.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111596.9.clone.1, %shift-right-logical.117892.9.clone.1) + %xor.123978.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252939.5.clone.1, %or.117418.7.clone.1) + %add.252940.3.clone.1 = u32[1280,1280]{1,0} add(%add.252939.5.clone.1, %xor.123978.5.clone.1) + %shift-left.111597.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123978.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117893.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123978.5.clone.1, %broadcast.244417.5760) + %or.117419.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111597.9.clone.1, %shift-right-logical.117893.9.clone.1) + %xor.123979.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252940.3.clone.1, %or.117419.7.clone.1) + %add.252941.3.clone.1 = u32[1280,1280]{1,0} add(%add.252940.3.clone.1, %xor.123979.5.clone.1) + %shift-left.111598.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123979.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117894.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123979.5.clone.1, %broadcast.244419.4352) + %or.117420.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111598.5.clone.1, %shift-right-logical.117894.5.clone.1) + %xor.123980.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252941.3.clone.1, %or.117420.3.clone.1) + %add.252942.3.clone.1 = u32[1280,1280]{1,0} add(%add.252941.3.clone.1, %xor.123980.3.clone.1) + %add.252943.17.clone.1 = u32[1280,1280]{1,0} add(%add.252942.3.clone.1, %broadcast.259910.24.clone.1) + %shift-left.111599.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123980.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117895.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123980.3.clone.1, %broadcast.244418.4352) + %or.117422.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111599.5.clone.1, %shift-right-logical.117895.5.clone.1) + %xor.123982.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252942.3.clone.1, %or.117422.3.clone.1) + %constant_218740_1_clone_1 = u32[] constant(210688272) + %broadcast.259943.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218740_1_clone_1), dimensions={} + %add.252944.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123982.15.clone.1, %broadcast.259943.19.clone.1) + %xor.123983.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252943.17.clone.1, %add.252944.19.clone.1) + %shift-right-logical.117896.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123983.17.clone.1, %broadcast.244468.1920) + %or.117423.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117896.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5845.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117423.13.clone.1) + %add.252945.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5845.11.clone.1, %broadcast.244470.1152) + %multiply.27398.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252945.9.clone.1, %broadcast.244471.896) + %add.252947.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27398.7.clone.1, %broadcast.244408.1024) + %maximum.3777.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252947.5.clone.1) + %abs.1601.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3777.3.clone.1) + %compare.7364.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1601.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27399.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3777.3.clone.1, %broadcast.244476.1152) + %negate.4707.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3777.3.clone.1) + %multiply.27400.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3777.3.clone.1, %negate.4707.5.clone.1) + %log-plus-one.1601.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27400.5.clone.1) + %negate.4708.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1601.3.clone.1) + %compare.7365.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4708.4.clone.1, %broadcast.244477.384), direction=LT + %select.21718.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21719.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21720.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21721.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21722.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21723.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21724.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21725.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21726.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252950.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4708.4.clone.1, %broadcast.244496.640) + %sqrt.1601.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4708.4.clone.1) + %add.252951.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1601.5.clone.1, %broadcast.244498.640) + %select.21727.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7365.3.clone.1, %add.252950.5.clone.1, %add.252951.5.clone.1) + %multiply.27401.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21726.3.clone.1, %select.21727.3.clone.1) + %add.252952.1.clone.1 = f32[1280,1280]{1,0} add(%select.21725.3.clone.1, %multiply.27401.1.clone.1) + %multiply.27402.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252952.1.clone.1, %select.21727.3.clone.1) + %add.252953.1.clone.1 = f32[1280,1280]{1,0} add(%select.21724.3.clone.1, %multiply.27402.1.clone.1) + %multiply.27403.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252953.1.clone.1, %select.21727.3.clone.1) + %add.252955.1.clone.1 = f32[1280,1280]{1,0} add(%select.21723.3.clone.1, %multiply.27403.1.clone.1) + %multiply.27404.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252955.1.clone.1, %select.21727.3.clone.1) + %add.252956.1.clone.1 = f32[1280,1280]{1,0} add(%select.21722.3.clone.1, %multiply.27404.1.clone.1) + %multiply.27405.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252956.1.clone.1, %select.21727.3.clone.1) + %add.252957.3.clone.1 = f32[1280,1280]{1,0} add(%select.21721.5.clone.1, %multiply.27405.1.clone.1) + %multiply.27406.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252957.3.clone.1, %select.21727.3.clone.1) + %add.252958.3.clone.1 = f32[1280,1280]{1,0} add(%select.21720.5.clone.1, %multiply.27406.1.clone.1) + %multiply.27407.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252958.3.clone.1, %select.21727.3.clone.1) + %add.252960.9.clone.1 = f32[1280,1280]{1,0} add(%select.21719.11.clone.1, %multiply.27407.7.clone.1) + %multiply.27408.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252960.9.clone.1, %select.21727.3.clone.1) + %add.252961.7.clone.1 = f32[1280,1280]{1,0} add(%select.21718.7.clone.1, %multiply.27408.7.clone.1) + %multiply.27409.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252961.7.clone.1, %maximum.3777.3.clone.1) + %select.21728.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7364.3.clone.1, %multiply.27399.9.clone.1, %multiply.27409.7.clone.1) + %multiply.27410.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21728.7.clone.1, %broadcast.244500.640) + %clamp.1245.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27410.5.clone.1, %broadcast.244501.384) + %multiply.27411.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1245.3.clone.1, %broadcast.244502.1) + %constant_171499_1_clone_1 = u32[] constant(2417465807) + %broadcast.250193.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171499_1_clone_1), dimensions={} + %add.247375.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.250193.44.clone.1) + %constant_171507_1_clone_1 = u32[] constant(1581799846) + %broadcast.250194.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171507_1_clone_1), dimensions={} + %add.247377.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.250194.113.clone.1) + %add.247378.35.clone.1 = u32[1280,1280]{1,0} add(%add.247375.37.clone.1, %add.247377.99.clone.1) + %shift-left.109180.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247377.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115351.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247377.99.clone.1, %broadcast.244415.6016) + %or.114871.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109180.31.clone.1, %shift-right-logical.115351.29.clone.1) + %xor.121418.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247378.35.clone.1, %or.114871.29.clone.1) + %add.247379.5.clone.1 = u32[1280,1280]{1,0} add(%add.247378.35.clone.1, %xor.121418.27.clone.1) + %shift-left.109181.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121418.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115352.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121418.27.clone.1, %broadcast.244417.5760) + %or.114872.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109181.9.clone.1, %shift-right-logical.115352.9.clone.1) + %xor.121419.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247379.5.clone.1, %or.114872.7.clone.1) + %add.247380.3.clone.1 = u32[1280,1280]{1,0} add(%add.247379.5.clone.1, %xor.121419.5.clone.1) + %shift-left.109182.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121419.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115353.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121419.5.clone.1, %broadcast.244419.4352) + %or.114873.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109182.5.clone.1, %shift-right-logical.115353.5.clone.1) + %xor.121420.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247380.3.clone.1, %or.114873.3.clone.1) + %add.247381.3.clone.1 = u32[1280,1280]{1,0} add(%add.247380.3.clone.1, %xor.121420.3.clone.1) + %add.247383.7.clone.1 = u32[1280,1280]{1,0} add(%add.247381.3.clone.1, %broadcast.250194.113.clone.1) + %shift-left.109183.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121420.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115354.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121420.3.clone.1, %broadcast.244418.4352) + %or.114874.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109183.5.clone.1, %shift-right-logical.115354.5.clone.1) + %xor.121421.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247381.3.clone.1, %or.114874.3.clone.1) + %constant_218126_1_clone_1 = u32[] constant(3582909364) + %broadcast.250206.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218126_1_clone_1), dimensions={} + %add.247387.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121421.3.clone.1, %broadcast.250206.5.clone.1) + %add.247388.5.clone.1 = u32[1280,1280]{1,0} add(%add.247383.7.clone.1, %add.247387.5.clone.1) + %shift-left.109184.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247387.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115355.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247387.5.clone.1, %broadcast.244416.5760) + %or.114875.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109184.9.clone.1, %shift-right-logical.115355.9.clone.1) + %xor.121422.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247388.5.clone.1, %or.114875.7.clone.1) + %add.247389.3.clone.1 = u32[1280,1280]{1,0} add(%add.247388.5.clone.1, %xor.121422.5.clone.1) + %shift-left.109185.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121422.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115356.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121422.5.clone.1, %broadcast.244429.2304) + %or.114876.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109185.9.clone.1, %shift-right-logical.115356.9.clone.1) + %xor.121423.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247389.3.clone.1, %or.114876.7.clone.1) + %add.247390.3.clone.1 = u32[1280,1280]{1,0} add(%add.247389.3.clone.1, %xor.121423.5.clone.1) + %shift-left.109186.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121423.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115357.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121423.5.clone.1, %broadcast.244430.4608) + %or.114877.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109186.9.clone.1, %shift-right-logical.115357.9.clone.1) + %xor.121424.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247390.3.clone.1, %or.114877.7.clone.1) + %add.247392.3.clone.1 = u32[1280,1280]{1,0} add(%add.247390.3.clone.1, %xor.121424.5.clone.1) + %constant_171509_1_clone_1 = u32[] constant(3582909363) + %broadcast.250213.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_171509_1_clone_1), dimensions={} + %add.247393.7.clone.1 = u32[1280,1280]{1,0} add(%add.247392.3.clone.1, %broadcast.250213.24.clone.1) + %shift-left.109187.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121424.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115358.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121424.5.clone.1, %broadcast.244434.2816) + %or.114878.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109187.11.clone.1, %shift-right-logical.115358.11.clone.1) + %xor.121425.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247392.3.clone.1, %or.114878.9.clone.1) + %constant_218127_1_clone_1 = u32[] constant(2417465809) + %broadcast.250216.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218127_1_clone_1), dimensions={} + %add.247394.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121425.7.clone.1, %broadcast.250216.5.clone.1) + %add.247395.5.clone.1 = u32[1280,1280]{1,0} add(%add.247393.7.clone.1, %add.247394.5.clone.1) + %shift-left.109188.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247394.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115359.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247394.5.clone.1, %broadcast.244415.6016) + %or.114879.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109188.9.clone.1, %shift-right-logical.115359.9.clone.1) + %xor.121426.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247395.5.clone.1, %or.114879.7.clone.1) + %add.247397.3.clone.1 = u32[1280,1280]{1,0} add(%add.247395.5.clone.1, %xor.121426.5.clone.1) + %shift-left.109189.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121426.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115360.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121426.5.clone.1, %broadcast.244417.5760) + %or.114880.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109189.9.clone.1, %shift-right-logical.115360.9.clone.1) + %xor.121427.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247397.3.clone.1, %or.114880.7.clone.1) + %add.247398.3.clone.1 = u32[1280,1280]{1,0} add(%add.247397.3.clone.1, %xor.121427.5.clone.1) + %shift-left.109190.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121427.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115361.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121427.5.clone.1, %broadcast.244419.4352) + %or.114881.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109190.7.clone.1, %shift-right-logical.115361.7.clone.1) + %xor.121428.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247398.3.clone.1, %or.114881.5.clone.1) + %add.247399.3.clone.1 = u32[1280,1280]{1,0} add(%add.247398.3.clone.1, %xor.121428.3.clone.1) + %add.247400.7.clone.1 = u32[1280,1280]{1,0} add(%add.247399.3.clone.1, %broadcast.250193.44.clone.1) + %shift-left.109191.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121428.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115362.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121428.3.clone.1, %broadcast.244418.4352) + %or.114882.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109191.7.clone.1, %shift-right-logical.115362.7.clone.1) + %xor.121429.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247399.3.clone.1, %or.114882.5.clone.1) + %constant_218128_1_clone_1 = u32[] constant(1581799849) + %broadcast.250226.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218128_1_clone_1), dimensions={} + %add.247402.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121429.3.clone.1, %broadcast.250226.5.clone.1) + %add.247403.5.clone.1 = u32[1280,1280]{1,0} add(%add.247400.7.clone.1, %add.247402.5.clone.1) + %shift-left.109192.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247402.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115363.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247402.5.clone.1, %broadcast.244416.5760) + %or.114883.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109192.9.clone.1, %shift-right-logical.115363.9.clone.1) + %xor.121430.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247403.5.clone.1, %or.114883.7.clone.1) + %add.247404.3.clone.1 = u32[1280,1280]{1,0} add(%add.247403.5.clone.1, %xor.121430.5.clone.1) + %shift-left.109193.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121430.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115364.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121430.5.clone.1, %broadcast.244429.2304) + %or.114884.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109193.9.clone.1, %shift-right-logical.115364.9.clone.1) + %xor.121431.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247404.3.clone.1, %or.114884.7.clone.1) + %add.247405.3.clone.1 = u32[1280,1280]{1,0} add(%add.247404.3.clone.1, %xor.121431.5.clone.1) + %shift-left.109194.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121431.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115365.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121431.5.clone.1, %broadcast.244430.4608) + %or.114885.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109194.9.clone.1, %shift-right-logical.115365.9.clone.1) + %xor.121432.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247405.3.clone.1, %or.114885.7.clone.1) + %add.247406.3.clone.1 = u32[1280,1280]{1,0} add(%add.247405.3.clone.1, %xor.121432.5.clone.1) + %add.247408.7.clone.1 = u32[1280,1280]{1,0} add(%add.247406.3.clone.1, %broadcast.250194.113.clone.1) + %shift-left.109195.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121432.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115366.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121432.5.clone.1, %broadcast.244434.2816) + %or.114886.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109195.11.clone.1, %shift-right-logical.115366.11.clone.1) + %xor.121433.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247406.3.clone.1, %or.114886.9.clone.1) + %constant_218129_1_clone_1 = u32[] constant(3582909367) + %broadcast.250238.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218129_1_clone_1), dimensions={} + %add.247412.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121433.7.clone.1, %broadcast.250238.5.clone.1) + %add.247413.5.clone.1 = u32[1280,1280]{1,0} add(%add.247408.7.clone.1, %add.247412.5.clone.1) + %shift-left.109196.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247412.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115367.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247412.5.clone.1, %broadcast.244415.6016) + %or.114887.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109196.9.clone.1, %shift-right-logical.115367.9.clone.1) + %xor.121434.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247413.5.clone.1, %or.114887.7.clone.1) + %add.247414.3.clone.1 = u32[1280,1280]{1,0} add(%add.247413.5.clone.1, %xor.121434.5.clone.1) + %shift-left.109197.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121434.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115368.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121434.5.clone.1, %broadcast.244417.5760) + %or.114888.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109197.9.clone.1, %shift-right-logical.115368.9.clone.1) + %xor.121435.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247414.3.clone.1, %or.114888.7.clone.1) + %add.247415.3.clone.1 = u32[1280,1280]{1,0} add(%add.247414.3.clone.1, %xor.121435.5.clone.1) + %shift-left.109198.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121435.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115369.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121435.5.clone.1, %broadcast.244419.4352) + %or.114889.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109198.5.clone.1, %shift-right-logical.115369.5.clone.1) + %xor.121436.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247415.3.clone.1, %or.114889.3.clone.1) + %add.247417.3.clone.1 = u32[1280,1280]{1,0} add(%add.247415.3.clone.1, %xor.121436.3.clone.1) + %add.247418.17.clone.1 = u32[1280,1280]{1,0} add(%add.247417.3.clone.1, %broadcast.250213.24.clone.1) + %shift-left.109199.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121436.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115370.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121436.3.clone.1, %broadcast.244418.4352) + %or.114890.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109199.5.clone.1, %shift-right-logical.115370.5.clone.1) + %xor.121437.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247417.3.clone.1, %or.114890.3.clone.1) + %constant_218130_1_clone_1 = u32[] constant(2417465812) + %broadcast.250248.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218130_1_clone_1), dimensions={} + %add.247419.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121437.15.clone.1, %broadcast.250248.19.clone.1) + %xor.121438.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247418.17.clone.1, %add.247419.19.clone.1) + %shift-right-logical.115371.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121438.17.clone.1, %broadcast.244468.1920) + %or.114891.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115371.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5735.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114891.13.clone.1) + %add.247420.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5735.11.clone.1, %broadcast.244470.1152) + %multiply.26260.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247420.9.clone.1, %broadcast.244471.896) + %add.247422.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26260.7.clone.1, %broadcast.244408.1024) + %maximum.3667.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247422.5.clone.1) + %abs.1527.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3667.3.clone.1) + %compare.7202.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1527.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26261.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3667.3.clone.1, %broadcast.244476.1152) + %negate.4559.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3667.3.clone.1) + %multiply.26262.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3667.3.clone.1, %negate.4559.5.clone.1) + %log-plus-one.1527.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26262.5.clone.1) + %negate.4560.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1527.3.clone.1) + %compare.7203.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4560.4.clone.1, %broadcast.244477.384), direction=LT + %select.20862.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20863.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20864.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20865.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20866.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20867.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20868.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20869.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20870.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247423.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4560.4.clone.1, %broadcast.244496.640) + %sqrt.1527.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4560.4.clone.1) + %add.247424.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1527.5.clone.1, %broadcast.244498.640) + %select.20871.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7203.3.clone.1, %add.247423.5.clone.1, %add.247424.5.clone.1) + %multiply.26263.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20870.3.clone.1, %select.20871.3.clone.1) + %add.247425.1.clone.1 = f32[1280,1280]{1,0} add(%select.20869.3.clone.1, %multiply.26263.1.clone.1) + %multiply.26264.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247425.1.clone.1, %select.20871.3.clone.1) + %add.247427.1.clone.1 = f32[1280,1280]{1,0} add(%select.20868.3.clone.1, %multiply.26264.1.clone.1) + %multiply.26265.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247427.1.clone.1, %select.20871.3.clone.1) + %add.247428.1.clone.1 = f32[1280,1280]{1,0} add(%select.20867.3.clone.1, %multiply.26265.1.clone.1) + %multiply.26266.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247428.1.clone.1, %select.20871.3.clone.1) + %add.247429.1.clone.1 = f32[1280,1280]{1,0} add(%select.20866.3.clone.1, %multiply.26266.1.clone.1) + %multiply.26267.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247429.1.clone.1, %select.20871.3.clone.1) + %add.247430.3.clone.1 = f32[1280,1280]{1,0} add(%select.20865.5.clone.1, %multiply.26267.1.clone.1) + %multiply.26268.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247430.3.clone.1, %select.20871.3.clone.1) + %add.247431.3.clone.1 = f32[1280,1280]{1,0} add(%select.20864.5.clone.1, %multiply.26268.1.clone.1) + %multiply.26269.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247431.3.clone.1, %select.20871.3.clone.1) + %add.247433.9.clone.1 = f32[1280,1280]{1,0} add(%select.20863.11.clone.1, %multiply.26269.7.clone.1) + %multiply.26270.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247433.9.clone.1, %select.20871.3.clone.1) + %add.247436.7.clone.1 = f32[1280,1280]{1,0} add(%select.20862.7.clone.1, %multiply.26270.7.clone.1) + %multiply.26271.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247436.7.clone.1, %maximum.3667.3.clone.1) + %select.20872.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7202.3.clone.1, %multiply.26261.9.clone.1, %multiply.26271.7.clone.1) + %multiply.26272.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20872.7.clone.1, %broadcast.244500.640) + %clamp.1171.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26272.5.clone.1, %broadcast.244501.384) + %multiply.26273.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1171.3.clone.1, %broadcast.244502.1) + %constant_184140_1_clone_1 = u32[] constant(662923858) + %broadcast.255682.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184140_1_clone_1), dimensions={} + %add.250488.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255682.44.clone.1) + %constant_184147_1_clone_1 = u32[] constant(107869931) + %broadcast.255683.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184147_1_clone_1), dimensions={} + %add.250489.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255683.113.clone.1) + %add.250490.35.clone.1 = u32[1280,1280]{1,0} add(%add.250488.37.clone.1, %add.250489.99.clone.1) + %shift-left.110523.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250489.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116780.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250489.99.clone.1, %broadcast.244415.6016) + %or.116291.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110523.31.clone.1, %shift-right-logical.116780.29.clone.1) + %xor.122861.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250490.35.clone.1, %or.116291.29.clone.1) + %add.250491.5.clone.1 = u32[1280,1280]{1,0} add(%add.250490.35.clone.1, %xor.122861.27.clone.1) + %shift-left.110524.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122861.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116781.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122861.27.clone.1, %broadcast.244417.5760) + %or.116292.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110524.9.clone.1, %shift-right-logical.116781.9.clone.1) + %xor.122862.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250491.5.clone.1, %or.116292.7.clone.1) + %add.250493.3.clone.1 = u32[1280,1280]{1,0} add(%add.250491.5.clone.1, %xor.122862.5.clone.1) + %shift-left.110526.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122862.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116782.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122862.5.clone.1, %broadcast.244419.4352) + %or.116293.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110526.5.clone.1, %shift-right-logical.116782.5.clone.1) + %xor.122863.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250493.3.clone.1, %or.116293.3.clone.1) + %add.250494.3.clone.1 = u32[1280,1280]{1,0} add(%add.250493.3.clone.1, %xor.122863.3.clone.1) + %add.250495.7.clone.1 = u32[1280,1280]{1,0} add(%add.250494.3.clone.1, %broadcast.255683.113.clone.1) + %shift-left.110527.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122863.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116783.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122863.3.clone.1, %broadcast.244418.4352) + %or.116294.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110527.5.clone.1, %shift-right-logical.116783.5.clone.1) + %xor.122865.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250494.3.clone.1, %or.116294.3.clone.1) + %constant_218474_1_clone_1 = u32[] constant(977241956) + %broadcast.255693.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218474_1_clone_1), dimensions={} + %add.250496.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122865.3.clone.1, %broadcast.255693.5.clone.1) + %add.250497.5.clone.1 = u32[1280,1280]{1,0} add(%add.250495.7.clone.1, %add.250496.5.clone.1) + %shift-left.110528.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250496.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116784.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250496.5.clone.1, %broadcast.244416.5760) + %or.116295.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110528.9.clone.1, %shift-right-logical.116784.9.clone.1) + %xor.122866.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250497.5.clone.1, %or.116295.7.clone.1) + %add.250499.3.clone.1 = u32[1280,1280]{1,0} add(%add.250497.5.clone.1, %xor.122866.5.clone.1) + %shift-left.110529.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122866.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116785.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122866.5.clone.1, %broadcast.244429.2304) + %or.116296.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110529.9.clone.1, %shift-right-logical.116785.9.clone.1) + %xor.122867.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250499.3.clone.1, %or.116296.7.clone.1) + %add.250503.3.clone.1 = u32[1280,1280]{1,0} add(%add.250499.3.clone.1, %xor.122867.5.clone.1) + %shift-left.110530.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122867.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116786.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122867.5.clone.1, %broadcast.244430.4608) + %or.116297.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110530.9.clone.1, %shift-right-logical.116786.9.clone.1) + %xor.122868.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250503.3.clone.1, %or.116297.7.clone.1) + %add.250504.3.clone.1 = u32[1280,1280]{1,0} add(%add.250503.3.clone.1, %xor.122868.5.clone.1) + %constant_184149_1_clone_1 = u32[] constant(977241955) + %broadcast.255700.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_184149_1_clone_1), dimensions={} + %add.250505.7.clone.1 = u32[1280,1280]{1,0} add(%add.250504.3.clone.1, %broadcast.255700.24.clone.1) + %shift-left.110531.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122868.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116787.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122868.5.clone.1, %broadcast.244434.2816) + %or.116298.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110531.11.clone.1, %shift-right-logical.116787.11.clone.1) + %xor.122870.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250504.3.clone.1, %or.116298.9.clone.1) + %constant_218475_1_clone_1 = u32[] constant(662923860) + %broadcast.255703.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218475_1_clone_1), dimensions={} + %add.250506.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122870.7.clone.1, %broadcast.255703.5.clone.1) + %add.250508.5.clone.1 = u32[1280,1280]{1,0} add(%add.250505.7.clone.1, %add.250506.5.clone.1) + %shift-left.110532.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250506.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116788.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250506.5.clone.1, %broadcast.244415.6016) + %or.116300.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110532.9.clone.1, %shift-right-logical.116788.9.clone.1) + %xor.122871.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250508.5.clone.1, %or.116300.7.clone.1) + %add.250509.3.clone.1 = u32[1280,1280]{1,0} add(%add.250508.5.clone.1, %xor.122871.5.clone.1) + %shift-left.110533.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122871.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116789.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122871.5.clone.1, %broadcast.244417.5760) + %or.116301.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110533.9.clone.1, %shift-right-logical.116789.9.clone.1) + %xor.122872.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250509.3.clone.1, %or.116301.7.clone.1) + %add.250510.3.clone.1 = u32[1280,1280]{1,0} add(%add.250509.3.clone.1, %xor.122872.5.clone.1) + %shift-left.110534.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122872.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116790.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122872.5.clone.1, %broadcast.244419.4352) + %or.116302.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110534.7.clone.1, %shift-right-logical.116790.7.clone.1) + %xor.122873.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250510.3.clone.1, %or.116302.5.clone.1) + %add.250511.3.clone.1 = u32[1280,1280]{1,0} add(%add.250510.3.clone.1, %xor.122873.3.clone.1) + %add.250513.7.clone.1 = u32[1280,1280]{1,0} add(%add.250511.3.clone.1, %broadcast.255682.44.clone.1) + %shift-left.110536.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122873.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116791.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122873.3.clone.1, %broadcast.244418.4352) + %or.116303.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110536.7.clone.1, %shift-right-logical.116791.7.clone.1) + %xor.122875.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250511.3.clone.1, %or.116303.5.clone.1) + %constant_218476_1_clone_1 = u32[] constant(107869934) + %broadcast.255713.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218476_1_clone_1), dimensions={} + %add.250514.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122875.3.clone.1, %broadcast.255713.5.clone.1) + %add.250515.5.clone.1 = u32[1280,1280]{1,0} add(%add.250513.7.clone.1, %add.250514.5.clone.1) + %shift-left.110537.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250514.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116792.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250514.5.clone.1, %broadcast.244416.5760) + %or.116305.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110537.9.clone.1, %shift-right-logical.116792.9.clone.1) + %xor.122876.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250515.5.clone.1, %or.116305.7.clone.1) + %add.250516.3.clone.1 = u32[1280,1280]{1,0} add(%add.250515.5.clone.1, %xor.122876.5.clone.1) + %shift-left.110538.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122876.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116793.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122876.5.clone.1, %broadcast.244429.2304) + %or.116306.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110538.9.clone.1, %shift-right-logical.116793.9.clone.1) + %xor.122877.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250516.3.clone.1, %or.116306.7.clone.1) + %add.250518.3.clone.1 = u32[1280,1280]{1,0} add(%add.250516.3.clone.1, %xor.122877.5.clone.1) + %shift-left.110539.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122877.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116794.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122877.5.clone.1, %broadcast.244430.4608) + %or.116307.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110539.9.clone.1, %shift-right-logical.116794.9.clone.1) + %xor.122878.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250518.3.clone.1, %or.116307.7.clone.1) + %add.250519.3.clone.1 = u32[1280,1280]{1,0} add(%add.250518.3.clone.1, %xor.122878.5.clone.1) + %add.250520.7.clone.1 = u32[1280,1280]{1,0} add(%add.250519.3.clone.1, %broadcast.255683.113.clone.1) + %shift-left.110541.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122878.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116795.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122878.5.clone.1, %broadcast.244434.2816) + %or.116308.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110541.11.clone.1, %shift-right-logical.116795.11.clone.1) + %xor.122880.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250519.3.clone.1, %or.116308.9.clone.1) + %constant_218477_1_clone_1 = u32[] constant(977241959) + %broadcast.255723.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218477_1_clone_1), dimensions={} + %add.250521.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122880.7.clone.1, %broadcast.255723.5.clone.1) + %add.250522.5.clone.1 = u32[1280,1280]{1,0} add(%add.250520.7.clone.1, %add.250521.5.clone.1) + %shift-left.110542.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250521.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116796.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250521.5.clone.1, %broadcast.244415.6016) + %or.116310.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110542.9.clone.1, %shift-right-logical.116796.9.clone.1) + %xor.122881.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250522.5.clone.1, %or.116310.7.clone.1) + %add.250524.3.clone.1 = u32[1280,1280]{1,0} add(%add.250522.5.clone.1, %xor.122881.5.clone.1) + %shift-left.110543.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122881.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116797.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122881.5.clone.1, %broadcast.244417.5760) + %or.116311.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110543.9.clone.1, %shift-right-logical.116797.9.clone.1) + %xor.122882.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250524.3.clone.1, %or.116311.7.clone.1) + %add.250528.3.clone.1 = u32[1280,1280]{1,0} add(%add.250524.3.clone.1, %xor.122882.5.clone.1) + %shift-left.110544.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122882.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116798.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122882.5.clone.1, %broadcast.244419.4352) + %or.116312.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110544.5.clone.1, %shift-right-logical.116798.5.clone.1) + %xor.122883.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250528.3.clone.1, %or.116312.3.clone.1) + %add.250529.3.clone.1 = u32[1280,1280]{1,0} add(%add.250528.3.clone.1, %xor.122883.3.clone.1) + %add.250530.17.clone.1 = u32[1280,1280]{1,0} add(%add.250529.3.clone.1, %broadcast.255700.24.clone.1) + %shift-left.110546.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122883.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116799.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122883.3.clone.1, %broadcast.244418.4352) + %or.116313.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110546.5.clone.1, %shift-right-logical.116799.5.clone.1) + %xor.122884.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250529.3.clone.1, %or.116313.3.clone.1) + %constant_218478_1_clone_1 = u32[] constant(662923863) + %broadcast.255733.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218478_1_clone_1), dimensions={} + %add.250531.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122884.15.clone.1, %broadcast.255733.19.clone.1) + %xor.122885.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250530.17.clone.1, %add.250531.19.clone.1) + %shift-right-logical.116800.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122885.17.clone.1, %broadcast.244468.1920) + %or.116315.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116800.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5797.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116315.13.clone.1) + %add.250533.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5797.11.clone.1, %broadcast.244470.1152) + %multiply.26903.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250533.9.clone.1, %broadcast.244471.896) + %add.250534.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26903.7.clone.1, %broadcast.244408.1024) + %maximum.3729.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250534.5.clone.1) + %abs.1569.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3729.3.clone.1) + %compare.7292.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1569.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26904.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3729.3.clone.1, %broadcast.244476.1152) + %negate.4643.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3729.3.clone.1) + %multiply.26905.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3729.3.clone.1, %negate.4643.5.clone.1) + %log-plus-one.1569.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26905.5.clone.1) + %negate.4644.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1569.3.clone.1) + %compare.7293.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4644.4.clone.1, %broadcast.244477.384), direction=LT + %select.21345.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21346.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21347.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21348.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21349.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21350.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21351.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21352.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21353.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250535.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4644.4.clone.1, %broadcast.244496.640) + %sqrt.1569.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4644.4.clone.1) + %add.250536.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1569.5.clone.1, %broadcast.244498.640) + %select.21354.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7293.3.clone.1, %add.250535.5.clone.1, %add.250536.5.clone.1) + %multiply.26906.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21353.3.clone.1, %select.21354.3.clone.1) + %add.250538.1.clone.1 = f32[1280,1280]{1,0} add(%select.21352.3.clone.1, %multiply.26906.1.clone.1) + %multiply.26907.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250538.1.clone.1, %select.21354.3.clone.1) + %add.250539.1.clone.1 = f32[1280,1280]{1,0} add(%select.21351.3.clone.1, %multiply.26907.1.clone.1) + %multiply.26908.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250539.1.clone.1, %select.21354.3.clone.1) + %add.250540.1.clone.1 = f32[1280,1280]{1,0} add(%select.21350.3.clone.1, %multiply.26908.1.clone.1) + %multiply.26909.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250540.1.clone.1, %select.21354.3.clone.1) + %add.250541.1.clone.1 = f32[1280,1280]{1,0} add(%select.21349.3.clone.1, %multiply.26909.1.clone.1) + %multiply.26910.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250541.1.clone.1, %select.21354.3.clone.1) + %add.250543.3.clone.1 = f32[1280,1280]{1,0} add(%select.21348.5.clone.1, %multiply.26910.1.clone.1) + %multiply.26911.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250543.3.clone.1, %select.21354.3.clone.1) + %add.250544.3.clone.1 = f32[1280,1280]{1,0} add(%select.21347.5.clone.1, %multiply.26911.1.clone.1) + %multiply.26913.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250544.3.clone.1, %select.21354.3.clone.1) + %add.250545.9.clone.1 = f32[1280,1280]{1,0} add(%select.21346.11.clone.1, %multiply.26913.7.clone.1) + %multiply.26914.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250545.9.clone.1, %select.21354.3.clone.1) + %add.250546.7.clone.1 = f32[1280,1280]{1,0} add(%select.21345.7.clone.1, %multiply.26914.7.clone.1) + %multiply.26915.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250546.7.clone.1, %maximum.3729.3.clone.1) + %select.21355.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7292.3.clone.1, %multiply.26904.9.clone.1, %multiply.26915.7.clone.1) + %multiply.26916.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21355.7.clone.1, %broadcast.244500.640) + %clamp.1213.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26916.5.clone.1, %broadcast.244501.384) + %multiply.26917.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1213.3.clone.1, %broadcast.244502.1) + %constant_170948_1_clone_1 = u32[] constant(2274471178) + %broadcast.249962.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170948_1_clone_1), dimensions={} + %add.247240.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249962.44.clone.1) + %constant_170956_1_clone_1 = u32[] constant(1719103117) + %broadcast.249963.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170956_1_clone_1), dimensions={} + %add.247241.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249963.113.clone.1) + %add.247243.35.clone.1 = u32[1280,1280]{1,0} add(%add.247240.37.clone.1, %add.247241.99.clone.1) + %shift-left.109116.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247241.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115285.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247241.99.clone.1, %broadcast.244415.6016) + %or.114808.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109116.31.clone.1, %shift-right-logical.115285.29.clone.1) + %xor.121355.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247243.35.clone.1, %or.114808.29.clone.1) + %add.247244.5.clone.1 = u32[1280,1280]{1,0} add(%add.247243.35.clone.1, %xor.121355.27.clone.1) + %shift-left.109117.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121355.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115287.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121355.27.clone.1, %broadcast.244417.5760) + %or.114809.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109117.9.clone.1, %shift-right-logical.115287.9.clone.1) + %xor.121356.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247244.5.clone.1, %or.114809.7.clone.1) + %add.247245.3.clone.1 = u32[1280,1280]{1,0} add(%add.247244.5.clone.1, %xor.121356.5.clone.1) + %shift-left.109119.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121356.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115288.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121356.5.clone.1, %broadcast.244419.4352) + %or.114810.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109119.5.clone.1, %shift-right-logical.115288.5.clone.1) + %xor.121357.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247245.3.clone.1, %or.114810.3.clone.1) + %add.247246.3.clone.1 = u32[1280,1280]{1,0} add(%add.247245.3.clone.1, %xor.121357.3.clone.1) + %add.247247.7.clone.1 = u32[1280,1280]{1,0} add(%add.247246.3.clone.1, %broadcast.249963.113.clone.1) + %shift-left.109120.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121357.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115289.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121357.3.clone.1, %broadcast.244418.4352) + %or.114811.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109120.5.clone.1, %shift-right-logical.115289.5.clone.1) + %xor.121358.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247246.3.clone.1, %or.114811.3.clone.1) + %constant_218111_1_clone_1 = u32[] constant(4197963870) + %broadcast.249973.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218111_1_clone_1), dimensions={} + %add.247249.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121358.3.clone.1, %broadcast.249973.5.clone.1) + %add.247252.5.clone.1 = u32[1280,1280]{1,0} add(%add.247247.7.clone.1, %add.247249.5.clone.1) + %shift-left.109121.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247249.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115290.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247249.5.clone.1, %broadcast.244416.5760) + %or.114812.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109121.9.clone.1, %shift-right-logical.115290.9.clone.1) + %xor.121359.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247252.5.clone.1, %or.114812.7.clone.1) + %add.247253.3.clone.1 = u32[1280,1280]{1,0} add(%add.247252.5.clone.1, %xor.121359.5.clone.1) + %shift-left.109122.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121359.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115291.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121359.5.clone.1, %broadcast.244429.2304) + %or.114813.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109122.9.clone.1, %shift-right-logical.115291.9.clone.1) + %xor.121360.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247253.3.clone.1, %or.114813.7.clone.1) + %add.247254.3.clone.1 = u32[1280,1280]{1,0} add(%add.247253.3.clone.1, %xor.121360.5.clone.1) + %shift-left.109124.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121360.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115292.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121360.5.clone.1, %broadcast.244430.4608) + %or.114814.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109124.9.clone.1, %shift-right-logical.115292.9.clone.1) + %xor.121361.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247254.3.clone.1, %or.114814.7.clone.1) + %add.247255.3.clone.1 = u32[1280,1280]{1,0} add(%add.247254.3.clone.1, %xor.121361.5.clone.1) + %constant_170958_1_clone_1 = u32[] constant(4197963869) + %broadcast.249980.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170958_1_clone_1), dimensions={} + %add.247256.7.clone.1 = u32[1280,1280]{1,0} add(%add.247255.3.clone.1, %broadcast.249980.24.clone.1) + %shift-left.109125.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121361.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115293.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121361.5.clone.1, %broadcast.244434.2816) + %or.114815.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109125.11.clone.1, %shift-right-logical.115293.11.clone.1) + %xor.121362.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247255.3.clone.1, %or.114815.9.clone.1) + %constant_218112_1_clone_1 = u32[] constant(2274471180) + %broadcast.249983.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218112_1_clone_1), dimensions={} + %add.247257.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121362.7.clone.1, %broadcast.249983.5.clone.1) + %add.247258.5.clone.1 = u32[1280,1280]{1,0} add(%add.247256.7.clone.1, %add.247257.5.clone.1) + %shift-left.109126.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247257.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115294.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247257.5.clone.1, %broadcast.244415.6016) + %or.114816.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109126.9.clone.1, %shift-right-logical.115294.9.clone.1) + %xor.121363.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247258.5.clone.1, %or.114816.7.clone.1) + %add.247259.3.clone.1 = u32[1280,1280]{1,0} add(%add.247258.5.clone.1, %xor.121363.5.clone.1) + %shift-left.109127.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121363.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115295.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121363.5.clone.1, %broadcast.244417.5760) + %or.114817.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109127.9.clone.1, %shift-right-logical.115295.9.clone.1) + %xor.121364.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247259.3.clone.1, %or.114817.7.clone.1) + %add.247260.3.clone.1 = u32[1280,1280]{1,0} add(%add.247259.3.clone.1, %xor.121364.5.clone.1) + %shift-left.109129.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121364.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115296.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121364.5.clone.1, %broadcast.244419.4352) + %or.114818.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109129.7.clone.1, %shift-right-logical.115296.7.clone.1) + %xor.121365.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247260.3.clone.1, %or.114818.5.clone.1) + %add.247261.3.clone.1 = u32[1280,1280]{1,0} add(%add.247260.3.clone.1, %xor.121365.3.clone.1) + %add.247262.7.clone.1 = u32[1280,1280]{1,0} add(%add.247261.3.clone.1, %broadcast.249962.44.clone.1) + %shift-left.109130.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121365.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115297.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121365.3.clone.1, %broadcast.244418.4352) + %or.114819.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109130.7.clone.1, %shift-right-logical.115297.7.clone.1) + %xor.121366.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247261.3.clone.1, %or.114819.5.clone.1) + %constant_218113_1_clone_1 = u32[] constant(1719103120) + %broadcast.249993.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218113_1_clone_1), dimensions={} + %add.247263.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121366.3.clone.1, %broadcast.249993.5.clone.1) + %add.247264.5.clone.1 = u32[1280,1280]{1,0} add(%add.247262.7.clone.1, %add.247263.5.clone.1) + %shift-left.109131.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247263.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115298.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247263.5.clone.1, %broadcast.244416.5760) + %or.114820.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109131.9.clone.1, %shift-right-logical.115298.9.clone.1) + %xor.121367.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247264.5.clone.1, %or.114820.7.clone.1) + %add.247265.3.clone.1 = u32[1280,1280]{1,0} add(%add.247264.5.clone.1, %xor.121367.5.clone.1) + %shift-left.109132.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121367.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115299.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121367.5.clone.1, %broadcast.244429.2304) + %or.114821.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109132.9.clone.1, %shift-right-logical.115299.9.clone.1) + %xor.121368.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247265.3.clone.1, %or.114821.7.clone.1) + %add.247266.3.clone.1 = u32[1280,1280]{1,0} add(%add.247265.3.clone.1, %xor.121368.5.clone.1) + %shift-left.109134.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121368.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115300.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121368.5.clone.1, %broadcast.244430.4608) + %or.114822.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109134.9.clone.1, %shift-right-logical.115300.9.clone.1) + %xor.121369.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247266.3.clone.1, %or.114822.7.clone.1) + %add.247267.3.clone.1 = u32[1280,1280]{1,0} add(%add.247266.3.clone.1, %xor.121369.5.clone.1) + %add.247268.7.clone.1 = u32[1280,1280]{1,0} add(%add.247267.3.clone.1, %broadcast.249963.113.clone.1) + %shift-left.109135.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121369.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115301.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121369.5.clone.1, %broadcast.244434.2816) + %or.114823.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109135.11.clone.1, %shift-right-logical.115301.11.clone.1) + %xor.121370.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247267.3.clone.1, %or.114823.9.clone.1) + %constant_218114_1_clone_1 = u32[] constant(4197963873) + %broadcast.250003.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218114_1_clone_1), dimensions={} + %add.247269.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121370.7.clone.1, %broadcast.250003.5.clone.1) + %add.247270.5.clone.1 = u32[1280,1280]{1,0} add(%add.247268.7.clone.1, %add.247269.5.clone.1) + %shift-left.109136.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247269.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115302.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247269.5.clone.1, %broadcast.244415.6016) + %or.114824.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109136.9.clone.1, %shift-right-logical.115302.9.clone.1) + %xor.121371.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247270.5.clone.1, %or.114824.7.clone.1) + %add.247271.3.clone.1 = u32[1280,1280]{1,0} add(%add.247270.5.clone.1, %xor.121371.5.clone.1) + %shift-left.109137.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121371.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115303.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121371.5.clone.1, %broadcast.244417.5760) + %or.114825.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109137.9.clone.1, %shift-right-logical.115303.9.clone.1) + %xor.121372.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247271.3.clone.1, %or.114825.7.clone.1) + %add.247272.3.clone.1 = u32[1280,1280]{1,0} add(%add.247271.3.clone.1, %xor.121372.5.clone.1) + %shift-left.109138.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121372.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115304.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121372.5.clone.1, %broadcast.244419.4352) + %or.114826.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109138.5.clone.1, %shift-right-logical.115304.5.clone.1) + %xor.121373.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247272.3.clone.1, %or.114826.3.clone.1) + %add.247273.3.clone.1 = u32[1280,1280]{1,0} add(%add.247272.3.clone.1, %xor.121373.3.clone.1) + %add.247274.17.clone.1 = u32[1280,1280]{1,0} add(%add.247273.3.clone.1, %broadcast.249980.24.clone.1) + %shift-left.109139.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121373.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115305.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121373.3.clone.1, %broadcast.244418.4352) + %or.114827.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109139.5.clone.1, %shift-right-logical.115305.5.clone.1) + %xor.121374.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247273.3.clone.1, %or.114827.3.clone.1) + %constant_218115_1_clone_1 = u32[] constant(2274471183) + %broadcast.250013.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218115_1_clone_1), dimensions={} + %add.247275.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121374.15.clone.1, %broadcast.250013.19.clone.1) + %xor.121375.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247274.17.clone.1, %add.247275.19.clone.1) + %shift-right-logical.115306.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121375.17.clone.1, %broadcast.244468.1920) + %or.114828.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115306.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5732.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114828.13.clone.1) + %add.247276.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5732.11.clone.1, %broadcast.244470.1152) + %multiply.26242.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247276.9.clone.1, %broadcast.244471.896) + %add.247277.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26242.7.clone.1, %broadcast.244408.1024) + %maximum.3664.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247277.5.clone.1) + %abs.1526.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3664.3.clone.1) + %compare.7200.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1526.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26243.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3664.3.clone.1, %broadcast.244476.1152) + %negate.4557.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3664.3.clone.1) + %multiply.26244.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3664.3.clone.1, %negate.4557.5.clone.1) + %log-plus-one.1526.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26244.5.clone.1) + %negate.4558.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1526.3.clone.1) + %compare.7201.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4558.4.clone.1, %broadcast.244477.384), direction=LT + %select.20851.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20852.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20853.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20854.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20855.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20856.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20857.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20858.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20859.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247278.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4558.4.clone.1, %broadcast.244496.640) + %sqrt.1526.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4558.4.clone.1) + %add.247279.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1526.5.clone.1, %broadcast.244498.640) + %select.20860.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7201.3.clone.1, %add.247278.5.clone.1, %add.247279.5.clone.1) + %multiply.26245.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20859.3.clone.1, %select.20860.3.clone.1) + %add.247280.1.clone.1 = f32[1280,1280]{1,0} add(%select.20858.3.clone.1, %multiply.26245.1.clone.1) + %multiply.26246.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247280.1.clone.1, %select.20860.3.clone.1) + %add.247281.1.clone.1 = f32[1280,1280]{1,0} add(%select.20857.3.clone.1, %multiply.26246.1.clone.1) + %multiply.26247.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247281.1.clone.1, %select.20860.3.clone.1) + %add.247282.1.clone.1 = f32[1280,1280]{1,0} add(%select.20856.3.clone.1, %multiply.26247.1.clone.1) + %multiply.26248.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247282.1.clone.1, %select.20860.3.clone.1) + %add.247283.1.clone.1 = f32[1280,1280]{1,0} add(%select.20855.3.clone.1, %multiply.26248.1.clone.1) + %multiply.26249.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247283.1.clone.1, %select.20860.3.clone.1) + %add.247284.3.clone.1 = f32[1280,1280]{1,0} add(%select.20854.5.clone.1, %multiply.26249.1.clone.1) + %multiply.26250.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247284.3.clone.1, %select.20860.3.clone.1) + %add.247285.3.clone.1 = f32[1280,1280]{1,0} add(%select.20853.5.clone.1, %multiply.26250.1.clone.1) + %multiply.26251.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247285.3.clone.1, %select.20860.3.clone.1) + %add.247286.9.clone.1 = f32[1280,1280]{1,0} add(%select.20852.11.clone.1, %multiply.26251.7.clone.1) + %multiply.26252.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247286.9.clone.1, %select.20860.3.clone.1) + %add.247287.7.clone.1 = f32[1280,1280]{1,0} add(%select.20851.7.clone.1, %multiply.26252.7.clone.1) + %multiply.26253.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247287.7.clone.1, %maximum.3664.3.clone.1) + %select.20861.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7200.3.clone.1, %multiply.26243.9.clone.1, %multiply.26253.7.clone.1) + %multiply.26254.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20861.7.clone.1, %broadcast.244500.640) + %clamp.1170.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26254.5.clone.1, %broadcast.244501.384) + %multiply.26255.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1170.3.clone.1, %broadcast.244502.1) + %constant_190473_1_clone_1 = u32[] constant(2952426392) + %broadcast.258389.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190473_1_clone_1), dimensions={} + %add.252060.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258389.44.clone.1) + %constant_190480_1_clone_1 = u32[] constant(3094339752) + %broadcast.258390.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190480_1_clone_1), dimensions={} + %add.252061.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258390.113.clone.1) + %add.252062.35.clone.1 = u32[1280,1280]{1,0} add(%add.252060.37.clone.1, %add.252061.99.clone.1) + %shift-left.111220.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252061.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117488.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252061.99.clone.1, %broadcast.244415.6016) + %or.117020.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111220.31.clone.1, %shift-right-logical.117488.29.clone.1) + %xor.123576.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252062.35.clone.1, %or.117020.29.clone.1) + %add.252063.5.clone.1 = u32[1280,1280]{1,0} add(%add.252062.35.clone.1, %xor.123576.27.clone.1) + %shift-left.111221.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123576.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117489.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123576.27.clone.1, %broadcast.244417.5760) + %or.117021.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111221.9.clone.1, %shift-right-logical.117489.9.clone.1) + %xor.123577.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252063.5.clone.1, %or.117021.7.clone.1) + %add.252065.3.clone.1 = u32[1280,1280]{1,0} add(%add.252063.5.clone.1, %xor.123577.5.clone.1) + %shift-left.111222.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123577.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117490.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123577.5.clone.1, %broadcast.244419.4352) + %or.117022.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111222.5.clone.1, %shift-right-logical.117490.5.clone.1) + %xor.123578.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252065.3.clone.1, %or.117022.3.clone.1) + %add.252066.3.clone.1 = u32[1280,1280]{1,0} add(%add.252065.3.clone.1, %xor.123578.3.clone.1) + %add.252067.7.clone.1 = u32[1280,1280]{1,0} add(%add.252066.3.clone.1, %broadcast.258390.113.clone.1) + %shift-left.111223.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123578.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117491.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123578.3.clone.1, %broadcast.244418.4352) + %or.117023.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111223.5.clone.1, %shift-right-logical.117491.5.clone.1) + %xor.123580.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252066.3.clone.1, %or.117023.3.clone.1) + %constant_218650_1_clone_1 = u32[] constant(205818091) + %broadcast.258400.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218650_1_clone_1), dimensions={} + %add.252068.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123580.3.clone.1, %broadcast.258400.5.clone.1) + %add.252070.5.clone.1 = u32[1280,1280]{1,0} add(%add.252067.7.clone.1, %add.252068.5.clone.1) + %shift-left.111224.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252068.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117492.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252068.5.clone.1, %broadcast.244416.5760) + %or.117024.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111224.9.clone.1, %shift-right-logical.117492.9.clone.1) + %xor.123581.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252070.5.clone.1, %or.117024.7.clone.1) + %add.252071.3.clone.1 = u32[1280,1280]{1,0} add(%add.252070.5.clone.1, %xor.123581.5.clone.1) + %shift-left.111225.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123581.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117493.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123581.5.clone.1, %broadcast.244429.2304) + %or.117025.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111225.9.clone.1, %shift-right-logical.117493.9.clone.1) + %xor.123582.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252071.3.clone.1, %or.117025.7.clone.1) + %add.252072.3.clone.1 = u32[1280,1280]{1,0} add(%add.252071.3.clone.1, %xor.123582.5.clone.1) + %shift-left.111226.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123582.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117494.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123582.5.clone.1, %broadcast.244430.4608) + %or.117026.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111226.9.clone.1, %shift-right-logical.117494.9.clone.1) + %xor.123583.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252072.3.clone.1, %or.117026.7.clone.1) + %add.252073.3.clone.1 = u32[1280,1280]{1,0} add(%add.252072.3.clone.1, %xor.123583.5.clone.1) + %constant_190482_1_clone_1 = u32[] constant(205818090) + %broadcast.258409.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190482_1_clone_1), dimensions={} + %add.252075.7.clone.1 = u32[1280,1280]{1,0} add(%add.252073.3.clone.1, %broadcast.258409.24.clone.1) + %shift-left.111227.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123583.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117496.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123583.5.clone.1, %broadcast.244434.2816) + %or.117027.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111227.11.clone.1, %shift-right-logical.117496.11.clone.1) + %xor.123585.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252073.3.clone.1, %or.117027.9.clone.1) + %constant_218651_1_clone_1 = u32[] constant(2952426394) + %broadcast.258412.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218651_1_clone_1), dimensions={} + %add.252076.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123585.7.clone.1, %broadcast.258412.5.clone.1) + %add.252077.5.clone.1 = u32[1280,1280]{1,0} add(%add.252075.7.clone.1, %add.252076.5.clone.1) + %shift-left.111228.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252076.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117497.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252076.5.clone.1, %broadcast.244415.6016) + %or.117028.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111228.9.clone.1, %shift-right-logical.117497.9.clone.1) + %xor.123586.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252077.5.clone.1, %or.117028.7.clone.1) + %add.252078.3.clone.1 = u32[1280,1280]{1,0} add(%add.252077.5.clone.1, %xor.123586.5.clone.1) + %shift-left.111229.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123586.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117498.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123586.5.clone.1, %broadcast.244417.5760) + %or.117029.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111229.9.clone.1, %shift-right-logical.117498.9.clone.1) + %xor.123587.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252078.3.clone.1, %or.117029.7.clone.1) + %add.252079.3.clone.1 = u32[1280,1280]{1,0} add(%add.252078.3.clone.1, %xor.123587.5.clone.1) + %shift-left.111230.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123587.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117499.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123587.5.clone.1, %broadcast.244419.4352) + %or.117030.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111230.7.clone.1, %shift-right-logical.117499.7.clone.1) + %xor.123588.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252079.3.clone.1, %or.117030.5.clone.1) + %add.252081.3.clone.1 = u32[1280,1280]{1,0} add(%add.252079.3.clone.1, %xor.123588.3.clone.1) + %add.252085.7.clone.1 = u32[1280,1280]{1,0} add(%add.252081.3.clone.1, %broadcast.258389.44.clone.1) + %shift-left.111231.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123588.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117501.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123588.3.clone.1, %broadcast.244418.4352) + %or.117031.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111231.7.clone.1, %shift-right-logical.117501.7.clone.1) + %xor.123590.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252081.3.clone.1, %or.117031.5.clone.1) + %constant_218652_1_clone_1 = u32[] constant(3094339755) + %broadcast.258422.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218652_1_clone_1), dimensions={} + %add.252086.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123590.3.clone.1, %broadcast.258422.5.clone.1) + %add.252087.5.clone.1 = u32[1280,1280]{1,0} add(%add.252085.7.clone.1, %add.252086.5.clone.1) + %shift-left.111232.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252086.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117502.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252086.5.clone.1, %broadcast.244416.5760) + %or.117032.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111232.9.clone.1, %shift-right-logical.117502.9.clone.1) + %xor.123591.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252087.5.clone.1, %or.117032.7.clone.1) + %add.252088.3.clone.1 = u32[1280,1280]{1,0} add(%add.252087.5.clone.1, %xor.123591.5.clone.1) + %shift-left.111233.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123591.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117503.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123591.5.clone.1, %broadcast.244429.2304) + %or.117033.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111233.9.clone.1, %shift-right-logical.117503.9.clone.1) + %xor.123592.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252088.3.clone.1, %or.117033.7.clone.1) + %add.252090.3.clone.1 = u32[1280,1280]{1,0} add(%add.252088.3.clone.1, %xor.123592.5.clone.1) + %shift-left.111234.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123592.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117504.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123592.5.clone.1, %broadcast.244430.4608) + %or.117034.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111234.9.clone.1, %shift-right-logical.117504.9.clone.1) + %xor.123593.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252090.3.clone.1, %or.117034.7.clone.1) + %add.252091.3.clone.1 = u32[1280,1280]{1,0} add(%add.252090.3.clone.1, %xor.123593.5.clone.1) + %add.252092.7.clone.1 = u32[1280,1280]{1,0} add(%add.252091.3.clone.1, %broadcast.258390.113.clone.1) + %shift-left.111235.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123593.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117506.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123593.5.clone.1, %broadcast.244434.2816) + %or.117035.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111235.11.clone.1, %shift-right-logical.117506.11.clone.1) + %xor.123594.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252091.3.clone.1, %or.117035.9.clone.1) + %constant_218653_1_clone_1 = u32[] constant(205818094) + %broadcast.258434.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218653_1_clone_1), dimensions={} + %add.252093.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123594.7.clone.1, %broadcast.258434.5.clone.1) + %add.252095.5.clone.1 = u32[1280,1280]{1,0} add(%add.252092.7.clone.1, %add.252093.5.clone.1) + %shift-left.111236.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252093.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117507.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252093.5.clone.1, %broadcast.244415.6016) + %or.117036.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111236.9.clone.1, %shift-right-logical.117507.9.clone.1) + %xor.123595.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252095.5.clone.1, %or.117036.7.clone.1) + %add.252096.3.clone.1 = u32[1280,1280]{1,0} add(%add.252095.5.clone.1, %xor.123595.5.clone.1) + %shift-left.111237.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123595.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117508.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123595.5.clone.1, %broadcast.244417.5760) + %or.117037.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111237.9.clone.1, %shift-right-logical.117508.9.clone.1) + %xor.123596.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252096.3.clone.1, %or.117037.7.clone.1) + %add.252097.3.clone.1 = u32[1280,1280]{1,0} add(%add.252096.3.clone.1, %xor.123596.5.clone.1) + %shift-left.111238.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123596.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117509.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123596.5.clone.1, %broadcast.244419.4352) + %or.117038.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111238.5.clone.1, %shift-right-logical.117509.5.clone.1) + %xor.123597.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252097.3.clone.1, %or.117038.3.clone.1) + %add.252098.3.clone.1 = u32[1280,1280]{1,0} add(%add.252097.3.clone.1, %xor.123597.3.clone.1) + %add.252100.17.clone.1 = u32[1280,1280]{1,0} add(%add.252098.3.clone.1, %broadcast.258409.24.clone.1) + %shift-left.111239.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123597.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117511.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123597.3.clone.1, %broadcast.244418.4352) + %or.117039.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111239.5.clone.1, %shift-right-logical.117511.5.clone.1) + %xor.123598.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252098.3.clone.1, %or.117039.3.clone.1) + %constant_218654_1_clone_1 = u32[] constant(2952426397) + %broadcast.258444.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218654_1_clone_1), dimensions={} + %add.252101.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123598.15.clone.1, %broadcast.258444.19.clone.1) + %xor.123600.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252100.17.clone.1, %add.252101.19.clone.1) + %shift-right-logical.117512.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123600.17.clone.1, %broadcast.244468.1920) + %or.117040.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117512.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5828.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117040.13.clone.1) + %add.252102.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5828.11.clone.1, %broadcast.244470.1152) + %multiply.27232.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252102.9.clone.1, %broadcast.244471.896) + %add.252103.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27232.7.clone.1, %broadcast.244408.1024) + %maximum.3760.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252103.5.clone.1) + %abs.1590.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3760.3.clone.1) + %compare.7342.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1590.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27233.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3760.3.clone.1, %broadcast.244476.1152) + %negate.4685.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3760.3.clone.1) + %multiply.27234.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3760.3.clone.1, %negate.4685.5.clone.1) + %log-plus-one.1590.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27234.5.clone.1) + %negate.4686.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1590.3.clone.1) + %compare.7343.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4686.4.clone.1, %broadcast.244477.384), direction=LT + %select.21597.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21598.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21599.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21600.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21601.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21602.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21603.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21604.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21605.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252104.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4686.4.clone.1, %broadcast.244496.640) + %sqrt.1590.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4686.4.clone.1) + %add.252106.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1590.5.clone.1, %broadcast.244498.640) + %select.21606.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7343.3.clone.1, %add.252104.5.clone.1, %add.252106.5.clone.1) + %multiply.27235.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21605.3.clone.1, %select.21606.3.clone.1) + %add.252109.1.clone.1 = f32[1280,1280]{1,0} add(%select.21604.3.clone.1, %multiply.27235.1.clone.1) + %multiply.27236.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252109.1.clone.1, %select.21606.3.clone.1) + %add.252110.1.clone.1 = f32[1280,1280]{1,0} add(%select.21603.3.clone.1, %multiply.27236.1.clone.1) + %multiply.27237.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252110.1.clone.1, %select.21606.3.clone.1) + %add.252111.1.clone.1 = f32[1280,1280]{1,0} add(%select.21602.3.clone.1, %multiply.27237.1.clone.1) + %multiply.27238.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252111.1.clone.1, %select.21606.3.clone.1) + %add.252112.1.clone.1 = f32[1280,1280]{1,0} add(%select.21601.3.clone.1, %multiply.27238.1.clone.1) + %multiply.27239.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252112.1.clone.1, %select.21606.3.clone.1) + %add.252113.3.clone.1 = f32[1280,1280]{1,0} add(%select.21600.5.clone.1, %multiply.27239.1.clone.1) + %multiply.27240.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252113.3.clone.1, %select.21606.3.clone.1) + %add.252114.3.clone.1 = f32[1280,1280]{1,0} add(%select.21599.5.clone.1, %multiply.27240.1.clone.1) + %multiply.27241.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252114.3.clone.1, %select.21606.3.clone.1) + %add.252115.9.clone.1 = f32[1280,1280]{1,0} add(%select.21598.11.clone.1, %multiply.27241.7.clone.1) + %multiply.27242.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252115.9.clone.1, %select.21606.3.clone.1) + %add.252116.7.clone.1 = f32[1280,1280]{1,0} add(%select.21597.7.clone.1, %multiply.27242.7.clone.1) + %multiply.27243.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252116.7.clone.1, %maximum.3760.3.clone.1) + %select.21607.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7342.3.clone.1, %multiply.27233.9.clone.1, %multiply.27243.7.clone.1) + %multiply.27244.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21607.7.clone.1, %broadcast.244500.640) + %clamp.1234.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27244.5.clone.1, %broadcast.244501.384) + %multiply.27245.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1234.3.clone.1, %broadcast.244502.1) + %constant_170738_1_clone_1 = u32[] constant(1424143128) + %broadcast.249861.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170738_1_clone_1), dimensions={} + %add.247180.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249861.44.clone.1) + %constant_170745_1_clone_1 = u32[] constant(187212167) + %broadcast.249862.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170745_1_clone_1), dimensions={} + %add.247181.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249862.113.clone.1) + %add.247183.35.clone.1 = u32[1280,1280]{1,0} add(%add.247180.37.clone.1, %add.247181.99.clone.1) + %shift-left.109092.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247181.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115260.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247181.99.clone.1, %broadcast.244415.6016) + %or.114787.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109092.31.clone.1, %shift-right-logical.115260.29.clone.1) + %xor.121334.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247183.35.clone.1, %or.114787.29.clone.1) + %add.247184.5.clone.1 = u32[1280,1280]{1,0} add(%add.247183.35.clone.1, %xor.121334.27.clone.1) + %shift-left.109094.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121334.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115262.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121334.27.clone.1, %broadcast.244417.5760) + %or.114788.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109094.9.clone.1, %shift-right-logical.115262.9.clone.1) + %xor.121335.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247184.5.clone.1, %or.114788.7.clone.1) + %add.247185.3.clone.1 = u32[1280,1280]{1,0} add(%add.247184.5.clone.1, %xor.121335.5.clone.1) + %shift-left.109095.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121335.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115263.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121335.5.clone.1, %broadcast.244419.4352) + %or.114789.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109095.5.clone.1, %shift-right-logical.115263.5.clone.1) + %xor.121336.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247185.3.clone.1, %or.114789.3.clone.1) + %add.247186.3.clone.1 = u32[1280,1280]{1,0} add(%add.247185.3.clone.1, %xor.121336.3.clone.1) + %add.247188.7.clone.1 = u32[1280,1280]{1,0} add(%add.247186.3.clone.1, %broadcast.249862.113.clone.1) + %shift-left.109096.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121336.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115264.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121336.3.clone.1, %broadcast.244418.4352) + %or.114790.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109096.5.clone.1, %shift-right-logical.115264.5.clone.1) + %xor.121337.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247186.3.clone.1, %or.114790.3.clone.1) + %constant_218106_1_clone_1 = u32[] constant(1142622534) + %broadcast.249875.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218106_1_clone_1), dimensions={} + %add.247189.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121337.3.clone.1, %broadcast.249875.5.clone.1) + %add.247190.5.clone.1 = u32[1280,1280]{1,0} add(%add.247188.7.clone.1, %add.247189.5.clone.1) + %shift-left.109097.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247189.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115265.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247189.5.clone.1, %broadcast.244416.5760) + %or.114791.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109097.9.clone.1, %shift-right-logical.115265.9.clone.1) + %xor.121338.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247190.5.clone.1, %or.114791.7.clone.1) + %add.247191.3.clone.1 = u32[1280,1280]{1,0} add(%add.247190.5.clone.1, %xor.121338.5.clone.1) + %shift-left.109099.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121338.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115266.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121338.5.clone.1, %broadcast.244429.2304) + %or.114792.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109099.9.clone.1, %shift-right-logical.115266.9.clone.1) + %xor.121339.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247191.3.clone.1, %or.114792.7.clone.1) + %add.247193.3.clone.1 = u32[1280,1280]{1,0} add(%add.247191.3.clone.1, %xor.121339.5.clone.1) + %shift-left.109100.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121339.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115267.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121339.5.clone.1, %broadcast.244430.4608) + %or.114793.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109100.9.clone.1, %shift-right-logical.115267.9.clone.1) + %xor.121340.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247193.3.clone.1, %or.114793.7.clone.1) + %add.247194.3.clone.1 = u32[1280,1280]{1,0} add(%add.247193.3.clone.1, %xor.121340.5.clone.1) + %constant_170747_1_clone_1 = u32[] constant(1142622533) + %broadcast.249886.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170747_1_clone_1), dimensions={} + %add.247195.7.clone.1 = u32[1280,1280]{1,0} add(%add.247194.3.clone.1, %broadcast.249886.24.clone.1) + %shift-left.109101.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121340.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115268.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121340.5.clone.1, %broadcast.244434.2816) + %or.114794.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109101.11.clone.1, %shift-right-logical.115268.11.clone.1) + %xor.121341.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247194.3.clone.1, %or.114794.9.clone.1) + %constant_218107_1_clone_1 = u32[] constant(1424143130) + %broadcast.249889.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218107_1_clone_1), dimensions={} + %add.247196.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121341.7.clone.1, %broadcast.249889.5.clone.1) + %add.247197.5.clone.1 = u32[1280,1280]{1,0} add(%add.247195.7.clone.1, %add.247196.5.clone.1) + %shift-left.109102.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247196.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115269.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247196.5.clone.1, %broadcast.244415.6016) + %or.114795.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109102.9.clone.1, %shift-right-logical.115269.9.clone.1) + %xor.121342.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247197.5.clone.1, %or.114795.7.clone.1) + %add.247199.3.clone.1 = u32[1280,1280]{1,0} add(%add.247197.5.clone.1, %xor.121342.5.clone.1) + %shift-left.109104.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121342.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115270.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121342.5.clone.1, %broadcast.244417.5760) + %or.114796.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109104.9.clone.1, %shift-right-logical.115270.9.clone.1) + %xor.121343.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247199.3.clone.1, %or.114796.7.clone.1) + %add.247203.3.clone.1 = u32[1280,1280]{1,0} add(%add.247199.3.clone.1, %xor.121343.5.clone.1) + %shift-left.109105.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121343.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115272.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121343.5.clone.1, %broadcast.244419.4352) + %or.114797.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109105.7.clone.1, %shift-right-logical.115272.7.clone.1) + %xor.121344.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247203.3.clone.1, %or.114797.5.clone.1) + %add.247204.3.clone.1 = u32[1280,1280]{1,0} add(%add.247203.3.clone.1, %xor.121344.3.clone.1) + %add.247205.7.clone.1 = u32[1280,1280]{1,0} add(%add.247204.3.clone.1, %broadcast.249861.44.clone.1) + %shift-left.109106.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121344.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115273.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121344.3.clone.1, %broadcast.244418.4352) + %or.114798.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109106.7.clone.1, %shift-right-logical.115273.7.clone.1) + %xor.121345.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247204.3.clone.1, %or.114798.5.clone.1) + %constant_218108_1_clone_1 = u32[] constant(187212170) + %broadcast.249901.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218108_1_clone_1), dimensions={} + %add.247206.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121345.3.clone.1, %broadcast.249901.5.clone.1) + %add.247208.5.clone.1 = u32[1280,1280]{1,0} add(%add.247205.7.clone.1, %add.247206.5.clone.1) + %shift-left.109107.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247206.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115274.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247206.5.clone.1, %broadcast.244416.5760) + %or.114799.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109107.9.clone.1, %shift-right-logical.115274.9.clone.1) + %xor.121346.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247208.5.clone.1, %or.114799.7.clone.1) + %add.247209.3.clone.1 = u32[1280,1280]{1,0} add(%add.247208.5.clone.1, %xor.121346.5.clone.1) + %shift-left.109109.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121346.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115275.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121346.5.clone.1, %broadcast.244429.2304) + %or.114800.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109109.9.clone.1, %shift-right-logical.115275.9.clone.1) + %xor.121347.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247209.3.clone.1, %or.114800.7.clone.1) + %add.247210.3.clone.1 = u32[1280,1280]{1,0} add(%add.247209.3.clone.1, %xor.121347.5.clone.1) + %shift-left.109110.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121347.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115277.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121347.5.clone.1, %broadcast.244430.4608) + %or.114801.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109110.9.clone.1, %shift-right-logical.115277.9.clone.1) + %xor.121348.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247210.3.clone.1, %or.114801.7.clone.1) + %add.247211.3.clone.1 = u32[1280,1280]{1,0} add(%add.247210.3.clone.1, %xor.121348.5.clone.1) + %add.247213.7.clone.1 = u32[1280,1280]{1,0} add(%add.247211.3.clone.1, %broadcast.249862.113.clone.1) + %shift-left.109111.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121348.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115278.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121348.5.clone.1, %broadcast.244434.2816) + %or.114802.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109111.11.clone.1, %shift-right-logical.115278.11.clone.1) + %xor.121349.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247211.3.clone.1, %or.114802.9.clone.1) + %constant_218109_1_clone_1 = u32[] constant(1142622537) + %broadcast.249915.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218109_1_clone_1), dimensions={} + %add.247214.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121349.7.clone.1, %broadcast.249915.5.clone.1) + %add.247215.5.clone.1 = u32[1280,1280]{1,0} add(%add.247213.7.clone.1, %add.247214.5.clone.1) + %shift-left.109112.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247214.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115279.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247214.5.clone.1, %broadcast.244415.6016) + %or.114803.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109112.9.clone.1, %shift-right-logical.115279.9.clone.1) + %xor.121350.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247215.5.clone.1, %or.114803.7.clone.1) + %add.247216.3.clone.1 = u32[1280,1280]{1,0} add(%add.247215.5.clone.1, %xor.121350.5.clone.1) + %shift-left.109113.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121350.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115280.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121350.5.clone.1, %broadcast.244417.5760) + %or.114804.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109113.9.clone.1, %shift-right-logical.115280.9.clone.1) + %xor.121351.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247216.3.clone.1, %or.114804.7.clone.1) + %add.247218.3.clone.1 = u32[1280,1280]{1,0} add(%add.247216.3.clone.1, %xor.121351.5.clone.1) + %shift-left.109114.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121351.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115282.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121351.5.clone.1, %broadcast.244419.4352) + %or.114805.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109114.5.clone.1, %shift-right-logical.115282.5.clone.1) + %xor.121352.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247218.3.clone.1, %or.114805.3.clone.1) + %add.247219.3.clone.1 = u32[1280,1280]{1,0} add(%add.247218.3.clone.1, %xor.121352.3.clone.1) + %add.247220.17.clone.1 = u32[1280,1280]{1,0} add(%add.247219.3.clone.1, %broadcast.249886.24.clone.1) + %shift-left.109115.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121352.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115283.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121352.3.clone.1, %broadcast.244418.4352) + %or.114806.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109115.5.clone.1, %shift-right-logical.115283.5.clone.1) + %xor.121353.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247219.3.clone.1, %or.114806.3.clone.1) + %constant_218110_1_clone_1 = u32[] constant(1424143133) + %broadcast.249927.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218110_1_clone_1), dimensions={} + %add.247221.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121353.15.clone.1, %broadcast.249927.19.clone.1) + %xor.121354.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247220.17.clone.1, %add.247221.19.clone.1) + %shift-right-logical.115284.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121354.17.clone.1, %broadcast.244468.1920) + %or.114807.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115284.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5731.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114807.13.clone.1) + %add.247222.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5731.11.clone.1, %broadcast.244470.1152) + %multiply.26228.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247222.9.clone.1, %broadcast.244471.896) + %add.247224.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26228.7.clone.1, %broadcast.244408.1024) + %maximum.3663.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247224.5.clone.1) + %abs.1525.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3663.3.clone.1) + %compare.7198.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1525.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26229.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3663.3.clone.1, %broadcast.244476.1152) + %negate.4555.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3663.3.clone.1) + %multiply.26230.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3663.3.clone.1, %negate.4555.5.clone.1) + %log-plus-one.1525.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26230.5.clone.1) + %negate.4556.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1525.3.clone.1) + %compare.7199.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4556.4.clone.1, %broadcast.244477.384), direction=LT + %select.20840.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20841.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20842.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20843.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20844.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20845.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20846.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20847.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20848.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247228.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4556.4.clone.1, %broadcast.244496.640) + %sqrt.1525.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4556.4.clone.1) + %add.247229.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1525.5.clone.1, %broadcast.244498.640) + %select.20849.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7199.3.clone.1, %add.247228.5.clone.1, %add.247229.5.clone.1) + %multiply.26231.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20848.3.clone.1, %select.20849.3.clone.1) + %add.247230.1.clone.1 = f32[1280,1280]{1,0} add(%select.20847.3.clone.1, %multiply.26231.1.clone.1) + %multiply.26232.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247230.1.clone.1, %select.20849.3.clone.1) + %add.247231.1.clone.1 = f32[1280,1280]{1,0} add(%select.20846.3.clone.1, %multiply.26232.1.clone.1) + %multiply.26233.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247231.1.clone.1, %select.20849.3.clone.1) + %add.247233.1.clone.1 = f32[1280,1280]{1,0} add(%select.20845.3.clone.1, %multiply.26233.1.clone.1) + %multiply.26234.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247233.1.clone.1, %select.20849.3.clone.1) + %add.247234.1.clone.1 = f32[1280,1280]{1,0} add(%select.20844.3.clone.1, %multiply.26234.1.clone.1) + %multiply.26235.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247234.1.clone.1, %select.20849.3.clone.1) + %add.247235.3.clone.1 = f32[1280,1280]{1,0} add(%select.20843.5.clone.1, %multiply.26235.1.clone.1) + %multiply.26236.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247235.3.clone.1, %select.20849.3.clone.1) + %add.247236.3.clone.1 = f32[1280,1280]{1,0} add(%select.20842.5.clone.1, %multiply.26236.1.clone.1) + %multiply.26237.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247236.3.clone.1, %select.20849.3.clone.1) + %add.247238.9.clone.1 = f32[1280,1280]{1,0} add(%select.20841.11.clone.1, %multiply.26237.7.clone.1) + %multiply.26238.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247238.9.clone.1, %select.20849.3.clone.1) + %add.247239.7.clone.1 = f32[1280,1280]{1,0} add(%select.20840.7.clone.1, %multiply.26238.7.clone.1) + %multiply.26239.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247239.7.clone.1, %maximum.3663.3.clone.1) + %select.20850.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7198.3.clone.1, %multiply.26229.9.clone.1, %multiply.26239.7.clone.1) + %multiply.26240.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20850.7.clone.1, %broadcast.244500.640) + %clamp.1169.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26240.5.clone.1, %broadcast.244501.384) + %multiply.26241.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1169.3.clone.1, %broadcast.244502.1) + %constant_183930_1_clone_1 = u32[] constant(396441042) + %broadcast.255570.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183930_1_clone_1), dimensions={} + %add.250428.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255570.44.clone.1) + %constant_183937_1_clone_1 = u32[] constant(1703058310) + %broadcast.255571.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183937_1_clone_1), dimensions={} + %add.250429.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255571.113.clone.1) + %add.250430.35.clone.1 = u32[1280,1280]{1,0} add(%add.250428.37.clone.1, %add.250429.99.clone.1) + %shift-left.110500.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250429.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116755.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250429.99.clone.1, %broadcast.244415.6016) + %or.116266.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110500.31.clone.1, %shift-right-logical.116755.29.clone.1) + %xor.122836.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250430.35.clone.1, %or.116266.29.clone.1) + %add.250431.5.clone.1 = u32[1280,1280]{1,0} add(%add.250430.35.clone.1, %xor.122836.27.clone.1) + %shift-left.110501.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122836.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116756.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122836.27.clone.1, %broadcast.244417.5760) + %or.116267.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110501.9.clone.1, %shift-right-logical.116756.9.clone.1) + %xor.122837.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250431.5.clone.1, %or.116267.7.clone.1) + %add.250433.3.clone.1 = u32[1280,1280]{1,0} add(%add.250431.5.clone.1, %xor.122837.5.clone.1) + %shift-left.110502.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122837.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116757.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122837.5.clone.1, %broadcast.244419.4352) + %or.116268.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110502.5.clone.1, %shift-right-logical.116757.5.clone.1) + %xor.122838.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250433.3.clone.1, %or.116268.3.clone.1) + %add.250434.3.clone.1 = u32[1280,1280]{1,0} add(%add.250433.3.clone.1, %xor.122838.3.clone.1) + %add.250435.7.clone.1 = u32[1280,1280]{1,0} add(%add.250434.3.clone.1, %broadcast.255571.113.clone.1) + %shift-left.110503.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122838.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116758.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122838.3.clone.1, %broadcast.244418.4352) + %or.116269.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110503.5.clone.1, %shift-right-logical.116758.5.clone.1) + %xor.122840.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250434.3.clone.1, %or.116269.3.clone.1) + %constant_218469_1_clone_1 = u32[] constant(1777513871) + %broadcast.255583.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218469_1_clone_1), dimensions={} + %add.250436.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122840.3.clone.1, %broadcast.255583.5.clone.1) + %add.250438.5.clone.1 = u32[1280,1280]{1,0} add(%add.250435.7.clone.1, %add.250436.5.clone.1) + %shift-left.110504.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250436.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116759.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250436.5.clone.1, %broadcast.244416.5760) + %or.116270.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110504.9.clone.1, %shift-right-logical.116759.9.clone.1) + %xor.122841.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250438.5.clone.1, %or.116270.7.clone.1) + %add.250439.3.clone.1 = u32[1280,1280]{1,0} add(%add.250438.5.clone.1, %xor.122841.5.clone.1) + %shift-left.110505.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122841.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116760.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122841.5.clone.1, %broadcast.244429.2304) + %or.116271.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110505.9.clone.1, %shift-right-logical.116760.9.clone.1) + %xor.122842.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250439.3.clone.1, %or.116271.7.clone.1) + %add.250440.3.clone.1 = u32[1280,1280]{1,0} add(%add.250439.3.clone.1, %xor.122842.5.clone.1) + %shift-left.110506.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122842.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116761.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122842.5.clone.1, %broadcast.244430.4608) + %or.116272.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110506.9.clone.1, %shift-right-logical.116761.9.clone.1) + %xor.122843.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250440.3.clone.1, %or.116272.7.clone.1) + %add.250441.3.clone.1 = u32[1280,1280]{1,0} add(%add.250440.3.clone.1, %xor.122843.5.clone.1) + %constant_183939_1_clone_1 = u32[] constant(1777513870) + %broadcast.255591.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183939_1_clone_1), dimensions={} + %add.250443.7.clone.1 = u32[1280,1280]{1,0} add(%add.250441.3.clone.1, %broadcast.255591.24.clone.1) + %shift-left.110507.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122843.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116762.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122843.5.clone.1, %broadcast.244434.2816) + %or.116273.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110507.11.clone.1, %shift-right-logical.116762.11.clone.1) + %xor.122845.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250441.3.clone.1, %or.116273.9.clone.1) + %constant_218470_1_clone_1 = u32[] constant(396441044) + %broadcast.255597.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218470_1_clone_1), dimensions={} + %add.250444.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122845.7.clone.1, %broadcast.255597.5.clone.1) + %add.250445.5.clone.1 = u32[1280,1280]{1,0} add(%add.250443.7.clone.1, %add.250444.5.clone.1) + %shift-left.110508.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250444.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116764.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250444.5.clone.1, %broadcast.244415.6016) + %or.116275.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110508.9.clone.1, %shift-right-logical.116764.9.clone.1) + %xor.122846.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250445.5.clone.1, %or.116275.7.clone.1) + %add.250446.3.clone.1 = u32[1280,1280]{1,0} add(%add.250445.5.clone.1, %xor.122846.5.clone.1) + %shift-left.110509.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122846.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116765.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122846.5.clone.1, %broadcast.244417.5760) + %or.116276.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110509.9.clone.1, %shift-right-logical.116765.9.clone.1) + %xor.122847.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250446.3.clone.1, %or.116276.7.clone.1) + %add.250447.3.clone.1 = u32[1280,1280]{1,0} add(%add.250446.3.clone.1, %xor.122847.5.clone.1) + %shift-left.110511.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122847.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116766.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122847.5.clone.1, %broadcast.244419.4352) + %or.116277.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110511.7.clone.1, %shift-right-logical.116766.7.clone.1) + %xor.122848.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250447.3.clone.1, %or.116277.5.clone.1) + %add.250449.3.clone.1 = u32[1280,1280]{1,0} add(%add.250447.3.clone.1, %xor.122848.3.clone.1) + %add.250453.7.clone.1 = u32[1280,1280]{1,0} add(%add.250449.3.clone.1, %broadcast.255570.44.clone.1) + %shift-left.110512.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122848.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116767.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122848.3.clone.1, %broadcast.244418.4352) + %or.116278.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110512.7.clone.1, %shift-right-logical.116767.7.clone.1) + %xor.122850.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250449.3.clone.1, %or.116278.5.clone.1) + %constant_218471_1_clone_1 = u32[] constant(1703058313) + %broadcast.255617.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218471_1_clone_1), dimensions={} + %add.250454.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122850.3.clone.1, %broadcast.255617.5.clone.1) + %add.250455.5.clone.1 = u32[1280,1280]{1,0} add(%add.250453.7.clone.1, %add.250454.5.clone.1) + %shift-left.110513.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250454.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116769.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250454.5.clone.1, %broadcast.244416.5760) + %or.116280.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110513.9.clone.1, %shift-right-logical.116769.9.clone.1) + %xor.122851.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250455.5.clone.1, %or.116280.7.clone.1) + %add.250456.3.clone.1 = u32[1280,1280]{1,0} add(%add.250455.5.clone.1, %xor.122851.5.clone.1) + %shift-left.110514.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122851.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116770.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122851.5.clone.1, %broadcast.244429.2304) + %or.116281.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110514.9.clone.1, %shift-right-logical.116770.9.clone.1) + %xor.122852.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250456.3.clone.1, %or.116281.7.clone.1) + %add.250458.3.clone.1 = u32[1280,1280]{1,0} add(%add.250456.3.clone.1, %xor.122852.5.clone.1) + %shift-left.110516.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122852.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116771.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122852.5.clone.1, %broadcast.244430.4608) + %or.116282.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110516.9.clone.1, %shift-right-logical.116771.9.clone.1) + %xor.122853.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250458.3.clone.1, %or.116282.7.clone.1) + %add.250459.3.clone.1 = u32[1280,1280]{1,0} add(%add.250458.3.clone.1, %xor.122853.5.clone.1) + %add.250460.7.clone.1 = u32[1280,1280]{1,0} add(%add.250459.3.clone.1, %broadcast.255571.113.clone.1) + %shift-left.110517.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122853.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116772.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122853.5.clone.1, %broadcast.244434.2816) + %or.116283.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110517.11.clone.1, %shift-right-logical.116772.11.clone.1) + %xor.122855.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250459.3.clone.1, %or.116283.9.clone.1) + %constant_218472_1_clone_1 = u32[] constant(1777513874) + %broadcast.255637.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218472_1_clone_1), dimensions={} + %add.250461.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122855.7.clone.1, %broadcast.255637.5.clone.1) + %add.250463.5.clone.1 = u32[1280,1280]{1,0} add(%add.250460.7.clone.1, %add.250461.5.clone.1) + %shift-left.110518.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250461.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116774.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250461.5.clone.1, %broadcast.244415.6016) + %or.116285.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110518.9.clone.1, %shift-right-logical.116774.9.clone.1) + %xor.122856.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250463.5.clone.1, %or.116285.7.clone.1) + %add.250464.3.clone.1 = u32[1280,1280]{1,0} add(%add.250463.5.clone.1, %xor.122856.5.clone.1) + %shift-left.110519.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122856.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116775.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122856.5.clone.1, %broadcast.244417.5760) + %or.116286.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110519.9.clone.1, %shift-right-logical.116775.9.clone.1) + %xor.122857.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250464.3.clone.1, %or.116286.7.clone.1) + %add.250465.3.clone.1 = u32[1280,1280]{1,0} add(%add.250464.3.clone.1, %xor.122857.5.clone.1) + %shift-left.110521.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122857.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116776.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122857.5.clone.1, %broadcast.244419.4352) + %or.116287.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110521.5.clone.1, %shift-right-logical.116776.5.clone.1) + %xor.122858.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250465.3.clone.1, %or.116287.3.clone.1) + %add.250466.3.clone.1 = u32[1280,1280]{1,0} add(%add.250465.3.clone.1, %xor.122858.3.clone.1) + %add.250468.17.clone.1 = u32[1280,1280]{1,0} add(%add.250466.3.clone.1, %broadcast.255591.24.clone.1) + %shift-left.110522.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122858.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116777.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122858.3.clone.1, %broadcast.244418.4352) + %or.116288.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110522.5.clone.1, %shift-right-logical.116777.5.clone.1) + %xor.122859.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250466.3.clone.1, %or.116288.3.clone.1) + %constant_218473_1_clone_1 = u32[] constant(396441047) + %broadcast.255647.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218473_1_clone_1), dimensions={} + %add.250469.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122859.15.clone.1, %broadcast.255647.19.clone.1) + %xor.122860.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250468.17.clone.1, %add.250469.19.clone.1) + %shift-right-logical.116779.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122860.17.clone.1, %broadcast.244468.1920) + %or.116290.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116779.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5796.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116290.13.clone.1) + %add.250470.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5796.11.clone.1, %broadcast.244470.1152) + %multiply.26889.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250470.9.clone.1, %broadcast.244471.896) + %add.250471.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26889.7.clone.1, %broadcast.244408.1024) + %maximum.3728.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250471.5.clone.1) + %abs.1568.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3728.3.clone.1) + %compare.7287.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1568.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26890.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3728.3.clone.1, %broadcast.244476.1152) + %negate.4641.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3728.3.clone.1) + %multiply.26891.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3728.3.clone.1, %negate.4641.5.clone.1) + %log-plus-one.1568.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26891.5.clone.1) + %negate.4642.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1568.3.clone.1) + %compare.7291.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4642.4.clone.1, %broadcast.244477.384), direction=LT + %select.21334.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21335.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21336.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21337.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21338.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21339.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21340.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21341.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21342.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250472.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4642.4.clone.1, %broadcast.244496.640) + %sqrt.1568.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4642.4.clone.1) + %add.250474.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1568.5.clone.1, %broadcast.244498.640) + %select.21343.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7291.3.clone.1, %add.250472.5.clone.1, %add.250474.5.clone.1) + %multiply.26892.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21342.3.clone.1, %select.21343.3.clone.1) + %add.250478.1.clone.1 = f32[1280,1280]{1,0} add(%select.21341.3.clone.1, %multiply.26892.1.clone.1) + %multiply.26893.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250478.1.clone.1, %select.21343.3.clone.1) + %add.250479.1.clone.1 = f32[1280,1280]{1,0} add(%select.21340.3.clone.1, %multiply.26893.1.clone.1) + %multiply.26894.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250479.1.clone.1, %select.21343.3.clone.1) + %add.250480.1.clone.1 = f32[1280,1280]{1,0} add(%select.21339.3.clone.1, %multiply.26894.1.clone.1) + %multiply.26895.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250480.1.clone.1, %select.21343.3.clone.1) + %add.250481.1.clone.1 = f32[1280,1280]{1,0} add(%select.21338.3.clone.1, %multiply.26895.1.clone.1) + %multiply.26896.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250481.1.clone.1, %select.21343.3.clone.1) + %add.250483.3.clone.1 = f32[1280,1280]{1,0} add(%select.21337.5.clone.1, %multiply.26896.1.clone.1) + %multiply.26897.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250483.3.clone.1, %select.21343.3.clone.1) + %add.250484.3.clone.1 = f32[1280,1280]{1,0} add(%select.21336.5.clone.1, %multiply.26897.1.clone.1) + %multiply.26898.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250484.3.clone.1, %select.21343.3.clone.1) + %add.250485.9.clone.1 = f32[1280,1280]{1,0} add(%select.21335.11.clone.1, %multiply.26898.7.clone.1) + %multiply.26899.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250485.9.clone.1, %select.21343.3.clone.1) + %add.250486.7.clone.1 = f32[1280,1280]{1,0} add(%select.21334.7.clone.1, %multiply.26899.7.clone.1) + %multiply.26900.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250486.7.clone.1, %maximum.3728.3.clone.1) + %select.21344.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7287.3.clone.1, %multiply.26890.9.clone.1, %multiply.26900.7.clone.1) + %multiply.26901.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21344.7.clone.1, %broadcast.244500.640) + %clamp.1212.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26901.5.clone.1, %broadcast.244501.384) + %multiply.26902.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1212.3.clone.1, %broadcast.244502.1) + %constant_170506_1_clone_1 = u32[] constant(2641164229) + %broadcast.249769.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170506_1_clone_1), dimensions={} + %add.247119.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249769.44.clone.1) + %constant_170513_1_clone_1 = u32[] constant(3440052210) + %broadcast.249770.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170513_1_clone_1), dimensions={} + %add.247120.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249770.113.clone.1) + %add.247121.35.clone.1 = u32[1280,1280]{1,0} add(%add.247119.37.clone.1, %add.247120.99.clone.1) + %shift-left.109069.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247120.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115235.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247120.99.clone.1, %broadcast.244415.6016) + %or.114765.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109069.31.clone.1, %shift-right-logical.115235.29.clone.1) + %xor.121313.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247121.35.clone.1, %or.114765.29.clone.1) + %add.247122.5.clone.1 = u32[1280,1280]{1,0} add(%add.247121.35.clone.1, %xor.121313.27.clone.1) + %shift-left.109070.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121313.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115237.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121313.27.clone.1, %broadcast.244417.5760) + %or.114766.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109070.9.clone.1, %shift-right-logical.115237.9.clone.1) + %xor.121314.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247122.5.clone.1, %or.114766.7.clone.1) + %add.247123.3.clone.1 = u32[1280,1280]{1,0} add(%add.247122.5.clone.1, %xor.121314.5.clone.1) + %shift-left.109071.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121314.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115238.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121314.5.clone.1, %broadcast.244419.4352) + %or.114767.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109071.5.clone.1, %shift-right-logical.115238.5.clone.1) + %xor.121315.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247123.3.clone.1, %or.114767.3.clone.1) + %add.247125.3.clone.1 = u32[1280,1280]{1,0} add(%add.247123.3.clone.1, %xor.121315.3.clone.1) + %add.247128.7.clone.1 = u32[1280,1280]{1,0} add(%add.247125.3.clone.1, %broadcast.249770.113.clone.1) + %shift-left.109072.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121315.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115239.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121315.3.clone.1, %broadcast.244418.4352) + %or.114768.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109072.5.clone.1, %shift-right-logical.115239.5.clone.1) + %xor.121316.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247125.3.clone.1, %or.114768.3.clone.1) + %constant_218101_1_clone_1 = u32[] constant(1270279150) + %broadcast.249780.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218101_1_clone_1), dimensions={} + %add.247129.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121316.3.clone.1, %broadcast.249780.5.clone.1) + %add.247130.5.clone.1 = u32[1280,1280]{1,0} add(%add.247128.7.clone.1, %add.247129.5.clone.1) + %shift-left.109074.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247129.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115240.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247129.5.clone.1, %broadcast.244416.5760) + %or.114769.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109074.9.clone.1, %shift-right-logical.115240.9.clone.1) + %xor.121317.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247130.5.clone.1, %or.114769.7.clone.1) + %add.247131.3.clone.1 = u32[1280,1280]{1,0} add(%add.247130.5.clone.1, %xor.121317.5.clone.1) + %shift-left.109075.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121317.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115241.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121317.5.clone.1, %broadcast.244429.2304) + %or.114770.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109075.9.clone.1, %shift-right-logical.115241.9.clone.1) + %xor.121318.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247131.3.clone.1, %or.114770.7.clone.1) + %add.247133.3.clone.1 = u32[1280,1280]{1,0} add(%add.247131.3.clone.1, %xor.121318.5.clone.1) + %shift-left.109076.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121318.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115242.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121318.5.clone.1, %broadcast.244430.4608) + %or.114771.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109076.9.clone.1, %shift-right-logical.115242.9.clone.1) + %xor.121319.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247133.3.clone.1, %or.114771.7.clone.1) + %add.247134.3.clone.1 = u32[1280,1280]{1,0} add(%add.247133.3.clone.1, %xor.121319.5.clone.1) + %constant_170515_1_clone_1 = u32[] constant(1270279149) + %broadcast.249787.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170515_1_clone_1), dimensions={} + %add.247135.7.clone.1 = u32[1280,1280]{1,0} add(%add.247134.3.clone.1, %broadcast.249787.24.clone.1) + %shift-left.109077.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121319.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115243.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121319.5.clone.1, %broadcast.244434.2816) + %or.114772.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109077.11.clone.1, %shift-right-logical.115243.11.clone.1) + %xor.121320.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247134.3.clone.1, %or.114772.9.clone.1) + %constant_218102_1_clone_1 = u32[] constant(2641164231) + %broadcast.249792.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218102_1_clone_1), dimensions={} + %add.247136.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121320.7.clone.1, %broadcast.249792.5.clone.1) + %add.247138.5.clone.1 = u32[1280,1280]{1,0} add(%add.247135.7.clone.1, %add.247136.5.clone.1) + %shift-left.109079.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247136.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115244.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247136.5.clone.1, %broadcast.244415.6016) + %or.114773.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109079.9.clone.1, %shift-right-logical.115244.9.clone.1) + %xor.121321.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247138.5.clone.1, %or.114773.7.clone.1) + %add.247139.3.clone.1 = u32[1280,1280]{1,0} add(%add.247138.5.clone.1, %xor.121321.5.clone.1) + %shift-left.109080.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121321.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115245.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121321.5.clone.1, %broadcast.244417.5760) + %or.114774.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109080.9.clone.1, %shift-right-logical.115245.9.clone.1) + %xor.121322.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247139.3.clone.1, %or.114774.7.clone.1) + %add.247140.3.clone.1 = u32[1280,1280]{1,0} add(%add.247139.3.clone.1, %xor.121322.5.clone.1) + %shift-left.109081.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121322.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115247.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121322.5.clone.1, %broadcast.244419.4352) + %or.114775.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109081.7.clone.1, %shift-right-logical.115247.7.clone.1) + %xor.121323.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247140.3.clone.1, %or.114775.5.clone.1) + %add.247141.3.clone.1 = u32[1280,1280]{1,0} add(%add.247140.3.clone.1, %xor.121323.3.clone.1) + %add.247143.7.clone.1 = u32[1280,1280]{1,0} add(%add.247141.3.clone.1, %broadcast.249769.44.clone.1) + %shift-left.109082.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121323.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115248.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121323.3.clone.1, %broadcast.244418.4352) + %or.114776.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109082.7.clone.1, %shift-right-logical.115248.7.clone.1) + %xor.121324.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247141.3.clone.1, %or.114776.5.clone.1) + %constant_218103_1_clone_1 = u32[] constant(3440052213) + %broadcast.249802.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218103_1_clone_1), dimensions={} + %add.247144.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121324.3.clone.1, %broadcast.249802.5.clone.1) + %add.247145.5.clone.1 = u32[1280,1280]{1,0} add(%add.247143.7.clone.1, %add.247144.5.clone.1) + %shift-left.109084.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247144.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115249.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247144.5.clone.1, %broadcast.244416.5760) + %or.114777.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109084.9.clone.1, %shift-right-logical.115249.9.clone.1) + %xor.121325.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247145.5.clone.1, %or.114777.7.clone.1) + %add.247146.3.clone.1 = u32[1280,1280]{1,0} add(%add.247145.5.clone.1, %xor.121325.5.clone.1) + %shift-left.109085.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121325.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115250.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121325.5.clone.1, %broadcast.244429.2304) + %or.114779.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109085.9.clone.1, %shift-right-logical.115250.9.clone.1) + %xor.121326.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247146.3.clone.1, %or.114779.7.clone.1) + %add.247147.3.clone.1 = u32[1280,1280]{1,0} add(%add.247146.3.clone.1, %xor.121326.5.clone.1) + %shift-left.109086.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121326.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115252.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121326.5.clone.1, %broadcast.244430.4608) + %or.114780.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109086.9.clone.1, %shift-right-logical.115252.9.clone.1) + %xor.121327.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247147.3.clone.1, %or.114780.7.clone.1) + %add.247149.3.clone.1 = u32[1280,1280]{1,0} add(%add.247147.3.clone.1, %xor.121327.5.clone.1) + %add.247153.7.clone.1 = u32[1280,1280]{1,0} add(%add.247149.3.clone.1, %broadcast.249770.113.clone.1) + %shift-left.109087.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121327.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115253.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121327.5.clone.1, %broadcast.244434.2816) + %or.114781.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109087.11.clone.1, %shift-right-logical.115253.11.clone.1) + %xor.121328.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247149.3.clone.1, %or.114781.9.clone.1) + %constant_218104_1_clone_1 = u32[] constant(1270279153) + %broadcast.249812.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218104_1_clone_1), dimensions={} + %add.247154.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121328.7.clone.1, %broadcast.249812.5.clone.1) + %add.247155.5.clone.1 = u32[1280,1280]{1,0} add(%add.247153.7.clone.1, %add.247154.5.clone.1) + %shift-left.109088.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247154.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115254.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247154.5.clone.1, %broadcast.244415.6016) + %or.114782.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109088.9.clone.1, %shift-right-logical.115254.9.clone.1) + %xor.121329.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247155.5.clone.1, %or.114782.7.clone.1) + %add.247156.3.clone.1 = u32[1280,1280]{1,0} add(%add.247155.5.clone.1, %xor.121329.5.clone.1) + %shift-left.109089.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121329.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115255.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121329.5.clone.1, %broadcast.244417.5760) + %or.114783.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109089.9.clone.1, %shift-right-logical.115255.9.clone.1) + %xor.121330.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247156.3.clone.1, %or.114783.7.clone.1) + %add.247158.3.clone.1 = u32[1280,1280]{1,0} add(%add.247156.3.clone.1, %xor.121330.5.clone.1) + %shift-left.109090.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121330.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115257.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121330.5.clone.1, %broadcast.244419.4352) + %or.114784.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109090.5.clone.1, %shift-right-logical.115257.5.clone.1) + %xor.121331.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247158.3.clone.1, %or.114784.3.clone.1) + %add.247159.3.clone.1 = u32[1280,1280]{1,0} add(%add.247158.3.clone.1, %xor.121331.3.clone.1) + %add.247160.17.clone.1 = u32[1280,1280]{1,0} add(%add.247159.3.clone.1, %broadcast.249787.24.clone.1) + %shift-left.109091.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121331.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115258.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121331.3.clone.1, %broadcast.244418.4352) + %or.114785.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109091.5.clone.1, %shift-right-logical.115258.5.clone.1) + %xor.121332.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247159.3.clone.1, %or.114785.3.clone.1) + %constant_218105_1_clone_1 = u32[] constant(2641164234) + %broadcast.249824.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218105_1_clone_1), dimensions={} + %add.247161.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121332.15.clone.1, %broadcast.249824.19.clone.1) + %xor.121333.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247160.17.clone.1, %add.247161.19.clone.1) + %shift-right-logical.115259.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121333.17.clone.1, %broadcast.244468.1920) + %or.114786.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115259.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5730.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114786.13.clone.1) + %add.247163.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5730.11.clone.1, %broadcast.244470.1152) + %multiply.26214.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247163.9.clone.1, %broadcast.244471.896) + %add.247164.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26214.7.clone.1, %broadcast.244408.1024) + %maximum.3662.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247164.5.clone.1) + %abs.1524.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3662.3.clone.1) + %compare.7196.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1524.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26215.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3662.3.clone.1, %broadcast.244476.1152) + %negate.4553.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3662.3.clone.1) + %multiply.26216.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3662.3.clone.1, %negate.4553.5.clone.1) + %log-plus-one.1524.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26216.5.clone.1) + %negate.4554.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1524.3.clone.1) + %compare.7197.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4554.4.clone.1, %broadcast.244477.384), direction=LT + %select.20829.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20830.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20831.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20832.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20833.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20834.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20835.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20836.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20837.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247165.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4554.4.clone.1, %broadcast.244496.640) + %sqrt.1524.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4554.4.clone.1) + %add.247166.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1524.5.clone.1, %broadcast.244498.640) + %select.20838.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7197.3.clone.1, %add.247165.5.clone.1, %add.247166.5.clone.1) + %multiply.26217.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20837.3.clone.1, %select.20838.3.clone.1) + %add.247168.1.clone.1 = f32[1280,1280]{1,0} add(%select.20836.3.clone.1, %multiply.26217.1.clone.1) + %multiply.26218.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247168.1.clone.1, %select.20838.3.clone.1) + %add.247169.1.clone.1 = f32[1280,1280]{1,0} add(%select.20835.3.clone.1, %multiply.26218.1.clone.1) + %multiply.26219.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247169.1.clone.1, %select.20838.3.clone.1) + %add.247170.1.clone.1 = f32[1280,1280]{1,0} add(%select.20834.3.clone.1, %multiply.26219.1.clone.1) + %multiply.26220.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247170.1.clone.1, %select.20838.3.clone.1) + %add.247171.1.clone.1 = f32[1280,1280]{1,0} add(%select.20833.3.clone.1, %multiply.26220.1.clone.1) + %multiply.26221.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247171.1.clone.1, %select.20838.3.clone.1) + %add.247172.3.clone.1 = f32[1280,1280]{1,0} add(%select.20832.5.clone.1, %multiply.26221.1.clone.1) + %multiply.26222.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247172.3.clone.1, %select.20838.3.clone.1) + %add.247174.3.clone.1 = f32[1280,1280]{1,0} add(%select.20831.5.clone.1, %multiply.26222.1.clone.1) + %multiply.26223.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247174.3.clone.1, %select.20838.3.clone.1) + %add.247178.9.clone.1 = f32[1280,1280]{1,0} add(%select.20830.11.clone.1, %multiply.26223.7.clone.1) + %multiply.26224.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247178.9.clone.1, %select.20838.3.clone.1) + %add.247179.7.clone.1 = f32[1280,1280]{1,0} add(%select.20829.7.clone.1, %multiply.26224.7.clone.1) + %multiply.26225.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247179.7.clone.1, %maximum.3662.3.clone.1) + %select.20839.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7196.3.clone.1, %multiply.26215.9.clone.1, %multiply.26225.7.clone.1) + %multiply.26226.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20839.7.clone.1, %broadcast.244500.640) + %clamp.1168.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26226.5.clone.1, %broadcast.244501.384) + %multiply.26227.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1168.3.clone.1, %broadcast.244502.1) + %constant_195327_1_clone_1 = u32[] constant(912190203) + %broadcast.260497.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195327_1_clone_1), dimensions={} + %add.253251.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.260497.44.clone.1) + %constant_195334_1_clone_1 = u32[] constant(953710832) + %broadcast.260499.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195334_1_clone_1), dimensions={} + %add.253252.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.260499.113.clone.1) + %add.253254.35.clone.1 = u32[1280,1280]{1,0} add(%add.253251.37.clone.1, %add.253252.99.clone.1) + %shift-left.111737.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253252.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.118034.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253252.99.clone.1, %broadcast.244415.6016) + %or.117569.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111737.31.clone.1, %shift-right-logical.118034.29.clone.1) + %xor.124124.27.clone.1 = u32[1280,1280]{1,0} xor(%add.253254.35.clone.1, %or.117569.29.clone.1) + %add.253255.5.clone.1 = u32[1280,1280]{1,0} add(%add.253254.35.clone.1, %xor.124124.27.clone.1) + %shift-left.111738.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124124.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.118035.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124124.27.clone.1, %broadcast.244417.5760) + %or.117570.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111738.9.clone.1, %shift-right-logical.118035.9.clone.1) + %xor.124126.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253255.5.clone.1, %or.117570.7.clone.1) + %add.253256.3.clone.1 = u32[1280,1280]{1,0} add(%add.253255.5.clone.1, %xor.124126.5.clone.1) + %shift-left.111739.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124126.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118037.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124126.5.clone.1, %broadcast.244419.4352) + %or.117571.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111739.5.clone.1, %shift-right-logical.118037.5.clone.1) + %xor.124127.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253256.3.clone.1, %or.117571.3.clone.1) + %add.253257.3.clone.1 = u32[1280,1280]{1,0} add(%add.253256.3.clone.1, %xor.124127.3.clone.1) + %add.253259.7.clone.1 = u32[1280,1280]{1,0} add(%add.253257.3.clone.1, %broadcast.260499.113.clone.1) + %shift-left.111741.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124127.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118038.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124127.3.clone.1, %broadcast.244418.4352) + %or.117572.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111741.5.clone.1, %shift-right-logical.118038.5.clone.1) + %xor.124128.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253257.3.clone.1, %or.117572.3.clone.1) + %constant_218787_1_clone_1 = u32[] constant(358058450) + %broadcast.260514.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218787_1_clone_1), dimensions={} + %add.253260.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124128.3.clone.1, %broadcast.260514.5.clone.1) + %add.253261.5.clone.1 = u32[1280,1280]{1,0} add(%add.253259.7.clone.1, %add.253260.5.clone.1) + %shift-left.111742.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253260.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118039.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253260.5.clone.1, %broadcast.244416.5760) + %or.117573.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111742.9.clone.1, %shift-right-logical.118039.9.clone.1) + %xor.124129.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253261.5.clone.1, %or.117573.7.clone.1) + %add.253262.3.clone.1 = u32[1280,1280]{1,0} add(%add.253261.5.clone.1, %xor.124129.5.clone.1) + %shift-left.111743.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124129.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118040.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124129.5.clone.1, %broadcast.244429.2304) + %or.117574.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111743.9.clone.1, %shift-right-logical.118040.9.clone.1) + %xor.124131.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253262.3.clone.1, %or.117574.7.clone.1) + %add.253263.3.clone.1 = u32[1280,1280]{1,0} add(%add.253262.3.clone.1, %xor.124131.5.clone.1) + %shift-left.111744.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124131.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118041.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124131.5.clone.1, %broadcast.244430.4608) + %or.117575.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111744.9.clone.1, %shift-right-logical.118041.9.clone.1) + %xor.124132.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253263.3.clone.1, %or.117575.7.clone.1) + %add.253265.3.clone.1 = u32[1280,1280]{1,0} add(%add.253263.3.clone.1, %xor.124132.5.clone.1) + %constant_195336_1_clone_1 = u32[] constant(358058449) + %broadcast.260521.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195336_1_clone_1), dimensions={} + %add.253269.7.clone.1 = u32[1280,1280]{1,0} add(%add.253265.3.clone.1, %broadcast.260521.24.clone.1) + %shift-left.111746.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124132.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118042.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124132.5.clone.1, %broadcast.244434.2816) + %or.117576.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111746.11.clone.1, %shift-right-logical.118042.11.clone.1) + %xor.124133.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253265.3.clone.1, %or.117576.9.clone.1) + %constant_218788_1_clone_1 = u32[] constant(912190205) + %broadcast.260524.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218788_1_clone_1), dimensions={} + %add.253270.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124133.7.clone.1, %broadcast.260524.5.clone.1) + %add.253271.5.clone.1 = u32[1280,1280]{1,0} add(%add.253269.7.clone.1, %add.253270.5.clone.1) + %shift-left.111747.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253270.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118043.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253270.5.clone.1, %broadcast.244415.6016) + %or.117577.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111747.9.clone.1, %shift-right-logical.118043.9.clone.1) + %xor.124134.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253271.5.clone.1, %or.117577.7.clone.1) + %add.253272.3.clone.1 = u32[1280,1280]{1,0} add(%add.253271.5.clone.1, %xor.124134.5.clone.1) + %shift-left.111748.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124134.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118044.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124134.5.clone.1, %broadcast.244417.5760) + %or.117578.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111748.9.clone.1, %shift-right-logical.118044.9.clone.1) + %xor.124136.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253272.3.clone.1, %or.117578.7.clone.1) + %add.253274.3.clone.1 = u32[1280,1280]{1,0} add(%add.253272.3.clone.1, %xor.124136.5.clone.1) + %shift-left.111749.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124136.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118045.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124136.5.clone.1, %broadcast.244419.4352) + %or.117579.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111749.7.clone.1, %shift-right-logical.118045.7.clone.1) + %xor.124137.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253274.3.clone.1, %or.117579.5.clone.1) + %add.253275.3.clone.1 = u32[1280,1280]{1,0} add(%add.253274.3.clone.1, %xor.124137.3.clone.1) + %add.253276.7.clone.1 = u32[1280,1280]{1,0} add(%add.253275.3.clone.1, %broadcast.260497.44.clone.1) + %shift-left.111751.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124137.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118047.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124137.3.clone.1, %broadcast.244418.4352) + %or.117580.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111751.7.clone.1, %shift-right-logical.118047.7.clone.1) + %xor.124138.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253275.3.clone.1, %or.117580.5.clone.1) + %constant_218789_1_clone_1 = u32[] constant(953710835) + %broadcast.260536.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218789_1_clone_1), dimensions={} + %add.253277.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124138.3.clone.1, %broadcast.260536.5.clone.1) + %add.253279.5.clone.1 = u32[1280,1280]{1,0} add(%add.253276.7.clone.1, %add.253277.5.clone.1) + %shift-left.111752.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253277.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118048.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253277.5.clone.1, %broadcast.244416.5760) + %or.117581.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111752.9.clone.1, %shift-right-logical.118048.9.clone.1) + %xor.124139.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253279.5.clone.1, %or.117581.7.clone.1) + %add.253280.3.clone.1 = u32[1280,1280]{1,0} add(%add.253279.5.clone.1, %xor.124139.5.clone.1) + %shift-left.111753.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124139.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118049.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124139.5.clone.1, %broadcast.244429.2304) + %or.117582.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111753.9.clone.1, %shift-right-logical.118049.9.clone.1) + %xor.124141.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253280.3.clone.1, %or.117582.7.clone.1) + %add.253281.3.clone.1 = u32[1280,1280]{1,0} add(%add.253280.3.clone.1, %xor.124141.5.clone.1) + %shift-left.111754.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124141.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118050.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124141.5.clone.1, %broadcast.244430.4608) + %or.117583.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111754.9.clone.1, %shift-right-logical.118050.9.clone.1) + %xor.124142.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253281.3.clone.1, %or.117583.7.clone.1) + %add.253282.3.clone.1 = u32[1280,1280]{1,0} add(%add.253281.3.clone.1, %xor.124142.5.clone.1) + %add.253284.7.clone.1 = u32[1280,1280]{1,0} add(%add.253282.3.clone.1, %broadcast.260499.113.clone.1) + %shift-left.111755.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124142.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118052.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124142.5.clone.1, %broadcast.244434.2816) + %or.117584.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111755.11.clone.1, %shift-right-logical.118052.11.clone.1) + %xor.124143.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253282.3.clone.1, %or.117584.9.clone.1) + %constant_218790_1_clone_1 = u32[] constant(358058453) + %broadcast.260546.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218790_1_clone_1), dimensions={} + %add.253285.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124143.7.clone.1, %broadcast.260546.5.clone.1) + %add.253286.5.clone.1 = u32[1280,1280]{1,0} add(%add.253284.7.clone.1, %add.253285.5.clone.1) + %shift-left.111756.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253285.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118053.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253285.5.clone.1, %broadcast.244415.6016) + %or.117585.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111756.9.clone.1, %shift-right-logical.118053.9.clone.1) + %xor.124144.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253286.5.clone.1, %or.117585.7.clone.1) + %add.253287.3.clone.1 = u32[1280,1280]{1,0} add(%add.253286.5.clone.1, %xor.124144.5.clone.1) + %shift-left.111757.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124144.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118054.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124144.5.clone.1, %broadcast.244417.5760) + %or.117586.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111757.9.clone.1, %shift-right-logical.118054.9.clone.1) + %xor.124145.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253287.3.clone.1, %or.117586.7.clone.1) + %add.253288.3.clone.1 = u32[1280,1280]{1,0} add(%add.253287.3.clone.1, %xor.124145.5.clone.1) + %shift-left.111758.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124145.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118055.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124145.5.clone.1, %broadcast.244419.4352) + %or.117587.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111758.5.clone.1, %shift-right-logical.118055.5.clone.1) + %xor.124146.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253288.3.clone.1, %or.117587.3.clone.1) + %add.253290.3.clone.1 = u32[1280,1280]{1,0} add(%add.253288.3.clone.1, %xor.124146.3.clone.1) + %add.253294.17.clone.1 = u32[1280,1280]{1,0} add(%add.253290.3.clone.1, %broadcast.260521.24.clone.1) + %shift-left.111759.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124146.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118057.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124146.3.clone.1, %broadcast.244418.4352) + %or.117588.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111759.5.clone.1, %shift-right-logical.118057.5.clone.1) + %xor.124147.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253290.3.clone.1, %or.117588.3.clone.1) + %constant_218791_1_clone_1 = u32[] constant(912190208) + %broadcast.260556.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218791_1_clone_1), dimensions={} + %add.253295.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124147.15.clone.1, %broadcast.260556.19.clone.1) + %xor.124148.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253294.17.clone.1, %add.253295.19.clone.1) + %shift-right-logical.118058.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124148.17.clone.1, %broadcast.244468.1920) + %or.117589.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.118058.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5852.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117589.13.clone.1) + %add.253296.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5852.11.clone.1, %broadcast.244470.1152) + %multiply.27473.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253296.9.clone.1, %broadcast.244471.896) + %add.253297.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27473.7.clone.1, %broadcast.244408.1024) + %maximum.3784.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253297.5.clone.1) + %abs.1606.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3784.3.clone.1) + %compare.7374.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1606.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27474.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3784.3.clone.1, %broadcast.244476.1152) + %negate.4717.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3784.3.clone.1) + %multiply.27475.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3784.3.clone.1, %negate.4717.5.clone.1) + %log-plus-one.1606.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27475.5.clone.1) + %negate.4718.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1606.3.clone.1) + %compare.7375.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4718.4.clone.1, %broadcast.244477.384), direction=LT + %select.21773.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21774.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21775.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21776.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21777.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21778.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21779.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21780.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21781.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253299.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4718.4.clone.1, %broadcast.244496.640) + %sqrt.1606.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4718.4.clone.1) + %add.253300.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1606.5.clone.1, %broadcast.244498.640) + %select.21782.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7375.3.clone.1, %add.253299.5.clone.1, %add.253300.5.clone.1) + %multiply.27476.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21781.3.clone.1, %select.21782.3.clone.1) + %add.253301.1.clone.1 = f32[1280,1280]{1,0} add(%select.21780.3.clone.1, %multiply.27476.1.clone.1) + %multiply.27477.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253301.1.clone.1, %select.21782.3.clone.1) + %add.253302.1.clone.1 = f32[1280,1280]{1,0} add(%select.21779.3.clone.1, %multiply.27477.1.clone.1) + %multiply.27478.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253302.1.clone.1, %select.21782.3.clone.1) + %add.253304.1.clone.1 = f32[1280,1280]{1,0} add(%select.21778.3.clone.1, %multiply.27478.1.clone.1) + %multiply.27479.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253304.1.clone.1, %select.21782.3.clone.1) + %add.253305.1.clone.1 = f32[1280,1280]{1,0} add(%select.21777.3.clone.1, %multiply.27479.1.clone.1) + %multiply.27481.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253305.1.clone.1, %select.21782.3.clone.1) + %add.253306.3.clone.1 = f32[1280,1280]{1,0} add(%select.21776.5.clone.1, %multiply.27481.1.clone.1) + %multiply.27482.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253306.3.clone.1, %select.21782.3.clone.1) + %add.253307.3.clone.1 = f32[1280,1280]{1,0} add(%select.21775.5.clone.1, %multiply.27482.1.clone.1) + %multiply.27483.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253307.3.clone.1, %select.21782.3.clone.1) + %add.253309.9.clone.1 = f32[1280,1280]{1,0} add(%select.21774.11.clone.1, %multiply.27483.7.clone.1) + %multiply.27484.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253309.9.clone.1, %select.21782.3.clone.1) + %add.253310.7.clone.1 = f32[1280,1280]{1,0} add(%select.21773.7.clone.1, %multiply.27484.7.clone.1) + %multiply.27485.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253310.7.clone.1, %maximum.3784.3.clone.1) + %select.21783.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7374.3.clone.1, %multiply.27474.9.clone.1, %multiply.27485.7.clone.1) + %multiply.27486.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21783.7.clone.1, %broadcast.244500.640) + %clamp.1250.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27486.5.clone.1, %broadcast.244501.384) + %multiply.27487.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1250.3.clone.1, %broadcast.244502.1) + %constant_170295_1_clone_1 = u32[] constant(162910663) + %broadcast.249664.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170295_1_clone_1), dimensions={} + %add.247068.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249664.44.clone.1) + %constant_170302_1_clone_1 = u32[] constant(507276444) + %broadcast.249665.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170302_1_clone_1), dimensions={} + %add.247069.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249665.113.clone.1) + %add.247071.35.clone.1 = u32[1280,1280]{1,0} add(%add.247068.37.clone.1, %add.247069.99.clone.1) + %shift-left.109045.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247069.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115210.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247069.99.clone.1, %broadcast.244415.6016) + %or.114744.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109045.31.clone.1, %shift-right-logical.115210.29.clone.1) + %xor.121291.27.clone.1 = u32[1280,1280]{1,0} xor(%add.247071.35.clone.1, %or.114744.29.clone.1) + %add.247072.5.clone.1 = u32[1280,1280]{1,0} add(%add.247071.35.clone.1, %xor.121291.27.clone.1) + %shift-left.109046.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121291.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115212.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121291.27.clone.1, %broadcast.244417.5760) + %or.114745.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109046.9.clone.1, %shift-right-logical.115212.9.clone.1) + %xor.121292.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247072.5.clone.1, %or.114745.7.clone.1) + %add.247074.3.clone.1 = u32[1280,1280]{1,0} add(%add.247072.5.clone.1, %xor.121292.5.clone.1) + %shift-left.109047.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121292.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115213.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121292.5.clone.1, %broadcast.244419.4352) + %or.114746.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109047.5.clone.1, %shift-right-logical.115213.5.clone.1) + %xor.121293.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247074.3.clone.1, %or.114746.3.clone.1) + %add.247075.3.clone.1 = u32[1280,1280]{1,0} add(%add.247074.3.clone.1, %xor.121293.3.clone.1) + %add.247077.7.clone.1 = u32[1280,1280]{1,0} add(%add.247075.3.clone.1, %broadcast.249665.113.clone.1) + %shift-left.109049.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121293.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115214.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121293.3.clone.1, %broadcast.244418.4352) + %or.114747.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109049.5.clone.1, %shift-right-logical.115214.5.clone.1) + %xor.121294.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247075.3.clone.1, %or.114747.3.clone.1) + %constant_218096_1_clone_1 = u32[] constant(207136386) + %broadcast.249675.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218096_1_clone_1), dimensions={} + %add.247078.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121294.3.clone.1, %broadcast.249675.5.clone.1) + %add.247080.5.clone.1 = u32[1280,1280]{1,0} add(%add.247077.7.clone.1, %add.247078.5.clone.1) + %shift-left.109050.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247078.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115215.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247078.5.clone.1, %broadcast.244416.5760) + %or.114748.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109050.9.clone.1, %shift-right-logical.115215.9.clone.1) + %xor.121296.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247080.5.clone.1, %or.114748.7.clone.1) + %add.247081.3.clone.1 = u32[1280,1280]{1,0} add(%add.247080.5.clone.1, %xor.121296.5.clone.1) + %shift-left.109051.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121296.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115216.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121296.5.clone.1, %broadcast.244429.2304) + %or.114749.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109051.9.clone.1, %shift-right-logical.115216.9.clone.1) + %xor.121297.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247081.3.clone.1, %or.114749.7.clone.1) + %add.247083.3.clone.1 = u32[1280,1280]{1,0} add(%add.247081.3.clone.1, %xor.121297.5.clone.1) + %shift-left.109052.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121297.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115217.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121297.5.clone.1, %broadcast.244430.4608) + %or.114750.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109052.9.clone.1, %shift-right-logical.115217.9.clone.1) + %xor.121298.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247083.3.clone.1, %or.114750.7.clone.1) + %add.247084.3.clone.1 = u32[1280,1280]{1,0} add(%add.247083.3.clone.1, %xor.121298.5.clone.1) + %constant_170304_1_clone_1 = u32[] constant(207136385) + %broadcast.249682.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_170304_1_clone_1), dimensions={} + %add.247086.7.clone.1 = u32[1280,1280]{1,0} add(%add.247084.3.clone.1, %broadcast.249682.24.clone.1) + %shift-left.109054.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121298.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115218.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121298.5.clone.1, %broadcast.244434.2816) + %or.114751.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109054.11.clone.1, %shift-right-logical.115218.11.clone.1) + %xor.121299.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247084.3.clone.1, %or.114751.9.clone.1) + %constant_218097_1_clone_1 = u32[] constant(162910665) + %broadcast.249685.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218097_1_clone_1), dimensions={} + %add.247087.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121299.7.clone.1, %broadcast.249685.5.clone.1) + %add.247089.5.clone.1 = u32[1280,1280]{1,0} add(%add.247086.7.clone.1, %add.247087.5.clone.1) + %shift-left.109055.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247087.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115219.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247087.5.clone.1, %broadcast.244415.6016) + %or.114752.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109055.9.clone.1, %shift-right-logical.115219.9.clone.1) + %xor.121300.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247089.5.clone.1, %or.114752.7.clone.1) + %add.247090.3.clone.1 = u32[1280,1280]{1,0} add(%add.247089.5.clone.1, %xor.121300.5.clone.1) + %shift-left.109056.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121300.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115220.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121300.5.clone.1, %broadcast.244417.5760) + %or.114753.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109056.9.clone.1, %shift-right-logical.115220.9.clone.1) + %xor.121301.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247090.3.clone.1, %or.114753.7.clone.1) + %add.247091.3.clone.1 = u32[1280,1280]{1,0} add(%add.247090.3.clone.1, %xor.121301.5.clone.1) + %shift-left.109057.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121301.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115222.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121301.5.clone.1, %broadcast.244419.4352) + %or.114754.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109057.7.clone.1, %shift-right-logical.115222.7.clone.1) + %xor.121302.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247091.3.clone.1, %or.114754.5.clone.1) + %add.247092.3.clone.1 = u32[1280,1280]{1,0} add(%add.247091.3.clone.1, %xor.121302.3.clone.1) + %add.247093.7.clone.1 = u32[1280,1280]{1,0} add(%add.247092.3.clone.1, %broadcast.249664.44.clone.1) + %shift-left.109059.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121302.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115223.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121302.3.clone.1, %broadcast.244418.4352) + %or.114755.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109059.7.clone.1, %shift-right-logical.115223.7.clone.1) + %xor.121303.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247092.3.clone.1, %or.114755.5.clone.1) + %constant_218098_1_clone_1 = u32[] constant(507276447) + %broadcast.249695.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218098_1_clone_1), dimensions={} + %add.247094.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121303.3.clone.1, %broadcast.249695.5.clone.1) + %add.247095.5.clone.1 = u32[1280,1280]{1,0} add(%add.247093.7.clone.1, %add.247094.5.clone.1) + %shift-left.109060.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247094.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115224.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247094.5.clone.1, %broadcast.244416.5760) + %or.114756.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109060.9.clone.1, %shift-right-logical.115224.9.clone.1) + %xor.121304.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247095.5.clone.1, %or.114756.7.clone.1) + %add.247096.3.clone.1 = u32[1280,1280]{1,0} add(%add.247095.5.clone.1, %xor.121304.5.clone.1) + %shift-left.109061.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121304.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115225.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121304.5.clone.1, %broadcast.244429.2304) + %or.114757.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109061.9.clone.1, %shift-right-logical.115225.9.clone.1) + %xor.121305.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247096.3.clone.1, %or.114757.7.clone.1) + %add.247097.3.clone.1 = u32[1280,1280]{1,0} add(%add.247096.3.clone.1, %xor.121305.5.clone.1) + %shift-left.109062.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121305.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115227.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121305.5.clone.1, %broadcast.244430.4608) + %or.114758.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109062.9.clone.1, %shift-right-logical.115227.9.clone.1) + %xor.121306.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247097.3.clone.1, %or.114758.7.clone.1) + %add.247098.3.clone.1 = u32[1280,1280]{1,0} add(%add.247097.3.clone.1, %xor.121306.5.clone.1) + %add.247099.7.clone.1 = u32[1280,1280]{1,0} add(%add.247098.3.clone.1, %broadcast.249665.113.clone.1) + %shift-left.109063.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121306.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115228.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121306.5.clone.1, %broadcast.244434.2816) + %or.114759.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109063.11.clone.1, %shift-right-logical.115228.11.clone.1) + %xor.121307.7.clone.1 = u32[1280,1280]{1,0} xor(%add.247098.3.clone.1, %or.114759.9.clone.1) + %constant_218099_1_clone_1 = u32[] constant(207136389) + %broadcast.249705.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218099_1_clone_1), dimensions={} + %add.247100.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121307.7.clone.1, %broadcast.249705.5.clone.1) + %add.247101.5.clone.1 = u32[1280,1280]{1,0} add(%add.247099.7.clone.1, %add.247100.5.clone.1) + %shift-left.109064.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.247100.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115229.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.247100.5.clone.1, %broadcast.244415.6016) + %or.114760.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109064.9.clone.1, %shift-right-logical.115229.9.clone.1) + %xor.121308.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247101.5.clone.1, %or.114760.7.clone.1) + %add.247102.3.clone.1 = u32[1280,1280]{1,0} add(%add.247101.5.clone.1, %xor.121308.5.clone.1) + %shift-left.109065.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121308.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115230.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121308.5.clone.1, %broadcast.244417.5760) + %or.114761.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109065.9.clone.1, %shift-right-logical.115230.9.clone.1) + %xor.121309.5.clone.1 = u32[1280,1280]{1,0} xor(%add.247102.3.clone.1, %or.114761.7.clone.1) + %add.247103.3.clone.1 = u32[1280,1280]{1,0} add(%add.247102.3.clone.1, %xor.121309.5.clone.1) + %shift-left.109066.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121309.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115232.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121309.5.clone.1, %broadcast.244419.4352) + %or.114762.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109066.5.clone.1, %shift-right-logical.115232.5.clone.1) + %xor.121310.3.clone.1 = u32[1280,1280]{1,0} xor(%add.247103.3.clone.1, %or.114762.3.clone.1) + %add.247104.3.clone.1 = u32[1280,1280]{1,0} add(%add.247103.3.clone.1, %xor.121310.3.clone.1) + %add.247105.17.clone.1 = u32[1280,1280]{1,0} add(%add.247104.3.clone.1, %broadcast.249682.24.clone.1) + %shift-left.109067.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121310.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115233.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121310.3.clone.1, %broadcast.244418.4352) + %or.114763.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.109067.5.clone.1, %shift-right-logical.115233.5.clone.1) + %xor.121311.15.clone.1 = u32[1280,1280]{1,0} xor(%add.247104.3.clone.1, %or.114763.3.clone.1) + %constant_218100_1_clone_1 = u32[] constant(162910668) + %broadcast.249724.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218100_1_clone_1), dimensions={} + %add.247106.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121311.15.clone.1, %broadcast.249724.19.clone.1) + %xor.121312.17.clone.1 = u32[1280,1280]{1,0} xor(%add.247105.17.clone.1, %add.247106.19.clone.1) + %shift-right-logical.115234.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121312.17.clone.1, %broadcast.244468.1920) + %or.114764.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115234.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5729.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114764.13.clone.1) + %add.247107.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5729.11.clone.1, %broadcast.244470.1152) + %multiply.26200.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247107.9.clone.1, %broadcast.244471.896) + %add.247108.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26200.7.clone.1, %broadcast.244408.1024) + %maximum.3661.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.247108.5.clone.1) + %abs.1523.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3661.3.clone.1) + %compare.7194.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1523.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26201.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3661.3.clone.1, %broadcast.244476.1152) + %negate.4551.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3661.3.clone.1) + %multiply.26202.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3661.3.clone.1, %negate.4551.5.clone.1) + %log-plus-one.1523.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26202.5.clone.1) + %negate.4552.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1523.3.clone.1) + %compare.7195.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4552.4.clone.1, %broadcast.244477.384), direction=LT + %select.20818.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20819.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20820.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20821.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20822.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20823.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20824.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20825.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20826.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.247109.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4552.4.clone.1, %broadcast.244496.640) + %sqrt.1523.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4552.4.clone.1) + %add.247110.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1523.5.clone.1, %broadcast.244498.640) + %select.20827.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7195.3.clone.1, %add.247109.5.clone.1, %add.247110.5.clone.1) + %multiply.26203.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20826.3.clone.1, %select.20827.3.clone.1) + %add.247111.1.clone.1 = f32[1280,1280]{1,0} add(%select.20825.3.clone.1, %multiply.26203.1.clone.1) + %multiply.26204.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247111.1.clone.1, %select.20827.3.clone.1) + %add.247112.1.clone.1 = f32[1280,1280]{1,0} add(%select.20824.3.clone.1, %multiply.26204.1.clone.1) + %multiply.26205.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247112.1.clone.1, %select.20827.3.clone.1) + %add.247113.1.clone.1 = f32[1280,1280]{1,0} add(%select.20823.3.clone.1, %multiply.26205.1.clone.1) + %multiply.26206.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247113.1.clone.1, %select.20827.3.clone.1) + %add.247114.1.clone.1 = f32[1280,1280]{1,0} add(%select.20822.3.clone.1, %multiply.26206.1.clone.1) + %multiply.26207.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247114.1.clone.1, %select.20827.3.clone.1) + %add.247115.3.clone.1 = f32[1280,1280]{1,0} add(%select.20821.5.clone.1, %multiply.26207.1.clone.1) + %multiply.26208.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.247115.3.clone.1, %select.20827.3.clone.1) + %add.247116.3.clone.1 = f32[1280,1280]{1,0} add(%select.20820.5.clone.1, %multiply.26208.1.clone.1) + %multiply.26209.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247116.3.clone.1, %select.20827.3.clone.1) + %add.247117.9.clone.1 = f32[1280,1280]{1,0} add(%select.20819.11.clone.1, %multiply.26209.7.clone.1) + %multiply.26210.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247117.9.clone.1, %select.20827.3.clone.1) + %add.247118.7.clone.1 = f32[1280,1280]{1,0} add(%select.20818.7.clone.1, %multiply.26210.7.clone.1) + %multiply.26211.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.247118.7.clone.1, %maximum.3661.3.clone.1) + %select.20828.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7194.3.clone.1, %multiply.26201.9.clone.1, %multiply.26211.7.clone.1) + %multiply.26212.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20828.7.clone.1, %broadcast.244500.640) + %clamp.1167.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26212.5.clone.1, %broadcast.244501.384) + %multiply.26213.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1167.3.clone.1, %broadcast.244502.1) + %constant_183698_1_clone_1 = u32[] constant(3274149152) + %broadcast.255466.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183698_1_clone_1), dimensions={} + %add.250377.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255466.44.clone.1) + %constant_183705_1_clone_1 = u32[] constant(879129337) + %broadcast.255468.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183705_1_clone_1), dimensions={} + %add.250378.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255468.113.clone.1) + %add.250380.35.clone.1 = u32[1280,1280]{1,0} add(%add.250377.37.clone.1, %add.250378.99.clone.1) + %shift-left.110480.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250378.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116730.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250378.99.clone.1, %broadcast.244415.6016) + %or.116245.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110480.31.clone.1, %shift-right-logical.116730.29.clone.1) + %xor.122813.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250380.35.clone.1, %or.116245.29.clone.1) + %add.250381.5.clone.1 = u32[1280,1280]{1,0} add(%add.250380.35.clone.1, %xor.122813.27.clone.1) + %shift-left.110481.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122813.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116731.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122813.27.clone.1, %broadcast.244417.5760) + %or.116246.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110481.9.clone.1, %shift-right-logical.116731.9.clone.1) + %xor.122814.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250381.5.clone.1, %or.116246.7.clone.1) + %add.250383.3.clone.1 = u32[1280,1280]{1,0} add(%add.250381.5.clone.1, %xor.122814.5.clone.1) + %shift-left.110482.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122814.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116732.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122814.5.clone.1, %broadcast.244419.4352) + %or.116247.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110482.5.clone.1, %shift-right-logical.116732.5.clone.1) + %xor.122815.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250383.3.clone.1, %or.116247.3.clone.1) + %add.250384.3.clone.1 = u32[1280,1280]{1,0} add(%add.250383.3.clone.1, %xor.122815.3.clone.1) + %add.250386.7.clone.1 = u32[1280,1280]{1,0} add(%add.250384.3.clone.1, %broadcast.255468.113.clone.1) + %shift-left.110483.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122815.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116733.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122815.3.clone.1, %broadcast.244418.4352) + %or.116248.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110483.5.clone.1, %shift-right-logical.116733.5.clone.1) + %xor.122816.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250384.3.clone.1, %or.116248.3.clone.1) + %constant_218464_1_clone_1 = u32[] constant(3968918532) + %broadcast.255478.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218464_1_clone_1), dimensions={} + %add.250387.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122816.3.clone.1, %broadcast.255478.5.clone.1) + %add.250389.5.clone.1 = u32[1280,1280]{1,0} add(%add.250386.7.clone.1, %add.250387.5.clone.1) + %shift-left.110484.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250387.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116734.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250387.5.clone.1, %broadcast.244416.5760) + %or.116249.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110484.9.clone.1, %shift-right-logical.116734.9.clone.1) + %xor.122817.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250389.5.clone.1, %or.116249.7.clone.1) + %add.250390.3.clone.1 = u32[1280,1280]{1,0} add(%add.250389.5.clone.1, %xor.122817.5.clone.1) + %shift-left.110485.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122817.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116735.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122817.5.clone.1, %broadcast.244429.2304) + %or.116250.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110485.9.clone.1, %shift-right-logical.116735.9.clone.1) + %xor.122818.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250390.3.clone.1, %or.116250.7.clone.1) + %add.250391.3.clone.1 = u32[1280,1280]{1,0} add(%add.250390.3.clone.1, %xor.122818.5.clone.1) + %shift-left.110486.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122818.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116736.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122818.5.clone.1, %broadcast.244430.4608) + %or.116251.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110486.9.clone.1, %shift-right-logical.116736.9.clone.1) + %xor.122819.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250391.3.clone.1, %or.116251.7.clone.1) + %add.250392.3.clone.1 = u32[1280,1280]{1,0} add(%add.250391.3.clone.1, %xor.122819.5.clone.1) + %constant_183707_1_clone_1 = u32[] constant(3968918531) + %broadcast.255485.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183707_1_clone_1), dimensions={} + %add.250393.7.clone.1 = u32[1280,1280]{1,0} add(%add.250392.3.clone.1, %broadcast.255485.24.clone.1) + %shift-left.110487.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122819.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116737.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122819.5.clone.1, %broadcast.244434.2816) + %or.116252.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110487.11.clone.1, %shift-right-logical.116737.11.clone.1) + %xor.122820.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250392.3.clone.1, %or.116252.9.clone.1) + %constant_218465_1_clone_1 = u32[] constant(3274149154) + %broadcast.255488.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218465_1_clone_1), dimensions={} + %add.250394.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122820.7.clone.1, %broadcast.255488.5.clone.1) + %add.250395.5.clone.1 = u32[1280,1280]{1,0} add(%add.250393.7.clone.1, %add.250394.5.clone.1) + %shift-left.110488.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250394.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116739.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250394.5.clone.1, %broadcast.244415.6016) + %or.116253.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110488.9.clone.1, %shift-right-logical.116739.9.clone.1) + %xor.122821.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250395.5.clone.1, %or.116253.7.clone.1) + %add.250396.3.clone.1 = u32[1280,1280]{1,0} add(%add.250395.5.clone.1, %xor.122821.5.clone.1) + %shift-left.110489.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122821.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116740.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122821.5.clone.1, %broadcast.244417.5760) + %or.116254.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110489.9.clone.1, %shift-right-logical.116740.9.clone.1) + %xor.122822.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250396.3.clone.1, %or.116254.7.clone.1) + %add.250397.3.clone.1 = u32[1280,1280]{1,0} add(%add.250396.3.clone.1, %xor.122822.5.clone.1) + %shift-left.110490.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122822.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116741.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122822.5.clone.1, %broadcast.244419.4352) + %or.116255.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110490.7.clone.1, %shift-right-logical.116741.7.clone.1) + %xor.122823.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250397.3.clone.1, %or.116255.5.clone.1) + %add.250398.3.clone.1 = u32[1280,1280]{1,0} add(%add.250397.3.clone.1, %xor.122823.3.clone.1) + %add.250399.7.clone.1 = u32[1280,1280]{1,0} add(%add.250398.3.clone.1, %broadcast.255466.44.clone.1) + %shift-left.110491.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122823.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116742.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122823.3.clone.1, %broadcast.244418.4352) + %or.116256.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110491.7.clone.1, %shift-right-logical.116742.7.clone.1) + %xor.122824.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250398.3.clone.1, %or.116256.5.clone.1) + %constant_218466_1_clone_1 = u32[] constant(879129340) + %broadcast.255500.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218466_1_clone_1), dimensions={} + %add.250400.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122824.3.clone.1, %broadcast.255500.5.clone.1) + %add.250401.5.clone.1 = u32[1280,1280]{1,0} add(%add.250399.7.clone.1, %add.250400.5.clone.1) + %shift-left.110492.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250400.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116744.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250400.5.clone.1, %broadcast.244416.5760) + %or.116257.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110492.9.clone.1, %shift-right-logical.116744.9.clone.1) + %xor.122825.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250401.5.clone.1, %or.116257.7.clone.1) + %add.250402.3.clone.1 = u32[1280,1280]{1,0} add(%add.250401.5.clone.1, %xor.122825.5.clone.1) + %shift-left.110493.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122825.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116745.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122825.5.clone.1, %broadcast.244429.2304) + %or.116258.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110493.9.clone.1, %shift-right-logical.116745.9.clone.1) + %xor.122826.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250402.3.clone.1, %or.116258.7.clone.1) + %add.250403.3.clone.1 = u32[1280,1280]{1,0} add(%add.250402.3.clone.1, %xor.122826.5.clone.1) + %shift-left.110494.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122826.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116746.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122826.5.clone.1, %broadcast.244430.4608) + %or.116259.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110494.9.clone.1, %shift-right-logical.116746.9.clone.1) + %xor.122827.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250403.3.clone.1, %or.116259.7.clone.1) + %add.250404.3.clone.1 = u32[1280,1280]{1,0} add(%add.250403.3.clone.1, %xor.122827.5.clone.1) + %add.250405.7.clone.1 = u32[1280,1280]{1,0} add(%add.250404.3.clone.1, %broadcast.255468.113.clone.1) + %shift-left.110495.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122827.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116747.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122827.5.clone.1, %broadcast.244434.2816) + %or.116260.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110495.11.clone.1, %shift-right-logical.116747.11.clone.1) + %xor.122828.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250404.3.clone.1, %or.116260.9.clone.1) + %constant_218467_1_clone_1 = u32[] constant(3968918535) + %broadcast.255510.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218467_1_clone_1), dimensions={} + %add.250406.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122828.7.clone.1, %broadcast.255510.5.clone.1) + %add.250407.5.clone.1 = u32[1280,1280]{1,0} add(%add.250405.7.clone.1, %add.250406.5.clone.1) + %shift-left.110496.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250406.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116749.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250406.5.clone.1, %broadcast.244415.6016) + %or.116261.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110496.9.clone.1, %shift-right-logical.116749.9.clone.1) + %xor.122831.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250407.5.clone.1, %or.116261.7.clone.1) + %add.250408.3.clone.1 = u32[1280,1280]{1,0} add(%add.250407.5.clone.1, %xor.122831.5.clone.1) + %shift-left.110497.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122831.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116750.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122831.5.clone.1, %broadcast.244417.5760) + %or.116262.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110497.9.clone.1, %shift-right-logical.116750.9.clone.1) + %xor.122832.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250408.3.clone.1, %or.116262.7.clone.1) + %add.250409.3.clone.1 = u32[1280,1280]{1,0} add(%add.250408.3.clone.1, %xor.122832.5.clone.1) + %shift-left.110498.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122832.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116751.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122832.5.clone.1, %broadcast.244419.4352) + %or.116263.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110498.5.clone.1, %shift-right-logical.116751.5.clone.1) + %xor.122833.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250409.3.clone.1, %or.116263.3.clone.1) + %add.250410.3.clone.1 = u32[1280,1280]{1,0} add(%add.250409.3.clone.1, %xor.122833.3.clone.1) + %add.250411.17.clone.1 = u32[1280,1280]{1,0} add(%add.250410.3.clone.1, %broadcast.255485.24.clone.1) + %shift-left.110499.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122833.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116752.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122833.3.clone.1, %broadcast.244418.4352) + %or.116264.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110499.5.clone.1, %shift-right-logical.116752.5.clone.1) + %xor.122834.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250410.3.clone.1, %or.116264.3.clone.1) + %constant_218468_1_clone_1 = u32[] constant(3274149157) + %broadcast.255522.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218468_1_clone_1), dimensions={} + %add.250412.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122834.15.clone.1, %broadcast.255522.19.clone.1) + %xor.122835.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250411.17.clone.1, %add.250412.19.clone.1) + %shift-right-logical.116754.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122835.17.clone.1, %broadcast.244468.1920) + %or.116265.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116754.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5795.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116265.13.clone.1) + %add.250413.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5795.11.clone.1, %broadcast.244470.1152) + %multiply.26875.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250413.9.clone.1, %broadcast.244471.896) + %add.250414.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26875.7.clone.1, %broadcast.244408.1024) + %maximum.3727.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250414.5.clone.1) + %abs.1567.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3727.3.clone.1) + %compare.7285.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1567.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26876.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3727.3.clone.1, %broadcast.244476.1152) + %negate.4639.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3727.3.clone.1) + %multiply.26877.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3727.3.clone.1, %negate.4639.5.clone.1) + %log-plus-one.1567.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26877.5.clone.1) + %negate.4640.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1567.3.clone.1) + %compare.7286.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4640.4.clone.1, %broadcast.244477.384), direction=LT + %select.21323.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21324.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21325.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21326.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21327.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21328.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21329.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21330.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21331.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250415.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4640.4.clone.1, %broadcast.244496.640) + %sqrt.1567.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4640.4.clone.1) + %add.250416.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1567.5.clone.1, %broadcast.244498.640) + %select.21332.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7286.3.clone.1, %add.250415.5.clone.1, %add.250416.5.clone.1) + %multiply.26878.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21331.3.clone.1, %select.21332.3.clone.1) + %add.250417.1.clone.1 = f32[1280,1280]{1,0} add(%select.21330.3.clone.1, %multiply.26878.1.clone.1) + %multiply.26879.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250417.1.clone.1, %select.21332.3.clone.1) + %add.250418.1.clone.1 = f32[1280,1280]{1,0} add(%select.21329.3.clone.1, %multiply.26879.1.clone.1) + %multiply.26880.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250418.1.clone.1, %select.21332.3.clone.1) + %add.250419.1.clone.1 = f32[1280,1280]{1,0} add(%select.21328.3.clone.1, %multiply.26880.1.clone.1) + %multiply.26881.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250419.1.clone.1, %select.21332.3.clone.1) + %add.250420.1.clone.1 = f32[1280,1280]{1,0} add(%select.21327.3.clone.1, %multiply.26881.1.clone.1) + %multiply.26882.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250420.1.clone.1, %select.21332.3.clone.1) + %add.250421.3.clone.1 = f32[1280,1280]{1,0} add(%select.21326.5.clone.1, %multiply.26882.1.clone.1) + %multiply.26883.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250421.3.clone.1, %select.21332.3.clone.1) + %add.250422.3.clone.1 = f32[1280,1280]{1,0} add(%select.21325.5.clone.1, %multiply.26883.1.clone.1) + %multiply.26884.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250422.3.clone.1, %select.21332.3.clone.1) + %add.250423.9.clone.1 = f32[1280,1280]{1,0} add(%select.21324.11.clone.1, %multiply.26884.7.clone.1) + %multiply.26885.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250423.9.clone.1, %select.21332.3.clone.1) + %add.250425.7.clone.1 = f32[1280,1280]{1,0} add(%select.21323.7.clone.1, %multiply.26885.7.clone.1) + %multiply.26886.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250425.7.clone.1, %maximum.3727.3.clone.1) + %select.21333.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7285.3.clone.1, %multiply.26876.9.clone.1, %multiply.26886.7.clone.1) + %multiply.26887.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21333.7.clone.1, %broadcast.244500.640) + %clamp.1211.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26887.5.clone.1, %broadcast.244501.384) + %multiply.26888.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1211.3.clone.1, %broadcast.244502.1) + %constant_169744_1_clone_1 = u32[] constant(2598801053) + %broadcast.249427.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169744_1_clone_1), dimensions={} + %add.246951.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249427.44.clone.1) + %constant_169751_1_clone_1 = u32[] constant(1962997650) + %broadcast.249429.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169751_1_clone_1), dimensions={} + %add.246952.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249429.113.clone.1) + %add.246953.35.clone.1 = u32[1280,1280]{1,0} add(%add.246951.37.clone.1, %add.246952.99.clone.1) + %shift-left.108980.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246952.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115140.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246952.99.clone.1, %broadcast.244415.6016) + %or.114669.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108980.31.clone.1, %shift-right-logical.115140.29.clone.1) + %xor.121222.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246953.35.clone.1, %or.114669.29.clone.1) + %add.246955.5.clone.1 = u32[1280,1280]{1,0} add(%add.246953.35.clone.1, %xor.121222.27.clone.1) + %shift-left.108981.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121222.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115141.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121222.27.clone.1, %broadcast.244417.5760) + %or.114670.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108981.9.clone.1, %shift-right-logical.115141.9.clone.1) + %xor.121223.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246955.5.clone.1, %or.114670.7.clone.1) + %add.246958.3.clone.1 = u32[1280,1280]{1,0} add(%add.246955.5.clone.1, %xor.121223.5.clone.1) + %shift-left.108982.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121223.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115142.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121223.5.clone.1, %broadcast.244419.4352) + %or.114672.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108982.5.clone.1, %shift-right-logical.115142.5.clone.1) + %xor.121224.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246958.3.clone.1, %or.114672.3.clone.1) + %add.246959.3.clone.1 = u32[1280,1280]{1,0} add(%add.246958.3.clone.1, %xor.121224.3.clone.1) + %add.246960.7.clone.1 = u32[1280,1280]{1,0} add(%add.246959.3.clone.1, %broadcast.249429.113.clone.1) + %shift-left.108983.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121224.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115143.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121224.3.clone.1, %broadcast.244418.4352) + %or.114673.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108983.5.clone.1, %shift-right-logical.115143.5.clone.1) + %xor.121225.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246959.3.clone.1, %or.114673.3.clone.1) + %constant_218081_1_clone_1 = u32[] constant(4097270486) + %broadcast.249446.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218081_1_clone_1), dimensions={} + %add.246961.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121225.3.clone.1, %broadcast.249446.5.clone.1) + %add.246962.5.clone.1 = u32[1280,1280]{1,0} add(%add.246960.7.clone.1, %add.246961.5.clone.1) + %shift-left.108984.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246961.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115144.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246961.5.clone.1, %broadcast.244416.5760) + %or.114674.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108984.9.clone.1, %shift-right-logical.115144.9.clone.1) + %xor.121227.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246962.5.clone.1, %or.114674.7.clone.1) + %add.246963.3.clone.1 = u32[1280,1280]{1,0} add(%add.246962.5.clone.1, %xor.121227.5.clone.1) + %shift-left.108985.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121227.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115145.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121227.5.clone.1, %broadcast.244429.2304) + %or.114675.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108985.9.clone.1, %shift-right-logical.115145.9.clone.1) + %xor.121228.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246963.3.clone.1, %or.114675.7.clone.1) + %add.246964.3.clone.1 = u32[1280,1280]{1,0} add(%add.246963.3.clone.1, %xor.121228.5.clone.1) + %shift-left.108986.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121228.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115146.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121228.5.clone.1, %broadcast.244430.4608) + %or.114677.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108986.9.clone.1, %shift-right-logical.115146.9.clone.1) + %xor.121229.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246964.3.clone.1, %or.114677.7.clone.1) + %add.246965.3.clone.1 = u32[1280,1280]{1,0} add(%add.246964.3.clone.1, %xor.121229.5.clone.1) + %constant_169753_1_clone_1 = u32[] constant(4097270485) + %broadcast.249460.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169753_1_clone_1), dimensions={} + %add.246966.7.clone.1 = u32[1280,1280]{1,0} add(%add.246965.3.clone.1, %broadcast.249460.24.clone.1) + %shift-left.108987.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121229.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115147.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121229.5.clone.1, %broadcast.244434.2816) + %or.114678.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108987.11.clone.1, %shift-right-logical.115147.11.clone.1) + %xor.121230.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246965.3.clone.1, %or.114678.9.clone.1) + %constant_218082_1_clone_1 = u32[] constant(2598801055) + %broadcast.249466.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218082_1_clone_1), dimensions={} + %add.246967.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121230.7.clone.1, %broadcast.249466.5.clone.1) + %add.246968.5.clone.1 = u32[1280,1280]{1,0} add(%add.246966.7.clone.1, %add.246967.5.clone.1) + %shift-left.108988.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246967.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115148.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246967.5.clone.1, %broadcast.244415.6016) + %or.114679.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108988.9.clone.1, %shift-right-logical.115148.9.clone.1) + %xor.121232.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246968.5.clone.1, %or.114679.7.clone.1) + %add.246969.3.clone.1 = u32[1280,1280]{1,0} add(%add.246968.5.clone.1, %xor.121232.5.clone.1) + %shift-left.108989.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121232.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115149.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121232.5.clone.1, %broadcast.244417.5760) + %or.114680.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108989.9.clone.1, %shift-right-logical.115149.9.clone.1) + %xor.121233.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246969.3.clone.1, %or.114680.7.clone.1) + %add.246970.3.clone.1 = u32[1280,1280]{1,0} add(%add.246969.3.clone.1, %xor.121233.5.clone.1) + %shift-left.108990.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121233.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115150.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121233.5.clone.1, %broadcast.244419.4352) + %or.114682.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108990.7.clone.1, %shift-right-logical.115150.7.clone.1) + %xor.121234.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246970.3.clone.1, %or.114682.5.clone.1) + %add.246971.3.clone.1 = u32[1280,1280]{1,0} add(%add.246970.3.clone.1, %xor.121234.3.clone.1) + %add.246972.7.clone.1 = u32[1280,1280]{1,0} add(%add.246971.3.clone.1, %broadcast.249427.44.clone.1) + %shift-left.108991.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121234.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115151.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121234.3.clone.1, %broadcast.244418.4352) + %or.114683.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108991.7.clone.1, %shift-right-logical.115151.7.clone.1) + %xor.121235.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246971.3.clone.1, %or.114683.5.clone.1) + %constant_218083_1_clone_1 = u32[] constant(1962997653) + %broadcast.249485.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218083_1_clone_1), dimensions={} + %add.246973.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121235.3.clone.1, %broadcast.249485.5.clone.1) + %add.246974.5.clone.1 = u32[1280,1280]{1,0} add(%add.246972.7.clone.1, %add.246973.5.clone.1) + %shift-left.108992.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246973.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115152.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246973.5.clone.1, %broadcast.244416.5760) + %or.114684.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108992.9.clone.1, %shift-right-logical.115152.9.clone.1) + %xor.121236.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246974.5.clone.1, %or.114684.7.clone.1) + %add.246975.3.clone.1 = u32[1280,1280]{1,0} add(%add.246974.5.clone.1, %xor.121236.5.clone.1) + %shift-left.108993.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121236.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115153.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121236.5.clone.1, %broadcast.244429.2304) + %or.114685.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108993.9.clone.1, %shift-right-logical.115153.9.clone.1) + %xor.121237.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246975.3.clone.1, %or.114685.7.clone.1) + %add.246976.3.clone.1 = u32[1280,1280]{1,0} add(%add.246975.3.clone.1, %xor.121237.5.clone.1) + %shift-left.108994.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121237.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115154.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121237.5.clone.1, %broadcast.244430.4608) + %or.114687.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108994.9.clone.1, %shift-right-logical.115154.9.clone.1) + %xor.121238.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246976.3.clone.1, %or.114687.7.clone.1) + %add.246977.3.clone.1 = u32[1280,1280]{1,0} add(%add.246976.3.clone.1, %xor.121238.5.clone.1) + %add.246978.7.clone.1 = u32[1280,1280]{1,0} add(%add.246977.3.clone.1, %broadcast.249429.113.clone.1) + %shift-left.108995.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121238.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115155.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121238.5.clone.1, %broadcast.244434.2816) + %or.114688.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108995.11.clone.1, %shift-right-logical.115155.11.clone.1) + %xor.121239.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246977.3.clone.1, %or.114688.9.clone.1) + %constant_218084_1_clone_1 = u32[] constant(4097270489) + %broadcast.249495.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218084_1_clone_1), dimensions={} + %add.246979.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121239.7.clone.1, %broadcast.249495.5.clone.1) + %add.246980.5.clone.1 = u32[1280,1280]{1,0} add(%add.246978.7.clone.1, %add.246979.5.clone.1) + %shift-left.108996.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246979.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115156.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246979.5.clone.1, %broadcast.244415.6016) + %or.114689.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108996.9.clone.1, %shift-right-logical.115156.9.clone.1) + %xor.121240.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246980.5.clone.1, %or.114689.7.clone.1) + %add.246981.3.clone.1 = u32[1280,1280]{1,0} add(%add.246980.5.clone.1, %xor.121240.5.clone.1) + %shift-left.108997.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121240.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115157.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121240.5.clone.1, %broadcast.244417.5760) + %or.114690.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108997.9.clone.1, %shift-right-logical.115157.9.clone.1) + %xor.121242.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246981.3.clone.1, %or.114690.7.clone.1) + %add.246982.3.clone.1 = u32[1280,1280]{1,0} add(%add.246981.3.clone.1, %xor.121242.5.clone.1) + %shift-left.108998.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121242.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115158.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121242.5.clone.1, %broadcast.244419.4352) + %or.114691.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108998.5.clone.1, %shift-right-logical.115158.5.clone.1) + %xor.121243.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246982.3.clone.1, %or.114691.3.clone.1) + %add.246983.3.clone.1 = u32[1280,1280]{1,0} add(%add.246982.3.clone.1, %xor.121243.3.clone.1) + %add.246984.17.clone.1 = u32[1280,1280]{1,0} add(%add.246983.3.clone.1, %broadcast.249460.24.clone.1) + %shift-left.108999.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121243.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115159.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121243.3.clone.1, %broadcast.244418.4352) + %or.114692.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108999.5.clone.1, %shift-right-logical.115159.5.clone.1) + %xor.121244.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246983.3.clone.1, %or.114692.3.clone.1) + %constant_218085_1_clone_1 = u32[] constant(2598801058) + %broadcast.249505.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218085_1_clone_1), dimensions={} + %add.246985.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121244.15.clone.1, %broadcast.249505.19.clone.1) + %xor.121245.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246984.17.clone.1, %add.246985.19.clone.1) + %shift-right-logical.115160.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121245.17.clone.1, %broadcast.244468.1920) + %or.114693.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115160.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5726.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114693.13.clone.1) + %add.246986.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5726.11.clone.1, %broadcast.244470.1152) + %multiply.26182.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246986.9.clone.1, %broadcast.244471.896) + %add.246987.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26182.7.clone.1, %broadcast.244408.1024) + %maximum.3658.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246987.5.clone.1) + %abs.1522.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3658.3.clone.1) + %compare.7192.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1522.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26183.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3658.3.clone.1, %broadcast.244476.1152) + %negate.4549.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3658.3.clone.1) + %multiply.26184.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3658.3.clone.1, %negate.4549.5.clone.1) + %log-plus-one.1522.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26184.5.clone.1) + %negate.4550.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1522.3.clone.1) + %compare.7193.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4550.4.clone.1, %broadcast.244477.384), direction=LT + %select.20807.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20808.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20809.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20810.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20811.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20812.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20813.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20814.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20815.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246988.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4550.4.clone.1, %broadcast.244496.640) + %sqrt.1522.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4550.4.clone.1) + %add.246989.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1522.5.clone.1, %broadcast.244498.640) + %select.20816.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7193.3.clone.1, %add.246988.5.clone.1, %add.246989.5.clone.1) + %multiply.26185.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20815.3.clone.1, %select.20816.3.clone.1) + %add.246990.1.clone.1 = f32[1280,1280]{1,0} add(%select.20814.3.clone.1, %multiply.26185.1.clone.1) + %multiply.26186.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246990.1.clone.1, %select.20816.3.clone.1) + %add.246991.1.clone.1 = f32[1280,1280]{1,0} add(%select.20813.3.clone.1, %multiply.26186.1.clone.1) + %multiply.26187.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246991.1.clone.1, %select.20816.3.clone.1) + %add.246992.1.clone.1 = f32[1280,1280]{1,0} add(%select.20812.3.clone.1, %multiply.26187.1.clone.1) + %multiply.26188.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246992.1.clone.1, %select.20816.3.clone.1) + %add.246993.1.clone.1 = f32[1280,1280]{1,0} add(%select.20811.3.clone.1, %multiply.26188.1.clone.1) + %multiply.26189.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246993.1.clone.1, %select.20816.3.clone.1) + %add.246994.3.clone.1 = f32[1280,1280]{1,0} add(%select.20810.5.clone.1, %multiply.26189.1.clone.1) + %multiply.26190.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246994.3.clone.1, %select.20816.3.clone.1) + %add.246995.3.clone.1 = f32[1280,1280]{1,0} add(%select.20809.5.clone.1, %multiply.26190.1.clone.1) + %multiply.26191.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246995.3.clone.1, %select.20816.3.clone.1) + %add.246996.9.clone.1 = f32[1280,1280]{1,0} add(%select.20808.11.clone.1, %multiply.26191.7.clone.1) + %multiply.26192.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246996.9.clone.1, %select.20816.3.clone.1) + %add.246998.7.clone.1 = f32[1280,1280]{1,0} add(%select.20807.7.clone.1, %multiply.26192.7.clone.1) + %multiply.26193.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246998.7.clone.1, %maximum.3658.3.clone.1) + %select.20817.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7192.3.clone.1, %multiply.26183.9.clone.1, %multiply.26193.7.clone.1) + %multiply.26194.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20817.7.clone.1, %broadcast.244500.640) + %clamp.1166.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26194.5.clone.1, %broadcast.244501.384) + %multiply.26195.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1166.3.clone.1, %broadcast.244502.1) + %constant_190241_1_clone_1 = u32[] constant(3290012183) + %broadcast.258286.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190241_1_clone_1), dimensions={} + %add.251997.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258286.44.clone.1) + %constant_190248_1_clone_1 = u32[] constant(4257276142) + %broadcast.258287.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190248_1_clone_1), dimensions={} + %add.251998.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258287.113.clone.1) + %add.252000.35.clone.1 = u32[1280,1280]{1,0} add(%add.251997.37.clone.1, %add.251998.99.clone.1) + %shift-left.111200.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251998.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117463.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251998.99.clone.1, %broadcast.244415.6016) + %or.116999.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111200.31.clone.1, %shift-right-logical.117463.29.clone.1) + %xor.123552.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252000.35.clone.1, %or.116999.29.clone.1) + %add.252001.5.clone.1 = u32[1280,1280]{1,0} add(%add.252000.35.clone.1, %xor.123552.27.clone.1) + %shift-left.111201.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123552.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117464.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123552.27.clone.1, %broadcast.244417.5760) + %or.117000.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111201.9.clone.1, %shift-right-logical.117464.9.clone.1) + %xor.123553.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252001.5.clone.1, %or.117000.7.clone.1) + %add.252002.3.clone.1 = u32[1280,1280]{1,0} add(%add.252001.5.clone.1, %xor.123553.5.clone.1) + %shift-left.111202.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123553.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117465.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123553.5.clone.1, %broadcast.244419.4352) + %or.117001.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111202.5.clone.1, %shift-right-logical.117465.5.clone.1) + %xor.123554.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252002.3.clone.1, %or.117001.3.clone.1) + %add.252003.3.clone.1 = u32[1280,1280]{1,0} add(%add.252002.3.clone.1, %xor.123554.3.clone.1) + %add.252004.7.clone.1 = u32[1280,1280]{1,0} add(%add.252003.3.clone.1, %broadcast.258287.113.clone.1) + %shift-left.111203.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123554.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117466.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123554.3.clone.1, %broadcast.244418.4352) + %or.117002.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111203.5.clone.1, %shift-right-logical.117466.5.clone.1) + %xor.123555.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252003.3.clone.1, %or.117002.3.clone.1) + %constant_218645_1_clone_1 = u32[] constant(570974500) + %broadcast.258297.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218645_1_clone_1), dimensions={} + %add.252006.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123555.3.clone.1, %broadcast.258297.5.clone.1) + %add.252010.5.clone.1 = u32[1280,1280]{1,0} add(%add.252004.7.clone.1, %add.252006.5.clone.1) + %shift-left.111204.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252006.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117467.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252006.5.clone.1, %broadcast.244416.5760) + %or.117003.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111204.9.clone.1, %shift-right-logical.117467.9.clone.1) + %xor.123556.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252010.5.clone.1, %or.117003.7.clone.1) + %add.252011.3.clone.1 = u32[1280,1280]{1,0} add(%add.252010.5.clone.1, %xor.123556.5.clone.1) + %shift-left.111205.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123556.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117468.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123556.5.clone.1, %broadcast.244429.2304) + %or.117004.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111205.9.clone.1, %shift-right-logical.117468.9.clone.1) + %xor.123557.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252011.3.clone.1, %or.117004.7.clone.1) + %add.252012.3.clone.1 = u32[1280,1280]{1,0} add(%add.252011.3.clone.1, %xor.123557.5.clone.1) + %shift-left.111206.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123557.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117469.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123557.5.clone.1, %broadcast.244430.4608) + %or.117005.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111206.9.clone.1, %shift-right-logical.117469.9.clone.1) + %xor.123558.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252012.3.clone.1, %or.117005.7.clone.1) + %add.252013.3.clone.1 = u32[1280,1280]{1,0} add(%add.252012.3.clone.1, %xor.123558.5.clone.1) + %constant_190250_1_clone_1 = u32[] constant(570974499) + %broadcast.258304.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190250_1_clone_1), dimensions={} + %add.252015.7.clone.1 = u32[1280,1280]{1,0} add(%add.252013.3.clone.1, %broadcast.258304.24.clone.1) + %shift-left.111207.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123558.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117471.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123558.5.clone.1, %broadcast.244434.2816) + %or.117006.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111207.11.clone.1, %shift-right-logical.117471.11.clone.1) + %xor.123559.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252013.3.clone.1, %or.117006.9.clone.1) + %constant_218646_1_clone_1 = u32[] constant(3290012185) + %broadcast.258307.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218646_1_clone_1), dimensions={} + %add.252016.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123559.7.clone.1, %broadcast.258307.5.clone.1) + %add.252017.5.clone.1 = u32[1280,1280]{1,0} add(%add.252015.7.clone.1, %add.252016.5.clone.1) + %shift-left.111208.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252016.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117472.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252016.5.clone.1, %broadcast.244415.6016) + %or.117007.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111208.9.clone.1, %shift-right-logical.117472.9.clone.1) + %xor.123560.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252017.5.clone.1, %or.117007.7.clone.1) + %add.252018.3.clone.1 = u32[1280,1280]{1,0} add(%add.252017.5.clone.1, %xor.123560.5.clone.1) + %shift-left.111209.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123560.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117473.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123560.5.clone.1, %broadcast.244417.5760) + %or.117008.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111209.9.clone.1, %shift-right-logical.117473.9.clone.1) + %xor.123561.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252018.3.clone.1, %or.117008.7.clone.1) + %add.252020.3.clone.1 = u32[1280,1280]{1,0} add(%add.252018.3.clone.1, %xor.123561.5.clone.1) + %shift-left.111210.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123561.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117474.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123561.5.clone.1, %broadcast.244419.4352) + %or.117009.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111210.7.clone.1, %shift-right-logical.117474.7.clone.1) + %xor.123562.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252020.3.clone.1, %or.117009.5.clone.1) + %add.252021.3.clone.1 = u32[1280,1280]{1,0} add(%add.252020.3.clone.1, %xor.123562.3.clone.1) + %add.252022.7.clone.1 = u32[1280,1280]{1,0} add(%add.252021.3.clone.1, %broadcast.258286.44.clone.1) + %shift-left.111211.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123562.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117476.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123562.3.clone.1, %broadcast.244418.4352) + %or.117010.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111211.7.clone.1, %shift-right-logical.117476.7.clone.1) + %xor.123563.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252021.3.clone.1, %or.117010.5.clone.1) + %constant_218647_1_clone_1 = u32[] constant(4257276145) + %broadcast.258317.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218647_1_clone_1), dimensions={} + %add.252023.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123563.3.clone.1, %broadcast.258317.5.clone.1) + %add.252025.5.clone.1 = u32[1280,1280]{1,0} add(%add.252022.7.clone.1, %add.252023.5.clone.1) + %shift-left.111212.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252023.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117477.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252023.5.clone.1, %broadcast.244416.5760) + %or.117011.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111212.9.clone.1, %shift-right-logical.117477.9.clone.1) + %xor.123566.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252025.5.clone.1, %or.117011.7.clone.1) + %add.252026.3.clone.1 = u32[1280,1280]{1,0} add(%add.252025.5.clone.1, %xor.123566.5.clone.1) + %shift-left.111213.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123566.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117478.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123566.5.clone.1, %broadcast.244429.2304) + %or.117012.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111213.9.clone.1, %shift-right-logical.117478.9.clone.1) + %xor.123567.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252026.3.clone.1, %or.117012.7.clone.1) + %add.252027.3.clone.1 = u32[1280,1280]{1,0} add(%add.252026.3.clone.1, %xor.123567.5.clone.1) + %shift-left.111214.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123567.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117479.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123567.5.clone.1, %broadcast.244430.4608) + %or.117013.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111214.9.clone.1, %shift-right-logical.117479.9.clone.1) + %xor.123568.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252027.3.clone.1, %or.117013.7.clone.1) + %add.252028.3.clone.1 = u32[1280,1280]{1,0} add(%add.252027.3.clone.1, %xor.123568.5.clone.1) + %add.252029.7.clone.1 = u32[1280,1280]{1,0} add(%add.252028.3.clone.1, %broadcast.258287.113.clone.1) + %shift-left.111215.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123568.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117481.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123568.5.clone.1, %broadcast.244434.2816) + %or.117014.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111215.11.clone.1, %shift-right-logical.117481.11.clone.1) + %xor.123569.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252028.3.clone.1, %or.117014.9.clone.1) + %constant_218648_1_clone_1 = u32[] constant(570974503) + %broadcast.258327.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218648_1_clone_1), dimensions={} + %add.252031.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123569.7.clone.1, %broadcast.258327.5.clone.1) + %add.252035.5.clone.1 = u32[1280,1280]{1,0} add(%add.252029.7.clone.1, %add.252031.5.clone.1) + %shift-left.111216.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252031.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117482.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252031.5.clone.1, %broadcast.244415.6016) + %or.117015.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111216.9.clone.1, %shift-right-logical.117482.9.clone.1) + %xor.123570.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252035.5.clone.1, %or.117015.7.clone.1) + %add.252036.3.clone.1 = u32[1280,1280]{1,0} add(%add.252035.5.clone.1, %xor.123570.5.clone.1) + %shift-left.111217.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123570.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117483.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123570.5.clone.1, %broadcast.244417.5760) + %or.117016.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111217.9.clone.1, %shift-right-logical.117483.9.clone.1) + %xor.123571.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252036.3.clone.1, %or.117016.7.clone.1) + %add.252037.3.clone.1 = u32[1280,1280]{1,0} add(%add.252036.3.clone.1, %xor.123571.5.clone.1) + %shift-left.111218.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123571.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117484.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123571.5.clone.1, %broadcast.244419.4352) + %or.117017.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111218.5.clone.1, %shift-right-logical.117484.5.clone.1) + %xor.123572.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252037.3.clone.1, %or.117017.3.clone.1) + %add.252038.3.clone.1 = u32[1280,1280]{1,0} add(%add.252037.3.clone.1, %xor.123572.3.clone.1) + %add.252040.17.clone.1 = u32[1280,1280]{1,0} add(%add.252038.3.clone.1, %broadcast.258304.24.clone.1) + %shift-left.111219.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123572.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117486.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123572.3.clone.1, %broadcast.244418.4352) + %or.117018.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111219.5.clone.1, %shift-right-logical.117486.5.clone.1) + %xor.123573.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252038.3.clone.1, %or.117018.3.clone.1) + %constant_218649_1_clone_1 = u32[] constant(3290012188) + %broadcast.258337.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218649_1_clone_1), dimensions={} + %add.252041.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123573.15.clone.1, %broadcast.258337.19.clone.1) + %xor.123575.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252040.17.clone.1, %add.252041.19.clone.1) + %shift-right-logical.117487.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123575.17.clone.1, %broadcast.244468.1920) + %or.117019.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117487.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5827.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117019.13.clone.1) + %add.252042.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5827.11.clone.1, %broadcast.244470.1152) + %multiply.27218.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252042.9.clone.1, %broadcast.244471.896) + %add.252043.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27218.7.clone.1, %broadcast.244408.1024) + %maximum.3759.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252043.5.clone.1) + %abs.1589.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3759.3.clone.1) + %compare.7340.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1589.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27219.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3759.3.clone.1, %broadcast.244476.1152) + %negate.4683.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3759.3.clone.1) + %multiply.27220.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3759.3.clone.1, %negate.4683.5.clone.1) + %log-plus-one.1589.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27220.5.clone.1) + %negate.4684.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1589.3.clone.1) + %compare.7341.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4684.4.clone.1, %broadcast.244477.384), direction=LT + %select.21581.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21582.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21583.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21584.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21586.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21591.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21592.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21593.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21594.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252045.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4684.4.clone.1, %broadcast.244496.640) + %sqrt.1589.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4684.4.clone.1) + %add.252046.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1589.5.clone.1, %broadcast.244498.640) + %select.21595.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7341.3.clone.1, %add.252045.5.clone.1, %add.252046.5.clone.1) + %multiply.27221.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21594.3.clone.1, %select.21595.3.clone.1) + %add.252047.1.clone.1 = f32[1280,1280]{1,0} add(%select.21593.3.clone.1, %multiply.27221.1.clone.1) + %multiply.27222.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252047.1.clone.1, %select.21595.3.clone.1) + %add.252048.1.clone.1 = f32[1280,1280]{1,0} add(%select.21592.3.clone.1, %multiply.27222.1.clone.1) + %multiply.27223.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252048.1.clone.1, %select.21595.3.clone.1) + %add.252050.1.clone.1 = f32[1280,1280]{1,0} add(%select.21591.3.clone.1, %multiply.27223.1.clone.1) + %multiply.27224.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252050.1.clone.1, %select.21595.3.clone.1) + %add.252051.1.clone.1 = f32[1280,1280]{1,0} add(%select.21586.3.clone.1, %multiply.27224.1.clone.1) + %multiply.27225.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252051.1.clone.1, %select.21595.3.clone.1) + %add.252052.3.clone.1 = f32[1280,1280]{1,0} add(%select.21584.5.clone.1, %multiply.27225.1.clone.1) + %multiply.27226.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252052.3.clone.1, %select.21595.3.clone.1) + %add.252053.3.clone.1 = f32[1280,1280]{1,0} add(%select.21583.5.clone.1, %multiply.27226.1.clone.1) + %multiply.27227.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252053.3.clone.1, %select.21595.3.clone.1) + %add.252054.9.clone.1 = f32[1280,1280]{1,0} add(%select.21582.11.clone.1, %multiply.27227.7.clone.1) + %multiply.27228.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252054.9.clone.1, %select.21595.3.clone.1) + %add.252056.7.clone.1 = f32[1280,1280]{1,0} add(%select.21581.7.clone.1, %multiply.27228.7.clone.1) + %multiply.27229.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252056.7.clone.1, %maximum.3759.3.clone.1) + %select.21596.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7340.3.clone.1, %multiply.27219.9.clone.1, %multiply.27229.7.clone.1) + %multiply.27230.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21596.7.clone.1, %broadcast.244500.640) + %clamp.1233.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27230.5.clone.1, %broadcast.244501.384) + %multiply.27231.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1233.3.clone.1, %broadcast.244502.1) + %constant_169501_1_clone_1 = u32[] constant(415291909) + %broadcast.249328.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169501_1_clone_1), dimensions={} + %add.246891.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249328.44.clone.1) + %constant_169508_1_clone_1 = u32[] constant(4252968395) + %broadcast.249329.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169508_1_clone_1), dimensions={} + %add.246892.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249329.113.clone.1) + %add.246894.35.clone.1 = u32[1280,1280]{1,0} add(%add.246891.37.clone.1, %add.246892.99.clone.1) + %shift-left.108960.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246892.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115119.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246892.99.clone.1, %broadcast.244415.6016) + %or.114644.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108960.31.clone.1, %shift-right-logical.115119.29.clone.1) + %xor.121197.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246894.35.clone.1, %or.114644.29.clone.1) + %add.246895.5.clone.1 = u32[1280,1280]{1,0} add(%add.246894.35.clone.1, %xor.121197.27.clone.1) + %shift-left.108961.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121197.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115120.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121197.27.clone.1, %broadcast.244417.5760) + %or.114645.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108961.9.clone.1, %shift-right-logical.115120.9.clone.1) + %xor.121198.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246895.5.clone.1, %or.114645.7.clone.1) + %add.246896.3.clone.1 = u32[1280,1280]{1,0} add(%add.246895.5.clone.1, %xor.121198.5.clone.1) + %shift-left.108962.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121198.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115121.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121198.5.clone.1, %broadcast.244419.4352) + %or.114647.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108962.5.clone.1, %shift-right-logical.115121.5.clone.1) + %xor.121199.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246896.3.clone.1, %or.114647.3.clone.1) + %add.246897.3.clone.1 = u32[1280,1280]{1,0} add(%add.246896.3.clone.1, %xor.121199.3.clone.1) + %add.246899.7.clone.1 = u32[1280,1280]{1,0} add(%add.246897.3.clone.1, %broadcast.249329.113.clone.1) + %shift-left.108963.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121199.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115122.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121199.3.clone.1, %broadcast.244418.4352) + %or.114648.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108963.5.clone.1, %shift-right-logical.115122.5.clone.1) + %xor.121200.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246897.3.clone.1, %or.114648.3.clone.1) + %constant_218076_1_clone_1 = u32[] constant(4268680213) + %broadcast.249341.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218076_1_clone_1), dimensions={} + %add.246900.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121200.3.clone.1, %broadcast.249341.5.clone.1) + %add.246901.5.clone.1 = u32[1280,1280]{1,0} add(%add.246899.7.clone.1, %add.246900.5.clone.1) + %shift-left.108964.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246900.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115123.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246900.5.clone.1, %broadcast.244416.5760) + %or.114649.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108964.9.clone.1, %shift-right-logical.115123.9.clone.1) + %xor.121202.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246901.5.clone.1, %or.114649.7.clone.1) + %add.246902.3.clone.1 = u32[1280,1280]{1,0} add(%add.246901.5.clone.1, %xor.121202.5.clone.1) + %shift-left.108965.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121202.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115124.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121202.5.clone.1, %broadcast.244429.2304) + %or.114650.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108965.9.clone.1, %shift-right-logical.115124.9.clone.1) + %xor.121203.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246902.3.clone.1, %or.114650.7.clone.1) + %add.246903.3.clone.1 = u32[1280,1280]{1,0} add(%add.246902.3.clone.1, %xor.121203.5.clone.1) + %shift-left.108966.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121203.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115125.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121203.5.clone.1, %broadcast.244430.4608) + %or.114652.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108966.9.clone.1, %shift-right-logical.115125.9.clone.1) + %xor.121204.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246903.3.clone.1, %or.114652.7.clone.1) + %add.246905.3.clone.1 = u32[1280,1280]{1,0} add(%add.246903.3.clone.1, %xor.121204.5.clone.1) + %constant_169510_1_clone_1 = u32[] constant(4268680212) + %broadcast.249348.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169510_1_clone_1), dimensions={} + %add.246909.7.clone.1 = u32[1280,1280]{1,0} add(%add.246905.3.clone.1, %broadcast.249348.24.clone.1) + %shift-left.108967.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121204.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115126.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121204.5.clone.1, %broadcast.244434.2816) + %or.114653.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108967.11.clone.1, %shift-right-logical.115126.11.clone.1) + %xor.121205.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246905.3.clone.1, %or.114653.9.clone.1) + %constant_218077_1_clone_1 = u32[] constant(415291911) + %broadcast.249351.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218077_1_clone_1), dimensions={} + %add.246910.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121205.7.clone.1, %broadcast.249351.5.clone.1) + %add.246911.5.clone.1 = u32[1280,1280]{1,0} add(%add.246909.7.clone.1, %add.246910.5.clone.1) + %shift-left.108968.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246910.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115127.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246910.5.clone.1, %broadcast.244415.6016) + %or.114654.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108968.9.clone.1, %shift-right-logical.115127.9.clone.1) + %xor.121207.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246911.5.clone.1, %or.114654.7.clone.1) + %add.246912.3.clone.1 = u32[1280,1280]{1,0} add(%add.246911.5.clone.1, %xor.121207.5.clone.1) + %shift-left.108969.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121207.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115128.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121207.5.clone.1, %broadcast.244417.5760) + %or.114655.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108969.9.clone.1, %shift-right-logical.115128.9.clone.1) + %xor.121208.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246912.3.clone.1, %or.114655.7.clone.1) + %add.246914.3.clone.1 = u32[1280,1280]{1,0} add(%add.246912.3.clone.1, %xor.121208.5.clone.1) + %shift-left.108970.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121208.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115129.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121208.5.clone.1, %broadcast.244419.4352) + %or.114657.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108970.7.clone.1, %shift-right-logical.115129.7.clone.1) + %xor.121209.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246914.3.clone.1, %or.114657.5.clone.1) + %add.246915.3.clone.1 = u32[1280,1280]{1,0} add(%add.246914.3.clone.1, %xor.121209.3.clone.1) + %add.246916.7.clone.1 = u32[1280,1280]{1,0} add(%add.246915.3.clone.1, %broadcast.249328.44.clone.1) + %shift-left.108971.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121209.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115130.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121209.3.clone.1, %broadcast.244418.4352) + %or.114658.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108971.7.clone.1, %shift-right-logical.115130.7.clone.1) + %xor.121210.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246915.3.clone.1, %or.114658.5.clone.1) + %constant_218078_1_clone_1 = u32[] constant(4252968398) + %broadcast.249363.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218078_1_clone_1), dimensions={} + %add.246917.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121210.3.clone.1, %broadcast.249363.5.clone.1) + %add.246919.5.clone.1 = u32[1280,1280]{1,0} add(%add.246916.7.clone.1, %add.246917.5.clone.1) + %shift-left.108972.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246917.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115131.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246917.5.clone.1, %broadcast.244416.5760) + %or.114659.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108972.9.clone.1, %shift-right-logical.115131.9.clone.1) + %xor.121211.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246919.5.clone.1, %or.114659.7.clone.1) + %add.246920.3.clone.1 = u32[1280,1280]{1,0} add(%add.246919.5.clone.1, %xor.121211.5.clone.1) + %shift-left.108973.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121211.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115132.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121211.5.clone.1, %broadcast.244429.2304) + %or.114660.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108973.9.clone.1, %shift-right-logical.115132.9.clone.1) + %xor.121212.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246920.3.clone.1, %or.114660.7.clone.1) + %add.246921.3.clone.1 = u32[1280,1280]{1,0} add(%add.246920.3.clone.1, %xor.121212.5.clone.1) + %shift-left.108974.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121212.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115133.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121212.5.clone.1, %broadcast.244430.4608) + %or.114662.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108974.9.clone.1, %shift-right-logical.115133.9.clone.1) + %xor.121213.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246921.3.clone.1, %or.114662.7.clone.1) + %add.246922.3.clone.1 = u32[1280,1280]{1,0} add(%add.246921.3.clone.1, %xor.121213.5.clone.1) + %add.246924.7.clone.1 = u32[1280,1280]{1,0} add(%add.246922.3.clone.1, %broadcast.249329.113.clone.1) + %shift-left.108975.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121213.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115134.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121213.5.clone.1, %broadcast.244434.2816) + %or.114663.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108975.11.clone.1, %shift-right-logical.115134.11.clone.1) + %xor.121214.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246922.3.clone.1, %or.114663.9.clone.1) + %constant_218079_1_clone_1 = u32[] constant(4268680216) + %broadcast.249373.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218079_1_clone_1), dimensions={} + %add.246925.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121214.7.clone.1, %broadcast.249373.5.clone.1) + %add.246926.5.clone.1 = u32[1280,1280]{1,0} add(%add.246924.7.clone.1, %add.246925.5.clone.1) + %shift-left.108976.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246925.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115135.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246925.5.clone.1, %broadcast.244415.6016) + %or.114664.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108976.9.clone.1, %shift-right-logical.115135.9.clone.1) + %xor.121215.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246926.5.clone.1, %or.114664.7.clone.1) + %add.246927.3.clone.1 = u32[1280,1280]{1,0} add(%add.246926.5.clone.1, %xor.121215.5.clone.1) + %shift-left.108977.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121215.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115136.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121215.5.clone.1, %broadcast.244417.5760) + %or.114665.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108977.9.clone.1, %shift-right-logical.115136.9.clone.1) + %xor.121217.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246927.3.clone.1, %or.114665.7.clone.1) + %add.246928.3.clone.1 = u32[1280,1280]{1,0} add(%add.246927.3.clone.1, %xor.121217.5.clone.1) + %shift-left.108978.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121217.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115137.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121217.5.clone.1, %broadcast.244419.4352) + %or.114666.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108978.5.clone.1, %shift-right-logical.115137.5.clone.1) + %xor.121218.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246928.3.clone.1, %or.114666.3.clone.1) + %add.246930.3.clone.1 = u32[1280,1280]{1,0} add(%add.246928.3.clone.1, %xor.121218.3.clone.1) + %add.246934.17.clone.1 = u32[1280,1280]{1,0} add(%add.246930.3.clone.1, %broadcast.249348.24.clone.1) + %shift-left.108979.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121218.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115138.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121218.3.clone.1, %broadcast.244418.4352) + %or.114667.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108979.5.clone.1, %shift-right-logical.115138.5.clone.1) + %xor.121219.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246930.3.clone.1, %or.114667.3.clone.1) + %constant_218080_1_clone_1 = u32[] constant(415291914) + %broadcast.249383.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218080_1_clone_1), dimensions={} + %add.246935.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121219.15.clone.1, %broadcast.249383.19.clone.1) + %xor.121220.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246934.17.clone.1, %add.246935.19.clone.1) + %shift-right-logical.115139.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121220.17.clone.1, %broadcast.244468.1920) + %or.114668.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115139.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5725.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114668.13.clone.1) + %add.246936.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5725.11.clone.1, %broadcast.244470.1152) + %multiply.26168.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246936.9.clone.1, %broadcast.244471.896) + %add.246937.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26168.7.clone.1, %broadcast.244408.1024) + %maximum.3657.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246937.5.clone.1) + %abs.1521.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3657.3.clone.1) + %compare.7190.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1521.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26169.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3657.3.clone.1, %broadcast.244476.1152) + %negate.4547.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3657.3.clone.1) + %multiply.26170.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3657.3.clone.1, %negate.4547.5.clone.1) + %log-plus-one.1521.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26170.5.clone.1) + %negate.4548.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1521.3.clone.1) + %compare.7191.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4548.4.clone.1, %broadcast.244477.384), direction=LT + %select.20796.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20797.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20798.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20799.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20800.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20801.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20802.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20803.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20804.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246939.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4548.4.clone.1, %broadcast.244496.640) + %sqrt.1521.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4548.4.clone.1) + %add.246940.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1521.5.clone.1, %broadcast.244498.640) + %select.20805.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7191.3.clone.1, %add.246939.5.clone.1, %add.246940.5.clone.1) + %multiply.26171.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20804.3.clone.1, %select.20805.3.clone.1) + %add.246941.1.clone.1 = f32[1280,1280]{1,0} add(%select.20803.3.clone.1, %multiply.26171.1.clone.1) + %multiply.26172.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246941.1.clone.1, %select.20805.3.clone.1) + %add.246942.1.clone.1 = f32[1280,1280]{1,0} add(%select.20802.3.clone.1, %multiply.26172.1.clone.1) + %multiply.26173.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246942.1.clone.1, %select.20805.3.clone.1) + %add.246944.1.clone.1 = f32[1280,1280]{1,0} add(%select.20801.3.clone.1, %multiply.26173.1.clone.1) + %multiply.26174.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246944.1.clone.1, %select.20805.3.clone.1) + %add.246945.1.clone.1 = f32[1280,1280]{1,0} add(%select.20800.3.clone.1, %multiply.26174.1.clone.1) + %multiply.26175.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246945.1.clone.1, %select.20805.3.clone.1) + %add.246946.3.clone.1 = f32[1280,1280]{1,0} add(%select.20799.5.clone.1, %multiply.26175.1.clone.1) + %multiply.26176.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246946.3.clone.1, %select.20805.3.clone.1) + %add.246947.3.clone.1 = f32[1280,1280]{1,0} add(%select.20798.5.clone.1, %multiply.26176.1.clone.1) + %multiply.26177.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246947.3.clone.1, %select.20805.3.clone.1) + %add.246949.9.clone.1 = f32[1280,1280]{1,0} add(%select.20797.11.clone.1, %multiply.26177.7.clone.1) + %multiply.26178.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246949.9.clone.1, %select.20805.3.clone.1) + %add.246950.7.clone.1 = f32[1280,1280]{1,0} add(%select.20796.7.clone.1, %multiply.26178.7.clone.1) + %multiply.26179.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246950.7.clone.1, %maximum.3657.3.clone.1) + %select.20806.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7190.3.clone.1, %multiply.26169.9.clone.1, %multiply.26179.7.clone.1) + %multiply.26180.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20806.7.clone.1, %broadcast.244500.640) + %clamp.1165.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26180.5.clone.1, %broadcast.244501.384) + %multiply.26181.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1165.3.clone.1, %broadcast.244502.1) + %constant_183147_1_clone_1 = u32[] constant(528311313) + %broadcast.255234.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183147_1_clone_1), dimensions={} + %add.250260.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255234.44.clone.1) + %constant_183154_1_clone_1 = u32[] constant(1367963542) + %broadcast.255235.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183154_1_clone_1), dimensions={} + %add.250261.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255235.113.clone.1) + %add.250262.35.clone.1 = u32[1280,1280]{1,0} add(%add.250260.37.clone.1, %add.250261.99.clone.1) + %shift-left.110420.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250261.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116655.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250261.99.clone.1, %broadcast.244415.6016) + %or.116177.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110420.31.clone.1, %shift-right-logical.116655.29.clone.1) + %xor.122749.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250262.35.clone.1, %or.116177.29.clone.1) + %add.250263.5.clone.1 = u32[1280,1280]{1,0} add(%add.250262.35.clone.1, %xor.122749.27.clone.1) + %shift-left.110421.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122749.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116656.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122749.27.clone.1, %broadcast.244417.5760) + %or.116178.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110421.9.clone.1, %shift-right-logical.116656.9.clone.1) + %xor.122750.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250263.5.clone.1, %or.116178.7.clone.1) + %add.250264.3.clone.1 = u32[1280,1280]{1,0} add(%add.250263.5.clone.1, %xor.122750.5.clone.1) + %shift-left.110422.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122750.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116657.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122750.5.clone.1, %broadcast.244419.4352) + %or.116179.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110422.5.clone.1, %shift-right-logical.116657.5.clone.1) + %xor.122751.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250264.3.clone.1, %or.116179.3.clone.1) + %add.250265.3.clone.1 = u32[1280,1280]{1,0} add(%add.250264.3.clone.1, %xor.122751.3.clone.1) + %add.250266.7.clone.1 = u32[1280,1280]{1,0} add(%add.250265.3.clone.1, %broadcast.255235.113.clone.1) + %shift-left.110423.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122751.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116658.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122751.3.clone.1, %broadcast.244418.4352) + %or.116181.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110423.5.clone.1, %shift-right-logical.116658.5.clone.1) + %xor.122752.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250265.3.clone.1, %or.116181.3.clone.1) + %constant_218449_1_clone_1 = u32[] constant(1428490334) + %broadcast.255245.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218449_1_clone_1), dimensions={} + %add.250267.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122752.3.clone.1, %broadcast.255245.5.clone.1) + %add.250268.5.clone.1 = u32[1280,1280]{1,0} add(%add.250266.7.clone.1, %add.250267.5.clone.1) + %shift-left.110424.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250267.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116659.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250267.5.clone.1, %broadcast.244416.5760) + %or.116182.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110424.9.clone.1, %shift-right-logical.116659.9.clone.1) + %xor.122753.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250268.5.clone.1, %or.116182.7.clone.1) + %add.250269.3.clone.1 = u32[1280,1280]{1,0} add(%add.250268.5.clone.1, %xor.122753.5.clone.1) + %shift-left.110425.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122753.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116660.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122753.5.clone.1, %broadcast.244429.2304) + %or.116183.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110425.9.clone.1, %shift-right-logical.116660.9.clone.1) + %xor.122754.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250269.3.clone.1, %or.116183.7.clone.1) + %add.250270.3.clone.1 = u32[1280,1280]{1,0} add(%add.250269.3.clone.1, %xor.122754.5.clone.1) + %shift-left.110426.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122754.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116661.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122754.5.clone.1, %broadcast.244430.4608) + %or.116184.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110426.9.clone.1, %shift-right-logical.116661.9.clone.1) + %xor.122755.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250270.3.clone.1, %or.116184.7.clone.1) + %add.250271.3.clone.1 = u32[1280,1280]{1,0} add(%add.250270.3.clone.1, %xor.122755.5.clone.1) + %constant_183156_1_clone_1 = u32[] constant(1428490333) + %broadcast.255252.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_183156_1_clone_1), dimensions={} + %add.250272.7.clone.1 = u32[1280,1280]{1,0} add(%add.250271.3.clone.1, %broadcast.255252.24.clone.1) + %shift-left.110427.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122755.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116662.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122755.5.clone.1, %broadcast.244434.2816) + %or.116185.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110427.11.clone.1, %shift-right-logical.116662.11.clone.1) + %xor.122756.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250271.3.clone.1, %or.116185.9.clone.1) + %constant_218450_1_clone_1 = u32[] constant(528311315) + %broadcast.255255.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218450_1_clone_1), dimensions={} + %add.250273.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122756.7.clone.1, %broadcast.255255.5.clone.1) + %add.250274.5.clone.1 = u32[1280,1280]{1,0} add(%add.250272.7.clone.1, %add.250273.5.clone.1) + %shift-left.110428.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250273.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116664.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250273.5.clone.1, %broadcast.244415.6016) + %or.116186.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110428.9.clone.1, %shift-right-logical.116664.9.clone.1) + %xor.122757.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250274.5.clone.1, %or.116186.7.clone.1) + %add.250275.3.clone.1 = u32[1280,1280]{1,0} add(%add.250274.5.clone.1, %xor.122757.5.clone.1) + %shift-left.110429.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122757.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116665.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122757.5.clone.1, %broadcast.244417.5760) + %or.116187.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110429.9.clone.1, %shift-right-logical.116665.9.clone.1) + %xor.122758.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250275.3.clone.1, %or.116187.7.clone.1) + %add.250276.3.clone.1 = u32[1280,1280]{1,0} add(%add.250275.3.clone.1, %xor.122758.5.clone.1) + %shift-left.110430.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122758.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116666.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122758.5.clone.1, %broadcast.244419.4352) + %or.116188.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110430.7.clone.1, %shift-right-logical.116666.7.clone.1) + %xor.122759.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250276.3.clone.1, %or.116188.5.clone.1) + %add.250277.3.clone.1 = u32[1280,1280]{1,0} add(%add.250276.3.clone.1, %xor.122759.3.clone.1) + %add.250278.7.clone.1 = u32[1280,1280]{1,0} add(%add.250277.3.clone.1, %broadcast.255234.44.clone.1) + %shift-left.110431.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122759.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116667.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122759.3.clone.1, %broadcast.244418.4352) + %or.116189.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110431.7.clone.1, %shift-right-logical.116667.7.clone.1) + %xor.122760.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250277.3.clone.1, %or.116189.5.clone.1) + %constant_218451_1_clone_1 = u32[] constant(1367963545) + %broadcast.255265.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218451_1_clone_1), dimensions={} + %add.250279.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122760.3.clone.1, %broadcast.255265.5.clone.1) + %add.250280.5.clone.1 = u32[1280,1280]{1,0} add(%add.250278.7.clone.1, %add.250279.5.clone.1) + %shift-left.110432.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250279.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116669.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250279.5.clone.1, %broadcast.244416.5760) + %or.116191.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110432.9.clone.1, %shift-right-logical.116669.9.clone.1) + %xor.122761.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250280.5.clone.1, %or.116191.7.clone.1) + %add.250281.3.clone.1 = u32[1280,1280]{1,0} add(%add.250280.5.clone.1, %xor.122761.5.clone.1) + %shift-left.110433.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122761.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116670.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122761.5.clone.1, %broadcast.244429.2304) + %or.116192.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110433.9.clone.1, %shift-right-logical.116670.9.clone.1) + %xor.122762.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250281.3.clone.1, %or.116192.7.clone.1) + %add.250282.3.clone.1 = u32[1280,1280]{1,0} add(%add.250281.3.clone.1, %xor.122762.5.clone.1) + %shift-left.110434.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122762.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116671.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122762.5.clone.1, %broadcast.244430.4608) + %or.116193.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110434.9.clone.1, %shift-right-logical.116671.9.clone.1) + %xor.122763.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250282.3.clone.1, %or.116193.7.clone.1) + %add.250283.3.clone.1 = u32[1280,1280]{1,0} add(%add.250282.3.clone.1, %xor.122763.5.clone.1) + %add.250284.7.clone.1 = u32[1280,1280]{1,0} add(%add.250283.3.clone.1, %broadcast.255235.113.clone.1) + %shift-left.110435.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122763.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116672.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122763.5.clone.1, %broadcast.244434.2816) + %or.116194.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110435.11.clone.1, %shift-right-logical.116672.11.clone.1) + %xor.122764.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250283.3.clone.1, %or.116194.9.clone.1) + %constant_218452_1_clone_1 = u32[] constant(1428490337) + %broadcast.255275.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218452_1_clone_1), dimensions={} + %add.250285.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122764.7.clone.1, %broadcast.255275.5.clone.1) + %add.250286.5.clone.1 = u32[1280,1280]{1,0} add(%add.250284.7.clone.1, %add.250285.5.clone.1) + %shift-left.110436.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250285.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116674.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250285.5.clone.1, %broadcast.244415.6016) + %or.116196.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110436.9.clone.1, %shift-right-logical.116674.9.clone.1) + %xor.122765.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250286.5.clone.1, %or.116196.7.clone.1) + %add.250287.3.clone.1 = u32[1280,1280]{1,0} add(%add.250286.5.clone.1, %xor.122765.5.clone.1) + %shift-left.110437.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122765.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116675.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122765.5.clone.1, %broadcast.244417.5760) + %or.116197.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110437.9.clone.1, %shift-right-logical.116675.9.clone.1) + %xor.122766.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250287.3.clone.1, %or.116197.7.clone.1) + %add.250288.3.clone.1 = u32[1280,1280]{1,0} add(%add.250287.3.clone.1, %xor.122766.5.clone.1) + %shift-left.110438.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122766.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116676.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122766.5.clone.1, %broadcast.244419.4352) + %or.116198.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110438.5.clone.1, %shift-right-logical.116676.5.clone.1) + %xor.122767.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250288.3.clone.1, %or.116198.3.clone.1) + %add.250289.3.clone.1 = u32[1280,1280]{1,0} add(%add.250288.3.clone.1, %xor.122767.3.clone.1) + %add.250290.17.clone.1 = u32[1280,1280]{1,0} add(%add.250289.3.clone.1, %broadcast.255252.24.clone.1) + %shift-left.110439.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122767.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116677.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122767.3.clone.1, %broadcast.244418.4352) + %or.116199.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110439.5.clone.1, %shift-right-logical.116677.5.clone.1) + %xor.122768.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250289.3.clone.1, %or.116199.3.clone.1) + %constant_218453_1_clone_1 = u32[] constant(528311318) + %broadcast.255285.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218453_1_clone_1), dimensions={} + %add.250291.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122768.15.clone.1, %broadcast.255285.19.clone.1) + %xor.122769.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250290.17.clone.1, %add.250291.19.clone.1) + %shift-right-logical.116679.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122769.17.clone.1, %broadcast.244468.1920) + %or.116201.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116679.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5792.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116201.13.clone.1) + %add.250292.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5792.11.clone.1, %broadcast.244470.1152) + %multiply.26857.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250292.9.clone.1, %broadcast.244471.896) + %add.250293.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26857.7.clone.1, %broadcast.244408.1024) + %maximum.3724.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250293.5.clone.1) + %abs.1566.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3724.3.clone.1) + %compare.7282.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1566.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26858.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3724.3.clone.1, %broadcast.244476.1152) + %negate.4637.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3724.3.clone.1) + %multiply.26859.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3724.3.clone.1, %negate.4637.5.clone.1) + %log-plus-one.1566.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26859.5.clone.1) + %negate.4638.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1566.3.clone.1) + %compare.7283.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4638.4.clone.1, %broadcast.244477.384), direction=LT + %select.21312.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21313.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21314.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21315.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21316.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21317.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21318.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21319.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21320.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250294.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4638.4.clone.1, %broadcast.244496.640) + %sqrt.1566.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4638.4.clone.1) + %add.250295.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1566.5.clone.1, %broadcast.244498.640) + %select.21321.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7283.3.clone.1, %add.250294.5.clone.1, %add.250295.5.clone.1) + %multiply.26860.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21320.3.clone.1, %select.21321.3.clone.1) + %add.250296.1.clone.1 = f32[1280,1280]{1,0} add(%select.21319.3.clone.1, %multiply.26860.1.clone.1) + %multiply.26861.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250296.1.clone.1, %select.21321.3.clone.1) + %add.250298.1.clone.1 = f32[1280,1280]{1,0} add(%select.21318.3.clone.1, %multiply.26861.1.clone.1) + %multiply.26862.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250298.1.clone.1, %select.21321.3.clone.1) + %add.250299.1.clone.1 = f32[1280,1280]{1,0} add(%select.21317.3.clone.1, %multiply.26862.1.clone.1) + %multiply.26863.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250299.1.clone.1, %select.21321.3.clone.1) + %add.250300.1.clone.1 = f32[1280,1280]{1,0} add(%select.21316.3.clone.1, %multiply.26863.1.clone.1) + %multiply.26864.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250300.1.clone.1, %select.21321.3.clone.1) + %add.250301.3.clone.1 = f32[1280,1280]{1,0} add(%select.21315.5.clone.1, %multiply.26864.1.clone.1) + %multiply.26865.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250301.3.clone.1, %select.21321.3.clone.1) + %add.250302.3.clone.1 = f32[1280,1280]{1,0} add(%select.21314.5.clone.1, %multiply.26865.1.clone.1) + %multiply.26866.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250302.3.clone.1, %select.21321.3.clone.1) + %add.250303.9.clone.1 = f32[1280,1280]{1,0} add(%select.21313.11.clone.1, %multiply.26866.7.clone.1) + %multiply.26867.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250303.9.clone.1, %select.21321.3.clone.1) + %add.250304.7.clone.1 = f32[1280,1280]{1,0} add(%select.21312.7.clone.1, %multiply.26867.7.clone.1) + %multiply.26868.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250304.7.clone.1, %maximum.3724.3.clone.1) + %select.21322.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7282.3.clone.1, %multiply.26858.9.clone.1, %multiply.26868.7.clone.1) + %multiply.26869.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21322.7.clone.1, %broadcast.244500.640) + %clamp.1210.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26869.5.clone.1, %broadcast.244501.384) + %multiply.26870.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1210.3.clone.1, %broadcast.244502.1) + %constant_169285_1_clone_1 = u32[] constant(3205688226) + %broadcast.249221.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169285_1_clone_1), dimensions={} + %add.246829.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249221.44.clone.1) + %constant_169292_1_clone_1 = u32[] constant(4173809523) + %broadcast.249222.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169292_1_clone_1), dimensions={} + %add.246831.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249222.113.clone.1) + %add.246834.35.clone.1 = u32[1280,1280]{1,0} add(%add.246829.37.clone.1, %add.246831.99.clone.1) + %shift-left.108940.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246831.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115095.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246831.99.clone.1, %broadcast.244415.6016) + %or.114619.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108940.31.clone.1, %shift-right-logical.115095.29.clone.1) + %xor.121172.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246834.35.clone.1, %or.114619.29.clone.1) + %add.246835.5.clone.1 = u32[1280,1280]{1,0} add(%add.246834.35.clone.1, %xor.121172.27.clone.1) + %shift-left.108941.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121172.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115096.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121172.27.clone.1, %broadcast.244417.5760) + %or.114620.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108941.9.clone.1, %shift-right-logical.115096.9.clone.1) + %xor.121173.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246835.5.clone.1, %or.114620.7.clone.1) + %add.246836.3.clone.1 = u32[1280,1280]{1,0} add(%add.246835.5.clone.1, %xor.121173.5.clone.1) + %shift-left.108942.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121173.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115098.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121173.5.clone.1, %broadcast.244419.4352) + %or.114622.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108942.5.clone.1, %shift-right-logical.115098.5.clone.1) + %xor.121174.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246836.3.clone.1, %or.114622.3.clone.1) + %add.246837.3.clone.1 = u32[1280,1280]{1,0} add(%add.246836.3.clone.1, %xor.121174.3.clone.1) + %add.246839.7.clone.1 = u32[1280,1280]{1,0} add(%add.246837.3.clone.1, %broadcast.249222.113.clone.1) + %shift-left.108943.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121174.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115099.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121174.3.clone.1, %broadcast.244418.4352) + %or.114623.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108943.5.clone.1, %shift-right-logical.115099.5.clone.1) + %xor.121175.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246837.3.clone.1, %or.114623.3.clone.1) + %constant_218071_1_clone_1 = u32[] constant(1543812876) + %broadcast.249236.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218071_1_clone_1), dimensions={} + %add.246840.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121175.3.clone.1, %broadcast.249236.5.clone.1) + %add.246841.5.clone.1 = u32[1280,1280]{1,0} add(%add.246839.7.clone.1, %add.246840.5.clone.1) + %shift-left.108944.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246840.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115100.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246840.5.clone.1, %broadcast.244416.5760) + %or.114624.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108944.9.clone.1, %shift-right-logical.115100.9.clone.1) + %xor.121177.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246841.5.clone.1, %or.114624.7.clone.1) + %add.246842.3.clone.1 = u32[1280,1280]{1,0} add(%add.246841.5.clone.1, %xor.121177.5.clone.1) + %shift-left.108945.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121177.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115101.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121177.5.clone.1, %broadcast.244429.2304) + %or.114625.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108945.9.clone.1, %shift-right-logical.115101.9.clone.1) + %xor.121178.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246842.3.clone.1, %or.114625.7.clone.1) + %add.246844.3.clone.1 = u32[1280,1280]{1,0} add(%add.246842.3.clone.1, %xor.121178.5.clone.1) + %shift-left.108946.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121178.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115103.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121178.5.clone.1, %broadcast.244430.4608) + %or.114627.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108946.9.clone.1, %shift-right-logical.115103.9.clone.1) + %xor.121179.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246844.3.clone.1, %or.114627.7.clone.1) + %add.246845.3.clone.1 = u32[1280,1280]{1,0} add(%add.246844.3.clone.1, %xor.121179.5.clone.1) + %constant_169294_1_clone_1 = u32[] constant(1543812875) + %broadcast.249250.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169294_1_clone_1), dimensions={} + %add.246846.7.clone.1 = u32[1280,1280]{1,0} add(%add.246845.3.clone.1, %broadcast.249250.24.clone.1) + %shift-left.108947.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121179.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115104.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121179.5.clone.1, %broadcast.244434.2816) + %or.114628.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108947.11.clone.1, %shift-right-logical.115104.11.clone.1) + %xor.121180.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246845.3.clone.1, %or.114628.9.clone.1) + %constant_218072_1_clone_1 = u32[] constant(3205688228) + %broadcast.249256.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218072_1_clone_1), dimensions={} + %add.246847.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121180.7.clone.1, %broadcast.249256.5.clone.1) + %add.246849.5.clone.1 = u32[1280,1280]{1,0} add(%add.246846.7.clone.1, %add.246847.5.clone.1) + %shift-left.108948.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246847.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115105.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246847.5.clone.1, %broadcast.244415.6016) + %or.114629.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108948.9.clone.1, %shift-right-logical.115105.9.clone.1) + %xor.121182.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246849.5.clone.1, %or.114629.7.clone.1) + %add.246850.3.clone.1 = u32[1280,1280]{1,0} add(%add.246849.5.clone.1, %xor.121182.5.clone.1) + %shift-left.108949.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121182.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115106.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121182.5.clone.1, %broadcast.244417.5760) + %or.114630.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108949.9.clone.1, %shift-right-logical.115106.9.clone.1) + %xor.121183.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246850.3.clone.1, %or.114630.7.clone.1) + %add.246851.3.clone.1 = u32[1280,1280]{1,0} add(%add.246850.3.clone.1, %xor.121183.5.clone.1) + %shift-left.108950.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121183.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115107.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121183.5.clone.1, %broadcast.244419.4352) + %or.114632.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108950.7.clone.1, %shift-right-logical.115107.7.clone.1) + %xor.121184.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246851.3.clone.1, %or.114632.5.clone.1) + %add.246852.3.clone.1 = u32[1280,1280]{1,0} add(%add.246851.3.clone.1, %xor.121184.3.clone.1) + %add.246853.7.clone.1 = u32[1280,1280]{1,0} add(%add.246852.3.clone.1, %broadcast.249221.44.clone.1) + %shift-left.108951.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121184.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115108.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121184.3.clone.1, %broadcast.244418.4352) + %or.114633.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108951.7.clone.1, %shift-right-logical.115108.7.clone.1) + %xor.121185.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246852.3.clone.1, %or.114633.5.clone.1) + %constant_218073_1_clone_1 = u32[] constant(4173809526) + %broadcast.249269.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218073_1_clone_1), dimensions={} + %add.246855.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121185.3.clone.1, %broadcast.249269.5.clone.1) + %add.246859.5.clone.1 = u32[1280,1280]{1,0} add(%add.246853.7.clone.1, %add.246855.5.clone.1) + %shift-left.108952.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246855.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115109.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246855.5.clone.1, %broadcast.244416.5760) + %or.114634.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108952.9.clone.1, %shift-right-logical.115109.9.clone.1) + %xor.121186.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246859.5.clone.1, %or.114634.7.clone.1) + %add.246860.3.clone.1 = u32[1280,1280]{1,0} add(%add.246859.5.clone.1, %xor.121186.5.clone.1) + %shift-left.108953.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121186.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115110.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121186.5.clone.1, %broadcast.244429.2304) + %or.114635.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108953.9.clone.1, %shift-right-logical.115110.9.clone.1) + %xor.121187.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246860.3.clone.1, %or.114635.7.clone.1) + %add.246861.3.clone.1 = u32[1280,1280]{1,0} add(%add.246860.3.clone.1, %xor.121187.5.clone.1) + %shift-left.108954.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121187.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115111.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121187.5.clone.1, %broadcast.244430.4608) + %or.114637.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108954.9.clone.1, %shift-right-logical.115111.9.clone.1) + %xor.121188.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246861.3.clone.1, %or.114637.7.clone.1) + %add.246862.3.clone.1 = u32[1280,1280]{1,0} add(%add.246861.3.clone.1, %xor.121188.5.clone.1) + %add.246864.7.clone.1 = u32[1280,1280]{1,0} add(%add.246862.3.clone.1, %broadcast.249222.113.clone.1) + %shift-left.108955.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121188.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115112.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121188.5.clone.1, %broadcast.244434.2816) + %or.114638.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108955.11.clone.1, %shift-right-logical.115112.11.clone.1) + %xor.121189.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246862.3.clone.1, %or.114638.9.clone.1) + %constant_218074_1_clone_1 = u32[] constant(1543812879) + %broadcast.249279.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218074_1_clone_1), dimensions={} + %add.246865.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121189.7.clone.1, %broadcast.249279.5.clone.1) + %add.246866.5.clone.1 = u32[1280,1280]{1,0} add(%add.246864.7.clone.1, %add.246865.5.clone.1) + %shift-left.108956.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246865.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115113.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246865.5.clone.1, %broadcast.244415.6016) + %or.114639.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108956.9.clone.1, %shift-right-logical.115113.9.clone.1) + %xor.121190.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246866.5.clone.1, %or.114639.7.clone.1) + %add.246867.3.clone.1 = u32[1280,1280]{1,0} add(%add.246866.5.clone.1, %xor.121190.5.clone.1) + %shift-left.108957.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121190.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115114.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121190.5.clone.1, %broadcast.244417.5760) + %or.114640.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108957.9.clone.1, %shift-right-logical.115114.9.clone.1) + %xor.121192.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246867.3.clone.1, %or.114640.7.clone.1) + %add.246869.3.clone.1 = u32[1280,1280]{1,0} add(%add.246867.3.clone.1, %xor.121192.5.clone.1) + %shift-left.108958.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121192.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115116.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121192.5.clone.1, %broadcast.244419.4352) + %or.114641.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108958.5.clone.1, %shift-right-logical.115116.5.clone.1) + %xor.121193.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246869.3.clone.1, %or.114641.3.clone.1) + %add.246870.3.clone.1 = u32[1280,1280]{1,0} add(%add.246869.3.clone.1, %xor.121193.3.clone.1) + %add.246871.17.clone.1 = u32[1280,1280]{1,0} add(%add.246870.3.clone.1, %broadcast.249250.24.clone.1) + %shift-left.108959.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121193.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115117.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121193.3.clone.1, %broadcast.244418.4352) + %or.114642.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108959.5.clone.1, %shift-right-logical.115117.5.clone.1) + %xor.121194.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246870.3.clone.1, %or.114642.3.clone.1) + %constant_218075_1_clone_1 = u32[] constant(3205688231) + %broadcast.249291.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218075_1_clone_1), dimensions={} + %add.246872.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121194.15.clone.1, %broadcast.249291.19.clone.1) + %xor.121195.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246871.17.clone.1, %add.246872.19.clone.1) + %shift-right-logical.115118.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121195.17.clone.1, %broadcast.244468.1920) + %or.114643.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115118.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5724.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114643.13.clone.1) + %add.246874.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5724.11.clone.1, %broadcast.244470.1152) + %multiply.26154.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246874.9.clone.1, %broadcast.244471.896) + %add.246875.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26154.7.clone.1, %broadcast.244408.1024) + %maximum.3656.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246875.5.clone.1) + %abs.1520.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3656.3.clone.1) + %compare.7188.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1520.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26155.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3656.3.clone.1, %broadcast.244476.1152) + %negate.4545.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3656.3.clone.1) + %multiply.26156.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3656.3.clone.1, %negate.4545.5.clone.1) + %log-plus-one.1520.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26156.5.clone.1) + %negate.4546.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1520.3.clone.1) + %compare.7189.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4546.4.clone.1, %broadcast.244477.384), direction=LT + %select.20785.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20786.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20787.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20788.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20789.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20790.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20791.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20792.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20793.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246876.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4546.4.clone.1, %broadcast.244496.640) + %sqrt.1520.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4546.4.clone.1) + %add.246877.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1520.5.clone.1, %broadcast.244498.640) + %select.20794.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7189.3.clone.1, %add.246876.5.clone.1, %add.246877.5.clone.1) + %multiply.26157.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20793.3.clone.1, %select.20794.3.clone.1) + %add.246878.1.clone.1 = f32[1280,1280]{1,0} add(%select.20792.3.clone.1, %multiply.26157.1.clone.1) + %multiply.26158.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246878.1.clone.1, %select.20794.3.clone.1) + %add.246880.1.clone.1 = f32[1280,1280]{1,0} add(%select.20791.3.clone.1, %multiply.26158.1.clone.1) + %multiply.26159.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246880.1.clone.1, %select.20794.3.clone.1) + %add.246884.1.clone.1 = f32[1280,1280]{1,0} add(%select.20790.3.clone.1, %multiply.26159.1.clone.1) + %multiply.26160.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246884.1.clone.1, %select.20794.3.clone.1) + %add.246885.1.clone.1 = f32[1280,1280]{1,0} add(%select.20789.3.clone.1, %multiply.26160.1.clone.1) + %multiply.26161.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246885.1.clone.1, %select.20794.3.clone.1) + %add.246886.3.clone.1 = f32[1280,1280]{1,0} add(%select.20788.5.clone.1, %multiply.26161.1.clone.1) + %multiply.26162.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246886.3.clone.1, %select.20794.3.clone.1) + %add.246887.3.clone.1 = f32[1280,1280]{1,0} add(%select.20787.5.clone.1, %multiply.26162.1.clone.1) + %multiply.26163.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246887.3.clone.1, %select.20794.3.clone.1) + %add.246889.9.clone.1 = f32[1280,1280]{1,0} add(%select.20786.11.clone.1, %multiply.26163.7.clone.1) + %multiply.26164.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246889.9.clone.1, %select.20794.3.clone.1) + %add.246890.7.clone.1 = f32[1280,1280]{1,0} add(%select.20785.7.clone.1, %multiply.26164.7.clone.1) + %multiply.26165.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246890.7.clone.1, %maximum.3656.3.clone.1) + %select.20795.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7188.3.clone.1, %multiply.26155.9.clone.1, %multiply.26165.7.clone.1) + %multiply.26166.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20795.7.clone.1, %broadcast.244500.640) + %clamp.1164.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26166.5.clone.1, %broadcast.244501.384) + %multiply.26167.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1164.3.clone.1, %broadcast.244502.1) + %constant_193664_1_clone_1 = u32[] constant(3522665278) + %broadcast.259787.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193664_1_clone_1), dimensions={} + %add.252859.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.259787.44.clone.1) + %constant_193671_1_clone_1 = u32[] constant(151387858) + %broadcast.259788.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193671_1_clone_1), dimensions={} + %add.252860.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.259788.113.clone.1) + %add.252862.35.clone.1 = u32[1280,1280]{1,0} add(%add.252859.37.clone.1, %add.252860.99.clone.1) + %shift-left.111558.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252860.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117855.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252860.99.clone.1, %broadcast.244415.6016) + %or.117381.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111558.31.clone.1, %shift-right-logical.117855.29.clone.1) + %xor.123934.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252862.35.clone.1, %or.117381.29.clone.1) + %add.252866.5.clone.1 = u32[1280,1280]{1,0} add(%add.252862.35.clone.1, %xor.123934.27.clone.1) + %shift-left.111559.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123934.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117856.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123934.27.clone.1, %broadcast.244417.5760) + %or.117382.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111559.9.clone.1, %shift-right-logical.117856.9.clone.1) + %xor.123935.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252866.5.clone.1, %or.117382.7.clone.1) + %add.252867.3.clone.1 = u32[1280,1280]{1,0} add(%add.252866.5.clone.1, %xor.123935.5.clone.1) + %shift-left.111560.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123935.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117857.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123935.5.clone.1, %broadcast.244419.4352) + %or.117383.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111560.5.clone.1, %shift-right-logical.117857.5.clone.1) + %xor.123936.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252867.3.clone.1, %or.117383.3.clone.1) + %add.252868.3.clone.1 = u32[1280,1280]{1,0} add(%add.252867.3.clone.1, %xor.123936.3.clone.1) + %add.252869.7.clone.1 = u32[1280,1280]{1,0} add(%add.252868.3.clone.1, %broadcast.259788.113.clone.1) + %shift-left.111562.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123936.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117858.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123936.3.clone.1, %broadcast.244418.4352) + %or.117384.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111562.5.clone.1, %shift-right-logical.117858.5.clone.1) + %xor.123937.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252868.3.clone.1, %or.117384.3.clone.1) + %constant_218731_1_clone_1 = u32[] constant(3273883191) + %broadcast.259798.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218731_1_clone_1), dimensions={} + %add.252871.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123937.3.clone.1, %broadcast.259798.5.clone.1) + %add.252872.5.clone.1 = u32[1280,1280]{1,0} add(%add.252869.7.clone.1, %add.252871.5.clone.1) + %shift-left.111563.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252871.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117859.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252871.5.clone.1, %broadcast.244416.5760) + %or.117385.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111563.9.clone.1, %shift-right-logical.117859.9.clone.1) + %xor.123938.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252872.5.clone.1, %or.117385.7.clone.1) + %add.252873.3.clone.1 = u32[1280,1280]{1,0} add(%add.252872.5.clone.1, %xor.123938.5.clone.1) + %shift-left.111564.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123938.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117860.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123938.5.clone.1, %broadcast.244429.2304) + %or.117386.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111564.9.clone.1, %shift-right-logical.117860.9.clone.1) + %xor.123939.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252873.3.clone.1, %or.117386.7.clone.1) + %add.252874.3.clone.1 = u32[1280,1280]{1,0} add(%add.252873.3.clone.1, %xor.123939.5.clone.1) + %shift-left.111565.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123939.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117861.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123939.5.clone.1, %broadcast.244430.4608) + %or.117387.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111565.9.clone.1, %shift-right-logical.117861.9.clone.1) + %xor.123940.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252874.3.clone.1, %or.117387.7.clone.1) + %add.252876.3.clone.1 = u32[1280,1280]{1,0} add(%add.252874.3.clone.1, %xor.123940.5.clone.1) + %constant_193673_1_clone_1 = u32[] constant(3273883190) + %broadcast.259807.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193673_1_clone_1), dimensions={} + %add.252877.7.clone.1 = u32[1280,1280]{1,0} add(%add.252876.3.clone.1, %broadcast.259807.24.clone.1) + %shift-left.111567.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123940.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117862.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123940.5.clone.1, %broadcast.244434.2816) + %or.117388.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111567.11.clone.1, %shift-right-logical.117862.11.clone.1) + %xor.123942.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252876.3.clone.1, %or.117388.9.clone.1) + %constant_218732_1_clone_1 = u32[] constant(3522665280) + %broadcast.259810.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218732_1_clone_1), dimensions={} + %add.252878.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123942.7.clone.1, %broadcast.259810.5.clone.1) + %add.252879.5.clone.1 = u32[1280,1280]{1,0} add(%add.252877.7.clone.1, %add.252878.5.clone.1) + %shift-left.111568.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252878.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117863.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252878.5.clone.1, %broadcast.244415.6016) + %or.117389.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111568.9.clone.1, %shift-right-logical.117863.9.clone.1) + %xor.123943.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252879.5.clone.1, %or.117389.7.clone.1) + %add.252881.3.clone.1 = u32[1280,1280]{1,0} add(%add.252879.5.clone.1, %xor.123943.5.clone.1) + %shift-left.111569.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123943.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117864.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123943.5.clone.1, %broadcast.244417.5760) + %or.117390.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111569.9.clone.1, %shift-right-logical.117864.9.clone.1) + %xor.123944.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252881.3.clone.1, %or.117390.7.clone.1) + %add.252882.3.clone.1 = u32[1280,1280]{1,0} add(%add.252881.3.clone.1, %xor.123944.5.clone.1) + %shift-left.111570.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123944.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117865.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123944.5.clone.1, %broadcast.244419.4352) + %or.117391.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111570.7.clone.1, %shift-right-logical.117865.7.clone.1) + %xor.123945.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252882.3.clone.1, %or.117391.5.clone.1) + %add.252883.3.clone.1 = u32[1280,1280]{1,0} add(%add.252882.3.clone.1, %xor.123945.3.clone.1) + %add.252884.7.clone.1 = u32[1280,1280]{1,0} add(%add.252883.3.clone.1, %broadcast.259787.44.clone.1) + %shift-left.111571.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123945.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117866.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123945.3.clone.1, %broadcast.244418.4352) + %or.117392.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111571.7.clone.1, %shift-right-logical.117866.7.clone.1) + %xor.123947.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252883.3.clone.1, %or.117392.5.clone.1) + %constant_218733_1_clone_1 = u32[] constant(151387861) + %broadcast.259820.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218733_1_clone_1), dimensions={} + %add.252885.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123947.3.clone.1, %broadcast.259820.5.clone.1) + %add.252887.5.clone.1 = u32[1280,1280]{1,0} add(%add.252884.7.clone.1, %add.252885.5.clone.1) + %shift-left.111572.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252885.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117867.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252885.5.clone.1, %broadcast.244416.5760) + %or.117393.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111572.9.clone.1, %shift-right-logical.117867.9.clone.1) + %xor.123948.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252887.5.clone.1, %or.117393.7.clone.1) + %add.252890.3.clone.1 = u32[1280,1280]{1,0} add(%add.252887.5.clone.1, %xor.123948.5.clone.1) + %shift-left.111573.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123948.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117868.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123948.5.clone.1, %broadcast.244429.2304) + %or.117394.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111573.9.clone.1, %shift-right-logical.117868.9.clone.1) + %xor.123949.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252890.3.clone.1, %or.117394.7.clone.1) + %add.252891.3.clone.1 = u32[1280,1280]{1,0} add(%add.252890.3.clone.1, %xor.123949.5.clone.1) + %shift-left.111574.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123949.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117869.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123949.5.clone.1, %broadcast.244430.4608) + %or.117395.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111574.9.clone.1, %shift-right-logical.117869.9.clone.1) + %xor.123950.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252891.3.clone.1, %or.117395.7.clone.1) + %add.252892.3.clone.1 = u32[1280,1280]{1,0} add(%add.252891.3.clone.1, %xor.123950.5.clone.1) + %add.252893.7.clone.1 = u32[1280,1280]{1,0} add(%add.252892.3.clone.1, %broadcast.259788.113.clone.1) + %shift-left.111575.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123950.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117870.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123950.5.clone.1, %broadcast.244434.2816) + %or.117396.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111575.11.clone.1, %shift-right-logical.117870.11.clone.1) + %xor.123952.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252892.3.clone.1, %or.117396.9.clone.1) + %constant_218734_1_clone_1 = u32[] constant(3273883194) + %broadcast.259832.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218734_1_clone_1), dimensions={} + %add.252894.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123952.7.clone.1, %broadcast.259832.5.clone.1) + %add.252895.5.clone.1 = u32[1280,1280]{1,0} add(%add.252893.7.clone.1, %add.252894.5.clone.1) + %shift-left.111576.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252894.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117871.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252894.5.clone.1, %broadcast.244415.6016) + %or.117397.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111576.9.clone.1, %shift-right-logical.117871.9.clone.1) + %xor.123953.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252895.5.clone.1, %or.117397.7.clone.1) + %add.252896.3.clone.1 = u32[1280,1280]{1,0} add(%add.252895.5.clone.1, %xor.123953.5.clone.1) + %shift-left.111577.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123953.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117872.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123953.5.clone.1, %broadcast.244417.5760) + %or.117398.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111577.9.clone.1, %shift-right-logical.117872.9.clone.1) + %xor.123954.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252896.3.clone.1, %or.117398.7.clone.1) + %add.252897.3.clone.1 = u32[1280,1280]{1,0} add(%add.252896.3.clone.1, %xor.123954.5.clone.1) + %shift-left.111578.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123954.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117873.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123954.5.clone.1, %broadcast.244419.4352) + %or.117399.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111578.5.clone.1, %shift-right-logical.117873.5.clone.1) + %xor.123955.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252897.3.clone.1, %or.117399.3.clone.1) + %add.252898.3.clone.1 = u32[1280,1280]{1,0} add(%add.252897.3.clone.1, %xor.123955.3.clone.1) + %add.252899.17.clone.1 = u32[1280,1280]{1,0} add(%add.252898.3.clone.1, %broadcast.259807.24.clone.1) + %shift-left.111579.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123955.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117874.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123955.3.clone.1, %broadcast.244418.4352) + %or.117400.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111579.5.clone.1, %shift-right-logical.117874.5.clone.1) + %xor.123957.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252898.3.clone.1, %or.117400.3.clone.1) + %constant_218735_1_clone_1 = u32[] constant(3522665283) + %broadcast.259842.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218735_1_clone_1), dimensions={} + %add.252900.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123957.15.clone.1, %broadcast.259842.19.clone.1) + %xor.123958.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252899.17.clone.1, %add.252900.19.clone.1) + %shift-right-logical.117875.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123958.17.clone.1, %broadcast.244468.1920) + %or.117401.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117875.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5844.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117401.13.clone.1) + %add.252901.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5844.11.clone.1, %broadcast.244470.1152) + %multiply.27384.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252901.9.clone.1, %broadcast.244471.896) + %add.252902.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27384.7.clone.1, %broadcast.244408.1024) + %maximum.3776.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252902.5.clone.1) + %abs.1600.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3776.3.clone.1) + %compare.7362.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1600.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27385.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3776.3.clone.1, %broadcast.244476.1152) + %negate.4705.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3776.3.clone.1) + %multiply.27386.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3776.3.clone.1, %negate.4705.5.clone.1) + %log-plus-one.1600.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27386.5.clone.1) + %negate.4706.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1600.3.clone.1) + %compare.7363.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4706.4.clone.1, %broadcast.244477.384), direction=LT + %select.21707.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21708.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21709.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21710.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21711.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21712.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21713.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21714.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21715.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252903.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4706.4.clone.1, %broadcast.244496.640) + %sqrt.1600.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4706.4.clone.1) + %add.252904.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1600.5.clone.1, %broadcast.244498.640) + %select.21716.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7363.3.clone.1, %add.252903.5.clone.1, %add.252904.5.clone.1) + %multiply.27387.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21715.3.clone.1, %select.21716.3.clone.1) + %add.252905.1.clone.1 = f32[1280,1280]{1,0} add(%select.21714.3.clone.1, %multiply.27387.1.clone.1) + %multiply.27388.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252905.1.clone.1, %select.21716.3.clone.1) + %add.252906.1.clone.1 = f32[1280,1280]{1,0} add(%select.21713.3.clone.1, %multiply.27388.1.clone.1) + %multiply.27389.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252906.1.clone.1, %select.21716.3.clone.1) + %add.252907.1.clone.1 = f32[1280,1280]{1,0} add(%select.21712.3.clone.1, %multiply.27389.1.clone.1) + %multiply.27390.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252907.1.clone.1, %select.21716.3.clone.1) + %add.252908.1.clone.1 = f32[1280,1280]{1,0} add(%select.21711.3.clone.1, %multiply.27390.1.clone.1) + %multiply.27391.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252908.1.clone.1, %select.21716.3.clone.1) + %add.252909.3.clone.1 = f32[1280,1280]{1,0} add(%select.21710.5.clone.1, %multiply.27391.1.clone.1) + %multiply.27392.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252909.3.clone.1, %select.21716.3.clone.1) + %add.252910.3.clone.1 = f32[1280,1280]{1,0} add(%select.21709.5.clone.1, %multiply.27392.1.clone.1) + %multiply.27393.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252910.3.clone.1, %select.21716.3.clone.1) + %add.252911.9.clone.1 = f32[1280,1280]{1,0} add(%select.21708.11.clone.1, %multiply.27393.7.clone.1) + %multiply.27394.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252911.9.clone.1, %select.21716.3.clone.1) + %add.252912.7.clone.1 = f32[1280,1280]{1,0} add(%select.21707.7.clone.1, %multiply.27394.7.clone.1) + %multiply.27395.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252912.7.clone.1, %maximum.3776.3.clone.1) + %select.21717.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7362.3.clone.1, %multiply.27385.9.clone.1, %multiply.27395.7.clone.1) + %multiply.27396.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21717.7.clone.1, %broadcast.244500.640) + %clamp.1244.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27396.5.clone.1, %broadcast.244501.384) + %multiply.27397.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1244.3.clone.1, %broadcast.244502.1) + %constant_169075_1_clone_1 = u32[] constant(737138696) + %broadcast.249135.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169075_1_clone_1), dimensions={} + %add.246785.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.249135.44.clone.1) + %constant_169082_1_clone_1 = u32[] constant(2005701892) + %broadcast.249136.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169082_1_clone_1), dimensions={} + %add.246786.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.249136.113.clone.1) + %add.246787.35.clone.1 = u32[1280,1280]{1,0} add(%add.246785.37.clone.1, %add.246786.99.clone.1) + %shift-left.108920.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246786.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.115070.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246786.99.clone.1, %broadcast.244415.6016) + %or.114598.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108920.31.clone.1, %shift-right-logical.115070.29.clone.1) + %xor.121147.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246787.35.clone.1, %or.114598.29.clone.1) + %add.246788.5.clone.1 = u32[1280,1280]{1,0} add(%add.246787.35.clone.1, %xor.121147.27.clone.1) + %shift-left.108921.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121147.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.115071.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121147.27.clone.1, %broadcast.244417.5760) + %or.114599.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108921.9.clone.1, %shift-right-logical.115071.9.clone.1) + %xor.121148.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246788.5.clone.1, %or.114599.7.clone.1) + %add.246789.3.clone.1 = u32[1280,1280]{1,0} add(%add.246788.5.clone.1, %xor.121148.5.clone.1) + %shift-left.108922.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121148.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115073.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121148.5.clone.1, %broadcast.244419.4352) + %or.114600.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108922.5.clone.1, %shift-right-logical.115073.5.clone.1) + %xor.121149.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246789.3.clone.1, %or.114600.3.clone.1) + %add.246790.3.clone.1 = u32[1280,1280]{1,0} add(%add.246789.3.clone.1, %xor.121149.3.clone.1) + %add.246791.7.clone.1 = u32[1280,1280]{1,0} add(%add.246790.3.clone.1, %broadcast.249136.113.clone.1) + %shift-left.108923.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121149.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115074.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121149.3.clone.1, %broadcast.244418.4352) + %or.114601.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108923.5.clone.1, %shift-right-logical.115074.5.clone.1) + %xor.121150.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246790.3.clone.1, %or.114601.3.clone.1) + %constant_218066_1_clone_1 = u32[] constant(1202869975) + %broadcast.249146.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218066_1_clone_1), dimensions={} + %add.246792.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121150.3.clone.1, %broadcast.249146.5.clone.1) + %add.246793.5.clone.1 = u32[1280,1280]{1,0} add(%add.246791.7.clone.1, %add.246792.5.clone.1) + %shift-left.108924.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246792.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115075.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246792.5.clone.1, %broadcast.244416.5760) + %or.114602.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108924.9.clone.1, %shift-right-logical.115075.9.clone.1) + %xor.121152.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246793.5.clone.1, %or.114602.7.clone.1) + %add.246794.3.clone.1 = u32[1280,1280]{1,0} add(%add.246793.5.clone.1, %xor.121152.5.clone.1) + %shift-left.108925.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121152.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115076.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121152.5.clone.1, %broadcast.244429.2304) + %or.114603.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108925.9.clone.1, %shift-right-logical.115076.9.clone.1) + %xor.121153.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246794.3.clone.1, %or.114603.7.clone.1) + %add.246795.3.clone.1 = u32[1280,1280]{1,0} add(%add.246794.3.clone.1, %xor.121153.5.clone.1) + %shift-left.108926.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121153.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115078.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121153.5.clone.1, %broadcast.244430.4608) + %or.114604.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108926.9.clone.1, %shift-right-logical.115078.9.clone.1) + %xor.121154.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246795.3.clone.1, %or.114604.7.clone.1) + %add.246796.3.clone.1 = u32[1280,1280]{1,0} add(%add.246795.3.clone.1, %xor.121154.5.clone.1) + %constant_169084_1_clone_1 = u32[] constant(1202869974) + %broadcast.249153.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_169084_1_clone_1), dimensions={} + %add.246797.7.clone.1 = u32[1280,1280]{1,0} add(%add.246796.3.clone.1, %broadcast.249153.24.clone.1) + %shift-left.108927.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121154.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115079.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121154.5.clone.1, %broadcast.244434.2816) + %or.114605.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108927.11.clone.1, %shift-right-logical.115079.11.clone.1) + %xor.121155.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246796.3.clone.1, %or.114605.9.clone.1) + %constant_218067_1_clone_1 = u32[] constant(737138698) + %broadcast.249156.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218067_1_clone_1), dimensions={} + %add.246798.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121155.7.clone.1, %broadcast.249156.5.clone.1) + %add.246799.5.clone.1 = u32[1280,1280]{1,0} add(%add.246797.7.clone.1, %add.246798.5.clone.1) + %shift-left.108928.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246798.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115080.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246798.5.clone.1, %broadcast.244415.6016) + %or.114606.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108928.9.clone.1, %shift-right-logical.115080.9.clone.1) + %xor.121157.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246799.5.clone.1, %or.114606.7.clone.1) + %add.246800.3.clone.1 = u32[1280,1280]{1,0} add(%add.246799.5.clone.1, %xor.121157.5.clone.1) + %shift-left.108929.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121157.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115081.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121157.5.clone.1, %broadcast.244417.5760) + %or.114607.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108929.9.clone.1, %shift-right-logical.115081.9.clone.1) + %xor.121158.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246800.3.clone.1, %or.114607.7.clone.1) + %add.246801.3.clone.1 = u32[1280,1280]{1,0} add(%add.246800.3.clone.1, %xor.121158.5.clone.1) + %shift-left.108930.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121158.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115082.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121158.5.clone.1, %broadcast.244419.4352) + %or.114608.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108930.7.clone.1, %shift-right-logical.115082.7.clone.1) + %xor.121159.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246801.3.clone.1, %or.114608.5.clone.1) + %add.246802.3.clone.1 = u32[1280,1280]{1,0} add(%add.246801.3.clone.1, %xor.121159.3.clone.1) + %add.246803.7.clone.1 = u32[1280,1280]{1,0} add(%add.246802.3.clone.1, %broadcast.249135.44.clone.1) + %shift-left.108931.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121159.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115083.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121159.3.clone.1, %broadcast.244418.4352) + %or.114609.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108931.7.clone.1, %shift-right-logical.115083.7.clone.1) + %xor.121160.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246802.3.clone.1, %or.114609.5.clone.1) + %constant_218068_1_clone_1 = u32[] constant(2005701895) + %broadcast.249166.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218068_1_clone_1), dimensions={} + %add.246804.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121160.3.clone.1, %broadcast.249166.5.clone.1) + %add.246805.5.clone.1 = u32[1280,1280]{1,0} add(%add.246803.7.clone.1, %add.246804.5.clone.1) + %shift-left.108932.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246804.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115084.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246804.5.clone.1, %broadcast.244416.5760) + %or.114610.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108932.9.clone.1, %shift-right-logical.115084.9.clone.1) + %xor.121161.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246805.5.clone.1, %or.114610.7.clone.1) + %add.246806.3.clone.1 = u32[1280,1280]{1,0} add(%add.246805.5.clone.1, %xor.121161.5.clone.1) + %shift-left.108933.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121161.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115085.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121161.5.clone.1, %broadcast.244429.2304) + %or.114611.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108933.9.clone.1, %shift-right-logical.115085.9.clone.1) + %xor.121162.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246806.3.clone.1, %or.114611.7.clone.1) + %add.246807.3.clone.1 = u32[1280,1280]{1,0} add(%add.246806.3.clone.1, %xor.121162.5.clone.1) + %shift-left.108934.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121162.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115086.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121162.5.clone.1, %broadcast.244430.4608) + %or.114612.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108934.9.clone.1, %shift-right-logical.115086.9.clone.1) + %xor.121163.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246807.3.clone.1, %or.114612.7.clone.1) + %add.246808.3.clone.1 = u32[1280,1280]{1,0} add(%add.246807.3.clone.1, %xor.121163.5.clone.1) + %add.246809.7.clone.1 = u32[1280,1280]{1,0} add(%add.246808.3.clone.1, %broadcast.249136.113.clone.1) + %shift-left.108935.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121163.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115088.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121163.5.clone.1, %broadcast.244434.2816) + %or.114613.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108935.11.clone.1, %shift-right-logical.115088.11.clone.1) + %xor.121164.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246808.3.clone.1, %or.114613.9.clone.1) + %constant_218069_1_clone_1 = u32[] constant(1202869978) + %broadcast.249176.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218069_1_clone_1), dimensions={} + %add.246810.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121164.7.clone.1, %broadcast.249176.5.clone.1) + %add.246811.5.clone.1 = u32[1280,1280]{1,0} add(%add.246809.7.clone.1, %add.246810.5.clone.1) + %shift-left.108936.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246810.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115089.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246810.5.clone.1, %broadcast.244415.6016) + %or.114614.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108936.9.clone.1, %shift-right-logical.115089.9.clone.1) + %xor.121165.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246811.5.clone.1, %or.114614.7.clone.1) + %add.246812.3.clone.1 = u32[1280,1280]{1,0} add(%add.246811.5.clone.1, %xor.121165.5.clone.1) + %shift-left.108937.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121165.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115090.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121165.5.clone.1, %broadcast.244417.5760) + %or.114615.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108937.9.clone.1, %shift-right-logical.115090.9.clone.1) + %xor.121167.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246812.3.clone.1, %or.114615.7.clone.1) + %add.246813.3.clone.1 = u32[1280,1280]{1,0} add(%add.246812.3.clone.1, %xor.121167.5.clone.1) + %shift-left.108938.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121167.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115091.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121167.5.clone.1, %broadcast.244419.4352) + %or.114616.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108938.5.clone.1, %shift-right-logical.115091.5.clone.1) + %xor.121168.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246813.3.clone.1, %or.114616.3.clone.1) + %add.246814.3.clone.1 = u32[1280,1280]{1,0} add(%add.246813.3.clone.1, %xor.121168.3.clone.1) + %add.246815.17.clone.1 = u32[1280,1280]{1,0} add(%add.246814.3.clone.1, %broadcast.249153.24.clone.1) + %shift-left.108939.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121168.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115093.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121168.3.clone.1, %broadcast.244418.4352) + %or.114617.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108939.5.clone.1, %shift-right-logical.115093.5.clone.1) + %xor.121169.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246814.3.clone.1, %or.114617.3.clone.1) + %constant_218070_1_clone_1 = u32[] constant(737138701) + %broadcast.249186.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218070_1_clone_1), dimensions={} + %add.246816.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121169.15.clone.1, %broadcast.249186.19.clone.1) + %xor.121170.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246815.17.clone.1, %add.246816.19.clone.1) + %shift-right-logical.115094.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121170.17.clone.1, %broadcast.244468.1920) + %or.114618.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115094.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5723.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114618.13.clone.1) + %add.246817.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5723.11.clone.1, %broadcast.244470.1152) + %multiply.26140.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246817.9.clone.1, %broadcast.244471.896) + %add.246818.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26140.7.clone.1, %broadcast.244408.1024) + %maximum.3655.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246818.5.clone.1) + %abs.1519.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3655.3.clone.1) + %compare.7186.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1519.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26141.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3655.3.clone.1, %broadcast.244476.1152) + %negate.4543.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3655.3.clone.1) + %multiply.26142.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3655.3.clone.1, %negate.4543.5.clone.1) + %log-plus-one.1519.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26142.5.clone.1) + %negate.4544.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1519.3.clone.1) + %compare.7187.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4544.4.clone.1, %broadcast.244477.384), direction=LT + %select.20774.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20775.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20776.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20777.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20778.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20779.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20780.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20781.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20782.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246819.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4544.4.clone.1, %broadcast.244496.640) + %sqrt.1519.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4544.4.clone.1) + %add.246820.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1519.5.clone.1, %broadcast.244498.640) + %select.20783.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7187.3.clone.1, %add.246819.5.clone.1, %add.246820.5.clone.1) + %multiply.26143.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20782.3.clone.1, %select.20783.3.clone.1) + %add.246821.1.clone.1 = f32[1280,1280]{1,0} add(%select.20781.3.clone.1, %multiply.26143.1.clone.1) + %multiply.26144.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246821.1.clone.1, %select.20783.3.clone.1) + %add.246822.1.clone.1 = f32[1280,1280]{1,0} add(%select.20780.3.clone.1, %multiply.26144.1.clone.1) + %multiply.26145.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246822.1.clone.1, %select.20783.3.clone.1) + %add.246823.1.clone.1 = f32[1280,1280]{1,0} add(%select.20779.3.clone.1, %multiply.26145.1.clone.1) + %multiply.26146.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246823.1.clone.1, %select.20783.3.clone.1) + %add.246824.1.clone.1 = f32[1280,1280]{1,0} add(%select.20778.3.clone.1, %multiply.26146.1.clone.1) + %multiply.26147.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246824.1.clone.1, %select.20783.3.clone.1) + %add.246825.3.clone.1 = f32[1280,1280]{1,0} add(%select.20777.5.clone.1, %multiply.26147.1.clone.1) + %multiply.26148.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246825.3.clone.1, %select.20783.3.clone.1) + %add.246826.3.clone.1 = f32[1280,1280]{1,0} add(%select.20776.5.clone.1, %multiply.26148.1.clone.1) + %multiply.26149.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246826.3.clone.1, %select.20783.3.clone.1) + %add.246827.9.clone.1 = f32[1280,1280]{1,0} add(%select.20775.11.clone.1, %multiply.26149.7.clone.1) + %multiply.26150.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246827.9.clone.1, %select.20783.3.clone.1) + %add.246828.7.clone.1 = f32[1280,1280]{1,0} add(%select.20774.7.clone.1, %multiply.26150.7.clone.1) + %multiply.26151.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246828.7.clone.1, %maximum.3655.3.clone.1) + %select.20784.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7186.3.clone.1, %multiply.26141.9.clone.1, %multiply.26151.7.clone.1) + %multiply.26152.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20784.7.clone.1, %broadcast.244500.640) + %clamp.1163.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26152.5.clone.1, %broadcast.244501.384) + %multiply.26153.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1163.3.clone.1, %broadcast.244502.1) + %constant_182936_1_clone_1 = u32[] constant(1379923113) + %broadcast.255148.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182936_1_clone_1), dimensions={} + %add.250199.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255148.44.clone.1) + %constant_182943_1_clone_1 = u32[] constant(1461771423) + %broadcast.255149.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182943_1_clone_1), dimensions={} + %add.250200.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255149.113.clone.1) + %add.250201.35.clone.1 = u32[1280,1280]{1,0} add(%add.250199.37.clone.1, %add.250200.99.clone.1) + %shift-left.110396.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250200.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116634.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250200.99.clone.1, %broadcast.244415.6016) + %or.116152.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110396.31.clone.1, %shift-right-logical.116634.29.clone.1) + %xor.122724.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250201.35.clone.1, %or.116152.29.clone.1) + %add.250202.5.clone.1 = u32[1280,1280]{1,0} add(%add.250201.35.clone.1, %xor.122724.27.clone.1) + %shift-left.110397.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122724.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116635.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122724.27.clone.1, %broadcast.244417.5760) + %or.116153.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110397.9.clone.1, %shift-right-logical.116635.9.clone.1) + %xor.122725.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250202.5.clone.1, %or.116153.7.clone.1) + %add.250203.3.clone.1 = u32[1280,1280]{1,0} add(%add.250202.5.clone.1, %xor.122725.5.clone.1) + %shift-left.110398.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122725.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116636.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122725.5.clone.1, %broadcast.244419.4352) + %or.116154.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110398.5.clone.1, %shift-right-logical.116636.5.clone.1) + %xor.122726.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250203.3.clone.1, %or.116154.3.clone.1) + %add.250205.3.clone.1 = u32[1280,1280]{1,0} add(%add.250203.3.clone.1, %xor.122726.3.clone.1) + %add.250209.7.clone.1 = u32[1280,1280]{1,0} add(%add.250205.3.clone.1, %broadcast.255149.113.clone.1) + %shift-left.110399.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122726.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116637.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122726.3.clone.1, %broadcast.244418.4352) + %or.116156.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110399.5.clone.1, %shift-right-logical.116637.5.clone.1) + %xor.122727.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250205.3.clone.1, %or.116156.3.clone.1) + %constant_218444_1_clone_1 = u32[] constant(516830189) + %broadcast.255159.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218444_1_clone_1), dimensions={} + %add.250210.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122727.3.clone.1, %broadcast.255159.5.clone.1) + %add.250211.5.clone.1 = u32[1280,1280]{1,0} add(%add.250209.7.clone.1, %add.250210.5.clone.1) + %shift-left.110400.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250210.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116638.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250210.5.clone.1, %broadcast.244416.5760) + %or.116157.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110400.9.clone.1, %shift-right-logical.116638.9.clone.1) + %xor.122728.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250211.5.clone.1, %or.116157.7.clone.1) + %add.250212.3.clone.1 = u32[1280,1280]{1,0} add(%add.250211.5.clone.1, %xor.122728.5.clone.1) + %shift-left.110401.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122728.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116639.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122728.5.clone.1, %broadcast.244429.2304) + %or.116158.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110401.9.clone.1, %shift-right-logical.116639.9.clone.1) + %xor.122729.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250212.3.clone.1, %or.116158.7.clone.1) + %add.250214.3.clone.1 = u32[1280,1280]{1,0} add(%add.250212.3.clone.1, %xor.122729.5.clone.1) + %shift-left.110402.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122729.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116640.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122729.5.clone.1, %broadcast.244430.4608) + %or.116159.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110402.9.clone.1, %shift-right-logical.116640.9.clone.1) + %xor.122730.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250214.3.clone.1, %or.116159.7.clone.1) + %add.250215.3.clone.1 = u32[1280,1280]{1,0} add(%add.250214.3.clone.1, %xor.122730.5.clone.1) + %constant_182945_1_clone_1 = u32[] constant(516830188) + %broadcast.255166.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182945_1_clone_1), dimensions={} + %add.250216.7.clone.1 = u32[1280,1280]{1,0} add(%add.250215.3.clone.1, %broadcast.255166.24.clone.1) + %shift-left.110403.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122730.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116641.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122730.5.clone.1, %broadcast.244434.2816) + %or.116160.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110403.11.clone.1, %shift-right-logical.116641.11.clone.1) + %xor.122731.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250215.3.clone.1, %or.116160.9.clone.1) + %constant_218445_1_clone_1 = u32[] constant(1379923115) + %broadcast.255169.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218445_1_clone_1), dimensions={} + %add.250217.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122731.7.clone.1, %broadcast.255169.5.clone.1) + %add.250219.5.clone.1 = u32[1280,1280]{1,0} add(%add.250216.7.clone.1, %add.250217.5.clone.1) + %shift-left.110405.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250217.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116642.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250217.5.clone.1, %broadcast.244415.6016) + %or.116161.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110405.9.clone.1, %shift-right-logical.116642.9.clone.1) + %xor.122732.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250219.5.clone.1, %or.116161.7.clone.1) + %add.250220.3.clone.1 = u32[1280,1280]{1,0} add(%add.250219.5.clone.1, %xor.122732.5.clone.1) + %shift-left.110406.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122732.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116643.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122732.5.clone.1, %broadcast.244417.5760) + %or.116162.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110406.9.clone.1, %shift-right-logical.116643.9.clone.1) + %xor.122734.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250220.3.clone.1, %or.116162.7.clone.1) + %add.250221.3.clone.1 = u32[1280,1280]{1,0} add(%add.250220.3.clone.1, %xor.122734.5.clone.1) + %shift-left.110407.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122734.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116644.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122734.5.clone.1, %broadcast.244419.4352) + %or.116163.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110407.7.clone.1, %shift-right-logical.116644.7.clone.1) + %xor.122735.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250221.3.clone.1, %or.116163.5.clone.1) + %add.250222.3.clone.1 = u32[1280,1280]{1,0} add(%add.250221.3.clone.1, %xor.122735.3.clone.1) + %add.250224.7.clone.1 = u32[1280,1280]{1,0} add(%add.250222.3.clone.1, %broadcast.255148.44.clone.1) + %shift-left.110408.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122735.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116645.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122735.3.clone.1, %broadcast.244418.4352) + %or.116164.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110408.7.clone.1, %shift-right-logical.116645.7.clone.1) + %xor.122736.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250222.3.clone.1, %or.116164.5.clone.1) + %constant_218446_1_clone_1 = u32[] constant(1461771426) + %broadcast.255179.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218446_1_clone_1), dimensions={} + %add.250225.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122736.3.clone.1, %broadcast.255179.5.clone.1) + %add.250226.5.clone.1 = u32[1280,1280]{1,0} add(%add.250224.7.clone.1, %add.250225.5.clone.1) + %shift-left.110410.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250225.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116646.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250225.5.clone.1, %broadcast.244416.5760) + %or.116166.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110410.9.clone.1, %shift-right-logical.116646.9.clone.1) + %xor.122737.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250226.5.clone.1, %or.116166.7.clone.1) + %add.250227.3.clone.1 = u32[1280,1280]{1,0} add(%add.250226.5.clone.1, %xor.122737.5.clone.1) + %shift-left.110411.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122737.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116647.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122737.5.clone.1, %broadcast.244429.2304) + %or.116167.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110411.9.clone.1, %shift-right-logical.116647.9.clone.1) + %xor.122739.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250227.3.clone.1, %or.116167.7.clone.1) + %add.250228.3.clone.1 = u32[1280,1280]{1,0} add(%add.250227.3.clone.1, %xor.122739.5.clone.1) + %shift-left.110412.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122739.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116648.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122739.5.clone.1, %broadcast.244430.4608) + %or.116168.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110412.9.clone.1, %shift-right-logical.116648.9.clone.1) + %xor.122740.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250228.3.clone.1, %or.116168.7.clone.1) + %add.250230.3.clone.1 = u32[1280,1280]{1,0} add(%add.250228.3.clone.1, %xor.122740.5.clone.1) + %add.250234.7.clone.1 = u32[1280,1280]{1,0} add(%add.250230.3.clone.1, %broadcast.255149.113.clone.1) + %shift-left.110413.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122740.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116649.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122740.5.clone.1, %broadcast.244434.2816) + %or.116169.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110413.11.clone.1, %shift-right-logical.116649.11.clone.1) + %xor.122741.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250230.3.clone.1, %or.116169.9.clone.1) + %constant_218447_1_clone_1 = u32[] constant(516830192) + %broadcast.255189.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218447_1_clone_1), dimensions={} + %add.250235.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122741.7.clone.1, %broadcast.255189.5.clone.1) + %add.250236.5.clone.1 = u32[1280,1280]{1,0} add(%add.250234.7.clone.1, %add.250235.5.clone.1) + %shift-left.110415.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250235.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116650.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250235.5.clone.1, %broadcast.244415.6016) + %or.116171.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110415.9.clone.1, %shift-right-logical.116650.9.clone.1) + %xor.122742.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250236.5.clone.1, %or.116171.7.clone.1) + %add.250237.3.clone.1 = u32[1280,1280]{1,0} add(%add.250236.5.clone.1, %xor.122742.5.clone.1) + %shift-left.110416.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122742.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116651.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122742.5.clone.1, %broadcast.244417.5760) + %or.116172.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110416.9.clone.1, %shift-right-logical.116651.9.clone.1) + %xor.122744.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250237.3.clone.1, %or.116172.7.clone.1) + %add.250239.3.clone.1 = u32[1280,1280]{1,0} add(%add.250237.3.clone.1, %xor.122744.5.clone.1) + %shift-left.110417.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122744.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116652.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122744.5.clone.1, %broadcast.244419.4352) + %or.116173.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110417.5.clone.1, %shift-right-logical.116652.5.clone.1) + %xor.122745.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250239.3.clone.1, %or.116173.3.clone.1) + %add.250240.3.clone.1 = u32[1280,1280]{1,0} add(%add.250239.3.clone.1, %xor.122745.3.clone.1) + %add.250241.17.clone.1 = u32[1280,1280]{1,0} add(%add.250240.3.clone.1, %broadcast.255166.24.clone.1) + %shift-left.110418.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122745.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116653.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122745.3.clone.1, %broadcast.244418.4352) + %or.116174.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110418.5.clone.1, %shift-right-logical.116653.5.clone.1) + %xor.122746.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250240.3.clone.1, %or.116174.3.clone.1) + %constant_218448_1_clone_1 = u32[] constant(1379923118) + %broadcast.255199.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218448_1_clone_1), dimensions={} + %add.250242.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122746.15.clone.1, %broadcast.255199.19.clone.1) + %xor.122747.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250241.17.clone.1, %add.250242.19.clone.1) + %shift-right-logical.116654.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122747.17.clone.1, %broadcast.244468.1920) + %or.116176.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116654.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5791.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116176.13.clone.1) + %add.250244.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5791.11.clone.1, %broadcast.244470.1152) + %multiply.26843.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250244.9.clone.1, %broadcast.244471.896) + %add.250245.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26843.7.clone.1, %broadcast.244408.1024) + %maximum.3723.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250245.5.clone.1) + %abs.1565.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3723.3.clone.1) + %compare.7280.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1565.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26844.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3723.3.clone.1, %broadcast.244476.1152) + %negate.4635.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3723.3.clone.1) + %multiply.26845.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3723.3.clone.1, %negate.4635.5.clone.1) + %log-plus-one.1565.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26845.5.clone.1) + %negate.4636.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1565.3.clone.1) + %compare.7281.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4636.4.clone.1, %broadcast.244477.384), direction=LT + %select.21301.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21302.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21303.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21304.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21305.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21306.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21307.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21308.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21309.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250246.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4636.4.clone.1, %broadcast.244496.640) + %sqrt.1565.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4636.4.clone.1) + %add.250247.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1565.5.clone.1, %broadcast.244498.640) + %select.21310.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7281.3.clone.1, %add.250246.5.clone.1, %add.250247.5.clone.1) + %multiply.26846.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21309.3.clone.1, %select.21310.3.clone.1) + %add.250249.1.clone.1 = f32[1280,1280]{1,0} add(%select.21308.3.clone.1, %multiply.26846.1.clone.1) + %multiply.26847.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250249.1.clone.1, %select.21310.3.clone.1) + %add.250250.1.clone.1 = f32[1280,1280]{1,0} add(%select.21307.3.clone.1, %multiply.26847.1.clone.1) + %multiply.26848.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250250.1.clone.1, %select.21310.3.clone.1) + %add.250251.1.clone.1 = f32[1280,1280]{1,0} add(%select.21306.3.clone.1, %multiply.26848.1.clone.1) + %multiply.26849.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250251.1.clone.1, %select.21310.3.clone.1) + %add.250252.1.clone.1 = f32[1280,1280]{1,0} add(%select.21305.3.clone.1, %multiply.26849.1.clone.1) + %multiply.26850.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250252.1.clone.1, %select.21310.3.clone.1) + %add.250253.3.clone.1 = f32[1280,1280]{1,0} add(%select.21304.5.clone.1, %multiply.26850.1.clone.1) + %multiply.26851.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250253.3.clone.1, %select.21310.3.clone.1) + %add.250255.3.clone.1 = f32[1280,1280]{1,0} add(%select.21303.5.clone.1, %multiply.26851.1.clone.1) + %multiply.26852.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250255.3.clone.1, %select.21310.3.clone.1) + %add.250258.9.clone.1 = f32[1280,1280]{1,0} add(%select.21302.11.clone.1, %multiply.26852.7.clone.1) + %multiply.26853.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250258.9.clone.1, %select.21310.3.clone.1) + %add.250259.7.clone.1 = f32[1280,1280]{1,0} add(%select.21301.7.clone.1, %multiply.26853.7.clone.1) + %multiply.26854.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250259.7.clone.1, %maximum.3723.3.clone.1) + %select.21311.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7280.3.clone.1, %multiply.26844.9.clone.1, %multiply.26854.7.clone.1) + %multiply.26855.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21311.7.clone.1, %broadcast.244500.640) + %clamp.1209.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26855.5.clone.1, %broadcast.244501.384) + %multiply.26856.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1209.3.clone.1, %broadcast.244502.1) + %constant_168524_1_clone_1 = u32[] constant(552660468) + %broadcast.248904.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168524_1_clone_1), dimensions={} + %add.246636.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248904.44.clone.1) + %constant_168531_1_clone_1 = u32[] constant(385509064) + %broadcast.248905.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168531_1_clone_1), dimensions={} + %add.246637.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248905.113.clone.1) + %add.246638.35.clone.1 = u32[1280,1280]{1,0} add(%add.246636.37.clone.1, %add.246637.99.clone.1) + %shift-left.108860.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246637.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114995.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246637.99.clone.1, %broadcast.244415.6016) + %or.114531.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108860.31.clone.1, %shift-right-logical.114995.29.clone.1) + %xor.121080.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246638.35.clone.1, %or.114531.29.clone.1) + %add.246639.5.clone.1 = u32[1280,1280]{1,0} add(%add.246638.35.clone.1, %xor.121080.27.clone.1) + %shift-left.108861.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121080.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114996.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121080.27.clone.1, %broadcast.244417.5760) + %or.114532.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108861.9.clone.1, %shift-right-logical.114996.9.clone.1) + %xor.121081.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246639.5.clone.1, %or.114532.7.clone.1) + %add.246640.3.clone.1 = u32[1280,1280]{1,0} add(%add.246639.5.clone.1, %xor.121081.5.clone.1) + %shift-left.108862.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121081.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114998.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121081.5.clone.1, %broadcast.244419.4352) + %or.114533.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108862.5.clone.1, %shift-right-logical.114998.5.clone.1) + %xor.121082.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246640.3.clone.1, %or.114533.3.clone.1) + %add.246641.3.clone.1 = u32[1280,1280]{1,0} add(%add.246640.3.clone.1, %xor.121082.3.clone.1) + %add.246642.7.clone.1 = u32[1280,1280]{1,0} add(%add.246641.3.clone.1, %broadcast.248905.113.clone.1) + %shift-left.108863.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121082.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114999.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121082.3.clone.1, %broadcast.244418.4352) + %or.114534.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108863.5.clone.1, %shift-right-logical.114999.5.clone.1) + %xor.121083.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246641.3.clone.1, %or.114534.3.clone.1) + %constant_218051_1_clone_1 = u32[] constant(769364199) + %broadcast.248915.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218051_1_clone_1), dimensions={} + %add.246643.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121083.3.clone.1, %broadcast.248915.5.clone.1) + %add.246644.5.clone.1 = u32[1280,1280]{1,0} add(%add.246642.7.clone.1, %add.246643.5.clone.1) + %shift-left.108864.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246643.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115000.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246643.5.clone.1, %broadcast.244416.5760) + %or.114535.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108864.9.clone.1, %shift-right-logical.115000.9.clone.1) + %xor.121084.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246644.5.clone.1, %or.114535.7.clone.1) + %add.246645.3.clone.1 = u32[1280,1280]{1,0} add(%add.246644.5.clone.1, %xor.121084.5.clone.1) + %shift-left.108865.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121084.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115001.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121084.5.clone.1, %broadcast.244429.2304) + %or.114536.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108865.9.clone.1, %shift-right-logical.115001.9.clone.1) + %xor.121085.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246645.3.clone.1, %or.114536.7.clone.1) + %add.246647.3.clone.1 = u32[1280,1280]{1,0} add(%add.246645.3.clone.1, %xor.121085.5.clone.1) + %shift-left.108866.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121085.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115003.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121085.5.clone.1, %broadcast.244430.4608) + %or.114538.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108866.9.clone.1, %shift-right-logical.115003.9.clone.1) + %xor.121086.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246647.3.clone.1, %or.114538.7.clone.1) + %add.246650.3.clone.1 = u32[1280,1280]{1,0} add(%add.246647.3.clone.1, %xor.121086.5.clone.1) + %constant_168533_1_clone_1 = u32[] constant(769364198) + %broadcast.248924.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168533_1_clone_1), dimensions={} + %add.246651.7.clone.1 = u32[1280,1280]{1,0} add(%add.246650.3.clone.1, %broadcast.248924.24.clone.1) + %shift-left.108867.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121086.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115004.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121086.5.clone.1, %broadcast.244434.2816) + %or.114539.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108867.11.clone.1, %shift-right-logical.115004.11.clone.1) + %xor.121087.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246650.3.clone.1, %or.114539.9.clone.1) + %constant_218052_1_clone_1 = u32[] constant(552660470) + %broadcast.248927.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218052_1_clone_1), dimensions={} + %add.246652.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121087.7.clone.1, %broadcast.248927.5.clone.1) + %add.246653.5.clone.1 = u32[1280,1280]{1,0} add(%add.246651.7.clone.1, %add.246652.5.clone.1) + %shift-left.108868.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246652.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115005.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246652.5.clone.1, %broadcast.244415.6016) + %or.114540.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108868.9.clone.1, %shift-right-logical.115005.9.clone.1) + %xor.121088.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246653.5.clone.1, %or.114540.7.clone.1) + %add.246655.3.clone.1 = u32[1280,1280]{1,0} add(%add.246653.5.clone.1, %xor.121088.5.clone.1) + %shift-left.108869.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121088.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115006.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121088.5.clone.1, %broadcast.244417.5760) + %or.114541.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108869.9.clone.1, %shift-right-logical.115006.9.clone.1) + %xor.121089.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246655.3.clone.1, %or.114541.7.clone.1) + %add.246656.3.clone.1 = u32[1280,1280]{1,0} add(%add.246655.3.clone.1, %xor.121089.5.clone.1) + %shift-left.108870.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121089.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115007.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121089.5.clone.1, %broadcast.244419.4352) + %or.114543.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108870.7.clone.1, %shift-right-logical.115007.7.clone.1) + %xor.121090.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246656.3.clone.1, %or.114543.5.clone.1) + %add.246657.3.clone.1 = u32[1280,1280]{1,0} add(%add.246656.3.clone.1, %xor.121090.3.clone.1) + %add.246658.7.clone.1 = u32[1280,1280]{1,0} add(%add.246657.3.clone.1, %broadcast.248904.44.clone.1) + %shift-left.108871.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121090.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115008.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121090.3.clone.1, %broadcast.244418.4352) + %or.114544.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108871.7.clone.1, %shift-right-logical.115008.7.clone.1) + %xor.121091.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246657.3.clone.1, %or.114544.5.clone.1) + %constant_218053_1_clone_1 = u32[] constant(385509067) + %broadcast.248937.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218053_1_clone_1), dimensions={} + %add.246660.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121091.3.clone.1, %broadcast.248937.5.clone.1) + %add.246661.5.clone.1 = u32[1280,1280]{1,0} add(%add.246658.7.clone.1, %add.246660.5.clone.1) + %shift-left.108872.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246660.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.115009.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246660.5.clone.1, %broadcast.244416.5760) + %or.114545.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108872.9.clone.1, %shift-right-logical.115009.9.clone.1) + %xor.121092.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246661.5.clone.1, %or.114545.7.clone.1) + %add.246662.3.clone.1 = u32[1280,1280]{1,0} add(%add.246661.5.clone.1, %xor.121092.5.clone.1) + %shift-left.108873.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121092.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.115010.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121092.5.clone.1, %broadcast.244429.2304) + %or.114546.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108873.9.clone.1, %shift-right-logical.115010.9.clone.1) + %xor.121093.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246662.3.clone.1, %or.114546.7.clone.1) + %add.246663.3.clone.1 = u32[1280,1280]{1,0} add(%add.246662.3.clone.1, %xor.121093.5.clone.1) + %shift-left.108874.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121093.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.115011.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121093.5.clone.1, %broadcast.244430.4608) + %or.114548.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108874.9.clone.1, %shift-right-logical.115011.9.clone.1) + %xor.121094.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246663.3.clone.1, %or.114548.7.clone.1) + %add.246665.3.clone.1 = u32[1280,1280]{1,0} add(%add.246663.3.clone.1, %xor.121094.5.clone.1) + %add.246666.7.clone.1 = u32[1280,1280]{1,0} add(%add.246665.3.clone.1, %broadcast.248905.113.clone.1) + %shift-left.108875.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121094.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.115013.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121094.5.clone.1, %broadcast.244434.2816) + %or.114549.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108875.11.clone.1, %shift-right-logical.115013.11.clone.1) + %xor.121095.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246665.3.clone.1, %or.114549.9.clone.1) + %constant_218054_1_clone_1 = u32[] constant(769364202) + %broadcast.248949.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218054_1_clone_1), dimensions={} + %add.246667.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121095.7.clone.1, %broadcast.248949.5.clone.1) + %add.246668.5.clone.1 = u32[1280,1280]{1,0} add(%add.246666.7.clone.1, %add.246667.5.clone.1) + %shift-left.108876.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246667.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.115014.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246667.5.clone.1, %broadcast.244415.6016) + %or.114550.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108876.9.clone.1, %shift-right-logical.115014.9.clone.1) + %xor.121096.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246668.5.clone.1, %or.114550.7.clone.1) + %add.246669.3.clone.1 = u32[1280,1280]{1,0} add(%add.246668.5.clone.1, %xor.121096.5.clone.1) + %shift-left.108877.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121096.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.115015.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121096.5.clone.1, %broadcast.244417.5760) + %or.114551.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108877.9.clone.1, %shift-right-logical.115015.9.clone.1) + %xor.121097.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246669.3.clone.1, %or.114551.7.clone.1) + %add.246671.3.clone.1 = u32[1280,1280]{1,0} add(%add.246669.3.clone.1, %xor.121097.5.clone.1) + %shift-left.108878.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121097.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.115016.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121097.5.clone.1, %broadcast.244419.4352) + %or.114553.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108878.5.clone.1, %shift-right-logical.115016.5.clone.1) + %xor.121098.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246671.3.clone.1, %or.114553.3.clone.1) + %add.246675.3.clone.1 = u32[1280,1280]{1,0} add(%add.246671.3.clone.1, %xor.121098.3.clone.1) + %add.246676.17.clone.1 = u32[1280,1280]{1,0} add(%add.246675.3.clone.1, %broadcast.248924.24.clone.1) + %shift-left.108879.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121098.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.115018.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121098.3.clone.1, %broadcast.244418.4352) + %or.114554.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108879.5.clone.1, %shift-right-logical.115018.5.clone.1) + %xor.121099.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246675.3.clone.1, %or.114554.3.clone.1) + %constant_218055_1_clone_1 = u32[] constant(552660473) + %broadcast.248959.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218055_1_clone_1), dimensions={} + %add.246677.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121099.15.clone.1, %broadcast.248959.19.clone.1) + %xor.121100.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246676.17.clone.1, %add.246677.19.clone.1) + %shift-right-logical.115019.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121100.17.clone.1, %broadcast.244468.1920) + %or.114555.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.115019.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5720.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114555.13.clone.1) + %add.246678.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5720.11.clone.1, %broadcast.244470.1152) + %multiply.26122.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246678.9.clone.1, %broadcast.244471.896) + %add.246680.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26122.7.clone.1, %broadcast.244408.1024) + %maximum.3652.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246680.5.clone.1) + %abs.1518.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3652.3.clone.1) + %compare.7184.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1518.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26123.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3652.3.clone.1, %broadcast.244476.1152) + %negate.4541.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3652.3.clone.1) + %multiply.26124.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3652.3.clone.1, %negate.4541.5.clone.1) + %log-plus-one.1518.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26124.5.clone.1) + %negate.4542.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1518.3.clone.1) + %compare.7185.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4542.4.clone.1, %broadcast.244477.384), direction=LT + %select.20763.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20764.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20765.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20766.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20767.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20768.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20769.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20770.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20771.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246681.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4542.4.clone.1, %broadcast.244496.640) + %sqrt.1518.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4542.4.clone.1) + %add.246682.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1518.5.clone.1, %broadcast.244498.640) + %select.20772.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7185.3.clone.1, %add.246681.5.clone.1, %add.246682.5.clone.1) + %multiply.26125.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20771.3.clone.1, %select.20772.3.clone.1) + %add.246683.1.clone.1 = f32[1280,1280]{1,0} add(%select.20770.3.clone.1, %multiply.26125.1.clone.1) + %multiply.26126.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246683.1.clone.1, %select.20772.3.clone.1) + %add.246685.1.clone.1 = f32[1280,1280]{1,0} add(%select.20769.3.clone.1, %multiply.26126.1.clone.1) + %multiply.26127.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246685.1.clone.1, %select.20772.3.clone.1) + %add.246686.1.clone.1 = f32[1280,1280]{1,0} add(%select.20768.3.clone.1, %multiply.26127.1.clone.1) + %multiply.26128.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246686.1.clone.1, %select.20772.3.clone.1) + %add.246687.1.clone.1 = f32[1280,1280]{1,0} add(%select.20767.3.clone.1, %multiply.26128.1.clone.1) + %multiply.26129.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246687.1.clone.1, %select.20772.3.clone.1) + %add.246688.3.clone.1 = f32[1280,1280]{1,0} add(%select.20766.5.clone.1, %multiply.26129.1.clone.1) + %multiply.26130.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246688.3.clone.1, %select.20772.3.clone.1) + %add.246690.3.clone.1 = f32[1280,1280]{1,0} add(%select.20765.5.clone.1, %multiply.26130.1.clone.1) + %multiply.26131.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246690.3.clone.1, %select.20772.3.clone.1) + %add.246691.9.clone.1 = f32[1280,1280]{1,0} add(%select.20764.11.clone.1, %multiply.26131.7.clone.1) + %multiply.26132.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246691.9.clone.1, %select.20772.3.clone.1) + %add.246692.7.clone.1 = f32[1280,1280]{1,0} add(%select.20763.7.clone.1, %multiply.26132.7.clone.1) + %multiply.26133.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246692.7.clone.1, %maximum.3652.3.clone.1) + %select.20773.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7184.3.clone.1, %multiply.26123.9.clone.1, %multiply.26133.7.clone.1) + %multiply.26134.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20773.7.clone.1, %broadcast.244500.640) + %clamp.1162.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26134.5.clone.1, %broadcast.244501.384) + %multiply.26135.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1162.3.clone.1, %broadcast.244502.1) + %constant_190030_1_clone_1 = u32[] constant(2898079861) + %broadcast.258200.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190030_1_clone_1), dimensions={} + %add.251947.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258200.44.clone.1) + %constant_190037_1_clone_1 = u32[] constant(1982886832) + %broadcast.258201.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190037_1_clone_1), dimensions={} + %add.251948.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258201.113.clone.1) + %add.251950.35.clone.1 = u32[1280,1280]{1,0} add(%add.251947.37.clone.1, %add.251948.99.clone.1) + %shift-left.111176.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251948.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117438.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251948.99.clone.1, %broadcast.244415.6016) + %or.116976.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111176.31.clone.1, %shift-right-logical.117438.29.clone.1) + %xor.123531.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251950.35.clone.1, %or.116976.29.clone.1) + %add.251951.5.clone.1 = u32[1280,1280]{1,0} add(%add.251950.35.clone.1, %xor.123531.27.clone.1) + %shift-left.111177.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123531.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117439.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123531.27.clone.1, %broadcast.244417.5760) + %or.116977.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111177.9.clone.1, %shift-right-logical.117439.9.clone.1) + %xor.123532.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251951.5.clone.1, %or.116977.7.clone.1) + %add.251952.3.clone.1 = u32[1280,1280]{1,0} add(%add.251951.5.clone.1, %xor.123532.5.clone.1) + %shift-left.111178.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123532.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117440.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123532.5.clone.1, %broadcast.244419.4352) + %or.116978.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111178.5.clone.1, %shift-right-logical.117440.5.clone.1) + %xor.123533.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251952.3.clone.1, %or.116978.3.clone.1) + %add.251953.3.clone.1 = u32[1280,1280]{1,0} add(%add.251952.3.clone.1, %xor.123533.3.clone.1) + %add.251954.7.clone.1 = u32[1280,1280]{1,0} add(%add.251953.3.clone.1, %broadcast.258201.113.clone.1) + %shift-left.111179.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123533.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117441.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123533.3.clone.1, %broadcast.244418.4352) + %or.116979.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111179.5.clone.1, %shift-right-logical.117441.5.clone.1) + %xor.123534.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251953.3.clone.1, %or.116979.3.clone.1) + %constant_218640_1_clone_1 = u32[] constant(3244054560) + %broadcast.258211.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218640_1_clone_1), dimensions={} + %add.251955.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123534.3.clone.1, %broadcast.258211.5.clone.1) + %add.251956.5.clone.1 = u32[1280,1280]{1,0} add(%add.251954.7.clone.1, %add.251955.5.clone.1) + %shift-left.111180.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251955.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117442.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251955.5.clone.1, %broadcast.244416.5760) + %or.116981.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111180.9.clone.1, %shift-right-logical.117442.9.clone.1) + %xor.123535.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251956.5.clone.1, %or.116981.7.clone.1) + %add.251957.3.clone.1 = u32[1280,1280]{1,0} add(%add.251956.5.clone.1, %xor.123535.5.clone.1) + %shift-left.111181.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123535.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117443.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123535.5.clone.1, %broadcast.244429.2304) + %or.116982.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111181.9.clone.1, %shift-right-logical.117443.9.clone.1) + %xor.123536.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251957.3.clone.1, %or.116982.7.clone.1) + %add.251958.3.clone.1 = u32[1280,1280]{1,0} add(%add.251957.3.clone.1, %xor.123536.5.clone.1) + %shift-left.111182.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123536.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117444.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123536.5.clone.1, %broadcast.244430.4608) + %or.116983.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111182.9.clone.1, %shift-right-logical.117444.9.clone.1) + %xor.123537.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251958.3.clone.1, %or.116983.7.clone.1) + %add.251959.3.clone.1 = u32[1280,1280]{1,0} add(%add.251958.3.clone.1, %xor.123537.5.clone.1) + %constant_190039_1_clone_1 = u32[] constant(3244054559) + %broadcast.258218.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_190039_1_clone_1), dimensions={} + %add.251960.7.clone.1 = u32[1280,1280]{1,0} add(%add.251959.3.clone.1, %broadcast.258218.24.clone.1) + %shift-left.111183.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123537.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117446.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123537.5.clone.1, %broadcast.244434.2816) + %or.116984.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111183.11.clone.1, %shift-right-logical.117446.11.clone.1) + %xor.123538.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251959.3.clone.1, %or.116984.9.clone.1) + %constant_218641_1_clone_1 = u32[] constant(2898079863) + %broadcast.258221.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218641_1_clone_1), dimensions={} + %add.251961.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123538.7.clone.1, %broadcast.258221.5.clone.1) + %add.251962.5.clone.1 = u32[1280,1280]{1,0} add(%add.251960.7.clone.1, %add.251961.5.clone.1) + %shift-left.111185.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251961.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117447.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251961.5.clone.1, %broadcast.244415.6016) + %or.116986.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111185.9.clone.1, %shift-right-logical.117447.9.clone.1) + %xor.123539.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251962.5.clone.1, %or.116986.7.clone.1) + %add.251963.3.clone.1 = u32[1280,1280]{1,0} add(%add.251962.5.clone.1, %xor.123539.5.clone.1) + %shift-left.111186.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123539.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117448.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123539.5.clone.1, %broadcast.244417.5760) + %or.116987.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111186.9.clone.1, %shift-right-logical.117448.9.clone.1) + %xor.123540.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251963.3.clone.1, %or.116987.7.clone.1) + %add.251964.3.clone.1 = u32[1280,1280]{1,0} add(%add.251963.3.clone.1, %xor.123540.5.clone.1) + %shift-left.111187.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123540.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117449.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123540.5.clone.1, %broadcast.244419.4352) + %or.116988.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111187.7.clone.1, %shift-right-logical.117449.7.clone.1) + %xor.123541.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251964.3.clone.1, %or.116988.5.clone.1) + %add.251965.3.clone.1 = u32[1280,1280]{1,0} add(%add.251964.3.clone.1, %xor.123541.3.clone.1) + %add.251966.7.clone.1 = u32[1280,1280]{1,0} add(%add.251965.3.clone.1, %broadcast.258200.44.clone.1) + %shift-left.111188.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123541.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117451.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123541.3.clone.1, %broadcast.244418.4352) + %or.116989.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111188.7.clone.1, %shift-right-logical.117451.7.clone.1) + %xor.123542.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251965.3.clone.1, %or.116989.5.clone.1) + %constant_218642_1_clone_1 = u32[] constant(1982886835) + %broadcast.258231.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218642_1_clone_1), dimensions={} + %add.251967.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123542.3.clone.1, %broadcast.258231.5.clone.1) + %add.251968.5.clone.1 = u32[1280,1280]{1,0} add(%add.251966.7.clone.1, %add.251967.5.clone.1) + %shift-left.111190.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251967.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117452.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251967.5.clone.1, %broadcast.244416.5760) + %or.116990.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111190.9.clone.1, %shift-right-logical.117452.9.clone.1) + %xor.123543.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251968.5.clone.1, %or.116990.7.clone.1) + %add.251969.3.clone.1 = u32[1280,1280]{1,0} add(%add.251968.5.clone.1, %xor.123543.5.clone.1) + %shift-left.111191.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123543.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117453.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123543.5.clone.1, %broadcast.244429.2304) + %or.116991.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111191.9.clone.1, %shift-right-logical.117453.9.clone.1) + %xor.123544.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251969.3.clone.1, %or.116991.7.clone.1) + %add.251970.3.clone.1 = u32[1280,1280]{1,0} add(%add.251969.3.clone.1, %xor.123544.5.clone.1) + %shift-left.111192.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123544.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117454.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123544.5.clone.1, %broadcast.244430.4608) + %or.116992.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111192.9.clone.1, %shift-right-logical.117454.9.clone.1) + %xor.123545.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251970.3.clone.1, %or.116992.7.clone.1) + %add.251971.3.clone.1 = u32[1280,1280]{1,0} add(%add.251970.3.clone.1, %xor.123545.5.clone.1) + %add.251972.7.clone.1 = u32[1280,1280]{1,0} add(%add.251971.3.clone.1, %broadcast.258201.113.clone.1) + %shift-left.111193.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123545.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117456.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123545.5.clone.1, %broadcast.244434.2816) + %or.116993.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111193.11.clone.1, %shift-right-logical.117456.11.clone.1) + %xor.123546.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251971.3.clone.1, %or.116993.9.clone.1) + %constant_218643_1_clone_1 = u32[] constant(3244054563) + %broadcast.258241.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218643_1_clone_1), dimensions={} + %add.251973.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123546.7.clone.1, %broadcast.258241.5.clone.1) + %add.251974.5.clone.1 = u32[1280,1280]{1,0} add(%add.251972.7.clone.1, %add.251973.5.clone.1) + %shift-left.111195.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251973.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117457.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251973.5.clone.1, %broadcast.244415.6016) + %or.116994.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111195.9.clone.1, %shift-right-logical.117457.9.clone.1) + %xor.123547.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251974.5.clone.1, %or.116994.7.clone.1) + %add.251975.3.clone.1 = u32[1280,1280]{1,0} add(%add.251974.5.clone.1, %xor.123547.5.clone.1) + %shift-left.111196.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123547.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117458.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123547.5.clone.1, %broadcast.244417.5760) + %or.116995.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111196.9.clone.1, %shift-right-logical.117458.9.clone.1) + %xor.123548.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251975.3.clone.1, %or.116995.7.clone.1) + %add.251976.3.clone.1 = u32[1280,1280]{1,0} add(%add.251975.3.clone.1, %xor.123548.5.clone.1) + %shift-left.111197.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123548.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117459.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123548.5.clone.1, %broadcast.244419.4352) + %or.116996.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111197.5.clone.1, %shift-right-logical.117459.5.clone.1) + %xor.123549.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251976.3.clone.1, %or.116996.3.clone.1) + %add.251977.3.clone.1 = u32[1280,1280]{1,0} add(%add.251976.3.clone.1, %xor.123549.3.clone.1) + %add.251978.17.clone.1 = u32[1280,1280]{1,0} add(%add.251977.3.clone.1, %broadcast.258218.24.clone.1) + %shift-left.111198.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123549.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117461.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123549.3.clone.1, %broadcast.244418.4352) + %or.116997.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111198.5.clone.1, %shift-right-logical.117461.5.clone.1) + %xor.123550.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251977.3.clone.1, %or.116997.3.clone.1) + %constant_218644_1_clone_1 = u32[] constant(2898079866) + %broadcast.258251.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218644_1_clone_1), dimensions={} + %add.251979.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123550.15.clone.1, %broadcast.258251.19.clone.1) + %xor.123551.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251978.17.clone.1, %add.251979.19.clone.1) + %shift-right-logical.117462.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123551.17.clone.1, %broadcast.244468.1920) + %or.116998.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117462.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5826.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116998.13.clone.1) + %add.251980.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5826.11.clone.1, %broadcast.244470.1152) + %multiply.27204.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251980.9.clone.1, %broadcast.244471.896) + %add.251982.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27204.7.clone.1, %broadcast.244408.1024) + %maximum.3758.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251982.5.clone.1) + %abs.1588.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3758.3.clone.1) + %compare.7338.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1588.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27205.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3758.3.clone.1, %broadcast.244476.1152) + %negate.4681.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3758.3.clone.1) + %multiply.27206.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3758.3.clone.1, %negate.4681.5.clone.1) + %log-plus-one.1588.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27206.5.clone.1) + %negate.4682.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1588.3.clone.1) + %compare.7339.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4682.4.clone.1, %broadcast.244477.384), direction=LT + %select.21570.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21571.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21572.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21573.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21574.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21575.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21576.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21577.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21578.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251985.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4682.4.clone.1, %broadcast.244496.640) + %sqrt.1588.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4682.4.clone.1) + %add.251986.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1588.5.clone.1, %broadcast.244498.640) + %select.21579.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7339.3.clone.1, %add.251985.5.clone.1, %add.251986.5.clone.1) + %multiply.27207.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21578.3.clone.1, %select.21579.3.clone.1) + %add.251987.1.clone.1 = f32[1280,1280]{1,0} add(%select.21577.3.clone.1, %multiply.27207.1.clone.1) + %multiply.27208.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251987.1.clone.1, %select.21579.3.clone.1) + %add.251988.1.clone.1 = f32[1280,1280]{1,0} add(%select.21576.3.clone.1, %multiply.27208.1.clone.1) + %multiply.27209.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251988.1.clone.1, %select.21579.3.clone.1) + %add.251990.1.clone.1 = f32[1280,1280]{1,0} add(%select.21575.3.clone.1, %multiply.27209.1.clone.1) + %multiply.27210.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251990.1.clone.1, %select.21579.3.clone.1) + %add.251991.1.clone.1 = f32[1280,1280]{1,0} add(%select.21574.3.clone.1, %multiply.27210.1.clone.1) + %multiply.27211.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251991.1.clone.1, %select.21579.3.clone.1) + %add.251992.3.clone.1 = f32[1280,1280]{1,0} add(%select.21573.5.clone.1, %multiply.27211.1.clone.1) + %multiply.27212.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251992.3.clone.1, %select.21579.3.clone.1) + %add.251993.3.clone.1 = f32[1280,1280]{1,0} add(%select.21572.5.clone.1, %multiply.27212.1.clone.1) + %multiply.27213.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251993.3.clone.1, %select.21579.3.clone.1) + %add.251995.9.clone.1 = f32[1280,1280]{1,0} add(%select.21571.11.clone.1, %multiply.27213.7.clone.1) + %multiply.27214.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251995.9.clone.1, %select.21579.3.clone.1) + %add.251996.7.clone.1 = f32[1280,1280]{1,0} add(%select.21570.7.clone.1, %multiply.27214.7.clone.1) + %multiply.27215.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251996.7.clone.1, %maximum.3758.3.clone.1) + %select.21580.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7338.3.clone.1, %multiply.27205.9.clone.1, %multiply.27215.7.clone.1) + %multiply.27216.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21580.7.clone.1, %broadcast.244500.640) + %clamp.1232.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27216.5.clone.1, %broadcast.244501.384) + %multiply.27217.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1232.3.clone.1, %broadcast.244502.1) + %constant_168292_1_clone_1 = u32[] constant(3723649444) + %broadcast.248799.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168292_1_clone_1), dimensions={} + %add.246590.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248799.44.clone.1) + %constant_168299_1_clone_1 = u32[] constant(3203058576) + %broadcast.248800.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168299_1_clone_1), dimensions={} + %add.246591.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248800.113.clone.1) + %add.246592.35.clone.1 = u32[1280,1280]{1,0} add(%add.246590.37.clone.1, %add.246591.99.clone.1) + %shift-left.108840.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246591.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114972.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246591.99.clone.1, %broadcast.244415.6016) + %or.114506.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108840.31.clone.1, %shift-right-logical.114972.29.clone.1) + %xor.121055.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246592.35.clone.1, %or.114506.29.clone.1) + %add.246593.5.clone.1 = u32[1280,1280]{1,0} add(%add.246592.35.clone.1, %xor.121055.27.clone.1) + %shift-left.108841.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121055.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114973.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121055.27.clone.1, %broadcast.244417.5760) + %or.114507.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108841.9.clone.1, %shift-right-logical.114973.9.clone.1) + %xor.121056.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246593.5.clone.1, %or.114507.7.clone.1) + %add.246594.3.clone.1 = u32[1280,1280]{1,0} add(%add.246593.5.clone.1, %xor.121056.5.clone.1) + %shift-left.108842.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121056.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114974.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121056.5.clone.1, %broadcast.244419.4352) + %or.114508.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108842.5.clone.1, %shift-right-logical.114974.5.clone.1) + %xor.121058.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246594.3.clone.1, %or.114508.3.clone.1) + %add.246596.3.clone.1 = u32[1280,1280]{1,0} add(%add.246594.3.clone.1, %xor.121058.3.clone.1) + %add.246597.7.clone.1 = u32[1280,1280]{1,0} add(%add.246596.3.clone.1, %broadcast.248800.113.clone.1) + %shift-left.108843.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121058.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114975.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121058.3.clone.1, %broadcast.244418.4352) + %or.114509.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108843.5.clone.1, %shift-right-logical.114975.5.clone.1) + %xor.121059.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246596.3.clone.1, %or.114509.3.clone.1) + %constant_218046_1_clone_1 = u32[] constant(2026484207) + %broadcast.248810.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218046_1_clone_1), dimensions={} + %add.246598.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121059.3.clone.1, %broadcast.248810.5.clone.1) + %add.246599.5.clone.1 = u32[1280,1280]{1,0} add(%add.246597.7.clone.1, %add.246598.5.clone.1) + %shift-left.108844.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246598.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114976.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246598.5.clone.1, %broadcast.244416.5760) + %or.114510.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108844.9.clone.1, %shift-right-logical.114976.9.clone.1) + %xor.121060.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246599.5.clone.1, %or.114510.7.clone.1) + %add.246600.3.clone.1 = u32[1280,1280]{1,0} add(%add.246599.5.clone.1, %xor.121060.5.clone.1) + %shift-left.108845.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121060.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114977.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121060.5.clone.1, %broadcast.244429.2304) + %or.114511.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108845.9.clone.1, %shift-right-logical.114977.9.clone.1) + %xor.121061.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246600.3.clone.1, %or.114511.7.clone.1) + %add.246601.3.clone.1 = u32[1280,1280]{1,0} add(%add.246600.3.clone.1, %xor.121061.5.clone.1) + %shift-left.108846.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121061.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114978.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121061.5.clone.1, %broadcast.244430.4608) + %or.114513.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108846.9.clone.1, %shift-right-logical.114978.9.clone.1) + %xor.121063.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246601.3.clone.1, %or.114513.7.clone.1) + %add.246602.3.clone.1 = u32[1280,1280]{1,0} add(%add.246601.3.clone.1, %xor.121063.5.clone.1) + %constant_168301_1_clone_1 = u32[] constant(2026484206) + %broadcast.248817.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168301_1_clone_1), dimensions={} + %add.246603.7.clone.1 = u32[1280,1280]{1,0} add(%add.246602.3.clone.1, %broadcast.248817.24.clone.1) + %shift-left.108847.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121063.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114979.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121063.5.clone.1, %broadcast.244434.2816) + %or.114514.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108847.11.clone.1, %shift-right-logical.114979.11.clone.1) + %xor.121064.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246602.3.clone.1, %or.114514.9.clone.1) + %constant_218047_1_clone_1 = u32[] constant(3723649446) + %broadcast.248820.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218047_1_clone_1), dimensions={} + %add.246604.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121064.7.clone.1, %broadcast.248820.5.clone.1) + %add.246605.5.clone.1 = u32[1280,1280]{1,0} add(%add.246603.7.clone.1, %add.246604.5.clone.1) + %shift-left.108848.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246604.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114980.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246604.5.clone.1, %broadcast.244415.6016) + %or.114515.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108848.9.clone.1, %shift-right-logical.114980.9.clone.1) + %xor.121065.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246605.5.clone.1, %or.114515.7.clone.1) + %add.246606.3.clone.1 = u32[1280,1280]{1,0} add(%add.246605.5.clone.1, %xor.121065.5.clone.1) + %shift-left.108849.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121065.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114981.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121065.5.clone.1, %broadcast.244417.5760) + %or.114516.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108849.9.clone.1, %shift-right-logical.114981.9.clone.1) + %xor.121066.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246606.3.clone.1, %or.114516.7.clone.1) + %add.246607.3.clone.1 = u32[1280,1280]{1,0} add(%add.246606.3.clone.1, %xor.121066.5.clone.1) + %shift-left.108850.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121066.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114982.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121066.5.clone.1, %broadcast.244419.4352) + %or.114518.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108850.7.clone.1, %shift-right-logical.114982.7.clone.1) + %xor.121068.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246607.3.clone.1, %or.114518.5.clone.1) + %add.246608.3.clone.1 = u32[1280,1280]{1,0} add(%add.246607.3.clone.1, %xor.121068.3.clone.1) + %add.246609.7.clone.1 = u32[1280,1280]{1,0} add(%add.246608.3.clone.1, %broadcast.248799.44.clone.1) + %shift-left.108851.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121068.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114983.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121068.3.clone.1, %broadcast.244418.4352) + %or.114519.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108851.7.clone.1, %shift-right-logical.114983.7.clone.1) + %xor.121069.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246608.3.clone.1, %or.114519.5.clone.1) + %constant_218048_1_clone_1 = u32[] constant(3203058579) + %broadcast.248830.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218048_1_clone_1), dimensions={} + %add.246610.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121069.3.clone.1, %broadcast.248830.5.clone.1) + %add.246611.5.clone.1 = u32[1280,1280]{1,0} add(%add.246609.7.clone.1, %add.246610.5.clone.1) + %shift-left.108852.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246610.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114984.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246610.5.clone.1, %broadcast.244416.5760) + %or.114520.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108852.9.clone.1, %shift-right-logical.114984.9.clone.1) + %xor.121070.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246611.5.clone.1, %or.114520.7.clone.1) + %add.246612.3.clone.1 = u32[1280,1280]{1,0} add(%add.246611.5.clone.1, %xor.121070.5.clone.1) + %shift-left.108853.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121070.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114985.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121070.5.clone.1, %broadcast.244429.2304) + %or.114521.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108853.9.clone.1, %shift-right-logical.114985.9.clone.1) + %xor.121071.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246612.3.clone.1, %or.114521.7.clone.1) + %add.246613.3.clone.1 = u32[1280,1280]{1,0} add(%add.246612.3.clone.1, %xor.121071.5.clone.1) + %shift-left.108854.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121071.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114986.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121071.5.clone.1, %broadcast.244430.4608) + %or.114523.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108854.9.clone.1, %shift-right-logical.114986.9.clone.1) + %xor.121073.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246613.3.clone.1, %or.114523.7.clone.1) + %add.246614.3.clone.1 = u32[1280,1280]{1,0} add(%add.246613.3.clone.1, %xor.121073.5.clone.1) + %add.246616.7.clone.1 = u32[1280,1280]{1,0} add(%add.246614.3.clone.1, %broadcast.248800.113.clone.1) + %shift-left.108855.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121073.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114988.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121073.5.clone.1, %broadcast.244434.2816) + %or.114524.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108855.11.clone.1, %shift-right-logical.114988.11.clone.1) + %xor.121074.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246614.3.clone.1, %or.114524.9.clone.1) + %constant_218049_1_clone_1 = u32[] constant(2026484210) + %broadcast.248842.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218049_1_clone_1), dimensions={} + %add.246617.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121074.7.clone.1, %broadcast.248842.5.clone.1) + %add.246618.5.clone.1 = u32[1280,1280]{1,0} add(%add.246616.7.clone.1, %add.246617.5.clone.1) + %shift-left.108856.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246617.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114989.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246617.5.clone.1, %broadcast.244415.6016) + %or.114525.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108856.9.clone.1, %shift-right-logical.114989.9.clone.1) + %xor.121075.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246618.5.clone.1, %or.114525.7.clone.1) + %add.246619.3.clone.1 = u32[1280,1280]{1,0} add(%add.246618.5.clone.1, %xor.121075.5.clone.1) + %shift-left.108857.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121075.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114990.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121075.5.clone.1, %broadcast.244417.5760) + %or.114526.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108857.9.clone.1, %shift-right-logical.114990.9.clone.1) + %xor.121076.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246619.3.clone.1, %or.114526.7.clone.1) + %add.246620.3.clone.1 = u32[1280,1280]{1,0} add(%add.246619.3.clone.1, %xor.121076.5.clone.1) + %shift-left.108858.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121076.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114991.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121076.5.clone.1, %broadcast.244419.4352) + %or.114528.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108858.5.clone.1, %shift-right-logical.114991.5.clone.1) + %xor.121077.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246620.3.clone.1, %or.114528.3.clone.1) + %add.246621.3.clone.1 = u32[1280,1280]{1,0} add(%add.246620.3.clone.1, %xor.121077.3.clone.1) + %add.246622.17.clone.1 = u32[1280,1280]{1,0} add(%add.246621.3.clone.1, %broadcast.248817.24.clone.1) + %shift-left.108859.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121077.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114993.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121077.3.clone.1, %broadcast.244418.4352) + %or.114529.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108859.5.clone.1, %shift-right-logical.114993.5.clone.1) + %xor.121078.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246621.3.clone.1, %or.114529.3.clone.1) + %constant_218050_1_clone_1 = u32[] constant(3723649449) + %broadcast.248862.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218050_1_clone_1), dimensions={} + %add.246623.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121078.15.clone.1, %broadcast.248862.19.clone.1) + %xor.121079.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246622.17.clone.1, %add.246623.19.clone.1) + %shift-right-logical.114994.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121079.17.clone.1, %broadcast.244468.1920) + %or.114530.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114994.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5719.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114530.13.clone.1) + %add.246624.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5719.11.clone.1, %broadcast.244470.1152) + %multiply.26108.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246624.9.clone.1, %broadcast.244471.896) + %add.246625.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26108.7.clone.1, %broadcast.244408.1024) + %maximum.3651.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246625.5.clone.1) + %abs.1517.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3651.3.clone.1) + %compare.7182.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1517.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26109.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3651.3.clone.1, %broadcast.244476.1152) + %negate.4539.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3651.3.clone.1) + %multiply.26110.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3651.3.clone.1, %negate.4539.5.clone.1) + %log-plus-one.1517.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26110.5.clone.1) + %negate.4540.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1517.3.clone.1) + %compare.7183.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4540.4.clone.1, %broadcast.244477.384), direction=LT + %select.20752.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20753.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20754.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20755.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20756.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20757.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20758.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20759.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20760.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246626.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4540.4.clone.1, %broadcast.244496.640) + %sqrt.1517.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4540.4.clone.1) + %add.246627.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1517.5.clone.1, %broadcast.244498.640) + %select.20761.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7183.3.clone.1, %add.246626.5.clone.1, %add.246627.5.clone.1) + %multiply.26111.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20760.3.clone.1, %select.20761.3.clone.1) + %add.246628.1.clone.1 = f32[1280,1280]{1,0} add(%select.20759.3.clone.1, %multiply.26111.1.clone.1) + %multiply.26112.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246628.1.clone.1, %select.20761.3.clone.1) + %add.246629.1.clone.1 = f32[1280,1280]{1,0} add(%select.20758.3.clone.1, %multiply.26112.1.clone.1) + %multiply.26113.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246629.1.clone.1, %select.20761.3.clone.1) + %add.246630.1.clone.1 = f32[1280,1280]{1,0} add(%select.20757.3.clone.1, %multiply.26113.1.clone.1) + %multiply.26114.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246630.1.clone.1, %select.20761.3.clone.1) + %add.246631.1.clone.1 = f32[1280,1280]{1,0} add(%select.20756.3.clone.1, %multiply.26114.1.clone.1) + %multiply.26115.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246631.1.clone.1, %select.20761.3.clone.1) + %add.246632.3.clone.1 = f32[1280,1280]{1,0} add(%select.20755.5.clone.1, %multiply.26115.1.clone.1) + %multiply.26116.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246632.3.clone.1, %select.20761.3.clone.1) + %add.246633.3.clone.1 = f32[1280,1280]{1,0} add(%select.20754.5.clone.1, %multiply.26116.1.clone.1) + %multiply.26117.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246633.3.clone.1, %select.20761.3.clone.1) + %add.246634.9.clone.1 = f32[1280,1280]{1,0} add(%select.20753.11.clone.1, %multiply.26117.7.clone.1) + %multiply.26118.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246634.9.clone.1, %select.20761.3.clone.1) + %add.246635.7.clone.1 = f32[1280,1280]{1,0} add(%select.20752.7.clone.1, %multiply.26118.7.clone.1) + %multiply.26119.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246635.7.clone.1, %maximum.3651.3.clone.1) + %select.20762.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7182.3.clone.1, %multiply.26109.9.clone.1, %multiply.26119.7.clone.1) + %multiply.26120.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20762.7.clone.1, %broadcast.244500.640) + %clamp.1161.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26120.5.clone.1, %broadcast.244501.384) + %multiply.26121.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1161.3.clone.1, %broadcast.244502.1) + %constant_182688_1_clone_1 = u32[] constant(4003218824) + %broadcast.255039.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182688_1_clone_1), dimensions={} + %add.250139.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.255039.44.clone.1) + %constant_182695_1_clone_1 = u32[] constant(1756142685) + %broadcast.255040.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182695_1_clone_1), dimensions={} + %add.250140.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.255040.113.clone.1) + %add.250141.35.clone.1 = u32[1280,1280]{1,0} add(%add.250139.37.clone.1, %add.250140.99.clone.1) + %shift-left.110372.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250140.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116613.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250140.99.clone.1, %broadcast.244415.6016) + %or.116127.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110372.31.clone.1, %shift-right-logical.116613.29.clone.1) + %xor.122699.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250141.35.clone.1, %or.116127.29.clone.1) + %add.250142.5.clone.1 = u32[1280,1280]{1,0} add(%add.250141.35.clone.1, %xor.122699.27.clone.1) + %shift-left.110373.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122699.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116614.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122699.27.clone.1, %broadcast.244417.5760) + %or.116128.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110373.9.clone.1, %shift-right-logical.116614.9.clone.1) + %xor.122700.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250142.5.clone.1, %or.116128.7.clone.1) + %add.250144.3.clone.1 = u32[1280,1280]{1,0} add(%add.250142.5.clone.1, %xor.122700.5.clone.1) + %shift-left.110374.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122700.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116615.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122700.5.clone.1, %broadcast.244419.4352) + %or.116129.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110374.5.clone.1, %shift-right-logical.116615.5.clone.1) + %xor.122701.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250144.3.clone.1, %or.116129.3.clone.1) + %add.250145.3.clone.1 = u32[1280,1280]{1,0} add(%add.250144.3.clone.1, %xor.122701.3.clone.1) + %add.250146.7.clone.1 = u32[1280,1280]{1,0} add(%add.250145.3.clone.1, %broadcast.255040.113.clone.1) + %shift-left.110375.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122701.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116616.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122701.3.clone.1, %broadcast.244418.4352) + %or.116131.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110375.5.clone.1, %shift-right-logical.116616.5.clone.1) + %xor.122702.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250145.3.clone.1, %or.116131.3.clone.1) + %constant_218439_1_clone_1 = u32[] constant(2648818192) + %broadcast.255059.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218439_1_clone_1), dimensions={} + %add.250147.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122702.3.clone.1, %broadcast.255059.5.clone.1) + %add.250149.5.clone.1 = u32[1280,1280]{1,0} add(%add.250146.7.clone.1, %add.250147.5.clone.1) + %shift-left.110376.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250147.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116617.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250147.5.clone.1, %broadcast.244416.5760) + %or.116132.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110376.9.clone.1, %shift-right-logical.116617.9.clone.1) + %xor.122703.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250149.5.clone.1, %or.116132.7.clone.1) + %add.250150.3.clone.1 = u32[1280,1280]{1,0} add(%add.250149.5.clone.1, %xor.122703.5.clone.1) + %shift-left.110377.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122703.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116618.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122703.5.clone.1, %broadcast.244429.2304) + %or.116133.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110377.9.clone.1, %shift-right-logical.116618.9.clone.1) + %xor.122704.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250150.3.clone.1, %or.116133.7.clone.1) + %add.250151.3.clone.1 = u32[1280,1280]{1,0} add(%add.250150.3.clone.1, %xor.122704.5.clone.1) + %shift-left.110378.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122704.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116619.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122704.5.clone.1, %broadcast.244430.4608) + %or.116134.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110378.9.clone.1, %shift-right-logical.116619.9.clone.1) + %xor.122705.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250151.3.clone.1, %or.116134.7.clone.1) + %add.250152.3.clone.1 = u32[1280,1280]{1,0} add(%add.250151.3.clone.1, %xor.122705.5.clone.1) + %constant_182697_1_clone_1 = u32[] constant(2648818191) + %broadcast.255073.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182697_1_clone_1), dimensions={} + %add.250153.7.clone.1 = u32[1280,1280]{1,0} add(%add.250152.3.clone.1, %broadcast.255073.24.clone.1) + %shift-left.110380.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122705.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116620.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122705.5.clone.1, %broadcast.244434.2816) + %or.116135.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110380.11.clone.1, %shift-right-logical.116620.11.clone.1) + %xor.122706.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250152.3.clone.1, %or.116135.9.clone.1) + %constant_218440_1_clone_1 = u32[] constant(4003218826) + %broadcast.255079.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218440_1_clone_1), dimensions={} + %add.250155.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122706.7.clone.1, %broadcast.255079.5.clone.1) + %add.250159.5.clone.1 = u32[1280,1280]{1,0} add(%add.250153.7.clone.1, %add.250155.5.clone.1) + %shift-left.110381.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250155.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116621.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250155.5.clone.1, %broadcast.244415.6016) + %or.116136.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110381.9.clone.1, %shift-right-logical.116621.9.clone.1) + %xor.122707.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250159.5.clone.1, %or.116136.7.clone.1) + %add.250160.3.clone.1 = u32[1280,1280]{1,0} add(%add.250159.5.clone.1, %xor.122707.5.clone.1) + %shift-left.110382.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122707.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116622.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122707.5.clone.1, %broadcast.244417.5760) + %or.116137.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110382.9.clone.1, %shift-right-logical.116622.9.clone.1) + %xor.122709.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250160.3.clone.1, %or.116137.7.clone.1) + %add.250161.3.clone.1 = u32[1280,1280]{1,0} add(%add.250160.3.clone.1, %xor.122709.5.clone.1) + %shift-left.110383.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122709.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116623.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122709.5.clone.1, %broadcast.244419.4352) + %or.116138.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110383.7.clone.1, %shift-right-logical.116623.7.clone.1) + %xor.122710.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250161.3.clone.1, %or.116138.5.clone.1) + %add.250162.3.clone.1 = u32[1280,1280]{1,0} add(%add.250161.3.clone.1, %xor.122710.3.clone.1) + %add.250164.7.clone.1 = u32[1280,1280]{1,0} add(%add.250162.3.clone.1, %broadcast.255039.44.clone.1) + %shift-left.110385.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122710.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116624.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122710.3.clone.1, %broadcast.244418.4352) + %or.116139.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110385.7.clone.1, %shift-right-logical.116624.7.clone.1) + %xor.122711.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250162.3.clone.1, %or.116139.5.clone.1) + %constant_218441_1_clone_1 = u32[] constant(1756142688) + %broadcast.255093.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218441_1_clone_1), dimensions={} + %add.250165.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122711.3.clone.1, %broadcast.255093.5.clone.1) + %add.250166.5.clone.1 = u32[1280,1280]{1,0} add(%add.250164.7.clone.1, %add.250165.5.clone.1) + %shift-left.110386.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250165.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116625.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250165.5.clone.1, %broadcast.244416.5760) + %or.116141.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110386.9.clone.1, %shift-right-logical.116625.9.clone.1) + %xor.122712.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250166.5.clone.1, %or.116141.7.clone.1) + %add.250167.3.clone.1 = u32[1280,1280]{1,0} add(%add.250166.5.clone.1, %xor.122712.5.clone.1) + %shift-left.110387.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122712.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116626.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122712.5.clone.1, %broadcast.244429.2304) + %or.116142.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110387.9.clone.1, %shift-right-logical.116626.9.clone.1) + %xor.122714.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250167.3.clone.1, %or.116142.7.clone.1) + %add.250169.3.clone.1 = u32[1280,1280]{1,0} add(%add.250167.3.clone.1, %xor.122714.5.clone.1) + %shift-left.110388.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122714.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116627.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122714.5.clone.1, %broadcast.244430.4608) + %or.116143.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110388.9.clone.1, %shift-right-logical.116627.9.clone.1) + %xor.122715.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250169.3.clone.1, %or.116143.7.clone.1) + %add.250170.3.clone.1 = u32[1280,1280]{1,0} add(%add.250169.3.clone.1, %xor.122715.5.clone.1) + %add.250171.7.clone.1 = u32[1280,1280]{1,0} add(%add.250170.3.clone.1, %broadcast.255040.113.clone.1) + %shift-left.110390.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122715.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116628.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122715.5.clone.1, %broadcast.244434.2816) + %or.116144.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110390.11.clone.1, %shift-right-logical.116628.11.clone.1) + %xor.122716.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250170.3.clone.1, %or.116144.9.clone.1) + %constant_218442_1_clone_1 = u32[] constant(2648818195) + %broadcast.255103.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218442_1_clone_1), dimensions={} + %add.250172.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122716.7.clone.1, %broadcast.255103.5.clone.1) + %add.250174.5.clone.1 = u32[1280,1280]{1,0} add(%add.250171.7.clone.1, %add.250172.5.clone.1) + %shift-left.110391.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250172.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116629.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250172.5.clone.1, %broadcast.244415.6016) + %or.116146.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110391.9.clone.1, %shift-right-logical.116629.9.clone.1) + %xor.122717.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250174.5.clone.1, %or.116146.7.clone.1) + %add.250175.3.clone.1 = u32[1280,1280]{1,0} add(%add.250174.5.clone.1, %xor.122717.5.clone.1) + %shift-left.110392.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122717.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116630.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122717.5.clone.1, %broadcast.244417.5760) + %or.116147.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110392.9.clone.1, %shift-right-logical.116630.9.clone.1) + %xor.122719.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250175.3.clone.1, %or.116147.7.clone.1) + %add.250176.3.clone.1 = u32[1280,1280]{1,0} add(%add.250175.3.clone.1, %xor.122719.5.clone.1) + %shift-left.110393.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122719.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116631.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122719.5.clone.1, %broadcast.244419.4352) + %or.116148.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110393.5.clone.1, %shift-right-logical.116631.5.clone.1) + %xor.122720.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250176.3.clone.1, %or.116148.3.clone.1) + %add.250177.3.clone.1 = u32[1280,1280]{1,0} add(%add.250176.3.clone.1, %xor.122720.3.clone.1) + %add.250178.17.clone.1 = u32[1280,1280]{1,0} add(%add.250177.3.clone.1, %broadcast.255073.24.clone.1) + %shift-left.110395.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122720.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116632.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122720.3.clone.1, %broadcast.244418.4352) + %or.116149.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110395.5.clone.1, %shift-right-logical.116632.5.clone.1) + %xor.122721.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250177.3.clone.1, %or.116149.3.clone.1) + %constant_218443_1_clone_1 = u32[] constant(4003218829) + %broadcast.255113.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218443_1_clone_1), dimensions={} + %add.250180.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122721.15.clone.1, %broadcast.255113.19.clone.1) + %xor.122722.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250178.17.clone.1, %add.250180.19.clone.1) + %shift-right-logical.116633.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122722.17.clone.1, %broadcast.244468.1920) + %or.116151.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116633.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5790.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116151.13.clone.1) + %add.250184.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5790.11.clone.1, %broadcast.244470.1152) + %multiply.26829.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250184.9.clone.1, %broadcast.244471.896) + %add.250185.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26829.7.clone.1, %broadcast.244408.1024) + %maximum.3722.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250185.5.clone.1) + %abs.1564.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3722.3.clone.1) + %compare.7277.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1564.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26830.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3722.3.clone.1, %broadcast.244476.1152) + %negate.4633.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3722.3.clone.1) + %multiply.26831.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3722.3.clone.1, %negate.4633.5.clone.1) + %log-plus-one.1564.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26831.5.clone.1) + %negate.4634.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1564.3.clone.1) + %compare.7279.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4634.4.clone.1, %broadcast.244477.384), direction=LT + %select.21290.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21291.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21292.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21293.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21294.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21295.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21296.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21297.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21298.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250186.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4634.4.clone.1, %broadcast.244496.640) + %sqrt.1564.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4634.4.clone.1) + %add.250187.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1564.5.clone.1, %broadcast.244498.640) + %select.21299.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7279.3.clone.1, %add.250186.5.clone.1, %add.250187.5.clone.1) + %multiply.26832.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21298.3.clone.1, %select.21299.3.clone.1) + %add.250189.1.clone.1 = f32[1280,1280]{1,0} add(%select.21297.3.clone.1, %multiply.26832.1.clone.1) + %multiply.26833.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250189.1.clone.1, %select.21299.3.clone.1) + %add.250190.1.clone.1 = f32[1280,1280]{1,0} add(%select.21296.3.clone.1, %multiply.26833.1.clone.1) + %multiply.26834.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250190.1.clone.1, %select.21299.3.clone.1) + %add.250191.1.clone.1 = f32[1280,1280]{1,0} add(%select.21295.3.clone.1, %multiply.26834.1.clone.1) + %multiply.26835.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250191.1.clone.1, %select.21299.3.clone.1) + %add.250192.1.clone.1 = f32[1280,1280]{1,0} add(%select.21294.3.clone.1, %multiply.26835.1.clone.1) + %multiply.26836.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250192.1.clone.1, %select.21299.3.clone.1) + %add.250194.3.clone.1 = f32[1280,1280]{1,0} add(%select.21293.5.clone.1, %multiply.26836.1.clone.1) + %multiply.26837.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250194.3.clone.1, %select.21299.3.clone.1) + %add.250195.3.clone.1 = f32[1280,1280]{1,0} add(%select.21292.5.clone.1, %multiply.26837.1.clone.1) + %multiply.26838.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250195.3.clone.1, %select.21299.3.clone.1) + %add.250196.9.clone.1 = f32[1280,1280]{1,0} add(%select.21291.11.clone.1, %multiply.26838.7.clone.1) + %multiply.26839.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250196.9.clone.1, %select.21299.3.clone.1) + %add.250197.7.clone.1 = f32[1280,1280]{1,0} add(%select.21290.7.clone.1, %multiply.26839.7.clone.1) + %multiply.26840.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250197.7.clone.1, %maximum.3722.3.clone.1) + %select.21300.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7277.3.clone.1, %multiply.26830.9.clone.1, %multiply.26840.7.clone.1) + %multiply.26841.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21300.7.clone.1, %broadcast.244500.640) + %clamp.1208.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26841.5.clone.1, %broadcast.244501.384) + %multiply.26842.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1208.3.clone.1, %broadcast.244502.1) + %constant_168081_1_clone_1 = u32[] constant(623170565) + %broadcast.248713.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168081_1_clone_1), dimensions={} + %add.246536.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248713.44.clone.1) + %constant_168088_1_clone_1 = u32[] constant(1375940360) + %broadcast.248714.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168088_1_clone_1), dimensions={} + %add.246537.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248714.113.clone.1) + %add.246538.35.clone.1 = u32[1280,1280]{1,0} add(%add.246536.37.clone.1, %add.246537.99.clone.1) + %shift-left.108820.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246537.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114951.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246537.99.clone.1, %broadcast.244415.6016) + %or.114481.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108820.31.clone.1, %shift-right-logical.114951.29.clone.1) + %xor.121030.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246538.35.clone.1, %or.114481.29.clone.1) + %add.246540.5.clone.1 = u32[1280,1280]{1,0} add(%add.246538.35.clone.1, %xor.121030.27.clone.1) + %shift-left.108821.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121030.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114952.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121030.27.clone.1, %broadcast.244417.5760) + %or.114482.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108821.9.clone.1, %shift-right-logical.114952.9.clone.1) + %xor.121031.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246540.5.clone.1, %or.114482.7.clone.1) + %add.246544.3.clone.1 = u32[1280,1280]{1,0} add(%add.246540.5.clone.1, %xor.121031.5.clone.1) + %shift-left.108822.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121031.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114953.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121031.5.clone.1, %broadcast.244419.4352) + %or.114483.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108822.5.clone.1, %shift-right-logical.114953.5.clone.1) + %xor.121033.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246544.3.clone.1, %or.114483.3.clone.1) + %add.246545.3.clone.1 = u32[1280,1280]{1,0} add(%add.246544.3.clone.1, %xor.121033.3.clone.1) + %add.246546.7.clone.1 = u32[1280,1280]{1,0} add(%add.246545.3.clone.1, %broadcast.248714.113.clone.1) + %shift-left.108823.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121033.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114954.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121033.3.clone.1, %broadcast.244418.4352) + %or.114484.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108823.5.clone.1, %shift-right-logical.114954.5.clone.1) + %xor.121034.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246545.3.clone.1, %or.114484.3.clone.1) + %constant_218041_1_clone_1 = u32[] constant(1828118744) + %broadcast.248724.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218041_1_clone_1), dimensions={} + %add.246547.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121034.3.clone.1, %broadcast.248724.5.clone.1) + %add.246549.5.clone.1 = u32[1280,1280]{1,0} add(%add.246546.7.clone.1, %add.246547.5.clone.1) + %shift-left.108824.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246547.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114955.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246547.5.clone.1, %broadcast.244416.5760) + %or.114485.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108824.9.clone.1, %shift-right-logical.114955.9.clone.1) + %xor.121035.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246549.5.clone.1, %or.114485.7.clone.1) + %add.246550.3.clone.1 = u32[1280,1280]{1,0} add(%add.246549.5.clone.1, %xor.121035.5.clone.1) + %shift-left.108825.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121035.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114956.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121035.5.clone.1, %broadcast.244429.2304) + %or.114486.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108825.9.clone.1, %shift-right-logical.114956.9.clone.1) + %xor.121036.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246550.3.clone.1, %or.114486.7.clone.1) + %add.246551.3.clone.1 = u32[1280,1280]{1,0} add(%add.246550.3.clone.1, %xor.121036.5.clone.1) + %shift-left.108826.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121036.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114957.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121036.5.clone.1, %broadcast.244430.4608) + %or.114488.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108826.9.clone.1, %shift-right-logical.114957.9.clone.1) + %xor.121038.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246551.3.clone.1, %or.114488.7.clone.1) + %add.246552.3.clone.1 = u32[1280,1280]{1,0} add(%add.246551.3.clone.1, %xor.121038.5.clone.1) + %constant_168090_1_clone_1 = u32[] constant(1828118743) + %broadcast.248731.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_168090_1_clone_1), dimensions={} + %add.246554.7.clone.1 = u32[1280,1280]{1,0} add(%add.246552.3.clone.1, %broadcast.248731.24.clone.1) + %shift-left.108827.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121038.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114958.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121038.5.clone.1, %broadcast.244434.2816) + %or.114489.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108827.11.clone.1, %shift-right-logical.114958.11.clone.1) + %xor.121039.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246552.3.clone.1, %or.114489.9.clone.1) + %constant_218042_1_clone_1 = u32[] constant(623170567) + %broadcast.248734.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218042_1_clone_1), dimensions={} + %add.246555.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121039.7.clone.1, %broadcast.248734.5.clone.1) + %add.246556.5.clone.1 = u32[1280,1280]{1,0} add(%add.246554.7.clone.1, %add.246555.5.clone.1) + %shift-left.108828.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246555.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114959.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246555.5.clone.1, %broadcast.244415.6016) + %or.114490.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108828.9.clone.1, %shift-right-logical.114959.9.clone.1) + %xor.121040.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246556.5.clone.1, %or.114490.7.clone.1) + %add.246557.3.clone.1 = u32[1280,1280]{1,0} add(%add.246556.5.clone.1, %xor.121040.5.clone.1) + %shift-left.108829.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121040.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114960.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121040.5.clone.1, %broadcast.244417.5760) + %or.114491.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108829.9.clone.1, %shift-right-logical.114960.9.clone.1) + %xor.121041.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246557.3.clone.1, %or.114491.7.clone.1) + %add.246559.3.clone.1 = u32[1280,1280]{1,0} add(%add.246557.3.clone.1, %xor.121041.5.clone.1) + %shift-left.108830.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121041.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114961.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121041.5.clone.1, %broadcast.244419.4352) + %or.114493.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108830.7.clone.1, %shift-right-logical.114961.7.clone.1) + %xor.121043.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246559.3.clone.1, %or.114493.5.clone.1) + %add.246560.3.clone.1 = u32[1280,1280]{1,0} add(%add.246559.3.clone.1, %xor.121043.3.clone.1) + %add.246561.7.clone.1 = u32[1280,1280]{1,0} add(%add.246560.3.clone.1, %broadcast.248713.44.clone.1) + %shift-left.108831.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121043.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114962.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121043.3.clone.1, %broadcast.244418.4352) + %or.114494.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108831.7.clone.1, %shift-right-logical.114962.7.clone.1) + %xor.121044.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246560.3.clone.1, %or.114494.5.clone.1) + %constant_218043_1_clone_1 = u32[] constant(1375940363) + %broadcast.248744.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218043_1_clone_1), dimensions={} + %add.246562.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121044.3.clone.1, %broadcast.248744.5.clone.1) + %add.246563.5.clone.1 = u32[1280,1280]{1,0} add(%add.246561.7.clone.1, %add.246562.5.clone.1) + %shift-left.108832.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246562.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114963.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246562.5.clone.1, %broadcast.244416.5760) + %or.114495.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108832.9.clone.1, %shift-right-logical.114963.9.clone.1) + %xor.121045.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246563.5.clone.1, %or.114495.7.clone.1) + %add.246565.3.clone.1 = u32[1280,1280]{1,0} add(%add.246563.5.clone.1, %xor.121045.5.clone.1) + %shift-left.108833.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121045.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114964.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121045.5.clone.1, %broadcast.244429.2304) + %or.114496.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108833.9.clone.1, %shift-right-logical.114964.9.clone.1) + %xor.121046.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246565.3.clone.1, %or.114496.7.clone.1) + %add.246568.3.clone.1 = u32[1280,1280]{1,0} add(%add.246565.3.clone.1, %xor.121046.5.clone.1) + %shift-left.108834.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121046.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114965.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121046.5.clone.1, %broadcast.244430.4608) + %or.114498.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108834.9.clone.1, %shift-right-logical.114965.9.clone.1) + %xor.121048.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246568.3.clone.1, %or.114498.7.clone.1) + %add.246569.3.clone.1 = u32[1280,1280]{1,0} add(%add.246568.3.clone.1, %xor.121048.5.clone.1) + %add.246570.7.clone.1 = u32[1280,1280]{1,0} add(%add.246569.3.clone.1, %broadcast.248714.113.clone.1) + %shift-left.108835.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121048.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114966.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121048.5.clone.1, %broadcast.244434.2816) + %or.114499.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108835.11.clone.1, %shift-right-logical.114966.11.clone.1) + %xor.121049.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246569.3.clone.1, %or.114499.9.clone.1) + %constant_218044_1_clone_1 = u32[] constant(1828118747) + %broadcast.248754.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218044_1_clone_1), dimensions={} + %add.246571.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121049.7.clone.1, %broadcast.248754.5.clone.1) + %add.246572.5.clone.1 = u32[1280,1280]{1,0} add(%add.246570.7.clone.1, %add.246571.5.clone.1) + %shift-left.108836.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246571.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114967.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246571.5.clone.1, %broadcast.244415.6016) + %or.114500.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108836.9.clone.1, %shift-right-logical.114967.9.clone.1) + %xor.121050.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246572.5.clone.1, %or.114500.7.clone.1) + %add.246573.3.clone.1 = u32[1280,1280]{1,0} add(%add.246572.5.clone.1, %xor.121050.5.clone.1) + %shift-left.108837.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121050.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114968.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121050.5.clone.1, %broadcast.244417.5760) + %or.114501.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108837.9.clone.1, %shift-right-logical.114968.9.clone.1) + %xor.121051.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246573.3.clone.1, %or.114501.7.clone.1) + %add.246574.3.clone.1 = u32[1280,1280]{1,0} add(%add.246573.3.clone.1, %xor.121051.5.clone.1) + %shift-left.108838.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121051.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114969.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121051.5.clone.1, %broadcast.244419.4352) + %or.114503.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108838.5.clone.1, %shift-right-logical.114969.5.clone.1) + %xor.121052.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246574.3.clone.1, %or.114503.3.clone.1) + %add.246575.3.clone.1 = u32[1280,1280]{1,0} add(%add.246574.3.clone.1, %xor.121052.3.clone.1) + %add.246576.17.clone.1 = u32[1280,1280]{1,0} add(%add.246575.3.clone.1, %broadcast.248731.24.clone.1) + %shift-left.108839.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121052.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114970.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121052.3.clone.1, %broadcast.244418.4352) + %or.114504.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108839.5.clone.1, %shift-right-logical.114970.5.clone.1) + %xor.121053.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246575.3.clone.1, %or.114504.3.clone.1) + %constant_218045_1_clone_1 = u32[] constant(623170570) + %broadcast.248764.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218045_1_clone_1), dimensions={} + %add.246577.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121053.15.clone.1, %broadcast.248764.19.clone.1) + %xor.121054.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246576.17.clone.1, %add.246577.19.clone.1) + %shift-right-logical.114971.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121054.17.clone.1, %broadcast.244468.1920) + %or.114505.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114971.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5718.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114505.13.clone.1) + %add.246578.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5718.11.clone.1, %broadcast.244470.1152) + %multiply.26094.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246578.9.clone.1, %broadcast.244471.896) + %add.246579.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26094.7.clone.1, %broadcast.244408.1024) + %maximum.3650.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246579.5.clone.1) + %abs.1516.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3650.3.clone.1) + %compare.7180.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1516.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26095.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3650.3.clone.1, %broadcast.244476.1152) + %negate.4537.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3650.3.clone.1) + %multiply.26096.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3650.3.clone.1, %negate.4537.5.clone.1) + %log-plus-one.1516.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26096.5.clone.1) + %negate.4538.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1516.3.clone.1) + %compare.7181.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4538.4.clone.1, %broadcast.244477.384), direction=LT + %select.20741.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20742.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20743.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20744.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20745.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20746.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20747.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20748.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20749.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246580.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4538.4.clone.1, %broadcast.244496.640) + %sqrt.1516.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4538.4.clone.1) + %add.246581.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1516.5.clone.1, %broadcast.244498.640) + %select.20750.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7181.3.clone.1, %add.246580.5.clone.1, %add.246581.5.clone.1) + %multiply.26097.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20749.3.clone.1, %select.20750.3.clone.1) + %add.246582.1.clone.1 = f32[1280,1280]{1,0} add(%select.20748.3.clone.1, %multiply.26097.1.clone.1) + %multiply.26098.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246582.1.clone.1, %select.20750.3.clone.1) + %add.246583.1.clone.1 = f32[1280,1280]{1,0} add(%select.20747.3.clone.1, %multiply.26098.1.clone.1) + %multiply.26099.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246583.1.clone.1, %select.20750.3.clone.1) + %add.246584.1.clone.1 = f32[1280,1280]{1,0} add(%select.20746.3.clone.1, %multiply.26099.1.clone.1) + %multiply.26100.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246584.1.clone.1, %select.20750.3.clone.1) + %add.246585.1.clone.1 = f32[1280,1280]{1,0} add(%select.20745.3.clone.1, %multiply.26100.1.clone.1) + %multiply.26101.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246585.1.clone.1, %select.20750.3.clone.1) + %add.246586.3.clone.1 = f32[1280,1280]{1,0} add(%select.20744.5.clone.1, %multiply.26101.1.clone.1) + %multiply.26102.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246586.3.clone.1, %select.20750.3.clone.1) + %add.246587.3.clone.1 = f32[1280,1280]{1,0} add(%select.20743.5.clone.1, %multiply.26102.1.clone.1) + %multiply.26103.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246587.3.clone.1, %select.20750.3.clone.1) + %add.246588.9.clone.1 = f32[1280,1280]{1,0} add(%select.20742.11.clone.1, %multiply.26103.7.clone.1) + %multiply.26104.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246588.9.clone.1, %select.20750.3.clone.1) + %add.246589.7.clone.1 = f32[1280,1280]{1,0} add(%select.20741.7.clone.1, %multiply.26104.7.clone.1) + %multiply.26105.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246589.7.clone.1, %maximum.3650.3.clone.1) + %select.20751.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7180.3.clone.1, %multiply.26095.9.clone.1, %multiply.26105.7.clone.1) + %multiply.26106.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20751.7.clone.1, %broadcast.244500.640) + %clamp.1160.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26106.5.clone.1, %broadcast.244501.384) + %multiply.26107.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1160.3.clone.1, %broadcast.244502.1) + %constant_196564_1_clone_1 = u32[] constant(3909802021) + %broadcast.261019.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_196564_1_clone_1), dimensions={} + %add.253565.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.261019.44.clone.1) + %constant_196571_1_clone_1 = u32[] constant(983326418) + %broadcast.261020.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_196571_1_clone_1), dimensions={} + %add.253566.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.261020.113.clone.1) + %add.253567.35.clone.1 = u32[1280,1280]{1,0} add(%add.253565.37.clone.1, %add.253566.99.clone.1) + %shift-left.111860.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253566.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.118174.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253566.99.clone.1, %broadcast.244415.6016) + %or.117713.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111860.31.clone.1, %shift-right-logical.118174.29.clone.1) + %xor.124270.27.clone.1 = u32[1280,1280]{1,0} xor(%add.253567.35.clone.1, %or.117713.29.clone.1) + %add.253568.5.clone.1 = u32[1280,1280]{1,0} add(%add.253567.35.clone.1, %xor.124270.27.clone.1) + %shift-left.111861.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124270.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.118175.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124270.27.clone.1, %broadcast.244417.5760) + %or.117714.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111861.9.clone.1, %shift-right-logical.118175.9.clone.1) + %xor.124271.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253568.5.clone.1, %or.117714.7.clone.1) + %add.253569.3.clone.1 = u32[1280,1280]{1,0} add(%add.253568.5.clone.1, %xor.124271.5.clone.1) + %shift-left.111862.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124271.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118176.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124271.5.clone.1, %broadcast.244419.4352) + %or.117716.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111862.5.clone.1, %shift-right-logical.118176.5.clone.1) + %xor.124272.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253569.3.clone.1, %or.117716.3.clone.1) + %add.253570.3.clone.1 = u32[1280,1280]{1,0} add(%add.253569.3.clone.1, %xor.124272.3.clone.1) + %add.253571.7.clone.1 = u32[1280,1280]{1,0} add(%add.253570.3.clone.1, %broadcast.261020.113.clone.1) + %shift-left.111863.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124272.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118177.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124272.3.clone.1, %broadcast.244418.4352) + %or.117717.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111863.5.clone.1, %shift-right-logical.118177.5.clone.1) + %xor.124273.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253570.3.clone.1, %or.117717.3.clone.1) + %constant_218817_1_clone_1 = u32[] constant(3360136494) + %broadcast.261032.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218817_1_clone_1), dimensions={} + %add.253572.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124273.3.clone.1, %broadcast.261032.5.clone.1) + %add.253573.5.clone.1 = u32[1280,1280]{1,0} add(%add.253571.7.clone.1, %add.253572.5.clone.1) + %shift-left.111864.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253572.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118178.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253572.5.clone.1, %broadcast.244416.5760) + %or.117718.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111864.9.clone.1, %shift-right-logical.118178.9.clone.1) + %xor.124274.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253573.5.clone.1, %or.117718.7.clone.1) + %add.253574.3.clone.1 = u32[1280,1280]{1,0} add(%add.253573.5.clone.1, %xor.124274.5.clone.1) + %shift-left.111865.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124274.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118179.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124274.5.clone.1, %broadcast.244429.2304) + %or.117719.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111865.9.clone.1, %shift-right-logical.118179.9.clone.1) + %xor.124275.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253574.3.clone.1, %or.117719.7.clone.1) + %add.253575.3.clone.1 = u32[1280,1280]{1,0} add(%add.253574.3.clone.1, %xor.124275.5.clone.1) + %shift-left.111866.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124275.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118181.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124275.5.clone.1, %broadcast.244430.4608) + %or.117721.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111866.9.clone.1, %shift-right-logical.118181.9.clone.1) + %xor.124276.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253575.3.clone.1, %or.117721.7.clone.1) + %add.253576.3.clone.1 = u32[1280,1280]{1,0} add(%add.253575.3.clone.1, %xor.124276.5.clone.1) + %constant_196573_1_clone_1 = u32[] constant(3360136493) + %broadcast.261044.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_196573_1_clone_1), dimensions={} + %add.253577.7.clone.1 = u32[1280,1280]{1,0} add(%add.253576.3.clone.1, %broadcast.261044.24.clone.1) + %shift-left.111867.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124276.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118182.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124276.5.clone.1, %broadcast.244434.2816) + %or.117722.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111867.11.clone.1, %shift-right-logical.118182.11.clone.1) + %xor.124277.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253576.3.clone.1, %or.117722.9.clone.1) + %constant_218818_1_clone_1 = u32[] constant(3909802023) + %broadcast.261047.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218818_1_clone_1), dimensions={} + %add.253578.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124277.7.clone.1, %broadcast.261047.5.clone.1) + %add.253580.5.clone.1 = u32[1280,1280]{1,0} add(%add.253577.7.clone.1, %add.253578.5.clone.1) + %shift-left.111868.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253578.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118183.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253578.5.clone.1, %broadcast.244415.6016) + %or.117723.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111868.9.clone.1, %shift-right-logical.118183.9.clone.1) + %xor.124278.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253580.5.clone.1, %or.117723.7.clone.1) + %add.253581.3.clone.1 = u32[1280,1280]{1,0} add(%add.253580.5.clone.1, %xor.124278.5.clone.1) + %shift-left.111869.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124278.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118184.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124278.5.clone.1, %broadcast.244417.5760) + %or.117724.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111869.9.clone.1, %shift-right-logical.118184.9.clone.1) + %xor.124280.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253581.3.clone.1, %or.117724.7.clone.1) + %add.253582.3.clone.1 = u32[1280,1280]{1,0} add(%add.253581.3.clone.1, %xor.124280.5.clone.1) + %shift-left.111870.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124280.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118186.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124280.5.clone.1, %broadcast.244419.4352) + %or.117725.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111870.7.clone.1, %shift-right-logical.118186.7.clone.1) + %xor.124281.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253582.3.clone.1, %or.117725.5.clone.1) + %add.253583.3.clone.1 = u32[1280,1280]{1,0} add(%add.253582.3.clone.1, %xor.124281.3.clone.1) + %add.253584.7.clone.1 = u32[1280,1280]{1,0} add(%add.253583.3.clone.1, %broadcast.261019.44.clone.1) + %shift-left.111871.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124281.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118187.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124281.3.clone.1, %broadcast.244418.4352) + %or.117726.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111871.7.clone.1, %shift-right-logical.118187.7.clone.1) + %xor.124282.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253583.3.clone.1, %or.117726.5.clone.1) + %constant_218819_1_clone_1 = u32[] constant(983326421) + %broadcast.261059.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218819_1_clone_1), dimensions={} + %add.253585.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124282.3.clone.1, %broadcast.261059.5.clone.1) + %add.253586.5.clone.1 = u32[1280,1280]{1,0} add(%add.253584.7.clone.1, %add.253585.5.clone.1) + %shift-left.111872.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253585.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118188.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253585.5.clone.1, %broadcast.244416.5760) + %or.117727.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111872.9.clone.1, %shift-right-logical.118188.9.clone.1) + %xor.124283.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253586.5.clone.1, %or.117727.7.clone.1) + %add.253587.3.clone.1 = u32[1280,1280]{1,0} add(%add.253586.5.clone.1, %xor.124283.5.clone.1) + %shift-left.111873.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124283.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118189.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124283.5.clone.1, %broadcast.244429.2304) + %or.117728.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111873.9.clone.1, %shift-right-logical.118189.9.clone.1) + %xor.124284.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253587.3.clone.1, %or.117728.7.clone.1) + %add.253588.3.clone.1 = u32[1280,1280]{1,0} add(%add.253587.3.clone.1, %xor.124284.5.clone.1) + %shift-left.111874.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124284.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118191.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124284.5.clone.1, %broadcast.244430.4608) + %or.117729.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111874.9.clone.1, %shift-right-logical.118191.9.clone.1) + %xor.124285.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253588.3.clone.1, %or.117729.7.clone.1) + %add.253589.3.clone.1 = u32[1280,1280]{1,0} add(%add.253588.3.clone.1, %xor.124285.5.clone.1) + %add.253590.7.clone.1 = u32[1280,1280]{1,0} add(%add.253589.3.clone.1, %broadcast.261020.113.clone.1) + %shift-left.111875.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124285.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118192.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124285.5.clone.1, %broadcast.244434.2816) + %or.117730.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111875.11.clone.1, %shift-right-logical.118192.11.clone.1) + %xor.124286.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253589.3.clone.1, %or.117730.9.clone.1) + %constant_218820_1_clone_1 = u32[] constant(3360136497) + %broadcast.261071.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218820_1_clone_1), dimensions={} + %add.253591.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124286.7.clone.1, %broadcast.261071.5.clone.1) + %add.253592.5.clone.1 = u32[1280,1280]{1,0} add(%add.253590.7.clone.1, %add.253591.5.clone.1) + %shift-left.111876.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253591.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118193.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253591.5.clone.1, %broadcast.244415.6016) + %or.117731.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111876.9.clone.1, %shift-right-logical.118193.9.clone.1) + %xor.124287.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253592.5.clone.1, %or.117731.7.clone.1) + %add.253593.3.clone.1 = u32[1280,1280]{1,0} add(%add.253592.5.clone.1, %xor.124287.5.clone.1) + %shift-left.111877.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124287.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118194.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124287.5.clone.1, %broadcast.244417.5760) + %or.117732.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111877.9.clone.1, %shift-right-logical.118194.9.clone.1) + %xor.124288.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253593.3.clone.1, %or.117732.7.clone.1) + %add.253594.3.clone.1 = u32[1280,1280]{1,0} add(%add.253593.3.clone.1, %xor.124288.5.clone.1) + %shift-left.111878.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124288.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118196.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124288.5.clone.1, %broadcast.244419.4352) + %or.117733.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111878.5.clone.1, %shift-right-logical.118196.5.clone.1) + %xor.124289.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253594.3.clone.1, %or.117733.3.clone.1) + %add.253595.3.clone.1 = u32[1280,1280]{1,0} add(%add.253594.3.clone.1, %xor.124289.3.clone.1) + %add.253596.17.clone.1 = u32[1280,1280]{1,0} add(%add.253595.3.clone.1, %broadcast.261044.24.clone.1) + %shift-left.111879.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124289.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118197.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124289.3.clone.1, %broadcast.244418.4352) + %or.117734.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111879.5.clone.1, %shift-right-logical.118197.5.clone.1) + %xor.124290.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253595.3.clone.1, %or.117734.3.clone.1) + %constant_218821_1_clone_1 = u32[] constant(3909802026) + %broadcast.261085.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218821_1_clone_1), dimensions={} + %add.253597.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124290.15.clone.1, %broadcast.261085.19.clone.1) + %xor.124291.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253596.17.clone.1, %add.253597.19.clone.1) + %shift-right-logical.118198.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124291.17.clone.1, %broadcast.244468.1920) + %or.117735.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.118198.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5858.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117735.13.clone.1) + %add.253598.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5858.11.clone.1, %broadcast.244470.1152) + %multiply.27534.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253598.9.clone.1, %broadcast.244471.896) + %add.253600.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27534.7.clone.1, %broadcast.244408.1024) + %maximum.3790.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253600.5.clone.1) + %abs.1610.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3790.3.clone.1) + %compare.7382.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1610.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27535.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3790.3.clone.1, %broadcast.244476.1152) + %negate.4725.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3790.3.clone.1) + %multiply.27536.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3790.3.clone.1, %negate.4725.5.clone.1) + %log-plus-one.1610.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27536.5.clone.1) + %negate.4726.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1610.3.clone.1) + %compare.7383.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4726.4.clone.1, %broadcast.244477.384), direction=LT + %select.21817.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21818.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21819.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21820.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21821.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21822.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21823.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21824.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21825.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253601.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4726.4.clone.1, %broadcast.244496.640) + %sqrt.1610.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4726.4.clone.1) + %add.253602.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1610.5.clone.1, %broadcast.244498.640) + %select.21826.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7383.3.clone.1, %add.253601.5.clone.1, %add.253602.5.clone.1) + %multiply.27537.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21825.3.clone.1, %select.21826.3.clone.1) + %add.253603.1.clone.1 = f32[1280,1280]{1,0} add(%select.21824.3.clone.1, %multiply.27537.1.clone.1) + %multiply.27538.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253603.1.clone.1, %select.21826.3.clone.1) + %add.253604.1.clone.1 = f32[1280,1280]{1,0} add(%select.21823.3.clone.1, %multiply.27538.1.clone.1) + %multiply.27539.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253604.1.clone.1, %select.21826.3.clone.1) + %add.253605.1.clone.1 = f32[1280,1280]{1,0} add(%select.21822.3.clone.1, %multiply.27539.1.clone.1) + %multiply.27540.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253605.1.clone.1, %select.21826.3.clone.1) + %add.253606.1.clone.1 = f32[1280,1280]{1,0} add(%select.21821.3.clone.1, %multiply.27540.1.clone.1) + %multiply.27541.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253606.1.clone.1, %select.21826.3.clone.1) + %add.253607.3.clone.1 = f32[1280,1280]{1,0} add(%select.21820.5.clone.1, %multiply.27541.1.clone.1) + %multiply.27542.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253607.3.clone.1, %select.21826.3.clone.1) + %add.253608.3.clone.1 = f32[1280,1280]{1,0} add(%select.21819.5.clone.1, %multiply.27542.1.clone.1) + %multiply.27543.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253608.3.clone.1, %select.21826.3.clone.1) + %add.253609.9.clone.1 = f32[1280,1280]{1,0} add(%select.21818.11.clone.1, %multiply.27543.7.clone.1) + %multiply.27544.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253609.9.clone.1, %select.21826.3.clone.1) + %add.253610.7.clone.1 = f32[1280,1280]{1,0} add(%select.21817.7.clone.1, %multiply.27544.7.clone.1) + %multiply.27545.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253610.7.clone.1, %maximum.3790.3.clone.1) + %select.21827.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7382.3.clone.1, %multiply.27535.9.clone.1, %multiply.27545.7.clone.1) + %multiply.27546.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21827.7.clone.1, %broadcast.244500.640) + %clamp.1254.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27546.5.clone.1, %broadcast.244501.384) + %multiply.27547.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1254.3.clone.1, %broadcast.244502.1) + %constant_167833_1_clone_1 = u32[] constant(3727816689) + %broadcast.248614.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167833_1_clone_1), dimensions={} + %add.246476.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248614.44.clone.1) + %constant_167845_1_clone_1 = u32[] constant(589999196) + %broadcast.248616.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167845_1_clone_1), dimensions={} + %add.246477.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248616.113.clone.1) + %add.246479.35.clone.1 = u32[1280,1280]{1,0} add(%add.246476.37.clone.1, %add.246477.99.clone.1) + %shift-left.108800.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246477.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114930.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246477.99.clone.1, %broadcast.244415.6016) + %or.114456.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108800.31.clone.1, %shift-right-logical.114930.29.clone.1) + %xor.121005.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246479.35.clone.1, %or.114456.29.clone.1) + %add.246480.5.clone.1 = u32[1280,1280]{1,0} add(%add.246479.35.clone.1, %xor.121005.27.clone.1) + %shift-left.108801.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121005.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114931.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121005.27.clone.1, %broadcast.244417.5760) + %or.114457.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108801.9.clone.1, %shift-right-logical.114931.9.clone.1) + %xor.121006.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246480.5.clone.1, %or.114457.7.clone.1) + %add.246481.3.clone.1 = u32[1280,1280]{1,0} add(%add.246480.5.clone.1, %xor.121006.5.clone.1) + %shift-left.108802.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121006.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114932.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121006.5.clone.1, %broadcast.244419.4352) + %or.114458.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108802.5.clone.1, %shift-right-logical.114932.5.clone.1) + %xor.121008.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246481.3.clone.1, %or.114458.3.clone.1) + %add.246482.3.clone.1 = u32[1280,1280]{1,0} add(%add.246481.3.clone.1, %xor.121008.3.clone.1) + %add.246484.7.clone.1 = u32[1280,1280]{1,0} add(%add.246482.3.clone.1, %broadcast.248616.113.clone.1) + %shift-left.108803.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121008.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114933.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121008.3.clone.1, %broadcast.244418.4352) + %or.114459.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108803.5.clone.1, %shift-right-logical.114933.5.clone.1) + %xor.121009.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246482.3.clone.1, %or.114459.3.clone.1) + %constant_218036_1_clone_1 = u32[] constant(3872014456) + %broadcast.248630.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218036_1_clone_1), dimensions={} + %add.246485.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121009.3.clone.1, %broadcast.248630.5.clone.1) + %add.246486.5.clone.1 = u32[1280,1280]{1,0} add(%add.246484.7.clone.1, %add.246485.5.clone.1) + %shift-left.108804.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246485.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114934.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246485.5.clone.1, %broadcast.244416.5760) + %or.114460.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108804.9.clone.1, %shift-right-logical.114934.9.clone.1) + %xor.121010.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246486.5.clone.1, %or.114460.7.clone.1) + %add.246487.3.clone.1 = u32[1280,1280]{1,0} add(%add.246486.5.clone.1, %xor.121010.5.clone.1) + %shift-left.108805.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121010.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114935.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121010.5.clone.1, %broadcast.244429.2304) + %or.114461.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108805.9.clone.1, %shift-right-logical.114935.9.clone.1) + %xor.121011.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246487.3.clone.1, %or.114461.7.clone.1) + %add.246488.3.clone.1 = u32[1280,1280]{1,0} add(%add.246487.3.clone.1, %xor.121011.5.clone.1) + %shift-left.108806.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121011.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114936.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121011.5.clone.1, %broadcast.244430.4608) + %or.114463.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108806.9.clone.1, %shift-right-logical.114936.9.clone.1) + %xor.121013.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246488.3.clone.1, %or.114463.7.clone.1) + %add.246490.3.clone.1 = u32[1280,1280]{1,0} add(%add.246488.3.clone.1, %xor.121013.5.clone.1) + %constant_167847_1_clone_1 = u32[] constant(3872014455) + %broadcast.248639.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167847_1_clone_1), dimensions={} + %add.246494.7.clone.1 = u32[1280,1280]{1,0} add(%add.246490.3.clone.1, %broadcast.248639.24.clone.1) + %shift-left.108807.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121013.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114937.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121013.5.clone.1, %broadcast.244434.2816) + %or.114464.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108807.11.clone.1, %shift-right-logical.114937.11.clone.1) + %xor.121014.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246490.3.clone.1, %or.114464.9.clone.1) + %constant_218037_1_clone_1 = u32[] constant(3727816691) + %broadcast.248642.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218037_1_clone_1), dimensions={} + %add.246495.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121014.7.clone.1, %broadcast.248642.5.clone.1) + %add.246496.5.clone.1 = u32[1280,1280]{1,0} add(%add.246494.7.clone.1, %add.246495.5.clone.1) + %shift-left.108808.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246495.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114938.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246495.5.clone.1, %broadcast.244415.6016) + %or.114465.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108808.9.clone.1, %shift-right-logical.114938.9.clone.1) + %xor.121015.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246496.5.clone.1, %or.114465.7.clone.1) + %add.246497.3.clone.1 = u32[1280,1280]{1,0} add(%add.246496.5.clone.1, %xor.121015.5.clone.1) + %shift-left.108809.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121015.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114939.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121015.5.clone.1, %broadcast.244417.5760) + %or.114466.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108809.9.clone.1, %shift-right-logical.114939.9.clone.1) + %xor.121016.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246497.3.clone.1, %or.114466.7.clone.1) + %add.246499.3.clone.1 = u32[1280,1280]{1,0} add(%add.246497.3.clone.1, %xor.121016.5.clone.1) + %shift-left.108810.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121016.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114940.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121016.5.clone.1, %broadcast.244419.4352) + %or.114468.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108810.7.clone.1, %shift-right-logical.114940.7.clone.1) + %xor.121018.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246499.3.clone.1, %or.114468.5.clone.1) + %add.246500.3.clone.1 = u32[1280,1280]{1,0} add(%add.246499.3.clone.1, %xor.121018.3.clone.1) + %add.246501.7.clone.1 = u32[1280,1280]{1,0} add(%add.246500.3.clone.1, %broadcast.248614.44.clone.1) + %shift-left.108811.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121018.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114941.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121018.3.clone.1, %broadcast.244418.4352) + %or.114469.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108811.7.clone.1, %shift-right-logical.114941.7.clone.1) + %xor.121019.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246500.3.clone.1, %or.114469.5.clone.1) + %constant_218038_1_clone_1 = u32[] constant(589999199) + %broadcast.248656.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218038_1_clone_1), dimensions={} + %add.246502.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121019.3.clone.1, %broadcast.248656.5.clone.1) + %add.246504.5.clone.1 = u32[1280,1280]{1,0} add(%add.246501.7.clone.1, %add.246502.5.clone.1) + %shift-left.108812.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246502.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114942.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246502.5.clone.1, %broadcast.244416.5760) + %or.114470.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108812.9.clone.1, %shift-right-logical.114942.9.clone.1) + %xor.121020.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246504.5.clone.1, %or.114470.7.clone.1) + %add.246505.3.clone.1 = u32[1280,1280]{1,0} add(%add.246504.5.clone.1, %xor.121020.5.clone.1) + %shift-left.108813.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121020.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114943.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121020.5.clone.1, %broadcast.244429.2304) + %or.114471.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108813.9.clone.1, %shift-right-logical.114943.9.clone.1) + %xor.121021.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246505.3.clone.1, %or.114471.7.clone.1) + %add.246506.3.clone.1 = u32[1280,1280]{1,0} add(%add.246505.3.clone.1, %xor.121021.5.clone.1) + %shift-left.108814.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121021.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114944.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121021.5.clone.1, %broadcast.244430.4608) + %or.114473.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108814.9.clone.1, %shift-right-logical.114944.9.clone.1) + %xor.121023.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246506.3.clone.1, %or.114473.7.clone.1) + %add.246507.3.clone.1 = u32[1280,1280]{1,0} add(%add.246506.3.clone.1, %xor.121023.5.clone.1) + %add.246509.7.clone.1 = u32[1280,1280]{1,0} add(%add.246507.3.clone.1, %broadcast.248616.113.clone.1) + %shift-left.108815.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121023.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114945.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121023.5.clone.1, %broadcast.244434.2816) + %or.114474.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108815.11.clone.1, %shift-right-logical.114945.11.clone.1) + %xor.121024.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246507.3.clone.1, %or.114474.9.clone.1) + %constant_218039_1_clone_1 = u32[] constant(3872014459) + %broadcast.248668.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218039_1_clone_1), dimensions={} + %add.246510.5.clone.1 = u32[1280,1280]{1,0} add(%xor.121024.7.clone.1, %broadcast.248668.5.clone.1) + %add.246511.5.clone.1 = u32[1280,1280]{1,0} add(%add.246509.7.clone.1, %add.246510.5.clone.1) + %shift-left.108816.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246510.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114946.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246510.5.clone.1, %broadcast.244415.6016) + %or.114475.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108816.9.clone.1, %shift-right-logical.114946.9.clone.1) + %xor.121025.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246511.5.clone.1, %or.114475.7.clone.1) + %add.246512.3.clone.1 = u32[1280,1280]{1,0} add(%add.246511.5.clone.1, %xor.121025.5.clone.1) + %shift-left.108817.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121025.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114947.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121025.5.clone.1, %broadcast.244417.5760) + %or.114476.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108817.9.clone.1, %shift-right-logical.114947.9.clone.1) + %xor.121026.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246512.3.clone.1, %or.114476.7.clone.1) + %add.246513.3.clone.1 = u32[1280,1280]{1,0} add(%add.246512.3.clone.1, %xor.121026.5.clone.1) + %shift-left.108818.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121026.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114948.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121026.5.clone.1, %broadcast.244419.4352) + %or.114478.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108818.5.clone.1, %shift-right-logical.114948.5.clone.1) + %xor.121027.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246513.3.clone.1, %or.114478.3.clone.1) + %add.246515.3.clone.1 = u32[1280,1280]{1,0} add(%add.246513.3.clone.1, %xor.121027.3.clone.1) + %add.246519.17.clone.1 = u32[1280,1280]{1,0} add(%add.246515.3.clone.1, %broadcast.248639.24.clone.1) + %shift-left.108819.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.121027.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114949.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121027.3.clone.1, %broadcast.244418.4352) + %or.114479.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108819.5.clone.1, %shift-right-logical.114949.5.clone.1) + %xor.121028.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246515.3.clone.1, %or.114479.3.clone.1) + %constant_218040_1_clone_1 = u32[] constant(3727816694) + %broadcast.248678.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218040_1_clone_1), dimensions={} + %add.246520.19.clone.1 = u32[1280,1280]{1,0} add(%xor.121028.15.clone.1, %broadcast.248678.19.clone.1) + %xor.121029.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246519.17.clone.1, %add.246520.19.clone.1) + %shift-right-logical.114950.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.121029.17.clone.1, %broadcast.244468.1920) + %or.114480.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114950.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5717.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114480.13.clone.1) + %add.246521.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5717.11.clone.1, %broadcast.244470.1152) + %multiply.26080.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246521.9.clone.1, %broadcast.244471.896) + %add.246522.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26080.7.clone.1, %broadcast.244408.1024) + %maximum.3649.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246522.5.clone.1) + %abs.1515.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3649.3.clone.1) + %compare.7178.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1515.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26081.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3649.3.clone.1, %broadcast.244476.1152) + %negate.4535.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3649.3.clone.1) + %multiply.26082.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3649.3.clone.1, %negate.4535.5.clone.1) + %log-plus-one.1515.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26082.5.clone.1) + %negate.4536.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1515.3.clone.1) + %compare.7179.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4536.4.clone.1, %broadcast.244477.384), direction=LT + %select.20730.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20731.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20732.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20733.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20734.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20735.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20736.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20737.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20738.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246524.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4536.4.clone.1, %broadcast.244496.640) + %sqrt.1515.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4536.4.clone.1) + %add.246525.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1515.5.clone.1, %broadcast.244498.640) + %select.20739.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7179.3.clone.1, %add.246524.5.clone.1, %add.246525.5.clone.1) + %multiply.26083.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20738.3.clone.1, %select.20739.3.clone.1) + %add.246526.1.clone.1 = f32[1280,1280]{1,0} add(%select.20737.3.clone.1, %multiply.26083.1.clone.1) + %multiply.26084.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246526.1.clone.1, %select.20739.3.clone.1) + %add.246527.1.clone.1 = f32[1280,1280]{1,0} add(%select.20736.3.clone.1, %multiply.26084.1.clone.1) + %multiply.26085.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246527.1.clone.1, %select.20739.3.clone.1) + %add.246529.1.clone.1 = f32[1280,1280]{1,0} add(%select.20735.3.clone.1, %multiply.26085.1.clone.1) + %multiply.26086.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246529.1.clone.1, %select.20739.3.clone.1) + %add.246530.1.clone.1 = f32[1280,1280]{1,0} add(%select.20734.3.clone.1, %multiply.26086.1.clone.1) + %multiply.26087.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246530.1.clone.1, %select.20739.3.clone.1) + %add.246531.3.clone.1 = f32[1280,1280]{1,0} add(%select.20733.5.clone.1, %multiply.26087.1.clone.1) + %multiply.26088.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246531.3.clone.1, %select.20739.3.clone.1) + %add.246532.3.clone.1 = f32[1280,1280]{1,0} add(%select.20732.5.clone.1, %multiply.26088.1.clone.1) + %multiply.26089.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246532.3.clone.1, %select.20739.3.clone.1) + %add.246534.9.clone.1 = f32[1280,1280]{1,0} add(%select.20731.11.clone.1, %multiply.26089.7.clone.1) + %multiply.26090.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246534.9.clone.1, %select.20739.3.clone.1) + %add.246535.7.clone.1 = f32[1280,1280]{1,0} add(%select.20730.7.clone.1, %multiply.26090.7.clone.1) + %multiply.26091.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246535.7.clone.1, %maximum.3649.3.clone.1) + %select.20740.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7178.3.clone.1, %multiply.26081.9.clone.1, %multiply.26091.7.clone.1) + %multiply.26092.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20740.7.clone.1, %broadcast.244500.640) + %clamp.1159.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26092.5.clone.1, %broadcast.244501.384) + %multiply.26093.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1159.3.clone.1, %broadcast.244502.1) + %constant_182477_1_clone_1 = u32[] constant(437258332) + %broadcast.254933.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182477_1_clone_1), dimensions={} + %add.250091.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254933.44.clone.1) + %constant_182484_1_clone_1 = u32[] constant(2304223561) + %broadcast.254934.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182484_1_clone_1), dimensions={} + %add.250092.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254934.113.clone.1) + %add.250093.35.clone.1 = u32[1280,1280]{1,0} add(%add.250091.37.clone.1, %add.250092.99.clone.1) + %shift-left.110348.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250092.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116590.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250092.99.clone.1, %broadcast.244415.6016) + %or.116102.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110348.31.clone.1, %shift-right-logical.116590.29.clone.1) + %xor.122674.27.clone.1 = u32[1280,1280]{1,0} xor(%add.250093.35.clone.1, %or.116102.29.clone.1) + %add.250094.5.clone.1 = u32[1280,1280]{1,0} add(%add.250093.35.clone.1, %xor.122674.27.clone.1) + %shift-left.110349.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122674.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116591.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122674.27.clone.1, %broadcast.244417.5760) + %or.116103.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110349.9.clone.1, %shift-right-logical.116591.9.clone.1) + %xor.122675.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250094.5.clone.1, %or.116103.7.clone.1) + %add.250095.3.clone.1 = u32[1280,1280]{1,0} add(%add.250094.5.clone.1, %xor.122675.5.clone.1) + %shift-left.110350.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122675.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116592.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122675.5.clone.1, %broadcast.244419.4352) + %or.116104.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110350.5.clone.1, %shift-right-logical.116592.5.clone.1) + %xor.122676.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250095.3.clone.1, %or.116104.3.clone.1) + %add.250096.3.clone.1 = u32[1280,1280]{1,0} add(%add.250095.3.clone.1, %xor.122676.3.clone.1) + %add.250097.7.clone.1 = u32[1280,1280]{1,0} add(%add.250096.3.clone.1, %broadcast.254934.113.clone.1) + %shift-left.110351.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122676.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116593.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122676.3.clone.1, %broadcast.244418.4352) + %or.116106.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110351.5.clone.1, %shift-right-logical.116593.5.clone.1) + %xor.122677.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250096.3.clone.1, %or.116106.3.clone.1) + %constant_218434_1_clone_1 = u32[] constant(2291579600) + %broadcast.254946.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218434_1_clone_1), dimensions={} + %add.250098.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122677.3.clone.1, %broadcast.254946.5.clone.1) + %add.250099.5.clone.1 = u32[1280,1280]{1,0} add(%add.250097.7.clone.1, %add.250098.5.clone.1) + %shift-left.110352.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250098.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116595.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250098.5.clone.1, %broadcast.244416.5760) + %or.116107.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110352.9.clone.1, %shift-right-logical.116595.9.clone.1) + %xor.122678.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250099.5.clone.1, %or.116107.7.clone.1) + %add.250100.3.clone.1 = u32[1280,1280]{1,0} add(%add.250099.5.clone.1, %xor.122678.5.clone.1) + %shift-left.110353.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122678.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116596.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122678.5.clone.1, %broadcast.244429.2304) + %or.116108.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110353.9.clone.1, %shift-right-logical.116596.9.clone.1) + %xor.122679.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250100.3.clone.1, %or.116108.7.clone.1) + %add.250101.3.clone.1 = u32[1280,1280]{1,0} add(%add.250100.3.clone.1, %xor.122679.5.clone.1) + %shift-left.110355.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122679.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116597.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122679.5.clone.1, %broadcast.244430.4608) + %or.116109.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110355.9.clone.1, %shift-right-logical.116597.9.clone.1) + %xor.122680.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250101.3.clone.1, %or.116109.7.clone.1) + %add.250102.3.clone.1 = u32[1280,1280]{1,0} add(%add.250101.3.clone.1, %xor.122680.5.clone.1) + %constant_182486_1_clone_1 = u32[] constant(2291579599) + %broadcast.254953.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_182486_1_clone_1), dimensions={} + %add.250103.7.clone.1 = u32[1280,1280]{1,0} add(%add.250102.3.clone.1, %broadcast.254953.24.clone.1) + %shift-left.110356.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122680.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116598.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122680.5.clone.1, %broadcast.244434.2816) + %or.116110.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110356.11.clone.1, %shift-right-logical.116598.11.clone.1) + %xor.122681.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250102.3.clone.1, %or.116110.9.clone.1) + %constant_218435_1_clone_1 = u32[] constant(437258334) + %broadcast.254956.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218435_1_clone_1), dimensions={} + %add.250104.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122681.7.clone.1, %broadcast.254956.5.clone.1) + %add.250105.5.clone.1 = u32[1280,1280]{1,0} add(%add.250103.7.clone.1, %add.250104.5.clone.1) + %shift-left.110357.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250104.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116599.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250104.5.clone.1, %broadcast.244415.6016) + %or.116111.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110357.9.clone.1, %shift-right-logical.116599.9.clone.1) + %xor.122682.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250105.5.clone.1, %or.116111.7.clone.1) + %add.250106.3.clone.1 = u32[1280,1280]{1,0} add(%add.250105.5.clone.1, %xor.122682.5.clone.1) + %shift-left.110358.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122682.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116600.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122682.5.clone.1, %broadcast.244417.5760) + %or.116112.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110358.9.clone.1, %shift-right-logical.116600.9.clone.1) + %xor.122684.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250106.3.clone.1, %or.116112.7.clone.1) + %add.250107.3.clone.1 = u32[1280,1280]{1,0} add(%add.250106.3.clone.1, %xor.122684.5.clone.1) + %shift-left.110360.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122684.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116601.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122684.5.clone.1, %broadcast.244419.4352) + %or.116113.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110360.7.clone.1, %shift-right-logical.116601.7.clone.1) + %xor.122685.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250107.3.clone.1, %or.116113.5.clone.1) + %add.250108.3.clone.1 = u32[1280,1280]{1,0} add(%add.250107.3.clone.1, %xor.122685.3.clone.1) + %add.250109.7.clone.1 = u32[1280,1280]{1,0} add(%add.250108.3.clone.1, %broadcast.254933.44.clone.1) + %shift-left.110361.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122685.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116602.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122685.3.clone.1, %broadcast.244418.4352) + %or.116114.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110361.7.clone.1, %shift-right-logical.116602.7.clone.1) + %xor.122686.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250108.3.clone.1, %or.116114.5.clone.1) + %constant_218436_1_clone_1 = u32[] constant(2304223564) + %broadcast.254968.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218436_1_clone_1), dimensions={} + %add.250110.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122686.3.clone.1, %broadcast.254968.5.clone.1) + %add.250111.5.clone.1 = u32[1280,1280]{1,0} add(%add.250109.7.clone.1, %add.250110.5.clone.1) + %shift-left.110362.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250110.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116603.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250110.5.clone.1, %broadcast.244416.5760) + %or.116116.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110362.9.clone.1, %shift-right-logical.116603.9.clone.1) + %xor.122687.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250111.5.clone.1, %or.116116.7.clone.1) + %add.250112.3.clone.1 = u32[1280,1280]{1,0} add(%add.250111.5.clone.1, %xor.122687.5.clone.1) + %shift-left.110363.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122687.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116604.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122687.5.clone.1, %broadcast.244429.2304) + %or.116117.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110363.9.clone.1, %shift-right-logical.116604.9.clone.1) + %xor.122689.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250112.3.clone.1, %or.116117.7.clone.1) + %add.250113.3.clone.1 = u32[1280,1280]{1,0} add(%add.250112.3.clone.1, %xor.122689.5.clone.1) + %shift-left.110365.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122689.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116605.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122689.5.clone.1, %broadcast.244430.4608) + %or.116118.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110365.9.clone.1, %shift-right-logical.116605.9.clone.1) + %xor.122690.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250113.3.clone.1, %or.116118.7.clone.1) + %add.250114.3.clone.1 = u32[1280,1280]{1,0} add(%add.250113.3.clone.1, %xor.122690.5.clone.1) + %add.250115.7.clone.1 = u32[1280,1280]{1,0} add(%add.250114.3.clone.1, %broadcast.254934.113.clone.1) + %shift-left.110366.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122690.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116606.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122690.5.clone.1, %broadcast.244434.2816) + %or.116119.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110366.11.clone.1, %shift-right-logical.116606.11.clone.1) + %xor.122691.7.clone.1 = u32[1280,1280]{1,0} xor(%add.250114.3.clone.1, %or.116119.9.clone.1) + %constant_218437_1_clone_1 = u32[] constant(2291579603) + %broadcast.254981.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218437_1_clone_1), dimensions={} + %add.250116.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122691.7.clone.1, %broadcast.254981.5.clone.1) + %add.250117.5.clone.1 = u32[1280,1280]{1,0} add(%add.250115.7.clone.1, %add.250116.5.clone.1) + %shift-left.110367.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.250116.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116608.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.250116.5.clone.1, %broadcast.244415.6016) + %or.116121.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110367.9.clone.1, %shift-right-logical.116608.9.clone.1) + %xor.122692.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250117.5.clone.1, %or.116121.7.clone.1) + %add.250118.3.clone.1 = u32[1280,1280]{1,0} add(%add.250117.5.clone.1, %xor.122692.5.clone.1) + %shift-left.110368.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122692.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116609.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122692.5.clone.1, %broadcast.244417.5760) + %or.116122.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110368.9.clone.1, %shift-right-logical.116609.9.clone.1) + %xor.122694.5.clone.1 = u32[1280,1280]{1,0} xor(%add.250118.3.clone.1, %or.116122.7.clone.1) + %add.250119.3.clone.1 = u32[1280,1280]{1,0} add(%add.250118.3.clone.1, %xor.122694.5.clone.1) + %shift-left.110370.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122694.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116610.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122694.5.clone.1, %broadcast.244419.4352) + %or.116123.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110370.5.clone.1, %shift-right-logical.116610.5.clone.1) + %xor.122695.3.clone.1 = u32[1280,1280]{1,0} xor(%add.250119.3.clone.1, %or.116123.3.clone.1) + %add.250120.3.clone.1 = u32[1280,1280]{1,0} add(%add.250119.3.clone.1, %xor.122695.3.clone.1) + %add.250121.17.clone.1 = u32[1280,1280]{1,0} add(%add.250120.3.clone.1, %broadcast.254953.24.clone.1) + %shift-left.110371.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122695.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116611.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122695.3.clone.1, %broadcast.244418.4352) + %or.116124.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110371.5.clone.1, %shift-right-logical.116611.5.clone.1) + %xor.122696.15.clone.1 = u32[1280,1280]{1,0} xor(%add.250120.3.clone.1, %or.116124.3.clone.1) + %constant_218438_1_clone_1 = u32[] constant(437258337) + %broadcast.254995.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218438_1_clone_1), dimensions={} + %add.250122.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122696.15.clone.1, %broadcast.254995.19.clone.1) + %xor.122697.17.clone.1 = u32[1280,1280]{1,0} xor(%add.250121.17.clone.1, %add.250122.19.clone.1) + %shift-right-logical.116612.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122697.17.clone.1, %broadcast.244468.1920) + %or.116126.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116612.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5789.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116126.13.clone.1) + %add.250123.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5789.11.clone.1, %broadcast.244470.1152) + %multiply.26815.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250123.9.clone.1, %broadcast.244471.896) + %add.250124.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26815.7.clone.1, %broadcast.244408.1024) + %maximum.3721.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.250124.5.clone.1) + %abs.1563.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3721.3.clone.1) + %compare.7275.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1563.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26816.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3721.3.clone.1, %broadcast.244476.1152) + %negate.4631.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3721.3.clone.1) + %multiply.26817.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3721.3.clone.1, %negate.4631.5.clone.1) + %log-plus-one.1563.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26817.5.clone.1) + %negate.4632.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1563.3.clone.1) + %compare.7276.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4632.4.clone.1, %broadcast.244477.384), direction=LT + %select.21279.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21280.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21281.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21282.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21283.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21284.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21285.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21286.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21287.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.250125.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4632.4.clone.1, %broadcast.244496.640) + %sqrt.1563.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4632.4.clone.1) + %add.250126.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1563.5.clone.1, %broadcast.244498.640) + %select.21288.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7276.3.clone.1, %add.250125.5.clone.1, %add.250126.5.clone.1) + %multiply.26818.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21287.3.clone.1, %select.21288.3.clone.1) + %add.250127.1.clone.1 = f32[1280,1280]{1,0} add(%select.21286.3.clone.1, %multiply.26818.1.clone.1) + %multiply.26819.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250127.1.clone.1, %select.21288.3.clone.1) + %add.250128.1.clone.1 = f32[1280,1280]{1,0} add(%select.21285.3.clone.1, %multiply.26819.1.clone.1) + %multiply.26820.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250128.1.clone.1, %select.21288.3.clone.1) + %add.250129.1.clone.1 = f32[1280,1280]{1,0} add(%select.21284.3.clone.1, %multiply.26820.1.clone.1) + %multiply.26821.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250129.1.clone.1, %select.21288.3.clone.1) + %add.250131.1.clone.1 = f32[1280,1280]{1,0} add(%select.21283.3.clone.1, %multiply.26821.1.clone.1) + %multiply.26822.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250131.1.clone.1, %select.21288.3.clone.1) + %add.250134.3.clone.1 = f32[1280,1280]{1,0} add(%select.21282.5.clone.1, %multiply.26822.1.clone.1) + %multiply.26823.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.250134.3.clone.1, %select.21288.3.clone.1) + %add.250135.3.clone.1 = f32[1280,1280]{1,0} add(%select.21281.5.clone.1, %multiply.26823.1.clone.1) + %multiply.26824.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250135.3.clone.1, %select.21288.3.clone.1) + %add.250136.9.clone.1 = f32[1280,1280]{1,0} add(%select.21280.11.clone.1, %multiply.26824.7.clone.1) + %multiply.26825.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250136.9.clone.1, %select.21288.3.clone.1) + %add.250137.7.clone.1 = f32[1280,1280]{1,0} add(%select.21279.7.clone.1, %multiply.26825.7.clone.1) + %multiply.26826.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250137.7.clone.1, %maximum.3721.3.clone.1) + %select.21289.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7275.3.clone.1, %multiply.26816.9.clone.1, %multiply.26826.7.clone.1) + %multiply.26827.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21289.7.clone.1, %broadcast.244500.640) + %clamp.1207.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26827.5.clone.1, %broadcast.244501.384) + %multiply.26828.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1207.3.clone.1, %broadcast.244502.1) + %constant_167282_1_clone_1 = u32[] constant(4076729066) + %broadcast.248377.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167282_1_clone_1), dimensions={} + %add.246341.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248377.44.clone.1) + %constant_167294_1_clone_1 = u32[] constant(10925196) + %broadcast.248378.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167294_1_clone_1), dimensions={} + %add.246342.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248378.113.clone.1) + %add.246343.35.clone.1 = u32[1280,1280]{1,0} add(%add.246341.37.clone.1, %add.246342.99.clone.1) + %shift-left.108734.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246342.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114867.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246342.99.clone.1, %broadcast.244415.6016) + %or.114387.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108734.31.clone.1, %shift-right-logical.114867.29.clone.1) + %xor.120932.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246343.35.clone.1, %or.114387.29.clone.1) + %add.246345.5.clone.1 = u32[1280,1280]{1,0} add(%add.246343.35.clone.1, %xor.120932.27.clone.1) + %shift-left.108735.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120932.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114868.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120932.27.clone.1, %broadcast.244417.5760) + %or.114388.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108735.9.clone.1, %shift-right-logical.114868.9.clone.1) + %xor.120933.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246345.5.clone.1, %or.114388.7.clone.1) + %add.246346.3.clone.1 = u32[1280,1280]{1,0} add(%add.246345.5.clone.1, %xor.120933.5.clone.1) + %shift-left.108737.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120933.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114869.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120933.5.clone.1, %broadcast.244419.4352) + %or.114389.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108737.5.clone.1, %shift-right-logical.114869.5.clone.1) + %xor.120934.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246346.3.clone.1, %or.114389.3.clone.1) + %add.246347.3.clone.1 = u32[1280,1280]{1,0} add(%add.246346.3.clone.1, %xor.120934.3.clone.1) + %add.246348.7.clone.1 = u32[1280,1280]{1,0} add(%add.246347.3.clone.1, %broadcast.248378.113.clone.1) + %shift-left.108738.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120934.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114870.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120934.3.clone.1, %broadcast.244418.4352) + %or.114390.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108738.5.clone.1, %shift-right-logical.114870.5.clone.1) + %xor.120935.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246347.3.clone.1, %or.114390.3.clone.1) + %constant_218021_1_clone_1 = u32[] constant(3918159293) + %broadcast.248388.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218021_1_clone_1), dimensions={} + %add.246350.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120935.3.clone.1, %broadcast.248388.5.clone.1) + %add.246351.5.clone.1 = u32[1280,1280]{1,0} add(%add.246348.7.clone.1, %add.246350.5.clone.1) + %shift-left.108739.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246350.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114871.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246350.5.clone.1, %broadcast.244416.5760) + %or.114391.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108739.9.clone.1, %shift-right-logical.114871.9.clone.1) + %xor.120936.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246351.5.clone.1, %or.114391.7.clone.1) + %add.246352.3.clone.1 = u32[1280,1280]{1,0} add(%add.246351.5.clone.1, %xor.120936.5.clone.1) + %shift-left.108740.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120936.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114872.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120936.5.clone.1, %broadcast.244429.2304) + %or.114393.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108740.9.clone.1, %shift-right-logical.114872.9.clone.1) + %xor.120937.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246352.3.clone.1, %or.114393.7.clone.1) + %add.246353.3.clone.1 = u32[1280,1280]{1,0} add(%add.246352.3.clone.1, %xor.120937.5.clone.1) + %shift-left.108742.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120937.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114873.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120937.5.clone.1, %broadcast.244430.4608) + %or.114394.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108742.9.clone.1, %shift-right-logical.114873.9.clone.1) + %xor.120938.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246353.3.clone.1, %or.114394.7.clone.1) + %add.246354.3.clone.1 = u32[1280,1280]{1,0} add(%add.246353.3.clone.1, %xor.120938.5.clone.1) + %constant_167296_1_clone_1 = u32[] constant(3918159292) + %broadcast.248395.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167296_1_clone_1), dimensions={} + %add.246356.7.clone.1 = u32[1280,1280]{1,0} add(%add.246354.3.clone.1, %broadcast.248395.24.clone.1) + %shift-left.108743.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120938.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114874.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120938.5.clone.1, %broadcast.244434.2816) + %or.114396.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108743.11.clone.1, %shift-right-logical.114874.11.clone.1) + %xor.120939.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246354.3.clone.1, %or.114396.9.clone.1) + %constant_218022_1_clone_1 = u32[] constant(4076729068) + %broadcast.248398.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218022_1_clone_1), dimensions={} + %add.246360.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120939.7.clone.1, %broadcast.248398.5.clone.1) + %add.246361.5.clone.1 = u32[1280,1280]{1,0} add(%add.246356.7.clone.1, %add.246360.5.clone.1) + %shift-left.108744.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246360.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114875.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246360.5.clone.1, %broadcast.244415.6016) + %or.114397.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108744.9.clone.1, %shift-right-logical.114875.9.clone.1) + %xor.120940.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246361.5.clone.1, %or.114397.7.clone.1) + %add.246362.3.clone.1 = u32[1280,1280]{1,0} add(%add.246361.5.clone.1, %xor.120940.5.clone.1) + %shift-left.108745.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120940.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114876.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120940.5.clone.1, %broadcast.244417.5760) + %or.114398.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108745.9.clone.1, %shift-right-logical.114876.9.clone.1) + %xor.120941.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246362.3.clone.1, %or.114398.7.clone.1) + %add.246363.3.clone.1 = u32[1280,1280]{1,0} add(%add.246362.3.clone.1, %xor.120941.5.clone.1) + %shift-left.108746.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120941.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114877.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120941.5.clone.1, %broadcast.244419.4352) + %or.114399.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108746.7.clone.1, %shift-right-logical.114877.7.clone.1) + %xor.120942.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246363.3.clone.1, %or.114399.5.clone.1) + %add.246365.3.clone.1 = u32[1280,1280]{1,0} add(%add.246363.3.clone.1, %xor.120942.3.clone.1) + %add.246366.7.clone.1 = u32[1280,1280]{1,0} add(%add.246365.3.clone.1, %broadcast.248377.44.clone.1) + %shift-left.108747.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120942.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114878.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120942.3.clone.1, %broadcast.244418.4352) + %or.114400.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108747.7.clone.1, %shift-right-logical.114878.7.clone.1) + %xor.120943.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246365.3.clone.1, %or.114400.5.clone.1) + %constant_218023_1_clone_1 = u32[] constant(10925199) + %broadcast.248408.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218023_1_clone_1), dimensions={} + %add.246367.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120943.3.clone.1, %broadcast.248408.5.clone.1) + %add.246368.5.clone.1 = u32[1280,1280]{1,0} add(%add.246366.7.clone.1, %add.246367.5.clone.1) + %shift-left.108748.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246367.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114879.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246367.5.clone.1, %broadcast.244416.5760) + %or.114401.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108748.9.clone.1, %shift-right-logical.114879.9.clone.1) + %xor.120944.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246368.5.clone.1, %or.114401.7.clone.1) + %add.246370.3.clone.1 = u32[1280,1280]{1,0} add(%add.246368.5.clone.1, %xor.120944.5.clone.1) + %shift-left.108749.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120944.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114880.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120944.5.clone.1, %broadcast.244429.2304) + %or.114402.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108749.9.clone.1, %shift-right-logical.114880.9.clone.1) + %xor.120945.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246370.3.clone.1, %or.114402.7.clone.1) + %add.246371.3.clone.1 = u32[1280,1280]{1,0} add(%add.246370.3.clone.1, %xor.120945.5.clone.1) + %shift-left.108750.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120945.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114881.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120945.5.clone.1, %broadcast.244430.4608) + %or.114403.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108750.9.clone.1, %shift-right-logical.114881.9.clone.1) + %xor.120946.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246371.3.clone.1, %or.114403.7.clone.1) + %add.246372.3.clone.1 = u32[1280,1280]{1,0} add(%add.246371.3.clone.1, %xor.120946.5.clone.1) + %add.246373.7.clone.1 = u32[1280,1280]{1,0} add(%add.246372.3.clone.1, %broadcast.248378.113.clone.1) + %shift-left.108752.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120946.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114882.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120946.5.clone.1, %broadcast.244434.2816) + %or.114404.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108752.11.clone.1, %shift-right-logical.114882.11.clone.1) + %xor.120949.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246372.3.clone.1, %or.114404.9.clone.1) + %constant_218024_1_clone_1 = u32[] constant(3918159296) + %broadcast.248418.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218024_1_clone_1), dimensions={} + %add.246375.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120949.7.clone.1, %broadcast.248418.5.clone.1) + %add.246376.5.clone.1 = u32[1280,1280]{1,0} add(%add.246373.7.clone.1, %add.246375.5.clone.1) + %shift-left.108753.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246375.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114883.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246375.5.clone.1, %broadcast.244415.6016) + %or.114405.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108753.9.clone.1, %shift-right-logical.114883.9.clone.1) + %xor.120950.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246376.5.clone.1, %or.114405.7.clone.1) + %add.246377.3.clone.1 = u32[1280,1280]{1,0} add(%add.246376.5.clone.1, %xor.120950.5.clone.1) + %shift-left.108754.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120950.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114884.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120950.5.clone.1, %broadcast.244417.5760) + %or.114406.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108754.9.clone.1, %shift-right-logical.114884.9.clone.1) + %xor.120951.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246377.3.clone.1, %or.114406.7.clone.1) + %add.246378.3.clone.1 = u32[1280,1280]{1,0} add(%add.246377.3.clone.1, %xor.120951.5.clone.1) + %shift-left.108755.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120951.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114885.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120951.5.clone.1, %broadcast.244419.4352) + %or.114407.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108755.5.clone.1, %shift-right-logical.114885.5.clone.1) + %xor.120952.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246378.3.clone.1, %or.114407.3.clone.1) + %add.246379.3.clone.1 = u32[1280,1280]{1,0} add(%add.246378.3.clone.1, %xor.120952.3.clone.1) + %add.246381.17.clone.1 = u32[1280,1280]{1,0} add(%add.246379.3.clone.1, %broadcast.248395.24.clone.1) + %shift-left.108757.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120952.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114886.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120952.3.clone.1, %broadcast.244418.4352) + %or.114408.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108757.5.clone.1, %shift-right-logical.114886.5.clone.1) + %xor.120953.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246379.3.clone.1, %or.114408.3.clone.1) + %constant_218025_1_clone_1 = u32[] constant(4076729071) + %broadcast.248428.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218025_1_clone_1), dimensions={} + %add.246384.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120953.15.clone.1, %broadcast.248428.19.clone.1) + %xor.120954.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246381.17.clone.1, %add.246384.19.clone.1) + %shift-right-logical.114887.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120954.17.clone.1, %broadcast.244468.1920) + %or.114409.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114887.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5714.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114409.13.clone.1) + %add.246385.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5714.11.clone.1, %broadcast.244470.1152) + %multiply.26062.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246385.9.clone.1, %broadcast.244471.896) + %add.246386.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26062.7.clone.1, %broadcast.244408.1024) + %maximum.3646.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246386.5.clone.1) + %abs.1514.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3646.3.clone.1) + %compare.7176.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1514.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26063.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3646.3.clone.1, %broadcast.244476.1152) + %negate.4533.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3646.3.clone.1) + %multiply.26064.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3646.3.clone.1, %negate.4533.5.clone.1) + %log-plus-one.1514.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26064.5.clone.1) + %negate.4534.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1514.3.clone.1) + %compare.7177.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4534.4.clone.1, %broadcast.244477.384), direction=LT + %select.20719.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20720.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20721.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20722.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20723.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20724.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20725.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20726.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20727.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246387.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4534.4.clone.1, %broadcast.244496.640) + %sqrt.1514.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4534.4.clone.1) + %add.246388.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1514.5.clone.1, %broadcast.244498.640) + %select.20728.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7177.3.clone.1, %add.246387.5.clone.1, %add.246388.5.clone.1) + %multiply.26065.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20727.3.clone.1, %select.20728.3.clone.1) + %add.246389.1.clone.1 = f32[1280,1280]{1,0} add(%select.20726.3.clone.1, %multiply.26065.1.clone.1) + %multiply.26066.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246389.1.clone.1, %select.20728.3.clone.1) + %add.246390.1.clone.1 = f32[1280,1280]{1,0} add(%select.20725.3.clone.1, %multiply.26066.1.clone.1) + %multiply.26067.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246390.1.clone.1, %select.20728.3.clone.1) + %add.246391.1.clone.1 = f32[1280,1280]{1,0} add(%select.20724.3.clone.1, %multiply.26067.1.clone.1) + %multiply.26068.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246391.1.clone.1, %select.20728.3.clone.1) + %add.246392.1.clone.1 = f32[1280,1280]{1,0} add(%select.20723.3.clone.1, %multiply.26068.1.clone.1) + %multiply.26069.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246392.1.clone.1, %select.20728.3.clone.1) + %add.246393.3.clone.1 = f32[1280,1280]{1,0} add(%select.20722.5.clone.1, %multiply.26069.1.clone.1) + %multiply.26070.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246393.3.clone.1, %select.20728.3.clone.1) + %add.246394.3.clone.1 = f32[1280,1280]{1,0} add(%select.20721.5.clone.1, %multiply.26070.1.clone.1) + %multiply.26071.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246394.3.clone.1, %select.20728.3.clone.1) + %add.246395.9.clone.1 = f32[1280,1280]{1,0} add(%select.20720.11.clone.1, %multiply.26071.7.clone.1) + %multiply.26072.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246395.9.clone.1, %select.20728.3.clone.1) + %add.246396.7.clone.1 = f32[1280,1280]{1,0} add(%select.20719.7.clone.1, %multiply.26072.7.clone.1) + %multiply.26073.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246396.7.clone.1, %maximum.3646.3.clone.1) + %select.20729.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7176.3.clone.1, %multiply.26063.9.clone.1, %multiply.26073.7.clone.1) + %multiply.26074.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20729.7.clone.1, %broadcast.244500.640) + %clamp.1158.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26074.5.clone.1, %broadcast.244501.384) + %multiply.26075.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1158.3.clone.1, %broadcast.244502.1) + %constant_189782_1_clone_1 = u32[] constant(1759179695) + %broadcast.258099.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189782_1_clone_1), dimensions={} + %add.251902.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.258099.44.clone.1) + %constant_189789_1_clone_1 = u32[] constant(154785442) + %broadcast.258100.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189789_1_clone_1), dimensions={} + %add.251903.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.258100.113.clone.1) + %add.251904.35.clone.1 = u32[1280,1280]{1,0} add(%add.251902.37.clone.1, %add.251903.99.clone.1) + %shift-left.111152.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251903.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117413.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251903.99.clone.1, %broadcast.244415.6016) + %or.116951.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111152.31.clone.1, %shift-right-logical.117413.29.clone.1) + %xor.123510.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251904.35.clone.1, %or.116951.29.clone.1) + %add.251905.5.clone.1 = u32[1280,1280]{1,0} add(%add.251904.35.clone.1, %xor.123510.27.clone.1) + %shift-left.111153.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123510.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117414.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123510.27.clone.1, %broadcast.244417.5760) + %or.116952.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111153.9.clone.1, %shift-right-logical.117414.9.clone.1) + %xor.123511.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251905.5.clone.1, %or.116952.7.clone.1) + %add.251906.3.clone.1 = u32[1280,1280]{1,0} add(%add.251905.5.clone.1, %xor.123511.5.clone.1) + %shift-left.111154.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123511.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117415.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123511.5.clone.1, %broadcast.244419.4352) + %or.116953.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111154.5.clone.1, %shift-right-logical.117415.5.clone.1) + %xor.123512.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251906.3.clone.1, %or.116953.3.clone.1) + %add.251907.3.clone.1 = u32[1280,1280]{1,0} add(%add.251906.3.clone.1, %xor.123512.3.clone.1) + %add.251908.7.clone.1 = u32[1280,1280]{1,0} add(%add.251907.3.clone.1, %broadcast.258100.113.clone.1) + %shift-left.111155.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123512.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117416.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123512.3.clone.1, %broadcast.244418.4352) + %or.116954.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111155.5.clone.1, %shift-right-logical.117416.5.clone.1) + %xor.123513.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251907.3.clone.1, %or.116954.3.clone.1) + %constant_218635_1_clone_1 = u32[] constant(2050113240) + %broadcast.258110.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218635_1_clone_1), dimensions={} + %add.251909.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123513.3.clone.1, %broadcast.258110.5.clone.1) + %add.251910.5.clone.1 = u32[1280,1280]{1,0} add(%add.251908.7.clone.1, %add.251909.5.clone.1) + %shift-left.111156.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251909.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117417.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251909.5.clone.1, %broadcast.244416.5760) + %or.116956.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111156.9.clone.1, %shift-right-logical.117417.9.clone.1) + %xor.123514.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251910.5.clone.1, %or.116956.7.clone.1) + %add.251911.3.clone.1 = u32[1280,1280]{1,0} add(%add.251910.5.clone.1, %xor.123514.5.clone.1) + %shift-left.111157.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123514.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117418.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123514.5.clone.1, %broadcast.244429.2304) + %or.116957.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111157.9.clone.1, %shift-right-logical.117418.9.clone.1) + %xor.123515.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251911.3.clone.1, %or.116957.7.clone.1) + %add.251912.3.clone.1 = u32[1280,1280]{1,0} add(%add.251911.3.clone.1, %xor.123515.5.clone.1) + %shift-left.111158.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123515.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117419.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123515.5.clone.1, %broadcast.244430.4608) + %or.116958.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111158.9.clone.1, %shift-right-logical.117419.9.clone.1) + %xor.123516.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251912.3.clone.1, %or.116958.7.clone.1) + %add.251913.3.clone.1 = u32[1280,1280]{1,0} add(%add.251912.3.clone.1, %xor.123516.5.clone.1) + %constant_189791_1_clone_1 = u32[] constant(2050113239) + %broadcast.258119.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189791_1_clone_1), dimensions={} + %add.251914.7.clone.1 = u32[1280,1280]{1,0} add(%add.251913.3.clone.1, %broadcast.258119.24.clone.1) + %shift-left.111160.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123516.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117421.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123516.5.clone.1, %broadcast.244434.2816) + %or.116959.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111160.11.clone.1, %shift-right-logical.117421.11.clone.1) + %xor.123517.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251913.3.clone.1, %or.116959.9.clone.1) + %constant_218636_1_clone_1 = u32[] constant(1759179697) + %broadcast.258122.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218636_1_clone_1), dimensions={} + %add.251915.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123517.7.clone.1, %broadcast.258122.5.clone.1) + %add.251916.5.clone.1 = u32[1280,1280]{1,0} add(%add.251914.7.clone.1, %add.251915.5.clone.1) + %shift-left.111161.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251915.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117422.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251915.5.clone.1, %broadcast.244415.6016) + %or.116961.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111161.9.clone.1, %shift-right-logical.117422.9.clone.1) + %xor.123518.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251916.5.clone.1, %or.116961.7.clone.1) + %add.251917.3.clone.1 = u32[1280,1280]{1,0} add(%add.251916.5.clone.1, %xor.123518.5.clone.1) + %shift-left.111162.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123518.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117423.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123518.5.clone.1, %broadcast.244417.5760) + %or.116962.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111162.9.clone.1, %shift-right-logical.117423.9.clone.1) + %xor.123519.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251917.3.clone.1, %or.116962.7.clone.1) + %add.251918.3.clone.1 = u32[1280,1280]{1,0} add(%add.251917.3.clone.1, %xor.123519.5.clone.1) + %shift-left.111163.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123519.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117424.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123519.5.clone.1, %broadcast.244419.4352) + %or.116963.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111163.7.clone.1, %shift-right-logical.117424.7.clone.1) + %xor.123520.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251918.3.clone.1, %or.116963.5.clone.1) + %add.251919.3.clone.1 = u32[1280,1280]{1,0} add(%add.251918.3.clone.1, %xor.123520.3.clone.1) + %add.251920.7.clone.1 = u32[1280,1280]{1,0} add(%add.251919.3.clone.1, %broadcast.258099.44.clone.1) + %shift-left.111165.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123520.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117426.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123520.3.clone.1, %broadcast.244418.4352) + %or.116964.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111165.7.clone.1, %shift-right-logical.117426.7.clone.1) + %xor.123521.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251919.3.clone.1, %or.116964.5.clone.1) + %constant_218637_1_clone_1 = u32[] constant(154785445) + %broadcast.258137.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218637_1_clone_1), dimensions={} + %add.251921.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123521.3.clone.1, %broadcast.258137.5.clone.1) + %add.251922.5.clone.1 = u32[1280,1280]{1,0} add(%add.251920.7.clone.1, %add.251921.5.clone.1) + %shift-left.111166.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251921.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117427.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251921.5.clone.1, %broadcast.244416.5760) + %or.116965.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111166.9.clone.1, %shift-right-logical.117427.9.clone.1) + %xor.123522.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251922.5.clone.1, %or.116965.7.clone.1) + %add.251923.3.clone.1 = u32[1280,1280]{1,0} add(%add.251922.5.clone.1, %xor.123522.5.clone.1) + %shift-left.111167.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123522.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117428.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123522.5.clone.1, %broadcast.244429.2304) + %or.116966.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111167.9.clone.1, %shift-right-logical.117428.9.clone.1) + %xor.123523.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251923.3.clone.1, %or.116966.7.clone.1) + %add.251924.3.clone.1 = u32[1280,1280]{1,0} add(%add.251923.3.clone.1, %xor.123523.5.clone.1) + %shift-left.111168.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123523.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117429.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123523.5.clone.1, %broadcast.244430.4608) + %or.116967.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111168.9.clone.1, %shift-right-logical.117429.9.clone.1) + %xor.123524.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251924.3.clone.1, %or.116967.7.clone.1) + %add.251925.3.clone.1 = u32[1280,1280]{1,0} add(%add.251924.3.clone.1, %xor.123524.5.clone.1) + %add.251926.7.clone.1 = u32[1280,1280]{1,0} add(%add.251925.3.clone.1, %broadcast.258100.113.clone.1) + %shift-left.111170.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123524.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117431.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123524.5.clone.1, %broadcast.244434.2816) + %or.116968.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111170.11.clone.1, %shift-right-logical.117431.11.clone.1) + %xor.123525.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251925.3.clone.1, %or.116968.9.clone.1) + %constant_218638_1_clone_1 = u32[] constant(2050113243) + %broadcast.258149.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218638_1_clone_1), dimensions={} + %add.251927.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123525.7.clone.1, %broadcast.258149.5.clone.1) + %add.251928.5.clone.1 = u32[1280,1280]{1,0} add(%add.251926.7.clone.1, %add.251927.5.clone.1) + %shift-left.111171.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251927.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117432.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251927.5.clone.1, %broadcast.244415.6016) + %or.116969.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111171.9.clone.1, %shift-right-logical.117432.9.clone.1) + %xor.123526.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251928.5.clone.1, %or.116969.7.clone.1) + %add.251930.3.clone.1 = u32[1280,1280]{1,0} add(%add.251928.5.clone.1, %xor.123526.5.clone.1) + %shift-left.111172.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123526.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117433.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123526.5.clone.1, %broadcast.244417.5760) + %or.116971.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111172.9.clone.1, %shift-right-logical.117433.9.clone.1) + %xor.123527.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251930.3.clone.1, %or.116971.7.clone.1) + %add.251931.3.clone.1 = u32[1280,1280]{1,0} add(%add.251930.3.clone.1, %xor.123527.5.clone.1) + %shift-left.111173.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123527.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117434.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123527.5.clone.1, %broadcast.244419.4352) + %or.116972.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111173.5.clone.1, %shift-right-logical.117434.5.clone.1) + %xor.123528.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251931.3.clone.1, %or.116972.3.clone.1) + %add.251932.3.clone.1 = u32[1280,1280]{1,0} add(%add.251931.3.clone.1, %xor.123528.3.clone.1) + %add.251933.17.clone.1 = u32[1280,1280]{1,0} add(%add.251932.3.clone.1, %broadcast.258119.24.clone.1) + %shift-left.111175.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123528.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117436.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123528.3.clone.1, %broadcast.244418.4352) + %or.116973.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111175.5.clone.1, %shift-right-logical.117436.5.clone.1) + %xor.123529.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251932.3.clone.1, %or.116973.3.clone.1) + %constant_218639_1_clone_1 = u32[] constant(1759179700) + %broadcast.258160.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218639_1_clone_1), dimensions={} + %add.251934.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123529.15.clone.1, %broadcast.258160.19.clone.1) + %xor.123530.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251933.17.clone.1, %add.251934.19.clone.1) + %shift-right-logical.117437.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123530.17.clone.1, %broadcast.244468.1920) + %or.116974.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117437.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5825.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116974.13.clone.1) + %add.251935.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5825.11.clone.1, %broadcast.244470.1152) + %multiply.27190.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251935.9.clone.1, %broadcast.244471.896) + %add.251936.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27190.7.clone.1, %broadcast.244408.1024) + %maximum.3757.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251936.5.clone.1) + %abs.1587.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3757.3.clone.1) + %compare.7336.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1587.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27191.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3757.3.clone.1, %broadcast.244476.1152) + %negate.4679.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3757.3.clone.1) + %multiply.27192.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3757.3.clone.1, %negate.4679.5.clone.1) + %log-plus-one.1587.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27192.5.clone.1) + %negate.4680.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1587.3.clone.1) + %compare.7337.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4680.4.clone.1, %broadcast.244477.384), direction=LT + %select.21554.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21555.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21556.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21557.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21558.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21559.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21560.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21561.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21562.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251937.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4680.4.clone.1, %broadcast.244496.640) + %sqrt.1587.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4680.4.clone.1) + %add.251938.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1587.5.clone.1, %broadcast.244498.640) + %select.21564.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7337.3.clone.1, %add.251937.5.clone.1, %add.251938.5.clone.1) + %multiply.27193.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21562.3.clone.1, %select.21564.3.clone.1) + %add.251939.1.clone.1 = f32[1280,1280]{1,0} add(%select.21561.3.clone.1, %multiply.27193.1.clone.1) + %multiply.27194.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251939.1.clone.1, %select.21564.3.clone.1) + %add.251940.1.clone.1 = f32[1280,1280]{1,0} add(%select.21560.3.clone.1, %multiply.27194.1.clone.1) + %multiply.27195.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251940.1.clone.1, %select.21564.3.clone.1) + %add.251941.1.clone.1 = f32[1280,1280]{1,0} add(%select.21559.3.clone.1, %multiply.27195.1.clone.1) + %multiply.27196.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251941.1.clone.1, %select.21564.3.clone.1) + %add.251942.1.clone.1 = f32[1280,1280]{1,0} add(%select.21558.3.clone.1, %multiply.27196.1.clone.1) + %multiply.27197.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251942.1.clone.1, %select.21564.3.clone.1) + %add.251943.3.clone.1 = f32[1280,1280]{1,0} add(%select.21557.5.clone.1, %multiply.27197.1.clone.1) + %multiply.27198.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251943.3.clone.1, %select.21564.3.clone.1) + %add.251944.3.clone.1 = f32[1280,1280]{1,0} add(%select.21556.5.clone.1, %multiply.27198.1.clone.1) + %multiply.27199.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251944.3.clone.1, %select.21564.3.clone.1) + %add.251945.9.clone.1 = f32[1280,1280]{1,0} add(%select.21555.11.clone.1, %multiply.27199.7.clone.1) + %multiply.27200.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251945.9.clone.1, %select.21564.3.clone.1) + %add.251946.7.clone.1 = f32[1280,1280]{1,0} add(%select.21554.7.clone.1, %multiply.27200.7.clone.1) + %multiply.27201.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251946.7.clone.1, %maximum.3757.3.clone.1) + %select.21569.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7336.3.clone.1, %multiply.27191.9.clone.1, %multiply.27201.7.clone.1) + %multiply.27202.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21569.7.clone.1, %broadcast.244500.640) + %clamp.1231.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27202.5.clone.1, %broadcast.244501.384) + %multiply.27203.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1231.3.clone.1, %broadcast.244502.1) + %constant_167071_1_clone_1 = u32[] constant(2639504033) + %broadcast.248291.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167071_1_clone_1), dimensions={} + %add.246278.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248291.44.clone.1) + %constant_167078_1_clone_1 = u32[] constant(3763626449) + %broadcast.248292.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167078_1_clone_1), dimensions={} + %add.246279.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248292.113.clone.1) + %add.246281.35.clone.1 = u32[1280,1280]{1,0} add(%add.246278.37.clone.1, %add.246279.99.clone.1) + %shift-left.108710.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246279.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114846.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246279.99.clone.1, %broadcast.244415.6016) + %or.114364.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108710.31.clone.1, %shift-right-logical.114846.29.clone.1) + %xor.120911.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246281.35.clone.1, %or.114364.29.clone.1) + %add.246285.5.clone.1 = u32[1280,1280]{1,0} add(%add.246281.35.clone.1, %xor.120911.27.clone.1) + %shift-left.108712.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120911.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114847.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120911.27.clone.1, %broadcast.244417.5760) + %or.114365.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108712.9.clone.1, %shift-right-logical.114847.9.clone.1) + %xor.120912.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246285.5.clone.1, %or.114365.7.clone.1) + %add.246286.3.clone.1 = u32[1280,1280]{1,0} add(%add.246285.5.clone.1, %xor.120912.5.clone.1) + %shift-left.108713.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120912.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114848.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120912.5.clone.1, %broadcast.244419.4352) + %or.114366.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108713.5.clone.1, %shift-right-logical.114848.5.clone.1) + %xor.120913.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246286.3.clone.1, %or.114366.3.clone.1) + %add.246287.3.clone.1 = u32[1280,1280]{1,0} add(%add.246286.3.clone.1, %xor.120913.3.clone.1) + %add.246288.7.clone.1 = u32[1280,1280]{1,0} add(%add.246287.3.clone.1, %broadcast.248292.113.clone.1) + %shift-left.108714.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120913.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114849.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120913.3.clone.1, %broadcast.244418.4352) + %or.114367.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108714.5.clone.1, %shift-right-logical.114849.5.clone.1) + %xor.120914.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246287.3.clone.1, %or.114367.3.clone.1) + %constant_218016_1_clone_1 = u32[] constant(1725356203) + %broadcast.248302.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218016_1_clone_1), dimensions={} + %add.246290.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120914.3.clone.1, %broadcast.248302.5.clone.1) + %add.246291.5.clone.1 = u32[1280,1280]{1,0} add(%add.246288.7.clone.1, %add.246290.5.clone.1) + %shift-left.108715.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246290.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114850.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246290.5.clone.1, %broadcast.244416.5760) + %or.114368.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108715.9.clone.1, %shift-right-logical.114850.9.clone.1) + %xor.120915.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246291.5.clone.1, %or.114368.7.clone.1) + %add.246292.3.clone.1 = u32[1280,1280]{1,0} add(%add.246291.5.clone.1, %xor.120915.5.clone.1) + %shift-left.108717.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120915.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114851.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120915.5.clone.1, %broadcast.244429.2304) + %or.114369.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108717.9.clone.1, %shift-right-logical.114851.9.clone.1) + %xor.120916.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246292.3.clone.1, %or.114369.7.clone.1) + %add.246293.3.clone.1 = u32[1280,1280]{1,0} add(%add.246292.3.clone.1, %xor.120916.5.clone.1) + %shift-left.108718.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120916.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114852.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120916.5.clone.1, %broadcast.244430.4608) + %or.114371.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108718.9.clone.1, %shift-right-logical.114852.9.clone.1) + %xor.120917.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246293.3.clone.1, %or.114371.7.clone.1) + %add.246295.3.clone.1 = u32[1280,1280]{1,0} add(%add.246293.3.clone.1, %xor.120917.5.clone.1) + %constant_167080_1_clone_1 = u32[] constant(1725356202) + %broadcast.248309.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_167080_1_clone_1), dimensions={} + %add.246296.7.clone.1 = u32[1280,1280]{1,0} add(%add.246295.3.clone.1, %broadcast.248309.24.clone.1) + %shift-left.108719.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120917.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114853.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120917.5.clone.1, %broadcast.244434.2816) + %or.114372.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108719.11.clone.1, %shift-right-logical.114853.11.clone.1) + %xor.120918.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246295.3.clone.1, %or.114372.9.clone.1) + %constant_218017_1_clone_1 = u32[] constant(2639504035) + %broadcast.248312.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218017_1_clone_1), dimensions={} + %add.246297.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120918.7.clone.1, %broadcast.248312.5.clone.1) + %add.246298.5.clone.1 = u32[1280,1280]{1,0} add(%add.246296.7.clone.1, %add.246297.5.clone.1) + %shift-left.108720.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246297.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114854.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246297.5.clone.1, %broadcast.244415.6016) + %or.114374.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108720.9.clone.1, %shift-right-logical.114854.9.clone.1) + %xor.120919.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246298.5.clone.1, %or.114374.7.clone.1) + %add.246300.3.clone.1 = u32[1280,1280]{1,0} add(%add.246298.5.clone.1, %xor.120919.5.clone.1) + %shift-left.108721.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120919.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114855.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120919.5.clone.1, %broadcast.244417.5760) + %or.114375.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108721.9.clone.1, %shift-right-logical.114855.9.clone.1) + %xor.120920.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246300.3.clone.1, %or.114375.7.clone.1) + %add.246301.3.clone.1 = u32[1280,1280]{1,0} add(%add.246300.3.clone.1, %xor.120920.5.clone.1) + %shift-left.108722.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120920.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114856.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120920.5.clone.1, %broadcast.244419.4352) + %or.114376.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108722.7.clone.1, %shift-right-logical.114856.7.clone.1) + %xor.120921.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246301.3.clone.1, %or.114376.5.clone.1) + %add.246302.3.clone.1 = u32[1280,1280]{1,0} add(%add.246301.3.clone.1, %xor.120921.3.clone.1) + %add.246303.7.clone.1 = u32[1280,1280]{1,0} add(%add.246302.3.clone.1, %broadcast.248291.44.clone.1) + %shift-left.108723.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120921.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114857.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120921.3.clone.1, %broadcast.244418.4352) + %or.114377.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108723.7.clone.1, %shift-right-logical.114857.7.clone.1) + %xor.120922.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246302.3.clone.1, %or.114377.5.clone.1) + %constant_218018_1_clone_1 = u32[] constant(3763626452) + %broadcast.248322.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218018_1_clone_1), dimensions={} + %add.246304.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120922.3.clone.1, %broadcast.248322.5.clone.1) + %add.246306.5.clone.1 = u32[1280,1280]{1,0} add(%add.246303.7.clone.1, %add.246304.5.clone.1) + %shift-left.108724.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246304.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114858.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246304.5.clone.1, %broadcast.244416.5760) + %or.114378.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108724.9.clone.1, %shift-right-logical.114858.9.clone.1) + %xor.120923.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246306.5.clone.1, %or.114378.7.clone.1) + %add.246310.3.clone.1 = u32[1280,1280]{1,0} add(%add.246306.5.clone.1, %xor.120923.5.clone.1) + %shift-left.108725.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120923.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114859.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120923.5.clone.1, %broadcast.244429.2304) + %or.114379.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108725.9.clone.1, %shift-right-logical.114859.9.clone.1) + %xor.120924.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246310.3.clone.1, %or.114379.7.clone.1) + %add.246311.3.clone.1 = u32[1280,1280]{1,0} add(%add.246310.3.clone.1, %xor.120924.5.clone.1) + %shift-left.108727.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120924.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114860.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120924.5.clone.1, %broadcast.244430.4608) + %or.114380.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108727.9.clone.1, %shift-right-logical.114860.9.clone.1) + %xor.120925.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246311.3.clone.1, %or.114380.7.clone.1) + %add.246312.3.clone.1 = u32[1280,1280]{1,0} add(%add.246311.3.clone.1, %xor.120925.5.clone.1) + %add.246313.7.clone.1 = u32[1280,1280]{1,0} add(%add.246312.3.clone.1, %broadcast.248292.113.clone.1) + %shift-left.108728.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120925.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114861.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120925.5.clone.1, %broadcast.244434.2816) + %or.114381.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108728.11.clone.1, %shift-right-logical.114861.11.clone.1) + %xor.120926.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246312.3.clone.1, %or.114381.9.clone.1) + %constant_218019_1_clone_1 = u32[] constant(1725356206) + %broadcast.248332.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218019_1_clone_1), dimensions={} + %add.246315.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120926.7.clone.1, %broadcast.248332.5.clone.1) + %add.246316.5.clone.1 = u32[1280,1280]{1,0} add(%add.246313.7.clone.1, %add.246315.5.clone.1) + %shift-left.108729.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246315.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114862.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246315.5.clone.1, %broadcast.244415.6016) + %or.114382.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108729.9.clone.1, %shift-right-logical.114862.9.clone.1) + %xor.120927.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246316.5.clone.1, %or.114382.7.clone.1) + %add.246317.3.clone.1 = u32[1280,1280]{1,0} add(%add.246316.5.clone.1, %xor.120927.5.clone.1) + %shift-left.108730.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120927.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114863.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120927.5.clone.1, %broadcast.244417.5760) + %or.114383.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108730.9.clone.1, %shift-right-logical.114863.9.clone.1) + %xor.120928.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246317.3.clone.1, %or.114383.7.clone.1) + %add.246318.3.clone.1 = u32[1280,1280]{1,0} add(%add.246317.3.clone.1, %xor.120928.5.clone.1) + %shift-left.108732.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120928.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114864.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120928.5.clone.1, %broadcast.244419.4352) + %or.114384.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108732.5.clone.1, %shift-right-logical.114864.5.clone.1) + %xor.120929.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246318.3.clone.1, %or.114384.3.clone.1) + %add.246320.3.clone.1 = u32[1280,1280]{1,0} add(%add.246318.3.clone.1, %xor.120929.3.clone.1) + %add.246321.17.clone.1 = u32[1280,1280]{1,0} add(%add.246320.3.clone.1, %broadcast.248309.24.clone.1) + %shift-left.108733.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120929.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114865.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120929.3.clone.1, %broadcast.244418.4352) + %or.114385.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108733.5.clone.1, %shift-right-logical.114865.5.clone.1) + %xor.120930.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246320.3.clone.1, %or.114385.3.clone.1) + %constant_218020_1_clone_1 = u32[] constant(2639504038) + %broadcast.248342.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218020_1_clone_1), dimensions={} + %add.246322.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120930.15.clone.1, %broadcast.248342.19.clone.1) + %xor.120931.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246321.17.clone.1, %add.246322.19.clone.1) + %shift-right-logical.114866.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120931.17.clone.1, %broadcast.244468.1920) + %or.114386.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114866.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5713.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114386.13.clone.1) + %add.246323.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5713.11.clone.1, %broadcast.244470.1152) + %multiply.26048.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246323.9.clone.1, %broadcast.244471.896) + %add.246325.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26048.7.clone.1, %broadcast.244408.1024) + %maximum.3645.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246325.5.clone.1) + %abs.1513.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3645.3.clone.1) + %compare.7174.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1513.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26049.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3645.3.clone.1, %broadcast.244476.1152) + %negate.4531.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3645.3.clone.1) + %multiply.26050.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3645.3.clone.1, %negate.4531.5.clone.1) + %log-plus-one.1513.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26050.5.clone.1) + %negate.4532.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1513.3.clone.1) + %compare.7175.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4532.4.clone.1, %broadcast.244477.384), direction=LT + %select.20708.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20709.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20710.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20711.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20712.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20713.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20714.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20715.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20716.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246326.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4532.4.clone.1, %broadcast.244496.640) + %sqrt.1513.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4532.4.clone.1) + %add.246327.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1513.5.clone.1, %broadcast.244498.640) + %select.20717.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7175.3.clone.1, %add.246326.5.clone.1, %add.246327.5.clone.1) + %multiply.26051.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20716.3.clone.1, %select.20717.3.clone.1) + %add.246328.1.clone.1 = f32[1280,1280]{1,0} add(%select.20715.3.clone.1, %multiply.26051.1.clone.1) + %multiply.26052.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246328.1.clone.1, %select.20717.3.clone.1) + %add.246329.1.clone.1 = f32[1280,1280]{1,0} add(%select.20714.3.clone.1, %multiply.26052.1.clone.1) + %multiply.26053.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246329.1.clone.1, %select.20717.3.clone.1) + %add.246331.1.clone.1 = f32[1280,1280]{1,0} add(%select.20713.3.clone.1, %multiply.26053.1.clone.1) + %multiply.26054.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246331.1.clone.1, %select.20717.3.clone.1) + %add.246335.1.clone.1 = f32[1280,1280]{1,0} add(%select.20712.3.clone.1, %multiply.26054.1.clone.1) + %multiply.26055.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246335.1.clone.1, %select.20717.3.clone.1) + %add.246336.3.clone.1 = f32[1280,1280]{1,0} add(%select.20711.5.clone.1, %multiply.26055.1.clone.1) + %multiply.26056.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246336.3.clone.1, %select.20717.3.clone.1) + %add.246337.3.clone.1 = f32[1280,1280]{1,0} add(%select.20710.5.clone.1, %multiply.26056.1.clone.1) + %multiply.26057.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246337.3.clone.1, %select.20717.3.clone.1) + %add.246338.9.clone.1 = f32[1280,1280]{1,0} add(%select.20709.11.clone.1, %multiply.26057.7.clone.1) + %multiply.26058.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246338.9.clone.1, %select.20717.3.clone.1) + %add.246340.7.clone.1 = f32[1280,1280]{1,0} add(%select.20708.7.clone.1, %multiply.26058.7.clone.1) + %multiply.26059.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246340.7.clone.1, %maximum.3645.3.clone.1) + %select.20718.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7174.3.clone.1, %multiply.26049.9.clone.1, %multiply.26059.7.clone.1) + %multiply.26060.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20718.7.clone.1, %broadcast.244500.640) + %clamp.1157.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26060.5.clone.1, %broadcast.244501.384) + %multiply.26061.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1157.3.clone.1, %broadcast.244502.1) + %constant_181926_1_clone_1 = u32[] constant(2048903110) + %broadcast.254700.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181926_1_clone_1), dimensions={} + %add.249942.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254700.44.clone.1) + %constant_181933_1_clone_1 = u32[] constant(733765324) + %broadcast.254701.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181933_1_clone_1), dimensions={} + %add.249943.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254701.113.clone.1) + %add.249944.35.clone.1 = u32[1280,1280]{1,0} add(%add.249942.37.clone.1, %add.249943.99.clone.1) + %shift-left.110280.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249943.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116515.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249943.99.clone.1, %broadcast.244415.6016) + %or.116034.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110280.31.clone.1, %shift-right-logical.116515.29.clone.1) + %xor.122601.27.clone.1 = u32[1280,1280]{1,0} xor(%add.249944.35.clone.1, %or.116034.29.clone.1) + %add.249945.5.clone.1 = u32[1280,1280]{1,0} add(%add.249944.35.clone.1, %xor.122601.27.clone.1) + %shift-left.110281.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122601.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116516.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122601.27.clone.1, %broadcast.244417.5760) + %or.116035.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110281.9.clone.1, %shift-right-logical.116516.9.clone.1) + %xor.122602.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249945.5.clone.1, %or.116035.7.clone.1) + %add.249947.3.clone.1 = u32[1280,1280]{1,0} add(%add.249945.5.clone.1, %xor.122602.5.clone.1) + %shift-left.110282.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122602.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116517.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122602.5.clone.1, %broadcast.244419.4352) + %or.116036.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110282.5.clone.1, %shift-right-logical.116517.5.clone.1) + %xor.122603.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249947.3.clone.1, %or.116036.3.clone.1) + %add.249950.3.clone.1 = u32[1280,1280]{1,0} add(%add.249947.3.clone.1, %xor.122603.3.clone.1) + %add.249951.7.clone.1 = u32[1280,1280]{1,0} add(%add.249950.3.clone.1, %broadcast.254701.113.clone.1) + %shift-left.110283.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122603.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116518.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122603.3.clone.1, %broadcast.244418.4352) + %or.116037.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110283.5.clone.1, %shift-right-logical.116518.5.clone.1) + %xor.122604.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249950.3.clone.1, %or.116037.3.clone.1) + %constant_218419_1_clone_1 = u32[] constant(1249018577) + %broadcast.254711.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218419_1_clone_1), dimensions={} + %add.249952.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122604.3.clone.1, %broadcast.254711.5.clone.1) + %add.249953.5.clone.1 = u32[1280,1280]{1,0} add(%add.249951.7.clone.1, %add.249952.5.clone.1) + %shift-left.110284.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249952.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116520.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249952.5.clone.1, %broadcast.244416.5760) + %or.116038.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110284.9.clone.1, %shift-right-logical.116520.9.clone.1) + %xor.122605.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249953.5.clone.1, %or.116038.7.clone.1) + %add.249955.3.clone.1 = u32[1280,1280]{1,0} add(%add.249953.5.clone.1, %xor.122605.5.clone.1) + %shift-left.110285.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122605.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116521.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122605.5.clone.1, %broadcast.244429.2304) + %or.116039.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110285.9.clone.1, %shift-right-logical.116521.9.clone.1) + %xor.122606.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249955.3.clone.1, %or.116039.7.clone.1) + %add.249956.3.clone.1 = u32[1280,1280]{1,0} add(%add.249955.3.clone.1, %xor.122606.5.clone.1) + %shift-left.110286.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122606.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116522.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122606.5.clone.1, %broadcast.244430.4608) + %or.116040.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110286.9.clone.1, %shift-right-logical.116522.9.clone.1) + %xor.122607.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249956.3.clone.1, %or.116040.7.clone.1) + %add.249957.3.clone.1 = u32[1280,1280]{1,0} add(%add.249956.3.clone.1, %xor.122607.5.clone.1) + %constant_181935_1_clone_1 = u32[] constant(1249018576) + %broadcast.254718.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181935_1_clone_1), dimensions={} + %add.249958.7.clone.1 = u32[1280,1280]{1,0} add(%add.249957.3.clone.1, %broadcast.254718.24.clone.1) + %shift-left.110287.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122607.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116523.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122607.5.clone.1, %broadcast.244434.2816) + %or.116041.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110287.11.clone.1, %shift-right-logical.116523.11.clone.1) + %xor.122608.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249957.3.clone.1, %or.116041.9.clone.1) + %constant_218420_1_clone_1 = u32[] constant(2048903112) + %broadcast.254721.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218420_1_clone_1), dimensions={} + %add.249960.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122608.7.clone.1, %broadcast.254721.5.clone.1) + %add.249961.5.clone.1 = u32[1280,1280]{1,0} add(%add.249958.7.clone.1, %add.249960.5.clone.1) + %shift-left.110288.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249960.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116524.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249960.5.clone.1, %broadcast.244415.6016) + %or.116042.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110288.9.clone.1, %shift-right-logical.116524.9.clone.1) + %xor.122609.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249961.5.clone.1, %or.116042.7.clone.1) + %add.249962.3.clone.1 = u32[1280,1280]{1,0} add(%add.249961.5.clone.1, %xor.122609.5.clone.1) + %shift-left.110289.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122609.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116525.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122609.5.clone.1, %broadcast.244417.5760) + %or.116043.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110289.9.clone.1, %shift-right-logical.116525.9.clone.1) + %xor.122610.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249962.3.clone.1, %or.116043.7.clone.1) + %add.249963.3.clone.1 = u32[1280,1280]{1,0} add(%add.249962.3.clone.1, %xor.122610.5.clone.1) + %shift-left.110290.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122610.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116526.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122610.5.clone.1, %broadcast.244419.4352) + %or.116044.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110290.7.clone.1, %shift-right-logical.116526.7.clone.1) + %xor.122611.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249963.3.clone.1, %or.116044.5.clone.1) + %add.249965.3.clone.1 = u32[1280,1280]{1,0} add(%add.249963.3.clone.1, %xor.122611.3.clone.1) + %add.249966.7.clone.1 = u32[1280,1280]{1,0} add(%add.249965.3.clone.1, %broadcast.254700.44.clone.1) + %shift-left.110291.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122611.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116527.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122611.3.clone.1, %broadcast.244418.4352) + %or.116046.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110291.7.clone.1, %shift-right-logical.116527.7.clone.1) + %xor.122612.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249965.3.clone.1, %or.116046.5.clone.1) + %constant_218421_1_clone_1 = u32[] constant(733765327) + %broadcast.254731.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218421_1_clone_1), dimensions={} + %add.249967.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122612.3.clone.1, %broadcast.254731.5.clone.1) + %add.249968.5.clone.1 = u32[1280,1280]{1,0} add(%add.249966.7.clone.1, %add.249967.5.clone.1) + %shift-left.110292.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249967.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116528.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249967.5.clone.1, %broadcast.244416.5760) + %or.116047.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110292.9.clone.1, %shift-right-logical.116528.9.clone.1) + %xor.122613.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249968.5.clone.1, %or.116047.7.clone.1) + %add.249969.3.clone.1 = u32[1280,1280]{1,0} add(%add.249968.5.clone.1, %xor.122613.5.clone.1) + %shift-left.110293.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122613.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116530.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122613.5.clone.1, %broadcast.244429.2304) + %or.116049.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110293.9.clone.1, %shift-right-logical.116530.9.clone.1) + %xor.122614.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249969.3.clone.1, %or.116049.7.clone.1) + %add.249971.3.clone.1 = u32[1280,1280]{1,0} add(%add.249969.3.clone.1, %xor.122614.5.clone.1) + %shift-left.110294.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122614.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116531.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122614.5.clone.1, %broadcast.244430.4608) + %or.116050.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110294.9.clone.1, %shift-right-logical.116531.9.clone.1) + %xor.122615.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249971.3.clone.1, %or.116050.7.clone.1) + %add.249975.3.clone.1 = u32[1280,1280]{1,0} add(%add.249971.3.clone.1, %xor.122615.5.clone.1) + %add.249976.7.clone.1 = u32[1280,1280]{1,0} add(%add.249975.3.clone.1, %broadcast.254701.113.clone.1) + %shift-left.110295.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122615.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116532.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122615.5.clone.1, %broadcast.244434.2816) + %or.116051.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110295.11.clone.1, %shift-right-logical.116532.11.clone.1) + %xor.122616.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249975.3.clone.1, %or.116051.9.clone.1) + %constant_218422_1_clone_1 = u32[] constant(1249018580) + %broadcast.254741.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218422_1_clone_1), dimensions={} + %add.249977.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122616.7.clone.1, %broadcast.254741.5.clone.1) + %add.249978.5.clone.1 = u32[1280,1280]{1,0} add(%add.249976.7.clone.1, %add.249977.5.clone.1) + %shift-left.110296.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249977.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116533.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249977.5.clone.1, %broadcast.244415.6016) + %or.116052.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110296.9.clone.1, %shift-right-logical.116533.9.clone.1) + %xor.122617.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249978.5.clone.1, %or.116052.7.clone.1) + %add.249980.3.clone.1 = u32[1280,1280]{1,0} add(%add.249978.5.clone.1, %xor.122617.5.clone.1) + %shift-left.110297.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122617.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116535.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122617.5.clone.1, %broadcast.244417.5760) + %or.116053.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110297.9.clone.1, %shift-right-logical.116535.9.clone.1) + %xor.122618.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249980.3.clone.1, %or.116053.7.clone.1) + %add.249981.3.clone.1 = u32[1280,1280]{1,0} add(%add.249980.3.clone.1, %xor.122618.5.clone.1) + %shift-left.110298.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122618.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116536.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122618.5.clone.1, %broadcast.244419.4352) + %or.116054.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110298.5.clone.1, %shift-right-logical.116536.5.clone.1) + %xor.122619.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249981.3.clone.1, %or.116054.3.clone.1) + %add.249982.3.clone.1 = u32[1280,1280]{1,0} add(%add.249981.3.clone.1, %xor.122619.3.clone.1) + %add.249983.17.clone.1 = u32[1280,1280]{1,0} add(%add.249982.3.clone.1, %broadcast.254718.24.clone.1) + %shift-left.110299.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122619.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116537.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122619.3.clone.1, %broadcast.244418.4352) + %or.116055.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110299.5.clone.1, %shift-right-logical.116537.5.clone.1) + %xor.122620.15.clone.1 = u32[1280,1280]{1,0} xor(%add.249982.3.clone.1, %or.116055.3.clone.1) + %constant_218423_1_clone_1 = u32[] constant(2048903115) + %broadcast.254751.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218423_1_clone_1), dimensions={} + %add.249985.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122620.15.clone.1, %broadcast.254751.19.clone.1) + %xor.122621.17.clone.1 = u32[1280,1280]{1,0} xor(%add.249983.17.clone.1, %add.249985.19.clone.1) + %shift-right-logical.116538.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122621.17.clone.1, %broadcast.244468.1920) + %or.116056.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116538.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5786.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116056.13.clone.1) + %add.249986.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5786.11.clone.1, %broadcast.244470.1152) + %multiply.26797.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249986.9.clone.1, %broadcast.244471.896) + %add.249987.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26797.7.clone.1, %broadcast.244408.1024) + %maximum.3718.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.249987.5.clone.1) + %abs.1562.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3718.3.clone.1) + %compare.7273.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1562.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26798.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3718.3.clone.1, %broadcast.244476.1152) + %negate.4629.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3718.3.clone.1) + %multiply.26799.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3718.3.clone.1, %negate.4629.5.clone.1) + %log-plus-one.1562.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26799.5.clone.1) + %negate.4630.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1562.3.clone.1) + %compare.7274.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4630.4.clone.1, %broadcast.244477.384), direction=LT + %select.21268.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21269.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21270.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21271.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21272.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21273.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21274.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21275.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21276.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.249988.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4630.4.clone.1, %broadcast.244496.640) + %sqrt.1562.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4630.4.clone.1) + %add.249990.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1562.5.clone.1, %broadcast.244498.640) + %select.21277.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7274.3.clone.1, %add.249988.5.clone.1, %add.249990.5.clone.1) + %multiply.26800.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21276.3.clone.1, %select.21277.3.clone.1) + %add.249991.1.clone.1 = f32[1280,1280]{1,0} add(%select.21275.3.clone.1, %multiply.26800.1.clone.1) + %multiply.26801.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249991.1.clone.1, %select.21277.3.clone.1) + %add.249992.1.clone.1 = f32[1280,1280]{1,0} add(%select.21274.3.clone.1, %multiply.26801.1.clone.1) + %multiply.26802.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249992.1.clone.1, %select.21277.3.clone.1) + %add.249993.1.clone.1 = f32[1280,1280]{1,0} add(%select.21273.3.clone.1, %multiply.26802.1.clone.1) + %multiply.26803.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249993.1.clone.1, %select.21277.3.clone.1) + %add.249994.1.clone.1 = f32[1280,1280]{1,0} add(%select.21272.3.clone.1, %multiply.26803.1.clone.1) + %multiply.26804.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249994.1.clone.1, %select.21277.3.clone.1) + %add.249996.3.clone.1 = f32[1280,1280]{1,0} add(%select.21271.5.clone.1, %multiply.26804.1.clone.1) + %multiply.26805.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249996.3.clone.1, %select.21277.3.clone.1) + %add.250000.3.clone.1 = f32[1280,1280]{1,0} add(%select.21270.5.clone.1, %multiply.26805.1.clone.1) + %multiply.26806.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250000.3.clone.1, %select.21277.3.clone.1) + %add.250001.9.clone.1 = f32[1280,1280]{1,0} add(%select.21269.11.clone.1, %multiply.26806.7.clone.1) + %multiply.26807.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250001.9.clone.1, %select.21277.3.clone.1) + %add.250002.7.clone.1 = f32[1280,1280]{1,0} add(%select.21268.7.clone.1, %multiply.26807.7.clone.1) + %multiply.26808.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.250002.7.clone.1, %maximum.3718.3.clone.1) + %select.21278.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7273.3.clone.1, %multiply.26798.9.clone.1, %multiply.26808.7.clone.1) + %multiply.26809.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21278.7.clone.1, %broadcast.244500.640) + %clamp.1206.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26809.5.clone.1, %broadcast.244501.384) + %multiply.26810.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1206.3.clone.1, %broadcast.244502.1) + %constant_166861_1_clone_1 = u32[] constant(2479642663) + %broadcast.248189.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166861_1_clone_1), dimensions={} + %add.246228.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248189.44.clone.1) + %constant_166868_1_clone_1 = u32[] constant(3220179177) + %broadcast.248191.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166868_1_clone_1), dimensions={} + %add.246229.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248191.113.clone.1) + %add.246230.35.clone.1 = u32[1280,1280]{1,0} add(%add.246228.37.clone.1, %add.246229.99.clone.1) + %shift-left.108687.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246229.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114825.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246229.99.clone.1, %broadcast.244415.6016) + %or.114343.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108687.31.clone.1, %shift-right-logical.114825.29.clone.1) + %xor.120890.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246230.35.clone.1, %or.114343.29.clone.1) + %add.246231.5.clone.1 = u32[1280,1280]{1,0} add(%add.246230.35.clone.1, %xor.120890.27.clone.1) + %shift-left.108688.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120890.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114826.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120890.27.clone.1, %broadcast.244417.5760) + %or.114344.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108688.9.clone.1, %shift-right-logical.114826.9.clone.1) + %xor.120891.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246231.5.clone.1, %or.114344.7.clone.1) + %add.246232.3.clone.1 = u32[1280,1280]{1,0} add(%add.246231.5.clone.1, %xor.120891.5.clone.1) + %shift-left.108689.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120891.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114827.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120891.5.clone.1, %broadcast.244419.4352) + %or.114345.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108689.5.clone.1, %shift-right-logical.114827.5.clone.1) + %xor.120892.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246232.3.clone.1, %or.114345.3.clone.1) + %add.246233.3.clone.1 = u32[1280,1280]{1,0} add(%add.246232.3.clone.1, %xor.120892.3.clone.1) + %add.246234.7.clone.1 = u32[1280,1280]{1,0} add(%add.246233.3.clone.1, %broadcast.248191.113.clone.1) + %shift-left.108690.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120892.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114828.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120892.3.clone.1, %broadcast.244418.4352) + %or.114346.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108690.5.clone.1, %shift-right-logical.114828.5.clone.1) + %xor.120893.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246233.3.clone.1, %or.114346.3.clone.1) + %constant_218011_1_clone_1 = u32[] constant(938298133) + %broadcast.248201.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218011_1_clone_1), dimensions={} + %add.246235.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120893.3.clone.1, %broadcast.248201.5.clone.1) + %add.246236.5.clone.1 = u32[1280,1280]{1,0} add(%add.246234.7.clone.1, %add.246235.5.clone.1) + %shift-left.108692.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246235.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114829.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246235.5.clone.1, %broadcast.244416.5760) + %or.114347.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108692.9.clone.1, %shift-right-logical.114829.9.clone.1) + %xor.120894.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246236.5.clone.1, %or.114347.7.clone.1) + %add.246237.3.clone.1 = u32[1280,1280]{1,0} add(%add.246236.5.clone.1, %xor.120894.5.clone.1) + %shift-left.108693.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120894.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114830.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120894.5.clone.1, %broadcast.244429.2304) + %or.114348.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108693.9.clone.1, %shift-right-logical.114830.9.clone.1) + %xor.120895.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246237.3.clone.1, %or.114348.7.clone.1) + %add.246238.3.clone.1 = u32[1280,1280]{1,0} add(%add.246237.3.clone.1, %xor.120895.5.clone.1) + %shift-left.108694.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120895.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114831.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120895.5.clone.1, %broadcast.244430.4608) + %or.114349.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108694.9.clone.1, %shift-right-logical.114831.9.clone.1) + %xor.120896.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246238.3.clone.1, %or.114349.7.clone.1) + %add.246239.3.clone.1 = u32[1280,1280]{1,0} add(%add.246238.3.clone.1, %xor.120896.5.clone.1) + %constant_166870_1_clone_1 = u32[] constant(938298132) + %broadcast.248208.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166870_1_clone_1), dimensions={} + %add.246240.7.clone.1 = u32[1280,1280]{1,0} add(%add.246239.3.clone.1, %broadcast.248208.24.clone.1) + %shift-left.108695.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120896.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114832.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120896.5.clone.1, %broadcast.244434.2816) + %or.114350.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108695.11.clone.1, %shift-right-logical.114832.11.clone.1) + %xor.120897.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246239.3.clone.1, %or.114350.9.clone.1) + %constant_218012_1_clone_1 = u32[] constant(2479642665) + %broadcast.248211.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218012_1_clone_1), dimensions={} + %add.246241.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120897.7.clone.1, %broadcast.248211.5.clone.1) + %add.246242.5.clone.1 = u32[1280,1280]{1,0} add(%add.246240.7.clone.1, %add.246241.5.clone.1) + %shift-left.108696.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246241.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114833.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246241.5.clone.1, %broadcast.244415.6016) + %or.114351.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108696.9.clone.1, %shift-right-logical.114833.9.clone.1) + %xor.120898.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246242.5.clone.1, %or.114351.7.clone.1) + %add.246243.3.clone.1 = u32[1280,1280]{1,0} add(%add.246242.5.clone.1, %xor.120898.5.clone.1) + %shift-left.108697.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120898.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114834.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120898.5.clone.1, %broadcast.244417.5760) + %or.114352.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108697.9.clone.1, %shift-right-logical.114834.9.clone.1) + %xor.120899.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246243.3.clone.1, %or.114352.7.clone.1) + %add.246244.3.clone.1 = u32[1280,1280]{1,0} add(%add.246243.3.clone.1, %xor.120899.5.clone.1) + %shift-left.108698.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120899.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114835.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120899.5.clone.1, %broadcast.244419.4352) + %or.114353.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108698.7.clone.1, %shift-right-logical.114835.7.clone.1) + %xor.120900.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246244.3.clone.1, %or.114353.5.clone.1) + %add.246245.3.clone.1 = u32[1280,1280]{1,0} add(%add.246244.3.clone.1, %xor.120900.3.clone.1) + %add.246246.7.clone.1 = u32[1280,1280]{1,0} add(%add.246245.3.clone.1, %broadcast.248189.44.clone.1) + %shift-left.108699.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120900.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114836.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120900.3.clone.1, %broadcast.244418.4352) + %or.114354.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108699.7.clone.1, %shift-right-logical.114836.7.clone.1) + %xor.120901.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246245.3.clone.1, %or.114354.5.clone.1) + %constant_218013_1_clone_1 = u32[] constant(3220179180) + %broadcast.248223.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218013_1_clone_1), dimensions={} + %add.246247.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120901.3.clone.1, %broadcast.248223.5.clone.1) + %add.246248.5.clone.1 = u32[1280,1280]{1,0} add(%add.246246.7.clone.1, %add.246247.5.clone.1) + %shift-left.108700.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246247.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114837.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246247.5.clone.1, %broadcast.244416.5760) + %or.114355.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108700.9.clone.1, %shift-right-logical.114837.9.clone.1) + %xor.120902.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246248.5.clone.1, %or.114355.7.clone.1) + %add.246249.3.clone.1 = u32[1280,1280]{1,0} add(%add.246248.5.clone.1, %xor.120902.5.clone.1) + %shift-left.108702.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120902.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114838.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120902.5.clone.1, %broadcast.244429.2304) + %or.114356.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108702.9.clone.1, %shift-right-logical.114838.9.clone.1) + %xor.120903.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246249.3.clone.1, %or.114356.7.clone.1) + %add.246250.3.clone.1 = u32[1280,1280]{1,0} add(%add.246249.3.clone.1, %xor.120903.5.clone.1) + %shift-left.108703.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120903.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114839.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120903.5.clone.1, %broadcast.244430.4608) + %or.114357.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108703.9.clone.1, %shift-right-logical.114839.9.clone.1) + %xor.120904.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246250.3.clone.1, %or.114357.7.clone.1) + %add.246251.3.clone.1 = u32[1280,1280]{1,0} add(%add.246250.3.clone.1, %xor.120904.5.clone.1) + %add.246252.7.clone.1 = u32[1280,1280]{1,0} add(%add.246251.3.clone.1, %broadcast.248191.113.clone.1) + %shift-left.108704.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120904.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114840.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120904.5.clone.1, %broadcast.244434.2816) + %or.114358.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108704.11.clone.1, %shift-right-logical.114840.11.clone.1) + %xor.120905.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246251.3.clone.1, %or.114358.9.clone.1) + %constant_218014_1_clone_1 = u32[] constant(938298136) + %broadcast.248238.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218014_1_clone_1), dimensions={} + %add.246253.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120905.7.clone.1, %broadcast.248238.5.clone.1) + %add.246254.5.clone.1 = u32[1280,1280]{1,0} add(%add.246252.7.clone.1, %add.246253.5.clone.1) + %shift-left.108705.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246253.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114841.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246253.5.clone.1, %broadcast.244415.6016) + %or.114359.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108705.9.clone.1, %shift-right-logical.114841.9.clone.1) + %xor.120906.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246254.5.clone.1, %or.114359.7.clone.1) + %add.246255.3.clone.1 = u32[1280,1280]{1,0} add(%add.246254.5.clone.1, %xor.120906.5.clone.1) + %shift-left.108707.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120906.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114842.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120906.5.clone.1, %broadcast.244417.5760) + %or.114360.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108707.9.clone.1, %shift-right-logical.114842.9.clone.1) + %xor.120907.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246255.3.clone.1, %or.114360.7.clone.1) + %add.246257.3.clone.1 = u32[1280,1280]{1,0} add(%add.246255.3.clone.1, %xor.120907.5.clone.1) + %shift-left.108708.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120907.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114843.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120907.5.clone.1, %broadcast.244419.4352) + %or.114361.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108708.5.clone.1, %shift-right-logical.114843.5.clone.1) + %xor.120908.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246257.3.clone.1, %or.114361.3.clone.1) + %add.246260.3.clone.1 = u32[1280,1280]{1,0} add(%add.246257.3.clone.1, %xor.120908.3.clone.1) + %add.246261.17.clone.1 = u32[1280,1280]{1,0} add(%add.246260.3.clone.1, %broadcast.248208.24.clone.1) + %shift-left.108709.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120908.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114844.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120908.3.clone.1, %broadcast.244418.4352) + %or.114362.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108709.5.clone.1, %shift-right-logical.114844.5.clone.1) + %xor.120909.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246260.3.clone.1, %or.114362.3.clone.1) + %constant_218015_1_clone_1 = u32[] constant(2479642668) + %broadcast.248250.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218015_1_clone_1), dimensions={} + %add.246262.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120909.15.clone.1, %broadcast.248250.19.clone.1) + %xor.120910.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246261.17.clone.1, %add.246262.19.clone.1) + %shift-right-logical.114845.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120910.17.clone.1, %broadcast.244468.1920) + %or.114363.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114845.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5712.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114363.13.clone.1) + %add.246263.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5712.11.clone.1, %broadcast.244470.1152) + %multiply.26034.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246263.9.clone.1, %broadcast.244471.896) + %add.246265.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26034.7.clone.1, %broadcast.244408.1024) + %maximum.3644.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246265.5.clone.1) + %abs.1512.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3644.3.clone.1) + %compare.7172.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1512.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26035.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3644.3.clone.1, %broadcast.244476.1152) + %negate.4529.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3644.3.clone.1) + %multiply.26036.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3644.3.clone.1, %negate.4529.5.clone.1) + %log-plus-one.1512.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26036.5.clone.1) + %negate.4530.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1512.3.clone.1) + %compare.7173.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4530.4.clone.1, %broadcast.244477.384), direction=LT + %select.20697.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20698.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20699.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20700.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20701.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20702.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20703.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20704.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20705.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246266.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4530.4.clone.1, %broadcast.244496.640) + %sqrt.1512.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4530.4.clone.1) + %add.246267.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1512.5.clone.1, %broadcast.244498.640) + %select.20706.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7173.3.clone.1, %add.246266.5.clone.1, %add.246267.5.clone.1) + %multiply.26037.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20705.3.clone.1, %select.20706.3.clone.1) + %add.246268.1.clone.1 = f32[1280,1280]{1,0} add(%select.20704.3.clone.1, %multiply.26037.1.clone.1) + %multiply.26038.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246268.1.clone.1, %select.20706.3.clone.1) + %add.246270.1.clone.1 = f32[1280,1280]{1,0} add(%select.20703.3.clone.1, %multiply.26038.1.clone.1) + %multiply.26039.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246270.1.clone.1, %select.20706.3.clone.1) + %add.246271.1.clone.1 = f32[1280,1280]{1,0} add(%select.20702.3.clone.1, %multiply.26039.1.clone.1) + %multiply.26040.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246271.1.clone.1, %select.20706.3.clone.1) + %add.246272.1.clone.1 = f32[1280,1280]{1,0} add(%select.20701.3.clone.1, %multiply.26040.1.clone.1) + %multiply.26041.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246272.1.clone.1, %select.20706.3.clone.1) + %add.246273.3.clone.1 = f32[1280,1280]{1,0} add(%select.20700.5.clone.1, %multiply.26041.1.clone.1) + %multiply.26042.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246273.3.clone.1, %select.20706.3.clone.1) + %add.246275.3.clone.1 = f32[1280,1280]{1,0} add(%select.20699.5.clone.1, %multiply.26042.1.clone.1) + %multiply.26043.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246275.3.clone.1, %select.20706.3.clone.1) + %add.246276.9.clone.1 = f32[1280,1280]{1,0} add(%select.20698.11.clone.1, %multiply.26043.7.clone.1) + %multiply.26044.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246276.9.clone.1, %select.20706.3.clone.1) + %add.246277.7.clone.1 = f32[1280,1280]{1,0} add(%select.20697.7.clone.1, %multiply.26044.7.clone.1) + %multiply.26045.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246277.7.clone.1, %maximum.3644.3.clone.1) + %select.20707.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7172.3.clone.1, %multiply.26035.9.clone.1, %multiply.26045.7.clone.1) + %multiply.26046.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20707.7.clone.1, %broadcast.244500.640) + %clamp.1156.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26046.5.clone.1, %broadcast.244501.384) + %multiply.26047.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1156.3.clone.1, %broadcast.244502.1) + %constant_193448_1_clone_1 = u32[] constant(3979012375) + %broadcast.259680.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193448_1_clone_1), dimensions={} + %add.252799.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.259680.44.clone.1) + %constant_193455_1_clone_1 = u32[] constant(824634254) + %broadcast.259681.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193455_1_clone_1), dimensions={} + %add.252801.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.259681.113.clone.1) + %add.252802.35.clone.1 = u32[1280,1280]{1,0} add(%add.252799.37.clone.1, %add.252801.99.clone.1) + %shift-left.111534.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252801.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117834.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252801.99.clone.1, %broadcast.244415.6016) + %or.117358.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111534.31.clone.1, %shift-right-logical.117834.29.clone.1) + %xor.123911.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252802.35.clone.1, %or.117358.29.clone.1) + %add.252803.5.clone.1 = u32[1280,1280]{1,0} add(%add.252802.35.clone.1, %xor.123911.27.clone.1) + %shift-left.111535.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123911.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117835.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123911.27.clone.1, %broadcast.244417.5760) + %or.117359.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111535.9.clone.1, %shift-right-logical.117835.9.clone.1) + %xor.123912.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252803.5.clone.1, %or.117359.7.clone.1) + %add.252804.3.clone.1 = u32[1280,1280]{1,0} add(%add.252803.5.clone.1, %xor.123912.5.clone.1) + %shift-left.111537.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123912.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117836.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123912.5.clone.1, %broadcast.244419.4352) + %or.117360.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111537.5.clone.1, %shift-right-logical.117836.5.clone.1) + %xor.123913.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252804.3.clone.1, %or.117360.3.clone.1) + %add.252806.3.clone.1 = u32[1280,1280]{1,0} add(%add.252804.3.clone.1, %xor.123913.3.clone.1) + %add.252807.7.clone.1 = u32[1280,1280]{1,0} add(%add.252806.3.clone.1, %broadcast.259681.113.clone.1) + %shift-left.111538.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123913.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117837.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123913.3.clone.1, %broadcast.244418.4352) + %or.117361.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111538.5.clone.1, %shift-right-logical.117837.5.clone.1) + %xor.123914.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252806.3.clone.1, %or.117361.3.clone.1) + %constant_218725_1_clone_1 = u32[] constant(3353155908) + %broadcast.259691.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218725_1_clone_1), dimensions={} + %add.252808.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123914.3.clone.1, %broadcast.259691.5.clone.1) + %add.252809.5.clone.1 = u32[1280,1280]{1,0} add(%add.252807.7.clone.1, %add.252808.5.clone.1) + %shift-left.111539.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252808.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117838.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252808.5.clone.1, %broadcast.244416.5760) + %or.117362.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111539.9.clone.1, %shift-right-logical.117838.9.clone.1) + %xor.123915.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252809.5.clone.1, %or.117362.7.clone.1) + %add.252810.3.clone.1 = u32[1280,1280]{1,0} add(%add.252809.5.clone.1, %xor.123915.5.clone.1) + %shift-left.111540.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123915.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117839.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123915.5.clone.1, %broadcast.244429.2304) + %or.117363.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111540.9.clone.1, %shift-right-logical.117839.9.clone.1) + %xor.123916.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252810.3.clone.1, %or.117363.7.clone.1) + %add.252812.3.clone.1 = u32[1280,1280]{1,0} add(%add.252810.3.clone.1, %xor.123916.5.clone.1) + %shift-left.111542.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123916.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117840.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123916.5.clone.1, %broadcast.244430.4608) + %or.117364.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111542.9.clone.1, %shift-right-logical.117840.9.clone.1) + %xor.123917.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252812.3.clone.1, %or.117364.7.clone.1) + %add.252816.3.clone.1 = u32[1280,1280]{1,0} add(%add.252812.3.clone.1, %xor.123917.5.clone.1) + %constant_193457_1_clone_1 = u32[] constant(3353155907) + %broadcast.259703.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_193457_1_clone_1), dimensions={} + %add.252817.7.clone.1 = u32[1280,1280]{1,0} add(%add.252816.3.clone.1, %broadcast.259703.24.clone.1) + %shift-left.111543.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123917.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117841.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123917.5.clone.1, %broadcast.244434.2816) + %or.117365.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111543.11.clone.1, %shift-right-logical.117841.11.clone.1) + %xor.123918.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252816.3.clone.1, %or.117365.9.clone.1) + %constant_218726_1_clone_1 = u32[] constant(3979012377) + %broadcast.259709.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218726_1_clone_1), dimensions={} + %add.252818.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123918.7.clone.1, %broadcast.259709.5.clone.1) + %add.252819.5.clone.1 = u32[1280,1280]{1,0} add(%add.252817.7.clone.1, %add.252818.5.clone.1) + %shift-left.111544.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252818.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117842.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252818.5.clone.1, %broadcast.244415.6016) + %or.117366.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111544.9.clone.1, %shift-right-logical.117842.9.clone.1) + %xor.123919.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252819.5.clone.1, %or.117366.7.clone.1) + %add.252821.3.clone.1 = u32[1280,1280]{1,0} add(%add.252819.5.clone.1, %xor.123919.5.clone.1) + %shift-left.111545.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123919.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117843.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123919.5.clone.1, %broadcast.244417.5760) + %or.117367.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111545.9.clone.1, %shift-right-logical.117843.9.clone.1) + %xor.123920.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252821.3.clone.1, %or.117367.7.clone.1) + %add.252822.3.clone.1 = u32[1280,1280]{1,0} add(%add.252821.3.clone.1, %xor.123920.5.clone.1) + %shift-left.111546.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123920.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117844.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123920.5.clone.1, %broadcast.244419.4352) + %or.117368.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111546.7.clone.1, %shift-right-logical.117844.7.clone.1) + %xor.123921.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252822.3.clone.1, %or.117368.5.clone.1) + %add.252823.3.clone.1 = u32[1280,1280]{1,0} add(%add.252822.3.clone.1, %xor.123921.3.clone.1) + %add.252824.7.clone.1 = u32[1280,1280]{1,0} add(%add.252823.3.clone.1, %broadcast.259680.44.clone.1) + %shift-left.111547.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123921.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117845.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123921.3.clone.1, %broadcast.244418.4352) + %or.117369.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111547.7.clone.1, %shift-right-logical.117845.7.clone.1) + %xor.123922.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252823.3.clone.1, %or.117369.5.clone.1) + %constant_218727_1_clone_1 = u32[] constant(824634257) + %broadcast.259727.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218727_1_clone_1), dimensions={} + %add.252826.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123922.3.clone.1, %broadcast.259727.5.clone.1) + %add.252827.5.clone.1 = u32[1280,1280]{1,0} add(%add.252824.7.clone.1, %add.252826.5.clone.1) + %shift-left.111548.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252826.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117846.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252826.5.clone.1, %broadcast.244416.5760) + %or.117370.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111548.9.clone.1, %shift-right-logical.117846.9.clone.1) + %xor.123923.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252827.5.clone.1, %or.117370.7.clone.1) + %add.252828.3.clone.1 = u32[1280,1280]{1,0} add(%add.252827.5.clone.1, %xor.123923.5.clone.1) + %shift-left.111549.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123923.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117847.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123923.5.clone.1, %broadcast.244429.2304) + %or.117371.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111549.9.clone.1, %shift-right-logical.117847.9.clone.1) + %xor.123924.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252828.3.clone.1, %or.117371.7.clone.1) + %add.252829.3.clone.1 = u32[1280,1280]{1,0} add(%add.252828.3.clone.1, %xor.123924.5.clone.1) + %shift-left.111550.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123924.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117848.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123924.5.clone.1, %broadcast.244430.4608) + %or.117372.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111550.9.clone.1, %shift-right-logical.117848.9.clone.1) + %xor.123925.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252829.3.clone.1, %or.117372.7.clone.1) + %add.252831.3.clone.1 = u32[1280,1280]{1,0} add(%add.252829.3.clone.1, %xor.123925.5.clone.1) + %add.252832.7.clone.1 = u32[1280,1280]{1,0} add(%add.252831.3.clone.1, %broadcast.259681.113.clone.1) + %shift-left.111552.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123925.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117849.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123925.5.clone.1, %broadcast.244434.2816) + %or.117373.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111552.11.clone.1, %shift-right-logical.117849.11.clone.1) + %xor.123926.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252831.3.clone.1, %or.117373.9.clone.1) + %constant_218728_1_clone_1 = u32[] constant(3353155911) + %broadcast.259738.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218728_1_clone_1), dimensions={} + %add.252833.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123926.7.clone.1, %broadcast.259738.5.clone.1) + %add.252834.5.clone.1 = u32[1280,1280]{1,0} add(%add.252832.7.clone.1, %add.252833.5.clone.1) + %shift-left.111553.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252833.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117850.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252833.5.clone.1, %broadcast.244415.6016) + %or.117374.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111553.9.clone.1, %shift-right-logical.117850.9.clone.1) + %xor.123927.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252834.5.clone.1, %or.117374.7.clone.1) + %add.252835.3.clone.1 = u32[1280,1280]{1,0} add(%add.252834.5.clone.1, %xor.123927.5.clone.1) + %shift-left.111554.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123927.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117851.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123927.5.clone.1, %broadcast.244417.5760) + %or.117375.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111554.9.clone.1, %shift-right-logical.117851.9.clone.1) + %xor.123928.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252835.3.clone.1, %or.117375.7.clone.1) + %add.252837.3.clone.1 = u32[1280,1280]{1,0} add(%add.252835.3.clone.1, %xor.123928.5.clone.1) + %shift-left.111555.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123928.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117852.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123928.5.clone.1, %broadcast.244419.4352) + %or.117377.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111555.5.clone.1, %shift-right-logical.117852.5.clone.1) + %xor.123929.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252837.3.clone.1, %or.117377.3.clone.1) + %add.252841.3.clone.1 = u32[1280,1280]{1,0} add(%add.252837.3.clone.1, %xor.123929.3.clone.1) + %add.252842.17.clone.1 = u32[1280,1280]{1,0} add(%add.252841.3.clone.1, %broadcast.259703.24.clone.1) + %shift-left.111557.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123929.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117853.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123929.3.clone.1, %broadcast.244418.4352) + %or.117378.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111557.5.clone.1, %shift-right-logical.117853.5.clone.1) + %xor.123930.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252841.3.clone.1, %or.117378.3.clone.1) + %constant_218730_1_clone_1 = u32[] constant(3979012380) + %broadcast.259748.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218730_1_clone_1), dimensions={} + %add.252843.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123930.15.clone.1, %broadcast.259748.19.clone.1) + %xor.123933.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252842.17.clone.1, %add.252843.19.clone.1) + %shift-right-logical.117854.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123933.17.clone.1, %broadcast.244468.1920) + %or.117380.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117854.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5843.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117380.13.clone.1) + %add.252844.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5843.11.clone.1, %broadcast.244470.1152) + %multiply.27370.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252844.9.clone.1, %broadcast.244471.896) + %add.252846.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27370.7.clone.1, %broadcast.244408.1024) + %maximum.3775.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252846.5.clone.1) + %abs.1599.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3775.3.clone.1) + %compare.7360.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1599.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27371.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3775.3.clone.1, %broadcast.244476.1152) + %negate.4703.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3775.3.clone.1) + %multiply.27372.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3775.3.clone.1, %negate.4703.5.clone.1) + %log-plus-one.1599.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27372.5.clone.1) + %negate.4704.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1599.3.clone.1) + %compare.7361.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4704.4.clone.1, %broadcast.244477.384), direction=LT + %select.21696.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21697.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21698.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21699.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21700.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21701.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21702.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21703.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21704.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252847.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4704.4.clone.1, %broadcast.244496.640) + %sqrt.1599.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4704.4.clone.1) + %add.252848.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1599.5.clone.1, %broadcast.244498.640) + %select.21705.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7361.3.clone.1, %add.252847.5.clone.1, %add.252848.5.clone.1) + %multiply.27373.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21704.3.clone.1, %select.21705.3.clone.1) + %add.252849.1.clone.1 = f32[1280,1280]{1,0} add(%select.21703.3.clone.1, %multiply.27373.1.clone.1) + %multiply.27374.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252849.1.clone.1, %select.21705.3.clone.1) + %add.252851.1.clone.1 = f32[1280,1280]{1,0} add(%select.21702.3.clone.1, %multiply.27374.1.clone.1) + %multiply.27375.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252851.1.clone.1, %select.21705.3.clone.1) + %add.252852.1.clone.1 = f32[1280,1280]{1,0} add(%select.21701.3.clone.1, %multiply.27375.1.clone.1) + %multiply.27376.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252852.1.clone.1, %select.21705.3.clone.1) + %add.252853.1.clone.1 = f32[1280,1280]{1,0} add(%select.21700.3.clone.1, %multiply.27376.1.clone.1) + %multiply.27377.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252853.1.clone.1, %select.21705.3.clone.1) + %add.252854.3.clone.1 = f32[1280,1280]{1,0} add(%select.21699.5.clone.1, %multiply.27377.1.clone.1) + %multiply.27378.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252854.3.clone.1, %select.21705.3.clone.1) + %add.252856.3.clone.1 = f32[1280,1280]{1,0} add(%select.21698.5.clone.1, %multiply.27378.1.clone.1) + %multiply.27379.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252856.3.clone.1, %select.21705.3.clone.1) + %add.252857.9.clone.1 = f32[1280,1280]{1,0} add(%select.21697.11.clone.1, %multiply.27379.7.clone.1) + %multiply.27380.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252857.9.clone.1, %select.21705.3.clone.1) + %add.252858.7.clone.1 = f32[1280,1280]{1,0} add(%select.21696.7.clone.1, %multiply.27380.7.clone.1) + %multiply.27381.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252858.7.clone.1, %maximum.3775.3.clone.1) + %select.21706.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7360.3.clone.1, %multiply.27371.9.clone.1, %multiply.27381.7.clone.1) + %multiply.27382.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21706.7.clone.1, %broadcast.244500.640) + %clamp.1243.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27382.5.clone.1, %broadcast.244501.384) + %multiply.27383.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1243.3.clone.1, %broadcast.244502.1) + %constant_166629_1_clone_1 = u32[] constant(4239085431) + %broadcast.248096.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166629_1_clone_1), dimensions={} + %add.246176.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.248096.44.clone.1) + %constant_166636_1_clone_1 = u32[] constant(142116863) + %broadcast.248097.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166636_1_clone_1), dimensions={} + %add.246177.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.248097.113.clone.1) + %add.246178.35.clone.1 = u32[1280,1280]{1,0} add(%add.246176.37.clone.1, %add.246177.99.clone.1) + %shift-left.108663.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246177.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114804.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246177.99.clone.1, %broadcast.244415.6016) + %or.114322.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108663.31.clone.1, %shift-right-logical.114804.29.clone.1) + %xor.120869.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246178.35.clone.1, %or.114322.29.clone.1) + %add.246179.5.clone.1 = u32[1280,1280]{1,0} add(%add.246178.35.clone.1, %xor.120869.27.clone.1) + %shift-left.108664.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120869.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114805.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120869.27.clone.1, %broadcast.244417.5760) + %or.114323.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108664.9.clone.1, %shift-right-logical.114805.9.clone.1) + %xor.120870.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246179.5.clone.1, %or.114323.7.clone.1) + %add.246180.3.clone.1 = u32[1280,1280]{1,0} add(%add.246179.5.clone.1, %xor.120870.5.clone.1) + %shift-left.108665.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120870.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114806.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120870.5.clone.1, %broadcast.244419.4352) + %or.114324.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108665.5.clone.1, %shift-right-logical.114806.5.clone.1) + %xor.120871.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246180.3.clone.1, %or.114324.3.clone.1) + %add.246181.3.clone.1 = u32[1280,1280]{1,0} add(%add.246180.3.clone.1, %xor.120871.3.clone.1) + %add.246182.7.clone.1 = u32[1280,1280]{1,0} add(%add.246181.3.clone.1, %broadcast.248097.113.clone.1) + %shift-left.108667.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120871.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114807.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120871.3.clone.1, %broadcast.244418.4352) + %or.114325.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108667.5.clone.1, %shift-right-logical.114807.5.clone.1) + %xor.120872.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246181.3.clone.1, %or.114325.3.clone.1) + %constant_218006_1_clone_1 = u32[] constant(4009939795) + %broadcast.248107.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218006_1_clone_1), dimensions={} + %add.246183.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120872.3.clone.1, %broadcast.248107.5.clone.1) + %add.246184.5.clone.1 = u32[1280,1280]{1,0} add(%add.246182.7.clone.1, %add.246183.5.clone.1) + %shift-left.108668.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246183.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114808.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246183.5.clone.1, %broadcast.244416.5760) + %or.114326.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108668.9.clone.1, %shift-right-logical.114808.9.clone.1) + %xor.120873.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246184.5.clone.1, %or.114326.7.clone.1) + %add.246185.3.clone.1 = u32[1280,1280]{1,0} add(%add.246184.5.clone.1, %xor.120873.5.clone.1) + %shift-left.108669.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120873.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114809.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120873.5.clone.1, %broadcast.244429.2304) + %or.114327.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108669.9.clone.1, %shift-right-logical.114809.9.clone.1) + %xor.120874.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246185.3.clone.1, %or.114327.7.clone.1) + %add.246186.3.clone.1 = u32[1280,1280]{1,0} add(%add.246185.3.clone.1, %xor.120874.5.clone.1) + %shift-left.108670.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120874.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114810.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120874.5.clone.1, %broadcast.244430.4608) + %or.114328.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108670.9.clone.1, %shift-right-logical.114810.9.clone.1) + %xor.120875.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246186.3.clone.1, %or.114328.7.clone.1) + %add.246187.3.clone.1 = u32[1280,1280]{1,0} add(%add.246186.3.clone.1, %xor.120875.5.clone.1) + %constant_166638_1_clone_1 = u32[] constant(4009939794) + %broadcast.248116.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166638_1_clone_1), dimensions={} + %add.246188.7.clone.1 = u32[1280,1280]{1,0} add(%add.246187.3.clone.1, %broadcast.248116.24.clone.1) + %shift-left.108671.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120875.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114811.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120875.5.clone.1, %broadcast.244434.2816) + %or.114329.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108671.11.clone.1, %shift-right-logical.114811.11.clone.1) + %xor.120876.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246187.3.clone.1, %or.114329.9.clone.1) + %constant_218007_1_clone_1 = u32[] constant(4239085433) + %broadcast.248119.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218007_1_clone_1), dimensions={} + %add.246189.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120876.7.clone.1, %broadcast.248119.5.clone.1) + %add.246190.5.clone.1 = u32[1280,1280]{1,0} add(%add.246188.7.clone.1, %add.246189.5.clone.1) + %shift-left.108672.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246189.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114812.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246189.5.clone.1, %broadcast.244415.6016) + %or.114330.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108672.9.clone.1, %shift-right-logical.114812.9.clone.1) + %xor.120877.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246190.5.clone.1, %or.114330.7.clone.1) + %add.246191.3.clone.1 = u32[1280,1280]{1,0} add(%add.246190.5.clone.1, %xor.120877.5.clone.1) + %shift-left.108673.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120877.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114813.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120877.5.clone.1, %broadcast.244417.5760) + %or.114331.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108673.9.clone.1, %shift-right-logical.114813.9.clone.1) + %xor.120878.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246191.3.clone.1, %or.114331.7.clone.1) + %add.246192.3.clone.1 = u32[1280,1280]{1,0} add(%add.246191.3.clone.1, %xor.120878.5.clone.1) + %shift-left.108674.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120878.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114814.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120878.5.clone.1, %broadcast.244419.4352) + %or.114332.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108674.7.clone.1, %shift-right-logical.114814.7.clone.1) + %xor.120879.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246192.3.clone.1, %or.114332.5.clone.1) + %add.246193.3.clone.1 = u32[1280,1280]{1,0} add(%add.246192.3.clone.1, %xor.120879.3.clone.1) + %add.246194.7.clone.1 = u32[1280,1280]{1,0} add(%add.246193.3.clone.1, %broadcast.248096.44.clone.1) + %shift-left.108675.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120879.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114815.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120879.3.clone.1, %broadcast.244418.4352) + %or.114333.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108675.7.clone.1, %shift-right-logical.114815.7.clone.1) + %xor.120880.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246193.3.clone.1, %or.114333.5.clone.1) + %constant_218008_1_clone_1 = u32[] constant(142116866) + %broadcast.248129.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218008_1_clone_1), dimensions={} + %add.246195.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120880.3.clone.1, %broadcast.248129.5.clone.1) + %add.246196.5.clone.1 = u32[1280,1280]{1,0} add(%add.246194.7.clone.1, %add.246195.5.clone.1) + %shift-left.108677.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246195.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114816.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246195.5.clone.1, %broadcast.244416.5760) + %or.114334.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108677.9.clone.1, %shift-right-logical.114816.9.clone.1) + %xor.120881.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246196.5.clone.1, %or.114334.7.clone.1) + %add.246197.3.clone.1 = u32[1280,1280]{1,0} add(%add.246196.5.clone.1, %xor.120881.5.clone.1) + %shift-left.108678.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120881.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114817.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120881.5.clone.1, %broadcast.244429.2304) + %or.114335.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108678.9.clone.1, %shift-right-logical.114817.9.clone.1) + %xor.120882.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246197.3.clone.1, %or.114335.7.clone.1) + %add.246198.3.clone.1 = u32[1280,1280]{1,0} add(%add.246197.3.clone.1, %xor.120882.5.clone.1) + %shift-left.108679.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120882.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114818.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120882.5.clone.1, %broadcast.244430.4608) + %or.114336.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108679.9.clone.1, %shift-right-logical.114818.9.clone.1) + %xor.120883.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246198.3.clone.1, %or.114336.7.clone.1) + %add.246200.3.clone.1 = u32[1280,1280]{1,0} add(%add.246198.3.clone.1, %xor.120883.5.clone.1) + %add.246201.7.clone.1 = u32[1280,1280]{1,0} add(%add.246200.3.clone.1, %broadcast.248097.113.clone.1) + %shift-left.108680.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120883.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114819.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120883.5.clone.1, %broadcast.244434.2816) + %or.114337.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108680.11.clone.1, %shift-right-logical.114819.11.clone.1) + %xor.120884.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246200.3.clone.1, %or.114337.9.clone.1) + %constant_218009_1_clone_1 = u32[] constant(4009939798) + %broadcast.248141.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218009_1_clone_1), dimensions={} + %add.246203.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120884.7.clone.1, %broadcast.248141.5.clone.1) + %add.246204.5.clone.1 = u32[1280,1280]{1,0} add(%add.246201.7.clone.1, %add.246203.5.clone.1) + %shift-left.108682.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246203.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114820.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246203.5.clone.1, %broadcast.244415.6016) + %or.114338.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108682.9.clone.1, %shift-right-logical.114820.9.clone.1) + %xor.120885.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246204.5.clone.1, %or.114338.7.clone.1) + %add.246206.3.clone.1 = u32[1280,1280]{1,0} add(%add.246204.5.clone.1, %xor.120885.5.clone.1) + %shift-left.108683.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120885.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114821.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120885.5.clone.1, %broadcast.244417.5760) + %or.114339.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108683.9.clone.1, %shift-right-logical.114821.9.clone.1) + %xor.120886.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246206.3.clone.1, %or.114339.7.clone.1) + %add.246207.3.clone.1 = u32[1280,1280]{1,0} add(%add.246206.3.clone.1, %xor.120886.5.clone.1) + %shift-left.108684.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120886.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114822.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120886.5.clone.1, %broadcast.244419.4352) + %or.114340.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108684.5.clone.1, %shift-right-logical.114822.5.clone.1) + %xor.120887.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246207.3.clone.1, %or.114340.3.clone.1) + %add.246209.3.clone.1 = u32[1280,1280]{1,0} add(%add.246207.3.clone.1, %xor.120887.3.clone.1) + %add.246210.17.clone.1 = u32[1280,1280]{1,0} add(%add.246209.3.clone.1, %broadcast.248116.24.clone.1) + %shift-left.108685.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120887.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114823.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120887.3.clone.1, %broadcast.244418.4352) + %or.114341.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108685.5.clone.1, %shift-right-logical.114823.5.clone.1) + %xor.120888.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246209.3.clone.1, %or.114341.3.clone.1) + %constant_218010_1_clone_1 = u32[] constant(4239085436) + %broadcast.248151.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218010_1_clone_1), dimensions={} + %add.246212.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120888.15.clone.1, %broadcast.248151.19.clone.1) + %xor.120889.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246210.17.clone.1, %add.246212.19.clone.1) + %shift-right-logical.114824.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120889.17.clone.1, %broadcast.244468.1920) + %or.114342.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114824.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5711.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114342.13.clone.1) + %add.246213.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5711.11.clone.1, %broadcast.244470.1152) + %multiply.26020.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246213.9.clone.1, %broadcast.244471.896) + %add.246215.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26020.7.clone.1, %broadcast.244408.1024) + %maximum.3643.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246215.5.clone.1) + %abs.1511.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3643.3.clone.1) + %compare.7170.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1511.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26021.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3643.3.clone.1, %broadcast.244476.1152) + %negate.4527.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3643.3.clone.1) + %multiply.26022.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3643.3.clone.1, %negate.4527.5.clone.1) + %log-plus-one.1511.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26022.5.clone.1) + %negate.4528.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1511.3.clone.1) + %compare.7171.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4528.4.clone.1, %broadcast.244477.384), direction=LT + %select.20686.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20687.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20688.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20689.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20690.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20691.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20692.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20693.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20694.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246216.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4528.4.clone.1, %broadcast.244496.640) + %sqrt.1511.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4528.4.clone.1) + %add.246218.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1511.5.clone.1, %broadcast.244498.640) + %select.20695.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7171.3.clone.1, %add.246216.5.clone.1, %add.246218.5.clone.1) + %multiply.26023.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20694.3.clone.1, %select.20695.3.clone.1) + %add.246219.1.clone.1 = f32[1280,1280]{1,0} add(%select.20693.3.clone.1, %multiply.26023.1.clone.1) + %multiply.26024.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246219.1.clone.1, %select.20695.3.clone.1) + %add.246221.1.clone.1 = f32[1280,1280]{1,0} add(%select.20692.3.clone.1, %multiply.26024.1.clone.1) + %multiply.26025.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246221.1.clone.1, %select.20695.3.clone.1) + %add.246222.1.clone.1 = f32[1280,1280]{1,0} add(%select.20691.3.clone.1, %multiply.26025.1.clone.1) + %multiply.26026.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246222.1.clone.1, %select.20695.3.clone.1) + %add.246223.1.clone.1 = f32[1280,1280]{1,0} add(%select.20690.3.clone.1, %multiply.26026.1.clone.1) + %multiply.26027.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246223.1.clone.1, %select.20695.3.clone.1) + %add.246224.3.clone.1 = f32[1280,1280]{1,0} add(%select.20689.5.clone.1, %multiply.26027.1.clone.1) + %multiply.26028.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246224.3.clone.1, %select.20695.3.clone.1) + %add.246225.3.clone.1 = f32[1280,1280]{1,0} add(%select.20688.5.clone.1, %multiply.26028.1.clone.1) + %multiply.26029.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246225.3.clone.1, %select.20695.3.clone.1) + %add.246226.9.clone.1 = f32[1280,1280]{1,0} add(%select.20687.11.clone.1, %multiply.26029.7.clone.1) + %multiply.26030.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246226.9.clone.1, %select.20695.3.clone.1) + %add.246227.7.clone.1 = f32[1280,1280]{1,0} add(%select.20686.7.clone.1, %multiply.26030.7.clone.1) + %multiply.26031.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246227.7.clone.1, %maximum.3643.3.clone.1) + %select.20696.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7170.3.clone.1, %multiply.26021.9.clone.1, %multiply.26031.7.clone.1) + %multiply.26032.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20696.7.clone.1, %broadcast.244500.640) + %clamp.1155.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26032.5.clone.1, %broadcast.244501.384) + %multiply.26033.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1155.3.clone.1, %broadcast.244502.1) + %constant_181713_1_clone_1 = u32[] constant(1982119221) + %broadcast.254614.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181713_1_clone_1), dimensions={} + %add.249897.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254614.44.clone.1) + %constant_181723_1_clone_1 = u32[] constant(1790761844) + %broadcast.254615.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181723_1_clone_1), dimensions={} + %add.249898.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254615.113.clone.1) + %add.249899.35.clone.1 = u32[1280,1280]{1,0} add(%add.249897.37.clone.1, %add.249898.99.clone.1) + %shift-left.110260.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249898.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116490.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249898.99.clone.1, %broadcast.244415.6016) + %or.116011.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110260.31.clone.1, %shift-right-logical.116490.29.clone.1) + %xor.122580.27.clone.1 = u32[1280,1280]{1,0} xor(%add.249899.35.clone.1, %or.116011.29.clone.1) + %add.249900.5.clone.1 = u32[1280,1280]{1,0} add(%add.249899.35.clone.1, %xor.122580.27.clone.1) + %shift-left.110261.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122580.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116491.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122580.27.clone.1, %broadcast.244417.5760) + %or.116012.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110261.9.clone.1, %shift-right-logical.116491.9.clone.1) + %xor.122581.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249900.5.clone.1, %or.116012.7.clone.1) + %add.249901.3.clone.1 = u32[1280,1280]{1,0} add(%add.249900.5.clone.1, %xor.122581.5.clone.1) + %shift-left.110262.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122581.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116492.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122581.5.clone.1, %broadcast.244419.4352) + %or.116013.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110262.5.clone.1, %shift-right-logical.116492.5.clone.1) + %xor.122582.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249901.3.clone.1, %or.116013.3.clone.1) + %add.249902.3.clone.1 = u32[1280,1280]{1,0} add(%add.249901.3.clone.1, %xor.122582.3.clone.1) + %add.249903.7.clone.1 = u32[1280,1280]{1,0} add(%add.249902.3.clone.1, %broadcast.254615.113.clone.1) + %shift-left.110263.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122582.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116493.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122582.3.clone.1, %broadcast.244418.4352) + %or.116014.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110263.5.clone.1, %shift-right-logical.116493.5.clone.1) + %xor.122583.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249902.3.clone.1, %or.116014.3.clone.1) + %constant_218414_1_clone_1 = u32[] constant(122256796) + %broadcast.254625.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218414_1_clone_1), dimensions={} + %add.249904.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122583.3.clone.1, %broadcast.254625.5.clone.1) + %add.249905.5.clone.1 = u32[1280,1280]{1,0} add(%add.249903.7.clone.1, %add.249904.5.clone.1) + %shift-left.110264.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249904.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116495.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249904.5.clone.1, %broadcast.244416.5760) + %or.116015.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110264.9.clone.1, %shift-right-logical.116495.9.clone.1) + %xor.122584.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249905.5.clone.1, %or.116015.7.clone.1) + %add.249906.3.clone.1 = u32[1280,1280]{1,0} add(%add.249905.5.clone.1, %xor.122584.5.clone.1) + %shift-left.110265.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122584.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116496.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122584.5.clone.1, %broadcast.244429.2304) + %or.116016.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110265.9.clone.1, %shift-right-logical.116496.9.clone.1) + %xor.122585.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249906.3.clone.1, %or.116016.7.clone.1) + %add.249907.3.clone.1 = u32[1280,1280]{1,0} add(%add.249906.3.clone.1, %xor.122585.5.clone.1) + %shift-left.110266.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122585.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116497.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122585.5.clone.1, %broadcast.244430.4608) + %or.116017.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110266.9.clone.1, %shift-right-logical.116497.9.clone.1) + %xor.122586.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249907.3.clone.1, %or.116017.7.clone.1) + %add.249908.3.clone.1 = u32[1280,1280]{1,0} add(%add.249907.3.clone.1, %xor.122586.5.clone.1) + %constant_181725_1_clone_1 = u32[] constant(122256795) + %broadcast.254632.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181725_1_clone_1), dimensions={} + %add.249909.7.clone.1 = u32[1280,1280]{1,0} add(%add.249908.3.clone.1, %broadcast.254632.24.clone.1) + %shift-left.110267.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122586.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116498.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122586.5.clone.1, %broadcast.244434.2816) + %or.116018.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110267.11.clone.1, %shift-right-logical.116498.11.clone.1) + %xor.122587.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249908.3.clone.1, %or.116018.9.clone.1) + %constant_218415_1_clone_1 = u32[] constant(1982119223) + %broadcast.254635.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218415_1_clone_1), dimensions={} + %add.249910.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122587.7.clone.1, %broadcast.254635.5.clone.1) + %add.249911.5.clone.1 = u32[1280,1280]{1,0} add(%add.249909.7.clone.1, %add.249910.5.clone.1) + %shift-left.110268.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249910.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116499.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249910.5.clone.1, %broadcast.244415.6016) + %or.116019.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110268.9.clone.1, %shift-right-logical.116499.9.clone.1) + %xor.122588.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249911.5.clone.1, %or.116019.7.clone.1) + %add.249912.3.clone.1 = u32[1280,1280]{1,0} add(%add.249911.5.clone.1, %xor.122588.5.clone.1) + %shift-left.110269.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122588.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116500.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122588.5.clone.1, %broadcast.244417.5760) + %or.116020.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110269.9.clone.1, %shift-right-logical.116500.9.clone.1) + %xor.122589.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249912.3.clone.1, %or.116020.7.clone.1) + %add.249913.3.clone.1 = u32[1280,1280]{1,0} add(%add.249912.3.clone.1, %xor.122589.5.clone.1) + %shift-left.110270.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122589.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116501.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122589.5.clone.1, %broadcast.244419.4352) + %or.116021.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110270.7.clone.1, %shift-right-logical.116501.7.clone.1) + %xor.122590.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249913.3.clone.1, %or.116021.5.clone.1) + %add.249914.3.clone.1 = u32[1280,1280]{1,0} add(%add.249913.3.clone.1, %xor.122590.3.clone.1) + %add.249916.7.clone.1 = u32[1280,1280]{1,0} add(%add.249914.3.clone.1, %broadcast.254614.44.clone.1) + %shift-left.110271.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122590.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116502.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122590.3.clone.1, %broadcast.244418.4352) + %or.116022.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110271.7.clone.1, %shift-right-logical.116502.7.clone.1) + %xor.122591.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249914.3.clone.1, %or.116022.5.clone.1) + %constant_218416_1_clone_1 = u32[] constant(1790761847) + %broadcast.254645.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218416_1_clone_1), dimensions={} + %add.249917.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122591.3.clone.1, %broadcast.254645.5.clone.1) + %add.249918.5.clone.1 = u32[1280,1280]{1,0} add(%add.249916.7.clone.1, %add.249917.5.clone.1) + %shift-left.110272.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249917.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116503.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249917.5.clone.1, %broadcast.244416.5760) + %or.116024.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110272.9.clone.1, %shift-right-logical.116503.9.clone.1) + %xor.122592.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249918.5.clone.1, %or.116024.7.clone.1) + %add.249919.3.clone.1 = u32[1280,1280]{1,0} add(%add.249918.5.clone.1, %xor.122592.5.clone.1) + %shift-left.110273.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122592.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116505.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122592.5.clone.1, %broadcast.244429.2304) + %or.116025.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110273.9.clone.1, %shift-right-logical.116505.9.clone.1) + %xor.122593.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249919.3.clone.1, %or.116025.7.clone.1) + %add.249920.3.clone.1 = u32[1280,1280]{1,0} add(%add.249919.3.clone.1, %xor.122593.5.clone.1) + %shift-left.110274.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122593.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116506.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122593.5.clone.1, %broadcast.244430.4608) + %or.116027.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110274.9.clone.1, %shift-right-logical.116506.9.clone.1) + %xor.122594.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249920.3.clone.1, %or.116027.7.clone.1) + %add.249921.3.clone.1 = u32[1280,1280]{1,0} add(%add.249920.3.clone.1, %xor.122594.5.clone.1) + %add.249922.7.clone.1 = u32[1280,1280]{1,0} add(%add.249921.3.clone.1, %broadcast.254615.113.clone.1) + %shift-left.110275.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122594.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116507.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122594.5.clone.1, %broadcast.244434.2816) + %or.116028.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110275.11.clone.1, %shift-right-logical.116507.11.clone.1) + %xor.122595.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249921.3.clone.1, %or.116028.9.clone.1) + %constant_218417_1_clone_1 = u32[] constant(122256799) + %broadcast.254655.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218417_1_clone_1), dimensions={} + %add.249923.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122595.7.clone.1, %broadcast.254655.5.clone.1) + %add.249924.5.clone.1 = u32[1280,1280]{1,0} add(%add.249922.7.clone.1, %add.249923.5.clone.1) + %shift-left.110276.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249923.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116508.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249923.5.clone.1, %broadcast.244415.6016) + %or.116029.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110276.9.clone.1, %shift-right-logical.116508.9.clone.1) + %xor.122596.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249924.5.clone.1, %or.116029.7.clone.1) + %add.249925.3.clone.1 = u32[1280,1280]{1,0} add(%add.249924.5.clone.1, %xor.122596.5.clone.1) + %shift-left.110277.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122596.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116510.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122596.5.clone.1, %broadcast.244417.5760) + %or.116030.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110277.9.clone.1, %shift-right-logical.116510.9.clone.1) + %xor.122597.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249925.3.clone.1, %or.116030.7.clone.1) + %add.249926.3.clone.1 = u32[1280,1280]{1,0} add(%add.249925.3.clone.1, %xor.122597.5.clone.1) + %shift-left.110278.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122597.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116511.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122597.5.clone.1, %broadcast.244419.4352) + %or.116031.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110278.5.clone.1, %shift-right-logical.116511.5.clone.1) + %xor.122598.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249926.3.clone.1, %or.116031.3.clone.1) + %add.249927.3.clone.1 = u32[1280,1280]{1,0} add(%add.249926.3.clone.1, %xor.122598.3.clone.1) + %add.249928.17.clone.1 = u32[1280,1280]{1,0} add(%add.249927.3.clone.1, %broadcast.254632.24.clone.1) + %shift-left.110279.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122598.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116512.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122598.3.clone.1, %broadcast.244418.4352) + %or.116032.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110279.5.clone.1, %shift-right-logical.116512.5.clone.1) + %xor.122599.15.clone.1 = u32[1280,1280]{1,0} xor(%add.249927.3.clone.1, %or.116032.3.clone.1) + %constant_218418_1_clone_1 = u32[] constant(1982119226) + %broadcast.254665.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218418_1_clone_1), dimensions={} + %add.249929.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122599.15.clone.1, %broadcast.254665.19.clone.1) + %xor.122600.17.clone.1 = u32[1280,1280]{1,0} xor(%add.249928.17.clone.1, %add.249929.19.clone.1) + %shift-right-logical.116513.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122600.17.clone.1, %broadcast.244468.1920) + %or.116033.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116513.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5785.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116033.13.clone.1) + %add.249930.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5785.11.clone.1, %broadcast.244470.1152) + %multiply.26783.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249930.9.clone.1, %broadcast.244471.896) + %add.249931.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26783.7.clone.1, %broadcast.244408.1024) + %maximum.3717.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.249931.5.clone.1) + %abs.1561.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3717.3.clone.1) + %compare.7271.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1561.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26784.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3717.3.clone.1, %broadcast.244476.1152) + %negate.4627.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3717.3.clone.1) + %multiply.26785.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3717.3.clone.1, %negate.4627.5.clone.1) + %log-plus-one.1561.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26785.5.clone.1) + %negate.4628.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1561.3.clone.1) + %compare.7272.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4628.4.clone.1, %broadcast.244477.384), direction=LT + %select.21257.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21258.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21259.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21260.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21261.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21262.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21263.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21264.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21265.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.249932.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4628.4.clone.1, %broadcast.244496.640) + %sqrt.1561.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4628.4.clone.1) + %add.249933.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1561.5.clone.1, %broadcast.244498.640) + %select.21266.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7272.3.clone.1, %add.249932.5.clone.1, %add.249933.5.clone.1) + %multiply.26786.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21265.3.clone.1, %select.21266.3.clone.1) + %add.249934.1.clone.1 = f32[1280,1280]{1,0} add(%select.21264.3.clone.1, %multiply.26786.1.clone.1) + %multiply.26787.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249934.1.clone.1, %select.21266.3.clone.1) + %add.249935.1.clone.1 = f32[1280,1280]{1,0} add(%select.21263.3.clone.1, %multiply.26787.1.clone.1) + %multiply.26788.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249935.1.clone.1, %select.21266.3.clone.1) + %add.249936.1.clone.1 = f32[1280,1280]{1,0} add(%select.21262.3.clone.1, %multiply.26788.1.clone.1) + %multiply.26789.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249936.1.clone.1, %select.21266.3.clone.1) + %add.249937.1.clone.1 = f32[1280,1280]{1,0} add(%select.21261.3.clone.1, %multiply.26789.1.clone.1) + %multiply.26790.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249937.1.clone.1, %select.21266.3.clone.1) + %add.249938.3.clone.1 = f32[1280,1280]{1,0} add(%select.21260.5.clone.1, %multiply.26790.1.clone.1) + %multiply.26791.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249938.3.clone.1, %select.21266.3.clone.1) + %add.249939.3.clone.1 = f32[1280,1280]{1,0} add(%select.21259.5.clone.1, %multiply.26791.1.clone.1) + %multiply.26792.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249939.3.clone.1, %select.21266.3.clone.1) + %add.249940.9.clone.1 = f32[1280,1280]{1,0} add(%select.21258.11.clone.1, %multiply.26792.7.clone.1) + %multiply.26793.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249940.9.clone.1, %select.21266.3.clone.1) + %add.249941.7.clone.1 = f32[1280,1280]{1,0} add(%select.21257.7.clone.1, %multiply.26793.7.clone.1) + %multiply.26794.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249941.7.clone.1, %maximum.3717.3.clone.1) + %select.21267.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7271.3.clone.1, %multiply.26784.9.clone.1, %multiply.26794.7.clone.1) + %multiply.26795.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21267.7.clone.1, %broadcast.244500.640) + %clamp.1205.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26795.5.clone.1, %broadcast.244501.384) + %multiply.26796.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1205.3.clone.1, %broadcast.244502.1) + %constant_166078_1_clone_1 = u32[] constant(1224719244) + %broadcast.247863.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166078_1_clone_1), dimensions={} + %add.246052.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247863.44.clone.1) + %constant_166085_1_clone_1 = u32[] constant(3697509252) + %broadcast.247864.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166085_1_clone_1), dimensions={} + %add.246053.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247864.113.clone.1) + %add.246054.35.clone.1 = u32[1280,1280]{1,0} add(%add.246052.37.clone.1, %add.246053.99.clone.1) + %shift-left.108600.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246053.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114739.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246053.99.clone.1, %broadcast.244415.6016) + %or.114259.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108600.31.clone.1, %shift-right-logical.114739.29.clone.1) + %xor.120806.27.clone.1 = u32[1280,1280]{1,0} xor(%add.246054.35.clone.1, %or.114259.29.clone.1) + %add.246056.5.clone.1 = u32[1280,1280]{1,0} add(%add.246054.35.clone.1, %xor.120806.27.clone.1) + %shift-left.108601.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120806.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114740.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120806.27.clone.1, %broadcast.244417.5760) + %or.114260.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108601.9.clone.1, %shift-right-logical.114740.9.clone.1) + %xor.120807.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246056.5.clone.1, %or.114260.7.clone.1) + %add.246057.3.clone.1 = u32[1280,1280]{1,0} add(%add.246056.5.clone.1, %xor.120807.5.clone.1) + %shift-left.108602.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120807.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114741.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120807.5.clone.1, %broadcast.244419.4352) + %or.114261.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108602.5.clone.1, %shift-right-logical.114741.5.clone.1) + %xor.120808.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246057.3.clone.1, %or.114261.3.clone.1) + %add.246058.3.clone.1 = u32[1280,1280]{1,0} add(%add.246057.3.clone.1, %xor.120808.3.clone.1) + %add.246059.7.clone.1 = u32[1280,1280]{1,0} add(%add.246058.3.clone.1, %broadcast.247864.113.clone.1) + %shift-left.108603.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120808.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114742.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120808.3.clone.1, %broadcast.244418.4352) + %or.114262.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108603.5.clone.1, %shift-right-logical.114742.5.clone.1) + %xor.120809.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246058.3.clone.1, %or.114262.3.clone.1) + %constant_217986_1_clone_1 = u32[] constant(2404197331) + %broadcast.247880.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217986_1_clone_1), dimensions={} + %add.246060.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120809.3.clone.1, %broadcast.247880.5.clone.1) + %add.246062.5.clone.1 = u32[1280,1280]{1,0} add(%add.246059.7.clone.1, %add.246060.5.clone.1) + %shift-left.108604.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246060.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114743.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246060.5.clone.1, %broadcast.244416.5760) + %or.114263.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108604.9.clone.1, %shift-right-logical.114743.9.clone.1) + %xor.120810.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246062.5.clone.1, %or.114263.7.clone.1) + %add.246066.3.clone.1 = u32[1280,1280]{1,0} add(%add.246062.5.clone.1, %xor.120810.5.clone.1) + %shift-left.108605.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120810.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114744.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120810.5.clone.1, %broadcast.244429.2304) + %or.114264.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108605.9.clone.1, %shift-right-logical.114744.9.clone.1) + %xor.120811.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246066.3.clone.1, %or.114264.7.clone.1) + %add.246067.3.clone.1 = u32[1280,1280]{1,0} add(%add.246066.3.clone.1, %xor.120811.5.clone.1) + %shift-left.108606.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120811.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114745.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120811.5.clone.1, %broadcast.244430.4608) + %or.114265.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108606.9.clone.1, %shift-right-logical.114745.9.clone.1) + %xor.120812.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246067.3.clone.1, %or.114265.7.clone.1) + %add.246068.3.clone.1 = u32[1280,1280]{1,0} add(%add.246067.3.clone.1, %xor.120812.5.clone.1) + %constant_166087_1_clone_1 = u32[] constant(2404197330) + %broadcast.247887.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_166087_1_clone_1), dimensions={} + %add.246069.7.clone.1 = u32[1280,1280]{1,0} add(%add.246068.3.clone.1, %broadcast.247887.24.clone.1) + %shift-left.108607.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120812.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114746.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120812.5.clone.1, %broadcast.244434.2816) + %or.114266.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108607.11.clone.1, %shift-right-logical.114746.11.clone.1) + %xor.120813.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246068.3.clone.1, %or.114266.9.clone.1) + %constant_217988_1_clone_1 = u32[] constant(1224719246) + %broadcast.247890.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217988_1_clone_1), dimensions={} + %add.246071.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120813.7.clone.1, %broadcast.247890.5.clone.1) + %add.246072.5.clone.1 = u32[1280,1280]{1,0} add(%add.246069.7.clone.1, %add.246071.5.clone.1) + %shift-left.108608.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246071.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114747.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246071.5.clone.1, %broadcast.244415.6016) + %or.114267.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108608.9.clone.1, %shift-right-logical.114747.9.clone.1) + %xor.120814.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246072.5.clone.1, %or.114267.7.clone.1) + %add.246073.3.clone.1 = u32[1280,1280]{1,0} add(%add.246072.5.clone.1, %xor.120814.5.clone.1) + %shift-left.108609.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120814.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114748.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120814.5.clone.1, %broadcast.244417.5760) + %or.114268.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108609.9.clone.1, %shift-right-logical.114748.9.clone.1) + %xor.120815.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246073.3.clone.1, %or.114268.7.clone.1) + %add.246074.3.clone.1 = u32[1280,1280]{1,0} add(%add.246073.3.clone.1, %xor.120815.5.clone.1) + %shift-left.108610.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120815.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114749.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120815.5.clone.1, %broadcast.244419.4352) + %or.114269.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108610.7.clone.1, %shift-right-logical.114749.7.clone.1) + %xor.120816.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246074.3.clone.1, %or.114269.5.clone.1) + %add.246076.3.clone.1 = u32[1280,1280]{1,0} add(%add.246074.3.clone.1, %xor.120816.3.clone.1) + %add.246077.7.clone.1 = u32[1280,1280]{1,0} add(%add.246076.3.clone.1, %broadcast.247863.44.clone.1) + %shift-left.108611.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120816.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114750.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120816.3.clone.1, %broadcast.244418.4352) + %or.114270.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108611.7.clone.1, %shift-right-logical.114750.7.clone.1) + %xor.120817.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246076.3.clone.1, %or.114270.5.clone.1) + %constant_217990_1_clone_1 = u32[] constant(3697509255) + %broadcast.247900.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217990_1_clone_1), dimensions={} + %add.246078.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120817.3.clone.1, %broadcast.247900.5.clone.1) + %add.246079.5.clone.1 = u32[1280,1280]{1,0} add(%add.246077.7.clone.1, %add.246078.5.clone.1) + %shift-left.108612.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246078.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114751.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246078.5.clone.1, %broadcast.244416.5760) + %or.114271.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108612.9.clone.1, %shift-right-logical.114751.9.clone.1) + %xor.120818.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246079.5.clone.1, %or.114271.7.clone.1) + %add.246081.3.clone.1 = u32[1280,1280]{1,0} add(%add.246079.5.clone.1, %xor.120818.5.clone.1) + %shift-left.108613.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120818.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114752.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120818.5.clone.1, %broadcast.244429.2304) + %or.114272.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108613.9.clone.1, %shift-right-logical.114752.9.clone.1) + %xor.120819.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246081.3.clone.1, %or.114272.7.clone.1) + %add.246082.3.clone.1 = u32[1280,1280]{1,0} add(%add.246081.3.clone.1, %xor.120819.5.clone.1) + %shift-left.108614.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120819.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114753.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120819.5.clone.1, %broadcast.244430.4608) + %or.114273.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108614.9.clone.1, %shift-right-logical.114753.9.clone.1) + %xor.120820.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246082.3.clone.1, %or.114273.7.clone.1) + %add.246083.3.clone.1 = u32[1280,1280]{1,0} add(%add.246082.3.clone.1, %xor.120820.5.clone.1) + %add.246084.7.clone.1 = u32[1280,1280]{1,0} add(%add.246083.3.clone.1, %broadcast.247864.113.clone.1) + %shift-left.108615.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120820.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114754.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120820.5.clone.1, %broadcast.244434.2816) + %or.114274.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108615.11.clone.1, %shift-right-logical.114754.11.clone.1) + %xor.120821.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246083.3.clone.1, %or.114274.9.clone.1) + %constant_217992_1_clone_1 = u32[] constant(2404197334) + %broadcast.247910.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217992_1_clone_1), dimensions={} + %add.246085.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120821.7.clone.1, %broadcast.247910.5.clone.1) + %add.246087.5.clone.1 = u32[1280,1280]{1,0} add(%add.246084.7.clone.1, %add.246085.5.clone.1) + %shift-left.108616.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246085.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114755.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246085.5.clone.1, %broadcast.244415.6016) + %or.114275.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108616.9.clone.1, %shift-right-logical.114755.9.clone.1) + %xor.120822.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246087.5.clone.1, %or.114275.7.clone.1) + %add.246090.3.clone.1 = u32[1280,1280]{1,0} add(%add.246087.5.clone.1, %xor.120822.5.clone.1) + %shift-left.108617.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120822.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114756.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120822.5.clone.1, %broadcast.244417.5760) + %or.114276.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108617.9.clone.1, %shift-right-logical.114756.9.clone.1) + %xor.120823.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246090.3.clone.1, %or.114276.7.clone.1) + %add.246091.3.clone.1 = u32[1280,1280]{1,0} add(%add.246090.3.clone.1, %xor.120823.5.clone.1) + %shift-left.108618.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120823.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114757.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120823.5.clone.1, %broadcast.244419.4352) + %or.114277.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108618.5.clone.1, %shift-right-logical.114757.5.clone.1) + %xor.120824.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246091.3.clone.1, %or.114277.3.clone.1) + %add.246092.3.clone.1 = u32[1280,1280]{1,0} add(%add.246091.3.clone.1, %xor.120824.3.clone.1) + %add.246093.17.clone.1 = u32[1280,1280]{1,0} add(%add.246092.3.clone.1, %broadcast.247887.24.clone.1) + %shift-left.108619.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120824.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114758.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120824.3.clone.1, %broadcast.244418.4352) + %or.114278.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108619.5.clone.1, %shift-right-logical.114758.5.clone.1) + %xor.120825.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246092.3.clone.1, %or.114278.3.clone.1) + %constant_217994_1_clone_1 = u32[] constant(1224719249) + %broadcast.247920.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217994_1_clone_1), dimensions={} + %add.246094.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120825.15.clone.1, %broadcast.247920.19.clone.1) + %xor.120826.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246093.17.clone.1, %add.246094.19.clone.1) + %shift-right-logical.114759.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120826.17.clone.1, %broadcast.244468.1920) + %or.114279.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114759.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5708.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114279.13.clone.1) + %add.246095.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5708.11.clone.1, %broadcast.244470.1152) + %multiply.26002.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246095.9.clone.1, %broadcast.244471.896) + %add.246096.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26002.7.clone.1, %broadcast.244408.1024) + %maximum.3640.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246096.5.clone.1) + %abs.1510.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3640.3.clone.1) + %compare.7168.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1510.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26003.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3640.3.clone.1, %broadcast.244476.1152) + %negate.4525.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3640.3.clone.1) + %multiply.26004.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3640.3.clone.1, %negate.4525.5.clone.1) + %log-plus-one.1510.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26004.5.clone.1) + %negate.4526.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1510.3.clone.1) + %compare.7169.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4526.4.clone.1, %broadcast.244477.384), direction=LT + %select.20675.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20676.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20677.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20678.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20679.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20680.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20681.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20682.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20683.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246097.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4526.4.clone.1, %broadcast.244496.640) + %sqrt.1510.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4526.4.clone.1) + %add.246098.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1510.5.clone.1, %broadcast.244498.640) + %select.20684.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7169.3.clone.1, %add.246097.5.clone.1, %add.246098.5.clone.1) + %multiply.26005.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20683.3.clone.1, %select.20684.3.clone.1) + %add.246099.1.clone.1 = f32[1280,1280]{1,0} add(%select.20682.3.clone.1, %multiply.26005.1.clone.1) + %multiply.26006.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246099.1.clone.1, %select.20684.3.clone.1) + %add.246100.1.clone.1 = f32[1280,1280]{1,0} add(%select.20681.3.clone.1, %multiply.26006.1.clone.1) + %multiply.26007.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246100.1.clone.1, %select.20684.3.clone.1) + %add.246101.1.clone.1 = f32[1280,1280]{1,0} add(%select.20680.3.clone.1, %multiply.26007.1.clone.1) + %multiply.26008.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246101.1.clone.1, %select.20684.3.clone.1) + %add.246102.1.clone.1 = f32[1280,1280]{1,0} add(%select.20679.3.clone.1, %multiply.26008.1.clone.1) + %multiply.26009.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246102.1.clone.1, %select.20684.3.clone.1) + %add.246103.3.clone.1 = f32[1280,1280]{1,0} add(%select.20678.5.clone.1, %multiply.26009.1.clone.1) + %multiply.26010.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246103.3.clone.1, %select.20684.3.clone.1) + %add.246104.3.clone.1 = f32[1280,1280]{1,0} add(%select.20677.5.clone.1, %multiply.26010.1.clone.1) + %multiply.26011.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246104.3.clone.1, %select.20684.3.clone.1) + %add.246105.9.clone.1 = f32[1280,1280]{1,0} add(%select.20676.11.clone.1, %multiply.26011.7.clone.1) + %multiply.26012.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246105.9.clone.1, %select.20684.3.clone.1) + %add.246106.7.clone.1 = f32[1280,1280]{1,0} add(%select.20675.7.clone.1, %multiply.26012.7.clone.1) + %multiply.26013.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246106.7.clone.1, %maximum.3640.3.clone.1) + %select.20685.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7168.3.clone.1, %multiply.26003.9.clone.1, %multiply.26013.7.clone.1) + %multiply.26014.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20685.7.clone.1, %broadcast.244500.640) + %clamp.1154.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26014.5.clone.1, %broadcast.244501.384) + %multiply.26015.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1154.3.clone.1, %broadcast.244502.1) + %constant_189231_1_clone_1 = u32[] constant(1760383797) + %broadcast.257864.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189231_1_clone_1), dimensions={} + %add.251753.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.257864.44.clone.1) + %constant_189238_1_clone_1 = u32[] constant(3661218658) + %broadcast.257865.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189238_1_clone_1), dimensions={} + %add.251754.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.257865.113.clone.1) + %add.251755.35.clone.1 = u32[1280,1280]{1,0} add(%add.251753.37.clone.1, %add.251754.99.clone.1) + %shift-left.111080.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251754.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117350.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251754.99.clone.1, %broadcast.244415.6016) + %or.116876.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111080.31.clone.1, %shift-right-logical.117350.29.clone.1) + %xor.123435.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251755.35.clone.1, %or.116876.29.clone.1) + %add.251756.5.clone.1 = u32[1280,1280]{1,0} add(%add.251755.35.clone.1, %xor.123435.27.clone.1) + %shift-left.111081.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123435.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117351.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123435.27.clone.1, %broadcast.244417.5760) + %or.116877.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111081.9.clone.1, %shift-right-logical.117351.9.clone.1) + %xor.123436.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251756.5.clone.1, %or.116877.7.clone.1) + %add.251757.3.clone.1 = u32[1280,1280]{1,0} add(%add.251756.5.clone.1, %xor.123436.5.clone.1) + %shift-left.111082.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123436.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117352.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123436.5.clone.1, %broadcast.244419.4352) + %or.116878.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111082.5.clone.1, %shift-right-logical.117352.5.clone.1) + %xor.123437.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251757.3.clone.1, %or.116878.3.clone.1) + %add.251758.3.clone.1 = u32[1280,1280]{1,0} add(%add.251757.3.clone.1, %xor.123437.3.clone.1) + %add.251759.7.clone.1 = u32[1280,1280]{1,0} add(%add.251758.3.clone.1, %broadcast.257865.113.clone.1) + %shift-left.111083.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123437.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117353.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123437.3.clone.1, %broadcast.244418.4352) + %or.116879.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111083.5.clone.1, %shift-right-logical.117353.5.clone.1) + %xor.123438.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251758.3.clone.1, %or.116879.3.clone.1) + %constant_218620_1_clone_1 = u32[] constant(2835712910) + %broadcast.257875.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218620_1_clone_1), dimensions={} + %add.251760.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123438.3.clone.1, %broadcast.257875.5.clone.1) + %add.251761.5.clone.1 = u32[1280,1280]{1,0} add(%add.251759.7.clone.1, %add.251760.5.clone.1) + %shift-left.111085.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251760.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117354.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251760.5.clone.1, %broadcast.244416.5760) + %or.116881.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111085.9.clone.1, %shift-right-logical.117354.9.clone.1) + %xor.123439.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251761.5.clone.1, %or.116881.7.clone.1) + %add.251762.3.clone.1 = u32[1280,1280]{1,0} add(%add.251761.5.clone.1, %xor.123439.5.clone.1) + %shift-left.111086.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123439.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117355.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123439.5.clone.1, %broadcast.244429.2304) + %or.116882.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111086.9.clone.1, %shift-right-logical.117355.9.clone.1) + %xor.123441.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251762.3.clone.1, %or.116882.7.clone.1) + %add.251763.3.clone.1 = u32[1280,1280]{1,0} add(%add.251762.3.clone.1, %xor.123441.5.clone.1) + %shift-left.111087.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123441.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117356.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123441.5.clone.1, %broadcast.244430.4608) + %or.116883.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111087.9.clone.1, %shift-right-logical.117356.9.clone.1) + %xor.123442.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251763.3.clone.1, %or.116883.7.clone.1) + %add.251764.3.clone.1 = u32[1280,1280]{1,0} add(%add.251763.3.clone.1, %xor.123442.5.clone.1) + %constant_189240_1_clone_1 = u32[] constant(2835712909) + %broadcast.257882.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189240_1_clone_1), dimensions={} + %add.251765.7.clone.1 = u32[1280,1280]{1,0} add(%add.251764.3.clone.1, %broadcast.257882.24.clone.1) + %shift-left.111088.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123442.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117357.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123442.5.clone.1, %broadcast.244434.2816) + %or.116884.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111088.11.clone.1, %shift-right-logical.117357.11.clone.1) + %xor.123443.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251764.3.clone.1, %or.116884.9.clone.1) + %constant_218621_1_clone_1 = u32[] constant(1760383799) + %broadcast.257885.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218621_1_clone_1), dimensions={} + %add.251766.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123443.7.clone.1, %broadcast.257885.5.clone.1) + %add.251767.5.clone.1 = u32[1280,1280]{1,0} add(%add.251765.7.clone.1, %add.251766.5.clone.1) + %shift-left.111090.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251766.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117358.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251766.5.clone.1, %broadcast.244415.6016) + %or.116886.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111090.9.clone.1, %shift-right-logical.117358.9.clone.1) + %xor.123444.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251767.5.clone.1, %or.116886.7.clone.1) + %add.251768.3.clone.1 = u32[1280,1280]{1,0} add(%add.251767.5.clone.1, %xor.123444.5.clone.1) + %shift-left.111091.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123444.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117359.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123444.5.clone.1, %broadcast.244417.5760) + %or.116887.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111091.9.clone.1, %shift-right-logical.117359.9.clone.1) + %xor.123446.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251768.3.clone.1, %or.116887.7.clone.1) + %add.251769.3.clone.1 = u32[1280,1280]{1,0} add(%add.251768.3.clone.1, %xor.123446.5.clone.1) + %shift-left.111092.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123446.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117360.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123446.5.clone.1, %broadcast.244419.4352) + %or.116888.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111092.7.clone.1, %shift-right-logical.117360.7.clone.1) + %xor.123447.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251769.3.clone.1, %or.116888.5.clone.1) + %add.251770.3.clone.1 = u32[1280,1280]{1,0} add(%add.251769.3.clone.1, %xor.123447.3.clone.1) + %add.251771.7.clone.1 = u32[1280,1280]{1,0} add(%add.251770.3.clone.1, %broadcast.257864.44.clone.1) + %shift-left.111093.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123447.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117361.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123447.3.clone.1, %broadcast.244418.4352) + %or.116889.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111093.7.clone.1, %shift-right-logical.117361.7.clone.1) + %xor.123448.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251770.3.clone.1, %or.116889.5.clone.1) + %constant_218622_1_clone_1 = u32[] constant(3661218661) + %broadcast.257895.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218622_1_clone_1), dimensions={} + %add.251772.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123448.3.clone.1, %broadcast.257895.5.clone.1) + %add.251773.5.clone.1 = u32[1280,1280]{1,0} add(%add.251771.7.clone.1, %add.251772.5.clone.1) + %shift-left.111095.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251772.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117362.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251772.5.clone.1, %broadcast.244416.5760) + %or.116890.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111095.9.clone.1, %shift-right-logical.117362.9.clone.1) + %xor.123449.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251773.5.clone.1, %or.116890.7.clone.1) + %add.251775.3.clone.1 = u32[1280,1280]{1,0} add(%add.251773.5.clone.1, %xor.123449.5.clone.1) + %shift-left.111096.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123449.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117363.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123449.5.clone.1, %broadcast.244429.2304) + %or.116891.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111096.9.clone.1, %shift-right-logical.117363.9.clone.1) + %xor.123451.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251775.3.clone.1, %or.116891.7.clone.1) + %add.251778.3.clone.1 = u32[1280,1280]{1,0} add(%add.251775.3.clone.1, %xor.123451.5.clone.1) + %shift-left.111097.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123451.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117364.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123451.5.clone.1, %broadcast.244430.4608) + %or.116892.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111097.9.clone.1, %shift-right-logical.117364.9.clone.1) + %xor.123452.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251778.3.clone.1, %or.116892.7.clone.1) + %add.251779.3.clone.1 = u32[1280,1280]{1,0} add(%add.251778.3.clone.1, %xor.123452.5.clone.1) + %add.251780.7.clone.1 = u32[1280,1280]{1,0} add(%add.251779.3.clone.1, %broadcast.257865.113.clone.1) + %shift-left.111098.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123452.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117365.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123452.5.clone.1, %broadcast.244434.2816) + %or.116893.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111098.11.clone.1, %shift-right-logical.117365.11.clone.1) + %xor.123453.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251779.3.clone.1, %or.116893.9.clone.1) + %constant_218623_1_clone_1 = u32[] constant(2835712913) + %broadcast.257905.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218623_1_clone_1), dimensions={} + %add.251781.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123453.7.clone.1, %broadcast.257905.5.clone.1) + %add.251783.5.clone.1 = u32[1280,1280]{1,0} add(%add.251780.7.clone.1, %add.251781.5.clone.1) + %shift-left.111100.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251781.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117366.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251781.5.clone.1, %broadcast.244415.6016) + %or.116894.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111100.9.clone.1, %shift-right-logical.117366.9.clone.1) + %xor.123454.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251783.5.clone.1, %or.116894.7.clone.1) + %add.251784.3.clone.1 = u32[1280,1280]{1,0} add(%add.251783.5.clone.1, %xor.123454.5.clone.1) + %shift-left.111101.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123454.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117367.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123454.5.clone.1, %broadcast.244417.5760) + %or.116896.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111101.9.clone.1, %shift-right-logical.117367.9.clone.1) + %xor.123456.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251784.3.clone.1, %or.116896.7.clone.1) + %add.251785.3.clone.1 = u32[1280,1280]{1,0} add(%add.251784.3.clone.1, %xor.123456.5.clone.1) + %shift-left.111102.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123456.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117368.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123456.5.clone.1, %broadcast.244419.4352) + %or.116897.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111102.5.clone.1, %shift-right-logical.117368.5.clone.1) + %xor.123457.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251785.3.clone.1, %or.116897.3.clone.1) + %add.251786.3.clone.1 = u32[1280,1280]{1,0} add(%add.251785.3.clone.1, %xor.123457.3.clone.1) + %add.251788.17.clone.1 = u32[1280,1280]{1,0} add(%add.251786.3.clone.1, %broadcast.257882.24.clone.1) + %shift-left.111103.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123457.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117369.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123457.3.clone.1, %broadcast.244418.4352) + %or.116898.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111103.5.clone.1, %shift-right-logical.117369.5.clone.1) + %xor.123458.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251786.3.clone.1, %or.116898.3.clone.1) + %constant_218624_1_clone_1 = u32[] constant(1760383802) + %broadcast.257915.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218624_1_clone_1), dimensions={} + %add.251789.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123458.15.clone.1, %broadcast.257915.19.clone.1) + %xor.123459.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251788.17.clone.1, %add.251789.19.clone.1) + %shift-right-logical.117370.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123459.17.clone.1, %broadcast.244468.1920) + %or.116899.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117370.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5822.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116899.13.clone.1) + %add.251790.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5822.11.clone.1, %broadcast.244470.1152) + %multiply.27172.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251790.9.clone.1, %broadcast.244471.896) + %add.251791.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27172.7.clone.1, %broadcast.244408.1024) + %maximum.3754.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251791.5.clone.1) + %abs.1586.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3754.3.clone.1) + %compare.7334.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1586.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27173.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3754.3.clone.1, %broadcast.244476.1152) + %negate.4677.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3754.3.clone.1) + %multiply.27174.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3754.3.clone.1, %negate.4677.5.clone.1) + %log-plus-one.1586.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27174.5.clone.1) + %negate.4678.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1586.3.clone.1) + %compare.7335.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4678.4.clone.1, %broadcast.244477.384), direction=LT + %select.21542.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21543.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21544.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21545.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21547.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21548.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21549.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21550.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21551.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251793.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4678.4.clone.1, %broadcast.244496.640) + %sqrt.1586.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4678.4.clone.1) + %add.251794.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1586.5.clone.1, %broadcast.244498.640) + %select.21552.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7335.3.clone.1, %add.251793.5.clone.1, %add.251794.5.clone.1) + %multiply.27175.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21551.3.clone.1, %select.21552.3.clone.1) + %add.251795.1.clone.1 = f32[1280,1280]{1,0} add(%select.21550.3.clone.1, %multiply.27175.1.clone.1) + %multiply.27176.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251795.1.clone.1, %select.21552.3.clone.1) + %add.251796.1.clone.1 = f32[1280,1280]{1,0} add(%select.21549.3.clone.1, %multiply.27176.1.clone.1) + %multiply.27177.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251796.1.clone.1, %select.21552.3.clone.1) + %add.251797.1.clone.1 = f32[1280,1280]{1,0} add(%select.21548.3.clone.1, %multiply.27177.1.clone.1) + %multiply.27178.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251797.1.clone.1, %select.21552.3.clone.1) + %add.251799.1.clone.1 = f32[1280,1280]{1,0} add(%select.21547.3.clone.1, %multiply.27178.1.clone.1) + %multiply.27179.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251799.1.clone.1, %select.21552.3.clone.1) + %add.251803.3.clone.1 = f32[1280,1280]{1,0} add(%select.21545.5.clone.1, %multiply.27179.1.clone.1) + %multiply.27180.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251803.3.clone.1, %select.21552.3.clone.1) + %add.251804.3.clone.1 = f32[1280,1280]{1,0} add(%select.21544.5.clone.1, %multiply.27180.1.clone.1) + %multiply.27181.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251804.3.clone.1, %select.21552.3.clone.1) + %add.251805.9.clone.1 = f32[1280,1280]{1,0} add(%select.21543.11.clone.1, %multiply.27181.7.clone.1) + %multiply.27182.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251805.9.clone.1, %select.21552.3.clone.1) + %add.251806.7.clone.1 = f32[1280,1280]{1,0} add(%select.21542.7.clone.1, %multiply.27182.7.clone.1) + %multiply.27183.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251806.7.clone.1, %maximum.3754.3.clone.1) + %select.21553.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7334.3.clone.1, %multiply.27173.9.clone.1, %multiply.27183.7.clone.1) + %multiply.27184.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21553.7.clone.1, %broadcast.244500.640) + %clamp.1230.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27184.5.clone.1, %broadcast.244501.384) + %multiply.27185.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1230.3.clone.1, %broadcast.244502.1) + %constant_165867_1_clone_1 = u32[] constant(2880272204) + %broadcast.247764.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165867_1_clone_1), dimensions={} + %add.245992.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247764.44.clone.1) + %constant_165874_1_clone_1 = u32[] constant(2024700297) + %broadcast.247765.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165874_1_clone_1), dimensions={} + %add.245993.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247765.113.clone.1) + %add.245994.35.clone.1 = u32[1280,1280]{1,0} add(%add.245992.37.clone.1, %add.245993.99.clone.1) + %shift-left.108579.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245993.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114714.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245993.99.clone.1, %broadcast.244415.6016) + %or.114238.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108579.31.clone.1, %shift-right-logical.114714.29.clone.1) + %xor.120785.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245994.35.clone.1, %or.114238.29.clone.1) + %add.245996.5.clone.1 = u32[1280,1280]{1,0} add(%add.245994.35.clone.1, %xor.120785.27.clone.1) + %shift-left.108580.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120785.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114715.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120785.27.clone.1, %broadcast.244417.5760) + %or.114239.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108580.9.clone.1, %shift-right-logical.114715.9.clone.1) + %xor.120786.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245996.5.clone.1, %or.114239.7.clone.1) + %add.245997.3.clone.1 = u32[1280,1280]{1,0} add(%add.245996.5.clone.1, %xor.120786.5.clone.1) + %shift-left.108581.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120786.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114716.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120786.5.clone.1, %broadcast.244419.4352) + %or.114240.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108581.5.clone.1, %shift-right-logical.114716.5.clone.1) + %xor.120787.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245997.3.clone.1, %or.114240.3.clone.1) + %add.245998.3.clone.1 = u32[1280,1280]{1,0} add(%add.245997.3.clone.1, %xor.120787.3.clone.1) + %add.245999.7.clone.1 = u32[1280,1280]{1,0} add(%add.245998.3.clone.1, %broadcast.247765.113.clone.1) + %shift-left.108583.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120787.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114717.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120787.3.clone.1, %broadcast.244418.4352) + %or.114241.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108583.5.clone.1, %shift-right-logical.114717.5.clone.1) + %xor.120788.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245998.3.clone.1, %or.114241.3.clone.1) + %constant_217976_1_clone_1 = u32[] constant(3369211168) + %broadcast.247777.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217976_1_clone_1), dimensions={} + %add.246001.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120788.3.clone.1, %broadcast.247777.5.clone.1) + %add.246002.5.clone.1 = u32[1280,1280]{1,0} add(%add.245999.7.clone.1, %add.246001.5.clone.1) + %shift-left.108584.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246001.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114718.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246001.5.clone.1, %broadcast.244416.5760) + %or.114242.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108584.9.clone.1, %shift-right-logical.114718.9.clone.1) + %xor.120789.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246002.5.clone.1, %or.114242.7.clone.1) + %add.246003.3.clone.1 = u32[1280,1280]{1,0} add(%add.246002.5.clone.1, %xor.120789.5.clone.1) + %shift-left.108585.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120789.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114719.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120789.5.clone.1, %broadcast.244429.2304) + %or.114243.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108585.9.clone.1, %shift-right-logical.114719.9.clone.1) + %xor.120790.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246003.3.clone.1, %or.114243.7.clone.1) + %add.246004.3.clone.1 = u32[1280,1280]{1,0} add(%add.246003.3.clone.1, %xor.120790.5.clone.1) + %shift-left.108586.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120790.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114721.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120790.5.clone.1, %broadcast.244430.4608) + %or.114244.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108586.9.clone.1, %shift-right-logical.114721.9.clone.1) + %xor.120791.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246004.3.clone.1, %or.114244.7.clone.1) + %add.246006.3.clone.1 = u32[1280,1280]{1,0} add(%add.246004.3.clone.1, %xor.120791.5.clone.1) + %constant_165876_1_clone_1 = u32[] constant(3369211167) + %broadcast.247784.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165876_1_clone_1), dimensions={} + %add.246007.7.clone.1 = u32[1280,1280]{1,0} add(%add.246006.3.clone.1, %broadcast.247784.24.clone.1) + %shift-left.108587.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120791.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114722.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120791.5.clone.1, %broadcast.244434.2816) + %or.114245.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108587.11.clone.1, %shift-right-logical.114722.11.clone.1) + %xor.120792.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246006.3.clone.1, %or.114245.9.clone.1) + %constant_217978_1_clone_1 = u32[] constant(2880272206) + %broadcast.247787.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217978_1_clone_1), dimensions={} + %add.246008.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120792.7.clone.1, %broadcast.247787.5.clone.1) + %add.246009.5.clone.1 = u32[1280,1280]{1,0} add(%add.246007.7.clone.1, %add.246008.5.clone.1) + %shift-left.108588.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246008.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114723.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246008.5.clone.1, %broadcast.244415.6016) + %or.114246.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108588.9.clone.1, %shift-right-logical.114723.9.clone.1) + %xor.120793.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246009.5.clone.1, %or.114246.7.clone.1) + %add.246010.3.clone.1 = u32[1280,1280]{1,0} add(%add.246009.5.clone.1, %xor.120793.5.clone.1) + %shift-left.108589.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120793.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114724.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120793.5.clone.1, %broadcast.244417.5760) + %or.114247.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108589.9.clone.1, %shift-right-logical.114724.9.clone.1) + %xor.120794.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246010.3.clone.1, %or.114247.7.clone.1) + %add.246012.3.clone.1 = u32[1280,1280]{1,0} add(%add.246010.3.clone.1, %xor.120794.5.clone.1) + %shift-left.108590.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120794.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114726.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120794.5.clone.1, %broadcast.244419.4352) + %or.114248.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108590.7.clone.1, %shift-right-logical.114726.7.clone.1) + %xor.120795.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246012.3.clone.1, %or.114248.5.clone.1) + %add.246016.3.clone.1 = u32[1280,1280]{1,0} add(%add.246012.3.clone.1, %xor.120795.3.clone.1) + %add.246017.7.clone.1 = u32[1280,1280]{1,0} add(%add.246016.3.clone.1, %broadcast.247764.44.clone.1) + %shift-left.108591.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120795.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114727.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120795.3.clone.1, %broadcast.244418.4352) + %or.114249.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108591.7.clone.1, %shift-right-logical.114727.7.clone.1) + %xor.120796.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246016.3.clone.1, %or.114249.5.clone.1) + %constant_217980_1_clone_1 = u32[] constant(2024700300) + %broadcast.247798.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217980_1_clone_1), dimensions={} + %add.246018.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120796.3.clone.1, %broadcast.247798.5.clone.1) + %add.246019.5.clone.1 = u32[1280,1280]{1,0} add(%add.246017.7.clone.1, %add.246018.5.clone.1) + %shift-left.108592.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246018.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114728.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246018.5.clone.1, %broadcast.244416.5760) + %or.114250.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108592.9.clone.1, %shift-right-logical.114728.9.clone.1) + %xor.120797.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246019.5.clone.1, %or.114250.7.clone.1) + %add.246021.3.clone.1 = u32[1280,1280]{1,0} add(%add.246019.5.clone.1, %xor.120797.5.clone.1) + %shift-left.108593.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120797.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114729.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120797.5.clone.1, %broadcast.244429.2304) + %or.114251.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108593.9.clone.1, %shift-right-logical.114729.9.clone.1) + %xor.120798.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246021.3.clone.1, %or.114251.7.clone.1) + %add.246022.3.clone.1 = u32[1280,1280]{1,0} add(%add.246021.3.clone.1, %xor.120798.5.clone.1) + %shift-left.108594.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120798.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114731.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120798.5.clone.1, %broadcast.244430.4608) + %or.114252.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108594.9.clone.1, %shift-right-logical.114731.9.clone.1) + %xor.120799.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246022.3.clone.1, %or.114252.7.clone.1) + %add.246023.3.clone.1 = u32[1280,1280]{1,0} add(%add.246022.3.clone.1, %xor.120799.5.clone.1) + %add.246024.7.clone.1 = u32[1280,1280]{1,0} add(%add.246023.3.clone.1, %broadcast.247765.113.clone.1) + %shift-left.108595.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120799.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114732.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120799.5.clone.1, %broadcast.244434.2816) + %or.114253.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108595.11.clone.1, %shift-right-logical.114732.11.clone.1) + %xor.120800.7.clone.1 = u32[1280,1280]{1,0} xor(%add.246023.3.clone.1, %or.114253.9.clone.1) + %constant_217982_1_clone_1 = u32[] constant(3369211171) + %broadcast.247809.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217982_1_clone_1), dimensions={} + %add.246026.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120800.7.clone.1, %broadcast.247809.5.clone.1) + %add.246027.5.clone.1 = u32[1280,1280]{1,0} add(%add.246024.7.clone.1, %add.246026.5.clone.1) + %shift-left.108596.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.246026.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114733.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.246026.5.clone.1, %broadcast.244415.6016) + %or.114254.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108596.9.clone.1, %shift-right-logical.114733.9.clone.1) + %xor.120801.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246027.5.clone.1, %or.114254.7.clone.1) + %add.246028.3.clone.1 = u32[1280,1280]{1,0} add(%add.246027.5.clone.1, %xor.120801.5.clone.1) + %shift-left.108597.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120801.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114734.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120801.5.clone.1, %broadcast.244417.5760) + %or.114255.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108597.9.clone.1, %shift-right-logical.114734.9.clone.1) + %xor.120802.5.clone.1 = u32[1280,1280]{1,0} xor(%add.246028.3.clone.1, %or.114255.7.clone.1) + %add.246029.3.clone.1 = u32[1280,1280]{1,0} add(%add.246028.3.clone.1, %xor.120802.5.clone.1) + %shift-left.108598.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120802.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114736.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120802.5.clone.1, %broadcast.244419.4352) + %or.114256.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108598.5.clone.1, %shift-right-logical.114736.5.clone.1) + %xor.120803.3.clone.1 = u32[1280,1280]{1,0} xor(%add.246029.3.clone.1, %or.114256.3.clone.1) + %add.246031.3.clone.1 = u32[1280,1280]{1,0} add(%add.246029.3.clone.1, %xor.120803.3.clone.1) + %add.246032.17.clone.1 = u32[1280,1280]{1,0} add(%add.246031.3.clone.1, %broadcast.247784.24.clone.1) + %shift-left.108599.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120803.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114737.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120803.3.clone.1, %broadcast.244418.4352) + %or.114257.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108599.5.clone.1, %shift-right-logical.114737.5.clone.1) + %xor.120804.15.clone.1 = u32[1280,1280]{1,0} xor(%add.246031.3.clone.1, %or.114257.3.clone.1) + %constant_217984_1_clone_1 = u32[] constant(2880272209) + %broadcast.247819.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217984_1_clone_1), dimensions={} + %add.246033.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120804.15.clone.1, %broadcast.247819.19.clone.1) + %xor.120805.17.clone.1 = u32[1280,1280]{1,0} xor(%add.246032.17.clone.1, %add.246033.19.clone.1) + %shift-right-logical.114738.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120805.17.clone.1, %broadcast.244468.1920) + %or.114258.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114738.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5707.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114258.13.clone.1) + %add.246034.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5707.11.clone.1, %broadcast.244470.1152) + %multiply.25987.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246034.9.clone.1, %broadcast.244471.896) + %add.246035.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25987.7.clone.1, %broadcast.244408.1024) + %maximum.3639.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.246035.5.clone.1) + %abs.1509.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3639.3.clone.1) + %compare.7166.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1509.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25988.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3639.3.clone.1, %broadcast.244476.1152) + %negate.4523.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3639.3.clone.1) + %multiply.25989.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3639.3.clone.1, %negate.4523.5.clone.1) + %log-plus-one.1509.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25989.5.clone.1) + %negate.4524.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1509.3.clone.1) + %compare.7167.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4524.4.clone.1, %broadcast.244477.384), direction=LT + %select.20664.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20665.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20666.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20667.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20668.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20669.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20670.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20671.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20672.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.246037.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4524.4.clone.1, %broadcast.244496.640) + %sqrt.1509.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4524.4.clone.1) + %add.246041.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1509.5.clone.1, %broadcast.244498.640) + %select.20673.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7167.3.clone.1, %add.246037.5.clone.1, %add.246041.5.clone.1) + %multiply.25990.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20672.3.clone.1, %select.20673.3.clone.1) + %add.246042.1.clone.1 = f32[1280,1280]{1,0} add(%select.20671.3.clone.1, %multiply.25990.1.clone.1) + %multiply.25991.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246042.1.clone.1, %select.20673.3.clone.1) + %add.246043.1.clone.1 = f32[1280,1280]{1,0} add(%select.20670.3.clone.1, %multiply.25991.1.clone.1) + %multiply.25992.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246043.1.clone.1, %select.20673.3.clone.1) + %add.246044.1.clone.1 = f32[1280,1280]{1,0} add(%select.20669.3.clone.1, %multiply.25992.1.clone.1) + %multiply.25993.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246044.1.clone.1, %select.20673.3.clone.1) + %add.246046.1.clone.1 = f32[1280,1280]{1,0} add(%select.20668.3.clone.1, %multiply.25993.1.clone.1) + %multiply.25994.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246046.1.clone.1, %select.20673.3.clone.1) + %add.246047.3.clone.1 = f32[1280,1280]{1,0} add(%select.20667.5.clone.1, %multiply.25994.1.clone.1) + %multiply.25995.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.246047.3.clone.1, %select.20673.3.clone.1) + %add.246048.3.clone.1 = f32[1280,1280]{1,0} add(%select.20666.5.clone.1, %multiply.25995.1.clone.1) + %multiply.25996.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246048.3.clone.1, %select.20673.3.clone.1) + %add.246049.9.clone.1 = f32[1280,1280]{1,0} add(%select.20665.11.clone.1, %multiply.25996.7.clone.1) + %multiply.25997.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246049.9.clone.1, %select.20673.3.clone.1) + %add.246051.7.clone.1 = f32[1280,1280]{1,0} add(%select.20664.7.clone.1, %multiply.25997.7.clone.1) + %multiply.25998.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.246051.7.clone.1, %maximum.3639.3.clone.1) + %select.20674.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7166.3.clone.1, %multiply.25988.9.clone.1, %multiply.25998.7.clone.1) + %multiply.25999.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20674.7.clone.1, %broadcast.244500.640) + %clamp.1153.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25999.5.clone.1, %broadcast.244501.384) + %multiply.26001.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1153.3.clone.1, %broadcast.244502.1) + %constant_181484_1_clone_1 = u32[] constant(2866377858) + %broadcast.254519.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181484_1_clone_1), dimensions={} + %add.249846.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254519.44.clone.1) + %constant_181491_1_clone_1 = u32[] constant(2908281955) + %broadcast.254521.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181491_1_clone_1), dimensions={} + %add.249847.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254521.113.clone.1) + %add.249849.35.clone.1 = u32[1280,1280]{1,0} add(%add.249846.37.clone.1, %add.249847.99.clone.1) + %shift-left.110240.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249847.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116466.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249847.99.clone.1, %broadcast.244415.6016) + %or.115990.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110240.31.clone.1, %shift-right-logical.116466.29.clone.1) + %xor.122557.27.clone.1 = u32[1280,1280]{1,0} xor(%add.249849.35.clone.1, %or.115990.29.clone.1) + %add.249850.5.clone.1 = u32[1280,1280]{1,0} add(%add.249849.35.clone.1, %xor.122557.27.clone.1) + %shift-left.110241.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122557.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116467.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122557.27.clone.1, %broadcast.244417.5760) + %or.115991.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110241.9.clone.1, %shift-right-logical.116467.9.clone.1) + %xor.122558.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249850.5.clone.1, %or.115991.7.clone.1) + %add.249851.3.clone.1 = u32[1280,1280]{1,0} add(%add.249850.5.clone.1, %xor.122558.5.clone.1) + %shift-left.110242.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122558.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116468.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122558.5.clone.1, %broadcast.244419.4352) + %or.115992.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110242.5.clone.1, %shift-right-logical.116468.5.clone.1) + %xor.122560.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249851.3.clone.1, %or.115992.3.clone.1) + %add.249852.3.clone.1 = u32[1280,1280]{1,0} add(%add.249851.3.clone.1, %xor.122560.3.clone.1) + %add.249854.7.clone.1 = u32[1280,1280]{1,0} add(%add.249852.3.clone.1, %broadcast.254521.113.clone.1) + %shift-left.110243.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122560.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116469.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122560.3.clone.1, %broadcast.244418.4352) + %or.115993.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110243.5.clone.1, %shift-right-logical.116469.5.clone.1) + %xor.122561.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249852.3.clone.1, %or.115993.3.clone.1) + %constant_218409_1_clone_1 = u32[] constant(475050812) + %broadcast.254539.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218409_1_clone_1), dimensions={} + %add.249855.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122561.3.clone.1, %broadcast.254539.5.clone.1) + %add.249856.5.clone.1 = u32[1280,1280]{1,0} add(%add.249854.7.clone.1, %add.249855.5.clone.1) + %shift-left.110244.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249855.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116470.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249855.5.clone.1, %broadcast.244416.5760) + %or.115994.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110244.9.clone.1, %shift-right-logical.116470.9.clone.1) + %xor.122562.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249856.5.clone.1, %or.115994.7.clone.1) + %add.249857.3.clone.1 = u32[1280,1280]{1,0} add(%add.249856.5.clone.1, %xor.122562.5.clone.1) + %shift-left.110245.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122562.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116471.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122562.5.clone.1, %broadcast.244429.2304) + %or.115995.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110245.9.clone.1, %shift-right-logical.116471.9.clone.1) + %xor.122563.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249857.3.clone.1, %or.115995.7.clone.1) + %add.249859.3.clone.1 = u32[1280,1280]{1,0} add(%add.249857.3.clone.1, %xor.122563.5.clone.1) + %shift-left.110246.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122563.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116472.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122563.5.clone.1, %broadcast.244430.4608) + %or.115996.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110246.9.clone.1, %shift-right-logical.116472.9.clone.1) + %xor.122565.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249859.3.clone.1, %or.115996.7.clone.1) + %add.249860.3.clone.1 = u32[1280,1280]{1,0} add(%add.249859.3.clone.1, %xor.122565.5.clone.1) + %constant_181493_1_clone_1 = u32[] constant(475050811) + %broadcast.254546.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181493_1_clone_1), dimensions={} + %add.249861.7.clone.1 = u32[1280,1280]{1,0} add(%add.249860.3.clone.1, %broadcast.254546.24.clone.1) + %shift-left.110247.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122565.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116473.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122565.5.clone.1, %broadcast.244434.2816) + %or.115997.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110247.11.clone.1, %shift-right-logical.116473.11.clone.1) + %xor.122566.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249860.3.clone.1, %or.115997.9.clone.1) + %constant_218410_1_clone_1 = u32[] constant(2866377860) + %broadcast.254549.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218410_1_clone_1), dimensions={} + %add.249862.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122566.7.clone.1, %broadcast.254549.5.clone.1) + %add.249863.5.clone.1 = u32[1280,1280]{1,0} add(%add.249861.7.clone.1, %add.249862.5.clone.1) + %shift-left.110248.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249862.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116474.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249862.5.clone.1, %broadcast.244415.6016) + %or.115998.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110248.9.clone.1, %shift-right-logical.116474.9.clone.1) + %xor.122567.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249863.5.clone.1, %or.115998.7.clone.1) + %add.249865.3.clone.1 = u32[1280,1280]{1,0} add(%add.249863.5.clone.1, %xor.122567.5.clone.1) + %shift-left.110249.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122567.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116475.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122567.5.clone.1, %broadcast.244417.5760) + %or.115999.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110249.9.clone.1, %shift-right-logical.116475.9.clone.1) + %xor.122568.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249865.3.clone.1, %or.115999.7.clone.1) + %add.249868.3.clone.1 = u32[1280,1280]{1,0} add(%add.249865.3.clone.1, %xor.122568.5.clone.1) + %shift-left.110250.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122568.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116476.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122568.5.clone.1, %broadcast.244419.4352) + %or.116000.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110250.7.clone.1, %shift-right-logical.116476.7.clone.1) + %xor.122569.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249868.3.clone.1, %or.116000.5.clone.1) + %add.249869.3.clone.1 = u32[1280,1280]{1,0} add(%add.249868.3.clone.1, %xor.122569.3.clone.1) + %add.249870.7.clone.1 = u32[1280,1280]{1,0} add(%add.249869.3.clone.1, %broadcast.254519.44.clone.1) + %shift-left.110251.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122569.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116477.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122569.3.clone.1, %broadcast.244418.4352) + %or.116001.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110251.7.clone.1, %shift-right-logical.116477.7.clone.1) + %xor.122570.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249869.3.clone.1, %or.116001.5.clone.1) + %constant_218411_1_clone_1 = u32[] constant(2908281958) + %broadcast.254559.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218411_1_clone_1), dimensions={} + %add.249871.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122570.3.clone.1, %broadcast.254559.5.clone.1) + %add.249872.5.clone.1 = u32[1280,1280]{1,0} add(%add.249870.7.clone.1, %add.249871.5.clone.1) + %shift-left.110252.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249871.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116478.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249871.5.clone.1, %broadcast.244416.5760) + %or.116002.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110252.9.clone.1, %shift-right-logical.116478.9.clone.1) + %xor.122571.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249872.5.clone.1, %or.116002.7.clone.1) + %add.249873.3.clone.1 = u32[1280,1280]{1,0} add(%add.249872.5.clone.1, %xor.122571.5.clone.1) + %shift-left.110253.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122571.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116480.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122571.5.clone.1, %broadcast.244429.2304) + %or.116003.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110253.9.clone.1, %shift-right-logical.116480.9.clone.1) + %xor.122572.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249873.3.clone.1, %or.116003.7.clone.1) + %add.249874.3.clone.1 = u32[1280,1280]{1,0} add(%add.249873.3.clone.1, %xor.122572.5.clone.1) + %shift-left.110254.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122572.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116481.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122572.5.clone.1, %broadcast.244430.4608) + %or.116004.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110254.9.clone.1, %shift-right-logical.116481.9.clone.1) + %xor.122573.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249874.3.clone.1, %or.116004.7.clone.1) + %add.249875.3.clone.1 = u32[1280,1280]{1,0} add(%add.249874.3.clone.1, %xor.122573.5.clone.1) + %add.249876.7.clone.1 = u32[1280,1280]{1,0} add(%add.249875.3.clone.1, %broadcast.254521.113.clone.1) + %shift-left.110255.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122573.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116482.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122573.5.clone.1, %broadcast.244434.2816) + %or.116005.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110255.11.clone.1, %shift-right-logical.116482.11.clone.1) + %xor.122574.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249875.3.clone.1, %or.116005.9.clone.1) + %constant_218412_1_clone_1 = u32[] constant(475050815) + %broadcast.254569.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218412_1_clone_1), dimensions={} + %add.249877.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122574.7.clone.1, %broadcast.254569.5.clone.1) + %add.249878.5.clone.1 = u32[1280,1280]{1,0} add(%add.249876.7.clone.1, %add.249877.5.clone.1) + %shift-left.110256.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249877.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116483.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249877.5.clone.1, %broadcast.244415.6016) + %or.116006.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110256.9.clone.1, %shift-right-logical.116483.9.clone.1) + %xor.122575.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249878.5.clone.1, %or.116006.7.clone.1) + %add.249879.3.clone.1 = u32[1280,1280]{1,0} add(%add.249878.5.clone.1, %xor.122575.5.clone.1) + %shift-left.110257.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122575.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116485.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122575.5.clone.1, %broadcast.244417.5760) + %or.116007.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110257.9.clone.1, %shift-right-logical.116485.9.clone.1) + %xor.122576.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249879.3.clone.1, %or.116007.7.clone.1) + %add.249880.3.clone.1 = u32[1280,1280]{1,0} add(%add.249879.3.clone.1, %xor.122576.5.clone.1) + %shift-left.110258.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122576.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116486.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122576.5.clone.1, %broadcast.244419.4352) + %or.116008.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110258.5.clone.1, %shift-right-logical.116486.5.clone.1) + %xor.122577.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249880.3.clone.1, %or.116008.3.clone.1) + %add.249881.3.clone.1 = u32[1280,1280]{1,0} add(%add.249880.3.clone.1, %xor.122577.3.clone.1) + %add.249882.17.clone.1 = u32[1280,1280]{1,0} add(%add.249881.3.clone.1, %broadcast.254546.24.clone.1) + %shift-left.110259.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122577.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116487.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122577.3.clone.1, %broadcast.244418.4352) + %or.116009.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110259.5.clone.1, %shift-right-logical.116487.5.clone.1) + %xor.122578.15.clone.1 = u32[1280,1280]{1,0} xor(%add.249881.3.clone.1, %or.116009.3.clone.1) + %constant_218413_1_clone_1 = u32[] constant(2866377863) + %broadcast.254579.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218413_1_clone_1), dimensions={} + %add.249883.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122578.15.clone.1, %broadcast.254579.19.clone.1) + %xor.122579.17.clone.1 = u32[1280,1280]{1,0} xor(%add.249882.17.clone.1, %add.249883.19.clone.1) + %shift-right-logical.116488.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122579.17.clone.1, %broadcast.244468.1920) + %or.116010.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116488.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5784.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116010.13.clone.1) + %add.249884.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5784.11.clone.1, %broadcast.244470.1152) + %multiply.26769.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249884.9.clone.1, %broadcast.244471.896) + %add.249885.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26769.7.clone.1, %broadcast.244408.1024) + %maximum.3716.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.249885.5.clone.1) + %abs.1560.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3716.3.clone.1) + %compare.7269.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1560.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26770.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3716.3.clone.1, %broadcast.244476.1152) + %negate.4625.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3716.3.clone.1) + %multiply.26771.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3716.3.clone.1, %negate.4625.5.clone.1) + %log-plus-one.1560.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26771.5.clone.1) + %negate.4626.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1560.3.clone.1) + %compare.7270.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4626.4.clone.1, %broadcast.244477.384), direction=LT + %select.21246.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21247.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21248.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21249.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21250.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21251.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21252.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21253.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21254.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.249886.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4626.4.clone.1, %broadcast.244496.640) + %sqrt.1560.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4626.4.clone.1) + %add.249887.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1560.5.clone.1, %broadcast.244498.640) + %select.21255.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7270.3.clone.1, %add.249886.5.clone.1, %add.249887.5.clone.1) + %multiply.26772.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21254.3.clone.1, %select.21255.3.clone.1) + %add.249888.1.clone.1 = f32[1280,1280]{1,0} add(%select.21253.3.clone.1, %multiply.26772.1.clone.1) + %multiply.26773.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249888.1.clone.1, %select.21255.3.clone.1) + %add.249889.1.clone.1 = f32[1280,1280]{1,0} add(%select.21252.3.clone.1, %multiply.26773.1.clone.1) + %multiply.26774.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249889.1.clone.1, %select.21255.3.clone.1) + %add.249890.1.clone.1 = f32[1280,1280]{1,0} add(%select.21251.3.clone.1, %multiply.26774.1.clone.1) + %multiply.26775.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249890.1.clone.1, %select.21255.3.clone.1) + %add.249891.1.clone.1 = f32[1280,1280]{1,0} add(%select.21250.3.clone.1, %multiply.26775.1.clone.1) + %multiply.26776.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249891.1.clone.1, %select.21255.3.clone.1) + %add.249892.3.clone.1 = f32[1280,1280]{1,0} add(%select.21249.5.clone.1, %multiply.26776.1.clone.1) + %multiply.26777.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249892.3.clone.1, %select.21255.3.clone.1) + %add.249893.3.clone.1 = f32[1280,1280]{1,0} add(%select.21248.5.clone.1, %multiply.26777.1.clone.1) + %multiply.26778.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249893.3.clone.1, %select.21255.3.clone.1) + %add.249894.9.clone.1 = f32[1280,1280]{1,0} add(%select.21247.11.clone.1, %multiply.26778.7.clone.1) + %multiply.26779.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249894.9.clone.1, %select.21255.3.clone.1) + %add.249896.7.clone.1 = f32[1280,1280]{1,0} add(%select.21246.7.clone.1, %multiply.26779.7.clone.1) + %multiply.26780.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249896.7.clone.1, %maximum.3716.3.clone.1) + %select.21256.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7269.3.clone.1, %multiply.26770.9.clone.1, %multiply.26780.7.clone.1) + %multiply.26781.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21256.7.clone.1, %broadcast.244500.640) + %clamp.1204.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26781.5.clone.1, %broadcast.244501.384) + %multiply.26782.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1204.3.clone.1, %broadcast.244502.1) + %constant_165619_1_clone_1 = u32[] constant(3292448747) + %broadcast.247657.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165619_1_clone_1), dimensions={} + %add.245938.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247657.44.clone.1) + %constant_165626_1_clone_1 = u32[] constant(64167697) + %broadcast.247658.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165626_1_clone_1), dimensions={} + %add.245939.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247658.113.clone.1) + %add.245940.35.clone.1 = u32[1280,1280]{1,0} add(%add.245938.37.clone.1, %add.245939.99.clone.1) + %shift-left.108555.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245939.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114689.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245939.99.clone.1, %broadcast.244415.6016) + %or.114216.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108555.31.clone.1, %shift-right-logical.114689.29.clone.1) + %xor.120764.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245940.35.clone.1, %or.114216.29.clone.1) + %add.245941.5.clone.1 = u32[1280,1280]{1,0} add(%add.245940.35.clone.1, %xor.120764.27.clone.1) + %shift-left.108556.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120764.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114690.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120764.27.clone.1, %broadcast.244417.5760) + %or.114217.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108556.9.clone.1, %shift-right-logical.114690.9.clone.1) + %xor.120765.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245941.5.clone.1, %or.114217.7.clone.1) + %add.245942.3.clone.1 = u32[1280,1280]{1,0} add(%add.245941.5.clone.1, %xor.120765.5.clone.1) + %shift-left.108558.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120765.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114691.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120765.5.clone.1, %broadcast.244419.4352) + %or.114218.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108558.5.clone.1, %shift-right-logical.114691.5.clone.1) + %xor.120766.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245942.3.clone.1, %or.114218.3.clone.1) + %add.245943.3.clone.1 = u32[1280,1280]{1,0} add(%add.245942.3.clone.1, %xor.120766.3.clone.1) + %add.245944.7.clone.1 = u32[1280,1280]{1,0} add(%add.245943.3.clone.1, %broadcast.247658.113.clone.1) + %shift-left.108559.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120766.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114692.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120766.3.clone.1, %broadcast.244418.4352) + %or.114219.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108559.5.clone.1, %shift-right-logical.114692.5.clone.1) + %xor.120767.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245943.3.clone.1, %or.114219.3.clone.1) + %constant_217966_1_clone_1 = u32[] constant(3694969633) + %broadcast.247671.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217966_1_clone_1), dimensions={} + %add.245945.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120767.3.clone.1, %broadcast.247671.5.clone.1) + %add.245946.5.clone.1 = u32[1280,1280]{1,0} add(%add.245944.7.clone.1, %add.245945.5.clone.1) + %shift-left.108560.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245945.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114693.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245945.5.clone.1, %broadcast.244416.5760) + %or.114220.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108560.9.clone.1, %shift-right-logical.114693.9.clone.1) + %xor.120768.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245946.5.clone.1, %or.114220.7.clone.1) + %add.245947.3.clone.1 = u32[1280,1280]{1,0} add(%add.245946.5.clone.1, %xor.120768.5.clone.1) + %shift-left.108561.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120768.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114694.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120768.5.clone.1, %broadcast.244429.2304) + %or.114221.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108561.9.clone.1, %shift-right-logical.114694.9.clone.1) + %xor.120769.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245947.3.clone.1, %or.114221.7.clone.1) + %add.245948.3.clone.1 = u32[1280,1280]{1,0} add(%add.245947.3.clone.1, %xor.120769.5.clone.1) + %shift-left.108562.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120769.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114696.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120769.5.clone.1, %broadcast.244430.4608) + %or.114222.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108562.9.clone.1, %shift-right-logical.114696.9.clone.1) + %xor.120770.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245948.3.clone.1, %or.114222.7.clone.1) + %add.245949.3.clone.1 = u32[1280,1280]{1,0} add(%add.245948.3.clone.1, %xor.120770.5.clone.1) + %constant_165628_1_clone_1 = u32[] constant(3694969632) + %broadcast.247685.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165628_1_clone_1), dimensions={} + %add.245950.7.clone.1 = u32[1280,1280]{1,0} add(%add.245949.3.clone.1, %broadcast.247685.24.clone.1) + %shift-left.108563.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120770.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114697.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120770.5.clone.1, %broadcast.244434.2816) + %or.114223.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108563.11.clone.1, %shift-right-logical.114697.11.clone.1) + %xor.120771.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245949.3.clone.1, %or.114223.9.clone.1) + %constant_217968_1_clone_1 = u32[] constant(3292448749) + %broadcast.247691.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217968_1_clone_1), dimensions={} + %add.245951.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120771.7.clone.1, %broadcast.247691.5.clone.1) + %add.245952.5.clone.1 = u32[1280,1280]{1,0} add(%add.245950.7.clone.1, %add.245951.5.clone.1) + %shift-left.108564.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245951.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114698.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245951.5.clone.1, %broadcast.244415.6016) + %or.114224.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108564.9.clone.1, %shift-right-logical.114698.9.clone.1) + %xor.120772.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245952.5.clone.1, %or.114224.7.clone.1) + %add.245953.3.clone.1 = u32[1280,1280]{1,0} add(%add.245952.5.clone.1, %xor.120772.5.clone.1) + %shift-left.108565.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120772.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114699.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120772.5.clone.1, %broadcast.244417.5760) + %or.114225.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108565.9.clone.1, %shift-right-logical.114699.9.clone.1) + %xor.120773.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245953.3.clone.1, %or.114225.7.clone.1) + %add.245954.3.clone.1 = u32[1280,1280]{1,0} add(%add.245953.3.clone.1, %xor.120773.5.clone.1) + %shift-left.108566.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120773.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114701.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120773.5.clone.1, %broadcast.244419.4352) + %or.114226.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108566.7.clone.1, %shift-right-logical.114701.7.clone.1) + %xor.120774.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245954.3.clone.1, %or.114226.5.clone.1) + %add.245955.3.clone.1 = u32[1280,1280]{1,0} add(%add.245954.3.clone.1, %xor.120774.3.clone.1) + %add.245956.7.clone.1 = u32[1280,1280]{1,0} add(%add.245955.3.clone.1, %broadcast.247657.44.clone.1) + %shift-left.108568.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120774.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114702.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120774.3.clone.1, %broadcast.244418.4352) + %or.114228.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108568.7.clone.1, %shift-right-logical.114702.7.clone.1) + %xor.120775.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245955.3.clone.1, %or.114228.5.clone.1) + %constant_217970_1_clone_1 = u32[] constant(64167700) + %broadcast.247705.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217970_1_clone_1), dimensions={} + %add.245957.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120775.3.clone.1, %broadcast.247705.5.clone.1) + %add.245958.5.clone.1 = u32[1280,1280]{1,0} add(%add.245956.7.clone.1, %add.245957.5.clone.1) + %shift-left.108569.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245957.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114703.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245957.5.clone.1, %broadcast.244416.5760) + %or.114229.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108569.9.clone.1, %shift-right-logical.114703.9.clone.1) + %xor.120776.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245958.5.clone.1, %or.114229.7.clone.1) + %add.245959.3.clone.1 = u32[1280,1280]{1,0} add(%add.245958.5.clone.1, %xor.120776.5.clone.1) + %shift-left.108570.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120776.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114704.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120776.5.clone.1, %broadcast.244429.2304) + %or.114230.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108570.9.clone.1, %shift-right-logical.114704.9.clone.1) + %xor.120777.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245959.3.clone.1, %or.114230.7.clone.1) + %add.245960.3.clone.1 = u32[1280,1280]{1,0} add(%add.245959.3.clone.1, %xor.120777.5.clone.1) + %shift-left.108571.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120777.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114706.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120777.5.clone.1, %broadcast.244430.4608) + %or.114231.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108571.9.clone.1, %shift-right-logical.114706.9.clone.1) + %xor.120778.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245960.3.clone.1, %or.114231.7.clone.1) + %add.245961.3.clone.1 = u32[1280,1280]{1,0} add(%add.245960.3.clone.1, %xor.120778.5.clone.1) + %add.245963.7.clone.1 = u32[1280,1280]{1,0} add(%add.245961.3.clone.1, %broadcast.247658.113.clone.1) + %shift-left.108573.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120778.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114707.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120778.5.clone.1, %broadcast.244434.2816) + %or.114232.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108573.11.clone.1, %shift-right-logical.114707.11.clone.1) + %xor.120779.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245961.3.clone.1, %or.114232.9.clone.1) + %constant_217972_1_clone_1 = u32[] constant(3694969636) + %broadcast.247715.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217972_1_clone_1), dimensions={} + %add.245966.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120779.7.clone.1, %broadcast.247715.5.clone.1) + %add.245967.5.clone.1 = u32[1280,1280]{1,0} add(%add.245963.7.clone.1, %add.245966.5.clone.1) + %shift-left.108574.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245966.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114708.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245966.5.clone.1, %broadcast.244415.6016) + %or.114233.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108574.9.clone.1, %shift-right-logical.114708.9.clone.1) + %xor.120780.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245967.5.clone.1, %or.114233.7.clone.1) + %add.245968.3.clone.1 = u32[1280,1280]{1,0} add(%add.245967.5.clone.1, %xor.120780.5.clone.1) + %shift-left.108575.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120780.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114709.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120780.5.clone.1, %broadcast.244417.5760) + %or.114234.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108575.9.clone.1, %shift-right-logical.114709.9.clone.1) + %xor.120781.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245968.3.clone.1, %or.114234.7.clone.1) + %add.245969.3.clone.1 = u32[1280,1280]{1,0} add(%add.245968.3.clone.1, %xor.120781.5.clone.1) + %shift-left.108576.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120781.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114711.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120781.5.clone.1, %broadcast.244419.4352) + %or.114235.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108576.5.clone.1, %shift-right-logical.114711.5.clone.1) + %xor.120782.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245969.3.clone.1, %or.114235.3.clone.1) + %add.245971.3.clone.1 = u32[1280,1280]{1,0} add(%add.245969.3.clone.1, %xor.120782.3.clone.1) + %add.245972.17.clone.1 = u32[1280,1280]{1,0} add(%add.245971.3.clone.1, %broadcast.247685.24.clone.1) + %shift-left.108578.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120782.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114712.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120782.3.clone.1, %broadcast.244418.4352) + %or.114236.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108578.5.clone.1, %shift-right-logical.114712.5.clone.1) + %xor.120783.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245971.3.clone.1, %or.114236.3.clone.1) + %constant_217974_1_clone_1 = u32[] constant(3292448752) + %broadcast.247727.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217974_1_clone_1), dimensions={} + %add.245973.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120783.15.clone.1, %broadcast.247727.19.clone.1) + %xor.120784.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245972.17.clone.1, %add.245973.19.clone.1) + %shift-right-logical.114713.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120784.17.clone.1, %broadcast.244468.1920) + %or.114237.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114713.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5706.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114237.13.clone.1) + %add.245974.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5706.11.clone.1, %broadcast.244470.1152) + %multiply.25973.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245974.9.clone.1, %broadcast.244471.896) + %add.245976.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25973.7.clone.1, %broadcast.244408.1024) + %maximum.3638.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245976.5.clone.1) + %abs.1508.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3638.3.clone.1) + %compare.7164.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1508.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25974.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3638.3.clone.1, %broadcast.244476.1152) + %negate.4521.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3638.3.clone.1) + %multiply.25975.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3638.3.clone.1, %negate.4521.5.clone.1) + %log-plus-one.1508.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25975.5.clone.1) + %negate.4522.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1508.3.clone.1) + %compare.7165.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4522.4.clone.1, %broadcast.244477.384), direction=LT + %select.20653.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20654.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20655.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20656.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20657.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20658.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20659.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20660.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20661.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245977.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4522.4.clone.1, %broadcast.244496.640) + %sqrt.1508.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4522.4.clone.1) + %add.245978.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1508.5.clone.1, %broadcast.244498.640) + %select.20662.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7165.3.clone.1, %add.245977.5.clone.1, %add.245978.5.clone.1) + %multiply.25976.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20661.3.clone.1, %select.20662.3.clone.1) + %add.245979.1.clone.1 = f32[1280,1280]{1,0} add(%select.20660.3.clone.1, %multiply.25976.1.clone.1) + %multiply.25977.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245979.1.clone.1, %select.20662.3.clone.1) + %add.245981.1.clone.1 = f32[1280,1280]{1,0} add(%select.20659.3.clone.1, %multiply.25977.1.clone.1) + %multiply.25978.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245981.1.clone.1, %select.20662.3.clone.1) + %add.245982.1.clone.1 = f32[1280,1280]{1,0} add(%select.20658.3.clone.1, %multiply.25978.1.clone.1) + %multiply.25979.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245982.1.clone.1, %select.20662.3.clone.1) + %add.245983.1.clone.1 = f32[1280,1280]{1,0} add(%select.20657.3.clone.1, %multiply.25979.1.clone.1) + %multiply.25980.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245983.1.clone.1, %select.20662.3.clone.1) + %add.245984.3.clone.1 = f32[1280,1280]{1,0} add(%select.20656.5.clone.1, %multiply.25980.1.clone.1) + %multiply.25981.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245984.3.clone.1, %select.20662.3.clone.1) + %add.245985.3.clone.1 = f32[1280,1280]{1,0} add(%select.20655.5.clone.1, %multiply.25981.1.clone.1) + %multiply.25982.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245985.3.clone.1, %select.20662.3.clone.1) + %add.245987.9.clone.1 = f32[1280,1280]{1,0} add(%select.20654.11.clone.1, %multiply.25982.7.clone.1) + %multiply.25983.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245987.9.clone.1, %select.20662.3.clone.1) + %add.245991.7.clone.1 = f32[1280,1280]{1,0} add(%select.20653.7.clone.1, %multiply.25983.7.clone.1) + %multiply.25984.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245991.7.clone.1, %maximum.3638.3.clone.1) + %select.20663.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7164.3.clone.1, %multiply.25974.9.clone.1, %multiply.25984.7.clone.1) + %multiply.25985.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20663.7.clone.1, %broadcast.244500.640) + %clamp.1152.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25985.5.clone.1, %broadcast.244501.384) + %multiply.25986.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1152.3.clone.1, %broadcast.244502.1) + %constant_195112_1_clone_1 = u32[] constant(398121435) + %broadcast.260400.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195112_1_clone_1), dimensions={} + %add.253202.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.260400.44.clone.1) + %constant_195119_1_clone_1 = u32[] constant(1738677882) + %broadcast.260401.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195119_1_clone_1), dimensions={} + %add.253203.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.260401.113.clone.1) + %add.253205.35.clone.1 = u32[1280,1280]{1,0} add(%add.253202.37.clone.1, %add.253203.99.clone.1) + %shift-left.111713.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253203.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.118009.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253203.99.clone.1, %broadcast.244415.6016) + %or.117548.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111713.31.clone.1, %shift-right-logical.118009.29.clone.1) + %xor.124101.27.clone.1 = u32[1280,1280]{1,0} xor(%add.253205.35.clone.1, %or.117548.29.clone.1) + %add.253206.5.clone.1 = u32[1280,1280]{1,0} add(%add.253205.35.clone.1, %xor.124101.27.clone.1) + %shift-left.111714.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124101.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.118010.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124101.27.clone.1, %broadcast.244417.5760) + %or.117549.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111714.9.clone.1, %shift-right-logical.118010.9.clone.1) + %xor.124102.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253206.5.clone.1, %or.117549.7.clone.1) + %add.253207.3.clone.1 = u32[1280,1280]{1,0} add(%add.253206.5.clone.1, %xor.124102.5.clone.1) + %shift-left.111716.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124102.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118012.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124102.5.clone.1, %broadcast.244419.4352) + %or.117550.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111716.5.clone.1, %shift-right-logical.118012.5.clone.1) + %xor.124103.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253207.3.clone.1, %or.117550.3.clone.1) + %add.253208.3.clone.1 = u32[1280,1280]{1,0} add(%add.253207.3.clone.1, %xor.124103.3.clone.1) + %add.253209.7.clone.1 = u32[1280,1280]{1,0} add(%add.253208.3.clone.1, %broadcast.260401.113.clone.1) + %shift-left.111717.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124103.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118013.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124103.3.clone.1, %broadcast.244418.4352) + %or.117551.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111717.5.clone.1, %shift-right-logical.118013.5.clone.1) + %xor.124104.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253208.3.clone.1, %or.117551.3.clone.1) + %constant_218782_1_clone_1 = u32[] constant(1808391292) + %broadcast.260411.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218782_1_clone_1), dimensions={} + %add.253210.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124104.3.clone.1, %broadcast.260411.5.clone.1) + %add.253211.5.clone.1 = u32[1280,1280]{1,0} add(%add.253209.7.clone.1, %add.253210.5.clone.1) + %shift-left.111718.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253210.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118014.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253210.5.clone.1, %broadcast.244416.5760) + %or.117552.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111718.9.clone.1, %shift-right-logical.118014.9.clone.1) + %xor.124105.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253211.5.clone.1, %or.117552.7.clone.1) + %add.253212.3.clone.1 = u32[1280,1280]{1,0} add(%add.253211.5.clone.1, %xor.124105.5.clone.1) + %shift-left.111719.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124105.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118015.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124105.5.clone.1, %broadcast.244429.2304) + %or.117553.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111719.9.clone.1, %shift-right-logical.118015.9.clone.1) + %xor.124106.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253212.3.clone.1, %or.117553.7.clone.1) + %add.253213.3.clone.1 = u32[1280,1280]{1,0} add(%add.253212.3.clone.1, %xor.124106.5.clone.1) + %shift-left.111721.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124106.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118016.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124106.5.clone.1, %broadcast.244430.4608) + %or.117554.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111721.9.clone.1, %shift-right-logical.118016.9.clone.1) + %xor.124107.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253213.3.clone.1, %or.117554.7.clone.1) + %add.253214.3.clone.1 = u32[1280,1280]{1,0} add(%add.253213.3.clone.1, %xor.124107.5.clone.1) + %constant_195121_1_clone_1 = u32[] constant(1808391291) + %broadcast.260418.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_195121_1_clone_1), dimensions={} + %add.253215.7.clone.1 = u32[1280,1280]{1,0} add(%add.253214.3.clone.1, %broadcast.260418.24.clone.1) + %shift-left.111722.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124107.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118017.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124107.5.clone.1, %broadcast.244434.2816) + %or.117555.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111722.11.clone.1, %shift-right-logical.118017.11.clone.1) + %xor.124108.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253214.3.clone.1, %or.117555.9.clone.1) + %constant_218783_1_clone_1 = u32[] constant(398121437) + %broadcast.260421.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218783_1_clone_1), dimensions={} + %add.253216.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124108.7.clone.1, %broadcast.260421.5.clone.1) + %add.253217.5.clone.1 = u32[1280,1280]{1,0} add(%add.253215.7.clone.1, %add.253216.5.clone.1) + %shift-left.111723.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253216.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118018.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253216.5.clone.1, %broadcast.244415.6016) + %or.117556.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111723.9.clone.1, %shift-right-logical.118018.9.clone.1) + %xor.124109.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253217.5.clone.1, %or.117556.7.clone.1) + %add.253218.3.clone.1 = u32[1280,1280]{1,0} add(%add.253217.5.clone.1, %xor.124109.5.clone.1) + %shift-left.111724.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124109.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118019.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124109.5.clone.1, %broadcast.244417.5760) + %or.117557.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111724.9.clone.1, %shift-right-logical.118019.9.clone.1) + %xor.124110.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253218.3.clone.1, %or.117557.7.clone.1) + %add.253219.3.clone.1 = u32[1280,1280]{1,0} add(%add.253218.3.clone.1, %xor.124110.5.clone.1) + %shift-left.111726.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124110.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118020.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124110.5.clone.1, %broadcast.244419.4352) + %or.117558.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111726.7.clone.1, %shift-right-logical.118020.7.clone.1) + %xor.124111.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253219.3.clone.1, %or.117558.5.clone.1) + %add.253220.3.clone.1 = u32[1280,1280]{1,0} add(%add.253219.3.clone.1, %xor.124111.3.clone.1) + %add.253221.7.clone.1 = u32[1280,1280]{1,0} add(%add.253220.3.clone.1, %broadcast.260400.44.clone.1) + %shift-left.111727.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124111.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118022.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124111.3.clone.1, %broadcast.244418.4352) + %or.117559.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111727.7.clone.1, %shift-right-logical.118022.7.clone.1) + %xor.124112.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253220.3.clone.1, %or.117559.5.clone.1) + %constant_218784_1_clone_1 = u32[] constant(1738677885) + %broadcast.260431.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218784_1_clone_1), dimensions={} + %add.253222.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124112.3.clone.1, %broadcast.260431.5.clone.1) + %add.253223.5.clone.1 = u32[1280,1280]{1,0} add(%add.253221.7.clone.1, %add.253222.5.clone.1) + %shift-left.111728.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253222.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.118023.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253222.5.clone.1, %broadcast.244416.5760) + %or.117560.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111728.9.clone.1, %shift-right-logical.118023.9.clone.1) + %xor.124113.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253223.5.clone.1, %or.117560.7.clone.1) + %add.253224.3.clone.1 = u32[1280,1280]{1,0} add(%add.253223.5.clone.1, %xor.124113.5.clone.1) + %shift-left.111729.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124113.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.118024.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124113.5.clone.1, %broadcast.244429.2304) + %or.117561.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111729.9.clone.1, %shift-right-logical.118024.9.clone.1) + %xor.124114.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253224.3.clone.1, %or.117561.7.clone.1) + %add.253225.3.clone.1 = u32[1280,1280]{1,0} add(%add.253224.3.clone.1, %xor.124114.5.clone.1) + %shift-left.111730.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124114.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.118025.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124114.5.clone.1, %broadcast.244430.4608) + %or.117562.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111730.9.clone.1, %shift-right-logical.118025.9.clone.1) + %xor.124117.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253225.3.clone.1, %or.117562.7.clone.1) + %add.253226.3.clone.1 = u32[1280,1280]{1,0} add(%add.253225.3.clone.1, %xor.124117.5.clone.1) + %add.253227.7.clone.1 = u32[1280,1280]{1,0} add(%add.253226.3.clone.1, %broadcast.260401.113.clone.1) + %shift-left.111731.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124117.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.118027.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124117.5.clone.1, %broadcast.244434.2816) + %or.117563.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111731.11.clone.1, %shift-right-logical.118027.11.clone.1) + %xor.124118.7.clone.1 = u32[1280,1280]{1,0} xor(%add.253226.3.clone.1, %or.117563.9.clone.1) + %constant_218785_1_clone_1 = u32[] constant(1808391295) + %broadcast.260441.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218785_1_clone_1), dimensions={} + %add.253228.5.clone.1 = u32[1280,1280]{1,0} add(%xor.124118.7.clone.1, %broadcast.260441.5.clone.1) + %add.253229.5.clone.1 = u32[1280,1280]{1,0} add(%add.253227.7.clone.1, %add.253228.5.clone.1) + %shift-left.111732.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.253228.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.118028.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.253228.5.clone.1, %broadcast.244415.6016) + %or.117564.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111732.9.clone.1, %shift-right-logical.118028.9.clone.1) + %xor.124119.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253229.5.clone.1, %or.117564.7.clone.1) + %add.253230.3.clone.1 = u32[1280,1280]{1,0} add(%add.253229.5.clone.1, %xor.124119.5.clone.1) + %shift-left.111733.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124119.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.118029.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124119.5.clone.1, %broadcast.244417.5760) + %or.117565.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111733.9.clone.1, %shift-right-logical.118029.9.clone.1) + %xor.124120.5.clone.1 = u32[1280,1280]{1,0} xor(%add.253230.3.clone.1, %or.117565.7.clone.1) + %add.253231.3.clone.1 = u32[1280,1280]{1,0} add(%add.253230.3.clone.1, %xor.124120.5.clone.1) + %shift-left.111734.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124120.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.118030.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124120.5.clone.1, %broadcast.244419.4352) + %or.117566.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111734.5.clone.1, %shift-right-logical.118030.5.clone.1) + %xor.124121.3.clone.1 = u32[1280,1280]{1,0} xor(%add.253231.3.clone.1, %or.117566.3.clone.1) + %add.253232.3.clone.1 = u32[1280,1280]{1,0} add(%add.253231.3.clone.1, %xor.124121.3.clone.1) + %add.253233.17.clone.1 = u32[1280,1280]{1,0} add(%add.253232.3.clone.1, %broadcast.260418.24.clone.1) + %shift-left.111736.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.124121.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.118032.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124121.3.clone.1, %broadcast.244418.4352) + %or.117567.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111736.5.clone.1, %shift-right-logical.118032.5.clone.1) + %xor.124122.15.clone.1 = u32[1280,1280]{1,0} xor(%add.253232.3.clone.1, %or.117567.3.clone.1) + %constant_218786_1_clone_1 = u32[] constant(398121440) + %broadcast.260451.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218786_1_clone_1), dimensions={} + %add.253234.19.clone.1 = u32[1280,1280]{1,0} add(%xor.124122.15.clone.1, %broadcast.260451.19.clone.1) + %xor.124123.17.clone.1 = u32[1280,1280]{1,0} xor(%add.253233.17.clone.1, %add.253234.19.clone.1) + %shift-right-logical.118033.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.124123.17.clone.1, %broadcast.244468.1920) + %or.117568.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.118033.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5851.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117568.13.clone.1) + %add.253235.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5851.11.clone.1, %broadcast.244470.1152) + %multiply.27458.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253235.9.clone.1, %broadcast.244471.896) + %add.253236.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27458.7.clone.1, %broadcast.244408.1024) + %maximum.3783.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.253236.5.clone.1) + %abs.1605.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3783.3.clone.1) + %compare.7372.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1605.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27459.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3783.3.clone.1, %broadcast.244476.1152) + %negate.4715.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3783.3.clone.1) + %multiply.27460.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3783.3.clone.1, %negate.4715.5.clone.1) + %log-plus-one.1605.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27460.5.clone.1) + %negate.4716.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1605.3.clone.1) + %compare.7373.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4716.4.clone.1, %broadcast.244477.384), direction=LT + %select.21762.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21763.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21764.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21765.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21766.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21767.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21768.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21769.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21770.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.253237.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4716.4.clone.1, %broadcast.244496.640) + %sqrt.1605.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4716.4.clone.1) + %add.253238.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1605.5.clone.1, %broadcast.244498.640) + %select.21771.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7373.3.clone.1, %add.253237.5.clone.1, %add.253238.5.clone.1) + %multiply.27461.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21770.3.clone.1, %select.21771.3.clone.1) + %add.253239.1.clone.1 = f32[1280,1280]{1,0} add(%select.21769.3.clone.1, %multiply.27461.1.clone.1) + %multiply.27462.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253239.1.clone.1, %select.21771.3.clone.1) + %add.253241.1.clone.1 = f32[1280,1280]{1,0} add(%select.21768.3.clone.1, %multiply.27462.1.clone.1) + %multiply.27464.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253241.1.clone.1, %select.21771.3.clone.1) + %add.253244.1.clone.1 = f32[1280,1280]{1,0} add(%select.21767.3.clone.1, %multiply.27464.1.clone.1) + %multiply.27465.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253244.1.clone.1, %select.21771.3.clone.1) + %add.253245.1.clone.1 = f32[1280,1280]{1,0} add(%select.21766.3.clone.1, %multiply.27465.1.clone.1) + %multiply.27466.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253245.1.clone.1, %select.21771.3.clone.1) + %add.253246.3.clone.1 = f32[1280,1280]{1,0} add(%select.21765.5.clone.1, %multiply.27466.1.clone.1) + %multiply.27467.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.253246.3.clone.1, %select.21771.3.clone.1) + %add.253247.3.clone.1 = f32[1280,1280]{1,0} add(%select.21764.5.clone.1, %multiply.27467.1.clone.1) + %multiply.27468.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253247.3.clone.1, %select.21771.3.clone.1) + %add.253249.9.clone.1 = f32[1280,1280]{1,0} add(%select.21763.11.clone.1, %multiply.27468.7.clone.1) + %multiply.27469.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253249.9.clone.1, %select.21771.3.clone.1) + %add.253250.7.clone.1 = f32[1280,1280]{1,0} add(%select.21762.7.clone.1, %multiply.27469.7.clone.1) + %multiply.27470.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.253250.7.clone.1, %maximum.3783.3.clone.1) + %select.21772.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7372.3.clone.1, %multiply.27459.9.clone.1, %multiply.27470.7.clone.1) + %multiply.27471.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21772.7.clone.1, %broadcast.244500.640) + %clamp.1249.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27471.5.clone.1, %broadcast.244501.384) + %multiply.27472.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1249.3.clone.1, %broadcast.244502.1) + %constant_165408_1_clone_1 = u32[] constant(3110233334) + %broadcast.247571.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165408_1_clone_1), dimensions={} + %add.245889.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247571.44.clone.1) + %constant_165415_1_clone_1 = u32[] constant(3521699035) + %broadcast.247572.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165415_1_clone_1), dimensions={} + %add.245890.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247572.113.clone.1) + %add.245892.35.clone.1 = u32[1280,1280]{1,0} add(%add.245889.37.clone.1, %add.245890.99.clone.1) + %shift-left.108531.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245890.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114664.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245890.99.clone.1, %broadcast.244415.6016) + %or.114195.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108531.31.clone.1, %shift-right-logical.114664.29.clone.1) + %xor.120742.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245892.35.clone.1, %or.114195.29.clone.1) + %add.245893.5.clone.1 = u32[1280,1280]{1,0} add(%add.245892.35.clone.1, %xor.120742.27.clone.1) + %shift-left.108533.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120742.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114665.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120742.27.clone.1, %broadcast.244417.5760) + %or.114196.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108533.9.clone.1, %shift-right-logical.114665.9.clone.1) + %xor.120743.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245893.5.clone.1, %or.114196.7.clone.1) + %add.245894.3.clone.1 = u32[1280,1280]{1,0} add(%add.245893.5.clone.1, %xor.120743.5.clone.1) + %shift-left.108534.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120743.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114666.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120743.5.clone.1, %broadcast.244419.4352) + %or.114197.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108534.5.clone.1, %shift-right-logical.114666.5.clone.1) + %xor.120745.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245894.3.clone.1, %or.114197.3.clone.1) + %add.245895.3.clone.1 = u32[1280,1280]{1,0} add(%add.245894.3.clone.1, %xor.120745.3.clone.1) + %add.245897.7.clone.1 = u32[1280,1280]{1,0} add(%add.245895.3.clone.1, %broadcast.247572.113.clone.1) + %shift-left.108535.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120745.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114667.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120745.3.clone.1, %broadcast.244418.4352) + %or.114198.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108535.5.clone.1, %shift-right-logical.114667.5.clone.1) + %xor.120746.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245895.3.clone.1, %or.114198.3.clone.1) + %constant_217960_1_clone_1 = u32[] constant(1935385592) + %broadcast.247582.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217960_1_clone_1), dimensions={} + %add.245898.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120746.3.clone.1, %broadcast.247582.5.clone.1) + %add.245899.5.clone.1 = u32[1280,1280]{1,0} add(%add.245897.7.clone.1, %add.245898.5.clone.1) + %shift-left.108536.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245898.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114668.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245898.5.clone.1, %broadcast.244416.5760) + %or.114199.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108536.9.clone.1, %shift-right-logical.114668.9.clone.1) + %xor.120747.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245899.5.clone.1, %or.114199.7.clone.1) + %add.245900.3.clone.1 = u32[1280,1280]{1,0} add(%add.245899.5.clone.1, %xor.120747.5.clone.1) + %shift-left.108537.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120747.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114669.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120747.5.clone.1, %broadcast.244429.2304) + %or.114200.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108537.9.clone.1, %shift-right-logical.114669.9.clone.1) + %xor.120748.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245900.3.clone.1, %or.114200.7.clone.1) + %add.245901.3.clone.1 = u32[1280,1280]{1,0} add(%add.245900.3.clone.1, %xor.120748.5.clone.1) + %shift-left.108538.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120748.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114671.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120748.5.clone.1, %broadcast.244430.4608) + %or.114201.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108538.9.clone.1, %shift-right-logical.114671.9.clone.1) + %xor.120749.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245901.3.clone.1, %or.114201.7.clone.1) + %add.245903.3.clone.1 = u32[1280,1280]{1,0} add(%add.245901.3.clone.1, %xor.120749.5.clone.1) + %constant_165417_1_clone_1 = u32[] constant(1935385591) + %broadcast.247589.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_165417_1_clone_1), dimensions={} + %add.245906.7.clone.1 = u32[1280,1280]{1,0} add(%add.245903.3.clone.1, %broadcast.247589.24.clone.1) + %shift-left.108539.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120749.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114672.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120749.5.clone.1, %broadcast.244434.2816) + %or.114202.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108539.11.clone.1, %shift-right-logical.114672.11.clone.1) + %xor.120750.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245903.3.clone.1, %or.114202.9.clone.1) + %constant_217961_1_clone_1 = u32[] constant(3110233336) + %broadcast.247592.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217961_1_clone_1), dimensions={} + %add.245907.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120750.7.clone.1, %broadcast.247592.5.clone.1) + %add.245908.5.clone.1 = u32[1280,1280]{1,0} add(%add.245906.7.clone.1, %add.245907.5.clone.1) + %shift-left.108540.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245907.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114673.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245907.5.clone.1, %broadcast.244415.6016) + %or.114203.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108540.9.clone.1, %shift-right-logical.114673.9.clone.1) + %xor.120751.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245908.5.clone.1, %or.114203.7.clone.1) + %add.245909.3.clone.1 = u32[1280,1280]{1,0} add(%add.245908.5.clone.1, %xor.120751.5.clone.1) + %shift-left.108541.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120751.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114674.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120751.5.clone.1, %broadcast.244417.5760) + %or.114204.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108541.9.clone.1, %shift-right-logical.114674.9.clone.1) + %xor.120752.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245909.3.clone.1, %or.114204.7.clone.1) + %add.245910.3.clone.1 = u32[1280,1280]{1,0} add(%add.245909.3.clone.1, %xor.120752.5.clone.1) + %shift-left.108543.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120752.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114676.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120752.5.clone.1, %broadcast.244419.4352) + %or.114205.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108543.7.clone.1, %shift-right-logical.114676.7.clone.1) + %xor.120753.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245910.3.clone.1, %or.114205.5.clone.1) + %add.245911.3.clone.1 = u32[1280,1280]{1,0} add(%add.245910.3.clone.1, %xor.120753.3.clone.1) + %add.245912.7.clone.1 = u32[1280,1280]{1,0} add(%add.245911.3.clone.1, %broadcast.247571.44.clone.1) + %shift-left.108544.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120753.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114677.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120753.3.clone.1, %broadcast.244418.4352) + %or.114206.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108544.7.clone.1, %shift-right-logical.114677.7.clone.1) + %xor.120754.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245911.3.clone.1, %or.114206.5.clone.1) + %constant_217962_1_clone_1 = u32[] constant(3521699038) + %broadcast.247602.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217962_1_clone_1), dimensions={} + %add.245913.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120754.3.clone.1, %broadcast.247602.5.clone.1) + %add.245914.5.clone.1 = u32[1280,1280]{1,0} add(%add.245912.7.clone.1, %add.245913.5.clone.1) + %shift-left.108545.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245913.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114678.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245913.5.clone.1, %broadcast.244416.5760) + %or.114207.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108545.9.clone.1, %shift-right-logical.114678.9.clone.1) + %xor.120755.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245914.5.clone.1, %or.114207.7.clone.1) + %add.245915.3.clone.1 = u32[1280,1280]{1,0} add(%add.245914.5.clone.1, %xor.120755.5.clone.1) + %shift-left.108546.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120755.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114679.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120755.5.clone.1, %broadcast.244429.2304) + %or.114208.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108546.9.clone.1, %shift-right-logical.114679.9.clone.1) + %xor.120756.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245915.3.clone.1, %or.114208.7.clone.1) + %add.245916.3.clone.1 = u32[1280,1280]{1,0} add(%add.245915.3.clone.1, %xor.120756.5.clone.1) + %shift-left.108548.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120756.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114681.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120756.5.clone.1, %broadcast.244430.4608) + %or.114209.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108548.9.clone.1, %shift-right-logical.114681.9.clone.1) + %xor.120757.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245916.3.clone.1, %or.114209.7.clone.1) + %add.245917.3.clone.1 = u32[1280,1280]{1,0} add(%add.245916.3.clone.1, %xor.120757.5.clone.1) + %add.245918.7.clone.1 = u32[1280,1280]{1,0} add(%add.245917.3.clone.1, %broadcast.247572.113.clone.1) + %shift-left.108549.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120757.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114682.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120757.5.clone.1, %broadcast.244434.2816) + %or.114210.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108549.11.clone.1, %shift-right-logical.114682.11.clone.1) + %xor.120758.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245917.3.clone.1, %or.114210.9.clone.1) + %constant_217963_1_clone_1 = u32[] constant(1935385595) + %broadcast.247612.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217963_1_clone_1), dimensions={} + %add.245919.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120758.7.clone.1, %broadcast.247612.5.clone.1) + %add.245920.5.clone.1 = u32[1280,1280]{1,0} add(%add.245918.7.clone.1, %add.245919.5.clone.1) + %shift-left.108550.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245919.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114683.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245919.5.clone.1, %broadcast.244415.6016) + %or.114211.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108550.9.clone.1, %shift-right-logical.114683.9.clone.1) + %xor.120759.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245920.5.clone.1, %or.114211.7.clone.1) + %add.245921.3.clone.1 = u32[1280,1280]{1,0} add(%add.245920.5.clone.1, %xor.120759.5.clone.1) + %shift-left.108551.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120759.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114684.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120759.5.clone.1, %broadcast.244417.5760) + %or.114212.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108551.9.clone.1, %shift-right-logical.114684.9.clone.1) + %xor.120760.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245921.3.clone.1, %or.114212.7.clone.1) + %add.245922.3.clone.1 = u32[1280,1280]{1,0} add(%add.245921.3.clone.1, %xor.120760.5.clone.1) + %shift-left.108553.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120760.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114686.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120760.5.clone.1, %broadcast.244419.4352) + %or.114213.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108553.5.clone.1, %shift-right-logical.114686.5.clone.1) + %xor.120761.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245922.3.clone.1, %or.114213.3.clone.1) + %add.245923.3.clone.1 = u32[1280,1280]{1,0} add(%add.245922.3.clone.1, %xor.120761.3.clone.1) + %add.245924.17.clone.1 = u32[1280,1280]{1,0} add(%add.245923.3.clone.1, %broadcast.247589.24.clone.1) + %shift-left.108554.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120761.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114687.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120761.3.clone.1, %broadcast.244418.4352) + %or.114214.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108554.5.clone.1, %shift-right-logical.114687.5.clone.1) + %xor.120762.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245923.3.clone.1, %or.114214.3.clone.1) + %constant_217964_1_clone_1 = u32[] constant(3110233339) + %broadcast.247622.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217964_1_clone_1), dimensions={} + %add.245925.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120762.15.clone.1, %broadcast.247622.19.clone.1) + %xor.120763.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245924.17.clone.1, %add.245925.19.clone.1) + %shift-right-logical.114688.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120763.17.clone.1, %broadcast.244468.1920) + %or.114215.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114688.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5705.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114215.13.clone.1) + %add.245926.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5705.11.clone.1, %broadcast.244470.1152) + %multiply.25959.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245926.9.clone.1, %broadcast.244471.896) + %add.245927.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25959.7.clone.1, %broadcast.244408.1024) + %maximum.3637.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245927.5.clone.1) + %abs.1507.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3637.3.clone.1) + %compare.7162.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1507.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25960.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3637.3.clone.1, %broadcast.244476.1152) + %negate.4519.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3637.3.clone.1) + %multiply.25961.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3637.3.clone.1, %negate.4519.5.clone.1) + %log-plus-one.1507.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25961.5.clone.1) + %negate.4520.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1507.3.clone.1) + %compare.7163.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4520.4.clone.1, %broadcast.244477.384), direction=LT + %select.20642.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20643.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20644.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20645.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20646.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20647.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20648.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20649.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20650.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245928.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4520.4.clone.1, %broadcast.244496.640) + %sqrt.1507.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4520.4.clone.1) + %add.245929.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1507.5.clone.1, %broadcast.244498.640) + %select.20651.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7163.3.clone.1, %add.245928.5.clone.1, %add.245929.5.clone.1) + %multiply.25962.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20650.3.clone.1, %select.20651.3.clone.1) + %add.245930.1.clone.1 = f32[1280,1280]{1,0} add(%select.20649.3.clone.1, %multiply.25962.1.clone.1) + %multiply.25963.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245930.1.clone.1, %select.20651.3.clone.1) + %add.245931.1.clone.1 = f32[1280,1280]{1,0} add(%select.20648.3.clone.1, %multiply.25963.1.clone.1) + %multiply.25964.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245931.1.clone.1, %select.20651.3.clone.1) + %add.245932.1.clone.1 = f32[1280,1280]{1,0} add(%select.20647.3.clone.1, %multiply.25964.1.clone.1) + %multiply.25965.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245932.1.clone.1, %select.20651.3.clone.1) + %add.245933.1.clone.1 = f32[1280,1280]{1,0} add(%select.20646.3.clone.1, %multiply.25965.1.clone.1) + %multiply.25966.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245933.1.clone.1, %select.20651.3.clone.1) + %add.245934.3.clone.1 = f32[1280,1280]{1,0} add(%select.20645.5.clone.1, %multiply.25966.1.clone.1) + %multiply.25967.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245934.3.clone.1, %select.20651.3.clone.1) + %add.245935.3.clone.1 = f32[1280,1280]{1,0} add(%select.20644.5.clone.1, %multiply.25967.1.clone.1) + %multiply.25968.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245935.3.clone.1, %select.20651.3.clone.1) + %add.245936.9.clone.1 = f32[1280,1280]{1,0} add(%select.20643.11.clone.1, %multiply.25968.7.clone.1) + %multiply.25969.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245936.9.clone.1, %select.20651.3.clone.1) + %add.245937.7.clone.1 = f32[1280,1280]{1,0} add(%select.20642.7.clone.1, %multiply.25969.7.clone.1) + %multiply.25970.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245937.7.clone.1, %maximum.3637.3.clone.1) + %select.20652.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7162.3.clone.1, %multiply.25960.9.clone.1, %multiply.25970.7.clone.1) + %multiply.25971.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20652.7.clone.1, %broadcast.244500.640) + %clamp.1151.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25971.5.clone.1, %broadcast.244501.384) + %multiply.25972.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1151.3.clone.1, %broadcast.244502.1) + %constant_181257_1_clone_1 = u32[] constant(4107876913) + %broadcast.254401.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181257_1_clone_1), dimensions={} + %add.249784.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254401.44.clone.1) + %constant_181264_1_clone_1 = u32[] constant(467309358) + %broadcast.254402.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181264_1_clone_1), dimensions={} + %add.249785.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254402.113.clone.1) + %add.249786.35.clone.1 = u32[1280,1280]{1,0} add(%add.249784.37.clone.1, %add.249785.99.clone.1) + %shift-left.110216.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249785.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116445.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249785.99.clone.1, %broadcast.244415.6016) + %or.115969.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110216.31.clone.1, %shift-right-logical.116445.29.clone.1) + %xor.122532.27.clone.1 = u32[1280,1280]{1,0} xor(%add.249786.35.clone.1, %or.115969.29.clone.1) + %add.249787.5.clone.1 = u32[1280,1280]{1,0} add(%add.249786.35.clone.1, %xor.122532.27.clone.1) + %shift-left.110217.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122532.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116446.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122532.27.clone.1, %broadcast.244417.5760) + %or.115970.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110217.9.clone.1, %shift-right-logical.116446.9.clone.1) + %xor.122533.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249787.5.clone.1, %or.115970.7.clone.1) + %add.249788.3.clone.1 = u32[1280,1280]{1,0} add(%add.249787.5.clone.1, %xor.122533.5.clone.1) + %shift-left.110218.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122533.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116447.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122533.5.clone.1, %broadcast.244419.4352) + %or.115971.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110218.5.clone.1, %shift-right-logical.116447.5.clone.1) + %xor.122535.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249788.3.clone.1, %or.115971.3.clone.1) + %add.249790.3.clone.1 = u32[1280,1280]{1,0} add(%add.249788.3.clone.1, %xor.122535.3.clone.1) + %add.249794.7.clone.1 = u32[1280,1280]{1,0} add(%add.249790.3.clone.1, %broadcast.254402.113.clone.1) + %shift-left.110219.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122535.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116448.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122535.3.clone.1, %broadcast.244418.4352) + %or.115972.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110219.5.clone.1, %shift-right-logical.116448.5.clone.1) + %xor.122536.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249790.3.clone.1, %or.115972.3.clone.1) + %constant_218404_1_clone_1 = u32[] constant(4107449030) + %broadcast.254412.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218404_1_clone_1), dimensions={} + %add.249795.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122536.3.clone.1, %broadcast.254412.5.clone.1) + %add.249796.5.clone.1 = u32[1280,1280]{1,0} add(%add.249794.7.clone.1, %add.249795.5.clone.1) + %shift-left.110221.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249795.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116449.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249795.5.clone.1, %broadcast.244416.5760) + %or.115973.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110221.9.clone.1, %shift-right-logical.116449.9.clone.1) + %xor.122537.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249796.5.clone.1, %or.115973.7.clone.1) + %add.249797.3.clone.1 = u32[1280,1280]{1,0} add(%add.249796.5.clone.1, %xor.122537.5.clone.1) + %shift-left.110222.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122537.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116450.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122537.5.clone.1, %broadcast.244429.2304) + %or.115974.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110222.9.clone.1, %shift-right-logical.116450.9.clone.1) + %xor.122538.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249797.3.clone.1, %or.115974.7.clone.1) + %add.249799.3.clone.1 = u32[1280,1280]{1,0} add(%add.249797.3.clone.1, %xor.122538.5.clone.1) + %shift-left.110223.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122538.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116451.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122538.5.clone.1, %broadcast.244430.4608) + %or.115975.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110223.9.clone.1, %shift-right-logical.116451.9.clone.1) + %xor.122540.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249799.3.clone.1, %or.115975.7.clone.1) + %add.249800.3.clone.1 = u32[1280,1280]{1,0} add(%add.249799.3.clone.1, %xor.122540.5.clone.1) + %constant_181282_1_clone_1 = u32[] constant(4107449029) + %broadcast.254421.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_181282_1_clone_1), dimensions={} + %add.249801.7.clone.1 = u32[1280,1280]{1,0} add(%add.249800.3.clone.1, %broadcast.254421.24.clone.1) + %shift-left.110224.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122540.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116452.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122540.5.clone.1, %broadcast.244434.2816) + %or.115976.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110224.11.clone.1, %shift-right-logical.116452.11.clone.1) + %xor.122541.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249800.3.clone.1, %or.115976.9.clone.1) + %constant_218405_1_clone_1 = u32[] constant(4107876915) + %broadcast.254424.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218405_1_clone_1), dimensions={} + %add.249802.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122541.7.clone.1, %broadcast.254424.5.clone.1) + %add.249804.5.clone.1 = u32[1280,1280]{1,0} add(%add.249801.7.clone.1, %add.249802.5.clone.1) + %shift-left.110226.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249802.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116453.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249802.5.clone.1, %broadcast.244415.6016) + %or.115977.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110226.9.clone.1, %shift-right-logical.116453.9.clone.1) + %xor.122542.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249804.5.clone.1, %or.115977.7.clone.1) + %add.249805.3.clone.1 = u32[1280,1280]{1,0} add(%add.249804.5.clone.1, %xor.122542.5.clone.1) + %shift-left.110227.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122542.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116454.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122542.5.clone.1, %broadcast.244417.5760) + %or.115978.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110227.9.clone.1, %shift-right-logical.116454.9.clone.1) + %xor.122543.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249805.3.clone.1, %or.115978.7.clone.1) + %add.249806.3.clone.1 = u32[1280,1280]{1,0} add(%add.249805.3.clone.1, %xor.122543.5.clone.1) + %shift-left.110228.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122543.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116455.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122543.5.clone.1, %broadcast.244419.4352) + %or.115979.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110228.7.clone.1, %shift-right-logical.116455.7.clone.1) + %xor.122544.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249806.3.clone.1, %or.115979.5.clone.1) + %add.249807.3.clone.1 = u32[1280,1280]{1,0} add(%add.249806.3.clone.1, %xor.122544.3.clone.1) + %add.249809.7.clone.1 = u32[1280,1280]{1,0} add(%add.249807.3.clone.1, %broadcast.254401.44.clone.1) + %shift-left.110229.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122544.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116456.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122544.3.clone.1, %broadcast.244418.4352) + %or.115980.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110229.7.clone.1, %shift-right-logical.116456.7.clone.1) + %xor.122545.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249807.3.clone.1, %or.115980.5.clone.1) + %constant_218406_1_clone_1 = u32[] constant(467309361) + %broadcast.254439.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218406_1_clone_1), dimensions={} + %add.249810.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122545.3.clone.1, %broadcast.254439.5.clone.1) + %add.249811.5.clone.1 = u32[1280,1280]{1,0} add(%add.249809.7.clone.1, %add.249810.5.clone.1) + %shift-left.110231.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249810.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116457.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249810.5.clone.1, %broadcast.244416.5760) + %or.115981.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110231.9.clone.1, %shift-right-logical.116457.9.clone.1) + %xor.122546.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249811.5.clone.1, %or.115981.7.clone.1) + %add.249812.3.clone.1 = u32[1280,1280]{1,0} add(%add.249811.5.clone.1, %xor.122546.5.clone.1) + %shift-left.110232.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122546.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116458.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122546.5.clone.1, %broadcast.244429.2304) + %or.115982.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110232.9.clone.1, %shift-right-logical.116458.9.clone.1) + %xor.122547.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249812.3.clone.1, %or.115982.7.clone.1) + %add.249813.3.clone.1 = u32[1280,1280]{1,0} add(%add.249812.3.clone.1, %xor.122547.5.clone.1) + %shift-left.110233.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122547.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116459.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122547.5.clone.1, %broadcast.244430.4608) + %or.115983.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110233.9.clone.1, %shift-right-logical.116459.9.clone.1) + %xor.122548.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249813.3.clone.1, %or.115983.7.clone.1) + %add.249815.3.clone.1 = u32[1280,1280]{1,0} add(%add.249813.3.clone.1, %xor.122548.5.clone.1) + %add.249819.7.clone.1 = u32[1280,1280]{1,0} add(%add.249815.3.clone.1, %broadcast.254402.113.clone.1) + %shift-left.110234.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122548.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116460.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122548.5.clone.1, %broadcast.244434.2816) + %or.115984.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110234.11.clone.1, %shift-right-logical.116460.11.clone.1) + %xor.122550.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249815.3.clone.1, %or.115984.9.clone.1) + %constant_218407_1_clone_1 = u32[] constant(4107449033) + %broadcast.254451.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218407_1_clone_1), dimensions={} + %add.249820.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122550.7.clone.1, %broadcast.254451.5.clone.1) + %add.249821.5.clone.1 = u32[1280,1280]{1,0} add(%add.249819.7.clone.1, %add.249820.5.clone.1) + %shift-left.110236.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249820.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116461.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249820.5.clone.1, %broadcast.244415.6016) + %or.115985.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110236.9.clone.1, %shift-right-logical.116461.9.clone.1) + %xor.122551.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249821.5.clone.1, %or.115985.7.clone.1) + %add.249822.3.clone.1 = u32[1280,1280]{1,0} add(%add.249821.5.clone.1, %xor.122551.5.clone.1) + %shift-left.110237.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122551.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116462.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122551.5.clone.1, %broadcast.244417.5760) + %or.115986.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110237.9.clone.1, %shift-right-logical.116462.9.clone.1) + %xor.122552.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249822.3.clone.1, %or.115986.7.clone.1) + %add.249824.3.clone.1 = u32[1280,1280]{1,0} add(%add.249822.3.clone.1, %xor.122552.5.clone.1) + %shift-left.110238.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122552.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116463.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122552.5.clone.1, %broadcast.244419.4352) + %or.115987.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110238.5.clone.1, %shift-right-logical.116463.5.clone.1) + %xor.122553.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249824.3.clone.1, %or.115987.3.clone.1) + %add.249825.3.clone.1 = u32[1280,1280]{1,0} add(%add.249824.3.clone.1, %xor.122553.3.clone.1) + %add.249826.17.clone.1 = u32[1280,1280]{1,0} add(%add.249825.3.clone.1, %broadcast.254421.24.clone.1) + %shift-left.110239.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122553.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116464.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122553.3.clone.1, %broadcast.244418.4352) + %or.115988.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110239.5.clone.1, %shift-right-logical.116464.5.clone.1) + %xor.122555.15.clone.1 = u32[1280,1280]{1,0} xor(%add.249825.3.clone.1, %or.115988.3.clone.1) + %constant_218408_1_clone_1 = u32[] constant(4107876918) + %broadcast.254464.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218408_1_clone_1), dimensions={} + %add.249827.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122555.15.clone.1, %broadcast.254464.19.clone.1) + %xor.122556.17.clone.1 = u32[1280,1280]{1,0} xor(%add.249826.17.clone.1, %add.249827.19.clone.1) + %shift-right-logical.116465.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122556.17.clone.1, %broadcast.244468.1920) + %or.115989.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116465.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5783.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115989.13.clone.1) + %add.249829.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5783.11.clone.1, %broadcast.244470.1152) + %multiply.26755.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249829.9.clone.1, %broadcast.244471.896) + %add.249830.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26755.7.clone.1, %broadcast.244408.1024) + %maximum.3715.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.249830.5.clone.1) + %abs.1559.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3715.3.clone.1) + %compare.7267.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1559.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26756.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3715.3.clone.1, %broadcast.244476.1152) + %negate.4623.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3715.3.clone.1) + %multiply.26757.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3715.3.clone.1, %negate.4623.5.clone.1) + %log-plus-one.1559.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26757.5.clone.1) + %negate.4624.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1559.3.clone.1) + %compare.7268.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4624.4.clone.1, %broadcast.244477.384), direction=LT + %select.21235.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21236.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21237.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21238.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21239.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21240.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21241.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21242.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21243.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.249831.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4624.4.clone.1, %broadcast.244496.640) + %sqrt.1559.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4624.4.clone.1) + %add.249832.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1559.5.clone.1, %broadcast.244498.640) + %select.21244.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7268.3.clone.1, %add.249831.5.clone.1, %add.249832.5.clone.1) + %multiply.26758.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21243.3.clone.1, %select.21244.3.clone.1) + %add.249834.1.clone.1 = f32[1280,1280]{1,0} add(%select.21242.3.clone.1, %multiply.26758.1.clone.1) + %multiply.26759.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249834.1.clone.1, %select.21244.3.clone.1) + %add.249835.1.clone.1 = f32[1280,1280]{1,0} add(%select.21241.3.clone.1, %multiply.26759.1.clone.1) + %multiply.26760.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249835.1.clone.1, %select.21244.3.clone.1) + %add.249836.1.clone.1 = f32[1280,1280]{1,0} add(%select.21240.3.clone.1, %multiply.26760.1.clone.1) + %multiply.26761.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249836.1.clone.1, %select.21244.3.clone.1) + %add.249837.1.clone.1 = f32[1280,1280]{1,0} add(%select.21239.3.clone.1, %multiply.26761.1.clone.1) + %multiply.26762.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249837.1.clone.1, %select.21244.3.clone.1) + %add.249838.3.clone.1 = f32[1280,1280]{1,0} add(%select.21238.5.clone.1, %multiply.26762.1.clone.1) + %multiply.26763.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249838.3.clone.1, %select.21244.3.clone.1) + %add.249840.3.clone.1 = f32[1280,1280]{1,0} add(%select.21237.5.clone.1, %multiply.26763.1.clone.1) + %multiply.26764.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249840.3.clone.1, %select.21244.3.clone.1) + %add.249844.9.clone.1 = f32[1280,1280]{1,0} add(%select.21236.11.clone.1, %multiply.26764.7.clone.1) + %multiply.26765.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249844.9.clone.1, %select.21244.3.clone.1) + %add.249845.7.clone.1 = f32[1280,1280]{1,0} add(%select.21235.7.clone.1, %multiply.26765.7.clone.1) + %multiply.26766.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249845.7.clone.1, %maximum.3715.3.clone.1) + %select.21245.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7267.3.clone.1, %multiply.26756.9.clone.1, %multiply.26766.7.clone.1) + %multiply.26767.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21245.7.clone.1, %broadcast.244500.640) + %clamp.1203.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26767.5.clone.1, %broadcast.244501.384) + %multiply.26768.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1203.3.clone.1, %broadcast.244502.1) + %constant_164857_1_clone_1 = u32[] constant(1739548742) + %broadcast.247340.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164857_1_clone_1), dimensions={} + %add.245744.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247340.44.clone.1) + %constant_164864_1_clone_1 = u32[] constant(2992964051) + %broadcast.247341.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164864_1_clone_1), dimensions={} + %add.245745.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247341.113.clone.1) + %add.245746.35.clone.1 = u32[1280,1280]{1,0} add(%add.245744.37.clone.1, %add.245745.99.clone.1) + %shift-left.108460.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245745.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114593.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245745.99.clone.1, %broadcast.244415.6016) + %or.114121.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108460.31.clone.1, %shift-right-logical.114593.29.clone.1) + %xor.120673.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245746.35.clone.1, %or.114121.29.clone.1) + %add.245748.5.clone.1 = u32[1280,1280]{1,0} add(%add.245746.35.clone.1, %xor.120673.27.clone.1) + %shift-left.108461.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120673.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114594.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120673.27.clone.1, %broadcast.244417.5760) + %or.114122.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108461.9.clone.1, %shift-right-logical.114594.9.clone.1) + %xor.120674.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245748.5.clone.1, %or.114122.7.clone.1) + %add.245749.3.clone.1 = u32[1280,1280]{1,0} add(%add.245748.5.clone.1, %xor.120674.5.clone.1) + %shift-left.108462.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120674.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114595.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120674.5.clone.1, %broadcast.244419.4352) + %or.114123.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108462.5.clone.1, %shift-right-logical.114595.5.clone.1) + %xor.120676.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245749.3.clone.1, %or.114123.3.clone.1) + %add.245750.3.clone.1 = u32[1280,1280]{1,0} add(%add.245749.3.clone.1, %xor.120676.3.clone.1) + %add.245751.7.clone.1 = u32[1280,1280]{1,0} add(%add.245750.3.clone.1, %broadcast.247341.113.clone.1) + %shift-left.108463.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120676.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114596.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120676.3.clone.1, %broadcast.244418.4352) + %or.114124.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108463.5.clone.1, %shift-right-logical.114596.5.clone.1) + %xor.120677.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245750.3.clone.1, %or.114124.3.clone.1) + %constant_217944_1_clone_1 = u32[] constant(3457905232) + %broadcast.247351.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217944_1_clone_1), dimensions={} + %add.245752.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120677.3.clone.1, %broadcast.247351.5.clone.1) + %add.245753.5.clone.1 = u32[1280,1280]{1,0} add(%add.245751.7.clone.1, %add.245752.5.clone.1) + %shift-left.108464.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245752.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114597.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245752.5.clone.1, %broadcast.244416.5760) + %or.114126.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108464.9.clone.1, %shift-right-logical.114597.9.clone.1) + %xor.120678.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245753.5.clone.1, %or.114126.7.clone.1) + %add.245754.3.clone.1 = u32[1280,1280]{1,0} add(%add.245753.5.clone.1, %xor.120678.5.clone.1) + %shift-left.108465.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120678.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114598.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120678.5.clone.1, %broadcast.244429.2304) + %or.114127.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108465.9.clone.1, %shift-right-logical.114598.9.clone.1) + %xor.120679.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245754.3.clone.1, %or.114127.7.clone.1) + %add.245755.3.clone.1 = u32[1280,1280]{1,0} add(%add.245754.3.clone.1, %xor.120679.5.clone.1) + %shift-left.108466.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120679.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114599.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120679.5.clone.1, %broadcast.244430.4608) + %or.114128.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108466.9.clone.1, %shift-right-logical.114599.9.clone.1) + %xor.120681.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245755.3.clone.1, %or.114128.7.clone.1) + %add.245756.3.clone.1 = u32[1280,1280]{1,0} add(%add.245755.3.clone.1, %xor.120681.5.clone.1) + %constant_164866_1_clone_1 = u32[] constant(3457905231) + %broadcast.247360.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164866_1_clone_1), dimensions={} + %add.245757.7.clone.1 = u32[1280,1280]{1,0} add(%add.245756.3.clone.1, %broadcast.247360.24.clone.1) + %shift-left.108468.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120681.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114600.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120681.5.clone.1, %broadcast.244434.2816) + %or.114129.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108468.11.clone.1, %shift-right-logical.114600.11.clone.1) + %xor.120682.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245756.3.clone.1, %or.114129.9.clone.1) + %constant_217945_1_clone_1 = u32[] constant(1739548744) + %broadcast.247363.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217945_1_clone_1), dimensions={} + %add.245758.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120682.7.clone.1, %broadcast.247363.5.clone.1) + %add.245759.5.clone.1 = u32[1280,1280]{1,0} add(%add.245757.7.clone.1, %add.245758.5.clone.1) + %shift-left.108469.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245758.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114601.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245758.5.clone.1, %broadcast.244415.6016) + %or.114131.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108469.9.clone.1, %shift-right-logical.114601.9.clone.1) + %xor.120683.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245759.5.clone.1, %or.114131.7.clone.1) + %add.245760.3.clone.1 = u32[1280,1280]{1,0} add(%add.245759.5.clone.1, %xor.120683.5.clone.1) + %shift-left.108470.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120683.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114602.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120683.5.clone.1, %broadcast.244417.5760) + %or.114132.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108470.9.clone.1, %shift-right-logical.114602.9.clone.1) + %xor.120684.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245760.3.clone.1, %or.114132.7.clone.1) + %add.245761.3.clone.1 = u32[1280,1280]{1,0} add(%add.245760.3.clone.1, %xor.120684.5.clone.1) + %shift-left.108471.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120684.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114603.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120684.5.clone.1, %broadcast.244419.4352) + %or.114133.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108471.7.clone.1, %shift-right-logical.114603.7.clone.1) + %xor.120685.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245761.3.clone.1, %or.114133.5.clone.1) + %add.245762.3.clone.1 = u32[1280,1280]{1,0} add(%add.245761.3.clone.1, %xor.120685.3.clone.1) + %add.245763.7.clone.1 = u32[1280,1280]{1,0} add(%add.245762.3.clone.1, %broadcast.247340.44.clone.1) + %shift-left.108473.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120685.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114604.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120685.3.clone.1, %broadcast.244418.4352) + %or.114134.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108473.7.clone.1, %shift-right-logical.114604.7.clone.1) + %xor.120686.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245762.3.clone.1, %or.114134.5.clone.1) + %constant_217946_1_clone_1 = u32[] constant(2992964054) + %broadcast.247373.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217946_1_clone_1), dimensions={} + %add.245764.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120686.3.clone.1, %broadcast.247373.5.clone.1) + %add.245765.5.clone.1 = u32[1280,1280]{1,0} add(%add.245763.7.clone.1, %add.245764.5.clone.1) + %shift-left.108474.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245764.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114605.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245764.5.clone.1, %broadcast.244416.5760) + %or.114136.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108474.9.clone.1, %shift-right-logical.114605.9.clone.1) + %xor.120687.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245765.5.clone.1, %or.114136.7.clone.1) + %add.245766.3.clone.1 = u32[1280,1280]{1,0} add(%add.245765.5.clone.1, %xor.120687.5.clone.1) + %shift-left.108475.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120687.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114606.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120687.5.clone.1, %broadcast.244429.2304) + %or.114137.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108475.9.clone.1, %shift-right-logical.114606.9.clone.1) + %xor.120688.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245766.3.clone.1, %or.114137.7.clone.1) + %add.245767.3.clone.1 = u32[1280,1280]{1,0} add(%add.245766.3.clone.1, %xor.120688.5.clone.1) + %shift-left.108476.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120688.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114607.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120688.5.clone.1, %broadcast.244430.4608) + %or.114138.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108476.9.clone.1, %shift-right-logical.114607.9.clone.1) + %xor.120689.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245767.3.clone.1, %or.114138.7.clone.1) + %add.245768.3.clone.1 = u32[1280,1280]{1,0} add(%add.245767.3.clone.1, %xor.120689.5.clone.1) + %add.245769.7.clone.1 = u32[1280,1280]{1,0} add(%add.245768.3.clone.1, %broadcast.247341.113.clone.1) + %shift-left.108478.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120689.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114608.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120689.5.clone.1, %broadcast.244434.2816) + %or.114139.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108478.11.clone.1, %shift-right-logical.114608.11.clone.1) + %xor.120691.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245768.3.clone.1, %or.114139.9.clone.1) + %constant_217948_1_clone_1 = u32[] constant(3457905235) + %broadcast.247385.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217948_1_clone_1), dimensions={} + %add.245770.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120691.7.clone.1, %broadcast.247385.5.clone.1) + %add.245771.5.clone.1 = u32[1280,1280]{1,0} add(%add.245769.7.clone.1, %add.245770.5.clone.1) + %shift-left.108479.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245770.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114609.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245770.5.clone.1, %broadcast.244415.6016) + %or.114140.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108479.9.clone.1, %shift-right-logical.114609.9.clone.1) + %xor.120692.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245771.5.clone.1, %or.114140.7.clone.1) + %add.245772.3.clone.1 = u32[1280,1280]{1,0} add(%add.245771.5.clone.1, %xor.120692.5.clone.1) + %shift-left.108480.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120692.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114610.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120692.5.clone.1, %broadcast.244417.5760) + %or.114141.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108480.9.clone.1, %shift-right-logical.114610.9.clone.1) + %xor.120693.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245772.3.clone.1, %or.114141.7.clone.1) + %add.245773.3.clone.1 = u32[1280,1280]{1,0} add(%add.245772.3.clone.1, %xor.120693.5.clone.1) + %shift-left.108481.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120693.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114611.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120693.5.clone.1, %broadcast.244419.4352) + %or.114142.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108481.5.clone.1, %shift-right-logical.114611.5.clone.1) + %xor.120694.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245773.3.clone.1, %or.114142.3.clone.1) + %add.245774.3.clone.1 = u32[1280,1280]{1,0} add(%add.245773.3.clone.1, %xor.120694.3.clone.1) + %add.245775.17.clone.1 = u32[1280,1280]{1,0} add(%add.245774.3.clone.1, %broadcast.247360.24.clone.1) + %shift-left.108483.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120694.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114612.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120694.3.clone.1, %broadcast.244418.4352) + %or.114143.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108483.5.clone.1, %shift-right-logical.114612.5.clone.1) + %xor.120696.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245774.3.clone.1, %or.114143.3.clone.1) + %constant_217949_1_clone_1 = u32[] constant(1739548747) + %broadcast.247395.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217949_1_clone_1), dimensions={} + %add.245776.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120696.15.clone.1, %broadcast.247395.19.clone.1) + %xor.120697.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245775.17.clone.1, %add.245776.19.clone.1) + %shift-right-logical.114613.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120697.17.clone.1, %broadcast.244468.1920) + %or.114144.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114613.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5702.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114144.13.clone.1) + %add.245777.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5702.11.clone.1, %broadcast.244470.1152) + %multiply.25933.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245777.9.clone.1, %broadcast.244471.896) + %add.245779.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25933.7.clone.1, %broadcast.244408.1024) + %maximum.3634.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245779.5.clone.1) + %abs.1506.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3634.3.clone.1) + %compare.7160.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1506.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25935.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3634.3.clone.1, %broadcast.244476.1152) + %negate.4517.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3634.3.clone.1) + %multiply.25936.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3634.3.clone.1, %negate.4517.5.clone.1) + %log-plus-one.1506.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25936.5.clone.1) + %negate.4518.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1506.3.clone.1) + %compare.7161.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4518.4.clone.1, %broadcast.244477.384), direction=LT + %select.20631.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20632.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20633.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20634.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20635.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20636.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20637.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20638.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20639.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245782.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4518.4.clone.1, %broadcast.244496.640) + %sqrt.1506.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4518.4.clone.1) + %add.245783.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1506.5.clone.1, %broadcast.244498.640) + %select.20640.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7161.3.clone.1, %add.245782.5.clone.1, %add.245783.5.clone.1) + %multiply.25938.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20639.3.clone.1, %select.20640.3.clone.1) + %add.245784.1.clone.1 = f32[1280,1280]{1,0} add(%select.20638.3.clone.1, %multiply.25938.1.clone.1) + %multiply.25939.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245784.1.clone.1, %select.20640.3.clone.1) + %add.245785.1.clone.1 = f32[1280,1280]{1,0} add(%select.20637.3.clone.1, %multiply.25939.1.clone.1) + %multiply.25941.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245785.1.clone.1, %select.20640.3.clone.1) + %add.245787.1.clone.1 = f32[1280,1280]{1,0} add(%select.20636.3.clone.1, %multiply.25941.1.clone.1) + %multiply.25942.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245787.1.clone.1, %select.20640.3.clone.1) + %add.245788.1.clone.1 = f32[1280,1280]{1,0} add(%select.20635.3.clone.1, %multiply.25942.1.clone.1) + %multiply.25944.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245788.1.clone.1, %select.20640.3.clone.1) + %add.245789.3.clone.1 = f32[1280,1280]{1,0} add(%select.20634.5.clone.1, %multiply.25944.1.clone.1) + %multiply.25945.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245789.3.clone.1, %select.20640.3.clone.1) + %add.245790.3.clone.1 = f32[1280,1280]{1,0} add(%select.20633.5.clone.1, %multiply.25945.1.clone.1) + %multiply.25947.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245790.3.clone.1, %select.20640.3.clone.1) + %add.245792.9.clone.1 = f32[1280,1280]{1,0} add(%select.20632.11.clone.1, %multiply.25947.7.clone.1) + %multiply.25949.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245792.9.clone.1, %select.20640.3.clone.1) + %add.245793.7.clone.1 = f32[1280,1280]{1,0} add(%select.20631.7.clone.1, %multiply.25949.7.clone.1) + %multiply.25950.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245793.7.clone.1, %maximum.3634.3.clone.1) + %select.20641.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7160.3.clone.1, %multiply.25935.9.clone.1, %multiply.25950.7.clone.1) + %multiply.25952.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20641.7.clone.1, %broadcast.244500.640) + %clamp.1150.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25952.5.clone.1, %broadcast.244501.384) + %multiply.25954.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1150.3.clone.1, %broadcast.244502.1) + %constant_189020_1_clone_1 = u32[] constant(2412923323) + %broadcast.257777.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189020_1_clone_1), dimensions={} + %add.251705.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.257777.44.clone.1) + %constant_189027_1_clone_1 = u32[] constant(1171010214) + %broadcast.257779.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189027_1_clone_1), dimensions={} + %add.251706.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.257779.113.clone.1) + %add.251707.35.clone.1 = u32[1280,1280]{1,0} add(%add.251705.37.clone.1, %add.251706.99.clone.1) + %shift-left.111060.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251706.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117329.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251706.99.clone.1, %broadcast.244415.6016) + %or.116853.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111060.31.clone.1, %shift-right-logical.117329.29.clone.1) + %xor.123410.27.clone.1 = u32[1280,1280]{1,0} xor(%add.251707.35.clone.1, %or.116853.29.clone.1) + %add.251709.5.clone.1 = u32[1280,1280]{1,0} add(%add.251707.35.clone.1, %xor.123410.27.clone.1) + %shift-left.111061.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123410.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117330.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123410.27.clone.1, %broadcast.244417.5760) + %or.116854.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111061.9.clone.1, %shift-right-logical.117330.9.clone.1) + %xor.123411.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251709.5.clone.1, %or.116854.7.clone.1) + %add.251710.3.clone.1 = u32[1280,1280]{1,0} add(%add.251709.5.clone.1, %xor.123411.5.clone.1) + %shift-left.111062.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123411.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117331.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123411.5.clone.1, %broadcast.244419.4352) + %or.116855.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111062.5.clone.1, %shift-right-logical.117331.5.clone.1) + %xor.123412.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251710.3.clone.1, %or.116855.3.clone.1) + %add.251711.3.clone.1 = u32[1280,1280]{1,0} add(%add.251710.3.clone.1, %xor.123412.3.clone.1) + %add.251712.7.clone.1 = u32[1280,1280]{1,0} add(%add.251711.3.clone.1, %broadcast.257779.113.clone.1) + %shift-left.111063.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123412.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117332.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123412.3.clone.1, %broadcast.244418.4352) + %or.116856.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111063.5.clone.1, %shift-right-logical.117332.5.clone.1) + %xor.123413.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251711.3.clone.1, %or.116856.3.clone.1) + %constant_218615_1_clone_1 = u32[] constant(3520028872) + %broadcast.257789.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218615_1_clone_1), dimensions={} + %add.251713.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123413.3.clone.1, %broadcast.257789.5.clone.1) + %add.251715.5.clone.1 = u32[1280,1280]{1,0} add(%add.251712.7.clone.1, %add.251713.5.clone.1) + %shift-left.111064.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251713.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117333.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251713.5.clone.1, %broadcast.244416.5760) + %or.116857.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111064.9.clone.1, %shift-right-logical.117333.9.clone.1) + %xor.123414.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251715.5.clone.1, %or.116857.7.clone.1) + %add.251718.3.clone.1 = u32[1280,1280]{1,0} add(%add.251715.5.clone.1, %xor.123414.5.clone.1) + %shift-left.111065.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123414.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117334.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123414.5.clone.1, %broadcast.244429.2304) + %or.116858.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111065.9.clone.1, %shift-right-logical.117334.9.clone.1) + %xor.123416.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251718.3.clone.1, %or.116858.7.clone.1) + %add.251719.3.clone.1 = u32[1280,1280]{1,0} add(%add.251718.3.clone.1, %xor.123416.5.clone.1) + %shift-left.111066.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123416.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117335.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123416.5.clone.1, %broadcast.244430.4608) + %or.116859.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111066.9.clone.1, %shift-right-logical.117335.9.clone.1) + %xor.123417.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251719.3.clone.1, %or.116859.7.clone.1) + %add.251720.3.clone.1 = u32[1280,1280]{1,0} add(%add.251719.3.clone.1, %xor.123417.5.clone.1) + %constant_189029_1_clone_1 = u32[] constant(3520028871) + %broadcast.257796.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_189029_1_clone_1), dimensions={} + %add.251721.7.clone.1 = u32[1280,1280]{1,0} add(%add.251720.3.clone.1, %broadcast.257796.24.clone.1) + %shift-left.111067.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123417.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117336.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123417.5.clone.1, %broadcast.244434.2816) + %or.116860.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111067.11.clone.1, %shift-right-logical.117336.11.clone.1) + %xor.123418.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251720.3.clone.1, %or.116860.9.clone.1) + %constant_218616_1_clone_1 = u32[] constant(2412923325) + %broadcast.257799.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218616_1_clone_1), dimensions={} + %add.251722.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123418.7.clone.1, %broadcast.257799.5.clone.1) + %add.251723.5.clone.1 = u32[1280,1280]{1,0} add(%add.251721.7.clone.1, %add.251722.5.clone.1) + %shift-left.111068.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251722.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117337.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251722.5.clone.1, %broadcast.244415.6016) + %or.116861.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111068.9.clone.1, %shift-right-logical.117337.9.clone.1) + %xor.123419.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251723.5.clone.1, %or.116861.7.clone.1) + %add.251724.3.clone.1 = u32[1280,1280]{1,0} add(%add.251723.5.clone.1, %xor.123419.5.clone.1) + %shift-left.111069.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123419.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117338.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123419.5.clone.1, %broadcast.244417.5760) + %or.116862.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111069.9.clone.1, %shift-right-logical.117338.9.clone.1) + %xor.123421.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251724.3.clone.1, %or.116862.7.clone.1) + %add.251725.3.clone.1 = u32[1280,1280]{1,0} add(%add.251724.3.clone.1, %xor.123421.5.clone.1) + %shift-left.111070.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123421.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117339.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123421.5.clone.1, %broadcast.244419.4352) + %or.116863.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111070.7.clone.1, %shift-right-logical.117339.7.clone.1) + %xor.123422.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251725.3.clone.1, %or.116863.5.clone.1) + %add.251726.3.clone.1 = u32[1280,1280]{1,0} add(%add.251725.3.clone.1, %xor.123422.3.clone.1) + %add.251727.7.clone.1 = u32[1280,1280]{1,0} add(%add.251726.3.clone.1, %broadcast.257777.44.clone.1) + %shift-left.111071.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123422.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117340.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123422.3.clone.1, %broadcast.244418.4352) + %or.116864.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111071.7.clone.1, %shift-right-logical.117340.7.clone.1) + %xor.123423.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251726.3.clone.1, %or.116864.5.clone.1) + %constant_218617_1_clone_1 = u32[] constant(1171010217) + %broadcast.257809.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218617_1_clone_1), dimensions={} + %add.251728.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123423.3.clone.1, %broadcast.257809.5.clone.1) + %add.251729.5.clone.1 = u32[1280,1280]{1,0} add(%add.251727.7.clone.1, %add.251728.5.clone.1) + %shift-left.111072.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251728.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117341.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251728.5.clone.1, %broadcast.244416.5760) + %or.116865.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111072.9.clone.1, %shift-right-logical.117341.9.clone.1) + %xor.123424.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251729.5.clone.1, %or.116865.7.clone.1) + %add.251730.3.clone.1 = u32[1280,1280]{1,0} add(%add.251729.5.clone.1, %xor.123424.5.clone.1) + %shift-left.111073.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123424.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117342.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123424.5.clone.1, %broadcast.244429.2304) + %or.116866.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111073.9.clone.1, %shift-right-logical.117342.9.clone.1) + %xor.123426.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251730.3.clone.1, %or.116866.7.clone.1) + %add.251731.3.clone.1 = u32[1280,1280]{1,0} add(%add.251730.3.clone.1, %xor.123426.5.clone.1) + %shift-left.111074.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123426.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117343.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123426.5.clone.1, %broadcast.244430.4608) + %or.116867.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111074.9.clone.1, %shift-right-logical.117343.9.clone.1) + %xor.123427.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251731.3.clone.1, %or.116867.7.clone.1) + %add.251732.3.clone.1 = u32[1280,1280]{1,0} add(%add.251731.3.clone.1, %xor.123427.5.clone.1) + %add.251733.7.clone.1 = u32[1280,1280]{1,0} add(%add.251732.3.clone.1, %broadcast.257779.113.clone.1) + %shift-left.111075.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123427.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117344.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123427.5.clone.1, %broadcast.244434.2816) + %or.116868.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111075.11.clone.1, %shift-right-logical.117344.11.clone.1) + %xor.123428.7.clone.1 = u32[1280,1280]{1,0} xor(%add.251732.3.clone.1, %or.116868.9.clone.1) + %constant_218618_1_clone_1 = u32[] constant(3520028875) + %broadcast.257819.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218618_1_clone_1), dimensions={} + %add.251734.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123428.7.clone.1, %broadcast.257819.5.clone.1) + %add.251735.5.clone.1 = u32[1280,1280]{1,0} add(%add.251733.7.clone.1, %add.251734.5.clone.1) + %shift-left.111076.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.251734.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117345.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.251734.5.clone.1, %broadcast.244415.6016) + %or.116869.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111076.9.clone.1, %shift-right-logical.117345.9.clone.1) + %xor.123429.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251735.5.clone.1, %or.116869.7.clone.1) + %add.251736.3.clone.1 = u32[1280,1280]{1,0} add(%add.251735.5.clone.1, %xor.123429.5.clone.1) + %shift-left.111077.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123429.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117346.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123429.5.clone.1, %broadcast.244417.5760) + %or.116871.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111077.9.clone.1, %shift-right-logical.117346.9.clone.1) + %xor.123431.5.clone.1 = u32[1280,1280]{1,0} xor(%add.251736.3.clone.1, %or.116871.7.clone.1) + %add.251737.3.clone.1 = u32[1280,1280]{1,0} add(%add.251736.3.clone.1, %xor.123431.5.clone.1) + %shift-left.111078.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123431.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117347.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123431.5.clone.1, %broadcast.244419.4352) + %or.116872.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111078.5.clone.1, %shift-right-logical.117347.5.clone.1) + %xor.123432.3.clone.1 = u32[1280,1280]{1,0} xor(%add.251737.3.clone.1, %or.116872.3.clone.1) + %add.251738.3.clone.1 = u32[1280,1280]{1,0} add(%add.251737.3.clone.1, %xor.123432.3.clone.1) + %add.251739.17.clone.1 = u32[1280,1280]{1,0} add(%add.251738.3.clone.1, %broadcast.257796.24.clone.1) + %shift-left.111079.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123432.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117348.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123432.3.clone.1, %broadcast.244418.4352) + %or.116873.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111079.5.clone.1, %shift-right-logical.117348.5.clone.1) + %xor.123433.15.clone.1 = u32[1280,1280]{1,0} xor(%add.251738.3.clone.1, %or.116873.3.clone.1) + %constant_218619_1_clone_1 = u32[] constant(2412923328) + %broadcast.257829.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218619_1_clone_1), dimensions={} + %add.251740.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123433.15.clone.1, %broadcast.257829.19.clone.1) + %xor.123434.17.clone.1 = u32[1280,1280]{1,0} xor(%add.251739.17.clone.1, %add.251740.19.clone.1) + %shift-right-logical.117349.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123434.17.clone.1, %broadcast.244468.1920) + %or.116874.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117349.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5821.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.116874.13.clone.1) + %add.251741.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5821.11.clone.1, %broadcast.244470.1152) + %multiply.27158.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251741.9.clone.1, %broadcast.244471.896) + %add.251742.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27158.7.clone.1, %broadcast.244408.1024) + %maximum.3753.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.251742.5.clone.1) + %abs.1585.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3753.3.clone.1) + %compare.7332.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1585.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27159.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3753.3.clone.1, %broadcast.244476.1152) + %negate.4675.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3753.3.clone.1) + %multiply.27160.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3753.3.clone.1, %negate.4675.5.clone.1) + %log-plus-one.1585.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27160.5.clone.1) + %negate.4676.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1585.3.clone.1) + %compare.7333.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4676.4.clone.1, %broadcast.244477.384), direction=LT + %select.21526.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21528.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21529.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21531.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21532.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21534.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21535.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21537.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21538.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.251743.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4676.4.clone.1, %broadcast.244496.640) + %sqrt.1585.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4676.4.clone.1) + %add.251744.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1585.5.clone.1, %broadcast.244498.640) + %select.21540.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7333.3.clone.1, %add.251743.5.clone.1, %add.251744.5.clone.1) + %multiply.27161.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21538.3.clone.1, %select.21540.3.clone.1) + %add.251745.1.clone.1 = f32[1280,1280]{1,0} add(%select.21537.3.clone.1, %multiply.27161.1.clone.1) + %multiply.27162.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251745.1.clone.1, %select.21540.3.clone.1) + %add.251746.1.clone.1 = f32[1280,1280]{1,0} add(%select.21535.3.clone.1, %multiply.27162.1.clone.1) + %multiply.27163.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251746.1.clone.1, %select.21540.3.clone.1) + %add.251747.1.clone.1 = f32[1280,1280]{1,0} add(%select.21534.3.clone.1, %multiply.27163.1.clone.1) + %multiply.27164.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251747.1.clone.1, %select.21540.3.clone.1) + %add.251748.1.clone.1 = f32[1280,1280]{1,0} add(%select.21532.3.clone.1, %multiply.27164.1.clone.1) + %multiply.27165.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251748.1.clone.1, %select.21540.3.clone.1) + %add.251749.3.clone.1 = f32[1280,1280]{1,0} add(%select.21531.5.clone.1, %multiply.27165.1.clone.1) + %multiply.27166.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.251749.3.clone.1, %select.21540.3.clone.1) + %add.251750.3.clone.1 = f32[1280,1280]{1,0} add(%select.21529.5.clone.1, %multiply.27166.1.clone.1) + %multiply.27167.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251750.3.clone.1, %select.21540.3.clone.1) + %add.251751.9.clone.1 = f32[1280,1280]{1,0} add(%select.21528.11.clone.1, %multiply.27167.7.clone.1) + %multiply.27168.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251751.9.clone.1, %select.21540.3.clone.1) + %add.251752.7.clone.1 = f32[1280,1280]{1,0} add(%select.21526.7.clone.1, %multiply.27168.7.clone.1) + %multiply.27169.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.251752.7.clone.1, %maximum.3753.3.clone.1) + %select.21541.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7332.3.clone.1, %multiply.27159.9.clone.1, %multiply.27169.7.clone.1) + %multiply.27170.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21541.7.clone.1, %broadcast.244500.640) + %clamp.1229.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27170.5.clone.1, %broadcast.244501.384) + %multiply.27171.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1229.3.clone.1, %broadcast.244502.1) + %constant_164647_1_clone_1 = u32[] constant(2003854816) + %broadcast.247235.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164647_1_clone_1), dimensions={} + %add.245697.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247235.44.clone.1) + %constant_164654_1_clone_1 = u32[] constant(2270062255) + %broadcast.247236.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164654_1_clone_1), dimensions={} + %add.245700.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247236.113.clone.1) + %add.245701.35.clone.1 = u32[1280,1280]{1,0} add(%add.245697.37.clone.1, %add.245700.99.clone.1) + %shift-left.108440.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245700.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114572.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245700.99.clone.1, %broadcast.244415.6016) + %or.114096.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108440.31.clone.1, %shift-right-logical.114572.29.clone.1) + %xor.120648.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245701.35.clone.1, %or.114096.29.clone.1) + %add.245702.5.clone.1 = u32[1280,1280]{1,0} add(%add.245701.35.clone.1, %xor.120648.27.clone.1) + %shift-left.108441.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120648.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114573.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120648.27.clone.1, %broadcast.244417.5760) + %or.114097.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108441.9.clone.1, %shift-right-logical.114573.9.clone.1) + %xor.120649.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245702.5.clone.1, %or.114097.7.clone.1) + %add.245703.3.clone.1 = u32[1280,1280]{1,0} add(%add.245702.5.clone.1, %xor.120649.5.clone.1) + %shift-left.108442.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120649.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114574.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120649.5.clone.1, %broadcast.244419.4352) + %or.114098.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108442.5.clone.1, %shift-right-logical.114574.5.clone.1) + %xor.120651.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245703.3.clone.1, %or.114098.3.clone.1) + %add.245704.3.clone.1 = u32[1280,1280]{1,0} add(%add.245703.3.clone.1, %xor.120651.3.clone.1) + %add.245705.7.clone.1 = u32[1280,1280]{1,0} add(%add.245704.3.clone.1, %broadcast.247236.113.clone.1) + %shift-left.108443.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120651.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114575.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120651.3.clone.1, %broadcast.244418.4352) + %or.114099.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108443.5.clone.1, %shift-right-logical.114575.5.clone.1) + %xor.120652.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245704.3.clone.1, %or.114099.3.clone.1) + %constant_217939_1_clone_1 = u32[] constant(3958315158) + %broadcast.247246.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217939_1_clone_1), dimensions={} + %add.245706.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120652.3.clone.1, %broadcast.247246.5.clone.1) + %add.245707.5.clone.1 = u32[1280,1280]{1,0} add(%add.245705.7.clone.1, %add.245706.5.clone.1) + %shift-left.108444.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245706.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114576.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245706.5.clone.1, %broadcast.244416.5760) + %or.114101.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108444.9.clone.1, %shift-right-logical.114576.9.clone.1) + %xor.120653.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245707.5.clone.1, %or.114101.7.clone.1) + %add.245708.3.clone.1 = u32[1280,1280]{1,0} add(%add.245707.5.clone.1, %xor.120653.5.clone.1) + %shift-left.108445.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120653.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114577.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120653.5.clone.1, %broadcast.244429.2304) + %or.114102.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108445.9.clone.1, %shift-right-logical.114577.9.clone.1) + %xor.120654.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245708.3.clone.1, %or.114102.7.clone.1) + %add.245709.3.clone.1 = u32[1280,1280]{1,0} add(%add.245708.3.clone.1, %xor.120654.5.clone.1) + %shift-left.108446.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120654.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114578.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120654.5.clone.1, %broadcast.244430.4608) + %or.114103.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108446.9.clone.1, %shift-right-logical.114578.9.clone.1) + %xor.120656.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245709.3.clone.1, %or.114103.7.clone.1) + %add.245710.3.clone.1 = u32[1280,1280]{1,0} add(%add.245709.3.clone.1, %xor.120656.5.clone.1) + %constant_164656_1_clone_1 = u32[] constant(3958315157) + %broadcast.247253.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164656_1_clone_1), dimensions={} + %add.245711.7.clone.1 = u32[1280,1280]{1,0} add(%add.245710.3.clone.1, %broadcast.247253.24.clone.1) + %shift-left.108447.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120656.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114579.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120656.5.clone.1, %broadcast.244434.2816) + %or.114104.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108447.11.clone.1, %shift-right-logical.114579.11.clone.1) + %xor.120657.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245710.3.clone.1, %or.114104.9.clone.1) + %constant_217940_1_clone_1 = u32[] constant(2003854818) + %broadcast.247256.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217940_1_clone_1), dimensions={} + %add.245712.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120657.7.clone.1, %broadcast.247256.5.clone.1) + %add.245713.5.clone.1 = u32[1280,1280]{1,0} add(%add.245711.7.clone.1, %add.245712.5.clone.1) + %shift-left.108448.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245712.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114580.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245712.5.clone.1, %broadcast.244415.6016) + %or.114106.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108448.9.clone.1, %shift-right-logical.114580.9.clone.1) + %xor.120658.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245713.5.clone.1, %or.114106.7.clone.1) + %add.245714.3.clone.1 = u32[1280,1280]{1,0} add(%add.245713.5.clone.1, %xor.120658.5.clone.1) + %shift-left.108449.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120658.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114581.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120658.5.clone.1, %broadcast.244417.5760) + %or.114107.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108449.9.clone.1, %shift-right-logical.114581.9.clone.1) + %xor.120659.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245714.3.clone.1, %or.114107.7.clone.1) + %add.245715.3.clone.1 = u32[1280,1280]{1,0} add(%add.245714.3.clone.1, %xor.120659.5.clone.1) + %shift-left.108450.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120659.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114582.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120659.5.clone.1, %broadcast.244419.4352) + %or.114108.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108450.7.clone.1, %shift-right-logical.114582.7.clone.1) + %xor.120660.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245715.3.clone.1, %or.114108.5.clone.1) + %add.245716.3.clone.1 = u32[1280,1280]{1,0} add(%add.245715.3.clone.1, %xor.120660.3.clone.1) + %add.245717.7.clone.1 = u32[1280,1280]{1,0} add(%add.245716.3.clone.1, %broadcast.247235.44.clone.1) + %shift-left.108451.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120660.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114583.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120660.3.clone.1, %broadcast.244418.4352) + %or.114109.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108451.7.clone.1, %shift-right-logical.114583.7.clone.1) + %xor.120661.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245716.3.clone.1, %or.114109.5.clone.1) + %constant_217941_1_clone_1 = u32[] constant(2270062258) + %broadcast.247266.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217941_1_clone_1), dimensions={} + %add.245718.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120661.3.clone.1, %broadcast.247266.5.clone.1) + %add.245719.5.clone.1 = u32[1280,1280]{1,0} add(%add.245717.7.clone.1, %add.245718.5.clone.1) + %shift-left.108452.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245718.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114584.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245718.5.clone.1, %broadcast.244416.5760) + %or.114111.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108452.9.clone.1, %shift-right-logical.114584.9.clone.1) + %xor.120662.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245719.5.clone.1, %or.114111.7.clone.1) + %add.245720.3.clone.1 = u32[1280,1280]{1,0} add(%add.245719.5.clone.1, %xor.120662.5.clone.1) + %shift-left.108453.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120662.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114585.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120662.5.clone.1, %broadcast.244429.2304) + %or.114112.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108453.9.clone.1, %shift-right-logical.114585.9.clone.1) + %xor.120663.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245720.3.clone.1, %or.114112.7.clone.1) + %add.245721.3.clone.1 = u32[1280,1280]{1,0} add(%add.245720.3.clone.1, %xor.120663.5.clone.1) + %shift-left.108454.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120663.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114586.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120663.5.clone.1, %broadcast.244430.4608) + %or.114113.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108454.9.clone.1, %shift-right-logical.114586.9.clone.1) + %xor.120664.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245721.3.clone.1, %or.114113.7.clone.1) + %add.245722.3.clone.1 = u32[1280,1280]{1,0} add(%add.245721.3.clone.1, %xor.120664.5.clone.1) + %add.245723.7.clone.1 = u32[1280,1280]{1,0} add(%add.245722.3.clone.1, %broadcast.247236.113.clone.1) + %shift-left.108455.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120664.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114587.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120664.5.clone.1, %broadcast.244434.2816) + %or.114114.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108455.11.clone.1, %shift-right-logical.114587.11.clone.1) + %xor.120666.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245722.3.clone.1, %or.114114.9.clone.1) + %constant_217942_1_clone_1 = u32[] constant(3958315161) + %broadcast.247278.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217942_1_clone_1), dimensions={} + %add.245724.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120666.7.clone.1, %broadcast.247278.5.clone.1) + %add.245725.5.clone.1 = u32[1280,1280]{1,0} add(%add.245723.7.clone.1, %add.245724.5.clone.1) + %shift-left.108456.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245724.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114588.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245724.5.clone.1, %broadcast.244415.6016) + %or.114115.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108456.9.clone.1, %shift-right-logical.114588.9.clone.1) + %xor.120667.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245725.5.clone.1, %or.114115.7.clone.1) + %add.245726.3.clone.1 = u32[1280,1280]{1,0} add(%add.245725.5.clone.1, %xor.120667.5.clone.1) + %shift-left.108457.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120667.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114589.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120667.5.clone.1, %broadcast.244417.5760) + %or.114116.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108457.9.clone.1, %shift-right-logical.114589.9.clone.1) + %xor.120668.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245726.3.clone.1, %or.114116.7.clone.1) + %add.245728.3.clone.1 = u32[1280,1280]{1,0} add(%add.245726.3.clone.1, %xor.120668.5.clone.1) + %shift-left.108458.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120668.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114590.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120668.5.clone.1, %broadcast.244419.4352) + %or.114117.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108458.5.clone.1, %shift-right-logical.114590.5.clone.1) + %xor.120669.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245728.3.clone.1, %or.114117.3.clone.1) + %add.245729.3.clone.1 = u32[1280,1280]{1,0} add(%add.245728.3.clone.1, %xor.120669.3.clone.1) + %add.245730.17.clone.1 = u32[1280,1280]{1,0} add(%add.245729.3.clone.1, %broadcast.247253.24.clone.1) + %shift-left.108459.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120669.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114591.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120669.3.clone.1, %broadcast.244418.4352) + %or.114118.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108459.5.clone.1, %shift-right-logical.114591.5.clone.1) + %xor.120671.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245729.3.clone.1, %or.114118.3.clone.1) + %constant_217943_1_clone_1 = u32[] constant(2003854821) + %broadcast.247298.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217943_1_clone_1), dimensions={} + %add.245731.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120671.15.clone.1, %broadcast.247298.19.clone.1) + %xor.120672.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245730.17.clone.1, %add.245731.19.clone.1) + %shift-right-logical.114592.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120672.17.clone.1, %broadcast.244468.1920) + %or.114119.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114592.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5701.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114119.13.clone.1) + %add.245732.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5701.11.clone.1, %broadcast.244470.1152) + %multiply.25916.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245732.9.clone.1, %broadcast.244471.896) + %add.245733.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25916.7.clone.1, %broadcast.244408.1024) + %maximum.3633.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245733.5.clone.1) + %abs.1505.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3633.3.clone.1) + %compare.7158.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1505.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25917.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3633.3.clone.1, %broadcast.244476.1152) + %negate.4515.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3633.3.clone.1) + %multiply.25918.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3633.3.clone.1, %negate.4515.5.clone.1) + %log-plus-one.1505.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25918.5.clone.1) + %negate.4516.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1505.3.clone.1) + %compare.7159.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4516.4.clone.1, %broadcast.244477.384), direction=LT + %select.20620.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20621.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20622.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20623.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20624.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20625.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20626.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20627.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20628.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245734.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4516.4.clone.1, %broadcast.244496.640) + %sqrt.1505.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4516.4.clone.1) + %add.245735.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1505.5.clone.1, %broadcast.244498.640) + %select.20629.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7159.3.clone.1, %add.245734.5.clone.1, %add.245735.5.clone.1) + %multiply.25919.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20628.3.clone.1, %select.20629.3.clone.1) + %add.245736.1.clone.1 = f32[1280,1280]{1,0} add(%select.20627.3.clone.1, %multiply.25919.1.clone.1) + %multiply.25920.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245736.1.clone.1, %select.20629.3.clone.1) + %add.245737.1.clone.1 = f32[1280,1280]{1,0} add(%select.20626.3.clone.1, %multiply.25920.1.clone.1) + %multiply.25921.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245737.1.clone.1, %select.20629.3.clone.1) + %add.245738.1.clone.1 = f32[1280,1280]{1,0} add(%select.20625.3.clone.1, %multiply.25921.1.clone.1) + %multiply.25922.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245738.1.clone.1, %select.20629.3.clone.1) + %add.245739.1.clone.1 = f32[1280,1280]{1,0} add(%select.20624.3.clone.1, %multiply.25922.1.clone.1) + %multiply.25923.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245739.1.clone.1, %select.20629.3.clone.1) + %add.245740.3.clone.1 = f32[1280,1280]{1,0} add(%select.20623.5.clone.1, %multiply.25923.1.clone.1) + %multiply.25924.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245740.3.clone.1, %select.20629.3.clone.1) + %add.245741.3.clone.1 = f32[1280,1280]{1,0} add(%select.20622.5.clone.1, %multiply.25924.1.clone.1) + %multiply.25926.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245741.3.clone.1, %select.20629.3.clone.1) + %add.245742.9.clone.1 = f32[1280,1280]{1,0} add(%select.20621.11.clone.1, %multiply.25926.7.clone.1) + %multiply.25927.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245742.9.clone.1, %select.20629.3.clone.1) + %add.245743.7.clone.1 = f32[1280,1280]{1,0} add(%select.20620.7.clone.1, %multiply.25927.7.clone.1) + %multiply.25929.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245743.7.clone.1, %maximum.3633.3.clone.1) + %select.20630.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7158.3.clone.1, %multiply.25917.9.clone.1, %multiply.25929.7.clone.1) + %multiply.25930.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20630.7.clone.1, %broadcast.244500.640) + %clamp.1149.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25930.5.clone.1, %broadcast.244501.384) + %multiply.25932.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1149.3.clone.1, %broadcast.244502.1) + %constant_180706_1_clone_1 = u32[] constant(951851306) + %broadcast.254166.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_180706_1_clone_1), dimensions={} + %add.249648.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254166.44.clone.1) + %constant_180713_1_clone_1 = u32[] constant(3477382136) + %broadcast.254167.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_180713_1_clone_1), dimensions={} + %add.249650.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254167.113.clone.1) + %add.249651.35.clone.1 = u32[1280,1280]{1,0} add(%add.249648.37.clone.1, %add.249650.99.clone.1) + %shift-left.110144.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249650.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116377.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249650.99.clone.1, %broadcast.244415.6016) + %or.115906.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110144.31.clone.1, %shift-right-logical.116377.29.clone.1) + %xor.122457.27.clone.1 = u32[1280,1280]{1,0} xor(%add.249651.35.clone.1, %or.115906.29.clone.1) + %add.249652.5.clone.1 = u32[1280,1280]{1,0} add(%add.249651.35.clone.1, %xor.122457.27.clone.1) + %shift-left.110146.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122457.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116379.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122457.27.clone.1, %broadcast.244417.5760) + %or.115907.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110146.9.clone.1, %shift-right-logical.116379.9.clone.1) + %xor.122458.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249652.5.clone.1, %or.115907.7.clone.1) + %add.249653.3.clone.1 = u32[1280,1280]{1,0} add(%add.249652.5.clone.1, %xor.122458.5.clone.1) + %shift-left.110147.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122458.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116380.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122458.5.clone.1, %broadcast.244419.4352) + %or.115908.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110147.5.clone.1, %shift-right-logical.116380.5.clone.1) + %xor.122460.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249653.3.clone.1, %or.115908.3.clone.1) + %add.249654.3.clone.1 = u32[1280,1280]{1,0} add(%add.249653.3.clone.1, %xor.122460.3.clone.1) + %add.249656.7.clone.1 = u32[1280,1280]{1,0} add(%add.249654.3.clone.1, %broadcast.254167.113.clone.1) + %shift-left.110148.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122460.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116381.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122460.3.clone.1, %broadcast.244418.4352) + %or.115909.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110148.5.clone.1, %shift-right-logical.116381.5.clone.1) + %xor.122461.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249654.3.clone.1, %or.115909.3.clone.1) + %constant_218389_1_clone_1 = u32[] constant(3962151177) + %broadcast.254177.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218389_1_clone_1), dimensions={} + %add.249660.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122461.3.clone.1, %broadcast.254177.5.clone.1) + %add.249661.5.clone.1 = u32[1280,1280]{1,0} add(%add.249656.7.clone.1, %add.249660.5.clone.1) + %shift-left.110149.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249660.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116382.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249660.5.clone.1, %broadcast.244416.5760) + %or.115910.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110149.9.clone.1, %shift-right-logical.116382.9.clone.1) + %xor.122462.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249661.5.clone.1, %or.115910.7.clone.1) + %add.249662.3.clone.1 = u32[1280,1280]{1,0} add(%add.249661.5.clone.1, %xor.122462.5.clone.1) + %shift-left.110151.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122462.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116384.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122462.5.clone.1, %broadcast.244429.2304) + %or.115911.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110151.9.clone.1, %shift-right-logical.116384.9.clone.1) + %xor.122463.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249662.3.clone.1, %or.115911.7.clone.1) + %add.249663.3.clone.1 = u32[1280,1280]{1,0} add(%add.249662.3.clone.1, %xor.122463.5.clone.1) + %shift-left.110152.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122463.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116385.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122463.5.clone.1, %broadcast.244430.4608) + %or.115912.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110152.9.clone.1, %shift-right-logical.116385.9.clone.1) + %xor.122465.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249663.3.clone.1, %or.115912.7.clone.1) + %add.249665.3.clone.1 = u32[1280,1280]{1,0} add(%add.249663.3.clone.1, %xor.122465.5.clone.1) + %constant_180715_1_clone_1 = u32[] constant(3962151176) + %broadcast.254184.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_180715_1_clone_1), dimensions={} + %add.249666.7.clone.1 = u32[1280,1280]{1,0} add(%add.249665.3.clone.1, %broadcast.254184.24.clone.1) + %shift-left.110153.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122465.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116386.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122465.5.clone.1, %broadcast.244434.2816) + %or.115913.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110153.11.clone.1, %shift-right-logical.116386.11.clone.1) + %xor.122466.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249665.3.clone.1, %or.115913.9.clone.1) + %constant_218390_1_clone_1 = u32[] constant(951851308) + %broadcast.254187.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218390_1_clone_1), dimensions={} + %add.249667.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122466.7.clone.1, %broadcast.254187.5.clone.1) + %add.249668.5.clone.1 = u32[1280,1280]{1,0} add(%add.249666.7.clone.1, %add.249667.5.clone.1) + %shift-left.110154.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249667.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116387.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249667.5.clone.1, %broadcast.244415.6016) + %or.115914.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110154.9.clone.1, %shift-right-logical.116387.9.clone.1) + %xor.122467.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249668.5.clone.1, %or.115914.7.clone.1) + %add.249670.3.clone.1 = u32[1280,1280]{1,0} add(%add.249668.5.clone.1, %xor.122467.5.clone.1) + %shift-left.110156.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122467.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116389.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122467.5.clone.1, %broadcast.244417.5760) + %or.115915.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110156.9.clone.1, %shift-right-logical.116389.9.clone.1) + %xor.122468.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249670.3.clone.1, %or.115915.7.clone.1) + %add.249671.3.clone.1 = u32[1280,1280]{1,0} add(%add.249670.3.clone.1, %xor.122468.5.clone.1) + %shift-left.110157.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122468.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116390.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122468.5.clone.1, %broadcast.244419.4352) + %or.115916.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110157.7.clone.1, %shift-right-logical.116390.7.clone.1) + %xor.122469.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249671.3.clone.1, %or.115916.5.clone.1) + %add.249672.3.clone.1 = u32[1280,1280]{1,0} add(%add.249671.3.clone.1, %xor.122469.3.clone.1) + %add.249673.7.clone.1 = u32[1280,1280]{1,0} add(%add.249672.3.clone.1, %broadcast.254166.44.clone.1) + %shift-left.110158.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122469.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116391.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122469.3.clone.1, %broadcast.244418.4352) + %or.115917.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110158.7.clone.1, %shift-right-logical.116391.7.clone.1) + %xor.122470.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249672.3.clone.1, %or.115917.5.clone.1) + %constant_218391_1_clone_1 = u32[] constant(3477382139) + %broadcast.254197.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218391_1_clone_1), dimensions={} + %add.249675.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122470.3.clone.1, %broadcast.254197.5.clone.1) + %add.249676.5.clone.1 = u32[1280,1280]{1,0} add(%add.249673.7.clone.1, %add.249675.5.clone.1) + %shift-left.110159.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249675.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116392.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249675.5.clone.1, %broadcast.244416.5760) + %or.115918.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110159.9.clone.1, %shift-right-logical.116392.9.clone.1) + %xor.122471.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249676.5.clone.1, %or.115918.7.clone.1) + %add.249677.3.clone.1 = u32[1280,1280]{1,0} add(%add.249676.5.clone.1, %xor.122471.5.clone.1) + %shift-left.110161.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122471.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116393.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122471.5.clone.1, %broadcast.244429.2304) + %or.115919.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110161.9.clone.1, %shift-right-logical.116393.9.clone.1) + %xor.122472.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249677.3.clone.1, %or.115919.7.clone.1) + %add.249678.3.clone.1 = u32[1280,1280]{1,0} add(%add.249677.3.clone.1, %xor.122472.5.clone.1) + %shift-left.110162.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122472.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116394.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122472.5.clone.1, %broadcast.244430.4608) + %or.115920.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110162.9.clone.1, %shift-right-logical.116394.9.clone.1) + %xor.122473.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249678.3.clone.1, %or.115920.7.clone.1) + %add.249679.3.clone.1 = u32[1280,1280]{1,0} add(%add.249678.3.clone.1, %xor.122473.5.clone.1) + %add.249681.7.clone.1 = u32[1280,1280]{1,0} add(%add.249679.3.clone.1, %broadcast.254167.113.clone.1) + %shift-left.110163.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122473.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116395.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122473.5.clone.1, %broadcast.244434.2816) + %or.115921.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110163.11.clone.1, %shift-right-logical.116395.11.clone.1) + %xor.122475.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249679.3.clone.1, %or.115921.9.clone.1) + %constant_218392_1_clone_1 = u32[] constant(3962151180) + %broadcast.254207.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218392_1_clone_1), dimensions={} + %add.249684.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122475.7.clone.1, %broadcast.254207.5.clone.1) + %add.249685.5.clone.1 = u32[1280,1280]{1,0} add(%add.249681.7.clone.1, %add.249684.5.clone.1) + %shift-left.110164.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249684.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116396.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249684.5.clone.1, %broadcast.244415.6016) + %or.115922.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110164.9.clone.1, %shift-right-logical.116396.9.clone.1) + %xor.122476.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249685.5.clone.1, %or.115922.7.clone.1) + %add.249686.3.clone.1 = u32[1280,1280]{1,0} add(%add.249685.5.clone.1, %xor.122476.5.clone.1) + %shift-left.110165.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122476.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116397.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122476.5.clone.1, %broadcast.244417.5760) + %or.115923.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110165.9.clone.1, %shift-right-logical.116397.9.clone.1) + %xor.122477.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249686.3.clone.1, %or.115923.7.clone.1) + %add.249687.3.clone.1 = u32[1280,1280]{1,0} add(%add.249686.3.clone.1, %xor.122477.5.clone.1) + %shift-left.110166.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122477.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116398.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122477.5.clone.1, %broadcast.244419.4352) + %or.115924.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110166.5.clone.1, %shift-right-logical.116398.5.clone.1) + %xor.122478.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249687.3.clone.1, %or.115924.3.clone.1) + %add.249688.3.clone.1 = u32[1280,1280]{1,0} add(%add.249687.3.clone.1, %xor.122478.3.clone.1) + %add.249689.17.clone.1 = u32[1280,1280]{1,0} add(%add.249688.3.clone.1, %broadcast.254184.24.clone.1) + %shift-left.110167.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122478.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116399.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122478.3.clone.1, %broadcast.244418.4352) + %or.115925.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110167.5.clone.1, %shift-right-logical.116399.5.clone.1) + %xor.122480.15.clone.1 = u32[1280,1280]{1,0} xor(%add.249688.3.clone.1, %or.115925.3.clone.1) + %constant_218393_1_clone_1 = u32[] constant(951851311) + %broadcast.254217.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218393_1_clone_1), dimensions={} + %add.249690.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122480.15.clone.1, %broadcast.254217.19.clone.1) + %xor.122481.17.clone.1 = u32[1280,1280]{1,0} xor(%add.249689.17.clone.1, %add.249690.19.clone.1) + %shift-right-logical.116400.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122481.17.clone.1, %broadcast.244468.1920) + %or.115926.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116400.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5780.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115926.13.clone.1) + %add.249691.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5780.11.clone.1, %broadcast.244470.1152) + %multiply.26737.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249691.9.clone.1, %broadcast.244471.896) + %add.249692.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26737.7.clone.1, %broadcast.244408.1024) + %maximum.3712.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.249692.5.clone.1) + %abs.1558.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3712.3.clone.1) + %compare.7265.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1558.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26738.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3712.3.clone.1, %broadcast.244476.1152) + %negate.4621.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3712.3.clone.1) + %multiply.26739.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3712.3.clone.1, %negate.4621.5.clone.1) + %log-plus-one.1558.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26739.5.clone.1) + %negate.4622.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1558.3.clone.1) + %compare.7266.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4622.4.clone.1, %broadcast.244477.384), direction=LT + %select.21224.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21225.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21226.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21227.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21228.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21229.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21230.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21231.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21232.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.249693.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4622.4.clone.1, %broadcast.244496.640) + %sqrt.1558.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4622.4.clone.1) + %add.249694.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1558.5.clone.1, %broadcast.244498.640) + %select.21233.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7266.3.clone.1, %add.249693.5.clone.1, %add.249694.5.clone.1) + %multiply.26740.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21232.3.clone.1, %select.21233.3.clone.1) + %add.249695.1.clone.1 = f32[1280,1280]{1,0} add(%select.21231.3.clone.1, %multiply.26740.1.clone.1) + %multiply.26741.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249695.1.clone.1, %select.21233.3.clone.1) + %add.249696.1.clone.1 = f32[1280,1280]{1,0} add(%select.21230.3.clone.1, %multiply.26741.1.clone.1) + %multiply.26742.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249696.1.clone.1, %select.21233.3.clone.1) + %add.249697.1.clone.1 = f32[1280,1280]{1,0} add(%select.21229.3.clone.1, %multiply.26742.1.clone.1) + %multiply.26743.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249697.1.clone.1, %select.21233.3.clone.1) + %add.249698.1.clone.1 = f32[1280,1280]{1,0} add(%select.21228.3.clone.1, %multiply.26743.1.clone.1) + %multiply.26744.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249698.1.clone.1, %select.21233.3.clone.1) + %add.249699.3.clone.1 = f32[1280,1280]{1,0} add(%select.21227.5.clone.1, %multiply.26744.1.clone.1) + %multiply.26745.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249699.3.clone.1, %select.21233.3.clone.1) + %add.249700.3.clone.1 = f32[1280,1280]{1,0} add(%select.21226.5.clone.1, %multiply.26745.1.clone.1) + %multiply.26746.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249700.3.clone.1, %select.21233.3.clone.1) + %add.249701.9.clone.1 = f32[1280,1280]{1,0} add(%select.21225.11.clone.1, %multiply.26746.7.clone.1) + %multiply.26747.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249701.9.clone.1, %select.21233.3.clone.1) + %add.249702.7.clone.1 = f32[1280,1280]{1,0} add(%select.21224.7.clone.1, %multiply.26747.7.clone.1) + %multiply.26748.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249702.7.clone.1, %maximum.3712.3.clone.1) + %select.21234.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7265.3.clone.1, %multiply.26738.9.clone.1, %multiply.26748.7.clone.1) + %multiply.26749.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21234.7.clone.1, %broadcast.244500.640) + %clamp.1202.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26749.5.clone.1, %broadcast.244501.384) + %multiply.26750.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1202.3.clone.1, %broadcast.244502.1) + %constant_164415_1_clone_1 = u32[] constant(187457133) + %broadcast.247149.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164415_1_clone_1), dimensions={} + %add.245637.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247149.44.clone.1) + %constant_164422_1_clone_1 = u32[] constant(1046480133) + %broadcast.247150.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164422_1_clone_1), dimensions={} + %add.245638.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247150.113.clone.1) + %add.245639.35.clone.1 = u32[1280,1280]{1,0} add(%add.245637.37.clone.1, %add.245638.99.clone.1) + %shift-left.108420.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245638.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114549.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245638.99.clone.1, %broadcast.244415.6016) + %or.114071.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108420.31.clone.1, %shift-right-logical.114549.29.clone.1) + %xor.120623.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245639.35.clone.1, %or.114071.29.clone.1) + %add.245641.5.clone.1 = u32[1280,1280]{1,0} add(%add.245639.35.clone.1, %xor.120623.27.clone.1) + %shift-left.108421.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120623.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114550.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120623.27.clone.1, %broadcast.244417.5760) + %or.114072.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108421.9.clone.1, %shift-right-logical.114550.9.clone.1) + %xor.120624.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245641.5.clone.1, %or.114072.7.clone.1) + %add.245642.3.clone.1 = u32[1280,1280]{1,0} add(%add.245641.5.clone.1, %xor.120624.5.clone.1) + %shift-left.108422.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120624.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114552.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120624.5.clone.1, %broadcast.244419.4352) + %or.114073.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108422.5.clone.1, %shift-right-logical.114552.5.clone.1) + %xor.120626.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245642.3.clone.1, %or.114073.3.clone.1) + %add.245643.3.clone.1 = u32[1280,1280]{1,0} add(%add.245642.3.clone.1, %xor.120626.3.clone.1) + %add.245644.7.clone.1 = u32[1280,1280]{1,0} add(%add.245643.3.clone.1, %broadcast.247150.113.clone.1) + %shift-left.108423.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120626.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114553.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120626.3.clone.1, %broadcast.244418.4352) + %or.114074.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108423.5.clone.1, %shift-right-logical.114553.5.clone.1) + %xor.120627.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245643.3.clone.1, %or.114074.3.clone.1) + %constant_217934_1_clone_1 = u32[] constant(782057651) + %broadcast.247160.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217934_1_clone_1), dimensions={} + %add.245645.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120627.3.clone.1, %broadcast.247160.5.clone.1) + %add.245647.5.clone.1 = u32[1280,1280]{1,0} add(%add.245644.7.clone.1, %add.245645.5.clone.1) + %shift-left.108424.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245645.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114554.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245645.5.clone.1, %broadcast.244416.5760) + %or.114076.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108424.9.clone.1, %shift-right-logical.114554.9.clone.1) + %xor.120628.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245647.5.clone.1, %or.114076.7.clone.1) + %add.245651.3.clone.1 = u32[1280,1280]{1,0} add(%add.245647.5.clone.1, %xor.120628.5.clone.1) + %shift-left.108425.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120628.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114555.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120628.5.clone.1, %broadcast.244429.2304) + %or.114077.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108425.9.clone.1, %shift-right-logical.114555.9.clone.1) + %xor.120629.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245651.3.clone.1, %or.114077.7.clone.1) + %add.245652.3.clone.1 = u32[1280,1280]{1,0} add(%add.245651.3.clone.1, %xor.120629.5.clone.1) + %shift-left.108426.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120629.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114556.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120629.5.clone.1, %broadcast.244430.4608) + %or.114078.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108426.9.clone.1, %shift-right-logical.114556.9.clone.1) + %xor.120631.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245652.3.clone.1, %or.114078.7.clone.1) + %add.245653.3.clone.1 = u32[1280,1280]{1,0} add(%add.245652.3.clone.1, %xor.120631.5.clone.1) + %constant_164424_1_clone_1 = u32[] constant(782057650) + %broadcast.247167.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164424_1_clone_1), dimensions={} + %add.245654.7.clone.1 = u32[1280,1280]{1,0} add(%add.245653.3.clone.1, %broadcast.247167.24.clone.1) + %shift-left.108427.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120631.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114557.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120631.5.clone.1, %broadcast.244434.2816) + %or.114079.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108427.11.clone.1, %shift-right-logical.114557.11.clone.1) + %xor.120632.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245653.3.clone.1, %or.114079.9.clone.1) + %constant_217935_1_clone_1 = u32[] constant(187457135) + %broadcast.247170.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217935_1_clone_1), dimensions={} + %add.245656.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120632.7.clone.1, %broadcast.247170.5.clone.1) + %add.245657.5.clone.1 = u32[1280,1280]{1,0} add(%add.245654.7.clone.1, %add.245656.5.clone.1) + %shift-left.108428.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245656.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114558.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245656.5.clone.1, %broadcast.244415.6016) + %or.114081.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108428.9.clone.1, %shift-right-logical.114558.9.clone.1) + %xor.120633.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245657.5.clone.1, %or.114081.7.clone.1) + %add.245658.3.clone.1 = u32[1280,1280]{1,0} add(%add.245657.5.clone.1, %xor.120633.5.clone.1) + %shift-left.108429.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120633.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114559.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120633.5.clone.1, %broadcast.244417.5760) + %or.114082.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108429.9.clone.1, %shift-right-logical.114559.9.clone.1) + %xor.120634.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245658.3.clone.1, %or.114082.7.clone.1) + %add.245659.3.clone.1 = u32[1280,1280]{1,0} add(%add.245658.3.clone.1, %xor.120634.5.clone.1) + %shift-left.108430.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120634.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114560.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120634.5.clone.1, %broadcast.244419.4352) + %or.114083.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108430.7.clone.1, %shift-right-logical.114560.7.clone.1) + %xor.120635.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245659.3.clone.1, %or.114083.5.clone.1) + %add.245661.3.clone.1 = u32[1280,1280]{1,0} add(%add.245659.3.clone.1, %xor.120635.3.clone.1) + %add.245662.7.clone.1 = u32[1280,1280]{1,0} add(%add.245661.3.clone.1, %broadcast.247149.44.clone.1) + %shift-left.108431.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120635.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114561.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120635.3.clone.1, %broadcast.244418.4352) + %or.114084.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108431.7.clone.1, %shift-right-logical.114561.7.clone.1) + %xor.120636.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245661.3.clone.1, %or.114084.5.clone.1) + %constant_217936_1_clone_1 = u32[] constant(1046480136) + %broadcast.247180.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217936_1_clone_1), dimensions={} + %add.245663.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120636.3.clone.1, %broadcast.247180.5.clone.1) + %add.245664.5.clone.1 = u32[1280,1280]{1,0} add(%add.245662.7.clone.1, %add.245663.5.clone.1) + %shift-left.108432.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245663.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114562.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245663.5.clone.1, %broadcast.244416.5760) + %or.114086.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108432.9.clone.1, %shift-right-logical.114562.9.clone.1) + %xor.120637.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245664.5.clone.1, %or.114086.7.clone.1) + %add.245666.3.clone.1 = u32[1280,1280]{1,0} add(%add.245664.5.clone.1, %xor.120637.5.clone.1) + %shift-left.108433.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120637.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114563.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120637.5.clone.1, %broadcast.244429.2304) + %or.114087.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108433.9.clone.1, %shift-right-logical.114563.9.clone.1) + %xor.120638.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245666.3.clone.1, %or.114087.7.clone.1) + %add.245667.3.clone.1 = u32[1280,1280]{1,0} add(%add.245666.3.clone.1, %xor.120638.5.clone.1) + %shift-left.108434.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120638.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114565.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120638.5.clone.1, %broadcast.244430.4608) + %or.114088.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108434.9.clone.1, %shift-right-logical.114565.9.clone.1) + %xor.120639.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245667.3.clone.1, %or.114088.7.clone.1) + %add.245668.3.clone.1 = u32[1280,1280]{1,0} add(%add.245667.3.clone.1, %xor.120639.5.clone.1) + %add.245669.7.clone.1 = u32[1280,1280]{1,0} add(%add.245668.3.clone.1, %broadcast.247150.113.clone.1) + %shift-left.108435.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120639.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114566.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120639.5.clone.1, %broadcast.244434.2816) + %or.114089.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108435.11.clone.1, %shift-right-logical.114566.11.clone.1) + %xor.120641.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245668.3.clone.1, %or.114089.9.clone.1) + %constant_217937_1_clone_1 = u32[] constant(782057654) + %broadcast.247190.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217937_1_clone_1), dimensions={} + %add.245670.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120641.7.clone.1, %broadcast.247190.5.clone.1) + %add.245672.5.clone.1 = u32[1280,1280]{1,0} add(%add.245669.7.clone.1, %add.245670.5.clone.1) + %shift-left.108436.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245670.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114567.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245670.5.clone.1, %broadcast.244415.6016) + %or.114090.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108436.9.clone.1, %shift-right-logical.114567.9.clone.1) + %xor.120642.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245672.5.clone.1, %or.114090.7.clone.1) + %add.245676.3.clone.1 = u32[1280,1280]{1,0} add(%add.245672.5.clone.1, %xor.120642.5.clone.1) + %shift-left.108437.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120642.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114568.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120642.5.clone.1, %broadcast.244417.5760) + %or.114091.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108437.9.clone.1, %shift-right-logical.114568.9.clone.1) + %xor.120643.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245676.3.clone.1, %or.114091.7.clone.1) + %add.245677.3.clone.1 = u32[1280,1280]{1,0} add(%add.245676.3.clone.1, %xor.120643.5.clone.1) + %shift-left.108438.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120643.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114569.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120643.5.clone.1, %broadcast.244419.4352) + %or.114092.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108438.5.clone.1, %shift-right-logical.114569.5.clone.1) + %xor.120644.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245677.3.clone.1, %or.114092.3.clone.1) + %add.245678.3.clone.1 = u32[1280,1280]{1,0} add(%add.245677.3.clone.1, %xor.120644.3.clone.1) + %add.245679.17.clone.1 = u32[1280,1280]{1,0} add(%add.245678.3.clone.1, %broadcast.247167.24.clone.1) + %shift-left.108439.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120644.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114570.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120644.3.clone.1, %broadcast.244418.4352) + %or.114093.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108439.5.clone.1, %shift-right-logical.114570.5.clone.1) + %xor.120646.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245678.3.clone.1, %or.114093.3.clone.1) + %constant_217938_1_clone_1 = u32[] constant(187457138) + %broadcast.247200.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217938_1_clone_1), dimensions={} + %add.245681.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120646.15.clone.1, %broadcast.247200.19.clone.1) + %xor.120647.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245679.17.clone.1, %add.245681.19.clone.1) + %shift-right-logical.114571.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120647.17.clone.1, %broadcast.244468.1920) + %or.114094.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114571.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5700.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114094.13.clone.1) + %add.245682.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5700.11.clone.1, %broadcast.244470.1152) + %multiply.25901.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245682.9.clone.1, %broadcast.244471.896) + %add.245683.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25901.7.clone.1, %broadcast.244408.1024) + %maximum.3632.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245683.5.clone.1) + %abs.1504.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3632.3.clone.1) + %compare.7156.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1504.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25902.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3632.3.clone.1, %broadcast.244476.1152) + %negate.4513.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3632.3.clone.1) + %multiply.25903.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3632.3.clone.1, %negate.4513.5.clone.1) + %log-plus-one.1504.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25903.5.clone.1) + %negate.4514.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1504.3.clone.1) + %compare.7157.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4514.4.clone.1, %broadcast.244477.384), direction=LT + %select.20609.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20610.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20611.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20612.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20613.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20614.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20615.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20616.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20617.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245684.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4514.4.clone.1, %broadcast.244496.640) + %sqrt.1504.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4514.4.clone.1) + %add.245686.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1504.5.clone.1, %broadcast.244498.640) + %select.20618.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7157.3.clone.1, %add.245684.5.clone.1, %add.245686.5.clone.1) + %multiply.25904.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20617.3.clone.1, %select.20618.3.clone.1) + %add.245687.1.clone.1 = f32[1280,1280]{1,0} add(%select.20616.3.clone.1, %multiply.25904.1.clone.1) + %multiply.25905.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245687.1.clone.1, %select.20618.3.clone.1) + %add.245688.1.clone.1 = f32[1280,1280]{1,0} add(%select.20615.3.clone.1, %multiply.25905.1.clone.1) + %multiply.25906.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245688.1.clone.1, %select.20618.3.clone.1) + %add.245689.1.clone.1 = f32[1280,1280]{1,0} add(%select.20614.3.clone.1, %multiply.25906.1.clone.1) + %multiply.25907.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245689.1.clone.1, %select.20618.3.clone.1) + %add.245691.1.clone.1 = f32[1280,1280]{1,0} add(%select.20613.3.clone.1, %multiply.25907.1.clone.1) + %multiply.25908.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245691.1.clone.1, %select.20618.3.clone.1) + %add.245692.3.clone.1 = f32[1280,1280]{1,0} add(%select.20612.5.clone.1, %multiply.25908.1.clone.1) + %multiply.25909.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245692.3.clone.1, %select.20618.3.clone.1) + %add.245693.3.clone.1 = f32[1280,1280]{1,0} add(%select.20611.5.clone.1, %multiply.25909.1.clone.1) + %multiply.25910.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245693.3.clone.1, %select.20618.3.clone.1) + %add.245694.9.clone.1 = f32[1280,1280]{1,0} add(%select.20610.11.clone.1, %multiply.25910.7.clone.1) + %multiply.25911.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245694.9.clone.1, %select.20618.3.clone.1) + %add.245695.7.clone.1 = f32[1280,1280]{1,0} add(%select.20609.7.clone.1, %multiply.25911.7.clone.1) + %multiply.25912.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245695.7.clone.1, %maximum.3632.3.clone.1) + %select.20619.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7156.3.clone.1, %multiply.25902.9.clone.1, %multiply.25912.7.clone.1) + %multiply.25913.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20619.7.clone.1, %broadcast.244500.640) + %clamp.1148.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25913.5.clone.1, %broadcast.244501.384) + %multiply.25914.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1148.3.clone.1, %broadcast.244502.1) + %constant_192897_1_clone_1 = u32[] constant(749277645) + %broadcast.259469.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_192897_1_clone_1), dimensions={} + %add.252670.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.259469.44.clone.1) + %constant_192904_1_clone_1 = u32[] constant(3963949667) + %broadcast.259471.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_192904_1_clone_1), dimensions={} + %add.252671.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.259471.113.clone.1) + %add.252672.35.clone.1 = u32[1280,1280]{1,0} add(%add.252670.37.clone.1, %add.252671.99.clone.1) + %shift-left.111463.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252671.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.117771.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252671.99.clone.1, %broadcast.244415.6016) + %or.117293.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111463.31.clone.1, %shift-right-logical.117771.29.clone.1) + %xor.123848.27.clone.1 = u32[1280,1280]{1,0} xor(%add.252672.35.clone.1, %or.117293.29.clone.1) + %add.252673.5.clone.1 = u32[1280,1280]{1,0} add(%add.252672.35.clone.1, %xor.123848.27.clone.1) + %shift-left.111464.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123848.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.117772.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123848.27.clone.1, %broadcast.244417.5760) + %or.117294.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111464.9.clone.1, %shift-right-logical.117772.9.clone.1) + %xor.123849.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252673.5.clone.1, %or.117294.7.clone.1) + %add.252675.3.clone.1 = u32[1280,1280]{1,0} add(%add.252673.5.clone.1, %xor.123849.5.clone.1) + %shift-left.111465.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123849.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117773.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123849.5.clone.1, %broadcast.244419.4352) + %or.117295.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111465.5.clone.1, %shift-right-logical.117773.5.clone.1) + %xor.123850.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252675.3.clone.1, %or.117295.3.clone.1) + %add.252676.3.clone.1 = u32[1280,1280]{1,0} add(%add.252675.3.clone.1, %xor.123850.3.clone.1) + %add.252677.7.clone.1 = u32[1280,1280]{1,0} add(%add.252676.3.clone.1, %broadcast.259471.113.clone.1) + %shift-left.111467.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123850.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117774.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123850.3.clone.1, %broadcast.244418.4352) + %or.117296.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111467.5.clone.1, %shift-right-logical.117774.5.clone.1) + %xor.123851.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252676.3.clone.1, %or.117296.3.clone.1) + %constant_218710_1_clone_1 = u32[] constant(3678214261) + %broadcast.259481.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218710_1_clone_1), dimensions={} + %add.252678.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123851.3.clone.1, %broadcast.259481.5.clone.1) + %add.252679.5.clone.1 = u32[1280,1280]{1,0} add(%add.252677.7.clone.1, %add.252678.5.clone.1) + %shift-left.111468.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252678.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117775.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252678.5.clone.1, %broadcast.244416.5760) + %or.117297.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111468.9.clone.1, %shift-right-logical.117775.9.clone.1) + %xor.123852.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252679.5.clone.1, %or.117297.7.clone.1) + %add.252681.3.clone.1 = u32[1280,1280]{1,0} add(%add.252679.5.clone.1, %xor.123852.5.clone.1) + %shift-left.111469.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123852.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117776.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123852.5.clone.1, %broadcast.244429.2304) + %or.117298.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111469.9.clone.1, %shift-right-logical.117776.9.clone.1) + %xor.123853.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252681.3.clone.1, %or.117298.7.clone.1) + %add.252684.3.clone.1 = u32[1280,1280]{1,0} add(%add.252681.3.clone.1, %xor.123853.5.clone.1) + %shift-left.111470.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123853.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117777.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123853.5.clone.1, %broadcast.244430.4608) + %or.117299.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111470.9.clone.1, %shift-right-logical.117777.9.clone.1) + %xor.123854.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252684.3.clone.1, %or.117299.7.clone.1) + %add.252685.3.clone.1 = u32[1280,1280]{1,0} add(%add.252684.3.clone.1, %xor.123854.5.clone.1) + %constant_192906_1_clone_1 = u32[] constant(3678214260) + %broadcast.259488.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_192906_1_clone_1), dimensions={} + %add.252686.7.clone.1 = u32[1280,1280]{1,0} add(%add.252685.3.clone.1, %broadcast.259488.24.clone.1) + %shift-left.111471.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123854.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117778.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123854.5.clone.1, %broadcast.244434.2816) + %or.117300.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111471.11.clone.1, %shift-right-logical.117778.11.clone.1) + %xor.123855.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252685.3.clone.1, %or.117300.9.clone.1) + %constant_218711_1_clone_1 = u32[] constant(749277647) + %broadcast.259491.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218711_1_clone_1), dimensions={} + %add.252687.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123855.7.clone.1, %broadcast.259491.5.clone.1) + %add.252688.5.clone.1 = u32[1280,1280]{1,0} add(%add.252686.7.clone.1, %add.252687.5.clone.1) + %shift-left.111472.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252687.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117779.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252687.5.clone.1, %broadcast.244415.6016) + %or.117301.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111472.9.clone.1, %shift-right-logical.117779.9.clone.1) + %xor.123856.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252688.5.clone.1, %or.117301.7.clone.1) + %add.252689.3.clone.1 = u32[1280,1280]{1,0} add(%add.252688.5.clone.1, %xor.123856.5.clone.1) + %shift-left.111473.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123856.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117780.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123856.5.clone.1, %broadcast.244417.5760) + %or.117302.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111473.9.clone.1, %shift-right-logical.117780.9.clone.1) + %xor.123857.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252689.3.clone.1, %or.117302.7.clone.1) + %add.252690.3.clone.1 = u32[1280,1280]{1,0} add(%add.252689.3.clone.1, %xor.123857.5.clone.1) + %shift-left.111474.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123857.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117781.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123857.5.clone.1, %broadcast.244419.4352) + %or.117303.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111474.7.clone.1, %shift-right-logical.117781.7.clone.1) + %xor.123858.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252690.3.clone.1, %or.117303.5.clone.1) + %add.252691.3.clone.1 = u32[1280,1280]{1,0} add(%add.252690.3.clone.1, %xor.123858.3.clone.1) + %add.252692.7.clone.1 = u32[1280,1280]{1,0} add(%add.252691.3.clone.1, %broadcast.259469.44.clone.1) + %shift-left.111475.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123858.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117782.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123858.3.clone.1, %broadcast.244418.4352) + %or.117304.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111475.7.clone.1, %shift-right-logical.117782.7.clone.1) + %xor.123859.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252691.3.clone.1, %or.117304.5.clone.1) + %constant_218712_1_clone_1 = u32[] constant(3963949670) + %broadcast.259501.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218712_1_clone_1), dimensions={} + %add.252693.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123859.3.clone.1, %broadcast.259501.5.clone.1) + %add.252694.5.clone.1 = u32[1280,1280]{1,0} add(%add.252692.7.clone.1, %add.252693.5.clone.1) + %shift-left.111477.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252693.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.117783.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252693.5.clone.1, %broadcast.244416.5760) + %or.117305.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111477.9.clone.1, %shift-right-logical.117783.9.clone.1) + %xor.123860.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252694.5.clone.1, %or.117305.7.clone.1) + %add.252695.3.clone.1 = u32[1280,1280]{1,0} add(%add.252694.5.clone.1, %xor.123860.5.clone.1) + %shift-left.111478.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123860.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.117784.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123860.5.clone.1, %broadcast.244429.2304) + %or.117306.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111478.9.clone.1, %shift-right-logical.117784.9.clone.1) + %xor.123861.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252695.3.clone.1, %or.117306.7.clone.1) + %add.252696.3.clone.1 = u32[1280,1280]{1,0} add(%add.252695.3.clone.1, %xor.123861.5.clone.1) + %shift-left.111479.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123861.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.117785.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123861.5.clone.1, %broadcast.244430.4608) + %or.117307.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111479.9.clone.1, %shift-right-logical.117785.9.clone.1) + %xor.123862.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252696.3.clone.1, %or.117307.7.clone.1) + %add.252697.3.clone.1 = u32[1280,1280]{1,0} add(%add.252696.3.clone.1, %xor.123862.5.clone.1) + %add.252698.7.clone.1 = u32[1280,1280]{1,0} add(%add.252697.3.clone.1, %broadcast.259471.113.clone.1) + %shift-left.111480.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123862.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.117786.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123862.5.clone.1, %broadcast.244434.2816) + %or.117308.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111480.11.clone.1, %shift-right-logical.117786.11.clone.1) + %xor.123863.7.clone.1 = u32[1280,1280]{1,0} xor(%add.252697.3.clone.1, %or.117308.9.clone.1) + %constant_218713_1_clone_1 = u32[] constant(3678214264) + %broadcast.259511.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218713_1_clone_1), dimensions={} + %add.252699.5.clone.1 = u32[1280,1280]{1,0} add(%xor.123863.7.clone.1, %broadcast.259511.5.clone.1) + %add.252700.5.clone.1 = u32[1280,1280]{1,0} add(%add.252698.7.clone.1, %add.252699.5.clone.1) + %shift-left.111482.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.252699.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.117787.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.252699.5.clone.1, %broadcast.244415.6016) + %or.117309.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111482.9.clone.1, %shift-right-logical.117787.9.clone.1) + %xor.123864.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252700.5.clone.1, %or.117309.7.clone.1) + %add.252701.3.clone.1 = u32[1280,1280]{1,0} add(%add.252700.5.clone.1, %xor.123864.5.clone.1) + %shift-left.111483.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123864.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.117788.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123864.5.clone.1, %broadcast.244417.5760) + %or.117310.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111483.9.clone.1, %shift-right-logical.117788.9.clone.1) + %xor.123865.5.clone.1 = u32[1280,1280]{1,0} xor(%add.252701.3.clone.1, %or.117310.7.clone.1) + %add.252702.3.clone.1 = u32[1280,1280]{1,0} add(%add.252701.3.clone.1, %xor.123865.5.clone.1) + %shift-left.111484.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123865.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.117789.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123865.5.clone.1, %broadcast.244419.4352) + %or.117311.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111484.5.clone.1, %shift-right-logical.117789.5.clone.1) + %xor.123866.3.clone.1 = u32[1280,1280]{1,0} xor(%add.252702.3.clone.1, %or.117311.3.clone.1) + %add.252703.3.clone.1 = u32[1280,1280]{1,0} add(%add.252702.3.clone.1, %xor.123866.3.clone.1) + %add.252704.17.clone.1 = u32[1280,1280]{1,0} add(%add.252703.3.clone.1, %broadcast.259488.24.clone.1) + %shift-left.111485.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.123866.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.117790.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123866.3.clone.1, %broadcast.244418.4352) + %or.117312.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.111485.5.clone.1, %shift-right-logical.117790.5.clone.1) + %xor.123867.15.clone.1 = u32[1280,1280]{1,0} xor(%add.252703.3.clone.1, %or.117312.3.clone.1) + %constant_218714_1_clone_1 = u32[] constant(749277650) + %broadcast.259521.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218714_1_clone_1), dimensions={} + %add.252705.19.clone.1 = u32[1280,1280]{1,0} add(%xor.123867.15.clone.1, %broadcast.259521.19.clone.1) + %xor.123868.17.clone.1 = u32[1280,1280]{1,0} xor(%add.252704.17.clone.1, %add.252705.19.clone.1) + %shift-right-logical.117791.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.123868.17.clone.1, %broadcast.244468.1920) + %or.117313.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.117791.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5840.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.117313.13.clone.1) + %add.252706.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5840.11.clone.1, %broadcast.244470.1152) + %multiply.27352.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252706.9.clone.1, %broadcast.244471.896) + %add.252707.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.27352.7.clone.1, %broadcast.244408.1024) + %maximum.3772.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.252707.5.clone.1) + %abs.1598.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3772.3.clone.1) + %compare.7358.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1598.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.27353.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3772.3.clone.1, %broadcast.244476.1152) + %negate.4701.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3772.3.clone.1) + %multiply.27354.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3772.3.clone.1, %negate.4701.5.clone.1) + %log-plus-one.1598.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.27354.5.clone.1) + %negate.4702.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1598.3.clone.1) + %compare.7359.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4702.4.clone.1, %broadcast.244477.384), direction=LT + %select.21685.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21686.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21687.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21688.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21689.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21690.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21691.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21692.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21693.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.252708.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4702.4.clone.1, %broadcast.244496.640) + %sqrt.1598.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4702.4.clone.1) + %add.252709.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1598.5.clone.1, %broadcast.244498.640) + %select.21694.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7359.3.clone.1, %add.252708.5.clone.1, %add.252709.5.clone.1) + %multiply.27355.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21693.3.clone.1, %select.21694.3.clone.1) + %add.252710.1.clone.1 = f32[1280,1280]{1,0} add(%select.21692.3.clone.1, %multiply.27355.1.clone.1) + %multiply.27356.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252710.1.clone.1, %select.21694.3.clone.1) + %add.252712.1.clone.1 = f32[1280,1280]{1,0} add(%select.21691.3.clone.1, %multiply.27356.1.clone.1) + %multiply.27357.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252712.1.clone.1, %select.21694.3.clone.1) + %add.252713.1.clone.1 = f32[1280,1280]{1,0} add(%select.21690.3.clone.1, %multiply.27357.1.clone.1) + %multiply.27358.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252713.1.clone.1, %select.21694.3.clone.1) + %add.252714.1.clone.1 = f32[1280,1280]{1,0} add(%select.21689.3.clone.1, %multiply.27358.1.clone.1) + %multiply.27359.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252714.1.clone.1, %select.21694.3.clone.1) + %add.252715.3.clone.1 = f32[1280,1280]{1,0} add(%select.21688.5.clone.1, %multiply.27359.1.clone.1) + %multiply.27360.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.252715.3.clone.1, %select.21694.3.clone.1) + %add.252716.3.clone.1 = f32[1280,1280]{1,0} add(%select.21687.5.clone.1, %multiply.27360.1.clone.1) + %multiply.27361.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252716.3.clone.1, %select.21694.3.clone.1) + %add.252717.9.clone.1 = f32[1280,1280]{1,0} add(%select.21686.11.clone.1, %multiply.27361.7.clone.1) + %multiply.27362.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252717.9.clone.1, %select.21694.3.clone.1) + %add.252718.7.clone.1 = f32[1280,1280]{1,0} add(%select.21685.7.clone.1, %multiply.27362.7.clone.1) + %multiply.27363.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.252718.7.clone.1, %maximum.3772.3.clone.1) + %select.21695.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7358.3.clone.1, %multiply.27353.9.clone.1, %multiply.27363.7.clone.1) + %multiply.27364.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21695.7.clone.1, %broadcast.244500.640) + %clamp.1242.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.27364.5.clone.1, %broadcast.244501.384) + %multiply.27365.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1242.3.clone.1, %broadcast.244502.1) + %constant_164188_1_clone_1 = u32[] constant(961610870) + %broadcast.247063.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164188_1_clone_1), dimensions={} + %add.245577.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.247063.44.clone.1) + %constant_164211_1_clone_1 = u32[] constant(4021814316) + %broadcast.247064.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164211_1_clone_1), dimensions={} + %add.245578.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.247064.113.clone.1) + %add.245579.35.clone.1 = u32[1280,1280]{1,0} add(%add.245577.37.clone.1, %add.245578.99.clone.1) + %shift-left.108400.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245578.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114524.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245578.99.clone.1, %broadcast.244415.6016) + %or.114049.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108400.31.clone.1, %shift-right-logical.114524.29.clone.1) + %xor.120598.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245579.35.clone.1, %or.114049.29.clone.1) + %add.245581.5.clone.1 = u32[1280,1280]{1,0} add(%add.245579.35.clone.1, %xor.120598.27.clone.1) + %shift-left.108401.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120598.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114525.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120598.27.clone.1, %broadcast.244417.5760) + %or.114050.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108401.9.clone.1, %shift-right-logical.114525.9.clone.1) + %xor.120599.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245581.5.clone.1, %or.114050.7.clone.1) + %add.245582.3.clone.1 = u32[1280,1280]{1,0} add(%add.245581.5.clone.1, %xor.120599.5.clone.1) + %shift-left.108402.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120599.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114527.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120599.5.clone.1, %broadcast.244419.4352) + %or.114051.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108402.5.clone.1, %shift-right-logical.114527.5.clone.1) + %xor.120601.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245582.3.clone.1, %or.114051.3.clone.1) + %add.245583.3.clone.1 = u32[1280,1280]{1,0} add(%add.245582.3.clone.1, %xor.120601.3.clone.1) + %add.245584.7.clone.1 = u32[1280,1280]{1,0} add(%add.245583.3.clone.1, %broadcast.247064.113.clone.1) + %shift-left.108403.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120601.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114528.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120601.3.clone.1, %broadcast.244418.4352) + %or.114052.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108403.5.clone.1, %shift-right-logical.114528.5.clone.1) + %xor.120602.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245583.3.clone.1, %or.114052.3.clone.1) + %constant_217929_1_clone_1 = u32[] constant(3443006337) + %broadcast.247074.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217929_1_clone_1), dimensions={} + %add.245586.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120602.3.clone.1, %broadcast.247074.5.clone.1) + %add.245587.5.clone.1 = u32[1280,1280]{1,0} add(%add.245584.7.clone.1, %add.245586.5.clone.1) + %shift-left.108404.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245586.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114529.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245586.5.clone.1, %broadcast.244416.5760) + %or.114053.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108404.9.clone.1, %shift-right-logical.114529.9.clone.1) + %xor.120603.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245587.5.clone.1, %or.114053.7.clone.1) + %add.245588.3.clone.1 = u32[1280,1280]{1,0} add(%add.245587.5.clone.1, %xor.120603.5.clone.1) + %shift-left.108405.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120603.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114530.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120603.5.clone.1, %broadcast.244429.2304) + %or.114054.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108405.9.clone.1, %shift-right-logical.114530.9.clone.1) + %xor.120604.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245588.3.clone.1, %or.114054.7.clone.1) + %add.245589.3.clone.1 = u32[1280,1280]{1,0} add(%add.245588.3.clone.1, %xor.120604.5.clone.1) + %shift-left.108406.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120604.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114531.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120604.5.clone.1, %broadcast.244430.4608) + %or.114055.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108406.9.clone.1, %shift-right-logical.114531.9.clone.1) + %xor.120606.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245589.3.clone.1, %or.114055.7.clone.1) + %add.245591.3.clone.1 = u32[1280,1280]{1,0} add(%add.245589.3.clone.1, %xor.120606.5.clone.1) + %constant_164213_1_clone_1 = u32[] constant(3443006336) + %broadcast.247081.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_164213_1_clone_1), dimensions={} + %add.245592.7.clone.1 = u32[1280,1280]{1,0} add(%add.245591.3.clone.1, %broadcast.247081.24.clone.1) + %shift-left.108407.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120606.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114532.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120606.5.clone.1, %broadcast.244434.2816) + %or.114056.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108407.11.clone.1, %shift-right-logical.114532.11.clone.1) + %xor.120607.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245591.3.clone.1, %or.114056.9.clone.1) + %constant_217930_1_clone_1 = u32[] constant(961610872) + %broadcast.247084.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217930_1_clone_1), dimensions={} + %add.245593.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120607.7.clone.1, %broadcast.247084.5.clone.1) + %add.245594.5.clone.1 = u32[1280,1280]{1,0} add(%add.245592.7.clone.1, %add.245593.5.clone.1) + %shift-left.108408.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245593.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114533.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245593.5.clone.1, %broadcast.244415.6016) + %or.114057.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108408.9.clone.1, %shift-right-logical.114533.9.clone.1) + %xor.120608.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245594.5.clone.1, %or.114057.7.clone.1) + %add.245595.3.clone.1 = u32[1280,1280]{1,0} add(%add.245594.5.clone.1, %xor.120608.5.clone.1) + %shift-left.108409.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120608.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114534.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120608.5.clone.1, %broadcast.244417.5760) + %or.114058.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108409.9.clone.1, %shift-right-logical.114534.9.clone.1) + %xor.120609.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245595.3.clone.1, %or.114058.7.clone.1) + %add.245597.3.clone.1 = u32[1280,1280]{1,0} add(%add.245595.3.clone.1, %xor.120609.5.clone.1) + %shift-left.108410.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120609.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114535.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120609.5.clone.1, %broadcast.244419.4352) + %or.114059.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108410.7.clone.1, %shift-right-logical.114535.7.clone.1) + %xor.120610.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245597.3.clone.1, %or.114059.5.clone.1) + %add.245601.3.clone.1 = u32[1280,1280]{1,0} add(%add.245597.3.clone.1, %xor.120610.3.clone.1) + %add.245602.7.clone.1 = u32[1280,1280]{1,0} add(%add.245601.3.clone.1, %broadcast.247063.44.clone.1) + %shift-left.108411.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120610.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114537.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120610.3.clone.1, %broadcast.244418.4352) + %or.114060.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108411.7.clone.1, %shift-right-logical.114537.7.clone.1) + %xor.120611.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245601.3.clone.1, %or.114060.5.clone.1) + %constant_217931_1_clone_1 = u32[] constant(4021814319) + %broadcast.247094.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217931_1_clone_1), dimensions={} + %add.245603.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120611.3.clone.1, %broadcast.247094.5.clone.1) + %add.245604.5.clone.1 = u32[1280,1280]{1,0} add(%add.245602.7.clone.1, %add.245603.5.clone.1) + %shift-left.108412.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245603.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114538.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245603.5.clone.1, %broadcast.244416.5760) + %or.114061.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108412.9.clone.1, %shift-right-logical.114538.9.clone.1) + %xor.120612.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245604.5.clone.1, %or.114061.7.clone.1) + %add.245606.3.clone.1 = u32[1280,1280]{1,0} add(%add.245604.5.clone.1, %xor.120612.5.clone.1) + %shift-left.108413.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120612.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114539.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120612.5.clone.1, %broadcast.244429.2304) + %or.114062.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108413.9.clone.1, %shift-right-logical.114539.9.clone.1) + %xor.120613.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245606.3.clone.1, %or.114062.7.clone.1) + %add.245607.3.clone.1 = u32[1280,1280]{1,0} add(%add.245606.3.clone.1, %xor.120613.5.clone.1) + %shift-left.108414.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120613.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114540.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120613.5.clone.1, %broadcast.244430.4608) + %or.114063.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108414.9.clone.1, %shift-right-logical.114540.9.clone.1) + %xor.120614.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245607.3.clone.1, %or.114063.7.clone.1) + %add.245608.3.clone.1 = u32[1280,1280]{1,0} add(%add.245607.3.clone.1, %xor.120614.5.clone.1) + %add.245609.7.clone.1 = u32[1280,1280]{1,0} add(%add.245608.3.clone.1, %broadcast.247064.113.clone.1) + %shift-left.108415.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120614.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114542.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120614.5.clone.1, %broadcast.244434.2816) + %or.114064.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108415.11.clone.1, %shift-right-logical.114542.11.clone.1) + %xor.120616.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245608.3.clone.1, %or.114064.9.clone.1) + %constant_217932_1_clone_1 = u32[] constant(3443006340) + %broadcast.247104.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217932_1_clone_1), dimensions={} + %add.245611.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120616.7.clone.1, %broadcast.247104.5.clone.1) + %add.245612.5.clone.1 = u32[1280,1280]{1,0} add(%add.245609.7.clone.1, %add.245611.5.clone.1) + %shift-left.108416.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245611.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114543.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245611.5.clone.1, %broadcast.244415.6016) + %or.114065.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108416.9.clone.1, %shift-right-logical.114543.9.clone.1) + %xor.120617.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245612.5.clone.1, %or.114065.7.clone.1) + %add.245613.3.clone.1 = u32[1280,1280]{1,0} add(%add.245612.5.clone.1, %xor.120617.5.clone.1) + %shift-left.108417.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120617.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114544.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120617.5.clone.1, %broadcast.244417.5760) + %or.114066.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108417.9.clone.1, %shift-right-logical.114544.9.clone.1) + %xor.120618.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245613.3.clone.1, %or.114066.7.clone.1) + %add.245614.3.clone.1 = u32[1280,1280]{1,0} add(%add.245613.3.clone.1, %xor.120618.5.clone.1) + %shift-left.108418.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120618.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114545.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120618.5.clone.1, %broadcast.244419.4352) + %or.114067.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108418.5.clone.1, %shift-right-logical.114545.5.clone.1) + %xor.120619.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245614.3.clone.1, %or.114067.3.clone.1) + %add.245616.3.clone.1 = u32[1280,1280]{1,0} add(%add.245614.3.clone.1, %xor.120619.3.clone.1) + %add.245617.17.clone.1 = u32[1280,1280]{1,0} add(%add.245616.3.clone.1, %broadcast.247081.24.clone.1) + %shift-left.108419.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120619.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114547.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120619.3.clone.1, %broadcast.244418.4352) + %or.114068.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108419.5.clone.1, %shift-right-logical.114547.5.clone.1) + %xor.120621.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245616.3.clone.1, %or.114068.3.clone.1) + %constant_217933_1_clone_1 = u32[] constant(961610875) + %broadcast.247114.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217933_1_clone_1), dimensions={} + %add.245618.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120621.15.clone.1, %broadcast.247114.19.clone.1) + %xor.120622.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245617.17.clone.1, %add.245618.19.clone.1) + %shift-right-logical.114548.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120622.17.clone.1, %broadcast.244468.1920) + %or.114069.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114548.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5699.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114069.13.clone.1) + %add.245619.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5699.11.clone.1, %broadcast.244470.1152) + %multiply.25887.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245619.9.clone.1, %broadcast.244471.896) + %add.245620.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25887.7.clone.1, %broadcast.244408.1024) + %maximum.3631.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245620.5.clone.1) + %abs.1503.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3631.3.clone.1) + %compare.7154.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1503.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25888.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3631.3.clone.1, %broadcast.244476.1152) + %negate.4511.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3631.3.clone.1) + %multiply.25889.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3631.3.clone.1, %negate.4511.5.clone.1) + %log-plus-one.1503.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25889.5.clone.1) + %negate.4512.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1503.3.clone.1) + %compare.7155.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4512.4.clone.1, %broadcast.244477.384), direction=LT + %select.20598.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20599.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20600.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20601.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20602.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20603.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20604.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20605.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20606.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245622.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4512.4.clone.1, %broadcast.244496.640) + %sqrt.1503.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4512.4.clone.1) + %add.245626.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1503.5.clone.1, %broadcast.244498.640) + %select.20607.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7155.3.clone.1, %add.245622.5.clone.1, %add.245626.5.clone.1) + %multiply.25890.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20606.3.clone.1, %select.20607.3.clone.1) + %add.245627.1.clone.1 = f32[1280,1280]{1,0} add(%select.20605.3.clone.1, %multiply.25890.1.clone.1) + %multiply.25891.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245627.1.clone.1, %select.20607.3.clone.1) + %add.245628.1.clone.1 = f32[1280,1280]{1,0} add(%select.20604.3.clone.1, %multiply.25891.1.clone.1) + %multiply.25892.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245628.1.clone.1, %select.20607.3.clone.1) + %add.245629.1.clone.1 = f32[1280,1280]{1,0} add(%select.20603.3.clone.1, %multiply.25892.1.clone.1) + %multiply.25893.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245629.1.clone.1, %select.20607.3.clone.1) + %add.245631.1.clone.1 = f32[1280,1280]{1,0} add(%select.20602.3.clone.1, %multiply.25893.1.clone.1) + %multiply.25894.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245631.1.clone.1, %select.20607.3.clone.1) + %add.245632.3.clone.1 = f32[1280,1280]{1,0} add(%select.20601.5.clone.1, %multiply.25894.1.clone.1) + %multiply.25895.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245632.3.clone.1, %select.20607.3.clone.1) + %add.245633.3.clone.1 = f32[1280,1280]{1,0} add(%select.20600.5.clone.1, %multiply.25895.1.clone.1) + %multiply.25896.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245633.3.clone.1, %select.20607.3.clone.1) + %add.245634.9.clone.1 = f32[1280,1280]{1,0} add(%select.20599.11.clone.1, %multiply.25896.7.clone.1) + %multiply.25897.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245634.9.clone.1, %select.20607.3.clone.1) + %add.245636.7.clone.1 = f32[1280,1280]{1,0} add(%select.20598.7.clone.1, %multiply.25897.7.clone.1) + %multiply.25898.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245636.7.clone.1, %maximum.3631.3.clone.1) + %select.20608.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7154.3.clone.1, %multiply.25888.9.clone.1, %multiply.25898.7.clone.1) + %multiply.25899.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20608.7.clone.1, %broadcast.244500.640) + %clamp.1147.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25899.5.clone.1, %broadcast.244501.384) + %multiply.25900.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1147.3.clone.1, %broadcast.244502.1) + %constant_180474_1_clone_1 = u32[] constant(594086741) + %broadcast.254080.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_180474_1_clone_1), dimensions={} + %add.249588.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.254080.44.clone.1) + %constant_180481_1_clone_1 = u32[] constant(3794055208) + %broadcast.254081.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_180481_1_clone_1), dimensions={} + %add.249590.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.254081.113.clone.1) + %add.249591.35.clone.1 = u32[1280,1280]{1,0} add(%add.249588.37.clone.1, %add.249590.99.clone.1) + %shift-left.110121.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249590.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.116352.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249590.99.clone.1, %broadcast.244415.6016) + %or.115885.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110121.31.clone.1, %shift-right-logical.116352.29.clone.1) + %xor.122432.27.clone.1 = u32[1280,1280]{1,0} xor(%add.249591.35.clone.1, %or.115885.29.clone.1) + %add.249592.5.clone.1 = u32[1280,1280]{1,0} add(%add.249591.35.clone.1, %xor.122432.27.clone.1) + %shift-left.110122.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122432.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.116354.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122432.27.clone.1, %broadcast.244417.5760) + %or.115886.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110122.9.clone.1, %shift-right-logical.116354.9.clone.1) + %xor.122433.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249592.5.clone.1, %or.115886.7.clone.1) + %add.249593.3.clone.1 = u32[1280,1280]{1,0} add(%add.249592.5.clone.1, %xor.122433.5.clone.1) + %shift-left.110123.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122433.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116355.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122433.5.clone.1, %broadcast.244419.4352) + %or.115887.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110123.5.clone.1, %shift-right-logical.116355.5.clone.1) + %xor.122434.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249593.3.clone.1, %or.115887.3.clone.1) + %add.249595.3.clone.1 = u32[1280,1280]{1,0} add(%add.249593.3.clone.1, %xor.122434.3.clone.1) + %add.249596.7.clone.1 = u32[1280,1280]{1,0} add(%add.249595.3.clone.1, %broadcast.254081.113.clone.1) + %shift-left.110124.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122434.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116356.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122434.3.clone.1, %broadcast.244418.4352) + %or.115888.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110124.5.clone.1, %shift-right-logical.116356.5.clone.1) + %xor.122435.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249595.3.clone.1, %or.115888.3.clone.1) + %constant_218381_1_clone_1 = u32[] constant(3667697832) + %broadcast.254091.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218381_1_clone_1), dimensions={} + %add.249597.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122435.3.clone.1, %broadcast.254091.5.clone.1) + %add.249598.5.clone.1 = u32[1280,1280]{1,0} add(%add.249596.7.clone.1, %add.249597.5.clone.1) + %shift-left.110126.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249597.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116357.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249597.5.clone.1, %broadcast.244416.5760) + %or.115889.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110126.9.clone.1, %shift-right-logical.116357.9.clone.1) + %xor.122436.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249598.5.clone.1, %or.115889.7.clone.1) + %add.249600.3.clone.1 = u32[1280,1280]{1,0} add(%add.249598.5.clone.1, %xor.122436.5.clone.1) + %shift-left.110127.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122436.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116359.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122436.5.clone.1, %broadcast.244429.2304) + %or.115890.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110127.9.clone.1, %shift-right-logical.116359.9.clone.1) + %xor.122437.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249600.3.clone.1, %or.115890.7.clone.1) + %add.249601.3.clone.1 = u32[1280,1280]{1,0} add(%add.249600.3.clone.1, %xor.122437.5.clone.1) + %shift-left.110128.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122437.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116360.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122437.5.clone.1, %broadcast.244430.4608) + %or.115891.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110128.9.clone.1, %shift-right-logical.116360.9.clone.1) + %xor.122438.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249601.3.clone.1, %or.115891.7.clone.1) + %add.249602.3.clone.1 = u32[1280,1280]{1,0} add(%add.249601.3.clone.1, %xor.122438.5.clone.1) + %constant_180483_1_clone_1 = u32[] constant(3667697831) + %broadcast.254098.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_180483_1_clone_1), dimensions={} + %add.249603.7.clone.1 = u32[1280,1280]{1,0} add(%add.249602.3.clone.1, %broadcast.254098.24.clone.1) + %shift-left.110129.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122438.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116361.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122438.5.clone.1, %broadcast.244434.2816) + %or.115892.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110129.11.clone.1, %shift-right-logical.116361.11.clone.1) + %xor.122441.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249602.3.clone.1, %or.115892.9.clone.1) + %constant_218383_1_clone_1 = u32[] constant(594086743) + %broadcast.254101.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218383_1_clone_1), dimensions={} + %add.249604.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122441.7.clone.1, %broadcast.254101.5.clone.1) + %add.249606.5.clone.1 = u32[1280,1280]{1,0} add(%add.249603.7.clone.1, %add.249604.5.clone.1) + %shift-left.110131.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249604.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116362.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249604.5.clone.1, %broadcast.244415.6016) + %or.115893.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110131.9.clone.1, %shift-right-logical.116362.9.clone.1) + %xor.122442.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249606.5.clone.1, %or.115893.7.clone.1) + %add.249610.3.clone.1 = u32[1280,1280]{1,0} add(%add.249606.5.clone.1, %xor.122442.5.clone.1) + %shift-left.110132.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122442.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116364.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122442.5.clone.1, %broadcast.244417.5760) + %or.115894.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110132.9.clone.1, %shift-right-logical.116364.9.clone.1) + %xor.122443.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249610.3.clone.1, %or.115894.7.clone.1) + %add.249611.3.clone.1 = u32[1280,1280]{1,0} add(%add.249610.3.clone.1, %xor.122443.5.clone.1) + %shift-left.110133.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122443.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116365.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122443.5.clone.1, %broadcast.244419.4352) + %or.115895.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110133.7.clone.1, %shift-right-logical.116365.7.clone.1) + %xor.122444.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249611.3.clone.1, %or.115895.5.clone.1) + %add.249612.3.clone.1 = u32[1280,1280]{1,0} add(%add.249611.3.clone.1, %xor.122444.3.clone.1) + %add.249613.7.clone.1 = u32[1280,1280]{1,0} add(%add.249612.3.clone.1, %broadcast.254080.44.clone.1) + %shift-left.110134.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122444.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116366.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122444.3.clone.1, %broadcast.244418.4352) + %or.115896.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110134.7.clone.1, %shift-right-logical.116366.7.clone.1) + %xor.122445.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249612.3.clone.1, %or.115896.5.clone.1) + %constant_218385_1_clone_1 = u32[] constant(3794055211) + %broadcast.254111.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218385_1_clone_1), dimensions={} + %add.249615.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122445.3.clone.1, %broadcast.254111.5.clone.1) + %add.249616.5.clone.1 = u32[1280,1280]{1,0} add(%add.249613.7.clone.1, %add.249615.5.clone.1) + %shift-left.110136.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249615.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.116367.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249615.5.clone.1, %broadcast.244416.5760) + %or.115897.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110136.9.clone.1, %shift-right-logical.116367.9.clone.1) + %xor.122446.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249616.5.clone.1, %or.115897.7.clone.1) + %add.249617.3.clone.1 = u32[1280,1280]{1,0} add(%add.249616.5.clone.1, %xor.122446.5.clone.1) + %shift-left.110137.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122446.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.116368.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122446.5.clone.1, %broadcast.244429.2304) + %or.115898.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110137.9.clone.1, %shift-right-logical.116368.9.clone.1) + %xor.122447.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249617.3.clone.1, %or.115898.7.clone.1) + %add.249618.3.clone.1 = u32[1280,1280]{1,0} add(%add.249617.3.clone.1, %xor.122447.5.clone.1) + %shift-left.110138.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122447.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.116369.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122447.5.clone.1, %broadcast.244430.4608) + %or.115899.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110138.9.clone.1, %shift-right-logical.116369.9.clone.1) + %xor.122448.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249618.3.clone.1, %or.115899.7.clone.1) + %add.249620.3.clone.1 = u32[1280,1280]{1,0} add(%add.249618.3.clone.1, %xor.122448.5.clone.1) + %add.249621.7.clone.1 = u32[1280,1280]{1,0} add(%add.249620.3.clone.1, %broadcast.254081.113.clone.1) + %shift-left.110139.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122448.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.116370.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122448.5.clone.1, %broadcast.244434.2816) + %or.115900.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110139.11.clone.1, %shift-right-logical.116370.11.clone.1) + %xor.122450.7.clone.1 = u32[1280,1280]{1,0} xor(%add.249620.3.clone.1, %or.115900.9.clone.1) + %constant_218387_1_clone_1 = u32[] constant(3667697835) + %broadcast.254121.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218387_1_clone_1), dimensions={} + %add.249622.5.clone.1 = u32[1280,1280]{1,0} add(%xor.122450.7.clone.1, %broadcast.254121.5.clone.1) + %add.249623.5.clone.1 = u32[1280,1280]{1,0} add(%add.249621.7.clone.1, %add.249622.5.clone.1) + %shift-left.110140.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.249622.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.116371.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.249622.5.clone.1, %broadcast.244415.6016) + %or.115901.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110140.9.clone.1, %shift-right-logical.116371.9.clone.1) + %xor.122451.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249623.5.clone.1, %or.115901.7.clone.1) + %add.249625.3.clone.1 = u32[1280,1280]{1,0} add(%add.249623.5.clone.1, %xor.122451.5.clone.1) + %shift-left.110141.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122451.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.116372.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122451.5.clone.1, %broadcast.244417.5760) + %or.115902.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110141.9.clone.1, %shift-right-logical.116372.9.clone.1) + %xor.122452.5.clone.1 = u32[1280,1280]{1,0} xor(%add.249625.3.clone.1, %or.115902.7.clone.1) + %add.249626.3.clone.1 = u32[1280,1280]{1,0} add(%add.249625.3.clone.1, %xor.122452.5.clone.1) + %shift-left.110142.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122452.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.116374.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122452.5.clone.1, %broadcast.244419.4352) + %or.115903.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110142.5.clone.1, %shift-right-logical.116374.5.clone.1) + %xor.122453.3.clone.1 = u32[1280,1280]{1,0} xor(%add.249626.3.clone.1, %or.115903.3.clone.1) + %add.249627.3.clone.1 = u32[1280,1280]{1,0} add(%add.249626.3.clone.1, %xor.122453.3.clone.1) + %add.249628.17.clone.1 = u32[1280,1280]{1,0} add(%add.249627.3.clone.1, %broadcast.254098.24.clone.1) + %shift-left.110143.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.122453.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.116375.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122453.3.clone.1, %broadcast.244418.4352) + %or.115904.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.110143.5.clone.1, %shift-right-logical.116375.5.clone.1) + %xor.122455.15.clone.1 = u32[1280,1280]{1,0} xor(%add.249627.3.clone.1, %or.115904.3.clone.1) + %constant_218388_1_clone_1 = u32[] constant(594086746) + %broadcast.254131.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_218388_1_clone_1), dimensions={} + %add.249629.19.clone.1 = u32[1280,1280]{1,0} add(%xor.122455.15.clone.1, %broadcast.254131.19.clone.1) + %xor.122456.17.clone.1 = u32[1280,1280]{1,0} xor(%add.249628.17.clone.1, %add.249629.19.clone.1) + %shift-right-logical.116376.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.122456.17.clone.1, %broadcast.244468.1920) + %or.115905.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.116376.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5779.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.115905.13.clone.1) + %add.249631.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5779.11.clone.1, %broadcast.244470.1152) + %multiply.26723.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249631.9.clone.1, %broadcast.244471.896) + %add.249635.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.26723.7.clone.1, %broadcast.244408.1024) + %maximum.3711.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.249635.5.clone.1) + %abs.1557.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3711.3.clone.1) + %compare.7263.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1557.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.26724.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3711.3.clone.1, %broadcast.244476.1152) + %negate.4619.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3711.3.clone.1) + %multiply.26725.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3711.3.clone.1, %negate.4619.5.clone.1) + %log-plus-one.1557.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.26725.5.clone.1) + %negate.4620.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1557.3.clone.1) + %compare.7264.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4620.4.clone.1, %broadcast.244477.384), direction=LT + %select.21213.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.21214.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.21215.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.21216.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.21217.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.21218.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.21219.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.21220.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.21221.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.249636.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4620.4.clone.1, %broadcast.244496.640) + %sqrt.1557.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4620.4.clone.1) + %add.249637.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1557.5.clone.1, %broadcast.244498.640) + %select.21222.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7264.3.clone.1, %add.249636.5.clone.1, %add.249637.5.clone.1) + %multiply.26726.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.21221.3.clone.1, %select.21222.3.clone.1) + %add.249638.1.clone.1 = f32[1280,1280]{1,0} add(%select.21220.3.clone.1, %multiply.26726.1.clone.1) + %multiply.26727.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249638.1.clone.1, %select.21222.3.clone.1) + %add.249640.1.clone.1 = f32[1280,1280]{1,0} add(%select.21219.3.clone.1, %multiply.26727.1.clone.1) + %multiply.26728.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249640.1.clone.1, %select.21222.3.clone.1) + %add.249641.1.clone.1 = f32[1280,1280]{1,0} add(%select.21218.3.clone.1, %multiply.26728.1.clone.1) + %multiply.26729.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249641.1.clone.1, %select.21222.3.clone.1) + %add.249642.1.clone.1 = f32[1280,1280]{1,0} add(%select.21217.3.clone.1, %multiply.26729.1.clone.1) + %multiply.26730.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249642.1.clone.1, %select.21222.3.clone.1) + %add.249643.3.clone.1 = f32[1280,1280]{1,0} add(%select.21216.5.clone.1, %multiply.26730.1.clone.1) + %multiply.26731.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.249643.3.clone.1, %select.21222.3.clone.1) + %add.249645.3.clone.1 = f32[1280,1280]{1,0} add(%select.21215.5.clone.1, %multiply.26731.1.clone.1) + %multiply.26732.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249645.3.clone.1, %select.21222.3.clone.1) + %add.249646.9.clone.1 = f32[1280,1280]{1,0} add(%select.21214.11.clone.1, %multiply.26732.7.clone.1) + %multiply.26733.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249646.9.clone.1, %select.21222.3.clone.1) + %add.249647.7.clone.1 = f32[1280,1280]{1,0} add(%select.21213.7.clone.1, %multiply.26733.7.clone.1) + %multiply.26734.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.249647.7.clone.1, %maximum.3711.3.clone.1) + %select.21223.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7263.3.clone.1, %multiply.26724.9.clone.1, %multiply.26734.7.clone.1) + %multiply.26735.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.21223.7.clone.1, %broadcast.244500.640) + %clamp.1201.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.26735.5.clone.1, %broadcast.244501.384) + %multiply.26736.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1201.3.clone.1, %broadcast.244502.1) + %constant_163637_1_clone_1 = u32[] constant(1982261340) + %broadcast.246792.44.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_163637_1_clone_1), dimensions={} + %add.245442.37.clone.1 = u32[1280,1280]{1,0} add(%convert.3610.4865, %broadcast.246792.44.clone.1) + %constant_163660_1_clone_1 = u32[] constant(4196244297) + %broadcast.246793.113.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_163660_1_clone_1), dimensions={} + %add.245443.99.clone.1 = u32[1280,1280]{1,0} add(%convert.3611.12801, %broadcast.246793.113.clone.1) + %add.245444.35.clone.1 = u32[1280,1280]{1,0} add(%add.245442.37.clone.1, %add.245443.99.clone.1) + %shift-left.108340.31.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245443.99.clone.1, %broadcast.244414.6272) + %shift-right-logical.114449.29.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245443.99.clone.1, %broadcast.244415.6016) + %or.113982.29.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108340.31.clone.1, %shift-right-logical.114449.29.clone.1) + %xor.120531.27.clone.1 = u32[1280,1280]{1,0} xor(%add.245444.35.clone.1, %or.113982.29.clone.1) + %add.245445.5.clone.1 = u32[1280,1280]{1,0} add(%add.245444.35.clone.1, %xor.120531.27.clone.1) + %shift-left.108341.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120531.27.clone.1, %broadcast.244416.5760) + %shift-right-logical.114450.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120531.27.clone.1, %broadcast.244417.5760) + %or.113983.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108341.9.clone.1, %shift-right-logical.114450.9.clone.1) + %xor.120532.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245445.5.clone.1, %or.113983.7.clone.1) + %add.245447.3.clone.1 = u32[1280,1280]{1,0} add(%add.245445.5.clone.1, %xor.120532.5.clone.1) + %shift-left.108342.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120532.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114452.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120532.5.clone.1, %broadcast.244419.4352) + %or.113984.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108342.5.clone.1, %shift-right-logical.114452.5.clone.1) + %xor.120533.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245447.3.clone.1, %or.113984.3.clone.1) + %add.245448.3.clone.1 = u32[1280,1280]{1,0} add(%add.245447.3.clone.1, %xor.120533.3.clone.1) + %add.245449.7.clone.1 = u32[1280,1280]{1,0} add(%add.245448.3.clone.1, %broadcast.246793.113.clone.1) + %shift-left.108343.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120533.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114453.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120533.3.clone.1, %broadcast.244418.4352) + %or.113985.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108343.5.clone.1, %shift-right-logical.114453.5.clone.1) + %xor.120534.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245448.3.clone.1, %or.113985.3.clone.1) + %constant_217914_1_clone_1 = u32[] constant(2548721872) + %broadcast.246810.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217914_1_clone_1), dimensions={} + %add.245450.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120534.3.clone.1, %broadcast.246810.5.clone.1) + %add.245452.5.clone.1 = u32[1280,1280]{1,0} add(%add.245449.7.clone.1, %add.245450.5.clone.1) + %shift-left.108344.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245450.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114454.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245450.5.clone.1, %broadcast.244416.5760) + %or.113987.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108344.9.clone.1, %shift-right-logical.114454.9.clone.1) + %xor.120535.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245452.5.clone.1, %or.113987.7.clone.1) + %add.245453.3.clone.1 = u32[1280,1280]{1,0} add(%add.245452.5.clone.1, %xor.120535.5.clone.1) + %shift-left.108345.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120535.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114455.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120535.5.clone.1, %broadcast.244429.2304) + %or.113988.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108345.9.clone.1, %shift-right-logical.114455.9.clone.1) + %xor.120536.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245453.3.clone.1, %or.113988.7.clone.1) + %add.245454.3.clone.1 = u32[1280,1280]{1,0} add(%add.245453.3.clone.1, %xor.120536.5.clone.1) + %shift-left.108346.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120536.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114456.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120536.5.clone.1, %broadcast.244430.4608) + %or.113989.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108346.9.clone.1, %shift-right-logical.114456.9.clone.1) + %xor.120537.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245454.3.clone.1, %or.113989.7.clone.1) + %add.245455.3.clone.1 = u32[1280,1280]{1,0} add(%add.245454.3.clone.1, %xor.120537.5.clone.1) + %constant_163662_1_clone_1 = u32[] constant(2548721871) + %broadcast.246824.24.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_163662_1_clone_1), dimensions={} + %add.245457.7.clone.1 = u32[1280,1280]{1,0} add(%add.245455.3.clone.1, %broadcast.246824.24.clone.1) + %shift-left.108347.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120537.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114457.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120537.5.clone.1, %broadcast.244434.2816) + %or.113990.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108347.11.clone.1, %shift-right-logical.114457.11.clone.1) + %xor.120538.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245455.3.clone.1, %or.113990.9.clone.1) + %constant_217915_1_clone_1 = u32[] constant(1982261342) + %broadcast.246828.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217915_1_clone_1), dimensions={} + %add.245458.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120538.7.clone.1, %broadcast.246828.5.clone.1) + %add.245459.5.clone.1 = u32[1280,1280]{1,0} add(%add.245457.7.clone.1, %add.245458.5.clone.1) + %shift-left.108348.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245458.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114458.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245458.5.clone.1, %broadcast.244415.6016) + %or.113992.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108348.9.clone.1, %shift-right-logical.114458.9.clone.1) + %xor.120539.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245459.5.clone.1, %or.113992.7.clone.1) + %add.245460.3.clone.1 = u32[1280,1280]{1,0} add(%add.245459.5.clone.1, %xor.120539.5.clone.1) + %shift-left.108349.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120539.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114459.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120539.5.clone.1, %broadcast.244417.5760) + %or.113993.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108349.9.clone.1, %shift-right-logical.114459.9.clone.1) + %xor.120540.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245460.3.clone.1, %or.113993.7.clone.1) + %add.245461.3.clone.1 = u32[1280,1280]{1,0} add(%add.245460.3.clone.1, %xor.120540.5.clone.1) + %shift-left.108350.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120540.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114460.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120540.5.clone.1, %broadcast.244419.4352) + %or.113994.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108350.7.clone.1, %shift-right-logical.114460.7.clone.1) + %xor.120541.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245461.3.clone.1, %or.113994.5.clone.1) + %add.245463.3.clone.1 = u32[1280,1280]{1,0} add(%add.245461.3.clone.1, %xor.120541.3.clone.1) + %add.245467.7.clone.1 = u32[1280,1280]{1,0} add(%add.245463.3.clone.1, %broadcast.246792.44.clone.1) + %shift-left.108351.7.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120541.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114462.7.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120541.3.clone.1, %broadcast.244418.4352) + %or.113995.5.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108351.7.clone.1, %shift-right-logical.114462.7.clone.1) + %xor.120542.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245463.3.clone.1, %or.113995.5.clone.1) + %constant_217916_1_clone_1 = u32[] constant(4196244300) + %broadcast.246840.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217916_1_clone_1), dimensions={} + %add.245468.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120542.3.clone.1, %broadcast.246840.5.clone.1) + %add.245469.5.clone.1 = u32[1280,1280]{1,0} add(%add.245467.7.clone.1, %add.245468.5.clone.1) + %shift-left.108352.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245468.5.clone.1, %broadcast.244417.5760) + %shift-right-logical.114463.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245468.5.clone.1, %broadcast.244416.5760) + %or.113997.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108352.9.clone.1, %shift-right-logical.114463.9.clone.1) + %xor.120543.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245469.5.clone.1, %or.113997.7.clone.1) + %add.245470.3.clone.1 = u32[1280,1280]{1,0} add(%add.245469.5.clone.1, %xor.120543.5.clone.1) + %shift-left.108353.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120543.5.clone.1, %broadcast.244428.2304) + %shift-right-logical.114464.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120543.5.clone.1, %broadcast.244429.2304) + %or.113998.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108353.9.clone.1, %shift-right-logical.114464.9.clone.1) + %xor.120544.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245470.3.clone.1, %or.113998.7.clone.1) + %add.245472.3.clone.1 = u32[1280,1280]{1,0} add(%add.245470.3.clone.1, %xor.120544.5.clone.1) + %shift-left.108354.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120544.5.clone.1, %broadcast.244430.4608) + %shift-right-logical.114465.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120544.5.clone.1, %broadcast.244430.4608) + %or.113999.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108354.9.clone.1, %shift-right-logical.114465.9.clone.1) + %xor.120545.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245472.3.clone.1, %or.113999.7.clone.1) + %add.245473.3.clone.1 = u32[1280,1280]{1,0} add(%add.245472.3.clone.1, %xor.120545.5.clone.1) + %add.245474.7.clone.1 = u32[1280,1280]{1,0} add(%add.245473.3.clone.1, %broadcast.246793.113.clone.1) + %shift-left.108355.11.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120545.5.clone.1, %broadcast.244433.2816) + %shift-right-logical.114467.11.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120545.5.clone.1, %broadcast.244434.2816) + %or.114000.9.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108355.11.clone.1, %shift-right-logical.114467.11.clone.1) + %xor.120546.7.clone.1 = u32[1280,1280]{1,0} xor(%add.245473.3.clone.1, %or.114000.9.clone.1) + %constant_217917_1_clone_1 = u32[] constant(2548721875) + %broadcast.246850.5.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217917_1_clone_1), dimensions={} + %add.245475.5.clone.1 = u32[1280,1280]{1,0} add(%xor.120546.7.clone.1, %broadcast.246850.5.clone.1) + %add.245477.5.clone.1 = u32[1280,1280]{1,0} add(%add.245474.7.clone.1, %add.245475.5.clone.1) + %shift-left.108356.9.clone.1 = u32[1280,1280]{1,0} shift-left(%add.245475.5.clone.1, %broadcast.244414.6272) + %shift-right-logical.114468.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%add.245475.5.clone.1, %broadcast.244415.6016) + %or.114002.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108356.9.clone.1, %shift-right-logical.114468.9.clone.1) + %xor.120547.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245477.5.clone.1, %or.114002.7.clone.1) + %add.245478.3.clone.1 = u32[1280,1280]{1,0} add(%add.245477.5.clone.1, %xor.120547.5.clone.1) + %shift-left.108357.9.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120547.5.clone.1, %broadcast.244416.5760) + %shift-right-logical.114469.9.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120547.5.clone.1, %broadcast.244417.5760) + %or.114003.7.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108357.9.clone.1, %shift-right-logical.114469.9.clone.1) + %xor.120548.5.clone.1 = u32[1280,1280]{1,0} xor(%add.245478.3.clone.1, %or.114003.7.clone.1) + %add.245479.3.clone.1 = u32[1280,1280]{1,0} add(%add.245478.3.clone.1, %xor.120548.5.clone.1) + %shift-left.108358.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120548.5.clone.1, %broadcast.244418.4352) + %shift-right-logical.114470.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120548.5.clone.1, %broadcast.244419.4352) + %or.114004.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108358.5.clone.1, %shift-right-logical.114470.5.clone.1) + %xor.120549.3.clone.1 = u32[1280,1280]{1,0} xor(%add.245479.3.clone.1, %or.114004.3.clone.1) + %add.245480.3.clone.1 = u32[1280,1280]{1,0} add(%add.245479.3.clone.1, %xor.120549.3.clone.1) + %add.245482.17.clone.1 = u32[1280,1280]{1,0} add(%add.245480.3.clone.1, %broadcast.246824.24.clone.1) + %shift-left.108359.5.clone.1 = u32[1280,1280]{1,0} shift-left(%xor.120549.3.clone.1, %broadcast.244419.4352) + %shift-right-logical.114472.5.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120549.3.clone.1, %broadcast.244418.4352) + %or.114005.3.clone.1 = u32[1280,1280]{1,0} or(%shift-left.108359.5.clone.1, %shift-right-logical.114472.5.clone.1) + %xor.120550.15.clone.1 = u32[1280,1280]{1,0} xor(%add.245480.3.clone.1, %or.114005.3.clone.1) + %constant_217918_1_clone_1 = u32[] constant(1982261345) + %broadcast.246862.19.clone.1 = u32[1280,1280]{1,0} broadcast(%constant_217918_1_clone_1), dimensions={} + %add.245483.19.clone.1 = u32[1280,1280]{1,0} add(%xor.120550.15.clone.1, %broadcast.246862.19.clone.1) + %xor.120551.17.clone.1 = u32[1280,1280]{1,0} xor(%add.245482.17.clone.1, %add.245483.19.clone.1) + %shift-right-logical.114473.15.clone.1 = u32[1280,1280]{1,0} shift-right-logical(%xor.120551.17.clone.1, %broadcast.244468.1920) + %or.114006.13.clone.1 = u32[1280,1280]{1,0} or(%shift-right-logical.114473.15.clone.1, %broadcast.244469.1664) + %bitcast-convert.5696.11.clone.1 = f32[1280,1280]{1,0} bitcast-convert(%or.114006.13.clone.1) + %add.245484.9.clone.1 = f32[1280,1280]{1,0} add(%bitcast-convert.5696.11.clone.1, %broadcast.244470.1152) + %multiply.25869.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245484.9.clone.1, %broadcast.244471.896) + %add.245485.5.clone.1 = f32[1280,1280]{1,0} add(%multiply.25869.7.clone.1, %broadcast.244408.1024) + %maximum.3628.3.clone.1 = f32[1280,1280]{1,0} maximum(%broadcast.244408.1024, %add.245485.5.clone.1) + %abs.1502.3.clone.1 = f32[1280,1280]{1,0} abs(%maximum.3628.3.clone.1) + %compare.7152.3.clone.1 = pred[1280,1280]{1,0} compare(%abs.1502.3.clone.1, %broadcast.244475.384), direction=EQ + %multiply.25870.9.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3628.3.clone.1, %broadcast.244476.1152) + %negate.4509.5.clone.1 = f32[1280,1280]{1,0} negate(%maximum.3628.3.clone.1) + %multiply.25871.5.clone.1 = f32[1280,1280]{1,0} multiply(%maximum.3628.3.clone.1, %negate.4509.5.clone.1) + %log-plus-one.1502.3.clone.1 = f32[1280,1280]{1,0} log-plus-one(%multiply.25871.5.clone.1) + %negate.4510.4.clone.1 = f32[1280,1280]{1,0} negate(%log-plus-one.1502.3.clone.1) + %compare.7153.3.clone.1 = pred[1280,1280]{1,0} compare(%negate.4510.4.clone.1, %broadcast.244477.384), direction=LT + %select.20587.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244478.896, %broadcast.244479.896) + %select.20588.11.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244480.1408, %broadcast.244481.1408) + %select.20589.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244482.640, %broadcast.244483.640) + %select.20590.5.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244484.640, %broadcast.244485.640) + %select.20591.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244486.384, %broadcast.244487.384) + %select.20592.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244488.384, %broadcast.244489.384) + %select.20593.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244490.384, %broadcast.244491.384) + %select.20594.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244492.384, %broadcast.244493.384) + %select.20595.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %broadcast.244494.384, %broadcast.244495.384) + %add.245486.5.clone.1 = f32[1280,1280]{1,0} add(%negate.4510.4.clone.1, %broadcast.244496.640) + %sqrt.1502.5.clone.1 = f32[1280,1280]{1,0} sqrt(%negate.4510.4.clone.1) + %add.245488.5.clone.1 = f32[1280,1280]{1,0} add(%sqrt.1502.5.clone.1, %broadcast.244498.640) + %select.20596.3.clone.1 = f32[1280,1280]{1,0} select(%compare.7153.3.clone.1, %add.245486.5.clone.1, %add.245488.5.clone.1) + %multiply.25872.1.clone.1 = f32[1280,1280]{1,0} multiply(%select.20595.3.clone.1, %select.20596.3.clone.1) + %add.245492.1.clone.1 = f32[1280,1280]{1,0} add(%select.20594.3.clone.1, %multiply.25872.1.clone.1) + %multiply.25873.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245492.1.clone.1, %select.20596.3.clone.1) + %add.245493.1.clone.1 = f32[1280,1280]{1,0} add(%select.20593.3.clone.1, %multiply.25873.1.clone.1) + %multiply.25874.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245493.1.clone.1, %select.20596.3.clone.1) + %add.245494.1.clone.1 = f32[1280,1280]{1,0} add(%select.20592.3.clone.1, %multiply.25874.1.clone.1) + %multiply.25875.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245494.1.clone.1, %select.20596.3.clone.1) + %add.245495.1.clone.1 = f32[1280,1280]{1,0} add(%select.20591.3.clone.1, %multiply.25875.1.clone.1) + %multiply.25876.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245495.1.clone.1, %select.20596.3.clone.1) + %add.245497.3.clone.1 = f32[1280,1280]{1,0} add(%select.20590.5.clone.1, %multiply.25876.1.clone.1) + %multiply.25877.1.clone.1 = f32[1280,1280]{1,0} multiply(%add.245497.3.clone.1, %select.20596.3.clone.1) + %add.245498.3.clone.1 = f32[1280,1280]{1,0} add(%select.20589.5.clone.1, %multiply.25877.1.clone.1) + %multiply.25878.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245498.3.clone.1, %select.20596.3.clone.1) + %add.245499.9.clone.1 = f32[1280,1280]{1,0} add(%select.20588.11.clone.1, %multiply.25878.7.clone.1) + %multiply.25879.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245499.9.clone.1, %select.20596.3.clone.1) + %add.245500.7.clone.1 = f32[1280,1280]{1,0} add(%select.20587.7.clone.1, %multiply.25879.7.clone.1) + %multiply.25880.7.clone.1 = f32[1280,1280]{1,0} multiply(%add.245500.7.clone.1, %maximum.3628.3.clone.1) + %select.20597.7.clone.1 = f32[1280,1280]{1,0} select(%compare.7152.3.clone.1, %multiply.25870.9.clone.1, %multiply.25880.7.clone.1) + %multiply.25881.5.clone.1 = f32[1280,1280]{1,0} multiply(%select.20597.7.clone.1, %broadcast.244500.640) + %clamp.1146.3.clone.1 = f32[1280,1280]{1,0} clamp(%broadcast.244407.384, %multiply.25881.5.clone.1, %broadcast.244501.384) + %multiply.25882.1.clone.1 = f32[1280,1280]{1,0} multiply(%clamp.1146.3.clone.1, %broadcast.244502.1) + ROOT %tuple.3392 = (f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=5*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=10*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=15*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=20*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=25*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=30*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=35*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=40*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=45*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=50*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=55*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=60*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=65*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=70*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=75*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=80*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=85*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=90*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}) tuple(%multiply.25596.1, %multiply.25624.1.clone.1, %multiply.26588.1.clone.1, %multiply.26602.1.clone.1, %multiply.27082.1.clone.1, /*index=5*/%multiply.26570.1.clone.1, %multiply.27323.1.clone.1, %multiply.26556.1.clone.1, %multiply.27064.1.clone.1, %multiply.26541.1.clone.1, /*index=10*/%multiply.27443.1.clone.1, %multiply.26527.1.clone.1, %multiply.27047.1.clone.1, %multiply.26509.1.clone.1, %multiply.27305.1.clone.1, /*index=15*/%multiply.26489.1.clone.1, %multiply.27026.1.clone.1, %multiply.26470.1.clone.1, %multiply.27505.1.clone.1, %multiply.26455.1.clone.1, /*index=20*/%multiply.27010.1.clone.1, %multiply.26437.1.clone.1, %multiply.27291.1.clone.1, %multiply.26423.1.clone.1, %multiply.26992.1.clone.1, /*index=25*/%multiply.26409.1.clone.1, %multiply.27425.1.clone.1, %multiply.26395.1.clone.1, %multiply.26978.1.clone.1, %multiply.26376.1.clone.1, /*index=30*/%multiply.27277.1.clone.1, %multiply.26361.1.clone.1, %multiply.26964.1.clone.1, %multiply.26347.1.clone.1, %multiply.27533.1.clone.1, /*index=35*/%multiply.26333.1.clone.1, %multiply.26950.1.clone.1, %multiply.26315.1.clone.1, %multiply.27263.1.clone.1, %multiply.26301.1.clone.1, /*index=40*/%multiply.26932.1.clone.1, %multiply.26287.1.clone.1, %multiply.27411.1.clone.1, %multiply.26273.1.clone.1, %multiply.26917.1.clone.1, /*index=45*/%multiply.26255.1.clone.1, %multiply.27245.1.clone.1, %multiply.26241.1.clone.1, %multiply.26902.1.clone.1, %multiply.26227.1.clone.1, /*index=50*/%multiply.27487.1.clone.1, %multiply.26213.1.clone.1, %multiply.26888.1.clone.1, %multiply.26195.1.clone.1, %multiply.27231.1.clone.1, /*index=55*/%multiply.26181.1.clone.1, %multiply.26870.1.clone.1, %multiply.26167.1.clone.1, %multiply.27397.1.clone.1, %multiply.26153.1.clone.1, /*index=60*/%multiply.26856.1.clone.1, %multiply.26135.1.clone.1, %multiply.27217.1.clone.1, %multiply.26121.1.clone.1, %multiply.26842.1.clone.1, /*index=65*/%multiply.26107.1.clone.1, %multiply.27547.1.clone.1, %multiply.26093.1.clone.1, %multiply.26828.1.clone.1, %multiply.26075.1.clone.1, /*index=70*/%multiply.27203.1.clone.1, %multiply.26061.1.clone.1, %multiply.26810.1.clone.1, %multiply.26047.1.clone.1, %multiply.27383.1.clone.1, /*index=75*/%multiply.26033.1.clone.1, %multiply.26796.1.clone.1, %multiply.26015.1.clone.1, %multiply.27185.1.clone.1, %multiply.26001.1.clone.1, /*index=80*/%multiply.26782.1.clone.1, %multiply.25986.1.clone.1, %multiply.27472.1.clone.1, %multiply.25972.1.clone.1, %multiply.26768.1.clone.1, /*index=85*/%multiply.25954.1.clone.1, %multiply.27171.1.clone.1, %multiply.25932.1.clone.1, %multiply.26750.1.clone.1, %multiply.25914.1.clone.1, /*index=90*/%multiply.27365.1.clone.1, %multiply.25900.1.clone.1, %multiply.26736.1.clone.1, %multiply.25882.1.clone.1) +} +// CHECK-PARTITIONED-HLO-COUNT-100: func.func private +// CHECK-COUNT-100: func.func private \ No newline at end of file diff --git a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc index 910f7ad7e81f58..aefcaf0ca5be95 100644 --- a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc @@ -169,6 +169,8 @@ bool IsEvaluatedMoreThanOnce(const HloInstruction* instr) { using SubgraphId = int; +constexpr int kMaxHloOpsPerSubgraph = 2000; + // HloSubgraphData is associated with a single HLO instruction and contains // the necessary information to partition the computation into subgraphs. struct HloSubgraphData { @@ -180,6 +182,8 @@ struct HloSubgraphData { SubgraphId subgraph_id = -1; // Whether the instruction is a root of the subgraph. bool is_root = false; + // Number of users. + int num_users = 0; }; PartitionedComputation::PartitionedComputation( @@ -197,6 +201,7 @@ PartitionedComputation::PartitionedComputation( SubgraphId subgraph_count = 0; std::vector id_to_subgraph_data(pre_order.size()); + std::vector num_ops_per_subgraph; // Iterate over the use-def chains and check if the instruction should be // placed in a separate function. for (auto [instr_index, instr] : llvm::enumerate(pre_order)) { @@ -208,12 +213,26 @@ PartitionedComputation::PartitionedComputation( is_subgraph_root(instr) || instr_subgraph_data.user_subgraph_ids.size() != 1 || instr_subgraph_data.indexings.size() > 1; - if (instr_subgraph_data.is_root) { + bool is_large_subgraph = + instr_subgraph_data.subgraph_id > -1 && + num_ops_per_subgraph[instr_subgraph_data.subgraph_id] >= + kMaxHloOpsPerSubgraph; + if (instr_subgraph_data.is_root || is_large_subgraph) { instr_subgraph_data.subgraph_id = subgraph_count++; instr_subgraph_data.indexings.clear(); + num_ops_per_subgraph.push_back(1); } else { instr_subgraph_data.subgraph_id = *instr_subgraph_data.user_subgraph_ids.begin(); + ++num_ops_per_subgraph.at(instr_subgraph_data.subgraph_id); + } + if (num_ops_per_subgraph.at(instr_subgraph_data.subgraph_id) > + kMaxHloOpsPerSubgraph && + instr_subgraph_data.num_users == 1) { + instr_subgraph_data.subgraph_id = subgraph_count++; + instr_subgraph_data.is_root = true; + instr_subgraph_data.indexings.clear(); + num_ops_per_subgraph.push_back(1); } auto operands_indexing = ComputeOperandIndexingMaps(instr, mlir_context); // Iterate over the operands and add the func_ids of the current instruction @@ -222,6 +241,7 @@ PartitionedComputation::PartitionedComputation( llvm::zip(instr->operands(), operands_indexing)) { auto& operand_subgraph_data = id_to_subgraph_data[instr_to_id[operand_instr]]; + ++operand_subgraph_data.num_users; IndexingMap instr_indexing = instr_subgraph_data.indexings.empty() ? IndexingMap::GetUndefined() : *instr_subgraph_data.indexings.begin(); diff --git a/third_party/xla/xla/codegen/emitters/ir/tests/inlining.mlir b/third_party/xla/xla/codegen/emitters/ir/tests/inlining.mlir index 98c18c91b76a30..9263a7470e90f9 100644 --- a/third_party/xla/xla/codegen/emitters/ir/tests/inlining.mlir +++ b/third_party/xla/xla/codegen/emitters/ir/tests/inlining.mlir @@ -147,11 +147,11 @@ module { // ----- module { - func.func private @fib0(%start : f32) -> f32 { + func.func private @fib0(%start : f32) -> f32 attributes {no_compute = true} { %zero = arith.constant 0.0 : f32 return %zero : f32 } - func.func private @fib1(%start : f32) -> f32 { + func.func private @fib1(%start : f32) -> f32 attributes {no_compute = true} { return %start : f32 } func.func private @fib2(%start : f32) -> f32 { @@ -202,16 +202,17 @@ module { } // CHECK-LABEL: module { +// CHECK-NOT: fib0 +// CHECK: func.func private @fib2 +// CHECK-NOT: fib1 +// CHECK-NOT: fib3 +// CHECK-NOT: fib4 +// CHECK-NOT: fib5 +// CHECK-NOT: fib6 +// CHECK-NOT: fib7 + // CHECK: @caller -// CHECK: arith.constant 0.000000e+00 -// CHECK: xla.pure_call @fib5 -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: xla.pure_call @fib5 -// CHECK: arith.addf -// CHECK: arith.addf +// CHECK-COUNT-8: xla.pure_call @fib2 // ----- diff --git a/third_party/xla/xla/codegen/emitters/ir/xla_dialect.cc b/third_party/xla/xla/codegen/emitters/ir/xla_dialect.cc index 47e22e6e4427de..6085f85bb10823 100644 --- a/third_party/xla/xla/codegen/emitters/ir/xla_dialect.cc +++ b/third_party/xla/xla/codegen/emitters/ir/xla_dialect.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep @@ -34,6 +36,10 @@ limitations under the License. namespace xla { namespace { +constexpr int64_t kMaxFuncSize = 4000; + +int64_t GetNumOps(mlir::Block& block) { return block.getOperations().size(); } + struct XlaInlinerInterface : public mlir::DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; // Returns true if the given operation 'callable', that implements the @@ -47,11 +53,6 @@ struct XlaInlinerInterface : public mlir::DialectInlinerInterface { bool wouldBeCloned) const final { if (call->hasAttr("noinline")) return false; if (callable->hasAttr(emitters::kHasNoCompute)) return true; - if (!wouldBeCloned) { - // If no duplicate would be created, 'call' is likely the only caller of - // 'callable'. - return true; - } // Otherwise, inline only if the called function is small. We could // theoretically also inline if there is no other caller in the function // that contains the callee that has a call path to the callable, but that @@ -64,36 +65,38 @@ struct XlaInlinerInterface : public mlir::DialectInlinerInterface { if (!pure_call_op) { return false; } - auto region = func_op.getCallableRegion(); - if (!region) { + auto callable_region = func_op.getCallableRegion(); + if (!callable_region) { return false; } llvm::SmallDenseSet callee_calls; - for (auto call : region->getOps()) { - callee_calls.insert(call.getCallee()); + for (auto callee_call : callable_region->getOps()) { + callee_calls.insert(callee_call.getCallee()); } // If true, then the callee and the caller call the same third function. bool contains_call_to_same_function = false; // The number of calls to the callee in the caller. int num_calls_in_caller = 0; - for (auto neighbor_call : call->getParentRegion()->getOps()) { - contains_call_to_same_function |= - callee_calls.contains(neighbor_call.getCallee()); - if (neighbor_call.getCallee() == pure_call_op.getCallee()) { - ++num_calls_in_caller; + if (!wouldBeCloned) { + num_calls_in_caller = 1; + } else { + for (auto neighbor_call : call->getParentRegion()->getOps()) { + contains_call_to_same_function |= + callee_calls.contains(neighbor_call.getCallee()); + if (neighbor_call.getCallee() == pure_call_op.getCallee()) { + ++num_calls_in_caller; + } } } if (num_calls_in_caller > 1) return false; - if (contains_call_to_same_function) return true; - - constexpr int kMaxOperationsToInline = 8; - int num_ops = 0; - region->front().walk([&](mlir::Operation* op) { ++num_ops; }); - - // Don't inline functions with more than `kMaxOperationsToInline` ops. - return num_ops <= kMaxOperationsToInline; + // Don't inline functions, if after inlining the size of the function + // becomes too big. + int num_ops = num_calls_in_caller * GetNumOps(callable_region->front()) + + GetNumOps(call->getParentRegion()->front()); + if (num_ops > kMaxFuncSize) return false; + return !wouldBeCloned || contains_call_to_same_function; } // Returns true if the given operation 'op', that is registered to this From 963353b90944a8874e752a6d01ebac48cac763b8 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Tue, 8 Apr 2025 03:22:26 -0700 Subject: [PATCH 0354/1324] PR #23967: [ROCm][MLIR] Support native fp8 conversion instructions where possible Imported from GitHub PR https://github.com/openxla/xla/pull/23967 Copybara import of the project: -- fdf5a888b8878afd5f1e74d9b7a78fbcef3303fe by Dragan Mladjenovic : [ROCm][MLIR] Support native fp8 conversion instructions where possible -- 3d9ed18d06a94a3024b19dc7db8727790fc7f1f1 by Dragan Mladjenovic : Split convert-float pass back to convert-float-nvidia and convert-float-amd -- c0b62fbce0b097a5217f5d1cabd8ff647c6ba04a by Dragan Mladjenovic : Pick better insertion point -- e455c119adbcd163f7ec4563b5470fc7587ec4ea by Dragan Mladjenovic : Address LINT warnings -- 590e8a5b54987422c306f7fb0df1a00b1b4d16a9 by Dragan Mladjenovic : Avoid using ifdefs -- 2defba5b795273b0326ac483230333ca78f3f775 by Dragan Mladjenovic : Inline MaybeCreateConvertFloatPass -- 979d1f7440eb1dc52b83c8403993cdba168553e4 by Dragan Mladjenovic : Tode down usage of auto + trim usings -- 8df7c6c6edfd13641a28967a2d3ee9083ac48d9a by Dragan Mladjenovic : Add missing files; remove the gpu tag Merging this change closes #23967 PiperOrigin-RevId: 745065146 --- .../xla/backends/gpu/codegen/emitters/BUILD | 1 + .../gpu/codegen/emitters/emitter_base.cc | 21 +- .../gpu/codegen/emitters/transforms/BUILD | 1 + .../emitters/transforms/convert_float_amd.cc | 578 ++++++++++++++++++ .../transforms/convert_float_nvidia.cc | 23 - .../gpu/codegen/emitters/transforms/passes.h | 6 +- .../gpu/codegen/emitters/transforms/passes.td | 17 +- .../transforms/tests/convert_float_amd.mlir | 399 ++++++++++++ .../xla/service/gpu/llvm_gpu_backend/BUILD | 15 + .../gpu/llvm_gpu_backend/nvptx_backend.cc | 35 -- .../gpu/llvm_gpu_backend/nvptx_backend.h | 6 +- .../gpu/llvm_gpu_backend/ptx_version_util.cc | 54 ++ .../gpu/llvm_gpu_backend/ptx_version_util.h | 30 + .../xla/stream_executor/device_description.h | 4 +- 14 files changed, 1119 insertions(+), 71 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc create mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/convert_float_amd.mlir create mode 100644 third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.cc create mode 100644 third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.h diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD index cc9eeb8f8ff295..b0e0d66fcdaa5c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD @@ -71,6 +71,7 @@ cc_library( "//xla/service/gpu:kernel_reuse_cache", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", + "//xla/service/gpu/llvm_gpu_backend:ptx_version_util", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc index 694aa7dca258fd..4d393924bfb774 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc @@ -105,6 +105,7 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/llvm_gpu_backend/ptx_version_util.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" @@ -656,9 +657,23 @@ void AddLoweringPasses(mlir::OpPassManager& pm, pm.addPass(mlir::createCSEPass()); // This pass has to run before `ExpandFloatOpsPass`. - auto maybe_convert_fp8 = MaybeCreateConvertFloatNvidiaPass(device); - if (maybe_convert_fp8.has_value()) { - pm.addPass(std::move(*maybe_convert_fp8)); + if (auto* cc = std::get_if( + &device.gpu_compute_capability())) { + se::SemanticVersion ptx_version = + nvptx::DetermineHighestSupportedPtxVersionFromCudaVersion( + device.runtime_version()); + + // FP8 conversion intrinsics are available on sm89 since ptx 8.1 + // Older ptx versions only support FP8 conversion for sm90 + if ((ptx_version >= se::SemanticVersion(8, 1, 0) && cc->IsAtLeast(8, 9)) || + (ptx_version >= se::SemanticVersion(7, 8, 0) && cc->IsAtLeast(9, 0))) { + pm.addPass(CreateConvertFloatNvidiaPass()); + } + } else if (auto* cc = std::get_if( + &device.gpu_compute_capability())) { + if (cc->has_fp8_support()) { + pm.addPass(CreateConvertFloatAMDPass(*cc)); + } } pm.addPass(emitters::CreateExpandFloatOpsPass()); diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD index c8560ef0227f6f..ef97d62edccc0c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD @@ -38,6 +38,7 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "convert_float_amd.cc", "convert_float_nvidia.cc", "convert_index_type.cc", "fuse_loops.cc", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc new file mode 100644 index 00000000000000..5fd0d4c5a9429b --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc @@ -0,0 +1,578 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/emitters/transforms/passes.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_CONVERTFLOATAMDPASS +#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" + +namespace { + +namespace LLVM = ::mlir::LLVM; +namespace arith = ::mlir::arith; +namespace vector = ::mlir::vector; + +template +struct Fp8OpRewritePattern : public mlir::OpRewritePattern { + using FixedVectorValue = mlir::TypedValue; + using FloatValue = mlir::TypedValue; + Fp8OpRewritePattern(mlir::MLIRContext* context, bool nativeNanooFp8) + : mlir::OpRewritePattern(context), + nativeNanooFp8_(nativeNanooFp8) {} + bool isFp8(const mlir::Type& type) const { + return nativeNanooFp8_ ? llvm::isa(type) + : llvm::isa(type); + } + bool isBf8(const mlir::Type& type) const { + return nativeNanooFp8_ ? llvm::isa(type) + : llvm::isa(type); + } + + private: + bool nativeNanooFp8_; +}; + +struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern { + using Fp8OpRewritePattern::Fp8OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + arith::TruncFOp op, mlir::PatternRewriter& rewriter) const override { + auto src = mlir::cast(op.getOperand()); + auto dst_ty = mlir::cast(op.getType()); + if (!isFp8(dst_ty) && !isBf8(dst_ty)) { + return rewriter.notifyMatchFailure(op, "unsupported float conversion"); + } + + auto match = MatchBuildVector(op, src, dst_ty); + + if (match) { + auto [inputs, output] = *match; + rewriter.setInsertionPointAfter(output.getDefiningOp()); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + rewriter.replaceOp( + output.getDefiningOp(), + EmitVectorizedTruncToF8Intrinsic(inputs, output.getType(), b)); + return mlir::success(); + } + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + rewriter.replaceOp(op, EmitTruncToF8Intrinsic(src, dst_ty, b)); + return mlir::success(); + } + + std::optional, FixedVectorValue>> + MatchBuildVector(arith::TruncFOp op, FloatValue value, + mlir::FloatType to_ty) const { + auto matchPos = [](vector::InsertOp insert, size_t* pos) -> bool { + llvm::APInt ap_pos; + auto position = insert.getMixedPosition(); + if (position.size() != 1) { + return false; + } + if (auto attr = mlir::dyn_cast(position[0])) { + if (!mlir::matchPattern(attr, mlir::m_ConstantInt(&ap_pos))) { + return false; + } + } else { + if (!mlir::matchPattern(mlir::cast(position[0]), + mlir::m_ConstantInt(&ap_pos))) { + return false; + } + } + + *pos = ap_pos.getZExtValue(); + return true; + }; + + if (!op->hasOneUse()) { + return std::nullopt; + } + + size_t pos; + auto insert = mlir::dyn_cast(op->use_begin()->getOwner()); + if (!insert || insert.getSource() != op->getResult(0) || + !matchPos(insert, &pos) || !insert.getDest().hasOneUse()) { + return std::nullopt; + } + + mlir::Value vector = insert.getDest(); + + size_t element_count = + mlir::cast(vector).getType().getNumElements(); + + if (!llvm::isPowerOf2_64(element_count) || element_count == 1) { + return std::nullopt; + } + + llvm::SmallVector inputs(element_count); + + auto addInput = [&](mlir::Value input, size_t index) -> bool { + if (index >= element_count) { + return false; + } + if (inputs[index]) { + return false; + } + inputs[index] = input; + return true; + }; + + addInput(value, pos); + + mlir::Value input; + mlir::Operation* to_match = vector.getDefiningOp(); + while (mlir::matchPattern(to_match, mlir::m_Op( + mlir::m_Op( + mlir::matchers::m_Any(&input)), + mlir::matchers::m_Any(&vector))) && + matchPos(mlir::cast(to_match), &pos) && + vector.hasOneUse()) { + if (!addInput(input, pos)) { + return std::nullopt; + } + to_match = vector.getDefiningOp(); + } + + while ( + insert->hasOneUse() && + mlir::matchPattern( + insert->use_begin()->getOwner(), + mlir::m_Op( + mlir::m_Op(mlir::matchers::m_Any(&input)), + mlir::matchers::m_Val(insert->getResult(0)))) && + matchPos(mlir::cast(insert->use_begin()->getOwner()), + &pos) && + input.getType() == value.getType()) { + if (!addInput(input, pos)) { + return std::nullopt; + } + insert = mlir::cast(insert->use_begin()->getOwner()); + } + + if (llvm::any_of(inputs, [](mlir::Value input) { return !input; })) { + return std::nullopt; + } + return std::make_tuple(std::move(inputs), + mlir::cast(insert->getResult(0))); + } + + mlir::Value EmitVectorizedTruncToF8Intrinsic( + llvm::SmallVector& inputs, mlir::FixedVectorType to_ty, + mlir::ImplicitLocOpBuilder& b) const { + assert(isFp8(to_ty.getElementType()) || isBf8(to_ty.getElementType())); + + mlir::FloatType f32_ty = b.getF32Type(); + mlir::IntegerType i32_ty = b.getI32Type(); + mlir::IntegerType i8_ty = b.getI8Type(); + mlir::IntegerType i1_ty = b.getI1Type(); + + llvm::transform(inputs, inputs.begin(), [&](mlir::Value v) -> mlir::Value { + if (v.getType().getIntOrFloatBitWidth() < f32_ty.getWidth()) { + return b.create(f32_ty, v); + } else if (v.getType() != f32_ty) { + return b.create(f32_ty, v); + } else { + return v; + } + }); + + mlir::StringAttr cvtIntr = b.getStringAttr( + isFp8(to_ty.getElementType()) ? "llvm.amdgcn.cvt.pk.fp8.f32" + : "llvm.amdgcn.cvt.pk.bf8.f32"); + + size_t num_elements = to_ty.getNumElements(); + assert(num_elements == inputs.size() && + (num_elements == 2 || num_elements % 4 == 0)); + + size_t num_chunks = (num_elements + 2) / 4; + + mlir::Type chunks_ty = LLVM::getFixedVectorType(i32_ty, num_chunks); + mlir::Value chunks = b.create(chunks_ty); + bool pos = false; + for (size_t i = 0; i < inputs.size() / 2; i++) { + mlir::Value chunk_pos = b.create(i32_ty, 2 * i / 4); + mlir::Value chunk = b.create(chunks, chunk_pos); + LLVM::CallIntrinsicOp cvtOp = b.create( + i32_ty, cvtIntr, + mlir::ValueRange{inputs[2 * i], inputs[2 * i + 1], chunk, + b.create(i1_ty, pos)}); + chunks = b.create(chunks, cvtOp.getResult(0), + chunk_pos); + pos ^= true; + } + + if (num_elements == 2) { + return b + .create( + to_ty, + mlir::ValueRange{b.create( + LLVM::getFixedVectorType(i8_ty, num_elements), + b.create( + b.create( + LLVM::getFixedVectorType(b.getI16Type(), 2), chunks), + b.create(i32_ty, 0)))}) + .getResult(0); + } + + return b + .create( + to_ty, mlir::ValueRange{b.create( + LLVM::getFixedVectorType(i8_ty, num_elements), chunks)}) + .getResult(0); + } + + mlir::Value EmitTruncToF8Intrinsic(mlir::Value value, mlir::FloatType to_ty, + mlir::ImplicitLocOpBuilder& b) const { + assert(isFp8(to_ty) || isBf8(to_ty)); + + mlir::FloatType f32_ty = b.getF32Type(); + mlir::IntegerType i32_ty = b.getI32Type(); + if (value.getType().getIntOrFloatBitWidth() < f32_ty.getWidth()) { + value = b.create(f32_ty, value); + } else if (value.getType() != f32_ty) { + value = b.create(f32_ty, value); + } + + mlir::StringAttr cvtIntr = + b.getStringAttr(isFp8(to_ty) ? "llvm.amdgcn.cvt.pk.fp8.f32" + : "llvm.amdgcn.cvt.pk.bf8.f32"); + + LLVM::CallIntrinsicOp cvtOp = b.create( + i32_ty, cvtIntr, + mlir::ValueRange{value, b.create(f32_ty), + b.create(i32_ty), + b.create(b.getI1Type(), 0)}); + mlir::Value res = + b.create(b.getI8Type(), cvtOp.getResults()); + return b + .create(to_ty, mlir::ValueRange{res}) + .getResult(0); + } +}; + +struct RewriteFp8ExtFPattern : public Fp8OpRewritePattern { + using Fp8OpRewritePattern::Fp8OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + arith::ExtFOp op, mlir::PatternRewriter& rewriter) const override { + auto src = mlir::cast(op.getOperand()); + auto dst_ty = mlir::cast(op.getType()); + if (!isFp8(src.getType()) && !isBf8(src.getType())) { + return rewriter.notifyMatchFailure(op, "unsupported float conversion"); + } + + auto match = MatchDecomposeVector(op, src, dst_ty); + + if (match) { + auto [input, outputs] = *match; + if (mlir::Operation* input_op = input.getDefiningOp()) { + rewriter.setInsertionPointAfter(input_op); + } else { + rewriter.setInsertionPointToStart( + mlir::cast(input).getOwner()); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto new_outputs = EmitVectorizedExtFromF8Intrinsic( + input, mlir::cast(outputs[0].getType()), b); + for (auto [old_value, new_value] : + llvm::zip_equal(outputs, new_outputs)) { + rewriter.replaceOp(old_value.getDefiningOp(), new_value); + } + + return mlir::success(); + } + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + rewriter.replaceOp(op, EmitExtFromF8Intrinsic(src, dst_ty, b)); + return mlir::success(); + } + std::optional>> + MatchDecomposeVector(arith::ExtFOp op, FloatValue value, + mlir::FloatType to_ty) const { + auto matchPos = [](vector::ExtractOp extract, size_t* pos) -> bool { + llvm::APInt ap_pos; + auto position = extract.getMixedPosition(); + if (position.size() != 1) { + return false; + } + if (auto attr = mlir::dyn_cast(position[0])) { + if (!mlir::matchPattern(attr, mlir::m_ConstantInt(&ap_pos))) { + return false; + } + } else { + if (!mlir::matchPattern(mlir::cast(position[0]), + mlir::m_ConstantInt(&ap_pos))) { + return false; + } + } + *pos = ap_pos.getZExtValue(); + return true; + }; + + size_t pos; + auto extract = value.getDefiningOp(); + if (!extract || !extract->hasOneUse() || !matchPos(extract, &pos)) { + return std::nullopt; + } + + mlir::Value vector = extract.getVector(); + + size_t element_count = + mlir::cast(vector).getType().getNumElements(); + + if (!llvm::isPowerOf2_64(element_count) || element_count == 1) { + return std::nullopt; + } + + llvm::SmallVector outputs(element_count); + + auto addOutput = [&](mlir::Value output, size_t index) -> bool { + if (index >= element_count) { + return false; + } + if (outputs[index]) { + return false; + } + outputs[index] = output; + return true; + }; + + for (const mlir::OpOperand& use : vector.getUses()) { + extract = mlir::dyn_cast(use.getOwner()); + if (!extract || !extract->hasOneUse() || extract.getVector() != vector || + !matchPos(extract, &pos)) { + return std::nullopt; + } + auto extf = + mlir::dyn_cast(extract->use_begin()->getOwner()); + if (!extf || extf.getType() != to_ty || extf.getOperand() != extract) { + return std::nullopt; + } + if (!addOutput(extf, pos)) { + return std::nullopt; + } + } + + if (llvm::any_of(outputs, [](mlir::Value output) { return !output; })) { + return std::nullopt; + } + return std::make_tuple(mlir::cast(vector), + std::move(outputs)); + } + + mlir::Value ConvertFromFloat(mlir::Value v, mlir::FloatType to_ty, + mlir::ImplicitLocOpBuilder& b) const { + mlir::FloatType f32_ty = b.getF32Type(); + mlir::IntegerType i32_ty = b.getI32Type(); + if (to_ty == f32_ty) { + return v; + } + + if (to_ty.getWidth() > f32_ty.getWidth()) { + return b.create(to_ty, v); + } + + if (to_ty.isBF16()) { + return b.create( + to_ty, + b.create( + b.getI16Type(), + b.create(b.create(i32_ty, v), + b.create(i32_ty, 16)))); + } + + assert(to_ty.getWidth() < f32_ty.getWidth()); + return b.create(to_ty, v); + } + + llvm::SmallVector EmitVectorizedExtFromF8Intrinsic( + FixedVectorValue value, mlir::FloatType to_ty, + mlir::ImplicitLocOpBuilder& b) const { + mlir::FloatType f32_ty = b.getF32Type(); + mlir::IntegerType i32_ty = b.getI32Type(); + mlir::IntegerType i16_ty = b.getI16Type(); + mlir::IntegerType i8_ty = b.getI8Type(); + mlir::IntegerType i1_ty = b.getI1Type(); + mlir::Value zero_cst = b.create(i32_ty, 0); + mlir::Value one_cst = b.create(i32_ty, 1); + + size_t num_elements = value.getType().getNumElements(); + assert(num_elements == 2 || num_elements % 4 == 0); + + size_t num_chunks = (num_elements + 2) / 4; + mlir::Type chunks_ty = LLVM::getFixedVectorType(i32_ty, num_chunks); + mlir::Value chunks; + + if (num_elements == 2) { + chunks = b.create( + chunks_ty, + b.create( + b.create(LLVM::getFixedVectorType(i16_ty, 2)), + b.create( + i16_ty, b.create( + LLVM::getFixedVectorType(i8_ty, num_elements), + mlir::ValueRange{value}) + .getResult(0)), + zero_cst)); + } else { + chunks = b.create( + chunks_ty, b.create( + LLVM::getFixedVectorType(i8_ty, num_elements), + mlir::ValueRange{value}) + .getResult(0)); + } + + llvm::SmallVector results; + mlir::StringAttr cvtIntr = b.getStringAttr( + isFp8(value.getType().getElementType()) ? "llvm.amdgcn.cvt.pk.f32.fp8" + : "llvm.amdgcn.cvt.pk.f32.bf8"); + mlir::Type result_ty = LLVM::getFixedVectorType(f32_ty, 2); + LLVM::FastmathFlagsAttr flags = + LLVM::FastmathFlagsAttr::get(b.getContext(), LLVM::FastmathFlags::ninf); + for (size_t i = 0; i < num_elements / 2; i++) { + mlir::Value chunk_pos = b.create(i32_ty, (2 * i) / 4); + mlir::Value chunk = b.create(chunks, chunk_pos); + LLVM::CallIntrinsicOp cvtOp = b.create( + result_ty, cvtIntr, + mlir::ValueRange{ + chunk, b.create(i1_ty, ((2 * i) % 4) != 0)}, + flags); + + results.push_back( + b.create(cvtOp.getResult(0), zero_cst)); + results.push_back( + b.create(cvtOp.getResult(0), one_cst)); + } + + if (to_ty.isF16()) { + result_ty = LLVM::getFixedVectorType(b.getF16Type(), 2); + cvtIntr = b.getStringAttr("llvm.amdgcn.cvt.pkrtz"); + for (size_t i = 0; i < num_elements / 2; i++) { + LLVM::CallIntrinsicOp cvtOp = b.create( + result_ty, cvtIntr, + mlir::ValueRange{results[2 * i], results[2 * i + 1]}, flags); + + results[2 * i] = + b.create(cvtOp.getResult(0), zero_cst); + results[2 * i + 1] = + b.create(cvtOp.getResult(0), one_cst); + } + } else if (to_ty != f32_ty) { + llvm::transform(results, results.begin(), + [&](mlir::Value v) -> mlir::Value { + return ConvertFromFloat(v, to_ty, b); + }); + } + + return results; + } + + mlir::Value EmitExtFromF8Intrinsic(mlir::Value value, mlir::FloatType to_ty, + mlir::ImplicitLocOpBuilder& b) const { + assert(isFp8(value.getType()) || isBf8(value.getType())); + + mlir::FloatType f32_ty = b.getF32Type(); + mlir::IntegerType i32_ty = b.getI32Type(); + mlir::IntegerType i8_ty = b.getI8Type(); + mlir::Value zero_cst = b.create(i32_ty, 0); + // Emulate anyext + mlir::Value input = b.create( + i32_ty, b.create( + b.create(LLVM::getFixedVectorType(i8_ty, 4)), + b.create( + i8_ty, mlir::ValueRange{value}) + .getResult(0), + zero_cst)); + mlir::StringAttr cvtIntr = + b.getStringAttr(isFp8(value.getType()) ? "llvm.amdgcn.cvt.f32.fp8" + : "llvm.amdgcn.cvt.f32.bf8"); + LLVM::FastmathFlagsAttr flags = + LLVM::FastmathFlagsAttr::get(b.getContext(), LLVM::FastmathFlags::ninf); + LLVM::CallIntrinsicOp cvtOp = b.create( + mlir::TypeRange{f32_ty}, cvtIntr, mlir::ValueRange{input, zero_cst}, + flags); + + return ConvertFromFloat(cvtOp.getResult(0), to_ty, b); + } +}; + +class ConvertFloatAMDPass + : public impl::ConvertFloatAMDPassBase { + public: + explicit ConvertFloatAMDPass(const ConvertFloatAMDPassOptions& options) + : ConvertFloatAMDPassBase(options) {} + + explicit ConvertFloatAMDPass(const se::RocmComputeCapability& cc) : cc_(cc) {} + + void runOnOperation() override { + if (!gpu_device_info_.empty()) { + se::GpuDeviceInfoProto device_info; + CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_, + &device_info)); + cc_ = se::DeviceDescription(device_info).rocm_compute_capability(); + } + mlir::RewritePatternSet patterns(&getContext()); + bool nativeNanooFp8 = cc_.has_nanoo_fp8_support(); + patterns.add( + &getContext(), nativeNanooFp8); + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } + + private: + se::RocmComputeCapability cc_; +}; + +} // namespace + +std::unique_ptr CreateConvertFloatAMDPass( + const std::string& gpu_device_info) { + ConvertFloatAMDPassOptions options; + options.gpu_device_info_ = gpu_device_info; + return std::make_unique(options); +} + +std::unique_ptr CreateConvertFloatAMDPass( + const se::RocmComputeCapability& cc) { + return std::make_unique(cc); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc index 3816a077b075b8..ff280e5a38059f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc @@ -31,12 +31,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/backends/gpu/codegen/emitters/transforms/passes.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/semantic_version.h" -#ifdef GOOGLE_CUDA -#include "xla/service/gpu/llvm_gpu_backend/nvptx_backend.h" -#endif namespace xla { namespace gpu { @@ -258,23 +253,5 @@ std::unique_ptr CreateConvertFloatNvidiaPass() { return std::make_unique(); } -std::optional> MaybeCreateConvertFloatNvidiaPass( - const se::DeviceDescription& device_description) { -#ifdef GOOGLE_CUDA - se::SemanticVersion ptx_version = - nvptx::DetermineHighestSupportedPtxVersionFromCudaVersion( - device_description.runtime_version()); - se::CudaComputeCapability cc = device_description.cuda_compute_capability(); - - // FP8 conversion intrinsics are available on sm89 since ptx 8.1 - // Older ptx versions only support FP8 conversion for sm90 - if ((ptx_version >= se::SemanticVersion(8, 1, 0) && cc.IsAtLeast(8, 9)) || - (ptx_version >= se::SemanticVersion(7, 8, 0) && cc.IsAtLeast(9, 0))) { - return CreateConvertFloatNvidiaPass(); - } -#endif - return std::nullopt; -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.h b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.h index 8046bb4ba06d4f..0afe833dc305e1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.h +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.h @@ -34,8 +34,10 @@ namespace gpu { #include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" std::unique_ptr CreateConvertFloatNvidiaPass(); -std::optional> MaybeCreateConvertFloatNvidiaPass( - const se::DeviceDescription& device_description); +std::unique_ptr CreateConvertFloatAMDPass( + const std::string& gpu_device_info = ""); +std::unique_ptr CreateConvertFloatAMDPass( + const se::RocmComputeCapability& cc); std::unique_ptr CreateConvertIndexTypePass(); std::unique_ptr CreateOptimizeLoopsPass(); std::unique_ptr CreateFuseLoopsPass(); diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.td b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.td index 3ac2f9ab763a30..c1e964aca7813e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.td +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/passes.td @@ -37,16 +37,29 @@ def ConvertIndexTypePass : Pass<"xla-gpu-convert-index-type", "mlir::ModuleOp"> } def ConvertFloatNvidiaPass : Pass<"xla-gpu-convert-float-nvidia", "mlir::ModuleOp"> { - let summary = "Convert floating point types using NVidia intrinsics."; + let summary = "Convert floating point types using NVPTX intrinsics."; let dependentDialects = [ "mlir::LLVM::LLVMDialect", "mlir::arith::ArithDialect", ]; - let constructor = "CreateConvertFloatNvidiaPass()"; } +def ConvertFloatAMDPass : Pass<"xla-gpu-convert-float-amd", "mlir::ModuleOp"> { + let summary = "Convert floating point types using AMDGCN intrinsics."; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::arith::ArithDialect", + ]; + let options = [ + Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", + "Serialized stream_executor::GPUDeviceInfo proto.">, + ]; + let constructor = "CreateConvertFloatAMDPass()"; +} + def FuseLoopsPass : Pass<"xla-gpu-fuse-loops", "mlir::func::FuncOp"> { let summary = "Fuse xla_gpu.loop."; let description = [{ diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/convert_float_amd.mlir b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/convert_float_amd.mlir new file mode 100644 index 00000000000000..5c453542be8d06 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/convert_float_amd.mlir @@ -0,0 +1,399 @@ +// RUN: emitters_opt %s -split-input-file -xla-gpu-convert-float-amd="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx942:sramecc+:xnack\"}'" -canonicalize | FileCheck %s + +module { + func.func @intr_f16_to_f8(%arg0: f16) -> (f8E4M3FNUZ, f8E5M2FNUZ) { + %a = arith.truncf %arg0 : f16 to f8E4M3FNUZ + %b = arith.truncf %arg0 : f16 to f8E5M2FNUZ + return %a, %b : f8E4M3FNUZ, f8E5M2FNUZ + } +} + +// CHECK-LABEL: @intr_f16_to_f8 +// CHECK: arith.extf %{{.+}} : f16 to f32 +// CHECK: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_bf16_to_f8(%arg0: bf16) -> (f8E4M3FNUZ, f8E5M2FNUZ) { + %a = arith.truncf %arg0 : bf16 to f8E4M3FNUZ + %b = arith.truncf %arg0 : bf16 to f8E5M2FNUZ + return %a, %b : f8E4M3FNUZ, f8E5M2FNUZ + } +} + +// CHECK-LABEL: @intr_bf16_to_f8 +// CHECK: arith.extf %{{.+}} : bf16 to f32 +// CHECK: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_f32_to_f8(%arg0: f32) -> (f8E4M3FNUZ, f8E5M2FNUZ) { + %a = arith.truncf %arg0 : f32 to f8E4M3FNUZ + %b = arith.truncf %arg0 : f32 to f8E5M2FNUZ + return %a, %b : f8E4M3FNUZ, f8E5M2FNUZ + } +} + +// CHECK-LABEL: @intr_f32_to_f8 +// CHECK: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_f64_to_f8(%arg0: f64) -> (f8E4M3FNUZ, f8E5M2FNUZ) { + %a = arith.truncf %arg0 : f64 to f8E4M3FNUZ + %b = arith.truncf %arg0 : f64 to f8E5M2FNUZ + return %a, %b : f8E4M3FNUZ, f8E5M2FNUZ + } +} + +// CHECK-LABEL: @intr_f64_to_f8 +// CHECK: arith.truncf %{{.+}} : f64 to f32 +// CHECK: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK: arith.truncf %{{.+}} : f64 to f32 +// CHECK: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_f8_to_f16(%arg0: f8E4M3FNUZ, %arg1: f8E5M2FNUZ) -> (f16, f16) { + %a = arith.extf %arg0 : f8E4M3FNUZ to f16 + %b = arith.extf %arg1 : f8E5M2FNUZ to f16 + return %a, %b : f16, f16 + } +} + +// CHECK-LABEL: @intr_f8_to_f16 +// CHECK: llvm.amdgcn.cvt.f32.fp8 +// CHECK: llvm.amdgcn.cvt.f32.bf8 +// CHECK: arith.truncf %{{.+}} : f32 to f16 + +// ----- + +module { + func.func @intr_f8_to_bf16(%arg0: f8E4M3FNUZ, %arg1: f8E5M2FNUZ) -> (bf16, bf16) { + %a = arith.extf %arg0 : f8E4M3FNUZ to bf16 + %b = arith.extf %arg1 : f8E5M2FNUZ to bf16 + return %a, %b : bf16, bf16 + } +} + +// CHECK-LABEL: @intr_f8_to_bf16 +// CHECK: llvm.amdgcn.cvt.f32.fp8 +// CHECK: llvm.amdgcn.cvt.f32.bf8 +// CHECK: llvm.bitcast %{{.+}} : f32 to i32 +// CHECK: llvm.bitcast %{{.+}} : i16 to bf16 +// CHECK-NOT: arith.truncf %{{.+}} : f32 to bf16 + +// ----- + +module { + func.func @intr_f8_to_f32(%arg0: f8E4M3FNUZ, %arg1: f8E5M2FNUZ) -> (f32, f32) { + %a = arith.extf %arg0 : f8E4M3FNUZ to f32 + %b = arith.extf %arg1 : f8E5M2FNUZ to f32 + return %a, %b : f32, f32 + } +} + +// CHECK-LABEL: @intr_f8_to_f32 +// CHECK: llvm.amdgcn.cvt.f32.fp8 +// CHECK: llvm.amdgcn.cvt.f32.bf8 + + +// ----- + +module { + func.func @intr_f8_to_f64(%arg0: f8E4M3FNUZ, %arg1: f8E5M2FNUZ) -> (f64, f64) { + %a = arith.extf %arg0 : f8E4M3FNUZ to f64 + %b = arith.extf %arg1 : f8E5M2FNUZ to f64 + return %a, %b : f64, f64 + } +} + +// CHECK-LABEL: @intr_f8_to_f64 +// CHECK: llvm.amdgcn.cvt.f32.fp8 +// CHECK: arith.extf %{{.+}} : f32 to f64 +// CHECK: llvm.amdgcn.cvt.f32.bf8 +// CHECK: arith.extf %{{.+}} : f32 to f64 + +// ----- + +module { + func.func @intr_f16_to_4f8(%arg0: f16, %arg1: f16, %arg2: f16, %arg3: f16) -> (vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ>) { + %a0 = arith.truncf %arg0 : f16 to f8E4M3FNUZ + %a1 = arith.truncf %arg1 : f16 to f8E4M3FNUZ + %a2 = arith.truncf %arg2 : f16 to f8E4M3FNUZ + %a3 = arith.truncf %arg3 : f16 to f8E4M3FNUZ + %b0 = arith.truncf %arg0 : f16 to f8E5M2FNUZ + %b1 = arith.truncf %arg1 : f16 to f8E5M2FNUZ + %b2 = arith.truncf %arg2 : f16 to f8E5M2FNUZ + %b3 = arith.truncf %arg3 : f16 to f8E5M2FNUZ + %a_init = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FNUZ> + %b_init = arith.constant dense<0.000000e+00> : vector<4xf8E5M2FNUZ> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = vector.insert %a0, %a_init [%c0] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %1 = vector.insert %a1, %0 [%c1] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %2 = vector.insert %a2, %1 [%c2] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %a = vector.insert %a3, %2 [%c3] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %3 = vector.insert %b0, %b_init [%c0] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %4 = vector.insert %b1, %3 [%c1] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %5 = vector.insert %b2, %4 [%c2] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %b = vector.insert %b3, %5 [%c3] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + return %a, %b : vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ> + } +} + +// CHECK-LABEL: @intr_f16_to_4f8 +// CHECK-COUNT-4: arith.extf %{{.+}} : f16 to f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK-COUNT-4: arith.extf %{{.+}} : f16 to f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_bf16_to_4f8(%arg0: bf16, %arg1: bf16, %arg2: bf16, %arg3: bf16) -> (vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ>) { + %a0 = arith.truncf %arg0 : bf16 to f8E4M3FNUZ + %a1 = arith.truncf %arg1 : bf16 to f8E4M3FNUZ + %a2 = arith.truncf %arg2 : bf16 to f8E4M3FNUZ + %a3 = arith.truncf %arg3 : bf16 to f8E4M3FNUZ + %b0 = arith.truncf %arg0 : bf16 to f8E5M2FNUZ + %b1 = arith.truncf %arg1 : bf16 to f8E5M2FNUZ + %b2 = arith.truncf %arg2 : bf16 to f8E5M2FNUZ + %b3 = arith.truncf %arg3 : bf16 to f8E5M2FNUZ + %a_init = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FNUZ> + %b_init = arith.constant dense<0.000000e+00> : vector<4xf8E5M2FNUZ> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = vector.insert %a0, %a_init [%c0] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %1 = vector.insert %a1, %0 [%c1] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %2 = vector.insert %a2, %1 [%c2] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %a = vector.insert %a3, %2 [%c3] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %3 = vector.insert %b0, %b_init [%c0] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %4 = vector.insert %b1, %3 [%c1] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %5 = vector.insert %b2, %4 [%c2] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %b = vector.insert %b3, %5 [%c3] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + return %a, %b : vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ> + } +} + +// CHECK-LABEL: @intr_bf16_to_4f8 +// CHECK-COUNT-4: arith.extf %{{.+}} : bf16 to f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK-COUNT-4: arith.extf %{{.+}} : bf16 to f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_f32_to_4f8(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ>) { + %a0 = arith.truncf %arg0 : f32 to f8E4M3FNUZ + %a1 = arith.truncf %arg1 : f32 to f8E4M3FNUZ + %a2 = arith.truncf %arg2 : f32 to f8E4M3FNUZ + %a3 = arith.truncf %arg3 : f32 to f8E4M3FNUZ + %b0 = arith.truncf %arg0 : f32 to f8E5M2FNUZ + %b1 = arith.truncf %arg1 : f32 to f8E5M2FNUZ + %b2 = arith.truncf %arg2 : f32 to f8E5M2FNUZ + %b3 = arith.truncf %arg3 : f32 to f8E5M2FNUZ + %a_init = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FNUZ> + %b_init = arith.constant dense<0.000000e+00> : vector<4xf8E5M2FNUZ> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = vector.insert %a0, %a_init [%c0] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %1 = vector.insert %a1, %0 [%c1] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %2 = vector.insert %a2, %1 [%c2] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %a = vector.insert %a3, %2 [%c3] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %3 = vector.insert %b0, %b_init [%c0] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %4 = vector.insert %b1, %3 [%c1] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %5 = vector.insert %b2, %4 [%c2] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %b = vector.insert %b3, %5 [%c3] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + return %a, %b : vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ> + } +} + +// CHECK-LABEL: @intr_f32_to_4f8 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_f64_to_4f8(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64) -> (vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ>) { + %a0 = arith.truncf %arg0 : f64 to f8E4M3FNUZ + %a1 = arith.truncf %arg1 : f64 to f8E4M3FNUZ + %a2 = arith.truncf %arg2 : f64 to f8E4M3FNUZ + %a3 = arith.truncf %arg3 : f64 to f8E4M3FNUZ + %b0 = arith.truncf %arg0 : f64 to f8E5M2FNUZ + %b1 = arith.truncf %arg1 : f64 to f8E5M2FNUZ + %b2 = arith.truncf %arg2 : f64 to f8E5M2FNUZ + %b3 = arith.truncf %arg3 : f64 to f8E5M2FNUZ + %a_init = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FNUZ> + %b_init = arith.constant dense<0.000000e+00> : vector<4xf8E5M2FNUZ> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = vector.insert %a0, %a_init [%c0] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %1 = vector.insert %a1, %0 [%c1] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %2 = vector.insert %a2, %1 [%c2] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %a = vector.insert %a3, %2 [%c3] : f8E4M3FNUZ into vector<4xf8E4M3FNUZ> + %3 = vector.insert %b0, %b_init [%c0] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %4 = vector.insert %b1, %3 [%c1] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %5 = vector.insert %b2, %4 [%c2] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + %b = vector.insert %b3, %5 [%c3] : f8E5M2FNUZ into vector<4xf8E5M2FNUZ> + return %a, %b : vector<4xf8E4M3FNUZ>, vector<4xf8E5M2FNUZ> + } +} + +// CHECK-LABEL: @intr_f64_to_4f8 +// CHECK-COUNT-4: arith.truncf %{{.+}} : f64 to f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.fp8.f32 +// CHECK-COUNT-4: arith.truncf %{{.+}} : f64 to f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.bf8.f32 + +// ----- + +module { + func.func @intr_4f8_to_f16(%arg0: vector<4xf8E4M3FNUZ>, %arg1: vector<4xf8E5M2FNUZ>) -> (f16, f16, f16, f16, f16, f16, f16, f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %a0 = vector.extract %arg0[%c0] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a1 = vector.extract %arg0[%c1] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a2 = vector.extract %arg0[%c2] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a3 = vector.extract %arg0[%c3] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %b0 = vector.extract %arg1[%c0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b1 = vector.extract %arg1[%c1] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b2 = vector.extract %arg1[%c2] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b3 = vector.extract %arg1[%c3] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %0 = arith.extf %a0 : f8E4M3FNUZ to f16 + %1 = arith.extf %a1 : f8E4M3FNUZ to f16 + %2 = arith.extf %a2 : f8E4M3FNUZ to f16 + %3 = arith.extf %a3 : f8E4M3FNUZ to f16 + %4 = arith.extf %b0 : f8E5M2FNUZ to f16 + %5 = arith.extf %b1 : f8E5M2FNUZ to f16 + %6 = arith.extf %b2 : f8E5M2FNUZ to f16 + %7 = arith.extf %b3 : f8E5M2FNUZ to f16 + return %0, %1, %2, %3, %4, %5, %6, %7 : f16, f16, f16, f16, f16, f16, f16, f16 + } +} + +// CHECK-LABEL: @intr_4f8_to_f16 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.fp8 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pkrtz +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.bf8 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pkrtz + +// ----- + +module { + func.func @intr_4f8_to_bf16(%arg0: vector<4xf8E4M3FNUZ>, %arg1: vector<4xf8E5M2FNUZ>) -> (bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %a0 = vector.extract %arg0[%c0] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a1 = vector.extract %arg0[%c1] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a2 = vector.extract %arg0[%c2] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a3 = vector.extract %arg0[%c3] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %b0 = vector.extract %arg1[%c0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b1 = vector.extract %arg1[%c1] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b2 = vector.extract %arg1[%c2] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b3 = vector.extract %arg1[%c3] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %0 = arith.extf %a0 : f8E4M3FNUZ to bf16 + %1 = arith.extf %a1 : f8E4M3FNUZ to bf16 + %2 = arith.extf %a2 : f8E4M3FNUZ to bf16 + %3 = arith.extf %a3 : f8E4M3FNUZ to bf16 + %4 = arith.extf %b0 : f8E5M2FNUZ to bf16 + %5 = arith.extf %b1 : f8E5M2FNUZ to bf16 + %6 = arith.extf %b2 : f8E5M2FNUZ to bf16 + %7 = arith.extf %b3 : f8E5M2FNUZ to bf16 + return %0, %1, %2, %3, %4, %5, %6, %7 : bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16 + } +} + +// CHECK-LABEL: @intr_4f8_to_bf16 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.fp8 +// CHECK-COUNT-8: llvm.bitcast +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.bf8 +// CHECK-COUNT-8: llvm.bitcast +// CHECK-NOT: arith.truncf %{{.+}} : f32 to bf16 + +// ----- + +module { + func.func @intr_4f8_to_f32(%arg0: vector<4xf8E4M3FNUZ>, %arg1: vector<4xf8E5M2FNUZ>) -> (f32, f32, f32, f32, f32, f32, f32, f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %a0 = vector.extract %arg0[%c0] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a1 = vector.extract %arg0[%c1] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a2 = vector.extract %arg0[%c2] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a3 = vector.extract %arg0[%c3] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %b0 = vector.extract %arg1[%c0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b1 = vector.extract %arg1[%c1] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b2 = vector.extract %arg1[%c2] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b3 = vector.extract %arg1[%c3] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %0 = arith.extf %a0 : f8E4M3FNUZ to f32 + %1 = arith.extf %a1 : f8E4M3FNUZ to f32 + %2 = arith.extf %a2 : f8E4M3FNUZ to f32 + %3 = arith.extf %a3 : f8E4M3FNUZ to f32 + %4 = arith.extf %b0 : f8E5M2FNUZ to f32 + %5 = arith.extf %b1 : f8E5M2FNUZ to f32 + %6 = arith.extf %b2 : f8E5M2FNUZ to f32 + %7 = arith.extf %b3 : f8E5M2FNUZ to f32 + return %0, %1, %2, %3, %4, %5, %6, %7 : f32, f32, f32, f32, f32, f32, f32, f32 + } +} + +// CHECK-LABEL: @intr_4f8_to_f32 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.fp8 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.bf8 + +// ----- + +module { + func.func @intr_4f8_to_f64(%arg0: vector<4xf8E4M3FNUZ>, %arg1: vector<4xf8E5M2FNUZ>) -> (f64, f64, f64, f64, f64, f64, f64, f64) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %a0 = vector.extract %arg0[%c0] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a1 = vector.extract %arg0[%c1] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a2 = vector.extract %arg0[%c2] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %a3 = vector.extract %arg0[%c3] : f8E4M3FNUZ from vector<4xf8E4M3FNUZ> + %b0 = vector.extract %arg1[%c0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b1 = vector.extract %arg1[%c1] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b2 = vector.extract %arg1[%c2] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %b3 = vector.extract %arg1[%c3] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> + %0 = arith.extf %a0 : f8E4M3FNUZ to f64 + %1 = arith.extf %a1 : f8E4M3FNUZ to f64 + %2 = arith.extf %a2 : f8E4M3FNUZ to f64 + %3 = arith.extf %a3 : f8E4M3FNUZ to f64 + %4 = arith.extf %b0 : f8E5M2FNUZ to f64 + %5 = arith.extf %b1 : f8E5M2FNUZ to f64 + %6 = arith.extf %b2 : f8E5M2FNUZ to f64 + %7 = arith.extf %b3 : f8E5M2FNUZ to f64 + return %0, %1, %2, %3, %4, %5, %6, %7 : f64, f64, f64, f64, f64, f64, f64, f64 + } +} + +// CHECK-LABEL: @intr_4f8_to_f64 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.fp8 +// CHECK-COUNT-4: arith.extf %{{.+}} : f32 to f64 +// CHECK-COUNT-2: llvm.amdgcn.cvt.pk.f32.bf8 +// CHECK-COUNT-4: arith.extf %{{.+}} : f32 to f64 \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index 2e1d9c2c469aab..e5e70db9a6b8a7 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -59,6 +59,20 @@ cc_library( ], ) +cc_library( + name = "ptx_version_util", + srcs = [ + "ptx_version_util.cc", + ], + hdrs = [ + "ptx_version_util.h", + ], + deps = [ + "//xla/stream_executor:semantic_version", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "nvptx_backend", srcs = [ @@ -75,6 +89,7 @@ cc_library( ":llvm_gpu_backend", ":load_ir_module", ":nvptx_libdevice_path", + ":ptx_version_util", "//xla:util", "//xla:xla_proto_cc", "//xla/service/gpu:metrics", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc index 857b8d2521bdca..001eefea345b77 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc @@ -330,39 +330,4 @@ absl::StatusOr CompileToPtx( } return ptx; } - -namespace { -constexpr stream_executor::SemanticVersion kFallbackPtxVersion{6, 5, 0}; -constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 7, 0}; -} // namespace - -stream_executor::SemanticVersion -DetermineHighestSupportedPtxVersionFromCudaVersion( - stream_executor::SemanticVersion cuda_version) { - if (cuda_version < stream_executor::SemanticVersion{11, 0, 0}) { - // For everything below CUDA 11 we just fall back to PTX 6.5. - // We don't support CUDA below 11 anymore. - return kFallbackPtxVersion; - } - - // Mapping determined from - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#release-notes - // Examples: - // CUDA 11.0 -> PTX 7.0 - // CUDA 11.1 -> PTX 7.1 - // CUDA 12.0 -> PTX 8.0 - // CUDA 12.4 -> PTX 8.4 - // This versioning scheme is valid until CUDA 12.6 - if (cuda_version < stream_executor::SemanticVersion{12, 6, 0}) { - return {cuda_version.major() - 4, cuda_version.minor(), 0}; - } - // CUDA 12.6 -> PTX 8.5 - // CUDA 12.8 -> PTX 8.7 - if (cuda_version < stream_executor::SemanticVersion{12, 9, 0}) { - return {cuda_version.major() - 4, cuda_version.minor() - 1, 0}; - } - - // Return maximum known PTX version. - return kMaxPtxVersion; -} } // namespace xla::gpu::nvptx diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h index 38711afcb52b01..525641d90b46cd 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" +#include "xla/service/gpu/llvm_gpu_backend/ptx_version_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" #include "xla/xla.pb.h" @@ -48,11 +49,6 @@ absl::StatusOr CompileToPtx( const DebugOptions& debug_options, std::function configure_target = nullptr); -// Determine PTX version from CUDA version. -stream_executor::SemanticVersion -DetermineHighestSupportedPtxVersionFromCudaVersion( - stream_executor::SemanticVersion cuda_version); - // Returns the LLVM command line flags that we use for compilation. std::vector GetNVPTXBackendOptions( const DebugOptions& debug_options); diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.cc new file mode 100644 index 00000000000000..a8791598be6c61 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/llvm_gpu_backend/ptx_version_util.h" + +namespace xla::gpu::nvptx { + +namespace { +constexpr stream_executor::SemanticVersion kFallbackPtxVersion{6, 5, 0}; +constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 7, 0}; +} // namespace + +stream_executor::SemanticVersion +DetermineHighestSupportedPtxVersionFromCudaVersion( + stream_executor::SemanticVersion cuda_version) { + if (cuda_version < stream_executor::SemanticVersion{11, 0, 0}) { + // For everything below CUDA 11 we just fall back to PTX 6.5. + // We don't support CUDA below 11 anymore. + return kFallbackPtxVersion; + } + + // Mapping determined from + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#release-notes + // Examples: + // CUDA 11.0 -> PTX 7.0 + // CUDA 11.1 -> PTX 7.1 + // CUDA 12.0 -> PTX 8.0 + // CUDA 12.4 -> PTX 8.4 + // This versioning scheme is valid until CUDA 12.6 + if (cuda_version < stream_executor::SemanticVersion{12, 6, 0}) { + return {cuda_version.major() - 4, cuda_version.minor(), 0}; + } + // CUDA 12.6 -> PTX 8.5 + // CUDA 12.8 -> PTX 8.7 + if (cuda_version < stream_executor::SemanticVersion{12, 9, 0}) { + return {cuda_version.major() - 4, cuda_version.minor() - 1, 0}; + } + + // Return maximum known PTX version. + return kMaxPtxVersion; +} +} // namespace xla::gpu::nvptx diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.h new file mode 100644 index 00000000000000..ecfb0c1a1736a4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/ptx_version_util.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_PTX_VERSION_UTIL_H_ +#define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_PTX_VERSION_UTIL_H_ + +#include "xla/stream_executor/semantic_version.h" + +namespace xla::gpu::nvptx { + +// Determine PTX version from CUDA version. +stream_executor::SemanticVersion +DetermineHighestSupportedPtxVersionFromCudaVersion( + stream_executor::SemanticVersion cuda_version); + +} // namespace xla::gpu::nvptx + +#endif // XLA_SERVICE_GPU_LLVM_GPU_BACKEND_PTX_VERSION_UTIL_H_ diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 333cb0c4c4319d..1c5365be2b9f52 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -127,7 +127,9 @@ class RocmComputeCapability { return has_ocp_fp8_support() || has_nanoo_fp8_support(); } - bool has_ocp_fp8_support() const { return gfx1200() || gfx1201(); } + bool has_ocp_fp8_support() const { + return gfx1200() || gfx1201() || gfx_version() == "gfx950"; + } bool has_nanoo_fp8_support() const { return gfx_version() == "gfx942"; } From 95f587274149f29066c18a4e80bddb3b82fb0d9e Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Tue, 8 Apr 2025 03:36:41 -0700 Subject: [PATCH 0355/1324] [XLA:GPU] preserve types while hoisting bitcasts some operations like compare have different types of argmenets and results, keep the original type while hoisting bitcasts assuming that bitcasts do not convert the type of operand PiperOrigin-RevId: 745069068 --- third_party/xla/xla/service/gpu/transforms/BUILD | 1 + .../service/gpu/transforms/nest_gemm_fusion.cc | 15 +++++++++++++-- .../gpu/transforms/nest_gemm_fusion_test.cc | 4 +--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index b66f9871093b5c..c5bf2734ff976e 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2335,6 +2335,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/backends/gpu/codegen/triton:support", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index 9580087f9490d6..ba80a524e6998f 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -66,6 +66,7 @@ limitations under the License. #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla::gpu { @@ -552,6 +553,7 @@ absl::StatusOr CalculateBroadcastOutputReshape( // outside of the computation. // Returns the new shapes of affected instructions in order of traversal from // users to producers. +// Assumes that the bitcast does not covert the type of the operand. absl::StatusOr>> PlanHoistBitcastToCallers(const HloInstruction* bitcast) { // Check that all producers only affect the bitcast. If there are any @@ -560,15 +562,24 @@ PlanHoistBitcastToCallers(const HloInstruction* bitcast) { // producers downward. HloInstructionSet producers = GetProducerSet(bitcast); TF_RETURN_IF_ERROR(VerifyIsClosedProducerSet(producers, bitcast)); + if (bitcast->shape().element_type() != + bitcast->operand(0)->shape().element_type()) { + return absl::UnimplementedError( + absl::StrCat("Hoisting bitcast with type conversion is not supported: ", + bitcast->ToString())); + } HloInstructionMap to_update; auto set_shape = [&](const absl::Span instructions, const Shape& shape) -> absl::Status { for (HloInstruction* instruction : instructions) { auto it = to_update.find(instruction); + // Only update the dimensions keeping the type intact. + Shape updated_shape = ShapeUtil::MakeShape( + instruction->shape().element_type(), shape.dimensions()); if (it == to_update.end()) { - to_update.emplace(instruction, shape); - } else if (it->second != shape) { + to_update.emplace(instruction, updated_shape); + } else if (it->second != updated_shape) { return absl::FailedPreconditionError(absl::StrCat( "Conflicting shape assignment for ", instruction->ToString(), " got ", it->second.ToString(), " and ", shape.ToString())); diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc index baffca3da9b6cd..c5f1836652bce8 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -486,9 +486,7 @@ ENTRY e { EXPECT_THAT(result, StatusIs(absl::StatusCode::kInternal)) << result.status(); } -// TODO(b/393299275): correctly hoist bitcast through compare. -// Fails with: "... [Unknown]: Expected comparison type UNSIGNED.". -TEST_F(NestGemmFusionTest, DISABLED_BitcastsAreHoistedPastCompare) { +TEST_F(NestGemmFusionTest, BitcastsAreHoistedPastCompare) { absl::string_view hlo = R"( HloModule t From cbb94f2fd8c420af343638edc2f99a963b3d8c6f Mon Sep 17 00:00:00 2001 From: Terry Sun Date: Tue, 8 Apr 2025 03:43:11 -0700 Subject: [PATCH 0356/1324] PR #24473: [NVIDIA GPU] Build topology for multi-host fast-interconnect domain Imported from GitHub PR https://github.com/openxla/xla/pull/24473 Slice is a TPU-derived concept, representing a fast-interconnect domain. It has been used equivalently as host on GPU. With multi-host fast-interconnect, slice is a concept beyond host. Leveraging the utility functions introduced in https://github.com/openxla/xla/pull/23320, this PR builds topology for multi-host fast-interconnect domain. Copybara import of the project: -- 161a94332cd1195d864b6544b90c2060fa82f4ce by Terry Sun : build topo with fabric uuid -- edd8f08b4d36a97135f8c65da9cac6671b8c1ca4 by Terry Sun : better doc string and more tests -- b89d61d0e8cb0d0630252c13cad2f14a335c4510 by Terry Sun : improve proto doc string Merging this change closes #24473 PiperOrigin-RevId: 745070925 --- .../xla/xla/pjrt/distributed/protocol.proto | 3 +- .../xla/xla/pjrt/distributed/topology_util.cc | 32 ++++++-- .../pjrt/distributed/topology_util_test.cc | 74 +++++++++++++++++++ 3 files changed, 100 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/pjrt/distributed/protocol.proto b/third_party/xla/xla/pjrt/distributed/protocol.proto index 9d65bae39e4a24..3c73c130734bc8 100644 --- a/third_party/xla/xla/pjrt/distributed/protocol.proto +++ b/third_party/xla/xla/pjrt/distributed/protocol.proto @@ -58,7 +58,8 @@ message DeviceProto { int32 global_device_id = 4; // Globally unique ID number. // Devices with the same slice_index are connected by fast network, e.g. - // NVLink on GPUs. + // NVLink on GPUs. Note that fast-interconnect can be cross-host, i.e. a + // slice may include multiple hosts. int32 slice_index = 5; // Store vendor-specific compute capability. diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.cc b/third_party/xla/xla/pjrt/distributed/topology_util.cc index 62a80ac692a811..a21b3052188aeb 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util.cc @@ -72,6 +72,19 @@ bool SameLocalTopology(const LocalTopologyProto& a, return true; } +// Returns true if all devices have a valid fabric_uuid. +bool HasFabricUuid(absl::Span local_topologies) { + for (const LocalTopologyProto& local : local_topologies) { + for (const DeviceProto& device : local.devices()) { + if (device.fabric_uuid().empty() || + device.fabric_uuid() == "00000000-0000-0000-0000-000000000000/0") { + return false; + } + } + } + return true; +} + } // namespace // Exists on Linux systems. Unique per OS kernel restart. @@ -174,25 +187,28 @@ absl::StatusOr BuildGlobalTopology( } } } else { - // Assign local devices of the same host to the same slice_index. - absl::flat_hash_map boot_id_to_slice_index; + // Assign local devices of the same fabric_uuid/boot_id to the same + // slice_index. + const bool has_fabric_uuid = HasFabricUuid(local_topologies); + absl::flat_hash_map id_to_slice_index; for (LocalTopologyProto& local : local_topologies) { if (local.has_slice_index()) { return InvalidArgument( "Either all of or none of the local topologies " "should explicitly set slice_index"); } - // Every new boot_id seen is treated as a new host/slice. - auto [it, _] = boot_id_to_slice_index.try_emplace( - local.boot_id(), boot_id_to_slice_index.size()); for (DeviceProto& device : *local.mutable_devices()) { + // Each new fabric_uuid/boot_id seen is treated as a new slice. + auto [it, _] = id_to_slice_index.try_emplace( + has_fabric_uuid ? device.fabric_uuid() : local.boot_id(), + id_to_slice_index.size()); device.set_slice_index(it->second); } } if (VLOG_IS_ON(10)) { - for (auto it = boot_id_to_slice_index.begin(); - it != boot_id_to_slice_index.end(); ++it) { - LOG(INFO) << "BuildGlobalTopology boot_id_to_slice_index " << it->first + for (auto it = id_to_slice_index.begin(); it != id_to_slice_index.end(); + ++it) { + LOG(INFO) << "BuildGlobalTopology id_to_slice_index " << it->first << "->" << it->second; } } diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index ea9665936dc2f5..a911c779758522 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -55,6 +55,80 @@ TEST(TopologyTest, BuildGlobalTopology) { EXPECT_EQ(global.nodes()[1].devices_size(), 2); } +TEST(TopologyTest, BuildGlobalTopologyWithFabricUuid) { + std::vector locals(2); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + d0->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(1); + d1->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + d2->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + d3->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, + BuildGlobalTopology(absl::Span(locals), + /*assign_global_device_ids=*/true)); + EXPECT_EQ(global.nodes_size(), 2); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); + EXPECT_EQ(global.nodes()[0].devices()[0].slice_index(), 0); + EXPECT_EQ(global.nodes()[0].devices()[1].slice_index(), 0); + EXPECT_EQ(global.nodes()[1].devices()[0].slice_index(), 0); + EXPECT_EQ(global.nodes()[1].devices()[1].slice_index(), 0); +} + +TEST(TopologyTest, BuildGlobalTopologyMultipleFabricUuid) { + std::vector locals(4); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + d0->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(1); + d1->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + d2->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + d3->set_fabric_uuid("00000000-0000-0000-0000-000000000001/0"); + DeviceProto* d4 = locals[2].add_devices(); + d4->set_local_device_ordinal(0); + d4->set_fabric_uuid("00000000-0000-0000-0000-000000000002/0"); + DeviceProto* d5 = locals[2].add_devices(); + d5->set_local_device_ordinal(1); + d5->set_fabric_uuid("00000000-0000-0000-0000-000000000002/0"); + DeviceProto* d6 = locals[3].add_devices(); + d6->set_local_device_ordinal(0); + d6->set_fabric_uuid("00000000-0000-0000-0000-000000000002/0"); + DeviceProto* d7 = locals[3].add_devices(); + d7->set_local_device_ordinal(1); + d7->set_fabric_uuid("00000000-0000-0000-0000-000000000002/0"); + + TF_ASSERT_OK_AND_ASSIGN( + GlobalTopologyProto global, + BuildGlobalTopology(absl::Span(locals), + /*assign_global_device_ids=*/true)); + EXPECT_EQ(global.nodes_size(), 4); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); + EXPECT_EQ(global.nodes()[2].devices_size(), 2); + EXPECT_EQ(global.nodes()[3].devices_size(), 2); + EXPECT_EQ(global.nodes()[0].devices()[0].slice_index(), 0); + EXPECT_EQ(global.nodes()[0].devices()[1].slice_index(), 0); + EXPECT_EQ(global.nodes()[1].devices()[0].slice_index(), 0); + EXPECT_EQ(global.nodes()[1].devices()[1].slice_index(), 0); + EXPECT_EQ(global.nodes()[2].devices()[0].slice_index(), 1); + EXPECT_EQ(global.nodes()[2].devices()[1].slice_index(), 1); + EXPECT_EQ(global.nodes()[3].devices()[0].slice_index(), 1); + EXPECT_EQ(global.nodes()[3].devices()[1].slice_index(), 1); +} + TEST(TopologyTest, ExchangeTopology) { int num_nodes = 2; std::vector locals(num_nodes); From 08477a09bed93205171922dc9fd42b2742a57829 Mon Sep 17 00:00:00 2001 From: Shaogang Wang Date: Tue, 8 Apr 2025 03:44:25 -0700 Subject: [PATCH 0357/1324] PR #24750: Fix a multi-thread race issue when enabling command buffer cublasLT cmd Imported from GitHub PR https://github.com/openxla/xla/pull/24750 Copybara import of the project: -- d948053d2f16e79b85d986872a0d1b0815ecc23b by Shawn Wang : fix cublas race -- 9fbe0d230be77ac3a906d70a376d4e6ab6df3a97 by Shawn Wang : fix Merging this change closes #24750 PiperOrigin-RevId: 745071253 --- .../backends/gpu/runtime/command_buffer_cmd.cc | 17 +++++++++++++---- .../backends/gpu/runtime/command_buffer_cmd.h | 6 ++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 4e8a67d0324e82..51c8bdfeebfc0a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -992,10 +992,15 @@ CublasLtCmd::CublasLtCmd( absl::StatusOr CublasLtCmd::GetMatmulPlan( const se::Stream* stream) { - auto it = matmul_plans_cache_.find(stream); - if (it != matmul_plans_cache_.end()) return it->second.get(); + { + absl::MutexLock lock(&matmul_plans_cache_mutex_); + auto it = matmul_plans_cache_.find(stream); + if (it != matmul_plans_cache_.end()) return it->second.get(); + } TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( stream, gemm_config_, epilogue_)); + + absl::MutexLock lock(&matmul_plans_cache_mutex_); auto [it_insert, _] = matmul_plans_cache_.emplace(stream, std::move(plan)); return it_insert->second.get(); } @@ -1004,13 +1009,17 @@ absl::StatusOr CublasLtCmd::GetMatmulAlgorithm(const se::Stream* stream, const se::gpu::BlasLt::MatmulPlan* plan, int64_t max_workspace) { - auto it = matmul_algorithm_cache_.find(plan); - if (it != matmul_algorithm_cache_.end()) return it->second; + { + absl::MutexLock lock(&matmul_algorithm_cache_mutex_); + auto it = matmul_algorithm_cache_.find(plan); + if (it != matmul_algorithm_cache_.end()) return it->second; + } TF_ASSIGN_OR_RETURN( auto algorithms, plan->GetAlgorithms(stream, /*max_algorithm_count*/ 128, /*max_workspace_size*/ max_workspace)); TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); + absl::MutexLock lock(&matmul_algorithm_cache_mutex_); auto [it_insert, _] = matmul_algorithm_cache_.emplace(plan, algorithms[algorithm_idx_]); return it_insert->second; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index d414bf74288221..d65c1a82f0fb5b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -710,12 +710,14 @@ class CublasLtCmd : public TracedCommandBufferCmd { const se::Stream* stream, const se::gpu::BlasLt::MatmulPlan* plan, int64_t max_workspace); + absl::Mutex matmul_plans_cache_mutex_; absl::flat_hash_map - matmul_plans_cache_; + matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); + absl::Mutex matmul_algorithm_cache_mutex_; absl::flat_hash_map - matmul_algorithm_cache_; + matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_); const GemmConfig gemm_config_; const se::gpu::BlasLt::Epilogue epilogue_; From c60735b04bb4900cf7842b3ea1429c2de73b5bbb Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Tue, 8 Apr 2025 12:33:36 +0100 Subject: [PATCH 0358/1324] [mlir][tosa] Add int8 and int16 legalization of the LOG op (#90456) --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 28 +++++++- .../mlir/tosa/transforms/legalize_tfl.cc | 67 +++++++++++++++++++ .../tosa/transforms/tfl_legalize_patterns.td | 1 - 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 1bc67dd3088e71..c802d3c6e9033e 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -934,9 +934,31 @@ func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_log // CHECK: %[[VAR0:.*]] = tosa.log %arg0 -func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_log_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_log_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_log_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_log_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> } // ----- diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index f4877c5c44aac7..61a9f54fd2f2e7 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -211,6 +211,7 @@ DECL_CONVERT_OP(LogicalOr); DECL_CONVERT_OP(Pow); DECL_CONVERT_OP(BroadcastTo); DECL_CONVERT_OP(Exp); +DECL_CONVERT_OP(Log); #undef DECL_CONVERT_OP @@ -4889,6 +4890,72 @@ LogicalResult ConvertTFLExpOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLLogOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_log_op = cast(op); + + RankedTensorType input_type = + dyn_cast(tfl_log_op.getX().getType()); + RankedTensorType output_type = + dyn_cast(tfl_log_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure( + op, "input/output are not all a ranked tensor"); + } + + mlir::quant::UniformQuantizedType input_qtype = + dyn_cast_or_null( + input_type.getElementType()); + mlir::quant::UniformQuantizedType output_qtype = + dyn_cast_or_null( + output_type.getElementType()); + + if ((input_qtype == nullptr) != (output_qtype == nullptr)) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + // Quantization case + if (input_qtype && output_qtype) { + const float output_min = + ((input_qtype.getStorageTypeIntegralWidth() == 8 ? -128 : -32768) - + output_qtype.getZeroPoint()) * + static_cast(output_qtype.getScale()); + + auto log_func = [&](float x) -> float { + if (x <= 0.0f) { + return output_min; + } + return std::log(x); + }; + + Value table_const; + if (input_qtype.getStorageTypeIntegralWidth() == 8) { + table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), log_func); + } else if (input_qtype.getStorageTypeIntegralWidth() == 16) { + table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), log_func); + } else { + return rewriter.notifyMatchFailure( + op, "only quantized int8 and int16 are supported"); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_log_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_log_op.getType(), + tfl_log_op.getX()); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td index 1ed4d67a5f2bb7..b0141dcaf9fa13 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td +++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td @@ -29,7 +29,6 @@ include "mlir/Dialect/Tosa/IR/TosaOps.td" def ConvertTFLAbsOp : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>; def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>; def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>; -def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>; def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>; // Removing the quant.stats op for unquantized models. From 4a33cf549d4b4c18eb1165b9f0a201984cc6ad8f Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 8 Apr 2025 04:19:13 -0700 Subject: [PATCH 0359/1324] Extract TMA metadata from the Triton module if it's available. PiperOrigin-RevId: 745081701 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 1 + .../gpu/codegen/triton/fusion_emitter.cc | 67 +++++--- .../gpu/codegen/triton/fusion_emitter.h | 9 +- .../triton/fusion_emitter_deviceless_test.cc | 2 +- .../triton/fusion_emitter_legacy_matmul.cc | 17 +- .../triton/fusion_emitter_legacy_matmul.h | 15 +- .../fusion_emitter_legacy_matmul_stub.cc | 13 +- .../gpu/codegen/triton/fusion_emitter_stub.cc | 3 +- .../backends/gpu/codegen/triton/test_utils.cc | 6 +- .../backends/gpu/codegen/triton/tma_utils.cc | 50 ++---- .../backends/gpu/codegen/triton/tma_utils.h | 22 +-- .../gpu/codegen/triton/tma_utils_test.cc | 148 +++--------------- ...riton_xla_extract_insert_to_triton_pass.cc | 17 +- 13 files changed, 120 insertions(+), 250 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index f5789d3f3c3677..0f6c0f451e6d59 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -185,6 +185,7 @@ cc_library( ":emitter_helpers", ":fusion_emitter_legacy_matmul", ":support", + ":tma_utils", "//xla:autotuning_proto_cc", "//xla:permutation_util", "//xla:shape_util", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 575dee4b8f0503..0265fcea8553a5 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -91,6 +91,7 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "xla/backends/gpu/codegen/triton/support.h" +#include "xla/backends/gpu/codegen/triton/tma_utils.h" #include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" @@ -129,6 +130,7 @@ limitations under the License. #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -1533,6 +1535,7 @@ absl::Status CreateInternalError(absl::string_view message, } // Legacy emitter works with tt.func. New emitter works with func.func. +// TODO(393299275): Remove legacy optionality once migration is complete. void AppendFuncArgType(EmitterLocOpBuilder& b, absl::Span dims, absl::string_view fusion_kind, Type ir_type, SmallVector& fn_arg_types) { @@ -1548,6 +1551,7 @@ void AppendFuncArgType(EmitterLocOpBuilder& b, absl::Span dims, // Only needed for the new emitter since we are using func.func instead of // tt.func. +// TODO(393299275): Remove legacy optionality once migration is complete. void AppendFuncResultType(EmitterLocOpBuilder& b, absl::string_view fusion_kind, absl::Span dims, Type ir_type, SmallVector& fn_result_types) { @@ -1559,6 +1563,7 @@ void AppendFuncResultType(EmitterLocOpBuilder& b, absl::string_view fusion_kind, } // Legacy emitter works with tt.func. New emitter works with func.func. +// TODO(393299275): Remove legacy optionality once migration is complete. mlir::FunctionOpInterface CreateFuncOp(EmitterLocOpBuilder& b, absl::string_view fn_name, absl::string_view fusion_kind, @@ -1579,6 +1584,7 @@ mlir::FunctionOpInterface CreateFuncOp(EmitterLocOpBuilder& b, } // Legacy emitter works with tt.return. New emitter works with func.return. +// TODO(393299275): Remove legacy optionality once migration is complete. void EmitReturnOp(EmitterLocOpBuilder& b, absl::string_view fusion_kind, SmallVector insert_results) { if (fusion_kind == kTritonGemmFusionKind) { @@ -1588,7 +1594,34 @@ void EmitReturnOp(EmitterLocOpBuilder& b, absl::string_view fusion_kind, } } -absl::StatusOr CreateTritonModule( +absl::StatusOr ExtractTmaMetadata( + mlir::ModuleOp triton_module, absl::string_view kernel_name) { + stream_executor::gpu::TmaMetadata tma_metadata; + SmallVector func_ops; + for (auto func : triton_module.getOps()) { + // Custom calls will also match to LLVMFuncOp, so we are only interested in + // the entry function. + if (func.getName().str() == kernel_name) { + func_ops.push_back(func); + } + } + CHECK_EQ(func_ops.size(), 1) + << "Expected a single LLVMFuncOp in the module for the entry function."; + + for (auto [idx, arg] : llvm::enumerate(func_ops[0].getArguments())) { + if (auto attr = func_ops[0].getArgAttrOfType( + idx, "tt.tma_descriptor")) { + TF_ASSIGN_OR_RETURN( + auto tma_desc, + Create2DTmaDescriptor(attr.getGlobalShape(), attr.getBlockShape(), + attr.getElementByteSize())); + tma_metadata.arg_index_to_tma_info.insert({idx, tma_desc}); + } + } + return tma_metadata; +} + +absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, @@ -1650,10 +1683,7 @@ absl::StatusOr CreateTritonModule( std::string libdevice_path = GetLibdevicePath(fusion->GetModule()->config(), device_info); - // It's okay for tma_metadata to be empty; it's only populated when used - // explicitly. SmallVector insert_results; - std::optional tma_metadata = std::nullopt; if (fusion_kind == kTritonGemmFusionKind) { // If the generic Triton emitter is enabled, we should never go through the // legacy MatMul emitter. @@ -1662,9 +1692,8 @@ absl::StatusOr CreateTritonModule( "The generic Triton emitter is enabled, but the legacy MatMul " "emitter is being used."); } - TF_ASSIGN_OR_RETURN(tma_metadata, - EmitMatMul(b, libdevice_path, device_info, fusion, fn, - block_level_parameters)); + TF_RETURN_IF_ERROR(EmitMatMul(b, libdevice_path, device_info, fusion, fn, + block_level_parameters)); } else if (fusion_kind == kTritonFusionKind || fusion_kind == kTritonNestedGemmFusionKind) { TF_ASSIGN_OR_RETURN(insert_results, @@ -1722,7 +1751,7 @@ absl::StatusOr CreateTritonModule( .xla_gpu_unsupported_annotate_with_emitter_loc())); } - return TritonModule{std::move(triton_module), tma_metadata}; + return std::move(triton_module); } absl::StatusOr TritonWrapper( @@ -1741,7 +1770,7 @@ absl::StatusOr TritonWrapper( } } - TF_ASSIGN_OR_RETURN(auto triton_module, + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef triton_module, CreateTritonModule(fn_name, fusion, device_info, block_level_parameters, mlir_context)); @@ -1751,14 +1780,10 @@ absl::StatusOr TritonWrapper( // Compile Triton kernel to LLVM. const HloModule* hlo_module = fusion->GetModule(); - TF_ASSIGN_OR_RETURN( - TritonWrapperResult result, - CompileTritonToLLVM(fn_name, *hlo_module, device_info, - block_level_parameters, triton_module.module.get(), - llvm_module, mlir_context, - /*is_xla_fusion=*/true)); - result.tma_metadata = triton_module.tma_metadata; - return result; + return CompileTritonToLLVM(fn_name, *hlo_module, device_info, + block_level_parameters, triton_module.get(), + llvm_module, mlir_context, + /*is_xla_fusion=*/true); } absl::StatusOr CompileTritonToLLVM( @@ -1931,7 +1956,13 @@ absl::StatusOr CompileTritonToLLVM( cluster_info.clusterDimY == 1 && cluster_info.clusterDimZ == 1); } - return {{shared_mem_bytes, cluster_dim}}; + + // It's okay for tma_metadata to be empty; it's only populated when used + // explicitly. + TF_ASSIGN_OR_RETURN(stream_executor::gpu::TmaMetadata tma_metadata, + ExtractTmaMetadata(triton_module, kernel_name)); + + return {{shared_mem_bytes, cluster_dim, tma_metadata}}; } } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h index 3cbd9a5039c34d..6cfa3f2a69c3cd 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.h @@ -61,13 +61,6 @@ struct TritonWrapperResult { std::optional tma_metadata; }; -// A wrapper containing a Triton module and optional TmaMetadata, which must be -// extracted from compile-time and passed to the runtime. -struct TritonModule { - mlir::OwningOpRef module; - std::optional tma_metadata; -}; - // Load the MLIR dialects required for Triton IR generation. void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context); @@ -82,7 +75,7 @@ absl::StatusOr TritonWrapper( // Creates the initial Triton module for the given fusion. Visible for testing, // use TritonWrapper instead. -absl::StatusOr CreateTritonModule( +absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc index d25b4d5717edcc..e4118beb8407df 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc @@ -112,7 +112,7 @@ TEST_F(AnnotationsTest, Annotations) { TestGpuDeviceInfo::RTXA6000DeviceInfo(), block_level_parameters, context)); - std::string annotated_ir = DumpTritonIR(triton_module.module.get(), true); + std::string annotated_ir = DumpTritonIR(triton_module.get(), true); if constexpr (EmitterLocOpBuilder::kSourceLocationSupported) { EXPECT_THAT(RunFileCheck(annotated_ir, R"( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index 8a9635baead4fa..98870b577d36e8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -1521,7 +1521,6 @@ ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, return result; } - // This is a heuristic that serves as a proxy for register usage and code size. // // We have noticed that tilings with very long LLVM IR code are both slow to @@ -1902,14 +1901,12 @@ absl::Status EmitForLoopBody(EmitterLocOpBuilder& b, // Use tiling and execution parameters from 'config'. BlockLevelParameters are // ignored. // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -absl::StatusOr> EmitMatMul( - EmitterLocOpBuilder& b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, - const BlockLevelParameters&) { - // TODO b/315957220: Populate tma_metadata. - stream_executor::gpu::TmaMetadata tma_metadata; - +absl::Status EmitMatMul(EmitterLocOpBuilder& b, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, + mlir::FunctionOpInterface fn, + const BlockLevelParameters&) { TF_ASSIGN_OR_RETURN(TritonGemmConfig config, GetTritonGemmConfig(fusion)); TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute( @@ -2029,7 +2026,7 @@ absl::StatusOr> EmitMatMul( b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); } - return tma_metadata; + return absl::OkStatus(); } absl::StatusOr GetMatMulLaunchDimensions( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h index 99d7e10fdf9d9a..23d5484f9650a7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h @@ -16,8 +16,7 @@ limitations under the License. #ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_FUSION_EMITTER_LEGACY_MATMUL_H_ #define XLA_BACKENDS_GPU_CODEGEN_TRITON_FUSION_EMITTER_LEGACY_MATMUL_H_ -#include - +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -29,7 +28,6 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/tma_metadata.h" namespace xla::gpu { @@ -41,11 +39,12 @@ absl::StatusOr GetMatMulLaunchDimensions( // Use tiling and execution parameters from 'config'. BlockLevelParameters are // ignored. // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -absl::StatusOr> EmitMatMul( - EmitterLocOpBuilder& builder, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, - const BlockLevelParameters&); +absl::Status EmitMatMul(EmitterLocOpBuilder& builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, + mlir::FunctionOpInterface fn, + const BlockLevelParameters&); } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc index eb845dab9830da..9b78b00de3bd4e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul_stub.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -28,7 +27,6 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/tma_metadata.h" namespace xla::gpu { @@ -39,11 +37,12 @@ absl::StatusOr GetMatMulLaunchDimensions( return absl::UnimplementedError("not supported for this build configuration"); } -absl::StatusOr> EmitMatMul( - EmitterLocOpBuilder& builder, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, - const BlockLevelParameters&) { +absl::Status EmitMatMul(EmitterLocOpBuilder& builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, + mlir::FunctionOpInterface fn, + const BlockLevelParameters&) { return absl::UnimplementedError("not supported for this build configuration"); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc index 19ed1aab38c432..692726d8f4afaa 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_stub.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/PassManager.h" @@ -62,7 +63,7 @@ absl::StatusOr TritonWrapper( return absl::UnimplementedError("not supported for this build configuration"); } -absl::StatusOr CreateTritonModule( +absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc index 997a5fb4c450bc..006d686f73eabd 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/test_utils.cc @@ -33,7 +33,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -119,14 +121,14 @@ absl::Status CreateTritonIrAndFileCheck( mlir::MLIRContext context; TF_ASSIGN_OR_RETURN( - auto triton_module, + mlir::OwningOpRef triton_module, CreateTritonModule("triton_fn", fusion, TestGpuDeviceInfo::RTXA6000DeviceInfo(), block_level_parameters, context)); std::string out; llvm::raw_string_ostream os(out); - triton_module.module->print(os); + triton_module->print(os); TF_ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern)); if (!succeeded) { return absl::InternalError("FileCheck failed."); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.cc b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.cc index 518490d358b1d0..eba4718b72de4f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.cc @@ -16,58 +16,42 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/tma_utils.h" #include -#include #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "xla/codegen/emitter_loc_op_builder.h" -#include "xla/shape.h" #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/tsl/platform/statusor.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" namespace xla::gpu { -namespace mt = ::mlir::triton; - using ::llvm::SmallVector; -using ::mlir::RankedTensorType; -using ::mlir::Type; -using ::mlir::Value; using ::stream_executor::gpu::TmaDescriptor; -using ::stream_executor::gpu::TmaMetadata; // Returns a TmaDescriptor for a 2D tensor to be emitted in Triton. // // This function follows the defaults and logic found in fill2DTMADescriptor in // @triton/third_party/nvidia/backend/cuda_utils.cc absl::StatusOr Create2DTmaDescriptor( - Shape global_shape, llvm::ArrayRef block_shape, - Type element_type) { - if (global_shape.dimensions().size() != 2) { + llvm::ArrayRef global_shape, llvm::ArrayRef block_shape, + int element_byte_size) { + if (global_shape.size() != 2) { return absl::InvalidArgumentError("expected 2D global shape"); } if (block_shape.size() != 2) { return absl::InvalidArgumentError("expected 2D block shape"); } - int byte_width = element_type.getIntOrFloatBitWidth() / 8; SmallVector global_dims = { - static_cast(global_shape.dimensions(1)), - static_cast(global_shape.dimensions(0))}; - auto global_strides = {global_dims[0] * byte_width}; + static_cast(global_shape[1]), + static_cast(global_shape[0])}; + auto global_strides = {global_dims[0] * element_byte_size}; SmallVector box_dims = {static_cast(block_shape[1]), static_cast(block_shape[0])}; SmallVector element_strides = {1, 1}; TmaDescriptor::TmaSwizzle swizzle; - uint32_t contig_dim_size_in_byte = byte_width * box_dims[0]; + uint32_t contig_dim_size_in_byte = element_byte_size * box_dims[0]; if (contig_dim_size_in_byte >= 128) { swizzle = TmaDescriptor::TmaSwizzle::k128B; } else if (contig_dim_size_in_byte >= 64) { @@ -79,30 +63,14 @@ absl::StatusOr Create2DTmaDescriptor( "continguous dimension size too small"); } if (contig_dim_size_in_byte > 128) { - box_dims[0] = 128 / byte_width; + box_dims[0] = 128 / element_byte_size; } TF_ASSIGN_OR_RETURN( auto tma_desc, TmaDescriptor::Create( global_dims, global_strides, box_dims, element_strides, - byte_width, TmaDescriptor::TmaInterleave::kNone, + element_byte_size, TmaDescriptor::TmaInterleave::kNone, swizzle, TmaDescriptor::TmaL2Promotion::k128B)); return tma_desc; } -Value EmitTmaDescriptor(EmitterLocOpBuilder& b, Value arg, - RankedTensorType tensor_type) { - auto desc_type = mt::TensorDescType::get(b.getContext(), tensor_type); - return b.create(desc_type, arg); -} - -void RewriteFunctionForTma(EmitterLocOpBuilder& b, mlir::triton::FuncOp fn, - std::optional tma_metadata) { - if (!tma_metadata.has_value()) { - return; - } - for (auto& [parameter_number, _] : tma_metadata->arg_index_to_tma_info) { - fn.setArgAttr(parameter_number, "tt.nv_tma_desc", b.getI32IntegerAttr(1)); - } -} - } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h index 3a6382f5ea75ab..aa5be3fe90127e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils.h @@ -17,35 +17,17 @@ limitations under the License. #define XLA_BACKENDS_GPU_CODEGEN_TRITON_TMA_UTILS_H_ #include -#include #include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "xla/codegen/emitter_loc_op_builder.h" -#include "xla/shape.h" #include "xla/stream_executor/gpu/tma_metadata.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace xla::gpu { // Returns a TmaDescriptor for a 2D tensor to be emitted in Triton. absl::StatusOr Create2DTmaDescriptor( - Shape global_shape, llvm::ArrayRef block_shape, - mlir::Type element_type); - -// Emit a TmaDescriptor for the given argument & tensor type. It can then be -// used to load a tensor using the DescriptorLoadOp. -mlir::Value EmitTmaDescriptor(EmitterLocOpBuilder& b, mlir::Value arg, - mlir::RankedTensorType tensor_type); - -// Loading arguments by TMA changes the kernel signature and must be updated -// appropriately. -void RewriteFunctionForTma( - EmitterLocOpBuilder& b, mlir::triton::FuncOp fn, - std::optional tma_metadata); + llvm::ArrayRef global_shape, llvm::ArrayRef block_shape, + int element_byte_size); } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc index be4e0b6fb9b480..8aaded460f601e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tma_utils_test.cc @@ -16,47 +16,23 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/tma_utils.h" #include -#include -#include #include #include #include "absl/status/status.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "xla/codegen/emitter_loc_op_builder.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" -#include "xla/xla_data.pb.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" namespace xla::gpu { namespace { using ::absl::StatusCode; using ::llvm::SmallVector; -using ::mlir::RankedTensorType; -using ::mlir::Type; -using ::mlir::Value; -using ::mlir::triton::FuncOp; -using ::mlir::triton::PointerType; using ::stream_executor::gpu::TmaDescriptor; -using ::stream_executor::gpu::TmaMetadata; using ::testing::ElementsAre; using ::testing::HasSubstr; using ::tsl::testing::StatusIs; @@ -64,12 +40,12 @@ using ::tsl::testing::StatusIs; TEST(Create2DTmaDescriptorTest, ValidInputReturnCorrectDescriptor) { mlir::MLIRContext mlir_context; mlir::Builder b(&mlir_context); - Shape global_shape = ShapeUtil::MakeShape(F32, {256, 128}); + llvm::SmallVector global_shape = {256, 128}; llvm::SmallVector block_shape = {64, 32}; - mlir::Type element_type = b.getF32Type(); + int element_byte_size = 4; TF_ASSERT_OK_AND_ASSIGN( TmaDescriptor tma_desc, - Create2DTmaDescriptor(global_shape, block_shape, element_type)); + Create2DTmaDescriptor(global_shape, block_shape, element_byte_size)); EXPECT_EQ(tma_desc.element_size(), 4); EXPECT_EQ(tma_desc.num_dimensions(), 2); EXPECT_THAT(tma_desc.global_dims(), ElementsAre(128, 256)); @@ -85,119 +61,37 @@ TEST(Create2DTmaDescriptorTest, ValidInputReturnCorrectDescriptor) { TEST(Create2DTmaDescriptorTest, BadGlobalShapeFailsGracefully) { mlir::MLIRContext mlir_context; mlir::Builder b(&mlir_context); - Shape global_shape = ShapeUtil::MakeShape(F32, {128}); + llvm::SmallVector global_shape = {128}; llvm::SmallVector block_shape = {128, 128}; - mlir::Type element_type = b.getF32Type(); - EXPECT_THAT(Create2DTmaDescriptor(global_shape, block_shape, element_type), - StatusIs(StatusCode::kInvalidArgument, - HasSubstr("expected 2D global shape"))); + int element_byte_size = 4; + EXPECT_THAT( + Create2DTmaDescriptor(global_shape, block_shape, element_byte_size), + StatusIs(StatusCode::kInvalidArgument, + HasSubstr("expected 2D global shape"))); } TEST(Create2DTmaDescriptorTest, BadBlockShapeFailsGracefully) { mlir::MLIRContext mlir_context; mlir::Builder b(&mlir_context); - Shape global_shape = ShapeUtil::MakeShape(F32, {128, 128}); + llvm::SmallVector global_shape = {128, 128}; llvm::SmallVector block_shape = {128}; - mlir::Type element_type = b.getF32Type(); - EXPECT_THAT(Create2DTmaDescriptor(global_shape, block_shape, element_type), - StatusIs(StatusCode::kInvalidArgument, - HasSubstr("expected 2D block shape"))); + int element_byte_size = 4; + EXPECT_THAT( + Create2DTmaDescriptor(global_shape, block_shape, element_byte_size), + StatusIs(StatusCode::kInvalidArgument, + HasSubstr("expected 2D block shape"))); } TEST(Create2DTmaDescriptorTest, SmallBlockShapeFailsGracefully) { mlir::MLIRContext mlir_context; mlir::Builder b(&mlir_context); - Shape global_shape = ShapeUtil::MakeShape(F32, {128, 128}); + llvm::SmallVector global_shape = {128, 128}; llvm::SmallVector block_shape = {128, 2}; - mlir::Type element_type = b.getF32Type(); - EXPECT_THAT(Create2DTmaDescriptor(global_shape, block_shape, element_type), - StatusIs(StatusCode::kFailedPrecondition, - HasSubstr("dimension size too small"))); -} - -class TmaUtilsFixture : public testing::Test { - public: - void SetUp() override { - mlir_context_.loadDialect(); - std::string fn_name = "test_fn"; - auto loc = mlir::NameLoc::get(b_.getStringAttr(fn_name)); - triton_module_ = llvm_ir::CreateMlirModuleOp(loc); - b_.setInsertionPointToEnd(triton_module_->getBody()); - } - - EmitterLocOpBuilder GetEmitterLocOpBuilder() { - return EmitterLocOpBuilder(mlir::NameLoc::get(b_.getStringAttr("test_fn")), - b_); - } - - FuncOp CreateTestFunction(EmitterLocOpBuilder& b) { - std::string fn_name = "test_fn"; - SmallVector fn_arg_types{ - PointerType::get(b.getF32Type(), mlir::NVVM::kGlobalMemorySpace), - PointerType::get(b.getF32Type(), mlir::NVVM::kGlobalMemorySpace), - PointerType::get(b.getF32Type(), mlir::NVVM::kGlobalMemorySpace)}; - auto func_type = b.getFunctionType(fn_arg_types, std::nullopt); - FuncOp fn = b.create(fn_name, func_type); - b.setInsertionPointToStart(fn.addEntryBlock()); - return fn; - } - - protected: - mlir::MLIRContext mlir_context_; - mlir::OpBuilder b_{&mlir_context_}; - mlir::OwningOpRef triton_module_; -}; - -TEST_F(TmaUtilsFixture, - EmitTmaDescriptor_ValidInputReturnsCorrectTmaDescriptor) { - EmitterLocOpBuilder b = GetEmitterLocOpBuilder(); - FuncOp fn = CreateTestFunction(b); - Value arg = fn.getArgument(0); - RankedTensorType tensor_type = - RankedTensorType::get({128, 128}, b.getF32Type()); - Value tma_desc = EmitTmaDescriptor(b, arg, tensor_type); - EXPECT_EQ(tma_desc.getType(), - mlir::triton::TensorDescType::get(b.getContext(), tensor_type)); -} - -TEST_F(TmaUtilsFixture, - RewriteFunctionForTma_TmaDescriptorsSetCorrectTmaAttribute) { - EmitterLocOpBuilder b = GetEmitterLocOpBuilder(); - FuncOp fn = CreateTestFunction(b); - TmaMetadata tma_metadata; - TF_ASSERT_OK_AND_ASSIGN( - auto tma_desc, - TmaDescriptor::Create({128, 128}, {128}, {64, 64}, {1, 1}, 4)); - tma_metadata.arg_index_to_tma_info.insert({0, tma_desc}); - TF_ASSERT_OK_AND_ASSIGN( - tma_desc, TmaDescriptor::Create({128, 128}, {128}, {64, 64}, {1, 1}, 4)); - tma_metadata.arg_index_to_tma_info.insert({2, tma_desc}); - - RewriteFunctionForTma(b, fn, tma_metadata); - EXPECT_EQ(fn.getArgAttr(0, "tt.nv_tma_desc"), b_.getI32IntegerAttr(1)); - EXPECT_FALSE(fn.getArgAttr(1, "tt.nv_tma_desc")); - EXPECT_EQ(fn.getArgAttr(2, "tt.nv_tma_desc"), b_.getI32IntegerAttr(1)); -} - -TEST_F(TmaUtilsFixture, - RewriteFunctionForTma_NoTmaMetadataDoesNotSetTmaAttribute) { - EmitterLocOpBuilder b = GetEmitterLocOpBuilder(); - FuncOp fn = CreateTestFunction(b); - RewriteFunctionForTma(b, fn, std::nullopt); - EXPECT_FALSE(fn.getArgAttr(0, "tt.nv_tma_desc")); - EXPECT_FALSE(fn.getArgAttr(1, "tt.nv_tma_desc")); - EXPECT_FALSE(fn.getArgAttr(2, "tt.nv_tma_desc")); -} - -TEST_F(TmaUtilsFixture, - RewriteFunctionForTma_EmptyTmaMetadataDoesNotSetTmaAttribute) { - EmitterLocOpBuilder b = GetEmitterLocOpBuilder(); - FuncOp fn = CreateTestFunction(b); - TmaMetadata tma_metadata; - RewriteFunctionForTma(b, fn, tma_metadata); - EXPECT_FALSE(fn.getArgAttr(0, "tt.nv_tma_desc")); - EXPECT_FALSE(fn.getArgAttr(1, "tt.nv_tma_desc")); - EXPECT_FALSE(fn.getArgAttr(2, "tt.nv_tma_desc")); + int element_byte_size = 4; + EXPECT_THAT( + Create2DTmaDescriptor(global_shape, block_shape, element_byte_size), + StatusIs(StatusCode::kFailedPrecondition, + HasSubstr("dimension size too small"))); } } // namespace diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index e7f064d64ba535..0e9dcc6d87cfb4 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -15,9 +15,7 @@ limitations under the License. #include -#include #include -#include #include #include #include @@ -51,7 +49,6 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" #include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" -#include "xla/backends/gpu/codegen/triton/tma_utils.h" #include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/codegen/emitters/ir/xla_ops.h" @@ -373,8 +370,13 @@ struct RewriteTile : mlir::OpRewritePattern { op.getTensor()) .getResult(0); - auto reinterpret_tensor_desc = xg::EmitTmaDescriptor( - builder, cast_to_tensor_ptr_type, tiled_tensor_type.getTileType()); + auto reinterpret_tensor_desc = + builder + .create( + mlir::triton::TensorDescType::get( + builder.getContext(), tiled_tensor_type.getTileType()), + cast_to_tensor_ptr_type) + .getResult(); // !tt.tensordesc -> tiled_tensor auto cast_desc_ptr_to_tiled_tensor_ptr_type = @@ -622,10 +624,11 @@ struct TritonXLAExtractInsertToTritonPass &device_info)); device_description = stream_executor::DeviceDescription(device_info); } - tma_enabled = tma_enabled_; + if (tma_enabled_.hasValue()) { + tma_enabled = tma_enabled_.getValue(); + } mlir::MLIRContext* mlir_context = &getContext(); - mlir::RewritePatternSet tile_pattern_set(mlir_context); tile_pattern_set.add(mlir_context, &device_description, tma_enabled); From be0984edd903638773fcd79bd6dea7f0e21b74a7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 8 Apr 2025 04:26:18 -0700 Subject: [PATCH 0360/1324] =?UTF-8?q?[XLA:GPU]=C2=A0Activate=20a=20few=20m?= =?UTF-8?q?ore=20tests=20in=20`fusion=5Femitter=5Fdevice=5Flegacy=5Fport?= =?UTF-8?q?=5Ftest.cc`.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also identify a few more `TODO`s on the path to general enablement. PiperOrigin-RevId: 745083617 --- .../fusion_emitter_device_legacy_port_test.cc | 333 +++++++++++------- .../gpu/model/symbolic_tile_analysis.cc | 4 + 2 files changed, 211 insertions(+), 126 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 11e9fb9569b73b..e9b44186b362a5 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -3135,82 +3135,88 @@ ENTRY e { /*run_hlo_passes=*/false)); } +// TODO(b/393299275): enable once `NestGemmFusion` data type propagation is +// fixed. At the moment, the data type is not propagated correctly and that +// causes a miscompile. TEST_F(CompareTest, DISABLED_SplitKNontrivialBitcast) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloTextRef = R"( -HloModule module, is_scheduled=true - -triton_gemm_dot.5316 { - parameter_1 = bf16[16,4,128]{2,1,0} parameter(1) - bitcast.2 = bf16[16,512]{1,0} bitcast(parameter_1) - parameter_0 = s8[512,96]{1,0} parameter(0) - convert.4 = bf16[512,96]{1,0} convert(parameter_0) - ROOT dot.0 = bf16[16,96]{1,0} dot(bitcast.2, convert.4), +HloModule module + +dot { + p0 = s8[512,96]{1,0} parameter(0) + convert = bf16[512,96]{1,0} convert(p0) + p1 = bf16[16,4,128]{2,1,0} parameter(1) + bitcast = bf16[16,512]{1,0} bitcast(p1) + ROOT dot = bf16[16,96]{1,0} dot(bitcast, convert), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY entry { - parameter_0.1 = s8[96,4,128]{2,1,0} parameter(0) - bitcast.6 = s8[512,96]{1,0} bitcast(parameter_0.1) - parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) - ROOT triton_gemm_dot.5316 = bf16[16,96]{1,0} fusion(bitcast.6, parameter_1.1), - kind=kCustom, calls=triton_gemm_dot.5316, + p0 = s8[512,96]{1,0} parameter(0) + p1 = bf16[16,4,128]{2,1,0} parameter(1) + ROOT dot = bf16[16,96]{1,0} fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":256, "split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} })"; - const std::string kHloTextSplitK = R"( -HloModule module, is_scheduled=true - -triton_gemm_dot.5316 { - parameter_1 = bf16[16,4,128]{2,1,0} parameter(1) - bitcast.2 = bf16[16,512]{1,0} bitcast(parameter_1) - bitcast.17 = bf16[16,16,32]{2,1,0} bitcast(bitcast.2) - parameter_0 = s8[512,96]{1,0} parameter(0) - convert.4 = bf16[512,96]{1,0} convert(parameter_0) - bitcast.18 = bf16[16,32,96]{2,1,0} bitcast(convert.4) - ROOT dot.4 = bf16[16,16,96]{2,1,0} dot(bitcast.17, bitcast.18), + const std::string kHloTextTest = R"( +HloModule module + +dot { + p0 = s8[512,96]{1,0} parameter(0) + convert = bf16[512,96]{1,0} convert(p0) + p1 = bf16[16,4,128]{2,1,0} parameter(1) + bitcast_p1 = bf16[16,16,32]{2,1,0} bitcast(p1) + bitcast_convert = bf16[16,32,96]{2,1,0} bitcast(convert) + ROOT dot = bf16[16,16,96]{2,1,0} dot(bitcast_p1, bitcast_convert), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} } -triton_gemm_dot.5316.reduce_sub_computation.clone { - rhs.1 = f32[] parameter(1) - lhs.1 = f32[] parameter(0) - ROOT add.1 = f32[] add(lhs.1, rhs.1) +reducer { + rhs = f32[] parameter(1) + lhs = f32[] parameter(0) + ROOT add = f32[] add(lhs, rhs) } -fused_computation { - param_0.2 = bf16[16,16,96]{2,1,0} parameter(0) - convert.19 = f32[16,16,96]{2,1,0} convert(param_0.2) - constant_1 = bf16[] constant(0) - convert.18 = f32[] convert(constant_1) - reduce.1 = f32[16,96]{1,0} reduce(convert.19, convert.18), - dimensions={0}, to_apply=triton_gemm_dot.5316.reduce_sub_computation.clone - ROOT convert.17 = bf16[16,96]{1,0} convert(reduce.1) +split_k_reducer { + p0 = bf16[16,16,96]{2,1,0} parameter(0) + convert = f32[16,16,96]{2,1,0} convert(p0) + c0 = f32[] constant(0) + reduce = f32[16,96]{1,0} reduce(convert, c0), + dimensions={0}, to_apply=reducer + ROOT output = bf16[16,96]{1,0} convert(reduce) } ENTRY entry { - parameter_0.1 = s8[96,4,128]{2,1,0} parameter(0) - bitcast.6 = s8[512,96]{1,0} bitcast(parameter_0.1) - parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) - triton_gemm_dot.5316 = bf16[16,16,96]{2,1,0} fusion(bitcast.6, parameter_1.1), - kind=kCustom, calls=triton_gemm_dot.5316, + p0 = s8[512,96]{1,0} parameter(0) + p1 = bf16[16,4,128]{2,1,0} parameter(1) + dot = bf16[16,16,96]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":32, "split_k":16,"num_stages":1,"num_warps":4, "num_ctas":1}}} - ROOT fusion.1 = bf16[16,96]{1,0} fusion(triton_gemm_dot.5316), - kind=kLoop, calls=fused_computation + ROOT output = bf16[16,96]{1,0} fusion(dot), kind=kLoop, calls=split_k_reducer })"; + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, - ErrorSpec{/*aabs=*/2, /*arel=*/1e-2}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata ref_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextRef)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(ref_module_and_metadata.module), + std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); } // This is based on gemm_fusion_test.cc/SplitKTest.SupportsIndivisible. @@ -3294,6 +3300,7 @@ ENTRY entry_computation { /*run_hlo_passes=*/false)); } +// TODO(b/393299275): transform this test once padding derivation if fixed. TEST_F(CompareTest, DISABLED_SupportsSplitKWithIndivisibleKUsingPaddingEqual1) { constexpr absl::string_view kHloTextRef = R"( HloModule extracted, entry_computation_layout={(f16[1,8,4,1023]{3,2,1,0}, f16[1,1023,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} @@ -3372,24 +3379,41 @@ ENTRY entry_computation { /*run_hlo_passes=*/false)); } +// TODO(b/393299275): symbolic tile derivation fails for one of the padded +// operands, with indexing map +// (d0, d1, d2, d3) -> (d1, d0 * 64 + d3) +// domain: d0 in [0, 15] +// d1 in [0, 31] +// d2 in [0, 127] +// d3 in [0, 63] +// d0 * 64 + d3 in [0, 1018] +// While the expression should be processed without any issue, padding +// introduces a non-redundant pre-existing constraint d0 * 64 + d3 in [0, 1018], +// which causes the derivation to be rejected. The reason for this is that it's +// not quite clear how to handle these pre-existing constraints in the general +// sense. But wrt HLO specifically and symbolic tile analysis, we could probably +// decide to just drop them from symbolic tile derivation: the reason for that +// is that offset constraints are handled via `tile_offsets_indexing` anyway, +// and it's all that should be relevant afaik. We can probably let the caller +// decide to drop pre-existing constraints. TEST_F(CompareTest, DISABLED_SupportsSplitKWithIndivisibleKUsingPaddingEqual5) { constexpr absl::string_view kHloTextRef = R"( -HloModule extracted, entry_computation_layout={(f16[1,8,4,1019]{3,2,1,0}, f16[1,1019,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} +HloModule extracted -triton_gemm_dot.7103_computation.clone { - parameter_0.499 = f16[1,8,4,1019]{3,2,1,0} parameter(0) - bitcast.7923 = f16[32,1019]{1,0} bitcast(parameter_0.499) - parameter_1.499 = f16[1,1019,128]{2,1,0} parameter(1) - bitcast.7924 = f16[1019,128]{1,0} bitcast(parameter_1.499) - dot.9350 = f16[32,128]{1,0} dot(bitcast.7923, bitcast.7924), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT bitcast.7925 = f16[1,8,4,128]{3,2,1,0} bitcast(dot.9350) +dot { + p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) + bitcast_p0 = f16[32,1019]{1,0} bitcast(p0) + p1 = f16[1,1019,128]{2,1,0} parameter(1) + bitcast_p1 = f16[1019,128]{1,0} bitcast(p1) + dot = f16[32,128]{1,0} dot(bitcast_p0, bitcast_p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT bitcast_dot = f16[1,8,4,128]{3,2,1,0} bitcast(dot) } ENTRY entry_computation { p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) p1 = f16[1,1019,128]{2,1,0} parameter(1) - ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, - calls=triton_gemm_dot.7103_computation.clone, + ROOT dot = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=dot, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", "split_k":"1","num_stages":"1","num_warps":"4", @@ -3397,77 +3421,86 @@ ENTRY entry_computation { } )"; - constexpr absl::string_view kHloTextSplitK = R"( -HloModule extracted, entry_computation_layout={(f16[1,8,4,1019]{3,2,1,0}, f16[1,1019,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} + constexpr absl::string_view kHloTextTest = R"( +HloModule extracted -triton_gemm_dot.7103_computation.clone { - parameter_0.499 = f16[1,8,4,1019]{3,2,1,0} parameter(0) - bitcast.7923 = f16[32,1019]{1,0} bitcast(parameter_0.499) - constant = f16[] constant(0) - pad = f16[32,1024]{1,0} pad(bitcast.7923, constant), padding=0_0x0_5 - bitcast = f16[32,16,64]{2,1,0} bitcast(pad) - parameter_1.499 = f16[1,1019,128]{2,1,0} parameter(1) - bitcast.7924 = f16[1019,128]{1,0} bitcast(parameter_1.499) - constant.1 = f16[] constant(0) - pad.1 = f16[1024,128]{1,0} pad(bitcast.7924, constant.1), padding=0_5x0_0 - bitcast.1 = f16[16,64,128]{2,1,0} bitcast(pad.1) - dot.1 = f16[16,32,128]{2,1,0} dot(bitcast, bitcast.1), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} - ROOT bitcast.7925.clone = f16[16,1,8,4,128]{4,3,2,1,0} bitcast(dot.1) +split_k_dot { + p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) + bitcast_p0 = f16[32,1019]{1,0} bitcast(p0) + c0 = f16[] constant(0) + pad_p0 = f16[32,1024]{1,0} pad(bitcast_p0, c0), padding=0_0x0_5 + bitcast_pad_p0 = f16[32,16,64]{2,1,0} bitcast(pad_p0) + p1 = f16[1,1019,128]{2,1,0} parameter(1) + bitcast_p1 = f16[1019,128]{1,0} bitcast(p1) + pad_p1 = f16[1024,128]{1,0} pad(bitcast_p1, c0), padding=0_5x0_0 + bitcast_pad_p1 = f16[16,64,128]{2,1,0} bitcast(pad_p1) + dot = f16[16,32,128]{2,1,0} dot(bitcast_pad_p0, bitcast_pad_p1), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT bitcast.7925.clone = f16[16,1,8,4,128]{4,3,2,1,0} bitcast(dot) } -triton_gemm_dot.7103.reduce_sub_computation.clone { - lhs.1 = f32[] parameter(0) - rhs.1 = f32[] parameter(1) - add.2 = f32[] add(lhs.1, rhs.1) - convert.13 = f16[] convert(add.2) - ROOT convert.12 = f32[] convert(convert.13) +reducer { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + add = f32[] add(lhs, rhs) + convert = f16[] convert(add) + ROOT output = f32[] convert(convert) } -fused_computation.1 { - param_0.5 = f16[16,1,8,4,128]{4,3,2,1,0} parameter(0) - convert.16 = f32[16,1,8,4,128]{4,3,2,1,0} convert(param_0.5) - constant.3 = f16[] constant(0) - convert.15 = f32[] convert(constant.3) - reduce.1 = f32[1,8,4,128]{3,2,1,0} reduce(convert.16, convert.15), dimensions={0}, to_apply=triton_gemm_dot.7103.reduce_sub_computation.clone - ROOT convert.14 = f16[1,8,4,128]{3,2,1,0} convert(reduce.1) +split_k_reducer { + p0 = f16[16,1,8,4,128]{4,3,2,1,0} parameter(0) + convert = f32[16,1,8,4,128]{4,3,2,1,0} convert(p0) + c0 = f32[] constant(0) + reduce = f32[1,8,4,128]{3,2,1,0} reduce(convert, c0), dimensions={0}, to_apply=reducer + ROOT output = f16[1,8,4,128]{3,2,1,0} convert(reduce) } ENTRY entry_computation { p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) p1 = f16[1,1019,128]{2,1,0} parameter(1) - triton_gemm_dot.7103 = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, - calls=triton_gemm_dot.7103_computation.clone, + dot = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=split_k_dot, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"32","block_k":"32", "split_k":"16","num_stages":"1","num_warps":"4", "num_ctas":"1"}}} - ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 + ROOT fusion = f16[1,8,4,128]{3,2,1,0} fusion(dot), kind=kLoop, + calls=split_k_reducer } )"; + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, - ErrorSpec{/*aabs=*/4e-2, /*arel=*/2e-2}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata ref_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextRef)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(ref_module_and_metadata.module), + std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/4e-2, /*arel=*/2e-2}, + /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_NonMajorMostOutputBatchWorksCorrectly) { +TEST_F(CompareTest, NonMajorMostOutputBatchWorksCorrectly) { const std::string kHloTextTest = R"( HloModule m -triton_gemm_dot.6 { - parameter_1 = f32[32,50,104]{2,1,0} parameter(1) - parameter_0 = s8[32,26,104]{2,1,0} parameter(0) - convert.22 = f32[32,26,104]{2,1,0} convert(parameter_0) - ROOT dot.127 = f32[32,50,26]{2,0,1} dot(parameter_1, convert.22), +dot { + p0 = pred[32,26,104]{2,1,0} parameter(0) + p1 = f32[32,50,104]{2,1,0} parameter(1) + convert = f32[32,26,104]{2,1,0} convert(p0) + ROOT dot = f32[32,50,26]{2,0,1} dot(p1, convert), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2} } ENTRY e { - p0 = s8[32,26,104]{2,1,0} parameter(0) + p0 = pred[32,26,104]{2,1,0} parameter(0) p1 = f32[32,50,104]{2,1,0} parameter(1) - ROOT triton_gemm_dot.6 = f32[32,50,26]{2,0,1} fusion(p0, p1), - kind=kCustom, calls=triton_gemm_dot.6, + ROOT dot = f32[32,50,26]{2,0,1} fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":16,"block_k":32, "split_k":1,"num_stages":1,"num_warps":4, @@ -3477,39 +3510,49 @@ ENTRY e { const std::string kHloTextRef = R"( HloModule m -%triton_gemm_dot.127 { - %parameter_1.1 = f32[32,50,104]{2,1,0} parameter(1) - %parameter_0.1 = s8[32,26,104]{2,1,0} parameter(0) - %convert.0 = f32[32,26,104]{2,1,0} convert(%parameter_0.1) - ROOT %dot.0 = f32[32,50,26]{2,1,0} dot(%parameter_1.1, %convert.0), +dot { + p0 = pred[32,26,104]{2,1,0} parameter(0) + p1 = f32[32,50,104]{2,1,0} parameter(1) + convert = f32[32,26,104]{2,1,0} convert(p0) + ROOT dot = f32[32,50,26]{2,1,0} dot(p1, convert), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2} } -%fused_computation { - %param_0.1 = f32[32,50,26]{2,1,0} parameter(0) - %transpose.1 = f32[50,32,26]{2,1,0} transpose(%param_0.1), dimensions={1,0,2} - ROOT %bitcast.7 = f32[32,50,26]{2,0,1} bitcast(%transpose.1) +loop_fusion { + p0 = f32[32,50,26]{2,1,0} parameter(0) + transpose = f32[50,32,26]{2,1,0} transpose(p0), dimensions={1,0,2} + ROOT bitcast = f32[32,50,26]{2,0,1} bitcast(transpose) } ENTRY e { - %parameter_0 = s8[32,26,104]{2,1,0} parameter(0) - %parameter_1 = f32[32,50,104]{2,1,0} parameter(1) - %triton_gemm_dot.127 = f32[32,50,26]{2,1,0} fusion(%parameter_0, %parameter_1), - kind=kCustom, calls=%triton_gemm_dot.127, + p0 = pred[32,26,104]{2,1,0} parameter(0) + p1 = f32[32,50,104]{2,1,0} parameter(1) + dot = f32[32,50,26]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":128,"block_k":64, "split_k":1,"num_stages":2,"num_warps":4, "num_ctas":1}}} - ROOT %fusion.1 = f32[32,50,26]{2,0,1} fusion(%triton_gemm_dot.127), kind=kLoop, calls=%fused_computation + ROOT fusion = f32[32,50,26]{2,0,1} fusion(dot), kind=kLoop, + calls=loop_fusion })"; - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); + + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata ref_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextRef)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(ref_module_and_metadata.module), + std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_TritonDotFusionCanHaveOnlyRHSParameter) { +TEST_F(CompareTest, TritonDotFusionCanHaveOnlyRHSParameter) { const std::string kHloTextTest = R"( HloModule m, is_scheduled=true @@ -3545,13 +3588,19 @@ ENTRY e { backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[63,92]{1,0} get-tuple-element((f32[63,92]{1,0}, s8[0]{0}) gemm), index=0 })"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ref_module, + ParseAndReturnVerifiedModule(kHloTextRef)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(ref_module), + std::move(module_and_metadata.module), ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_TritonDotFusionCanHaveNoParametersAtAll) { +TEST_F(CompareTest, TritonDotFusionCanHaveNoParametersAtAll) { const std::string kHloTextTest = R"( HloModule m, is_scheduled=true @@ -3588,12 +3637,19 @@ ENTRY triton_gemm___computation { ROOT get-tuple-element = f32[11,45]{1,0} get-tuple-element((f32[11,45]{1,0}, s8[0]{0}) gemm), index=0 })"; - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ref_module, + ParseAndReturnVerifiedModule(kHloTextRef)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(ref_module), + std::move(module_and_metadata.module), ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_TritonDotFusionCanHaveManyParameters) { +TEST_F(CompareTest, TritonDotFusionCanHaveManyParameters) { const std::string kHloTextTest = R"( HloModule m @@ -3705,12 +3761,19 @@ ENTRY e { ROOT get-tuple-element = f32[32,57]{0,1} get-tuple-element((f32[32,57]{0,1}, s8[0]{0}) gemm), index=0 })"; - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ref_module, + ParseAndReturnVerifiedModule(kHloTextRef)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(ref_module), + std::move(module_and_metadata.module), ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4}, /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_PredToBF16ConversionWorks) { +TEST_F(CompareTest, PredToBF16ConversionWorks) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3740,6 +3803,9 @@ ENTRY e { "num_ctas":"1"}}} })"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); + const std::string kHloTextRef = R"( HloModule m, is_scheduled=true @@ -3765,11 +3831,19 @@ ENTRY e { ROOT get-tuple-element = bf16[92,63]{1,0} get-tuple-element((bf16[92,63]{1,0}, s8[0]{0}) gemm), index=0 })"; - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ref_module, + ParseAndReturnVerifiedModule(kHloTextRef)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(ref_module), + std::move(module_and_metadata.module), + ErrorSpec{/*aabs=*/0, /*arel=*/0}, /*run_hlo_passes=*/false)); } +// TODO(b/393299275): symbolic tile analysis fails to derive a tile for one +// outer parameter here. However, we shouldn't be deriving this tile anyway, +// and the underlying indexing map is incorrect. This requires a fix in +// symbolic tile derivation. TEST_F(CompareTest, DISABLED_DifferentLayoutsAreSupportedInOneScope) { const std::string kHloTextTest = R"( triton_dot { @@ -3803,6 +3877,9 @@ ENTRY e { "num_ctas":"1"}}} })"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); + const std::string kHloTextRef = R"( ENTRY e { p1 = f16[3,3,2,16]{1,3,2,0} parameter(1) @@ -3823,7 +3900,11 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ref_module, + ParseAndReturnVerifiedModule(kHloTextRef)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(ref_module), + std::move(module_and_metadata.module), ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4}, /*run_hlo_passes=*/false)); } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 993fe46973343d..03d63e57ff9b1b 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -556,6 +556,10 @@ absl::StatusOr GetRealRootIndex( tiled_operand = std::make_unique( &operand.instruction(), std::move(operand_indexing_map)); } + + // TODO(b/393299275): propagation to operands is not correct when nesting, + // because we derive something all the way to the parameters that are + // outside the fusion. We should not derive anything for those operands. auto [operand_tiled_hlo, inserted] = tiled_hlo_instructions_set.Insert(std::move(tiled_operand)); tiled_hlo_instruction->AppendOperand(operand_tiled_hlo); From 84226b41af6ffc3032a4a536c6c23fe82a18c74b Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Tue, 8 Apr 2025 05:18:16 -0700 Subject: [PATCH 0361/1324] [XLA:GPU] Have a pass that splits GEMM K dimension for better compute utilization. Currently, splitK is done by autotuner, the plan is to remove it so that autotuner doesn't rewrite the graph. The pass is off by default, --xla_gpu_experimental_enable_split_k_rewrite to enable it (also makes sense to --xla_gpu_enable_split_k_autotuning=false together with it). The heuristic is to be updated. PiperOrigin-RevId: 745097715 --- third_party/xla/xla/debug_options_flags.cc | 9 +- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/gpu_compiler.cc | 8 +- .../xla/xla/service/gpu/transforms/BUILD | 41 +++ .../service/gpu/transforms/splitk_rewriter.cc | 339 ++++++++++++++++++ .../service/gpu/transforms/splitk_rewriter.h | 50 +++ .../gpu/transforms/splitk_rewriter_test.cc | 126 +++++++ third_party/xla/xla/shape_util.cc | 7 + third_party/xla/xla/shape_util.h | 3 + third_party/xla/xla/util.cc | 4 +- third_party/xla/xla/xla.proto | 6 +- 11 files changed, 587 insertions(+), 7 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/transforms/splitk_rewriter.cc create mode 100644 third_party/xla/xla/service/gpu/transforms/splitk_rewriter.h create mode 100644 third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 73fa48e678914a..9ea257825cc70a 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -340,6 +340,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_experimental_enable_sync_collective_combining(false); opts.set_xla_unsupported_crash_on_hlo_pass_silent_hlo_change(false); opts.set_xla_unsupported_crash_on_hlo_pass_noop_change(false); + opts.set_xla_gpu_experimental_enable_split_k_rewrite(false); return opts; } @@ -1911,7 +1912,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_split_k_autotuning), debug_options->xla_gpu_enable_split_k_autotuning(), "Enable split_k autotuning for triton gemms.")); - flag_list->push_back(tsl::Flag( "xla_gpu_enable_reduction_epilogue_fusion", bool_setter_for( @@ -2356,6 +2356,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "If non empty will interpret this variable as a path for performance " "tables for matmuls. Expects `xla.gpu.DeviceHloInstructionProfiles` " "proto.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_split_k_rewrite", + bool_setter_for( + &DebugOptions::set_xla_gpu_experimental_enable_split_k_rewrite), + debug_options->xla_gpu_experimental_enable_split_k_rewrite(), + "Enable the pass that splits GEMMs that underutilize the GPU load by " + "splitting the K dimension using a heuristic.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 867c9a53949d84..373e98aff4d515 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1547,6 +1547,7 @@ cc_library( "//xla/service/gpu/transforms:scatter_slice_simplifier", "//xla/service/gpu/transforms:softmax_rewriter_triton", "//xla/service/gpu/transforms:sort_rewriter", + "//xla/service/gpu/transforms:splitk_rewriter", "//xla/service/gpu/transforms:stream_attribute_annotator", "//xla/service/gpu/transforms:stream_attribute_async_wrapper", "//xla/service/gpu/transforms:topk_specializer", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index be36336e74f25b..af4f084c2c9805 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -232,6 +232,7 @@ limitations under the License. #include "xla/service/gpu/transforms/scatter_slice_simplifier.h" #include "xla/service/gpu/transforms/softmax_rewriter_triton.h" #include "xla/service/gpu/transforms/sort_rewriter.h" +#include "xla/service/gpu/transforms/splitk_rewriter.h" #include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" #include "xla/service/gpu/transforms/topk_specializer.h" @@ -676,6 +677,10 @@ absl::Status RunOptimizationPasses( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(gpu_target_config.device_description); + pipeline.AddPass(); + pipeline.AddPass(); + HloPredicate upcaster_filter = [&](const HloInstruction* instr) { const auto* cuda_cc = std::get_if( &gpu_target_config.device_description.gpu_compute_capability()); @@ -685,9 +690,6 @@ absl::Status RunOptimizationPasses( } return !gpu::IsMatrixMultiplication(*instr); }; - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(upcaster_filter); pipeline.AddPass(upcaster_filter); diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index c5bf2734ff976e..77daf1e18153f7 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -3572,3 +3572,44 @@ xla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "splitk_rewriter", + srcs = ["splitk_rewriter.cc"], + hdrs = ["splitk_rewriter.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "//xla/stream_executor:device_description", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "splitk_rewriter_test", + srcs = ["splitk_rewriter_test.cc"], + deps = [ + ":splitk_rewriter", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:filecheck", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/transforms/splitk_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter.cc new file mode 100644 index 00000000000000..5dbb8218c3e9b3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter.cc @@ -0,0 +1,339 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/splitk_rewriter.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +struct DotDimensions { + int64_t b; // batch dimensions + int64_t m; // lhs non-contracting dimensions + int64_t n; // rhs non-contracting dimensions + int64_t k; // contracting dimensions + int64_t lhs_el_size_in_bits; + int64_t rhs_el_size_in_bits; + int64_t result_el_size_in_bits; +}; + +DotDimensions GetDotDimensions(const HloInstruction* dot) { + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + DotDimensionNumbers dnums = dot->dot_dimension_numbers(); + + auto product_dimensions = [](const Shape& shape, + absl::Span dimensions) { + return absl::c_accumulate(dimensions, static_cast(1), + [&](int64_t product, int64_t dimension) { + return product * shape.dimensions(dimension); + }); + }; + + auto get_side_size = [](const HloInstruction* instr) { + while (instr->IsElementwise() && instr->operand_count() == 1) { + instr = instr->operand(0); + } + return ShapeUtil::ElementSizeInBits(instr->shape()); + }; + + return DotDimensions{ + /*.b = */ product_dimensions(lhs_shape, dnums.lhs_batch_dimensions()), + /*.m = */ + product_dimensions( + lhs_shape, GetNonContractingDims(lhs_shape.dimensions().size(), + dnums.lhs_contracting_dimensions(), + dnums.lhs_batch_dimensions())), + /*.n = */ + product_dimensions( + rhs_shape, GetNonContractingDims(rhs_shape.dimensions().size(), + dnums.rhs_contracting_dimensions(), + dnums.rhs_batch_dimensions())), + /*.k = */ + product_dimensions(lhs_shape, dnums.lhs_contracting_dimensions()), + /* .lhs_el_size_in_bits = */ get_side_size(dot->operand(0)), + /* .rhs_el_size_in_bits = */ get_side_size(dot->operand(1)), + /* .result_el_size_in_bits = */ + ShapeUtil::ElementSizeInBits(dot->shape()), + }; +} + +size_t ChooseSplitK(const DotDimensions& dims, int num_cores) { + // Compute the computational intensity in FLOPs per 256 bits of memory I/O + // (instead of FLOPs per byte to avoid the need for floating point). + size_t computational_intensity = + 256 * dims.m * dims.n * dims.k / + (dims.m * dims.k * dims.lhs_el_size_in_bits + + dims.n * dims.k * dims.rhs_el_size_in_bits + + dims.m * dims.n * dims.result_el_size_in_bits); + + // The constants below were tuned the following way: + // 1. Generated random GEMM kernels. + // * M, N, K and B dimensions are exponentially distributed between 1 and + // 200000 + // * M, N and K are rounded up to the multiple of 16. + // * B is set to 1 in the half of samples. + // * Use combinations (that make sense)of s32, s8, s4, fp8, bf16, f32 and + // f16 as op and result types. + // 2. Every of these kernels were run on H100 with exhaustive tiling search + // enabled. + // 3. The best values of the constants were picked using brute force search. + // 4. Two functions were used as a loss function, converging to the same + // result (performance of the best splitK was taken as 1.0, and + // performance of other splitK value as a fraction of it): + // * Geomean. + // * Mean square loss. + constexpr int64_t kIntensityThreshold = 240; + // The minimum K dimension size for the dot after splitting. + constexpr int64_t kMemoryBoundMinK = 768; + constexpr int64_t kComputeBoundMinK = 1220; + constexpr size_t kMaxSplitK = 128; + // The target number tiles of num_cores×1.55 was tuned to be the best, but + // let's keep it more sane-looking 1.5. + const int64_t kTargetNumTiles = num_cores + num_cores / 2; + const int64_t kMTileSize = 64; + const int64_t kNTileSize = 128; + + VLOG(3) << "ChooseSplitK(), b=" << dims.b << " m=" << dims.m + << " n=" << dims.n << " k=" << dims.k + << " lhs_sz=" << dims.lhs_el_size_in_bits + << " rhs_size=" << dims.rhs_el_size_in_bits + << " result_size=" << dims.result_el_size_in_bits + << " intensity=" << computational_intensity; + + if (computational_intensity < kIntensityThreshold) { + // Assume memory throughput bound, choose as high splitK as possible, but + // keep the resulting K >= kMemoryBoundMinK. + size_t splitk = std::min( + kMaxSplitK, size_t{1} << Log2Ceiling(static_cast( + std::max(int64_t{1}, dims.k / kMemoryBoundMinK)))); + VLOG(3) << "Memory throughput bound, splitK=" << splitk; + return splitk; + } + + // Assume compute bound, try to fill target number of tiles. + const int64_t m_tiles = CeilOfRatio(dims.m, kMTileSize); + const int64_t n_tiles = CeilOfRatio(dims.n, kNTileSize); + const int64_t num_tiles = dims.b * m_tiles * n_tiles; + const uint64_t max_splitk = 1 << Log2Floor(static_cast(std::max( + int64_t{1}, dims.k / kComputeBoundMinK))); + const uint64_t desired_splitk = CeilOfRatio(kTargetNumTiles, num_tiles); + const size_t splitk = 1 << Log2Ceiling(std::min(max_splitk, desired_splitk)); + + VLOG(3) << "Compute throughput bound, m_tiles=" << m_tiles + << " n_tiles=" << n_tiles << " num_tiles=" << num_tiles + << " max_splitk=" << max_splitk + << " desired_splitk=" << desired_splitk << " splitk=" << splitk; + return splitk; +} + +// Pads the given instruction with zeros along the given dimension to the given +// size. +HloInstruction* PadInstruction(HloInstruction* instr, int64_t dimension_idx, + int64_t new_dimension_size) { + HloComputation* computation = instr->parent(); + const PrimitiveType element_type = instr->shape().element_type(); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + PaddingConfig padding_config = + MakeNoPaddingConfig(instr->shape().dimensions().size()); + padding_config.mutable_dimensions(dimension_idx) + ->set_edge_padding_low(new_dimension_size - + instr->shape().dimensions(dimension_idx)); + Shape new_shape = instr->shape(); + new_shape.set_dimensions(dimension_idx, new_dimension_size); + return computation->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, padding_config)); +} + +// The contracting dimension index becomes new batch (split) dimension, and all +// dimensions after it are shifted by 1. +HloInstruction* SplitKOperand(HloInstruction* operand, + int64_t contracting_dimension_idx, + int64_t split_k) { + // if the K dimension is not divisible by split_k, we need to pad it. + const int64_t src_k = operand->shape().dimensions(contracting_dimension_idx); + const bool needs_padding = src_k % split_k != 0; + if (needs_padding) { + const int64_t padded_k = RoundUpTo(src_k, split_k); + operand = PadInstruction(operand, contracting_dimension_idx, padded_k); + } + const Shape& old_shape = operand->shape(); + + // Copy the existing shape to keep all the non-dimension/non-layout fields of + // the shape (element size in bits etc). + Shape new_shape = old_shape; + new_shape.clear_dimensions(); + for (int64_t i = 0; i < old_shape.dimensions().size(); ++i) { + const int64_t old_dim = old_shape.dimensions(i); + if (i == contracting_dimension_idx) { + new_shape.add_dimensions(split_k); + new_shape.add_dimensions(old_dim / split_k); + } else { + new_shape.add_dimensions(old_dim); + } + } + + // Update the physical layout so the the physical layout is preserved (i.e. + // the splitK dimension goes right before the contracting dimension, and all + // remaining dimensions are kept). + if (new_shape.layout().minor_to_major_size() > 0) { + new_shape.mutable_layout()->clear_minor_to_major(); + for (int64_t dim_idx : old_shape.layout().minor_to_major()) { + if (dim_idx >= contracting_dimension_idx) { + new_shape.mutable_layout()->add_minor_to_major(dim_idx + 1); + } + if (dim_idx <= contracting_dimension_idx) { + new_shape.mutable_layout()->add_minor_to_major(dim_idx); + } + } + } + + // Now reshape into the "new_shape". + return operand->parent()->AddInstruction( + HloInstruction::CreateReshape(new_shape, operand)); +} + +// Sums/reduces the tensor along the given dimension. +absl::StatusOr ReduceDimension(HloInstruction* instr, + int64_t dimension_idx) { + HloComputation* computation = instr->parent(); + const PrimitiveType element_type = instr->shape().element_type(); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + return MakeReduceHlo(instr, zero, {dimension_idx}, HloOpcode::kAdd, + &instr->metadata()); +} + +absl::StatusOr SplitKDimensionOfDot(HloDotInstruction* src_dot, + size_t split_k) { + // "split_k" is the number on chunks the K dimension is split into. + const int64_t lhs_k_idx = + src_dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + const int64_t rhs_k_idx = + src_dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + // The operands' K dimension are split into [split_k, K/split_k] (shifting + // right all the dimensions after it). + HloInstruction* lhs = + SplitKOperand(src_dot->mutable_operand(0), lhs_k_idx, split_k); + HloInstruction* rhs = + SplitKOperand(src_dot->mutable_operand(1), rhs_k_idx, split_k); + + // Update the dot's dimension numbers accordingly (shifting right all the + // dimensions starting from the K dimension and inserting new batch dims). + DotDimensionNumbers new_dnums = src_dot->dot_dimension_numbers(); + auto shift_dimension = [](tsl::protobuf::RepeatedField* dims, + int64_t idx) { + absl::c_for_each(*dims, [idx](int64_t& dim) { + if (dim >= idx) dim++; + }); + }; + shift_dimension(new_dnums.mutable_lhs_contracting_dimensions(), lhs_k_idx); + shift_dimension(new_dnums.mutable_rhs_contracting_dimensions(), rhs_k_idx); + shift_dimension(new_dnums.mutable_lhs_batch_dimensions(), lhs_k_idx); + shift_dimension(new_dnums.mutable_rhs_batch_dimensions(), rhs_k_idx); + new_dnums.mutable_lhs_batch_dimensions()->Add(lhs_k_idx); + new_dnums.mutable_rhs_batch_dimensions()->Add(rhs_k_idx); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_dot, + MakeDotHlo(lhs, rhs, new_dnums, src_dot->precision_config(), + src_dot->shape().element_type(), {}, {}, + &src_dot->metadata())); + + // Reduce along the new batch dimension. + const int64_t splitk_dim_idx = new_dnums.lhs_batch_dimensions_size() - 1; + TF_ASSIGN_OR_RETURN(HloInstruction * reduced_dot, + ReduceDimension(new_dot, splitk_dim_idx)); + *reduced_dot->mutable_shape()->mutable_layout() = src_dot->shape().layout(); + return reduced_dot; +} + +class SplitkRewriterVisitor : public DfsHloRewriteVisitor { + public: + explicit SplitkRewriterVisitor(se::DeviceDescription device_description) + : device_description_(device_description) {} + + private: + absl::Status HandleDot(HloInstruction* instr) override { + HloDotInstruction* dot = DynCast(instr); + if (dot->sparse_operands()) return absl::OkStatus(); + if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() != 1 || + dot->dot_dimension_numbers().rhs_contracting_dimensions_size() != 1) { + // In theory we could support it, but it's rare and adds complexity. + return absl::OkStatus(); + } + const size_t split_k = + ChooseSplitK(GetDotDimensions(dot), device_description_.core_count()); + if (split_k == 1) return absl::OkStatus(); + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + SplitKDimensionOfDot(dot, split_k)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_dot)); + return absl::OkStatus(); + } + + se::DeviceDescription device_description_; +}; + +} // namespace + +absl::StatusOr SplitkRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + if (!module->config() + .debug_options() + .xla_gpu_experimental_enable_split_k_rewrite()) { + return false; + } + + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + SplitkRewriterVisitor visitor(device_description_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + changed |= visitor.changed(); + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/splitk_rewriter.h b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter.h new file mode 100644 index 00000000000000..2a06bc9cc16b85 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter.h @@ -0,0 +1,50 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SPLITK_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SPLITK_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +// Rewrites dot instructions that don't fully utilize cores but have a long K +// dimension. For such dots, the input tensors are split along the K dimension +// (forming a new batch dimension) and the resulting dot is reduced along the +// new batch dimension. +class SplitkRewriter : public HloModulePass { + public: + explicit SplitkRewriter(se::DeviceDescription device_description) + : device_description_(device_description) {} + + private: + absl::string_view name() const override { return "splitk-rewriter"; } + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + se::DeviceDescription device_description_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_SPLITK_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc new file mode 100644 index 00000000000000..bc32790c468dd9 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/splitk_rewriter.h" + +#include + +#include +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_utils.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class SplitkRewriterTest : public HloTestBase { + public: + SplitkRewriterTest() + : rewriter_(se::DeviceDescription( + ParseTextProto( + "core_count: 132") + .value())) {} + + protected: + SplitkRewriter rewriter_; +}; + +TEST_F(SplitkRewriterTest, SmallNonContractingDimensionCauseSplitK) { + const char* hlo_string = R"( +HloModule module + +ENTRY test { + lhs = f32[16,10240]{1,0} parameter(0) + rhs = f32[10240,128]{1,0} parameter(1) + ROOT dot = f32[16,128]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_split_k_rewrite(true); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + rewriter_.HloModulePass::Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunFileCheck(module->ToString(), R"( +CHECK: dot({{.*}}), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +CHECK: ROOT {{.*}} = f32[16,128]{1,0} reduce + )") + .value_or(false)); +} + +TEST_F(SplitkRewriterTest, PaddingIsInserted) { + // Huge K dimension to trigger 128 which is the largest possible splitK + // (hoping to make the test less fragile as heuristic changes). + const char* hlo_string = R"( + HloModule module + + ENTRY test { + lhs = f32[16,102401]{1,0} parameter(0) + rhs = f32[102401,128]{1,0} parameter(1) + ROOT dot = f32[16,128]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_split_k_rewrite(true); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + rewriter_.HloModulePass::Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunFileCheck(module->ToString(), R"( +CHECK: f32[16,102528]{1,0} pad({{.*}}), padding=0_0x127_0 + )") + .value_or(false)); +} + +TEST_F(SplitkRewriterTest, NoSplitKIfEnoughWork) { + // Huge K dimension to trigger 128 which is the largest possible splitK + // (hoping to make the test less fragile as heuristic changes). + const char* hlo_string = R"( + HloModule module + + ENTRY test { + lhs = f32[1024,10240]{1,0} parameter(0) + rhs = f32[10240,2048]{1,0} parameter(1) + ROOT dot = f32[1024,2048]{1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_split_k_rewrite(true); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + rewriter_.HloModulePass::Run(module.get())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 42508bdb5f32ea..944b7c44488d8a 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -2135,6 +2135,13 @@ std::optional> ShapeUtil::ByteStrides( return strides; } +/*static*/ int64_t ShapeUtil::ElementSizeInBits(const Shape& shape) { + if (shape.has_layout() && shape.layout().element_size_in_bits() != 0) { + return shape.layout().element_size_in_bits(); + } + return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * CHAR_BIT; +} + /*static*/ int64_t ShapeUtil::ArraySize(const Shape& shape) { CHECK(LayoutUtil::IsDenseArray(shape)); if (shape.layout().tiles().empty()) { diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index ea04f3b6bc95c7..f39fdad114cd50 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -1038,6 +1038,9 @@ class ShapeUtil { static std::optional> ByteStrides( const Shape& shape); + // Returns the size of the tensor element in bits. + static int64_t ElementSizeInBits(const Shape& shape); + // Returns the array size in bytes (layout/tiling required), all paddings are // included. static int64_t ArraySize(const Shape& shape); diff --git a/third_party/xla/xla/util.cc b/third_party/xla/xla/util.cc index c23d78f1535c8a..8720698dcb6b11 100644 --- a/third_party/xla/xla/util.cc +++ b/third_party/xla/xla/util.cc @@ -350,8 +350,8 @@ void LogLines(absl::LogSeverity sev, absl::string_view text, const char* fname, } int64_t Product(absl::Span xs) { - return std::accumulate(xs.begin(), xs.end(), static_cast(1), - std::multiplies()); + return absl::c_accumulate(xs, static_cast(1), + std::multiplies()); } std::vector ElemwiseProduct(absl::Span a, diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 924701a09deda4..8ab86b274a21f4 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -552,6 +552,10 @@ message DebugOptions { // Pre-existing block-level fusions are left unmodified. bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334; + // Enable the pass that splits GEMMs that underutilize the GPU load by + // splitting the K dimension using a heuristic. + bool xla_gpu_experimental_enable_split_k_rewrite = 386; + // Enable fusion for the subchannel dequantisation sequences like // [x,z]param -> [x,y,z]broadcast -> [x*y,z]bitcast -> multiply -> dot. // Performance can be worse, because some block sizes / split-k > 1 is @@ -1198,7 +1202,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 386 + // Next id: 387 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 9e138b8194d6bfeb5b0e12022953f103441d6d2b Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Tue, 8 Apr 2025 07:31:02 -0700 Subject: [PATCH 0362/1324] [XLA:GPU] update and enable some of the legacy matmul emitter tests for the generic dot emitter PiperOrigin-RevId: 745136615 --- .../fusion_emitter_device_legacy_port_test.cc | 297 ++++++++---------- 1 file changed, 135 insertions(+), 162 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index e9b44186b362a5..3e0dd108960baf 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -394,7 +394,7 @@ CHECK-NOT: mma )"); } -TEST_F(TritonGemmTest, DISABLED_DebugOptionsArePropagated) { +TEST_F(TritonGemmTest, DebugOptionsArePropagated) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) @@ -423,8 +423,8 @@ ENTRY e { EXPECT_EQ(paths.size(), 1); } -TEST_F(TritonGemmTest, DISABLED_DotWithPredFromCompareProducesCorrectResult) { - const std::string hlo_text = R"( +TEST_F(TritonGemmTest, DotWithPredFromCompareProducesCorrectResult) { + constexpr absl::string_view kHloText = R"( triton_dot { parameter_0 = s32[4,128]{1,0} parameter(0) broadcast.255 = s32[4,128,64]{2,1,0} broadcast(parameter_0), dimensions={0,1} @@ -442,11 +442,14 @@ ENTRY main { p2 = bf16[64,256]{0,1} parameter(2) ROOT gemm_fusion_dot.0 = bf16[512,256]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_dot, backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"128","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}} })"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + EXPECT_TRUE( + RunAndCompareNoHloPasses(module_and_metadata.module->ToString(), + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, DISABLED_UseTensorCoresForF32OnAmpere) { +TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { constexpr absl::string_view kHloText = R"( triton_gemm_r { parameter_0 = f16[80,15]{1,0} parameter(0) @@ -466,10 +469,9 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":2, "num_ctas":1}}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(kHloText)); - - CompileAndOptionallyVerifyPtx(std::move(verified_module), + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + CompileAndOptionallyVerifyPtx(std::move(module_and_metadata.module), R"( CHECK: mma )"); @@ -585,7 +587,7 @@ ENTRY e { } TEST_F(TritonGemmTest, DISABLED_MultipleDims) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -596,7 +598,7 @@ ENTRY e { lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -605,11 +607,11 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_PredWithBF16DotProducesCorrectResult) { - const std::string hlo_text = R"( +TEST_F(TritonGemmTest, PredWithBF16DotProducesCorrectResult) { + constexpr absl::string_view kHloText = R"( triton_dot { p0 = pred[8,640]{1,0} parameter(0) cvt = bf16[8,640]{1,0} convert(pred[8,640]{1,0} p0) @@ -627,11 +629,15 @@ ENTRY e { "num_ctas":1}}} })"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + EXPECT_TRUE( + RunAndCompareNoHloPasses(module_and_metadata.module->ToString(), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_NoPadding) { - const char* hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -642,7 +648,7 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -654,11 +660,11 @@ ENTRY e { ; CHECK-NOT: slice )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_S8xS8) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY f { @@ -667,11 +673,11 @@ ENTRY f { ROOT z = s32[1024,1024]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_SplitLhsNoncontractingTransposeRhs) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -683,7 +689,7 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={2} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -692,12 +698,12 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/0, /*arel=*/0})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } TEST_F(TritonGemmTest, DISABLED_SplitLhsNoncontracting) { - const std::string hlo_text = R"( -HloModule t + constexpr absl::string_view kHloText = R"( +Hl–oModule t ENTRY e { p0 = f32[72,72] parameter(0) @@ -710,7 +716,7 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -719,10 +725,10 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_SplitAndTransposeLhsExecutesCorrectly) { +TEST_F(TritonGemmTest, SplitAndTransposeLhsExecutesCorrectly) { constexpr absl::string_view kHloText = R"( HloModule m @@ -743,12 +749,12 @@ ENTRY e { ; CHECK-NEXT: ROOT ; CHECK-SAME: fusion ; CHECK-SAME: kind=kCustom +; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" )"); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_NondefaultOperandLayoutIsSupported) { +TEST_F(TritonGemmTest, NondefaultOperandLayoutIsSupported) { // TODO(bchetioui): reenable when b/285866137 is fixed. #ifndef NDEBUG GTEST_SKIP() << "This test times out when -UNDEBUG is set."; @@ -775,7 +781,7 @@ ENTRY r { } TEST_F(TritonGemmTest, DISABLED_DoNotFuseSplitRhsContractingTranspose) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -788,7 +794,7 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={1} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK: transpose ; CHECK: fusion @@ -796,11 +802,11 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_DoNotFuseSplitLhsContractingTranspose) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -813,7 +819,7 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={1} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK: transpose ; CHECK: fusion @@ -821,11 +827,11 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_BatchF32F16) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -837,7 +843,7 @@ ENTRY e { lhs_batch_dims={0}, rhs_batch_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -846,11 +852,11 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); } TEST_F(TritonGemmTest, DISABLED_NonMajorMostInputBatchWorksCorrectly) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -862,7 +868,7 @@ ENTRY e { lhs_batch_dims={1}, rhs_batch_dims={1} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -871,11 +877,11 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_BatchTransposeF32F16) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -888,7 +894,7 @@ ENTRY e { lhs_batch_dims={0}, rhs_batch_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter @@ -897,11 +903,11 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); } TEST_F(TritonGemmTest, DISABLED_DoNotFuseArbitraryReshape) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -914,7 +920,7 @@ ENTRY e { rhs_batch_dims={0}, rhs_contracting_dims={1} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK: f32[5,3,4]{2,1,0} bitcast ; CHECK: fusion @@ -922,10 +928,10 @@ ENTRY e { ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); } -TEST_F(TritonGemmTest, DISABLED_MultipleBatchRequireSeparateTranspose) { +TEST_F(TritonGemmTest, MultipleBatchRequireSeparateTranspose) { constexpr absl::string_view kHloText = R"( HloModule m @@ -943,6 +949,7 @@ ENTRY e { ; CHECK: transpose( ; CHECK: bitcast( ; CHECK: kCustom +; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); @@ -966,6 +973,7 @@ ENTRY e { ; CHECK-NOT: concatenate ; CHECK: fusion ; CHECK-SAME: kind=kCustom +; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" )"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -991,6 +999,7 @@ ENTRY e { ; CHECK-NOT: concatenate ; CHECK: fusion ; CHECK-SAME: kind=kCustom +; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" )"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -1200,9 +1209,8 @@ ENTRY e { )"); } -TEST_F( - TritonGemmTest, - DISABLED_BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported) { +TEST_F(TritonGemmTest, + BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported) { EXPECT_TRUE(RunAndCompare(R"( f { p0 = f32[64,6464] parameter(0) @@ -1221,7 +1229,9 @@ e { p1 = f32[16,6464] parameter(1) p2 = f32[64] parameter(2) f = f32[1,16,64] fusion(p0, p1, p2), - kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config": {"block_m":"16","block_n":"16","block_k":"64","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} })", ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } @@ -1249,7 +1259,7 @@ e { } TEST_F(TritonGemmTest, DISABLED_DoF32F32) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -1259,13 +1269,13 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-PTX-SAME: block_m )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_DoAddConstantToScalarAndBroadcastThat) { @@ -1273,7 +1283,7 @@ TEST_F(TritonGemmTest, DISABLED_DoAddConstantToScalarAndBroadcastThat) { GpuComputeCapability())) { GTEST_SKIP() << "Not using autotuner on ROCM yet."; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -1286,15 +1296,15 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion({{.*}} kind=kCustom, {{.*}}block_m )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_SameInput) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -1305,14 +1315,14 @@ ENTRY e { })"; // The fusion has separate parameters for each scope. - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK: %[[p0:.*]] = pred[5,5]{1,0} parameter(0) ; CHECK: fusion(%[[p0]], %[[p0]]), kind=kCustom ; CHECK-PTX-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } TEST_F(TritonGemmTest, DISABLED_DynamicSliceIsSupportedInLhsEndToEnd) { @@ -1590,7 +1600,7 @@ TEST_F(TritonGemmTest, if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -1601,7 +1611,7 @@ ENTRY e { ROOT d = bf16[2,384,20] dot(concat, z), lhs_contracting_dims={2}, rhs_contracting_dims={0} })"; - MatchOptimizedHlo(hlo_text, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK: concatenate ; CHECK: ROOT @@ -1610,7 +1620,7 @@ ENTRY e { ; CHECK-SAME: "block_m" )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarWorksCorrectly) { @@ -1665,7 +1675,7 @@ ENTRY e { .status()); } -TEST_F(TritonGemmTest, DISABLED_BinaryOperationWithSmallInputsIsFused) { +TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1687,11 +1697,9 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); } -TEST_F(TritonGemmTest, DISABLED_BinaryOperationWithLargeInputsIsNotFused) { +TEST_F(TritonGemmTest, BinaryOperationWithLargeInputsIsNotFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1716,13 +1724,11 @@ ENTRY e { ; CHECK: ENTRY ; CHECK: kLoop ; CHECK: kCustom +; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" )"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, - DISABLED_ParametersWithDifferentLayoutsAreSupportedInOneScope) { +TEST_F(TritonGemmTest, ParametersWithDifferentLayoutsAreSupportedInOneScope) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = s8[5,3] parameter(0) @@ -1741,11 +1747,9 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, DISABLED_BinaryOperationOnLargeParametersIsFused) { +TEST_F(TritonGemmTest, BinaryOperationOnLargeParametersIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1766,11 +1770,12 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_LinkingLibdeviceTwiceWorks) { +TEST_F(TritonGemmTest, LinkingLibdeviceTwiceWorks) { + // TODO(b/393299275): This test looks weird. It's testing the whole + // optimization pipeline end-to-end to check that linking libdevice twice + // works? rewrite this to just use post-optimization HLO constexpr absl::string_view kHloText = R"( ENTRY e { p0 = s8[7,3] parameter(0) @@ -1801,7 +1806,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); } -TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarParameterIsFused) { +TEST_F(TritonGemmTest, BroadcastOfScalarParameterIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[64,256] parameter(0) @@ -1822,7 +1827,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarConstantIsFused) { +TEST_F(TritonGemmTest, BroadcastOfScalarConstantIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1841,8 +1846,6 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); } TEST_F(TritonGemmTest, DISABLED_DoubleBroadcastOfScalarConstantIsHandled) { @@ -1893,7 +1896,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, DISABLED_AlwaysFuseScalarConstantAtBroadcastInput) { +TEST_F(TritonGemmTest, AlwaysFuseScalarConstantAtBroadcastInput) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -1918,12 +1921,11 @@ ENTRY e { ; CHECK: ROOT ; CHECK: ENTRY ; CHECK: kCustom +; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" )"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_BroadcastOfVectorParameterIsFused) { +TEST_F(TritonGemmTest, BroadcastOfVectorParameterIsFused) { constexpr absl::string_view kHloText = R"( triton_dot { p0 = f16[75] parameter(0) @@ -1945,8 +1947,6 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloText)); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); } TEST_F(TritonGemmTest, DISABLED_FuseConcatenation) { @@ -2170,7 +2170,7 @@ ENTRY e { /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_SineOutputIsNotFused) { +TEST_F(TritonGemmTest, SineOutputIsNotFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -2189,11 +2189,9 @@ ENTRY e { GmockMatch(m::Sin( m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom)))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); } -TEST_F(TritonGemmTest, DISABLED_SliceInputIsFused) { +TEST_F(TritonGemmTest, SliceInputIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[97,121] parameter(0) @@ -2210,8 +2208,6 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_SliceInputWithReshapeIsFused) { @@ -2236,7 +2232,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_NestedSlicingWorks) { +TEST_F(TritonGemmTest, NestedSlicingWorks) { constexpr absl::string_view kHloText = R"( ENTRY e { p1 = f32[6,24] parameter(1) @@ -2254,11 +2250,9 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_SlicedBatchDimensionIsSupported) { +TEST_F(TritonGemmTest, SlicedBatchDimensionIsSupported) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[3,3,256] parameter(0) @@ -2278,8 +2272,6 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTestWithSplitK, @@ -2375,7 +2367,8 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); } -TEST_F(TritonGemmTest, DISABLED_OutputFusionExecutesCorrectly) { +TEST_F(TritonGemmTest, OutputFusionExecutesCorrectly) { + // TODO(b/393299275): it this test useful? if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -2473,7 +2466,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_SupportPredParametersUsedInExpressions) { +TEST_F(TritonGemmTest, SupportPredParametersUsedInExpressions) { constexpr absl::string_view kHloText = R"( ENTRY e { p = pred[2,2]{1,0} parameter(0) @@ -2496,31 +2489,11 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_Naming) { - const char* hlo_text = R"( -HloModule t - -ENTRY e { - p0 = f16[15,19] parameter(0) - p1 = s8[19,17] parameter(1) - cp1 = f16[19,17] convert(p1) - ROOT r = f16[15,17] dot(p0, cp1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(hlo_text, R"( -; CHECK: %gemm_fusion_r_computation ( -; CHECK: ROOT %gemm_fusion_r -; CHECK-SAME: kCustom -)"); } TEST_F(TritonGemmTest, DISABLED_LowerDotWithLhsWithoutNonContractingDimThroughTriton) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2531,18 +2504,18 @@ ENTRY e { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(hlo_text)); + GetOptimizedModule(kHloText)); EXPECT_THAT( module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, DISABLED_LowerDotWithRhsWithoutNonContractingDimThroughTriton) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2553,13 +2526,13 @@ ENTRY e { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(hlo_text)); + GetOptimizedModule(kHloText)); EXPECT_THAT( module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } // This group of tests compares GPU results of dots already rewritten @@ -2567,7 +2540,7 @@ ENTRY e { using CompareTest = TritonGemmTest; TEST_F(CompareTest, DISABLED_DifferentTilingsProduceSameResult) { - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule t triton_dot { @@ -2589,7 +2562,7 @@ ENTRY e { "num_ctas":1}}} })"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2616,7 +2589,7 @@ ENTRY e { } TEST_F(CompareTest, DISABLED_F16) { - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r ENTRY e { @@ -2629,7 +2602,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2656,7 +2629,7 @@ ENTRY e { } TEST_F(CompareTest, DISABLED_F32) { - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r ENTRY e { @@ -2669,7 +2642,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2696,7 +2669,7 @@ ENTRY e { } TEST_F(CompareTest, DISABLED_F32WithTrivialNonContractingDimension) { - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r ENTRY e { @@ -2709,7 +2682,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2739,7 +2712,7 @@ TEST_F(CompareTest, DISABLED_BF16TransposedLHS) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r ENTRY e { @@ -2752,7 +2725,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2865,7 +2838,7 @@ ENTRY e { } TEST_F(CompareTest, DISABLED_F16TransposedRHS) { - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r ENTRY e { @@ -2878,7 +2851,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2905,7 +2878,7 @@ ENTRY e { } TEST_F(CompareTest, DISABLED_F32TransposedBoth) { - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r ENTRY e { @@ -2918,7 +2891,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2948,7 +2921,7 @@ TEST_F(CompareTest, DISABLED_S8BF16) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } - const char* hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule r fused_computation { @@ -2967,7 +2940,7 @@ ENTRY e { } )"; - const char* hlo_text_triton = R"( + constexpr absl::string_view hlo_text_triton = R"( HloModule t triton_dot { @@ -2998,7 +2971,7 @@ TEST_F(CompareTest, DISABLED_SplitK) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string hlo_text_ref = R"( + constexpr absl::string_view hlo_text_ref = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -3021,7 +2994,7 @@ ENTRY e { "num_ctas":1}}} })"; - const std::string hlo_text_splitk = R"( + constexpr absl::string_view hlo_text_splitk = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -4033,7 +4006,7 @@ ENTRY e { // TODO(b/393299275): this test uncovers a bug in hoisting bitcasts through // broadcasts (seems to generate a type mismatch). TEST_F(TritonTest, DISABLED_UseTF32For8BitOrLessWithF32) { - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -4062,7 +4035,7 @@ ENTRY e { "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, - GetModuleAndNestedFusionMetadata(hlo_text)); + GetModuleAndNestedFusionMetadata(kHloText)); TF_ASSERT_OK( CreateTritonIrAndFileCheck(*module_and_metadata.computation, module_and_metadata.block_level_parameters, @@ -4072,7 +4045,7 @@ CHECK: inputPrecision = tf32 )")); EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } // TODO(b/393299275): this test requires us to allow actual mixed type GEMMs @@ -4082,7 +4055,7 @@ TEST_F(TritonTest, DISABLED_Fp8LoweringIsSupportedPostHopper) { if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -4105,7 +4078,7 @@ ENTRY main { })"; TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, - GetModuleAndNestedFusionMetadata(hlo_text)); + GetModuleAndNestedFusionMetadata(kHloText)); TF_ASSERT_OK( CreateTritonIrAndFileCheck(*module_and_metadata.computation, module_and_metadata.block_level_parameters, @@ -4113,7 +4086,7 @@ ENTRY main { CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4M3FN> * tensor<64x32xf8E4M3FN> -> tensor<128x32xf32> )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } @@ -4125,7 +4098,7 @@ TEST_F(TritonTest, DISABLED_BF16ToFP8EndToEnd) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -4147,7 +4120,7 @@ ENTRY main { "num_stages":"1","num_warps":"4","num_ctas":"1"}}} })"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } // TODO(b/393299275): this test requires us to allow actual mixed type GEMMs @@ -4158,7 +4131,7 @@ TEST_F(TritonTest, DISABLED_FP8ToFP8EndToEnd) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -4182,7 +4155,7 @@ ENTRY main { ASSERT_TRUE( GetDebugOptionsForTest() .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } // Test PreventMmaV3LoopUnrolling pass in order to keep compile time low. @@ -4193,7 +4166,7 @@ TEST_F(TritonGemmTest, TestPreventMMAV3LoopUnrolling) { if (GetCudaComputeCapability().major != se::CudaComputeCapability::kHopper) { GTEST_SKIP() << "wgmma instruction is only available on Hopper"; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( gemm_fusion_dot { p0 = f16[64,1024]{1,0} parameter(0) p1 = f16[1024,32,32]{2,1,0} parameter(1) @@ -4214,7 +4187,7 @@ ENTRY e { "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, - GetModuleAndNestedFusionMetadata(hlo_text)); + GetModuleAndNestedFusionMetadata(kHloText)); CompileAndOptionallyVerifyPtx(std::move(module_and_metadata.module), R"( R"( @@ -4232,7 +4205,7 @@ TEST_F(TritonGemmTest, WgmmaIsUsedForMemBoundShape) { if (GetCudaComputeCapability().major != se::CudaComputeCapability::kHopper) { GTEST_SKIP() << "wgmma instruction is only available on Hopper"; } - const std::string hlo_text = R"( + constexpr absl::string_view kHloText = R"( gemm_fusion_dot { p0 = s8[128,128]{1,0} parameter(0) p1 = bf16[128,16]{1,0} parameter(1) @@ -4253,7 +4226,7 @@ ENTRY e { })"; TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, - GetModuleAndNestedFusionMetadata(hlo_text)); + GetModuleAndNestedFusionMetadata(kHloText)); CompileAndOptionallyVerifyPtx(std::move(module_and_metadata.module), R"( CHECK: wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 From 0b72d6951af280893cbf7d8a9679279a4b0b000f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 8 Apr 2025 07:42:39 -0700 Subject: [PATCH 0363/1324] [xla:collectives] Add an OnDestroy callback to Collectives and cleanup defunct communicator cliques We prefer an on-destroy callback to passing collectives as a `std::shared_ptr` (and relying on `std::weak_ptr` to track the lifetime), because it can be later implemented as a C API, in case we'd want to add support for passing external collectives implementations to XLA via PJRT APIs. PiperOrigin-RevId: 745140099 --- .../xla/xla/backends/cpu/collectives/BUILD | 18 +++++ .../backends/cpu/collectives/cpu_cliques.cc | 31 ++++++++- .../cpu/collectives/cpu_cliques_test.cc | 65 +++++++++++++++++++ third_party/xla/xla/core/collectives/BUILD | 2 + .../xla/xla/core/collectives/collectives.cc | 35 ++++++++++ .../xla/xla/core/collectives/collectives.h | 34 +++++++++- 6 files changed, 181 insertions(+), 4 deletions(-) create mode 100644 third_party/xla/xla/backends/cpu/collectives/cpu_cliques_test.cc create mode 100644 third_party/xla/xla/core/collectives/collectives.cc diff --git a/third_party/xla/xla/backends/cpu/collectives/BUILD b/third_party/xla/xla/backends/cpu/collectives/BUILD index dce3c36a19ac32..bca2f4d6a7f852 100644 --- a/third_party/xla/xla/backends/cpu/collectives/BUILD +++ b/third_party/xla/xla/backends/cpu/collectives/BUILD @@ -68,6 +68,24 @@ cc_library( ], ) +xla_cc_test( + name = "cpu_cliques_test", + srcs = ["cpu_cliques_test.cc"], + deps = [ + ":cpu_clique_key", + ":cpu_cliques", + ":in_process_collectives", + "//xla:util", + "//xla/core/collectives:rank_id", + "//xla/service:global_device_id", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cpu_collectives", srcs = ["cpu_collectives.cc"], diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc index 8edfcaa145ae94..0d2192e3f710f0 100644 --- a/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc @@ -55,9 +55,12 @@ struct ThreadSafeClique { // cliques, CPU cliques are not lockable, and we create communicators lazily // when needed. struct ProcessCpuCliques { + using Key = std::pair; + absl::Mutex mu; - absl::node_hash_map map ABSL_GUARDED_BY(mu); + absl::node_hash_map map ABSL_GUARDED_BY(mu); }; + } // namespace // Returns process-local CPU cliques. @@ -66,6 +69,17 @@ static ProcessCpuCliques& GetProcessCpuCliques() { return *cliques; } +// Erases cliques constructed from a given instance of CpuCollectives. +static void EraseProcessCpuCliques(CpuCollectives* collectives) { + VLOG(3) << "Erase process CPU cliques for collectives: " << collectives; + ProcessCpuCliques& cliques = GetProcessCpuCliques(); + + absl::MutexLock lock(&cliques.mu); + absl::erase_if(cliques.map, [collectives](const auto& entry) { + return entry.first.first == collectives; + }); +} + //===----------------------------------------------------------------------===// // TODO(b/380457503): Consider switching to a lockable CPU clique model similar @@ -73,14 +87,25 @@ static ProcessCpuCliques& GetProcessCpuCliques() { absl::StatusOr AcquireCommunicator( CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank) { VLOG(3) << "Acquire communicator for clique key " << clique_key.ToString() - << " and rank " << rank; + << " and rank " << rank << " from collectives: " << collectives; ProcessCpuCliques& cliques = GetProcessCpuCliques(); // Synchronize access to the process cliques. ThreadSafeClique& thread_safe_clique = [&]() -> ThreadSafeClique& { absl::MutexLock lock(&cliques.mu); - auto [it, emplaced] = cliques.map.try_emplace(clique_key, clique_key); + auto [it, emplaced] = cliques.map.try_emplace( + std::make_pair(collectives, clique_key), clique_key); + + // If we created a new clique, register a callback to erase it when the + // collectives instance is destroyed. + if (emplaced) { + VLOG(3) << "Created a new clique for clique key " << clique_key.ToString() + << " and collectives: " << collectives; + collectives->AddOnDestroyCallback( + [collectives] { EraseProcessCpuCliques(collectives); }); + } + return it->second; }(); diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_cliques_test.cc b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques_test.cc new file mode 100644 index 00000000000000..f01f8b38aa7f56 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_cliques.h" + +#include + +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/in_process_collectives.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/global_device_id.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(CpuCliques, InvalidateAcquiredCommunicators) { + GlobalDeviceId d0(0); + GlobalDeviceId d1(1); + + CpuCliqueKey clique_key({d0, d1}); + + auto collectives0 = std::make_unique(); + auto collectives1 = std::make_unique(); + + // Check that communicator instance is cached. + TF_ASSERT_OK_AND_ASSIGN( + auto* comm0, AcquireCommunicator(&*collectives0, clique_key, RankId(0))); + TF_ASSERT_OK_AND_ASSIGN( + auto* comm1, AcquireCommunicator(&*collectives0, clique_key, RankId(0))); + EXPECT_EQ(comm0, comm1); + + // Destroy communicators created for `collectives0`. + collectives0.reset(); + + // Acquire communicator from a new instance of collectives. + TF_ASSERT_OK_AND_ASSIGN( + auto* comm2, AcquireCommunicator(&*collectives1, clique_key, RankId(0))); + TF_ASSERT_OK_AND_ASSIGN( + auto* comm3, AcquireCommunicator(&*collectives1, clique_key, RankId(0))); + EXPECT_EQ(comm2, comm3); + + // Check that we acquired new communicators. + EXPECT_NE(comm0, comm2); + + // Destroy communicators created for `collectives1`. + collectives1.reset(); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index 9e2f3325c9ece5..d199bcaad52964 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -31,12 +31,14 @@ cc_library( cc_library( name = "collectives", + srcs = ["collectives.cc"], hdrs = ["collectives.h"], deps = [ ":clique_id", ":clique_key", ":communicator", ":rank_id", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], diff --git a/third_party/xla/xla/core/collectives/collectives.cc b/third_party/xla/xla/core/collectives/collectives.cc new file mode 100644 index 00000000000000..98fe2e0a5fdbeb --- /dev/null +++ b/third_party/xla/xla/core/collectives/collectives.cc @@ -0,0 +1,35 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/core/collectives/collectives.h" + +#include + +#include "absl/functional/any_invocable.h" + +namespace xla { + +Collectives::~Collectives() { NotifyOnDestroyCallbacks(); } + +void Collectives::AddOnDestroyCallback(absl::AnyInvocable callback) { + on_destroy_callbacks_.push_back(std::move(callback)); +} + +void Collectives::NotifyOnDestroyCallbacks() { + auto callbacks = std::move(on_destroy_callbacks_); + for (auto& callback : callbacks) callback(); +} + +} // namespace xla diff --git a/third_party/xla/xla/core/collectives/collectives.h b/third_party/xla/xla/core/collectives/collectives.h index 009842972e8978..95a40d608128d7 100644 --- a/third_party/xla/xla/core/collectives/collectives.h +++ b/third_party/xla/xla/core/collectives/collectives.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/core/collectives/clique_id.h" @@ -42,7 +43,7 @@ namespace xla { // XLA:GPU device-initiated collective operations are implemented using NVSHMEM. class Collectives { public: - virtual ~Collectives() = default; + virtual ~Collectives(); // A base class for the device that the collectives are running on, i.e. in // XLA:GPU this is the GPU device (StreamExecutor). @@ -79,7 +80,38 @@ class Collectives { virtual absl::StatusOr>> SplitCommunicators(absl::Span comms, int32_t color, absl::Span keys, const Config& config) = 0; + + // Collectives instance can be ephemeral and used only for a small number of + // XLA program executions. XLA backends that rely on the collectives instances + // as a part of the cache key can be notified when the collectives instance + // is destroyed, so that they can invalidate the cache entries. + // + // After the on-destroy callback is invoked, XLA backends must not use any + // of the communicators created by the collectives instance. + // + // It is an XLA client responsibility (i.e. Pathways) to guarantee that + // collectives instance stays alive until all the XLA program executions that + // use it are finished. + void AddOnDestroyCallback(absl::AnyInvocable callback); + + protected: + Collectives() = default; + Collectives(Collectives&&) = default; + Collectives& operator=(Collectives&&) = default; + + // Notifies all registered callbacks that the collectives instance is + // about to be destroyed. + // + // IMPORTANT: Because callbacks are invoked from the base class destructor, + // they will be called after the derived class is destroyed. If it is + // important to call callbacks before the derived class is destroyed, the + // derived class should call it explicitly in its own destructor. + void NotifyOnDestroyCallbacks(); + + private: + std::vector> on_destroy_callbacks_; }; } // namespace xla + #endif // XLA_CORE_COLLECTIVES_COLLECTIVES_H_ From e35467e55e6df020c540920a5cd9c2fa76331bfd Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Tue, 8 Apr 2025 08:08:41 -0700 Subject: [PATCH 0364/1324] [XLA] Implement replica group deduplication for HloReplicationAnalysis. This helps compile times when the replica groups are very large. PiperOrigin-RevId: 745149068 --- .../hlo/analysis/hlo_replication_analysis.cc | 109 +++++++++++++++--- .../hlo/analysis/hlo_replication_analysis.h | 18 ++- 2 files changed, 108 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index 64e38cb469fff9..dfa5f7584b2d47 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -106,7 +106,10 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( bool cross_partition_spmd, const absl::flat_hash_map>& hlo_replication, - bool support_partial_replication) { + bool support_partial_replication, + const absl::flat_hash_map*>& + replica_group_dedup_map) { const auto merge_operand_replication = [&hlo_replication]( const HloInstruction* inst) { HloReplication replication = HloReplication::ReplicatedOnAllDevices(); @@ -121,19 +124,9 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( return replication; }; - if (hlo->opcode() == HloOpcode::kAllReduce || - hlo->opcode() == HloOpcode::kAllGather) { - // All-reduce/all-gather returns same values across partitions/replicas as - // long as its operands are replicated. - HloReplication replication = merge_operand_replication(hlo); - if (replication.IsReplicatedOnAllDevices()) { - return replication; - } + auto calculate_all_reduce_all_gather_replication = [&](const HloInstruction* + hlo) { if (!hlo->channel_id().has_value()) { - // This is cross-replica-only. - if (cross_partition_spmd) { - return replication; - } if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) { return HloReplication::ReplicatedOnAllDevices(); } @@ -148,9 +141,8 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( device_sets_per_replica[0].push_back(device_set); } return HloReplication::PartiallyReplicated(device_sets_per_replica); - } else { - return HloReplication::UniqueOnAllDevices(); } + return HloReplication::UniqueOnAllDevices(); } else { bool global_id; if (hlo->opcode() == HloOpcode::kAllReduce) { @@ -191,10 +183,38 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( } if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) { return HloReplication::ReplicatedOnAllDevices(); - } else { - return HloReplication::UniqueOnAllDevices(); } + return HloReplication::UniqueOnAllDevices(); } + }; + + if (hlo->opcode() == HloOpcode::kAllReduce || + hlo->opcode() == HloOpcode::kAllGather) { + // All-reduce/all-gather returns same values across partitions/replicas as + // long as its operands are replicated. + HloReplication replication = merge_operand_replication(hlo); + if (replication.IsReplicatedOnAllDevices()) { + return replication; + } + // This is cross-replica-only. + if (!hlo->channel_id().has_value() && cross_partition_spmd) { + return replication; + } + + // To save compile time on very large replica groups, check first if the + // replica group dedup map has an entry already populated with the + // replication and if so return that. + auto unique_replication_it = replica_group_dedup_map.find(hlo); + if (unique_replication_it == replica_group_dedup_map.end()) { + VLOG(1) << "No dedup entry for " << hlo->name(); + return calculate_all_reduce_all_gather_replication(hlo); + } + std::optional* unique_replication = + unique_replication_it->second; + if (!unique_replication->has_value()) { + *unique_replication = calculate_all_reduce_all_gather_replication(hlo); + } + return **unique_replication; } if (hlo->HasSideEffectNoRecurse()) { return HloReplication::UniqueOnAllDevices(); @@ -452,7 +472,7 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( *shape_tree.mutable_element(index) = DetermineHloInstructionIsReplicated( inst, index, cross_partition_spmd_, hlo_replication_, - support_partial_replication_); + support_partial_replication_, replica_group_dedup_map_); }); changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); } @@ -521,6 +541,57 @@ absl::Status HloReplicationAnalysis::ComputeHloReplication() { return absl::OkStatus(); } +void HloReplicationAnalysis::BuildReplicaGroupDedupMap() { + std::vector> dedupable_instructions; + for (const HloComputation* computation : + module_->MakeNonfusionComputations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kAllReduce || + instruction->opcode() == HloOpcode::kAllGather) { + auto dedupable_it = absl::c_find_if( + dedupable_instructions, + [&](const std::vector& insts) { + const HloInstruction* other = insts.at(0); + auto use_global_device_ids = [&](const HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kAllReduce) { + return Cast(inst) + ->use_global_device_ids(); + } + return Cast(inst) + ->use_global_device_ids(); + }; + // The existence of channel ids, global device ids and the replica + // groups can affect whether the instruction is replicated. So + // include these in the dedup cache key. + return instruction->channel_id().has_value() == + other->channel_id().has_value() && + use_global_device_ids(instruction) == + use_global_device_ids(other) && + absl::c_equal( + instruction->replica_groups(), other->replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return absl::c_equal(a.replica_ids(), + b.replica_ids()); + }); + }); + if (dedupable_it == dedupable_instructions.end()) { + dedupable_instructions.push_back({instruction}); + } else { + dedupable_it->push_back(instruction); + } + } + } + } + + unique_replications_.reserve(dedupable_instructions.size()); + for (auto& insts : dedupable_instructions) { + unique_replications_.push_back(std::nullopt); + for (const HloInstruction* inst : insts) { + replica_group_dedup_map_[inst] = &unique_replications_.back(); + } + } +} + bool HloReplicationAnalysis::HloInstructionIsReplicatedAt( const HloInstruction* inst, const ShapeIndex& index) const { auto it = hlo_replication_.find(inst); @@ -572,6 +643,7 @@ HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd, auto analysis = absl::WrapUnique(new HloReplicationAnalysis( module, cross_partition_spmd, loops_known_with_same_iterations, /*support_partial_replication=*/false)); + analysis->BuildReplicaGroupDedupMap(); TF_RETURN_IF_ERROR(analysis->ComputeHloReplication()); return analysis; } @@ -583,6 +655,7 @@ HloReplicationAnalysis::RunWithPartialReplication(const HloModule* module, auto analysis = absl::WrapUnique( new HloReplicationAnalysis(module, cross_partition_spmd, &empty, /*support_partial_replication=*/true)); + analysis->BuildReplicaGroupDedupMap(); TF_RETURN_IF_ERROR(analysis->ComputeHloReplication()); return analysis; } diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h index f71fb0850568b4..80bc62eb27a098 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h @@ -111,7 +111,10 @@ class HloReplicationAnalysis { bool cross_partition_spmd, const absl::flat_hash_map>& hlo_replication, - bool support_partial_replication); + bool support_partial_replication, + const absl::flat_hash_map*>& + replica_group_dedup_map); HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* @@ -130,6 +133,12 @@ class HloReplicationAnalysis { bool ComputeHloReplicationOnComputation(const HloComputation* computation, bool mark_everything_not_replicated); + // Builds the replica group dedup map that allows caching replication + // calculations for all-reduce/all-gather that share the same replica groups. + // This can significantly help in compile times when replica groups are very + // large. + void BuildReplicaGroupDedupMap(); + const HloModule* module_; // If true, run this replication analysis for replicated values across @@ -155,6 +164,13 @@ class HloReplicationAnalysis { // partitions at each shape index. absl::flat_hash_map> hlo_replication_; + + // Replications for all-reduce/all-gather that have the same replica groups is + // usually identical. We use the following data structures to memoize the + // replications for instructions with identical replica groups. + absl::flat_hash_map*> + replica_group_dedup_map_; + std::vector> unique_replications_; }; } // namespace xla From fa8543f0e0cc98cbdcd50d80e5bb2d3be36ff3ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 08:27:37 -0700 Subject: [PATCH 0365/1324] Instrument batching_delay_msecs and queueing_delay_msecs request cost dimensions in BatchResourceBase, and export them in QueryCostExt. PiperOrigin-RevId: 745155244 --- .../batching_util/batch_resource_base.cc | 112 +++++++++++---- .../batching_util/batch_resource_base.h | 16 ++- .../batching_util/batch_resource_base_test.cc | 135 +++++++++++++++++- 3 files changed, 231 insertions(+), 32 deletions(-) diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 545baba8110c10..84722b65043a21 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -1054,39 +1054,35 @@ void BatchResourceBase::ProcessFuncBatch( batch->task(batch->num_tasks() - 1).captured_inputs; args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); - uint64 current_time = EnvTime::NowNanos(); - for (int i = 0; i < batch->num_tasks(); ++i) { - RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3, - model_name, last_task_context->op_kernel().name(), - processed_size); - RecordBatchDelayUsV2((current_time - batch->task(i).start_time) * 1e-3, - model_name, last_task_context->op_kernel().name(), - processed_size); - } + RecordBatchDelayMetrics( + *batch, model_name, op_name, processed_size, + /*batch_schedule_time=*/absl::FromUnixNanos(EnvTime::NowNanos()), + GetBatchTimeout()); + // Releases the cleanup method here, because the callback of the function // library runtime will handle it now. finally.release(); - ProcessFuncBatchImpl( - last_task, args, &combined_outputs, [&](const absl::Status& run_status) { - absl::Status final_status; - auto run_finally = gtl::MakeCleanup([&]() { - // We do the cleanup here as an optimization, so that - // it runs in the underlying TF inter-op threadpool. - // Running it in the threadpool, let's the ensuing - // ops be scheduled faster, because the executor will - // add them to the front of the threadpool's task - // queue rather than the end. - cleanup_fn(final_status); - }); - final_status = run_status; - if (!final_status.ok()) { - return; - } - if (last_task.forced_warmup_batch_size == 0) { - final_status = SplitOutputTensors(combined_outputs, batch.get(), - unbatched_tasks); - } - }); + ProcessFuncBatchImpl(last_task, args, &combined_outputs, + [&](const absl::Status& run_status) { + absl::Status final_status; + auto run_finally = gtl::MakeCleanup([&]() { + // We do the cleanup here as an optimization, so that + // it runs in the underlying TF inter-op threadpool. + // Running it in the threadpool, let's the ensuing + // ops be scheduled faster, because the executor will + // add them to the front of the threadpool's task + // queue rather than the end. + cleanup_fn(final_status); + }); + final_status = run_status; + if (!final_status.ok()) { + return; + } + if (last_task.forced_warmup_batch_size == 0) { + final_status = SplitOutputTensors( + combined_outputs, batch.get(), unbatched_tasks); + } + }); } // Processes a batch of one or more BatchTask entries. @@ -1248,6 +1244,17 @@ absl::Status BatchResourceBase::LookupOrCreateBatcherQueue( return absl::OkStatus(); } +std::optional BatchResourceBase::GetBatchTimeout() const { + if (batcher_) { + return absl::Microseconds(batcher_queue_options_.batch_timeout_micros); + } + if (adaptive_batcher_) { + return absl::Microseconds( + adaptive_batcher_queue_options_.batch_timeout_micros); + } + return std::nullopt; +} + void BatchResourceBase::SplitBatchCostsAndRecordMetrics( const std::string& model_name, const std::string& op_name, const std::vector>& @@ -1324,5 +1331,50 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( } } +void BatchResourceBase::RecordBatchDelayMetrics( + const BatchResourceBase::BatchT& batch, const std::string& model_name, + const std::string& op_name, int64_t processed_size, + absl::Time batch_schedule_time, + std::optional batch_timeout) { + absl::Time earliest_task_start_time = absl::InfiniteFuture(); + for (int i = 0; i < batch.num_tasks(); ++i) { + earliest_task_start_time = + std::min(earliest_task_start_time, + absl::FromUnixNanos(batch.task(i).start_time)); + } + for (int i = 0; i < batch.num_tasks(); ++i) { + const BatchResourceBase::BatchTask& task = batch.task(i); + + const absl::Time start_time = absl::FromUnixNanos(task.start_time); + const absl::Duration total_scheduler_delay = + batch_schedule_time - start_time; + RecordBatchDelayUs(absl::ToInt64Microseconds(total_scheduler_delay), + model_name, op_name, processed_size); + RecordBatchDelayUsV2(absl::ToInt64Microseconds(total_scheduler_delay), + model_name, op_name, processed_size); + + RequestCost* request_cost = task.request_cost; + // Skip recording the cost if the request_cost is null. + if (!request_cost) continue; + + // The duration from when the task was enqueued to when the earliest task in + // its batch has been in the queue for a duration of batch_timeout (i.e. + // when the task is eligible being scheduled into a batch, regardless of the + // number of tasks in the queue) is considered as batching delay, and the + // remaining duration in the queue is considered as queueing delay. + const absl::Duration remaining_batch_timeout = + std::max(earliest_task_start_time + + batch_timeout.value_or(absl::ZeroDuration()) - start_time, + absl::ZeroDuration()); + const absl::Duration batching_delay = + std::min(remaining_batch_timeout, total_scheduler_delay); + const absl::Duration queueing_delay = + total_scheduler_delay - batching_delay; + request_cost->RecordMetrics( + {{"batching_delay_msecs", absl::ToDoubleMilliseconds(batching_delay)}, + {"queueing_delay_msecs", absl::ToDoubleMilliseconds(queueing_delay)}}); + } +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 633724de79a21b..54e83c82367f3e 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" +#include "absl/time/time.h" #include "xla/tsl/platform/criticality.h" #include "tensorflow/core/common_runtime/cost_measurement_registry.h" #include "tensorflow/core/common_runtime/request_cost.h" @@ -82,7 +84,7 @@ class BatchResourceBase : public ResourceBase { // Note input from one batch-op invocation is valid and considered a // specialized `slice`. struct BatchTask : public tensorflow::serving::BatchTask { - BatchTask() : criticality_val(tsl::criticality::GetCriticality()){}; + BatchTask() : criticality_val(tsl::criticality::GetCriticality()) {}; // A unique ID to identify this invocation of Batch. int64_t guid; @@ -274,6 +276,14 @@ class BatchResourceBase : public ResourceBase { batch_cost_measurements, int64_t processed_size, BatchT& batch); + // Records information about the delay between a task being registered and + // that task being scheduled into a batch. + static void RecordBatchDelayMetrics( + const BatchResourceBase::BatchT& batch, const std::string& model_name, + const std::string& op_name, int64_t processed_size, + absl::Time batch_schedule_time, + std::optional batch_timeout); + private: // Implementation of calling the process batch function. virtual void ProcessFuncBatchImpl( @@ -346,6 +356,10 @@ class BatchResourceBase : public ResourceBase { const string& op_name, BatcherQueueT** queue); + // Returns the batch timeout for the configured scheduler, or nullopt if the + // scheduler does not have such a parameter. + std::optional GetBatchTimeout() const; + SessionMetadata session_metadata_; absl::Mutex outstanding_batch_mu_; diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index b3e2548b58d326..ee912719fba873 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/tsl/platform/criticality.h" @@ -129,10 +130,12 @@ class TestGcuCostMeasurement : public CostMeasurement { REGISTER_COST_MEASUREMENT("test_gcu", TestGcuCostMeasurement); std::unique_ptr MakeBatchTask( - const int64_t task_size, RequestCost* request_cost) { + const int64_t task_size, RequestCost* request_cost, + absl::Time start_time = absl::UnixEpoch()) { auto task = std::make_unique(); task->inputs.push_back(Tensor(DT_DOUBLE, TensorShape({task_size, 1}))); task->request_cost = request_cost; + task->start_time = absl::ToUnixNanos(start_time); return task; } @@ -418,6 +421,136 @@ TEST(SplitBatchCostsAndRecordMetricsTest, GlobalBatchStatsProcessedSize) { original_cumulative_processed_size + 4); } +TEST(RecordBatchDelayMetricsTest, + TwoRequestsWithNoQueueingDelayAndSchedulingAtBatchTimeout) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = batch_timeout / 4; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = task1_start_time + batch_timeout; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT( + batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout)), + Pair("queueing_delay_msecs", 0))); + EXPECT_THAT(batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout - task2_delay)), + Pair("queueing_delay_msecs", 0))); +} + +TEST(RecordBatchDelayMetricsTest, + TwoRequestsWithNoQueueingDelayAndSchedulingAfterSecondRequest) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = batch_timeout / 4; + const absl::Duration scheduling_delay = batch_timeout / 10; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = + task1_start_time + task2_delay + scheduling_delay; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT( + batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(task2_delay + scheduling_delay)), + Pair("queueing_delay_msecs", 0))); + EXPECT_THAT( + batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(scheduling_delay)), + Pair("queueing_delay_msecs", 0))); +} + +TEST(RecordBatchDelayMetricsTest, TwoRequestWithQueueingDelay) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = batch_timeout / 4; + const absl::Duration queueing_delay = 5 * batch_timeout; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = + task1_start_time + batch_timeout + queueing_delay; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT( + batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout)), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(queueing_delay)))); + EXPECT_THAT(batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout - task2_delay)), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(queueing_delay)))); +} + +TEST(RecordBatchDelayMetricsTest, + TwoRequestsWithQueueingDelayAndSecondArrivingAfterBatchTimeout) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = 3 * batch_timeout; + const absl::Duration queueing_delay = 5 * batch_timeout; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = + task1_start_time + task2_delay + queueing_delay; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT(batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout)), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(task2_delay - batch_timeout + + queueing_delay)))); + EXPECT_THAT( + batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", 0), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(queueing_delay)))); +} + class BatchResourceBaseTest : public ::testing::Test { protected: // Like BatchResourceBase but overrides abstract methods, one of which From 24ef432ec8b80df4394afd98d9d0d13fd20c5f95 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Tue, 8 Apr 2025 08:38:58 -0700 Subject: [PATCH 0366/1324] [XLA] Allow multihost runner to use entry computation layouts from the module To do that, pass --use_layouts_from_hlo_module flag. Without the flag, default layouts were always used. PiperOrigin-RevId: 745159025 --- .../xla/xla/tools/multihost_hlo_runner/BUILD | 2 +- .../data/single_device.hlo | 6 +++--- .../functional_hlo_runner.cc | 20 ++++++++++++++++--- .../functional_hlo_runner.h | 6 +++++- .../functional_hlo_runner_test.cc | 16 +++++++++++++++ .../multihost_hlo_runner/hlo_runner_main.cc | 5 +++++ 6 files changed, 47 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index c3c26a89ea0f06..d8d10a07b29e6b 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -139,6 +139,7 @@ cc_library( deps = [ "//xla:literal", "//xla:literal_util", + "//xla:shape_layout", "//xla:shape_util", "//xla:status_macros", "//xla:util", @@ -183,7 +184,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@llvm-project//mlir:FuncExtensions", "@local_tsl//tsl/platform:protobuf", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/data/single_device.hlo b/third_party/xla/xla/tools/multihost_hlo_runner/data/single_device.hlo index c569f98abb1796..8af659c29e42d8 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/data/single_device.hlo +++ b/third_party/xla/xla/tools/multihost_hlo_runner/data/single_device.hlo @@ -1,6 +1,6 @@ -HloModule f +HloModule f, entry_computation_layout={(f32[2,2]{0,1})->f32[2,2]{0,1}} ENTRY f { - arg = f32[2,2]{1,0} parameter(0) - ROOT add_result = f32[2,2]{1,0} add(arg, arg) + arg = f32[2,2]{0,1} parameter(0) + ROOT add_result = f32[2,2]{0,1} add(arg, arg) } diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index b308cd9ac5ff56..bc534ca0fed0b6 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -65,6 +66,7 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_module_util.h" +#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tests/test_utils.h" @@ -792,10 +794,22 @@ absl::Status FunctionalHloRunner::PrepareHloModuleForCompilation( } CompileOptions FunctionalHloRunner::CompleteCompileOptions( - const HloModule& hlo_module, CompileOptions compile_options) { + const HloModule& hlo_module, CompileOptions compile_options, + const PreprocessingOptions& preproc_options) { ParameterType parameter_type = GetParameterType(hlo_module); compile_options.parameter_is_tupled_arguments = (parameter_type == ParameterType::kOneTupleOfArrays); + if (preproc_options.use_layouts_from_hlo_module) { + const ComputationLayout& layout = hlo_module.entry_computation_layout(); + std::vector parameter_shapes; + parameter_shapes.reserve(layout.parameter_count()); + for (const ShapeLayout& shape_layout : layout.parameter_layouts()) { + parameter_shapes.push_back(shape_layout.shape()); + } + compile_options.argument_layouts = std::move(parameter_shapes); + compile_options.executable_build_options.set_result_layout( + layout.result_shape()); + } return compile_options; } @@ -838,7 +852,7 @@ FunctionalHloRunner::Compile(PjRtClient& client, HloModule* hlo_module, TF_RETURN_IF_ERROR(PrepareHloModuleForCompilation(hlo_module, debug_options, preproc_options)); CompileOptions modified_compile_options = - CompleteCompileOptions(*hlo_module, compile_options); + CompleteCompileOptions(*hlo_module, compile_options, preproc_options); return ConvertAndCallCompiler( preproc_options.compile_as_stablehlo, hlo_module, @@ -856,7 +870,7 @@ absl::StatusOr> FunctionalHloRunner::Compile( TF_RETURN_IF_ERROR(PrepareHloModuleForCompilation(hlo_module, debug_options, preproc_options)); CompileOptions modified_compile_options = - CompleteCompileOptions(*hlo_module, compile_options); + CompleteCompileOptions(*hlo_module, compile_options, preproc_options); return ConvertAndCallCompiler( preproc_options.compile_as_stablehlo, hlo_module, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index e615ae67db3999..40dd0d868940c2 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -238,6 +238,9 @@ class FunctionalHloRunner { // compilation. bool compile_as_stablehlo = false; + // Use layouts from the HLO module. + bool use_layouts_from_hlo_module = false; + // Should we flatten all while loops? bool flatten_while_loop() const { return while_execution_count.has_value(); @@ -418,7 +421,8 @@ class FunctionalHloRunner { // This would ideally be private, but we need it for the implementation of // MultihostHloRunner. static CompileOptions CompleteCompileOptions(const HloModule& hlo_module, - CompileOptions compile_options); + CompileOptions compile_options, + const PreprocessingOptions&); static absl::Status DumpOutput( const FunctionalHloRunner::PerDeviceLiteralVecType& output, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index ef567b35b834ba..9007e427adf28b 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -386,6 +386,22 @@ void CompileAndFilecheck( } } +TEST_F(FunctionalHloRunnerTest, KeepLayoutsFromHloModule) { + FunctionalHloRunner::PreprocessingOptions preproc_options; + preproc_options.use_layouts_from_hlo_module = true; + + CompileAndFilecheck(GetHloPath("single_device.hlo"), + // Check that non-standard layouts are preserved. + R"( +// CHECK: entry_computation_layout={(f32[2,2]{0,1})->f32[2,2]{0,1}} +// CHECK: f32[2,2]{0,1} parameter(0) +// CHECK: ROOT {{.*}} = f32[2,2]{0,1} +)", + preproc_options, + FunctionalHloRunner::HloPassesMode::kStandardCompile, + /*num_partitions=*/1); +} + TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { CompileAndFilecheck(GetHloPath("sharded_16_devices.hlo"), // Check that the sharding was done correctly. diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index d20d901b991186..5bb13542731772 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -98,6 +98,7 @@ struct HloRunnerConfig { int32_t while_execution_count = -1; bool remove_infeed_outfeed = true; bool compile_as_stablehlo = false; + bool use_layouts_from_hlo_module = false; int32_t num_repeats = 1; std::string execution_options_path = ""; int64_t gpu_client_initialization_timeout_sec = 300; @@ -361,6 +362,10 @@ int main(int argc, char** argv) { tsl::Flag("compile_as_stablehlo", &opts.compile_as_stablehlo, "If set, convert the module to StableHLO before passing to " "PjRt for compilation."), + tsl::Flag("use_layouts_from_hlo_module", + &opts.use_layouts_from_hlo_module, + "If set, use layouts from the HLO module's " + "entry_computation_layout."), tsl::Flag("num_repeats", &opts.num_repeats, "Repeatedly execute the HLO for this many times."), tsl::Flag("execution_options_path", &opts.execution_options_path, From 9aabdd952447afda24b246b7bf230c9cca76a990 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 8 Apr 2025 08:49:50 -0700 Subject: [PATCH 0367/1324] Direct HLO -> StableHLO Conversion Largely NFC for users of existing HLO->MHLO APIs since all exit points are plugged with a StableHLO->MHLO conversion. However, users who are converting HLO->StableHLO will pay less of an overhead since there will be fewer and fewer ops that require MHLO->StableHLO conversion. PiperOrigin-RevId: 745162439 --- .../mlir/tf2xla/tests/legalize-tf.mlir | 32 +- .../mlir/tf2xla/transforms/xla_legalize_tf.cc | 25 +- third_party/stablehlo/temporary.patch | 27 + .../xla/third_party/stablehlo/temporary.patch | 27 + .../xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 5 + .../hlo_to_mhlo/hlo_function_importer.cc | 457 ++- .../hlo_to_mhlo/hlo_function_importer.h | 3 + .../hlo_to_mhlo/hlo_module_importer.cc | 40 +- .../hlo_to_mhlo/hlo_module_importer.h | 4 +- .../translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc | 28 +- .../translate/hlo_to_mhlo/hlo_to_mlir_hlo.h | 10 +- .../hlo_to_mhlo/location_importer.cc | 5 +- .../translate/hlo_to_mhlo/location_importer.h | 4 +- .../hlo_to_mhlo/stack_location_utils.cc | 4 +- .../hlo_to_mhlo/stack_location_utils.h | 4 +- .../xla/hlo/translate/hlo_to_mhlo/tests/BUILD | 1 + .../tests/import_emit_stablehlo.hlo | 2788 +++++++++++++++++ .../hlo/translate/hlo_to_mhlo/translate.cc | 10 +- .../xla/hlo/translate/hlo_to_mhlo/translate.h | 4 +- .../hlo_to_mhlo/translate_registration.cc | 11 +- .../xla/xla/hlo/translate/stablehlo.cc | 6 +- .../xla/hlo/translate/xla_translate_main.cc | 5 +- 22 files changed, 3273 insertions(+), 227 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 92754a181e8551..c54bef4f4a6947 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -91,12 +91,12 @@ func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8x // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[MEAN]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %arg3, %[[ALPHA]] + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[MEAN]], %[[BETA]] // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %arg4, %[[ALPHA]] + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[CORRECTED_VAR]], %[[BETA]] // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]] @@ -430,8 +430,7 @@ func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: // CHECK-LABEL: func @biasAdd_default func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> @@ -443,8 +442,7 @@ func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) // CHECK-LABEL: func @biasAdd_NHWC func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> @@ -456,8 +454,7 @@ func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK-LABEL: func @biasAdd_NCHW func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> @@ -469,8 +466,7 @@ func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK-LABEL: func @biasAdd_dynamic func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor @@ -482,8 +478,7 @@ func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: func @biasAdd_partial_dynamic func.func @biasAdd_partial_dynamic(%arg0: tensor, %arg1: tensor<512xi32>) -> tensor { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] // CHECK: %[[CAST:.+]] = tensor.cast %[[RESULT]] : tensor to tensor @@ -1797,7 +1792,7 @@ func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> // CHECK-LABEL: func @relu func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: chlo.broadcast_maximum %arg0, %[[ZERO]] {broadcast_dimensions = array} : (tensor<1xi32>, tensor) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> } @@ -1807,7 +1802,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unsigned func.func @relu_unsigned(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %arg0, %[[ZERO]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor func.return %0: tensor } @@ -1882,7 +1877,7 @@ func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> te // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = array} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ABS]], %[[ONE]] {broadcast_dimensions = array} : (tensor<4x10xf32>, tensor) -> tensor<4x10xf32> // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> @@ -1958,8 +1953,7 @@ func.func @select_batch_dynamic_r1(%arg0: tensor, %arg1: tensor // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex> // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor) { - // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex> - // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor, tensor<3xindex>) -> tensor + // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[SELECT:.*]] = mhlo.select %[[BCAST]], %arg1, %arg2 : tensor, tensor // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index f5364586ec73c9..6db7ccec27c710 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -154,6 +154,24 @@ mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, return result; } +mlir::LogicalResult StablehloToMhlo(Operation *op) { + RewritePatternSet patterns(op->getContext()); + stablehlo::StablehloToHloTypeConverter converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &converter, + op->getContext()); + ConversionTarget target(*op->getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + stablehlo::StablehloToHloTypeConverter shlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &shlo_converter, + patterns.getContext()); + stablehlo::registerFuncOpsForTypeConversion(target, patterns, shlo_converter); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + return op->emitError("TF2XLA failed to convert StableHLO to MHLO"); + } + return success(); +} + /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization /// patterns from TF2XLA fallback for provided device type (see /// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is @@ -216,7 +234,12 @@ LogicalResult legalizeTF(Operation *op, bool legalize_chlo, // canonicalization pattern to pattern list to enable multi-hop lowering. chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); - return ApplyPatterns(op, patterns, legalize_chlo); + if (failed(ApplyPatterns(op, patterns, legalize_chlo))) { + return failure(); + } + + // HLO->MLIR raises to StableHLO, but users of this pass expect MHLO. + return StablehloToMhlo(op); } // Performs the lowering to XLA dialect. diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 90ca4ec1d0d819..a6f4b05d72cedb 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -46,4 +46,31 @@ diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehl return %cst : tensor<14x15x0x9xcomplex> } } +diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp ++++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +@@ -1539,8 +1539,8 @@ + + void populateStablehloHloImportCanonicalizationPatterns( + MLIRContext *context, RewritePatternSet *patterns) { +- patterns->add( +- context); ++ patterns->add(context); + } + + std::unique_ptr createStablehloAggressiveSimplificationPass( +diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ++++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +@@ -366,7 +366,8 @@ + (StableHLO_ReshapeOpWithShape $reshape, $operand)>; + + // Pattern: reshape(X, [X.shape]) -> X +-def : Pat<(StableHLO_ReshapeOp:$reshape $operand), ++def ReshapeIsNoop ++ : Pat<(StableHLO_ReshapeOp:$reshape $operand), + (replaceWithValue $operand), + [(TypesEqual $reshape, $operand)]>; + diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 90ca4ec1d0d819..a6f4b05d72cedb 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -46,4 +46,31 @@ diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehl return %cst : tensor<14x15x0x9xcomplex> } } +diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp ++++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +@@ -1539,8 +1539,8 @@ + + void populateStablehloHloImportCanonicalizationPatterns( + MLIRContext *context, RewritePatternSet *patterns) { +- patterns->add( +- context); ++ patterns->add(context); + } + + std::unique_ptr createStablehloAggressiveSimplificationPass( +diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ++++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +@@ -366,7 +366,8 @@ + (StableHLO_ReshapeOpWithShape $reshape, $operand)>; + + // Pattern: reshape(X, [X.shape]) -> X +-def : Pat<(StableHLO_ReshapeOp:$reshape $operand), ++def ReshapeIsNoop ++ : Pat<(StableHLO_ReshapeOp:$reshape $operand), + (replaceWithValue $operand), + [(TypesEqual $reshape, $operand)]>; + diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index d7a4ecc61d5e98..d62735b16d7a5a 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -26,6 +26,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@stablehlo//:stablehlo_ops", ], ) @@ -119,6 +120,7 @@ cc_library( "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", ], ) @@ -137,6 +139,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", @@ -144,7 +147,9 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 44d298bd69f5a8..eef67b1c3ef9da 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -37,7 +37,6 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -46,17 +45,21 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -141,6 +144,36 @@ ArrayRef FlattenTupleSharding(const HloSharding& sharding) { return sharding; } +// Returns true if changed. +bool FoldGetTupleElementOfTuple(Operation* op) { + int64_t idx; + if (auto getTupleElementOp = + llvm::dyn_cast(op)) { + idx = getTupleElementOp.getIndex(); + } else if (auto getTupleElementOp = + llvm::dyn_cast(op)) { + idx = getTupleElementOp.getIndex(); + } else { + llvm::report_fatal_error("Unexpected op for tuple folding: " + + op->getName().getStringRef()); + } + + if (auto tupleOp = op->getOperand(0).getDefiningOp()) { + llvm::SmallVector new_operand{tupleOp.getOperand(idx)}; + op->replaceAllUsesWith(new_operand); + op->erase(); + return true; + } + if (auto tupleOp = + op->getOperand(0).getDefiningOp()) { + llvm::SmallVector new_operand{tupleOp.getOperand(idx)}; + op->replaceAllUsesWith(new_operand); + op->erase(); + return true; + } + return false; +} + // Clean up the GetTupleElementOp, created during the flattening of // tuple arguments and return values, if eligible for folding. Removal of // get-tuple-element can transitively make the defining TupleOp dead to be @@ -152,13 +185,10 @@ void CleanUpTupleOps(mlir::Block* block, mlir::OpBuilder* builder) { while (changed) { changed = false; for (Operation& op : llvm::make_early_inc_range(block->getOperations())) { - if (llvm::isa(op)) { - folded_results.clear(); - if (failed(builder->tryFold(&op, folded_results))) continue; - op.replaceAllUsesWith(folded_results); - op.erase(); - changed = true; - } else if (llvm::isa(op) && + if (llvm::isa(op)) { + changed = FoldGetTupleElementOfTuple(&op); + } else if (llvm::isa(op) && mlir::isOpTriviallyDead(&op)) { op.erase(); changed = true; @@ -167,12 +197,62 @@ void CleanUpTupleOps(mlir::Block* block, mlir::OpBuilder* builder) { } } -Operation* createReturnOp(mlir::OpBuilder& builder, mlir::Location loc, - mlir::ValueRange operands, bool is_func) { - if (is_func) { - return builder.create(loc, operands); +llvm::ArrayRef ToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +Operation* CreateReturnOp(mlir::OpBuilder& builder, mlir::Location loc, + mlir::ValueRange operands, + mlir::Dialect* parent_dialect) { + LLVM_DEBUG(llvm::dbgs() << "CreateReturnOp: " + << parent_dialect->getNamespace() << '\n'); + if (llvm::isa(parent_dialect)) { + return builder.create(loc, operands); + } + if (llvm::isa(parent_dialect)) { + return builder.create(loc, operands); + } + return builder.create(loc, operands); +} + +bool HasMhloTokenType(mlir::TypeRange types) { + bool use_mhlo = false; + for (auto type : types) { + if (!use_mhlo) { + type.walk([&](Type type) { + use_mhlo |= llvm::isa(type); + if (use_mhlo) return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }); + } + } + return use_mhlo; +} + +Operation* WrapInTuple(mlir::OpBuilder* builder, Operation* op) { + // TODO(b/408024772) ToStablehlo: Make StableHLO only once tokens migrated. + + if (HasMhloTokenType(op->getResultTypes())) { + return builder->create(op->getLoc(), op->getResults()); } - return builder.create(loc, operands); + LLVM_DEBUG(llvm::dbgs() << "WrapInTuple: " << op->getName() + << op->getResultTypes() << '\n'); + return builder->create(op->getLoc(), + op->getResults()); +} + +Operation* GetTupleElementOp(mlir::OpBuilder* builder, Value value, + int64_t index, + llvm::SmallVector&& attributes) { + // TODO(b/408024772) ToStablehlo: Inline once tokens migrated. + attributes.push_back( + builder->getNamedAttr("index", builder->getI32IntegerAttr(index))); + if (HasMhloTokenType(value.getType())) { + return builder->create( + value.getLoc(), value, builder->getI32IntegerAttr(index)); + } + return builder->create( + value.getLoc(), value, builder->getI32IntegerAttr(index)); } // Creates an array of zeros like the given MLIR type, if type has bounded @@ -197,7 +277,7 @@ absl::StatusOr createConstantZeroLike(mlir::Value operand, << type << '\n'); if (type.hasStaticShape()) return builder - ->create(loc, builder->getZeroAttr(type)) + ->create(loc, builder->getZeroAttr(type)) ->getResult(0); // Note: Currently this only supports a single bounded dimension. @@ -215,17 +295,17 @@ absl::StatusOr createConstantZeroLike(mlir::Value operand, input_shape.dimensions().end()); auto padded_type = mlir::RankedTensorType::get(padded_dims, type.getElementType()); - auto padded_constant = builder->create( + auto padded_constant = builder->create( loc, builder->getZeroAttr(padded_type)); // Get or Set the dimensions size based on the operand type. - auto dim_size = builder->create( + auto dim_size = builder->create( loc, operand, builder->getI64IntegerAttr(bounded_dim)); std::vector operands = {padded_constant->getResult(0), dim_size}; std::vector attributes{builder->getNamedAttr( "dimension", builder->getI64IntegerAttr(bounded_dim))}; - return builder->create(loc, type, operands, - attributes); + return builder->create( + loc, type, operands, attributes); } } // namespace @@ -233,7 +313,9 @@ absl::StatusOr createConstantZeroLike(mlir::Value operand, void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( mlir::Operation* op, llvm::ArrayRef implicit_operands) { assert((mlir::dyn_cast(*op) || - mlir::dyn_cast(*op)) && + mlir::dyn_cast(*op) || + mlir::dyn_cast(*op) || + mlir::dyn_cast(*op)) && "Unexpected mlir op in " "HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!"); @@ -251,7 +333,8 @@ static bool IsNestedTupleInData(Type type) { auto tuple_type = type.dyn_cast(); if (!tuple_type) return false; - assert(tuple_type.getType(1).isa() && + assert((llvm::isa(tuple_type.getType(1)) || + llvm::isa(tuple_type.getType(1))) && "Infeed: Non token type"); auto data_type = tuple_type.getType(0); @@ -291,17 +374,16 @@ void HloFunctionImporter::FlattenTupleType( void HloFunctionImporter::FlattenTupleValue( mlir::OpBuilder* func_builder, mlir::Location loc, Value value, llvm::SmallVectorImpl& flattened_values) { - auto tuple_type = value.getType().dyn_cast(); + auto tuple_type = llvm::dyn_cast(value.getType()); if (!tuple_type) { flattened_values.push_back(value); return; } - int flattenIdx = 0; - for (auto child_type : tuple_type.getTypes()) { - auto sub_value = func_builder->create( - loc, child_type, value, func_builder->getI32IntegerAttr(flattenIdx++)); - FlattenTupleValue(func_builder, loc, sub_value, flattened_values); + for (int64_t flattenIdx = 0; flattenIdx < tuple_type.size(); ++flattenIdx) { + auto sub_value = GetTupleElementOp(func_builder, value, flattenIdx, {}); + FlattenTupleValue(func_builder, loc, sub_value->getResult(0), + flattened_values); } } @@ -424,7 +506,7 @@ absl::StatusOr HloFunctionImporter::ImportAsFunc( // NOTE: since we are flattening args, all arguments will share the same // location as the tuple parameter instruction. function.getArgument(arg_index).setLoc( - mlir::mhlo::GenerateInstructionLocation(instruction, context_)); + mlir::hlo::GenerateInstructionLocation(instruction, context_)); ++arg_index; } } else { @@ -452,7 +534,7 @@ absl::StatusOr HloFunctionImporter::ImportAsFunc( } } function.getArgument(arg_index).setLoc( - mlir::mhlo::GenerateInstructionLocation(instruction, context_)); + mlir::hlo::GenerateInstructionLocation(instruction, context_)); ++arg_index; } } @@ -587,7 +669,7 @@ absl::Status HloFunctionImporter::ImportInstructions( continue; } - // For each tuple-typed computation parameter, create a mhlo::TupleOp + // For each tuple-typed computation parameter, create a TupleOp // value in the region body, using the already flattened values in // 'arguments'. For example: With computation parameters: [tuple, // tuple] We have, 'arguments' = [T1 arg1, T2 arg2, T3 arg3] and @@ -600,6 +682,7 @@ absl::Status HloFunctionImporter::ImportInstructions( arguments.begin() + flatten_idx, arguments.begin() + flatten_idx + flattened_arg_type.size())); + // TODO(b/408024772) ToStablehlo: CreateTupleValue auto tupleVal = CreateTupleValue(&builder, loc, sub_args, orig_tuple_arg_type); effective_arguments.push_back(tupleVal); @@ -620,9 +703,10 @@ absl::Status HloFunctionImporter::ImportInstructions( // Flatten tuples in results of this region. llvm::SmallVector flattened_return_operands; FlattenTupleValue(&builder, loc, result, flattened_return_operands); - createReturnOp(builder, loc, flattened_return_operands, is_func); + CreateReturnOp(builder, loc, flattened_return_operands, + block->getParentOp()->getDialect()); } else { - createReturnOp(builder, loc, result, is_func); + CreateReturnOp(builder, loc, result, block->getParentOp()->getDialect()); } CleanUpTupleOps(block, &builder); @@ -670,7 +754,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( : instruction_shape; TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType(shape, *builder_)); - mlir::Location loc = mlir::mhlo::GenerateInstructionLocation( + mlir::Location loc = mlir::hlo::GenerateInstructionLocation( instruction, func_builder->getContext()); llvm::SmallVector attributes; @@ -707,7 +791,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_); if (!attr.ok()) return attr.status(); mlir::Operation* new_operation = - func_builder->create(loc, attr.value()); + func_builder->create(loc, attr.value()); for (auto attr : attributes) { new_operation->setAttr(attr.getName(), attr.getValue()); } @@ -715,7 +799,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kIota: { return func_builder - ->create( + ->create( loc, result_type, func_builder->getI64IntegerAttr( Cast(instruction)->iota_dimension())) @@ -741,6 +825,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (instruction->opcode() == HloOpcode::kAsyncStart) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( context_, result_type.cast().getTypes()); + // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, operands, attributes) @@ -748,12 +833,14 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } else if (instruction->opcode() == HloOpcode::kAsyncUpdate) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( context_, result_type.cast().getTypes()); + // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, operands, attributes) .getOperation(); } else { assert(instruction->opcode() == HloOpcode::kAsyncDone); + // XLA Feature -- MHLO Only return func_builder ->create(loc, result_type, operands, attributes) @@ -763,12 +850,12 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( case HloOpcode::kBroadcast: { // Note that the HLO broadcast is more powerful than the XLA broadcast // op. BroadcastInDim offers a superset of the HLO op's functionality. - attributes.push_back( - builder_->getNamedAttr("broadcast_dimensions", - ConvertDimensions(instruction->dimensions()))); + attributes.push_back(builder_->getNamedAttr( + "broadcast_dimensions", + ConvertArray(ToArrayRef(instruction->dimensions())))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, + operands, attributes) .getOperation(); } @@ -786,15 +873,15 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleType(result_type, flattened_ret_types); auto op = func_builder - ->create( + ->create( loc, flattened_ret_types, operands, attributes) .getOperation(); return CreateTupleFromOpResults(func_builder, loc, op, result_type); } else if (instruction->opcode() == HloOpcode::kBatchNormInference) { return func_builder - ->create(loc, result_type, - operands, attributes) + ->create( + loc, result_type, operands, attributes) .getOperation(); } else { assert(instruction->opcode() == HloOpcode::kBatchNormTraining); @@ -804,7 +891,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleType(result_type, flattened_ret_types); auto op = func_builder - ->create( + ->create( loc, flattened_ret_types, operands, attributes) .getOperation(); @@ -825,6 +912,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } // Consider consolidating DotOps together. if (DotIsDefault(instruction) && !dot->sparse_operands()) { + // TODO(b/408024772) ToStablehlo: Convert[PrecisionConfig|DotAlgorithm] return func_builder ->create(loc, result_type, operands, attributes) .getOperation(); @@ -835,6 +923,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(), builder_))); if (!dot->sparse_operands()) { + // TODO(b/408024772) ToStablehlo: ConvertDotDimensionNumbers return func_builder ->create(loc, result_type, operands, attributes) @@ -848,6 +937,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( descriptor.index() == 0 ? "lhs_sparsity" : "rhs_sparsity", sparsity)); } + // XLA Feature -- MHLO Only return func_builder ->create(loc, result_type, operands, attributes) @@ -869,12 +959,12 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "call_target_name", builder_->getStringAttr("ragged_all_to_all"))); attributes.push_back(builder_->getNamedAttr( "api_version", - mlir::mhlo::CustomCallApiVersionAttr::get( + mlir::stablehlo::CustomCallApiVersionAttr::get( builder_->getContext(), - mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI))); + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kRaggedDot: { @@ -885,6 +975,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "ragged_dot_dimension_numbers", ConvertRaggedDotDimensionNumbers( instruction->ragged_dot_dimension_numbers(), builder_))); + // XLA Feature -- MHLO Only return func_builder ->create(loc, result_type, operands, attributes) @@ -946,7 +1037,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( std::stoi( frontend_attributes_map.find("composite.version")->second)); - new_operation = func_builder->create( + new_operation = func_builder->create( loc, result_type, operands); new_operation->setAttr("name", name); new_operation->setAttr("composite_attributes", composite_attributes); @@ -1005,6 +1096,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (collective_broadcast->channel_id().has_value()) attributes.push_back(ConvertChannelHandle( collective_broadcast->channel_id().value(), builder_)); + // TODO(b/408024772) ToStablehlo: ConvertChannelHandle return func_builder ->create(loc, result_type, operands, attributes) @@ -1018,6 +1110,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (collective_permute->channel_id().has_value()) attributes.push_back(ConvertChannelHandle( collective_permute->channel_id().value(), builder_)); + // TODO(b/408024772) ToStablehlo: ConvertChannelHandle return func_builder ->create(loc, result_type, operands, attributes) @@ -1123,6 +1216,9 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "output_operand_aliases", ConvertOutputOperandAliasing(instruction->output_operand_aliasing(), builder_))); + + // TODO(b/408024772) ToStablehlo: Special handling needed for CC schedules + // that aren't NONE and convert CC attrs to StableHLO return func_builder ->create(loc, result_type, operands, attributes) @@ -1135,6 +1231,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( compare->operand(0)->shape().element_type()); if (compare->type() != default_type) attributes.push_back(ConvertComparisonType(compare->type())); + // TODO(b/408024772) ToStableHLO: ConvertComparison[Direction|Type] return func_builder ->create(loc, result_type, operands, attributes) @@ -1145,8 +1242,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "lower", builder_->getBoolAttr(instruction->cholesky_options().lower()))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kGather: { @@ -1165,6 +1262,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "indices_are_sorted", builder_->getBoolAttr(gather_instruction->indices_are_sorted()))); + // TODO(b/408024772) ToStableHLO: ConvertGatherDimensionNumbers return func_builder ->create(loc, result_type, operands, attributes) .getOperation(); @@ -1174,14 +1272,14 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( instruction->dynamic_slice_sizes().begin(), instruction->dynamic_slice_sizes().end()); return func_builder - ->create( + ->create( loc, result_type, operands[0], - llvm::ArrayRef(operands).drop_front(), Convert(slice_sizes)) + llvm::ArrayRef(operands).drop_front(), ConvertArray(slice_sizes)) .getOperation(); } case HloOpcode::kDynamicUpdateSlice: { return func_builder - ->create( + ->create( loc, result_type, operands[0], operands[1], llvm::ArrayRef(operands.begin() + 2, operands.end())) .getOperation(); @@ -1207,6 +1305,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( llvm::SmallVector flattened_ret_types; FlattenTupleType(result_type, flattened_ret_types); + // TODO(b/408024772) ToStableHLO: result_type needs to be StableHLO Token auto op = func_builder->create( loc, flattened_ret_types, operands, attributes); @@ -1226,6 +1325,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleValue(func_builder, loc, operands[0], flattened_operands); flattened_operands.push_back(operands[1]); + // TODO(b/408024772) ToStableHLO: result_type needs to be StableHLO Token auto op = func_builder->create( loc, result_type, flattened_operands, attributes); @@ -1247,10 +1347,10 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } return func_builder - ->create(loc, result_type, operands[0], - operands[1], Convert(edge_padding_low), - Convert(edge_padding_high), - Convert(interior_padding)) + ->create( + loc, result_type, operands[0], operands[1], + ConvertArray(edge_padding_low), ConvertArray(edge_padding_high), + ConvertArray(interior_padding)) .getOperation(); } case HloOpcode::kScatter: { @@ -1268,6 +1368,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( llvm::SmallVector flattened_types; FlattenTupleType(result_type, flattened_types); + // TODO(b/408024772) ToStableHLO: ConvertScatterDimensionNumbers auto scatter_op = func_builder->create( loc, flattened_types, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(), @@ -1288,13 +1389,13 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( padding.push_back(dim.padding_low()); padding.push_back(dim.padding_high()); } - attributes.push_back( - builder_->getNamedAttr("window_strides", Convert(window_strides))); - attributes.push_back(builder_->getNamedAttr("window_dimensions", - Convert(window_dimensions))); + attributes.push_back(builder_->getNamedAttr( + "window_strides", ConvertArray(window_strides))); + attributes.push_back(builder_->getNamedAttr( + "window_dimensions", ConvertArray(window_dimensions))); attributes.push_back(ConvertPadding(padding)); auto select_scatter_op = - func_builder->create( + func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(), &select_scatter_op.getSelect())); @@ -1306,17 +1407,17 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back(builder_->getNamedAttr( "dimension", builder_->getI64IntegerAttr(instruction->dimension()))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, + operands, attributes) .getOperation(); } case HloOpcode::kSlice: { return func_builder - ->create( + ->create( loc, result_type, operands[0], - ConvertDimensions(instruction->slice_starts()), - ConvertDimensions(instruction->slice_limits()), - ConvertDimensions(instruction->slice_strides())) + ConvertArray(instruction->slice_starts()), + ConvertArray(instruction->slice_limits()), + ConvertArray(instruction->slice_strides())) .getOperation(); } case HloOpcode::kSort: { @@ -1327,7 +1428,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return_types = llvm::to_vector<6>(tuple_ty.getTypes()); } - auto sort_op = func_builder->create( + auto sort_op = func_builder->create( loc, return_types, operands, builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), builder_->getBoolAttr(sort_instruction->is_stable())); @@ -1339,19 +1440,16 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return sort_op.getOperation(); } - return func_builder - ->create(loc, result_type, sort_op.getResults()) - .getOperation(); + return WrapInTuple(func_builder, sort_op); } case HloOpcode::kTopK: { auto topk_instruction = Cast(instruction); + // XLA Feature -- MHLO Only auto topk_op = func_builder->create( loc, result_type.dyn_cast().getTypes(), operands[0], builder_->getI64IntegerAttr(topk_instruction->k()), builder_->getBoolAttr(topk_instruction->largest())); - return func_builder - ->create(loc, result_type, topk_op.getResults()) - .getOperation(); + return WrapInTuple(func_builder, topk_op); } case HloOpcode::kCopyStart: { return ImportCopyStart(instruction, loc, operands, attributes, @@ -1400,7 +1498,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( assert(rets.size() == 1); FlattenTupleType(rets[0], flattened_ret_types); - auto op = func_builder->create( + auto op = func_builder->create( loc, flattened_ret_types, flattened_operands[0], attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(), &op.getTrueBranch())); @@ -1427,7 +1525,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleType(rets[0], flattened_ret_types); int num_branches = instruction->branch_count(); - auto op = func_builder->create( + auto op = func_builder->create( loc, flattened_ret_types, flattened_operands[0], attributes, num_branches); for (const auto& index_and_computation : @@ -1448,7 +1546,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kConcatenate: { return func_builder - ->create( + ->create( loc, result_type, operands, builder_->getI64IntegerAttr(instruction->concatenate_dimension())) .getOperation(); @@ -1471,13 +1569,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertChannelHandle(all_gather->channel_id().value(), builder_)); if (all_gather->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); + // TODO(b/408024772) ToStablehlo: ConvertChannelHandleToStablehlo auto all_gather_op = func_builder->create( loc, result_types, operands, attributes); if (result_tuple_ty) { - return func_builder - ->create(loc, result_type, - all_gather_op.getResults()) - .getOperation(); + return WrapInTuple(func_builder, all_gather_op); } return all_gather_op.getOperation(); } @@ -1505,19 +1601,19 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertChannelHandle(all_reduce->channel_id().value(), builder_)); if (all_reduce->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); + // TODO(b/408024772) ToStablehlo: ConvertChannelHandleToStablehlo auto all_reduce_op = func_builder->create( loc, result_types, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(), &all_reduce_op.getComputation())); if (result_tuple_ty) { - return func_builder - ->create(loc, result_type, - all_reduce_op.getResults()) - .getOperation(); + return WrapInTuple(func_builder, all_reduce_op); } return all_reduce_op.getOperation(); } case HloOpcode::kAllReduceStart: { + // TODO(b/408024772) ToStablehlo: Special handling needed for + // AllReduceStart. auto appendRegion = [&](mlir::mhlo::AllReduceOp all_reduce_sync) { TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(), &all_reduce_sync.getComputation())); @@ -1529,6 +1625,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( symbol_table_); } case HloOpcode::kAllReduceDone: { + // TODO(b/408024772) ToStablehlo: Special handling needed for + // AllReduceStart. return ImportAsyncOpDone(instruction, loc, operands, attributes, result_type, func_builder); } @@ -1562,6 +1660,9 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return_types = llvm::to_vector<4>(result_tuple_ty.getTypes()); } + // TODO(b/408024772) ToStablehlo: mhlo AllToAll and StableHLO AllToAll are + // different currently for multiple arguments, we need to fix this. + // Additionally ConvertChannelHandle changes are needed. auto result = func_builder->create( loc, return_types, operands, nullptr, nullptr, nullptr, replica_groups_attr); @@ -1574,9 +1675,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } if (result_tuple_ty) { - return func_builder - ->create(loc, result_type, result.getResults()) - .getOperation(); + return WrapInTuple(func_builder, result); } result.setSplitDimension(all_to_all->split_dimension().value()); @@ -1593,6 +1692,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return_types = llvm::to_vector<6>(tuple_ty.getTypes()); } + // TODO(b/408024772) ToStablehlo: ReduceOp builder is different in + // StableHLO auto reduce = func_builder->create( loc, return_types, llvm::ArrayRef(operands).take_front(num_inputs), llvm::ArrayRef(operands).drop_front(num_inputs), @@ -1608,11 +1709,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return reduce.getOperation(); } - mlir::Operation* operation = - func_builder - ->create(loc, result_type, - reduce.getResults()) - .getOperation(); + mlir::Operation* operation = WrapInTuple(func_builder, reduce); for (auto attr : attributes) { operation->setAttr(attr.getName(), attr.getValue()); } @@ -1620,27 +1717,27 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kReverse: { return func_builder - ->create( + ->create( loc, result_type, operands[0], - ConvertDimensions(instruction->dimensions())) + ConvertArray(ToArrayRef(instruction->dimensions()))) .getOperation(); } case HloOpcode::kRng: { - auto shape = func_builder->create( + auto shape = func_builder->create( loc, Convert(result_type.cast().getShape())); switch (instruction->random_distribution()) { case RNG_UNIFORM: return func_builder - ->create( + ->create( loc, result_type, operands[0], operands[1], shape, - ::mlir::mhlo::RngDistribution::UNIFORM) + ::mlir::stablehlo::RngDistribution::UNIFORM) .getOperation(); case RNG_NORMAL: return func_builder - ->create(loc, result_type, operands[0], - operands[1], shape, - ::mlir::mhlo::RngDistribution::NORMAL) + ->create( + loc, result_type, operands[0], operands[1], shape, + ::mlir::stablehlo::RngDistribution::NORMAL) .getOperation(); default: @@ -1652,13 +1749,13 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( case HloOpcode::kRngBitGenerator: { // HloRngBitGeneratorInstruction can have two kinds of shapes, (1) // tuple(output_state, output_data), and (2) output_data. - // mhlo::RngBitGeneratorOp has only one shape, (output_state, + // stablehlo::RngBitGeneratorOp has only one shape, (output_state, // output_data). auto rng_op = Cast(instruction); - auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( + auto algorithm_attr = mlir::stablehlo::RngAlgorithmAttr::get( builder_->getContext(), - *mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm())); + *mlir::stablehlo::symbolizeRngAlgorithm(rng_op->algorithm())); attributes.push_back( builder_->getNamedAttr("rng_algorithm", algorithm_attr)); @@ -1683,15 +1780,17 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } CHECK_EQ(flattened_ret_types.size(), 2); - auto op = func_builder->create( + auto op = func_builder->create( loc, flattened_ret_types, operands[0], attributes); if (rng_op->shape().IsArray()) { return op.getOperation(); } + // TODO(b/408024772) ToStablehlo: CreateTupleFromOpResults return CreateTupleFromOpResults(func_builder, loc, op.getOperation(), result_type); } case HloOpcode::kRngGetAndUpdateState: { + // XLA Feature -- MHLO Only return func_builder ->create( loc, result_type, @@ -1706,7 +1805,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleType(operands[0].getType(), flattened_operand_types); FlattenTupleValue(func_builder, loc, operands[0], flattened_operands); - auto op = func_builder->create( + auto op = func_builder->create( loc, flattened_operand_types, flattened_operands); TF_RETURN_IF_ERROR( @@ -1717,28 +1816,24 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( operands[0].getType()); } case HloOpcode::kGetTupleElement: { - attributes.push_back(builder_->getNamedAttr( - "index", builder_->getIntegerAttr(builder_->getIntegerType(32), - instruction->tuple_index()))); - return func_builder - ->create(loc, result_type, operands, - attributes) - .getOperation(); + return GetTupleElementOp(func_builder, operands[0], + instruction->tuple_index(), + std::move(attributes)); }; case HloOpcode::kGetDimensionSize: { attributes.push_back(builder_->getNamedAttr( "dimension", builder_->getI64IntegerAttr(instruction->dimension()))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, + operands, attributes) .getOperation(); }; case HloOpcode::kTranspose: { attributes.push_back(builder_->getNamedAttr( - "permutation", ConvertDimensions(instruction->dimensions()))); + "permutation", ConvertArray(ToArrayRef(instruction->dimensions())))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kTriangularSolve: { @@ -1753,17 +1848,17 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "unit_diagonal", builder_->getBoolAttr( instruction->triangular_solve_options().unit_diagonal()))); - auto transpose_a = mlir::mhlo::TransposeAttr::get( + auto transpose_a = mlir::stablehlo::TransposeAttr::get( builder_->getContext(), - mlir::mhlo::symbolizeTranspose( + mlir::stablehlo::symbolizeTranspose( TriangularSolveOptions::Transpose_Name( instruction->triangular_solve_options().transpose_a())) .value()); attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a)); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, + operands, attributes) .getOperation(); } case HloOpcode::kReduceScatter: { @@ -1778,6 +1873,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( reduce_scatter->channel_id().value(), builder_)); if (reduce_scatter->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); + // TODO(b/408024772) ToStablehlo: ConvertChannelHandle auto reduce_scatter_op = func_builder->create( loc, result_type, operands, attributes); @@ -1802,16 +1898,16 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( padding.push_back(dim.padding_low()); padding.push_back(dim.padding_high()); } - attributes.push_back(builder_->getNamedAttr("window_dimensions", - ConvertDimensions(sizes))); attributes.push_back( - builder_->getNamedAttr("window_strides", ConvertDimensions(strides))); - attributes.push_back(builder_->getNamedAttr( - "base_dilations", ConvertDimensions(base_dilations))); + builder_->getNamedAttr("window_dimensions", ConvertArray(sizes))); + attributes.push_back( + builder_->getNamedAttr("window_strides", ConvertArray(strides))); attributes.push_back(builder_->getNamedAttr( - "window_dilations", ConvertDimensions(win_dilations))); + "base_dilations", ConvertArray(base_dilations))); + attributes.push_back(builder_->getNamedAttr("window_dilations", + ConvertArray(win_dilations))); attributes.push_back(ConvertPadding(padding)); - auto reduce = func_builder->create( + auto reduce = func_builder->create( loc, return_types, operands, attributes); TF_RETURN_IF_ERROR( ImportAsRegion(*instruction->to_apply(), &reduce.getBody())); @@ -1821,14 +1917,12 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return reduce.getOperation(); } - return func_builder - ->create(loc, result_type, reduce.getResults()) - .getOperation(); + return WrapInTuple(func_builder, reduce); } case HloOpcode::kMap: { - auto op = func_builder->create( + auto op = func_builder->create( loc, result_type, operands, - ConvertDimensions(instruction->dimensions())); + ConvertArray(ToArrayRef(instruction->dimensions()))); TF_RETURN_IF_ERROR( ImportAsRegion(*instruction->to_apply(), &op.getComputation())); return op.getOperation(); @@ -1882,14 +1976,14 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto convert_op_return_type = mlir::cast(lhs.getType()) .clone(mlir::getElementTypeOrSelf(rhs)); - lhs = func_builder->create( + lhs = func_builder->create( loc, convert_op_return_type, lhs); } else if (primitive_util::CastPreservesValues(rhs_element_type, lhs_element_type)) { auto convert_op_return_type = mlir::cast(rhs.getType()) .clone(mlir::getElementTypeOrSelf(lhs)); - rhs = func_builder->create( + rhs = func_builder->create( loc, convert_op_return_type, rhs); } else { return InvalidArgument( @@ -1898,6 +1992,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( instruction->operand(0)->shape().ToString(), instruction->operand(1)->shape().ToString()); } + // TODO(b/408024772) ToStablehlo: ConvertPrecisionConfig, + // ConvertConvDimensionNumbers return func_builder ->create( loc, result_type, std::vector{lhs, rhs}, @@ -1905,6 +2001,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( .getOperation(); } + // TODO(b/408024772) ToStablehlo: ConvertPrecisionConfig, + // ConvertConvDimensionNumbers return func_builder ->create(loc, result_type, operands, attributes) @@ -1912,19 +2010,20 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kFft: { - auto fft_type = mlir::mhlo::FftTypeAttr::get( - builder_->getContext(), - mlir::mhlo::symbolizeFftType(FftType_Name(instruction->fft_type())) - .value()); + auto fft_type = mlir::stablehlo::FftTypeAttr::get( + builder_->getContext(), mlir::stablehlo::symbolizeFftType( + FftType_Name(instruction->fft_type())) + .value()); std::vector fft_length(instruction->fft_length().begin(), instruction->fft_length().end()); attributes.push_back(builder_->getNamedAttr("fft_type", fft_type)); attributes.push_back( - builder_->getNamedAttr("fft_length", Convert(fft_length))); + builder_->getNamedAttr("fft_length", ConvertArray(fft_length))); return func_builder - ->create(loc, result_type, operands, attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } @@ -1935,11 +2034,13 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( // arith::OrIOp instead. if (instruction->shape().element_type() == PRED) { return func_builder - ->create(loc, result_type, operands, attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } else { return func_builder - ->create(loc, result_type, operands, attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } } @@ -1947,11 +2048,16 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( // HLO AfterAll ops without any token input are used to just create a // token. MHLO has a special op CreateToken for this case. if (instruction->operands().empty()) { + // TODO(b/408024772) ToStablehlo: result_type needs to be a + // stablehlo::TokenType. + // Also, remove CreateTokenOp usage to AfterAllOp. return func_builder ->create(loc, result_type, operands, attributes) .getOperation(); } else { + // TODO(b/408024772) ToStablehlo: result_type needs to be a + // stablehlo::TokenType. return func_builder ->create(loc, result_type, operands, attributes) @@ -1969,7 +2075,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( : nullptr; if (!integer_type || integer_type.getWidth() != 1) { // Simple case: 1-1 mapping. - return {func_builder->create( + return {func_builder->create( loc, result_type, operands, attributes)}; } @@ -1980,10 +2086,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( createConstantZeroLike(operands[0], input_shape, func_builder, loc)); std::vector compare_operands = {operands[0], zero}; std::vector attributes = {builder_->getNamedAttr( - "comparison_direction", mlir::mhlo::ComparisonDirectionAttr::get( - func_builder->getContext(), - mlir::mhlo::ComparisonDirection::NE))}; - return {func_builder->create( + "comparison_direction", + mlir::stablehlo::ComparisonDirectionAttr::get( + func_builder->getContext(), + mlir::stablehlo::ComparisonDirection::NE))}; + return {func_builder->create( loc, result_type, compare_operands, attributes)}; } case HloOpcode::kOptimizationBarrier: { @@ -1992,7 +2099,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleType(operands[0].getType(), flattened_operand_types); FlattenTupleValue(func_builder, loc, operands[0], flattened_operands); - auto op = func_builder->create( + auto op = func_builder->create( loc, flattened_operand_types, flattened_operands); return CreateTupleFromOpResults(func_builder, loc, op.getOperation(), @@ -2027,23 +2134,24 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "entry_metadata", ConvertSharding(*(*entry_metadata)->sharding(), builder_))); + // XLA Feature -- MHLO Only return func_builder ->create(loc, result_type, operands, attributes) .getOperation(); } -#define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op) \ - case HloOpcode::hlo_op_code: { \ - return func_builder \ - ->create(loc, result_type, operands, attributes) \ - .getOperation(); \ +#define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op) \ + case HloOpcode::hlo_op_code: { \ + return func_builder \ + ->create(loc, result_type, operands, \ + attributes) \ + .getOperation(); \ } // broadcast dimensions are never added here because they don't exist as // part of the HLO instruction. They are only a convenience in the XLA // builder API. NO_ATTRIBUTE_CASE(kAbs, AbsOp); - NO_ATTRIBUTE_CASE(kAddDependency, AddDependencyOp); NO_ATTRIBUTE_CASE(kAnd, AndOp); NO_ATTRIBUTE_CASE(kAtan2, Atan2Op); NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp); @@ -2067,8 +2175,6 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kReal, RealOp); NO_ATTRIBUTE_CASE(kRemainder, RemOp); NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp); - NO_ATTRIBUTE_CASE(kStochasticConvert, StochasticConvertOp); - NO_ATTRIBUTE_CASE(kErf, ErfOp); // The dimensions attribute is not present on the HLO Reshape // instruction. If dimensions are non-default, the XLA builder // implements it as a separate transpose. @@ -2081,12 +2187,27 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp); NO_ATTRIBUTE_CASE(kSign, SignOp); NO_ATTRIBUTE_CASE(kSubtract, SubtractOp); - NO_ATTRIBUTE_CASE(kTuple, TupleOp); NO_ATTRIBUTE_CASE(kXor, XorOp); - NO_ATTRIBUTE_CASE(kCopy, CopyOp); #undef NO_ATTRIBUTE_CASE +#define NO_ATTRIBUTE_CASE_MHLO(hlo_op_code, mlir_op) \ + case HloOpcode::hlo_op_code: { \ + return func_builder \ + ->create(loc, result_type, operands, attributes) \ + .getOperation(); \ + } + NO_ATTRIBUTE_CASE_MHLO(kAddDependency, AddDependencyOp); + NO_ATTRIBUTE_CASE_MHLO(kCopy, CopyOp); + NO_ATTRIBUTE_CASE_MHLO(kErf, ErfOp); + NO_ATTRIBUTE_CASE_MHLO(kStochasticConvert, StochasticConvertOp); + // TODO(b/408024772) ToStablehlo: Once all tokens are stablehlo.token move + // to NO_ATTRIBUTE_CASE. + NO_ATTRIBUTE_CASE_MHLO(kTuple, TupleOp); + +#undef NO_ATTRIBUTE_CASE_MHLO + +// TODO(b/408024772) ToStablehlo: ConvertResultAccuracy #define RESULT_ACCURACY_CASE(hlo_op_code, mlir_op) \ case HloOpcode::hlo_op_code: { \ if (instruction->has_result_accuracy()) { \ @@ -2132,6 +2253,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "output_operand_aliasing", ConvertOutputOperandAliasing(instruction->output_operand_aliasing(), builder_))); + // XLA Feature -- MHLO Only auto fusion = func_builder->create( loc, flattened_ret_types, flattened_operands, attributes); TF_RETURN_IF_ERROR( @@ -2142,6 +2264,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( result_type); } case HloOpcode::kBitcast: { + // XLA Feature -- MHLO Only auto bitcast = func_builder->create( loc, result_type, operands, attributes); // Store the source and result layout as attributes. Although the MHLO @@ -2153,7 +2276,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return bitcast.getOperation(); } case HloOpcode::kReducePrecision: { - auto op = func_builder->create( + auto op = func_builder->create( loc, result_type, operands[0], attributes); op.setExponentBitsAttr(func_builder->getIntegerAttr( func_builder->getI32Type(), instruction->exponent_bits())); @@ -2200,7 +2323,10 @@ HloFunctionImporter::ImportInstructionWithLayout( LLVM_DEBUG(llvm::dbgs() << " instruction skipped.\n"); return op; } - LLVM_DEBUG(llvm::dbgs() << " imported: " << *op << '\n'); + + // Print generic in debug since module may be invalid while printing. + LLVM_DEBUG( + op->print(llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm())); // See MlirToHloConversionOptions for more about layouts. // @@ -2284,6 +2410,11 @@ mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions( dimensions); } +mlir::DenseI64ArrayAttr HloFunctionImporter::ConvertArray( + llvm::ArrayRef elements) { + return builder_->getDenseI64ArrayAttr(elements); +} + mlir::DenseIntElementsAttr HloFunctionImporter::Convert( llvm::ArrayRef elements) { return DenseIntElementsAttr::get( diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h index ec0ffa87ad1fbc..386dff23fe83e6 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -208,6 +209,8 @@ class HloFunctionImporter { mlir::DenseIntElementsAttr ConvertDimensions( absl::Span op_dimensions); + mlir::DenseI64ArrayAttr ConvertArray(llvm::ArrayRef elements); + // Converts Array ref to an DenseIntElementsAttr. mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc index b0ae41e090912e..0a65ba01b4a3e3 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -19,17 +19,21 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" @@ -38,17 +42,31 @@ namespace xla { HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, bool import_all_computation, - bool flatten_computation_args_result) + bool flatten_computation_args_result, + bool emit_stablehlo) : import_all_computation_(import_all_computation), flatten_computation_args_result_(flatten_computation_args_result), symbol_table_(module), + emit_stablehlo_(emit_stablehlo), builder_(module.getContext()) { module.getContext()->loadDialect(); module.getContext()->loadDialect(); module.getContext()->loadDialect(); + module.getContext()->loadDialect(); module.getContext()->loadDialect(); } +namespace { +absl::Status ConvertToMhlo(mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + if (failed(pm.run(module))) { + return absl::InternalError("Failed to convert to MHLO"); + } + return absl::OkStatus(); +} +} // namespace + absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { auto module = llvm::cast(symbol_table_.getOp()); module.setName(hlo_module.name()); @@ -68,11 +86,16 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { if (!import_all_computation_) { // Only import the entry computation, any reachable one will be imported // unless turned into a region operation. - return HloFunctionImporter::ImportAsFunc( - *hlo_module.entry_computation(), symbol_table_, &function_map_, - &builder_, - /*is_main*/ true, flatten_computation_args_result_) - .status(); + TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( + *hlo_module.entry_computation(), symbol_table_, + &function_map_, &builder_, + /*is_main*/ true, flatten_computation_args_result_) + .status()); + // Convert all ops to MHLO + if (!emit_stablehlo_) { + TF_RETURN_IF_ERROR(ConvertToMhlo(module)); + } + return absl::OkStatus(); } auto* module_entry_computation = hlo_module.entry_computation(); @@ -89,6 +112,11 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { hlo_module, module, flatten_computation_args_result_, builder_); TF_RETURN_IF_ERROR(ImportLayoutModes( hlo_module, module, flatten_computation_args_result_, builder_)); + + // Convert all ops to MHLO + if (!emit_stablehlo_) { + TF_RETURN_IF_ERROR(ConvertToMhlo(module)); + } return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h index 8937f673035a23..0abfe50b93e0b3 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h @@ -38,7 +38,8 @@ class HloModuleImporter { public: explicit HloModuleImporter(mlir::ModuleOp module, bool import_all_computation = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, + bool emit_stablehlo = false); // Import the HloModule into the MLIR Module. absl::Status Import(const xla::HloModule& module); @@ -50,6 +51,7 @@ class HloModuleImporter { bool import_all_computation_; bool flatten_computation_args_result_; mlir::SymbolTable symbol_table_; + bool emit_stablehlo_; mlir::Builder builder_; // Map for tracking which MLIR function map to which HLO Computation. This diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc index 39293742e068c1..9db6d66a06768d 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc @@ -33,43 +33,47 @@ namespace xla { absl::StatusOr> ConvertHloToMlirHlo( mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, - bool import_all_computations, bool flatten_computation_args_result) { + bool import_all_computations, bool flatten_computation_args_result, + bool emit_stablehlo) { mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module, - import_all_computations, - flatten_computation_args_result)); + TF_RETURN_IF_ERROR( + ConvertHloToMlirHlo(*module, hlo_module, import_all_computations, + flatten_computation_args_result, emit_stablehlo)); return module; } absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModuleProto const* hlo_module_proto, bool import_all_computation, - bool flatten_computation_args_result) { + bool flatten_computation_args_result, + bool emit_stablehlo) { mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext()); return HloModuleImporter(module, import_all_computation, - flatten_computation_args_result) + flatten_computation_args_result, emit_stablehlo) .Import(*hlo_module_proto); } absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, const xla::HloModule* hlo_module, bool import_all_computation, - bool flatten_computation_args_result) { + bool flatten_computation_args_result, + bool emit_stablehlo) { mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext()); return HloModuleImporter(module, import_all_computation, - flatten_computation_args_result) + flatten_computation_args_result, emit_stablehlo) .Import(*hlo_module); } absl::StatusOr> ConvertHloToMlirHlo( mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, - bool import_all_computations, bool flatten_computation_args_result) { + bool import_all_computations, bool flatten_computation_args_result, + bool emit_stablehlo) { mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module, - import_all_computations, - flatten_computation_args_result)); + TF_RETURN_IF_ERROR( + ConvertHloToMlirHlo(*module, hlo_module, import_all_computations, + flatten_computation_args_result, emit_stablehlo)); return module; } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h index 2489106527569b..cf1bab0b56e22c 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -42,12 +42,13 @@ class HloModuleProto; absl::StatusOr> ConvertHloToMlirHlo( mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, bool import_all_computations = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, bool emit_stablehlo = false); absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModuleProto const* hlo_module, bool import_all_computations = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, + bool emit_stablehlo = false); // Converts an HLO module to a MLIR module in HLO dialect. // @@ -59,12 +60,13 @@ absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, absl::StatusOr> ConvertHloToMlirHlo( mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, bool import_all_computations = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, bool emit_stablehlo = false); absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, const xla::HloModule* hlo_module, bool import_all_computations = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, + bool emit_stablehlo = false); } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc index f752a4b2577a40..48169f4d749361 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc @@ -26,9 +26,8 @@ limitations under the License. #include "xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h" namespace mlir { -namespace mhlo { +namespace hlo { -// TODO(herhut): Refactor the format. mlir::Location GenerateInstructionLocation( const xla::HloInstruction* instruction, mlir::MLIRContext* context) { mlir::Builder b(context); @@ -61,5 +60,5 @@ mlir::Location GenerateInstructionLocation( instruction->metadata().source_line(), 0)}); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h index 0137fa446b024a..06174be31d89b8 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h @@ -21,14 +21,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" namespace mlir { -namespace mhlo { +namespace hlo { // Returns an MLIR Location generated from HLO Instruction. Uses instruction // metadata if present or instruction name. mlir::Location GenerateInstructionLocation( const xla::HloInstruction* instruction, mlir::MLIRContext* context); -} // namespace mhlo +} // namespace hlo } // namespace mlir #endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc index 9448af4acb5eeb..9b53614c0c8c33 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" namespace mlir { -namespace mhlo { +namespace hlo { mlir::Location GetLocationFromFrameIndex(int frame_id, mlir::Builder& builder, const xla::HloModule* hlo_module) { std::vector stack_locations; @@ -57,5 +57,5 @@ mlir::Location GetLocationFromFrameIndex(int frame_id, mlir::Builder& builder, return mlir::CallSiteLoc::get(stack_locations[0], stack_locations_ref.drop_front()); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h index f5210d558d9152..3fec1b2fc879a1 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h @@ -22,13 +22,13 @@ limitations under the License. #include "xla/service/hlo.pb.h" namespace mlir { -namespace mhlo { +namespace hlo { // Construct MLIR location from frame index. // Returns unknown location if frame is not presented. mlir::Location GetLocationFromFrameIndex(int frame_id, mlir::Builder &builder, const xla::HloModule *hlo_module); -} // namespace mhlo +} // namespace hlo } // namespace mlir #endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index a407328a2ada0b..51fdb9712fae54 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -26,6 +26,7 @@ lit_test_suite( "import_async.hlo", "import_bounded_dynamism.hlo", "import_entry_computation_layout.hlo", + "import_emit_stablehlo.hlo", "layouts_and_names.hlo", "location.hlo", "module_attributes.hlo", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo new file mode 100644 index 00000000000000..d76a2b2fe51f21 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo @@ -0,0 +1,2788 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// minimized and named to reflect the test intent. + +// Regnerate them using the following command: +// $ TFILE=/path/to/import_emit_stablehlo.hlo +// $ DELIM="Hlo Module" # Remove the space in the middle when running cmd. This comment needs the space since the source file is regex matched. +// $ xla-translate $TFILE -hlo-text-to-mlir-hlo --emit-stablehlo --split-input-file --hlo-import-all-computations | \ +// third_party/llvm/llvm-project/mlir/utils/generate-test-checks.py --source $TFILE --source_delim_regex="$DELIM" --starts_from_scope=0 -i + +// RUN: xla-translate %s -hlo-text-to-mlir-hlo --emit-stablehlo --split-input-file --hlo-import-all-computations | FileCheck %s + +// CHECK-LABEL: module @foo attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +// CHECK: } +HloModule foo, entry_computation_layout={(pred[])->pred[]} + +ENTRY %main.2 (Arg_0.1: pred[]) -> pred[] { + ROOT %Arg_0.1 = pred[] parameter(0) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = stablehlo.xor %[[VAL_0]], %[[VAL_0]] : tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(pred[2]{0})->pred[2]{0}} + +ENTRY %main.3 (Arg_0.1: pred[2]) -> pred[2] { + %Arg_0.1 = pred[2] parameter(0) + ROOT %xor.2 = pred[2] xor(%Arg_0.1, %Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_2:.*]] = mhlo.after_all %[[VAL_0]], %[[VAL_1]] {xla_shape = "token[]"} : !mhlo.token +// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[], token[])->token[]} + +ENTRY %main.4 (Arg_0.1: token[], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = token[] parameter(0) + %Arg_1.2 = token[] parameter(1) + ROOT %after-all.3 = token[] after-all(%Arg_0.1, %Arg_1.2) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main() -> !mhlo.token { +// CHECK: %[[VAL_0:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token +// CHECK: return %[[VAL_0]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={()->token[]} + +ENTRY %main.2 () -> token[] { + ROOT %after-all.1 = token[] after-all() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<5xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<5xf32> +// CHECK: return %[[VAL_4]] : tensor<5xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[5]{0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(%Arg_0.3, %Arg_1.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[5] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %reduce-scatter.6 = f32[5] reduce-scatter(%Arg_0.1), channel_id=5, replica_groups={{0,2},{1,3}}, dimensions={0}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> { +// CHECK: %[[VAL_1:.*]] = "mhlo.all_gather"(%[[VAL_0]]) <{all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> : (tensor<128x32xf32>) -> tensor<128x128xf32> +// CHECK: return %[[VAL_1]] : tensor<128x128xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,128] { + %Arg_0.1 = f32[128,32] parameter(0) + ROOT %all-gather.2 = f32[128,128] all-gather(%Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> { +// CHECK: %[[VAL_1:.*]] = "mhlo.all_gather"(%[[VAL_0]]) <{all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<128x32xf32>) -> tensor<128x128xf32> +// CHECK: return %[[VAL_1]] : tensor<128x128xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,128] { + %Arg_0.1 = f32[128,32] parameter(0) + ROOT %all-gather.2 = f32[128,128] all-gather(%Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<8x2xf32>, %[[VAL_1:.*]]: tensor<8x4xf32>) -> tuple, tensor<8x16xf32>> { +// CHECK: %[[VAL_2:.*]]:2 = "mhlo.all_gather"(%[[VAL_0]], %[[VAL_1]]) <{all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) +// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[8,8]{1,0}, f32[8,16]{1,0})"} : tuple, tensor<8x16xf32>> +// CHECK: return %[[VAL_3]] : tuple, tensor<8x16xf32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[8,2]{1,0}, f32[8,4]{1,0})->(f32[8,8]{1,0}, f32[8,16]{1,0})} + +ENTRY %main.10 (Arg_0.1: f32[8,2], Arg_1.2: f32[8,4]) -> (f32[8,8], f32[8,16]) { + %Arg_0.1 = f32[8,2] parameter(0) + %Arg_1.2 = f32[8,4] parameter(1) + %tuple.3 = (f32[8,2], f32[8,4]) tuple(%Arg_0.1, %Arg_1.2) + %get-tuple-element.4 = f32[8,2] get-tuple-element(%tuple.3), index=0 + %get-tuple-element.5 = f32[8,4] get-tuple-element(%tuple.3), index=1 + %all-gather.6 = (f32[8,8], f32[8,16]) all-gather(%get-tuple-element.4, %get-tuple-element.5), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true + %get-tuple-element.7 = f32[8,8] get-tuple-element(%all-gather.6), index=0 + %get-tuple-element.8 = f32[8,16] get-tuple-element(%all-gather.6), index=1 + ROOT %tuple.9 = (f32[8,8], f32[8,16]) tuple(%get-tuple-element.7, %get-tuple-element.8) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> +// CHECK: return %[[VAL_4]] : tensor<10xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(%Arg_0.3, %Arg_1.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %all-reduce.6 = f32[10] all-reduce(%Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, -1], [1, 3, 5, 6]]> : tensor<2x4xi64>}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> +// CHECK: return %[[VAL_4]] : tensor<10xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(%Arg_0.3, %Arg_1.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %all-reduce.6 = f32[10] all-reduce(%Arg_0.1), channel_id=5, replica_groups={{0,2,4},{1,3,5,6}}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> +// CHECK: return %[[VAL_4]] : tensor<10xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(%Arg_0.3, %Arg_1.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %all-reduce.6 = f32[10] all-reduce(%Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.6(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<8xf32>, %[[VAL_4:.*]]: tensor) -> tuple, tensor> { +// CHECK: %[[VAL_5:.*]]:2 = "mhlo.all_reduce"(%[[VAL_3]], %[[VAL_4]]) <{replica_groups = dense<> : tensor<0x0xi64>}> ({ +// CHECK: ^bb0(%[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor): +// CHECK: %[[VAL_8:.*]] = stablehlo.add %[[VAL_6]], %[[VAL_7]] : tensor +// CHECK: mhlo.return %[[VAL_8]] : tensor +// CHECK: }) : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) +// CHECK: %[[VAL_9:.*]] = mhlo.tuple %[[VAL_10:.*]]#0, %[[VAL_10]]#1 {xla_shape = "(f32[8]{0}, f32[])"} : tuple, tensor> +// CHECK: return %[[VAL_9]] : tuple, tensor> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[8]{0}, f32[])->(f32[8]{0}, f32[])} + +%region_0.6 (Arg_0.7: f32[], Arg_1.8: f32[]) -> f32[] { + %Arg_0.7 = f32[] parameter(0) + %Arg_1.8 = f32[] parameter(1) + ROOT %add.9 = f32[] add(%Arg_0.7, %Arg_1.8) +} + +ENTRY %main.14 (Arg_0.1: f32[8], Arg_1.2: f32[]) -> (f32[8], f32[]) { + %Arg_0.1 = f32[8] parameter(0) + %Arg_1.2 = f32[] parameter(1) + %tuple.3 = (f32[8], f32[]) tuple(%Arg_0.1, %Arg_1.2) + %get-tuple-element.4 = f32[8] get-tuple-element(%tuple.3), index=0 + %get-tuple-element.5 = f32[] get-tuple-element(%tuple.3), index=1 + %all-reduce.10 = (f32[8], f32[]) all-reduce(%get-tuple-element.4, %get-tuple-element.5), replica_groups={}, to_apply=%region_0.6 + %get-tuple-element.11 = f32[8] get-tuple-element(%all-reduce.10), index=0 + %get-tuple-element.12 = f32[] get-tuple-element(%all-reduce.10), index=1 + ROOT %tuple.13 = (f32[8], f32[]) tuple(%get-tuple-element.11, %get-tuple-element.12) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<5xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<5xf32> +// CHECK: return %[[VAL_4]] : tensor<5xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[5]{0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(%Arg_0.3, %Arg_1.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[5] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %reduce-scatter.6 = f32[5] reduce-scatter(%Arg_0.1), channel_id=5, replica_groups={{0,2},{1,3}}, dimensions={0}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<5xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64, use_global_device_ids}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<5xf32> +// CHECK: return %[[VAL_4]] : tensor<5xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[5]{0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(%Arg_0.3, %Arg_1.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[5] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %reduce-scatter.6 = f32[5] reduce-scatter(%Arg_0.1), channel_id=5, replica_groups={{0,2},{1,3}}, use_global_device_ids=true, dimensions={0}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2x2x2xf32>, %[[VAL_1:.*]]: tensor<2xf32>, %[[VAL_2:.*]]: tensor<2xf32>, %[[VAL_3:.*]]: tensor<2xf32>, %[[VAL_4:.*]]: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { +// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]], %[[VAL_7:.*]] = "stablehlo.batch_norm_grad"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) <{epsilon = 1.000000e-03 : f32, feature_index = 0 : i64}> : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) +// CHECK: %[[VAL_8:.*]] = mhlo.tuple %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {xla_shape = "(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})"} : tuple, tensor<2xf32>, tensor<2xf32>> +// CHECK: return %[[VAL_8]] : tuple, tensor<2xf32>, tensor<2xf32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}, f32[2]{0}, f32[2,2,2,2]{3,2,1,0})->(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})} + +ENTRY %main.11 (Arg_0.1: f32[2,2,2,2], Arg_1.2: f32[2], Arg_2.3: f32[2], Arg_3.4: f32[2], Arg_4.5: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { + %Arg_0.1 = f32[2,2,2,2] parameter(0) + %Arg_1.2 = f32[2] parameter(1) + %Arg_2.3 = f32[2] parameter(2) + %Arg_3.4 = f32[2] parameter(3) + %Arg_4.5 = f32[2,2,2,2] parameter(4) + %batch-norm-grad.6 = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4, %Arg_4.5), epsilon=0.001, feature_index=0 + %get-tuple-element.7 = f32[2,2,2,2] get-tuple-element(%batch-norm-grad.6), index=0 + %get-tuple-element.8 = f32[2] get-tuple-element(%batch-norm-grad.6), index=1 + %get-tuple-element.9 = f32[2] get-tuple-element(%batch-norm-grad.6), index=2 + ROOT %tuple.10 = (f32[2,2,2,2], f32[2], f32[2]) tuple(%get-tuple-element.7, %get-tuple-element.8, %get-tuple-element.9) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2x2x2xf32>, %[[VAL_1:.*]]: tensor<2xf32>, %[[VAL_2:.*]]: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { +// CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]], %[[VAL_5:.*]] = "stablehlo.batch_norm_training"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) +// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})"} : tuple, tensor<2xf32>, tensor<2xf32>> +// CHECK: return %[[VAL_6]] : tuple, tensor<2xf32>, tensor<2xf32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})->(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})} + +ENTRY %main.9 (Arg_0.1: f32[2,2,2,2], Arg_1.2: f32[2], Arg_2.3: f32[2]) -> (f32[2,2,2,2], f32[2], f32[2]) { + %Arg_0.1 = f32[2,2,2,2] parameter(0) + %Arg_1.2 = f32[2] parameter(1) + %Arg_2.3 = f32[2] parameter(2) + %batch-norm-training.4 = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(%Arg_0.1, %Arg_1.2, %Arg_2.3), epsilon=0.001, feature_index=3 + %get-tuple-element.5 = f32[2,2,2,2] get-tuple-element(%batch-norm-training.4), index=0 + %get-tuple-element.6 = f32[2] get-tuple-element(%batch-norm-training.4), index=1 + %get-tuple-element.7 = f32[2] get-tuple-element(%batch-norm-training.4), index=2 + ROOT %tuple.8 = (f32[2,2,2,2], f32[2], f32[2]) tuple(%get-tuple-element.5, %get-tuple-element.6, %get-tuple-element.7) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor<4xf32>, %[[VAL_2:.*]]: tensor<4xi32>, %[[VAL_3:.*]]: tensor<4xi32>) -> tuple, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>> { +// CHECK: %[[VAL_4:.*]] = stablehlo.atan2 %[[VAL_0]], %[[VAL_1]] : tensor<4xf32> +// CHECK: %[[VAL_5:.*]] = stablehlo.shift_left %[[VAL_2]], %[[VAL_3]] : tensor<4xi32> +// CHECK: %[[VAL_6:.*]] = stablehlo.shift_right_arithmetic %[[VAL_2]], %[[VAL_3]] : tensor<4xi32> +// CHECK: %[[VAL_7:.*]] = stablehlo.shift_right_logical %[[VAL_2]], %[[VAL_3]] : tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {xla_shape = "(f32[4]{0}, s32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>> +// CHECK: return %[[VAL_8]] : tuple, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4]{0}, f32[4]{0}, s32[4]{0}, s32[4]{0})->(f32[4]{0}, s32[4]{0}, s32[4]{0}, s32[4]{0})} + +ENTRY %main.10 (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: s32[4], Arg_3.4: s32[4]) -> (f32[4], s32[4], s32[4], s32[4]) { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + %atan2.5 = f32[4] atan2(%Arg_0.1, %Arg_1.2) + %Arg_2.3 = s32[4] parameter(2) + %Arg_3.4 = s32[4] parameter(3) + %shift-left.6 = s32[4] shift-left(%Arg_2.3, %Arg_3.4) + %shift-right-arithmetic.7 = s32[4] shift-right-arithmetic(%Arg_2.3, %Arg_3.4) + %shift-right-logical.8 = s32[4] shift-right-logical(%Arg_2.3, %Arg_3.4) + ROOT %tuple.9 = (f32[4], s32[4], s32[4], s32[4]) tuple(%atan2.5, %shift-left.6, %shift-right-arithmetic.7, %shift-right-logical.8) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.bitcast_convert %[[VAL_0]] : (tensor<2xi32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[2]{0})->f32[2]{0}} + +ENTRY %main.3 (Arg_0.1: s32[2]) -> f32[2] { + %Arg_0.1 = s32[2] parameter(0) + ROOT %bitcast-convert.2 = f32[2] bitcast-convert(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xi32>) -> tensor<1x2x3x4xi32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.broadcast_in_dim %[[VAL_0]], dims = [3] : (tensor<4xi32>) -> tensor<1x2x3x4xi32> +// CHECK: return %[[VAL_1]] : tensor<1x2x3x4xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[4]{0})->s32[1,2,3,4]{3,2,1,0}} + +ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] { + %Arg_0.1 = s32[4] parameter(0) + ROOT %broadcast.2 = s32[1,2,3,4] broadcast(%Arg_0.1), dimensions={3} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<1xf32>) -> tensor<1x10xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.broadcast_in_dim %[[VAL_0]], dims = [0] : (tensor<1xf32>) -> tensor<1x10xf32> +// CHECK: return %[[VAL_1]] : tensor<1x10xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[1]{0})->f32[1,10]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[1]) -> f32[1,10] { + %Arg_0.1 = f32[1] parameter(0) + ROOT %broadcast.2 = f32[1,10] broadcast(%Arg_0.1), dimensions={0} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main() -> !mhlo.token { +// CHECK: %[[VAL_0:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token +// CHECK: return %[[VAL_0]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={()->token[]} + +ENTRY %main.2 () -> token[] { + ROOT %after-all.1 = token[] after-all() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @empty_callee.2() -> tuple<> { +// CHECK: %[[VAL_0:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_0]] : tuple<> +// CHECK: } +// CHECK: func.func @main(%[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = call @empty_callee.2() {xla_shape = "()"} : () -> tuple<> +// CHECK: return %[[VAL_1]] : tensor<4xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[4]{0})->s32[4]{0}} + +%empty_callee.2 () -> () { + ROOT %tuple.3 = () tuple() +} + +ENTRY %main.5 (Arg_0.1: s32[4]) -> s32[4] { + ROOT %Arg_0.1 = s32[4] parameter(0) + %call.4 = () call(), to_apply=%empty_callee.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @callee.2(%[[VAL_0:.*]]: tensor<4xi32>, %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } +// CHECK: func.func private @callee.7(%[[VAL_3:.*]]: tensor<4xi32>, %[[VAL_4:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_5:.*]] = stablehlo.add %[[VAL_3]], %[[VAL_4]] : tensor<4xi32> +// CHECK: return %[[VAL_5]] : tensor<4xi32> +// CHECK: } +// CHECK: func.func @main(%[[VAL_6:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_7:.*]] = call @callee.2(%[[VAL_6]], %[[VAL_6]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = call @callee.7(%[[VAL_7]], %[[VAL_7]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_8]] : tensor<4xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[4]{0})->s32[4]{0}} + +%callee.2 (Arg_0.3: s32[4], Arg_1.4: s32[4]) -> s32[4] { + %Arg_0.3 = s32[4] parameter(0) + %Arg_1.4 = s32[4] parameter(1) + ROOT %add.5 = s32[4] add(%Arg_0.3, %Arg_1.4) +} + +%callee.7 (Arg_0.8: s32[4], Arg_1.9: s32[4]) -> s32[4] { + %Arg_0.8 = s32[4] parameter(0) + %Arg_1.9 = s32[4] parameter(1) + ROOT %add.10 = s32[4] add(%Arg_0.8, %Arg_1.9) +} + +ENTRY %main.12 (Arg_0.1: s32[4]) -> s32[4] { + %Arg_0.1 = s32[4] parameter(0) + %call.6 = s32[4] call(%Arg_0.1, %Arg_0.1), to_apply=%callee.2 + ROOT %call.11 = s32[4] call(%call.6, %call.6), to_apply=%callee.7 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @callee.2(%[[VAL_0:.*]]: tensor<4xi32>, %[[VAL_1:.*]]: tensor<4xi32>) -> tuple, tensor<4xi32>> { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor<4xi32> +// CHECK: %[[VAL_3:.*]] = stablehlo.multiply %[[VAL_0]], %[[VAL_1]] : tensor<4xi32> +// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>> +// CHECK: return %[[VAL_4]] : tuple, tensor<4xi32>> +// CHECK: } +// CHECK: func.func @main(%[[VAL_5:.*]]: tensor<4xi32>) -> tuple, tensor<4xi32>> { +// CHECK: %[[VAL_6:.*]] = call @callee.2(%[[VAL_5]], %[[VAL_5]]) {xla_shape = "(s32[4]{0}, s32[4]{0})"} : (tensor<4xi32>, tensor<4xi32>) -> tuple, tensor<4xi32>> +// CHECK: %[[VAL_7:.*]] = stablehlo.get_tuple_element %[[VAL_6]][0] : (tuple, tensor<4xi32>>) -> tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = stablehlo.get_tuple_element %[[VAL_6]][1] : (tuple, tensor<4xi32>>) -> tensor<4xi32> +// CHECK: %[[VAL_9:.*]] = mhlo.tuple %[[VAL_7]], %[[VAL_8]] {xla_shape = "(s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>> +// CHECK: return %[[VAL_9]] : tuple, tensor<4xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[4]{0})->(s32[4]{0}, s32[4]{0})} + +%callee.2 (Arg_0.3: s32[4], Arg_1.4: s32[4]) -> (s32[4], s32[4]) { + %Arg_0.3 = s32[4] parameter(0) + %Arg_1.4 = s32[4] parameter(1) + %add.5 = s32[4] add(%Arg_0.3, %Arg_1.4) + %multiply.6 = s32[4] multiply(%Arg_0.3, %Arg_1.4) + ROOT %tuple.7 = (s32[4], s32[4]) tuple(%add.5, %multiply.6) +} + +ENTRY %main.12 (Arg_0.1: s32[4]) -> (s32[4], s32[4]) { + %Arg_0.1 = s32[4] parameter(0) + %call.8 = (s32[4], s32[4]) call(%Arg_0.1, %Arg_0.1), to_apply=%callee.2 + %get-tuple-element.9 = s32[4] get-tuple-element(%call.8), index=0 + %get-tuple-element.10 = s32[4] get-tuple-element(%call.8), index=1 + ROOT %tuple.11 = (s32[4], s32[4]) tuple(%get-tuple-element.9, %get-tuple-element.10) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.cosine %[[VAL_0]] : tensor<1x16x16x3xf32> +// CHECK: return %[[VAL_1]] : tensor<1x16x16x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[1,16,16,3]{3,2,1,0})->f32[1,16,16,3]{3,2,1,0}} + +ENTRY %main.3 (Arg_0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { + %Arg_0.1 = f32[1,16,16,3] parameter(0) + ROOT %cosine.2 = f32[1,16,16,3] cosine(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.sine %[[VAL_0]] : tensor<1x16x16x3xf32> +// CHECK: return %[[VAL_1]] : tensor<1x16x16x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[1,16,16,3]{3,2,1,0})->f32[1,16,16,3]{3,2,1,0}} + +ENTRY %main.3 (Arg_0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { + %Arg_0.1 = f32[1,16,16,3] parameter(0) + ROOT %sine.2 = f32[1,16,16,3] sine(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = mhlo.exponential %[[VAL_0]] {result_accuracy = #mhlo.result_accuracy>} : tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[])->f32[]} + +ENTRY %main.3 (Arg_0.1: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + ROOT %exponential.2 = f32[] exponential(%Arg_0.1), result_accuracy={tolerance={atol=0,rtol=0,ulps=10}} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> { +// CHECK: %[[VAL_1:.*]] = "mhlo.collective_broadcast"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> +// CHECK: return %[[VAL_1]] : tensor<128x32xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,32] { + %Arg_0.1 = f32[128,32] parameter(0) + ROOT %collective-broadcast.2 = f32[128,32] collective-broadcast(%Arg_0.1), channel_id=1, replica_groups={{0,1},{2,3}} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> { +// CHECK: %[[VAL_1:.*]] = "mhlo.collective_permute"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> +// CHECK: return %[[VAL_1]] : tensor<128x32xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,32] { + %Arg_0.1 = f32[128,32] parameter(0) + ROOT %collective-permute.2 = f32[128,32] collective-permute(%Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<5x2xf32>, %[[VAL_1:.*]]: tensor<5x5xf32>, %[[VAL_2:.*]]: tensor<5x7xf32>) -> tensor<5x14xf32> { +// CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], dim = 1 : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32> +// CHECK: return %[[VAL_3]] : tensor<5x14xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[5,2]{1,0}, f32[5,5]{1,0}, f32[5,7]{1,0})->f32[5,14]{1,0}} + +ENTRY %main.5 (Arg_0.1: f32[5,2], Arg_1.2: f32[5,5], Arg_2.3: f32[5,7]) -> f32[5,14] { + %Arg_0.1 = f32[5,2] parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + %Arg_2.3 = f32[5,7] parameter(2) + ROOT %concatenate.4 = f32[5,14] concatenate(%Arg_0.1, %Arg_1.2, %Arg_2.3), dimensions={1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main() -> tuple<> { +// CHECK: %[[VAL_0:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_0]] : tuple<> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={()->()} + +ENTRY %main.2 () -> () { + ROOT %tuple.1 = () tuple() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<100x26x26x32xf32>, %[[VAL_1:.*]]: tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { +// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> +// CHECK: return %[[VAL_2]] : tensor<100x28x28x1xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[100,26,26,32]{3,2,1,0}, f32[3,3,1,32]{3,2,1,0})->f32[100,28,28,1]{3,2,1,0}} + +ENTRY %main.4 (Arg_0.1: f32[100,26,26,32], Arg_1.2: f32[3,3,1,32]) -> f32[100,28,28,1] { + %Arg_0.1 = f32[100,26,26,32] parameter(0) + %Arg_1.2 = f32[3,3,1,32] parameter(1) + ROOT %convolution.3 = f32[100,28,28,1] convolution(%Arg_0.1, %Arg_1.2), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01oi->b01f +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<100x26x26x32xi8>, %[[VAL_1:.*]]: tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { +// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> +// CHECK: return %[[VAL_2]] : tensor<100x28x28x1xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s8[100,26,26,32]{3,2,1,0}, s8[3,3,1,32]{3,2,1,0})->s32[100,28,28,1]{3,2,1,0}} + +ENTRY %main.4 (Arg_0.1: s8[100,26,26,32], Arg_1.2: s8[3,3,1,32]) -> s32[100,28,28,1] { + %Arg_0.1 = s8[100,26,26,32] parameter(0) + %Arg_1.2 = s8[3,3,1,32] parameter(1) + ROOT %convolution.3 = s32[100,28,28,1] convolution(%Arg_0.1, %Arg_1.2), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01oi->b01f +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<100x26x26x32xi8>, %[[VAL_1:.*]]: tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { +// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> +// CHECK: return %[[VAL_2]] : tensor<100x28x28x1xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s8[100,26,26,32]{3,2,1,0}, s8[3,3,1,32]{3,2,1,0})->s32[100,28,28,1]{3,2,1,0}} + +ENTRY %main.4 (Arg_0.1: s8[100,26,26,32], Arg_1.2: s8[3,3,1,32]) -> s32[100,28,28,1] { + %Arg_0.1 = s8[100,26,26,32] parameter(0) + %Arg_1.2 = s8[3,3,1,32] parameter(1) + ROOT %convolution.3 = s32[100,28,28,1] convolution(%Arg_0.1, %Arg_1.2), window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.convert %[[VAL_0]] : (tensor<2xi32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[2]{0})->f32[2]{0}} + +ENTRY %main.3 (Arg_0.1: s32[2]) -> f32[2] { + %Arg_0.1 = s32[2] parameter(0) + ROOT %convert.2 = f32[2] convert(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.convert %[[VAL_0]] : (tensor<2xf32>) -> tensor<2xf8E5M2> +// CHECK: %[[VAL_2:.*]] = stablehlo.convert %[[VAL_1]] : (tensor<2xf8E5M2>) -> tensor<2xf8E4M3FN> +// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[VAL_2]] : (tensor<2xf8E4M3FN>) -> tensor<2xf8E4M3FNUZ> +// CHECK: %[[VAL_4:.*]] = stablehlo.convert %[[VAL_3]] : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf8E5M2FNUZ> +// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[VAL_4]] : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf8E4M3> +// CHECK: %[[VAL_6:.*]] = stablehlo.convert %[[VAL_5]] : (tensor<2xf8E4M3>) -> tensor<2xf8E3M4> +// CHECK: %[[VAL_7:.*]] = stablehlo.convert %[[VAL_6]] : (tensor<2xf8E3M4>) -> tensor<2xf4E2M1FN> +// CHECK: %[[VAL_8:.*]] = stablehlo.convert %[[VAL_7]] : (tensor<2xf4E2M1FN>) -> tensor<2xf32> +// CHECK: return %[[VAL_8]] : tensor<2xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2]{0})->f32[2]{0}} + +ENTRY %main.10 (Arg_0.1: f32[2]) -> f32[2] { + %Arg_0.1 = f32[2] parameter(0) + %convert.2 = f8e5m2[2] convert(%Arg_0.1) + %convert.3 = f8e4m3fn[2] convert(%convert.2) + %convert.4 = f8e4m3fnuz[2] convert(%convert.3) + %convert.5 = f8e5m2fnuz[2] convert(%convert.4) + %convert.6 = f8e4m3[2] convert(%convert.5) + %convert.7 = f8e3m4[2] convert(%convert.6) + %convert.8 = f4e2m1fn[2] convert(%convert.7) + ROOT %convert.9 = f32[2] convert(%convert.8) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<5x5xf32>, %[[VAL_1:.*]]: tensor<5x5xui32>) -> tensor<5x5xi8> { +// CHECK: %[[VAL_2:.*]] = "mhlo.stochastic_convert"(%[[VAL_0]], %[[VAL_1]]) : (tensor<5x5xf32>, tensor<5x5xui32>) -> tensor<5x5xi8> +// CHECK: return %[[VAL_2]] : tensor<5x5xi8> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[5,5]{1,0}, u32[5,5]{1,0})->s8[5,5]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[5,5], Arg_1.2: u32[5,5]) -> s8[5,5] { + %Arg_0.1 = f32[5,5] parameter(0) + %Arg_1.2 = u32[5,5] parameter(1) + ROOT %stochastic-convert.3 = s8[5,5] stochastic-convert(%Arg_0.1, %Arg_1.2) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = mhlo.copy %[[VAL_0]] : tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[2]{0})->s32[2]{0}} + +ENTRY %main.3 (Arg_0.1: s32[2]) -> s32[2] { + %Arg_0.1 = s32[2] parameter(0) + ROOT %copy.2 = s32[2] copy(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @sum.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { +// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> ({ +// CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): +// CHECK: %[[VAL_7:.*]] = stablehlo.add %[[VAL_5]], %[[VAL_6]] : tensor +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> +// CHECK: return %[[VAL_4]] : tensor<10xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} + +%sum.2 (x.3: f32[], y.4: f32[]) -> f32[] { + %x.3 = f32[] parameter(0) + %y.4 = f32[] parameter(1) + ROOT %add.5 = f32[] add(%x.3, %y.4) +} + +ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { + %Arg_0.1 = f32[10] parameter(0) + ROOT %all-reduce.6 = f32[10] all-reduce(%Arg_0.1), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=%sum.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @SetBound(%[[VAL_0]]) {backend_config = "", mhlo.literal = dense<1> : tensor} : (tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: return %[[VAL_1]] : tensor<2x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0})->f32[2,3]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[2,3]) -> f32[2,3] { + %Arg_0.1 = f32[2,3] parameter(0) + ROOT %custom-call.2 = f32[2,3] custom-call(%Arg_0.1), custom_call_target="SetBound", literal=s32[] 1 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<6xf32>, %[[VAL_1:.*]]: tensor<6xf32>, %[[VAL_2:.*]]: tensor<3xi32>, %[[VAL_3:.*]]: tensor<3xi32>, %[[VAL_4:.*]]: tensor<3xi32>, %[[VAL_5:.*]]: tensor<3xi32>) -> tensor<6xf32> { +// CHECK: %[[VAL_6:.*]] = stablehlo.custom_call @ragged_all_to_all(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) {api_version = 4 : i32, backend_config = {channel_id = 1 : i64, replica_groups = dense<{{\[\[}}0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32> +// CHECK: return %[[VAL_6]] : tensor<6xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[6]{0}, f32[6]{0}, s32[3]{0}, s32[3]{0}, s32[3]{0}, /*index=5*/s32[3]{0})->f32[6]{0}} + +ENTRY %main.8 (Arg_0.1: f32[6], Arg_1.2: f32[6], Arg_2.3: s32[3], Arg_3.4: s32[3], Arg_4.5: s32[3], Arg_5.6: s32[3]) -> f32[6] { + %Arg_0.1 = f32[6] parameter(0) + %Arg_1.2 = f32[6] parameter(1) + %Arg_2.3 = s32[3] parameter(2) + %Arg_3.4 = s32[3] parameter(3) + %Arg_4.5 = s32[3] parameter(4) + %Arg_5.6 = s32[3] parameter(5) + ROOT %ragged-all-to-all.7 = f32[6] ragged-all-to-all(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4, %Arg_4.5, /*index=5*/%Arg_5.6), channel_id=1, replica_groups={{0,1,2}} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @top_k_gt_comparator.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_4:.*]] = mhlo.compare GT, %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_5:.*]]: tensor<16x256xbf16>, %[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor<16x256xi32>, %[[VAL_8:.*]]: tensor) -> tuple, tensor<16x4xi32>> { +// CHECK: %[[VAL_9:.*]]:2 = "stablehlo.sort"(%[[VAL_5]], %[[VAL_7]]) <{dimension = 1 : i64, is_stable = false}> ({ +// CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor): +// CHECK: %[[VAL_14:.*]] = mhlo.compare GT, %[[VAL_10]], %[[VAL_11]] : (tensor, tensor) -> tensor +// CHECK: stablehlo.return %[[VAL_14]] : tensor +// CHECK: }) : (tensor<16x256xbf16>, tensor<16x256xi32>) -> (tensor<16x256xbf16>, tensor<16x256xi32>) +// CHECK: %[[VAL_15:.*]] = stablehlo.slice %[[VAL_16:.*]]#0 [0:16, 0:4] : (tensor<16x256xbf16>) -> tensor<16x4xbf16> +// CHECK: %[[VAL_17:.*]] = stablehlo.slice %[[VAL_16]]#1 [0:16, 0:4] : (tensor<16x256xi32>) -> tensor<16x4xi32> +// CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_15]], %[[VAL_17]] {xla_shape = "(bf16[16,4]{1,0}, s32[16,4]{1,0})"} : tuple, tensor<16x4xi32>> +// CHECK: return %[[VAL_18]] : tuple, tensor<16x4xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(bf16[16,256]{1,0}, s32[], s32[16,256]{1,0}, bf16[])->(bf16[16,4]{1,0}, s32[16,4]{1,0})} + +%top_k_gt_comparator.5 (Arg_0.6: bf16[], Arg_1.7: bf16[], Arg_2.8: s32[], Arg_3.9: s32[]) -> pred[] { + %Arg_2.8 = s32[] parameter(2) + %Arg_3.9 = s32[] parameter(3) + %Arg_0.6 = bf16[] parameter(0) + %Arg_1.7 = bf16[] parameter(1) + ROOT %compare.10 = pred[] compare(%Arg_0.6, %Arg_1.7), direction=GT +} + +ENTRY %main.20 (Arg_0.1: bf16[16,256], Arg_1.2: s32[], Arg_2.3: s32[16,256], Arg_3.4: bf16[]) -> (bf16[16,4], s32[16,4]) { + %Arg_1.2 = s32[] parameter(1) + %Arg_3.4 = bf16[] parameter(3) + %Arg_0.1 = bf16[16,256] parameter(0) + %Arg_2.3 = s32[16,256] parameter(2) + %sort.11 = (bf16[16,256], s32[16,256]) sort(%Arg_0.1, %Arg_2.3), dimensions={1}, to_apply=%top_k_gt_comparator.5 + %get-tuple-element.12 = bf16[16,256] get-tuple-element(%sort.11), index=0 + %slice.13 = bf16[16,4] slice(%get-tuple-element.12), slice={[0:16], [0:4]} + %get-tuple-element.14 = s32[16,256] get-tuple-element(%sort.11), index=1 + %slice.15 = s32[16,4] slice(%get-tuple-element.14), slice={[0:16], [0:4]} + %tuple.16 = (bf16[16,4], s32[16,4]) tuple(%slice.13, %slice.15) + %get-tuple-element.17 = bf16[16,4] get-tuple-element(%tuple.16), index=0 + %get-tuple-element.18 = s32[16,4] get-tuple-element(%tuple.16), index=1 + ROOT %tuple.19 = (bf16[16,4], s32[16,4]) tuple(%get-tuple-element.17, %get-tuple-element.18) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @top_k_gt_comparator.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_4:.*]] = mhlo.compare GT, %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } +// CHECK: func.func private @top_k_gt_comparator.14(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_9:.*]] = mhlo.compare GT, %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_9]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_10:.*]]: tensor<16x256xbf16>, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor<16x256xi32>, %[[VAL_13:.*]]: tensor) -> tuple, tensor<16x4xi32>> { +// CHECK: %[[VAL_14:.*]] = mhlo.custom_call @PartialReduce(%[[VAL_10]], %[[VAL_12]], %[[VAL_13]], %[[VAL_11]]) {backend_config = "{\22log2_reduction\22: 1, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 4, \22recall_target\22: 0.949218}", called_computations = [@top_k_gt_comparator.5], xla_shape = "(bf16[16,128]{1,0}, s32[16,128]{1,0})"} : (tensor<16x256xbf16>, tensor<16x256xi32>, tensor, tensor) -> tuple, tensor<16x128xi32>> +// CHECK: %[[VAL_15:.*]] = stablehlo.get_tuple_element %[[VAL_14]][0] : (tuple, tensor<16x128xi32>>) -> tensor<16x128xbf16> +// CHECK: %[[VAL_16:.*]] = stablehlo.get_tuple_element %[[VAL_14]][1] : (tuple, tensor<16x128xi32>>) -> tensor<16x128xi32> +// CHECK: %[[VAL_17:.*]]:2 = "stablehlo.sort"(%[[VAL_15]], %[[VAL_16]]) <{dimension = 1 : i64, is_stable = false}> ({ +// CHECK: ^bb0(%[[VAL_18:.*]]: tensor, %[[VAL_19:.*]]: tensor, %[[VAL_20:.*]]: tensor, %[[VAL_21:.*]]: tensor): +// CHECK: %[[VAL_22:.*]] = mhlo.compare GT, %[[VAL_18]], %[[VAL_19]] : (tensor, tensor) -> tensor +// CHECK: stablehlo.return %[[VAL_22]] : tensor +// CHECK: }) : (tensor<16x128xbf16>, tensor<16x128xi32>) -> (tensor<16x128xbf16>, tensor<16x128xi32>) +// CHECK: %[[VAL_23:.*]] = stablehlo.slice %[[VAL_24:.*]]#0 [0:16, 0:4] : (tensor<16x128xbf16>) -> tensor<16x4xbf16> +// CHECK: %[[VAL_25:.*]] = stablehlo.slice %[[VAL_24]]#1 [0:16, 0:4] : (tensor<16x128xi32>) -> tensor<16x4xi32> +// CHECK: %[[VAL_26:.*]] = mhlo.tuple %[[VAL_23]], %[[VAL_25]] {xla_shape = "(bf16[16,4]{1,0}, s32[16,4]{1,0})"} : tuple, tensor<16x4xi32>> +// CHECK: return %[[VAL_26]] : tuple, tensor<16x4xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(bf16[16,256]{1,0}, s32[], s32[16,256]{1,0}, bf16[])->(bf16[16,4]{1,0}, s32[16,4]{1,0})} + +%top_k_gt_comparator.5 (Arg_0.6: bf16[], Arg_1.7: bf16[], Arg_2.8: s32[], Arg_3.9: s32[]) -> pred[] { + %Arg_2.8 = s32[] parameter(2) + %Arg_3.9 = s32[] parameter(3) + %Arg_0.6 = bf16[] parameter(0) + %Arg_1.7 = bf16[] parameter(1) + ROOT %compare.10 = pred[] compare(%Arg_0.6, %Arg_1.7), direction=GT +} + +%top_k_gt_comparator.14 (Arg_0.15: bf16[], Arg_1.16: bf16[], Arg_2.17: s32[], Arg_3.18: s32[]) -> pred[] { + %Arg_2.17 = s32[] parameter(2) + %Arg_3.18 = s32[] parameter(3) + %Arg_0.15 = bf16[] parameter(0) + %Arg_1.16 = bf16[] parameter(1) + ROOT %compare.19 = pred[] compare(%Arg_0.15, %Arg_1.16), direction=GT +} + +ENTRY %main.29 (Arg_0.1: bf16[16,256], Arg_1.2: s32[], Arg_2.3: s32[16,256], Arg_3.4: bf16[]) -> (bf16[16,4], s32[16,4]) { + %Arg_0.1 = bf16[16,256] parameter(0) + %Arg_2.3 = s32[16,256] parameter(2) + %Arg_3.4 = bf16[] parameter(3) + %Arg_1.2 = s32[] parameter(1) + %custom-call.11 = (bf16[16,128], s32[16,128]) custom-call(%Arg_0.1, %Arg_2.3, %Arg_3.4, %Arg_1.2), custom_call_target="PartialReduce", called_computations={%top_k_gt_comparator.5}, backend_config={"log2_reduction": 1, "reduction_dim": 1, "to_apply_type": "comparator", "top_k": 4, "recall_target": 0.949218} + %get-tuple-element.12 = bf16[16,128] get-tuple-element(%custom-call.11), index=0 + %get-tuple-element.13 = s32[16,128] get-tuple-element(%custom-call.11), index=1 + %sort.20 = (bf16[16,128], s32[16,128]) sort(%get-tuple-element.12, %get-tuple-element.13), dimensions={1}, to_apply=%top_k_gt_comparator.14 + %get-tuple-element.21 = bf16[16,128] get-tuple-element(%sort.20), index=0 + %slice.22 = bf16[16,4] slice(%get-tuple-element.21), slice={[0:16], [0:4]} + %get-tuple-element.23 = s32[16,128] get-tuple-element(%sort.20), index=1 + %slice.24 = s32[16,4] slice(%get-tuple-element.23), slice={[0:16], [0:4]} + %tuple.25 = (bf16[16,4], s32[16,4]) tuple(%slice.22, %slice.24) + %get-tuple-element.26 = bf16[16,4] get-tuple-element(%tuple.25), index=0 + %get-tuple-element.27 = s32[16,4] get-tuple-element(%tuple.25), index=1 + ROOT %tuple.28 = (bf16[16,4], s32[16,4]) tuple(%get-tuple-element.26, %get-tuple-element.27) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tensor<1x2x3xf32> { +// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "bar", custom_call_schedule = #mhlo, has_side_effect = true} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: return %[[VAL_2]] : tensor<1x2x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0}, f32[5,5]{1,0})->f32[1,2,3]{2,1,0}} + +ENTRY %main.4 (Arg_0.1: f32[2,3], Arg_1.2: f32[5,5]) -> f32[1,2,3] { + %Arg_0.1 = f32[2,3] parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + ROOT %custom-call.3 = f32[1,2,3] custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="foo", custom_call_has_side_effect=true, schedule=SCHEDULE_LATEST, backend_config="bar" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tensor<1x2x3xf32> { +// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "bar", custom_call_schedule = #mhlo, has_side_effect = true} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: return %[[VAL_2]] : tensor<1x2x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0}, f32[5,5]{1,0})->f32[1,2,3]{2,1,0}} + +ENTRY %main.4 (Arg_0.1: f32[2,3], Arg_1.2: f32[5,5]) -> f32[1,2,3] { + %Arg_0.1 = f32[2,3] parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + ROOT %custom-call.3 = f32[1,2,3] custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="foo", custom_call_has_side_effect=true, schedule=SCHEDULE_EARLIEST, backend_config="bar" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tuple> { +// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0})"} : (tensor<2x3xf32>) -> tuple> +// CHECK: return %[[VAL_1]] : tuple> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0})->(f32[2,3]{1,0})} + +ENTRY %main.3 (Arg_0.1: f32[2,3]) -> (f32[2,3]) { + %Arg_0.1 = f32[2,3] parameter(0) + ROOT %custom-call.2 = (f32[2,3]) custom-call(%Arg_0.1), custom_call_target="foo" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> { +// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : (tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> +// CHECK: return %[[VAL_1]] : tuple, tensor<4x5xf16>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0})->(f32[2,3]{1,0}, f16[4,5]{1,0})} + +ENTRY %main.3 (Arg_0.1: f32[2,3]) -> (f32[2,3], f16[4,5]) { + %Arg_0.1 = f32[2,3] parameter(0) + ROOT %custom-call.2 = (f32[2,3], f16[4,5]) custom-call(%Arg_0.1), custom_call_target="foo" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> { +// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : (tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> +// CHECK: %[[VAL_2:.*]] = stablehlo.get_tuple_element %[[VAL_1]][0] : (tuple, tensor<4x5xf16>>) -> tensor<2x3xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.get_tuple_element %[[VAL_1]][1] : (tuple, tensor<4x5xf16>>) -> tensor<4x5xf16> +// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : tuple, tensor<4x5xf16>> +// CHECK: return %[[VAL_4]] : tuple, tensor<4x5xf16>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0})->(f32[2,3]{1,0}, f16[4,5]{1,0})} + +ENTRY %main.6 (Arg_0.1: f32[2,3]) -> (f32[2,3], f16[4,5]) { + %Arg_0.1 = f32[2,3] parameter(0) + %custom-call.2 = (f32[2,3], f16[4,5]) custom-call(%Arg_0.1), custom_call_target="foo" + %get-tuple-element.3 = f32[2,3] get-tuple-element(%custom-call.2), index=0 + %get-tuple-element.4 = f16[4,5] get-tuple-element(%custom-call.2), index=1 + ROOT %tuple.5 = (f32[2,3], f16[4,5]) tuple(%get-tuple-element.3, %get-tuple-element.4) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi8>, %[[VAL_1:.*]]: tensor<3xi8>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3xi8>, tensor<3xi8>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s8[3]{0}, s8[3]{0})->s64[]} + +ENTRY %main.4 (Arg_0.1: s8[3], Arg_1.2: s8[3]) -> s64[] { + %Arg_0.1 = s8[3] parameter(0) + %Arg_1.2 = s8[3] parameter(1) + ROOT %dot.3 = s64[] dot(%Arg_0.1, %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi4>, %[[VAL_1:.*]]: tensor<3xi4>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3xi4>, tensor<3xi4>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s4[3]{0}, s4[3]{0})->s8[]} + +ENTRY %main.4 (Arg_0.1: s4[3], Arg_1.2: s4[3]) -> s8[] { + %Arg_0.1 = s4[3] parameter(0) + %Arg_1.2 = s4[3] parameter(1) + ROOT %dot.3 = s8[] dot(%Arg_0.1, %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xui4>, %[[VAL_1:.*]]: tensor<3xui4>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3xui4>, tensor<3xui4>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(u4[3]{0}, u4[3]{0})->u8[]} + +ENTRY %main.4 (Arg_0.1: u4[3], Arg_1.2: u4[3]) -> u8[] { + %Arg_0.1 = u4[3] parameter(0) + %Arg_1.2 = u4[3] parameter(1) + ROOT %dot.3 = u8[] dot(%Arg_0.1, %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2x2xi8>, %[[VAL_1:.*]]: tensor<2x2x3xi8>) -> tensor<2x2x3xi32> { +// CHECK: %[[VAL_2:.*]] = "mhlo.dot_general"(%[[VAL_0]], %[[VAL_1]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> +// CHECK: return %[[VAL_2]] : tensor<2x2x3xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s8[2,2,2]{2,1,0}, s8[2,2,3]{2,1,0})->s32[2,2,3]{2,1,0}} + +ENTRY %main.4 (Arg_0.1: s8[2,2,2], Arg_1.2: s8[2,2,3]) -> s32[2,2,3] { + %Arg_0.1 = s8[2,2,2] parameter(0) + %Arg_1.2 = s8[2,2,3] parameter(1) + ROOT %dot.3 = s32[2,2,3] dot(%Arg_0.1, %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<10x16xbf16>, %[[VAL_1:.*]]: tensor<32x20xbf16>, %[[VAL_2:.*]]: tensor<10x2xui16>) -> tensor<10x20xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.sparse_dot"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{dot_dimension_numbers = #mhlo.dot, lhs_sparsity = #mhlo.sparsity, precision_config = [#mhlo, #mhlo]}> : (tensor<10x16xbf16>, tensor<32x20xbf16>, tensor<10x2xui16>) -> tensor<10x20xf32> +// CHECK: return %[[VAL_3]] : tensor<10x20xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(bf16[10,16]{1,0}, bf16[32,20]{1,0}, u16[10,2]{1,0})->f32[10,20]{1,0}} + +ENTRY %main.5 (Arg_0.1: bf16[10,16], Arg_1.2: bf16[32,20], Arg_2.3: u16[10,2]) -> f32[10,20] { + %Arg_0.1 = bf16[10,16] parameter(0) + %Arg_1.2 = bf16[32,20] parameter(1) + %Arg_2.3 = u16[10,2] parameter(2) + ROOT %dot.4 = f32[10,20] dot(%Arg_0.1, %Arg_1.2, %Arg_2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: tensor<4x5xi32>) -> tensor<3x5xi32> { +// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> +// CHECK: %[[VAL_3:.*]] = stablehlo.transpose %[[VAL_2]], dims = [0, 1] : (tensor<3x5xi32>) -> tensor<3x5xi32> +// CHECK: return %[[VAL_3]] : tensor<3x5xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3,4]{1,0}, s32[4,5]{1,0})->s32[3,5]{1,0}} + +ENTRY %main.5 (Arg_0.1: s32[3,4], Arg_1.2: s32[4,5]) -> s32[3,5] { + %Arg_0.1 = s32[3,4] parameter(0) + %Arg_1.2 = s32[4,5] parameter(1) + %dot.3 = s32[3,5] dot(%Arg_0.1, %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} + ROOT %transpose.4 = s32[3,5] transpose(%dot.3), dimensions={0,1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x9xf32>) -> tensor<3x5xcomplex> { +// CHECK: %[[VAL_1:.*]] = stablehlo.fft %[[VAL_0]], type = RFFT, length = [9] : (tensor<3x9xf32>) -> tensor<3x5xcomplex> +// CHECK: return %[[VAL_1]] : tensor<3x5xcomplex> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,9]{1,0})->c64[3,5]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[3,9]) -> c64[3,5] { + %Arg_0.1 = f32[3,9] parameter(0) + ROOT %fft.2 = c64[3,5] fft(%Arg_0.1), fft_type=RFFT, fft_length={9} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<200x100x300xf32>, %[[VAL_1:.*]]: tensor<10x2xi32>) -> tensor<10x300xf32> { +// CHECK: %[[VAL_2:.*]] = "mhlo.gather"(%[[VAL_0]], %[[VAL_1]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>}> : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> +// CHECK: return %[[VAL_2]] : tensor<10x300xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[200,100,300]{2,1,0}, s32[10,2]{1,0})->f32[10,300]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[10,2]) -> f32[10,300] { + %Arg_0.1 = f32[200,100,300] parameter(0) + %Arg_1.2 = s32[10,2] parameter(1) + ROOT %gather.3 = f32[10,300] gather(%Arg_0.1, %Arg_1.2), offset_dims={1}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,300}, indices_are_sorted=true +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<200x100x300xf32>, %[[VAL_1:.*]]: tensor<100x200x1xi32>) -> tensor<100x200x300xf32> { +// CHECK: %[[VAL_2:.*]] = "mhlo.gather"(%[[VAL_0]], %[[VAL_1]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>}> : (tensor<200x100x300xf32>, tensor<100x200x1xi32>) -> tensor<100x200x300xf32> +// CHECK: return %[[VAL_2]] : tensor<100x200x300xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[200,100,300]{2,1,0}, s32[100,200,1]{2,1,0})->f32[100,200,300]{2,1,0}} + +ENTRY %main.4 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[100,200,1]) -> f32[100,200,300] { + %Arg_0.1 = f32[200,100,300] parameter(0) + %Arg_1.2 = s32[100,200,1] parameter(1) + ROOT %gather.3 = f32[100,200,300] gather(%Arg_0.1, %Arg_1.2), offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, start_indices_batching_dims={1,0}, index_vector_dim=2, slice_sizes={1,1,300}, indices_are_sorted=true +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x2xf32>, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<2> : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,2]{1,0}, s32[])->s32[]} + +ENTRY %main.4 (Arg_0.1: f32[4,2], Arg_1.2: s32[]) -> s32[] { + %Arg_0.1 = f32[4,2] parameter(0) + %Arg_1.2 = s32[] parameter(1) + ROOT %constant.3 = s32[] constant(2) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor) -> tensor<4x?xf32, #mhlo.type_extensions> { +// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 1 : (tensor<4x4xf32>, tensor) -> tensor<4x?xf32, #mhlo.type_extensions> +// CHECK: return %[[VAL_2]] : tensor<4x?xf32, #mhlo.type_extensions> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,4]{1,0}, s32[])->f32[4,<=4]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { + %Arg_0.1 = f32[4,4] parameter(0) + %Arg_1.2 = s32[] parameter(1) + ROOT %set-dimension-size.3 = f32[4,<=4] set-dimension-size(%Arg_0.1, %Arg_1.2), dimensions={1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tuple, tensor>) -> tensor { +// CHECK: %[[VAL_1:.*]] = stablehlo.get_tuple_element %[[VAL_0]][0] : (tuple, tensor>) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={((f32[], s32[]))->f32[]} + +ENTRY %main.3 (Arg_0.1: (f32[], s32[])) -> f32[] { + %Arg_0.1 = (f32[], s32[]) parameter(0) + ROOT %get-tuple-element.2 = f32[] get-tuple-element(%Arg_0.1), index=0 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tuple, tensor>, !mhlo.token> { +// CHECK: %[[VAL_1:.*]]:3 = "mhlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = {{\[\[}}1, 0], []]}> : (!mhlo.token) -> (tensor<3x3xi32>, tensor, !mhlo.token) +// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,3]{1,0}, pred[])"} : tuple, tensor> +// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_1]]#2 {xla_shape = "((s32[3,3]{1,0}, pred[]), token[])"} : tuple, tensor>, !mhlo.token> +// CHECK: return %[[VAL_3]] : tuple, tensor>, !mhlo.token> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->((s32[3,3]{1,0}, pred[]), token[])} + +ENTRY %main.9 (Arg_0.1: token[]) -> ((s32[3,3], pred[]), token[]) { + %Arg_0.1 = token[] parameter(0) + %infeed.2 = ((s32[3,3], pred[]), token[]) infeed(%Arg_0.1), infeed_config="foobar" + %get-tuple-element.3 = (s32[3,3], pred[]) get-tuple-element(%infeed.2), index=0 + %get-tuple-element.4 = s32[3,3] get-tuple-element(%get-tuple-element.3), index=0 + %get-tuple-element.5 = pred[] get-tuple-element(%get-tuple-element.3), index=1 + %tuple.7 = (s32[3,3], pred[]) tuple(%get-tuple-element.4, %get-tuple-element.5) + %get-tuple-element.6 = token[] get-tuple-element(%infeed.2), index=1 + ROOT %tuple.8 = ((s32[3,3], pred[]), token[]) tuple(%tuple.7, %get-tuple-element.6) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false, mhlo.xla_entry_computation_result_layout = [dense<[0, 1]> : tensor<2xindex>], mhlo.xla_entry_computation_result_tiles = {{\[\[}}]]} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tensor<3x3xi32> { +// CHECK: %[[VAL_1:.*]]:2 = "mhlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = {{\[\[}}1, 0]]}> : (!mhlo.token) -> (tensor<3x3xi32>, !mhlo.token) +// CHECK: return %[[VAL_1]]#0 : tensor<3x3xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->s32[3,3]{0,1}} + +ENTRY %main.6 (Arg_0.1: token[]) -> s32[3,3] { + %Arg_0.1 = token[] parameter(0) + %infeed.2 = ((s32[3,3]), token[]) infeed(%Arg_0.1), infeed_config="foobar" + %get-tuple-element.3 = (s32[3,3]) get-tuple-element(%infeed.2), index=0 + ROOT %get-tuple-element.4 = s32[3,3] get-tuple-element(%get-tuple-element.3), index=0 + %get-tuple-element.5 = token[] get-tuple-element(%infeed.2), index=1 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_1:.*]] = "mhlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = []}> : (!mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->token[]} + +ENTRY %main.4 (Arg_0.1: token[]) -> token[] { + %Arg_0.1 = token[] parameter(0) + %infeed.2 = ((), token[]) infeed(%Arg_0.1), infeed_config="foobar" + ROOT %get-tuple-element.3 = token[] get-tuple-element(%infeed.2), index=1 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main() -> tensor<1x10xf32> { +// CHECK: %[[VAL_0:.*]] = stablehlo.iota dim = 0 : tensor<10xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.reshape %[[VAL_0]] : (tensor<10xf32>) -> tensor<1x10xf32> +// CHECK: return %[[VAL_1]] : tensor<1x10xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={()->f32[1,10]{1,0}} + +ENTRY %main.3 () -> f32[1,10] { + %iota.1 = f32[10] iota(), iota_dimension=0 + ROOT %reshape.2 = f32[1,10] reshape(%iota.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.3(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<4xf32>, %[[VAL_4:.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[VAL_5:.*]] = "stablehlo.map"(%[[VAL_3]], %[[VAL_4]]) <{dimensions = array}> ({ +// CHECK: ^bb0(%[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor): +// CHECK: %[[VAL_8:.*]] = stablehlo.add %[[VAL_6]], %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_8]] : tensor +// CHECK: }) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[VAL_5]] : tensor<4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4]{0}, f32[4]{0})->f32[4]{0}} + +%region_0.3 (Arg_0.4: f32[], Arg_1.5: f32[]) -> f32[] { + %Arg_0.4 = f32[] parameter(0) + %Arg_1.5 = f32[] parameter(1) + ROOT %add.6 = f32[] add(%Arg_0.4, %Arg_1.5) +} + +ENTRY %main.8 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + ROOT %map.7 = f32[4] map(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%region_0.3 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xf32> { +// CHECK: return %[[VAL_0]] : tensor<4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4]{0}, s32[4]{0})->f32[4]{0}} + +ENTRY %main.3 (Arg_0.1: f32[4], Arg_1.2: s32[4]) -> f32[4] { + ROOT %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = s32[4] parameter(1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_2:.*]] = "mhlo.outfeed"(%[[VAL_0]], %[[VAL_1]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (tensor<3xi32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3]{0}, token[])->token[]} + +ENTRY %main.5 (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = s32[3] parameter(0) + %tuple.3 = (s32[3]) tuple(%Arg_0.1) + %Arg_1.2 = token[] parameter(1) + ROOT %outfeed.4 = token[] outfeed(%tuple.3, %Arg_1.2), outfeed_shape=(s32[3]{0}), outfeed_config="foobar" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x2xi32>, %[[VAL_1:.*]]: !mhlo.token) -> (!mhlo.token {mhlo.sharding = "{{\{\{}}devices=[2,1]0,1}, {maximal device=0}}"}) { +// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @Sharding(%[[VAL_0]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<3x2xi32>) -> tensor<3x2xi32> +// CHECK: %[[VAL_3:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[VAL_2]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<3x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_4:.*]] = "mhlo.outfeed"(%[[VAL_3]], %[[VAL_1]]) <{outfeed_config = "foobar"}> {mhlo.sharding = "{{\{\{}}devices=[2,1]0,1}, {maximal device=0}}", xla_shape = "token[]"} : (tensor<6x2xi32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_4]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3,2]{1,0}, token[])->token[]} + +ENTRY %main.7 (Arg_0.1: s32[3,2], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = s32[3,2] parameter(0) + %custom-call.3 = s32[3,2] custom-call(%Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2]0,1} + %custom-call.4 = s32[6,2] custom-call(%custom-call.3), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,2]0,1} + %tuple.5 = (s32[6,2]) tuple(%custom-call.4) + %Arg_1.2 = token[] parameter(1) + ROOT %outfeed.6 = token[] outfeed(%tuple.5, %Arg_1.2), outfeed_shape=(s32[6,2]{1,0}), outfeed_config="foobar", sharding={{devices=[2,1]0,1}, {maximal device=0}} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_3:.*]] = "mhlo.outfeed"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (tensor<3xi32>, tensor<3xi32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_3]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3]{0}, s32[3]{0}, token[])->token[]} + +ENTRY %main.6 (Arg_0.1: s32[3], Arg_1.2: s32[3], Arg_2.3: token[]) -> token[] { + %Arg_0.1 = s32[3] parameter(0) + %Arg_1.2 = s32[3] parameter(1) + %tuple.4 = (s32[3], s32[3]) tuple(%Arg_0.1, %Arg_1.2) + %Arg_2.3 = token[] parameter(2) + ROOT %outfeed.5 = token[] outfeed(%tuple.4, %Arg_2.3), outfeed_shape=(s32[3]{0}, s32[3]{0}), outfeed_config="foobar" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_1:.*]] = "mhlo.outfeed"(%[[VAL_0]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (!mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->token[]} + +ENTRY %main.4 (Arg_0.1: token[]) -> token[] { + %tuple.2 = () tuple() + %Arg_0.1 = token[] parameter(0) + ROOT %outfeed.3 = token[] outfeed(%tuple.2, %Arg_0.1), outfeed_shape=(), outfeed_config="foobar" +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x6xf32>, %[[VAL_1:.*]]: tensor) -> tensor<13x19xf32> { +// CHECK: %[[VAL_2:.*]] = stablehlo.pad %[[VAL_0]], %[[VAL_1]], low = [2, 3], high = [4, 5], interior = [1, 1] : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> +// CHECK: return %[[VAL_2]] : tensor<13x19xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,6]{1,0}, f32[])->f32[13,19]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[4,6], Arg_1.2: f32[]) -> f32[13,19] { + %Arg_0.1 = f32[4,6] parameter(0) + %Arg_1.2 = f32[] parameter(1) + ROOT %pad.3 = f32[13,19] pad(%Arg_0.1, %Arg_1.2), padding=2_4_1x3_5_1 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tuple, !mhlo.token> { +// CHECK: %[[VAL_1:.*]]:2 = "mhlo.recv"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) +// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,4]{1,0}, token[])"} : tuple, !mhlo.token> +// CHECK: return %[[VAL_2]] : tuple, !mhlo.token> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} + +ENTRY %main.7 (Arg_0.1: token[]) -> (s32[3,4], token[]) { + %Arg_0.1 = token[] parameter(0) + %recv.2 = (s32[3,4], u32[], token[]) recv(%Arg_0.1), channel_id=5, is_host_transfer=true + %recv-done.3 = (s32[3,4], token[]) recv-done(%recv.2), channel_id=5, is_host_transfer=true + %get-tuple-element.4 = s32[3,4] get-tuple-element(%recv-done.3), index=0 + %get-tuple-element.5 = token[] get-tuple-element(%recv-done.3), index=1 + ROOT %tuple.6 = (s32[3,4], token[]) tuple(%get-tuple-element.4, %get-tuple-element.5) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tuple, !mhlo.token> { +// CHECK: %[[VAL_1:.*]]:2 = "mhlo.recv"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) +// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,4]{1,0}, token[])"} : tuple, !mhlo.token> +// CHECK: return %[[VAL_2]] : tuple, !mhlo.token> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} + +ENTRY %main.7 (Arg_0.1: token[]) -> (s32[3,4], token[]) { + %Arg_0.1 = token[] parameter(0) + %recv.2 = (s32[3,4], u32[], token[]) recv(%Arg_0.1), channel_id=5 + %recv-done.3 = (s32[3,4], token[]) recv-done(%recv.2), channel_id=5 + %get-tuple-element.4 = s32[3,4] get-tuple-element(%recv-done.3), index=0 + %get-tuple-element.5 = token[] get-tuple-element(%recv-done.3), index=1 + ROOT %tuple.6 = (s32[3,4], token[]) tuple(%get-tuple-element.4, %get-tuple-element.5) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_1:.*]] = "mhlo.recv"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> : (!mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->token[]} + +ENTRY %main.6 (Arg_0.1: token[]) -> token[] { + %Arg_0.1 = token[] parameter(0) + %recv.2 = ((), u32[], token[]) recv(%Arg_0.1), channel_id=5 + %recv-done.3 = ((), token[]) recv-done(%recv.2), channel_id=5 + %get-tuple-element.4 = () get-tuple-element(%recv-done.3), index=0 + ROOT %get-tuple-element.5 = token[] get-tuple-element(%recv-done.3), index=1 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tuple, tensor> { +// CHECK: %[[VAL_4:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_5:.*]] = stablehlo.maximum %[[VAL_1]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> +// CHECK: return %[[VAL_6]] : tuple, tensor> +// CHECK: } +// CHECK: func.func @main(%[[VAL_7:.*]]: tensor<1x10xf32>, %[[VAL_8:.*]]: tensor<1x10xi32>, %[[VAL_9:.*]]: tensor, %[[VAL_10:.*]]: tensor) -> tuple, tensor<1xi32>> { +// CHECK: %[[VAL_11:.*]]:2 = mhlo.reduce(%[[VAL_7]] init: %[[VAL_9]]), (%[[VAL_8]] init: %[[VAL_10]]) across dimensions = [1] : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) +// CHECK: reducer(%[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor) { +// CHECK: %[[VAL_16:.*]] = stablehlo.maximum %[[VAL_12]], %[[VAL_13]] : tensor +// CHECK: %[[VAL_17:.*]] = stablehlo.maximum %[[VAL_14]], %[[VAL_15]] : tensor +// CHECK: mhlo.return %[[VAL_16]], %[[VAL_17]] : tensor, tensor +// CHECK: } +// CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_11]]#0, %[[VAL_11]]#1 {xla_shape = "(f32[1]{0}, s32[1]{0})"} : tuple, tensor<1xi32>> +// CHECK: return %[[VAL_18]] : tuple, tensor<1xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[1,10]{1,0}, s32[1,10]{1,0}, f32[], s32[])->(f32[1]{0}, s32[1]{0})} + +%region_0.5 (Arg_0.6: f32[], Arg_1.7: s32[], Arg_2.8: f32[], Arg_3.9: s32[]) -> (f32[], s32[]) { + %Arg_0.6 = f32[] parameter(0) + %Arg_2.8 = f32[] parameter(2) + %maximum.10 = f32[] maximum(%Arg_0.6, %Arg_2.8) + %Arg_1.7 = s32[] parameter(1) + %Arg_3.9 = s32[] parameter(3) + %maximum.11 = s32[] maximum(%Arg_1.7, %Arg_3.9) + ROOT %tuple.12 = (f32[], s32[]) tuple(%maximum.10, %maximum.11) +} + +ENTRY %main.17 (Arg_0.1: f32[1,10], Arg_1.2: s32[1,10], Arg_2.3: f32[], Arg_3.4: s32[]) -> (f32[1], s32[1]) { + %Arg_0.1 = f32[1,10] parameter(0) + %Arg_1.2 = s32[1,10] parameter(1) + %Arg_2.3 = f32[] parameter(2) + %Arg_3.4 = s32[] parameter(3) + %reduce.13 = (f32[1], s32[1]) reduce(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4), dimensions={1}, to_apply=%region_0.5 + %get-tuple-element.14 = f32[1] get-tuple-element(%reduce.13), index=0 + %get-tuple-element.15 = s32[1] get-tuple-element(%reduce.13), index=1 + ROOT %tuple.16 = (f32[1], s32[1]) tuple(%get-tuple-element.14, %get-tuple-element.15) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.3(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<2x17x31x7xi32>) -> tensor<2x5x8x7xi32> { +// CHECK: %[[VAL_4:.*]] = stablehlo.constant dense<-2147483648> : tensor +// CHECK: %[[VAL_5:.*]] = "stablehlo.reduce_window"(%[[VAL_3]], %[[VAL_4]]) <{base_dilations = array, padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ +// CHECK: ^bb0(%[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor): +// CHECK: %[[VAL_8:.*]] = stablehlo.maximum %[[VAL_6]], %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_8]] : tensor +// CHECK: }) : (tensor<2x17x31x7xi32>, tensor) -> tensor<2x5x8x7xi32> +// CHECK: return %[[VAL_5]] : tensor<2x5x8x7xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[2,17,31,7]{3,2,1,0})->s32[2,5,8,7]{3,2,1,0}} + +%region_0.3 (Arg_0.4: s32[], Arg_1.5: s32[]) -> s32[] { + %Arg_0.4 = s32[] parameter(0) + %Arg_1.5 = s32[] parameter(1) + ROOT %maximum.6 = s32[] maximum(%Arg_0.4, %Arg_1.5) +} + +ENTRY %main.8 (Arg_0.1: s32[2,17,31,7]) -> s32[2,5,8,7] { + %Arg_0.1 = s32[2,17,31,7] parameter(0) + %constant.2 = s32[] constant(-2147483648) + ROOT %reduce-window.7 = s32[2,5,8,7] reduce-window(%Arg_0.1, %constant.2), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, to_apply=%region_0.3 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xf32>) -> tensor<1x2xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.reshape %[[VAL_0]] : (tensor<2xf32>) -> tensor<1x2xf32> +// CHECK: return %[[VAL_1]] : tensor<1x2xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2]{0})->f32[1,2]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[2]) -> f32[1,2] { + %Arg_0.1 = f32[2] parameter(0) + ROOT %reshape.2 = f32[1,2] reshape(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.reverse %[[VAL_0]], dims = [1, 2] : tensor<10x11x12x13xf32> +// CHECK: return %[[VAL_1]] : tensor<10x11x12x13xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10,11,12,13]{3,2,1,0})->f32[10,11,12,13]{3,2,1,0}} + +ENTRY %main.3 (Arg_0.1: f32[10,11,12,13]) -> f32[10,11,12,13] { + %Arg_0.1 = f32[10,11,12,13] parameter(0) + ROOT %reverse.2 = f32[10,11,12,13] reverse(%Arg_0.1), dimensions={1,2} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor<2x3x5xf32> { +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], distribution = NORMAL : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> +// CHECK: return %[[VAL_4]] : tensor<2x3x5xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[], f32[])->f32[2,3,5]{2,1,0}} + +ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] { + %constant.3 = s64[3] constant({2, 3, 5}) + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + ROOT %rng.4 = f32[2,3,5] rng(%Arg_0.1, %Arg_1.2), distribution=rng_normal +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main() -> tensor<2x3x5xf32> { +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = stablehlo.rng %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], distribution = UNIFORM : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> +// CHECK: return %[[VAL_4]] : tensor<2x3x5xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={()->f32[2,3,5]{2,1,0}} + +ENTRY %main.5 () -> f32[2,3,5] { + %constant.3 = s64[3] constant({2, 3, 5}) + %constant.1 = f32[] constant(0) + %constant.2 = f32[] constant(1) + ROOT %rng.4 = f32[2,3,5] rng(%constant.1, %constant.2), distribution=rng_uniform +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<200x100x300xf32>, %[[VAL_4:.*]]: tensor<10x2xi32>, %[[VAL_5:.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> { +// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) <{indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor): +// CHECK: %[[VAL_9:.*]] = stablehlo.add %[[VAL_7]], %[[VAL_8]] : tensor +// CHECK: mhlo.return %[[VAL_9]] : tensor +// CHECK: }) : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> +// CHECK: return %[[VAL_6]] : tensor<200x100x300xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[200,100,300]{2,1,0}, s32[10,2]{1,0}, f32[10,300]{1,0})->f32[200,100,300]{2,1,0}} + +%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> f32[] { + %Arg_0.5 = f32[] parameter(0) + %Arg_1.6 = f32[] parameter(1) + ROOT %add.7 = f32[] add(%Arg_0.5, %Arg_1.6) +} + +ENTRY %main.9 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[10,2], Arg_2.3: f32[10,300]) -> f32[200,100,300] { + %Arg_0.1 = f32[200,100,300] parameter(0) + %Arg_1.2 = s32[10,2] parameter(1) + %Arg_2.3 = f32[10,300] parameter(2) + ROOT %scatter.8 = f32[200,100,300] scatter(%Arg_0.1, %Arg_1.2, %Arg_2.3), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=%region_0.4 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_3:.*]]: tensor<200x100x300xf32>, %[[VAL_4:.*]]: tensor<100x200x1xi32>, %[[VAL_5:.*]]: tensor<100x200x300xf32>) -> tensor<200x100x300xf32> { +// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) <{indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor): +// CHECK: %[[VAL_9:.*]] = stablehlo.add %[[VAL_7]], %[[VAL_8]] : tensor +// CHECK: mhlo.return %[[VAL_9]] : tensor +// CHECK: }) : (tensor<200x100x300xf32>, tensor<100x200x1xi32>, tensor<100x200x300xf32>) -> tensor<200x100x300xf32> +// CHECK: return %[[VAL_6]] : tensor<200x100x300xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[200,100,300]{2,1,0}, s32[100,200,1]{2,1,0}, f32[100,200,300]{2,1,0})->f32[200,100,300]{2,1,0}} + +%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> f32[] { + %Arg_0.5 = f32[] parameter(0) + %Arg_1.6 = f32[] parameter(1) + ROOT %add.7 = f32[] add(%Arg_0.5, %Arg_1.6) +} + +ENTRY %main.9 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[100,200,1], Arg_2.3: f32[100,200,300]) -> f32[200,100,300] { + %Arg_0.1 = f32[200,100,300] parameter(0) + %Arg_1.2 = s32[100,200,1] parameter(1) + %Arg_2.3 = f32[100,200,300] parameter(2) + ROOT %scatter.8 = f32[200,100,300] scatter(%Arg_0.1, %Arg_1.2, %Arg_2.3), update_window_dims={2}, inserted_window_dims={}, scatter_dims_to_operand_dims={2}, input_batching_dims={0,1}, scatter_indices_batching_dims={1,0}, index_vector_dim=2, indices_are_sorted=true, unique_indices=true, to_apply=%region_0.4 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tuple, tensor> { +// CHECK: %[[VAL_4:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor +// CHECK: %[[VAL_5:.*]] = stablehlo.add %[[VAL_2]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], f32[])"} : tuple, tensor> +// CHECK: return %[[VAL_6]] : tuple, tensor> +// CHECK: } +// CHECK: func.func @main(%[[VAL_7:.*]]: tensor<200x100x300xf32>, %[[VAL_8:.*]]: tensor<10x2xi64>, %[[VAL_9:.*]]: tensor<10x300xf32>) -> tuple, tensor<200x100x300xf32>> { +// CHECK: %[[VAL_10:.*]]:2 = "mhlo.scatter"(%[[VAL_7]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_9]]) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false}> ({ +// CHECK: ^bb0(%[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor): +// CHECK: %[[VAL_15:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_12]] : tensor +// CHECK: %[[VAL_16:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_14]] : tensor +// CHECK: mhlo.return %[[VAL_15]], %[[VAL_16]] : tensor, tensor +// CHECK: }) : (tensor<200x100x300xf32>, tensor<200x100x300xf32>, tensor<10x2xi64>, tensor<10x300xf32>, tensor<10x300xf32>) -> (tensor<200x100x300xf32>, tensor<200x100x300xf32>) +// CHECK: %[[VAL_17:.*]] = mhlo.tuple %[[VAL_18:.*]]#0, %[[VAL_18]]#1 {xla_shape = "(f32[200,100,300]{2,1,0}, f32[200,100,300]{2,1,0})"} : tuple, tensor<200x100x300xf32>> +// CHECK: return %[[VAL_17]] : tuple, tensor<200x100x300xf32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[200,100,300]{2,1,0}, s64[10,2]{1,0}, f32[10,300]{1,0})->(f32[200,100,300]{2,1,0}, f32[200,100,300]{2,1,0})} + +%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[], Arg_2.7: f32[], Arg_3.8: f32[]) -> (f32[], f32[]) { + %Arg_0.5 = f32[] parameter(0) + %Arg_1.6 = f32[] parameter(1) + %add.9 = f32[] add(%Arg_0.5, %Arg_1.6) + %Arg_2.7 = f32[] parameter(2) + %Arg_3.8 = f32[] parameter(3) + %add.10 = f32[] add(%Arg_2.7, %Arg_3.8) + ROOT %tuple.11 = (f32[], f32[]) tuple(%add.9, %add.10) +} + +ENTRY %main.16 (Arg_0.1: f32[200,100,300], Arg_1.2: s64[10,2], Arg_2.3: f32[10,300]) -> (f32[200,100,300], f32[200,100,300]) { + %Arg_0.1 = f32[200,100,300] parameter(0) + %Arg_1.2 = s64[10,2] parameter(1) + %Arg_2.3 = f32[10,300] parameter(2) + %scatter.12 = (f32[200,100,300], f32[200,100,300]) scatter(%Arg_0.1, %Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_2.3), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%region_0.4 + %get-tuple-element.13 = f32[200,100,300] get-tuple-element(%scatter.12), index=0 + %get-tuple-element.14 = f32[200,100,300] get-tuple-element(%scatter.12), index=1 + ROOT %tuple.15 = (f32[200,100,300], f32[200,100,300]) tuple(%get-tuple-element.13, %get-tuple-element.14) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor<2x3xi32>, %[[VAL_2:.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> { +// CHECK: %[[VAL_3:.*]] = stablehlo.broadcast_in_dim %[[VAL_0]], dims = [] : (tensor) -> tensor<2x3xi1> +// CHECK: %[[VAL_4:.*]] = stablehlo.select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : tensor<2x3xi1>, tensor<2x3xi32> +// CHECK: return %[[VAL_4]] : tensor<2x3xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(pred[], s32[2,3]{1,0}, s32[2,3]{1,0})->s32[2,3]{1,0}} + +ENTRY %main.6 (Arg_0.1: pred[], Arg_1.2: s32[2,3], Arg_2.3: s32[2,3]) -> s32[2,3] { + %Arg_0.1 = pred[] parameter(0) + %broadcast.4 = pred[2,3] broadcast(%Arg_0.1), dimensions={} + %Arg_1.2 = s32[2,3] parameter(1) + %Arg_2.3 = s32[2,3] parameter(2) + ROOT %select.5 = s32[2,3] select(%broadcast.4, %Arg_1.2, %Arg_2.3) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = mhlo.compare GE, %[[VAL_0]], %[[VAL_1]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +// CHECK: func.func private @region_1.8(%[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_5:.*]] = stablehlo.add %[[VAL_3]], %[[VAL_4]] : tensor +// CHECK: return %[[VAL_5]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_6:.*]]: tensor<10x24x24x64xf32>, %[[VAL_7:.*]]: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { +// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_9:.*]] = "stablehlo.select_and_scatter"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ +// CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor): +// CHECK: %[[VAL_12:.*]] = mhlo.compare GE, %[[VAL_10]], %[[VAL_11]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: stablehlo.return %[[VAL_12]] : tensor +// CHECK: }, { +// CHECK: ^bb0(%[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor): +// CHECK: %[[VAL_15:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_14]] : tensor +// CHECK: stablehlo.return %[[VAL_15]] : tensor +// CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> +// CHECK: return %[[VAL_9]] : tensor<10x24x24x64xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10,24,24,64]{3,2,1,0}, f32[10,12,12,64]{3,2,1,0})->f32[10,24,24,64]{3,2,1,0}} + +%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> pred[] { + %Arg_0.5 = f32[] parameter(0) + %Arg_1.6 = f32[] parameter(1) + ROOT %compare.7 = pred[] compare(%Arg_0.5, %Arg_1.6), direction=GE, type=TOTALORDER +} + +%region_1.8 (Arg_0.9: f32[], Arg_1.10: f32[]) -> f32[] { + %Arg_0.9 = f32[] parameter(0) + %Arg_1.10 = f32[] parameter(1) + ROOT %add.11 = f32[] add(%Arg_0.9, %Arg_1.10) +} + +ENTRY %main.13 (Arg_0.1: f32[10,24,24,64], Arg_1.2: f32[10,12,12,64]) -> f32[10,24,24,64] { + %Arg_0.1 = f32[10,24,24,64] parameter(0) + %Arg_1.2 = f32[10,12,12,64] parameter(1) + %constant.3 = f32[] constant(0) + ROOT %select-and-scatter.12 = f32[10,24,24,64] select-and-scatter(%Arg_0.1, %Arg_1.2, %constant.3), window={size=1x2x2x1 stride=1x2x2x1}, select=%region_0.4, scatter=%region_1.8 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} + +ENTRY %main.5 (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = s32[3,4] parameter(0) + %Arg_1.2 = token[] parameter(1) + %send.3 = (s32[3,4], u32[], token[]) send(%Arg_0.1, %Arg_1.2), channel_id=5, is_host_transfer=true + ROOT %send-done.4 = token[] send-done(%send.3), channel_id=5, is_host_transfer=true +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> {xla_shape = "token[]"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} + +ENTRY %main.5 (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = s32[3,4] parameter(0) + %Arg_1.2 = token[] parameter(1) + %send.3 = (s32[3,4], u32[], token[]) send(%Arg_0.1, %Arg_1.2), channel_id=5 + ROOT %send-done.4 = token[] send-done(%send.3), channel_id=5 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_1:.*]] = "mhlo.send"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> {xla_shape = "token[]"} : (!mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(token[])->token[]} + +ENTRY %main.5 (Arg_0.1: token[]) -> token[] { + %tuple.2 = () tuple() + %Arg_0.1 = token[] parameter(0) + %send.3 = ((), u32[], token[]) send(%tuple.2, %Arg_0.1), channel_id=5 + ROOT %send-done.4 = token[] send-done(%send.3), channel_id=5 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor) -> tensor<4x?xf32, #mhlo.type_extensions> { +// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 1 : (tensor<4x4xf32>, tensor) -> tensor<4x?xf32, #mhlo.type_extensions> +// CHECK: return %[[VAL_2]] : tensor<4x?xf32, #mhlo.type_extensions> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,4]{1,0}, s32[])->f32[4,<=4]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { + %Arg_0.1 = f32[4,4] parameter(0) + %Arg_1.2 = s32[] parameter(1) + ROOT %set-dimension-size.3 = f32[4,<=4] set-dimension-size(%Arg_0.1, %Arg_1.2), dimensions={1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:4:2] : (tensor<3x4xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_1]] : tensor<1x2xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3,4]{1,0})->s32[1,2]{1,0}} + +ENTRY %main.3 (Arg_0.1: s32[3,4]) -> s32[1,2] { + %Arg_0.1 = s32[3,4] parameter(0) + ROOT %slice.2 = s32[1,2] slice(%Arg_0.1), slice={[1:2:1], [0:4:2]} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor) -> tensor<1x4xi32> { +// CHECK: %[[VAL_3:.*]] = stablehlo.dynamic_slice %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], sizes = [1, 4] : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> +// CHECK: return %[[VAL_3]] : tensor<1x4xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3,4]{1,0}, s64[], s64[])->s32[1,4]{1,0}} + +ENTRY %main.5 (Arg_0.1: s32[3,4], Arg_1.2: s64[], Arg_2.3: s64[]) -> s32[1,4] { + %Arg_0.1 = s32[3,4] parameter(0) + %Arg_1.2 = s64[] parameter(1) + %Arg_2.3 = s64[] parameter(2) + ROOT %dynamic-slice.4 = s32[1,4] dynamic-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3), dynamic_slice_sizes={1,4} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false, mhlo.xla_entry_computation_result_layout = [dense<[2, 3, 0, 1]> : tensor<4xindex>], mhlo.xla_entry_computation_result_tiles = {{\[\[}}]]} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.transpose %[[VAL_0]], dims = [1, 0, 3, 2] : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> +// CHECK: return %[[VAL_1]] : tensor<2x1x4x3xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[1,2,3,4]{3,2,1,0})->s32[2,1,4,3]{2,3,0,1}} + +ENTRY %main.3 (Arg_0.1: s32[1,2,3,4]) -> s32[2,1,4,3] { + %Arg_0.1 = s32[1,2,3,4] parameter(0) + ROOT %transpose.2 = s32[2,1,4,3] transpose(%Arg_0.1), dimensions={1,0,3,2} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor<4x3xf32>) -> tensor<4x3xf32> { +// CHECK: %[[VAL_2:.*]] = "stablehlo.triangular_solve"(%[[VAL_0]], %[[VAL_1]]) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = true}> : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> +// CHECK: return %[[VAL_2]] : tensor<4x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,4]{1,0}, f32[4,3]{1,0})->f32[4,3]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] { + %Arg_0.1 = f32[4,4] parameter(0) + %Arg_1.2 = f32[4,3] parameter(1) + ROOT %triangular-solve.3 = f32[4,3] triangular-solve(%Arg_0.1, %Arg_1.2), left_side=true, lower=true, unit_diagonal=true, transpose_a=NO_TRANSPOSE +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tuple, tensor> { +// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_0]], %[[VAL_1]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> +// CHECK: return %[[VAL_2]] : tuple, tensor> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[], s32[])->(f32[], s32[])} + +ENTRY %main.4 (Arg_0.1: f32[], Arg_1.2: s32[]) -> (f32[], s32[]) { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = s32[] parameter(1) + ROOT %tuple.3 = (f32[], s32[]) tuple(%Arg_0.1, %Arg_1.2) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor<4xi32>) -> tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> { +// CHECK: %[[VAL_2:.*]] = mhlo.exponential_minus_one %[[VAL_0]] : tensor<4xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.log_plus_one %[[VAL_0]] : tensor<4xf32> +// CHECK: %[[VAL_4:.*]] = stablehlo.not %[[VAL_1]] : tensor<4xi32> +// CHECK: %[[VAL_5:.*]] = stablehlo.popcnt %[[VAL_1]] : tensor<4xi32> +// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[4]{0}, f32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> +// CHECK: return %[[VAL_6]] : tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4]{0}, s32[4]{0})->(f32[4]{0}, f32[4]{0}, s32[4]{0}, s32[4]{0})} + +ENTRY %main.8 (Arg_0.1: f32[4], Arg_1.2: s32[4]) -> (f32[4], f32[4], s32[4], s32[4]) { + %Arg_0.1 = f32[4] parameter(0) + %exponential-minus-one.3 = f32[4] exponential-minus-one(%Arg_0.1) + %log-plus-one.4 = f32[4] log-plus-one(%Arg_0.1) + %Arg_1.2 = s32[4] parameter(1) + %not.5 = s32[4] not(%Arg_1.2) + %popcnt.6 = s32[4] popcnt(%Arg_1.2) + ROOT %tuple.7 = (f32[4], f32[4], s32[4], s32[4]) tuple(%exponential-minus-one.3, %log-plus-one.4, %not.5, %popcnt.6) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xi1>, %[[VAL_1:.*]]: tensor<4xi1>) -> tensor<4xi1> { +// CHECK: %[[VAL_2:.*]] = stablehlo.xor %[[VAL_0]], %[[VAL_1]] : tensor<4xi1> +// CHECK: return %[[VAL_2]] : tensor<4xi1> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(pred[4]{0}, pred[4]{0})->pred[4]{0}} + +ENTRY %main.4 (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] { + %Arg_0.1 = pred[4] parameter(0) + %Arg_1.2 = pred[4] parameter(1) + ROOT %xor.3 = pred[4] xor(%Arg_0.1, %Arg_1.2) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>, %[[VAL_1:.*]]: tensor<16x16xi32>) -> tuple<> { +// CHECK: %[[VAL_2:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_2]] : tuple<> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[16,16]{1,0}, s32[16,16]{1,0})->()} + +ENTRY %main.4 (Arg_0.1: f32[16,16], Arg_1.2: s32[16,16]) -> () { + %Arg_0.1 = f32[16,16] parameter(0) + %Arg_1.2 = s32[16,16] parameter(1) + ROOT %tuple.3 = () tuple() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>) -> tuple<> { +// CHECK: %[[VAL_1:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_1]] : tuple<> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[16,16]{1,0})->()} + +ENTRY %main.3 (Arg_0.1: f32[16,16]) -> () { + %Arg_0.1 = f32[16,16] parameter(0) + ROOT %tuple.2 = () tuple() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>) -> (tensor<16x16xf32> {mhlo.sharding = "{devices=[1,2]0,1}"}) { +// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @Sharding(%[[VAL_0]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<16x16xf32>) -> tensor<16x16xf32> +// CHECK: return %[[VAL_1]] : tensor<16x16xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[16,16]{1,0})->f32[16,16]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[16,16]) -> f32[16,16] { + %Arg_0.1 = f32[16,16] parameter(0) + ROOT %custom-call.2 = f32[16,16] custom-call(%Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2]0,1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @foo.3(%[[VAL_0:.*]]: tensor<2x3xf32>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tensor<2x3xf32> { +// CHECK: return %[[VAL_0]] : tensor<2x3xf32> +// CHECK: } +// CHECK: func.func @main(%[[VAL_2:.*]]: tensor<2x3xf32>, %[[VAL_3:.*]]: tensor<5x5xf32>) -> tensor<2x3xf32> { +// CHECK: %[[VAL_4:.*]] = mhlo.custom_call @foo(%[[VAL_2]], %[[VAL_3]]) {backend_config = "", called_computations = [@foo.3]} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<2x3xf32> +// CHECK: return %[[VAL_4]] : tensor<2x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0}, f32[5,5]{1,0})->f32[2,3]{1,0}} + +%foo.3 (Arg_0.4: f32[2,3], Arg_1.5: f32[5,5]) -> f32[2,3] { + ROOT %Arg_0.4 = f32[2,3] parameter(0) + %Arg_1.5 = f32[5,5] parameter(1) +} + +ENTRY %main.7 (Arg_0.1: f32[2,3], Arg_1.2: f32[5,5]) -> f32[2,3] { + %Arg_0.1 = f32[2,3] parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + ROOT %custom-call.6 = f32[2,3] custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="foo", called_computations={%foo.3} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xcomplex>, %[[VAL_1:.*]]: tensor<2xcomplex>) -> tuple, tensor<2xf64>> { +// CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_0]] : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.abs %[[VAL_1]] : (tensor<2xcomplex>) -> tensor<2xf64> +// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(f32[2]{0}, f64[2]{0})"} : tuple, tensor<2xf64>> +// CHECK: return %[[VAL_4]] : tuple, tensor<2xf64>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(c64[2]{0}, c128[2]{0})->(f32[2]{0}, f64[2]{0})} + +ENTRY %main.6 (Arg_0.1: c64[2], Arg_1.2: c128[2]) -> (f32[2], f64[2]) { + %Arg_0.1 = c64[2] parameter(0) + %abs.3 = f32[2] abs(%Arg_0.1) + %Arg_1.2 = c128[2] parameter(1) + %abs.4 = f64[2] abs(%Arg_1.2) + ROOT %tuple.5 = (f32[2], f64[2]) tuple(%abs.3, %abs.4) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xui8>) -> tensor<4xui8> { +// CHECK: %[[VAL_1:.*]] = stablehlo.not %[[VAL_0]] : tensor<4xui8> +// CHECK: return %[[VAL_1]] : tensor<4xui8> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(u8[4]{0})->u8[4]{0}} + +ENTRY %main.3 (Arg_0.1: u8[4]) -> u8[4] { + %Arg_0.1 = u8[4] parameter(0) + ROOT %not.2 = u8[4] not(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.not %[[VAL_0]] : tensor<4xi32> +// CHECK: return %[[VAL_1]] : tensor<4xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[4]{0})->s32[4]{0}} + +ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[4] { + %Arg_0.1 = s32[4] parameter(0) + ROOT %not.2 = s32[4] not(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor) -> tensor> { +// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 0 : (tensor<4xf32>, tensor) -> tensor> +// CHECK: return %[[VAL_2]] : tensor> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4]{0}, s32[])->f32[<=4]{0}} + +ENTRY %main.4 (Arg_0.1: f32[4], Arg_1.2: s32[]) -> f32[<=4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = s32[] parameter(1) + ROOT %set-dimension-size.3 = f32[<=4] set-dimension-size(%Arg_0.1, %Arg_1.2), dimensions={0} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !mhlo.token) -> tuple, !mhlo.token> { +// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_rendezvous = "channel_dtoh_0"}, xla_shape = "token[]"} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token +// CHECK: %[[VAL_3:.*]]:2 = "mhlo.recv"(%[[VAL_2]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_rendezvous = "channel_htod_0"}} : (!mhlo.token) -> (tensor<3x4xf32>, !mhlo.token) +// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_3]]#0, %[[VAL_3]]#1 {xla_shape = "(f32[3,4]{1,0}, token[])"} : tuple, !mhlo.token> +// CHECK: return %[[VAL_4]] : tuple, !mhlo.token> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0}, token[])->(f32[3,4]{1,0}, token[])} + +ENTRY %main.10 (Arg_0.1: f32[3,4], Arg_1.2: token[]) -> (f32[3,4], token[]) { + %Arg_0.1 = f32[3,4] parameter(0) + %Arg_1.2 = token[] parameter(1) + %send.3 = (f32[3,4], u32[], token[]) send(%Arg_0.1, %Arg_1.2), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_rendezvous="channel_dtoh_0"} + %send-done.4 = token[] send-done(%send.3), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_rendezvous="channel_dtoh_0"} + %recv.5 = (f32[3,4], u32[], token[]) recv(%send-done.4), channel_id=2, is_host_transfer=true, frontend_attributes={_xla_host_transfer_rendezvous="channel_htod_0"} + %recv-done.6 = (f32[3,4], token[]) recv-done(%recv.5), channel_id=2, is_host_transfer=true, frontend_attributes={_xla_host_transfer_rendezvous="channel_htod_0"} + %get-tuple-element.7 = f32[3,4] get-tuple-element(%recv-done.6), index=0, frontend_attributes={_xla_host_transfer_rendezvous="channel_htod_0"} + %get-tuple-element.8 = token[] get-tuple-element(%recv-done.6), index=1, frontend_attributes={_xla_host_transfer_rendezvous="channel_htod_0"} + ROOT %tuple.9 = (f32[3,4], token[]) tuple(%get-tuple-element.7, %get-tuple-element.8) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0}, token[])->token[]} + +ENTRY %main.5 (Arg_0.1: f32[3,4], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = f32[3,4] parameter(0) + %Arg_1.2 = token[] parameter(1) + %send.3 = (f32[3,4], u32[], token[]) send(%Arg_0.1, %Arg_1.2), channel_id=1, is_host_transfer=true + ROOT %send-done.4 = token[] send-done(%send.3), channel_id=1, is_host_transfer=true +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { +// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token +// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0}, token[])->token[]} + +ENTRY %main.5 (Arg_0.1: f32[3,4], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = f32[3,4] parameter(0) + %Arg_1.2 = token[] parameter(1) + %send.3 = (f32[3,4], u32[], token[]) send(%Arg_0.1, %Arg_1.2), channel_id=1, is_host_transfer=true + ROOT %send-done.4 = token[] send-done(%send.3), channel_id=1, is_host_transfer=true +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { +// CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = stablehlo.rng_bit_generator %[[VAL_0]], algorithm = PHILOX : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) +// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_1]], %[[VAL_2]] {xla_shape = "(u64[3]{0}, u32[2,2]{1,0})"} : tuple, tensor<2x2xui32>> +// CHECK: return %[[VAL_3]] : tuple, tensor<2x2xui32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(u64[3]{0})->(u64[3]{0}, u32[2,2]{1,0})} + +ENTRY %main.6 (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) { + %Arg_0.1 = u64[3] parameter(0) + %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(%Arg_0.1), algorithm=rng_philox + %get-tuple-element.3 = u64[3] get-tuple-element(%rng-bit-generator.2), index=0 + %get-tuple-element.4 = u32[2,2] get-tuple-element(%rng-bit-generator.2), index=1 + ROOT %tuple.5 = (u64[3], u32[2,2]) tuple(%get-tuple-element.3, %get-tuple-element.4) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.cbrt %[[VAL_0]] : tensor<3x4xf32> +// CHECK: return %[[VAL_1]] : tensor<3x4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0})->f32[3,4]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + ROOT %cbrt.2 = f32[3,4] cbrt(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.reduce_precision %[[VAL_0]], format = e8m10 : tensor<3x4xf32> +// CHECK: return %[[VAL_1]] : tensor<3x4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0})->f32[3,4]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + ROOT %reduce-precision.2 = f32[3,4] reduce-precision(%Arg_0.1), exponent_bits=8, mantissa_bits=10 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.bitcast %[[VAL_0]] {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, source_layout = dense<[1, 0]> : tensor<2xindex>} : (tensor<3x4xf32>) -> tensor<3x4x1xf32> +// CHECK: return %[[VAL_1]] : tensor<3x4x1xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0})->f32[3,4,1]{2,1,0}} + +ENTRY %main.3 (Arg_0.1: f32[3,4]) -> f32[3,4,1] { + %Arg_0.1 = f32[3,4] parameter(0) + ROOT %bitcast.2 = f32[3,4,1] bitcast(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor<3x4xf32>) -> tuple, tensor<3x4xf32>> { +// CHECK: %[[VAL_2:.*]]:2 = stablehlo.optimization_barrier %[[VAL_0]], %[[VAL_1]] : tensor<4x4xf32>, tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[4,4]{1,0}, f32[3,4]{1,0})"} : tuple, tensor<3x4xf32>> +// CHECK: return %[[VAL_3]] : tuple, tensor<3x4xf32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,4]{1,0}, f32[3,4]{1,0})->(f32[4,4]{1,0}, f32[3,4]{1,0})} + +ENTRY %main.8 (Arg_0.1: f32[4,4], Arg_1.2: f32[3,4]) -> (f32[4,4], f32[3,4]) { + %Arg_0.1 = f32[4,4] parameter(0) + %Arg_1.2 = f32[3,4] parameter(1) + %tuple.3 = (f32[4,4], f32[3,4]) tuple(%Arg_0.1, %Arg_1.2), sharding={{replicated}, {devices=[1,2]<=[2]}} + %opt-barrier.4 = (f32[4,4], f32[3,4]) opt-barrier(%tuple.3), sharding={{replicated}, {devices=[1,2]<=[2]}} + %get-tuple-element.5 = f32[4,4] get-tuple-element(%opt-barrier.4), index=0, sharding={replicated} + %get-tuple-element.6 = f32[3,4] get-tuple-element(%opt-barrier.4), index=1, sharding={devices=[1,2]<=[2]} + ROOT %tuple.7 = (f32[4,4], f32[3,4]) tuple(%get-tuple-element.5, %get-tuple-element.6) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main() -> tensor { +// CHECK: %[[VAL_0:.*]] = stablehlo.partition_id : tensor +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={()->u32[]} + +ENTRY %main.2 () -> u32[] { + ROOT %partition-id.1 = u32[] partition-id() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "mhlo.domain"(%[[VAL_0]]) <{entry_metadata = "{maximal device=1}", exit_metadata = "{}", kind = #mhlo}> : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(u32[])->u32[]} + +ENTRY %main.3 (Arg_0.1: u32[]) -> u32[] { + %Arg_0.1 = u32[] parameter(0) + ROOT %domain.2 = u32[] domain(%Arg_0.1), domain={kind="sharding", entry={maximal device=1}, exit={}} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[VAL_2:.*]] = "stablehlo.triangular_solve"(%[[VAL_0]], %[[VAL_1]]) <{left_side = false, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: return %[[VAL_2]] : tensor<3x4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,4]{1,0}, f32[3,4]{1,0})->f32[3,4]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[4,4] parameter(0) + %Arg_1.2 = f32[3,4] parameter(1) + ROOT %triangular-solve.3 = f32[3,4] triangular-solve(%Arg_0.1, %Arg_1.2), lower=true, transpose_a=NO_TRANSPOSE +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tuple, tensor> { +// CHECK: %[[VAL_4:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_5:.*]] = stablehlo.add %[[VAL_1]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> +// CHECK: return %[[VAL_6]] : tuple, tensor> +// CHECK: } +// CHECK: func.func @main(%[[VAL_7:.*]]: tensor<4x2xf32>, %[[VAL_8:.*]]: tensor<4x2xi32>, %[[VAL_9:.*]]: tensor, %[[VAL_10:.*]]: tensor) -> tuple, tensor<2x2xi32>> { +// CHECK: %[[VAL_11:.*]]:2 = "stablehlo.reduce_window"(%[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]]) <{base_dilations = array, padding = dense<{{\[\[}}2, 2], [0, 0]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor): +// CHECK: %[[VAL_16:.*]] = stablehlo.add %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_17:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_15]] : tensor +// CHECK: stablehlo.return %[[VAL_16]], %[[VAL_17]] : tensor, tensor +// CHECK: }) : (tensor<4x2xf32>, tensor<4x2xi32>, tensor, tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) +// CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_19:.*]]#0, %[[VAL_19]]#1 {xla_shape = "(f32[2,2]{1,0}, s32[2,2]{1,0})"} : tuple, tensor<2x2xi32>> +// CHECK: return %[[VAL_18]] : tuple, tensor<2x2xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,2]{1,0}, s32[4,2]{1,0}, f32[], s32[])->(f32[2,2]{1,0}, s32[2,2]{1,0})} + +%region_0.5 (Arg_0.6: f32[], Arg_1.7: s32[], Arg_2.8: f32[], Arg_3.9: s32[]) -> (f32[], s32[]) { + %Arg_0.6 = f32[] parameter(0) + %Arg_2.8 = f32[] parameter(2) + %add.10 = f32[] add(%Arg_0.6, %Arg_2.8) + %Arg_1.7 = s32[] parameter(1) + %Arg_3.9 = s32[] parameter(3) + %add.11 = s32[] add(%Arg_1.7, %Arg_3.9) + ROOT %tuple.12 = (f32[], s32[]) tuple(%add.10, %add.11) +} + +ENTRY %main.17 (Arg_0.1: f32[4,2], Arg_1.2: s32[4,2], Arg_2.3: f32[], Arg_3.4: s32[]) -> (f32[2,2], s32[2,2]) { + %Arg_0.1 = f32[4,2] parameter(0) + %Arg_1.2 = s32[4,2] parameter(1) + %Arg_2.3 = f32[] parameter(2) + %Arg_3.4 = s32[] parameter(3) + %reduce-window.13 = (f32[2,2], s32[2,2]) reduce-window(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4), window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=%region_0.5 + %get-tuple-element.14 = f32[2,2] get-tuple-element(%reduce-window.13), index=0 + %get-tuple-element.15 = s32[2,2] get-tuple-element(%reduce-window.13), index=1 + ROOT %tuple.16 = (f32[2,2], s32[2,2]) tuple(%get-tuple-element.14, %get-tuple-element.15) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = stablehlo.round_nearest_even %[[VAL_0]] : tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2]{0})->f32[2]{0}} + +ENTRY %main.3 (Arg_0.1: f32[2]) -> f32[2] { + %Arg_0.1 = f32[2] parameter(0) + ROOT %round-nearest-even.2 = f32[2] round-nearest-even(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.tan %[[VAL_0]] : tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2]{0})->f32[2]{0}} + +ENTRY %main.3 (Arg_0.1: f32[2]) -> f32[2] { + %Arg_0.1 = f32[2] parameter(0) + ROOT %tan.2 = f32[2] tan(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>) -> tuple, tensor<4x2xi32>> { +// CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = mhlo.topk(%[[VAL_0]], k = 2) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) +// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_1]], %[[VAL_2]] {xla_shape = "(f32[4,2]{1,0}, s32[4,2]{1,0})"} : tuple, tensor<4x2xi32>> +// CHECK: return %[[VAL_3]] : tuple, tensor<4x2xi32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,4]{1,0})->(f32[4,2]{1,0}, s32[4,2]{1,0})} + +ENTRY %main.6 (Arg_0.1: f32[4,4]) -> (f32[4,2], s32[4,2]) { + %Arg_0.1 = f32[4,4] parameter(0) + %topk.2 = (f32[4,2], s32[4,2]) topk(%Arg_0.1), k=2, largest=true + %get-tuple-element.3 = f32[4,2] get-tuple-element(%topk.2), index=0 + %get-tuple-element.4 = s32[4,2] get-tuple-element(%topk.2), index=1 + ROOT %tuple.5 = (f32[4,2], s32[4,2]) tuple(%get-tuple-element.3, %get-tuple-element.4) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tuple, tensor<2x3xf32>>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tuple<> { +// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "", output_operand_aliases = [#mhlo.output_operand_alias], xla_shape = "(f32[2,3]{1,0})"} : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple> +// CHECK: %[[VAL_3:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_3]] : tuple<> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={((f32[1,1]{1,0}, f32[2,3]{1,0}), f32[5,5]{1,0})->()} + +ENTRY %main.5 (Arg_0.1: (f32[1,1], f32[2,3]), Arg_1.2: f32[5,5]) -> () { + %Arg_0.1 = (f32[1,1], f32[2,3]) parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + %custom-call.3 = (f32[2,3]) custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="foo", output_to_operand_aliasing={{0}: (0, {1})} + ROOT %tuple.4 = () tuple() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tuple, tensor<2x3xf32>>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tuple<> { +// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "", output_operand_aliases = [#mhlo.output_operand_alias]} : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tensor<2x3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_3]] : tuple<> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={((f32[1,1]{1,0}, f32[2,3]{1,0}), f32[5,5]{1,0})->()} + +ENTRY %main.5 (Arg_0.1: (f32[1,1], f32[2,3]), Arg_1.2: f32[5,5]) -> () { + %Arg_0.1 = (f32[1,1], f32[2,3]) parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + %custom-call.3 = f32[2,3] custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="foo", output_to_operand_aliasing={{}: (0, {1})} + ROOT %tuple.4 = () tuple() +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token +// CHECK: %[[VAL_2:.*]] = mhlo.add_dependency %[[VAL_0]], %[[VAL_1]] : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> +// CHECK: return %[[VAL_2]] : tensor<3x4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0})->f32[3,4]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + %after-all.2 = token[] after-all() + ROOT %add-dependency.3 = f32[3,4] add-dependency(%Arg_0.1, %after-all.2) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> attributes {execution_thread = "test_thread"} { +// CHECK: %[[VAL_1:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token +// CHECK: %[[VAL_2:.*]] = mhlo.add_dependency %[[VAL_0]], %[[VAL_1]] : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> +// CHECK: return %[[VAL_2]] : tensor<3x4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[3,4]{1,0})->f32[3,4]{1,0}} + +ENTRY %main.4 (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + %after-all.2 = token[] after-all() + ROOT %add-dependency.3 = f32[3,4] add-dependency(%Arg_0.1, %after-all.2) +}, execution_thread="test_thread" + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VAL_1:.*]] = "mhlo.all_to_all"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, concat_dimension = 1 : i64, replica_groups = dense<{{\[\[}}1, 2], [0, 3]]> : tensor<2x2xi64>, split_count = 2 : i64, split_dimension = 1 : i64}> : (tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: return %[[VAL_1]] : tensor<2x2xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[2,2]{1,0})->s32[2,2]{1,0}} + +ENTRY %main.3 (Arg_0.1: s32[2,2]) -> s32[2,2] { + %Arg_0.1 = s32[2,2] parameter(0) + ROOT %all-to-all.2 = s32[2,2] all-to-all(%Arg_0.1), channel_id=1, replica_groups={{1,2},{0,3}}, dimensions={1} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x4xf32>, %[[VAL_1:.*]]: tensor<128x4xf32>) -> tuple, tensor<128x4xf32>> { +// CHECK: %[[VAL_2:.*]]:2 = "mhlo.all_to_all"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1]]> : tensor<1x2xi64>}> : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) +// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[128,4]{1,0}, f32[128,4]{1,0})"} : tuple, tensor<128x4xf32>> +// CHECK: return %[[VAL_3]] : tuple, tensor<128x4xf32>> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[128,4]{1,0}, f32[128,4]{1,0})->(f32[128,4]{1,0}, f32[128,4]{1,0})} + +ENTRY %main.7 (Arg_0.1: f32[128,4], Arg_1.2: f32[128,4]) -> (f32[128,4], f32[128,4]) { + %Arg_0.1 = f32[128,4] parameter(0) + %Arg_1.2 = f32[128,4] parameter(1) + %all-to-all.3 = (f32[128,4], f32[128,4]) all-to-all(%Arg_0.1, %Arg_1.2), channel_id=1, replica_groups={{0,1}} + %get-tuple-element.4 = f32[128,4] get-tuple-element(%all-to-all.3), index=0 + %get-tuple-element.5 = f32[128,4] get-tuple-element(%all-to-all.3), index=1 + ROOT %tuple.6 = (f32[128,4], f32[128,4]) tuple(%get-tuple-element.4, %get-tuple-element.5) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tensor<1x2x3xf32> { +// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {api_version = 4 : i32, backend_config = {user_attr0 = 123 : i32, user_attr1 = dense<42> : tensor}, has_side_effect = true} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: return %[[VAL_2]] : tensor<1x2x3xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,3]{1,0}, f32[5,5]{1,0})->f32[1,2,3]{2,1,0}} + +ENTRY %main.4 (Arg_0.1: f32[2,3], Arg_1.2: f32[5,5]) -> f32[1,2,3] { + %Arg_0.1 = f32[2,3] parameter(0) + %Arg_1.2 = f32[5,5] parameter(1) + ROOT %custom-call.3 = f32[1,2,3] custom-call(%Arg_0.1, %Arg_1.2), custom_call_target="foo", custom_call_has_side_effect=true, api_version=API_VERSION_TYPED_FFI, backend_config={user_attr0 = 123 : i32, user_attr1 = dense<42> : tensor} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor {mhlo.parameter_replication = [true]}, %[[VAL_1:.*]]: tuple, tuple>> {mhlo.parameter_replication = [false, true]}) -> tensor { +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0})))->f32[]} + +ENTRY %main.3 (Arg_0.1: f32[], Arg_1.2: (f32[2,4], (f32[2,4]))) -> f32[] { + ROOT %Arg_0.1 = f32[] parameter(0), parameter_replication={true} + %Arg_1.2 = (f32[2,4], (f32[2,4])) parameter(1), parameter_replication={false,true} +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = stablehlo.abs %[[VAL_0]] : tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[?,784]{1,0})->f32[?,784]{1,0}} + +ENTRY %main.3 (Arg_0.1: f32[?,784]) -> f32[?,784] { + %Arg_0.1 = f32[?,784] parameter(0) + ROOT %abs.2 = f32[?,784] abs(%Arg_0.1) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.3(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.compare NE, %[[VAL_0]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = mhlo.compare NE, %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = stablehlo.or %[[VAL_3]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.select %[[VAL_5]], %[[VAL_6]], %[[VAL_2]] : tensor, tensor +// CHECK: return %[[VAL_7]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_8:.*]]: tensor<2x2xf32>) -> tuple> { +// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_10:.*]] = mhlo.reduce(%[[VAL_8]] init: %[[VAL_9]]) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor +// CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor) { +// CHECK: %[[VAL_13:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_14:.*]] = mhlo.compare NE, %[[VAL_11]], %[[VAL_13]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = mhlo.compare NE, %[[VAL_12]], %[[VAL_13]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = stablehlo.or %[[VAL_14]], %[[VAL_15]] : tensor +// CHECK: %[[VAL_17:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_18:.*]] = stablehlo.select %[[VAL_16]], %[[VAL_17]], %[[VAL_13]] : tensor, tensor +// CHECK: mhlo.return %[[VAL_18]] : tensor +// CHECK: } +// CHECK: %[[VAL_19:.*]] = mhlo.compare NE, %[[VAL_10]], %[[VAL_9]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = mhlo.tuple %[[VAL_19]] {xla_shape = "(pred[])"} : tuple> +// CHECK: return %[[VAL_20]] : tuple> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,2]{1,0})->(pred[])} + +%region_0.3 (Arg_0.4: f32[], Arg_1.5: f32[]) -> f32[] { + %Arg_0.4 = f32[] parameter(0) + %constant.7 = f32[] constant(0) + %compare.8 = pred[] compare(%Arg_0.4, %constant.7), direction=NE + %Arg_1.5 = f32[] parameter(1) + %compare.9 = pred[] compare(%Arg_1.5, %constant.7), direction=NE + %or.10 = pred[] or(%compare.8, %compare.9) + %constant.6 = f32[] constant(1) + ROOT %select.11 = f32[] select(%or.10, %constant.6, %constant.7) +} + +ENTRY %main.15 (Arg_0.1: f32[2,2]) -> (pred[]) { + %Arg_0.1 = f32[2,2] parameter(0) + %constant.2 = f32[] constant(0) + %reduce.12 = f32[] reduce(%Arg_0.1, %constant.2), dimensions={0,1}, to_apply=%region_0.3 + %compare.13 = pred[] compare(%reduce.12, %constant.2), direction=NE + ROOT %tuple.14 = (pred[]) tuple(%compare.13) +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_2:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_3:.*]] = "mhlo.all_reduce"(%[[VAL_2]]) <{replica_groups = dense<{{\[\[}}0], [1]]> : tensor<2x1xi64>}> ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_4]] : tensor +// CHECK: }) : (tensor) -> tensor +// CHECK: return %[[VAL_3]] : tensor +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[])->f32[]} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + ROOT %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) +} + +ENTRY %main.6 (Arg_0.1: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + ROOT %all-reduce.5 = f32[] all-reduce(%Arg_0.1), replica_groups={{0},{1}}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.2(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_2:.*]]: tensor<4x16xf32>) -> tensor<4x4xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.reduce_scatter"(%[[VAL_2]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_4]] : tensor +// CHECK: }) : (tensor<4x16xf32>) -> tensor<4x4xf32> +// CHECK: return %[[VAL_3]] : tensor<4x4xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[4,16]{1,0})->f32[4,4]{1,0}} + +%region_0.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + ROOT %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) +} + +ENTRY %main.6 (Arg_0.1: f32[4,16]) -> f32[4,4] { + %Arg_0.1 = f32[4,16] parameter(0) + ROOT %reduce-scatter.5 = f32[4,4] reduce-scatter(%Arg_0.1), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, dimensions={1}, to_apply=%region_0.2 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.3(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: return %[[VAL_3]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_4:.*]]: tensor<2x17x31x7xf32>, %[[VAL_5:.*]]: tensor) -> tensor<2x16x30x7xf32> { +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_4]], %[[VAL_5]]) <{base_dilations = array, padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor): +// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_7]], %[[VAL_9]] : tensor +// CHECK: stablehlo.return %[[VAL_10]] : tensor +// CHECK: }) : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> +// CHECK: return %[[VAL_6]] : tensor<2x16x30x7xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[2,17,31,7]{3,2,1,0}, f32[])->f32[2,16,30,7]{3,2,1,0}} + +%region_0.3 (Arg_0.4: f32[], Arg_1.5: f32[]) -> f32[] { + %Arg_1.5 = f32[] parameter(1) + %Arg_0.4 = f32[] parameter(0) + %constant.6 = f32[] constant(0) + ROOT %maximum.7 = f32[] maximum(%Arg_0.4, %constant.6) +} + +ENTRY %main.9 (Arg_0.1: f32[2,17,31,7], Arg_1.2: f32[]) -> f32[2,16,30,7] { + %Arg_0.1 = f32[2,17,31,7] parameter(0) + %Arg_1.2 = f32[] parameter(1) + ROOT %reduce-window.8 = f32[2,16,30,7] reduce-window(%Arg_0.1, %Arg_1.2), window={size=1x2x2x1}, to_apply=%region_0.3 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_2:.*]]: tensor<3xi32>, %[[VAL_3:.*]]: tensor<1x1xi32>, %[[VAL_4:.*]]: tensor<1xi32>) -> tensor<3xi32> { +// CHECK: %[[VAL_5:.*]] = "mhlo.scatter"(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false}> ({ +// CHECK: ^bb0(%[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: }) : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: return %[[VAL_5]] : tensor<3xi32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(s32[3]{0}, s32[1,1]{1,0}, s32[1]{0})->s32[3]{0}} + +%region_0.4 (Arg_0.5: s32[], Arg_1.6: s32[]) -> s32[] { + %Arg_0.5 = s32[] parameter(0) + ROOT %Arg_1.6 = s32[] parameter(1) +} + +ENTRY %main.8 (Arg_0.1: s32[3], Arg_1.2: s32[1,1], Arg_2.3: s32[1]) -> s32[3] { + %Arg_0.1 = s32[3] parameter(0) + %Arg_1.2 = s32[1,1] parameter(1) + %Arg_2.3 = s32[1] parameter(2) + ROOT %scatter.7 = s32[3] scatter(%Arg_0.1, %Arg_1.2, %Arg_2.3), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_0.4 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.compare GE, %[[VAL_0]], %[[VAL_2]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_3]] : tensor +// CHECK: } +// CHECK: func.func private @region_1.9(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor) -> tensor { +// CHECK: return %[[VAL_5]] : tensor +// CHECK: } +// CHECK: func.func @main(%[[VAL_6:.*]]: tensor<10x24x24x64xf32>, %[[VAL_7:.*]]: tensor<10x23x23x64xf32>, %[[VAL_8:.*]]: tensor) -> tensor<10x24x24x64xf32> { +// CHECK: %[[VAL_9:.*]] = "stablehlo.select_and_scatter"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ +// CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor): +// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_13:.*]] = mhlo.compare GE, %[[VAL_10]], %[[VAL_12]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: stablehlo.return %[[VAL_13]] : tensor +// CHECK: }, { +// CHECK: ^bb0(%[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor): +// CHECK: stablehlo.return %[[VAL_15]] : tensor +// CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> +// CHECK: return %[[VAL_9]] : tensor<10x24x24x64xf32> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[10,24,24,64]{3,2,1,0}, f32[10,23,23,64]{3,2,1,0}, f32[])->f32[10,24,24,64]{3,2,1,0}} + +%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> pred[] { + %Arg_1.6 = f32[] parameter(1) + %Arg_0.5 = f32[] parameter(0) + %constant.7 = f32[] constant(0) + ROOT %compare.8 = pred[] compare(%Arg_0.5, %constant.7), direction=GE, type=TOTALORDER +} + +%region_1.9 (Arg_0.10: f32[], Arg_1.11: f32[]) -> f32[] { + %Arg_0.10 = f32[] parameter(0) + ROOT %Arg_1.11 = f32[] parameter(1) +} + +ENTRY %main.13 (Arg_0.1: f32[10,24,24,64], Arg_1.2: f32[10,23,23,64], Arg_2.3: f32[]) -> f32[10,24,24,64] { + %Arg_0.1 = f32[10,24,24,64] parameter(0) + %Arg_1.2 = f32[10,23,23,64] parameter(1) + %Arg_2.3 = f32[] parameter(2) + ROOT %select-and-scatter.12 = f32[10,24,24,64] select-and-scatter(%Arg_0.1, %Arg_1.2, %Arg_2.3), window={size=1x2x2x1}, select=%region_0.4, scatter=%region_1.9 +} + +// ----- + +// CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>, %[[VAL_1:.*]]: tensor<16x16xi32>) -> tuple<> { +// CHECK: %[[VAL_2:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %[[VAL_2]] : tuple<> +// CHECK: } +// CHECK: } +HloModule main, entry_computation_layout={(f32[16,16]{1,0}, s32[16,16]{1,0})->()} + +ENTRY %main.4 (Arg_0.1: f32[16,16], Arg_1.2: s32[16,16]) -> () { + %Arg_0.1 = f32[16,16] parameter(0) + %Arg_1.2 = s32[16,16] parameter(1) + ROOT %tuple.3 = () tuple() +} diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc index e203008f7187d7..1e2d46b2da7885 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc @@ -54,7 +54,8 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) { mlir::OwningOpRef HloToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations, bool flatten_computation_args_result) { + bool import_all_computations, bool flatten_computation_args_result, + bool emit_stablehlo) { mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); @@ -67,7 +68,7 @@ mlir::OwningOpRef HloToMlirHloTranslateFunction( auto status = ConvertHloToMlirHlo( module.get(), hlo_proto.mutable_hlo_module(), import_all_computations, - flatten_computation_args_result); + flatten_computation_args_result, emit_stablehlo); if (!status.ok()) { module->emitError("Hlo module import failed: ") << status.message(); return nullptr; @@ -78,7 +79,8 @@ mlir::OwningOpRef HloToMlirHloTranslateFunction( mlir::OwningOpRef HloTextToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations, bool flatten_computation_args_result) { + bool import_all_computations, bool flatten_computation_args_result, + bool emit_stablehlo) { mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); @@ -93,7 +95,7 @@ mlir::OwningOpRef HloTextToMlirHloTranslateFunction( auto hlo_module = std::move(hlo_module_error.value()); auto status = ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations, - flatten_computation_args_result); + flatten_computation_args_result, emit_stablehlo); if (!status.ok()) { module->emitError("HLO Module import failed: ") << status.message(); return nullptr; diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h index 69b06bf9a60813..4ad30ebb7f9705 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h @@ -40,7 +40,7 @@ namespace xla { mlir::OwningOpRef HloToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, bool import_all_computations = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, bool emit_stablehlo = false); // Converts a HloModule stored in text form for a file with the given // `input_filename` into a MHLO module. Creates MLIR entities into the given @@ -54,7 +54,7 @@ mlir::OwningOpRef HloToMlirHloTranslateFunction( mlir::OwningOpRef HloTextToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, bool import_all_computations = false, - bool flatten_computation_args_result = false); + bool flatten_computation_args_result = false, bool emit_stablehlo = false); // Converts a HloModuleProto stored in the file with the given `input_filename` // into a StableHLO module. Creates MLIR entities into the given MLIR `context`. diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc index c1ef2a37675c83..10e04967f1b044 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc @@ -30,18 +30,25 @@ llvm::cl::opt import_all_computations( llvm::cl::opt flatten_computation_args_result( "hlo-flatten-computation-args-result", llvm::cl::desc("Enable flattening computation arguments and results.")); + +// NOLINTNEXTLINE +llvm::cl::opt emit_stablehlo( + "emit-stablehlo", + llvm::cl::desc("Allow a mix of MHLO and StableHLO ops in the output.")); } // namespace static mlir::OwningOpRef HloToMlirHloTranslate( llvm::StringRef input, mlir::MLIRContext* context) { return xla::HloToMlirHloTranslateFunction( - input, context, import_all_computations, flatten_computation_args_result); + input, context, import_all_computations, flatten_computation_args_result, + emit_stablehlo); } static mlir::OwningOpRef HloTextToMlirHloTranslate( llvm::StringRef input, mlir::MLIRContext* context) { return xla::HloTextToMlirHloTranslateFunction( - input, context, import_all_computations, flatten_computation_args_result); + input, context, import_all_computations, flatten_computation_args_result, + emit_stablehlo); } static mlir::OwningOpRef HloToStablehloTranslate( diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index d62fcee25dabd8..f7642181220077 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -150,7 +150,8 @@ absl::StatusOr> ConvertHloToStablehlo( llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), /*import_all_computation=*/true, - /*flatten_computation_args_result=*/true) + /*flatten_computation_args_result=*/true, + /*emit_stablehlo=*/true) .Import(*hlo_module)); TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); return std::move(mlir_module); @@ -162,7 +163,8 @@ absl::StatusOr> ConvertHloToStablehlo( llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), /*import_all_computation=*/true, - /*flatten_computation_args_result=*/true) + /*flatten_computation_args_result=*/true, + /*emit_stablehlo=*/true) .Import(*hlo_module_proto)); TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); return std::move(mlir_module); diff --git a/third_party/xla/xla/hlo/translate/xla_translate_main.cc b/third_party/xla/xla/hlo/translate/xla_translate_main.cc index 86d44d08b46e08..530a89546f8f7b 100644 --- a/third_party/xla/xla/hlo/translate/xla_translate_main.cc +++ b/third_party/xla/xla/hlo/translate/xla_translate_main.cc @@ -101,8 +101,9 @@ int main(int argc, char** argv) { }; if (splitInputFile) { - if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer, - output->os()))) + if (failed(mlir::splitAndProcessBuffer( + std::move(input), processBuffer, output->os(), + mlir::kDefaultSplitMarker, mlir::kDefaultSplitMarker))) return 1; } else { if (failed(processBuffer(std::move(input), output->os()))) return 1; From 2a53a30bfb777a3284d48ca5771893dfa2180dfd Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Tue, 8 Apr 2025 08:58:05 -0700 Subject: [PATCH 0368/1324] PR #24834: Define macros for is_cuda_configured and is_rocm_configured Imported from GitHub PR https://github.com/openxla/xla/pull/24834 This is a split from #23997 Copybara import of the project: -- d69f6e3aaa12d37606ca1840fe6b0e397d3b7d02 by Shraiysh Vaishay : Define macros for is_cuda_configured and is_rocm_configured This is a split from #23997 Merging this change closes #24834 PiperOrigin-RevId: 745165425 --- third_party/gpus/cuda/build_defs.bzl.tpl | 6 ++++++ third_party/gpus/rocm/build_defs.bzl.tpl | 6 ++++++ third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl | 6 ++++++ third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl | 6 ++++++ 4 files changed, 24 insertions(+) diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 6c1b68ffb77bcf..b5767abf5be9e6 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -94,6 +94,12 @@ def if_cuda_is_configured(x, no_cuda = []): return select({"//conditions:default": x}) return select({"//conditions:default": no_cuda}) +def is_cuda_configured(): + """ + Returns True if CUDA is configured. False otherwise. + """ + return %{cuda_is_configured} + def if_cuda_newer_than(wanted_ver, if_true, if_false = []): """Tests if CUDA was enabled during the configured process and if the configured version is at least `wanted_ver`. `wanted_ver` needs diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index d327083e4dc8ea..4c3a85c1730c6b 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -64,6 +64,12 @@ def if_rocm_is_configured(if_true, if_false = []): return select({"//conditions:default": if_true}) return select({"//conditions:default": if_false}) +def is_rocm_configured(): + """ + Returns True if ROCm is configured. False otherwise. + """ + return %{rocm_is_configured} + def rocm_hipblaslt(): return %{rocm_is_configured} and %{rocm_hipblaslt} diff --git a/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl index 6c1b68ffb77bcf..b5767abf5be9e6 100644 --- a/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl @@ -94,6 +94,12 @@ def if_cuda_is_configured(x, no_cuda = []): return select({"//conditions:default": x}) return select({"//conditions:default": no_cuda}) +def is_cuda_configured(): + """ + Returns True if CUDA is configured. False otherwise. + """ + return %{cuda_is_configured} + def if_cuda_newer_than(wanted_ver, if_true, if_false = []): """Tests if CUDA was enabled during the configured process and if the configured version is at least `wanted_ver`. `wanted_ver` needs diff --git a/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl index d327083e4dc8ea..4c3a85c1730c6b 100644 --- a/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl @@ -64,6 +64,12 @@ def if_rocm_is_configured(if_true, if_false = []): return select({"//conditions:default": if_true}) return select({"//conditions:default": if_false}) +def is_rocm_configured(): + """ + Returns True if ROCm is configured. False otherwise. + """ + return %{rocm_is_configured} + def rocm_hipblaslt(): return %{rocm_is_configured} and %{rocm_hipblaslt} From 37c09a07f39c63bc23c05bfbd0582cd729d2d52e Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Tue, 8 Apr 2025 09:09:54 -0700 Subject: [PATCH 0369/1324] PR #17637: [XLA:CPU][oneDNN] Alias result to addend when feasible Imported from GitHub PR https://github.com/openxla/xla/pull/17637 This PR adds a new oneDNN post-op type that will perform in-place addition whenever there is no broadcast overhead. It also includes matmul and convolution tests to verify the functionality. Copybara import of the project: -- dec5edc7a80c20427e950f5fea33e50c6b12be21 by Akhil Goel : Refactor and add new post-op -- 4da2d7430fdf65be8a129951560d8bf31aa37a55 by Akhil Goel : Wrap long lines -- eb8cfafc4e8020dc9c18e22da33b9352c825619b by Akhil Goel : Address review comments. Merging this change closes #17637 PiperOrigin-RevId: 745170284 --- .../xla/xla/service/cpu/onednn_config.proto | 3 +- .../cpu/onednn_contraction_rewriter.cc | 19 ++++++- .../xla/xla/service/cpu/onednn_convolution.cc | 6 ++ .../xla/xla/service/cpu/onednn_matmul.cc | 6 ++ .../xla/xla/service/cpu/onednn_util.cc | 3 + .../cpu/tests/onednn_convolution_test.cc | 19 +++++++ .../service/cpu/tests/onednn_matmul_test.cc | 57 ++++++++++++++++++- 7 files changed, 110 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/cpu/onednn_config.proto b/third_party/xla/xla/service/cpu/onednn_config.proto index 44829a6857f1f9..ab534a8f00689c 100644 --- a/third_party/xla/xla/service/cpu/onednn_config.proto +++ b/third_party/xla/xla/service/cpu/onednn_config.proto @@ -36,7 +36,7 @@ message OneDnnOptimizationConfig { } message OneDnnFusionConfig { - // These enum needs to be mapped to oneDNN enum for post_op algorithm. + // This enum needs to be mapped to oneDNN enum for post_op algorithm. // TODO(intel-tf): Add kinds supported by oneDNN. enum FusionKind { UNDEFINED = 0; @@ -50,6 +50,7 @@ message OneDnnFusionConfig { ELU = 8; RELU6 = 9; SIGMOID = 10; + SUM = 11; // This represents in-place accumulation. } repeated FusionKind ops = 1; // To avoid protobuf failures for specific decimal values, diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index ed9a9e4b0f7964..0cefc32d0b663b 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -724,8 +724,11 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } // Validate addend for fusion. + auto addend_user_count = addend->user_count(); + auto addend_idx = -1; if (IsSupportedType(addend->shape().element_type()) && IsOperandFusible(addend, contraction)) { + addend_idx = new_operands.size(); new_operands.push_back(addend); } else { return absl::OkStatus(); @@ -736,6 +739,10 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { contraction->shape(), new_operands))); auto backend_config = custom_call->backend_config(); + bool can_fuse_sum = + (ShapeUtil::Equal(custom_call->shape(), addend->shape()) && + addend_user_count == 1 && + custom_call->output_operand_aliasing().empty()); auto fusions_config = GetFusionsConfig(&backend_config); auto optimization_config = GetOptimizationsConfig(&backend_config); // TODO(intel-tf): Here, we allow 1D addends only when they are the first @@ -745,9 +752,15 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { (ShapeUtil::TrueNumDimensions(addend->shape()) == 1) ? (fusions_config->ops().empty() ? OneDnnFusionConfig::BIAS : OneDnnFusionConfig::UNDEFINED) - : OneDnnFusionConfig::BINARY_ADD; + : can_fuse_sum ? OneDnnFusionConfig::SUM + : OneDnnFusionConfig::BINARY_ADD; if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); + // Alias output buffers to addend for in-place accumulation + if (kind == OneDnnFusionConfig::SUM) { + custom_call->set_output_to_operand_aliasing({{{}, {addend_idx, {}}}}); + } + fusions_config->add_ops(kind); if (optional_addend_broadcast) { @@ -1154,6 +1167,10 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { auto scratch_add = AddScratch(custom_call); if (scratch_add.ok()) { custom_call = *scratch_add; + auto aliases = custom_call->output_operand_aliasing(); + if (!aliases.empty()) { + custom_call->set_output_to_operand_aliasing({{{0}, aliases[0].second}}); + } } else { VLOG(2) << scratch_add.status(); } diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.cc b/third_party/xla/xla/service/cpu/onednn_convolution.cc index 2bda9a0003bfe5..f6f546f3d975cf 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution.cc @@ -266,6 +266,12 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( std::vector fused_mds; std::vector fused_bufs; for (int64_t i = 0; i < num_fused_operands; ++i) { + // Skip the MemrefInfo object for the SUM operand, as oneDNN does not + // require an input and performs in-place accumulation. + if (conv_config.fusions().ops(i) == OneDnnFusionConfig::SUM) { + arg_indx++; + continue; + } MemrefInfo operand_minfo(args[arg_indx++]); memory::desc mem_desc = operand_minfo.GetOneDnnMemDesc(); if (mem_desc.get_ndims() == new_res_md.get_ndims()) { diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 0e1a6a6d92981a..91d5979473bc62 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -283,6 +283,12 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( std::vector fused_mds; std::vector fused_bufs; for (int64_t i = 0; i < num_fused_operands; ++i) { + // Skip the MemrefInfo object for the SUM operand, as oneDNN does not + // require an input and performs in-place accumulation. + if (matmul_config.fusions().ops(i) == OneDnnFusionConfig::SUM) { + arg_indx++; + continue; + } MemrefInfo operand_minfo(args[arg_indx++]); fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); fused_bufs.push_back(operand_minfo.Data()); diff --git a/third_party/xla/xla/service/cpu/onednn_util.cc b/third_party/xla/xla/service/cpu/onednn_util.cc index c1e727cc73be15..8cb22022c04b5c 100644 --- a/third_party/xla/xla/service/cpu/onednn_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_util.cc @@ -67,6 +67,9 @@ dnnl::post_ops PopulateOneDnnPostOps( case OneDnnFusionConfig::SIGMOID: post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f); break; + case OneDnnFusionConfig::SUM: + post_ops.append_sum(); + break; case OneDnnFusionConfig::BIAS: { *bias_md = fused_mds.at(fused_operand_idx); if (fused_operands_ref) { diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc index 869ec371b82409..e9a04ff5d4e115 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc @@ -345,6 +345,25 @@ TEST_P(ConvolutionTest, ToeplitzConstrcutionTest) { RunCompareAndMatchOptimizedHlo(outline, {"BINARY_ADD"}); } +TEST_P(ConvolutionTest, Conv2DWithSumTest) { + const absl::string_view outline = R"( + HloModule convolution.test.with.sum + ENTRY convolution.test.with.sum { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[1,11,11,1] parameter(1) + constant.3 = $dtype[] constant(1) + broadcast.4 = $dtype[8,8,1,1] broadcast(constant.3), dimensions={} + convolution.0 = $dtype[1,11,11,1] convolution(arg0.1, broadcast.4), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + ROOT add.10 = $dtype[1,11,11,1] add(convolution.0, arg0.2) + })"; + + // Optimized HLO must match "SUM" only for precisions that support Elementwise + // Add operations + RunCompareAndMatchOptimizedHlo(outline, + {(dtype_ == BF16) ? "BINARY_ADD" : "SUM"}); +} + TEST_P(ConvolutionTest, Conv2DWithBiasAndTanhTest) { const absl::string_view outline = R"( HloModule convolution.bias.tanh.test diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc index 8a89a6656dd337..c2053ee8851eba 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc @@ -64,6 +64,17 @@ class MatmulTest : public HloTestBase { ; CHECK-DAG: } ; CHECK: } )"; + const char* fused_matmul_sum_ = R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["SUM"] + ; CHECK-DAG: } + ; CHECK-DAG: } + ; CHECK: } + )"; const char* matmul_rewrite_str_ = R"( ; CHECK: custom_call_target="__onednn$matmul", ; CHECK: backend_config={ @@ -267,7 +278,51 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter1) { })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestF32Add2Dots) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0) + arg0.2 = f32[32,32,30,40] parameter(1) + arg0.3 = f32[32,32,40,40] parameter(2) + arg0.4 = f32[32,32,40,40] parameter(3) + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.8 = f32[32,32,40,40] dot(arg0.3, arg0.4), lhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT add.10 = f32[32,32,40,40] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestF16Add2Dots) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f16 + + ENTRY matmul.biasadd.test.f16 { + arg0.1 = f16[32,64,128] parameter(0) + arg0.2 = f16[32,128,64] parameter(1) + arg0.3 = f16[32,64,64] parameter(2) + arg0.4 = f16[32,64,64] parameter(3) + dot.7 = f16[32,64,64] dot(arg0.1, arg0.2), lhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + dot.8 = f16[32,64,64] dot(arg0.3, arg0.4), lhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT add.10 = f16[32,64,64] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); } TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2) { From f161b1fc93f8afd9d0126dfccdafbd04fabc46fb Mon Sep 17 00:00:00 2001 From: mmakevic-amd Date: Tue, 8 Apr 2025 09:11:56 -0700 Subject: [PATCH 0370/1324] PR #23976: [ROCm] Fix ragged_all_to_all_kernel_test Imported from GitHub PR https://github.com/openxla/xla/pull/23976 Test was failing with mismatch error: ``` xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc:153: Failure Expected equality of these values: output_results Which is: { { 2, 3, 4, 5, -0.372549, -0.372549, -0.372549, -0.372549, 8, 9, 10, 11, 12, 13, -0.372549, -0.372549 }, { -0.372549, -0.372549, 0, 1, -0.372549, -0.372549, -0.372549, -0.372549, -0.372549, -0.372549, 6, 7, 8, 9, -0.372549, -0.372549 } } expected_output_results Which is: { { 2, 3, 4, 5, 0, 0, 0, 0, 8, 9, 10, 11, 12, 13, 0, 0 }, { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 0, 0 } } ``` This PR initializes output_buffers to 0 before `RunRaggedAllToAllKernel` Copybara import of the project: -- dc0b44ef70a2a9c000996fa5ea47c6e9cbe4382e by Milica Makevic : Initialize output_buffers Merging this change closes #23976 PiperOrigin-RevId: 745171039 --- .../xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc index 9015afab958dd1..c7f9266504e346 100644 --- a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc @@ -90,6 +90,8 @@ TEST_F(RaggedAllToAllKernelTest, SimpleKernelTest) { for (int64_t i = 0; i < num_outputs; ++i) { output_buffers.emplace_back(executor, executor->AllocateArray(n)); ASSERT_TRUE(!output_buffers[i].memory().is_null()); + TF_ASSERT_OK( + stream->MemZero(output_buffers[i].memory_ptr(), n * sizeof(T))); } stream_executor::DeviceMemoryHandle input_offsets_buffer( From 1a3aac4984bc67416ab22880252a5726cdcc2443 Mon Sep 17 00:00:00 2001 From: jparkerh Date: Tue, 8 Apr 2025 09:25:31 -0700 Subject: [PATCH 0371/1324] move header include to the correct location PiperOrigin-RevId: 745175684 --- third_party/xla/xla/pjrt/cpu/cpu_client.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index 8b1f708c444e30..77106bc68d9262 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -1,4 +1,3 @@ -#include "xla/pjrt/async_work_runner.h" /* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -46,6 +45,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/pjrt/async_work_runner.h" #include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" #include "xla/pjrt/cpu/cpu_device.h" #include "xla/pjrt/cpu/cpu_event.h" @@ -56,7 +56,6 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" -#include "xla/pjrt/plugin/xla_cpu/cpu_device_description.h" #include "xla/pjrt/plugin/xla_cpu/cpu_topology_description.h" #include "xla/pjrt/transpose.h" #include "xla/service/buffer_assignment.h" From 9f47a27bccb3c3bf887ee13da4271630ce1f9b71 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 8 Apr 2025 09:55:29 -0700 Subject: [PATCH 0372/1324] [XLA:GPU] Add `IsMultiHostTopology` utility function. The utility is heuristic based until the device topology is available at compile time. PiperOrigin-RevId: 745186253 --- .../service/gpu/transforms/collectives/BUILD | 13 ++++ .../collectives/collective_ops_utils.cc | 14 ++++ .../collectives/collective_ops_utils.h | 7 ++ .../collectives/collective_ops_utils_test.cc | 68 +++++++++++++++++++ .../cuda/cuda_compute_capability.h | 2 + 5 files changed, 104 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD index d6f76342cdeb6d..aeb588f65671da 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD @@ -54,6 +54,7 @@ cc_library( hdrs = ["collective_ops_utils.h"], deps = [ "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", @@ -62,6 +63,18 @@ cc_library( ], ) +xla_cc_test( + name = "collective_ops_utils_test", + srcs = ["collective_ops_utils_test.cc"], + deps = [ + ":collective_ops_utils", + "//xla/service:hlo_module_config", + "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "gpu_collective_combiner_utils", srcs = ["gpu_collective_combiner_utils.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc index bdec66b97a3c13..140b20462df012 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" @@ -69,5 +70,18 @@ absl::StatusOr CommunicationType( return GPUCommunicationType::UNDEFINED; } +bool IsMultiHostTopology(const HloModuleConfig& config, + const se::DeviceDescription& device_description) { + // TODO: b/390095346 - Use topology information once available at compile + // time. + if (device_description.cuda_compute_capability().IsHopper()) { + return config.num_partitions() * config.replica_count() > 8; + } + if (device_description.cuda_compute_capability().IsAmpere()) { + return config.num_partitions() * config.replica_count() > 16; + } + return false; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h index 859ba4bb5cb504..88c9ee37a2550a 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -38,6 +39,12 @@ absl::StatusOr CommunicationType( // Returns true if instruction is a synchronous collective op. bool IsGPUSyncCollective(const HloInstruction& instr); +// Returns true if the topology is multi-host. Currently this function is +// heuristic based. Will return false on any platform other than Hopper and +// Ampere. +bool IsMultiHostTopology(const HloModuleConfig& config, + const se::DeviceDescription& device_description); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc new file mode 100644 index 00000000000000..ebcd83b7d07325 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" + +#include +#include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { +namespace { + +bool IsMultiHostTopology(se::CudaComputeCapability compute_capability, + int num_partitions, int replica_count) { + HloModuleConfig config; + config.set_num_partitions(num_partitions); + config.set_replica_count(replica_count); + se::DeviceDescription device_description; + device_description.set_gpu_compute_capability(compute_capability); + return xla::gpu::IsMultiHostTopology(config, device_description); +} + +TEST(IsMultiHostTopologyTest, SingleHostSingleDevice) { + EXPECT_FALSE(IsMultiHostTopology(se::CudaComputeCapability::Ampere(), + /*num_partitions=*/1, /*replica_count=*/1)); + EXPECT_FALSE(IsMultiHostTopology(se::CudaComputeCapability::Hopper(), + /*num_partitions=*/1, /*replica_count=*/1)); +} + +TEST(IsMultiHostTopologyTest, SingleHostMultiDevices) { + EXPECT_FALSE(IsMultiHostTopology(se::CudaComputeCapability::Ampere(), + /*num_partitions=*/16, /*replica_count=*/1)); + EXPECT_FALSE(IsMultiHostTopology(se::CudaComputeCapability::Ampere(), + /*num_partitions=*/1, /*replica_count=*/16)); + EXPECT_FALSE(IsMultiHostTopology(se::CudaComputeCapability::Hopper(), + /*num_partitions=*/8, /*replica_count=*/1)); + EXPECT_FALSE(IsMultiHostTopology(se::CudaComputeCapability::Hopper(), + /*num_partitions=*/1, /*replica_count=*/8)); +} + +TEST(IsMultiHostTopologyTest, MultiHosts) { + EXPECT_TRUE(IsMultiHostTopology(se::CudaComputeCapability::Ampere(), + /*num_partitions=*/32, /*replica_count=*/1)); + EXPECT_TRUE(IsMultiHostTopology(se::CudaComputeCapability::Ampere(), + /*num_partitions=*/1, /*replica_count=*/32)); + EXPECT_TRUE(IsMultiHostTopology(se::CudaComputeCapability::Hopper(), + /*num_partitions=*/16, /*replica_count=*/1)); + EXPECT_TRUE(IsMultiHostTopology(se::CudaComputeCapability::Hopper(), + /*num_partitions=*/1, /*replica_count=*/16)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h b/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h index aa4865f12ef100..109ee3ce10519d 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_compute_capability.h @@ -98,6 +98,8 @@ struct CudaComputeCapability { return major >= CudaComputeCapabilities::kBlackwell; } + bool IsAmpere() const { return major == CudaComputeCapabilities::kAmpere; } + bool IsHopper() const { return major == CudaComputeCapabilities::kHopper; } bool IsBlackwell() const { From 45b5ed4bb9ea452efac074ff2b07be8fa03e65ae Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 8 Apr 2025 09:58:32 -0700 Subject: [PATCH 0373/1324] [XLA:GPU] Extract collective combiner passes into a helper function. PiperOrigin-RevId: 745187417 --- .../xla/xla/service/gpu/gpu_compiler.cc | 73 +++++++++---------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index af4f084c2c9805..e822d46aaae140 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1158,43 +1158,44 @@ void AddDoubleBufferingPasses(const HloModule& module, } } -absl::Status RunPostFusionPasses( - HloModule* hlo_module, const se::DeviceDescription& device_description, - int pointer_size, const int combine_threshold_count) { - const DebugOptions& opts = hlo_module->config().debug_options(); +constexpr int kCombineThresholdCount = 256; - HloPassPipeline pipeline("post-fusion optimization"); - pipeline.AddPass(); - if (hlo_module->config() - .debug_options() - .xla_gpu_experimental_enable_sync_collective_combining()) { +void AddCollectiveCombinerPasses( + HloPassPipeline& pipeline, const HloModule& module, + const se::DeviceDescription& device_description, int pointer_size) { + const DebugOptions& opts = module.config().debug_options(); + + if (opts.xla_gpu_experimental_enable_sync_collective_combining()) { pipeline.AddPass(device_description, pointer_size); } + pipeline.AddPass( - device_description, - /*default_combine_threshold_in_bytes=*/kDefaultAllGatherCombineThreshold, - /*combine_threshold_in_bytes=*/ - opts.xla_gpu_all_gather_combine_threshold_bytes(), - combine_threshold_count, - /*combine_by_dim=*/opts.xla_gpu_enable_all_gather_combine_by_dim(), - /*combine_different_dtypes=*/true, /*pointer_size=*/pointer_size); + device_description, kDefaultAllGatherCombineThreshold, + opts.xla_gpu_all_gather_combine_threshold_bytes(), kCombineThresholdCount, + opts.xla_gpu_enable_all_gather_combine_by_dim(), + /*combine_different_dtypes=*/true, pointer_size); pipeline.AddPass( device_description, kDefaultAllReduceCombineThreshold, - opts.xla_gpu_all_reduce_combine_threshold_bytes(), - combine_threshold_count, /*pointer_size=*/pointer_size); + opts.xla_gpu_all_reduce_combine_threshold_bytes(), kCombineThresholdCount, + pointer_size); pipeline.AddPass( - device_description, /*default_combine_threshold_in_bytes=*/ - kDefaultReduceScatterCombineThreshold, - /*combine_threshold_in_bytes=*/ + device_description, kDefaultReduceScatterCombineThreshold, opts.xla_gpu_reduce_scatter_combine_threshold_bytes(), - combine_threshold_count, - /*combine_by_dim=*/opts.xla_gpu_enable_reduce_scatter_combine_by_dim(), - /*pointer_size=*/pointer_size); + kCombineThresholdCount, + opts.xla_gpu_enable_reduce_scatter_combine_by_dim(), pointer_size); pipeline.AddPass( - /*combine_threshold_in_bytes=*/ opts.xla_gpu_collective_permute_combine_threshold_bytes(), - combine_threshold_count); + kCombineThresholdCount); +} + +absl::Status RunPostFusionPasses( + HloModule* hlo_module, const se::DeviceDescription& device_description, + int pointer_size) { + HloPassPipeline pipeline("post-fusion optimization"); + pipeline.AddPass(); + AddCollectiveCombinerPasses(pipeline, *hlo_module, device_description, + pointer_size); pipeline.AddPass(); @@ -1316,21 +1317,17 @@ absl::Status RunAsyncDotPasses(HloModule* hlo_module) { absl::Status RunDynamicSliceFusionPasses( HloModule* hlo_module, se::Platform::Id platform_id, - const se::DeviceDescription& device_description, int64_t pointer_size, - const int combine_threshold_count) { + const se::DeviceDescription& device_description, int64_t pointer_size) { const DebugOptions& opts = hlo_module->config().debug_options(); if (opts.xla_gpu_enable_dynamic_slice_fusion()) { HloPassPipeline pipeline("dynamic-slice"); TF_ASSIGN_OR_RETURN(se::Platform * platform, se::PlatformManager::PlatformWithId(platform_id)); pipeline.AddPass( - device_description, /*default_combine_threshold_in_bytes=*/ - kDefaultReduceScatterCombineThreshold, - /*combine_threshold_in_bytes=*/ + device_description, kDefaultReduceScatterCombineThreshold, opts.xla_gpu_reduce_scatter_combine_threshold_bytes(), - /*combine_threshold_count=*/combine_threshold_count, - /*combine_by_dim=*/opts.xla_gpu_enable_reduce_scatter_combine_by_dim(), - /*pointer_size=*/pointer_size); + kCombineThresholdCount, + opts.xla_gpu_enable_reduce_scatter_combine_by_dim(), pointer_size); pipeline.AddPass(platform->Name()); pipeline.AddPass([](const HloInstruction* instr) { if (!IsDynamicSliceFusion(instr)) { @@ -1416,19 +1413,17 @@ absl::Status GpuCompiler::OptimizeHloModule( hlo_module, stream_exec, options, gpu_target_config, thread_pool.get_mutable())); - const int combine_threshold_count = 256; - // This is a "low effort, high impact" fusion that should be run first. TF_RETURN_IF_ERROR(RunDynamicSliceFusionPasses( hlo_module, /*platform_id=*/PlatformId(), /*device_description=*/gpu_target_config.device_description, - /*pointer_size=*/pointer_size_, combine_threshold_count)); + /*pointer_size=*/pointer_size_)); TF_RETURN_IF_ERROR(RunFusionPasses(hlo_module, gpu_target_config, thread_pool.get_mutable(), ShapeSizeBytesFunction())); - TF_RETURN_IF_ERROR(RunPostFusionPasses( - hlo_module, device_description, pointer_size_, combine_threshold_count)); + TF_RETURN_IF_ERROR( + RunPostFusionPasses(hlo_module, device_description, pointer_size_)); TF_RETURN_IF_ERROR(RunAsyncCollectivesConversionPasses(hlo_module)); TF_RETURN_IF_ERROR(RunPostFusionSimplificationPasses( hlo_module, layout_insensitive_algsimp_opts, gpu_version, From df2fba15cfad039d3981ee3911c2151863ed14c5 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 8 Apr 2025 10:12:40 -0700 Subject: [PATCH 0374/1324] [XLA:GPU] Clean up `TritonGemmTest.BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported`. PiperOrigin-RevId: 745193603 --- .../triton/fusion_emitter_device_legacy_port_test.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 3e0dd108960baf..bbc4a7e9651e96 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -1211,7 +1211,7 @@ ENTRY e { TEST_F(TritonGemmTest, BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported) { - EXPECT_TRUE(RunAndCompare(R"( + constexpr absl::string_view kHloText = R"( f { p0 = f32[64,6464] parameter(0) p1 = f32[16,6464] parameter(1) @@ -1232,8 +1232,13 @@ e { kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config": {"block_m":"16","block_n":"16","block_k":"64","split_k":"1", "num_stages":"1","num_warps":"4","num_ctas":"1"}}} -})", - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +})"; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, From e5c782bcdc2eee3e79997d07e58f5f02e2e67c34 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 8 Apr 2025 10:26:10 -0700 Subject: [PATCH 0375/1324] Use standard TestBufferDonationClashes for PjRtStreamExecutorLoadedExecutable as well. PiperOrigin-RevId: 745199390 --- .../xla/pjrt/pjrt_stream_executor_client.cc | 29 +++---------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index e31ed5dc48cc68..6a8a97eb674121 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -2523,8 +2523,8 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( absl::Span donated_params = ParametersThatMustBeDonated(executable_idx); auto donate_it = donated_params.begin(); - absl::flat_hash_set used_buffers; - absl::flat_hash_set donated_buffers; + absl::flat_hash_map> donation_clashes; + donation_clashes.reserve(argument_handles.size()); for (int i = 0; i < argument_handles.size(); ++i) { auto* handle = tensorflow::down_cast(argument_handles[i]); @@ -2541,29 +2541,8 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( if (must_donate) { ++donate_it; } - bool already_used = !used_buffers.emplace(handle).second; - bool already_donated = - must_donate ? !donated_buffers.emplace(handle).second - : donated_buffers.find(handle) != donated_buffers.end(); - if (must_donate && already_donated) { - return InvalidArgument( - "Attempt to donate the same buffer twice in Execute() (second use: " - "flattened argument %d, replica %d). " - "Toy example for this bug: `f(donate(a), donate(a))`.", - i, replica); - } else if (must_donate && already_used) { - return InvalidArgument( - "Attempt to donate a buffer which is also used by the same call to " - "Execute() (second use: flattened argument %d, replica %d). " - "Toy example for this bug: `f(a, donate(a))`.", - i, replica); - } else if (already_donated) { - return InvalidArgument( - "Attempt to use a buffer that was previously donated in the same " - "call to Execute() (second use: flattened argument %d, replica %d). " - "Toy example for this bug: `f(donate(a), a)`.", - i, replica); - } + TF_RETURN_IF_ERROR(TestBufferDonationClashes( + handle, donation_clashes, must_donate, i, replica, partition)); device_buffers->emplace_back(handle->GetBufferWithHold( must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation : PjRtStreamExecutorBuffer::ScopedHold::kUsage)); From 9893e57107fffed9e2e34bb1eab26d4c05f90a98 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 8 Apr 2025 10:36:17 -0700 Subject: [PATCH 0376/1324] Update AbstractTfrtCpuBuffer to use the new unified ScopedHold logic. PiperOrigin-RevId: 745203548 --- third_party/xla/xla/pjrt/cpu/BUILD | 2 + .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 123 +++++------------- .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h | 107 +++++---------- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 18 ++- .../xla/pjrt/cpu/tracked_cpu_device_buffer.h | 5 +- 5 files changed, 79 insertions(+), 176 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index dd26df16f693a6..299bb4f33f2d1b 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -34,6 +34,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/backends/cpu:alignment", + "//xla/pjrt:abstract_tracked_device_buffer", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base", @@ -77,6 +78,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/pjrt:abstract_tracked_device_buffer", "//xla/pjrt:async_work_runner", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_future", diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 45978171c6e855..157fc5def33c5e 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/cpu_function_runtime.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/async_work_runner.h" #include "xla/pjrt/cpu/cpu_event.h" #include "xla/pjrt/cpu/tracked_cpu_device_buffer.h" @@ -122,8 +123,8 @@ ShapedBuffer AsShapedBuffer(int device_ordinal, const Shape& on_device_shape, AbstractTfrtCpuBuffer::AbstractTfrtCpuBuffer( Shape on_device_shape, std::unique_ptr tracked_device_buffer) - : on_device_shape_(std::move(on_device_shape)), - tracked_device_buffer_(std::move(tracked_device_buffer)) {} + : CommonPjRtBuffer(std::move(tracked_device_buffer)), + on_device_shape_(std::move(on_device_shape)) {} AbstractTfrtCpuBuffer::~AbstractTfrtCpuBuffer() { AbstractTfrtCpuBuffer::Delete(); @@ -167,9 +168,10 @@ absl::StatusOr> AbstractTfrtCpuBuffer::AcquireExternalReference() { class ScopedExternalReference : public PjRtBuffer::ExternalReference { public: - explicit ScopedExternalReference(AbstractTfrtCpuBuffer* buffer, - tsl::AsyncValueRef data) - : buffer_(buffer), data_(std::move(data)) { + explicit ScopedExternalReference(AbstractTfrtCpuBuffer::ScopedHold hold) + : external_reference_(std::move(hold)), + data_(external_reference_->buffer()) { + DCHECK(external_reference_.type() == ScopedHold::kExternalReference); DCHECK(data_); // We need to wait for the memory to be allocated before sharing it with // external frameworks like NumPy. @@ -178,30 +180,18 @@ AbstractTfrtCpuBuffer::AcquireExternalReference() { data_ptr_ = data_->untyped_data(); } - ~ScopedExternalReference() override { buffer_->DropExternalReference(); } + ~ScopedExternalReference() override = default; private: - AbstractTfrtCpuBuffer* buffer_ = nullptr; + AbstractTfrtCpuBuffer::ScopedHold external_reference_; // Keep a reference to the underlying data used. Note that it is still // users' responsibility to synchronize reads and writes to the data. tsl::AsyncValueRef data_; }; - absl::MutexLock lock(&mu_); - if (tracked_device_buffer_ == nullptr) { - return InvalidArgument("Buffer has been deleted or donated."); - } - - ++external_reference_counter_; - - return {std::make_unique( - this, tracked_device_buffer_->buffer())}; -} - -void AbstractTfrtCpuBuffer::DropExternalReference() { - absl::MutexLock lock(&mu_); - CHECK_GT(external_reference_counter_, 0); - --external_reference_counter_; + ScopedHold hold = GetBufferWithHold(ScopedHold::kExternalReference); + TF_RETURN_IF_ERROR(hold.status()); + return {std::make_unique(std::move(hold))}; } class TrackedCpuDeviceBufferExternalReference @@ -209,10 +199,10 @@ class TrackedCpuDeviceBufferExternalReference public: explicit TrackedCpuDeviceBufferExternalReference( std::unique_ptr tracked_device_buffer) - : tracked_device_buffer_(std::move(tracked_device_buffer)) { + : device_buffer_(std::move(tracked_device_buffer)) { // We need to wait for the memory to be allocated before sharing it with // external frameworks like NumPy. - const auto& buffer = tracked_device_buffer_->buffer(); + const auto& buffer = device_buffer_->buffer(); tsl::BlockUntilReady(buffer); CHECK(buffer.IsConcrete()); data_ptr_ = buffer->untyped_data(); @@ -221,7 +211,7 @@ class TrackedCpuDeviceBufferExternalReference ~TrackedCpuDeviceBufferExternalReference() override = default; private: - std::unique_ptr tracked_device_buffer_; + std::unique_ptr device_buffer_; }; absl::StatusOr> @@ -243,29 +233,10 @@ AbstractTfrtCpuBuffer::ReleaseDeviceMemoryOwnership( return ref; } -void AbstractTfrtCpuBuffer::CommitDonation() { - absl::MutexLock lock(&mu_); - CHECK(pending_donation_); - CHECK(!tracked_device_buffer_); - pending_donation_ = false; -} - -void AbstractTfrtCpuBuffer::AbortDonation( - std::unique_ptr device_buffer) { - absl::MutexLock lock(&mu_); - CHECK(pending_donation_); - CHECK(!tracked_device_buffer_); - pending_donation_ = false; - tracked_device_buffer_ = std::move(device_buffer); -} - void AbstractTfrtCpuBuffer::Delete() { - std::unique_ptr device_buffer; - { - absl::MutexLock lock(&mu_); - device_buffer = ReleaseBufferLocked(); - if (device_buffer == nullptr) return; - } + std::unique_ptr device_buffer( + static_cast(ReleaseBuffer().release())); + if (device_buffer == nullptr) return; // Now that all holds have completed and no more can be added, we can get // the final set of usage events. @@ -286,27 +257,10 @@ void AbstractTfrtCpuBuffer::Delete() { }); } -bool AbstractTfrtCpuBuffer::IsDeleted() { - absl::MutexLock lock(&mu_); - return tracked_device_buffer_ == nullptr; -} - -std::unique_ptr -AbstractTfrtCpuBuffer::ReleaseBufferLocked() { - auto condition = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { - return !pending_donation_; - }; - mu_.Await(absl::Condition(&condition)); - return std::move(tracked_device_buffer_); -} - absl::StatusOr> AbstractTfrtCpuBuffer::Release(bool wait_for_operations_to_complete) { - std::unique_ptr device_buffer; - { - absl::MutexLock lock(&mu_); - device_buffer = ReleaseBufferLocked(); - } + std::unique_ptr device_buffer( + static_cast(ReleaseBuffer().release())); if (device_buffer == nullptr) return {nullptr}; absl::InlinedVector, 4> events; @@ -334,33 +288,26 @@ AbstractTfrtCpuBuffer::Release(bool wait_for_operations_to_complete) { TrackedCpuDeviceBuffer* AbstractTfrtCpuBuffer::AcquireUsage( tsl::AsyncValueRef usage_event) { absl::MutexLock lock(&mu_); - if (!tracked_device_buffer_) { + if (!device_buffer()) { return nullptr; } - tracked_device_buffer_->AddUsageEvents(absl::MakeSpan(&usage_event, 1)); - return tracked_device_buffer_.get(); + device_buffer()->AddUsageEvents(absl::MakeSpan(&usage_event, 1)); + return device_buffer(); } -absl::StatusOr -AbstractTfrtCpuBuffer::AcquireDonation() { +AbstractTfrtCpuBuffer::ScopedHold AbstractTfrtCpuBuffer::GetBufferWithHold( + ScopedHold::Type type) { absl::MutexLock lock(&mu_); + // Ensure that at most one donation hold can be in progress at a time. + WaitForOutstandingDonationHold(); + ScopedHold hold(this, type); + AcquireHoldLocked(&hold); + return hold; +} - if (tracked_device_buffer_ == nullptr) { - return InvalidArgument("Donation requested for invalid buffer"); - } - - if (external_reference_counter_ > 0) { - return InvalidArgument( - "Donation requested for buffer with external reference"); - } - - CHECK(!pending_donation_); - pending_donation_ = true; - - // Swap out `tracked_device_buffer_` so that no one can acquire a usage event - // after this point. - return DonationTransaction(this, std::move(tracked_device_buffer_)); +AbstractTfrtCpuBuffer::ScopedHold AbstractTfrtCpuBuffer::AcquireDonation() { + return GetBufferWithHold(ScopedHold::kDonation); } PjRtFuture<> AbstractTfrtCpuBuffer::DoAsyncWorkOnBuffer( @@ -553,11 +500,11 @@ PjRtFuture<> AbstractTfrtCpuBuffer::GetReadyFuture() { tsl::AsyncValueRef definition_event; { absl::MutexLock lock(&mu_); - if (!tracked_device_buffer_) { + if (!device_buffer()) { return PjRtFuture<>(InvalidArgument( "GetReadyFuture() called on deleted or donated buffer")); } - definition_event = tracked_device_buffer_->definition_event(); + definition_event = device_buffer()->definition_event(); } DCHECK(definition_event); diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index a107d70f5cc0c3..0d0a18ee2b8606 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -36,6 +36,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/literal.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/async_work_runner.h" #include "xla/pjrt/cpu/cpu_event.h" #include "xla/pjrt/cpu/tracked_cpu_device_buffer.h" @@ -76,8 +77,25 @@ class MarkEventReadyOnExit { tsl::AsyncValueRef event_; }; -class AbstractTfrtCpuBuffer : public PjRtBuffer { +class AbstractTfrtCpuBuffer : public CommonPjRtBuffer { public: + class ScopedHold : public CommonPjRtBuffer::ScopedHold { + public: + TrackedCpuDeviceBuffer* buffer() const { + return static_cast( + CommonPjRtBuffer::ScopedHold::buffer()); + } + TrackedCpuDeviceBuffer* operator->() const { return buffer(); } + const TrackedCpuDeviceBuffer& operator*() const { return *buffer(); } + AbstractTfrtCpuBuffer* parent() const { + return static_cast( + CommonPjRtBuffer::ScopedHold::parent()); + } + + private: + using CommonPjRtBuffer::ScopedHold::ScopedHold; + friend class AbstractTfrtCpuBuffer; + }; AbstractTfrtCpuBuffer( Shape on_device_shape, std::unique_ptr tracked_device_buffer); @@ -107,8 +125,6 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { void Delete() override; - bool IsDeleted() override; - void CopyToRemoteDevice(PjRtFuture serialized_descriptor, RemoteSendCallback on_done) override { on_done(Unimplemented("CopyToRemoteDevice not implemented."), @@ -127,56 +143,12 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { TrackedCpuDeviceBuffer* AcquireUsage( tsl::AsyncValueRef usage_event); - // A helper class for managing a pending donation. It should be committed upon - // success. Otherwise, the donated buffer is returned to the - // AbstractTfrtCpuBuffer. - class DonationTransaction { - public: - explicit DonationTransaction( - AbstractTfrtCpuBuffer* buffer, - std::unique_ptr device_buffer) - : buffer_(buffer), device_buffer_(std::move(device_buffer)) { - CHECK(buffer_); - } - DonationTransaction(const DonationTransaction&) = delete; - DonationTransaction& operator=(const DonationTransaction&) = delete; - DonationTransaction(DonationTransaction&&) = default; - DonationTransaction& operator=(DonationTransaction&& other) noexcept { - Abort(); - - buffer_ = other.buffer_; - device_buffer_ = std::move(other.device_buffer_); - return *this; - } - - ~DonationTransaction() { Abort(); } - - // Commit the donation. The rvalue ref qualifier is used to ensure the - // semantic that it can be committed at most once. - void Commit() && { - buffer_->CommitDonation(); - device_buffer_.reset(); - } - - TrackedCpuDeviceBuffer* device_buffer() const { - return device_buffer_.get(); - } - - private: - void Abort() { - if (device_buffer_) buffer_->AbortDonation(std::move(device_buffer_)); - } - - AbstractTfrtCpuBuffer* buffer_ = nullptr; - std::unique_ptr device_buffer_; - }; - // Acquires the device buffer for exclusive donation. The caller of this // method is expected to use the usage events and definition events to // serialize this donation with previous usages. After this method is called, // calls to AcquireUsage() will fail. Returns error status if the buffer is // already donated or there is outstanding external references. - absl::StatusOr AcquireDonation(); + ScopedHold AcquireDonation(); // A helper function for PjRtClient::BufferFromHostLiteral. Copy the literal // to the current buffer asynchronously. `avs` is used to signal when the copy @@ -215,9 +187,19 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { const Shape& shape, AsyncWorkRunner* async_work_runner, absl::Mutex* transpose_mu, TransposePlanCache* transpose_cache); + // Returns a hold on the TrackedTfrtTpuDeviceBuffer holding the device + // buffers. See comment on ScopedHold. + ScopedHold GetBufferWithHold(ScopedHold::Type type); + protected: virtual absl::string_view buffer_name() const = 0; + TrackedCpuDeviceBuffer* device_buffer() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return static_cast( + CommonPjRtBuffer::device_buffer()); + } + PjRtFuture<> ToLiteralHelper(MutableLiteralBase* literal, AsyncWorkRunner* async_work_runner); @@ -243,17 +225,6 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { on_device_shape_.tuple_shapes_size() == 0; } - void DropExternalReference(); - - // Commits the pending donation by setting `pending_donation_` to false. - // `pending_donation_` must be true before calling this method. - void CommitDonation(); - - // Aborts the pending donation by returning the donated buffer, and setting - // `pending_donation_` to false. `pending_donation_` must be true before - // calling this method. - void AbortDonation(std::unique_ptr device_buffer); - // Similar to Delete, drops the buffer's reference to its associated device // memory, leaving the buffer in an invalid state, but returns the // TrackedCpuDeviceBuffer rather than freeing the device memory, so that @@ -271,25 +242,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { absl::StatusOr> Release( bool wait_for_operations_to_complete); - // Releases the device buffer by returning a unique_ptr of it. If there is - // outstanding donation or usage holds, this method blocks until those holds - // are committed or dropped. - std::unique_ptr ReleaseBufferLocked() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - const Shape on_device_shape_; - - mutable absl::Mutex mu_; - std::unique_ptr tracked_device_buffer_ - ABSL_GUARDED_BY(mu_); - // Count of external references on the buffer. - int external_reference_counter_ ABSL_GUARDED_BY(mu_) = 0; - - // `pending_donation_` indicates whether a donation is pending. The destructor - // of the AbstractTfrtCpuBuffer will wait for a pending donation, as the - // donation might fail. Note that concurrent calls to AcquireUsage() and - // AcquireDonation() might fail even if the pending donation is aborted later. - bool pending_donation_ ABSL_GUARDED_BY(mu_) = false; }; class AbstractAsyncHostToHostMemoryTransferManager diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index cf6a24d9195bd9..bde9dab1e928a5 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -1488,8 +1488,7 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( auto execute_event = tsl::MakeConstructedAsyncValueRef(); MarkEventReadyOnExit ready_on_exit(execute_event); - absl::InlinedVector - donation_transactions; + absl::InlinedVector donation_transactions; absl::InlinedVector, 4> tracked_buffers; @@ -1526,8 +1525,8 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( tfrt_buffer, donation_clashes, must_donate, i, replica, partition)); if (must_donate) { ++donate_it; - absl::StatusOr - donation_transaction = tfrt_buffer->AcquireDonation(); + TfrtCpuBuffer::ScopedHold donation_transaction = + tfrt_buffer->AcquireDonation(); // On CPU, we allow donation to succeed by introducing a copy. This was // added when enabling buffer donation on CPU since it turned out that a // number of users were holding external references to buffers that were @@ -1537,15 +1536,14 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( // After acquiring the buffer for donation, we retrieve the dependent // usage events. Note that we don't need any locking here as // AcquireDonation() is supposed to synchronize with other usages. - for (const auto& ev : - donation_transaction->device_buffer()->UsageEvents()) { + for (const auto& ev : donation_transaction->UsageEvents()) { if (!ev.IsAvailable()) { input_deps.push_back(ev.CopyRCRef()); } } - tracked_buffer = donation_transaction->device_buffer(); + tracked_buffer = donation_transaction.buffer(); tracked_buffers.emplace_back(/*can_donate=*/true, tracked_buffer); - donation_transactions.push_back(std::move(*donation_transaction)); + donation_transactions.push_back(std::move(donation_transaction)); return absl::OkStatus(); } } @@ -1760,7 +1758,7 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( } for (auto& donation_transaction : donation_transactions) { - std::move(donation_transaction).Commit(); + std::move(donation_transaction).ConfirmDonation(); } // Forward errors (if any) after executing compute function or thunks. @@ -1909,7 +1907,7 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( } for (auto& donation_transaction : donation_transactions) { - std::move(donation_transaction).Commit(); + std::move(donation_transaction).ConfirmDonation(); } if (!status.ok()) { diff --git a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h index 7bbad71a8172dd..0eaf5cf68a4e41 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/tracked_cpu_device_buffer.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/cpu/cpu_event.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -86,7 +87,7 @@ class CpuDeviceMemoryOwned : public CpuDeviceMemory { // or multiple memory regions for a tuple buffers. It also tracks the definition // and usage of the memory to allow for synchronized usage and deletion of CPU // memory. This class is thread-compatible. -class TrackedCpuDeviceBuffer { +class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer { public: // For non-tuple, takes a single buffer. // For tuple, takes the leaf buffers. Tuple index table created internally. @@ -151,6 +152,8 @@ class TrackedCpuDeviceBuffer { // buffer is passed to a computation that aliases its inputs to outputs. void ReleaseDeviceMemory(); + void ConfirmDonation() override { ReleaseDeviceMemory(); } + bool owns_buffers_; // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers. From 97da08e3228f9094d978a62e9c100e08a41049ac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 10:51:21 -0700 Subject: [PATCH 0377/1324] Fixes undefined behaviors in `bonsai_pjrt_c_api`. It's undefined behavior to `reinterpret_cast` an `A*` to a `B*` where `A` and `B` are unrelated and neither is a `char`/`unsigned char`/`std::byte` type. In particular, we cannot `reinterpret_cast` a pointer to a PJRT extension struct `Foo` to a `PJRT_Extension_Base*`, as `Foo` is not related to `PJRT_Extension_Base` as far as the compiler is concerned. Usually, the fix is to make `Foo` derive from `PJRT_Extension_Base`. However, the code needs to work as C, and C doesn't support inheritance. Therefore we make `Foo` contain a `PJRT_Extension_Base` variable as its first field instead. PiperOrigin-RevId: 745209597 --- .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 8 +++++--- .../xla/xla/pjrt/c/pjrt_c_api_helpers.cc | 8 +++++--- .../xla/pjrt/c/pjrt_c_api_layouts_extension.h | 4 +--- ...pjrt_c_api_memory_descriptions_extension.h | 4 +--- .../pjrt/c/pjrt_c_api_profiler_extension.h | 4 +--- .../pjrt/c/pjrt_c_api_raw_buffer_extension.h | 4 +--- .../pjrt/c/pjrt_c_api_raw_buffer_internal.cc | 8 +++++--- .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 19 +++++++++++-------- 8 files changed, 30 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index b00a40ce78bf51..6980d736a57104 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -319,9 +319,11 @@ PLUGIN_Profiler_Api profiler_api{ }; PJRT_Profiler_Extension profiler_extension{ - /*struct_size=*/PJRT_Profiler_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, - /*next=*/nullptr, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Profiler_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, + /*next=*/nullptr, + }, /*profiler_api=*/&profiler_api, }; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index 121a50f5699d92..1b64f384984041 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -1120,9 +1120,11 @@ PJRT_Profiler_Extension CreatePjrtProfilerExtension( traceme_name, tsl::profiler::ContextType::kPjrtLibraryCall); int64_t traceme_context_id = producer.GetContextId(); PJRT_Profiler_Extension profiler_extension{ - /*struct_size=*/PJRT_Profiler_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, - /*next=*/nullptr, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Profiler_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, + /*next=*/nullptr, + }, /*profiler_api=*/nullptr, /*traceme_context_id=*/traceme_context_id, }; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_layouts_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_layouts_extension.h index 4ab5a1f6dbe829..403101800198d7 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_layouts_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_layouts_extension.h @@ -111,9 +111,7 @@ typedef PJRT_Error* PJRT_Layouts_PJRT_Client_GetDefaultLayout( // --------------------------- Extension entrypoint ---------------------------- typedef struct PJRT_Layouts_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_Layouts_MemoryLayout_Destroy* PJRT_Layouts_MemoryLayout_Destroy; PJRT_Layouts_MemoryLayout_Serialize* PJRT_Layouts_MemoryLayout_Serialize; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h index 9f5fb6fc7246d0..06e24b6dab176b 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h @@ -69,9 +69,7 @@ typedef PJRT_Error* PJRT_MemoryDescription_Kind( PJRT_MemoryDescription_Kind_Args* args); typedef struct PJRT_MemoryDescriptions_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_DeviceDescription_MemoryDescriptions* PJRT_DeviceDescription_MemoryDescriptions; PJRT_MemoryDescription_Kind* PJRT_MemoryDescription_Kind; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h index 35222ed40d5eb2..26d1d3a387f4e0 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h @@ -29,9 +29,7 @@ extern "C" { #define PJRT_API_PROFILER_EXTENSION_VERSION 1 typedef struct PJRT_Profiler_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; // can be nullptr if PJRT_Profiler_Extension is used as an args extension PLUGIN_Profiler_Api* profiler_api; // valid only when used as an args extension diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_extension.h index 50e943f712ca5c..69d9c1ee80801b 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_extension.h @@ -110,9 +110,7 @@ typedef PJRT_Error* PJRT_RawBuffer_CopyRawHostToDevice( #define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type typedef struct PJRT_RawBuffer_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; _PJRT_API_STRUCT_FIELD(PJRT_RawBuffer_CreateRawAliasOfBuffer); _PJRT_API_STRUCT_FIELD(PJRT_RawBuffer_Destroy); _PJRT_API_STRUCT_FIELD(PJRT_RawBuffer_GetOnDeviceSizeInBytes); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_internal.cc index f0a0deccb8fa36..a7be1a18c74cdc 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_internal.cc @@ -98,9 +98,11 @@ PJRT_Error* PJRT_RawBuffer_CopyRawDeviceToHost( PJRT_RawBuffer_Extension CreateRawBufferExtension(PJRT_Extension_Base* next) { return { - /*struct_size=*/PJRT_RawBuffer_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_RawBuffer, - /*next=*/next, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_RawBuffer_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_RawBuffer, + /*next=*/next, + }, /*PJRT_RawBuffer_CreateRawAliasOfBuffer=*/ pjrt::PJRT_RawBuffer_CreateRawAliasOfBuffer, /*PJRT_RawBuffer_Destroy=*/pjrt::PJRT_RawBuffer_Destroy, diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 04c52242b4dbda..d47228e2617ce8 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -2831,9 +2831,11 @@ PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) { return PJRT_Layouts_Extension{ - /*struct_size=*/PJRT_Layouts_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type_Layouts, - /*next=*/next, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Layouts_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type_Layouts, + /*next=*/next, + }, /*PJRT_Layouts_MemoryLayout_Destroy=*/ pjrt::PJRT_Layouts_MemoryLayout_Destroy, /*PJRT_Layouts_MemoryLayout_Serialize=*/ @@ -2848,14 +2850,15 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) { PJRT_MemoryDescriptions_Extension CreateMemoryDescriptionsExtension( PJRT_Extension_Base* next) { return PJRT_MemoryDescriptions_Extension{ - /*struct_size=*/PJRT_MemoryDescriptions_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type_MemoryDescriptions, - /*next=*/next, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_MemoryDescriptions_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type_MemoryDescriptions, + /*next=*/next, + }, /*PJRT_DeviceDescription_MemorySpaces=*/ pjrt::PJRT_DeviceDescription_MemoryDescriptions, /*PJRT_MemoryDescription_Kind=*/ - pjrt::PJRT_MemoryDescription_Kind, - }; + pjrt::PJRT_MemoryDescription_Kind}; } } // namespace pjrt From c31ed8cc3809a72c1ae7b9dd00860534c981b451 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 11:04:45 -0700 Subject: [PATCH 0378/1324] Make set_parent() private in HLO instruction in order to help enforce the constraint that a instruction I has a parent P IFF I is in P. Merge Evaluate and EvaluateWithSubstitutions in HloEvaluator. HloEvaluator already had a mechanism for caching values. We can just substitute into that cached result. This avoids calling instruction->set_parent() on a cloned computation in the hlo evaluator (violating the constraint described above). Modifies existing usage of EvaluateWithSubstitutions PiperOrigin-RevId: 745215007 --- .../xla/backends/gpu/runtime/copy_thunk.cc | 9 +-- .../xla/hlo/analysis/while_loop_analysis.cc | 8 +- third_party/xla/xla/hlo/evaluator/BUILD | 1 - .../xla/xla/hlo/evaluator/hlo_evaluator.cc | 80 ++++--------------- .../xla/xla/hlo/evaluator/hlo_evaluator.h | 20 +++-- .../xla/hlo/evaluator/hlo_evaluator_test.cc | 34 ++++---- third_party/xla/xla/hlo/ir/hlo_instruction.h | 8 +- .../xla/xla/service/hlo_verifier_test.cc | 19 ----- 8 files changed, 57 insertions(+), 122 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc index ae7cc514eacf25..eb9a73f8884f7d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc @@ -227,11 +227,10 @@ absl::StatusOr EvaluateDynamicOffsets( Literal induction_variable_literal(offset.induction_variable->shape()); TF_RETURN_IF_ERROR( induction_variable_literal.SetIntegralAsS64({}, induction_variable)); - TF_ASSIGN_OR_RETURN( - Literal array_index_literal, - evaluator.EvaluateWithSubstitutions( - offset.offset, - {{offset.induction_variable, &induction_variable_literal}}, true)); + TF_ASSIGN_OR_RETURN(Literal array_index_literal, + evaluator.Evaluate(offset.offset, {}, true, + {{offset.induction_variable, + &induction_variable_literal}})); std::optional array_index = LiteralUtil::LiteralAsScalarInt64(array_index_literal); diff --git a/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc b/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc index 02c672ba3dd1df..5756e73ff3e9d0 100644 --- a/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc @@ -879,8 +879,8 @@ optional ComputeWhileLoopTripCount(const HloInstruction* while_op, for (int64_t trip_count = 0; trip_count != max_brute_force_iters + 1; ++trip_count) { - absl::StatusOr result = evaluator.EvaluateWithSubstitutions( - while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); + absl::StatusOr result = evaluator.Evaluate( + while_cond_root, {}, false, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; @@ -893,8 +893,8 @@ optional ComputeWhileLoopTripCount(const HloInstruction* while_op, // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. absl::StatusOr indvar_next_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); + evaluator.Evaluate(while_body_indvar_update, {}, false, + {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index d328996528f74b..76dd06e7066e07 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -56,7 +56,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", "//xla/service:call_graph", "//xla/service:compilation_environments", "//xla/service:dynamic_dimension_inference", diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index 530312a87bae1e..eaf8792d3405ad 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -59,10 +59,10 @@ limitations under the License. #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/index_util.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -965,9 +965,17 @@ absl::StatusOr HloEvaluator::Evaluate( absl::StatusOr HloEvaluator::Evaluate( const HloInstruction* instruction, PrecomputedAnalyses precomputed_analyses, - bool recursively_evaluate_nonconstant_operands) { + bool recursively_evaluate_nonconstant_operands, + const absl::flat_hash_map& + substitutions) { ScopedEvaluateState evaluate_state(&state_); + // Use the substitutions to manually set instructions results to a specific + // value. + for (const auto& [substituted_instr, literal_value] : substitutions) { + SetEvaluatedLiteralFor(substituted_instr, literal_value->Clone()); + } + call_graph_cache_.reset(); tuple_points_to_analysis_cache_.reset(); auto enable_partial_evaluation_cleanup = @@ -998,63 +1006,6 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction, return true; } -absl::StatusOr HloEvaluator::EvaluateWithSubstitutions( - const HloInstruction* instruction, - const absl::flat_hash_map& - substitutions, - bool recursively_evaluate_nonconstant_operands) { - auto value = substitutions.find(instruction); - if (value != substitutions.end()) { - return value->second->Clone(); - } - - std::vector> owned_operands; - for (const HloInstruction* operand : instruction->operands()) { - auto it = substitutions.find(operand); - if (it == substitutions.end()) { - if (recursively_evaluate_nonconstant_operands) { - TF_ASSIGN_OR_RETURN(Literal value, - EvaluateWithSubstitutions( - operand, substitutions, - recursively_evaluate_nonconstant_operands)); - owned_operands.push_back(HloInstruction::CreateConstant(value.Clone())); - } else { - if (!operand->IsConstant()) { - VLOG(2) << "EvaluateWithSubstitutions called when not all operands " - "are constant. Consider calling it with " - "`recursively_evaluate_non_constant_operands` true."; - } - owned_operands.push_back(operand->Clone()); - } - } else { - owned_operands.push_back( - HloInstruction::CreateConstant(it->second->Clone())); - } - } - - std::vector operands; - operands.reserve(owned_operands.size()); - for (auto& operand : owned_operands) { - operands.push_back(operand.get()); - } - - std::unique_ptr cloned_instruction = - instruction->CloneWithNewOperands(instruction->shape(), operands); - // TODO(phawkins): it's unfortunate that we need to call set_parent() here, - // since it violates the invariant that an instruction has a parent iff it is - // in a computation. - // It's probably better to avoid constructing new instructions here in the - // first place. - cloned_instruction->set_parent( - const_cast(instruction->parent())); - auto result = Evaluate(cloned_instruction.get()); - - // Undo the parent change, since it will confuse code that expects the - // instruction to be in a computation. - cloned_instruction->set_parent(nullptr); - - return result; -} absl::StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { @@ -1241,10 +1192,13 @@ absl::Status HloEvaluator::EvaluateInternal( } if (!recursively_evaluate_nonconstant_operands) { - if (!hlo_query::AllOperandsAreConstants(*instruction)) { - return absl::FailedPreconditionError( - absl::StrCat("Not all operands are constants. Instruction: ", - instruction->ToString())); + for (const HloInstruction* operand : instruction->operands()) { + if (!IsAlreadyEvaluated(operand, shape_index)) { + return absl::FailedPreconditionError( + absl::StrCat("Not all operands are constants or have known " + "results. Instruction: ", + instruction->ToString())); + } } } else { if (instruction->opcode() == HloOpcode::kGetTupleElement) { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index bd1890be10086d..e665869ab5479f 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -148,26 +148,24 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // The caller may pass in non-null `precomputed_analyses` to avoid // recomputation during evaluation; the caller must ensure that any // precomputed analyses were performed on the module containing `instruction`. + // The optional `substitutions` map can be used to substitute the given + // literals for any instruction in the evaluation graph, usually some of the + // instruction's operands. + // + // For example, given instruction = op(A, B, C) and the map + // {A = x, C = y}, this evaluates op(x, B, y). absl::StatusOr Evaluate( const HloInstruction* instruction, PrecomputedAnalyses precomputed_analyses = {}, - bool recursively_evaluate_nonconstant_operands = false); + bool recursively_evaluate_nonconstant_operands = false, + const absl::flat_hash_map& + substitutions = {}); // Same as Evaluate, except returning false on error and accepts an output // pointer. bool TryEvaluate(const HloInstruction* instruction, Literal* result, bool recursively_evaluate_nonconstant_operands = false); - // Evaluates a single HLO instruction, substituting the given literals for - // some of the instruction's operands. - // - // For example, given instruction = op(A, B, C) and the map - // {A = x, C = y}, this evaluates op(x, B, y). - absl::StatusOr EvaluateWithSubstitutions( - const HloInstruction* instruction, - const absl::flat_hash_map& - substitutions, - bool recursively_evaluate_nonconstant_operands = false); absl::StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, const Literal& lhs, diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index 610bec193491b7..9a3eddc718dbeb 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -3312,10 +3312,10 @@ TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloEvaluator evaluator; Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); - TF_ASSERT_OK_AND_ASSIGN( - Literal result, - evaluator.EvaluateWithSubstitutions( - add, {{param0, ¶m0_literal}, {square, &square_literal}})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + evaluator.Evaluate(add, {}, false, + {{param0, ¶m0_literal}, + {square, &square_literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR1({11, 22, 33, 44}), result)); } @@ -3337,10 +3337,11 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsRecursive) { HloInstruction* param = module->entry_computation()->parameter_instruction(0); TF_ASSERT_OK_AND_ASSIGN( auto result, - evaluator_.EvaluateWithSubstitutions( + evaluator_.Evaluate( /*instruction=*/module->entry_computation()->root_instruction(), - /*substitutions=*/{{param, ¶m_value}}, - /*recursively_evaluate_nonconstant_operands=*/true)); + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true, + /*substitutions=*/{{param, ¶m_value}})); EXPECT_EQ(result, LiteralUtil::CreateR0(PrimitiveType::S32, 1 + 2 + 3)); } @@ -3362,10 +3363,11 @@ TEST_F(HloEvaluatorTest, HloInstruction* param = module->entry_computation()->parameter_instruction(0); TF_ASSERT_OK_AND_ASSIGN( Literal result, - evaluator_.EvaluateWithSubstitutions( + evaluator_.Evaluate( /*instruction=*/module->entry_computation()->root_instruction(), - /*substitutions=*/{{param, ¶m_value}}, - /*recursively_evaluate_nonconstant_operands=*/true)); + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true, + /*substitutions=*/{{param, ¶m_value}})); EXPECT_EQ(result, LiteralUtil::CreateR0(PrimitiveType::S32, 4 + 1 + 2 + 1)); } @@ -3389,7 +3391,7 @@ TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); TF_ASSERT_OK_AND_ASSIGN( Literal result, - evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}})); + evaluator.Evaluate(add, {}, false, {{square, &square_literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR1({11, 22, 33, 44}), result)); } @@ -3405,8 +3407,9 @@ TEST_P(HloEvaluatorBf16Test, EvaluateSubstitutedInstruction) { HloEvaluator evaluator; Literal literal = LiteralUtil::CreateR1({10, 20, 30, 40}); - TF_ASSERT_OK_AND_ASSIGN(Literal result, evaluator.EvaluateWithSubstitutions( - param, {{param, &literal}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator.Evaluate(param, {}, false, {{param, &literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR1({10, 20, 30, 40}), result)); } @@ -3426,8 +3429,9 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsLiteralBase) { BorrowingLiteral literal(reinterpret_cast(int64_values), literal_shape); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result, evaluator.EvaluateWithSubstitutions( - square, {{param0, &literal}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator.Evaluate(square, {}, false, {{param0, &literal}})); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({1, 4, 9}), result)); } diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 4f0df16edb34f3..689a79825633fe 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -2026,10 +2026,7 @@ class HloInstruction { } const OpMetadata& metadata() const { return *metadata_; } - // Set/get the computation containing this instruction. set_parent should only - // be called by HloComputation methods which add/remove instructions to - // computations. - void set_parent(HloComputation* computation) { parent_ = computation; } + // Get the computation containing this instruction. const HloComputation* parent() const { return parent_; } HloComputation* parent() { return parent_; } @@ -2432,6 +2429,9 @@ class HloInstruction { bool ignore_channel_id_values, bool ignore_commutative_operand_order) const; + // Set the computation containing this instruction. + void set_parent(HloComputation* computation) { parent_ = computation; } + // Implementation for non-common logic of PrintExtraAttributes. virtual void PrintExtraAttributesImpl(AttributePrinter& printer, const HloPrintOptions& options) const {} diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 671d1ec6e5f567..969ab5fca269fc 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -102,25 +102,6 @@ class HloVerifierTestLayoutFusion : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; -TEST_F(HloVerifierTest, NullInstructionParent) { - HloComputation::Builder builder(TestName()); - const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloInstruction* param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param")); - HloInstruction* negate = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK(verifier().Run(module.get()).status()); - - negate->set_parent(nullptr); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), HasSubstr("has a null parent pointer")); -} - TEST_F(HloVerifierTest, DifferentOperandParents) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); From 93457bd92f9b52c5b0c1153346dfe398e5e5feb7 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 8 Apr 2025 11:06:37 -0700 Subject: [PATCH 0379/1324] Use `std::function` default value for default post-processors. There is no need to wrap the `std::function` into a `std::optional`. PiperOrigin-RevId: 745215732 --- .../xla/xla/service/collective_pipeliner.cc | 41 +++++++++---------- .../xla/xla/service/collective_pipeliner.h | 12 +++--- .../xla/service/collective_pipeliner_test.cc | 28 ++++++------- .../xla/xla/service/gpu/gpu_compiler.cc | 18 ++++---- .../xla/xla/service/gpu/gpu_p2p_pipeliner.cc | 3 +- .../gpu_collective_combiner_utils_test.cc | 12 +++--- 6 files changed, 55 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 3b384c46a0988b..f76a80e6e1d141 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -726,8 +726,7 @@ absl::StatusOr CloneBackwardChain( int64_t& next_scheduling_id, absl::flat_hash_map& annotation_map, LoopVariantParameterInfo* loop_variant_parameter_info = nullptr, - CollectivePipeliner::HloPostprocessor postprocess_pipelined_ops = - std::nullopt) { + CollectivePipeliner::HloPostprocessor postprocess_pipelined_ops = {}) { std::vector to_clone(move_info.formatting_ops.begin(), move_info.formatting_ops.end()); to_clone.push_back(move_info.collectives_to_move[0]); @@ -748,9 +747,9 @@ absl::StatusOr CloneBackwardChain( annotation_map); } clone_map[chain_op] = cloned; - if (postprocess_pipelined_ops.has_value()) { + if (postprocess_pipelined_ops) { TF_RETURN_IF_ERROR( - (*postprocess_pipelined_ops)(cloned, /*new_while_instr=*/nullptr)); + postprocess_pipelined_ops(cloned, /*new_while_instr=*/nullptr)); } last_cloned = cloned; if (loop_variant_parameter_info != nullptr && @@ -1949,9 +1948,9 @@ absl::Status TransformLoopForward( CollectivePipeliner::kInsertedByPreviousStep)); } - if (post_processing_fn.has_value()) { + if (post_processing_fn) { TF_RETURN_IF_ERROR( - (*post_processing_fn)(processed, /*new_while_instr=*/nullptr)); + post_processing_fn(processed, /*new_while_instr=*/nullptr)); } InstructionMap cloned_map = pipelined_values_map; @@ -1966,9 +1965,9 @@ absl::Status TransformLoopForward( annotation_map); } cloned_map[formatting_op] = processed; - if (post_processing_fn.has_value()) { + if (post_processing_fn) { TF_RETURN_IF_ERROR( - (*post_processing_fn)(processed, /*new_while_instr=*/nullptr)); + post_processing_fn(processed, /*new_while_instr=*/nullptr)); } } return processed; @@ -2788,13 +2787,13 @@ static absl::Status TransformLoopBackward( next_channel_id, next_scheduling_id, annotation_map, /*loop_variant_parameter_info=*/nullptr, post_processing_fn)); - if (post_processing_fn.has_value()) { - TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx], - /*new_while_instr=*/nullptr)); + if (post_processing_fn) { + TF_RETURN_IF_ERROR(post_processing_fn(new_init_operands[idx], + /*new_while_instr=*/nullptr)); } - if (postprocess_peeled.has_value()) { - TF_RETURN_IF_ERROR(postprocess_peeled.value()( - new_init_operands[idx], /*new_while_instr=*/nullptr)); + if (postprocess_peeled) { + TF_RETURN_IF_ERROR(postprocess_peeled(new_init_operands[idx], + /*new_while_instr=*/nullptr)); } } ConstantValue next_loop_iteration = @@ -2848,13 +2847,13 @@ static absl::Status TransformLoopBackward( next_scheduling_id, annotation_map, &loop_variant_parameter_info, post_processing_fn)); - if (post_processing_fn.has_value()) { + if (post_processing_fn) { TF_RETURN_IF_ERROR( - (*post_processing_fn)(cloned_instr, /*new_while_instr=*/nullptr)); + post_processing_fn(cloned_instr, /*new_while_instr=*/nullptr)); } - if (postprocess_rotated.has_value()) { - TF_RETURN_IF_ERROR(postprocess_rotated.value()( - cloned_instr, /*new_while_instr=*/nullptr)); + if (postprocess_rotated) { + TF_RETURN_IF_ERROR( + postprocess_rotated(cloned_instr, /*new_while_instr=*/nullptr)); } } else { auto new_operands = @@ -2991,10 +2990,10 @@ static absl::Status TransformLoopBackward( HloInstruction* cloned_instr = while_loop->parent()->AddInstruction( instr->CloneWithNewOperands(instr->shape(), new_operands)); - if (postprocess_peeled_trailing_op.has_value()) { + if (postprocess_peeled_trailing_op) { CHECK_NE(new_while_loop, nullptr); TF_RETURN_IF_ERROR( - postprocess_peeled_trailing_op.value()(cloned_instr, new_while_loop)); + postprocess_peeled_trailing_op(cloned_instr, new_while_loop)); } TF_RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr, diff --git a/third_party/xla/xla/service/collective_pipeliner.h b/third_party/xla/xla/service/collective_pipeliner.h index bc3ba3212bb348..a8f6bb2823c99d 100644 --- a/third_party/xla/xla/service/collective_pipeliner.h +++ b/third_party/xla/xla/service/collective_pipeliner.h @@ -70,8 +70,8 @@ class CollectivePipeliner : public HloModulePass { // before and after the loop, and rotated instructions. The new while op is // only passed for the peeled trailing ops when the new while op was already // created. - using HloPostprocessor = std::optional>; + using HloPostprocessor = std::function; struct Config { int64_t level_to_operate_on = 0; @@ -104,14 +104,14 @@ class CollectivePipeliner : public HloModulePass { // pipelined. This is currently only used to support kBackward pipelining. bool should_allow_control_dependencies = false; // TODO(b/399476667): Consolidate these postprocessing functions. - HloPostprocessor postprocess_backward_peeled_op = std::nullopt; - HloPostprocessor postprocess_backward_rotated_op = std::nullopt; - HloPostprocessor postprocess_backward_peeled_trailing_op = std::nullopt; + HloPostprocessor postprocess_backward_peeled_op; + HloPostprocessor postprocess_backward_rotated_op; + HloPostprocessor postprocess_backward_peeled_trailing_op; // Determines whether a loop invariant instruction can be considered // in the pipelining chain. bool should_add_loop_invariant_op_in_chain = false; // Postprocessing hook which runs for every successfully pipelined op. - HloPostprocessor postprocess_pipelined_ops = std::nullopt; + HloPostprocessor postprocess_pipelined_ops; int64_t collective_size_threshold_to_stop_sinking = INT64_MAX; }; static const char* const kInsertedByPreviousStep; diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index b4ff0b276ecc12..983ba64dc85fca 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -82,12 +82,10 @@ absl::StatusOr RunOptimizer( HloPredicate reuse_pipelined_op_buffer = HloPredicateTrue, HloPredicate should_allow_loop_variant_parameter_in_chain = HloPredicateFalse, - CollectivePipeliner::HloPostprocessor postprocess_backward_peeled = - std::nullopt, - CollectivePipeliner::HloPostprocessor postprocess_backward_rotated = - std::nullopt, + CollectivePipeliner::HloPostprocessor postprocess_backward_peeled = {}, + CollectivePipeliner::HloPostprocessor postprocess_backward_rotated = {}, CollectivePipeliner::HloPostprocessor postprocess_backward_peeled_trailing = - std::nullopt, + {}, bool should_add_loop_invariant_op_in_chain = false, int64_t collective_size_threshold_to_stop_sinking = INT64_MAX) { CollectivePipeliner::Config config = { @@ -105,7 +103,7 @@ absl::StatusOr RunOptimizer( /*should_allow_control_dependencies=*/false, postprocess_backward_peeled, postprocess_backward_rotated, postprocess_backward_peeled_trailing, should_add_loop_invariant_op_in_chain, - /*postprocess_pipelined_ops=*/std::nullopt, + /*postprocess_pipelined_ops=*/{}, collective_size_threshold_to_stop_sinking}; HloPassPipeline pass("optimizer"); pass.AddPass(/*layout_sensitive=*/false, @@ -3175,9 +3173,9 @@ ENTRY entry { /*acceptable_formatting=*/HloPredicateTrue, /*reuse_pipelined_op_buffer=*/HloPredicateTrue, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateTrue, - /*postprocess_backward_peeled=*/std::nullopt, - /*postprocess_backward_rotated=*/std::nullopt, - /*postprocess_backward_peeled_trailing=*/std::nullopt, + /*postprocess_backward_peeled=*/{}, + /*postprocess_backward_rotated=*/{}, + /*postprocess_backward_peeled_trailing=*/{}, /*should_add_loop_invariant_op_in_chain=*/true) .value()); XLA_VLOG_LINES(1, module->ToString()); @@ -3206,9 +3204,9 @@ ENTRY entry { /*acceptable_formatting=*/HloPredicateTrue, /*reuse_pipelined_op_buffer=*/HloPredicateTrue, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateTrue, - /*postprocess_backward_peeled=*/std::nullopt, - /*postprocess_backward_rotated=*/std::nullopt, - /*postprocess_backward_peeled_trailing=*/std::nullopt, + /*postprocess_backward_peeled=*/{}, + /*postprocess_backward_rotated=*/{}, + /*postprocess_backward_peeled_trailing=*/{}, /*should_add_loop_invariant_op_in_chain=*/false) .value()); } @@ -3598,9 +3596,9 @@ ENTRY entry { /*acceptable_formatting=*/HloPredicateIsNotOp, /*reuse_pipelined_op_buffer=*/HloPredicateTrue, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, - /*postprocess_backward_peeled=*/std::nullopt, - /*postprocess_backward_rotated=*/std::nullopt, - /*postprocess_backward_peeled_trailing=*/std::nullopt, + /*postprocess_backward_peeled=*/{}, + /*postprocess_backward_rotated=*/{}, + /*postprocess_backward_peeled_trailing=*/{}, /*should_add_loop_invariant_op_in_chain=*/false, /*collective_size_threshold_to_stop_sinking=*/1024) .value()); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index e822d46aaae140..7c9340d6ef38ca 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -932,9 +932,9 @@ absl::Status RunCollectiveOptimizationPasses( /*reuse_pipelined_op_buffer=*/HloPredicateFalse, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, /*should_allow_control_dependencies=*/false, - /*postprocess_backward_peeled_op=*/std::nullopt, - /*postprocess_backward_rotated_op=*/std::nullopt, - /*postprocess_backward_peeled_trailing_op=*/std::nullopt, + /*postprocess_backward_peeled_op=*/{}, + /*postprocess_backward_rotated_op=*/{}, + /*postprocess_backward_peeled_trailing_op=*/{}, /*should_add_loop_invariant_op_in_chain=*/false, /*postprocess_pipelined_ops=*/AppendPipelinedInstruction, }; @@ -956,9 +956,9 @@ absl::Status RunCollectiveOptimizationPasses( /*reuse_pipelined_op_buffer=*/HloPredicateFalse, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, /*should_allow_control_dependencies=*/false, - /*postprocess_backward_peeled_op=*/std::nullopt, - /*postprocess_backward_rotated_op=*/std::nullopt, - /*postprocess_backward_peeled_trailing_op=*/std::nullopt, + /*postprocess_backward_peeled_op=*/{}, + /*postprocess_backward_rotated_op=*/{}, + /*postprocess_backward_peeled_trailing_op=*/{}, /*should_add_loop_invariant_op_in_chain=*/true, /*postprocess_pipelined_ops=*/AppendPipelinedInstruction, }; @@ -980,9 +980,9 @@ absl::Status RunCollectiveOptimizationPasses( /*reuse_pipelined_op_buffer=*/HloPredicateFalse, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, /*should_allow_control_dependencies=*/false, - /*postprocess_backward_peeled_op=*/std::nullopt, - /*postprocess_backward_rotated_op=*/std::nullopt, - /*postprocess_backward_peeled_trailing_op=*/std::nullopt, + /*postprocess_backward_peeled_op=*/{}, + /*postprocess_backward_rotated_op=*/{}, + /*postprocess_backward_peeled_trailing_op=*/{}, /*should_add_loop_invariant_op_in_chain=*/false, /*postprocess_pipelined_ops=*/AppendPipelinedInstruction, }; diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc index 7a6151aabd5ce4..456bc1bf79faa6 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc @@ -547,8 +547,7 @@ absl::StatusOr GpuP2PPipeliner::Run( PostprocessPeeledP2P; CollectivePipeliner::HloPostprocessor postprocess_backward_rotated_op = PostprocessRotatedP2P; - CollectivePipeliner::HloPostprocessor - postprocess_backward_peeled_trailing_op = std::nullopt; + CollectivePipeliner::HloPostprocessor postprocess_backward_peeled_trailing_op; // If partial send/recv pipelining is enabled, collect send/recv instructions // for post-processing. diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc index 371a464afefedf..9993985c72ab84 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils_test.cc @@ -332,9 +332,9 @@ TEST_F(CollectiveCombinerUtilsTest, /*reuse_pipelined_op_buffer=*/HloPredicateFalse, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, /*should_allow_control_dependencies=*/false, - /*postprocess_backward_peeled_op=*/std::nullopt, - /*postprocess_backward_rotated_op=*/std::nullopt, - /*postprocess_backward_peeled_trailing_op=*/std::nullopt, + /*postprocess_backward_peeled_op=*/{}, + /*postprocess_backward_rotated_op=*/{}, + /*postprocess_backward_peeled_trailing_op=*/{}, /*should_add_loop_invariant_op_in_chain=*/true, }; config.postprocess_pipelined_ops = AppendPipelinedInstruction; @@ -427,9 +427,9 @@ TEST_F(CollectiveCombinerUtilsTest, /*reuse_pipelined_op_buffer=*/HloPredicateFalse, /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, /*should_allow_control_dependencies=*/false, - /*postprocess_backward_peeled_op=*/std::nullopt, - /*postprocess_backward_rotated_op=*/std::nullopt, - /*postprocess_backward_peeled_trailing_op=*/std::nullopt, + /*postprocess_backward_peeled_op=*/{}, + /*postprocess_backward_rotated_op=*/{}, + /*postprocess_backward_peeled_trailing_op=*/{}, /*should_add_loop_invariant_op_in_chain=*/true, }; config.postprocess_pipelined_ops = AppendPipelinedInstruction; From 6977a3272627118004592a0f1958c3d723a5eee0 Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Tue, 8 Apr 2025 11:06:45 -0700 Subject: [PATCH 0380/1324] PR #24429: Fixed failing NCCL Group test on internal CI Imported from GitHub PR https://github.com/openxla/xla/pull/24429 Previously, this test was failing with ``` INTERNAL: NCCL operation ncclRecv( recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, source_rank->value(), comm_, se::gpu::AsGpuStreamValue(stream)) failed: invalid argument (run with NCCL_DEBUG=WARN for details). Last NCCL warning(error) log entry (may be unrelated) 'Recv : invalid root 3 (root should be in the 0..1 range)'. ``` This was because of a mix of the bad channel ids and this missing `replica_count` attribute. Copybara import of the project: -- f904b6c985db68ebb2d98f8e8a2a4d2da45834e5 by chaserileyroberts : Set replica_count=4, removed channel_ids -- 67702da461c3f9bac803cfb01bf9352fec46c610 by chaser : replica_count -> num_partitions Merging this change closes #24429 PiperOrigin-RevId: 745215789 --- third_party/xla/xla/tests/nccl_group_execution_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/tests/nccl_group_execution_test.cc b/third_party/xla/xla/tests/nccl_group_execution_test.cc index 433e187e20501c..aef4335c642588 100644 --- a/third_party/xla/xla/tests/nccl_group_execution_test.cc +++ b/third_party/xla/xla/tests/nccl_group_execution_test.cc @@ -127,13 +127,13 @@ XLA_TEST_F(NcclGroupExecutionTest, NcclGroupSendRecvNoWhileLoop) { XLA_TEST_F(NcclGroupExecutionTest, BidirectionalCommunication) { const absl::string_view kModuleStr = R"( - HloModule module_main, entry_computation_layout={()->(u32[], u32[])} + HloModule module_main, entry_computation_layout={()->(u32[], u32[])}, num_partitions=4 bidirectional_ring { a = u32[] parameter(0) - start = (u32[], u32[]) collective-permute-start(a), channel_id=2, source_target_pairs={{0,1},{1,2},{2,3},{3,0}} + start = (u32[], u32[]) collective-permute-start(a), source_target_pairs={{0,1},{1,2},{2,3},{3,0}} done = u32[] collective-permute-done(start) - start.1 = (u32[], u32[]) collective-permute-start(a), channel_id=1, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + start.1 = (u32[], u32[]) collective-permute-start(a), source_target_pairs={{0,3},{1,0},{2,1},{3,2}} done.1 = u32[] collective-permute-done(start.1) ROOT tuple = (u32[], u32[]) tuple(done, done.1) } From 649a47160e01543607fad5fe0b487292050f6453 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 11:12:08 -0700 Subject: [PATCH 0381/1324] Fix clang-tidy in gpusolver_rewriter and associated build rule. PiperOrigin-RevId: 745218065 --- third_party/xla/xla/service/gpu/transforms/BUILD | 7 ++++++- .../xla/service/gpu/transforms/gpusolver_rewriter.cc | 10 +++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 77daf1e18153f7..e21bfedccb6dec 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2070,7 +2070,12 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@com_google_absl//absl/functional:any_invocable", - ]), + ]) + [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + ], ) cc_library( diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc index b960cbf5ea1f5f..693ebc5158cd7b 100644 --- a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc @@ -17,11 +17,16 @@ limitations under the License. #include #include +#include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" @@ -35,11 +40,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/gpu_solver_context.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { From 9b091c859ae40e715a49f0f678958d414ef6a88e Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Tue, 8 Apr 2025 12:34:08 -0700 Subject: [PATCH 0382/1324] PR #24794: Remove scheduling annotations on the cloned computation for collective groups Imported from GitHub PR https://github.com/openxla/xla/pull/24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser : Remove scheduling annotations on the cloned computation Merging this change closes #24794 PiperOrigin-RevId: 745248689 --- ...xplicit_collectives_group_async_wrapper.cc | 3 ++ ...it_collectives_group_async_wrapper_test.cc | 39 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc index 3748bf5aab02fa..42aa4b8af87fb5 100644 --- a/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc @@ -43,6 +43,9 @@ absl::StatusOr CreateCollectivesGroupAsyncPair(HloInstruction* instr) { HloComputation* computation = instr->parent(); auto new_computation = instr->GetModule()->AddEmbeddedComputation( instr->to_apply()->Clone("collectives_group")); + for (auto inner_instruction : new_computation->instructions()) { + inner_instruction->erase_frontend_attribute(kXlaSchedulingGroupIdAttr); + } // Get the shapes for the original instruction. std::vector parameter_shapes(instr->operand_count()); for (int i = 0; i < instr->operand_count(); ++i) { diff --git a/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper_test.cc index 51939f19039ab1..014c7f9ab2d323 100644 --- a/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper_test.cc @@ -64,6 +64,45 @@ TEST_F(ExplicitCollectivesGroupAsyncWrapperTest, AnnotatedOpIsWrapped) { ASSERT_TRUE(mutated); } +TEST_F(ExplicitCollectivesGroupAsyncWrapperTest, + RemoveSchedulingGroupAnnotation) { + const absl::string_view hlo_string = R"( + HloModule composite + comms { + a = f32[1] parameter(0) + x = f32[1] all-gather(a), replica_groups={}, dimensions={0}, frontend_attributes={_scheduling_group_id="1"} + y = f32[1] collective-permute(a), source_target_pairs={{0,1}}, frontend_attributes={_scheduling_group_id="1"} + ROOT result = (f32[1], f32[1]) tuple(x, y) + } + + ENTRY main { + b = f32[1] parameter(0) + ROOT c = (f32[1], f32[1]) call(b), to_apply=comms, frontend_attributes={_collectives_group="", _scheduling_group_id="1"} + } + )"; + + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + ExplicitCollectivesGroupAsyncWrapper wrapper_pass; + + TF_ASSERT_OK_AND_ASSIGN(bool mutated, wrapper_pass.Run(module.get())); + // Assert that the scheduling annotation is removed within the cloned + // computation, but remains on the async operations. + absl::StatusOr filecheck_result = RunFileCheck(module->ToString({}), R"( + // CHECK: %comms.collectives_group {{.*}} { + // CHECK-NEXT: %{{.*}} parameter(0) + // CHECK-NEXT: %{{.*}} all-gather({{.*}}), replica_groups={}, dimensions={0} + // CHECK-NEXT: %{{.*}} collective-permute({{.*}}), source_target_pairs={{[{][{]0,1[}][}]}} + // CHECK: ENTRY %main {{.*}} + // CHECK-NEXT: %[[P0:.*]] = {{.*}} parameter(0) + // CHECK-NEXT: %[[P1:.*]] = {{.*}} async-start(%[[P0]]), async_execution_thread="explicit", calls=%comms.collectives_group, frontend_attributes={_collectives_group="",_scheduling_group_id="1"} + // CHECK-NEXT: ROOT %{{.*}} async-done(%[[P1]]), frontend_attributes={_collectives_group="",_scheduling_group_id="1"} + )"); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(*filecheck_result); + ASSERT_TRUE(mutated); +} + TEST_F(ExplicitCollectivesGroupAsyncWrapperTest, ManyCollectivesGroups) { // This test calls the same collectives group computation twice, so the // computation is cloned so it can be used with many async instructions. From 77efc477cd82a02aacb85eafa95ae610f6689a85 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 8 Apr 2025 12:49:42 -0700 Subject: [PATCH 0383/1324] [XLA:GPU] Make `fusion_emitter_deviceless_test` truly deviceless. Previously, it was still running on GPU for some reason. Also simplify some of the test code while we're at it. PiperOrigin-RevId: 745254948 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 6 +- .../triton/fusion_emitter_deviceless_test.cc | 84 ++++++++----------- 2 files changed, 36 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 0f6c0f451e6d59..fc18849ddd810b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -469,19 +469,19 @@ xla_cc_test( ], ) -xla_test( +xla_cc_test( name = "fusion_emitter_deviceless_test", srcs = ["fusion_emitter_deviceless_test.cc"], - backends = ["gpu"], + tags = ["no_oss"], # Doesn't pass in OSS when building with the `fusion_emitter_stub`. deps = [ ":fusion_emitter", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", - "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc index e4118beb8407df..c10ed8ca50aeff 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_deviceless_test.cc @@ -27,35 +27,26 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" -#include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" -#if defined(PLATFORM_GOOGLE) -#else - -#endif namespace xla::gpu { namespace { using ::tsl::testing::IsOkAndHolds; using ::xla::gpu::ir_emitter_triton_internal::DumpTritonIR; -class AnnotationsTest : public GpuCodegenTest { +class AnnotationsTest : public HloHardwareIndependentTestBase { public: - const stream_executor::GpuComputeCapability& GpuComputeComp() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .gpu_compute_capability(); - } DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + DebugOptions debug_options = + HloHardwareIndependentTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_unsupported_annotate_with_emitter_loc(true); return debug_options; } @@ -63,54 +54,45 @@ class AnnotationsTest : public GpuCodegenTest { TEST_F(AnnotationsTest, Annotations) { static constexpr absl::string_view kHloText = R"( - HloModule Annotations - - triton_dot { - p0 = f32[8,8] parameter(0) - p1 = f32[8,8] parameter(1) - ROOT dot = f32[8,8] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 - } +HloModule Annotations + +triton_dot { + p0 = f32[8,8] parameter(0) + p1 = f32[8,8] parameter(1) + ROOT dot = f32[8,8] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} - ENTRY e { - p0 = f32[8,8]{1, 0} parameter(0) - p1 = f32[8,8]{1, 0} parameter(1) - ROOT _ = f32[8,8] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - { - "block_m":32, - "block_n":32, - "block_k":32, - "split_k":1, - "num_stages":1, - "num_warps":1, - "num_ctas":1 - } - } +ENTRY e { + p0 = f32[8,8]{1, 0} parameter(0) + p1 = f32[8,8]{1, 0} parameter(1) + ROOT _ = f32[8,8] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + { + "block_m":32, + "block_n":32, + "block_k":32, + "split_k":1, + "num_stages":1, + "num_warps":1, + "num_ctas":1 } } - )"; + } +})"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - auto* comp = module->GetComputationWithName("triton_dot"); - EXPECT_NE(comp, nullptr); - auto fusion_backend_config = comp->FusionInstruction() - ->backend_config() - ->fusion_backend_config(); - BlockLevelParameters block_level_parameters = - BlockLevelParameters::FromBlockLevelFusionConfig( - fusion_backend_config.block_level_fusion_config()); - - auto* fusion = Cast(comp->FusionInstruction()); + auto* fusion = Cast( + module->entry_computation()->root_instruction()); mlir::MLIRContext context; TF_ASSERT_OK_AND_ASSIGN( auto triton_module, CreateTritonModule("triton_fn", fusion, TestGpuDeviceInfo::RTXA6000DeviceInfo(), - block_level_parameters, context)); + BlockLevelParameters(), context)); std::string annotated_ir = DumpTritonIR(triton_module.get(), true); @@ -127,7 +109,7 @@ TEST_F(AnnotationsTest, Annotations) { } } -using TritonEmitterDevicelessTest = GpuCodegenTest; +using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase; TEST_F(TritonEmitterDevicelessTest, FailsGracefullyIfNumWarpsIsMissing) { constexpr absl::string_view kHloText = R"( From d0f757be92f8eeeb5af727170e32b1c8e93272b9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 12:56:36 -0700 Subject: [PATCH 0384/1324] [Repository Update] Remove lib to gpu_cost_analysis from Repository files. PiperOrigin-RevId: 745257376 --- tensorflow/core/profiler/convert/BUILD | 1 - tensorflow/core/profiler/convert/repository.h | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 978b075dea7733..64675cee9b6c90 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -1013,7 +1013,6 @@ cc_library( "@local_xla//xla/tsl/platform:statusor", "@local_xla//xla/tsl/profiler/utils:file_system_utils", "@org_xprof//xprof/utils:hlo_module_map", - "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], ) diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h index 9c5ee606552711..2db0ad41777384 100644 --- a/tensorflow/core/profiler/convert/repository.h +++ b/tensorflow/core/profiler/convert/repository.h @@ -34,7 +34,6 @@ limitations under the License. #include "tsl/platform/path.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/utils/hlo_module_map.h" // from @org_xprof -#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof namespace tensorflow { namespace profiler { From 052fe0dd51473d48c2eef8845c7f5081c3f4fb74 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 13:14:08 -0700 Subject: [PATCH 0385/1324] Replace legacy struct providers with modern ones PiperOrigin-RevId: 745263819 --- tensorflow/tensorflow.bzl | 12 +++++++----- third_party/xla/xla/tsl/tsl.bzl | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f64e111f78b4a3..df9f64736bb7fd 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2217,6 +2217,8 @@ def tf_custom_op_library_additional_deps_impl(): clean_dep("//tensorflow/core:reader_base"), ] +CollectedDepsInfo = provider("CollectedDepsInfo", fields = ["tf_collected_deps"]) + # Traverse the dependency graph along the "deps" attribute of the # target and return a struct with one field called 'tf_collected_deps'. # tf_collected_deps will be the union of the deps of the current target @@ -2232,9 +2234,9 @@ def _collect_deps_aspect_impl(target, ctx): all_deps += ctx.rule.attr.roots for dep in all_deps: direct.append(dep.label) - if hasattr(dep, "tf_collected_deps"): - transitive.append(dep.tf_collected_deps) - return struct(tf_collected_deps = depset(direct = direct, transitive = transitive)) + if CollectedDepsInfo in dep: + transitive.append(dep[CollectedDepsInfo].tf_collected_deps) + return CollectedDepsInfo(tf_collected_deps = depset(direct = direct, transitive = transitive)) collect_deps_aspect = aspect( attr_aspects = ["deps", "data", "roots"], @@ -2253,9 +2255,9 @@ def _check_deps_impl(ctx): required_deps = ctx.attr.required_deps disallowed_deps = ctx.attr.disallowed_deps for input_dep in ctx.attr.deps: - if not hasattr(input_dep, "tf_collected_deps"): + if CollectedDepsInfo not in input_dep: continue - collected_deps = sets.make(input_dep.tf_collected_deps.to_list()) + collected_deps = sets.make(input_dep[CollectedDepsInfo].tf_collected_deps.to_list()) for disallowed_dep in disallowed_deps: if sets.contains(collected_deps, disallowed_dep.label): fail( diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 188a956618592c..bb1dd9a19ba2d8 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -406,6 +406,8 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs register_extension_info(extension = tsl_gpu_library, label_regex_for_dep = "{extension_name}") +CollectedDepsInfo = provider("CollectedDepsInfo", fields = ["tf_collected_deps"]) + # Traverse the dependency graph along the "deps" attribute of the # target and return a struct with one field called 'tf_collected_deps'. # tf_collected_deps will be the union of the deps of the current target @@ -421,9 +423,9 @@ def _collect_deps_aspect_impl(target, ctx): # buildifier: disable=unused-variab all_deps += ctx.rule.attr.roots for dep in all_deps: direct.append(dep.label) - if hasattr(dep, "tf_collected_deps"): - transitive.append(dep.tf_collected_deps) - return struct(tf_collected_deps = depset(direct = direct, transitive = transitive)) + if CollectedDepsInfo in dep: + transitive.append(dep[CollectedDepsInfo].tf_collected_deps) + return CollectedDepsInfo(tf_collected_deps = depset(direct = direct, transitive = transitive)) collect_deps_aspect = aspect( attr_aspects = ["deps", "data", "roots"], @@ -442,9 +444,9 @@ def _check_deps_impl(ctx): required_deps = ctx.attr.required_deps disallowed_deps = ctx.attr.disallowed_deps for input_dep in ctx.attr.deps: - if not hasattr(input_dep, "tf_collected_deps"): + if CollectedDepsInfo not in input_dep: continue - collected_deps = sets.make(input_dep.tf_collected_deps.to_list()) + collected_deps = sets.make(input_dep[CollectedDepsInfo].tf_collected_deps.to_list()) for disallowed_dep in disallowed_deps: if sets.contains(collected_deps, disallowed_dep.label): fail( From abf78621aa69e04bfff25ec1590015673a093589 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Tue, 8 Apr 2025 13:35:19 -0700 Subject: [PATCH 0386/1324] Replace `string const&` with `const string&` PiperOrigin-RevId: 745272373 --- third_party/xla/xla/hlo/tools/hlo_translate.cc | 2 +- .../hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc | 2 +- third_party/xla/xla/pjrt/local_device_state.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_translate.cc b/third_party/xla/xla/hlo/tools/hlo_translate.cc index f0dddd06638843..eab8ff6f5abbf3 100644 --- a/third_party/xla/xla/hlo/tools/hlo_translate.cc +++ b/third_party/xla/xla/hlo/tools/hlo_translate.cc @@ -113,7 +113,7 @@ absl::StatusOr> GetModuleFromHLOText( } absl::StatusOr> GetModuleFromHLOProto( - std::string const& content, mlir::MLIRContext* context) { + const std::string& content, mlir::MLIRContext* context) { xla::HloProto hlo_proto; if (!LoadHloProto(content, &hlo_proto)) return absl::InvalidArgumentError(kLoadHloError); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 372970d2b47e25..ec7e331d45a4ff 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -355,7 +355,7 @@ Attribute encodePrecisionConfig(ArrayAttr precisionConfigAttr) { template LogicalResult notifyConversionFailure(ConversionPatternRewriter& rewriter, Operation* op, - std::string const& errorMessage, + const std::string& errorMessage, FailedToConvertTy ty) { return rewriter.notifyMatchFailure( op, [=](Diagnostic& diag) { diag << errorMessage << ": " << ty; }); diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index 881ff4613dac4c..5996c8d25f74d6 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -78,7 +78,7 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, int num_device_to_device_streams = stream_options.has_value() ? stream_options->num_device_to_device_streams : kNumDeviceToDeviceStreams; - auto create_stream = [executor, &stream_options](std::string const& name) { + auto create_stream = [executor, &stream_options](const std::string& name) { std::unique_ptr stream; if (stream_options.has_value()) { stream = executor->CreateStream(stream_options->priority).value(); From 964bd66a3b13f91d51d9671d3f5692d0edc929af Mon Sep 17 00:00:00 2001 From: jparkerh Date: Tue, 8 Apr 2025 13:50:17 -0700 Subject: [PATCH 0387/1324] fix includes in se_gpu compilation tests add missing includes and remove unused includes in the se_gpu compilation testing. breaking this out of another larger cl to limit review burden. PiperOrigin-RevId: 745278022 --- third_party/xla/xla/pjrt/gpu/BUILD | 8 +++----- .../xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc | 2 +- third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 7c418f7595acb4..d5db0f22fdbfe8 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -413,6 +413,7 @@ xla_test( ":gpu_topology", ":se_gpu_pjrt_client", ":se_gpu_pjrt_compiler_impl", + ":se_gpu_topology_description", "//xla:literal", "//xla:literal_util", "//xla/hlo/builder:xla_computation", @@ -423,8 +424,8 @@ xla_test( "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", - "//xla/service:platform_util", "//xla/tests:literal_test_util", + "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -433,7 +434,6 @@ xla_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], ) @@ -456,8 +456,8 @@ xla_test( "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:compiler", "//xla/tests:literal_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", @@ -466,8 +466,6 @@ xla_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 8c0b9bc6d3a182..f18c509108498c 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -40,8 +40,8 @@ limitations under the License. #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/service/compiler.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/casts.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index a7a56b618be668..d7f27f7e6b0c95 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -36,13 +36,13 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/gpu/se_gpu_topology_description.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/tests/literal_test_util.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/status_matchers.h" namespace xla { namespace { From 23b9e6410a628b7fda98af11b7e0aee5a6f24b4a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 8 Apr 2025 13:51:19 -0700 Subject: [PATCH 0388/1324] [xla] Rename VariantVisitor to Overload and move to xla/service Prepare for eventual absl::Overload migration and allow using Overload in other backends. PiperOrigin-RevId: 745278444 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 2 +- .../gpu/codegen/triton/support_legacy.cc | 54 +++++++++---------- .../xla/xla/backends/gpu/runtime/BUILD | 2 +- .../backends/gpu/runtime/conditional_thunk.cc | 17 +++--- third_party/xla/xla/service/BUILD | 5 ++ third_party/xla/xla/service/gpu/BUILD | 9 +--- .../xla/xla/service/gpu/autotuning/BUILD | 4 +- .../gpu/autotuning/gemm_algorithm_picker.cc | 24 ++++----- .../autotuning/gemm_algorithm_picker_test.cc | 4 +- .../gpu/cublas_padding_requirements.cc | 4 +- .../xla/xla/service/gpu/float_support_test.cc | 10 ++-- third_party/xla/xla/service/gpu/model/BUILD | 2 +- .../service/gpu/model/interpolator_test.cc | 17 +++--- .../xla/xla/service/gpu/transforms/BUILD | 2 +- .../transforms/command_buffer_scheduling.cc | 4 +- .../{gpu/variant_visitor.h => overload.h} | 16 +++--- .../xla/xla/tools/matmul_perf_table_gen.cc | 6 +-- 17 files changed, 89 insertions(+), 93 deletions(-) rename third_party/xla/xla/service/{gpu/variant_visitor.h => overload.h} (78%) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index fc18849ddd810b..b1a34cfbeb91ff 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -899,10 +899,10 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:algorithm_util", "//xla/service:instruction_fusion", + "//xla/service:overload", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_indexing_utils", - "//xla/service/gpu:variant_visitor", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc index c503d18c215671..f6d4d4ee15c0f7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc @@ -30,7 +30,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/tensor_float_32_utils.h" @@ -67,38 +67,34 @@ bool IsTritonSupportedDotOutputType( case F32: return true; case F8E5M2: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastAmpere(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); + return std::visit( + Overload{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastAmpere(); + }, + [](const se::RocmComputeCapability& cc) { return false; }}, + gpu_version); case F8E4M3FN: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastHopper(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); + return std::visit( + Overload{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastHopper(); + }, + [](const se::RocmComputeCapability& cc) { return false; }}, + gpu_version); case BF16: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return true; - }, - [](const se::RocmComputeCapability& cc) { - return cc.has_bf16_dtype_support(); - }}, - gpu_version); + return std::visit( + Overload{[](const se::CudaComputeCapability& cc) { return true; }, + [](const se::RocmComputeCapability& cc) { + return cc.has_bf16_dtype_support(); + }}, + gpu_version); case S32: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastAmpere(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); + return std::visit( + Overload{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastAmpere(); + }, + [](const se::RocmComputeCapability& cc) { return false; }}, + gpu_version); default: return false; } diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index c766e7a7bb45d3..372b2eaf0b59f6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -418,7 +418,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:buffer_assignment", - "//xla/service/gpu:variant_visitor", + "//xla/service:overload", "//xla/stream_executor:device_memory", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream_executor_h", diff --git a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc index e85708f63b1e82..d5e6301eff4cc0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/memory_allocation.h" @@ -120,16 +120,15 @@ absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { &stream, blocked.message()); } - int32_t branch_index = std::visit( - VariantVisitor{[](int32_t* branch_index) { return *branch_index; }, - [](bool* pred) { return *pred ? 0 : 1; }}, - branch_index_or_pred); - - absl::string_view branch_kind = - std::visit(VariantVisitor{[](int32_t*) { return "index"; }, - [](bool*) { return "pred"; }}, + int32_t branch_index = + std::visit(Overload{[](int32_t* branch_index) { return *branch_index; }, + [](bool* pred) { return *pred ? 0 : 1; }}, branch_index_or_pred); + absl::string_view branch_kind = std::visit( + Overload{[](int32_t*) { return "index"; }, [](bool*) { return "pred"; }}, + branch_index_or_pred); + VLOG(3) << "ConditionalThunk: branch_index=" << branch_index << " (kind: " << branch_kind << ")"; diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 68b110d741389c..2d01ea4d6fa595 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -574,6 +574,11 @@ cc_library( ]), ) +cc_library( + name = "overload", + hdrs = ["overload.h"], +) + xla_cc_test( name = "dump_test", srcs = ["dump_test.cc"], diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 373e98aff4d515..5dec0050d0be87 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -746,11 +746,6 @@ gpu_kernel_library( ]), ) -cc_library( - name = "variant_visitor", - hdrs = ["variant_visitor.h"], -) - build_cub_sort_kernels( name = "cub_sort_kernel", srcs = if_gpu_is_configured(["cub_sort_kernel.cu.cc"]), @@ -1150,10 +1145,10 @@ cc_library( srcs = ["cublas_padding_requirements.cc"], hdrs = ["cublas_padding_requirements.h"], deps = [ - ":variant_visitor", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:overload", "//xla/stream_executor:device_description", ], ) @@ -2581,9 +2576,9 @@ xla_test( "gpu_b200", ], deps = [ - ":variant_visitor", "//xla:error_spec", "//xla:xla_proto_cc", + "//xla/service:overload", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 47a9c7c02d8ffd..58433609cb3361 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -305,11 +305,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_module_config", + "//xla/service:overload", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu:variant_visitor", "//xla/stream_executor:blas", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", @@ -452,9 +452,9 @@ xla_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/service:overload", "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:variant_visitor", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/stream_executor:device_description", "//xla/stream_executor:platform", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc index 90087127341c81..c323def8b13a2b 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc @@ -43,7 +43,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/blas.h" @@ -437,17 +437,17 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, auto old_algorithm = backend_config.selected_algorithm(); bool update_algorithm = IsCublasLtMatmulF8(*gemm) || - std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - // We only set the 'algorithm' field on - // non-Ampere architectures, as for Ampere - // it's ignored in any case. - return !cc.IsAtLeast( - se::CudaComputeCapability::kAmpere); - }, - [](const se::RocmComputeCapability&) { - return true; // TODO: not decided yet - }}, - config.GetGpuComputeCapability()); + std::visit( + Overload{[](const se::CudaComputeCapability& cc) { + // We only set the 'algorithm' field on + // non-Ampere architectures, as for Ampere + // it's ignored in any case. + return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere); + }, + [](const se::RocmComputeCapability&) { + return true; // TODO: not decided yet + }}, + config.GetGpuComputeCapability()); if (update_algorithm) { int64_t new_algorithm{}; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc index ba8f0dc7d59e8c..003abe52bddf5d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" @@ -74,7 +74,7 @@ class GemmAlgorithmPickerTest : public HloTestBase, bool blas_get_version = name.rfind("BlasGetVersion") == 0; std::visit( - VariantVisitor{ + Overload{ [&](const se::CudaComputeCapability& cc) { if (!blas_get_version && cc.IsAtLeastAmpere()) { GTEST_SKIP() diff --git a/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc b/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc index 7227f94c4bb72e..3e606f93cff619 100644 --- a/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc +++ b/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" @@ -33,7 +33,7 @@ namespace { bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type, const se::GpuComputeCapability& gpu_cc) { return std::visit( - VariantVisitor{ + Overload{ [&](const se::CudaComputeCapability& cc) { for (const auto& req : CublasPaddingRequirements) { if (cc.IsAtLeast(req.min_compute_capability) && diff --git a/third_party/xla/xla/service/gpu/float_support_test.cc b/third_party/xla/xla/service/gpu/float_support_test.cc index 79f343f90c1a4c..fea048a09159ed 100644 --- a/third_party/xla/xla/service/gpu/float_support_test.cc +++ b/third_party/xla/xla/service/gpu/float_support_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "xla/error_spec.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" @@ -75,10 +75,10 @@ ENTRY e { TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) { bool skip_test = std::visit( - VariantVisitor{[](const se::CudaComputeCapability& cc) { - return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere); - }, - [](const se::RocmComputeCapability&) { return true; }}, + Overload{[](const se::CudaComputeCapability& cc) { + return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere); + }, + [](const se::RocmComputeCapability&) { return true; }}, GetGpuComputeCapability()); if (skip_test) { diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index acbabc297c174a..6199e65a54afdf 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -1056,7 +1056,7 @@ xla_cc_test( srcs = ["interpolator_test.cc"], deps = [ ":interpolator", - "//xla/service/gpu:variant_visitor", + "//xla/service:overload", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", ], diff --git a/third_party/xla/xla/service/gpu/model/interpolator_test.cc b/third_party/xla/xla/service/gpu/model/interpolator_test.cc index b5289b2412aa25..0dbcb2bce16c02 100644 --- a/third_party/xla/xla/service/gpu/model/interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/interpolator_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include #include "absl/log/log.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" namespace xla::gpu { namespace { @@ -119,17 +119,16 @@ TEST_P(EuclideanNN2DInterpolatorTest, ReturnsNearestNeighbour) { std::array plane_point = point.first; int val = point.second; std::visit( - VariantVisitor{ - [&](const std::unique_ptr>& - nn) { return nn->Add(plane_point, val); }, - [&](const std::unique_ptr< - EuclideanComplementInterpolator>& comp) { - return comp->Add(plane_point, val); - }}, + Overload{[&](const std::unique_ptr>& + nn) { return nn->Add(plane_point, val); }, + [&](const std::unique_ptr< + EuclideanComplementInterpolator>& comp) { + return comp->Add(plane_point, val); + }}, interpolator); } std::visit( - VariantVisitor{ + Overload{ [&](const std::unique_ptr>& nn) { EXPECT_EQ(nn->Eval(param.eval_point), param.expected_value); }, diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index e21bfedccb6dec..f93c98fa6f03f4 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -698,11 +698,11 @@ cc_library( "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:overload", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:variant_visitor", "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "@com_google_absl//absl/algorithm:container", diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 216af6041b416e..22dfa3823b082e 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -44,7 +44,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/variant_visitor.h" +#include "xla/service/overload.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -880,7 +880,7 @@ absl::StatusOr CommandBufferScheduling::Run( erase(kRequireConditionals); // on-device control flow }; - std::visit(VariantVisitor{erase_cuda, erase_rocm}, + std::visit(Overload{erase_cuda, erase_rocm}, device_description_.gpu_compute_capability()); auto order = module->MakeComputationPostOrder(); diff --git a/third_party/xla/xla/service/gpu/variant_visitor.h b/third_party/xla/xla/service/overload.h similarity index 78% rename from third_party/xla/xla/service/gpu/variant_visitor.h rename to third_party/xla/xla/service/overload.h index c4ff4aa89b3fd5..f5f7e2cb03e152 100644 --- a/third_party/xla/xla/service/gpu/variant_visitor.h +++ b/third_party/xla/xla/service/overload.h @@ -13,22 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_VARIANT_VISITOR_H_ -#define XLA_SERVICE_GPU_VARIANT_VISITOR_H_ +#ifndef XLA_SERVICE_OVERLOAD_H_ +#define XLA_SERVICE_OVERLOAD_H_ + +namespace xla { -namespace xla::gpu { // This structure is used to support C++17 overload pattern as described in // https://en.cppreference.com/w/cpp/utility/variant/visit // // TODO(b/319202112): Replace with absl::Overload once abs lts_2024_XXX is // tagged. template -struct VariantVisitor : Ts... { +struct Overload : Ts... { using Ts::operator()...; }; + template -VariantVisitor(Ts...) -> VariantVisitor; +Overload(Ts...) -> Overload; -} // namespace xla::gpu +} // namespace xla -#endif // XLA_SERVICE_GPU_VARIANT_VISITOR_H_ +#endif // XLA_SERVICE_OVERLOAD_H_ diff --git a/third_party/xla/xla/tools/matmul_perf_table_gen.cc b/third_party/xla/xla/tools/matmul_perf_table_gen.cc index 787ceecd31337c..588dd10b989f02 100644 --- a/third_party/xla/xla/tools/matmul_perf_table_gen.cc +++ b/third_party/xla/xla/tools/matmul_perf_table_gen.cc @@ -69,11 +69,11 @@ namespace { constexpr size_t kNumProfilingRuns = 5; template -struct VariantVisitor : Ts... { +struct Overload : Ts... { using Ts::operator()...; }; template -VariantVisitor(Ts...) -> VariantVisitor; +Overload(Ts...) -> Overload; struct StaticSpec { int b; @@ -252,7 +252,7 @@ std::vector GetExplicitSpecs( for (int i = 0; i < entry_specs.size(); i++) { const EntrySpec& entry_spec = entry_specs[i]; std::visit( - VariantVisitor{ + Overload{ [&specs](const PathSpec& spec) { std::string hlo; CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(), spec.filepath, From 94db88a64eb56d1432a3dea0aacb75801aab82d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 13:51:52 -0700 Subject: [PATCH 0389/1324] Update regex for TF version to conform the strings like `2.20.0-dev0+selfbuilt` PiperOrigin-RevId: 745278722 --- tensorflow/lite/objc/tests/TFLInterpreterTests.m | 15 +++++++-------- .../lite/swift/Tests/TensorFlowLiteTests.swift | 4 ++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/objc/tests/TFLInterpreterTests.m b/tensorflow/lite/objc/tests/TFLInterpreterTests.m index 45c8280af039ce..c0963d1d5a3a08 100644 --- a/tensorflow/lite/objc/tests/TFLInterpreterTests.m +++ b/tensorflow/lite/objc/tests/TFLInterpreterTests.m @@ -20,9 +20,9 @@ /** * Regular expression for TensorFlow Lite runtime version string, e.g. "1.14.0", "0.1.2-alpha.1", - * "0.3.4-beta2", "1.14.0-rc.3". + * "0.3.4-beta2", "1.14.0-rc.3", "2.20.0-dev0+selfbuilt". */ -static NSString *const kTFLVersionRegex = @"^\\d+\\.\\d+\\.\\d+(-[a-zA-Z0-9.-]+)?$"; +static NSString *const kTFLVersionRegex = @"^\\d+\\.\\d+\\.\\d+(-[a-zA-Z0-9.-]+)?(\\+\\w+)?$"; /** Float model resource name. */ static NSString *const kAddFloatModelResourceName = @"add"; @@ -89,12 +89,11 @@ - (void)tearDown { #pragma mark - Tests -// Disable below test for now to validate TFLite ios tests. -// - (void)testTFLVersion { -// NSLog(@"TFLVersion: %@", TFLVersion); -// NSRange range = [TFLVersion rangeOfString:kTFLVersionRegex options:NSRegularExpressionSearch]; -// XCTAssertNotEqual(range.location, NSNotFound); -// } +- (void)testTFLVersion { + NSLog(@"TFLVersion: %@", TFLVersion); + NSRange range = [TFLVersion rangeOfString:kTFLVersionRegex options:NSRegularExpressionSearch]; + XCTAssertNotEqual(range.location, NSNotFound); +} - (void)testSuccessfulFullRunAddFloatModel { // Shape for both input and output tensor. diff --git a/tensorflow/lite/swift/Tests/TensorFlowLiteTests.swift b/tensorflow/lite/swift/Tests/TensorFlowLiteTests.swift index f0b2302f722da4..cdb43b5328fec6 100644 --- a/tensorflow/lite/swift/Tests/TensorFlowLiteTests.swift +++ b/tensorflow/lite/swift/Tests/TensorFlowLiteTests.swift @@ -20,9 +20,9 @@ class TensorFlowLiteTests: XCTestCase { func testRuntime_Version() { #if swift(>=5.0) - let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"# + let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?(\+\w+)?$"# #else - let pattern = "^(\\d+)\\.(\\d+)\\.(\\d+)([+-][-.0-9A-Za-z]+)?$" + let pattern = "^(\\d+)\\.(\\d+)\\.(\\d+)([+-][-.0-9A-Za-z]+)?(\\+\\w+)?$" #endif // swift(>=5.0) XCTAssertNotNil(TensorFlowLite.Runtime.version.range(of: pattern, options: .regularExpression)) } From a172d908bbb55790ea7fc08ab171fdbd1be152a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 13:56:38 -0700 Subject: [PATCH 0390/1324] Make `safe_reinterpret_cast` support `__restrict` pointers, as some XLA code uses them. PiperOrigin-RevId: 745280665 --- .../xla/xla/tsl/util/safe_reinterpret_cast.h | 14 ++++++++++ .../tsl/util/safe_reinterpret_cast_test.cc | 28 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index 67089a8100deae..085f7014e459cd 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -75,6 +75,20 @@ struct IsSafeCast IsCvByteLike::value || std::is_same_v> {}; +// If __restrict is a macro, we assume that the compiler doesn't support +// the __restrict keyword (e.g. when the code is compiled for iOS). Otherwsie, +// we make safe_reinterpret_cast ignore the __restrict qualifier. +#ifndef __restrict // If __restrict is not a macro. + +template +struct IsSafeCast : IsSafeCast {}; +template +struct IsSafeCast : IsSafeCast {}; +template +struct IsSafeCast : IsSafeCast {}; + +#endif // __restrict + // It's safe to cast a pointer to/from std::uintptr_t. template struct IsSafeCast : std::true_type {}; diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index 6deecb32681800..881cd73b0f4c11 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -111,5 +111,33 @@ TEST(SafeReinterpretCast, CanCastPointerToFromSameType) { EXPECT_EQ(char_p, &y); } +TEST(SafeReinterpretCast, CanCastPointerToRestrictPointer) { + const int x = 42; + const char* __restrict const char_p = + safe_reinterpret_cast(&x); + EXPECT_EQ(char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + &x)); +} + +TEST(SafeReinterpretCast, CanCastRestrictPointerToPointer) { + const int x = 42; + const int* __restrict const int_p = &x; + const char* const char_p = safe_reinterpret_cast(int_p); + EXPECT_EQ(char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + &x)); +} + +TEST(SafeReinterpretCast, CanCastRestrictPointerToRestrictPointer) { + const int x = 42; + const int* __restrict const int_p = &x; + const char* __restrict const char_p = + safe_reinterpret_cast(int_p); + EXPECT_EQ(char_p, // + reinterpret_cast( // REINTERPRET_CAST_OK=for testing. + &x)); +} + } // namespace } // namespace tsl From 0ff0b5020f778872e13aeb0c5a7ccae62b36149f Mon Sep 17 00:00:00 2001 From: Charles Alaras Date: Tue, 8 Apr 2025 14:01:36 -0700 Subject: [PATCH 0391/1324] Add comparator to `GetSortedEvents` PiperOrigin-RevId: 745282474 --- third_party/xla/xla/tsl/profiler/utils/BUILD | 1 - third_party/xla/xla/tsl/profiler/utils/timespan_test.cc | 6 +++++- third_party/xla/xla/tsl/profiler/utils/xplane_utils.h | 9 +++++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/tsl/profiler/utils/BUILD b/third_party/xla/xla/tsl/profiler/utils/BUILD index 3fbffc13ce1dde..8ebac922e660c0 100644 --- a/third_party/xla/xla/tsl/profiler/utils/BUILD +++ b/third_party/xla/xla/tsl/profiler/utils/BUILD @@ -221,7 +221,6 @@ cc_library( ":xplane_visitor", "//xla/tsl/platform:types", "//xla/tsl/util:stats_calculator_portable", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc index 986e91740c78ac..608c63f522b458 100644 --- a/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc @@ -82,10 +82,14 @@ TEST(TimespanTests, InstantSpanNonInstantSpanOverlappedDuration) { TEST(TimespanTests, Operators) { EXPECT_LT(Timespan(11, 0), Timespan(12, 0)); - EXPECT_LT(Timespan(12, 1), Timespan(12, 0)); + // Instants nest within larger timespans + EXPECT_LT(Timespan(12, 1), Timespan(12, 0)); EXPECT_FALSE(Timespan(12, 0) < Timespan(12, 1)); + EXPECT_FALSE(Timespan(12, 0) < Timespan(11, 0)); + + // Instants with same beginning are considered equivalent EXPECT_FALSE(Timespan(12, 0) < Timespan(12, 0)); EXPECT_FALSE(Timespan(12, 0) == Timespan(12, 1)); diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h index ef93de69236290..374d87111f6e2b 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -143,7 +142,13 @@ std::vector GetSortedEvents(Plane& plane, line.ForEachEvent( [&events](auto event) { events.emplace_back(std::move(event)); }); }); - absl::c_sort(events); + std::sort(events.begin(), events.end(), [](const Event& a, const Event& b) { + const tsl::profiler::Timespan a_span = a.GetTimespan(); + const tsl::profiler::Timespan b_span = b.GetTimespan(); + if (a_span.begin_ps() < b_span.begin_ps()) return true; + if (a_span.begin_ps() > b_span.begin_ps()) return false; + return a_span.duration_ps() < b_span.duration_ps(); + }); return events; } From 34b1aebe7fbcd3350d9b7e7dcf361016b03dc0cd Mon Sep 17 00:00:00 2001 From: Daniel Chen Date: Tue, 8 Apr 2025 14:04:07 -0700 Subject: [PATCH 0392/1324] Add copy button for full HLO instruction text format in the HTML output. PiperOrigin-RevId: 745283477 --- .../render/hlo_gumgraph_html_renderer.cc | 189 ++++++++++++------ 1 file changed, 128 insertions(+), 61 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc index 95500a2c3718de..f1e27bea457db2 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc @@ -45,7 +45,7 @@ namespace { // Prints the CSS styles for the HTML output. std::string PrintCss() { - return R"( + return R"html( - )"; + )html"; +} + +// Prints javascript for the HTML output. +std::string PrintJavascript() { + return R"html( + + )html"; +} + +// Escapes the string for html attribute. +std::string EscapeStringForHtmlAttribute(absl::string_view str) { + std::string escaped_str; + for (char c : str) { + switch (c) { + case '&': + absl::StrAppend(&escaped_str, "&"); + break; + case '<': + absl::StrAppend(&escaped_str, "<"); + break; + case '>': + absl::StrAppend(&escaped_str, ">"); + break; + case '"': + absl::StrAppend(&escaped_str, """); + break; + case '\'': + absl::StrAppend(&escaped_str, "'"); + break; + default: + absl::StrAppend(&escaped_str, absl::string_view(&c, 1)); + break; + } + } + return escaped_str; } // Prints the div html block. std::string PrintDiv(absl::string_view content, absl::string_view class_name) { - return absl::StrFormat("
%s
", class_name, content); + return absl::StrFormat(R"html(
%s
)html", class_name, + content); } // Prints the detail html block. std::string PrintDetails(absl::string_view summary, absl::string_view content) { - return absl::StrFormat(R"(
%s%s
)", - summary, PrintDiv(content, "content")); + return absl::StrFormat( + R"html(
%s%s
)html", summary, + PrintDiv(content, "content")); +} + +// Prints a link to the given url. +std::string PrintLink(absl::string_view text, absl::string_view url) { + return absl::StrFormat("%s", url, text); } // Prints a html block with a header. @@ -168,51 +225,82 @@ std::string PrintAttributesList(absl::Span items) { "attributes-list"); } -// Prints a link to the given url. -std::string PrintLink(absl::string_view text, absl::string_view url) { - return absl::StrFormat("%s", url, text); -} - // Prints a span with a tooltip. std::string PrintTooltip(absl::string_view text, absl::string_view tooltip_text) { return absl::StrFormat( - R"(%s%s)", + R"html(%s%s)html", text, tooltip_text); } +// Print click to copy button. +std::string PrintClickToCopyButton(absl::string_view text, + absl::string_view content) { + return absl::StrFormat( + R"html(%s)html", + EscapeStringForHtmlAttribute(content), + PrintTooltip(text, "Click to copy")); +} + /*** Summary logic ***/ -// Prints a link to the instruction in model explorer if url_generator is not -// null, otherwise returns the text directly. -std::string PrintInstructionLink(const HloInstruction* left_inst, +// Prints the instruction name and click to copy button that copy the text +// format. +std::string PrintInstruction(const HloInstruction& inst) { + return absl::StrFormat("%s (%s)", inst.name(), + PrintClickToCopyButton("text", inst.ToString())); +} + +// Prints a pair of instructions. If url_generator is not null, a link to the +// pair of instructions in model explorer will be printed. +std::string PrintInstructionPair(const HloInstruction* left_inst, const HloInstruction* right_inst, - absl::string_view text, GraphUrlGenerator* url_generator) { + std::vector instructions; + if (left_inst != nullptr) { + instructions.push_back(PrintInstruction(*left_inst)); + } + if (right_inst != nullptr) { + instructions.push_back(PrintInstruction(*right_inst)); + } + std::string text = absl::StrJoin(instructions, " ↔ "); if (url_generator == nullptr) { - return std::string(text); + return text; } std::string url = url_generator->Generate(left_inst, right_inst); if (url.empty()) { - return std::string(text); + return text; } - return PrintLink(text, url); + return absl::StrCat(text, " (", PrintLink("Model Explorer", url), ")"); +} + +// Prints computation name and click to copy button that copy the text format. +std::string PrintComputation(const HloComputation& comp) { + return absl::StrFormat("%s (%s)", comp.name(), + PrintClickToCopyButton("text", comp.ToString())); } -// Prints a link to the computation in model explorer if url_generator is not -// null, otherwise returns the text directly. -std::string PrintComputationLink(const HloComputation* left_comp, +// Prints a pair of computations. If url_generator is not null, a link to the +// pair of computations in model explorer will be printed. +std::string PrintComputationPair(const HloComputation* left_comp, const HloComputation* right_comp, - absl::string_view text, GraphUrlGenerator* url_generator) { + std::vector computations; + if (left_comp != nullptr) { + computations.push_back(PrintComputation(*left_comp)); + } + if (right_comp != nullptr) { + computations.push_back(PrintComputation(*right_comp)); + } + std::string text = absl::StrJoin(computations, " ↔ "); if (url_generator == nullptr) { - return std::string(text); + return text; } - std::string maybe_url = url_generator->Generate(left_comp, right_comp); - if (maybe_url.empty()) { - return std::string(text); + std::string url = url_generator->Generate(left_comp, right_comp); + if (url.empty()) { + return text; } - return PrintLink(text, maybe_url); + return absl::StrCat(text, " (", PrintLink("Model Explorer", url), ")"); } // The location of the instruction in the diff result. @@ -227,13 +315,9 @@ std::string PrintInstructionsAsList( for (const HloInstruction* inst : instructions) { std::string link; if (location == InstructionLocation::kLeft) { - link = PrintInstructionLink(inst, /*right_inst=*/nullptr, - InstructionToString(inst, name_only), - url_generator); + link = PrintInstructionPair(inst, /*right_inst=*/nullptr, url_generator); } else { - link = PrintInstructionLink(/*left_inst=*/nullptr, inst, - InstructionToString(inst, name_only), - url_generator); + link = PrintInstructionPair(/*left_inst=*/nullptr, inst, url_generator); } instructions_list.push_back(link); } @@ -375,12 +459,7 @@ std::string PrintChangedInstructions( GetChangedInstructionDiffTypes(*left_inst, *right_inst); return absl::StrFormat( "%s have changed: %s", - PrintInstructionLink( - left_inst, right_inst, - absl::StrFormat( - "%s and %s", InstructionToString(left_inst, /*name_only=*/true), - InstructionToString(right_inst, /*name_only=*/true)), - url_generator), + PrintInstructionPair(left_inst, right_inst, url_generator), absl::StrJoin( diff_types, ", ", [&left_inst, &right_inst](std::string* out, const auto& diff_type) { @@ -408,12 +487,7 @@ std::string PrintUnchangedInstructions( GraphUrlGenerator* url_generator) { auto simple_printer = [&url_generator](const HloInstruction* left_inst, const HloInstruction* right_inst) { - return PrintInstructionLink( - left_inst, right_inst, - absl::StrFormat("%s and %s", - InstructionToString(left_inst, /*name_only=*/true), - InstructionToString(right_inst, /*name_only=*/true)), - url_generator); + return PrintInstructionPair(left_inst, right_inst, url_generator); }; return PrintInstructionPairsByOpcode(instructions, opcodes_to_ignore, simple_printer); @@ -432,11 +506,10 @@ std::string PrintUnmatchedMetricsDiff( std::sort(sorted_metrics_diff.begin(), sorted_metrics_diff.end()); std::vector metrics_diff_list(sorted_metrics_diff.size()); for (const auto& [inst, metrics_diff] : sorted_metrics_diff) { - metrics_diff_list.push_back( - absl::StrFormat("%s: %.2f (us)", - PrintInstructionLink(inst, /*right_inst=*/nullptr, - inst->name(), url_generator), - metrics_diff / 1e6)); + metrics_diff_list.push_back(absl::StrFormat( + "%s: %.2f (us)", + PrintInstructionPair(inst, /*right_inst=*/nullptr, url_generator), + metrics_diff / 1e6)); } return PrintList(metrics_diff_list); } @@ -464,10 +537,7 @@ std::string PrintMatchedMetricsDiff( const auto& [left_inst, right_inst] = inst_pair; metrics_diff_list.push_back(absl::StrFormat( "%s: %.2f (us)", - PrintInstructionLink( - left_inst, right_inst, - absl::StrFormat("%s and %s", left_inst->name(), right_inst->name()), - url_generator), + PrintInstructionPair(left_inst, right_inst, url_generator), metrics_diff / 1e6)); } return PrintList(metrics_diff_list); @@ -519,11 +589,8 @@ std::string PrintRepetitiveComputationGroups(const DiffSummary& diff_summary, computation_group.left_computations[0]; const HloComputation* right_computation = computation_group.right_computations[0]; - computation_pair_list.push_back(PrintComputationLink( - left_computation, right_computation, - absl::StrFormat("%s and %s", left_computation->name(), - right_computation->name()), - url_generator)); + computation_pair_list.push_back(PrintComputationPair( + left_computation, right_computation, url_generator)); } absl::StrAppend( &computation_group_list, @@ -550,7 +617,7 @@ void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, GetOpMetricFn right_op_metrics, std::ostringstream& out) { const absl::flat_hash_set ignored_opcodes(kIgnoredOpcodes.begin(), kIgnoredOpcodes.end()); - out << PrintCss(); + out << PrintCss() << PrintJavascript(); // Print full diff results out << PrintSectionWithHeader( From c0093bfb21f0e1157e1892cc8273e10831077174 Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Tue, 8 Apr 2025 14:06:24 -0700 Subject: [PATCH 0393/1324] Add python 3.13 requirements and conditions for TF the code below creates a condition to run specific Python.h code in cpp for python3.13 as the following functions are deprecated: _PyArg_NoKeywords (removed) _PyObject_VisitManagedDict (renamed to PyObject_VisitManagedDict) _PyObject_ClearManagedDict (renamed to PyObject_ClearManagedDict) PiperOrigin-RevId: 745284443 --- WORKSPACE | 1 + .../requirements_updater/requirements.in | 4 +- requirements_lock_3_13.txt | 842 ++++++++++++++++++ tensorflow/python/eager/pywrap_tensor.cc | 16 + .../tools/toolchains/python/python_repo.bzl | 2 +- .../tools/toolchains/python/python_repo.bzl | 2 +- 6 files changed, 863 insertions(+), 4 deletions(-) create mode 100644 requirements_lock_3_13.txt diff --git a/WORKSPACE b/WORKSPACE index 445f974b094333..e42663c6922986 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,6 +43,7 @@ python_init_repositories( "3.10": "//:requirements_lock_3_10.txt", "3.11": "//:requirements_lock_3_11.txt", "3.12": "//:requirements_lock_3_12.txt", + "3.13": "//:requirements_lock_3_13.txt", }, ) diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 0cfbaf22f820b1..f63fa5ccc52934 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -28,7 +28,7 @@ requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 -zstandard=0.23.0 +zstandard==0.23.0 # NVIDIA CUDA dependencies # Note that the wheels are downloaded only when the targets in bazel command # contain dependencies on these wheels. @@ -44,7 +44,7 @@ nvidia-cusparse-cu12 == 12.5.1.3 nvidia-nccl-cu12 == 2.25.1 nvidia-nvjitlink-cu12 == 12.5.82 # The dependencies below are needed for TF wheel testing. -tensorflow-io-gcs-filesystem==0.37.1 +tensorflow-io-gcs-filesystem==0.37.1 ; python_version <= "3.12" libclang >= 13.0.0 google_pasta ~= 0.2 flatbuffers ~= 24.3.25 diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt new file mode 100644 index 00000000000000..a03c65b0b2486c --- /dev/null +++ b/requirements_lock_3_13.txt @@ -0,0 +1,842 @@ +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# +absl-py==2.2.1 \ + --hash=sha256:4c7bc50d42d021c12d4f31b7001167925e0bd71ade853069f64af410f5565ff9 \ + --hash=sha256:ca8209abd5005ae6e700ef36e2edc84ad5338678f95625a3f15275410a89ffbc + # via + # dm-tree + # keras-nightly + # tb-nightly +astor==0.7.1 \ + --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ + --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e + # via -r ci/official/requirements_updater/requirements.in +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via -r ci/official/requirements_updater/requirements.in +attrs==25.3.0 \ + --hash=sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3 \ + --hash=sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b + # via dm-tree +auditwheel==6.3.0 \ + --hash=sha256:05c70a234fa14c140aa6d9076135d9550962d95849911b8d5d0419a3add09f00 \ + --hash=sha256:31cbd8045d4ff6776f79bef328b5fd563e5ecc8ae82ea34b6fe5e76efe2a84eb + # via -r ci/official/requirements_updater/requirements.in +certifi==2025.1.31 \ + --hash=sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651 \ + --hash=sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe + # via requests +charset-normalizer==3.4.1 \ + --hash=sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537 \ + --hash=sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa \ + --hash=sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a \ + --hash=sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294 \ + --hash=sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b \ + --hash=sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd \ + --hash=sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601 \ + --hash=sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd \ + --hash=sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4 \ + --hash=sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d \ + --hash=sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2 \ + --hash=sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313 \ + --hash=sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd \ + --hash=sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa \ + --hash=sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8 \ + --hash=sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1 \ + --hash=sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2 \ + --hash=sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496 \ + --hash=sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d \ + --hash=sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b \ + --hash=sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e \ + --hash=sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a \ + --hash=sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4 \ + --hash=sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca \ + --hash=sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78 \ + --hash=sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408 \ + --hash=sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5 \ + --hash=sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3 \ + --hash=sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f \ + --hash=sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a \ + --hash=sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765 \ + --hash=sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6 \ + --hash=sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146 \ + --hash=sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6 \ + --hash=sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9 \ + --hash=sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd \ + --hash=sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c \ + --hash=sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f \ + --hash=sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545 \ + --hash=sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176 \ + --hash=sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770 \ + --hash=sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824 \ + --hash=sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f \ + --hash=sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf \ + --hash=sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487 \ + --hash=sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d \ + --hash=sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd \ + --hash=sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b \ + --hash=sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534 \ + --hash=sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f \ + --hash=sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b \ + --hash=sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9 \ + --hash=sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd \ + --hash=sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125 \ + --hash=sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9 \ + --hash=sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de \ + --hash=sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11 \ + --hash=sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d \ + --hash=sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35 \ + --hash=sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f \ + --hash=sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda \ + --hash=sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7 \ + --hash=sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a \ + --hash=sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971 \ + --hash=sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8 \ + --hash=sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41 \ + --hash=sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d \ + --hash=sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f \ + --hash=sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757 \ + --hash=sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a \ + --hash=sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886 \ + --hash=sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77 \ + --hash=sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76 \ + --hash=sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247 \ + --hash=sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85 \ + --hash=sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb \ + --hash=sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7 \ + --hash=sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e \ + --hash=sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6 \ + --hash=sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037 \ + --hash=sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1 \ + --hash=sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e \ + --hash=sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807 \ + --hash=sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407 \ + --hash=sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c \ + --hash=sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12 \ + --hash=sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3 \ + --hash=sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089 \ + --hash=sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd \ + --hash=sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e \ + --hash=sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00 \ + --hash=sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616 + # via requests +dill==0.3.7 \ + --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ + --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 + # via -r ci/official/requirements_updater/requirements.in +dm-tree==0.1.9 \ + --hash=sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf \ + --hash=sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2 \ + --hash=sha256:2334cfe9d2ed4293f9f1c7aefba0657deaab9ea74b5fadd966f6d01d9b6b42d9 \ + --hash=sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7 \ + --hash=sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15 \ + --hash=sha256:5d5b28ee2e461b6af65330c143806a6d0945dcabbb8d22d2ba863e6dabd9254e \ + --hash=sha256:6893fcdc5cf1a4f459cfc383526d35d42e7c671ae565d7e429a2f2cb2cb93e89 \ + --hash=sha256:7d7d784afaeb4b67d87d858261aaf02503939ddc1f09c4cca70728f9892ab004 \ + --hash=sha256:80c43417814b1181d3367b335460bfdd30b79ee187a64220e11f6ddd093a4b15 \ + --hash=sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc \ + --hash=sha256:9020a5ce256fcc83aa4bc190cc96dd66e87685db0a6e501b0c06aa492c2e38fc \ + --hash=sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b \ + --hash=sha256:a8d20eeab7fde77a3ed71f07716021eb0edfb4812a128eb381d108af3a310257 \ + --hash=sha256:b06e7a5da1c31a82521a60060573527e8d24b9920fdd20b2ec86f08412737598 \ + --hash=sha256:cfa33c2e028155810ad1b4e11928707bf47489516763a86e79cab2954d23bf68 \ + --hash=sha256:d05622d074353cf434049206e53c12147903a048c4bd7d77f2800d427413ad78 \ + --hash=sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c \ + --hash=sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c \ + --hash=sha256:e97c34fcb44941c36b7ee81dcdbceba0fbe728bddcc77e5837ab2eb665bcbff8 \ + --hash=sha256:f68b0efad76703dd4648586c75618a48cdd671b68c3266fe980e323c15423607 + # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in +gast==0.4.0 \ + --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ + --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in +grpcio==1.71.0 \ + --hash=sha256:0ab8b2864396663a5b0b0d6d79495657ae85fa37dcb6498a2669d067c65c11ea \ + --hash=sha256:0fa05ee31a20456b13ae49ad2e5d585265f71dd19fbd9ef983c28f926d45d0a7 \ + --hash=sha256:0ff35c8d807c1c7531d3002be03221ff9ae15712b53ab46e2a0b4bb271f38537 \ + --hash=sha256:1be857615e26a86d7363e8a163fade914595c81fec962b3d514a4b1e8760467b \ + --hash=sha256:20e8f653abd5ec606be69540f57289274c9ca503ed38388481e98fa396ed0b41 \ + --hash=sha256:22c3bc8d488c039a199f7a003a38cb7635db6656fa96437a8accde8322ce2366 \ + --hash=sha256:24e867651fc67717b6f896d5f0cac0ec863a8b5fb7d6441c2ab428f52c651c6b \ + --hash=sha256:2b85f7820475ad3edec209d3d89a7909ada16caab05d3f2e08a7e8ae3200a55c \ + --hash=sha256:39983a9245d37394fd59de71e88c4b295eb510a3555e0a847d9965088cdbd033 \ + --hash=sha256:3d081e859fb1ebe176de33fc3adb26c7d46b8812f906042705346b314bde32c3 \ + --hash=sha256:469f42a0b410883185eab4689060a20488a1a0a00f8bbb3cbc1061197b4c5a79 \ + --hash=sha256:47be9584729534660416f6d2a3108aaeac1122f6b5bdbf9fd823e11fe6fbaa29 \ + --hash=sha256:4be74ddeeb92cc87190e0e376dbc8fc7736dbb6d3d454f2fa1f5be1dee26b9d7 \ + --hash=sha256:4dd0dfbe4d5eb1fcfec9490ca13f82b089a309dc3678e2edabc144051270a66e \ + --hash=sha256:5b08d03ace7aca7b2fadd4baf291139b4a5f058805a8327bfe9aece7253b6d67 \ + --hash=sha256:63e41b91032f298b3e973b3fa4093cbbc620c875e2da7b93e249d4728b54559a \ + --hash=sha256:652350609332de6dac4ece254e5d7e1ff834e203d6afb769601f286886f6f3a8 \ + --hash=sha256:693bc706c031aeb848849b9d1c6b63ae6bcc64057984bb91a542332b75aa4c3d \ + --hash=sha256:74258dce215cb1995083daa17b379a1a5a87d275387b7ffe137f1d5131e2cfbb \ + --hash=sha256:789d5e2a3a15419374b7b45cd680b1e83bbc1e52b9086e49308e2c0b5bbae6e3 \ + --hash=sha256:7c9c80ac6091c916db81131d50926a93ab162a7e97e4428ffc186b6e80d6dda4 \ + --hash=sha256:7d6ac9481d9d0d129224f6d5934d5832c4b1cddb96b59e7eba8416868909786a \ + --hash=sha256:85da336e3649a3d2171e82f696b5cad2c6231fdd5bad52616476235681bee5b3 \ + --hash=sha256:8700a2a57771cc43ea295296330daaddc0d93c088f0a35cc969292b6db959bf3 \ + --hash=sha256:8997d6785e93308f277884ee6899ba63baafa0dfb4729748200fcc537858a509 \ + --hash=sha256:9182e0063112e55e74ee7584769ec5a0b4f18252c35787f48738627e23a62b97 \ + --hash=sha256:9b91879d6da1605811ebc60d21ab6a7e4bae6c35f6b63a061d61eb818c8168f6 \ + --hash=sha256:a2242d6950dc892afdf9e951ed7ff89473aaf744b7d5727ad56bdaace363722b \ + --hash=sha256:a371e6b6a5379d3692cc4ea1cb92754d2a47bdddeee755d3203d1f84ae08e03e \ + --hash=sha256:a76d39b5fafd79ed604c4be0a869ec3581a172a707e2a8d7a4858cb05a5a7637 \ + --hash=sha256:ad9f30838550695b5eb302add33f21f7301b882937460dd24f24b3cc5a95067a \ + --hash=sha256:b2266862c5ad664a380fbbcdbdb8289d71464c42a8c29053820ee78ba0119e5d \ + --hash=sha256:b78a99cd1ece4be92ab7c07765a0b038194ded2e0a26fd654591ee136088d8d7 \ + --hash=sha256:c200cb6f2393468142eb50ab19613229dcc7829b5ccee8b658a36005f6669fdd \ + --hash=sha256:c30f393f9d5ff00a71bb56de4aa75b8fe91b161aeb61d39528db6b768d7eac69 \ + --hash=sha256:c6a0a28450c16809f94e0b5bfe52cabff63e7e4b97b44123ebf77f448534d07d \ + --hash=sha256:cebc1b34ba40a312ab480ccdb396ff3c529377a2fce72c45a741f7215bfe8379 \ + --hash=sha256:d2c170247315f2d7e5798a22358e982ad6eeb68fa20cf7a820bb74c11f0736e7 \ + --hash=sha256:d35a95f05a8a2cbe8e02be137740138b3b2ea5f80bd004444e4f9a1ffc511e32 \ + --hash=sha256:d5170929109450a2c031cfe87d6716f2fae39695ad5335d9106ae88cc32dc84c \ + --hash=sha256:d6aa986318c36508dc1d5001a3ff169a15b99b9f96ef5e98e13522c506b37eef \ + --hash=sha256:d6de81c9c00c8a23047136b11794b3584cdc1460ed7cbc10eada50614baa1444 \ + --hash=sha256:dc1a1231ed23caac1de9f943d031f1bc38d0f69d2a3b243ea0d664fc1fbd7fec \ + --hash=sha256:e6beeea5566092c5e3c4896c6d1d307fb46b1d4bdf3e70c8340b190a69198594 \ + --hash=sha256:e6d8de076528f7c43a2f576bc311799f89d795aa6c9b637377cc2b1616473804 \ + --hash=sha256:e6f83a583ed0a5b08c5bc7a3fe860bb3c2eac1f03f1f63e0bc2091325605d2b7 \ + --hash=sha256:f250ff44843d9a0615e350c77f890082102a0318d66a99540f54769c8766ab73 \ + --hash=sha256:f71574afdf944e6652203cd1badcda195b2a27d9c83e6d88dc1ce3cfb73b31a5 \ + --hash=sha256:f903017db76bf9cc2b2d8bdd37bf04b505bbccad6be8a81e1542206875d0e9db \ + --hash=sha256:f9a412f55bb6e8f3bb000e020dbc1e709627dcb3a56f6431fa7076b4c1aab0db \ + --hash=sha256:f9c30c464cb2ddfbc2ddf9400287701270fdc0f14be5f08a1e3939f1e749b455 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly +h5py==3.13.0 \ + --hash=sha256:10894c55d46df502d82a7a4ed38f9c3fdbcb93efb42e25d275193e093071fade \ + --hash=sha256:1870e46518720023da85d0895a1960ff2ce398c5671eac3b1a41ec696b7105c3 \ + --hash=sha256:21daf38171753899b5905f3d82c99b0b1ec2cbbe282a037cad431feb620e62ec \ + --hash=sha256:22ffe2a25770a2d67213a1b94f58006c14dce06933a42d2aaa0318c5868d1508 \ + --hash=sha256:337af114616f3656da0c83b68fcf53ecd9ce9989a700b0883a6e7c483c3235d4 \ + --hash=sha256:357e6dc20b101a805ccfd0024731fbaf6e8718c18c09baf3b5e4e9d198d13fca \ + --hash=sha256:477c58307b6b9a2509c59c57811afb9f598aedede24a67da808262dfa0ee37b4 \ + --hash=sha256:4f97ecde7ac6513b21cd95efdfc38dc6d19f96f6ca6f2a30550e94e551458e0a \ + --hash=sha256:5540daee2b236d9569c950b417f13fd112d51d78b4c43012de05774908dff3f5 \ + --hash=sha256:560e71220dc92dfa254b10a4dcb12d56b574d2d87e095db20466b32a93fec3f9 \ + --hash=sha256:56dd172d862e850823c4af02dc4ddbc308f042b85472ffdaca67f1598dff4a57 \ + --hash=sha256:57c4c74f627c616f02b7aec608a8c706fe08cb5b0ba7c08555a4eb1dde20805a \ + --hash=sha256:782ff0ac39f455f21fd1c8ebc007328f65f43d56718a89327eec76677ebf238a \ + --hash=sha256:82690e89c72b85addf4fc4d5058fb1e387b6c14eb063b0b879bf3f42c3b93c35 \ + --hash=sha256:851ae3a8563d87a5a0dc49c2e2529c75b8842582ccaefbf84297d2cfceeacd61 \ + --hash=sha256:8a8e38ef4ceb969f832cc230c0cf808c613cc47e31e768fd7b1106c55afa1cb8 \ + --hash=sha256:9c82ece71ed1c2b807b6628e3933bc6eae57ea21dac207dca3470e3ceaaf437c \ + --hash=sha256:be949b46b7388074c5acae017fbbe3e5ba303fd9daaa52157fdfef30bbdacadd \ + --hash=sha256:c10f061764d8dce0a9592ce08bfd5f243a00703325c388f1086037e5d619c5f1 \ + --hash=sha256:d2cf6a231a07c14acd504a945a6e9ec115e0007f675bde5e0de30a4dc8d86a31 \ + --hash=sha256:d571644958c5e19a61c793d8d23cd02479572da828e333498c9acc463f4a3997 \ + --hash=sha256:d6f13f9b5ce549448c01e4dfe08ea8d1772e6078799af2c1c8d09e941230a90d \ + --hash=sha256:e520ec76de00943dd017c8ea3f354fa1d2f542eac994811943a8faedf2a7d5cb \ + --hash=sha256:e79d8368cd9295045956bfb436656bea3f915beaa11d342e9f79f129f5178763 \ + --hash=sha256:f35640e81b03c02a88b8bf99fb6a9d3023cc52f7c627694db2f379e0028f2868 \ + --hash=sha256:fb267ce4b83f9c42560e9ff4d30f60f7ae492eacf9c7ede849edf8c1b860e16b + # via + # -r ci/official/requirements_updater/requirements.in + # keras-nightly +idna==3.10 \ + --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ + --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 + # via requests +jax==0.4.7 \ + --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 + # via -r ci/official/requirements_updater/requirements.in +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b + # via -r ci/official/requirements_updater/requirements.in +markdown==3.7 \ + --hash=sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2 \ + --hash=sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803 + # via tb-nightly +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +markupsafe==3.0.2 \ + --hash=sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4 \ + --hash=sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30 \ + --hash=sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0 \ + --hash=sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9 \ + --hash=sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396 \ + --hash=sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13 \ + --hash=sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028 \ + --hash=sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca \ + --hash=sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557 \ + --hash=sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832 \ + --hash=sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0 \ + --hash=sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b \ + --hash=sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579 \ + --hash=sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a \ + --hash=sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c \ + --hash=sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff \ + --hash=sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c \ + --hash=sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22 \ + --hash=sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094 \ + --hash=sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb \ + --hash=sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e \ + --hash=sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5 \ + --hash=sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a \ + --hash=sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d \ + --hash=sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a \ + --hash=sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b \ + --hash=sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8 \ + --hash=sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225 \ + --hash=sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c \ + --hash=sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144 \ + --hash=sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f \ + --hash=sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87 \ + --hash=sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d \ + --hash=sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93 \ + --hash=sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf \ + --hash=sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158 \ + --hash=sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84 \ + --hash=sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb \ + --hash=sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48 \ + --hash=sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171 \ + --hash=sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c \ + --hash=sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6 \ + --hash=sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd \ + --hash=sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d \ + --hash=sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1 \ + --hash=sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d \ + --hash=sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca \ + --hash=sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a \ + --hash=sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29 \ + --hash=sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe \ + --hash=sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798 \ + --hash=sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c \ + --hash=sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8 \ + --hash=sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f \ + --hash=sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f \ + --hash=sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a \ + --hash=sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178 \ + --hash=sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0 \ + --hash=sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79 \ + --hash=sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430 \ + --hash=sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50 + # via werkzeug +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 + # via + # -r ci/official/requirements_updater/requirements.in + # jax + # keras-nightly +namex==0.0.8 \ + --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ + --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 + # via keras-nightly +numpy==2.1.3 \ + --hash=sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe \ + --hash=sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0 \ + --hash=sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48 \ + --hash=sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a \ + --hash=sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564 \ + --hash=sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958 \ + --hash=sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17 \ + --hash=sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0 \ + --hash=sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee \ + --hash=sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b \ + --hash=sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4 \ + --hash=sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4 \ + --hash=sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6 \ + --hash=sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4 \ + --hash=sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d \ + --hash=sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f \ + --hash=sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f \ + --hash=sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f \ + --hash=sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56 \ + --hash=sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9 \ + --hash=sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd \ + --hash=sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23 \ + --hash=sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed \ + --hash=sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a \ + --hash=sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098 \ + --hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \ + --hash=sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512 \ + --hash=sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f \ + --hash=sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09 \ + --hash=sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f \ + --hash=sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc \ + --hash=sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8 \ + --hash=sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0 \ + --hash=sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761 \ + --hash=sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef \ + --hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \ + --hash=sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e \ + --hash=sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b \ + --hash=sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d \ + --hash=sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43 \ + --hash=sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c \ + --hash=sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41 \ + --hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff \ + --hash=sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408 \ + --hash=sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2 \ + --hash=sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9 \ + --hash=sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57 \ + --hash=sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb \ + --hash=sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9 \ + --hash=sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3 \ + --hash=sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a \ + --hash=sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0 \ + --hash=sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e \ + --hash=sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598 \ + --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree + # h5py + # jax + # keras-nightly + # ml-dtypes + # opt-einsum + # scipy + # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via + # -r ci/official/requirements_updater/requirements.in + # jax +packaging==23.2 \ + --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ + --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 + # via + # -r ci/official/requirements_updater/requirements.in + # auditwheel + # tb-nightly +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r ci/official/requirements_updater/requirements.in +protobuf==6.30.2 \ + --hash=sha256:0eb523c550a66a09a0c20f86dd554afbf4d32b02af34ae53d93268c1f73bc65b \ + --hash=sha256:35c859ae076d8c56054c25b59e5e59638d86545ed6e2b6efac6be0b6ea3ba048 \ + --hash=sha256:4f6c687ae8efae6cf6093389a596548214467778146b7245e886f35e1485315d \ + --hash=sha256:50f32cc9fd9cb09c783ebc275611b4f19dfdfb68d1ee55d2f0c7fa040df96815 \ + --hash=sha256:524afedc03b31b15586ca7f64d877a98b184f007180ce25183d1a5cb230ee72b \ + --hash=sha256:7653c99774f73fe6b9301b87da52af0e69783a2e371e8b599b3e9cb4da4b12b9 \ + --hash=sha256:acec579c39c88bd8fbbacab1b8052c793efe83a0a5bd99db4a31423a25c0a0e2 \ + --hash=sha256:ae86b030e69a98e08c77beab574cbcb9fff6d031d57209f574a5aea1445f4b51 \ + --hash=sha256:b12ef7df7b9329886e66404bef5e9ce6a26b54069d7f7436a0853ccdeb91c103 + # via tb-nightly +psutil==7.0.0 \ + --hash=sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25 \ + --hash=sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e \ + --hash=sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91 \ + --hash=sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da \ + --hash=sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34 \ + --hash=sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553 \ + --hash=sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456 \ + --hash=sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17 \ + --hash=sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993 \ + --hash=sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99 + # via portpicker +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 + # via auditwheel +pygments==2.19.1 \ + --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ + --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c + # via rich +requests==2.32.3 \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 + # via -r ci/official/requirements_updater/requirements.in +rich==14.0.0 \ + --hash=sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0 \ + --hash=sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725 + # via keras-nightly +scipy==1.15.2 \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db + # via + # -r ci/official/requirements_updater/requirements.in + # jax +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # tb-nightly +tb-nightly==2.19.0a20250218 \ + --hash=sha256:7c7fea911a9e113e7d40fa9aed96168840e2443c5ada52fba5bc3645ec6e206f + # via -r ci/official/requirements_updater/requirements.in +tblib==2.0.0 \ + --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ + --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 + # via -r ci/official/requirements_updater/requirements.in +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tb-nightly +termcolor==2.3.0 \ + --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ + --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a + # via -r ci/official/requirements_updater/requirements.in +typing-extensions==4.8.0 \ + --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ + --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef + # via -r ci/official/requirements_updater/requirements.in +urllib3==2.3.0 \ + --hash=sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df \ + --hash=sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d + # via requests +werkzeug==3.1.3 \ + --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ + --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 + # via tb-nightly +wheel==0.41.3 \ + --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ + --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 + # via + # -r ci/official/requirements_updater/requirements.in + # astunparse +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree +zstandard==0.23.0 \ + --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ + --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ + --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ + --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ + --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ + --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ + --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ + --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ + --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ + --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ + --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ + --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ + --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ + --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ + --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ + --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ + --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ + --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ + --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ + --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ + --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ + --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ + --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ + --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ + --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ + --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ + --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ + --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ + --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ + --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ + --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ + --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ + --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ + --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ + --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ + --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ + --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ + --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ + --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ + --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ + --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ + --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ + --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ + --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ + --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ + --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ + --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ + --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ + --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ + --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ + --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ + --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ + --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ + --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ + --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ + --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ + --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ + --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ + --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ + --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ + --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ + --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ + --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ + --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ + --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ + --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ + --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ + --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ + --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ + --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ + --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ + --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ + --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ + --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ + --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ + --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ + --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ + --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ + --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ + --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ + --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ + --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ + --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ + --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ + --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ + --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ + --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ + --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ + --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ + --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ + --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ + --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ + --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ + --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ + --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ + --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ + --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 + # via -r ci/official/requirements_updater/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 5bfa389e92a08a..e52aa0b799b1fe 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -686,7 +686,19 @@ static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value, // Function `_copy_to_device`. static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, PyObject* kwds) { +#if PY_VERSION_HEX <= 0x030C0000 // <= Python 3.12 if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr; +#else + const char* keyname = "copy_to_device"; + if (kwds != NULL && PyDict_Size(kwds) > 0) { + PyErr_SetString(PyExc_TypeError, "Function does not accept keyword args."); + return nullptr; + } + + if (!PyArg_ParseTuple(args, "s", &keyname)) { + return nullptr; + } +#endif const char* device_name = nullptr; if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName, @@ -873,6 +885,8 @@ static int EagerTensor_traverse(PyObject* self, visitproc visit, void* arg) { #if PY_VERSION_HEX < 0x030C0000 // < Python 3.12 PyObject*& dict = *_PyObject_GetDictPtr(self); Py_VISIT(dict); +#elif PY_VERSION_HEX >= 0x030D0000 // >= Python 3.13 + PyObject_VisitManagedDict(self, visit, arg); #else _PyObject_VisitManagedDict(self, visit, arg); #endif // PY_VERSION_HEX < 0x030C0000 @@ -896,6 +910,8 @@ extern int EagerTensor_clear(PyObject* self) { #if PY_VERSION_HEX < 0x030C0000 // < Python 3.12 PyObject*& dict = *_PyObject_GetDictPtr(self); Py_CLEAR(dict); +#elif PY_VERSION_HEX >= 0x030D0000 // >= Python 3.13 + PyObject_ClearManagedDict(self); #else _PyObject_ClearManagedDict(self); #endif // PY_VERSION_HEX < 0x030C0000 diff --git a/tensorflow/tools/toolchains/python/python_repo.bzl b/tensorflow/tools/toolchains/python/python_repo.bzl index 47fe64d7b7b039..2af9b29d7af20b 100644 --- a/tensorflow/tools/toolchains/python/python_repo.bzl +++ b/tensorflow/tools/toolchains/python/python_repo.bzl @@ -7,7 +7,7 @@ Defaults to 3.10. To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" """ -VERSIONS = ["3.9", "3.10", "3.11", "3.12"] +VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] DEFAULT_VERSION = "3.11" WARNING = """ TF_PYTHON_VERSION environment variable was not set correctly; using Python {}. diff --git a/third_party/xla/tools/toolchains/python/python_repo.bzl b/third_party/xla/tools/toolchains/python/python_repo.bzl index 47fe64d7b7b039..2af9b29d7af20b 100644 --- a/third_party/xla/tools/toolchains/python/python_repo.bzl +++ b/third_party/xla/tools/toolchains/python/python_repo.bzl @@ -7,7 +7,7 @@ Defaults to 3.10. To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" """ -VERSIONS = ["3.9", "3.10", "3.11", "3.12"] +VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] DEFAULT_VERSION = "3.11" WARNING = """ TF_PYTHON_VERSION environment variable was not set correctly; using Python {}. From 3f0e03a0a78cf6f1e7e702f14030136440cce655 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 14:12:01 -0700 Subject: [PATCH 0394/1324] Make `xla/tsl/platform/default:all` buildable. PiperOrigin-RevId: 745286599 --- third_party/xla/xla/tsl/platform/default/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/tsl/platform/default/BUILD b/third_party/xla/xla/tsl/platform/default/BUILD index 9202de8697df6e..08e109c6b57eca 100644 --- a/third_party/xla/xla/tsl/platform/default/BUILD +++ b/third_party/xla/xla/tsl/platform/default/BUILD @@ -153,9 +153,9 @@ cc_library( "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform", @@ -163,7 +163,6 @@ cc_library( "@local_tsl//tsl/platform:context", "@local_tsl//tsl/platform:cord", "@local_tsl//tsl/platform:denormal", - "@local_tsl//tsl/platform:load_library", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", From 884af6db46eb09928f52c1e21fe8f0eda55eec33 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 8 Apr 2025 14:29:36 -0700 Subject: [PATCH 0395/1324] Add debug callback partitioner. PiperOrigin-RevId: 745292986 --- third_party/xla/xla/python/BUILD | 14 ++++ .../xla/python/debug_callback_partitioner.cc | 84 +++++++++++++++++++ .../xla/python/debug_callback_partitioner.h | 37 ++++++++ third_party/xla/xla/python/version.h | 2 +- 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/python/debug_callback_partitioner.cc create mode 100644 third_party/xla/xla/python/debug_callback_partitioner.h diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 6f7941d35e15b9..fcec573557ad2c 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -191,6 +191,20 @@ cc_library( ], ) +cc_library( + name = "debug_callback_partitioner", + srcs = ["debug_callback_partitioner.cc"], + hdrs = ["debug_callback_partitioner.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_sharding", + "//xla/service:custom_call_sharding_helper", + "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "ops", srcs = ["ops.cc"], diff --git a/third_party/xla/xla/python/debug_callback_partitioner.cc b/third_party/xla/xla/python/debug_callback_partitioner.cc new file mode 100644 index 00000000000000..fe6e4d6b6f28b7 --- /dev/null +++ b/third_party/xla/xla/python/debug_callback_partitioner.cc @@ -0,0 +1,84 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/debug_callback_partitioner.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/service/spmd/spmd_partitioner.h" +#include "xla/shape.h" + +namespace xla { + +absl::Status DebugCallbackCustomCallPartitioner::Partition( + spmd::SpmdPartitioningVisitor* partitioner, HloInstruction* hlo) const { + // Cast HloInstruction to HloCustomCallInstruction. + const HloCustomCallInstruction* custom_call = + Cast(hlo); + + // Initialize partitioned operands and shapes. + const int64_t num_operands = hlo->operand_count(); + std::vector partitioned_operands; + partitioned_operands.reserve(num_operands); + std::vector partitioned_shapes_with_layout_constraints; + partitioned_shapes_with_layout_constraints.reserve(num_operands); + + // Loop through and get partitioned operands and shapes. + for (size_t i = 0; i < num_operands; ++i) { + // For each operand, get partitioned hlo. + spmd::PartitionedHlo partitioned_operand = + partitioner->GetPartitionedHlo(hlo->operand(i)); + partitioned_operands.push_back(partitioned_operand.hlo()); + Shape partitioned_shape_with_layout_constraint = + partitioned_operand.hlo()->shape(); + (*partitioned_shape_with_layout_constraint.mutable_layout()) = + custom_call->operand_shapes_with_layout()[i].layout(); + partitioned_shapes_with_layout_constraints.push_back( + partitioned_shape_with_layout_constraint); + } + + // Create new custom call with partitioned operands. + std::unique_ptr partitioned_instruction = + HloInstruction::CreateCustomCall( + hlo->shape(), partitioned_operands, custom_call->custom_call_target(), + partitioned_shapes_with_layout_constraints, custom_call->opaque(), + custom_call->api_version()); + auto partitioned_custom_call = + Cast(partitioned_instruction.get()); + partitioned_custom_call->set_custom_call_has_side_effect( + custom_call->custom_call_has_side_effect()); + HloInstruction* partitioned_hlo = partitioner->builder()->AddInstruction( + std::move(partitioned_instruction)); + partitioned_hlo->set_sharding(HloSharding::Replicate()); + + spmd::PartitionedHlo result_partitioned = + spmd::PartitionedHlo(partitioned_hlo, hlo->shape(), + partitioner->MakePartitioningState()) + .Reshard(hlo->sharding()); + partitioner->SetPartitionedHlo(hlo, result_partitioned); + + return absl::OkStatus(); +} + +} // namespace xla diff --git a/third_party/xla/xla/python/debug_callback_partitioner.h b/third_party/xla/xla/python/debug_callback_partitioner.h new file mode 100644 index 00000000000000..02f163689b2157 --- /dev/null +++ b/third_party/xla/xla/python/debug_callback_partitioner.h @@ -0,0 +1,37 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_DEBUG_CALLBACK_PARTITIONER_H_ +#define XLA_PYTHON_DEBUG_CALLBACK_PARTITIONER_H_ + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/custom_call_sharding_helper.h" + +namespace xla { + +// Partition the custom call according to XLA partitioning. Currently only used +// by `jax.debug.callback`. +// TODO(b/409338207): Pass additional metadata to the custom call e.g., +// partition id. +class DebugCallbackCustomCallPartitioner : public CustomCallPartitioner { + public: + absl::Status Partition(spmd::SpmdPartitioningVisitor* partitioner, + HloInstruction* hlo) const override; +}; + +} // namespace xla + +#endif // XLA_PYTHON_DEBUG_CALLBACK_PARTITIONER_H_ diff --git a/third_party/xla/xla/python/version.h b/third_party/xla/xla/python/version.h index a8891ef4b22cc0..a575c7ccb68c01 100644 --- a/third_party/xla/xla/python/version.h +++ b/third_party/xla/xla/python/version.h @@ -18,6 +18,6 @@ limitations under the License. // An increasing version number to protect jax code against breaking changes. // In JAX, reference this via jax._src.lib.ifrt_version. -#define JAX_IFRT_VERSION_NUMBER 2 +#define JAX_IFRT_VERSION_NUMBER 3 #endif // XLA_PYTHON_VERSION_H_ From aa6fc518fc1bbc7db3ce0de6290c57cf511d604d Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 8 Apr 2025 14:47:39 -0700 Subject: [PATCH 0396/1324] Bring back `xla_gpu_enable_nccl_comm_splitting`. PiperOrigin-RevId: 745299224 --- third_party/xla/xla/xla.proto | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 8ab86b274a21f4..6e10808a81b60d 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -443,6 +443,9 @@ message DebugOptions { // threads. Setting to 0 (the default value) means no enforcement. bool xla_gpu_enable_llvm_module_compilation_parallelism = 268; + // DEPRECATED: This flag is a no-op. + bool xla_gpu_enable_nccl_clique_optimization = 244 [deprecated = true]; + // Enable NCCL communicator splitting. bool xla_gpu_enable_nccl_comm_splitting = 272; @@ -806,7 +809,6 @@ message DebugOptions { // go/keep-sorted end - reserved 244; // xla_gpu_enable_nccl_clique_optimization reserved 276; // xla_gpu_enable_nccl_per_stream_comms //--------------------------------------------------------------------------// From 5ed404cafebe4ccb9cf04c83b88e30ce7360e1bc Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 8 Apr 2025 14:50:47 -0700 Subject: [PATCH 0397/1324] Removed unused coordination service registration code. PiperOrigin-RevId: 745300499 --- third_party/xla/xla/pjrt/distributed/BUILD | 1 - .../distributed_runtime/coordination/BUILD | 33 ++++--------- .../coordination/coordination_service.cc | 8 ++-- .../coordination/coordination_service.h | 47 +------------------ .../tsl/distributed_runtime/preemption/BUILD | 1 - 5 files changed, 12 insertions(+), 78 deletions(-) diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 69bee92bb3812c..a288d5fd8c9fa3 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -26,7 +26,6 @@ cc_library( "//xla:types", "//xla:util", "//xla/tsl/distributed_runtime/coordination:coordination_service", - "//xla/tsl/distributed_runtime/coordination:coordination_service_impl", "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", "//xla/tsl/protobuf:coordination_config_proto_cc", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 992ab062588c6d..22ac47cdc8834d 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -50,33 +50,14 @@ cc_library( cc_library( name = "coordination_service", - hdrs = ["coordination_service.h"], - deps = [ - ":coordination_client", - "//xla/tsl/platform:macros", - "//xla/tsl/platform:status", - "//xla/tsl/protobuf:coordination_config_proto_cc", - "//xla/tsl/protobuf:coordination_service_proto_cc", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/time", - ], -) - -# Keeping the implementation as a separate build target. -# This is an alwayslink library for statically registering "standalone" implementation. -# Other implementations of the service will be provided in the future. -cc_library( - name = "coordination_service_impl", srcs = ["coordination_service.cc"], + hdrs = ["coordination_service.h"], deps = [ ":coordination_client", - ":coordination_service", ":coordination_service_error_util", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/platform:env", + "//xla/tsl/platform:macros", "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", @@ -90,15 +71,20 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:random", ], - alwayslink = 1, ) +# TODO(mwhittaker): Remove this once tensorflow only relies on the +# coordination_service BUILD rule. +cc_library(name = "coordination_service_impl") + tf_proto_library( name = "test_device_proto", testonly = 1, @@ -116,7 +102,6 @@ tsl_cc_test( ":coordination_client", ":coordination_service", ":coordination_service_error_util", - ":coordination_service_impl", ":test_device_proto_cc", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", @@ -226,7 +211,6 @@ tsl_cc_test( ":coordination_client", ":coordination_service", ":coordination_service_agent", - ":coordination_service_impl", "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", @@ -258,7 +242,6 @@ tsl_cc_test( ":coordination_client", ":coordination_service", ":coordination_service_agent", - ":coordination_service_impl", "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index 1ecb499fd4adcf..445040905d84d0 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -2166,8 +2166,9 @@ void CoordinationServiceStandaloneImpl::CompleteShutdownAfterBarrier( } } // namespace -std::unique_ptr EnableCoordinationService( - Env* env, const CoordinationServiceConfig& config, +std::unique_ptr +CoordinationServiceInterface::EnableCoordinationService( + Env* env, const tensorflow::CoordinationServiceConfig& config, std::unique_ptr cache) { return std::make_unique(env, config, std::move(cache)); @@ -2204,7 +2205,4 @@ bool CoordinationServiceStandaloneImpl::IsClientPollingForError() const { return client_polling_for_error_; } -// Register standalone coordination service implementation. -REGISTER_COORDINATION_SERVICE("standalone", EnableCoordinationService); - } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index 293c1073784df2..d7e42d48981301 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -21,8 +21,6 @@ limitations under the License. #include #include #include -#include -#include #include #include "absl/log/log.h" @@ -31,7 +29,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" -#include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" @@ -39,19 +36,6 @@ limitations under the License. namespace tsl { class Env; -// Static registration for coordination service implementations. -#define REGISTER_COORDINATION_SERVICE(service_type_name, factory_fn) \ - REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(__COUNTER__, service_type_name, \ - factory_fn) -#define REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(counter, service_type_name, \ - factory_fn) \ - static bool static_coordination_service_##counter TF_ATTRIBUTE_UNUSED = \ - []() { \ - ::tsl::CoordinationServiceInterface::RegisterCoordinationService( \ - service_type_name, std::move(factory_fn)); \ - return true; \ - }() - // Coordination service is used for controlling and coordinating distributed // execution in a cluster of multiple tasks. // @@ -70,11 +54,6 @@ class Env; // tasks. Each task interacts with the service through CoordinationServiceAgent. class CoordinationServiceInterface { public: - using CoordinationServiceFactory = - std::function( - Env* env, const tensorflow::CoordinationServiceConfig& config, - std::unique_ptr cache)>; - using StatusOrValueCallback = std::function&)>; using BarrierCallback = std::function; @@ -83,27 +62,10 @@ class CoordinationServiceInterface { virtual ~CoordinationServiceInterface() = default; - static void RegisterCoordinationService( - std::string_view service_type_name, - CoordinationServiceFactory factory_fn) { - auto factories = GetCoordinationServiceFactories(); - factories->emplace(service_type_name, factory_fn); - } - static std::unique_ptr EnableCoordinationService(Env* env, const tensorflow::CoordinationServiceConfig& config, - std::unique_ptr cache) { - const auto* factories = GetCoordinationServiceFactories(); - auto factories_iter = factories->find(config.service_type()); - if (factories_iter == factories->end()) { - LOG(ERROR) << "No coordination service factory found for service type " - << config.service_type(); - return nullptr; - } - auto service = factories_iter->second(env, config, std::move(cache)); - return service; - } + std::unique_ptr cache); // This function is invoked after each task's local devices are appended in a // deterministic order during WaitForAllTasks(). This is useful to convert the @@ -305,13 +267,6 @@ class CoordinationServiceInterface { virtual const tensorflow::DeviceInfo& ListClusterDevices() = 0; virtual uint64_t GetServiceIncarnation() = 0; - - static std::unordered_map* - GetCoordinationServiceFactories() { - static auto* coordination_service_factories = - new std::unordered_map(); - return coordination_service_factories; - } }; } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD index f74fb55ed05f23..1766d18104e209 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD @@ -80,7 +80,6 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/coordination:coordination_client", "//xla/tsl/distributed_runtime/coordination:coordination_service", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", - "//xla/tsl/distributed_runtime/coordination:coordination_service_impl", "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", From a3689329712ea22dfd1fa6e17abb2aaa5e7f3e4d Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Tue, 8 Apr 2025 14:54:31 -0700 Subject: [PATCH 0398/1324] [IFRT] Pass SDY meshes using ifrt.sdy.meshes s.t. it is not removed during IFRT versioning. PiperOrigin-RevId: 745301792 --- third_party/xla/xla/python/ifrt/ir/constants.h | 3 +++ .../ifrt/ir/tests/ifrt_compile_atom_program.mlir | 3 +-- .../ir/transforms/ifrt_compile_atom_program_pass.cc | 12 ++++-------- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/constants.h b/third_party/xla/xla/python/ifrt/ir/constants.h index 512b22259fdc03..99621f5f693b80 100644 --- a/third_party/xla/xla/python/ifrt/ir/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/constants.h @@ -61,6 +61,9 @@ inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = // partitioned by the Sdy partitioner. inline constexpr llvm::StringLiteral kIsSdyPartitioned = "ifrt.is_sdy_partitioned"; +// Name of the StringAttr set on the ModuleOp to store meshes SDY uses. +inline constexpr llvm::StringLiteral kIfrtSdyMeshesRoundTripAttr = + "ifrt.sdy.meshes"; inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main"; diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir index 22257730e01d5e..6a544019e15b44 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir @@ -32,8 +32,7 @@ module @call_hlo { #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> // CHECK-LABEL: @call_hlo_sdy_lowered module @call_hlo_sdy_lowered attributes { - mhlo.frontend_attributes = { - xla.sdy.meshes ="{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}} { + ifrt.sdy.meshes ="{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"} { func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { // CHECK: ifrt.CallLoadedExecutable @fake_component__fake_method_1(%arg0) %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc index 322c03013558ac..556b5088ec9c76 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc @@ -116,12 +116,8 @@ void IfrtCompileAtomProgramPass::runOnOperation() { llvm::DenseMap call_to_compile_futures; mlir::ModuleOp module_op = getOperation(); - mlir::Attribute meshes_round_trip_attr; - // TODO: icgog - This attribute will be deleted in the IFRT -> VIFRT - // legalization. Fix in order to be able to use Sdy with VIFRT. - if (auto front_end_attr = xla::sdy::getFrontendAttrs(module_op)) { - meshes_round_trip_attr = front_end_attr.get(xla::sdy::kMeshesRoundTripAttr); - } + mlir::Attribute sdy_meshes_round_trip_attr = + module_op->getAttr(kIfrtSdyMeshesRoundTripAttr); // Stash the errors in a MapVector, which maintains the order in which they // are encountered. We do not emit an error within the walk because atom @@ -156,7 +152,7 @@ void IfrtCompileAtomProgramPass::runOnOperation() { if (call_op->hasAttr(kIsSdyPartitioned)) { // Add the meshes roundtrip attribute to the callee module if the // atom program was partitioned with sdy. - if (!meshes_round_trip_attr) { + if (!sdy_meshes_round_trip_attr) { call_op_to_error.try_emplace( call_op, "requires meshes roundtrip attribute to be set on the " @@ -166,7 +162,7 @@ void IfrtCompileAtomProgramPass::runOnOperation() { } xla::sdy::setFrontendAttribute( callee_module, xla::sdy::kMeshesRoundTripAttr, - meshes_round_trip_attr, /*escapeAttr=*/false); + sdy_meshes_round_trip_attr, /*escapeAttr=*/false); } absl::StatusOr compile_future = From c96bd328c99f5412378f8a7c0a68ff48e5b1a426 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 8 Apr 2025 15:01:02 -0700 Subject: [PATCH 0399/1324] [xla:copy_insertion] Generalize the handling of pipelined Send/Recv to handle asynchronous operations that produce non-copyable results. PiperOrigin-RevId: 745303985 --- third_party/xla/xla/service/BUILD | 2 +- third_party/xla/xla/service/copy_insertion.cc | 226 +++++++++++------- third_party/xla/xla/service/copy_insertion.h | 6 +- 3 files changed, 140 insertions(+), 94 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2d01ea4d6fa595..4241ca51b3da33 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3952,10 +3952,10 @@ cc_library( "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:ptrvec", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", - "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 6d8ec0ba5c3509..d2b7f1a73f00d6 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/map_util.h" @@ -191,20 +192,30 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, return std::make_pair(from_deep_copy, to_deep_copy); } -bool IsSendRecv(const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kRecv; -} - -bool IsSendRecvDone(const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSendDone || - instruction->opcode() == HloOpcode::kRecvDone; +// Returns true if the instruction produces non-copyable results. +// +// Currently, only asynchronous start ops produce non-copyable results and the +// the whole result is non-copyable. +bool IsNonCopyable(const HloInstruction* instruction) { + // Currently, the verifier only allows the pipelining of Send/Recv. As such, + // here we only handle to the ops allowed by + // HloDataflowAnalysis::IsAsynchronousOperationStart that pass through its + // operand for now. For the ops that don't pass through its operand, we need + // to add a copy of its operand for the straight line case in order to allow + // all ops in HloDataflowAnalysis::IsAsynchronousOperationStart. + HloOpcode opcode = instruction->opcode(); + return opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv || + opcode == HloOpcode::kCopyStart; } -bool IsSendRecvInInit(const HloInstruction* init, const ShapeIndex& index) { +// Returns true if the value at the given index in the while init is +// non-copyable. +bool IsNonCopyableInWhileInit(const HloInstruction* while_init, + const ShapeIndex& index) { if (index.empty()) return false; int64_t i = index.front(); - return i < init->operand_count() && IsSendRecv(init->operand(i)); + return i < while_init->operand_count() && + IsNonCopyable(while_init->operand(i)); } // Compute the indices of the loop state which need copies in order to avoid @@ -223,9 +234,9 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, for (auto& pair : *indices_to_copy) { const ShapeIndex& index = pair.first; bool& should_copy = pair.second; - if (IsSendRecvInInit(init, index)) { - // Do not copy partially pipelined send/recv ops. The required copies will - // be inserted specifically for the send/recv ops. + if (IsNonCopyableInWhileInit(init, index)) { + // Do not copy non-copyable values, instead, we will add copies for + // transitioning into and out of non-copyable values. should_copy = false; continue; } else if (dataflow.GetValueSet(init, index).values().size() > 1 || @@ -2018,51 +2029,107 @@ absl::Status CopyInsertion::AddCopiesForConditional( return absl::OkStatus(); } -HloInstruction* FindAsyncSendRecvDoneInWhileBody( - const HloComputation* while_body, const HloInstruction* start_op) { - // Partially pipelined send/recv must have a single user. - if (start_op->user_count() != 1) return nullptr; - HloInstruction* unique_user = start_op->users().front(); - // Send/recv must be consumed by send/recv-done op or be passed through the - // loop. - if (IsSendRecvDone(unique_user)) return unique_user; +// If `chain_start` is the head of a chain of non-copyable ops inside a while +// loop, and part of the chain is rotated to the next iteration, returns the +// chain end in the rotated part. Otherwise, returns nullptr. +HloInstruction* FindEndOpForRotatedNonCopyableChain( + const HloComputation* while_body, const HloInstruction* chain_start) { + // Non-copyable op must have a single user. + if (chain_start->user_count() != 1) return nullptr; + HloInstruction* unique_user = chain_start->users().front(); if (unique_user->opcode() != HloOpcode::kTuple || !unique_user->IsRoot()) return nullptr; - int64_t index = unique_user->operand_index(start_op); + int64_t index = unique_user->operand_index(chain_start); for (const HloInstruction* it : while_body->parameter_instruction(0)->users()) { const auto* gte = DynCast(it); if (gte->tuple_index() == index) { - CHECK_EQ(gte->user_count(), 1) << "send/recv in next loop iteration must " - "be consumed by unique send/recv-done."; + CHECK_EQ(gte->user_count(), 1) + << "non-copyable value in next loop iteration must " + "be consumed by unique instruction."; HloInstruction* next_unique_user = gte->users().front(); - if (IsSendRecvDone(next_unique_user)) return next_unique_user; + if (HloDataflowAnalysis::IsAsynchronousOperationDone( + next_unique_user->opcode())) { + return next_unique_user; + } + break; } } return nullptr; } -// Add copies for partially pipelined async send/recv. Copies are added before -// starting to send and after finishing to recv. This is to prevent overlapping -// live times of the buffers. The control flow edges from the added copy to the -// recv or send-done operation guarantee disjoint live times of the buffers. -// Note that we have anchor these control flow edges to the copies as the send -// and recv-done ops are aliasing. +// Adds copies for non-copyable transitioning between copyable and non-copyable +// for a chain start with `chain_start` and part of the chain is rotated to the +// next iteration that ends with `chain_end`. +absl::Status AddCopiesForNonCopyableTransitionsRotatedCase( + HloInstruction* chain_start, HloInstruction* chain_end) { + HloComputation* while_body = chain_start->parent(); + // Handle aliasing input for the op, where we transition from copyable to + // non-copyable. + if (!chain_start->operands().empty()) { + // A chain_start may have multiple operands, but we assume only the first + // operand is a buffer aliasing with the output, which is true currently. + HloInstruction* operand = chain_start->mutable_operand(0); + HloInstruction* copied_operand = + while_body->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); + TF_RETURN_IF_ERROR(chain_end->AddControlDependencyTo(copied_operand)); + } + + // The chain_end is rotated and semantically paired with the chain_start of + // the previous iteration. We add a control dependency from the chain_end to + // the chain_start to in the same lexical iteration guarantee disjoint live + // times of the buffers involved. + TF_RETURN_IF_ERROR(chain_end->AddControlDependencyTo(chain_start)); + + // If chain_end has users, insert copies for the result produced by the + // chain_end with aliasing input and output buffers, where we transition from + // non-copyable to copyable. + PtrVec users = chain_end->users(); + if (users.empty()) return absl::OkStatus(); + + ShapeTree copies_added(chain_end->shape()); + TF_ASSIGN_OR_RETURN( + HloInstruction * copy, + while_body->DeepCopyInstruction(chain_end, /*indices_to_copy=*/nullptr, + &copies_added)); + for (auto [shape_index, instr] : copies_added) { + if (instr != nullptr) + TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(chain_start)); + } + for (HloInstruction* it : users) { + TF_RETURN_IF_ERROR(chain_end->ReplaceUseWith(it, copy)); + } + return absl::OkStatus(); +} + +// Adds the needed copies for transitioning into and out of non-copyable values, +// to prevent overlapping live times of buffers. This is needed when the unique +// user of the non-copyable op is rotated (also called pipelined) in a +// while-loop. In particlar, if a non-copyable op has an input aliasing with its +// output, such as async Send, we make a copy of its input to transition from +// copyable to non-copyable. If a non-copyable op's unique user produces an +// output aliasing with its input, such as async Recv, we make a copy of the +// output produced by the unique user, to transition out of non-copyable to +// copyable. We also add control-flow edges between the copies and the +// non-copyable op to guarantee disjoint live times of the buffers invovled. // +// Using async Send and Recv as examples, here is the transformation: // // Before: // -// kParameter kParameter -// | | -// kSendDone kRecvDone -// | -// ... consumer +// kParameter kParameter +// | | +// kSendDone kRecvDone (end of a non-copyable chain) +// | +// ... consumer // -// producer ... +// producer ... // | -// kSend kRecv -// | | -// (body root) (body root) +// kSend kRecv (start of a non-copyable op) +// | | +// (body root) (body root) // // // After: @@ -2080,73 +2147,53 @@ HloInstruction* FindAsyncSendRecvDoneInWhileBody( // | | // (body root) (body root) // -absl::Status CopyInsertion::AddCopiesForAsyncSendRecv( - const HloAliasAnalysis& alias_analysis, HloInstruction* start_op) { - // If start op has multiple users, this must be the synchronous use of - // send/recv. - // TODO(b/369589022): Disambiguate sync and async use of send/recv. - if (start_op->users().size() != 1) return absl::OkStatus(); +absl::Status CopyInsertion::AddCopiesForNonCopyableTransitions( + const HloAliasAnalysis& alias_analysis, HloInstruction* chain_start) { + if (chain_start->users().empty()) { + return absl::OkStatus(); + } + + // Currently non-copyable ops can have at most one user. + if (chain_start->users().size() != 1) { + return absl::InvalidArgumentError( + "Non-copyable op must have a single user."); + } + HloInstruction* unique_user = chain_start->users().front(); // If start feeds directly into done, the live time is contained and we don't // need to add any copies. - HloInstruction* unique_user = start_op->users().front(); - const HloOpcode done_opcode = start_op->opcode() == HloOpcode::kSend - ? HloOpcode::kSendDone - : HloOpcode::kRecvDone; - if (unique_user->opcode() == done_opcode) { + if (HloDataflowAnalysis::IsAsynchronousOperationDone(unique_user->opcode())) { return absl::OkStatus(); } - HloComputation* parent = start_op->parent(); - // If a Send is feeded into a pipelined while-loop, we need to make a copy - // of the Send operand and use it in the Send. - if (start_op->opcode() == HloOpcode::kSend && + HloComputation* parent = chain_start->parent(); + // If a start op with an operand is fed into a pipelined while-loop, we + // need to make a copy of the operand and use the copy in the start op. + if (chain_start->operand_count() > 0 && unique_user->opcode() == HloOpcode::kTuple && unique_user->users().size() == 1 && unique_user->users().front()->opcode() == HloOpcode::kWhile) { - HloInstruction* operand = start_op->mutable_operand(0); + HloInstruction* operand = chain_start->mutable_operand(0); HloInstruction* copied_operand = parent->AddInstruction(HloInstruction::CreateUnary( operand->shape(), HloOpcode::kCopy, operand)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); return absl::OkStatus(); } - // For other cases that send/recv are outside of the while loop, live times - // are disjoint. No copies are needed. + // For other cases where a non-copyable chain is outside of the while loop, + // live times are disjoint. No copies are needed. if (parent->caller_instructions(HloOpcode::kWhile).empty()) { return absl::OkStatus(); } - // Handle send case. - HloInstruction* done_op = FindAsyncSendRecvDoneInWhileBody(parent, start_op); - // TODO(b/369589022): Disambiguate sync and async use of send/recv. - if (done_op == nullptr) return absl::OkStatus(); - if (start_op->opcode() == HloOpcode::kSend) { - HloInstruction* operand = start_op->mutable_operand(0); - HloInstruction* copied_operand = - parent->AddInstruction(HloInstruction::CreateUnary( - operand->shape(), HloOpcode::kCopy, operand)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); - TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(copied_operand)); - return absl::OkStatus(); - } + // For async start ops, the end of the chain is the async done op. + HloInstruction* chain_end = + FindEndOpForRotatedNonCopyableChain(parent, chain_start); + if (chain_end) + return AddCopiesForNonCopyableTransitionsRotatedCase(chain_start, + chain_end); - // Handle recv case. - CHECK_EQ(start_op->opcode(), HloOpcode::kRecv); - PtrVec done_op_users = done_op->users(); - ShapeTree copies_added(done_op->shape()); - TF_ASSIGN_OR_RETURN(HloInstruction * done_op_copy, - parent->DeepCopyInstruction( - done_op, /*indices_to_copy=*/nullptr, &copies_added)); - for (auto [shape_index, instr] : copies_added) { - if (instr != nullptr) - TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(start_op)); - } - TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(start_op)); - for (HloInstruction* it : done_op_users) { - TF_RETURN_IF_ERROR(done_op->ReplaceUseWith(it, done_op_copy)); - } return absl::OkStatus(); } @@ -2170,10 +2217,9 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference( } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); - } else if (IsSendRecv(instruction)) { - // TODO(b/371225893): Generalize this to all async collectives. + } else if (IsNonCopyable(instruction)) { TF_RETURN_IF_ERROR( - AddCopiesForAsyncSendRecv(*alias_analysis, instruction)); + AddCopiesForNonCopyableTransitions(*alias_analysis, instruction)); } else { // When an operand is a tuple, we avoid copying the operand multiple // times by recording and checking the operand number of operands that diff --git a/third_party/xla/xla/service/copy_insertion.h b/third_party/xla/xla/service/copy_insertion.h index 0b2ba86e3ef3eb..e580405341c87a 100644 --- a/third_party/xla/xla/service/copy_insertion.h +++ b/third_party/xla/xla/service/copy_insertion.h @@ -107,9 +107,9 @@ class CopyInsertion : public HloModulePass { virtual absl::Status AddCopiesForConditional( const HloAliasAnalysis& alias_analysis, HloInstruction* conditional); - // Add copies for async send/recv instructions. - absl::Status AddCopiesForAsyncSendRecv(const HloAliasAnalysis& alias_analysis, - HloInstruction* async); + // Adds copies for transitioning into and out of non-copyable values. + absl::Status AddCopiesForNonCopyableTransitions( + const HloAliasAnalysis& alias_analysis, HloInstruction* chain_start); // Backend specific function that decides whether an instruction can share // buffer with its operand. From e6e9c35ac4556fa0ba91e1eaaf705b90cf036126 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 15:02:56 -0700 Subject: [PATCH 0400/1324] Roll TF's XNNPACK version to include https://github.com/google/XNNPACK/commit/ece21c589be842fbeaee297b0d668194d6f3a35b PiperOrigin-RevId: 745304841 --- tensorflow/lite/tools/cmake/modules/xnnpack.cmake | 2 +- tensorflow/workspace2.bzl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 8737699e3eaa5a..6a60e1e1f9fde9 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG 42ed90ba36f14321df08712e7a36713de5b2f29b + GIT_TAG ece21c589be842fbeaee297b0d668194d6f3a35b GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 7465f38d4df828..eab74a3f526060 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -160,9 +160,9 @@ def _tf_repositories(): # LINT.IfChange(xnnpack) tf_http_archive( name = "XNNPACK", - sha256 = "a7e47b12fb8beb0177fbd49c8dfcb842709b5a50cdf2f5bf5ec5e33d8244fcfa", - strip_prefix = "XNNPACK-42ed90ba36f14321df08712e7a36713de5b2f29b", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/42ed90ba36f14321df08712e7a36713de5b2f29b.zip"), + sha256 = "f25179a30775d9918670fb5fb07cd8e80c2ae0a8f4ec450a6d6c496d159ba66b", + strip_prefix = "XNNPACK-ece21c589be842fbeaee297b0d668194d6f3a35b", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/ece21c589be842fbeaee297b0d668194d6f3a35b.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) From 51b69c97e602b2bd53f2d1c4ab98eef662701782 Mon Sep 17 00:00:00 2001 From: Daniel Chen Date: Tue, 8 Apr 2025 15:10:22 -0700 Subject: [PATCH 0401/1324] Create a result summary for diff patterns. PiperOrigin-RevId: 745308067 --- third_party/xla/xla/hlo/tools/hlo_diff/BUILD | 1 + .../hlo/tools/hlo_diff/hlo_diff_eval_test.cc | 4 +- .../xla/hlo/tools/hlo_diff/hlo_diff_result.cc | 4 +- .../xla/hlo/tools/hlo_diff/hlo_diff_result.h | 7 +- .../hlo/tools/hlo_diff/hlo_diff_summary.cc | 92 +++++++++++++++++-- .../xla/hlo/tools/hlo_diff/hlo_diff_summary.h | 27 +++++- .../tools/hlo_diff/hlo_diff_summary_test.cc | 68 ++++++++------ .../xla/xla/hlo/tools/hlo_diff/render/BUILD | 2 + .../render/hlo_gumgraph_html_renderer.cc | 57 +++++++----- .../render/hlo_gumgraph_renderer_util.cc | 3 +- .../render/hlo_gumgraph_renderer_util.h | 4 +- .../render/hlo_gumgraph_renderer_util_test.cc | 5 +- .../render/hlo_gumgraph_text_renderer.cc | 2 +- 13 files changed, 201 insertions(+), 75 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD index 3c7639a89de630..a4798f00f51ebf 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD @@ -37,6 +37,7 @@ cc_library( "//xla/hlo/tools/hlo_diff/graph/utils:hlo_gumgraph_bfs", "//xla/hlo/tools/hlo_diff/utils:hlo_diff_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", ], diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc index c5aed19f50814c..36c5b58932d694 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_eval_test.cc @@ -184,9 +184,9 @@ TEST_F(HloDiffTest, DiffSizeWorks) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph_r, HloGumgraph::Create(module_r.get())); DiffResult diff_result; - diff_result.left_module_unmatched_instructions.push_back( + diff_result.left_module_unmatched_instructions.insert( graph_l->GetRoot().instruction); - diff_result.right_module_unmatched_instructions.push_back( + diff_result.right_module_unmatched_instructions.insert( graph_r->GetRoot().instruction); diff_result.changed_instructions.insert( {graph_l->GetRoot().instruction, graph_r->GetRoot().instruction}); diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc index a6657831ea4289..1bae0144583d9b 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.cc @@ -63,7 +63,7 @@ std::unique_ptr ConstructDiffResult( } diff_result->node_props.insert({left_node->instruction, left_node->props}); if (!mappings.InstructionMapContainsLeft(left_node)) { - diff_result->left_module_unmatched_instructions.push_back( + diff_result->left_module_unmatched_instructions.insert( left_node->instruction); continue; } @@ -104,7 +104,7 @@ std::unique_ptr ConstructDiffResult( diff_result->node_props.insert( {right_node->instruction, right_node->props}); if (!mappings.InstructionMapContainsRight(right_node)) { - diff_result->right_module_unmatched_instructions.push_back( + diff_result->right_module_unmatched_instructions.insert( right_node->instruction); } } diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h index 08a8e0d5ea2dcb..ff5ed8bf9d3f4a 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_result.h @@ -19,9 +19,9 @@ #include #include -#include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" #include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h" @@ -40,8 +40,9 @@ struct DiffResult { moved_instructions; // Unmatched instructions. - std::vector left_module_unmatched_instructions; - std::vector right_module_unmatched_instructions; + absl::flat_hash_set left_module_unmatched_instructions; + absl::flat_hash_set + right_module_unmatched_instructions; // Debug info. absl::flat_hash_map, diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc index b69fad4be806f5..a9803ff7eff4b4 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -210,6 +211,53 @@ FindConnectedComponents( return result; } +DiffMetrics GetDiffMetrics(const ComputationGroup& computation_group, + const DiffResult& diff_result) { + DiffMetrics result; + for (const HloComputation* computation : + computation_group.left_computations) { + for (const HloInstruction* instruction : computation->instructions()) { + if (diff_result.changed_instructions.contains(instruction)) { + ++result.changed_instruction_count; + } else if (diff_result.left_module_unmatched_instructions.contains( + instruction)) { + ++result.left_unmatched_instruction_count; + } + } + } + for (const HloComputation* computation : + computation_group.right_computations) { + for (const HloInstruction* instruction : computation->instructions()) { + if (diff_result.changed_instructions.contains(instruction)) { + ++result.changed_instruction_count; + } else if (diff_result.right_module_unmatched_instructions.contains( + instruction)) { + ++result.right_unmatched_instruction_count; + } + } + } + return result; +} + +std::vector FindComputationDiffPatterns( + const absl::flat_hash_map& + computation_summary, + const DiffResult& diff_result) { + std::vector result; + absl::flat_hash_map> + connected_components = FindConnectedComponents(computation_summary); + for (const auto& [fingerprint, computation_groups] : connected_components) { + ComputationDiffPattern diff_pattern; + diff_pattern.fingerprint = fingerprint; + diff_pattern.computation_groups = computation_groups; + diff_pattern.diff_metrics = + GetDiffMetrics(computation_groups[0], diff_result); + result.push_back(std::move(diff_pattern)); + } + return result; +} + +// Summarizes all computations in the given graph. absl::flat_hash_map SummarizeAllComputationsInGraph( const HloGumgraph& graph, const HloGumgraphMappings& mappings, @@ -265,24 +313,25 @@ std::unique_ptr ConstructDiffSummary( right_unmatched_instructions, ComputationMappingDirection::kRightToLeft)); // Group the computations by their diff fingerprint. - summary->grouped_computations = - FindConnectedComponents(summary->computation_summary); + summary->computation_diff_patterns = + FindComputationDiffPatterns(summary->computation_summary, diff_result); return summary; } void LogDiffSummary(const DiffSummary& diff_summary) { // Log the connected components repeated more than 3 times. - LOG(INFO) << "Find Repeated Connected Components: "; - for (const auto& [fingerprint, computation_groups] : - diff_summary.grouped_computations) { - if (computation_groups.size() < 3) { + LOG(INFO) << "Find Repeated Diff Patterns: "; + for (const ComputationDiffPattern& diff_pattern : + diff_summary.computation_diff_patterns) { + if (diff_pattern.computation_groups.size() < 3) { continue; } - LOG(INFO) << computation_groups.size() - << " Repeated Connected Components Fingerprint: " << fingerprint; + LOG(INFO) << diff_pattern.computation_groups.size() + << " Repeated Diff Pattern Fingerprint: " + << diff_pattern.fingerprint; int i = 0; - for (const auto& computation_group : computation_groups) { + for (const auto& computation_group : diff_pattern.computation_groups) { ++i; std::string computations_str; for (const HloComputation* computation : @@ -304,5 +353,30 @@ void LogDiffSummary(const DiffSummary& diff_summary) { } } +void PrintTo(const ComputationDiffPattern& diff_pattern, std::ostream* os) { + *os << "{ fingerprint: " << diff_pattern.fingerprint; + for (const auto& computation_group : diff_pattern.computation_groups) { + *os << ", computation_groups: " + << "{ L: "; + for (const HloComputation* computation : + computation_group.left_computations) { + *os << absl::StrFormat("%s ", computation->name()); + } + *os << ", R: "; + for (const HloComputation* computation : + computation_group.right_computations) { + *os << absl::StrFormat("%s ", computation->name()); + } + *os << " }"; + } + *os << ", diff_metrics: {" + << diff_pattern.diff_metrics.changed_instruction_count << " changed, " + << diff_pattern.diff_metrics.left_unmatched_instruction_count + << " left unmatched, " + << diff_pattern.diff_metrics.right_unmatched_instruction_count + << " right unmatched }"; + *os << " }"; +} + } // namespace hlo_diff } // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h index a5dfed6c855ace..a68aecc8182fcb 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include "absl/container/flat_hash_map.h" @@ -68,11 +70,28 @@ struct ComputationGroup { std::vector right_computations; }; -// Summary of the diff result of the left and right HLO modules. +// Metrics of the diff pattern. +struct DiffMetrics { + int64_t changed_instruction_count = 0; + int64_t left_unmatched_instruction_count = 0; + int64_t right_unmatched_instruction_count = 0; +}; + +// A computation diff pattern is multiple groups of computations that have the +// same diff. +struct ComputationDiffPattern { + uint64_t fingerprint = 0; + std::vector computation_groups; + DiffMetrics diff_metrics; +}; + +// Teach the gunit to print the diff pattern. +void PrintTo(const ComputationDiffPattern& diff_pattern, std::ostream* os); + +// Summary of the diff result of the left and right HLO modules. struct DiffSummary { - // Connected computations grouped by fingerprint. - absl::flat_hash_map> - grouped_computations; + // The computation diff patterns found in the diff result. + std::vector computation_diff_patterns; // Summary of each computation. absl::flat_hash_map diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc index d1026c80b4ea5a..df9700938f434e 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary_test.cc @@ -267,14 +267,19 @@ TEST_F(HloDiffTest, ComputationDiffFingerprintWorks) { /*split_allegiance_instruction=*/0, /*diff_fingerprint=*/13464792036913846758U, /*all_unchanged=*/false)))); - EXPECT_THAT(diff_summary->grouped_computations, - UnorderedElementsAre(Pair( - 2864899211444957078U, + EXPECT_THAT(diff_summary->computation_diff_patterns, + UnorderedElementsAre(FieldsAre( + /*fingerprint=*/2864899211444957078U, + /*computation_groups=*/ UnorderedElementsAre(FieldsAre( /*left_computations=*/UnorderedElementsAre( Pointee(Property(&HloComputation::name, "entry"))), - /*right_computations=*/UnorderedElementsAre(Pointee( - Property(&HloComputation::name, "entry")))))))); + /*right_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, "entry"))))), + /*diff_metrics=*/ + FieldsAre(/*changed_instruction_count=*/0, + /*left_unmatched_instruction_count=*/2, + /*right_unmatched_instruction_count=*/2)))); } TEST_F(HloDiffTest, FindConnectedComponentsWorks) { @@ -340,28 +345,39 @@ TEST_F(HloDiffTest, FindConnectedComponentsWorks) { std::unique_ptr diff_summary = ConstructDiffSummary(*graph_l, *graph_r, *mappings, *diff_result); EXPECT_THAT( - diff_summary->grouped_computations, + diff_summary->computation_diff_patterns, UnorderedElementsAre( - Pair(2864899211444957078U, - UnorderedElementsAre( - FieldsAre(/*left_computations=*/UnorderedElementsAre( - Pointee(Property(&HloComputation::name, - "fused_computation.1"))), - /*right_computations=*/UnorderedElementsAre( - Pointee(Property(&HloComputation::name, - "fused_computation.2")))), - FieldsAre(/*left_computations=*/UnorderedElementsAre( - Pointee(Property(&HloComputation::name, - "fused_computation.2"))), - /*right_computations=*/UnorderedElementsAre( - Pointee(Property(&HloComputation::name, - "fused_computation.1")))))), - Pair(15473561031564762362U, - UnorderedElementsAre(FieldsAre( - /*left_computations=*/UnorderedElementsAre( - Pointee(Property(&HloComputation::name, "entry"))), - /*right_computations=*/UnorderedElementsAre( - Pointee(Property(&HloComputation::name, "entry")))))))); + FieldsAre( + /*fingerprint=*/2864899211444957078U, + /*computation_groups=*/ + UnorderedElementsAre( + FieldsAre(/*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.1"))), + /*right_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.2")))), + FieldsAre(/*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.2"))), + /*right_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, + "fused_computation.1"))))), + /*diff_metrics=*/ + FieldsAre(/*changed_instruction_count=*/0, + /*left_unmatched_instruction_count=*/2, + /*right_unmatched_instruction_count=*/2)), + FieldsAre(/*fingerprint=*/15473561031564762362U, + /*computation_groups=*/ + UnorderedElementsAre(FieldsAre( + /*left_computations=*/UnorderedElementsAre( + Pointee(Property(&HloComputation::name, "entry"))), + /*right_computations=*/UnorderedElementsAre(Pointee( + Property(&HloComputation::name, "entry"))))), + /*diff_metrics=*/ + FieldsAre(/*changed_instruction_count=*/0, + /*left_unmatched_instruction_count=*/6, + /*right_unmatched_instruction_count=*/6)))); } } // namespace diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD index 9168981c0a8cc6..b451a0a2615c71 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD @@ -22,6 +22,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:span", ], ) @@ -36,6 +37,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc index f1e27bea457db2..36b3cf4da50053 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc @@ -341,7 +341,7 @@ std::string PrintInstructionPairsAsList( // Prints unmatched instructions grouped by opcode and print in a descending // order of the number of instructions for each opcode. std::string PrintUnmatchedInstructions( - absl::Span instructions, + const absl::flat_hash_set& instructions, InstructionLocation location, const absl::flat_hash_set& opcodes_to_ignore, bool name_only, GraphUrlGenerator* url_generator) { @@ -494,7 +494,7 @@ std::string PrintUnchangedInstructions( } std::string PrintUnmatchedMetricsDiff( - absl::Span instructions, + const absl::flat_hash_set& instructions, GetOpMetricFn get_op_metrics, GraphUrlGenerator* url_generator) { std::vector> sorted_metrics_diff; for (const HloInstruction* inst : instructions) { @@ -543,44 +543,45 @@ std::string PrintMatchedMetricsDiff( return PrintList(metrics_diff_list); } -// Summarize a diff group. -std::string SummarizeDiffGroup( - absl::Span computation_groups) { - if (computation_groups.size() > 1) { +// Summarize a diff pattern. +std::string SummarizeDiffPattern(const ComputationDiffPattern& diff_pattern) { + if (diff_pattern.computation_groups.size() > 1) { return absl::StrFormat("Summarized %d computations with the same diff", - computation_groups.size()); + diff_pattern.computation_groups.size()); } return "A single computation has unique diff"; } -// Prints the summary of the repetitive computation groups. -std::string PrintRepetitiveComputationGroups(const DiffSummary& diff_summary, - GraphUrlGenerator* url_generator) { - // Sort the computation groups by the number of computations in each group in +// Prints the summary of the repetitive diff patterns. +std::string PrintRepetitiveDiffPatterns( + absl::Span diff_patterns, + GraphUrlGenerator* url_generator) { + // Sort the diff patterns by the number of computations in each group in // descending order. - std::vector> sorted_computation_groups; - for (const auto& [_, computation_groups] : - diff_summary.grouped_computations) { - sorted_computation_groups.push_back(computation_groups); + std::vector sorted_diff_patterns; + for (const ComputationDiffPattern& diff_pattern : diff_patterns) { + sorted_diff_patterns.push_back(diff_pattern); } std::sort( - sorted_computation_groups.begin(), sorted_computation_groups.end(), - [](absl::Span a, - absl::Span b) { return a.size() > b.size(); }); + sorted_diff_patterns.begin(), sorted_diff_patterns.end(), + [](const ComputationDiffPattern& a, const ComputationDiffPattern& b) { + return a.computation_groups.size() > b.computation_groups.size(); + }); std::string computation_group_list; int i = 0; - for (const auto& computation_groups : sorted_computation_groups) { - if (computation_groups.empty()) { + for (const auto& diff_pattern : sorted_diff_patterns) { + if (diff_pattern.computation_groups.empty()) { continue; } - const ComputationGroup& sample = computation_groups[0]; + const ComputationGroup& sample = diff_pattern.computation_groups[0]; // We only print the one-to-one mapping for now. if (sample.left_computations.size() != 1 || sample.right_computations.size() != 1) { continue; } std::vector computation_pair_list; - for (const ComputationGroup& computation_group : computation_groups) { + for (const ComputationGroup& computation_group : + diff_pattern.computation_groups) { if (computation_group.left_computations.size() != 1 || computation_group.right_computations.size() != 1) { continue; @@ -596,7 +597,7 @@ std::string PrintRepetitiveComputationGroups(const DiffSummary& diff_summary, &computation_group_list, PrintDetails( absl::StrFormat("Group %d: %s (Sample: %s → %s)", ++i, - SummarizeDiffGroup(computation_groups), + SummarizeDiffPattern(diff_pattern), sample.left_computations[0]->name(), sample.right_computations[0]->name()), PrintAttributesList( @@ -604,6 +605,13 @@ std::string PrintRepetitiveComputationGroups(const DiffSummary& diff_summary, "Instruction count: %d → %d", sample.left_computations[0]->instruction_count(), sample.right_computations[0]->instruction_count()), + absl::StrFormat( + "Diff summary: %d changed, %d left unmatched, %d right " + "unmatched", + diff_pattern.diff_metrics.changed_instruction_count, + diff_pattern.diff_metrics.left_unmatched_instruction_count, + diff_pattern.diff_metrics + .right_unmatched_instruction_count), PrintDetails("Instances", PrintList(computation_pair_list))}))); } @@ -669,7 +677,8 @@ void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, // Print repetitive computation groups out << PrintSectionWithHeader( "Group of computations with the same diff", - PrintRepetitiveComputationGroups(diff_summary, url_generator)); + PrintRepetitiveDiffPatterns(diff_summary.computation_diff_patterns, + url_generator)); } } // namespace hlo_diff diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc index d33d2099d49ee8..fb7df45a0e181d 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.cc @@ -20,6 +20,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -111,7 +112,7 @@ std::string GetChangedInstructionDiffTypeString( absl::flat_hash_map> GroupInstructionsByOpcode( - absl::Span instructions) { + const absl::flat_hash_set& instructions) { absl::flat_hash_map> instructions_by_opcode; for (const HloInstruction* inst : instructions) { diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h index 01609c53f5b302..91d3ed79b802ab 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h @@ -24,6 +24,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -62,7 +63,8 @@ inline constexpr auto kIgnoredOpcodes = std::array( // Groups the instructions by opcode. absl::flat_hash_map> -GroupInstructionsByOpcode(absl::Span instructions); +GroupInstructionsByOpcode( + const absl::flat_hash_set& instructions); // Groups the instruction pairs by opcode. absl::flat_hash_map< diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc index f5d7e60061dcec..2306646662668d 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util_test.cc @@ -20,6 +20,7 @@ #include #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/verified_hlo_module.h" @@ -48,10 +49,10 @@ ENTRY test_computation { ROOT sub = s32[10] subtract(add, param2) } )")); - std::vector instructions; + absl::flat_hash_set instructions; for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { - instructions.push_back(instruction); + instructions.insert(instruction); } } diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc index ee27ea2bd36aab..dc03fbe4b04e42 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_text_renderer.cc @@ -43,7 +43,7 @@ namespace { // be printed. void PrintUnmatchedInstructions( const absl::string_view header, - absl::Span instructions, + const absl::flat_hash_set& instructions, std::ostringstream& out, const RenderTextOptions& options) { out << header; if (options.top_n_opcodes >= 0) { From 840accdf9bc4323b4f3edcc4c6a60f339fc4f6e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 15:13:06 -0700 Subject: [PATCH 0402/1324] Add `collect_symlink_data_aspect` to search for the symlinked files in the target runfiles. PiperOrigin-RevId: 745309073 --- .../xla/third_party/py/python_wheel.bzl | 61 +++++++++++++++---- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/third_party/xla/third_party/py/python_wheel.bzl b/third_party/xla/third_party/py/python_wheel.bzl index ad33f44c3cc7d9..5bb045321d1439 100644 --- a/third_party/xla/third_party/py/python_wheel.bzl +++ b/third_party/xla/third_party/py/python_wheel.bzl @@ -114,14 +114,14 @@ Examples: --repo_env=ML_WHEEL_BUILD_DATE=20250107 2. release wheel version: 2.19.0 Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release -3. release candidate wheel version: 2.19.0-rc1 +3. release candidate wheel version: 2.19.0rc1 Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release - --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1 -4. custom wheel version: 2.19.0.dev20250107+cbe478fc5-custom + --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1 +4. custom wheel version: 2.19.0.dev20250107+cbe478fc5custom Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) - --repo_env=ML_WHEEL_VERSION_SUFFIX=-custom + --repo_env=ML_WHEEL_VERSION_SUFFIX=custom 5. snapshot wheel version: 2.19.0.dev0+selfbuilt Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=snapshot @@ -168,11 +168,12 @@ def _collect_data_aspect_impl(_, ctx): if hasattr(ctx.rule.attr, "data"): for data in ctx.rule.attr.data: for f in data.files.to_list(): - if not any([f.path.endswith(ext) for ext in extensions]): + if not f.owner.package: continue - if "pypi" in f.path: - continue - files[f] = True + for ext in extensions: + if f.extension == ext: + files[f] = True + break if hasattr(ctx.rule.attr, "deps"): for dep in ctx.rule.attr.deps: @@ -187,17 +188,50 @@ collect_data_aspect = aspect( attr_aspects = ["deps"], attrs = { "_extensions": attr.string_list( - default = [".so", ".pyd", ".pyi", ".dll", ".dylib", ".lib", ".pd"], + default = ["so", "pyd", "pyi", "dll", "dylib", "lib", "pd"], + ), + }, +) + +def _collect_symlink_data_aspect_impl(_, ctx): + files = {} + symlink_extensions = ctx.attr._symlink_extensions + if not hasattr(ctx.rule.attr, "deps"): + return [FilePathInfo(files = depset(files.keys()))] + for dep in ctx.rule.attr.deps: + if not (dep[DefaultInfo].default_runfiles and + dep[DefaultInfo].default_runfiles.files): + continue + for file in dep[DefaultInfo].default_runfiles.files.to_list(): + if not file.owner.package: + continue + for ext in symlink_extensions: + if file.extension == ext: + files[file] = True + break + + return [FilePathInfo(files = depset(files.keys()))] + +collect_symlink_data_aspect = aspect( + implementation = _collect_symlink_data_aspect_impl, + attr_aspects = ["symlink_deps"], + attrs = { + "_symlink_extensions": attr.string_list( + default = ["pyi", "lib", "pd"], ), }, ) def _collect_data_files_impl(ctx): - files = [] + files = {} for dep in ctx.attr.deps: - files.extend((dep[FilePathInfo].files.to_list())) + for f in dep[FilePathInfo].files.to_list(): + files[f] = True + for symlink_dep in ctx.attr.symlink_deps: + for f in symlink_dep[FilePathInfo].files.to_list(): + files[f] = True return [DefaultInfo(files = depset( - files, + files.keys(), ))] collect_data_files = rule( @@ -206,6 +240,9 @@ collect_data_files = rule( "deps": attr.label_list( aspects = [collect_data_aspect], ), + "symlink_deps": attr.label_list( + aspects = [collect_symlink_data_aspect], + ), }, ) From 01ee40f5dcb41590f83afe357776fc941b61a867 Mon Sep 17 00:00:00 2001 From: jparkerh Date: Tue, 8 Apr 2025 15:37:25 -0700 Subject: [PATCH 0403/1324] add serialization test for PjRtCompiler we previously tested the case of client->Compile and serialize, but we're missing a test case for the compiler->Compile and serialize, which is relied upon by Pathways. adding a quick test for this case. PiperOrigin-RevId: 745317367 --- third_party/xla/xla/pjrt/gpu/BUILD | 2 +- .../xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc | 67 +++++++++++++++++-- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index d5db0f22fdbfe8..b4b0367d1e8b40 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -426,6 +426,7 @@ xla_test( "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/tests:literal_test_util", "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -434,7 +435,6 @@ xla_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index d7f27f7e6b0c95..464e775cbb4dad 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h" #include +#include +#include #include #include @@ -43,12 +45,15 @@ limitations under the License. #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { using ::tsl::testing::StatusIs; +constexpr absl::string_view kFakeDeviceName = "Fake_device"; + constexpr absl::string_view kProgram = R"(HloModule Computation ENTRY Computation() -> s32[] { @@ -83,7 +88,8 @@ std::shared_ptr GetGpuTopology( TEST(StreamExecutorGpuCompilerTest, NoClientXla) { StreamExecutorGpuCompiler compiler; StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); + CudaId(), CudaName(), + GetGpuTopology({0, 1}, kFakeDeviceName, 1, 1, 2, 10)); TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram)); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology, @@ -94,7 +100,8 @@ TEST(StreamExecutorGpuCompilerTest, NoClientXla) { TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { StreamExecutorGpuCompiler compiler; StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); + CudaId(), CudaName(), + GetGpuTopology({0, 1}, kFakeDeviceName, 1, 1, 2, 10)); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -118,8 +125,9 @@ TEST(StreamExecutorGpuCompilerTest, SuccessXla) { TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable, client->Load(std::move(executable), load_options)); - TF_ASSERT_OK_AND_ASSIGN( - auto result, loaded_executable->Execute(/*argument_handles=*/{{}}, {})); + TF_ASSERT_OK_AND_ASSIGN(auto result, + loaded_executable->Execute( + /*argument_handles=*/{{}}, /*options=*/{})); ASSERT_EQ(result.size(), 1); std::vector>& result_buffers = result[0]; @@ -140,7 +148,8 @@ TEST(StreamExecutorGpuCompilerTest, NoClientMlir) { mlir::parseSourceString(mlir_str, &context); StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); + CudaId(), CudaName(), + GetGpuTopology({0, 1}, kFakeDeviceName, 1, 1, 2, 10)); EXPECT_THAT( compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology, @@ -158,7 +167,8 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) { mlir::parseSourceString(mlir_str, &context); StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); + CudaId(), CudaName(), + GetGpuTopology({0, 1}, kFakeDeviceName, 1, 1, 2, 10)); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -187,8 +197,51 @@ TEST(StreamExecutorGpuCompilerTest, SuccessMlir) { TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable, client->Load(std::move(executable), load_options)); + TF_ASSERT_OK_AND_ASSIGN(auto result, + loaded_executable->Execute( + /*argument_handles=*/{{}}, /*options=*/{})); + + ASSERT_EQ(result.size(), 1); + std::vector>& result_buffers = result[0]; + ASSERT_EQ(result_buffers.size(), 1); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr result_literal, + result_buffers[0]->ToLiteralSync()); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR0(2), *result_literal)); +} + +TEST(StreamExecutorGpuCompilerTest, SuccessMlirCanBeSerialized) { + StreamExecutorGpuCompiler compiler; + + mlir::MLIRContext context; + context.loadDialect(); + + auto mlir_module = + mlir::parseSourceString(mlir_str, &context); + + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + StreamExecutorGpuTopologyDescription topology( + CudaId(), CudaName(), + GetGpuTopology({0, 1}, kFakeDeviceName, 1, 1, 2, 10)); + TF_ASSERT_OK_AND_ASSIGN( - auto result, loaded_executable->Execute(/*argument_handles=*/{{}}, {})); + std::unique_ptr executable, + compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology, + client.get())); + + TF_ASSERT_OK_AND_ASSIGN(std::string serialized, + executable->SerializeExecutable()); + ASSERT_FALSE(serialized.empty()); + + TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable_from_serialized, + client->LoadSerializedExecutable( + serialized, std::nullopt, xla::LoadOptions())); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + loaded_executable_from_serialized->Execute( + /*argument_handles=*/{{}}, /*options=*/{})); ASSERT_EQ(result.size(), 1); std::vector>& result_buffers = result[0]; From 96b4c85c6ef12cce6b1871c4122e8e26292c911d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 15:47:21 -0700 Subject: [PATCH 0404/1324] Fix test of __ANDROID__ All other preprocessor tests of __ANDROID__ test if it is defined. Do that here too. PiperOrigin-RevId: 745321103 --- tensorflow/lite/core/api/tensor_utils.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/core/api/tensor_utils.cc b/tensorflow/lite/core/api/tensor_utils.cc index 18a643c78dc272..c5052c78f840cd 100644 --- a/tensorflow/lite/core/api/tensor_utils.cc +++ b/tensorflow/lite/core/api/tensor_utils.cc @@ -33,8 +33,8 @@ TfLiteStatus ResetVariableTensor(TfLiteTensor* tensor) { } // TODO(b/139446230): Provide a platform header to better handle these // specific scenarios. -#if __ANDROID__ || defined(__x86_64__) || defined(__i386__) || \ - defined(__i386) || defined(__x86__) || defined(__X86__) || \ +#if defined(__ANDROID__) || defined(__x86_64__) || defined(__i386__) || \ + defined(__i386) || defined(__x86__) || defined(__X86__) || \ defined(_X86_) || defined(_M_IX86) || defined(_M_X64) memset(tensor->data.raw, value, tensor->bytes); #else From 749598dbebe157190e7bf1a20adcbee0c6986671 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 17:06:45 -0700 Subject: [PATCH 0405/1324] Computing merged HloReplication objects for shapes with many replicas or devices can be slow, wrap calls to HloReplication::Merge using a cache to avoid re-computing. PiperOrigin-RevId: 745347064 --- .../hlo/analysis/hlo_replication_analysis.cc | 68 ++++++++++++------- .../hlo/analysis/hlo_replication_analysis.h | 34 +++++++++- 2 files changed, 77 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index dfa5f7584b2d47..cb8da1d1a2cf75 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -109,20 +109,26 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( bool support_partial_replication, const absl::flat_hash_map*>& - replica_group_dedup_map) { - const auto merge_operand_replication = [&hlo_replication]( - const HloInstruction* inst) { - HloReplication replication = HloReplication::ReplicatedOnAllDevices(); - for (auto operand : inst->operands()) { - auto operand_it = hlo_replication.find(operand); - if (operand_it == hlo_replication.end()) { - replication = replication.Merge(HloReplication::UniqueOnAllDevices()); - } else { - replication = replication.Merge(operand_it->second.element({})); - } - } - return replication; - }; + replica_group_dedup_map, + absl::flat_hash_map, + HloReplication>& replication_merge_map) { + const auto merge_operand_replication = + [&hlo_replication, &replication_merge_map](const HloInstruction* inst) { + HloReplication replication = HloReplication::ReplicatedOnAllDevices(); + for (auto operand : inst->operands()) { + auto operand_it = hlo_replication.find(operand); + if (operand_it == hlo_replication.end()) { + replication = MergeReplications( + replication, HloReplication::UniqueOnAllDevices(), + replication_merge_map); + } else { + replication = + MergeReplications(replication, operand_it->second.element({}), + replication_merge_map); + } + } + return replication; + }; auto calculate_all_reduce_all_gather_replication = [&](const HloInstruction* hlo) { @@ -325,15 +331,15 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( return true; } bool updated = false; - it->second.ForEachMutableElement( - [&](const ShapeIndex& index, HloReplication* element) { - HloReplication new_replication = - element->Merge(to_combine.element(index)); - if (!element->Equal(new_replication)) { - *element = std::move(new_replication); - updated = true; - } - }); + it->second.ForEachMutableElement([&](const ShapeIndex& index, + HloReplication* element) { + HloReplication new_replication = MergeReplications( + *element, to_combine.element(index), replication_merge_map_); + if (!element->Equal(new_replication)) { + *element = std::move(new_replication); + updated = true; + } + }); return updated; }; // Assigns or combines source's shape tree to dest. Returns if anything is @@ -472,7 +478,8 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( *shape_tree.mutable_element(index) = DetermineHloInstructionIsReplicated( inst, index, cross_partition_spmd_, hlo_replication_, - support_partial_replication_, replica_group_dedup_map_); + support_partial_replication_, replica_group_dedup_map_, + replication_merge_map_); }); changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); } @@ -762,6 +769,14 @@ HloReplicationAnalysis::HloReplication::Merge( } } +HloReplicationAnalysis::HloReplication::HloReplication( + const std::pair& merge_pair) { + auto merged_replication = merge_pair.first.Merge(merge_pair.second); + state_ = merged_replication.state_; + device_set_root_per_replica_ = + std::move(merged_replication.device_set_root_per_replica_); +} + bool HloReplicationAnalysis::HloReplication::Equal( const HloReplication& other) const { if (state_ != other.state_) { @@ -777,6 +792,11 @@ bool HloReplicationAnalysis::HloReplication::Equal( return true; } +bool HloReplicationAnalysis::HloReplication::operator==( + const HloReplicationAnalysis::HloReplication& rhs) const { + return Equal(rhs); +} + bool HloReplicationAnalysis::HloReplication::IsReplicatedOnAllDevices() const { return state_ == State::kReplicatedOnAllDevices; } diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h index 80bc62eb27a098..5a9057ebf89cab 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h @@ -16,7 +16,10 @@ limitations under the License. #ifndef XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ #define XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ +#include #include +#include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -76,14 +79,25 @@ class HloReplicationAnalysis { HloReplication(); HloReplication(const HloReplication& other) = default; HloReplication(HloReplication&& other) = default; + // Create a new HloReplication that is the merge of two other HloReplication + // objects using the Merge() method, useful for lazy construction with + // try_emplace. + explicit HloReplication( + const std::pair& merge_pair); HloReplication& operator=(HloReplication&& other) = default; HloReplication Merge(const HloReplication& other) const; bool Equal(const HloReplication& other) const; + bool operator==(const HloReplication& rhs) const; bool IsReplicatedOnAllDevices() const; bool IsUniqueOnAllDevices() const; bool IsReplicatedWithinSubgroup(absl::Span device_ids) const; std::string ToString() const; + template + friend H AbslHashValue(H h, const HloReplication& r) { + return H::combine(std::move(h), r.state_, r.device_set_root_per_replica_); + } + private: enum class State { kReplicatedOnAllDevices = 0, @@ -114,7 +128,23 @@ class HloReplicationAnalysis { bool support_partial_replication, const absl::flat_hash_map*>& - replica_group_dedup_map); + replica_group_dedup_map, + absl::flat_hash_map, + HloReplication>& replication_merge_map); + + static HloReplication MergeReplications( + const HloReplication& replication_a, const HloReplication& replication_b, + absl::flat_hash_map, + HloReplication>& replication_merge_map) { + std::pair key = {replication_a, + replication_b}; + + // Look replication pair up in map: if not found we pass the pair to an + // overloaded constructor of HloReplication which constructs and returns + // a merged HloReplication. + auto [iter, inserted] = replication_merge_map.try_emplace(key, key); + return iter->second; + } HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* @@ -170,6 +200,8 @@ class HloReplicationAnalysis { // replications for instructions with identical replica groups. absl::flat_hash_map*> replica_group_dedup_map_; + absl::flat_hash_map, HloReplication> + replication_merge_map_; std::vector> unique_replications_; }; From aed1f9d5c4abbba48f492903c90f010fb36d9145 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Tue, 8 Apr 2025 17:11:53 -0700 Subject: [PATCH 0406/1324] Fork fake_quant_utils within TF to use tf_quantization_lib (fork using TF Quant Dialect) PiperOrigin-RevId: 745348683 --- .../mlir/quantization/tensorflow/utils/BUILD | 19 +++ .../tensorflow/utils/temp_fake_quant_utils.cc | 73 ++++++++ .../tensorflow/utils/temp_fake_quant_utils.h | 160 ++++++++++++++++++ 3 files changed, 252 insertions(+) create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index fcd42b88cc30c9..dba0230f5d6d1f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -23,6 +23,25 @@ cc_library( ], ) +cc_library( + name = "temp_fake_quant_utils", + srcs = ["temp_fake_quant_utils.cc"], + hdrs = [ + "temp_fake_quant_utils.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tf_quantize_op_utils", srcs = ["tf_quantize_op_utils.cc"], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc new file mode 100644 index 00000000000000..bcde1612898a17 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h" + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace tf_quant { + +// Three instances of the rule to cover the three different types of +// TF::FakeQuant operators +using PreparePerTensorFakeQuant = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false, + FetchConstantMinMaxInputs>; + +using PreparePerChannelFakeQuant = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true, + FetchConstantMinMaxInputs>; + +using PreparePerTensorFakeQuantWithMinMaxArgs = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false, + FetchMinMaxAttrs>; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being foled. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext* ctx, + bool use_fake_quant_num_bits) { + OpBuilder builder(func); + + // Insert the quant.qcast/quant.dcast ops in place of the tf.FakeQuant* ops to + // preserve the quantization parameters. + func.walk([&](Operation* op) { + if (auto fake_quant = llvm::dyn_cast(op)) { + (void)PreparePerTensorFakeQuantWithMinMaxArgs(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } else if (auto fake_quant = + llvm::dyn_cast(op)) { + (void)PreparePerTensorFakeQuant(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } else if (auto fake_quant = + llvm::dyn_cast( + op)) { + (void)PreparePerChannelFakeQuant(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } + }); + + return success(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h new file mode 100644 index 00000000000000..84119aa38b4a66 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h @@ -0,0 +1,160 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TF-Quant transformation +// passes to work with tf.FakeQuant* ops. Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace tf_quant { + +template +struct FetchMinMaxAttrs { + using AttrType = FloatAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); + return true; // Successfully matched and fetched. + } +}; + +template +struct FetchConstantMinMaxInputs { + using AttrType = DenseFPElementsAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + Value min = tf_op.getMin(), max = tf_op.getMax(); + if (auto min_id = min.getDefiningOp()) { + min = min_id.getInput(); + } + if (auto max_id = max.getDefiningOp()) { + max = max_id.getInput(); + } + + if (!matchPattern(min, m_Constant(&min_value))) { + return false; + } + if (!matchPattern(max, m_Constant(&max_value))) { + return false; + } + return true; // Successfully matched and fetched. + } +}; + +// Inserts a "quant.qcast" and "quant.dcast" op pair (QDQs) in place of the +// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op +// before the op being constant folded. Since the constant +// folding logic will use a "arith.constant" op to replace the +// "tf.FakeQuantWithMinMaxVarsOp", the "quant.qcast" op is used to preserve +// the quantization parameters as a TypeAttr and "quant.dcast" op used to +// convert the output type to the next op. Here are the transformations: +// +// input min cst max cst input +// \ | | | +// \ (tf.Identity) (tf.Identity) => quant.qcast +// \ | | | +// tf.FakeQuantWithMinMaxVars quant.dcast +// | | +// +// Warns if the (most likely unwanted, currently not quite correctly handled) +// case of back-to-back tf.FakeQuant occurs +// +// tf.FakeQuant* +// | +// tf.FakeQuant* +// +template +class ConvertFakeQuantOpToQuantOps { + public: + explicit ConvertFakeQuantOpToQuantOps(bool use_fake_quant_num_bits) + : use_fake_quant_num_bits_(use_fake_quant_num_bits) {} + + FetchMinMax fetch_min_max_; + + using FetchAttrType = typename FetchMinMax::AttrType; + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + OpBuilder &rewriter) const { + if (tf_op.getNumBits() != 8) { + return failure(); + } + + // Extract the min/max constant values from the operands. We also consider + // a special case that there are tf.Identity ops between the min/max + // constants and the tf.FakeQuantWithMinMaxVarsOp. + FetchAttrType min_value, max_value; + if (!fetch_min_max_(tf_op, min_value, max_value)) { + return failure(); + } + + Value input = tf_op.getInputs(); + int quant_dim = -1; + auto input_type = mlir::cast(input.getType()); + if (PerAxis) { + if (!input_type.hasRank()) { + tf_op.emitError("The input should have known rank for per-channel op."); + return failure(); + } + // This is a special case that the quant_dim is the last dimensions. + quant_dim = input_type.getRank() - 1; + } + // Use the min/max from the operands and the num_bits and narrow_range + // attribute to create the quantization parameter for the new quantize op. + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); + Type res_type = tf_op.getType(); + TypeAttr qtype = tf_quant::GetQuantizedTypeAttr( + rewriter, input_type, min_value, max_value, quant_dim, num_bits, + narrow_range, /*is_signed=*/true, /*legacy_float_scale=*/false, + use_fake_quant_num_bits_); + if (!qtype) { + return failure(); + } + + // Finally, use the quantization parameter to create the quantize and + // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp + // and its users. + auto quantize = rewriter.create( + tf_op.getLoc(), qtype.getValue(), input); + auto dequantize = rewriter.create( + tf_op.getLoc(), res_type, quantize.getResult()); + tf_op.getOutputs().replaceAllUsesWith(dequantize); + + return success(); + } + + bool use_fake_quant_num_bits_; +}; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being folded. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx, + bool use_fake_quant_num_bits); + +} // namespace tf_quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ From 2a9943a53d88f66d31fb6c12e4d98f53f23aecef Mon Sep 17 00:00:00 2001 From: Jun Jiang Date: Tue, 8 Apr 2025 17:26:55 -0700 Subject: [PATCH 0407/1324] Fix build error in //tensorflow/lite/ios:TensorFlowLiteSelectTfOps_framework PiperOrigin-RevId: 745352608 --- tensorflow/core/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index acab945bf8c5ca..c3d6b934ea232f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1038,6 +1038,7 @@ cc_library( "//tensorflow/core/public:release_version", "//tensorflow/core/util:onednn_env_vars", "//tensorflow/core/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:safe_reinterpret_cast", ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) From 18054263da68aca100b67347e6050bfe3755b409 Mon Sep 17 00:00:00 2001 From: Tongfei Guo Date: Tue, 8 Apr 2025 18:05:34 -0700 Subject: [PATCH 0408/1324] [XLA] Check array operator() access is within bounds. PiperOrigin-RevId: 745363282 --- third_party/xla/xla/array.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index 0bec1540e95f48..fbe6cb944b1adb 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -365,6 +365,9 @@ class Array { const T&>::type operator()(Dims... dims) const { CHECK_EQ(sizeof...(dims), num_dimensions()); + // Check each index is within the bounds of the array. + int64_t i = 0; + ([&] { DCHECK_LT(dims, sizes_[i++]); }(), ...); // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{ From 7506b90231ac25b1685d571540ff45a51faa56c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 18:11:13 -0700 Subject: [PATCH 0409/1324] Integrate LLVM at llvm/llvm-project@cd54cb062bba Updates LLVM usage to match [cd54cb062bba](https://github.com/llvm/llvm-project/commit/cd54cb062bba) PiperOrigin-RevId: 745364812 --- .../convert_control_to_data_outputs.mlir | 7 +- .../convert_control_to_data_outputs.cc | 4 +- third_party/llvm/generated.patch | 769 +++++++++++++- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 944 +++++++++++++++++- third_party/shardy/workspace.bzl | 4 +- .../triton/llvm_integration/cl744822685.patch | 13 + .../triton/llvm_integration/series.bzl | 1 + .../xla/third_party/shardy/temporary.patch | 944 +++++++++++++++++- .../xla/third_party/shardy/workspace.bzl | 4 +- .../triton/llvm_integration/cl744822685.patch | 13 + .../triton/llvm_integration/series.bzl | 1 + .../deallocation/transforms/buffer_reuse.cc | 4 +- 13 files changed, 2666 insertions(+), 46 deletions(-) create mode 100644 third_party/triton/llvm_integration/cl744822685.patch create mode 100644 third_party/xla/third_party/triton/llvm_integration/cl744822685.patch diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir index 5f59e35498151e..abff7aeb61a2d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir @@ -656,7 +656,6 @@ func.func @incomplete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_re %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () // CHECK: [[exe]]{{.*}}"tf.Identity" - // CHECK-NOT: "tf.Identity" // CHECK: tf_executor.fetch tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control } @@ -816,11 +815,11 @@ func.func @tpu_execute_with_non_resource_operands(%arg0: !tf_res {tf._composite_ func.func @double_tpu_execute_while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (!tf_res, !tf_res, tensor) { - // CHECK: "tf.Identity" %graph:3 = tf_executor.graph { // CHECK: {{.*}}, [[ctrl1:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: {{.*}}, [[ctrl2:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: "tf.Identity" + // CHECK: "tf.Identity" %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[ctrl1]]) wraps "tf.TPUExecuteAndUpdateVariables" %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { @@ -887,9 +886,9 @@ func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res %arg2: tensor) -> (!tf_res, !tf_res, tensor) { %graph:3 = tf_executor.graph { - // CHECK: "tf.Identity" // CHECK: {{.*}}, [[id_ctrl:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: "tf.Identity" + // CHECK: "tf.Identity" %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[id_ctrl]]) wraps "tf.TPUExecuteAndUpdateVariables" %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { @@ -911,8 +910,8 @@ func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control1, %exe_control2) wraps "tf.NoOp"() : () -> () - // CHECK: "tf.Identity"(%arg3) // CHECK: tf_executor.island([[exe_ctrl1]], [[exe_ctrl2]]) wraps "tf.Identity" + // CHECK: "tf.Identity"(%arg4) // CHECK: "tf.Identity"(%arg5) // CHECK-NEXT: tf_executor.fetch tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index dd2f4b7309bf48..d63ace094451a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -424,7 +424,7 @@ void ChainResourceOps( for (auto class_iter = resource_equivalence_classes.begin(); class_iter != resource_equivalence_classes.end(); ++class_iter) { // Only visit one element per class, the leader. - if (!class_iter->isLeader()) continue; + if (!(*class_iter)->isLeader()) continue; // Create chain source and sink identity islands for current equivalence // class. @@ -445,7 +445,7 @@ void ChainResourceOps( // by `class_iter`). Keep track of ops that have already been processed. llvm::SmallDenseSet processed_ops; for (auto member_iter = - resource_equivalence_classes.member_begin(*class_iter); + resource_equivalence_classes.member_begin(**class_iter); member_iter != resource_equivalence_classes.member_end(); ++member_iter) { ResourceAndDevice resource_and_device = *member_iter; diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 2e6ff5801f349f..97282ecaf5c6ac 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,23 +1,748 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ---- a/clang/lib/Sema/SemaCXXScopeSpec.cpp -+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp -@@ -873,6 +873,7 @@ - DependentTemplateSpecializationTypeLoc SpecTL - = Builder.push(T); - SpecTL.setElaboratedKeywordLoc(SourceLocation()); -+ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); - SpecTL.setTemplateKeywordLoc(TemplateKWLoc); - SpecTL.setTemplateNameLoc(TemplateNameLoc); - SpecTL.setLAngleLoc(LAngleLoc); -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -1902,7 +1902,6 @@ - name = "inv_trigf_utils", - srcs = ["src/math/generic/inv_trigf_utils.cpp"], - hdrs = [ -- "src/math/generic/atan_utils.h", - "src/math/generic/inv_trigf_utils.h", - ], - deps = [ +diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp +--- a/clang/lib/AST/ASTContext.cpp ++++ b/clang/lib/AST/ASTContext.cpp +@@ -7011,7 +7011,7 @@ + getCanonicalTemplateArgument(subst->getArgumentPack()); + return getSubstTemplateTemplateParmPack( + canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), +- subst->getFinal(), subst->getIndex()); ++ subst->getIndex(), subst->getFinal()); + } + case TemplateName::DeducedTemplate: { + assert(IgnoreDeduced == false); +diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h +--- a/clang/lib/Sema/TreeTransform.h ++++ b/clang/lib/Sema/TreeTransform.h +@@ -7765,17 +7765,23 @@ + NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); + NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); + +- typedef TemplateArgumentLocContainerIterator< +- DependentTemplateSpecializationTypeLoc> ArgIterator; +- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), +- ArgIterator(TL, TL.getNumArgs()), +- NewTemplateArgs)) ++ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); ++ ++ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), ++ ArgsRange.end(), NewTemplateArgs)) + return QualType(); ++ bool TemplateArgumentsChanged = !llvm::equal( ++ ArgsRange, NewTemplateArgs.arguments(), ++ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { ++ return A.getArgument().structurallyEquals(B.getArgument()); ++ }); + + const DependentTemplateStorage &DTN = T->getDependentTemplateName(); + + QualType Result = TL.getType(); +- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { ++ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || ++ TemplateArgumentsChanged) { + TemplateName Name = getDerived().RebuildTemplateName( + SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), + /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp +--- a/clang/lib/Serialization/ASTReaderStmt.cpp ++++ b/clang/lib/Serialization/ASTReaderStmt.cpp +@@ -2229,6 +2229,7 @@ + E->PackIndex = Record.readInt(); + else + E->PackIndex = 0; ++ E->Final = CurrentUnpackingBits->getNextBit(); + E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); + E->Replacement = Record.readSubExpr(); + } +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp +--- a/clang/lib/Serialization/ASTWriterStmt.cpp ++++ b/clang/lib/Serialization/ASTWriterStmt.cpp +@@ -2229,6 +2229,7 @@ + CurrentPackingBits.addBit((bool)E->getPackIndex()); + if (auto PackIndex = E->getPackIndex()) + Record.push_back(*PackIndex + 1); ++ CurrentPackingBits.addBit(E->getFinal()); + + Record.AddSourceLocation(E->getNameLoc()); + Record.AddStmt(E->getReplacement()); +diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h +--- a/clang/test/CodeGen/include/cuda.h ++++ b/clang/test/CodeGen/include/cuda.h +@@ -1,194 +0,0 @@ +-/* Minimal declarations for CUDA support. Testing purposes only. +- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +- */ +-#pragma once +- +-// Make this file work with nvcc, for testing compatibility. +- +-#ifndef __NVCC__ +-#define __constant__ __attribute__((constant)) +-#define __device__ __attribute__((device)) +-#define __global__ __attribute__((global)) +-#define __host__ __attribute__((host)) +-#define __shared__ __attribute__((shared)) +-#define __managed__ __attribute__((managed)) +-#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +- +-struct dim3 { +- unsigned x, y, z; +- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +-}; +- +-// Host- and device-side placement new overloads. +-void *operator new(__SIZE_TYPE__, void *p) { return p; } +-void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +-__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +- +-#define CUDA_VERSION 10100 +- +-struct char1 { +- char x; +- __host__ __device__ char1(char x = 0) : x(x) {} +-}; +-struct char2 { +- char x, y; +- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +-}; +-struct char4 { +- char x, y, z, w; +- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct uchar1 { +- unsigned char x; +- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +-}; +-struct uchar2 { +- unsigned char x, y; +- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +-}; +-struct uchar4 { +- unsigned char x, y, z, w; +- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct short1 { +- short x; +- __host__ __device__ short1(short x = 0) : x(x) {} +-}; +-struct short2 { +- short x, y; +- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +-}; +-struct short4 { +- short x, y, z, w; +- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct ushort1 { +- unsigned short x; +- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +-}; +-struct ushort2 { +- unsigned short x, y; +- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +-}; +-struct ushort4 { +- unsigned short x, y, z, w; +- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct int1 { +- int x; +- __host__ __device__ int1(int x = 0) : x(x) {} +-}; +-struct int2 { +- int x, y; +- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +-}; +-struct int4 { +- int x, y, z, w; +- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct uint1 { +- unsigned x; +- __host__ __device__ uint1(unsigned x = 0) : x(x) {} +-}; +-struct uint2 { +- unsigned x, y; +- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +-}; +-struct uint3 { +- unsigned x, y, z; +- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +-}; +-struct uint4 { +- unsigned x, y, z, w; +- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct longlong1 { +- long long x; +- __host__ __device__ longlong1(long long x = 0) : x(x) {} +-}; +-struct longlong2 { +- long long x, y; +- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +-}; +-struct longlong4 { +- long long x, y, z, w; +- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct ulonglong1 { +- unsigned long long x; +- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +-}; +-struct ulonglong2 { +- unsigned long long x, y; +- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +-}; +-struct ulonglong4 { +- unsigned long long x, y, z, w; +- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct float1 { +- float x; +- __host__ __device__ float1(float x = 0) : x(x) {} +-}; +-struct float2 { +- float x, y; +- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +-}; +-struct float4 { +- float x, y, z, w; +- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-struct double1 { +- double x; +- __host__ __device__ double1(double x = 0) : x(x) {} +-}; +-struct double2 { +- double x, y; +- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +-}; +-struct double4 { +- double x, y, z, w; +- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +-}; +- +-typedef unsigned long long cudaTextureObject_t; +-typedef unsigned long long cudaSurfaceObject_t; +- +-enum cudaTextureReadMode { +- cudaReadModeNormalizedFloat, +- cudaReadModeElementType +-}; +- +-enum cudaSurfaceBoundaryMode { +- cudaBoundaryModeZero, +- cudaBoundaryModeClamp, +- cudaBoundaryModeTrap +-}; +- +-enum { +- cudaTextureType1D, +- cudaTextureType2D, +- cudaTextureType3D, +- cudaTextureTypeCubemap, +- cudaTextureType1DLayered, +- cudaTextureType2DLayered, +- cudaTextureTypeCubemapLayered +-}; +- +-struct textureReference {}; +-template +-struct __attribute__((device_builtin_texture_type)) texture +- : public textureReference {}; +- +-#endif // !__NVCC__ +diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h +--- a/clang/test/CodeGen/Inputs/cuda.h ++++ b/clang/test/CodeGen/Inputs/cuda.h +@@ -0,0 +1,194 @@ ++/* Minimal declarations for CUDA support. Testing purposes only. ++ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h ++ */ ++#pragma once ++ ++// Make this file work with nvcc, for testing compatibility. ++ ++#ifndef __NVCC__ ++#define __constant__ __attribute__((constant)) ++#define __device__ __attribute__((device)) ++#define __global__ __attribute__((global)) ++#define __host__ __attribute__((host)) ++#define __shared__ __attribute__((shared)) ++#define __managed__ __attribute__((managed)) ++#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) ++ ++struct dim3 { ++ unsigned x, y, z; ++ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} ++}; ++ ++// Host- and device-side placement new overloads. ++void *operator new(__SIZE_TYPE__, void *p) { return p; } ++void *operator new[](__SIZE_TYPE__, void *p) { return p; } ++__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } ++__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } ++ ++#define CUDA_VERSION 10100 ++ ++struct char1 { ++ char x; ++ __host__ __device__ char1(char x = 0) : x(x) {} ++}; ++struct char2 { ++ char x, y; ++ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} ++}; ++struct char4 { ++ char x, y, z, w; ++ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct uchar1 { ++ unsigned char x; ++ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} ++}; ++struct uchar2 { ++ unsigned char x, y; ++ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} ++}; ++struct uchar4 { ++ unsigned char x, y, z, w; ++ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct short1 { ++ short x; ++ __host__ __device__ short1(short x = 0) : x(x) {} ++}; ++struct short2 { ++ short x, y; ++ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} ++}; ++struct short4 { ++ short x, y, z, w; ++ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct ushort1 { ++ unsigned short x; ++ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} ++}; ++struct ushort2 { ++ unsigned short x, y; ++ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} ++}; ++struct ushort4 { ++ unsigned short x, y, z, w; ++ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct int1 { ++ int x; ++ __host__ __device__ int1(int x = 0) : x(x) {} ++}; ++struct int2 { ++ int x, y; ++ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} ++}; ++struct int4 { ++ int x, y, z, w; ++ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct uint1 { ++ unsigned x; ++ __host__ __device__ uint1(unsigned x = 0) : x(x) {} ++}; ++struct uint2 { ++ unsigned x, y; ++ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} ++}; ++struct uint3 { ++ unsigned x, y, z; ++ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} ++}; ++struct uint4 { ++ unsigned x, y, z, w; ++ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct longlong1 { ++ long long x; ++ __host__ __device__ longlong1(long long x = 0) : x(x) {} ++}; ++struct longlong2 { ++ long long x, y; ++ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} ++}; ++struct longlong4 { ++ long long x, y, z, w; ++ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct ulonglong1 { ++ unsigned long long x; ++ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} ++}; ++struct ulonglong2 { ++ unsigned long long x, y; ++ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} ++}; ++struct ulonglong4 { ++ unsigned long long x, y, z, w; ++ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct float1 { ++ float x; ++ __host__ __device__ float1(float x = 0) : x(x) {} ++}; ++struct float2 { ++ float x, y; ++ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} ++}; ++struct float4 { ++ float x, y, z, w; ++ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++struct double1 { ++ double x; ++ __host__ __device__ double1(double x = 0) : x(x) {} ++}; ++struct double2 { ++ double x, y; ++ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} ++}; ++struct double4 { ++ double x, y, z, w; ++ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} ++}; ++ ++typedef unsigned long long cudaTextureObject_t; ++typedef unsigned long long cudaSurfaceObject_t; ++ ++enum cudaTextureReadMode { ++ cudaReadModeNormalizedFloat, ++ cudaReadModeElementType ++}; ++ ++enum cudaSurfaceBoundaryMode { ++ cudaBoundaryModeZero, ++ cudaBoundaryModeClamp, ++ cudaBoundaryModeTrap ++}; ++ ++enum { ++ cudaTextureType1D, ++ cudaTextureType2D, ++ cudaTextureType3D, ++ cudaTextureTypeCubemap, ++ cudaTextureType1DLayered, ++ cudaTextureType2DLayered, ++ cudaTextureTypeCubemapLayered ++}; ++ ++struct textureReference {}; ++template ++struct __attribute__((device_builtin_texture_type)) texture ++ : public textureReference {}; ++ ++#endif // !__NVCC__ +diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu +--- a/clang/test/CodeGen/nvptx-surface.cu ++++ b/clang/test/CodeGen/nvptx-surface.cu +@@ -1,6 +1,6 @@ + // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s + // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s +-#include "include/cuda.h" ++#include "Inputs/cuda.h" + + #include "__clang_cuda_texture_intrinsics.h" + +diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp +--- a/clang/test/SemaTemplate/dependent-names.cpp ++++ b/clang/test/SemaTemplate/dependent-names.cpp +@@ -458,3 +458,12 @@ + }; + int f(b ba) { return ba.add<0>(); } + } ++ ++namespace TransformDependentTemplates { ++ template struct Test1 { ++ template ++ using Arg = typename T::template Arg; ++ void f(Arg); ++ void f(Arg); ++ }; ++} // namespace TransformDependentTemplates +diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +@@ -15391,12 +15391,20 @@ + + if (E->State == TreeEntry::SplitVectorize) { + Res = FindLastInst(); ++ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { ++ for (auto *E : Entries) { ++ auto *I = dyn_cast_or_null(E->VectorizedValue); ++ if (!I) ++ I = &getLastInstructionInBundle(E); ++ if (Res->comesBefore(I)) ++ Res = I; ++ } ++ } + return *Res; + } + + // Set insertpoint for gathered loads to the very first load. +- if (E->State != TreeEntry::SplitVectorize && +- GatheredLoadsEntriesFirst.has_value() && ++ if (GatheredLoadsEntriesFirst.has_value() && + E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && + E->getOpcode() == Instruction::Load) { + Res = FindFirstInst(); +diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ++++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +@@ -2590,6 +2590,14 @@ + if (R.mayWriteToMemory() && !InterleaveR) + return; + ++ // Do not narrow interleave groups if there are VectorPointer recipes and ++ // the plan was unrolled. The recipe implicitly uses VF from ++ // VPTransformState. ++ // TODO: Remove restriction once the VF for the VectorPointer offset is ++ // modeled explicitly as operand. ++ if (isa(&R) && Plan.getUF() > 1) ++ return; ++ + // All other ops are allowed, but we reject uses that cannot be converted + // when checking all allowed consumers (store interleave groups) below. + if (!InterleaveR) +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +--- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ++++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +@@ -66,3 +66,91 @@ + exit: + ret void + } ++ ++define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { ++; CHECK-LABEL: define void @test_2xi64_with_wide_load( ++; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { ++; CHECK-NEXT: [[ENTRY:.*]]: ++; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] ++; CHECK: [[VECTOR_PH]]: ++; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ++; CHECK: [[VECTOR_BODY]]: ++; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] ++; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 ++; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] ++; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 ++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 ++; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 ++; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 ++; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 ++; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 ++; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] ++; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] ++; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 ++; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> ++; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> ++; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 ++; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> ++; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> ++; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] ++; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] ++; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] ++; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] ++; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> ++; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> ++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 ++; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> ++; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> ++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 ++; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 ++; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 ++; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ++; CHECK: [[MIDDLE_BLOCK]]: ++; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] ++; CHECK: [[SCALAR_PH]]: ++; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] ++; CHECK-NEXT: br label %[[LOOP:.*]] ++; CHECK: [[LOOP]]: ++; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] ++; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] ++; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 ++; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 ++; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] ++; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 ++; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] ++; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 ++; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 ++; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] ++; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 ++; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] ++; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 ++; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 ++; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 ++; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] ++; CHECK: [[EXIT]]: ++; CHECK-NEXT: ret void ++; ++entry: ++ br label %loop ++ ++loop: ++ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] ++ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv ++ %l.factor = load i64, ptr %arrayidx, align 8 ++ %1 = shl nsw i64 %iv, 1 ++ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 ++ %l.0 = load i64, ptr %data.0, align 8 ++ %mul.0 = mul i64 %l.factor, %l.0 ++ store i64 %mul.0, ptr %data.0, align 8 ++ %3 = or disjoint i64 %1, 1 ++ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 ++ %l.1 = load i64, ptr %data.1, align 8 ++ %mul.1 = mul i64 %l.factor, %l.1 ++ store i64 %mul.1, ptr %data.1, align 8 ++ %iv.next = add nuw nsw i64 %iv, 1 ++ %ec = icmp eq i64 %iv.next, 100 ++ br i1 %ec, label %exit, label %loop ++ ++exit: ++ ret void ++} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +@@ -0,0 +1,99 @@ ++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s ++ ++define void @test(ptr %0, <8 x i8> %1) { ++; CHECK-LABEL: define void @test( ++; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { ++; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 ++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 ++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 ++; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 ++; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 ++; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> ++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 ++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> ++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> ++; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) ++; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 ++; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> ++; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) ++; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) ++; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] ++; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 ++; CHECK-NEXT: ret void ++; ++ %3 = load i8, ptr %0, align 2 ++ %4 = getelementptr i8, ptr %0, i64 13442 ++ %5 = load i8, ptr %4, align 2 ++ %6 = or i8 %5, %3 ++ %7 = getelementptr i8, ptr %0, i64 13550 ++ store i8 %6, ptr %7, align 2 ++ %8 = extractelement <8 x i8> %1, i64 0 ++ %9 = or i8 %5, %8 ++ %10 = getelementptr i8, ptr %0, i64 13542 ++ store i8 %9, ptr %10, align 2 ++ %11 = getelementptr i8, ptr %0, i64 13438 ++ %12 = load i8, ptr %11, align 2 ++ %13 = or i8 %12, %3 ++ %14 = getelementptr i8, ptr %0, i64 13546 ++ store i8 %13, ptr %14, align 2 ++ %15 = extractelement <8 x i8> %1, i64 2 ++ %16 = or i8 %12, %15 ++ %17 = getelementptr i8, ptr %0, i64 13538 ++ store i8 %16, ptr %17, align 2 ++ %18 = getelementptr i8, ptr %0, i64 13440 ++ %19 = load i8, ptr %18, align 4 ++ %20 = or i8 %19, %3 ++ %21 = getelementptr i8, ptr %0, i64 13548 ++ store i8 %20, ptr %21, align 4 ++ %22 = extractelement <8 x i8> %1, i64 4 ++ %23 = or i8 %19, %22 ++ %24 = getelementptr i8, ptr %0, i64 13540 ++ store i8 %23, ptr %24, align 4 ++ %25 = getelementptr i8, ptr %0, i64 13436 ++ %26 = load i8, ptr %25, align 4 ++ %27 = getelementptr i8, ptr %0, i64 13444 ++ %28 = load i8, ptr %27, align 4 ++ %29 = or i8 %28, %26 ++ %30 = getelementptr i8, ptr %0, i64 13544 ++ store i8 %29, ptr %30, align 4 ++ %31 = or i8 %26, %8 ++ %32 = getelementptr i8, ptr %0, i64 13536 ++ store i8 %31, ptr %32, align 4 ++ %33 = getelementptr i8, ptr %0, i64 13443 ++ %34 = load i8, ptr %33, align 1 ++ %35 = or i8 %34, %3 ++ %36 = getelementptr i8, ptr %0, i64 13551 ++ store i8 %35, ptr %36, align 1 ++ %37 = extractelement <8 x i8> %1, i64 7 ++ %38 = or i8 %34, %37 ++ %39 = getelementptr i8, ptr %0, i64 13543 ++ store i8 %38, ptr %39, align 1 ++ %40 = getelementptr i8, ptr %0, i64 13439 ++ %41 = load i8, ptr %40, align 1 ++ %42 = or i8 %41, %3 ++ %43 = getelementptr i8, ptr %0, i64 13547 ++ store i8 %42, ptr %43, align 1 ++ %44 = extractelement <8 x i8> %1, i64 3 ++ %45 = or i8 %41, %44 ++ %46 = getelementptr i8, ptr %0, i64 13539 ++ store i8 %45, ptr %46, align 1 ++ %47 = getelementptr i8, ptr %0, i64 13441 ++ %48 = load i8, ptr %47, align 1 ++ %49 = or i8 %48, %3 ++ %50 = getelementptr i8, ptr %0, i64 13549 ++ store i8 %49, ptr %50, align 1 ++ %51 = extractelement <8 x i8> %1, i64 5 ++ %52 = or i8 %48, %51 ++ %53 = getelementptr i8, ptr %0, i64 13541 ++ store i8 %52, ptr %53, align 1 ++ %54 = getelementptr i8, ptr %0, i64 13437 ++ %55 = load i8, ptr %54, align 1 ++ %56 = or i8 %55, %3 ++ %57 = getelementptr i8, ptr %0, i64 13545 ++ store i8 %56, ptr %57, align 1 ++ %58 = or i8 %55, %8 ++ %59 = getelementptr i8, ptr %0, i64 13537 ++ store i8 %58, ptr %59, align 1 ++ ret void ++} +diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +--- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ++++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +@@ -151,9 +151,10 @@ + MachineModuleInfoWrapperPass *MMIWP = + new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); + +- legacy::PassManager PassMgrF; + SmallString<1024> Buf; + llvm::raw_svector_ostream OS(Buf); ++ legacy::PassManager PassMgrF; ++ + AsmPrinter *Printer = + addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); + PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 4a58099b072de7..c3bcd538474076 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" - LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" + LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" + LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 4adb475a33423c..e0644f1eee41da 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,949 @@ +diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td +index 58c9f74..64cfe7f 100644 +--- a/shardy/dialect/sdy/transforms/export/passes.td ++++ b/shardy/dialect/sdy/transforms/export/passes.td +@@ -114,8 +114,8 @@ def TempExplicitReshardsForOptimizationsPass : Pass<"sdy-temp-explicit-reshards- + This pass is a temporary solution until we can enable the + `sdy-insert-explicit-reshards` pass by default. + +- It allows us to insert explicit reshards on specific operations for +- optimizations. ++ It allows us to improve specific use cases where the partitioner does the ++ sub-optimal thing. + }]; + } + +diff --git a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc +index b20b794..0642e3c 100644 +--- a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc ++++ b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc +@@ -29,7 +29,6 @@ limitations under the License. + #include "mlir/Support/LLVM.h" + #include "shardy/dialect/sdy/ir/dialect.h" + #include "shardy/dialect/sdy/ir/utils.h" +-#include "shardy/dialect/sdy/transforms/export/explicit_reshards_util.h" + #include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep + #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h" + #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" +@@ -236,9 +235,6 @@ struct TempExplicitReshardsForOptimizationsPass + [&](stablehlo::DotGeneralOp dotGeneralOp) { + processDot(dotGeneralOp, rewriter, symbolTable); + }); +- if (op->getName().getStringRef().str() == "mhlo.ragged_dot") { +- insertExplicitReshardsOnOp(op, rewriter, symbolTable); +- } + }); + } + }; +diff --git a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir +index 48bcbcb..117954c 100644 +--- a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir ++++ b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir +@@ -1,8 +1,7 @@ +-// RUN: sdy_opt %s -allow-unregistered-dialect -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s ++// RUN: sdy_opt %s -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s + + sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> + sdy.mesh @other_mesh = <["x"=2, "y"=2]> +-sdy.mesh @mesh_abcd = <["a"=2, "b"=2, "c"=2, "d"=2]> + + // CHECK-LABEL: func @reshard_dot_result_to_match_lhs + func.func @reshard_dot_result_to_match_lhs( +@@ -317,77 +316,3 @@ func.func @dot_result_conflicting_sharding_mismatch_with_reduction_axes_3( + (tensor<4x2x32xf32>, tensor<2x32x8xf32>) -> tensor<4x8xf32> + return %0 : tensor<4x8xf32> + } +- +-// CHECK-LABEL: func @ragged_dot_mode_non_contracting +-func.func @ragged_dot_mode_non_contracting( +- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg1: tensor<4x16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>}, +- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<16x32x8xf32> { +- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {}, {"c"}]> : tensor<16x32x64xf32> +- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{}, {"a"}, {"c"}, {"d"}]> : tensor<4x16x64x8xf32> +- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> +- +- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ +- // CHECK: }> +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {}, {"d"}]>]> +- +- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x32x8xf32> +- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}]> : tensor<16x32x8xf32> +- // CHECK: return %[[RESHARD3]] : tensor<16x32x8xf32> +- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = +- #mhlo.ragged_dot, +- lhs_ragged_dimensions = [1], rhs_group_dimensions = [0]>}> +- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [m, i, l, k], [i, m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=4} reduction={l} need_replication={j, m}>, +- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>]>} +- : (tensor<16x32x64xf32>, tensor<4x16x64x8xf32>, tensor<16x4xi32>) -> tensor<16x32x8xf32> +- return %0 : tensor<16x32x8xf32> +-} +- +-// CHECK-LABEL: func @ragged_dot_mode_contracting +-func.func @ragged_dot_mode_contracting( +- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<4x16x32x8xf32> { +- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {"b"}, {}]> : tensor<16x32x64xf32> +- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x64x8xf32> +- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> +- +- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ +- // CHECK: }> +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{}, {"a"}, {"b"}, {"d"}]>]> +- +- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[RAGGED_DOT]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]> : tensor<4x16x32x8xf32> +- // CHECK: return %[[RESHARD3]] : tensor<4x16x32x8xf32> +- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = +- #mhlo.ragged_dot, +- lhs_ragged_dimensions = [2]>}> +- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [i, m])->([m, i, j, k]) {i=16, j=32, k=8, l=64, m=4} need_replication={l, m}>, +- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>]>} +- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<16x4xi32>) -> tensor<4x16x32x8xf32> +- return %0 : tensor<4x16x32x8xf32> +-} +- +-// CHECK-LABEL: func @ragged_dot_mode_batch +-func.func @ragged_dot_mode_batch( +- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, +- %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> { +- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ +- // CHECK: }> +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> +- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32> +- // CHECK: return %[[ALL_REDUCE]] : tensor<16x32x8xf32> +- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = +- #mhlo.ragged_dot, +- lhs_ragged_dimensions = [0]>}> +- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=1} reduction={l}>, +- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>} +- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<4xi32>) -> tensor<16x32x8xf32> +- return %0 : tensor<16x32x8xf32> +-} +diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +index 6cfed8f..4061903 100644 +--- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc ++++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +@@ -117,8 +117,8 @@ GroupIdToShardingGroups unifyShardingGroups( + int64_t reindexId = 0; + SmallDenseMap reindexMap; + for (const auto& group : shardingGroupEquivalences) { +- if (group.isLeader()) { +- reindexMap[group.getData()] = reindexId++; ++ if (group->isLeader()) { ++ reindexMap[group->getData()] = reindexId++; + } + } + +diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +index 97099a1..6a711ae 100644 +--- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir ++++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +@@ -64,8 +64,8 @@ func.func @sharding_groups_reindexes_ids(%arg0: tensor<4xf32>, %arg1: tensor<4xf + + // CHECK-LABEL: sharding_groups_reindex_ordering_matches_min_element_ordering + func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) { +- // CHECK: sdy.sharding_group %arg0 group_id=1 : tensor<4xf32> +- // CHECK: sdy.sharding_group %arg1 group_id=0 : tensor<4xf32> ++ // CHECK: sdy.sharding_group %arg0 group_id=0 : tensor<4xf32> ++ // CHECK: sdy.sharding_group %arg1 group_id=1 : tensor<4xf32> + // CHECK: sdy.sharding_group %arg2 group_id=2 : tensor<4xf32> + sdy.sharding_group %arg0 group_id = 567 : tensor<4xf32> + sdy.sharding_group %arg0 group_id = 23 : tensor<4xf32> +diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch +index 2e6ff58..97282ec 100644 +--- a/third_party/llvm/generated.patch ++++ b/third_party/llvm/generated.patch +@@ -1,23 +1,748 @@ + Auto generated patch. Do not edit or delete it, even if empty. +-diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp +---- a/clang/lib/Sema/SemaCXXScopeSpec.cpp +-+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp +-@@ -873,6 +873,7 @@ +- DependentTemplateSpecializationTypeLoc SpecTL +- = Builder.push(T); +- SpecTL.setElaboratedKeywordLoc(SourceLocation()); +-+ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); +- SpecTL.setTemplateKeywordLoc(TemplateKWLoc); +- SpecTL.setTemplateNameLoc(TemplateNameLoc); +- SpecTL.setLAngleLoc(LAngleLoc); +-diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +-+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +-@@ -1902,7 +1902,6 @@ +- name = "inv_trigf_utils", +- srcs = ["src/math/generic/inv_trigf_utils.cpp"], +- hdrs = [ +-- "src/math/generic/atan_utils.h", +- "src/math/generic/inv_trigf_utils.h", +- ], +- deps = [ ++diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp ++--- a/clang/lib/AST/ASTContext.cpp +++++ b/clang/lib/AST/ASTContext.cpp ++@@ -7011,7 +7011,7 @@ ++ getCanonicalTemplateArgument(subst->getArgumentPack()); ++ return getSubstTemplateTemplateParmPack( ++ canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), ++- subst->getFinal(), subst->getIndex()); +++ subst->getIndex(), subst->getFinal()); ++ } ++ case TemplateName::DeducedTemplate: { ++ assert(IgnoreDeduced == false); ++diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ++--- a/clang/lib/Sema/TreeTransform.h +++++ b/clang/lib/Sema/TreeTransform.h ++@@ -7765,17 +7765,23 @@ ++ NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); ++ NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); ++ ++- typedef TemplateArgumentLocContainerIterator< ++- DependentTemplateSpecializationTypeLoc> ArgIterator; ++- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), ++- ArgIterator(TL, TL.getNumArgs()), ++- NewTemplateArgs)) +++ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); +++ +++ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), +++ ArgsRange.end(), NewTemplateArgs)) ++ return QualType(); +++ bool TemplateArgumentsChanged = !llvm::equal( +++ ArgsRange, NewTemplateArgs.arguments(), +++ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { +++ return A.getArgument().structurallyEquals(B.getArgument()); +++ }); ++ ++ const DependentTemplateStorage &DTN = T->getDependentTemplateName(); ++ ++ QualType Result = TL.getType(); ++- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { +++ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || +++ TemplateArgumentsChanged) { ++ TemplateName Name = getDerived().RebuildTemplateName( ++ SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), ++ /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ++--- a/clang/lib/Serialization/ASTReaderStmt.cpp +++++ b/clang/lib/Serialization/ASTReaderStmt.cpp ++@@ -2229,6 +2229,7 @@ ++ E->PackIndex = Record.readInt(); ++ else ++ E->PackIndex = 0; +++ E->Final = CurrentUnpackingBits->getNextBit(); ++ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); ++ E->Replacement = Record.readSubExpr(); ++ } ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ++--- a/clang/lib/Serialization/ASTWriterStmt.cpp +++++ b/clang/lib/Serialization/ASTWriterStmt.cpp ++@@ -2229,6 +2229,7 @@ ++ CurrentPackingBits.addBit((bool)E->getPackIndex()); ++ if (auto PackIndex = E->getPackIndex()) ++ Record.push_back(*PackIndex + 1); +++ CurrentPackingBits.addBit(E->getFinal()); ++ ++ Record.AddSourceLocation(E->getNameLoc()); ++ Record.AddStmt(E->getReplacement()); ++diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h ++--- a/clang/test/CodeGen/include/cuda.h +++++ b/clang/test/CodeGen/include/cuda.h ++@@ -1,194 +0,0 @@ ++-/* Minimal declarations for CUDA support. Testing purposes only. ++- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h ++- */ ++-#pragma once ++- ++-// Make this file work with nvcc, for testing compatibility. ++- ++-#ifndef __NVCC__ ++-#define __constant__ __attribute__((constant)) ++-#define __device__ __attribute__((device)) ++-#define __global__ __attribute__((global)) ++-#define __host__ __attribute__((host)) ++-#define __shared__ __attribute__((shared)) ++-#define __managed__ __attribute__((managed)) ++-#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) ++- ++-struct dim3 { ++- unsigned x, y, z; ++- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} ++-}; ++- ++-// Host- and device-side placement new overloads. ++-void *operator new(__SIZE_TYPE__, void *p) { return p; } ++-void *operator new[](__SIZE_TYPE__, void *p) { return p; } ++-__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } ++-__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } ++- ++-#define CUDA_VERSION 10100 ++- ++-struct char1 { ++- char x; ++- __host__ __device__ char1(char x = 0) : x(x) {} ++-}; ++-struct char2 { ++- char x, y; ++- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} ++-}; ++-struct char4 { ++- char x, y, z, w; ++- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct uchar1 { ++- unsigned char x; ++- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} ++-}; ++-struct uchar2 { ++- unsigned char x, y; ++- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} ++-}; ++-struct uchar4 { ++- unsigned char x, y, z, w; ++- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct short1 { ++- short x; ++- __host__ __device__ short1(short x = 0) : x(x) {} ++-}; ++-struct short2 { ++- short x, y; ++- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} ++-}; ++-struct short4 { ++- short x, y, z, w; ++- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct ushort1 { ++- unsigned short x; ++- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} ++-}; ++-struct ushort2 { ++- unsigned short x, y; ++- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} ++-}; ++-struct ushort4 { ++- unsigned short x, y, z, w; ++- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct int1 { ++- int x; ++- __host__ __device__ int1(int x = 0) : x(x) {} ++-}; ++-struct int2 { ++- int x, y; ++- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} ++-}; ++-struct int4 { ++- int x, y, z, w; ++- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct uint1 { ++- unsigned x; ++- __host__ __device__ uint1(unsigned x = 0) : x(x) {} ++-}; ++-struct uint2 { ++- unsigned x, y; ++- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} ++-}; ++-struct uint3 { ++- unsigned x, y, z; ++- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} ++-}; ++-struct uint4 { ++- unsigned x, y, z, w; ++- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct longlong1 { ++- long long x; ++- __host__ __device__ longlong1(long long x = 0) : x(x) {} ++-}; ++-struct longlong2 { ++- long long x, y; ++- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} ++-}; ++-struct longlong4 { ++- long long x, y, z, w; ++- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct ulonglong1 { ++- unsigned long long x; ++- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} ++-}; ++-struct ulonglong2 { ++- unsigned long long x, y; ++- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} ++-}; ++-struct ulonglong4 { ++- unsigned long long x, y, z, w; ++- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct float1 { ++- float x; ++- __host__ __device__ float1(float x = 0) : x(x) {} ++-}; ++-struct float2 { ++- float x, y; ++- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} ++-}; ++-struct float4 { ++- float x, y, z, w; ++- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct double1 { ++- double x; ++- __host__ __device__ double1(double x = 0) : x(x) {} ++-}; ++-struct double2 { ++- double x, y; ++- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} ++-}; ++-struct double4 { ++- double x, y, z, w; ++- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-typedef unsigned long long cudaTextureObject_t; ++-typedef unsigned long long cudaSurfaceObject_t; ++- ++-enum cudaTextureReadMode { ++- cudaReadModeNormalizedFloat, ++- cudaReadModeElementType ++-}; ++- ++-enum cudaSurfaceBoundaryMode { ++- cudaBoundaryModeZero, ++- cudaBoundaryModeClamp, ++- cudaBoundaryModeTrap ++-}; ++- ++-enum { ++- cudaTextureType1D, ++- cudaTextureType2D, ++- cudaTextureType3D, ++- cudaTextureTypeCubemap, ++- cudaTextureType1DLayered, ++- cudaTextureType2DLayered, ++- cudaTextureTypeCubemapLayered ++-}; ++- ++-struct textureReference {}; ++-template ++-struct __attribute__((device_builtin_texture_type)) texture ++- : public textureReference {}; ++- ++-#endif // !__NVCC__ ++diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h ++--- a/clang/test/CodeGen/Inputs/cuda.h +++++ b/clang/test/CodeGen/Inputs/cuda.h ++@@ -0,0 +1,194 @@ +++/* Minimal declarations for CUDA support. Testing purposes only. +++ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +++ */ +++#pragma once +++ +++// Make this file work with nvcc, for testing compatibility. +++ +++#ifndef __NVCC__ +++#define __constant__ __attribute__((constant)) +++#define __device__ __attribute__((device)) +++#define __global__ __attribute__((global)) +++#define __host__ __attribute__((host)) +++#define __shared__ __attribute__((shared)) +++#define __managed__ __attribute__((managed)) +++#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +++ +++struct dim3 { +++ unsigned x, y, z; +++ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +++}; +++ +++// Host- and device-side placement new overloads. +++void *operator new(__SIZE_TYPE__, void *p) { return p; } +++void *operator new[](__SIZE_TYPE__, void *p) { return p; } +++__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +++__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +++ +++#define CUDA_VERSION 10100 +++ +++struct char1 { +++ char x; +++ __host__ __device__ char1(char x = 0) : x(x) {} +++}; +++struct char2 { +++ char x, y; +++ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +++}; +++struct char4 { +++ char x, y, z, w; +++ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct uchar1 { +++ unsigned char x; +++ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +++}; +++struct uchar2 { +++ unsigned char x, y; +++ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +++}; +++struct uchar4 { +++ unsigned char x, y, z, w; +++ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct short1 { +++ short x; +++ __host__ __device__ short1(short x = 0) : x(x) {} +++}; +++struct short2 { +++ short x, y; +++ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +++}; +++struct short4 { +++ short x, y, z, w; +++ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct ushort1 { +++ unsigned short x; +++ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +++}; +++struct ushort2 { +++ unsigned short x, y; +++ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +++}; +++struct ushort4 { +++ unsigned short x, y, z, w; +++ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct int1 { +++ int x; +++ __host__ __device__ int1(int x = 0) : x(x) {} +++}; +++struct int2 { +++ int x, y; +++ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +++}; +++struct int4 { +++ int x, y, z, w; +++ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct uint1 { +++ unsigned x; +++ __host__ __device__ uint1(unsigned x = 0) : x(x) {} +++}; +++struct uint2 { +++ unsigned x, y; +++ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +++}; +++struct uint3 { +++ unsigned x, y, z; +++ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +++}; +++struct uint4 { +++ unsigned x, y, z, w; +++ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct longlong1 { +++ long long x; +++ __host__ __device__ longlong1(long long x = 0) : x(x) {} +++}; +++struct longlong2 { +++ long long x, y; +++ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +++}; +++struct longlong4 { +++ long long x, y, z, w; +++ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct ulonglong1 { +++ unsigned long long x; +++ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +++}; +++struct ulonglong2 { +++ unsigned long long x, y; +++ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +++}; +++struct ulonglong4 { +++ unsigned long long x, y, z, w; +++ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct float1 { +++ float x; +++ __host__ __device__ float1(float x = 0) : x(x) {} +++}; +++struct float2 { +++ float x, y; +++ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +++}; +++struct float4 { +++ float x, y, z, w; +++ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct double1 { +++ double x; +++ __host__ __device__ double1(double x = 0) : x(x) {} +++}; +++struct double2 { +++ double x, y; +++ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +++}; +++struct double4 { +++ double x, y, z, w; +++ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++typedef unsigned long long cudaTextureObject_t; +++typedef unsigned long long cudaSurfaceObject_t; +++ +++enum cudaTextureReadMode { +++ cudaReadModeNormalizedFloat, +++ cudaReadModeElementType +++}; +++ +++enum cudaSurfaceBoundaryMode { +++ cudaBoundaryModeZero, +++ cudaBoundaryModeClamp, +++ cudaBoundaryModeTrap +++}; +++ +++enum { +++ cudaTextureType1D, +++ cudaTextureType2D, +++ cudaTextureType3D, +++ cudaTextureTypeCubemap, +++ cudaTextureType1DLayered, +++ cudaTextureType2DLayered, +++ cudaTextureTypeCubemapLayered +++}; +++ +++struct textureReference {}; +++template +++struct __attribute__((device_builtin_texture_type)) texture +++ : public textureReference {}; +++ +++#endif // !__NVCC__ ++diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu ++--- a/clang/test/CodeGen/nvptx-surface.cu +++++ b/clang/test/CodeGen/nvptx-surface.cu ++@@ -1,6 +1,6 @@ ++ // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s ++ // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s ++-#include "include/cuda.h" +++#include "Inputs/cuda.h" ++ ++ #include "__clang_cuda_texture_intrinsics.h" ++ ++diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp ++--- a/clang/test/SemaTemplate/dependent-names.cpp +++++ b/clang/test/SemaTemplate/dependent-names.cpp ++@@ -458,3 +458,12 @@ ++ }; ++ int f(b ba) { return ba.add<0>(); } ++ } +++ +++namespace TransformDependentTemplates { +++ template struct Test1 { +++ template +++ using Arg = typename T::template Arg; +++ void f(Arg); +++ void f(Arg); +++ }; +++} // namespace TransformDependentTemplates ++diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++@@ -15391,12 +15391,20 @@ ++ ++ if (E->State == TreeEntry::SplitVectorize) { ++ Res = FindLastInst(); +++ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { +++ for (auto *E : Entries) { +++ auto *I = dyn_cast_or_null(E->VectorizedValue); +++ if (!I) +++ I = &getLastInstructionInBundle(E); +++ if (Res->comesBefore(I)) +++ Res = I; +++ } +++ } ++ return *Res; ++ } ++ ++ // Set insertpoint for gathered loads to the very first load. ++- if (E->State != TreeEntry::SplitVectorize && ++- GatheredLoadsEntriesFirst.has_value() && +++ if (GatheredLoadsEntriesFirst.has_value() && ++ E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && ++ E->getOpcode() == Instruction::Load) { ++ Res = FindFirstInst(); ++diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ++--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ++@@ -2590,6 +2590,14 @@ ++ if (R.mayWriteToMemory() && !InterleaveR) ++ return; ++ +++ // Do not narrow interleave groups if there are VectorPointer recipes and +++ // the plan was unrolled. The recipe implicitly uses VF from +++ // VPTransformState. +++ // TODO: Remove restriction once the VF for the VectorPointer offset is +++ // modeled explicitly as operand. +++ if (isa(&R) && Plan.getUF() > 1) +++ return; +++ ++ // All other ops are allowed, but we reject uses that cannot be converted ++ // when checking all allowed consumers (store interleave groups) below. ++ if (!InterleaveR) ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ++--- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +++++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ++@@ -66,3 +66,91 @@ ++ exit: ++ ret void ++ } +++ +++define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { +++; CHECK-LABEL: define void @test_2xi64_with_wide_load( +++; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { +++; CHECK-NEXT: [[ENTRY:.*]]: +++; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] +++; CHECK: [[VECTOR_PH]]: +++; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] +++; CHECK: [[VECTOR_BODY]]: +++; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +++; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 +++; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] +++; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 +++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 +++; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 +++; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 +++; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 +++; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 +++; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] +++; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] +++; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 +++; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 +++; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] +++; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] +++; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] +++; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] +++; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> +++; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 +++; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> +++; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 +++; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 +++; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 +++; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] +++; CHECK: [[MIDDLE_BLOCK]]: +++; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] +++; CHECK: [[SCALAR_PH]]: +++; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +++; CHECK-NEXT: br label %[[LOOP:.*]] +++; CHECK: [[LOOP]]: +++; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] +++; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] +++; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +++; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 +++; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] +++; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 +++; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] +++; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 +++; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 +++; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] +++; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 +++; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] +++; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 +++; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +++; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 +++; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] +++; CHECK: [[EXIT]]: +++; CHECK-NEXT: ret void +++; +++entry: +++ br label %loop +++ +++loop: +++ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] +++ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv +++ %l.factor = load i64, ptr %arrayidx, align 8 +++ %1 = shl nsw i64 %iv, 1 +++ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 +++ %l.0 = load i64, ptr %data.0, align 8 +++ %mul.0 = mul i64 %l.factor, %l.0 +++ store i64 %mul.0, ptr %data.0, align 8 +++ %3 = or disjoint i64 %1, 1 +++ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 +++ %l.1 = load i64, ptr %data.1, align 8 +++ %mul.1 = mul i64 %l.factor, %l.1 +++ store i64 %mul.1, ptr %data.1, align 8 +++ %iv.next = add nuw nsw i64 %iv, 1 +++ %ec = icmp eq i64 %iv.next, 100 +++ br i1 %ec, label %exit, label %loop +++ +++exit: +++ ret void +++} ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ++--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ++@@ -0,0 +1,99 @@ +++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s +++ +++define void @test(ptr %0, <8 x i8> %1) { +++; CHECK-LABEL: define void @test( +++; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { +++; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 +++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 +++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 +++; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 +++; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 +++; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> +++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 +++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> +++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> +++; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) +++; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 +++; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> +++; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) +++; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) +++; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] +++; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 +++; CHECK-NEXT: ret void +++; +++ %3 = load i8, ptr %0, align 2 +++ %4 = getelementptr i8, ptr %0, i64 13442 +++ %5 = load i8, ptr %4, align 2 +++ %6 = or i8 %5, %3 +++ %7 = getelementptr i8, ptr %0, i64 13550 +++ store i8 %6, ptr %7, align 2 +++ %8 = extractelement <8 x i8> %1, i64 0 +++ %9 = or i8 %5, %8 +++ %10 = getelementptr i8, ptr %0, i64 13542 +++ store i8 %9, ptr %10, align 2 +++ %11 = getelementptr i8, ptr %0, i64 13438 +++ %12 = load i8, ptr %11, align 2 +++ %13 = or i8 %12, %3 +++ %14 = getelementptr i8, ptr %0, i64 13546 +++ store i8 %13, ptr %14, align 2 +++ %15 = extractelement <8 x i8> %1, i64 2 +++ %16 = or i8 %12, %15 +++ %17 = getelementptr i8, ptr %0, i64 13538 +++ store i8 %16, ptr %17, align 2 +++ %18 = getelementptr i8, ptr %0, i64 13440 +++ %19 = load i8, ptr %18, align 4 +++ %20 = or i8 %19, %3 +++ %21 = getelementptr i8, ptr %0, i64 13548 +++ store i8 %20, ptr %21, align 4 +++ %22 = extractelement <8 x i8> %1, i64 4 +++ %23 = or i8 %19, %22 +++ %24 = getelementptr i8, ptr %0, i64 13540 +++ store i8 %23, ptr %24, align 4 +++ %25 = getelementptr i8, ptr %0, i64 13436 +++ %26 = load i8, ptr %25, align 4 +++ %27 = getelementptr i8, ptr %0, i64 13444 +++ %28 = load i8, ptr %27, align 4 +++ %29 = or i8 %28, %26 +++ %30 = getelementptr i8, ptr %0, i64 13544 +++ store i8 %29, ptr %30, align 4 +++ %31 = or i8 %26, %8 +++ %32 = getelementptr i8, ptr %0, i64 13536 +++ store i8 %31, ptr %32, align 4 +++ %33 = getelementptr i8, ptr %0, i64 13443 +++ %34 = load i8, ptr %33, align 1 +++ %35 = or i8 %34, %3 +++ %36 = getelementptr i8, ptr %0, i64 13551 +++ store i8 %35, ptr %36, align 1 +++ %37 = extractelement <8 x i8> %1, i64 7 +++ %38 = or i8 %34, %37 +++ %39 = getelementptr i8, ptr %0, i64 13543 +++ store i8 %38, ptr %39, align 1 +++ %40 = getelementptr i8, ptr %0, i64 13439 +++ %41 = load i8, ptr %40, align 1 +++ %42 = or i8 %41, %3 +++ %43 = getelementptr i8, ptr %0, i64 13547 +++ store i8 %42, ptr %43, align 1 +++ %44 = extractelement <8 x i8> %1, i64 3 +++ %45 = or i8 %41, %44 +++ %46 = getelementptr i8, ptr %0, i64 13539 +++ store i8 %45, ptr %46, align 1 +++ %47 = getelementptr i8, ptr %0, i64 13441 +++ %48 = load i8, ptr %47, align 1 +++ %49 = or i8 %48, %3 +++ %50 = getelementptr i8, ptr %0, i64 13549 +++ store i8 %49, ptr %50, align 1 +++ %51 = extractelement <8 x i8> %1, i64 5 +++ %52 = or i8 %48, %51 +++ %53 = getelementptr i8, ptr %0, i64 13541 +++ store i8 %52, ptr %53, align 1 +++ %54 = getelementptr i8, ptr %0, i64 13437 +++ %55 = load i8, ptr %54, align 1 +++ %56 = or i8 %55, %3 +++ %57 = getelementptr i8, ptr %0, i64 13545 +++ store i8 %56, ptr %57, align 1 +++ %58 = or i8 %55, %8 +++ %59 = getelementptr i8, ptr %0, i64 13537 +++ store i8 %58, ptr %59, align 1 +++ ret void +++} ++diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ++--- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +++++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ++@@ -151,9 +151,10 @@ ++ MachineModuleInfoWrapperPass *MMIWP = ++ new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); ++ ++- legacy::PassManager PassMgrF; ++ SmallString<1024> Buf; ++ llvm::raw_svector_ostream OS(Buf); +++ legacy::PassManager PassMgrF; +++ ++ AsmPrinter *Printer = ++ addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); ++ PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 91166c3..4a58099 100644 +index 4a58099..c3bcd53 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" -- LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" -+ LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" -+ LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" +- LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" +- LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" ++ LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" ++ LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 4b7db9e5a69eb3..3789787f919440 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "98555add83dfaa334cd538a401b49130ecacb0d8" - SHARDY_SHA256 = "cafe90437597fedee14f57b3cccea63b689c254748df42c1be0105ed1d64f21f" + SHARDY_COMMIT = "cf9436497603650441904a21316fbd058551f663" + SHARDY_SHA256 = "133dcda8bf84d516f67b3bcd0e1c5a564b266a67cb96f8901ccf8be30e830d3e" tf_http_archive( name = "shardy", diff --git a/third_party/triton/llvm_integration/cl744822685.patch b/third_party/triton/llvm_integration/cl744822685.patch new file mode 100644 index 00000000000000..90163c2c4ab23c --- /dev/null +++ b/third_party/triton/llvm_integration/cl744822685.patch @@ -0,0 +1,13 @@ + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp 2025-04-07 13:13:57.000000000 -0700 +@@ -127,7 +127,8 @@ + Value cmp) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadMask = b.int_val(type.getIntOrFloatBitWidth(), -1); +- return rewriter.create(loc, type, threadMask, cmp); ++ return rewriter.create(loc, type, threadMask, cmp, ++ NVVM::VoteSyncKind::ballot); + } + + static Value mapa(RewriterBase &rewriter, Location loc, Value ptr, Value ctaid, diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index f95ab09fd5f2cb..a0964d051f1668 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -10,5 +10,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. llvm_patch_list = [ "//third_party/triton:llvm_integration/cl741558316.patch", "//third_party/triton:llvm_integration/cl742325920.patch", + "//third_party/triton:llvm_integration/cl744822685.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 4adb475a33423c..e0644f1eee41da 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,15 +1,949 @@ +diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td +index 58c9f74..64cfe7f 100644 +--- a/shardy/dialect/sdy/transforms/export/passes.td ++++ b/shardy/dialect/sdy/transforms/export/passes.td +@@ -114,8 +114,8 @@ def TempExplicitReshardsForOptimizationsPass : Pass<"sdy-temp-explicit-reshards- + This pass is a temporary solution until we can enable the + `sdy-insert-explicit-reshards` pass by default. + +- It allows us to insert explicit reshards on specific operations for +- optimizations. ++ It allows us to improve specific use cases where the partitioner does the ++ sub-optimal thing. + }]; + } + +diff --git a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc +index b20b794..0642e3c 100644 +--- a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc ++++ b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc +@@ -29,7 +29,6 @@ limitations under the License. + #include "mlir/Support/LLVM.h" + #include "shardy/dialect/sdy/ir/dialect.h" + #include "shardy/dialect/sdy/ir/utils.h" +-#include "shardy/dialect/sdy/transforms/export/explicit_reshards_util.h" + #include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep + #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h" + #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" +@@ -236,9 +235,6 @@ struct TempExplicitReshardsForOptimizationsPass + [&](stablehlo::DotGeneralOp dotGeneralOp) { + processDot(dotGeneralOp, rewriter, symbolTable); + }); +- if (op->getName().getStringRef().str() == "mhlo.ragged_dot") { +- insertExplicitReshardsOnOp(op, rewriter, symbolTable); +- } + }); + } + }; +diff --git a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir +index 48bcbcb..117954c 100644 +--- a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir ++++ b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir +@@ -1,8 +1,7 @@ +-// RUN: sdy_opt %s -allow-unregistered-dialect -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s ++// RUN: sdy_opt %s -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s + + sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> + sdy.mesh @other_mesh = <["x"=2, "y"=2]> +-sdy.mesh @mesh_abcd = <["a"=2, "b"=2, "c"=2, "d"=2]> + + // CHECK-LABEL: func @reshard_dot_result_to_match_lhs + func.func @reshard_dot_result_to_match_lhs( +@@ -317,77 +316,3 @@ func.func @dot_result_conflicting_sharding_mismatch_with_reduction_axes_3( + (tensor<4x2x32xf32>, tensor<2x32x8xf32>) -> tensor<4x8xf32> + return %0 : tensor<4x8xf32> + } +- +-// CHECK-LABEL: func @ragged_dot_mode_non_contracting +-func.func @ragged_dot_mode_non_contracting( +- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg1: tensor<4x16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>}, +- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<16x32x8xf32> { +- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {}, {"c"}]> : tensor<16x32x64xf32> +- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{}, {"a"}, {"c"}, {"d"}]> : tensor<4x16x64x8xf32> +- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> +- +- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ +- // CHECK: }> +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {}, {"d"}]>]> +- +- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x32x8xf32> +- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}]> : tensor<16x32x8xf32> +- // CHECK: return %[[RESHARD3]] : tensor<16x32x8xf32> +- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = +- #mhlo.ragged_dot, +- lhs_ragged_dimensions = [1], rhs_group_dimensions = [0]>}> +- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [m, i, l, k], [i, m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=4} reduction={l} need_replication={j, m}>, +- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>]>} +- : (tensor<16x32x64xf32>, tensor<4x16x64x8xf32>, tensor<16x4xi32>) -> tensor<16x32x8xf32> +- return %0 : tensor<16x32x8xf32> +-} +- +-// CHECK-LABEL: func @ragged_dot_mode_contracting +-func.func @ragged_dot_mode_contracting( +- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<4x16x32x8xf32> { +- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {"b"}, {}]> : tensor<16x32x64xf32> +- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x64x8xf32> +- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> +- +- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ +- // CHECK: }> +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{}, {"a"}, {"b"}, {"d"}]>]> +- +- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[RAGGED_DOT]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]> : tensor<4x16x32x8xf32> +- // CHECK: return %[[RESHARD3]] : tensor<4x16x32x8xf32> +- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = +- #mhlo.ragged_dot, +- lhs_ragged_dimensions = [2]>}> +- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [i, m])->([m, i, j, k]) {i=16, j=32, k=8, l=64, m=4} need_replication={l, m}>, +- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>]>} +- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<16x4xi32>) -> tensor<4x16x32x8xf32> +- return %0 : tensor<4x16x32x8xf32> +-} +- +-// CHECK-LABEL: func @ragged_dot_mode_batch +-func.func @ragged_dot_mode_batch( +- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, +- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, +- %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> { +- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ +- // CHECK: }> +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> +- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32> +- // CHECK: return %[[ALL_REDUCE]] : tensor<16x32x8xf32> +- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = +- #mhlo.ragged_dot, +- lhs_ragged_dimensions = [0]>}> +- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=1} reduction={l}>, +- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>} +- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<4xi32>) -> tensor<16x32x8xf32> +- return %0 : tensor<16x32x8xf32> +-} +diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +index 6cfed8f..4061903 100644 +--- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc ++++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +@@ -117,8 +117,8 @@ GroupIdToShardingGroups unifyShardingGroups( + int64_t reindexId = 0; + SmallDenseMap reindexMap; + for (const auto& group : shardingGroupEquivalences) { +- if (group.isLeader()) { +- reindexMap[group.getData()] = reindexId++; ++ if (group->isLeader()) { ++ reindexMap[group->getData()] = reindexId++; + } + } + +diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +index 97099a1..6a711ae 100644 +--- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir ++++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +@@ -64,8 +64,8 @@ func.func @sharding_groups_reindexes_ids(%arg0: tensor<4xf32>, %arg1: tensor<4xf + + // CHECK-LABEL: sharding_groups_reindex_ordering_matches_min_element_ordering + func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) { +- // CHECK: sdy.sharding_group %arg0 group_id=1 : tensor<4xf32> +- // CHECK: sdy.sharding_group %arg1 group_id=0 : tensor<4xf32> ++ // CHECK: sdy.sharding_group %arg0 group_id=0 : tensor<4xf32> ++ // CHECK: sdy.sharding_group %arg1 group_id=1 : tensor<4xf32> + // CHECK: sdy.sharding_group %arg2 group_id=2 : tensor<4xf32> + sdy.sharding_group %arg0 group_id = 567 : tensor<4xf32> + sdy.sharding_group %arg0 group_id = 23 : tensor<4xf32> +diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch +index 2e6ff58..97282ec 100644 +--- a/third_party/llvm/generated.patch ++++ b/third_party/llvm/generated.patch +@@ -1,23 +1,748 @@ + Auto generated patch. Do not edit or delete it, even if empty. +-diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp +---- a/clang/lib/Sema/SemaCXXScopeSpec.cpp +-+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp +-@@ -873,6 +873,7 @@ +- DependentTemplateSpecializationTypeLoc SpecTL +- = Builder.push(T); +- SpecTL.setElaboratedKeywordLoc(SourceLocation()); +-+ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); +- SpecTL.setTemplateKeywordLoc(TemplateKWLoc); +- SpecTL.setTemplateNameLoc(TemplateNameLoc); +- SpecTL.setLAngleLoc(LAngleLoc); +-diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +-+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +-@@ -1902,7 +1902,6 @@ +- name = "inv_trigf_utils", +- srcs = ["src/math/generic/inv_trigf_utils.cpp"], +- hdrs = [ +-- "src/math/generic/atan_utils.h", +- "src/math/generic/inv_trigf_utils.h", +- ], +- deps = [ ++diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp ++--- a/clang/lib/AST/ASTContext.cpp +++++ b/clang/lib/AST/ASTContext.cpp ++@@ -7011,7 +7011,7 @@ ++ getCanonicalTemplateArgument(subst->getArgumentPack()); ++ return getSubstTemplateTemplateParmPack( ++ canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), ++- subst->getFinal(), subst->getIndex()); +++ subst->getIndex(), subst->getFinal()); ++ } ++ case TemplateName::DeducedTemplate: { ++ assert(IgnoreDeduced == false); ++diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ++--- a/clang/lib/Sema/TreeTransform.h +++++ b/clang/lib/Sema/TreeTransform.h ++@@ -7765,17 +7765,23 @@ ++ NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); ++ NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); ++ ++- typedef TemplateArgumentLocContainerIterator< ++- DependentTemplateSpecializationTypeLoc> ArgIterator; ++- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), ++- ArgIterator(TL, TL.getNumArgs()), ++- NewTemplateArgs)) +++ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); +++ +++ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), +++ ArgsRange.end(), NewTemplateArgs)) ++ return QualType(); +++ bool TemplateArgumentsChanged = !llvm::equal( +++ ArgsRange, NewTemplateArgs.arguments(), +++ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { +++ return A.getArgument().structurallyEquals(B.getArgument()); +++ }); ++ ++ const DependentTemplateStorage &DTN = T->getDependentTemplateName(); ++ ++ QualType Result = TL.getType(); ++- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { +++ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || +++ TemplateArgumentsChanged) { ++ TemplateName Name = getDerived().RebuildTemplateName( ++ SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), ++ /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ++--- a/clang/lib/Serialization/ASTReaderStmt.cpp +++++ b/clang/lib/Serialization/ASTReaderStmt.cpp ++@@ -2229,6 +2229,7 @@ ++ E->PackIndex = Record.readInt(); ++ else ++ E->PackIndex = 0; +++ E->Final = CurrentUnpackingBits->getNextBit(); ++ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); ++ E->Replacement = Record.readSubExpr(); ++ } ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ++--- a/clang/lib/Serialization/ASTWriterStmt.cpp +++++ b/clang/lib/Serialization/ASTWriterStmt.cpp ++@@ -2229,6 +2229,7 @@ ++ CurrentPackingBits.addBit((bool)E->getPackIndex()); ++ if (auto PackIndex = E->getPackIndex()) ++ Record.push_back(*PackIndex + 1); +++ CurrentPackingBits.addBit(E->getFinal()); ++ ++ Record.AddSourceLocation(E->getNameLoc()); ++ Record.AddStmt(E->getReplacement()); ++diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h ++--- a/clang/test/CodeGen/include/cuda.h +++++ b/clang/test/CodeGen/include/cuda.h ++@@ -1,194 +0,0 @@ ++-/* Minimal declarations for CUDA support. Testing purposes only. ++- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h ++- */ ++-#pragma once ++- ++-// Make this file work with nvcc, for testing compatibility. ++- ++-#ifndef __NVCC__ ++-#define __constant__ __attribute__((constant)) ++-#define __device__ __attribute__((device)) ++-#define __global__ __attribute__((global)) ++-#define __host__ __attribute__((host)) ++-#define __shared__ __attribute__((shared)) ++-#define __managed__ __attribute__((managed)) ++-#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) ++- ++-struct dim3 { ++- unsigned x, y, z; ++- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} ++-}; ++- ++-// Host- and device-side placement new overloads. ++-void *operator new(__SIZE_TYPE__, void *p) { return p; } ++-void *operator new[](__SIZE_TYPE__, void *p) { return p; } ++-__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } ++-__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } ++- ++-#define CUDA_VERSION 10100 ++- ++-struct char1 { ++- char x; ++- __host__ __device__ char1(char x = 0) : x(x) {} ++-}; ++-struct char2 { ++- char x, y; ++- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} ++-}; ++-struct char4 { ++- char x, y, z, w; ++- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct uchar1 { ++- unsigned char x; ++- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} ++-}; ++-struct uchar2 { ++- unsigned char x, y; ++- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} ++-}; ++-struct uchar4 { ++- unsigned char x, y, z, w; ++- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct short1 { ++- short x; ++- __host__ __device__ short1(short x = 0) : x(x) {} ++-}; ++-struct short2 { ++- short x, y; ++- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} ++-}; ++-struct short4 { ++- short x, y, z, w; ++- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct ushort1 { ++- unsigned short x; ++- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} ++-}; ++-struct ushort2 { ++- unsigned short x, y; ++- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} ++-}; ++-struct ushort4 { ++- unsigned short x, y, z, w; ++- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct int1 { ++- int x; ++- __host__ __device__ int1(int x = 0) : x(x) {} ++-}; ++-struct int2 { ++- int x, y; ++- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} ++-}; ++-struct int4 { ++- int x, y, z, w; ++- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct uint1 { ++- unsigned x; ++- __host__ __device__ uint1(unsigned x = 0) : x(x) {} ++-}; ++-struct uint2 { ++- unsigned x, y; ++- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} ++-}; ++-struct uint3 { ++- unsigned x, y, z; ++- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} ++-}; ++-struct uint4 { ++- unsigned x, y, z, w; ++- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct longlong1 { ++- long long x; ++- __host__ __device__ longlong1(long long x = 0) : x(x) {} ++-}; ++-struct longlong2 { ++- long long x, y; ++- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} ++-}; ++-struct longlong4 { ++- long long x, y, z, w; ++- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct ulonglong1 { ++- unsigned long long x; ++- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} ++-}; ++-struct ulonglong2 { ++- unsigned long long x, y; ++- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} ++-}; ++-struct ulonglong4 { ++- unsigned long long x, y, z, w; ++- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct float1 { ++- float x; ++- __host__ __device__ float1(float x = 0) : x(x) {} ++-}; ++-struct float2 { ++- float x, y; ++- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} ++-}; ++-struct float4 { ++- float x, y, z, w; ++- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-struct double1 { ++- double x; ++- __host__ __device__ double1(double x = 0) : x(x) {} ++-}; ++-struct double2 { ++- double x, y; ++- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} ++-}; ++-struct double4 { ++- double x, y, z, w; ++- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} ++-}; ++- ++-typedef unsigned long long cudaTextureObject_t; ++-typedef unsigned long long cudaSurfaceObject_t; ++- ++-enum cudaTextureReadMode { ++- cudaReadModeNormalizedFloat, ++- cudaReadModeElementType ++-}; ++- ++-enum cudaSurfaceBoundaryMode { ++- cudaBoundaryModeZero, ++- cudaBoundaryModeClamp, ++- cudaBoundaryModeTrap ++-}; ++- ++-enum { ++- cudaTextureType1D, ++- cudaTextureType2D, ++- cudaTextureType3D, ++- cudaTextureTypeCubemap, ++- cudaTextureType1DLayered, ++- cudaTextureType2DLayered, ++- cudaTextureTypeCubemapLayered ++-}; ++- ++-struct textureReference {}; ++-template ++-struct __attribute__((device_builtin_texture_type)) texture ++- : public textureReference {}; ++- ++-#endif // !__NVCC__ ++diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h ++--- a/clang/test/CodeGen/Inputs/cuda.h +++++ b/clang/test/CodeGen/Inputs/cuda.h ++@@ -0,0 +1,194 @@ +++/* Minimal declarations for CUDA support. Testing purposes only. +++ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +++ */ +++#pragma once +++ +++// Make this file work with nvcc, for testing compatibility. +++ +++#ifndef __NVCC__ +++#define __constant__ __attribute__((constant)) +++#define __device__ __attribute__((device)) +++#define __global__ __attribute__((global)) +++#define __host__ __attribute__((host)) +++#define __shared__ __attribute__((shared)) +++#define __managed__ __attribute__((managed)) +++#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +++ +++struct dim3 { +++ unsigned x, y, z; +++ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +++}; +++ +++// Host- and device-side placement new overloads. +++void *operator new(__SIZE_TYPE__, void *p) { return p; } +++void *operator new[](__SIZE_TYPE__, void *p) { return p; } +++__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +++__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +++ +++#define CUDA_VERSION 10100 +++ +++struct char1 { +++ char x; +++ __host__ __device__ char1(char x = 0) : x(x) {} +++}; +++struct char2 { +++ char x, y; +++ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +++}; +++struct char4 { +++ char x, y, z, w; +++ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct uchar1 { +++ unsigned char x; +++ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +++}; +++struct uchar2 { +++ unsigned char x, y; +++ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +++}; +++struct uchar4 { +++ unsigned char x, y, z, w; +++ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct short1 { +++ short x; +++ __host__ __device__ short1(short x = 0) : x(x) {} +++}; +++struct short2 { +++ short x, y; +++ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +++}; +++struct short4 { +++ short x, y, z, w; +++ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct ushort1 { +++ unsigned short x; +++ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +++}; +++struct ushort2 { +++ unsigned short x, y; +++ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +++}; +++struct ushort4 { +++ unsigned short x, y, z, w; +++ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct int1 { +++ int x; +++ __host__ __device__ int1(int x = 0) : x(x) {} +++}; +++struct int2 { +++ int x, y; +++ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +++}; +++struct int4 { +++ int x, y, z, w; +++ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct uint1 { +++ unsigned x; +++ __host__ __device__ uint1(unsigned x = 0) : x(x) {} +++}; +++struct uint2 { +++ unsigned x, y; +++ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +++}; +++struct uint3 { +++ unsigned x, y, z; +++ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +++}; +++struct uint4 { +++ unsigned x, y, z, w; +++ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct longlong1 { +++ long long x; +++ __host__ __device__ longlong1(long long x = 0) : x(x) {} +++}; +++struct longlong2 { +++ long long x, y; +++ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +++}; +++struct longlong4 { +++ long long x, y, z, w; +++ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct ulonglong1 { +++ unsigned long long x; +++ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +++}; +++struct ulonglong2 { +++ unsigned long long x, y; +++ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +++}; +++struct ulonglong4 { +++ unsigned long long x, y, z, w; +++ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct float1 { +++ float x; +++ __host__ __device__ float1(float x = 0) : x(x) {} +++}; +++struct float2 { +++ float x, y; +++ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +++}; +++struct float4 { +++ float x, y, z, w; +++ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++struct double1 { +++ double x; +++ __host__ __device__ double1(double x = 0) : x(x) {} +++}; +++struct double2 { +++ double x, y; +++ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +++}; +++struct double4 { +++ double x, y, z, w; +++ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +++}; +++ +++typedef unsigned long long cudaTextureObject_t; +++typedef unsigned long long cudaSurfaceObject_t; +++ +++enum cudaTextureReadMode { +++ cudaReadModeNormalizedFloat, +++ cudaReadModeElementType +++}; +++ +++enum cudaSurfaceBoundaryMode { +++ cudaBoundaryModeZero, +++ cudaBoundaryModeClamp, +++ cudaBoundaryModeTrap +++}; +++ +++enum { +++ cudaTextureType1D, +++ cudaTextureType2D, +++ cudaTextureType3D, +++ cudaTextureTypeCubemap, +++ cudaTextureType1DLayered, +++ cudaTextureType2DLayered, +++ cudaTextureTypeCubemapLayered +++}; +++ +++struct textureReference {}; +++template +++struct __attribute__((device_builtin_texture_type)) texture +++ : public textureReference {}; +++ +++#endif // !__NVCC__ ++diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu ++--- a/clang/test/CodeGen/nvptx-surface.cu +++++ b/clang/test/CodeGen/nvptx-surface.cu ++@@ -1,6 +1,6 @@ ++ // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s ++ // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s ++-#include "include/cuda.h" +++#include "Inputs/cuda.h" ++ ++ #include "__clang_cuda_texture_intrinsics.h" ++ ++diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp ++--- a/clang/test/SemaTemplate/dependent-names.cpp +++++ b/clang/test/SemaTemplate/dependent-names.cpp ++@@ -458,3 +458,12 @@ ++ }; ++ int f(b ba) { return ba.add<0>(); } ++ } +++ +++namespace TransformDependentTemplates { +++ template struct Test1 { +++ template +++ using Arg = typename T::template Arg; +++ void f(Arg); +++ void f(Arg); +++ }; +++} // namespace TransformDependentTemplates ++diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++@@ -15391,12 +15391,20 @@ ++ ++ if (E->State == TreeEntry::SplitVectorize) { ++ Res = FindLastInst(); +++ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { +++ for (auto *E : Entries) { +++ auto *I = dyn_cast_or_null(E->VectorizedValue); +++ if (!I) +++ I = &getLastInstructionInBundle(E); +++ if (Res->comesBefore(I)) +++ Res = I; +++ } +++ } ++ return *Res; ++ } ++ ++ // Set insertpoint for gathered loads to the very first load. ++- if (E->State != TreeEntry::SplitVectorize && ++- GatheredLoadsEntriesFirst.has_value() && +++ if (GatheredLoadsEntriesFirst.has_value() && ++ E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && ++ E->getOpcode() == Instruction::Load) { ++ Res = FindFirstInst(); ++diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ++--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ++@@ -2590,6 +2590,14 @@ ++ if (R.mayWriteToMemory() && !InterleaveR) ++ return; ++ +++ // Do not narrow interleave groups if there are VectorPointer recipes and +++ // the plan was unrolled. The recipe implicitly uses VF from +++ // VPTransformState. +++ // TODO: Remove restriction once the VF for the VectorPointer offset is +++ // modeled explicitly as operand. +++ if (isa(&R) && Plan.getUF() > 1) +++ return; +++ ++ // All other ops are allowed, but we reject uses that cannot be converted ++ // when checking all allowed consumers (store interleave groups) below. ++ if (!InterleaveR) ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ++--- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +++++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ++@@ -66,3 +66,91 @@ ++ exit: ++ ret void ++ } +++ +++define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { +++; CHECK-LABEL: define void @test_2xi64_with_wide_load( +++; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { +++; CHECK-NEXT: [[ENTRY:.*]]: +++; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] +++; CHECK: [[VECTOR_PH]]: +++; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] +++; CHECK: [[VECTOR_BODY]]: +++; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +++; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 +++; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] +++; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 +++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 +++; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 +++; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 +++; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 +++; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 +++; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] +++; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] +++; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 +++; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 +++; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +++; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] +++; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] +++; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] +++; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] +++; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> +++; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 +++; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> +++; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 +++; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 +++; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 +++; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] +++; CHECK: [[MIDDLE_BLOCK]]: +++; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] +++; CHECK: [[SCALAR_PH]]: +++; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +++; CHECK-NEXT: br label %[[LOOP:.*]] +++; CHECK: [[LOOP]]: +++; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] +++; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] +++; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +++; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 +++; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] +++; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 +++; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] +++; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 +++; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 +++; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] +++; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 +++; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] +++; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 +++; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +++; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 +++; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] +++; CHECK: [[EXIT]]: +++; CHECK-NEXT: ret void +++; +++entry: +++ br label %loop +++ +++loop: +++ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] +++ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv +++ %l.factor = load i64, ptr %arrayidx, align 8 +++ %1 = shl nsw i64 %iv, 1 +++ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 +++ %l.0 = load i64, ptr %data.0, align 8 +++ %mul.0 = mul i64 %l.factor, %l.0 +++ store i64 %mul.0, ptr %data.0, align 8 +++ %3 = or disjoint i64 %1, 1 +++ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 +++ %l.1 = load i64, ptr %data.1, align 8 +++ %mul.1 = mul i64 %l.factor, %l.1 +++ store i64 %mul.1, ptr %data.1, align 8 +++ %iv.next = add nuw nsw i64 %iv, 1 +++ %ec = icmp eq i64 %iv.next, 100 +++ br i1 %ec, label %exit, label %loop +++ +++exit: +++ ret void +++} ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ++--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ++@@ -0,0 +1,99 @@ +++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s +++ +++define void @test(ptr %0, <8 x i8> %1) { +++; CHECK-LABEL: define void @test( +++; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { +++; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 +++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 +++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 +++; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 +++; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 +++; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> +++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 +++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> +++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> +++; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) +++; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 +++; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> +++; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) +++; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) +++; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] +++; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 +++; CHECK-NEXT: ret void +++; +++ %3 = load i8, ptr %0, align 2 +++ %4 = getelementptr i8, ptr %0, i64 13442 +++ %5 = load i8, ptr %4, align 2 +++ %6 = or i8 %5, %3 +++ %7 = getelementptr i8, ptr %0, i64 13550 +++ store i8 %6, ptr %7, align 2 +++ %8 = extractelement <8 x i8> %1, i64 0 +++ %9 = or i8 %5, %8 +++ %10 = getelementptr i8, ptr %0, i64 13542 +++ store i8 %9, ptr %10, align 2 +++ %11 = getelementptr i8, ptr %0, i64 13438 +++ %12 = load i8, ptr %11, align 2 +++ %13 = or i8 %12, %3 +++ %14 = getelementptr i8, ptr %0, i64 13546 +++ store i8 %13, ptr %14, align 2 +++ %15 = extractelement <8 x i8> %1, i64 2 +++ %16 = or i8 %12, %15 +++ %17 = getelementptr i8, ptr %0, i64 13538 +++ store i8 %16, ptr %17, align 2 +++ %18 = getelementptr i8, ptr %0, i64 13440 +++ %19 = load i8, ptr %18, align 4 +++ %20 = or i8 %19, %3 +++ %21 = getelementptr i8, ptr %0, i64 13548 +++ store i8 %20, ptr %21, align 4 +++ %22 = extractelement <8 x i8> %1, i64 4 +++ %23 = or i8 %19, %22 +++ %24 = getelementptr i8, ptr %0, i64 13540 +++ store i8 %23, ptr %24, align 4 +++ %25 = getelementptr i8, ptr %0, i64 13436 +++ %26 = load i8, ptr %25, align 4 +++ %27 = getelementptr i8, ptr %0, i64 13444 +++ %28 = load i8, ptr %27, align 4 +++ %29 = or i8 %28, %26 +++ %30 = getelementptr i8, ptr %0, i64 13544 +++ store i8 %29, ptr %30, align 4 +++ %31 = or i8 %26, %8 +++ %32 = getelementptr i8, ptr %0, i64 13536 +++ store i8 %31, ptr %32, align 4 +++ %33 = getelementptr i8, ptr %0, i64 13443 +++ %34 = load i8, ptr %33, align 1 +++ %35 = or i8 %34, %3 +++ %36 = getelementptr i8, ptr %0, i64 13551 +++ store i8 %35, ptr %36, align 1 +++ %37 = extractelement <8 x i8> %1, i64 7 +++ %38 = or i8 %34, %37 +++ %39 = getelementptr i8, ptr %0, i64 13543 +++ store i8 %38, ptr %39, align 1 +++ %40 = getelementptr i8, ptr %0, i64 13439 +++ %41 = load i8, ptr %40, align 1 +++ %42 = or i8 %41, %3 +++ %43 = getelementptr i8, ptr %0, i64 13547 +++ store i8 %42, ptr %43, align 1 +++ %44 = extractelement <8 x i8> %1, i64 3 +++ %45 = or i8 %41, %44 +++ %46 = getelementptr i8, ptr %0, i64 13539 +++ store i8 %45, ptr %46, align 1 +++ %47 = getelementptr i8, ptr %0, i64 13441 +++ %48 = load i8, ptr %47, align 1 +++ %49 = or i8 %48, %3 +++ %50 = getelementptr i8, ptr %0, i64 13549 +++ store i8 %49, ptr %50, align 1 +++ %51 = extractelement <8 x i8> %1, i64 5 +++ %52 = or i8 %48, %51 +++ %53 = getelementptr i8, ptr %0, i64 13541 +++ store i8 %52, ptr %53, align 1 +++ %54 = getelementptr i8, ptr %0, i64 13437 +++ %55 = load i8, ptr %54, align 1 +++ %56 = or i8 %55, %3 +++ %57 = getelementptr i8, ptr %0, i64 13545 +++ store i8 %56, ptr %57, align 1 +++ %58 = or i8 %55, %8 +++ %59 = getelementptr i8, ptr %0, i64 13537 +++ store i8 %58, ptr %59, align 1 +++ ret void +++} ++diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ++--- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +++++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ++@@ -151,9 +151,10 @@ ++ MachineModuleInfoWrapperPass *MMIWP = ++ new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); ++ ++- legacy::PassManager PassMgrF; ++ SmallString<1024> Buf; ++ llvm::raw_svector_ostream OS(Buf); +++ legacy::PassManager PassMgrF; +++ ++ AsmPrinter *Printer = ++ addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); ++ PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 91166c3..4a58099 100644 +index 4a58099..c3bcd53 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "537b6541e8067d7ef7aa38791989fca6303b7fdf" -- LLVM_SHA256 = "6dca45b9afe4f530a29f7cfd21b183f6ee51c61b0a4dede2f1173049615cab5f" -+ LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" -+ LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" +- LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" +- LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" ++ LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" ++ LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 4b7db9e5a69eb3..3789787f919440 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "98555add83dfaa334cd538a401b49130ecacb0d8" - SHARDY_SHA256 = "cafe90437597fedee14f57b3cccea63b689c254748df42c1be0105ed1d64f21f" + SHARDY_COMMIT = "cf9436497603650441904a21316fbd058551f663" + SHARDY_SHA256 = "133dcda8bf84d516f67b3bcd0e1c5a564b266a67cb96f8901ccf8be30e830d3e" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/triton/llvm_integration/cl744822685.patch b/third_party/xla/third_party/triton/llvm_integration/cl744822685.patch new file mode 100644 index 00000000000000..90163c2c4ab23c --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl744822685.patch @@ -0,0 +1,13 @@ + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp 2025-03-25 07:48:50.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp 2025-04-07 13:13:57.000000000 -0700 +@@ -127,7 +127,8 @@ + Value cmp) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadMask = b.int_val(type.getIntOrFloatBitWidth(), -1); +- return rewriter.create(loc, type, threadMask, cmp); ++ return rewriter.create(loc, type, threadMask, cmp, ++ NVVM::VoteSyncKind::ballot); + } + + static Value mapa(RewriterBase &rewriter, Location loc, Value ptr, Value ctaid, diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index f95ab09fd5f2cb..a0964d051f1668 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -10,5 +10,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. llvm_patch_list = [ "//third_party/triton:llvm_integration/cl741558316.patch", "//third_party/triton:llvm_integration/cl742325920.patch", + "//third_party/triton:llvm_integration/cl744822685.patch", # Add new patches just above this line ] diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 464884a2d426f0..54ca1e51bbea80 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -456,12 +456,12 @@ bool simplifyLoopDeallocs(Block& block) { } for (auto it = eq.begin(), e = eq.end(); it != e; ++it) { - if (!it->isLeader()) continue; + if (!(*it)->isLeader()) continue; breaks_if_you_move_ops::ValueSet equivalentOperands; llvm::SmallVector deallocs; bool failed = false; - for (auto member = eq.member_begin(*it); + for (auto member = eq.member_begin(**it); !failed && member != eq.member_end(); ++member) { if (operands.contains(*member)) { equivalentOperands.insert(*member); From 1a584611e5e6ade2fe40c682e79ed6aa296415f1 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Tue, 8 Apr 2025 18:19:35 -0700 Subject: [PATCH 0410/1324] [XLA] Insert missing optimization barrier during unrolling This is needed to keep the schedule consistent across unrolled bodies. PiperOrigin-RevId: 745367066 --- third_party/xla/xla/service/BUILD | 5 ++--- .../service/while_loop_pipeline_unroller.cc | 2 ++ .../while_loop_pipeline_unroller_test.cc | 21 +++++++++++-------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 4241ca51b3da33..26c8cdf9cb47ff 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6551,15 +6551,14 @@ xla_cc_test( deps = [ ":copy_insertion", ":while_loop_pipeline_unroller", - "//xla:test_helpers", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller.cc b/third_party/xla/xla/service/while_loop_pipeline_unroller.cc index 8f242ab227f869..06867a11ef7ac2 100644 --- a/third_party/xla/xla/service/while_loop_pipeline_unroller.cc +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller.cc @@ -147,6 +147,8 @@ absl::StatusOr WhileLoopPipelineUnroller::Run( // Find the original bodies root after inlining. This is the inputs for // the next (unrolled) loop iteration. input_tuple = inline_map[loop_step->root_instruction()]; + input_tuple = unrolled_body->AddInstruction(HloInstruction::CreateUnary( + input_tuple->shape(), HloOpcode::kOptimizationBarrier, input_tuple)); original_roots.push_back(input_tuple); } // The final original root is now the root of the unrolled loop. diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc b/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc index 82793a2e52b28f..55dae69ae588db 100644 --- a/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc @@ -23,20 +23,19 @@ limitations under the License. #include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/service/copy_insertion.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { // Copied from xla/service/copy_insertion_test.cc -int64_t CountCopies(const HloComputation& computation) { +int64_t Count(HloOpcode opcode, const HloComputation& computation) { int64_t count = 0; for (const auto& instruction : computation.instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { + if (instruction->opcode() == opcode) { count++; } } @@ -91,12 +90,14 @@ ENTRY main { // arg.1 moves to index 0. // arg.2 moves to index 1. // out.0 moves to index 2. - EXPECT_EQ(CountCopies(*original_loop->while_body()), 3); + EXPECT_EQ(Count(HloOpcode::kCopy, *original_loop->while_body()), 3); const HloInstruction* unrolled_loop = original_loop->operand(0); EXPECT_EQ(unrolled_loop->opcode(), HloOpcode::kWhile); // There should be no copies inserted into the unrolled loop. - EXPECT_EQ(CountCopies(*unrolled_loop->while_body()), 0); + EXPECT_EQ(Count(HloOpcode::kCopy, *unrolled_loop->while_body()), 0); + EXPECT_EQ( + Count(HloOpcode::kOptimizationBarrier, *unrolled_loop->while_body()), 4); } TEST_F(WhileLoopPipelineUnrollerTest, PipelinedLoopWithInfeed) { @@ -157,12 +158,14 @@ ENTRY main { FindInstruction(module.get(), "while.0"); // The original loop should have 1 copy. // arg.2 moves to index 1. - EXPECT_EQ(CountCopies(*original_loop->while_body()), 1); + EXPECT_EQ(Count(HloOpcode::kCopy, *original_loop->while_body()), 1); const HloInstruction* unrolled_loop = original_loop->operand(0); EXPECT_EQ(unrolled_loop->opcode(), HloOpcode::kWhile); // There should be no copies inserted into the unrolled loop. - EXPECT_EQ(CountCopies(*unrolled_loop->while_body()), 0); + EXPECT_EQ(Count(HloOpcode::kCopy, *unrolled_loop->while_body()), 0); + EXPECT_EQ( + Count(HloOpcode::kOptimizationBarrier, *unrolled_loop->while_body()), 3); // All infeeds in the unrolled body need to be ordered with respect to each // other. From 6e29b761c6e6c9a66d1f9cbf1065c9b8f3fafdd8 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 8 Apr 2025 18:26:05 -0700 Subject: [PATCH 0411/1324] Delete reduction_emitter_test and transpose_emitter_test. Their targets were removed during migration but missed these two files. PiperOrigin-RevId: 745368348 --- .../gpu/tests/reduction_emitter_test.cc | 55 ---- .../gpu/tests/transpose_emitter_test.cc | 261 ------------------ 2 files changed, 316 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/tests/reduction_emitter_test.cc delete mode 100644 third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc diff --git a/third_party/xla/xla/service/gpu/tests/reduction_emitter_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_emitter_test.cc deleted file mode 100644 index b6ecfd527dee50..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduction_emitter_test.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/error_spec.h" -#include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace { - -class ReductionEmitterTest : public gpu::GpuCodegenTest {}; - -TEST_F(ReductionEmitterTest, ProperShmemAllocation) { - const char* const kHloString = R"( - HloModule m - - add { - a = f64[] parameter(0) - b = f64[] parameter(1) - ROOT out = f64[] add(a, b) - } - - fused_computation { - p1 = f64[1024,1024]{1,0} parameter(0) - p2 = f64[1024,1024]{1,0} parameter(1) - s = pred[1024,1024]{1,0} parameter(2) - p = f64[1024,1024]{1,0} select(s, p1, p2) - z = f64[] constant(0) - ROOT out = f64[1024]{0} reduce(p, z), to_apply=add, dimensions={0} - } - - ENTRY e { - p1 = f64[1024,1024]{1,0} parameter(0) - p2 = f64[1024,1024]{1,0} parameter(1) - s = pred[1024,1024]{1,0} parameter(2) - ROOT f = f64[1024]{0} fusion(p1, p2, s), kind=kInput, calls=fused_computation - })"; - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -} // namespace -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc b/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc deleted file mode 100644 index d7f3de7ba224fe..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc +++ /dev/null @@ -1,261 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "xla/error_spec.h" -#include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace { - -class TransposeEmitterTest : public gpu::GpuCodegenTest { - protected: - TransposeEmitterTest() = default; -}; - -// TODO(cheshire): Test vectorization somehow. - -TEST_F(TransposeEmitterTest, SimpleLogicalTranspose) { - const char* const kHloString = R"( - HloModule m - - ENTRY e { - para0 = f16[32,16,64]{2,1,0} parameter(0) - ROOT t = f16[64,32,16]{2,1,0} transpose(para0), dimensions={2,0,1} - })"; - - auto expected_ir = R"( -; CHECK: call void BARRIER() -)"; - CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/true); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, BatchedLogicalTranspose) { - const char* const kHloString = R"( - HloModule m - - ENTRY e { - para0 = f16[32,48,64]{2,1,0} parameter(0) - ROOT t = f16[32,64,48]{2,1,0} transpose(para0), dimensions={0,2,1} - })"; - - auto expected_ir = R"( -; CHECK: call void BARRIER() -)"; - CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, FusionAfterHero) { - const char* hlo = R"( -HloModule m - -%fused_computation { - %param_0.1 = f32[16,32]{1,0} parameter(0) - %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - bc = f32[1,16,32]{2,1,0} bitcast(%s.1) - %t.1 = f32[1,32,16]{2,1,0} transpose(bc), dimensions={0,2,1} - b = f32[32,16,1]{2,1,0} bitcast(%t.1) - ROOT o = f32[32,16,1]{2,1,0} sqrt(b) -} - -ENTRY main { - %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = f32[32,16,1]{2,1,0} fusion(%p), kind=kInput, calls=%fused_computation -} - )"; - - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, MultipleTransposesWithPostFusion) { - const char* hlo = R"( -HloModule m - -%fused_computation { - %param_0.1 = f32[16,32]{1,0} parameter(0) - %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) - %bc.2 = f32[1,16,32]{2,1,0} bitcast(%param_0.1) - %t.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} - %t1.1 = f32[1,32,16]{2,1,0} transpose(%bc.2), dimensions={0,2,1} - %r.1 = f32[32,16,1]{2,1,0} reshape(%t.1) - %r1.1 = f32[32,16,1]{2,1,0} reshape(%t1.1) - ROOT %tuple = (f32[32,16,1]{2,1,0}, f32[32,16,1]{2,1,0}) tuple(%r.1, %r1.1) -} - -ENTRY main { - %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16,1]{2,1,0}, f32[32,16,1]{2,1,0}) fusion(%p), kind=kInput, calls=%fused_computation -} - )"; - - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, MultipleTransposes) { - const char* hlo = R"( -HloModule m - -%fused_computation { - %param_0.1 = f32[16,32]{1,0} parameter(0) - %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %t1.1 = f32[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} - ROOT %tuple = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(%t.1, %t1.1) -} - -ENTRY main { - %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation -} - )"; - - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, MultipleTransposesLogical) { - const char* hlo = R"( -HloModule m - -%fused_computation { - %param_0.1 = f32[16,32]{1,0} parameter(0) - %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) - %bc.2 = f32[1,16,32]{2,1,0} bitcast(%param_0.1) - %c.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} - %c1.1 = f32[1,32,16]{2,1,0} transpose(%bc.2), dimensions={0,2,1} - ROOT %tuple = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple(%c.1, %c1.1) -} - -ENTRY main { - %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) fusion(%p), kind=kInput, calls=%fused_computation -} - )"; - - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, MultipleTransposesDifferentTypes) { - const char* hlo = R"( -HloModule module - -%fused_computation (param_0.1: f16[16,32]) -> (f32[32,16], f16[32,16]) { - %param_0.1 = f16[16,32]{1,0} parameter(0) - %s.1 = f32[16,32]{1,0} convert(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %t1.1 = f16[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} - ROOT %tuple = (f32[32,16]{1,0}, f16[32,16]{1,0}) tuple(%t.1, %t1.1) -} - -ENTRY %main (p: f16[16,32]) -> (f32[32,16], f16[32,16]) { - %p = f16[16,32]{1,0} parameter(0) - %fusion = (f32[32,16]{1,0}, f16[32,16]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation - %get-tuple-element = f32[32,16]{1,0} get-tuple-element(%fusion), index=0 - %get-tuple-element.1 = f16[32,16]{1,0} get-tuple-element(%fusion), index=1 - ROOT %t = (f32[32,16]{1,0}, f16[32,16]{1,0}) tuple(%get-tuple-element, %get-tuple-element.1) -} - )"; - - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, TransposeAndInput) { - const char* hlo = R"( -HloModule m - -%fused_computation { - %param_0.1 = f32[16,32]{1,0} parameter(0) - %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %exp = f32[16,32]{1,0} exponential(%param_0.1) - ROOT %tuple = (f32[32,16]{1,0}, f32[16,32]{1,0}) tuple(%t.1, %exp) -} - -ENTRY entry { - %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16]{1,0}, f32[16,32]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation -} - )"; - - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -TEST_F(TransposeEmitterTest, InconsistentTransposes) { - const char* hlo = R"( -HloModule module - -fusion { - p0 = f32[32, 64] parameter(0) - p1 = f32[64, 32] parameter(1) - t0 = f32[64, 32] transpose(p0), dimensions={1,0} - t1 = f32[32, 64] transpose(p1), dimensions={1,0} - ROOT tuple = (f32[64, 32], f32[32, 64]) tuple(t0, t1) -} - -ENTRY module { - p0 = f32[32, 64] parameter(0) - p1 = f32[64, 32] parameter(1) - ROOT fusion = (f32[64, 32], f32[32, 64]) fusion(p0, p1), kind=kLoop, calls=fusion -} - )"; - CompileAndVerifyIr(hlo, MakePlatformSpecificLlvm(R"( -// CHECK-NOT: call void BARRIER() - )"), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); -} - -} // namespace -} // namespace xla From 880c606b12b8f8f98e0170c1948fff31a6af0370 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 8 Apr 2025 18:49:31 -0700 Subject: [PATCH 0412/1324] Add shape_with_layout parameter to ClientLibraryTestRunnerMixin comparison. PiperOrigin-RevId: 745373744 --- .../tests/client_library_test_runner_mixin.h | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index 7235f93870af50..aba09bdedfcf2d 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -35,6 +35,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/bitmap.h" @@ -143,16 +144,40 @@ class ClientLibraryTestRunnerMixin : public T { void ComputeAndCompareLiteral( XlaBuilder* const builder, const Literal& expected, const absl::Span arguments, - const std::optional error = std::nullopt) { - TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder->Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - BuildAndVerifyHloModule(computation)); - TF_ASSERT_OK_AND_ASSIGN(Literal actual, - this->Execute(std::move(module), arguments)); + const std::optional error = std::nullopt, + const Shape* shape_with_layout = nullptr) { + if (error == std::nullopt) { + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + LOG(WARNING) << "performing exact comparison of floating point numbers"; + } + } + // We allow using a float expected literal for a non float outputs. In this + // case, we need to convert the expected literal to test_type_. + const Literal* expected_ptr = &expected; + Literal converted_expected; + Shape layout_shape; + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); + expected_ptr = &converted_expected; + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(test_type_); + } + }); + shape_with_layout = &layout_shape; + } + } + TF_ASSERT_OK_AND_ASSIGN( + Literal actual, + this->ExecuteAndTransfer(builder, arguments, shape_with_layout)); if (error.has_value()) { - EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, *error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, *error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); } } From fa3670c8ca907838eec299ede10b473d365eb158 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 20:17:49 -0700 Subject: [PATCH 0413/1324] Make `safe_reinterpret_cast` work between function pointer and `void*`. XLA `reinterpret_cast`s between a function pointer and a `void*` in many places. Even though the C++ standard doesn't guarantee this is safe, POSIX requires it to be, and Windows has long supported this as a de facto standard. Since XLA only targets POSIX systems and Windows, it's fine to allow this pattern. PiperOrigin-RevId: 745394285 --- .../xla/xla/tsl/util/safe_reinterpret_cast.h | 20 +++++++++++++++---- .../tsl/util/safe_reinterpret_cast_test.cc | 8 ++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index 085f7014e459cd..51d17d0158281b 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -68,12 +68,24 @@ template struct IsSafeCast : std::false_type {}; // It's safe to cast a pointer to/from a byte-like type, or to/from the same -// type. +// type. Also, while not guaranteed by the C++ standard, POSIX mandates that +// it's safe to cast a function pointer to/from a void pointer +// (https://pubs.opengroup.org/onlinepubs/9799919799/functions/dlsym.html). +// On Windows (with MSVC), casting a function pointer to/from a void pointer has +// been a widely adopted practice for decades and is considered safe in +// practice, even though it is not explicitly guaranteed by Microsoft. template struct IsSafeCast - : std::integral_constant::value || - IsCvByteLike::value || - std::is_same_v> {}; + : std::integral_constant< + bool, + // To/from a pointer to a byte-like type. + (IsCvByteLike::value || IsCvByteLike::value) || + // From function pointer to void pointer. + (std::is_function_v&& std::is_void_v) || + // From void pointer to function pointer. + (std::is_void_v&& std::is_function_v) || + // Between the same type. + std::is_same_v> {}; // If __restrict is a macro, we assume that the compiler doesn't support // the __restrict keyword (e.g. when the code is compiled for iOS). Otherwsie, diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index 881cd73b0f4c11..7f9a48f2eb2248 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -139,5 +139,13 @@ TEST(SafeReinterpretCast, CanCastRestrictPointerToRestrictPointer) { &x)); } +void Dummy() {} + +TEST(SafeReinterepretCast, CanCastFuncPointerToFromVoidPointer) { + void* const void_p = safe_reinterpret_cast(&Dummy); + void (*func_p)() = safe_reinterpret_cast(void_p); + EXPECT_EQ(func_p, &Dummy); +} + } // namespace } // namespace tsl From b549e0f5f606bc20c2414d8fb600930d4e3c33ae Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Apr 2025 22:51:06 -0700 Subject: [PATCH 0414/1324] Automated Code Change PiperOrigin-RevId: 745436991 --- tensorflow/core/kernels/image/BUILD | 5 ++++- tensorflow/core/kernels/image/adjust_hue_op.cc | 6 ++---- tensorflow/core/kernels/image/adjust_saturation_op.cc | 8 ++++---- tensorflow/core/kernels/image/extract_image_patches_op.cc | 1 + tensorflow/core/kernels/image/random_crop_op.cc | 2 ++ tensorflow/core/kernels/image/resize_bicubic_op.cc | 5 +++++ .../kernels/image/sample_distorted_bounding_box_op.cc | 5 +++++ 7 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD index 480eadb279bb2a..a47f1771243a1b 100644 --- a/tensorflow/core/kernels/image/BUILD +++ b/tensorflow/core/kernels/image/BUILD @@ -318,7 +318,10 @@ tf_kernel_library( tf_kernel_library( name = "sample_distorted_bounding_box_op", prefix = "sample_distorted_bounding_box_op", - deps = IMAGE_DEPS + ["//tensorflow/core/kernels:stateless_random_ops"], + deps = IMAGE_DEPS + [ + "//tensorflow/core/kernels:stateless_random_ops", + "@com_google_absl//absl/log:check", + ], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/image/adjust_hue_op.cc b/tensorflow/core/kernels/image/adjust_hue_op.cc index fb089f13f8edd9..8795185c365dfd 100644 --- a/tensorflow/core/kernels/image/adjust_hue_op.cc +++ b/tensorflow/core/kernels/image/adjust_hue_op.cc @@ -11,16 +11,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif -#include "tensorflow/core/kernels/image/adjust_hue_op.h" - -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -28,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/image/adjust_hue_op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow/core/kernels/image/adjust_saturation_op.cc b/tensorflow/core/kernels/image/adjust_saturation_op.cc index 5c108aa2ab7434..5387e636f69a4f 100644 --- a/tensorflow/core/kernels/image/adjust_saturation_op.cc +++ b/tensorflow/core/kernels/image/adjust_saturation_op.cc @@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif -#include "tensorflow/core/kernels/image/adjust_saturation_op.h" - -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/image/adjust_saturation_op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow/core/kernels/image/extract_image_patches_op.cc b/tensorflow/core/kernels/image/extract_image_patches_op.cc index a1dbcd9efa3650..b40c59147e51b5 100644 --- a/tensorflow/core/kernels/image/extract_image_patches_op.cc +++ b/tensorflow/core/kernels/image/extract_image_patches_op.cc @@ -15,6 +15,7 @@ limitations under the License. // See docs in ../ops/image_ops.cc. +#include #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS diff --git a/tensorflow/core/kernels/image/random_crop_op.cc b/tensorflow/core/kernels/image/random_crop_op.cc index 987001c58c0a69..1fceed794d29a4 100644 --- a/tensorflow/core/kernels/image/random_crop_op.cc +++ b/tensorflow/core/kernels/image/random_crop_op.cc @@ -15,6 +15,8 @@ limitations under the License. // See docs in ../ops/image_ops.cc. +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/image/resize_bicubic_op.cc b/tensorflow/core/kernels/image/resize_bicubic_op.cc index 23e6251f8a0f48..338a9fbfcf9a98 100644 --- a/tensorflow/core/kernels/image/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/image/resize_bicubic_op.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/image_ops.cc +#include +#include +#include +#include +#include #define EIGEN_USE_THREADS #include diff --git a/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc index 90e26496ed8f0a..a754a8cec1fc62 100644 --- a/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc +++ b/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc @@ -15,8 +15,13 @@ limitations under the License. // See docs in ../ops/image_ops.cc. #include +#include #include +#include +#include +#include +#include "absl/log/check.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" From b18b4b1508fa9cda72fa45a38fd0c6e66cbfcb17 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 9 Apr 2025 00:09:25 -0700 Subject: [PATCH 0415/1324] PR #24860: Fixed a small typo in the log message Imported from GitHub PR https://github.com/openxla/xla/pull/24860 ```diff - "^^^ Shared objects corresondence map^^^\n\n" + "^^^ Shared objects correspondence map^^^\n\n" ``` cc @hawkinsp Copybara import of the project: -- 8184c004b550bd3fd8946242aa3f4e3e1dc2fdd7 by vfdev-5 : Fixed a typo in the log message Merging this change closes #24860 PiperOrigin-RevId: 745457369 --- third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl b/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl index 222b7d47c5283d..5c1a36466c8416 100644 --- a/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl +++ b/third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl @@ -995,7 +995,7 @@ def _pywrap_binaries_impl(ctx): final_binaries = [] original_to_final_binaries = [ - "\n\nvvv Shared objects corresondence map, target = {} vvv".format(ctx.label), + "\n\nvvv Shared objects correspondence map, target = {} vvv".format(ctx.label), ] wheel_locations = {} for i in range(0, len(pywrap_infos)): @@ -1063,7 +1063,7 @@ def _pywrap_binaries_impl(ctx): ) original_to_final_binaries.append( - "^^^ Shared objects corresondence map^^^\n\n", + "^^^ Shared objects correspondence map^^^\n\n", ) print("\n".join(original_to_final_binaries)) From ae30aa9c5fec5fcac1463aba721bdee3d9978050 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 01:10:52 -0700 Subject: [PATCH 0416/1324] Automated Code Change PiperOrigin-RevId: 745474241 --- tensorflow/lite/experimental/genai/genai_ops_wrapper.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc b/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc index 8fa8451909e57e..cd57ea2cc55d80 100644 --- a/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc +++ b/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "tensorflow/lite/experimental/genai/genai_ops.h" From c588a494d668d718902620489d226fe4127db029 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 9 Apr 2025 01:12:36 -0700 Subject: [PATCH 0417/1324] [XLA:GPU] Use raw_fd_ostream for Triton passes dump. The string stream we don't get any output if Triton pipeline crashes. PiperOrigin-RevId: 745474705 --- .../gpu/codegen/triton/fusion_emitter.cc | 60 ++++++++++++------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 0265fcea8553a5..1983d45ecdf671 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include // NOLINT #include #include #include @@ -39,6 +40,7 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -135,6 +137,7 @@ limitations under the License. #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/path.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -1814,6 +1817,7 @@ absl::StatusOr CompileTritonToLLVM( #endif bool should_dump_mlir_passes = + hlo_config.debug_options().xla_enable_dumping() && DumpingEnabledForHloModule(hlo_module) && DumpingEnabledForHloPass("triton-fusion-emitter", hlo_config.debug_options()); @@ -1821,6 +1825,40 @@ absl::StatusOr CompileTritonToLLVM( mlir::PassManager pm(&mlir_context); pm.enableVerifier(should_verify); + std::optional log_stream; + if (should_dump_mlir_passes) { + std::string outputs_dir; + if (!tsl::io::GetTestUndeclaredOutputsDir(&outputs_dir)) { + outputs_dir = hlo_config.debug_options().xla_dump_to(); + } + if (!outputs_dir.empty()) { + const std::string basename = + absl::StrCat(absl::string_view(tsl::io::Basename(hlo_module.name())), + ".", kernel_name, ".triton-passes.log"); + std::string path = tsl::io::JoinPath(outputs_dir, basename); + std::error_code err; + log_stream.emplace(path, err, llvm::sys::fs::OF_None); + if (err) { + log_stream.reset(); + LOG(ERROR) << err.message(); + } else { + pm.getContext()->disableMultithreading(); + auto print_always = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/print_always, + /*shouldPrintAfterPass=*/print_always, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure=*/true, *log_stream, + /*opPrintingFlags=*/{}); + } + } else { + LOG(ERROR) + << "--xla_dump_hlo_pass_re=triton-fusion-emitter is set, but neither " + << "the environment variable TEST_UNDECLARED_OUTPUTS_DIR nor the " + << "flag --xla_dump_to is set, so the llvm dumps are disabled."; + } + } + // TODO(b/315957220): Propagate TMA flag once it's supported. pm.addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass( device_info, /*tma_enabled=*/false)); @@ -1854,30 +1892,8 @@ absl::StatusOr CompileTritonToLLVM( // llvm::Linker::linkModules() segfaults if we don't strip locations. pm.addPass(mlir::createStripDebugInfoPass()); - std::string mlir_passes_dump_result; - llvm::raw_string_ostream log_stream(mlir_passes_dump_result); - if (should_dump_mlir_passes) { - pm.getContext()->disableMultithreading(); - auto print_always = [](mlir::Pass*, mlir::Operation*) { return true; }; - pm.enableIRPrinting(/*shouldPrintBeforePass=*/print_always, - /*shouldPrintAfterPass=*/print_always, - /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false, - /*printAfterOnlyOnFailure=*/true, log_stream, - /*opPrintingFlags=*/{}); - - pm.printAsTextualPipeline(log_stream); - log_stream.write("\n\n", 2); - } - bool succeeded = mlir::succeeded(pm.run(triton_module)); - if (should_dump_mlir_passes) { - DumpToFileInDirOrStdout(hlo_module, "", - absl::StrCat(kernel_name, ".triton-passes.log"), - mlir_passes_dump_result); - } - if (!succeeded) { return Internal("Failed to compile Triton kernel."); } From a7dbd5fa60605fd1c7fb44f61fcc3bede347b3f2 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 9 Apr 2025 01:29:17 -0700 Subject: [PATCH 0418/1324] [XLA:GPU] Delete unused `xla_gpu_ensure_minor_dot_contraction_dims` flag and tests. PiperOrigin-RevId: 745479113 --- .../fusion_emitter_device_legacy_port_test.cc | 121 ------------------ .../fusion_emitter_device_legacy_test.cc | 117 ----------------- third_party/xla/xla/debug_options_flags.cc | 8 -- .../gpu/transforms/layout_assignment.cc | 9 -- third_party/xla/xla/xla.proto | 7 +- 5 files changed, 2 insertions(+), 260 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index bbc4a7e9651e96..c40fb80185a397 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -3887,127 +3887,6 @@ ENTRY e { /*run_hlo_passes=*/false)); } -// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be -// moved to deviceless test file. We should move all the -// `TritonGemmContractionDims` tests. -class TritonGemmContractionDims : public TritonGemmTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true); - debug_options.set_xla_gpu_autotune_level(0); - return debug_options; - } -}; - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - p0 = bf16[16,40]{1,0} parameter(0) - p1 = bf16[40,32]{1,0} parameter(1) - ROOT dot = bf16[16,32]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT(module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {16, 40}, {1, 0}), - m::Fusion().WithShape(BF16, {40, 32}, {0, 1})) - .WithShape(BF16, {16, 32}, {1, 0}))); -} - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - p0 = bf16[32,4,36]{2,1,0} parameter(0) - p1 = bf16[40,4,36]{2,1,0} parameter(1) - ROOT dot = bf16[4,32,40]{2,1,0} dot(p0, p1), - lhs_batch_dims={1}, lhs_contracting_dims={2}, - rhs_batch_dims={1}, rhs_contracting_dims={2} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - // The contracting dims were already minor, so the layout is unchanged - // (non-major batch dims are fine). - EXPECT_THAT( - module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {32, 4, 36}, {2, 1, 0}), - m::Fusion().WithShape(BF16, {40, 4, 36}, {2, 1, 0})) - .WithShape(BF16, {4, 32, 40}, {2, 1, 0}))); -} - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - parameter_1 = bf16[16,16,48]{2,1,0} parameter(1) - parameter_2 = bf16[16,48,32]{2,1,0} parameter(0) - ROOT dot.16125 = bf16[16,16,32]{2,1,0} dot(parameter_1, parameter_2), - lhs_batch_dims={1}, lhs_contracting_dims={2}, - rhs_batch_dims={0}, rhs_contracting_dims={1} -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - // lhs has minor contracting dims, so the layout is changed. - // rhs changes layout to have minor contracting dims. - EXPECT_THAT( - module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {16, 16, 48}, {2, 1, 0}), - m::Fusion().WithShape(BF16, {16, 48, 32}, {1, 2, 0})) - .WithShape(BF16, {16, 16, 32}, {2, 1, 0}))); -} - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - p0 = bf16[16,32]{1,0} parameter(0) - p1 = bf16[40,32]{0,1} parameter(1) - ROOT dot = bf16[16,40]{1,0} dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - EXPECT_THAT(module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Fusion().WithShape(BF16, {16, 32}, {1, 0}), - m::Fusion().WithShape(BF16, {32, 40}, {1, 0})) - .WithShape(BF16, {16, 40}, {1, 0}))); -} - // TODO(b/393299275): this test uncovers a bug in hoisting bitcasts through // broadcasts (seems to generate a type mismatch). TEST_F(TritonTest, DISABLED_UseTF32For8BitOrLessWithF32) { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc index 9feda522639a66..1319dacde17e67 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc @@ -3968,123 +3968,6 @@ ENTRY e { /*run_hlo_passes=*/false)); } -class TritonGemmContractionDims : public TritonGemmTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true); - return debug_options; - } -}; - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - parameter.0 = bf16[16,40]{1,0} parameter(0) - parameter.1 = bf16[40,32]{1,0} parameter(1) - ROOT dot.31472 = bf16[16,32]{1,0} dot(parameter.0, parameter.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT(module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 40}, {1, 0}), - m::Op().WithShape(BF16, {40, 32}, {0, 1})) - .WithShape(BF16, {16, 32}, {1, 0}))); -} - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - parameter_0 = bf16[32,4,36]{2,1,0} parameter(0) - parameter_1 = bf16[40,4,36]{2,1,0} parameter(1) - ROOT dot.16450 = bf16[4,32,40]{2,1,0} dot(parameter_0, parameter_1), - lhs_batch_dims={1}, lhs_contracting_dims={2}, - rhs_batch_dims={1}, rhs_contracting_dims={2} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - // The contracting dims were already minor, so the layout is unchanged - // (non-major batch dims are fine). - EXPECT_THAT(module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {32, 4, 36}, {2, 1, 0}), - m::Op().WithShape(BF16, {40, 4, 36}, {2, 1, 0})) - .WithShape(BF16, {4, 32, 40}, {2, 1, 0}))); -} - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - parameter_1 = bf16[16,16,48]{2,1,0} parameter(1) - parameter_2 = bf16[16,48,32]{2,1,0} parameter(0) - ROOT dot.16125 = bf16[16,16,32]{2,1,0} dot(parameter_1, parameter_2), - lhs_batch_dims={1}, lhs_contracting_dims={2}, - rhs_batch_dims={0}, rhs_contracting_dims={1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - // lhs has minor contracting dims, so the layout is changed. - // rhs changes layout to have minor contracting dims. - EXPECT_THAT( - module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 16, 48}, {2, 1, 0}), - m::Op().WithShape(BF16, {16, 48, 32}, {1, 2, 0})) - .WithShape(BF16, {16, 16, 32}, {2, 1, 0}))); -} - -TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - parameter_0 = bf16[16,32]{1,0} parameter(0) - parameter_1 = bf16[40,32]{0,1} parameter(1) - ROOT dot.15148 = bf16[16,40]{1,0} dot(parameter_0, parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - EXPECT_THAT(module->entry_computation() - ->root_instruction() - ->fused_instructions_computation() - ->root_instruction(), - GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 32}, {1, 0}), - m::Op().WithShape(BF16, {32, 40}, {1, 0})) - .WithShape(BF16, {16, 40}, {1, 0}))); -} - TEST_F(TritonTest, UseTF32For8BitOrLessWithF32) { const std::string hlo_text = R"( HloModule t diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 9ea257825cc70a..9afc77a08a25c1 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -258,7 +258,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_while_loop_double_buffering(false); opts.set_xla_gpu_enable_while_loop_unrolling( DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL); - opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false); opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true); opts.set_xla_gpu_fail_ptx_compilation_on_register_spilling(false); opts.set_xla_gpu_llvm_verification_level(0); @@ -1944,13 +1943,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_while_loop_double_buffering), debug_options->xla_gpu_enable_while_loop_double_buffering(), "Enable double buffering for while loop")); - flag_list->push_back(tsl::Flag( - "xla_gpu_ensure_minor_dot_contraction_dims", - bool_setter_for( - &DebugOptions::set_xla_gpu_ensure_minor_dot_contraction_dims), - debug_options->xla_gpu_ensure_minor_dot_contraction_dims(), - "Ensure that the contracting dimensions for matmul operands are the most " - "minor by changing layouts accordingly")); flag_list->push_back(tsl::Flag( "xla_gpu_filter_kernels_spilling_registers_on_autotuning", bool_setter_for( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index b3ff0ea817d0bf..fd1d6182e2ef2c 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -370,20 +370,12 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints( // dimensions. Additionally, no batch dimension can be in the most // minor physical dimension for inputs or the output. - const bool xla_gpu_ensure_minor_dot_contraction_dims = - instruction->GetModule() - ->config() - .debug_options() - .xla_gpu_ensure_minor_dot_contraction_dims(); const bool pack_along_contracting_dims = instruction->GetModule() ->config() .debug_options() .xla_gpu_experimental_pack_dot_operands_along_k_dimension(); - const bool is_bf16_to_bf16 = - (output_type == PrimitiveType::BF16 && lhs.type == PrimitiveType::BF16 && - rhs.type == PrimitiveType::BF16); const bool is_s8_to_s32 = output_type == PrimitiveType::S32 && lhs.type == PrimitiveType::S8 && rhs.type == PrimitiveType::S8; @@ -395,7 +387,6 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints( const se::CudaComputeCapability* cc = std::get_if(&gpu_version_); const bool both_operands_require_minor_contraction_dims = - (is_bf16_to_bf16 && xla_gpu_ensure_minor_dot_contraction_dims) || is_s8_to_s32 || (is_fp8 && !(cc && cc->IsBlackwell())); for (const Side& side : {lhs, rhs}) { diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 6e10808a81b60d..b0389816056c9d 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -500,10 +500,6 @@ message DebugOptions { // Determine the while loop unrolling scheme. WhileLoopUnrolling xla_gpu_enable_while_loop_unrolling = 294; - // Change the layout of the second triton dot operand to be column major. - // Only works for (bf16 x bf16) -> bf16. - bool xla_gpu_ensure_minor_dot_contraction_dims = 249; - // Excludes non-deterministic ops from compiled executables. // Unlike --xla_gpu_deterministic_ops does not disable autotuning - the // compilation itself can be non-deterministic. @@ -1233,8 +1229,9 @@ message DebugOptions { // xla_gpu_enable_cudnn_fmha // xla_gpu_unsupported_force_triton_gemm // xla_allow_get_default_platform + // xla_gpu_ensure_minor_dot_contraction_dims reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320, - 325, 326, 332, 361, 270, 229, 271, 279, 218, 369, 371; + 325, 326, 332, 361, 270, 229, 271, 279, 218, 369, 371, 249; } // Contains flags which affects the GPU compilation result. From 22885935f463c5c5ebb4d347f3b09b9da001747e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 01:50:35 -0700 Subject: [PATCH 0419/1324] Automated Code Change PiperOrigin-RevId: 745485457 --- tensorflow/python/grappler/BUILD | 82 +++++++++++++++++++ tensorflow/python/grappler/cluster_wrapper.cc | 4 + tensorflow/python/grappler/cost_analyzer.cc | 13 +++ tensorflow/python/grappler/cost_analyzer.h | 2 + tensorflow/python/grappler/item_wrapper.cc | 1 + tensorflow/python/grappler/model_analyzer.cc | 7 +- tensorflow/python/grappler/model_analyzer.h | 2 + .../python/grappler/model_analyzer_wrapper.cc | 1 + .../python/grappler/tf_optimizer_wrapper.cc | 2 + 9 files changed, 113 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 687fd36ce2053e..6f5dd068d11ec5 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -26,6 +26,8 @@ cc_library( "//tensorflow/core/grappler/costs:cost_estimator", "//tensorflow/core/grappler/costs:measuring_cost_estimator", "//tensorflow/core/grappler/costs:utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", ] + tf_protos_grappler(), alwayslink = 1, ) @@ -58,12 +60,28 @@ tf_python_pybind_extension( starlark_only = True, deps = [ ":cost_analyzer_headers", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/platform:threadpool_options", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:thread_annotations", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", ], ) @@ -78,6 +96,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/costs:graph_properties", + "@com_google_absl//absl/status", ], ) @@ -94,10 +113,16 @@ tf_python_pybind_extension( ], starlark_only = True, deps = [ + "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:platform_port", "@pybind11", ] + if_pywrap(["//tensorflow/python/grappler:model_analyzer_lib"]), ) @@ -128,11 +153,29 @@ tf_python_pybind_extension( "_pywrap_tf_item.pyi", ], deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/grappler:utils", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:thread_annotations", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]), # b/148556093, ) @@ -213,13 +256,32 @@ tf_python_pybind_extension( "_pywrap_tf_cluster.pyi", ], deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/platform:status", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:thread_annotations", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", ] + if_pywrap( if_true = [ @@ -292,12 +354,32 @@ tf_python_pybind_extension( # }), # static_deps = tf_python_pybind_static_deps(), deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:thread_annotations", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ] + if_pywrap( diff --git a/tensorflow/python/grappler/cluster_wrapper.cc b/tensorflow/python/grappler/cluster_wrapper.cc index dbf97535082413..0df0f3fcc25dc3 100644 --- a/tensorflow/python/grappler/cluster_wrapper.cc +++ b/tensorflow/python/grappler/cluster_wrapper.cc @@ -15,7 +15,9 @@ limitations under the License. #include #include +#include #include +#include #include #include #include @@ -24,6 +26,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 #include "tensorflow/core/framework/kernel_def.pb.h" diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 90f9b426d3756c..44239ebe140536 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -15,10 +15,23 @@ limitations under the License. #include "tensorflow/python/grappler/cost_analyzer.h" +#include +#include +#include #include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h index 44e1e45265b9c5..b14a89b1f318d9 100644 --- a/tensorflow/python/grappler/cost_analyzer.h +++ b/tensorflow/python/grappler/cost_analyzer.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_PYTHON_GRAPPLER_COST_ANALYZER_H_ #include + +#include "absl/status/status.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/python/grappler/item_wrapper.cc b/tensorflow/python/grappler/item_wrapper.cc index 13d2ee6def5c75..27207ec2c053f4 100644 --- a/tensorflow/python/grappler/item_wrapper.cc +++ b/tensorflow/python/grappler/item_wrapper.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include "pybind11/pybind11.h" // from @pybind11 diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index 202eb758a91221..5fe0e94ff8947b 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/python/grappler/model_analyzer.h" -#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h index d66ad8915c99b5..c76d850a5b119a 100644 --- a/tensorflow/python/grappler/model_analyzer.h +++ b/tensorflow/python/grappler/model_analyzer.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_PYTHON_GRAPPLER_MODEL_ANALYZER_H_ #include + +#include "absl/status/status.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/python/grappler/model_analyzer_wrapper.cc b/tensorflow/python/grappler/model_analyzer_wrapper.cc index 86ac40701e303a..c5db3fffa3c123 100644 --- a/tensorflow/python/grappler/model_analyzer_wrapper.cc +++ b/tensorflow/python/grappler/model_analyzer_wrapper.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/core/grappler/grappler_item_builder.h" diff --git a/tensorflow/python/grappler/tf_optimizer_wrapper.cc b/tensorflow/python/grappler/tf_optimizer_wrapper.cc index 08b3a0895071a8..4e88995858e6a7 100644 --- a/tensorflow/python/grappler/tf_optimizer_wrapper.cc +++ b/tensorflow/python/grappler/tf_optimizer_wrapper.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include #include +#include +#include #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf From 94a3eabdab8dbdddfd5ecdb8f205470225a8eb9f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 9 Apr 2025 01:57:09 -0700 Subject: [PATCH 0420/1324] [xla] Add benchmark for grouped rendezvous This benchmark roughly corresponds to a rendezvous that happens for collective operation with 4/8 local devices grouped into 2/4 devices. ``` -------------------------------------------------------------------------------- Benchmark Time CPU Iterations -------------------------------------------------------------------------------- BM_GroupedRendezvous/2/2/process_time 55178 ns 111084 ns 7205 BM_GroupedRendezvous/4/2/process_time 54914 ns 158092 ns 3983 BM_GroupedRendezvous/2/4/process_time 73235 ns 213743 ns 3225 ``` PiperOrigin-RevId: 745487409 --- .../xla/xla/service/rendezvous_test.cc | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index 36e2958413c301..789e01a59a9e7f 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -253,7 +253,7 @@ static void BM_RendezvousWithValues(benchmark::State& state) { for (auto _ : state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { - thread_pool.Schedule([&] { + thread_pool.Schedule([&, i] { int32_t value = i; Rendezvous("rendezvous_test", 0, value, num_threads, [](auto) { return 42; }); @@ -264,6 +264,27 @@ static void BM_RendezvousWithValues(benchmark::State& state) { } } +static void BM_GroupedRendezvous(benchmark::State& state) { + int64_t num_groups = state.range(0); + int64_t group_size = state.range(1); + + auto thread_pool = CreateThreadPool(num_groups * group_size); + + for (auto _ : state) { + absl::BlockingCounter counter(num_groups * group_size); + for (int64_t group = 0; group < num_groups; ++group) { + for (int64_t i = 0; i < group_size; ++i) { + thread_pool.Schedule([&, group] { + Rendezvous("rendezvous_test", group, group_size, + [] { return 42; }); + counter.DecrementCount(); + }); + } + } + counter.Wait(); + } +} + BENCHMARK(BM_Rendezvous) ->MeasureProcessCPUTime() ->Arg(2) @@ -280,5 +301,11 @@ BENCHMARK(BM_RendezvousWithValues) ->Arg(16) ->Arg(32); +BENCHMARK(BM_GroupedRendezvous) + ->MeasureProcessCPUTime() + ->ArgPair(2, 2) + ->ArgPair(4, 2) + ->ArgPair(2, 4); + } // namespace } // namespace xla From 7dc0b9699c917a2590740272199db2cfd68a2309 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Wed, 9 Apr 2025 01:58:19 -0700 Subject: [PATCH 0421/1324] PR #24708: [ROCm] Fix run_hlo_module Imported from GitHub PR https://github.com/openxla/xla/pull/24708 Copybara import of the project: -- fd4b7ed3abb25b697d4af32716ad9a59403fe7c2 by Dragan Mladjenovic : [ROCm] Fix run_hlo_module Merging this change closes #24708 PiperOrigin-RevId: 745487766 --- third_party/xla/xla/tools/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 4a472ea029f6de..d38c371fe4f096 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -2,7 +2,7 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm", "if_rocm_is_configured") load("//xla:lit.bzl", "lit_test_suite") load( "//xla:xla.default.bzl", @@ -536,6 +536,8 @@ xla_cc_binary( "//xla/service:gpu_plugin", ]) + if_cuda([ "//xla/stream_executor:cuda_platform", + ]) + if_rocm([ + "//xla/stream_executor:rocm_platform", ]), ) From 270f6a798a1e441e6bd22f2dc6118377863e2909 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 02:01:23 -0700 Subject: [PATCH 0422/1324] Add interface of Codegen backend for autotuner. PiperOrigin-RevId: 745488823 --- third_party/xla/xla/backends/autotuner/BUILD | 27 +++++++ .../xla/xla/backends/autotuner/backend.h | 60 ++++++++++++++ .../xla/backends/autotuner/backends/gpu/BUILD | 20 +++++ .../autotuner/backends/gpu/gpu_backend.h | 81 +++++++++++++++++++ 4 files changed, 188 insertions(+) create mode 100644 third_party/xla/xla/backends/autotuner/BUILD create mode 100644 third_party/xla/xla/backends/autotuner/backend.h create mode 100644 third_party/xla/xla/backends/autotuner/backends/gpu/BUILD create mode 100644 third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h diff --git a/third_party/xla/xla/backends/autotuner/BUILD b/third_party/xla/xla/backends/autotuner/BUILD new file mode 100644 index 00000000000000..288e7c14bcbe8a --- /dev/null +++ b/third_party/xla/xla/backends/autotuner/BUILD @@ -0,0 +1,27 @@ +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "backend", + hdrs = ["backend.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:protobuf", + ], +) diff --git a/third_party/xla/xla/backends/autotuner/backend.h b/third_party/xla/xla/backends/autotuner/backend.h new file mode 100644 index 00000000000000..30954f6f656d88 --- /dev/null +++ b/third_party/xla/xla/backends/autotuner/backend.h @@ -0,0 +1,60 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_AUTOTUNER_BACKEND_H_ +#define XLA_BACKENDS_AUTOTUNER_BACKEND_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/executable.h" +#include "tsl/platform/protobuf.h" + +namespace xla { + +using BackendConfig = tsl::protobuf::Message; + +// Interface for a codegen backend which can compile HLO instructions with +// different configurations. This can be used to get the supported configs, and +// compile HLO instructions with different configs. +class Backend { + public: + virtual ~Backend() = default; + + virtual absl::string_view name() const = 0; + + // Returns all supported configs for the given HLO instruction. + virtual std::vector> GetSupportedConfigs( + const HloInstruction& instr) = 0; + + // Returns a default config for the given HLO instruction. + virtual absl::StatusOr> GetDefaultConfig( + HloInstruction* instr) { + return absl::UnimplementedError("Not implemented."); + }; + + // Wraps the HLO instruction in a module, assigns the given config, and + // compiles it. + virtual absl::StatusOr> Compile( + const HloInstruction& instr, const BackendConfig& config) = 0; +}; + +} // namespace xla + +#endif // XLA_BACKENDS_AUTOTUNER_BACKEND_H_ diff --git a/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD b/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD new file mode 100644 index 00000000000000..4de5fe9f3b1db2 --- /dev/null +++ b/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD @@ -0,0 +1,20 @@ +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +cc_library( + name = "gpu_backend", + hdrs = ["gpu_backend.h"], + deps = [ + "//xla/backends/autotuner:backend", + "//xla/hlo/ir:hlo", + "//xla/service:compiler", + "//xla/service:executable", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h b/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h new file mode 100644 index 00000000000000..46d23ba057fe6c --- /dev/null +++ b/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h @@ -0,0 +1,81 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_AUTOTUNER_BACKENDS_GPU_GPU_BACKEND_H_ +#define XLA_BACKENDS_AUTOTUNER_BACKENDS_GPU_GPU_BACKEND_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/autotuner/backend.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/compiler.h" +#include "xla/service/executable.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +// Abstract base class for GPU backends, implementing the Backend interface. +class GpuBackend : public Backend { + public: + // target_config and compiler should outlive the backend. + GpuBackend(absl::string_view name, + const Compiler::TargetConfig& target_config, Compiler* compiler) + : name_(name), target_config_(target_config), compiler_(compiler) {} + + absl::string_view name() const override { return name_; } + + absl::StatusOr> Compile( + const HloInstruction& hlo_instruction, + const BackendConfig& config) override { + TF_ASSIGN_OR_RETURN(auto hlo_module, WrapInModule(hlo_instruction, config)); + + Compiler::CompileOptions options; + options.target_config = target_config_; + + TF_ASSIGN_OR_RETURN(auto optimized_module, + RunHloPasses(std::move(hlo_module), options)); + return compiler_->RunBackend(std::move(optimized_module), + /*executor=*/nullptr, options); + } + + private: + // TODO(b/407494653): Provide a default implementation. + virtual absl::StatusOr> WrapInModule( + const HloInstruction& hlo_instruction, const BackendConfig& config) = 0; + + // Optimize the HLO module. + // TODO(b/407494653): Remove this when XLA pipelines is fixed and we autotune + // only optimized and fused HLOs. + virtual absl::StatusOr> RunHloPasses( + std::unique_ptr hlo_module, + const Compiler::CompileOptions& options) = 0; + + std::string name_; + const Compiler::TargetConfig& target_config_; + // TODO(b/407494653): remove compiler when we don't need to run any HLO passes + // and the codegen backend can directly produce an executable without a + // compiler instance. + Compiler* compiler_; +}; + +} // namespace xla + +#endif // XLA_BACKENDS_AUTOTUNER_BACKENDS_GPU_GPU_BACKEND_H_ From 6ffe819a4e8407a672e96f87a00e7f0efbd62e59 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 02:02:35 -0700 Subject: [PATCH 0423/1324] Update GraphDef version to 2192. PiperOrigin-RevId: 745489325 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index c8bc4044c228d2..8cb8f11d7f0d53 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2191 // Updated: 2025/4/8 +#define TF_GRAPH_DEF_VERSION 2192 // Updated: 2025/4/9 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 957ad8f2eec97e3dafd131de91cb960ea24f185c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 02:03:16 -0700 Subject: [PATCH 0424/1324] compat: Update forward compatibility horizon to 2025-04-09 PiperOrigin-RevId: 745489586 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 3b77d06cffd2fe..beebbbb3444e76 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 8) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 9) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 53c9e6ec7a72b682824d1e0a5477e8a109b1aab1 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 9 Apr 2025 02:18:09 -0700 Subject: [PATCH 0425/1324] [XLA] Escape frontend attrs when they are printed as strings. This is needed when the string has characters like \" that need escaping when they are printed, so they can be parsed correctly. PiperOrigin-RevId: 745494179 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 2 +- third_party/xla/xla/hlo/parser/hlo_parser_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 550a474059b265..ee444ab0951121 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -4297,7 +4297,7 @@ std::string FrontendAttributesToString( if (LexesAsJsonDict(item.second)) { absl::StrAppend(out, item.first, "=", item.second); } else { - absl::StrAppend(out, item.first, "=\"", item.second, "\""); + absl::StrAppend(out, item.first, "=\"", CEscape(item.second), "\""); } }; return absl::StrFormat("{%s}", diff --git a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc index 52afbbab8b92a1..04cc5ab3416852 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -4153,7 +4153,7 @@ TEST_F(HloParserTest, ParseUnknownSharding) { TEST_F(HloParserTest, ParseFrontendAttributes) { const std::string original = - R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})"; + R"({attr_a="test_a",attr_b="b",attr_c={type="s64"},attr_d="a=\"b/c\""})"; TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, ParseFrontendAttributes(original)); EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original); From d9a88a0f8671c871fe6c2be41f59bb434e51c42b Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 9 Apr 2025 02:40:14 -0700 Subject: [PATCH 0426/1324] [XLA:GPU] Disable autotuning by default in `fusion_emitter_device_legacy_port_test.cc`. Makes the already enabled tests run 4x faster. PiperOrigin-RevId: 745500714 --- .../codegen/triton/fusion_emitter_device_legacy_port_test.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index c40fb80185a397..d11a868d6819bd 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -142,6 +142,9 @@ class TritonGemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Disable autotuning by default, re-enable it on a per-test basis in order + // to avoid unnecessary slowness. + debug_options.set_xla_gpu_autotune_level(0); // Do not fall back to cuBLAS and disable cuDNN; we are testing Triton. debug_options.set_xla_gpu_cublas_fallback(false); debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0); From 3cf288ab95749b3a657d5f4445dcdda6858f6b17 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Wed, 9 Apr 2025 02:45:33 -0700 Subject: [PATCH 0427/1324] Fix the build issue after commit df841fe When running tests, like `bazel test //xla/tests:hlo_op_profiler_test`, on a machine with only CUDA configured, the build system will try to build ROCm tests and fail. This commit fixes the issue by only building ROCm tests if ROCm is configured and vice versa. -- 6bebc0405cc8bc3a44c44845799b3861ac3376ee by Shraiysh Vaishay : address comments Merging this change closes #23997 PiperOrigin-RevId: 745502157 --- third_party/xla/xla/tests/build_defs.bzl | 11 ++++++++--- .../xla/xla/tsl/platform/default/cuda_build_defs.bzl | 6 ++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 4adcf42b1a0acb..4d878cd2cdf782 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -1,5 +1,9 @@ """Build rules for XLA testing. This file is only used for the OSS build.""" +load( + "@local_config_rocm//rocm:build_defs.bzl", + "is_rocm_configured", +) load("//xla:xla.default.bzl", "xla_cc_test") load("//xla/tests:plugin.bzl", "plugins") load("//xla/tsl:package_groups.bzl", "DEFAULT_LOAD_VISIBILITY") @@ -8,6 +12,7 @@ load( "tf_gpu_tests_tags", ) load("//xla/tsl/platform/default:build_config.bzl", "strict_cc_test") +load("//xla/tsl/platform/default:cuda_build_defs.bzl", "is_cuda_configured") visibility(DEFAULT_LOAD_VISIBILITY) @@ -362,8 +367,9 @@ def xla_test( fail_if_no_test_linked = fail_if_no_test_linked, **this_backend_kwargs ) - - test_names.append(test_name) + if ((backend in NVIDIA_GPU_BACKENDS and is_cuda_configured()) or + (backend in AMD_GPU_DEFAULT_BACKENDS and is_rocm_configured())): + test_names.append(test_name) # Notably, a test_suite with `tests = []` is not empty: # https://bazel.build/reference/be/general#test_suite_args and the default @@ -398,7 +404,6 @@ def xla_test( # --build_tag_filters (see above). Therefore we don't want to fail # if no test case is linked in. fail_if_no_test_linked = False, - **kwargs ) def xla_test_library( diff --git a/third_party/xla/xla/tsl/platform/default/cuda_build_defs.bzl b/third_party/xla/xla/tsl/platform/default/cuda_build_defs.bzl index e48575673b29b2..4490404b09fa74 100644 --- a/third_party/xla/xla/tsl/platform/default/cuda_build_defs.bzl +++ b/third_party/xla/xla/tsl/platform/default/cuda_build_defs.bzl @@ -7,6 +7,7 @@ load( "@local_config_cuda//cuda:build_defs.bzl", _if_cuda_is_configured = "if_cuda_is_configured", _if_cuda_newer_than = "if_cuda_newer_than", + _is_cuda_configured = "is_cuda_configured", ) # IMPORTANT: Do not remove this load statement. We rely on that //xla/tsl doesn't exist in g3 @@ -20,6 +21,11 @@ visibility(DEFAULT_LOAD_VISIBILITY) def if_cuda_is_configured(x, no_cuda = []): return _if_cuda_is_configured(x, no_cuda) +# We perform this indirection so that the copybara tool can distinguish this +# macro from others provided by the same file. +def is_cuda_configured(): + return _is_cuda_configured() + # Constructs rpath linker flags for use with nvidia wheel-packaged libs # avaialble from PyPI. Two paths are needed because symbols are used from # both the root of the TensorFlow installation directory as well as from From 2df2b8059717e39fdf19b49908a083bf13b46d48 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 9 Apr 2025 02:57:04 -0700 Subject: [PATCH 0428/1324] Add contracting split (split-K) selection logic to dynamic search space First step towards the full search space. We figure out the contracting split based on the problem size, and a few other properties of the search space (still hardcoded for now). We also have an option for forcing a specific split, which is helpful for both disabling autotuning that parameter, and will facilitate support for analytically setting it in the future. PiperOrigin-RevId: 745505400 --- .../xla/xla/service/gpu/autotuning/BUILD | 7 + .../gpu/autotuning/dot_search_space.cc | 138 ++++++++++++++++-- .../service/gpu/autotuning/dot_search_space.h | 38 ++++- .../gpu/autotuning/dot_search_space_test.cc | 113 +++++++++++--- .../gpu/autotuning/gemm_fusion_autotuner.cc | 7 +- 5 files changed, 270 insertions(+), 33 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 58433609cb3361..b750e922073c5f 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -264,10 +264,15 @@ cc_library( hdrs = ["dot_search_space.h"], tags = ["gpu"], deps = [ + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/gpu:matmul_utils", "//xla/stream_executor:device_description", + "//xla/tsl/lib/core:bits", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:protobuf", ], ) @@ -284,6 +289,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 9e1510818150a4..9a223b44774273 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -15,35 +15,153 @@ limitations under the License. #include "xla/service/gpu/autotuning/dot_search_space.h" +#include +#include +#include #include #include +#include "absl/log/log.h" #include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/core/bits.h" +#include "xla/util.h" +#include "tsl/platform/protobuf.h" namespace xla::gpu { +namespace { + +// Returns the size (in number of elements) of the subshape of `shape` defined +// by `dimensions`. +int64_t GetSizeInDimensions( + const Shape& shape, + const tsl::protobuf::RepeatedField& dimensions) { + int64_t size = 1; + for (int64_t dim : dimensions) { + size *= shape.dimensions(dim); + } + return size; +} + +// Finds the next power of two larger than or equal to x. +// +// Unlike tsl::NextPowerOfTwo, doesn't crash for 0. +int64_t NextPowerOfTwo(int64_t x) { + if (x == 0) { + return 1; + } + return tsl::NextPowerOfTwoS64(x); +} + +} // namespace TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( const se::DeviceDescription& device_description, const HloDotInstruction* dot) : // Set up basic information about the hardware and the problem. - device_description_(device_description) { - // TODO: b/404470821 - Do something based on `dot`. -} + device_description_(device_description), + contracting_size_(GetSizeInDimensions( + dot->operand(0)->shape(), + dot->dot_dimension_numbers().lhs_contracting_dimensions())), + batch_size_(GetSizeInDimensions( + dot->operand(0)->shape(), + dot->dot_dimension_numbers().lhs_batch_dimensions())), + lhs_parallel_size_(ShapeUtil::ElementsIn(dot->operand(0)->shape()) / + (contracting_size_ * batch_size_)), + rhs_parallel_size_(ShapeUtil::ElementsIn(dot->operand(1)->shape()) / + (contracting_size_ * batch_size_)), + // TODO: b/404470821 - Compute these from the problem properties instead + // of hardcoding. + desired_total_warps_(2160), + max_out_tile_{64, 128}, + min_warps_per_cta_(4), + min_contracting_tile_size_(16), + max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) {} -std::vector TritonDotFusionSearchSpace::GenerateConfigs() { +std::vector TritonDotFusionSearchSpace::GenerateConfigs( + std::optional force_contracting_split) { + std::vector configs; + if (force_contracting_split.has_value()) { + TritonGemmConfig config; + config.split_k = force_contracting_split.value(); + configs.push_back(config); + } else { + configs = GenerateContractingSplitFactors(); + } // TODO: b/404470821 - Implement this properly rather than hardcoding the - // config. - return {TritonGemmConfig( - /*block_m=*/64, /*block_n=*/128, /*block_k=*/64, - /*split_k=*/1, /*num_stages=*/3, /*num_warps=*/4, - /*num_ctas=*/1)}; + // config parameters. + for (auto& config : configs) { + config.block_m = 64; + config.block_n = 128; + config.block_k = 64; + config.num_stages = 3; + config.num_warps = 4; + config.num_ctas = 1; + } + return configs; } std::string TritonDotFusionSearchSpace::Serialize() { - return absl::StrFormat("TODO: b/404470821 - Implement this."); + return absl::StrFormat( + "problem_size_BxMxNxK: %dx%dx%dx%d " + "tile_range_SxMxNxK: [1-%d]x[1-%d]x[1-%d]x[%d-?] " + "desired_total_warps: %d warps_per_block: [%d-?]", + batch_size_, lhs_parallel_size_, rhs_parallel_size_, contracting_size_, + max_contracting_split_, max_out_tile_.lhs_dim, max_out_tile_.rhs_dim, + min_contracting_tile_size_, desired_total_warps_, min_warps_per_cta_); +} + +int64_t TritonDotFusionSearchSpace::GetNumResultTiles( + OutputTile output_tile) const { + return batch_size_ * + CeilOfRatio(lhs_parallel_size_, output_tile.lhs_dim) * + CeilOfRatio(rhs_parallel_size_, output_tile.rhs_dim); +} + +int TritonDotFusionSearchSpace::GetMaxContractingSplit( + OutputTile output_tile) const { + const int64_t desired_num_blocks = desired_total_warps_ / min_warps_per_cta_; + VLOG(5) << "Computing split_k: Considering output tile " + << output_tile.lhs_dim << "x" << output_tile.rhs_dim; + VLOG(5) << "Computing split_k: Want up to " << desired_num_blocks + << " blocks to occupy all cores."; + + const int64_t min_result_tiles = GetNumResultTiles(output_tile); + VLOG(5) << "Computing split_k: Without split_k have " << min_result_tiles + << " tiles."; + + const int64_t split_for_occupancy = + NextPowerOfTwo(CeilOfRatio(desired_num_blocks, min_result_tiles)); + VLOG(5) << "Computing split_k: Want split_k of up to " << split_for_occupancy + << " for sufficient occupancy."; + + const int64_t split_for_contracting_size = + NextPowerOfTwo(contracting_size_ / min_contracting_tile_size_); + VLOG(5) << "Computing split_k: Can't have split_k more than " + << split_for_contracting_size + << " to have sufficiently large contracting dimension."; + + const int64_t split = + std::min(split_for_occupancy, split_for_contracting_size); + VLOG(5) << "Computing split_k: max_split_k = " << split; + return split; +} + +std::vector +TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { + std::vector configs; + TritonGemmConfig config; + for (int split = 1; split <= max_contracting_split_; split *= 2) { + config.split_k = split; + VLOG(10) << "Generating contracting split factors: config = " + << config.ToString(); + configs.push_back(config); + } + return configs; } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index 3fd976b3fd62aa..25d986c15abefa 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_AUTOTUNING_DOT_SEARCH_SPACE_H_ #define XLA_SERVICE_GPU_AUTOTUNING_DOT_SEARCH_SPACE_H_ +#include +#include #include #include @@ -39,14 +41,46 @@ class TritonDotFusionSearchSpace { const HloDotInstruction* dot); // Generates the list of promising configs in the search space for the - // autotuner to try. - std::vector GenerateConfigs(); + // autotuner to try. If `force_contracting_split` is set, the search space + // will be restricted to only include configs with the given split_k factor. + std::vector GenerateConfigs( + std::optional force_contracting_split = std::nullopt); // Serializes the search space to a human-readable string. std::string Serialize(); private: + // Groups together the tiling of the dot's output dimensions: the parallel + // dimensions of the left and right hand sides. We assume that any batch + // dimensions are tiled by a factor of 1. + struct OutputTile { + int lhs_dim; // LHS tiling (aka. block_m). + int rhs_dim; // RHS tiling (aka. block_n). + }; + + // Computes the number of result tiles we would have without + // splitting the contracting dimension for a given output tile. + int64_t GetNumResultTiles(OutputTile output_tile) const; + + // Computes the maximum sensible split in the contracting dimension + // (split_k) to sufficiently occupy all available cores when using the given + // output tile. + int GetMaxContractingSplit(OutputTile output_tile) const; + + // Finds all promising values for splitting the contracting dimension to + // achieve sufficient occupancy (split_k). + std::vector GenerateContractingSplitFactors(); + se::DeviceDescription device_description_; + int64_t contracting_size_; + int64_t batch_size_; + int64_t lhs_parallel_size_; + int64_t rhs_parallel_size_; + int desired_total_warps_; + OutputTile max_out_tile_; + int min_warps_per_cta_; + int min_contracting_tile_size_; + int max_contracting_split_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index bd6feba7bef256..c3d9aded73a1f1 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -31,40 +33,113 @@ limitations under the License. namespace xla::gpu { namespace { +using ::testing::Eq; using ::testing::Field; using ::testing::Ge; +using ::testing::IsEmpty; +using ::testing::Le; + +template +auto BlockMIs(MatcherType matcher) { + return Field("block_m", &TritonGemmConfig::block_m, matcher); +} +template +auto BlockNIs(MatcherType matcher) { + return Field("block_n", &TritonGemmConfig::block_n, matcher); +} +template +auto BlockKIs(MatcherType matcher) { + return Field("block_k", &TritonGemmConfig::block_k, matcher); +} +template +auto SplitKIs(MatcherType matcher) { + return Field("split_k", &TritonGemmConfig::split_k, matcher); +} +template +auto NumStagesIs(MatcherType matcher) { + return Field("num_stages", &TritonGemmConfig::num_stages, matcher); +} +template +auto NumWarpsIs(MatcherType matcher) { + return Field("num_warps", &TritonGemmConfig::num_warps, matcher); +} +template +auto NumCtasIs(MatcherType matcher) { + return Field("num_ctas", &TritonGemmConfig::num_ctas, matcher); +} auto IsValidConfig() { - return AllOf(Field("block_m", &TritonGemmConfig::block_m, Ge(1)), - Field("block_n", &TritonGemmConfig::block_n, Ge(1)), - Field("block_k", &TritonGemmConfig::block_k, Ge(1)), - Field("split_k", &TritonGemmConfig::split_k, Ge(1)), - Field("num_stages", &TritonGemmConfig::num_stages, Ge(1)), - Field("num_warps", &TritonGemmConfig::num_warps, Ge(1)), - Field("num_ctas", &TritonGemmConfig::num_ctas, Ge(1))); + return AllOf(BlockMIs(Ge(1)), BlockNIs(Ge(1)), BlockKIs(Ge(1)), + SplitKIs(Ge(1)), NumStagesIs(Ge(1)), NumWarpsIs(Ge(1)), + NumCtasIs(Ge(1))); }; class DotSearchSpaceTest : public HloHardwareIndependentTestBase { protected: se::DeviceDescription device_description_{ se::DeviceDescription(se::GpuDeviceInfoProto::default_instance())}; -}; -TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + absl::StatusOr> GetDefaultDotModule( + int lhs_parallel_dim = 1024, int rhs_parallel_dim = 1024, + int contracting_dim = 1024) { + constexpr const char* kModuleTextFormat = R"( ENTRY e { - p0 = f32[1024,1024] parameter(0) - p1 = f32[1024,1024] parameter(1) - ROOT r = f32[1024,1024] dot(p0, p1), + p0 = f16[%d,%d] parameter(0) + p1 = f16[%d,%d] parameter(1) + ROOT r = f16[%d,%d] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - TritonDotFusionSearchSpace search_space( - device_description_, - Cast(module->entry_computation()->root_instruction())); +})"; + return ParseAndReturnVerifiedModule(absl::StrFormat( + kModuleTextFormat, lhs_parallel_dim, contracting_dim, contracting_dim, + rhs_parallel_dim, lhs_parallel_dim, rhs_parallel_dim)); + } + + HloDotInstruction* GetDot(VerifiedHloModule* module) { + return Cast( + module->entry_computation()->root_instruction()); + } +}; + +TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { + TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space(device_description_, + GetDot(module.get())); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Not(IsEmpty()), Each(IsValidConfig()))); +} + +TEST_F(DotSearchSpaceTest, HonorsForcedContractingSplit) { + TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space(device_description_, + GetDot(module.get())); + + EXPECT_THAT( + search_space.GenerateConfigs(/*force_contracting_split=*/2), + AllOf(Not(IsEmpty()), Each(IsValidConfig()), Each(SplitKIs(Eq(2))))); +} + +TEST_F(DotSearchSpaceTest, ConsidersContractingSplitForSmallOutputSize) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/16, + /*rhs_parallel_dim=*/16, + /*contracting_dim=*/1024)); + TritonDotFusionSearchSpace search_space(device_description_, + GetDot(module.get())); + + EXPECT_THAT(search_space.GenerateConfigs(), Contains(SplitKIs(Ge(2)))); +} + +TEST_F(DotSearchSpaceTest, LimitsContractingSplitForSmallerContractingSize) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/16, + /*rhs_parallel_dim=*/16, + /*contracting_dim=*/32)); + TritonDotFusionSearchSpace search_space(device_description_, + GetDot(module.get())); EXPECT_THAT(search_space.GenerateConfigs(), - AllOf(Not(::testing::IsEmpty()), Each(IsValidConfig()))); + AllOf(Not(IsEmpty()), Each(SplitKIs(Le(2))))); } } // namespace diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 0cb9682053a86b..10bb06ef5d4c17 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -887,6 +887,8 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { bool small_dot = ShapeUtil::ElementsIn(dot.operand(0)->shape()) + ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements; + bool autotune_contracting_split = + debug_options_.xla_gpu_enable_split_k_autotuning(); if (debug_options_.xla_gpu_experimental_enable_dynamic_dot_search_space()) { if (small_dot || !IsAutotuningEnabled()) { @@ -896,7 +898,8 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { &dot); VLOG(1) << "Generating configs from search space: " << search_space.Serialize(); - return search_space.GenerateConfigs(); + return search_space.GenerateConfigs( + autotune_contracting_split ? std::make_optional(1) : std::nullopt); } // Retrieve the minimum bit-width participating in the dot. This is needed @@ -952,7 +955,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { config.block_n = std::min(config.block_n, limits.block_n); config.block_k = std::min(config.block_k, limits.block_k); int max_split_k = 1; - if (debug_options_.xla_gpu_enable_split_k_autotuning()) { + if (autotune_contracting_split) { int64_t ratio = kSufficientNumberOfTiles * config.block_m * config.block_n / result_size; max_split_k = 1 << std::max(tsl::Log2Floor64(ratio), 0); From 8a4aa4ef3aed779a7ded2b389bad9e53983df62c Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Wed, 9 Apr 2025 03:09:36 -0700 Subject: [PATCH 0429/1324] PR #24796: Turn on stream annotation support by default Imported from GitHub PR https://github.com/openxla/xla/pull/24796 Turn this flag on by default, with the intention of removing this flag completely. Ideally, this should not break anything, and only enable a new feature for all users. Copybara import of the project: -- 61092c1a3dda3aada9313dba95e454ee53537048 by chaser : Turn on stream annotation support by default Merging this change closes #24796 PiperOrigin-RevId: 745508930 --- third_party/xla/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 9afc77a08a25c1..7bd5f886de4ee5 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -280,7 +280,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_p2p_max_nchannels(0); opts.set_xla_gpu_multi_streamed_windowed_einsum(true); - opts.set_xla_gpu_experimental_stream_annotation(false); + opts.set_xla_gpu_experimental_stream_annotation(true); // Minimum combined size of matrices in matrix multiplication to // be rewritten to cuBLAS or Triton kernel call. // This threshold is a conservative estimate and has been measured From 2b9a56b658b01658cab71466ca5ced9faf3d2316 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 9 Apr 2025 03:19:07 -0700 Subject: [PATCH 0430/1324] [XLA:GPU] Add initial version of the CUDA kernel for one-shot all-reduce. This is only the initial version that has a number of limitations: * Only support `float` input. * Doesn't have vectorization or loop unrolling. * Requires the caller to do synchronization. PiperOrigin-RevId: 745511335 --- third_party/xla/xla/service/gpu/kernels/BUILD | 59 +++++++++++ .../service/gpu/kernels/all_reduce_kernel.cc | 91 +++++++++++++++++ .../gpu/kernels/all_reduce_kernel.cu.cc | 62 ++++++++++++ .../service/gpu/kernels/all_reduce_kernel.h | 54 ++++++++++ .../gpu/kernels/all_reduce_kernel_common.h | 34 +++++++ .../gpu/kernels/all_reduce_kernel_test.cc | 98 +++++++++++++++++++ 6 files changed, 398 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc create mode 100644 third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc create mode 100644 third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h create mode 100644 third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h create mode 100644 third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 06b0504b11047c..4b58886527fbc3 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -263,6 +263,65 @@ xla_test( ], ) +cc_library( + name = "all_reduce_kernel", + srcs = ["all_reduce_kernel.cc"], + hdrs = ["all_reduce_kernel.h"], + tags = ["gpu"], + visibility = [":friends"], + deps = [ + ":all_reduce_kernel_gpu", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + ], +) + +gpu_kernel_library( + name = "all_reduce_kernel_gpu", + srcs = ["all_reduce_kernel.cu.cc"], + hdrs = ["all_reduce_kernel_common.h"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), +) + +xla_test( + name = "all_reduce_kernel_test", + srcs = ["all_reduce_kernel_test.cc"], + backends = ["gpu"], + deps = [ + ":all_reduce_kernel", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor/gpu:gpu_init", + "//xla/stream_executor/host:host_platform", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "ragged_all_to_all_kernel", srcs = ["ragged_all_to_all_kernel.cc"], diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc new file mode 100644 index 00000000000000..d17217e82ead88 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc @@ -0,0 +1,91 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/all_reduce_kernel.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/service/gpu/kernels/all_reduce_kernel_common.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/typed_kernel_factory.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu { + +namespace { + +void* GetKernel(PrimitiveType element_type) { + switch (element_type) { + case F32: + return GetAllReduceKernel(); + default: + return nullptr; + } +} + +} // namespace + +bool IsAllReduceKernelSupported(int64_t num_outputs, + PrimitiveType element_type) { + return num_outputs <= kMaxNumAllReduceInputPtrs && + GetKernel(element_type) != nullptr; +} + +absl::Status RunAllReduceKernel( + se::Stream* stream, PrimitiveType element_type, + absl::Span input_buffers, + se::DeviceMemoryBase output_buffer, int64_t num_inputs, + int64_t num_elements) { + if (input_buffers.size() > kMaxNumAllReduceInputPtrs) { + return absl::InvalidArgumentError( + "Number of input pointers exceeds the maximum supported number of " + "input pointers."); + } + + se::StreamExecutor* executor = stream->parent(); + + // TODO(b/383125489): Fine tune the block and thread dimensions. + static constexpr size_t kBlocks = 8; + static constexpr size_t kThreads = 512; + + TF_ASSIGN_OR_RETURN( + auto kernel, + (se::TypedKernelFactory, + se::DeviceMemoryBase, int64_t, + int64_t>::Create(executor, "one_shot_all_reduce", + GetKernel(element_type)))); + + std::array input_ptrs; + absl::c_transform( + input_buffers, input_ptrs.begin(), + [](se::DeviceMemoryBase buffer) { return buffer.opaque(); }); + + return kernel.Launch(se::ThreadDim(kThreads, 1, 1), + se::BlockDim(kBlocks, 1, 1), stream, input_ptrs, + output_buffer, num_inputs, num_elements); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc new file mode 100644 index 00000000000000..f819ab286a3390 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc @@ -0,0 +1,62 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/service/gpu/kernels/all_reduce_kernel_common.h" + +namespace xla::gpu { +namespace { + +template +__global__ void AllReduceKernel( + std::array input_ptrs, + T* __restrict__ output_ptr, int64_t num_inputs, int64_t num_elements) { + int64_t offset = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride = blockDim.x * gridDim.x; + + for (int i = offset; i < num_elements; i += stride) { + T sum = 0; + +#pragma unroll + for (int j = 0; j < kMaxNumAllReduceInputPtrs; ++j) { + if (j >= num_inputs) break; + + // TODO(b/383125489): Add vectorization. + T* input_ptr = + reinterpret_cast< // REINTERPRET_CAST_OK=tsl::safe_reinterpret_cast + // doesn't work with __restrict__. + T* __restrict__>(input_ptrs[j]); + sum += input_ptr[i]; + } + + output_ptr[i] = sum; + } +} + +} // namespace + +template +void* GetAllReduceKernel() { + return reinterpret_cast< // REINTERPRET_CAST_OK=tsl::safe_reinterpret_cast + // doesn't support this cast, but it's necessary to + // conform to se::TypedKernelFactory<>::Create(). + void*>(&AllReduceKernel); +} + +template void* GetAllReduceKernel(); + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h new file mode 100644 index 00000000000000..7c354752e545f9 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h @@ -0,0 +1,54 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_H_ + +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "xla/types.h" // IWYU pragma: keep +#include "xla/xla_data.pb.h" + +namespace xla::gpu { + +// Performs element-wise addition of all input buffers and stores the result in +// the output buffer. +// The kernel is intended to be used for all-reduce operations in environment +// where direct peer memory access is available. Input buffers can point to +// memory on different devices. The caller is responsible to gather pointers +// from different devices. +// +// TODO(b/383125489): Add synchronization between blocks in the kernek. +// The caller is also responsible to synchronize streams on all participating +// devices before and after the kernel execution. +// +// Input arguments: +// - input_buffers: A list of input buffers. +// - output_buffer: The buffer to store the result. +// - num_inputs: The number of input buffers. +// - num_elements: The number of elements in each buffer. +absl::Status RunAllReduceKernel( + se::Stream* stream, PrimitiveType element_type, + absl::Span input_buffers, + se::DeviceMemoryBase output_buffer, int64_t num_inputs, + int64_t num_elements); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h new file mode 100644 index 00000000000000..dc199258e58dd9 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h @@ -0,0 +1,34 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_COMMON_H_ +#define XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_COMMON_H_ + +#include + +namespace xla::gpu { + +// The maximum number of input pointers that can be passed to the all-reduce +// kernel. +inline constexpr int64_t kMaxNumAllReduceInputPtrs = 8; + +// Returns a pointer to the all-reduce kernel for the given element type. +// Returns nullptr if the element type is not supported. +template +void* GetAllReduceKernel(); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_COMMON_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc new file mode 100644 index 00000000000000..32e3ac51b4a558 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/all_reduce_kernel.h" + +#include +#include +#include +#include +#include + +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/primitive_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_handle.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu { +namespace { + +se::StreamExecutor* GetGpuExecutor() { + auto* platform = + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + return platform->ExecutorForDevice(0).value(); +} + +using AllReduceKernelTest = ::testing::Test; + +TEST_F(AllReduceKernelTest, SimpleKernelTest) { + using T = float; + + auto* executor = GetGpuExecutor(); + auto stream = executor->CreateStream().value(); + + constexpr int64_t num_inputs = 2; + constexpr int64_t num_elements = 128000; + + std::vector input_buffers; + for (int64_t i = 0; i < num_inputs; ++i) { + input_buffers.emplace_back(executor, + executor->AllocateArray(num_elements)); + ASSERT_TRUE(!input_buffers[i].memory().is_null()); + } + + se::DeviceMemoryHandle output_buffer( + executor, executor->AllocateArray(num_elements)); + ASSERT_TRUE(!output_buffer.memory().is_null()); + + std::vector output_data(num_elements); + for (int i = 0; i < num_inputs; ++i) { + std::vector input_data(num_elements); + std::iota(input_data.begin(), input_data.end(), 0); + + TF_ASSERT_OK(stream->Memcpy(input_buffers[i].memory_ptr(), + input_data.data(), num_elements * sizeof(T))); + + std::transform(input_data.begin(), input_data.end(), output_data.begin(), + output_data.begin(), std::plus()); + } + + std::vector input_buffers_span; + for (auto& input_buffer : input_buffers) { + input_buffers_span.push_back(input_buffer.memory()); + } + + TF_ASSERT_OK(RunAllReduceKernel( + stream.get(), primitive_util::NativeToPrimitiveType(), + input_buffers_span, output_buffer.memory(), num_inputs, num_elements)); + + std::vector output_results(num_elements); + TF_ASSERT_OK(stream->Memcpy(output_results.data(), output_buffer.memory(), + num_elements * sizeof(T))); + + EXPECT_EQ(output_results, output_data); +} + +} // namespace +} // namespace xla::gpu From 4934f50d3cc7df2ecac9fa824e8c71c25bc41b44 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 9 Apr 2025 03:23:16 -0700 Subject: [PATCH 0431/1324] Automated Code Change PiperOrigin-RevId: 745512351 --- .../compiler/mlir/tools/kernel_gen/ir/BUILD | 38 ++--- .../mlir/tools/kernel_gen/transforms/BUILD | 11 +- tensorflow/core/ir/BUILD | 96 +++++------ tensorflow/core/ir/types/BUILD | 52 ++---- tensorflow/core/transforms/BUILD | 22 +-- tensorflow/core/transforms/remapper/BUILD | 7 +- .../backends/cpu/codegen/emitters/ir/BUILD | 54 +++---- .../cpu/codegen/emitters/transforms/BUILD | 13 +- .../backends/gpu/codegen/emitters/ir/BUILD | 90 ++++------- .../gpu/codegen/emitters/transforms/BUILD | 13 +- .../xla/backends/gpu/codegen/triton/ir/BUILD | 80 ++++----- .../gpu/codegen/triton/transforms/BUILD | 15 +- third_party/xla/xla/codegen/emitters/ir/BUILD | 60 +++---- .../xla/xla/codegen/emitters/transforms/BUILD | 13 +- .../xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 2 +- third_party/xla/xla/mlir/framework/ir/BUILD | 34 +--- .../xla/xla/mlir/framework/transforms/BUILD | 13 +- third_party/xla/xla/mlir_hlo/BUILD | 152 +++++------------- 18 files changed, 239 insertions(+), 526 deletions(-) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 42b29d86d31e51..0c504a62de1627 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -30,24 +30,12 @@ td_library( gentbl_cc_library( name = "tf_framework_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_framework_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_framework_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "tf_framework_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "tf_framework_dialect.cc.inc", - ), - ], + tbl_outs = { + "tf_framework_ops.h.inc": ["-gen-op-decls"], + "tf_framework_ops.cc.inc": ["-gen-op-defs"], + "tf_framework_dialect.h.inc": ["-gen-dialect-decls"], + "tf_framework_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_framework_ops.td", deps = [":td_files"], @@ -56,16 +44,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_status_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-enum-decls"], - "tf_status.h.inc", - ), - ( - ["-gen-enum-defs"], - "tf_status.cc.inc", - ), - ], + tbl_outs = { + "tf_status.h.inc": ["-gen-enum-decls"], + "tf_status.cc.inc": ["-gen-enum-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_status.td", deps = [":td_files"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 331d1aa9c28aa8..262f9fc56d78f2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -85,13 +85,10 @@ cc_library( gentbl_cc_library( name = "kernel_gen_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [( - [ - "-gen-pass-decls", - "-name=KernelGen", - ], - "kernel_gen_passes.h.inc", - )], + tbl_outs = {"kernel_gen_passes.h.inc": [ + "-gen-pass-decls", + "-name=KernelGen", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/core/ir/BUILD b/tensorflow/core/ir/BUILD index 6d9aee324fb32d..9a77ffcfd97283 100644 --- a/tensorflow/core/ir/BUILD +++ b/tensorflow/core/ir/BUILD @@ -23,16 +23,10 @@ td_library( gentbl_cc_library( name = "InterfacesIncGen", - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "interfaces.cc.inc", - ), - ], + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "interfaces.td", deps = [ @@ -59,56 +53,38 @@ td_library( gentbl_cc_library( name = "DialectIncGen", - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-dialect", - "tfg", - ], - "ops.h.inc", - ), - ( - [ - "-gen-op-defs", - "-dialect", - "tfg", - ], - "ops.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect", - "tfg", - ], - "dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect", - "tfg", - ], - "dialect.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "-attrdefs-dialect", - "tfg", - ], - "attributes.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "-attrdefs-dialect", - "tfg", - ], - "attributes.cc.inc", - ), - ], + tbl_outs = { + "ops.h.inc": [ + "-gen-op-decls", + "-dialect", + "tfg", + ], + "ops.cc.inc": [ + "-gen-op-defs", + "-dialect", + "tfg", + ], + "dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect", + "tfg", + ], + "dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect", + "tfg", + ], + "attributes.h.inc": [ + "-gen-attrdef-decls", + "-attrdefs-dialect", + "tfg", + ], + "attributes.cc.inc": [ + "-gen-attrdef-defs", + "-attrdefs-dialect", + "tfg", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ops.td", deps = [ diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD index 4577967ebc7d19..73d8e61c03c6a3 100644 --- a/tensorflow/core/ir/types/BUILD +++ b/tensorflow/core/ir/types/BUILD @@ -30,16 +30,10 @@ td_library( gentbl_cc_library( name = "DialectIncGen", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "dialect.cpp.inc", - ), - ], + tbl_outs = { + "dialect.h.inc": ["-gen-dialect-decls"], + "dialect.cpp.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "dialect.td", deps = [ @@ -49,24 +43,12 @@ gentbl_cc_library( gentbl_cc_library( name = "AttributesIncGen", - tbl_outs = [ - ( - ["-gen-attrdef-decls"], - "attributes.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "attributes.cc.inc", - ), - ( - ["-gen-enum-decls"], - "attributes_enum.h.inc", - ), - ( - ["-gen-enum-defs"], - "attributes_enum.cc.inc", - ), - ], + tbl_outs = { + "attributes.h.inc": ["-gen-attrdef-decls"], + "attributes.cc.inc": ["-gen-attrdef-defs"], + "attributes_enum.h.inc": ["-gen-enum-decls"], + "attributes_enum.cc.inc": ["-gen-enum-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "attributes.td", deps = [":DialectTdFiles"], @@ -74,16 +56,10 @@ gentbl_cc_library( gentbl_cc_library( name = "TypesIncGen", - tbl_outs = [ - ( - ["-gen-typedef-decls"], - "types.h.inc", - ), - ( - ["-gen-typedef-defs"], - "types.cc.inc", - ), - ], + tbl_outs = { + "types.h.inc": ["-gen-typedef-decls"], + "types.cc.inc": ["-gen-typedef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "types.td", deps = [":DialectTdFiles"], diff --git a/tensorflow/core/transforms/BUILD b/tensorflow/core/transforms/BUILD index 9f887a843ba4be..5ebaa61c329d3e 100644 --- a/tensorflow/core/transforms/BUILD +++ b/tensorflow/core/transforms/BUILD @@ -15,16 +15,11 @@ package( gentbl_cc_library( name = "PassIncGen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "--name", - "TFGraph", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "--name", + "TFGraph", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ @@ -69,12 +64,7 @@ cc_library( gentbl_cc_library( name = "PDLLUtilsIncGen", - tbl_outs = [ - ( - ["-x=cpp"], - "utils/pdll/PDLLUtils.h.inc", - ), - ], + tbl_outs = {"utils/pdll/PDLLUtils.h.inc": ["-x=cpp"]}, tblgen = "@llvm-project//mlir:mlir-pdll", td_file = "utils/pdll/utils.pdll", deps = [ diff --git a/tensorflow/core/transforms/remapper/BUILD b/tensorflow/core/transforms/remapper/BUILD index 1b0c4ca4ded504..455abdaf32d851 100644 --- a/tensorflow/core/transforms/remapper/BUILD +++ b/tensorflow/core/transforms/remapper/BUILD @@ -13,12 +13,7 @@ package( gentbl_cc_library( name = "MklPDLLPatternsIncGen", - tbl_outs = [ - ( - ["-x=cpp"], - "pdll/MklPDLLPatterns.h.inc", - ), - ], + tbl_outs = {"pdll/MklPDLLPatterns.h.inc": ["-x=cpp"]}, tblgen = "@llvm-project//mlir:mlir-pdll", td_file = "pdll/mkl_patterns.pdll", deps = [ diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD index 71e5d5af522bb6..7d223e2e09168d 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD @@ -30,16 +30,10 @@ gentbl_cc_library( name = "xla_cpu_dialect_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "xla_cpu_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "xla_cpu_dialect.cc.inc", - ), - ], + tbl_outs = { + "xla_cpu_dialect.h.inc": ["-gen-dialect-decls"], + "xla_cpu_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_cpu_dialect.td", deps = [":xla_cpu_td_files"], @@ -49,22 +43,16 @@ gentbl_cc_library( name = "xla_cpu_types_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - "-typedefs-dialect=xla_cpu", - ], - "xla_cpu_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "-typedefs-dialect=xla_cpu", - ], - "xla_cpu_types.cc.inc", - ), - ], + tbl_outs = { + "xla_cpu_types.h.inc": [ + "-gen-typedef-decls", + "-typedefs-dialect=xla_cpu", + ], + "xla_cpu_types.cc.inc": [ + "-gen-typedef-defs", + "-typedefs-dialect=xla_cpu", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_cpu_types.td", deps = [":xla_cpu_td_files"], @@ -74,16 +62,10 @@ gentbl_cc_library( name = "xla_cpu_ops_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "xla_cpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "xla_cpu_ops.cc.inc", - ), - ], + tbl_outs = { + "xla_cpu_ops.h.inc": ["-gen-op-decls"], + "xla_cpu_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_cpu_ops.td", deps = [":xla_cpu_td_files"], diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD index abd71c599a8571..8d1db1679a840e 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD @@ -18,15 +18,10 @@ package_group( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=XlaCpuTransforms", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=XlaCpuTransforms", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", visibility = ["//visibility:private"], diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD index 25b835dd9e52b7..bab734000f6404 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD @@ -35,16 +35,10 @@ gentbl_cc_library( name = "xla_gpu_dialect_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "xla_gpu_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "xla_gpu_dialect.cc.inc", - ), - ], + tbl_outs = { + "xla_gpu_dialect.h.inc": ["-gen-dialect-decls"], + "xla_gpu_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_gpu_dialect.td", deps = [":xla_gpu_td_files"], @@ -54,16 +48,10 @@ gentbl_cc_library( name = "xla_gpu_ops_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "xla_gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "xla_gpu_ops.cc.inc", - ), - ], + tbl_outs = { + "xla_gpu_ops.h.inc": ["-gen-op-decls"], + "xla_gpu_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_gpu_ops.td", deps = [":xla_gpu_td_files"], @@ -73,30 +61,18 @@ gentbl_cc_library( name = "xla_gpu_attrs_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-enum-decls"], - "xla_gpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "xla_gpu_enums.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "-attrdefs-dialect=xla_gpu", - ], - "xla_gpu_attrs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "-attrdefs-dialect=xla_gpu", - ], - "xla_gpu_attrs.cc.inc", - ), - ], + tbl_outs = { + "xla_gpu_enums.h.inc": ["-gen-enum-decls"], + "xla_gpu_enums.cc.inc": ["-gen-enum-defs"], + "xla_gpu_attrs.h.inc": [ + "-gen-attrdef-decls", + "-attrdefs-dialect=xla_gpu", + ], + "xla_gpu_attrs.cc.inc": [ + "-gen-attrdef-defs", + "-attrdefs-dialect=xla_gpu", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_gpu_attrs.td", deps = [":xla_gpu_td_files"], @@ -106,22 +82,16 @@ gentbl_cc_library( name = "xla_gpu_types_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - "-typedefs-dialect=xla_gpu", - ], - "xla_gpu_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "-typedefs-dialect=xla_gpu", - ], - "xla_gpu_types.cc.inc", - ), - ], + tbl_outs = { + "xla_gpu_types.h.inc": [ + "-gen-typedef-decls", + "-typedefs-dialect=xla_gpu", + ], + "xla_gpu_types.cc.inc": [ + "-gen-typedef-defs", + "-typedefs-dialect=xla_gpu", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_gpu_types.td", deps = [":xla_gpu_td_files"], diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD index ef97d62edccc0c..6ba7c4508eb79f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD @@ -20,15 +20,10 @@ package_group( gentbl_cc_library( name = "passes_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=GpuFusionTransforms", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=GpuFusionTransforms", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", visibility = ["//visibility:private"], diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/ir/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/ir/BUILD index 7f9a41931fbc59..91261cfa3789c7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/ir/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/ir/BUILD @@ -27,16 +27,10 @@ td_library( gentbl_cc_library( name = "triton_xla_dialect_inc_gen", strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "triton_xla_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "triton_xla_dialect.cc.inc", - ), - ], + tbl_outs = { + "triton_xla_dialect.h.inc": ["-gen-dialect-decls"], + "triton_xla_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "triton_xla_dialect.td", deps = [":triton_xla_td_files"], @@ -45,16 +39,10 @@ gentbl_cc_library( gentbl_cc_library( name = "triton_xla_ops_inc_gen", strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "triton_xla_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "triton_xla_ops.cc.inc", - ), - ], + tbl_outs = { + "triton_xla_ops.h.inc": ["-gen-op-decls"], + "triton_xla_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "triton_xla_ops.td", deps = [ @@ -69,22 +57,16 @@ gentbl_cc_library( gentbl_cc_library( name = "triton_xla_types_inc_gen", strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - "-typedefs-dialect=triton_xla", - ], - "triton_xla_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "-typedefs-dialect=triton_xla", - ], - "triton_xla_types.cc.inc", - ), - ], + tbl_outs = { + "triton_xla_types.h.inc": [ + "-gen-typedef-decls", + "-typedefs-dialect=triton_xla", + ], + "triton_xla_types.cc.inc": [ + "-gen-typedef-defs", + "-typedefs-dialect=triton_xla", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "triton_xla_types.td", deps = [":triton_xla_td_files"], @@ -93,22 +75,16 @@ gentbl_cc_library( gentbl_cc_library( name = "triton_xla_attrs_inc_gen", strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-attrdef-decls", - "-attrdefs-dialect=triton_xla", - ], - "triton_xla_attrs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "-attrdefs-dialect=triton_xla", - ], - "triton_xla_attrs.cc.inc", - ), - ], + tbl_outs = { + "triton_xla_attrs.h.inc": [ + "-gen-attrdef-decls", + "-attrdefs-dialect=triton_xla", + ], + "triton_xla_attrs.cc.inc": [ + "-gen-attrdef-defs", + "-attrdefs-dialect=triton_xla", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "triton_xla_attrs.td", deps = [ diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD index e992f71db66711..3a9da3b0c96c68 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD @@ -15,16 +15,11 @@ package_group( gentbl_cc_library( name = "passes_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TritonXlaTransforms", - "-attrdefs-dialect=triton_xla", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=TritonXlaTransforms", + "-attrdefs-dialect=triton_xla", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", visibility = ["//visibility:private"], diff --git a/third_party/xla/xla/codegen/emitters/ir/BUILD b/third_party/xla/xla/codegen/emitters/ir/BUILD index 14bd888e53aa27..11da591197ee32 100644 --- a/third_party/xla/xla/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/codegen/emitters/ir/BUILD @@ -35,16 +35,10 @@ gentbl_cc_library( name = "xla_dialect_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "xla_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "xla_dialect.cc.inc", - ), - ], + tbl_outs = { + "xla_dialect.h.inc": ["-gen-dialect-decls"], + "xla_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_dialect.td", deps = [":xla_td_files"], @@ -54,16 +48,10 @@ gentbl_cc_library( name = "xla_ops_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "xla_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "xla_ops.cc.inc", - ), - ], + tbl_outs = { + "xla_ops.h.inc": ["-gen-op-decls"], + "xla_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_ops.td", deps = [":xla_td_files"], @@ -73,28 +61,16 @@ gentbl_cc_library( name = "xla_attrs_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-enum-decls"], - "xla_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "xla_enums.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - ], - "xla_attrs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - ], - "xla_attrs.cc.inc", - ), - ], + tbl_outs = { + "xla_enums.h.inc": ["-gen-enum-decls"], + "xla_enums.cc.inc": ["-gen-enum-defs"], + "xla_attrs.h.inc": [ + "-gen-attrdef-decls", + ], + "xla_attrs.cc.inc": [ + "-gen-attrdef-defs", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_attrs.td", deps = [":xla_td_files"], diff --git a/third_party/xla/xla/codegen/emitters/transforms/BUILD b/third_party/xla/xla/codegen/emitters/transforms/BUILD index 131dcf441126b2..d9c2424c708a9d 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/codegen/emitters/transforms/BUILD @@ -32,15 +32,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Transforms", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=Transforms", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", visibility = ["//visibility:private"], diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index 7ad5ef164428bc..5c5427a255ed28 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -206,7 +206,7 @@ cc_binary( gentbl_cc_library( name = "hlo_op_writer_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "hlo_op_writer.inc")], + tbl_outs = {"hlo_op_writer.inc": []}, tblgen = ":gen_hlo_op_writer", td_file = "gen_hlo_op_writer.td", deps = [ diff --git a/third_party/xla/xla/mlir/framework/ir/BUILD b/third_party/xla/xla/mlir/framework/ir/BUILD index 05048b9d618add..e0feb5c25e583c 100644 --- a/third_party/xla/xla/mlir/framework/ir/BUILD +++ b/third_party/xla/xla/mlir/framework/ir/BUILD @@ -25,32 +25,14 @@ td_library( gentbl_cc_library( name = "xla_framework_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "xla_framework.h.inc", - ), - ( - ["-gen-op-defs"], - "xla_framework.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "xla_framework_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "xla_framework_dialect.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "xla_framework_types.h.inc", - ), - ( - ["-gen-typedef-defs"], - "xla_framework_types.cc.inc", - ), - ], + tbl_outs = { + "xla_framework.h.inc": ["-gen-op-decls"], + "xla_framework.cc.inc": ["-gen-op-defs"], + "xla_framework_dialect.h.inc": ["-gen-dialect-decls"], + "xla_framework_dialect.cc.inc": ["-gen-dialect-defs"], + "xla_framework_types.h.inc": ["-gen-typedef-decls"], + "xla_framework_types.cc.inc": ["-gen-typedef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_framework_ops.td", deps = [":td_files"], diff --git a/third_party/xla/xla/mlir/framework/transforms/BUILD b/third_party/xla/xla/mlir/framework/transforms/BUILD index f52120f9dc0eb9..b64af611274f66 100644 --- a/third_party/xla/xla/mlir/framework/transforms/BUILD +++ b/third_party/xla/xla/mlir/framework/transforms/BUILD @@ -12,15 +12,10 @@ package( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=XlaFramework", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=XlaFramework", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 9caf96a8b72f2b..e2efc83b32793b 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -48,15 +48,10 @@ gentbl_cc_library( name = "mhlo_pass_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=AllMhlo", - ], - "mhlo/transforms/mhlo_passes.h.inc", - ), - ], + tbl_outs = {"mhlo/transforms/mhlo_passes.h.inc": [ + "-gen-pass-decls", + "-name=AllMhlo", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/transforms/mhlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -66,16 +61,10 @@ gentbl_cc_library( name = "hlo_ops_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "mhlo/IR/hlo_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mhlo/IR/hlo_ops.cc.inc", - ), - ], + tbl_outs = { + "mhlo/IR/hlo_ops.h.inc": ["-gen-op-decls"], + "mhlo/IR/hlo_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], @@ -85,16 +74,10 @@ gentbl_cc_library( name = "hlo_ops_attrs_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-attrdef-decls"], - "mhlo/IR/hlo_ops_attrs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "mhlo/IR/hlo_ops_attrs.cc.inc", - ), - ], + tbl_outs = { + "mhlo/IR/hlo_ops_attrs.h.inc": ["-gen-attrdef-decls"], + "mhlo/IR/hlo_ops_attrs.cc.inc": ["-gen-attrdef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], @@ -104,16 +87,10 @@ gentbl_cc_library( name = "hlo_ops_enums_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-enum-decls"], - "mhlo/IR/hlo_ops_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "mhlo/IR/hlo_ops_enums.cc.inc", - ), - ], + tbl_outs = { + "mhlo/IR/hlo_ops_enums.h.inc": ["-gen-enum-decls"], + "mhlo/IR/hlo_ops_enums.cc.inc": ["-gen-enum-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], @@ -123,22 +100,16 @@ gentbl_cc_library( name = "hlo_ops_typedefs_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=mhlo", - ], - "mhlo/IR/hlo_ops_typedefs.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=mhlo", - ], - "mhlo/IR/hlo_ops_typedefs.cc.inc", - ), - ], + tbl_outs = { + "mhlo/IR/hlo_ops_typedefs.h.inc": [ + "-gen-typedef-decls", + "--typedefs-dialect=mhlo", + ], + "mhlo/IR/hlo_ops_typedefs.cc.inc": [ + "-gen-typedef-defs", + "--typedefs-dialect=mhlo", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], @@ -148,12 +119,7 @@ gentbl_cc_library( name = "hlo_ops_pattern_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/IR/", - tbl_outs = [ - ( - ["-gen-rewriters"], - "mhlo/IR/hlo_patterns.cc.inc", - ), - ], + tbl_outs = {"mhlo/IR/hlo_patterns.cc.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/IR/hlo_patterns.td", deps = [ @@ -197,12 +163,7 @@ gentbl_cc_library( name = "canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-rewriters"], - "mhlo/IR/mhlo_canonicalize.inc", - ), - ], + tbl_outs = {"mhlo/IR/mhlo_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/IR/mhlo_canonicalize.td", deps = [":hlo_ops_td_files"], @@ -245,15 +206,10 @@ gentbl_cc_library( name = "deallocation_passes_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Deallocation", - ], - "deallocation/transforms/passes.h.inc", - ), - ], + tbl_outs = {"deallocation/transforms/passes.h.inc": [ + "-gen-pass-decls", + "-name=Deallocation", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "deallocation/transforms/passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -621,12 +577,7 @@ gentbl_cc_library( name = "chlo_legalize_to_hlo_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms", - tbl_outs = [ - ( - ["-gen-rewriters"], - "mhlo/transforms/chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc", - ), - ], + tbl_outs = {"mhlo/transforms/chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td", deps = [":hlo_ops_td_files"], @@ -856,15 +807,10 @@ gentbl_cc_library( name = "transforms_passes_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=LMHLOTransforms", - ], - "transforms/passes.h.inc", - ), - ], + tbl_outs = {"transforms/passes.h.inc": [ + "-gen-pass-decls", + "-name=LMHLOTransforms", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -874,15 +820,10 @@ gentbl_cc_library( name = "gpu_transforms_passes_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=LMHLOGPUTransforms", - ], - "transforms/gpu_passes.h.inc", - ), - ], + tbl_outs = {"transforms/gpu_passes.h.inc": [ + "-gen-pass-decls", + "-name=LMHLOGPUTransforms", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/gpu_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -1062,14 +1003,9 @@ gentbl_cc_library( name = "stablehlo_extension_pass_inc_gen", compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - ], - "stablehlo_ext/transforms/passes.h.inc", - ), - ], + tbl_outs = {"stablehlo_ext/transforms/passes.h.inc": [ + "-gen-pass-decls", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "stablehlo_ext/transforms/passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], From e8bbb0f7951df0253f240bbb6eec242bfd1aaeb7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 9 Apr 2025 03:31:24 -0700 Subject: [PATCH 0432/1324] [XLA:GPU] Enable tests that were unblocked by fixing the issue of data type propagation in `NestGemmFusion`. [This commit](https://github.com/openxla/xla/commit/098a6e0ecce0268403cb087b7b1a7bf628e6ead6) allows turning on a couple more tests in `fusion_emitter_device_legacy_port_test.cc`. PiperOrigin-RevId: 745514508 --- .../fusion_emitter_device_legacy_port_test.cc | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index d11a868d6819bd..f07802411d6aa9 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -3116,10 +3116,7 @@ ENTRY e { /*run_hlo_passes=*/false)); } -// TODO(b/393299275): enable once `NestGemmFusion` data type propagation is -// fixed. At the moment, the data type is not propagated correctly and that -// causes a miscompile. -TEST_F(CompareTest, DISABLED_SplitKNontrivialBitcast) { +TEST_F(CompareTest, SplitKNontrivialBitcast) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3890,9 +3887,7 @@ ENTRY e { /*run_hlo_passes=*/false)); } -// TODO(b/393299275): this test uncovers a bug in hoisting bitcasts through -// broadcasts (seems to generate a type mismatch). -TEST_F(TritonTest, DISABLED_UseTF32For8BitOrLessWithF32) { +TEST_F(TritonTest, UseTF32For8BitOrLessWithF32) { constexpr absl::string_view kHloText = R"( HloModule t @@ -3931,8 +3926,9 @@ CHECK: tt.dot CHECK: inputPrecision = tf32 )")); - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } // TODO(b/393299275): this test requires us to allow actual mixed type GEMMs From b875ce50f3499fd88fcd3c971595eddd8974e32c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 03:37:57 -0700 Subject: [PATCH 0433/1324] Automated Code Change PiperOrigin-RevId: 745516151 --- tensorflow/compiler/tf2xla/lib/scatter.cc | 5 ++--- tensorflow/compiler/tf2xla/lib/scatter.h | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 69c8a830937257..5d5ec2e22f2c3e 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -35,9 +35,8 @@ limitations under the License. namespace tensorflow { absl::StatusOr XlaScatter( - const xla::XlaOp& buffer, const xla::XlaOp& updates, - const xla::XlaOp& indices, bool indices_are_vectors, - bool indices_are_sorted, + const xla::XlaOp buffer, const xla::XlaOp updates, const xla::XlaOp indices, + bool indices_are_vectors, bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder) { diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 90af6e63fcbf05..1428d173ea138c 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -45,9 +45,8 @@ namespace tensorflow { // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. absl::StatusOr XlaScatter( - const xla::XlaOp& buffer, const xla::XlaOp& updates, - const xla::XlaOp& indices, bool indices_are_vectors, - bool indices_are_sorted, + xla::XlaOp buffer, xla::XlaOp updates, xla::XlaOp indices, + bool indices_are_vectors, bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder); From 6efefbf1aa28cf5f3e9427a4f03e30a027f3fffe Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 9 Apr 2025 03:44:02 -0700 Subject: [PATCH 0434/1324] =?UTF-8?q?[XLA:GPU]=C2=A0Enable=20and=20delete?= =?UTF-8?q?=20more=20tests=20in=20`fusion=5Femitter=5Fdevice=5Flegacy=5Fpo?= =?UTF-8?q?rt=5Ftest.cc`.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All the removed tests were confirmed to work manually before being deleted. We list them here, along with a reason for deleting them: * `TritonGemmTest.BatchF32F16`: this is a trivial test of compositionality, which is tested by virtually any existing test of the generic Triton emitter, and is thus not worth a dedicated test; * `CompareTest.DifferentTilingsProduceSameResult`: this is pretty much covered by any existing test of the generic Triton emitter; * `CompareTest.F16`: this is already well covered by the parametrized `DotUnsetAlgorithmEmitterTest` in `fusion_emitter_device_test.cc`; * `CompareTest.F32`: this is already well covered by the parametrized `DotUnsetAlgorithmEmitterTest` in `fusion_emitter_device_test.cc`; * `CompareTest.BF16TransposedLHS`: this is already well covered by a combination of `TritonEmitterTest.DotWithMajorLhsContractingDimIsEmittedCorrectly` and `DotUnsetAlgorithmEmitterTest`; * `CompareTest.F16TransposedRHS`: this is already well covered by a combination of `TritonEmitterTest.DotWithMinorRhsContractingDimIsEmittedCorrectly` and `DotUnsetAlgorithmEmitterTest`; * `CompareTest.F32TransposedBoth`: same as `CompareTest.BF16TransposedLHS` and `CompareTest.F16TransposedRHS`; * `CompareTest.S8BF16`: this is already well covered by the parametrized `DotUnsetAlgorithmEmitterTest`. At the same time, we uncover that `CompareTest.SplitK` seems to be failing due to a miscompile. This issue remains to be investigated. PiperOrigin-RevId: 745517765 --- .../fusion_emitter_device_legacy_port_test.cc | 430 ++---------------- 1 file changed, 47 insertions(+), 383 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index f07802411d6aa9..b37d2e0a7658f0 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -783,7 +783,9 @@ ENTRY r { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_DoNotFuseSplitRhsContractingTranspose) { +// TODO(b/393299275): this is a pure test of fusion logic. It should be moved to +// a separate, fusion-specific, deviceless test. +TEST_F(TritonGemmTest, DoNotFuseSplitRhsContractingTranspose) { constexpr absl::string_view kHloText = R"( HloModule t @@ -802,13 +804,13 @@ ENTRY e { ; CHECK: transpose ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: "__triton_nested_gemm_fusion" )"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_DoNotFuseSplitLhsContractingTranspose) { +// TODO(b/393299275): this is a pure test of fusion logic. It should be moved to +// a separate, fusion-specific, deviceless test. +TEST_F(TritonGemmTest, DoNotFuseSplitLhsContractingTranspose) { constexpr absl::string_view kHloText = R"( HloModule t @@ -818,7 +820,7 @@ ENTRY e { p0tr = f16[16,75]{1,0} reshape(p0t) p1 = s8[128,75]{1,0} parameter(1) cp1 = f16[128,75]{1,0} convert(p1) - ROOT dot.126 = f16[16,128]{1,0} dot(p0tr, cp1), + ROOT dot = f16[16,128]{1,0} dot(p0tr, cp1), lhs_contracting_dims={1}, rhs_contracting_dims={1} })"; @@ -827,38 +829,14 @@ ENTRY e { ; CHECK: transpose ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_BatchF32F16) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - x = f32[5,2,3] parameter(0) - y = f16[5,3,4] parameter(1) - cy = f32[5,3,4] convert(y) - ROOT _ = f32[5,2,4] dot(x, cy), - lhs_contracting_dims={2}, rhs_contracting_dims={1}, - lhs_batch_dims={0}, rhs_batch_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: fusion -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: "__triton_nested_gemm_fusion" )"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); } -TEST_F(TritonGemmTest, DISABLED_NonMajorMostInputBatchWorksCorrectly) { +// TODO(b/393299275): this test should be rewritten to start from +// post-optimization HLO. (Though I'm not entirely sure it's even worth keeping +// it.) +TEST_F(TritonGemmTest, NonMajorMostInputBatchWorksCorrectly) { constexpr absl::string_view kHloText = R"( HloModule t @@ -877,39 +855,15 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: "__triton_nested_gemm_fusion" )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_BatchTransposeF32F16) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - x = f32[5,3,2] parameter(0) - y = f16[5,3,4] parameter(1) - cy = f32[5,3,4] convert(y) - x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1} - ROOT _ = f32[5,2,4] dot(x_transposed, cy), - lhs_contracting_dims={2}, rhs_contracting_dims={1}, - lhs_batch_dims={0}, rhs_batch_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: fusion -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, DISABLED_DoNotFuseArbitraryReshape) { +// TODO(b/393299275): this is a pure test of fusion logic. It should be moved to +// a separate, fusion-specific, deviceless test. +TEST_F(TritonGemmTest, DoNotFuseArbitraryReshape) { constexpr absl::string_view kHloText = R"( HloModule m @@ -918,7 +872,7 @@ ENTRY e { p0c = f32[5,2,3] convert(p0) p1 = f32[20,3] parameter(1) p1r = f32[5,3,4] reshape(p1) - ROOT dot.5 = f32[5,2,4] dot(p0c, p1r), + ROOT dot = f32[5,2,4] dot(p0c, p1r), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; @@ -928,10 +882,8 @@ ENTRY e { ; CHECK: f32[5,3,4]{2,1,0} bitcast ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: "__triton_nested_gemm_fusion" )"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); } TEST_F(TritonGemmTest, MultipleBatchRequireSeparateTranspose) { @@ -2547,134 +2499,6 @@ ENTRY e { // into Triton fusions. using CompareTest = TritonGemmTest; -TEST_F(CompareTest, DISABLED_DifferentTilingsProduceSameResult) { - constexpr absl::string_view hlo_text_ref = R"( -HloModule t - -triton_dot { - p0 = s8[101,202] parameter(0) - p0c = f32[101,202] convert(p0) - p1 = f32[202,303] parameter(1) - ROOT dot = f32[101,303] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[101,202]{1,0} parameter(0) - p1 = f32[202,303]{1,0} parameter(1) - ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":16,"block_n":64,"block_k":32, - "split_k":1,"num_stages":3,"num_warps":8, - "num_ctas":1}}} -})"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - p0 = s8[101,202] parameter(0) - p0c = f32[101,202] convert(p0) - p1 = f32[202,303] parameter(1) - ROOT dot = f32[101,303] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[101,202]{1,0} parameter(0) - p1 = f32[202,303]{1,0} parameter(1) - ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":128,"block_k":32, - "split_k":1,"num_stages":2,"num_warps":4, - "num_ctas":1}}} -})"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, - /*run_hlo_passes=*/false)); -} - -TEST_F(CompareTest, DISABLED_F16) { - constexpr absl::string_view hlo_text_ref = R"( -HloModule r - -ENTRY e { - arg0 = f16[5,7] parameter(0) - arg1 = f16[7,33] parameter(1) - gemm = (f16[5,33], s8[0]{0}) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = f16[5,33]{1,0} get-tuple-element((f16[5,33]{1,0}, s8[0]{0}) gemm), index=0 -} -)"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - p0 = f16[5,7] parameter(0) - p1 = f16[7,33] parameter(1) - ROOT dot = f16[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f16[5,7]{1,0} parameter(0) - p1 = f16[7,33]{1,0} parameter(1) - ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, - "split_k":1,"num_stages":1,"num_warps":1, - "num_ctas":1}}} -} -)"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, - /*run_hlo_passes=*/false)); -} - -TEST_F(CompareTest, DISABLED_F32) { - constexpr absl::string_view hlo_text_ref = R"( -HloModule r - -ENTRY e { - arg0 = f32[5,7] parameter(0) - arg1 = f32[7,33] parameter(1) - gemm = (f32[5,33], s8[0]{0}) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = f32[5,33]{1,0} get-tuple-element((f32[5,33]{1,0}, s8[0]{0}) gemm), index=0 -} -)"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, - "split_k":1,"num_stages":1,"num_warps":1, - "num_ctas":1}}} -} -)"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, - /*run_hlo_passes=*/false)); -} TEST_F(CompareTest, DISABLED_F32WithTrivialNonContractingDimension) { constexpr absl::string_view hlo_text_ref = R"( @@ -2716,49 +2540,6 @@ ENTRY e { /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_BF16TransposedLHS) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view hlo_text_ref = R"( -HloModule r - -ENTRY e { - arg0 = bf16[512,16]{1,0} parameter(0) - arg1 = bf16[512,256]{1,0} parameter(1) - gemm = (bf16[16,256]{1,0}, s8[0]{0}) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = bf16[16,256]{1,0} get-tuple-element((bf16[16,256]{1,0}, s8[0]{0}) gemm), index=0 -} -)"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - arg0 = bf16[512,16]{1,0} parameter(0) - arg1 = bf16[512,256]{1,0} parameter(1) - ROOT dot = bf16[16,256]{1,0} dot(arg0, arg1), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -} - -ENTRY e { - arg0 = bf16[512,16]{1,0} parameter(0) - arg1 = bf16[512,256]{1,0} parameter(1) - ROOT _ = bf16[16,256]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":128,"block_n":32,"block_k":16, - "split_k":1,"num_stages":2,"num_warps":4, - "num_ctas":1}}} -} -)"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, - /*run_hlo_passes=*/false)); -} - TEST_F(CompareTest, DISABLED_UsingOptinSharedMemoryOnAmpereProducesSameResult) { if (std::holds_alternative( GpuComputeCapability())) { @@ -2845,142 +2626,14 @@ ENTRY e { /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_F16TransposedRHS) { - constexpr absl::string_view hlo_text_ref = R"( -HloModule r - -ENTRY e { - arg0 = f16[128,32]{1,0} parameter(0) - arg1 = f16[64,32]{1,0} parameter(1) - gemm = (f16[128,64]{1,0}, s8[0]{0}) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = f16[128,64]{1,0} get-tuple-element((f16[128,64]{1,0}, s8[0]{0}) gemm), index=0 -} -)"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - arg0 = f16[128,32]{1,0} parameter(0) - arg1 = f16[64,32]{1,0} parameter(1) - ROOT dot = f16[128,64]{1,0} dot(arg0, arg1), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -} - -ENTRY e { - arg0 = f16[128,32]{1,0} parameter(0) - arg1 = f16[64,32]{1,0} parameter(1) - ROOT _ = f16[128,64]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":128,"block_n":32,"block_k":64, - "split_k":1,"num_stages":2,"num_warps":4, - "num_ctas":1}}} -} -)"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, - /*run_hlo_passes=*/false)); -} - -TEST_F(CompareTest, DISABLED_F32TransposedBoth) { - constexpr absl::string_view hlo_text_ref = R"( -HloModule r - -ENTRY e { - arg0 = f32[64,128]{1,0} parameter(0) - arg1 = f32[1024,64]{1,0} parameter(1) - gemm = (f32[128,1024]{1,0}, s8[0]{0}) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = f32[128,1024]{1,0} get-tuple-element((f32[128,1024]{1,0}, s8[0]{0}) gemm), index=0 -} -)"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - arg0 = f32[64,128]{1,0} parameter(0) - arg1 = f32[1024,64]{1,0} parameter(1) - ROOT dot = f32[128,1024]{1,0} dot(arg0, arg1), - lhs_contracting_dims={0}, rhs_contracting_dims={1} -} - -ENTRY e { - arg0 = f32[64,128]{1,0} parameter(0) - arg1 = f32[1024,64]{1,0} parameter(1) - ROOT _ = f32[128,1024]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":64, - "split_k":1,"num_stages":2,"num_warps":4, - "num_ctas":1}}} -} -)"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, - /*run_hlo_passes=*/false)); -} - -TEST_F(CompareTest, DISABLED_S8BF16) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } - constexpr absl::string_view hlo_text_ref = R"( -HloModule r - -fused_computation { - param_0.1 = s8[144,256]{1,0} parameter(0) - ROOT convert.4 = bf16[144,256]{1,0} convert(param_0.1) -} - -ENTRY e { - p0 = s8[144,256]{1,0} parameter(0) - fusion = bf16[144,256]{1,0} fusion(p0), kind=kInput, calls=fused_computation - p1 = bf16[256,122]{1,0} parameter(1) - gemm = (bf16[144,122]{1,0}, s8[0]{0}) custom-call(fusion, p1), - custom_call_target="__cublas$gemm", - backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = bf16[144,122]{1,0} get-tuple-element((bf16[144,122]{1,0}, s8[0]{0}) gemm), index=0 -} -)"; - - constexpr absl::string_view hlo_text_triton = R"( -HloModule t - -triton_dot { - param_0.1 = s8[144,256]{1,0} parameter(0) - p0c = bf16[144,256]{1,0} convert(param_0.1) - param_1.1 = bf16[256,122]{1,0} parameter(1) - ROOT dot = bf16[144,122]{1,0} dot(p0c, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[144,256]{1,0} parameter(0) - p1 = bf16[256,122]{1,0} parameter(1) - ROOT _ = bf16[144,122]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":64,"block_n":64,"block_k":64, - "split_k":1,"num_stages":1,"num_warps":2, - "num_ctas":1}}} -} -)"; - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, - /*run_hlo_passes=*/false)); -} - +// TODO(b/393299275): there seems to be a (not yet diagnosed) miscompile here. +// We have to investigate. TEST_F(CompareTest, DISABLED_SplitK) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } constexpr absl::string_view hlo_text_ref = R"( -HloModule t, is_scheduled=true +HloModule t triton_gemm_r { parameter_0 = s8[480,120]{1,0} parameter(0) @@ -2991,19 +2644,20 @@ triton_gemm_r { } ENTRY e { - p1 = bf16[16,120]{1,0} parameter(1) - p0 = s8[3,120,5,32]{3,2,1,0} parameter(0) - bitcast.4 = s8[480,120]{1,0} bitcast(p0) - ROOT triton_gemm_r = bf16[480,16]{1,0} fusion(bitcast.4, p1), kind=kCustom, + p0_pred = s8[480,120]{1,0} parameter(0) + p0 = s8[480,120]{1,0} convert(p0_pred) + p1_pred = pred[16,120]{1,0} parameter(1) + p1 = bf16[16,120]{1,0} convert(p1_pred) + ROOT triton_gemm_r = bf16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + triton_gemm_config: {"block_m":64,"block_n":16,"block_k":16, "split_k":1,"num_stages":4,"num_warps":4, "num_ctas":1}}} })"; constexpr absl::string_view hlo_text_splitk = R"( -HloModule t, is_scheduled=true +HloModule t triton_gemm_r { parameter_0 = s8[480,120]{1,0} parameter(0) @@ -3033,22 +2687,32 @@ fused_computation { } ENTRY e { - p1 = bf16[16,120]{1,0} parameter(1) - p0 = s8[3,120,5,32]{3,2,1,0} parameter(0) - bitcast.4 = s8[480,120]{1,0} bitcast(p0) - triton_gemm_r = bf16[4,480,16]{2,1,0} fusion(bitcast.4, p1), kind=kCustom, + p0_pred = s8[480,120]{1,0} parameter(0) + p0 = s8[480,120]{1,0} convert(p0_pred) + p1_pred = pred[16,120]{1,0} parameter(1) + p1 = bf16[16,120]{1,0} convert(p1_pred) + triton_gemm_r = bf16[4,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":128, + triton_gemm_config: {"block_m":64,"block_n":16,"block_k":16, "split_k":4,"num_stages":1,"num_warps":4, "num_ctas":1}}} ROOT fusion.1 = bf16[480,16]{1,0} fusion(triton_gemm_r), kind=kLoop, calls=fused_computation })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_splitk, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(hlo_text_splitk)); + + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata ref_module_and_metadata, + GetModuleAndNestedFusionMetadata(hlo_text_ref)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(ref_module_and_metadata.module), + std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); } TEST_F(CompareTest, DISABLED_SplitKBatch) { From 847192bcb361db599b65c432538d79d5265a695d Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Wed, 9 Apr 2025 04:49:47 -0700 Subject: [PATCH 0435/1324] [XLA:GPU] fix the precision algorithm behavior in case of the trivial tensors like [1,1] These tensors will be rewritten to the trivial multiply. We don't have the specific rewrite for them in GPU codebase. But we have the rewrite to multiply for the less trivial cases like [1, 100] x [100, 1]. Let's split these cases. PiperOrigin-RevId: 745535271 --- .../xla/xla/hlo/transforms/simplifiers/BUILD | 2 ++ .../simplifiers/algebraic_simplifier.cc | 13 ++++++---- .../simplifiers/algebraic_simplifier.h | 3 ++- .../simplifiers/algebraic_simplifier_test.cc | 20 ++++++++++++--- .../xla/xla/service/gpu/transforms/BUILD | 3 +++ .../gpu/transforms/algebraic_simplifier.cc | 6 ++++- .../gpu/transforms/algebraic_simplifier.h | 3 ++- .../transforms/algebraic_simplifier_test.cc | 25 +++++++++++++++++++ 8 files changed, 64 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD index 41a2ce422b6d86..cab76b59d355cb 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -367,6 +367,8 @@ cc_library( "//xla/service:memory_annotations_hdr", "//xla/service:pattern_matcher", "//xla/service:shape_inference", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 398d8dead09791..7466c5d7d9c482 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -68,6 +69,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" @@ -3961,7 +3964,7 @@ absl::Status AlgebraicSimplifierVisitor::RewriteBatchPlusContractingAsReduce( } bool AlgebraicSimplifierVisitor::SupportedDotPrecisionConfig( - const PrecisionConfig& config) { + const PrecisionConfig& config, bool has_contracting_dim) { return config.algorithm() == PrecisionConfig::ALG_UNSET || // TODO(loislo): Fixes a failure on a test with CPU backend. config.algorithm() == PrecisionConfig::ALG_DOT_F32_F32_F32; @@ -3995,11 +3998,10 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } - const bool can_rewrite_dot_with_precision_config_algorithm = - SupportedDotPrecisionConfig(dot->precision_config()); // If there are no contracting dimensions, a dot can be rewritten as // mul(broadcast(transpose(x)),broadcast(transpose(y))) - if (can_rewrite_dot_with_precision_config_algorithm && + if (SupportedDotPrecisionConfig(dot->precision_config(), + /*has_contracting_dim=*/false) && options_.enable_dot_to_multiply_rewrite() && dnums.lhs_contracting_dimensions_size() == 0) { return RewriteAsMultiplyDotWithZeroLhsContractingDim(dot, lhs, rhs, dnums); @@ -4026,7 +4028,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) - if (can_rewrite_dot_with_precision_config_algorithm && + if (SupportedDotPrecisionConfig(dot->precision_config(), + /*has_contracting_dim=*/true) && options_.enable_dot_strength_reduction() && DotHasOnlyBatchAndContractingOnOneOperand(lhs->shape().dimensions_size(), rhs->shape().dimensions_size(), diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h index d39680ed40dbc8..f02e170dfbd4f4 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h @@ -597,7 +597,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { private: // Returns whether the dot precision config is supported by simplifier. - virtual bool SupportedDotPrecisionConfig(const PrecisionConfig& config); + virtual bool SupportedDotPrecisionConfig(const PrecisionConfig& config, + bool has_contracting_dim); // Makes algorithm specific set of instructions for multiply with precision // algorithm in mind. In the trivial case it returns just multiply. diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 5d0d1705129942..984af5cdce37c8 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -9849,11 +9849,25 @@ TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { TEST_F(AlgebraicSimplifierTest, NoDotToMultiplyRewriteWithPrecisionConfigAlgorithm) { constexpr char kModuleStr[] = R"( +HloModule test +ENTRY dot { + a = f32[128]{0} parameter(0) + b = f32[128]{0} parameter(1) + ROOT dot = f32[128,128]{1,0} dot(a, b), algorithm=dot_tf32_tf32_f32 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + +TEST_F(AlgebraicSimplifierTest, + NoDotToMultiplyRewriteZeroContractingDimWithPrecisionConfigAlgorithm) { + constexpr char kModuleStr[] = R"( HloModule test ENTRY dot { - a = f32[128]{0} parameter(0) - b = f32[128]{0} parameter(1) - ROOT dot = f32[128,128]{1,0} dot(a, b), algorithm=dot_tf32_tf32_f32 + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT dot = f32[] dot(a, b), algorithm=dot_tf32_tf32_f32 } )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index f93c98fa6f03f4..2fab25b5c14fbb 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -80,12 +80,15 @@ xla_cc_test( deps = [ ":algebraic_simplifier", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/service:pattern_matcher", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc index 89bffebd32756a..d591b06a98d8db 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc @@ -82,7 +82,11 @@ absl::Status GpuAlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { } bool GpuAlgebraicSimplifierVisitor::SupportedDotPrecisionConfig( - const PrecisionConfig& config) { + const PrecisionConfig& config, bool has_contracting_dim) { + if (!has_contracting_dim) { + return config.algorithm() == PrecisionConfig::ALG_UNSET || + config.algorithm() == PrecisionConfig::ALG_DOT_F32_F32_F32; + } return config.algorithm() == PrecisionConfig::ALG_UNSET || config.algorithm() == PrecisionConfig::ALG_DOT_BF16_BF16_F32 || config.algorithm() == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h index 51a4c67fdefeac..781dea4d5fb680 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h @@ -45,7 +45,8 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor { private: // Returns true if the dot precision config is supported by simplifier. - bool SupportedDotPrecisionConfig(const PrecisionConfig& config) override; + bool SupportedDotPrecisionConfig(const PrecisionConfig& config, + bool has_contracting_dim) override; // Makes algorithm specific set of instructions for multiply with precision // algorithm in mind. In the trivial case it returns just multiply. diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc index 50e76c98201e7e..24def9147c7ee4 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc @@ -19,12 +19,15 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -360,5 +363,27 @@ TEST_F(GpuAlgebraicSimplifierTest, ASSERT_TRUE(GpuAlgebraicSimplifier(options, Ampere()).Run(m.get()).value()); } +TEST_F( + GpuAlgebraicSimplifierTest, + DotToMultiplyRewriteForZeroContractingDimWith_BF16_BF16_F32_X6_Algorithm) { + constexpr char kModuleStr[] = R"( + HloModule test + ENTRY dot { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT dot = f32[] dot(a, b), + algorithm=dot_bf16_bf16_f32_x6 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + ASSERT_TRUE(GpuAlgebraicSimplifier(options, Ampere()).Run(m.get()).value()); + constexpr absl::string_view kPattern = R"( + CHECK-COUNT-6: %[[partial_result:.*]] = bf16[] multiply + )"; + TF_ASSERT_OK_AND_ASSIGN(bool matched, RunFileCheck(m->ToString(), kPattern)); + EXPECT_TRUE(matched); +} + } // namespace } // namespace xla::gpu From bb9d43ca2f93ee6d6ef39f69b07ac580ae1f2089 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 9 Apr 2025 05:33:01 -0700 Subject: [PATCH 0436/1324] [XLA:GPU] Automatically enable AUTO layout when using multihost layout Make the recently added --use_layouts_from_hlo_module flag of the multihost runner also enable AUTO layout. PiperOrigin-RevId: 745546575 --- third_party/xla/xla/tools/multihost_hlo_runner/BUILD | 1 + .../tools/multihost_hlo_runner/data/auto_layout.hlo | 7 +++++++ .../multihost_hlo_runner/functional_hlo_runner.cc | 2 ++ .../functional_hlo_runner_test.cc | 12 ++++++++++++ 4 files changed, 22 insertions(+) create mode 100644 third_party/xla/xla/tools/multihost_hlo_runner/data/auto_layout.hlo diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index d8d10a07b29e6b..685f72b8fd2d01 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -210,6 +210,7 @@ xla_test( "gpu", ], data = [ + "data/auto_layout.hlo", "data/multiple_gemm_fusions.hlo", "data/sharded_16_devices.hlo", "data/sharded_2_devices.hlo", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/data/auto_layout.hlo b/third_party/xla/xla/tools/multihost_hlo_runner/data/auto_layout.hlo new file mode 100644 index 00000000000000..76dd13dc9156e1 --- /dev/null +++ b/third_party/xla/xla/tools/multihost_hlo_runner/data/auto_layout.hlo @@ -0,0 +1,7 @@ +HloModule t, entry_computation_layout={(bf16[4096,64,8], bf16[4096,512,8])->bf16[8,64,512]} + +ENTRY main { + p0 = bf16[4096,64,8] parameter(0) + p1 = bf16[4096,512,8] parameter(1) + ROOT res = bf16[8,64,512] dot(p0,p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}, lhs_batch_dims={2}, rhs_batch_dims={2} +} \ No newline at end of file diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index bc534ca0fed0b6..2f07003cfc6131 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -809,6 +809,8 @@ CompileOptions FunctionalHloRunner::CompleteCompileOptions( compile_options.argument_layouts = std::move(parameter_shapes); compile_options.executable_build_options.set_result_layout( layout.result_shape()); + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_pjrt_allow_auto_layout_in_hlo(true); } return compile_options; } diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 9007e427adf28b..c60cb2ed61485d 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -402,6 +402,18 @@ TEST_F(FunctionalHloRunnerTest, KeepLayoutsFromHloModule) { /*num_partitions=*/1); } +TEST_F(FunctionalHloRunnerTest, AutoLayoutAssignsNonDefaultLayout) { + if (IsTestingCpu()) GTEST_SKIP() << "CPU doesn't support auto-layout yet."; + FunctionalHloRunner::PreprocessingOptions preproc_options; + preproc_options.use_layouts_from_hlo_module = true; + CompileAndFilecheck(GetHloPath("auto_layout.hlo"), + // Makes LHS contracting dimension minor. + "// CHECK: entry_computation_layout={(bf16[4096,64,8]{0", + preproc_options, + FunctionalHloRunner::HloPassesMode::kStandardCompile, + /*num_partitions=*/1); +} + TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { CompileAndFilecheck(GetHloPath("sharded_16_devices.hlo"), // Check that the sharding was done correctly. From 02d7df64f89971f2a8ee00bdb25aa147db0c422d Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 9 Apr 2025 05:55:29 -0700 Subject: [PATCH 0437/1324] =?UTF-8?q?[XLA:GPU]=C2=A0Finish=20enabling=20or?= =?UTF-8?q?=20documenting=20failures=20for=20`CompareTest`=20tests=20in=20?= =?UTF-8?q?`fusion=5Femitter=5Fdevice=5Flegacy=5Fport=5Ftest.cc`.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PiperOrigin-RevId: 745552604 --- .../fusion_emitter_device_legacy_port_test.cc | 221 ++++++++++-------- 1 file changed, 126 insertions(+), 95 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index b37d2e0a7658f0..fed261431da15c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -2499,7 +2499,8 @@ ENTRY e { // into Triton fusions. using CompareTest = TritonGemmTest; - +// TODO(bchetioui): same as +// TritonTest.TestGemmWithTrivialNonContractingDimension. TEST_F(CompareTest, DISABLED_F32WithTrivialNonContractingDimension) { constexpr absl::string_view hlo_text_ref = R"( HloModule r @@ -2540,59 +2541,70 @@ ENTRY e { /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_UsingOptinSharedMemoryOnAmpereProducesSameResult) { +// TODO(b/353484968, b/393299275): the e2e test path was never really testing +// anything useful until now---it used to not actually set opt-in shared memory +// on properly. Additionally, it claimed to be testing Ampere specifically but +// runs across every possible chip that is supported by Triton. The test should +// probably be made deviceless and repurposed to test that opt-in shared memory +// is used only. +TEST_F(CompareTest, UsingOptinSharedMemoryProducesSameResult) { if (std::holds_alternative( GpuComputeCapability())) { GTEST_SKIP() << "No Optin Shared Memory on AMD."; } const se::DeviceDescription dev_info = backend().default_stream_executor()->GetDeviceDescription(); - constexpr int kBytesOfSharedMemoryTested = 64 * 1024; + // TODO(b/353484968): pin this test to a specific device type to ensure + // correct expectations. + // + // On Hopper, the RHS has to be provided through shared memory, so a minima, + // the kernel will get away with using + // num_stages * block_k * block_n * sizeof(rhs_element_type) + // = 2 * 128 * 128 * 2 + // = 65536 bytes. + // + // This should hold on Blackwell as well. + constexpr int kBytesOfSharedMemoryTested = 2 * 128 * 128 * 2; EXPECT_GE(dev_info.shared_memory_per_block_optin(), kBytesOfSharedMemoryTested); const std::string kHloTextOptinShmem = R"( -HloModule t - triton_dot { - param_0.1 = s8[332,441]{1,0} parameter(0) - p0c = f16[332,441]{1,0} convert(param_0.1) - param_1.1 = f16[441,39]{1,0} parameter(1) - ROOT dot = f16[332,39]{1,0} dot(p0c, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + p0 = s8[332,441]{1,0} parameter(0) + convert = bf16[332,441]{1,0} convert(p0) + p1 = bf16[441,39]{1,0} parameter(1) + // Fix an algorithm on the dot in order to explicitly control the size of the + // operands in shared memory, as well as the precision. + ROOT dot = bf16[332,39]{1,0} dot(convert, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32 } ENTRY e { p0 = s8[332,441]{1,0} parameter(0) - p1 = f16[441,39]{1,0} parameter(1) - ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, + p1 = bf16[441,39]{1,0} parameter(1) + ROOT _ = bf16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":128,"block_k":128, "split_k":1,"num_stages":2,"num_warps":32, "num_ctas":1}}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kHloTextOptinShmem)); + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata optin_shmem_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextOptinShmem)); const HloFusionInstruction* triton_dot_fusion = Cast( - hlo_module->entry_computation()->root_instruction()); + optin_shmem_module_and_metadata.computation->FusionInstruction()); llvm::LLVMContext llvm_ctx; llvm::Module llvm_module("module", llvm_ctx); mlir::MLIRContext mlir_context; - TF_ASSERT_OK_AND_ASSIGN( - auto gpu_config, triton_dot_fusion->backend_config()); - const FusionBackendConfig& config = gpu_config.fusion_backend_config(); - auto gemm_config = config.triton_gemm_config(); - BlockLevelParameters block_level_parameters; - block_level_parameters.num_ctas = gemm_config.num_ctas(); - block_level_parameters.num_warps = gemm_config.num_warps(); - block_level_parameters.num_stages = gemm_config.num_stages(); TF_ASSERT_OK_AND_ASSIGN( const auto result, TritonWrapper("test_fn", triton_dot_fusion, GpuComputeCapability(), - dev_info, block_level_parameters, &llvm_module, - mlir_context)); + dev_info, + optin_shmem_module_and_metadata.block_level_parameters, + &llvm_module, mlir_context)); // The config is chosen so that the used memory size is slightly above the // 48 kB boundary of standard / optin shared memory so that any GPU that // has the optin one should be able to execute the test. @@ -2604,26 +2616,35 @@ ENTRY e { HloModule t triton_dot { - param_0.1 = s8[332,441]{1,0} parameter(0) - p0c = f16[332,441]{1,0} convert(param_0.1) - param_1.1 = f16[441,39]{1,0} parameter(1) - ROOT dot = f16[332,39]{1,0} dot(p0c, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + p0 = s8[332,441]{1,0} parameter(0) + convert = bf16[332,441]{1,0} convert(p0) + p1 = bf16[441,39]{1,0} parameter(1) + // Fix an algorithm on the dot in order to explicitly control the size of the + // operands in shared memory, as well as the precision. + ROOT dot = bf16[332,39]{1,0} dot(convert, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32 } ENTRY e { p0 = s8[332,441]{1,0} parameter(0) - p1 = f16[441,39]{1,0} parameter(1) - ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, + p1 = bf16[441,39]{1,0} parameter(1) + ROOT _ = bf16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, "split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextLowShmem, kHloTextOptinShmem, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata low_shmem_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextLowShmem)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(low_shmem_module_and_metadata.module), + std::move(optin_shmem_module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); } // TODO(b/393299275): there seems to be a (not yet diagnosed) miscompile here. @@ -2715,13 +2736,11 @@ ENTRY e { /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, DISABLED_SplitKBatch) { +TEST_F(CompareTest, SplitKBatch) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloTextRef = R"( -HloModule m, is_scheduled=true - triton_gemm_dot.24 { parameter_1 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) bitcast.3 = bf16[800,5,128]{2,1,0} bitcast(parameter_1) @@ -2742,9 +2761,7 @@ ENTRY e { "num_ctas":1}}} })"; - const std::string kHloTextSplitK = R"( -HloModule m, is_scheduled=true - + const std::string kHloTextTest = R"( triton_gemm_dot { parameter_1 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) bitcast.3 = bf16[800,5,128]{2,1,0} bitcast(parameter_1) @@ -2774,10 +2791,18 @@ ENTRY e { constant = f32[] constant(0) ROOT reduce = f32[5,128,700]{2,1,0} reduce(triton_gemm_dot.24, constant), dimensions={0}, to_apply=add })"; + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata ref_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextRef)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(ref_module_and_metadata.module), + std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); } TEST_F(CompareTest, SplitKNontrivialBitcast) { @@ -2866,80 +2891,88 @@ ENTRY entry { // There were relatively large numeric errors with an f16 temporary buffer, so I // ended up using --xla_gpu_triton_gemm_disable_reduced_precision_reduction=true // when generating this test case. +// +// TODO(b/393299275): transform this test once padding derivation if fixed. TEST_F(CompareTest, DISABLED_SupportsSplitKWithIndivisibleKComplexExample) { constexpr absl::string_view kHloTextRef = R"( -HloModule extracted, entry_computation_layout={(s8[3,129,5,32]{3,2,1,0}, f16[16,129]{1,0})->f16[480,16]{1,0}} - -triton_gemm_dot.clone { - parameter_0 = s8[3,129,5,32]{3,2,1,0} parameter(0) - bitcast.1 = s8[3,5,32,129]{2,1,3,0} bitcast(parameter_0) - copy.1 = s8[3,5,32,129]{3,2,1,0} copy(bitcast.1) - reshape.5 = s8[480,129]{1,0} reshape(copy.1) - convert.8 = f16[480,129]{1,0} convert(reshape.5) - parameter_1 = f16[16,129]{1,0} parameter(1) - ROOT dot.0 = f16[480,16]{1,0} dot(convert.8, parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} +dot { + p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) + bitcast = s8[3,5,32,129]{2,1,3,0} bitcast(p0) + copy = s8[3,5,32,129]{3,2,1,0} copy(bitcast) + reshape = s8[480,129]{1,0} reshape(copy) + convert = f16[480,129]{1,0} convert(reshape) + p1 = f16[16,129]{1,0} parameter(1) + ROOT dot = f16[480,16]{1,0} dot(convert, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} } ENTRY entry_computation { p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) p1 = f16[16,129]{1,0} parameter(1) - ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", - "split_k":"1","num_stages":"1","num_warps":"4", - "num_ctas":"1"}}} + ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=dot, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} } )"; - constexpr absl::string_view kHloTextSplitK = R"( -HloModule extracted, entry_computation_layout={(s8[3,129,5,32]{3,2,1,0}, f16[16,129]{1,0})->f16[480,16]{1,0}} - -triton_gemm_dot.clone { - parameter_0 = s8[3,129,5,32]{3,2,1,0} parameter(0) - bitcast.1 = s8[3,5,32,129]{2,1,3,0} bitcast(parameter_0) - copy.1 = s8[3,5,32,129]{3,2,1,0} copy(bitcast.1) - reshape.5 = s8[480,129]{1,0} reshape(copy.1) - convert.8 = f16[480,129]{1,0} convert(reshape.5) - constant = f16[] constant(0) - pad = f16[480,130]{1,0} pad(convert.8, constant), padding=0_0x0_1 - bitcast = f16[480,2,65]{2,1,0} bitcast(pad) - convert.1 = f32[480,2,65]{2,1,0} convert(bitcast) - parameter_1 = f16[16,129]{1,0} parameter(1) - constant.1 = f16[] constant(0) - pad.1 = f16[16,130]{1,0} pad(parameter_1, constant.1), padding=0_0x0_1 - bitcast.2 = f16[16,2,65]{2,1,0} bitcast(pad.1) - convert.2 = f32[16,2,65]{2,1,0} convert(bitcast.2) - ROOT dot.2 = f32[2,480,16]{2,1,0} dot(convert.1, convert.2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2} + constexpr absl::string_view kHloTextTest = R"( +dot { + p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) + bitcast_p0 = s8[3,5,32,129]{2,1,3,0} bitcast(p0) + copy_p0 = s8[3,5,32,129]{3,2,1,0} copy(bitcast_p0) + reshape_p0 = s8[480,129]{1,0} reshape(copy_p0) + convert.8 = f16[480,129]{1,0} convert(reshape_p0) + c0 = f16[] constant(0) + pad_p0 = f16[480,130]{1,0} pad(convert.8, c0), padding=0_0x0_1 + bitcast_pad_p0 = f16[480,2,65]{2,1,0} bitcast(pad_p0) + dot_lhs = f32[480,2,65]{2,1,0} convert(bitcast_pad_p0) + p1 = f16[16,129]{1,0} parameter(1) + pad_p1 = f16[16,130]{1,0} pad(p1, c0), padding=0_0x0_1 + bitcast_pad_p1 = f16[16,2,65]{2,1,0} bitcast(pad_p1) + dot_rhs = f32[16,2,65]{2,1,0} convert(bitcast_pad_p1) + ROOT dot.2 = f32[2,480,16]{2,1,0} dot(dot_lhs, dot_rhs), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={1}, rhs_contracting_dims={2} } -fusion.reduce_sub_computation { +reducer { lhs = f32[] parameter(0) rhs = f32[] parameter(1) ROOT add = f32[] add(lhs, rhs) } -fused_computation { - param_0.1 = f32[2,480,16]{2,1,0} parameter(0) - constant.3 = f32[] constant(0) - reduce.1 = f32[480,16]{1,0} reduce(param_0.1, constant.3), dimensions={0}, to_apply=fusion.reduce_sub_computation - ROOT convert.3 = f16[480,16]{1,0} convert(reduce.1) +split_k_reducer { + p0 = f32[2,480,16]{2,1,0} parameter(0) + c0 = f32[] constant(0) + reduce = f32[480,16]{1,0} reduce(p0, c0), dimensions={0}, to_apply=reducer + ROOT convert = f16[480,16]{1,0} convert(reduce) } ENTRY entry_computation { p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) p1 = f16[16,129]{1,0} parameter(1) - fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=dot, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"64", "split_k":"2","num_stages":"1","num_warps":"8", "num_ctas":"1"}}} - ROOT fusion.1 = f16[480,16]{1,0} fusion(fusion), kind=kLoop, calls=fused_computation -} -)"; + ROOT output = f16[480,16]{1,0} fusion(fusion), kind=kLoop, + calls=split_k_reducer +})"; + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextTest)); - EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextSplitK, - ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata ref_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloTextRef)); + + EXPECT_TRUE( + RunAndCompareTwoModules(std::move(ref_module_and_metadata.module), + std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}, + /*run_hlo_passes=*/false)); } // TODO(b/393299275): transform this test once padding derivation if fixed. @@ -2965,8 +2998,7 @@ ENTRY entry_computation { "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"32", "split_k":"1","num_stages":"4","num_warps":"4", "num_ctas":"1"}}} -} -)"; +})"; constexpr absl::string_view kHloTextSplitK = R"( HloModule extracted, entry_computation_layout={(f16[1,8,4,1023]{3,2,1,0}, f16[1,1023,128]{2,1,0})->f16[1,8,4,128]{3,2,1,0}} @@ -3109,8 +3141,7 @@ ENTRY entry_computation { "num_ctas":"1"}}} ROOT fusion = f16[1,8,4,128]{3,2,1,0} fusion(dot), kind=kLoop, calls=split_k_reducer -} -)"; +})"; TF_ASSERT_OK_AND_ASSIGN( ModuleAndNestedFusionMetadata test_module_and_metadata, GetModuleAndNestedFusionMetadata(kHloTextTest)); From cdfa51b54192086dd3dbe8fb420e83eaca7ca47d Mon Sep 17 00:00:00 2001 From: Will Froom Date: Wed, 9 Apr 2025 06:04:42 -0700 Subject: [PATCH 0438/1324] [XLA:CPU] Split out fusion kernel api codegen from CpuFusionEmitterBase PiperOrigin-RevId: 745555819 --- .../codegen/emitters/cpu_fusion_emitter.cc | 180 ++++++++++-------- .../cpu/codegen/emitters/cpu_fusion_emitter.h | 5 + 2 files changed, 102 insertions(+), 83 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc index d0f175985d6baa..47d79125fd7bc3 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc @@ -192,90 +192,10 @@ CpuFusionEmitterBase::CreateMLIRModule( auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); - // Create the entry function. TF_ASSIGN_OR_RETURN( - std::vector arguments, - KernelApiIrBuilder::GetKernelArgumentsParameters(&fusion, - &buffer_assignment)); - TF_ASSIGN_OR_RETURN(std::vector results, - KernelApiIrBuilder::GetKernelResultsParameters( - &fusion, &buffer_assignment)); - - // TBD: Annotate tensors with the buffer indices. This way, the buffer - // propagation pass can clean them up later. - auto get_arg_attrs = [&](int index, BufferAllocation::Slice& slice, - bool is_result) -> absl::StatusOr { - SmallVector attrs; - attrs.push_back(builder.getNamedAttr( - "xla.slice_index", - builder.getIndexAttr(index + (is_result ? arguments.size() : 0)))); - attrs.push_back(builder.getNamedAttr( - mlir::LLVM::LLVMDialect::getDereferenceableAttrName(), - builder.getIndexAttr(slice.size()))); - attrs.push_back( - builder.getNamedAttr(mlir::LLVM::LLVMDialect::getAlignAttrName(), - builder.getIndexAttr(MinAlign()))); - return builder.getDictionaryAttr(attrs); - }; - - // First argument is the thread id. - SmallVector arg_attrs{builder.getDictionaryAttr( - builder.getNamedAttr("xla.invariant", builder.getUnitAttr()))}; - SmallVector param_types{builder.getIndexType()}; - - for (const auto& [index, arg] : llvm::enumerate(arguments)) { - param_types.push_back(emitters::TensorShapeToMlirType(arg.shape, builder)); - TF_ASSIGN_OR_RETURN( - arg_attrs.emplace_back(), - get_arg_attrs(index - 1, arg.slice, /*is_result=*/false)); - } - - auto result_types = emitters::ShapeToMlirTypes(fusion.shape(), builder); - param_types.append(result_types.begin(), result_types.end()); - for (const auto& [index, result] : llvm::enumerate(results)) { - TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), - get_arg_attrs(index, result.slice, /*is_result=*/true)); - } - - builder.setInsertionPointToStart(module->getBody()); - auto entry_func = builder.create( - loc, entry_function_name, - mlir::FunctionType::get(&context, param_types, result_types), - /*sym_visibility=*/mlir::StringAttr{}, - mlir::ArrayAttr::get(&context, arg_attrs), - /*res_attrs=*/mlir::ArrayAttr{}); - entry_func->setAttr("xla.entry", mlir::UnitAttr::get(&context)); - SetBackendKind(&context, entry_func, xla::BackendKind::kCpu); - entry_func.setPrivate(); - - // Create wrapper for the entry function. This function has one call_frame - // argument and call the entry function. - auto error_type = cpu::ErrorType::get(&context); - auto call_frame_type = CallFrameType::get(mlir_context_); - auto call_frame_func = builder.create( - loc, fusion.name(), - builder.getFunctionType(/*arg_types=*/{call_frame_type}, - /*result_types=*/{error_type})); - builder.setInsertionPointToStart(call_frame_func.addEntryBlock()); - mlir::Value call_frame_arg = call_frame_func.getArgument(0); - SmallVector extracted_values; - extracted_values.reserve(arguments.size() + results.size() + 1); - extracted_values.push_back(builder.create( - loc, builder.getIndexType(), call_frame_arg)); - - for (int i = 1; i < param_types.size(); ++i) { - extracted_values.push_back(builder.create( - loc, param_types[i], call_frame_arg, i - 1)); - } - auto call_results = - builder.create(loc, entry_func, extracted_values); - call_results->setAttr("noinline", mlir::UnitAttr::get(&context)); - for (auto [index, call_result] : llvm::enumerate(call_results.getResults())) { - builder.create(loc, call_result, call_frame_arg, - index + arguments.size()); - } - auto error = builder.create(loc, error_type); - builder.create(loc, error.getResult()); + mlir::func::FuncOp entry_func, + EmitFusionKernelApi(module.get(), fusion, entry_function_name, + buffer_assignment)); TF_RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion)); return module; @@ -382,6 +302,100 @@ IndexingMap GetDefaultIndexingMap(absl::Span thread_tile_sizes, constraints); } +absl::StatusOr EmitFusionKernelApi( + mlir::ModuleOp fusion_module, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment& buffer_assignment) { + auto* context = fusion_module.getContext(); + mlir::OpBuilder builder(context); + auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); + TF_ASSIGN_OR_RETURN( + std::vector arguments, + KernelApiIrBuilder::GetKernelArgumentsParameters(&fusion, + &buffer_assignment)); + TF_ASSIGN_OR_RETURN(std::vector results, + KernelApiIrBuilder::GetKernelResultsParameters( + &fusion, &buffer_assignment)); + + // TBD: Annotate tensors with the buffer indices. This way, the buffer + // propagation pass can clean them up later. + auto get_arg_attrs = [&](int index, BufferAllocation::Slice& slice, + bool is_result) -> absl::StatusOr { + SmallVector attrs; + attrs.push_back(builder.getNamedAttr( + "xla.slice_index", + builder.getIndexAttr(index + (is_result ? arguments.size() : 0)))); + attrs.push_back(builder.getNamedAttr( + mlir::LLVM::LLVMDialect::getDereferenceableAttrName(), + builder.getIndexAttr(slice.size()))); + attrs.push_back( + builder.getNamedAttr(mlir::LLVM::LLVMDialect::getAlignAttrName(), + builder.getIndexAttr(MinAlign()))); + return builder.getDictionaryAttr(attrs); + }; + + // First argument is the thread id. + SmallVector arg_attrs{builder.getDictionaryAttr( + builder.getNamedAttr("xla.invariant", builder.getUnitAttr()))}; + SmallVector param_types{builder.getIndexType()}; + + for (const auto& [index, arg] : llvm::enumerate(arguments)) { + param_types.push_back(emitters::TensorShapeToMlirType(arg.shape, builder)); + TF_ASSIGN_OR_RETURN( + arg_attrs.emplace_back(), + get_arg_attrs(index - 1, arg.slice, /*is_result=*/false)); + } + + auto result_types = emitters::ShapeToMlirTypes(fusion.shape(), builder); + param_types.append(result_types.begin(), result_types.end()); + for (const auto& [index, result] : llvm::enumerate(results)) { + TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), + get_arg_attrs(index, result.slice, /*is_result=*/true)); + } + + builder.setInsertionPointToStart(fusion_module.getBody()); + auto entry_func = builder.create( + loc, entry_function_name, + mlir::FunctionType::get(context, param_types, result_types), + /*sym_visibility=*/mlir::StringAttr{}, + mlir::ArrayAttr::get(context, arg_attrs), + /*res_attrs=*/mlir::ArrayAttr{}); + entry_func->setAttr("xla.entry", mlir::UnitAttr::get(context)); + SetBackendKind(context, entry_func, xla::BackendKind::kCpu); + entry_func.setPrivate(); + + // Create wrapper for the entry function. This function has one call_frame + // argument and call the entry function. + auto error_type = cpu::ErrorType::get(context); + auto call_frame_type = CallFrameType::get(context); + auto call_frame_func = builder.create( + loc, fusion.name(), + builder.getFunctionType(/*arg_types=*/{call_frame_type}, + /*result_types=*/{error_type})); + builder.setInsertionPointToStart(call_frame_func.addEntryBlock()); + mlir::Value call_frame_arg = call_frame_func.getArgument(0); + SmallVector extracted_values; + extracted_values.reserve(arguments.size() + results.size() + 1); + extracted_values.push_back(builder.create( + loc, builder.getIndexType(), call_frame_arg)); + + for (int i = 1; i < param_types.size(); ++i) { + extracted_values.push_back(builder.create( + loc, param_types[i], call_frame_arg, i - 1)); + } + auto call_results = + builder.create(loc, entry_func, extracted_values); + call_results->setAttr("noinline", mlir::UnitAttr::get(context)); + for (auto [index, call_result] : llvm::enumerate(call_results.getResults())) { + builder.create(loc, call_result, call_frame_arg, + index + arguments.size()); + } + auto error = builder.create(loc, error_type); + builder.create(loc, error.getResult()); + + return entry_func; +} + int64_t CeilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h index 52940c331461db..820886a57aa108 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h @@ -53,6 +53,11 @@ IndexingMap GetDefaultIndexingMap(absl::Span thread_tile_sizes, absl::Span shape, mlir::MLIRContext* mlir_context); +absl::StatusOr EmitFusionKernelApi( + mlir::ModuleOp fusion_module, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment& buffer_assignment); + class CpuFusionEmitterBase { public: CpuFusionEmitterBase(mlir::MLIRContext* mlir_context, From 5cba819760f2a34c0536841254ce0239d5a75f47 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 9 Apr 2025 06:12:52 -0700 Subject: [PATCH 0439/1324] [XLA:GPU] Re-enable `TritonGemmTest.NondefaultOperandLayoutIsSupported` when `-UNDEBUG` is set. Also adapt the test in `fusion_emitter_device_legacy_port_test` to be more accurate. PiperOrigin-RevId: 745558362 --- .../fusion_emitter_device_legacy_port_test.cc | 31 ++++++++++--------- .../fusion_emitter_device_legacy_test.cc | 15 +++------ 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index fed261431da15c..5bb2f36e15c913 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -757,28 +757,29 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): it's not clear that this test is actually testing what it +// claims to be testing. It should either be rewritten to start from +// post-optimization HLO, or hoisted out to test the fusion logic specifically. TEST_F(TritonGemmTest, NondefaultOperandLayoutIsSupported) { - // TODO(bchetioui): reenable when b/285866137 is fixed. -#ifndef NDEBUG - GTEST_SKIP() << "This test times out when -UNDEBUG is set."; -#endif constexpr absl::string_view kHloText = R"( ENTRY r { - p1 = f16[9,140,128]{2,1,0} parameter(1) - cp = f16[9,140,128]{2,0,1} copy(p1) - cv = f32[9,140,128]{2,0,1} convert(cp) - p0 = f32[9,140,123]{2,1,0} parameter(0) - ROOT d = f32[9,128,123]{2,1,0} dot(cv, p0), + p1 = f16[3,10,128]{2,1,0} parameter(1) + cp = f16[3,10,128]{2,0,1} copy(p1) + cv = f32[3,10,128]{2,0,1} convert(cp) + p0 = f32[3,10,123]{2,1,0} parameter(0) + ROOT d = f32[3,128,123]{2,1,0} dot(cv, p0), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: fusion( +; CHECK-SAME: kind=kCustom +; CHECK-SAME: "__triton_nested_gemm_fusion" +)"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc index 1319dacde17e67..6889723c51f8f6 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc @@ -863,21 +863,16 @@ ENTRY e { } TEST_F(TritonGemmTest, NondefaultOperandLayoutIsSupported) { - // TODO(bchetioui): reenable when b/285866137 is fixed. -#ifndef NDEBUG - GTEST_SKIP() << "This test times out when -UNDEBUG is set."; -#endif constexpr absl::string_view kHloText = R"( ENTRY r { - p1 = f16[9,140,128]{2,1,0} parameter(1) - cp = f16[9,140,128]{2,0,1} copy(p1) - cv = f32[9,140,128]{2,0,1} convert(cp) - p0 = f32[9,140,123]{2,1,0} parameter(0) - ROOT d = f32[9,128,123]{2,1,0} dot(cv, p0), + p1 = f16[3,10,128]{2,1,0} parameter(1) + cp = f16[3,10,128]{2,0,1} copy(p1) + cv = f32[3,10,128]{2,0,1} convert(cp) + p0 = f32[3,10,123]{2,1,0} parameter(0) + ROOT d = f32[3,128,123]{2,1,0} dot(cv, p0), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT( From f2b3e84e7eed5df6bb659e06d656d2b3ab4e4db3 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Apr 2025 06:31:16 -0700 Subject: [PATCH 0440/1324] PR #24898: [ROCM] fix asan invalid memory access in redzone allocator kernel rocm cu Imported from GitHub PR https://github.com/openxla/xla/pull/24898 Fix issue reported by asan while running the tests on rocm ci: ``` ==1718600==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x5030001d97f8 at pc 0x5647cfdda211 bp 0x7ffc9eb7eac0 sp 0x7ffc9eb7eab8 READ of size 8 at 0x5030001d97f8 thread T0 #0 0x5647cfdda210 in absl::lts_20230802::container_internal::CommonFields::capacity() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:990:36 #1 0x5647cfdda210 in void absl::lts_20230802::container_internal::InitializeSlots, 8ul, 8ul>(absl::lts_20230802::container_internal::CommonFields&, std::allocator) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:1403:24 #2 0x7f066c2cfdde in absl::lts_20230802::container_internal::raw_hash_set, std::allocator>, void*>, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>, absl::lts_20230802::hash_internal::Hash, std::allocator>, void*>>, std::equal_to, std::allocator>, void*>>, std::allocator, std::allocator>, void*> const, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>>>::resize(unsigned long) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x9dde) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #3 0x7f066c2cfd97 in absl::lts_20230802::container_internal::raw_hash_set, std::allocator>, void*>, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>, absl::lts_20230802::hash_internal::Hash, std::allocator>, void*>>, std::equal_to, std::allocator>, void*>>, std::allocator, std::allocator>, void*> const, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>>>::prepare_insert(unsigned long) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x9d97) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #4 0x7f066c2cfcca in std::pair absl::lts_20230802::container_internal::raw_hash_set, std::allocator>, void*>, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>, absl::lts_20230802::hash_internal::Hash, std::allocator>, void*>>, std::equal_to, std::allocator>, void*>>, std::allocator, std::allocator>, void*> const, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>>>::find_or_prepare_insert, std::allocator>, void*>>(std::tuple, std::allocator>, void*> const&) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x9cca) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #5 0x7f066c2cf9c4 in std::pair, std::allocator>, void*>, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>, absl::lts_20230802::hash_internal::Hash, std::allocator>, void*>>, std::equal_to, std::allocator>, void*>>, std::allocator, std::allocator>, void*> const, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>>>::iterator, bool> absl::lts_20230802::container_internal::raw_hash_set, std::allocator>, void*>, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>, absl::lts_20230802::hash_internal::Hash, std::allocator>, void*>>, std::equal_to, std::allocator>, void*>>, std::allocator, std::allocator>, void*> const, stream_executor::TypedKernel, unsigned char, unsigned long, stream_executor::DeviceMemory>>>>::EmplaceDecomposable::operator(), std::allocator>, void*>, std::piecewise_construct_t const&, std::tuple, std::allocator>, void*>&>, std::tuple, unsigned char, unsigned long, stream_executor::DeviceMemory>&&>>(std::tuple, std::allocator>, void*> const&, std::piecewise_construct_t const&, std::tuple, std::allocator>, void*>&>&&, std::tuple, unsigned char, unsigned long, stream_executor::DeviceMemory>&&>&&) const (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x99c4) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #6 0x7f066c2cf0ad in stream_executor::GetComparisonKernel(stream_executor::StreamExecutor*, stream_executor::GpuAsmOpts) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x90ad) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #7 0x7f066c37ba93 in stream_executor::RedzoneAllocator::CheckRedzones() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/gpu/redzone_allocator.cc:272:3 #8 0x7f06b31bb7e9 in absl::lts_20230802::StatusOr xla::gpu::(anonymous namespace)::GemmAutotuner::GetBestAlgorithm(xla::HloInstruction const*, absl::lts_20230802::Span, double, bool, xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&)::'lambda'(long const&)&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:328:7 #9 0x7f06b31bb7e9 in xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:256:12 #10 0x7f06b31bb7e9 in xla::gpu::(anonymous namespace)::GemmAutotuner::operator()(xla::HloInstruction const*, xla::gpu::AutotuneCacheKey const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:137:18 #11 0x7f06b31b6760 in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0::operator()() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #12 0x7f06b31b6760 in absl::lts_20230802::StatusOr std::__invoke_impl, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(std::__invoke_other, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14 #13 0x7f06b31b6760 in std::enable_if, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>, absl::lts_20230802::StatusOr>::type std::__invoke_r, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:114:9 #14 0x7f06b31b6760 in std::_Function_handler (), xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0>::_M_invoke(std::_Any_data const&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:290:9 #15 0x7f06b308670d in std::function ()>::operator()() const /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:590:9 #16 0x7f06b308670d in xla::gpu::AutotunerUtil::Autotune(xla::HloInstruction const*, xla::gpu::AutotuneConfig const&, std::function ()> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/autotuner_util.cc:460:3 #17 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #18 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnComputation(xla::HloComputation*, xla::gpu::(anonymous namespace)::GemmAutotuner&, unsigned long*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:468:7 #19 0x7f06b31b336e in xla::gpu::GemmAlgorithmPicker::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:495:5 #20 0x7f06b30242f3 in xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_pipeline.h:150:5 #21 0x7f06b3010bb9 in absl::lts_20230802::StatusOr xla::HloPassPipeline::RunPassesInternal(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:198:30 #22 0x7f06b300f786 in xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:338:10 #23 0x5647cfd66945 in xla::HloPassInterface::Run(xla::HloModule*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_interface.h:85:12 #24 0x7f06c2908be0 in xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1754:3 #25 0x7f06c2a000f3 in xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/amdgpu_compiler.cc:197:3 #26 0x7f06c28f85e9 in xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1392:3 #27 0x7f06c291250d in xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1824:3 #28 0x5647cfd63784 in xla::Compiler::RunHloPasses(std::unique_ptr>, stream_executor::StreamExecutor*, stream_executor::DeviceMemoryAllocator*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/service/compiler.h:177:12 #29 0x7f06c339acba in xla::HloTestBase::GetOptimizedModule(std::unique_ptr>) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/hlo_test_base.cc:188:32 #30 0x5647cfd89516 in xla::gpu::(anonymous namespace)::GpuCompilerTest_CollectivePermuteDecompositionAndPipelining_Test::TestBody() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler_test.cc:879:3 #31 0x7f06c2c649dd in void testing::internal::HandleSehExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #32 0x7f06c2c649dd in void testing::internal::HandleExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #33 0x7f06c2c64708 in testing::Test::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2739:5 #34 0x7f06c2c6771b in testing::TestInfo::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2885:11 #35 0x7f06c2c6a5ab in testing::TestSuite::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:3063:30 #36 0x7f06c2c96eba in testing::internal::UnitTestImpl::RunAllTests() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:6054:44 #37 0x7f06c2c9579d in bool testing::internal::HandleSehExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #38 0x7f06c2c9579d in bool testing::internal::HandleExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #39 0x7f06c2c95203 in testing::UnitTest::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:5594:10 #40 0x7f06c2d679b8 in RUN_ALL_TESTS() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/include/gtest/gtest.h:2334:73 #41 0x7f06c2d679b8 in main /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/xla_internal_test_main.cc:65:10 #42 0x7f064c0b3d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16 #43 0x7f064c0b3e3f in __libc_start_main csu/../csu/libc-start.c:392:3 #44 0x5647cfc7b044 in _start (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/gpu_compiler_test_gpu_amd_any+0xff044) (BuildId: ef1ac485eb61840d0e2233a2cca69eec) 0x5030001d97f8 is located 8 bytes before 32-byte region [0x5030001d9800,0x5030001d9820) allocated by thread T0 here: #0 0x5647cfd1527f in malloc (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/gpu_compiler_test_gpu_amd_any+0x19927f) (BuildId: ef1ac485eb61840d0e2233a2cca69eec) #1 0x7f064c39798b in operator new(unsigned long) (/lib/x86_64-linux-gnu/libstdc++.so.6+0xae98b) (BuildId: e37fe1a879783838de78cbc8c80621fa685d58a2) #2 0x7f06b31bb5b7 in google::protobuf::Duration* google::protobuf::MessageLite::CreateMaybeMessage(google::protobuf::Arena*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_protobuf/src/google/protobuf/message_lite.h:425:12 #3 0x7f06b31bb5b7 in xla::AutotuneResult::_internal_mutable_run_time() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/autotuning.pb.h:3079:15 #4 0x7f06b31bb5b7 in xla::AutotuneResult::mutable_run_time() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/autotuning.pb.h:3085:45 #5 0x7f06b31bb5b7 in absl::lts_20230802::StatusOr xla::gpu::(anonymous namespace)::GemmAutotuner::GetBestAlgorithm(xla::HloInstruction const*, absl::lts_20230802::Span, double, bool, xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&)::'lambda'(long const&)&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:321:15 #6 0x7f06b31bb5b7 in xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:256:12 #7 0x7f06b31bb5b7 in xla::gpu::(anonymous namespace)::GemmAutotuner::operator()(xla::HloInstruction const*, xla::gpu::AutotuneCacheKey const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:137:18 #8 0x7f06b31b6760 in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0::operator()() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #9 0x7f06b31b6760 in absl::lts_20230802::StatusOr std::__invoke_impl, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(std::__invoke_other, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14 #10 0x7f06b31b6760 in std::enable_if, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>, absl::lts_20230802::StatusOr>::type std::__invoke_r, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:114:9 #11 0x7f06b31b6760 in std::_Function_handler (), xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0>::_M_invoke(std::_Any_data const&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:290:9 #12 0x7f06b308670d in std::function ()>::operator()() const /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:590:9 #13 0x7f06b308670d in xla::gpu::AutotunerUtil::Autotune(xla::HloInstruction const*, xla::gpu::AutotuneConfig const&, std::function ()> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/autotuner_util.cc:460:3 #14 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #15 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnComputation(xla::HloComputation*, xla::gpu::(anonymous namespace)::GemmAutotuner&, unsigned long*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:468:7 #16 0x7f06b31b336e in xla::gpu::GemmAlgorithmPicker::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:495:5 #17 0x7f06b30242f3 in xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_pipeline.h:150:5 #18 0x7f06b3010bb9 in absl::lts_20230802::StatusOr xla::HloPassPipeline::RunPassesInternal(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:198:30 #19 0x7f06b300f786 in xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:338:10 #20 0x5647cfd66945 in xla::HloPassInterface::Run(xla::HloModule*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_interface.h:85:12 #21 0x7f06c2908be0 in xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1754:3 #22 0x7f06c2a000f3 in xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/amdgpu_compiler.cc:197:3 #23 0x7f06c28f85e9 in xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1392:3 #24 0x7f06c291250d in xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1824:3 #25 0x5647cfd63784 in xla::Compiler::RunHloPasses(std::unique_ptr>, stream_executor::StreamExecutor*, stream_executor::DeviceMemoryAllocator*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/service/compiler.h:177:12 #26 0x7f06c339acba in xla::HloTestBase::GetOptimizedModule(std::unique_ptr>) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/hlo_test_base.cc:188:32 #27 0x5647cfd89516 in xla::gpu::(anonymous namespace)::GpuCompilerTest_CollectivePermuteDecompositionAndPipelining_Test::TestBody() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler_test.cc:879:3 #28 0x7f06c2c649dd in void testing::internal::HandleSehExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #29 0x7f06c2c649dd in void testing::internal::HandleExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #30 0x7f06c2c64708 in testing::Test::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2739:5 #31 0x7f06c2c6771b in testing::TestInfo::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2885:11 #32 0x7f06c2c6a5ab in testing::TestSuite::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:3063:30 #33 0x7f06c2c96eba in testing::internal::UnitTestImpl::RunAllTests() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:6054:44 #34 0x7f06c2c9579d in bool testing::internal::HandleSehExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #35 0x7f06c2c9579d in bool testing::internal::HandleExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #36 0x7f06c2c95203 in testing::UnitTest::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:5594:10 #37 0x7f06c2d679b8 in RUN_ALL_TESTS() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/include/gtest/gtest.h:2334:73 #38 0x7f06c2d679b8 in main /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/xla_internal_test_main.cc:65:10 #39 0x7f064c0b3d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16 SUMMARY: AddressSanitizer: heap-buffer-overflow /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:990:36 in absl::lts_20230802::container_internal::CommonFields::capacity() const Shadow bytes around the buggy address: 0x5030001d9500: fd fd fd fa fa fa fd fd fd fa fa fa fd fd fd fa 0x5030001d9580: fa fa fd fd fd fd fa fa fd fd fd fd fa fa fd fd 0x5030001d9600: fd fa fa fa fd fd fd fa fa fa fd fd fd fa fa fa 0x5030001d9680: fd fd fd fd fa fa fd fd fd fa fa fa fd fd fd fa 0x5030001d9700: fa fa fd fd fd fd fa fa fd fd fd fd fa fa fd fd =>0x5030001d9780: fd fa fa fa 00 00 00 fa fa fa 00 00 00 00 fa[fa] 0x5030001d9800: 00 00 00 00 fa fa 00 00 00 00 fa fa fd fd fd fd 0x5030001d9880: fa fa fd fd fd fd fa fa fd fd fd fa fa fa fd fd 0x5030001d9900: fd fd fa fa fd fd fd fd fa fa fd fd fd fd fa fa 0x5030001d9980: fd fd fd fa fa fa fd fd fd fa fa fa fd fd fd fa 0x5030001d9a00: fa fa fd fd fd fa fa fa fd fd fd fd fa fa fd fd Shadow byte legend (one shadow byte represents 8 application bytes): Addressable: 00 Partially addressable: 01 02 03 04 05 06 07 Heap left redzone: fa Freed heap region: fd Stack left redzone: f1 Stack mid redzone: f2 Stack right redzone: f3 Stack after return: f5 Stack use after scope: f8 Global redzone: f9 Global init order: f6 Poisoned by user: f7 Container overflow: fc Array cookie: ac Intra object redzone: bb ASan internal: fe Left alloca redzone: ca Right alloca redzone: cb ==1718600==ABORTING ``` Copybara import of the project: -- 9a75d26eb9aab4226a690658d254a057fc59f22c by alekstheod : Fix access memory asan issue in redzone_allocator_kernel_rocm.cu Merging this change closes #24898 PiperOrigin-RevId: 745563669 --- .../stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc index cff487c5a062bb..f44988212ff180 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc @@ -39,12 +39,11 @@ namespace stream_executor { template static absl::StatusOr*> LoadKernelOrGetPtr( StreamExecutor* executor, absl::string_view kernel_name, void* kernel_ptr) { - using KernelPtrCacheKey = - std::tuple; + using KernelPtrCacheKey = std::tuple; static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = - *new absl::node_hash_map>(); + *new std::map>; KernelPtrCacheKey kernel_ptr_cache_key{executor, kernel_name, kernel_ptr}; absl::MutexLock lock(&kernel_ptr_cache_mutex); From 7230b6035413bdc2dda1491801953d86723e1b80 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 9 Apr 2025 06:36:41 -0700 Subject: [PATCH 0441/1324] Move CUDA specific test from gpu_command_buffer_test.cc into cuda_command_buffer_test.cc Backend specific things can be in gpu_command_buffer_test.cc since including CUDA specific headers breaks the ROCm build. PiperOrigin-RevId: 745565149 --- .../xla/xla/stream_executor/cuda/BUILD | 26 +++ .../cuda/cuda_command_buffer_test.cc | 163 ++++++++++++++++++ third_party/xla/xla/stream_executor/gpu/BUILD | 3 - .../gpu/gpu_command_buffer_test.cc | 116 +------------ 4 files changed, 190 insertions(+), 118 deletions(-) create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 83377e18a4bf6a..dc80d2397b3a07 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1377,6 +1377,32 @@ cc_library( ], ) +xla_test( + name = "cuda_command_buffer_test", + srcs = ["cuda_command_buffer_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cudnn_plugin", + "//xla/service:platform_util", + "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:dnn", + "//xla/stream_executor:numeric_options", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@cudnn_frontend_archive//:cudnn_frontend", + ], +) + cc_library( name = "cubin_or_ptx_image", hdrs = ["cubin_or_ptx_image.h"], diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc new file mode 100644 index 00000000000000..1479a14d17dc59 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/ascii.h" +#include "absl/types/span.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend.h" // IWYU pragma: keep - cudnn frontend headers are not hermetic +#include "third_party/cudnn_frontend/include/cudnn_frontend/graph_interface.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend/graph_properties.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_utils.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" + +namespace stream_executor::cuda { +namespace { + +using ::testing::Each; +using ::tsl::testing::IsOkAndHolds; + +static Platform* CudaPlatform() { + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("cuda").value()); + return PlatformManager::PlatformWithName(name).value(); +} + +static constexpr auto primary = CommandBuffer::Mode::kPrimary; // NOLINT + +TEST(CudaCommandBufferTest, CuDnnExplicitConstructionAndUpdateWork) { + Platform* platform = CudaPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + dnn::DnnSupport& dnn_support = *executor->AsDnn(); + + if (dnn_support.GetVersion().value_or(dnn::VersionInfo{0, 0, 0}) < + dnn::VersionInfo(9, 7, 0)) { + GTEST_SKIP() << "Requires cuDNN 9.7.0 or later."; + } + + constexpr int kDimSize = 32; + constexpr int kTotalElements = kDimSize * kDimSize; + + stream_executor::gpu::CudnnGraph graph([]() { + cudnn_frontend::graph::Graph graph; + graph.set_compute_data_type(cudnn_frontend::DataType_t::INT32); + std::shared_ptr lhs = + graph.tensor(cudnn_frontend::graph::Tensor_attributes() + .set_dim({1, kDimSize, kDimSize}) + .set_stride({kDimSize * kDimSize, kDimSize, 1}) + .set_data_type(cudnn_frontend::DataType_t::INT8) + .set_uid(1)); + std::shared_ptr rhs = + graph.tensor_like(lhs); + rhs->set_uid(2); + graph.matmul(lhs, rhs, cudnn_frontend::graph::Matmul_attributes()) + ->set_output(true) + .set_data_type(cudnn_frontend::DataType_t::INT32) + .set_uid(3); + return graph; + }()); + TF_ASSERT_OK(graph.Prepare(dnn_support, NumericOptions{})); + TF_ASSERT_OK(graph.Build(dnn_support, /*plan_id=*/std::nullopt)); + EXPECT_THAT(graph.SupportsExplicitCommandBufferConstruction(), + IsOkAndHolds(true)); + + DeviceMemory input = executor->AllocateArray(kTotalElements); + TF_ASSERT_OK(stream->MemZero(&input, input.size())); + DeviceMemory output0 = + executor->AllocateArray(kTotalElements); + DeviceMemoryBase workspace; + std::vector operands; + operands.reserve(4); + operands.push_back(input); // multiplying the input by itself + operands.push_back(input); + operands.push_back(output0); + if (graph.Graph().get_workspace_size() > 0) { + workspace = executor->Allocate(graph.Graph().get_workspace_size()); + operands.push_back(workspace); + } + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr cmd_buffer, + executor->CreateCommandBuffer(primary)); + TF_ASSERT_OK( + cmd_buffer + ->DnnGraph(graph, *stream, absl::Span(operands), {}) + .status()); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + std::vector host_buffer(output0.ElementCount()); + + // Initialize and check the output before execution. + TF_ASSERT_OK(stream->Memset32(&output0, 123, output0.size())); + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output0, output0.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(123)); + + // Run the computation. + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); + + // Check the output after execution. + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output0, output0.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(0)); + + // Swap the output buffer. + DeviceMemory output1 = + executor->AllocateArray(kTotalElements); + operands[2] = output1; + executor->Deallocate(&output0); + + // Initialize and check the output before execution. + TF_ASSERT_OK(stream->Memset32(&output1, 456, output1.size())); + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output1, output1.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(456)); + + // Update the command buffer to write into the new output buffer. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK( + cmd_buffer + ->DnnGraph(graph, *stream, absl::Span(operands), {}) + .status()); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + // Run the computation. + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); + + // Check the output after execution. + TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output1, output1.size())); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + EXPECT_THAT(host_buffer, Each(0)); +} + +} // namespace +} // namespace stream_executor::cuda diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 22747215ae3b7a..71a6ece0220379 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -646,7 +646,6 @@ xla_test( "//xla/service:platform_util", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", - "//xla/stream_executor:dnn", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", @@ -673,10 +672,8 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@cudnn_frontend_archive//:cudnn_frontend", ] + if_cuda([ "//xla/stream_executor/cuda:cuda_platform", - "//xla/stream_executor/cuda:cudnn_plugin", ]) + if_rocm([ "//xla/stream_executor/rocm:rocm_platform", ]), diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index ffa9d26a31e740..b4f94f963cf9dd 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -13,11 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/gpu/gpu_command_buffer.h" - #include #include -#include #include #include @@ -25,13 +22,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/ascii.h" #include "absl/types/span.h" -#include "third_party/cudnn_frontend/include/cudnn_frontend.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/cuda/cuda_dnn.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -51,9 +45,6 @@ limitations under the License. namespace stream_executor::gpu { -using ::testing::Each; -using ::tsl::testing::IsOkAndHolds; - static Platform* GpuPlatform() { auto name = absl::AsciiStrToUpper( xla::PlatformUtil::CanonicalPlatformName("gpu").value()); @@ -146,7 +137,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { ASSERT_EQ(dst, expected); } -TEST(CudaCommandBufferTest, TraceSingleKernel) { +TEST(GpuCommandBufferTest, TraceSingleKernel) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); @@ -791,111 +782,6 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { ASSERT_EQ(dst, expected); } -TEST(GpuCommandBufferTest, CuDnnExplicitConstructionAndUpdateWork) { - Platform* platform = GpuPlatform(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - dnn::DnnSupport& dnn_support = *executor->AsDnn(); - - if (dnn_support.GetVersion().value_or(dnn::VersionInfo{0, 0, 0}) < - dnn::VersionInfo(9, 7, 0)) { - GTEST_SKIP() << "Requires cuDNN 9.7.0 or later."; - } - - constexpr int dim_size = 32; - constexpr int total_elements = dim_size * dim_size; - - CudnnGraph graph([]() { - cudnn_frontend::graph::Graph graph; - graph.set_compute_data_type(cudnn_frontend::DataType_t::INT32); - std::shared_ptr lhs = - graph.tensor(cudnn_frontend::graph::Tensor_attributes() - .set_dim({1, dim_size, dim_size}) - .set_stride({dim_size * dim_size, dim_size, 1}) - .set_data_type(cudnn_frontend::DataType_t::INT8) - .set_uid(1)); - std::shared_ptr rhs = - graph.tensor_like(lhs); - rhs->set_uid(2); - graph.matmul(lhs, rhs, cudnn_frontend::graph::Matmul_attributes()) - ->set_output(true) - .set_data_type(cudnn_frontend::DataType_t::INT32) - .set_uid(3); - return graph; - }()); - TF_ASSERT_OK(graph.Prepare(dnn_support, NumericOptions{})); - TF_ASSERT_OK(graph.Build(dnn_support, /*plan_id=*/std::nullopt)); - EXPECT_THAT(graph.SupportsExplicitCommandBufferConstruction(), - IsOkAndHolds(true)); - - DeviceMemory input = executor->AllocateArray(total_elements); - TF_ASSERT_OK(stream->MemZero(&input, input.size())); - DeviceMemory output0 = - executor->AllocateArray(total_elements); - DeviceMemoryBase workspace; - std::vector operands; - operands.reserve(4); - operands.push_back(input); // multiplying the input by itself - operands.push_back(input); - operands.push_back(output0); - if (graph.Graph().get_workspace_size() > 0) { - workspace = executor->Allocate(graph.Graph().get_workspace_size()); - operands.push_back(workspace); - } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr cmd_buffer, - executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK( - cmd_buffer - ->DnnGraph(graph, *stream, absl::Span(operands), {}) - .status()); - TF_ASSERT_OK(cmd_buffer->Finalize()); - - std::vector host_buffer(output0.ElementCount()); - - // Initialize and check the output before execution. - TF_ASSERT_OK(stream->Memset32(&output0, 123, output0.size())); - TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output0, output0.size())); - TF_ASSERT_OK(stream->BlockHostUntilDone()); - EXPECT_THAT(host_buffer, Each(123)); - - // Run the computation. - TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); - - // Check the output after execution. - TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output0, output0.size())); - TF_ASSERT_OK(stream->BlockHostUntilDone()); - EXPECT_THAT(host_buffer, Each(0)); - - // Swap the output buffer. - DeviceMemory output1 = - executor->AllocateArray(total_elements); - operands[2] = output1; - executor->Deallocate(&output0); - - // Initialize and check the output before execution. - TF_ASSERT_OK(stream->Memset32(&output1, 456, output1.size())); - TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output1, output1.size())); - TF_ASSERT_OK(stream->BlockHostUntilDone()); - EXPECT_THAT(host_buffer, Each(456)); - - // Update the command buffer to write into the new output buffer. - TF_ASSERT_OK(cmd_buffer->Update()); - TF_ASSERT_OK( - cmd_buffer - ->DnnGraph(graph, *stream, absl::Span(operands), {}) - .status()); - TF_ASSERT_OK(cmd_buffer->Finalize()); - - // Run the computation. - TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); - - // Check the output after execution. - TF_ASSERT_OK(stream->Memcpy(host_buffer.data(), output1, output1.size())); - TF_ASSERT_OK(stream->BlockHostUntilDone()); - EXPECT_THAT(host_buffer, Each(0)); -} - //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// From c2636ac0c7c44c46246bc715842dce17015245f7 Mon Sep 17 00:00:00 2001 From: Kostiantyn Liepieshov Date: Wed, 9 Apr 2025 06:39:44 -0700 Subject: [PATCH 0442/1324] #sdy integrate propagation barrier PiperOrigin-RevId: 745565953 --- .../xla/xla/service/sharding_remover.cc | 3 +++ .../xla/xla/service/spmd/shardy/constants.h | 8 ++++++ .../import_sdy_custom_calls.cc | 26 ++++++++++++++++++- .../spmd/shardy/sdy_round_trip/export_ops.cc | 25 +++++++++++++++--- .../shardy/stablehlo_round_trip/export_ops.cc | 21 ++++++++++++--- .../spmd/shardy/test/round_trip_pipeline.mlir | 9 +++++++ .../test/sdy_round_trip_export_pipeline.mlir | 8 ++++++ .../test/sdy_round_trip_import_pipeline.mlir | 20 ++++++++++++++ .../test/stablehlo_export_pipeline.mlir | 7 +++++ 9 files changed, 120 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/sharding_remover.cc b/third_party/xla/xla/service/sharding_remover.cc index 042e9f137ef1f0..e2d926a9045aa8 100644 --- a/third_party/xla/xla/service/sharding_remover.cc +++ b/third_party/xla/xla/service/sharding_remover.cc @@ -46,6 +46,7 @@ absl::StatusOr ShardingRemover::Run( "SPMDShardToFullShape", "SPMDFullToShardShape", sdy::kShardingGroupCustomCallTargetName, + sdy::kPropagationBarrierCustomCallTargetName, sdy::kFuncResultShardingTargetName, spmd::kShardBarrierFrom, spmd::kShardBarrierTo}; @@ -80,6 +81,8 @@ absl::StatusOr ShardingRemover::Run( if (instruction->custom_call_target() == "Sharding" || instruction->custom_call_target() == sdy::kFuncResultShardingTargetName || + instruction->custom_call_target() == + sdy::kPropagationBarrierCustomCallTargetName || instruction->custom_call_target() == spmd::kShardBarrierFrom || instruction->custom_call_target() == spmd::kShardBarrierTo) { auto copy = computation->AddInstruction( diff --git a/third_party/xla/xla/service/spmd/shardy/constants.h b/third_party/xla/xla/service/spmd/shardy/constants.h index a347d1cc272ee4..7cad4abc01fd26 100644 --- a/third_party/xla/xla/service/spmd/shardy/constants.h +++ b/third_party/xla/xla/service/spmd/shardy/constants.h @@ -96,6 +96,14 @@ inline constexpr llvm::StringRef kShardingGroupCustomCallTargetName = inline constexpr llvm::StringRef kShardingGroupIdAttr = "xla.sdy.sharding_group_id"; +// Shardy propagation barrier custom call target name. +inline constexpr llvm::StringRef kPropagationBarrierCustomCallTargetName = + "xla.sdy.PropagationBarrier"; + +// Propagation barrier allowed direction attribute name. +inline constexpr llvm::StringRef kAllowedDirectionAttr = + "xla.sdy.allowed_direction"; + // Attribute name for storing frontend attributes in XLA. inline constexpr llvm::StringRef kFrontendAttributesAttr = "mhlo.frontend_attributes"; diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc index c224ddc444fd18..48666b6a12ceb5 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc @@ -48,6 +48,8 @@ namespace { using ::mlir::IntegerAttr; using ::mlir::StringRef; +using ::mlir::sdy::PropagationBarrierOp; +using ::mlir::sdy::PropagationDirectionAttr; using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; @@ -87,6 +89,24 @@ mlir::LogicalResult rewriteShardingCustomCall( return mlir::success(); } +mlir::LogicalResult rewritePropagationBarrierCustomCall( + CustomCallOp op, CustomCallOpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) { + CHECK_EQ(op.getNumOperands(), 1); + CHECK_EQ(op.getNumResults(), 1); + std::optional allowedDirection = + tryGetFrontendAttr(op, kAllowedDirectionAttr); + if (!allowedDirection.has_value()) { + op.emitError() << "expected PropagationBarrier CustomCall Op with a " + "propagation direction."; + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getInputs().front(), allowedDirection->getValue()); + + return mlir::success(); +} mlir::LogicalResult rewriteShardingGroupCustomCall( CustomCallOp op, CustomCallOpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) { @@ -122,6 +142,9 @@ class SdyCustomCallPattern : public mlir::OpConversionPattern { if (op.getCallTargetName() == kShardingGroupCustomCallTargetName) { return rewriteShardingGroupCustomCall(op, adaptor, rewriter); } + if (op.getCallTargetName() == kPropagationBarrierCustomCallTargetName) { + return rewritePropagationBarrierCustomCall(op, adaptor, rewriter); + } return rewriter.notifyMatchFailure( op, "expected CustomCallOp with xla.sdy target name."); @@ -143,7 +166,8 @@ class ImportSdyCustomCallsPass target.addLegalDialect(); target.addDynamicallyLegalOp([](CustomCallOp op) { return op.getCallTargetName() != kShardingCustomCallTargetName && - op.getCallTargetName() != kShardingGroupCustomCallTargetName; + op.getCallTargetName() != kShardingGroupCustomCallTargetName && + op.getCallTargetName() != kPropagationBarrierCustomCallTargetName; }); mlir::RewritePatternSet patterns(&context); patterns.add(&context); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc index b9bfc2762440e4..1352a20a1d0636 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc @@ -62,6 +62,7 @@ using ::mlir::StringRef; using ::mlir::success; using ::mlir::sdy::ConstantOp; +using ::mlir::sdy::PropagationBarrierOp; using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; @@ -128,6 +129,25 @@ class ShardingGroupPattern : public OpConversionPattern { } }; +class PropagationBarrierPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + private: + LogicalResult matchAndRewrite( + PropagationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto customCallOp = rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), adaptor.getInput()); + + customCallOp.setCallTargetName(kPropagationBarrierCustomCallTargetName); + setFrontendAttribute(customCallOp, kAllowedDirectionAttr, + op.getAllowedDirectionAttr()); + return success(); + } +}; + class SdyRoundTripExportOpsPass : public PassWrapper> { public: @@ -139,9 +159,8 @@ class SdyRoundTripExportOpsPass target.addIllegalOp(); target.addLegalOp(); mlir::RewritePatternSet patterns(&context); - patterns - .add( - &context); + patterns.add(&context); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); diff --git a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/export_ops.cc index d0a65d998e7fc3..233d714a54339f 100644 --- a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/export_ops.cc @@ -64,6 +64,7 @@ using ::mlir::sdy::AllSliceOp; using ::mlir::sdy::AllToAllOp; using ::mlir::sdy::CollectivePermuteOp; using ::mlir::sdy::ConstantOp; +using ::mlir::sdy::PropagationBarrierOp; using ::mlir::sdy::ReshardOp; using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::TensorShardingAttr; @@ -98,6 +99,20 @@ class AllReducePattern : public OpConversionPattern { } }; +class PropagationBarrierPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + private: + LogicalResult matchAndRewrite( + PropagationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOp(op, adaptor.getInput()); + return success(); + } +}; + void rewriteCollectiveOp(mlir::Operation* op, mlir::Value input, TensorShardingAttr sharding, ConversionPatternRewriter& rewriter) { @@ -147,7 +162,7 @@ class ExportOpsPass // Hence, we add ShardingConstraintOp as an illegal op. target.addIllegalOp(); + ShardingConstraintOp, PropagationBarrierOp>(); target.addLegalOp(); mlir::RewritePatternSet patterns(&context); // After converting `sdy.constant` into `stablehlo.constant`, the constants @@ -155,8 +170,8 @@ class ExportOpsPass // greedy pattern rewriters. ExportHloShardingsPass does a simple walk, // which keeps the constants as is. patterns.add, CollectivePattern, - CollectivePattern, + PropagationBarrierPattern, CollectivePattern, + CollectivePattern, CollectivePattern, CollectivePattern>(&context); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index 522976c5bbd2c5..113b2b85faef6c 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -230,6 +230,15 @@ func.func @main(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { // ----- +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { + // CHECK: sdy.propagation_barrier %arg0 allowed_direction=BACKWARD : tensor<8x16xf32> + %r = sdy.propagation_barrier %arg0 allowed_direction=BACKWARD : tensor<8x16xf32> + return %r : tensor<8x16xf32> +} + +// ----- + // Test call with backend config and multiple results. This is what JAX would // emit in the frontend, and then we'd convert it to a NamedComputationOp when // coming back. diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index edb7be4c0a5d66..1e679b6d9644e1 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -103,6 +103,14 @@ func.func @export_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { return %arg0 : tensor<8x8xf32> } +// CHECK-LABEL: func @export_propagation_barrier +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @export_propagation_barrier(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK: %0 = stablehlo.custom_call @local_xla.sdy.PropagationBarrier(%arg0) {mhlo.frontend_attributes = {xla.sdy.allowed_direction = "2 : i32"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + %0 = sdy.propagation_barrier %arg0 allowed_direction=BACKWARD : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + // CHECK-LABEL: func @constant func.func @constant() -> tensor { // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<0> diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 22262153e8ea78..89c57ff153ce3a 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -244,6 +244,26 @@ func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // ----- +// CHECK-LABEL: func @import_propagation_barrier_backward +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @import_propagation_barrier_backward(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK %r = sdy.propagation_barrier %arg0 allowed_direction=BACKWARD : tensor<8x8xf32> + %r = stablehlo.custom_call @local_xla.sdy.PropagationBarrier(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.allowed_direction = "2 : i32"}} : (tensor<8x8xf32>) -> (tensor<8x8xf32>) + return %r : tensor<8x8xf32> +} + +// ----- + +// CHECK-LABEL: func @import_propagation_barrier_forward +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @import_propagation_barrier_forward(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK %r = sdy.propagation_barrier %arg0 allowed_direction=FORWARD : tensor<8x8xf32> + %r = stablehlo.custom_call @local_xla.sdy.PropagationBarrier(%arg0) {mhlo.frontend_attributes = {xla.sdy.allowed_direction = "1 : i32"}} : (tensor<8x8xf32>) -> (tensor<8x8xf32>) + return %r : tensor<8x8xf32> +} + +// ----- + func.func @callback_no_result(%arg0: tensor) { // CHECK: %[[C:.*]] = sdy.constant // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/stablehlo_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/stablehlo_export_pipeline.mlir index 7e720674ec11e0..1dd3d185058485 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/stablehlo_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/stablehlo_export_pipeline.mlir @@ -425,6 +425,13 @@ func.func @while_with_no_sharding_inside_manual_comp( return %0 : tensor<32x2xi32> } +// CHECK-LABEL: func @propagation_barrier +func.func @propagation_barrier(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { + // CHECK-NEXT: return %arg0 : tensor<8x16xf32> + %r = sdy.propagation_barrier %arg0 allowed_direction=BACKWARD : tensor<8x16xf32> + return %r : tensor<8x16xf32> +} + // CHECK-LABEL: func private @foo // CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} From 62fb2e161e2e9e797386930d4cefdd78507de0c7 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Apr 2025 08:06:10 -0700 Subject: [PATCH 0443/1324] PR #24900: [ROCM] fix invalid memory access in rocm executor Imported from GitHub PR https://github.com/openxla/xla/pull/24900 Fix asan memory access violation: ``` exec ${PAGER:-/usr/bin/less} "$0" || exit 1 Executing tests from //xla/service:elemental_ir_emitter_test_gpu_amd_any ----------------------------------------------------------------------------- Running test /home/atheodor/projects/tmp/xla_asan/execroot/xla/bazel-out/k8-opt/bin/xla/service/elemental_ir_emitter_test_gpu_amd_any.runfiles/xla/xla/service/elemental_ir_emitter_test_gpu_amd_any --gtest_shuffle --gtest_fail_if_no_test_linked on GPU 3 Note: Randomizing tests' orders with a seed of 19906 . [==========] Running 118 tests from 13 test suites. [----------] Global test environment set-up. [----------] 10 tests from ElementalIrEmitterExecutionTypedTest/7, where TypeParam = ml_dtypes::float8_internal::float8_e5m2 [ RUN ] ElementalIrEmitterExecutionTypedTest/7.ConvertFloatsToFloat ================================================================= ==2457579==ERROR: AddressSanitizer: use-after-poison on address 0x506000843a08 at pc 0x7f401151be6a bp 0x7ffd1e3c3410 sp 0x7ffd1e3c3408 READ of size 8 at 0x506000843a08 thread T0 #0 0x7f401151be69 in stream_executor::gpu::RocmExecutor::UnloadGpuBinary(stream_executor::ModuleHandle) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/stream_executor/rocm/rocm_executor.cc:596:23 #1 0x7f401151b036 in stream_executor::gpu::RocmExecutor::UnloadModule(stream_executor::ModuleHandle) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/stream_executor/rocm/rocm_executor.cc:496:10 #2 0x7f405dee713b in stream_executor::ScopedModuleHandle::~ScopedModuleHandle() /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/stream_executor/scoped_module_handle.h:48:7 #3 0x7f405dee713b in std::pair::~pair() /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/stl_iterator.h:2488:12 #4 0x7f405dee713b in void __gnu_cxx::new_allocator>::destroy>(std::pair*) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/ext/new_allocator.h:168:10 #5 0x7f405dee713b in void std::allocator_traits>>::destroy>(std::allocator>&, std::pair*) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/alloc_traits.h:535:8 #6 0x7f405dee713b in void absl::lts_20230802::container_internal::map_slot_policy::destroy>>(std::allocator>*, absl::lts_20230802::container_internal::map_slot_type*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/container_memory.h:419:7 #7 0x7f405dee713b in void absl::lts_20230802::container_internal::FlatHashMapPolicy::destroy>>(std::allocator>*, absl::lts_20230802::container_internal::map_slot_type*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/flat_hash_map.h:578:5 #8 0x7f405dee713b in void absl::lts_20230802::container_internal::common_policy_traits, void>::destroy>>(std::allocator>*, absl::lts_20230802::container_internal::map_slot_type*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/common_policy_traits.h:50:5 #9 0x7f405dee713b in absl::lts_20230802::container_internal::raw_hash_set, absl::lts_20230802::container_internal::HashEq::Hash, absl::lts_20230802::container_internal::HashEq::Eq, std::allocator>>::destroy_slots() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:1946:9 #10 0x7f405dee713b in absl::lts_20230802::container_internal::raw_hash_set, absl::lts_20230802::container_internal::HashEq::Hash, absl::lts_20230802::container_internal::HashEq::Eq, std::allocator>>::~raw_hash_set() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:1885:5 #11 0x7f405dee8580 in xla::gpu::GpuExecutable::~GpuExecutable() /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/gpu/gpu_executable.cc:155:1 #12 0x7f405dee8d4d in xla::gpu::GpuExecutable::~GpuExecutable() /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/gpu/gpu_executable.cc:151:33 #13 0x7f407b818b3f in std::default_delete::operator()(xla::Executable*) const /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/unique_ptr.h:85:2 #14 0x7f407b818b3f in std::unique_ptr>::~unique_ptr() /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/unique_ptr.h:361:4 #15 0x7f407b818b3f in xla::(anonymous namespace)::HloRunnerExecutable::~HloRunnerExecutable() /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:69:7 #16 0x7f407b818b3f in xla::(anonymous namespace)::HloRunnerExecutable::~HloRunnerExecutable() /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:69:7 #17 0x7f407b7e6503 in std::default_delete::operator()(xla::OpaqueExecutable*) const /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/unique_ptr.h:85:2 #18 0x7f407b7e6503 in std::unique_ptr>::~unique_ptr() /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/unique_ptr.h:361:4 #19 0x7f407b7e6503 in xla::HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment(std::unique_ptr>, xla::BufferAssignmentProto const*, std::vector>, bool, xla::ExecutionProfile*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:400:1 #20 0x7f407b7e57c3 in xla::HloRunner::Execute(std::unique_ptr>, absl::lts_20230802::Span, bool, xla::ExecutionProfile*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:221:3 #21 0x55b8a3cb4622 in xla::HloRunnerInterface::Execute(std::unique_ptr>, absl::lts_20230802::Span, bool) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/service/hlo_runner_interface.h:244:12 #22 0x55b8a3cb4622 in xla::HloRunnerAgnosticReferenceMixin::RunAndCompareInternal(std::unique_ptr>, absl::lts_20230802::Span, std::optional const&, bool, std::function const&, std::function const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/tests/hlo_runner_agnostic_reference_mixin.h:238:5 #23 0x55b8a3cbf766 in xla::HloRunnerAgnosticReferenceMixin::RunAndCompare(std::unique_ptr>, absl::lts_20230802::Span, std::optional const&, std::function const&, std::function const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/tests/hlo_runner_agnostic_reference_mixin.h:94:9 #24 0x55b8a3cbf235 in xla::HloRunnerAgnosticReferenceMixin::RunAndCompare(std::unique_ptr>, std::optional const&, std::function const&, std::function const&, std::optional) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/tests/hlo_runner_agnostic_reference_mixin.h:140:12 #25 0x55b8a3cceda8 in xla::(anonymous namespace)::ElementalIrEmitterExecutionTest::RunTypeConversionTest(std::basic_string_view>) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/elemental_ir_emitter_test.cc:76:5 #26 0x55b8a3cd8cf3 in xla::(anonymous namespace)::ElementalIrEmitterExecutionTypedTest_ConvertFloatsToFloat_Test::TestBody() /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/elemental_ir_emitter_test.cc:472:36 #27 0x7f407b2f09dd in void testing::internal::HandleSehExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #28 0x7f407b2f09dd in void testing::internal::HandleExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #29 0x7f407b2f0708 in testing::Test::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2739:5 #30 0x7f407b2f371b in testing::TestInfo::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2885:11 #31 0x7f407b2f65ab in testing::TestSuite::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:3063:30 #32 0x7f407b322eba in testing::internal::UnitTestImpl::RunAllTests() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:6054:44 #33 0x7f407b32179d in bool testing::internal::HandleSehExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #34 0x7f407b32179d in bool testing::internal::HandleExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #35 0x7f407b321203 in testing::UnitTest::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:5594:10 #36 0x7f407b3f59b8 in RUN_ALL_TESTS() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/include/gtest/gtest.h:2334:73 #37 0x7f407b3f59b8 in main /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/tests/xla_internal_test_main.cc:65:10 #38 0x7f4004766d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16 #39 0x7f4004766e3f in __libc_start_main csu/../csu/libc-start.c:392:3 #40 0x55b8a3b9be44 in _start (/home/atheodor/projects/tmp/xla_asan/execroot/xla/bazel-out/k8-opt/bin/xla/service/elemental_ir_emitter_test_gpu_amd_any+0x10ce44) (BuildId: 1c37d17e488373aad7bf33204cb4234e) 0x506000843a08 is located 40 bytes inside of 56-byte region [0x5060008439e0,0x506000843a18) allocated by thread T0 here: #0 0x55b8a3c3607f in malloc (/home/atheodor/projects/tmp/xla_asan/execroot/xla/bazel-out/k8-opt/bin/xla/service/elemental_ir_emitter_test_gpu_amd_any+0x1a707f) (BuildId: 1c37d17e488373aad7bf33204cb4234e) #1 0x7f4004a4a98b in operator new(unsigned long) (/lib/x86_64-linux-gnu/libstdc++.so.6+0xae98b) (BuildId: e37fe1a879783838de78cbc8c80621fa685d58a2) #2 0x7f40115449aa in absl::lts_20230802::container_internal::raw_hash_set>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::initialize_slots() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2505:5 #3 0x7f40115449aa in absl::lts_20230802::container_internal::raw_hash_set>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::resize(unsigned long) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2515:5 #4 0x7f40115443fa in absl::lts_20230802::container_internal::raw_hash_set>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::prepare_insert(unsigned long) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2672:7 #5 0x7f40115442df in std::pair absl::lts_20230802::container_internal::raw_hash_set>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::find_or_prepare_insert(stream_executor::ModuleHandle const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2659:13 #6 0x7f4011524701 in std::pair>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::iterator, bool> absl::lts_20230802::container_internal::raw_hash_map>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::try_emplace_impl(stream_executor::ModuleHandle const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_map.h:202:22 #7 0x7f4011524701 in std::pair>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::iterator, bool> absl::lts_20230802::container_internal::raw_hash_map>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::try_emplace(stream_executor::ModuleHandle const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_map.h:139:12 #8 0x7f4011524701 in decltype(absl::lts_20230802::container_internal::FlatHashMapPolicy>::value(std::pair>* std::addressof>>(std::pair>&)(decltype(__declval>>(0)) std::declval>&>()()))) absl::lts_20230802::container_internal::raw_hash_map>, absl::lts_20230802::hash_internal::Hash, std::equal_to, std::allocator>>>::operator[]>>(stream_executor::ModuleHandle const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_map.h:184:28 #9 0x7f4011524701 in stream_executor::gpu::RocmExecutor::LoadModuleFromHsaco(char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/stream_executor/rocm/rocm_executor.cc:717:39 #10 0x7f4011524387 in stream_executor::gpu::RocmExecutor::LoadModule(stream_executor::MultiModuleLoaderSpec const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/stream_executor/rocm/rocm_executor.cc:705:12 #11 0x7f405deeae34 in xla::gpu::GpuExecutable::ResolveConstantGlobals(stream_executor::Stream*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/gpu/gpu_executable.cc:499:5 #12 0x7f405def050a in xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl(xla::ServiceExecutableRunOptions const*, std::variant, absl::lts_20230802::Span>) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/gpu/gpu_executable.cc:703:5 #13 0x7f405deefc6f in xla::gpu::GpuExecutable::ExecuteAsyncOnStream(xla::ServiceExecutableRunOptions const*, std::vector>) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/gpu/gpu_executable.cc:661:10 #14 0x7f401607a78e in xla::Executable::ExecuteAsyncOnStreamWrapper(xla::ServiceExecutableRunOptions const*, std::vector>) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/executable.cc:229:7 #15 0x7f4016079fd3 in xla::Executable::ExecuteOnStreamWrapper(xla::ServiceExecutableRunOptions const*, std::vector>) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/executable.cc:153:7 #16 0x7f407b7ea78b in xla::HloRunner::ExecuteWithExecutionInputs(xla::Executable*, std::vector>, xla::ExecutionProfile*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:448:3 #17 0x7f407b7ecde2 in xla::HloRunner::ExecuteWithMovedDeviceBuffers(xla::Executable*, std::vector>, xla::ExecutionProfile*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:415:3 #18 0x7f407b7e642a in xla::HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment(std::unique_ptr>, xla::BufferAssignmentProto const*, std::vector>, bool, xla::ExecutionProfile*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:398:10 #19 0x7f407b7e57c3 in xla::HloRunner::Execute(std::unique_ptr>, absl::lts_20230802::Span, bool, xla::ExecutionProfile*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/hlo_runner.cc:221:3 #20 0x55b8a3cb4622 in xla::HloRunnerInterface::Execute(std::unique_ptr>, absl::lts_20230802::Span, bool) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/service/hlo_runner_interface.h:244:12 #21 0x55b8a3cb4622 in xla::HloRunnerAgnosticReferenceMixin::RunAndCompareInternal(std::unique_ptr>, absl::lts_20230802::Span, std::optional const&, bool, std::function const&, std::function const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/tests/hlo_runner_agnostic_reference_mixin.h:238:5 #22 0x55b8a3cbf766 in xla::HloRunnerAgnosticReferenceMixin::RunAndCompare(std::unique_ptr>, absl::lts_20230802::Span, std::optional const&, std::function const&, std::function const&) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/tests/hlo_runner_agnostic_reference_mixin.h:94:9 #23 0x55b8a3cbf235 in xla::HloRunnerAgnosticReferenceMixin::RunAndCompare(std::unique_ptr>, std::optional const&, std::function const&, std::function const&, std::optional) /home/atheodor/projects/tmp/xla_asan/execroot/xla/./xla/tests/hlo_runner_agnostic_reference_mixin.h:140:12 #24 0x55b8a3cceda8 in xla::(anonymous namespace)::ElementalIrEmitterExecutionTest::RunTypeConversionTest(std::basic_string_view>) /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/elemental_ir_emitter_test.cc:76:5 #25 0x55b8a3cd8cf3 in xla::(anonymous namespace)::ElementalIrEmitterExecutionTypedTest_ConvertFloatsToFloat_Test::TestBody() /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/service/elemental_ir_emitter_test.cc:472:36 #26 0x7f407b2f09dd in void testing::internal::HandleSehExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #27 0x7f407b2f09dd in void testing::internal::HandleExceptionsInMethodIfSupported(testing::Test*, void (testing::Test::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #28 0x7f407b2f0708 in testing::Test::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2739:5 #29 0x7f407b2f371b in testing::TestInfo::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2885:11 #30 0x7f407b2f65ab in testing::TestSuite::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:3063:30 #31 0x7f407b322eba in testing::internal::UnitTestImpl::RunAllTests() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:6054:44 #32 0x7f407b32179d in bool testing::internal::HandleSehExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #33 0x7f407b32179d in bool testing::internal::HandleExceptionsInMethodIfSupported(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #34 0x7f407b321203 in testing::UnitTest::Run() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:5594:10 #35 0x7f407b3f59b8 in RUN_ALL_TESTS() /home/atheodor/projects/tmp/xla_asan/execroot/xla/external/com_google_googletest/googletest/include/gtest/gtest.h:2334:73 #36 0x7f407b3f59b8 in main /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/tests/xla_internal_test_main.cc:65:10 #37 0x7f4004766d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16 SUMMARY: AddressSanitizer: use-after-poison /home/atheodor/projects/tmp/xla_asan/execroot/xla/xla/stream_executor/rocm/rocm_executor.cc:596:23 in stream_executor::gpu::RocmExecutor::UnloadGpuBinary(stream_executor::ModuleHandle) Shadow bytes around the buggy address: 0x506000843780: fa fa fa fa fd fd fd fd fd fd fd fa fa fa fa fa 0x506000843800: fd fd fd fd fd fd fd fa fa fa fa fa fd fd fd fd 0x506000843880: fd fd fd fa fa fa fa fa fd fd fd fd fd fd fd fd 0x506000843900: fa fa fa fa 00 00 00 00 00 00 00 fa fa fa fa fa 0x506000843980: fd fd fd fd fd fd fd fa fa fa fa fa 00 00 00 00 =>0x506000843a00: f7[f7]f7 fa fa fa fa fa 00 00 00 00 00 00 00 00 0x506000843a80: fa fa fa fa 00 00 00 00 00 00 00 fa fa fa fa fa 0x506000843b00: 00 00 00 00 00 00 00 fa fa fa fa fa 00 00 00 00 0x506000843b80: 00 00 00 fa fa fa fa fa 00 00 00 00 00 00 00 fa 0x506000843c00: fa fa fa fa 00 00 00 00 00 00 00 fa fa fa fa fa 0x506000843c80: 00 00 00 00 00 00 00 fa fa fa fa fa fd fd fd fd Shadow byte legend (one shadow byte represents 8 application bytes): Addressable: 00 Partially addressable: 01 02 03 04 05 06 07 Heap left redzone: fa Freed heap region: fd Stack left redzone: f1 Stack mid redzone: f2 Stack right redzone: f3 Stack after return: f5 Stack use after scope: f8 Global redzone: f9 Global init order: f6 Poisoned by user: f7 Container overflow: fc Array cookie: ac Intra object redzone: bb ASan internal: fe Left alloca redzone: ca Right alloca redzone: cb ==2457579==ABORTING ``` Copybara import of the project: -- 8f74d4c822d951b5a213500ea9396ed7b160871d by alekstheod : Fix asan report memory access vialation in rocm_executor Merging this change closes #24900 PiperOrigin-RevId: 745592235 --- third_party/xla/xla/stream_executor/rocm/rocm_executor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index d4dde90953fa03..150eac7fc97d95 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -584,7 +584,7 @@ bool RocmExecutor::UnloadGpuBinary(ModuleHandle module_handle) { VLOG(3) << "No loaded HSACO module for " << module_handle; return false; } - auto& module = module_it->second.first; + auto module = module_it->second.first; auto& refcount = module_it->second.second; VLOG(3) << "Found HSACO module " << module << " with refcount " << refcount; if (--refcount == 0) { From 43312a90f397a575e70b463c883bbf8e6300cb3d Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 9 Apr 2025 08:23:28 -0700 Subject: [PATCH 0444/1324] #sdy check if `xla_dump_hlo_pass_re` is empty, otherwise `RE2::PartialMatch` will return true. PiperOrigin-RevId: 745597694 --- .../service/spmd/shardy/shardy_xla_pass.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index 11924bb74f2f13..80f3b145610807 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -294,6 +294,19 @@ void removeFrontendAttributes(HloModule* hloModule, hloModule->set_frontend_attributes(feAttrs); } +std::string getShardyDirIfShouldDump(const DebugOptions& debugOptions, + absl::string_view passName) { + std::string shardyDir = debugOptions.xla_dump_to(); + if (shardyDir.empty()) { + return ""; + } + if (debugOptions.xla_dump_hlo_pass_re().empty() || + !RE2::PartialMatch(passName, debugOptions.xla_dump_hlo_pass_re())) { + return ""; + } + return shardyDir; +} + } // namespace absl::StatusOr ShardyXLA::Run( @@ -309,12 +322,7 @@ absl::StatusOr ShardyXLA::Run( xla::ConvertHloToStablehlo(*mlirContext.get(), hloModule)); const DebugOptions& debugOptions = hloModule->config().debug_options(); - std::string shardyDir = debugOptions.xla_dump_to(); - - if (!shardyDir.empty() && - !RE2::PartialMatch(name(), debugOptions.xla_dump_hlo_pass_re())) { - shardyDir.clear(); - } + std::string shardyDir = getShardyDirIfShouldDump(debugOptions, name()); if (shardyDir == "sponge") { shardyDir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); From 72a9a9b0f823c9cc29f26899e6f8f92602d5fdb6 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Wed, 9 Apr 2025 08:47:54 -0700 Subject: [PATCH 0445/1324] [XLA:WHILE_LOOP_FUSIBLE] Ensure while loop has at least one iteration and that the while loop induction variable for transformation is not used in the while condition computation. PiperOrigin-RevId: 745605952 --- .../xla/service/while_loop_fusible_sinking.cc | 3 +- .../while_loop_fusible_sinking_test.cc | 155 +++++++++++------- 2 files changed, 97 insertions(+), 61 deletions(-) diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index 09b760863b76ef..ed7fb233b03fe5 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -188,7 +188,8 @@ absl::StatusOr TryRewritingBroadcastAsAllocateBuffer( HloInstruction* while_instr) { std::optional induction_var_tuple_index = GetLoopInductionVarTupleIdx(while_instr); - if (!induction_var_tuple_index.has_value()) { + if (!induction_var_tuple_index.has_value() || + ComputeWhileLoopTripCount(while_instr).value_or(0) == 0) { return false; } HloComputation* while_body = while_instr->while_body(); diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc index c161bc05d20417..885e2738daa57a 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -290,70 +289,107 @@ TEST_F(WhileLoopFusibleSinkingTest, op::While(op::Tuple(_, op::CustomCall(), _, _))); } -TEST_F(WhileLoopFusibleSinkingTest, - TestPlumbSingleBroadcastNoneZeroLoopIterationVar) { +TEST_F(WhileLoopFusibleSinkingTest, TestPlumbMultipleBroadcast) { const std::string hlo_string_before = R"( - HloModule cluster_6512412223095190558_f15n_0__.258 - - %wide._functionalize_body_1_const_0__.164.clone.clone.clone.clone (wide.arg_tuple.1: (s32[], f32[2])) -> (s32[], f32[2]) { - %wide.arg_tuple.1 = (s32[], f32[2]{0}) parameter(0) - %get-tuple-element.383 = s32[] get-tuple-element((s32[], f32[2]{0}) %wide.arg_tuple.1), index=0 - %constant.50..sunk.4 = s32[] constant(-1) - %add.48 = s32[] add(s32[] %get-tuple-element.383, s32[] %constant.50..sunk.4) - %get-tuple-element.384 = f32[2]{0} get-tuple-element((s32[], f32[2]{0}) %wide.arg_tuple.1), index=1 - %constant.11..sunk.4 = f32[] constant(1) - %broadcast.19 = f32[2]{0} broadcast(f32[] %constant.11..sunk.4), dimensions={} - %add.49 = f32[2]{0} add(f32[2]{0} %get-tuple-element.384, f32[2]{0} %broadcast.19) - ROOT %tuple.55 = (s32[], f32[2]{0}) tuple(s32[] %add.48, f32[2]{0} %add.49) - } - - %wide.cond_wrapper.236.clone.clone.clone.clone (wide.inputs.1: (s32[], f32[2])) -> pred[] { - %wide.inputs.1 = (s32[], f32[2]{0}) parameter(0) - %get-tuple-element.382 = s32[] get-tuple-element((s32[], f32[2]{0}) %wide.inputs.1), index=0 - %constant.66 = s32[] constant(1) - ROOT %compare.10 = pred[] compare(s32[] %get-tuple-element.382, s32[] %constant.66), direction=GE - } - - %_functionalize_body_0_const_0__.40.clone.clone.clone.clone.clone.clone.clone (arg_tuple.9: (s32[])) -> (s32[]) { - %arg_tuple.9 = (s32[]) parameter(0) - %get-tuple-element.409 = s32[] get-tuple-element((s32[]) %arg_tuple.9), index=0 - %constant.71 = s32[] constant(1) - %add.57 = s32[] add(s32[] %get-tuple-element.409, s32[] %constant.71) - ROOT %tuple.61 = (s32[]) tuple(s32[] %add.57) - } - - %cond_wrapper.120.clone.clone.clone.clone.clone.clone (inputs.7: (s32[])) -> pred[] { - %inputs.7 = (s32[]) parameter(0) - %get-tuple-element.408 = s32[] get-tuple-element((s32[]) %inputs.7), index=0 - %constant.70 = s32[] constant(10) - ROOT %compare.12 = pred[] compare(s32[] %get-tuple-element.408, s32[] %constant.70), direction=LT - } - - ENTRY %cluster_6512412223095190558_f15n_0__.258{ - %arg_tuple.1 = () parameter(0) - %constant.24 = s32[] constant(0) - %tuple.60 = (s32[]) tuple(s32[] %constant.24) - %while.10 = (s32[]) while((s32[]) %tuple.60), condition=%cond_wrapper.120.clone.clone.clone.clone.clone.clone, body=%_functionalize_body_0_const_0__.40.clone.clone.clone.clone.clone.clone.clone - %get-tuple-element.380 = s32[] get-tuple-element((s32[]) %while.10), index=0 - %constant.9 = f32[] constant(0) - %broadcast.10 = f32[2]{0} broadcast(f32[] %constant.9), dimensions={} - %tuple.54 = (s32[], f32[2]{0}) tuple(s32[] %get-tuple-element.380, f32[2]{0} %broadcast.10) - ROOT %while.8 = (s32[], f32[2]{0}) while((s32[], f32[2]{0}) %tuple.54), condition=%wide.cond_wrapper.236.clone.clone.clone.clone, body=%wide._functionalize_body_1_const_0__.164.clone.clone.clone.clone - } + HloModule test + + loop.body { + loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.4 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=2 + get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=3 + bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3) + add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855) + add.1 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.4, add.40974) + constant.1 = s32[]{:T(128)} constant(1) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, add.1, get-tuple-element.3) + } + + loop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0 + constant.2 = s32[]{:T(128)} constant(4) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + + ENTRY %main { + param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0 + zero = s32[]{:T(128)} constant(0) + zeros32 = s32[]{:T(128)} constant(0) + broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32) + input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, broadcast, param.1) + ROOT while = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body + } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_before, ParseAndReturnVerifiedModule(hlo_string_before)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopFusibleSinking{}.Run(module_before.get())); EXPECT_TRUE(changed); - EXPECT_THAT(FindInstruction(module_before.get(), "while.8"), - op::While(op::Tuple(_, op::CustomCall(), _))); + EXPECT_THAT( + FindInstruction(module_before.get(), "while"), + op::While(op::Tuple(_, op::CustomCall(), op::CustomCall(), _, _))); } -TEST_F(WhileLoopFusibleSinkingTest, TestPlumbMultipleBroadcast) { +TEST_F(WhileLoopFusibleSinkingTest, TestNoPlumbWithBadCondition) { const std::string hlo_string_before = R"( HloModule test + tmp { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT add = s32[] add(x, y) + } + + loop.body { + loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.4 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=2 + get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=3 + bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3) + add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855) + add.1 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.4, add.40974) + constant.1 = s32[]{:T(128)} constant(1) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, add.1, get-tuple-element.3) + } + + loop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.2), index=1 + z = s32[]{:T(128)} constant(0) + r = s32[]{:T(128)} reduce(get-tuple-element.4, z), dimensions={0,1,2,3,4,5}, to_apply=tmp + ROOT less-than = pred[] compare(get-tuple-element.3, r), direction=LT + } + + ENTRY %main { + param.1 = s32[4,3,5]{2,1,0} parameter(0) + zero = s32[]{:T(128)} constant(0) + zeros32 = s32[]{:T(128)} constant(0) + broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32) + input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, broadcast, param.1) + ROOT while = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_before, + ParseAndReturnVerifiedModule(hlo_string_before)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module_before.get())); + EXPECT_FALSE(changed); +} +TEST_F(WhileLoopFusibleSinkingTest, TestNoPlumbWithUnknonwnTripCount) { + const std::string hlo_string_before = R"( + HloModule test + tmp { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT add = s32[] add(x, y) + } + loop.body { loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 @@ -371,12 +407,14 @@ TEST_F(WhileLoopFusibleSinkingTest, TestPlumbMultipleBroadcast) { loop.condition { loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0 - constant.2 = s32[]{:T(128)} constant(4) - ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + m = s32[] constant(0) + v = s32[] constant(10000) + rng = s32[] rng(m, v), distribution=rng_uniform + ROOT less-than = pred[] compare(get-tuple-element.3, rng), direction=LT } ENTRY %main { - param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0 + param.1 = s32[4,3,5]{2,1,0} parameter(0) zero = s32[]{:T(128)} constant(0) zeros32 = s32[]{:T(128)} constant(0) broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32) @@ -388,10 +426,7 @@ TEST_F(WhileLoopFusibleSinkingTest, TestPlumbMultipleBroadcast) { ParseAndReturnVerifiedModule(hlo_string_before)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopFusibleSinking{}.Run(module_before.get())); - EXPECT_TRUE(changed); - EXPECT_THAT( - FindInstruction(module_before.get(), "while"), - op::While(op::Tuple(_, op::CustomCall(), op::CustomCall(), _, _))); + EXPECT_FALSE(changed); } } // namespace From f505a638366bb986f6880e1aa6e93e4036e98021 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 9 Apr 2025 08:56:58 -0700 Subject: [PATCH 0446/1324] [xla:gpu] CommandBuffer: add a RecordAction to distinguish between command create and update Also link CommandBufferCmd state with and underlying se::CommandBuffer as the same command might be recorded multiple times (in presence of a control flow in thunk sequence). PiperOrigin-RevId: 745609021 --- .../xla/xla/backends/gpu/runtime/BUILD | 10 +- .../gpu/runtime/command_buffer_cmd.cc | 109 +++++++++++++----- .../backends/gpu/runtime/command_buffer_cmd.h | 99 ++++++++++------ .../gpu/runtime/command_buffer_cmd_test.cc | 22 ++-- .../gpu/runtime/command_buffer_thunk.cc | 10 +- .../gpu/runtime/command_buffer_thunk_test.cc | 3 +- 6 files changed, 172 insertions(+), 81 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 372b2eaf0b59f6..101dc1bdd138f3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -182,6 +182,7 @@ xla_test( "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -327,16 +328,16 @@ cc_library( "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:profiler_lock", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", @@ -369,6 +370,7 @@ xla_test( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:executable", + "//xla/service:hlo_module_config", "//xla/service:platform_util", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 51c8bdfeebfc0a..634b42d0506590 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -170,8 +170,9 @@ CommandBufferCmd::StateManager::GetNextTypeId() { } CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrNull( - const CommandBufferCmd* cmd, TypeId type_id) { - StateKey key = {cmd, type_id}; + const CommandBufferCmd* cmd, const se::CommandBuffer* command_buffer, + TypeId type_id) { + Key key = {cmd, command_buffer, type_id}; if (auto it = state_.find(key); it != state_.end()) { return it->second.get(); } @@ -179,9 +180,9 @@ CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrNull( } CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( - const CommandBufferCmd* cmd, TypeId type_id, - absl::FunctionRef()> create) { - StateKey key = {cmd, type_id}; + const CommandBufferCmd* cmd, const se::CommandBuffer* command_buffer, + TypeId type_id, absl::FunctionRef()> create) { + Key key = {cmd, command_buffer, type_id}; if (auto it = state_.find(key); it != state_.end()) { return it->second.get(); } @@ -258,8 +259,12 @@ absl::Status CommandBufferCmdSequence::Record( } } - // Track the number of commands recorded between barriers. - int64_t num_recorded_commands = 0; + // Keep a state associated with commands in the sequence in the state manager. + CommandBufferCmd::StateManager& state = record_params.state; + + // If the command buffer is in update state, it means that we already recorded + // all commands into the underlying command buffer and we need to update them. + bool is_update = command_buffer->state() == se::CommandBuffer::State::kUpdate; for (std::unique_ptr& command : commands_) { if (execute_params.mock_collectives && @@ -270,11 +275,35 @@ absl::Status CommandBufferCmdSequence::Record( std::optional annotation = GetKernelAnnotation(command->profile_annotation()); - TF_ASSIGN_OR_RETURN( - CommandBufferCmd::RecordedCommands recorded_commands, - command->Record(execute_params, record_params, command_buffer)); - (void)recorded_commands; - ++num_recorded_commands; + if (is_update) { + // Update existing commands in the command buffer. + auto* record_state = + state.GetOrNull(command.get(), command_buffer); + DCHECK(record_state) << "Record state must be not null for " + << command->ToString(); + + auto record_action = CommandBufferCmd::RecordUpdate{ + std::move(record_state->recorded_commands)}; + TF_ASSIGN_OR_RETURN( + record_state->recorded_commands, + command->Record(execute_params, record_params, + std::move(record_action), command_buffer)); + + } else { + // Create new commands by recording them into the command buffer. + DCHECK(!state.GetOrNull(command.get(), command_buffer)) + << "Record state must be null for " << command->ToString(); + auto* record_state = + state.GetOrCreate(command.get(), command_buffer); + + // TODO(b/406370928): Fetch command dependencies computed from the command + // sequence, today we rely on implicit synchronization of all commands. + auto record_action = CommandBufferCmd::RecordCreate{}; + TF_ASSIGN_OR_RETURN( + record_state->recorded_commands, + command->Record(execute_params, record_params, + std::move(record_action), command_buffer)); + } } if (mode == RecordMode::kExclusive) { @@ -386,8 +415,8 @@ TracedCommandBufferCmd::AddTracedCommandBuffer( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::FunctionRef trace) { - auto traced_cmd = - record_params.state.GetOrCreate(this, [&] { + auto traced_cmd = record_params.state.GetOrCreate( + this, command_buffer, [&] { const auto& debug_options = xla::GetDebugOptionsFromFlags(); return std::make_unique( this, buffers(), debug_options.xla_cmd_buffer_trace_cache_size()); @@ -496,7 +525,8 @@ absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr ComputationIdCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dest_); @@ -576,7 +606,8 @@ absl::Status LaunchCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr LaunchCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { VLOG(5) << "LaunchCmd: kernel=" << kernel_name_ << "; shmem_bytes=" << shmem_bytes_; @@ -647,6 +678,7 @@ absl::Status CustomKernelLaunchCmd::Initialize( absl::StatusOr CustomKernelLaunchCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, + RecordAction record_action, se::CommandBuffer* command_buffer) { VLOG(5) << "CustomKernelLaunchCmd: custom_kernel=" << custom_kernel_.name(); @@ -701,6 +733,7 @@ MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd( absl::StatusOr MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, + RecordAction record_action, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dst_); @@ -735,7 +768,8 @@ MemzeroCmd::MemzeroCmd(ExecutionStreamId execution_stream_id, absl::StatusOr MemzeroCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dst_); @@ -768,7 +802,8 @@ Memset32Cmd::Memset32Cmd(ExecutionStreamId execution_stream_id, absl::StatusOr Memset32Cmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = execute_params.buffer_allocations->GetDeviceAddress(dst_); @@ -811,7 +846,8 @@ absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr CaseCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase index = execute_params.buffer_allocations->GetDeviceAddress(index_); @@ -868,7 +904,8 @@ absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr WhileCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase pred = execute_params.buffer_allocations->GetDeviceAddress(pred_); @@ -925,7 +962,8 @@ absl::Status GemmCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr GemmCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase lhs = execute_params.buffer_allocations->GetDeviceAddress(lhs_buffer_); se::DeviceMemoryBase rhs = @@ -1040,7 +1078,8 @@ absl::Status CublasLtCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr CublasLtCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(execute_params.stream)); TF_ASSIGN_OR_RETURN(auto algorithm, GetMatmulAlgorithm(execute_params.stream, plan, @@ -1152,7 +1191,8 @@ absl::Status CuDnnCmd::Initialize(const Thunk::InitializeParams& params, absl::StatusOr CuDnnCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { CHECK(graph_ != nullptr); std::vector operands; operands.reserve(args_.size()); @@ -1194,7 +1234,8 @@ CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { absl::StatusOr CustomCallCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { if (handler_ == nullptr) { return RecordLegacyCustomCall(execute_params, record_params, command_buffer); @@ -1402,7 +1443,8 @@ AllReduceCmd::AllReduceCmd(ExecutionStreamId execution_stream_id, absl::StatusOr AllReduceCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1463,7 +1505,8 @@ ReduceScatterCmd::ReduceScatterCmd( absl::StatusOr ReduceScatterCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1524,7 +1567,8 @@ AllToAllCmd::AllToAllCmd(ExecutionStreamId execution_stream_id, absl::StatusOr AllToAllCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1582,7 +1626,8 @@ AllGatherCmd::AllGatherCmd(ExecutionStreamId execution_stream_id, absl::StatusOr AllGatherCmd::Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, @@ -1643,6 +1688,7 @@ CollectiveBroadcastCmd::CollectiveBroadcastCmd( absl::StatusOr CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, + RecordAction record_action, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, @@ -1788,6 +1834,7 @@ absl::Status DynamicSliceFusionCmd::Prepare( absl::StatusOr DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, + RecordAction record_action, se::CommandBuffer* command_buffer) { se::Stream& stream = *execute_params.stream; @@ -1913,6 +1960,12 @@ DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, Thunk::ExecuteParams new_params = Thunk::ExecuteParams::CloneWithNewAllocations(execute_params, slice_allocations); + + // TODO(b/406370928): Instead of creating a nested command buffer on every + // call we should create it once and update it. CommandBufferThunk state + // manager relies on command buffer pointer as an identity for command + // buffers, and it means that command buffer commands sequence should not + // create ephemeral command buffers at run time. auto nested_command_buffer = execute_params.stream->parent() ->CreateCommandBuffer(se::CommandBuffer::Mode::kNested) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index d65c1a82f0fb5b..b6d53ab7ffc339 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -21,7 +21,9 @@ limitations under the License. #include #include #include +#include #include +#include #include #include "absl/algorithm/container.h" @@ -135,32 +137,36 @@ class CommandBufferCmd { virtual ~State() = default; }; - // An external manager for a state attached to commands. + // An external manager for a state attached to commands recorded into command + // buffers (same command can be recorded into multiple command buffers). class StateManager { public: virtual ~StateManager() = default; template - ConcreteState* GetOrNull(const CommandBufferCmd* cmd) { + ConcreteState* GetOrNull(const CommandBufferCmd* cmd, + const se::CommandBuffer* command_buffer) { static_assert(std::is_base_of_v); return static_cast( - GetOrNull(cmd, GetTypeId())); + GetOrNull(cmd, command_buffer, GetTypeId())); } template ConcreteState* GetOrCreate( - const CommandBufferCmd* cmd, + const CommandBufferCmd* cmd, const se::CommandBuffer* command_buffer, absl::FunctionRef()> create) { static_assert(std::is_base_of_v); - return static_cast( - GetOrCreate(cmd, GetTypeId(), - [&]() -> std::unique_ptr { return create(); })); + return static_cast(GetOrCreate(cmd, command_buffer, + GetTypeId(), + [&] { return create(); })); } template - ConcreteState* GetOrCreate(const CommandBufferCmd* cmd) { - return GetOrCreate( - cmd, [] { return std::make_unique(); }); + ConcreteState* GetOrCreate(const CommandBufferCmd* cmd, + const se::CommandBuffer* command_buffer) { + return GetOrCreate(cmd, command_buffer, [] { + return std::make_unique(); + }); } private: @@ -175,13 +181,16 @@ class CommandBufferCmd { static TypeId GetNextTypeId(); - State* GetOrNull(const CommandBufferCmd* cmd, TypeId type_id); + State* GetOrNull(const CommandBufferCmd* cmd, + const se::CommandBuffer* command_buffer, TypeId type_id); - State* GetOrCreate(const CommandBufferCmd* cmd, TypeId type_id, + State* GetOrCreate(const CommandBufferCmd* cmd, + const se::CommandBuffer* command_buffer, TypeId type_id, absl::FunctionRef()> create); - using StateKey = std::pair; - absl::flat_hash_map> state_; + using Key = + std::tuple; + absl::flat_hash_map> state_; }; // Parameters for recording commands into the command buffer. @@ -200,6 +209,23 @@ class CommandBufferCmd { absl::InlinedVector commands; }; + // Create new commands in the command buffer using the given dependencies. + struct RecordCreate { + absl::Span dependencies; + }; + + // Update previously recorded commands in the command buffer. + struct RecordUpdate { + RecordedCommands recorded_commands; + }; + + // When recording a command into the command buffer we can either update + // previously recorded commands or create new ones. The command DAG structure + // can be defined only when we record commands the first time, after that we + // can only update previously recorded commands parameters (i.e. with pointers + // to new buffer allocations). + using RecordAction = std::variant; + // See Thunk documentation for XLA execution stages (prepare, initialize, // execute). Commands mirror thunks as they are executed as CommandBufferThunk // that is plugged into the Thunk execution cycle. @@ -225,7 +251,8 @@ class CommandBufferCmd { // can do efficient command buffer updates. virtual absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) = 0; + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) = 0; // For some commands need to force update on Record even the input device // pointers do not change, e.g. command that has state that can be changed by @@ -344,6 +371,12 @@ class CommandBufferCmdSequence { } private: + // A state associated with commands in the sequence. We rely on this state to + // efficiently update command recorded into the command buffer. + struct RecordState : public CommandBufferCmd::State { + CommandBufferCmd::RecordedCommands recorded_commands; + }; + CommandBufferCmdSequence( SynchronizationMode synchronization_mode, std::vector> commands); @@ -425,7 +458,7 @@ class ComputationIdCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -463,7 +496,7 @@ class LaunchCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -498,7 +531,7 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -527,7 +560,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -549,7 +582,7 @@ class MemzeroCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -569,7 +602,7 @@ class Memset32Cmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -594,7 +627,7 @@ class CaseCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; bool force_update() override; @@ -622,7 +655,7 @@ class WhileCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; bool force_update() override; @@ -652,7 +685,7 @@ class GemmCmd : public TracedCommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -695,7 +728,7 @@ class CublasLtCmd : public TracedCommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -751,7 +784,7 @@ class CuDnnCmd : public TracedCommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -805,7 +838,7 @@ class CustomCallCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -896,7 +929,7 @@ class AllReduceCmd : public CollectiveCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -923,7 +956,7 @@ class ReduceScatterCmd : public CollectiveCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -950,7 +983,7 @@ class AllToAllCmd : public CollectiveCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -976,7 +1009,7 @@ class AllGatherCmd : public CollectiveCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -1002,7 +1035,7 @@ class CollectiveBroadcastCmd : public CollectiveCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; @@ -1037,7 +1070,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, + const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; BufferUseVector buffers() override; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 892e61c62c5136..43bf61a1fb614f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" #include "xla/types.h" // IWYU pragma: keep namespace xla::gpu { @@ -76,7 +77,7 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { buffer_usage(buffer_usage) {} absl::StatusOr Record(const Thunk::ExecuteParams&, - const RecordParams&, + const RecordParams&, RecordAction, se::CommandBuffer*) override { return RecordedCommands{}; } @@ -93,7 +94,7 @@ class FakeCmd : public CommandBufferCmd { execution_stream_id) {} absl::StatusOr Record(const Thunk::ExecuteParams&, - const RecordParams&, + const RecordParams&, RecordAction, se::CommandBuffer*) override { return RecordedCommands{}; } @@ -110,31 +111,34 @@ TEST(CommandBufferCmdStateManageTest, GetOrCreateState) { }; // We need a fake command buffer pointer to use as a key. - CommandBufferCmd* cmd = reinterpret_cast(0x1234567); + auto* cmd = + tsl::safe_reinterpret_cast(std::intptr_t{0x1234567}); + auto* command_buffer = + tsl::safe_reinterpret_cast(std::intptr_t{0x1234567}); CommandBufferCmd::StateManager state_manager; // Create a state of type StateA. - auto* stateA0 = state_manager.GetOrNull(cmd); + auto* stateA0 = state_manager.GetOrNull(cmd, command_buffer); ASSERT_EQ(stateA0, nullptr); - auto* stateA1 = state_manager.GetOrCreate(cmd); + auto* stateA1 = state_manager.GetOrCreate(cmd, command_buffer); ASSERT_EQ(stateA1->value, 0); stateA1->value += 42; - auto* stateA2 = state_manager.GetOrCreate(cmd); + auto* stateA2 = state_manager.GetOrCreate(cmd, command_buffer); ASSERT_EQ(stateA2->value, 42); ASSERT_EQ(stateA1, stateA2); // StateB has a different type, and has no connection to StateA created above. - auto* stateB0 = state_manager.GetOrNull(cmd); + auto* stateB0 = state_manager.GetOrNull(cmd, command_buffer); ASSERT_EQ(stateB0, nullptr); - auto* stateB1 = state_manager.GetOrCreate(cmd); + auto* stateB1 = state_manager.GetOrCreate(cmd, command_buffer); ASSERT_EQ(stateB1->value, 0); stateB1->value += 42.0; - auto* stateB2 = state_manager.GetOrCreate(cmd); + auto* stateB2 = state_manager.GetOrCreate(cmd, command_buffer); ASSERT_EQ(stateB2->value, 42.0); ASSERT_EQ(stateB1, stateB2); } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc index a13f590696506f..4232e9d57a4586 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -26,7 +25,6 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" -#include "xla/backends/gpu/runtime/annotation.h" #include "xla/backends/gpu/runtime/command_buffer_cmd.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -35,10 +33,10 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/profiler_lock.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index c759af5746bf23..6fea36551f864c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape.h" @@ -738,7 +739,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); } -TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { +TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { se::StreamExecutor* executor = GpuExecutor(); if (!IsAtLeastCuda12300(executor)) { From 161861433ff3e8ca3912256b9f33889def73e143 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 9 Apr 2025 09:04:18 -0700 Subject: [PATCH 0447/1324] Automated Code Change PiperOrigin-RevId: 745611755 --- third_party/xla/xla/python/ifrt/ir/BUILD | 226 ++++++------------ .../xla/xla/python/ifrt/ir/transforms/BUILD | 13 +- 2 files changed, 78 insertions(+), 161 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index a74a38deda79b3..09b6b12e6d163c 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -38,50 +38,32 @@ td_library( gentbl_cc_library( name = "ifrt_dialect_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-dialect-decls", - "-dialect=ifrt", - ], - "ifrt_dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=ifrt", - ], - "ifrt_dialect.cc.inc", - ), - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=ifrt", - ], - "ifrt_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=ifrt", - ], - "ifrt_types.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "--attrdefs-dialect=ifrt", - ], - "ifrt_attrs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "--attrdefs-dialect=ifrt", - ], - "ifrt_attrs.cc.inc", - ), - ], + tbl_outs = { + "ifrt_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=ifrt", + ], + "ifrt_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=ifrt", + ], + "ifrt_types.h.inc": [ + "-gen-typedef-decls", + "--typedefs-dialect=ifrt", + ], + "ifrt_types.cc.inc": [ + "-gen-typedef-defs", + "--typedefs-dialect=ifrt", + ], + "ifrt_attrs.h.inc": [ + "-gen-attrdef-decls", + "--attrdefs-dialect=ifrt", + ], + "ifrt_attrs.cc.inc": [ + "-gen-attrdef-defs", + "--attrdefs-dialect=ifrt", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ifrt_dialect.td", test = True, @@ -91,16 +73,10 @@ gentbl_cc_library( gentbl_cc_library( name = "ifrt_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ifrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ifrt_ops.cc.inc", - ), - ], + tbl_outs = { + "ifrt_ops.h.inc": ["-gen-op-decls"], + "ifrt_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ifrt_ops.td", test = True, @@ -110,24 +86,12 @@ gentbl_cc_library( gentbl_cc_library( name = "ifrt_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-attr-interface-decls"], - "ifrt_attr_interfaces.h.inc", - ), - ( - ["-gen-attr-interface-defs"], - "ifrt_attr_interfaces.cc.inc", - ), - ( - ["-gen-op-interface-decls"], - "ifrt_op_interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ifrt_op_interfaces.cc.inc", - ), - ], + tbl_outs = { + "ifrt_attr_interfaces.h.inc": ["-gen-attr-interface-decls"], + "ifrt_attr_interfaces.cc.inc": ["-gen-attr-interface-defs"], + "ifrt_op_interfaces.h.inc": ["-gen-op-interface-decls"], + "ifrt_op_interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ifrt_interfaces.td", test = True, @@ -341,32 +305,14 @@ td_library( gentbl_cc_library( name = "vifrt_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-attr-interface-decls"], - "vifrt_attr_interfaces.h.inc", - ), - ( - ["-gen-attr-interface-defs"], - "vifrt_attr_interfaces.cc.inc", - ), - ( - ["-gen-type-interface-decls"], - "vifrt_type_interfaces.h.inc", - ), - ( - ["-gen-type-interface-defs"], - "vifrt_type_interfaces.cc.inc", - ), - ( - ["-gen-op-interface-decls"], - "vifrt_op_interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "vifrt_op_interfaces.cc.inc", - ), - ], + tbl_outs = { + "vifrt_attr_interfaces.h.inc": ["-gen-attr-interface-decls"], + "vifrt_attr_interfaces.cc.inc": ["-gen-attr-interface-defs"], + "vifrt_type_interfaces.h.inc": ["-gen-type-interface-decls"], + "vifrt_type_interfaces.cc.inc": ["-gen-type-interface-defs"], + "vifrt_op_interfaces.h.inc": ["-gen-op-interface-decls"], + "vifrt_op_interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "vifrt_interfaces.td", test = True, @@ -376,50 +322,32 @@ gentbl_cc_library( gentbl_cc_library( name = "vifrt_dialect_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-dialect-decls", - "-dialect=vifrt", - ], - "vifrt_dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=vifrt", - ], - "vifrt_dialect.cc.inc", - ), - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=vifrt", - ], - "vifrt_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=vifrt", - ], - "vifrt_types.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "--attrdefs-dialect=vifrt", - ], - "vifrt_attrs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "--attrdefs-dialect=vifrt", - ], - "vifrt_attrs.cc.inc", - ), - ], + tbl_outs = { + "vifrt_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=vifrt", + ], + "vifrt_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=vifrt", + ], + "vifrt_types.h.inc": [ + "-gen-typedef-decls", + "--typedefs-dialect=vifrt", + ], + "vifrt_types.cc.inc": [ + "-gen-typedef-defs", + "--typedefs-dialect=vifrt", + ], + "vifrt_attrs.h.inc": [ + "-gen-attrdef-decls", + "--attrdefs-dialect=vifrt", + ], + "vifrt_attrs.cc.inc": [ + "-gen-attrdef-defs", + "--attrdefs-dialect=vifrt", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "vifrt_dialect.td", test = True, @@ -429,16 +357,10 @@ gentbl_cc_library( gentbl_cc_library( name = "vifrt_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "vifrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "vifrt_ops.cc.inc", - ), - ], + tbl_outs = { + "vifrt_ops.h.inc": ["-gen-op-decls"], + "vifrt_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "vifrt_ops.td", test = True, diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 8a425c11e593d3..7cb7cf731fd05a 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -10,15 +10,10 @@ package( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=IfrtIr", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=IfrtIr", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ From 62afe6aae7f24749e8b84e8c12954a871c778010 Mon Sep 17 00:00:00 2001 From: Fabian Mentzer Date: Wed, 9 Apr 2025 09:13:55 -0700 Subject: [PATCH 0448/1324] Reverts 6594d5a03c53cb86d250600cd55a8c557e12a611 PiperOrigin-RevId: 745615197 --- tensorflow/core/lib/gif/gif_io.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index e4123665dfe788..df9201a96f5b25 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -76,26 +76,30 @@ uint8* Decode(const void* srcdata, int datasize, return nullptr; } + int target_num_frames; if (DGifSlurp(gif_file) != GIF_OK) { *error_string = absl::StrCat("failed to slurp gif file: ", GifErrorStringNonNull(gif_file->Error)); // Stop load if no images are detected or the allocation of the last image // buffer was failed. if (gif_file->ImageCount <= 0 || - gif_file->SavedImages[gif_file->ImageCount - 1].RasterBits == nullptr || - gif_file->Error == D_GIF_ERR_EOF_TOO_SOON) { + gif_file->SavedImages[gif_file->ImageCount - 1].RasterBits == nullptr) { return nullptr; } + // If giflib parses the header correctly but the image data is corrupt, + // giflib incorrectly sets ImageCount if it hits an error. + target_num_frames = gif_file->ImageCount - 1; + LOG(WARNING) << "Decoding" << target_num_frames << " frames due to error."; LOG(ERROR) << *error_string; + } else { + target_num_frames = gif_file->ImageCount; } - if (gif_file->ImageCount <= 0) { + if (target_num_frames <= 0) { *error_string = "gif file does not contain any image"; return nullptr; } - int target_num_frames = gif_file->ImageCount; - // Don't request more memory than needed for each frame, preventing OOM int max_frame_width = 0; int max_frame_height = 0; From 8ec51a36f6c35fa74775387dda977bf4f83baeb7 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 9 Apr 2025 09:14:30 -0700 Subject: [PATCH 0449/1324] Use helper `IsAtLeastCuda12300` function. PiperOrigin-RevId: 745615366 --- .../xla/stream_executor/gpu/gpu_command_buffer_test.cc | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index b4f94f963cf9dd..3c1638a1340464 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -141,13 +141,7 @@ TEST(GpuCommandBufferTest, TraceSingleKernel) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - if (platform->id() == rocm::kROCmPlatformId) { - GTEST_SKIP() << "Not supported on ROCM"; - } - - if (platform->id() == cuda::kCudaPlatformId && - executor->GetDeviceDescription().runtime_version() < - SemanticVersion{12, 3, 0}) { + if (!IsAtLeastCuda12300(executor)) { GTEST_SKIP() << "Command buffer tracing is not supported"; } From 37946c8699c9194a72a0bd730f52a06debcb763a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 9 Apr 2025 10:11:58 -0700 Subject: [PATCH 0450/1324] [xla] Always return absl::StatusOr> result from Rendezvous PiperOrigin-RevId: 745637005 --- .../collectives/in_process_communicator.cc | 35 ++-- .../backends/gpu/collectives/gpu_cliques.cc | 6 +- .../gpu/runtime/collective_permute_thunk.cc | 14 +- .../backends/gpu/runtime/collective_thunk.cc | 8 +- .../gpu/runtime/ragged_all_to_all_thunk.cc | 12 +- third_party/xla/xla/service/BUILD | 4 + .../xla/xla/service/gpu/gpu_executable.cc | 4 +- third_party/xla/xla/service/rendezvous.h | 197 +++++++++--------- .../xla/xla/service/rendezvous_test.cc | 81 ++++--- 9 files changed, 204 insertions(+), 157 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc index 30af23eddee945..95baf3874dbfc7 100644 --- a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc @@ -404,9 +404,10 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, std::string name = absl::StrCat("all reduce ", key.ToString()); AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer}; - auto op = Rendezvous>( - name, key, partiticipant, key.num_local_participants, - CollectParticipants); + TF_ASSIGN_OR_RETURN(auto op, + Rendezvous>( + name, key, partiticipant, key.num_local_participants, + CollectParticipants)); return op->Invoke(AllReduceOp, rank_, dtype, count, reduction_kind); } @@ -421,9 +422,10 @@ absl::Status InProcessCommunicator::ReduceScatter( std::string name = absl::StrCat("reduce scatter ", key.ToString()); ReduceScatterParticipant partiticipant{rank_, send_buffer, recv_buffer}; - auto op = Rendezvous>( - name, key, partiticipant, key.num_local_participants, - CollectParticipants); + TF_ASSIGN_OR_RETURN(auto op, + Rendezvous>( + name, key, partiticipant, key.num_local_participants, + CollectParticipants)); return op->Invoke(ReduceScatterOp, rank_, dtype, count, reduction_kind); } @@ -439,9 +441,10 @@ absl::Status InProcessCommunicator::CollectivePermute( CollectivePermuteParticipant partiticipant{rank_, source_rank, send_buffer, recv_buffer}; - auto op = Rendezvous>( - name, key, partiticipant, key.num_local_participants, - CollectParticipants); + TF_ASSIGN_OR_RETURN(auto op, + Rendezvous>( + name, key, partiticipant, key.num_local_participants, + CollectParticipants)); size_t num_bytes = count * primitive_util::ByteWidth(dtype); return op->Invoke(CollectivePermuteOp, rank_, num_bytes); @@ -459,9 +462,10 @@ absl::Status InProcessCommunicator::AllToAll( {send_buffers.begin(), send_buffers.end()}, {recv_buffers.begin(), recv_buffers.end()}}; - auto op = Rendezvous>( - name, key, partiticipant, key.num_local_participants, - CollectParticipants); + TF_ASSIGN_OR_RETURN(auto op, + Rendezvous>( + name, key, partiticipant, key.num_local_participants, + CollectParticipants)); size_t num_bytes = count * primitive_util::ByteWidth(dtype); return op->Invoke(AllToAllOp, rank_, num_bytes); @@ -477,9 +481,10 @@ absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, std::string name = absl::StrCat("all gather ", key.ToString()); AllGatherParticipant partiticipant{rank_, send_buffer, recv_buffer}; - auto op = Rendezvous>( - name, key, partiticipant, key.num_local_participants, - CollectParticipants); + TF_ASSIGN_OR_RETURN(auto op, + Rendezvous>( + name, key, partiticipant, key.num_local_participants, + CollectParticipants)); size_t num_bytes = count * primitive_util::ByteWidth(dtype); return op->Invoke(AllGatherOp, rank_, num_bytes); diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc index 790020d4839845..6c366de471fe43 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc @@ -348,7 +348,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, // processes are not able to synchronize device activity. RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized); - return Rendezvous>( + return Rendezvous( initialization_rendezvous_name, rendezvous_key, rendezvous_arg, num_local_participants, initialize, WarnStuckTimeout(), TerminateTimeout()); @@ -510,7 +510,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, rank.value(), clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); - return Rendezvous>( + return Rendezvous( initialization_rendezvous_name, rendezvous_key, rank_pair, num_local_participants, split, WarnStuckTimeout(), TerminateTimeout()); } @@ -545,7 +545,7 @@ absl::StatusOr> AcquireGpuClique( TF_ASSIGN_OR_RETURN( std::shared_ptr clique, - Rendezvous>( + Rendezvous( rendezvous_name, rendezvous_key, num_local_participants, [&] { tsl::profiler::TraceMe trace("LockGpuClique"); diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc index 5671862aa82ee8..6bb17ce4a8fd13 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc @@ -285,9 +285,10 @@ absl::Status CollectivePermuteStartThunk::RunCollective( // Perform a rendezvous to make sure all receivers have their events // recorded. - Rendezvous(rendezvous_name, rendezvous_key, num_local_participants, - /*warn_stuck_timeout=*/absl::Seconds(20), - /*terminate_timeout=*/absl::Seconds(40)); + TF_RETURN_IF_ERROR(Rendezvous(rendezvous_name, rendezvous_key, + num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40))); // For sending side, wait for the recorded event from the receiving side. if (target_id) { @@ -325,9 +326,10 @@ absl::Status CollectivePermuteStartThunk::RunCollective( // Perform a rendezvous to make sure all senders have their events // recorded. - Rendezvous(rendezvous_name, rendezvous_key, num_local_participants, - /*warn_stuck_timeout=*/absl::Seconds(20), - /*terminate_timeout=*/absl::Seconds(40)); + TF_RETURN_IF_ERROR(Rendezvous(rendezvous_name, rendezvous_key, + num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40))); // For receiving side, wait for the recorded event from the sending side. if (source_id) { diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc index e68ef4e353a3ef..a45c6e0cc4dd6d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc @@ -480,10 +480,10 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { "first call to collective operation %d; run_id=%d", config().op_id, params.collective_params->run_id.ToInt()); - Rendezvous(first_call_rendezvous_flag_, rendezvous_name, rendezvous_key, - num_local_participants, - /*warn_stuck_timeout=*/absl::Seconds(20), - /*terminate_timeout=*/absl::Seconds(40)); + TF_RETURN_IF_ERROR(Rendezvous(first_call_rendezvous_flag_, rendezvous_name, + rendezvous_key, num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40))); } return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc index fff21ee1858a05..09a70533f4b16f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc @@ -251,11 +251,13 @@ RendezvousBeforeKernelStart(absl::string_view name, std::string start_rendezvous_key = absl::StrFormat("start %s ragged-all-to-all for rank %d, clique %s", name, rank.value(), clique_key.ToString()); - std::shared_ptr> rendezvous_values = + TF_ASSIGN_OR_RETURN( + std::shared_ptr> rendezvous_values, Rendezvous>( /*name=*/ start_rendezvous_key, /*key=*/clique_key, - /*value=*/rendezvous_value, /*num_threads=*/num_ranks, rendezvous_fn); + /*value=*/rendezvous_value, /*num_threads=*/num_ranks, + rendezvous_fn)); // Wait for all devices to reach the start event. This indicates that all // output buffers are ready for transfer. @@ -280,9 +282,9 @@ absl::Status RendezvousAfterKernelFinish( std::string finish_rendezvous_key = absl::StrFormat("finish %s ragged-all-to-all for rank %d, clique %s", name, rank.value(), clique_key.ToString()); - Rendezvous(/*name=*/finish_rendezvous_key, - /*key=*/clique_key, - /*num_threads=*/num_ranks); + TF_RETURN_IF_ERROR(Rendezvous(/*name=*/finish_rendezvous_key, + /*key=*/clique_key, + /*num_threads=*/num_ranks)); // Wait for all devices to reach the end event. This indicates that all // updates from other devices have arrived. diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 26c8cdf9cb47ff..40d48bbe5a3c44 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5904,11 +5904,15 @@ xla_cc_test( srcs = ["rendezvous_test.cc"], deps = [ ":rendezvous", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", + "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 72865c0ba3e831..4d4dfce94580ee 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -426,7 +426,7 @@ absl::Status RendezvousAfterInitialization( run_options->device_ordinal(), run_options->run_options().run_id().ToInt()); - Rendezvous( + return Rendezvous( rendezvous_name, rendezvous_key, num_local_participants, absl::Seconds( debug_options @@ -436,8 +436,6 @@ absl::Status RendezvousAfterInitialization( debug_options ? debug_options->xla_gpu_executable_terminate_timeout_seconds() : 30)); - - return absl::OkStatus(); } absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index 79c7da83105afd..86488963e16744 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -39,60 +40,36 @@ limitations under the License. namespace xla { //===----------------------------------------------------------------------===// -// A rendezvous for a group of threads. +// Rendezvous synchronization. //===----------------------------------------------------------------------===// -// A little bit of compile time metaprogramming to simplify the rendezvous -// return type for functions returning `absl::StatusOr`. If we detect that -// rendezvous callback returns `absl::StatusOr` we swap the order of a shared -// pointer and status container. - -template -struct RendezvousResult { - using Type = std::shared_ptr; - - template - static Type Wrap(Result result) { - static_assert(std::is_constructible_v, - "Result `R` is not constructible from `Result`"); - return std::make_shared(std::move(result)); - } - - static Type Empty() { return std::shared_ptr(); } -}; - -template -struct RendezvousResult> { - using Type = absl::StatusOr>; - - template - static Type Wrap(absl::StatusOr result) { - static_assert(std::is_constructible_v, - "Result `R` is not constructible from `Result`"); - if (!result.ok()) return result.status(); - return std::make_shared(std::move(*result)); - } - - template - static Type Wrap(Result result) { - static_assert(std::is_constructible_v, - "Result `R` is not constructible from `Result`"); - return std::make_shared(std::move(result)); - } - - static Type Empty() { return {std::shared_ptr()}; } -}; - -template <> -struct RendezvousResult { - using Type = absl::Status; - - static Type Wrap(absl::Status result) { return result; } - static Type Empty() { return absl::OkStatus(); } -}; +// Rendezvous is an XLA synchronization primitive that guarantees that all +// participating threads arrive to a rendezvous barrier identified by a key, and +// the last arriving thread becomes a leader that executes a rendezvous +// callback. The result of executing a callback broadcasted back to all +// participants as an `std::shared_ptr` value, which makes all participants +// "collective owners" of the computed value. +// +// XLA uses rendezvous to guarantee that all ranks make progress together when +// executing a partitioned XLA program, and acts as a guard against the +// deadlocks in the lower parts of the stack (i.e. if not all participants +// arrive to NCCL collective, then we will get a deadlock on device, which is a +// lot harder to debug). +// +// Rendezvous can synchronize only within a same process, as it relies on +// shared memory to communicate between participants. +// +// If rendezvous reaches a `terminate_timeout`, it will return an error status +// to all participants, meaning that not all participants have arrived to the +// rendezvous barrier in the given time. +// +// Rendezvous callback must return the value of type `R`, or `absl::StatusOr` +// which will be automatically converted to `absl::StatusOr>` +// for all participants. -template -using RendezvousResultType = typename RendezvousResult::Type; +//===----------------------------------------------------------------------===// +// Rendezvous API. +//===----------------------------------------------------------------------===// // The group of threads identifies itself with a key that must be unique to // the group. When all threads have arrived at the rendezvous, one thread @@ -100,14 +77,14 @@ using RendezvousResultType = typename RendezvousResult::Type; // all threads receive the result. Rendezvous must have a human readable name to // make easy to debug stuck and timed out attempts. template -RendezvousResultType Rendezvous( +absl::StatusOr> Rendezvous( absl::string_view name, const K& key, const V& value, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); // A rendezvous for a group of threads that do not have any value arguments. template -RendezvousResultType Rendezvous( +absl::StatusOr> Rendezvous( absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); @@ -115,9 +92,10 @@ RendezvousResultType Rendezvous( // A rendezvous for a group of threads that do not have any computation to run // and simply acts as a barrier for a group of thread. template -void Rendezvous(absl::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +absl::Status Rendezvous( + absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); // An `std::once_flag`-like primitive for executing Rendezvous operations. // @@ -170,10 +148,10 @@ class RendezvousFlag { // A rendezvous for a group of threads that will be executed only if the flag is // not in `completed` state and will switch it to `completed` after finishing a -// rendezvous. If rendezvous will not be executed it will return empty shared -// pointer result. +// rendezvous. If rendezvous was not executed, the result will be an empty +// shared pointer. template -RendezvousResultType Rendezvous( +absl::StatusOr> Rendezvous( RendezvousFlag& flag, absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), @@ -183,10 +161,11 @@ RendezvousResultType Rendezvous( // not in `completed` state and will switch it to `completed` after finishing a // rendezvous. template -void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, - size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +absl::Status Rendezvous( + RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); //===----------------------------------------------------------------------===// // Internal implementation details. @@ -194,6 +173,12 @@ void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, namespace internal { +// Detects types that are `absl::StatusOr` container. +template +struct IsStatusOrResult : std::false_type {}; +template +struct IsStatusOrResult> : std::true_type {}; + // A base class for rendezvous state that holds synchronization primitives. struct RendezvousStateSynchronization { explicit RendezvousStateSynchronization(size_t num_threads) @@ -213,17 +198,15 @@ struct RendezvousStateSynchronization { // A state for a single round of rendezvous. We expect exactly `num_treads` to // arrive to a rendezvous and update corresponding slots in `values`. We -// pre-allocate storage for values so at run time each participant doesn't have +// pre-allocate storage for values, so at run time each participant doesn't have // to grab a lock and can simple write to the destination storage. template struct RendezvousState : public RendezvousStateSynchronization { explicit RendezvousState(size_t n_threads) - : RendezvousStateSynchronization(n_threads), - values(n_threads, nullptr), - result(RendezvousResult::Empty()) {} + : RendezvousStateSynchronization(n_threads), values(n_threads, nullptr) {} std::vector values; - RendezvousResultType result; + absl::StatusOr> result; }; // A container for in-progress rendezvous. @@ -259,7 +242,8 @@ class RendezvousMap { return state = std::make_shared(num_threads); } - void Complete(const K& key, RendezvousResultType result) { + template + void Complete(const K& key, Result&& result) { std::shared_ptr state = [&] { absl::MutexLock lock(&mutex_); @@ -280,7 +264,15 @@ class RendezvousMap { // the progress of concurrent rendezvous for other keys. // Publish rendezvous result to all participants. - state->result = std::move(result); + if constexpr (IsStatusOrResult::value) { + if (ABSL_PREDICT_TRUE(result.ok())) { + state->result = std::make_shared(*std::forward(result)); + } else { + state->result = result.status(); + } + } else { + state->result = std::make_shared(std::forward(result)); + } // Notify awaiting participants that result is ready. absl::MutexLock lock(&state->mutex); @@ -304,10 +296,10 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, //===----------------------------------------------------------------------===// template -RendezvousResultType Rendezvous(absl::string_view name, const K& key, - const V& value, size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +absl::StatusOr> Rendezvous( + absl::string_view name, const K& key, const V& value, size_t num_threads, + Fn fn, absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { // Check that `fn` is callable with a span of values. static_assert(std::is_invocable_v>, "invalid rendezvous function signature"); @@ -315,7 +307,16 @@ RendezvousResultType Rendezvous(absl::string_view name, const K& key, // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread). if (num_threads == 1) { const V* ptr = &value; - return RendezvousResult::Wrap(fn(absl::MakeSpan(&ptr, 1))); + auto result = fn(absl::MakeSpan(&ptr, 1)); + + if constexpr (internal::IsStatusOrResult::value) { + if (ABSL_PREDICT_TRUE(result.ok())) { + return std::make_shared(*std::move(result)); + } + return result.status(); + } else { + return std::make_shared(std::move(result)); + } } using State = internal::RendezvousState; @@ -360,51 +361,55 @@ RendezvousResultType Rendezvous(absl::string_view name, const K& key, // `state->result` safe without any extra synchronization. tsl::profiler::TraceMe trace("ExecuteRendezvousCallback"); absl::Span values(state->values.data(), num_threads); - rendezvous.Complete(key, RendezvousResult::Wrap(fn(values))); + rendezvous.Complete(key, fn(values)); } return state->result; } template -RendezvousResultType Rendezvous(absl::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +absl::StatusOr> Rendezvous( + absl::string_view name, const K& key, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout) { return Rendezvous( name, key, std::nullopt, num_threads, [fn](auto) { return fn(); }, warn_stuck_timeout, terminate_timeout); } template -void Rendezvous(absl::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - Rendezvous( - name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; }, - warn_stuck_timeout, terminate_timeout); +absl::Status Rendezvous(absl::string_view name, const K& key, + size_t num_threads, absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + return Rendezvous( + name, key, std::nullopt, num_threads, + [](auto) { return std::nullopt; }, warn_stuck_timeout, + terminate_timeout) + .status(); } template -RendezvousResultType Rendezvous(RendezvousFlag& flag, absl::string_view name, - const K& key, size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +absl::StatusOr> Rendezvous( + RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { return Rendezvous(name, key, num_threads, std::move(fn), warn_stuck_timeout, terminate_timeout); } else { - return RendezvousResult::Empty(); + return std::shared_ptr(); } } template -void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, - size_t num_threads, absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +absl::Status Rendezvous(RendezvousFlag& flag, absl::string_view name, + const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - Rendezvous(name, key, num_threads, warn_stuck_timeout, - terminate_timeout); + return Rendezvous(name, key, num_threads, warn_stuck_timeout, + terminate_timeout); + } else { + return absl::OkStatus(); } } diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index 789e01a59a9e7f..f9fbb1c8287e20 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -20,12 +20,16 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" #include "xla/tsl/platform/threadpool.h" @@ -41,7 +45,9 @@ tsl::thread::ThreadPool CreateThreadPool(int32_t size) { } TEST(RendezvousTest, OneParticipant) { - auto result = Rendezvous("rendezvous_test", 0, 1, [] { return 42; }); + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr result, + Rendezvous("rendezvous_test", 0, 1, [] { return 42; })); ASSERT_EQ(*result, 42); } @@ -51,8 +57,9 @@ TEST(RendezvousTest, TwoParticipants) { auto task = [&](int32_t id) { return [&, id] { - results[id] = - Rendezvous("rendezvous_test", 0, 2, [] { return 42; }); + TF_ASSERT_OK_AND_ASSIGN( + results[id], + Rendezvous("rendezvous_test", 0, 2, [] { return 42; })); counter.DecrementCount(); }; }; @@ -62,7 +69,6 @@ TEST(RendezvousTest, TwoParticipants) { thread_pool.Schedule(task(1)); counter.Wait(); - ASSERT_EQ(results.size(), 2); ASSERT_EQ(*results[0], 42); ASSERT_EQ(*results[1], 42); } @@ -79,8 +85,9 @@ TEST(RendezvousTest, TwoParticipantsWithValues) { auto task = [&](int32_t id) { return [&, id] { - results[id] = - Rendezvous("rendezvous_test", 0, id, 2, accumulate); + TF_ASSERT_OK_AND_ASSIGN( + results[id], + Rendezvous("rendezvous_test", 0, id, 2, accumulate)); counter.DecrementCount(); }; }; @@ -90,7 +97,6 @@ TEST(RendezvousTest, TwoParticipantsWithValues) { thread_pool.Schedule(task(1)); counter.Wait(); - ASSERT_EQ(results.size(), 2); ASSERT_EQ(*results[0], 1); ASSERT_EQ(*results[1], 1); } @@ -102,7 +108,8 @@ TEST(RendezvousTest, RepeatRendezvous) { absl::BlockingCounter counter(2); auto task = [&] { - Rendezvous("rendezvous_test", i, 2, [] { return 42; }); + TF_ASSERT_OK( + Rendezvous("rendezvous_test", i, 2, [] { return 42; })); counter.DecrementCount(); }; @@ -113,13 +120,38 @@ TEST(RendezvousTest, RepeatRendezvous) { } TEST(RendezvousTest, ReturningStatusOr) { + absl::BlockingCounter counter(2); + std::vector> results(2); + + auto task = [&](int32_t id) { + return [&, id] { + TF_ASSERT_OK_AND_ASSIGN( + results[id], + Rendezvous("rendezvous_test", 0, 2, + []() -> absl::StatusOr { return 42; })); + counter.DecrementCount(); + }; + }; + + auto thread_pool = CreateThreadPool(2); + thread_pool.Schedule(task(0)); + thread_pool.Schedule(task(1)); + counter.Wait(); + + ASSERT_EQ(*results[0], 42); + ASSERT_EQ(*results[1], 42); +} + +TEST(RendezvousTest, ReturningStatusError) { absl::BlockingCounter counter(2); std::vector>> results(2); auto task = [&](int32_t id) { return [&, id] { - results[id] = Rendezvous>("rendezvous_test", 0, 2, - [] { return 42; }); + results[id] = Rendezvous( + "rendezvous_test", 0, 2, []() -> absl::StatusOr { + return absl::InternalError("test error"); + }); counter.DecrementCount(); }; }; @@ -129,9 +161,8 @@ TEST(RendezvousTest, ReturningStatusOr) { thread_pool.Schedule(task(1)); counter.Wait(); - ASSERT_EQ(results.size(), 2); - ASSERT_EQ(**results[0], 42); - ASSERT_EQ(**results[1], 42); + ASSERT_EQ(results[0].status(), absl::InternalError("test error")); + ASSERT_EQ(results[1].status(), absl::InternalError("test error")); } TEST(RendezvousTest, RendezvousFlag) { @@ -145,9 +176,9 @@ TEST(RendezvousTest, RendezvousFlag) { auto task = [&](absl::BlockingCounter& counter) { return [&] { - Rendezvous( + TF_ASSERT_OK(Rendezvous( flag, "rendezvous_test", 0, 2, [&] { return ++num_executed; }, - Timeout(), Terminate()); + Timeout(), Terminate())); counter.DecrementCount(); }; }; @@ -178,8 +209,8 @@ TEST(RendezvousTest, RendezvousFlagRace) { auto task = [&](int32_t key) { return [&, key] { - Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + TF_ASSERT_OK(Rendezvous(flag, "key: " + std::to_string(key), key, + kNumThreads, Timeout(), Terminate())); }; }; @@ -208,8 +239,8 @@ TEST(RendezvousTest, RendezvousFlagRaceWithBarriers) { return [&, key] { participants_ready.DecrementCount(); participants_notification.WaitForNotification(); - Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + TF_ASSERT_OK(Rendezvous(flag, "key: " + std::to_string(key), key, + kNumThreads, Timeout(), Terminate())); participants_done.DecrementCount(); }; }; @@ -237,8 +268,8 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - Rendezvous("rendezvous_test", 0, num_threads, - [] { return 42; }); + CHECK_OK(Rendezvous("rendezvous_test", 0, num_threads, + [] { return 42; })); counter.DecrementCount(); }); } @@ -255,8 +286,8 @@ static void BM_RendezvousWithValues(benchmark::State& state) { for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&, i] { int32_t value = i; - Rendezvous("rendezvous_test", 0, value, num_threads, - [](auto) { return 42; }); + CHECK_OK(Rendezvous("rendezvous_test", 0, value, num_threads, + [](auto) { return 42; })); counter.DecrementCount(); }); } @@ -275,8 +306,8 @@ static void BM_GroupedRendezvous(benchmark::State& state) { for (int64_t group = 0; group < num_groups; ++group) { for (int64_t i = 0; i < group_size; ++i) { thread_pool.Schedule([&, group] { - Rendezvous("rendezvous_test", group, group_size, - [] { return 42; }); + CHECK_OK(Rendezvous("rendezvous_test", group, group_size, + [] { return 42; })); counter.DecrementCount(); }); } From f762455e05a4f8f77cd368a374802d5b977ab394 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Wed, 9 Apr 2025 10:12:21 -0700 Subject: [PATCH 0451/1324] Update to CUDA 12.6.3 and cuDNN 9.3.0 in tests and the documentation Reason for the update: - The current HERMETIC_CUDA_VERSION used by XLA is 12.1.1 (released April 2023), which is now 2 years old. - The current HERMETIC_CUDNN_VERSION in XLA is 8.6 (released October 2022), making it 2.5 years old. - JAX currently uses HERMETIC CUDA 12.8.0 and cuDNN 9.8.0, indicating it's already on a newer stack. - I tested upgrading XLA to CUDA 12.6.3 and cuDNN 9.3.0 - all GitHub XLA PR tests passed: PR #24788. Release dates: - CUDA 12.6.3: November 2024 - cuDNN 9.3.0: August 2024 PiperOrigin-RevId: 745637141 --- .../build_tools/configure/configure_test.py | 4 ++-- .../configure/testdata/cuda_clang.bazelrc | 4 ++-- .../configure/testdata/nvcc_clang.bazelrc | 4 ++-- .../configure/testdata/nvcc_gcc.bazelrc | 4 ++-- third_party/xla/docs/hermetic_cuda.md | 20 +++++++++---------- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/third_party/xla/build_tools/configure/configure_test.py b/third_party/xla/build_tools/configure/configure_test.py index 8849b931a905e9..a94cefdc4a92b1 100644 --- a/third_party/xla/build_tools/configure/configure_test.py +++ b/third_party/xla/build_tools/configure/configure_test.py @@ -35,9 +35,9 @@ # CUDA specific paths and versions _CUDA_SPECIFIC_PATHS_AND_VERSIONS = { - "cuda_version": '"12.1.1"', + "cuda_version": '"12.6.3"', "cuda_compute_capabilities": ["7.5"], - "cudnn_version": '"8.6"', + "cudnn_version": '"9.3.0"', "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", } _CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH = { diff --git a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc index 3f42ca9e563aa2..0832fbcb4c16ef 100644 --- a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc @@ -3,9 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config cuda_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.6.3" build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 -build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" +build:cuda --repo_env HERMETIC_CUDNN_VERSION="9.3.0" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc index 59d8d15c220843..96c655bc57c22e 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -3,9 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config cuda_nvcc build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.6.3" build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 -build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" +build:cuda --repo_env HERMETIC_CUDNN_VERSION="9.3.0" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc index d587802dbd3a4f..cc6c84fb0195c7 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -1,8 +1,8 @@ build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc build --config cuda -build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.6.3" build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 -build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" +build:cuda --repo_env HERMETIC_CUDNN_VERSION="9.3.0" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --define=xnn_enable_avxvnniint8=false diff --git a/third_party/xla/docs/hermetic_cuda.md b/third_party/xla/docs/hermetic_cuda.md index 208aae137d0c77..076c1f628ba5d2 100644 --- a/third_party/xla/docs/hermetic_cuda.md +++ b/third_party/xla/docs/hermetic_cuda.md @@ -23,26 +23,26 @@ default when `--config=cuda` is specified in Bazel command options. ## Environment variables controlling the hermetic CUDA/CUDNN versions `HERMETIC_CUDA_VERSION` environment variable should consist of major, minor and -patch CUDA version, e.g. `12.3.2`. +patch CUDA version, e.g. `12.6.3`. `HERMETIC_CUDNN_VERSION` environment variable should consist of major, minor and -patch CUDNN version, e.g. `9.1.1`. +patch CUDNN version, e.g. `9.3.0`. Three ways to set the environment variables for Bazel commands: ``` # Add an entry to your `.bazelrc` file -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.3" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" # OR pass it directly to your specific build command bazel build --config=cuda \ ---repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ ---repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +--repo_env=HERMETIC_CUDA_VERSION="12.6.3" \ +--repo_env=HERMETIC_CUDNN_VERSION="9.3.0" # If .bazelrc doesn't have corresponding entries and the environment variables # are not passed to bazel command, you can set them globally in your shell: -export HERMETIC_CUDA_VERSION="12.3.2" -export HERMETIC_CUDNN_VERSION="9.1.1" +export HERMETIC_CUDA_VERSION="12.6.3" +export HERMETIC_CUDNN_VERSION="9.3.0" ``` If `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` are not present, the @@ -114,8 +114,8 @@ is specified in [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https: respectively. Use only supported versions. You may set the environment variables directly in your shell or in `.bazelrc` file as shown below: ``` - build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" - build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.3" + build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" build:cuda --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" ``` From 13bf1f1e9db2748e155cf5399b665924ed0ca89a Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 9 Apr 2025 10:14:05 -0700 Subject: [PATCH 0452/1324] Automated Code Change PiperOrigin-RevId: 745637844 --- tensorflow/compiler/mlir/lite/BUILD | 188 ++++-------------- .../compiler/mlir/lite/experimental/tac/BUILD | 7 +- .../common/quantization_lib/BUILD | 14 +- .../compiler/mlir/lite/quantization/ir/BUILD | 49 ++--- .../mlir/lite/quantization/tensorflow/BUILD | 7 +- tensorflow/compiler/mlir/lite/stablehlo/BUILD | 48 +---- .../mlir/lite/stablehlo/odml_converter/BUILD | 20 +- .../mlir/quantization/common/ir/BUILD | 49 ++--- .../common/quantization_lib/BUILD | 14 +- .../common/tf_quantization_lib/BUILD | 14 +- .../mlir/quantization/tensorflow/BUILD | 98 ++------- tensorflow/compiler/mlir/tensorflow/BUILD | 150 +++++--------- .../mlir/tensorflow/ir/host_runtime/BUILD | 14 +- .../compiler/mlir/tensorflow/transforms/BUILD | 122 ++++-------- .../tensorflow/transforms/host_runtime/BUILD | 13 +- .../tensorflow/transforms/sparsecore/BUILD | 13 +- .../mlir/tf2xla/internal/inference/BUILD | 13 +- .../mlir/tf2xla/internal/passes/BUILD | 39 ++-- .../compiler/mlir/tf2xla/transforms/BUILD | 33 +-- tensorflow/compiler/mlir/tfr/BUILD | 21 +- tensorflow/compiler/mlir/tfrt/BUILD | 14 +- tensorflow/compiler/mlir/tfrt/ir/BUILD | 78 +++----- tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD | 56 ++---- .../compiler/mlir/tfrt/transforms/ifrt/BUILD | 13 +- 24 files changed, 289 insertions(+), 798 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 26e436cc519c72..4ecd78da80209c 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -101,15 +101,10 @@ td_library( gentbl_cc_library( name = "tensorflow_lite_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowLiteTd", - ], - "transforms/passes.h.inc", - ), - ], + tbl_outs = {"transforms/passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowLiteTd", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/passes.td", deps = [ @@ -120,23 +115,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tfl_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tfl_ops.cc.inc", - ), - ( - [ - "-gen-dialect-doc", - "-dialect=tfl", - ], - "g3doc/tfl_ops.md", - ), - ], + tbl_outs = { + "ir/tfl_ops.h.inc": ["-gen-op-decls"], + "ir/tfl_ops.cc.inc": ["-gen-op-defs"], + "g3doc/tfl_ops.md": [ + "-gen-dialect-doc", + "-dialect=tfl", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_ops.td", deps = [ @@ -147,24 +133,12 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "ir/tfl_ops_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ir/tfl_ops_interface.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "ir/tfl_ops_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "ir/tfl_ops_dialect.cc.inc", - ), - ], + tbl_outs = { + "ir/tfl_ops_interface.h.inc": ["-gen-op-interface-decls"], + "ir/tfl_ops_interface.cc.inc": ["-gen-op-interface-defs"], + "ir/tfl_ops_dialect.h.inc": ["-gen-dialect-decls"], + "ir/tfl_ops_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_op_interfaces.td", deps = [ @@ -175,24 +149,12 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_enums_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-enum-decls"], - "ir/tfl_ops_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "ir/tfl_ops_enums.cc.inc", - ), - ( - ["-gen-attrdef-decls"], - "ir/tfl_ops_attrdefs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "ir/tfl_ops_attrdefs.cc.inc", - ), - ], + tbl_outs = { + "ir/tfl_ops_enums.h.inc": ["-gen-enum-decls"], + "ir/tfl_ops_enums.cc.inc": ["-gen-enum-defs"], + "ir/tfl_ops_attrdefs.h.inc": ["-gen-attrdef-decls"], + "ir/tfl_ops_attrdefs.cc.inc": ["-gen-attrdef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_op_enums.td", deps = [ @@ -203,12 +165,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_prepare_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_prepare_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_prepare_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -217,12 +174,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_lower_static_tensor_list_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_lower_static_tensor_list.inc", - ), - ], + tbl_outs = {"transforms/generated_lower_static_tensor_list.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tensorlist_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -231,12 +183,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -245,12 +192,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_variables_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_variables.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_variables.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_variables.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -259,12 +201,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -273,12 +210,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_batch_matmul_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize_batch_matmul.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize_batch_matmul.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_batch_matmul.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -287,12 +219,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_broadcast_like_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize_broadcast_like.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize_broadcast_like.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_broadcast_like_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -301,12 +228,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_quantize.inc", - ), - ], + tbl_outs = {"transforms/generated_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -315,12 +237,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_by_converter_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_quantize_by_converter.inc", - ), - ], + tbl_outs = {"transforms/generated_quantize_by_converter.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_by_converter_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -329,12 +246,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_post_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_post_quantize.inc", - ), - ], + tbl_outs = {"transforms/generated_post_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/post_quantize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -343,12 +255,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tensorlist_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tensorlist.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tensorlist.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_tensorlist.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -380,12 +287,7 @@ cc_library( gentbl_cc_library( name = "tensorflow_lite_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "ir/tfl_canonicalize.inc", - ), - ], + tbl_outs = {"ir/tfl_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_canonicalize.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -1470,7 +1372,7 @@ filegroup( gentbl_cc_library( name = "op_quant_spec_getters_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "utils/generated_op_quant_spec_getters.inc")], + tbl_outs = {"utils/generated_op_quant_spec_getters.inc": []}, tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", deps = [ @@ -1481,7 +1383,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tflite_op_coverage_spec_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "utils/tflite_op_coverage_spec.inc")], + tbl_outs = {"utils/tflite_op_coverage_spec.inc": []}, tblgen = "//tensorflow/compiler/mlir/lite/quantization:tflite_op_coverage_spec_getters_gen", td_file = "ir/tfl_ops.td", visibility = ["//learning/brain/mobile/model_optimization/g3doc/autogen:__pkg__"], @@ -1506,16 +1408,10 @@ tf_native_cc_binary( gentbl_cc_library( name = "converter_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["--gen-operator-converters"], - "operator_converters.inc", - ), - ( - ["--gen-runtime-verifiers"], - "runtime_verifiers.inc", - ), - ], + tbl_outs = { + "operator_converters.inc": ["--gen-operator-converters"], + "runtime_verifiers.inc": ["--gen-runtime-verifiers"], + }, tblgen = ":converter-gen", td_file = "ir/tfl_ops.td", test = 1, diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index f3edb169515bb6..e8018e4d9acdae 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -98,12 +98,7 @@ cc_library( gentbl_cc_library( name = "transform_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_transform_patterns.inc", - ), - ], + tbl_outs = {"transforms/generated_transform_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/transform_patterns.td", deps = [ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD index f572ad6418feba..55f5727b81cd1d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD @@ -109,16 +109,10 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "quantization_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "quantization_interface.cc.inc", - ), - ], + tbl_outs = { + "quantization_interface.h.inc": ["-gen-op-interface-decls"], + "quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index a6d6c61444548e..88022e023443f6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -26,30 +26,18 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "QuantOps.h.inc", - ), - ( - ["-gen-op-defs"], - "QuantOps.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=quantfork", - ], - "QuantOpsDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=quantfork", - ], - "QuantOpsDialect.cc.inc", - ), - ], + tbl_outs = { + "QuantOps.h.inc": ["-gen-op-decls"], + "QuantOps.cc.inc": ["-gen-op-defs"], + "QuantOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=quantfork", + ], + "QuantOpsDialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=quantfork", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "QuantOps.td", deps = [":QuantizationOpsTdFiles"], @@ -58,15 +46,10 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=quantfork", - ], - "Passes.h.inc", - ), - ], + tbl_outs = {"Passes.h.inc": [ + "-gen-pass-decls", + "-name=quantfork", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 8a73407338f697..dec7cbc852da00 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -36,12 +36,7 @@ td_library( gentbl_cc_library( name = "ptq_fallback_to_flex_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "fallback_to_flex_patterns.inc", - ), - ], + tbl_outs = {"fallback_to_flex_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "fallback_to_flex_patterns.td", deps = [":ptq_td_files"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index feb8afac64c098..2f25ec4532b233 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -123,12 +123,7 @@ tf_cc_test( gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_tf_patterns.td", deps = [ @@ -330,15 +325,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=OdmlStablehlo", - ], - "transforms/stablehlo_passes.h.inc", - ), - ], + tbl_outs = {"transforms/stablehlo_passes.h.inc": [ + "-gen-pass-decls", + "-name=OdmlStablehlo", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/stablehlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -626,12 +616,7 @@ cc_library( gentbl_cc_library( name = "hlo_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_hlo_patterns.td", deps = [ @@ -645,12 +630,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_legalize_tflite_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_tflite_legalize_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_tflite_legalize_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tflite_legalize_hlo_patterns.td", deps = [ @@ -708,12 +688,7 @@ cc_library( gentbl_cc_library( name = "prepare_hlo_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_prepare_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_prepare_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_hlo.td", deps = [ @@ -959,12 +934,7 @@ cc_library( gentbl_cc_library( name = "composite_lowering_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_composite_lowering.inc", - ), - ], + tbl_outs = {"transforms/generated_composite_lowering.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/composite_lowering_patterns.td", deps = [ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD index c54545bd331391..d6b46ee3d31ad1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD @@ -64,12 +64,7 @@ cc_library( gentbl_cc_library( name = "shlo_simplify_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_shlo_simplify.inc", - ), - ], + tbl_outs = {"transforms/generated_shlo_simplify.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/shlo_simplify.td", deps = ["@stablehlo//:stablehlo_ops_td_files"], @@ -91,15 +86,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=ODMLConverter", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=ODMLConverter", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/quantization/common/ir/BUILD b/tensorflow/compiler/mlir/quantization/common/ir/BUILD index 2821bb96a66950..162c14c4ad70f9 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/ir/BUILD @@ -25,30 +25,18 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "QuantOps.h.inc", - ), - ( - ["-gen-op-defs"], - "QuantOps.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=quantization", - ], - "QuantOpsDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=quantization", - ], - "QuantOpsDialect.cc.inc", - ), - ], + tbl_outs = { + "QuantOps.h.inc": ["-gen-op-decls"], + "QuantOps.cc.inc": ["-gen-op-defs"], + "QuantOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=quantization", + ], + "QuantOpsDialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=quantization", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "QuantOps.td", deps = [":QuantizationOpsTdFiles"], @@ -57,15 +45,10 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=tfquant", - ], - "Passes.h.inc", - ), - ], + tbl_outs = {"Passes.h.inc": [ + "-gen-pass-decls", + "-name=tfquant", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD index b6b1d17d17a4a7..36b7152c15ff02 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD @@ -102,16 +102,10 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "quantization_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "quantization_interface.cc.inc", - ), - ], + tbl_outs = { + "quantization_interface.h.inc": ["-gen-op-interface-decls"], + "quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD index 8079760d548d5f..f42c3ca6446c42 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD @@ -102,16 +102,10 @@ td_library( gentbl_cc_library( name = "tf_quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "tf_quantization_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "tf_quantization_interface.cc.inc", - ), - ], + tbl_outs = { + "tf_quantization_interface.h.inc": ["-gen-op-interface-decls"], + "tf_quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index bfc5cde2dcbc82..37edee990eb314 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -114,12 +114,7 @@ td_library( gentbl_cc_library( name = "convert_tf_xla_op_to_tf_op_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/convert_tf_xla_op_to_tf_op.inc", - ), - ], + tbl_outs = {"passes/convert_tf_xla_op_to_tf_op.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/convert_tf_xla_op_to_tf_op.td", deps = [":quant_td_files"], @@ -128,12 +123,7 @@ gentbl_cc_library( gentbl_cc_library( name = "cast_bf16_ops_to_f32_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/cast_bf16_ops_to_f32.inc", - ), - ], + tbl_outs = {"passes/cast_bf16_ops_to_f32.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/cast_bf16_ops_to_f32.td", deps = [":quant_td_files"], @@ -142,12 +132,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_lifting_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/prepare_lifting.inc", - ), - ], + tbl_outs = {"passes/prepare_lifting.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/prepare_lifting.td", deps = [":quant_td_files"], @@ -156,12 +141,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions.td", deps = [":quant_td_files"], @@ -170,12 +150,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_drq_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_drq.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_drq.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_drq.td", deps = [":quant_td_files"], @@ -184,12 +159,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/prepare_quantize.inc", - ), - ], + tbl_outs = {"passes/prepare_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/prepare_quantize.td", deps = [":quant_td_files"], @@ -198,12 +168,7 @@ gentbl_cc_library( gentbl_cc_library( name = "quantize_composite_functions_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/quantize_composite_functions.inc", - ), - ], + tbl_outs = {"passes/quantize_composite_functions.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/quantize_composite_functions.td", deps = [":quant_td_files"], @@ -212,16 +177,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_quant_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "passes/tf_quant_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "passes/tf_quant_ops.cc.inc", - ), - ], + tbl_outs = { + "passes/tf_quant_ops.h.inc": ["-gen-op-decls"], + "passes/tf_quant_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/tf_quant_ops.td", deps = [ @@ -232,12 +191,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/optimize.inc", - ), - ], + tbl_outs = {"passes/optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/optimize.td", deps = [":quant_td_files"], @@ -246,12 +200,7 @@ gentbl_cc_library( gentbl_cc_library( name = "convert_tpu_model_to_cpu_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/convert_tpu_model_to_cpu.inc", - ), - ], + tbl_outs = {"passes/convert_tpu_model_to_cpu.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/convert_tpu_model_to_cpu.td", deps = [":quant_td_files"], @@ -260,12 +209,7 @@ gentbl_cc_library( gentbl_cc_library( name = "post_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/post_quantize.inc", - ), - ], + tbl_outs = {"passes/post_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/post_quantize.td", deps = [":quant_td_files"], @@ -274,12 +218,7 @@ gentbl_cc_library( gentbl_cc_library( name = "preprocess_op_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/preprocess_op.inc", - ), - ], + tbl_outs = {"passes/preprocess_op.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/preprocess_op.td", deps = [":quant_td_files"], @@ -319,12 +258,7 @@ cc_library( gentbl_cc_library( name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/replace_cast_hacks_with_tf_xla_ops.inc", - ), - ], + tbl_outs = {"passes/replace_cast_hacks_with_tf_xla_ops.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/replace_cast_hacks_with_tf_xla_ops.td", deps = [":quant_td_files"], diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 81bf61234707c0..a09199d428cd6b 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -47,16 +47,10 @@ td_library( gentbl_cc_library( name = "tensorflow_op_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "ir/tf_op_interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ir/tf_op_interfaces.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_op_interfaces.h.inc": ["-gen-op-interface-decls"], + "ir/tf_op_interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_op_interfaces.td", test = True, @@ -68,12 +62,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_struct_doc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-dialect-doc"], - "g3doc/tf_ops.md", - ), - ], + tbl_outs = {"g3doc/tf_ops.md": ["-gen-dialect-doc"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", test = True, @@ -107,16 +96,10 @@ cc_library( gentbl_cc_library( name = "tensorflow_all_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_all_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_all_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_all_ops.h.inc": ["-gen-op-decls"], + "ir/tf_all_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -140,22 +123,16 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_" + target["name"] + "_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-op-include-regex=" + target["include"], - ], - "ir/tf_" + target["name"] + ".h.inc", - ), - ( - [ - "-gen-op-defs", - "-op-include-regex=" + target["include"], - ], - "ir/tf_" + target["name"] + ".cc.inc", - ), - ], + tbl_outs = { + "ir/tf_" + target["name"] + ".h.inc": [ + "-gen-op-decls", + "-op-include-regex=" + target["include"], + ], + "ir/tf_" + target["name"] + ".cc.inc": [ + "-gen-op-defs", + "-op-include-regex=" + target["include"], + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -167,22 +144,16 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_remaining_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), - ], - "ir/tf_remaining_ops.h.inc", - ), - ( - [ - "-gen-op-defs", - "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), - ], - "ir/tf_remaining_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_remaining_ops.h.inc": [ + "-gen-op-decls", + "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), + ], + "ir/tf_remaining_ops.cc.inc": [ + "-gen-op-defs", + "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -193,20 +164,11 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_saved_model_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_saved_model.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_saved_model.cc.inc", - ), - ( - ["-gen-dialect-doc"], - "g3doc/tf_saved_model.md", - ), - ], + tbl_outs = { + "ir/tf_saved_model.h.inc": ["-gen-op-decls"], + "ir/tf_saved_model.cc.inc": ["-gen-op-defs"], + "g3doc/tf_saved_model.md": ["-gen-dialect-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_saved_model_ops.td", test = True, @@ -219,23 +181,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_executor_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_executor.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_executor.cc.inc", - ), - ( - [ - "-gen-dialect-doc", - "-dialect=tf_executor", - ], - "g3doc/tf_executor.md", - ), - ], + tbl_outs = { + "ir/tf_executor.h.inc": ["-gen-op-decls"], + "ir/tf_executor.cc.inc": ["-gen-op-defs"], + "g3doc/tf_executor.md": [ + "-gen-dialect-doc", + "-dialect=tf_executor", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_executor_ops.td", test = True, @@ -250,20 +203,11 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_device_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_device.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_device.cc.inc", - ), - ( - ["-gen-dialect-doc"], - "g3doc/tf_device.md", - ), - ], + tbl_outs = { + "ir/tf_device.h.inc": ["-gen-op-decls"], + "ir/tf_device.cc.inc": ["-gen-op-defs"], + "g3doc/tf_device.md": ["-gen-dialect-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_device_ops.td", test = True, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD index ccf7b0b547ab90..f1ab2432181e36 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD @@ -31,16 +31,10 @@ td_library( gentbl_cc_library( name = "tensorflow_tfrt_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_ops.cc.inc", - ), - ], + tbl_outs = { + "tfrt_ops.h.inc": ["-gen-op-decls"], + "tfrt_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_ops.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 19a1137b20de4f..3f4af7a26e34fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -13,12 +13,7 @@ package( gentbl_cc_library( name = "tensorflow_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_canonicalize.inc", - ), - ], + tbl_outs = {"generated_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "canonicalize.td", deps = [ @@ -29,12 +24,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_reduce_patterns_inc_gen", - tbl_outs = [ - ( - ["-gen-rewriters"], - "reducer/tf_reduce_patterns.inc", - ), - ], + tbl_outs = {"reducer/tf_reduce_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "reducer/tf_mlir_reduce_patterns.td", deps = [ @@ -89,12 +79,7 @@ cc_library( gentbl_cc_library( name = "decompose_resource_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_decompose_resource_ops.inc", - ), - ], + tbl_outs = {"generated_decompose_resource_ops.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "decompose_resource_ops.td", deps = [ @@ -152,12 +137,7 @@ cc_library( gentbl_cc_library( name = "tf_data_optimization_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_tf_data_optimization.inc", - ), - ], + tbl_outs = {"generated_tf_data_optimization.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_data_optimization.td", deps = [ @@ -376,19 +356,13 @@ cc_library( gentbl_cc_library( name = "tf_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlow", - ], - "tf_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/_includes/tf_passes.md", - ), - ], + tbl_outs = { + "tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlow", + ], + "g3doc/_includes/tf_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_passes.td", deps = [ @@ -399,19 +373,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_device_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowDevice", - ], - "tf_device_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_device_passes.md", - ), - ], + tbl_outs = { + "tf_device_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowDevice", + ], + "g3doc/includes/tf_device_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_device_passes.td", deps = [ @@ -422,19 +390,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_savedmodel_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowSavedModel", - ], - "tf_savedmodel_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_savedmodel_passes.md", - ), - ], + tbl_outs = { + "tf_savedmodel_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowSavedModel", + ], + "g3doc/includes/tf_savedmodel_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_savedmodel_passes.td", deps = [ @@ -445,19 +407,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_test_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowTest", - ], - "test_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_test_passes.md", - ), - ], + tbl_outs = { + "test_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowTest", + ], + "g3doc/includes/tf_test_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_test_passes.td", deps = [ @@ -1025,12 +981,7 @@ filegroup( gentbl_cc_library( name = "tensorflow_optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_optimize.inc", - ), - ], + tbl_outs = {"generated_optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "optimize.td", deps = [ @@ -1045,12 +996,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lower_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_lower_tf.inc", - ), - ], + tbl_outs = {"generated_lower_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lower_tf.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index fa91275c392432..838dc1eb6fb8c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -141,15 +141,10 @@ tf_cc_test( gentbl_cc_library( name = "runtime_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=RuntimeLowering", - ], - "runtime_passes.h.inc", - ), - ], + tbl_outs = {"runtime_passes.h.inc": [ + "-gen-pass-decls", + "-name=RuntimeLowering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "runtime_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD index d19d5e8e8ab5aa..74f952a6cb7db6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD @@ -16,15 +16,10 @@ package( gentbl_cc_library( name = "sparsecore_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=SparseCore", - ], - "sparsecore_passes.h.inc", - ), - ], + tbl_outs = {"sparsecore_passes.h.inc": [ + "-gen-pass-decls", + "-name=SparseCore", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "sparsecore_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD index e80d33abb5cb37..d87efdfbf146f5 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD @@ -13,15 +13,10 @@ package( gentbl_cc_library( name = "inference_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TF2XLA", - ], - "inference_passes.h.inc", - ), - ], + tbl_outs = {"inference_passes.h.inc": [ + "-gen-pass-decls", + "-name=TF2XLA", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "inference_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index becdc528044f86..697ef42d3c6ac2 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -71,15 +71,10 @@ cc_library( gentbl_cc_library( name = "clustering_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeClustering", - ], - "clustering_passes.h.inc", - ), - ], + tbl_outs = {"clustering_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeClustering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "clustering_passes.td", deps = [ @@ -229,15 +224,10 @@ cc_library( gentbl_cc_library( name = "mlir_to_graph_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeMlirToGraph", - ], - "mlir_to_graph_passes.h.inc", - ), - ], + tbl_outs = {"mlir_to_graph_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeMlirToGraph", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mlir_to_graph_passes.td", deps = [ @@ -459,15 +449,10 @@ cc_library( gentbl_cc_library( name = "lowering_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeLowering", - ], - "lowering_passes.h.inc", - ), - ], + tbl_outs = {"lowering_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeLowering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lowering_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index abd057643629c6..87af60a1cfabe3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -16,12 +16,7 @@ package( gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_legalize_tf.inc", - ), - ], + tbl_outs = {"generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "legalize_tf_patterns.td", deps = [ @@ -36,15 +31,10 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_legalize_tf_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=LegalizeTf", - ], - "xla_legalize_tf_passes.h.inc", - ), - ], + tbl_outs = {"xla_legalize_tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=LegalizeTf", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_legalize_tf_passes.td", deps = [ @@ -55,15 +45,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_xla_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfXla", - ], - "tf_xla_passes.h.inc", - ), - ], + tbl_outs = {"tf_xla_passes.h.inc": [ + "-gen-pass-decls", + "-name=TfXla", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_xla_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index a90f25aab887cf..4435ef59a7e385 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -53,16 +53,10 @@ td_library( gentbl_cc_library( name = "tfr_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tfr_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tfr_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tfr_ops.h.inc": ["-gen-op-decls"], + "ir/tfr_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfr_ops.td", deps = [ @@ -73,12 +67,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tfr_decompose_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/generated_decompose.inc", - ), - ], + tbl_outs = {"passes/generated_decompose.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/decompose_patterns.td", deps = [ diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 7c18a25ef08365..cc3c2b7beb5dca 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -60,16 +60,10 @@ td_library( gentbl_cc_library( name = "runtime_fallback_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "runtime_fallback_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "runtime_fallback_ops.cc.inc", - ), - ], + tbl_outs = { + "runtime_fallback_ops.h.inc": ["-gen-op-decls"], + "runtime_fallback_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "runtime_fallback/runtime_fallback_ops.td", deps = [":runtime_fallback_ops_td_files"], diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index b29066807fbf78..ae5379f2102f36 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -141,16 +141,10 @@ td_library( gentbl_cc_library( name = "tfrt_fallback_opdefs_inc_gen", compatible_with = get_compatible_with_portable(), # copybara: comment - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback.cpp.inc", - ), - ], + tbl_outs = { + "tfrt_fallback.h.inc": ["-gen-op-decls"], + "tfrt_fallback.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback.td", deps = [":tfrt_fallback_td_files"], @@ -159,16 +153,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tfrt_fallback_async_opdefs_inc_gen", compatible_with = get_compatible_with_portable(), # copybara: comment - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback_async.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback_async.cpp.inc", - ), - ], + tbl_outs = { + "tfrt_fallback_async.h.inc": ["-gen-op-decls"], + "tfrt_fallback_async.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback_async.td", deps = [":tfrt_fallback_td_files"], @@ -176,23 +164,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tfrt_fallback_sync_opdefs_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback_sync.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback_sync.cpp.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=tfrt_fallback_sync", - ], - "tfrt_fallback_sync_dialect.h.inc", - ), - ], + tbl_outs = { + "tfrt_fallback_sync.h.inc": ["-gen-op-decls"], + "tfrt_fallback_sync.cpp.inc": ["-gen-op-defs"], + "tfrt_fallback_sync_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=tfrt_fallback_sync", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback_sync.td", test = True, @@ -219,23 +198,14 @@ td_library( gentbl_cc_library( name = "tfrt_gpu_opdefs_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "gpu_ops.cpp.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=gpurt", - ], - "gpurt_dialect.h.inc", - ), - ], + tbl_outs = { + "gpu_ops.h.inc": ["-gen-op-decls"], + "gpu_ops.cpp.inc": ["-gen-op-defs"], + "gpurt_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=gpurt", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "gpu_ops.td", test = True, diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index 374aad2a242d9b..200f66fd722fef 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -23,16 +23,10 @@ td_library( gentbl_cc_library( name = "mlrt_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "mlrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mlrt_ops.cpp.inc", - ), - ], + tbl_outs = { + "mlrt_ops.h.inc": ["-gen-op-decls"], + "mlrt_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mlrt_ops.td", deps = [":mlrt_td_files"], @@ -96,16 +90,10 @@ td_library( gentbl_cc_library( name = "tf_mlrt_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_mlrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_mlrt_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_mlrt_ops.h.inc": ["-gen-op-decls"], + "tf_mlrt_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_mlrt_ops.td", deps = [":tf_mlrt_td_files"], @@ -113,16 +101,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_mlrt_tpu_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_mlrt_tpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_mlrt_tpu_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_mlrt_tpu_ops.h.inc": ["-gen-op-decls"], + "tf_mlrt_tpu_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_mlrt_tpu_ops.td", deps = [":tf_mlrt_tpu_td_files"], @@ -130,16 +112,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_ops.h.inc": ["-gen-op-decls"], + "tf_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_ops.td", deps = [":tf_mlrt_td_files"], diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 2162d37eebcfef..c957df04f2ff2f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -33,15 +33,10 @@ package_group( gentbl_cc_library( name = "pass_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfrtIfrtServing", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=TfrtIfrtServing", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ From 73d9b29fc9dc34dbf566a6fdb2e10cb8a8381bed Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 9 Apr 2025 10:21:32 -0700 Subject: [PATCH 0453/1324] Add output tile selection logic to dynamic search space PiperOrigin-RevId: 745640819 --- .../xla/xla/service/gpu/autotuning/BUILD | 2 + .../gpu/autotuning/dot_search_space.cc | 196 ++++++++++++++++-- .../service/gpu/autotuning/dot_search_space.h | 50 ++++- .../gpu/autotuning/dot_search_space_test.cc | 100 ++++++++- .../gpu/autotuning/gemm_fusion_autotuner.cc | 4 +- 5 files changed, 315 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index b750e922073c5f..3bb76e3111901e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -271,7 +271,9 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tsl/lib/core:bits", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 9a223b44774273..245ee5a3ba0c74 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -16,14 +16,19 @@ limitations under the License. #include "xla/service/gpu/autotuning/dot_search_space.h" #include +#include #include #include #include +#include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_format.h" +#include "llvm/ADT/STLExtras.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -57,6 +62,18 @@ int64_t NextPowerOfTwo(int64_t x) { return tsl::NextPowerOfTwoS64(x); } +// Finds the previous power of two, smaller or equal to x. +// +// Returns 1 for an edge case of x = 0 (which we can get as a result of integer +// division). This might feel a bit weird, but it does the right thing when +// calculating tile sizes, since we need a strictly positive size. +int64_t PreviousPowerOfTwo(int64_t x) { + if (x == 0) { + return 1; + } + return tsl::NextPowerOfTwoS64(x + 1) / 2; +} + } // namespace TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( @@ -74,47 +91,115 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( (contracting_size_ * batch_size_)), rhs_parallel_size_(ShapeUtil::ElementsIn(dot->operand(1)->shape()) / (contracting_size_ * batch_size_)), + compute_bitwidth_(primitive_util::BitWidth(dot->shape().element_type())), // TODO: b/404470821 - Compute these from the problem properties instead // of hardcoding. desired_total_warps_(2160), - max_out_tile_{64, 128}, + max_out_tile_(GetMaxOutputTile()), + min_out_tile_{16, 16}, min_warps_per_cta_(4), min_contracting_tile_size_(16), - max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) {} + max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) { + // Make sure that the range of output tile sizes is not empty + // (min_output_tile_ is a hard limit, while max_output_tile_ is a soft one). + max_out_tile_.lhs_dim = + std::max(min_out_tile_.lhs_dim, max_out_tile_.lhs_dim); + max_out_tile_.rhs_dim = + std::max(min_out_tile_.rhs_dim, max_out_tile_.rhs_dim); +} std::vector TritonDotFusionSearchSpace::GenerateConfigs( std::optional force_contracting_split) { - std::vector configs; + std::vector configs; if (force_contracting_split.has_value()) { - TritonGemmConfig config; - config.split_k = force_contracting_split.value(); + ConfigWithNotes config; + const int split = force_contracting_split.value(); + config.config.split_k = split; + // It is possible that the user manually forced a huge contracting split + // that is outside of the search space. In that case, we would end up + // discarding all configs, and use the smallest possible tile size further + // down, which is likely not what the user had in mind. + config.keep_large_split = GetMaxContractingSplit(max_out_tile_) < split; + if (config.keep_large_split) { + LOG(WARNING) + << "split_k is larger than what we would have found automatically. " + "Skipping split and output tile compatibility checks. Should we " + "expand the split_k search space?"; + } configs.push_back(config); } else { configs = GenerateContractingSplitFactors(); } - // TODO: b/404470821 - Implement this properly rather than hardcoding the - // config parameters. - for (auto& config : configs) { - config.block_m = 64; - config.block_n = 128; + + ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddOutputTilings); + EliminateLowOccupancyConfigs(configs); + + std::vector result; + result.reserve(configs.size()); + for (auto& config_with_notes : configs) { + // TODO: b/404470821 - Implement this properly rather than hardcoding the + // config parameters. + auto& config = config_with_notes.config; config.block_k = 64; config.num_stages = 3; config.num_warps = 4; config.num_ctas = 1; + result.push_back(config); } - return configs; + return result; } std::string TritonDotFusionSearchSpace::Serialize() { return absl::StrFormat( - "problem_size_BxMxNxK: %dx%dx%dx%d " - "tile_range_SxMxNxK: [1-%d]x[1-%d]x[1-%d]x[%d-?] " - "desired_total_warps: %d warps_per_block: [%d-?]", + "problem_size_BxMxNxKxE: %dx%dx%dx%dx%d " + "tile_range_SxMxNxK: [1-%d]x[%d-%d]x[%d-%d]x[%d-?] " + "desired_total_warps: %d warps_per_cta: [%d-?]", batch_size_, lhs_parallel_size_, rhs_parallel_size_, contracting_size_, - max_contracting_split_, max_out_tile_.lhs_dim, max_out_tile_.rhs_dim, + compute_bitwidth_, max_contracting_split_, min_out_tile_.lhs_dim, + max_out_tile_.lhs_dim, min_out_tile_.rhs_dim, max_out_tile_.rhs_dim, min_contracting_tile_size_, desired_total_warps_, min_warps_per_cta_); } +TritonDotFusionSearchSpace::OutputTile +TritonDotFusionSearchSpace::GetMaxOutputTile() const { + constexpr int kRegisterSizeInBits = 32; + const int64_t max_elements_per_cta = + device_description_.registers_per_block_limit() * kRegisterSizeInBits / + compute_bitwidth_; + auto limit_other_size_to_fit = [max_elements_per_cta](int64_t this_size) { + return PreviousPowerOfTwo(max_elements_per_cta / this_size); + }; + // We generally want to have square-ish tiles if possible to get maximal + // reuse. For wgmma the optimal instruction shape is 64x256, so optimizing for + // larger RHS given the choice. + OutputTile max_tile; + max_tile.lhs_dim = PreviousPowerOfTwo(std::sqrt(max_elements_per_cta)); + max_tile.rhs_dim = limit_other_size_to_fit(max_tile.lhs_dim); + VLOG(5) << "Computing max_output_tile: Based on available registers, " + "max_output_tile = " + << max_tile.lhs_dim << "x" << max_tile.rhs_dim; + + const int64_t lhs_parallel_limit = NextPowerOfTwo(lhs_parallel_size_); + const int64_t rhs_parallel_limit = NextPowerOfTwo(rhs_parallel_size_); + if (lhs_parallel_limit < max_tile.lhs_dim) { + max_tile.lhs_dim = lhs_parallel_limit; + max_tile.rhs_dim = std::min(limit_other_size_to_fit(lhs_parallel_limit), + rhs_parallel_limit); + VLOG(5) << "Computing max_tile: However, due to small LHS parallel size," + "max_output_tile = " + << max_tile.lhs_dim << "x" << max_tile.rhs_dim; + } + if (rhs_parallel_limit < max_tile.rhs_dim) { + max_tile.lhs_dim = std::min(limit_other_size_to_fit(rhs_parallel_limit), + lhs_parallel_limit); + max_tile.rhs_dim = rhs_parallel_limit; + VLOG(5) << "Computing max_tile: However, due to small RHS parallel " + "size, max_output_tile = " + << max_tile.lhs_dim << "x" << max_tile.rhs_dim; + } + return max_tile; +} + int64_t TritonDotFusionSearchSpace::GetNumResultTiles( OutputTile output_tile) const { return batch_size_ * @@ -124,18 +209,18 @@ int64_t TritonDotFusionSearchSpace::GetNumResultTiles( int TritonDotFusionSearchSpace::GetMaxContractingSplit( OutputTile output_tile) const { - const int64_t desired_num_blocks = desired_total_warps_ / min_warps_per_cta_; + const int64_t desired_num_ctas = desired_total_warps_ / min_warps_per_cta_; VLOG(5) << "Computing split_k: Considering output tile " << output_tile.lhs_dim << "x" << output_tile.rhs_dim; - VLOG(5) << "Computing split_k: Want up to " << desired_num_blocks - << " blocks to occupy all cores."; + VLOG(5) << "Computing split_k: Want up to " << desired_num_ctas + << " CTAs to occupy all cores."; const int64_t min_result_tiles = GetNumResultTiles(output_tile); VLOG(5) << "Computing split_k: Without split_k have " << min_result_tiles << " tiles."; const int64_t split_for_occupancy = - NextPowerOfTwo(CeilOfRatio(desired_num_blocks, min_result_tiles)); + NextPowerOfTwo(CeilOfRatio(desired_num_ctas, min_result_tiles)); VLOG(5) << "Computing split_k: Want split_k of up to " << split_for_occupancy << " for sufficient occupancy."; @@ -151,12 +236,13 @@ int TritonDotFusionSearchSpace::GetMaxContractingSplit( return split; } -std::vector +std::vector TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { - std::vector configs; - TritonGemmConfig config; + CHECK_GE(max_contracting_split_, 1); + std::vector configs; + ConfigWithNotes config; for (int split = 1; split <= max_contracting_split_; split *= 2) { - config.split_k = split; + config.config.split_k = split; VLOG(10) << "Generating contracting split factors: config = " << config.ToString(); configs.push_back(config); @@ -164,4 +250,68 @@ TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { return configs; } +void TritonDotFusionSearchSpace::ExtendConfigs( + std::vector& configs, ExtendConfigCallback extend_config) { + CHECK(!configs.empty()); + std::vector updated_configs; + for (ConfigWithNotes& config : configs) { + (this->*extend_config)(config, updated_configs); + } + CHECK(!updated_configs.empty()); + configs = std::move(updated_configs); +} + +void TritonDotFusionSearchSpace::AddOutputTilings( + const ConfigWithNotes& config, + std::vector& updated_configs) { + CHECK_GT(config.config.split_k, 0) + << "Need config with contracting split already set."; + const int split = config.config.split_k; + auto new_config = config; + auto& m = new_config.config.block_m; + auto& n = new_config.config.block_n; + for (m = min_out_tile_.lhs_dim; m <= max_out_tile_.lhs_dim; m *= 2) { + for (n = min_out_tile_.rhs_dim; n <= max_out_tile_.rhs_dim; n *= 2) { + OutputTile tile = {m, n}; + // We could make the tile size limits depend on split_k, but then we + // need to implement the "inverse" of `GetMaxContractingSplit`. + // Simpler is to just verify that the given combination of tiling and + // split_k is compatible. + if (!config.keep_large_split && GetMaxContractingSplit(tile) < split) { + VLOG(10) << "Skipping due to too large split_k, config = " + << new_config.ToString(); + continue; + } + new_config.not_enough_tiles = + GetNumResultTiles(tile) * split < device_description_.core_count(); + VLOG(10) << "Adding output tiling: config = " << new_config.ToString(); + updated_configs.push_back(new_config); + } + } +} + +void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( + std::vector& configs) { + CHECK(!configs.empty()); + auto last_config = configs.back(); // Config with the largest split. + auto has_too_few_tiles = [](const ConfigWithNotes& config) { + if (config.not_enough_tiles) { + VLOG(10) << "Skipping due to fewer tiles than cores, config = " + << config.ToString(); + } + return config.not_enough_tiles; + }; + configs.erase(llvm::remove_if(configs, has_too_few_tiles), configs.end()); + if (configs.empty()) { + // We can get no configs if the problem is small enough to not even occupy + // all cores. In that case, we just use the largest split and smallest + // tiling. + last_config.config.block_m = min_out_tile_.lhs_dim; + last_config.config.block_n = min_out_tile_.rhs_dim; + VLOG(10) << "No configs with sufficient occupancy, using config = " + << last_config.ToString(); + configs.push_back(last_config); + } +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index 25d986c15abefa..44025a65dbc96f 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -54,10 +54,40 @@ class TritonDotFusionSearchSpace { // dimensions of the left and right hand sides. We assume that any batch // dimensions are tiled by a factor of 1. struct OutputTile { - int lhs_dim; // LHS tiling (aka. block_m). - int rhs_dim; // RHS tiling (aka. block_n). + int lhs_dim = 0; // LHS tiling (aka. block_m). + int rhs_dim = 0; // RHS tiling (aka. block_n). }; + // Adds notes to configs, which carry additional information we need to + // consider while generating the search space. + struct ConfigWithNotes { + TritonGemmConfig config; + // This config has a larger than expected split_k, but we do not want to + // discard it. + bool keep_large_split = false; + // This config does not have enough tiles for all cores to be occupied. + bool not_enough_tiles = false; + + std::string ToString() const { return config.ToString(); } + }; + + // Callback type for `ExtendConfigs`. The method should append zero or more + // extensions of `config` to the `updated_configs` vector. + using ExtendConfigCallback = void (TritonDotFusionSearchSpace::*)( + const ConfigWithNotes& config, + std::vector& updated_configs); + + // Extends Triton gemm configs by repeatedly calling `*extend_config()` on + // each config in `configs`. Expects that after all calls to `extend_config`, + // the updated list of configs is non-empty. + void ExtendConfigs(std::vector& configs, + ExtendConfigCallback extend_config); + + // Computes the maximum sensible size of the output tile (block_m, block_n) + // based on the dot shape and element type, and the available registers on + // the core. + OutputTile GetMaxOutputTile() const; + // Computes the number of result tiles we would have without // splitting the contracting dimension for a given output tile. int64_t GetNumResultTiles(OutputTile output_tile) const; @@ -69,15 +99,29 @@ class TritonDotFusionSearchSpace { // Finds all promising values for splitting the contracting dimension to // achieve sufficient occupancy (split_k). - std::vector GenerateContractingSplitFactors(); + std::vector GenerateContractingSplitFactors(); + + // Finds all promising output shape tilings (block_m, block_n), based on + // `config` with already determined contracting split value and appends them + // to `updated_configs`. Each config in the input list might yield zero or + // more configs in the output. + void AddOutputTilings(const ConfigWithNotes& config, + std::vector& updated_configs); + + // Removes configs that are marked with `not_enough_tiles` from the list. If + // this results in an empty list, adds a config that should be the most + // optimal one even though it does not occupy all cores. + void EliminateLowOccupancyConfigs(std::vector& configs); se::DeviceDescription device_description_; int64_t contracting_size_; int64_t batch_size_; int64_t lhs_parallel_size_; int64_t rhs_parallel_size_; + int compute_bitwidth_; int desired_total_warps_; OutputTile max_out_tile_; + OutputTile min_out_tile_; int min_warps_per_cta_; int min_contracting_tile_size_; int max_contracting_split_; diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index c3d9aded73a1f1..6d150dff1f91a3 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -38,6 +38,7 @@ using ::testing::Field; using ::testing::Ge; using ::testing::IsEmpty; using ::testing::Le; +using ::testing::SizeIs; template auto BlockMIs(MatcherType matcher) { @@ -76,8 +77,14 @@ auto IsValidConfig() { class DotSearchSpaceTest : public HloHardwareIndependentTestBase { protected: - se::DeviceDescription device_description_{ - se::DeviceDescription(se::GpuDeviceInfoProto::default_instance())}; + se::DeviceDescription device_description_; + + DotSearchSpaceTest() + : device_description_(se::GpuDeviceInfoProto::default_instance()) { + // Using H100 numbers as the most relevant example here. + device_description_.set_registers_per_block_limit(64 * 1024); + device_description_.set_core_count(132); + } absl::StatusOr> GetDefaultDotModule( int lhs_parallel_dim = 1024, int rhs_parallel_dim = 1024, @@ -98,12 +105,25 @@ ENTRY e { return Cast( module->entry_computation()->root_instruction()); } + + TritonDotFusionSearchSpace MakeSearchSpace(VerifiedHloModule* module) { + return TritonDotFusionSearchSpace(device_description_, GetDot(module)); + } }; +TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { + TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_EQ(search_space.Serialize(), + "problem_size_BxMxNxKxE: 1x1024x1024x1024x16 " + "tile_range_SxMxNxK: [1-64]x[16-256]x[16-512]x[16-?] " + "desired_total_warps: 2160 warps_per_cta: [4-?]"); +} + TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - TritonDotFusionSearchSpace search_space(device_description_, - GetDot(module.get())); + auto search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), AllOf(Not(IsEmpty()), Each(IsValidConfig()))); @@ -111,8 +131,7 @@ TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { TEST_F(DotSearchSpaceTest, HonorsForcedContractingSplit) { TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - TritonDotFusionSearchSpace search_space(device_description_, - GetDot(module.get())); + auto search_space = MakeSearchSpace(module.get()); EXPECT_THAT( search_space.GenerateConfigs(/*force_contracting_split=*/2), @@ -124,8 +143,7 @@ TEST_F(DotSearchSpaceTest, ConsidersContractingSplitForSmallOutputSize) { GetDefaultDotModule(/*lhs_parallel_dim=*/16, /*rhs_parallel_dim=*/16, /*contracting_dim=*/1024)); - TritonDotFusionSearchSpace search_space(device_description_, - GetDot(module.get())); + auto search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), Contains(SplitKIs(Ge(2)))); } @@ -135,12 +153,74 @@ TEST_F(DotSearchSpaceTest, LimitsContractingSplitForSmallerContractingSize) { GetDefaultDotModule(/*lhs_parallel_dim=*/16, /*rhs_parallel_dim=*/16, /*contracting_dim=*/32)); - TritonDotFusionSearchSpace search_space(device_description_, - GetDot(module.get())); + auto search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), AllOf(Not(IsEmpty()), Each(SplitKIs(Le(2))))); } +TEST_F(DotSearchSpaceTest, FindsGoodDataReuseOutputTiles) { + TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + Contains(AllOf(BlockMIs(Ge(32)), BlockNIs(Ge(32)))).Times(Ge(2))); +} + +TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForLowOccupancyProblem) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/4096, /*rhs_parallel_dim=*/16, + /*contracting_dim=*/4096)); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + Contains(AllOf(BlockMIs(Ge(32)), SplitKIs(Ge(2))))); +} + +TEST_F(DotSearchSpaceTest, + FindsUniqueOccupancyMaximizingTilingForSmallProblem) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/32, /*rhs_parallel_dim=*/32, + /*contracting_dim=*/32)); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(SizeIs(1), Each(AllOf(BlockMIs(Eq(16)), BlockNIs(Eq(16)), + SplitKIs(Eq(2)))))); +} + +TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) { + TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT( + search_space.GenerateConfigs(/*force_contracting_split=*/128), + Contains(AllOf(BlockMIs(Ge(32)), BlockNIs(Ge(32)), SplitKIs(Eq(128))))); +} + +TEST_F(DotSearchSpaceTest, PadsTilesForSmallParallelDimension) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, + /*rhs_parallel_dim=*/15, + /*contracting_dim=*/1024)); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), Contains(BlockNIs(Eq(16)))); +} + +TEST_F(DotSearchSpaceTest, HonorsMinimumOutputTileSizeForTinyProblem) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/12, + /*rhs_parallel_dim=*/8, + /*contracting_dim=*/16)); + auto search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT( + search_space.GenerateConfigs(), + AllOf(Not(IsEmpty()), Each(BlockMIs(Ge(16))), Each(BlockNIs(Ge(16))))); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 10bb06ef5d4c17..c05f71517ac5e0 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -891,13 +891,15 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { debug_options_.xla_gpu_enable_split_k_autotuning(); if (debug_options_.xla_gpu_experimental_enable_dynamic_dot_search_space()) { - if (small_dot || !IsAutotuningEnabled()) { + if (!IsAutotuningEnabled()) { return {{kDefaultConfig}}; } TritonDotFusionSearchSpace search_space(config_.GetDeviceDescription(), &dot); VLOG(1) << "Generating configs from search space: " << search_space.Serialize(); + // We don't need to consider small_dot here. The new search space will + // already generate a unique config for small problems. return search_space.GenerateConfigs( autotune_contracting_split ? std::make_optional(1) : std::nullopt); } From a7d97d667b4fd595e02116ca7ebfcbe253bfa4bc Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Wed, 9 Apr 2025 10:27:00 -0700 Subject: [PATCH 0454/1324] Removed dependence on deprecated BUILD rule. PiperOrigin-RevId: 745642757 --- tensorflow/core/common_runtime/eager/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 7d2d3148818736..871d279bf396bf 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -253,7 +253,6 @@ tf_cuda_library( "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", - "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_impl", "@local_xla//xla/tsl/distributed_runtime/preemption:preemption_notifier", "@local_xla//xla/tsl/platform:statusor", ], From 101c8d5037aa773dab2e207b1ff5699456178d10 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 9 Apr 2025 10:51:46 -0700 Subject: [PATCH 0455/1324] Register DebugCallbackCustomCallPartitioner for OSS GPU. PiperOrigin-RevId: 745652777 --- third_party/xla/xla/pjrt/c/BUILD | 1 + third_party/xla/xla/python/BUILD | 3 +++ .../xla/xla/python/debug_callback_partitioner.cc | 12 ++++++++++++ third_party/xla/xla/python/version.h | 2 +- 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 57ab02e4e50809..d1f3a43584314f 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -388,6 +388,7 @@ cc_library( "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler "//xla/python:custom_call_batch_partitioner", "//xla/python:custom_partition_callback", + "//xla/python:debug_callback_partitioner", # To register "DebugCallbackCustomCallPartitioner" custom partitioning handler. "//xla/python:inspect_sharding", # To register "InspectSharding" custom partitioning handler. "//xla/service:compiler", "//xla/service:custom_call_target_registry", diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index fcec573557ad2c..ca03d08d37427d 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -195,6 +195,7 @@ cc_library( name = "debug_callback_partitioner", srcs = ["debug_callback_partitioner.cc"], hdrs = ["debug_callback_partitioner.h"], + visibility = internal_visibility([":friends"]), deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -203,6 +204,8 @@ cc_library( "//xla/service/spmd:spmd_partitioner", "@com_google_absl//absl/status", ], + # Always register 'DebugCallbackCustomCallPartitioner' custom partitioning handler. + alwayslink = 1, ) cc_library( diff --git a/third_party/xla/xla/python/debug_callback_partitioner.cc b/third_party/xla/xla/python/debug_callback_partitioner.cc index fe6e4d6b6f28b7..cd58f9922885b6 100644 --- a/third_party/xla/xla/python/debug_callback_partitioner.cc +++ b/third_party/xla/xla/python/debug_callback_partitioner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -81,4 +82,15 @@ absl::Status DebugCallbackCustomCallPartitioner::Partition( return absl::OkStatus(); } +namespace { +struct Registerer { + explicit Registerer(std::string target_name) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + } +}; +Registerer cpu_registerer("xla_ffi_partitioned_python_cpu_callback"); +Registerer gpu_registerer("xla_ffi_partitioned_python_gpu_callback"); +} // namespace + } // namespace xla diff --git a/third_party/xla/xla/python/version.h b/third_party/xla/xla/python/version.h index a575c7ccb68c01..ea6e7b78c5d510 100644 --- a/third_party/xla/xla/python/version.h +++ b/third_party/xla/xla/python/version.h @@ -18,6 +18,6 @@ limitations under the License. // An increasing version number to protect jax code against breaking changes. // In JAX, reference this via jax._src.lib.ifrt_version. -#define JAX_IFRT_VERSION_NUMBER 3 +#define JAX_IFRT_VERSION_NUMBER 4 #endif // XLA_PYTHON_VERSION_H_ From a804ef0061dfc9d99dc255c354fd79f1a698df2b Mon Sep 17 00:00:00 2001 From: pizzud Date: Wed, 9 Apr 2025 11:01:54 -0700 Subject: [PATCH 0456/1324] safe_reinterpret_cast: Use std::remove_cv instead of building our own version. No need to type out all the const/volatile variations when the stdlib will do it for us. PiperOrigin-RevId: 745656894 --- .../xla/xla/tsl/util/safe_reinterpret_cast.h | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index 51d17d0158281b..ab5e2efd800aae 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -48,17 +48,6 @@ struct IsByteLike : std::true_type {}; template <> struct IsByteLike : std::true_type {}; -// IsCvByteLike::value is true if T is a possibly CV-qualified byte-like type -// (char, unsigned char, or std::byte). -template -struct IsCvByteLike : IsByteLike {}; -template -struct IsCvByteLike : IsByteLike {}; -template -struct IsCvByteLike : IsByteLike {}; -template -struct IsCvByteLike : IsByteLike {}; - // IsSafeCast::value is true if it is safe to reinterpret_cast a // value of type From to a value of type To. // @@ -79,7 +68,8 @@ struct IsSafeCast : std::integral_constant< bool, // To/from a pointer to a byte-like type. - (IsCvByteLike::value || IsCvByteLike::value) || + (IsByteLike::type>::value || + IsByteLike::type>::value) || // From function pointer to void pointer. (std::is_function_v&& std::is_void_v) || // From void pointer to function pointer. From 7ea58235c95d5df58243921e6a578f5bf862445c Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 9 Apr 2025 11:10:26 -0700 Subject: [PATCH 0457/1324] Automated Code Change PiperOrigin-RevId: 745660369 --- .../mlir/quantization/stablehlo/BUILD | 72 +++++-------------- 1 file changed, 16 insertions(+), 56 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index f40c1371c9df19..8818f81be89813 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -29,14 +29,9 @@ package( gentbl_cc_library( name = "stablehlo_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - ], - "passes/passes.h.inc", - ), - ], + tbl_outs = {"passes/passes.h.inc": [ + "-gen-pass-decls", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/passes.td", deps = [ @@ -212,12 +207,7 @@ td_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_simple_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_simple.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_simple.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_simple.td", deps = [ @@ -229,12 +219,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_fusion_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_fusion.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_fusion.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_fusion.td", deps = [ @@ -246,12 +231,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_graph_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/optimize_graph.inc", - ), - ], + tbl_outs = {"passes/optimize_graph.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/optimize_graph.td", deps = [ @@ -263,12 +243,7 @@ gentbl_cc_library( gentbl_cc_library( name = "remove_sharding_custom_call_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/remove_sharding_custom_call.inc", - ), - ], + tbl_outs = {"passes/remove_sharding_custom_call.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/remove_sharding_custom_call.td", deps = [ @@ -279,15 +254,10 @@ gentbl_cc_library( gentbl_cc_library( name = "bridge_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Bridge", - ], - "passes/bridge/passes.h.inc", - ), - ], + tbl_outs = {"passes/bridge/passes.h.inc": [ + "-gen-pass-decls", + "-name=Bridge", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/bridge/passes.td", deps = [ @@ -368,12 +338,7 @@ td_library( gentbl_cc_library( name = "optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/bridge/optimize.inc", - ), - ], + tbl_outs = {"passes/bridge/optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/bridge/optimize.td", deps = [":optimize_td_files"], @@ -498,15 +463,10 @@ cc_library( gentbl_cc_library( name = "stablehlo_test_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Test", - ], - "passes/testing/passes.h.inc", - ), - ], + tbl_outs = {"passes/testing/passes.h.inc": [ + "-gen-pass-decls", + "-name=Test", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/testing/passes.td", deps = [ From 8507198de4e2fe8038575441676df534f8238e9b Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Wed, 9 Apr 2025 11:12:50 -0700 Subject: [PATCH 0458/1324] [HLO Componentization] Migrate away from deprecated build targets and header files. PiperOrigin-RevId: 745661250 --- third_party/xla/xla/service/BUILD | 131 ++++++++---------- .../xla/service/all_reduce_simplifier_test.cc | 2 +- .../xla/service/batchnorm_expander_test.cc | 2 +- .../xla/xla/service/buffer_assignment_test.cc | 4 +- .../xla/service/collective_pipeliner_test.cc | 2 +- .../service/compilation_environments_test.cc | 2 +- .../service/conditional_code_motion_test.cc | 2 +- .../service/conditional_simplifier_test.cc | 2 +- .../xla/service/conditional_to_select_test.cc | 2 +- .../xla/xla/service/copy_insertion_test.cc | 4 +- .../custom_call_target_registry_test.cc | 2 +- .../dfs_hlo_visitor_with_default_test.cc | 3 +- .../dynamic_dimension_inference_test.cc | 3 +- .../xla/xla/service/dynamic_padder_test.cc | 4 +- .../xla/service/dynamic_update_slice_test.cc | 2 +- .../xla/service/elemental_ir_emitter_test.cc | 2 +- .../xla/xla/service/gather_expander_test.cc | 2 +- .../xla/xla/service/hlo_computation_test.cc | 4 +- .../xla/xla/service/hlo_cost_analysis_test.cc | 2 +- .../xla/service/hlo_creation_utils_test.cc | 2 +- .../xla/xla/service/hlo_domain_test.cc | 2 +- .../xla/xla/service/hlo_graph_dumper_test.cc | 2 +- .../hlo_input_output_alias_config_test.cc | 2 +- .../xla/xla/service/hlo_instruction_test.cc | 4 +- .../xla/xla/service/hlo_module_group_test.cc | 2 +- .../xla/service/hlo_module_metadata_test.cc | 4 +- .../xla/xla/service/hlo_module_test.cc | 2 +- .../xla/xla/service/hlo_proto_util_test.cc | 2 +- .../xla/xla/service/hlo_schedule_test.cc | 2 +- .../xla/xla/service/hlo_sharding_test.cc | 4 +- .../xla/xla/service/layout_assignment_test.cc | 4 +- .../service/loop_schedule_linearizer_test.cc | 2 +- .../xla/xla/service/map_inliner_test.cc | 2 +- .../mapped_ptr_container_sorter_test.cc | 2 +- .../xla/service/pattern_matcher_gmock_test.cc | 2 +- .../xla/xla/service/pattern_matcher_test.cc | 2 +- .../scatter_determinism_expander_test.cc | 2 +- .../xla/xla/service/scatter_expander_test.cc | 2 +- .../xla/xla/service/shape_inference_test.cc | 4 +- .../xla/xla/service/shaped_buffer_test.cc | 2 +- .../service/space_to_batch_converter_test.cc | 2 +- .../xla/xla/service/stream_pool_test.cc | 2 +- .../xla/xla/service/transpose_folding_test.cc | 4 +- .../xla/xla/service/tuple_util_test.cc | 2 +- .../while_loop_constant_sinking_test.cc | 2 +- .../while_loop_invariant_code_motion_test.cc | 2 +- .../xla/service/while_loop_simplifier_test.cc | 2 +- .../xla/xla/service/while_util_test.cc | 2 +- 48 files changed, 116 insertions(+), 131 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 40d48bbe5a3c44..1bbcd3de3db5a5 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -152,7 +152,6 @@ xla_cc_test( deps = [ ":all_reduce_promotion", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", @@ -193,7 +192,6 @@ xla_cc_test( deps = [ ":all_reduce_reassociate", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -506,12 +504,12 @@ xla_cc_test( ":memory_annotations_hdr", ":scheduling_annotations_util", "//xla:literal_util", - "//xla:test_helpers", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", @@ -637,12 +635,12 @@ xla_cc_test( deps = [ ":shape_inference", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/builder:padding", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -815,8 +813,8 @@ xla_test( deps = [ "//xla:execution_options_util", "//xla:status_macros", - "//xla:test", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", @@ -830,9 +828,8 @@ xla_cc_test( deps = [ ":hlo_runner", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -869,10 +866,10 @@ xla_cc_test( "//xla:comparison_util", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", @@ -894,11 +891,10 @@ xla_cc_test( srcs = ["pattern_matcher_gmock_test.cc"], deps = [ ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:shape_util", - "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", "//xla/tests:xla_internal_test_main", "@local_tsl//tsl/platform:test", ], @@ -934,17 +930,16 @@ xla_cc_test( deps = [ ":hlo_proto_cc", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:comparison_util", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -964,10 +959,10 @@ xla_cc_test( srcs = ["hlo_sharding_test.cc"], deps = [ "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/util/proto:proto_matchers", @@ -1065,11 +1060,11 @@ xla_cc_test( ":call_inliner", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1509,8 +1504,8 @@ xla_cc_test( ":shaped_buffer", "//xla:shape_tree", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:test", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_h", @@ -1831,14 +1826,14 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/transforms/simplifiers:flatten_call_graph", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", @@ -1873,8 +1868,8 @@ xla_cc_test( deps = [ ":hlo_module_group_metadata", ":hlo_proto_cc", - "//xla:test", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1961,9 +1956,9 @@ xla_cc_test( ":buffer_value", "//xla:literal_util", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/tests:hlo_test_base", @@ -1981,10 +1976,10 @@ xla_cc_test( srcs = ["hlo_input_output_alias_config_test.cc"], deps = [ "//xla:shape_util", - "//xla:test_helpers", "//xla:types", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -2145,16 +2140,15 @@ xla_cc_test( deps = [ ":hlo_creation_utils", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:array2d", "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -2299,10 +2293,10 @@ xla_cc_test( ":scatter_expander", "//xla:literal", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2321,7 +2315,7 @@ xla_test( deps = [ ":scatter_determinism_expander", "//xla:literal", - "//xla:test", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@local_tsl//tsl/platform:statusor", @@ -2429,11 +2423,11 @@ xla_test( "//xla:error_spec", "//xla:literal", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2629,14 +2623,13 @@ xla_cc_test( ":all_reduce_simplifier", ":hlo_module_config", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla:window_util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -2752,9 +2745,9 @@ xla_cc_test( srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", - "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", @@ -2793,10 +2786,10 @@ xla_cc_test( ":conditional_simplifier", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2842,10 +2835,10 @@ xla_cc_test( ":conditional_code_motion", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2904,9 +2897,9 @@ xla_cc_test( srcs = ["space_to_batch_converter_test.cc"], deps = [ ":space_to_batch_converter", - "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -3110,10 +3103,10 @@ xla_cc_test( ":while_loop_simplifier", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_matchers", @@ -3357,20 +3350,19 @@ xla_test( ":dynamic_dimension_inference", ":dynamic_padder", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:test", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", "//xla/hlo/transforms/simplifiers:hlo_dce", @@ -3405,12 +3397,11 @@ xla_cc_test( ":hlo_runner", "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -3588,7 +3579,6 @@ xla_cc_test( ":local_service", ":service", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client", "//xla/client:client_library", @@ -3598,6 +3588,7 @@ xla_cc_test( "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status:statusor", @@ -3641,16 +3632,15 @@ xla_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:comparison_util", "//xla:literal_util", "//xla:shape_tree", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -3676,10 +3666,10 @@ xla_cc_test( "//xla:debug_options_flags", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_matchers", @@ -3704,9 +3694,9 @@ xla_cc_test( name = "hlo_module_metadata_test", srcs = ["hlo_module_metadata_test.cc"], deps = [ - "//xla:test", - "//xla:test_helpers", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/tests:xla_internal_test_main", ], ) @@ -4011,11 +4001,11 @@ xla_cc_test( "//xla:debug_options_flags", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", @@ -4038,8 +4028,8 @@ xla_cc_test( ":copy_insertion", ":hlo_graph_dumper", ":loop_schedule_linearizer", - "//xla:test_helpers", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings:string_view", @@ -4216,18 +4206,17 @@ xla_cc_test( ":layout_assignment", ":logical_buffer", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:literal", "//xla:literal_util", "//xla:shape_layout", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -4299,7 +4288,6 @@ xla_cc_test( deps = [ ":hlo_cse", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:literal", "//xla:shape_util", "//xla:util", @@ -4395,9 +4383,9 @@ xla_cc_test( ":hlo_domain_verifier", ":sharding_propagation", "//xla:debug_options_flags", - "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -4507,9 +4495,9 @@ xla_test( "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/tests:hlo_test_base", @@ -4630,9 +4618,9 @@ xla_cc_test( deps = [ ":hlo_graph_dumper", "//xla:literal_util", - "//xla:test", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -4669,11 +4657,11 @@ xla_cc_test( ":transpose_folding", "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/service/gpu:ir_emission_utils", "//xla/tests:hlo_test_base", @@ -4707,7 +4695,7 @@ xla_cc_test( srcs = ["stream_pool_test.cc"], deps = [ ":stream_pool", - "//xla:test_helpers", + "//xla/hlo/testlib:test_helpers", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform", @@ -4736,9 +4724,9 @@ xla_cc_test( ":hlo_proto_util", "//xla:shape_util", "//xla:status_macros", - "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -4908,9 +4896,9 @@ xla_cc_test( ":hlo_module_config", ":tuple_util", "//xla:shape_util", - "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -4992,10 +4980,10 @@ xla_cc_test( ":host_offload_utils", ":memory_annotations_hdr", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -5056,9 +5044,9 @@ xla_cc_test( srcs = ["while_util_test.cc"], deps = [ ":while_util", - "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", @@ -5192,10 +5180,10 @@ xla_cc_test( ":while_loop_invariant_code_motion", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -5277,8 +5265,8 @@ xla_cc_test( deps = [ ":while_loop_constant_sinking", "//xla:literal_util", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -5437,9 +5425,9 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -5482,8 +5470,8 @@ xla_cc_test( deps = [ ":conditional_to_select", "//xla:literal", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -5510,7 +5498,7 @@ xla_cc_test( deps = [ ":custom_call_status", ":custom_call_target_registry", - "//xla:test", + "//xla/hlo/testlib:test", "@com_google_googletest//:gtest_main", ], ) @@ -5844,7 +5832,7 @@ xla_cc_test( srcs = ["mapped_ptr_container_sorter_test.cc"], deps = [ ":mapped_ptr_container_sorter", - "//xla:test", + "//xla/hlo/testlib:test", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", @@ -5970,8 +5958,8 @@ xla_cc_test( deps = [ ":compilation_environments", ":test_compilation_environment_proto_cc", - "//xla:test", "//xla:xla_proto_cc", + "//xla/hlo/testlib:test", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", @@ -6208,7 +6196,6 @@ xla_cc_test( deps = [ ":change_op_data_type", ":pattern_matcher", - ":pattern_matcher_gmock", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 6850ef1d11b315..cce4f27b9cb5de 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/batchnorm_expander_test.cc b/third_party/xla/xla/service/batchnorm_expander_test.cc index 25cfa87004be01..658426f867873b 100644 --- a/third_party/xla/xla/service/batchnorm_expander_test.cc +++ b/third_party/xla/xla/service/batchnorm_expander_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index d91b13a7bff1d6..8f50ff8dbd6801 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -39,6 +39,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/literal.h" @@ -52,8 +54,6 @@ limitations under the License. #include "xla/service/logical_buffer.h" #include "xla/service/memory_space_assignment/memory_space_assignment.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 983ba64dc85fca..556d0234db0500 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -39,13 +39,13 @@ limitations under the License. #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/service/memory_annotations.h" #include "xla/service/scheduling_annotations_util.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/compilation_environments_test.cc b/third_party/xla/xla/service/compilation_environments_test.cc index 1338f899b901b3..a2c044fa7e65a9 100644 --- a/third_party/xla/xla/service/compilation_environments_test.cc +++ b/third_party/xla/xla/service/compilation_environments_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/test_compilation_environment.pb.h" -#include "xla/test.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "tsl/platform/casts.h" diff --git a/third_party/xla/xla/service/conditional_code_motion_test.cc b/third_party/xla/xla/service/conditional_code_motion_test.cc index 1398a9b1fdc8d5..644a688fe4dfad 100644 --- a/third_party/xla/xla/service/conditional_code_motion_test.cc +++ b/third_party/xla/xla/service/conditional_code_motion_test.cc @@ -24,10 +24,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/conditional_simplifier_test.cc b/third_party/xla/xla/service/conditional_simplifier_test.cc index 8394cdcdc7a006..77b5fd35b216ed 100644 --- a/third_party/xla/xla/service/conditional_simplifier_test.cc +++ b/third_party/xla/xla/service/conditional_simplifier_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/conditional_to_select_test.cc b/third_party/xla/xla/service/conditional_to_select_test.cc index b32aeacf14e40f..f79de7206d3d9a 100644 --- a/third_party/xla/xla/service/conditional_to_select_test.cc +++ b/third_party/xla/xla/service/conditional_to_select_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; diff --git a/third_party/xla/xla/service/copy_insertion_test.cc b/third_party/xla/xla/service/copy_insertion_test.cc index f26650863fc622..68c4ecfb979cdc 100644 --- a/third_party/xla/xla/service/copy_insertion_test.cc +++ b/third_party/xla/xla/service/copy_insertion_test.cc @@ -33,6 +33,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" @@ -41,8 +43,6 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/custom_call_target_registry_test.cc b/third_party/xla/xla/service/custom_call_target_registry_test.cc index 1b423449953ba5..a4f58097d33b65 100644 --- a/third_party/xla/xla/service/custom_call_target_registry_test.cc +++ b/third_party/xla/xla/service/custom_call_target_registry_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/custom_call_status.h" -#include "xla/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc b/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc index 2fe22688ee2018..c244faf444cfec 100644 --- a/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -19,10 +19,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/hlo_runner.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc index 944a773e09a270..31ee1a58162c46 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc @@ -23,12 +23,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/hlo_runner.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc index 65ac2512efdcee..9ae246791e1bdc 100644 --- a/third_party/xla/xla/service/dynamic_padder_test.cc +++ b/third_party/xla/xla/service/dynamic_padder_test.cc @@ -34,6 +34,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" @@ -45,8 +47,6 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/llvm_irgen_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/service/dynamic_update_slice_test.cc b/third_party/xla/xla/service/dynamic_update_slice_test.cc index 96298fb6437ffd..154657307b09a9 100644 --- a/third_party/xla/xla/service/dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/dynamic_update_slice_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/execution_options_util.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc index 43426d10a6404e..cb6338cffb0a52 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc @@ -33,12 +33,12 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/gather_expander_test.cc b/third_party/xla/xla/service/gather_expander_test.cc index e6ea76ea4ff2cd..938433b26be4d4 100644 --- a/third_party/xla/xla/service/gather_expander_test.cc +++ b/third_party/xla/xla/service/gather_expander_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_query.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index ccefc268cb46b3..9a0c3cdb786a7f 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -31,14 +31,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/hlo_cost_analysis_test.cc b/third_party/xla/xla/service/hlo_cost_analysis_test.cc index a7011aa30de531..b2a3c50c6b738d 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis_test.cc @@ -30,10 +30,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/service/local_service.h" #include "xla/service/service.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/hlo_creation_utils_test.cc b/third_party/xla/xla/service/hlo_creation_utils_test.cc index 4e8f19f031d157..789a659867c215 100644 --- a/third_party/xla/xla/service/hlo_creation_utils_test.cc +++ b/third_party/xla/xla/service/hlo_creation_utils_test.cc @@ -27,13 +27,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/hlo_domain_test.cc b/third_party/xla/xla/service/hlo_domain_test.cc index 11acf73bf6cfff..20b61c76b30dd6 100644 --- a/third_party/xla/xla/service/hlo_domain_test.cc +++ b/third_party/xla/xla/service/hlo_domain_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/call_inliner.h" #include "xla/service/hlo_domain_isolator.h" #include "xla/service/hlo_domain_remover.h" #include "xla/service/hlo_domain_verifier.h" #include "xla/service/sharding_propagation.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/hlo_graph_dumper_test.cc b/third_party/xla/xla/service/hlo_graph_dumper_test.cc index d76e734f33ec7d..25d6e5796ad795 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper_test.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc b/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc index 7c2cc1d945c0d5..19228b9772c433 100644 --- a/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc +++ b/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index c97003ced8b129..5a363b658bd305 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -40,6 +40,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -47,8 +49,6 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/proto/proto_matchers.h" diff --git a/third_party/xla/xla/service/hlo_module_group_test.cc b/third_party/xla/xla/service/hlo_module_group_test.cc index 007df88bdcc9d9..e7ec60fa091883 100644 --- a/third_party/xla/xla/service/hlo_module_group_test.cc +++ b/third_party/xla/xla/service/hlo_module_group_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_group_metadata.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/hlo_module_metadata_test.cc b/third_party/xla/xla/service/hlo_module_metadata_test.cc index fc64b31d019ab8..e3861c97c2f512 100644 --- a/third_party/xla/xla/service/hlo_module_metadata_test.cc +++ b/third_party/xla/xla/service/hlo_module_metadata_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_metadata.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc index 6c673f56f914c7..718b1cb6fad7bc 100644 --- a/third_party/xla/xla/service/hlo_module_test.cc +++ b/third_party/xla/xla/service/hlo_module_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_matchers.h" @@ -46,7 +47,6 @@ limitations under the License. #include "xla/service/test_compilation_environment.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" diff --git a/third_party/xla/xla/service/hlo_proto_util_test.cc b/third_party/xla/xla/service/hlo_proto_util_test.cc index d5ef461e58878e..b32a85bb56ae23 100644 --- a/third_party/xla/xla/service/hlo_proto_util_test.cc +++ b/third_party/xla/xla/service/hlo_proto_util_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/hlo_proto_util.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/hlo_schedule_test.cc b/third_party/xla/xla/service/hlo_schedule_test.cc index 8b4db06aac1a57..74ff22abcf44a6 100644 --- a/third_party/xla/xla/service/hlo_schedule_test.cc +++ b/third_party/xla/xla/service/hlo_schedule_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/hlo_sharding_test.cc b/third_party/xla/xla/service/hlo_sharding_test.cc index 221a8cb5538e6b..2bc8f320ff3497 100644 --- a/third_party/xla/xla/service/hlo_sharding_test.cc +++ b/third_party/xla/xla/service/hlo_sharding_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/types/span.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/util/proto/proto_matchers.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 7e9d55e75bb0fb..0c0d9c5c5e508d 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -31,6 +31,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -42,8 +44,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/loop_schedule_linearizer_test.cc b/third_party/xla/xla/service/loop_schedule_linearizer_test.cc index 3d652478eb8713..c0eb33a1d467af 100644 --- a/third_party/xla/xla/service/loop_schedule_linearizer_test.cc +++ b/third_party/xla/xla/service/loop_schedule_linearizer_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/service/copy_insertion.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/map_inliner_test.cc b/third_party/xla/xla/service/map_inliner_test.cc index c9387108a19fae..d3cedb796d825d 100644 --- a/third_party/xla/xla/service/map_inliner_test.cc +++ b/third_party/xla/xla/service/map_inliner_test.cc @@ -22,12 +22,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc b/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc index bb1b55ccdd646b..89651978105925 100644 --- a/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc +++ b/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include #include "absl/functional/bind_front.h" #include "absl/log/log.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/third_party/xla/xla/service/pattern_matcher_gmock_test.cc b/third_party/xla/xla/service/pattern_matcher_gmock_test.cc index 899f909c983df0..dc581173c7612d 100644 --- a/third_party/xla/xla/service/pattern_matcher_gmock_test.cc +++ b/third_party/xla/xla/service/pattern_matcher_gmock_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/pattern_matcher_test.cc b/third_party/xla/xla/service/pattern_matcher_test.cc index fe3d92d29c12b3..21b902ebcf6b18 100644 --- a/third_party/xla/xla/service/pattern_matcher_test.cc +++ b/third_party/xla/xla/service/pattern_matcher_test.cc @@ -25,12 +25,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/scatter_determinism_expander_test.cc b/third_party/xla/xla/service/scatter_determinism_expander_test.cc index 81078b0da54499..b530a8d23b77f3 100644 --- a/third_party/xla/xla/service/scatter_determinism_expander_test.cc +++ b/third_party/xla/xla/service/scatter_determinism_expander_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/scatter_expander_test.cc b/third_party/xla/xla/service/scatter_expander_test.cc index 587780a32ff3cf..bd0b9b3a84110b 100644 --- a/third_party/xla/xla/service/scatter_expander_test.cc +++ b/third_party/xla/xla/service/scatter_expander_test.cc @@ -24,10 +24,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index a4f0693f4d2e0f..1a09139afcaf06 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/third_party/xla/xla/service/shaped_buffer_test.cc b/third_party/xla/xla/service/shaped_buffer_test.cc index b07e246e33b43d..6e0d17b30b747f 100644 --- a/third_party/xla/xla/service/shaped_buffer_test.cc +++ b/third_party/xla/xla/service/shaped_buffer_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_tree.h" @@ -29,7 +30,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/service/space_to_batch_converter_test.cc b/third_party/xla/xla/service/space_to_batch_converter_test.cc index 6473c65dccf73b..b81e0681b3db02 100644 --- a/third_party/xla/xla/service/space_to_batch_converter_test.cc +++ b/third_party/xla/xla/service/space_to_batch_converter_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/stream_pool_test.cc b/third_party/xla/xla/service/stream_pool_test.cc index 2bea4119a4d9b7..92da6b7e51dc3e 100644 --- a/third_party/xla/xla/service/stream_pool_test.cc +++ b/third_party/xla/xla/service/stream_pool_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "xla/hlo/testlib/test_helpers.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/test_helpers.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/transpose_folding_test.cc b/third_party/xla/xla/service/transpose_folding_test.cc index e00b6adeafb79b..1ea8143f4d4b22 100644 --- a/third_party/xla/xla/service/transpose_folding_test.cc +++ b/third_party/xla/xla/service/transpose_folding_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/shape_inference.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/tuple_util_test.cc b/third_party/xla/xla/service/tuple_util_test.cc index 6e91ad17f7e12d..886eaa2e3cde0f 100644 --- a/third_party/xla/xla/service/tuple_util_test.cc +++ b/third_party/xla/xla/service/tuple_util_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/while_loop_constant_sinking_test.cc b/third_party/xla/xla/service/while_loop_constant_sinking_test.cc index 2cfd69a9254e8b..92e0706863cc46 100644 --- a/third_party/xla/xla/service/while_loop_constant_sinking_test.cc +++ b/third_party/xla/xla/service/while_loop_constant_sinking_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc index d8b8eb620935ff..a44b940cffe57e 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc @@ -22,11 +22,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/while_loop_simplifier_test.cc b/third_party/xla/xla/service/while_loop_simplifier_test.cc index a478453e4aa881..d7543dc7573045 100644 --- a/third_party/xla/xla/service/while_loop_simplifier_test.cc +++ b/third_party/xla/xla/service/while_loop_simplifier_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/while_util_test.cc b/third_party/xla/xla/service/while_util_test.cc index e2162a841d599e..d4686557946ea3 100644 --- a/third_party/xla/xla/service/while_util_test.cc +++ b/third_party/xla/xla/service/while_util_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" From 3b77c9ab6dd0f3f81ddea49d825c7197c15aad0a Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 9 Apr 2025 11:34:30 -0700 Subject: [PATCH 0459/1324] Automated Code Change PiperOrigin-RevId: 745669468 --- tensorflow/dtensor/mlir/BUILD | 25 ++++++------------- tensorflow/dtensor/mlir/dtensor_dialect/BUILD | 24 +++++------------- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index d3bdcb73f0839e..08ffac8e75b859 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -20,16 +20,10 @@ package( gentbl_cc_library( name = "tensorflow_dtensor_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_dtensor.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_dtensor.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_dtensor.h.inc": ["-gen-op-decls"], + "ir/tf_dtensor.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_dtensor.td", td_srcs = [ @@ -48,13 +42,10 @@ gentbl_cc_library( gentbl_cc_library( name = "dtensor_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [( - [ - "-gen-pass-decls", - "-name=DTensor", - ], - "dtensor_passes.h.inc", - )], + tbl_outs = {"dtensor_passes.h.inc": [ + "-gen-pass-decls", + "-name=DTensor", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD index c0cca6cf3846f4..e32cb17a0bdaa2 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD +++ b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD @@ -33,24 +33,12 @@ td_library( gentbl_cc_library( name = "DialectIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "ir/dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "ir/dialect.cc.inc", - ), - ], + tbl_outs = { + "ir/ops.h.inc": ["-gen-op-decls"], + "ir/ops.cc.inc": ["-gen-op-defs"], + "ir/dialect.h.inc": ["-gen-dialect-decls"], + "ir/dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/dtensor_ops.td", deps = [":dtensor_td_files"], From 92ea23d5c602aa6a929b67d1f97612b5d44b97c3 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 9 Apr 2025 12:22:37 -0700 Subject: [PATCH 0460/1324] Fix direct HLO->StableHLO conversion for programs with bounded dynamism PiperOrigin-RevId: 745687946 --- .../xla/xla/hlo/tools/hlo_translate.cc | 37 +++++++++------- third_party/xla/xla/hlo/translate/BUILD | 1 + .../hlo_to_mhlo/hlo_function_importer.cc | 6 ++- .../hlo_to_mhlo/hlo_module_importer.cc | 6 +++ .../import_bounded_dynamism_stablehlo.mlir | 13 ++++++ .../xla/xla/hlo/translate/stablehlo.cc | 43 ++++++++++++++++++- 6 files changed, 86 insertions(+), 20 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism_stablehlo.mlir diff --git a/third_party/xla/xla/hlo/tools/hlo_translate.cc b/third_party/xla/xla/hlo/tools/hlo_translate.cc index eab8ff6f5abbf3..162671f87bee3f 100644 --- a/third_party/xla/xla/hlo/tools/hlo_translate.cc +++ b/third_party/xla/xla/hlo/tools/hlo_translate.cc @@ -97,14 +97,21 @@ bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) { constexpr char kLoadHloError[] = "Failed to parse HLO."; absl::StatusOr> GetModuleFromHLOText( - absl::string_view content, mlir::MLIRContext* context) { + absl::string_view content, mlir::MLIRContext* context, bool emit_mhlo) { auto hlo_text = xla::ParseAndReturnUnverifiedModule( content, {}, xla::HloParserOptions().set_keep_module_auto_layouts(true)); if (!hlo_text.ok()) return absl::InvalidArgumentError(kLoadHloError); + auto hlo_module = std::move(hlo_text.value()); + + // For emitting StableHLO, use new APIs by defualt. + if (!emit_mhlo) { + return xla::ConvertHloToStablehlo(*context, hlo_module.get()); + } + + // For MHLO require legacy API for now. mlir::OwningOpRef module = xla::llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); - auto hlo_module = std::move(hlo_text.value()); auto status = ConvertHloToMlirHlo(*module, hlo_module.get(), /*import_all_computations=*/true, /*flatten_computation_args_result*/ true); @@ -113,11 +120,17 @@ absl::StatusOr> GetModuleFromHLOText( } absl::StatusOr> GetModuleFromHLOProto( - const std::string& content, mlir::MLIRContext* context) { + const std::string& content, mlir::MLIRContext* context, bool emit_mhlo) { xla::HloProto hlo_proto; if (!LoadHloProto(content, &hlo_proto)) return absl::InvalidArgumentError(kLoadHloError); + // For emitting StableHLO, use new APIs by defualt. + if (!emit_mhlo) { + return xla::ConvertHloToStablehlo(*context, hlo_proto.mutable_hlo_module()); + } + + // For MHLO require legacy API for now. mlir::OwningOpRef module = xla::llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); auto status = @@ -130,7 +143,7 @@ absl::StatusOr> GetModuleFromHLOProto( mlir::OwningOpRef GetModuleFromHloInput( const std::shared_ptr& source_mgr, - mlir::MLIRContext* context) { + mlir::MLIRContext* context, bool emit_mhlo) { const llvm::MemoryBuffer* input = source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); absl::string_view content = @@ -144,7 +157,7 @@ mlir::OwningOpRef GetModuleFromHloInput( }; // Try HLO Text - auto module_from_text = GetModuleFromHLOText(content, context); + auto module_from_text = GetModuleFromHLOText(content, context, emit_mhlo); if (module_from_text.ok()) return std::move(module_from_text.value()); if (module_from_text.status().message() != kLoadHloError) { emitError() << "Failed to convert HLO to MLIR: " @@ -153,7 +166,8 @@ mlir::OwningOpRef GetModuleFromHloInput( } // Try HLO Proto - auto module_from_proto = GetModuleFromHLOProto(std::string(content), context); + auto module_from_proto = + GetModuleFromHLOProto(std::string(content), context, emit_mhlo); if (module_from_proto.ok()) return std::move(module_from_proto.value()); if (module_from_proto.status().message() != kLoadHloError) { emitError() << "Failed to convert HLO to MLIR: " @@ -172,19 +186,10 @@ static mlir::OwningOpRef HloToMlirTranslate( const std::shared_ptr& sourceMgr, mlir::MLIRContext* context) { mlir::OwningOpRef module = - GetModuleFromHloInput(sourceMgr, context); + GetModuleFromHloInput(sourceMgr, context, emit_mhlo); if (!module) return nullptr; - if (emit_mhlo) return module; - - mlir::PassManager pm(context); - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); - if (failed(pm.run(*module))) { - module->emitError("Failed to legalize to StableHLO"); - return nullptr; - } - return module; } diff --git a/third_party/xla/xla/hlo/translate/BUILD b/third_party/xla/xla/hlo/translate/BUILD index 4a5bcf1e127909..ef35b6f67a6276 100644 --- a/third_party/xla/xla/hlo/translate/BUILD +++ b/third_party/xla/xla/hlo/translate/BUILD @@ -95,6 +95,7 @@ cc_library( "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter", "//xla/mlir/utils:error_util", + "//xla/mlir_hlo", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/mlir_hlo:mhlo_passes", "//xla/mlir_hlo:stablehlo_extension_passes", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index eef67b1c3ef9da..177f3fc0386dfb 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -2325,8 +2325,10 @@ HloFunctionImporter::ImportInstructionWithLayout( } // Print generic in debug since module may be invalid while printing. - LLVM_DEBUG( - op->print(llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm())); + LLVM_DEBUG({ + op->print(llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); // See MlirToHloConversionOptions for more about layouts. // diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc index 0a65ba01b4a3e3..6958f7ae0c3736 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -38,6 +39,8 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" +#define DEBUG_TYPE "xla-translate" + namespace xla { HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, @@ -91,7 +94,9 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { &function_map_, &builder_, /*is_main*/ true, flatten_computation_args_result_) .status()); + // Convert all ops to MHLO + LLVM_DEBUG(llvm::dbgs() << "Emit StableHLO: " << emit_stablehlo_ << "\n"); if (!emit_stablehlo_) { TF_RETURN_IF_ERROR(ConvertToMhlo(module)); } @@ -114,6 +119,7 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { hlo_module, module, flatten_computation_args_result_, builder_)); // Convert all ops to MHLO + LLVM_DEBUG(llvm::dbgs() << "Emit StableHLO: " << emit_stablehlo_ << "\n"); if (!emit_stablehlo_) { TF_RETURN_IF_ERROR(ConvertToMhlo(module)); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism_stablehlo.mlir b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism_stablehlo.mlir new file mode 100644 index 00000000000000..ab00b96b84f27c --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism_stablehlo.mlir @@ -0,0 +1,13 @@ +// RUN: hlo-translate -hlo-to-mlir -split-input-file -verify-diagnostics %s | FileCheck %s + +HloModule main, entry_computation_layout={(f32[16,50]{1,0}, s64[1,<=16]{1,0})->f32[<=16,50]{1,0}} + +// CHECK-LABEL: main +// CHECK: stablehlo.reshape {{.*}} (tensor<1x?xi64, #stablehlo.bounds>) -> tensor> +// CHECK-NEXT: "stablehlo.gather"{{.*}} : (tensor<16x50xf32>, tensor>) -> tensor> +ENTRY %main.5 (Arg_0.1: f32[16,50], Arg_1.2: s64[1,<=16]) -> f32[<=16,50] { + %Arg_0.1 = f32[16,50] parameter(0) + %Arg_1.2 = s64[1,<=16] parameter(1) + %reshape.3 = s64[<=16] reshape(%Arg_1.2), metadata={source_file="/tmp/t.mlir" source_line=3} + ROOT %gather.4 = f32[<=16,50] gather(%Arg_0.1, %reshape.3), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,50}, metadata={source_file="/tmp/t.mlir" source_line=4} +} diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index f7642181220077..cde29fc4125e35 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -18,9 +18,13 @@ limitations under the License. #include #include +#include "mhlo/transforms/passes.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -28,9 +32,11 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "stablehlo/dialect/Register.h" @@ -40,6 +46,7 @@ limitations under the License. #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" #include "xla/mlir/utils/error_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/mlir_hlo/stablehlo_ext/transforms/passes.h" @@ -47,15 +54,47 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tsl/platform/errors.h" +#define DEBUG_TYPE "xla-translate" + namespace xla { namespace { + +bool isBoundedDynamic(mlir::Type type) { + LLVM_DEBUG(llvm::dbgs() << "isBoundedDynamic: " << type << "\n"); + if (!llvm::isa(type)) { + return false; + } + auto encoding = llvm::cast(type).getEncoding(); + return encoding && llvm::isa(encoding); +} + +bool hasBoundedDynamism(mlir::ModuleOp module) { + bool has_bounded_dynamism = false; + module->walk([&](mlir::Operation* op) { + auto results = op->getResultTypes(); + has_bounded_dynamism |= llvm::any_of(results, isBoundedDynamic); + if (has_bounded_dynamism) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + return has_bounded_dynamism; +} + absl::Status MhloToStablehlo(mlir::ModuleOp module) { + LLVM_DEBUG(llvm::dbgs() << "MHLO to StableHLO\n"); auto context = module.getContext(); mlir::PassManager pm(context); mlir::BaseScopedDiagnosticHandler diag_handler(context); mlir::mhlo::HloLegalizeToStablehloPassOptions options; options.allow_xla_features_ = true; + bool has_bounded_dynamism = hasBoundedDynamism(module); + if (has_bounded_dynamism) { + // Need to converge program to MHLO before StableHLO in the presence of + // bounded dynamism. + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + } pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass(options)); if (failed(pm.run(module))) { return diag_handler.ConsumeStatus(); @@ -151,7 +190,7 @@ absl::StatusOr> ConvertHloToStablehlo( TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), /*import_all_computation=*/true, /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/true) + /*emit_stablehlo=*/false) .Import(*hlo_module)); TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); return std::move(mlir_module); @@ -164,7 +203,7 @@ absl::StatusOr> ConvertHloToStablehlo( TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), /*import_all_computation=*/true, /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/true) + /*emit_stablehlo=*/false) .Import(*hlo_module_proto)); TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); return std::move(mlir_module); From fc08d9dcca73ce184c2f8e95a1fd20a3ca230395 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 9 Apr 2025 12:53:26 -0700 Subject: [PATCH 0461/1324] Use `==` instead of `*_EQ` macros checking equivalence with `nullptr`. PiperOrigin-RevId: 745699112 --- .../lite/kernels/bidirectional_sequence_lstm.cc | 3 +-- tensorflow/lite/kernels/lstm.cc | 8 ++++---- .../lite/kernels/unidirectional_sequence_lstm.cc | 12 +++++++++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 40f3b812825497..6472da7ca6601b 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/lstm_eval.h" @@ -320,7 +319,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes( const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, input_gate_bias_tensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + TF_LITE_ENSURE(context, input_gate_bias == nullptr); } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index 9f74aad97553a0..a88ba32428f2e7 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -25,10 +25,10 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -982,7 +982,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, kInputGateBiasTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + TF_LITE_ENSURE(context, input_gate_bias == nullptr); } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); @@ -1061,7 +1061,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( context, node, kInputLayerNormCoefficientsTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + TF_LITE_ENSURE(context, input_layer_norm_coefficients == nullptr); } else { TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 04167e897fb808..518b5c7d69bccb 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -16,14 +16,20 @@ limitations under the License. #include #include +#include #include +#include +#include +#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -562,7 +568,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + TF_LITE_ENSURE(context, input_gate_bias == nullptr); } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); @@ -642,7 +648,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( context, node, lstm::full::kInputLayerNormCoefficientsTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + TF_LITE_ENSURE(context, input_layer_norm_coefficients == nullptr); } else { TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); From c300b7dece76084bc25ff1a5c4e2a47cab6f57b0 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 9 Apr 2025 12:55:24 -0700 Subject: [PATCH 0462/1324] [xla] Use std::weak_ptr to keep state for in-progress rendezvous PiperOrigin-RevId: 745699808 --- third_party/xla/xla/service/rendezvous.h | 106 ++++++++---------- .../xla/xla/service/rendezvous_test.cc | 20 ++-- 2 files changed, 59 insertions(+), 67 deletions(-) diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index 86488963e16744..dbc2fae612fa14 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -214,15 +214,15 @@ struct RendezvousState : public RendezvousStateSynchronization { // Rendezvous state ownership: // // (1) When rendezvous participant initiates a rendezvous with a particular key -// we create a new state for it, keep it in a map for tracking and return a -// shared pointer to the caller. +// we create a new state for it, keep it in a map as weak pointer for +// tracking and return a shared pointer to the caller. // // (2) When rendezvous participant joins in-progress rendezvous it gets back // a shared pointer that is copied from a tracking map. // -// (3) When the last rendezvous participant computes the result it completes the -// rendezvous and removes a shared pointer to a state. Remaining shared -// pointers destructed when all participants are notified. +// (3) When all rendezvous participants complete the rendezvous, shared pointers +// are destructed and the tracking map will have an expired weak pointer, +// that will be lazily garbage collected by the next rendezvous. // // This process guarantees that all completed rendezvous are removed from a map // and a map has records only for rendezvous in progress. @@ -233,56 +233,25 @@ class RendezvousMap { std::shared_ptr Join(const K& key, size_t num_threads) { absl::MutexLock lock(&mutex_); - std::shared_ptr& state = state_[key]; - // Join an in-progress rendezvous. - if (state) return state; + // Erase expired rendezvous from the map. + absl::erase_if(state_, [](const auto& e) { return e.second.expired(); }); - // Join a newly created rendezvous. - return state = std::make_shared(num_threads); - } - - template - void Complete(const K& key, Result&& result) { - std::shared_ptr state = [&] { - absl::MutexLock lock(&mutex_); - - // Extract state from the map so we can immediately start a new round of - // rendezvous with the same key. A state for previous rendezvous will be - // destructed with the last copy of a shared pointer. - std::shared_ptr state = state_.extract(key).mapped(); - - // Check that we have have exactly the number of participants we expected: - // +1 reference for all participants and a +1 reference we extracted. - CHECK_EQ(state.use_count(), 1 + state->values.size()); // NOLINT + std::weak_ptr& in_progress = state_[key]; - return state; - }(); - - // We notify awaiting participants without holding a rendezvous map lock, as - // the rendezvous callback might be an expensive operation and might block - // the progress of concurrent rendezvous for other keys. - - // Publish rendezvous result to all participants. - if constexpr (IsStatusOrResult::value) { - if (ABSL_PREDICT_TRUE(result.ok())) { - state->result = std::make_shared(*std::forward(result)); - } else { - state->result = result.status(); - } - } else { - state->result = std::make_shared(std::forward(result)); + // Try to join an in-progress rendezvous for a given key. + if (std::shared_ptr joined = in_progress.lock()) { + return joined; } - // Notify awaiting participants that result is ready. - absl::MutexLock lock(&state->mutex); - state->ready = true; - state->cv.SignalAll(); + // Start a new rendezvous for a given key. + std::shared_ptr start = std::make_shared(num_threads); + return (in_progress = start, start); } private: absl::Mutex mutex_; - absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); }; void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, @@ -295,6 +264,22 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, // Rendezvous implemenetation. //===----------------------------------------------------------------------===// +template +absl::StatusOr> InvokeRendezvous( + Fn fn, absl::Span values) { + auto result = fn(values); + + if constexpr (internal::IsStatusOrResult::value) { + if (ABSL_PREDICT_TRUE(result.ok())) { + return std::make_shared(*std::move(result)); + } else { + return result.status(); + } + } else { + return std::make_shared(std::move(result)); + } +} + template absl::StatusOr> Rendezvous( absl::string_view name, const K& key, const V& value, size_t num_threads, @@ -307,16 +292,7 @@ absl::StatusOr> Rendezvous( // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread). if (num_threads == 1) { const V* ptr = &value; - auto result = fn(absl::MakeSpan(&ptr, 1)); - - if constexpr (internal::IsStatusOrResult::value) { - if (ABSL_PREDICT_TRUE(result.ok())) { - return std::make_shared(*std::move(result)); - } - return result.status(); - } else { - return std::make_shared(std::move(result)); - } + return InvokeRendezvous(std::move(fn), absl::MakeSpan(&ptr, 1)); } using State = internal::RendezvousState; @@ -359,9 +335,23 @@ absl::StatusOr> Rendezvous( // be notified via `state->ready` flag when result is ready, and we rely on // the store to a flag to create a memory barrier that makes access to // `state->result` safe without any extra synchronization. - tsl::profiler::TraceMe trace("ExecuteRendezvousCallback"); + tsl::profiler::TraceMe trace("InvokeRendezvous"); absl::Span values(state->values.data(), num_threads); - rendezvous.Complete(key, fn(values)); + + // Check that we have have exactly the number of participants we expect. + CHECK_EQ(state.use_count(), num_threads); // NOLINT + + // Publish rendezvous result to all participants. + state->result = InvokeRendezvous(std::move(fn), values); + + // Switch `ready` flag to signal all participants that result is ready. + { + absl::MutexLock lock(&state->mutex); + state->ready = true; + } + + // Notify awaiting participants that result is ready. + state->cv.SignalAll(); } return state->result; diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index f9fbb1c8287e20..70725f135ea24e 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -37,8 +37,8 @@ limitations under the License. namespace xla { namespace { -absl::Duration Timeout() { return absl::Seconds(10); } -absl::Duration Terminate() { return absl::Seconds(10); } +absl::Duration Timeout() { return absl::Seconds(5); } +absl::Duration Terminate() { return absl::Seconds(5); } tsl::thread::ThreadPool CreateThreadPool(int32_t size) { return tsl::thread::ThreadPool(tsl::Env::Default(), "rendezvous_test", size); @@ -268,8 +268,9 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - CHECK_OK(Rendezvous("rendezvous_test", 0, num_threads, - [] { return 42; })); + CHECK_OK(Rendezvous( + "rendezvous_test", /*key=*/0, num_threads, [] { return 42; }, + Timeout(), Terminate())); counter.DecrementCount(); }); } @@ -285,9 +286,9 @@ static void BM_RendezvousWithValues(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&, i] { - int32_t value = i; - CHECK_OK(Rendezvous("rendezvous_test", 0, value, num_threads, - [](auto) { return 42; })); + CHECK_OK(Rendezvous( + "rendezvous_test", /*key=*/0, /*value=*/i, num_threads, + [](auto) { return 42; }, Timeout(), Terminate())); counter.DecrementCount(); }); } @@ -306,8 +307,9 @@ static void BM_GroupedRendezvous(benchmark::State& state) { for (int64_t group = 0; group < num_groups; ++group) { for (int64_t i = 0; i < group_size; ++i) { thread_pool.Schedule([&, group] { - CHECK_OK(Rendezvous("rendezvous_test", group, group_size, - [] { return 42; })); + CHECK_OK(Rendezvous( + "rendezvous_test", /*key=*/group, /*num_threads=*/group_size, + [] { return 42; }, Timeout(), Terminate())); counter.DecrementCount(); }); } From 8f343cbf73155ef63e5544d707b26a48aee1af30 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 13:00:46 -0700 Subject: [PATCH 0463/1324] Ensure control dependencies are safely removed before removing instructions, in preparation for adding more control dependencies during copy insertion. PiperOrigin-RevId: 745701678 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/cpu/small_while_loop_hoisting_pass.cc | 1 + third_party/xla/xla/service/gpu/gpu_compiler_test.cc | 2 +- third_party/xla/xla/service/instruction_fusion.cc | 2 ++ 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 1bbcd3de3db5a5..86ff582f47a9be 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -2022,6 +2022,7 @@ cc_library( "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc b/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc index 767e914a075992..c599b5f3ff6425 100644 --- a/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc +++ b/third_party/xla/xla/service/cpu/small_while_loop_hoisting_pass.cc @@ -127,6 +127,7 @@ absl::StatusOr SmallWhileLoopHoistingPass::Run( call_instruction->add_frontend_attribute("xla_cpu_small_call", "true"); TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(call_instruction)); + TF_RETURN_IF_ERROR(while_instr->SafelyDropAllControlDependencies()); TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); changed = true; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index feaba89cb931b6..bc8bd7a3fde0ec 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -837,7 +837,7 @@ CHECK: %[[RESULT_RECV:.*]] = recv(%[[AFTER_ALL]]) CHECK-SAME: channel_id=[[CHANNEL_ID]] CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}, -CHECK-SAME: control-predecessors={%[[CUSTOM_CALL]]} +CHECK-SAME: control-predecessors={%[[CUSTOM_CALL:.*]]} CHECK: %[[RESULT_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[AFTER_ALL]]) CHECK-SAME: channel_id=1 CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index d9f66a552a5266..f75aaf57fd130e 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/errors.h" // The source_location.h is not available in open source. #if defined(PLATFORM_GOOGLE) #include "absl/types/source_location.h" @@ -727,6 +728,7 @@ absl::StatusOr InstructionFusion::Run( // Operand is now dead. Remove from queue. fusion_queue->RemoveInstruction(operand); // Remove from computation. + TF_RETURN_IF_ERROR(operand->SafelyDropAllControlDependencies()); TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); } From a9fb074d5b4c8a11d735658e9fdc4041b7cc278b Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Wed, 9 Apr 2025 13:03:00 -0700 Subject: [PATCH 0464/1324] [XLA:Upkeep] Resolve 2 instances of the following issue: Todo (resolved) PiperOrigin-RevId: 745702499 --- third_party/xla/xla/service/buffer_assignment_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index 8f50ff8dbd6801..4bc7e2935c98c9 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -1771,8 +1771,6 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { GetTopLevelAllocation(*assignment, tuple_element)); } -// TODO(b/32248867): Enable when buffer assignment gives allocations to -// constants. TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. @@ -1956,8 +1954,6 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { GetTopLevelAllocation(*assignment, bitcast)); } -// TODO(b/34669761): Remove this test when buffers are allowed to share -// allocations. TEST_F(BufferAssignmentTest, TupleBufferNotReused) { // Test a computation that returns a tuple parameter. auto builder = HloComputation::Builder(TestName()); From 3ff1aa33e561bfcbf4a01be1a2c9cd3931b7cefb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 13:04:08 -0700 Subject: [PATCH 0465/1324] Adds IOPDDL utilities to XLA Auto Sharding's third_party directory. PiperOrigin-RevId: 745702969 --- .../xla/hlo/experimental/auto_sharding/BUILD | 33 +++++ .../experimental/auto_sharding/example.json | 45 ++++++ .../hlo/experimental/auto_sharding/iopddl.cc | 133 ++++++++++++++++++ .../hlo/experimental/auto_sharding/iopddl.h | 81 +++++++++++ .../experimental/auto_sharding/iopddl_test.cc | 130 +++++++++++++++++ .../hlo/experimental/auto_sharding/solver.cc | 65 +++++++++ .../hlo/experimental/auto_sharding/solver.h | 37 +++++ 7 files changed, 524 insertions(+) create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/example.json create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.cc create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.h create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/iopddl_test.cc create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/solver.cc create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/solver.h diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 9499e2e5380bce..1615747d70a1bd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -440,3 +440,36 @@ xla_cc_test( "@local_tsl//tsl/platform:statusor", ] + if_google(["@com_google_ortools//ortools/linear_solver:linear_solver_scip"]), ) + +cc_library( + name = "iopddl_lib", + srcs = [ + "iopddl.cc", + "solver.cc", + ], + hdrs = [ + "iopddl.h", + "solver.h", + ], + deps = [ + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +cc_test( + # CC_TEST_OK=tests imported from the IOPDDL library. + name = "iopddl_test", + srcs = ["iopddl_test.cc"], + data = ["example.json"], + deps = [ + ":iopddl_lib", + "//xla/tsl/platform:status_matchers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/example.json b/third_party/xla/xla/hlo/experimental/auto_sharding/example.json new file mode 100644 index 00000000000000..198ffd1621c09f --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/example.json @@ -0,0 +1,45 @@ +{ + "problem": { + "name": "example", + "nodes": { + "intervals": [ + [30, 70], + [40, 70], + [50, 120], + [110, 140], + [110, 150] + ], + "costs": [ + [15], + [55, 65], + [25, 45, 35], + [85, 75], + [95] + ], + "usages": [ + [10], + [25, 25], + [15, 20, 15], + [10, 10], + [15] + ] + }, + "edges": { + "nodes": [ + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 4] + ], + "costs": [ + [30, 40], + [50, 10, 40], + [90, 10, 20, 80], + [60, 20, 30], + [70, 60] + ] + }, + "usage_limit": 50 + } +} diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.cc new file mode 100644 index 00000000000000..88748eb9782084 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.cc @@ -0,0 +1,133 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "xla/hlo/experimental/auto_sharding/iopddl.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +//////////////////////////////////////////////////////////////////////////////// +///////// Utilities for reading problems and evaluating solutions. ///////// +///////// Contest participants do not need to modify this code. ///////// +//////////////////////////////////////////////////////////////////////////////// + +namespace iopddl { + +bool Strategy::operator==(const Strategy& other) const { + return cost == other.cost && usage == other.usage; +} + +bool Node::operator==(const Node& other) const { + return interval == other.interval && strategies == other.strategies; +} + +bool Edge::operator==(const Edge& other) const { + return nodes == other.nodes && strategies == other.strategies; +} + +bool Problem::operator==(const Problem& other) const { + return name == other.name && nodes == other.nodes && edges == other.edges && + usage_limit == other.usage_limit; +} + +absl::StatusOr Evaluate(const Problem& problem, + const Solution& solution) { + if (solution.size() != problem.nodes.size()) { + return absl::InvalidArgumentError("Incorrect solution size"); + } + TimeIdx max_time = 0; + for (const Node& node : problem.nodes) { + max_time = std::max(max_time, node.interval.second); + } + TotalCost cost = 0; + std::vector total_usages(max_time); + for (NodeIdx node_idx = 0; node_idx < problem.nodes.size(); ++node_idx) { + const Node& node = problem.nodes[node_idx]; + const StrategyIdx strategy_idx = solution[node_idx]; + if (strategy_idx < 0 || strategy_idx >= (int64_t)node.strategies.size()) { + return absl::OutOfRangeError("Invalid strategy index"); + } + cost += node.strategies[strategy_idx].cost; + for (TimeIdx t = node.interval.first; t < node.interval.second; ++t) { + total_usages[t] += node.strategies[strategy_idx].usage; + } + } + for (const Edge& edge : problem.edges) { + StrategyIdx strategy_idx = 0; + for (const NodeIdx node_idx : edge.nodes) { + strategy_idx *= problem.nodes[node_idx].strategies.size(); + strategy_idx += solution[node_idx]; + } + cost += edge.strategies[strategy_idx].cost; + } + if (problem.usage_limit) { + for (const TotalUsage& total_usage : total_usages) { + if (total_usage > *problem.usage_limit) { + return absl::ResourceExhaustedError("Usage limit exceeded"); + } + } + } + return cost; +} + +// TODO(moffitt): Re-implement this using an XLA-friendly library (eg, jsoncpp). +absl::StatusOr ReadProblem(const std::string& filename) { +/* + const nlohmann::json data = nlohmann::json::parse(std::ifstream(filename)); + Problem problem = {.name = data["problem"]["name"]}; + const auto& nodes = data["problem"]["nodes"]; + for (const auto& node_interval : nodes["intervals"]) { + problem.nodes.push_back({.interval = {node_interval[0], node_interval[1]}}); + } + for (NodeIdx node_idx = 0; node_idx < problem.nodes.size(); ++node_idx) { + Node& node = problem.nodes[node_idx]; + const auto& costs = nodes["costs"][node_idx]; + const auto& usages = nodes["usages"][node_idx]; + node.strategies.reserve(costs.size()); + for (StrategyIdx strategy_idx = 0; strategy_idx < costs.size(); + ++strategy_idx) { + node.strategies.push_back( + {.cost = costs[strategy_idx], .usage = usages[strategy_idx]}); + } + } + const auto& edges = data["problem"]["edges"]; + for (const auto& node_list : edges["nodes"]) { + problem.edges.push_back({}); + for (const NodeIdx node_idx : node_list) { + problem.edges.back().nodes.push_back(node_idx); + } + } + for (EdgeIdx edge_idx = 0; edge_idx < problem.edges.size(); ++edge_idx) { + Edge& edge = problem.edges[edge_idx]; + for (const Cost cost : edges["costs"][edge_idx]) { + edge.strategies.push_back({.cost = cost, .usage = 0}); + } + } + if (data["problem"].contains("usage_limit")) { + problem.usage_limit = data["problem"]["usage_limit"]; + } + return problem; +*/ + return absl::UnimplementedError("ReadProblem is not implemented"); +} + +} // namespace iopddl diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.h b/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.h new file mode 100644 index 00000000000000..9dbab5cbc38ca9 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl.h @@ -0,0 +1,81 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_IOPDDL_H_ +#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_IOPDDL_H_ + +#include +#include +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "absl/status/statusor.h" + +//////////////////////////////////////////////////////////////////////////////// +///////// Basic definitions for problem & solution data structures. ///////// +///////// Contest participants do not need to modify this code. ///////// +//////////////////////////////////////////////////////////////////////////////// + +namespace iopddl { + +using Cost = int64_t; +using Usage = int64_t; +using TimeIdx = int64_t; +using NodeIdx = int64_t; +using EdgeIdx = int64_t; +using StrategyIdx = int64_t; +using Interval = std::pair; +using Solution = std::vector; +using TotalUsage = absl::int128; +using TotalCost = absl::int128; + +struct Strategy { + Cost cost; + Usage usage; + bool operator==(const Strategy& other) const; +}; + +struct Node { + Interval interval; // Interpreted as half-open with an exclusive upper bound + std::vector strategies; + bool operator==(const Node& other) const; +}; + +struct Edge { + using Nodes = std::vector; + Nodes nodes; + std::vector strategies; + bool operator==(const Edge& other) const; +}; + +struct Problem { + std::string name; + std::vector nodes; + std::vector edges; + std::optional usage_limit; + bool operator==(const Problem& other) const; +}; + +absl::StatusOr Evaluate(const Problem& problem, + const Solution& solution); + +absl::StatusOr ReadProblem(const std::string& filename); + +} // namespace iopddl + +#endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_IOPDDL_H_ diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl_test.cc new file mode 100644 index 00000000000000..4dc2e528c7a411 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/iopddl_test.cc @@ -0,0 +1,130 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "xla/hlo/experimental/auto_sharding/iopddl.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "xla/hlo/experimental/auto_sharding/solver.h" +#include "xla/tsl/platform/status_matchers.h" + +namespace iopddl { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +Problem GetExampleProblem() { + return { + .name = "example", + .nodes = + { + // Node 0 + {.interval = {30, 70}, .strategies = {{.cost = 15, .usage = 10}}}, + // Node 1 + {.interval = {40, 70}, + .strategies = {{.cost = 55, .usage = 25}, + {.cost = 65, .usage = 25}}}, + // Node 2 + {.interval = {50, 120}, + .strategies = {{.cost = 25, .usage = 15}, + {.cost = 45, .usage = 20}, + {.cost = 35, .usage = 15}}}, + // Node 3 + {.interval = {110, 140}, + .strategies = {{.cost = 85, .usage = 10}, + {.cost = 75, .usage = 10}}}, + // Node 4 + {.interval = {110, 150}, + .strategies = {{.cost = 95, .usage = 15}}}, + }, + .edges = + { + {.nodes = {0, 1}, .strategies = {{.cost = 30}, {.cost = 40}}}, + {.nodes = {0, 2}, + .strategies = {{.cost = 50}, {.cost = 10}, {.cost = 40}}}, + {.nodes = {1, 3}, + .strategies = + {{.cost = 90}, {.cost = 10}, {.cost = 20}, {.cost = 80}}}, + {.nodes = {2, 4}, + .strategies = {{.cost = 60}, {.cost = 20}, {.cost = 30}}}, + {.nodes = {3, 4}, .strategies = {{.cost = 70}, {.cost = 60}}}, + }, + .usage_limit = 50}; +} + +TEST(EvaluateTest, LegalSolution) { + // Node costs: 15 + 65 + 35 + 85 + 95 = 295 + // Edge costs: 40 + 40 + 20 + 30 + 70 = 200 + EXPECT_THAT(Evaluate(GetExampleProblem(), {0, 1, 2, 0, 0}), + IsOkAndHolds(495)); +} + +TEST(EvaluateTest, LegalSolutionNoUsageLimit) { + Problem problem = GetExampleProblem(); + problem.usage_limit.reset(); + // Node costs: 15 + 55 + 45 + 75 + 95 = 285 + // Edge costs: 30 + 10 + 10 + 20 + 60 = 130 + EXPECT_THAT(Evaluate(problem, {0, 0, 1, 1, 0}), IsOkAndHolds(415)); +} + +TEST(EvaluateTest, IllegalSolutionEclipsesUsageLimit) { + EXPECT_EQ(Evaluate(GetExampleProblem(), {0, 0, 1, 1, 0}).status().code(), + absl::StatusCode::kResourceExhausted); +} + +TEST(EvaluateTest, IllegalSolutionHasTooManyTerms) { + EXPECT_EQ(Evaluate(GetExampleProblem(), {0, 0, 0, 0, 0, 0}).status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(EvaluateTest, IllegalSolutionHasTooFewTerms) { + EXPECT_EQ(Evaluate(GetExampleProblem(), {0, 0, 0, 0}).status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(EvaluateTest, IllegalSolutionHasNegativeStrategyIndex) { + EXPECT_EQ(Evaluate(GetExampleProblem(), {0, 0, -1, 0, 0}).status().code(), + absl::StatusCode::kOutOfRange); +} + +TEST(EvaluateTest, IllegalSolutionHasBogusStrategyIndex) { + EXPECT_EQ(Evaluate(GetExampleProblem(), {0, 0, 4, 0, 0}).status().code(), + absl::StatusCode::kOutOfRange); +} + +TEST(DISABLED_ReadProblemTest, ExampleFile) { + const std::string filename = "example.json"; + EXPECT_THAT(ReadProblem(filename), IsOkAndHolds(GetExampleProblem())); +} + +TEST(SolveTest, FindsOptimalSolution) { + EXPECT_THAT(Solver().Solve(GetExampleProblem(), absl::Seconds(1)), + IsOkAndHolds(Solution{0, 0, 2, 1, 0})); +} + +TEST(SolveTest, NoSolutionFound) { + Problem problem = GetExampleProblem(); + problem.usage_limit = 0; + EXPECT_EQ(Solver().Solve(problem, absl::Seconds(1)).status().code(), + absl::StatusCode::kNotFound); +} + +} // namespace +} // namespace iopddl diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/solver.cc new file mode 100644 index 00000000000000..616a1d0f2a611f --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/solver.cc @@ -0,0 +1,65 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "xla/hlo/experimental/auto_sharding/solver.h" + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "xla/hlo/experimental/auto_sharding/iopddl.h" + +namespace iopddl { + +//////////////////////////////////////////////////////////////////////////////// +// A simple solver that generates random solutions until the given timeout. // +// Contest participants SHOULD replace this implementation with their own!! // +//////////////////////////////////////////////////////////////////////////////// + +absl::StatusOr Solver::Solve(const Problem& problem, + absl::Duration timeout) { + const absl::Time start_time = absl::Now(); + std::optional best_cost; + std::optional best_solution; + unsigned int seed = 2025; + while (absl::Now() - start_time < timeout) { + Solution solution; + solution.reserve(problem.nodes.size()); + for (const Node& node : problem.nodes) { + solution.push_back(rand_r(&seed) % node.strategies.size()); + } + auto cost = Evaluate(problem, solution); + if (!cost.ok() || (best_cost && *best_cost <= *cost)) { + continue; + } + std::cout << "# Found solution [" << absl::StrJoin(solution, ", ") + << "] with cost " << *cost << std::endl;; + best_cost = *cost; + best_solution = solution; + } + if (!best_solution) { + return absl::NotFoundError("No solution found"); + } + return *best_solution; +} + +} // namespace iopddl diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/solver.h new file mode 100644 index 00000000000000..1da575ae7c37c5 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/solver.h @@ -0,0 +1,37 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_SOLVER_H_ +#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_SOLVER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "xla/hlo/experimental/auto_sharding/iopddl.h" + +namespace iopddl { + +class Solver { + public: + absl::StatusOr Solve( + const Problem& problem, + absl::Duration timeout = absl::InfiniteDuration()); +}; + +} // namespace iopddl + +#endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_SOLVER_H_ From 6bc18a5a09055bf4acdc3b725aa5068f0cdccdcb Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Wed, 9 Apr 2025 13:08:44 -0700 Subject: [PATCH 0466/1324] copy tensorflow/compiler/mlir/lite:validators to mlir/utils:validators PiperOrigin-RevId: 745704781 --- tensorflow/compiler/mlir/stablehlo/BUILD | 2 +- .../mhlo_passes/fuse_convolution_pass.cc | 4 +- .../compiler/mlir/tensorflow/transforms/BUILD | 2 +- .../mlir/tensorflow/transforms/optimize.cc | 2 +- .../mlir/tensorflow/transforms/optimize.td | 2 +- tensorflow/compiler/mlir/utils/BUILD | 16 ++ tensorflow/compiler/mlir/utils/validators.cc | 147 ++++++++++++++++++ tensorflow/compiler/mlir/utils/validators.h | 126 +++++++++++++++ 8 files changed, 295 insertions(+), 6 deletions(-) create mode 100644 tensorflow/compiler/mlir/utils/validators.cc create mode 100644 tensorflow/compiler/mlir/utils/validators.h diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 4d24ec1f0e6661..ef782f48d939da 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -140,8 +140,8 @@ cc_library( "-Ithird_party", ], deps = [ - "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/utils:validators", "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc index a701f7830841b0..2a6db05dffc98e 100644 --- a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc @@ -36,8 +36,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/utils/validators.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -95,7 +95,7 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { // format and backprop input conv filter is in HWOI format. // Only fuses multiplier if all dimensions other than the out channel // dimension are equal to 1. - if (!TFL::IsDimensionsDegenerateExceptLastOne( + if (!TF::IsDimensionsDegenerateExceptLastOne( mul_value.getShapedType().getShape())) { return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { diag << "entities 'mul_value' failed to satisfy constraint: " diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 3f4af7a26e34fc..1e14755b0119f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -557,7 +557,6 @@ cc_library( ":verify_no_outside_compilation_markers_pass", "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", @@ -599,6 +598,7 @@ cc_library( "//tensorflow/compiler/mlir/tf2xla/transforms:split_into_island_per_op_pass", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/utils:validators", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla/kernels:xla_call_module_loader", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index be7e914bd29846..f02dffc5d6f2f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -27,10 +27,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" +#include "tensorflow/compiler/mlir/utils/validators.h" // IWYU pragma: keep namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index be01d276902047..188fbbb6be532b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -40,7 +40,7 @@ def IsNotComplexType : Constraint>; + Constraint>; def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; diff --git a/tensorflow/compiler/mlir/utils/BUILD b/tensorflow/compiler/mlir/utils/BUILD index 2256c421b45717..13cdb3e51d33a9 100644 --- a/tensorflow/compiler/mlir/utils/BUILD +++ b/tensorflow/compiler/mlir/utils/BUILD @@ -37,3 +37,19 @@ cc_library( "@llvm-project//llvm:Support", ], ) + +cc_library( + name = "validators", + srcs = [ + "validators.cc", + ], + hdrs = [ + "validators.h", + ], + deps = [ + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/utils/validators.cc b/tensorflow/compiler/mlir/utils/validators.cc new file mode 100644 index 00000000000000..870c7e1f1efbfe --- /dev/null +++ b/tensorflow/compiler/mlir/utils/validators.cc @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/validators.h" + +#include +#include + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y) { + auto attr = op->getAttrOfType(name); + if (!attr) return false; + + auto elements = attr.getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) + return false; + + Builder b(op->getContext()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + + return true; +} + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getValue() != 1 || + mlir::cast(elements.back()).getValue() != 1) + return false; + return true; +} + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + const Attribute *data = elements.data(); + if (mlir::cast(data[0]).getValue() != 1 || + mlir::cast(data[1]).getValue() != 1) + return false; + return true; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z) { + auto attr = op->getAttrOfType(name); + if (!attr) return false; + + auto elements = attr.getValue(); + if (elements.size() != 5 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) + return false; + + Builder b(op->getContext()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + *z = b.getI32IntegerAttr(mlir::cast(elements[3]).getInt()); + + return true; +} + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + + return !std::any_of(elements.begin(), elements.end(), [](Attribute e) { + return mlir::cast(e).getValue() != 1; + }); +} + +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b) { + // This would return false if we had unranked tensors (where they should + // probably be considered as broadcastable), but given we are working with + // attributes here that shouldn't be an issue, + return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); +} + +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape) { + if (elements_shape.empty()) return true; + + for (auto dim : elements_shape.drop_back(1)) { + if (dim != 1) return false; + } + return true; +} + +bool IsDimensionsDegenerateExceptLastOne(TypedAttr val) { + if (auto ranked_type = mlir::dyn_cast(val.getType())) { + return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape()); + } + return false; +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/validators.h b/tensorflow/compiler/mlir/utils/validators.h new file mode 100644 index 00000000000000..b55bd219914603 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/validators.h @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common validators used by TFLite transformation +// passes to validate op attributes or values. + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// TODO(jpienaar): Change these to being one of these variants and/or generate +// these predicates. + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NHWC"), or its `data_format` attribute is "NHWC". +inline bool TFDataFormatIsNHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NHWC"; +} + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NDHWC"), or its `data_format` attribute is +// "NDHWC". +inline bool TFDataFormatIsNDHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NDHWC"; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y); + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(Attribute attr); + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(Attribute attr); + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z); + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(Attribute attr); + +// Returns true iff the given value is a float32 tensor. +// is "DT_FLOAT". +inline bool TFTypeIsFloat32Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF32(); +} + +// Returns true iff the given value is a bf16 tensor. +inline bool TFTypeIsBFloat16Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isBF16(); +} + +// Returns true iff the given value is a f16 tensor. +inline bool TFTypeIsHalfTensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF16(); +} + +// Returns true iff the given value is a f16 or bf16 tensor. +inline bool TFTypeIsBFloat16OrHalfTensor(Value value) { + return TFTypeIsBFloat16Tensor(value) || TFTypeIsHalfTensor(value); +} + +// Returns true iff the given TensorFlow op has a `padding` attribute whose +// value is "SAME" or "VALID", and writes the attribute to `padding`. +inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) { + auto padding_attr = op->getAttrOfType("padding"); + if (padding_attr.getValue() != "SAME" && padding_attr.getValue() != "VALID") + return false; + *padding = padding_attr; + return true; +} + +/// Returns whether the given `a` and `b` have broadcast-compatible +/// types. +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b); +// Returns true if every dimension of the attribute is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(mlir::TypedAttr val); +// Returns true if every element is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape); + +} // end namespace TF +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ From cb94d95e45e6870889e49fb2f80051a6af2a2a70 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Wed, 9 Apr 2025 13:36:15 -0700 Subject: [PATCH 0467/1324] Fix check in HloDataflowAnalysis for compatible shape kind. PiperOrigin-RevId: 745714947 --- third_party/xla/xla/hlo/analysis/BUILD | 1 + .../analysis/hlo_dataflow_analysis_test.cc | 49 +++++++++++++++++++ third_party/xla/xla/service/hlo_value.cc | 9 +++- 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/analysis/BUILD b/third_party/xla/xla/hlo/analysis/BUILD index e0157bf490a36b..c20bfc78def4ec 100644 --- a/third_party/xla/xla/hlo/analysis/BUILD +++ b/third_party/xla/xla/hlo/analysis/BUILD @@ -243,6 +243,7 @@ xla_cc_test( "//xla/service:hlo_creation_utils", "//xla/service:hlo_value", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc index 783bda9b153eac..fe1ce5f513807f 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -3625,5 +3626,53 @@ ENTRY AllToAll { EXPECT_EQ(in_place_pairs, expected_pairs); } +// Test to check that the dataflow analysis works with a module that has scalar +// bitcast user. +TEST_P(HloDataflowAnalysisTest, b409416499) { + const char* after_layout_bitcast = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(s32[1]{0:T(128)}, s32[1]{0:T(128)}, s32[1]{0:T(128)}, s32[1]{0:T(128)})->(s32[1]{0:T(128)}, s32[1]{0:T(128)}, s32[1]{0:T(128)}, s32[1]{0:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, allow_spmd_sharding_propagation_to_output={true,true,true,true}, num_partitions=4 + %region_0.13_spmd (param.1: s32[]) -> s32[] { + %param.1 = s32[]{:T(128)} parameter(0), metadata={op_name="jit()/jit(main)/jit(shmap_body)/while"} + %constant.1 = s32[]{:T(128)} constant(1) + ROOT %add.0 = s32[]{:T(128)} add(%param.1, %constant.1), metadata={op_name="jit()/jit(main)/jit(shmap_body)/while/body/add" source_file="third_party/py/jax/tests/shard_map_test.py" source_line=1052} + } + + %region_1.17_spmd (param: s32[]) -> pred[] { + %param = s32[]{:T(128)} parameter(0), metadata={op_name="jit()/jit(main)/jit(shmap_body)/while"} + %constant = s32[]{:T(128)} constant(1) + ROOT %compare.0 = pred[]{:T(512)} compare(%param, %constant), direction=LT, metadata={op_name="jit()/jit(main)/jit(shmap_body)/while/cond/lt" source_file="third_party/py/jax/tests/shard_map_test.py" source_line=1049} + } + + ENTRY %main.44_spmd (param.2: s32[1], param.3: s32[1], param.4: s32[1], param.5: s32[1]) -> (s32[1], s32[1], s32[1], s32[1]) { + %param.2 = s32[1]{0:T(128)} parameter(0), sharding={devices=[4]<=[4]}, metadata={op_name="args[0]"} + %bitcast.2 = s32[]{:T(128)} bitcast(%param.2), metadata={op_name="jit()/jit(main)/jit(shmap_body)/squeeze" source_file="third_party/py/jax/tests/shard_map_test.py" source_line=1053} + %while.1 = s32[]{:T(128)} while(%bitcast.2), condition=%region_1.17_spmd, body=%region_0.13_spmd, metadata={op_name="jit()/jit(main)/jit(shmap_body)/while" source_file="third_party/py/jax/tests/shard_map_test.py" source_line=1053} + %bitcast.3 = s32[1]{0:T(128)} bitcast(%while.1), metadata={op_name="jit()/jit(main)/jit(shmap_body)/broadcast_in_dim" source_file="third_party/py/jax/tests/shard_map_test.py" source_line=1053} + %param.3 = s32[1]{0:T(128)} parameter(1), sharding={devices=[4]<=[4]}, metadata={op_name="args[1]"} + %param.4 = s32[1]{0:T(128)} parameter(2), sharding={devices=[4]<=[4]}, metadata={op_name="args[2]"} + %param.5 = s32[1]{0:T(128)} parameter(3), sharding={devices=[4]<=[4]}, metadata={op_name="args[3]"} + ROOT %tuple.1 = (s32[1]{0:T(128)}, s32[1]{0:T(128)}, s32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(%bitcast.3, %param.3, %param.4, %param.5) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto after_layout_bitcast_module, + ParseAndReturnVerifiedModule(after_layout_bitcast)); + TF_ASSERT_OK_AND_ASSIGN(auto analysis, + HloDataflowAnalysis::Run(*after_layout_bitcast_module, + /*ssa_form=*/false)); + HloInstruction* bitcast3 = + FindInstruction(after_layout_bitcast_module.get(), "bitcast.3"); + HloInstruction* param2 = + FindInstruction(after_layout_bitcast_module.get(), "param.2"); + HloComputation* while_body = + FindComputation(after_layout_bitcast_module.get(), "region_0.13_spmd"); + HloInstruction* add0 = while_body->root_instruction(); + std::vector defining_instructions; + for (const HloValue* value : + analysis->GetValueSet(bitcast3, {}).TakeValues()) { + defining_instructions.push_back(value->defining_instruction()); + } + EXPECT_THAT(defining_instructions, UnorderedElementsAre(param2, add0)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_value.cc b/third_party/xla/xla/service/hlo_value.cc index 8d331f89b81f81..699b4ed6ad474c 100644 --- a/third_party/xla/xla/service/hlo_value.cc +++ b/third_party/xla/xla/service/hlo_value.cc @@ -287,7 +287,14 @@ bool InstructionValueSet::AssignUnionOf( absl::Span inputs) { CHECK_GT(inputs.size(), 0); for (int i = 1; i < inputs.size(); ++i) { - DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); + // It is possible that some values come from effective scalar shapes, i.e., + // X[1] that was bitcasted to X[]. In such cases, shapes are not compatible + // but it is still valid to get the union of the values. + bool shapes_are_effective_scalar = + ShapeUtil::IsEffectiveScalar(inputs[0]->shape()) && + ShapeUtil::IsEffectiveScalar(inputs[i]->shape()); + DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()) || + shapes_are_effective_scalar); } bool changed = false; for (auto& pair : *this) { From d38b0b108f22053ab4541bc1b4628f536044a2ed Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 9 Apr 2025 13:41:35 -0700 Subject: [PATCH 0468/1324] Skip jax/tests:unary_ops_accuracy_test when running with older versions of StableHLO. PiperOrigin-RevId: 745717137 --- third_party/xla/xla/pjrt/mlir_to_hlo.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.cc b/third_party/xla/xla/pjrt/mlir_to_hlo.cc index c3cbb60ee51ca4..24dede0eeca4fa 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.cc @@ -270,8 +270,8 @@ absl::StatusOr SerializeUsingVersionedStablehlo( mlir_module, target.value(), os))) { const absl::Status status = diagnostic_handler.ConsumeStatus(); return absl::InvalidArgumentError(absl::StrCat( - "Failed to serialize StableHLO;\n\nDetailed error from MLIR: ", - status.message())); + "Failed to serialize StableHLO to plugin version ", target.value(), + ";\n\nDetailed error from MLIR: ", status.message())); } return buffer; } From 4fcda01c64d8180b893e922056f96f7e154e2de6 Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Wed, 9 Apr 2025 14:11:21 -0700 Subject: [PATCH 0469/1324] Drop old CUB version. Always use CUB bundled with CUDA SDK. PiperOrigin-RevId: 745729192 --- tensorflow/tools/lib_package/BUILD | 2 -- tensorflow/tools/pip_package/BUILD | 1 - tensorflow/workspace2.bzl | 10 ---------- third_party/gpus/cuda_configure.bzl | 6 +----- third_party/xla/third_party/gpus/cuda_configure.bzl | 6 +----- 5 files changed, 2 insertions(+), 23 deletions(-) diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 755479e5823ff2..05e9c40074471d 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -166,7 +166,6 @@ genrule( "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", ], }) + if_cuda([ - "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", @@ -205,7 +204,6 @@ genrule( "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", ], }) + if_cuda([ - "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index f54660fa75d241..5b2f730f402cc4 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -126,7 +126,6 @@ filegroup( "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", ], }) + if_cuda([ - "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index eab74a3f526060..6ce2f4be5b2a80 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -644,16 +644,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/google/pprof/archive/83db2b799d1f74c40857232cb5eb4c60379fe6c2.tar.gz"), ) - # The CUDA 11 toolkit ships with CUB. We should be able to delete this rule - # once TF drops support for CUDA 10. - tf_http_archive( - name = "cub_archive", - build_file = "//third_party:cub.BUILD", - sha256 = "162514b3cc264ac89d91898b58450190b8192e2af1142cf8ccac2d59aa160dda", - strip_prefix = "cub-1.9.9", - urls = tf_mirror_urls("https://github.com/NVlabs/cub/archive/1.9.9.zip"), - ) - tf_http_archive( name = "cython", build_file = "//third_party:cython.BUILD", diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index c1e6c42e4fd49e..d110fa8146083d 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1065,10 +1065,6 @@ def _create_local_cuda_repository(repository_ctx): }, ) - cub_actual = "@cub_archive//:cub" - if int(cuda_config.cuda_version_major) >= 11: - cub_actual = ":cuda_headers" - repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], @@ -1085,7 +1081,7 @@ def _create_local_cuda_repository(repository_ctx): "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), - "%{cub_actual}": cub_actual, + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": "\n".join(copy_rules), }, ) diff --git a/third_party/xla/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/gpus/cuda_configure.bzl index c1e6c42e4fd49e..d110fa8146083d 100644 --- a/third_party/xla/third_party/gpus/cuda_configure.bzl +++ b/third_party/xla/third_party/gpus/cuda_configure.bzl @@ -1065,10 +1065,6 @@ def _create_local_cuda_repository(repository_ctx): }, ) - cub_actual = "@cub_archive//:cub" - if int(cuda_config.cuda_version_major) >= 11: - cub_actual = ":cuda_headers" - repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], @@ -1085,7 +1081,7 @@ def _create_local_cuda_repository(repository_ctx): "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), - "%{cub_actual}": cub_actual, + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": "\n".join(copy_rules), }, ) From 6d4739d22cccfbf67fa215653b015445c880afa4 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Wed, 9 Apr 2025 14:21:46 -0700 Subject: [PATCH 0470/1324] Extract shared ClientLibraryTestBase+ClientLibraryTestRunnerMixin helper funs. PiperOrigin-RevId: 745733153 --- third_party/xla/xla/tests/BUILD | 19 ++++ .../xla/xla/tests/client_library_test_base.cc | 22 ----- .../xla/xla/tests/client_library_test_base.h | 27 ++---- .../tests/client_library_test_runner_mixin.h | 23 +---- .../tests/client_library_test_runner_utils.cc | 88 +++++++++++++++++++ .../tests/client_library_test_runner_utils.h | 74 ++++++++++++++++ third_party/xla/xla/tests/concat_test.cc | 1 + 7 files changed, 194 insertions(+), 60 deletions(-) create mode 100644 third_party/xla/xla/tests/client_library_test_runner_utils.cc create mode 100644 third_party/xla/xla/tests/client_library_test_runner_utils.h diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index e29b60c4c4070b..840e28a3d61583 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -338,6 +338,7 @@ cc_library( srcs = ["client_library_test_base.cc"], hdrs = ["client_library_test_base.h"], deps = [ + ":client_library_test_runner_utils", ":literal_test_util", ":test_utils", "//xla:array2d", @@ -375,6 +376,7 @@ cc_library( testonly = True, hdrs = ["client_library_test_runner_mixin.h"], deps = [ + ":client_library_test_runner_utils", ":hlo_runner_agnostic_test_base", ":literal_test_util", "//xla:array2d", @@ -401,6 +403,22 @@ cc_library( ], ) +cc_library( + name = "client_library_test_runner_utils", + testonly = True, + srcs = ["client_library_test_runner_utils.cc"], + hdrs = ["client_library_test_runner_utils.h"], + deps = [ + ":test_utils", + "//xla:array2d", + "//xla:shape_util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/tsl/platform:status", + "@com_google_absl//absl/log:check", + ], +) + cc_library( name = "llvm_irgen_test_base", testonly = True, @@ -2560,6 +2578,7 @@ xla_test( ], deps = [ ":client_library_test_runner_mixin", + ":client_library_test_runner_utils", ":hlo_pjrt_interpreter_reference_mixin", ":hlo_pjrt_test_base", ":test_macros_header", diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 74ac6e801a68fe..1358c8600ffc04 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -472,28 +472,6 @@ ClientLibraryTestBase::ComputeValueAndReference( return std::make_pair(std::move(reference), std::move(result)); } -XlaComputation ClientLibraryTestBase::CreateScalarReluF32() { - XlaBuilder builder("relu"); - auto shape = ShapeUtil::MakeShape(F32, {}); - auto z_value = Parameter(&builder, 0, shape, "z_value"); - auto zero = ConstantR0(&builder, 0.0f); - Max(z_value, zero); - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return std::move(computation_status).value(); -} - -XlaComputation ClientLibraryTestBase::CreateScalarMax() { - XlaBuilder builder("max"); - auto shape = ShapeUtil::MakeShape(test_type_, {}); - auto x = Parameter(&builder, 0, shape, "x"); - auto y = Parameter(&builder, 1, shape, "y"); - Max(x, y); - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return std::move(computation_status).value(); -} - std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { auto array = std::make_unique>(rows, cols); diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index a5d0a457525275..03d526b85a6105 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -29,6 +29,7 @@ static_assert(false, "test that has been explicitly migrated to use HloRunnerPjRt."); #endif // XLA_TEST_MIGRATED_TO_HLO_RUNNER_PJRT +#include #include #include #include @@ -46,6 +47,7 @@ static_assert(false, #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tests/client_library_test_runner_utils.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/bitmap.h" @@ -219,8 +221,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::Span arguments, std::optional error = std::nullopt); // Create scalar operations for use in reductions. - XlaComputation CreateScalarReluF32(); - XlaComputation CreateScalarMax(); + XlaComputation CreateScalarReluF32() { return xla::CreateScalarReluF32(); } + XlaComputation CreateScalarMax() { return xla::CreateScalarMax(test_type_); } // Special case convenience functions for creating filled arrays. @@ -576,27 +578,16 @@ std::unique_ptr ClientLibraryTestBase::CreateParameter( template std::vector ClientLibraryTestBase::CreatePseudorandomR1( - const int width, NativeT min_value, NativeT max_value, uint32_t seed) { - std::vector result(width); - PseudorandomGenerator generator(min_value, max_value, seed); - for (int i = 0; i < width; ++i) { - result[i] = generator.get(); - } - return result; + const int width, NativeT min_value, NativeT max_value, + const uint32_t seed) { + return xla::CreatePseudorandomR1(width, min_value, max_value, seed); } template std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, - uint32_t seed) { - auto result = std::make_unique>(rows, cols); - PseudorandomGenerator generator(min_value, max_value, seed); - for (int y = 0; y < rows; ++y) { - for (int x = 0; x < cols; ++x) { - (*result)(y, x) = generator.get(); - } - } - return result; + const uint32_t seed) { + return xla::CreatePseudorandomR2(rows, cols, min_value, max_value, seed); } } // namespace xla diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index aba09bdedfcf2d..ca94bc1ff5c188 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_utils.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/bitmap.h" @@ -249,6 +250,8 @@ class ClientLibraryTestRunnerMixin : public T { ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } + XlaComputation CreateScalarMax() { return xla::CreateScalarMax(test_type_); } + Literal CreateParameterAndTransferLiteral(const int64_t parameter_number, const Literal& literal, const std::string& name, @@ -328,26 +331,6 @@ class ClientLibraryTestRunnerMixin : public T { return execution_options_.mutable_debug_options(); } - // Creates a (rows x cols) array filled in the following form: - // - // [ 0 1 ... cols-1] - // [ 1,000 1,001 ... 1000.0 + cols-1] - // [ ... ... ... ...] - // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1] - // - // If provided, offset is added uniformly to every element (e.g. an offset of - // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.) - static std::unique_ptr> CreatePatternedMatrix( - const int rows, const int cols, float offset = 0.0) { - auto array = std::make_unique>(rows, cols); - for (int64_t row = 0; row < rows; ++row) { - for (int64_t col = 0; col < cols; ++col) { - (*array)(row, col) = col + (row * 1000.0f) + offset; - } - } - return array; - } - private: absl::StatusOr> BuildAndVerifyHloModule( const XlaComputation& computation, diff --git a/third_party/xla/xla/tests/client_library_test_runner_utils.cc b/third_party/xla/xla/tests/client_library_test_runner_utils.cc new file mode 100644 index 00000000000000..f344ca5311eb2b --- /dev/null +++ b/third_party/xla/xla/tests/client_library_test_runner_utils.cc @@ -0,0 +1,88 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tests/client_library_test_runner_utils.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "xla/array2d.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/status.h" + +namespace xla { + +XlaComputation CreateScalarReluF32() { + XlaBuilder builder("relu"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + XlaOp z_value = Parameter(&builder, 0, std::move(shape), "z_value"); + XlaOp zero = ConstantR0(&builder, 0.0f); + Max(std::move(z_value), std::move(zero)); + absl::StatusOr computation = builder.Build(); + TF_CHECK_OK(computation.status()); + return *std::move(computation); +} + +XlaComputation CreateScalarMax(const PrimitiveType test_type) { + XlaBuilder builder("max"); + Shape shape = ShapeUtil::MakeShape(test_type, {}); + XlaOp x = Parameter(&builder, 0, shape, "x"); + XlaOp y = Parameter(&builder, 1, shape, "y"); + Max(std::move(x), std::move(y)); + absl::StatusOr computation = builder.Build(); + TF_CHECK_OK(computation.status()); + return *std::move(computation); +} + +// Creates a (rows x cols) array filled in the following form: +// +// [ 0 1 ... cols-1] +// [ 1,000 1,001 ... 1000.0 + cols-1] +// [ ... ... ... ...] +// [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1] +// +// If provided, offset is added uniformly to every element (e.g. an offset of +// 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.) +std::unique_ptr> CreatePatternedMatrix(const int rows, + const int cols, + float offset) { + auto array = std::make_unique>(rows, cols); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t col = 0; col < cols; ++col) { + (*array)(row, col) = col + (row * 1000.0f) + offset; + } + } + return array; +} + +std::unique_ptr> CreatePatternedMatrixWithZeroPadding( + const int rows, const int cols, const int rows_padded, + const int cols_padded) { + CHECK_GE(rows_padded, rows); + CHECK_GE(cols_padded, cols); + auto array = std::make_unique>(rows_padded, cols_padded, 0.0); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t col = 0; col < cols; ++col) { + (*array)(row, col) = col + (row * 1000.0f); + } + } + return array; +} +} // namespace xla diff --git a/third_party/xla/xla/tests/client_library_test_runner_utils.h b/third_party/xla/xla/tests/client_library_test_runner_utils.h new file mode 100644 index 00000000000000..0c0d04e0d3bab2 --- /dev/null +++ b/third_party/xla/xla/tests/client_library_test_runner_utils.h @@ -0,0 +1,74 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_UTILS_H_ +#define XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_UTILS_H_ + +#include +#include +#include + +#include "xla/array2d.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/tests/test_utils.h" + +namespace xla { +// Create scalar operations for use in reductions. +XlaComputation CreateScalarReluF32(); +XlaComputation CreateScalarMax(PrimitiveType test_type); + +// Special case convenience functions for creating filled arrays. + +// Creates an array of pseudorandom values lying between the given minimum and +// maximum values. +template +std::vector CreatePseudorandomR1(const int width, NativeT min_value, + NativeT max_value, uint32_t seed) { + std::vector result(width); + PseudorandomGenerator generator(min_value, max_value, seed); + for (int i = 0; i < width; ++i) { + result[i] = generator.get(); + } + return result; +} + +template +std::unique_ptr> CreatePseudorandomR2(const int rows, + const int cols, + const NativeT min_value, + const NativeT max_value, + const uint32_t seed) { + auto result = std::make_unique>(rows, cols); + PseudorandomGenerator generator(min_value, max_value, seed); + for (int y = 0; y < rows; ++y) { + for (int x = 0; x < cols; ++x) { + (*result)(y, x) = generator.get(); + } + } + return result; +} + +std::unique_ptr> CreatePatternedMatrix(int rows, int cols, + float offset = 0.0f); + +// Creates a (rows x cols) array as above, padded out to +// (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows +// and cols_padded > cols. +std::unique_ptr> CreatePatternedMatrixWithZeroPadding( + int rows, int cols, int rows_padded, int cols_padded); + +} // namespace xla + +#endif // XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_UTILS_H_ diff --git a/third_party/xla/xla/tests/concat_test.cc b/third_party/xla/xla/tests/concat_test.cc index 7fa71f38934a7b..78b2d090560168 100644 --- a/third_party/xla/xla/tests/concat_test.cc +++ b/third_party/xla/xla/tests/concat_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/reference_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/client_library_test_runner_utils.h" #include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" #include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" From ed315b3a71228ca6341a06dc82c90b3a5424280c Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Wed, 9 Apr 2025 14:48:38 -0700 Subject: [PATCH 0471/1324] Add ComputeAndCompareLiteral overload without error_spec. PiperOrigin-RevId: 745742726 --- .../xla/xla/tests/client_library_test_runner_mixin.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index ca94bc1ff5c188..c44536ae19cacb 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -140,6 +140,16 @@ class ClientLibraryTestRunnerMixin : public T { EXPECT_TRUE(this->RunAndCompare(std::move(module), arguments, error)); } + // Compare with literal. + // Side effect: EXPECT_OK + void ComputeAndCompareLiteral(XlaBuilder* const builder, + const Literal& expected, + const absl::Span arguments, + const Shape* shape_with_layout) { + return ComputeAndCompareLiteral(builder, expected, arguments, std::nullopt, + shape_with_layout); + } + // Compare with literal. // Side effect: EXPECT_OK void ComputeAndCompareLiteral( From 6358e9ee9c40dea8f2fda695242ef63d07ae67b4 Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Wed, 9 Apr 2025 14:55:34 -0700 Subject: [PATCH 0472/1324] Add python 3.13 version for wheel release job for linux, mac, win PiperOrigin-RevId: 745744999 --- ci/official/envs/py313 | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 ci/official/envs/py313 diff --git a/ci/official/envs/py313 b/ci/official/envs/py313 new file mode 100644 index 00000000000000..1210c5eca815f8 --- /dev/null +++ b/ci/official/envs/py313 @@ -0,0 +1,15 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +TFCI_PYTHON_VERSION=3.13 From 63db9bf38fc405f73b3e68fee3911bcffcafefe9 Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Wed, 9 Apr 2025 15:13:04 -0700 Subject: [PATCH 0473/1324] [XLA:Upkeep] Resolve 6 instances of the following issue: Todo (resolved) PiperOrigin-RevId: 745751276 --- third_party/xla/xla/backends/cpu/collectives/BUILD | 6 ------ 1 file changed, 6 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/collectives/BUILD b/third_party/xla/xla/backends/cpu/collectives/BUILD index bca2f4d6a7f852..cfce13a7963680 100644 --- a/third_party/xla/xla/backends/cpu/collectives/BUILD +++ b/third_party/xla/xla/backends/cpu/collectives/BUILD @@ -110,7 +110,6 @@ cc_library( ], ) -# TODO(b/380457503): Restrict visibility to private. cc_library( name = "in_process_collectives", srcs = ["in_process_collectives.cc"], @@ -143,7 +142,6 @@ cc_library( ], ) -# TODO(b/380457503): Restrict visibility to private. cc_library( name = "in_process_communicator", srcs = ["in_process_communicator.cc"], @@ -177,7 +175,6 @@ cc_library( ], ) -# TODO(b/380457503): Restrict visibility to private. cc_library( name = "gloo_kv_store", srcs = ["gloo_kv_store.cc"], @@ -200,7 +197,6 @@ cc_library( ], ) -# TODO(b/380457503): Restrict visibility to private. cc_library( name = "gloo_collectives", srcs = ["gloo_collectives.cc"], @@ -282,7 +278,6 @@ xla_cc_test( }), ) -# TODO(b/380457503): Restrict visibility to private. cc_library( name = "gloo_communicator", srcs = ["gloo_communicator.cc"], @@ -362,7 +357,6 @@ cc_library( ], ) -# TODO(b/380457503): Restrict visibility to private. cc_library( name = "mpi_communicator", srcs = ["mpi_communicator.cc"], From 0e2c33b26d31c25fa7604856c9c9fc5b685e1077 Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Wed, 9 Apr 2025 15:41:53 -0700 Subject: [PATCH 0474/1324] move tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.{h,cc} to tensorflow/compiler/mlir/tools/. PiperOrigin-RevId: 745760984 --- tensorflow/compiler/mlir/BUILD | 2 +- tensorflow/compiler/mlir/lite/BUILD | 1 - .../compiler/mlir/lite/tf_tfl_translate.cc | 7 +- .../compiler/mlir/lite/tf_tfl_translate_cl.cc | 70 +++++++++++++++++++ .../compiler/mlir/lite/tf_tfl_translate_cl.h | 11 +++ tensorflow/compiler/mlir/lite/tools/BUILD | 16 +---- .../tools/tf_mlir_translate_registration.cc | 2 +- tensorflow/compiler/mlir/tensorflow/BUILD | 2 +- .../tensorflow/utils/tf_xla_mlir_translate.cc | 2 +- .../mlir/tf2xla/tests/registration/BUILD | 2 +- .../graph_to_tf_executor_registration.cc | 2 +- .../compiler/mlir/tf_mlir_translate_main.cc | 2 +- tensorflow/compiler/mlir/tools/BUILD | 21 ++++++ .../{lite => }/tools/tf_mlir_translate_cl.cc | 2 +- .../{lite => }/tools/tf_mlir_translate_cl.h | 6 +- 15 files changed, 120 insertions(+), 28 deletions(-) create mode 100644 tensorflow/compiler/mlir/tools/BUILD rename tensorflow/compiler/mlir/{lite => }/tools/tf_mlir_translate_cl.cc (98%) rename tensorflow/compiler/mlir/{lite => }/tools/tf_mlir_translate_cl.h (91%) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 20c7d3abb35dfc..e85e3935e54b95 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -220,11 +220,11 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/lite/tools:translate_registration", "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tf2xla/tests/registration:graph_to_tf_executor_registration", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 4ecd78da80209c..5ee2ccbb024788 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1769,7 +1769,6 @@ tf_cc_binary( ":tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite:converter_flags_proto_cc", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 5b20a6e72f9984..5cc8856aedd0c7 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -48,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -58,9 +57,15 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +using llvm::cl::opt; using mlir::MLIRContext; using mlir::ModuleOp; +// NOLINTNEXTLINE +opt upgrade_legacy("tf-upgrade-legacy", + llvm::cl::desc("Upgrade legacy TF graph behavior"), + llvm::cl::init(false)); + // NOLINTNEXTLINE static llvm::cl::opt weight_quantization( "weight_quantization", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index 0f05c371868b8d..a332af0e2a5734 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -218,3 +218,73 @@ opt model_origin_framework( "model-origin-framework", llvm::cl::desc("The source model type: PYTORCH, JAX, TENSORFLOW, etc."), llvm::cl::init("UNSET")); + +// NOLINTNEXTLINE +opt input_arrays( + "tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt input_dtypes( + "tf-input-data-types", + llvm::cl::desc("(Optional) Input tensor data types, separated by ','. Use " + "'' if a single data type is skipped. The data type from " + "the import graph is used if it is skipped."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt input_shapes( + "tf-input-shapes", + llvm::cl::desc( + "Input tensor shapes. Shapes for different tensors are separated by " + "':', and dimension sizes for the same tensor are separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_arrays( + "tf-output-arrays", llvm::cl::desc("Output tensor names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt control_output_arrays( + "tf-control-output-arrays", + llvm::cl::desc("Control output node names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt inference_type( + "tf-inference-type", + llvm::cl::desc( + "Sets the type of real-number arrays in the output file. Only allows " + "float and quantized types"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt min_values( + "tf-input-min-values", + llvm::cl::desc( + "Sets the lower bound of the input data. Separated by ','; Each entry " + "in the list should match an entry in -tf-input-arrays. This is " + "used when -tf-inference-type is a quantized type."), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt max_values( + "tf-input-max-values", + llvm::cl::desc( + "Sets the upper bound of the input data. Separated by ','; Each entry " + "in the list should match an entry in -tf-input-arrays. This is " + "used when -tf-inference-type is a quantized type."), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt debug_info_file( + "tf-debug-info", + llvm::cl::desc("Path to the debug info file of the input graph def"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt enable_shape_inference( + "tf-enable-shape-inference-on-import", + llvm::cl::desc("Enable shape inference on import (temporary)"), + llvm::cl::init(false)); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h index c225291360c9df..6095b69d471ad8 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h @@ -48,6 +48,17 @@ extern llvm::cl::opt enable_dynamic_update_slice; extern llvm::cl::opt preserve_assert_op; extern llvm::cl::opt legalize_custom_tensor_list_ops; extern llvm::cl::opt reduce_type_precision; +extern llvm::cl::opt input_arrays; +extern llvm::cl::opt input_dtypes; +extern llvm::cl::opt input_shapes; +extern llvm::cl::opt output_arrays; +extern llvm::cl::opt control_output_arrays; +extern llvm::cl::opt inference_type; +extern llvm::cl::opt min_values; +extern llvm::cl::opt max_values; +extern llvm::cl::opt debug_info_file; +extern llvm::cl::opt upgrade_legacy; +extern llvm::cl::opt enable_shape_inference; // Import saved model. extern llvm::cl::opt import_saved_model_object_graph; diff --git a/tensorflow/compiler/mlir/lite/tools/BUILD b/tensorflow/compiler/mlir/lite/tools/BUILD index 63590fc545fd4d..134f5b767f863d 100644 --- a/tensorflow/compiler/mlir/lite/tools/BUILD +++ b/tensorflow/compiler/mlir/lite/tools/BUILD @@ -23,31 +23,17 @@ cc_library( # LINT.ThenChange(//tensorflow/lite/tools:command_line_flags) -cc_library( - name = "translate_cl_options", - srcs = [ - "tf_mlir_translate_cl.cc", - ], - hdrs = [ - "tf_mlir_translate_cl.h", - ], - deps = [ - "@llvm-project//llvm:Support", - ], - alwayslink = 1, -) - cc_library( name = "translate_registration", srcs = [ "tf_mlir_translate_registration.cc", ], deps = [ - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc index 4a07a184bbffb9..7d14d3e954b5f9 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc @@ -21,8 +21,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/file_tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/core/framework/graph.pb.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a09199d428cd6b..ecbc3de6bd7db4 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -978,9 +978,9 @@ cc_library( ":mlir_roundtrip_flags", ":serialize_mlir_module_utils", ":tensorflow", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow/translate/tools:parsers", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/mlir/utils:string_container_utils", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_argument", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 348ae41e3d2ebb..34917780dc80cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -48,13 +48,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD index a5d8d8d8c5183f..f46627f0e43565 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD @@ -12,11 +12,11 @@ cc_library( "graph_to_tf_executor_registration.cc", ], deps = [ - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc index 8a9811c8dcbcbc..7b7b5771f5a4be 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc @@ -26,11 +26,11 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/file_tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/client_library.h" diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index babd62f6b13f89..80e58756bbfaad 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -33,8 +33,8 @@ limitations under the License. #include "mlir/Support/ToolUtilities.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/core/platform/init_main.h" // NOLINTNEXTLINE diff --git a/tensorflow/compiler/mlir/tools/BUILD b/tensorflow/compiler/mlir/tools/BUILD new file mode 100644 index 00000000000000..d3d0aa56ab97aa --- /dev/null +++ b/tensorflow/compiler/mlir/tools/BUILD @@ -0,0 +1,21 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "translate_cl_options", + srcs = [ + "tf_mlir_translate_cl.cc", + ], + hdrs = [ + "tf_mlir_translate_cl.h", + ], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc similarity index 98% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc index 46b3a5500052a9..db21d257cd58f5 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h similarity index 91% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h index b3da62caa95e5e..ef67186d206644 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ // This file contains command-line options aimed to provide the parameters // required by the TensorFlow Graph(Def) to MLIR module conversion. It is only @@ -51,4 +51,4 @@ extern llvm::cl::opt set_original_tf_func_name; extern llvm::cl::opt export_entry_func_to_flib; extern llvm::cl::opt export_original_tf_func_name; -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ From 43d5e257613171aba5ec8e87c963fc27a83bb4eb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 16:02:40 -0700 Subject: [PATCH 0475/1324] Integrate LLVM at llvm/llvm-project@f280d60c9839 Updates LLVM usage to match [f280d60c9839](https://github.com/llvm/llvm-project/commit/f280d60c9839) PiperOrigin-RevId: 745768055 --- .../executor_tpuv1_inline_tpu_island.cc | 4 +- .../sparsecore/embedding_pipelining.cc | 5 +- .../compiler/mlir/tfr/passes/canonicalize.cc | 6 +- .../compiler/mlir/tfr/passes/decompose.cc | 4 +- third_party/llvm/generated.patch | 162 --- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 1153 ++++------------- third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 1153 ++++------------- .../xla/third_party/shardy/workspace.bzl | 4 +- .../deallocation/transforms/buffer_reuse.cc | 4 +- .../xla/mlir_hlo/deallocation/utils/util.h | 3 - 12 files changed, 468 insertions(+), 2038 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 61bba38454afd8..8bdd088b2ddef2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -57,6 +58,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { if (!nested_module) return; InlinerInterface inliner(&getContext()); + InlinerConfig config; auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) { if (!call_op.getF().getRootReference().getValue().starts_with( kNestedModule)) @@ -69,7 +71,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { auto called_func = dyn_cast_or_null(call_interface.resolveCallable()); - if (failed(inlineCall(inliner, call_interface, + if (failed(inlineCall(inliner, config.getCloneCallback(), call_interface, cast(called_func.getOperation()), called_func.getCallableRegion(), /* shouldCloneInlinedRegion = */ false))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index ccd246bd0d85a2..d22180fdbe45f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -148,6 +148,7 @@ return selected_results #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" @@ -422,6 +423,7 @@ struct Inliner : public InlinerInterface { LogicalResult InlineCallsInFunc(func::FuncOp func, bool inline_all_funcs = false) { llvm::SetVector ops_to_erase; + InlinerConfig config; for (auto caller : func.getRegion().getOps()) { if (!inline_all_funcs && @@ -441,7 +443,8 @@ struct Inliner : public InlinerInterface { auto callee = llvm::dyn_cast(symbol_table.lookup(caller.getF())); auto& src_region = callee.getRegion(); - auto result = inlineCall(*this, caller, callee, &src_region, true); + auto result = inlineCall(*this, config.getCloneCallback(), caller, callee, + &src_region, true); if (failed(result)) { func.emitError("Inliner failed"); return result; diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc index 9cc555b7893563..fb0640536d4fe5 100644 --- a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" @@ -142,8 +143,9 @@ LogicalResult SimplifySCFIfOp::InlineRegion(Location loc, Operation *inline_point, Region *region) const { InlinerInterface interface(loc.getContext()); - if (failed(inlineRegion(interface, region, inline_point, {}, - inline_point->getResults(), loc, + InlinerConfig config; + if (failed(inlineRegion(interface, config.getCloneCallback(), region, + inline_point, {}, inline_point->getResults(), loc, /*shouldCloneInlinedRegion=*/true))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 3a5d6f23072b00..105cd8de2041aa 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -47,6 +47,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -282,6 +283,7 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { // The Inliner will automatically use the registered dialect inliner. InlinerInterface inliner(&getContext()); + InlinerConfig config; func::FuncOp func = getOperation(); SymbolTable table(external_tfr_module_.has_value() ? *external_tfr_module_ @@ -301,7 +303,7 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { // Use the inliner to replace all the uses of the call_op by its // composition. - if (failed(inlineCall(inliner, + if (failed(inlineCall(inliner, config.getCloneCallback(), cast(call_op.getOperation()), cast(callee.getOperation()), callee.getCallableRegion(), diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 97282ecaf5c6ac..a3ecef4dbbedb7 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,16 +1,4 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp ---- a/clang/lib/AST/ASTContext.cpp -+++ b/clang/lib/AST/ASTContext.cpp -@@ -7011,7 +7011,7 @@ - getCanonicalTemplateArgument(subst->getArgumentPack()); - return getSubstTemplateTemplateParmPack( - canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), -- subst->getFinal(), subst->getIndex()); -+ subst->getIndex(), subst->getFinal()); - } - case TemplateName::DeducedTemplate: { - assert(IgnoreDeduced == false); diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -44,28 +32,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/ TemplateName Name = getDerived().RebuildTemplateName( SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ---- a/clang/lib/Serialization/ASTReaderStmt.cpp -+++ b/clang/lib/Serialization/ASTReaderStmt.cpp -@@ -2229,6 +2229,7 @@ - E->PackIndex = Record.readInt(); - else - E->PackIndex = 0; -+ E->Final = CurrentUnpackingBits->getNextBit(); - E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); - E->Replacement = Record.readSubExpr(); - } -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ---- a/clang/lib/Serialization/ASTWriterStmt.cpp -+++ b/clang/lib/Serialization/ASTWriterStmt.cpp -@@ -2229,6 +2229,7 @@ - CurrentPackingBits.addBit((bool)E->getPackIndex()); - if (auto PackIndex = E->getPackIndex()) - Record.push_back(*PackIndex + 1); -+ CurrentPackingBits.addBit(E->getFinal()); - - Record.AddSourceLocation(E->getNameLoc()); - Record.AddStmt(E->getReplacement()); diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h --- a/clang/test/CodeGen/include/cuda.h +++ b/clang/test/CodeGen/include/cuda.h @@ -515,119 +481,6 @@ diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && E->getOpcode() == Instruction::Load) { Res = FindFirstInst(); -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ---- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -@@ -2590,6 +2590,14 @@ - if (R.mayWriteToMemory() && !InterleaveR) - return; - -+ // Do not narrow interleave groups if there are VectorPointer recipes and -+ // the plan was unrolled. The recipe implicitly uses VF from -+ // VPTransformState. -+ // TODO: Remove restriction once the VF for the VectorPointer offset is -+ // modeled explicitly as operand. -+ if (isa(&R) && Plan.getUF() > 1) -+ return; -+ - // All other ops are allowed, but we reject uses that cannot be converted - // when checking all allowed consumers (store interleave groups) below. - if (!InterleaveR) -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ---- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -+++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -@@ -66,3 +66,91 @@ - exit: - ret void - } -+ -+define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { -+; CHECK-LABEL: define void @test_2xi64_with_wide_load( -+; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { -+; CHECK-NEXT: [[ENTRY:.*]]: -+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] -+; CHECK: [[VECTOR_PH]]: -+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -+; CHECK: [[VECTOR_BODY]]: -+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 -+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] -+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 -+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 -+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 -+; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 -+; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 -+; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 -+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] -+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] -+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 -+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> -+; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 -+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> -+; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] -+; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] -+; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] -+; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] -+; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> -+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> -+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 -+; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> -+; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> -+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 -+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 -+; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 -+; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] -+; CHECK: [[MIDDLE_BLOCK]]: -+; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] -+; CHECK: [[SCALAR_PH]]: -+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] -+; CHECK-NEXT: br label %[[LOOP:.*]] -+; CHECK: [[LOOP]]: -+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] -+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] -+; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -+; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 -+; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] -+; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 -+; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] -+; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 -+; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 -+; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] -+; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 -+; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] -+; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 -+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 -+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 -+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] -+; CHECK: [[EXIT]]: -+; CHECK-NEXT: ret void -+; -+entry: -+ br label %loop -+ -+loop: -+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] -+ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv -+ %l.factor = load i64, ptr %arrayidx, align 8 -+ %1 = shl nsw i64 %iv, 1 -+ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 -+ %l.0 = load i64, ptr %data.0, align 8 -+ %mul.0 = mul i64 %l.factor, %l.0 -+ store i64 %mul.0, ptr %data.0, align 8 -+ %3 = or disjoint i64 %1, 1 -+ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 -+ %l.1 = load i64, ptr %data.1, align 8 -+ %mul.1 = mul i64 %l.factor, %l.1 -+ store i64 %mul.1, ptr %data.1, align 8 -+ %iv.next = add nuw nsw i64 %iv, 1 -+ %ec = icmp eq i64 %iv.next, 100 -+ br i1 %ec, label %exit, label %loop -+ -+exit: -+ ret void -+} diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll @@ -731,18 +584,3 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-nod + store i8 %58, ptr %59, align 1 + ret void +} -diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ---- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -+++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -@@ -151,9 +151,10 @@ - MachineModuleInfoWrapperPass *MMIWP = - new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); - -- legacy::PassManager PassMgrF; - SmallString<1024> Buf; - llvm::raw_svector_ostream OS(Buf); -+ legacy::PassManager PassMgrF; -+ - AsmPrinter *Printer = - addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); - PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c3bcd538474076..73450ce1ae572a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" - LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" + LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" + LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index e0644f1eee41da..bd3beff6435ac1 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,949 +1,242 @@ -diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td -index 58c9f74..64cfe7f 100644 ---- a/shardy/dialect/sdy/transforms/export/passes.td -+++ b/shardy/dialect/sdy/transforms/export/passes.td -@@ -114,8 +114,8 @@ def TempExplicitReshardsForOptimizationsPass : Pass<"sdy-temp-explicit-reshards- - This pass is a temporary solution until we can enable the - `sdy-insert-explicit-reshards` pass by default. - -- It allows us to insert explicit reshards on specific operations for -- optimizations. -+ It allows us to improve specific use cases where the partitioner does the -+ sub-optimal thing. - }]; - } - -diff --git a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc -index b20b794..0642e3c 100644 ---- a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc -+++ b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc -@@ -29,7 +29,6 @@ limitations under the License. +diff --git a/shardy/dialect/sdy/ir/BUILD b/shardy/dialect/sdy/ir/BUILD +index 780cd17..fe8986b 100644 +--- a/shardy/dialect/sdy/ir/BUILD ++++ b/shardy/dialect/sdy/ir/BUILD +@@ -164,6 +164,7 @@ cc_library( + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_assembly_format", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_type_inference", +diff --git a/shardy/dialect/sdy/ir/canonicalization.cc b/shardy/dialect/sdy/ir/canonicalization.cc +index e1b391f..7ab3e28 100644 +--- a/shardy/dialect/sdy/ir/canonicalization.cc ++++ b/shardy/dialect/sdy/ir/canonicalization.cc +@@ -25,6 +25,7 @@ limitations under the License. + #include "mlir/IR/Region.h" + #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" ++#include "mlir/Transforms/Inliner.h" + #include "mlir/Transforms/InliningUtils.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" --#include "shardy/dialect/sdy/transforms/export/explicit_reshards_util.h" - #include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep - #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h" - #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" -@@ -236,9 +235,6 @@ struct TempExplicitReshardsForOptimizationsPass - [&](stablehlo::DotGeneralOp dotGeneralOp) { - processDot(dotGeneralOp, rewriter, symbolTable); - }); -- if (op->getName().getStringRef().str() == "mhlo.ragged_dot") { -- insertExplicitReshardsOnOp(op, rewriter, symbolTable); -- } - }); - } - }; -diff --git a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir -index 48bcbcb..117954c 100644 ---- a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir -+++ b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir -@@ -1,8 +1,7 @@ --// RUN: sdy_opt %s -allow-unregistered-dialect -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s -+// RUN: sdy_opt %s -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s - - sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> - sdy.mesh @other_mesh = <["x"=2, "y"=2]> --sdy.mesh @mesh_abcd = <["a"=2, "b"=2, "c"=2, "d"=2]> - - // CHECK-LABEL: func @reshard_dot_result_to_match_lhs - func.func @reshard_dot_result_to_match_lhs( -@@ -317,77 +316,3 @@ func.func @dot_result_conflicting_sharding_mismatch_with_reduction_axes_3( - (tensor<4x2x32xf32>, tensor<2x32x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> - } -- --// CHECK-LABEL: func @ragged_dot_mode_non_contracting --func.func @ragged_dot_mode_non_contracting( -- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg1: tensor<4x16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>}, -- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<16x32x8xf32> { -- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {}, {"c"}]> : tensor<16x32x64xf32> -- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{}, {"a"}, {"c"}, {"d"}]> : tensor<4x16x64x8xf32> -- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> -- -- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ -- // CHECK: }> -- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {}, {"d"}]>]> -- -- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x32x8xf32> -- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}]> : tensor<16x32x8xf32> -- // CHECK: return %[[RESHARD3]] : tensor<16x32x8xf32> -- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = -- #mhlo.ragged_dot, -- lhs_ragged_dimensions = [1], rhs_group_dimensions = [0]>}> -- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [m, i, l, k], [i, m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=4} reduction={l} need_replication={j, m}>, -- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>]>} -- : (tensor<16x32x64xf32>, tensor<4x16x64x8xf32>, tensor<16x4xi32>) -> tensor<16x32x8xf32> -- return %0 : tensor<16x32x8xf32> --} -- --// CHECK-LABEL: func @ragged_dot_mode_contracting --func.func @ragged_dot_mode_contracting( -- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<4x16x32x8xf32> { -- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {"b"}, {}]> : tensor<16x32x64xf32> -- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x64x8xf32> -- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> -- -- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ -- // CHECK: }> -- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{}, {"a"}, {"b"}, {"d"}]>]> -- -- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[RAGGED_DOT]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]> : tensor<4x16x32x8xf32> -- // CHECK: return %[[RESHARD3]] : tensor<4x16x32x8xf32> -- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = -- #mhlo.ragged_dot, -- lhs_ragged_dimensions = [2]>}> -- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [i, m])->([m, i, j, k]) {i=16, j=32, k=8, l=64, m=4} need_replication={l, m}>, -- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>]>} -- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<16x4xi32>) -> tensor<4x16x32x8xf32> -- return %0 : tensor<4x16x32x8xf32> --} -- --// CHECK-LABEL: func @ragged_dot_mode_batch --func.func @ragged_dot_mode_batch( -- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, -- %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> { -- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ -- // CHECK: }> -- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> -- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32> -- // CHECK: return %[[ALL_REDUCE]] : tensor<16x32x8xf32> -- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = -- #mhlo.ragged_dot, -- lhs_ragged_dimensions = [0]>}> -- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=1} reduction={l}>, -- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>} -- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<4xi32>) -> tensor<16x32x8xf32> -- return %0 : tensor<16x32x8xf32> --} -diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc -index 6cfed8f..4061903 100644 ---- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc -+++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc -@@ -117,8 +117,8 @@ GroupIdToShardingGroups unifyShardingGroups( - int64_t reindexId = 0; - SmallDenseMap reindexMap; - for (const auto& group : shardingGroupEquivalences) { -- if (group.isLeader()) { -- reindexMap[group.getData()] = reindexId++; -+ if (group->isLeader()) { -+ reindexMap[group->getData()] = reindexId++; +@@ -103,9 +104,11 @@ class RedundantManualComputationPattern } - } - -diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir -index 97099a1..6a711ae 100644 ---- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir -+++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir -@@ -64,8 +64,8 @@ func.func @sharding_groups_reindexes_ids(%arg0: tensor<4xf32>, %arg1: tensor<4xf - // CHECK-LABEL: sharding_groups_reindex_ordering_matches_min_element_ordering - func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) { -- // CHECK: sdy.sharding_group %arg0 group_id=1 : tensor<4xf32> -- // CHECK: sdy.sharding_group %arg1 group_id=0 : tensor<4xf32> -+ // CHECK: sdy.sharding_group %arg0 group_id=0 : tensor<4xf32> -+ // CHECK: sdy.sharding_group %arg1 group_id=1 : tensor<4xf32> - // CHECK: sdy.sharding_group %arg2 group_id=2 : tensor<4xf32> - sdy.sharding_group %arg0 group_id = 567 : tensor<4xf32> - sdy.sharding_group %arg0 group_id = 23 : tensor<4xf32> + mlir::InlinerInterface inliner(manualComputationOp.getContext()); ++ mlir::InlinerConfig config; + if (inlineRegion( +- inliner, &manualComputationOp.getRegion(), +- manualComputationOp->getBlock(), manualComputationOp->getIterator(), ++ inliner, config.getCloneCallback(), ++ &manualComputationOp.getRegion(), manualComputationOp->getBlock(), ++ manualComputationOp->getIterator(), + manualComputationOp.getOperands(), manualComputationOp.getResults()) + .failed()) { + manualComputationOp.emitOpError( diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 2e6ff58..97282ec 100644 +index 97282ec..a3ecef4 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,23 +1,748 @@ +@@ -1,16 +1,4 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ----- a/clang/lib/Sema/SemaCXXScopeSpec.cpp --+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp --@@ -873,6 +873,7 @@ -- DependentTemplateSpecializationTypeLoc SpecTL -- = Builder.push(T); -- SpecTL.setElaboratedKeywordLoc(SourceLocation()); --+ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); -- SpecTL.setTemplateKeywordLoc(TemplateKWLoc); -- SpecTL.setTemplateNameLoc(TemplateNameLoc); -- SpecTL.setLAngleLoc(LAngleLoc); --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --@@ -1902,7 +1902,6 @@ -- name = "inv_trigf_utils", -- srcs = ["src/math/generic/inv_trigf_utils.cpp"], -- hdrs = [ --- "src/math/generic/atan_utils.h", -- "src/math/generic/inv_trigf_utils.h", -- ], -- deps = [ -+diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp -+--- a/clang/lib/AST/ASTContext.cpp -++++ b/clang/lib/AST/ASTContext.cpp -+@@ -7011,7 +7011,7 @@ -+ getCanonicalTemplateArgument(subst->getArgumentPack()); -+ return getSubstTemplateTemplateParmPack( -+ canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), -+- subst->getFinal(), subst->getIndex()); -++ subst->getIndex(), subst->getFinal()); -+ } -+ case TemplateName::DeducedTemplate: { -+ assert(IgnoreDeduced == false); -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h -+--- a/clang/lib/Sema/TreeTransform.h -++++ b/clang/lib/Sema/TreeTransform.h -+@@ -7765,17 +7765,23 @@ -+ NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); -+ NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); -+ -+- typedef TemplateArgumentLocContainerIterator< -+- DependentTemplateSpecializationTypeLoc> ArgIterator; -+- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), -+- ArgIterator(TL, TL.getNumArgs()), -+- NewTemplateArgs)) -++ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); -++ -++ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), -++ ArgsRange.end(), NewTemplateArgs)) -+ return QualType(); -++ bool TemplateArgumentsChanged = !llvm::equal( -++ ArgsRange, NewTemplateArgs.arguments(), -++ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { -++ return A.getArgument().structurallyEquals(B.getArgument()); -++ }); -+ -+ const DependentTemplateStorage &DTN = T->getDependentTemplateName(); -+ -+ QualType Result = TL.getType(); -+- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { -++ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || -++ TemplateArgumentsChanged) { -+ TemplateName Name = getDerived().RebuildTemplateName( -+ SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), -+ /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp -+--- a/clang/lib/Serialization/ASTReaderStmt.cpp -++++ b/clang/lib/Serialization/ASTReaderStmt.cpp -+@@ -2229,6 +2229,7 @@ -+ E->PackIndex = Record.readInt(); -+ else -+ E->PackIndex = 0; -++ E->Final = CurrentUnpackingBits->getNextBit(); -+ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); -+ E->Replacement = Record.readSubExpr(); -+ } -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp -+--- a/clang/lib/Serialization/ASTWriterStmt.cpp -++++ b/clang/lib/Serialization/ASTWriterStmt.cpp -+@@ -2229,6 +2229,7 @@ -+ CurrentPackingBits.addBit((bool)E->getPackIndex()); -+ if (auto PackIndex = E->getPackIndex()) -+ Record.push_back(*PackIndex + 1); -++ CurrentPackingBits.addBit(E->getFinal()); -+ -+ Record.AddSourceLocation(E->getNameLoc()); -+ Record.AddStmt(E->getReplacement()); -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h -+--- a/clang/test/CodeGen/include/cuda.h -++++ b/clang/test/CodeGen/include/cuda.h -+@@ -1,194 +0,0 @@ -+-/* Minimal declarations for CUDA support. Testing purposes only. -+- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h -+- */ -+-#pragma once -+- -+-// Make this file work with nvcc, for testing compatibility. -+- -+-#ifndef __NVCC__ -+-#define __constant__ __attribute__((constant)) -+-#define __device__ __attribute__((device)) -+-#define __global__ __attribute__((global)) -+-#define __host__ __attribute__((host)) -+-#define __shared__ __attribute__((shared)) -+-#define __managed__ __attribute__((managed)) -+-#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) -+- -+-struct dim3 { -+- unsigned x, y, z; -+- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} -+-}; -+- -+-// Host- and device-side placement new overloads. -+-void *operator new(__SIZE_TYPE__, void *p) { return p; } -+-void *operator new[](__SIZE_TYPE__, void *p) { return p; } -+-__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } -+-__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } -+- -+-#define CUDA_VERSION 10100 -+- -+-struct char1 { -+- char x; -+- __host__ __device__ char1(char x = 0) : x(x) {} -+-}; -+-struct char2 { -+- char x, y; -+- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} -+-}; -+-struct char4 { -+- char x, y, z, w; -+- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct uchar1 { -+- unsigned char x; -+- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} -+-}; -+-struct uchar2 { -+- unsigned char x, y; -+- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} -+-}; -+-struct uchar4 { -+- unsigned char x, y, z, w; -+- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct short1 { -+- short x; -+- __host__ __device__ short1(short x = 0) : x(x) {} -+-}; -+-struct short2 { -+- short x, y; -+- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} -+-}; -+-struct short4 { -+- short x, y, z, w; -+- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct ushort1 { -+- unsigned short x; -+- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} -+-}; -+-struct ushort2 { -+- unsigned short x, y; -+- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} -+-}; -+-struct ushort4 { -+- unsigned short x, y, z, w; -+- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct int1 { -+- int x; -+- __host__ __device__ int1(int x = 0) : x(x) {} -+-}; -+-struct int2 { -+- int x, y; -+- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} -+-}; -+-struct int4 { -+- int x, y, z, w; -+- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct uint1 { -+- unsigned x; -+- __host__ __device__ uint1(unsigned x = 0) : x(x) {} -+-}; -+-struct uint2 { -+- unsigned x, y; -+- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} -+-}; -+-struct uint3 { -+- unsigned x, y, z; -+- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} -+-}; -+-struct uint4 { -+- unsigned x, y, z, w; -+- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct longlong1 { -+- long long x; -+- __host__ __device__ longlong1(long long x = 0) : x(x) {} -+-}; -+-struct longlong2 { -+- long long x, y; -+- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} -+-}; -+-struct longlong4 { -+- long long x, y, z, w; -+- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct ulonglong1 { -+- unsigned long long x; -+- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} -+-}; -+-struct ulonglong2 { -+- unsigned long long x, y; -+- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} -+-}; -+-struct ulonglong4 { -+- unsigned long long x, y, z, w; -+- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct float1 { -+- float x; -+- __host__ __device__ float1(float x = 0) : x(x) {} -+-}; -+-struct float2 { -+- float x, y; -+- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} -+-}; -+-struct float4 { -+- float x, y, z, w; -+- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct double1 { -+- double x; -+- __host__ __device__ double1(double x = 0) : x(x) {} -+-}; -+-struct double2 { -+- double x, y; -+- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} -+-}; -+-struct double4 { -+- double x, y, z, w; -+- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-typedef unsigned long long cudaTextureObject_t; -+-typedef unsigned long long cudaSurfaceObject_t; -+- -+-enum cudaTextureReadMode { -+- cudaReadModeNormalizedFloat, -+- cudaReadModeElementType -+-}; -+- -+-enum cudaSurfaceBoundaryMode { -+- cudaBoundaryModeZero, -+- cudaBoundaryModeClamp, -+- cudaBoundaryModeTrap -+-}; -+- -+-enum { -+- cudaTextureType1D, -+- cudaTextureType2D, -+- cudaTextureType3D, -+- cudaTextureTypeCubemap, -+- cudaTextureType1DLayered, -+- cudaTextureType2DLayered, -+- cudaTextureTypeCubemapLayered -+-}; -+- -+-struct textureReference {}; -+-template -+-struct __attribute__((device_builtin_texture_type)) texture -+- : public textureReference {}; -+- -+-#endif // !__NVCC__ -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h -+--- a/clang/test/CodeGen/Inputs/cuda.h -++++ b/clang/test/CodeGen/Inputs/cuda.h -+@@ -0,0 +1,194 @@ -++/* Minimal declarations for CUDA support. Testing purposes only. -++ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h -++ */ -++#pragma once -++ -++// Make this file work with nvcc, for testing compatibility. -++ -++#ifndef __NVCC__ -++#define __constant__ __attribute__((constant)) -++#define __device__ __attribute__((device)) -++#define __global__ __attribute__((global)) -++#define __host__ __attribute__((host)) -++#define __shared__ __attribute__((shared)) -++#define __managed__ __attribute__((managed)) -++#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) -++ -++struct dim3 { -++ unsigned x, y, z; -++ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} -++}; -++ -++// Host- and device-side placement new overloads. -++void *operator new(__SIZE_TYPE__, void *p) { return p; } -++void *operator new[](__SIZE_TYPE__, void *p) { return p; } -++__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } -++__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } -++ -++#define CUDA_VERSION 10100 -++ -++struct char1 { -++ char x; -++ __host__ __device__ char1(char x = 0) : x(x) {} -++}; -++struct char2 { -++ char x, y; -++ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} -++}; -++struct char4 { -++ char x, y, z, w; -++ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct uchar1 { -++ unsigned char x; -++ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} -++}; -++struct uchar2 { -++ unsigned char x, y; -++ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} -++}; -++struct uchar4 { -++ unsigned char x, y, z, w; -++ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct short1 { -++ short x; -++ __host__ __device__ short1(short x = 0) : x(x) {} -++}; -++struct short2 { -++ short x, y; -++ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} -++}; -++struct short4 { -++ short x, y, z, w; -++ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct ushort1 { -++ unsigned short x; -++ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} -++}; -++struct ushort2 { -++ unsigned short x, y; -++ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} -++}; -++struct ushort4 { -++ unsigned short x, y, z, w; -++ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct int1 { -++ int x; -++ __host__ __device__ int1(int x = 0) : x(x) {} -++}; -++struct int2 { -++ int x, y; -++ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} -++}; -++struct int4 { -++ int x, y, z, w; -++ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct uint1 { -++ unsigned x; -++ __host__ __device__ uint1(unsigned x = 0) : x(x) {} -++}; -++struct uint2 { -++ unsigned x, y; -++ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} -++}; -++struct uint3 { -++ unsigned x, y, z; -++ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} -++}; -++struct uint4 { -++ unsigned x, y, z, w; -++ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct longlong1 { -++ long long x; -++ __host__ __device__ longlong1(long long x = 0) : x(x) {} -++}; -++struct longlong2 { -++ long long x, y; -++ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} -++}; -++struct longlong4 { -++ long long x, y, z, w; -++ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct ulonglong1 { -++ unsigned long long x; -++ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} -++}; -++struct ulonglong2 { -++ unsigned long long x, y; -++ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} -++}; -++struct ulonglong4 { -++ unsigned long long x, y, z, w; -++ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct float1 { -++ float x; -++ __host__ __device__ float1(float x = 0) : x(x) {} -++}; -++struct float2 { -++ float x, y; -++ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} -++}; -++struct float4 { -++ float x, y, z, w; -++ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct double1 { -++ double x; -++ __host__ __device__ double1(double x = 0) : x(x) {} -++}; -++struct double2 { -++ double x, y; -++ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} -++}; -++struct double4 { -++ double x, y, z, w; -++ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++typedef unsigned long long cudaTextureObject_t; -++typedef unsigned long long cudaSurfaceObject_t; -++ -++enum cudaTextureReadMode { -++ cudaReadModeNormalizedFloat, -++ cudaReadModeElementType -++}; -++ -++enum cudaSurfaceBoundaryMode { -++ cudaBoundaryModeZero, -++ cudaBoundaryModeClamp, -++ cudaBoundaryModeTrap -++}; -++ -++enum { -++ cudaTextureType1D, -++ cudaTextureType2D, -++ cudaTextureType3D, -++ cudaTextureTypeCubemap, -++ cudaTextureType1DLayered, -++ cudaTextureType2DLayered, -++ cudaTextureTypeCubemapLayered -++}; -++ -++struct textureReference {}; -++template -++struct __attribute__((device_builtin_texture_type)) texture -++ : public textureReference {}; -++ -++#endif // !__NVCC__ -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu -+--- a/clang/test/CodeGen/nvptx-surface.cu -++++ b/clang/test/CodeGen/nvptx-surface.cu -+@@ -1,6 +1,6 @@ -+ // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s -+ // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s -+-#include "include/cuda.h" -++#include "Inputs/cuda.h" -+ -+ #include "__clang_cuda_texture_intrinsics.h" -+ -+diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp -+--- a/clang/test/SemaTemplate/dependent-names.cpp -++++ b/clang/test/SemaTemplate/dependent-names.cpp -+@@ -458,3 +458,12 @@ -+ }; -+ int f(b ba) { return ba.add<0>(); } -+ } -++ -++namespace TransformDependentTemplates { -++ template struct Test1 { -++ template -++ using Arg = typename T::template Arg; -++ void f(Arg); -++ void f(Arg); -++ }; -++} // namespace TransformDependentTemplates -+diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -+--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -+@@ -15391,12 +15391,20 @@ -+ -+ if (E->State == TreeEntry::SplitVectorize) { -+ Res = FindLastInst(); -++ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { -++ for (auto *E : Entries) { -++ auto *I = dyn_cast_or_null(E->VectorizedValue); -++ if (!I) -++ I = &getLastInstructionInBundle(E); -++ if (Res->comesBefore(I)) -++ Res = I; -++ } -++ } -+ return *Res; -+ } -+ -+ // Set insertpoint for gathered loads to the very first load. -+- if (E->State != TreeEntry::SplitVectorize && -+- GatheredLoadsEntriesFirst.has_value() && -++ if (GatheredLoadsEntriesFirst.has_value() && -+ E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && -+ E->getOpcode() == Instruction::Load) { -+ Res = FindFirstInst(); -+diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -+--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -++++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -+@@ -2590,6 +2590,14 @@ -+ if (R.mayWriteToMemory() && !InterleaveR) -+ return; -+ -++ // Do not narrow interleave groups if there are VectorPointer recipes and -++ // the plan was unrolled. The recipe implicitly uses VF from -++ // VPTransformState. -++ // TODO: Remove restriction once the VF for the VectorPointer offset is -++ // modeled explicitly as operand. -++ if (isa(&R) && Plan.getUF() > 1) -++ return; -++ -+ // All other ops are allowed, but we reject uses that cannot be converted -+ // when checking all allowed consumers (store interleave groups) below. -+ if (!InterleaveR) -+diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -+--- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -++++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -+@@ -66,3 +66,91 @@ -+ exit: -+ ret void -+ } -++ -++define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { -++; CHECK-LABEL: define void @test_2xi64_with_wide_load( -++; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { -++; CHECK-NEXT: [[ENTRY:.*]]: -++; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] -++; CHECK: [[VECTOR_PH]]: -++; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -++; CHECK: [[VECTOR_BODY]]: -++; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -++; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 -++; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] -++; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 -++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 -++; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 -++; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 -++; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 -++; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 -++; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] -++; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] -++; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 -++; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 -++; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] -++; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] -++; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] -++; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] -++; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> -++; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> -++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 -++; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> -++; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> -++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 -++; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 -++; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 -++; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] -++; CHECK: [[MIDDLE_BLOCK]]: -++; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] -++; CHECK: [[SCALAR_PH]]: -++; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] -++; CHECK-NEXT: br label %[[LOOP:.*]] -++; CHECK: [[LOOP]]: -++; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] -++; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] -++; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -++; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 -++; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] -++; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 -++; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] -++; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 -++; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 -++; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] -++; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 -++; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] -++; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 -++; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 -++; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 -++; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] -++; CHECK: [[EXIT]]: -++; CHECK-NEXT: ret void -++; -++entry: -++ br label %loop -++ -++loop: -++ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] -++ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv -++ %l.factor = load i64, ptr %arrayidx, align 8 -++ %1 = shl nsw i64 %iv, 1 -++ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 -++ %l.0 = load i64, ptr %data.0, align 8 -++ %mul.0 = mul i64 %l.factor, %l.0 -++ store i64 %mul.0, ptr %data.0, align 8 -++ %3 = or disjoint i64 %1, 1 -++ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 -++ %l.1 = load i64, ptr %data.1, align 8 -++ %mul.1 = mul i64 %l.factor, %l.1 -++ store i64 %mul.1, ptr %data.1, align 8 -++ %iv.next = add nuw nsw i64 %iv, 1 -++ %ec = icmp eq i64 %iv.next, 100 -++ br i1 %ec, label %exit, label %loop -++ -++exit: -++ ret void -++} -+diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -+--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -+@@ -0,0 +1,99 @@ -++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s -++ -++define void @test(ptr %0, <8 x i8> %1) { -++; CHECK-LABEL: define void @test( -++; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { -++; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 -++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 -++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 -++; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 -++; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 -++; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> -++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 -++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> -++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> -++; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) -++; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 -++; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> -++; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) -++; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) -++; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] -++; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 -++; CHECK-NEXT: ret void -++; -++ %3 = load i8, ptr %0, align 2 -++ %4 = getelementptr i8, ptr %0, i64 13442 -++ %5 = load i8, ptr %4, align 2 -++ %6 = or i8 %5, %3 -++ %7 = getelementptr i8, ptr %0, i64 13550 -++ store i8 %6, ptr %7, align 2 -++ %8 = extractelement <8 x i8> %1, i64 0 -++ %9 = or i8 %5, %8 -++ %10 = getelementptr i8, ptr %0, i64 13542 -++ store i8 %9, ptr %10, align 2 -++ %11 = getelementptr i8, ptr %0, i64 13438 -++ %12 = load i8, ptr %11, align 2 -++ %13 = or i8 %12, %3 -++ %14 = getelementptr i8, ptr %0, i64 13546 -++ store i8 %13, ptr %14, align 2 -++ %15 = extractelement <8 x i8> %1, i64 2 -++ %16 = or i8 %12, %15 -++ %17 = getelementptr i8, ptr %0, i64 13538 -++ store i8 %16, ptr %17, align 2 -++ %18 = getelementptr i8, ptr %0, i64 13440 -++ %19 = load i8, ptr %18, align 4 -++ %20 = or i8 %19, %3 -++ %21 = getelementptr i8, ptr %0, i64 13548 -++ store i8 %20, ptr %21, align 4 -++ %22 = extractelement <8 x i8> %1, i64 4 -++ %23 = or i8 %19, %22 -++ %24 = getelementptr i8, ptr %0, i64 13540 -++ store i8 %23, ptr %24, align 4 -++ %25 = getelementptr i8, ptr %0, i64 13436 -++ %26 = load i8, ptr %25, align 4 -++ %27 = getelementptr i8, ptr %0, i64 13444 -++ %28 = load i8, ptr %27, align 4 -++ %29 = or i8 %28, %26 -++ %30 = getelementptr i8, ptr %0, i64 13544 -++ store i8 %29, ptr %30, align 4 -++ %31 = or i8 %26, %8 -++ %32 = getelementptr i8, ptr %0, i64 13536 -++ store i8 %31, ptr %32, align 4 -++ %33 = getelementptr i8, ptr %0, i64 13443 -++ %34 = load i8, ptr %33, align 1 -++ %35 = or i8 %34, %3 -++ %36 = getelementptr i8, ptr %0, i64 13551 -++ store i8 %35, ptr %36, align 1 -++ %37 = extractelement <8 x i8> %1, i64 7 -++ %38 = or i8 %34, %37 -++ %39 = getelementptr i8, ptr %0, i64 13543 -++ store i8 %38, ptr %39, align 1 -++ %40 = getelementptr i8, ptr %0, i64 13439 -++ %41 = load i8, ptr %40, align 1 -++ %42 = or i8 %41, %3 -++ %43 = getelementptr i8, ptr %0, i64 13547 -++ store i8 %42, ptr %43, align 1 -++ %44 = extractelement <8 x i8> %1, i64 3 -++ %45 = or i8 %41, %44 -++ %46 = getelementptr i8, ptr %0, i64 13539 -++ store i8 %45, ptr %46, align 1 -++ %47 = getelementptr i8, ptr %0, i64 13441 -++ %48 = load i8, ptr %47, align 1 -++ %49 = or i8 %48, %3 -++ %50 = getelementptr i8, ptr %0, i64 13549 -++ store i8 %49, ptr %50, align 1 -++ %51 = extractelement <8 x i8> %1, i64 5 -++ %52 = or i8 %48, %51 -++ %53 = getelementptr i8, ptr %0, i64 13541 -++ store i8 %52, ptr %53, align 1 -++ %54 = getelementptr i8, ptr %0, i64 13437 -++ %55 = load i8, ptr %54, align 1 -++ %56 = or i8 %55, %3 -++ %57 = getelementptr i8, ptr %0, i64 13545 -++ store i8 %56, ptr %57, align 1 -++ %58 = or i8 %55, %8 -++ %59 = getelementptr i8, ptr %0, i64 13537 -++ store i8 %58, ptr %59, align 1 -++ ret void -++} -+diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -+--- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -++++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -+@@ -151,9 +151,10 @@ -+ MachineModuleInfoWrapperPass *MMIWP = -+ new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); -+ -+- legacy::PassManager PassMgrF; -+ SmallString<1024> Buf; -+ llvm::raw_svector_ostream OS(Buf); -++ legacy::PassManager PassMgrF; -++ -+ AsmPrinter *Printer = -+ addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); -+ PassMgrF.run(*M); +-diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp +---- a/clang/lib/AST/ASTContext.cpp +-+++ b/clang/lib/AST/ASTContext.cpp +-@@ -7011,7 +7011,7 @@ +- getCanonicalTemplateArgument(subst->getArgumentPack()); +- return getSubstTemplateTemplateParmPack( +- canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), +-- subst->getFinal(), subst->getIndex()); +-+ subst->getIndex(), subst->getFinal()); +- } +- case TemplateName::DeducedTemplate: { +- assert(IgnoreDeduced == false); + diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h + --- a/clang/lib/Sema/TreeTransform.h + +++ b/clang/lib/Sema/TreeTransform.h +@@ -44,28 +32,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/ + TemplateName Name = getDerived().RebuildTemplateName( + SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), + /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp +---- a/clang/lib/Serialization/ASTReaderStmt.cpp +-+++ b/clang/lib/Serialization/ASTReaderStmt.cpp +-@@ -2229,6 +2229,7 @@ +- E->PackIndex = Record.readInt(); +- else +- E->PackIndex = 0; +-+ E->Final = CurrentUnpackingBits->getNextBit(); +- E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); +- E->Replacement = Record.readSubExpr(); +- } +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp +---- a/clang/lib/Serialization/ASTWriterStmt.cpp +-+++ b/clang/lib/Serialization/ASTWriterStmt.cpp +-@@ -2229,6 +2229,7 @@ +- CurrentPackingBits.addBit((bool)E->getPackIndex()); +- if (auto PackIndex = E->getPackIndex()) +- Record.push_back(*PackIndex + 1); +-+ CurrentPackingBits.addBit(E->getFinal()); +- +- Record.AddSourceLocation(E->getNameLoc()); +- Record.AddStmt(E->getReplacement()); + diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h + --- a/clang/test/CodeGen/include/cuda.h + +++ b/clang/test/CodeGen/include/cuda.h +@@ -515,119 +481,6 @@ diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp + E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && + E->getOpcode() == Instruction::Load) { + Res = FindFirstInst(); +-diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +---- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +-+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +-@@ -2590,6 +2590,14 @@ +- if (R.mayWriteToMemory() && !InterleaveR) +- return; +- +-+ // Do not narrow interleave groups if there are VectorPointer recipes and +-+ // the plan was unrolled. The recipe implicitly uses VF from +-+ // VPTransformState. +-+ // TODO: Remove restriction once the VF for the VectorPointer offset is +-+ // modeled explicitly as operand. +-+ if (isa(&R) && Plan.getUF() > 1) +-+ return; +-+ +- // All other ops are allowed, but we reject uses that cannot be converted +- // when checking all allowed consumers (store interleave groups) below. +- if (!InterleaveR) +-diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +---- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +-+++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +-@@ -66,3 +66,91 @@ +- exit: +- ret void +- } +-+ +-+define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { +-+; CHECK-LABEL: define void @test_2xi64_with_wide_load( +-+; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { +-+; CHECK-NEXT: [[ENTRY:.*]]: +-+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] +-+; CHECK: [[VECTOR_PH]]: +-+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] +-+; CHECK: [[VECTOR_BODY]]: +-+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +-+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 +-+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] +-+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 +-+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 +-+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 +-+; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 +-+; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 +-+; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 +-+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] +-+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] +-+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 +-+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 +-+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] +-+; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] +-+; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] +-+; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] +-+; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> +-+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> +-+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 +-+; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> +-+; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> +-+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 +-+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 +-+; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 +-+; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] +-+; CHECK: [[MIDDLE_BLOCK]]: +-+; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] +-+; CHECK: [[SCALAR_PH]]: +-+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +-+; CHECK-NEXT: br label %[[LOOP:.*]] +-+; CHECK: [[LOOP]]: +-+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] +-+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] +-+; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +-+; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 +-+; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] +-+; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 +-+; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] +-+; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 +-+; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 +-+; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] +-+; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 +-+; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] +-+; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 +-+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +-+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 +-+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] +-+; CHECK: [[EXIT]]: +-+; CHECK-NEXT: ret void +-+; +-+entry: +-+ br label %loop +-+ +-+loop: +-+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] +-+ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv +-+ %l.factor = load i64, ptr %arrayidx, align 8 +-+ %1 = shl nsw i64 %iv, 1 +-+ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 +-+ %l.0 = load i64, ptr %data.0, align 8 +-+ %mul.0 = mul i64 %l.factor, %l.0 +-+ store i64 %mul.0, ptr %data.0, align 8 +-+ %3 = or disjoint i64 %1, 1 +-+ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 +-+ %l.1 = load i64, ptr %data.1, align 8 +-+ %mul.1 = mul i64 %l.factor, %l.1 +-+ store i64 %mul.1, ptr %data.1, align 8 +-+ %iv.next = add nuw nsw i64 %iv, 1 +-+ %ec = icmp eq i64 %iv.next, 100 +-+ br i1 %ec, label %exit, label %loop +-+ +-+exit: +-+ ret void +-+} + diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll + --- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll + +++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +@@ -731,18 +584,3 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-nod + + store i8 %58, ptr %59, align 1 + + ret void + +} +-diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +---- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +-+++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +-@@ -151,9 +151,10 @@ +- MachineModuleInfoWrapperPass *MMIWP = +- new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); +- +-- legacy::PassManager PassMgrF; +- SmallString<1024> Buf; +- llvm::raw_svector_ostream OS(Buf); +-+ legacy::PassManager PassMgrF; +-+ +- AsmPrinter *Printer = +- addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); +- PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 4a58099..c3bcd53 100644 +index c3bcd53..73450ce 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" -- LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" -+ LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" -+ LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" +- LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" +- LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" ++ LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" ++ LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 3789787f919440..104a8de57175ad 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "cf9436497603650441904a21316fbd058551f663" - SHARDY_SHA256 = "133dcda8bf84d516f67b3bcd0e1c5a564b266a67cb96f8901ccf8be30e830d3e" + SHARDY_COMMIT = "1ba08b6822b3bce9ce4acb3b839b05b3266ca0bc" + SHARDY_SHA256 = "6930a383ed9b516041f08ae948e86a3926a2cf11c7457fa50950f298275c6a84" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index e0644f1eee41da..bd3beff6435ac1 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,949 +1,242 @@ -diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td -index 58c9f74..64cfe7f 100644 ---- a/shardy/dialect/sdy/transforms/export/passes.td -+++ b/shardy/dialect/sdy/transforms/export/passes.td -@@ -114,8 +114,8 @@ def TempExplicitReshardsForOptimizationsPass : Pass<"sdy-temp-explicit-reshards- - This pass is a temporary solution until we can enable the - `sdy-insert-explicit-reshards` pass by default. - -- It allows us to insert explicit reshards on specific operations for -- optimizations. -+ It allows us to improve specific use cases where the partitioner does the -+ sub-optimal thing. - }]; - } - -diff --git a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc -index b20b794..0642e3c 100644 ---- a/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc -+++ b/shardy/dialect/sdy/transforms/export/temp_explicit_reshards_for_optimizations.cc -@@ -29,7 +29,6 @@ limitations under the License. +diff --git a/shardy/dialect/sdy/ir/BUILD b/shardy/dialect/sdy/ir/BUILD +index 780cd17..fe8986b 100644 +--- a/shardy/dialect/sdy/ir/BUILD ++++ b/shardy/dialect/sdy/ir/BUILD +@@ -164,6 +164,7 @@ cc_library( + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_assembly_format", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_type_inference", +diff --git a/shardy/dialect/sdy/ir/canonicalization.cc b/shardy/dialect/sdy/ir/canonicalization.cc +index e1b391f..7ab3e28 100644 +--- a/shardy/dialect/sdy/ir/canonicalization.cc ++++ b/shardy/dialect/sdy/ir/canonicalization.cc +@@ -25,6 +25,7 @@ limitations under the License. + #include "mlir/IR/Region.h" + #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" ++#include "mlir/Transforms/Inliner.h" + #include "mlir/Transforms/InliningUtils.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" --#include "shardy/dialect/sdy/transforms/export/explicit_reshards_util.h" - #include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep - #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h" - #include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h" -@@ -236,9 +235,6 @@ struct TempExplicitReshardsForOptimizationsPass - [&](stablehlo::DotGeneralOp dotGeneralOp) { - processDot(dotGeneralOp, rewriter, symbolTable); - }); -- if (op->getName().getStringRef().str() == "mhlo.ragged_dot") { -- insertExplicitReshardsOnOp(op, rewriter, symbolTable); -- } - }); - } - }; -diff --git a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir -index 48bcbcb..117954c 100644 ---- a/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir -+++ b/shardy/dialect/sdy/transforms/export/test/temp_explicit_reshards_for_optimizations.mlir -@@ -1,8 +1,7 @@ --// RUN: sdy_opt %s -allow-unregistered-dialect -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s -+// RUN: sdy_opt %s -sdy-temp-explicit-reshards-for-optimizations | FileCheck %s - - sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> - sdy.mesh @other_mesh = <["x"=2, "y"=2]> --sdy.mesh @mesh_abcd = <["a"=2, "b"=2, "c"=2, "d"=2]> - - // CHECK-LABEL: func @reshard_dot_result_to_match_lhs - func.func @reshard_dot_result_to_match_lhs( -@@ -317,77 +316,3 @@ func.func @dot_result_conflicting_sharding_mismatch_with_reduction_axes_3( - (tensor<4x2x32xf32>, tensor<2x32x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> - } -- --// CHECK-LABEL: func @ragged_dot_mode_non_contracting --func.func @ragged_dot_mode_non_contracting( -- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg1: tensor<4x16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>}, -- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<16x32x8xf32> { -- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {}, {"c"}]> : tensor<16x32x64xf32> -- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{}, {"a"}, {"c"}, {"d"}]> : tensor<4x16x64x8xf32> -- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> -- -- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ -- // CHECK: }> -- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {}, {"d"}]>]> -- -- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x32x8xf32> -- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}]> : tensor<16x32x8xf32> -- // CHECK: return %[[RESHARD3]] : tensor<16x32x8xf32> -- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = -- #mhlo.ragged_dot, -- lhs_ragged_dimensions = [1], rhs_group_dimensions = [0]>}> -- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [m, i, l, k], [i, m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=4} reduction={l} need_replication={j, m}>, -- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>]>} -- : (tensor<16x32x64xf32>, tensor<4x16x64x8xf32>, tensor<16x4xi32>) -> tensor<16x32x8xf32> -- return %0 : tensor<16x32x8xf32> --} -- --// CHECK-LABEL: func @ragged_dot_mode_contracting --func.func @ragged_dot_mode_contracting( -- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg2: tensor<16x4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}]>}) -> tensor<4x16x32x8xf32> { -- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_abcd, [{"a"}, {"b"}, {}]> : tensor<16x32x64xf32> -- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_abcd, [{"a"}, {}, {"d"}]> : tensor<16x64x8xf32> -- // CHECK: %[[RESHARD2:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{"a"}, {}]> : tensor<16x4xi32> -- -- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%[[RESHARD0]], %[[RESHARD1]], %[[RESHARD2]]) <{ -- // CHECK: }> -- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{}, {"a"}, {"b"}, {"d"}]>]> -- -- // CHECK: %[[RESHARD3:.*]] = sdy.reshard %[[RAGGED_DOT]] <@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]> : tensor<4x16x32x8xf32> -- // CHECK: return %[[RESHARD3]] : tensor<4x16x32x8xf32> -- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = -- #mhlo.ragged_dot, -- lhs_ragged_dimensions = [2]>}> -- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [i, m])->([m, i, j, k]) {i=16, j=32, k=8, l=64, m=4} need_replication={l, m}>, -- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"c"}, {"d"}]>]>} -- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<16x4xi32>) -> tensor<4x16x32x8xf32> -- return %0 : tensor<4x16x32x8xf32> --} -- --// CHECK-LABEL: func @ragged_dot_mode_batch --func.func @ragged_dot_mode_batch( -- %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, -- %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, -- %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> { -- // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ -- // CHECK: }> -- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> -- // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32> -- // CHECK: return %[[ALL_REDUCE]] : tensor<16x32x8xf32> -- %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = -- #mhlo.ragged_dot, -- lhs_ragged_dimensions = [0]>}> -- {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=1} reduction={l}>, -- sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>} -- : (tensor<16x32x64xf32>, tensor<16x64x8xf32>, tensor<4xi32>) -> tensor<16x32x8xf32> -- return %0 : tensor<16x32x8xf32> --} -diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc -index 6cfed8f..4061903 100644 ---- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc -+++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc -@@ -117,8 +117,8 @@ GroupIdToShardingGroups unifyShardingGroups( - int64_t reindexId = 0; - SmallDenseMap reindexMap; - for (const auto& group : shardingGroupEquivalences) { -- if (group.isLeader()) { -- reindexMap[group.getData()] = reindexId++; -+ if (group->isLeader()) { -+ reindexMap[group->getData()] = reindexId++; +@@ -103,9 +104,11 @@ class RedundantManualComputationPattern } - } - -diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir -index 97099a1..6a711ae 100644 ---- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir -+++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir -@@ -64,8 +64,8 @@ func.func @sharding_groups_reindexes_ids(%arg0: tensor<4xf32>, %arg1: tensor<4xf - // CHECK-LABEL: sharding_groups_reindex_ordering_matches_min_element_ordering - func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) { -- // CHECK: sdy.sharding_group %arg0 group_id=1 : tensor<4xf32> -- // CHECK: sdy.sharding_group %arg1 group_id=0 : tensor<4xf32> -+ // CHECK: sdy.sharding_group %arg0 group_id=0 : tensor<4xf32> -+ // CHECK: sdy.sharding_group %arg1 group_id=1 : tensor<4xf32> - // CHECK: sdy.sharding_group %arg2 group_id=2 : tensor<4xf32> - sdy.sharding_group %arg0 group_id = 567 : tensor<4xf32> - sdy.sharding_group %arg0 group_id = 23 : tensor<4xf32> + mlir::InlinerInterface inliner(manualComputationOp.getContext()); ++ mlir::InlinerConfig config; + if (inlineRegion( +- inliner, &manualComputationOp.getRegion(), +- manualComputationOp->getBlock(), manualComputationOp->getIterator(), ++ inliner, config.getCloneCallback(), ++ &manualComputationOp.getRegion(), manualComputationOp->getBlock(), ++ manualComputationOp->getIterator(), + manualComputationOp.getOperands(), manualComputationOp.getResults()) + .failed()) { + manualComputationOp.emitOpError( diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 2e6ff58..97282ec 100644 +index 97282ec..a3ecef4 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,23 +1,748 @@ +@@ -1,16 +1,4 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ----- a/clang/lib/Sema/SemaCXXScopeSpec.cpp --+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp --@@ -873,6 +873,7 @@ -- DependentTemplateSpecializationTypeLoc SpecTL -- = Builder.push(T); -- SpecTL.setElaboratedKeywordLoc(SourceLocation()); --+ SpecTL.setQualifierLoc(NestedNameSpecifierLoc()); -- SpecTL.setTemplateKeywordLoc(TemplateKWLoc); -- SpecTL.setTemplateNameLoc(TemplateNameLoc); -- SpecTL.setLAngleLoc(LAngleLoc); --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --@@ -1902,7 +1902,6 @@ -- name = "inv_trigf_utils", -- srcs = ["src/math/generic/inv_trigf_utils.cpp"], -- hdrs = [ --- "src/math/generic/atan_utils.h", -- "src/math/generic/inv_trigf_utils.h", -- ], -- deps = [ -+diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp -+--- a/clang/lib/AST/ASTContext.cpp -++++ b/clang/lib/AST/ASTContext.cpp -+@@ -7011,7 +7011,7 @@ -+ getCanonicalTemplateArgument(subst->getArgumentPack()); -+ return getSubstTemplateTemplateParmPack( -+ canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), -+- subst->getFinal(), subst->getIndex()); -++ subst->getIndex(), subst->getFinal()); -+ } -+ case TemplateName::DeducedTemplate: { -+ assert(IgnoreDeduced == false); -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h -+--- a/clang/lib/Sema/TreeTransform.h -++++ b/clang/lib/Sema/TreeTransform.h -+@@ -7765,17 +7765,23 @@ -+ NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); -+ NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); -+ -+- typedef TemplateArgumentLocContainerIterator< -+- DependentTemplateSpecializationTypeLoc> ArgIterator; -+- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), -+- ArgIterator(TL, TL.getNumArgs()), -+- NewTemplateArgs)) -++ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); -++ -++ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), -++ ArgsRange.end(), NewTemplateArgs)) -+ return QualType(); -++ bool TemplateArgumentsChanged = !llvm::equal( -++ ArgsRange, NewTemplateArgs.arguments(), -++ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { -++ return A.getArgument().structurallyEquals(B.getArgument()); -++ }); -+ -+ const DependentTemplateStorage &DTN = T->getDependentTemplateName(); -+ -+ QualType Result = TL.getType(); -+- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { -++ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || -++ TemplateArgumentsChanged) { -+ TemplateName Name = getDerived().RebuildTemplateName( -+ SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), -+ /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp -+--- a/clang/lib/Serialization/ASTReaderStmt.cpp -++++ b/clang/lib/Serialization/ASTReaderStmt.cpp -+@@ -2229,6 +2229,7 @@ -+ E->PackIndex = Record.readInt(); -+ else -+ E->PackIndex = 0; -++ E->Final = CurrentUnpackingBits->getNextBit(); -+ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); -+ E->Replacement = Record.readSubExpr(); -+ } -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp -+--- a/clang/lib/Serialization/ASTWriterStmt.cpp -++++ b/clang/lib/Serialization/ASTWriterStmt.cpp -+@@ -2229,6 +2229,7 @@ -+ CurrentPackingBits.addBit((bool)E->getPackIndex()); -+ if (auto PackIndex = E->getPackIndex()) -+ Record.push_back(*PackIndex + 1); -++ CurrentPackingBits.addBit(E->getFinal()); -+ -+ Record.AddSourceLocation(E->getNameLoc()); -+ Record.AddStmt(E->getReplacement()); -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h -+--- a/clang/test/CodeGen/include/cuda.h -++++ b/clang/test/CodeGen/include/cuda.h -+@@ -1,194 +0,0 @@ -+-/* Minimal declarations for CUDA support. Testing purposes only. -+- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h -+- */ -+-#pragma once -+- -+-// Make this file work with nvcc, for testing compatibility. -+- -+-#ifndef __NVCC__ -+-#define __constant__ __attribute__((constant)) -+-#define __device__ __attribute__((device)) -+-#define __global__ __attribute__((global)) -+-#define __host__ __attribute__((host)) -+-#define __shared__ __attribute__((shared)) -+-#define __managed__ __attribute__((managed)) -+-#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) -+- -+-struct dim3 { -+- unsigned x, y, z; -+- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} -+-}; -+- -+-// Host- and device-side placement new overloads. -+-void *operator new(__SIZE_TYPE__, void *p) { return p; } -+-void *operator new[](__SIZE_TYPE__, void *p) { return p; } -+-__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } -+-__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } -+- -+-#define CUDA_VERSION 10100 -+- -+-struct char1 { -+- char x; -+- __host__ __device__ char1(char x = 0) : x(x) {} -+-}; -+-struct char2 { -+- char x, y; -+- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} -+-}; -+-struct char4 { -+- char x, y, z, w; -+- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct uchar1 { -+- unsigned char x; -+- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} -+-}; -+-struct uchar2 { -+- unsigned char x, y; -+- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} -+-}; -+-struct uchar4 { -+- unsigned char x, y, z, w; -+- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct short1 { -+- short x; -+- __host__ __device__ short1(short x = 0) : x(x) {} -+-}; -+-struct short2 { -+- short x, y; -+- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} -+-}; -+-struct short4 { -+- short x, y, z, w; -+- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct ushort1 { -+- unsigned short x; -+- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} -+-}; -+-struct ushort2 { -+- unsigned short x, y; -+- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} -+-}; -+-struct ushort4 { -+- unsigned short x, y, z, w; -+- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct int1 { -+- int x; -+- __host__ __device__ int1(int x = 0) : x(x) {} -+-}; -+-struct int2 { -+- int x, y; -+- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} -+-}; -+-struct int4 { -+- int x, y, z, w; -+- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct uint1 { -+- unsigned x; -+- __host__ __device__ uint1(unsigned x = 0) : x(x) {} -+-}; -+-struct uint2 { -+- unsigned x, y; -+- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} -+-}; -+-struct uint3 { -+- unsigned x, y, z; -+- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} -+-}; -+-struct uint4 { -+- unsigned x, y, z, w; -+- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct longlong1 { -+- long long x; -+- __host__ __device__ longlong1(long long x = 0) : x(x) {} -+-}; -+-struct longlong2 { -+- long long x, y; -+- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} -+-}; -+-struct longlong4 { -+- long long x, y, z, w; -+- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct ulonglong1 { -+- unsigned long long x; -+- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} -+-}; -+-struct ulonglong2 { -+- unsigned long long x, y; -+- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} -+-}; -+-struct ulonglong4 { -+- unsigned long long x, y, z, w; -+- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct float1 { -+- float x; -+- __host__ __device__ float1(float x = 0) : x(x) {} -+-}; -+-struct float2 { -+- float x, y; -+- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} -+-}; -+-struct float4 { -+- float x, y, z, w; -+- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-struct double1 { -+- double x; -+- __host__ __device__ double1(double x = 0) : x(x) {} -+-}; -+-struct double2 { -+- double x, y; -+- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} -+-}; -+-struct double4 { -+- double x, y, z, w; -+- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} -+-}; -+- -+-typedef unsigned long long cudaTextureObject_t; -+-typedef unsigned long long cudaSurfaceObject_t; -+- -+-enum cudaTextureReadMode { -+- cudaReadModeNormalizedFloat, -+- cudaReadModeElementType -+-}; -+- -+-enum cudaSurfaceBoundaryMode { -+- cudaBoundaryModeZero, -+- cudaBoundaryModeClamp, -+- cudaBoundaryModeTrap -+-}; -+- -+-enum { -+- cudaTextureType1D, -+- cudaTextureType2D, -+- cudaTextureType3D, -+- cudaTextureTypeCubemap, -+- cudaTextureType1DLayered, -+- cudaTextureType2DLayered, -+- cudaTextureTypeCubemapLayered -+-}; -+- -+-struct textureReference {}; -+-template -+-struct __attribute__((device_builtin_texture_type)) texture -+- : public textureReference {}; -+- -+-#endif // !__NVCC__ -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h -+--- a/clang/test/CodeGen/Inputs/cuda.h -++++ b/clang/test/CodeGen/Inputs/cuda.h -+@@ -0,0 +1,194 @@ -++/* Minimal declarations for CUDA support. Testing purposes only. -++ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h -++ */ -++#pragma once -++ -++// Make this file work with nvcc, for testing compatibility. -++ -++#ifndef __NVCC__ -++#define __constant__ __attribute__((constant)) -++#define __device__ __attribute__((device)) -++#define __global__ __attribute__((global)) -++#define __host__ __attribute__((host)) -++#define __shared__ __attribute__((shared)) -++#define __managed__ __attribute__((managed)) -++#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) -++ -++struct dim3 { -++ unsigned x, y, z; -++ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} -++}; -++ -++// Host- and device-side placement new overloads. -++void *operator new(__SIZE_TYPE__, void *p) { return p; } -++void *operator new[](__SIZE_TYPE__, void *p) { return p; } -++__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } -++__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } -++ -++#define CUDA_VERSION 10100 -++ -++struct char1 { -++ char x; -++ __host__ __device__ char1(char x = 0) : x(x) {} -++}; -++struct char2 { -++ char x, y; -++ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} -++}; -++struct char4 { -++ char x, y, z, w; -++ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct uchar1 { -++ unsigned char x; -++ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} -++}; -++struct uchar2 { -++ unsigned char x, y; -++ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} -++}; -++struct uchar4 { -++ unsigned char x, y, z, w; -++ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct short1 { -++ short x; -++ __host__ __device__ short1(short x = 0) : x(x) {} -++}; -++struct short2 { -++ short x, y; -++ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} -++}; -++struct short4 { -++ short x, y, z, w; -++ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct ushort1 { -++ unsigned short x; -++ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} -++}; -++struct ushort2 { -++ unsigned short x, y; -++ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} -++}; -++struct ushort4 { -++ unsigned short x, y, z, w; -++ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct int1 { -++ int x; -++ __host__ __device__ int1(int x = 0) : x(x) {} -++}; -++struct int2 { -++ int x, y; -++ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} -++}; -++struct int4 { -++ int x, y, z, w; -++ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct uint1 { -++ unsigned x; -++ __host__ __device__ uint1(unsigned x = 0) : x(x) {} -++}; -++struct uint2 { -++ unsigned x, y; -++ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} -++}; -++struct uint3 { -++ unsigned x, y, z; -++ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} -++}; -++struct uint4 { -++ unsigned x, y, z, w; -++ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct longlong1 { -++ long long x; -++ __host__ __device__ longlong1(long long x = 0) : x(x) {} -++}; -++struct longlong2 { -++ long long x, y; -++ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} -++}; -++struct longlong4 { -++ long long x, y, z, w; -++ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct ulonglong1 { -++ unsigned long long x; -++ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} -++}; -++struct ulonglong2 { -++ unsigned long long x, y; -++ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} -++}; -++struct ulonglong4 { -++ unsigned long long x, y, z, w; -++ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct float1 { -++ float x; -++ __host__ __device__ float1(float x = 0) : x(x) {} -++}; -++struct float2 { -++ float x, y; -++ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} -++}; -++struct float4 { -++ float x, y, z, w; -++ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++struct double1 { -++ double x; -++ __host__ __device__ double1(double x = 0) : x(x) {} -++}; -++struct double2 { -++ double x, y; -++ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} -++}; -++struct double4 { -++ double x, y, z, w; -++ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} -++}; -++ -++typedef unsigned long long cudaTextureObject_t; -++typedef unsigned long long cudaSurfaceObject_t; -++ -++enum cudaTextureReadMode { -++ cudaReadModeNormalizedFloat, -++ cudaReadModeElementType -++}; -++ -++enum cudaSurfaceBoundaryMode { -++ cudaBoundaryModeZero, -++ cudaBoundaryModeClamp, -++ cudaBoundaryModeTrap -++}; -++ -++enum { -++ cudaTextureType1D, -++ cudaTextureType2D, -++ cudaTextureType3D, -++ cudaTextureTypeCubemap, -++ cudaTextureType1DLayered, -++ cudaTextureType2DLayered, -++ cudaTextureTypeCubemapLayered -++}; -++ -++struct textureReference {}; -++template -++struct __attribute__((device_builtin_texture_type)) texture -++ : public textureReference {}; -++ -++#endif // !__NVCC__ -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu -+--- a/clang/test/CodeGen/nvptx-surface.cu -++++ b/clang/test/CodeGen/nvptx-surface.cu -+@@ -1,6 +1,6 @@ -+ // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s -+ // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s -+-#include "include/cuda.h" -++#include "Inputs/cuda.h" -+ -+ #include "__clang_cuda_texture_intrinsics.h" -+ -+diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp -+--- a/clang/test/SemaTemplate/dependent-names.cpp -++++ b/clang/test/SemaTemplate/dependent-names.cpp -+@@ -458,3 +458,12 @@ -+ }; -+ int f(b ba) { return ba.add<0>(); } -+ } -++ -++namespace TransformDependentTemplates { -++ template struct Test1 { -++ template -++ using Arg = typename T::template Arg; -++ void f(Arg); -++ void f(Arg); -++ }; -++} // namespace TransformDependentTemplates -+diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -+--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -+@@ -15391,12 +15391,20 @@ -+ -+ if (E->State == TreeEntry::SplitVectorize) { -+ Res = FindLastInst(); -++ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { -++ for (auto *E : Entries) { -++ auto *I = dyn_cast_or_null(E->VectorizedValue); -++ if (!I) -++ I = &getLastInstructionInBundle(E); -++ if (Res->comesBefore(I)) -++ Res = I; -++ } -++ } -+ return *Res; -+ } -+ -+ // Set insertpoint for gathered loads to the very first load. -+- if (E->State != TreeEntry::SplitVectorize && -+- GatheredLoadsEntriesFirst.has_value() && -++ if (GatheredLoadsEntriesFirst.has_value() && -+ E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && -+ E->getOpcode() == Instruction::Load) { -+ Res = FindFirstInst(); -+diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -+--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -++++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp -+@@ -2590,6 +2590,14 @@ -+ if (R.mayWriteToMemory() && !InterleaveR) -+ return; -+ -++ // Do not narrow interleave groups if there are VectorPointer recipes and -++ // the plan was unrolled. The recipe implicitly uses VF from -++ // VPTransformState. -++ // TODO: Remove restriction once the VF for the VectorPointer offset is -++ // modeled explicitly as operand. -++ if (isa(&R) && Plan.getUF() > 1) -++ return; -++ -+ // All other ops are allowed, but we reject uses that cannot be converted -+ // when checking all allowed consumers (store interleave groups) below. -+ if (!InterleaveR) -+diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -+--- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -++++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll -+@@ -66,3 +66,91 @@ -+ exit: -+ ret void -+ } -++ -++define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { -++; CHECK-LABEL: define void @test_2xi64_with_wide_load( -++; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { -++; CHECK-NEXT: [[ENTRY:.*]]: -++; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] -++; CHECK: [[VECTOR_PH]]: -++; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -++; CHECK: [[VECTOR_BODY]]: -++; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -++; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 -++; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] -++; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 -++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 -++; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 -++; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 -++; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 -++; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 -++; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] -++; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] -++; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 -++; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 -++; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> -++; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] -++; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] -++; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] -++; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] -++; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> -++; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> -++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 -++; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> -++; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> -++; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 -++; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 -++; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 -++; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] -++; CHECK: [[MIDDLE_BLOCK]]: -++; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] -++; CHECK: [[SCALAR_PH]]: -++; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] -++; CHECK-NEXT: br label %[[LOOP:.*]] -++; CHECK: [[LOOP]]: -++; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] -++; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] -++; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -++; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 -++; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] -++; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 -++; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] -++; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 -++; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 -++; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] -++; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 -++; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] -++; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 -++; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 -++; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 -++; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] -++; CHECK: [[EXIT]]: -++; CHECK-NEXT: ret void -++; -++entry: -++ br label %loop -++ -++loop: -++ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] -++ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv -++ %l.factor = load i64, ptr %arrayidx, align 8 -++ %1 = shl nsw i64 %iv, 1 -++ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 -++ %l.0 = load i64, ptr %data.0, align 8 -++ %mul.0 = mul i64 %l.factor, %l.0 -++ store i64 %mul.0, ptr %data.0, align 8 -++ %3 = or disjoint i64 %1, 1 -++ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 -++ %l.1 = load i64, ptr %data.1, align 8 -++ %mul.1 = mul i64 %l.factor, %l.1 -++ store i64 %mul.1, ptr %data.1, align 8 -++ %iv.next = add nuw nsw i64 %iv, 1 -++ %ec = icmp eq i64 %iv.next, 100 -++ br i1 %ec, label %exit, label %loop -++ -++exit: -++ ret void -++} -+diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -+--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -+@@ -0,0 +1,99 @@ -++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s -++ -++define void @test(ptr %0, <8 x i8> %1) { -++; CHECK-LABEL: define void @test( -++; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { -++; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 -++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 -++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 -++; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 -++; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 -++; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> -++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 -++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> -++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> -++; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) -++; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 -++; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> -++; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) -++; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) -++; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] -++; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 -++; CHECK-NEXT: ret void -++; -++ %3 = load i8, ptr %0, align 2 -++ %4 = getelementptr i8, ptr %0, i64 13442 -++ %5 = load i8, ptr %4, align 2 -++ %6 = or i8 %5, %3 -++ %7 = getelementptr i8, ptr %0, i64 13550 -++ store i8 %6, ptr %7, align 2 -++ %8 = extractelement <8 x i8> %1, i64 0 -++ %9 = or i8 %5, %8 -++ %10 = getelementptr i8, ptr %0, i64 13542 -++ store i8 %9, ptr %10, align 2 -++ %11 = getelementptr i8, ptr %0, i64 13438 -++ %12 = load i8, ptr %11, align 2 -++ %13 = or i8 %12, %3 -++ %14 = getelementptr i8, ptr %0, i64 13546 -++ store i8 %13, ptr %14, align 2 -++ %15 = extractelement <8 x i8> %1, i64 2 -++ %16 = or i8 %12, %15 -++ %17 = getelementptr i8, ptr %0, i64 13538 -++ store i8 %16, ptr %17, align 2 -++ %18 = getelementptr i8, ptr %0, i64 13440 -++ %19 = load i8, ptr %18, align 4 -++ %20 = or i8 %19, %3 -++ %21 = getelementptr i8, ptr %0, i64 13548 -++ store i8 %20, ptr %21, align 4 -++ %22 = extractelement <8 x i8> %1, i64 4 -++ %23 = or i8 %19, %22 -++ %24 = getelementptr i8, ptr %0, i64 13540 -++ store i8 %23, ptr %24, align 4 -++ %25 = getelementptr i8, ptr %0, i64 13436 -++ %26 = load i8, ptr %25, align 4 -++ %27 = getelementptr i8, ptr %0, i64 13444 -++ %28 = load i8, ptr %27, align 4 -++ %29 = or i8 %28, %26 -++ %30 = getelementptr i8, ptr %0, i64 13544 -++ store i8 %29, ptr %30, align 4 -++ %31 = or i8 %26, %8 -++ %32 = getelementptr i8, ptr %0, i64 13536 -++ store i8 %31, ptr %32, align 4 -++ %33 = getelementptr i8, ptr %0, i64 13443 -++ %34 = load i8, ptr %33, align 1 -++ %35 = or i8 %34, %3 -++ %36 = getelementptr i8, ptr %0, i64 13551 -++ store i8 %35, ptr %36, align 1 -++ %37 = extractelement <8 x i8> %1, i64 7 -++ %38 = or i8 %34, %37 -++ %39 = getelementptr i8, ptr %0, i64 13543 -++ store i8 %38, ptr %39, align 1 -++ %40 = getelementptr i8, ptr %0, i64 13439 -++ %41 = load i8, ptr %40, align 1 -++ %42 = or i8 %41, %3 -++ %43 = getelementptr i8, ptr %0, i64 13547 -++ store i8 %42, ptr %43, align 1 -++ %44 = extractelement <8 x i8> %1, i64 3 -++ %45 = or i8 %41, %44 -++ %46 = getelementptr i8, ptr %0, i64 13539 -++ store i8 %45, ptr %46, align 1 -++ %47 = getelementptr i8, ptr %0, i64 13441 -++ %48 = load i8, ptr %47, align 1 -++ %49 = or i8 %48, %3 -++ %50 = getelementptr i8, ptr %0, i64 13549 -++ store i8 %49, ptr %50, align 1 -++ %51 = extractelement <8 x i8> %1, i64 5 -++ %52 = or i8 %48, %51 -++ %53 = getelementptr i8, ptr %0, i64 13541 -++ store i8 %52, ptr %53, align 1 -++ %54 = getelementptr i8, ptr %0, i64 13437 -++ %55 = load i8, ptr %54, align 1 -++ %56 = or i8 %55, %3 -++ %57 = getelementptr i8, ptr %0, i64 13545 -++ store i8 %56, ptr %57, align 1 -++ %58 = or i8 %55, %8 -++ %59 = getelementptr i8, ptr %0, i64 13537 -++ store i8 %58, ptr %59, align 1 -++ ret void -++} -+diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -+--- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -++++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp -+@@ -151,9 +151,10 @@ -+ MachineModuleInfoWrapperPass *MMIWP = -+ new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); -+ -+- legacy::PassManager PassMgrF; -+ SmallString<1024> Buf; -+ llvm::raw_svector_ostream OS(Buf); -++ legacy::PassManager PassMgrF; -++ -+ AsmPrinter *Printer = -+ addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); -+ PassMgrF.run(*M); +-diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp +---- a/clang/lib/AST/ASTContext.cpp +-+++ b/clang/lib/AST/ASTContext.cpp +-@@ -7011,7 +7011,7 @@ +- getCanonicalTemplateArgument(subst->getArgumentPack()); +- return getSubstTemplateTemplateParmPack( +- canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), +-- subst->getFinal(), subst->getIndex()); +-+ subst->getIndex(), subst->getFinal()); +- } +- case TemplateName::DeducedTemplate: { +- assert(IgnoreDeduced == false); + diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h + --- a/clang/lib/Sema/TreeTransform.h + +++ b/clang/lib/Sema/TreeTransform.h +@@ -44,28 +32,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/ + TemplateName Name = getDerived().RebuildTemplateName( + SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), + /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp +---- a/clang/lib/Serialization/ASTReaderStmt.cpp +-+++ b/clang/lib/Serialization/ASTReaderStmt.cpp +-@@ -2229,6 +2229,7 @@ +- E->PackIndex = Record.readInt(); +- else +- E->PackIndex = 0; +-+ E->Final = CurrentUnpackingBits->getNextBit(); +- E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); +- E->Replacement = Record.readSubExpr(); +- } +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp +---- a/clang/lib/Serialization/ASTWriterStmt.cpp +-+++ b/clang/lib/Serialization/ASTWriterStmt.cpp +-@@ -2229,6 +2229,7 @@ +- CurrentPackingBits.addBit((bool)E->getPackIndex()); +- if (auto PackIndex = E->getPackIndex()) +- Record.push_back(*PackIndex + 1); +-+ CurrentPackingBits.addBit(E->getFinal()); +- +- Record.AddSourceLocation(E->getNameLoc()); +- Record.AddStmt(E->getReplacement()); + diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h + --- a/clang/test/CodeGen/include/cuda.h + +++ b/clang/test/CodeGen/include/cuda.h +@@ -515,119 +481,6 @@ diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp + E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && + E->getOpcode() == Instruction::Load) { + Res = FindFirstInst(); +-diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +---- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +-+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +-@@ -2590,6 +2590,14 @@ +- if (R.mayWriteToMemory() && !InterleaveR) +- return; +- +-+ // Do not narrow interleave groups if there are VectorPointer recipes and +-+ // the plan was unrolled. The recipe implicitly uses VF from +-+ // VPTransformState. +-+ // TODO: Remove restriction once the VF for the VectorPointer offset is +-+ // modeled explicitly as operand. +-+ if (isa(&R) && Plan.getUF() > 1) +-+ return; +-+ +- // All other ops are allowed, but we reject uses that cannot be converted +- // when checking all allowed consumers (store interleave groups) below. +- if (!InterleaveR) +-diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +---- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +-+++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll +-@@ -66,3 +66,91 @@ +- exit: +- ret void +- } +-+ +-+define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { +-+; CHECK-LABEL: define void @test_2xi64_with_wide_load( +-+; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { +-+; CHECK-NEXT: [[ENTRY:.*]]: +-+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] +-+; CHECK: [[VECTOR_PH]]: +-+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] +-+; CHECK: [[VECTOR_BODY]]: +-+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +-+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 +-+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] +-+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 +-+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 +-+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 +-+; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 +-+; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 +-+; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 +-+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] +-+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] +-+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 +-+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 +-+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> +-+; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] +-+; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] +-+; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] +-+; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] +-+; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> +-+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> +-+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 +-+; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> +-+; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> +-+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 +-+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 +-+; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 +-+; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] +-+; CHECK: [[MIDDLE_BLOCK]]: +-+; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] +-+; CHECK: [[SCALAR_PH]]: +-+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +-+; CHECK-NEXT: br label %[[LOOP:.*]] +-+; CHECK: [[LOOP]]: +-+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] +-+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] +-+; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +-+; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 +-+; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] +-+; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 +-+; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] +-+; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 +-+; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 +-+; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] +-+; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 +-+; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] +-+; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 +-+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +-+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 +-+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] +-+; CHECK: [[EXIT]]: +-+; CHECK-NEXT: ret void +-+; +-+entry: +-+ br label %loop +-+ +-+loop: +-+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] +-+ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv +-+ %l.factor = load i64, ptr %arrayidx, align 8 +-+ %1 = shl nsw i64 %iv, 1 +-+ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 +-+ %l.0 = load i64, ptr %data.0, align 8 +-+ %mul.0 = mul i64 %l.factor, %l.0 +-+ store i64 %mul.0, ptr %data.0, align 8 +-+ %3 = or disjoint i64 %1, 1 +-+ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 +-+ %l.1 = load i64, ptr %data.1, align 8 +-+ %mul.1 = mul i64 %l.factor, %l.1 +-+ store i64 %mul.1, ptr %data.1, align 8 +-+ %iv.next = add nuw nsw i64 %iv, 1 +-+ %ec = icmp eq i64 %iv.next, 100 +-+ br i1 %ec, label %exit, label %loop +-+ +-+exit: +-+ ret void +-+} + diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll + --- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll + +++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +@@ -731,18 +584,3 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-nod + + store i8 %58, ptr %59, align 1 + + ret void + +} +-diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +---- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +-+++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp +-@@ -151,9 +151,10 @@ +- MachineModuleInfoWrapperPass *MMIWP = +- new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); +- +-- legacy::PassManager PassMgrF; +- SmallString<1024> Buf; +- llvm::raw_svector_ostream OS(Buf); +-+ legacy::PassManager PassMgrF; +-+ +- AsmPrinter *Printer = +- addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); +- PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 4a58099..c3bcd53 100644 +index c3bcd53..73450ce 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109" -- LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46" -+ LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" -+ LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" +- LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" +- LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" ++ LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" ++ LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 3789787f919440..104a8de57175ad 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "cf9436497603650441904a21316fbd058551f663" - SHARDY_SHA256 = "133dcda8bf84d516f67b3bcd0e1c5a564b266a67cb96f8901ccf8be30e830d3e" + SHARDY_COMMIT = "1ba08b6822b3bce9ce4acb3b839b05b3266ca0bc" + SHARDY_SHA256 = "6930a383ed9b516041f08ae948e86a3926a2cf11c7457fa50950f298275c6a84" tf_http_archive( name = "shardy", diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 54ca1e51bbea80..30f3cec32c4723 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -15,11 +15,11 @@ limitations under the License. #include #include -#include #include #include "deallocation/transforms/passes.h" #include "deallocation/utils/util.h" +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -440,7 +440,7 @@ bool simplifyLoopDeallocs(Block& block) { } } - breaks_if_you_move_ops::ValueEquivalenceClasses eq; + llvm::EquivalenceClasses eq; auto getAliases = [&](RegionBranchPoint point) { for (const auto& edge : getSuccessorRegions(rbi, point)) { for (auto [pred, succ] : llvm::zip(edge.getPredecessorOperands(), diff --git a/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h b/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h index c18c8f9dcd2485..2349e243ab0376 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h +++ b/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "llvm/ADT/EquivalenceClasses.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" namespace mlir { @@ -136,8 +135,6 @@ namespace breaks_if_you_move_ops { // The comparator depends on the location of ops, so if you insert an op into // a set and then move it, it may end up in the wrong location. -using ValueEquivalenceClasses = - llvm::EquivalenceClasses; using ValueSet = std::set; template using ValueMap = std::map; From 11e1b4051fdec807d60b15704084204b58191ce9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 16:31:19 -0700 Subject: [PATCH 0476/1324] Nit: Remove % from a variable for consistency PiperOrigin-RevId: 745778122 --- third_party/xla/xla/hlo/transforms/tests/cholesky_expander.hlo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/transforms/tests/cholesky_expander.hlo b/third_party/xla/xla/hlo/transforms/tests/cholesky_expander.hlo index dfdd63e56e41ed..f1a177a84ddb45 100644 --- a/third_party/xla/xla/hlo/transforms/tests/cholesky_expander.hlo +++ b/third_party/xla/xla/hlo/transforms/tests/cholesky_expander.hlo @@ -337,5 +337,5 @@ HloModule CholeskyExpanderTest ENTRY test { input = f32[1,256,256] parameter(0) - ROOT decomp = f32[1,256,256] cholesky(%input) + ROOT decomp = f32[1,256,256] cholesky(input) } From 34d6e719dd1dc704a989d29dd3ae180009ca8d09 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 16:43:50 -0700 Subject: [PATCH 0477/1324] Extend `safe_reinterpret_cast` to work between a data pointer and `void*`. Currently `safe_reinterpret_cast` allows casting between a function pointer and `void*`. According to https://google.github.io/styleguide/cppguide.html#Casting, `reinterpret_cast` between a data pointer and a `void*` is allowed too. PiperOrigin-RevId: 745782186 --- .../xla/xla/tsl/util/safe_reinterpret_cast.h | 25 +++++++++++-------- .../tsl/util/safe_reinterpret_cast_test.cc | 7 ++++++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index ab5e2efd800aae..079f4c060bf490 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -56,13 +56,18 @@ struct IsByteLike : std::true_type {}; template struct IsSafeCast : std::false_type {}; -// It's safe to cast a pointer to/from a byte-like type, or to/from the same -// type. Also, while not guaranteed by the C++ standard, POSIX mandates that -// it's safe to cast a function pointer to/from a void pointer -// (https://pubs.opengroup.org/onlinepubs/9799919799/functions/dlsym.html). -// On Windows (with MSVC), casting a function pointer to/from a void pointer has -// been a widely adopted practice for decades and is considered safe in -// practice, even though it is not explicitly guaranteed by Microsoft. +// 1. The C++ standard guarantees that it's safe to cast a pointer to/from a +// pointer to a byte-like type. +// 2a. The Google C++ style guide states that it's safe to cast a data pointer +// to/from a void pointer. +// 2b. While not guaranteed by the C++ standard, POSIX mandates that it's safe +// to cast a function pointer to/from a void pointer +// (https://pubs.opengroup.org/onlinepubs/9799919799/functions/dlsym.html). +// On Windows (with MSVC), casting a function pointer to/from a void +// pointer has been a widely adopted practice for decades and is considered +// safe in practice, even though it is not explicitly guaranteed by +// Microsoft. +// 3. It's safe to cast a pointer or to/from the same type. template struct IsSafeCast : std::integral_constant< @@ -70,10 +75,8 @@ struct IsSafeCast // To/from a pointer to a byte-like type. (IsByteLike::type>::value || IsByteLike::type>::value) || - // From function pointer to void pointer. - (std::is_function_v&& std::is_void_v) || - // From void pointer to function pointer. - (std::is_void_v&& std::is_function_v) || + // To/from void pointer. + (std::is_void_v || std::is_void_v) || // Between the same type. std::is_same_v> {}; diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index 7f9a48f2eb2248..24c82637e42041 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -147,5 +147,12 @@ TEST(SafeReinterepretCast, CanCastFuncPointerToFromVoidPointer) { EXPECT_EQ(func_p, &Dummy); } +TEST(SafeReinterepretCast, CanCastDataPointerToFromVoidPointer) { + int x = 42; + void* const void_p = safe_reinterpret_cast(&x); + int* const int_p = safe_reinterpret_cast(void_p); + EXPECT_EQ(int_p, &x); +} + } // namespace } // namespace tsl From dbaf77714730454d2b7a50d077bfca246eedd4da Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Wed, 9 Apr 2025 17:03:06 -0700 Subject: [PATCH 0478/1324] move tf_mlir_translate_registration.cc to tensorflow/compiler/mlir/tools PiperOrigin-RevId: 745787932 --- tensorflow/compiler/mlir/BUILD | 2 +- tensorflow/compiler/mlir/lite/tools/BUILD | 30 ------------------- tensorflow/compiler/mlir/tools/BUILD | 30 +++++++++++++++++++ .../tools/tf_mlir_translate_registration.cc | 0 4 files changed, 31 insertions(+), 31 deletions(-) rename tensorflow/compiler/mlir/{lite => }/tools/tf_mlir_translate_registration.cc (100%) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index e85e3935e54b95..c11a761a089128 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -220,11 +220,11 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/lite/tools:translate_registration", "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tf2xla/tests/registration:graph_to_tf_executor_registration", "//tensorflow/compiler/mlir/tools:translate_cl_options", + "//tensorflow/compiler/mlir/tools:translate_registration", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/lite/tools/BUILD b/tensorflow/compiler/mlir/lite/tools/BUILD index 134f5b767f863d..055877d0b32200 100644 --- a/tensorflow/compiler/mlir/lite/tools/BUILD +++ b/tensorflow/compiler/mlir/lite/tools/BUILD @@ -22,33 +22,3 @@ cc_library( ) # LINT.ThenChange(//tensorflow/lite/tools:command_line_flags) - -cc_library( - name = "translate_registration", - srcs = [ - "tf_mlir_translate_registration.cc", - ], - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", - "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", - "//tensorflow/compiler/mlir/tools:translate_cl_options", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/service/cpu:cpu_compiler", - "@local_xla//xla/service/cpu:cpu_transfer_manager", - "@local_xla//xla/stream_executor/host:host_platform", - "@local_xla//xla/stream_executor/host:host_platform_id", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/mlir/tools/BUILD b/tensorflow/compiler/mlir/tools/BUILD index d3d0aa56ab97aa..3b29e0f5666497 100644 --- a/tensorflow/compiler/mlir/tools/BUILD +++ b/tensorflow/compiler/mlir/tools/BUILD @@ -19,3 +19,33 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "translate_registration", + srcs = [ + "tf_mlir_translate_registration.cc", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", + "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/service/cpu:cpu_transfer_manager", + "@local_xla//xla/stream_executor/host:host_platform", + "@local_xla//xla/stream_executor/host:host_platform_id", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc similarity index 100% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc From bc49118573e2baacd5576f0a0cb4ef09945113bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 18:15:01 -0700 Subject: [PATCH 0479/1324] tensorflow/lite/experimental/litert has now been moved to google-ai-edge/litert PiperOrigin-RevId: 745806922 --- tensorflow/lite/experimental/litert/BUILD | 18 - .../experimental/litert/build_common/BUILD | 45 - .../export_litert_only_darwin.lds | 8 - .../build_common/export_litert_only_linux.lds | 29 - .../export_litert_runtime_only_darwin.lds | 2 - .../export_litert_runtime_only_linux.lds | 20 - .../litert/build_common/litert_build_defs.bzl | 321 ----- .../litert/build_common/special_rule.bzl | 35 - .../litert/build_common/tfl_model_gen.bzl | 49 - tensorflow/lite/experimental/litert/c/BUILD | 589 ---------- .../litert/c/litert_accelerator.cc | 108 -- .../litert/c/litert_accelerator.h | 76 -- .../litert_accelerator_compilation_options.cc | 145 --- .../litert_accelerator_compilation_options.h | 93 -- ...rt_accelerator_compilation_options_test.cc | 182 --- .../c/litert_accelerator_registration.cc | 122 -- .../c/litert_accelerator_registration.h | 98 -- .../c/litert_accelerator_registration_test.cc | 142 --- .../litert/c/litert_accelerator_test.cc | 287 ----- .../lite/experimental/litert/c/litert_any.h | 67 -- .../litert/c/litert_c_api_common_test.c | 38 - .../experimental/litert/c/litert_common.cc | 51 - .../experimental/litert/c/litert_common.h | 139 --- .../litert/c/litert_common_test.cc | 71 -- .../litert/c/litert_compilation_options.cc | 91 -- .../litert/c/litert_compilation_options.h | 67 -- .../c/litert_compilation_options_test.cc | 139 --- .../litert/c/litert_compiled_model.cc | 144 --- .../litert/c/litert_compiled_model.h | 136 --- .../litert/c/litert_compiled_model_test.cc | 177 --- .../litert/c/litert_dispatch_delegate.h | 61 - .../litert/c/litert_environment.cc | 69 -- .../litert/c/litert_environment.h | 54 - .../litert/c/litert_environment_options.cc | 37 - .../litert/c/litert_environment_options.h | 52 - .../c/litert_environment_options_test.cc | 97 -- .../experimental/litert/c/litert_event.cc | 105 -- .../lite/experimental/litert/c/litert_event.h | 59 - .../experimental/litert/c/litert_event_type.h | 32 - .../experimental/litert/c/litert_gl_types.h | 43 - .../experimental/litert/c/litert_layout.h | 45 - .../experimental/litert/c/litert_logging.cc | 115 -- .../experimental/litert/c/litert_logging.h | 95 -- .../litert/c/litert_logging_test.cc | 34 - .../experimental/litert/c/litert_model.cc | 506 -------- .../lite/experimental/litert/c/litert_model.h | 372 ------ .../litert/c/litert_model_test.cc | 390 ------ .../experimental/litert/c/litert_op_code.h | 245 ---- .../experimental/litert/c/litert_options.cc | 771 ------------ .../experimental/litert/c/litert_options.h | 351 ------ .../litert/c/litert_options_test.cc | 462 -------- .../litert/c/litert_tensor_buffer.cc | 480 -------- .../litert/c/litert_tensor_buffer.h | 244 ---- .../c/litert_tensor_buffer_requirements.cc | 92 -- .../c/litert_tensor_buffer_requirements.h | 57 - .../litert_tensor_buffer_requirements_test.cc | 92 -- .../litert/c/litert_tensor_buffer_test.cc | 439 ------- .../litert/c/litert_tensor_buffer_types.h | 30 - tensorflow/lite/experimental/litert/cc/BUILD | 708 ----------- .../litert_accelerator_compilation_options.h | 124 -- .../lite/experimental/litert/cc/litert_any.h | 221 ---- .../experimental/litert/cc/litert_any_test.cc | 110 -- .../litert/cc/litert_buffer_ref.h | 356 ------ .../litert/cc/litert_buffer_ref_test.cc | 332 ------ .../litert/cc/litert_compilation_options.h | 78 -- .../litert/cc/litert_compiled_model.cc | 214 ---- .../litert/cc/litert_compiled_model.h | 438 ------- .../cc/litert_compiled_model_gpu_test.cc | 203 ---- .../litert_compiled_model_integration_test.cc | 376 ------ .../litert/cc/litert_compiled_model_test.cc | 359 ------ .../experimental/litert/cc/litert_consts.h | 34 - .../experimental/litert/cc/litert_detail.h | 105 -- .../litert/cc/litert_dispatch_delegate.h | 41 - .../litert/cc/litert_element_type.h | 157 --- .../litert/cc/litert_element_type_test.cc | 48 - .../litert/cc/litert_environment.h | 86 -- .../litert/cc/litert_environment_test.cc | 91 -- .../experimental/litert/cc/litert_event.h | 100 -- .../litert/cc/litert_event_test.cc | 36 - .../experimental/litert/cc/litert_expected.h | 390 ------ .../litert/cc/litert_expected_test.cc | 233 ---- .../experimental/litert/cc/litert_handle.h | 74 -- .../experimental/litert/cc/litert_layout.h | 158 --- .../litert/cc/litert_layout_test.cc | 62 - .../experimental/litert/cc/litert_macros.cc | 74 -- .../experimental/litert/cc/litert_macros.h | 349 ------ .../litert/cc/litert_macros_test.cc | 207 ---- .../experimental/litert/cc/litert_model.cc | 154 --- .../experimental/litert/cc/litert_model.h | 473 -------- .../litert/cc/litert_model_predicates.cc | 120 -- .../litert/cc/litert_model_predicates.h | 78 -- .../litert/cc/litert_model_predicates_test.cc | 215 ---- .../litert/cc/litert_model_test.cc | 359 ------ .../litert/cc/litert_op_options.cc | 44 - .../litert/cc/litert_op_options.h | 63 - .../litert/cc/litert_op_options_test.cc | 60 - .../litert/cc/litert_shared_library.cc | 211 ---- .../litert/cc/litert_shared_library.h | 172 --- .../litert/cc/litert_shared_library_test.cc | 142 --- .../litert/cc/litert_tensor_buffer.h | 350 ------ .../cc/litert_tensor_buffer_requirements.h | 92 -- .../litert_tensor_buffer_requirements_test.cc | 104 -- .../litert/cc/litert_tensor_buffer_test.cc | 613 ---------- .../litert/cc/litert_tensor_buffer_utils.cc | 46 - .../litert/cc/litert_tensor_buffer_utils.h | 24 - .../litert/cc/test_shared_library.cc | 19 - .../lite/experimental/litert/compiler/BUILD | 18 - .../experimental/litert/compiler/plugin/BUILD | 152 --- .../litert/compiler/plugin/algo.cc | 331 ------ .../litert/compiler/plugin/algo.h | 39 - .../litert/compiler/plugin/algo_test.cc | 304 ----- .../litert/compiler/plugin/compiler_flags.cc | 88 -- .../litert/compiler/plugin/compiler_flags.h | 69 -- .../compiler/plugin/compiler_flags_test.cc | 100 -- .../litert/compiler/plugin/compiler_plugin.cc | 649 ---------- .../litert/compiler/plugin/compiler_plugin.h | 202 ---- .../compiler/plugin/compiler_plugin_test.cc | 498 -------- .../lite/experimental/litert/core/BUILD | 218 ---- .../experimental/litert/core/build_stamp.cc | 63 - .../experimental/litert/core/build_stamp.h | 57 - .../litert/core/build_stamp_test.cc | 51 - .../litert/core/dispatch_op_schema.cc | 89 -- .../litert/core/dispatch_op_schema.h | 57 - .../litert/core/dispatch_op_schema_test.cc | 74 -- .../litert/core/dynamic_loading.cc | 148 --- .../litert/core/dynamic_loading.h | 68 -- .../litert/core/dynamic_loading_test.cc | 200 ---- .../experimental/litert/core/environment.cc | 33 - .../experimental/litert/core/environment.h | 58 - .../litert/core/environment_options.cc | 71 -- .../litert/core/environment_options.h | 42 - .../litert/core/environment_options_test.cc | 78 -- .../litert/core/environment_test.cc | 68 -- .../experimental/litert/core/filesystem.cc | 102 -- .../experimental/litert/core/filesystem.h | 49 - .../litert/core/filesystem_test.cc | 39 - .../litert/core/insert_order_map.h | 78 -- .../litert/core/insert_order_map_test.cc | 99 -- .../lite/experimental/litert/core/model/BUILD | 349 ------ .../litert/core/model/buffer_manager.h | 115 -- .../litert/core/model/buffer_manager_test.cc | 71 -- .../litert/core/model/flatbuffer_to_litert.cc | 154 --- .../litert/core/model/flatbuffer_to_litert.h | 43 - .../core/model/flatbuffer_to_litert_test.cc | 125 -- .../litert/core/model/graph_validation.cc | 114 -- .../litert/core/model/graph_validation.h | 47 - .../litert/core/model/ir_allocator.h | 153 --- .../litert/core/model/ir_allocator_test.cc | 129 -- .../litert/core/model/litert_to_flatbuffer.cc | 128 -- .../litert/core/model/litert_to_flatbuffer.h | 32 - .../core/model/litert_to_flatbuffer_test.cc | 108 -- .../experimental/litert/core/model/model.cc | 211 ---- .../experimental/litert/core/model/model.h | 1024 ---------------- .../litert/core/model/model_buffer.cc | 134 --- .../litert/core/model/model_buffer.h | 58 - .../litert/core/model/model_buffer_test.cc | 128 -- .../litert/core/model/model_file_test.cc | 1041 ----------------- .../litert/core/model/model_file_test_util.cc | 181 --- .../litert/core/model/model_file_test_util.h | 51 - .../litert/core/model/model_graph.cc | 226 ---- .../litert/core/model/model_graph.h | 110 -- .../litert/core/model/model_graph_test.cc | 419 ------- .../litert/core/model/model_load.cc | 432 ------- .../litert/core/model/model_load.h | 36 - .../litert/core/model/model_serialize.cc | 559 --------- .../litert/core/model/model_serialize.h | 30 - .../litert/core/model/model_test.cc | 531 --------- .../lite/experimental/litert/core/util/BUILD | 89 -- .../litert/core/util/flatbuffer_tools.cc | 330 ------ .../litert/core/util/flatbuffer_tools.h | 314 ----- .../litert/core/util/flatbuffer_tools_test.cc | 175 --- .../litert/core/util/tensor_type_util.cc | 57 - .../litert/core/util/tensor_type_util.h | 111 -- .../litert/core/util/tensor_type_util_test.cc | 69 -- .../lite/experimental/litert/core/version.h | 37 - .../litert/integration_test/BUILD | 129 -- .../integration_test/gen_device_test.cc | 199 ---- .../integration_test/gen_device_test_lib.h | 139 --- .../litert/integration_test/run_on_device.bzl | 296 ----- .../run_on_device_driver_OSS.sh | 30 - .../single_op_models/add_f32.mlir | 6 - .../single_op_models/concatenate_f32.mlir | 6 - .../single_op_models/divide_f32.mlir | 6 - .../single_op_models/greater_f32.mlir | 6 - .../single_op_models/less_f32.mlir | 6 - .../single_op_models/multiply_f32.mlir | 6 - .../single_op_models/reshape_f32.mlir | 7 - .../reshape_f32_large_rank.mlir | 7 - .../single_op_models/rsqrt_f32.mlir | 6 - .../single_op_models/select_f32.mlir | 6 - .../single_op_models/slice_f32.mlir | 8 - .../single_op_models/subtract_f32.mlir | 6 - .../single_op_models/tanh_f32.mlir | 6 - .../lite/experimental/litert/python/BUILD | 18 - .../lite/experimental/litert/runtime/BUILD | 414 ------- .../experimental/litert/runtime/accelerator.h | 74 -- .../accelerator_model_compilation_data.h | 54 - ...accelerator_model_compilation_data_test.cc | 41 - .../litert/runtime/accelerator_registry.cc | 66 -- .../litert/runtime/accelerator_registry.h | 89 -- .../litert/runtime/accelerator_test.cc | 61 - .../litert/runtime/accelerators/BUILD | 52 - .../accelerator_implementation_helper.h | 147 --- .../runtime/accelerators/auto_registration.cc | 89 -- .../runtime/accelerators/auto_registration.h | 33 - .../runtime/accelerators/dispatch/BUILD | 41 - .../dispatch/dispatch_accelerator.cc | 242 ---- .../dispatch/dispatch_accelerator.h | 46 - .../litert/runtime/accelerators/xnnpack/BUILD | 38 - .../xnnpack/xnnpack_accelerator.cc | 101 -- .../xnnpack/xnnpack_accelerator.h | 41 - .../litert/runtime/ahwb_buffer.cc | 111 -- .../experimental/litert/runtime/ahwb_buffer.h | 53 - .../litert/runtime/compilation_options.h | 34 - .../litert/runtime/compiled_model.cc | 638 ---------- .../litert/runtime/compiled_model.h | 226 ---- .../litert/runtime/compiled_model_test.cc | 544 --------- .../litert/runtime/compiler/BUILD | 89 -- .../compiler/jit_compilation_mediatek_test.cc | 97 -- .../compiler/jit_compilation_qualcomm_test.cc | 99 -- .../litert/runtime/dispatch/BUILD | 206 ---- .../litert/runtime/dispatch/README.md | 20 - .../runtime/dispatch/dispatch_delegate.cc | 173 --- .../dispatch_delegate_google_tensor_test.cc | 529 --------- .../dispatch/dispatch_delegate_kernel.cc | 657 ----------- .../dispatch/dispatch_delegate_kernel.h | 121 -- .../dispatch_delegate_mediatek_test.cc | 406 ------- .../dispatch/dispatch_delegate_options.h | 123 -- .../dispatch_delegate_qualcomm_test.cc | 406 ------- .../runtime/dispatch/litert_dispatch.cc | 571 --------- .../litert/runtime/dmabuf_buffer.cc | 195 --- .../litert/runtime/dmabuf_buffer.h | 33 - .../lite/experimental/litert/runtime/event.cc | 121 -- .../lite/experimental/litert/runtime/event.h | 53 - .../runtime/external_litert_buffer_context.cc | 125 -- .../runtime/external_litert_buffer_context.h | 134 --- .../litert/runtime/fastrpc_buffer.cc | 158 --- .../litert/runtime/fastrpc_buffer.h | 33 - .../experimental/litert/runtime/gl_buffer.cc | 338 ------ .../experimental/litert/runtime/gl_buffer.h | 105 -- .../litert/runtime/gl_buffer_test.cc | 183 --- .../experimental/litert/runtime/gl_texture.cc | 110 -- .../experimental/litert/runtime/gl_texture.h | 57 - .../litert/runtime/gpu_environment.cc | 94 -- .../litert/runtime/gpu_environment.h | 74 -- .../litert/runtime/gpu_environment_test.cc | 77 -- .../experimental/litert/runtime/ion_buffer.cc | 196 ---- .../experimental/litert/runtime/ion_buffer.h | 33 - .../litert/runtime/open_cl_buffer.cc | 120 -- .../litert/runtime/open_cl_buffer.h | 92 -- .../experimental/litert/runtime/opencl/BUILD | 129 -- .../litert/runtime/opencl/buffer.cc | 114 -- .../litert/runtime/opencl/buffer.h | 125 -- .../litert/runtime/opencl/buffer_test.cc | 61 - .../litert/runtime/opencl/cl_command_queue.cc | 141 --- .../litert/runtime/opencl/cl_command_queue.h | 82 -- .../litert/runtime/opencl/cl_context.cc | 105 -- .../litert/runtime/opencl/cl_context.h | 57 - .../litert/runtime/opencl/cl_device.cc | 104 -- .../litert/runtime/opencl/cl_device.h | 73 -- .../litert/runtime/opencl/cl_event.cc | 58 - .../litert/runtime/opencl/cl_event.h | 33 - .../litert/runtime/opencl/opencl_wrapper.cc | 470 -------- .../litert/runtime/opencl/opencl_wrapper.h | 737 ------------ .../litert/runtime/tensor_buffer.cc | 655 ----------- .../litert/runtime/tensor_buffer.h | 239 ---- .../runtime/tensor_buffer_conversion.cc | 210 ---- .../litert/runtime/tensor_buffer_conversion.h | 33 - .../runtime/tensor_buffer_conversion_test.cc | 140 --- .../runtime/tensor_buffer_requirements.h | 49 - .../experimental/litert/runtime/tfl_utils.cc | 97 -- .../experimental/litert/runtime/tfl_utils.h | 32 - .../lite/experimental/litert/test/BUILD | 173 --- .../lite/experimental/litert/test/common.cc | 126 -- .../lite/experimental/litert/test/common.h | 108 -- .../lite/experimental/litert/test/matchers.h | 359 ------ .../experimental/litert/test/matchers_test.cc | 184 --- .../experimental/litert/test/test_models.h | 126 -- .../litert/test/testdata/add_cst.mlir | 7 - .../litert/test/testdata/add_simple.mlir | 6 - .../litert/test/testdata/cos_mul.mlir | 7 - .../test/testdata/cst_multi_subgraph.mlir | 12 - .../test/testdata/dynamic_shape_tensor.mlir | 6 - .../test/testdata/fully_connected_3d.mlir | 6 - .../litert/test/testdata/mul_simple.mlir | 7 - .../litert/test/testdata/multi_composite.mlir | 21 - .../testdata/multi_op_multi_subgraph.mlir | 9 - .../litert/test/testdata/multi_subgraph.mlir | 21 - .../test/testdata/multi_subgraph_mul.mlir | 13 - .../litert/test/testdata/multi_use_cst.mlir | 9 - .../test/testdata/nested_composite.mlir | 14 - .../litert/test/testdata/one_mul.mlir | 6 - .../litert/test/testdata/rms_norm.mlir | 16 - .../test/testdata/rms_norm_composite.mlir | 23 - .../litert/test/testdata/scala_reshape.mlir | 7 - .../test/testdata/shared_input_cpu_npu.mlir | 7 - .../litert/test/testdata/simple_add_op.mlir | 6 - .../test/testdata/simple_average_poll_2d.mlir | 6 - .../test/testdata/simple_batch_matmul_op.mlir | 6 - .../testdata/simple_cascade_model_npu.mlir | 7 - .../litert/test/testdata/simple_cast_op.mlir | 6 - .../test/testdata/simple_composite.mlir | 11 - .../testdata/simple_concatenation_op.mlir | 6 - .../test/testdata/simple_conv_2d_op.mlir | 6 - .../litert/test/testdata/simple_cos_op.mlir | 6 - .../testdata/simple_depth_to_space_op.mlir | 6 - .../testdata/simple_depthwise_conv_2d_op.mlir | 6 - .../litert/test/testdata/simple_div_op.mlir | 6 - .../simple_dynamic_update_slice_op.mlir | 7 - .../testdata/simple_embedding_lookup_op.mlir | 7 - .../test/testdata/simple_floor_mod_op.mlir | 6 - .../testdata/simple_fully_connected_op.mlir | 6 - .../test/testdata/simple_gather_op.mlir | 6 - .../litert/test/testdata/simple_gelu_op.mlir | 6 - .../test/testdata/simple_greater_op.mlir | 6 - .../test/testdata/simple_hard_swish_op.mlir | 6 - .../test/testdata/simple_leaky_relu_op.mlir | 6 - .../litert/test/testdata/simple_less_op.mlir | 6 - .../test/testdata/simple_logical_and_op.mlir | 6 - .../litert/test/testdata/simple_mean_op.mlir | 7 - .../litert/test/testdata/simple_model.mlir | 6 - .../testdata/simple_model_google_tensor.bin | Bin 12288 -> 0 bytes .../litert/test/testdata/simple_model_mtk.bin | Bin 6956 -> 0 bytes .../test/testdata/simple_model_npu.mlir | 6 - .../test/testdata/simple_model_qualcomm.bin | Bin 13800 -> 0 bytes .../test/testdata/simple_model_test_vectors.h | 67 -- .../litert/test/testdata/simple_mul_op.mlir | 6 - .../litert/test/testdata/simple_multi_op.mlir | 9 - .../litert/test/testdata/simple_pack_op.mlir | 7 - .../litert/test/testdata/simple_relu6_op.mlir | 6 - .../litert/test/testdata/simple_relu_op.mlir | 6 - .../test/testdata/simple_reshape_op.mlir | 6 - .../testdata/simple_resize_bilinear_op.mlir | 7 - .../simple_resize_nearest_neighbor_op.mlir | 7 - .../litert/test/testdata/simple_rsqrt_op.mlir | 6 - .../test/testdata/simple_select_op.mlir | 6 - .../test/testdata/simple_select_v2_op.mlir | 6 - .../litert/test/testdata/simple_sin_op.mlir | 6 - .../litert/test/testdata/simple_slice_op.mlir | 8 - .../test/testdata/simple_softmax_op.mlir | 6 - .../testdata/simple_space_to_depth_op.mlir | 6 - .../litert/test/testdata/simple_split_op.mlir | 7 - .../testdata/simple_stablehlo_scatter_op.mlir | 9 - .../testdata/simple_strided_slice_op.mlir | 6 - .../litert/test/testdata/simple_sub_op.mlir | 6 - .../litert/test/testdata/simple_sum_op.mlir | 7 - .../litert/test/testdata/simple_tanh_op.mlir | 6 - .../test/testdata/simple_transpose_op.mlir | 7 - .../litert/test/testdata/two_adds.mlir | 7 - .../litert/test/testdata/two_partition.mlir | 9 - .../litert/test/testdata/unranked_tensor.mlir | 6 - .../lite/experimental/litert/tools/BUILD | 263 ----- .../lite/experimental/litert/tools/README.md | 24 - .../experimental/litert/tools/apply_plugin.cc | 515 -------- .../experimental/litert/tools/apply_plugin.h | 160 --- .../litert/tools/apply_plugin_main.cc | 157 --- .../litert/tools/apply_plugin_test.cc | 194 --- .../litert/tools/benchmark_litert_model.cc | 93 -- .../litert/tools/benchmark_litert_model.h | 151 --- .../tools/benchmark_litert_model_main.cc | 35 - .../tools/benchmark_litert_model_test.cc | 86 -- .../lite/experimental/litert/tools/dump.cc | 442 ------- .../lite/experimental/litert/tools/dump.h | 68 -- .../experimental/litert/tools/dump_test.cc | 131 --- .../experimental/litert/tools/outstream.h | 83 -- .../experimental/litert/tools/run_model.cc | 113 -- .../experimental/litert/tools/tool_display.cc | 86 -- .../experimental/litert/tools/tool_display.h | 102 -- .../litert/tools/tool_display_test.cc | 99 -- .../lite/experimental/litert/vendors/c/BUILD | 67 -- .../litert/vendors/c/litert_compiler_plugin.h | 118 -- .../vendors/c/litert_compiler_plugin_api.h | 156 --- .../litert/vendors/c/litert_dispatch.h | 309 ----- .../litert/vendors/c/litert_dispatch_api.h | 245 ---- .../c/litert_vendor_c_api_common_test.c | 28 - .../lite/experimental/litert/vendors/cc/BUILD | 126 -- .../litert/vendors/cc/backend_ir.h | 79 -- .../litert/vendors/cc/conversion.h | 262 ----- .../litert/vendors/cc/convert_graph.h | 177 --- .../litert/vendors/cc/convert_graph_test.cc | 390 ------ .../experimental/litert/vendors/cc/ir_types.h | 50 - .../vendors/cc/litert_compiler_plugin.h | 47 - .../vendors/cc/partition_with_capabilities.h | 97 -- .../cc/partition_with_capabilities_test.cc | 207 ---- .../litert/vendors/examples/BUILD | 160 --- .../examples/example_conversion_impl.cc | 64 - .../examples/example_conversion_impl.h | 128 -- .../examples/example_conversion_impl_test.cc | 213 ---- .../litert/vendors/examples/example_ir.cc | 87 -- .../litert/vendors/examples/example_ir.h | 153 --- .../litert/vendors/examples/example_plugin.cc | 120 -- .../vendors/examples/example_plugin_common.cc | 137 --- .../vendors/examples/example_plugin_common.h | 29 - .../vendors/examples/example_plugin_test.cc | 98 -- .../example_plugin_with_conversions.cc | 139 --- .../example_plugin_with_conversions_test.cc | 111 -- .../litert/vendors/google_tensor/BUILD | 89 -- .../litert/vendors/google_tensor/adapter.cc | 95 -- .../litert/vendors/google_tensor/adapter.h | 68 -- .../vendors/google_tensor/adapter_test.cc | 92 -- .../vendors/google_tensor/compiler/BUILD | 93 -- .../google_tensor/compiler/compiler_plugin.cc | 360 ------ .../compiler/compiler_plugin_test.cc | 92 -- .../vendors/google_tensor/dispatch/BUILD | 153 --- .../google_tensor/dispatch/dispatch_api.cc | 644 ---------- .../google_tensor/dispatch/dispatch_api.h | 65 - .../dispatch_api_async_google_tensor_test.cc | 340 ------ .../dispatch_api_google_tensor_test.cc | 291 ----- .../litert_dispatch_device_context.cc | 294 ----- .../dispatch/litert_dispatch_device_context.h | 67 -- .../dispatch/litert_dispatch_graph.cc | 305 ----- .../dispatch/litert_dispatch_graph.h | 129 -- .../litert_dispatch_invocation_context.cc | 613 ---------- .../litert_dispatch_invocation_context.h | 104 -- .../dispatch/litert_dispatch_metrics.h | 58 - .../google_tensor/dispatch/southbound.cc | 163 --- .../google_tensor/dispatch/southbound.h | 133 --- .../litert/vendors/mediatek/BUILD | 44 - .../litert/vendors/mediatek/compiler/BUILD | 160 --- .../mediatek/compiler/compile_model.cc | 104 -- .../vendors/mediatek/compiler/compile_model.h | 32 - .../mediatek/compiler/compiler_plugin.cc | 370 ------ .../mediatek/compiler/compiler_plugin_test.cc | 133 --- .../vendors/mediatek/compiler/create_model.cc | 164 --- .../vendors/mediatek/compiler/create_model.h | 34 - .../mediatek/compiler/legalizations/BUILD | 338 ------ .../legalizations/add_op_legalization.cc | 76 -- .../legalizations/add_op_legalization.h | 32 - .../batch_matmul_op_legalization.cc | 89 -- .../batch_matmul_op_legalization.h | 33 - .../legalizations/common_op_legalization.cc | 66 -- .../legalizations/common_op_legalization.h | 35 - .../legalizations/concat_op_legalization.cc | 77 -- .../legalizations/concat_op_legalization.h | 32 - .../fully_connected_op_legalization.cc | 129 -- .../fully_connected_op_legalization.h | 32 - .../legalizations/gelu_op_legalization.cc | 70 -- .../legalizations/gelu_op_legalization.h | 32 - .../legalizations/mean_op_legalization.cc | 75 -- .../legalizations/mean_op_legalization.h | 32 - .../legalizations/mul_op_legalization.cc | 76 -- .../legalizations/mul_op_legalization.h | 32 - .../compiler/legalizations/neuron_utils.cc | 104 -- .../compiler/legalizations/neuron_utils.h | 43 - .../compiler/legalizations/operand_map.cc | 80 -- .../compiler/legalizations/operand_map.h | 269 ----- .../legalizations/quantize_op_legalization.cc | 61 - .../legalizations/quantize_op_legalization.h | 32 - .../legalizations/reshape_op_legalization.cc | 61 - .../legalizations/reshape_op_legalization.h | 32 - .../legalizations/rsqrt_op_legalization.cc | 61 - .../legalizations/rsqrt_op_legalization.h | 32 - .../legalizations/softmax_op_legalization.cc | 74 -- .../legalizations/softmax_op_legalization.h | 32 - .../legalizations/sub_op_legalization.cc | 76 -- .../legalizations/sub_op_legalization.h | 32 - .../transpose_op_legalization.cc | 61 - .../legalizations/transpose_op_legalization.h | 32 - .../litert/vendors/mediatek/dispatch/BUILD | 122 -- .../vendors/mediatek/dispatch/README.md | 4 - .../vendors/mediatek/dispatch/dispatch_api.cc | 327 ------ .../dispatch/dispatch_api_mediatek_test.cc | 638 ---------- .../litert_dispatch_device_context.cc | 190 --- .../dispatch/litert_dispatch_device_context.h | 87 -- .../litert_dispatch_invocation_context.cc | 435 ------- .../litert_dispatch_invocation_context.h | 94 -- .../vendors/mediatek/mediatek_build_defs.bzl | 91 -- .../vendors/mediatek/neuron_adapter_api.cc | 187 --- .../vendors/mediatek/neuron_adapter_api.h | 151 --- .../litert/vendors/mediatek/schema/BUILD | 43 - .../vendors/mediatek/schema/neuron_schema.fbs | 61 - .../vendors/mediatek/schema/schema_resolver.h | 184 --- .../litert/vendors/mediatek/supported_soc.csv | 17 - .../litert/vendors/qualcomm/BUILD | 146 --- .../litert/vendors/qualcomm/common.h | 100 -- .../litert/vendors/qualcomm/compiler/BUILD | 206 ---- .../litert/vendors/qualcomm/compiler/IR/BUILD | 123 -- .../compiler/IR/op_compatibility_test.cc | 82 -- .../vendors/qualcomm/compiler/IR/qnn_op.cc | 147 --- .../vendors/qualcomm/compiler/IR/qnn_op.h | 53 - .../qualcomm/compiler/IR/qnn_op_test.cc | 67 -- .../qualcomm/compiler/IR/qnn_tensor.cc | 253 ---- .../vendors/qualcomm/compiler/IR/qnn_tensor.h | 75 -- .../qualcomm/compiler/IR/qnn_tensor_test.cc | 202 ---- .../vendors/qualcomm/compiler/graph_mapper.cc | 190 --- .../vendors/qualcomm/compiler/graph_mapper.h | 125 -- .../qualcomm/compiler/legalizations/BUILD | 922 --------------- .../legalizations/add_op_legalization.cc | 51 - .../legalizations/add_op_legalization.h | 49 - .../batch_matmul_op_legalization.cc | 52 - .../batch_matmul_op_legalization.h | 49 - .../legalizations/cast_op_legalization.cc | 49 - .../legalizations/cast_op_legalization.h | 49 - .../concatenation_op_legalization.cc | 100 -- .../concatenation_op_legalization.h | 51 - .../legalizations/cos_op_legalization.cc | 49 - .../legalizations/cos_op_legalization.h | 49 - .../legalizations/div_op_legalization.cc | 51 - .../legalizations/div_op_legalization.h | 49 - .../dynamic_update_slice_op_legalization.cc | 304 ----- .../dynamic_update_slice_op_legalization.h | 46 - .../embedding_lookup_op_legalization.cc | 103 -- .../embedding_lookup_op_legalization.h | 51 - .../fully_connected_op_legalization.cc | 51 - .../fully_connected_op_legalization.h | 51 - .../legalizations/gelu_op_legalization.cc | 50 - .../legalizations/gelu_op_legalization.h | 49 - .../legalizations/greater_op_legalization.cc | 51 - .../legalizations/greater_op_legalization.h | 49 - .../compiler/legalizations/legalization.h | 50 - .../legalizations/less_op_legalization.cc | 51 - .../legalizations/less_op_legalization.h | 49 - .../logical_and_op_legalization.cc | 51 - .../logical_and_op_legalization.h | 51 - .../legalizations/mul_op_legalization.cc | 51 - .../legalizations/mul_op_legalization.h | 49 - .../legalizations/pack_op_legalization.cc | 140 --- .../legalizations/pack_op_legalization.h | 49 - .../legalizations/quantize_op_legalization.cc | 174 --- .../legalizations/quantize_op_legalization.h | 65 - .../legalizations/reshape_op_legalization.cc | 82 -- .../legalizations/reshape_op_legalization.h | 49 - .../legalizations/rsqrt_op_legalization.cc | 51 - .../legalizations/rsqrt_op_legalization.h | 49 - .../legalizations/select_op_legalization.cc | 55 - .../legalizations/select_op_legalization.h | 49 - .../legalizations/sin_op_legalization.cc | 49 - .../legalizations/sin_op_legalization.h | 49 - .../legalizations/slice_op_legalization.cc | 160 --- .../legalizations/slice_op_legalization.h | 47 - .../legalizations/softmax_op_legalization.cc | 99 -- .../legalizations/softmax_op_legalization.h | 49 - .../legalizations/sub_op_legalization.cc | 51 - .../legalizations/sub_op_legalization.h | 49 - .../legalizations/sum_op_legalization.cc | 146 --- .../legalizations/sum_op_legalization.h | 49 - .../legalizations/tanh_op_legalization.cc | 51 - .../legalizations/tanh_op_legalization.h | 49 - .../transpose_op_legalization.cc | 121 -- .../legalizations/transpose_op_legalization.h | 49 - .../qualcomm/compiler/legalizations/util.cc | 121 -- .../qualcomm/compiler/legalizations/util.h | 117 -- .../qualcomm/compiler/qnn_compiler_plugin.cc | 423 ------- .../compiler/qnn_compiler_plugin_test.cc | 399 ------- .../qualcomm/compiler/qnn_compose_graph.cc | 758 ------------ .../qualcomm/compiler/qnn_compose_graph.h | 49 - .../vendors/qualcomm/context_binary_info.cc | 216 ---- .../vendors/qualcomm/context_binary_info.h | 66 -- .../litert/vendors/qualcomm/core/BUILD | 27 - .../vendors/qualcomm/core/builders/BUILD | 574 --------- .../qualcomm/core/builders/cast_op_builder.cc | 27 - .../qualcomm/core/builders/cast_op_builder.h | 20 - .../core/builders/concatenation_op_builder.cc | 36 - .../core/builders/concatenation_op_builder.h | 20 - .../core/builders/conv2d_op_builder.cc | 166 --- .../core/builders/conv2d_op_builder.h | 25 - .../builders/depthwise_conv2d_op_builder.cc | 118 -- .../builders/depthwise_conv2d_op_builder.h | 25 - .../dynamic_update_slice_op_builder.cc | 135 --- .../dynamic_update_slice_op_builder.h | 21 - .../core/builders/elementwise_op_builder.cc | 233 ---- .../core/builders/elementwise_op_builder.h | 73 -- .../builders/embedding_lookup_op_builder.cc | 68 -- .../builders/embedding_lookup_op_builder.h | 20 - .../builders/fully_connected_op_builder.cc | 67 -- .../builders/fully_connected_op_builder.h | 22 - .../fully_connected_op_builder_htp.cc | 132 --- .../builders/fully_connected_op_builder_htp.h | 21 - .../core/builders/gather_op_builder.cc | 44 - .../core/builders/gather_op_builder.h | 21 - .../qualcomm/core/builders/gelu_op_builder.cc | 26 - .../qualcomm/core/builders/gelu_op_builder.h | 20 - .../core/builders/hard_swish_op_builder.cc | 38 - .../core/builders/hard_swish_op_builder.h | 21 - .../core/builders/leaky_relu_op_builder.cc | 101 -- .../core/builders/leaky_relu_op_builder.h | 21 - .../core/builders/matmul_op_builder.cc | 33 - .../core/builders/matmul_op_builder.h | 21 - .../qualcomm/core/builders/mean_op_builder.cc | 67 -- .../qualcomm/core/builders/mean_op_builder.h | 21 - .../qualcomm/core/builders/op_builder.cc | 107 -- .../qualcomm/core/builders/op_builder.h | 36 - .../qualcomm/core/builders/pack_op_builder.cc | 53 - .../qualcomm/core/builders/pack_op_builder.h | 23 - .../core/builders/pool2d_op_builder.cc | 113 -- .../core/builders/pool2d_op_builder.h | 33 - .../core/builders/quantize_op_builder.cc | 56 - .../core/builders/quantize_op_builder.h | 25 - .../core/builders/reduce_op_builder.cc | 66 -- .../core/builders/reduce_op_builder.h | 20 - .../core/builders/relu6_op_builder.cc | 24 - .../qualcomm/core/builders/relu6_op_builder.h | 21 - .../qualcomm/core/builders/relu_op_builder.cc | 24 - .../qualcomm/core/builders/relu_op_builder.h | 21 - .../core/builders/reshape_op_builder.cc | 28 - .../core/builders/reshape_op_builder.h | 20 - .../core/builders/resize_op_builder.cc | 59 - .../core/builders/resize_op_builder.h | 27 - .../core/builders/rms_norm_op_builder.cc | 85 -- .../core/builders/rms_norm_op_builder.h | 32 - .../core/builders/select_op_builder.cc | 30 - .../core/builders/select_op_builder.h | 20 - .../core/builders/slice_op_builder.cc | 75 -- .../qualcomm/core/builders/slice_op_builder.h | 19 - .../core/builders/softmax_op_builder.cc | 29 - .../core/builders/softmax_op_builder.h | 19 - .../builders/spatial_transform_op_builder.cc | 63 - .../builders/spatial_transform_op_builder.h | 28 - .../core/builders/split_op_builder.cc | 69 -- .../qualcomm/core/builders/split_op_builder.h | 21 - .../qualcomm/core/builders/tanh_op_builder.cc | 26 - .../qualcomm/core/builders/tanh_op_builder.h | 20 - .../core/builders/transpose_op_builder.cc | 39 - .../core/builders/transpose_op_builder.h | 20 - .../litert/vendors/qualcomm/core/common.h | 40 - .../vendors/qualcomm/core/tensor_pool.cc | 92 -- .../vendors/qualcomm/core/tensor_pool.h | 66 -- .../litert/vendors/qualcomm/core/utils/BUILD | 63 - .../litert/vendors/qualcomm/core/utils/log.h | 56 - .../qualcomm/core/utils/log_android.cc | 58 - .../qualcomm/core/utils/log_default.cc | 36 - .../vendors/qualcomm/core/utils/miscs.cc | 30 - .../vendors/qualcomm/core/utils/miscs.h | 44 - .../vendors/qualcomm/core/utils/utils_test.cc | 150 --- .../vendors/qualcomm/core/wrappers/BUILD | 69 -- .../qualcomm/core/wrappers/op_wrapper.cc | 82 -- .../qualcomm/core/wrappers/op_wrapper.h | 55 - .../qualcomm/core/wrappers/param_wrapper.cc | 27 - .../qualcomm/core/wrappers/param_wrapper.h | 78 -- .../core/wrappers/quantize_params_wrapper.cc | 114 -- .../core/wrappers/quantize_params_wrapper.h | 86 -- .../qualcomm/core/wrappers/tensor_wrapper.cc | 274 ----- .../qualcomm/core/wrappers/tensor_wrapper.h | 349 ------ .../qualcomm/core/wrappers/tests/BUILD | 78 -- .../core/wrappers/tests/op_wrapper_test.cc | 168 --- .../core/wrappers/tests/param_wrapper_test.cc | 232 ---- .../tests/quantize_params_wrapper_test.cc | 163 --- .../wrappers/tests/tensor_wrapper_test.cc | 309 ----- .../litert/vendors/qualcomm/dispatch/BUILD | 115 -- .../vendors/qualcomm/dispatch/dispatch_api.cc | 292 ----- .../dispatch/dispatch_api_qualcomm_test.cc | 544 --------- .../litert_dispatch_device_context.cc | 190 --- .../dispatch/litert_dispatch_device_context.h | 79 -- .../litert_dispatch_invocation_context.cc | 240 ---- .../litert_dispatch_invocation_context.h | 81 -- .../vendors/qualcomm/dispatch/registry.h | 71 -- .../litert/vendors/qualcomm/qnn_log.cc | 64 - .../litert/vendors/qualcomm/qnn_log.h | 28 - .../litert/vendors/qualcomm/qnn_manager.cc | 411 ------- .../litert/vendors/qualcomm/qnn_manager.h | 239 ---- .../vendors/qualcomm/qnn_manager_test.cc | 50 - .../litert/vendors/qualcomm/qnn_tensor.cc | 104 -- .../litert/vendors/qualcomm/qnn_tensor.h | 58 - .../vendors/qualcomm/qualcomm_build_defs.bzl | 118 -- .../litert/vendors/qualcomm/supported_soc.csv | 41 - .../litert/vendors/qualcomm/tools/BUILD | 31 - .../litert/vendors/qualcomm/tools/dump.cc | 88 -- .../litert/vendors/qualcomm/tools/dump.h | 29 - tensorflow/opensource_only.files | 1 - 659 files changed, 79998 deletions(-) delete mode 100644 tensorflow/lite/experimental/litert/BUILD delete mode 100644 tensorflow/lite/experimental/litert/build_common/BUILD delete mode 100644 tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds delete mode 100644 tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds delete mode 100644 tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds delete mode 100644 tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds delete mode 100644 tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl delete mode 100644 tensorflow/lite/experimental/litert/build_common/special_rule.bzl delete mode 100644 tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl delete mode 100644 tensorflow/lite/experimental/litert/c/BUILD delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_any.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c delete mode 100644 tensorflow/lite/experimental/litert/c/litert_common.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_common.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_common_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_compilation_options.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_compilation_options.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_compiled_model.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_compiled_model.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_environment.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_environment.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_environment_options.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_environment_options.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_event.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_event.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_event_type.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_gl_types.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_layout.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_logging.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_logging.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_logging_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_model.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_model.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_op_code.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_options.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_options.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_options_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc delete mode 100644 tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h delete mode 100644 tensorflow/lite/experimental/litert/cc/BUILD delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_any.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_any_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_compilation_options.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_compiled_model.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_consts.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_detail.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_element_type.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_environment.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_environment_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_event.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_event_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_expected.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_expected_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_handle.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_layout.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_layout_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_macros.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_macros.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_macros_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_model.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_model.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_model_predicates.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_op_options.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_op_options.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_shared_library.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_shared_library.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc delete mode 100644 tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h delete mode 100644 tensorflow/lite/experimental/litert/cc/test_shared_library.cc delete mode 100644 tensorflow/lite/experimental/litert/compiler/BUILD delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/BUILD delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/algo.cc delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/algo.h delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h delete mode 100644 tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/BUILD delete mode 100644 tensorflow/lite/experimental/litert/core/build_stamp.cc delete mode 100644 tensorflow/lite/experimental/litert/core/build_stamp.h delete mode 100644 tensorflow/lite/experimental/litert/core/build_stamp_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc delete mode 100644 tensorflow/lite/experimental/litert/core/dispatch_op_schema.h delete mode 100644 tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/dynamic_loading.cc delete mode 100644 tensorflow/lite/experimental/litert/core/dynamic_loading.h delete mode 100644 tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/environment.cc delete mode 100644 tensorflow/lite/experimental/litert/core/environment.h delete mode 100644 tensorflow/lite/experimental/litert/core/environment_options.cc delete mode 100644 tensorflow/lite/experimental/litert/core/environment_options.h delete mode 100644 tensorflow/lite/experimental/litert/core/environment_options_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/environment_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/filesystem.cc delete mode 100644 tensorflow/lite/experimental/litert/core/filesystem.h delete mode 100644 tensorflow/lite/experimental/litert/core/filesystem_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/insert_order_map.h delete mode 100644 tensorflow/lite/experimental/litert/core/insert_order_map_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/BUILD delete mode 100644 tensorflow/lite/experimental/litert/core/model/buffer_manager.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/graph_validation.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/graph_validation.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/ir_allocator.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_file_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_file_test_util.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_graph.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_graph.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_graph_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_load.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_load.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_serialize.cc delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_serialize.h delete mode 100644 tensorflow/lite/experimental/litert/core/model/model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/util/BUILD delete mode 100644 tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc delete mode 100644 tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h delete mode 100644 tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc delete mode 100644 tensorflow/lite/experimental/litert/core/util/tensor_type_util.h delete mode 100644 tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc delete mode 100644 tensorflow/lite/experimental/litert/core/version.h delete mode 100644 tensorflow/lite/experimental/litert/integration_test/BUILD delete mode 100644 tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc delete mode 100644 tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h delete mode 100644 tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl delete mode 100755 tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir delete mode 100644 tensorflow/lite/experimental/litert/python/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerator.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerator_registry.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerator_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/compilation_options.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/compiled_model.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/compiled_model.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/compiler/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/README.md delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/event.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/event.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/gl_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/gl_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/gl_texture.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/gl_texture.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/gpu_environment.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/gpu_environment.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/ion_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/ion_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/BUILD delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/tensor_buffer.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h delete mode 100644 tensorflow/lite/experimental/litert/runtime/tfl_utils.cc delete mode 100644 tensorflow/lite/experimental/litert/runtime/tfl_utils.h delete mode 100644 tensorflow/lite/experimental/litert/test/BUILD delete mode 100644 tensorflow/lite/experimental/litert/test/common.cc delete mode 100644 tensorflow/lite/experimental/litert/test/common.h delete mode 100644 tensorflow/lite/experimental/litert/test/matchers.h delete mode 100644 tensorflow/lite/experimental/litert/test/matchers_test.cc delete mode 100644 tensorflow/lite/experimental/litert/test/test_models.h delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_model_mtk.bin delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir delete mode 100644 tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir delete mode 100644 tensorflow/lite/experimental/litert/tools/BUILD delete mode 100644 tensorflow/lite/experimental/litert/tools/README.md delete mode 100644 tensorflow/lite/experimental/litert/tools/apply_plugin.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/apply_plugin.h delete mode 100644 tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h delete mode 100644 tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/dump.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/dump.h delete mode 100644 tensorflow/lite/experimental/litert/tools/dump_test.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/outstream.h delete mode 100644 tensorflow/lite/experimental/litert/tools/run_model.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/tool_display.cc delete mode 100644 tensorflow/lite/experimental/litert/tools/tool_display.h delete mode 100644 tensorflow/lite/experimental/litert/tools/tool_display_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/c/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/conversion.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/ir_types.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_ir.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/common.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc delete mode 100644 tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h diff --git a/tensorflow/lite/experimental/litert/BUILD b/tensorflow/lite/experimental/litert/BUILD deleted file mode 100644 index 23b07d5602d7c8..00000000000000 --- a/tensorflow/lite/experimental/litert/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) diff --git a/tensorflow/lite/experimental/litert/build_common/BUILD b/tensorflow/lite/experimental/litert/build_common/BUILD deleted file mode 100644 index 735f1cbed03c2c..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -exports_files(srcs = [ - "export_litert_only_darwin.lds", - "export_litert_only_linux.lds", - "export_litert_runtime_only_darwin.lds", - "export_litert_runtime_only_linux.lds", -]) - -bzl_library( - name = "special_rule_bzl", - srcs = ["special_rule.bzl"], - visibility = ["//visibility:private"], -) - -bzl_library( - name = "litert_build_defs_bzl", - srcs = ["litert_build_defs.bzl"], - visibility = ["//visibility:private"], -) - -bzl_library( - name = "tfl_model_gen_bzl", - srcs = ["tfl_model_gen.bzl"], - visibility = ["//visibility:private"], -) diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds deleted file mode 100644 index a51afcee0a21f0..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds +++ /dev/null @@ -1,8 +0,0 @@ -# Compiler Plugin -*LiteRt*CompilerPlugin* - -# Compiled Result -*LiteRt*CompiledResult* - -# Dispatch -*LiteRtDispatch* diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds deleted file mode 100644 index 97b05c1d655a71..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds +++ /dev/null @@ -1,29 +0,0 @@ -VERS_1.0 { - - /* - Export abi-stable "vendor" implemented symbols. - - TODO: Add all vendor symbols. Also export qnn libc++ symbols - (statically linked) as "protected" as needed. - */ - - global: - - /* Compiler Plugin */ - - LiteRt*CompilerPlugin*; - - /* Compiled Result */ - - LiteRt*CompiledResult*; - - /* Dispatch */ - - LiteRtDispatch*; - - local: - - /* Hide everything else */ - - *; -}; diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds deleted file mode 100644 index 9638faa6b23e98..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds +++ /dev/null @@ -1,2 +0,0 @@ -# All LiteRt C APIs -LiteRt* diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds deleted file mode 100644 index 6948af4950cfd6..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds +++ /dev/null @@ -1,20 +0,0 @@ -VERS_1.0 { - - /* - Export abi-stable "vendor" implemented symbols. - - TODO: Add all vendor symbols. Also export qnn libc++ symbols - (statically linked) as "protected" as needed. - */ - - global: - - /* All LiteRt C APIs */ - LiteRt*; - - local: - - /* Hide everything else */ - - *; -}; diff --git a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl b/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl deleted file mode 100644 index a6b13cb2d18767..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Common LiteRT Build Utilities.""" - -#################################################################################################### -# Util - -_LRT_SO_PREFIX = "libLiteRt" -_SO_EXT = ".so" -_SHARED_LIB_SUFFIX = "_so" - -# Public - -def make_linkopt(opt): - return "-Wl,{}".format(opt) - -def make_rpaths(rpaths): - return make_linkopt("-rpath={}".format(":".join(rpaths))) - -def append_rule_kwargs(rule_kwargs, **append): - for k, v in append.items(): - append_to = rule_kwargs.pop(k, []) - append_to += v - rule_kwargs[k] = append_to - -def absolute_label(label, package_name = None): - """Get the absolute label for a given label. - - Args: - label: The label to convert to absolute. - package_name: The package name to use if the label is relative. - - Returns: - The absolute label. - """ - if label.startswith("//"): - if ":" in label: - return label - return "%s:%s" % (label, label.rsplit("/", 1)[-1]) - if not package_name: - package_name = native.package_name() - if label.startswith(":"): - return "//%s%s" % (package_name, label) - if ":" in label: - return "//%s/%s" % (package_name, label) - return "//%s:%s" % (package_name, label) - -# Private - -def _valid_shared_lib_name(name): - return name.endswith(_SHARED_LIB_SUFFIX) - -def _valid_so_name(name): - return name.startswith(_LRT_SO_PREFIX) and name.endswith(_SO_EXT) - -def _make_target_ref(name): - return ":{}".format(name) - -#################################################################################################### -# Explicitly Link System Libraries ("ungrte") - -_SYS_RPATHS_X86_64 = [ - "/usr/lib/x86_64-linux-gnu", - "/lib/x86_64-linux-gnu", -] -_SYS_RPATHS_LINKOPT_X86_64 = make_rpaths(_SYS_RPATHS_X86_64) - -_SYS_ELF_INTERPRETER_X86_64 = "/lib64/ld-linux-x86-64.so.2" -_SYS_ELF_INTERPRETER_LINKOPT_X86_64 = make_linkopt("--dynamic-linker={}".format(_SYS_ELF_INTERPRETER_X86_64)) - -#################################################################################################### -# Symbol Hiding - -_EXPORT_LRT_ONLY_SCRIPT_LINUX = "//tensorflow/lite/experimental/litert/build_common:export_litert_only_linux.lds" -_EXPORT_LRT_ONLY_SCRIPT_DARWIN = "//tensorflow/lite/experimental/litert/build_common:export_litert_only_darwin.lds" -_EXPORT_LRT_ONLY_LINKOPT_LINUX = make_linkopt("--version-script=$(location {})".format(_EXPORT_LRT_ONLY_SCRIPT_LINUX)) -_EXPORT_LRT_ONLY_LINKOPT_DARWIN = make_linkopt("-exported_symbols_list,$(location {})".format(_EXPORT_LRT_ONLY_SCRIPT_DARWIN)) - -def symbol_opts(): - """Defines linker flags whether to include symbols or not.""" - return select({ - "//tensorflow:debug": [], - "//conditions:default": [ - # Omit symbol table, for all non debug builds - "-Wl,-s", - ], - }) - -def export_lrt_only_script(): - return select({ - "//tensorflow:linux_x86_64": [_EXPORT_LRT_ONLY_SCRIPT_LINUX], - "//tensorflow:android": [_EXPORT_LRT_ONLY_SCRIPT_LINUX], - "//tensorflow:macos": [_EXPORT_LRT_ONLY_SCRIPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_ONLY_SCRIPT_DARWIN], - "//conditions:default": [], - }) - -def export_lrt_only_linkopt(): - return select({ - "//tensorflow:linux_x86_64": [_EXPORT_LRT_ONLY_LINKOPT_LINUX], - "//tensorflow:android": [_EXPORT_LRT_ONLY_LINKOPT_LINUX], - "//tensorflow:macos": [_EXPORT_LRT_ONLY_LINKOPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_ONLY_LINKOPT_DARWIN], - "//conditions:default": [], - }) + symbol_opts() - -_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX = "//tensorflow/lite/experimental/litert/build_common:export_litert_runtime_only_linux.lds" -_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN = "//tensorflow/lite/experimental/litert/build_common:export_litert_runtime_only_darwin.lds" -_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_LINUX = make_linkopt("--version-script=$(location {})".format(_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX)) -_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_DARWIN = make_linkopt("-exported_symbols_list,$(location {})".format(_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN)) - -# TODO b/391390553: Add "-Wl,--no-undefined" to make sure all symbols are defined. -_EXPORT_LRT_COMMON_LINKOPTS_LINUX = [ - "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. - "-Wl,--gc-sections", # Eliminate unused code and data. - "-Wl,--as-needed", # Don't link unused libs.a -] - -def export_lrt_runtime_only_script(): - return select({ - "//tensorflow:linux_x86_64": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX], - "//tensorflow:android": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX], - "//tensorflow:macos": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN], - "//conditions:default": [], - }) - -def export_lrt_runtime_only_linkopt(): - return select({ - "//tensorflow:linux_x86_64": _EXPORT_LRT_COMMON_LINKOPTS_LINUX + [_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_LINUX], - "//tensorflow:android": _EXPORT_LRT_COMMON_LINKOPTS_LINUX + [ - "-Wl,-z,max-page-size=16384", - _EXPORT_LRT_RUNTIME_ONLY_LINKOPT_LINUX, - ], - "//tensorflow:macos": [_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_DARWIN], - "//conditions:default": [], - }) + symbol_opts() - -#################################################################################################### -# Macros - -# Private - -def _litert_base( - rule, - ungrte = False, - **cc_rule_kwargs): - """ - Base rule for LiteRT targets. - - Args: - rule: The underlying rule to use (e.g., cc_test, cc_library). - ungrte: Whether to link against system libraries ("ungrte"). - **cc_rule_kwargs: Keyword arguments to pass to the underlying rule. - """ - if ungrte: - append_rule_kwargs( - cc_rule_kwargs, - linkopts = select({ - "//tensorflow:linux_x86_64": [_SYS_ELF_INTERPRETER_LINKOPT_X86_64, _SYS_RPATHS_LINKOPT_X86_64], - "//conditions:default": [], - }), - ) - rule(**cc_rule_kwargs) - -# Public - -def litert_test( - ungrte = False, - use_sys_malloc = False, - **cc_test_kwargs): - """ - LiteRT test rule. - - Args: - ungrte: Whether to link against system libraries ("ungrte"). - use_sys_malloc: Whether to use the system malloc. - **cc_test_kwargs: Keyword arguments to pass to the underlying rule. - """ - if use_sys_malloc: - # copybara:uncomment cc_test_kwargs["malloc"] = "//base:system_malloc" - pass - - append_rule_kwargs( - cc_test_kwargs, - deps = ["@com_google_googletest//:gtest_main"], - ) - - _litert_base( - native.cc_test, - ungrte, - **cc_test_kwargs - ) - -def litert_lib( - ungrte = False, - **cc_lib_kwargs): - """ - LiteRT library rule. - - Args: - ungrte: Whether to link against system libraries ("ungrte"). - **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. - """ - _litert_base( - native.cc_library, - ungrte, - **cc_lib_kwargs - ) - -def litert_bin( - ungrte = False, - export_litert_only = False, - **cc_bin_kwargs): - """ - LiteRT binary rule. - - Args: - ungrte: Whether to link against system libraries ("ungrte"). - export_litert_only: Whether to export only LiteRT symbols. - **cc_bin_kwargs: Keyword arguments to pass to the underlying rule. - """ - if export_litert_only: - append_rule_kwargs( - cc_bin_kwargs, - linkopts = export_lrt_only_linkopt(), - deps = export_lrt_only_script(), - ) - - _litert_base( - native.cc_binary, - ungrte, - **cc_bin_kwargs - ) - -def litert_dynamic_lib( - name, - shared_lib_name, - so_name, - export_litert_only = False, - ungrte = False, - **cc_lib_kwargs): - """ - LiteRT dynamic library rule. - - Args: - name: The name of the library. - shared_lib_name: The name of the shared library. - so_name: The name of the shared object file. - export_litert_only: Whether to export only LiteRT symbols. - ungrte: Whether to link against system libraries ("ungrte"). - **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. - """ - if not _valid_shared_lib_name(shared_lib_name): - fail("\"shared_lib_name\" must end with \"_so\"") - if not _valid_so_name(so_name): - fail("\"so_name\" must be \"libLiteRt*.so\"") - - lib_name = name - cc_lib_kwargs["name"] = lib_name - - lib_target_ref = _make_target_ref(lib_name) - - vis = cc_lib_kwargs.get("visibility", None) - - # Share tags for all targets. - tags = cc_lib_kwargs.get("tags", []) - - litert_lib( - ungrte = ungrte, - **cc_lib_kwargs - ) - - user_link_flags = [] - additional_linker_inputs = [] - if export_litert_only: - user_link_flags = export_lrt_only_linkopt() - additional_linker_inputs = export_lrt_only_script() - - native.cc_shared_library( - name = shared_lib_name, - shared_lib_name = so_name, - user_link_flags = user_link_flags, - additional_linker_inputs = additional_linker_inputs, - tags = tags, - visibility = vis, - deps = [lib_target_ref], - ) - -def copy_file(name, src, target, visibility = None): - input_path = "$(location %s)" % src - output_path = "$(@D)/" + target - - native.genrule( - name = name, - srcs = [src], - outs = [target], - visibility = visibility, - cmd = "cp %s %s" % (input_path, output_path), - ) - -def gtest_main_no_heapcheck_deps(): - # copybara:uncomment_begin(google-only) - # return ["//testing/base/public:gunit_main_no_heapcheck"] - # copybara:uncomment_end - # copybara:comment_begin(oss-only) - return ["@com_google_googletest//:gtest_main"] - # copybara:comment_end diff --git a/tensorflow/lite/experimental/litert/build_common/special_rule.bzl b/tensorflow/lite/experimental/litert/build_common/special_rule.bzl deleted file mode 100644 index e6a3c1c47fcf1e..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/special_rule.bzl +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""External versions of LiteRT build rules that differ outside of Google.""" - -def lite_rt_friends(): - """Internal visibility for packages outside of LiteRT code location. - - Return the package group declaration for internal code locations that need - visibility to LiteRT APIs""" - - return [] - -def gles_deps(): - """This is a no-op outside of Google.""" - return [] - -def gles_headers(): - """This is a no-op outside of Google.""" - return [] - -def gles_linkopts(): - """This is a no-op outside of Google.""" - return [] diff --git a/tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl b/tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl deleted file mode 100644 index 654c8684cf5754..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility to generate tflite models from MLIR files.""" - -def tfl_model_gen(name, srcs, subdir = "testdata"): - """ - Generates tflite models from MLIR files. - - Args: - name: name of the rule. - srcs: list of MLIR files. - subdir: subdirectory to place the generated tflite files. - """ - OUT_DIR = "$(RULEDIR)" - CONVERTER = "//tensorflow/compiler/mlir/lite:tf_tfl_translate" - CMD = """ - for mlir_file in $(SRCS); do - $(location {converter}) --input-mlir $$mlir_file --o={out_dir}/{subdir}/$$(basename $$mlir_file .mlir).tflite - done - """.format( - converter = CONVERTER, - out_dir = OUT_DIR, - subdir = subdir, - ) - - native.genrule( - name = name, - srcs = srcs, - outs = [s.removesuffix(".mlir") + ".tflite" for s in srcs], - cmd = CMD, - tools = [CONVERTER], - ) - - native.filegroup( - name = name + "_files", - srcs = [name], - ) diff --git a/tensorflow/lite/experimental/litert/c/BUILD b/tensorflow/lite/experimental/litert/c/BUILD deleted file mode 100644 index 6ea15aa4478430..00000000000000 --- a/tensorflow/lite/experimental/litert/c/BUILD +++ /dev/null @@ -1,589 +0,0 @@ -# copybara:uncomment_begin(google-only) -# load("//devtools/deps/check:deps_check.bzl", "check_dependencies") -# -# copybara:uncomment_end -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "export_lrt_runtime_only_linkopt", "export_lrt_runtime_only_script") -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "gles_deps", "gles_headers", "gles_linkopts") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "litert_common", - srcs = ["litert_common.cc"], - hdrs = ["litert_common.h"], -) - -cc_test( - name = "litert_common_test", - srcs = ["litert_common_test.cc"], - deps = [ - ":litert_common", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_any", - hdrs = ["litert_any.h"], -) - -cc_library( - name = "litert_environment", - srcs = ["litert_environment.cc"], - hdrs = ["litert_environment.h"], - deps = [ - ":litert_common", - ":litert_environment_options", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:gpu_environment", - "//tensorflow/lite/experimental/litert/runtime/accelerators:auto_registration", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "litert_environment_options", - srcs = ["litert_environment_options.cc"], - hdrs = ["litert_environment_options.h"], - deps = [ - ":litert_any", - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment_options", - ], -) - -cc_library( - name = "litert_environment_options_header", - hdrs = ["litert_environment_options.h"], - tags = ["avoid_dep"], - deps = [ - ":litert_any", - ":litert_common", - ], -) - -cc_test( - name = "litert_environment_options_test", - srcs = ["litert_environment_options_test.cc"], - deps = [ - ":litert_any", - ":litert_common", - ":litert_environment_options", - "//tensorflow/lite/experimental/litert/core:environment_options", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_logging", - srcs = [ - "litert_logging.cc", - ], - hdrs = [ - "litert_logging.h", - ], - deps = [ - ":litert_common", - "//tensorflow/lite:minimal_logging", - ], -) - -cc_test( - name = "litert_logging_test", - srcs = [ - "litert_logging_test.cc", - ], - deps = [ - ":litert_common", - ":litert_logging", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_layout", - hdrs = ["litert_layout.h"], - deps = [ - ":litert_common", - ":litert_op_code", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "litert_model", - srcs = ["litert_model.cc"], - hdrs = ["litert_model.h"], - deps = [ - ":litert_common", - ":litert_layout", - ":litert_op_code", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_load", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_model_test", - srcs = ["litert_model_test.cc"], - deps = [ - ":litert_common", - ":litert_model", - ":litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_op_code", - hdrs = ["litert_op_code.h"], - deps = ["//tensorflow/lite:builtin_ops"], -) - -cc_library( - name = "litert_options", - srcs = ["litert_options.cc"], - hdrs = [ - "litert_options.h", - ], - deps = [ - ":litert_common", - ":litert_op_code", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/core/model", - ], -) - -cc_test( - name = "litert_options_test", - srcs = ["litert_options_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - tags = ["no_oss"], - deps = [ - ":litert_common", - ":litert_options", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_event_type", - hdrs = ["litert_event_type.h"], -) - -cc_library( - name = "litert_event", - srcs = ["litert_event.cc"], - hdrs = ["litert_event.h"], - deps = [ - ":litert_common", - ":litert_event_type", - ":litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime:event", - ], -) - -cc_library( - name = "litert_tensor_buffer_types", - srcs = [], - hdrs = ["litert_tensor_buffer_types.h"], -) - -cc_library( - name = "litert_gl_types", - srcs = [], - hdrs = ["litert_gl_types.h"], -) - -cc_library( - name = "litert_tensor_buffer", - srcs = [ - "litert_tensor_buffer.cc", - "litert_tensor_buffer_requirements.cc", - ], - hdrs = [ - "litert_tensor_buffer.h", - "litert_tensor_buffer_requirements.h", - ], - linkopts = select({ - "//tensorflow:android": [ - "-landroid", - ], - "//conditions:default": [], - }) + gles_linkopts(), - deps = [ - ":litert_common", - ":litert_event", - ":litert_gl_types", - ":litert_logging", - ":litert_model", - ":litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ] + gles_deps(), -) - -cc_test( - name = "litert_tensor_buffer_test", - srcs = [ - "litert_tensor_buffer_test.cc", - ], - # require GPU to run OpenCL tests. - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":litert_common", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_tensor_buffer_requirements_test", - srcs = [ - "litert_tensor_buffer_requirements_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - ":litert_common", - ":litert_tensor_buffer", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_dispatch_delegate", - hdrs = [ - "litert_dispatch_delegate.h", - ], - deps = [ - ":litert_environment_options", - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/runtime/dispatch:dispatch_delegate", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - ], -) - -cc_library( - name = "litert_compilation_options", - srcs = ["litert_compilation_options.cc"], - hdrs = [ - "litert_compilation_options.h", - ], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_common", - ":litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime:compilation_options", - ], -) - -cc_test( - name = "litert_compilation_options_test", - srcs = ["litert_compilation_options_test.cc"], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_common", - ":litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_compiled_model", - srcs = ["litert_compiled_model.cc"], - hdrs = [ - "litert_compiled_model.h", - ], - deps = [ - ":litert_common", - ":litert_compilation_options", - ":litert_environment", - ":litert_logging", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/runtime:compiled_model", - ], -) - -cc_test( - name = "litert_compiled_model_test", - srcs = [ - "litert_compiled_model_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":litert_common", - ":litert_compilation_options", - ":litert_compiled_model", - ":litert_environment", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -# The same test as `litert_compiled_model_test` but using the shared library `libLiteRtRuntimeCApi.so`. -cc_test( - name = "litert_compiled_model_shared_lib_test", - srcs = [ - "litert_compiled_model_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator", - srcs = ["litert_accelerator.cc"], - hdrs = ["litert_accelerator.h"], - deps = [ - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - ], -) - -cc_test( - name = "litert_accelerator_test", - srcs = ["litert_accelerator_test.cc"], - deps = [ - ":litert_accelerator", - ":litert_accelerator_registration", - ":litert_common", - ":litert_environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator_compilation_options", - srcs = ["litert_accelerator_compilation_options.cc"], - hdrs = ["litert_accelerator_compilation_options.h"], - deps = [ - ":litert_common", - ], -) - -cc_test( - name = "litert_accelerator_compilation_options_test", - srcs = ["litert_accelerator_compilation_options_test.cc"], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:version", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator_registration", - srcs = ["litert_accelerator_registration.cc"], - hdrs = ["litert_accelerator_registration.h"], - deps = [ - ":litert_accelerator", - ":litert_accelerator_compilation_options", - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - ], -) - -cc_test( - name = "litert_accelerator_registration_test", - srcs = ["litert_accelerator_registration_test.cc"], - deps = [ - ":litert_accelerator", - ":litert_accelerator_compilation_options", - ":litert_accelerator_registration", - ":litert_common", - ":litert_environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - "@com_google_googletest//:gtest_main", - ], -) - -filegroup( - name = "litert_model_srcs", - srcs = ["litert_model.cc"], - visibility = ["//tensorflow/lite/experimental/litert/core/model:__pkg__"], -) - -filegroup( - name = "litert_model_hdrs", - srcs = ["litert_model.h"], - visibility = ["//tensorflow/lite/experimental/litert/core/model:__pkg__"], -) - -# Collection of all C API targets. -LITERT_C_API_COMMON_DEPS = [ - ":litert_accelerator", - ":litert_accelerator_registration", - ":litert_any", - ":litert_common", - ":litert_compiled_model", - ":litert_compilation_options", - ":litert_dispatch_delegate", - ":litert_event", - ":litert_environment", - ":litert_layout", - ":litert_logging", - ":litert_model", - ":litert_op_code", - ":litert_options", - ":litert_tensor_buffer", -] - -# This test verifies that the C API header files can build via C compiler. -cc_test( - name = "litert_c_api_common_test", - srcs = ["litert_c_api_common_test.c"], - copts = ["--std=c11"], - linkopts = ["-ldl"], - deps = LITERT_C_API_COMMON_DEPS, -) - -# Build `litert/c:litert_runtime_c_api_so` for `libLiteRtRuntimeCApi.so`. -cc_shared_library( - name = "litert_runtime_c_api_so", - additional_linker_inputs = export_lrt_runtime_only_script(), - shared_lib_name = "libLiteRtRuntimeCApi.so", - user_link_flags = export_lrt_runtime_only_linkopt() + select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + ["-Wl,-soname=libLiteRtRuntimeCApi.so"], - deps = LITERT_C_API_COMMON_DEPS, -) - -cc_library( - name = "litert_dispatch_headers", - hdrs = [ - ":litert_environment.h", - ":litert_environment_options.h", - ":litert_accelerator.h", - ":litert_accelerator_compilation_options.h", - ":litert_any.h", - ":litert_common.h", - ":litert_compiled_model.h", - ":litert_compilation_options.h", - ":litert_event.h", - ":litert_event_type.h", - ":litert_layout.h", - ":litert_logging.h", - ":litert_model.h", - ":litert_op_code.h", - ":litert_options.h", - ":litert_tensor_buffer.h", - ":litert_tensor_buffer_requirements.h", - ":litert_tensor_buffer_types.h", - ":litert_gl_types.h", - # Needed for litert/c/litert_op_code.h - "//tensorflow/lite:builtin_ops.h", - # Neeeded for litert/c/litert_model.h - "//tensorflow/lite/c:tensorflowlite_c_api_hdrs_filegroup", - "//tensorflow/lite/core/c:headers_filegroup", - ], # Export all header files (.h) in this directory - deps = [ - "@opencl_headers", - ], -) - -copy_file( - name = "copy_litert_runtime_c_api_so", - src = "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_so", - target = "libLiteRtRuntimeCApi.so", -) - -# This is cc_library target based on `libLiteRtRuntimeCApi.so`. -cc_library( - name = "litert_runtime_c_api_shared_lib", - srcs = [":litert_runtime_c_api_so"], - hdrs = glob(["litert_*.h"]) + [ - # Needed for litert/c/litert_op_code.h - "//tensorflow/lite:builtin_ops.h", - # Neeeded for litert/c/litert_model.h - "//tensorflow/lite/c:tensorflowlite_c_api_hdrs_filegroup", - "//tensorflow/lite/core/c:headers_filegroup", - ], - linkstatic = 1, - deps = [ - # only depend on headers - "@opencl_headers", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - ] + gles_headers(), -) - -# copybara:uncomment_begin(google-only) -# # Check that litert runtime doesn't depend on MLIR. -# check_dependencies( -# of = [":litert_runtime_c_api_shared_lib"], -# dont_match_regexp = "^//third_party/llvm/llvm-project/mlir", -# ) -# copybara:uncomment_end - -exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator.cc deleted file mode 100644 index 3e90de58967d07..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator.cc +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Gets the number of accelerators registered to LiteRT. -LiteRtStatus LiteRtGetNumAccelerators(LiteRtEnvironment environment, - LiteRtParamIndex* num_accelerators) { - if (!environment || !num_accelerators) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_accelerators = environment->GetAcceleratorRegistry().size(); - return kLiteRtStatusOk; -} - -// Gets the accelerator at given index that is registered to LiteRT. -LiteRtStatus LiteRtGetAccelerator(LiteRtEnvironment environment, - LiteRtParamIndex index, - - LiteRtAccelerator* accelerator) { - if (!environment || !accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - litert::Expected registered_accelerator = - environment->GetAcceleratorRegistry().Get(index); - if (!registered_accelerator.HasValue()) { - return registered_accelerator.Error().Status(); - } - *accelerator = registered_accelerator.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorName(LiteRtAccelerator accelerator, - char const** name) { - if (!accelerator || !accelerator->GetName || !name) { - return kLiteRtStatusErrorInvalidArgument; - } - return accelerator->GetName(accelerator, name); -} - -LiteRtStatus LiteRtGetAcceleratorId(LiteRtAccelerator accelerator, - LiteRtAcceleratorId* id) { - if (!accelerator || !accelerator->env || !id) { - return kLiteRtStatusErrorInvalidArgument; - } - litert::Expected index = - accelerator->env->GetAcceleratorRegistry().FindAcceleratorIndex( - accelerator); - if (!index.HasValue()) { - return index.Error().Status(); - } - *id = index.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - if (!accelerator || !accelerator->GetVersion || !version) { - return kLiteRtStatusErrorInvalidArgument; - } - return accelerator->GetVersion(accelerator, version); -} - -LiteRtStatus LiteRtGetAcceleratorHardwareSupport( - LiteRtAccelerator accelerator, LiteRtHwAcceleratorSet* supported_hardware) { - if (!accelerator || !accelerator->GetHardwareSupport || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - return accelerator->GetHardwareSupport(accelerator, supported_hardware); -} - -LiteRtStatus LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, bool* does_jit_compilation) { - if (!accelerator || !does_jit_compilation) { - return kLiteRtStatusErrorInvalidArgument; - } - if (!accelerator->IsTfLiteDelegateResponsibleForJitCompilation) { - *does_jit_compilation = false; - return kLiteRtStatusOk; - } - return accelerator->IsTfLiteDelegateResponsibleForJitCompilation( - accelerator, does_jit_compilation); -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator.h b/tensorflow/lite/experimental/litert/c/litert_accelerator.h deleted file mode 100644 index ff3ec4bf14f9a6..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LITERT_DEFINE_HANDLE(LiteRtAccelerator); - -typedef size_t LiteRtAcceleratorId; - -// Gets the number of accelerators registered to LiteRT. -LiteRtStatus LiteRtGetNumAccelerators(LiteRtEnvironment environment, - LiteRtParamIndex* num_accelerators); - -// Gets the accelerator at given index that is registered to LiteRT. -LiteRtStatus LiteRtGetAccelerator(LiteRtEnvironment environment, - LiteRtParamIndex index, - LiteRtAccelerator* accelerator); - -// Fetches the name of the accelerator. -// -// Note: client code does not need to manage the `name` lifetime. -LiteRtStatus LiteRtGetAcceleratorName(LiteRtAccelerator accelerator, - char const** name); - -// Fetches the accelerator identifier. -// -// The identifier is a runtime unique number, provided by the registrar to the -// accelerator upon registration. -LiteRtStatus LiteRtGetAcceleratorId(LiteRtAccelerator accelerator, - LiteRtAcceleratorId* id); - -// Fetches the version of the accelerator implementation. -// -// Note: This is NOT the LiteRT version. It's the accelerator specific software -// implementation version. -LiteRtStatus LiteRtGetAcceleratorVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version); - -// Fetches the accelerator hardware. -// -// `supported_hardware` is a bitfield of `LiteRtHwAccelerators` values. -LiteRtStatus LiteRtGetAcceleratorHardwareSupport( - LiteRtAccelerator accelerator, LiteRtHwAcceleratorSet* supported_hardware); - -// Returns whether the accelerator TFLite delegate does some JIT compilation. -LiteRtStatus LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, bool* does_jit_compilation); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc deleted file mode 100644 index fb6e0016e69885..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -struct LiteRtAcceleratorCompilationOptionsT { - LiteRtApiVersion payload_version; - std::string payload_identifier; - std::unique_ptr payload_data; - LiteRtAcceleratorCompilationOptionsT* next = nullptr; - - LiteRtAcceleratorCompilationOptionsT(const LiteRtApiVersion& payload_version_, - std::string payload_identifier_, - void* payload_data_, - void (*payload_destructor_)(void*)) - : payload_version(payload_version_), - payload_identifier(std::move(payload_identifier_)), - payload_data(payload_data_, payload_destructor_) {} -}; - -LiteRtStatus LiteRtCreateAcceleratorCompilationOptions( - const LiteRtApiVersion* payload_version, const char* payload_identifier, - void* payload_data, void (*payload_destructor)(void*), - LiteRtAcceleratorCompilationOptions* options) { - if (!payload_version || !payload_identifier || !payload_data || - !payload_destructor || !options) { - return kLiteRtStatusErrorInvalidArgument; - } - *options = new LiteRtAcceleratorCompilationOptionsT( - *payload_version, std::string(payload_identifier), payload_data, - payload_destructor); - return kLiteRtStatusOk; -} - -void LiteRtDestroyAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions options) { - while (options) { - LiteRtAcceleratorCompilationOptions next = options->next; - delete options; - options = next; - } -} - -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsVersion( - LiteRtAcceleratorCompilationOptions options, - LiteRtApiVersion* payload_version) { - if (!options || !payload_version) { - return kLiteRtStatusErrorInvalidArgument; - } - *payload_version = options->payload_version; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsIdentifier( - LiteRtAcceleratorCompilationOptions options, - const char** payload_identifier) { - if (!options || !payload_identifier) { - return kLiteRtStatusErrorInvalidArgument; - } - *payload_identifier = options->payload_identifier.c_str(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, void** payload_data) { - if (!options || !payload_data) { - return kLiteRtStatusErrorInvalidArgument; - } - *payload_data = options->payload_data.get(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtFindAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, const char* payload_identifier, - LiteRtApiVersion* payload_version, void** payload_data) { - if (!options || !payload_identifier || !payload_version || !payload_data) { - return kLiteRtStatusErrorInvalidArgument; - } - while (options) { - if (!strcmp(options->payload_identifier.c_str(), payload_identifier)) { - *payload_version = options->payload_version; - *payload_data = options->payload_data.get(); - return kLiteRtStatusOk; - } else { - options = options->next; - } - } - return kLiteRtStatusErrorNotFound; -} - -LiteRtStatus LiteRtGetNextAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options) { - if (!options || !*options) { - return kLiteRtStatusErrorInvalidArgument; - } - *options = (*options)->next; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAppendAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options, - LiteRtAcceleratorCompilationOptions appended_options) { - if (!options || !appended_options) { - return kLiteRtStatusErrorInvalidArgument; - } - while (*options) { - options = &((*options)->next); - } - *options = appended_options; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtPopAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options) { - if (!options) { - return kLiteRtStatusErrorInvalidArgument; - } - LiteRtAcceleratorCompilationOptions* last = options; - while ((*last)->next) { - last = &(*last)->next; - } - if (*last) { - LiteRtDestroyAcceleratorCompilationOptions(*last); - *last = nullptr; - } - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h b/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h deleted file mode 100644 index cfeff3df8d6cf7..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// A linked list of versioned accelerator compilation options. List items -// include: -// -// - a unique payload identifier field (string), used to distinguish payloads of -// different types; -// -// - a payload field and associated payload destructor callback; -// -// - a payload version field, used by the consumer code to know the structure of -// the payload. -LITERT_DEFINE_HANDLE(LiteRtAcceleratorCompilationOptions); - -LiteRtStatus LiteRtCreateAcceleratorCompilationOptions( - const LiteRtApiVersion* payload_version, const char* payload_identifier, - void* payload_data, void (*payload_destructor)(void* payload_data), - LiteRtAcceleratorCompilationOptions* options); - -// Releases an entire options list starting from `options`. -// -// Warning: Once an `options` item has been appended to another `options` item, -// the user will no longer need to destoy the former `options` item manually -// with this function. -void LiteRtDestroyAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions options); - -// Gets the payload version field of the first item in the given `options` list. -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsVersion( - LiteRtAcceleratorCompilationOptions options, - LiteRtApiVersion* payload_version); - -// Gets the patload identifier field of the first item in the given `options` -// list. -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsIdentifier( - LiteRtAcceleratorCompilationOptions options, - const char** payload_identifier); - -// Gets the payload data field of the first item in the given `options` list. -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, void** payload_data); - -// Gets the payload version and data for the `options` list item with a given -// payload identifier. Return kLiteRtStatusErrorNotFound if not such item is -// found. -LiteRtStatus LiteRtFindAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, const char* payload_identifier, - LiteRtApiVersion* payload_version, void** payload_data); - -// Iterate through the next item in the option list pointed by `options` and -// sets parameter `options` to null if there is no next item. -LiteRtStatus LiteRtGetNextAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options); - -// Appends `next_options` to the list ponted by `options` and takes ownership of -// the appended object. While parameter `options` must be non-null, `*options` -// may however be null, in which case this call is equivalent to `*options = -// appended_options`. -LiteRtStatus LiteRtAppendAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options, - LiteRtAcceleratorCompilationOptions appended_options); - -// Removes and deallocates the last option in the linked list pointed by -// parameter `options`. -LiteRtStatus LiteRtPopAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc deleted file mode 100644 index f13195942893e1..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::StrEq; -using testing::litert::IsError; - -struct DummyAccleratorCompilationOptions { - static constexpr const LiteRtApiVersion kVersion = {0, 1, 0}; - static constexpr const char* const kIdentifier = "dummy-accelerator"; - - int dummy_option = 3; - - // Allocates and sets the basic structure for the accelerator options. - static litert::Expected CreateOptions() { - auto* payload = new DummyAccleratorCompilationOptions; - auto payload_destructor = [](void* payload) { - delete reinterpret_cast(payload); - }; - return CreateOptions(kVersion, kIdentifier, payload, payload_destructor); - } - - static litert::Expected CreateOptions( - LiteRtApiVersion version, const char* identifier, void* payload, - void (*payload_destructor)(void*)) { - LiteRtAcceleratorCompilationOptions options; - LITERT_RETURN_IF_ERROR(LiteRtCreateAcceleratorCompilationOptions( - &version, identifier, payload, payload_destructor, &options)); - return options; - } -}; - -class LiteRtAcceleratorOptionsTest : public testing::Test { - public: - void SetUp() override { - auto options = DummyAccleratorCompilationOptions::CreateOptions(); - EXPECT_TRUE(options); - options_ = *options; - } - - void TearDown() override { - LiteRtDestroyAcceleratorCompilationOptions(options_); - options_ = nullptr; - } - - LiteRtAcceleratorCompilationOptions options_ = nullptr; -}; - -TEST_F(LiteRtAcceleratorOptionsTest, CreateAndDestroyDoesntLeak) {} - -TEST_F(LiteRtAcceleratorOptionsTest, GetIdentifier) { - const char* identifier = nullptr; - LITERT_EXPECT_OK( - LiteRtGetAcceleratorCompilationOptionsIdentifier(options_, &identifier)); - EXPECT_THAT(identifier, - StrEq(DummyAccleratorCompilationOptions::kIdentifier)); - EXPECT_THAT( - LiteRtGetAcceleratorCompilationOptionsIdentifier(nullptr, &identifier), - IsError(kLiteRtStatusErrorInvalidArgument)); - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsIdentifier(options_, nullptr), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorOptionsTest, GetVersion) { - LiteRtApiVersion version; - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsVersion(options_, &version), - kLiteRtStatusOk); - EXPECT_TRUE(litert::internal::IsSameVersion( - version, DummyAccleratorCompilationOptions::kVersion)); - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsVersion(nullptr, &version), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsVersion(options_, nullptr), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorOptionsTest, CreatingAndDestroyingAListWorks) { - auto appended_options1 = DummyAccleratorCompilationOptions::CreateOptions(); - ASSERT_TRUE(appended_options1); - auto appended_options2 = DummyAccleratorCompilationOptions::CreateOptions(); - ASSERT_TRUE(appended_options2); - - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options1), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options2), - kLiteRtStatusOk); - - // Iterate through the list to check that the links have been correctly added. - - LiteRtAcceleratorCompilationOptions options_it = options_; - ASSERT_EQ(LiteRtGetNextAcceleratorCompilationOptions(&options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *appended_options1); - - ASSERT_EQ(LiteRtGetNextAcceleratorCompilationOptions(&options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *appended_options2); - - // The list is destroyed in the `TearDown()` function. -} - -TEST_F(LiteRtAcceleratorOptionsTest, FindData) { - constexpr LiteRtApiVersion appended_options_version = {1, 2, 3}; - constexpr auto* appended_options_id = "appended_options_id"; - void* appended_options_data = reinterpret_cast(12345); - constexpr auto appended_options_destructor = [](void*) {}; - - auto appended_options = DummyAccleratorCompilationOptions::CreateOptions( - appended_options_version, appended_options_id, appended_options_data, - appended_options_destructor); - - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options), - kLiteRtStatusOk); - - LiteRtApiVersion payload_version; - void* payload_data; - EXPECT_EQ(LiteRtFindAcceleratorCompilationOptionsData( - options_, appended_options_id, &payload_version, &payload_data), - kLiteRtStatusOk); - - EXPECT_EQ(payload_version.major, appended_options_version.major); - EXPECT_EQ(payload_version.minor, appended_options_version.minor); - EXPECT_EQ(payload_version.patch, appended_options_version.patch); - EXPECT_EQ(payload_data, appended_options_data); - - // The list is destroyed in the `TearDown()` function. -} - -TEST_F(LiteRtAcceleratorOptionsTest, Pop) { - constexpr LiteRtApiVersion appended_options_version = {1, 2, 3}; - constexpr auto* appended_options_id = "appended_options_id"; - void* appended_options_data = reinterpret_cast(12345); - constexpr auto appended_options_destructor = [](void*) {}; - - auto appended_options = DummyAccleratorCompilationOptions::CreateOptions( - appended_options_version, appended_options_id, appended_options_data, - appended_options_destructor); - - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options), - kLiteRtStatusOk); - - LiteRtApiVersion payload_version; - void* payload_data; - EXPECT_EQ(LiteRtFindAcceleratorCompilationOptionsData( - options_, appended_options_id, &payload_version, &payload_data), - kLiteRtStatusOk); - - // After poping the last item, we shouldn't be able to find it any longer. - EXPECT_EQ(LiteRtPopAcceleratorCompilationOptions(&options_), kLiteRtStatusOk); - EXPECT_NE(LiteRtFindAcceleratorCompilationOptionsData( - options_, appended_options_id, &payload_version, &payload_data), - kLiteRtStatusOk); - - // The list is destroyed in the `TearDown()` function. -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc deleted file mode 100644 index 8404f7275b7fb5..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -LiteRtStatus LiteRtCreateAccelerator(LiteRtAccelerator* accelerator) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - *accelerator = - litert::internal::AcceleratorRegistry::CreateEmptyAccelerator().release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDestroyAccelerator(LiteRtAccelerator accelerator) { - litert::internal::AcceleratorRegistry::DestroyAccelerator(accelerator); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtRegisterAccelerator(LiteRtEnvironment environment, - LiteRtAccelerator accelerator, - void* data, void (*ReleaseData)(void*)) { - std::unique_ptr data_guard(data, ReleaseData); - litert::internal::AcceleratorRegistry::Ptr accelerator_guard(accelerator); - if (!accelerator_guard) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator_guard->env = environment; - litert::Expected registered_accelerator = - environment->GetAcceleratorRegistry().RegisterAccelerator( - std::move(accelerator_guard)); - if (!registered_accelerator.HasValue()) { - return registered_accelerator.Error().Status(); - } - registered_accelerator.Value()->data = data_guard.release(); - registered_accelerator.Value()->ReleaseData = ReleaseData; - return kLiteRtStatusOk; -} - -// Sets the function used to retrieve the accelerator name. -LiteRtStatus LiteRtSetAcceleratorGetName( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetName)(LiteRtAccelerator accelerator, const char** name)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->GetName = GetName; - return kLiteRtStatusOk; -} - -// Sets the function used to retrieve the accelerator version. -LiteRtStatus LiteRtSetAcceleratorGetVersion( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetVersion)(LiteRtAccelerator accelerator, - LiteRtApiVersion* version)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->GetVersion = GetVersion; - return kLiteRtStatusOk; -} - -// Sets the function used to retrieve the accelerator hardware support. -LiteRtStatus LiteRtSetAcceleratorGetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetHardwareSupport)( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->GetHardwareSupport = GetHardwareSupport; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetDelegateFunction( - LiteRtAccelerator accelerator, - LiteRtStatus (*CreateDelegate)(LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, - void** delegate), - void (*DestroyDelegate)(void* delegate)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->CreateDelegate = CreateDelegate; - accelerator->DestroyDelegate = DestroyDelegate; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, - LiteRtStatus (*IsTfLiteDelegateResponsibleForJitCompilation)( - LiteRtAcceleratorT* accelerator, bool* does_jit_compilation)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->IsTfLiteDelegateResponsibleForJitCompilation = - IsTfLiteDelegateResponsibleForJitCompilation; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h b/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h deleted file mode 100644 index 19369d436daea0..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_REGISTRATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_REGISTRATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Creates an empty accelerator handle. -LiteRtStatus LiteRtCreateAccelerator(LiteRtAccelerator* accelerator); - -// Destroys an accelerator handle. -// -// Warning: This SHOULD NOT BE CALLED after a call to -// `LiteRtRegisterAccelerator`. -LiteRtStatus LiteRtDestroyAccelerator(LiteRtAccelerator accelerator); - -// Sets the registration data AND clean-up function, then registers the -// accelerator with the LiteRT environment. -// -// - `data` and `ReleaseData` may be null. -// -// Note: After this function returns successfully, `data` is managed by the -// LiteRT environment. `ReleaseData` is called to release its memory. -// -// Warning: In case of failure, `accelerator` is released and `data` is released -// using `ReleaseData`. -LiteRtStatus LiteRtRegisterAccelerator(LiteRtEnvironment environment, - LiteRtAccelerator accelerator, - void* data, void (*ReleaseData)(void*)); - -// Sets the function used to retrieve the accelerator name. -LiteRtStatus LiteRtSetAcceleratorGetName( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetName)(LiteRtAccelerator accelerator, const char** name)); - -// Sets the function used to retrieve the accelerator implementation version. -// -// Note: This is NOT the LiteRT version. It's the accelerator specific software -// implementation version. -LiteRtStatus LiteRtSetAcceleratorGetVersion( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetVersion)(LiteRtAccelerator accelerator, - LiteRtApiVersion* version)); - -// Sets the function used to retrieve the accelerator hardware support. -LiteRtStatus LiteRtSetAcceleratorGetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetHardwareSupport)( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware)); - -// Sets the function used to return a Delegate to apply the accelerator by the -// compiled model and its destructor. The returned Delegate object is owned by -// the compiled model. Used void** for the Delegate instead of -// TfLiteOpaqueDelegate** to avoid TFLite dependency. -LiteRtStatus LiteRtSetDelegateFunction( - LiteRtAccelerator accelerator, - LiteRtStatus (*CreateDelegate)(LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, - void** delegate), - void (*DestroyDelegate)(void* delegate)); - -// Sets the function used to surface whether the delegate created by the -// accelerator does JIT compilation or not. -// -// This affects whether the compiled model creation will apply the accelerator -// without an explicit request in the JIT compilation options. -// -// If this isn't set, the result will be treated as `false`. -LiteRtStatus LiteRtSetIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, - LiteRtStatus (*IsTfLiteDelegateResponsibleForJitCompilation)( - LiteRtAccelerator accelerator, bool* does_jit_compilation)); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_REGISTRATION_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc deleted file mode 100644 index df7a051949447f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" - -#include - -#include -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -namespace { - -class DummyAccelerator { - public: - static std::unique_ptr CpuAccelerator() { - auto accelerator = std::make_unique(); - accelerator->hardware_support_ = kLiteRtHwAcceleratorCpu; - return accelerator; - } - - static void Destroy(void* dummy_accelerator) { - DummyAccelerator* instance = - reinterpret_cast(dummy_accelerator); - delete instance; - } - - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - return kLiteRtStatusOk; - } - - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - return kLiteRtStatusOk; - } - - static LiteRtStatus GetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware) { - return kLiteRtStatusOk; - } - - static LiteRtStatus CreateDelegate( - LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, void** delegate) { - return kLiteRtStatusOk; - } - - static void DestroyDelegate(void* delegate) {} - - LiteRtHwAccelerators hardware_support_; -}; - -TEST(LiteRtAcceleratorRegistrationTest, SetAcceleratorGetNameWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ(LiteRtSetAcceleratorGetName(nullptr, DummyAccelerator::GetName), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetAcceleratorGetName(&accelerator, DummyAccelerator::GetName); - EXPECT_EQ(accelerator.GetName, DummyAccelerator::GetName); -} - -TEST(LiteRtAcceleratorRegistrationTest, SetAcceleratorGetVersionWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ( - LiteRtSetAcceleratorGetVersion(nullptr, DummyAccelerator::GetVersion), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetAcceleratorGetVersion(&accelerator, DummyAccelerator::GetVersion); - EXPECT_EQ(accelerator.GetVersion, DummyAccelerator::GetVersion); -} - -TEST(LiteRtAcceleratorRegistrationTest, SetAcceleratorGetHardwareSupportWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ(LiteRtSetAcceleratorGetHardwareSupport( - nullptr, DummyAccelerator::GetHardwareSupport), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetAcceleratorGetHardwareSupport(&accelerator, - DummyAccelerator::GetHardwareSupport); - EXPECT_EQ(accelerator.GetHardwareSupport, - DummyAccelerator::GetHardwareSupport); -} - -TEST(LiteRtAcceleratorRegistrationTest, SetDelegateFunctionsWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ(LiteRtSetDelegateFunction(nullptr, DummyAccelerator::CreateDelegate, - DummyAccelerator::DestroyDelegate), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetDelegateFunction(&accelerator, DummyAccelerator::CreateDelegate, - DummyAccelerator::DestroyDelegate); - EXPECT_EQ(accelerator.CreateDelegate, DummyAccelerator::CreateDelegate); - EXPECT_EQ(accelerator.DestroyDelegate, DummyAccelerator::DestroyDelegate); -} - -TEST(LiteRtAcceleratorRegistrationTest, CreateDestroyAcceleratorDoesntLeak) { - LiteRtAccelerator accelerator; - ASSERT_EQ(LiteRtCreateAccelerator(&accelerator), kLiteRtStatusOk); - ASSERT_EQ(LiteRtDestroyAccelerator(accelerator), kLiteRtStatusOk); -} - -TEST(LiteRtAcceleratorRegistrationTest, RegisterAcceleratorWorks) { - LiteRtEnvironment env_; - LiteRtEnvironmentCreate(/*num_options=*/0, /*options=*/nullptr, &env_); - auto dummy_accelerator = DummyAccelerator::CpuAccelerator(); - LiteRtAccelerator accelerator; - LiteRtCreateAccelerator(&accelerator); - LiteRtSetAcceleratorGetName(accelerator, DummyAccelerator::GetName); - LiteRtSetAcceleratorGetVersion(accelerator, DummyAccelerator::GetVersion); - LiteRtSetAcceleratorGetHardwareSupport(accelerator, - DummyAccelerator::GetHardwareSupport); - LiteRtRegisterAccelerator(env_, accelerator, dummy_accelerator.release(), - DummyAccelerator::Destroy); - LiteRtDestroyEnvironment(env_); -} - -TEST(LiteRtAcceleratorRegistrationTest, - RegisterAcceleratorFailsForNullAccelerator) { - LiteRtEnvironment env; - LiteRtEnvironmentCreate(/*num_options=*/0, /*options=*/nullptr, &env); - // We check that the memory is correctly deallocated if the registration - // fails. - auto dummy_accelerator = DummyAccelerator::CpuAccelerator(); - EXPECT_EQ(LiteRtRegisterAccelerator(env, nullptr, dummy_accelerator.release(), - DummyAccelerator::Destroy), - kLiteRtStatusErrorInvalidArgument); - LiteRtDestroyEnvironment(env); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc deleted file mode 100644 index 0ef878b5cc3198..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" - -#include -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -#define LITERT_ENSURE_OK(expr) \ - do { \ - LiteRtStatus status = (expr); \ - if (status != kLiteRtStatusOk) { \ - return status; \ - } \ - } while (0) - -namespace { -using testing::Eq; -using testing::Ne; -using testing::NotNull; -using testing::StrEq; - -class DummyAccelerator { - public: - // `hardware_support` is a bitfield of `LiteRtHwAccelerators` values. - static LiteRtStatus RegisterAccelerator(int hardware_support, - LiteRtEnvironment env) { - auto dummy_accelerator = std::make_unique(); - dummy_accelerator->hardware_support_ = hardware_support; - LiteRtAccelerator accelerator; - LiteRtCreateAccelerator(&accelerator); - LITERT_ENSURE_OK( - LiteRtSetAcceleratorGetName(accelerator, DummyAccelerator::GetName)); - LITERT_ENSURE_OK(LiteRtSetAcceleratorGetVersion( - accelerator, DummyAccelerator::GetVersion)); - LITERT_ENSURE_OK(LiteRtSetAcceleratorGetHardwareSupport( - accelerator, DummyAccelerator::GetHardwareSupport)); - LITERT_ENSURE_OK(LiteRtRegisterAccelerator(env, accelerator, - dummy_accelerator.release(), - DummyAccelerator::Destroy)); - return kLiteRtStatusOk; - } - - static void Destroy(void* dummy_accelerator) { - DummyAccelerator* instance = - reinterpret_cast(dummy_accelerator); - delete instance; - } - - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - if (!accelerator || !accelerator->data || !name) { - return kLiteRtStatusErrorInvalidArgument; - } - DummyAccelerator& self = - *reinterpret_cast(accelerator->data); - if (self.name_.empty()) { - self.name_.append("Dummy"); - if (self.hardware_support_ & kLiteRtHwAcceleratorCpu) { - self.name_.append("Cpu"); - } - if (self.hardware_support_ & kLiteRtHwAcceleratorGpu) { - self.name_.append("Gpu"); - } - self.name_.append("Accelerator"); - } - *name = self.name_.c_str(); - return kLiteRtStatusOk; - } - - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - if (!version) { - return kLiteRtStatusErrorInvalidArgument; - } - version->major = 1; - version->minor = 2; - version->patch = 3; - return kLiteRtStatusOk; - } - - static LiteRtStatus GetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware) { - if (!accelerator || !accelerator->data || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - - const DummyAccelerator& self = - *reinterpret_cast(accelerator->data); - *supported_hardware = self.hardware_support_; - return kLiteRtStatusOk; - } - - int hardware_support_; - std::string name_; -}; - -class LiteRtAcceleratorTest : public testing::Test { - public: - LiteRtEnvironment env_; - void SetUp() override { - LiteRtEnvironmentCreate(/*num_options=*/0, nullptr, &env_); - DummyAccelerator::RegisterAccelerator(kLiteRtHwAcceleratorCpu, env_); - } - - void TearDown() override { LiteRtDestroyEnvironment(env_); } -}; - -TEST_F(LiteRtAcceleratorTest, IteratingOverAcceleratorsWorks) { - // CPU accelerator is registered in the SetUp function. - DummyAccelerator::RegisterAccelerator(kLiteRtHwAcceleratorGpu, env_); - - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 2); - - EXPECT_THAT(LiteRtGetAccelerator(env_, 0, nullptr), - kLiteRtStatusErrorInvalidArgument); - LiteRtAccelerator accelerator0; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator0), kLiteRtStatusOk); - EXPECT_THAT(accelerator0, NotNull()); - - EXPECT_THAT(LiteRtGetAccelerator(env_, 1, nullptr), - kLiteRtStatusErrorInvalidArgument); - LiteRtAccelerator accelerator1; - ASSERT_THAT(LiteRtGetAccelerator(env_, 1, &accelerator1), kLiteRtStatusOk); - EXPECT_THAT(accelerator1, NotNull()); - - EXPECT_THAT(accelerator0, Ne(accelerator1)); - - LiteRtAccelerator accelerator2; - EXPECT_THAT(LiteRtGetAccelerator(env_, 2, &accelerator2), - kLiteRtStatusErrorNotFound); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorNameWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - const char* name = nullptr; - ASSERT_THAT(LiteRtGetAcceleratorName(accelerator, &name), kLiteRtStatusOk); - EXPECT_THAT(name, StrEq("DummyCpuAccelerator")); - - EXPECT_THAT(LiteRtGetAcceleratorName(nullptr, &name), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorName(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->GetName = nullptr; - EXPECT_THAT(LiteRtGetAcceleratorName(accelerator, &name), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorIdWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - LiteRtAcceleratorId accelerator_id; - ASSERT_THAT(LiteRtGetAcceleratorId(accelerator, &accelerator_id), - kLiteRtStatusOk); - EXPECT_THAT(accelerator_id, Eq(0)); - - EXPECT_THAT(LiteRtGetAcceleratorId(nullptr, &accelerator_id), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorId(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->env = nullptr; - EXPECT_THAT(LiteRtGetAcceleratorId(accelerator, &accelerator_id), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorVersionWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - LiteRtApiVersion version; - ASSERT_THAT(LiteRtGetAcceleratorVersion(accelerator, &version), - kLiteRtStatusOk); - EXPECT_THAT(version.major, Eq(1)); - EXPECT_THAT(version.minor, Eq(2)); - EXPECT_THAT(version.patch, Eq(3)); - - EXPECT_THAT(LiteRtGetAcceleratorVersion(nullptr, &version), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorVersion(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->GetVersion = nullptr; - EXPECT_THAT(LiteRtGetAcceleratorVersion(accelerator, &version), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorHardwareSupportWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - int hardware_support; - ASSERT_THAT( - LiteRtGetAcceleratorHardwareSupport(accelerator, &hardware_support), - kLiteRtStatusOk); - EXPECT_THAT(hardware_support & kLiteRtHwAcceleratorCpu, true); - EXPECT_THAT(hardware_support & kLiteRtHwAcceleratorGpu, false); - EXPECT_THAT(hardware_support & kLiteRtHwAcceleratorNpu, false); - - EXPECT_THAT(LiteRtGetAcceleratorHardwareSupport(nullptr, &hardware_support), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorHardwareSupport(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->GetHardwareSupport = nullptr; - EXPECT_THAT( - LiteRtGetAcceleratorHardwareSupport(accelerator, &hardware_support), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, - IsAcceleratorDelegateResponsibleForJitCompilationWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - bool does_jit_compilation; - ASSERT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator, &does_jit_compilation), - kLiteRtStatusOk); - EXPECT_THAT(does_jit_compilation, false); - - EXPECT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - nullptr, &does_jit_compilation), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - - // Add an implementation to the function. - accelerator->IsTfLiteDelegateResponsibleForJitCompilation = - [](LiteRtAccelerator, bool* does_jit) { - *does_jit = true; - return kLiteRtStatusOk; - }; - EXPECT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator, &does_jit_compilation), - kLiteRtStatusOk); - EXPECT_THAT(does_jit_compilation, true); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_any.h b/tensorflow/lite/experimental/litert/c/litert_any.h deleted file mode 100644 index e8e67b0c80f239..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_any.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ - -#include // NOLINT: To use bool type in C -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLiteRtAnyTypeNone = 0, - kLiteRtAnyTypeBool = 1, - kLiteRtAnyTypeInt = 2, - kLiteRtAnyTypeReal = 3, - kLiteRtAnyTypeString = 8, - kLiteRtAnyTypeVoidPtr = 9, -} LiteRtAnyType; - -inline const char* LiteRtAnyTypeToString(LiteRtAnyType type) { - switch (type) { - case kLiteRtAnyTypeNone: - return "kLiteRtAnyTypeNone"; - case kLiteRtAnyTypeBool: - return "kLiteRtAnyTypeBool"; - case kLiteRtAnyTypeInt: - return "kLiteRtAnyTypeInt"; - case kLiteRtAnyTypeReal: - return "kLiteRtAnyTypeReal"; - case kLiteRtAnyTypeString: - return "kLiteRtAnyTypeString"; - case kLiteRtAnyTypeVoidPtr: - return "kLiteRtAnyTypeVoidPtr"; - } - return "Unknown"; -} - -typedef struct { - LiteRtAnyType type; - union { - bool bool_value; - int64_t int_value; - double real_value; - const char* str_value; - const void* ptr_value; - }; -} LiteRtAny; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c b/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c deleted file mode 100644 index f4aa75e231c297..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file exists to verify that the below header files can build, link, -// and run as C code. -#ifdef __cplusplus -#error "This file should be compiled as C code, not as C++." -#endif - -// Include all the header files in the litert/c directory. -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_any.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_common.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_event.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_model.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_options.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" // NOLINT - -int main(void) { return 0; } diff --git a/tensorflow/lite/experimental/litert/c/litert_common.cc b/tensorflow/lite/experimental/litert/c/litert_common.cc deleted file mode 100644 index adbecb6259f2f8..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_common.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -extern "C" { - -const char* LiteRtGetStatusString(LiteRtStatus status) { - switch (status) { - // NOLINTNEXTLINE(preprocessor-macros) -#define LITERT_STATUS_STR_CASE(STATUS) \ - case STATUS: \ - return #STATUS; - LITERT_STATUS_STR_CASE(kLiteRtStatusOk); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidArgument); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorMemoryAllocationFailure); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorRuntimeFailure); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorMissingInputTensor); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorUnsupported); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorNotFound); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorTimeoutExpired); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorFileIO); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidFlatbuffer); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorDynamicLoading); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorSerialization); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorCompilation); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorIndexOOB); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidIrType); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidGraphInvariant); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorGraphModification); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidToolConfig); - LITERT_STATUS_STR_CASE(kLiteRtStatusLegalizeNoMatch); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidLegalization); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorWrongVersion); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorUnknown); -#undef LITERT_STATUS_STR_CASE - } -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/c/litert_common.h b/tensorflow/lite/experimental/litert/c/litert_common.h deleted file mode 100644 index 4ed0ca5407c4ad..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_common.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Declares canonical opaque type. -#define LITERT_DEFINE_HANDLE(name) typedef struct name##T* name - -#if __ANDROID_API__ >= 26 -#define LITERT_HAS_AHWB_SUPPORT 1 -#else -#define LITERT_HAS_AHWB_SUPPORT 0 -#endif // __ANDROID_API__ >= 26 - -#if defined(__linux__) || defined(__ANDROID__) -#define LITERT_HAS_SYNC_FENCE_SUPPORT 1 -#else -#define LITERT_HAS_SYNC_FENCE_SUPPORT 0 -#endif - -#if defined(__ANDROID__) -#define LITERT_HAS_ION_SUPPORT 1 -#define LITERT_HAS_DMABUF_SUPPORT 1 -#define LITERT_HAS_FASTRPC_SUPPORT 1 -#define LITERT_HAS_OPENGL_SUPPORT 1 -#define LITERT_HAS_OPENCL_SUPPORT_DEFAULT 1 -// copybara:comment_begin(google-only) -#elif defined(GOOGLE_UNSUPPORTED_OS_LOONIX) -#define LITERT_HAS_ION_SUPPORT 0 -#define LITERT_HAS_DMABUF_SUPPORT 1 -#define LITERT_HAS_FASTRPC_SUPPORT 0 -#define LITERT_HAS_OPENCL_SUPPORT_DEFAULT 1 -// copybara:comment_end -#else -#define LITERT_HAS_ION_SUPPORT 0 -#define LITERT_HAS_DMABUF_SUPPORT 0 -#define LITERT_HAS_FASTRPC_SUPPORT 0 -#define LITERT_HAS_OPENCL_SUPPORT_DEFAULT 1 -#define LITERT_HAS_OPENGL_SUPPORT 0 -#endif - -#if defined(LITERT_DISABLE_OPENCL_SUPPORT) -#define LITERT_HAS_OPENCL_SUPPORT 0 -#else -#define LITERT_HAS_OPENCL_SUPPORT LITERT_HAS_OPENCL_SUPPORT_DEFAULT -#endif - -#define LITERT_API_VERSION_MAJOR 0 -#define LITERT_API_VERSION_MINOR 1 -#define LITERT_API_VERSION_PATCH 0 - -typedef struct LiteRtApiVersion { - int major; - int minor; - int patch; -} LiteRtApiVersion; - -typedef enum { - kLiteRtStatusOk = 0, - - // Generic errors. - kLiteRtStatusErrorInvalidArgument = 1, - kLiteRtStatusErrorMemoryAllocationFailure = 2, - kLiteRtStatusErrorRuntimeFailure = 3, - kLiteRtStatusErrorMissingInputTensor = 4, - kLiteRtStatusErrorUnsupported = 5, - kLiteRtStatusErrorNotFound = 6, - kLiteRtStatusErrorTimeoutExpired = 7, - kLiteRtStatusErrorWrongVersion = 8, - kLiteRtStatusErrorUnknown = 9, - - // File and loading related errors. - kLiteRtStatusErrorFileIO = 500, - kLiteRtStatusErrorInvalidFlatbuffer = 501, - kLiteRtStatusErrorDynamicLoading = 502, - kLiteRtStatusErrorSerialization = 503, - kLiteRtStatusErrorCompilation = 504, - - // IR related errors. - kLiteRtStatusErrorIndexOOB = 1000, - kLiteRtStatusErrorInvalidIrType = 1001, - kLiteRtStatusErrorInvalidGraphInvariant = 1002, - kLiteRtStatusErrorGraphModification = 1003, - - // Tool related errors. - kLiteRtStatusErrorInvalidToolConfig = 1500, - - // Legalization related errors. - kLiteRtStatusLegalizeNoMatch = 2000, - kLiteRtStatusErrorInvalidLegalization = 2001, -} LiteRtStatus; - -// Returns a string describing the status value. -const char* LiteRtGetStatusString(LiteRtStatus status); - -typedef enum : int { - kLiteRtHwAcceleratorNone = 0, - kLiteRtHwAcceleratorCpu = 1 << 0, - kLiteRtHwAcceleratorGpu = 1 << 1, - kLiteRtHwAcceleratorNpu = 1 << 2, -} LiteRtHwAccelerators; - -// A bit field of `LiteRtHwAccelerators` values. -typedef int LiteRtHwAcceleratorSet; - -// For indexing into LiteRT collections or counting LiteRT things. -typedef size_t LiteRtParamIndex; - -#if defined(_WIN32) -// Provides posix_memalign() missing in Windows. -#include - -#define posix_memalign(p, a, s) \ - (((*(p)) = _aligned_malloc((s), (a))), *(p) ? 0 : errno) -#endif // defined(_WIN32) - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_common_test.cc b/tensorflow/lite/experimental/litert/c/litert_common_test.cc deleted file mode 100644 index be0993c1ce4733..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_common_test.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#include -#include - -namespace { - -using testing::Eq; -using testing::Gt; -using testing::Lt; -using testing::StrEq; - -TEST(GetStatusString, Works) { - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusOk), StrEq("kLiteRtStatusOk")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidArgument), - StrEq("kLiteRtStatusErrorInvalidArgument")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorMemoryAllocationFailure), - StrEq("kLiteRtStatusErrorMemoryAllocationFailure")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorRuntimeFailure), - StrEq("kLiteRtStatusErrorRuntimeFailure")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorMissingInputTensor), - StrEq("kLiteRtStatusErrorMissingInputTensor")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorUnsupported), - StrEq("kLiteRtStatusErrorUnsupported")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorNotFound), - StrEq("kLiteRtStatusErrorNotFound")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorTimeoutExpired), - StrEq("kLiteRtStatusErrorTimeoutExpired")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorFileIO), - StrEq("kLiteRtStatusErrorFileIO")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidFlatbuffer), - StrEq("kLiteRtStatusErrorInvalidFlatbuffer")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorDynamicLoading), - StrEq("kLiteRtStatusErrorDynamicLoading")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorSerialization), - StrEq("kLiteRtStatusErrorSerialization")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorCompilation), - StrEq("kLiteRtStatusErrorCompilation")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorIndexOOB), - StrEq("kLiteRtStatusErrorIndexOOB")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidIrType), - StrEq("kLiteRtStatusErrorInvalidIrType")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidGraphInvariant), - StrEq("kLiteRtStatusErrorInvalidGraphInvariant")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorGraphModification), - StrEq("kLiteRtStatusErrorGraphModification")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidToolConfig), - StrEq("kLiteRtStatusErrorInvalidToolConfig")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusLegalizeNoMatch), - StrEq("kLiteRtStatusLegalizeNoMatch")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidLegalization), - StrEq("kLiteRtStatusErrorInvalidLegalization")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorWrongVersion), - StrEq("kLiteRtStatusErrorWrongVersion")); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_compilation_options.cc b/tensorflow/lite/experimental/litert/c/litert_compilation_options.cc deleted file mode 100644 index 0ff32733d13d1e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compilation_options.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/compilation_options.h" - -#define LRT_CHECK_NON_NULL(handle) \ - if (!(handle)) { \ - LITERT_LOG(LITERT_ERROR, #handle " must not be null."); \ - return kLiteRtStatusErrorInvalidArgument; \ - } - -extern "C" { - -LiteRtStatus LiteRtCreateCompilationOptions(LiteRtCompilationOptions* options) { - LRT_CHECK_NON_NULL(options); - *options = new LiteRtCompilationOptionsT; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilationOptions(LiteRtCompilationOptions options) { - delete options; -} - -LiteRtStatus LiteRtSetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet hardware_accelerators) { - LRT_CHECK_NON_NULL(options); - if ((hardware_accelerators & - (kLiteRtHwAcceleratorCpu | kLiteRtHwAcceleratorGpu | - kLiteRtHwAcceleratorNpu)) != hardware_accelerators) { - LITERT_LOG(LITERT_ERROR, - "Invalid bitfield value for hardware accelerator set: %d.", - hardware_accelerators); - return kLiteRtStatusErrorInvalidArgument; - } - options->hardware_accelerators = hardware_accelerators; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet* hardware_accelerators) { - LRT_CHECK_NON_NULL(options); - LRT_CHECK_NON_NULL(hardware_accelerators); - *hardware_accelerators = options->hardware_accelerators; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAddAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions accelerator_compilation_options) { - LRT_CHECK_NON_NULL(options); - LRT_CHECK_NON_NULL(accelerator_compilation_options); - LITERT_RETURN_IF_ERROR(options->accelerator_compilation_options.Append( - litert::AcceleratorCompilationOptions(accelerator_compilation_options, - /*owned=*/false))); - return kLiteRtStatusOk; -} - -// Retrieves the head of the accelerator compilation option list. -// -// Note: The following elements may be retrieved with -// `LiteRtGetNextAcceleratorCompilationOptions`. -LiteRtStatus LiteRtGetAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions* accelerator_compilation_options) { - LRT_CHECK_NON_NULL(options); - LRT_CHECK_NON_NULL(accelerator_compilation_options); - *accelerator_compilation_options = - options->accelerator_compilation_options.Get(); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/c/litert_compilation_options.h b/tensorflow/lite/experimental/litert/c/litert_compilation_options.h deleted file mode 100644 index d27aa3919ff2a2..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compilation_options.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// The compilation options for the LiteRtCompiledModel. -LITERT_DEFINE_HANDLE(LiteRtCompilationOptions); - -// Creates a compilation option object. -LiteRtStatus LiteRtCreateCompilationOptions(LiteRtCompilationOptions* options); - -// Destroys a compilation option object. -void LiteRtDestroyCompilationOptions(LiteRtCompilationOptions options); - -// Sets the requested hardware accelerators to apply during model compilation. -LiteRtStatus LiteRtSetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet hardware_accelerators); - -// Gets the hardware accelerators to apply during model compilation. -LiteRtStatus LiteRtGetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet* hardware_accelerators); - -// Adds compilation options for a specific accelerator to the accelerator -// compilation option list. -// -// Note: Multiple accelerator options may be added to the options object. -// -// Note: `accelerator_compilation_options`'s ownership is transferred to -// `options`. -LiteRtStatus LiteRtAddAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions accelerator_compilation_options); - -// Retrieves the head of the accelerator compilation option list. -// -// Note: The following elements may be retrieved with -// `LiteRtGetNextAcceleratorCompilationOptions`. -LiteRtStatus LiteRtGetAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions* accelerator_compilation_options); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc deleted file mode 100644 index 3941105430e68e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" - -#include -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace { - -TEST(LiteRtCompiledModelOptionsTest, CreateAndDestroyDontLeak) { - LiteRtCompilationOptions options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&options), kLiteRtStatusOk); - LiteRtDestroyCompilationOptions(options); -} - -TEST(LiteRtCompiledModelOptionsTest, CreateWithANullPointerErrors) { - EXPECT_EQ(LiteRtCreateCompilationOptions(nullptr), - kLiteRtStatusErrorInvalidArgument); -} - -TEST(LiteRtCompiledModelOptionsTest, SetAndGetHardwareAcceleratorsWorks) { - LiteRtCompilationOptions options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&options), kLiteRtStatusOk); - - LiteRtHwAcceleratorSet hardware_accelerators; - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorNone), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorNone); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorCpu); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorGpu), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorGpu); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorNpu), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorNpu); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, (kLiteRtHwAcceleratorCpu | kLiteRtHwAcceleratorGpu | - kLiteRtHwAcceleratorNpu) + - 1), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - nullptr, kLiteRtHwAcceleratorNone), - kLiteRtStatusErrorInvalidArgument); - - LiteRtDestroyCompilationOptions(options); -} - -struct DummyAcceleratorCompilationOptions { - static constexpr const LiteRtApiVersion kVersion = {1, 0, 0}; - static constexpr const char* const kIdentifier = "dummy-accelerator"; - - // Allocates and sets the basic structure for the accelerator options. - static litert::Expected CreateOptions() { - LiteRtAcceleratorCompilationOptions options; - auto* payload = new DummyAcceleratorCompilationOptions; - auto payload_destructor = [](void* payload) { - delete reinterpret_cast(payload); - }; - LITERT_RETURN_IF_ERROR(LiteRtCreateAcceleratorCompilationOptions( - &kVersion, kIdentifier, payload, payload_destructor, &options)); - return options; - } -}; - -TEST(LiteRtCompiledModelOptionsTest, AddAcceleratorCompilationOptionsWorks) { - LiteRtCompilationOptions options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&options), kLiteRtStatusOk); - - auto accelerator_compilation_options1 = - DummyAcceleratorCompilationOptions::CreateOptions(); - EXPECT_TRUE(accelerator_compilation_options1); - auto accelerator_compilation_options2 = - DummyAcceleratorCompilationOptions::CreateOptions(); - EXPECT_TRUE(accelerator_compilation_options2); - - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions( - nullptr, *accelerator_compilation_options1), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions(options, nullptr), - kLiteRtStatusErrorInvalidArgument); - - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions( - options, *accelerator_compilation_options1), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions( - options, *accelerator_compilation_options2), - kLiteRtStatusOk); - - LiteRtAcceleratorCompilationOptions options_it = nullptr; - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptions(options, &options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *accelerator_compilation_options1); - - EXPECT_EQ(LiteRtGetNextAcceleratorCompilationOptions(&options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *accelerator_compilation_options2); - - LiteRtDestroyCompilationOptions(options); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc deleted file mode 100644 index 295fbf32c40596..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" - -#include - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateCompiledModel( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtCompilationOptions jit_compilation_options, - LiteRtCompiledModel* compiled_model) { - if (!environment || !model || !compiled_model) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_compiled_model = - LiteRtCompiledModelT::Create(environment, model, jit_compilation_options); - if (!created_compiled_model) { - LITERT_LOG(LITERT_ERROR, "%s", - created_compiled_model.Error().Message().c_str()); - return created_compiled_model.Error().Status(); - } - *compiled_model = created_compiled_model->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex input_index, - LiteRtTensorBufferRequirements* buffer_requirements) { - if (!compiled_model || !buffer_requirements) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto res = compiled_model->GetInputBufferRequirementsCApi(signature_index, - input_index); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - *buffer_requirements = res.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex output_index, - LiteRtTensorBufferRequirements* buffer_requirements) { - if (!compiled_model || !buffer_requirements) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto res = compiled_model->GetOutputBufferRequirementsCApi(signature_index, - output_index); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - *buffer_requirements = res.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, - LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers) { - if (!compiled_model || (num_input_buffers > 0 && !input_buffers) || - (num_output_buffers > 0 && !output_buffers)) { - return kLiteRtStatusErrorInvalidArgument; - } - - bool async = false; - auto res = - compiled_model->RunCApi(signature_index, num_input_buffers, input_buffers, - num_output_buffers, output_buffers, &async); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtRunCompiledModelAsync(LiteRtCompiledModel compiled_model, - LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool* async) { - if (!compiled_model || (num_input_buffers > 0 && !input_buffers) || - (num_output_buffers > 0 && !output_buffers)) { - return kLiteRtStatusErrorInvalidArgument; - } - - if (async) { - *async = true; - } - bool async_ = true; - bool* async_ptr = async ? async : &async_; - - auto res = - compiled_model->RunCApi(signature_index, num_input_buffers, input_buffers, - num_output_buffers, output_buffers, async_ptr); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledModel(LiteRtCompiledModel compiled_model) { - delete compiled_model; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model.h b/tensorflow/lite/experimental/litert/c/litert_compiled_model.h deleted file mode 100644 index 76df573c5ea9e9..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model.h +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// The LiteRtCompiledModel is a higher level inference API. It is created by -// provided model with compilation options. Internally, it instantiates runtime -// and applies Delegates mapped to the compilation options. -// It also supports getting LiteRtTensorBufferRequirements to create -// input/output TensorBuffers, and it allows to invoke the model with the -// input/output TensorBuffers. -// -// Example user flow: -// -// 1. Create LiteRtCompiledModel -// 2. Query the model input/output LiteRtTensorBufferRequirements -// 3. Create input/output LiteRtTensorBuffer -// 4. Fill the input LiteRtTensorBuffer with input data -// 5. Invoke the model with the input/output LiteRtTensorBuffer -// 6. Evaluate the output LiteRtTensorBuffer - -LITERT_DEFINE_HANDLE(LiteRtCompiledModel); - -// Creates a LiteRtCompiledModel from a LiteRtModel object. Parameter -// `jit_compilation_options` is optional and can be null, and is owned by the -// caller. The model is loaded into memory and the caller takes ownership of -// the returned object. -LiteRtStatus LiteRtCreateCompiledModel( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtCompilationOptions compilation_options, - LiteRtCompiledModel* compiled_model); - -// Returns the buffer requirements for the given n-th input tensor. The returned -// LiteRtTensorBufferRequirements is used to create the input tensor -// buffer. -// -// Parameters: -// - compiled_model: the target `LiteRtCompiledModel` object. -// - signature_index: the index of the signature in `LiteRtModel`. -// - input_index: the index of the input tensor in the signature (subgraph). -// - buffer_requirements: the returned `LiteRtTensorBufferRequirements`. -LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex input_index, - LiteRtTensorBufferRequirements* buffer_requirements); - -// Returns the buffer requirements for the given n-th output tensor. The -// returned LiteRtTensorBufferRequirements is used to create the output tensor -// buffer. -// -// Parameters: -// - compiled_model: the target `LiteRtCompiledModel` object. -// - signature_index: the index of the signature in `LiteRtModel`. -// - input_index: the index of the input tensor in the signature (subgraph). -// - buffer_requirements: the returned `LiteRtTensorBufferRequirements`. -LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex output_index, - LiteRtTensorBufferRequirements* buffer_requirements); - -// Runs the model of the given signature synchronously, with the provided -// input/output LiteRtTensorBuffer. -// -// Parameters: -// - compiled_model: the target `LiteRtCompiledModel` object. -// - signature_index: the index of the signature in `LiteRtModel`. -// - num_input_buffers: the number of input `LiteRtTensorBuffer`. -// - input_buffers: the array of input `LiteRtTensorBuffer`. -// - num_output_buffers: the number of output `LiteRtTensorBuffer`. -// - output_buffers: the array of output LiteRtTensorBuffer. -LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, - LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers); - -// Runs the model of the given signature asynchronously, if possible, with the -// provided input/output LiteRtTensorBuffers. If asynchronous execution is -// possible, then the function sets parameter `async` to true; if asynchronous -// execution is not possible, then the function runs the model synchronously and -// sets parameter `async` to false. Note that: -// -// - Asynchronous execution is possible only in certain cases, based on the ops -// included in the model, the selected HW accelerator(s), and the capability -// of the user device hardware. -// -// - If asynchronous execution is indeed possible, it may be that only some -// parts of the model are run asynchronously (e.g., ops mapped to the GPU) -// while other parts of the model are still run synchronously with the -// invocation of this call (e.g., ops mapped to the CPU). -// -// - In case of asynchronous execution some or all of the output tensor buffers -// will have a synchronization event attached to them and the caller is -// responsible for passing such events to a downstream processing step. -// -// Parameters: -// - async: optional boolean to let the caller know if the model is being run -// asynchronously. -LiteRtStatus LiteRtRunCompiledModelAsync( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - size_t num_input_buffers, LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, LiteRtTensorBuffer* output_buffers, bool* async); - -void LiteRtDestroyCompiledModel(LiteRtCompiledModel compiled_model); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc b/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc deleted file mode 100644 index 6aa617bed4f551..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -using testing::FloatNear; -using testing::Pointwise; - -namespace litert { -namespace { - -TEST(CompiledModelTest, Basic) { - auto path = testing::GetTestFilePath(kModelFileName); - - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - - LiteRtEnvironment environment; - LiteRtEnvOption options = {}; - ASSERT_EQ(LiteRtEnvironmentCreate(/*num_options=*/0, &options, &environment), - kLiteRtStatusOk); - - LiteRtCompiledModel compiled_model; - ASSERT_EQ(LiteRtCreateCompiledModel(environment, model, - jit_compilation_options, &compiled_model), - kLiteRtStatusOk); - - LiteRtDestroyCompilationOptions(jit_compilation_options); - - LiteRtSubgraph subgraph; - ASSERT_EQ(LiteRtGetModelSubgraph(model, 0, &subgraph), kLiteRtStatusOk); - - LiteRtParamIndex num_inputs; - ASSERT_EQ(LiteRtGetNumSubgraphInputs(subgraph, &num_inputs), kLiteRtStatusOk); - - std::vector input_tensor_buffers; - input_tensor_buffers.reserve(num_inputs); - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensorBufferRequirements tensor_buffer_requirements; - ASSERT_EQ(LiteRtGetCompiledModelInputBufferRequirements( - compiled_model, /*signature_index=*/0, i, - &tensor_buffer_requirements), - kLiteRtStatusOk); - LiteRtTensorBufferType tensor_buffer_type; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), - kLiteRtStatusOk); - size_t tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - tensor_buffer_requirements, &tensor_buffer_size), - kLiteRtStatusOk); - LiteRtTensorBuffer tensor_buffer; - EXPECT_EQ( - LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, - tensor_buffer_size, &tensor_buffer), - kLiteRtStatusOk); - input_tensor_buffers.push_back(tensor_buffer); - } - - LiteRtParamIndex num_outputs; - ASSERT_EQ(LiteRtGetNumSubgraphOutputs(subgraph, &num_outputs), - kLiteRtStatusOk); - - std::vector output_tensor_buffers; - output_tensor_buffers.reserve(num_outputs); - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensorBufferRequirements tensor_buffer_requirements; - ASSERT_EQ(LiteRtGetCompiledModelOutputBufferRequirements( - compiled_model, /*signature_index=*/0, i, - &tensor_buffer_requirements), - kLiteRtStatusOk); - LiteRtTensorBufferType tensor_buffer_type; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), - kLiteRtStatusOk); - size_t tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - tensor_buffer_requirements, &tensor_buffer_size), - kLiteRtStatusOk); - LiteRtTensorBuffer tensor_buffer; - EXPECT_EQ( - LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, - tensor_buffer_size, &tensor_buffer), - kLiteRtStatusOk); - output_tensor_buffers.push_back(tensor_buffer); - } - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[0]), - kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[1], &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[1]), - kLiteRtStatusOk); - } - - ASSERT_EQ(LiteRtRunCompiledModel( - compiled_model, /*signature_index=*/0, - input_tensor_buffers.size(), input_tensor_buffers.data(), - output_tensor_buffers.size(), output_tensor_buffers.data()), - kLiteRtStatusOk); - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffers[0]), - kLiteRtStatusOk); - } - - LiteRtDestroyCompiledModel(compiled_model); - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(environment); - - for (auto tensor_buffer : input_tensor_buffers) { - LiteRtDestroyTensorBuffer(tensor_buffer); - } - for (auto tensor_buffer : output_tensor_buffers) { - LiteRtDestroyTensorBuffer(tensor_buffer); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h b/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h deleted file mode 100644 index 7186bf794db7fe..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ - -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -typedef struct LiteRtDispatchDelegateOptions LiteRtDispatchDelegateOptions; -typedef struct LiteRtEnvironmentT* LiteRtEnvironment; - -// Returns DispatchDelegateOptions populated with default values. -LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions( - LiteRtEnvironment environment); - -TfLiteStatus LiteRtAddDispatchDelegateOption( - LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option); - -void LiteRtDestroyDispatchDelegateOptions( - LiteRtDispatchDelegateOptions* options); - -// Create a delegate that uses the Dispatch API for execution. Takes ownership -// of the passed `options`. Must outlive the TFL interpreter. -TfLiteOpaqueDelegate* LiteRtCreateDispatchDelegate( - LiteRtEnvironmentOptions environment_options, - LiteRtDispatchDelegateOptions* options); - -// Do any needed cleanup and delete 'delegate'. -void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate); - -// -// Common option helpers -// - -// Alloc base is the address of the first byte of flatbuffer model in memory. It -// is used by ops to find the start of npu byte code appended to the file. -TfLiteStatus LiteRtDispatchDelegateAddAllocBaseOption( - LiteRtDispatchDelegateOptions* options, const void* alloc_base); - -// Alloc fd is the file descriptor for an mmapped flatbuffer. It is used by ops -// to find the start of npu byte code appended to the file. -TfLiteStatus LiteRtDispatchDelegateAddAllocFdOption( - LiteRtDispatchDelegateOptions* options, int alloc_fd); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_environment.cc b/tensorflow/lite/experimental/litert/c/litert_environment.cc deleted file mode 100644 index 702cf90e113379..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h" -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtEnvironmentCreate(int num_options, - const LiteRtEnvOption* options, - LiteRtEnvironment* environment) { - LITERT_RETURN_IF_ERROR(environment != nullptr, - kLiteRtStatusErrorInvalidArgument); - LITERT_ASSIGN_OR_RETURN(auto env, LiteRtEnvironmentT::CreateWithOptions( - absl::MakeSpan(options, num_options))); - litert::TriggerAcceleratorAutomaticRegistration(*env); - *environment = env.release(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyEnvironment(LiteRtEnvironment environment) { - if (environment != nullptr) { - delete environment; - } -} - -LiteRtStatus LiteRtGetEnvironmentOptions(LiteRtEnvironment environment, - LiteRtEnvironmentOptions* options) { - LITERT_RETURN_IF_ERROR( - environment, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument) - << "Environment pointer is null."); - LITERT_RETURN_IF_ERROR( - options, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument) - << "Options pointer is null."); - *options = &environment->GetOptions(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGpuGlobalEnvironmentCreate(int num_options, - const LiteRtEnvOption* options) { - LITERT_ASSIGN_OR_RETURN(auto env, LiteRtEnvironmentT::CreateWithOptions( - absl::MakeSpan(options, num_options))); - litert::internal::GpuEnvironmentSingleton::Create(env.get()); - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_environment.h b/tensorflow/lite/experimental/litert/c/litert_environment.h deleted file mode 100644 index 20b834df645513..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtEnvironment); - -// Create a LiteRT environment with options. -// Used to set the path of the compiler plugin library and dispatch library. -// -// Note: options of kLiteRtEnvOptionTagOpenCl* shouldn't be set with this API. -LiteRtStatus LiteRtEnvironmentCreate(int num_options, - const LiteRtEnvOption* options, - LiteRtEnvironment* environment); - -// Destroy a created LiteRT environment. -void LiteRtDestroyEnvironment(LiteRtEnvironment environment); - -// Get the options that the environment was created with. -LiteRtStatus LiteRtGetEnvironmentOptions(LiteRtEnvironment environment, - LiteRtEnvironmentOptions* options); - -// Create a LiteRT GPU global environment with options. -// This API is usually called by the GPU accelerator implementation to set GPU -// environment options which affect the entire LiteRT runtime. -// -// Note: In most cases, users should not call this API directly. -LiteRtStatus LiteRtGpuGlobalEnvironmentCreate(int num_options, - const LiteRtEnvOption* options); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_environment_options.cc b/tensorflow/lite/experimental/litert/c/litert_environment_options.cc deleted file mode 100644 index daf14bb79a046f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment_options.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" - -extern "C" { - -LiteRtStatus LiteRtGetEnvironmentOptionsValue(LiteRtEnvironmentOptions options, - LiteRtEnvOptionTag tag, - LiteRtAny* value) { - LITERT_RETURN_IF_ERROR( - options, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument)) - << "`options` handle is null."; - LITERT_RETURN_IF_ERROR( - value, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument)) - << "`value` handle is null."; - LITERT_ASSIGN_OR_RETURN(*value, options->GetOption(tag)); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/c/litert_environment_options.h b/tensorflow/lite/experimental/litert/c/litert_environment_options.h deleted file mode 100644 index ab778e230f51d3..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment_options.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLiteRtEnvOptionTagCompilerPluginLibraryDir = 0, - kLiteRtEnvOptionTagDispatchLibraryDir = 1, - kLiteRtEnvOptionTagOpenClDeviceId = 2, - kLiteRtEnvOptionTagOpenClPlatformId = 3, - kLiteRtEnvOptionTagOpenClContext = 4, - kLiteRtEnvOptionTagOpenClCommandQueue = 5, -} LiteRtEnvOptionTag; - -typedef struct { - LiteRtEnvOptionTag tag; - LiteRtAny value; -} LiteRtEnvOption; - -LITERT_DEFINE_HANDLE(LiteRtEnvironmentOptions); - -// Retrieves the value corresponding to the given tag. -// -// Returns kLiteRtStatusErrorNotFound if the option tag is not found. -LiteRtStatus LiteRtGetEnvironmentOptionsValue(LiteRtEnvironmentOptions options, - LiteRtEnvOptionTag tag, - LiteRtAny* value); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc deleted file mode 100644 index d827ee4005edbd..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::AnyOf; -using testing::Eq; -using testing::Not; -using testing::StrEq; -using testing::litert::IsError; - -class LiteRtEnvironmentOptionsTest : public testing::Test { - public: - void SetUp() override { - constexpr const char* kStrValue = "string_value"; - dispatch_option_.tag = kLiteRtEnvOptionTagDispatchLibraryDir; - dispatch_option_.value.type = kLiteRtAnyTypeString; - dispatch_option_.value.str_value = kStrValue; - options_.SetOption(dispatch_option_); - - constexpr int kIntValue = 3; - cl_device_id_option_.tag = kLiteRtEnvOptionTagOpenClDeviceId; - cl_device_id_option_.value.type = kLiteRtAnyTypeInt; - cl_device_id_option_.value.int_value = kIntValue; - options_.SetOption(cl_device_id_option_); - - ASSERT_THAT(NotInsertedOptionTag(), - Not(AnyOf(dispatch_option_.tag, cl_device_id_option_.tag))); - } - - LiteRtEnvironmentOptions Options() { return &options_; } - const LiteRtEnvOption& DispatchOption() const { return dispatch_option_; } - const LiteRtEnvOption& ClDeviceIdOption() const { - return cl_device_id_option_; - } - - static constexpr LiteRtEnvOptionTag NotInsertedOptionTag() { - return kLiteRtEnvOptionTagOpenClPlatformId; - } - - private: - LiteRtEnvironmentOptionsT options_; - LiteRtEnvOption dispatch_option_; - LiteRtEnvOption cl_device_id_option_; -}; - -TEST_F(LiteRtEnvironmentOptionsTest, - LiteRtGetEnvironmentOptionsValueReturnsAnErrorForInvalidArguments) { - LiteRtAny option_value; - EXPECT_THAT( - LiteRtGetEnvironmentOptionsValue( - /*options=*/nullptr, kLiteRtEnvOptionTagOpenClContext, &option_value), - IsError(kLiteRtStatusErrorInvalidArgument)); - EXPECT_THAT( - LiteRtGetEnvironmentOptionsValue( - Options(), kLiteRtEnvOptionTagOpenClContext, /*value=*/nullptr), - IsError(kLiteRtStatusErrorInvalidArgument)); -} - -TEST_F(LiteRtEnvironmentOptionsTest, LiteRtGetEnvironmentOptionsValueWorks) { - LiteRtAny option_value; - LITERT_EXPECT_OK(LiteRtGetEnvironmentOptionsValue( - Options(), ClDeviceIdOption().tag, &option_value)); - EXPECT_THAT(option_value.type, Eq(ClDeviceIdOption().value.type)); - EXPECT_THAT(option_value.int_value, Eq(ClDeviceIdOption().value.int_value)); - - LITERT_EXPECT_OK(LiteRtGetEnvironmentOptionsValue( - Options(), DispatchOption().tag, &option_value)); - EXPECT_THAT(option_value.type, Eq(DispatchOption().value.type)); - EXPECT_THAT(option_value.str_value, StrEq(DispatchOption().value.str_value)); - - EXPECT_THAT(LiteRtGetEnvironmentOptionsValue( - Options(), NotInsertedOptionTag(), &option_value), - IsError(kLiteRtStatusErrorNotFound)); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_event.cc b/tensorflow/lite/experimental/litert/c/litert_event.cc deleted file mode 100644 index 09364b9cfdd1fa..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_event.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_event.h" - -#include - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, - LiteRtEvent* event) { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - *event = new LiteRtEventT{.type = LiteRtEventTypeSyncFenceFd, - .fd = sync_fence_fd, - .owns_fd = owns_fd}; - return kLiteRtStatusOk; -#else - return kLiteRtStatusErrorUnsupported; -#endif -} - -LiteRtStatus LiteRtCreateEventFromOpenClEvent(cl_event cl_event, - LiteRtEvent* event) { -#if LITERT_HAS_OPENCL_SUPPORT - *event = new LiteRtEventT{ - .type = LiteRtEventTypeOpenCl, - .opencl_event = cl_event, - }; - return kLiteRtStatusOk; -#else - return kLiteRtStatusErrorUnsupported; -#endif -} - -LiteRtStatus LiteRtGetEventEventType(LiteRtEvent event, LiteRtEventType* type) { - *type = event->type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd) { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - if (event->type == LiteRtEventTypeSyncFenceFd) { - *sync_fence_fd = event->fd; - return kLiteRtStatusOk; - } -#endif - return kLiteRtStatusErrorUnsupported; -} - -LiteRtStatus LiteRtGetEventOpenClEvent(LiteRtEvent event, cl_event* cl_event) { -#if LITERT_HAS_OPENCL_SUPPORT - if (event->type == LiteRtEventTypeOpenCl) { - *cl_event = event->opencl_event; - return kLiteRtStatusOk; - } -#endif - return kLiteRtStatusErrorUnsupported; -} - -LiteRtStatus LiteRtCreateManagedEvent(LiteRtEventType type, - LiteRtEvent* event) { - auto event_res = LiteRtEventT::CreateManaged(type); - if (!event_res) { - return kLiteRtStatusErrorUnsupported; - } - *event = *event_res; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms) { - LITERT_RETURN_IF_ERROR(event->Wait(timeout_in_ms)); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtEventSignal(LiteRtEvent event) { - LITERT_RETURN_IF_ERROR(event->Signal()); - return kLiteRtStatusOk; -} - -void LiteRtDestroyEvent(LiteRtEvent event) { delete event; } - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_event.h b/tensorflow/lite/experimental/litert/c/litert_event.h deleted file mode 100644 index 16cc107168d0f0..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_event.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ - -#include // NOLINT: To use bool type in C -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Forward declaration of OpenCL event to avoid including OpenCL headers. -typedef struct _cl_event* cl_event; - -LITERT_DEFINE_HANDLE(LiteRtEvent); - -LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, - LiteRtEvent* event); - -LiteRtStatus LiteRtCreateEventFromOpenClEvent(cl_event cl_event, - LiteRtEvent* event); - -LiteRtStatus LiteRtCreateManagedEvent(LiteRtEventType type, LiteRtEvent* event); - -LiteRtStatus LiteRtGetEventEventType(LiteRtEvent event, LiteRtEventType* type); - -LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd); - -LiteRtStatus LiteRtGetEventOpenClEvent(LiteRtEvent event, cl_event* cl_event); - -// Pass -1 for timeout_in_ms for indefinite wait. -LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms); - -// Signal the event to notify the waiters. -LiteRtStatus LiteRtEventSignal(LiteRtEvent event); - -void LiteRtDestroyEvent(LiteRtEvent event); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_event_type.h b/tensorflow/lite/experimental/litert/c/litert_event_type.h deleted file mode 100644 index 24c7124702dcea..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_event_type.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_TYPE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_TYPE_H_ - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - LiteRtEventTypeUnknown = 0, - LiteRtEventTypeSyncFenceFd = 1, - LiteRtEventTypeOpenCl = 2, -} LiteRtEventType; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_TYPE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_gl_types.h b/tensorflow/lite/experimental/litert/c/litert_gl_types.h deleted file mode 100644 index 4f394e19ded7d1..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_gl_types.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_GL_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_GL_TYPES_H_ - -#include -#if LITERT_HAS_OPENGL_SUPPORT -#include -#include -#endif // LITERT_HAS_OPENGL_SUPPORT - -#ifdef __cplusplus -extern "C" { -#endif - -#if LITERT_HAS_OPENGL_SUPPORT -typedef GLenum LiteRtGLenum; -typedef GLuint LiteRtGLuint; -typedef GLint LiteRtGLint; -#else -// Allows for compilation of GL types when OpenGl support is not available. -typedef uint32_t LiteRtGLenum; -typedef uint32_t LiteRtGLuint; -typedef int32_t LiteRtGLint; -#endif // LITERT_HAS_OPENGL_SUPPORT - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_GL_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_layout.h b/tensorflow/lite/experimental/litert/c/litert_layout.h deleted file mode 100644 index b641985b9793af..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_layout.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Max number of dimensions in any ranked tensor type. -#define LITERT_TENSOR_MAX_RANK 8 - -// The shape information for tensor types of fixed rank. -typedef struct { - // The number of dimensions. - uint32_t rank; - - // Dimension sizes, array of length `rank`. Dynamic dimensions are anything - // less than 0. Everything from [rank, LITERT_MAX_RANK) is undefined. - int32_t dimensions[LITERT_TENSOR_MAX_RANK]; - - // Strides for a nomimal NWHC layout. NULL if unused. - const uint32_t* strides; -} LiteRtLayout; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_logging.cc b/tensorflow/lite/experimental/litert/c/litert_logging.cc deleted file mode 100644 index 66f92cd9e79545..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_logging.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/logger.h" -#include "tensorflow/lite/minimal_logging.h" - -class LiteRtLoggerT { - public: - LiteRtLogSeverity GetMinSeverity() { - return ConvertSeverity( - tflite::logging_internal::MinimalLogger::GetMinimumLogSeverity()); - } - - void SetMinSeverity(LiteRtLogSeverity severity) { - tflite::logging_internal::MinimalLogger::SetMinimumLogSeverity( - ConvertSeverity(severity)); - } - - void Log(LiteRtLogSeverity severity, const char* format, va_list args) { - tflite::logging_internal::MinimalLogger::LogFormatted( - ConvertSeverity(severity), format, args); - } - - private: - static tflite::LogSeverity ConvertSeverity(LiteRtLogSeverity severity) { - return static_cast(severity); - } - - static LiteRtLogSeverity ConvertSeverity(tflite::LogSeverity severity) { - return static_cast(severity); - } -}; - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger) { - if (!logger) { - return kLiteRtStatusErrorInvalidArgument; - } - *logger = new LiteRtLoggerT; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity* min_severity) { - if (!logger || !min_severity) { - return kLiteRtStatusErrorInvalidArgument; - } - *min_severity = logger->GetMinSeverity(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity min_severity) { - if (!logger) { - return kLiteRtStatusErrorInvalidArgument; - } - logger->SetMinSeverity(min_severity); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, - const char* format, ...) { - if (!logger || !format) { - return kLiteRtStatusErrorInvalidArgument; - } - va_list args; - va_start(args, format); - logger->Log(severity, format, args); - va_end(args); - return kLiteRtStatusOk; -} - -void LiteRtDestroyLogger(LiteRtLogger logger) { - if (logger != nullptr) { - delete logger; - } -} - -#ifdef __cplusplus -} // extern "C" -#endif - -namespace { -LiteRtLoggerT StaticLogger; -LiteRtLogger DefaultLogger = &StaticLogger; -} // namespace - -LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger) { - if (!logger) { - return kLiteRtStatusErrorInvalidArgument; - } - DefaultLogger = logger; - return kLiteRtStatusOk; -} - -LiteRtLogger LiteRtGetDefaultLogger() { return DefaultLogger; } diff --git a/tensorflow/lite/experimental/litert/c/litert_logging.h b/tensorflow/lite/experimental/litert/c/litert_logging.h deleted file mode 100644 index 4570e76327b7f9..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_logging.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtLogger); - -// WARNING: The values of the following enum are to be kept in sync with -// tflite::LogSeverity. -typedef enum { - kLiteRtLogSeverityVerbose = 0, - kLiteRtLogSeverityInfo = 1, - kLiteRtLogSeverityWarning = 2, - kLiteRtLogSeverityError = 3, - kLiteRtLogSeveritySilent = 4, -} LiteRtLogSeverity; - -#define LITERT_VERBOSE kLiteRtLogSeverityVerbose -#define LITERT_INFO kLiteRtLogSeverityInfo -#define LITERT_WARNING kLiteRtLogSeverityWarning -#define LITERT_ERROR kLiteRtLogSeverityError -#define LITERT_SILENT kLiteRtLogSeveritySilent - -LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger); -LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity* min_severity); -LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity min_severity); -LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, - const char* format, ...); -void LiteRtDestroyLogger(LiteRtLogger logger); - -LiteRtLogger LiteRtGetDefaultLogger(); -LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger); -LiteRtStatus LiteRtDefaultLoggerLog(LiteRtLogSeverity severity, - const char* format, ...); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#define LITERT_LOGGER_LOG_PROD(logger, severity, format, ...) \ - { \ - LiteRtLogSeverity __min_severity__; \ - if (LiteRtGetMinLoggerSeverity(logger, &__min_severity__) != \ - kLiteRtStatusOk) { \ - __min_severity__ = kLiteRtLogSeverityVerbose; \ - } \ - if (severity >= __min_severity__) { \ - LiteRtLoggerLog(logger, severity, "[%s:%d] " format, __FILE__, __LINE__, \ - ##__VA_ARGS__); \ - } \ - } - -#ifndef NDEBUG -#define LITERT_LOGGER_LOG LITERT_LOGGER_LOG_PROD -#else -#define LITERT_LOGGER_LOG(logger, severity, format, ...) \ - do { \ - LITERT_LOGGER_LOG_PROD(logger, severity, format, ##__VA_ARGS__); \ - } while (false) -#endif - -#define LITERT_LOG(severity, format, ...) \ - LITERT_LOGGER_LOG(LiteRtGetDefaultLogger(), severity, format, ##__VA_ARGS__); - -#define LITERT_ABORT abort() - -#define LITERT_FATAL(format, ...) \ - do { \ - LITERT_LOG(kLiteRtLogSeverityError, format, ##__VA_ARGS__) \ - LITERT_ABORT; \ - } while (0) - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_logging_test.cc b/tensorflow/lite/experimental/litert/c/litert_logging_test.cc deleted file mode 100644 index 148fc778f18915..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_logging_test.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -TEST(Layout, Creation) { - LiteRtLogger logger; - ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); - LiteRtDestroyLogger(logger); -} - -TEST(Layout, MinLogging) { - LiteRtLogger logger; - ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetMinLoggerSeverity(logger, LITERT_SILENT), kLiteRtStatusOk); - LiteRtLogSeverity min_severity; - ASSERT_EQ(LiteRtGetMinLoggerSeverity(logger, &min_severity), kLiteRtStatusOk); - ASSERT_EQ(min_severity, LITERT_SILENT); - LiteRtDestroyLogger(logger); -} diff --git a/tensorflow/lite/experimental/litert/c/litert_model.cc b/tensorflow/lite/experimental/litert/c/litert_model.cc deleted file mode 100644 index af83f970c94c2f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_model.cc +++ /dev/null @@ -1,506 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// -// Model -// - -LiteRtStatus LiteRtCreateModelFromFile(const char* filename, - LiteRtModel* model) { - if (!filename || !model) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto new_model = litert::internal::LoadModelFromFile(filename); - if (!new_model) { - return new_model.Error().Status(); - } - *model = new_model->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, - size_t buffer_size, - LiteRtModel* model) { - if (!buffer_addr || !buffer_size || !model) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto new_model = litert::internal::LoadModelFromBuffer( - litert::BufferRef(buffer_addr, buffer_size)); - if (!new_model) { - return new_model.Error().Status(); - } - *model = new_model->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, - LiteRtParamIndex* num_subgraphs) { - if (model == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_subgraphs = model->Subgraphs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, - LiteRtParamIndex subgraph_index, - LiteRtSubgraph* subgraph) { - if (model == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - if (subgraph_index >= model->Subgraphs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *subgraph = &model->Subgraph(subgraph_index); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMainModelSubgraphIndex( - LiteRtModel model, LiteRtParamIndex* main_subgraph_index) { - if (!model || !main_subgraph_index) { - return kLiteRtStatusErrorInvalidArgument; - } - *main_subgraph_index = LiteRtModelT::kMainSubgraphIndex; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetModelMetadata(LiteRtModel model, const char* metadata_key, - const void** metadata_buffer, - size_t* metadata_buffer_size) { - if (!model || !metadata_key || !metadata_buffer || !metadata_buffer_size) { - return kLiteRtStatusErrorInvalidArgument; - } - auto m_buf = model->FindMetadata(metadata_key); - if (!m_buf) { - return m_buf.Error().Status(); - } - *metadata_buffer = m_buf->Data(); - *metadata_buffer_size = m_buf->Size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, - LiteRtParamIndex* num_signatures) { - if (!model || !num_signatures) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_signatures = model->Signatures().size(); - return kLiteRtStatusOk; -} - -// Get the signature at the given index in the model -LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, - LiteRtParamIndex signature_index, - LiteRtSignature* signature) { - if (!model || !signature) { - return kLiteRtStatusErrorInvalidArgument; - } - if (signature_index >= model->Signatures().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *signature = model->Signatures().at(signature_index); - return kLiteRtStatusOk; -} - -void LiteRtDestroyModel(LiteRtModel model) { delete model; } - -LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, - size_t* size, size_t* offset, - bool destroy_model, - LiteRtModelSerializationOptions options) { - auto serialized = litert::internal::SerializeModel( - std::move(*model), options.bytecode_alignment); - // Even if we fail to serialize, we still need to destroy the model if - // requested. This is because the model may have been partially serialized - // and we don't want to leak memory. Also if ownership of the model is - // transferred to the caller, we need to ensure that the model is destroyed - // when the caller is done with it. - if (destroy_model) { - delete model; - } - if (!serialized) { - return serialized.Error().Status(); - } - std::tie(*buf, *size, *offset) = serialized->Release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op, - LiteRtParamIndex index) { - if (!op_list || !op) { - return kLiteRtStatusErrorInvalidArgument; - } - op_list->Push(op, index); - return kLiteRtStatusOk; -} - -// -// Signature -// - -LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key) { - if (!signature_key) { - return kLiteRtStatusErrorInvalidArgument; - } - *signature_key = LiteRtSignatureT::kDefaultSignatureKey.data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, - const char** signature_key) { - if (!signature || !signature_key) { - return kLiteRtStatusErrorInvalidArgument; - } - *signature_key = signature->Key().data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, - LiteRtSubgraph* subgraph) { - if (signature == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *subgraph = &signature->GetSubgraph(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, - LiteRtParamIndex* num_inputs) { - if (!signature || !num_inputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_inputs = signature->InputNames().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, - LiteRtParamIndex input_idx, - const char** input_name) { - if (!signature || !input_name) { - return kLiteRtStatusErrorInvalidArgument; - } - if (input_idx >= signature->InputNames().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *input_name = signature->InputNames().at(input_idx).data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, - LiteRtParamIndex* num_outputs) { - if (!signature || !num_outputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_outputs = signature->OutputNames().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, - LiteRtParamIndex output_idx, - const char** output_name) { - if (!signature || !output_name) { - return kLiteRtStatusErrorInvalidArgument; - } - if (output_idx >= signature->OutputNames().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *output_name = signature->OutputNames().at(output_idx).data(); - return kLiteRtStatusOk; -} - -// -// Subgraph -// - -LiteRtStatus LiteRtGetNumSubgraphInputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_inputs) { - if (!subgraph || !num_inputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_inputs = subgraph->Inputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubgraphInput(LiteRtSubgraph subgraph, - LiteRtParamIndex input_index, - LiteRtTensor* input) { - if (!subgraph || !input) { - return kLiteRtStatusErrorInvalidArgument; - } else if (input_index < 0 || input_index >= subgraph->Inputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *input = subgraph->Inputs()[input_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSubgraphOutputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_outputs) { - if (!subgraph || !num_outputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_outputs = subgraph->Outputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubgraphOutput(LiteRtSubgraph subgraph, - LiteRtParamIndex output_index, - LiteRtTensor* output) { - if (!subgraph || !output) { - return kLiteRtStatusErrorInvalidArgument; - } else if (output_index < 0 || output_index >= subgraph->Outputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *output = subgraph->Outputs()[output_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSubgraphOps(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_ops) { - if (!subgraph || !num_ops) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_ops = subgraph->Ops().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubgraphOp(LiteRtSubgraph subgraph, - LiteRtParamIndex op_index, LiteRtOp* op) { - if (!subgraph || !op) { - return kLiteRtStatusErrorInvalidArgument; - } else if (op_index < 0 || op_index >= subgraph->Ops().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *op = subgraph->Ops()[op_index]; - return kLiteRtStatusOk; -} - -// -// Op -// - -LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code) { - if (!op || !code) { - return kLiteRtStatusErrorInvalidArgument; - } - *code = op->OpCode(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs) { - if (!op || !num_inputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_inputs = op->Inputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetOpInput(LiteRtOp op, LiteRtParamIndex input_index, - LiteRtTensor* input) { - if (!op || !input) { - return kLiteRtStatusErrorInvalidArgument; - } else if (input_index < 0 || input_index >= op->Inputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *input = op->Inputs()[input_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs) { - if (!op || !num_outputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_outputs = op->Outputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetOpOutput(LiteRtOp op, LiteRtParamIndex output_index, - LiteRtTensor* output) { - if (!op || !output) { - return kLiteRtStatusErrorInvalidArgument; - } else if (output_index < 0 || output_index >= op->Outputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *output = op->Outputs()[output_index]; - return kLiteRtStatusOk; -} - -// -// Weights -// - -LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, - size_t* size) { - if (!weights || !addr || !size) { - return kLiteRtStatusErrorInvalidArgument; - } - *addr = weights->Buffer().Data(); - *size = weights->Buffer().Size(); - return kLiteRtStatusOk; -} - -// -// Tensor -// - -LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, - LiteRtWeights* weights) { - if (!tensor || !weights) { - return kLiteRtStatusErrorInvalidArgument; - } - *weights = &tensor->Weights(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumTensorUses(LiteRtTensor tensor, - LiteRtParamIndex* num_uses) { - if (!tensor || !num_uses) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_uses = tensor->Users().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorUse(LiteRtTensor tensor, LiteRtParamIndex use_index, - LiteRtOp* user, - LiteRtParamIndex* user_arg_index) { - if (!tensor || !user || !user_arg_index) { - return kLiteRtStatusErrorInvalidArgument; - } else if (use_index < 0 || use_index >= tensor->Users().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *user = tensor->Users()[use_index]; - *user_arg_index = tensor->UserArgInds()[use_index]; - return kLiteRtStatusOk; -} - -// Null if subgraph input or constant. -LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, - bool* has_defining_op, - LiteRtTensorDefiningOp* defining_op) { - if (!tensor || !has_defining_op || !defining_op) { - return kLiteRtStatusErrorInvalidArgument; - } - if (tensor->DefiningOp() != nullptr) { - *has_defining_op = true; - defining_op->op = tensor->DefiningOp(); - defining_op->op_output_index = tensor->DefiningOpOutInd(); - } else { - *has_defining_op = false; - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, - LiteRtTensorTypeId* type_id) { - if (!tensor || !type_id) { - return kLiteRtStatusErrorInvalidArgument; - } - *type_id = tensor->Type().first; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetUnrankedTensorType( - LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type) { - if (!tensor || !unranked_tensor_type) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Type().first != kLiteRtUnrankedTensorType) { - return kLiteRtStatusErrorInvalidIrType; - } - *unranked_tensor_type = tensor->Type().second.unranked_tensor_type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetRankedTensorType( - LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type) { - if (!tensor || !ranked_tensor_type) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Type().first != kLiteRtRankedTensorType) { - return kLiteRtStatusErrorInvalidIrType; - } - *ranked_tensor_type = tensor->Type().second.ranked_tensor_type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name) { - if (!tensor || !name) { - return kLiteRtStatusErrorInvalidArgument; - } - *name = tensor->Name().data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, - LiteRtQuantizationTypeId* q_type_id) { - if (!tensor || !q_type_id) { - return kLiteRtStatusErrorInvalidArgument; - } - *q_type_id = tensor->Qparams().first; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetPerTensorQuantization( - LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization) { - if (!tensor || !per_tensor_quantization) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Qparams().first != kLiteRtQuantizationPerTensor) { - return kLiteRtStatusErrorInvalidIrType; - } - auto& per_tensor = tensor->Qparams().second.per_tensor; - per_tensor_quantization->scale = per_tensor.scale; - per_tensor_quantization->zero_point = per_tensor.zero_point; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetPerChannelQuantization( - LiteRtTensor tensor, - LiteRtQuantizationPerChannel* per_channel_quantization) { - if (!tensor || !per_channel_quantization) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Qparams().first != kLiteRtQuantizationPerChannel) { - return kLiteRtStatusErrorInvalidIrType; - } - auto& per_channel = tensor->Qparams().second.per_channel; - per_channel_quantization->scales = per_channel.scales; - per_channel_quantization->zero_points = per_channel.zero_points; - per_channel_quantization->num_channels = per_channel.num_channels; - per_channel_quantization->quantized_dimension = - per_channel.quantized_dimension; - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_model.h b/tensorflow/lite/experimental/litert/c/litert_model.h deleted file mode 100644 index ba55d759e23e77..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_model.h +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ - -#include // NOLINT: To use bool type in C -#include -#include - -#include "tensorflow/lite/core/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// -// Handles + Common -// - -// Constant data behind a tensor stored in the model. -LITERT_DEFINE_HANDLE(LiteRtWeights); - -// Values/edges of the models graph. -LITERT_DEFINE_HANDLE(LiteRtTensor); - -// Operations/nodes of the models graph. -LITERT_DEFINE_HANDLE(LiteRtOp); - -// Fundamental block of program, i.e. a function body. -LITERT_DEFINE_HANDLE(LiteRtSubgraph); - -// Signature of the model. -LITERT_DEFINE_HANDLE(LiteRtSignature); - -// A collection of subgraph + metadata + signature. -LITERT_DEFINE_HANDLE(LiteRtModel); - -// Append only list of ops. -LITERT_DEFINE_HANDLE(LiteRtOpList); - -// -// LiteRtTensor + Types -// - -// Get the string name associated with this tensor. This is an optional -// attribute and if not set will return a zero-length string. -LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name); - -// TENSOR TYPES - -// Primitive types for elements in a tensor. -typedef enum { - kLiteRtElementTypeNone = kTfLiteNoType, - kLiteRtElementTypeBool = kTfLiteBool, - kLiteRtElementTypeInt4 = kTfLiteInt4, - kLiteRtElementTypeInt8 = kTfLiteInt8, - kLiteRtElementTypeInt16 = kTfLiteInt16, - kLiteRtElementTypeInt32 = kTfLiteInt32, - kLiteRtElementTypeInt64 = kTfLiteInt64, - kLiteRtElementTypeUInt8 = kTfLiteUInt8, - kLiteRtElementTypeUInt16 = kTfLiteUInt16, - kLiteRtElementTypeUInt32 = kTfLiteUInt32, - kLiteRtElementTypeUInt64 = kTfLiteUInt64, - kLiteRtElementTypeFloat16 = kTfLiteFloat16, - kLiteRtElementTypeBFloat16 = kTfLiteBFloat16, - kLiteRtElementTypeFloat32 = kTfLiteFloat32, - kLiteRtElementTypeFloat64 = kTfLiteFloat64, - kLiteRtElementTypeComplex64 = kTfLiteComplex64, - kLiteRtElementTypeComplex128 = kTfLiteComplex128, - kLiteRtElementTypeTfResource = kTfLiteResource, - kLiteRtElementTypeTfString = kTfLiteString, - kLiteRtElementTypeTfVariant = kTfLiteVariant, -} LiteRtElementType; - -// Tensor whose rank is dynamic. -typedef struct { - // The primitive element type of the constituent data. - LiteRtElementType element_type; -} LiteRtUnrankedTensorType; - -// Tensor whose rank is static but dimenions may be dynamic. -typedef struct { - // The primitive element type of the constituent data. - LiteRtElementType element_type; - - // Shape information. - LiteRtLayout layout; -} LiteRtRankedTensorType; - -// The identifier for tensor type union. -typedef enum { - // Type with fix ranked and possibly dynamic dimensions. - kLiteRtRankedTensorType = 0, - - // Type with dynamic rank. - kLiteRtUnrankedTensorType = 1, -} LiteRtTensorTypeId; - -// Get type identifier from tensor. -LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, - LiteRtTensorTypeId* type_id); - -// Get unranked tensor type info, return bad status if not unranked. -LiteRtStatus LiteRtGetUnrankedTensorType( - LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type); - -// Get ranked tensor type info, return bad status if not ranked. -LiteRtStatus LiteRtGetRankedTensorType( - LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type); - -// QUANTIZATION - -// Schema for tensors quantized with one set of q-params. -typedef struct { - // Scaling factor. - float scale; - - // The value that float:0 maps to in q-space. - int64_t zero_point; -} LiteRtQuantizationPerTensor; - -// Schema for tensors quantized with one set of q-params per channel. -typedef struct { - int32_t quantized_dimension; - uint64_t num_channels; - float* scales; - int64_t* zero_points; -} LiteRtQuantizationPerChannel; - -// The identifier for quantization scheme type union. -typedef enum { - // Tag for tensors without quantization. - kLiteRtQuantizationNone = 0, - - // Basic quantization, one set of q-params per tensor. - kLiteRtQuantizationPerTensor = 1, - - // [NOT IMPLEMENTED YET] Q-params for each element accross a single dimension. - kLiteRtQuantizationPerChannel = 2, - - // [NOT IMPLEMENTED YET] Q-params accross blocks of fixed size (e.g. 2048). - kLiteRtQuantizationBlockWise = 3, -} LiteRtQuantizationTypeId; - -// Get the identifier for the type of quantization for a given tensor. -LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, - LiteRtQuantizationTypeId* q_type_id); - -// Get the per-tensor quantization information for a given tensor if it has it. -LiteRtStatus LiteRtGetPerTensorQuantization( - LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization); - -// Get the per-channel quantization information for a given tensor if it has it. -LiteRtStatus LiteRtGetPerChannelQuantization( - LiteRtTensor tensor, - LiteRtQuantizationPerChannel* per_channel_quantization); - -// EDGES - -// Information about the about that defines a tensor. -typedef struct LiteRtTensorDefiningOp { - // The defining op itself. - LiteRtOp op; - - // The op output index that defines the specific tensor. - LiteRtParamIndex op_output_index; -} LiteRtTensorDefiningOp; - -// Information about a reference to a tensor in the graph. -typedef struct LiteRtTensorUserOp { - // The referring op itself. - LiteRtOp op; - - // Index of which operand the op refers to a specific tensor on. - LiteRtParamIndex op_input_index; -} LiteRtTensorUserOp; - -// Get all the ops that reference given tensor, and at what operand index. -LiteRtStatus LiteRtGetNumTensorUses(LiteRtTensor tensor, - LiteRtParamIndex* num_uses); -LiteRtStatus LiteRtGetTensorUse(LiteRtTensor tensor, LiteRtParamIndex use_index, - LiteRtOp* user, - LiteRtParamIndex* user_arg_index); - -// Get the op that defines this tensor and the corresponding output index. If -// tensor is a subgraph input, has_defining_op will be false. -LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, - bool* has_defining_op, - LiteRtTensorDefiningOp* defining_op); - -// WEIGHTS (constant data) - -// Get static weights associated with a given tensor. All tensors have weights, -// null weights have size = 0; -LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, - LiteRtWeights* weights); - -// -// LiteRtWeights -// - -// Get opaque array from given tensor weights. -LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, - size_t* size); - -// -// LiteRtOp -// - -// Get code corresponding to operation type for given op. -LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code); - -// Get input tensors of given op. -LiteRtStatus LiteRtGetNumOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs); -LiteRtStatus LiteRtGetOpInput(LiteRtOp op, LiteRtParamIndex input_index, - LiteRtTensor* input); - -// Get output tensors of given op. -LiteRtStatus LiteRtGetNumOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs); -LiteRtStatus LiteRtGetOpOutput(LiteRtOp op, LiteRtParamIndex output_index, - LiteRtTensor* output); - -// -// LiteRtSubgraph -// - -// Get input tensors for given subgraph. -LiteRtStatus LiteRtGetNumSubgraphInputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_inputs); -LiteRtStatus LiteRtGetSubgraphInput(LiteRtSubgraph subgraph, - LiteRtParamIndex input_index, - LiteRtTensor* input); - -// Get output tensors for given subgraph. -LiteRtStatus LiteRtGetNumSubgraphOutputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_outputs); -LiteRtStatus LiteRtGetSubgraphOutput(LiteRtSubgraph subgraph, - LiteRtParamIndex output_index, - LiteRtTensor* output); - -// Get all ops in given subgraph in a topological order. -LiteRtStatus LiteRtGetNumSubgraphOps(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_ops); -LiteRtStatus LiteRtGetSubgraphOp(LiteRtSubgraph subgraph, - LiteRtParamIndex op_index, LiteRtOp* op); - -// -// LiteRtSignature -// - -// Default signature key. This is the key that is used if the model does not -// define any signatures. -LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key); - -// Get the signature key string defined in the model. -LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, - const char** signature_key); - -// Get the associated subgraph for the given signature. -LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, - LiteRtSubgraph* subgraph); - -// Get the number of inputs for the given signature. -LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, - LiteRtParamIndex* num_inputs); - -// Get the name of the i-th of input tensor name for the given signature. -LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, - LiteRtParamIndex input_idx, - const char** input_name); - -// Get the number of outputs for the given signature. -LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, - LiteRtParamIndex* num_outputs); - -// Get the name of the i-th of output tensor name for the given signature. -LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, - LiteRtParamIndex output_idx, - const char** output_name); - -// -// LiteRtModel -// - -LiteRtStatus LiteRtCreateModelFromFile(const char* filename, - LiteRtModel* model); - -LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, - size_t buffer_size, - LiteRtModel* model); - -// Get the metadata buffer associated with given key if it exists. -LiteRtStatus LiteRtGetModelMetadata(LiteRtModel model, const char* metadata_key, - const void** metadata_buffer, - size_t* metadata_buffer_size); - -// Get the index of the entry subgraph. -// TODO: b/365299994 - Figure out signatures. -LiteRtStatus LiteRtGetMainModelSubgraphIndex( - LiteRtModel model, LiteRtParamIndex* main_subgraph_index); - -// Get number of subgraphs in model. -LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, - LiteRtParamIndex* num_subgraphs); - -// Get subgraph at given index in model. -LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, - LiteRtParamIndex subgraph_index, - LiteRtSubgraph* subgraph); - -// Get the number of signatures defined in the model. -LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, - LiteRtParamIndex* num_signatures); - -// Get the signature at the given index in the model -LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, - LiteRtParamIndex signature_index, - LiteRtSignature* signature); - -// Destroy the given model, freeing any memory it owns. -void LiteRtDestroyModel(LiteRtModel model); - -// -// Utility Types -// - -// An append only list of ops. -LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op, - LiteRtParamIndex partition_index); - -// -// Serialization related functions -// - -// Options for model serialization. -typedef struct LiteRtModelSerializationOptions { - // Alignment for bytecode assets that are appended to the model. - // Alignment is enforced relative to the first byte of the flatbuffer. - size_t bytecode_alignment; -} LiteRtModelSerializationOptions; - -// Serializes model to valid tflite flatbuffer bytes. -// -// This destroys the model before it returns unless destroy_model is false. -// Caller takes ownership of `buf`. Flatbuffers are packed into their arrays -// back to front, so the valid flatbuffer is buf[offset, size]. See the above -// options for more details. -LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, - size_t* size, size_t* offset, - bool destroy_model, - LiteRtModelSerializationOptions options); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_model_test.cc b/tensorflow/lite/experimental/litert/c/litert_model_test.cc deleted file mode 100644 index 8f41902e8b70f1..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_model_test.cc +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using ::litert::BufferRef; -using ::litert::OwningBufferRef; -using ::testing::ElementsAreArray; -using ::testing::litert::IsError; - -TEST(LiteRtWeightsTest, GetNullWeights) { - LiteRtWeightsT weights = {}; - - const void* addr; - size_t size; - LITERT_ASSERT_OK(LiteRtGetWeightsBytes(&weights, &addr, &size)); - - EXPECT_EQ(addr, nullptr); - EXPECT_EQ(size, 0); -} - -TEST(LiteRtWeightsTest, GetWeights) { - static constexpr std::array kData = {1, 2, 3}; - const uint8_t* kDataPtr = reinterpret_cast(kData.data()); - const auto kDataSize = kData.size() * sizeof(int32_t); - - LiteRtWeightsT weights; - SetWeightsFromOwnedBuffer(weights, - OwningBufferRef(kDataPtr, kDataSize)); - - const void* addr; - size_t size; - LITERT_ASSERT_OK(LiteRtGetWeightsBytes(&weights, &addr, &size)); - - EXPECT_NE(addr, nullptr); - EXPECT_EQ(size, 3 * sizeof(int32_t)); - - EXPECT_THAT(absl::MakeConstSpan(reinterpret_cast(addr), 3), - ElementsAreArray(kData)); -} - -TEST(LiteRtTensorTest, GetUnrankedType) { - static constexpr auto kElementType = kLiteRtElementTypeFloat32; - static constexpr auto kId = kLiteRtUnrankedTensorType; - - TensorType type; - type.first = kId; - type.second.unranked_tensor_type.element_type = kElementType; - - LiteRtTensorT tensor; - tensor.SetType(std::move(type)); - - LiteRtTensorTypeId id; - LITERT_ASSERT_OK(LiteRtGetTensorTypeId(&tensor, &id)); - ASSERT_EQ(id, kId); - - LiteRtUnrankedTensorType unranked; - LITERT_ASSERT_OK(LiteRtGetUnrankedTensorType(&tensor, &unranked)); - EXPECT_EQ(unranked.element_type, kElementType); -} - -TEST(LiteRtTensorTest, GetRankedTensorType) { - static constexpr auto kElementType = kLiteRtElementTypeFloat32; - static constexpr auto kId = kLiteRtRankedTensorType; - - LiteRtTensorT tensor; - tensor.SetType(MakeRankedTensorType(kElementType, {3, 3})); - - LiteRtTensorTypeId id; - LITERT_ASSERT_OK(LiteRtGetTensorTypeId(&tensor, &id)); - ASSERT_EQ(id, kId); - - LiteRtRankedTensorType ranked; - LITERT_ASSERT_OK(LiteRtGetRankedTensorType(&tensor, &ranked)); - EXPECT_EQ(ranked.element_type, kElementType); - ASSERT_EQ(ranked.layout.rank, 2); - EXPECT_THAT(absl::MakeConstSpan(ranked.layout.dimensions, 2), - ElementsAreArray({3, 3})); -} - -TEST(LiteRtTensorTest, GetUses) { - LiteRtTensorT tensor; - - LiteRtOpT user; - tensor.Users().push_back(&user); - tensor.UserArgInds().push_back(0); - - LiteRtOpT other_user; - tensor.Users().push_back(&other_user); - tensor.UserArgInds().push_back(1); - - LiteRtParamIndex num_uses; - LITERT_ASSERT_OK(LiteRtGetNumTensorUses(&tensor, &num_uses)); - ASSERT_EQ(num_uses, 2); - - LiteRtOp actual_user; - LiteRtParamIndex actual_user_arg_index; - LITERT_ASSERT_OK(LiteRtGetTensorUse(&tensor, /*use_index=*/0, &actual_user, - &actual_user_arg_index)); - ASSERT_EQ(actual_user, &user); - ASSERT_EQ(actual_user_arg_index, 0); - - LITERT_ASSERT_OK(LiteRtGetTensorUse(&tensor, /*use_index=*/1, &actual_user, - &actual_user_arg_index)); - ASSERT_EQ(actual_user, &other_user); - ASSERT_EQ(actual_user_arg_index, 1); -} - -TEST(LiteRtTensorTest, GetDefiningOp) { - LiteRtTensorT tensor; - - LiteRtOpT def_op; - tensor.SetDefiningOp(def_op, 0); - - LiteRtTensorDefiningOp actual_def_op; - bool has_defining_op; - LITERT_ASSERT_OK( - LiteRtGetTensorDefiningOp(&tensor, &has_defining_op, &actual_def_op)); - ASSERT_TRUE(has_defining_op); - EXPECT_EQ(actual_def_op.op, &def_op); - EXPECT_EQ(actual_def_op.op_output_index, 0); -} - -TEST(LiteRtTensorTest, NoDefiningOp) { - LiteRtTensorT tensor; - - LiteRtTensorDefiningOp actual_def_op; - bool has_defining_op; - LITERT_ASSERT_OK( - LiteRtGetTensorDefiningOp(&tensor, &has_defining_op, &actual_def_op)); - ASSERT_FALSE(has_defining_op); -} - -TEST(LiteRtTensorTest, Name) { - static constexpr const char kName[] = "foo"; - - LiteRtTensorT tensor; - tensor.SetName(std::string(kName)); - - const char* name; - LITERT_ASSERT_OK(LiteRtGetTensorName(&tensor, &name)); - EXPECT_STREQ(name, kName); -} - -TEST(LiteRtTensorTest, QuantizationNone) { - LiteRtTensorT tensor; - - LiteRtQuantizationTypeId q_type_id; - LITERT_ASSERT_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); - EXPECT_EQ(q_type_id, kLiteRtQuantizationNone); - - LiteRtQuantizationPerTensor per_tensor_quantization; - EXPECT_NE(LiteRtGetPerTensorQuantization(&tensor, &per_tensor_quantization), - kLiteRtStatusOk); -} - -TEST(LiteRtTensorTest, QuantizationPerTensor) { - static constexpr auto kScale = 1.0; - static constexpr auto kZeroPoint = 1; - - LiteRtTensorT tensor; - tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); - - LiteRtQuantizationTypeId q_type_id; - LITERT_ASSERT_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); - ASSERT_EQ(q_type_id, kLiteRtQuantizationPerTensor); - - LiteRtQuantizationPerTensor per_tensor_quantization; - LITERT_ASSERT_OK( - LiteRtGetPerTensorQuantization(&tensor, &per_tensor_quantization)); - - EXPECT_EQ(per_tensor_quantization.scale, kScale); - EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); -} - -TEST(LiteRtTensorTest, QuantizationPerChannel) { - static constexpr size_t kNumChannels = 2; - static constexpr size_t kQuantizedDimension = 0; - static constexpr float kScales[kNumChannels] = {1.0, 2.0}; - static constexpr int64_t kZps[kNumChannels] = {2, 3}; - - LiteRtTensorT tensor; - - { - auto per_channel = - MakePerChannelQuantization(kScales, kZps, kQuantizedDimension, tensor); - tensor.SetQarams(per_channel); - } - - LiteRtQuantizationTypeId q_type_id; - LITERT_ASSERT_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); - ASSERT_EQ(q_type_id, kLiteRtQuantizationPerChannel); - - LiteRtQuantizationPerChannel per_channel_quantization; - LITERT_ASSERT_OK( - LiteRtGetPerChannelQuantization(&tensor, &per_channel_quantization)); - - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), - testing::ElementsAreArray(kScales)); - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), - testing::ElementsAreArray(kZps)); - ASSERT_EQ(per_channel_quantization.num_channels, kNumChannels); - ASSERT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); -} - -TEST(LiteRtOpTest, GetOpCode) { - static constexpr auto kCode = kLiteRtOpCodeTflCustom; - - LiteRtOpT op; - op.SetOpCode(kCode); - - LiteRtOpCode code; - LITERT_ASSERT_OK(LiteRtGetOpCode(&op, &code)); - EXPECT_EQ(code, kCode); -} - -TEST(LiteRtOpTest, GetInputs) { - LiteRtTensorT input1; - LiteRtTensorT input2; - - LiteRtOpT op; - op.Inputs().push_back(&input1); - op.Inputs().push_back(&input2); - - LiteRtParamIndex num_inputs; - LITERT_ASSERT_OK(LiteRtGetNumOpInputs(&op, &num_inputs)); - ASSERT_EQ(num_inputs, 2); - - LiteRtTensor actual_input; - LITERT_ASSERT_OK(LiteRtGetOpInput(&op, /*input_index=*/0, &actual_input)); - EXPECT_EQ(actual_input, &input1); - - LITERT_ASSERT_OK(LiteRtGetOpInput(&op, /*input_index=*/1, &actual_input)); - EXPECT_EQ(actual_input, &input2); -} - -TEST(LiteRtOpTest, GetOutputs) { - LiteRtTensorT output1; - LiteRtTensorT output2; - - LiteRtOpT op; - op.Outputs().push_back(&output1); - op.Outputs().push_back(&output2); - - LiteRtParamIndex num_outputs; - LITERT_ASSERT_OK(LiteRtGetNumOpOutputs(&op, &num_outputs)); - ASSERT_EQ(num_outputs, 2); - - LiteRtTensor actual_output; - LITERT_ASSERT_OK(LiteRtGetOpOutput(&op, /*output_index=*/0, &actual_output)); - EXPECT_EQ(actual_output, &output1); - - LITERT_ASSERT_OK(LiteRtGetOpOutput(&op, /*output_index=*/1, &actual_output)); - EXPECT_EQ(actual_output, &output2); -} - -TEST(LiteRtSubgraphTest, GetInputs) { - LiteRtTensorT input1; - LiteRtTensorT input2; - - LiteRtSubgraphT subgraph; - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - - LiteRtParamIndex num_inputs; - LITERT_ASSERT_OK(LiteRtGetNumSubgraphInputs(&subgraph, &num_inputs)); - - LiteRtTensor actual_input; - LITERT_ASSERT_OK( - LiteRtGetSubgraphInput(&subgraph, /*input_index=*/0, &actual_input)); - EXPECT_EQ(actual_input, &input1); - - LITERT_ASSERT_OK( - LiteRtGetSubgraphInput(&subgraph, /*input_index=*/1, &actual_input)); - EXPECT_EQ(actual_input, &input2); -} - -TEST(LiteRtSubgraphTest, GetOutputs) { - LiteRtTensorT output1; - LiteRtTensorT output2; - - LiteRtSubgraphT subgraph; - subgraph.Outputs().push_back(&output1); - subgraph.Outputs().push_back(&output2); - - LiteRtParamIndex num_outputs; - LITERT_ASSERT_OK(LiteRtGetNumSubgraphOutputs(&subgraph, &num_outputs)); - - LiteRtTensor actual_output; - LITERT_ASSERT_OK( - LiteRtGetSubgraphOutput(&subgraph, /*output_index=*/0, &actual_output)); - EXPECT_EQ(actual_output, &output1); - - LITERT_ASSERT_OK( - LiteRtGetSubgraphOutput(&subgraph, /*output_index=*/1, &actual_output)); - EXPECT_EQ(actual_output, &output2); -} - -TEST(LiteRtSubgraphTest, GetOps) { - LiteRtSubgraphT subgraph; - auto& op1 = subgraph.EmplaceOp(); - auto& op2 = subgraph.EmplaceOp(); - - LiteRtParamIndex num_ops; - LITERT_ASSERT_OK(LiteRtGetNumSubgraphOps(&subgraph, &num_ops)); - ASSERT_EQ(num_ops, 2); - - LiteRtOp actual_op; - LITERT_ASSERT_OK(LiteRtGetSubgraphOp(&subgraph, /*op_index=*/0, &actual_op)); - ASSERT_EQ(actual_op, &op1); - - LITERT_ASSERT_OK(LiteRtGetSubgraphOp(&subgraph, /*op_index=*/1, &actual_op)); - ASSERT_EQ(actual_op, &op2); -} - -TEST(LiteRtModelTest, GetMetadata) { - static constexpr absl::string_view kKey = "KEY"; - static constexpr absl::string_view kData = "DATA"; - - LiteRtModelT model; - model.PushMetadata(kKey, kData); - - const void* metadata; - size_t metadata_size; - LITERT_ASSERT_OK( - LiteRtGetModelMetadata(&model, kKey.data(), &metadata, &metadata_size)); - EXPECT_EQ(BufferRef(metadata, metadata_size).StrView(), kData); -} - -TEST(LiteRtModelTest, GetSubgraph) { - LiteRtModelT model; - auto& subgraph = model.EmplaceSubgraph(); - - LiteRtSubgraph actual_subgraph; - LITERT_ASSERT_OK(LiteRtGetModelSubgraph(&model, 0, &actual_subgraph)); - EXPECT_EQ(actual_subgraph, &subgraph); -} - -TEST(LiteRtModelTest, GetSubgraphOOB) { - LiteRtModelT model; - - LiteRtSubgraph actual_subgraph; - EXPECT_THAT(LiteRtGetModelSubgraph(&model, 0, &actual_subgraph), - IsError(kLiteRtStatusErrorIndexOOB)); -} - -TEST(LiteRtOpListTest, PushOps) { - LiteRtOpListT op_list; - LiteRtOpT op; - - LITERT_ASSERT_OK(LiteRtPushOp(&op_list, &op, 0)); - auto vec = op_list.Values(); - ASSERT_EQ(vec.size(), 1); - EXPECT_EQ(vec.front().first, &op); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_op_code.h b/tensorflow/lite/experimental/litert/c/litert_op_code.h deleted file mode 100644 index 529360e87dc415..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_op_code.h +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ - -#include "tensorflow/lite/builtin_ops.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLiteRtOpCodeTflAdd = kTfLiteBuiltinAdd, - kLiteRtOpCodeTflAveragePool2d = kTfLiteBuiltinAveragePool2d, - kLiteRtOpCodeTflConcatenation = kTfLiteBuiltinConcatenation, - kLiteRtOpCodeTflConv2d = kTfLiteBuiltinConv2d, - kLiteRtOpCodeTflDepthwiseConv2d = kTfLiteBuiltinDepthwiseConv2d, - kLiteRtOpCodeTflDepthToSpace = kTfLiteBuiltinDepthToSpace, - kLiteRtOpCodeTflDequantize = kTfLiteBuiltinDequantize, - kLiteRtOpCodeTflEmbeddingLookup = kTfLiteBuiltinEmbeddingLookup, - kLiteRtOpCodeTflFloor = kTfLiteBuiltinFloor, - kLiteRtOpCodeTflFullyConnected = kTfLiteBuiltinFullyConnected, - kLiteRtOpCodeTflHashtableLookup = kTfLiteBuiltinHashtableLookup, - kLiteRtOpCodeTflL2Normalization = kTfLiteBuiltinL2Normalization, - kLiteRtOpCodeTflL2Pool2d = kTfLiteBuiltinL2Pool2d, - kLiteRtOpCodeTflLocalResponseNormalization = - kTfLiteBuiltinLocalResponseNormalization, - kLiteRtOpCodeTflLogistic = kTfLiteBuiltinLogistic, - kLiteRtOpCodeTflLshProjection = kTfLiteBuiltinLshProjection, - kLiteRtOpCodeTflLstm = kTfLiteBuiltinLstm, - kLiteRtOpCodeTflMaxPool2d = kTfLiteBuiltinMaxPool2d, - kLiteRtOpCodeTflMul = kTfLiteBuiltinMul, - kLiteRtOpCodeTflRelu = kTfLiteBuiltinRelu, - kLiteRtOpCodeTflReluN1To1 = kTfLiteBuiltinReluN1To1, - kLiteRtOpCodeTflRelu6 = kTfLiteBuiltinRelu6, - kLiteRtOpCodeTflReshape = kTfLiteBuiltinReshape, - kLiteRtOpCodeTflResizeBilinear = kTfLiteBuiltinResizeBilinear, - kLiteRtOpCodeTflRnn = kTfLiteBuiltinRnn, - kLiteRtOpCodeTflSoftmax = kTfLiteBuiltinSoftmax, - kLiteRtOpCodeTflSpaceToDepth = kTfLiteBuiltinSpaceToDepth, - kLiteRtOpCodeTflSvdf = kTfLiteBuiltinSvdf, - kLiteRtOpCodeTflTanh = kTfLiteBuiltinTanh, - kLiteRtOpCodeTflConcatEmbeddings = kTfLiteBuiltinConcatEmbeddings, - kLiteRtOpCodeTflSkipGram = kTfLiteBuiltinSkipGram, - kLiteRtOpCodeTflCall = kTfLiteBuiltinCall, - kLiteRtOpCodeTflCustom = kTfLiteBuiltinCustom, - kLiteRtOpCodeTflEmbeddingLookupSparse = kTfLiteBuiltinEmbeddingLookupSparse, - kLiteRtOpCodeTflPad = kTfLiteBuiltinPad, - kLiteRtOpCodeTflUnidirectionalSequenceRnn = - kTfLiteBuiltinUnidirectionalSequenceRnn, - kLiteRtOpCodeTflGather = kTfLiteBuiltinGather, - kLiteRtOpCodeTflBatchToSpaceNd = kTfLiteBuiltinBatchToSpaceNd, - kLiteRtOpCodeTflSpaceToBatchNd = kTfLiteBuiltinSpaceToBatchNd, - kLiteRtOpCodeTflTranspose = kTfLiteBuiltinTranspose, - kLiteRtOpCodeTflMean = kTfLiteBuiltinMean, - kLiteRtOpCodeTflSub = kTfLiteBuiltinSub, - kLiteRtOpCodeTflDiv = kTfLiteBuiltinDiv, - kLiteRtOpCodeTflSqueeze = kTfLiteBuiltinSqueeze, - kLiteRtOpCodeTflUnidirectionalSequenceLstm = - kTfLiteBuiltinUnidirectionalSequenceLstm, - kLiteRtOpCodeTflStridedSlice = kTfLiteBuiltinStridedSlice, - kLiteRtOpCodeTflBidirectionalSequenceRnn = - kTfLiteBuiltinBidirectionalSequenceRnn, - kLiteRtOpCodeTflExp = kTfLiteBuiltinExp, - kLiteRtOpCodeTflTopkV2 = kTfLiteBuiltinTopkV2, - kLiteRtOpCodeTflSplit = kTfLiteBuiltinSplit, - kLiteRtOpCodeTflLogSoftmax = kTfLiteBuiltinLogSoftmax, - kLiteRtOpCodeTflDelegate = kTfLiteBuiltinDelegate, - kLiteRtOpCodeTflBidirectionalSequenceLstm = - kTfLiteBuiltinBidirectionalSequenceLstm, - kLiteRtOpCodeTflCast = kTfLiteBuiltinCast, - kLiteRtOpCodeTflPrelu = kTfLiteBuiltinPrelu, - kLiteRtOpCodeTflMaximum = kTfLiteBuiltinMaximum, - kLiteRtOpCodeTflArgMax = kTfLiteBuiltinArgMax, - kLiteRtOpCodeTflMinimum = kTfLiteBuiltinMinimum, - kLiteRtOpCodeTflLess = kTfLiteBuiltinLess, - kLiteRtOpCodeTflNeg = kTfLiteBuiltinNeg, - kLiteRtOpCodeTflPadv2 = kTfLiteBuiltinPadv2, - kLiteRtOpCodeTflGreater = kTfLiteBuiltinGreater, - kLiteRtOpCodeTflGreaterEqual = kTfLiteBuiltinGreaterEqual, - kLiteRtOpCodeTflLessEqual = kTfLiteBuiltinLessEqual, - kLiteRtOpCodeTflSelect = kTfLiteBuiltinSelect, - kLiteRtOpCodeTflSlice = kTfLiteBuiltinSlice, - kLiteRtOpCodeTflSin = kTfLiteBuiltinSin, - kLiteRtOpCodeTflTransposeConv = kTfLiteBuiltinTransposeConv, - kLiteRtOpCodeTflSparseToDense = kTfLiteBuiltinSparseToDense, - kLiteRtOpCodeTflTile = kTfLiteBuiltinTile, - kLiteRtOpCodeTflExpandDims = kTfLiteBuiltinExpandDims, - kLiteRtOpCodeTflEqual = kTfLiteBuiltinEqual, - kLiteRtOpCodeTflNotEqual = kTfLiteBuiltinNotEqual, - kLiteRtOpCodeTflLog = kTfLiteBuiltinLog, - kLiteRtOpCodeTflSum = kTfLiteBuiltinSum, - kLiteRtOpCodeTflSqrt = kTfLiteBuiltinSqrt, - kLiteRtOpCodeTflRsqrt = kTfLiteBuiltinRsqrt, - kLiteRtOpCodeTflShape = kTfLiteBuiltinShape, - kLiteRtOpCodeTflPow = kTfLiteBuiltinPow, - kLiteRtOpCodeTflArgMin = kTfLiteBuiltinArgMin, - kLiteRtOpCodeTflFakeQuant = kTfLiteBuiltinFakeQuant, - kLiteRtOpCodeTflReduceProd = kTfLiteBuiltinReduceProd, - kLiteRtOpCodeTflReduceMax = kTfLiteBuiltinReduceMax, - kLiteRtOpCodeTflPack = kTfLiteBuiltinPack, - kLiteRtOpCodeTflLogicalOr = kTfLiteBuiltinLogicalOr, - kLiteRtOpCodeTflOneHot = kTfLiteBuiltinOneHot, - kLiteRtOpCodeTflLogicalAnd = kTfLiteBuiltinLogicalAnd, - kLiteRtOpCodeTflLogicalNot = kTfLiteBuiltinLogicalNot, - kLiteRtOpCodeTflUnpack = kTfLiteBuiltinUnpack, - kLiteRtOpCodeTflReduceMin = kTfLiteBuiltinReduceMin, - kLiteRtOpCodeTflFloorDiv = kTfLiteBuiltinFloorDiv, - kLiteRtOpCodeTflReduceAny = kTfLiteBuiltinReduceAny, - kLiteRtOpCodeTflSquare = kTfLiteBuiltinSquare, - kLiteRtOpCodeTflZerosLike = kTfLiteBuiltinZerosLike, - kLiteRtOpCodeTflFill = kTfLiteBuiltinFill, - kLiteRtOpCodeTflFloorMod = kTfLiteBuiltinFloorMod, - kLiteRtOpCodeTflRange = kTfLiteBuiltinRange, - kLiteRtOpCodeTflResizeNearestNeighbor = kTfLiteBuiltinResizeNearestNeighbor, - kLiteRtOpCodeTflLeakyRelu = kTfLiteBuiltinLeakyRelu, - kLiteRtOpCodeTflSquaredDifference = kTfLiteBuiltinSquaredDifference, - kLiteRtOpCodeTflMirrorPad = kTfLiteBuiltinMirrorPad, - kLiteRtOpCodeTflAbs = kTfLiteBuiltinAbs, - kLiteRtOpCodeTflSplitV = kTfLiteBuiltinSplitV, - kLiteRtOpCodeTflUnique = kTfLiteBuiltinUnique, - kLiteRtOpCodeTflCeil = kTfLiteBuiltinCeil, - kLiteRtOpCodeTflReverseV2 = kTfLiteBuiltinReverseV2, - kLiteRtOpCodeTflAddN = kTfLiteBuiltinAddN, - kLiteRtOpCodeTflGatherNd = kTfLiteBuiltinGatherNd, - kLiteRtOpCodeTflCos = kTfLiteBuiltinCos, - kLiteRtOpCodeTflWhere = kTfLiteBuiltinWhere, - kLiteRtOpCodeTflRank = kTfLiteBuiltinRank, - kLiteRtOpCodeTflElu = kTfLiteBuiltinElu, - kLiteRtOpCodeTflReverseSequence = kTfLiteBuiltinReverseSequence, - kLiteRtOpCodeTflMatrixDiag = kTfLiteBuiltinMatrixDiag, - kLiteRtOpCodeTflQuantize = kTfLiteBuiltinQuantize, - kLiteRtOpCodeTflMatrixSetDiag = kTfLiteBuiltinMatrixSetDiag, - kLiteRtOpCodeTflRound = kTfLiteBuiltinRound, - kLiteRtOpCodeTflHardSwish = kTfLiteBuiltinHardSwish, - kLiteRtOpCodeTflIf = kTfLiteBuiltinIf, - kLiteRtOpCodeTflWhile = kTfLiteBuiltinWhile, - kLiteRtOpCodeTflNonMaxSuppressionV4 = kTfLiteBuiltinNonMaxSuppressionV4, - kLiteRtOpCodeTflNonMaxSuppressionV5 = kTfLiteBuiltinNonMaxSuppressionV5, - kLiteRtOpCodeTflScatterNd = kTfLiteBuiltinScatterNd, - kLiteRtOpCodeTflSelectV2 = kTfLiteBuiltinSelectV2, - kLiteRtOpCodeTflDensify = kTfLiteBuiltinDensify, - kLiteRtOpCodeTflSegmentSum = kTfLiteBuiltinSegmentSum, - kLiteRtOpCodeTflBatchMatmul = kTfLiteBuiltinBatchMatmul, - kLiteRtOpCodeTflPlaceholderForGreaterOpCodeTfls = - kTfLiteBuiltinPlaceholderForGreaterOpCodes, - kLiteRtOpCodeTflCumsum = kTfLiteBuiltinCumsum, - kLiteRtOpCodeTflCallOnce = kTfLiteBuiltinCallOnce, - kLiteRtOpCodeTflBroadcastTo = kTfLiteBuiltinBroadcastTo, - kLiteRtOpCodeTflRfft2d = kTfLiteBuiltinRfft2d, - kLiteRtOpCodeTflConv3d = kTfLiteBuiltinConv3d, - kLiteRtOpCodeTflImag = kTfLiteBuiltinImag, - kLiteRtOpCodeTflReal = kTfLiteBuiltinReal, - kLiteRtOpCodeTflComplexAbs = kTfLiteBuiltinComplexAbs, - kLiteRtOpCodeTflHashtable = kTfLiteBuiltinHashtable, - kLiteRtOpCodeTflHashtableFind = kTfLiteBuiltinHashtableFind, - kLiteRtOpCodeTflHashtableImport = kTfLiteBuiltinHashtableImport, - kLiteRtOpCodeTflHashtableSize = kTfLiteBuiltinHashtableSize, - kLiteRtOpCodeTflReduceAll = kTfLiteBuiltinReduceAll, - kLiteRtOpCodeTflConv3dTranspose = kTfLiteBuiltinConv3dTranspose, - kLiteRtOpCodeTflVarHandle = kTfLiteBuiltinVarHandle, - kLiteRtOpCodeTflReadVariable = kTfLiteBuiltinReadVariable, - kLiteRtOpCodeTflAssignVariable = kTfLiteBuiltinAssignVariable, - kLiteRtOpCodeTflBroadcastArgs = kTfLiteBuiltinBroadcastArgs, - kLiteRtOpCodeTflRandomStandardNormal = kTfLiteBuiltinRandomStandardNormal, - kLiteRtOpCodeTflBucketize = kTfLiteBuiltinBucketize, - kLiteRtOpCodeTflRandomUniform = kTfLiteBuiltinRandomUniform, - kLiteRtOpCodeTflMultinomial = kTfLiteBuiltinMultinomial, - kLiteRtOpCodeTflGelu = kTfLiteBuiltinGelu, - kLiteRtOpCodeTflDynamicUpdateSlice = kTfLiteBuiltinDynamicUpdateSlice, - kLiteRtOpCodeTflRelu0To1 = kTfLiteBuiltinRelu0To1, - kLiteRtOpCodeTflUnsortedSegmentProd = kTfLiteBuiltinUnsortedSegmentProd, - kLiteRtOpCodeTflUnsortedSegmentMax = kTfLiteBuiltinUnsortedSegmentMax, - kLiteRtOpCodeTflUnsortedSegmentSum = kTfLiteBuiltinUnsortedSegmentSum, - kLiteRtOpCodeTflAtan2 = kTfLiteBuiltinAtan2, - kLiteRtOpCodeTflUnsortedSegmentMin = kTfLiteBuiltinUnsortedSegmentMin, - kLiteRtOpCodeTflSign = kTfLiteBuiltinSign, - kLiteRtOpCodeTflBitcast = kTfLiteBuiltinBitcast, - kLiteRtOpCodeTflBitwiseXor = kTfLiteBuiltinBitwiseXor, - kLiteRtOpCodeTflRightShift = kTfLiteBuiltinRightShift, - kLiteRtOpCodeShloLogistic = kTfLiteBuiltinStablehloLogistic, - kLiteRtOpCodeShloAdd = kTfLiteBuiltinStablehloAdd, - kLiteRtOpCodeShloDivide = kTfLiteBuiltinStablehloDivide, - kLiteRtOpCodeShloMultiply = kTfLiteBuiltinStablehloMultiply, - kLiteRtOpCodeShloMaximum = kTfLiteBuiltinStablehloMaximum, - kLiteRtOpCodeShloReshape = kTfLiteBuiltinStablehloReshape, - kLiteRtOpCodeShloClamp = kTfLiteBuiltinStablehloClamp, - kLiteRtOpCodeShloConcatenate = kTfLiteBuiltinStablehloConcatenate, - kLiteRtOpCodeShloBroadcastInDim = kTfLiteBuiltinStablehloBroadcastInDim, - kLiteRtOpCodeShloConvolution = kTfLiteBuiltinStablehloConvolution, - kLiteRtOpCodeShloSlice = kTfLiteBuiltinStablehloSlice, - kLiteRtOpCodeShloCustomCall = kTfLiteBuiltinStablehloCustomCall, - kLiteRtOpCodeShloReduce = kTfLiteBuiltinStablehloReduce, - kLiteRtOpCodeShloAbs = kTfLiteBuiltinStablehloAbs, - kLiteRtOpCodeShloAnd = kTfLiteBuiltinStablehloAnd, - kLiteRtOpCodeShloCosine = kTfLiteBuiltinStablehloCosine, - kLiteRtOpCodeShloExponential = kTfLiteBuiltinStablehloExponential, - kLiteRtOpCodeShloFloor = kTfLiteBuiltinStablehloFloor, - kLiteRtOpCodeShloLog = kTfLiteBuiltinStablehloLog, - kLiteRtOpCodeShloMinimum = kTfLiteBuiltinStablehloMinimum, - kLiteRtOpCodeShloNegate = kTfLiteBuiltinStablehloNegate, - kLiteRtOpCodeShloOr = kTfLiteBuiltinStablehloOr, - kLiteRtOpCodeShloPower = kTfLiteBuiltinStablehloPower, - kLiteRtOpCodeShloRemainder = kTfLiteBuiltinStablehloRemainder, - kLiteRtOpCodeShloRsqrt = kTfLiteBuiltinStablehloRsqrt, - kLiteRtOpCodeShloSelect = kTfLiteBuiltinStablehloSelect, - kLiteRtOpCodeShloSubtract = kTfLiteBuiltinStablehloSubtract, - kLiteRtOpCodeShloTanh = kTfLiteBuiltinStablehloTanh, - kLiteRtOpCodeShloScatter = kTfLiteBuiltinStablehloScatter, - kLiteRtOpCodeShloCompare = kTfLiteBuiltinStablehloCompare, - kLiteRtOpCodeShloConvert = kTfLiteBuiltinStablehloConvert, - kLiteRtOpCodeShloDynamicSlice = kTfLiteBuiltinStablehloDynamicSlice, - kLiteRtOpCodeShloDynamicUpdateSlice = - kTfLiteBuiltinStablehloDynamicUpdateSlice, - kLiteRtOpCodeShloPad = kTfLiteBuiltinStablehloPad, - kLiteRtOpCodeShloIota = kTfLiteBuiltinStablehloIota, - kLiteRtOpCodeShloGeneral = kTfLiteBuiltinStablehloDotGeneral, - kLiteRtOpCodeShloWindow = kTfLiteBuiltinStablehloReduceWindow, - kLiteRtOpCodeShloSort = kTfLiteBuiltinStablehloSort, - kLiteRtOpCodeShloWhile = kTfLiteBuiltinStablehloWhile, - kLiteRtOpCodeShloGather = kTfLiteBuiltinStablehloGather, - kLiteRtOpCodeShloTranspose = kTfLiteBuiltinStablehloTranspose, - kLiteRtOpCodeTflDilate = kTfLiteBuiltinDilate, - kLiteRtOpCodeShloRngBitGenerator = kTfLiteBuiltinStablehloRngBitGenerator, - kLiteRtOpCodeTflReduceWindow = kTfLiteBuiltinReduceWindow, - kLiteRtOpCodeShloComposite = kTfLiteBuiltinStablehloComposite, -} LiteRtOpCode; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_options.cc b/tensorflow/lite/experimental/litert/c/litert_options.cc deleted file mode 100644 index e14759ae641809..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_options.cc +++ /dev/null @@ -1,771 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_options.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// -// Op Options -// - -LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflAdd) { - return kLiteRtStatusErrorInvalidArgument; - } - const auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorNotFound; - } - *fused_activation = opts.AsAddOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x) { - if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *adj_x = opts.AsBatchMatMulOptions()->adj_x; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y) { - if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *adj_y = opts.AsBatchMatMulOptions()->adj_y; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input) { - if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *asymmetric_quantize_input = - opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConcatenationFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsConcatenationOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis) { - if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *axis = opts.AsConcatenationOptions()->axis; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflDiv) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsDivOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsFullyConnectedOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, - bool* keep_num_dims) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *keep_num_dims = opts.AsFullyConnectedOptions()->keep_num_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( - LiteRtOp op, uint32_t* quantized_bias_type) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *quantized_bias_type = opts.AsFullyConnectedOptions()->quantized_bias_type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *asymmetric_quantize_input = - opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( - LiteRtOp op, uint32_t* weights_format) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *weights_format = opts.AsFullyConnectedOptions()->weights_format; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflMul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsMulOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta) { - if (op->OpCode() != kLiteRtOpCodeTflSoftmax) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *beta = opts.AsSoftmaxOptions()->beta; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, - int32_t* begin_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *begin_mask = opts.AsStridedSliceOptions()->begin_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, - int32_t* end_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *end_mask = opts.AsStridedSliceOptions()->end_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, - int32_t* ellipsis_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *ellipsis_mask = opts.AsStridedSliceOptions()->ellipsis_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, - int32_t* new_axis_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *new_axis_mask = opts.AsStridedSliceOptions()->new_axis_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( - LiteRtOp op, int32_t* shrink_axis_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *shrink_axis_mask = opts.AsStridedSliceOptions()->shrink_axis_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *offset = opts.AsStridedSliceOptions()->offset; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflSub) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsSubOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, - const int32_t** new_shape, - int32_t* new_shape_size) { - if (op->OpCode() != kLiteRtOpCodeTflReshape) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - *new_shape_size = -1; - return kLiteRtStatusErrorInvalidArgument; - } - if (opts.AsReshapeOptions() == nullptr) { - *new_shape_size = -1; - return kLiteRtStatusOk; - } else { - *new_shape = opts.AsReshapeOptions()->new_shape.data(); - *new_shape_size = opts.AsReshapeOptions()->new_shape.size(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims) { - if (op->OpCode() != kLiteRtOpCodeTflSum) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - // Sum OP options is stored as ReducerOptions. - *keepdims = opts.AsReducerOptions()->keep_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetPackAxisOption(LiteRtOp op, int32_t* axis) { - if (op->OpCode() != kLiteRtOpCodeTflPack) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *axis = opts.AsPackOptions()->axis; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetGatherAxisOption(LiteRtOp op, int32_t* axis) { - if (op->OpCode() != kLiteRtOpCodeTflGather) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *axis = opts.AsGatherOptions()->axis; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetGatherBatchDimsOption(LiteRtOp op, int32_t* batch_dims) { - if (op->OpCode() != kLiteRtOpCodeTflGather) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *batch_dims = opts.AsGatherOptions()->batch_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMeanKeepDimsOption(LiteRtOp op, bool* keepdims) { - if (op->OpCode() != kLiteRtOpCodeTflMean) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - // Mean OP options is stored as ReducerOptions. - *keepdims = opts.AsReducerOptions()->keep_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSplitNumSplitsOption(LiteRtOp op, int32_t* num_splits) { - if (op->OpCode() != kLiteRtOpCodeTflSplit) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_splits = opts.AsSplitOptions()->num_splits; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSHLOCompositeOpName(LiteRtOp op, const char** name) { - if (op->OpCode() != kLiteRtOpCodeShloComposite) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions2(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *name = opts.AsStableHLOCompositeOptions()->name.data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSHLOCompositeOpDecompositionSubgraphIndex( - LiteRtOp op, int32_t* subgraph_index) { - if (op->OpCode() != kLiteRtOpCodeShloComposite) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions2(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *subgraph_index = - opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dPaddingOption(LiteRtOp op, uint32_t* padding) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *padding = opts.AsConv2DOptions()->padding; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dStrideWOption(LiteRtOp op, int32_t* stride_w) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_w = opts.AsConv2DOptions()->stride_w; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dStrideHOption(LiteRtOp op, int32_t* stride_h) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_h = opts.AsConv2DOptions()->stride_h; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation_function = - opts.AsConv2DOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dDilationWOption(LiteRtOp op, - int32_t* dilation_w_factor) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_w_factor = opts.AsConv2DOptions()->dilation_w_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dDilationHOption(LiteRtOp op, - int32_t* dilation_h_factor) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_h_factor = opts.AsConv2DOptions()->dilation_h_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dPaddingOption(LiteRtOp op, - uint32_t* padding) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *padding = opts.AsDepthwiseConv2DOptions()->padding; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dStrideWOption(LiteRtOp op, - int32_t* stride_w) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_w = opts.AsDepthwiseConv2DOptions()->stride_w; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dStrideHOption(LiteRtOp op, - int32_t* stride_h) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_h = opts.AsDepthwiseConv2DOptions()->stride_h; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation_function = - opts.AsDepthwiseConv2DOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dDilationWOption( - LiteRtOp op, int32_t* dilation_w_factor) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_w_factor = opts.AsDepthwiseConv2DOptions()->dilation_w_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dDilationHOptions( - LiteRtOp op, int32_t* dilation_h_factor) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_h_factor = opts.AsDepthwiseConv2DOptions()->dilation_h_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dOptions(LiteRtOp op, int8_t* padding, - int32_t* stride_w, int32_t* stride_h, - int32_t* filter_width, - int32_t* filter_height, - int8_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - auto* options = opts.AsPool2DOptions(); - *padding = options->padding; - *stride_w = options->stride_w; - *stride_h = options->stride_h; - *filter_width = options->filter_width; - *filter_height = options->filter_height; - *fused_activation_function = options->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dPaddingOption(LiteRtOp op, - uint32_t* padding) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *padding = opts.AsPool2DOptions()->padding; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dStrideWOption(LiteRtOp op, - int32_t* stride_w) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_w = opts.AsPool2DOptions()->stride_w; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dStrideHOption(LiteRtOp op, - int32_t* stride_h) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_h = opts.AsPool2DOptions()->stride_h; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dFilterWidthOption(LiteRtOp op, - int32_t* filter_width) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *filter_width = opts.AsPool2DOptions()->filter_width; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dFilterHeightOption(LiteRtOp op, - int32_t* filter_height) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *filter_height = opts.AsPool2DOptions()->filter_height; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation_function = - opts.AsPool2DOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeBilinearAlignCornersOption(LiteRtOp op, - bool* align_corners) { - if (op->OpCode() != kLiteRtOpCodeTflResizeBilinear) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *align_corners = opts.AsResizeBilinearOptions()->align_corners; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeBilinearHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers) { - if (op->OpCode() != kLiteRtOpCodeTflResizeBilinear) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *half_pixel_centers = opts.AsResizeBilinearOptions()->half_pixel_centers; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetLeakyReluAlphaOption(LiteRtOp op, float* alpha) { - if (op->OpCode() != kLiteRtOpCodeTflLeakyRelu) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *alpha = opts.AsLeakyReluOptions()->alpha; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthToSpaceBlockSizeOption(LiteRtOp op, - int32_t* block_size) { - if (op->OpCode() != kLiteRtOpCodeTflDepthToSpace) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *block_size = opts.AsDepthToSpaceOptions()->block_size; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSpaceToDepthBlockSizeOption(LiteRtOp op, - int32_t* block_size) { - if (op->OpCode() != kLiteRtOpCodeTflSpaceToDepth) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *block_size = opts.AsSpaceToDepthOptions()->block_size; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeNearestNeighborAlignCornersOption( - LiteRtOp op, bool* align_corners) { - if (op->OpCode() != kLiteRtOpCodeTflResizeNearestNeighbor) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *align_corners = opts.AsResizeNearestNeighborOptions()->align_corners; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers) { - if (op->OpCode() != kLiteRtOpCodeTflResizeNearestNeighbor) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *half_pixel_centers = - opts.AsResizeNearestNeighborOptions()->half_pixel_centers; - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_options.h b/tensorflow/lite/experimental/litert/c/litert_options.h deleted file mode 100644 index ed746575c3d03e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_options.h +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ - -#include // NOLINT: To use bool type in C -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtOp); - -//============================================================================== -// -// Get option APIs for LiteRt ADD op. -// Options: -// - FusedActivationOption : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt BatchMatmul op. -// Options: -// - AdjXOption : bool -// - AdjYOption : bool -// - AsymmtericQuantizeInputOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x); -LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y); -LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input); - -//============================================================================== -// -// Get option APIs for LiteRt Concatenation op. -// Options: -// - FusedActivationOption : uint32_t -// - AxisOption : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetConcatenationFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation); -LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis); - -//============================================================================== -// -// Get option APIs for LiteRt Div op. -// Options: -// - FusedActivationOption : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt FullyConnected op. -// Options: -// - FusedActivationOption : uint32_t -// - WeightsFormatOption : uint32_t -// - KeepNumDimsOption : bool -// - QuantizedBiasTypeOption : uint32_t -// - AsymmtericQuantizeInputOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation); -LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( - LiteRtOp op, uint32_t* weights_format); -LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, - bool* keep_num_dims); -LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( - LiteRtOp op, uint32_t* quantized_bias_type); -LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input); - -//============================================================================== -// -// Get option APIs for LiteRt Mul op. -// Options: -// - FusedActivationOption : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt Softmax op. -// Options: -// - BetaOption : float -// -//============================================================================== -LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta); - -//============================================================================== -// -// Get option APIs for LiteRt StridedSlice op. -// Options: -// - BeginMaskOption : int32_t -// - EndMaskOption : int32_t -// - EllipsisMaskOption : int32_t -// - NewAxisMaskOption : int32_t -// - ShrinkAxisMaskOption : int32_t -// - OffsetOption : bool - -//============================================================================== -LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, - int32_t* begin_mask); -LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, int32_t* end_mask); -LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, - int32_t* ellipsis_mask); -LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, - int32_t* new_axis_mask); -LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( - LiteRtOp op, int32_t* shrink_axis_mask); -LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset); - -//============================================================================== -// -// Get option APIs for LiteRt Sub op. -// Options: -// - FusedActivationOption : uint32_t -// - (Not supported) PotScaleInt16Option : bool -// -//============================================================================== -LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt Reshape op. -// Options: -// - new_shape : int32_t[] -// -//============================================================================== -LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, - const int32_t** new_shape, - int32_t* new_shape_size); - -//============================================================================== -// -// Get option APIs for LiteRt Sum op. -// Options: -// - KeepdimsOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims); - -//============================================================================== -// -// Get option APIs for LiteRt Pack op. -// Options: -// - axisOption : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetPackAxisOption(LiteRtOp op, int32_t* axis); - -//============================================================================== -// -// Get option APIs for LiteRt Gather op. -// Options: -// - axisOption : int32_t -// - batch_dims : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetGatherAxisOption(LiteRtOp op, int32_t* axis); -LiteRtStatus LiteRtGetGatherBatchDimsOption(LiteRtOp op, int32_t* batch_dims); - -//============================================================================== -// -// Get option APIs for LiteRt Mean op. -// Options: -// - keepdimsOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetMeanKeepDimsOption(LiteRtOp op, bool* keepdims); - -//============================================================================== -// -// Get option APIs for LiteRt Split op. -// Options: -// - num_splits : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetSplitNumSplitsOption(LiteRtOp op, int32_t* num_splits); - -//============================================================================== -// -// Get option APIs for LiteRt SHLO Composite op. -// Options: -// - name : string -// - decomposition_subgraph_index : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetSHLOCompositeOpName(LiteRtOp op, const char** name); -LiteRtStatus LiteRtGetSHLOCompositeOpDecompositionSubgraphIndex( - LiteRtOp op, int32_t* decomposition_subgraph_index); - -//============================================================================== -// -// Get option APIs for LiteRt Conv2d op. -// Options: -// - padding : uint32_t -// - stride_w : int32_t -// - stride_h : int32_t -// - fused_activation_function : uint32_t -// - dilation_w_factor : int32_t -// - dilation_h_factor : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetConv2dPaddingOption(LiteRtOp op, uint32_t* padding); -LiteRtStatus LiteRtGetConv2dStrideWOption(LiteRtOp op, int32_t* stride_w); -LiteRtStatus LiteRtGetConv2dStrideHOption(LiteRtOp op, int32_t* stride_h); -LiteRtStatus LiteRtGetConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function); -LiteRtStatus LiteRtGetConv2dDilationWOption(LiteRtOp op, - int32_t* dilation_w_factor); -LiteRtStatus LiteRtGetConv2dDilationHOption(LiteRtOp op, - int32_t* dilation_h_factor); - -//============================================================================== -// -// Get option APIs for LiteRt DepthwiseConv2d op. -// Options: -// - padding : uint32_t -// - stride_w : int32_t -// - stride_h : int32_t -// - fused_activation_function : uint32_t -// - dilation_w_factor : int32_t -// - dilation_h_factor : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetDepthwiseConv2dPaddingOption(LiteRtOp op, - uint32_t* padding); -LiteRtStatus LiteRtGetDepthwiseConv2dStrideWOption(LiteRtOp op, - int32_t* stride_w); -LiteRtStatus LiteRtGetDepthwiseConv2dStrideHOption(LiteRtOp op, - int32_t* stride_h); -LiteRtStatus LiteRtGetDepthwiseConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function); -LiteRtStatus LiteRtGetDepthwiseConv2dDilationWOption( - LiteRtOp op, int32_t* dilation_w_factor); -LiteRtStatus LiteRtGetDepthwiseConv2dDilationHOptions( - LiteRtOp op, int32_t* dilation_h_factor); - -//============================================================================== -// -// Get option APIs for LiteRt AveragePool2d op. -// Options: -// - padding : uint32_t -// - stride_w : int32_t -// - stride_h : int32_t -// - filter_width : int32_t -// - filter_height : int32_t -// - fused_activation_function : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetAveragePool2dPaddingOption(LiteRtOp op, - uint32_t* padding); -LiteRtStatus LiteRtGetAveragePool2dStrideWOption(LiteRtOp op, - int32_t* stride_w); -LiteRtStatus LiteRtGetAveragePool2dStrideHOption(LiteRtOp op, - int32_t* stride_h); -LiteRtStatus LiteRtGetAveragePool2dFilterWidthOption(LiteRtOp op, - int32_t* filter_width); -LiteRtStatus LiteRtGetAveragePool2dFilterHeightOption(LiteRtOp op, - int32_t* filter_height); -LiteRtStatus LiteRtGetAveragePool2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function); - -//============================================================================== -// -// Get option APIs for LiteRt ResizeBilinear op. -// Options: -// - align_corners : bool -// - half_pixel_centers : bool -// -//============================================================================== -LiteRtStatus LiteRtGetResizeBilinearAlignCornersOption(LiteRtOp op, - bool* align_corners); -LiteRtStatus LiteRtGetResizeBilinearHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers); - -//============================================================================== -// -// Get option APIs for LiteRt LeakyRelu op. -// Options: -// - alpha : float -// -//============================================================================== -LiteRtStatus LiteRtGetLeakyReluAlphaOption(LiteRtOp op, float* alpha); - -//============================================================================== -// -// Get option APIs for LiteRt DepthToSpace op. -// Options: -// - block_size : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetDepthToSpaceBlockSizeOption(LiteRtOp op, - int32_t* block_size); - -//============================================================================== -// -// Get option APIs for LiteRt SpaceToDepth op. -// Options: -// - block_size : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetSpaceToDepthBlockSizeOption(LiteRtOp op, - int32_t* block_size); - -//============================================================================== -// -// Get option APIs for LiteRt ResizeNearestNeighbor op. -// Options: -// - align_corners : bool -// - half_pixel_centers : bool -// -//============================================================================== -LiteRtStatus LiteRtGetResizeNearestNeighborAlignCornersOption( - LiteRtOp op, bool* align_corners); -LiteRtStatus LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_options_test.cc deleted file mode 100644 index 41c0c07f8bf39e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_options_test.cc +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_options.h" - -#include - -#include // IWYU pragma: keep -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { -using testing::litert::IsError; - -TEST(GetOpOptionTest, TestGetAddOptions) { - auto model = litert::testing::LoadTestFileModel("simple_add_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetAddFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetBatchMatmulOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_batch_matmul_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool adj_x; - LITERT_ASSERT_OK(LiteRtGetBatchMatmulAdjXOption(op, &adj_x)); - ASSERT_EQ(adj_x, false); - - bool adj_y; - LITERT_ASSERT_OK(LiteRtGetBatchMatmulAdjYOption(op, &adj_y)); - ASSERT_EQ(adj_y, false); - - bool asymmetric_quantize_input; - LITERT_ASSERT_OK(LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( - op, &asymmetric_quantize_input)); - ASSERT_EQ(asymmetric_quantize_input, false); -} - -TEST(GetOpOptionTest, TestGetConcatenationOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_concatenation_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK( - LiteRtGetConcatenationFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); - - int32_t axis; - LITERT_ASSERT_OK(LiteRtGetConcatenationAxisOption(op, &axis)); - ASSERT_EQ(axis, 2); -} - -TEST(GetOpOptionTest, TestGetDivOptions) { - auto model = litert::testing::LoadTestFileModel("simple_div_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetDivFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetFullyConnectedOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_fully_connected_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK( - LiteRtGetFullyConnectedFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); - - uint32_t weights_format; - LITERT_ASSERT_OK( - LiteRtGetFullyConnectedWeightsFormatOption(op, &weights_format)); - ASSERT_EQ(weights_format, 0); - - bool keep_num_dims; - LITERT_ASSERT_OK( - LiteRtGetFullyConnectedKeepNumDimsOption(op, &keep_num_dims)); - ASSERT_EQ(keep_num_dims, true); - - uint32_t quantized_bias_type; - LITERT_ASSERT_OK( - LiteRtFullyConnectedGetQuantizedBiasTypeOption(op, &quantized_bias_type)); - ASSERT_EQ(quantized_bias_type, 0); - - bool asymmetric_quantize_input; - LITERT_ASSERT_OK(LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( - op, &asymmetric_quantize_input)); - ASSERT_EQ(asymmetric_quantize_input, false); -} - -TEST(GetOpOptionTest, TestGetMulOptions) { - auto model = litert::testing::LoadTestFileModel("simple_mul_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetMulFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetSoftmaxOptions) { - auto model = litert::testing::LoadTestFileModel("simple_softmax_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - float beta; - LITERT_ASSERT_OK(LiteRtGetSoftmaxBetaOption(op, &beta)); - EXPECT_FLOAT_EQ(beta, 1.0); -} - -TEST(GetOpOptionTest, TestGetStridedSliceOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_strided_slice_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t begin_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceBeginMaskOption(op, &begin_mask)); - ASSERT_EQ(begin_mask, 0); - - int32_t end_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceEndMaskOption(op, &end_mask)); - ASSERT_EQ(end_mask, 0); - - int32_t ellipsis_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceEllipsisMaskOption(op, &ellipsis_mask)); - ASSERT_EQ(ellipsis_mask, 0); - - int32_t new_axis_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceNewAxisMaskOption(op, &new_axis_mask)); - ASSERT_EQ(new_axis_mask, 0); - - int32_t shrink_axis_mask; - LITERT_ASSERT_OK( - LiteRtGetStridedSliceShrinkAxisMaskOption(op, &shrink_axis_mask)); - ASSERT_EQ(shrink_axis_mask, 0); - - bool offset; - LITERT_ASSERT_OK(LiteRtGetStridedSliceOffsetOption(op, &offset)); - ASSERT_EQ(offset, false); -} - -TEST(GetOpOptionTest, TestGetSubOptions) { - auto model = litert::testing::LoadTestFileModel("simple_sub_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetSubFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetNullReshapeOptions) { - auto model = litert::testing::LoadTestFileModel("simple_reshape_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - const int32_t* new_shape = nullptr; - int32_t new_shape_size; - - EXPECT_THAT(LiteRtGetReshapeNewShapeOption(op, &new_shape, &new_shape_size), - IsError(kLiteRtStatusErrorInvalidArgument)); - ASSERT_EQ(new_shape_size, -1); -} - -TEST(GetOpOptionTest, TestGetSumOptions) { - auto model = litert::testing::LoadTestFileModel("simple_sum_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool keepdims; - LITERT_ASSERT_OK(LiteRtGetSumKeepDimsOption(op, &keepdims)); - ASSERT_EQ(keepdims, true); -} - -TEST(GetOpOptionTest, TestGetPackOptions) { - auto model = litert::testing::LoadTestFileModel("simple_pack_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t axis; - LITERT_ASSERT_OK(LiteRtGetPackAxisOption(op, &axis)); - ASSERT_EQ(axis, 0); -} - -TEST(GetOpOptionTest, TestGetGatherOptions) { - auto model = litert::testing::LoadTestFileModel("simple_gather_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t axis; - LITERT_ASSERT_OK(LiteRtGetGatherAxisOption(op, &axis)); - ASSERT_EQ(axis, 0); - - int32_t batch_dims; - LITERT_ASSERT_OK(LiteRtGetGatherBatchDimsOption(op, &batch_dims)); - ASSERT_EQ(batch_dims, 0); -} - -TEST(GetOpOptionTest, TestGetMeanOptions) { - auto model = litert::testing::LoadTestFileModel("simple_mean_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool keepdims; - LITERT_ASSERT_OK(LiteRtGetMeanKeepDimsOption(op, &keepdims)); - ASSERT_EQ(keepdims, false); -} - -TEST(GetOpOptionTest, TestGetSplitOptions) { - auto model = litert::testing::LoadTestFileModel("simple_split_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t num_splits; - LITERT_ASSERT_OK(LiteRtGetSplitNumSplitsOption(op, &num_splits)); - ASSERT_EQ(num_splits, 3); -} - -TEST(GetOpOptionTest, TestGetConv2dOptions) { - auto model = litert::testing::LoadTestFileModel("simple_conv_2d_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t padding; - LITERT_ASSERT_OK(LiteRtGetConv2dPaddingOption(op, &padding)); - ASSERT_EQ(padding, 0); - int32_t stride_w; - LITERT_ASSERT_OK(LiteRtGetConv2dStrideWOption(op, &stride_w)); - ASSERT_EQ(stride_w, 1); - int32_t stride_h; - LITERT_ASSERT_OK(LiteRtGetConv2dStrideHOption(op, &stride_h)); - ASSERT_EQ(stride_h, 1); - uint32_t fused_activation_function; - LITERT_ASSERT_OK( - LiteRtGetConv2dFusedActivationOption(op, &fused_activation_function)); - ASSERT_EQ(fused_activation_function, 0); - int32_t dilation_w_factor; - LITERT_ASSERT_OK(LiteRtGetConv2dDilationWOption(op, &dilation_w_factor)); - ASSERT_EQ(dilation_w_factor, 1); - int32_t dilation_h_factor; - LITERT_ASSERT_OK(LiteRtGetConv2dDilationWOption(op, &dilation_h_factor)); - ASSERT_EQ(dilation_h_factor, 1); -} - -TEST(GetOpOptionTest, TestGetDepthwiseConv2dOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_depthwise_conv_2d_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t padding; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dPaddingOption(op, &padding)); - ASSERT_EQ(padding, 1); - int32_t stride_w; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dStrideWOption(op, &stride_w)); - ASSERT_EQ(stride_w, 1); - int32_t stride_h; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dStrideHOption(op, &stride_h)); - ASSERT_EQ(stride_h, 1); - uint32_t fused_activation_function; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dFusedActivationOption( - op, &fused_activation_function)); - ASSERT_EQ(fused_activation_function, 0); - int32_t dilation_w_factor; - LITERT_ASSERT_OK( - LiteRtGetDepthwiseConv2dDilationWOption(op, &dilation_w_factor)); - ASSERT_EQ(dilation_w_factor, 4); - int32_t dilation_h_factor; - LITERT_ASSERT_OK( - LiteRtGetDepthwiseConv2dDilationHOptions(op, &dilation_h_factor)); - ASSERT_EQ(dilation_h_factor, 4); -} - -TEST(GetOpOptionTest, TestGetAveragePool2dOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_average_poll_2d.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t padding; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dPaddingOption(op, &padding)); - ASSERT_EQ(padding, 1); - int32_t stride_w; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dStrideWOption(op, &stride_w)); - ASSERT_EQ(stride_w, 4); - int32_t stride_h; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dStrideHOption(op, &stride_h)); - ASSERT_EQ(stride_h, 4); - int32_t filter_width; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dFilterWidthOption(op, &filter_width)); - ASSERT_EQ(filter_width, 4); - int32_t filter_height; - LITERT_ASSERT_OK( - LiteRtGetAveragePool2dFilterHeightOption(op, &filter_height)); - ASSERT_EQ(filter_height, 4); - uint32_t fused_activation_function; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dFusedActivationOption( - op, &fused_activation_function)); - ASSERT_EQ(fused_activation_function, 0); -} - -TEST(GetOpOptionTest, TestGetResizeBilinearOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_resize_bilinear_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool align_corners; - LITERT_ASSERT_OK( - LiteRtGetResizeBilinearAlignCornersOption(op, &align_corners)); - ASSERT_EQ(align_corners, false); - bool half_pixel_centers; - LITERT_ASSERT_OK( - LiteRtGetResizeBilinearHalfPixelCenterOption(op, &half_pixel_centers)); - ASSERT_EQ(half_pixel_centers, true); -} - -TEST(GetOpOptionTest, TestGetLeakyReluOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_leaky_relu_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - float alpha; - LITERT_ASSERT_OK(LiteRtGetLeakyReluAlphaOption(op, &alpha)); - ASSERT_FLOAT_EQ(alpha, 0.2); -} - -TEST(GetOpOptionTest, TestGetDepthToSpaceOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_depth_to_space_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t block_size; - LITERT_ASSERT_OK(LiteRtGetDepthToSpaceBlockSizeOption(op, &block_size)); - ASSERT_EQ(block_size, 2); -} - -TEST(GetOpOptionTest, TestGetSpaceToDepthOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_space_to_depth_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t block_size; - LITERT_ASSERT_OK(LiteRtGetSpaceToDepthBlockSizeOption(op, &block_size)); - ASSERT_EQ(block_size, 2); -} - -TEST(GetOpOptionTest, TestGetResizeNearestNeighborOptions) { - auto model = litert::testing::LoadTestFileModel( - "simple_resize_nearest_neighbor_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool align_corners; - LITERT_ASSERT_OK( - LiteRtGetResizeNearestNeighborAlignCornersOption(op, &align_corners)); - ASSERT_EQ(align_corners, false); - bool half_pixel_centers; - LITERT_ASSERT_OK(LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - op, &half_pixel_centers)); - ASSERT_EQ(half_pixel_centers, true); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc deleted file mode 100644 index 30588753ecb1ff..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc +++ /dev/null @@ -1,480 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#endif // LITERT_HAS_OPENCL_SUPPORT - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( - const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, - size_t size, LiteRtHostMemoryDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !host_buffer_addr || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromHostMemory( - *tensor_type, - absl::MakeSpan(static_cast(host_buffer_addr), size), - deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -#if LITERT_HAS_AHWB_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromAhwb( - const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !ahwb || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromAhwb( - *tensor_type, ahwb, ahwb_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, - AHardwareBuffer** ahwb) { - if (!tensor_buffer || !ahwb) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto ahwb_buffer = tensor_buffer->GetAhwbBuffer(); - if (!ahwb_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", ahwb_buffer.Error().Message().c_str()); - return ahwb_buffer.Error().Status(); - } - - *ahwb = *ahwb_buffer; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_ION_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( - const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromIonBuffer( - *tensor_type, ion_buffer_addr, ion_buffer_fd, ion_buffer_size, - ion_buffer_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer tensor_buffer, - void** ion_buffer_addr, - int* ion_buffer_fd) { - if (!tensor_buffer || !ion_buffer_addr || !ion_buffer_fd) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto ion_buffer = tensor_buffer->GetIonBuffer(); - if (!ion_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", ion_buffer.Error().Message().c_str()); - return ion_buffer.Error().Status(); - } - - *ion_buffer_addr = ion_buffer->first; - *ion_buffer_fd = ion_buffer->second; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_ION_SUPPORT - -#if LITERT_HAS_DMABUF_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( - const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromDmaBufBuffer( - *tensor_type, dmabuf_buffer_addr, dmabuf_buffer_fd, dmabuf_buffer_size, - dmabuf_buffer_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, - void** dmabuf_buffer_addr, - int* dmabuf_buffer_fd) { - if (!tensor_buffer || !dmabuf_buffer_addr || !dmabuf_buffer_fd) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto dmabuf_buffer = tensor_buffer->GetDmaBufBuffer(); - if (!dmabuf_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", dmabuf_buffer.Error().Message().c_str()); - return dmabuf_buffer.Error().Status(); - } - - *dmabuf_buffer_addr = dmabuf_buffer->first; - *dmabuf_buffer_fd = dmabuf_buffer->second; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_DMABUF_SUPPORT - -#if LITERT_HAS_OPENCL_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromOpenClBuffer( - const LiteRtRankedTensorType* tensor_type, cl_mem cl_mem_addr, - size_t opencl_buffer_size, LiteRtOpenClDeallocator deallocator, - LiteRtTensorBuffer* buffer) { - if (!tensor_type || !buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromOpenClBuffer( - *tensor_type, cl_mem_addr, opencl_buffer_size); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferOpenClBuffer(LiteRtTensorBuffer tensor_buffer, - cl_mem* cl_mem_addr) { - if (!tensor_buffer || !cl_mem_addr) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto opencl_buffer = tensor_buffer->GetOpenClBuffer(); - if (!opencl_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", opencl_buffer.Error().Message().c_str()); - return opencl_buffer.Error().Status(); - } - - *cl_mem_addr = (*opencl_buffer)->GetMemoryPtr(); - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -#if LITERT_HAS_FASTRPC_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( - const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, - int fastrpc_buffer_fd, size_t fastrpc_buffer_size, - size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromFastRpcBuffer( - *tensor_type, fastrpc_buffer_addr, fastrpc_buffer_fd, fastrpc_buffer_size, - fastrpc_buffer_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( - LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, - int* fastrpc_buffer_fd) { - if (!tensor_buffer || !fastrpc_buffer_addr || !fastrpc_buffer_fd) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto fastrpc_buffer = tensor_buffer->GetFastRpcBuffer(); - if (!fastrpc_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", fastrpc_buffer.Error().Message().c_str()); - return fastrpc_buffer.Error().Status(); - } - - *fastrpc_buffer_addr = fastrpc_buffer->first; - *fastrpc_buffer_fd = fastrpc_buffer->second; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_FASTRPC_SUPPORT - -LiteRtStatus LiteRtCreateTensorBufferFromGlBuffer( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromGlBuffer( - *tensor_type, target, id, size_bytes, offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().data()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferGlBuffer(LiteRtTensorBuffer tensor_buffer, - LiteRtGLenum* target, - LiteRtGLuint* id, size_t* size_bytes, - size_t* offset) { - if (!tensor_buffer || !target || !id) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto gl_buffer_expected = tensor_buffer->GetGlBuffer(); - if (!gl_buffer_expected) { - LITERT_LOG(LITERT_ERROR, "%s", - gl_buffer_expected.Error().Message().c_str()); - return gl_buffer_expected.Error().Status(); - } - *target = (*gl_buffer_expected)->target(); - *id = (*gl_buffer_expected)->id(); - *size_bytes = (*gl_buffer_expected)->size_bytes(); - *offset = (*gl_buffer_expected)->offset(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateTensorBufferFromGlTexture( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromGlTexture( - *tensor_type, target, id, format, size_bytes, layer, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferGlTexture( - LiteRtTensorBuffer tensor_buffer, LiteRtGLenum* target, LiteRtGLuint* id, - LiteRtGLenum* format, size_t* size_bytes, LiteRtGLint* layer) { - if (!tensor_buffer || !target || !id || !format || !size_bytes || !layer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto gl_texture_expected = tensor_buffer->GetGlTexture(); - if (!gl_texture_expected) { - LITERT_LOG(LITERT_ERROR, "%s", - gl_texture_expected.Error().Message().c_str()); - return gl_texture_expected.Error().Status(); - } - *target = (*gl_texture_expected)->target(); - *id = (*gl_texture_expected)->id(); - *format = (*gl_texture_expected)->format(); - *size_bytes = (*gl_texture_expected)->size_bytes(); - *layer = (*gl_texture_expected)->layer(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateManagedTensorBuffer( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType* tensor_type, size_t buffer_size, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateManaged( - buffer_type, *tensor_type, buffer_size); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDuplicateTensorBuffer(LiteRtTensorBuffer tensor_buffer) { - if (!tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - tensor_buffer->Duplicate(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferType* buffer_type) { - if (!tensor_buffer || !buffer_type) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_type = tensor_buffer->buffer_type(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferTensorType( - LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type) { - if (!tensor_buffer || !tensor_type) { - return kLiteRtStatusErrorInvalidArgument; - } - *tensor_type = tensor_buffer->tensor_type(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, - size_t* buffer_size) { - if (!tensor_buffer || !buffer_size) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_size = tensor_buffer->buffer_size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, - size_t* buffer_offset) { - if (!tensor_buffer || !buffer_offset) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_offset = tensor_buffer->buffer_offset(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, - void** host_memory_addr) { - if (!tensor_buffer || !host_memory_addr) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto host_buffer = tensor_buffer->GetHostBuffer(); - if (!host_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", host_buffer.Error().Message().c_str()); - return host_buffer.Error().Status(); - } - - *host_memory_addr = *host_buffer; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtHasTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - bool* has_event) { - if (!tensor_buffer || !has_event) { - return kLiteRtStatusErrorInvalidArgument; - } - *has_event = tensor_buffer->HasEvent(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent* event) { - if (!tensor_buffer || !event) { - return kLiteRtStatusErrorInvalidArgument; - } - auto result = tensor_buffer->GetEvent(); - if (!result) { - LITERT_LOG(LITERT_ERROR, "%s", result.Error().Message().c_str()); - return result.Error().Status(); - } - *event = *result; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent event) { - if (!tensor_buffer || !event) { - return kLiteRtStatusErrorInvalidArgument; - } - tensor_buffer->SetEvent(event); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtClearTensorBufferEvent(LiteRtTensorBuffer tensor_buffer) { - if (!tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - tensor_buffer->ClearEvent(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, - void** host_mem_addr) { - if (!tensor_buffer || !host_mem_addr) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto mapped_addr = tensor_buffer->Lock(); - if (!mapped_addr) { - LITERT_LOG(LITERT_ERROR, "%s", mapped_addr.Error().Message().c_str()); - return mapped_addr.Error().Status(); - } - - *host_mem_addr = *mapped_addr; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer tensor_buffer) { - if (!tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - - if (auto status = tensor_buffer->Unlock(); !status) { - LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().c_str()); - return status.Error().Status(); - } - - return kLiteRtStatusOk; -} - -void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer tensor_buffer) { - if (tensor_buffer->Unref()) { - delete tensor_buffer; - } -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h deleted file mode 100644 index 7f1fd8af836e1a..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#if LITERT_HAS_OPENCL_SUPPORT -#include -#endif // LITERT_HAS_OPENCL_SUPPORT -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#else -// Define a place holder AHardwareBuffer struct just to enable compilation. -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus -typedef struct AHardwareBuffer AHardwareBuffer; -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus -#endif // LITERT_HAS_AHWB_SUPPORT - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtTensorBuffer); - -#define LITERT_HOST_MEMORY_BUFFER_ALIGNMENT 64 - -typedef void (*LiteRtHostMemoryDeallocator)(void* addr); -typedef void (*LiteRtAhwbDeallocator)(AHardwareBuffer* ahwb); -typedef void (*LiteRtIonDeallocator)(void* ion_buffer_addr); -typedef void (*LiteRtDmaBufDeallocator)(void* dmabuf_buffer_addr); -typedef void (*LiteRtFastRpcDeallocator)(void* fastrpc_buffer_addr); -typedef void (*LiteRtOpenClDeallocator)(void* opencl_buffer_addr); -typedef void (*LiteRtGlBufferDeallocator)(void* gl_buffer_addr); -typedef void (*LiteRtGlTextureDeallocator)(void* gl_texture_addr); - -// ///////////////////////////////////////////////////////////////////////////// -// TensorBuffers. -// ///////////////////////////////////////////////////////////////////////////// - -// Create a tensor buffer from an existing host memory buffer of a given size, -// with optional host memory buffer deallocator (it can be NULL). Return an -// error if the passed host memory buffer doesn't satisfy -// LITERT_HOST_MEMORY_BUFFER_ALIGNMENT alignment. -LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( - const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, - size_t host_buffer_size, LiteRtHostMemoryDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not allocated on the host memory. -LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, - void** host_memory_addr); - -#if LITERT_HAS_AHWB_SUPPORT -// Create a tensor buffer from an existing AHardwareBuffer, with optional -// AHardwareBuffer deallocator (it can be NULL). An non-zero `buffer_offset` can -// be used to specify multiple tensor buffers sharing the same underlying AHWB, -// in which case the provided AHWB must be sufficiently large to accomodate for -// the allocation needed for all tensor buffers sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromAhwb( - const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not an AhardwareBuffer. -LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, - AHardwareBuffer** ahwb); -#endif // LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_ION_SUPPORT -// Create a tensor buffer from an existing ION buffer of a given size, with -// optional ION buffer deallocator (it can be NULL). An non-zero -// `ion_buffer_offset` can be used to specify multiple tensor buffers sharing -// the same underlying ION buffer, in which case parameter `ion_buffer_size` -// must be the entire size of the underlying ION memory buffer, including the -// allocation needed for all tensor buffers sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( - const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not an ION buffer. -LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer buffer, - void** ion_buffer_addr, - int* ion_buffer_fd); -#endif // LITERT_HAS_ION_SUPPORT - -#if LITERT_HAS_DMABUF_SUPPORT -// Create a tensor buffer from an existing DMA-BUF buffer of a given size, with -// optional DMA-BUF buffer deallocator (it can be NULL). An non-zero -// `dmabuf_buffer_offset` can be used to specify multiple tensor buffers sharing -// the same underlying ION buffer, in which case parameter `ion_buffer_size` -// must be the entire size of the underlying ION memory buffer, including the -// allocation needed for all tensor buffers sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( - const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not an DMA-BUF buffer. -LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, - void** dmabuf_buffer_addr, - int* dmabuf_buffer_fd); -#endif // LITERT_HAS_DMABUF_SUPPORT - -#if LITERT_HAS_FASTRPC_SUPPORT -// Create a tensor buffer from an existing FastRPC memory buffer of a given -// size, with optional FastRPC memory buffer deallocator (it can be NULL). An -// non-zero `fastrpc_buffer_offset` can be used to specify multiple tensor -// buffers sharing the same underlying FastRPC memory buffer, in which case -// parameter `fastrpc_buffer_size` must be the entire size of the underlying -// FastRPC memory buffer, including the allocation needed for all tensor buffers -// sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( - const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, - int fastrpc_fd, size_t fastrpc_buffer_size, size_t fastrpc_buffer_offset, - LiteRtFastRpcDeallocator deallocator, LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not a FastRPC memory buffer. -LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( - LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, - int* fastrpc_buffer_fd); -#endif // LITERT_HAS_FASTRPC_SUPPORT - -#if LITERT_HAS_OPENCL_SUPPORT -// Create a tensor buffer from an existing OpenCL buffer of a given size, with -// optional opencl memory buffer deallocator (it can be NULL). An non-zero -// `opencl_buffer_offset` can be used to specify multiple tensor buffers sharing -// the same underlying OpenCL buffer, in which case parameter -// `opencl_buffer_size` must be the entire size of the underlying OpenCL -// memory buffer, including the allocation needed for all tensor buffers -// sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromOpenClBuffer( - const LiteRtRankedTensorType* tensor_type, cl_mem cl_mem_addr, - size_t opencl_buffer_size, LiteRtOpenClDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not a OpenCL buffer. -LiteRtStatus LiteRtGetTensorBufferOpenClBuffer(LiteRtTensorBuffer tensor_buffer, - cl_mem* cl_mem_addr); -#endif // LITERT_HAS_OPENCL_SUPPORT - -LiteRtStatus LiteRtCreateTensorBufferFromGlBuffer( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator, LiteRtTensorBuffer* buffer); - -LiteRtStatus LiteRtGetTensorBufferGlBuffer(LiteRtTensorBuffer tensor_buffer, - LiteRtGLenum* target, - LiteRtGLuint* id, size_t* size_bytes, - size_t* offset); - -LiteRtStatus LiteRtCreateTensorBufferFromGlTexture( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator, LiteRtTensorBuffer* buffer); - -LiteRtStatus LiteRtGetTensorBufferGlTexture( - LiteRtTensorBuffer tensor_buffer, LiteRtGLenum* target, LiteRtGLuint* id, - LiteRtGLenum* format, size_t* size_bytes, LiteRtGLint* layer); - -// Create a buffer backed by managed memory for a given size. -LiteRtStatus LiteRtCreateManagedTensorBuffer( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType* tensor_type, size_t buffer_size, - LiteRtTensorBuffer* buffer); - -// Create a duplicate of the current tensor buffer. It will increase the -// reference count of a managed tensor buffer. And the number decreases when -// LiteRtDestroyTensorBuffer() is called. -LiteRtStatus LiteRtDuplicateTensorBuffer(LiteRtTensorBuffer tensor_buffer); - -LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferType* buffer_type); - -LiteRtStatus LiteRtGetTensorBufferTensorType( - LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type); - -LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, - size_t* size); - -LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, - size_t* offset); - -LiteRtStatus LiteRtHasTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - bool* has_event); - -// Return an event attached a given tensor buffer, or NULL if no such event -// exists. The tensor buffer retains ownership of the returned event. -LiteRtStatus LiteRtGetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent* event); - -// Attach a given event to a given tensor buffer. The tensor buffer takes -// ownership of the event. -LiteRtStatus LiteRtSetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent event); - -// Remove any event that may have been previously attached to the given tensor -// buffer and deallocate such event. -LiteRtStatus LiteRtClearTensorBufferEvent(LiteRtTensorBuffer tensor_buffer); - -// Lock a tensor buffer and map it to host memory, potentially synchronizing on -// an event that was previously attached to the tensor buffer with -// `LiteRtSetTensorBufferEvent`. -LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, - void** host_mem_addr); - -// Unlock a tensor buffer and (potentially) unmap it from host memory. -LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer buffer); - -// Destroy a tensor buffer. If the tensor buffer is managed, the number of -// references to it is decreased and released the underlying TensorBufferT when -// the last reference is removed. -void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer buffer); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc deleted file mode 100644 index fce2e4049f88e2..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateTensorBufferRequirements( - int num_supported_tensor_buffer_types, - const LiteRtTensorBufferType* supported_tensor_buffer_types, - size_t buffer_size, int num_strides, const uint32_t* strides, - LiteRtTensorBufferRequirements* requirements) { - if (num_supported_tensor_buffer_types < 1 || !supported_tensor_buffer_types || - !requirements) { - return kLiteRtStatusErrorInvalidArgument; - } - *requirements = new LiteRtTensorBufferRequirementsT( - num_supported_tensor_buffer_types, supported_tensor_buffer_types, - buffer_size, std::vector(strides, strides + num_strides)); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - LiteRtTensorBufferRequirements requirements, int* num_types) { - if (!requirements || !num_types) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_types = requirements->SupportedBufferTypes().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - LiteRtTensorBufferRequirements requirements, int type_index, - LiteRtTensorBufferType* type) { - if (!requirements || type_index < 0 || - type_index >= requirements->SupportedBufferTypes().size()) { - return kLiteRtStatusErrorInvalidArgument; - } - *type = requirements->SupportedBufferTypes()[type_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( - LiteRtTensorBufferRequirements requirements, size_t* buffer_size) { - if (!requirements || !buffer_size) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_size = requirements->BufferSize(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferRequirementsStrides( - LiteRtTensorBufferRequirements requirements, int* num_strides, - const uint32_t** strides) { - if (!requirements || !num_strides || !strides) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& s = requirements->Strides(); - *num_strides = s.size(); - *strides = s.data(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyTensorBufferRequirements( - LiteRtTensorBufferRequirements requirements) { - delete requirements; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h deleted file mode 100644 index 1c691a3ee38e9f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtTensorBufferRequirements); - -LiteRtStatus LiteRtCreateTensorBufferRequirements( - int num_supported_tensor_buffer_types, - const LiteRtTensorBufferType* supported_tensor_buffer_types, - size_t buffer_size, int num_strides, const uint32_t* strides, - LiteRtTensorBufferRequirements* requirements); - -LiteRtStatus LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - LiteRtTensorBufferRequirements requirements, int* num_types); - -LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - LiteRtTensorBufferRequirements requirements, int type_index, - LiteRtTensorBufferType* type); - -LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( - LiteRtTensorBufferRequirements requirements, size_t* buffer_size); - -LiteRtStatus LiteRtGetTensorBufferRequirementsStrides( - LiteRtTensorBufferRequirements requirements, int* num_strides, - const uint32_t** strides); - -void LiteRtDestroyTensorBufferRequirements( - LiteRtTensorBufferRequirements requirements); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc deleted file mode 100644 index 6a61eff786cbc9..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#include -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -namespace { - -constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { - kLiteRtTensorBufferTypeHostMemory, - kLiteRtTensorBufferTypeAhwb, - kLiteRtTensorBufferTypeIon, - kLiteRtTensorBufferTypeFastRpc, -}; - -constexpr const size_t kNumSupportedTensorBufferTypes = - sizeof(kSupportedTensorBufferTypes) / - sizeof(kSupportedTensorBufferTypes[0]); - -constexpr const size_t kBufferSize = 1234; - -} // namespace - -TEST(TensorBufferRequirements, NoStrides) { - LiteRtTensorBufferRequirements requirements; - ASSERT_EQ(LiteRtCreateTensorBufferRequirements( - kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, - kBufferSize, - /*num_strides=*/0, /*strides=*/nullptr, &requirements), - kLiteRtStatusOk); - - int num_types; - ASSERT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - requirements, &num_types), - kLiteRtStatusOk); - ASSERT_EQ(num_types, kNumSupportedTensorBufferTypes); - - for (auto i = 0; i < num_types; ++i) { - LiteRtTensorBufferType type; - ASSERT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - requirements, i, &type), - kLiteRtStatusOk); - ASSERT_EQ(type, kSupportedTensorBufferTypes[i]); - } - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferRequirementsBufferSize(requirements, &size), - kLiteRtStatusOk); - ASSERT_EQ(size, kBufferSize); - - LiteRtDestroyTensorBufferRequirements(requirements); -} - -TEST(TensorBufferRequirements, WithStrides) { - constexpr std::array kStrides = {1, 2, 3}; - - LiteRtTensorBufferRequirements requirements; - ASSERT_EQ(LiteRtCreateTensorBufferRequirements( - kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, - kBufferSize, kStrides.size(), kStrides.data(), &requirements), - kLiteRtStatusOk); - - int num_strides; - const uint32_t* strides; - ASSERT_EQ(LiteRtGetTensorBufferRequirementsStrides(requirements, &num_strides, - &strides), - kLiteRtStatusOk); - ASSERT_EQ(num_strides, kStrides.size()); - for (auto i = 0; i < kStrides.size(); ++i) { - ASSERT_EQ(strides[i], kStrides[i]); - } - - LiteRtDestroyTensorBufferRequirements(requirements); -} diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc deleted file mode 100644 index c77388d382f5de..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" - -namespace { -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / - sizeof(kTensorData[0])}; - -constexpr const LiteRtRankedTensorType kTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTensorDimensions)}; - -} // namespace - -TEST(TensorBuffer, HostMemory) { - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, Ahwb) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, Ion) { - if (!litert::internal::IonBuffer::IsSupported()) { - GTEST_SKIP() - << "ION buffers are not supported on this platform; skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, DmaBuf) { - if (!litert::internal::DmaBufBuffer::IsSupported()) { - GTEST_SKIP() - << "DMA-BUF buffers are not supported on this platform; skipping " - "the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, FastRpc) { - if (!litert::internal::FastRpcBuffer::IsSupported()) { - GTEST_SKIP() - << "FastRPC buffers are not supported on this platform; skipping " - "the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, Event) { - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - bool has_event = true; - ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), - kLiteRtStatusOk); - EXPECT_FALSE(has_event); - - LiteRtEvent event = new LiteRtEventT; - ASSERT_EQ(LiteRtSetTensorBufferEvent(tensor_buffer, event), kLiteRtStatusOk); - - has_event = false; - ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), - kLiteRtStatusOk); - EXPECT_TRUE(has_event); - - LiteRtEvent actual_event; - ASSERT_EQ(LiteRtGetTensorBufferEvent(tensor_buffer, &actual_event), - kLiteRtStatusOk); - ASSERT_EQ(actual_event, event); - - ASSERT_EQ(LiteRtClearTensorBufferEvent(tensor_buffer), kLiteRtStatusOk); - ASSERT_EQ(actual_event, event); - - has_event = true; - ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), - kLiteRtStatusOk); - EXPECT_FALSE(has_event); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, OpenCL) { -// MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan or tsan"; -#endif - - if (!litert::internal::OpenClBuffer::IsSupported()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeOpenCl; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -#if LITERT_HAS_OPENGL_SUPPORT -TEST(TensorBuffer, GlBuffer) { -// MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - - if (!litert::internal::GlBuffer::IsSupported()) { - GTEST_SKIP() << "OpenGL buffers are not supported on this platform; " - "skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeGlBuffer; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} -#endif // LITERT_HAS_OPENGL_SUPPORT diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h deleted file mode 100644 index 1953915c153c98..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_TYPES_H_ - -typedef enum { - kLiteRtTensorBufferTypeUnknown = 0, - kLiteRtTensorBufferTypeHostMemory = 1, - kLiteRtTensorBufferTypeAhwb = 2, - kLiteRtTensorBufferTypeIon = 3, - kLiteRtTensorBufferTypeDmaBuf = 4, - kLiteRtTensorBufferTypeFastRpc = 5, - kLiteRtTensorBufferTypeOpenCl = 6, - kLiteRtTensorBufferTypeGlBuffer = 7, - kLiteRtTensorBufferTypeGlTexture = 8, -} LiteRtTensorBufferType; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/cc/BUILD b/tensorflow/lite/experimental/litert/cc/BUILD deleted file mode 100644 index 566b35f370b70b..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/BUILD +++ /dev/null @@ -1,708 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "litert_environment", - hdrs = ["litert_environment.h"], - deps = [ - ":litert_any", - ":litert_expected", - ":litert_handle", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_environment_test", - srcs = [ - "litert_environment_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-llog"], - "//conditions:default": [], - }), - deps = [ - ":litert_any", - ":litert_compiled_model", - ":litert_environment", - ":litert_expected", - ":litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_event", - hdrs = ["litert_event.h"], - deps = [ - ":litert_expected", - ":litert_handle", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_event_type", - ], -) - -cc_library( - name = "litert_any", - hdrs = ["litert_any.h"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_any_test", - srcs = [ - "litert_any_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-llog"], - "//conditions:default": [], - }), - deps = [ - ":litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_consts", - hdrs = [ - "litert_consts.h", - ], -) - -cc_library( - name = "litert_model", - srcs = ["litert_model.cc"], - hdrs = [ - "litert_model.h", - ], - deps = [ - ":litert_buffer_ref", - ":litert_consts", - ":litert_detail", - ":litert_element_type", - ":litert_expected", - ":litert_handle", - ":litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_model_test", - srcs = [ - "litert_model_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - deps = [ - ":litert_element_type", - ":litert_layout", - ":litert_model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_handle", - hdrs = ["litert_handle.h"], -) - -cc_library( - name = "litert_tensor_buffer", - hdrs = [ - "litert_tensor_buffer.h", - "litert_tensor_buffer_requirements.h", - ], - deps = [ - ":litert_detail", - ":litert_event", - ":litert_expected", - ":litert_handle", - ":litert_macros", - ":litert_model", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_gl_types", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ], -) - -cc_test( - name = "litert_tensor_buffer_test", - srcs = [ - "litert_tensor_buffer_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - ":litert_element_type", - ":litert_event", - ":litert_layout", - ":litert_macros", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "litert_tensor_buffer_requirements", - hdrs = [ - "litert_tensor_buffer_requirements.h", - ], - deps = [ - ":litert_detail", - ":litert_handle", - ":litert_macros", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_tensor_buffer_requirements_test", - srcs = [ - "litert_tensor_buffer_requirements_test.cc", - ], - deps = [ - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_buffer_ref", - hdrs = [ - "litert_buffer_ref.h", - ], - deps = [ - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "litert_macros", - srcs = ["litert_macros.cc"], - hdrs = ["litert_macros.h"], - deps = [ - ":litert_expected", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - -cc_test( - name = "litert_macros_test", - srcs = ["litert_macros_test.cc"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_expected", - hdrs = ["litert_expected.h"], - deps = [ - ":litert_detail", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_test( - name = "litert_expected_test", - srcs = ["litert_expected_test.cc"], - deps = [ - ":litert_buffer_ref", - ":litert_expected", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_detail", - hdrs = ["litert_detail.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/log:absl_check", - ], -) - -# Dispatch Delegate of LiteRt. -# Warning: This API is not ABI stable and is subject to change. -cc_library( - name = "litert_dispatch_delegate", - hdrs = [ - "litert_dispatch_delegate.h", - ], - deps = [ - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/runtime/dispatch:dispatch_delegate", - ], -) - -cc_test( - name = "litert_buffer_ref_test", - srcs = ["litert_buffer_ref_test.cc"], - deps = [ - ":litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_element_type", - hdrs = ["litert_element_type.h"], - deps = ["//tensorflow/lite/experimental/litert/c:litert_model"], -) - -cc_test( - name = "litert_element_type_test", - srcs = ["litert_element_type_test.cc"], - deps = [ - ":litert_element_type", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_model_predicates", - srcs = ["litert_model_predicates.cc"], - hdrs = ["litert_model_predicates.h"], - deps = [ - ":litert_detail", - ":litert_model", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "litert_layout", - hdrs = ["litert_layout.h"], - deps = [ - ":litert_consts", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_model_predicates_test", - srcs = ["litert_model_predicates_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - deps = [ - ":litert_element_type", - ":litert_model", - ":litert_model_predicates", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_layout_test", - srcs = ["litert_layout_test.cc"], - deps = [ - ":litert_layout", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_compiled_model", - srcs = ["litert_compiled_model.cc"], - hdrs = ["litert_compiled_model.h"], - deps = [ - ":litert_compilation_options", - ":litert_environment", - ":litert_expected", - ":litert_handle", - ":litert_macros", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compiled_model", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "litert_compilation_options", - hdrs = ["litert_compilation_options.h"], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_environment", - ":litert_expected", - ":litert_handle", - ":litert_macros", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_compiled_model", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_compiled_model_test", - srcs = ["litert_compiled_model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":litert_compiled_model", - ":litert_environment", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_compiled_model_integration_test", - srcs = ["litert_compiled_model_integration_test.cc"], - deps = [ - ":litert_buffer_ref", - ":litert_compiled_model", - ":litert_environment", - ":litert_event", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_dispatch_headers", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - ], - "//conditions:default": [], - }), -) - -# copybara:uncomment_begin(google-only) -# cc_test( -# name = "litert_compiled_model_gpu_test", -# srcs = ["litert_compiled_model_gpu_test.cc"], -# data = [ -# "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", -# ], -# env = { -# "ASAN_OPTIONS": "detect_odr_violation=0", -# }, -# tags = [ -# "manual", -# "notap", -# "requires-gpu-nvidia", -# ], -# deps = [ -# ":litert_compiled_model", -# ":litert_environment", -# ":litert_event", -# ":litert_model", -# ":litert_tensor_buffer", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/debugging:leak_check", -# "@com_google_absl//absl/log:absl_log", -# "@com_google_absl//absl/strings:string_view", -# "@com_google_absl//absl/types:span", -# "//third_party/odml/infra/ml_drift_delegate/litert:ml_drift_cl_accelerator", # buildcleaner: keep -# "//tensorflow/lite:framework", -# "//tensorflow/lite/c:c_api_opaque", -# "//tensorflow/lite/c:common", -# "//tensorflow/lite/experimental/litert/c:litert_common", -# "//tensorflow/lite/experimental/litert/c:litert_event", -# "//tensorflow/lite/experimental/litert/c:litert_event_type", -# "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# "//tensorflow/lite/experimental/litert/test:simple_model", -# "//tensorflow/lite/kernels:builtin_ops", -# ], -# ) -# -# # The same test as above, but for Android. -# # This test doesn't run on TAP. -# # libLiteRtGpuAccelerator.so and libLiteRtRuntimeCApi.so are required to run this test. -# cc_test( -# name = "litert_compiled_model_gpu_android_test", -# srcs = ["litert_compiled_model_gpu_test.cc"], -# data = [ -# "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", -# ], -# tags = [ -# "manual", -# "notap", -# ], -# deps = [ -# ":litert_compiled_model", -# ":litert_environment", -# ":litert_event", -# ":litert_model", -# ":litert_tensor_buffer", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/debugging:leak_check", -# "@com_google_absl//absl/log:absl_log", -# "@com_google_absl//absl/strings:string_view", -# "@com_google_absl//absl/types:span", -# "//third_party/odml/infra/ml_drift_delegate/litert:ml_drift_cl_accelerator_shared_lib", # buildcleaner: keep -# "//tensorflow/lite:framework", -# "//tensorflow/lite/c:c_api_opaque", -# "//tensorflow/lite/c:common", -# "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# "//tensorflow/lite/experimental/litert/test:simple_model", -# "//tensorflow/lite/kernels:builtin_ops", -# ], -# ) -# copybara:uncomment_end - -cc_library( - name = "litert_tensor_buffer_utils", - srcs = ["litert_tensor_buffer_utils.cc"], - hdrs = ["litert_tensor_buffer_utils.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - ], -) - -cc_library( - name = "litert_op_options", - srcs = ["litert_op_options.cc"], - hdrs = ["litert_op_options.h"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_op_options_test", - srcs = ["litert_op_options_test.cc"], - deps = [ - ":litert_op_options", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_shared_library", - srcs = ["litert_shared_library.cc"], - hdrs = ["litert_shared_library.h"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "test_litert_shared_library_impl", - srcs = ["test_shared_library.cc"], -) - -cc_shared_library( - name = "test_litert_shared_library", - shared_lib_name = "test_shared_library.so", - deps = [":test_litert_shared_library_impl"], -) - -cc_test( - name = "litert_shared_library_test", - srcs = ["litert_shared_library_test.cc"], - data = [":test_litert_shared_library"], - defines = ["LITERT_DEFINE_GTEST_STATUS_PRINTER"], - deps = [ - ":litert_shared_library", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_event_test", - srcs = ["litert_event_test.cc"], - deps = [ - ":litert_event", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator_compilation_options", - hdrs = [ - "litert_accelerator_compilation_options.h", - ], - deps = [ - ":litert_expected", - ":litert_handle", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/strings:string_view", - ], -) - -exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h b/tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h deleted file mode 100644 index 80e36b6b4c980f..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -class AcceleratorCompilationOptions - : public internal::Handle { - public: - AcceleratorCompilationOptions() = default; - - // Parameter `owned` indicates if the created AcceleratorCompilationOptions - // object should take ownership of the provided `options` handle. - explicit AcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions options, bool owned = true) - : internal::Handle(options, - owned) {} - - static Expected Create( - const LiteRtApiVersion& payload_version, - const std::string& payload_identifier, void* payload_data, - void (*payload_destructor)(void* payload_data)) { - LiteRtAcceleratorCompilationOptions options; - LITERT_RETURN_IF_ERROR(LiteRtCreateAcceleratorCompilationOptions( - &payload_version, payload_identifier.c_str(), payload_data, - payload_destructor, &options)); - return AcceleratorCompilationOptions(options); - } - - Expected GetVersion() const { - LiteRtApiVersion payload_version; - LITERT_RETURN_IF_ERROR( - LiteRtGetAcceleratorCompilationOptionsVersion(Get(), &payload_version)); - return payload_version; - } - - Expected GetIdentifier() const { - const char* payload_identifier; - LITERT_RETURN_IF_ERROR(LiteRtGetAcceleratorCompilationOptionsIdentifier( - Get(), &payload_identifier)); - return absl::string_view(payload_identifier); - } - - template - Expected GetData() const { - void* payload_data; - LITERT_RETURN_IF_ERROR( - LiteRtGetAcceleratorCompilationOptionsData(Get(), &payload_data)); - return reinterpret_cast(payload_data); - } - - template - Expected> FindData( - const std::string& payload_identifier) { - LiteRtApiVersion payload_version; - void* payload_data; - LITERT_RETURN_IF_ERROR(LiteRtFindAcceleratorCompilationOptionsData( - Get(), payload_identifier.c_str(), &payload_version, &payload_data)); - return std::make_pair(payload_version, reinterpret_cast(payload_data)); - } - - Expected Next() { - auto h = Get(); - LITERT_RETURN_IF_ERROR(LiteRtGetNextAcceleratorCompilationOptions(&h)); - return AcceleratorCompilationOptions(h, /*owned=*/false); - } - - Expected Append(AcceleratorCompilationOptions&& appended_options) { - auto h = Get(); - LITERT_RETURN_IF_ERROR(LiteRtAppendAcceleratorCompilationOptions( - &h, appended_options.Release())); - if (h != Get()) { - // If appending a new linked list item has changed the linked list head - // pointer, then we need to reflect that as the new handle. Note that - // should happen only if the previous handle was null. - assert(!Get()); - *this = AcceleratorCompilationOptions(h); - } - return {}; - } - - Expected Pop() { - auto h = Get(); - LITERT_RETURN_IF_ERROR(LiteRtPopAcceleratorCompilationOptions(&h)); - if (h != Get()) { - // If popping the last item has changed the linked list head pointer, then - // we release the current handle since it has been already destructed by - // the pop call, and then use the new head pointer as the new handle. - (void)Release(); - *this = AcceleratorCompilationOptions(h); - } - return {}; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_any.h b/tensorflow/lite/experimental/litert/cc/litert_any.h deleted file mode 100644 index 97483ce3d63dcb..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_any.h +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ - -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -inline std::any ToStdAny(LiteRtAny litert_any) { - std::any res; - switch (litert_any.type) { - case kLiteRtAnyTypeNone: - break; - case kLiteRtAnyTypeBool: - res = litert_any.bool_value; - break; - case kLiteRtAnyTypeInt: - res = litert_any.int_value; - break; - case kLiteRtAnyTypeReal: - res = litert_any.real_value; - break; - case kLiteRtAnyTypeString: - res = litert_any.str_value; - break; - case kLiteRtAnyTypeVoidPtr: - res = litert_any.ptr_value; - break; - } - return res; -} - -inline Expected ToLiteRtAny(const std::any& any) { - LiteRtAny result; - if (!any.has_value()) { - result.type = kLiteRtAnyTypeNone; - return result; - - } else if (any.type() == typeid(LiteRtAny::bool_value)) { - result.type = kLiteRtAnyTypeBool; - result.bool_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int8_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int16_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int32_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int64_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(float)) { - result.type = kLiteRtAnyTypeReal; - result.real_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(double)) { - result.type = kLiteRtAnyTypeReal; - result.real_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(LiteRtAny::str_value)) { - result.type = kLiteRtAnyTypeString; - result.str_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(absl::string_view)) { - result.type = kLiteRtAnyTypeString; - result.str_value = std::any_cast(any).data(); - return result; - - } else if (any.type() == typeid(LiteRtAny::ptr_value)) { - result.type = kLiteRtAnyTypeVoidPtr; - result.ptr_value = std::any_cast(any); - return result; - - } else { - return Error(kLiteRtStatusErrorInvalidArgument, - "Invalid argument for ToLiteRtAny"); - } -} - -namespace internal { - -inline Expected CheckType(const LiteRtAny& any, - const LiteRtAnyType type) { - if (any.type != kLiteRtAnyTypeString) { - return Error(kLiteRtStatusErrorInvalidArgument, - absl::StrFormat("Wrong LiteRtAny type. Expected %s, got %s.", - LiteRtAnyTypeToString(type), - LiteRtAnyTypeToString(any.type))); - } - return {}; -} - -template -Expected GetInt(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(CheckType(any, kLiteRtAnyTypeInt)); - if (any.int_value > std::numeric_limits::max() || - any.int_value < std::numeric_limits::lowest()) { - return Error( - kLiteRtStatusErrorInvalidArgument, - absl::StrFormat("LiteRtAny integer is out of range. %v <= %v <= %v", - std::numeric_limits::lowest(), any.int_value, - std::numeric_limits::max())); - } - return any.int_value; -} - -template -Expected GetReal(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(CheckType(any, kLiteRtAnyTypeReal)); - if (any.real_value > std::numeric_limits::max() || - any.real_value < std::numeric_limits::lowest()) { - return Error( - kLiteRtStatusErrorInvalidArgument, - absl::StrFormat( - "LiteRtAny integer is out of range. %v <= %v <= %v failed.", - std::numeric_limits::lowest(), any.real_value, - std::numeric_limits::max())); - } - return any.real_value; -} -} // namespace internal - -// Extracts the value from a LiteRtAny object with type checking. -template -inline Expected Get(const LiteRtAny& any); - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeBool)); - return any.bool_value; -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetReal(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetReal(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeString)); - return std::string(any.str_value); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeString)); - return absl::string_view(any.str_value); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeVoidPtr)); - return any.ptr_value; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_any_test.cc b/tensorflow/lite/experimental/litert/cc/litert_any_test.cc deleted file mode 100644 index c6640ab8060c1c..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_any_test.cc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" - -TEST(Any, ConversionNone) { - EXPECT_FALSE( - litert::ToStdAny(LiteRtAny{/*.type=*/kLiteRtAnyTypeNone}).has_value()); - - ASSERT_EQ(litert::ToLiteRtAny(std::any())->type, kLiteRtAnyTypeNone); -} - -TEST(Any, ConversionBool) { - ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ - /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/true}})), - true); - ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ - /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/false}})), - false); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->type, kLiteRtAnyTypeBool); - ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->bool_value, true); - ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->type, kLiteRtAnyTypeBool); - ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->bool_value, false); -} - -TEST(Any, ConversionInt) { - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeInt; - litert_any.int_value = 1234; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 1234); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->int_value, - 12); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ( - litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, - 1234); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ( - litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, - 1234); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ( - litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, - 1234); -} - -TEST(Any, ConversionReal) { - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeReal; - litert_any.real_value = 123.4; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 123.4); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, - kLiteRtAnyTypeReal); - EXPECT_NEAR( - litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, - 1e-7); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, - kLiteRtAnyTypeReal); - EXPECT_NEAR( - litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, - 1e-7); -} - -TEST(Any, ConversionString) { - constexpr const char* kTestString = "test"; - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeString; - litert_any.str_value = kTestString; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), - kTestString); - - ASSERT_EQ(litert::ToLiteRtAny(std::any("test"))->type, kLiteRtAnyTypeString); - EXPECT_STREQ(litert::ToLiteRtAny(std::any("test"))->str_value, "test"); -} - -TEST(Any, ConversionPtr) { - const void* kTestPtr = reinterpret_cast(1234); - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeVoidPtr; - litert_any.ptr_value = kTestPtr; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), kTestPtr); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->type, - kLiteRtAnyTypeVoidPtr); - EXPECT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->ptr_value, kTestPtr); -} diff --git a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h b/tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h deleted file mode 100644 index c81b5d12524afc..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -namespace litert { - -//===----------------------------------------------------------------------===// -// -// << BUFFER REF >> -// -// Read, read/write, and owning views of buffers of arbitrary byte width types. -// -// Serialized model artifacts and assets are frequently large strings that with -// (annoyingly) non-standard char type and left padded. The following classes -// simplify handling such buffers in an efficient copy free manner. They also -// provide read and write left-padded aware interpretebility through standard -// signed char strings types. This is used for making manual edits to flatbuffer -// metadata or dierctly to serialized flatbuffer. -// NOTE: std::basic_xxx not supported by our C++ toolchain. -// -// Pre-allocated buffers can be transferred to these classes or allocation can -// be internalized. XBufferRefs can be implictly upcasted to non-owning -// read/write or read-only to provide other routines with an appropriate view of -// the data. E.g.: -// -// ``` -// void ReadBuffer(BufferRef r_buf) { std::cerr << r_buf.StrView(); } -// void WriteToBuffer(MutableBufferRef rw_buf) { rw_buf.WriteTo("SomeData"); } -// ... -// OwningBuffer buf(size); -// WriteToBuffer(buf); // Implicitly convert to read/write with no ownership. -// ReadBuffer(buf); // Implicitly convert to read-only. -// ``` -// -//===----------------------------------------------------------------------===// - -// Allocation/Deallocation behavior for owning buffer refs. An allocator is a -// trivially constructible/destructible object that overrides () for allocating -// and freeing memory. - -// Malloc/free based memory. -template -struct Mallocator { - void operator()(ByteT* d) { - if (d != nullptr) { - free(d); - } - } - - ByteT* operator()(size_t bytes) { - return reinterpret_cast(malloc(bytes)); - } -}; - -// New/delete based memory. -template -struct Newlocator { - void operator()(ByteT* d) { - if (d != nullptr) { - delete[] d; - } - } - - ByteT* operator()(size_t bytes) { return new ByteT[bytes]; } -}; - -// -// Read-Only Bytes -// - -// Immutable and non-owning view of a buffer. -template -class BufferRef { - public: - using TupleT = std::tuple; - - // Null buffer. - BufferRef() : size_(0), offset_(0), data_(nullptr) {} - - // Construct from already allocated buffer. Methods will only expose - // data[offset, offset + size]. - BufferRef(const ByteT* data, size_t size, size_t offset = 0) - : size_(size), offset_(offset), data_(const_cast(data)) {} - BufferRef(const void* data, size_t size, size_t offset = 0) - : size_(size), - offset_(offset), - data_(const_cast(reinterpret_cast(data))) {} - explicit BufferRef(absl::Span data) - : size_(data.size()), - offset_(0), - data_(const_cast(data.data())) {} - - // Start of actual data. - const ByteT* Data() const { return data_ + offset_; } - - // Size of actual data. - size_t Size() const { return size_ - offset_; } - - // Get buffer details in tuple form. - TupleT Get() const { return TupleT(data_, size_, offset_); } - - // Start of actual data as signed char. Might not be null terminated. - const char* StrData() const { return reinterpret_cast(Data()); } - - // Convenience view of actual data as a string. Makes null terminated. - absl::string_view StrView() const { - return absl::string_view(StrData(), Size()); - } - - // Const view of actual data. - absl::Span Span() const { - return absl::MakeConstSpan(Data(), Size()); - } - - // Copy the buffer data to a vector. - std::vector ToVec() const { - return std::vector(StrData(), StrData() + Size()); - } - - // Write the string data to a stream. - void WriteStr(std::ostream& out) const { out.write(StrData(), Size()); } - - // Print info about this buffer. - void Dump(std::ostream& out) const { - out << absl::StreamFormat("%s[%lu:%lu]\n", TypeName(), offset_, size_); - } - - BufferRef(const BufferRef& other) = default; - BufferRef& operator=(const BufferRef& other) = default; - - virtual ~BufferRef() = default; - - protected: - size_t size_; - size_t offset_; - ByteT* data_ = nullptr; - - // Debug name. - virtual absl::string_view TypeName() const { return "BufferRef"; } -}; -template -BufferRef(const ByteT*, size_t, size_t) -> BufferRef; - -// -// Read-Write Non-Owning Bytes -// - -// Writeable (but still non-owning) version of BufferRef. -template -class MutableBufferRef : public BufferRef { - public: - using TupleT = std::tuple; - - // Null buffer. - MutableBufferRef() - : BufferRef((ByteT*)nullptr, /*size*/ 0, /*offset*/ 0) {} - - // Create a mutable view from pre-allocated non-const buffer. - MutableBufferRef(ByteT* data, size_t size, size_t offset = 0) - : BufferRef(data, size, offset) {} - MutableBufferRef(void* data, size_t size, size_t offset = 0) - : BufferRef(data, size, offset) {} - explicit MutableBufferRef(absl::Span data) : BufferRef(data) {} - explicit MutableBufferRef(absl::Span data) = delete; - MutableBufferRef(const ByteT*, size_t, size_t) = delete; - MutableBufferRef(const void*, size_t, size_t) = delete; - - // Mutable start of actual data. - ByteT* Data() { return this->data_ + this->offset_; } - - // Get the mutable start of actual data as a char pointer. - char* StrData() { return reinterpret_cast(Data()); } - - // Get buffer info in tuple form. - TupleT Get() { return TupleT(this->data_, this->size_, this->offset_); } - - // Mutable span of actual data. - absl::Span Span() { return absl::MakeSpan(Data(), this->Size()); } - - // Write string into the actual buffer at offset. Returns false if the entire - // string cannot fit into the actual buffer. - bool WriteInto(absl::string_view str, size_t offset = 0) { - if (str.size() > this->Size() - offset) { - return false; - } - std::memcpy(Data() + offset, str.data(), str.size()); - return true; - } - - MutableBufferRef(const MutableBufferRef& other) = default; - MutableBufferRef& operator=(const MutableBufferRef& other) = default; - - protected: - // Debug name. - absl::string_view TypeName() const override { return "MutableBufferRef"; } -}; -template -MutableBufferRef(ByteT*, size_t, size_t) -> MutableBufferRef; - -// -// Read-Write Owning Bytes -// - -// Writable and owning buffer reference. Can allocate new buffers internally and -// take ownership of existing buffers. Does not support resizing. -template > -class OwningBufferRef : public MutableBufferRef { - public: - using TupleT = std::tuple; - using WeakTupleT = std::tuple; - - // Null buffer. - OwningBufferRef() - : MutableBufferRef(/*data*/ (ByteT*)nullptr, /*size*/ 0, - /*offset*/ 0) {} - - // Initialize a new buffer reference and allocate internally. - explicit OwningBufferRef(size_t size) - : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, /*offset*/ 0) { - this->data_ = (ByteT*)Allocator()(size); - } - - // Take ownership of given buffer. - OwningBufferRef(ByteT* data, size_t size, size_t offset = 0) - : MutableBufferRef(data, size, offset) {} - OwningBufferRef(void* data, size_t size, size_t offset = 0) - : MutableBufferRef(data, size, offset) {} - explicit OwningBufferRef(absl::Span data) - : MutableBufferRef(data) {} - - // Copy the given buffer. - OwningBufferRef(const ByteT* data, size_t size) - : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, - /*offset*/ 0) { - this->data_ = (ByteT*)Allocator()(size); - std::memcpy(this->data_, data, size); - } - explicit OwningBufferRef(absl::Span data) - : OwningBufferRef(data.data(), data.size()) {} - - // Copy data from givens string. - explicit OwningBufferRef(absl::string_view data) - : OwningBufferRef( - reinterpret_cast(data.data()), data.size()) {} - - // Copy data from given c-style string. - explicit OwningBufferRef(const char* data) - : OwningBufferRef(absl::string_view(data)) {} - - // Drop reference to any owned memory. - void Drop() { - this->data_ = nullptr; - this->size_ = 0; - this->offset_ = 0; - } - - // Get the buffer details and drop references to them. - TupleT Release() { - auto res = std::make_tuple(this->data_, this->size_, this->offset_); - Drop(); - return res; - } - - // Get weak references to buffer data. Takes ownership of anything that - // is swapped in. - WeakTupleT GetWeak() { - return WeakTupleT(this->data_, this->size_, this->offset_); - } - - // Free any owned memory. - void Reset() { - Allocator()(this->data_); - Drop(); - } - - // Reset any existing data and copy in given ro buffer. - void Assign(const ByteT* buf, size_t size, size_t offset = 0) { - Reset(); - this->size_ = size; - this->data_ = (ByteT*)Allocator()(this->size_); - std::memcpy(this->data_, buf, this->size_); - this->offset_ = offset; - } - - OwningBufferRef(OwningBufferRef&& other) - : MutableBufferRef(other.data_, other.size_, other.offset_) { - other.Drop(); - } - - OwningBufferRef& operator=(OwningBufferRef&& other) { - if (this != &other) { - Reset(); - this->data_ = other.data_; - this->size_ = other.size_; - this->offset_ = other.offset_; - other.Drop(); - } - return *this; - } - - OwningBufferRef(const OwningBufferRef& other) - : MutableBufferRef(/*data*/ (ByteT*)nullptr, other.size_, - other.offset_) { - Assign(other.data_, other.size_, other.offset_); - } - - OwningBufferRef& operator=(const OwningBufferRef& other) { - Assign(other.data_, other.size_, other.offset_); - return *this; - } - - ~OwningBufferRef() override { Reset(); } - - protected: - // Debug string. - absl::string_view TypeName() const override { return "OwningBufferRef"; } -}; - -template > -OwningBufferRef(const ByteT*, size_t) -> OwningBufferRef; - -template > -OwningBufferRef(ByteT*, size_t) -> OwningBufferRef; - -template > -OwningBufferRef(const char*) -> OwningBufferRef; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc b/tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc deleted file mode 100644 index a2900d0c8946fd..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -using litert::BufferRef; -using litert::Mallocator; -using litert::MutableBufferRef; -using litert::Newlocator; -using litert::OwningBufferRef; -using litert::internal::FbBufToStr; -using testing::ElementsAreArray; -using testing::Eq; -using testing::Pointwise; -using testing::StartsWith; - -namespace { - -static constexpr size_t kOffset = 4; - -static constexpr absl::string_view kData = "SomeRawBuffer"; -static constexpr absl::string_view kOtherData = "SOMERawBuffer"; - -absl::Span MakeConstFbData(absl::string_view data) { - const uint8_t* fb_data = reinterpret_cast(data.data()); - return absl::MakeConstSpan(fb_data, data.size()); -} - -absl::Span MakeFbData(absl::string_view data) { - const uint8_t* c_fb_data = reinterpret_cast(data.data()); - uint8_t* fb_data = const_cast(c_fb_data); - return absl::MakeSpan(fb_data, data.size()); -} - -std::vector MakeFbDataVec(absl::string_view data) { - const uint8_t* c_fb_data = reinterpret_cast(data.data()); - uint8_t* fb_data = const_cast(c_fb_data); - return std::vector(fb_data, fb_data + data.size()); -} - -template , typename ByteT = uint8_t> -absl::Span MakeInternalTestBuffer(absl::string_view data) { - ByteT* buffer = Allocator()(data.size()); - std::memcpy(buffer, data.data(), data.size()); - return absl::MakeSpan(reinterpret_cast(buffer), data.size()); -} - -// -// flatbuffer_tools.h -// - -TEST(FbBufToStringTest, ConstSpan) { - EXPECT_THAT(FbBufToStr(MakeConstFbData(kData)), Pointwise(Eq(), kData)); -} - -TEST(FbBufToStringTest, Span) { - EXPECT_THAT(FbBufToStr(MakeFbData(kData)), Pointwise(Eq(), kData)); -} - -TEST(FbBufToStringTest, ConstPointer) { - auto data = MakeConstFbData(kData); - EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); -} - -TEST(FbBufToStringTest, Pointer) { - auto data = MakeFbData(kData); - EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); -} - -// -// BufferRef (read-only) -// - -TEST(BufferRefTest, Dump) { - BufferRef buf(kData.data(), kData.size()); - std::stringstream out; - buf.Dump(out); - EXPECT_THAT(out.str(), StartsWith("BufferRef")); -} - -TEST(BufferRefTest, WithData) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size()); - EXPECT_EQ(buf.Span(), data); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(BufferRefTest, WithDataAndOffset) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size(), kOffset); - EXPECT_EQ(buf.Span(), data.subspan(kOffset, buf.Size())); - EXPECT_EQ(buf.StrView(), kData.substr(kOffset, buf.Size())); -} - -TEST(BufferRefTest, ToVec) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size()); - EXPECT_THAT(buf.ToVec(), ElementsAreArray(data)); -} - -TEST(BufferRefTest, WriteStr) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size()); - std::stringstream out; - buf.WriteStr(out); - EXPECT_EQ(out.str(), kData); -} - -TEST(BufferRefTest, WriteStrOffset) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size(), kOffset); - std::stringstream out; - buf.WriteStr(out); - EXPECT_EQ(out.str(), kData.substr(kOffset, buf.Size())); -} - -TEST(BufferRefTest, TupleGet) { - auto input = MakeConstFbData(kData); - BufferRef buf(input); - auto [data, size, offset] = buf.Get(); - ASSERT_EQ(offset, 0); - EXPECT_EQ(input, buf.Span()); -} - -// -// MutableBufferRef (read/write) -// - -TEST(MutableBufferRefTest, Dump) { - MutableBufferRef buf; - std::stringstream out; - buf.Dump(out); - EXPECT_THAT(out.str(), StartsWith("MutableBufferRef")); -} - -TEST(MutableBufferRefTest, WriteInto) { - auto v_data = MakeFbDataVec(kOtherData); - MutableBufferRef buf(v_data.data(), v_data.size()); - ASSERT_TRUE(buf.WriteInto("Some")); - EXPECT_THAT(buf.Span(), ElementsAreArray(v_data)); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(MutableBufferRefTest, WriteIntoOffsetBuf) { - auto v_data = MakeFbDataVec(kOtherData); - static constexpr absl::string_view kExpData = "RAWBuffer"; - MutableBufferRef buf(v_data.data(), v_data.size(), kOffset); - ASSERT_TRUE(buf.WriteInto("RAW")); - EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); - EXPECT_EQ(buf.StrView(), kExpData); -} - -TEST(MutableBufferRefTest, WriteIntoOffsetData) { - auto v_data = MakeFbDataVec(kOtherData); - static constexpr absl::string_view kExpData = "SOMERAWBuffer"; - MutableBufferRef buf(v_data.data(), v_data.size()); - ASSERT_TRUE(buf.WriteInto("RAW", kOffset)); - EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); - EXPECT_EQ(buf.StrView(), kExpData); -} - -TEST(MutableBufferRefTest, TupleGet) { - auto input = MakeInternalTestBuffer("FOO"); - MutableBufferRef buf(input); - auto [data, size, offset] = buf.Get(); - *data = 'b'; - EXPECT_EQ(buf.StrView(), "bOO"); - delete[] input.data(); -} - -// -// OwningBufferRef (read/write with memory management) -// - -TEST(OwningBufferRefTest, Dump) { - OwningBufferRef buf; - std::stringstream out; - buf.Dump(out); - EXPECT_THAT(out.str(), StartsWith("OwningBufferRef")); -} - -TEST(OwningBufferRefTest, MoveCstor) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other(std::move(buf)); - EXPECT_EQ(other.StrView(), kData); -} - -TEST(OwningBufferRefTest, MoveAssign) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other = std::move(buf); - EXPECT_EQ(other.StrView(), kData); -} - -TEST(OwningBufferRefTest, CopyCstor) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other(buf); - other.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), kData); - EXPECT_EQ(other.StrView(), "SOMERawBuffer"); -} - -TEST(OwningBufferRefTest, CopyAssign) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other = buf; - other.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), kData); - EXPECT_EQ(other.StrView(), "SOMERawBuffer"); -} - -TEST(OwningBufferRefTest, InternalMalloc) { - OwningBufferRef> buf(kData.size()); - ASSERT_EQ(buf.Size(), kData.size()); - ASSERT_NE(buf.Data(), nullptr); - - buf.WriteInto(kData); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, InternalNew) { - OwningBufferRef buf(kData.size()); - ASSERT_EQ(buf.Size(), kData.size()); - ASSERT_NE(buf.Data(), nullptr); - - buf.WriteInto(kData); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, TakeOwnershipMalloc) { - auto malloc_buffer = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(malloc_buffer.data(), - malloc_buffer.size()); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, TakeOwnershipNew) { - auto new_buffer = MakeInternalTestBuffer(kData); - OwningBufferRef buf(new_buffer.data(), new_buffer.size()); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, TakeOwnershipOffset) { - auto malloc_buffer = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(malloc_buffer.data(), - malloc_buffer.size(), - /*offset=*/4); - EXPECT_EQ(buf.StrView(), "RawBuffer"); -} - -TEST(OwningBufferRefTest, CopyBuffer) { - auto const_buf = MakeConstFbData(kData); - OwningBufferRef buf(const_buf.data(), const_buf.size()); - buf.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); - EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); -} - -TEST(OwningBufferRefTest, ImplicitUpCasts) { - OwningBufferRef buf(kData.size()); - BufferRef c_buf = buf; - - buf.WriteInto(kData); - EXPECT_EQ(c_buf.StrView(), buf.StrView()); -} - -TEST(OwningBufferRefTest, TupleGetWeak) { - auto input = MakeInternalTestBuffer("FOO"); - - OwningBufferRef buf; - auto [data, size, offset] = buf.GetWeak(); - - data = input.data(); - size = input.size(); - offset = 0; - - ASSERT_EQ(buf.Size(), input.size()); - ASSERT_EQ(buf.Size(), input.size()); - - buf.WriteInto("BAR"); - - EXPECT_EQ(buf.StrView(), "BAR"); - EXPECT_EQ(buf.Span(), input); -} - -TEST(OwningBufferRefTest, TupleRelease) { - OwningBufferRef buf("BAZ"); - - auto [data, size, offset] = buf.Release(); - - EXPECT_EQ(buf.Size(), 0); - EXPECT_EQ(absl::string_view(data, size), "BAZ"); - - delete[] data; -} - -TEST(OwningBufferRefTest, Assign) { - auto const_buf = MakeConstFbData(kData); - OwningBufferRef buf; - buf.Assign(const_buf.data(), const_buf.size()); - buf.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); - EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/cc/litert_compilation_options.h b/tensorflow/lite/experimental/litert/cc/litert_compilation_options.h deleted file mode 100644 index 8a21f22d120a79..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compilation_options.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -class CompilationOptions - : public internal::Handle { - public: - CompilationOptions() = default; - - // Parameter `owned` indicates if the created CompilationOptions object - // should take ownership of the provided `compilation_options` handle. - explicit CompilationOptions(LiteRtCompilationOptions compilation_options, - bool owned = true) - : internal::Handle(compilation_options, - owned) {} - - static Expected Create() { - LiteRtCompilationOptions options; - LITERT_RETURN_IF_ERROR(LiteRtCreateCompilationOptions(&options)); - return CompilationOptions(options); - } - - Expected SetHardwareAccelerators(LiteRtHwAcceleratorSet accelerators) { - LITERT_RETURN_IF_ERROR( - LiteRtSetCompilationOptionsHardwareAccelerators(Get(), accelerators)); - return {}; - } - - Expected GetHardwareAccelerators() { - LiteRtHwAcceleratorSet accelerators; - LITERT_RETURN_IF_ERROR( - LiteRtGetCompilationOptionsHardwareAccelerators(Get(), &accelerators)); - return accelerators; - } - - Expected AddAcceleratorCompilationOptions( - AcceleratorCompilationOptions&& options) { - LITERT_RETURN_IF_ERROR( - LiteRtAddAcceleratorCompilationOptions(Get(), options.Release())); - return {}; - } - - Expected GetAcceleratorCompilationOptions() { - LiteRtAcceleratorCompilationOptions options; - LITERT_RETURN_IF_ERROR( - LiteRtGetAcceleratorCompilationOptions(Get(), &options)); - return AcceleratorCompilationOptions(options, /*owned=*/false); - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc deleted file mode 100644 index 9a6658dcaa481c..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace litert { - -Expected CompiledModel::FindInputIndex( - size_t signature_index, absl::string_view input_name) const { - LITERT_ASSIGN_OR_RETURN(const Signature& signature, - model_.GetSignature(signature_index)); - const std::vector& input_names = signature.InputNames(); - auto it = std::find(input_names.begin(), input_names.end(), input_name); - if (it != input_names.end()) { - return std::distance(input_names.begin(), it); - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); -} - -Expected CompiledModel::FindOutputIndex( - size_t signature_index, absl::string_view output_name) const { - LITERT_ASSIGN_OR_RETURN(const Signature& signature, - model_.GetSignature(signature_index)); - const std::vector& output_names = signature.OutputNames(); - auto it = std::find(output_names.begin(), output_names.end(), output_name); - if (it != output_names.end()) { - return std::distance(output_names.begin(), it); - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); -} - -Expected CompiledModel::CreateBufferImpl( - const TensorBufferRequirements& buffer_requirements, - const RankedTensorType& tensor_type) { - LITERT_ASSIGN_OR_RETURN( - const std::vector& supported_types, - buffer_requirements.SupportedTypes()); - if (supported_types.empty()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Input doesn't support any tensor buffer types"); - } - // For simplicity we just pick the first supported tensor buffer type. - LiteRtTensorBufferType tensor_buffer_type = supported_types[0]; - LITERT_ASSIGN_OR_RETURN(size_t buffer_size, buffer_requirements.BufferSize()); - - LITERT_ASSIGN_OR_RETURN(TensorBuffer buffer, - TensorBuffer::CreateManaged( - tensor_buffer_type, tensor_type, buffer_size)); - return buffer; -} - -Expected CompiledModel::CreateInputOutputBuffer( - size_t signature_index, absl::string_view tensor_name, - bool is_input) const { - LITERT_ASSIGN_OR_RETURN(Signature signature, - model_.GetSignature(signature_index)); - - LITERT_ASSIGN_OR_RETURN(Subgraph subgraph, model_.Subgraph(signature.Key())); - - Expected tensor_expected = - is_input ? subgraph.Input(tensor_name) : subgraph.Output(tensor_name); - Expected buffer_requirements_expected = - is_input ? GetInputBufferRequirements(signature_index, tensor_name) - : GetOutputBufferRequirements(signature_index, tensor_name); - - LITERT_ASSIGN_OR_RETURN(const Tensor& tensor, tensor_expected); - LITERT_ASSIGN_OR_RETURN(const TensorBufferRequirements& buffer_requirements, - buffer_requirements_expected); - LITERT_ASSIGN_OR_RETURN(const RankedTensorType& tensor_type, - tensor.RankedTensorType()); - - return CreateBufferImpl(buffer_requirements, tensor_type); -} - -Expected> CompiledModel::CreateInputOutputBuffers( - size_t signature_index, bool is_input) const { - LITERT_ASSIGN_OR_RETURN(const Signature& signature, - model_.GetSignature(signature_index)); - LITERT_ASSIGN_OR_RETURN(const Subgraph subgraph, - model_.Subgraph(signature.Key())); - std::vector tensor_buffers; - std::vector tensor_names; - - tensor_names = is_input ? signature.InputNames() : signature.OutputNames(); - tensor_buffers.reserve(tensor_names.size()); - - for (int i = 0; i < tensor_names.size(); ++i) { - LITERT_ASSIGN_OR_RETURN( - TensorBuffer tensor_buffer, - CreateInputOutputBuffer(signature.Key(), tensor_names[i], is_input)); - tensor_buffers.push_back(std::move(tensor_buffer)); - } - - return tensor_buffers; -} - -Expected CompiledModel::RunCApiHelper(LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool& async) const { - LiteRtStatus status = - async ? LiteRtRunCompiledModelAsync( - Get(), signature_index, num_input_buffers, input_buffers, - num_output_buffers, output_buffers, &async) - : LiteRtRunCompiledModel(Get(), signature_index, num_input_buffers, - input_buffers, num_output_buffers, - output_buffers); - if (status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to invoke the compiled model"); - } - return {}; -} - -Expected CompiledModel::RunHelper( - size_t signature_index, const std::vector& input_buffers, - const std::vector& output_buffers, bool& async) const { - auto input_buffers_ptr = - std::make_unique(input_buffers.size()); - for (int i = 0; i < input_buffers.size(); ++i) { - input_buffers_ptr[i] = input_buffers[i].Get(); - } - auto output_buffers_ptr = - std::make_unique(output_buffers.size()); - for (int i = 0; i < output_buffers.size(); ++i) { - output_buffers_ptr[i] = output_buffers[i].Get(); - } - return RunCApiHelper(signature_index, input_buffers.size(), - input_buffers_ptr.get(), output_buffers.size(), - output_buffers_ptr.get(), async); -} - -Expected CompiledModel::RunMapHelper( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const { - auto signature_index = model_.GetSignatureIndex(signature_key); - if (!signature_index) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature_index"); - } - auto subgraph = model_.Subgraph(signature_key); - if (!subgraph) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); - } - return RunMapWithIndexHelper(*signature_index, *subgraph, input_map, - output_map, async); -} - -Expected CompiledModel::RunMapWithIndexHelper( - size_t signature_index, const Subgraph& subgraph, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const { - auto input_tensors = subgraph.Inputs(); - size_t num_inputs = input_tensors.size(); - auto input_buffers_ptr = std::make_unique(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - absl::string_view input_name = input_tensors[i].Name(); - auto it = input_map.find(input_name); - if (it == input_map.end()) { - return Unexpected(kLiteRtStatusErrorNotFound, - "The given map is missing some input TensorBuffers"); - } - input_buffers_ptr[i] = it->second.Get(); - } - auto output_tensors = subgraph.Outputs(); - size_t num_outputs = output_tensors.size(); - auto output_buffers_ptr = std::make_unique(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - absl::string_view output_name = output_tensors[i].Name(); - auto it = output_map.find(output_name); - if (it == output_map.end()) { - return Unexpected(kLiteRtStatusErrorNotFound, - "The given map is missing some output TensorBuffers"); - } - output_buffers_ptr[i] = it->second.Get(); - } - return RunCApiHelper(signature_index, num_inputs, input_buffers_ptr.get(), - num_outputs, output_buffers_ptr.get(), async); -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h deleted file mode 100644 index 7ad7207ad569ce..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h +++ /dev/null @@ -1,438 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace litert { - -// The CompiledModel is a higher level inference API. It is created by -// provided model with compilation options. Internally, it instantiates runtime -// and applies Delegates mapped to the compilation options. -// It also supports getting BufferRequirements to create input/output -// TensorBuffers, and it allows to invoke the model with the input/output -// TensorBuffers. -// -// Example user flow: -// -// 1. Create CompiledModel -// 2. Query the model input/output requirements -// 3. Create input/output TensorBuffers -// 4. Fill the input TensorBuffers with input data -// 5. Invoke the model with the input/output TensorBuffers -// 6. Evaluate the output TensorBuffers - -class CompiledModel - : public internal::Handle { - public: - CompiledModel() = default; - - // Creates a CompiledModel instance. - // - // If `owned` is `true`, then the created object takes ownership of the - // `compiled_model` handle. - explicit CompiledModel(LiteRtModel litert_model, - LiteRtCompiledModel compiled_model, bool owned = true) - : internal::Handle( - compiled_model, owned), - model_(Model::CreateFromNonOwnedHandle(litert_model)) {} - - // Creates a CompiledModel from a TFLite file. - // - // The model is loaded into memory and the caller takes ownership of the - // returned CompiledModel object. The caller should keep the model alive - // until the CompiledModel is destroyed. - // The given `compilation_options` is used for JIT compilation of the model. - // - // Note: The given environment must outlive the compiled model and any - // execution running it. - // Note: If the model is fully AOT compiled for NPU, NPU accelerator is used - // automatically which means the provided `compilation_options` are - // meaningless. - static Expected Create( - litert::Environment& env, litert::Model& model, - const CompilationOptions& jit_compilation_options) { - LiteRtModel litert_model = model.Get(); - LiteRtCompiledModel compiled_model; - LITERT_RETURN_IF_ERROR(LiteRtCreateCompiledModel( - env.Get(), litert_model, jit_compilation_options.Get(), - &compiled_model)); - return CompiledModel(litert_model, compiled_model); - } - - // Simpler version of Create() that uses the default compilation options. - // The provided hardware accelerator is used for JIT compilation of the model. - // - // Note: If the model is fully AOT compiled for NPU, NPU accelerator - // is used automatically which means the provided `hardware_accelerator` is - // meaningless. - static Expected Create( - litert::Environment& env, litert::Model& model, - LiteRtHwAccelerators hardware_accelerator = kLiteRtHwAcceleratorCpu) { - LITERT_ASSIGN_OR_RETURN(auto jit_compilation_options, - CompilationOptions::Create()); - jit_compilation_options.SetHardwareAccelerators(hardware_accelerator); - return Create(env, model, jit_compilation_options); - } - - // Get input buffer requirements for the given signature and input name. - Expected GetInputBufferRequirements( - absl::string_view signature_name, absl::string_view input_name) { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return GetInputBufferRequirements(signature_index, input_name); - } - - // Returns the buffer requirements for the given n-th input tensor. The - // returned TensorBufferRequirements is used to create the input tensor - // buffer. - Expected GetInputBufferRequirements( - size_t signature_index, size_t input_index) const { - LiteRtTensorBufferRequirements buffer_requirements; - LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelInputBufferRequirements( - Get(), signature_index, input_index, &buffer_requirements)); - return TensorBufferRequirements(buffer_requirements, /*owned=*/false); - } - - // The same as above except this function takes input tensor name. - Expected GetInputBufferRequirements( - size_t signature_index, absl::string_view input_name) const { - LITERT_ASSIGN_OR_RETURN(size_t input_index, - FindInputIndex(signature_index, input_name)); - return GetInputBufferRequirements(signature_index, input_index); - } - - // Get input buffer requirements of the default signature for the given n-th - // input tensor. - Expected GetInputBufferRequirements( - size_t input_index) const { - return GetInputBufferRequirements(/*signature_index=*/0, input_index); - } - - // Get input buffer requirements of the default signature for input name. - Expected GetInputBufferRequirements( - absl::string_view input_name) const { - return GetInputBufferRequirements(/*signature_index=*/0, input_name); - } - - // Get output buffer requirements for the given signature and output name. - Expected GetOutputBufferRequirements( - absl::string_view signature_name, absl::string_view output_name) { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return GetOutputBufferRequirements(signature_index, output_name); - } - - // Returns the buffer requirements for the given output tensor. The returned - // TensorBufferRequirements is used to create the output tensor - // buffer. - Expected GetOutputBufferRequirements( - size_t signature_index, size_t output_index) const { - LiteRtTensorBufferRequirements buffer_requirements; - LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelOutputBufferRequirements( - Get(), signature_index, output_index, &buffer_requirements)); - return TensorBufferRequirements(buffer_requirements, /*owned=*/false); - } - - // The same as above except this function takes output tensor name. - Expected GetOutputBufferRequirements( - size_t signature_index, absl::string_view output_name) const { - LITERT_ASSIGN_OR_RETURN(size_t output_index, - FindOutputIndex(signature_index, output_name)); - return GetOutputBufferRequirements(signature_index, output_index); - } - - // Get input buffer requirements of the default signature for the given n-th - // input tensor. - Expected GetOutputBufferRequirements( - size_t output_index) const { - return GetOutputBufferRequirements(/*signature_index=*/0, output_index); - } - - // Get input buffer requirements of the default signature for input name. - Expected GetOutputBufferRequirements( - absl::string_view output_name) const { - return GetOutputBufferRequirements(/*signature_index=*/0, output_name); - } - - // Creates an input tensor buffer for the given signature and input name. - Expected CreateInputBuffer(absl::string_view signature_name, - absl::string_view input_name) const { - return CreateInputOutputBuffer(signature_name, input_name, - /*is_input=*/true); - } - - // Creates an input tensor buffer of the default signature for the given input - // name. - Expected CreateInputBuffer(absl::string_view input_name) const { - return CreateInputOutputBuffer(/*signature_index=*/0, input_name, - /*is_input=*/true); - } - - // Creates an output tensor buffer for the given signature and output name. - Expected CreateOutputBuffer( - absl::string_view signature_name, absl::string_view output_name) const { - return CreateInputOutputBuffer(signature_name, output_name, - /*is_input=*/false); - } - - // Creates an output tensor buffer of the default signature for the given - // output name. - Expected CreateOutputBuffer( - absl::string_view output_name) const { - return CreateInputOutputBuffer(/*signature_index=*/0, output_name, - /*is_input=*/false); - } - - // A helper function to create input tensor buffers for the given signature. - // It uses BufferRequirements and RankedTensorType to create the input tensor - // buffers. - Expected> CreateInputBuffers( - absl::string_view signature_name) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return CreateInputOutputBuffers(signature_index, /*is_input=*/true); - } - - // A helper function to creates the input tensor buffers for the given - // signature. It uses BufferRequirements and RankedTensorType to create the - // input tensor buffers. - Expected> CreateInputBuffers( - size_t signature_index) const { - return CreateInputOutputBuffers(signature_index, /*is_input=*/true); - } - - // A helper function to creates the input tensor buffers for the default - // signature. It uses BufferRequirements and RankedTensorType to create the - // input tensor buffers. - Expected> CreateInputBuffers() const { - return CreateInputOutputBuffers(/*signature_index=*/0, /*is_input=*/true); - } - - // A helper function to create output tensor buffers for the given signature. - // It uses BufferRequirements and RankedTensorType to create the output tensor - // buffers. - Expected> CreateOutputBuffers( - absl::string_view signature_name) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return CreateOutputBuffers(signature_index); - } - - // A helper function to creates the output tensor buffers for the given - // signature. It uses BufferRequirements and RankedTensorType to create the - // output tensor buffers. - Expected> CreateOutputBuffers( - size_t signature_index) const { - return CreateInputOutputBuffers(signature_index, /*is_input=*/false); - } - - // A helper function to creates the output tensor buffers for the default - // signature. It uses BufferRequirements and RankedTensorType to create the - // output tensor buffers. - Expected> CreateOutputBuffers() const { - return CreateInputOutputBuffers(/*signature_index=*/0, /*is_input=*/false); - } - - // Runs the model of the given signature index synchronously with the provided - // input/output TensorBuffers. - Expected Run(size_t signature_index, - const std::vector& input_buffers, - const std::vector& output_buffers) const { - bool async = false; - return RunHelper(signature_index, input_buffers, output_buffers, async); - } - - // Runs the model of the default signature synchronously with the provided - // input/output TensorBuffers. - Expected Run(const std::vector& input_buffers, - const std::vector& output_buffers) const { - bool async = false; - return RunHelper(/*signature_index=*/0, input_buffers, output_buffers, - async); - } - - // Runs the model of the given signature index asynchronously, if possible, - // with the provided input/output TensorBuffers. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync(size_t signature_index, - const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const { - async = true; - return RunHelper(signature_index, input_buffers, output_buffers, async); - } - - // Runs the model of the default signature asynchronously, if possible, - // with the provided input/output TensorBuffers. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync(const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const { - async = true; - return RunHelper(/*signature_index=*/0, input_buffers, output_buffers, - async); - } - - // Runs the model of the given signature key synchronously with the provided - // input/output TensorBuffers. - Expected Run(absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_key)); - return Run(signature_index, input_buffers, output_buffers); - } - - // Runs the model of the given signature key asynchronously, if possible, with - // the provided input/output TensorBuffers. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync(absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const { - async = true; - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_key)); - return RunAsync(signature_index, input_buffers, output_buffers, async); - } - - // Runs the model of the given signature key synchronously with the provided - // input/output TensorBuffer map. - Expected Run( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map) - const { - bool async = false; - return RunMapHelper(signature_key, input_map, output_map, async); - } - - // Runs the model of the default signature synchronously with the provided - // input/output TensorBuffer map. - Expected Run( - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map) - const { - bool async = false; - auto subgraph = model_.MainSubgraph(); - if (!subgraph) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get main subgraph"); - } - return RunMapWithIndexHelper(/*signature_index=*/0, *subgraph, input_map, - output_map, async); - } - - // Runs the model of the given signature key asynchronously, if possible, with - // the provided input/output TensorBuffer map. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const { - async = true; - return RunMapHelper(signature_key, input_map, output_map, async); - } - - private: - // Returns the signature input index for the given input tensor name. - Expected FindInputIndex(size_t signature_index, - absl::string_view input_name) const; - - // Returns the signature output index for the given output tensor name. - Expected FindOutputIndex(size_t signature_index, - absl::string_view output_name) const; - - // Creates a TensorBuffer with the given buffer requirements and tensor type. - static Expected CreateBufferImpl( - const TensorBufferRequirements& buffer_requirements, - const RankedTensorType& tensor_type); - - // Creates a TensorBuffer for the given signature index and tensor name. - Expected CreateInputOutputBuffer(size_t signature_index, - absl::string_view tensor_name, - bool is_input) const; - - // Creates a TensorBuffer for the given signature and tensor name. - Expected CreateInputOutputBuffer( - absl::string_view signature_name, absl::string_view tensor_name, - bool is_input) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return CreateInputOutputBuffer(signature_index, tensor_name, is_input); - } - - // Creates a vector of TensorBuffers for the given signature subgraph. - Expected> CreateInputOutputBuffers( - size_t signature_index, bool is_input) const; - - Expected RunCApiHelper(LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool& async) const; - - Expected RunHelper(size_t signature_index, - const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const; - - Expected RunMapHelper( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const; - - Expected RunMapWithIndexHelper( - size_t signature_index, const Subgraph& subgraph, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const; - - Model model_; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc deleted file mode 100644 index 425658907802cb..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include -#include "absl/debugging/leak_check.h" -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -using testing::FloatNear; -using testing::Pointwise; - -namespace litert { -namespace { - -void BasicTest() { - auto model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - auto env = litert::Environment::Create({}); - ASSERT_TRUE(env); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto compiled_model, - CompiledModel::Create(*env, model, kLiteRtHwAcceleratorGpu)); - auto signatures = model.GetSignatures().Value(); - EXPECT_EQ(signatures.size(), 1); - - auto signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, compiled_model.CreateOutputBuffers(signature_index)); - - // Fill model inputs. - auto input_names = signatures[0].InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - EXPECT_EQ(*input_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - EXPECT_EQ(*input_buffers[1].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signatures[0].OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.add"); - EXPECT_EQ(*output_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - { - auto lock_and_addr = - litert::TensorBufferScopedLock::Create(output_buffers[0]); - ASSERT_TRUE(lock_and_addr); - auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModelGpuTest, Basic) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - // To workaround the memory leak in Nvidia's driver - absl::LeakCheckDisabler disable_leak_check; - - BasicTest(); -} - -TEST(CompiledModelGpuTest, Basic2nd) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - // To workaround the memory leak in Nvidia's driver - absl::LeakCheckDisabler disable_leak_check; - - // Run the test twice to verify that the CL environment is shared between - // instances. - BasicTest(); -} - -TEST(CompiledModelGpuTest, Async) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - // To workaround the memory leak in Nvidia's driver - absl::LeakCheckDisabler disable_leak_check; - - auto model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - auto env = litert::Environment::Create({}); - ASSERT_TRUE(env); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto compiled_model, - CompiledModel::Create(*env, model, kLiteRtHwAcceleratorGpu)); - - auto signatures = model.GetSignatures().Value(); - EXPECT_EQ(signatures.size(), 1); - - auto signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto input_event, - Event::CreateManaged(LiteRtEventTypeOpenCl)); - // Copy of the event to trigger the signal since the ownership of the - // input_event is transferred to the input_buffers[0]. - LiteRtEvent litert_input_event = input_event.Get(); - - // Fill model inputs. - auto input_names = signatures[0].InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - EXPECT_EQ(*input_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - EXPECT_EQ(*input_buffers[1].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Bind the input event to the input buffers. - // Note: The task should be done after the input buffers are filled. - // Otherwise the input_buffers[0].Write<> will be blocked by the associated - // event. - input_buffers[0].SetEvent(std::move(input_event)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, compiled_model.CreateOutputBuffers(signature_index)); - - // Execute model asynchronously. - bool async_execution_mode = true; - compiled_model.RunAsync(signature_index, input_buffers, output_buffers, - async_execution_mode); - - // Signal the input event to resume the async execution. - LiteRtEventSignal(litert_input_event); - - // Check model output. - auto output_names = signatures[0].OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.add"); - EXPECT_EQ(*output_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - { - auto lock_and_addr = - litert::TensorBufferScopedLock::Create(output_buffers[0]); - ASSERT_TRUE(lock_and_addr); - auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc deleted file mode 100644 index 7bc36e48268ed1..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc +++ /dev/null @@ -1,376 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace { - -using ::testing::Eq; -using ::testing::FloatNear; -using ::testing::Pointwise; -using ::testing::SizeIs; - -constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; -constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -TEST(CompiledModelTest, RunWithGoogleTensorModel) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; - } - - // Environment setup. - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, - litert::Environment::Create(environment_options)); - - // Create Model. - - // TODO(gcarranza): Replace internal API with C++ API or single npu tflite - // file. - LITERT_ASSERT_OK_AND_ASSIGN( - BufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - // Confirm input and output buffers are AHWB. - EXPECT_THAT(*input_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*input_buffers[1].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*output_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Run compiled model. - compiled_model.Run(model.DefaultSignatureKey(), input_buffers, - output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModel, RunAsyncWithGoogleTensorModel) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; - } - - // Environment setup. - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, - litert::Environment::Create(environment_options)); - - // Create Model. - - // TODO(gcarranza): Replace internal API with C++ API or single npu tflite - // file. - LITERT_ASSERT_OK_AND_ASSIGN( - BufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - // Confirm input and output buffers are AHWB. - EXPECT_THAT(*input_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*input_buffers[1].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*output_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Run compiled model. - bool async; - compiled_model.RunAsync(model.DefaultSignatureKey(), input_buffers, - output_buffers, async); - // Since output buffers have events, async should be true. - ASSERT_TRUE(async); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -void FillGlBuffer1(LiteRtGLuint id, size_t size) { -#if LITERT_HAS_OPENGL_SUPPORT - std::string shader_source = R"( #version 310 es - precision highp float; - layout(local_size_x = 1, local_size_y = 1) in; - layout(std430, binding = 0) buffer Output {float elements[];} output_data; - void main() { - uint v = gl_GlobalInvocationID.x * 2u; - output_data.elements[v] = float(v + 1u) / 1.0; - output_data.elements[v + 1u] = float(v + 2u) / 1.0; - })"; - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const GLchar* sources[] = {shader_source.c_str()}; - glShaderSource(shader, 1, sources, nullptr); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glCompileShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - GLuint to_buffer_program = glCreateProgram(); - glAttachShader(to_buffer_program, shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glLinkProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, id); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glUseProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDispatchCompute(size / 2, 1, 1); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -void FillGlBuffer2(LiteRtGLuint id, size_t size) { -#if LITERT_HAS_OPENGL_SUPPORT - std::string shader_source = R"( #version 310 es - precision highp float; - layout(local_size_x = 1, local_size_y = 1) in; - layout(std430, binding = 0) buffer Output {float elements[];} output_data; - void main() { - uint v = gl_GlobalInvocationID.x * 2u; - output_data.elements[v] = float(v + 1u) / 0.1; - output_data.elements[v + 1u] = float(v + 2u) / 0.1; - })"; - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const GLchar* sources[] = {shader_source.c_str()}; - glShaderSource(shader, 1, sources, nullptr); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glCompileShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - GLuint to_buffer_program = glCreateProgram(); - glAttachShader(to_buffer_program, shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glLinkProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, id); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glUseProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDispatchCompute(size / 2, 1, 1); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -TEST(CompiledModel, RunAsyncWithGoogleTensorModelUseAhwbGlInterop) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; - } - - // Environment setup. - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, - litert::Environment::Create(environment_options)); - - // Create Model. - - // TODO(gcarranza): Replace internal API with C++ API or single npu tflite - // file. - LITERT_ASSERT_OK_AND_ASSIGN( - BufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - // Confirm input and output buffers are AHWB. - EXPECT_THAT(*input_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*input_buffers[1].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*output_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - - // TODO(gcarranza): Integrate with LiteRT Environment. -#if LITERT_HAS_OPENGL_SUPPORT - std::unique_ptr egl_env; - ASSERT_TRUE( - tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&egl_env).ok()); - LITERT_LOG(LITERT_INFO, "Initialized EGL environment"); -#else - LITERT_LOG(LITERT_INFO, "EGL environment not initialized"); -#endif // LITERT_HAS_OPENGL_SUPPORT - - // Write to input buffers on GPU. - LITERT_ASSERT_OK_AND_ASSIGN(auto gl_buffer_1, input_buffers[0].GetGlBuffer()); - FillGlBuffer1(gl_buffer_1.id, 2); - LITERT_ASSERT_OK_AND_ASSIGN(auto gl_buffer_2, input_buffers[1].GetGlBuffer()); - FillGlBuffer2(gl_buffer_2.id, 2); - - // Create EGL sync and fence before AHWB read. - // TODO(gcarranza): Integrate into LiteRT C++ API. - LITERT_ASSERT_OK_AND_ASSIGN( - int native_fence, ::litert::internal::GlBuffer::CreateEglSyncAndFence()); - - LITERT_ASSERT_OK_AND_ASSIGN( - Event event_1, - Event::CreateFromSyncFenceFd(native_fence, /*owns_fd=*/false)); - LITERT_ASSERT_OK_AND_ASSIGN( - Event event_2, - Event::CreateFromSyncFenceFd(native_fence, /*owns_fd=*/false)); - - // Set event so that AHWB read is blocked by GPU write. - input_buffers[0].SetEvent(std::move(event_1)); - input_buffers[1].SetEvent(std::move(event_2)); - - // Run compiled model asynchronously. - bool async; - compiled_model.RunAsync(model.DefaultSignatureKey(), input_buffers, - output_buffers, async); - // Since output buffers have events, async should be true. - ASSERT_TRUE(async); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc deleted file mode 100644 index 426817fc792e1d..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" - -#include -#include -#include - -#include -#include -#include "absl/container/flat_hash_map.h" -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -using ::testing::ElementsAre; -using ::testing::Eq; -using testing::FloatNear; -using testing::Pointwise; -using ::testing::SizeIs; - -namespace litert { -namespace { - -TEST(CompiledModelTest, Basic) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(/*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(/*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(/*output_name=*/"tfl.add")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN(std::vector input_buffers, - compiled_model.CreateInputBuffers()); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector output_buffers, - compiled_model.CreateOutputBuffers()); - - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model with input and output buffers. - compiled_model.Run(input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModelTest, BasicSignatureIndex) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signatures[0].InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signatures[0].OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(signature_index, - /*output_name=*/"tfl.add")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model with input and output buffers. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModelTest, RunWithInputOutputMap) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signatures[0].InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signatures[0].OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(signature_index, - /*output_name=*/"tfl.add")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer input_buffer0, - compiled_model.CreateInputBuffer(signature_key, "arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer input_buffer1, - compiled_model.CreateInputBuffer(signature_key, "arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer output_buffer0, - compiled_model.CreateOutputBuffer(signature_key, "tfl.add")); - - ASSERT_TRUE(input_buffer0.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffer1.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Create input and output map. - absl::flat_hash_map input_map; - input_map["arg0"] = std::move(input_buffer0); - input_map["arg1"] = std::move(input_buffer1); - - absl::flat_hash_map output_map; - output_map["tfl.add"] = std::move(output_buffer0); - - // Execute model with input and output maps instead of buffers. - compiled_model.Run(signature_key, input_map, output_map); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, litert::TensorBufferScopedLock::Create( - output_map["tfl.add"])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -// Tests Compiled Model async API on CPU. In the CPU case, the async API should -// always return false. -TEST(CompiledModelTest, RunAsyncReturnsFalse) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Create input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - // Confirm input and output buffers are host memory. - EXPECT_THAT(*input_buffers[0].BufferType(), - Eq(kLiteRtTensorBufferTypeHostMemory)); - EXPECT_THAT(*input_buffers[1].BufferType(), - Eq(kLiteRtTensorBufferTypeHostMemory)); - EXPECT_THAT(*output_buffers[0].BufferType(), - Eq(kLiteRtTensorBufferTypeHostMemory)); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model with input and output buffers. - bool async; - compiled_model.RunAsync(model.DefaultSignatureKey(), input_buffers, - output_buffers, async); - // Since there are no events on the output buffers, async should be false. - ASSERT_FALSE(async); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_consts.h b/tensorflow/lite/experimental/litert/cc/litert_consts.h deleted file mode 100644 index 14ac9a0b00e832..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_consts.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ - -#include - -namespace litert { - -// The following constants are used to properly size absl::InlinedVector<> -// uses used in the LiteRT code. Their values don't need to be exact; they -// are just optimization hints. -static constexpr size_t kExpectedMaxTensorRank = 6; -static constexpr size_t kExpectedMaxNumOfTensorUses = 8; -static constexpr size_t kExpectedMaxNumOfOpInputs = 4; -static constexpr size_t kExpectedMaxNumOfOpOutputs = 8; -static constexpr size_t kExpectedMaxNumOfSubgraphInputs = 4; -static constexpr size_t kExpectedMaxNumOfSubgraphOutputs = 4; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_detail.h b/tensorflow/lite/experimental/litert/cc/litert_detail.h deleted file mode 100644 index 566d8468fa8148..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_detail.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ - -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert { - -// See "std::construct_at" from C++20. -template -T* ConstructAt(T* p, Args&&... args) { - return ::new (static_cast(p)) T(std::forward(args)...); -} - -// Reduce all over zipped iters of same size. -template -bool AllZip(const LeftVals& lhs, const RightVals& rhs, - std::function - bin_pred) { - if (lhs.size() != rhs.size()) { - return false; - } - for (auto i = 0; i < lhs.size(); ++i) { - if (!bin_pred(lhs.at(i), rhs.at(i))) { - return false; - } - } - return true; -} - -// Reduce any over zipped iters of same size. -template -bool AnyZip(const LeftVals& lhs, const RightVals& rhs, - std::function - bin_pred) { - auto neg = [&](const auto& l, const auto& r) { return !bin_pred(l, r); }; - return !(AllZip(lhs, rhs, neg)); -} - -// Does element exist in range. -template -bool Contains(It begin, It end, const T& val) { - return std::find(begin, end, val) != end; -} - -// Does element exist in range satisfying pred. -template -bool ContainsIf(It begin, It end, UPred u_pred) { - return std::find_if(begin, end, u_pred) != end; -} - -// Get the ind of the given element if it is present. -template -std::optional FindInd(It begin, It end, T val) { - auto it = std::find(begin, end, val); - return (it == end) ? std::nullopt : std::make_optional(it - begin); -} - -namespace internal { - -// Call function "get" and assert it returns value equal to given expected -// value. -template -inline void AssertEq(F get, Expected expected, Args&&... args) { - auto status = get(std::forward(args)...); - ABSL_CHECK_EQ(status, expected); -} - -// Call function "get" and assert it returns true. -template -inline void AssertTrue(F get, Args&&... args) { - AssertEq(get, true, std::forward(args)...); -} - -// Call function "get" and assert it returns an OK LiteRtStatus. -template -inline void AssertOk(F get, Args&&... args) { - AssertEq(get, kLiteRtStatusOk, std::forward(args)...); -} - -} // namespace internal -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h b/tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h deleted file mode 100644 index bdbb3a0c4df8c7..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DISPATCH_DELEGATE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DISPATCH_DELEGATE_H_ - -#include - -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -namespace litert { - -using DispatchDelegateOptionsPtr = - std::unique_ptr; - -using DispatchDelegatePtr = tflite::TfLiteOpaqueDelegateUniquePtr; - -DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr( - LiteRtEnvironmentOptions environment_options); - -DispatchDelegatePtr CreateDispatchDelegatePtr( - LiteRtEnvironmentOptions environment_options, - DispatchDelegateOptionsPtr&& options); - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DISPATCH_DELEGATE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_element_type.h b/tensorflow/lite/experimental/litert/cc/litert_element_type.h deleted file mode 100644 index 84b032b3820a7a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_element_type.h +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -namespace litert { - -// Data type of tensor elements. C++ equivalent to LiteRtElementType. -enum class ElementType { - None = kLiteRtElementTypeNone, - Bool = kLiteRtElementTypeBool, - Int4 = kLiteRtElementTypeInt4, - Int8 = kLiteRtElementTypeInt8, - Int16 = kLiteRtElementTypeInt16, - Int32 = kLiteRtElementTypeInt32, - Int64 = kLiteRtElementTypeInt64, - UInt8 = kLiteRtElementTypeUInt8, - UInt16 = kLiteRtElementTypeUInt16, - UInt32 = kLiteRtElementTypeUInt32, - UInt64 = kLiteRtElementTypeUInt64, - Float16 = kLiteRtElementTypeFloat16, - BFloat16 = kLiteRtElementTypeBFloat16, - Float32 = kLiteRtElementTypeFloat32, - Float64 = kLiteRtElementTypeFloat64, - Complex64 = kLiteRtElementTypeComplex64, - Complex128 = kLiteRtElementTypeComplex128, - TfResource = kLiteRtElementTypeTfResource, - TfString = kLiteRtElementTypeTfString, - TfVariant = kLiteRtElementTypeTfVariant, -}; - -// Get number of bytes of a single element of given type. -inline constexpr std::optional GetByteWidth(ElementType ty) { - if (ty == ElementType::Bool) - return 1; - else if (ty == ElementType::Int8) - return 1; - else if (ty == ElementType::Int16) - return 2; - else if (ty == ElementType::Int32) - return 4; - else if (ty == ElementType::Int64) - return 8; - else if (ty == ElementType::UInt8) - return 1; - else if (ty == ElementType::UInt16) - return 2; - else if (ty == ElementType::UInt32) - return 4; - else if (ty == ElementType::UInt64) - return 8; - else if (ty == ElementType::Float16) - return 2; - else if (ty == ElementType::BFloat16) - return 2; - else if (ty == ElementType::Float32) - return 4; - else if (ty == ElementType::Float64) - return 8; - else - return std::nullopt; -} - -// Get number of bytes of a single element of given type via template. -template -inline constexpr size_t GetByteWidth() { - constexpr auto byte_width = GetByteWidth(Ty); - static_assert(byte_width.has_value(), "Type does not have byte width"); - return byte_width.value(); -} - -template -constexpr bool dependent_false = false; // workaround before CWG2518/P2593R1 - -// Get the litert::ElementType associated with given C++ type. -template -inline constexpr ElementType GetElementType() { - static_assert(dependent_false, "Uknown C++ type"); - return ElementType::None; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Bool; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int8; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt8; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int16; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt16; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int32; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt32; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int64; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt64; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Float32; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Float64; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc b/tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc deleted file mode 100644 index 929bc499f32c63..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" - -#include -#include - -#include - -namespace litert { - -namespace { - -template -class ElementTypeTest : public ::testing::Test { - public: - size_t Size() const { return sizeof(T); } -}; - -TYPED_TEST_SUITE_P(ElementTypeTest); - -TYPED_TEST_P(ElementTypeTest, TypeAndSize) { - const size_t size = GetByteWidth()>(); - EXPECT_EQ(size, this->Size()); -} - -REGISTER_TYPED_TEST_SUITE_P(ElementTypeTest, TypeAndSize); - -using Types = - ::testing::Types; - -INSTANTIATE_TYPED_TEST_SUITE_P(ElementTypeTestSuite, ElementTypeTest, Types); - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_environment.h b/tensorflow/lite/experimental/litert/cc/litert_environment.h deleted file mode 100644 index 69faebdea892d3..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_environment.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -namespace litert { - -class Environment - : public internal::Handle { - public: - explicit Environment(LiteRtEnvironment env) - : internal::Handle(env, - true) {} - - enum class OptionTag { - CompilerPluginLibraryDir = kLiteRtEnvOptionTagCompilerPluginLibraryDir, - DispatchLibraryDir = kLiteRtEnvOptionTagDispatchLibraryDir, - }; - - struct Option { - OptionTag tag; - std::any value; - }; - - static Expected Create(absl::Span options) { - auto c_options = ConvertOptions(options); - if (!c_options) { - return c_options.Error(); - } - LiteRtEnvironment env; - if (auto status = - LiteRtEnvironmentCreate(c_options->size(), c_options->data(), &env); - status != kLiteRtStatusOk) { - return Error(status); - } else { - return Environment(env); - } - } - - private: - static Expected> ConvertOptions( - absl::Span options) { - std::vector c_options; - c_options.reserve(options.size()); - - for (auto& option : options) { - auto litert_any = ToLiteRtAny(option.value); - if (!litert_any) { - return litert_any.Error(); - } - - LiteRtEnvOption c_option = { - /*.tag=*/static_cast(option.tag), - /*.value=*/*litert_any, - }; - c_options.push_back(c_option); - } - - return c_options; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_environment_test.cc b/tensorflow/lite/experimental/litert/cc/litert_environment_test.cc deleted file mode 100644 index 0f012aedfee66b..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_environment_test.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -namespace litert { -namespace { - -TEST(EnvironmentTest, Default) { - auto env = litert::Environment::Create({}); - EXPECT_TRUE(env); -} - -TEST(EnvironmentTest, Options) { - constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - EXPECT_TRUE(env); -} - -TEST(EnvironmentTest, CompiledModelBasic) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - auto compiled_model = CompiledModel::Create(env, model); - EXPECT_TRUE(compiled_model); -} - -TEST(EnvironmentTest, StringLifeCycle) { - std::string dispatch_library_dir = "/data/local/tmp"; - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - absl::string_view(dispatch_library_dir), - }, - }; - - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - - EXPECT_TRUE(env); - - // Change the string value but the environment should still have a copy. - dispatch_library_dir = ""; - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - auto compiled_model = CompiledModel::Create(*env, model); - EXPECT_TRUE(compiled_model); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_event.h b/tensorflow/lite/experimental/litert/cc/litert_event.h deleted file mode 100644 index 0f8582205d9c39..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_event.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -extern "C" { -// Forward declaration of OpenCL event to avoid including OpenCL headers. -typedef struct _cl_event* cl_event; -} - -namespace litert { - -class Event : public internal::Handle { - public: - // Parameter `owned` indicates if the created TensorBufferRequirements object - // should take ownership of the provided `requirements` handle. - explicit Event(LiteRtEvent event, bool owned = true) - : internal::Handle(event, owned) {} - - // Creates an Event object with the given `sync_fence_fd`. - static Expected CreateFromSyncFenceFd(int sync_fence_fd, - bool owns_fd) { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR( - LiteRtCreateEventFromSyncFenceFd(sync_fence_fd, owns_fd, &event)); - return Event(event); - } - - // Creates an Event object with the given `cl_event`. - static Expected CreateFromOpenClEvent(cl_event cl_event) { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR(LiteRtCreateEventFromOpenClEvent(cl_event, &event)); - return Event(event); - } - - // Creates a managed event of the given `type`. Currently only - // LiteRtEventTypeOpenCl is supported. - static Expected CreateManaged(LiteRtEventType type) { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR(LiteRtCreateManagedEvent(type, &event)); - return Event(event); - } - - Expected GetSyncFenceFd() { - int fd; - LITERT_RETURN_IF_ERROR(LiteRtGetEventSyncFenceFd(Get(), &fd)); - return fd; - } - - // Returns the underlying OpenCL event if the event type is OpenCL. - Expected GetOpenClEvent() { - cl_event cl_event; - LITERT_RETURN_IF_ERROR(LiteRtGetEventOpenClEvent(Get(), &cl_event)); - return cl_event; - } - - // Pass -1 for timeout_in_ms for indefinite wait. - Expected Wait(int64_t timeout_in_ms) { - LITERT_RETURN_IF_ERROR(LiteRtEventWait(Get(), timeout_in_ms)); - return {}; - } - - // Singal the event. - // Note: This is only supported for OpenCL events. - Expected Signal() { - LITERT_RETURN_IF_ERROR(LiteRtEventSignal(Get())); - return {}; - } - - // Returns the underlying event type. - LiteRtEventType Type() const { - LiteRtEventType type; - LiteRtGetEventEventType(Get(), &type); - return type; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_event_test.cc b/tensorflow/lite/experimental/litert/cc/litert_event_test.cc deleted file mode 100644 index 752e8c0a6c3ce6..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_event_test.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert { -namespace { - -using ::testing::Eq; - -constexpr int kFakeSyncFenceFd = 1; - -TEST(Event, NoEvent) { - LITERT_ASSERT_OK_AND_ASSIGN( - Event event, Event::CreateFromSyncFenceFd(kFakeSyncFenceFd, true)); - LITERT_ASSERT_OK_AND_ASSIGN(int fd, event.GetSyncFenceFd()); - EXPECT_THAT(fd, Eq(kFakeSyncFenceFd)); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_expected.h b/tensorflow/lite/experimental/litert/cc/litert_expected.h deleted file mode 100644 index 5d60e86094391a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_expected.h +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" - -namespace litert { - -// An "Expected" incapsulates the result of some routine which may have an -// unexpected result. Unexpected results in this context are a standard -// LiteRtStatus plus extra usability data such as error messages. This is -// similar to an absl::StatusOr or std::expected (C++23) but better integrated -// with LiteRtStatus as the canonical status code. - -// C++ wrapper around LiteRtStatus code. Provides a status as well -// as an error message. -class Error { - public: - // Construct Unexpected from status and optional error message. - // - // NOTE: kLiteRtStatusOk should not be passed to Unexpected. - explicit Error(LiteRtStatus status, std::string message = "") - : status_(status), message_(std::move(message)) { - ABSL_DCHECK(status != kLiteRtStatusOk); - } - - // Get the status. - constexpr LiteRtStatus Status() const { return status_; } - - // Get the error message, empty string if none was attached. - const std::string& Message() const { return message_; } - - friend std::ostream& operator<<(std::ostream& stream, const Error& error) { - stream << LiteRtGetStatusString(error.Status()); - if (!error.Message().empty()) { - stream << ": " << error.Message(); - } - return stream; - } - - template - friend void AbslStringify(Sink& sink, const Error& error) { - absl::Format(&sink, "%s", LiteRtGetStatusString(error.Status())); - if (!error.Message().empty()) { - absl::Format(&sink, ": %v", error.Message()); - } - } - - private: - LiteRtStatus status_; - std::string message_; -}; - -class Unexpected { - public: - template - constexpr explicit Unexpected(Args&&... args) - : error_(std::forward(args)...) {} - - // Allow for implicit conversion from convertible Error value inplace. - // NOLINTNEXTLINE(*-explicit-constructor) - Unexpected(class Error&& e) : error_(std::move(e)) {} - - Unexpected(Unexpected&& other) = default; - Unexpected(const Unexpected& other) = default; - Unexpected& operator=(Unexpected&& other) = default; - Unexpected& operator=(const Unexpected& other) = default; - - constexpr const class Error& Error() const& noexcept { return error_; } - constexpr class Error& Error() & noexcept { return error_; } - constexpr const class Error&& Error() const&& noexcept { - return std::move(error_); - } - constexpr class Error&& Error() && noexcept { return std::move(error_); } - - template - friend void AbslStringify(Sink& sink, const Unexpected& unexpected) { - AbslStringify(sink, unexpected.Error()); - } - - private: - class Error error_; -}; - -// Utility for generic return values that may be a statused failure. Expecteds -// store and own the lifetime of either an Unexpected, or a T. T may be any -// type, primitive or non-primitive. -// -// No dynamic allocations occur during initialization, so the underlying T is -// only movable (as opposed to something like "release"). Arguments should be -// constructed in place at the time of initializing the expected if possible. -// -// Unexpected&& and T&& may be implicitly casted -// to an Expected. For example, -// -// Expected Bar() { -// bool success = ... -// if (!success) { return Unexpected(kLiteRtStatus, "Bad Baz"); } -// return Foo(); -// } -// -template -class Expected { - public: - // Construct Expected with T inplace. - - // Construct T from initializer list inplace. - template - Expected(std::initializer_list il) : has_value_(true), value_(il) {} - - // Construct T from forwarded args inplace. - template - explicit Expected(Args&&... args) - : has_value_(true), value_(std::forward(args)...) {} - - // NOLINTBEGIN(*-explicit-constructor) - - // Allow for implicit conversion from convertible T value inplace. - Expected(const T& t) : has_value_(true), value_(t) {} - Expected(T&& t) : has_value_(true), value_(std::move(t)) {} - - // Construct from Unexpected inplace. - - // Allow for implicit conversion from Error. - Expected(const Unexpected& err) : has_value_(false), unexpected_(err) {} - Expected(Unexpected&& err) : has_value_(false), unexpected_(std::move(err)) {} - Expected(const class Error& e) : has_value_(false), unexpected_(e) {} - - // NOLINTEND(*-explicit-constructor) - - // Copy/move - - Expected(Expected&& other) : has_value_(other.HasValue()) { - if (HasValue()) { - ConstructAt(std::addressof(value_), std::move(other.value_)); - } else { - ConstructAt(std::addressof(unexpected_), std::move(other.unexpected_)); - } - } - - Expected(const Expected& other) : has_value_(other.has_value_) { - if (HasValue()) { - ConstructAt(std::addressof(value_), other.value_); - value_ = other.value_; - } else { - ConstructAt(std::addressof(unexpected_), other.unexpected_); - } - } - - Expected& operator=(Expected&& other) { - if (this != &other) { - if (HasValue()) { - if (other.HasValue()) { - value_ = std::move(other.value_); - } else { - value_.~T(); - ConstructAt(std::addressof(unexpected_), - std::move(other.unexpected_)); - } - } else { - if (other.HasValue()) { - unexpected_.~Unexpected(); - ConstructAt(std::addressof(value_), std::move(other.value_)); - } else { - unexpected_ = std::move(other.unexpected_); - } - } - has_value_ = other.has_value_; - } - return *this; - } - - Expected& operator=(const Expected& other) { - if (this != &other) { - if (HasValue()) { - if (other.HasValue()) { - value_ = other.value_; - } else { - value_.~T(); - ConstructAt(std::addressof(unexpected_), other.unexpected_); - } - } else { - if (other.HasValue()) { - unexpected_.~Unexpected(); - ConstructAt(std::addressof(value_), other.value_); - } else { - unexpected_ = other.unexpected_; - } - } - has_value_ = other.has_value_; - } - return *this; - } - - ~Expected() { - if (has_value_ && std::is_destructible()) { - value_.~T(); - } else { - unexpected_.~Unexpected(); - } - } - - // Observers for T value, program exits if it doesn't have one. - const T& Value() const& { - CheckVal(); - return value_; - } - - T& Value() & { - CheckVal(); - return value_; - } - - const T&& Value() const&& { - CheckVal(); - return std::move(value_); - } - - T&& Value() && { - CheckVal(); - return std::move(value_); - } - - const T* operator->() const { - CheckVal(); - return &value_; - } - - T* operator->() { - CheckVal(); - return &value_; - } - - const T& operator*() const& { return Value(); } - - T& operator*() & { return Value(); } - - const T&& operator*() const&& { return std::move(Value()); } - - T&& operator*() && { return std::move(Value()); } - - // Observer for Unexpected, program exits if it doesn't have one. - const class Error& Error() const& { - CheckNoVal(); - return unexpected_.Error(); - } - - class Error& Error() & { - CheckNoVal(); - return unexpected_.Error(); - } - - const class Error&& Error() const&& { - CheckNoVal(); - return std::move(unexpected_.Error()); - } - - class Error&& Error() && { - CheckNoVal(); - return std::move(unexpected_.Error()); - } - - // Does this expected contain a T Value. It contains an unexpected if not. - bool HasValue() const { return has_value_; } - - // Convert to bool for HasValue. - explicit operator bool() const { return HasValue(); } - - private: - bool has_value_; - union { - T value_; - Unexpected unexpected_; - }; - void CheckNoVal() const { ABSL_CHECK(!HasValue()); } - void CheckVal() const { ABSL_CHECK(HasValue()); } -}; - -namespace internal { -template -struct CanBeAbslFormated { - template - static constexpr auto Check(int) - -> decltype(absl::StrCat(std::declval()), true) { - return true; - } - template - static constexpr bool Check(...) { - return false; - } - enum { value = Check(0) }; -}; -} // namespace internal - -template -void AbslStringify(Sink& sink, const Expected& expected) { - if (!expected.HasValue()) { - absl::Format(&sink, "%v", expected.Error()); - } else { - if constexpr (std::is_same_v) { - sink.Append("void expected value"); - } else { - if constexpr (internal::CanBeAbslFormated::value) { - absl::Format(&sink, "%v", expected.Value()); - } else { - absl::Format(&sink, "unformattable expected value"); - } - } - } -} - -template <> -class Expected { - public: - // Implicit construction is used to simplify returning a valid value, e.g., in - // "return {};" - Expected() : unexpected_(std::nullopt) {} - - // NOLINTBEGIN(*-explicit-constructor) - - // Construct from Unexpected inplace. - Expected(const Unexpected& err) : unexpected_(err) {} - Expected(Unexpected&& err) : unexpected_(std::move(err)) {} - - // Allow for implicit conversion from Error. - Expected(const Error& e) : unexpected_(e) {} - - // NOLINTEND(*-explicit-constructor) - - // Observer for Unexpected, program exits if it doesn't have one. - const class Error& Error() const& { - CheckNoVal(); - return unexpected_->Error(); - } - - class Error& Error() & { - CheckNoVal(); - return unexpected_->Error(); - } - - const class Error&& Error() const&& { - CheckNoVal(); - return std::move(unexpected_->Error()); - } - - class Error&& Error() && { - CheckNoVal(); - return std::move(unexpected_->Error()); - } - - // Does this expected contain a T Value. It contains an unexpected if not. - bool HasValue() const { return !unexpected_.has_value(); } - - // Convert to bool for HasValue. - explicit operator bool() const { return HasValue(); } - - private: - std::optional unexpected_; - void CheckNoVal() const { ABSL_CHECK(!HasValue()); } - void CheckVal() const { ABSL_CHECK(HasValue()); } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc b/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc deleted file mode 100644 index ad68a834dbe80f..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -namespace litert { - -namespace { -using testing::StrEq; - -static constexpr LiteRtStatus kErrorStatus = kLiteRtStatusErrorInvalidArgument; - -struct TypeWithAllocation { - TypeWithAllocation(std::initializer_list il) : allocated(il) {} - std::vector allocated; -}; - -struct TypeWithFields { - TypeWithFields(int i_, int j_) : i(i_), j(j_) {} - int i; - int j; -}; - -TEST(ExpectedTest, PrimitiveExplicit) { - Expected exp(1.0); - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, PrimitiveImplicit) { - Expected exp = 1.0; - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, ClassWithAllocation) { - Expected exp(TypeWithAllocation({1, 2, 3})); - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, ClassWithFields) { - Expected exp(TypeWithFields(1, 2)); - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, FromErrorExplicit) { - Expected exp((Unexpected(kErrorStatus, "MESSAGE"))); - ASSERT_FALSE(exp.HasValue()); -} - -TEST(ExpectedTest, FromErrorImplicit) { - Expected exp = Unexpected(kErrorStatus); - ASSERT_FALSE(exp.HasValue()); -} - -TEST(ExpectedTest, CopyCstorError) { - const Expected exp = Unexpected(kErrorStatus); - Expected other(exp); - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, CopyCstorVal) { - const Expected exp = 2; - Expected other(exp); - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, CopyAssignError) { - const Expected exp = Unexpected(kErrorStatus); - ASSERT_FALSE(exp.HasValue()); - Expected other = exp; - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, CopyAssignVal) { - const Expected exp = 2; - Expected other = exp; - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, MoveCstorError) { - Expected exp = Unexpected(kErrorStatus); - Expected other(std::move(exp)); - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, MoveCstorVal) { - Expected exp = 2; - Expected other(std::move(exp)); - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, MoveAssignError) { - Expected exp = Unexpected(kErrorStatus); - Expected other = std::move(exp); - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, MoveAssignVal) { - Expected exp = 2; - Expected other = std::move(exp); - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, Indirection) { - Expected exp(TypeWithFields(1, 2)); - EXPECT_EQ(exp->i, 1); - EXPECT_EQ(exp->j, 2); -} - -TEST(ExpectedTest, Dereference) { - Expected exp(TypeWithFields(1, 2)); - const auto& val = *exp; - EXPECT_EQ(val.i, 1); - EXPECT_EQ(val.j, 2); -} - -TEST(UnexpectedTest, WithStatus) { - Unexpected err(kErrorStatus); - EXPECT_EQ(err.Error().Status(), kErrorStatus); - EXPECT_TRUE(err.Error().Message().empty()); -} - -TEST(UnexpectedTest, WithMessage) { - Unexpected err(kErrorStatus, "MESSAGE"); - EXPECT_EQ(err.Error().Status(), kErrorStatus); - EXPECT_EQ(err.Error().Message(), "MESSAGE"); -} - -TEST(UnexpectedTest, WithLocalMessageString) { - // Message is a string with scoped lifetime. - Unexpected err(kErrorStatus, absl::StrCat("MESSAGE", 1)); - EXPECT_EQ(err.Error().Status(), kErrorStatus); - EXPECT_EQ(err.Error().Message(), "MESSAGE1"); -} - -Expected> Go() { - std::string data = "21234"; - OwningBufferRef buf(data.c_str()); - return buf; -} - -Expected> Forward() { - auto thing = Go(); - if (!thing.HasValue()) { - return thing.Error(); - } - // No copy elision here. - return thing; -} - -TEST(ExpectedTest, ForwardBufThroughFuncs) { - auto res = Forward(); - EXPECT_TRUE(res.HasValue()); - EXPECT_EQ(res->StrView(), "21234"); -} - -TEST(ExpectedWithNoValue, WithoutError) { - Expected expected = {}; - EXPECT_TRUE(expected.HasValue()); -} - -TEST(ExpectedWithNoValue, WithError) { - Expected expected(Unexpected(kErrorStatus, "MESSAGE")); - EXPECT_FALSE(expected.HasValue()); - EXPECT_EQ(expected.Error().Status(), kErrorStatus); - EXPECT_EQ(expected.Error().Message(), "MESSAGE"); -} - -TEST(ExpectedWithNoValue, OStreamOutput) { - Expected expected(Unexpected(kErrorStatus, "MESSAGE")); - std::ostringstream oss; - oss << expected.Error(); - EXPECT_THAT(oss.str(), testing::HasSubstr("MESSAGE")); -} - -TEST(ExpectedTest, PrintingWorks) { - EXPECT_THAT(absl::StrCat(Expected(3)), StrEq("3")); - - EXPECT_THAT(absl::StrCat(Expected()), StrEq("void expected value")); - - EXPECT_THAT(absl::StrCat(Unexpected(kLiteRtStatusErrorNotFound)), - StrEq("kLiteRtStatusErrorNotFound")); - - EXPECT_THAT(absl::StrCat(Unexpected(kLiteRtStatusErrorNotFound, - "Error not found message")), - StrEq("kLiteRtStatusErrorNotFound: Error not found message")); - - EXPECT_THAT(absl::StrCat(Error(kLiteRtStatusErrorNotFound)), - StrEq("kLiteRtStatusErrorNotFound")); - - EXPECT_THAT(absl::StrCat( - Error(kLiteRtStatusErrorNotFound, "Error not found message")), - StrEq("kLiteRtStatusErrorNotFound: Error not found message")); - - struct UnknownStruct {}; - EXPECT_THAT(absl::StrCat(Expected({})), - StrEq("unformattable expected value")); -} - -} // namespace - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_handle.h b/tensorflow/lite/experimental/litert/cc/litert_handle.h deleted file mode 100644 index 503eaad335b764..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_handle.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ - -#include -#include - -namespace litert { -namespace internal { - -template -inline void DummyDeleter(H) {} - -// This class is used to wrap and manage the lifetime of opaque handles from the -// C API into an equivalent C++ object. The class is a wrapper on -// std::unique_ptr<> that has a default constructor and doesn't crash if the -// deleter is null. -template -class Handle { - public: - Handle() = default; - explicit Handle(H handle, bool owned) noexcept - : ptr_(handle, owned ? deleter : DummyDeleter) {} - - Handle(Handle&& other) noexcept { *this = std::move(other); } - - Handle& operator=(Handle&& other) noexcept { - std::swap(ptr_, other.ptr_); - return *this; - } - - // Return true if the underlying LiteRtTensorBuffer handle is valid. - explicit operator bool() const noexcept { return static_cast(ptr_); } - - // Return the underlying LiteRtTensorBuffer handle. - H Get() const noexcept { return ptr_.get(); } - - H Release() noexcept { return ptr_.release(); } - - bool IsOwned() const noexcept { - return ptr_.get_deleter() != DummyDeleter; - } - - private: - std::unique_ptr, void (*)(H)> ptr_ = {nullptr, - DummyDeleter}; -}; - -// This class is similar to Handle, but the managed opaque handle is not owned -// (i.e., it will not be destroyed). -template -class NonOwnedHandle : public Handle> { - public: - explicit NonOwnedHandle(H handle) noexcept - : Handle>(handle, /*owned=*/false) {} -}; - -} // namespace internal -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_layout.h b/tensorflow/lite/experimental/litert/cc/litert_layout.h deleted file mode 100644 index a8f90ac6dc1069..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_layout.h +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" - -namespace litert { - -using Dimensions = absl::InlinedVector; -using Strides = absl::InlinedVector; - -// Small standalone helper functions for working with the C layout API. - -// Build layout from given iterator of dimensions. -template -inline constexpr LiteRtLayout BuildLayout(Begin begin, End end, - const uint32_t* strides = nullptr) { - LiteRtLayout res{static_cast(end - begin), {}, strides}; - auto i = 0; - - for (auto* it = begin; it != end; ++it) { - res.dimensions[i] = *it; - ++i; - } - - return res; -} - -// Build layout from given iterable of dimensions. -template -inline constexpr LiteRtLayout BuildLayout(const Dims& dims, - const uint32_t* strides = nullptr) { - return BuildLayout(std::cbegin(dims), std::cend(dims), strides); -} - -// Build layout from literal dimensions. -inline constexpr LiteRtLayout BuildLayout(std::initializer_list dims, - const uint32_t* strides = nullptr) { - return BuildLayout(dims.begin(), dims.end(), strides); -} - -// Compute the number of elements in dims iterator. Nullopt if there exists -// a dynamic dimension. -template -inline constexpr std::optional NumElements(Begin begin, End end) { - if (end - begin == 0) { - return {}; - } - size_t res = 1; - for (auto* it = begin; it != end; ++it) { - if (*it < 0) { - return {}; - } - res *= *it; - } - return res; -} - -// Override for layouts. -inline constexpr std::optional NumElements(const LiteRtLayout& layout) { - auto* b = std::cbegin(layout.dimensions); - return NumElements(b, b + layout.rank); -} - -// Get dims as span. -inline constexpr absl::Span DimsSpan( - const LiteRtLayout& layout) { - return absl::MakeConstSpan(layout.dimensions, layout.rank); -} - -// Get strides as span if they exist. -inline constexpr std::optional> StridesSpan( - const LiteRtLayout& layout) { - if (layout.strides) { - return absl::MakeConstSpan(layout.strides, layout.rank); - } - return {}; -} - -// Tensor layout. C++ equivalent to LiteRtLayout. -class Layout { - public: - explicit Layout(litert::Dimensions&& dimensions, - litert::Strides&& strides = litert::Strides()) - : dimensions_(std::move(dimensions)), strides_(std::move(strides)) {} - - explicit Layout(const LiteRtLayout& layout) - : dimensions_(layout.dimensions, layout.dimensions + layout.rank) { - if (layout.strides) { - strides_.assign(layout.strides, layout.strides + layout.rank); - } - } - - // Cast the existing Layout to a LiteRtLayout. Note that the present Layout - // object must outlive the returned LiteRtLayout, otherwise pointers in the - // latter may become dangling. - explicit operator LiteRtLayout() const { - auto res = BuildLayout(dimensions_); - res.strides = HasStrides() ? strides_.data() : nullptr; - return res; - } - - bool operator==(const Layout& other) const { - return dimensions_ == other.dimensions_ && strides_ == other.strides_; - } - - uint32_t Rank() const { return dimensions_.size(); } - - absl::Span Dimensions() const { - return absl::MakeSpan(dimensions_.data(), dimensions_.size()); - } - - bool HasStrides() const { return !strides_.empty(); } - - absl::Span Strides() const { - if (HasStrides()) - return {strides_.data(), Rank()}; - else - return {}; - } - - // Get the number of scalar elements in this tensor type. std::nullopt if - // not fully static. - std::optional NumElements() const { - return ::litert::NumElements(dimensions_.cbegin(), dimensions_.cend()); - } - - private: - litert::Dimensions dimensions_; - litert::Strides strides_; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_layout_test.cc b/tensorflow/lite/experimental/litert/cc/litert_layout_test.cc deleted file mode 100644 index 40d9cb9873e045..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_layout_test.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" - -#include - -#include -#include - -namespace litert { -namespace { - -using ::testing::ElementsAreArray; - -static constexpr int32_t kStaticDims[] = {2, 2}; -static constexpr int32_t kDynDims[] = {-1, 2}; -static constexpr uint32_t kStrides[] = {1, 1}; - -TEST(LayoutTest, BuildFromDims) { - auto layout = BuildLayout(kStaticDims); - EXPECT_EQ(layout.rank, 2); - EXPECT_THAT(DimsSpan(layout), ElementsAreArray(kStaticDims)); - EXPECT_EQ(layout.strides, nullptr); - EXPECT_FALSE(StridesSpan(layout).has_value()); -} - -TEST(LayoutTest, BuildFromDimsWithStrides) { - auto layout = BuildLayout(kStaticDims, kStrides); - EXPECT_EQ(layout.rank, 2); - EXPECT_THAT(DimsSpan(layout), ElementsAreArray(kStaticDims)); - auto strides = StridesSpan(layout); - ASSERT_TRUE(strides.has_value()); - EXPECT_THAT(*strides, ElementsAreArray(kStrides)); -} - -TEST(LayoutTest, NumElements) { - auto layout = BuildLayout(kStaticDims); - auto num_elements = NumElements(layout); - ASSERT_TRUE(num_elements.has_value()); - EXPECT_EQ(*num_elements, 4); -} - -TEST(LayoutTest, NumElementsDynamic) { - auto layout = BuildLayout(kDynDims); - auto num_elements = NumElements(layout); - ASSERT_FALSE(num_elements.has_value()); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_macros.cc b/tensorflow/lite/experimental/litert/cc/litert_macros.cc deleted file mode 100644 index 7d01ca346818e9..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_macros.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#include "absl/status/status.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { - -ErrorStatusBuilder::operator absl::Status() const noexcept { - switch (error_.Status()) { - case kLiteRtStatusOk: - return absl::OkStatus(); - case kLiteRtStatusErrorInvalidArgument: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorMemoryAllocationFailure: - return absl::ResourceExhaustedError(error_.Message()); - case kLiteRtStatusErrorRuntimeFailure: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorMissingInputTensor: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorUnsupported: - return absl::UnimplementedError(error_.Message()); - case kLiteRtStatusErrorNotFound: - return absl::NotFoundError(error_.Message()); - case kLiteRtStatusErrorTimeoutExpired: - return absl::DeadlineExceededError(error_.Message()); - case kLiteRtStatusErrorWrongVersion: - return absl::FailedPreconditionError(error_.Message()); - case kLiteRtStatusErrorUnknown: - return absl::UnknownError(error_.Message()); - case kLiteRtStatusErrorFileIO: - return absl::UnavailableError(error_.Message()); - case kLiteRtStatusErrorInvalidFlatbuffer: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorDynamicLoading: - return absl::UnavailableError(error_.Message()); - case kLiteRtStatusErrorSerialization: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorCompilation: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorIndexOOB: - return absl::OutOfRangeError(error_.Message()); - case kLiteRtStatusErrorInvalidIrType: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorInvalidGraphInvariant: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorGraphModification: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorInvalidToolConfig: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusLegalizeNoMatch: - return absl::NotFoundError(error_.Message()); - case kLiteRtStatusErrorInvalidLegalization: - return absl::InvalidArgumentError(error_.Message()); - default: - return absl::UnknownError(error_.Message()); - } -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_macros.h b/tensorflow/lite/experimental/litert/cc/litert_macros.h deleted file mode 100644 index 299649061fb333..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_macros.h +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" // IWYU pragma: keep - -#define _CONCAT_NAME_IMPL(x, y) x##y - -#define _CONCAT_NAME(x, y) _CONCAT_NAME_IMPL(x, y) - -#define _RETURN_VAL(val) return val - -#define LITERT_CHECK_STATUS_HAS_CODE(expr, code) ABSL_CHECK(expr == code); - -#define LITERT_CHECK_STATUS_OK(expr) \ - LITERT_CHECK_STATUS_HAS_CODE(expr, kLiteRtStatusOk); - -#define LITERT_ENSURE_SUPPORTED(cond, msg) \ - if (!(cond)) { \ - LITERT_LOG(LITERT_ERROR, "%s", msg); \ - return kLiteRtStatusErrorUnsupported; \ - } - -#define LITERT_ENSURE(expr, fail_stat, msg) \ - if (!(expr)) { \ - LITERT_LOG(LITERT_ERROR, "%s", msg); \ - return fail_stat; \ - } - -#define LITERT_RETURN_IF_ERROR_OR_NOT_MATCHED(expr) \ - if (LiteRtStatus status = expr; \ - (status != kLiteRtStatusOk && status != kLiteRtStatusLegalizeNoMatch)) \ - return status; - -#define LITERT_STACK_ARRAY(ty, var, size, init) \ - ty* var = (ty*)alloca(sizeof(ty) * size); \ - for (ty* e = var; e < var + size; ++e) { \ - *e = init; \ - } - -// LITERT_RETURN_IF_ERROR(expr); -// LITERT_RETURN_IF_ERROR(expr, return_value); -// -// Returns the result of `expr` if it represents an LiteRT error status (either -// `litert::Expected` holding an error, a `LiteRtStatus` or a bool that -// evaluated to `false`). -// -// Returns `return_value` if the result of `expr` represents an error. -// -// The result of `expr` may be referenced as `status` in `return_expr`. -// -// By default, the return value is an `ErrorStatusBuilder` built from using the -// result of `expr`. The error message of this builder can be customized by -// using its `*Log*()` functions and the << operator. -// -// ```cpp -// LITERT_RETURN_IF_ERROR(expr) << "Failed while trying to ..."; -// ``` -#define LITERT_RETURN_IF_ERROR(...) \ - LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD( \ - (__VA_ARGS__, LITERT_RETURN_IF_ERROR_2, LITERT_RETURN_IF_ERROR_1))( \ - __VA_ARGS__) - -// ASSIGN_OR_RETURN(decl, expr) -// ASSIGN_OR_RETURN(decl, expr, return_value) -// -// Evaluates `expr` that should convert to a `litert::Expected` object. -// -// - If the object holds a value, move-assigns the value to `decl`. -// - If the object holds an error, returns the error, casting it to a -// `LiteRtStatus` if required. -// -// `return_value` may be specified to return a custom value in case of error. -// -// By when specifying `return_value`, an `ErrorStatusBuilder` variable called -// `_` can be used to customize the error message. -// -// ```cpp -// LITERT_ASSIGN_OR_RETURN(expr, _ << "Failed while trying to ..."); -// ``` -#define LITERT_ASSIGN_OR_RETURN(DECL, ...) \ - LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD((DECL, __VA_ARGS__, \ - LITERT_ASSIGN_OR_RETURN_HELPER_3, \ - LITERT_ASSIGN_OR_RETURN_HELPER_2))( \ - _CONCAT_NAME(expected_value_or_error_, __LINE__), DECL, __VA_ARGS__) - -namespace litert { - -#if defined(__has_builtin) && __has_builtin(__builtin_FILE) && \ - __has_builtin(__builtin_LINE) -#define LITERT_INTERNAL_BUILTIN_FILE __builtin_FILE() -#define LITERT_INTERNAL_BUILTIN_LINE __builtin_LINE() -#else -#define LITERT_INTERNAL_BUILTIN_FILE "unknown" -#define LITERT_INTERNAL_BUILTIN_LINE 0 -#endif - -// Stores a file and a line number. -// -// Mimics a subset of `std::source_location` to be replaced by it when we update -// to C++20. -class SourceLocation { - // We have this to prevent `current()` parameters from begin modified. - struct PrivateTag {}; - - public: - // Creates a SourceLocation with the line and file corresponding to the - // call site. - static constexpr SourceLocation current( - PrivateTag = PrivateTag{}, - const char* file = LITERT_INTERNAL_BUILTIN_FILE, - uint32_t line = LITERT_INTERNAL_BUILTIN_LINE) { - return SourceLocation{file, line}; - } - - constexpr const char* file_name() const { return file_; } - constexpr uint32_t line() const { return line_; } - - private: - // Builds a SourceLocation object. - // - // Note: This is private as `std::source_location` doesn't provide a way of - // manually building a source location. - constexpr SourceLocation(const char* file, uint32_t line) - : file_(file), line_(line) {} - - const char* file_; - uint32_t line_; -}; - -// Converts implicitly to either `LiteRtStatus` or `litert::Expected` holding an -// error. This allows returning a status in functions using either of these as a -// return type in `LITERT_RETURN_IF_ERROR` and `LITERT_ASSIGN_OR_RETURN`. -// -// When a C++ error with a message is converted to a `LiteRtStatus`, the message -// is logged (as an error by default, use the `Log*()` functions to customize -// that). -// -// The error message may be completed with extra info by using the << operator. -class ErrorStatusBuilder { - public: - explicit ErrorStatusBuilder( - bool expr_result, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(kLiteRtStatusErrorUnknown), loc_(loc) {} - - template - explicit ErrorStatusBuilder( - const litert::Expected& expected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(expected.Error()), loc_(loc) {} - - template - explicit ErrorStatusBuilder( - litert::Expected&& expected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(std::move(expected.Error())), loc_(loc) {} - - explicit ErrorStatusBuilder( - LiteRtStatus status, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(status), loc_(loc) {} - - explicit ErrorStatusBuilder( - const litert::Unexpected& unexpected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(unexpected.Error()), loc_(loc) {} - - explicit ErrorStatusBuilder( - litert::Unexpected&& unexpected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(std::move(unexpected.Error())), loc_(loc) {} - - // NOLINTBEGIN(*-explicit-constructor): This class transparently converts to - // `LiteRtStatus` and `litert::Expected`. - - // Note: this conversion logs the error message if there is one unless NDEBUG - // is set (generally in case of optimized builds). - operator LiteRtStatus() const noexcept { -#ifndef NDEBUG - if (ShouldLog()) { - LiteRtLogger logger = LiteRtGetDefaultLogger(); - LiteRtLogSeverity __min_severity__; - if (LiteRtGetMinLoggerSeverity(logger, &__min_severity__) != - kLiteRtStatusOk) { - __min_severity__ = kLiteRtLogSeverityVerbose; - } - if (log_level_ >= __min_severity__) { - LiteRtLoggerLog(logger, log_level_, "[%s:%u] %s", loc_.file_name(), - loc_.line(), LogMessage().c_str()); - } - } -#endif - return error_.Status(); - } - - template - operator litert::Expected() const noexcept { - return litert::Unexpected(error_.Status(), LogMessage()); - } - - operator absl::Status() const noexcept; - - template - operator absl::StatusOr() const noexcept { - return static_cast(*this); - } - // NOLINTEND(*-explicit-constructor) - - static constexpr bool IsError(bool status) { return !status; } - - static constexpr bool IsError(LiteRtStatus status) { - return status != kLiteRtStatusOk; - } - - static constexpr bool IsError(const litert::Unexpected&) { return true; } - - template - static constexpr bool IsError(const litert::Expected& expected) { - return !expected.HasValue(); - } - - // Appends data to the error message. - template - ErrorStatusBuilder& operator<<(T&& val) { - if (!extra_log_) { - extra_log_ = std::make_unique(); - } - *extra_log_ << static_cast(val); - return *this; - } - - // Sets the log level used when converting to a `LiteRtStatus`. - ErrorStatusBuilder& Log(LiteRtLogSeverity log_level) noexcept { - log_level_ = log_level; - return *this; - } - - // Sets the log level used when converting to a `LiteRtStatus` to `error`. - ErrorStatusBuilder& LogVerbose() noexcept { - return Log(kLiteRtLogSeverityVerbose); - } - - // Sets the log level used when converting to a `LiteRtStatus` to `info`. - ErrorStatusBuilder& LogInfo() noexcept { return Log(kLiteRtLogSeverityInfo); } - - // Sets the log level used when converting to a `LiteRtStatus` to `error`. - ErrorStatusBuilder& LogWarning() noexcept { - return Log(kLiteRtLogSeverityWarning); - } - - // Sets the log level used when converting to a `LiteRtStatus` to `error`. - ErrorStatusBuilder& LogError() noexcept { - return Log(kLiteRtLogSeverityError); - } - - // Prevent logging any message when converting to a `LiteRtStatus`. - ErrorStatusBuilder& NoLog() noexcept { return Log(kLiteRtLogSeveritySilent); } - - private: - bool ShouldLog() const noexcept { - return log_level_ != kLiteRtLogSeveritySilent && - (!error_.Message().empty() || extra_log_); - } - - std::string LogMessage() const { - if (!error_.Message().empty() && extra_log_) { - std::string res; - res.reserve(error_.Message().size() + extra_log_->tellp() + 2); - res.append(error_.Message()); - res.append(" "); - res.append(extra_log_->str()); - return res; - } - if (!error_.Message().empty()) { - return error_.Message(); - } - if (extra_log_) { - return extra_log_->str(); - } - return {}; - } - - litert::Error error_; - litert::SourceLocation loc_; - std::unique_ptr extra_log_; - LiteRtLogSeverity log_level_ = kLiteRtLogSeverityError; -}; - -} // namespace litert - -//////////// Implementation details start here. /////////////////////// - -#define LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD_HELPER(_1, _2, OVERLOAD, ...) \ - OVERLOAD - -#define LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD(args) \ - LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD_HELPER args - -#define LITERT_RETURN_IF_ERROR_1(EXPR) \ - LITERT_RETURN_IF_ERROR_2(EXPR, \ - ::litert::ErrorStatusBuilder{std::move(status)}) - -#define LITERT_RETURN_IF_ERROR_2(EXPR, RETURN_VALUE) \ - if (auto status = (EXPR); ::litert::ErrorStatusBuilder::IsError(status)) \ - return RETURN_VALUE - -#define LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD_HELPER(_1, _2, _3, OVERLOAD, \ - ...) \ - OVERLOAD - -#define LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD(args) \ - LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD_HELPER args - -#define LITERT_ASSIGN_OR_RETURN_HELPER_2(TMP_VAR, DECL, EXPR) \ - LITERT_ASSIGN_OR_RETURN_HELPER_3(TMP_VAR, DECL, EXPR, _) - -#define LITERT_ASSIGN_OR_RETURN_HELPER_3(TMP_VAR, DECL, EXPR, RETURN_VALUE) \ - auto&& TMP_VAR = (EXPR); \ - if (::litert::ErrorStatusBuilder::IsError(TMP_VAR)) { \ - [[maybe_unused]] ::litert::ErrorStatusBuilder _(std::move(TMP_VAR)); \ - return RETURN_VALUE; \ - } \ - DECL = std::move(TMP_VAR.Value()); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_macros_test.cc b/tensorflow/lite/experimental/litert/cc/litert_macros_test.cc deleted file mode 100644 index f1d0b66e6748bb..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_macros_test.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace { - -using testing::AllOf; -using testing::Property; - -TEST(LiteRtReturnIfErrorTest, ConvertsResultToLiteRtStatus) { - EXPECT_EQ( - []() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR( - Expected(Unexpected(kLiteRtStatusErrorNotFound))); - return kLiteRtStatusOk; - }(), - kLiteRtStatusErrorNotFound); - EXPECT_EQ( - []() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(Unexpected(kLiteRtStatusErrorNotFound)); - return kLiteRtStatusOk; - }(), - kLiteRtStatusErrorNotFound); - EXPECT_EQ( - []() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(kLiteRtStatusErrorNotFound); - return kLiteRtStatusOk; - }(), - kLiteRtStatusErrorNotFound); -} - -TEST(LiteRtReturnIfErrorTest, ConvertsResultToExpectedHoldingAnError) { - EXPECT_THAT( - []() -> Expected { - LITERT_RETURN_IF_ERROR( - Expected(Unexpected(kLiteRtStatusErrorNotFound))); - return {}; - }(), - AllOf(Property(&Expected::HasValue, false), - Property(&Expected::Error, - Property(&Error::Status, kLiteRtStatusErrorNotFound)))); - EXPECT_THAT( - []() -> Expected { - LITERT_RETURN_IF_ERROR(Unexpected(kLiteRtStatusErrorNotFound)); - return {}; - }(), - AllOf(Property(&Expected::HasValue, false), - Property(&Expected::Error, - Property(&Error::Status, kLiteRtStatusErrorNotFound)))); - EXPECT_THAT( - []() -> Expected { - LITERT_RETURN_IF_ERROR(kLiteRtStatusErrorNotFound); - return {}; - }(), - AllOf(Property(&Expected::HasValue, false), - Property(&Expected::Error, - Property(&Error::Status, kLiteRtStatusErrorNotFound)))); -} - -TEST(LiteRtReturnIfErrorTest, DoesntReturnOnSuccess) { - int canary_value = 0; - auto ReturnExpectedIfError = [&canary_value]() -> Expected { - LITERT_RETURN_IF_ERROR(Expected()); - canary_value = 1; - return {}; - }; - EXPECT_THAT(ReturnExpectedIfError(), - Property(&Expected::HasValue, true)); - EXPECT_EQ(canary_value, 1); - - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(kLiteRtStatusOk); - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 2); -} - -TEST(LiteRtReturnIfErrorTest, ExtraLoggingWorks) { - int canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false) << "Successful default level logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogVerbose() << "Successful verbose logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogInfo() << "Successful info logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogWarning() << "Successful warning logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogError() << "Successful error logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).NoLog() << "This should never be printed"; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); -} - -TEST(LiteRtAssignOrReturnTest, VariableAssignmentWorks) { - int canary_value = 0; - auto ChangeCanaryValue = [&canary_value]() -> LiteRtStatus { - LITERT_ASSIGN_OR_RETURN(canary_value, Expected(1)); - return kLiteRtStatusOk; - }; - EXPECT_EQ(ChangeCanaryValue(), kLiteRtStatusOk); - EXPECT_EQ(canary_value, 1); -} - -TEST(LiteRtAssignOrReturnTest, MoveOnlyVariableAssignmentWorks) { - struct MoveOnly { - explicit MoveOnly(int val) : val(val) {}; - MoveOnly(const MoveOnly&) = delete; - MoveOnly& operator=(const MoveOnly&) = delete; - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; - int val = 1; - }; - - MoveOnly canary_value{0}; - auto ChangeCanaryValue = [&canary_value]() -> LiteRtStatus { - LITERT_ASSIGN_OR_RETURN(canary_value, Expected(1)); - return kLiteRtStatusOk; - }; - EXPECT_EQ(ChangeCanaryValue(), kLiteRtStatusOk); - EXPECT_EQ(canary_value.val, 1); -} - -TEST(LiteRtAssignOrReturnTest, ReturnsOnFailure) { - const Expected InvalidArgumentError = - Expected(Unexpected(kLiteRtStatusErrorInvalidArgument)); - - int canary_value = 0; - auto ErrorWithStatus = [&]() -> LiteRtStatus { - LITERT_ASSIGN_OR_RETURN(canary_value, InvalidArgumentError); - return kLiteRtStatusOk; - }; - EXPECT_EQ(ErrorWithStatus(), kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(canary_value, 0); - - auto ErrorWithCustomStatus = [&]() -> int { - LITERT_ASSIGN_OR_RETURN(canary_value, InvalidArgumentError, 42); - return 1; - }; - EXPECT_EQ(ErrorWithCustomStatus(), 42); - EXPECT_EQ(canary_value, 0); - - auto ErrorWithExpected = [&]() -> Expected { - LITERT_ASSIGN_OR_RETURN(canary_value, InvalidArgumentError); - return {}; - }; - auto expected_return = ErrorWithExpected(); - ASSERT_FALSE(expected_return.HasValue()); - EXPECT_EQ(expected_return.Error().Status(), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(canary_value, 0); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.cc b/tensorflow/lite/experimental/litert/cc/litert_model.cc deleted file mode 100644 index b67c5c75d2375a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { - -bool Tensor::IsSubgraphOutput() const { return Uses().empty(); } - -bool Tensor::IsSubgraphInput() const { - // A special case for zero-sized tensors. - if (RankedTensorType()->Layout().Rank() == 1 && - RankedTensorType()->Layout().Dimensions()[0] == 0) { - return false; - } - return !HasWeights() && !DefiningOp().has_value(); -} - -bool Tensor::IsConstant() const { - return HasWeights() && !DefiningOp().has_value(); -} - -Tensor::TensorUses Tensor::Uses() const { - LiteRtParamIndex num_uses; - litert::internal::AssertOk(LiteRtGetNumTensorUses, Get(), &num_uses); - - TensorUses uses; - for (auto i = 0; i < num_uses; ++i) { - LiteRtOp user; - LiteRtParamIndex user_arg_index; - litert::internal::AssertOk(LiteRtGetTensorUse, Get(), i, &user, - &user_arg_index); - uses.emplace_back(TensorUse{Op(user), user_arg_index}); - } - return uses; -} - -OpInputs Op::Inputs() const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumOpInputs, Get(), &num_inputs); - - OpInputs inputs; - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensor input; - internal::AssertOk(LiteRtGetOpInput, Get(), i, &input); - inputs.emplace_back(Tensor(input)); - } - return inputs; -} - -OpOutputs Op::Outputs() const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumOpOutputs, Get(), &num_outputs); - - OpOutputs outputs; - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensor output; - internal::AssertOk(LiteRtGetOpOutput, Get(), i, &output); - outputs.emplace_back(Tensor(output)); - } - return outputs; -} - -SubgraphInputs Subgraph::Inputs() const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumSubgraphInputs, Get(), &num_inputs); - - SubgraphInputs inputs; - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensor input; - internal::AssertOk(LiteRtGetSubgraphInput, Get(), i, &input); - inputs.emplace_back(Tensor(input)); - } - return inputs; -} - -Expected Subgraph::Input(absl::string_view name) const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumSubgraphInputs, Get(), &num_inputs); - - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensor input; - internal::AssertOk(LiteRtGetSubgraphInput, Get(), i, &input); - const char* input_name; - internal::AssertOk(LiteRtGetTensorName, input, &input_name); - if (name == input_name) { - return Tensor(input); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); -} - -Expected Subgraph::Output(absl::string_view name) const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumSubgraphOutputs, Get(), &num_outputs); - - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensor output; - internal::AssertOk(LiteRtGetSubgraphOutput, Get(), i, &output); - const char* output_name; - internal::AssertOk(LiteRtGetTensorName, output, &output_name); - if (name == output_name) { - return Tensor(output); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); -} - -SubgraphOutputs Subgraph::Outputs() const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumSubgraphOutputs, Get(), &num_outputs); - - SubgraphOutputs outputs; - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensor output; - internal::AssertOk(LiteRtGetSubgraphOutput, Get(), i, &output); - outputs.emplace_back(Tensor(output)); - } - return outputs; -} - -std::vector Subgraph::Ops() const { - LiteRtParamIndex num_ops; - internal::AssertOk(LiteRtGetNumSubgraphOps, Get(), &num_ops); - - std::vector ops; - for (auto i = 0; i < num_ops; ++i) { - LiteRtOp op; - litert::internal::AssertOk(LiteRtGetSubgraphOp, Get(), i, &op); - ops.emplace_back(Op(op)); - } - return ops; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.h b/tensorflow/lite/experimental/litert/cc/litert_model.h deleted file mode 100644 index 579e97db9888e1..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model.h +++ /dev/null @@ -1,473 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" - -namespace litert { - -// Type for tensors with known dimensions. C++ equivalent to -// LiteRtRankedTensorType. -class RankedTensorType { - public: - RankedTensorType(enum ElementType element_type, class Layout&& layout) - : element_type_(element_type), layout_(std::move(layout)) {} - explicit RankedTensorType(const LiteRtRankedTensorType& type) - : element_type_(static_cast(type.element_type)), - layout_(type.layout) {} - - explicit operator LiteRtRankedTensorType() const { - return LiteRtRankedTensorType{ - /*.element_type=*/static_cast(element_type_), - /*layout=*/static_cast(layout_), - }; - } - - bool operator==(const RankedTensorType& other) const { - return ElementType() == other.ElementType() && Layout() == other.Layout(); - } - - enum ElementType ElementType() const { return element_type_; } - - const class Layout& Layout() const { return layout_; } - - private: - enum ElementType element_type_; - class Layout layout_; -}; - -// Tensor weights. C++ equivalent of LiteRtWeights. -class Weights : public internal::NonOwnedHandle { - public: - explicit Weights(LiteRtWeights weights) - : internal::NonOwnedHandle(weights) {} - - absl::Span Bytes() const { - size_t size; - const void* addr; - internal::AssertOk(LiteRtGetWeightsBytes, Get(), &addr, &size); - return absl::MakeSpan(static_cast(addr), size); - } -}; - -// Tensor. C++ equivalent of LiteRtTensor. -class Tensor : public internal::NonOwnedHandle { - public: - explicit Tensor(LiteRtTensor tensor) - : internal::NonOwnedHandle(tensor) {} - - enum ElementType ElementType() const { - if (TypeId() == kLiteRtUnrankedTensorType) { - return static_cast(UnrankedTensorType()->element_type); - } else { - return RankedTensorType()->ElementType(); - } - } - - LiteRtTensorTypeId TypeId() const { - LiteRtTensorTypeId type_id; - internal::AssertOk(LiteRtGetTensorTypeId, Get(), &type_id); - return type_id; - } - - Expected UnrankedTensorType() const { - if (TypeId() != kLiteRtUnrankedTensorType) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Not an unranked invalid tensor"); - } - LiteRtUnrankedTensorType unranked_tensor_type; - internal::AssertOk(LiteRtGetUnrankedTensorType, Get(), - &unranked_tensor_type); - return unranked_tensor_type; - } - - Expected RankedTensorType() const { - if (TypeId() != kLiteRtRankedTensorType) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Not a ranked tensor type"); - } - LiteRtRankedTensorType ranked_tensor_type; - internal::AssertOk(LiteRtGetRankedTensorType, Get(), &ranked_tensor_type); - return litert::RankedTensorType(ranked_tensor_type); - } - - LiteRtQuantizationTypeId QTypeId() const { - LiteRtQuantizationTypeId q_type_id; - internal::AssertOk(LiteRtGetQuantizationTypeId, Get(), &q_type_id); - return q_type_id; - } - - bool HasQuantization() const { return QTypeId() != kLiteRtQuantizationNone; } - - LiteRtQuantizationPerTensor PerTensorQuantization() const { - internal::AssertEq([&]() { return QTypeId(); }, - kLiteRtQuantizationPerTensor); - LiteRtQuantizationPerTensor per_tensor_quantization; - internal::AssertOk(LiteRtGetPerTensorQuantization, Get(), - &per_tensor_quantization); - return per_tensor_quantization; - } - - LiteRtQuantizationPerChannel PerChannelQuantization() const { - internal::AssertEq([&]() { return QTypeId(); }, - kLiteRtQuantizationPerChannel); - LiteRtQuantizationPerChannel per_channel_quantization; - internal::AssertOk(LiteRtGetPerChannelQuantization, Get(), - &per_channel_quantization); - return per_channel_quantization; - } - - bool HasWeights() const { - auto weights = Weights(); - return !weights.Bytes().empty(); - } - - class Weights Weights() const { - LiteRtWeights weights; - internal::AssertOk(LiteRtGetTensorWeights, Get(), &weights); - return litert::Weights(weights); - } - - absl::string_view Name() const { - const char* name; - internal::AssertOk(LiteRtGetTensorName, Get(), &name); - return absl::string_view(name); - } - - struct TensorUse; - using TensorUses = - absl::InlinedVector; - - TensorUses Uses() const; - - template - Expected> WeightsData() const { - auto ranked_tensor_type = RankedTensorType(); - if (!ranked_tensor_type) { - return ranked_tensor_type.Error(); - } - - const enum ElementType ty = ranked_tensor_type->ElementType(); - if (ty != GetElementType()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - if (!HasWeights()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - const absl::Span weights = Weights().Bytes(); - - auto num_elements = ranked_tensor_type->Layout().NumElements(); - if (!num_elements.has_value()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - auto byte_width = GetByteWidth(ty); - if (!byte_width.has_value()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - if (byte_width.value() * num_elements.value() != weights.size()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - return absl::MakeConstSpan(reinterpret_cast(weights.data()), - num_elements.value()); - } - - std::optional DefiningOp() const { - bool has_defining_op; - LiteRtTensorDefiningOp defining_op; - internal::AssertOk(LiteRtGetTensorDefiningOp, Get(), &has_defining_op, - &defining_op); - if (has_defining_op) { - return defining_op; - } else { - return std::nullopt; - } - } - - bool IsSubgraphOutput() const; - bool IsSubgraphInput() const; - bool IsConstant() const; -}; - -using OpInputs = absl::InlinedVector; -using OpOutputs = absl::InlinedVector; - -// Operator. C++ equivalent of LiteRtOp. -class Op : public internal::NonOwnedHandle { - public: - explicit Op(LiteRtOp op) : internal::NonOwnedHandle(op) {} - - LiteRtOpCode Code() const { - LiteRtOpCode opcode; - internal::AssertOk(LiteRtGetOpCode, Get(), &opcode); - return opcode; - } - - OpInputs Inputs() const; - OpOutputs Outputs() const; -}; - -struct Tensor::TensorUse { - Op user; - LiteRtParamIndex user_arg_ind; -}; - -using SubgraphInputs = - absl::InlinedVector; -using SubgraphOutputs = - absl::InlinedVector; - -// Model subgraph. C++ equivalent of LiteRtSubgraph. -class Subgraph : public internal::NonOwnedHandle { - public: - explicit Subgraph(LiteRtSubgraph subgraph) - : internal::NonOwnedHandle(subgraph) {} - - SubgraphInputs Inputs() const; - SubgraphOutputs Outputs() const; - std::vector Ops() const; - - // Returns the input tensor with the given input signature name. - Expected Input(absl::string_view name) const; - - // Returns the output tensor with the given output signature name. - Expected Output(absl::string_view name) const; -}; - -// Model signature. C++ equivalent of LiteRtSignature. -class Signature : public internal::NonOwnedHandle { - public: - explicit Signature(LiteRtSignature signature) - : internal::NonOwnedHandle(signature) {} - - absl::string_view Key() const { - const char* key; - internal::AssertOk(LiteRtGetSignatureKey, Get(), &key); - return key; - } - - LiteRtSubgraph Subgraph() const { - LiteRtSubgraph subgraph; - internal::AssertOk(LiteRtGetSignatureSubgraph, Get(), &subgraph); - return subgraph; - } - - std::vector InputNames() const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumSignatureInputs, Get(), &num_inputs); - std::vector input_names; - input_names.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - const char* input_name; - internal::AssertOk(LiteRtGetSignatureInputName, Get(), i, &input_name); - input_names.push_back(input_name); - } - return input_names; - } - - std::vector OutputNames() const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumSignatureOutputs, Get(), &num_outputs); - std::vector output_names; - output_names.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - const char* output_name; - internal::AssertOk(LiteRtGetSignatureOutputName, Get(), i, &output_name); - output_names.push_back(output_name); - } - return output_names; - } -}; - -// Model. C++ equivalent of LiteRtModel. -class Model : public internal::Handle { - public: - Model() = default; - - static Model CreateFromOwnedHandle(LiteRtModel model) { - return Model(model, /*owned=*/true); - } - - static Model CreateFromNonOwnedHandle(LiteRtModel model) { - return Model(model, /*owned=*/false); - } - - static Expected CreateFromFile(const std::string& filename) { - LiteRtModel model; - if (auto status = LiteRtCreateModelFromFile(filename.c_str(), &model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to load model from file"); - } - return CreateFromOwnedHandle(model); - } - - static Expected CreateFromBuffer(BufferRef buffer) { - LiteRtModel model; - if (auto status = - LiteRtCreateModelFromBuffer(buffer.Data(), buffer.Size(), &model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to load model from buffer"); - } - return CreateFromOwnedHandle(model); - } - - Expected> Metadata( - const std::string& metadata_key) const { - const void* buffer; - size_t buffer_size; - if (LiteRtGetModelMetadata(Get(), metadata_key.data(), &buffer, - &buffer_size) != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorNotFound, "Metadata key not found"); - } - return absl::MakeSpan(static_cast(buffer), buffer_size); - } - - Expected MainSubgraph() const { - LiteRtParamIndex main_subgraph_index; - internal::AssertOk(LiteRtGetMainModelSubgraphIndex, Get(), - &main_subgraph_index); - return this->Subgraph(main_subgraph_index); - } - - size_t NumSubgraphs() const { - LiteRtParamIndex num_subgraphs; - internal::AssertOk(LiteRtGetNumModelSubgraphs, Get(), &num_subgraphs); - return num_subgraphs; - } - - Expected Subgraph(size_t subgraph_index) const { - LiteRtSubgraph subgraph; - if (LiteRtGetModelSubgraph(Get(), subgraph_index, &subgraph) != - kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorNotFound, "Subgraph not found"); - } - return litert::Subgraph(subgraph); - } - - Expected Subgraph(absl::string_view signature_key) const { - auto signature = FindSignature(signature_key); - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); - } - return litert::Subgraph(signature->Subgraph()); - } - - size_t GetNumSignatures() const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - return num_signatures; - } - - // Returns the list of signatures defined in the model. - Expected> GetSignatures() const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - std::vector signatures; - signatures.reserve(num_signatures); - for (int i = 0; i < num_signatures; ++i) { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); - Signature signature(lite_rt_signature); - signatures.push_back(std::move(signature)); - } - return std::move(signatures); - } - - // Returns the signature at the given index. - Expected GetSignature(size_t signature_index) const { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), signature_index, - &lite_rt_signature); - return Signature(lite_rt_signature); - } - - // Returns the signature index for the given signature key. - Expected GetSignatureIndex(absl::string_view signature_key) const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - for (int i = 0; i < num_signatures; ++i) { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); - const char* key_cstr; - internal::AssertOk(LiteRtGetSignatureKey, lite_rt_signature, &key_cstr); - if (absl::string_view(key_cstr) == signature_key) { - return i; - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); - } - - // Returns the Signature object for the given signature key. - Expected FindSignature( - absl::string_view signature_key) const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - for (int i = 0; i < num_signatures; ++i) { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); - const char* key_cstr; - internal::AssertOk(LiteRtGetSignatureKey, lite_rt_signature, &key_cstr); - if (absl::string_view(key_cstr) == signature_key) { - return Signature(lite_rt_signature); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); - } - - static absl::string_view DefaultSignatureKey() { - const char* key; - internal::AssertOk(LiteRtGetDefaultSignatureKey, &key); - return key; - } - - private: - // Parameter `owned` indicates if the created TensorBuffer object should take - // ownership of the provided `tensor_buffer` handle. - Model(LiteRtModel model, bool owned) - : internal::Handle(model, owned) {} -}; - -struct SerializationOptions { - static LiteRtModelSerializationOptions Defaults() { - return LiteRtModelSerializationOptions{}; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc b/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc deleted file mode 100644 index 18efea56f7ffa4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" - -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -namespace litert { -namespace { - -template -bool Any(absl::Span vals, std::function unary_pred) { - for (const auto& val : vals) { - if (unary_pred(val)) { - return true; - } - } - return false; -} - -bool UseSoftEqual(const Tensor::TensorUse& actual_use, - const UseInfo& expected_use) { - if (expected_use.user_param_ind.has_value() && - actual_use.user_arg_ind != expected_use.user_param_ind.value()) { - return false; - } - if (expected_use.op_code.has_value() && - actual_use.user.Code() != expected_use.op_code.value()) { - return false; - } - return true; -} - -} // namespace - -// Does given tensor have given type and shape info. Optional values considered -// to be a vacous match. -bool MatchRankedTensorType(const RankedTensorType& tensor_type, - const TensorTypeInfo& expected) { - if (expected.element_type.has_value() && - (tensor_type.ElementType() != expected.element_type.value())) { - return false; - } - - if (expected.dims.has_value()) { - auto actual_dims = tensor_type.Layout().Dimensions(); - auto expected_dims = absl::MakeConstSpan(expected.dims.value()); - return AllZip(actual_dims, expected_dims, - [](auto l, auto r) -> bool { return l == r; }); - } - return true; -} - -// Does given op have signature matching given types. Optional values considered -// to be a vacous match. -bool MatchOpType( - const Op& op, - const std::vector>& expected_inputs, - const std::vector>& expected_outputs) { - auto match = [](const Tensor& actual, - const std::optional& expected) -> bool { - if (!expected.has_value()) { - return true; - } - auto actual_ranked_tensor_type = actual.RankedTensorType(); - // Don't return a match if the tensor is unranked. - if (!actual_ranked_tensor_type) { - return false; - } - return MatchRankedTensorType(*actual_ranked_tensor_type, expected.value()); - }; - - const bool inputs_match = AllZip(absl::MakeConstSpan(op.Inputs()), - absl::MakeConstSpan(expected_inputs), match); - const bool outputs_match = - AllZip(absl::MakeConstSpan(op.Outputs()), - absl::MakeConstSpan(expected_outputs), match); - return inputs_match && outputs_match; -} - -bool MatchUse(const Tensor& tensor, const UseInfo& expected_use) { - auto soft_equal = [&expected_use = std::as_const(expected_use)]( - const Tensor::TensorUse& actual_use) { - return UseSoftEqual(actual_use, expected_use); - }; - return Any(tensor.Uses(), soft_equal); -} - -bool MatchUses(const Tensor& tensor, const std::vector& expected_uses, - bool strict) { - const auto uses = tensor.Uses(); - if (strict && uses.size() != expected_uses.size()) { - return false; - } - auto not_use = [&tensor = - std::as_const(tensor)](const UseInfo& expected_use) { - return !MatchUse(tensor, expected_use); - }; - return !Any(expected_uses, not_use); -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.h b/tensorflow/lite/experimental/litert/cc/litert_model_predicates.h deleted file mode 100644 index 238e9a455bbb9e..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ - -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -// Predicates used for matching patterns in the graph. NOTE: All optionals in -// matcher arguments are considered to be a vacous match. - -namespace litert { - -struct TensorTypeInfo { - std::optional element_type = std::nullopt; - std::optional> dims = std::nullopt; - - explicit TensorTypeInfo(ElementType element_type) - : element_type(element_type) {} - explicit TensorTypeInfo(absl::InlinedVector dims) : dims(dims) {} - TensorTypeInfo(ElementType element_type, absl::InlinedVector dims) - : element_type(element_type), dims(dims) {} -}; - -struct UseInfo { - std::optional op_code = std::nullopt; - std::optional user_param_ind = std::nullopt; -}; - -// Does this tensor have given type and shape info. -bool MatchRankedTensorType(const RankedTensorType& tensor_type, - const TensorTypeInfo& expected); - -// Does this op have signature matching given types. -bool MatchOpType( - const Op& op, - const std::vector>& expected_inputs, - const std::vector>& expected_outputs); - -// Does this tensor contain weights whose values match expected_data. -template -inline bool MatchWeights(const Tensor& tensor, - absl::Span expected_data) { - auto weights = tensor.WeightsData(); - return weights.HasValue() && *weights == expected_data; -} - -// Does this tensor have a user with the given information. -bool MatchUse(const Tensor& tensor, const UseInfo& expected_use); - -// Does this tensor have matching users. If "strict" is true, then expected_uses -// size must equal the number of actual uses, otherwise just checks each -// expected_use match an actual use. -bool MatchUses(const Tensor& tensor, const std::vector& expected_uses, - bool strict = true); - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc b/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc deleted file mode 100644 index f16bc764e560c4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" - -#include - -#include -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace litert { - -namespace { - -using ::litert::testing::LoadTestFileModel; - -TEST(MatchRankedTensorTypeTest, HasAll) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_TRUE(MatchRankedTensorType( - *input_tensor_type, TensorTypeInfo(ElementType::Float32, {2, 2}))); -} - -TEST(MatchRankedTensorTypeTest, NoMatch) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_FALSE(MatchRankedTensorType( - *input_tensor_type, TensorTypeInfo(ElementType::Float32, {3, 2}))); -} - -TEST(MatchRankedTensorTypeTest, AnyDims) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_TRUE(MatchRankedTensorType(*input_tensor_type, - TensorTypeInfo(ElementType::Float32))); -} - -TEST(MatchRankedTensorTypeTest, AnyElementType) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_TRUE( - MatchRankedTensorType(*input_tensor_type, TensorTypeInfo({2, 2}))); -} - -TEST(MatchOpTypeTest, HasAll) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - EXPECT_TRUE(MatchOpType(ops.front(), {expected_type, expected_type}, - {expected_type})); -} - -TEST(MatchOpTypeTest, NoMatch) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - TensorTypeInfo not_expected_type(ElementType::Int32, {2, 2}); - EXPECT_FALSE(MatchOpType(ops.front(), {not_expected_type, expected_type}, - {expected_type})); -} - -TEST(MatchOpTypeTest, AnyInput) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - EXPECT_TRUE( - MatchOpType(ops.front(), {std::nullopt, expected_type}, {expected_type})); -} - -TEST(MatchOpTypeTest, AnyOutput) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - EXPECT_TRUE( - MatchOpType(ops.front(), {std::nullopt, expected_type}, {std::nullopt})); -} - -TEST(MatchWeightsTest, Matches) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& cst = inputs.back(); - EXPECT_TRUE(MatchWeights(cst, absl::Span({1.0, 2.0, 3.0, 4.0}))); -} - -TEST(MatchWeightsTest, NoMatchBadType) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& cst = inputs.back(); - EXPECT_FALSE( - MatchWeights(cst, absl::Span({1.0, 2.0, 3.0, 4.0}))); -} -TEST(MatchWeightsTest, NoMatchBadVals) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& cst = inputs.back(); - EXPECT_FALSE( - MatchWeights(cst, absl::Span({3.0, 2.0, 3.0, 5.0}))); -} - -TEST(MatchUseTest, Match) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - EXPECT_TRUE(MatchUse(inputs.back(), UseInfo{kLiteRtOpCodeTflAdd, 1})); -} - -TEST(MatchUseTest, MatchAnyCode) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - EXPECT_TRUE(MatchUse(inputs.back(), UseInfo{std::nullopt, 1})); -} - -TEST(MatchUseTest, NoMatch) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - EXPECT_FALSE(MatchUse(inputs.back(), UseInfo{std::nullopt, 2})); -} - -TEST(MatchUsesTest, StrictMatch) { - auto litert_model = LoadTestFileModel("add_simple.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto subgraph_inputs = subgraph->Inputs(); - const auto& tensor = subgraph_inputs.front(); - EXPECT_TRUE( - MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}, {kLiteRtOpCodeTflAdd, 1}})); -} - -TEST(MatchUsesTest, StrictNoMatch) { - auto litert_model = LoadTestFileModel("add_simple.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto subgraph_inputs = subgraph->Inputs(); - const auto& tensor = subgraph_inputs.front(); - EXPECT_FALSE(MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}})); -} - -TEST(MatchUsesTest, NonStrict) { - auto litert_model = LoadTestFileModel("add_simple.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto subgraph_inputs = subgraph->Inputs(); - const auto& tensor = subgraph_inputs.front(); - EXPECT_TRUE(MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}}, /*strict=*/false)); -} - -} // namespace - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_test.cc b/tensorflow/lite/experimental/litert/cc/litert_model_test.cc deleted file mode 100644 index a1a80f82e5f397..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_test.cc +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -// Tests for CC Wrapper classes around public C api. - -namespace litert { - -namespace { - -static constexpr const int32_t kTensorDimensions[] = {1, 2, 3}; - -static constexpr const auto kRank = - sizeof(kTensorDimensions) / sizeof(kTensorDimensions[0]); - -static constexpr const uint32_t kTensorStrides[] = {6, 3, 1}; - -static constexpr const LiteRtLayout kLayout = BuildLayout(kTensorDimensions); - -static constexpr const LiteRtLayout kLayoutWithStrides = - BuildLayout(kTensorDimensions, kTensorStrides); - -static constexpr const LiteRtRankedTensorType kTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - /*.layout=*/kLayout, -}; - -//===----------------------------------------------------------------------===// -// CC Model // -//===----------------------------------------------------------------------===// - -TEST(CcModelTest, SimpleModel) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - LiteRtParamIndex num_subgraphs; - ASSERT_EQ(LiteRtGetNumModelSubgraphs(model.Get(), &num_subgraphs), - kLiteRtStatusOk); - EXPECT_EQ(model.NumSubgraphs(), num_subgraphs); - EXPECT_EQ(model.NumSubgraphs(), 1); - - LiteRtParamIndex main_subgraph_index; - ASSERT_EQ(LiteRtGetMainModelSubgraphIndex(model.Get(), &main_subgraph_index), - kLiteRtStatusOk); - EXPECT_EQ(main_subgraph_index, 0); - - LiteRtSubgraph litert_subgraph_0; - ASSERT_EQ(LiteRtGetModelSubgraph(model.Get(), /*subgraph_index=*/0, - &litert_subgraph_0), - kLiteRtStatusOk); - - auto subgraph_0 = model.Subgraph(0); - ASSERT_TRUE(subgraph_0); - EXPECT_EQ(subgraph_0->Get(), litert_subgraph_0); - - auto main_subgraph = model.MainSubgraph(); - EXPECT_EQ(main_subgraph->Get(), subgraph_0->Get()); -} - -//===----------------------------------------------------------------------===// -// CC Signature // -//===----------------------------------------------------------------------===// - -TEST(CcSignatureTest, Basic) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - auto signatures = model.GetSignatures(); - ASSERT_TRUE(signatures); - ASSERT_EQ(signatures->size(), 1); - auto& signature = signatures->at(0); - EXPECT_THAT(signature.Key(), Model::DefaultSignatureKey()); - auto input_names = signature.InputNames(); - EXPECT_THAT(input_names[0], "arg0"); - EXPECT_THAT(input_names[1], "arg1"); - auto output_names = signature.OutputNames(); - EXPECT_THAT(output_names[0], "tfl.mul"); -} - -TEST(CcSignatureTest, Lookup) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - { - auto signature = model.FindSignature("nonexistent"); - ASSERT_FALSE(signature); - } - auto signature = model.FindSignature(Model::DefaultSignatureKey()); - ASSERT_TRUE(signature); - EXPECT_THAT(signature->Key(), Model::DefaultSignatureKey()); - auto input_names = signature->InputNames(); - EXPECT_THAT(input_names[0], "arg0"); - EXPECT_THAT(input_names[1], "arg1"); - auto output_names = signature->OutputNames(); - EXPECT_THAT(output_names[0], "tfl.mul"); -} - -//===----------------------------------------------------------------------===// -// CC Layout // -//===----------------------------------------------------------------------===// - -TEST(CcLayoutTest, NoStrides) { - Layout layout(kLayout); - - ASSERT_EQ(layout.Rank(), kLayout.rank); - for (auto i = 0; i < layout.Rank(); ++i) { - ASSERT_EQ(layout.Dimensions()[i], kLayout.dimensions[i]); - } - ASSERT_FALSE(layout.HasStrides()); -} - -TEST(CcLayoutTest, WithStrides) { - Layout layout(kLayoutWithStrides); - - ASSERT_EQ(layout.Rank(), kLayoutWithStrides.rank); - for (auto i = 0; i < layout.Rank(); ++i) { - ASSERT_EQ(layout.Dimensions()[i], kLayoutWithStrides.dimensions[i]); - } - ASSERT_TRUE(layout.HasStrides()); - for (auto i = 0; i < layout.Rank(); ++i) { - ASSERT_EQ(layout.Strides()[i], kLayoutWithStrides.strides[i]); - } -} - -TEST(CcLayoutTest, Equal) { - auto&& dims = {2, 2}; - Layout layout1(BuildLayout(dims)); - Layout layout2(BuildLayout({2, 2})); - ASSERT_TRUE(layout1 == layout2); -} - -TEST(CcLayoutTest, NotEqual) { - Layout layout1(BuildLayout({2, 2}, nullptr)); - Layout layout2(BuildLayout({2, 2}, kTensorStrides)); - ASSERT_FALSE(layout1 == layout2); -} - -TEST(CcLayoutTest, NumElements) { - Layout layout(BuildLayout({2, 2, 3})); - auto num_elements = layout.NumElements(); - ASSERT_TRUE(num_elements.has_value()); - EXPECT_EQ(num_elements.value(), 12); -} - -//===----------------------------------------------------------------------===// -// CC Op // -//===----------------------------------------------------------------------===// - -TEST(CcOpTest, SimpleSupportedOp) { - auto litert_model = testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - const auto ops = subgraph->Ops(); - const auto& op = ops.front(); - - EXPECT_EQ(op.Code(), kLiteRtOpCodeTflMul); - EXPECT_EQ(op.Inputs().size(), 2); - EXPECT_EQ(op.Outputs().size(), 1); -} - -//===----------------------------------------------------------------------===// -// CC RankedTensorType // -//===----------------------------------------------------------------------===// - -TEST(CcRankedTensorTypeTest, Accessors) { - Layout layout(kLayout); - RankedTensorType tensor_type(kTensorType); - ASSERT_EQ(tensor_type.ElementType(), - static_cast(kTensorType.element_type)); - ASSERT_TRUE(tensor_type.Layout() == layout); -} - -//===----------------------------------------------------------------------===// -// CC Tensor // -//===----------------------------------------------------------------------===// - -TEST(CcTensorTest, SimpleModel) { - auto litert_model = testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - - auto inputs = subgraph->Inputs(); - ASSERT_EQ(inputs.size(), 2); - - { - const Tensor& input_tensor = inputs.front(); - ASSERT_EQ(input_tensor.TypeId(), kLiteRtRankedTensorType); - - auto input_ranked_tensor_type = input_tensor.RankedTensorType(); - EXPECT_TRUE(input_ranked_tensor_type); - ASSERT_EQ(input_ranked_tensor_type->ElementType(), ElementType::Float32); - - EXPECT_FALSE(input_tensor.HasWeights()); - - auto input_weights = input_tensor.Weights(); - ASSERT_EQ(input_weights.Bytes().size(), 0); - - ASSERT_EQ(input_tensor.DefiningOp(), std::nullopt); - - const auto uses = input_tensor.Uses(); - ASSERT_EQ(uses.size(), 1); - } - - auto outputs = subgraph->Outputs(); - ASSERT_EQ(outputs.size(), 1); - - { - const Tensor& output_tensor = outputs.front(); - ASSERT_EQ(output_tensor.TypeId(), kLiteRtRankedTensorType); - - auto output_defining_op = output_tensor.DefiningOp(); - EXPECT_TRUE(output_defining_op.has_value()); - - ASSERT_TRUE(output_tensor.Uses().empty()); - } -} - -TEST(CcTensorTest, WeightsData) { - auto litert_model = testing::LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - - auto data = subgraph->Ops().front().Inputs().back().WeightsData(); - ASSERT_TRUE(data.HasValue()); - EXPECT_THAT(data.Value(), ::testing::ElementsAreArray({1.0, 2.0, 3.0, 4.0})); -} - -TEST(CcTensorTest, Name) { - static constexpr absl::string_view kName = "foo"; - LiteRtTensorT tensor; - tensor.SetName(std::string(kName)); - - Tensor cc_tensor(&tensor); - EXPECT_EQ(cc_tensor.Name(), kName); -} - -TEST(CcTensorTest, QuantizationNone) { - LiteRtTensorT litert_tensor; - litert_tensor.Qparams().first = kLiteRtQuantizationNone; - - Tensor tensor(&litert_tensor); - EXPECT_EQ(tensor.QTypeId(), kLiteRtQuantizationNone); - EXPECT_FALSE(tensor.HasQuantization()); -} - -TEST(CcTensorTest, QuantizationPerTensor) { - static constexpr auto kScale = 1.0; - static constexpr auto kZeroPoint = 1; - - LiteRtTensorT litert_tensor; - litert_tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); - - Tensor tensor(&litert_tensor); - ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerTensor); - ASSERT_TRUE(tensor.HasQuantization()); - - const auto per_tensor_quantization = tensor.PerTensorQuantization(); - EXPECT_EQ(per_tensor_quantization.scale, kScale); - EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); -} - -TEST(CcTensorTest, QuantizationPerChannel) { - static constexpr auto kNumChannels = 2; - static constexpr auto kQuantizedDimension = 0; - static constexpr float kScales[kNumChannels] = {1.0, 2.0}; - static constexpr int64_t kZeroPoints[kNumChannels] = {0, 0}; - - LiteRtTensorT litert_tensor; - auto per_channel = MakePerChannelQuantization( - kScales, kZeroPoints, kQuantizedDimension, litert_tensor); - litert_tensor.SetQarams(per_channel); - - Tensor tensor(&litert_tensor); - ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerChannel); - ASSERT_TRUE(tensor.HasQuantization()); - - const auto per_channel_quantization = tensor.PerChannelQuantization(); - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), - ::testing::ElementsAreArray(kScales)); - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), - ::testing::ElementsAreArray(kZeroPoints)); - EXPECT_EQ(per_channel_quantization.num_channels, kNumChannels); - EXPECT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); -} - -TEST(CcTensorTest, ZeroSizeTensorTest) { - auto litert_model = testing::LoadTestFileModel("scala_reshape.tflite"); - auto subgraph = litert_model.MainSubgraph(); - const auto ops = subgraph->Ops(); - const auto& op = ops.front(); - EXPECT_FALSE(op.Inputs().at(1).IsSubgraphInput()); -} - -//===----------------------------------------------------------------------===// -// CC Subgraph // -//===----------------------------------------------------------------------===// - -TEST(CcSubgraphTest, SimpleModel) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = model.MainSubgraph(); - - ASSERT_EQ(subgraph->Inputs().size(), 2); - ASSERT_EQ(subgraph->Outputs().size(), 1); - ASSERT_EQ(subgraph->Ops().size(), 1); - - auto input0_tensor = subgraph->Input("arg0"); - ASSERT_TRUE(input0_tensor.HasValue()); - auto input1_tensor = subgraph->Input("arg1"); - ASSERT_TRUE(input1_tensor.HasValue()); - - auto output_tensor = subgraph->Output("tfl.mul"); - ASSERT_TRUE(output_tensor.HasValue()); - ASSERT_EQ(output_tensor->TypeId(), kLiteRtRankedTensorType); - auto output_ranked_tensor_type = output_tensor->RankedTensorType(); - EXPECT_TRUE(output_ranked_tensor_type); - ASSERT_EQ(output_ranked_tensor_type->ElementType(), ElementType::Float32); -} - -//===----------------------------------------------------------------------===// -// CC ElementType // -//===----------------------------------------------------------------------===// - -TEST(CcElementTypeTest, GetByteWidth) { - const size_t width = GetByteWidth(); - EXPECT_EQ(width, 1); -} - -TEST(CcElementTypeTest, GetElementType) { - ElementType ty = GetElementType(); - EXPECT_EQ(ty, ElementType::Float32); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_options.cc b/tensorflow/lite/experimental/litert/cc/litert_op_options.cc deleted file mode 100644 index c2cdfc6e2d0c7a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_op_options.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -LiteRtStatus CompositeOptions::InitFromOp(LiteRtOp op) { - LiteRtOpCode opcode; - LITERT_RETURN_IF_ERROR(LiteRtGetOpCode(op, &opcode)); - if (opcode != kLiteRtOpCodeShloComposite) { - return kLiteRtStatusErrorInvalidArgument; - } - - const char* op_name; - LITERT_RETURN_IF_ERROR(LiteRtGetSHLOCompositeOpName(op, &op_name)); - name = op_name; - - LITERT_RETURN_IF_ERROR( - LiteRtGetSHLOCompositeOpDecompositionSubgraphIndex(op, &subgraph)); - - this->op = op; - - return kLiteRtStatusOk; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_options.h b/tensorflow/lite/experimental/litert/cc/litert_op_options.h deleted file mode 100644 index 70f6de4a38007e..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_op_options.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_OPTIONS_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -struct OpOptions { - virtual LiteRtStatus InitFromOp(LiteRtOp op) = 0; - virtual ~OpOptions() = default; -}; - -// Struct to hold LiteRt composite ops. -struct CompositeOptions : public OpOptions { - // Name for special composites representing manual partitions. - static constexpr absl::string_view kNpuCall = "odml.npu_call"; - - // The root op. - LiteRtOp op; - // Decomposition subgraph. - int subgraph; - // The name of the composite op (stored in model). - absl::string_view name; - - LiteRtStatus InitFromOp(LiteRtOp op) override; -}; - -// Returns the composite info for the given op if it is a composite op. -template -Expected GetOptionsAs(LiteRtOp op) { - if constexpr (std::is_same_v) { - CompositeOptions options; - LITERT_RETURN_IF_ERROR(options.InitFromOp(op)); - return options; - } else { - // TODO: Add more as needed. - return Unexpected(kLiteRtStatusErrorInvalidArgument); - } -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc b/tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc deleted file mode 100644 index 4be92d3e22f9d4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert { -namespace { - -TEST(OpOptionsTest, GetCompositeOptions) { - static constexpr auto kOptsType = - ::tflite::BuiltinOptions2_StableHLOCompositeOptions; - static constexpr absl::string_view kName = "test.composite"; - static constexpr int kSubgraph = 1; - - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeShloComposite); - - tflite::StableHLOCompositeOptionsT options; - options.name = kName; - options.decomposition_subgraph_index = kSubgraph; - - internal::TflOptions2 tfl_options; - tfl_options.type = kOptsType; - tfl_options.Set(std::move(options)); - litert::internal::SetTflOptions2(op, std::move(tfl_options)); - - auto res = GetOptionsAs(&op); - ASSERT_TRUE(res); - EXPECT_EQ(res->name, kName); - EXPECT_EQ(res->subgraph, kSubgraph); -} - -TEST(OpOptionsTest, GetUnsupportedOptions) { - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeShloAdd); - ASSERT_FALSE(GetOptionsAs(&op)); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_shared_library.cc b/tensorflow/lite/experimental/litert/cc/litert_shared_library.cc deleted file mode 100644 index 7e769ed2b3a1bf..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_shared_library.cc +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#if !LITERT_WINDOWS_OS -#include -#endif - -#if defined(_GNU_SOURCE) && !defined(__ANDROID__) && !defined(__APPLE__) -#define LITERT_IMPLEMENT_SHARED_LIBRARY_INFO 1 -#include -#endif - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -// When using an address sanitizer, `RTLD_DEEPBIND` is not supported. When using -// one, we discard the flag and log an error. -#if defined(__SANITIZE_ADDRESS__) || \ - defined(__has_feature) && \ - (__has_feature(address_sanitizer) || __has_feature(memory_sanitizer)) -#define LITERT_SANITIZER_BUILD 1 -#endif - -#if LITERT_SANITIZER_BUILD && defined(RTLD_DEEPBIND) -namespace litert { -namespace { -RtldFlags SanitizeFlagsInCaseOfAsan(RtldFlags flags) { - LITERT_LOG( - LITERT_WARNING, - "Trying to load a library using `RTLD_DEEPBIND` is not supported by " - "address sanitizers. In an effort to enable testing we strip the flag. " - "If this leads to unintended behaviour, either remove the " - "`RTLD_DEEPBIND` flag or run without an address sanitizer. " - "See https://github.com/google/sanitizers/issues/611 for more " - "information."); - flags.flags &= ~RTLD_DEEPBIND; - return flags; -} -} // namespace -} // namespace litert -#else -#define SanitizeFlagsInCaseOfAsan(flags) (flags) -#endif - -#if LITERT_WINDOWS_OS -// Implement dummy functions from dlfnc.h on Windows. -namespace { - -const char* dlerror() { - return "Windows is not supported for loading shared libraries."; -} - -void* dlopen(const char*, int) { return NULL; } - -void dlclose(void*) {} - -void* dlsym(void*, const char*) { return NULL; } - -int dlinfo(void*, int, void*) { return -1; } - -#define RTLD_NEXT (void*)-1; -#define RTLD_DEFAULT (void*)0; - -} // namespace -#endif - -namespace litert { - -SharedLibrary::~SharedLibrary() noexcept { Close(); } - -SharedLibrary::SharedLibrary(SharedLibrary&& other) noexcept - : handle_kind_(other.handle_kind_), - path_(std::move(other.path_)), - handle_(other.handle_) { - other.handle_kind_ = HandleKind::kInvalid; - other.handle_ = nullptr; -} - -SharedLibrary& SharedLibrary::operator=(SharedLibrary&& other) noexcept { - Close(); - handle_kind_ = other.handle_kind_; - path_ = std::move(other.path_); - handle_ = other.handle_; - other.handle_kind_ = HandleKind::kInvalid; - other.handle_ = nullptr; - return *this; -} - -void SharedLibrary::Close() noexcept { - if (handle_kind_ == HandleKind::kPath) { - dlclose(handle_); - } - handle_kind_ = HandleKind::kInvalid; - path_.clear(); -} - -absl::string_view SharedLibrary::DlError() noexcept { - const char* error = dlerror(); - if (!error) { - return {}; - } - return error; -} - -Expected SharedLibrary::LoadImpl( - SharedLibrary::HandleKind handle_kind, absl::string_view path, - RtldFlags flags) { - SharedLibrary lib; - switch (handle_kind) { - case HandleKind::kInvalid: - return Error(kLiteRtStatusErrorDynamicLoading, - "This is a logic error. LoadImpl should not be called with " - "HandleKind::kInvalid"); - case HandleKind::kPath: - if (path.empty()) { - return Error(kLiteRtStatusErrorDynamicLoading, - "Cannot not load shared library: empty path."); - } - lib.path_ = path; - lib.handle_ = - dlopen(lib.Path().c_str(), SanitizeFlagsInCaseOfAsan(flags)); - if (!lib.handle_) { - return Error(kLiteRtStatusErrorDynamicLoading, - absl::StrFormat("Could not load shared library %s: %s.", - lib.path_, DlError())); - } - break; - case HandleKind::kRtldNext: - lib.handle_ = RTLD_NEXT; - break; - case HandleKind::kRtldDefault: - lib.handle_ = RTLD_DEFAULT; - break; - } - lib.handle_kind_ = handle_kind; - return lib; -} - -Expected SharedLibrary::LookupSymbolImpl(const char* symbol_name) const { - void* symbol = dlsym(handle_, symbol_name); - - if (!symbol) { - return Error(kLiteRtStatusErrorDynamicLoading, - absl::StrFormat("Could not load symbol %s: %s.", symbol_name, - DlError())); - } - return symbol; -} - -std::ostream& operator<<(std::ostream& os, const SharedLibrary& lib) { - static constexpr absl::string_view kHeader = "/// DLL Info ///\n"; - static constexpr absl::string_view kFooter = "////////////////\n"; - - if (lib.handle_ == nullptr) { - os << kHeader << "Handle is nullptr.\n" << kFooter; - return os; - } - - os << kHeader; -#ifdef RTLD_DI_LMID - if (Lmid_t dl_ns_idx; dlinfo(lib.handle_, RTLD_DI_LMID, &dl_ns_idx) != 0) { - os << "Error getting lib namespace index: " << dlerror() << ".\n"; - } else { - os << "LIB NAMESPACE INDEX: " << dl_ns_idx << "\n"; - } -#else - os << "Cannot retrieve namespace index on this platform.\n"; -#endif - -#ifdef RTLD_DI_LINKMAP - if (link_map* lm; dlinfo(lib.handle_, RTLD_DI_LINKMAP, &lm) != 0) { - os << "Error getting linked objects: " << dlerror() << ".\n"; - } else { - os << "LINKED OBJECTS:\n"; - // Rewind to the start of the linked list. - const link_map* link = lm; - while (link->l_prev) { - link = link->l_prev; - } - // Print all list elements - for (; link != nullptr; link = link->l_next) { - os << (link != lm ? " " : "***") << link->l_name << "\n"; - } - } -#else - os << "Cannot retrieve lib map on this platform.\n"; -#endif - return os << kFooter; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_shared_library.h b/tensorflow/lite/experimental/litert/cc/litert_shared_library.h deleted file mode 100644 index b28f1f7e9a604c..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_shared_library.h +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SHARED_LIBRARY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SHARED_LIBRARY_H_ - -#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || \ - defined(__NT__) || defined(_WIN64) -#define LITERT_WINDOWS_OS 1 -#endif - -#if !LITERT_WINDOWS_OS -#include -#endif - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -struct RtldFlags { - int flags; - - static constexpr struct NextTag { - } kNext; - static constexpr struct DefaultTag { - } kDefault; - - // NOLINTNEXTLINE(*-explicit-constructor): we want this to be passed as flags. - operator int() { return flags; } - - static constexpr RtldFlags Lazy() { - return { -#if defined(RTLD_LAZY) - RTLD_LAZY -#endif - }; - } - static constexpr RtldFlags Now() { - return { -#if defined(RTLD_NOW) - RTLD_NOW -#endif - }; - } - static constexpr RtldFlags Default() { return Lazy().Local().DeepBind(); } - constexpr RtldFlags& Global() { -#if defined(RTLD_GLOBAL) - flags |= RTLD_GLOBAL; -#endif - return *this; - } - constexpr RtldFlags& Local() { -#if defined(RTLD_LOCAL) - flags |= RTLD_LOCAL; -#endif - return *this; - } - constexpr RtldFlags& NoDelete() { -#if defined(RTLD_NODELETE) - flags |= RTLD_NODELETE; -#endif - return *this; - } - constexpr RtldFlags& NoLoad() { -#if defined(RTLD_NOLOAD) - flags |= RTLD_NOLOAD; -#endif - return *this; - } - constexpr RtldFlags& DeepBind() { -#if defined(RTLD_DEEPBIND) - flags |= RTLD_DEEPBIND; -#endif - return *this; - } -}; - -// Wraps a dynamically loaded shared library to offer RAII semantics. -class SharedLibrary { - public: - SharedLibrary() = default; - SharedLibrary(const SharedLibrary&) = delete; - SharedLibrary& operator=(const SharedLibrary&) = delete; - SharedLibrary(SharedLibrary&&) noexcept; - SharedLibrary& operator=(SharedLibrary&&) noexcept; - ~SharedLibrary() noexcept; - - // Loads the library at the given path. - static Expected Load(absl::string_view path, - RtldFlags flags) noexcept { - return LoadImpl(HandleKind::kPath, path, flags); - } - - // Loads the library as the RTLD_NEXT special handle. - static Expected Load(RtldFlags::NextTag) noexcept { - return LoadImpl(HandleKind::kRtldNext, "", RtldFlags{}); - } - - // Loads the library as the RTLD_DEFAULT special handle. - static Expected Load(RtldFlags::DefaultTag) noexcept { - return LoadImpl(HandleKind::kRtldDefault, "", RtldFlags{}); - } - - // Gets the last shared library operation error if there was one. - // - // If there was no error, returns an empty view. - static absl::string_view DlError() noexcept; - - friend std::ostream& operator<<(std::ostream& os, const SharedLibrary& lib); - - bool Loaded() const noexcept { return handle_kind_ != HandleKind::kInvalid; } - - // Unloads the shared library. - // - // Note: this is automatically done when the object is destroyed. - void Close() noexcept; - - // Looks up a symbol in the shared library. - // - // Note: This takes a `char*` because the underlying system call requires a - // null terminated string which a string view doesn't guarantee. - template - Expected LookupSymbol(const char* symbol) const noexcept { - static_assert(std::is_pointer_v, - "The template parameter should always be a pointer."); - LITERT_ASSIGN_OR_RETURN(void* const raw_symbol, LookupSymbolImpl(symbol)); - return reinterpret_cast(raw_symbol); - } - - // Returns the loaded library path. - const std::string& Path() const noexcept { return path_; } - - // Returns the underlying shared library handle. - // - // Warning: some special handle value may be NULL. Do not rely on this value - // to check whether a library is loaded or not. - const void* Handle() const noexcept { return handle_; } - void* Handle() noexcept { return handle_; } - - private: - enum class HandleKind { kInvalid, kPath, kRtldNext, kRtldDefault }; - - static Expected LoadImpl(HandleKind handle_kind, - absl::string_view path, - RtldFlags flags); - - Expected LookupSymbolImpl(const char* symbol) const; - - HandleKind handle_kind_ = HandleKind::kInvalid; - std::string path_; - void* handle_ = nullptr; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SHARED_LIBRARY_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc b/tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc deleted file mode 100644 index 5a6fb051d0d0e0..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#include - -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -using testing::Eq; -using testing::NotNull; -using testing::StrEq; -using testing::litert::IsError; - -namespace litert { -namespace { - -extern "C" { - -const char* TestFunction() { return "local_test_function"; } - -} // extern "C" - -TEST(RtldFlagsTest, FlagFactoryWorks) { - EXPECT_THAT(static_cast(RtldFlags::Now()), Eq(RTLD_NOW)); - EXPECT_THAT(static_cast(RtldFlags::Lazy()), Eq(RTLD_LAZY)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().Global()), - Eq(RTLD_LAZY | RTLD_GLOBAL)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().Local()), - Eq(RTLD_LAZY | RTLD_LOCAL)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().NoDelete()), - Eq(RTLD_LAZY | RTLD_NODELETE)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().NoLoad()), - Eq(RTLD_LAZY | RTLD_NOLOAD)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().DeepBind()), - Eq(RTLD_LAZY | RTLD_DEEPBIND)); -} - -TEST(SharedLibraryTest, LoadRtldDefaultWorks) { - LITERT_ASSERT_OK_AND_ASSIGN(SharedLibrary lib, - SharedLibrary::Load(RtldFlags::kDefault)); - - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_EQ(lib.Handle(), RTLD_DEFAULT); - - auto maybe_test_function = - lib.LookupSymbol("TestFunction"); - if (!maybe_test_function.HasValue()) { - GTEST_SKIP() << "TestFunction symbol was stripped from binary."; - } - - decltype(&TestFunction) test_function = maybe_test_function.Value(); - ASSERT_NE(test_function, nullptr); - EXPECT_THAT(test_function(), StrEq(TestFunction())); -} - -TEST(SharedLibraryTest, LoadRtldNextWorks) { - LITERT_ASSERT_OK_AND_ASSIGN(SharedLibrary lib, - SharedLibrary::Load(RtldFlags::kNext)); - - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_EQ(lib.Handle(), RTLD_NEXT); -} - -TEST(SharedLibraryTest, LoadEmptyPathFails) { - EXPECT_THAT(SharedLibrary::Load("", RtldFlags::Now().Local()), IsError()); -} - -TEST(SharedLibraryTest, LoadPathWorks) { - const std::string lib_path = absl::StrCat( - "third_party/tensorflow/lite/experimental/litert/cc/" - "test_shared_library.so"); - LITERT_ASSERT_OK_AND_ASSIGN( - SharedLibrary lib, - SharedLibrary::Load(lib_path, RtldFlags::Now().Local())); - - EXPECT_TRUE(lib.Loaded()); - EXPECT_THAT(lib.Path(), StrEq(lib_path)); - EXPECT_THAT(lib.Handle(), NotNull()); - - using TestFunctionSignature = char* (*)(); - - LITERT_ASSERT_OK_AND_ASSIGN(TestFunctionSignature test_function, - lib.LookupSymbol("TestFunction")); - ASSERT_NE(test_function, nullptr); - EXPECT_THAT(test_function(), StrEq("test_shared_library")); - - lib.Close(); - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_FALSE(lib.Loaded()); -} - -TEST(SharedLibraryTest, ConstructionAndAssignmentWork) { - const std::string lib_path = absl::StrCat( - "third_party/tensorflow/lite/experimental/litert/cc/" - "test_shared_library.so"); - LITERT_ASSERT_OK_AND_ASSIGN( - SharedLibrary lib, - SharedLibrary::Load(lib_path, RtldFlags::Now().Local())); - - const void* const lib_handle = lib.Handle(); - - SharedLibrary lib2(std::move(lib)); - - // NOLINTBEGIN(bugprone-use-after-move): Tests that moving clears up the - // object. - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_FALSE(lib.Loaded()); - - EXPECT_TRUE(lib2.Loaded()); - EXPECT_THAT(lib2.Path(), StrEq(lib_path)); - EXPECT_THAT(lib2.Handle(), Eq(lib_handle)); - - lib = std::move(lib2); - EXPECT_THAT(lib2.Path(), StrEq("")); - EXPECT_FALSE(lib2.Loaded()); - - EXPECT_TRUE(lib.Loaded()); - EXPECT_THAT(lib.Path(), StrEq(lib_path)); - EXPECT_THAT(lib.Handle(), Eq(lib_handle)); - // NOLINTEND(bugprone-use-after-move) -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h deleted file mode 100644 index d7fef8cec7e0d7..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#else -typedef struct _cl_mem* cl_mem; -#endif - -namespace litert { - -// Tensor and associated backing buffer. C++ equivalent of LiteRtTensorBuffer. -class TensorBuffer - : public internal::Handle { - public: - TensorBuffer() = default; - - // Parameter `owned` indicates if the created TensorBuffer object should take - // ownership of the provided `tensor_buffer` handle. - explicit TensorBuffer(LiteRtTensorBuffer tensor_buffer, bool owned = true) - : internal::Handle( - tensor_buffer, owned) {} - - // Creates a duplicate of the current TensorBuffer object. The returned - // object is reference counted so the underlying LiteRtTensorBuffer handle is - // not released with the destructor until the last reference is removed. - Expected Duplicate() const { - if (!IsOwned()) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Cannot duplicate a non-owned tensor buffer"); - } - LITERT_RETURN_IF_ERROR(LiteRtDuplicateTensorBuffer(Get())); - return TensorBuffer(Get()); - } - - static Expected CreateManaged( - LiteRtTensorBufferType buffer_type, const RankedTensorType& tensor_type, - size_t buffer_size) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - LITERT_RETURN_IF_ERROR(LiteRtCreateManagedTensorBuffer( - buffer_type, &litert_tensor_type, buffer_size, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - // Creates a TensorBuffer object that wraps the provided host memory. - // The provided host memory is not owned by the TensorBuffer object and must - // outlive the TensorBuffer object. - static Expected CreateFromHostMemory( - const RankedTensorType& tensor_type, void* host_mem_addr, - size_t buffer_size) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromHostMemory( - &litert_tensor_type, host_mem_addr, buffer_size, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - // Creates a TensorBuffer object that wraps an Android Hardware Buffer. Note - // that the provided AHardwareBuffer is not owned by the TensorBuffer object - // and must outlive the TensorBuffer object. The `ahwb_offset` parameter - // specifies the offset in bytes from the start of the AHardwareBuffer where - // the tensor data starts. - static Expected CreateFromAhwb( - const RankedTensorType& tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset) { -#if LITERT_HAS_AHWB_SUPPORT - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromAhwb( - &litert_tensor_type, ahwb, ahwb_offset, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); -#else - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif - } - - litert::Expected GetAhwb() const { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferAhwb(Get(), &ahwb)); - return ahwb; -#else - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif - } - - struct DmaBuf { - void* addr; - int fd; - }; - - litert::Expected GetDmaBuf() const { -#if LITERT_HAS_DMABUF_SUPPORT - DmaBuf dma_buf; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferDmaBufBuffer(Get(), &dma_buf.addr, &dma_buf.fd)); - return dma_buf; -#else - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "DMA-BUF is not supported on this platform"); -#endif - } - - Expected GetOpenClBuffer() const { -#if LITERT_HAS_OPENCL_SUPPORT - cl_mem cl_mem; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferOpenClBuffer(Get(), &cl_mem)); - return cl_mem; -#else - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenCL is not supported on this platform"); -#endif - } - - struct GlBuffer { - LiteRtGLenum target; - LiteRtGLuint id; - size_t size_bytes; - size_t offset; - }; - - static Expected CreateFromGlBuffer( - const RankedTensorType& tensor_type, LiteRtGLenum target, LiteRtGLuint id, - size_t size_bytes, size_t offset) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromGlBuffer( - &litert_tensor_type, target, id, size_bytes, offset, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - Expected GetGlBuffer() const { - GlBuffer gl_buffer; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferGlBuffer( - Get(), &gl_buffer.target, &gl_buffer.id, &gl_buffer.size_bytes, - &gl_buffer.offset)); - return gl_buffer; - } - struct GlTexture { - LiteRtGLenum target; - LiteRtGLuint id; - LiteRtGLenum format; - size_t size_bytes; - LiteRtGLint layer; - }; - static Expected CreateFromGlTexture( - const RankedTensorType& tensor_type, LiteRtGLenum target, LiteRtGLuint id, - LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromGlTexture( - &litert_tensor_type, target, id, format, size_bytes, layer, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - Expected GetGlTexture() const { - GlTexture gl_texture; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferGlTexture( - Get(), &gl_texture.target, &gl_texture.id, &gl_texture.format, - &gl_texture.size_bytes, &gl_texture.layer)); - return gl_texture; - } - - Expected BufferType() const { - LiteRtTensorBufferType tensor_buffer_type; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferType(Get(), &tensor_buffer_type)); - return tensor_buffer_type; - } - - Expected TensorType() const { - LiteRtRankedTensorType tensor_type; - if (auto status = LiteRtGetTensorBufferTensorType(Get(), &tensor_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor type"); - } - return RankedTensorType(tensor_type); - } - - Expected Size() const { - size_t size; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferSize(Get(), &size)); - return size; - } - - Expected Offset() const { - size_t offset; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferOffset(Get(), &offset)); - return offset; - } - - bool HasEvent() const { - bool has_event; - internal::AssertOk(LiteRtHasTensorBufferEvent, Get(), &has_event); - return has_event; - } - - Expected GetEvent() const { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferEvent(Get(), &event)); - return Event(event, /*owned=*/false); - } - - // Set the C++ Event object for the tensor buffer. - // The function takes ownership of the passed Event object. - Expected SetEvent(Event&& event) { - if (!event.IsOwned()) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Expected an owned event"); - } - LITERT_RETURN_IF_ERROR(LiteRtSetTensorBufferEvent(Get(), event.Release())); - return {}; - } - - // Set the C LiteRtEvent object for the tensor buffer. - // The function takes ownership of the passed LiteRtEvent object. - Expected SetLiteRtEvent(LiteRtEvent& litert_event) { - LITERT_RETURN_IF_ERROR(LiteRtSetTensorBufferEvent(Get(), litert_event)); - return {}; - } - - Expected ClearEvent() { - LITERT_RETURN_IF_ERROR(LiteRtClearTensorBufferEvent(Get())); - return {}; - } - - Expected Lock() { - void* host_mem_addr; - LITERT_RETURN_IF_ERROR(LiteRtLockTensorBuffer(Get(), &host_mem_addr)); - return host_mem_addr; - } - - Expected Unlock() { - LITERT_RETURN_IF_ERROR(LiteRtUnlockTensorBuffer(Get())); - return {}; - } - - // Writes data from the user provided Span to the tensor buffer. - // It returns an error if the provided buffer is bigger than the size of the - // tensor buffer. - template - Expected Write(absl::Span data) { - LITERT_ASSIGN_OR_RETURN(void* host_mem_addr, Lock()); - LITERT_ASSIGN_OR_RETURN(size_t size, Size()); - if (size < data.size() * sizeof(T)) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer size is smaller than the given data size"); - } - std::memcpy(host_mem_addr, data.data(), data.size() * sizeof(T)); - Unlock(); - return {}; - } - - // Reads data into the user provided Span from the tensor buffer. - // If the provided buffer is smaller than the size of the tensor buffer, the - // data will be read up to the size of the provided buffer. - // It returns an error if the provided buffer is bigger than the size of the - // tensor buffer. - template - Expected Read(absl::Span data) { - LITERT_ASSIGN_OR_RETURN(void* host_mem_addr, Lock()); - LITERT_ASSIGN_OR_RETURN(size_t size, Size()); - size_t total_read_size = data.size() * sizeof(T); - if (size < total_read_size) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer size is smaller than the given data size"); - } - std::memcpy(data.data(), host_mem_addr, total_read_size); - Unlock(); - return {}; - } -}; - -class TensorBufferScopedLock { - public: - TensorBufferScopedLock(const TensorBufferScopedLock& arg) = delete; - TensorBufferScopedLock(TensorBufferScopedLock&& arg) = default; - ~TensorBufferScopedLock() { (void)LiteRtUnlockTensorBuffer(tensor_buffer_); } - - template - static Expected> Create( - TensorBuffer& tensor_buffer) { - return Create(tensor_buffer.Get()); - } - - template - static Expected> Create( - LiteRtTensorBuffer tensor_buffer) { - void* host_mem_addr; - LITERT_RETURN_IF_ERROR( - LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr)); - return std::make_pair(TensorBufferScopedLock(tensor_buffer), - static_cast(host_mem_addr)); - } - - private: - explicit TensorBufferScopedLock(LiteRtTensorBuffer& tensor_buffer) - : tensor_buffer_(tensor_buffer) {} - - LiteRtTensorBuffer tensor_buffer_; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h deleted file mode 100644 index 881e3662a2fff6..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -// Requirements for allocating a TensorBuffer, typically specified by a HW -// accelerator for a given I/O tensor. C++ equivalent to -// LiteRtTensorBufferRequirements. -class TensorBufferRequirements - : public internal::Handle { - public: - TensorBufferRequirements() = default; - - // Parameter `owned` indicates if the created TensorBufferRequirements object - // should take ownership of the provided `requirements` handle. - explicit TensorBufferRequirements(LiteRtTensorBufferRequirements requirements, - bool owned = true) - : internal::Handle(requirements, - owned) {} - - static Expected Create( - absl::Span buffer_types, size_t buffer_size, - absl::Span strides = - absl::MakeSpan(static_cast(nullptr), 0)) { - LiteRtTensorBufferRequirements tensor_buffer_requirements; - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferRequirements( - buffer_types.size(), buffer_types.data(), buffer_size, strides.size(), - strides.data(), &tensor_buffer_requirements)); - return TensorBufferRequirements(tensor_buffer_requirements); - } - - Expected> SupportedTypes() const { - int num_types; - LITERT_RETURN_IF_ERROR( - LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes(Get(), - &num_types)); - std::vector types(num_types); - for (auto i = 0; i < num_types; ++i) { - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - Get(), i, &types[i])); - } - return types; - } - - Expected BufferSize() const { - size_t buffer_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferRequirementsBufferSize(Get(), &buffer_size)); - return buffer_size; - } - - Expected> Strides() const { - int num_strides; - const uint32_t* strides; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferRequirementsStrides( - Get(), &num_strides, &strides)); - return absl::MakeSpan(strides, num_strides); - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc deleted file mode 100644 index 0dba6aaac27641..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace { - -constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { - kLiteRtTensorBufferTypeHostMemory, - kLiteRtTensorBufferTypeAhwb, - kLiteRtTensorBufferTypeIon, - kLiteRtTensorBufferTypeFastRpc, -}; - -constexpr const size_t kNumSupportedTensorBufferTypes = - sizeof(kSupportedTensorBufferTypes) / - sizeof(kSupportedTensorBufferTypes[0]); - -constexpr const size_t kBufferSize = 1234; - -} // namespace - -TEST(TensorBufferRequirements, Owned) { - auto requirements = litert::TensorBufferRequirements::Create( - absl::MakeSpan(kSupportedTensorBufferTypes, - kNumSupportedTensorBufferTypes), - kBufferSize); - ASSERT_TRUE(requirements); - - auto supported_types = requirements->SupportedTypes(); - ASSERT_TRUE(supported_types); - ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); - for (auto i = 0; i < supported_types->size(); ++i) { - ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); - } - - auto size = requirements->BufferSize(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, kBufferSize); -} - -TEST(TensorBufferRequirements, NotOwned) { - LiteRtTensorBufferRequirements litert_requirements; - ASSERT_EQ(LiteRtCreateTensorBufferRequirements( - kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, - kBufferSize, /*num_strides=*/0, /*strides=*/nullptr, - &litert_requirements), - kLiteRtStatusOk); - - litert::TensorBufferRequirements requirements(litert_requirements, - /*owned=*/false); - - auto supported_types = requirements.SupportedTypes(); - ASSERT_TRUE(supported_types); - ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); - for (auto i = 0; i < supported_types->size(); ++i) { - ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); - } - - auto size = requirements.BufferSize(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, kBufferSize); - - ASSERT_EQ(requirements.Get(), litert_requirements); - - LiteRtDestroyTensorBufferRequirements(litert_requirements); -} - -TEST(TensorBufferRequirements, WithStrides) { - constexpr std::array kStrides = {1, 2, 3}; - - auto requirements = litert::TensorBufferRequirements::Create( - absl::MakeSpan(kSupportedTensorBufferTypes, - kNumSupportedTensorBufferTypes), - kBufferSize, absl::MakeSpan(kStrides.data(), kStrides.size())); - ASSERT_TRUE(requirements); - - auto strides = requirements->Strides(); - ASSERT_TRUE(strides); - ASSERT_EQ(strides->size(), kStrides.size()); - for (auto i = 0; i < kStrides.size(); ++i) { - ASSERT_EQ((*strides)[i], kStrides[i]); - } -} diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc deleted file mode 100644 index c366f2081c1b9b..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#include -#include -#include -#include -#include - -#include -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#endif // LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace { - -using ::testing::Eq; -using ::testing::Ne; - -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / - sizeof(kTensorData[0])}; - -constexpr int kFakeSyncFenceFd = 1; - -constexpr const LiteRtRankedTensorType kTestTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - BuildLayout(kTensorDimensions)}; - -int GetReferenceCount(const TensorBuffer& tensor_buffer) { - LiteRtTensorBufferT* internal_tensor_buffer = - static_cast(tensor_buffer.Get()); - return internal_tensor_buffer->RefCount(); -} - -TEST(TensorBuffer, HostMemory) { - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, Ahwb) { - if (!internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, Ion) { - if (!internal::IonBuffer::IsSupported()) { - GTEST_SKIP() - << "ION buffers are not supported on this platform; skipping the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, DmaBuf) { - if (!internal::DmaBufBuffer::IsSupported()) { - GTEST_SKIP() - << "DMA-BUF buffers are not supported on this platform; skipping " - "the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, FastRpc) { - if (!internal::FastRpcBuffer::IsSupported()) { - GTEST_SKIP() - << "FastRPC buffers are not supported on this platform; skipping " - "the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, NotOwned) { - LiteRtTensorBuffer litert_tensor_buffer; - ASSERT_EQ(LiteRtCreateManagedTensorBuffer( - kLiteRtTensorBufferTypeHostMemory, &kTestTensorType, - sizeof(kTensorData), &litert_tensor_buffer), - kLiteRtStatusOk); - - TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/false); - ASSERT_EQ(tensor_buffer.Get(), litert_tensor_buffer); - - LiteRtDestroyTensorBuffer(litert_tensor_buffer); -} - -TEST(TensorBuffer, CreateFromExternalHostMemory) { - // Allocate a tensor buffer with host memory. - const int kTensorBufferSize = - std::max(sizeof(kTensorData), LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); - const RankedTensorType kTensorType(kTestTensorType); - void* host_memory_ptr; - ASSERT_EQ( - ::posix_memalign(&host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, - kTensorBufferSize), - 0); - - std::memcpy(host_memory_ptr, kTensorData, sizeof(kTensorData)); - - // Create a tensor buffer that wraps the host memory. - auto tensor_buffer_from_external_memory = TensorBuffer::CreateFromHostMemory( - kTensorType, host_memory_ptr, kTensorBufferSize); - - auto lock_and_addr_external_memory = - TensorBufferScopedLock::Create(*tensor_buffer_from_external_memory); - ASSERT_TRUE(lock_and_addr_external_memory); - ASSERT_EQ(std::memcmp(lock_and_addr_external_memory->second, kTensorData, - sizeof(kTensorData)), - 0); - - free(host_memory_ptr); -} - -#if LITERT_HAS_AHWB_SUPPORT -TEST(TensorBuffer, CreateFromAhwb) { - AHardwareBuffer* ahw_buffer = nullptr; - if (__builtin_available(android 26, *)) { - int error = 0; - AHardwareBuffer_Desc desc = { - .width = LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, - .height = 1, - .layers = 1, - .format = AHARDWAREBUFFER_FORMAT_BLOB, - .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | - AHARDWAREBUFFER_USAGE_CPU_READ_RARELY}; - error = AHardwareBuffer_allocate(&desc, &ahw_buffer); - ASSERT_EQ(error, 0); - - void* host_memory_ptr = nullptr; - error = - AHardwareBuffer_lock(ahw_buffer, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, - -1, nullptr, &host_memory_ptr); - ASSERT_EQ(error, 0); - - std::memcpy(host_memory_ptr, kTensorData, sizeof(kTensorData)); - - int fence_file_descriptor = -1; - error = AHardwareBuffer_unlock(ahw_buffer, &fence_file_descriptor); - ASSERT_EQ(error, 0); - } else { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - - { - // Create a tensor buffer that wraps the AHardwareBuffer. - const RankedTensorType kTensorType(kTestTensorType); - auto tensor_buffer_from_ahwb = - TensorBuffer::CreateFromAhwb(kTensorType, ahw_buffer, - /*ahwb_offset=*/0); - - auto lock_and_addr_external_memory = - TensorBufferScopedLock::Create(*tensor_buffer_from_ahwb); - ASSERT_TRUE(lock_and_addr_external_memory); - ASSERT_EQ(std::memcmp(lock_and_addr_external_memory->second, kTensorData, - sizeof(kTensorData)), - 0); - } - - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(ahw_buffer); - } -} -#endif // LITERT_HAS_AHWB_SUPPORT - -TEST(TensorBuffer, Duplicate) { - LiteRtTensorBuffer litert_tensor_buffer; - ASSERT_EQ(LiteRtCreateManagedTensorBuffer( - kLiteRtTensorBufferTypeHostMemory, &kTestTensorType, - sizeof(kTensorData), &litert_tensor_buffer), - kLiteRtStatusOk); - - TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); - ASSERT_EQ(GetReferenceCount(tensor_buffer), 1); - { - auto duplicated_tensor_buffer = tensor_buffer.Duplicate(); - ASSERT_TRUE(duplicated_tensor_buffer); - ASSERT_EQ(GetReferenceCount(*duplicated_tensor_buffer), 2); - // The duplicated tensor buffer should point to the same underlying - // LiteRtTensorBuffer object. - ASSERT_EQ(duplicated_tensor_buffer->Get(), tensor_buffer.Get()); - - // Update tensor buffer using the duplicated tensor buffer. - auto lock_and_addr = - TensorBufferScopedLock::Create(*duplicated_tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - - // When the scope ends, the duplicated tensor buffer should be destroyed. - // This should not affect the original tensor buffer. - } - - ASSERT_EQ(GetReferenceCount(tensor_buffer), 1); - // Check that the original tensor buffer is not affected. - { - auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, ReadWriteBasic) { - LiteRtTensorBuffer litert_tensor_buffer; - ASSERT_EQ(LiteRtCreateManagedTensorBuffer( - kLiteRtTensorBufferTypeHostMemory, &kTestTensorType, - sizeof(kTensorData), &litert_tensor_buffer), - kLiteRtStatusOk); - - TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); - auto write_success = tensor_buffer.Write(absl::MakeSpan( - kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0]))); - ASSERT_TRUE(write_success); - float read_data[sizeof(kTensorData) / sizeof(kTensorData[0])]; - auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data)); - ASSERT_TRUE(read_success); - ASSERT_EQ(std::memcmp(read_data, kTensorData, sizeof(kTensorData)), 0); -} - -TEST(TensorBuffer, ReadWriteBufferSizeMismatch) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeHostMemory, - RankedTensorType(kTestTensorType), - sizeof(kTensorData))); - { - // Write with smaller size of data. - auto write_success = - tensor_buffer.Write(absl::MakeSpan(kTensorData, 1)); - ASSERT_TRUE(write_success); - } - { - constexpr const float big_data[] = {10, 20, 30, 40, 50}; - // Write with larger size of data. - auto write_success = - tensor_buffer.Write(absl::MakeSpan(big_data, 5)); - ASSERT_FALSE(write_success); - } - auto write_success = tensor_buffer.Write(absl::MakeSpan( - kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0]))); - ASSERT_TRUE(write_success); - { - // Read with smaller size of buffer. - float read_data[1]; - auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data, 1)); - ASSERT_TRUE(read_success); - ASSERT_EQ(read_data[0], kTensorData[0]); - } - { - // Read with larger size of buffer. - float read_data[5]; - auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data, 5)); - ASSERT_FALSE(read_success); - } -} - -#if LITERT_HAS_OPENGL_SUPPORT -TEST(TensorBuffer, CreateFromGlTexture) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - // Create GL texture. - tflite::gpu::gl::GlTexture gl_texture(GL_TEXTURE_2D, 1, GL_RGBA8, 1, 1, - /*has_ownership=*/true); - ASSERT_TRUE(gl_texture.is_valid()); - - // Create tensor buffer from existing GL texture (e.g. this could be from - // Android Camera API). - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateFromGlTexture( - RankedTensorType(kTensorType), gl_texture.target(), gl_texture.id(), - gl_texture.format(), gl_texture.bytes_size(), gl_texture.layer())); -} - -tflite::gpu::gl::GlBuffer CreateTestGlBuffer(size_t size_bytes) { - tflite::gpu::gl::GlBuffer gl_buffer; - CHECK_OK(tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - size_bytes, &gl_buffer)); - return gl_buffer; -} - -TEST(TensorBuffer, CreateFromGlBuffer) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - // Create GL buffer. - tflite::gpu::gl::GlBuffer gl_buffer = CreateTestGlBuffer(sizeof(kTensorData)); - - // Create tensor buffer from existing GL buffer. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateFromGlBuffer( - RankedTensorType(kTensorType), gl_buffer.target(), gl_buffer.id(), - gl_buffer.bytes_size(), gl_buffer.offset())); -} - -#if LITERT_HAS_AHWB_SUPPORT -TEST(TensorBuffer, GetGlBufferFromAhwb) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - // Create AHWB Tensor buffer. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer ahwb_tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeAhwb, - RankedTensorType(kTensorType), - sizeof(kTensorData))); - - // Write to AHWB Tensor buffer. - LITERT_ASSERT_OK(ahwb_tensor_buffer.Write(absl::MakeConstSpan( - kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0])))); - - LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer::GlBuffer gl_buffer, - ahwb_tensor_buffer.GetGlBuffer()); - EXPECT_THAT(gl_buffer.target, Eq(GL_SHADER_STORAGE_BUFFER)); - EXPECT_THAT(gl_buffer.id, Ne(0)); - EXPECT_THAT(gl_buffer.size_bytes, Eq(sizeof(kTensorData))); - EXPECT_THAT(gl_buffer.offset, Eq(0)); - - // Read from GL buffer. - // TODO(gcarranza): Add GlBuffer ReadLock functionality to LiteRT - // TensorBuffer. GlBuffer::Unlock currently writes to GL buffer. - tflite::gpu::gl::GlBuffer gl_buffer_from_ahwb( - gl_buffer.target, gl_buffer.id, gl_buffer.size_bytes, gl_buffer.offset, - /*has_ownership=*/false); - float read_data[sizeof(kTensorData) / sizeof(kTensorData[0])]; - ASSERT_OK(gl_buffer_from_ahwb.Read(absl::MakeSpan(read_data))); - ASSERT_EQ(std::memcmp(read_data, kTensorData, sizeof(kTensorData)), 0); -} -#endif // LITERT_HAS_AHWB_SUPPORT - -#endif // LITERT_HAS_OPENGL_SUPPORT - -TEST(TensorBuffer, GetAhwb) { - if (!internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeAhwb, - RankedTensorType(kTestTensorType), - sizeof(kTensorData))); - LITERT_ASSERT_OK_AND_ASSIGN(AHardwareBuffer * ahwb, tensor_buffer.GetAhwb()); - EXPECT_THAT(ahwb, Ne(nullptr)); -} - -TEST(TensorBuffer, Event) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeHostMemory, - RankedTensorType(kTestTensorType), - sizeof(kTensorData))); - // Create event. - LITERT_ASSERT_OK_AND_ASSIGN( - Event event, Event::CreateFromSyncFenceFd(kFakeSyncFenceFd, true)); - // Move event into tensor buffer. - LITERT_EXPECT_OK(tensor_buffer.SetEvent(std::move(event))); - EXPECT_TRUE(tensor_buffer.HasEvent()); - LITERT_ASSERT_OK_AND_ASSIGN(Event tensor_buffer_event, - tensor_buffer.GetEvent()); - LITERT_ASSERT_OK_AND_ASSIGN(int fence_fd, - tensor_buffer_event.GetSyncFenceFd()); - EXPECT_THAT(fence_fd, Eq(kFakeSyncFenceFd)); - // Clear event. - LITERT_ASSERT_OK(tensor_buffer.ClearEvent()); - EXPECT_FALSE(tensor_buffer.HasEvent()); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc deleted file mode 100644 index 67ef66c34291d4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" - -std::string BufferTypeToString(LiteRtTensorBufferType buffer_type) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeUnknown: - return "Unknown"; - case kLiteRtTensorBufferTypeHostMemory: - return "HostMemory"; - case kLiteRtTensorBufferTypeAhwb: - return "Ahwb"; - case kLiteRtTensorBufferTypeIon: - return "Ion"; - case kLiteRtTensorBufferTypeDmaBuf: - return "DmaBuf"; - case kLiteRtTensorBufferTypeFastRpc: - return "FastRpc"; - case kLiteRtTensorBufferTypeOpenCl: - return "OpenCl"; - case kLiteRtTensorBufferTypeGlBuffer: - return "GlBuffer"; - case kLiteRtTensorBufferTypeGlTexture: - return "GlTexture"; - } - LITERT_LOG(LITERT_ERROR, "Unexpected value for LiteRtTensorBufferType: %d", - static_cast(buffer_type)); - return "UnexpectedBufferType"; -} diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h deleted file mode 100644 index a2ccf427211007..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_UTILS_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" - -std::string BufferTypeToString(LiteRtTensorBufferType buffer_type); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/test_shared_library.cc b/tensorflow/lite/experimental/litert/cc/test_shared_library.cc deleted file mode 100644 index 37254390ea2018..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/test_shared_library.cc +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern "C" { - -const char* TestFunction() { return "test_shared_library"; } - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/compiler/BUILD b/tensorflow/lite/experimental/litert/compiler/BUILD deleted file mode 100644 index 23b07d5602d7c8..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/BUILD b/tensorflow/lite/experimental/litert/compiler/plugin/BUILD deleted file mode 100644 index 77f23b34399cc7..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/BUILD +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "compiler_plugin", - srcs = ["compiler_plugin.cc"], - hdrs = ["compiler_plugin.h"], - deps = [ - ":algo", - ":compiler_flags", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_op_options", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/core:version", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:buffer_manager", - "//tensorflow/lite/experimental/litert/core/model:ir_allocator", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -# copybara:uncomment_begin(no OSS for unique-test-directory) -# cc_test( -# name = "compiler_plugin_test", -# srcs = ["compiler_plugin_test.cc"], -# data = [ -# "//tensorflow/lite/experimental/litert/test:mlir_test_data", -# "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", -# ], -# tags = [ -# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. -# "noasan", -# "nomsan", -# "nosan", -# "notsan", -# ], -# deps = [ -# ":compiler_plugin", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/strings:string_view", -# "//tensorflow/lite/experimental/litert/c:litert_common", -# "//tensorflow/lite/experimental/litert/c:litert_model", -# "//tensorflow/lite/experimental/litert/c:litert_op_code", -# "//tensorflow/lite/experimental/litert/cc:litert_environment", -# "//tensorflow/lite/experimental/litert/cc:litert_op_options", -# "//tensorflow/lite/experimental/litert/core:build_stamp", -# "//tensorflow/lite/experimental/litert/core:filesystem", -# "//tensorflow/lite/experimental/litert/core/model", -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# "//tensorflow/lite/experimental/litert/tools:dump", -# ], -# ) -# copybara:uncomment_end - -cc_library( - name = "algo", - srcs = ["algo.cc"], - hdrs = ["algo.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core:insert_order_map", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:absl_check", - ], -) - -cc_test( - name = "algo_test", - srcs = ["algo_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - deps = [ - ":algo", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:graph_validation", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "compiler_flags", - srcs = ["compiler_flags.cc"], - hdrs = ["compiler_flags.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "compiler_flags_test", - srcs = ["compiler_flags_test.cc"], - deps = [ - ":compiler_flags", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc b/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc deleted file mode 100644 index eb36733486b3b2..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/insert_order_map.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -namespace litert::internal { -namespace { - -// -// flatlist to partition(s) -//===----------------------------------------------------------------------===// - -class DisjointSets { - public: - static std::vector> GetPartitionsFromFlatList( - const std::vector& flat_op_list); - - private: - void Insert(LiteRtOp op, LiteRtOp parent); - std::vector> GetBuckets(); - LiteRtOp GetBucket(LiteRtOp op); - InsertOrderMap map_; -}; - -//===----------------------------------------------------------------------===// -// LiteRt Core union-find algorithm. -// -// This algorithm is used to group partitions into sub DAGs. -// The input to the algorithm is a list of ops with the their partition index. -// -// [ (op_0, 0), -// (op_1, 0), -// (op_2, 0), -// ... -// (op_7, 1), -// (op_8, 1), ...] -// -// Union-find algorithm is run on each partition (list of ops with same -// partition index). -// -// For each partition, the input to the union find algorithm is a list of -// ops with the same partition index. For example, -// -// [ op_0, op_1, op_2, op3, op_4, op_5 ...] -// -// The output of the union find algorithm is a list of list of ops, where each -// list is a disjoint set(a sub DAG within the original Subgraph). For -// example, -// -// [ [op_0, op_1, op_6], -// [op_2, op_3], -// [op_4, op_5] ... ] -// -// Similarly, algorithm on the next parition would return something like -// -// [ [op_7, op_8, op_9], -// [op_10, op_11], -// [op_12, op_13] ... ] -// -// We aggregate all disjoint sets into the result buckets. For example, -// -// [ [op_0, op_1, op_6] -// [op_2, op_3] , -// [op_4, op_5], -// [op_7, op_8, op_9], -// [op_10, op_11], -// [op_12, op_13] ... ] -//===----------------------------------------------------------------------===// -std::vector> DisjointSets::GetPartitionsFromFlatList( - const std::vector& flat_op_list) { - // Find all unique partition indices. Use unique partition index as key and - // store the ops for each partition index as value of the map. - absl::flat_hash_map> partition_map; - for (int i = 0; i < flat_op_list.size(); ++i) { - partition_map[flat_op_list[i].second].push_back(flat_op_list[i].first); - } - - // A vector of disjoint sets, where each partition contains op with the same - // partition index. - std::vector partitions; - - // A vector of all unique partition indices for iterative access. We kept this - // vector so vendor plugin returned partition indices does not have to be - // zero-based. - std::vector flat_partition_indices; - for (auto& partition_index : partition_map) { - flat_partition_indices.push_back(partition_index.first); - } - - // Resize the partitions vector to the number of unique partition indices. - partitions.resize(flat_partition_indices.size()); - - // Resulting buckets of the union find algorithm. - std::vector> all_buckets; - - // Run union-find algorithm on each partition. - for (int i = 0; i < flat_partition_indices.size(); ++i) { - // For each partition, initialize the disjoint sets. - for (auto* op : partition_map[flat_partition_indices[i]]) { - partitions[i].map_.InsertOrAssign(op, op); - } - // For each partition, find all disjoint sets. - for (auto* op : partition_map[flat_partition_indices[i]]) { - for (auto* output : op->Outputs()) { - for (auto* user : output->Users()) { - if (!partitions[i].map_.Contains(user)) { - continue; - } - partitions[i].Insert(op, user); - } - } - } - // Aggregate all disjoint sets into the result buckets. - for (auto& bucket : partitions[i].GetBuckets()) { - all_buckets.push_back(std::move(bucket)); - } - } - return all_buckets; -} - -void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { - auto* parent_bucket = GetBucket(parent); - auto* op_bucket = GetBucket(op); - if (op_bucket == parent_bucket) { - return; - } - map_.InsertOrAssign(op_bucket, parent_bucket); -} - -// Get all disjoint sets. -std::vector> DisjointSets::GetBuckets() { - // NOLINTBEGIN - std::unordered_map> invert_map; - // NOLINTEND - for (auto it = map_.Begin(); it != map_.End(); ++it) { - auto* bucket = GetBucket(it->first); - - if (invert_map.find(bucket) == invert_map.end()) { - invert_map.insert_or_assign(bucket, std::vector{}); - } - - invert_map[bucket].push_back(it->first); - } - - std::vector> res; - res.reserve(invert_map.size()); - - for (auto& entry : invert_map) { - res.push_back(std::move(entry.second)); - } - - return res; -} - -// Gets the pointer which serves as the key for given ops bucket. Collapses -// paths to amortize. -LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { - auto it = map_.Find(op); - auto* parent = it->get().second; - if (op != parent) { - parent = GetBucket(parent); - map_.InsertOrAssign(op, parent); - } - return parent; -} - -// -// slice partitions out of a subgraph (into new subgraphs) -//===----------------------------------------------------------------------===// - -class GraphSlicer { - public: - // Slices "partitions" from "root" into the empty subgraph "slice". Assumes - // the partition is a valid sub-DAG, and replaces it witha single - // tfl.custom_op in "root". A reference to that op is returned. - static LiteRtOp SlicePartitionFromGraph(LiteRtSubgraphT& root, - LiteRtSubgraph slice, - std::vector& partition); - - private: - explicit GraphSlicer(LiteRtSubgraph slice) : slice_(slice) {} - - void CloneInto(const LiteRtOpT& op); - - void RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root); - - LiteRtSubgraph slice_; - // Maps tensor in old subgraph to tensor in new subgraph. - InsertOrderMap tensor_map_; - LiteRtOp dispatch_op_ = nullptr; -}; - -LiteRtOp GraphSlicer::SlicePartitionFromGraph( - LiteRtSubgraphT& root, LiteRtSubgraph slice, - std::vector& partition) { - GraphSlicer slicer(slice); - - // Register input tensors of the sliced partition WRT to their original order - // in the root subgraph. This ensures the order of input tensors of the - // later outlined custom op is the same as the order of input tensors of the - // GraphInputs. - absl::flat_hash_set used_tensors; - - // Get all tensors used in the partition. - for (auto* op : partition) { - used_tensors.insert(op->Inputs().cbegin(), op->Inputs().cend()); - } - for (auto* old_input : root.Inputs()) { - if (used_tensors.contains(old_input)) { - auto* new_input = &MakeClone(*slicer.slice_, *old_input); - slicer.slice_->Inputs().push_back(new_input); - slicer.tensor_map_.InsertOrAssign(old_input, new_input); - } - } - - for (auto* op : partition) { - slicer.CloneInto(*op); - } - - for (auto* op : partition) { - Drop(*op); - } - - // Reuse the storage from the last op in partition to maintain - // topological order. - slicer.dispatch_op_ = partition.back(); - - ABSL_DCHECK(slicer.dispatch_op_->Inputs().empty()); - ABSL_DCHECK(slicer.dispatch_op_->Outputs().empty()); - MakeDispatchOp(*slicer.dispatch_op_); - slicer.RerouteTensorsThroughCustomOp(root); - - DCE(root); - - return slicer.dispatch_op_; -} - -void GraphSlicer::RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root) { - for (auto it = tensor_map_.Begin(); it != tensor_map_.End(); ++it) { - auto* old_tensor = it->first; - auto* new_tensor = it->second; - - // Reroute tensors which need to be passed into the scope of the new - // subgraph to inputs of the custom op. - if (new_tensor->DefiningOp() == nullptr && !IsConstant(*new_tensor)) { - AttachInput(old_tensor, *dispatch_op_); - continue; - } - - // Reroute custom op as the definer of tensors within the removed partition - // and referenced later in the root graph. - if ((!old_tensor->Users().empty() && !IsConstant(*old_tensor)) || - FindOutput(root, *old_tensor)) { - AttachOutput(old_tensor, *dispatch_op_); - slice_->Outputs().push_back(new_tensor); - } - } -} - -void GraphSlicer::CloneInto(const LiteRtOpT& old_op) { - auto& new_op = MakeClone(*slice_, old_op); - - for (auto i = 0; i < old_op.NumInputs(); ++i) { - auto* old_input = old_op.Inputs().at(i); - LiteRtTensor new_input; - if (tensor_map_.Contains(old_input)) { - // If old_input is already in the map then map[input] is its cloned - // counterpart in the new graph. - auto it = tensor_map_.Find(old_input); - new_input = it->get().second; - } else { - // Otherwise, it must be a new subgraph input (or constant). - new_input = &MakeClone(*slice_, *old_input); - if (!IsConstant(*new_input)) { - slice_->Inputs().push_back(new_input); - } - - tensor_map_.InsertOrAssign(old_input, new_input); - } - - AttachInput(new_input, new_op); - } - - for (int i = 0; i < old_op.NumOutputs(); ++i) { - auto* old_output = old_op.Outputs().at(i); - auto* new_output = &MakeClone(*slice_, *old_output); - AttachOutput(new_output, new_op); - - // Update the values defined in scope of the new subgraph. - tensor_map_.InsertOrAssign(old_output, new_output); - } -} - -} // namespace - -std::vector> GroupPartitions( - const std::vector& ops) { - return DisjointSets::GetPartitionsFromFlatList(ops); -} - -LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, - std::vector& partition) { - return GraphSlicer::SlicePartitionFromGraph(root, slice, partition); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo.h b/tensorflow/lite/experimental/litert/compiler/plugin/algo.h deleted file mode 100644 index 8f82ca33ba0ffe..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// Identifies sub-DAGs of ops connected w.r.t. the use-def chain. Expects -// all "ops" belong to the same Subgraph. The ops in the input -// and output will always be the same. -std::vector> GroupPartitions( - const std::vector& ops); - -// Outlines "partition" from "root" into the empty subgraph "slice". Assumes -// the partition is a valid sub-DAG, and replaces it with a single -// tfl.custom_op in "root". A reference to that op is returned. -LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, - std::vector& partition); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc deleted file mode 100644 index f756f649520c77..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" - -#include - -#include -#include "absl/container/flat_hash_set.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace litert::internal { -namespace { - -TEST(TestPartitionsFromFlatList, SimpleMultiOp) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(2).Get(), 0}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 1); - ASSERT_EQ(partitions.front().size(), 2); - - EXPECT_EQ(partitions.front().at(0), selected_ops.at(0).first); - EXPECT_EQ(partitions.front().at(1), selected_ops.at(1).first); - } - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(3).Get(), 0}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 2); - ASSERT_EQ(partitions.front().size(), 1); - ASSERT_EQ(partitions.back().size(), 1); - - auto p1_op_code = partitions.front().front()->OpCode(); - auto p2_op_code = partitions.back().front()->OpCode(); - - ASSERT_TRUE((p1_op_code == kLiteRtOpCodeTflMul && - p2_op_code == kLiteRtOpCodeTflAdd) || - (p1_op_code == kLiteRtOpCodeTflAdd && - p2_op_code == kLiteRtOpCodeTflMul)); - } - - { - std::vector selected_ops; - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 0); - } - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(0).Get(), 0}); - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(2).Get(), 0}); - selected_ops.push_back({ops.at(3).Get(), 0}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 1); - ASSERT_EQ(partitions.front().size(), 4); - - EXPECT_EQ(partitions.front().at(0), selected_ops.at(0).first); - EXPECT_EQ(partitions.front().at(1), selected_ops.at(1).first); - EXPECT_EQ(partitions.front().at(2), selected_ops.at(2).first); - EXPECT_EQ(partitions.front().at(3), selected_ops.at(3).first); - } -} - -TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - std::vector partition; - partition.push_back(ops.at(1).Get()); - partition.push_back(ops.at(2).Get()); - - auto sliced_graph = litert::Subgraph(&model.Get()->EmplaceSubgraph()); - auto* dispatch_op = - OutlinePartition(*subgraph->Get(), sliced_graph.Get(), partition); - - const auto& internal_sliced = *sliced_graph.Get(); - ASSERT_TRUE(ValidateSubgraphIO(internal_sliced)); - ASSERT_TRUE(ValidateLocalTopology(internal_sliced.Ops().cbegin(), - internal_sliced.Ops().cend())); - - auto edited_subgraph_ops = subgraph->Ops(); - - ASSERT_EQ(edited_subgraph_ops.size(), 3); - ASSERT_EQ(edited_subgraph_ops.at(0).Code(), kLiteRtOpCodeTflAdd); - ASSERT_EQ(edited_subgraph_ops.at(1).Code(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(edited_subgraph_ops.at(2).Code(), kLiteRtOpCodeTflAdd); - - auto sliced_subgraph_ops = sliced_graph.Ops(); - - ASSERT_EQ(sliced_subgraph_ops.size(), 2); - ASSERT_EQ(sliced_subgraph_ops[0].Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(sliced_subgraph_ops[1].Code(), kLiteRtOpCodeTflMul); - - ASSERT_EQ(dispatch_op, edited_subgraph_ops.at(1).Get()); - const Op hal_call(dispatch_op); - - { - const auto dispatch_op_ins = hal_call.Inputs(); - - ASSERT_EQ(dispatch_op_ins.size(), 1); - - auto hal_input_defining_op = dispatch_op_ins.front().DefiningOp(); - ASSERT_EQ(hal_input_defining_op->op, edited_subgraph_ops.at(0).Get()); - ASSERT_EQ(hal_input_defining_op->op_output_index, 0); - - const auto sliced_subgraph_inputs = sliced_graph.Inputs(); - - ASSERT_EQ(sliced_subgraph_inputs.size(), 1); - - ASSERT_TRUE(MatchUses(sliced_subgraph_inputs.front(), - {UseInfo{sliced_subgraph_ops.front().Code(), 0}, - UseInfo{sliced_subgraph_ops.front().Code(), 0}})); - ASSERT_TRUE(sliced_subgraph_inputs.front().IsSubgraphInput()); - } - - { - const auto hal_call_outs = hal_call.Outputs(); - ASSERT_EQ(hal_call_outs.size(), 1); - const auto& hal_call_out = hal_call_outs.front(); - - ASSERT_TRUE(MatchUses(hal_call_out, - {UseInfo{edited_subgraph_ops.back().Code(), 0}, - UseInfo{edited_subgraph_ops.back().Code(), 1}})); - - auto sliced_subgraph_outputs = sliced_graph.Outputs(); - - ASSERT_EQ(sliced_subgraph_outputs.size(), 1); - - const auto defining_op = sliced_subgraph_outputs.front().DefiningOp(); - ASSERT_EQ(defining_op->op, sliced_subgraph_ops.back().Get()); - ASSERT_EQ(defining_op->op_output_index, 0); - - ASSERT_TRUE(sliced_subgraph_outputs.front().Uses().empty()); - } -} - -TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - std::vector partition_1; - partition_1.push_back(ops.at(0).Get()); - - auto sliced_graph_1 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); - OutlinePartition(*(subgraph->Get()), sliced_graph_1.Get(), partition_1); - - const auto& internal_slice_1 = *sliced_graph_1.Get(); - ASSERT_TRUE(ValidateSubgraphIO(internal_slice_1)); - ASSERT_TRUE(ValidateLocalTopology(internal_slice_1.Ops().cbegin(), - internal_slice_1.Ops().cend())); - - std::vector partition_2; - partition_2.push_back(ops.at(2).Get()); - partition_2.push_back(ops.at(3).Get()); - - auto sliced_graph_2 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); - OutlinePartition(*(subgraph->Get()), sliced_graph_2.Get(), partition_2); - - const auto& internal_slice_2 = *sliced_graph_2.Get(); - ASSERT_TRUE(ValidateSubgraphIO(internal_slice_2)); - ASSERT_TRUE(ValidateLocalTopology(internal_slice_2.Ops().cbegin(), - internal_slice_2.Ops().cend())); - - auto edited_subgraph_ops = subgraph->Ops(); - - ASSERT_EQ(edited_subgraph_ops.size(), 3); - ASSERT_EQ(edited_subgraph_ops.at(0).Code(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(edited_subgraph_ops.at(1).Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(edited_subgraph_ops.at(2).Code(), kLiteRtOpCodeTflCustom); - - { - auto sliced_ops = sliced_graph_1.Ops(); - - ASSERT_EQ(sliced_ops.size(), 1); - ASSERT_EQ(sliced_ops.at(0).Code(), kLiteRtOpCodeTflAdd); - } - - { - auto sliced_ops = sliced_graph_2.Ops(); - - ASSERT_EQ(sliced_ops.size(), 2); - ASSERT_EQ(sliced_ops.at(0).Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(sliced_ops.at(1).Code(), kLiteRtOpCodeTflAdd); - } -} - -TEST(TestSliceSubgraphSimpleMultiOp, PartitionWithIndex) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(2).Get(), 1}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 2); - ASSERT_EQ(partitions.front().size(), 1); - ASSERT_EQ(partitions.back().size(), 1); - - absl::flat_hash_set ops_in_partition; - for (int i = 0; i < partitions.size(); ++i) { - for (const auto& op : partitions.at(i)) { - ops_in_partition.insert(op); - } - } - for (int i = 0; i < partitions.size(); ++i) { - EXPECT_TRUE(ops_in_partition.contains(selected_ops.at(i).first)); - } - } - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(0).Get(), 1}); - selected_ops.push_back({ops.at(1).Get(), 2}); - selected_ops.push_back({ops.at(2).Get(), 3}); - selected_ops.push_back({ops.at(3).Get(), 4}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 4); - - absl::flat_hash_set ops_in_partition; - for (int i = 0; i < partitions.size(); ++i) { - for (const auto& op : partitions.at(i)) { - ops_in_partition.insert(op); - } - } - for (int i = 0; i < partitions.size(); ++i) { - EXPECT_TRUE(ops_in_partition.contains(selected_ops.at(i).first)); - } - } -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc deleted file mode 100644 index 2cc0fea5f0909f..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" - -#include -#include -#include -#include - -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -namespace { -static constexpr absl::string_view kPairChar = "="; -static constexpr absl::string_view kDelim = ","; -} // namespace - -namespace litert::internal { - -void CompilerFlags::Clear() { - keys_.clear(); - values_.clear(); -} - -void CompilerFlags::Push(std::string key, std::string value) { - keys_.push_back(std::move(key)); - values_.push_back(std::move(value)); -} - -LiteRtStatus CompilerFlags::SetPluginFlags( - LiteRtCompilerPlugin handle, - decltype(LiteRtCompilerPluginSetFlags) set_flags) const { - std::vector keys(keys_.size()); - std::vector values(values_.size()); - for (auto i = 0; i < keys_.size(); ++i) { - keys[i] = keys_[i].c_str(); - values[i] = values_[i].c_str(); - } - return set_flags(handle, keys.size(), keys.data(), values.data()); -} - -Expected ParseCompilerFlags(absl::string_view flags_str) { - using KeyVal = std::pair; - - CompilerFlags result; - if (flags_str.empty()) { - return result; - } - - for (const auto flag : absl::StrSplit(flags_str, kDelim)) { - KeyVal key_value = absl::StrSplit(flag, absl::MaxSplits(kPairChar, 1)); - result.Push(std::move(key_value.first), std::move(key_value.second)); - } - - return result; -} - -} // namespace litert::internal - -std::ostream& operator<<(std::ostream& os, - const litert::internal::CompilerFlags& flags) { - for (auto i = 0; i < flags.keys_.size(); ++i) { - os << flags.keys_[i]; - const auto& value = flags.values_[i]; - if (!value.empty()) { - os << kPairChar << value; - } - if (i < flags.keys_.size() - 1) { - os << kDelim; - } - } - return os; -} diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h deleted file mode 100644 index 403ff1db527fa9..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_FLAGS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_FLAGS_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -namespace litert::internal { -class CompilerFlags; -} - -// For logging. -std::ostream& operator<<(std::ostream& os, - const litert::internal::CompilerFlags& flags); - -namespace litert::internal { - -class CompilerFlags { - public: - CompilerFlags() = default; - - // Clears all flags. - void Clear(); - - // Pushes a new flag to the end of the list. - void Push(std::string key, std::string value = ""); - - // Sets the flags on the given plugin. - LiteRtStatus SetPluginFlags( - LiteRtCompilerPlugin handle, - decltype(LiteRtCompilerPluginSetFlags) set_flags) const; - - private: - friend std::ostream& ::operator<<(std::ostream& os, - const CompilerFlags& flags); - - std::vector keys_; - std::vector values_; -}; - -// Parses a comma-separated (no space) list of compiler flags. Flags may be -// key-value pairs in the format of "key=value", or just "key". E.g. -// "key1=value1,key2". -Expected ParseCompilerFlags(absl::string_view flags_str); - -} // namespace litert::internal - -// For logging. - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_FLAGS_H_ diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc deleted file mode 100644 index 0fcfdd72c52740..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility types for mapping LiteRt IR to arbitrary backend specific -// types. Implementations of these types define mapping for ops and tensors -// that may be used in a stndalone fashion. They also may be composed -// to create lowerings of entire graphs with topology. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" - -#include -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -struct LiteRtCompilerPluginT { - using Flag = std::pair; - std::vector flags; -}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - auto& flags = compiler_plugin->flags; - flags.resize(num_flags); - for (int i = 0; i < num_flags; ++i) { - auto& flag = flags[i]; - flag.first = std::string(keys[i]); - flag.second = std::string(values[i]); - } - return kLiteRtStatusOk; -} - -namespace litert::internal { -namespace { - -using ::testing::ElementsAre; -using ::testing::Pair; - -TEST(CompilerFlagsTest, SetPluginFlags) { - static constexpr const char* kKey1 = "key1"; - static constexpr const char* kKey2 = "key2"; - static constexpr const char* kKey3 = "key3"; - static constexpr const char* kValue1 = "value1"; - static constexpr const char* kEmtpyVal = ""; - - LiteRtCompilerPluginT plugin; - CompilerFlags flags; - flags.Push(kKey1, kValue1); - flags.Push(kKey2, kEmtpyVal); - flags.Push(kKey3); - LITERT_ASSERT_OK(flags.SetPluginFlags(&plugin, LiteRtCompilerPluginSetFlags)); - - EXPECT_THAT(plugin.flags, - ElementsAre(Pair(kKey1, kValue1), Pair(kKey2, kEmtpyVal), - Pair(kKey3, kEmtpyVal))); -} - -TEST(CompilerFlagsTest, ParseCompilerFlags) { - static constexpr const char* kKey1 = "key1"; - static constexpr const char* kKey2 = "key2"; - static constexpr const char* kKey3 = "key3"; - static constexpr const char* kValue1 = "value1"; - static constexpr const char* kEmtpyVal = ""; - - const auto flags_str = - absl::StrCat(kKey1, "=", kValue1, ",", kKey2, "=", kEmtpyVal, ",", kKey3); - - LiteRtCompilerPluginT plugin; - CompilerFlags flags; - flags.Push(kKey1, kValue1); - flags.Push(kKey2, kEmtpyVal); - flags.Push(kKey3); - LITERT_ASSERT_OK(flags.SetPluginFlags(&plugin, LiteRtCompilerPluginSetFlags)); - - EXPECT_THAT(plugin.flags, - ElementsAre(Pair(kKey1, kValue1), Pair(kKey2, kEmtpyVal), - Pair(kKey3, kEmtpyVal))); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc deleted file mode 100644 index d593840081e293..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc +++ /dev/null @@ -1,649 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" - -namespace litert::internal { - -// -// CompiledResult -// - -Expected> CompiledResult::ByteCode( - LiteRtParamIndex byte_code_idx) const { - const void* data; - size_t size; - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_byte_code( - compiled_result_handle_, byte_code_idx, &data, &size)); - return BufferRef(data, size); -} - -Expected CompiledResult::NumByteCodeModules() const { - LiteRtParamIndex byte_code_idx; - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_num_byte_code( - compiled_result_handle_, &byte_code_idx)); - return byte_code_idx; -} - -Expected CompiledResult::NumCalls() const { - LiteRtParamIndex num_calls; - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_num_calls( - compiled_result_handle_, &num_calls)); - return num_calls; -} - -Expected CompiledResult::CallInfo(LiteRtParamIndex call_idx) const { - const void* data; - size_t size; - LiteRtParamIndex byte_code_idx; - - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_call_info( - compiled_result_handle_, call_idx, &data, &size, &byte_code_idx)); - - absl::string_view call_info_str(reinterpret_cast(data), size); - return ::litert::internal::CallInfo(call_info_str, byte_code_idx); -} - -CompiledResult::~CompiledResult() { - if (compiled_result_handle_ != nullptr) { - parent_.destroy_compiled_result(compiled_result_handle_); - } -} - -CompiledResult::CompiledResult(CompiledResult&& other) - : parent_(other.parent_), - compiled_result_handle_(other.compiled_result_handle_) { - other.parent_ = {}; - other.compiled_result_handle_ = nullptr; -} - -CompiledResult& CompiledResult::operator=(CompiledResult&& other) { - if (this != &other) { - parent_ = other.parent_; - other.parent_ = {}; - - compiled_result_handle_ = other.compiled_result_handle_; - other.compiled_result_handle_ = nullptr; - } - return *this; -} - -// -// CompilerPlugin -// - -namespace { - -#define RESOLVE_API_FUNC(name, dest) \ - LITERT_ASSIGN_OR_RETURN(dest, lib.LookupSymbol(name.data())); - -LiteRtStatus ResolvePluginApi(SharedLibrary& lib, - LiteRtCompilerPluginApi& result) { - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginVersion, - result.get_compiler_plugin_version); - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedHardware, - result.get_compiler_plugin_supported_hardware); - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSocManufacturer, - result.get_compiler_plugin_soc_manufacturer); - RESOLVE_API_FUNC(kLiteRtGetNumCompilerPluginSupportedSocModels, - result.get_num_compiler_plugin_supported_models); - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedSocModel, - result.get_compiler_plugin_supported_soc_model); - - RESOLVE_API_FUNC(kLiteRtCreateCompilerPlugin, result.create_compiler_plugin); - RESOLVE_API_FUNC(kLiteRtDestroyCompilerPlugin, - result.destroy_compiler_plugin); - - RESOLVE_API_FUNC(kLiteRtCompilerPluginPartition, - result.compiler_plugin_partition); - RESOLVE_API_FUNC(kLiteRtCompilerPluginCompile, - result.compiler_plugin_compile); - - RESOLVE_API_FUNC(kLiteRtDestroyCompiledResult, - result.destroy_compiled_result); - RESOLVE_API_FUNC(kLiteRtCompiledResultNumByteCodeModules, - result.get_compiled_result_num_byte_code); - RESOLVE_API_FUNC(kLiteRtGetCompiledResultByteCode, - result.get_compiled_result_byte_code); - RESOLVE_API_FUNC(kLiteRtGetCompiledResultCallInfo, - result.get_compiled_result_call_info); - RESOLVE_API_FUNC(kLiteRtGetNumCompiledResultCalls, - result.get_compiled_result_num_calls); - RESOLVE_API_FUNC(kLiteRtCompilerPluginSetFlags, result.set_flags); - - return kLiteRtStatusOk; -} - -Expected> GetSocModels( - const LiteRtCompilerPluginApi& api, LiteRtCompilerPlugin plugin_handle) { - std::vector soc_models; - - LiteRtParamIndex num_models; - LITERT_RETURN_IF_ERROR( - api.get_num_compiler_plugin_supported_models(plugin_handle, &num_models)); - - for (LiteRtParamIndex i = 0; i < num_models; ++i) { - const char* model; - if (api.get_compiler_plugin_supported_soc_model(plugin_handle, i, &model) != - kLiteRtStatusOk) { - continue; - } - soc_models.push_back(std::string(model)); - } - - return soc_models; -} - -// Sort plugins so that we first apply those supporting NPU, then those -// supporting GPU, and finally those supporting CPU. -void SortPlugins(std::vector& compiler_plugins) { - std::sort(compiler_plugins.begin(), compiler_plugins.end(), - [](auto& x, auto& y) { - auto x_supported_hardware = x.SupportedHardware(); - auto y_supported_hardware = y.SupportedHardware(); - if (x_supported_hardware && y_supported_hardware) { - bool x_npu = (*x_supported_hardware & kLiteRtHwAcceleratorNpu); - bool x_gpu = (*x_supported_hardware & kLiteRtHwAcceleratorGpu); - bool x_cpu = (*x_supported_hardware & kLiteRtHwAcceleratorCpu); - bool y_npu = (*y_supported_hardware & kLiteRtHwAcceleratorNpu); - bool y_gpu = (*y_supported_hardware & kLiteRtHwAcceleratorGpu); - bool y_cpu = (*y_supported_hardware & kLiteRtHwAcceleratorCpu); - int x_score = 100 * x_npu + 10 * x_gpu + x_cpu; - int y_score = 100 * y_npu + 10 * y_gpu + y_cpu; - return x_score < y_score; - } - return true; - }); -} - -} // namespace - -Expected CompilerPlugin::LoadPlugin( - const absl::string_view lib_path) { - CompilerPlugin plugin; - LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.data()); - - LITERT_ASSIGN_OR_RETURN( - plugin.lib_, - SharedLibrary::Load(lib_path, RtldFlags::Now().Local().DeepBind())); - LITERT_LOG(LITERT_INFO, "Loaded plugin at: %s", lib_path.data()); - - LITERT_RETURN_IF_ERROR(ResolvePluginApi(plugin.lib_, plugin.plugin_api_)); - LITERT_LOG(LITERT_INFO, "Resolved plugin api at: %s", lib_path.data()); - - LITERT_RETURN_IF_ERROR( - plugin.plugin_api_.create_compiler_plugin(&plugin.plugin_handle_)); - LITERT_LOG(LITERT_INFO, "Initialize plugin at: %s", lib_path.data()); - - auto api_version = plugin.ApiVersion(); - if (!api_version) { - return api_version.Error(); - } - - LITERT_RETURN_IF_ERROR(litert::internal::IsSameVersionAsRuntime(*api_version), - Unexpected(kLiteRtStatusErrorWrongVersion, - "Unsupported compiler plugin version")); - - // This should never change throughout the lifetime of the compiler - // plugin so save to avoid recalling. - auto soc_models = GetSocModels(plugin.plugin_api_, plugin.plugin_handle_); - if (!soc_models) { - return soc_models.Error(); - } - plugin.soc_models_ = *soc_models; - - return plugin; -} - -Expected> CompilerPlugin::LoadPlugins( - absl::Span lib_search_paths) { - std::vector plugin_lib_paths; - for (auto search_path : lib_search_paths) { - // Skip paths that are not valid. - if (Exists(search_path)) { - LITERT_RETURN_IF_ERROR( - FindLiteRtCompilerPluginSharedLibs(search_path, plugin_lib_paths)); - } - } - - std::vector loaded_plugins; - loaded_plugins.reserve(lib_search_paths.size()); - - for (const auto& lib_path : plugin_lib_paths) { - LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.c_str()); - auto plugin = LoadPlugin(lib_path); - if (!plugin.HasValue()) { - continue; - } - loaded_plugins.push_back(std::move(plugin.Value())); - } - - // Sort plugins. - SortPlugins(loaded_plugins); - - return loaded_plugins; -} - -CompilerPlugin::CompilerPlugin(CompilerPlugin&& other) - : soc_models_(std::move(other.soc_models_)), - lib_(std::move(other.lib_)), - plugin_api_(std::move(other.plugin_api_)), - plugin_handle_(std::move(other.plugin_handle_)) { - other.soc_models_ = {}; - other.plugin_api_ = {}; - other.lib_.Close(); - other.plugin_handle_ = nullptr; -} - -CompilerPlugin& CompilerPlugin::operator=(CompilerPlugin&& other) { - if (this != &other) { - std::swap(soc_models_, other.soc_models_); - std::swap(lib_, other.lib_); - std::swap(plugin_api_, other.plugin_api_); - std::swap(plugin_handle_, other.plugin_handle_); - } - return *this; -} - -CompilerPlugin::~CompilerPlugin() { - if (plugin_handle_ != nullptr) { - plugin_api_.destroy_compiler_plugin(plugin_handle_); - } -} - -std::string CompilerPlugin::DebugString() const { - std::string version_str = "?"; - if (auto version = ApiVersion(); version) { - version_str = absl::StrFormat("%d.%d.%d", version->major, version->minor, - version->patch); - } - return absl::StrFormat("%s compiler plugin (ver %s)", SocManufacturer(), - version_str); -} - -Expected CompilerPlugin::ApiVersion() const { - LiteRtApiVersion api_version; - LITERT_RETURN_IF_ERROR(plugin_api_.get_compiler_plugin_version(&api_version)); - return api_version; -} - -Expected CompilerPlugin::SupportedHardware() const { - LiteRtHwAccelerators supported_hardware; - LITERT_RETURN_IF_ERROR(plugin_api_.get_compiler_plugin_supported_hardware( - plugin_handle_, &supported_hardware)); - return supported_hardware; -} - -Expected> CompilerPlugin::Partition( - const Subgraph& subgraph, absl::string_view soc_model) { - LiteRtOpListT ops; - const char* soc_model_str = !soc_model.empty() ? soc_model.data() : nullptr; - LITERT_RETURN_IF_ERROR(plugin_api_.compiler_plugin_partition( - plugin_handle_, soc_model_str, subgraph.Get(), &ops)); - return ops.Values(); -} - -Expected CompilerPlugin::Compile(LiteRtModel partitions, - absl::string_view soc_model) { - CompiledResult result = MakeResult(); - // If the user has passed an soc_model, then we use it; otherwise we let the - // backend pick the appropriate one by passing nullptr as soc_model. This is - // important for on-device compilation, where the backend must determine the - // SoC model based on the user device. - const char* soc_model_str = !soc_model.empty() ? soc_model.data() : nullptr; - LITERT_RETURN_IF_ERROR(plugin_api_.compiler_plugin_compile( - plugin_handle_, soc_model_str, partitions, - &result.compiled_result_handle_)); - return result; -} - -namespace { - -LiteRtStatus PartitionSubgraph( - std::vector selected_ops, - LiteRtSubgraphT& subgraph, PartitionResult& result, - BufferManager* buffer_manager) { - // Group selected ops into connected islands. - auto islands = GroupPartitions(selected_ops); - if (islands.empty()) { - return kLiteRtStatusOk; - } - - // For each connected island, slice into new subgraph and replace use with - // single dispatch op. - for (auto& island : islands) { - auto& new_subgraph = result.second.EmplaceBack(buffer_manager); - auto* dispatch_op = OutlinePartition(subgraph, &new_subgraph, island); - result.first.push_back(dispatch_op); - } - - return kLiteRtStatusOk; -} - -} // namespace - -Expected PartitionModel( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - const absl::flat_hash_set& subgraphs_to_partition) { - // This algorithm decides the subgraphs to be partitioned by the plugin. This - // is a trivial process with the exception of composite ops and their - // decomposition subgraphs. Currently, we deploy the most naive approach to - // handling composite ops. - // - // There are two cases to consider: - // 1. The composite op is an "odml.npu_call", in which case it represents a - // parition which was explictly requested by the model author. - // - // In this case, the the composite itself is always selected, regardless of - // whether the plugin selects it. Its subgraph is not passed to the partition - // function and it is passed in its entirety to the compilation function. - // - // More advanced behavior could include: - // * Ensuring the plugin can compile the entire partition, and inlining it if - // not. - // - // 2. Standard non npu_call composite ops. Currently these are treated as a - // regular op, and their decomposition subgraphs are completely ignored in all - // phases of plugin application. - // - // More advanced behavior could include: - // * Allowing the plugin to compile the decomposition subgraph in the case - // it cannot lower the composite directly. Potentially inline in this case - // contingent on the availability of a suitable CPU kernel for the composite - // op. - // - // ASSUMPTIONS: - // * npu_call ops ARE NOT nested within decompositions of other npu_call ops. - // * Standard composite ops ARE allowed to be nested within decompositions of - // npu_call ops. - // * No two npu_call ops share the same subgraph. - - // Find decomposition subgraphs and npu_call ops. These will be used to filter - // subgraphs passed to the plugin and pass on auto-selected npu_call - // partitions. - absl::flat_hash_set decomp_subgraphs; - std::vector npu_calls; - - ForEachIr(&model, [&](LiteRtOp op) { - auto info = GetOptionsAs(op); - if (!info) { - return; - } - decomp_subgraphs.insert(info->subgraph); - if (info->name == CompositeOptions::kNpuCall) { - npu_calls.push_back(std::move(*info)); - } - }); - - // Build partition result via calling plugin on non-decomposition subgraphs. - PartitionResult result; - for (auto i = 0; i < model.Subgraphs().size(); ++i) { - if (decomp_subgraphs.contains(i)) { - continue; - } - if (!subgraphs_to_partition.empty() && - !subgraphs_to_partition.contains(i)) { - continue; - } - auto* subgraph = model.Subgraphs()[i]; - auto selected_ops = compiler_plugin.Partition(Subgraph(subgraph)); - // TODO ensure selected ops don't contain npu_calls. - if (!selected_ops) { - return selected_ops.Error(); - } - auto num_selected_ops = selected_ops->size(); - auto num_ops = subgraph->Ops().size(); - - auto num_partitions = result.first.size(); - LITERT_RETURN_IF_ERROR(PartitionSubgraph( - std::move(*selected_ops), *subgraph, result, model.Buffers())); - num_partitions = result.first.size() - num_partitions; - LITERT_LOG(LITERT_INFO, - "PartitionSubgraph: %d, selected num ops: %lu, from totoal ops: " - "%lu, num partitions: %lu", - i, num_selected_ops, num_ops, num_partitions); - } - - // Add npu_call partitions to result. Update the npu_call ops to be dispatch - // ops. - std::vector decomps_to_compile; - for (auto& npu_call : npu_calls) { - auto* op = npu_call.op; - MakeDispatchOp(*op); - result.first.push_back(op); - decomps_to_compile.push_back(npu_call.subgraph); - } - model.TransferSubgraphTo(result.second, std::move(decomps_to_compile)); - - return result; -} - -Expected PartitionModelDirect( - std::vector selected_ops, LiteRtModelT& model) { - if (model.Subgraphs().size() != 1) { - // Only single subgraphs supported for direct partitioning. - return Unexpected(kLiteRtStatusErrorRuntimeFailure); - } - // Accumulate partition results for each subgraph in model. - PartitionResult result; - auto* subgraph = model.Subgraphs().front(); - LITERT_RETURN_IF_ERROR(PartitionSubgraph(std::move(selected_ops), *subgraph, - result, model.Buffers())); - ABSL_DCHECK_EQ(result.first.size(), result.second.Size()); - return result; -} - -Expected ApplyPluginWithPartition(CompilerPlugin& compiler_plugin, - LiteRtModelT& model, - PartitionResult partitions, - absl::string_view soc_model) { - auto& dispatch_ops = partitions.first; - auto& subgraphs = partitions.second; - - // Wrap the partitioned subgraphs in a LiteRtModel. - LiteRtModelT sliced_model; - sliced_model.TransferSubgraphsFrom(std::move(subgraphs)); - - // Copy op codes. - const auto& op_codes = litert::internal::GetTflOpCodes(model); - - LiteRtModelT::TflOpCodes codes; - codes.reserve(op_codes.size()); - for (const auto& op_code : op_codes) { - codes.emplace_back(std::make_unique(*op_code)); - } - - litert::internal::SetTflOpCodes(sliced_model, std::move(codes)); - - // Pass sliced subgraphs to plugin for compilation. - auto compiled_result = compiler_plugin.Compile(&sliced_model, soc_model); - if (!compiled_result) { - return compiled_result.Error(); - } - - // Register byte code buffers as external buffers. Map the byte code indices - // to the registered buffer ids. - auto num_byte_code = compiled_result->NumByteCodeModules(); - if (!num_byte_code) { - return num_byte_code.Error(); - } - - std::vector byte_code_idx_to_buf_id(*num_byte_code); - - for (auto i = 0; i < *num_byte_code; ++i) { - auto byte_code = compiled_result->ByteCode(i); - if (!byte_code) { - return byte_code.Error(); - } - - // TODO: This copy could probably be avoided. - OwningBufferRef owned_byte_code(byte_code->Data(), - byte_code->Size()); - const auto buf_id = - model.Buffers()->RegisterOwnedBuffer(std::move(owned_byte_code)); - - byte_code_idx_to_buf_id[i] = buf_id; - } - - // Register byte code buffers and add edges from dispatch ops to them. - for (auto i = 0; i < dispatch_ops.size(); ++i) { - auto* dispatch_op = dispatch_ops.at(i); - - auto call_info = compiled_result->CallInfo(i); - if (!call_info) { - return call_info.Error(); - } - auto [name, byte_code_idx] = *call_info; - const auto buf_id = byte_code_idx_to_buf_id[byte_code_idx]; - - model.AttachAssetToOp(dispatch_op, buf_id, std::string(name)); - } - - // Tag the model with make/model from the plugin. - auto build_stamp = - MakeBuildStamp(compiler_plugin.SocManufacturer(), soc_model); - if (!build_stamp) { - return build_stamp.Error(); - } - - if (auto status = - model.PushMetadata(kLiteRtBuildStampKey, std::move(*build_stamp)); - status != kLiteRtStatusOk) { - return Error(status); - } - - return {}; -} - -Expected ApplyPlugin( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - absl::string_view soc_model, - const absl::flat_hash_set& subgraphs_to_partition) { - // Collect partitions to pass to compilation. - auto partitions = - PartitionModel(compiler_plugin, model, subgraphs_to_partition); - if (!partitions) { - return partitions.Error(); - } - return ApplyPluginWithPartition(compiler_plugin, model, - std::move(*partitions), soc_model); -} - -Expected ApplyPlugins( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtHwAcceleratorSet selected_hw_accelerators, bool* mutated) { - auto option = - environment->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryDir); - if (!option.has_value() || option->type != kLiteRtAnyTypeString) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Compiler plugin is not configured"); - } - std::string compiler_plugin_lib_path = option->str_value; - - const std::array - compiler_plugin_lib_search_paths = {compiler_plugin_lib_path}; - - auto compiler_plugins = litert::internal::CompilerPlugin::LoadPlugins( - compiler_plugin_lib_search_paths); - if (!compiler_plugins) { - return compiler_plugins.Error(); - } - if (compiler_plugins->empty()) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "No compiler plugin found"); - } - - std::vector success_messages; - std::vector error_messages; - - ApplyPluginsResult result; - result.num_applied_plugins = 0; - for (auto& compiler_plugin : *compiler_plugins) { - auto plugin_name = compiler_plugin.DebugString(); - - auto plugin_supported_hardware = compiler_plugin.SupportedHardware(); - if (!plugin_supported_hardware) { - error_messages.push_back(absl::StrCat( - plugin_name, " ", plugin_supported_hardware.Error().Message())); - continue; - } - - if (*plugin_supported_hardware & selected_hw_accelerators) { - auto status = ApplyPlugin(compiler_plugin, *model); - if (mutated != nullptr) { - *mutated = true; - } - if (!status) { - error_messages.push_back( - absl::StrCat(plugin_name, " ", status.Error().Message())); - continue; - } - - success_messages.push_back(absl::StrCat(plugin_name)); - result.num_applied_plugins++; - } - } - - result.success_message = absl::StrJoin(success_messages, ", "); - result.error_message = absl::StrJoin(error_messages, ", "); - - return result; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h deleted file mode 100644 index 76c6ccbdc1b2df..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" - -// C++ wrappers and high-level functions for managing compiler plugins -// and applying them to models. - -namespace litert::internal { - -// Name and index of byte code. -using CallInfo = std::pair; - -// Wraps vendor compiled result. Must be outlived by the CompilerPlugin -// the generated it. -class CompiledResult { - public: - friend class CompilerPlugin; - - // Number of byte code modules compiled by the plugin. - Expected NumByteCodeModules() const; - - // Get the single module of compiled byte code. This contains the - // compilation result for all entry points. - Expected> ByteCode( - LiteRtParamIndex byte_code_idx = 0) const; - - // Get information regarding the "ith" entry points in the compiled module. - // There will be oe entry point for each subgraph compiled for. - Expected CallInfo(LiteRtParamIndex call_idx) const; - - // Get the number of entry points in the compiled module. This will be equal - // to the number of subgraphs passed to the compilation step. - Expected NumCalls() const; - - explicit CompiledResult(const LiteRtCompilerPluginApi& parent) - : parent_(parent) {} - - CompiledResult(CompiledResult&& other); - CompiledResult& operator=(CompiledResult&& other); - CompiledResult(const CompiledResult& other) = delete; - CompiledResult& operator=(const CompiledResult& other) = delete; - - ~CompiledResult(); - - private: - LiteRtCompilerPluginApi parent_; - LiteRtCompiledResult compiled_result_handle_ = nullptr; -}; - -// Wraps vendor compiler plugin. -class CompilerPlugin { - public: - std::string DebugString() const; - - // Get the compiler plugin's API version. - Expected ApiVersion() const; - - // Get the supported HW accelerators (e.g., GPU, NPU). - Expected SupportedHardware() const; - - // Get the manufacturer associated with this plugin. NOTE: SocManufacturer - // string returned by the underlying plugin are expected to have static - // lifetime. - absl::string_view SocManufacturer() const { - return plugin_api_.get_compiler_plugin_soc_manufacturer(); - } - - // Get list of unique soc models targetable by this plugin. - const std::vector& SocModels() const { return soc_models_; } - - // Selects ops for the plugin to compile. - Expected> Partition( - const Subgraph& subgraph, absl::string_view soc_model = ""); - - // Compile given LiteRtSubgraphs. Result object must be outlived by - // this CompilerPlugin. - Expected Compile(LiteRtModel partitions, - absl::string_view soc_model = ""); - - // Search for shared library files with prefix "libLiteRtCompilerPlugin" in - // the directories passed through "lib_search_paths". Populates - // "loaded_plugins" with resolved plugin apis for each found library that can - // be successfully loaded. Additionally initializes the compiler plugin - // instances and stores handle. - static Expected> LoadPlugins( - absl::Span lib_search_paths); - - // Set compiler flags within the plugin. - LiteRtStatus SetFlags(const CompilerFlags& flags) { - return flags.SetPluginFlags(plugin_handle_, plugin_api_.set_flags); - } - - CompilerPlugin(CompilerPlugin&& other); - CompilerPlugin& operator=(CompilerPlugin&& other); - CompilerPlugin(const CompilerPlugin& other) = delete; - CompilerPlugin& operator=(const CompilerPlugin& other) = delete; - - // Destroys any living `LiteRtCompilerPlugin` and frees reference - // to dynamically loaded library. - ~CompilerPlugin(); - - private: - static Expected LoadPlugin(absl::string_view lib_path); - CompilerPlugin() = default; - - std::vector soc_models_; - SharedLibrary lib_; - LiteRtCompilerPluginApi plugin_api_ = {}; - LiteRtCompilerPlugin plugin_handle_ = nullptr; - - // Internal LiteRtCompiledResult wrapper. - - CompiledResult MakeResult() const { return CompiledResult(plugin_api_); } -}; - -// Higher level functions for applying plugin to graph. -//===--------------------------------------------------------------------------- - -// Dispatch op references and their subgraph to be compiled. -using PartitionResult = - std::pair, typename LiteRtSubgraphT::Alloc>; - -// Applies just the partition phase of the plugin on the model. Returns -// references newly allocated subgraphs removed from input and their -// corresponding dispatch ops in the input. -Expected PartitionModel( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - const absl::flat_hash_set& subgraphs_to_partition = {}); - -// Same as "PartitionModel" choose partitions directly based on the selected -// ops. Selected ops may contain any ops in the the main subgraph of the model. -// This function will separate them into DAGs and slice the model accordingly. -Expected PartitionModelDirect( - std::vector selected_ops, LiteRtModelT& model); - -// Applies both the partition and compile steps to the model. Generated -// byte_code will be internalized within the model for later serialization. -Expected ApplyPlugin( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - absl::string_view soc_model = "", - const absl::flat_hash_set& subgraphs_to_partition = {}); - -// Applies the compilation step to the model given a predetermined partition. -Expected ApplyPluginWithPartition(CompilerPlugin& compiler_plugin, - LiteRtModelT& model, - PartitionResult partitions, - absl::string_view soc_model = ""); - -// Apply all available plugins providing the selected HW accelerators to the -// given model, modify the model accordingly, and return (1) the number of -// compiler plugins successfully applied, (2) a string listing the compiler -// plugins that were successfully applied, and (3) a string listing the compiler -// plugins that failed to apply with an associated error message. This mutates -// the given model. -struct ApplyPluginsResult { - size_t num_applied_plugins; - std::string success_message; - std::string error_message; -}; - -Expected ApplyPlugins( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtHwAcceleratorSet selected_hw_accelerators, bool* mutated = nullptr); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc deleted file mode 100644 index 96403219ec93b3..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc +++ /dev/null @@ -1,498 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -namespace litert::internal { -namespace { - -using testing::UniqueTestDirectory; - -constexpr absl::string_view kTestPluginSearchPath = - "third_party/tensorflow/lite/experimental/litert/vendors/examples"; - -constexpr absl::string_view kTestManufacturer = "ExampleSocManufacturer"; -constexpr absl::string_view kTestModels = "ExampleSocModel"; - -TEST(CompilerPluginTest, LoadTestPlugin) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - ASSERT_EQ(plugins->front().SocModels().size(), 1); - EXPECT_EQ(plugins->front().SocModels().front(), kTestModels); -} - -TEST(CompilerPluginTest, LoadTestPluginWithMalformed) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), "notLibLiteRt.so"})); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, MultipleValidPlugins) { - auto plugins = CompilerPlugin::LoadPlugins( - {kTestPluginSearchPath, kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 2); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - EXPECT_EQ(plugins->back().SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, MoveAssign) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - CompilerPlugin other = std::move(plugins->front()); - - EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, MoveConstruct) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - CompilerPlugin other(std::move(plugins->front())); - - EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, SocModels) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - EXPECT_THAT(plugins->front().SocModels(), - ::testing::ElementsAreArray({kTestModels})); -} - -TEST(CompilerPluginTest, SetFlags) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - LITERT_ASSERT_OK(plugins->front().SetFlags(CompilerFlags())); -} - -TEST(CompilerPluginTest, Partition) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - auto subgraph = model.MainSubgraph(); - auto ops = plugins->front().Partition(*subgraph); - ASSERT_TRUE(ops); - - EXPECT_EQ(ops->size(), 2); -} - -TEST(CompilerPluginTest, Compile) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - auto result = plugins->front().Compile(&model); - ASSERT_TRUE(result); - - auto byte_code = result->ByteCode(); - ASSERT_TRUE(byte_code && byte_code->Size() > 0); - - auto num_calls = result->NumCalls(); - ASSERT_TRUE(num_calls); - ASSERT_EQ(*num_calls, 1); - - auto call_info = result->CallInfo(0); - ASSERT_TRUE(call_info); -} - -TEST(CompilerPluginTest, Dump) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - - std::stringstream dump; - Dump(plugins->front(), dump); - - ASSERT_EQ(dump.view(), - "SocManufacturer: ExampleSocManufacturer\nSocModels: { " - "ExampleSocModel }\n"); -} - -TEST(PartitionModelTest, Simple) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 1); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 1); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 2); -} - -TEST(PartitionModelTest, PartitionDirect) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - std::vector selected_ops = { - {model.MainSubgraph()->Ops().front(), 0}, - {model.MainSubgraph()->Ops().back(), 0}}; - - auto partition_result = PartitionModelDirect(std::move(selected_ops), model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 1); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 1); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 2); -} - -TEST(PartitionModelTest, MultiSubgraph) { - auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 2); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 2); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_EQ(ops.back()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 2); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); - EXPECT_EQ(subgraphs.Elements().back()->Ops().size(), 1); -} - -TEST(PartitionModelTest, MultiSubgraphWithSelectedSubgraphs) { - auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model, {1}); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 2); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 1); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); -} - -TEST(PartitionModelTest, CstMultiSubgraph) { - auto model_wrap = testing::LoadTestFileModel("multi_use_cst.tflite"); - auto& model = *model_wrap.Get(); - ASSERT_EQ(model.MainSubgraph()->Ops().size(), 3); - - std::vector selected_ops = { - {model.MainSubgraph()->Ops().front(), 0}, - {model.MainSubgraph()->Ops().back(), 0}, - }; - auto partition_result = PartitionModelDirect(std::move(selected_ops), model); - ASSERT_TRUE(partition_result); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 2); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_EQ(ops.back()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 2); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); - EXPECT_EQ(subgraphs.Elements().back()->Ops().size(), 1); - - const auto& cst_1 = - subgraphs.Elements().front()->Ops().front()->Input(1).Weights(); - const auto& cst_2 = - subgraphs.Elements().back()->Ops().front()->Input(1).Weights(); - - // Both weights should have the same object managed by the same buffer - // manager. - ASSERT_EQ(cst_1.GetBufferManager(), model.Buffers()); - ASSERT_EQ(cst_2.GetBufferManager(), model.Buffers()); - ASSERT_GT(cst_1.Buffer().Size(), 0); - ASSERT_GT(cst_2.Buffer().Size(), 0); - EXPECT_EQ(cst_1.GetBufferId(), cst_2.GetBufferId()); - ASSERT_EQ(cst_1.Buffer().Data(), cst_2.Buffer().Data()); -} - -TEST(ApplyTest, Simple) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - - ASSERT_TRUE(ApplyPlugin(plugins->front(), model)); - ASSERT_EQ(model.NumSubgraphs(), 1); - - auto& subgraph = *model.MainSubgraph(); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - - EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); -} - -TEST(ApplyTest, WithPartition) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 1); - - ASSERT_TRUE(ApplyPluginWithPartition(plugins->front(), model, - std::move(*partition_result))); - - auto& subgraph = model.Subgraph(0); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); -} - -TEST(ApplyTest, MultiSubgraph) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - - ASSERT_TRUE(ApplyPlugin(plugins->front(), model)); - ASSERT_EQ(model.NumSubgraphs(), 2); - - { - auto& subgraph = model.Subgraph(0); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - } - - { - auto& subgraph = model.Subgraph(1); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - } - - EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); -} - -TEST(ApplyTest, ApplyPlugins) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - - const std::array environment_options = { - litert::Environment::Option{ - /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryDir, - /*.value=*/kTestPluginSearchPath, - }, - }; - auto env = litert::Environment::Create(environment_options); - ASSERT_TRUE(env); - - LiteRtHwAccelerators compilation_options = static_cast( - kLiteRtHwAcceleratorCpu | kLiteRtHwAcceleratorGpu | - kLiteRtHwAcceleratorNpu); - auto result = - litert::internal::ApplyPlugins(env->Get(), &model, compilation_options); - ASSERT_TRUE(result); - - ASSERT_EQ(model.NumSubgraphs(), 1); - - auto& subgraph = *model.MainSubgraph(); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - - EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); -} - -TEST(PartitionTest, MappedCompositeOp) { - auto model_wrap = testing::LoadTestFileModel("rms_norm_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - // One new subgraph for the consumed composite op only, decomp not consumed. - ASSERT_EQ(partition_result->second.Size(), 1); -} - -TEST(PartitionTest, SimpleNpuCallComposite) { - auto model_wrap = testing::LoadTestFileModel("simple_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - auto* decomp = model.Subgraphs()[1]; - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - - auto& ops = partition_result->first; - ASSERT_EQ(ops.size(), 1); - ASSERT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - auto& sgs = partition_result->second; - ASSERT_EQ(sgs.Size(), 1); - ASSERT_EQ(sgs.Elements().front(), decomp); -} - -TEST(PartitionTest, MultiNpuCallComposite) { - auto model_wrap = testing::LoadTestFileModel("multi_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(model.NumSubgraphs(), 4); - auto* decomp1 = model.Subgraphs()[1]; - auto* non_npu_call_decomop = model.Subgraphs()[2]; - auto* decomp2 = model.Subgraphs()[3]; - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - - { - // Subgraphs to be compiled will be moved to the result from the model. - // Non-npu-call decompositions will be reindexed. - ASSERT_EQ(model.NumSubgraphs(), 2); - ASSERT_EQ(model.Subgraphs()[1], non_npu_call_decomop); - auto opts = GetOptionsAs(model.Subgraph(0).Ops()[1]); - ASSERT_TRUE(opts); - ASSERT_EQ(opts->subgraph, 1); - } - - { - // All npu call ops are now dispatch ops. - auto& ops = partition_result->first; - - ASSERT_EQ(ops.size(), 2); - auto* first_dispatch_op = ops.front(); - auto* second_dispatch_op = ops.back(); - - ASSERT_EQ(first_dispatch_op->OpCode(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(first_dispatch_op, model.Subgraphs()[0]->Ops().front()); - - ASSERT_EQ(second_dispatch_op->OpCode(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(second_dispatch_op, model.Subgraphs()[0]->Ops().back()); - } - - { - // Bodies to compile are the decompositions of npu call ops. - auto& sgs = partition_result->second; - - ASSERT_EQ(sgs.Size(), 2); - ASSERT_EQ(sgs.Elements().front(), decomp1); - ASSERT_EQ(sgs.Elements().back(), decomp2); - } -} - -TEST(PartitionTest, NestedNpuCallComposite) { - auto model_wrap = testing::LoadTestFileModel("nested_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(model.NumSubgraphs(), 3); - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - - auto& ops = partition_result->first; - ASSERT_EQ(ops.size(), 1); - ASSERT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - auto& sgs = partition_result->second; - ASSERT_EQ(sgs.Size(), 1); - ASSERT_EQ(sgs.Elements().front()->Op(0).OpCode(), kLiteRtOpCodeShloComposite); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/BUILD b/tensorflow/lite/experimental/litert/core/BUILD deleted file mode 100644 index 005ced8f23276c..00000000000000 --- a/tensorflow/lite/experimental/litert/core/BUILD +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//third_party/odml/infra/ml_drift_delegate/litert:__subpackages__", - ], -) - -cc_library( - name = "build_stamp", - srcs = ["build_stamp.cc"], - hdrs = ["build_stamp.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "build_stamp_test", - srcs = ["build_stamp_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - deps = [ - ":build_stamp", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "dynamic_loading", - srcs = ["dynamic_loading.cc"], - hdrs = ["dynamic_loading.h"], - linkopts = ["-ldl"], - deps = [ - ":filesystem", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "insert_order_map", - hdrs = ["insert_order_map.h"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_test( - name = "insert_order_map_test", - srcs = ["insert_order_map_test.cc"], - deps = [ - ":insert_order_map", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "environment", - srcs = ["environment.cc"], - hdrs = [ - "environment.h", - "//tensorflow/lite/experimental/litert/c:litert_environment.h", - ], - deps = [ - ":environment_options", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/runtime:accelerator_registry", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "environment_options", - srcs = ["environment_options.cc"], - hdrs = ["environment_options.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options_header", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "environment_options_test", - srcs = ["environment_options_test.cc"], - deps = [ - ":environment_options", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "environment_test", - srcs = ["environment_test.cc"], - deps = [ - ":environment", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "filesystem", - srcs = ["filesystem.cc"], - hdrs = ["filesystem.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "dispatch_op_schema", - srcs = ["dispatch_op_schema.cc"], - hdrs = ["dispatch_op_schema.h"], - copts = ["-DFLATBUFFERS_LOCALE_INDEPENDENT=0"], - deps = [ - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "@flatbuffers//:runtime_cc", - ], -) - -cc_test( - name = "filesystem_test", - srcs = ["filesystem_test.cc"], - deps = [ - ":filesystem", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -# copybara:uncomment_begin(no OSS for unique-test-directory) -# cc_test( -# name = "dynamic_loading_test", -# srcs = ["dynamic_loading_test.cc"], -# tags = [ -# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. -# "noasan", -# "nomsan", -# "nosan", -# ], -# deps = [ -# ":dynamic_loading", -# ":filesystem", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/strings:string_view", -# "//tensorflow/lite/experimental/litert/c:litert_logging", # buildcleaner: keep -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# ], -# ) -# copybara:uncomment_end - -cc_test( - name = "dispatch_op_schema_test", - srcs = ["dispatch_op_schema_test.cc"], - deps = [ - ":dispatch_op_schema", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "version", - hdrs = ["version.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - ], -) diff --git a/tensorflow/lite/experimental/litert/core/build_stamp.cc b/tensorflow/lite/experimental/litert/core/build_stamp.cc deleted file mode 100644 index 9b7e942c36622f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/build_stamp.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -namespace { -// Simple metadata added to the flatbuffer related to compiler plugin. -struct BuildStamp { - char soc_manufacturer[kSocManufacturerMaxLen + 1] = {}; - char soc_model[kSocModelMaxLen + 1] = {}; -}; - -} // namespace - -Expected> MakeBuildStamp( - absl::string_view soc_manufacturer, absl::string_view soc_model) { - if (soc_manufacturer.size() >= kSocManufacturerMaxLen || - soc_model.size() >= kSocModelMaxLen) { - LITERT_LOG(LITERT_ERROR, "%s", "Soc Make/Model strings too large\n"); - return Unexpected(kLiteRtStatusErrorInvalidArgument); - } - BuildStamp stamp; - soc_manufacturer.copy(stamp.soc_manufacturer, soc_manufacturer.size()); - soc_model.copy(stamp.soc_model, soc_model.size()); - return OwningBufferRef(reinterpret_cast(&stamp), - sizeof(stamp)); -} - -// Parse a serialized build stamp from the given buf. -Expected> ParseBuildStamp( - BufferRef buf) { - if (buf.Size() != sizeof(BuildStamp)) { - LITERT_LOG(LITERT_ERROR, "%s", "Build stamp size mismatch\n"); - return Unexpected(kLiteRtStatusErrorInvalidArgument); - } - const BuildStamp* stamp = reinterpret_cast(buf.Data()); - return std::make_tuple(absl::string_view(stamp->soc_manufacturer), - absl::string_view(stamp->soc_model)); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/build_stamp.h b/tensorflow/lite/experimental/litert/core/build_stamp.h deleted file mode 100644 index bf9ee91934e503..00000000000000 --- a/tensorflow/lite/experimental/litert/core/build_stamp.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BUILD_STAMP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BUILD_STAMP_H_ - -#include - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -// TODO update this library to use the flexbuffers api. - -// Shared "custom_code" for all dispatch ops. -static constexpr absl::string_view kLiteRtDispatchOpCustomCode = "DISPATCH_OP"; - -// -// Build Stamp -// - -// Maximum size of string for soc_manufacturer. -static constexpr size_t kSocManufacturerMaxLen = 124; - -// Maximum size of string for soc_model. -static constexpr size_t kSocModelMaxLen = 124; - -// Metadata key to lookup the build stamp. -static constexpr absl::string_view kLiteRtBuildStampKey = "LiteRtStamp"; - -// Make a serialized build stamp that can go directly in the flatbuffer. -Expected> MakeBuildStamp( - absl::string_view soc_manufacturer, absl::string_view soc_model); - -// Parse a serialized build stamp from the given buf. -Expected> ParseBuildStamp( - BufferRef buf); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BUILD_STAMP_H_ diff --git a/tensorflow/lite/experimental/litert/core/build_stamp_test.cc b/tensorflow/lite/experimental/litert/core/build_stamp_test.cc deleted file mode 100644 index a0c3ce4fbbf1d0..00000000000000 --- a/tensorflow/lite/experimental/litert/core/build_stamp_test.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" - -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::internal { - -namespace { - -using ::testing::litert::IsError; - -static constexpr absl::string_view kSocModel = "TestSocModel"; -static constexpr absl::string_view kSocMan = "TestSocMan"; - -TEST(TestBuildStamp, MakeBuildStampInputsTooLarge) { - // NOLINTNEXTLINE - std::string long_manufacturer(256, 'a'); - auto res = MakeBuildStamp(long_manufacturer, kSocModel); - EXPECT_THAT(res, IsError(kLiteRtStatusErrorInvalidArgument)); -} - -TEST(TestBuildStamp, MakeBuildStamp) { - auto stamp = MakeBuildStamp(kSocMan, kSocModel); - auto pstamp = ParseBuildStamp(*stamp); - auto [man, model] = *pstamp; - EXPECT_EQ(man, kSocMan); - EXPECT_EQ(model, kSocModel); -} - -} // namespace - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc b/tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc deleted file mode 100644 index ed2226ef664cef..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" - -#include -#include -#include -#include - -#include "flatbuffers/flexbuffers.h" // from @flatbuffers -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -namespace litert { -namespace internal { -namespace { - -static constexpr const char kBytecodeSizeKey[] = "bytecode_size"; -static constexpr const char kBytecodeOffsetKey[] = "bytecode_offset"; -static constexpr const char kNameKey[] = "name"; - -} // namespace - -OwningBufferRef MakeDispatchOpOptions(DispatchOpOptions options) { - flexbuffers::Builder fbb; - - // Set maximum width for scalars to 64 bits. This prevents any upsizing of - // the buffer when updating the bytecode size and offset in place. - fbb.ForceMinimumBitWidth(flexbuffers::BIT_WIDTH_64); - - auto start = fbb.StartMap(); - - fbb.UInt(kBytecodeSizeKey, options.bytecode_size); - fbb.UInt(kBytecodeOffsetKey, options.bytecode_offset); - fbb.String(kNameKey, options.name); - - fbb.EndMap(start); - fbb.Finish(); - - auto buf = fbb.GetBuffer(); - OwningBufferRef res; - res.Assign(buf.data(), buf.size()); - - return res; -} - -bool UpdateDispatchOpOptionsInPlace(DispatchOpOptions options, - MutableBufferRef buffer) { - auto opts = flexbuffers::GetRoot(buffer.Data(), buffer.Size()).AsMap(); - - // Update name if same len. - const auto name_ok = opts[kNameKey].MutateString(options.name); - - // Update bytecode size and offset. Since min scalar bit width is set to max - // possible value, it shouldn't fail in theory. - const auto size_ok = opts[kBytecodeSizeKey].MutateUInt(options.bytecode_size); - const auto offset_ok = - opts[kBytecodeOffsetKey].MutateUInt(options.bytecode_offset); - - return name_ok && size_ok && offset_ok; -} - -DispatchOpOptions GetDispatchOpOptions(BufferRef buffer) { - const auto opts = flexbuffers::GetRoot(buffer.Data(), buffer.Size()).AsMap(); - - const size_t bytecode_size = opts[kBytecodeSizeKey].AsUInt64(); - const size_t bytecode_offset = opts[kBytecodeOffsetKey].AsUInt64(); - std::string name(opts[kNameKey].AsString().c_str()); - - return DispatchOpOptions{ - bytecode_size, - bytecode_offset, - std::move(name), - }; -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.h b/tensorflow/lite/experimental/litert/core/dispatch_op_schema.h deleted file mode 100644 index a6f6eb9216caba..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_OP_SCHEMA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_OP_SCHEMA_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -// Utilities for working with the dispatch op custom options buffer. These -// functions leverage the flexbuffer api under the hood which allows for inplace -// updates. - -namespace litert::internal { - -// Schema representing the custom options data for dispatch ops. Primarly used -// to for tracking location of bytecode. -struct DispatchOpOptions { - // The size of the bytecode for the dispatch op. - size_t bytecode_size; - - // The offset of the bytecode for the dispatch op relative to the start of the - // model file. - size_t bytecode_offset; - - // Name of specific dispatch op or entry point to be called in a shared - // bytecode module. - std::string name; -}; - -// Get a serialized representation of the dispatch op options. These should -// be stored directly in the custom options of the dispatch op. -OwningBufferRef MakeDispatchOpOptions(DispatchOpOptions options); - -// Update the dispatch op options in the given buffer with the given options. -// The buffer should be the custom options buffer of the dispatch op. Fails if -// the passed values would resize the buffer. -bool UpdateDispatchOpOptionsInPlace(DispatchOpOptions options, - MutableBufferRef buffer); - -// Get the dispatch op options from the given buffer. The buffer should be the -// custom options buffer of the dispatch op. -DispatchOpOptions GetDispatchOpOptions(BufferRef buffer); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_OP_SCHEMA_H_ diff --git a/tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc b/tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc deleted file mode 100644 index 53f784b50a674e..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" - -#include - -#include - -namespace litert { -namespace internal { -namespace { - -static constexpr size_t kBufferSize = 100; -static constexpr size_t kBufferOffset = 200; -static constexpr const char kName[] = "test_name"; - -TEST(DispatchOpSchemaTest, DispatchOpOptions) { - DispatchOpOptions options = { - kBufferSize, - kBufferOffset, - kName, - }; - - auto buffer = MakeDispatchOpOptions(options); - ASSERT_GT(buffer.Size(), 0); - - auto parsed_options = GetDispatchOpOptions(buffer); - ASSERT_EQ(parsed_options.bytecode_size, kBufferSize); - ASSERT_EQ(parsed_options.bytecode_offset, kBufferOffset); - ASSERT_EQ(parsed_options.name, kName); -} - -TEST(DispatchOpSchemaTest, UpdateDispatchOpOptions) { - DispatchOpOptions options = { - kBufferSize, - kBufferOffset, - kName, - }; - - auto buffer = MakeDispatchOpOptions(options); - ASSERT_GT(buffer.Size(), 0); - - static constexpr size_t kNewBufferSize = 1000; - static constexpr size_t kNewBufferOffset = 2000; - - DispatchOpOptions new_options = { - kNewBufferSize, - kNewBufferOffset, - kName, - }; - - ASSERT_TRUE(UpdateDispatchOpOptionsInPlace(new_options, buffer)); - - auto parsed_options = GetDispatchOpOptions(buffer); - ASSERT_EQ(parsed_options.bytecode_size, kNewBufferSize); - ASSERT_EQ(parsed_options.bytecode_offset, kNewBufferOffset); - ASSERT_EQ(parsed_options.name, kName); -} - -} // namespace -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading.cc deleted file mode 100644 index 37c4ef2040dd86..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" - -#include -#include - -// clang-format off -#ifndef __ANDROID__ -#if __has_include() -#include -#endif -#endif -// clang-format on - -#include -#include // NOLINT -#include -#include - -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -namespace litert::internal { - -namespace { - -static constexpr absl::string_view kLdLibraryPath = "LD_LIBRARY_PATH"; - -bool EnvPathContains(absl::string_view path, absl::string_view var_value) { - return absl::EndsWith(var_value, path) || - absl::StrContains(var_value, absl::StrCat(path, ":")); -} - -} // namespace - -static constexpr absl::string_view kSo = ".so"; - -LiteRtStatus FindLiteRtSharedLibsHelper(const std::string& search_path, - const std::string& lib_pattern, - bool full_match, - std::vector& results) { - if (!Exists(search_path)) { - return kLiteRtStatusErrorInvalidArgument; - } - - // TODO implement path glob in core/filesystem.h and remove filesystem - // include from this file. - for (const auto& entry : std::filesystem::directory_iterator( - search_path, - std::filesystem::directory_options::skip_permission_denied)) { - const auto& path = entry.path(); - if (access(path.c_str(), R_OK) != 0) { - continue; - } - if (entry.is_regular_file()) { - if (full_match) { - if (path.string().find(lib_pattern) != -1) { - LITERT_LOG(LITERT_VERBOSE, "Found shared library: %s", path.c_str()); - results.push_back(path); - } - } else { - const auto stem = path.stem().string(); - const auto ext = path.extension().string(); - if (stem.find(lib_pattern) == 0 && kSo == ext) { - LITERT_LOG(LITERT_VERBOSE, "Found shared library: %s", path.c_str()); - results.push_back(path); - } - } - } else if (entry.is_directory()) { - FindLiteRtSharedLibsHelper(path, lib_pattern, full_match, results); - } - } - - return kLiteRtStatusOk; -} - -static const char kCompilerPluginLibPatternFmt[] = "CompilerPlugin"; - -LiteRtStatus FindLiteRtCompilerPluginSharedLibs( - absl::string_view search_path, std::vector& results) { - std::string root(search_path); - const std::string lib_pattern = - absl::StrCat(kLiteRtSharedLibPrefix, kCompilerPluginLibPatternFmt); - return FindLiteRtSharedLibsHelper(root, lib_pattern, /*full_match=*/false, - results); -} - -static const char kDispatchLibPatternFmt[] = "Dispatch"; - -LiteRtStatus FindLiteRtDispatchSharedLibs(absl::string_view search_path, - std::vector& results) { - std::string root(search_path.data()); - const std::string lib_pattern = - absl::StrCat(kLiteRtSharedLibPrefix, kDispatchLibPatternFmt); - return FindLiteRtSharedLibsHelper(root, lib_pattern, /*full_match=*/false, - results); -} - -LiteRtStatus PutLibOnLdPath(absl::string_view search_path, - absl::string_view lib_pattern) { - std::vector results; - LITERT_RETURN_IF_ERROR(FindLiteRtSharedLibsHelper( - std::string(search_path), std::string(lib_pattern), true, results)); - if (results.empty()) { - LITERT_LOG(LITERT_INFO, "No match found in %s", search_path.data()); - return kLiteRtStatusOk; - } - - const auto lib_dir = std::filesystem::path(results[0]).parent_path().string(); - absl::string_view ld = getenv(kLdLibraryPath.data()); - - if (EnvPathContains(lib_dir, ld)) { - LITERT_LOG(LITERT_INFO, "dir already in LD_LIBRARY_PATH"); - return kLiteRtStatusOk; - } - - std::string new_ld; - if (ld.empty()) { - new_ld = lib_dir; - } else { - new_ld = absl::StrCat(ld, ":", lib_dir); - } - - LITERT_LOG(LITERT_INFO, "Adding %s to LD_LIBRARY_PATH", new_ld.c_str()); - setenv(kLdLibraryPath.data(), new_ld.c_str(), /*overwrite=*/1); - - return kLiteRtStatusOk; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.h b/tensorflow/lite/experimental/litert/core/dynamic_loading.h deleted file mode 100644 index 5f44f5aafaa851..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ - -#include -#include - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -namespace litert::internal { - -constexpr absl::string_view kLiteRtSharedLibPrefix = "libLiteRt"; - -// Loads shared library at given path. Logging can be disabled to probe for -// shared libraries. -LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle, - bool log_failure = true); - -// Find all litert shared libraries in "search_path" and return -// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't -// exist. All internal dynamically linked dependencies for litert should be -// prefixed with "libLiteRtCompilerPlugin". -LiteRtStatus FindLiteRtCompilerPluginSharedLibs( - absl::string_view search_path, std::vector& results); - -// Find all litert shared libraries in "search_path" and return -// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't -// exist. All internal dynamically linked dependencies for litert should be -// prefixed with "libLiteRtDispatch". -LiteRtStatus FindLiteRtDispatchSharedLibs(absl::string_view search_path, - std::vector& results); - -// Find shared libraries for a given pattern in "search_path" and return -// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't -// exist. -LiteRtStatus FindLiteRtSharedLibsHelper(const std::string& search_path, - const std::string& lib_pattern, - bool full_match, - std::vector& results); - -// Analogous to the above, but the first match identified, its immeidate parent -// directory will be appended to the LD_LIBRARY_PATH. -LiteRtStatus PutLibOnLdPath(absl::string_view search_path, - absl::string_view lib_pattern); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc deleted file mode 100644 index 000e33947fb790..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" - -#include -#include // NOLINT -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::internal { -namespace { - -using litert::testing::UniqueTestDirectory; -using ::testing::Contains; -using ::testing::HasSubstr; - -constexpr absl::string_view kNotLiteRtSo = "notLibLiteRt.so"; -constexpr absl::string_view kLiteRtSo1 = "libLiteRtCompilerPlugin_1.so"; -constexpr absl::string_view kLiteRtSo2 = "libLiteRtCompilerPlugin_2.so"; -constexpr absl::string_view kLiteRtSo3 = "libLiteRtDispatch_1.so"; -constexpr absl::string_view kLiteRtSo4 = "libLiteRtDispatch_2.so"; -constexpr absl::string_view kLdLibraryPath = "LD_LIBRARY_PATH"; - -TEST(TestDynamicLoading, GlobNoMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtCompilerPluginSharedLibs( - dir->Str(), results)); - EXPECT_EQ(results.size(), 0); - std::vector results2; - LITERT_ASSERT_OK( - litert::internal::FindLiteRtDispatchSharedLibs(dir->Str(), results2)); - EXPECT_EQ(results2.size(), 0); -} - -TEST(TestDynamicLoading, GlobOneMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kLiteRtSo1})); - Touch(Join({dir->Str(), kLiteRtSo3})); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtCompilerPluginSharedLibs( - dir->Str(), results)); - ASSERT_EQ(results.size(), 1); - EXPECT_TRUE(absl::string_view(results.front()).ends_with(kLiteRtSo1)); - - std::vector results2; - LITERT_ASSERT_OK( - litert::internal::FindLiteRtDispatchSharedLibs(dir->Str(), results2)); - ASSERT_EQ(results2.size(), 1); - EXPECT_TRUE(absl::string_view(results2.front()).ends_with(kLiteRtSo3)); -} - -TEST(TestDynamicLoading, GlobMultiMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kLiteRtSo1})); - Touch(Join({dir->Str(), kLiteRtSo2})); - Touch(Join({dir->Str(), kLiteRtSo3})); - Touch(Join({dir->Str(), kLiteRtSo4})); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtCompilerPluginSharedLibs( - dir->Str(), results)); - ASSERT_EQ(results.size(), 2); - EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo1))); - EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo2))); - - std::vector results2; - LITERT_ASSERT_OK( - litert::internal::FindLiteRtDispatchSharedLibs(dir->Str(), results2)); - ASSERT_EQ(results2.size(), 2); - EXPECT_THAT(results2, Contains(HasSubstr(kLiteRtSo3))); - EXPECT_THAT(results2, Contains(HasSubstr(kLiteRtSo4))); -} - -TEST(TestDynamicLoadingHelper, HelperWithFullMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kLiteRtSo1})); - Touch(Join({dir->Str(), kLiteRtSo2})); - Touch(Join({dir->Str(), kLiteRtSo3})); - Touch(Join({dir->Str(), kLiteRtSo4})); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtSharedLibsHelper( - std::string(dir->Str()), std::string(kLiteRtSo4), true, results)); - ASSERT_EQ(results.size(), 1); - EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo4))); -} - -TEST(TestPutLibOnLdPath, AppendToEmptyLdPath) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto lib_path = Join({dir_path, kLiteRtSo1}); - Touch(lib_path); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(dir_path)); -} - -TEST(TestPutLibOnLdPath, AppendToLdPathNoMatch) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - ASSERT_EQ(getenv(kLdLibraryPath.data()), nullptr); -} - -TEST(TestPutLibOnLdPath, AppendToExistingLdPath) { - static constexpr absl::string_view kExistingLdPath = "an/existing/path"; - - unsetenv(kLdLibraryPath.data()); - setenv(kLdLibraryPath.data(), kExistingLdPath.data(), /*overwrite=*/1); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto lib_path = Join({dir_path, kLiteRtSo1}); - Touch(lib_path); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(dir_path)); - EXPECT_THAT(ld_library_path, HasSubstr(kExistingLdPath)); -} - -TEST(TestPutLibOnLdPath, AppendToLdLibraryPathNoDupePath) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto lib_path = Join({dir_path, kLiteRtSo1}); - Touch(lib_path); - - setenv(kLdLibraryPath.data(), dir_path.data(), /*overwrite=*/1); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(dir_path)); - EXPECT_EQ(ld_library_path.size(), dir_path.size()); -} - -TEST(TestPutLibOnLdPath, AppendToNestedLdPath) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto nested_dir_path = Join({dir_path, "another/dir"}); - const auto lib_path = Join({nested_dir_path, kLiteRtSo1}); - ASSERT_TRUE(std::filesystem::create_directories(nested_dir_path)); - Touch(lib_path); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(nested_dir_path)); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/environment.cc b/tensorflow/lite/experimental/litert/core/environment.cc deleted file mode 100644 index 838a587fc471e6..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment.h" - -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -litert::Expected LiteRtEnvironmentT::CreateWithOptions( - absl::Span options) { - LITERT_LOG(LITERT_INFO, "Creating LiteRT environment with options"); - auto env = std::make_unique(); - for (const auto& opt : options) { - env->options_.SetOption(opt); - } - - return env; -} diff --git a/tensorflow/lite/experimental/litert/core/environment.h b/tensorflow/lite/experimental/litert/core/environment.h deleted file mode 100644 index 0ac5d42b0fc6b2..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_registry.h" - -// A singleton class that contains global LiteRT environment options. -class LiteRtEnvironmentT { - public: - using Ptr = std::unique_ptr; - - LiteRtEnvironmentT() = default; - // Create an environment instance with options. - static litert::Expected CreateWithOptions( - absl::Span options); - - ~LiteRtEnvironmentT() = default; - - std::optional GetOption(LiteRtEnvOptionTag tag) const { - auto opt = options_.GetOption(tag); - return opt.HasValue() ? std::optional(opt.Value()) - : std::nullopt; - } - - LiteRtEnvironmentOptionsT& GetOptions() { return options_; } - const LiteRtEnvironmentOptionsT& GetOptions() const { return options_; } - - litert::internal::AcceleratorRegistry& GetAcceleratorRegistry() { - return accelerators_; - } - - private: - litert::internal::AcceleratorRegistry accelerators_; - LiteRtEnvironmentOptionsT options_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/core/environment_options.cc b/tensorflow/lite/experimental/litert/core/environment_options.cc deleted file mode 100644 index ce1ea724ae83de..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_options.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment_options.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -LiteRtEnvironmentOptionsT::LiteRtEnvironmentOptionsT( - LiteRtEnvironmentOptionsT&& other) - : options_(std::move(other.options_)), - string_option_values_(std::move(other.string_option_values_)) { - // Update the string pointers in case they have changed when moving the - // container. This can happen because of small string optimization. - RefreshStringOptionValuePointers(); -} - -LiteRtEnvironmentOptionsT& LiteRtEnvironmentOptionsT::operator=( - LiteRtEnvironmentOptionsT&& other) { - options_ = std::move(other.options_); - string_option_values_ = std::move(other.string_option_values_); - // Update the string pointers in case they have changed when moving the - // container. This can happen because of small string optimization. - RefreshStringOptionValuePointers(); - return *this; -} - -void LiteRtEnvironmentOptionsT::RefreshStringOptionValuePointers() { - for (const auto& [tag, value] : string_option_values_) { - options_[tag].str_value = value.c_str(); - } -} - -litert::Expected LiteRtEnvironmentOptionsT::GetOption( - LiteRtEnvOptionTag tag) const { - if (auto it = options_.find(tag); it != options_.end()) { - return it->second; - } - return litert::Error(kLiteRtStatusErrorNotFound, - "Option was not set for this environment."); -} - -litert::Expected LiteRtEnvironmentOptionsT::SetOption( - LiteRtEnvOption option) { - if (option.value.type == kLiteRtAnyTypeString) { - auto [string_it, _] = string_option_values_.insert_or_assign( - option.tag, option.value.str_value); - LiteRtAny value{/*type=*/kLiteRtAnyTypeString}; - value.str_value = string_it->second.c_str(); - options_[option.tag] = value; - } else { - options_[option.tag] = option.value; - } - return {}; -} diff --git a/tensorflow/lite/experimental/litert/core/environment_options.h b/tensorflow/lite/experimental/litert/core/environment_options.h deleted file mode 100644 index 92133fe59835c7..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_options.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_OPTIONS_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -class LiteRtEnvironmentOptionsT { - public: - LiteRtEnvironmentOptionsT() = default; - - LiteRtEnvironmentOptionsT(LiteRtEnvironmentOptionsT&& other); - LiteRtEnvironmentOptionsT& operator=(LiteRtEnvironmentOptionsT&& other); - - litert::Expected GetOption(LiteRtEnvOptionTag tag) const; - litert::Expected SetOption(LiteRtEnvOption option); - - private: - void RefreshStringOptionValuePointers(); - - std::unordered_map options_; - std::unordered_map string_option_values_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/core/environment_options_test.cc b/tensorflow/lite/experimental/litert/core/environment_options_test.cc deleted file mode 100644 index 62294328590487..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_options_test.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment_options.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::Eq; -using testing::Ne; -using testing::litert::IsError; - -TEST(EnvironmentOptionsTest, SetGetStringOptionWorks) { - LiteRtEnvironmentOptionsT options; - constexpr const char* kStrValue = "string_value"; - LiteRtEnvOption env_option{/*tag=*/kLiteRtEnvOptionTagDispatchLibraryDir, - /*value=*/{/*type=*/kLiteRtAnyTypeString}}; - env_option.value.str_value = kStrValue; - options.SetOption(env_option); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtAny stored_option, - options.GetOption(kLiteRtEnvOptionTagDispatchLibraryDir)); - - EXPECT_THAT(stored_option.type, Eq(kLiteRtAnyTypeString)); - EXPECT_THAT(stored_option.str_value, Ne(nullptr)); - EXPECT_THAT(stored_option.str_value, Ne(kStrValue)); -} - -TEST(EnvironmentOptionsTest, SetGetIntOptionWorks) { - constexpr int kIntValue = 3; - LiteRtEnvironmentOptionsT options; - LiteRtEnvOption env_option{/*tag=*/kLiteRtEnvOptionTagOpenClDeviceId, - /*value=*/{/*type=*/kLiteRtAnyTypeInt}}; - env_option.value.int_value = kIntValue; - options.SetOption(env_option); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtAny stored_option, - options.GetOption(kLiteRtEnvOptionTagOpenClDeviceId)); - - EXPECT_THAT(stored_option.type, Eq(kLiteRtAnyTypeInt)); - EXPECT_THAT(stored_option.int_value, Eq(kIntValue)); -} - -TEST(EnvironmentOptionsTest, GetNotSetReturnsNotFound) { - LiteRtEnvironmentOptionsT options; - - // Add a non related option. - constexpr const char* kStrValue = "string_value"; - LiteRtEnvOption env_option{/*tag=*/kLiteRtEnvOptionTagDispatchLibraryDir, - /*value=*/{/*type=*/kLiteRtAnyTypeString}}; - env_option.value.str_value = kStrValue; - options.SetOption(env_option); - - // Request an option that wasn't added. - EXPECT_THAT(options.GetOption(kLiteRtEnvOptionTagOpenClDeviceId), - IsError(kLiteRtStatusErrorNotFound)); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/core/environment_test.cc b/tensorflow/lite/experimental/litert/core/environment_test.cc deleted file mode 100644 index b0d199a53173cf..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment.h" - -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" - -namespace litert::internal { -namespace { - -TEST(LiteRtEnvironmentT, CreateWithOptions) { - const std::array environment_options = { - LiteRtEnvOption{ - kLiteRtEnvOptionTagCompilerPluginLibraryDir, - *ToLiteRtAny(std::any("sample path")), - }, - }; - auto env = LiteRtEnvironmentT::CreateWithOptions(environment_options); - ASSERT_TRUE(env); - - auto option = (*env)->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryDir); - ASSERT_TRUE(option.has_value()); - ASSERT_EQ(option->type, kLiteRtAnyTypeString); - ASSERT_STREQ(option->str_value, "sample path"); -} - -TEST(LiteRtEnvironmentT, CheckStringCopy) { - LiteRtEnvironmentT::Ptr env; - - // The passed string becomes obsolete after the scope. - { - const std::array environment_options = { - LiteRtEnvOption{ - kLiteRtEnvOptionTagCompilerPluginLibraryDir, - *ToLiteRtAny(std::any("sample path")), - }, - }; - auto res = LiteRtEnvironmentT::CreateWithOptions(environment_options); - ASSERT_TRUE(res); - env = std::move(*res); - } - - auto option = env->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryDir); - ASSERT_TRUE(option.has_value()); - ASSERT_EQ(option->type, kLiteRtAnyTypeString); - ASSERT_STREQ(option->str_value, "sample path"); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/filesystem.cc b/tensorflow/lite/experimental/litert/core/filesystem.cc deleted file mode 100644 index e97a583aee27af..00000000000000 --- a/tensorflow/lite/experimental/litert/core/filesystem.cc +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -#include -#include -#include // NOLINT -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert::internal { - -namespace { - -using StdPath = std::filesystem::path; - -StdPath MakeStdPath(absl::string_view path) { - return StdPath(std::string(path.begin(), path.end())); -} - -bool StdExists(const StdPath& std_path) { - return std::filesystem::exists(std_path); -} - -size_t StdSize(const StdPath& std_path) { - return std::filesystem::file_size(std_path); -} - -LiteRtStatus StdIFRead(const StdPath& std_path, char* data, size_t size) { - std::ifstream in_file_stream(std_path, std::ifstream::binary); - if (!in_file_stream) { - return kLiteRtStatusErrorFileIO; - } - - in_file_stream.read(data, size); - if (!in_file_stream) { - return kLiteRtStatusErrorFileIO; - } - - in_file_stream.close(); - return kLiteRtStatusOk; -} - -} // namespace - -void Touch(absl::string_view path) { std::ofstream(MakeStdPath(path)); } - -std::string Join(const std::vector& paths) { - StdPath std_path; - for (auto subpath : paths) { - std_path /= MakeStdPath(subpath); - } - return std_path.generic_string(); -} - -bool Exists(absl::string_view path) { return StdExists(MakeStdPath(path)); } - -Expected Size(absl::string_view path) { - auto std_path = MakeStdPath(path); - if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorNotFound, - absl::StrFormat("File not found: %s", std_path.c_str())); - } - return StdSize(std_path); -} - -Expected> LoadBinaryFile(absl::string_view path) { - auto std_path = MakeStdPath(path); - - if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorNotFound, - absl::StrFormat("File not found: %s", std_path.c_str())); - } - - OwningBufferRef buf(StdSize(std_path)); - LITERT_RETURN_IF_ERROR(StdIFRead(std_path, buf.StrData(), buf.Size())); - - return buf; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/filesystem.h b/tensorflow/lite/experimental/litert/core/filesystem.h deleted file mode 100644 index 3de517dfd4d5c6..00000000000000 --- a/tensorflow/lite/experimental/litert/core/filesystem.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -// Generic file operations. Try to encapsulate the std filesystem header as much -// as possible because its technically unapproved. - -namespace litert::internal { - -// Append all given subpaths together (e.g. os.path.join). -std::string Join(const std::vector& paths); - -// Make a new empty file at the given path. -void Touch(absl::string_view path); - -// Does this file exist. -bool Exists(absl::string_view path); - -// Get size of file. -Expected Size(absl::string_view path); - -// Load the bytes of the file at given path. -Expected> LoadBinaryFile(absl::string_view path); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ diff --git a/tensorflow/lite/experimental/litert/core/filesystem_test.cc b/tensorflow/lite/experimental/litert/core/filesystem_test.cc deleted file mode 100644 index d961d469d10100..00000000000000 --- a/tensorflow/lite/experimental/litert/core/filesystem_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -#include -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" - -namespace litert::internal { -namespace { - -static constexpr absl::string_view kPrefix = "a/prefix"; -static constexpr absl::string_view kInfix = "an/infix"; -static constexpr absl::string_view kSuffix = "suffix.ext"; - -TEST(FilesystemTest, JoinTwo) { - const auto path = Join({kPrefix, kSuffix}); - EXPECT_EQ(path, absl::StrFormat("%s/%s", kPrefix, kSuffix)); -} - -TEST(FilesystemTest, JoinMany) { - const auto path = Join({kPrefix, kInfix, kSuffix}); - EXPECT_EQ(path, absl::StrFormat("%s/%s/%s", kPrefix, kInfix, kSuffix)); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/insert_order_map.h b/tensorflow/lite/experimental/litert/core/insert_order_map.h deleted file mode 100644 index f1c9ca46804943..00000000000000 --- a/tensorflow/lite/experimental/litert/core/insert_order_map.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_INSERT_ORDER_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_INSERT_ORDER_MAP_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" - -namespace litert::internal { - -// A map implementation that iterates in the same order as initial insertion. -template -class InsertOrderMap { - public: - using Pair = std::pair; - using Values = std::vector; - using ValRef = std::reference_wrapper; - using Map = absl::flat_hash_map; - using Iterator = typename Values::iterator; - - InsertOrderMap() = default; - - std::optional Find(const Key& key) { - if (auto it = map_.find(key); it != map_.end()) { - const auto ind = it->second; - return values_[ind]; - } - return {}; - } - - bool Contains(const Key& key) const { return map_.find(key) != map_.end(); } - - void InsertOrAssign(const Key& key, const Val& val) { - if (auto it = map_.find(key); it != map_.end()) { - const auto ind = it->second; - values_[ind].second = val; - } else { - values_.push_back({key, val}); - map_.insert({key, values_.size() - 1}); - } - } - - size_t Size() const { return values_.size(); } - - void Clear() { - values_.clear(); - map_.clear(); - } - - Iterator Begin() { return values_.begin(); } - - Iterator End() { return values_.end(); } - - private: - Values values_; - Map map_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_INSERT_ORDER_MAP_H_ diff --git a/tensorflow/lite/experimental/litert/core/insert_order_map_test.cc b/tensorflow/lite/experimental/litert/core/insert_order_map_test.cc deleted file mode 100644 index 6c24a01be97bdf..00000000000000 --- a/tensorflow/lite/experimental/litert/core/insert_order_map_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/insert_order_map.h" - -#include -#include -#include - -#include -#include - -namespace litert::internal { -namespace { - -using ::testing::ElementsAre; - -using TestMap = InsertOrderMap; - -static constexpr int k1 = 1; -static constexpr int k2 = 2; -static constexpr int k3 = 3; -static constexpr int k4 = 4; -static constexpr const char kV1[] = "1"; -static constexpr const char kV2[] = "2"; -static constexpr const char kV3[] = "3"; -static constexpr const char kV4[] = "4"; - -TestMap MakeTestMap() { - TestMap map; - map.InsertOrAssign(k1, kV1); - map.InsertOrAssign(k2, kV2); - map.InsertOrAssign(k3, kV3); - return map; -} - -TEST(InsertOrderMapTest, IterateInInsertOrder) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - std::vector values(map.Begin(), map.End()); - EXPECT_THAT(values, - ElementsAre(std::make_pair(k1, kV1), std::make_pair(k2, kV2), - std::make_pair(k3, kV3))); -} - -TEST(InsertOrderMapTest, IterateInInsertOrderWithUpdate) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - map.InsertOrAssign(k1, kV4); - std::vector values(map.Begin(), map.End()); - EXPECT_THAT(values, - ElementsAre(std::make_pair(k1, kV4), std::make_pair(k2, kV2), - std::make_pair(k3, kV3))); -} - -TEST(InsertOrderMapTest, FindExisting) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - auto val = map.Find(k1); - ASSERT_TRUE(val.has_value()); - EXPECT_EQ(val->get().first, k1); - EXPECT_EQ(val->get().second, kV1); - - EXPECT_TRUE(map.Contains(k1)); -} - -TEST(InsertOrderMapTest, FindMissing) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - EXPECT_EQ(map.Find(k4), std::nullopt); - EXPECT_FALSE(map.Contains(k4)); -} - -TEST(InsertOrderMapTest, Clear) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - map.Clear(); - EXPECT_EQ(map.Size(), 0); - EXPECT_EQ(map.Begin(), map.End()); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/BUILD b/tensorflow/lite/experimental/litert/core/model/BUILD deleted file mode 100644 index 071d3200041830..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/BUILD +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "lite_rt_friends") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ] + lite_rt_friends(), -) - -cc_library( - name = "model", - srcs = ["model.cc"], - hdrs = [ - "model.h", - "//tensorflow/lite/experimental/litert/c:litert_model_hdrs", - ], - deps = [ - ":buffer_manager", - ":ir_allocator", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "model_test", - srcs = ["model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":buffer_manager", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_load", - srcs = ["model_load.cc"], - hdrs = ["model_load.h"], - deps = [ - ":buffer_manager", - ":flatbuffer_to_litert", - ":model", - ":model_graph", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "model_file_test", - srcs = ["model_file_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - # copybara:uncomment "//tensorflow/lite/java/demo/app/src/main/assets:mobilenet_v1_1.0_224.tflite", - ], - deps = [ - ":buffer_manager", - ":graph_validation", - ":model", - ":model_file_test_util", - ":model_load", - ":model_serialize", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/schema:schema_fbs_with_mutable", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_serialize", - srcs = ["model_serialize.cc"], - hdrs = ["model_serialize.h"], - deps = [ - ":litert_to_flatbuffer", - ":model", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core:insert_order_map", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs_with_mutable", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "flatbuffer_to_litert", - srcs = ["flatbuffer_to_litert.cc"], - hdrs = ["flatbuffer_to_litert.h"], - deps = [ - ":model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - ], -) - -cc_test( - name = "flatbuffer_to_litert_test", - srcs = ["flatbuffer_to_litert_test.cc"], - deps = [ - ":flatbuffer_to_litert", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_to_flatbuffer", - srcs = ["litert_to_flatbuffer.cc"], - hdrs = ["litert_to_flatbuffer.h"], - deps = [ - ":model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_to_flatbuffer_test", - srcs = ["litert_to_flatbuffer_test.cc"], - deps = [ - ":litert_to_flatbuffer", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_buffer", - srcs = ["model_buffer.cc"], - hdrs = ["model_buffer.h"], - deps = [ - ":model", - ":model_load", - ":model_serialize", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:filesystem", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "model_file_test_util", - testonly = 1, - srcs = ["model_file_test_util.cc"], - hdrs = ["model_file_test_util.h"], - deps = [ - ":flatbuffer_to_litert", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "ir_allocator", - hdrs = ["ir_allocator.h"], - deps = ["@com_google_absl//absl/types:span"], -) - -cc_test( - name = "ir_allocator_test", - srcs = ["ir_allocator_test.cc"], - deps = [ - ":ir_allocator", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_graph", - srcs = ["model_graph.cc"], - hdrs = [ - "model_graph.h", - "//tensorflow/lite/experimental/litert/cc:litert_consts.h", - ], - deps = [ - ":model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - ], -) - -cc_library( - name = "graph_validation", - srcs = ["graph_validation.cc"], - hdrs = ["graph_validation.h"], - deps = [ - ":model", - ":model_graph", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - ], -) - -cc_library( - name = "buffer_manager", - hdrs = ["buffer_manager.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "model_graph_test", - srcs = ["model_graph_test.cc"], - deps = [ - ":graph_validation", - ":ir_allocator", - ":model", - ":model_graph", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "model_buffer_test", - srcs = ["model_buffer_test.cc"], - deps = [ - ":model", - ":model_buffer", - ":model_load", - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/lite:framework", - "//tensorflow/lite:model_builder", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_cascade_model_npu", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "buffer_manager_test", - srcs = ["buffer_manager_test.cc"], - deps = [ - ":buffer_manager", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/core/model/buffer_manager.h b/tensorflow/lite/experimental/litert/core/model/buffer_manager.h deleted file mode 100644 index af6b97f15c052b..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/buffer_manager.h +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_BUFFER_MANAGER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_BUFFER_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -// Extra info about how the buffer is handled during load or serialization. -struct BufferContext { - using Ref = std::reference_wrapper; - - // Whether the buffer should be appended to the flatbuffer during - // serialization. - bool should_append = false; -}; - -// Container type for efficiently holding data buffers used by the model. These -// buffers may be owned or non-owned by the model. Uses id based indexing. -class BufferManager { - public: - using Ptr = std::unique_ptr; - - // Unique identifier for a buffer. 0 is reserved for empty buffers. - using BufferId = uint32_t; - static constexpr BufferId kEmptyBufferId = 0; - - // Register a buffer that is not owned by the model. Caller must ensure the - // buffer outlives the model. - BufferId RegisterNonOwnedBuffer( - BufferRef buffer, - std::optional context = std::nullopt) { - auto&& ctx = context.has_value() ? std::move(*context) : BufferContext{}; - buffers_.emplace_back(BufferWithContext(buffer, std::move(ctx))); - return buffers_.size() - 1; - } - - // Register a buffer that is owned by the model. - BufferId RegisterOwnedBuffer( - OwningBufferRef&& buffer, - std::optional context = std::nullopt) { - auto&& ctx = context.has_value() ? std::move(*context) : BufferContext{}; - buffers_.emplace_back(BufferWithContext(buffer, std::move(ctx))); - return buffers_.size() - 1; - } - - // Get a view of the buffer at the given id. - Expected> GetBuffer(BufferId id) { - if (id >= buffers_.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return GetView(buffers_[id].first); - } - - // Get the context of the buffer at the given id. - Expected GetContext(BufferId id) { - if (id >= buffers_.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::ref(buffers_[id].second); - } - - // Number of buffers. Ids will be 0 <-> num - 1. - size_t NumBuffers() const { return buffers_.size(); } - - BufferManager() { - // Zero is reserved for empty buffers. - buffers_.emplace_back( - BufferWithContext(BufferRef(), BufferContext{})); - } - BufferManager(const BufferManager&) = delete; - BufferManager& operator=(const BufferManager&) = delete; - BufferManager(BufferManager&& other) = default; - BufferManager& operator=(BufferManager&& other) = default; - - private: - using BufferType = std::variant, OwningBufferRef>; - using BufferWithContext = std::pair; - - static BufferRef GetView(const BufferType& buffer) { - BufferRef res; - std::visit([&res](auto&& arg) { res = arg; }, buffer); - return res; - } - - std::vector buffers_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_BUFFER_MANAGER_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc b/tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc deleted file mode 100644 index b077eda8b4f3da..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -namespace litert::internal { - -namespace { - -static constexpr absl::string_view kData = "foo"; - -TEST(BufferManagerTest, EmptyFirstBuffer) { - BufferManager manager; - - EXPECT_EQ(manager.NumBuffers(), 1); - EXPECT_EQ(manager.GetBuffer(BufferManager::kEmptyBufferId)->Size(), 0); -} - -TEST(BufferManagerTest, RegisterNonOwnedBuffer) { - BufferManager manager; - - OwningBufferRef buffer(kData); - const auto id = manager.RegisterNonOwnedBuffer(buffer); - - EXPECT_EQ(manager.NumBuffers(), 2); - EXPECT_EQ(manager.GetBuffer(id)->StrView(), kData); -} - -TEST(BufferManagerTest, RegisterOwnedBuffer) { - BufferManager manager; - - OwningBufferRef buffer(kData); - const auto id = manager.RegisterOwnedBuffer(std::move(buffer)); - - EXPECT_EQ(manager.NumBuffers(), 2); - EXPECT_EQ(manager.GetBuffer(id)->StrView(), kData); -} - -TEST(BufferManagerTest, RegisterWithContext) { - BufferManager manager; - - OwningBufferRef buffer(kData); - BufferContext context = {true}; - const auto id = manager.RegisterNonOwnedBuffer(buffer, context); - - EXPECT_EQ(manager.NumBuffers(), 2); - EXPECT_EQ(manager.GetBuffer(id)->StrView(), kData); - EXPECT_EQ(manager.GetContext(id)->get().should_append, true); -} - -} // namespace - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc deleted file mode 100644 index 36c721af2009cc..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { - -LiteRtStatus IsOpSupported(const tflite::OperatorT& op) { - // TODO: b/365299994 - Check for supported options. - - if (!op.intermediates.empty()) { - // TODO: b/365299994 - Support intermediates. - LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (op.large_custom_options_size != 0) { - // TODO: b/365299994 - Support large custom options. - LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - for (auto m_input : op.mutating_variable_inputs) { - if (m_input) { - // TODO: b/365299994 - Support mutating variable inputs. - LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - } - - return kLiteRtStatusOk; -} - -LiteRtStatus IsBufferSupported(const tflite::BufferT& buffer) { - if (buffer.offset != 0) { - // TODO: b/365299994 - Support buffer with offset. - LITERT_LOG(LITERT_ERROR, "Buffers with offset not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus IsTensorSupported(const TflTensor& tensor) { - if (tensor.is_variable) { - // TODO: b/365299994 - Support variable tensors. - LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (!tensor.variant_tensors.empty()) { - // TODO: b/365299994 - Support variant tensors. - LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tensor.sparsity) { - // TODO: b/365299994 - Support sparsity tensors. - LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} - -LiteRtElementType MapElementType(TflElementType type) { - switch (type) { - case tflite::TensorType_FLOAT32: - return kLiteRtElementTypeFloat32; - case tflite::TensorType_FLOAT16: - return kLiteRtElementTypeFloat16; - case tflite::TensorType_INT32: - return kLiteRtElementTypeInt32; - case tflite::TensorType_INT64: - return kLiteRtElementTypeInt64; - case tflite::TensorType_BOOL: - return kLiteRtElementTypeBool; - case tflite::TensorType_INT16: - return kLiteRtElementTypeInt16; - case tflite::TensorType_INT8: - return kLiteRtElementTypeInt8; - case tflite::TensorType_UINT8: - return kLiteRtElementTypeUInt8; - case tflite::TensorType_INT4: - return kLiteRtElementTypeInt4; - default: - return kLiteRtElementTypeNone; - } -} - -Expected MapTensorType(const TflTensorType& tfl_tensor_type) { - const auto& [element_type, shape] = tfl_tensor_type; - auto ranked_shape = AsDynamicShape(shape); - if (!ranked_shape) { - LITERT_LOG(LITERT_ERROR, "Only ranked tensors currently supported"); - return Error(kLiteRtStatusErrorUnsupported); - } - - auto litert_element_type = MapElementType(element_type); - if (litert_element_type == kLiteRtElementTypeNone) { - LITERT_LOG(LITERT_ERROR, "Element type not currently supported"); - return Error(kLiteRtStatusErrorUnsupported); - } - - TensorTypeDetail detail; - detail.ranked_tensor_type.element_type = litert_element_type; - detail.ranked_tensor_type.layout = BuildLayout(*ranked_shape); - - return std::make_pair(kLiteRtRankedTensorType, detail); -} - -Expected MapQuantization(const TflQuantization* tfl_quantization, - ScratchBufferProvider buffer_provider) { - if (!IsQuantized(tfl_quantization)) { - return MakeEmptyQuantization(); - } - - if (auto tfl_qparams = AsPerTensorQparams(tfl_quantization)) { - return MakePerTensorQuantization(tfl_qparams->second, tfl_qparams->first); - } - - if (auto tfl_qparams = AsPerChannelQparams(tfl_quantization)) { - [[maybe_unused]] const auto& [quantized_dimension, num_channels, - zero_points, scales] = *tfl_qparams; - return MakePerChannelQuantization(scales, zero_points, quantized_dimension, - buffer_provider); - } - - LITERT_LOG(LITERT_ERROR, "Uknown tfl quantization type"); - return Error(kLiteRtStatusErrorUnsupported); -} -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h deleted file mode 100644 index 92a7d11cdf0321..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -LiteRtStatus IsOpSupported(const TflOp& op); - -LiteRtStatus IsBufferSupported(const TflBuffer& buffer); - -// Checks if the misc non-type non quantization parts of this tensor are -// supported in the litet model api. -LiteRtStatus IsTensorSupported(const TflTensor& tensor); - -LiteRtElementType MapElementType(TflElementType element_type); - -Expected MapTensorType(const TflTensorType& tfl_tensor_type); - -Expected MapQuantization(const TflQuantization* tfl_quantization, - ScratchBufferProvider buffer_provider); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc deleted file mode 100644 index b0a2e6598a683f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" - -#include -#include -#include - -#include -#include -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -TEST(FlatbufferToLiteRtTest, MapStaticTensorType) { - static constexpr int32_t kDims[] = {2, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - - auto t = MapTensorType(std::make_pair(TflElementType::TensorType_INT32, - TflShapeInfo(kDimsSpan))); - ASSERT_TRUE(t); - - ASSERT_EQ(t->first, kLiteRtRankedTensorType); - auto& ranked = t->second.ranked_tensor_type; - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt32); - EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), - kDimsSpan); -} - -TEST(FlatbufferToLiteRtTest, MapStaticTensorInt4Type) { - static constexpr int32_t kDims[] = {2, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - - auto t = MapTensorType( - std::make_pair(TflElementType::TensorType_INT4, TflShapeInfo(kDimsSpan))); - ASSERT_TRUE(t); - - ASSERT_EQ(t->first, kLiteRtRankedTensorType); - auto& ranked = t->second.ranked_tensor_type; - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt4); - EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), - kDimsSpan); -} - -TEST(FlatbufferToLiteRtTest, MapDynamicTensorType) { - static constexpr int32_t kDims[] = {-1, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - - auto t = MapTensorType(std::make_pair(TflElementType::TensorType_INT32, - TflShapeInfo(kDimsSpan))); - ASSERT_TRUE(t); - - ASSERT_EQ(t->first, kLiteRtRankedTensorType); - auto& ranked = t->second.ranked_tensor_type; - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt32); - EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), - kDimsSpan); -} - -TEST(FlatbufferToLiteRtTest, MapNoQuantization) { - LiteRtTensorT tensor; - auto q = MapQuantization(nullptr, tensor); - ASSERT_TRUE(q); - ASSERT_EQ(q->first, kLiteRtQuantizationNone); -} - -TEST(FlatbufferToLiteRtTest, MapPerTensorQuantization) { - static constexpr float kScale = 1.0; - static constexpr int64_t kZp = 2; - - TflQuantization tfl_q; - tfl_q.scale.assign({kScale}); - tfl_q.zero_point.assign({kZp}); - - LiteRtTensorT tensor; - auto q = MapQuantization(&tfl_q, tensor); - ASSERT_TRUE(q); - ASSERT_EQ(q->first, kLiteRtQuantizationPerTensor); - EXPECT_EQ(q->second.per_tensor.scale, kScale); - EXPECT_EQ(q->second.per_tensor.zero_point, kZp); -} - -TEST(FlatbufferToLiteRtTest, MapPerChannelQuantization) { - static constexpr size_t kRank = 2; - static constexpr float kScales[kRank] = {1.0, 2.0}; - static constexpr int64_t kZps[kRank] = {2, 3}; - static constexpr size_t kQDim = 1; - - TflQuantization tfl_q; - tfl_q.scale.assign(kScales, kScales + kRank); - tfl_q.zero_point.assign(kZps, kZps + kRank); - tfl_q.quantized_dimension = kQDim; - - LiteRtTensorT tensor; - auto q = MapQuantization(&tfl_q, tensor); - ASSERT_TRUE(q); - ASSERT_EQ(q->first, kLiteRtQuantizationPerChannel); - EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.scales, kRank), - ElementsAreArray(kScales)); - - EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.zero_points, kRank), - ElementsAreArray(kZps)); - EXPECT_EQ(q->second.per_channel.quantized_dimension, kQDim); - EXPECT_EQ(q->second.per_channel.num_channels, kRank); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/graph_validation.cc b/tensorflow/lite/experimental/litert/core/model/graph_validation.cc deleted file mode 100644 index a9a942c1bfaa14..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/graph_validation.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -namespace litert::internal { - -bool ValidateLocalTopology(const LiteRtOpT& litert_op) { - // Check number of in edges equals number of inputs and each input index - // appears on an in edge. - for (auto i = 0; i < litert_op.Inputs().size(); ++i) { - const auto& litert_tensor = litert_op.Input(i); - - auto input_use = - GetTensorUses(litert_tensor, FindUseInds(litert_tensor, litert_op)); - - if (!ContainsIf(input_use.cbegin(), input_use.cend(), - [i](auto u) { return u.second == i; })) { - LITERT_LOG(LITERT_WARNING, - "Input tensor %d not connected to op on correct index.", i); - return false; - } - } - - // Similar to above for outputs. - for (auto i = 0; i < litert_op.Outputs().size(); ++i) { - const auto& litert_tensor = litert_op.Output(i); - - if (litert_tensor.DefiningOp() != &litert_op) { - LITERT_LOG(LITERT_WARNING, "Output back edge doesn't refer to this op."); - return false; - } - - if (litert_tensor.DefiningOpOutInd() != i) { - LITERT_LOG(LITERT_WARNING, "Output back edge ind is incorrect."); - return false; - } - } - - return true; -} - -bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph) { - auto num_implied_inputs = 0; - auto num_implied_outputs = 0; - for (auto* tensor : litert_subgraph.Tensors()) { - const auto implied_out = tensor->NumUses() == 0; - const auto implied_in = - !IsConstant(*tensor) && tensor->DefiningOp() == nullptr; - - if (implied_out && implied_in) { - LITERT_LOG(LITERT_WARNING, "Graph contains a dead tensor"); - return false; - } - - const auto is_io = IsIO(litert_subgraph, *tensor); - - if (implied_in) { - if (!is_io) { - LITERT_LOG(LITERT_WARNING, - "Implied input not reflected in subgraph io %lu", - tensor - litert_subgraph.Tensors().at(0)); - return false; - } - ++num_implied_inputs; - } - - if (implied_out) { - if (!is_io) { - LITERT_LOG(LITERT_WARNING, - "Implied output not reflected in subgraph io"); - return false; - } - ++num_implied_outputs; - } - } - - if (num_implied_inputs != litert_subgraph.NumInputs()) { - LITERT_LOG( - LITERT_WARNING, - "Number of implied %lu inputs not equal to number of actual inputs %lu", - num_implied_inputs, litert_subgraph.NumInputs()); - return false; - } - - if (num_implied_outputs != litert_subgraph.NumOutputs()) { - LITERT_LOG(LITERT_WARNING, - "Number of implied %lu outputs not equal to number of actual " - "outputs %lu", - num_implied_outputs, litert_subgraph.NumOutputs()); - return false; - } - - return true; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/graph_validation.h b/tensorflow/lite/experimental/litert/core/model/graph_validation.h deleted file mode 100644 index c0a199294f8677..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/graph_validation.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -// Helper functions for validating the structure of IR graphs. - -namespace litert::internal { - -// Checks the double-linked edges to immediate neighbors are valid. -bool ValidateLocalTopology(const LiteRtOpT& litert_op); - -// Runs ValidateLocalTopology across given LiteRtOp iterator. -template -bool ValidateLocalTopology(OpIt start, OpIt end) { - return std::all_of(start, end, - [](const auto* op) { return ValidateLocalTopology(*op); }); -} - -// Checks the following are bijections: -// * non-const tensor with no defining op <-> subgraph input -// * tensor with no users <-> subgraph output (assuming no side effect ops) -// These are used to figure out the i/o signatures when building a subgraph -// from scratch. -bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/ir_allocator.h b/tensorflow/lite/experimental/litert/core/model/ir_allocator.h deleted file mode 100644 index 43433c1ecd02c8..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/ir_allocator.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" - -namespace litert::internal { - -// A list of IR objects scoped to the same block (subgraph) that provides -// pointer stability. Facilitates management of memory and c-like access -// to elements. -template -class IrAllocator { - private: - using Storage = std::list; - using Refs = std::vector; - - public: - // Emplace a new element onto the list. - template - Ir& EmplaceBack(Args&&... args) { - auto& emp = storage_.emplace_back(std::forward(args)...); - refs_->push_back(&emp); - return emp; - } - - // Get the array of (stable) pointers to underlying elements. Suitable - // for passing through c-like interface. Consituent pointers are always - // guarateed to be stable (unless explicitly erased). The array of pointers - // itself is guaranteed to be stable so long as no length-changing operations - // occur, moving this class does not invalidate pointers or array. - absl::Span Elements() const { - return absl::MakeSpan(refs_->data(), refs_->size()); - } - - // Remove elements from the allocator if they match the predicate. - // Returns the number of elements removed. - size_t RemoveIf(std::function pred) { - auto ref_it = refs_->begin(); - for (auto it = storage_.begin(); it != storage_.end();) { - if (!pred(*it)) { - *ref_it = &*it; - ++ref_it; - ++it; - continue; - } - it = storage_.erase(it); - } - const size_t removed = refs_->end() - ref_it; - refs_->resize(refs_->size() - removed); - return removed; - } - - // Cuts all but the first `size` elements from storage. Does nothing if `size` - // is greater or equal to current size. - void ResizeDown(size_t size) { - if (size >= Size()) { - return; - } - storage_.resize(size); - refs_->resize(size); - } - - // Transfers the ownership of given allocator to this one. If `indices` is - // provided, only the objects at the given indices are transferred. - void TransferFrom(IrAllocator& other, - std::optional> indices = std::nullopt) { - if (!indices) { - storage_.splice(storage_.cend(), other.storage_); - refs_->insert(refs_->end(), other.refs_->cbegin(), other.refs_->cend()); - other.ResetRefs(); - return; - } - - auto& inds = *indices; - std::sort(inds.begin(), inds.end()); - std::vector its; - auto i = 0; - auto it = other.storage_.begin(); - for (auto ind : inds) { - std::advance(it, ind - i); - i = ind; - its.push_back(it); - } - for (auto it : its) { - storage_.splice(storage_.cend(), other.storage_, it); - } - - ResetRefs(); - other.ResetRefs(); - } - - // Override for rvalues. - void TransferFrom(IrAllocator&& other) { TransferFrom(other, std::nullopt); } - - // Transfers the object at the given index to the back of the given allocator. - void TransferTo(IrAllocator& other, - std::optional> indices = std::nullopt) { - other.TransferFrom(*this, std::move(indices)); - } - - // Number of elements stored by this allocator. - size_t Size() const { return storage_.size(); } - - IrAllocator() { refs_ = std::make_unique(); } - - // IR is generally semantically movable (without reference invalidation) - // but not copyable. IrAllocators reflect that, note moving lists - // does not invalidate references. - IrAllocator(const IrAllocator& other) = delete; - IrAllocator& operator=(const IrAllocator& other) = delete; - IrAllocator(IrAllocator&& other) = default; - IrAllocator& operator=(IrAllocator&& other) = default; - - private: - void ResetRefs() { - refs_->resize(storage_.size()); - auto it = storage_.begin(); - for (auto i = 0; i < storage_.size(); ++i, ++it) { - refs_->at(i) = &*it; - } - } - - Storage storage_; - std::unique_ptr refs_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc b/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc deleted file mode 100644 index dd895dce211e25..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" - -#include -#include -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -static constexpr auto kCustomOpCode = kLiteRtOpCodeTflCustom; -static constexpr auto kNonCustomOpCode = kLiteRtOpCodeTflSoftmax; - -TEST(IrAllocatorTest, EmplaceBack) { - IrAllocator ops; - - LiteRtOpT my_op; - my_op.SetOpCode(kCustomOpCode); - - ops.EmplaceBack(std::move(my_op)); - ASSERT_EQ(ops.Elements().size(), 1); - EXPECT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); -} - -TEST(IrAllocatorTest, RemoveIf) { - IrAllocator ops; - - LiteRtOpT my_op; - my_op.SetOpCode(kNonCustomOpCode); - ops.EmplaceBack(std::move(my_op)); - - LiteRtOpT my_op2; - my_op2.SetOpCode(kCustomOpCode); - ops.EmplaceBack(std::move(my_op2)); - - LiteRtOpT my_op3; - my_op3.SetOpCode(kCustomOpCode); - ops.EmplaceBack(std::move(my_op3)); - - LiteRtOpT my_op4; - my_op4.SetOpCode(kNonCustomOpCode); - ops.EmplaceBack(std::move(my_op4)); - - auto pred = [](const auto& op) { return op.OpCode() != kCustomOpCode; }; - ASSERT_EQ(ops.RemoveIf(pred), 2); - - ASSERT_EQ(ops.Elements().size(), 2); - ASSERT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); - ASSERT_EQ(ops.Elements().at(1)->OpCode(), kCustomOpCode); -} - -TEST(IrAllocatorTest, ResizeDown) { - IrAllocator ops; - - LiteRtOp op1 = nullptr; - { - LiteRtOpT my_op; - my_op.SetOpCode(kNonCustomOpCode); - op1 = &ops.EmplaceBack(std::move(my_op)); - } - - { - LiteRtOpT my_op2; - my_op2.SetOpCode(kCustomOpCode); - ops.EmplaceBack(std::move(my_op2)); - } - - ops.ResizeDown(1); - - ASSERT_EQ(ops.Size(), 1); - EXPECT_EQ(ops.Elements().at(0), op1); -} - -TEST(IrAllocatorTest, Transfer) { - IrAllocator ops; - auto& op1 = ops.EmplaceBack(); - auto& op2 = ops.EmplaceBack(); - - IrAllocator other_ops; - auto& other_op1 = other_ops.EmplaceBack(); - auto& other_op2 = other_ops.EmplaceBack(); - - ops.TransferFrom(std::move(other_ops)); - - EXPECT_THAT(ops.Elements(), - ElementsAreArray({&op1, &op2, &other_op1, &other_op2})); -} - -TEST(IrAllocatorTest, TransferWithIndices) { - IrAllocator ops; - auto& op1 = ops.EmplaceBack(); - auto& op2 = ops.EmplaceBack(); - - IrAllocator other_ops; - auto& other_op1 = other_ops.EmplaceBack(); - auto& other_op2 = other_ops.EmplaceBack(); - auto& other_op3 = other_ops.EmplaceBack(); - auto& other_op4 = other_ops.EmplaceBack(); - - std::vector indices = {1, 3}; - ops.TransferFrom(other_ops, std::move(indices)); - - EXPECT_THAT(other_ops.Elements(), ElementsAreArray({&other_op1, &other_op3})); - EXPECT_THAT(ops.Elements(), - ElementsAreArray({&op1, &op2, &other_op2, &other_op4})); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc deleted file mode 100644 index 90292600ace0d3..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc +++ /dev/null @@ -1,128 +0,0 @@ - -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { - -namespace { - -Expected MapElementType(LiteRtElementType litert_element_type) { - switch (litert_element_type) { - case kLiteRtElementTypeFloat32: - return tflite::TensorType_FLOAT32; - case kLiteRtElementTypeFloat16: - return tflite::TensorType_FLOAT16; - case kLiteRtElementTypeInt32: - return tflite::TensorType_INT32; - case kLiteRtElementTypeInt64: - return tflite::TensorType_INT64; - case kLiteRtElementTypeBool: - return tflite::TensorType_BOOL; - case kLiteRtElementTypeInt16: - return tflite::TensorType_INT16; - case kLiteRtElementTypeInt8: - return tflite::TensorType_INT8; - default: - return Error(kLiteRtStatusErrorUnsupported); - } -} - -template -Expected MapTensorTypeDetail( - const LiteRtTenzorType& litert_tensor_type) { - return Error(kLiteRtStatusErrorUnsupported); -} - -template <> -Expected MapTensorTypeDetail( - const LiteRtRankedTensorType& litert_tensor_type) { - auto tfl_element_type = MapElementType(litert_tensor_type.element_type); - if (!tfl_element_type) { - return tfl_element_type.Error(); - } - - auto litert_shape = absl::MakeConstSpan(litert_tensor_type.layout.dimensions, - litert_tensor_type.layout.rank); - return std::make_pair(*tfl_element_type, TflShapeInfo(litert_shape)); -} - -template -Expected MapQuantizationDetail( - const LiteRtQuantDetail& litert_quantization) { - return Error(kLiteRtStatusErrorUnsupported); -} - -template <> -Expected MapQuantizationDetail( - const LiteRtQuantizationPerTensor& litert_quantization) { - auto tfl_quantization = std::make_unique(); - tfl_quantization->scale.assign({litert_quantization.scale}); - tfl_quantization->zero_point.assign({litert_quantization.zero_point}); - return tfl_quantization; -} - -template <> -Expected -MapQuantizationDetail( - const LiteRtQuantizationPerChannel& litert_quantization) { - auto tfl_quantization = std::make_unique(); - - for (int i = 0; i < litert_quantization.num_channels; ++i) { - tfl_quantization->scale.push_back(litert_quantization.scales[i]); - tfl_quantization->zero_point.push_back(litert_quantization.zero_points[i]); - } - tfl_quantization->quantized_dimension = - litert_quantization.quantized_dimension; - return tfl_quantization; -} - -} // namespace - -Expected MapTensorType(const TensorType& litert_tensor_type) { - switch (litert_tensor_type.first) { - case kLiteRtRankedTensorType: - return MapTensorTypeDetail(litert_tensor_type.second.ranked_tensor_type); - default: - return Error(kLiteRtStatusErrorUnsupported); - } -} - -Expected MapQuantization( - const Quantization& litert_quantization) { - switch (litert_quantization.first) { - case kLiteRtQuantizationNone: - return TflQuantizationPtr(nullptr); - case kLiteRtQuantizationPerTensor: - return MapQuantizationDetail(litert_quantization.second.per_tensor); - case kLiteRtQuantizationPerChannel: - return MapQuantizationDetail(litert_quantization.second.per_channel); - default: - return Error(kLiteRtStatusErrorUnsupported); - } -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h deleted file mode 100644 index 4fbe51bf9d3a0b..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h +++ /dev/null @@ -1,32 +0,0 @@ - -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -Expected MapTensorType(const TensorType& litert_tensor_type); - -Expected MapQuantization( - const Quantization& litert_quantization); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc deleted file mode 100644 index 3f5c8fdf101fa1..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc +++ /dev/null @@ -1,108 +0,0 @@ - -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" - -#include -#include -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -TEST(LiteRtToFlatbufferTest, MapNoQuantization) { - Quantization q; - auto tfl_q = MapQuantization(q); - ASSERT_TRUE(tfl_q); - EXPECT_EQ(tfl_q.Value(), nullptr); -} - -TEST(LiteRtToFlatbufferTest, MapPerTensorQuantization) { - static constexpr float kScale = 1.0; - static constexpr int64_t kZp = 2; - - Quantization q; - q.first = kLiteRtQuantizationPerTensor; - q.second.per_tensor.scale = kScale; - q.second.per_tensor.zero_point = kZp; - - auto tfl_q = MapQuantization(q); - ASSERT_TRUE(tfl_q); - EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray({kScale})); - EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray({kZp})); -} - -TEST(LiteRtToFlatbufferTest, MapPerChannelQuantization) { - static constexpr size_t kRank = 2; - static constexpr size_t kQuantizedDimension = 1; - static constexpr float kScales[kRank] = {1.0, 2.0}; - static constexpr int64_t kZps[kRank] = {2, 3}; - - Quantization q; - q.first = kLiteRtQuantizationPerChannel; - q.second.per_channel.scales = const_cast(kScales); - q.second.per_channel.zero_points = const_cast(kZps); - q.second.per_channel.num_channels = kRank; - q.second.per_channel.quantized_dimension = kQuantizedDimension; - - auto tfl_q = MapQuantization(q); - ASSERT_TRUE(tfl_q); - EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray(kScales)); - EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray(kZps)); -} - -TEST(LiteRtToFlatbufferTest, MapDynamicTensorType) { - static constexpr int32_t kDims[] = {-1, 2}; - - TensorType t; - t.first = kLiteRtRankedTensorType; - t.second.ranked_tensor_type.element_type = kLiteRtElementTypeFloat32; - t.second.ranked_tensor_type.layout = BuildLayout(kDims); - - auto tfl_t = MapTensorType(t); - ASSERT_TRUE(tfl_t); - EXPECT_EQ(tfl_t->first, TflElementType::TensorType_FLOAT32); - EXPECT_TRUE(tfl_t->second.has_rank); - EXPECT_THAT(tfl_t->second.shape, ElementsAreArray({1, 2})); - EXPECT_THAT(tfl_t->second.shape_signature, ElementsAreArray(kDims)); -} - -TEST(LiteRtToFlatbufferTest, MapStaticTensorType) { - static constexpr int32_t kDims[] = {2, 2}; - - TensorType t; - t.first = kLiteRtRankedTensorType; - t.second.ranked_tensor_type.element_type = kLiteRtElementTypeFloat32; - t.second.ranked_tensor_type.layout = BuildLayout(kDims); - - auto tfl_t = MapTensorType(t); - ASSERT_TRUE(tfl_t); - EXPECT_EQ(tfl_t->first, TflElementType::TensorType_FLOAT32); - EXPECT_TRUE(tfl_t->second.has_rank); - EXPECT_THAT(tfl_t->second.shape, ElementsAreArray({2, 2})); - EXPECT_TRUE(tfl_t->second.shape_signature.empty()); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model.cc b/tensorflow/lite/experimental/litert/core/model/model.cc deleted file mode 100644 index 552e3f1d5ed96a..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model.cc +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" - -using ::litert::BufferRef; -using ::litert::internal::TflBuffer; -using ::litert::internal::TflBufferPtr; -using ::litert::internal::TflOpCode; -using ::litert::internal::TflOpCodePtr; -using ::litert::internal::TflOptions; -using ::litert::internal::TflOptions2; - -std::optional GetBuildStamp( - const LiteRtModelT& model) { - using ::litert::internal::kLiteRtBuildStampKey; - using ::litert::internal::ParseBuildStamp; - - auto stamp_meta = model.FindMetadata(kLiteRtBuildStampKey); - if (!stamp_meta) { - return std::nullopt; - } - auto parsed_stamp = ParseBuildStamp(*stamp_meta); - if (!parsed_stamp) { - return std::nullopt; - } - auto [soc_manufacturer, soc_model] = *parsed_stamp; - return LiteRtModelT::BuildStamp{soc_manufacturer, soc_model}; -} - -bool IsCompiled(const LiteRtModelT& model) { - return GetBuildStamp(model).has_value(); -} - -std::optional GetCustomOpCode(const LiteRtModelT& model, - const LiteRtOpT& op) { - if (op.OpCode() != kLiteRtOpCodeTflCustom) { - return {}; - } - const auto& tfl_op_codes = litert::internal::GetTflOpCodes(model); - const auto tfl_op_code_ind = litert::internal::GetTflOpCodeInd(op); - return tfl_op_codes[tfl_op_code_ind]->custom_code; -} - -TensorType MakeRankedTensorType(LiteRtElementType element_type, - absl::Span dims) { - TensorType tensor_type; - tensor_type.first = kLiteRtRankedTensorType; - auto& ranked = tensor_type.second.ranked_tensor_type; - ranked.element_type = element_type; - ABSL_DCHECK_LE(dims.size(), LITERT_TENSOR_MAX_RANK); - ranked.layout.rank = dims.size(); - std::copy(dims.begin(), dims.end(), ranked.layout.dimensions); - // Strides not yet supported. - ranked.layout.strides = nullptr; - return tensor_type; -} - -Quantization MakePerTensorQuantization(float scale, int64_t zero_point) { - Quantization quantization; - quantization.first = kLiteRtQuantizationPerTensor; - quantization.second.per_tensor.scale = scale; - quantization.second.per_tensor.zero_point = zero_point; - return quantization; -} - -LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph) { - auto tensor_name = [](auto* tensor) { return std::string(tensor->Name()); }; - - auto in_start = subgraph->Inputs().cbegin(); - auto in_end = subgraph->Inputs().cend(); - std::vector input_names(subgraph->NumInputs()); - std::transform(in_start, in_end, input_names.begin(), tensor_name); - - auto out_start = subgraph->Outputs().cbegin(); - auto out_end = subgraph->Outputs().cend(); - std::vector output_names(subgraph->NumOutputs()); - std::transform(out_start, out_end, output_names.begin(), tensor_name); - - std::string name(LiteRtSignatureT::kDefaultSignatureKey); - return LiteRtSignatureT(subgraph, std::move(input_names), - std::move(output_names), std::move(name)); -} - -::litert::Expected LookupSubgraph( - const LiteRtModelT& model, absl::string_view signature_key) { - auto sig = model.FindSignature(signature_key); - if (!sig) { - return sig.Error(); - } - return &sig->get().GetSubgraph(); -} - -void LiteRtModelT::TransferSubgraphTo(LiteRtSubgraphT::Alloc& dest, - std::vector indices) { - if (indices.empty()) { - return; - } - std::sort(indices.begin(), indices.end()); - std::vector new_inds(subgraphs_.Size(), 0); - auto num_removed = 0; - auto i = indices.begin(); - for (size_t j = 0; j < new_inds.size(); ++j) { - if (i != indices.end() && *i == j) { - ++num_removed; - // Keep track of removed sgs just for dcheck. - new_inds[j] = -1; - ++i; - continue; - } - new_inds[j] = j - num_removed; - } - - ForEachIr( - this, [&](LiteRtSubgraph subgraph, int32_t subgraph_index, LiteRtOp op) { - if (op->OpCode() != kLiteRtOpCodeShloComposite) { - return; - } - auto opts = litert::internal::TakeTflOptions2(*op); - auto& decomp_ind = - opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - const auto new_ind = new_inds[decomp_ind]; - - // This op is either in a removed subgraph or refers to a subgraph that - // is not being removed. - ABSL_DCHECK((subgraph_index == -1) || (new_ind >= 0)); - - decomp_ind = new_ind; - litert::internal::SetTflOptions2(*op, std::move(opts)); - }); - - subgraphs_.TransferTo(dest, std::move(indices)); -} - -namespace litert::internal { - -void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind) { - litert_op.tfl_op_code_ind_ = tfl_op_code_ind; -} - -int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op) { - return litert_op.tfl_op_code_ind_; -} - -const TflOptions& GetTflOptions(const LiteRtOpT& litert_op) { - return litert_op.tfl_option_; -} - -const TflOptions2& GetTflOptions2(const LiteRtOpT& litert_op) { - return litert_op.tfl_option_2_; -} - -TflOptions&& TakeTflOptions(LiteRtOpT& litert_op) { - return std::move(litert_op.tfl_option_); -} - -TflOptions2&& TakeTflOptions2(LiteRtOpT& litert_op) { - return std::move(litert_op.tfl_option_2_); -} - -const std::vector& GetTflOpCodes( - const LiteRtModelT& litert_model) { - return litert_model.tfl_operator_codes_; -} - -std::vector&& TakeTflOpCodes(LiteRtModelT& litert_model) { - return std::move(litert_model.tfl_operator_codes_); -} - -// new stuff start -void SetTflFlatbuffer(LiteRtModelT& litert_model, - LiteRtModelT::TflFlatbuffer&& tfl_flatbuffer) { - litert_model.tfl_flatbuffer_ = std::move(tfl_flatbuffer); -} - -const LiteRtModelT::TflFlatbuffer& GetTflFlatbuffer( - const LiteRtModelT& litert_model) { - return litert_model.tfl_flatbuffer_; -} -// new stuff end - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model.h b/tensorflow/lite/experimental/litert/core/model/model.h deleted file mode 100644 index 5d4bcfacb0a380..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model.h +++ /dev/null @@ -1,1024 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -//////////////////////////////////////////////////////////////////////////////// -// Internal LiteRtIR -// -// These are the backing definitions for the opaque types in the c api -// (c/litert_model.h). -// -// < STORAGE DETAIL > -// -// Unless deleted as a result of calls c api client, the lifetime of all "IR -// Objects" (definitions of opaque types) are designed to be transitively owned -// by the LiteRtModelT which is generally the longset living object. See various -// "Emplace" methods. -// -// Since c api clients interface with pointers to IR Ojbects, a form of pointer -// stability is desirable. Classes in this file enforce that pointers to IR -// Objects are valid for their entire life time. Thus a c api client may store -// pointers and depend on referential equality of IR Objects thoughout different -// calls. This also facilitates storing edge/parent-references as pointers -// within IR Objects. -// -// Direct copying is generally not allowed for IR Objects since copying -// instances of mutually recursive types is not entirely well-defined. -// -// IR Objects are generally default constructible to facilitate stable storage -// and iterative construction. -// -// < EXPOSING TFLITE SCHEMA > -// -// Direct access to tflite schema types is limited to the "detail" namespace. -// This indicates that encapsulating all the details of the flatbuffer is a WIP. -// Future implementations may use different data forms (new litert serialized -// format, tflite runtime types etc). -// -// < USAGE NOTE > -// -// The classes here contain only simple getters & setters. Care should be taken -// to leave the IR in a valid state when using setters since the graph is -// doubly-linked. Higher-level functionality for correct graph mutation can be -// found in "model_graph.h". -//////////////////////////////////////////////////////////////////////////////// - -// All tflite schema type usage. -namespace litert::internal { - -// OP - -// Placeholder for the ind of the dispatch op code added during serialization. -static constexpr auto kDispatchOpCodeTflInd = -1; - -void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind); - -int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op); - -template -void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); - -template -void SetTflOptions2(LiteRtOpT& litert_op, Arg&& arg); - -const ::litert::internal::TflOptions& GetTflOptions(const LiteRtOpT& litert_op); - -const ::litert::internal::TflOptions2& GetTflOptions2( - const LiteRtOpT& litert_op); - -::litert::internal::TflOptions&& TakeTflOptions(LiteRtOpT& litert_op); - -::litert::internal::TflOptions2&& TakeTflOptions2(LiteRtOpT& litert_op); - -void ClearTflOptions(LiteRtOpT& litert_op); - -// MODEL - -const std::vector<::litert::internal::TflOpCodePtr>& GetTflOpCodes( - const LiteRtModelT& litert_model); - -template -void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg); - -std::vector<::litert::internal::TflOpCodePtr>&& TakeTflOpCodes( - LiteRtModelT& litert_model); - -void SetTflFlatbuffer(LiteRtModelT& litert_model, - ::litert::internal::FlatbufferWrapper&& tfl_flatbuffer); - -const ::litert::internal::FlatbufferWrapper& GetTflFlatbuffer( - const LiteRtModelT& litert_model); - -} // namespace litert::internal - -// -// Helpers for conceptual unions from C api. -// - -// // For requesting opaque data stored within IR. -using ScratchBufferProvider = std::function; - -// TENSOR TYPE - -// Detail convenience type for tensor type union. -typedef union { - LiteRtUnrankedTensorType unranked_tensor_type; - LiteRtRankedTensorType ranked_tensor_type; -} TensorTypeDetail; - -// Union and identifier for tensor types. -using TensorType = std::pair; - -// Construct tensor type union as ranked tensor. NOTE: Copies data in `dims`. -TensorType MakeRankedTensorType(LiteRtElementType element_type, - absl::Span dims); - -// QUANTIZATION TYPE - -// Detail convenience type for quantization type union. -typedef union { - LiteRtQuantizationPerTensor per_tensor; - LiteRtQuantizationPerChannel per_channel; -} QuantizationDetail; - -// Union and identifier for quantization types. -using Quantization = std::pair; - -// Make default type with quantization info. -inline Quantization MakeEmptyQuantization() { - return Quantization(kLiteRtQuantizationNone, QuantizationDetail()); -} - -// Construct quantization type as per tensor. -Quantization MakePerTensorQuantization(float scale, int64_t zero_point); - -// Construct quantization type as per channel, requires buffer callback to -// store data. -template -Quantization MakePerChannelQuantization(const Scales& scales, - const ZeroPoints& zero_points, - int32_t quantized_dim, - ScratchBufferProvider buffer_provider) { - const auto size = std::size(scales); - ABSL_DCHECK_EQ(size, std::size(zero_points)); - - Quantization res; - res.first = kLiteRtQuantizationPerChannel; - - res.second.per_channel.num_channels = size; - res.second.per_channel.quantized_dimension = quantized_dim; - - const size_t scales_buf_size = size * sizeof(float); - const size_t zeros_buf_size = size * sizeof(int64_t); - auto* scales_buf = reinterpret_cast(buffer_provider(scales_buf_size)); - auto* zeros_buf = reinterpret_cast(buffer_provider(zeros_buf_size)); - std::copy(std::cbegin(scales), std::cend(scales), scales_buf); - std::copy(std::cbegin(zero_points), std::cend(zero_points), zeros_buf); - - res.second.per_channel.scales = scales_buf; - res.second.per_channel.zero_points = zeros_buf; - - return res; -} - -// -// Tensor -// - -// Constant data associated with a tensor. -class LiteRtWeightsT { - private: - using OwnedBuffer = ::litert::OwningBufferRef; - - public: - using BufferId = ::litert::internal::BufferManager::BufferId; - using BufferManager = ::litert::internal::BufferManager; - - // Underlying data. - ::litert::BufferRef Buffer() const { - auto buf = GetBufferManager()->GetBuffer(buffer_id_); - ABSL_DCHECK(buf.HasValue()); - return *buf; - } - - // Set the buffer manager, expects a stable pointer. A default buffer manager - // will be initialized for convenience but most cases will share a single - // buffer manager owned by the model. - void SetBufferManager(BufferManager* buffer_manager) { - buffer_manager_ = buffer_manager; - } - - // Get the underlying buffer manager. - BufferManager* GetBufferManager() const { - if (std::holds_alternative(buffer_manager_)) { - return std::get(buffer_manager_); - } else { - return std::get(buffer_manager_).get(); - } - } - - // Set from a pre-registered buffer. This expects buffer was registered - // with the same manager. - void SetBufferId(BufferId buffer_id) { buffer_id_ = buffer_id; } - - // Get the id generated for the buffer by the manager. - BufferId GetBufferId() const { return buffer_id_; } - - // IR is generally, default constructible and movable but not copyable. - LiteRtWeightsT() = default; - explicit LiteRtWeightsT(BufferManager* buffer_manager) - : buffer_manager_(buffer_manager) {} - LiteRtWeightsT(const LiteRtWeightsT&) = delete; - LiteRtWeightsT(LiteRtWeightsT&&) = default; - LiteRtWeightsT& operator=(const LiteRtWeightsT&) = delete; - LiteRtWeightsT& operator=(LiteRtWeightsT&&) = default; - - private: - BufferId buffer_id_ = BufferManager::kEmptyBufferId; - std::variant buffer_manager_ = - std::make_unique(); -}; - -// Set weights via an unowned buffer. Caller is responsible for ensuring the -// buffer outlives the weights. Registers the buffer with the manager. -inline void SetWeightsFromUnownedBuffer( - LiteRtWeightsT& weights, ::litert::BufferRef buffer, - std::optional context = std::nullopt) { - auto* manager = weights.GetBufferManager(); - auto buf_id = manager->RegisterNonOwnedBuffer(buffer, context); - weights.SetBufferId(buf_id); -} - -// Set weights via an unowned buffer. Caller is responsible for ensuring the -// buffer outlives the weights. Registers the buffer with the manager. -inline void SetWeightsFromOwnedBuffer( - LiteRtWeightsT& weights, ::litert::OwningBufferRef&& buffer, - std::optional context = std::nullopt) { - auto* manager = weights.GetBufferManager(); - auto buf_id = manager->RegisterOwnedBuffer(std::move(buffer), context); - weights.SetBufferId(buf_id); -} - -// Fundamental value in a litert program, "edges" in the graph. -class LiteRtTensorT { - private: - using UserData = std::unique_ptr; - - public: - using Ref = std::reference_wrapper; - using Use = std::pair; - using UseVec = std::vector; - using Alloc = ::litert::internal::IrAllocator; - - // The ops that take this tensor as input. - const std::vector& Users() const { return users_; } - std::vector& Users() { return users_; } - - // Which operand index users take this tensor on, respects the ordering of - // users.. - const std::vector& UserArgInds() const { - return user_arg_inds_; - } - std::vector& UserArgInds() { return user_arg_inds_; } - - // Number of uses, same as number of user arg inds. - size_t NumUses() const { return users_.size(); } - - // Get the ith use. - Use GetUse(size_t ind) const { - return {users_.at(ind), user_arg_inds_.at(ind)}; - } - - // Remove the use at the given index. - void RemoveUse(size_t ind) { - users_.erase(users_.begin() + ind); - user_arg_inds_.erase(user_arg_inds_.begin() + ind); - } - - // Get the op that outputs this tensor, null if constant or subgraph input. - LiteRtOp DefiningOp() const { return defining_op_; } - - // Get the output index of the op that defines this tensor, only meaningful - // if it has a defining op. - LiteRtParamIndex DefiningOpOutInd() const { return defining_op_out_ind_; } - - // Update the defining op of this tensor. The caller is required to update the - // given op's output if not already correct. - void SetDefiningOp(LiteRtOpT& defining_op, LiteRtParamIndex out_ind) { - defining_op_ = &defining_op; - defining_op_out_ind_ = out_ind; - } - - // Set the defining op to none. - void ClearDefiningOp() { - defining_op_ = nullptr; - defining_op_out_ind_ = 0; - } - - // Any constant data associated with this tensor. - const LiteRtWeightsT& Weights() const { return weights_; } - LiteRtWeightsT& Weights() { return weights_; } - - // Authored name associated with this tensor. May be empty. - absl::string_view Name() const { return name_; } - - // Update the name associated with this tensor. - void SetName(std::string name) { name_ = std::move(name); } - - // Get quantization information for this tensor. - const Quantization& Qparams() const { return quantization_; } - Quantization& Qparams() { return quantization_; } - - // Set quantization information. - template - void SetQarams(Arg&& arg) { - quantization_ = std::forward(arg); - } - - // Get the tensor type of this tensor. - const TensorType& Type() const { return tensor_type_; } - TensorType& Type() { return tensor_type_; } - - // Set the tensor type. - template - void SetType(Arg&& arg) { - tensor_type_ = std::forward(arg); - } - - // Get a new buffer that will live as long as this tensor. Used for storing - // various buffers passed through c-api (dims, quantization etc). - // NOTE: This is just scratch data unrelated to weights buffer. - uint8_t* RequestScratchBuffer(size_t size) { - user_data_.push_back(std::make_unique(size)); - return user_data_.back().get(); - } - - // Allow for implicit conversion to scratch buffer provider. - // NOTE: This is just scratch data unrelated to weights buffer. - // NOLINTNEXTLINE - operator ScratchBufferProvider() & { - return [this](auto s) { return this->RequestScratchBuffer(s); }; - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtTensorT() = default; - LiteRtTensorT(::litert::internal::BufferManager* buffer_manager) - : weights_(buffer_manager) {} - LiteRtTensorT(const LiteRtTensorT&) = delete; - LiteRtTensorT(LiteRtTensorT&&) = default; - LiteRtTensorT& operator=(const LiteRtTensorT&) = delete; - LiteRtTensorT& operator=(LiteRtTensorT&&) = default; - - private: - std::vector users_; - std::vector user_arg_inds_; - - LiteRtOp defining_op_ = nullptr; - LiteRtParamIndex defining_op_out_ind_; - - LiteRtWeightsT weights_; - Quantization quantization_; - TensorType tensor_type_; - - std::string name_; - - std::vector user_data_; -}; - -// Helper to get multiple uses at once. -template -LiteRtTensorT::UseVec GetTensorUses(const LiteRtTensorT& tensor, - const Inds& inds) { - auto start = std::cbegin(inds); - auto end = std::cend(inds); - LiteRtTensorT::UseVec uses(end - start); - auto get = [&tensor = std::as_const(tensor)](auto i) { - return tensor.GetUse(i); - }; - std::transform(start, end, uses.begin(), get); - return uses; -} - -// -// Op -// - -// Fundamental unit of compute of a litert program, or "nodes" in the graph. -class LiteRtOpT { - public: - using Ref = std::reference_wrapper; - using Alloc = ::litert::internal::IrAllocator; - - // Input tensors for this op. - const std::vector& Inputs() const { return inputs_; } - std::vector& Inputs() { return inputs_; } - - // Access input at given ind. - LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } - const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } - - // Number of input tensors. - size_t NumInputs() const { return inputs_.size(); } - - // Output tensors for this op. - const std::vector& Outputs() const { return outputs_; } - std::vector& Outputs() { return outputs_; } - - // Number of output tensors. - size_t NumOutputs() const { return outputs_.size(); } - - // Access output at given ind. - LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } - const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } - - // Remove the ith entry of input list. - void RemoveInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } - - // Remove the ith entry of output list. - void RemoveOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } - - // Get any custom options attached to this op. Empty if there are none. - litert::BufferRef CustomOptions() const { return custom_options_; } - - // Attach custom opaque optins to this op. - template - void SetCustomOptions(Args&&... args) { - custom_options_ = - ::litert::OwningBufferRef(std::forward(args)...); - } - - // Sets the custom options to zero length buffer. - void ClearCustomOptions() { custom_options_.Reset(); } - - // Get the op code. - LiteRtOpCode OpCode() const { return litert_op_code_; } - - // Set the op code. - void SetOpCode(LiteRtOpCode litert_op_code) { - litert_op_code_ = litert_op_code; - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtOpT() = default; - LiteRtOpT(const LiteRtOpT&) = delete; - LiteRtOpT(LiteRtOpT&&) = default; - LiteRtOpT& operator=(const LiteRtOpT&) = delete; - LiteRtOpT& operator=(LiteRtOpT&&) = default; - - // Friendship for internal tflite details. - friend void litert::internal::SetTflOpCodeInd(LiteRtOpT& litert_op, - int32_t tfl_op_code_ind); - - friend int32_t litert::internal::GetTflOpCodeInd(const LiteRtOpT& litert_op); - - template - friend void litert::internal::SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); - - template - friend void litert::internal::SetTflOptions2(LiteRtOpT& litert_op, Arg&& arg); - - friend const ::litert::internal::TflOptions& litert::internal::GetTflOptions( - const LiteRtOpT& litert_op); - - friend const ::litert::internal::TflOptions2& - litert::internal::GetTflOptions2(const LiteRtOpT& litert_op); - - friend ::litert::internal::TflOptions&& litert::internal::TakeTflOptions( - LiteRtOpT& litert_op); - - friend ::litert::internal::TflOptions2&& litert::internal::TakeTflOptions2( - LiteRtOpT& litert_op); - - friend void litert::internal::ClearTflOptions(LiteRtOpT& litert_op); - - private: - LiteRtOpCode litert_op_code_; - - ::litert::OwningBufferRef custom_options_; - - std::vector inputs_; - std::vector outputs_; - - // TFLITE - int32_t tfl_op_code_ind_ = litert::internal::kDispatchOpCodeTflInd; - ::litert::internal::TflOptions tfl_option_; - ::litert::internal::TflOptions2 tfl_option_2_; -}; - -// Clears any attribute data and sets the op to be a dispatch op. -inline void MakeDispatchOp(LiteRtOpT& op) { - litert::internal::ClearTflOptions(op); - op.ClearCustomOptions(); - op.SetOpCode(kLiteRtOpCodeTflCustom); - litert::internal::SetTflOpCodeInd(op, - litert::internal::kDispatchOpCodeTflInd); -} - -// -// Subgraph -// - -// Fundamental block of a litert program. Manages the storage of all -// ops and tensor within. -class LiteRtSubgraphT { - public: - using Ref = std::reference_wrapper; - using Alloc = ::litert::internal::IrAllocator; - - // Get a stable pointer for all of the tensors in this subgraph. - absl::Span Tensors() { return tensors_.Elements(); } - absl::Span Tensors() const { return tensors_.Elements(); } - - // Access the tensor at given ind. - LiteRtTensorT& Tensor(size_t ind) { return *Tensors().at(ind); } - const LiteRtTensorT& Tensor(size_t ind) const { return *Tensors().at(ind); } - - // Get a stable pointer for all of the ops in this subgraph. Will - // be a valid toplological order. - absl::Span Ops() { return ops_.Elements(); } - absl::Span Ops() const { return ops_.Elements(); } - - // Access op at the given ind. - LiteRtOpT& Op(size_t ind) { return *Ops().at(ind); } - const LiteRtOpT& Op(size_t ind) const { return *Ops().at(ind); } - - // All the subgraph input tensors, these also exist in Tensors. - const std::vector& Inputs() const { return inputs_; } - std::vector& Inputs() { return inputs_; } - - // Number of inputs tensors. - size_t NumInputs() const { return inputs_.size(); } - - // Access the subgraph input at given ind. - LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } - const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } - - // All the subgraph output tensors, these also exist in Tensors. - const std::vector& Outputs() const { return outputs_; } - std::vector& Outputs() { return outputs_; } - - // Number of outputs tensors. - size_t NumOutputs() const { return outputs_.size(); } - - // Access the subgraph output at given ind. - LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } - const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } - - // Clear the entry for the ith input. - void ClearInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } - - // Clear the entry for the ith output. - void ClearOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } - - // Construct a new tensor which will be owned by this subgraph and get a - // reference to it. - template - LiteRtTensorT& EmplaceTensor(Args&&... args) { - if (buffer_manager_ == nullptr) { - return tensors_.EmplaceBack(std::forward(args)...); - } else { - // std::cerr << "Emplacing tensor with buffer manager \n"; - return tensors_.EmplaceBack(buffer_manager_, std::forward(args)...); - } - } - - // Construct a new op which will be owned by this subgraph and get a - // reference to it. - template - LiteRtOpT& EmplaceOp(Args&&... args) { - return ops_.EmplaceBack(std::forward(args)...); - } - - // De-allocates ops that pass given predicate. Returns number of ops removed. - size_t RemoveOpIf(std::function pred) { - return ops_.RemoveIf(pred); - } - - // De-allocates tensors that pass given predicate. Returns number of tensors - // removed. - size_t RemoveTensorIf(std::function pred) { - return tensors_.RemoveIf(pred); - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtSubgraphT() = default; - LiteRtSubgraphT(::litert::internal::BufferManager* buffer_manager) - : buffer_manager_(buffer_manager) {}; - LiteRtSubgraphT(const LiteRtSubgraphT&) = delete; - LiteRtSubgraphT(LiteRtSubgraphT&&) = default; - LiteRtSubgraphT& operator=(const LiteRtSubgraphT&) = delete; - LiteRtSubgraphT& operator=(LiteRtSubgraphT&&) = default; - - // Get the buffer manager for this subgraph. - ::litert::internal::BufferManager* GetBufferManager() const { - return buffer_manager_; - } - - private: - // If null, tensors emplaced will own their own buffer managers. - ::litert::internal::BufferManager* buffer_manager_ = nullptr; - - LiteRtTensorT::Alloc tensors_; - - LiteRtOpT::Alloc ops_; - - std::vector inputs_; - std::vector outputs_; -}; - -// -// Signature -// - -class LiteRtSignatureT { - private: - using StrVec = std::vector; - - public: - using Ptr = std::unique_ptr; - using Ref = std::reference_wrapper; - using Alloc = ::litert::internal::IrAllocator; - - static constexpr absl::string_view kDefaultSignatureKey = - ""; - - LiteRtSignatureT(LiteRtSubgraph subgraph, StrVec input_names, - StrVec output_names, std::string key) - : key_(std::move(key)), - subgraph_(subgraph), - input_names_(std::move(input_names)), - output_names_(std::move(output_names)) {} - - // String named inputs for called subgraph. - const StrVec& InputNames() const { return input_names_; } - - // String named outputs for called subgraph. - const StrVec& OutputNames() const { return output_names_; } - - // Get the callable subgraph. - const LiteRtSubgraphT& GetSubgraph() const { return *subgraph_; } - LiteRtSubgraphT& GetSubgraph() { return *subgraph_; } - - // Name of the callable signature. - absl::string_view Key() const { return key_; } - - bool operator==(const LiteRtSignatureT& other) const { - const auto key_eq = key_ == other.key_; - const auto subgraph_eq = subgraph_ == other.subgraph_; - const auto input_names_eq = input_names_ == other.input_names_; - const auto output_names_eq = output_names_ == other.output_names_; - return key_eq && subgraph_eq && input_names_eq && output_names_eq; - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtSignatureT() = default; - LiteRtSignatureT(const LiteRtSignatureT&) = delete; - LiteRtSignatureT(LiteRtSignatureT&&) = default; - LiteRtSignatureT& operator=(const LiteRtSignatureT&) = delete; - LiteRtSignatureT& operator=(LiteRtSignatureT&&) = default; - - private: - std::string key_; - - LiteRtSubgraph subgraph_; - - StrVec input_names_; - StrVec output_names_; -}; - -// Make a basic signature from information in the given subgraph. Used with the -// main subgraph when no explicit signatures have been authored. -LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph); - -// -// Model -// - -// Root-level graph object for litert programs. Manages the storage -// of all litert graph objects within. -class LiteRtModelT { - public: - using Ref = std::reference_wrapper; - using Ptr = std::unique_ptr; - using TflOpCodes = std::vector; - - using BufferManager = ::litert::internal::BufferManager; - using BufferId = BufferManager::BufferId; - - using OpAssetReference = std::pair; - using OpAssetMap = absl::flat_hash_map; - - using MetadataMap = absl::flat_hash_map; - - using TflFlatbuffer = ::litert::internal::FlatbufferWrapper; - - // TODO replace this with the index of the default signature. - static constexpr const size_t kMainSubgraphIndex = 0; - - // SUBGRAPHS - - // Get a stable pointer for all of the subgraphs within this model. - absl::Span Subgraphs() { return subgraphs_.Elements(); } - absl::Span Subgraphs() const { - return subgraphs_.Elements(); - } - - // Access subgraph at given ind. - LiteRtSubgraphT& Subgraph(size_t ind) { return *Subgraphs().at(ind); } - const LiteRtSubgraphT& Subgraph(size_t ind) const { - return *Subgraphs().at(ind); - } - - // Number of subraphs. - size_t NumSubgraphs() const { return subgraphs_.Elements().size(); } - - // Default entry point of this model. - const LiteRtSubgraphT* MainSubgraph() const { - return &Subgraph(kMainSubgraphIndex); - } - LiteRtSubgraph MainSubgraph() { return &Subgraph(kMainSubgraphIndex); } - - // Look up signature by key. - litert::Expected FindSignature( - absl::string_view signature_key) const { - for (LiteRtSignature sig : signatures_.Elements()) { - if (sig->Key() == signature_key) { - return std::ref(*sig); - } - } - return ::litert::Error(kLiteRtStatusErrorNotFound, "Signature not found"); - } - - // Build a new subgraph and get a stable reference to it. - template - LiteRtSubgraphT& EmplaceSubgraph(Args&&... args) { - return subgraphs_.EmplaceBack(Buffers(), std::forward(args)...); - } - - // Transfers given subgraphs into this model. New subgraphs are appended. - void TransferSubgraphsFrom(LiteRtSubgraphT::Alloc&& subgraphs) { - // TODO: Consider mergeing buffer managers here. - subgraphs_.TransferFrom(std::move(subgraphs)); - } - - // Cut all by the first `size` subgraphs. Does nothing if given size is - // greater or equal to current. - void ResizeSubgraphsDown(size_t size) { subgraphs_.ResizeDown(size); } - - // Transfers the subgraph at the given index to the back of the given - // allocator. Also updates any IR owned by the model that refers to subgraphs - // by index (e.g. composites). Does not update any IR in the subgraphs being - // transferred. - void TransferSubgraphTo(LiteRtSubgraphT::Alloc& dest, - std::vector indices); - - // SIGNATURES - - // All signatures registered with this model. - absl::Span Signatures() const { - return signatures_.Elements(); - } - - // Construct a new signature for this model. - template - LiteRtSignatureT& EmplaceSignature(Args&&... args) { - return signatures_.EmplaceBack(std::forward(args)...); - } - - // METADATA - - // Look up metadata by key, getting a view of its buffer as a string - // if it exists. - litert::Expected> FindMetadata( - absl::string_view key) const { - if (auto it = metadata_.find(key); it != metadata_.end()) { - const auto buf_id = it->second; - return Buffers()->GetBuffer(buf_id); - } - return ::litert::Error(kLiteRtStatusErrorNotFound); - } - - // Metadata key-val pair iterator. - MetadataMap::iterator MetadataBegin() { return metadata_.begin(); } - MetadataMap::iterator MetadataEnd() { return metadata_.end(); } - - // Adds a new metadata buffer to the model. Fails if it already exists. - template - LiteRtStatus PushMetadata(absl::string_view key, Args&&... args) { - if (metadata_.contains(key)) { - return kLiteRtStatusErrorInvalidArgument; - } - const auto buf_id = Buffers()->RegisterOwnedBuffer( - ::litert::OwningBufferRef(std::forward(args)...)); - metadata_.emplace(std::make_pair(std::string(key), buf_id)); - return kLiteRtStatusOk; - } - - // BUFFERS - - // Get stable pointer to buffer manager object. - BufferManager* Buffers() const { return buffer_manager_.get(); } - - // Attach an asset to the given op. An asset is a non-tensor buffer - // that is used by the op. Assets may be referenced by multiple ops. - // Each edge from an op to an asset is identified by a name. All buffers - // are appended to the model upon serialization and referenced by offset - // relative to the start of the model within the referring op's custom - // options. - void AttachAssetToOp(LiteRtOp op, BufferId buf_id, std::string name) { - OpAssetReference ref = {buf_id, std::move(name)}; - external_buffer_map_.emplace(op, std::move(ref)); - } - - // Returns an immutable view of the external buffer and the name of the edge - // if the given op has one attached. - litert::Expected FindOpAsset(LiteRtOp op) { - if (auto it = external_buffer_map_.find(op); - it != external_buffer_map_.end()) { - return it->second; - } - return ::litert::Error(kLiteRtStatusErrorNotFound); - } - - // Contains details about the compiler used if this model was compiled. - struct BuildStamp { - absl::string_view soc_manufacturer; - absl::string_view soc_model; - }; - - // IR is generally, default constructible and movable but not copyable. - LiteRtModelT() = default; - LiteRtModelT(const LiteRtModelT&) = delete; - LiteRtModelT(LiteRtModelT&&) = default; - LiteRtModelT& operator=(const LiteRtModelT&) = delete; - LiteRtModelT& operator=(LiteRtModelT&&) = default; - - // TFLITE - - // Friendship for internal tflite details. - friend const TflOpCodes& litert::internal::GetTflOpCodes( - const LiteRtModelT& litert_model); - - template - friend void litert::internal::SetTflOpCodes(LiteRtModelT& litert_model, - Arg&& arg); - - friend TflOpCodes&& litert::internal::TakeTflOpCodes( - LiteRtModelT& litert_model); - - friend void litert::internal::SetTflFlatbuffer( - LiteRtModelT& litert_model, TflFlatbuffer&& tfl_flatbuffer); - - friend const TflFlatbuffer& litert::internal::GetTflFlatbuffer( - const LiteRtModelT& litert_model); - - explicit LiteRtModelT(TflFlatbuffer&& tfl_flatbuffer) - : tfl_flatbuffer_(std::move(tfl_flatbuffer)) {} - - private: - LiteRtSubgraphT::Alloc subgraphs_; - LiteRtSignatureT::Alloc signatures_; - - MetadataMap metadata_; - OpAssetMap external_buffer_map_; - - // Use unique ptr here to keep stable. - BufferManager::Ptr buffer_manager_ = std::make_unique(); - - // TFLITE - TflOpCodes tfl_operator_codes_; - TflFlatbuffer tfl_flatbuffer_; -}; - -// Get the build stamp from the model if it exists. -// TODO: Consider a setter and internalizeing all build stamp stuff behind model -// interface. -std::optional GetBuildStamp( - const LiteRtModelT& model); - -// Returns true if this model contains any ops compiled for NPU. -bool IsCompiled(const LiteRtModelT& model); - -// Get the custom op code from a given op if it is a custom op. -std::optional GetCustomOpCode(const LiteRtModelT& model, - const LiteRtOpT& op); - -// Lookup subgraph by signature name. -::litert::Expected LookupSubgraph( - const LiteRtModelT& model, absl::string_view signature_key); - -namespace litert::internal { - -template -void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg) { - litert_op.tfl_option_ = std::forward(arg); -} - -template -void SetTflOptions2(LiteRtOpT& litert_op, Arg&& arg) { - litert_op.tfl_option_2_ = std::forward(arg); -} - -inline void ClearTflOptions(LiteRtOpT& litert_op) { - litert_op.tfl_option_2_.Reset(); - litert_op.tfl_option_.Reset(); -} - -template -void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg) { - litert_model.tfl_operator_codes_ = std::forward(arg); -} - -} // namespace litert::internal - -// -// Misc Ir Containers -// - -using LiteRtOpWithPartitionIndex = std::pair; - -// Used for communicating selections of ops in when partitioning. -class LiteRtOpListT { - public: - void Push(LiteRtOp op, LiteRtParamIndex partition_index = 0) { - values_.push_back(LiteRtOpWithPartitionIndex(op, partition_index)); - } - - std::vector Values() const { - std::vector ops; - ops.reserve(values_.size()); - ops.assign(values_.begin(), values_.end()); - - return ops; - } - - private: - // Investigate if this is possible with vector (hit some issues). - std::list values_; -}; - -// -// Traversal Utils -// - -// Apply func to all the IR in the given model. Iteration behavior is determined -// by the callback signature. -template -void ForEachIr(LiteRtModel model, F func) { - // Per subgraph callbacks. - using SgF1 = std::function; - using SgF2 = std::function; - - // Per op callbacks. - using OpF1 = std::function; - using OpF2 = std::function; - using OpF3 = - std::function; - - constexpr bool kIsSgOpF1 = std::is_convertible_v; - constexpr bool kIsSgF2 = std::is_convertible_v; - constexpr bool kIsOpF1 = std::is_convertible_v; - constexpr bool kIsOpF2 = std::is_convertible_v; - constexpr bool kIsOpF3 = std::is_convertible_v; - - for (int i = 0; i < model->NumSubgraphs(); ++i) { - auto subgraph = model->Subgraphs()[i]; - - if constexpr (kIsSgF2) { - func(subgraph, i); - } else if constexpr (kIsSgOpF1) { - func(subgraph); - } else { - for (int j = 0; j < subgraph->Ops().size(); ++j) { - auto* op = subgraph->Ops()[j]; - if constexpr (kIsOpF1) { - func(op); - } else if constexpr (kIsOpF2) { - func(subgraph, op); - } else if constexpr (kIsOpF3) { - func(subgraph, i, op); - } - } - } - } -} - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer.cc b/tensorflow/lite/experimental/litert/core/model/model_buffer.cc deleted file mode 100644 index 3353b3adbf10af..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" - -namespace litert { -namespace internal { - -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, - const absl::flat_hash_map>& - custom_code_to_npu_bytecode, - size_t bytecode_alignment) { - for (const auto& subgraph : model.Subgraphs()) { - for (auto op : subgraph->Ops()) { - if (op->OpCode() == kLiteRtOpCodeTflCustom) { - auto custom_code = GetCustomOpCode(model, *op); - if (!custom_code) { - continue; - } - - auto iter = custom_code_to_npu_bytecode.find(*custom_code); - if (iter == custom_code_to_npu_bytecode.end()) { - return Error(kLiteRtStatusErrorUnsupported, - absl::StrFormat("Unexpected custom code: %s", - custom_code->c_str())); - } - - LiteRtOpT* custom_op = op; - OwningBufferRef byte_code(iter->second); - const auto buf_id = - model.Buffers()->RegisterOwnedBuffer(std::move(byte_code)); - model.AttachAssetToOp(custom_op, buf_id, ""); - } - } - } - - return SerializeModel(std::move(model), bytecode_alignment); -} - -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, - const absl::flat_hash_map& - custom_code_to_npu_file, - size_t bytecode_alignment) { - auto model = LoadModelFromFile(tfl_file); - if (!model) { - return model.Error(); - } - - absl::flat_hash_map> - custom_code_to_npu_bytecode; - for (auto& iter : custom_code_to_npu_file) { - auto npu_file_buf = LoadBinaryFile(iter.second); - if (!npu_file_buf) { - return npu_file_buf.Error(); - } - custom_code_to_npu_bytecode[iter.first] = std::move(*npu_file_buf); - } - - return GetModelBufWithByteCode( - std::move(**model), custom_code_to_npu_bytecode, bytecode_alignment); -} - -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, BufferRef npu_byte_code, - size_t bytecode_alignment) { - absl::flat_hash_map> - custom_code_to_npu_bytecode; - for (const auto& subgraph : model.Subgraphs()) { - for (auto op : subgraph->Ops()) { - if (op->OpCode() == kLiteRtOpCodeTflCustom) { - auto custom_code = GetCustomOpCode(model, *op); - if (!custom_code) { - continue; - } - OwningBufferRef byte_code(npu_byte_code.Data(), - npu_byte_code.Size()); - custom_code_to_npu_bytecode[*custom_code] = std::move(byte_code); - } - } - } - - return GetModelBufWithByteCode(std::move(model), custom_code_to_npu_bytecode, - bytecode_alignment); -} - -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, absl::string_view npu_file, - size_t bytecode_alignment) { - auto model = LoadModelFromFile(tfl_file); - if (!model) { - return model.Error(); - } - - auto npu_file_buf = LoadBinaryFile(npu_file); - if (!npu_file_buf) { - return npu_file_buf.Error(); - } - - return GetModelBufWithByteCode(std::move(**model), std::move(*npu_file_buf), - bytecode_alignment); -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer.h b/tensorflow/lite/experimental/litert/core/model/model_buffer.h deleted file mode 100644 index 623e86f19b2899..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// Get a buffer that is the concatenation of given tflite file and one or more -// NPU byte code files. Adds metadata containing the offset/size of npu byte -// code. TFL custom ops are mapped to NPU byte code by their custom code, which -// must be non-null. -// -// NOTE: this is intended to be used for testing and tools and may be removed in -// the future. -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, - const absl::flat_hash_map& - custom_code_to_npu_file, - size_t bytecode_alignment = 1); - -// Same as above, but with a map specifying NPU byte code buffers. -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, - const absl::flat_hash_map>& - custom_code_to_npu_bytecode, - size_t bytecode_alignment = 1); - -// Same as above, but only a single NPU byte code file is specified. -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, absl::string_view npu_file, - size_t bytecode_alignment = 1); - -// Same as above, but only a single NPU byte code buffer is specified. -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, BufferRef npu_byte_code, - size_t bytecode_alignment = 1); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc b/tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc deleted file mode 100644 index 00eb7f557f045e..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" - -#include -#include - -#include -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/interpreter_builder.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/stderr_reporter.h" - -namespace litert::internal { -namespace { - -static constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; -static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -static constexpr absl::string_view kCascadedTfliteFile = - "simple_cascade_model_npu.tflite"; - -TEST(GetModelBufWithByteCode, CreateInterpreter) { - auto model_with_byte_code = - GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - - auto alloc = std::make_unique( - model_with_byte_code->Data(), model_with_byte_code->Size(), - tflite::DefaultErrorReporter()); - - auto fb_model = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(alloc->base()), alloc->bytes()); - ASSERT_NE(fb_model, nullptr); - - tflite::ops::builtin::BuiltinOpResolver resolver; - std::unique_ptr interpreter; - tflite::InterpreterBuilder(*fb_model, resolver)(&interpreter); - EXPECT_NE(interpreter, nullptr); -} - -TEST(GetModelBufWithByteCode, CheckAppended) { - auto model_with_byte_code = - GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - - auto model = LoadModelFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - - auto* op = model->get()->Subgraphs().front()->Ops().front(); - ASSERT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - auto dispatch_opts = GetDispatchOpOptions(op->CustomOptions()); - EXPECT_EQ(dispatch_opts.name, ""); - EXPECT_LE(dispatch_opts.bytecode_offset + dispatch_opts.bytecode_size, - model_with_byte_code->Size()); -} - -TEST(GetModelBufWithByteCode, CreateInterpreterWithMultpleNpuNodes) { - absl::flat_hash_map custom_code_to_npu_file = { - {"DISPATCH_OP_1", testing::GetTestFilePath(kNpuFile)}, - {"DISPATCH_OP_2", testing::GetTestFilePath(kNpuFile)}, - }; - - auto model_with_byte_code = GetModelBufWithByteCode( - testing::GetTestFilePath(kCascadedTfliteFile), custom_code_to_npu_file); - ASSERT_TRUE(model_with_byte_code); - - auto alloc = std::make_unique( - model_with_byte_code->Data(), model_with_byte_code->Size(), - tflite::DefaultErrorReporter()); - - auto fb_model = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(alloc->base()), alloc->bytes()); - ASSERT_NE(fb_model, nullptr); - - tflite::ops::builtin::BuiltinOpResolver resolver; - std::unique_ptr interpreter; - tflite::InterpreterBuilder(*fb_model, resolver)(&interpreter); - EXPECT_NE(interpreter, nullptr); -} - -TEST(GetModelBufWithByteCode, CheckAppendedWithMultipleNpuOps) { - absl::flat_hash_map custom_code_to_npu_file = { - {"DISPATCH_OP_1", testing::GetTestFilePath(kNpuFile)}, - {"DISPATCH_OP_2", testing::GetTestFilePath(kNpuFile)}, - }; - - auto model_with_byte_code = GetModelBufWithByteCode( - testing::GetTestFilePath(kCascadedTfliteFile), custom_code_to_npu_file); - ASSERT_TRUE(model_with_byte_code); - - auto model = LoadModelFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - - for (auto& op : model->get()->Subgraphs().front()->Ops()) { - ASSERT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - auto dispatch_opts = GetDispatchOpOptions(op->CustomOptions()); - EXPECT_EQ(dispatch_opts.name, ""); - EXPECT_LE(dispatch_opts.bytecode_offset + dispatch_opts.bytecode_size, - model_with_byte_code->Size()); - } -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test.cc b/tensorflow/lite/experimental/litert/core/model/model_file_test.cc deleted file mode 100644 index 9f3e07fd6caf3e..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test.cc +++ /dev/null @@ -1,1041 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include - -// schema/mutable/schema_generated.h and schema/schema_generated.h (included -// through flatbuffer_tools.h via model.h) have the same #ifdef, thus this line -// need to be put at the top to ensure we get the "mutable" version. -#if 1 -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#endif - -#include // IWYU pragma: keep -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_file_test_util.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" - -namespace litert::internal { -namespace { - -using ::litert::testing::GetTestFilePath; -using ::testing::Each; -using ::testing::ElementsAreArray; -using ::testing::FloatEq; -using ::testing::Values; -using ::testing::litert::IsError; - -using ModelFactory = std::function()>; - -static constexpr absl::string_view kAddSimple = "add_simple.tflite"; -static constexpr absl::string_view kAddCst = "add_cst.tflite"; -static constexpr absl::string_view kDynamicShapeModel = - "dynamic_shape_tensor.tflite"; -static constexpr absl::string_view kSimpleMultiOp = "simple_multi_op.tflite"; -static constexpr absl::string_view kOneMul = "one_mul.tflite"; -static constexpr absl::string_view kSimpleMultiSubgraph = - "multi_subgraph.tflite"; -static constexpr absl::string_view kCstMultiSubgraph = - "cst_multi_subgraph.tflite"; - -// Load a model, then serialize and re-load. Used to test serialization. -Expected LoadModelThroughRoundTrip(absl::string_view filename) { - auto model = Model::CreateFromFile(GetTestFilePath(filename)); - if (!model) { - return model.Error(); - } - - OwningBufferRef buf; - auto [data, size, offset] = buf.GetWeak(); - - const auto opts = litert::SerializationOptions::Defaults(); - LITERT_RETURN_IF_ERROR(LiteRtSerializeModel(model->Release(), &data, &size, - &offset, true, opts)); - - // Reload model. - LiteRtModel result = nullptr; - LITERT_RETURN_IF_ERROR( - LiteRtCreateModelFromBuffer(buf.Data(), buf.Size(), &result)); - - return Model::CreateFromOwnedHandle(result); -} - -ModelFactory MakeRoundTripFactory(absl::string_view filename) { - return [=]() { return LoadModelThroughRoundTrip(filename); }; -} - -ModelFactory MakeLoadFactory(absl::string_view filename) { - return [=]() { return Model::CreateFromFile(GetTestFilePath(filename)); }; -} - -// Test fixture parameterized by a file path to test model. -class TestWithModelPath : public ::testing::TestWithParam { - protected: - std::string GetTestModelPath() const { - return testing::GetTestFilePath(GetParam()); - } -}; - -// Test fixture pareterized by a function that loads a model. -class TestWithModelFactory : public ::testing::TestWithParam { - protected: - Expected LoadModel() { return GetParam()(); } -}; - -// Simple tests -//===--------------------------------------------------------------------------- - -TEST(ModelLoadTest, BadFilepath) { - LiteRtModel model = nullptr; - EXPECT_THAT(LiteRtCreateModelFromFile("bad_path", &model), - IsError(kLiteRtStatusErrorNotFound)); -} - -TEST(ModelLoadTest, BadFileData) { - // NOLINTBEGIN -#ifndef NDEBUG - // In debug mode, flatbuffers will `assert` while verifying. This will - // cause this test to crash (as expected). - GTEST_SKIP(); -#endif - std::filesystem::path test_file_path(::testing::TempDir()); - test_file_path.append("bad_file.txt"); - - std::ofstream bad_file; - bad_file.open(test_file_path.c_str()); - bad_file << "not_tflite"; - bad_file.close(); - - LiteRtModel model = nullptr; - EXPECT_THAT(LiteRtCreateModelFromFile(test_file_path.c_str(), &model), - IsError(kLiteRtStatusErrorInvalidFlatbuffer)); - // NOLINTEND -} - -TEST(ModelLoadTest, GetCustomOpCode) { - auto model = litert::testing::LoadTestFileModel("simple_model_npu.tflite"); - ASSERT_TRUE(model); - const auto& litert_model = *model.Get(); - const auto& op = *litert_model.MainSubgraph()->Ops().front(); - auto custom_op_code = GetCustomOpCode(litert_model, op); - ASSERT_TRUE(custom_op_code.has_value()); - EXPECT_EQ(*custom_op_code, "DISPATCH_OP"); -} - -TEST(ModelLoadTest, WithMetadata) { - constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; - constexpr static absl::string_view kMetadataData = "My_Meta_Data"; - - auto flatbuffer = - FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(kAddSimple)); - auto tfl_model = flatbuffer->get()->Unpack(); - PushMetadata(kMetadataName, *tfl_model, - BufferRef(kMetadataData.data(), kMetadataData.size())); - auto serialialized = SerializeFlatbuffer(*tfl_model); - - auto litert_model = LoadModelFromBuffer(serialialized); - ASSERT_TRUE(litert_model); - - auto metadata = litert_model->get()->FindMetadata(kMetadataName); - ASSERT_TRUE(metadata); - EXPECT_EQ(metadata->StrView(), kMetadataData); -} - -TEST(ModelSerializeTest, WithMetadata) { - auto model = litert::testing::LoadTestFileModel(kAddSimple); - - constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; - constexpr static absl::string_view kMetadataData = "My_Meta_Data"; - - LITERT_ASSERT_OK(model.Get()->PushMetadata( - kMetadataName, OwningBufferRef(kMetadataData))); - - auto serialized = SerializeModel(std::move(*model.Get())); - EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); - - auto re_loaded = LoadModelFromBuffer(*serialized); - auto metadata = re_loaded->get()->FindMetadata(kMetadataName); - EXPECT_EQ(metadata->StrView(), kMetadataData); -} - -TEST(ModelLoadTest, WithSignature) { - auto model = litert::testing::LoadTestFileModel(kAddSimple); - auto& litert_model = *model.Get(); - - auto signature = - litert_model.FindSignature(LiteRtSignatureT::kDefaultSignatureKey); - ASSERT_TRUE(signature); - - EXPECT_EQ(signature->get().InputNames().size(), 1); - EXPECT_EQ(signature->get().OutputNames().size(), 1); - EXPECT_EQ(&signature->get().GetSubgraph(), litert_model.MainSubgraph()); -} - -TEST(ModelLoadTest, NoSignature) { - auto model = *Model::CreateFromFile(testing::GetTfliteFilePath( - "java/demo/app/src/main/assets/mobilenet_v1_1.0_224.tflite")); - if (!model) { - GTEST_SKIP() << "Model file is not available."; - } - auto& litert_model = *model.Get(); - auto signature = - litert_model.FindSignature(LiteRtSignatureT::kDefaultSignatureKey); - ASSERT_TRUE(signature); - EXPECT_EQ(signature->get().InputNames().size(), 1); - EXPECT_EQ(signature->get().OutputNames().size(), 1); - EXPECT_EQ(&signature->get().GetSubgraph(), litert_model.MainSubgraph()); -} - -TEST(ModelSerializeTest, WithSignature) { - auto model = litert::testing::LoadTestFileModel(kAddSimple); - auto& litert_model = *model.Get(); - - static constexpr char kInput[] = "foo"; - static constexpr char kOutput[] = "bar"; - static constexpr char kKey[] = "newKey"; - - LiteRtSignatureT signature(litert_model.MainSubgraph(), {kInput}, {kOutput}, - kKey); - litert_model.EmplaceSignature(std::move(signature)); - - auto serialized = SerializeModel(std::move(*model.Get())); - EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); - - auto re_loaded = LoadModelFromBuffer(*serialized); - auto re_loaded_signature = re_loaded->get()->FindSignature(kKey); - ASSERT_TRUE(re_loaded_signature); - const auto& sig = re_loaded_signature->get(); - - const auto& inputs = sig.InputNames(); - const auto& outputs = sig.OutputNames(); - EXPECT_THAT(inputs, ElementsAreArray({kInput})); - EXPECT_THAT(outputs, ElementsAreArray({kOutput})); - EXPECT_EQ(&sig.GetSubgraph(), re_loaded->get()->MainSubgraph()); -} - -TEST(ModelLoadTest, ReverseSignature) { - auto model = - litert::testing::LoadTestFileModel("reverse_signature_model.tflite"); - ASSERT_TRUE(model); - auto& litert_model = *model.Get(); - - auto signature = litert_model.FindSignature("serving_default"); - ASSERT_TRUE(signature); - - // Check if the input and output names are in the order of the subgraph - // inputs and outputs instead of the signature appearance order. - const auto& sig = signature->get(); - ASSERT_EQ(sig.InputNames().size(), 2); - EXPECT_STREQ(sig.InputNames()[0].c_str(), "y"); - EXPECT_STREQ(sig.InputNames()[1].c_str(), "x"); - ASSERT_EQ(sig.OutputNames().size(), 2); - EXPECT_STREQ(sig.OutputNames()[0].c_str(), "sum"); - EXPECT_STREQ(sig.OutputNames()[1].c_str(), "prod"); - - auto serialized = SerializeModel(std::move(*model.Get())); - EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); - - auto re_loaded = LoadModelFromBuffer(*serialized); - auto re_loaded_signature = re_loaded->get()->FindSignature("serving_default"); - ASSERT_TRUE(re_loaded_signature); - - // Check again with the serialized model. - const auto& re_sig = re_loaded_signature->get(); - ASSERT_EQ(re_sig.InputNames().size(), 2); - EXPECT_STREQ(re_sig.InputNames()[0].c_str(), "y"); - EXPECT_STREQ(re_sig.InputNames()[1].c_str(), "x"); - ASSERT_EQ(re_sig.OutputNames().size(), 2); - EXPECT_STREQ(re_sig.OutputNames()[0].c_str(), "sum"); - EXPECT_STREQ(re_sig.OutputNames()[1].c_str(), "prod"); -} - -TEST(ModelLoadTest, WithOffsetTensorBuffer) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - - auto flatbuffer = - FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(kAddSimple)); - auto tfl_model = flatbuffer->get()->Unpack(); - const auto buf_ind = tfl_model->subgraphs[0]->tensors[0]->buffer; - auto& tfl_buffer = tfl_model->buffers[buf_ind]; - tfl_buffer->offset = 1; - tfl_buffer->size = 1; - auto model_buf = SerializeFlatbuffer(*tfl_model); - auto* packed_tfl = tflite::GetMutableModel(model_buf.Data()); - auto* buf = packed_tfl->mutable_buffers()->GetMutableObject(buf_ind); - ASSERT_TRUE(buf->mutate_offset(model_buf.Size())); - ASSERT_TRUE(buf->mutate_size(kTensorData.size())); - OwningBufferRef final_serializd(kTensorData.size() + - model_buf.Size()); - std::memcpy(final_serializd.Data(), model_buf.Data(), model_buf.Size()); - std::memcpy(final_serializd.Data() + model_buf.Size(), kTensorData.data(), - kTensorData.size()); - - auto litert_model = LoadModelFromBuffer(final_serializd); - ASSERT_TRUE(litert_model); - - const auto& weights_buffer = - litert_model->get()->Subgraph(0).Tensor(0).Weights(); - EXPECT_EQ(weights_buffer.Buffer().StrView(), kTensorData); - - // The loaded buffer should indicate that it should be also serialized as - // external. - const auto will_append = weights_buffer.GetBufferManager() - ->GetContext(weights_buffer.GetBufferId()) - ->get() - .should_append; - EXPECT_TRUE(will_append); - - // All tensors in the first subgraph should have the same buffer manager as - // the model. - for (auto* tensor : litert_model->get()->Subgraph(0).Tensors()) { - EXPECT_EQ(tensor->Weights().GetBufferManager(), - litert_model->get()->Buffers()); - } -} - -TEST(ModelSerializeTest, WithOffsetTensorBuffer) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& tensor = sg.EmplaceTensor(); - sg.EmplaceOp(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify the op contains an offset and size to the byte code and the correct - // name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); -} - -TEST(ModelSerializeTest, WithMultipleOffsetTensorBuffer) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - static constexpr absl::string_view kTensorData2 = "SOME_TENSOR_DATA2"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - sg.EmplaceOp(); - - { - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - { - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - OwningBufferRef buffer(kTensorData2); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify the op contains an offset and size to the byte code and the correct - // name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); - } - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[1]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData2); - } -} - -TEST(ModelSerializeTest, WithSingleExternalBuffer) { - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "foo"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify the op contains an offset and size to the byte code and the correct - // name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); -} - -TEST(ModelSerializeTest, WithMultipleUniqueExternalBuffer) { - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "foo"; - static constexpr absl::string_view kByteCode2 = "SOME_BYTE_CODE2"; - static constexpr absl::string_view kName2 = "bar"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& op2 = sg.EmplaceOp(); - - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - - OwningBufferRef buffer2(kByteCode2); - const auto buf_id2 = root.Buffers()->RegisterOwnedBuffer(std::move(buffer2)); - root.AttachAssetToOp(&op2, buf_id2, std::string(kName2)); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify both ops contains an offset and size to the byte code and the - // correct name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[1]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName2); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode2); - } -} - -TEST(ModelSerializeTest, WithSharedExternalBuffer) { - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "foo"; - static constexpr absl::string_view kName2 = "bar"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& op2 = sg.EmplaceOp(); - - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - root.AttachAssetToOp(&op2, buf_id, std::string(kName2)); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify both ops point to the same appended buffer. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[1]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName2); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } -} - -TEST(ModelSerializeTest, WithOffsetTensorBufferAndOpAsset) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "name"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - { - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - { - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - } - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - auto tfl = fb->get()->Unpack(); - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } -} - -TEST(ModelSerializeTest, WithOffsetTensorBufferAndOpAssetHasAlignment) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "name"; - static constexpr size_t kAlignment = 32; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - { - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - { - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - } - - auto serialized = SerializeModel(std::move(root), kAlignment); - ASSERT_TRUE(serialized); - - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - auto tfl = fb->get()->Unpack(); - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - ASSERT_EQ(dispatch_opts.bytecode_offset % kAlignment, 0); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } -} - -// Tests that explicitly check litert graph structure. -//===--------------------------------------------------------------------------- - -using AddSimpleTest = TestWithModelFactory; - -TEST_P(AddSimpleTest, CheckGraph) { - auto model = LoadModel(); - ASSERT_TRUE(model); - - // func(arg0) - // output = tfl.add(arg0, arg0) - // return(output) - // - - auto subgraph = model->MainSubgraph(); - const auto subgraph_inputs = subgraph->Inputs(); - const auto subgraph_outputs = subgraph->Outputs(); - const auto ops = subgraph->Ops(); - - ASSERT_EQ(subgraph_inputs.size(), 1); - ASSERT_EQ(subgraph_outputs.size(), 1); - - const auto& internal_ops = subgraph->Get()->Ops(); - ASSERT_TRUE( - ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); - ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); - - ASSERT_EQ(ops.size(), 1); - const auto& op = ops.front(); - - const TensorTypeInfo float_2by2_type(ElementType::Float32, {2, 2}); - ASSERT_TRUE( - MatchOpType(op, {float_2by2_type, float_2by2_type}, {float_2by2_type})); - EXPECT_EQ(op.Code(), kLiteRtOpCodeTflAdd); - - const auto op_inputs = op.Inputs(); - ASSERT_EQ(op_inputs.size(), 2); - ASSERT_EQ(op_inputs.front().Get(), subgraph_inputs.front().Get()); - ASSERT_EQ(op_inputs.front().Get(), op_inputs.back().Get()); - - const auto op_outputs = op.Outputs(); - ASSERT_EQ(op_outputs.size(), 1); - ASSERT_EQ(op_outputs.front().Get(), subgraph_outputs.front().Get()); - - ASSERT_FALSE(subgraph_outputs.front().IsConstant()); - ASSERT_FALSE(subgraph_inputs.front().IsConstant()); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddSimpleTest, - Values(MakeLoadFactory(kAddSimple))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddSimpleTest, - Values(MakeRoundTripFactory(kAddSimple))); - -using AddCstTest = TestWithModelFactory; - -TEST_P(AddCstTest, CheckGraph) { - auto model = LoadModel(); - ASSERT_TRUE(model); - - // func(arg0) - // cst = ConstantTensor([1, 2, 3, 4]) - // output = tfl.add(arg0, cst) - // return(output) - // - - auto subgraph = model->MainSubgraph(); - const auto subgraph_inputs = subgraph->Inputs(); - const auto subgraph_outputs = subgraph->Outputs(); - const auto ops = subgraph->Ops(); - - ASSERT_EQ(subgraph_inputs.size(), 1); - ASSERT_EQ(subgraph_outputs.size(), 1); - - const auto& internal_ops = subgraph->Get()->Ops(); - ASSERT_TRUE( - ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); - ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); - - ASSERT_EQ(ops.size(), 1); - const auto& op = ops.front(); - - const TensorTypeInfo float_by4_type(ElementType::Float32, {4}); - ASSERT_TRUE( - MatchOpType(op, {float_by4_type, float_by4_type}, {float_by4_type})); - EXPECT_EQ(op.Code(), kLiteRtOpCodeTflAdd); - - const auto op_inputs = op.Inputs(); - ASSERT_EQ(op_inputs.size(), 2); - ASSERT_EQ(op_inputs.front().Get(), subgraph_inputs.front().Get()); - ASSERT_TRUE(MatchWeights(op_inputs.back(), - absl::Span({1.0, 2.0, 3.0, 4.0}))); - - const auto op_outputs = op.Outputs(); - ASSERT_EQ(op_outputs.size(), 1); - ASSERT_EQ(op_outputs.front().Get(), subgraph_outputs.front().Get()); - - ASSERT_FALSE(subgraph_outputs.front().IsConstant()); - ASSERT_FALSE(subgraph_inputs.front().IsConstant()); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddCstTest, - Values(MakeLoadFactory(kAddCst))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddCstTest, - Values(MakeRoundTripFactory(kAddCst))); - -using SimpleMultiOpTest = TestWithModelFactory; - -TEST_P(SimpleMultiOpTest, CheckGraph) { - auto model = LoadModel(); - ASSERT_TRUE(model); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - auto subgraph = model->MainSubgraph(); - const auto subgraph_inputs = subgraph->Inputs(); - const auto subgraph_outputs = subgraph->Outputs(); - const auto ops = subgraph->Ops(); - - ASSERT_EQ(subgraph_inputs.size(), 1); - ASSERT_EQ(subgraph_outputs.size(), 1); - - const auto& internal_ops = subgraph->Get()->Ops(); - ASSERT_TRUE( - ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); - ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); - - ASSERT_EQ(ops.size(), 4); - - for (const auto& op : ops) { - const auto inputs = op.Inputs(); - ASSERT_EQ(inputs.size(), 2); - ASSERT_EQ(inputs.front().Get(), inputs.back().Get()); - } - - const TensorTypeInfo float_2by2_type(ElementType::Float32, {2, 2}); - - ASSERT_TRUE(MatchOpType(ops.at(2), {float_2by2_type, float_2by2_type}, - {float_2by2_type})); - EXPECT_EQ(ops.at(2).Code(), kLiteRtOpCodeTflMul); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiOpTest, - Values(MakeLoadFactory(kSimpleMultiOp))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiOpTest, - Values(MakeRoundTripFactory(kSimpleMultiOp))); - -using SimpleMultiSubgraphTest = TestWithModelFactory; - -TEST_P(SimpleMultiSubgraphTest, CheckGraph) { - auto model_wrap = LoadModel(); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap->Get(); - - ASSERT_EQ(model.NumSubgraphs(), 3); - - { - auto& main = *model.MainSubgraph(); - EXPECT_EQ(main.NumInputs(), 1); - EXPECT_EQ(main.NumOutputs(), 1); - EXPECT_EQ(main.Ops().size(), 1); - EXPECT_EQ(main.Tensors().size(), 3); - auto& op = main.Op(0); - auto* cst = op.Inputs().back(); - auto data = Tensor(cst).WeightsData(); - ASSERT_TRUE(data); - EXPECT_THAT(*data, Each(FloatEq(-1.0))); - EXPECT_TRUE(ValidateLocalTopology(main.Ops().cbegin(), main.Ops().cend())); - EXPECT_TRUE(ValidateSubgraphIO(main)); - } - - { - auto& func1 = model.Subgraph(1); - EXPECT_EQ(func1.NumInputs(), 1); - EXPECT_EQ(func1.NumOutputs(), 1); - EXPECT_EQ(func1.Ops().size(), 1); - EXPECT_EQ(func1.Tensors().size(), 3); - auto& op = func1.Op(0); - auto* cst = op.Inputs().back(); - auto data = Tensor(cst).WeightsData(); - ASSERT_TRUE(data); - EXPECT_THAT(*data, Each(FloatEq(1.0))); - EXPECT_TRUE( - ValidateLocalTopology(func1.Ops().cbegin(), func1.Ops().cend())); - EXPECT_TRUE(ValidateSubgraphIO(func1)); - } - - { - auto& func2 = model.Subgraph(2); - EXPECT_EQ(func2.NumInputs(), 1); - EXPECT_EQ(func2.NumOutputs(), 1); - EXPECT_EQ(func2.Ops().size(), 1); - EXPECT_EQ(func2.Tensors().size(), 3); - auto& op = func2.Op(0); - auto* cst = op.Inputs().back(); - auto data = Tensor(cst).WeightsData(); - ASSERT_TRUE(data); - EXPECT_THAT(*data, Each(FloatEq(2.0))); - EXPECT_TRUE( - ValidateLocalTopology(func2.Ops().cbegin(), func2.Ops().cend())); - EXPECT_TRUE(ValidateSubgraphIO(func2)); - } -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiSubgraphTest, - Values(MakeLoadFactory(kSimpleMultiSubgraph))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiSubgraphTest, - Values(MakeRoundTripFactory(kSimpleMultiSubgraph))); - -// Test when flatbuffer export has optimized multiple tensors to share the -// same buffer. -using MultiSubgraphDupeConstTest = TestWithModelFactory; - -TEST_P(MultiSubgraphDupeConstTest, CheckGraph) { - static constexpr std::array kWeights = {1.0, 2.0, 3.0, 4.0}; - - auto model_wrap = LoadModel(); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap->Get(); - - ASSERT_EQ(model.NumSubgraphs(), 2); - - { - ASSERT_EQ(model.Subgraph(0).Ops().size(), 1); - ASSERT_EQ(model.Subgraph(0).Tensors().size(), 3); - auto& cst = model.Subgraph(0).Op(0).Input(1); - Tensor t(&cst); - EXPECT_THAT(*t.WeightsData(), ElementsAreArray(kWeights)); - } - - { - ASSERT_EQ(model.Subgraph(1).Ops().size(), 1); - ASSERT_EQ(model.Subgraph(1).Tensors().size(), 3); - auto& cst = model.Subgraph(1).Op(0).Input(1); - Tensor t(&cst); - EXPECT_THAT(*t.WeightsData(), ElementsAreArray(kWeights)); - } - auto buf_id_0 = model.Subgraph(0).Op(0).Input(1).Weights().GetBufferId(); - auto buf_id_1 = model.Subgraph(1).Op(0).Input(1).Weights().GetBufferId(); - ASSERT_EQ(buf_id_0, buf_id_1); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, MultiSubgraphDupeConstTest, - Values(MakeLoadFactory(kCstMultiSubgraph))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, MultiSubgraphDupeConstTest, - Values(MakeRoundTripFactory(kCstMultiSubgraph))); - -// Tests that programmatically check litert against tflite models. -//===--------------------------------------------------------------------------- - -using ModelLoadOpCheckTest = TestWithModelPath; - -TEST_P(ModelLoadOpCheckTest, CheckOps) { - const auto model_path = GetTestModelPath(); - - auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(model_path); - ASSERT_TRUE(flatbuffer); - auto expected_fb = flatbuffer->get()->Unpack(); - - auto model = LoadModelFromFile(model_path); - ASSERT_TRUE(model); - - const auto* subgraph = model->get()->MainSubgraph(); - const auto& ops = subgraph->Ops(); - - const auto& fb_subgraph = *expected_fb->subgraphs.front(); - const auto& fb_ops = fb_subgraph.operators; - const auto& fb_tensors = fb_subgraph.tensors; - - ASSERT_EQ(ops.size(), fb_ops.size()); - - auto get_tfl_tensor = [&](uint32_t ind) -> const TflTensor& { - return *fb_tensors.at(ind); - }; - - for (auto i = 0; i < ops.size(); ++i) { - ASSERT_TRUE(EqualsFbOp(*ops.at(i), *fb_ops.at(i), get_tfl_tensor)); - } -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadQuantizedOpCheckTest, ModelLoadOpCheckTest, - ::testing::ValuesIn(kAllQModels)); - -INSTANTIATE_TEST_SUITE_P(ModelLoadDynamicOpCheckTest, ModelLoadOpCheckTest, - ::testing::ValuesIn({kDynamicShapeModel})); - -using ModelSerializeOpCheckTest = TestWithModelPath; - -TEST_P(ModelSerializeOpCheckTest, CheckOps) { - const auto model_path = GetTestModelPath(); - - // Save the initial fb for comparison. - auto expected_fb_data = FlatbufferWrapper::CreateFromTflFile(model_path); - ASSERT_TRUE(expected_fb_data); - auto expected_fb = expected_fb_data->get()->Unpack(); - - // Round trip the model. - auto model = LoadModelFromFile(model_path); - ASSERT_TRUE(model); - auto serialized = SerializeModel(std::move(**model)); - - auto actual_fb_data = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(actual_fb_data); - auto actual_fb = actual_fb_data->get()->Unpack(); - - const auto& expected_fb_subgraph = *expected_fb->subgraphs.front(); - const auto& expected_fb_ops = expected_fb_subgraph.operators; - const auto& expected_fb_tensors = expected_fb_subgraph.tensors; - - const auto& actual_fb_subgraph = *actual_fb->subgraphs.front(); - const auto& actual_fb_ops = actual_fb_subgraph.operators; - const auto& actual_fb_tensors = actual_fb_subgraph.tensors; - - ASSERT_EQ(expected_fb_ops.size(), actual_fb_ops.size()); - for (auto i = 0; i < actual_fb_ops.size(); ++i) { - const auto& expected = *expected_fb_ops.at(i); - const auto& actual = *actual_fb_ops.at(i); - EXPECT_EQ(expected.inputs.size(), actual.inputs.size()); - EXPECT_EQ(expected.outputs.size(), actual.outputs.size()); - } - - ASSERT_EQ(expected_fb_tensors.size(), actual_fb_tensors.size()); - for (auto i = 0; i < actual_fb_tensors.size(); ++i) { - const auto& expected = *expected_fb_tensors.at(i); - const auto& actual = *actual_fb_tensors.at(i); - - EXPECT_EQ(actual.type, expected.type); - EXPECT_EQ(actual.shape, expected.shape); - EXPECT_EQ(actual.shape_signature, expected.shape_signature); - - const auto expected_q_params = expected.quantization.get(); - const auto actual_q_params = actual.quantization.get(); - - const auto neither_quantized = - !IsQuantized(expected_q_params) && !IsQuantized(actual_q_params); - const auto both_per_tensor = IsPerTensorQuantized(expected_q_params) && - IsPerTensorQuantized(actual_q_params); - ASSERT_TRUE(neither_quantized || both_per_tensor); - - if (both_per_tensor) { - const auto expected_per_tensor = AsPerTensorQparams(expected_q_params); - const auto actual_per_tensor = AsPerTensorQparams(actual_q_params); - EXPECT_EQ(*expected_per_tensor, *actual_per_tensor); - } - } -} - -INSTANTIATE_TEST_SUITE_P(ModelSerializeOpCheckTest, ModelSerializeOpCheckTest, - ::testing::ValuesIn({kOneMul, kDynamicShapeModel})); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeQuantizedOpCheckTest, - ModelSerializeOpCheckTest, - ::testing::ValuesIn(kAllQModels)); - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc deleted file mode 100644 index 55bb72fa0c2961..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_file_test_util.h" - -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -namespace { - -template -bool EqualsFbQuantizationDetail(LiteRtQType litert_quantization, - const TflQuantization* tfl_quantization) { - return false; -} - -template <> -bool EqualsFbQuantizationDetail( - LiteRtQuantizationPerTensor litert_quantization, - const TflQuantization* tfl_quantization) { - auto tfl_q_params = AsPerTensorQparams(tfl_quantization); - if (!tfl_q_params) return false; - return litert_quantization.zero_point == tfl_q_params->first && - litert_quantization.scale == tfl_q_params->second; -} - -template <> -bool EqualsFbQuantizationDetail( - LiteRtQuantizationPerChannel litert_quantization, - const TflQuantization* tfl_quantization) { - auto tfl_q_params = AsPerChannelQparams(tfl_quantization); - if (!tfl_q_params) return false; - const auto& [quantized_dimension, num_channels, zero_points, scales] = - *tfl_q_params; - const auto qd_eq = - litert_quantization.quantized_dimension == quantized_dimension; - const auto num_chan_eq = litert_quantization.num_channels == num_channels; - const auto zeros_eq = std::equal(zero_points.begin(), zero_points.end(), - litert_quantization.zero_points); - const auto scales_eq = - std::equal(scales.begin(), scales.end(), litert_quantization.scales); - return qd_eq && num_chan_eq && zeros_eq && scales_eq; -} -template -bool EqualsFbTensorTypeDetail(LiteRtTenzorType litert_tensor_type, - const TflTensorType& tfl_tensor) { - LITERT_LOG(LITERT_ERROR, "LiteRtTensorType not supported"); - return false; -} - -template <> -bool EqualsFbTensorTypeDetail( - LiteRtRankedTensorType litert_tensor_type, - const TflTensorType& tfl_tensor_type) { - auto tfl_shape = AsDynamicShape(tfl_tensor_type.second); - if (!tfl_shape) { - LITERT_LOG(LITERT_ERROR, "Not ranked shape"); - return false; - } - - if (MapElementType(tfl_tensor_type.first) != - static_cast(litert_tensor_type.element_type)) { - LITERT_LOG(LITERT_ERROR, "Element type not equal"); - return false; - } - - auto same_or_both_dyn = [](auto l, auto r) { - const auto same_static = l >= 0 && l == r; - const auto both_dyn = l < 0 && r < 0; - return same_static || both_dyn; - }; - - auto& layout = litert_tensor_type.layout; - const bool shape_eq = - AllZip(*tfl_shape, absl::MakeConstSpan(layout.dimensions, layout.rank), - same_or_both_dyn); - if (!shape_eq) { - LITERT_LOG(LITERT_ERROR, "Shapes are not equal"); - return false; - } - - return true; -} - -} // namespace - -bool EqualsFbQuantization(const Quantization& litert_quantization, - const TflQuantization* tfl_quantization) { - switch (litert_quantization.first) { - case kLiteRtQuantizationPerTensor: - return EqualsFbQuantizationDetail(litert_quantization.second.per_tensor, - tfl_quantization); - case kLiteRtQuantizationPerChannel: - return EqualsFbQuantizationDetail(litert_quantization.second.per_channel, - tfl_quantization); - case kLiteRtQuantizationNone: - return !IsQuantized(tfl_quantization); - default: - // Not implemented yet. - return false; - } -} - -// Compare tensor type within litert tensor to the type within flatbuffer -// tensor. -bool EqualsFbTensorType(const TensorType& litert_tensor_type, - const TflTensorType& tfl_tensor_type) { - switch (litert_tensor_type.first) { - case kLiteRtRankedTensorType: - return EqualsFbTensorTypeDetail( - litert_tensor_type.second.ranked_tensor_type, tfl_tensor_type); - default: - LITERT_LOG(LITERT_ERROR, "Tensor kind not supported"); - // Not implemented yet. - return false; - } -} - -bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, - const TflTensor& tfl_tensor) { - if (!EqualsFbTensorType(litert_tensor.Type(), - {tfl_tensor.type, TflShapeInfo(tfl_tensor)})) { - LITERT_LOG(LITERT_ERROR, "Tensor not same type"); - return false; - } - - if (!EqualsFbQuantization(litert_tensor.Qparams(), - tfl_tensor.quantization.get())) { - LITERT_LOG(LITERT_ERROR, "Tensor not same quantization"); - return false; - } - - return true; -} - -bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, - GetTflTensor get_tfl_tensor) { - auto check_tensors = [&](auto& litert_tensors, auto& tfl_tensors) { - if (litert_tensors.size() != tfl_tensors.size()) { - LITERT_LOG(LITERT_ERROR, "Tensors not same size"); - return false; - } - - for (auto i = 0; i < litert_tensors.size(); ++i) { - const auto& fb_tensor = get_tfl_tensor(tfl_tensors.at(i)).get(); - const auto& litert_tensor = *litert_tensors.at(i); - - if (!EqualsFbTensor(litert_tensor, fb_tensor)) { - LITERT_LOG(LITERT_ERROR, "Tensor %d not same", i); - return false; - } - } - - return true; - }; - - return check_tensors(litert_op.Inputs(), tfl_op.inputs) && - check_tensors(litert_op.Outputs(), tfl_op.outputs); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h deleted file mode 100644 index df0138e321c063..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -// Callback to get a tfl tensor from it's index. -using GetTflTensor = - std::function(uint32_t ind)>; - -// Compare q-params for having the same type and values. -bool EqualsFbQuantization(const Quantization& litert_quantization, - const TflQuantization* tfl_quantization); - -// Compare tensor types for having the same shape and element type. -bool EqualsFbTensorType(const TensorType& litert_tensor_type, - const TflTensorType& tfl_tensor_type); - -// Compare litert op to flatbuffer op along with their input/output tensors -// types and quantization. Takes a callback to lookup tfl tensors the indices -// within the tfl op. -bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, - GetTflTensor get_tfl_tensor); - -// Compare litert tensor to flatbuffer tensor for having same types and -// quantization. -bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, - const TflTensor& tfl_tensor); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph.cc b/tensorflow/lite/experimental/litert/core/model/model_graph.cc deleted file mode 100644 index f7a8bb4c80ef61..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_graph.cc +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -namespace { - -bool IsOpDead(const LiteRtOpT& op) { - return op.Inputs().empty() && op.Outputs().empty(); -} - -bool IsTensorDead(const LiteRtTensorT& tensor) { - return tensor.DefiningOp() == nullptr && tensor.NumUses() == 0; -} - -} // namespace - -void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest) { - dest.SetName({src.Name().cbegin(), src.Name().cend()}); - dest.SetQarams(src.Qparams()); - dest.SetType(src.Type()); - - // Manully copy per-channel quantization params,quant array is owned by - // tensor. - if (src.Qparams().first == kLiteRtQuantizationPerChannel) { - std::vector scales( - src.Qparams().second.per_channel.scales, - src.Qparams().second.per_channel.scales + - src.Qparams().second.per_channel.num_channels); - std::vector zero_points( - src.Qparams().second.per_channel.zero_points, - src.Qparams().second.per_channel.zero_points + - src.Qparams().second.per_channel.num_channels); - Quantization dest_qparams = MakePerChannelQuantization( - scales, zero_points, - src.Qparams().second.per_channel.quantized_dimension, - [&dest](auto s) { return dest.RequestScratchBuffer(s); }); - dest.SetQarams(std::move(dest_qparams)); - } - - // Move weight buffer from src to dest. - const auto& src_weights = src.Weights(); - auto& dest_weights = dest.Weights(); - - const auto same_manager = - src_weights.GetBufferManager() == dest_weights.GetBufferManager(); - - if (same_manager) { - dest_weights.SetBufferId(src_weights.GetBufferId()); - } else { - OwningBufferRef weights_buffer(src_weights.Buffer().Data(), - src_weights.Buffer().Size()); - SetWeightsFromOwnedBuffer(dest_weights, std::move(weights_buffer)); - } -} - -void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest) { - dest.SetCustomOptions(src.CustomOptions().Data(), src.CustomOptions().Size()); - litert::internal::SetTflOptions(dest, litert::internal::GetTflOptions(src)); - litert::internal::SetTflOpCodeInd(dest, - litert::internal::GetTflOpCodeInd(src)); - dest.SetOpCode(src.OpCode()); -} - -LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src) { - auto& new_tensor = parent.EmplaceTensor(); - CloneTo(src, new_tensor); - return new_tensor; -} - -LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src) { - auto& new_op = parent.EmplaceOp(); - CloneTo(src, new_op); - return new_op; -} - -std::optional FindInput(const LiteRtOpT& op, - const LiteRtTensorT& tensor) { - return FindInd(op.Inputs().cbegin(), op.Inputs().cend(), &tensor); -} - -std::optional FindOutput(const LiteRtOpT& op, - const LiteRtTensorT& tensor) { - return FindInd(op.Outputs().cbegin(), op.Outputs().cend(), &tensor); -} - -std::optional FindInput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor) { - return FindInd(subgraph.Inputs().cbegin(), subgraph.Inputs().cend(), &tensor); -} - -std::optional FindOutput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor) { - return FindInd(subgraph.Outputs().cbegin(), subgraph.Outputs().cend(), - &tensor); -} - -UseIndices FindUseInds(const LiteRtTensorT& tensor, const LiteRtOpT& op) { - UseIndices res; - for (auto i = 0; i < tensor.NumUses(); ++i) { - if (tensor.Users().at(i) == &op) { - res.push_back(i); - } - } - return res; -} - -bool IsConstant(const LiteRtTensorT& tensor) { - bool is_zero_sized = false; - auto layout = tensor.Type().second.ranked_tensor_type.layout; - if (layout.rank == 1) { - if (layout.dimensions[0] == 0) { - is_zero_sized = true; - } - } - const auto is_const = tensor.Weights().Buffer().Size() > 0 || is_zero_sized; - ABSL_DCHECK(!is_const || tensor.DefiningOp() == nullptr) - << "Constant tensors should not be defined by an op"; - return is_const; -} - -void AttachInput(LiteRtTensor tensor, LiteRtOpT& op) { - op.Inputs().push_back(tensor); - tensor->Users().push_back(&op); - tensor->UserArgInds().push_back(op.Inputs().size() - 1); -} - -void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op) { - ABSL_DCHECK(tensor->DefiningOp() == nullptr) - << "Cannot add an already defined tensor as op output"; - op.Outputs().push_back(tensor); - tensor->SetDefiningOp(op, op.Outputs().size() - 1); -} - -LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind) { - ABSL_DCHECK(input_ind < op.Inputs().size()) << "Removing tensor index oob"; - auto& input = op.Input(input_ind); - - // Find the index of the use for the given in edge. - auto target_use_ind = -1; - for (auto i = 0; i < input.NumUses(); ++i) { - if (input.Users().at(i) == &op && input.UserArgInds().at(i) == input_ind) { - target_use_ind = i; - } - } - ABSL_DCHECK_GE(target_use_ind, 0) << "Malformed graph"; - - // Slide latter input use arg inds to the left. - for (auto i = input_ind + 1; i < op.Inputs().size(); ++i) { - auto& r_in = op.Input(i); - for (auto u = 0; u < r_in.NumUses(); ++u) { - auto& r_arg_ind = r_in.UserArgInds().at(u); - if (r_in.Users().at(u) == &op && r_arg_ind > input_ind) { - r_arg_ind -= 1; - } - } - } - - // Update the edges. - input.RemoveUse(target_use_ind); - op.RemoveInput(input_ind); - - return &input; -} - -bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor) { - return FindInput(subgraph, tensor) || FindOutput(subgraph, tensor); -} - -LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind) { - ABSL_DCHECK(output_ind < op.Outputs().size()) << "Removing tensor index oob"; - auto& output = op.Output(output_ind); - output.ClearDefiningOp(); - op.RemoveOutput(output_ind); - return &output; -} - -void Drop(LiteRtOpT& litert_op) { - while (!litert_op.Inputs().empty()) { - DisconnectInput(litert_op, 0); - } - while (!litert_op.Outputs().empty()) { - DisconnectOutput(litert_op, 0); - } -} - -bool DCE(LiteRtSubgraphT& subgraph) { - const auto ops_removed = subgraph.RemoveOpIf(IsOpDead); - - auto rm_tensor = [&subgraph = std::as_const(subgraph)](const auto& t) { - return IsTensorDead(t) && !IsIO(subgraph, t); - }; - const auto tensors_removed = subgraph.RemoveTensorIf(rm_tensor); - LITERT_LOG(LITERT_INFO, "Removed %d ops, %d tensors", ops_removed, - tensors_removed); - - return (ops_removed + tensors_removed) > 0; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph.h b/tensorflow/lite/experimental/litert/core/model/model_graph.h deleted file mode 100644 index 55e00e90c833ee..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_graph.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ - -#include -#include - -#include "absl/container/inlined_vector.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// using IrMapping = absl::flat_hash_map; - -// CLONING - -// Clones the basic data between tensors (like name and data) but not -// things related to incoming/outgoing edges (users, defining op) or weights. -void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest); - -// Clones the basic data between ops (like op code and options) but -// things related to incoming/outgoing edges (input/output tensors). -void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest); - -// Same as clone to, but allocates a the dest tensor into given subgraph. -LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src); - -// Same as clone to, but allocates a the dest op into given subgraph. -LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src); - -// OBSERVERS - -// Checks if tensor is input to given op, return its index if so. -std::optional FindInput(const LiteRtOpT& op, - const LiteRtTensorT& tensor); - -// Checks if tensor is output to given op, return its index if so. -std::optional FindOutput(const LiteRtOpT& op, - const LiteRtTensorT& tensor); - -// Checks if tensor is input to given subgraph, return its index if so. -std::optional FindInput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor); - -// Checks if tensor is output to given subgraph, return its index if so. -std::optional FindOutput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor); - -// Check if tensor is part of subgraph IO. -bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor); - -using UseIndices = - absl::InlinedVector; - -// Checks if tensor is used by op, return the use inds for each use of tensor by -// op (there may be multiple). These are the indexes to call -// LiteRtTensorT::GetUse with. -UseIndices FindUseInds(const LiteRtTensorT& tensor, const LiteRtOpT& op); - -// Is this tensor a constant tensor? -bool IsConstant(const LiteRtTensorT& tensor); - -// MUTATORS - -// Attaches the pre-allocated tensor to be an input of given op. -void AttachInput(LiteRtTensor tensor, LiteRtOpT& op); - -// Attaches the pre-allocated tensor to be an output of given op. -void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op); - -// Remove the input edge from an op. Return the disconnected tensor. -LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind); - -// Remove an output edge from an op. Return the disconnected tensor. -LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind); - -// Remove all incoming and outgoing edges from this op. This can prep nodes -// for removal in DCE. -void Drop(LiteRtOpT& litert_op); - -// Run very naive dead code elimination. Removes only ops/tensors that have no -// in/out edges. Ops are handled first. Ignores subgraph IO. Not recursive and -// does only one pass. Returns if the graph was modified. -// NOTE: This de-allocates removed objects, only use when references to these -// objects will not be used. -// TODO: Update this with complete work-list based approach. -bool DCE(LiteRtSubgraphT& subgraph); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc b/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc deleted file mode 100644 index 4258bc9edb7418..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { -namespace { - -using ::testing::UnorderedElementsAreArray; - -// Custom matcher; example: -// ``` -// LiteRtTensor tensor ... -// EXPECT_THAT(tensor, HasRankedType(kLiteRtInt, absl::MakeSpan({2, 2}))); -// ``` -// TODO: Update to use dumping API directly and move to shared header. -MATCHER_P2(HasRankedType, element_type, shape, "") { - if (arg.Type().first != kLiteRtRankedTensorType) { - *result_listener << "Not ranked tensor type"; - return false; - } - const auto& ranked_tensor_type = arg.Type().second.ranked_tensor_type; - const auto& layout = ranked_tensor_type.layout; - - const auto element_type_eq = ranked_tensor_type.element_type == element_type; - const auto rank_eq = layout.rank == std::size(shape); - - auto actual_shape = absl::MakeConstSpan(layout.dimensions, layout.rank); - auto expected_shape = - absl::MakeConstSpan(std::cbegin(shape), std::cend(shape)); - const auto shape_eq = actual_shape == expected_shape; - - if (shape_eq && element_type_eq && rank_eq) { - return true; - } - - *result_listener << "\n"; - if (!shape_eq) { - *result_listener << "Not correct shape\n"; - } - if (!element_type_eq) { - *result_listener << "Not correct element type\n"; - } - if (!rank_eq) { - *result_listener << "Not correct rank\n"; - } - - *result_listener << absl::StreamFormat("Actual ElementType is: %d\n", - ranked_tensor_type.element_type); - *result_listener << absl::StreamFormat("Actual Rank is: %lu\n", layout.rank); - *result_listener << "Actual shape is: { "; - for (const auto d : actual_shape) { - *result_listener << absl::StreamFormat("%d, ", d); - } - *result_listener << "}\n"; - - return false; -} - -using ::testing::ElementsAreArray; - -static constexpr size_t kRank = 1; -static constexpr int32_t kDims[] = {2}; -static constexpr absl::Span kDimsSpan(kDims); -static constexpr auto kType = kLiteRtElementTypeInt32; -static constexpr absl::string_view kCustomOptions = "OPTIONS"; -static constexpr auto kOpCode = kLiteRtOpCodeTflMul; - -LiteRtTensorT TestTensor() { - LiteRtTensorT tensor; - tensor.Type().first = kLiteRtRankedTensorType; - tensor.Type().second.ranked_tensor_type.element_type = kType; - tensor.Type().second.ranked_tensor_type.layout.dimensions[0] = kDims[0]; - tensor.Type().second.ranked_tensor_type.layout.rank = kRank; - return tensor; -} - -LiteRtTensorT& TestTensor(LiteRtTensorT& tensor) { - tensor.Type().first = kLiteRtRankedTensorType; - tensor.Type().second.ranked_tensor_type.element_type = kType; - tensor.Type().second.ranked_tensor_type.layout.dimensions[0] = kDims[0]; - tensor.Type().second.ranked_tensor_type.layout.rank = kRank; - return tensor; -} - -LiteRtOpT TestOp() { - LiteRtOpT op; - op.SetOpCode(kOpCode); - op.SetCustomOptions(kCustomOptions); - return op; -} - -TEST(ModelGraphTest, CloneTensor) { - LiteRtTensorT dest; - CloneTo(TestTensor(), dest); - EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); -} - -TEST(ModelQuantizationTypeTest, ClonePerChannelQuantization) { - static constexpr std::array kScale = {1.0f, 2.0f}; - static constexpr std::array kZero = {1L, 2L}; - static constexpr int32_t kQdim = 0; - - IrAllocator tensor_allocator; - auto& tensor = tensor_allocator.EmplaceBack(); - LiteRtTensorT dest; - const auto quant = MakePerChannelQuantization( - kScale, kZero, kQdim, - [&tensor](auto s) { return tensor.RequestScratchBuffer(s); }); - - ASSERT_EQ(quant.first, kLiteRtQuantizationPerChannel); - const auto& per_channel = quant.second.per_channel; - - const auto size = per_channel.num_channels; - ASSERT_EQ(size, 2); - EXPECT_EQ(per_channel.quantized_dimension, 0); - tensor.SetQarams(quant); - - CloneTo(tensor, dest); - // Mimic DCE. - tensor_allocator.RemoveIf([](auto& t) { return true; }); - auto dest_quant = dest.Qparams(); - - auto scales = absl::MakeConstSpan(dest_quant.second.per_channel.scales, - dest_quant.second.per_channel.num_channels); - auto zeros = absl::MakeConstSpan(dest_quant.second.per_channel.zero_points, - dest_quant.second.per_channel.num_channels); - - ASSERT_EQ(scales.size(), 2); - ASSERT_EQ(zeros.size(), 2); - EXPECT_THAT(scales, ElementsAreArray(kScale)); - EXPECT_THAT(zeros, ElementsAreArray(kZero)); -} - -TEST(ModelGraphTest, MakeCloneTensor) { - LiteRtSubgraphT subgraph; - auto& dest = MakeClone(subgraph, TestTensor()); - EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); -} - -TEST(ModelGraphTest, CloneCstSameManager) { - OwningBufferRef buffer("DATA"); - LiteRtModelT model; - const auto num_buffers = model.Buffers()->NumBuffers(); - auto& sg = model.EmplaceSubgraph(); - auto& src = TestTensor(sg.EmplaceTensor()); - SetWeightsFromUnownedBuffer(src.Weights(), buffer); - auto& dest = MakeClone(sg, src); - EXPECT_EQ(dest.Weights().Buffer().StrView(), buffer.StrView()); - EXPECT_EQ(model.Buffers()->NumBuffers(), num_buffers + 1); - EXPECT_EQ(dest.Weights().GetBufferId(), src.Weights().GetBufferId()); - EXPECT_EQ(dest.Weights().GetBufferManager(), - src.Weights().GetBufferManager()); - EXPECT_EQ(dest.Weights().Buffer().Data(), src.Weights().Buffer().Data()); -} - -TEST(ModelGraphTest, CloneCstDifferentManager) { - OwningBufferRef buffer("DATA"); - LiteRtSubgraphT sg; - auto& src = TestTensor(sg.EmplaceTensor()); - SetWeightsFromUnownedBuffer(src.Weights(), buffer); - auto& dest = MakeClone(sg, src); - EXPECT_EQ(dest.Weights().Buffer().StrView(), buffer.StrView()); - EXPECT_NE(dest.Weights().GetBufferManager(), - src.Weights().GetBufferManager()); - EXPECT_NE(dest.Weights().Buffer().Data(), src.Weights().Buffer().Data()); -} - -TEST(ModelGraphTest, CloneOp) { - LiteRtOpT dest; - CloneTo(TestOp(), dest); - EXPECT_EQ(dest.OpCode(), kOpCode); - EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); -} - -TEST(ModelGraphTest, MakeCloneOp) { - LiteRtSubgraphT subgraph; - auto& dest = MakeClone(subgraph, TestOp()); - EXPECT_EQ(dest.OpCode(), kOpCode); - EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); -} - -TEST(ModelGraphTest, OpFindInput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachInput(&tensor, op); - auto input = FindInput(op, tensor); - ASSERT_TRUE(input); - EXPECT_EQ(*input, 0); -} - -TEST(ModelGraphTest, OpFindOutput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachOutput(&tensor, op); - auto output = FindOutput(op, tensor); - ASSERT_TRUE(output); - EXPECT_EQ(*output, 0); -} - -TEST(ModelGraphTest, SubgraphFindInput) { - LiteRtSubgraphT subgraph; - auto tensor = TestTensor(); - subgraph.Inputs().push_back(&tensor); - auto input = FindInput(subgraph, tensor); - ASSERT_TRUE(input); - EXPECT_EQ(*input, 0); -} - -TEST(ModelGraphTest, SubgraphFindOutput) { - LiteRtSubgraphT subgraph; - auto tensor = TestTensor(); - subgraph.Outputs().push_back(&tensor); - auto output = FindOutput(subgraph, tensor); - ASSERT_TRUE(output); - EXPECT_EQ(*output, 0); -} - -TEST(ModelGraphTest, TensorFindUseInds) { - auto op1 = TestOp(); - auto op2 = TestOp(); - auto tensor = TestTensor(); - - AttachInput(&tensor, op1); - AttachInput(&tensor, op2); - AttachInput(&tensor, op1); - - auto use_inds = FindUseInds(tensor, op1); - auto uses = GetTensorUses(tensor, use_inds); - ASSERT_EQ(uses.size(), 2); - - LiteRtTensorT::UseVec expected = {{&op1, 0}, {&op1, 1}}; - EXPECT_THAT(uses, UnorderedElementsAreArray(expected)); -} - -TEST(ModelGraphTest, OpAttachInput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachInput(&tensor, op); - EXPECT_THAT(op.Inputs(), ElementsAreArray({&tensor})); - EXPECT_THAT(tensor.Users(), ElementsAreArray({&op})); - EXPECT_THAT(tensor.UserArgInds(), ElementsAreArray({0})); -} - -TEST(ModelGraphTest, OpAttachOutput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachOutput(&tensor, op); - EXPECT_THAT(op.Outputs(), ElementsAreArray({&tensor})); - EXPECT_EQ(tensor.DefiningOp(), &op); - EXPECT_EQ(tensor.DefiningOpOutInd(), 0); -} - -TEST(ModelGraphTest, DisconnectInputOp) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachInput(&tensor, op); - auto disconnected = DisconnectInput(op, 0); - EXPECT_EQ(disconnected, &tensor); - EXPECT_TRUE(op.Inputs().empty()); - EXPECT_TRUE(tensor.Users().empty()); - EXPECT_TRUE(tensor.UserArgInds().empty()); -} - -TEST(ModelGraphTest, DisconnectMiddleInputOp) { - auto op = TestOp(); - - auto tensor1 = TestTensor(); - auto tensor2 = TestTensor(); - auto tensor3 = TestTensor(); - - AttachInput(&tensor1, op); - AttachInput(&tensor2, op); - AttachInput(&tensor3, op); - - auto disconnected = DisconnectInput(op, 1); - - EXPECT_EQ(disconnected, &tensor2); - ASSERT_EQ(op.Inputs().size(), 2); - EXPECT_EQ(op.Inputs().front(), &tensor1); - EXPECT_EQ(op.Inputs().back(), &tensor3); - ASSERT_TRUE(tensor2.Users().empty()); - ASSERT_TRUE(tensor2.UserArgInds().empty()); - - ASSERT_TRUE(ValidateLocalTopology(op)); -} - -TEST(ModelGraphTest, DisconnectOutputOp) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachOutput(&tensor, op); - auto disconnected = DisconnectOutput(op, 0); - EXPECT_EQ(disconnected, &tensor); - EXPECT_EQ(tensor.DefiningOp(), nullptr); - EXPECT_TRUE(op.Outputs().empty()); -} - -TEST(ModelGraphTest, DropOp) { - LiteRtOpT op; - - LiteRtTensorT input1; - LiteRtTensorT input2; - LiteRtTensorT output; - - AttachInput(&input1, op); - AttachInput(&input2, op); - AttachOutput(&output, op); - - Drop(op); - - EXPECT_TRUE(op.Inputs().empty()); - EXPECT_TRUE(op.Outputs().empty()); - EXPECT_TRUE(input1.Users().empty()); - EXPECT_TRUE(input2.Users().empty()); - EXPECT_EQ(output.DefiningOp(), nullptr); -} - -TEST(ModelGraphTestDCE, NoDeadCode) { - LiteRtSubgraphT subgraph; - - auto& input = subgraph.EmplaceTensor(); - auto& output = subgraph.EmplaceTensor(); - - auto& op = subgraph.EmplaceOp(); - - AttachInput(&input, op); - AttachOutput(&output, op); - - subgraph.Inputs().push_back(&input); - subgraph.Outputs().push_back(&output); - - ASSERT_FALSE(DCE(subgraph)); - EXPECT_EQ(subgraph.Ops().size(), 1); - EXPECT_EQ(subgraph.Tensors().size(), 2); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -TEST(ModelGraphTestDCE, DeadTensor) { - LiteRtSubgraphT subgraph; - subgraph.EmplaceTensor(); - - ASSERT_TRUE(DCE(subgraph)); - EXPECT_TRUE(subgraph.Tensors().empty()); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -TEST(ModelGraphTestDCE, DeadOp) { - LiteRtSubgraphT subgraph; - subgraph.EmplaceOp(); - - ASSERT_TRUE(DCE(subgraph)); - EXPECT_TRUE(subgraph.Ops().empty()); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -TEST(ModelGraphTestDCE, SomeDead) { - LiteRtSubgraphT subgraph; - - auto& input = subgraph.EmplaceTensor(); - auto& output = subgraph.EmplaceTensor(); - - auto& op = subgraph.EmplaceOp(); - - AttachInput(&input, op); - AttachOutput(&output, op); - - // Dead - subgraph.EmplaceTensor(); - subgraph.EmplaceOp(); - - subgraph.Inputs().push_back(&input); - subgraph.Outputs().push_back(&output); - - ASSERT_TRUE(DCE(subgraph)); - EXPECT_EQ(subgraph.Ops().size(), 1); - EXPECT_EQ(subgraph.Tensors().size(), 2); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_load.cc b/tensorflow/lite/experimental/litert/core/model/model_load.cc deleted file mode 100644 index 86f11054608b05..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_load.cc +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { -namespace { - -// Provides a view of model-level resources when constructing litert graph. -class FlatbufferContext { - public: - using LiteRtBufferId = uint32_t; - using TflBufferInd = uint32_t; - using BufferIdMap = absl::flat_hash_map; - - FlatbufferContext(const FlatbufferWrapper& tfl_flatbuffer, - BufferManager* buffer_manager) - : tfl_flatbuffer_(tfl_flatbuffer), buffer_manager_(buffer_manager) {} - - void SetOpCode(LiteRtOpT& litert_op, uint32_t ind) { - const auto builtin_code = - PackedModel()->operator_codes()->Get(ind)->builtin_code(); - litert_op.SetOpCode(static_cast(builtin_code)); - litert::internal::SetTflOpCodeInd(litert_op, ind); - } - - // Get the buffer at the given index in the tflite model. - Expected GetTflBuffer(uint32_t ind) const { - const auto* packed_model = tfl_flatbuffer_.PackedModel(); - if (ind >= packed_model->buffers()->size()) { - LITERT_LOG(LITERT_ERROR, "Buffer index out of range"); - return Error(kLiteRtStatusErrorInvalidArgument); - } - return packed_model->buffers()->Get(ind); - } - - BufferManager* GetBufferManager() { return buffer_manager_; } - - const uint8_t* AllocBase() const { return tfl_flatbuffer_.AllocBase(); } - - const TflPackedModel* PackedModel() const { - return tfl_flatbuffer_.PackedModel(); - } - - BufferIdMap& RegisteredTflBufferIds() { return registered_tfl_buffer_ids_; } - - private: - const FlatbufferWrapper& tfl_flatbuffer_; - BufferManager* buffer_manager_; - BufferIdMap registered_tfl_buffer_ids_; -}; - -LiteRtStatus UnpackOp(FlatbufferContext& context, LiteRtSubgraphT& parent, - const TflPackedOp& tfl_op, LiteRtOpT& litert_op) { - // I/O TENSORS - - if (tfl_op.intermediates() && tfl_op.intermediates()->size() != 0) { - // TODO: b/365299994 - Support intermediates. - LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tfl_op.mutating_variable_inputs() && - tfl_op.mutating_variable_inputs()->size() != 0) { - // TODO: b/365299994 - Support mutating variable inputs. - LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - const auto num_inputs = tfl_op.inputs()->size(); - for (auto i = 0; i < num_inputs; ++i) { - const auto input_ind = tfl_op.inputs()->Get(i); - // Skipping optional input tensor. - if (input_ind == -1) { - continue; - } - AttachInput(&parent.Tensor(input_ind), litert_op); - } - - const auto num_outputs = tfl_op.outputs()->size(); - for (auto i = 0; i < num_outputs; ++i) { - const auto output_ind = tfl_op.outputs()->Get(i); - AttachOutput(&parent.Tensor(output_ind), litert_op); - } - - // OPTIONS - - if (tfl_op.large_custom_options_size() != 0) { - // TODO: b/365299994 - Support large custom options. - LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - const auto* custom_opts = tfl_op.custom_options(); - if (custom_opts) { - litert_op.SetCustomOptions(custom_opts->data(), custom_opts->size()); - } - - // TODO figure out how to parse builtins with the packed flatbuffer api. - TflOpPtr tfl_op_ptr(tfl_op.UnPack()); - litert::internal::SetTflOptions(litert_op, - std::move(tfl_op_ptr->builtin_options)); - litert::internal::SetTflOptions2(litert_op, - std::move(tfl_op_ptr->builtin_options_2)); - - // OP CODE - - context.SetOpCode(litert_op, tfl_op.opcode_index()); - - return kLiteRtStatusOk; -} - -struct TflBufferContext { - BufferRef buffer; - // Is buffer appended to the flatbuffer? - bool is_external; -}; - -Expected ReadBuffer(FlatbufferContext& context, - uint32_t buffer_ind) { - auto buffer = context.GetTflBuffer(buffer_ind); - if (!buffer) { - return buffer.Error(); - } - - const auto& tfl_buffer = **buffer; - - if (tfl_buffer.offset() != 0) { - // Data is appended to the end of the flatbuffer. - - const auto* alloc_base = context.AllocBase(); - const auto offset = tfl_buffer.offset(); - const auto size = tfl_buffer.size(); - - return TflBufferContext{BufferRef(alloc_base + offset, size), - true}; - } else if (tfl_buffer.data()) { - // Data is in the flatbuffer. - - const auto* start = tfl_buffer.data()->data(); - const auto size = tfl_buffer.data()->size(); - - return TflBufferContext{BufferRef(start, size), false}; - } else { - return TflBufferContext{}; - } -} - -LiteRtStatus UnpackTensor(FlatbufferContext& context, - const TflPackedTensor& tfl_tensor, - LiteRtTensorT& litert_tensor) { - const auto buffer_ind = tfl_tensor.buffer(); - if (buffer_ind != 0) { - auto buffer = ReadBuffer(context, buffer_ind); - if (!buffer) { - return buffer.Error().Status(); - } - - auto it = context.RegisteredTflBufferIds().find(buffer_ind); - if (it != context.RegisteredTflBufferIds().end()) { - litert_tensor.Weights().SetBufferId(it->second); - } else { - BufferContext lrt_buf_ctx; - lrt_buf_ctx.should_append = buffer->is_external; - SetWeightsFromUnownedBuffer(litert_tensor.Weights(), buffer->buffer, - lrt_buf_ctx); - context.RegisteredTflBufferIds()[buffer_ind] = - litert_tensor.Weights().GetBufferId(); - } - } - - // TENSOR TYPE - - TflTensorType tfl_tensor_type(tfl_tensor.type(), TflShapeInfo(tfl_tensor)); - auto tensor_type = MapTensorType(tfl_tensor_type); - if (!tensor_type) { - return tensor_type.Error().Status(); - } - - litert_tensor.SetType(std::move(*tensor_type)); - - // QUANTIZATION - - if (tfl_tensor.quantization()) { - TflQuantizationPtr tfl_quantization(tfl_tensor.quantization()->UnPack()); - auto quantization = MapQuantization(tfl_quantization.get(), litert_tensor); - if (!quantization) { - return quantization.Error().Status(); - } - litert_tensor.SetQarams(std::move(*quantization)); - } - - // MISC - - if (tfl_tensor.name()) { - litert_tensor.SetName(tfl_tensor.name()->str()); - } - - if (tfl_tensor.is_variable()) { - // TODO: b/365299994 - Support variable tensors. - LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tfl_tensor.variant_tensors() && - tfl_tensor.variant_tensors()->size() != 0) { - // TODO: b/365299994 - Support variant tensors. - LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tfl_tensor.sparsity() != nullptr) { - // TODO: b/365299994 - Support sparsity tensors. - LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus UnpackSubgraph(FlatbufferContext& context, - const TflPackedSubgraph& tfl_subgraph, - LiteRtSubgraphT& litert_subgraph) { - // Unpack tensors. - const auto num_tensors = tfl_subgraph.tensors()->size(); - for (auto i = 0; i < num_tensors; ++i) { - const auto* tfl_tensor = tfl_subgraph.tensors()->Get(i); - LITERT_RETURN_IF_ERROR( - UnpackTensor(context, *tfl_tensor, litert_subgraph.EmplaceTensor())); - } - - // Unpack ops, pass litert_subgraph so they can look up the new litert - // tensors. - const auto num_ops = tfl_subgraph.operators()->size(); - for (auto i = 0; i < num_ops; ++i) { - const auto* tfl_op = tfl_subgraph.operators()->Get(i); - LITERT_RETURN_IF_ERROR(UnpackOp(context, litert_subgraph, *tfl_op, - litert_subgraph.EmplaceOp())); - } - - // Update subgraph I/O. - const auto num_inputs = tfl_subgraph.inputs()->size(); - for (auto i = 0; i < num_inputs; ++i) { - const auto tfl_input_ind = tfl_subgraph.inputs()->Get(i); - litert_subgraph.Inputs().push_back(&litert_subgraph.Tensor(tfl_input_ind)); - } - const auto num_outputs = tfl_subgraph.outputs()->size(); - for (auto i = 0; i < num_outputs; ++i) { - const auto tfl_output_ind = tfl_subgraph.outputs()->Get(i); - litert_subgraph.Outputs().push_back( - &litert_subgraph.Tensor(tfl_output_ind)); - } - - return kLiteRtStatusOk; -} - -LiteRtStatus UnpackSignatures(std::vector& tfl_signatures, - LiteRtModelT& parent) { - for (auto& tfl_signature : tfl_signatures) { - if (tfl_signature->subgraph_index >= parent.Subgraphs().size()) { - LITERT_LOG(LITERT_ERROR, - "Signature does not refer to a valid subgraph index."); - return kLiteRtStatusErrorInvalidArgument; - } - - auto* litert_subgraph = - parent.Subgraphs().at(tfl_signature->subgraph_index); - - auto& tfl_inputs = tfl_signature->inputs; - auto& tfl_outputs = tfl_signature->outputs; - - // Tflite signatures map a tensor index to a name. The input & output - // indexes of signatures and subgraph are not matched, but the nubmer of - // inputs and outputs should be the same. - if (tfl_inputs.size() != litert_subgraph->Inputs().size() || - tfl_outputs.size() != litert_subgraph->Outputs().size()) { - LITERT_LOG(LITERT_ERROR, - "Signature has incorrect number of input/outputs"); - return kLiteRtStatusErrorInvalidFlatbuffer; - } - - // The tensor names may not be matched between signature and subgraph. - // Update the tensor names with the signature names since the signature - // names are used for LiteRT APIs. - for (auto i = 0; i < tfl_inputs.size(); ++i) { - const auto& tfl_input = tfl_inputs.at(i); - auto* index_litert_input = - litert_subgraph->Tensors().at(tfl_input->tensor_index); - index_litert_input->SetName(tfl_input->name); - } - for (auto i = 0; i < tfl_outputs.size(); ++i) { - const auto& tfl_output = tfl_outputs.at(i); - auto* index_litert_output = - litert_subgraph->Tensors().at(tfl_output->tensor_index); - index_litert_output->SetName(tfl_output->name); - } - - // Keep signature input/output names in the same order as the subgraph. - std::vector input_names; - input_names.reserve(tfl_inputs.size()); - for (auto& tensor : litert_subgraph->Inputs()) { - input_names.push_back(std::string(tensor->Name())); - } - std::vector output_names; - output_names.reserve(tfl_outputs.size()); - for (auto& tensor : litert_subgraph->Outputs()) { - output_names.push_back(std::string(tensor->Name())); - } - - parent.EmplaceSignature(litert_subgraph, std::move(input_names), - std::move(output_names), - tfl_signature->signature_key); - } - - if (tfl_signatures.empty()) { - parent.EmplaceSignature(MakeDefaultSignature(parent.MainSubgraph())); - } - - return kLiteRtStatusOk; -} - -Expected UnpackModel(FlatbufferWrapper&& flatbuffer) { - auto litert_model = std::make_unique(std::move(flatbuffer)); - - FlatbufferContext context(litert::internal::GetTflFlatbuffer(*litert_model), - litert_model->Buffers()); - const auto* packed_model = context.PackedModel(); - - if (packed_model->subgraphs()) { - const auto num_subgraphs = packed_model->subgraphs()->size(); - for (auto i = 0; i < num_subgraphs; ++i) { - const auto* tfl_subgraph = packed_model->subgraphs()->Get(i); - LITERT_RETURN_IF_ERROR(UnpackSubgraph(context, *tfl_subgraph, - litert_model->EmplaceSubgraph())); - } - } - - // TODO Figure out how to load signatures in packed flatbuffer. - if (packed_model->signature_defs()) { - std::vector tfl_signatures; - for (auto i = 0; i < packed_model->signature_defs()->size(); ++i) { - const auto* tfl_signature = packed_model->signature_defs()->Get(i); - tfl_signatures.push_back(TflSignaturePtr(tfl_signature->UnPack())); - } - LITERT_RETURN_IF_ERROR(UnpackSignatures(tfl_signatures, *litert_model)); - } else { - litert_model->EmplaceSignature( - MakeDefaultSignature(litert_model->MainSubgraph())); - } - - if (packed_model->metadata()) { - const auto num_metadata = packed_model->metadata()->size(); - for (auto i = 0; i < num_metadata; ++i) { - const auto* tfl_metadata = packed_model->metadata()->Get(i); - auto name = tfl_metadata->name()->str(); - const auto buf_id = tfl_metadata->buffer(); - auto buf = ReadBuffer(context, buf_id); - if (!buf) { - return buf.Error(); - } - - litert_model->PushMetadata(name, buf->buffer.Data(), buf->buffer.Size()); - } - } - - if (packed_model->operator_codes()) { - const auto num_operator_codes = packed_model->operator_codes()->size(); - std::vector tfl_op_codes(num_operator_codes); - for (auto i = 0; i < num_operator_codes; ++i) { - const auto* tfl_op_code = packed_model->operator_codes()->Get(i); - TflOpCodePtr tfl_op_code_ptr(tfl_op_code->UnPack()); - tfl_op_codes[i] = std::move(tfl_op_code_ptr); - } - litert::internal::SetTflOpCodes(*litert_model, std::move(tfl_op_codes)); - } - - return litert_model; -} - -} // namespace - -Expected LoadModelFromBuffer(BufferRef buffer) { - auto flatbuffer = FlatbufferWrapper::CreateFromBuffer(buffer); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return UnpackModel(std::move(**flatbuffer)); -} - -Expected LoadModelFromFile(absl::string_view filename) { - auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(filename); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return UnpackModel(std::move(**flatbuffer)); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_load.h b/tensorflow/lite/experimental/litert/core/model/model_load.h deleted file mode 100644 index b6a8c2cdd0f650..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_load.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -Expected> LoadModelFromFile( - absl::string_view filename); - -Expected> LoadModelFromBuffer( - BufferRef buffer); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_serialize.cc b/tensorflow/lite/experimental/litert/core/model/model_serialize.cc deleted file mode 100644 index bc3f1467fd8a6f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_serialize.cc +++ /dev/null @@ -1,559 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// schema/mutable/schema_generated.h and schema/schema_generated.h (included -// through flatbuffer_tools.h via model.h) have the same #ifdef, thus this line -// need to be put at the top to ensure we get the "mutable" version. -#if 1 -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#endif - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/insert_order_map.h" -#include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" - -namespace litert::internal { -namespace { - -using TensorMap = absl::flat_hash_map; - -// This is expected to be used to serialize the dispatch op custom code. -TflOpCodePtr MakeCustomOpCode(std::string custom_code_name) { - auto custom_code = std::make_unique(); - custom_code->builtin_code = ::tflite::BuiltinOperator_CUSTOM; - custom_code->custom_code = std::move(custom_code_name); - custom_code->version = 1; - return custom_code; -} - -// Utility for accessing flatbuffer state and other relevant state. -class SerializationContext { - public: - // Subgraph and op index pair. - using TflOpInd = std::pair; - using TflOpAssetMap = - absl::flat_hash_map; - using TflBufferInd = uint32_t; - using TflOffsetTensorMap = - absl::flat_hash_map; - using TflBufferIdMap = - absl::flat_hash_map; - - explicit SerializationContext(uint32_t dispatch_op_code_ind, - LiteRtModelT& litert_model, - size_t bytecode_alignment) - : tfl_model_(std::make_unique()), - dispatch_op_code_ind_(dispatch_op_code_ind), - litert_model_(litert_model), - bytecode_alignment_(bytecode_alignment) { - // Tfl expects empty buffer 0. - tfl_model_->buffers.push_back(std::make_unique()); - } - - TflModel& Model() { return *tfl_model_.get(); } - - TflModelPtr Release() && { return std::move(tfl_model_); } - - LiteRtModelT& LitertModel() { return litert_model_; } - - size_t BytecodeAlignment() const { return bytecode_alignment_; } - - LiteRtStatus HandleTensorBuffer(TflTensor& tfl_tensor, - const LiteRtTensorT& litert_tensor) { - const auto litert_buf_id = litert_tensor.Weights().GetBufferId(); - auto* buffer_manager = litert_tensor.Weights().GetBufferManager(); - - auto litert_buf_ctx = buffer_manager->GetContext(litert_buf_id); - if (!litert_buf_ctx) { - LITERT_LOG(LITERT_ERROR, "Failed to get buffer context"); - return litert_buf_ctx.Error().Status(); - } - - auto litert_buf = buffer_manager->GetBuffer(litert_buf_id); - if (!litert_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to get buffer"); - return litert_buf.Error().Status(); - } - - TflBufferInd tfl_buffer_ind; - if (buffer_id_map_.contains(litert_buf_id)) { - tfl_buffer_ind = buffer_id_map_.at(litert_buf_id); - } else { - auto& tfl_buffer = - tfl_model_->buffers.emplace_back(std::make_unique()); - tfl_buffer_ind = tfl_model_->buffers.size() - 1; - - if (litert_buf_ctx->get().should_append) { - tfl_buffer->offset = 1; - tfl_buffer->size = 1; - offset_tensor_map_.emplace(tfl_buffer_ind, litert_buf_id); - } else { - tfl_buffer->data.assign(litert_buf->Data(), - litert_buf->Data() + litert_buf->Size()); - } - buffer_id_map_[litert_buf_id] = tfl_buffer_ind; - } - - tfl_tensor.buffer = tfl_buffer_ind; - - return kLiteRtStatusOk; - } - - // Add to tfl model metadata. - void PushMetadata(std::string key, BufferRef data) { - auto& tfl_buffer = - tfl_model_->buffers.emplace_back(std::make_unique()); - const auto tfl_buffer_ind = tfl_model_->buffers.size() - 1; - tfl_buffer->data.assign(data.Data(), data.Data() + data.Size()); - tfl_model_->metadata_buffer.push_back(tfl_buffer_ind); - auto tfl_metadata = std::make_unique(); - tfl_metadata->name = key; - tfl_metadata->buffer = tfl_buffer_ind; - tfl_model_->metadata.push_back(std::move(tfl_metadata)); - } - - // Keep track of the given ops index as having a particular asset. - // These will be used to update the ops with the correct offset and size - // after the model is fully packed. - void AttachAssetToOp(size_t subgraph_ind, size_t op_ind, - LiteRtModelT::OpAssetReference asset) { - TflOpInd tfl_op_ind = {subgraph_ind, op_ind}; - op_asset_map_.emplace(tfl_op_ind, asset); - } - - const TflOpAssetMap& OpAssetMap() const { return op_asset_map_; } - - const TflOffsetTensorMap& OffsetTensorMap() const { - return offset_tensor_map_; - } - - // Get the index in the tfl op codes for the dispatch custom code. - // This should be the only new custom code added after loading the initial - // tfl. - uint32_t DispatchOpCodeInd() const { return dispatch_op_code_ind_; } - - private: - TflModelPtr tfl_model_; - uint32_t dispatch_op_code_ind_; - LiteRtModelT& litert_model_; - - TflOpAssetMap op_asset_map_; - TflOffsetTensorMap offset_tensor_map_; - TflBufferIdMap buffer_id_map_; - size_t bytecode_alignment_ = 0; -}; - -void SetOptions(const LiteRtOpT& litert_op, TflOp& tfl_op) { - tfl_op.builtin_options = litert::internal::GetTflOptions(litert_op); - if (litert_op.CustomOptions().Size() != 0) { - tfl_op.custom_options = litert_op.CustomOptions().ToVec(); - tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; - } -} - -LiteRtStatus PackOp(SerializationContext& builder, LiteRtOpT& litert_op, - TflOp& tfl_op, const TensorMap& tensor_map) { - // Get index of the op code in the tfl model. - auto tfl_op_code_ind = litert::internal::GetTflOpCodeInd(litert_op); - const bool is_dispatch_op = - tfl_op_code_ind == litert::internal::kDispatchOpCodeTflInd; - - if (is_dispatch_op) { - tfl_op_code_ind = builder.DispatchOpCodeInd(); - } - - tfl_op.opcode_index = tfl_op_code_ind; - - // Look up the tensor indices in the tfl model. - for (auto* in : litert_op.Inputs()) { - tfl_op.inputs.push_back(tensor_map.at(in)); - } - for (auto* out : litert_op.Outputs()) { - tfl_op.outputs.push_back(tensor_map.at(out)); - } - - // Set generic options. - tfl_op.builtin_options = litert::internal::GetTflOptions(litert_op); - - return kLiteRtStatusOk; -} - -LiteRtStatus PackTensor(SerializationContext& builder, - LiteRtTensorT& litert_tensor, TflTensor& tfl_tensor) { - auto tfl_tensor_type = MapTensorType(litert_tensor.Type()); - if (!tfl_tensor_type) { - return tfl_tensor_type.Error().Status(); - } - auto [tfl_elem_type, tfl_shape] = *tfl_tensor_type; - - tfl_tensor.type = tfl_elem_type; - tfl_tensor.shape.assign(tfl_shape.shape.begin(), tfl_shape.shape.end()); - tfl_tensor.has_rank = tfl_shape.has_rank; - tfl_tensor.shape_signature.assign(tfl_shape.shape_signature.begin(), - tfl_shape.shape_signature.end()); - - auto tfl_quantization = MapQuantization(litert_tensor.Qparams()); - if (!tfl_quantization) { - return tfl_quantization.Error().Status(); - } - tfl_tensor.quantization = std::move(*tfl_quantization); - - LITERT_RETURN_IF_ERROR(builder.HandleTensorBuffer(tfl_tensor, litert_tensor)); - - tfl_tensor.name = std::string(litert_tensor.Name()); - - return kLiteRtStatusOk; -} - -LiteRtStatus PackSubgraph(SerializationContext& builder, - LiteRtSubgraphT& litert_subgraph, - TflSubgraph& tfl_subgraph, TensorMap& tensor_map, - size_t subgraph_ind) { - for (auto* tensor : litert_subgraph.Tensors()) { - tfl_subgraph.tensors.push_back(std::make_unique()); - tensor_map.insert({tensor, tfl_subgraph.tensors.size() - 1}); - LITERT_RETURN_IF_ERROR( - PackTensor(builder, *tensor, *tfl_subgraph.tensors.back())); - } - - for (auto i = 0; i < litert_subgraph.Ops().size(); ++i) { - auto* op = litert_subgraph.Ops().at(i); - - tfl_subgraph.operators.push_back(std::make_unique()); - auto& tfl_op = *tfl_subgraph.operators.back(); - LITERT_RETURN_IF_ERROR(PackOp(builder, *op, tfl_op, tensor_map)); - - // Set custom options. - if (auto op_asset = builder.LitertModel().FindOpAsset(op)) { - // This mechanism is currently only used for dispatch ops to store - // location of bytecode. Here we update the name and placeholder values - // for offset and size. These will be updated when the model is fully - // packed. - auto dispatch_opts = MakeDispatchOpOptions({ - 1, - 1, - std::string(op_asset->second), - }); - tfl_op.custom_options = dispatch_opts.ToVec(); - - // Save the "location" of the op and its asset. - builder.AttachAssetToOp(subgraph_ind, i, *op_asset); - - } else if (op->CustomOptions().Size() != 0) { - tfl_op.custom_options = op->CustomOptions().ToVec(); - } - - tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; - } - - for (auto* in : litert_subgraph.Inputs()) { - tfl_subgraph.inputs.push_back(tensor_map.at(in)); - } - - for (auto* out : litert_subgraph.Outputs()) { - tfl_subgraph.outputs.push_back(tensor_map.at(out)); - } - - return kLiteRtStatusOk; -} - -Expected PackAsTflite(SerializationContext& builder) { - auto& litert_model = builder.LitertModel(); - - // Pack litert subgraphs into tfl subgraphs and save the mapping of - // tensors. - TensorMap tensor_map; - for (auto i = 0; i < litert_model.Subgraphs().size(); ++i) { - auto& litert_subgraph = litert_model.Subgraph(i); - auto& tfl_subgraph = *builder.Model().subgraphs.emplace_back( - std::make_unique()); - LITERT_RETURN_IF_ERROR( - PackSubgraph(builder, litert_subgraph, tfl_subgraph, tensor_map, i)); - } - - // Serialize the signatures using saved tensor mapping. - for (auto* litert_signature : litert_model.Signatures()) { - auto* litert_subgraph = &litert_signature->GetSubgraph(); - - auto& tfl_signature = *builder.Model().signature_defs.emplace_back( - std::make_unique()); - tfl_signature.signature_key = std::string(litert_signature->Key()); - - auto begin = litert_model.Subgraphs().cbegin(); - auto end = litert_model.Subgraphs().cend(); - const auto litert_subgraph_ind = - std::find(begin, end, litert_subgraph) - begin; - tfl_signature.subgraph_index = litert_subgraph_ind; - - auto input_ind = 0; - for (const auto& litert_name : litert_signature->InputNames()) { - auto& tfl_input = *tfl_signature.inputs.emplace_back( - std::make_unique<::tflite::TensorMapT>()); - tfl_input.name = litert_name; - tfl_input.tensor_index = - tensor_map.find(litert_subgraph->Inputs().at(input_ind))->second; - ++input_ind; - } - - auto output_ind = 0; - for (const auto& litert_name : litert_signature->OutputNames()) { - auto& tfl_output = *tfl_signature.outputs.emplace_back( - std::make_unique<::tflite::TensorMapT>()); - tfl_output.name = litert_name; - tfl_output.tensor_index = - tensor_map.find(litert_subgraph->Outputs().at(output_ind))->second; - ++output_ind; - } - } - - // Serialize metadata. - for (auto it = litert_model.MetadataBegin(); it != litert_model.MetadataEnd(); - ++it) { - const auto& [key, buf_id] = *it; - auto buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find metadata buffer"); - return buf.Error(); - } - builder.PushMetadata(key, *buf); - } - - builder.Model().version = 3; - - return std::move(builder).Release(); -} - -// Appends external buffers to the back of the serialized tflite model. Updates -// the ops that references them with the correct offset and size in-place. -Expected> SerializeWithAppendedBuffers( - SerializationContext& builder, OwningBufferRef serialized_tfl, - LiteRtModelT& litert_model) { - if (builder.OpAssetMap().empty() && builder.OffsetTensorMap().empty()) { - return serialized_tfl; - } - - const auto align = builder.BytecodeAlignment(); - // Pad the original model to the next multiple of the alignment. - auto align_offset = [align](size_t& cur_offset) { - cur_offset = (cur_offset + align - 1) & ~(align - 1); - }; - - size_t cur_offset = serialized_tfl.Size(); - align_offset(cur_offset); - - // Calculate the offset and size of each op asset. - InsertOrderMap> - asset_buffer_offsets; - for (auto it = builder.OpAssetMap().cbegin(); - it != builder.OpAssetMap().cend(); ++it) { - const auto& [buf_id, name] = it->second; - auto asset_buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!asset_buf) { - return asset_buf.Error(); - } - if (asset_buffer_offsets.Contains(buf_id)) { - continue; - } - asset_buffer_offsets.InsertOrAssign(buf_id, - {cur_offset, asset_buf->Size()}); - cur_offset += asset_buf->Size(); - align_offset(cur_offset); - } - - // Calculate the offset and size of each offset tensor. - InsertOrderMap> - offset_tensor_offsets; - for (auto it = builder.OffsetTensorMap().cbegin(); - it != builder.OffsetTensorMap().cend(); ++it) { - const auto& [tfl_buffer_ind, litert_buf_id] = *it; - auto litert_buf = litert_model.Buffers()->GetBuffer(litert_buf_id); - if (!litert_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find offset tensor buffer"); - return litert_buf.Error(); - } - if (offset_tensor_offsets.Contains(tfl_buffer_ind)) { - continue; - } - offset_tensor_offsets.InsertOrAssign(tfl_buffer_ind, - {cur_offset, litert_buf->Size()}); - cur_offset += litert_buf->Size(); - } - - // Read serialized tflite in packed form. - auto* tfl_model = tflite::GetMutableModel(serialized_tfl.Data()); - - // Find the ops that have external buffers and mark them with the future size - // and offset. - for (auto sg_ind = 0; sg_ind < tfl_model->mutable_subgraphs()->size(); - ++sg_ind) { - auto* sg = tfl_model->mutable_subgraphs()->GetMutableObject(sg_ind); - - for (auto op_ind = 0; op_ind < sg->mutable_operators()->size(); ++op_ind) { - SerializationContext::TflOpInd ind = {sg_ind, op_ind}; - - auto asset_buffer = builder.OpAssetMap().find(ind); - if (asset_buffer == builder.OpAssetMap().end()) { - // No external buffer for this op. - continue; - } - - auto* op = sg->mutable_operators()->GetMutableObject(op_ind); - - // The id of the buffer in the litert model. - const auto buf_id = asset_buffer->second.first; - - // The real offset and size of the buffer in the serialized tflite model. - const auto offset_and_size = asset_buffer_offsets.Find(buf_id); - if (!offset_and_size) { - LITERT_LOG(LITERT_ERROR, "Failed to find offset and size for buffer"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - const auto [offset, size] = offset_and_size->get().second; - - // The custom options should have already been set with the name and - // placeholder values for size and offset. - MutableBufferRef old_raw_opts( - op->mutable_custom_options()->data(), - op->mutable_custom_options()->size()); - - // Update with real size and offset. - DispatchOpOptions dispach_opts(GetDispatchOpOptions(old_raw_opts)); - dispach_opts.bytecode_offset = offset; - dispach_opts.bytecode_size = size; - - if (!UpdateDispatchOpOptionsInPlace(dispach_opts, old_raw_opts)) { - LITERT_LOG(LITERT_ERROR, "Failed to update dispatch op options"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - } - } - - // Find the buffers that are offset buffers and mark them with the future - // size and offset. - for (auto i = 0; i < tfl_model->mutable_buffers()->size(); ++i) { - auto* tfl_buffer = tfl_model->mutable_buffers()->GetMutableObject(i); - auto offset_size = offset_tensor_offsets.Find(i); - if (!offset_size) { - // Not offset buffer. - continue; - } - const auto [offset, size] = offset_size->get().second; - const auto offset_ok = tfl_buffer->mutate_offset(offset); - const auto size_ok = tfl_buffer->mutate_size(size); - if (!offset_ok || !size_ok) { - LITERT_LOG(LITERT_ERROR, "Failed to update offset and size for buffer"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - } - - // Allocate buffer enough for original model and appendd buffers and copy. - OwningBufferRef final_model(cur_offset); - - // Copy serialized tflite model. - uint8_t* const start = final_model.Data(); - std::memcpy(start, serialized_tfl.Data(), serialized_tfl.Size()); - - // Copy asset buffers (aligned). - for (auto it = asset_buffer_offsets.Begin(); it != asset_buffer_offsets.End(); - ++it) { - const auto buf_id = it->first; - - auto asset_buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!asset_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find asset buffer"); - return asset_buf.Error(); - } - uint8_t* const offset = start + it->second.first; - std::memcpy(offset, asset_buf->Data(), asset_buf->Size()); - } - - // Copy offset tensor buffers. - for (auto it = offset_tensor_offsets.Begin(); - it != offset_tensor_offsets.End(); ++it) { - const auto buf_id = it->first; - - auto offset_buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!offset_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find offset tensor buffer"); - return offset_buf.Error(); - } - - uint8_t* const offset = start + it->second.first; - std::memcpy(offset, offset_buf->Data(), offset_buf->Size()); - } - - return final_model; -} - -} // namespace - -Expected> SerializeModel(LiteRtModelT&& model, - size_t bytecode_alignment) { - // Pass the op code list through that was saved during loading. Add one more - // op code for the dispatch ops - auto tfl_op_codes = litert::internal::TakeTflOpCodes(model); - tfl_op_codes.push_back( - MakeCustomOpCode(std::string(kLiteRtDispatchOpCustomCode))); - - SerializationContext builder(tfl_op_codes.size() - 1, model, - bytecode_alignment); - builder.Model().operator_codes = std::move(tfl_op_codes); - - auto tfl_model = PackAsTflite(builder); - if (!tfl_model) { - LITERT_LOG(LITERT_ERROR, "Failed to pack as tflite"); - return tfl_model.Error(); - } - - auto serialized_tfl = SerializeFlatbuffer(**tfl_model); - auto serialized_with_buffers = - SerializeWithAppendedBuffers(builder, std::move(serialized_tfl), model); - if (!serialized_with_buffers) { - LITERT_LOG(LITERT_ERROR, "Failed to serialize with appended buffers"); - return serialized_with_buffers.Error(); - } - - if (!VerifyFlatbuffer(serialized_with_buffers->Span())) { - LITERT_LOG(LITERT_ERROR, "Failed to verify flatbuffer"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - - return serialized_with_buffers; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_serialize.h b/tensorflow/lite/experimental/litert/core/model/model_serialize.h deleted file mode 100644 index 0ffa2d878ba8c3..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_serialize.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -Expected> SerializeModel( - LiteRtModelT&& model, size_t bytecode_alignment = 1); - -} // namespace litert::internal - - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_test.cc b/tensorflow/lite/experimental/litert/core/model/model_test.cc deleted file mode 100644 index 52dfcdc3778a4f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_test.cc +++ /dev/null @@ -1,531 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -// -// Model -// - -TEST(ModelTest, GetMetadata) { - static constexpr absl::string_view kMetadata = "VALUE"; - static constexpr absl::string_view kKey = "KEY"; - - LiteRtModelT model; - LITERT_ASSERT_OK(model.PushMetadata(kKey, kMetadata)); - auto found_metadata = model.FindMetadata(kKey); - ASSERT_TRUE(found_metadata); - EXPECT_EQ(found_metadata->StrView(), kMetadata); -} - -TEST(ModelTest, MetadataDNE) { - LiteRtModelT model; - auto res = model.FindMetadata("FOO"); - ASSERT_FALSE(res.HasValue()); -} - -TEST(ModelTest, GetBuildStamp) { - static constexpr absl::string_view kSocManufacturer = "honda"; - static constexpr absl::string_view kSocModel = "accord"; - - LiteRtModelT model; - - LITERT_ASSERT_OK(model.PushMetadata( - kLiteRtBuildStampKey, *MakeBuildStamp(kSocManufacturer, kSocModel))); - auto build_stamp = GetBuildStamp(model); - ASSERT_TRUE(build_stamp); - EXPECT_TRUE(IsCompiled(model)); - EXPECT_EQ(build_stamp->soc_manufacturer, kSocManufacturer); - EXPECT_EQ(build_stamp->soc_model, kSocModel); -} - -TEST(ModelTest, EmplaceSubgraph) { - LiteRtModelT model; - auto& sg = model.EmplaceSubgraph(); - EXPECT_EQ(model.Subgraphs().size(), 1); - auto& tensor = sg.EmplaceTensor(); - EXPECT_EQ(tensor.Weights().GetBufferManager(), model.Buffers()); -} - -TEST(ModelTest, Signature) { - static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; - - const std::vector inputs = {"input_1", "input_2"}; - const std::vector outputs = {"output_1"}; - - LiteRtModelT model; - auto& subgraph = model.EmplaceSubgraph(); - - auto& signature = model.EmplaceSignature(&subgraph, inputs, outputs, - std::string(kSignatureName)); - - auto found_signature = model.FindSignature(kSignatureName); - ASSERT_TRUE(found_signature); - EXPECT_EQ(found_signature->get(), signature); -} - -TEST(ModelTest, SignatureDNE) { - static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; - LiteRtModelT model; - auto found_signature = model.FindSignature(kSignatureName); - EXPECT_FALSE(found_signature); -} - -TEST(ModelTest, AttachExternalBufferToOp) { - static constexpr absl::string_view kBufferData = "BUFFER_DATA"; - static constexpr absl::string_view kOpName = "OP1"; - static constexpr absl::string_view kOp2Name = "OP2"; - - LiteRtModelT model; - auto& subgraph = model.EmplaceSubgraph(); - auto& op = subgraph.EmplaceOp(); - auto& op2 = subgraph.EmplaceOp(); - - OwningBufferRef external_buf(kBufferData); - - auto buf1_id = model.Buffers()->RegisterOwnedBuffer(std::move(external_buf)); - - model.AttachAssetToOp(&op, buf1_id, std::string(kOpName)); - model.AttachAssetToOp(&op2, buf1_id, std::string(kOp2Name)); - - auto op_1_res = model.FindOpAsset(&op); - ASSERT_TRUE(op_1_res); - EXPECT_EQ(op_1_res->second, kOpName); - EXPECT_EQ(op_1_res->first, buf1_id); - - auto op_2_res = model.FindOpAsset(&op2); - ASSERT_TRUE(op_2_res); - EXPECT_EQ(op_2_res->second, kOp2Name); - EXPECT_EQ(op_2_res->first, buf1_id); -} - -TEST(ModelTest, ExternalBufferNotFound) { - LiteRtModelT model; - LiteRtOpT op; - ASSERT_FALSE(model.FindOpAsset(&op)); -} - -// -// Subgraph -// - -TEST(ModelSubgraphTest, Input) { - LiteRtTensorT tensor; - LiteRtSubgraphT subgraph; - subgraph.Inputs().push_back(&tensor); - EXPECT_EQ(&subgraph.Input(0), subgraph.Inputs().front()); -} - -TEST(ModelSubgraphTest, Output) { - LiteRtTensorT tensor; - LiteRtSubgraphT subgraph; - subgraph.Outputs().push_back(&tensor); - EXPECT_EQ(&subgraph.Output(0), subgraph.Outputs().front()); -} - -TEST(ModelSubgraphTest, EmplaceTensor) { - LiteRtSubgraphT subgraph; - auto& tensor = subgraph.EmplaceTensor(); - ASSERT_EQ(subgraph.Tensors().size(), 1); - EXPECT_THAT(subgraph.Tensors(), ElementsAreArray({&tensor})); -} - -TEST(ModelSubgraphTest, EmplaceOp) { - LiteRtSubgraphT subgraph; - auto& op = subgraph.EmplaceOp(); - ASSERT_EQ(subgraph.Ops().size(), 1); - EXPECT_THAT(subgraph.Ops(), ElementsAreArray({&op})); -} - -// -// Op -// - -TEST(ModelOpTest, Input) { - LiteRtOpT op; - LiteRtTensorT tensor; - op.Inputs().push_back(&tensor); - EXPECT_EQ(&op.Input(0), op.Inputs().front()); -} - -TEST(ModelOpTest, Output) { - LiteRtOpT op; - LiteRtTensorT tensor; - op.Outputs().push_back(&tensor); - EXPECT_EQ(&op.Output(0), op.Outputs().front()); -} - -TEST(ModelOpTest, CustomOptions) { - static constexpr absl::string_view kOpts = "OPTIONS"; - - LiteRtOpT op; - op.SetCustomOptions(kOpts); - EXPECT_EQ(op.CustomOptions().StrView(), kOpts); -} - -TEST(ModelOpTest, Options) { - static constexpr auto kOptsType = ::tflite::BuiltinOptions_AddOptions; - - TflOptions options; - options.type = kOptsType; - options.Set(::tflite::AddOptionsT()); - - LiteRtOpT op; - litert::internal::SetTflOptions(op, std::move(options)); - - ASSERT_EQ(litert::internal::GetTflOptions(op).type, kOptsType); -} - -TEST(ModelOpTest, OpCode) { - constexpr static auto kOpCode = kLiteRtOpCodeTflMul; - - LiteRtOpT op; - op.SetOpCode(kOpCode); - EXPECT_EQ(op.OpCode(), kOpCode); -} - -// -// Tensor -// - -TEST(ModelTensorTypeTest, MakeRankedTensorType) { - static constexpr const int32_t kDims[] = {2, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - static constexpr auto kElementType = kLiteRtElementTypeFloat32; - const auto tensor_type = MakeRankedTensorType(kElementType, kDimsSpan); - ASSERT_EQ(tensor_type.first, kLiteRtRankedTensorType); - EXPECT_EQ(tensor_type.second.ranked_tensor_type.element_type, kElementType); - const auto& layout = tensor_type.second.ranked_tensor_type.layout; - ASSERT_EQ(layout.rank, kDimsSpan.size()); - EXPECT_THAT(absl::MakeConstSpan(layout.dimensions, kDimsSpan.size()), - ElementsAreArray(kDimsSpan)); -} - -TEST(ModelQuantizationTypeTest, MakePerTensor) { - static constexpr auto kScale = 1.0f; - static constexpr auto kZero = 1L; - const auto quant = MakePerTensorQuantization(kScale, kZero); - ASSERT_EQ(quant.first, kLiteRtQuantizationPerTensor); - const auto& per_tensor = quant.second.per_tensor; - EXPECT_EQ(per_tensor.scale, kScale); - EXPECT_EQ(per_tensor.zero_point, kZero); -} - -TEST(ModelQuantizationTypeTest, MakePerChannel) { - static constexpr std::array kScale = {1.0f, 2.0f}; - static constexpr std::array kZero = {1L, 2L}; - static constexpr int32_t kQdim = 0; - - LiteRtTensorT tensor; - const auto quant = MakePerChannelQuantization( - kScale, kZero, kQdim, - [&tensor](auto s) { return tensor.RequestScratchBuffer(s); }); - - ASSERT_EQ(quant.first, kLiteRtQuantizationPerChannel); - const auto& per_channel = quant.second.per_channel; - - const auto size = per_channel.num_channels; - ASSERT_EQ(size, 2); - EXPECT_EQ(per_channel.quantized_dimension, 0); - - auto scales = absl::MakeConstSpan(per_channel.scales, size); - auto zeros = absl::MakeConstSpan(per_channel.zero_points, size); - - EXPECT_THAT(scales, ElementsAreArray(kScale)); - EXPECT_THAT(zeros, ElementsAreArray(kZero)); -} - -TEST(ModelWeightsTest, EmptyWeights) { - LiteRtWeightsT weights; - EXPECT_EQ(weights.Buffer().Size(), 0); -} - -TEST(ModelWeightsTest, WeightsWithExternalBufferManager) { - static constexpr absl::string_view kData = "some_data"; - BufferManager manager; - - LiteRtWeightsT weights; - weights.SetBufferManager(&manager); - - BufferRef buf(kData.data(), kData.size()); - SetWeightsFromUnownedBuffer(weights, buf); - - EXPECT_EQ(manager.GetBuffer(weights.GetBufferId())->StrView(), kData); - EXPECT_EQ(weights.Buffer().StrView(), kData); -} - -TEST(ModelWeightsTest, WeightsFromUnownedBuffer) { - static constexpr absl::string_view kData = "some_data"; - - LiteRtWeightsT weights; - BufferRef buf(kData.data(), kData.size()); - SetWeightsFromUnownedBuffer(weights, buf); - - EXPECT_EQ(weights.Buffer().StrView(), kData); -} - -TEST(ModelWeightsTest, WeightsFromOwnedBuffer) { - static constexpr absl::string_view kData = "some_data"; - - LiteRtWeightsT weights; - - OwningBufferRef buf(kData); - SetWeightsFromUnownedBuffer(weights, std::move(buf)); - - EXPECT_EQ(weights.Buffer().StrView(), kData); -} - -TEST(ModelWeightsTest, OverwriteBuffer) { - static constexpr absl::string_view kData = "some_data"; - static constexpr absl::string_view kData2 = "some_data2"; - - LiteRtWeightsT weights; - - { - OwningBufferRef buf(kData); - SetWeightsFromOwnedBuffer(weights, std::move(buf)); - } - - { - OwningBufferRef buf(kData2); - SetWeightsFromOwnedBuffer(weights, std::move(buf)); - } - - EXPECT_EQ(weights.Buffer().StrView(), kData2); -} - -TEST(ModelTensorTest, Name) { - static constexpr absl::string_view kName = "TENSOR_NAME"; - - LiteRtTensorT tensor; - tensor.SetName(std::string(kName.begin(), kName.end())); - EXPECT_EQ(tensor.Name(), kName); -} - -TEST(ModelTensorTest, Use) { - LiteRtTensorT tensor; - tensor.Users().emplace_back(); - tensor.UserArgInds().push_back(0); - auto [user, ind] = tensor.GetUse(0); - EXPECT_EQ(user, tensor.Users().front()); - EXPECT_EQ(ind, 0); -} - -TEST(ModelTensorTest, DefiningOp) { - LiteRtTensorT tensor; - LiteRtOpT op; - tensor.SetDefiningOp(op, 0); - EXPECT_EQ(tensor.DefiningOp(), &op); - EXPECT_EQ(tensor.DefiningOpOutInd(), 0); -} - -TEST(ModelTest, TransferSubgraphToReindexComposite) { - LiteRtModelT model; - - auto& subgraph = model.EmplaceSubgraph(); - auto& other_subgraph = model.EmplaceSubgraph(); - auto& decomp_subgraph = model.EmplaceSubgraph(); - - auto& composite = subgraph.EmplaceOp(); - composite.SetOpCode(kLiteRtOpCodeShloComposite); - ::tflite::StableHLOCompositeOptionsT opts; - opts.name = "composite"; - opts.decomposition_subgraph_index = 2; - TflOptions2 options; - options.type = tflite::BuiltinOptions2_StableHLOCompositeOptions; - options.Set(std::move(opts)); - litert::internal::SetTflOptions2(composite, std::move(options)); - - LiteRtSubgraphT::Alloc dest; - std::vector indices = {1}; - model.TransferSubgraphTo(dest, std::move(indices)); - - EXPECT_THAT(model.Subgraphs(), - ElementsAreArray({&subgraph, &decomp_subgraph})); - EXPECT_THAT(dest.Elements(), ElementsAreArray({&other_subgraph})); - - const auto& new_opts = litert::internal::GetTflOptions2(composite); - const auto new_decomp_ind = - new_opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - EXPECT_EQ(new_decomp_ind, 1); -} - -TEST(ModelTest, TransferSubgraphToReindexCompositeNoChange) { - LiteRtModelT model; - - auto& subgraph = model.EmplaceSubgraph(); - auto& decomp_subgraph = model.EmplaceSubgraph(); - auto& other_subgraph = model.EmplaceSubgraph(); - - auto& composite = subgraph.EmplaceOp(); - composite.SetOpCode(kLiteRtOpCodeShloComposite); - ::tflite::StableHLOCompositeOptionsT opts; - opts.name = "composite"; - opts.decomposition_subgraph_index = 1; - TflOptions2 options; - options.type = tflite::BuiltinOptions2_StableHLOCompositeOptions; - ; - options.Set(std::move(opts)); - litert::internal::SetTflOptions2(composite, std::move(options)); - - LiteRtSubgraphT::Alloc dest; - std::vector indices = {2}; - model.TransferSubgraphTo(dest, std::move(indices)); - - EXPECT_THAT(model.Subgraphs(), - ElementsAreArray({&subgraph, &decomp_subgraph})); - EXPECT_THAT(dest.Elements(), ElementsAreArray({&other_subgraph})); - - const auto& new_opts = litert::internal::GetTflOptions2(composite); - const auto new_decomp_ind = - new_opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - EXPECT_EQ(new_decomp_ind, 1); -} - -TEST(ModelTest, TransferSubgraphToReindexCompositeMultiple) { - LiteRtModelT model; - - auto& subgraph = model.EmplaceSubgraph(); - auto& other_subgraph = model.EmplaceSubgraph(); - auto& other_subgraph2 = model.EmplaceSubgraph(); - auto& other_subgraph3 = model.EmplaceSubgraph(); - auto& decomp_subgraph = model.EmplaceSubgraph(); - auto& other_subgraph4 = model.EmplaceSubgraph(); - - auto& composite = subgraph.EmplaceOp(); - composite.SetOpCode(kLiteRtOpCodeShloComposite); - ::tflite::StableHLOCompositeOptionsT opts; - opts.name = "composite"; - opts.decomposition_subgraph_index = 4; - TflOptions2 options; - options.type = tflite::BuiltinOptions2_StableHLOCompositeOptions; - ; - options.Set(std::move(opts)); - litert::internal::SetTflOptions2(composite, std::move(options)); - - LiteRtSubgraphT::Alloc dest; - std::vector indices = {1, 3, 5}; - model.TransferSubgraphTo(dest, std::move(indices)); - - EXPECT_THAT(model.Subgraphs(), ElementsAreArray({&subgraph, &other_subgraph2, - &decomp_subgraph})); - EXPECT_THAT( - dest.Elements(), - ElementsAreArray({&other_subgraph, &other_subgraph3, &other_subgraph4})); - - const auto& new_opts = litert::internal::GetTflOptions2(composite); - const auto new_decomp_ind = - new_opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - EXPECT_EQ(new_decomp_ind, 2); -} - -// -// Misc Ir Containers -// - -TEST(ModelOpListTest, Push) { - LiteRtOpListT op_list; - LiteRtOpT op; - op_list.Push(&op); - auto vec = op_list.Values(); - EXPECT_EQ(vec.front().first, &op); -} - -TEST(ModelOpListTest, PushWithIndex) { - LiteRtOpListT op_list; - LiteRtOpT op; - op_list.Push(&op, 1); - auto vec = op_list.Values(); - EXPECT_EQ(vec.front().first, &op); - EXPECT_EQ(vec.front().second, 1); -} - -// -// Traversal Utils -// - -TEST(CcForEachIrTest, OpF3) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtSubgraph subgraph, int32_t subgraph_index, - LiteRtOp op) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, OpF1) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtOp op) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, OpF2) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtSubgraph subgraph, LiteRtOp op) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, SgF1) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtSubgraph subgraph) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, SgF2) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, - [&](LiteRtSubgraph subgraph, int32_t subgraph_index) { count++; }); - EXPECT_EQ(count, 1); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/BUILD b/tensorflow/lite/experimental/litert/core/util/BUILD deleted file mode 100644 index 88fb50a693cb11..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - # copybara:uncomment "//third_party/mediapipe/calculators/tensor:__subpackages__", - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "flatbuffer_tools", - srcs = ["flatbuffer_tools.cc"], - hdrs = [ - "flatbuffer_tools.h", - "//tensorflow/lite/experimental/litert/cc:litert_consts.h", - ], - deps = [ - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/lite:model_builder", - "//tensorflow/lite:stderr_reporter", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@flatbuffers//:runtime_cc", - ], -) - -cc_test( - name = "flatbuffer_tools_test", - srcs = ["flatbuffer_tools_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - deps = [ - ":flatbuffer_tools", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "tensor_type_util", - srcs = [ - "tensor_type_util.cc", - ], - hdrs = [ - "tensor_type_util.h", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "tensor_type_util_test", - srcs = ["tensor_type_util_test.cc"], - deps = [ - ":tensor_type_util", - "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc deleted file mode 100644 index ab67b75b2cbdc0..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -#include -#include -#include -#include - -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -#ifndef NDEBUG -// Make flatbuffers verifier `assert` in debug mode. -#define FLATBUFFERS_DEBUG_VERIFICATION_FAILURE - -#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep -#endif - -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "flatbuffers/verifier.h" // from @flatbuffers -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/stderr_reporter.h" - -namespace litert::internal { - -using ::flatbuffers::Verifier; -using ::tflite::VerifyModelBuffer; - -namespace { - -Expected FindMetadataInd(const TflModel& model, - absl::string_view key) { - tflite::MetadataT* fb_metadata = nullptr; - for (auto& m : model.metadata) { - if (m->name == key) { - fb_metadata = m.get(); - break; - } - } - if (fb_metadata == nullptr) { - return Error(kLiteRtStatusErrorNotFound); - } - return fb_metadata->buffer; -} - -} // namespace - -absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size) { - auto fb_buf_raw = reinterpret_cast(fb_data); - return absl::string_view(fb_buf_raw, size); -} - -absl::string_view FbBufToStr(absl::Span fb_buf) { - auto fb_buf_raw = reinterpret_cast(fb_buf.data()); - const size_t fb_buf_size = fb_buf.size(); - return absl::string_view(fb_buf_raw, fb_buf_size); -} - -absl::Span FbBufToStr(absl::Span fb_buf) { - return absl::MakeSpan(reinterpret_cast(fb_buf.data()), fb_buf.size()); -} - -absl::Span FbBufToStr(uint8_t* fb_data, size_t size) { - return absl::MakeSpan(reinterpret_cast(fb_data), size); -} - -bool VerifyFlatbuffer(absl::Span buf) { - return VerifyFlatbuffer(buf.data(), buf.size()); -} - -bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { - flatbuffers::Verifier::Options options; -#ifndef NDEBUG - options.assert = true; -#endif - flatbuffers::Verifier verifier(buf, buf_size, options); - return VerifyModelBuffer(verifier); -} - -Expected> GetMetadata(absl::string_view key, - TflModel& model) { - auto buffer_ind = FindMetadataInd(model, key); - if (!buffer_ind) { - // Metadata key already has value. - return buffer_ind.Error(); - } - auto& fb_vec = model.buffers.at(*buffer_ind)->data; - return MutableBufferRef(fb_vec.data(), fb_vec.size()); -} - -Expected> GetMetadata(absl::string_view key, - const TflModel& model) { - auto metadata = GetMetadata(key, const_cast(model)); - if (!metadata) { - return metadata.Error(); - } - return *metadata; -} - -LiteRtStatus PushMetadata(absl::string_view key, TflModel& model, - BufferRef metadata) { - auto buffer_ind = FindMetadataInd(model, key); - if (buffer_ind) { - // Metadata key already has value. - return kLiteRtStatusErrorInvalidArgument; - } - - auto& new_metadata = - model.metadata.emplace_back(std::make_unique()); - new_metadata->name.assign(key.data(), key.size()); - - const auto new_m_buffer_ind = model.buffers.size(); - new_metadata->buffer = new_m_buffer_ind; - - auto& new_buffer = model.buffers.emplace_back(std::make_unique()); - new_buffer->data.assign(metadata.Data(), metadata.Data() + metadata.Size()); - - return kLiteRtStatusOk; -} - -Expected> GetTflBuffer(TflModel& tfl_model, - uint32_t buffer_ind) { - if (buffer_ind >= tfl_model.buffers.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - auto& tfl_data = tfl_model.buffers.at(buffer_ind)->data; - return MutableBufferRef(tfl_data.data(), tfl_data.size()); -} - -Expected> GetTflBuffer(const TflModel& tfl_model, - uint32_t buffer_ind) { - auto buffer = GetTflBuffer(const_cast(tfl_model), buffer_ind); - if (!buffer) { - return buffer.Error(); - } - return *buffer; -} - -Expected GetBuffer(const TflModel& tfl_model, - uint32_t buffer_ind) { - if (buffer_ind >= tfl_model.buffers.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return tfl_model.buffers.at(buffer_ind).get(); -} - -Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind) { - if (buffer_ind >= tfl_model.buffers.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::move(tfl_model.buffers.at(buffer_ind)); -} - -Expected PushTflBuffer(TflModel& tfl_model, - BufferRef buffer) { - tfl_model.buffers.emplace_back(std::make_unique<::tflite::BufferT>()) - ->data.assign(buffer.Data(), buffer.Data() + buffer.Size()); - return tfl_model.buffers.size() - 1; -} - -Expected GetTflOpCode(const TflModel& tfl_model, - uint32_t op_code_ind) { - if (op_code_ind >= tfl_model.operator_codes.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::move(tfl_model.operator_codes.at(op_code_ind)->builtin_code); -} - -bool IsRankedTensorType(const TflShapeInfo& tfl_shape) { - return tfl_shape.has_rank; -} - -bool IsStaticTensorType(const TflShapeInfo& tfl_shape) { - return !IsRankedTensorType(tfl_shape) || - std::none_of(tfl_shape.shape_signature.begin(), - tfl_shape.shape_signature.end(), - [](auto d) { return d < 0; }); -} - -Expected> AsStaticShape( - const TflShapeInfo& tfl_shape) { - if (!IsStaticTensorType(tfl_shape)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return absl::MakeConstSpan(tfl_shape.shape.data(), tfl_shape.shape.size()); -} - -Expected> AsDynamicShape( - const TflShapeInfo& tfl_shape) { - auto static_shape = AsStaticShape(tfl_shape); - if (static_shape) { - return static_shape; - } - if (!IsRankedTensorType(tfl_shape)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return absl::MakeConstSpan(tfl_shape.shape_signature.data(), - tfl_shape.shape_signature.size()); -} - -bool IsQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && - (!tfl_quantization->scale.empty() || - tfl_quantization->details.type != tflite::QuantizationDetails_NONE); -} - -bool IsPerChannelQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && tfl_quantization->scale.size() > 1; -} - -bool IsPerTensorQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && tfl_quantization->scale.size() == 1; -} - -bool IsBlockwiseQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && - tfl_quantization->details.type == - tflite::QuantizationDetails_BlockwiseQuantization; -} - -bool IsCustomQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && tfl_quantization->details.type == - tflite::QuantizationDetails_CustomQuantization; -} - -Expected AsPerTensorQparams( - const TflQuantization* tfl_quantization) { - if (!IsPerTensorQuantized(tfl_quantization)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return std::make_pair(tfl_quantization->zero_point.front(), - tfl_quantization->scale.front()); -} - -Expected AsPerChannelQparams( - const TflQuantization* tfl_quantization) { - if (!IsPerChannelQuantized(tfl_quantization)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return TflPerChannelQParams(tfl_quantization->quantized_dimension, - tfl_quantization->zero_point.size(), - tfl_quantization->zero_point, - tfl_quantization->scale); -} - -::tflite::Allocation::Ptr MakeAllocation(BufferRef buf) { - return std::make_unique<::tflite::MemoryAllocation>( - buf.Data(), buf.Size(), ::tflite::DefaultErrorReporter()); -} - -Expected FlatbufferWrapper::CreateFromBuffer( - OwningBufferRef&& buffer) { - static constexpr size_t k2GiB = 2e+9; - if (buffer.Size() < k2GiB && - !VerifyFlatbuffer(buffer.Data(), buffer.Size())) { - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - - auto alloc = MakeAllocation(buffer); - - if (alloc == nullptr) { - return Error(kLiteRtStatusErrorFileIO); - } - - auto fb_model = ::tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(alloc->base()), alloc->bytes()); - if (fb_model == nullptr) { - return Error(kLiteRtStatusErrorFileIO); - } - - return FlatbufferWrapper::Ptr(new FlatbufferWrapper( - std::move(fb_model), std::move(alloc), std::move(buffer))); -} - -Expected FlatbufferWrapper::CreateFromBuffer( - BufferRef buffer) { - return FlatbufferWrapper::CreateFromBuffer( - OwningBufferRef(buffer.Data(), buffer.Size())); -} - -Expected FlatbufferWrapper::CreateFromTflFile( - absl::string_view path) { - auto buf = LoadBinaryFile(path); - if (!buf) { - return buf.Error(); - } - return FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); -} - -OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model) { - flatbuffers::FlatBufferBuilder b; - auto model_offset = tflite::Model::Pack(b, &tfl_model); - tflite::FinishModelBuffer(b, model_offset); - - OwningBufferRef buffer; - auto [new_buf, new_size, new_offset] = buffer.GetWeak(); - new_buf = b.ReleaseRaw(new_size, new_offset); - - return buffer; -} - -OwningBufferRef SerializeFlatbuffer( - const FlatbufferWrapper& flatbuffer) { - auto tfl_model = flatbuffer.Unpack(); - return SerializeFlatbuffer(*tfl_model); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h deleted file mode 100644 index bf0ccf6604f737..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { - -// Flatbuffer IR - -using TflTensor = ::tflite::TensorT; -using TflOp = ::tflite::OperatorT; -using TflBuffer = ::tflite::BufferT; -using TflSubgraph = ::tflite::SubGraphT; -using TflModel = ::tflite::ModelT; -using TflOpCodeEnum = ::tflite::BuiltinOperator; -using TflOpCode = ::tflite::OperatorCodeT; -using TflQuantization = ::tflite::QuantizationParametersT; -using TflElementType = ::tflite::TensorType; -using TflOptions = ::tflite::BuiltinOptionsUnion; -using TflOptions2 = ::tflite::BuiltinOptions2Union; -using TflSignature = ::tflite::SignatureDefT; -using TflMetadata = ::tflite::MetadataT; - -using TflPackedModel = ::tflite::Model; -using TflPackedSubgraph = ::tflite::SubGraph; -using TflPackedOp = ::tflite::Operator; -using TflPackedTensor = ::tflite::Tensor; -using TflPackedBuffer = ::tflite::Buffer; - -using TflBufferPtr = std::unique_ptr; -using TflModelPtr = std::unique_ptr; -using TflQuantizationPtr = std::unique_ptr; -using TflOpCodePtr = std::unique_ptr; -using TflSubgraphPtr = std::unique_ptr; -using TflTensorPtr = std::unique_ptr; -using TflOpPtr = std::unique_ptr; -using TflSignaturePtr = std::unique_ptr; -using TflMetadataPtr = std::unique_ptr; - -// Code and verion. -using TflOpCodeDetail = std::pair; - -// Zero-point, scale. -using TflPerTensorQParams = std::pair; - -// Quantized dim, num channels, zero-points, scales. -using TflPerChannelQParams = - std::tuple, std::vector>; - -// Mirror of all the tensor type related fields in flatbuffer tensor definition. -struct TflShapeInfo { - // Fixed or dynamic rank. - bool has_rank; - - // Basic shape, all elements are non-negative (even if this is a dynamic - // shape). - absl::InlinedVector shape; - - // Dynamic dyn info. If this is not empty, then its length is equal to shape. - // If i is a dyn dim, then shape[i] == 1 and shape_signature[i] < 0. Otherwise - // shape_signature[i] == shape[i]. - absl::InlinedVector shape_signature; - - // Convert from a single dims array. Will detect if array is static/dynamic - // and populate fields accordingly. - explicit TflShapeInfo(absl::Span shape_data) : has_rank(true) { - bool is_dyn = false; - shape.reserve(shape_data.size()); - shape_signature.reserve(shape_data.size()); - for (auto d : shape_data) { - if (d >= 0) { - shape.push_back(d); - shape_signature.push_back(d); - } else { - is_dyn = true; - shape.push_back(1); - shape_signature.push_back(-1); - } - } - if (!is_dyn) { - shape_signature.clear(); - } - } - - // Convert from tensor. - explicit TflShapeInfo(const TflTensor& tfl_tensor) - : has_rank(tfl_tensor.has_rank), - shape(tfl_tensor.shape.begin(), tfl_tensor.shape.end()), - shape_signature(tfl_tensor.shape_signature.begin(), - tfl_tensor.shape_signature.end()) {} - - explicit TflShapeInfo(const TflPackedTensor& tfl_tensor) - : has_rank(tfl_tensor.has_rank()) { - if (tfl_tensor.shape()) { - shape.assign(tfl_tensor.shape()->begin(), tfl_tensor.shape()->end()); - } - - if (tfl_tensor.shape_signature()) { - shape_signature.assign(tfl_tensor.shape_signature()->begin(), - tfl_tensor.shape_signature()->end()); - } - } -}; - -using TflTensorType = std::pair; - -// Flatbuffer bytes util. - -// Convenience method to get string view from native flatbuffer chars. -absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size); - -// Span version. -absl::string_view FbBufToStr(absl::Span fb_buf); - -// Convenience method to get mutable signed char span from native flatbuffer -// chars. -absl::Span FbBufToStr(uint8_t* fb_data, size_t size); - -// Span to span version. -absl::Span FbBufToStr(absl::Span fb_buf); - -// Flatbuffer verifiers. - -// Verifies given serialized flatbuffer -bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size); - -// Override of above with view input. -bool VerifyFlatbuffer(absl::Span buf); - -// TFL flatbuffer IR helpers. - -// Get the metadata buffer under given key if it exists. -Expected> GetMetadata(absl::string_view key, - const TflModel& model); - -// Get the metadata buffer under given key if it exists that can be written to. -Expected> GetMutableMetadata(absl::string_view key, - TflModel& model); - -// Push the given metadata to the given key if the key does not already exist. -LiteRtStatus PushMetadata(absl::string_view key, TflModel& model, - BufferRef metadata); - -// Get the buffer object at the given index if it exists. -Expected> GetTflBuffer(const TflModel& tfl_model, - uint32_t buffer_ind); - -// Get the buffer object at the given index if it exists that can be written to. -Expected> GetMutableTflBuffer(TflModel& tfl_model, - uint32_t buffer_ind); - -// Get a non-owning view of tfl buffer if it exists. -Expected GetBuffer(const TflModel& tfl_model, - uint32_t buffer_ind); - -// Move and take ownership of the buffer object at given index if it exists. -Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind); - -// Add a new buffer to the tflite model, returning its index. -Expected PushTflBuffer(TflModel& tfl_model, - BufferRef buffer); - -// Make a tflite buffer from data. -template -TflBufferPtr MakeTflBuffer(std::initializer_list data) { - auto res = std::make_unique(); - const auto byte_size = data.size() * sizeof(T); - res->data.resize(byte_size); - for (auto it = data.begin(); it != data.end(); ++it) { - auto* write_to = - reinterpret_cast(res->data.data()) + (it - data.begin()); - *write_to = *it; - } - res->size = res->data.size(); - res->offset = 0; - return res; -} - -// Get the op code from the model at the given index if it exists. -Expected GetTflOpCode(const TflModel& tfl_model, - uint32_t op_code_ind); - -// Is tensor fixed rank, with possible dynamic dims. -bool IsRankedTensorType(const TflShapeInfo& tfl_shape); - -// Is ranked tensor type with static shape. -bool IsStaticTensorType(const TflShapeInfo& tfl_shape); - -// Get static shape info if given is indeed a static shape. -Expected> AsStaticShape( - const TflShapeInfo& tfl_shape); - -// Get ranked dynamic shape info if given is indeed a ranked. Still works with -// static shapes. -Expected> AsDynamicShape( - const TflShapeInfo& tfl_shape); - -// Is the tensor quantized. -bool IsQuantized(const TflQuantization* tfl_quantization); - -// Is the tensor per-tensor quantized. -bool IsPerTensorQuantized(const TflQuantization* tfl_quantization); - -// Is the tensor per-channel quantized. -bool IsPerChannelQuantized(const TflQuantization* tfl_quantization); - -// Is the tensor block-wise quantized. -bool IsBlockWiseQuantized(const TflQuantization* tfl_quantization); - -// Does tensor have custom quantization. -bool IsCustomQuantized(const TflQuantization* tfl_quantization); - -// Get the per-tensor tensor q-params if given tensor has them. -Expected AsPerTensorQparams( - const TflQuantization* tfl_quantization); - -// Get the per-channel tensor q-params if given tensor has them. -Expected AsPerChannelQparams( - const TflQuantization* tfl_quantization); - -// Flatbuffer management helpers. - -// Make a tfl allocation from buffer. -::tflite::Allocation::Ptr MakeAllocation(BufferRef buf); - -// Wrapper around a tflite model buffer. -class FlatbufferWrapper { - public: - using Ptr = std::unique_ptr; - - // TODO Don't return a unique_ptr, this can just be a move only type, all the - // fields are unique_ptrs. Load flatbuffer from file. - static Expected CreateFromTflFile(absl::string_view path); - - // Load flatbuffer from allocated buffer that will be copied. - static Expected CreateFromBuffer(BufferRef buffer); - - // Load flatbuffer from allocated buffer and take ownership. - static Expected CreateFromBuffer(OwningBufferRef&& buffer); - - // Underlying buffer. - BufferRef Buf() const { - return BufferRef(alloc_->base(), alloc_->bytes()); - } - - // Underlying model object. - const ::tflite::FlatBufferModel& FlatbufferModel() const { - return *fb_model_; - } - - // Packed schema object. - const TflPackedModel* PackedModel() const { return fb_model_->GetModel(); } - - // Unpack the contained flatbuffer. - TflModelPtr Unpack() const { - return TflModelPtr(fb_model_->GetModel()->UnPack()); - } - - // Address of first byte of the raw model buffer. - const uint8_t* AllocBase() const { return Buf().Data(); } - - // Default construct for compatibility. - FlatbufferWrapper() = default; - - private: - FlatbufferWrapper(::tflite::FlatBufferModel::Ptr fb_model, - ::tflite::Allocation::Ptr alloc, - OwningBufferRef&& model_buf) - : fb_model_(std::move(fb_model)), - alloc_(std::move(alloc)), - model_buf_(std::forward>(model_buf)) {} - - ::tflite::FlatBufferModel::Ptr fb_model_; - ::tflite::Allocation::Ptr alloc_; - OwningBufferRef model_buf_; -}; - -// Re-serialize the unpacked model from flatbuffer wrapper. -OwningBufferRef SerializeFlatbuffer( - const FlatbufferWrapper& flatbuffer); -OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc deleted file mode 100644 index bc4fd6c493647c..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAre; -using ::testing::ElementsAreArray; -using ::testing::Lt; - -FlatbufferWrapper::Ptr TestFlatbuffer( - absl::string_view filename = "one_mul.tflite") { - const auto tfl_path = testing::GetTestFilePath(filename); - return *FlatbufferWrapper::CreateFromTflFile(tfl_path); -} - -static const absl::string_view kKey = "MyKey"; -static const absl::string_view kData = "MyData"; - -TEST(FlatbufferToolsTest, Metadata) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - LITERT_ASSERT_OK(PushMetadata( - kKey, *tfl_model, BufferRef(kData.data(), kData.size()))); - - auto metadata = GetMetadata(kKey, *tfl_model); - ASSERT_TRUE(metadata); - EXPECT_EQ(metadata->StrView(), kData); -} - -TEST(FlatbufferToolsTest, GetMetadataNotFound) { - auto flatbuffer = TestFlatbuffer(); - auto tfl_model = flatbuffer->Unpack(); - ASSERT_NE(flatbuffer, nullptr); - EXPECT_FALSE(GetMetadata(kKey, *tfl_model)); -} - -TEST(FlatbufferToolsTest, TflBuffer) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto ind = PushTflBuffer((*tfl_model), - BufferRef(kData.data(), kData.size())); - ASSERT_TRUE(ind); - - auto buf = GetTflBuffer((*tfl_model), *ind); - ASSERT_TRUE(buf); - ASSERT_EQ(buf->StrView(), kData); -} - -TEST(FlatbufferToolsTest, GetTflBufferNotFound) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto buf = GetTflBuffer((*tfl_model), 100); - ASSERT_FALSE(buf); -} - -TEST(FlatbufferToolsTest, GetTflOpCode) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto op_code = GetTflOpCode((*tfl_model), 0); - ASSERT_TRUE(op_code); -} - -TEST(FlatbufferToolsTest, GetTflOpCodeNotFound) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto op_code = GetTflOpCode((*tfl_model), 100); - ASSERT_FALSE(op_code); -} - -TEST(FlatbufferToolsTest, StaticTensorTypeTest) { - auto flatbuffer = TestFlatbuffer(); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - TflShapeInfo shape(*tensor); - - ASSERT_TRUE(IsRankedTensorType(shape)); - ASSERT_TRUE(IsStaticTensorType(shape)); - - auto static_shape = AsStaticShape(shape); - - ASSERT_TRUE(static_shape); - ASSERT_THAT(*static_shape, ElementsAreArray({2, 2})); -} - -TEST(FlatbufferToolsTest, UnrankedTensorTypeTest) { - auto flatbuffer = TestFlatbuffer("unranked_tensor.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - TflShapeInfo shape(*tensor); - - ASSERT_FALSE(IsRankedTensorType(shape)); -} - -TEST(FlatbufferToolsTest, RankedDynamicTensorTypeTest) { - auto flatbuffer = TestFlatbuffer("dynamic_shape_tensor.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - TflShapeInfo shape(*tensor); - - ASSERT_TRUE(IsRankedTensorType(shape)); - ASSERT_FALSE(IsStaticTensorType(shape)); - - auto dyn_shape = AsDynamicShape(shape); - - ASSERT_TRUE(dyn_shape); - ASSERT_THAT(*dyn_shape, ElementsAre(Lt(0), 2)); -} - -TEST(FlatbufferToolsTest, PerTensorQuantizedTest) { - auto flatbuffer = - TestFlatbuffer("single_add_default_a16w8_recipe_quantized.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - const auto* const q_parms = tensor->quantization.get(); - - ASSERT_TRUE(IsQuantized(q_parms)); - EXPECT_TRUE(IsPerTensorQuantized(q_parms)); - - auto per_tensor = AsPerTensorQparams(q_parms); - ASSERT_TRUE(per_tensor); -} - -TEST(FlatbufferToolsTest, PerChannelQuantizedTest) { - auto flatbuffer = TestFlatbuffer("static_w8_a16_quantized_k_einsum.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors[1]; - - const auto* const q_parms = tensor->quantization.get(); - - ASSERT_TRUE(IsQuantized(q_parms)); - EXPECT_TRUE(IsPerChannelQuantized(q_parms)); - - auto per_channel = AsPerChannelQparams(q_parms); - ASSERT_TRUE(per_channel); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc b/tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc deleted file mode 100644 index 4e3284374d24a2..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace internal { - -Expected GetElementSize(LiteRtElementType element_type) { - switch (element_type) { - case kLiteRtElementTypeInt4: - return Ratio{1, 2}; - case kLiteRtElementTypeBool: - return Ratio{1, 1}; - case kLiteRtElementTypeInt8: - case kLiteRtElementTypeUInt8: - return Ratio{1, 1}; - case kLiteRtElementTypeInt16: - case kLiteRtElementTypeUInt16: - case kLiteRtElementTypeFloat16: - case kLiteRtElementTypeBFloat16: - return Ratio{2, 1}; - case kLiteRtElementTypeInt32: - case kLiteRtElementTypeUInt32: - case kLiteRtElementTypeFloat32: - return Ratio{4, 1}; - case kLiteRtElementTypeInt64: - case kLiteRtElementTypeUInt64: - case kLiteRtElementTypeFloat64: - return Ratio{8, 1}; - case kLiteRtElementTypeComplex64: - return Ratio{16, 1}; - case kLiteRtElementTypeComplex128: - return Ratio{32, 1}; - default: - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected element type"); - } -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.h b/tensorflow/lite/experimental/litert/core/util/tensor_type_util.h deleted file mode 100644 index 9663b2ac337403..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ - -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct Ratio { - using Type = int; - Type num; - Type denom; - std::string ToString() const { return absl::StrCat(num, "/", denom); } -}; - -Expected GetElementSize(LiteRtElementType element_type); - -// Get the number of elements in a tensor with given dimensions. -template -Expected GetNumElements(absl::Span dimensions) { - size_t num_elements = 1; - for (auto i = 0; i < dimensions.size(); ++i) { - auto dim = dimensions[i]; - if (dim < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected negative dimension"); - } else if (dim == 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected 0 dimension"); - } - num_elements *= dim; - } - return num_elements; -} - -inline Expected GetNumElements( - const LiteRtRankedTensorType& tensor_type) { - return GetNumElements( - absl::MakeSpan(tensor_type.layout.dimensions, tensor_type.layout.rank)); -} - -// Get the minimum number of bytes necessary to represent a packed tensor with a -// given element type and dimensions. -template -Expected GetNumPackedBytes(LiteRtElementType element_type, - absl::Span dimensions) { - auto element_size = GetElementSize(element_type); - if (!element_size) { - return element_size.Error(); - } - auto num_elements = GetNumElements(dimensions); - if (!num_elements) { - return num_elements.Error(); - } - return ((*num_elements * element_size->num) + (element_size->denom - 1)) / - element_size->denom; -} - -// Get the number of bytes necessary to represent a packed tensor type, ignoring -// any stride information. -inline Expected GetNumPackedBytes( - const LiteRtRankedTensorType& tensor_type) { - return GetNumPackedBytes( - tensor_type.element_type, - absl::MakeSpan(tensor_type.layout.dimensions, tensor_type.layout.rank)); -} - -// Get the minimum number of bytes necessary to represent a possibly unpacked -// tensor with a given element type, dimensions, and strides. -template -Expected GetNumBytes(LiteRtElementType element_type, - absl::Span dimensions, absl::Span strides) { - if (dimensions.size() != strides.size()) { - return Unexpected( - kLiteRtStatusErrorInvalidArgument, - "Dimensions and strides have different number of elements"); - } - auto element_size = GetElementSize(element_type); - if (!element_size) { - return element_size.Error(); - } - auto rank = dimensions.size(); - size_t num_elements = 1; - for (auto i = 0; i < rank; ++i) { - num_elements += (dimensions[i] - 1) * strides[i]; - } - return ((num_elements * element_size->num) + (element_size->denom - 1)) / - element_size->denom; -} - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc b/tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc deleted file mode 100644 index bfb084140eb073..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" - -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -using litert::internal::GetNumBytes; -using litert::internal::GetNumElements; -using litert::internal::GetNumPackedBytes; - -TEST(TensorTypeUtil, GetNumElements) { - constexpr std::array dimensions = {3, 2, 1}; - auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); - EXPECT_TRUE(num_elements); - EXPECT_EQ(*num_elements, 6); -} - -TEST(TensorTypeUtil, GetNumElementsWithUnknownDimension) { - constexpr std::array dimensions = {3, -1, 1}; - auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); - EXPECT_FALSE(num_elements); -} - -TEST(TensorTypeUtil, GetNumElementsWithZeroDimension) { - constexpr std::array dimensions = {3, 0, 1}; - auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); - EXPECT_FALSE(num_elements); -} - -TEST(TensorTypeUtil, GetNumPackedBytes) { - LiteRtElementType element_type = kLiteRtElementTypeInt32; - constexpr std::array dimensions = {3, 2, 1}; - auto num_bytes = GetNumPackedBytes(element_type, absl::MakeSpan(dimensions)); - EXPECT_TRUE(num_bytes); - EXPECT_EQ(*num_bytes, sizeof(int32_t) * 6); -} - -TEST(TensorTypeUtil, GetNumBytes) { - LiteRtElementType element_type = kLiteRtElementTypeInt32; - constexpr std::array dimensions = {3, 2, 1}; - constexpr std::array strides = {1, 4, 8}; - // The data should be allocated as follows (where 'X' is a used cell and 'o' - // is an unused/padding cell): - // - // XXXo XXX - // - // The total is 4 + 3 = 7 cells - auto num_bytes = GetNumBytes(element_type, absl::MakeSpan(dimensions), - absl::MakeSpan(strides)); - EXPECT_TRUE(num_bytes); - EXPECT_EQ(*num_bytes, sizeof(int32_t) * 7); -} diff --git a/tensorflow/lite/experimental/litert/core/version.h b/tensorflow/lite/experimental/litert/core/version.h deleted file mode 100644 index fa9b017917c349..00000000000000 --- a/tensorflow/lite/experimental/litert/core/version.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_VERSION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_VERSION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert::internal { - -// Return true if two API versions are the same. -inline bool IsSameVersion(const LiteRtApiVersion& v1, - const LiteRtApiVersion& v2) { - return (v1.major == v2.major) && (v1.minor == v2.minor) && - (v1.patch == v2.patch); -} - -// Return true if a given API version is the same as the current runtime. -inline bool IsSameVersionAsRuntime(const LiteRtApiVersion& v) { - return IsSameVersion(v, {LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH}); -} - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_VERSION_H_ diff --git a/tensorflow/lite/experimental/litert/integration_test/BUILD b/tensorflow/lite/experimental/litert/integration_test/BUILD deleted file mode 100644 index d36062a53e6c26..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:tfl_model_gen.bzl", "tfl_model_gen") -load("//tensorflow/lite/experimental/litert/integration_test:run_on_device.bzl", "litert_integration_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -# C++ TEST SCAFFOLD ################################################################################ - -cc_test( - name = "gen_device_test", - srcs = ["gen_device_test.cc"], - copts = ["-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1"], - data = [":single_op_models"], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = ["manual"], - deps = [ - ":gen_device_test_lib", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/tools:dump", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "gen_device_test_lib", - testonly = True, - hdrs = ["gen_device_test_lib.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -# TEST MODELS ###################################################################################### - -filegroup( - name = "classic_ml_models", - srcs = glob(["classic_ml_models/*.tflite"]), -) - -tfl_model_gen( - name = "single_op_models", - srcs = glob(["single_op_models/*.mlir"]), - subdir = "single_op_models", -) - -filegroup( - name = "pre_compiled_models", - srcs = glob(["pre_compiled_models/*.tflite"]), -) - -# ON DEVICE INTEGRATION TESTS ###################################################################### - -# NOTE: Everything here should be built with -c opt --config=android_arm64. - -sh_binary( - name = "run_on_device_driver_OSS", - srcs = ["run_on_device_driver_OSS.sh"], -) - -litert_integration_test( - name = "single_op_device_tests_cpu", - hw = "cpu", - models = ":single_op_models", -) - -litert_integration_test( - name = "single_op_device_tests_qualcomm_JIT", - hw = "qualcomm", - models = ":single_op_models", - skips = [ - "greater_f32", # TODO: lukeboyer - Investigate (segfault). - "less_f32", # TODO: lukeboyer - Investigate (segfault). - ], -) - -litert_integration_test( - name = "classic_ml_device_tests_cpu", - hw = "cpu", - models = ":classic_ml_models", -) - -litert_integration_test( - name = "classic_ml_device_tests_qualcomm_JIT", - hw = "qualcomm", - models = ":classic_ml_models", -) - -litert_integration_test( - name = "pre_compiled_device_tests_qualcomm", - hw = "qualcomm", - models = ":pre_compiled_models", -) diff --git a/tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc b/tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc deleted file mode 100644 index fd2f05e70dd21f..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include // NOLINT -#include -#include -#include - -#include -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -ABSL_FLAG(std::string, model_path, "", - "Tflite models to test. This can be a single tflite model or a " - "directory containing multiple tflite models."); -ABSL_FLAG(std::string, dispatch_library_dir, "/data/local/tmp/", - "Path to the dispatch library."); -ABSL_FLAG(std::string, compiler_library_dir, "/data/local/tmp/", - "Path to the compiler plugin library."); -ABSL_FLAG(std::string, hw, "cpu", "Which accelerator to use."); -ABSL_FLAG(std::vector, skips, std::vector{}, - "Substrings of models to skip."); - -namespace litert::test { -namespace { - -// UTILS /////////////////////////////////////////////////////////////////////// - -bool IsTfliteModel(const std::filesystem::path& path) { - return std::filesystem::is_regular_file(path) && - path.extension() == ".tflite"; -} - -std::vector GetModelPaths(const std::string& model_path_str) { - std::filesystem::path model_path = model_path_str; - std::vector models; - if (std::filesystem::is_directory(model_path)) { - for (const auto& entry : std::filesystem::directory_iterator(model_path)) { - if (!IsTfliteModel(entry.path())) { - continue; - } - models.push_back(entry.path().generic_string()); - } - return models; - } - - if (IsTfliteModel(model_path)) { - return {model_path.generic_string()}; - } - - return {}; -} - -std::string ModelName(const std::filesystem::path& path) { - return path.filename().replace_extension().generic_string(); -} - -} // namespace - -// FIXTURES //////////////////////////////////////////////////////////////////// - -class GenDeviceTestFixt : public ::testing::Test {}; - -// A test that simply calls the model and ensures it doesn't crash. -// Works with any accelerator. -template -class InvokeOnceTest : public GenDeviceTestFixt { - public: - InvokeOnceTest(std::string model_path, std::string dispatch_library_dir, - std::string compiler_library_dir) - : model_path_(std::move(model_path)), - dispatch_library_dir_(std::move(dispatch_library_dir)), - compiler_library_dir_(std::move(compiler_library_dir)) {} - - // Opens model and initializes the underlying invoker. - void SetUp() override { - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - absl::string_view(dispatch_library_dir_), - }, - litert::Environment::Option{ - litert::Environment::OptionTag::CompilerPluginLibraryDir, - absl::string_view(compiler_library_dir_), - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN( - auto env, litert::Environment::Create(environment_options)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto model, - litert::Model::CreateFromFile(model_path_)); - litert::internal::Dump(*model.Get()); - - invoker_ = std::make_unique(std::move(env), std::move(model)); - invoker_->MaybeSkip(); - ASSERT_NO_FATAL_FAILURE(invoker_->Setup()); - } - - void TestBody() override { ASSERT_NO_FATAL_FAILURE(invoker_->Run()); } - - private: - std::string model_path_; - std::string dispatch_library_dir_; - std::string compiler_library_dir_; - - CmInvoker::Ptr invoker_; -}; - -// REGISTRATION //////////////////////////////////////////////////////////////// - -// Registers tests dynamically based on the hw flag and the model_path flag. -void ParseTests() { - auto model_path_flag = absl::GetFlag(FLAGS_model_path); - // Provide a sensible default based on linux/android. - if (model_path_flag.empty()) { -#if defined(__ANDROID__) - model_path_flag = "/data/local/tmp/"; -#else - // Set this on linux for smoke check linux presubmit. - model_path_flag = testing::GetLiteRtPath( - "integration_test/single_op_models/add_f32.tflite"); -#endif - } - const auto model_paths = GetModelPaths(model_path_flag); - const auto hw = absl::GetFlag(FLAGS_hw); - const auto dispatch_library_dir = absl::GetFlag(FLAGS_dispatch_library_dir); - const auto compiler_library_dir = absl::GetFlag(FLAGS_compiler_library_dir); - const auto skips = absl::GetFlag(FLAGS_skips); - - LITERT_LOG(LITERT_INFO, "hw: %s", hw.c_str()); - LITERT_LOG(LITERT_INFO, "model_path: %s", model_path_flag.c_str()); - LITERT_LOG(LITERT_INFO, "dispatch_library_dir: %s", - dispatch_library_dir.c_str()); - LITERT_LOG(LITERT_INFO, "compiler_library_dir: %s", - compiler_library_dir.c_str()); - LITERT_LOG(LITERT_INFO, "skips: %s", absl::StrJoin(skips, ",").c_str()); - - if (model_paths.empty()) { - LITERT_LOG(LITERT_WARNING, "No models found to test."); - return; - } - - for (const auto& model_path : model_paths) { - LITERT_LOG(LITERT_INFO, "model_path: %s", model_path.c_str()); - - const auto test_name = absl::StrFormat("%s_%s", ModelName(model_path), hw); - const auto should_skip = - std::any_of(skips.cbegin(), skips.cend(), [&](const auto& skip) { - return (model_path.find(skip) != std::string::npos); - }); - - ::testing::RegisterTest( - "GenDeviceTest", test_name.c_str(), nullptr, nullptr, __FILE__, - __LINE__, [=]() -> GenDeviceTestFixt* { - if (should_skip) { - return new InvokeOnceTest( - model_path, dispatch_library_dir, compiler_library_dir); - } else if (hw == "npu") { - return new InvokeOnceTest( - model_path, dispatch_library_dir, compiler_library_dir); - } else { - return new InvokeOnceTest( - model_path, dispatch_library_dir, compiler_library_dir); - } - }); - } -} - -} // namespace litert::test - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - absl::ParseCommandLine(argc, argv); - litert::test::ParseTests(); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h b/tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h deleted file mode 100644 index b2e585d9a277b4..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_INTEGRATION_TEST_GEN_DEVICE_TEST_LIB_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_INTEGRATION_TEST_GEN_DEVICE_TEST_LIB_H_ - -namespace litert::test { - -// Absract wrapper for the invocation of the compiled model API within a -// standard test environment. -class CmInvoker { - public: - using Ptr = std::unique_ptr; - - CmInvoker(Environment&& env, Model&& model) - : env_(std::move(env)), model_(std::move(model)) {} - - // Setup the compiled model api and initialize the input and output buffers. - // Assumes default signature. - void Setup() { - LITERT_ASSERT_OK_AND_ASSIGN( - compiled_model_, CompiledModel::Create(env_, model_, Accelerator())); - const auto sig = model_.DefaultSignatureKey(); - LITERT_ASSERT_OK_AND_ASSIGN(input_buffers_, - compiled_model_.CreateInputBuffers(sig)); - LITERT_ASSERT_OK_AND_ASSIGN(output_buffers_, - compiled_model_.CreateOutputBuffers(sig)); - } - - // Invoke the compiled model api. Must be called after Setup(). - void Run() { - ASSERT_TRUE(compiled_model_.Run(model_.DefaultSignatureKey(), - input_buffers_, output_buffers_)); - } - - // Is this test in a state where it should be skipped? Implementations should - // call GTEST_SKIP(). - virtual void MaybeSkip() const = 0; - - // Which accelerator option to use. - virtual LiteRtHwAccelerators Accelerator() const = 0; - - std::vector& GetInputBuffers() { return input_buffers_; } - std::vector& GetOutputBuffers() { return output_buffers_; } - - virtual ~CmInvoker() = default; - - protected: - Environment env_; - Model model_; - - CompiledModel compiled_model_; - std::vector input_buffers_; - std::vector output_buffers_; -}; - -class SkippedCmInvoker : public CmInvoker { - public: - SkippedCmInvoker(Environment&& env, Model&& model) - : CmInvoker(std::move(env), std::move(model)) {} - void MaybeSkip() const override { - GTEST_SKIP() << "User requested skip for this model."; - } - - LiteRtHwAccelerators Accelerator() const override { - return kLiteRtHwAcceleratorNone; - }; -}; - -// Invocation of the compiled model API for the NPU accelerator. This handles -// both JIT and pre-compiled models. -class CmNpuInvoker : public CmInvoker { - public: - CmNpuInvoker(Environment&& env, Model&& model) - : CmInvoker(std::move(env), std::move(model)) {} - - // Will invocation require compilation. - bool IsJit() const { - auto& m = *model_.Get(); - return !IsCompiled(m); - } - - LiteRtHwAccelerators Accelerator() const override { - return IsJit() ? kLiteRtHwAcceleratorNpu : kLiteRtHwAcceleratorNone; - } - - void MaybeSkip() const override { -#if !defined(__ANDROID__) - GTEST_SKIP() << "NPU test must run on android device."; -#endif - } -}; - -// Invocation of the compiled model API on CPU. This can run on linux in -// addition to android. -class CmCpuInvoker : public CmInvoker { - public: - CmCpuInvoker(Environment&& env, Model&& model) - : CmInvoker(std::move(env), std::move(model)) {} - - LiteRtHwAccelerators Accelerator() const override { - return kLiteRtHwAcceleratorCpu; - } - - void MaybeSkip() const override { - if (IsCompiled(*model_.Get())) { - GTEST_SKIP() << "Cannot run CPU test on a compiled model."; - } - } -}; - -} // namespace litert::test - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_INTEGRATION_TEST_GEN_DEVICE_TEST_LIB_H_ diff --git a/tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl b/tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl deleted file mode 100644 index 3d229478e7c4c4..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module defines the `run_on_device` macro, which helps to execute a binary target on a device. -""" - -load("//tensorflow:tensorflow.bzl", "if_oss") -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "absolute_label") - -# DEVICE PATHS ##################################################################################### - -DEVICE_RLOCATION_ROOT = "/data/local/tmp/runfiles" - -def device_rlocation(label = None, get_parent = False): - """Get the path on device for a given label. - - Args: - label: The label to get the path for. If None, returns the root path. - get_parent: If true, get the parent directory of the resolved path. - - Returns: - The path on device for the given label. - """ - if not label: - return DEVICE_RLOCATION_ROOT - abs_label = absolute_label(label) - res = DEVICE_RLOCATION_ROOT + "/" + abs_label.replace("//", "").replace(":", "/") - if get_parent: - return res[:res.rfind("/")] - return res - -def make_path_args(spec): - """Formats shell path-like variable assignment exprs from common directories in given labels - - Useful for making things like LD_LIBRARY_PATH=... for paths on device. - - An entry of the spec contains a key, and a list of labels. Unique leaf directories paths are - extracted from the labels and joined into a colon-separated string. - - Example: - ``` - make_path_args({ - "LD_LIBRARY_PATH": [ - "// foo : bar", - ], - "ADSP_LIBRARY_PATH": [ - "// foo : baz", - "// foo : bat" - ], - }) - ``` - will return: - ``` - LD_LIBRARY_PATH=/data/local/tmp/runfiles/foo/bar - ADSP_LIBRARY_PATH=/data/local/tmp/runfiles/foo/baz:/data/local/tmp/runfiles/foo/bat - ``` - - Args: - spec: A dict of path variable names to lists of labels. - - Returns: - A list of shell variable assignment expressions. - """ - - res = [] - for path_var, values in spec.items(): - # TODO: Figure out why OSS doesn't have `set` core datatype. - dirs = [] - for v in values: - parent = device_rlocation(v, True) - if parent not in dirs: - dirs.append(parent) - res.append("{path_var}={paths}".format( - path_var = path_var, - paths = ":".join(dirs), - )) - return res - -# DYNAMIC LIBRARY DEPENDENCIES ##################################################################### - -LITERT_CORE_LIBS = [ - "//tensorflow/lite/experimental/litert/c:libLiteRtRuntimeCApi.so", -] - -def make_lib_spec(**kwargs): - return struct( - litert_base_libs = LITERT_CORE_LIBS, - core_libs = kwargs["core_libs"], - kernel_libs = kwargs["kernel_libs"], - dispatch_lib = kwargs["dispatch_lib"], - compiler_lib = kwargs["compiler_lib"], - ) - -BASE_LIB_SPEC = make_lib_spec( - core_libs = [], - kernel_libs = [], - dispatch_lib = None, - compiler_lib = None, -) - -def all_libs(spec): - """ - Returns all the dynamic libraries needed for the given spec. - - Args: - spec: The lib spec to get the libs for. - - Returns: - A list of all the dynamic libraries needed for the given spec. - """ - libs = spec.litert_base_libs + spec.core_libs + spec.kernel_libs - for lib in [spec.dispatch_lib, spec.compiler_lib]: - if lib: - libs.append(lib) - return libs - -# QNN - -QUALCOMM_LIB_SPEC = make_lib_spec( - core_libs = [ - "//third_party/qairt/latest:lib/aarch64-android/libQnnHtp.so", - "//third_party/qairt/latest:lib/aarch64-android/libQnnHtpV75Stub.so", - "//third_party/qairt/latest:lib/aarch64-android/libQnnSystem.so", - "//third_party/qairt/latest:lib/aarch64-android/libQnnHtpPrepare.so", - ], - kernel_libs = ["//third_party/qairt/latest:lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so"], - dispatch_lib = "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - compiler_lib = "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", -) - -# MTK -# TODO - -# GOOGLE TENSOR -# TODO - -def get_lib_spec(backend_id): - """ - Returns the dynamic library spec for the given backend id. - - Args: - backend_id: The backend id to get the lib spec for. - - Returns: - The dynamic library spec for the given backend id. - """ - if backend_id == "qualcomm": - return QUALCOMM_LIB_SPEC - if backend_id == "cpu": - return BASE_LIB_SPEC - else: - fail("Unsupported backend id: {}".format(backend_id)) - -# RUN ON DEVICE MACRO ############################################################################## - -def get_driver(): - return if_oss( - "//tensorflow/lite/experimental/litert/integration_test:run_on_device_driver_OSS", - "//tensorflow/lite/experimental/litert/integration_test/google:run_on_device_driver", - ) - -def run_on_device( - name, - target, - driver, - data = [], - exec_args = [], - exec_env_vars = []): - """ - Macro to execute a binary target on a device (locally through ADB). - - The output of this macro is an executable shell script that pushes all the necessary files to - the device and executes the target with the given arguments and environment variables. - - Args: - name: Name of the target. - target: The binary target to execute on device. - driver: The driver script to use for execution. - data: List of data files to push to the device. - exec_args: List of arguments to pass to the executable. - exec_env_vars: List of environment variables to set before executing the target. - """ - call_mobile_install = """ - echo '$(location {driver}) \ - --bin=$(rlocationpath {target}) \ - --data={data} \ - --do_exec=true \ - --exec_args={exec_args} \ - --exec_env_vars={exec_env_vars} \ - '\ - > $@ - """ - - concat_targ_data = "$$(echo \"$(rlocationpaths {})\" | sed \"s/ /,/g\")" - data_str = ",".join([concat_targ_data.format(d) for d in data]) - - # NOTE: Tilde delimiter here (also see driver script) to allow passing list args to underlying - # binary. - exec_args_str = "~".join(["{}".format(a) for a in exec_args]) - exec_env_vars_str = ",".join(["{}".format(a) for a in exec_env_vars]) - - driver_targ = driver.removesuffix(".sh") - driver_sh = driver_targ + ".sh" - - cmd = call_mobile_install.format( - driver = driver_sh, - target = target, - data = data_str, - exec_args = exec_args_str, - exec_env_vars = exec_env_vars_str, - ) - - exec_script = name + "_exec.sh" - - native.genrule( - name = name + "_gen_script", - srcs = [driver_sh] + [target] + data, - outs = [exec_script], - tags = ["manual", "notap"], - cmd = cmd, - testonly = True, - ) - - native.sh_binary( - testonly = True, - tags = ["manual", "notap"], - name = name, - deps = [driver_targ], - srcs = [exec_script], - data = [target] + data, - ) - -def litert_integration_test( - name, - models, - hw = "cpu", - skips = []): - """ - Higher level macro that configures run_on_device or a mobile test to run with gen_device_test. - - Args: - name: Name of the target. - models: A single target that may contain model or many models in the same directory. - hw: The backend to test against (see gen_device_test). - skips: List of substrings of models to skip. - """ - - # Get libs for the given backend. - lib_spec = get_lib_spec(hw) - - # Accelerator option to pass to the compiled model api on device. - hw_cfg = hw if hw == "cpu" else "npu" - - # Create env args for paths to dynamic libraries. - env_args = make_path_args({ - "LD_LIBRARY_PATH": lib_spec.litert_base_libs + lib_spec.core_libs + [lib_spec.dispatch_lib, lib_spec.compiler_lib], - "ADSP_LIBRARY_PATH": lib_spec.kernel_libs, - }) - - skips_str = ",".join(skips) - - # Create CLI args for the gen_device_test binary on device. - cli_args = [ - "--model_path={}".format(device_rlocation(models)), - "--dispatch_library_dir={}".format(device_rlocation(lib_spec.dispatch_lib, True)), - "--compiler_library_dir={}".format(device_rlocation(lib_spec.compiler_lib, True)), - "--hw={}".format(hw_cfg), - "--skips={}".format(skips_str), - ] - - data = [models] + all_libs(lib_spec) - driver = get_driver() - target = "//tensorflow/lite/experimental/litert/integration_test:gen_device_test" - - # TODO: Also kick off a xeno mobile test here. - - run_on_device( - name = name, - target = target, - driver = driver, - data = data, - exec_args = cli_args, - exec_env_vars = env_args, - ) diff --git a/tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh b/tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh deleted file mode 100755 index 6fa24babd8b328..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/bin/bash - -# TODO: Implement this script to leverage integration tests in OSS. - -# This script must handle the following flags: - -# DEFINE_string --required bin "" "The binary to execute on the device." -# DEFINE_array data --type=string "" "The data files to install on the device." -# DEFINE_bool do_exec false "Whether to execute the target on the device." -# DEFINE_array exec_args --type=string "" "The arguments to pass to the executable on device." -# DEFINE_array exec_env_vars --type=string "" "The environment variables to set for the executable on device." -# DEFINE_string device_rlocation_root "/data/local/tmp/runfiles" "The root directory for device relative locations." - -# This script must push the bin file and all the data files to the device under -# the device_rlocation_root directory. If do_exec is true, it must execute the -# binary on the device with the given exec_args and exec_env_vars. \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir deleted file mode 100644 index d4e9d5f59da6dc..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.add"}} { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir deleted file mode 100644 index ff1d3172f76e36..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x4x2xf32>, %arg2: tensor<2x1x2xf32>) -> tensor<2x8x2xf32> attributes {tf.entry_function = {inputs = "arg0,arg1,arg2", outputs = "tfl.concatenation"}} { - %0 = "tfl.concatenation"(%arg0, %arg1, %arg2) <{axis = 1 : i32, fused_activation_function = "NONE"}> : (tensor<2x3x2xf32>, tensor<2x4x2xf32>, tensor<2x1x2xf32>) -> tensor<2x8x2xf32> - return %0 : tensor<2x8x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir deleted file mode 100644 index 8bb1cf1b5f95af..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.6.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.div"}} { - %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir deleted file mode 100644 index 00fef7d8448236..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xi1> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.greater"}} { - %0 = tfl.greater(%arg0, %arg1) : (tensor<256x256xf32>, tensor<256x256xf32>) -> tensor<256x256xi1> - return %0 : tensor<256x256xi1> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir deleted file mode 100644 index 0c59c9e5c889ad..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xi1> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.less"}} { - %0 = tfl.less(%arg0, %arg1) : (tensor<256x256xf32>, tensor<256x256xf32>) -> tensor<256x256xi1> - return %0 : tensor<256x256xi1> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir deleted file mode 100644 index 3390ac72910615..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.mul"}} { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir deleted file mode 100644 index 342cfcc69fa61c..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<3x4xf32>) -> tensor<4x3xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.reshape"}} { - %0 = "tfl.pseudo_const"() <{value = dense<[4, 3]> : tensor<2xi32>}> : () -> tensor<2xi32> - %1 = "tfl.reshape"(%arg0, %0) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32> - return %1 : tensor<4x3xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir deleted file mode 100644 index 63df1776dd7b25..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<2x3x4x5x6x7x8xf32>) -> tensor<8x7x6x5x4x3x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.reshape"}} { - %0 = "tfl.pseudo_const"() <{value = dense<[8, 7, 6, 5, 4, 3, 2]> : tensor<7xi32>}> : () -> tensor<7xi32> - %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3x4x5x6x7x8xf32>, tensor<7xi32>) -> tensor<8x7x6x5x4x3x2xf32> - return %1 : tensor<8x7x6x5x4x3x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir deleted file mode 100644 index 51e03dc9cdcdaf..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.10.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256xf32>) -> tensor<256xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.rsqrt"}} { - %0 = "tfl.rsqrt"(%arg0) : (tensor<256xf32>) -> tensor<256xf32> - return %0 : tensor<256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir deleted file mode 100644 index c37db98eee2114..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<2x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf.entry_function = {inputs = "arg0,arg1,arg2", outputs = "tfl.select"}} { - %0 = "tfl.select"(%arg0, %arg1, %arg2) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir deleted file mode 100644 index 50b62c65be8ff2..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir +++ /dev/null @@ -1,8 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.slice"}} { - %0 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32> - %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<2xi32>}> : () -> tensor<2xi32> - %2 = "tfl.slice"(%arg0, %0, %1) : (tensor<3x4xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xf32> - return %2 : tensor<2x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir deleted file mode 100644 index 4265a8c95eeabd..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.6.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.sub"}} { - %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir deleted file mode 100644 index 022392b4f294a4..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256xf32>) -> tensor<256xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.tanh"}} { - %0 = "tfl.tanh"(%arg0) : (tensor<256xf32>) -> tensor<256xf32> - return %0 : tensor<256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/python/BUILD b/tensorflow/lite/experimental/litert/python/BUILD deleted file mode 100644 index eeab7c9f0c21b0..00000000000000 --- a/tensorflow/lite/experimental/litert/python/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) diff --git a/tensorflow/lite/experimental/litert/runtime/BUILD b/tensorflow/lite/experimental/litert/runtime/BUILD deleted file mode 100644 index a0e34ead449a1c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/BUILD +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "gtest_main_no_heapcheck_deps") -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "gles_deps", "gles_linkopts", "lite_rt_friends") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "event", - srcs = [ - "event.cc", - ], - hdrs = [ - "event.h", - ], - deps = [ - ":gpu_environment", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_event", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "tensor_buffer", - srcs = [ - "ahwb_buffer.cc", - "dmabuf_buffer.cc", - "fastrpc_buffer.cc", - "gl_buffer.cc", - "gl_texture.cc", - "ion_buffer.cc", - "open_cl_buffer.cc", - "tensor_buffer.cc", - ], - hdrs = [ - "ahwb_buffer.h", - "dmabuf_buffer.h", - "event.h", - "fastrpc_buffer.h", - "gl_buffer.h", - "gl_texture.h", - "ion_buffer.h", - "open_cl_buffer.h", - "tensor_buffer.h", - "tensor_buffer_requirements.h", - "//tensorflow/lite/experimental/litert/c:litert_event.h", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer.h", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_requirements.h", - ], - linkopts = gles_linkopts(), - deps = [ - ":gpu_environment", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event_type", - "//tensorflow/lite/experimental/litert/c:litert_gl_types", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_event", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_utils", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/runtime/opencl:buffer", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_command_queue", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_context", - "//tensorflow/lite/experimental/litert/runtime/opencl:opencl_wrapper", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ] + gles_deps() + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:gl_buffer", - "//tensorflow/lite/delegates/gpu/gl:gl_texture", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "gpu_environment", - srcs = [ - "gpu_environment.cc", - ], - hdrs = [ - "gpu_environment.h", - ], - visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//tensorflow/lite/experimental/litert/c:__subpackages__", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_command_queue", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_context", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_device", - "//tensorflow/lite/experimental/litert/runtime/opencl:opencl_wrapper", - "@opencl_headers", - ], -) - -cc_test( - name = "gpu_environment_test", - srcs = ["gpu_environment_test.cc"], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":gpu_environment", - "@com_google_googletest//:gtest_main", - # copybara:uncomment_begin(google-only) - # "//third_party/ml_drift/cl:environment", - # "//third_party/ml_drift/cl:opencl_wrapper", - # copybara:uncomment_end - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/runtime/opencl:opencl_wrapper", - ], -) - -cc_library( - name = "tfl_utils", - srcs = [ - "tfl_utils.cc", - ], - hdrs = [ - "tfl_utils.h", - ], - deps = [ - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_library( - name = "external_litert_buffer_context", - srcs = ["external_litert_buffer_context.cc"], - hdrs = ["external_litert_buffer_context.h"], - visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ] + lite_rt_friends(), - deps = [ - ":tfl_utils", - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_requirements", - ], -) - -cc_library( - name = "compiled_model", - srcs = ["compiled_model.cc"], - hdrs = ["compiled_model.h"], - deps = [ - ":accelerator", - ":accelerator_model_compilation_data", - ":compilation_options", - ":external_litert_buffer_context", - ":tensor_buffer", - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/lite:builtin_ops", - "//tensorflow/lite:framework", - "//tensorflow/lite:model_builder", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_event", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_requirements", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "compiled_model_test", - srcs = ["compiled_model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - # require GPU to run OpenCL tests. - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":compiled_model", - ":tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "compilation_options", - hdrs = [ - "compilation_options.h", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options.h", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_accelerator_compilation_options", - ], -) - -cc_test( - name = "gl_buffer_test", - srcs = ["gl_buffer_test.cc"], - linkopts = select({ - "//tensorflow:android": [ - "-landroid", - ], - "//conditions:default": [], - }), - tags = [ - "notap", - ], - deps = [ - ":tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/test:matchers", - ] + gtest_main_no_heapcheck_deps() + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - "//tensorflow/lite/delegates/gpu/gl:gl_buffer", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "tensor_buffer_conversion", - srcs = ["tensor_buffer_conversion.cc"], - hdrs = ["tensor_buffer_conversion.h"], - linkopts = gles_linkopts(), - deps = [ - ":tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_utils", - "@com_google_absl//absl/strings:str_format", - ] + gles_deps(), -) - -cc_test( - name = "tensor_buffer_conversion_test", - srcs = ["tensor_buffer_conversion_test.cc"], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "notap", - ], - deps = [ - ":tensor_buffer", - ":tensor_buffer_conversion", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "accelerator", - hdrs = ["accelerator.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - ], -) - -cc_library( - name = "accelerator_registry", - srcs = ["accelerator_registry.cc"], - hdrs = ["accelerator_registry.h"], - deps = [ - ":accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - ], -) - -cc_library( - name = "accelerator_model_compilation_data", - hdrs = ["accelerator_model_compilation_data.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "accelerator_model_compilation_data_test", - srcs = ["accelerator_model_compilation_data_test.cc"], - deps = [ - ":accelerator_model_compilation_data", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/core:version", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator.h b/tensorflow/lite/experimental/litert/runtime/accelerator.h deleted file mode 100644 index 2574588482976f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -// We need to forward declare this to avoid a dependency loop. -struct LiteRtCompiledModelT; -struct LiteRtEnvironmentT; - -struct LiteRtAcceleratorT { - // Points to the type-erased accelerator state. - void* data; - - // Points to the environment that owns this accelerator. - LiteRtEnvironmentT* env; - - // NOLINTBEGIN(*-readability-class-member-naming) - - // Releases the the data. - // - // This function is used by the framework to clean up the accelerator. It - // should not be called by client code. - void (*ReleaseData)(void*); - - // Retrieves the accelerator name. - LiteRtStatus (*GetName)(LiteRtAcceleratorT* accelerator, const char** name); - - // Retrieves the accelerator version. - LiteRtStatus (*GetVersion)(LiteRtAcceleratorT* accelerator, - LiteRtApiVersion* version); - - // Retrieves the accelerator hardware support. - LiteRtStatus (*GetHardwareSupport)( - LiteRtAcceleratorT* accelerator, - LiteRtHwAcceleratorSet* supported_hardware); - - // Creates a delegate for the accelerator. - // Used void** instead of TfLiteOpaqueDelegate** to avoid TFLite dependency. - LiteRtStatus (*CreateDelegate)( - LiteRtAcceleratorT* accelerator, - LiteRtAcceleratorCompilationOptions compilation_options, void** delegate); - - // Destroys created delegate for the accelerator. - // The function signature is matched with existing TfLiteOpaqueDelegate - // interface to use. - // Used void* instead of TfLiteOpaqueDelegate* to avoid TFLite dependency. - void (*DestroyDelegate)(void* delegate); - - LiteRtStatus (*IsTfLiteDelegateResponsibleForJitCompilation)( - LiteRtAcceleratorT* accelerator, bool* does_jit_compilation); - - // NOLINTEND(*-readability-class-member-naming) -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h b/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h deleted file mode 100644 index 9f465134533d25..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_MODEL_COMPILATION_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_MODEL_COMPILATION_DATA_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -// Holds environment data that accelerators may need to prepare their -// delegates. -// -// These options are automatically added to the compilation options list -// during the creation of the compiled model. -struct ModelCompilationData { - static constexpr LiteRtApiVersion kVersion = {1, 0, 0}; - static constexpr auto kIdentifier = "environment-compilation-options"; - - static Expected CreateOptions() { - auto* payload_data = new ModelCompilationData; - auto payload_destructor = [](void* payload_data) { - delete reinterpret_cast(payload_data); - }; - return AcceleratorCompilationOptions::Create( - kVersion, kIdentifier, payload_data, payload_destructor); - } - - // Pointer to the start of the model file memory allocation. - const char* allocation_base; - // File descriptor of the model file memory allocation. If there is no such - // file descriptor, this must be set to -1. - int allocation_fd; - - private: - ModelCompilationData() = default; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_MODEL_COMPILATION_DATA_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc b/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc deleted file mode 100644 index 8ebb1aa5426ba5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::Eq; -using testing::StrEq; - -TEST(ModelCompilationDataTest, CreateSetsUpAllNecessaryFields) { - LITERT_ASSERT_OK_AND_ASSIGN( - auto options, litert::internal::ModelCompilationData::CreateOptions()); - - LITERT_ASSERT_OK_AND_ASSIGN(auto identifier, options.GetIdentifier()); - EXPECT_THAT(identifier, - StrEq(litert::internal::ModelCompilationData::kIdentifier)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto version, options.GetVersion()); - EXPECT_TRUE(litert::internal::IsSameVersion( - version, litert::internal::ModelCompilationData::kVersion)); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc b/tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc deleted file mode 100644 index c74577d7f26ef8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerator_registry.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -namespace litert::internal { - -void AcceleratorRegistry::DestroyAccelerator(LiteRtAcceleratorT* accelerator) { - if (accelerator && accelerator->ReleaseData) { - accelerator->env = nullptr; - accelerator->ReleaseData(accelerator->data); - } - delete accelerator; -} - -Expected AcceleratorRegistry::RegisterAccelerator( - Ptr accelerator) { - if (!accelerator) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Cannot register a null accelerator."); - } - accelerators_.push_back(std::move(accelerator)); - return accelerators_.back().get(); -} - -Expected AcceleratorRegistry::Get(LiteRtParamIndex idx) { - if (idx >= size()) { - return Error(kLiteRtStatusErrorNotFound, "Cannot find accelerator."); - } - return accelerators_[idx].get(); -} - -Expected AcceleratorRegistry::FindAcceleratorIndex( - LiteRtAcceleratorT* accelerator) { - for (size_t idx = 0; idx < accelerators_.size(); ++idx) { - if (accelerator == accelerators_[idx].get()) { - return static_cast(idx); - } - } - return Error(kLiteRtStatusErrorNotFound, - "The accelerator is not registered in the LiteRT environment."); -} - -void AcceleratorRegistry::TakeOwnershipOfSharedLibrary(SharedLibrary lib) { - accelerator_shared_libraries_.push_back(std::move(lib)); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.h b/tensorflow/lite/experimental/litert/runtime/accelerator_registry.h deleted file mode 100644 index 11c4feec022985..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.h +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_REGISTRY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_REGISTRY_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -namespace litert::internal { - -// Holds a list of accelerators. -// -// This is a helper class for the LiteRT environment that manages the -// accelerators (and their resources) that are registered with it. -class AcceleratorRegistry { - public: - struct Deleter { - void operator()(LiteRtAcceleratorT* accelerator) { - DestroyAccelerator(accelerator); - } - }; - - // Wraps a pointer for LiteRtAcceleratorT with a custom deleter that handles - // cleaning up the accelerator internal data. - using Ptr = std::unique_ptr<::LiteRtAcceleratorT, Deleter>; - - // Internal implementation for the C API. - [[nodiscard]] - static Ptr CreateEmptyAccelerator() { - return Ptr(new LiteRtAcceleratorT()); - } - - // Internal implementation for the C API. - static void DestroyAccelerator(::LiteRtAcceleratorT* accelerator); - - // Registers an accelerator. - Expected RegisterAccelerator(Ptr accelerator); - - // Returns the idx-th accelerator that was registered. - [[nodiscard]] - Expected Get(LiteRtParamIndex idx); - - // Goes through accelerators and find the index of the given one. - Expected FindAcceleratorIndex( - LiteRtAcceleratorT* accelerator); - - // Gives ownership of the shared library to the registry. - // - // This should be called when an accelerator is loaded from a shared library - // to tie the library lifetime to the registry. - // - // The library will be closed when the registry is destroyed. - void TakeOwnershipOfSharedLibrary(SharedLibrary library); - - // Returns the number of accelerators that have been registered. - size_t size() const { return accelerators_.size(); } - auto begin() const { return accelerators_.begin(); } - auto begin() { return accelerators_.begin(); } - auto end() const { return accelerators_.end(); } - auto end() { return accelerators_.end(); } - - private: - std::vector accelerators_; - // Some accelerators are loaded as shared libraries. This list keeps these - // libraries loaded while the environment uses them. - std::vector accelerator_shared_libraries_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_REGISTRY_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_test.cc b/tensorflow/lite/experimental/litert/runtime/accelerator_test.cc deleted file mode 100644 index 84f88d13b61c75..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -#include - -namespace litert::internal { -namespace { - -TEST(AcceleratorRegistryTest, CreateEmptyAcceleratorWorks) { - [[maybe_unused]] - auto accelerator_squeleton = AcceleratorRegistry::CreateEmptyAccelerator(); -} - -TEST(AcceleratorRegistryTest, AcceleratorCanBeRegisteredAndRetrieved) { - AcceleratorRegistry registry; - - auto registered_accelerator1 = registry.RegisterAccelerator( - AcceleratorRegistry::CreateEmptyAccelerator()); - ASSERT_TRUE(registered_accelerator1); - - auto registered_accelerator2 = registry.RegisterAccelerator( - AcceleratorRegistry::CreateEmptyAccelerator()); - ASSERT_TRUE(registered_accelerator2); - - ASSERT_NE(registered_accelerator1, registered_accelerator2); - - auto queried_accelerator1 = registry.Get(0); - ASSERT_TRUE(queried_accelerator1); - EXPECT_EQ(queried_accelerator1, registered_accelerator1); - - auto queried_accelerator2 = registry.Get(1); - ASSERT_TRUE(queried_accelerator2); - EXPECT_EQ(queried_accelerator2, registered_accelerator2); - - EXPECT_FALSE(registry.Get(2)); - EXPECT_FALSE(registry.Get(-1)); - - auto idx1 = registry.FindAcceleratorIndex(queried_accelerator1.Value()); - ASSERT_TRUE(idx1); - EXPECT_EQ(idx1.Value(), 0); - - auto idx2 = registry.FindAcceleratorIndex(queried_accelerator2.Value()); - ASSERT_TRUE(idx2); - EXPECT_EQ(idx2.Value(), 1); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/BUILD b/tensorflow/lite/experimental/litert/runtime/accelerators/BUILD deleted file mode 100644 index c007232c0c1702..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "auto_registration", - srcs = ["auto_registration.cc"], - hdrs = ["auto_registration.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime/accelerators/dispatch:dispatch_accelerator", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "accelerator_implementation_helper", - hdrs = ["accelerator_implementation_helper.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_registration", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime:accelerator_model_compilation_data", - "@com_google_absl//absl/strings:string_view", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h b/tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h deleted file mode 100644 index d44d6533267cfc..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_ACCELERATOR_IMPLEMENTATION_HELPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_ACCELERATOR_IMPLEMENTATION_HELPER_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -namespace litert::internal { - -struct AcceleratorDestructor { - void operator()(LiteRtAccelerator accelerator) { - LiteRtDestroyAccelerator(accelerator); - } -}; - -// RAII wrapper for accelerator handles. -using AcceleratorGuard = - std::unique_ptr::element_type, - AcceleratorDestructor>; - -// Helps setting up an accelerator handle for accelerators that use the -// `AcceleratorImplementationHelper` template as a base class. -template -Expected SetAcceleratorBoilerplateFunctions( - AcceleratorGuard& accelerator) { - LITERT_RETURN_IF_ERROR( - LiteRtSetAcceleratorGetName(accelerator.get(), T::GetName)); - LITERT_RETURN_IF_ERROR( - LiteRtSetAcceleratorGetVersion(accelerator.get(), T::GetVersion)); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetHardwareSupport( - accelerator.get(), T::GetHardwareSupport)); - LITERT_RETURN_IF_ERROR(LiteRtSetDelegateFunction( - accelerator.get(), T::CreateDelegate, T::DestroyDelegate)); - return {}; -} - -// Goes through the options in the linked list and returns the model -// compilation data if it exists. -inline static Expected -GetModelCompilationData(LiteRtAcceleratorCompilationOptions options) { - LiteRtApiVersion payload_version; - void* payload_data; - LITERT_RETURN_IF_ERROR(LiteRtFindAcceleratorCompilationOptionsData( - options, litert::internal::ModelCompilationData::kIdentifier, - &payload_version, &payload_data)); - return reinterpret_cast( - payload_data); -} - -// Helps accelerator implementation by providing a lot of the boilerplate -// needed. -// -// Warning: The provided Ptr assumes that AcceleratorClass instances are -// created using `operator new`. -// -// Warning: `version` should be incremented every time the code of this -// accelerator is updated according to semanting versioning. -template -class AcceleratorImplementationHelper { - public: - // The accelerator name returned by `GetName`. - constexpr static const absl::string_view kName = name_; - // The accelerator version returned by `GetVersion`. - constexpr static const LiteRtApiVersion kVersion = version_; - // The accelerator hardware support returned by `GetHardwareSupport`. - constexpr static const LiteRtHwAcceleratorSet kHwSupport = hardware_support_; - - struct Deleter { - void operator()(AcceleratorClass* accelerator_impl) { - delete accelerator_impl; - } - }; - - // Owning pointer wrapping the accelerator. - using Ptr = std::unique_ptr; - - // Creates a new instance of the accelerator implementation. - template - static Ptr Allocate(Args&&... args) { - return Ptr(new AcceleratorClass(std::forward(args)...)); - } - - // Deletes the accelerator data. - static void Destroy(void* accelerator_impl) { - Deleter()(reinterpret_cast(accelerator_impl)); - } - - // Returns the accelerator's name by setting `name`. - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(name != nullptr, kLiteRtStatusErrorInvalidArgument, - "Name pointer is null."); - *name = kName.data(); - return kLiteRtStatusOk; - } - - // Returns the accelerator's version by setting `version`. - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(version != nullptr, kLiteRtStatusErrorInvalidArgument, - "Version pointer is null."); - *version = kVersion; - return kLiteRtStatusOk; - } - - // Returns the accelerator's hardware support by setting `hw_set`. - static LiteRtStatus GetHardwareSupport(LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* hw_set) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(hw_set != nullptr, kLiteRtStatusErrorInvalidArgument, - "Hardware support pointer is null."); - *hw_set = kHwSupport; - return kLiteRtStatusOk; - } -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_ACCELERATOR_IMPLEMENTATION_HELPER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc b/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc deleted file mode 100644 index ecda799184f4c8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h" - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h" - -// Define a function pointer to allow the accelerator registration to be -// overridden by the LiteRT environment. This is to use the GPU accelerator -// statically linked. -extern "C" bool (*LiteRtRegisterStaticLinkedAcceleratorGpu)( - LiteRtEnvironmentT& environment) = nullptr; - -namespace litert { - -Expected TriggerAcceleratorAutomaticRegistration( - LiteRtEnvironmentT& environment) { - // Register the NPU accelerator. - - auto npu_registration = - LiteRtRegisterNpuAccelerator(&environment, /*options=*/nullptr); - if (npu_registration != kLiteRtStatusOk) { - LITERT_LOG(LITERT_WARNING, - "GPU accelerator could not be loaded and registered: %s.", - LiteRtGetStatusString(npu_registration)); - } else { - LITERT_LOG(LITERT_INFO, "NPU accelerator registered."); - } - - // Register the GPU accelerator. - if (LiteRtRegisterStaticLinkedAcceleratorGpu != nullptr && - LiteRtRegisterStaticLinkedAcceleratorGpu(environment)) { - LITERT_LOG(LITERT_INFO, "Statically linked GPU accelerator registered."); - return {}; - } - auto gpu_registration = RegisterSharedObjectAccelerator( - environment, /*plugin_path=*/"libLiteRtGpuAccelerator.so", - /*registration_function_name=*/"LiteRtRegisterAcceleratorGpuOpenCl"); - if (!gpu_registration) { - LITERT_LOG(LITERT_WARNING, - "GPU accelerator could not be loaded and registered: %s.", - gpu_registration.Error().Message().c_str()); - } else { - LITERT_LOG(LITERT_INFO, "GPU accelerator registered."); - } - return {}; -}; - -Expected RegisterSharedObjectAccelerator( - LiteRtEnvironmentT& environment, absl::string_view plugin_path, - absl::string_view registration_function_name) { - auto maybe_lib = SharedLibrary::Load(plugin_path, RtldFlags::Lazy().Local()); - if (!maybe_lib.HasValue()) { - maybe_lib = SharedLibrary::Load(RtldFlags::kDefault); - } - // Note: the Load(kDefault) overload always succeeds, so we are sure that - // maybe_lib contains a value. - SharedLibrary lib(std::move(maybe_lib.Value())); - LITERT_ASSIGN_OR_RETURN(auto registration_function, - lib.LookupSymbol( - registration_function_name.data())); - LITERT_RETURN_IF_ERROR(registration_function(&environment)); - environment.GetAcceleratorRegistry().TakeOwnershipOfSharedLibrary( - std::move(lib)); - return {}; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h b/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h deleted file mode 100644 index 5ec12d7ed8a735..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_AUTO_REGISTRATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_AUTO_REGISTRATION_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" - -namespace litert { - -Expected TriggerAcceleratorAutomaticRegistration( - LiteRtEnvironmentT& environment); - -Expected RegisterSharedObjectAccelerator( - LiteRtEnvironmentT& environment, absl::string_view plugin_path, - absl::string_view registration_function_name); - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_AUTO_REGISTRATION_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD b/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD deleted file mode 100644 index 68758738b7dc45..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "dispatch_accelerator", - srcs = ["dispatch_accelerator.cc"], - hdrs = ["dispatch_accelerator.h"], - deps = [ - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_registration", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/cc:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator_model_compilation_data", - "@com_google_absl//absl/strings:string_view", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc b/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc deleted file mode 100644 index 66189b4a84b028..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h" - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -namespace litert { - -class NpuAccelerator final { - constexpr static const absl::string_view kName = "NpuAccelerator"; - // Warning: this should be incremented every time the code of this accelerator - // is updated according to semanting versioning. - constexpr static const LiteRtApiVersion kVersion{1, 0, 0}; - constexpr static const LiteRtHwAcceleratorSet kHwSupport = - kLiteRtHwAcceleratorNpu; - - public: - explicit NpuAccelerator(std::string library_folder) - : library_folder_(std::move(library_folder)) {} - - struct Deleter { - void operator()(NpuAccelerator* npu_accelerator) { delete npu_accelerator; } - }; - using Ptr = std::unique_ptr; - - static Expected Create(std::string library_folder) { - LITERT_RETURN_IF_ERROR( - !library_folder.empty(), - Error(kLiteRtStatusErrorInvalidArgument, - "Dispatch API implementation library folder was not specified.")); - return Ptr(new NpuAccelerator(std::move(library_folder))); - } - - // C API - - // Deletes the accelerator data. - static void Destroy(void* npu_accelerator) { - Deleter()(reinterpret_cast(npu_accelerator)); - } - - // Stores the accelerator's name in `name`. - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(name != nullptr, kLiteRtStatusErrorInvalidArgument, - "Name pointer is null."); - *name = kName.data(); - return kLiteRtStatusOk; - } - - // Stores the accelerator's version in `version`. - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(version != nullptr, kLiteRtStatusErrorInvalidArgument, - "Version pointer is null."); - *version = kVersion; - return kLiteRtStatusOk; - } - - // Stores the accelerator's hardware support in `hw_set`. - static LiteRtStatus GetHardwareSupport(LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* hw_set) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(hw_set != nullptr, kLiteRtStatusErrorInvalidArgument, - "Harware support pointer is null."); - *hw_set = kHwSupport; - return kLiteRtStatusOk; - } - - // Goes through the options in the linked list and returns the model - // compilation data if it exists. - static Expected - GetModelCompilationData(LiteRtAcceleratorCompilationOptions options) { - LiteRtApiVersion payload_version; - void* payload_data; - LITERT_RETURN_IF_ERROR(LiteRtFindAcceleratorCompilationOptionsData( - options, litert::internal::ModelCompilationData::kIdentifier, - &payload_version, &payload_data)); - return reinterpret_cast( - payload_data); - } - - // Creates a Dispatch delegate instance. - static LiteRtStatus CreateDelegate( - LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, void** delegate) { - LITERT_ENSURE(delegate != nullptr, kLiteRtStatusErrorInvalidArgument, - "Delegate pointer is null."); - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(accelerator->env != nullptr, - kLiteRtStatusErrorInvalidArgument, - "Accelerator is not registered to an environment."); - - LITERT_ASSIGN_OR_RETURN( - const litert::internal::ModelCompilationData* compilation_data, - GetModelCompilationData(options)); - - LITERT_ENSURE(compilation_data->allocation_base, - kLiteRtStatusErrorRuntimeFailure, - "No model allocation was passed by the runtime."); - - auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr( - &accelerator->env->GetOptions()); - LITERT_ENSURE(dispatch_delegate_options != nullptr, - kLiteRtStatusErrorRuntimeFailure, - "Dispatch delegate options failed to be created."); - - LITERT_ENSURE( - LiteRtDispatchDelegateAddAllocBaseOption( - dispatch_delegate_options.get(), - compilation_data->allocation_base) == kTfLiteOk, - kLiteRtStatusErrorRuntimeFailure, - "Could not add allocation base to dispatch delegate options."); - - if (compilation_data->allocation_fd != -1) { - LITERT_ENSURE(LiteRtDispatchDelegateAddAllocFdOption( - dispatch_delegate_options.get(), - compilation_data->allocation_fd) == kTfLiteOk, - kLiteRtStatusErrorRuntimeFailure, - "Could not add allocation file descriptor to dispatch " - "delegate options."); - } - - auto dispatch_delegate = litert::CreateDispatchDelegatePtr( - &accelerator->env->GetOptions(), std::move(dispatch_delegate_options)); - LITERT_ENSURE(dispatch_delegate != nullptr, - kLiteRtStatusErrorRuntimeFailure, - "Dispatch delegate failed to be created."); - - *delegate = dispatch_delegate.release(); - return kLiteRtStatusOk; - } - - // Destroys a Dispatch delegate instance. - static void DestroyDelegate(void* delegate) { - LiteRtDestroyDispatchDelegate( - reinterpret_cast(delegate)); - } - - private: - // Note: we do not directly use the option structure because we want to copy - // and own all the option data. - - // Folder to the Dispatch API implementation shared library. - std::string library_folder_; -}; - -namespace { - -struct AcceleratorDestructor { - void operator()(LiteRtAccelerator accelerator) { - LiteRtDestroyAccelerator(accelerator); - } -}; - -using AcceleratorGuard = - std::unique_ptr::element_type, - AcceleratorDestructor>; - -} // namespace -} // namespace litert - -extern "C" { - -LiteRtStatus LiteRtRegisterNpuAccelerator( - LiteRtEnvironmentT* environment, LiteRtNpuAcceleratorOptions* options) { - LITERT_ENSURE(environment != nullptr, kLiteRtStatusErrorInvalidArgument, - "accelerator handle is invalid"); - LiteRtAccelerator accelerator_handle; - LITERT_RETURN_IF_ERROR(LiteRtCreateAccelerator(&accelerator_handle)); - litert::AcceleratorGuard accelerator(accelerator_handle); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetName( - accelerator.get(), litert::NpuAccelerator::GetName)); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetVersion( - accelerator.get(), litert::NpuAccelerator::GetVersion)); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetHardwareSupport( - accelerator.get(), litert::NpuAccelerator::GetHardwareSupport)); - - LITERT_RETURN_IF_ERROR(LiteRtSetDelegateFunction( - accelerator.get(), litert::NpuAccelerator::CreateDelegate, - litert::NpuAccelerator::DestroyDelegate)); - - std::string library_folder; - if (options && options->library_folder) { - library_folder = options->library_folder; - } - // Check the environment options if the library folder wasn't set in the - // options. - if (library_folder.empty()) { - if (auto env_library_folder = - environment->GetOption(kLiteRtEnvOptionTagDispatchLibraryDir); - env_library_folder.has_value()) { - LITERT_ASSIGN_OR_RETURN( - library_folder, litert::Get(env_library_folder.value())); - } - } - - LITERT_ASSIGN_OR_RETURN( - auto accelerator_impl, - litert::NpuAccelerator::Create(std::move(library_folder))); - - LITERT_RETURN_IF_ERROR(LiteRtRegisterAccelerator( - environment, accelerator.release(), accelerator_impl.release(), - litert::NpuAccelerator::Destroy)); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h b/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h deleted file mode 100644 index 9c1d93938eb28c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_DISPATCH_DISPATCH_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_DISPATCH_DISPATCH_ACCELERATOR_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -struct LiteRtNpuAcceleratorOptions { - const char* library_folder; -}; - -// Registers the NPU accelerator to the given environment. -// -// `options` may be null, in which case the accelerator is registered with -// a default configuration. -// -// If `options.library_folder` is not specified, the library folder is replaced -// with the `LiteRtEnvOptionTagDispatchLibraryDir` environment option (that was -// passed upon creation). -// -// Once this function has returned, options may be freed or reused. -LiteRtStatus LiteRtRegisterNpuAccelerator(LiteRtEnvironment environment, - LiteRtNpuAcceleratorOptions* options); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_DISPATCH_DISPATCH_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD b/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD deleted file mode 100644 index 0a3140eb8ca7c2..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "xnnpack_accelerator", - srcs = ["xnnpack_accelerator.cc"], - hdrs = ["xnnpack_accelerator.h"], - deps = [ - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_registration", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime/accelerators:accelerator_implementation_helper", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc b/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc deleted file mode 100644 index a69e36147f00da..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h" - -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h" - -namespace litert { - -namespace { -constexpr const char kCpuAcceleratorName[] = "CpuAccelerator"; -constexpr const LiteRtApiVersion kCpuAcceleratorVersion{1, 0, 0}; - -class CpuAccelerator final - : public internal::AcceleratorImplementationHelper< - CpuAccelerator, kCpuAcceleratorName, kCpuAcceleratorVersion, - kLiteRtHwAcceleratorCpu> { - public: - CpuAccelerator() = default; - - static Expected Create() { return Allocate(); } - - // C API - - // Creates a Dispatch delegate instance. - static LiteRtStatus CreateDelegate( - LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, void** delegate) { - LITERT_ENSURE(delegate != nullptr, kLiteRtStatusErrorInvalidArgument, - "Delegate pointer is null."); - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(accelerator->env != nullptr, - kLiteRtStatusErrorInvalidArgument, - "Accelerator is not registered to an environment."); - - // TODO: b/403547017 - Make the CPU accelerator configurable using the - // compilation options. - auto xnn_options = TfLiteXNNPackDelegateOptionsDefault(); - *delegate = TfLiteXNNPackDelegateCreate(&xnn_options); - - LITERT_ENSURE(*delegate != nullptr, kLiteRtStatusErrorRuntimeFailure, - "XNNPack delegate failed to be created."); - return kLiteRtStatusOk; - } - - // Destroys an XNNPack delegate instance. - static void DestroyDelegate(void* delegate) { - TfLiteXNNPackDelegateDelete(reinterpret_cast(delegate)); - } -}; - -} // namespace -} // namespace litert - -extern "C" { - -LiteRtStatus LiteRtRegisterCpuAccelerator( - LiteRtEnvironmentT* environment, LiteRtCpuAcceleratorOptions* options) { - LITERT_ENSURE(environment != nullptr, kLiteRtStatusErrorInvalidArgument, - "accelerator handle is invalid"); - LiteRtAccelerator accelerator_handle; - LITERT_RETURN_IF_ERROR(LiteRtCreateAccelerator(&accelerator_handle)); - litert::internal::AcceleratorGuard accelerator(accelerator_handle); - - LITERT_RETURN_IF_ERROR(litert::internal::SetAcceleratorBoilerplateFunctions< - litert::CpuAccelerator>(accelerator)); - - LITERT_ASSIGN_OR_RETURN(auto accelerator_impl, - litert::CpuAccelerator::Create()); - - LITERT_RETURN_IF_ERROR(LiteRtRegisterAccelerator( - environment, accelerator.release(), accelerator_impl.release(), - litert::CpuAccelerator::Destroy)); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h b/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h deleted file mode 100644 index 01a252c7e1d429..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_XNNPACK_XNNPACK_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_XNNPACK_XNNPACK_ACCELERATOR_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Options that may be passed to the CPU accelerator when it is registered. -struct LiteRtCpuAcceleratorOptions {}; - -// Registers the CPU accelerator to the given environment. -// -// `options` may be null, in which case the accelerator is registered with -// a default configuration. -// -// Once this function has returned, options may be freed or reused. -LiteRtStatus LiteRtRegisterCpuAccelerator(LiteRtEnvironment environment, - LiteRtCpuAcceleratorOptions* options); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_XNNPACK_XNNPACK_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc b/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc deleted file mode 100644 index 26746dcd632546..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#if LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert { -namespace internal { - -bool AhwbBuffer::IsSupported() { -#if LITERT_HAS_AHWB_SUPPORT - return true; -#else - return false; -#endif -} - -Expected AhwbBuffer::Alloc(size_t size) { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb; - AHardwareBuffer_Desc ahwb_desc = { - .width = static_cast(size), - .height = 1, - .layers = 1, - .format = AHARDWAREBUFFER_FORMAT_BLOB, - .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | - AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | - AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER}; - if (AHardwareBuffer_allocate(&ahwb_desc, &ahwb) != 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate AHWB"); - } - return AhwbBuffer{/*.ahwb=*/ahwb}; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT -} - -void AhwbBuffer::Free(AHardwareBuffer* ahwb) { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer_release(ahwb); -#endif -} - -Expected AhwbBuffer::GetSize(AHardwareBuffer* ahwb) { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer_Desc ahwb_desc; - AHardwareBuffer_describe(ahwb, &ahwb_desc); - return static_cast(ahwb_desc.width) * ahwb_desc.height * - ahwb_desc.layers; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT -} - -Expected AhwbBuffer::Lock(AHardwareBuffer* ahwb, LiteRtEventT* event) { -#if LITERT_HAS_AHWB_SUPPORT - int fence = -1; - if (event != nullptr) { - LITERT_ASSIGN_OR_RETURN(fence, event->GetSyncFenceFd()); - } - void* host_addr; - LITERT_RETURN_IF_ERROR( - AHardwareBuffer_lock(ahwb, - AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | - AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, - fence, /*rect=*/nullptr, &host_addr) == 0, - Unexpected(kLiteRtStatusErrorRuntimeFailure, "Failed to lock AHWB")); - return host_addr; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif -} - -Expected AhwbBuffer::Unlock(AHardwareBuffer* ahwb) { -#if LITERT_HAS_AHWB_SUPPORT - if (AHardwareBuffer_unlock(ahwb, /*fence=*/nullptr) != 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to unlock AHWB"); - } - return {}; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h b/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h deleted file mode 100644 index 8722305225c1e7..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#else -// Define a place holder AHardwareBuffer struct just to enable compilation. -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus -typedef struct AHardwareBuffer AHardwareBuffer; -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert::internal { - -struct AhwbBuffer { - AHardwareBuffer* ahwb; - - static bool IsSupported(); - static Expected Alloc(size_t size); - static void Free(AHardwareBuffer* ahwb); - static Expected GetSize(AHardwareBuffer* ahwb); - static Expected Lock(AHardwareBuffer* ahwb, - LiteRtEventT* event = nullptr); - static Expected Unlock(AHardwareBuffer* ahwb); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compilation_options.h b/tensorflow/lite/experimental/litert/runtime/compilation_options.h deleted file mode 100644 index b92f6555a66b2d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compilation_options.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h" - -struct LiteRtCompilationOptionsT { - // This should be updated every time a field is added/edited. - // - // - Renaming a field: increment patch; - // - Adding or deprecating a field: set patch to 0, increment minor. - // - Breaking layout compatibility: set patch and minor to 0, increment major. - // - // Note: Changing a default value does not impact the version. - LiteRtApiVersion version = {.major = 0, .minor = 0, .patch = 1}; - LiteRtHwAcceleratorSet hardware_accelerators = kLiteRtHwAcceleratorNone; - litert::AcceleratorCompilationOptions accelerator_compilation_options; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model.cc b/tensorflow/lite/experimental/litert/runtime/compiled_model.cc deleted file mode 100644 index 04e538cf72d77b..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model.cc +++ /dev/null @@ -1,638 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -#if defined(__ANDROID__) -#include -#endif - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/builtin_ops.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/interpreter_builder.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" -#include "tensorflow/lite/experimental/litert/runtime/compilation_options.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/stderr_reporter.h" - -using litert::Error; -using litert::Expected; -using litert::OwningBufferRef; -using litert::TensorBuffer; -using litert::Unexpected; -using litert::internal::ExternalLiteRtBufferContext; -using litert::internal::SerializeModel; - -Expected LiteRtCompiledModelT::InitializeRuntime() { - tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder(*fb_model_, resolver)(&interp_); - if (interp_ == nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to build TFL interpreter"); - } - - signature_keys_ = interp_->signature_keys(); - if (signature_keys_.empty()) { - static auto* default_signature_key = - new std::string(LiteRtSignatureT::kDefaultSignatureKey); - signature_keys_.push_back(default_signature_key); - } - // Register the ExternalLiteRtBufferContext for TensorBuffer handshaking. - buffer_context_ = - std::make_unique(); - interp_->SetExternalContext(kTfLiteLiteRtBufferContext, - buffer_context_.get()); - - return {}; -} - -Expected LiteRtCompiledModelT::InitializeModel( - LiteRtModelT& model, LiteRtHwAcceleratorSet hw_accelerators, - LiteRtEnvironmentT& env) { - bool need_reserialization = false; - - if (hw_accelerators != kLiteRtHwAcceleratorNone) { - LITERT_LOG(LITERT_INFO, "Applying compiler plugins..."); - auto jit_result = litert::internal::ApplyPlugins( - &env, &model, hw_accelerators, &need_reserialization); - if (!jit_result) { - LITERT_LOG(LITERT_WARNING, "Failed to apply compiler plugins: %s", - jit_result.Error().Message().c_str()); - } else { - LITERT_LOG( - LITERT_INFO, "%d compiler plugins were applied successfully: %s", - jit_result->num_applied_plugins, jit_result->success_message.c_str()); - LITERT_LOG(LITERT_WARNING, "Plugin errs: %s", - jit_result->error_message.c_str()); - } - } - - const auto& tfl_wrapper = litert::internal::GetTflFlatbuffer(model); - // Currently, in all situations where litert model was import from a - // flatbuffer, the litert model will own said flatbuffer and stored it in the - // OwningBufferRef. - auto tfl_buf = tfl_wrapper.Buf(); - - if (!need_reserialization && tfl_buf.Data() != nullptr) { - LITERT_LOG( - LITERT_INFO, - "Flatbuffer model initialized directly from incoming litert model."); - fb_model_ = tflite::FlatBufferModel::BuildFromBuffer(tfl_buf.StrData(), - tfl_buf.Size()); - return {}; - } - - LITERT_LOG(LITERT_INFO, "JIT compilation changed model, reserializing..."); - - auto serialized = SerializeModel(std::move(model)); - if (!serialized) { - return serialized.Error(); - } - - model_buf_ = std::move(*serialized); - fb_model_ = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(model_buf_.Data()), model_buf_.Size()); - if (fb_model_ == nullptr) { - return Unexpected(kLiteRtStatusErrorFileIO, - "Failed to build flatbuffer from buffer"); - } - - return {}; -} - -namespace { - -// A utility class that allows appending additional compilation options, but -// only for the duration of a scope. -class ScopedCompilationOptionsModifier { - public: - explicit ScopedCompilationOptionsModifier( - LiteRtCompilationOptions compilation_options) - : accelerator_options_( - compilation_options->accelerator_compilation_options) {} - - ~ScopedCompilationOptionsModifier() { - // Remove any option that was appended during the lifetime of this object. - while (--num_appended_options_ >= 0) { - accelerator_options_.Pop(); - } - } - - Expected Append( - litert::AcceleratorCompilationOptions&& accelerator_options) { - auto status = accelerator_options_.Append(std::move(accelerator_options)); - if (status) { - ++num_appended_options_; - } - return status; - } - - private: - litert::AcceleratorCompilationOptions& accelerator_options_; - int num_appended_options_ = 0; -}; - -int GetAllocationFd(const tflite::Allocation* allocation) { - if (allocation != nullptr && - allocation->type() == tflite::Allocation::Type::kMMap) { - auto& mmap_allocation = - static_cast(*allocation); - return mmap_allocation.fd(); - } - return -1; -} - -} // namespace - -Expected LiteRtCompiledModelT::Create( - LiteRtEnvironmentT* env, LiteRtModel model, - LiteRtCompilationOptions jit_compilation_options) { - // If no compilation options were passed, we use default object. This allows - // us to add (for instance) accelerator compilation options. - std::unique_ptr - placeholder_jit_compilation_options; - if (!jit_compilation_options) { - placeholder_jit_compilation_options = - std::make_unique(); - jit_compilation_options = placeholder_jit_compilation_options.get(); - } - - auto compiled_model = std::make_unique(); - - LiteRtHwAcceleratorSet hardware_accelerators = kLiteRtHwAcceleratorNone; - if (jit_compilation_options) { - LiteRtGetCompilationOptionsHardwareAccelerators(jit_compilation_options, - &hardware_accelerators); - } - - LITERT_RETURN_IF_ERROR( - compiled_model->InitializeModel(*model, hardware_accelerators, *env)); - - LITERT_RETURN_IF_ERROR(compiled_model->InitializeRuntime()); - if (compiled_model->GetModelBase() == nullptr) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to initialize model memory."); - } - - // Add a new link in the accelerator compilation options that holds some data - // that is computed during model compilation. - LITERT_ASSIGN_OR_RETURN( - auto model_compilation_data_options, - litert::internal::ModelCompilationData::CreateOptions()); - - LITERT_ASSIGN_OR_RETURN( - auto* model_compilation_data, - model_compilation_data_options - .GetData()); - model_compilation_data->allocation_base = compiled_model->GetModelBase(); - model_compilation_data->allocation_fd = - GetAllocationFd(compiled_model->fb_model_->allocation()); - - // Temporarily append model_compilation_data to the jit_compilation_options, - // but remove it before returning from this function since the caller owns - // jit_compilation_options and may use it for other purposes. - ScopedCompilationOptionsModifier scoped_modifier(jit_compilation_options); - LITERT_RETURN_IF_ERROR( - scoped_modifier.Append(std::move(model_compilation_data_options))); - - // Retrieve the accelerator options list. - LiteRtAcceleratorCompilationOptions accelerator_options = nullptr; - LITERT_RETURN_IF_ERROR(LiteRtGetAcceleratorCompilationOptions( - jit_compilation_options, &accelerator_options)); - - // Apply accelerators matching the requested hardware support to the - // model in the order they were registered. - for (auto& accelerator : env->GetAcceleratorRegistry()) { - bool delegate_responsible_for_jit = false; - LITERT_RETURN_IF_ERROR( - LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator.get(), &delegate_responsible_for_jit)); - LiteRtHwAcceleratorSet accelerator_supported_hardware; - LITERT_RETURN_IF_ERROR(accelerator->GetHardwareSupport( - accelerator.get(), &accelerator_supported_hardware)); - // We don't apply the delegate if: - // - the delegate is responsible for JIT compilation - // - and JIT has not been requested for the hardware it supports. - if (delegate_responsible_for_jit && - !(hardware_accelerators & accelerator_supported_hardware)) { - continue; - } - - TfLiteOpaqueDelegate* delegate_ptr = nullptr; - LITERT_RETURN_IF_ERROR( - accelerator->CreateDelegate(accelerator.get(), accelerator_options, - reinterpret_cast(&delegate_ptr))); - - auto delegate = tflite::TfLiteOpaqueDelegateUniquePtr( - delegate_ptr, reinterpret_cast( - accelerator->DestroyDelegate)); - - if (compiled_model->interp_->ModifyGraphWithDelegate(delegate_ptr) != - kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to modify graph with delegate"); - } - compiled_model->RegisterDelegate(std::move(delegate)); - } - - compiled_model->CheckCpuTensors(); - return compiled_model; -} - -void LiteRtCompiledModelT::CheckCpuTensors() { - cpu_tensors_.clear(); - for (int subgraph_no = 0; subgraph_no < interp_->subgraphs_size(); - ++subgraph_no) { - auto* subgraph = interp_->subgraph(subgraph_no); - auto& execution_plan = subgraph->execution_plan(); - auto& nodes_and_registration = subgraph->nodes_and_registration(); - for (int execution_plan_index = 0; - execution_plan_index < execution_plan.size(); execution_plan_index++) { - int node_index = execution_plan[execution_plan_index]; - auto& node = nodes_and_registration[node_index].first; - const TfLiteRegistration& registration = - nodes_and_registration[node_index].second; - - if (registration.builtin_code == kTfLiteBuiltinDelegate) { - continue; - } - if (registration.builtin_code == kTfLiteBuiltinCustom && - litert::internal::kLiteRtDispatchOpCustomCode == - registration.custom_name) - continue; - for (int i = 0; i < node.inputs->size; ++i) { - int input_tensor_index = node.inputs->data[i]; - if (input_tensor_index == kTfLiteOptionalTensor) continue; - cpu_tensors_.insert(subgraph->tensor(input_tensor_index)); - } - } - } -} - -litert::Expected -LiteRtCompiledModelT::GetTensorBufferRequirements(const TfLiteTensor* tensor) { - // Use the buffer context to get the buffer requirements only if the tensor - // is not a CPU tensor. - if (cpu_tensors_.find(tensor) == cpu_tensors_.end()) { - auto requirements = buffer_context_->GetBufferRequirement(tensor); - if (requirements) { - return (*requirements)->Get(); - } - } else { - LITERT_LOG(LITERT_VERBOSE, "Tensor %s is shared with CPU.\n", tensor->name); - } - LiteRtTensorBufferRequirements litert_cpu_buffer_requirements; - LiteRtTensorBufferType cpu_buffer_type[] = { - kLiteRtTensorBufferTypeHostMemory}; - uint32_t cpu_buffer_strides[] = {0}; - auto res = LiteRtCreateTensorBufferRequirements( - /*num_supported_tensor_buffer_types=*/1, cpu_buffer_type, tensor->bytes, - /*num_strides=*/1, cpu_buffer_strides, &litert_cpu_buffer_requirements); - if (res != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create CPU buffer requirements"); - } - cpu_buffer_requirements_[tensor] = - litert::TensorBufferRequirements(litert_cpu_buffer_requirements); - return litert_cpu_buffer_requirements; -} - -Expected -LiteRtCompiledModelT::GetInputBufferRequirements( - absl::string_view signature_key, size_t input_index) { - auto runner = GetSignatureRunner(signature_key); - if (runner == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature runner"); - } - auto input_names = runner->subgraph_input_names(); - if (input_index >= input_names.size()) { - return Unexpected(kLiteRtStatusErrorIndexOOB, "Input index out of range"); - } - auto input_name = input_names[input_index]; - auto* input_tensor = runner->input_tensor(input_name); - if (input_tensor == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get input tensor"); - } - - return GetTensorBufferRequirements(input_tensor); -} - -Expected -LiteRtCompiledModelT::GetOutputBufferRequirements( - absl::string_view signature_key, size_t output_index) { - auto runner = GetSignatureRunner(signature_key); - if (runner == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature runner"); - } - auto output_names = runner->subgraph_output_names(); - if (output_index >= output_names.size()) { - return Unexpected(kLiteRtStatusErrorIndexOOB, "Output index out of range"); - } - auto output_name = output_names[output_index]; - auto* output_tensor = runner->output_tensor(output_name); - if (output_tensor == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get output tensor"); - } - - return GetTensorBufferRequirements(output_tensor); -} - -tflite::SignatureRunner* LiteRtCompiledModelT::GetSignatureRunner( - absl::string_view signature_key) { - if (signature_runners_.contains(signature_key)) { - return signature_runners_[signature_key]; - } - auto runner = interp_->GetSignatureRunner( - signature_key == LiteRtSignatureT::kDefaultSignatureKey - ? nullptr - : std::string(signature_key).c_str()); - signature_runners_[signature_key] = runner; - return runner; -} - -Expected LiteRtCompiledModelT::RegisterBuffer( - tflite::SignatureRunner* runner, TfLiteTensor* tensor, - const char* tensor_name, LiteRtTensorBuffer buffer, bool is_input, - std::vector& locked_buffers) { - bool backend_requires_cpu_buffer = false; - - auto requirements = buffer_context_->GetBufferRequirement(tensor); - if (requirements) { - auto supported_types = (*requirements)->SupportedTypes(); - if (!supported_types) { - return supported_types.Error(); - } - - for (auto& type : *supported_types) { - if (type == buffer->buffer_type()) { - // Register tensor buffer if it can be used by the backend. - buffer->Duplicate(); - TensorBuffer duplicated_buffer(buffer); - if (auto status = buffer_context_->RegisterTensorBuffer( - tensor, std::move(duplicated_buffer)); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register tensor buffer"); - } - // Mark the tensor as non-CPU to avoid TFLite from allocating it. - tensor->allocation_type = kTfLiteNonCpu; - tensor->data.data = nullptr; - return {}; - } - if (type == kLiteRtTensorBufferTypeHostMemory) { - backend_requires_cpu_buffer = true; - } - } - } else { - // If the BufferRequirement is not registered, assumes the backend requires - // CPU buffer. - backend_requires_cpu_buffer = true; - } - - if (backend_requires_cpu_buffer) { - // When backend requires CPU buffer. - bool buffer_is_cpu_compatible = - buffer->buffer_type() == kLiteRtTensorBufferTypeHostMemory || - buffer->buffer_type() == kLiteRtTensorBufferTypeOpenCl; -#if defined(__ANDROID__) - if (buffer->buffer_type() == kLiteRtTensorBufferTypeAhwb) { - if (__builtin_available(android 26, *)) { - auto ahwb = buffer->GetAhwbBuffer(); - if (ahwb) { - // TODO: b/382330322 - Update logic to check if the AHWB (stride) is - // CPU compatible. - AHardwareBuffer_Desc desc; - AHardwareBuffer_describe(*ahwb, &desc); - buffer_is_cpu_compatible = true; - } - } - } -#endif - if (buffer_is_cpu_compatible) { - void* host_mem_addr; - if (auto status = LiteRtLockTensorBuffer(buffer, &host_mem_addr); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to lock the tensor buffer"); - } - locked_buffers.push_back(buffer); - TfLiteCustomAllocation custom_allocation{host_mem_addr, tensor->bytes}; - if (is_input) { - runner->SetCustomAllocationForInputTensor(tensor_name, - custom_allocation, - /*flags=*/0); - } else { - runner->SetCustomAllocationForOutputTensor(tensor_name, - custom_allocation, - /*flags=*/0); - } - return {}; - } - } - - // If the tensor is shared with CPU, register tensor buffer as is and let - // accelerator handle the conversion. - if (cpu_tensors_.find(tensor) != cpu_tensors_.end()) { - void* host_mem_addr; - if (auto status = LiteRtLockTensorBuffer(buffer, &host_mem_addr); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to lock the tensor buffer"); - } - locked_buffers.push_back(buffer); - TfLiteCustomAllocation custom_allocation{host_mem_addr, tensor->bytes}; - if (is_input) { - runner->SetCustomAllocationForInputTensor(tensor_name, custom_allocation, - /*flags=*/0); - } else { - runner->SetCustomAllocationForOutputTensor(tensor_name, custom_allocation, - /*flags=*/0); - } - return {}; - } - // TODO: b/382330322 - Add buffer conversion logic instead of returning error. - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "The given buffer type is not supported."); -} - -Expected LiteRtCompiledModelT::Run( - absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers, bool& async) { - auto runner = GetSignatureRunner(signature_key); - if (runner == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature runner"); - } - size_t num_inputs = input_buffers.size(); - if (num_inputs != runner->subgraph_input_names().size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Input buffer size mismatch"); - } - size_t num_outputs = output_buffers.size(); - if (num_outputs != runner->subgraph_output_names().size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Output buffer size mismatch"); - } - - // In general output buffer events are assigned by the runtime and not the - // caller; here we check for any violation of that condition. - for (auto litert_output_buffer : output_buffers) { - if (litert_output_buffer->HasEvent()) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Output buffers cannot have events attached"); - } - } - - // The collection of locked buffers. It is used to unlock the buffers after - // the inference is done. - std::vector locked_buffers; - locked_buffers.reserve(num_inputs + num_outputs); - auto unlock_buffers = absl::MakeCleanup([&locked_buffers]() { - for (auto locked_buffer : locked_buffers) { - LiteRtUnlockTensorBuffer(locked_buffer); - } - }); - for (int i = 0; i < num_inputs; ++i) { - const auto& input_name = runner->subgraph_input_names()[i]; - auto* input_tensor = runner->input_tensor(input_name); - auto res = - RegisterBuffer(runner, input_tensor, input_name, input_buffers[i], - /*is_input=*/true, locked_buffers); - if (!res) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to register input tensor buffer: ", - res.Error().Message())); - } - } - - for (int i = 0; i < runner->subgraph_output_names().size(); ++i) { - const auto& output_name = runner->subgraph_output_names()[i]; - auto* output_tensor = runner->output_tensor(output_name); - auto res = RegisterBuffer(runner, const_cast(output_tensor), - output_name, output_buffers[i], - /*is_input=*/false, locked_buffers); - if (!res) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to register output tensor buffer: ", - res.Error().Message())); - } - } - - if (auto res = runner->AllocateTensors(); res != kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate tensors"); - } - - // Relay the intended async execution mode to DelegateKernel of Accelerator. - buffer_context_->SetAsyncExecutionMode(async); - - if (auto res = runner->Invoke(); res != kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Failed to invoke"); - } - - if (async) { - // If the caller requested async execution, then set async to true if any of - // the output buffers have been assigned a synchronization event. - async = false; - for (auto& tb : output_buffers) { - async |= tb->HasEvent(); - } - } else { - // If the caller has not requested async execution, then wait on - // synchronization events that have been attached to the outputs. - for (auto& tb : output_buffers) { - if (tb->HasEvent()) { - auto event = tb->GetEvent(); - if (auto status = litert::Event(*event, /*owned=*/false) - .Wait(/*timeout_in_ms=*/-1); - !status) { - return status; - } - } - } - } - - return {}; -} - -litert::Expected LiteRtCompiledModelT::RunCApi( - size_t signature_index, size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, bool* async) { - if (signature_index >= signature_keys_.size()) { - return litert::Unexpected( - kLiteRtStatusErrorIndexOOB, - "Signature index is out of range of signature keys"); - } - std::vector input_buffers_vec; - input_buffers_vec.reserve(num_input_buffers); - for (int i = 0; i < num_input_buffers; ++i) { - input_buffers_vec.push_back(std::move(input_buffers[i])); - } - std::vector output_buffers_vec; - output_buffers_vec.reserve(num_output_buffers); - for (int i = 0; i < num_output_buffers; ++i) { - output_buffers_vec.push_back(std::move(output_buffers[i])); - } - bool async_ = async ? *async : false; - auto result = Run(*signature_keys_[signature_index], input_buffers_vec, - output_buffers_vec, async_); - if (async) { - *async = async_; - } - return result; -} diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model.h b/tensorflow/lite/experimental/litert/runtime/compiled_model.h deleted file mode 100644 index 792efae934e4b8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model.h +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/model_builder.h" - -// The LiteRtCompiledModelT is internal implementation of CompiledModel C++ API. -class LiteRtCompiledModelT { - public: - using Ptr = std::unique_ptr; - - LiteRtCompiledModelT() = default; - ~LiteRtCompiledModelT() = default; - - // Creates a LiteRtCompiledModelT from a LiteRtModel object. - // The model is loaded into memory and the caller takes ownership of the - // returned object. - static litert::Expected Create( - LiteRtEnvironmentT* env, LiteRtModel model, - LiteRtCompilationOptions jit_compilation_options = nullptr); - - // Returns the buffer requirements for the n-th input tensor. The returned - // LiteRtTensorBufferRequirements is used to create the input tensor - // buffer. - litert::Expected GetInputBufferRequirements( - absl::string_view signature_key, size_t input_index); - - // The same as GetInputBufferRequirements() for C API. - litert::Expected - GetInputBufferRequirementsCApi(size_t signature_index, size_t input_index) { - if (signature_index >= signature_keys_.size()) { - return litert::Unexpected( - kLiteRtStatusErrorIndexOOB, - "Signature index is out of range of signature keys"); - } - return GetInputBufferRequirements(*signature_keys_[signature_index], - input_index); - } - - // Returns the buffer requirements for the n-th output tensor. The returned - // LiteRtTensorBufferRequirements is used to create the output tensor - // buffer. - litert::Expected GetOutputBufferRequirements( - absl::string_view signature_key, size_t output_index); - - // The same as GetOutputBufferRequirements() for C API. - litert::Expected - GetOutputBufferRequirementsCApi(size_t signature_index, size_t output_index) { - if (signature_index >= signature_keys_.size()) { - return litert::Unexpected( - kLiteRtStatusErrorIndexOOB, - "Signature index is out of range of signature keys"); - } - return GetOutputBufferRequirements(*signature_keys_[signature_index], - output_index); - } - - // Runs the model of the given signature with the provided input/output - // litert::TensorBuffers. If parameter `async` is true, then the model is run - // asynchronously, if possible. Upon returning, the function sets parameter - // `async` to true if asynchronous execution was requested and possible, - // otherwise it sets it to false. - litert::Expected Run( - absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers, bool& async); - - // The same as Run() for C API. - litert::Expected RunCApi(size_t signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool* async); - - private: - // Initializes the internal TFLite interpreter and related objects. - // This is called in the public Create*() methods. - // The flatbuffer_model_ must be set before calling this method. - litert::Expected InitializeRuntime(); - - // Handles any JIT compilation and intializes the flatbuffer_model_ and - // related field within the compiled model. - // - // If no JIT compilation is requested, the compiled model will point to the - // underlying tflite::Model* owned by the input litert model. The compiled - // models alloc_ and model_buf_ will be nullptr as these are only relevant - // when compiled model owns a flatbuffer. - // - // If JIT compilation does occur, a new flatbuffer owned by the compiled model - // will be serialized from the result of compilation. The alloc_ and - // model_buf_ will be set for storage of the new flatbuffer. - // - // NOTE: JIT compilation invalidates the input litert model. - // TODO: Design a better abstraction for optional ownership for flatbuffer, - // consider caching JIT result. - litert::Expected InitializeModel(LiteRtModelT& model, - LiteRtHwAcceleratorSet hw_accelerators, - LiteRtEnvironmentT& env); - - // Returns the base address of the flatbuffer memory. - // - // If no JIT compilation has taken place, this points to flatbuffer memory - // owned by the incoming litert model (litert models always owns their - // flatbuffer memory until serialization). - // - // If JIT compilation has taken place, this points to the base address of the - // a newly serialized flatbuffer which is owned by the compiled model (in - // model_buf_); - // - // NOTE: This should never be nullptr after initialization. - const char* GetModelBase() { - if (fb_model_ == nullptr) { - return nullptr; - } - - // fb_model_->allocation is only null when the flatbuffer is built with - // BuildFlatBufferFromModel, which is not currently in use in either - // litert::LoadModel or LiteRtCompiledModelT::Create. - const auto* alloc = fb_model_->allocation(); - if (alloc) { - // NOTE: During JIT, alloc->base() == model_buf_.Data(), which is owned - // by the compiled model. Otherwise, model_buf_.Data() is nullptr and - // alloc->base() points a buffer owned by the incoming litert model. - return reinterpret_cast(alloc->base()); - } - - return nullptr; - } - - // Returns the buffer requirements for the given tensor. - litert::Expected GetTensorBufferRequirements( - const TfLiteTensor* tensor); - - // Returns the SignatureRunner for the given signature key. - // If the signature key is not found, returns nullptr. - tflite::SignatureRunner* GetSignatureRunner(absl::string_view signature_key); - - // Registers the TensorBuffer for the given tensor with the SignatureRunner. - // If the TensorBuffer can be directly consumed as CPU Tensors, they'll be - // locked and use it with CustomAllocation. The locked buffer is kept in the - // `locked_buffers`. Caller is responsible for unlocking of these buffers. - // If the TensorBuffer can be consumed by the delegate, then `tensor` will be - // marked as non-CPU to avoid TFLite from allocating it. - litert::Expected RegisterBuffer( - tflite::SignatureRunner* runner, TfLiteTensor* tensor, - const char* tensor_name, LiteRtTensorBuffer buffer, bool is_input, - std::vector& locked_buffers); - - void RegisterDelegate(tflite::TfLiteOpaqueDelegateUniquePtr&& delegate) { - delegates_.push_back(std::move(delegate)); - } - - // Checks the CPU Tensors and stores them in the `cpu_tensors_` set. - void CheckCpuTensors(); - - // Map from signature key to SignatureRunner. This is used to lazy calling - // GetSignatureRunner() which is expensive. - absl::flat_hash_map - signature_runners_; - - // The buffer requirement maps for CPU buffers. For delegates with CPU - // buffers, they don't register TensorBufferRequirements. Instead, the - // CompiledModel creates the TensorBufferRequirements and stores them - // in this map. - absl::flat_hash_map - cpu_buffer_requirements_; - - // The Interpreter and related objects used to run the model. - std::unique_ptr<::tflite::Interpreter> interp_; - std::unique_ptr<::tflite::FlatBufferModel> fb_model_; - litert::OwningBufferRef model_buf_; - std::vector signature_keys_; - - // The ExternalLiteRtBufferContext used to register tensor buffers with - // Delegates. - // Note: The ExternalLiteRtBufferContext must be destroyed after the - // Interpreter. - std::unique_ptr - buffer_context_; - - std::vector delegates_; - - // The set of CPU Tensors. This is used to manage TensorBufferRequirements - // for shared CPU Tensors. - absl::flat_hash_set cpu_tensors_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc b/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc deleted file mode 100644 index 76797ac5074eed..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc +++ /dev/null @@ -1,544 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -namespace litert { -namespace { - -using ::testing::ElementsAre; -using ::testing::FloatNear; -using ::testing::Pointwise; - -// Creates a tensor buffer of the given tensor, buffer type, and size. -Expected CreateBufferOfType( - const LiteRtTensorT& tensor, LiteRtTensorBufferType buffer_type, - size_t bytes) { - const LiteRtRankedTensorType ranked_tensor_type = - tensor.Type().second.ranked_tensor_type; - - LiteRtTensorBufferT* tensor_buffer; - LITERT_RETURN_IF_ERROR(LiteRtCreateManagedTensorBuffer( - buffer_type, &ranked_tensor_type, bytes, &tensor_buffer)); - - return tensor_buffer; -} - -// Creates input or output tensor buffers of the given model, buffer type and -// size. -Expected> CreateInputOutputBuffersOfType( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtTensorBufferType buffer_type, size_t bytes, bool is_input) { - LITERT_ASSIGN_OR_RETURN(const LiteRtSignatureT& signature, - model.FindSignature(signature_key)); - const LiteRtSubgraphT& subgraph = signature.GetSubgraph(); - - const std::vector& tensors = - is_input ? subgraph.Inputs() : subgraph.Outputs(); - - std::vector tensor_buffers; - tensor_buffers.reserve(tensors.size()); - - for (int i = 0; i < tensors.size(); ++i) { - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT * tensor_buffer, - CreateBufferOfType(*tensors[i], buffer_type, bytes)); - tensor_buffers.push_back(tensor_buffer); - } - return tensor_buffers; -} - -// Creates input buffers of the given model, buffer type, and size. -Expected> CreateInputBuffersOfType( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtTensorBufferType buffer_type, size_t bytes) { - return CreateInputOutputBuffersOfType(model, signature_key, buffer_type, - bytes, /*is_input=*/true); -} - -// Creates output buffers of the given model, buffer type, and size. -Expected> CreateOutputBuffersOfType( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtTensorBufferType buffer_type, size_t bytes) { - return CreateInputOutputBuffersOfType(model, signature_key, buffer_type, - bytes, /*is_input=*/false); -} - -// Creates a tensor buffer of the given tensor and buffer requirements. -Expected CreateBufferFromRequirements( - const LiteRtTensorT& tensor, - const LiteRtTensorBufferRequirementsT& requirements) { - return CreateBufferOfType(tensor, requirements.SupportedBufferTypes().at(0), - requirements.BufferSize()); -} - -// Creates input or output tensor buffers of the given model and requirements. -Expected> -CreateInputOutputBuffersFromRequirements(LiteRtModelT& model, - absl::string_view signature_key, - LiteRtCompiledModelT& compiled_model, - bool is_input) { - LITERT_ASSIGN_OR_RETURN(const LiteRtSignatureT& signature, - model.FindSignature(signature_key)); - const LiteRtSubgraphT& subgraph = signature.GetSubgraph(); - - const std::vector& tensors = - is_input ? subgraph.Inputs() : subgraph.Outputs(); - - std::vector tensor_buffers; - tensor_buffers.reserve(tensors.size()); - - for (int i = 0; i < tensors.size(); ++i) { - Expected requirements_expected = - is_input ? compiled_model.GetInputBufferRequirements(signature_key, i) - : compiled_model.GetOutputBufferRequirements(signature_key, i); - LITERT_ASSIGN_OR_RETURN(LiteRtTensorBufferRequirementsT * requirements, - requirements_expected); - - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT * tensor_buffer, - CreateBufferFromRequirements(*tensors[i], *requirements)); - tensor_buffers.push_back(tensor_buffer); - } - return tensor_buffers; -} - -// Creates input buffers of the given model and requirements. -Expected> CreateInputBuffersFromRequirements( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtCompiledModelT& compiled_model) { - return CreateInputOutputBuffersFromRequirements(model, signature_key, - compiled_model, - /*is_input=*/true); -} - -// Creates output buffers of the given model and requirements. -Expected> CreateOutputBuffersFromRequirements( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtCompiledModelT& compiled_model) { - return CreateInputOutputBuffersFromRequirements(model, signature_key, - compiled_model, - /*is_input=*/false); -} - -TEST(CompiledModelTest, Basic) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr env, - LiteRtEnvironmentT::CreateWithOptions({})); - LiteRtEnvironmentT* env_ptr = env.release(); - - // Create LiteRtModel and check signatures. - std::string path = testing::GetTestFilePath(kModelFileName); - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - absl::Span signatures = model->Signatures(); - ASSERT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0]->Key(); - EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); - - const std::vector& input_names = signatures[0]->InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - const std::vector& output_names = signatures[0]->OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel with options. - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtCompiledModelT::Ptr compiled_model, - LiteRtCompiledModelT::Create(env_ptr, model, jit_compilation_options)); - LiteRtDestroyCompilationOptions(jit_compilation_options); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg0, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/0)); - const std::vector& input_buffer_types_arg0 = - input_buffer_requirements_arg0->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg1, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/1)); - const std::vector& input_buffer_types_arg1 = - input_buffer_requirements_arg1->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * output_buffer_requirements, - compiled_model->GetOutputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*output_index=*/0)); - const std::vector& output_buffer_types = - output_buffer_requirements->SupportedBufferTypes(); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output LiteRtTensorBuffers. Buffers are - // created to match CompiledModel's TensorBufferRequirements. - LITERT_ASSERT_OK_AND_ASSIGN(std::vector input_buffers, - CreateInputBuffersFromRequirements( - *model, signature_key, *compiled_model)); - LITERT_ASSERT_OK_AND_ASSIGN(std::vector output_buffers, - CreateOutputBuffersFromRequirements( - *model, signature_key, *compiled_model)); - - LiteRtTensorBuffer& input_0_buffer = input_buffers[0]; - { - TensorBuffer cpu_buffer(input_0_buffer, /*owned=*/false); - cpu_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - } - LiteRtTensorBuffer& input_1_buffer = input_buffers[1]; - { - TensorBuffer cpu_buffer(input_1_buffer, /*owned=*/false); - cpu_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - } - - // Execute model. - bool async = false; - compiled_model->Run(signature_key, input_buffers, output_buffers, async); - - // Check model output. - { - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - absl::Span output = absl::MakeSpan( - static_cast(host_mem_addr), kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); - } - - // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. - for (auto& input_buffer : input_buffers) { - LiteRtDestroyTensorBuffer(input_buffer); - } - for (auto& output_buffer : output_buffers) { - LiteRtDestroyTensorBuffer(output_buffer); - } - - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(env_ptr); -} - -TEST(CompiledModelTest, UseAhwbBuffer) { -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices"; -#endif - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr env, - LiteRtEnvironmentT::CreateWithOptions({})); - LiteRtEnvironmentT* env_ptr = env.release(); - - // Create LiteRtModel and check signatures. - std::string path = testing::GetTestFilePath(kModelFileName); - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - absl::Span signatures = model->Signatures(); - ASSERT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0]->Key(); - EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); - - const std::vector& input_names = signatures[0]->InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - const std::vector& output_names = signatures[0]->OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel with options. - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtCompiledModelT::Ptr compiled_model, - LiteRtCompiledModelT::Create(env_ptr, model, jit_compilation_options)); - LiteRtDestroyCompilationOptions(jit_compilation_options); - - // Check input and output buffer requirements expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg0, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/0)); - const std::vector& input_buffer_types_arg0 = - input_buffer_requirements_arg0->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg1, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/1)); - const std::vector& input_buffer_types_arg1 = - input_buffer_requirements_arg1->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * output_buffer_requirements, - compiled_model->GetOutputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*output_index=*/0)); - const std::vector& output_buffer_types = - output_buffer_requirements->SupportedBufferTypes(); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. CompiledModel's - // TensorBufferRequirements expect host memory,but we create AHWB buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - CreateInputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeAhwb, - sizeof(float) * kTestInput0Size)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - CreateOutputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeAhwb, - sizeof(float) * kTestOutputSize)); - - LiteRtTensorBuffer& input_0_buffer = input_buffers[0]; - EXPECT_EQ(input_0_buffer->buffer_type(), kLiteRtTensorBufferTypeAhwb); - { - TensorBuffer ahwb_buffer(input_0_buffer, /*owned=*/false); - ahwb_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - } - LiteRtTensorBuffer& input_1_buffer = input_buffers[1]; - { - TensorBuffer ahwb_buffer(input_1_buffer, /*owned=*/false); - ahwb_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - } - - // Execute model. - bool async = false; - compiled_model->Run(signature_key, input_buffers, output_buffers, async); - - // Check model output. - { - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - absl::Span output = absl::MakeSpan( - static_cast(host_mem_addr), kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); - } - - // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. - for (auto& input_buffer : input_buffers) { - LiteRtDestroyTensorBuffer(input_buffer); - } - for (auto& output_buffer : output_buffers) { - LiteRtDestroyTensorBuffer(output_buffer); - } - - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(env_ptr); -} - -TEST(CompiledModelTest, UseOpenCLBuffer) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - - if (!litert::internal::OpenClBuffer::IsSupported()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr env, - LiteRtEnvironmentT::CreateWithOptions({})); - LiteRtEnvironmentT* env_ptr = env.release(); - - // Create LiteRtModel and check signatures. - std::string path = testing::GetTestFilePath(kModelFileName); - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - absl::Span signatures = model->Signatures(); - ASSERT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0]->Key(); - EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); - - const std::vector& input_names = signatures[0]->InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - const std::vector& output_names = signatures[0]->OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel with options. - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtCompiledModelT::Ptr compiled_model, - LiteRtCompiledModelT::Create(env_ptr, model, jit_compilation_options)); - LiteRtDestroyCompilationOptions(jit_compilation_options); - - // Check ComiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg0, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/0)); - const std::vector& input_buffer_types_arg0 = - input_buffer_requirements_arg0->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg1, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/1)); - const std::vector& input_buffer_types_arg1 = - input_buffer_requirements_arg1->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * output_buffer_requirements, - compiled_model->GetOutputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*output_index=*/0)); - const std::vector& output_buffer_types = - output_buffer_requirements->SupportedBufferTypes(); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. CompiledModel's - // TensorBufferRequirements expect host memory,but we create OpenCL buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - CreateInputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeOpenCl, - sizeof(float) * kTestInput0Size)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - CreateOutputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeOpenCl, - sizeof(float) * kTestOutputSize)); - - // Fill model inputs. - LiteRtTensorBuffer& input_0_buffer = input_buffers[0]; - EXPECT_EQ(input_0_buffer->buffer_type(), kLiteRtTensorBufferTypeOpenCl); - { - TensorBuffer opencl_buffer(input_0_buffer, /*owned=*/false); - opencl_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - } - LiteRtTensorBuffer& input_1_buffer = input_buffers[1]; - { - TensorBuffer opencl_buffer(input_1_buffer, /*owned=*/false); - opencl_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - } - - // Execute model. - bool async = false; - compiled_model->Run(signature_key, input_buffers, output_buffers, async); - - // Check model output. - { - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - absl::Span output = absl::MakeSpan( - static_cast(host_mem_addr), kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); - } - - // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. - for (auto& input_buffer : input_buffers) { - LiteRtDestroyTensorBuffer(input_buffer); - } - for (auto& output_buffer : output_buffers) { - LiteRtDestroyTensorBuffer(output_buffer); - } - - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(env_ptr); -} -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/BUILD b/tensorflow/lite/experimental/litert/runtime/compiler/BUILD deleted file mode 100644 index 43bef76096cbe1..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiler/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_test( - name = "jit_compilation_qualcomm_test", - srcs = ["jit_compilation_qualcomm_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "jit_compilation_mediatek_test", - srcs = ["jit_compilation_mediatek_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "no_oss", - "nobuilder", - "notap", - ], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc deleted file mode 100644 index 4e3b2f24d87c2c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -constexpr absl::string_view kCompilerPluginLibSearchPath = "/data/local/tmp"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -using testing::FloatNear; -using testing::Pointwise; - -TEST(JitCompilation, MediaTek) { - const std::array environment_options = { - litert::Environment::Option{ - /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryDir, - /*.value=*/kCompilerPluginLibSearchPath, - }, - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(auto environment, - litert::Environment::Create(environment_options)); - - auto model_path = litert::testing::GetTestFilePath(kModelFileName); - LITERT_ASSERT_OK_AND_ASSIGN(auto model, - litert::Model::CreateFromFile(model_path)); - - auto num_signatures = model.GetNumSignatures(); - ASSERT_EQ(num_signatures, 1); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - - LITERT_ASSERT_OK_AND_ASSIGN(auto compiled_model, - litert::CompiledModel::Create( - environment, model, kLiteRtHwAcceleratorNpu)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, - compiled_model.CreateInputBuffers(/*signature_index=*/0)); - EXPECT_EQ(input_buffers.size(), 2); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, - compiled_model.CreateOutputBuffers(/*signature_index=*/0)); - EXPECT_EQ(output_buffers.size(), 1); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(/*signature_index=*/0, input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc deleted file mode 100644 index 1f7a3366f86af5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -constexpr absl::string_view kCompilerPluginLibSearchPath = "/data/local/tmp"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -using testing::FloatNear; -using testing::Pointwise; - -TEST(JitCompilation, Qualcomm) { - const std::array environment_options = { - litert::Environment::Option{ - /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryDir, - /*.value=*/kCompilerPluginLibSearchPath, - }, - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(auto environment, - litert::Environment::Create(environment_options)); - - auto model_path = litert::testing::GetTestFilePath(kModelFileName); - LITERT_ASSERT_OK_AND_ASSIGN(auto model, - litert::Model::CreateFromFile(model_path)); - - auto num_signatures = model.GetNumSignatures(); - ASSERT_EQ(num_signatures, 1); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - - LITERT_ASSERT_OK_AND_ASSIGN(auto compiled_model, - litert::CompiledModel::Create( - environment, model, kLiteRtHwAcceleratorNpu)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, - compiled_model.CreateInputBuffers(/*signature_index=*/0)); - EXPECT_EQ(input_buffers.size(), 2); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, - compiled_model.CreateOutputBuffers(/*signature_index=*/0)); - EXPECT_EQ(output_buffers.size(), 1); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(/*signature_index=*/0, input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD b/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD deleted file mode 100644 index 6e8886c434994f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "gles_linkopts") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//third_party/odml/infra/perf/mobile_tests/litert:__subpackages__", - ], -) - -# Dispatch API implementation, it is used by the dispatch delegate to call the vendor's dispatch -# API. -cc_library( - name = "dispatch", - srcs = [ - "litert_dispatch.cc", - ], - hdrs = [ - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch.h", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_api.h", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "//tensorflow/lite/experimental/litert/core:version", - ], - alwayslink = 1, -) - -cc_library( - name = "dispatch_delegate", - srcs = [ - "dispatch_delegate.cc", - "dispatch_delegate_kernel.cc", - ], - hdrs = [ - "dispatch_delegate_kernel.h", - "dispatch_delegate_options.h", - "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate.h", - "//tensorflow/lite/experimental/litert/cc:litert_dispatch_delegate.h", - ], - deps = [ - ":dispatch", - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core/c:c_api_opaque_without_op_resolver", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core:environment_options", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/runtime:tfl_utils", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "dispatch_delegate_google_tensor_test", - srcs = ["dispatch_delegate_google_tensor_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/shared_input_cpu_npu.tflite", - "//tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + gles_linkopts(), - deps = [ - ":dispatch_delegate", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/darwinn/driver_shared/fence:fence_test_util", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_event", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - ], -) - -cc_test( - name = "dispatch_delegate_qualcomm_test", - srcs = ["dispatch_delegate_qualcomm_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/shared_input_cpu_npu.tflite", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - ":dispatch_delegate", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "dispatch_delegate_mediatek_test", - srcs = ["dispatch_delegate_mediatek_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/shared_input_cpu_npu.tflite", - "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "no_oss", - "nobuilder", - "notap", - ], - deps = [ - ":dispatch_delegate", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/README.md b/tensorflow/lite/experimental/litert/runtime/dispatch/README.md deleted file mode 100644 index 5a2e33e0806a8c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## Google Tensor - -Test case can dispatch_delegate_google_tensor_test can be run on a device with a -Pixel 9 device with the following comands - -$ ../../google/run_test_on_android.sh dispatch_delegate_google_tensor_test - -## Qualcomm - -Test case can dispatch_delegate_qualcomm_test can be run on a Samsung S24 device -with the following comands - -$ ../../google/run_test_on_android.sh dispatch_delegate_qualcomm_test - -## MediaTek - -Test case can dispatch_delegate_mediatek_test can be run on a device with a -MetiaTek mt6989 SoC with the following comands - -$ ../../google/run_test_on_android.sh dispatch_delegate_mediatek_test diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc deleted file mode 100644 index 2b69430a2eae19..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h" -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace { - -using ::litert::internal::kLiteRtDispatchOpCustomCode; - -// A TFL Delegate that can recognize subgraphs that run on Dispatch API capable -// accelerators, e.g. TPU, DSP, ... It replaces such subgraphs and offloads -// their work through the Dispatch API. -class DispatchDelegate : public tflite::SimpleOpaqueDelegateInterface { - public: - static TfLiteOpaqueDelegate* Create(LiteRtDispatchDelegateOptions* options_) { - litert::DispatchDelegateOptionsPtr options( - options_, LiteRtDestroyDispatchDelegateOptions); - if (!options) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return nullptr; - } - - std::unique_ptr managed_sb_delegate( - new DispatchDelegate(std::move(options))); - return tflite::TfLiteOpaqueDelegateFactory::CreateSimpleDelegate( - std::move(managed_sb_delegate), - kTfLiteDelegateFlagsAllowDynamicTensors); - } - - bool IsNodeSupportedByDelegate(const TfLiteOperator* op, - const TfLiteOpaqueNode* node, - TfLiteOpaqueContext* context) const override; - - TfLiteStatus Initialize(TfLiteOpaqueContext* context) override; - - const char* Name() const override; - - std::unique_ptr - CreateDelegateKernelInterface() override; - - private: - static constexpr absl::string_view kDelegateName = "DispatchDelegate"; - - explicit DispatchDelegate(litert::DispatchDelegateOptionsPtr&& options) - : options_(std::move(options)) {} - - litert::DispatchDelegateOptionsPtr options_; - int dispatch_graph_name_id_ = 0; -}; - -bool DispatchDelegate::IsNodeSupportedByDelegate( - const TfLiteOperator* op, const TfLiteOpaqueNode* node, - TfLiteOpaqueContext* context) const { - auto custom_code = absl::string_view(TfLiteOperatorGetCustomName(op)); - return custom_code == kLiteRtDispatchOpCustomCode; -} - -TfLiteStatus DispatchDelegate::Initialize(TfLiteOpaqueContext* context) { - return kTfLiteOk; -} - -const char* DispatchDelegate::Name() const { return kDelegateName.data(); } - -std::unique_ptr -DispatchDelegate::CreateDelegateKernelInterface() { - std::string dispatch_graph_name = - absl::StrFormat("DispatchGraph_%d", dispatch_graph_name_id_++); - - auto kernel = litert::internal::DispatchDelegateKernel::Create( - std::move(dispatch_graph_name), *options_); - if (kernel) { - return std::move(*kernel); - } else { - LITERT_FATAL("Failed to create a dispatch delegate kernel: %s", - kernel.Error().Message().c_str()); - return nullptr; - } -} - -} // namespace - -LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions( - LiteRtEnvironmentOptions environment_options) { - return new LiteRtDispatchDelegateOptions(environment_options); -} - -TfLiteStatus LiteRtAddDispatchDelegateOption( - LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option) { - if (!options) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kTfLiteError; - } - - options->AddOption(option); - return kTfLiteOk; -} - -TfLiteStatus LiteRtDispatchDelegateAddAllocBaseOption( - LiteRtDispatchDelegateOptions* options, const void* alloc_base) { - AddAllocBaseOption(alloc_base, *options); - return kTfLiteOk; -} - -TfLiteStatus LiteRtDispatchDelegateAddAllocFdOption( - LiteRtDispatchDelegateOptions* options, int alloc_fd) { - AddAllocFdOption(alloc_fd, *options); - return kTfLiteOk; -} - -void LiteRtDestroyDispatchDelegateOptions( - LiteRtDispatchDelegateOptions* options) { - delete options; -} - -TfLiteOpaqueDelegate* LiteRtCreateDispatchDelegate( - LiteRtEnvironmentOptions environment_options, - LiteRtDispatchDelegateOptions* options) { - if (!options) { - options = LiteRtCreateDefaultDispatchDelegateOptions(environment_options); - } - return DispatchDelegate::Create(options); -} - -void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate) { - tflite::TfLiteOpaqueDelegateFactory::DeleteSimpleDelegate(delegate); -} - -namespace litert { - -DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr( - LiteRtEnvironmentOptions environment_options) { - return {LiteRtCreateDefaultDispatchDelegateOptions(environment_options), - LiteRtDestroyDispatchDelegateOptions}; -} - -DispatchDelegatePtr CreateDispatchDelegatePtr( - LiteRtEnvironmentOptions environment_options, - DispatchDelegateOptionsPtr&& options) { - return DispatchDelegatePtr( - LiteRtCreateDispatchDelegate(environment_options, options.release()), - LiteRtDestroyDispatchDelegate); -} -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc deleted file mode 100644 index b75f91627cb0db..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc +++ /dev/null @@ -1,529 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -#if defined(__ANDROID__) -#include "platforms/darwinn/tachyon/core/fence/fence.h" -#endif -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/darwinn/driver_shared/fence/fence_test_util.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/signature_runner.h" - -using litert::testing::MakeRuntimeFromTestFileWithNpuModel; -using testing::FloatNear; -using testing::Pointwise; -using Fence = std::shared_ptr; -using ::testing::ElementsAre; - -namespace litert { -namespace { - -constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; -constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -litert::Expected CreateDefaultEnvironment() { - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - return litert::Environment::Create(absl::MakeConstSpan(environment_options)); -} - -TEST(DispatchDelegate, GoogleTensorCpuBuffer) { - LITERT_ASSERT_OK_AND_ASSIGN( - testing::TflRuntime::Ptr runtime, - MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile)); - tflite::Interpreter& interpreter = runtime->Interpreter(); - - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options = nullptr; - LiteRtGetEnvironmentOptions(env.Get(), &env_options); - DispatchDelegateOptionsPtr dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - runtime->Flatbuffer().Buf().Data()); - DispatchDelegatePtr dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - TfLiteTensor* input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - float* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - TfLiteTensor* input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto output_tensor = runner->output_tensor("tfl.custom"); - ASSERT_NE(output_tensor, nullptr); - auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, GoogleTensorHwBuffer) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - LITERT_ASSERT_OK_AND_ASSIGN( - testing::TflRuntime::Ptr runtime, - MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile)); - tflite::Interpreter& interpreter = runtime->Interpreter(); - - internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options = nullptr; - LiteRtGetEnvironmentOptions(env.Get(), &env_options); - - DispatchDelegateOptionsPtr dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - runtime->Flatbuffer().Buf().Data()); - DispatchDelegatePtr dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Create and register tensor buffers for all inputs and outputs. - std::vector input_buffers; - for (int i = 0; i < interpreter.inputs().size(); ++i) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements * input_buffer_requirements, - buffer_context.GetBufferRequirement(interpreter.input_tensor(i))); - ASSERT_EQ(input_buffer_requirements->SupportedTypes()->at(0), - kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer input_buffer, - buffer_context.CreateBufferForTensor(interpreter.input_tensor(i))); - ASSERT_TRUE(input_buffer.IsOwned()); - ASSERT_EQ(*input_buffer.BufferType(), kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer duplicate_buffer, - input_buffer.Duplicate()); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.input_tensor(i), std::move(duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - input_buffers.push_back(std::move(input_buffer)); - } - - std::vector output_buffers; - for (int i = 0; i < interpreter.outputs().size(); ++i) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements * output_buffer_requirements, - buffer_context.GetBufferRequirement(interpreter.output_tensor(i))); - ASSERT_NE(output_buffer_requirements, nullptr); - ASSERT_EQ(output_buffer_requirements->SupportedTypes()->at(0), - kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer output_buffer, - buffer_context.CreateBufferForTensor(interpreter.output_tensor(i))); - ASSERT_TRUE(output_buffer.IsOwned()); - ASSERT_EQ(*output_buffer.BufferType(), kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer duplicate_buffer, - output_buffer.Duplicate()); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.output_tensor(i), std::move(duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - output_buffers.push_back(std::move(output_buffer)); - } - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto& input_0_buffer = input_buffers[0]; - input_0_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto& input_1_buffer = input_buffers[1]; - input_1_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto& output_buffer = output_buffers[0]; - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = output_buffer.Read(output_span); - ASSERT_TRUE(read_success); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModel) { -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - // Create Model and check signatures. - LITERT_ASSERT_OK_AND_ASSIGN( - OwningBufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - Signature& signature = signatures.at(0); - EXPECT_EQ(signature.Key(), Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signature.InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signature.OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.custom")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect AHWB. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(signature_index, - /*output_name=*/"tfl.custom")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, ElementsAre(kLiteRtTensorBufferTypeAhwb)); - - // Create and fill input and output tensor buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute compiled model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - float output_buffer_data[kTestOutputSize]; - absl::Span output_span = - absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModelSharedInput) { - auto model_with_byte_code = internal::GetModelBufWithByteCode( - testing::GetTestFilePath("shared_input_cpu_npu.tflite"), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - size_t signature_index = 0; - auto signature = *model->GetSignature(signature_index); - auto input_buffers = *compiled_model.CreateInputBuffers(signature_index); - auto output_buffers = *compiled_model.CreateOutputBuffers(signature_index); - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 2); - { - EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } - { - EXPECT_EQ(output_names.at(1), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[1].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(DispatchDelegate, CompiledModelAsync) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - // Create Model and check signatures. - LITERT_ASSERT_OK_AND_ASSIGN( - OwningBufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - Signature& signature = signatures.at(0); - absl::string_view signature_key = signature.Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signature.InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signature.OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.custom")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Create and fill input and output tensor buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto input_0_cpu_addr_and_lock, - TensorBufferScopedLock::Create(input_buffers[0])); - - LITERT_ASSERT_OK_AND_ASSIGN(auto input_1_cpu_addr_and_lock, - TensorBufferScopedLock::Create(input_buffers[1])); - - // Attach events to input buffers. - Fence input_fence_0 = platforms::darwinn::fence_util::CreateFence(); - LITERT_ASSERT_OK_AND_ASSIGN( - Event input_event_0, - litert::Event::CreateFromSyncFenceFd(input_fence_0->GetFd(), - /*owns_fd=*/false)); - input_buffers[0].SetEvent(std::move(input_event_0)); - - Fence input_fence_1 = platforms::darwinn::fence_util::CreateFence(); - LITERT_ASSERT_OK_AND_ASSIGN( - Event input_event_1, - litert::Event::CreateFromSyncFenceFd(input_fence_1->GetFd(), - /*owns_fd=*/false)); - input_buffers[1].SetEvent(std::move(input_event_1)); - - // Start the model asynchronously. - bool async; - compiled_model.RunAsync(signature_index, input_buffers, output_buffers, - async); - ASSERT_TRUE(async); - ASSERT_TRUE(output_buffers[0].HasEvent()); - - // Set input values. - std::memcpy(input_0_cpu_addr_and_lock.second, kTestInput0Tensor, - sizeof(kTestInput0Tensor)); - std::memcpy(input_1_cpu_addr_and_lock.second, kTestInput1Tensor, - sizeof(kTestInput1Tensor)); - - // Signal input fences so that the inference can start. - ASSERT_OK(input_fence_0->Signal(/*success=*/true)); - ASSERT_OK(input_fence_1->Signal(/*success=*/true)); - - // Check model output. - float output_buffer_data[kTestOutputSize]; - absl::Span output_span = - absl::MakeSpan(output_buffer_data, kTestOutputSize); - // The next read operation will block on the output buffer's sync fence. - ASSERT_TRUE(output_buffers[0].Read(output_span)); - // Print and confirm the output values are correct. - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc deleted file mode 100644 index b408ae39a027dc..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc +++ /dev/null @@ -1,657 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/c/c_api_opaque.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/runtime/tfl_utils.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace litert { -namespace internal { - -DispatchDelegateKernel::~DispatchDelegateKernel() { - for (size_t i = 0; i < input_tensor_buffer_handles_.size(); ++i) { - (void)LiteRtDispatchDetachInput(invocation_context_, i, - input_tensor_buffer_handles_[i]); - } - - for (size_t i = 0; i < output_tensor_buffer_handles_.size(); ++i) { - (void)LiteRtDispatchDetachOutput(invocation_context_, i, - output_tensor_buffer_handles_[i]); - } - - if (invocation_context_) { - (void)LiteRtDispatchInvocationContextDestroy(invocation_context_); - } - - for (auto& buffer_handle : input_tensor_buffer_handles_) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); - } - - for (auto& buffer_handle : output_tensor_buffer_handles_) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); - } - - if (device_context_) { - (void)LiteRtDispatchDeviceContextDestroy(device_context_); - } - - input_tensor_buffers_.clear(); - output_tensor_buffers_.clear(); -} - -Expected DispatchDelegateKernel::Create( - std::string&& graph_name, const LiteRtDispatchDelegateOptions& options) { - auto dispatch_options = options.GetDispatchOptions(); - if (auto status = LiteRtDispatchInitialize(dispatch_options.data(), - dispatch_options.size()); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to initialize Dispatch API: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to initialize Dispatch API"); - } - - const char* vendor_id; - if (auto status = LiteRtDispatchGetVendorId(&vendor_id); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API vendor ID: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get Dispatch API vendor ID"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API vendor ID: %s", vendor_id); - - const char* build_id; - if (auto status = LiteRtDispatchGetBuildId(&build_id); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API build ID: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get Dispatch API build ID"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API build ID: %s", build_id); - - LiteRtApiVersion api_version; - if (auto status = LiteRtDispatchGetApiVersion(&api_version); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get LiteRT Dispatch API version: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get LiteRT Dispatch API version"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API version: %d.%d.%d", api_version.major, - api_version.minor, api_version.patch); - // Check if the versions mach. - if (api_version.major != LITERT_API_VERSION_MAJOR || - api_version.minor < LITERT_API_VERSION_MINOR) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Found Dispatch API with an unsupported version"); - } - - int capabilities; - if (auto status = LiteRtDispatchGetCapabilities(&capabilities); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API capabilities: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get Dispatch API capabilities"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API capabilities: %d", capabilities); - - if (!(capabilities & kLiteRtDispatchCapabilitiesBasic)) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Dispatch API has insufficient capabilities"); - } - - bool async_dispatch = (capabilities & kLiteRtDispatchCapabilitiesAsync); - if (async_dispatch) { - LITERT_LOG(LITERT_INFO, "Found async dispatch capabilities"); - } - - LiteRtDispatchDeviceContext device_context; - if (auto status = LiteRtDispatchDeviceContextCreate(&device_context); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API device context: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create Dispatch API device context"); - } - - return Ptr(new DispatchDelegateKernel(options, std::move(graph_name), - device_context, async_dispatch)); -} - -TfLiteStatus DispatchDelegateKernel::Init( - TfLiteOpaqueContext* context, const TfLiteOpaqueDelegateParams* params) { - if (params->nodes_to_replace->size != 1) { - LITERT_LOG(LITERT_ERROR, - "Models with more than one dispatch node are not yet supported"); - return kTfLiteError; - } - - auto node_id = params->nodes_to_replace->data[0]; - TfLiteOpaqueNode* node; - TfLiteOperator* op; - if (auto status = TfLiteOpaqueContextGetNodeAndRegistration(context, node_id, - &node, &op); - status != kTfLiteOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get node and registration: %d", status); - return status; - } - - const void* init_data; - int init_data_size; - if (auto status = TfLiteOpaqueNodeGetCustomInitialData(node, &init_data, - &init_data_size); - status != kTfLiteOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get custom initial data: %d", status); - return status; - } - if (!init_data || !init_data_size) { - LITERT_LOG(LITERT_ERROR, "Found custom op with missing initial data"); - return kTfLiteError; - } - - BufferRef custom_opts(init_data, init_data_size); - - // Read offset and size (relative to alloc_base) from the custom options (and - // name). - const auto dispatch_opts = GetDispatchOpOptions(custom_opts); - if (dispatch_opts.bytecode_offset == 0) { - LITERT_LOG(LITERT_ERROR, "Found dispatch op with missing bytecode offset"); - return kTfLiteError; - } - - // Find pointer to the start of the loaded model buffer. - const auto alloc_base = FindAllocBase(options_); - if (!alloc_base) { - LITERT_LOG(LITERT_ERROR, - "Could not find requried delegate options \"alloc_base\""); - return kTfLiteError; - } - - const auto alloc_fd = FindAllocFd(options_); - - // Get location of bytecode in the model buffer relative to alloc_base. - LiteRtMemBuffer exec_bytecode_buffer = { - /*.fd=*/alloc_fd ? *alloc_fd : -1, - /*.base_addr=*/*alloc_base, - /*.offset=*/dispatch_opts.bytecode_offset, - /*.size=*/dispatch_opts.bytecode_size}; - const auto& function_name = dispatch_opts.name; - const int num_inputs = params->input_tensors->size; - const int num_outputs = params->output_tensors->size; - - if (auto status = LiteRtDispatchInvocationContextCreate( - device_context_, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, function_name.data(), num_inputs, num_outputs, - &invocation_context_); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %d", status); - return kTfLiteError; - } - - input_tensor_buffers_require_cpu_sync_.resize(num_inputs); - input_tensor_buffers_.resize(num_inputs); - input_tensor_buffer_handles_.resize(num_inputs); - input_tensor_buffer_used_size_.resize(num_inputs); - - output_tensor_buffers_require_cpu_sync_.resize(num_outputs); - output_tensor_buffers_.resize(num_outputs); - output_tensor_buffer_handles_.resize(num_outputs); - output_tensor_buffer_used_size_.resize(num_outputs); - - void* external_context; - TfLiteOpaqueContextGetExternalContext(context, &external_context, - kTfLiteLiteRtBufferContext); - if (!external_context) { - LITERT_LOG(LITERT_ERROR, "External context not found"); - return kTfLiteError; - } - - buffer_context_ = - reinterpret_cast( - external_context); - - // Register input and output buffer requirements. - size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); - for (size_t i = 0; i < num_node_inputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); - if (!tfl_opaque_tensor) { - LITERT_LOG(LITERT_ERROR, "Failed to get TFL node input %d", i); - return kTfLiteError; - } - auto tensor_type = ConvertTensorType(tfl_opaque_tensor); - if (!tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - auto input_buffer_requirements = - GetBufferRequirements(*tensor_type, i, /*is_input=*/true); - if (auto res = buffer_context_->RegisterBufferRequirement( - tfl_opaque_tensor, std::move(*input_buffer_requirements)); - res != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer requirement"); - return kTfLiteError; - } - } - - size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); - for (size_t i = 0; i < num_node_outputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); - if (!tfl_opaque_tensor) { - LITERT_LOG(LITERT_ERROR, "Failed to get TFL node output %d", i); - return kTfLiteError; - } - auto tensor_type = ConvertTensorType(tfl_opaque_tensor); - if (!tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - auto output_buffer_requirements = - GetBufferRequirements(*tensor_type, i, /*is_input=*/false); - if (auto res = buffer_context_->RegisterBufferRequirement( - tfl_opaque_tensor, std::move(*output_buffer_requirements)); - res != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer requirement"); - return kTfLiteError; - } - } - - return kTfLiteOk; -} - -Expected -DispatchDelegateKernel::GetBufferRequirements( - const RankedTensorType& tensor_type, int io_tensor_index, - bool is_input) const { - auto litert_tensor_type = static_cast(tensor_type); - LiteRtTensorBufferRequirements tensor_buffer_requirements; - if (is_input) { - if (auto status = LiteRtDispatchGetInputRequirements( - invocation_context_, /*input_index=*/io_tensor_index, - &litert_tensor_type, &tensor_buffer_requirements); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, - "Failed to get tensor buffer requirements for input %d: %d", - io_tensor_index, status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get tensor buffer requirements for input"); - } - - } else { - if (auto status = LiteRtDispatchGetOutputRequirements( - invocation_context_, /*output_index=*/io_tensor_index, - &litert_tensor_type, &tensor_buffer_requirements); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, - "Failed to get tensor buffer requirements for output %d: %d", - io_tensor_index, status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get tensor buffer requirements for output"); - } - } - - return TensorBufferRequirements(tensor_buffer_requirements, - /*owned=*/true); -} - -TfLiteStatus DispatchDelegateKernel::CreateAndSetBuffer( - const TfLiteOpaqueTensor* tfl_opaque_tensor, int buffer_index, - bool is_input) { - auto& cached_tensor_buffer = is_input ? input_tensor_buffers_[buffer_index] - : output_tensor_buffers_[buffer_index]; - - auto tensor_type = ConvertTensorType(tfl_opaque_tensor); - if (!tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - - // Check if we can reuse a cached tensor buffer or we need to create a new - // one. - if (static_cast(cached_tensor_buffer)) { - if (auto cached_tensor_type = cached_tensor_buffer.TensorType(); - !cached_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - cached_tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - - if (tensor_type->Layout() == cached_tensor_buffer.TensorType()->Layout()) { - // We can reuse the cached tensor buffer. - return kTfLiteOk; - } - - // We cannot reuse the cached tensor buffer; proceed below. - } - - auto tensor_buffer_requirements = - GetBufferRequirements(*tensor_type, buffer_index, is_input); - if (!tensor_buffer_requirements) { - LITERT_LOG(LITERT_ERROR, "%s", - tensor_buffer_requirements.Error().Message().c_str()); - return kTfLiteError; - } - - auto supported_tensor_buffer_types = - tensor_buffer_requirements->SupportedTypes(); - if (!supported_tensor_buffer_types) { - LITERT_LOG(LITERT_ERROR, "%s", - supported_tensor_buffer_types.Error().Message().c_str()); - return kTfLiteError; - } - - if (supported_tensor_buffer_types->empty()) { - LITERT_LOG(LITERT_ERROR, - "Insufficient number of supported tensor buffer types"); - return kTfLiteError; - } - - // For now we simply pick the first buffer type that's supported. - LiteRtTensorBufferType tensor_buffer_type = - (*supported_tensor_buffer_types)[0]; - - auto tensor_buffer_size = tensor_buffer_requirements->BufferSize(); - if (!tensor_buffer_size) { - LITERT_LOG(LITERT_ERROR, "%s", - tensor_buffer_size.Error().Message().c_str()); - return kTfLiteError; - } - - auto litert_tensor_type = static_cast(*tensor_type); - LiteRtTensorBuffer litert_tensor_buffer; - if (auto status = LiteRtCreateManagedTensorBuffer( - tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, - &litert_tensor_buffer); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to create managed tensor buffer: %d", - status); - return kTfLiteError; - } - - return RegisterLiteRtTensorBuffer(TensorBuffer(litert_tensor_buffer), - *tensor_buffer_size, buffer_index, - is_input); -} - -TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffer( - TensorBuffer&& tensor_buffer, size_t buffer_used_size, int buffer_index, - bool is_input) { - LiteRtTensorBufferHandle buffer_handle; - if (auto status = LiteRtDispatchRegisterTensorBuffer( - device_context_, tensor_buffer.Get(), &buffer_handle); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %d", status); - return kTfLiteError; - } - - if (is_input) { - if (auto status = LiteRtDispatchAttachInput(invocation_context_, - buffer_index, buffer_handle); - status != kLiteRtStatusOk) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, - buffer_handle); - LITERT_LOG(LITERT_ERROR, "Failed to attach tensor buffer to input %d: %d", - buffer_index, status); - return kTfLiteError; - } - - if (tensor_buffer.HasEvent()) { - auto event = tensor_buffer.GetEvent(); - if (!event) { - LITERT_LOG(LITERT_ERROR, - "Failed to get event from tensor buffer %d: %s", - buffer_index, event.Error().Message().c_str()); - return kTfLiteError; - } - - if (!async_dispatch_) { - // If the Dispatch API runtime doesn't support async execution, then - // wait for the event on the CPU. - LITERT_LOG(LITERT_WARNING, "Waiting on an input event on the CPU..."); - if (auto status = event->Wait(/*timeout_in_ms=*/-1); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to wait on event: %s", - status.Error().Message().c_str()); - return kTfLiteError; - } - - } else { - if (auto status = LiteRtDispatchAttachInputEvent( - invocation_context_, buffer_index, event->Get()); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to attach event to input %d: %d", - buffer_index, status); - return kTfLiteError; - } - } - } - - } else { - if (auto status = LiteRtDispatchAttachOutput(invocation_context_, - buffer_index, buffer_handle); - status != kLiteRtStatusOk) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, - buffer_handle); - LITERT_LOG(LITERT_ERROR, - "Failed to attach tensor buffer to output %d: %d", - buffer_index, status); - return kTfLiteError; - } - } - - if (is_input) { - input_tensor_buffers_[buffer_index] = std::move(tensor_buffer); - input_tensor_buffer_handles_[buffer_index] = buffer_handle; - input_tensor_buffer_used_size_[buffer_index] = buffer_used_size; - } else { - output_tensor_buffers_[buffer_index] = std::move(tensor_buffer); - output_tensor_buffer_handles_[buffer_index] = buffer_handle; - output_tensor_buffer_used_size_[buffer_index] = buffer_used_size; - } - return kTfLiteOk; -} - -TfLiteStatus DispatchDelegateKernel::Prepare(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) { - return kTfLiteOk; -} - -TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffers( - TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { - size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); - for (size_t i = 0; i < num_node_inputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); - auto tensor_buffer = buffer_context_->GetTensorBuffer(tfl_opaque_tensor); - if (tensor_buffer.HasValue()) { - // TODO - b/379176766: If the provided TensorBuffer is not supported - // types, we need to create a new one and convert the data from the - // provided TensorBuffer. - auto buffer_size = tensor_buffer->Size(); - if (!buffer_size) { - LITERT_LOG(LITERT_ERROR, "%s", buffer_size.Error().Message().c_str()); - return kTfLiteError; - } - if (auto status = RegisterLiteRtTensorBuffer(std::move(*tensor_buffer), - *buffer_size, i, - /*is_input=*/true); - status != kTfLiteOk) { - return status; - } - input_tensor_buffers_require_cpu_sync_[i] = false; - } else { - LITERT_LOG(LITERT_VERBOSE, - "Input#%d TensorBuffer is not registered. Create a new one", - i); - if (auto status = - CreateAndSetBuffer(tfl_opaque_tensor, i, /*is_input=*/true); - status != kTfLiteOk) { - return status; - } - input_tensor_buffers_require_cpu_sync_[i] = true; - } - } - - size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); - for (size_t i = 0; i < num_node_outputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); - auto tensor_buffer = buffer_context_->GetTensorBuffer(tfl_opaque_tensor); - if (tensor_buffer.HasValue()) { - // TODO - b/379176766: If the provided TensorBuffer is not supported - // types, we need to create a new one and convert the data back to the - // provided TensorBuffer. - auto buffer_size = tensor_buffer->Size(); - if (!buffer_size) { - LITERT_LOG(LITERT_ERROR, "%s", buffer_size.Error().Message().c_str()); - return kTfLiteError; - } - if (auto status = RegisterLiteRtTensorBuffer(std::move(*tensor_buffer), - *buffer_size, i, - /*is_input=*/false); - status != kTfLiteOk) { - return status; - } - output_tensor_buffers_require_cpu_sync_[i] = false; - } else { - LITERT_LOG(LITERT_VERBOSE, - "Output#%d TensorBuffer is not registered. Create a new one", - i); - if (auto status = - CreateAndSetBuffer(tfl_opaque_tensor, i, /*is_input=*/false); - status != kTfLiteOk) { - return status; - } - output_tensor_buffers_require_cpu_sync_[i] = true; - } - } - - return kTfLiteOk; -} - -TfLiteStatus DispatchDelegateKernel::Eval(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) { - if (auto status = RegisterLiteRtTensorBuffers(context, node); - status != kTfLiteOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffers: %d", status); - return kTfLiteError; - } - - size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); - if (num_node_inputs != input_tensor_buffers_.size()) { - LITERT_LOG(LITERT_ERROR, "Invalid number of inputs"); - return kTfLiteError; - } - - for (size_t i = 0; i < num_node_inputs; ++i) { - if (!input_tensor_buffers_require_cpu_sync_[i]) { - continue; - } - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); - void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); - auto& tensor_buffer = input_tensor_buffers_[i]; - - auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); - if (!lock_and_addr) { - LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.Error().Message().c_str()); - return kTfLiteError; - } - - size_t buffer_size = input_tensor_buffer_used_size_[i]; - std::memcpy(lock_and_addr->second, tensor_data, buffer_size); - } - - size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); - if (num_node_outputs != output_tensor_buffers_.size()) { - LITERT_LOG(LITERT_ERROR, "Invalid number of outputs"); - return kTfLiteError; - } - - if (async_dispatch_ && buffer_context_->IsAsyncExecutionMode()) { - std::vector output_events(num_node_outputs); - if (auto status = LiteRtDispatchInvokeAsync( - invocation_context_, output_events.size(), output_events.data()); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to invoke context asynchronously: %d", - status); - return kTfLiteError; - } - for (size_t i = 0; i < output_events.size(); ++i) { - auto output_event = output_events[i]; - if (output_event) { - auto& tensor_buffer = output_tensor_buffers_[i]; - if (auto status = tensor_buffer.SetEvent(Event(output_event)); - !status) { - LITERT_LOG(LITERT_ERROR, - "Failed to set event on output tensor buffer: %s", - status.Error().Message().c_str()); - return kTfLiteError; - } - } - } - - } else { - if (auto status = LiteRtDispatchInvoke(invocation_context_); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %d", status); - return kTfLiteError; - } - } - - for (size_t i = 0; i < num_node_outputs; ++i) { - if (!output_tensor_buffers_require_cpu_sync_[i]) { - continue; - } - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); - void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); - auto& tensor_buffer = output_tensor_buffers_[i]; - - auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); - if (!lock_and_addr) { - LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.Error().Message().c_str()); - return kTfLiteError; - } - - size_t buffer_size = output_tensor_buffer_used_size_[i]; - std::memcpy(tensor_data, lock_and_addr->second, buffer_size); - } - - return kTfLiteOk; -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h deleted file mode 100644 index c53cc09e9d780c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace litert::internal { - -class ExternalLiteRtBufferContext; - -// A TFL kernel that the interpreter calls to dispatch execution through the -// Dispatch API. -class DispatchDelegateKernel - : public tflite::SimpleOpaqueDelegateKernelInterface { - public: - using Ptr = std::unique_ptr; - - ~DispatchDelegateKernel() override; - - static Expected Create(std::string&& graph_name, - const LiteRtDispatchDelegateOptions& options); - - TfLiteStatus Init(TfLiteOpaqueContext* context, - const TfLiteOpaqueDelegateParams* params) override; - - TfLiteStatus Prepare(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) override; - - TfLiteStatus Eval(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) override; - - private: - DispatchDelegateKernel(const LiteRtDispatchDelegateOptions& options, - std::string&& graph_name, - LiteRtDispatchDeviceContext device_context, - bool async_dispatch) - : options_(options), - graph_name_(std::move(graph_name)), - device_context_(device_context), - async_dispatch_(async_dispatch) {} - - Expected GetBufferRequirements( - const RankedTensorType& tensor_type, int io_tensor_index, - bool is_input) const; - - // Creates a new tensor buffer for the given tensor. After that the created - // tensor buffer is registered with RegisterLiteRtTensorBuffer(). - TfLiteStatus CreateAndSetBuffer(const TfLiteOpaqueTensor* tfl_opaque_tensor, - int buffer_index, bool is_input); - - // Registers the given LiteRtTensorBuffer (and its size) with the Dispatch - // API. - // Also update the internal state (input_tensor_buffers_, etc.) to keep track - // of the registered tensor buffers. - TfLiteStatus RegisterLiteRtTensorBuffer(TensorBuffer&& tensor_buffer, - size_t used_size, int buffer_index, - bool is_input); - - // Registers LiteRtTensorBuffers for all inputs and outputs of the given - // node. - // Also update the internal state (input_tensor_buffers_, etc.) to keep track - // of the registered tensor buffers. - TfLiteStatus RegisterLiteRtTensorBuffers(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node); - - const LiteRtDispatchDelegateOptions& options_; - std::string graph_name_; - LiteRtDispatchDeviceContext device_context_; - LiteRtDispatchInvocationContext invocation_context_ = nullptr; - // Indicates whether the Dispatch API can be invoked asynchronously. - const bool async_dispatch_; - - ExternalLiteRtBufferContext* buffer_context_ = nullptr; - - // Indicates whether the input tensor buffer requires a CPU sync before - // invoking the Dispatch API. - std::vector input_tensor_buffers_require_cpu_sync_; - - std::vector input_tensor_buffers_; - std::vector input_tensor_buffer_handles_; - std::vector input_tensor_buffer_used_size_; - - // Indicates whether the output tensor buffer requires a CPU sync after - // invoking the Dispatch API. - std::vector output_tensor_buffers_require_cpu_sync_; - - std::vector output_tensor_buffers_; - std::vector output_tensor_buffer_handles_; - std::vector output_tensor_buffer_used_size_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc deleted file mode 100644 index aa26b3c211c9ad..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/signature_runner.h" - -namespace litert { -namespace { - -using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; -using ::testing::FloatNear; -using ::testing::Pointwise; - -static constexpr absl::string_view kNpuFile = kMediaTekModelFileName; -static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -static constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -TEST(DispatchDelegate, MediaTekCpuBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - auto* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto output_tensor = runner->output_tensor("tfl.custom"); - ASSERT_NE(output_tensor, nullptr); - auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, MediaTekHwBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Create and register tensor buffers for all inputs and outputs. - - std::vector input_buffers; - for (int i = 0; i < interpreter.inputs().size(); ++i) { - auto input_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer_requirements); - ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeAhwb); - auto input_buffer = - buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer); - ASSERT_TRUE(input_buffer->IsOwned()); - ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); - auto duplicate_buffer = (*input_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.input_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - input_buffers.push_back(std::move(*input_buffer)); - } - - std::vector output_buffers; - for (int i = 0; i < interpreter.outputs().size(); ++i) { - auto output_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer_requirements); - ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeAhwb); - auto output_buffer = - buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer.HasValue()); - ASSERT_TRUE(output_buffer->IsOwned()); - ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); - auto duplicate_buffer = (*output_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.output_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - output_buffers.push_back(std::move(*output_buffer)); - } - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto& input_0_buffer = input_buffers[0]; - input_0_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto& input_1_buffer = input_buffers[1]; - input_1_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto& output_buffer = output_buffers[0]; - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = output_buffer.Read(output_span); - ASSERT_TRUE(read_success); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModel) { - auto model_with_byte_code = - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - auto signatures = model->GetSignatures(); - ASSERT_TRUE(signatures); - EXPECT_EQ(signatures->size(), 1); - auto& signature = signatures->at(0); - auto signature_key = signature.Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); - EXPECT_TRUE(input_buffers_res); - auto& input_buffers = *input_buffers_res; - - auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); - EXPECT_TRUE(output_buffers_res); - auto& output_buffers = *output_buffers_res; - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModelSharedInput) { - auto model_with_byte_code = internal::GetModelBufWithByteCode( - testing::GetTestFilePath("shared_input_cpu_npu.tflite"), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - size_t signature_index = 0; - auto signature = *model->GetSignature(signature_index); - auto input_buffers = *compiled_model.CreateInputBuffers(signature_index); - auto output_buffers = *compiled_model.CreateOutputBuffers(signature_index); - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 2); - { - EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } - { - EXPECT_EQ(output_names.at(1), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[1].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h deleted file mode 100644 index c4847fdaac0e86..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -class LiteRtDispatchDelegateOptions { - public: - explicit LiteRtDispatchDelegateOptions( - const LiteRtEnvironmentOptionsT* environment_options) { - if (!environment_options) { - return; - } - auto option = - environment_options->GetOption(kLiteRtEnvOptionTagDispatchLibraryDir); - if (!option.HasValue()) { - return; - } - - if (option->type != kLiteRtAnyTypeString) { - LITERT_LOG(LITERT_WARNING, - "Ignoring option kLiteRtEnvOptionTagDispatchLibraryDir due " - "to invalid value"); - return; - } - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*option, - }; - AddOption(dispatch_option); - } - - // Push a new dispatch option. - void AddOption(LiteRtDispatchOption option) { options_.push_back(option); } - - // Get all dispatch options. - const std::vector& GetDispatchOptions() const { - return options_; - } - - // Find a dispatch option under the given name if it exists. - litert::Expected FindDispatchOption(absl::string_view name) const { - for (const auto& option : options_) { - if (option.name != name) { - continue; - } - return litert::ToStdAny(option.value); - } - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - private: - std::vector options_; -}; - -// -// Common options -// - -static constexpr absl::string_view kAllocBase = "alloc_base"; -static constexpr absl::string_view kAllocFd = "alloc_fd"; - -inline void AddAllocBaseOption(const void* alloc_base, - LiteRtDispatchDelegateOptions& opts) { - LiteRtAny opt; - opt.type = kLiteRtAnyTypeVoidPtr; - opt.ptr_value = alloc_base; - opts.AddOption(LiteRtDispatchOption{kAllocBase.data(), opt}); -} - -inline litert::Expected FindAllocBase( - const LiteRtDispatchDelegateOptions& opts) { - auto alloc_base = opts.FindDispatchOption(kAllocBase); - if (!alloc_base) { - return alloc_base.Error(); - } - return std::any_cast(*alloc_base); -} - -inline void AddAllocFdOption(int alloc_fd, - LiteRtDispatchDelegateOptions& opts) { - LiteRtAny opt; - opt.type = kLiteRtAnyTypeVoidPtr; - opt.int_value = alloc_fd; - opts.AddOption(LiteRtDispatchOption{kAllocBase.data(), opt}); -} - -inline litert::Expected FindAllocFd( - const LiteRtDispatchDelegateOptions& opts) { - auto alloc_fd = opts.FindDispatchOption(kAllocFd); - if (!alloc_fd) { - return alloc_fd.Error(); - } - return std::any_cast(*alloc_fd); -} - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc deleted file mode 100644 index 2b18e48f63b5c0..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/signature_runner.h" - -namespace litert { -namespace { - -using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; -using ::testing::FloatNear; -using ::testing::Pointwise; - -static constexpr absl::string_view kNpuFile = kQualcommModelFileName; -static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -static constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -TEST(DispatchDelegate, QualcommCpuBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - auto* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto output_tensor = runner->output_tensor("tfl.custom"); - ASSERT_NE(output_tensor, nullptr); - auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, QualcommHwBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Create and register tensor buffers for all inputs and outputs. - - std::vector input_buffers; - for (int i = 0; i < interpreter.inputs().size(); ++i) { - auto input_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer_requirements); - ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeFastRpc); - auto input_buffer = - buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer); - ASSERT_TRUE(input_buffer->IsOwned()); - ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeFastRpc); - auto duplicate_buffer = (*input_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.input_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - input_buffers.push_back(std::move(*input_buffer)); - } - - std::vector output_buffers; - for (int i = 0; i < interpreter.outputs().size(); ++i) { - auto output_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer_requirements); - ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeFastRpc); - auto output_buffer = - buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer.HasValue()); - ASSERT_TRUE(output_buffer->IsOwned()); - ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeFastRpc); - auto duplicate_buffer = (*output_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.output_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - output_buffers.push_back(std::move(*output_buffer)); - } - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto& input_0_buffer = input_buffers[0]; - input_0_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto& input_1_buffer = input_buffers[1]; - input_1_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto& output_buffer = output_buffers[0]; - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = output_buffer.Read(output_span); - ASSERT_TRUE(read_success); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModel) { - auto model_with_byte_code = - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - auto signatures = model->GetSignatures(); - ASSERT_TRUE(signatures); - EXPECT_EQ(signatures->size(), 1); - auto& signature = signatures->at(0); - auto signature_key = signature.Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); - EXPECT_TRUE(input_buffers_res); - auto& input_buffers = *input_buffers_res; - - auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); - EXPECT_TRUE(output_buffers_res); - auto& output_buffers = *output_buffers_res; - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, QualcommSharedInput) { - auto model_with_byte_code = internal::GetModelBufWithByteCode( - testing::GetTestFilePath("shared_input_cpu_npu.tflite"), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - size_t signature_index = 0; - auto signature = *model->GetSignature(signature_index); - auto input_buffers = *compiled_model.CreateInputBuffers(signature_index); - auto output_buffers = *compiled_model.CreateOutputBuffers(signature_index); - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 2); - { - EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } - { - EXPECT_EQ(output_names.at(1), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[1].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc deleted file mode 100644 index f725941832d5be..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc +++ /dev/null @@ -1,571 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -#include - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" - -#define INVOKE_FUNC(function, ...) \ - if (!TheApi.interface) { \ - LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - if (!TheApi.interface->function) { \ - LITERT_LOG(LITERT_ERROR, #function " not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - return TheApi.interface->function(__VA_ARGS__); - -#define INVOKE_ASYNC_FUNC(function, ...) \ - if (!TheApi.async_interface) { \ - LITERT_LOG(LITERT_ERROR, "Dispatch API async interface not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - if (!TheApi.async_interface->function) { \ - LITERT_LOG(LITERT_ERROR, #function " not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - return TheApi.async_interface->function(__VA_ARGS__); - -#define INVOKE_GRAPH_FUNC(function, ...) \ - if (!TheApi.graph_interface) { \ - LITERT_LOG(LITERT_ERROR, "Dispatch API graoh interface not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - if (!TheApi.graph_interface->function) { \ - LITERT_LOG(LITERT_ERROR, #function " not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - return TheApi.graph_interface->function(__VA_ARGS__); - -namespace { - -litert::SharedLibrary* DispatchSharedLibrary = nullptr; -bool IsTheApiInitialized = false; -LiteRtDispatchApi TheApi = { - /*.version=*/{/*.major=*/0, /*.minor=*/0, /*.patch=*/0}, - /*.interface=*/nullptr, - /*.async_interface=*/nullptr, - /*.graph_interface=*/nullptr, -}; - -LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { - INVOKE_FUNC(initialize, options, num_options); -} - -litert::Expected GetSharedLibraryPath( - const LiteRtDispatchOption* options, int num_options) { - std::vector dispatch_lib_paths; - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - litert::internal::FindLiteRtDispatchSharedLibs(option.value.str_value, - dispatch_lib_paths); - } - } - if (dispatch_lib_paths.empty()) { - LITERT_LOG(LITERT_ERROR, "No dispatch library found"); - return litert::Error(kLiteRtStatusErrorRuntimeFailure); - } - if (dispatch_lib_paths.size() > 1) { - LITERT_LOG(LITERT_WARNING, "Multiple dispatch libraries found"); - } - return dispatch_lib_paths[0]; -} -} // namespace - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, - int num_options) { - if (IsTheApiInitialized) { - return kLiteRtStatusOk; - } - - // TODO(piyu): support Android systems where libraries are not unpacked in the - // system directory. - LITERT_ASSIGN_OR_RETURN(auto shared_lib_path, - GetSharedLibraryPath(options, num_options)); - - LITERT_LOG(LITERT_INFO, "Loading shared library: %s", - shared_lib_path.c_str()); - - if (!DispatchSharedLibrary) { - DispatchSharedLibrary = new litert::SharedLibrary(); - } - - LITERT_ASSIGN_OR_RETURN( - *DispatchSharedLibrary, - litert::SharedLibrary::Load(shared_lib_path, - litert::RtldFlags::Now().Local())); - - using LiteRtDispatchGetApi_t = LiteRtStatus (*)(LiteRtDispatchApi*); - LITERT_ASSIGN_OR_RETURN( - auto LiteRtDispatchGetApi, - DispatchSharedLibrary->LookupSymbol( - "LiteRtDispatchGetApi")); - - if (auto status = LiteRtDispatchGetApi(&TheApi); status != kLiteRtStatusOk) { - return status; - } - - if (!litert::internal::IsSameVersionAsRuntime(TheApi.version)) { - LITERT_LOG(LITERT_ERROR, "Unsupported dispatch runtime version"); - return kLiteRtStatusErrorWrongVersion; - } - - auto status = Initialize(options, num_options); - if (status == kLiteRtStatusOk) { - IsTheApiInitialized = true; - } - return status; -} - -LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtApiVersion* api_version) { - if (!api_version) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - *api_version = TheApi.version; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id) { - if (!vendor_id) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_vendor_id, vendor_id); -} - -LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id) { - if (!build_id) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_build_id, build_id); -} - -LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities) { - if (!capabilities) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_capabilities, capabilities); -} - -LiteRtStatus LiteRtDispatchDeviceContextCreate( - LiteRtDispatchDeviceContext* device_context) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(device_context_create, device_context); -} - -LiteRtStatus LiteRtDispatchDeviceContextDestroy( - LiteRtDispatchDeviceContext device_context) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(device_context_destroy, device_context); -} - -LiteRtStatus LiteRtDispatchGetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_input_requirements, invocation_context, input_index, - tensor_type, tensor_buffer_requirements); -} - -LiteRtStatus LiteRtDispatchGetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_output_requirements, invocation_context, output_index, - tensor_type, tensor_buffer_requirements); -} - -LiteRtStatus LiteRtDispatchRegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (!device_context || !tensor_buffer || !tensor_buffer_handle) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(register_tensor_buffer, device_context, tensor_buffer, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(unregister_tensor_buffer, device_context, tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchInvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - if (!device_context || !exec_bytecode_buffer || !invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(invocation_context_create, device_context, exec_type, - exec_bytecode_buffer, function_name, num_inputs, num_outputs, - invocation_context); -} - -LiteRtStatus LiteRtDispatchInvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(invocation_context_destroy, invocation_context); -} - -LiteRtStatus LiteRtDispatchAttachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(attach_input, invocation_context, graph_input_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchAttachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - if (!TheApi.interface) { - LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - if (!TheApi.interface->attach_output) { - LITERT_LOG(LITERT_ERROR, "attach_output_tensor_buffer not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - INVOKE_FUNC(attach_output, invocation_context, graph_output_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchDetachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(detach_input, invocation_context, graph_input_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchDetachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(detach_output, invocation_context, graph_output_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchInvoke( - LiteRtDispatchInvocationContext invocation_context) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(invoke, invocation_context); -} - -LiteRtStatus LiteRtDispatchStartMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, int detail_level) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } else if (detail_level < 0) { - LITERT_LOG(LITERT_ERROR, "Invalid detail level"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(start_metrics_collection, invocation_context, detail_level); -} - -LiteRtStatus LiteRtDispatchStopMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics) { - if (!invocation_context || !metrics) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(stop_metrics_collection, invocation_context, metrics); -} - -LiteRtStatus LiteRtDispatchGetNumMetrics(LiteRtDispatchMetrics metrics, - int* num_metrics) { - if (!metrics || !num_metrics) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_num_metrics, metrics, num_metrics); -} - -LiteRtStatus LiteRtDispatchGetMetric(LiteRtDispatchMetrics metrics, - int metric_index, LiteRtMetric* metric) { - if (!metrics || !metric) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_metric, metrics, metric_index, metric); -} - -LiteRtStatus LiteRtDispatchDestroyMetrics(LiteRtDispatchMetrics metrics) { - if (!metrics) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(destroy_metrics, metrics); -} - -// ///////////////////////////////////////////////////////////////////////////// -// Async Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchAttachInputEvent( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event) { - if (!invocation_context || !input_event) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_ASYNC_FUNC(attach_input_event, invocation_context, graph_input_index, - input_event); -} - -LiteRtStatus LiteRtDispatchInvokeAsync( - LiteRtDispatchInvocationContext invocation_context, int num_output_events, - LiteRtEvent* output_events) { - if (!invocation_context || !output_events) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_ASYNC_FUNC(invoke_async, invocation_context, num_output_events, - output_events); -} - -// ///////////////////////////////////////////////////////////////////////////// -// Graph Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchGraphCreate( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph) { - if (!device_context || !graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(graph_create, device_context, graph); -} - -LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph graph) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(graph_destroy, graph); -} - -LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(add_node, graph, node_id, node_type); -} - -LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(add_edge, graph, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - int input_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_node_input, graph, node_id, input_index, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - int output_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_node_output, graph, node_id, output_index, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph graph, - int input_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_graph_input, graph, input_index, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph graph, - int output_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_graph_output, graph, output_index, edge_id); -} - -LiteRtStatus LiteRtDispatchLoadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle) { - if (!device_context || !bytecode_buffer || !exec_handle) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - if (!TheApi.graph_interface) { - LITERT_LOG(LITERT_ERROR, "Dispatch API graph interface not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - if (!TheApi.graph_interface->load_executable) { - LITERT_LOG(LITERT_ERROR, "load_executable not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - INVOKE_GRAPH_FUNC(load_executable, device_context, type, bytecode_buffer, - exec_handle); -} - -LiteRtStatus LiteRtDispatchUnloadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(unload_executable, device_context, exec_handle); -} - -LiteRtStatus LiteRtDispatchAssignNodeFunction( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, const char* function_name) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(assign_node_function, graph, node_id, exec_handle, - function_name); -} - -LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph graph, - const char* key, const char* value) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(annotate_graph, graph, key, value); -} - -LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - const char* key, const char* value) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(annotate_node, graph, node_id, key, value); -} - -LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id, - const char* key, const char* value) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(annotate_edge, graph, edge_id, key, value); -} - -LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context) { - if (!device_context || !graph || !invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(invocation_context_create_from_graph, device_context, graph, - invocation_context); -} diff --git a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc b/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc deleted file mode 100644 index 450ed56f8dc2aa..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_DMABUF_SUPPORT -#include -#include -#endif // LITERT_HAS_DMABUF_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_DMABUF_SUPPORT -namespace { - -class DmaBufLibrary { - public: - using Ptr = std::unique_ptr; - - ~DmaBufLibrary() { - if (allocator_) { - free_allocator_(allocator_); - } - } - - static Expected Create() { - DlHandle dlhandle(::dlopen("libdmabufheap.so", RTLD_LAZY | RTLD_LOCAL), - ::dlclose); - if (!dlhandle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "libdmabufheap.so not found"); - } - - auto create_allocator = reinterpret_cast( - ::dlsym(dlhandle.get(), "CreateDmabufHeapBufferAllocator")); - if (!create_allocator) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "CreateDmabufHeapBufferAllocator not found"); - } - - auto free_allocator = reinterpret_cast( - ::dlsym(dlhandle.get(), "FreeDmabufHeapBufferAllocator")); - if (!free_allocator) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "FreeDmabufHeapBufferAllocator not found"); - } - - auto alloc_buffer = reinterpret_cast( - ::dlsym(dlhandle.get(), "DmabufHeapAlloc")); - if (!alloc_buffer) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "DmabufHeapAlloc not found"); - } - - void* allocator = create_allocator(); - if (!allocator) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "CreateDmabufHeapBufferAllocator failed"); - } - - return Ptr(new DmaBufLibrary(std::move(dlhandle), allocator, free_allocator, - alloc_buffer)); - } - - Expected Alloc(size_t size) { - int fd = alloc_buffer_(allocator_, kDmaBufHeap, size, /*flags=*/0, - /*legacy_align=*/0); - if (fd < 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate DMA-BUF buffer"); - } - void* addr = - ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - if (addr == MAP_FAILED) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to mem-map DMA-BUF buffer"); - } - records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; - return DmaBufBuffer{.fd = fd, .addr = addr}; - } - - void Free(void* addr) { - auto iter = records_.find(addr); - if (iter == records_.end()) { - return; - } - auto& record = iter->second; - ::munmap(record.addr, record.size); - ::close(record.fd); - records_.erase(iter); - } - - private: - static constexpr const char* kDmaBufHeap = "system"; - - struct Record { - int fd; - void* addr; - size_t size; - }; - - using DlHandle = std::unique_ptr; - using CreateAllocator = void* (*)(); - using FreeAllocator = void (*)(void*); - using AllocBuffer = int (*)(void*, const char*, size_t, unsigned int, size_t); - - DmaBufLibrary(DlHandle&& dlhandle, void* allocator, - FreeAllocator free_allocator, AllocBuffer alloc_buffer) - : dlhandle_(std::move(dlhandle)) { - allocator_ = allocator; - free_allocator_ = free_allocator; - alloc_buffer_ = alloc_buffer; - } - - DlHandle dlhandle_; - void* allocator_; - FreeAllocator free_allocator_; - AllocBuffer alloc_buffer_; - absl::node_hash_map records_; -}; - -DmaBufLibrary* TheDmaBufLibrary; -ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); - -Expected InitLibraryIfNeededUnlocked() { - if (!TheDmaBufLibrary) { - if (auto library = DmaBufLibrary::Create(); library) { - TheDmaBufLibrary = library->release(); - } else { - return Unexpected(library.Error()); - } - } - return {}; -} - -} // namespace -#endif // LITERT_HAS_DMABUF_SUPPORT - -bool DmaBufBuffer::IsSupported() { -#if LITERT_HAS_DMABUF_SUPPORT - absl::MutexLock lock(&TheMutex); - auto status = InitLibraryIfNeededUnlocked(); - return static_cast(status); -#else // LITERT_HAS_DMABUF_SUPPORT - return false; -#endif // LITERT_HAS_DMABUF_SUPPORT -} - -Expected DmaBufBuffer::Alloc(size_t size) { -#if LITERT_HAS_DMABUF_SUPPORT - absl::MutexLock lock(&TheMutex); - if (auto status = InitLibraryIfNeededUnlocked(); !status) { - return Unexpected(status.Error()); - } - return TheDmaBufLibrary->Alloc(size); -#else // LITERT_HAS_DMABUF_SUPPORT - return Unexpected(kLiteRtStatusErrorUnsupported, - "DmaBufBuffer::Alloc not implemented for this platform"); -#endif // LITERT_HAS_DMABUF_SUPPORT -} - -void DmaBufBuffer::Free(void* addr) { -#if LITERT_HAS_DMABUF_SUPPORT - absl::MutexLock lock(&TheMutex); - if (TheDmaBufLibrary) { - TheDmaBufLibrary->Free(addr); - } -#endif // LITERT_HAS_DMABUF_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h b/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h deleted file mode 100644 index a391e0cf892a56..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct DmaBufBuffer { - int fd; - void* addr; - - static bool IsSupported(); - static Expected Alloc(size_t size); - static void Free(void* addr); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/event.cc b/tensorflow/lite/experimental/litert/runtime/event.cc deleted file mode 100644 index 70f2cbb5beb512..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/event.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/event.h" - -#include - -#include -#include - -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#if LITERT_HAS_SYNC_FENCE_SUPPORT -#include -#include -#endif // LITERT_HAS_SYNC_FENCE_SUPPORT -#if LITERT_HAS_OPENCL_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h" -#endif // LITERT_HAS_OPENCL_SUPPORT - -using litert::Error; -using litert::Expected; - -Expected LiteRtEventT::Wait(int64_t timeout_in_ms) { - if (type == LiteRtEventTypeSyncFenceFd) { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - struct pollfd fds = { - .fd = fd, - .events = POLLIN, - }; - - int ret; - do { - ret = ::poll(&fds, 1, timeout_in_ms); - if (ret == 1) { - break; - } else if (ret == 0) { - return Error(kLiteRtStatusErrorTimeoutExpired, "Timeout expired"); - } - } while (ret == -1 && (errno == EINTR || errno == EAGAIN)); - - if (ret < 0) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Error waiting for fence"); - } - - return {}; - -#else - return Error(kLiteRtStatusErrorUnsupported, - "LiteRtEventWait not implemented for this platform"); -#endif - } else if (type == LiteRtEventTypeOpenCl) { -#if LITERT_HAS_OPENCL_SUPPORT - return litert::cl::WaitForEvents(/*num_events=*/1, - /*event_list=*/&opencl_event); -#else - return Error(kLiteRtStatusErrorUnsupported, - "LiteRtEventWait not implemented for this platform"); -#endif - } - return Error(kLiteRtStatusErrorInvalidArgument, "Invalid event type"); -} - -#if LITERT_HAS_SYNC_FENCE_SUPPORT -namespace { -inline bool IsFdValid(int fd) { - return ::fcntl(fd, F_GETFD) != -1 || errno != EBADF; -} -} // namespace -#endif - -LiteRtEventT::~LiteRtEventT() { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - if (type == LiteRtEventTypeSyncFenceFd && owns_fd && IsFdValid(fd)) { - ::close(fd); - } -#endif -} - -Expected LiteRtEventT::Signal() { -#if LITERT_HAS_OPENCL_SUPPORT - if (type == LiteRtEventTypeOpenCl) { - return litert::cl::SetUserEventStatus(opencl_event); - } -#endif - return Error(kLiteRtStatusErrorInvalidArgument, - "The event signal is not supported"); -} - -Expected LiteRtEventT::CreateManaged(LiteRtEventType type) { -#if LITERT_HAS_OPENCL_SUPPORT - if (type == LiteRtEventTypeOpenCl) { - auto& env = litert::internal::GpuEnvironmentSingleton::GetInstance(); - LITERT_ASSIGN_OR_RETURN( - cl_event user_event, - litert::cl::CreateUserEvent(env.getContext()->context())); - return new LiteRtEventT{ - .type = LiteRtEventTypeOpenCl, - .opencl_event = user_event, - }; - } -#endif - return Error(kLiteRtStatusErrorInvalidArgument, - absl::StrFormat("CreateManaged doesn't support type %d", type)); -} diff --git a/tensorflow/lite/experimental/litert/runtime/event.h b/tensorflow/lite/experimental/litert/runtime/event.h deleted file mode 100644 index df93d5cfac10b4..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/event.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_OPENCL_SUPPORT -extern "C" { -typedef struct _cl_event* cl_event; -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -struct LiteRtEventT { - LiteRtEventType type = LiteRtEventTypeUnknown; -#if LITERT_HAS_SYNC_FENCE_SUPPORT - int fd = -1; - bool owns_fd = false; -#endif -#if LITERT_HAS_OPENCL_SUPPORT - cl_event opencl_event; -#endif - ~LiteRtEventT(); - litert::Expected Wait(int64_t timeout_in_ms); - litert::Expected GetSyncFenceFd() const { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - return fd; -#else - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Sync fence is not supported on this platform"); -#endif - } - litert::Expected Signal(); - static litert::Expected CreateManaged(LiteRtEventType type); -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc b/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc deleted file mode 100644 index 63ace18c1a85da..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" - -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/runtime/tfl_utils.h" - -namespace litert { -namespace internal { - -LiteRtStatus ExternalLiteRtBufferContext::RegisterBufferRequirement( - const TfLiteOpaqueTensor* tensor, - TensorBufferRequirements&& buffer_requirements) { - if (buffer_requirements_.find(tensor) != buffer_requirements_.end()) { - LITERT_LOG(LITERT_ERROR, - "RegisterBufferRequirement already exists for tensor: %p", - tensor); - return kLiteRtStatusErrorRuntimeFailure; - } - buffer_requirements_[tensor] = std::move(buffer_requirements); - return kLiteRtStatusOk; -} - -litert::Expected -ExternalLiteRtBufferContext::GetBufferRequirement( - const TfLiteOpaqueTensor* tensor) { - auto it = buffer_requirements_.find(tensor); - if (it == buffer_requirements_.end()) { - return litert::Unexpected(kLiteRtStatusErrorNotFound, - "Buffer requirement not found"); - } - return &(it->second); -} - -LiteRtStatus ExternalLiteRtBufferContext::RegisterTensorBuffer( - const TfLiteOpaqueTensor* tensor, TensorBuffer&& tensor_buffer) { - tensor_buffers_[tensor] = std::move(tensor_buffer); - return kLiteRtStatusOk; -} - -litert::Expected ExternalLiteRtBufferContext::GetTensorBuffer( - const TfLiteOpaqueTensor* tensor) { - auto it = tensor_buffers_.find(tensor); - if (it == tensor_buffers_.end()) { - return litert::Unexpected(kLiteRtStatusErrorNotFound, - "Tensor buffer not found"); - } - - auto duplicate_tensor_buffer = it->second.Duplicate(); - if (!duplicate_tensor_buffer) { - return litert::Unexpected(duplicate_tensor_buffer.Error()); - } - return std::move(duplicate_tensor_buffer.Value()); -} - -litert::Expected -ExternalLiteRtBufferContext::CreateBufferForTensor( - const TfLiteOpaqueTensor* tensor) { - auto tensor_buffer_requirements = GetBufferRequirement(tensor); - if (!tensor_buffer_requirements) { - return litert::Unexpected(tensor_buffer_requirements.Error()); - } - - auto tensor_type = litert::internal::ConvertTensorType(tensor); - if (!tensor_type) { - return litert::Unexpected(tensor_type.Error()); - } - - auto supported_tensor_buffer_types = - (*tensor_buffer_requirements)->SupportedTypes(); - if (!supported_tensor_buffer_types) { - return litert::Unexpected(supported_tensor_buffer_types.Error()); - } - if (supported_tensor_buffer_types->empty()) { - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "Insufficient number of supported tensor buffer types"); - } - - // For now we simply pick the first buffer type that's supported. - LiteRtTensorBufferType tensor_buffer_type = - (*supported_tensor_buffer_types)[0]; - - auto tensor_buffer_size = (*tensor_buffer_requirements)->BufferSize(); - if (!tensor_buffer_size) { - return litert::Unexpected(tensor_buffer_size.Error()); - } - auto litert_tensor_type = static_cast(*tensor_type); - - LiteRtTensorBuffer litert_tensor_buffer; - if (auto status = LiteRtCreateManagedTensorBuffer( - tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, - &litert_tensor_buffer); - status != kLiteRtStatusOk) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create managed tensor buffer"); - } - - return TensorBuffer(litert_tensor_buffer, /*owned=*/true); -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h b/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h deleted file mode 100644 index 81fc1fcdea9871..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ - -#include -#include -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace litert::internal { - -class ExternalLiteRtBufferContext : public TfLiteExternalContext { - public: - ExternalLiteRtBufferContext() = default; - ~ExternalLiteRtBufferContext() = default; - - // Registers a tensor buffer requirements for the given tensor. - // The registered TensorBufferRequirements object is owned by - // ExternalLiteRtBufferContext. - // Note: Currently, the system pre-registers tensor buffer requirements before - // they're actually used. A more efficient approach would be to query - // DelegateKernel only when these requirements are needed. - LiteRtStatus RegisterBufferRequirement( - const TfLiteOpaqueTensor* tensor, - TensorBufferRequirements&& buffer_requirements); - - inline LiteRtStatus RegisterBufferRequirement( - const TfLiteTensor* tensor, - TensorBufferRequirements&& buffer_requirements) { - return RegisterBufferRequirement( - reinterpret_cast(tensor), - std::move(buffer_requirements)); - } - - inline LiteRtStatus RegisterLiteRtBufferRequirement( - const TfLiteTensor* tensor, - LiteRtTensorBufferRequirements& litert_buffer_requirements) { - return RegisterBufferRequirement( - reinterpret_cast(tensor), - TensorBufferRequirements(litert_buffer_requirements, - /*owned=*/true)); - } - - // Gets a registered tensor buffer requirements for the given tensor. - // The returned TensorBufferRequirements object is still owned by - // ExternalLiteRtBufferContext. - litert::Expected GetBufferRequirement( - const TfLiteOpaqueTensor* tensor); - - inline litert::Expected GetBufferRequirement( - const TfLiteTensor* tensor) { - return GetBufferRequirement( - reinterpret_cast(tensor)); - } - - // Registers a tensor buffer for the given tensor. - // The registered TensorBuffer object is owned by ExternalLiteRtBufferContext. - LiteRtStatus RegisterTensorBuffer(const TfLiteOpaqueTensor* tensor, - TensorBuffer&& tensor_buffer); - - inline LiteRtStatus RegisterTensorBuffer(const TfLiteTensor* tensor, - TensorBuffer&& tensor_buffer) { - return RegisterTensorBuffer( - reinterpret_cast(tensor), - std::move(tensor_buffer)); - } - - // Gets a registered tensor buffer for the given tensor. - // The returned TensorBuffer object is duplication (reference counted) - // of registered TensorBuffer. - litert::Expected GetTensorBuffer( - const TfLiteOpaqueTensor* tensor); - - inline litert::Expected GetTensorBuffer( - const TfLiteTensor* tensor) { - return GetTensorBuffer(reinterpret_cast(tensor)); - } - - // Creates a tensor buffer for the given tensor. - // The callers takes ownership of the returned TensorBuffer object. - litert::Expected CreateBufferForTensor( - const TfLiteOpaqueTensor* tensor); - - inline litert::Expected CreateBufferForTensor( - const TfLiteTensor* tensor) { - return CreateBufferForTensor( - reinterpret_cast(tensor)); - } - - // Sets the async execution mode. It's set by CompiledModel and used by - // DelegateKernel to decide whether to use async execution mode. - inline void SetAsyncExecutionMode(bool async_execution_mode) { - async_execution_mode_ = async_execution_mode; - } - - // Returns true if the async execution mode is set. - inline bool IsAsyncExecutionMode() const { return async_execution_mode_; } - - private: - std::unordered_map - buffer_requirements_; - std::unordered_map tensor_buffers_; - - ExternalLiteRtBufferContext(const ExternalLiteRtBufferContext&) = delete; - ExternalLiteRtBufferContext& operator=(const ExternalLiteRtBufferContext&) = - delete; - - bool async_execution_mode_ = false; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc b/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc deleted file mode 100644 index d0ec124b3177da..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" - -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_FASTRPC_SUPPORT -#include -#endif // LITERT_HAS_FASTRPC_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_FASTRPC_SUPPORT -namespace { - -class FastRpcMemLibrary { - public: - using Ptr = std::unique_ptr; - - static Expected Create() { - DlHandle dlhandle(::dlopen("libcdsprpc.so", RTLD_NOW | RTLD_LOCAL), - ::dlclose); - if (!dlhandle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "libcdsprpc.so not found"); - } - - auto rpcmem_alloc = - reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_alloc")); - if (!rpcmem_alloc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "rpcmem_alloc not found"); - } - - auto rpcmem_free = - reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_free")); - if (!rpcmem_free) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "rpcmem_free not found"); - } - - auto rpcmem_to_fd = - reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_to_fd")); - if (!rpcmem_to_fd) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "rpcmem_to_fd not found"); - } - - return Ptr(new FastRpcMemLibrary(std::move(dlhandle), rpcmem_alloc, - rpcmem_free, rpcmem_to_fd)); - } - - void* Alloc(size_t size) const { - return rpcmem_alloc_(kRpcmemHeapIdSystem, kRpcmemDefaultFlags, size); - } - - void Free(void* buffer) const { return rpcmem_free_(buffer); } - - int ToFd(void* buffer) const { return rpcmem_to_fd_(buffer); } - - private: - static constexpr int kRpcmemHeapIdSystem = 25; - static constexpr uint32_t kRpcmemDefaultFlags = 1; - - using DlHandle = std::unique_ptr; - using RpcMemAlloc = void* (*)(int, uint32_t, int); - using RpcMemFree = void (*)(void*); - using RpcMemToFd = int (*)(void*); - - FastRpcMemLibrary(DlHandle&& dlhandle, RpcMemAlloc rpcmem_alloc, - RpcMemFree rpcmem_free, RpcMemToFd rpcmem_to_fd) - : dlhandle_(std::move(dlhandle)) { - rpcmem_alloc_ = rpcmem_alloc; - rpcmem_free_ = rpcmem_free; - rpcmem_to_fd_ = rpcmem_to_fd; - } - - DlHandle dlhandle_; - RpcMemAlloc rpcmem_alloc_; - RpcMemFree rpcmem_free_; - RpcMemToFd rpcmem_to_fd_; -}; - -FastRpcMemLibrary* TheFastRpcMemLibrary; -ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); - -Expected InitLibraryIfNeededUnlocked() { - if (!TheFastRpcMemLibrary) { - if (auto library = FastRpcMemLibrary::Create(); library) { - TheFastRpcMemLibrary = library->release(); - } else { - return Unexpected(library.Error()); - } - } - return {}; -} - -} // namespace -#endif // LITERT_HAS_FASTRPC_SUPPORT - -bool FastRpcBuffer::IsSupported() { -#if LITERT_HAS_FASTRPC_SUPPORT - absl::MutexLock lock(&TheMutex); - auto status = InitLibraryIfNeededUnlocked(); - return static_cast(status); -#else // LITERT_HAS_FASTRPC_SUPPORT - return false; -#endif // LITERT_HAS_FASTRPC_SUPPORT -} - -Expected FastRpcBuffer::Alloc(size_t size) { -#if LITERT_HAS_FASTRPC_SUPPORT - absl::MutexLock lock(&TheMutex); - if (auto status = InitLibraryIfNeededUnlocked(); !status) { - return status.Error(); - } - void* addr = TheFastRpcMemLibrary->Alloc(size); - int fd = TheFastRpcMemLibrary->ToFd(addr); - return FastRpcBuffer{.fd = fd, .addr = addr}; -#else // LITERT_HAS_FASTRPC_SUPPORT - return Unexpected(kLiteRtStatusErrorUnsupported, - "FastRpcBuffer::Alloc not implemented for this platform"); -#endif // LITERT_HAS_FASTRPC_SUPPORT -} - -void FastRpcBuffer::Free(void* addr) { -#if LITERT_HAS_FASTRPC_SUPPORT - absl::MutexLock lock(&TheMutex); - if (TheFastRpcMemLibrary) { - TheFastRpcMemLibrary->Free(addr); - } -#endif // LITERT_HAS_FASTRPC_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h b/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h deleted file mode 100644 index fa934ce0b693df..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct FastRpcBuffer { - int fd; - void* addr; - - static bool IsSupported(); - static Expected Alloc(size_t size); - static void Free(void* addr); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gl_buffer.cc b/tensorflow/lite/experimental/litert/runtime/gl_buffer.cc deleted file mode 100644 index 6befdf4b844bb2..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_buffer.cc +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" - -#include - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include -#include -#include -#include - -#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_AHWB_SUPPORT - -PFNGLBUFFERSTORAGEEXTERNALEXTPROC glBufferStorageExternalEXT; -PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC eglGetNativeClientBufferANDROID; -PFNEGLDUPNATIVEFENCEFDANDROIDPROC eglDupNativeFenceFDANDROID; -PFNEGLCREATESYNCKHRPROC eglCreateSyncKHR; -PFNEGLWAITSYNCKHRPROC eglWaitSyncKHR; -PFNEGLCLIENTWAITSYNCKHRPROC eglClientWaitSyncKHR; -PFNEGLDESTROYSYNCKHRPROC eglDestroySyncKHR; - -bool IsAhwbToGlInteropSupported() { - static const bool extensions_allowed = [] { - eglGetNativeClientBufferANDROID = - reinterpret_cast( - eglGetProcAddress("eglGetNativeClientBufferANDROID")); - glBufferStorageExternalEXT = - reinterpret_cast( - eglGetProcAddress("glBufferStorageExternalEXT")); - eglDupNativeFenceFDANDROID = - reinterpret_cast( - eglGetProcAddress("eglDupNativeFenceFDANDROID")); - eglCreateSyncKHR = reinterpret_cast( - eglGetProcAddress("eglCreateSyncKHR")); - eglWaitSyncKHR = reinterpret_cast( - eglGetProcAddress("eglWaitSyncKHR")); - eglClientWaitSyncKHR = reinterpret_cast( - eglGetProcAddress("eglClientWaitSyncKHR")); - eglDestroySyncKHR = reinterpret_cast( - eglGetProcAddress("eglDestroySyncKHR")); - return eglClientWaitSyncKHR && eglWaitSyncKHR && - eglGetNativeClientBufferANDROID && glBufferStorageExternalEXT && - eglCreateSyncKHR && eglDupNativeFenceFDANDROID && eglDestroySyncKHR; - }(); - return extensions_allowed; -} - -Expected GlBuffer::AllocFromAhwbBuffer(AhwbBuffer& ahwb_buffer) { - LITERT_RETURN_IF_ERROR( - IsAhwbToGlInteropSupported(), - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer to GL interop is not supported")); - LITERT_RETURN_IF_ERROR( - ahwb_buffer.ahwb != nullptr, - Unexpected(kLiteRtStatusErrorRuntimeFailure, "AHardwareBuffer is null")); - - // Create GL buffer id. - GLuint gl_id; - glGenBuffers(1, &gl_id); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, gl_id); - - // Create EGLClientBuffer from AHardwareBuffer. - EGLClientBuffer native_buffer = - eglGetNativeClientBufferANDROID(ahwb_buffer.ahwb); - LITERT_RETURN_IF_ERROR( - native_buffer != nullptr, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create EGLClientBuffer from AHardwareBuffer")); - - LITERT_ASSIGN_OR_RETURN( - size_t size_bytes, - litert::internal::AhwbBuffer::GetSize(ahwb_buffer.ahwb)); - LITERT_RETURN_IF_ERROR(size_bytes != 0, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer size is 0")); - - // Create OpenGl buffer object backed by the AHardwareBuffer. - glBufferStorageExternalEXT( - GL_SHADER_STORAGE_BUFFER, 0, size_bytes, native_buffer, - GL_MAP_READ_BIT | GL_MAP_WRITE_BIT | GL_MAP_COHERENT_BIT_EXT | - GL_MAP_PERSISTENT_BIT_EXT); - // Check for OpenGL errors. - absl::Status status = tflite::gpu::gl::GetOpenGlErrors(); - if (!status.ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("glBufferStorageExternalEXT: Failed to " - "create GL buffer from AHardwareBuffer: ", - status.message())); - } - // Unbind the buffer. - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - - // Create GL buffer object. We assume ownership of the GL buffer id so that it - // will be automatically deallocated when the internal::GlBuffer is destroyed. - tflite::gpu::gl::GlBuffer tflite_gl_buffer(GL_SHADER_STORAGE_BUFFER, gl_id, - size_bytes, /*offset=*/0, - /*has_ownership=*/true); - return GlBuffer(std::move(tflite_gl_buffer), ahwb_buffer.ahwb); -} -#endif // LITERT_HAS_AHWB_SUPPORT - -GlBuffer::GlBuffer(LiteRtGLenum target, LiteRtGLuint id, size_t size_bytes, - size_t offset, LiteRtGlBufferDeallocator deallocator) { -#if LITERT_HAS_OPENGL_SUPPORT - size_bytes_ = size_bytes; - - if (deallocator != nullptr) { - tflite_gl_buffer_ = tflite::gpu::gl::GlBuffer( - target, id, size_bytes, offset, /*has_ownership=*/false); - deallocator_ = std::move(deallocator); - } else { - tflite_gl_buffer_ = tflite::gpu::gl::GlBuffer( - target, id, size_bytes, offset, /*has_ownership=*/true); - deallocator_ = nullptr; - } -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::GlBuffer() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlBuffer::GlBuffer(GlBuffer&& other) { -#if LITERT_HAS_OPENGL_SUPPORT - tflite_gl_buffer_ = std::move(other.tflite_gl_buffer_); - deallocator_ = std::move(other.deallocator_); - data_ = other.data_; - size_bytes_ = other.size_bytes_; -#if LITERT_HAS_AHWB_SUPPORT - ahwb_ = other.ahwb_; -#endif // LITERT_HAS_AHWB_SUPPORT - // Reset the other GlBuffer to a default state. - other.data_ = nullptr; - other.size_bytes_ = 0; -#if LITERT_HAS_AHWB_SUPPORT - other.ahwb_ = nullptr; -#endif // LITERT_HAS_AHWB_SUPPORT -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::GlBuffer() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlBuffer::~GlBuffer() { -#if LITERT_HAS_OPENGL_SUPPORT - if (deallocator_ != nullptr) { - deallocator_(reinterpret_cast(tflite_gl_buffer_.id())); - } - if (data_ != nullptr) { - free(data_); - } -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::~GlBuffer() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -LiteRtGLenum GlBuffer::target() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.target(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::target() is not supported"); - return 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -} -LiteRtGLuint GlBuffer::id() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.id(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::id() is not supported"); - return 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -} -size_t GlBuffer::size_bytes() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.bytes_size(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::size_bytes() is not supported"); - return 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -} -size_t GlBuffer::offset() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.offset(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::offset() is not supported"); - return 0; -#endif -} - -Expected GlBuffer::Alloc(size_t size_bytes) { -#if LITERT_HAS_OPENGL_SUPPORT - tflite::gpu::gl::GlBuffer tflite_gl_buffer; - - if (!tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - size_bytes, &tflite_gl_buffer) - .ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate GL buffer"); - } - - return GlBuffer(std::move(tflite_gl_buffer)); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenGL buffers are not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -template Expected GlBuffer::Lock(); -template Expected GlBuffer::Lock(); -template Expected GlBuffer::Unlock(); -template Expected GlBuffer::Unlock(); - -template -Expected GlBuffer::Lock() { -#if LITERT_HAS_OPENGL_SUPPORT - absl::MutexLock lock(&mutex_); -#if LITERT_HAS_AHWB_SUPPORT - if (ahwb_ != nullptr) { - LITERT_ASSIGN_OR_RETURN(void* data, - litert::internal::AhwbBuffer::Lock(ahwb_)); - return static_cast(data); - } -#endif // LITERT_HAS_AHWB_SUPPORT - if (data_ == nullptr) { - // Ensure the data is aligned. - if (auto rc = posix_memalign(&data_, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, - size_bytes_); - rc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate aligned memory"); - } - if (auto status = tflite_gl_buffer_.Read( - absl::MakeSpan(static_cast(data_), size_bytes_ / sizeof(T))); - !status.ok()) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to read GL buffer: ", status.message())); - } - } - return Expected(static_cast(data_)); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GlBuffer::Lock() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -template -Expected GlBuffer::Unlock() { -#if LITERT_HAS_OPENGL_SUPPORT - absl::MutexLock lock(&mutex_); -#if LITERT_HAS_AHWB_SUPPORT - if (ahwb_ != nullptr) { - return litert::internal::AhwbBuffer::Unlock(ahwb_); - } -#endif // LITERT_HAS_AHWB_SUPPORT - if (data_ == nullptr) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "Cannot unlock a buffer that wasn't locked in the first place"); - } - if (auto status = tflite_gl_buffer_.Write(absl::MakeSpan( - static_cast(data_), size_bytes_ / sizeof(T))); - !status.ok()) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to write GL buffer: ", status.message())); - } - return Expected(); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GlBuffer::Unlock() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -Expected GlBuffer::CreateEglSyncAndFence() { -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - LITERT_RETURN_IF_ERROR( - IsAhwbToGlInteropSupported(), - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer to GL interop is not supported")); - - auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); - LITERT_RETURN_IF_ERROR(egl_display != EGL_NO_DISPLAY, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get EGL display")); - - EGLSyncKHR egl_sync = - eglCreateSyncKHR(egl_display, EGL_SYNC_NATIVE_FENCE_ANDROID, nullptr); - LITERT_RETURN_IF_ERROR( - egl_sync != EGL_NO_SYNC_KHR, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create EGL sync from AHardwareBuffer")); - - int native_fence = eglDupNativeFenceFDANDROID(egl_display, egl_sync); - LITERT_RETURN_IF_ERROR( - native_fence != -1, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to dup native fence from AHardwareBuffer")); - - return native_fence; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer to GL interop is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/gl_buffer.h b/tensorflow/lite/experimental/litert/runtime/gl_buffer.h deleted file mode 100644 index f691317aa87405..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_buffer.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_BUFFER_H_ - -#include -#include - -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -#if LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert::internal { - -class GlBuffer { - public: -#if LITERT_HAS_OPENGL_SUPPORT - explicit GlBuffer(tflite::gpu::gl::GlBuffer&& tflite_gl_buffer -#if LITERT_HAS_AHWB_SUPPORT - , - AHardwareBuffer* ahwb = nullptr -#endif // LITERT_HAS_AHWB_SUPPORT - ) - : tflite_gl_buffer_(std::move(tflite_gl_buffer)), - deallocator_(nullptr), - size_bytes_(tflite_gl_buffer.bytes_size()) -#if LITERT_HAS_AHWB_SUPPORT - , - ahwb_(ahwb) -#endif // LITERT_HAS_AHWB_SUPPORT - { - } -#endif // LITERT_HAS_OPENGL_SUPPORT - - GlBuffer(LiteRtGLenum target, LiteRtGLuint id, size_t size_bytes, - size_t offset, LiteRtGlBufferDeallocator deallocator); - - GlBuffer(GlBuffer&& other); - - ~GlBuffer(); - - static bool IsSupported() { return true; } - static Expected Alloc(size_t size_bytes); - -#if LITERT_HAS_AHWB_SUPPORT - static Expected AllocFromAhwbBuffer(AhwbBuffer& ahwb_buffer); -#endif // LITERT_HAS_AHWB_SUPPORT - - template - Expected Lock(); - - template - Expected Unlock(); - - LiteRtGLenum target() const; - LiteRtGLuint id() const; - size_t size_bytes() const; - size_t offset() const; - - // Creates an EGL sync object on the GPU command queue and returns a native - // fence associated with the sync object. - // Note: This function assumes that all GL operations have been already added - // to the GPU command queue. - static Expected CreateEglSyncAndFence(); - - private: - absl::Mutex mutex_; -#if LITERT_HAS_OPENGL_SUPPORT - tflite::gpu::gl::GlBuffer tflite_gl_buffer_; - LiteRtGlBufferDeallocator deallocator_; - // The cpu memory buffer pointer. - void* data_ = nullptr; - // The size of the buffer in bytes. - size_t size_bytes_ = 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb_ = nullptr; -#endif // LITERT_HAS_AHWB_SUPPORT -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc b/tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc deleted file mode 100644 index 905056a860cccc..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include - -#include -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert { -namespace internal { -namespace { - -using ::testing::FloatEq; -using ::testing::FloatNear; -using ::testing::Pointwise; - -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -TEST(Buffer, GlBufferAlloc) { - if (!GlBuffer::IsSupported()) { - GTEST_SKIP() << "OpenGL buffers are not supported on this platform"; - } - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - auto buffer = GlBuffer::Alloc(4 * sizeof(float)); - ASSERT_TRUE(buffer); - - // Test lock and unlock. - LITERT_ASSERT_OK_AND_ASSIGN(float* data, buffer->Lock()); - EXPECT_NE(data, nullptr); - LITERT_ASSERT_OK(buffer->Unlock()); -} - -#if LITERT_HAS_AHWB_SUPPORT -TEST(Buffer, GlBufferAllocFromAhwb) { - if (!GlBuffer::IsSupported()) { - GTEST_SKIP() << "OpenGL buffers are not supported on this platform"; - } - // TODO(gcarranza): Incorporate this into LiteRT environment. - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN(AhwbBuffer ahwb_buffer, - AhwbBuffer::Alloc(4 * sizeof(float))); - // Write to AHWB on CPU. - LITERT_ASSERT_OK_AND_ASSIGN( - void* ahwb_host_data, - litert::internal::AhwbBuffer::Lock(ahwb_buffer.ahwb)); - std::memcpy(ahwb_host_data, kTensorData, sizeof(kTensorData)); - LITERT_ASSERT_OK(litert::internal::AhwbBuffer::Unlock(ahwb_buffer.ahwb)); - - // Create GL buffer from AHWB. - LITERT_ASSERT_OK_AND_ASSIGN(GlBuffer gl_buffer, - GlBuffer::AllocFromAhwbBuffer(ahwb_buffer)); - - // Read from GL buffer backed by AHWB. - LITERT_ASSERT_OK_AND_ASSIGN(float* gl_host_data, gl_buffer.Lock()); - ASSERT_NE(gl_host_data, nullptr); - EXPECT_EQ(std::memcmp(gl_host_data, kTensorData, sizeof(kTensorData)), 0); - LITERT_EXPECT_OK(gl_buffer.Unlock()); -} - -TEST(Buffer, NegativeFenceAhwbRead) { - LITERT_ASSERT_OK_AND_ASSIGN(AhwbBuffer ahwb_buffer, - AhwbBuffer::Alloc(4 * sizeof(float))); - - LiteRtEventT event; - LITERT_ASSERT_OK_AND_ASSIGN(int fence_fd, event.GetSyncFenceFd()); - ASSERT_EQ(fence_fd, -1); - // Since fence is -1, there should be no wait on fence. - LITERT_ASSERT_OK_AND_ASSIGN(void* ahwb_host_data, - AhwbBuffer::Lock(ahwb_buffer.ahwb, &event)); - ASSERT_TRUE(ahwb_host_data != nullptr); - LITERT_ASSERT_OK(AhwbBuffer::Unlock(ahwb_buffer.ahwb)); -} - -// Utility function to fill the GPU buffer. -void FillGlBuffer(GLuint id, std::size_t size) { - std::string shader_source = R"( #version 310 es - precision highp float; - layout(local_size_x = 1, local_size_y = 1) in; - layout(std430, binding = 0) buffer Output {float elements[];} output_data; - void main() { - uint v = gl_GlobalInvocationID.x * 2u; - output_data.elements[v] = float(v) / 10.0; - output_data.elements[v + 1u] = float(v + 1u) / 10.0; - })"; - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const GLchar* sources[] = {shader_source.c_str()}; - glShaderSource(shader, 1, sources, nullptr); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glCompileShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - GLuint to_buffer_program = glCreateProgram(); - glAttachShader(to_buffer_program, shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glLinkProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, id); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glUseProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDispatchCompute(size / 2, 1, 1); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); -} - -TEST(Buffer, GpuWriteAhwbRead) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN(AhwbBuffer ahwb_buffer, - AhwbBuffer::Alloc(4 * sizeof(float))); - // Write to AHWB on CPU. - LITERT_ASSERT_OK_AND_ASSIGN( - void* ahwb_host_data, - litert::internal::AhwbBuffer::Lock(ahwb_buffer.ahwb)); - std::memcpy(ahwb_host_data, kTensorData, sizeof(kTensorData)); - LITERT_ASSERT_OK(litert::internal::AhwbBuffer::Unlock(ahwb_buffer.ahwb)); - - // Create GL buffer from AHWB. - LITERT_ASSERT_OK_AND_ASSIGN(GlBuffer gl_buffer, - GlBuffer::AllocFromAhwbBuffer(ahwb_buffer)); - - // Schedule GPU write to GL buffer. - FillGlBuffer(gl_buffer.id(), 4); - - // Create EGL sync and fence before AHWB read. - LITERT_ASSERT_OK_AND_ASSIGN(int native_fence, - GlBuffer::CreateEglSyncAndFence()); - - // Wrap native fence in LiteRT event. - LiteRtEventT gpu_write_event = {.fd = native_fence, .owns_fd = true}; - - // Read from AHWB on CPU, waiting for GPU write to complete. - LITERT_ASSERT_OK_AND_ASSIGN( - void* ahwb_host_data_after_write_data, - AhwbBuffer::Lock(ahwb_buffer.ahwb, &gpu_write_event)); - ASSERT_NE(ahwb_host_data_after_write_data, nullptr); - auto ahwb_host_data_after_write = absl::MakeSpan( - reinterpret_cast(ahwb_host_data_after_write_data), 4); - // Check that the data is the same as the GPU write. - std::vector expected_data = {0.0f, 0.1f, 0.2f, 0.3f}; - EXPECT_THAT(ahwb_host_data_after_write, - Pointwise(FloatNear(1e-5), expected_data)); - LITERT_ASSERT_OK(AhwbBuffer::Unlock(ahwb_buffer.ahwb)); -} - -#endif // LITERT_HAS_AHWB_SUPPORT - -} // namespace -} // namespace internal -} // namespace litert - -#endif // LITERT_HAS_OPENGL_SUPPORT diff --git a/tensorflow/lite/experimental/litert/runtime/gl_texture.cc b/tensorflow/lite/experimental/litert/runtime/gl_texture.cc deleted file mode 100644 index 6b453f32957f21..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_texture.cc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gl_texture.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace internal { - -LiteRtGLenum GlTexture::target() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.target(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::target() is not supported"); - return 0; -} - -LiteRtGLuint GlTexture::id() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.id(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::id() is not supported"); - return 0; -} - -LiteRtGLenum GlTexture::format() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.format(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::format() is not supported"); - return 0; -} - -size_t GlTexture::size_bytes() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.bytes_size(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::size_bytes() is not supported"); - return 0; -} - -LiteRtGLint GlTexture::layer() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.layer(); -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::layer() is not supported"); - return 0; -#endif -} - -GlTexture::GlTexture(LiteRtGLenum target, LiteRtGLuint id, LiteRtGLenum format, - size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator) { -#if LITERT_HAS_OPENGL_SUPPORT - if (deallocator != nullptr) { - tflite_gl_texture_ = tflite::gpu::gl::GlTexture( - target, id, format, size_bytes, layer, /*has_ownership=*/false); - deallocator_ = std::move(deallocator); - } else { - tflite_gl_texture_ = tflite::gpu::gl::GlTexture( - target, id, format, size_bytes, layer, /*has_ownership=*/true); - deallocator_ = nullptr; - } -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::GlTexture() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlTexture::GlTexture(GlTexture&& other) { -#if LITERT_HAS_OPENGL_SUPPORT - tflite_gl_texture_ = std::move(other.tflite_gl_texture_); - deallocator_ = std::move(other.deallocator_); -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::GlTexture() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlTexture::~GlTexture() { -#if LITERT_HAS_OPENGL_SUPPORT - if (deallocator_ != nullptr) { - deallocator_(reinterpret_cast(tflite_gl_texture_.id())); - } -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::~GlTexture() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/gl_texture.h b/tensorflow/lite/experimental/litert/runtime/gl_texture.h deleted file mode 100644 index 3d358fe9515a32..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_texture.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_TEXTURE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_TEXTURE_H_ - -#include - -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert::internal { - -class GlTexture { - public: - GlTexture(LiteRtGLenum target, LiteRtGLuint id, LiteRtGLenum format, - size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator); - - GlTexture(GlTexture&& other); - - ~GlTexture(); - - LiteRtGLenum target() const; - LiteRtGLuint id() const; - LiteRtGLenum format() const; - size_t size_bytes() const; - LiteRtGLint layer() const; - - private: - absl::Mutex mutex_; -#if LITERT_HAS_OPENGL_SUPPORT - tflite::gpu::gl::GlTexture tflite_gl_texture_; - LiteRtGlTextureDeallocator deallocator_; -#endif // LITERT_HAS_OPENGL_SUPPORT -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_TEXTURE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gpu_environment.cc b/tensorflow/lite/experimental/litert/runtime/gpu_environment.cc deleted file mode 100644 index c30f9055570dee..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gpu_environment.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" - -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace internal { - -GpuEnvironmentSingleton::GpuEnvironmentSingleton( - LiteRtEnvironmentT* environment) { - cl_device_id device_id = nullptr; - cl_platform_id platform_id = nullptr; - cl_context context = nullptr; - cl_command_queue command_queue = nullptr; - if (environment) { - auto device_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClDeviceId); - if (device_option.has_value() && device_option->type == kLiteRtAnyTypeInt) { - device_id = reinterpret_cast(device_option->int_value); - } - auto platform_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClPlatformId); - if (platform_option.has_value() && - platform_option->type == kLiteRtAnyTypeInt) { - platform_id = - reinterpret_cast(platform_option->int_value); - } - auto context_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClContext); - if (context_option.has_value() && - context_option->type == kLiteRtAnyTypeInt) { - context = reinterpret_cast(context_option->int_value); - } - auto command_queue_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClCommandQueue); - if (command_queue_option.has_value() && - command_queue_option->type == kLiteRtAnyTypeInt) { - command_queue = - reinterpret_cast(command_queue_option->int_value); - } - } - if (device_id && platform_id) { - device_ = litert::cl::ClDevice(device_id, platform_id); - } else { - auto status = litert::cl::CreateDefaultGPUDevice(&device_); - if (!status.ok()) { - LITERT_LOG(LITERT_ERROR, "Failed to create OpenCL device"); - } - } - if (context) { - context_ = litert::cl::ClContext(context, /*has_ownership=*/false); - } else { - auto status = litert::cl::CreateClContext(device_, &context_); - if (!status.ok()) { - LITERT_LOG(LITERT_ERROR, "Failed to create OpenCL contxt"); - } - } - if (command_queue) { - command_queue_ = - litert::cl::ClCommandQueue(command_queue, /*has_ownership=*/false); - } else { - auto status = - litert::cl::CreateClCommandQueue(device_, context_, &command_queue_); - if (!status.ok()) { - LITERT_LOG(LITERT_ERROR, "Failed to create OpenCL command queue"); - } - } -} - -GpuEnvironmentSingleton* GpuEnvironmentSingleton::instance_ = nullptr; - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/gpu_environment.h b/tensorflow/lite/experimental/litert/runtime/gpu_environment.h deleted file mode 100644 index 38f4d9215da15d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gpu_environment.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GPU_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GPU_ENVIRONMENT_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -namespace litert::internal { - -// Inner singleton class that is for storing the MLD global environment. -// This class is used to store OpenCL, OpenGL environment objects. -class GpuEnvironmentSingleton { - public: - GpuEnvironmentSingleton(const GpuEnvironmentSingleton&) = delete; - GpuEnvironmentSingleton& operator=(const GpuEnvironmentSingleton&) = delete; - ~GpuEnvironmentSingleton() = default; - litert::cl::ClDevice* getDevice() { return &device_; } - litert::cl::ClContext* getContext() { return &context_; } - litert::cl::ClCommandQueue* getCommandQueue() { return &command_queue_; } - - static GpuEnvironmentSingleton& GetInstance() { - if (instance_ == nullptr) { - instance_ = new GpuEnvironmentSingleton(nullptr); - } - return *instance_; - } - - // Create the singleton instance with the given environment. - // It will fail if the singleton instance already exists. - static Expected Create( - LiteRtEnvironmentT* environment) { - if (instance_ == nullptr) { - instance_ = new GpuEnvironmentSingleton(environment); - LITERT_LOG(LITERT_INFO, "Created LiteRT EnvironmentSingleton."); - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "EnvironmentSingleton already exists"); - } - return instance_; - } - - private: - // Load the OpenCL device, context and command queue from the environment if - // available. Otherwise, create the default device, context and command queue. - explicit GpuEnvironmentSingleton(LiteRtEnvironmentT* environment); - - litert::cl::ClDevice device_; - litert::cl::ClContext context_; - litert::cl::ClCommandQueue command_queue_; - static GpuEnvironmentSingleton* instance_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GPU_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc b/tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc deleted file mode 100644 index 5bf76f813d95e5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" - -#include -#include -#include - -#include -#include -#include "third_party/ml_drift/cl/environment.h" -#include "third_party/ml_drift/cl/opencl_wrapper.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace { - -TEST(EnvironmentSingletonTest, OpenClEnvironment) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - - if (!ml_drift::cl::LoadOpenCL().ok()) { - GTEST_SKIP() << "OpenCL not loaded for ml_drift"; - } - if (!litert::cl::LoadOpenCL().ok()) { - GTEST_SKIP() << "OpenCL not loaded for litert"; - } - - ml_drift::cl::Environment env; - ASSERT_OK(ml_drift::cl::CreateEnvironment(&env)); - - const std::array environment_options = { - LiteRtEnvOption{ - /*.tag=*/kLiteRtEnvOptionTagOpenClContext, - /*.value=*/ - *ToLiteRtAny( - std::any(reinterpret_cast(env.context().context()))), - }, - LiteRtEnvOption{ - /*.tag=*/kLiteRtEnvOptionTagOpenClCommandQueue, - /*.value=*/ - *ToLiteRtAny( - std::any(reinterpret_cast(env.queue()->queue()))), - }, - }; - auto litert_envt = LiteRtEnvironmentT::CreateWithOptions(environment_options); - ASSERT_TRUE(litert_envt); - auto singleton_env = - litert::internal::GpuEnvironmentSingleton::Create(litert_envt->get()); - ASSERT_TRUE(singleton_env); - EXPECT_EQ((*singleton_env)->getContext()->context(), env.context().context()); - EXPECT_EQ((*singleton_env)->getCommandQueue()->queue(), env.queue()->queue()); - - // Create another singleton environment should fail. - auto another_singleton_env = - litert::internal::GpuEnvironmentSingleton::Create(litert_envt->get()); - EXPECT_FALSE(another_singleton_env); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/ion_buffer.cc b/tensorflow/lite/experimental/litert/runtime/ion_buffer.cc deleted file mode 100644 index 41a3ee09c82643..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ion_buffer.cc +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_ION_SUPPORT -#include -#include -#endif // LITERT_HAS_ION_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_ION_SUPPORT -namespace { - -class IonLibrary { - public: - using Ptr = std::unique_ptr; - - ~IonLibrary() { - if (client_fd_ > 0) { - ion_close_(client_fd_); - } - } - - static Expected Create() { - DlHandle dlhandle(::dlopen("libion.so", RTLD_NOW | RTLD_LOCAL), ::dlclose); - if (!dlhandle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "libion.so not found"); - } - - auto ion_open = - reinterpret_cast(::dlsym(dlhandle.get(), "ion_open")); - if (!ion_open) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "ion_open not found"); - } - - auto ion_close = - reinterpret_cast(::dlsym(dlhandle.get(), "ion_close")); - if (!ion_close) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "ion_close not found"); - } - - auto ion_alloc_fd = - reinterpret_cast(::dlsym(dlhandle.get(), "ion_alloc_fd")); - if (!ion_alloc_fd) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "ion_alloc_fd not found"); - } - - int client_fd = ion_open(); - if (client_fd < 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to open ion device"); - } - - return Ptr(new IonLibrary(std::move(dlhandle), client_fd, ion_close, - ion_alloc_fd)); - } - - Expected Alloc(size_t size, size_t alignment) { - int heap_id_mask = 1 << kIonHeapId; - int fd; - if (auto status = ion_alloc_fd_(client_fd_, size, alignment, heap_id_mask, - kIonFlags, &fd); - status != 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate DMA-BUF buffer"); - } - void* addr = - ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - if (addr == MAP_FAILED) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to mem-map DMA-BUF buffer"); - } - records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; - return IonBuffer{.fd = fd, .addr = addr}; - } - - void Free(void* addr) { - auto iter = records_.find(addr); - if (iter == records_.end()) { - return; - } - auto& record = iter->second; - ::munmap(record.addr, record.size); - ::close(record.fd); - records_.erase(iter); - } - - private: - static constexpr const int kIonHeapId = 25; - static constexpr const int kIonFlags = 1; - - struct Record { - int fd; - void* addr; - size_t size; - }; - - using DlHandle = std::unique_ptr; - using IonOpen = int (*)(); - using IonClose = int (*)(int); - using IonAllocFd = int (*)(int, size_t, size_t, unsigned int, unsigned int, - int*); - - IonLibrary(DlHandle&& dlhandle, int client_fd, IonClose ion_close, - IonAllocFd ion_alloc_fd) - : dlhandle_(std::move(dlhandle)), - client_fd_(client_fd), - ion_close_(ion_close), - ion_alloc_fd_(ion_alloc_fd) {} - - DlHandle dlhandle_; - int client_fd_; - IonClose ion_close_; - IonAllocFd ion_alloc_fd_; - absl::node_hash_map records_; -}; - -IonLibrary* TheIonLibrary; -ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); - -Expected InitLibraryIfNeededUnlocked() { - if (!TheIonLibrary) { - if (auto library = IonLibrary::Create(); library) { - TheIonLibrary = library->release(); - } else { - return Unexpected(library.Error()); - } - } - return {}; -} - -} // namespace -#endif // LITERT_HAS_ION_SUPPORT - -bool IonBuffer::IsSupported() { -#if LITERT_HAS_ION_SUPPORT - absl::MutexLock lock(&TheMutex); - auto status = InitLibraryIfNeededUnlocked(); - return static_cast(status); -#else // LITERT_HAS_ION_SUPPORT - return false; -#endif // LITERT_HAS_ION_SUPPORT -} - -Expected IonBuffer::Alloc(size_t size, size_t alignment) { -#if LITERT_HAS_ION_SUPPORT - absl::MutexLock lock(&TheMutex); - if (auto status = InitLibraryIfNeededUnlocked(); !status) { - return status.Error(); - } - return TheIonLibrary->Alloc(size, alignment); -#else // LITERT_HAS_ION_SUPPORT - return Unexpected(kLiteRtStatusErrorUnsupported, - "IonBuffer::Alloc not implemented for this platform"); -#endif // LITERT_HAS_ION_SUPPORT -} - -void IonBuffer::Free(void* addr) { -#if LITERT_HAS_ION_SUPPORT - absl::MutexLock lock(&TheMutex); - if (TheIonLibrary) { - TheIonLibrary->Free(addr); - } -#endif // LITERT_HAS_ION_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/ion_buffer.h b/tensorflow/lite/experimental/litert/runtime/ion_buffer.h deleted file mode 100644 index 38a0b19abdc137..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ion_buffer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct IonBuffer { - int fd; - void* addr; - - static bool IsSupported(); - static Expected Alloc(size_t size, size_t alignment); - static void Free(void* addr); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc b/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc deleted file mode 100644 index d99a9875472981..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" - -#include - -#include -#include -#include -#include - -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace internal { - -template Expected OpenClBuffer::Lock(); -template Expected OpenClBuffer::Lock(); -template Expected OpenClBuffer::Unlock(); -template Expected OpenClBuffer::Unlock(); - -template -Expected OpenClBuffer::Lock() { - absl::MutexLock lock(&mutex_); - // The buffer has not been locked, so we need to read from the OpenCL - // buffer. - if (data_ == nullptr) { - litert::cl::ClCommandQueue* queue = - GpuEnvironmentSingleton::GetInstance().getCommandQueue(); - std::vector result; - auto status = buffer_.ReadData(queue, &result); - if (!status.ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to read OpenCL buffer"); - } - // Ensure the data is aligned. - if (auto rc = - posix_memalign(&data_, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, size_); - rc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate aligned memory"); - } - // Copy the data from the OpenCL buffer to the aligned memory. - // TODO(piyu): Consider adding support in MLD OpenCL buffer to directly - // write to the aligned memory. - std::copy(result.begin(), result.end(), static_cast(data_)); - } - return Expected(static_cast(data_)); -} - -template -Expected OpenClBuffer::Unlock() { - absl::MutexLock lock(&mutex_); - litert::cl::ClCommandQueue* queue = - GpuEnvironmentSingleton::GetInstance().getCommandQueue(); - // The buffer has not been locked, so we don't need to write back. - if (data_ == nullptr) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "Cannot unlock a buffer that wasn't locked in the first place"); - } - size_t write_size = (size_ + sizeof(T) - 1) / sizeof(T); - auto status = buffer_.WriteData( - queue, absl::MakeSpan(static_cast(data_), write_size)); - - if (status.ok()) { - return Expected(); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "The data failed to write to the OpenCL buffer when unlocked"); -} - -bool OpenClBuffer::IsSupported() { - static bool is_supported = ::litert::cl::LoadOpenCL().ok(); - return is_supported; -} - -Expected OpenClBuffer::Alloc(size_t bytes_size) { - LITERT_RETURN_IF_ERROR( - IsSupported(), - Unexpected(kLiteRtStatusErrorRuntimeFailure, "OpenCL is not supported")); - - litert::cl::Buffer buffer; - - litert::cl::ClContext* cl_context = - GpuEnvironmentSingleton::GetInstance().getContext(); - auto result = - litert::cl::CreateReadWriteBuffer(bytes_size, cl_context, &buffer); - if (!result.ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create OpenCL buffer"); - } - - return Expected(std::move(buffer), bytes_size); -} -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h b/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h deleted file mode 100644 index cf5c422afd1adb..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPEN_CL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPEN_CL_BUFFER_H_ - -#include -#include -#include - -#include "absl/synchronization/mutex.h" -#include -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert::internal { - -/** - * The OpenCL buffer class that provides GPU memory allocation and two-way sync - * between the CPU memory and the GPU OpenCL buffer. - */ -class OpenClBuffer { - public: - OpenClBuffer(OpenClBuffer&& other) { - data_ = other.data_; - buffer_ = std::move(other.buffer_); - size_ = other.size_; - other.data_ = nullptr; - other.size_ = 0; - } - - OpenClBuffer(litert::cl::Buffer buffer, size_t size) - : buffer_(std::move(buffer)), size_(size) {} - - OpenClBuffer(cl_mem buffer, size_t size, LiteRtOpenClDeallocator deallocator) - : deallocator_(deallocator), size_(size) { - if (deallocator_ != nullptr) { - buffer_ = litert::cl::CreateBufferShared(buffer); - } else { // The buffer will be deallocated automatically. - buffer_ = litert::cl::Buffer(buffer, size); - } - } - - ~OpenClBuffer() { - if (deallocator_ != nullptr) { - deallocator_(buffer_.GetMemoryPtr()); - } - if (data_ != nullptr) { - free(data_); - }; - } - - cl_mem GetMemoryPtr() { return buffer_.GetMemoryPtr(); } - // Allocates a CPU memory and conducts a copy from the OpenCL buffer to the - // CPU memory. - template - Expected Lock(); - - // Writes the data from the CPU memory to the OpenCL buffer. - template - Expected Unlock(); - - static bool IsSupported(); - static Expected Alloc(size_t bytes_size); - size_t size_bytes() const { return size_; } - - private: - absl::Mutex mutex_; - // The cpu memory buffer pointer. - void* data_ = nullptr; - litert::cl::Buffer buffer_; - LiteRtOpenClDeallocator deallocator_ = nullptr; - // The size of the buffer in bytes. - size_t size_ = 0; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPEN_CL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/BUILD b/tensorflow/lite/experimental/litert/runtime/opencl/BUILD deleted file mode 100644 index 21b74be48a932e..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "cl_command_queue", - srcs = [ - "cl_command_queue.cc", - ], - hdrs = [ - "cl_command_queue.h", - ], - deps = [ - ":cl_context", - ":cl_device", - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@opencl_headers", - ], -) - -cc_library( - name = "cl_device", - srcs = [ - "cl_device.cc", - ], - hdrs = [ - "cl_device.h", - ], - deps = [ - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@opencl_headers", - ], -) - -cc_library( - name = "cl_context", - srcs = [ - "cl_context.cc", - ], - hdrs = [ - "cl_context.h", - ], - deps = [ - ":cl_device", - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@opencl_headers", - ], -) - -cc_library( - name = "cl_event", - srcs = [ - "cl_event.cc", - ], - hdrs = [ - "cl_event.h", - ], - deps = [ - ":opencl_wrapper", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:str_format", - "@opencl_headers", - ], -) - -cc_library( - name = "opencl_wrapper", - srcs = [ - "opencl_wrapper.cc", - ], - hdrs = [ - "opencl_wrapper.h", - ], - visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//third_party/odml/infra:__subpackages__", - ], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@opencl_headers", - ], -) - -cc_library( - name = "buffer", - srcs = [ - "buffer.cc", - ], - hdrs = [ - "buffer.h", - ], - deps = [ - ":cl_command_queue", - ":cl_context", - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ], -) - -cc_test( - name = "buffer_test", - srcs = ["buffer_test.cc"], - # require GPU to run OpenCL tests. - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":buffer", - ":cl_command_queue", - ":cl_context", - ":cl_device", - ":opencl_wrapper", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc deleted file mode 100644 index c2878a4839517a..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/buffer.cc. -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { -absl::Status CreateClBuffer(cl_context context, size_t size_in_bytes, - bool read_only, void* data, cl_mem* result) { - cl_mem_flags flags = read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE; - if (data) { - flags |= CL_MEM_COPY_HOST_PTR; - } - cl_int error_code; - *result = clCreateBuffer(context, flags, size_in_bytes, data, &error_code); - if (!*result) { - return absl::UnknownError( - absl::StrCat("Failed to allocate device memory (clCreateBuffer): ", - std::to_string(error_code))); - } - return absl::OkStatus(); -} -absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, - const void* data, ClContext* context, - Buffer* result) { - cl_mem buffer; - auto status = CreateClBuffer(context->context(), size_in_bytes, gpu_read_only, - const_cast(data), &buffer); - if (!status.ok()) { - return status; - } - *result = Buffer(buffer, size_in_bytes); - - return absl::OkStatus(); -} - -Buffer::Buffer(cl_mem buffer, size_t size_in_bytes, bool is_sub_buffer) - : buffer_(buffer), size_(size_in_bytes), is_sub_buffer_(is_sub_buffer) {} - -Buffer::Buffer(cl_mem buffer) - : buffer_(buffer), size_(0), is_sub_buffer_(false), owner_(false) {} - -Buffer::Buffer(Buffer&& buffer) - : buffer_(buffer.buffer_), - size_(buffer.size_), - is_sub_buffer_(buffer.is_sub_buffer_), - owner_(buffer.owner_) { - buffer.buffer_ = nullptr; - buffer.size_ = 0; - buffer.is_sub_buffer_ = false; -} - -Buffer& Buffer::operator=(Buffer&& buffer) { - if (this != &buffer) { - Release(); - std::swap(size_, buffer.size_); - std::swap(buffer_, buffer.buffer_); - std::swap(is_sub_buffer_, buffer.is_sub_buffer_); - std::swap(owner_, buffer.owner_); - } - return *this; -} - -void Buffer::Release() { - if (owner_ && buffer_) { - clReleaseMemObject(buffer_); - buffer_ = nullptr; - size_ = 0; - is_sub_buffer_ = false; - } -} - -Buffer CreateBufferShared(cl_mem buffer) { return Buffer(buffer); } - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result) { - return CreateBuffer(size_in_bytes, true, nullptr, context, result); -} - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - ClContext* context, Buffer* result) { - return CreateBuffer(size_in_bytes, true, data, context, result); -} - -absl::Status CreateReadWriteBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result) { - return CreateBuffer(size_in_bytes, false, nullptr, context, result); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h deleted file mode 100644 index b1cb09f065508f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/buffer.h. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" - -namespace litert { -namespace cl { - -// Buffer represent linear GPU data storage with arbitrary data format. -// Buffer is moveable but not copyable. -class Buffer { - public: - Buffer() = default; // just for using Buffer as a class members - Buffer(cl_mem buffer, size_t size_in_bytes, bool is_sub_buffer = false); - explicit Buffer(cl_mem buffer); - - // Move only - Buffer(Buffer&& buffer); - Buffer& operator=(Buffer&& buffer); - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; - - ~Buffer() { Release(); } - - // for profiling and memory statistics - uint64_t GetMemorySizeInBytes() const { return size_; } - - cl_mem GetMemoryPtr() const { return buffer_; } - - bool IsSubBuffer() const { return is_sub_buffer_; } - - // Writes data to a buffer. Data should point to a region that - // has exact size in bytes as size_in_bytes(constructor parameter). - template - absl::Status WriteData(ClCommandQueue* queue, absl::Span data); - - // Reads data from Buffer into CPU memory. - template - absl::Status ReadData(ClCommandQueue* queue, std::vector* result) const; - - private: - void Release(); - - cl_mem buffer_ = nullptr; - size_t size_ = 0; - bool is_sub_buffer_ = false; - bool owner_ = true; -}; - -Buffer CreateBufferShared(cl_mem buffer); - -absl::Status CreateClBuffer(cl_context context, size_t size_in_bytes, - bool read_only, void* data, cl_mem* result); - -absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, - const void* data, ClContext* context, Buffer* result); - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result); - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - ClContext* context, Buffer* result); - -absl::Status CreateReadWriteBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result); - -absl::Status CreateReadWriteSubBuffer(const Buffer& parent, - size_t origin_in_bytes, - size_t size_in_bytes, ClContext* context, - Buffer* result); - -template -absl::Status Buffer::WriteData(ClCommandQueue* queue, - const absl::Span data) { - if (sizeof(T) * data.size() > size_) { - return absl::InvalidArgumentError( - "absl::Span data size is greater from buffer allocated size."); - } - auto status = queue->EnqueueWriteBuffer(buffer_, size_, data.data()); - if (!status.ok()) { - return status; - } - return absl::OkStatus(); -} - -template -absl::Status Buffer::ReadData(ClCommandQueue* queue, - std::vector* result) const { - if (size_ % sizeof(T) != 0) { - return absl::UnknownError("Wrong element size(typename T is not correct?"); - } - - const int elements_count = size_ / sizeof(T); - result->resize(elements_count); - - return queue->EnqueueReadBuffer(buffer_, size_, result->data()); -} - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc b/tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc deleted file mode 100644 index 84280ef6af23b9..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 The ML Drift Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" - -#include - -#include -#include -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -using ::testing::FloatNear; -using ::testing::Pointwise; - -namespace litert { -namespace internal { - -TEST(OpenCLTest, BufferTestFloat) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - - if (!litert::cl::LoadOpenCL().ok()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - const std::vector data = {1.0, 2.0, 3.0, -4.0, 5.1}; - litert::cl::Buffer buffer; - litert::cl::ClContext context; - litert::cl::ClDevice device; - litert::cl::ClCommandQueue queue; - ASSERT_TRUE(CreateDefaultGPUDevice(&device).ok()); - ASSERT_TRUE(CreateClContext(device, &context).ok()); - ASSERT_TRUE(CreateClCommandQueue(device, context, &queue).ok()); - ASSERT_TRUE(CreateReadWriteBuffer(sizeof(float) * 5, &context, &buffer).ok()); - ASSERT_TRUE( - buffer.WriteData(&queue, absl::MakeConstSpan(data.data(), data.size())) - .ok()); - std::vector gpu_data; - ASSERT_TRUE(buffer.ReadData(&queue, &gpu_data).ok()); - - EXPECT_THAT(gpu_data, Pointwise(FloatNear(0.0f), data)); -} -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc deleted file mode 100644 index 278862c3f87d20..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/cl_command_queue.cc. -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { -namespace { - -absl::StatusOr CreateClCommandQueueWithProperties( - const ClDevice& device, const ClContext& context, - cl_command_queue_properties queue_properties) { - int error_code; - cl_command_queue queue; - if (clCreateCommandQueueWithProperties) { - std::vector props; - if (queue_properties != 0) { - props.push_back(CL_QUEUE_PROPERTIES); - props.push_back(queue_properties); - } - props.push_back(0); - - queue = clCreateCommandQueueWithProperties(context.context(), device.id(), - props.data(), &error_code); - } else { - // Backwards compatibility for OpenCL versions before 2.0. - queue = clCreateCommandQueue(context.context(), device.id(), - queue_properties, &error_code); - } - if (!queue) { - return absl::UnknownError(absl::StrCat( - "Failed to create a command queue - ", std::to_string(error_code))); - } - return queue; -} - -} // namespace - -ClCommandQueue::ClCommandQueue() = default; - -ClCommandQueue::ClCommandQueue(cl_command_queue queue, bool has_ownership) - : queue_(queue), has_ownership_(has_ownership) {} - -ClCommandQueue::ClCommandQueue(ClCommandQueue&& queue) - : queue_(queue.queue_), has_ownership_(queue.has_ownership_) { - queue.queue_ = nullptr; -} - -ClCommandQueue& ClCommandQueue::operator=(ClCommandQueue&& queue) { - if (this != &queue) { - Release(); - std::swap(queue_, queue.queue_); - has_ownership_ = queue.has_ownership_; - } - return *this; -} - -ClCommandQueue::~ClCommandQueue() { Release(); } - -void ClCommandQueue::Release() { - if (has_ownership_ && queue_) { - clReleaseCommandQueue(queue_); - queue_ = nullptr; - } -} - -absl::Status ClCommandQueue::EnqueueWriteBuffer(cl_mem memory, - size_t size_in_bytes, - const void* data, bool async) { - const cl_bool blocking = async ? CL_FALSE : CL_TRUE; - auto error_code = clEnqueueWriteBuffer( - queue_, memory, blocking, 0, size_in_bytes, data, 0, nullptr, nullptr); - if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ", - std::to_string(error_code))); - } - return absl::OkStatus(); -} - -absl::Status ClCommandQueue::EnqueueReadBuffer(cl_mem memory, - size_t size_in_bytes, void* data, - bool async) { - const cl_bool blocking = async ? CL_FALSE : CL_TRUE; - auto error_code = clEnqueueReadBuffer( - queue_, memory, blocking, 0, size_in_bytes, data, 0, nullptr, nullptr); - if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ", - std::to_string(error_code))); - } - return absl::OkStatus(); -} - -absl::Status ClCommandQueue::WaitForCompletion() { - auto error_code = clFinish(queue_); - if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to clFinish - ", std::to_string(error_code))); - } - return absl::OkStatus(); -} - -absl::Status CreateClCommandQueue(const ClDevice& device, - const ClContext& context, - ClCommandQueue* result) { - auto queue = CreateClCommandQueueWithProperties(device, context, 0); - if (!queue.ok()) { - return queue.status(); - } - *result = ClCommandQueue(*queue, true); - return absl::OkStatus(); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h deleted file mode 100644 index a7691d52e6c65b..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/cl_command_queue.h. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ - -#include -#include - -#include "absl/status/status.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -namespace litert { -namespace cl { - -// A wrapper around opencl command queue -class ClCommandQueue { - public: - ClCommandQueue(); - ClCommandQueue(cl_command_queue queue, bool has_ownership); - - // Move only - ClCommandQueue(ClCommandQueue&& queue); - ClCommandQueue& operator=(ClCommandQueue&& queue); - ClCommandQueue(const ClCommandQueue&) = delete; - ClCommandQueue& operator=(const ClCommandQueue&) = delete; - - virtual ~ClCommandQueue(); - - cl_command_queue queue() const { return queue_; } - - absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, - const void* data, bool async = false); - absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, - void* data, bool async = false); - - absl::Status WaitForCompletion(); - - protected: - void Release(); - - cl_command_queue queue_ = nullptr; - bool has_ownership_ = false; -}; - -class ProfilingCommandQueue : public ClCommandQueue { - public: - ProfilingCommandQueue(); - explicit ProfilingCommandQueue(cl_command_queue queue); - - // Move only - ProfilingCommandQueue(ProfilingCommandQueue&& queue); - ProfilingCommandQueue& operator=(ProfilingCommandQueue&& queue); - ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; - ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; - - private: - std::string current_label_; -}; - -absl::Status CreateClCommandQueue(const ClDevice& device, - const ClContext& context, - ClCommandQueue* result); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc deleted file mode 100644 index 5eb5f4949d37f8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { -namespace { - -absl::Status CreateClContext(const ClDevice& device, - const std::vector& props, - ClContext* result) { - int error_code; - cl_device_id device_id = device.id(); - std::vector props_local = props; - if (!props_local.empty()) { - props_local.push_back(0); - } - cl_context_properties* properties_ptr = - props_local.empty() ? nullptr : props_local.data(); - cl_context context = clCreateContext(properties_ptr, 1, &device_id, nullptr, - nullptr, &error_code); - if (!context) { - return absl::UnknownError( - absl::StrCat("Failed to create a compute context - ", error_code)); - } - - *result = ClContext(context, true); - return absl::OkStatus(); -} - -} // namespace - -ClContext::ClContext() = default; - -ClContext::ClContext(cl_context context, bool has_ownership) - : context_(context), has_ownership_(has_ownership) {} - -ClContext::ClContext(cl_context context, bool has_ownership, ClDevice& device) - : context_(context), has_ownership_(has_ownership) {} - -ClContext::ClContext(ClContext&& context) - : context_(context.context_), has_ownership_(context.has_ownership_) { - context.context_ = nullptr; -} - -ClContext& ClContext::operator=(ClContext&& context) { - if (this != &context) { - Release(); - std::swap(context_, context.context_); - has_ownership_ = context.has_ownership_; - } - return *this; -} - -ClContext::~ClContext() { Release(); } - -void ClContext::Release() { - if (has_ownership_ && context_) { - clReleaseContext(context_); - context_ = nullptr; - } -} - -absl::Status CreateClContext(const ClDevice& device, ClContext* result) { - std::vector props; - return CreateClContext(device, props, result); -} - -absl::Status CreateClGlContext(const ClDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, - ClContext* result) { - cl_context_properties platform = - reinterpret_cast(device.platform()); - - std::vector props = {CL_GL_CONTEXT_KHR, egl_context, - CL_EGL_DISPLAY_KHR, egl_display, - CL_CONTEXT_PLATFORM, platform}; - - return CreateClContext(device, props, result); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h deleted file mode 100644 index 880e42b7c4a5c1..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ - -#include "absl/status/status.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -namespace litert { -namespace cl { - -// A RAII wrapper around opencl context -class ClContext { - public: - ClContext(); - ClContext(cl_context context, bool has_ownership); - ClContext(cl_context context, bool has_ownership, ClDevice& device); - // Move only - ClContext(ClContext&& context); - ClContext& operator=(ClContext&& context); - ClContext(const ClContext&) = delete; - ClContext& operator=(const ClContext&) = delete; - - ~ClContext(); - - cl_context context() const { return context_; } - - private: - void Release(); - - cl_context context_ = nullptr; - bool has_ownership_ = false; -}; - -absl::Status CreateClContext(const ClDevice& device, ClContext* result); -absl::Status CreateClGlContext(const ClDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, - ClContext* result); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc deleted file mode 100644 index 5677e50927a3c8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// this is a copy of ml_drift/cl/cl_device.cc -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { - -ClDevice::ClDevice(cl_device_id id, cl_platform_id platform_id) - : id_(id), platform_id_(platform_id) {} - -ClDevice::ClDevice(const ClDevice& device) = default; - -ClDevice& ClDevice::operator=(const ClDevice& device) { - if (this != &device) { - id_ = device.id_; - platform_id_ = device.platform_id_; - } - return *this; -} - -ClDevice::ClDevice(ClDevice&& device) - : id_(device.id_), platform_id_(device.platform_id_) { - device.id_ = nullptr; - device.platform_id_ = nullptr; -} - -ClDevice& ClDevice::operator=(ClDevice&& device) { - if (this != &device) { - id_ = nullptr; - platform_id_ = nullptr; - std::swap(id_, device.id_); - std::swap(platform_id_, device.platform_id_); - } - return *this; -} - -absl::Status CreateDefaultGPUDevice(ClDevice* result) { - cl_uint num_platforms; - cl_int status = clGetPlatformIDs(0, nullptr, &num_platforms); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetPlatformIDs returned %d", status)); - } - if (num_platforms == 0) { - return absl::UnknownError("No supported OpenCL platform."); - } - std::vector platforms(num_platforms); - status = clGetPlatformIDs(num_platforms, platforms.data(), nullptr); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetPlatformIDs returned %d", status)); - } - - cl_platform_id platform_id = platforms[0]; - cl_uint num_devices; - status = - clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetDeviceIDs returned %d", status)); - } - if (num_devices == 0) { - return absl::UnknownError("No GPU on current platform."); - } - - std::vector devices(num_devices); - status = clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, num_devices, - devices.data(), nullptr); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetDeviceIDs returned %d", status)); - } - - *result = ClDevice(devices[0], platform_id); - LoadOpenCLFunctionExtensions(platform_id); - return absl::OkStatus(); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h deleted file mode 100644 index 71d93e64ace879..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2024 The ML Drift Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ - -#include - -#include "absl/status/status.h" -#include -#include - -namespace litert { -namespace cl { - -// A wrapper around opencl device id -class ClDevice { - public: - ClDevice() = default; - ClDevice(cl_device_id id, cl_platform_id platform_id); - - ClDevice(ClDevice&& device); - ClDevice& operator=(ClDevice&& device); - ClDevice(const ClDevice&); - ClDevice& operator=(const ClDevice&); - - ~ClDevice() = default; - - cl_device_id id() const { return id_; } - cl_platform_id platform() const { return platform_id_; } - std::string GetPlatformVersion() const; - - private: - cl_device_id id_ = nullptr; - cl_platform_id platform_id_ = nullptr; -}; - -absl::Status CreateDefaultGPUDevice(ClDevice* result); - -template -T GetDeviceInfo(cl_device_id id, cl_device_info info) { - T result; - cl_int error = clGetDeviceInfo(id, info, sizeof(T), &result, nullptr); - if (error != CL_SUCCESS) { - return {}; - } - return result; -} - -template -absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { - cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr); - if (error != CL_SUCCESS) { - return absl::InvalidArgumentError("cl error:" + std::to_string(error)); - } - return absl::OkStatus(); -} - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc deleted file mode 100644 index 4fd14a130b9dec..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h" - -#include "absl/strings/str_format.h" -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { - -Expected WaitForEvents(int num_events, const cl_event* event_list) { - cl_int res = clWaitForEvents(num_events, event_list); - if (res != CL_SUCCESS) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("clWaitForEvents fails with error code %d", res)); - } - return {}; -} - -Expected SetUserEventStatus(cl_event event) { - cl_int res = clSetUserEventStatus(event, CL_COMPLETE); - if (res != CL_SUCCESS) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("clSetUserEventStatus fails with error code %d", res)); - } - return {}; -} - -Expected CreateUserEvent(cl_context context) { - cl_int res; - cl_event user_event = clCreateUserEvent(context, &res); - if (res != CL_SUCCESS) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("clCreateUserEvent fails with error code %d", res)); - } - return user_event; -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h deleted file mode 100644 index 1ba38b99a90555..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_EVENT_H_ - -#include -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace cl { - -Expected WaitForEvents(int num_events, const cl_event* event_list); - -Expected SetUserEventStatus(cl_event event); - -Expected CreateUserEvent(cl_context context); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc deleted file mode 100644 index 79c4e33e2eb72f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2024 The Tensorflow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is copied from third_party/ml_drift/cl/opencl_wrapper.cc. -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -#if defined(_WIN32) -#define __WINDOWS__ -#endif - -#ifdef __WINDOWS__ -#include -#else -#include -#endif - -#include - -#include "absl/strings/str_cat.h" - -namespace litert { -namespace cl { - -#ifdef __ANDROID__ -#define LoadFunction(function) \ - if (use_wrapper) { \ - function = reinterpret_cast(loadOpenCLPointer(#function)); \ - } else { \ - function = reinterpret_cast(dlsym(libopencl, #function)); \ - } - -namespace { - -// Loads a library from Android SP-HAL namespace which includes libraries from -// the path /vendor/lib[64] directly and several sub-folders in it. -// First tries using dlopen(), which should work if the process is running with -// linker namespace "sphal" (so has permissions to sphal paths). -// If it fails, for example if process is running with linker default namespace -// because it's a sub-process of the app, then tries loading the library using -// a sphal helper loader function from Vendor NDK support library. -void* AndroidDlopenSphalLibrary(const char* filename, int dlopen_flags) { - void* lib = dlopen(filename, dlopen_flags); - if (lib != nullptr) { - return lib; - } - static void* (*android_load_sphal_library)(const char*, int) = nullptr; - if (android_load_sphal_library != nullptr) { - return android_load_sphal_library(filename, dlopen_flags); - } - android_load_sphal_library = - reinterpret_cast( - dlsym(RTLD_NEXT, "android_load_sphal_library")); - if (android_load_sphal_library == nullptr) { - void* vndk = dlopen("libvndksupport.so", RTLD_NOW); - if (vndk != nullptr) { - android_load_sphal_library = - reinterpret_cast( - dlsym(vndk, "android_load_sphal_library")); - } - if (android_load_sphal_library == nullptr) { - return nullptr; - } - } - return android_load_sphal_library(filename, dlopen_flags); -} - -} // namespace - -#elif defined(__WINDOWS__) -#define LoadFunction(function) \ - function = \ - reinterpret_cast(GetProcAddress(libopencl, #function)); -#else -#define LoadFunction(function) \ - function = reinterpret_cast(dlsym(libopencl, #function)); -#endif - -#define LoadFunctionExtension(plat_id, function) \ - function = reinterpret_cast( \ - clGetExtensionFunctionAddressForPlatform(plat_id, #function)); - -#ifdef __WINDOWS__ -void LoadOpenCLFunctions(HMODULE libopencl); -#else -void LoadOpenCLFunctions(void* libopencl, bool use_wrapper); -#endif - -absl::Status LoadOpenCL() { -#ifdef __WINDOWS__ - HMODULE libopencl = LoadLibraryA("OpenCL.dll"); - if (libopencl) { - LoadOpenCLFunctions(libopencl); - return absl::OkStatus(); - } else { - DWORD error_code = GetLastError(); - return absl::UnknownError(absl::StrCat( - "Can not open OpenCL library on this device, error code - ", - error_code)); - } -#else - void* libopencl = nullptr; -#ifdef __APPLE__ - static const char* kClLibName = - "/System/Library/Frameworks/OpenCL.framework/OpenCL"; -#else - static const char* kClLibName = "libOpenCL.so"; -#endif -#ifdef __ANDROID__ - libopencl = AndroidDlopenSphalLibrary(kClLibName, RTLD_NOW | RTLD_LOCAL); - if (!libopencl) { - // Legacy Pixel phone or auto path? - libopencl = - AndroidDlopenSphalLibrary("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); - if (!libopencl) { - libopencl = - AndroidDlopenSphalLibrary("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); - } - if (libopencl) { - typedef void (*enableOpenCL_t)(); - enableOpenCL_t enableOpenCL = - reinterpret_cast(dlsym(libopencl, "enableOpenCL")); - enableOpenCL(); - LoadOpenCLFunctions(libopencl, true); - return absl::OkStatus(); - } - } -#else - libopencl = dlopen(kClLibName, RTLD_NOW | RTLD_LOCAL); -#endif - if (libopencl) { - LoadOpenCLFunctions(libopencl, false); - return absl::OkStatus(); - } - // record error - std::string error(dlerror()); - - // Check if OpenCL functions are found via OpenCL ICD Loader. - LoadOpenCLFunctions(libopencl, /*use_wrapper=*/false); - if (clGetPlatformIDs != nullptr) { - cl_uint num_platforms; - cl_int status = clGetPlatformIDs(0, nullptr, &num_platforms); - if (status == CL_SUCCESS && num_platforms != 0) { - return absl::OkStatus(); - } - return absl::UnknownError("OpenCL is not supported."); - } - return absl::UnknownError( - absl::StrCat("Can not open OpenCL library on this device - ", error)); -#endif -} - -void LoadOpenCLFunctionExtensions(cl_platform_id platform_id) { - // cl_khr_command_buffer extension - LoadFunctionExtension(platform_id, clCreateCommandBufferKHR); - LoadFunctionExtension(platform_id, clRetainCommandBufferKHR); - LoadFunctionExtension(platform_id, clReleaseCommandBufferKHR); - LoadFunctionExtension(platform_id, clFinalizeCommandBufferKHR); - LoadFunctionExtension(platform_id, clEnqueueCommandBufferKHR); - LoadFunctionExtension(platform_id, clCommandNDRangeKernelKHR); - LoadFunctionExtension(platform_id, clGetCommandBufferInfoKHR); -} - -#ifdef __WINDOWS__ -void LoadOpenCLFunctions(HMODULE libopencl) { -#else -void LoadOpenCLFunctions(void* libopencl, bool use_wrapper) { -#ifdef __ANDROID__ - typedef void* (*loadOpenCLPointer_t)(const char* name); - loadOpenCLPointer_t loadOpenCLPointer; - if (use_wrapper) { - loadOpenCLPointer = reinterpret_cast( - dlsym(libopencl, "loadOpenCLPointer")); - } -#endif -#endif - - LoadFunction(clGetPlatformIDs); - LoadFunction(clGetPlatformInfo); - LoadFunction(clGetDeviceIDs); - LoadFunction(clGetDeviceInfo); - LoadFunction(clCreateSubDevices); - LoadFunction(clRetainDevice); - LoadFunction(clReleaseDevice); - LoadFunction(clCreateContext); - LoadFunction(clCreateContextFromType); - LoadFunction(clRetainContext); - LoadFunction(clReleaseContext); - LoadFunction(clGetContextInfo); - LoadFunction(clCreateCommandQueueWithProperties); - LoadFunction(clRetainCommandQueue); - LoadFunction(clReleaseCommandQueue); - LoadFunction(clGetCommandQueueInfo); - LoadFunction(clCreateBuffer); - LoadFunction(clCreateSubBuffer); - LoadFunction(clCreateImage); - LoadFunction(clCreatePipe); - LoadFunction(clRetainMemObject); - LoadFunction(clReleaseMemObject); - LoadFunction(clGetSupportedImageFormats); - LoadFunction(clGetMemObjectInfo); - LoadFunction(clGetImageInfo); - LoadFunction(clGetPipeInfo); - LoadFunction(clSetMemObjectDestructorCallback); - LoadFunction(clSVMAlloc); - LoadFunction(clSVMFree); - LoadFunction(clCreateSamplerWithProperties); - LoadFunction(clRetainSampler); - LoadFunction(clReleaseSampler); - LoadFunction(clGetSamplerInfo); - LoadFunction(clCreateProgramWithSource); - LoadFunction(clCreateProgramWithBinary); - LoadFunction(clCreateProgramWithBuiltInKernels); - LoadFunction(clRetainProgram); - LoadFunction(clReleaseProgram); - LoadFunction(clBuildProgram); - LoadFunction(clCompileProgram); - LoadFunction(clLinkProgram); - LoadFunction(clUnloadPlatformCompiler); - LoadFunction(clGetProgramInfo); - LoadFunction(clGetProgramBuildInfo); - LoadFunction(clCreateKernel); - LoadFunction(clCreateKernelsInProgram); - LoadFunction(clRetainKernel); - LoadFunction(clReleaseKernel); - LoadFunction(clSetKernelArg); - LoadFunction(clSetKernelArgSVMPointer); - LoadFunction(clSetKernelExecInfo); - LoadFunction(clGetKernelInfo); - LoadFunction(clGetKernelArgInfo); - LoadFunction(clGetKernelWorkGroupInfo); - LoadFunction(clWaitForEvents); - LoadFunction(clGetEventInfo); - LoadFunction(clCreateUserEvent); - LoadFunction(clRetainEvent); - LoadFunction(clReleaseEvent); - LoadFunction(clSetUserEventStatus); - LoadFunction(clSetEventCallback); - LoadFunction(clGetEventProfilingInfo); - LoadFunction(clFlush); - LoadFunction(clFinish); - LoadFunction(clEnqueueReadBuffer); - LoadFunction(clEnqueueReadBufferRect); - LoadFunction(clEnqueueWriteBuffer); - LoadFunction(clEnqueueWriteBufferRect); - LoadFunction(clEnqueueFillBuffer); - LoadFunction(clEnqueueCopyBuffer); - LoadFunction(clEnqueueCopyBufferRect); - LoadFunction(clEnqueueReadImage); - LoadFunction(clEnqueueWriteImage); - LoadFunction(clEnqueueFillImage); - LoadFunction(clEnqueueCopyImage); - LoadFunction(clEnqueueCopyImageToBuffer); - LoadFunction(clEnqueueCopyBufferToImage); - LoadFunction(clEnqueueMapBuffer); - LoadFunction(clEnqueueMapImage); - LoadFunction(clEnqueueUnmapMemObject); - LoadFunction(clEnqueueMigrateMemObjects); - LoadFunction(clEnqueueNDRangeKernel); - LoadFunction(clEnqueueNativeKernel); - LoadFunction(clEnqueueMarkerWithWaitList); - LoadFunction(clEnqueueBarrierWithWaitList); - LoadFunction(clEnqueueSVMFree); - LoadFunction(clEnqueueSVMMemcpy); - LoadFunction(clEnqueueSVMMemFill); - LoadFunction(clEnqueueSVMMap); - LoadFunction(clEnqueueSVMUnmap); - LoadFunction(clGetExtensionFunctionAddressForPlatform); - LoadFunction(clCreateImage2D); - LoadFunction(clCreateImage3D); - LoadFunction(clEnqueueMarker); - LoadFunction(clEnqueueWaitForEvents); - LoadFunction(clEnqueueBarrier); - LoadFunction(clUnloadCompiler); - LoadFunction(clGetExtensionFunctionAddress); - LoadFunction(clCreateCommandQueue); - LoadFunction(clCreateSampler); - LoadFunction(clEnqueueTask); - - // OpenGL sharing - LoadFunction(clCreateFromGLBuffer); - LoadFunction(clCreateFromGLTexture); - LoadFunction(clEnqueueAcquireGLObjects); - LoadFunction(clEnqueueReleaseGLObjects); - - // cl_khr_egl_event extension - LoadFunction(clCreateEventFromEGLSyncKHR); - - // EGL sharing - LoadFunction(clCreateFromEGLImageKHR); - LoadFunction(clEnqueueAcquireEGLObjectsKHR); - LoadFunction(clEnqueueReleaseEGLObjectsKHR); - - // OpenCL 3.0 - LoadFunction(clCreateBufferWithProperties); - LoadFunction(clCreateImageWithProperties); -} - -// No OpenCL support, do not set function addresses -PFN_clGetPlatformIDs clGetPlatformIDs; -PFN_clGetPlatformInfo clGetPlatformInfo; -PFN_clGetDeviceIDs clGetDeviceIDs; -PFN_clGetDeviceInfo clGetDeviceInfo; -PFN_clCreateSubDevices clCreateSubDevices; -PFN_clRetainDevice clRetainDevice; -PFN_clReleaseDevice clReleaseDevice; -PFN_clCreateContext clCreateContext; -PFN_clCreateContextFromType clCreateContextFromType; -PFN_clRetainContext clRetainContext; -PFN_clReleaseContext clReleaseContext; -PFN_clGetContextInfo clGetContextInfo; -PFN_clCreateCommandQueueWithProperties clCreateCommandQueueWithProperties; -PFN_clRetainCommandQueue clRetainCommandQueue; -PFN_clReleaseCommandQueue clReleaseCommandQueue; -PFN_clGetCommandQueueInfo clGetCommandQueueInfo; -PFN_clCreateBuffer clCreateBuffer; -PFN_clCreateSubBuffer clCreateSubBuffer; -PFN_clCreateImage clCreateImage; -PFN_clCreatePipe clCreatePipe; -PFN_clRetainMemObject clRetainMemObject; -PFN_clReleaseMemObject clReleaseMemObject; -PFN_clGetSupportedImageFormats clGetSupportedImageFormats; -PFN_clGetMemObjectInfo clGetMemObjectInfo; -PFN_clGetImageInfo clGetImageInfo; -PFN_clGetPipeInfo clGetPipeInfo; -PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback; -PFN_clSVMAlloc clSVMAlloc; -PFN_clSVMFree clSVMFree; -PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties; -PFN_clRetainSampler clRetainSampler; -PFN_clReleaseSampler clReleaseSampler; -PFN_clGetSamplerInfo clGetSamplerInfo; -PFN_clCreateProgramWithSource clCreateProgramWithSource; -PFN_clCreateProgramWithBinary clCreateProgramWithBinary; -PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels; -PFN_clRetainProgram clRetainProgram; -PFN_clReleaseProgram clReleaseProgram; -PFN_clBuildProgram clBuildProgram; -PFN_clCompileProgram clCompileProgram; -PFN_clLinkProgram clLinkProgram; -PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler; -PFN_clGetProgramInfo clGetProgramInfo; -PFN_clGetProgramBuildInfo clGetProgramBuildInfo; -PFN_clCreateKernel clCreateKernel; -PFN_clCreateKernelsInProgram clCreateKernelsInProgram; -PFN_clRetainKernel clRetainKernel; -PFN_clReleaseKernel clReleaseKernel; -PFN_clSetKernelArg clSetKernelArg; -PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer; -PFN_clSetKernelExecInfo clSetKernelExecInfo; -PFN_clGetKernelInfo clGetKernelInfo; -PFN_clGetKernelArgInfo clGetKernelArgInfo; -PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo; -PFN_clWaitForEvents clWaitForEvents; -PFN_clGetEventInfo clGetEventInfo; -PFN_clCreateUserEvent clCreateUserEvent; -PFN_clRetainEvent clRetainEvent; -PFN_clReleaseEvent clReleaseEvent; -PFN_clSetUserEventStatus clSetUserEventStatus; -PFN_clSetEventCallback clSetEventCallback; -PFN_clGetEventProfilingInfo clGetEventProfilingInfo; -PFN_clFlush clFlush; -PFN_clFinish clFinish; -PFN_clEnqueueReadBuffer clEnqueueReadBuffer; -PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect; -PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer; -PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect; -PFN_clEnqueueFillBuffer clEnqueueFillBuffer; -PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer; -PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect; -PFN_clEnqueueReadImage clEnqueueReadImage; -PFN_clEnqueueWriteImage clEnqueueWriteImage; -PFN_clEnqueueFillImage clEnqueueFillImage; -PFN_clEnqueueCopyImage clEnqueueCopyImage; -PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer; -PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage; -PFN_clEnqueueMapBuffer clEnqueueMapBuffer; -PFN_clEnqueueMapImage clEnqueueMapImage; -PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject; -PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects; -PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel; -PFN_clEnqueueNativeKernel clEnqueueNativeKernel; -PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList; -PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList; -PFN_clEnqueueSVMFree clEnqueueSVMFree; -PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy; -PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill; -PFN_clEnqueueSVMMap clEnqueueSVMMap; -PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap; -PFN_clGetExtensionFunctionAddressForPlatform - clGetExtensionFunctionAddressForPlatform; -PFN_clCreateImage2D clCreateImage2D; -PFN_clCreateImage3D clCreateImage3D; -PFN_clEnqueueMarker clEnqueueMarker; -PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents; -PFN_clEnqueueBarrier clEnqueueBarrier; -PFN_clUnloadCompiler clUnloadCompiler; -PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress; -PFN_clCreateCommandQueue clCreateCommandQueue; -PFN_clCreateSampler clCreateSampler; -PFN_clEnqueueTask clEnqueueTask; - -// OpenGL sharing -PFN_clCreateFromGLBuffer clCreateFromGLBuffer; -PFN_clCreateFromGLTexture clCreateFromGLTexture; -PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects; -PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects; - -// cl_khr_egl_event extension -PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR; - -// EGL sharing -PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; -PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; -PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; - -// cl_khr_command_buffer extension -PFN_clCreateCommandBufferKHR clCreateCommandBufferKHR; -PFN_clRetainCommandBufferKHR clRetainCommandBufferKHR; -PFN_clReleaseCommandBufferKHR clReleaseCommandBufferKHR; -PFN_clFinalizeCommandBufferKHR clFinalizeCommandBufferKHR; -PFN_clEnqueueCommandBufferKHR clEnqueueCommandBufferKHR; -PFN_clCommandNDRangeKernelKHR clCommandNDRangeKernelKHR; -PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; - -// OpenCL 3.0 -PFN_clCreateBufferWithProperties clCreateBufferWithProperties; -PFN_clCreateImageWithProperties clCreateImageWithProperties; - -cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format* image_format, - const cl_image_desc* image_desc, void* host_ptr, - cl_int* errcode_ret) { - if (clCreateImage) { // clCreateImage available since OpenCL 1.2 - return clCreateImage(context, flags, image_format, image_desc, host_ptr, - errcode_ret); - } else { - return clCreateImage2D(context, flags, image_format, - image_desc->image_width, image_desc->image_height, - image_desc->image_row_pitch, host_ptr, errcode_ret); - } -} - -cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format* image_format, - const cl_image_desc* image_desc, void* host_ptr, - cl_int* errcode_ret) { - if (clCreateImage) { // clCreateImage available since OpenCL 1.2 - return clCreateImage(context, flags, image_format, image_desc, host_ptr, - errcode_ret); - } else { - return clCreateImage3D(context, flags, image_format, - image_desc->image_width, image_desc->image_height, - image_desc->image_depth, image_desc->image_row_pitch, - image_desc->image_slice_pitch, host_ptr, - errcode_ret); - } -} -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h deleted file mode 100644 index cfbeb805dbb49d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h +++ /dev/null @@ -1,737 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is copied from third_party/ml_drift/cl/opencl_wrapper.h. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ - -#include - -#include "absl/status/status.h" -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export - -namespace litert { -namespace cl { - -absl::Status LoadOpenCL(); -void LoadOpenCLFunctionExtensions(cl_platform_id platform_id); - -typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)( - cl_uint /* num_entries */, cl_platform_id * /* platforms */, - cl_uint * /* num_platforms */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetPlatformInfo)( - cl_platform_id /* platform */, cl_platform_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetDeviceIDs)( - cl_platform_id /* platform */, cl_device_type /* device_type */, - cl_uint /* num_entries */, cl_device_id * /* devices */, - cl_uint * /* num_devices */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetDeviceInfo)( - cl_device_id /* device */, cl_device_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clCreateSubDevices)( - cl_device_id /* in_device */, - const cl_device_partition_property * /* properties */, - cl_uint /* num_devices */, cl_device_id * /* out_devices */, - cl_uint * /* num_devices_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clRetainDevice)(cl_device_id /* device */) - CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clReleaseDevice)(cl_device_id /* device */) - CL_API_SUFFIX__VERSION_1_2; -typedef cl_context(CL_API_CALL *PFN_clCreateContext)( - const cl_context_properties * /* properties */, cl_uint /* num_devices */, - const cl_device_id * /* devices */, - void(CL_CALLBACK * /* pfn_notify */)(const char *, const void *, size_t, - void *), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_context(CL_API_CALL *PFN_clCreateContextFromType)( - const cl_context_properties * /* properties */, - cl_device_type /* device_type */, - void(CL_CALLBACK * /* pfn_notify*/)(const char *, const void *, size_t, - void *), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clRetainContext)(cl_context /* context */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseContext)(cl_context /* context */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetContextInfo)( - cl_context /* context */, cl_context_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueueWithProperties)( - cl_context /* context */, cl_device_id /* device */, - const cl_queue_properties * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clRetainCommandQueue)( - cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseCommandQueue)( - cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetCommandQueueInfo)( - cl_command_queue /* command_queue */, - cl_command_queue_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_mem(CL_API_CALL *PFN_clCreateBuffer)( - cl_context /* context */, cl_mem_flags /* flags */, size_t /* size */, - void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_mem(CL_API_CALL *PFN_clCreateSubBuffer)( - cl_mem /* buffer */, cl_mem_flags /* flags */, - cl_buffer_create_type /* buffer_create_type */, - const void * /* buffer_create_info */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_mem(CL_API_CALL *PFN_clCreateImage)( - cl_context /* context */, cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, - const cl_image_desc * /* image_desc */, void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_mem(CL_API_CALL *PFN_clCreatePipe)( - cl_context /* context */, cl_mem_flags /* flags */, - cl_uint /* pipe_packet_size */, cl_uint /* pipe_max_packets */, - const cl_pipe_properties * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clRetainMemObject)(cl_mem /* memobj */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseMemObject)(cl_mem /* memobj */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetSupportedImageFormats)( - cl_context /* context */, cl_mem_flags /* flags */, - cl_mem_object_type /* image_type */, cl_uint /* num_entries */, - cl_image_format * /* image_formats */, - cl_uint * /* num_image_formats */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetMemObjectInfo)( - cl_mem /* memobj */, cl_mem_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetImageInfo)( - cl_mem /* image */, cl_image_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetPipeInfo)( - cl_mem /* pipe */, cl_pipe_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clSetMemObjectDestructorCallback)( - cl_mem /* memobj */, - void(CL_CALLBACK * /*pfn_notify*/)(cl_mem /* memobj */, - void * /*user_data*/), - void * /*user_data */) CL_API_SUFFIX__VERSION_1_1; -typedef void *(CL_API_CALL *PFN_clSVMAlloc)( - cl_context /* context */, cl_svm_mem_flags /* flags */, size_t /* size */, - cl_uint /* alignment */)CL_API_SUFFIX__VERSION_2_0; -typedef void(CL_API_CALL *PFN_clSVMFree)(cl_context /* context */, - void * /* svm_pointer */) - CL_API_SUFFIX__VERSION_2_0; -typedef cl_sampler(CL_API_CALL *PFN_clCreateSamplerWithProperties)( - cl_context /* context */, - const cl_sampler_properties * /* normalized_coords */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clRetainSampler)(cl_sampler /* sampler */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseSampler)(cl_sampler /* sampler */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetSamplerInfo)( - cl_sampler /* sampler */, cl_sampler_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithSource)( - cl_context /* context */, cl_uint /* count */, const char ** /* strings */, - const size_t * /* lengths */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBinary)( - cl_context /* context */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const size_t * /* lengths */, - const unsigned char ** /* binaries */, cl_int * /* binary_status */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBuiltInKernels)( - cl_context /* context */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* kernel_names */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clRetainProgram)(cl_program /* program */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseProgram)(cl_program /* program */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clBuildProgram)( - cl_program /* program */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* options */, - void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, - void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clCompileProgram)( - cl_program /* program */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* options */, - cl_uint /* num_input_headers */, const cl_program * /* input_headers */, - const char ** /* header_include_names */, - void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, - void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_program(CL_API_CALL *PFN_clLinkProgram)( - cl_context /* context */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* options */, - cl_uint /* num_input_programs */, const cl_program * /* input_programs */, - void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, - void * /* user_data */), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clUnloadPlatformCompiler)( - cl_platform_id /* platform */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clGetProgramInfo)( - cl_program /* program */, cl_program_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetProgramBuildInfo)( - cl_program /* program */, cl_device_id /* device */, - cl_program_build_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_kernel(CL_API_CALL *PFN_clCreateKernel)( - cl_program /* program */, const char * /* kernel_name */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clCreateKernelsInProgram)( - cl_program /* program */, cl_uint /* num_kernels */, - cl_kernel * /* kernels */, - cl_uint * /* num_kernels_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clRetainKernel)(cl_kernel /* kernel */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseKernel)(cl_kernel /* kernel */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clSetKernelArg)( - cl_kernel /* kernel */, cl_uint /* arg_index */, size_t /* arg_size */, - const void * /* arg_value */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clSetKernelArgSVMPointer)( - cl_kernel /* kernel */, cl_uint /* arg_index */, - const void * /* arg_value */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clSetKernelExecInfo)( - cl_kernel /* kernel */, cl_kernel_exec_info /* param_name */, - size_t /* param_value_size */, - const void * /* param_value */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clGetKernelInfo)( - cl_kernel /* kernel */, cl_kernel_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetKernelArgInfo)( - cl_kernel /* kernel */, cl_uint /* arg_indx */, - cl_kernel_arg_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clGetKernelWorkGroupInfo)( - cl_kernel /* kernel */, cl_device_id /* device */, - cl_kernel_work_group_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clWaitForEvents)( - cl_uint /* num_events */, - const cl_event * /* event_list */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetEventInfo)( - cl_event /* event */, cl_event_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_event(CL_API_CALL *PFN_clCreateUserEvent)(cl_context /* context */, - cl_int * /* errcode_ret */) - CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clRetainEvent)(cl_event /* event */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseEvent)(cl_event /* event */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clSetUserEventStatus)( - cl_event /* event */, - cl_int /* execution_status */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clSetEventCallback)( - cl_event /* event */, cl_int /* command_exec_callback_type */, - void(CL_CALLBACK * /* pfn_notify */)(cl_event, cl_int, void *), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clGetEventProfilingInfo)( - cl_event /* event */, cl_profiling_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clFlush)(cl_command_queue /* command_queue */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clFinish)(cl_command_queue /* command_queue */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_read */, size_t /* offset */, size_t /* size */, - void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBufferRect)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_read */, const size_t * /* buffer_offset */, - const size_t * /* host_offset */, const size_t * /* region */, - size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, - size_t /* host_row_pitch */, size_t /* host_slice_pitch */, - void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_write */, size_t /* offset */, size_t /* size */, - const void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBufferRect)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_write */, const size_t * /* buffer_offset */, - const size_t * /* host_offset */, const size_t * /* region */, - size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, - size_t /* host_row_pitch */, size_t /* host_slice_pitch */, - const void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clEnqueueFillBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - const void * /* pattern */, size_t /* pattern_size */, size_t /* offset */, - size_t /* size */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBuffer)( - cl_command_queue /* command_queue */, cl_mem /* src_buffer */, - cl_mem /* dst_buffer */, size_t /* src_offset */, size_t /* dst_offset */, - size_t /* size */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferRect)( - cl_command_queue /* command_queue */, cl_mem /* src_buffer */, - cl_mem /* dst_buffer */, const size_t * /* src_origin */, - const size_t * /* dst_origin */, const size_t * /* region */, - size_t /* src_row_pitch */, size_t /* src_slice_pitch */, - size_t /* dst_row_pitch */, size_t /* dst_slice_pitch */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clEnqueueReadImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - cl_bool /* blocking_read */, const size_t * /* origin[3] */, - const size_t * /* region[3] */, size_t /* row_pitch */, - size_t /* slice_pitch */, void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - cl_bool /* blocking_write */, const size_t * /* origin[3] */, - const size_t * /* region[3] */, size_t /* input_row_pitch */, - size_t /* input_slice_pitch */, const void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueFillImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - const void * /* fill_color */, const size_t * /* origin[3] */, - const size_t * /* region[3] */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImage)( - cl_command_queue /* command_queue */, cl_mem /* src_image */, - cl_mem /* dst_image */, const size_t * /* src_origin[3] */, - const size_t * /* dst_origin[3] */, const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImageToBuffer)( - cl_command_queue /* command_queue */, cl_mem /* src_image */, - cl_mem /* dst_buffer */, const size_t * /* src_origin[3] */, - const size_t * /* region[3] */, size_t /* dst_offset */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferToImage)( - cl_command_queue /* command_queue */, cl_mem /* src_buffer */, - cl_mem /* dst_image */, size_t /* src_offset */, - const size_t * /* dst_origin[3] */, const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef void *(CL_API_CALL *PFN_clEnqueueMapBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_map */, cl_map_flags /* map_flags */, - size_t /* offset */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */, - cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0; -typedef void *(CL_API_CALL *PFN_clEnqueueMapImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - cl_bool /* blocking_map */, cl_map_flags /* map_flags */, - const size_t * /* origin[3] */, const size_t * /* region[3] */, - size_t * /* image_row_pitch */, size_t * /* image_slice_pitch */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */, - cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueUnmapMemObject)( - cl_command_queue /* command_queue */, cl_mem /* memobj */, - void * /* mapped_ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueMigrateMemObjects)( - cl_command_queue /* command_queue */, cl_uint /* num_mem_objects */, - const cl_mem * /* mem_objects */, cl_mem_migration_flags /* flags */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueNDRangeKernel)( - cl_command_queue /* command_queue */, cl_kernel /* kernel */, - cl_uint /* work_dim */, const size_t * /* global_work_offset */, - const size_t * /* global_work_size */, const size_t * /* local_work_size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueNativeKernel)( - cl_command_queue /* command_queue */, - void(CL_CALLBACK * /*user_func*/)(void *), void * /* args */, - size_t /* cb_args */, cl_uint /* num_mem_objects */, - const cl_mem * /* mem_list */, const void ** /* args_mem_loc */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueMarkerWithWaitList)( - cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrierWithWaitList)( - cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMFree)( - cl_command_queue /* command_queue */, cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void(CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */, - cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void * /* user_data */), - void * /* user_data */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemcpy)( - cl_command_queue /* command_queue */, cl_bool /* blocking_copy */, - void * /* dst_ptr */, const void * /* src_ptr */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemFill)( - cl_command_queue /* command_queue */, void * /* svm_ptr */, - const void * /* pattern */, size_t /* pattern_size */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMap)( - cl_command_queue /* command_queue */, cl_bool /* blocking_map */, - cl_map_flags /* flags */, void * /* svm_ptr */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMUnmap)( - cl_command_queue /* command_queue */, void * /* svm_ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddressForPlatform)( - cl_platform_id /* platform */, - const char * /* func_name */)CL_API_SUFFIX__VERSION_1_2; -typedef cl_mem(CL_API_CALL *PFN_clCreateImage2D)( - cl_context /* context */, cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, size_t /* image_width */, - size_t /* image_height */, size_t /* image_row_pitch */, - void * /* host_ptr */, cl_int * /* errcode_ret */); -typedef cl_mem(CL_API_CALL *PFN_clCreateImage3D)( - cl_context /* context */, cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, size_t /* image_width */, - size_t /* image_height */, size_t /* image_depth */, - size_t /* image_row_pitch */, size_t /* image_slice_pitch */, - void * /* host_ptr */, cl_int * /* errcode_ret */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueMarker)( - cl_command_queue /* command_queue */, cl_event * /* event */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueWaitForEvents)( - cl_command_queue /* command_queue */, cl_uint /* num_events */, - const cl_event * /* event_list */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrier)( - cl_command_queue /* command_queue */); -typedef cl_int(CL_API_CALL *PFN_clUnloadCompiler)(); -typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddress)( - const char * /* func_name */); -typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueue)( - cl_context /* context */, cl_device_id /* device */, - cl_command_queue_properties /* properties */, cl_int * /* errcode_ret */); -typedef cl_sampler(CL_API_CALL *PFN_clCreateSampler)( - cl_context /* context */, cl_bool /* normalized_coords */, - cl_addressing_mode /* addressing_mode */, cl_filter_mode /* filter_mode */, - cl_int * /* errcode_ret */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueTask)( - cl_command_queue /* command_queue */, cl_kernel /* kernel */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */); - -// OpenGL sharing -typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLBuffer)(cl_context, cl_mem_flags, - cl_GLuint, int *); -typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLTexture)( - cl_context /* context */, cl_mem_flags /* flags */, cl_GLenum /* target */, - cl_GLint /* miplevel */, cl_GLuint /* texture */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireGLObjects)( - cl_command_queue /* command_queue */, cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseGLObjects)( - cl_command_queue /* command_queue */, cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -// cl_khr_egl_event extension - -// CLeglDisplayKHR is an opaque handle to an EGLDisplay -typedef void *CLeglDisplayKHR; - -// CLeglSyncKHR is an opaque handle to an EGLSync object -typedef void *CLeglSyncKHR; - -typedef cl_event(CL_API_CALL *PFN_clCreateEventFromEGLSyncKHR)( - cl_context /* context */, CLeglSyncKHR /* sync */, - CLeglDisplayKHR /* display */, cl_int * /* errcode_ret */); - -// EGL sharing -typedef cl_mem(CL_API_CALL *PFN_clCreateFromEGLImageKHR)( - cl_context /*context*/, CLeglDisplayKHR /*display*/, - CLeglImageKHR /*image*/, cl_mem_flags /*flags*/, - const cl_egl_image_properties_khr * /*properties*/, - cl_int * /*errcode_ret*/); -typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireEGLObjectsKHR)( - cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, - const cl_mem * /*mem_objects*/, cl_uint /*num_events_in_wait_list*/, - const cl_event * /*event_wait_list*/, cl_event * /*event*/); -typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseEGLObjectsKHR)( - cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, - const cl_mem * /*mem_objects*/, cl_uint /*num_events_in_wait_list*/, - const cl_event * /*event_wait_list*/, cl_event * /*event*/); - -// cl_khr_command_buffer -typedef cl_command_buffer_khr(CL_API_CALL *PFN_clCreateCommandBufferKHR)( - cl_uint /*num_queues*/, const cl_command_queue * /*queues*/, - const cl_command_buffer_properties_khr * /*properties*/, - cl_int * /*errcode_ret*/); - -typedef cl_int(CL_API_CALL *PFN_clRetainCommandBufferKHR)( - cl_command_buffer_khr /*command_buffer*/); - -typedef cl_int(CL_API_CALL *PFN_clReleaseCommandBufferKHR)( - cl_command_buffer_khr /*command_buffer*/); - -typedef cl_int(CL_API_CALL *PFN_clFinalizeCommandBufferKHR)( - cl_command_buffer_khr /*command_buffer*/); - -typedef cl_int(CL_API_CALL *PFN_clEnqueueCommandBufferKHR)( - cl_uint /*num_queues*/, cl_command_queue * /*queues*/, - cl_command_buffer_khr /*command_buffer*/, - cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, - cl_event * /*event*/); - -#if CL_KHR_COMMAND_BUFFER_EXTENSION_VERSION >= CL_MAKE_VERSION(0, 9, 5) -typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( - cl_command_buffer_khr /*command_buffer*/, - cl_command_queue /*command_queue*/, - const cl_command_properties_khr * /*properties*/, cl_kernel /*kernel*/, - cl_uint /*work_dim*/, const size_t * /*global_work_offset*/, - const size_t * /*global_work_size*/, const size_t * /*local_work_size*/, - cl_uint /*num_sync_points_in_wait_list*/, - const cl_sync_point_khr * /*sync_point_wait_list*/, - cl_sync_point_khr * /*sync_point*/, - cl_mutable_command_khr * /*mutable_handle*/); -#else -typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( - cl_command_buffer_khr /*command_buffer*/, - cl_command_queue /*command_queue*/, - const cl_ndrange_kernel_command_properties_khr * /*properties*/, - cl_kernel /*kernel*/, cl_uint /*work_dim*/, - const size_t * /*global_work_offset*/, const size_t * /*global_work_size*/, - const size_t * /*local_work_size*/, - cl_uint /*num_sync_points_in_wait_list*/, - const cl_sync_point_khr * /*sync_point_wait_list*/, - cl_sync_point_khr * /*sync_point*/, - cl_mutable_command_khr * /*mutable_handle*/); -#endif - -typedef cl_int(CL_API_CALL *PFN_clGetCommandBufferInfoKHR)( - cl_command_buffer_khr /*command_buffer*/, - cl_command_buffer_info_khr /*param_name*/, size_t /*param_value_size*/, - void * /*param_value*/, size_t * /*param_value_size_ret*/); - -// OpenCL 3.0 -typedef cl_mem(CL_API_CALL *PFN_clCreateBufferWithProperties)( - cl_context /*context*/, const cl_mem_properties * /*properties*/, - cl_mem_flags /*flags*/, size_t /*size*/, void * /*host_ptr*/, - cl_int * /*errcode_ret*/); -typedef cl_mem(CL_API_CALL *PFN_clCreateImageWithProperties)( - cl_context /*context*/, const cl_mem_properties * /*properties*/, - cl_mem_flags /*flags*/, const cl_image_format * /*image_format*/, - const cl_image_desc * /*image_desc*/, void * /*host_ptr*/, - cl_int * /*errcode_ret*/); - -extern PFN_clGetPlatformIDs clGetPlatformIDs; -extern PFN_clGetPlatformInfo clGetPlatformInfo; -extern PFN_clGetDeviceIDs clGetDeviceIDs; -extern PFN_clGetDeviceInfo clGetDeviceInfo; -extern PFN_clCreateSubDevices clCreateSubDevices; -extern PFN_clRetainDevice clRetainDevice; -extern PFN_clReleaseDevice clReleaseDevice; -extern PFN_clCreateContext clCreateContext; -extern PFN_clCreateContextFromType clCreateContextFromType; -extern PFN_clRetainContext clRetainContext; -extern PFN_clReleaseContext clReleaseContext; -extern PFN_clGetContextInfo clGetContextInfo; -extern PFN_clCreateCommandQueueWithProperties - clCreateCommandQueueWithProperties; -extern PFN_clRetainCommandQueue clRetainCommandQueue; -extern PFN_clReleaseCommandQueue clReleaseCommandQueue; -extern PFN_clGetCommandQueueInfo clGetCommandQueueInfo; -extern PFN_clCreateBuffer clCreateBuffer; -extern PFN_clCreateSubBuffer clCreateSubBuffer; -extern PFN_clCreateImage clCreateImage; -extern PFN_clCreatePipe clCreatePipe; -extern PFN_clRetainMemObject clRetainMemObject; -extern PFN_clReleaseMemObject clReleaseMemObject; -extern PFN_clGetSupportedImageFormats clGetSupportedImageFormats; -extern PFN_clGetMemObjectInfo clGetMemObjectInfo; -extern PFN_clGetImageInfo clGetImageInfo; -extern PFN_clGetPipeInfo clGetPipeInfo; -extern PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback; -extern PFN_clSVMAlloc clSVMAlloc; -extern PFN_clSVMFree clSVMFree; -extern PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties; -extern PFN_clRetainSampler clRetainSampler; -extern PFN_clReleaseSampler clReleaseSampler; -extern PFN_clGetSamplerInfo clGetSamplerInfo; -extern PFN_clCreateProgramWithSource clCreateProgramWithSource; -extern PFN_clCreateProgramWithBinary clCreateProgramWithBinary; -extern PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels; -extern PFN_clRetainProgram clRetainProgram; -extern PFN_clReleaseProgram clReleaseProgram; -extern PFN_clBuildProgram clBuildProgram; -extern PFN_clCompileProgram clCompileProgram; -extern PFN_clLinkProgram clLinkProgram; -extern PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler; -extern PFN_clGetProgramInfo clGetProgramInfo; -extern PFN_clGetProgramBuildInfo clGetProgramBuildInfo; -extern PFN_clCreateKernel clCreateKernel; -extern PFN_clCreateKernelsInProgram clCreateKernelsInProgram; -extern PFN_clRetainKernel clRetainKernel; -extern PFN_clReleaseKernel clReleaseKernel; -extern PFN_clSetKernelArg clSetKernelArg; -extern PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer; -extern PFN_clSetKernelExecInfo clSetKernelExecInfo; -extern PFN_clGetKernelInfo clGetKernelInfo; -extern PFN_clGetKernelArgInfo clGetKernelArgInfo; -extern PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo; -extern PFN_clWaitForEvents clWaitForEvents; -extern PFN_clGetEventInfo clGetEventInfo; -extern PFN_clCreateUserEvent clCreateUserEvent; -extern PFN_clRetainEvent clRetainEvent; -extern PFN_clReleaseEvent clReleaseEvent; -extern PFN_clSetUserEventStatus clSetUserEventStatus; -extern PFN_clSetEventCallback clSetEventCallback; -extern PFN_clGetEventProfilingInfo clGetEventProfilingInfo; -extern PFN_clFlush clFlush; -extern PFN_clFinish clFinish; -extern PFN_clEnqueueReadBuffer clEnqueueReadBuffer; -extern PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect; -extern PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer; -extern PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect; -extern PFN_clEnqueueFillBuffer clEnqueueFillBuffer; -extern PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer; -extern PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect; -extern PFN_clEnqueueReadImage clEnqueueReadImage; -extern PFN_clEnqueueWriteImage clEnqueueWriteImage; -extern PFN_clEnqueueFillImage clEnqueueFillImage; -extern PFN_clEnqueueCopyImage clEnqueueCopyImage; -extern PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer; -extern PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage; -extern PFN_clEnqueueMapBuffer clEnqueueMapBuffer; -extern PFN_clEnqueueMapImage clEnqueueMapImage; -extern PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject; -extern PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects; -extern PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel; -extern PFN_clEnqueueNativeKernel clEnqueueNativeKernel; -extern PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList; -extern PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList; -extern PFN_clEnqueueSVMFree clEnqueueSVMFree; -extern PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy; -extern PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill; -extern PFN_clEnqueueSVMMap clEnqueueSVMMap; -extern PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap; -extern PFN_clGetExtensionFunctionAddressForPlatform - clGetExtensionFunctionAddressForPlatform; -extern PFN_clCreateImage2D clCreateImage2D; -extern PFN_clCreateImage3D clCreateImage3D; -extern PFN_clEnqueueMarker clEnqueueMarker; -extern PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents; -extern PFN_clEnqueueBarrier clEnqueueBarrier; -extern PFN_clUnloadCompiler clUnloadCompiler; -extern PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress; -extern PFN_clCreateCommandQueue clCreateCommandQueue; -extern PFN_clCreateSampler clCreateSampler; -extern PFN_clEnqueueTask clEnqueueTask; - -// OpenGL sharing -extern PFN_clCreateFromGLBuffer clCreateFromGLBuffer; -extern PFN_clCreateFromGLTexture clCreateFromGLTexture; -extern PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects; -extern PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects; - -// cl_khr_egl_event extension -extern PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR; - -// EGL sharing -extern PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; -extern PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; -extern PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; - -// cl_khr_command_buffer extension -extern PFN_clCreateCommandBufferKHR clCreateCommandBufferKHR; -extern PFN_clRetainCommandBufferKHR clRetainCommandBufferKHR; -extern PFN_clReleaseCommandBufferKHR clReleaseCommandBufferKHR; -extern PFN_clFinalizeCommandBufferKHR clFinalizeCommandBufferKHR; -extern PFN_clEnqueueCommandBufferKHR clEnqueueCommandBufferKHR; -extern PFN_clCommandNDRangeKernelKHR clCommandNDRangeKernelKHR; -extern PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; - -// OpenCL 3.0 -extern PFN_clCreateBufferWithProperties clCreateBufferWithProperties; -extern PFN_clCreateImageWithProperties clCreateImageWithProperties; - -// For convenient image creation -// It uses clCreateImage if it available (clCreateImage available since cl 1.2) -// otherwise it will use legacy clCreateImage2D -cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format *image_format, - const cl_image_desc *image_desc, void *host_ptr, - cl_int *errcode_ret); - -// It uses clCreateImage if it available (clCreateImage available since cl 1.2) -// otherwise it will use legacy clCreateImage3D -cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format *image_format, - const cl_image_desc *image_desc, void *host_ptr, - cl_int *errcode_ret); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc deleted file mode 100644 index d1b28539f4155d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc +++ /dev/null @@ -1,655 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h" -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_texture.h" -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" -#endif // LITERT_HAS_OPENCL_SUPPORT - -using litert::Expected; -using litert::Unexpected; - -namespace { - -template -void Copy(size_t array_size, const T* array, std::vector& vec) { - vec.clear(); - vec.reserve(array_size); - std::copy(array, array + array_size, std::back_inserter(vec)); - array = vec.data(); -} - -} // namespace - -LiteRtTensorBufferT::LiteRtTensorBufferT( - const LiteRtRankedTensorType& tensor_type, - LiteRtTensorBufferType buffer_type, size_t buffer_size, - size_t buffer_offset) - : tensor_type_(tensor_type), - buffer_type_(buffer_type), - buffer_size_(buffer_size), - buffer_offset_(buffer_offset), - ref_(1) { - // Copy local memory passed by the caller. - Copy(tensor_type_.layout.rank, tensor_type_.layout.dimensions, dimensions_); - if (tensor_type_.layout.strides) { - Copy(tensor_type_.layout.rank, tensor_type_.layout.strides, strides_); - } -} - -LiteRtTensorBufferT::~LiteRtTensorBufferT() { - switch (buffer_type()) { - case kLiteRtTensorBufferTypeUnknown: - // Nothing to do. - break; - case kLiteRtTensorBufferTypeHostMemory: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeAhwb: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.ahwb); - } - break; - case kLiteRtTensorBufferTypeIon: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeDmaBuf: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeFastRpc: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeOpenCl: - // internal opencl buffer is auto-disposed by the - // litert::internal::OpenClBuffer destructor. - break; - case kLiteRtTensorBufferTypeGlBuffer: - // internal gl buffer is auto-disposed by the - // litert::internal::GlBuffer destructor. - case kLiteRtTensorBufferTypeGlTexture: - // internal gl texture is auto-disposed by the - // litert::internal::GlTexture destructor. - break; - } -} - -Expected LiteRtTensorBufferT::CreateFromHostMemory( - const LiteRtRankedTensorType& tensor_type, absl::Span host_memory, - LiteRtHostMemoryDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeHostMemory, host_memory.size())); - tensor_buffer->buffer_ = HostBuffer{ - .addr = host_memory.data(), - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedOnHostMemory( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - void* host_memory_ptr; - if (auto rc = posix_memalign( - &host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, buffer_size); - rc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate aligned memory"); - } - - LiteRtHostMemoryDeallocator deallocator = ::free; - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer, - CreateFromHostMemory( - tensor_type, - absl::MakeSpan(static_cast(host_memory_ptr), buffer_size), - deallocator)); - - return std::move(tensor_buffer); -} - -Expected LiteRtTensorBufferT::CreateFromAhwb( - const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator) { - LITERT_ASSIGN_OR_RETURN(size_t buffer_size, - litert::internal::AhwbBuffer::GetSize(ahwb)); - - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeAhwb, buffer_size, ahwb_offset)); - tensor_buffer->buffer_ = AhwbBuffer{ - .ahwb = ahwb, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManagedAhwbBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - LITERT_ASSIGN_OR_RETURN(litert::internal::AhwbBuffer buffer, - litert::internal::AhwbBuffer::Alloc(buffer_size)); - return CreateFromAhwb(tensor_type, buffer.ahwb, /*ahwb_offset=*/0, - /*deallocator=*/litert::internal::AhwbBuffer::Free); -} - -Expected LiteRtTensorBufferT::CreateFromIonBuffer( - const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator) { - if (!ion_buffer_addr) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid ION buffer address"); - } - if (ion_buffer_fd < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid ION buffer fd"); - } - - Ptr tensor_buffer( - new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeIon, - ion_buffer_size, ion_buffer_offset)); - tensor_buffer->buffer_ = IonBuffer{ - .addr = ion_buffer_addr, - .fd = ion_buffer_fd, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManagedIonBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::IonBuffer::Alloc( - buffer_size, /*alignment=*/LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); - if (!buffer) { - return Unexpected(buffer.Error()); - } - return CreateFromIonBuffer(tensor_type, buffer->addr, buffer->fd, buffer_size, - /*ion_buffer_offset=*/0, - litert::internal::IonBuffer::Free); -} - -Expected LiteRtTensorBufferT::CreateFromDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator) { - if (!dmabuf_buffer_addr) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid DMA-BUF buffer address"); - } - if (dmabuf_buffer_fd < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid DMA-BUF buffer fd"); - } - - Ptr tensor_buffer( - new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeDmaBuf, - dmabuf_buffer_size, dmabuf_buffer_offset)); - tensor_buffer->buffer_ = DmaBufBuffer{ - .addr = dmabuf_buffer_addr, - .fd = dmabuf_buffer_fd, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::DmaBufBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - return CreateFromDmaBufBuffer(tensor_type, buffer->addr, buffer->fd, - buffer_size, /*dmabuf_buffer_offset=*/0, - litert::internal::DmaBufBuffer::Free); -} - -Expected LiteRtTensorBufferT::CreateFromFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, - int fastrpc_buffer_fd, size_t fastrpc_buffer_size, - size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator) { - if (!fastrpc_buffer_addr) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid FastRPC buffer address"); - } - if (fastrpc_buffer_fd < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid FastRPC buffer fd"); - } - - Ptr tensor_buffer( - new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeFastRpc, - fastrpc_buffer_size, fastrpc_buffer_offset)); - tensor_buffer->buffer_ = FastRpcBuffer{ - .addr = fastrpc_buffer_addr, - .fd = fastrpc_buffer_fd, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::FastRpcBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - return CreateFromFastRpcBuffer(tensor_type, buffer->addr, buffer->fd, - buffer_size, /*fastrpc_buffer_offset=*/0, - litert::internal::FastRpcBuffer::Free); -} - -#if LITERT_HAS_OPENCL_SUPPORT -Expected LiteRtTensorBufferT::CreateFromOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, cl_mem buffer, - size_t buffer_size, LiteRtOpenClDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeOpenCl, buffer_size)); - tensor_buffer->buffer_.emplace( - buffer, buffer_size, deallocator); - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::OpenClBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeOpenCl, buffer_size)); - tensor_buffer->buffer_.emplace( - std::move(*buffer)); - return tensor_buffer; -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -Expected LiteRtTensorBufferT::CreateFromGlBuffer( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeGlBuffer, size_bytes)); - tensor_buffer->buffer_.emplace( - target, id, size_bytes, offset, deallocator); - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManagedGlBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::GlBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeGlBuffer, buffer_size)); - tensor_buffer->buffer_.emplace( - std::move(*buffer)); - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateFromGlTexture( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeGlTexture, size_bytes)); - tensor_buffer->buffer_.emplace( - litert::internal::GlTexture(target, id, format, size_bytes, layer, - deallocator)); - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManaged( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeHostMemory: - return CreateManagedOnHostMemory(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeAhwb: - return CreateManagedAhwbBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeIon: - return CreateManagedIonBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeDmaBuf: - return CreateManagedDmaBufBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeFastRpc: - return CreateManagedFastRpcBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeOpenCl: { -#if LITERT_HAS_OPENCL_SUPPORT - return CreateManagedOpenClBuffer(tensor_type, buffer_size); -#else - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "OpenCL buffers are not supported."); -#endif // LITERT_HAS_OPENCL_SUPPORT - } - case kLiteRtTensorBufferTypeGlBuffer: { - return CreateManagedGlBuffer(tensor_type, buffer_size); - } - case kLiteRtTensorBufferTypeGlTexture: { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "LiteRT does not support managed GL textures."); - } - default: - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected tensor type"); - } -} - -Expected LiteRtTensorBufferT::IsValid() { - // Check for static dimensions. - for (auto i = 0; i < tensor_type_.layout.rank; ++i) { - if (tensor_type_.layout.dimensions[i] <= 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer must have all static dimensions"); - } - } - - // Check for valid offset. - if (buffer_offset() >= buffer_size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Invalid buffer offset"); - } - - // Check for sufficient size. - if (auto num_bytes = litert::internal::GetNumPackedBytes(tensor_type_); - !num_bytes) { - return Unexpected(num_bytes.Error()); - } else if (*num_bytes > buffer_size() - buffer_offset()) { - const std::string error_message = absl::StrFormat( - "Insufficient buffer size: Required %d bytes, actual size %d bytes", - *num_bytes, buffer_size() - buffer_offset()); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, error_message); - } - - // Check for proper alignment. - if (buffer_type() == kLiteRtTensorBufferTypeHostMemory) { - auto host_buffer = GetHostBuffer(); - if (!host_buffer) { - return Unexpected(host_buffer.Error()); - } - if (reinterpret_cast(*host_buffer) % - LITERT_HOST_MEMORY_BUFFER_ALIGNMENT) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unaligned host memory pointer"); - } - } - - return {}; -} - -Expected LiteRtTensorBufferT::GetHostBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeHostMemory) { - return std::get(buffer_).addr; - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeHostMemory), - BufferTypeToString(buffer_type_))); -} - -Expected LiteRtTensorBufferT::GetAhwbBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeAhwb) { - return std::get(buffer_).ahwb; - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeAhwb), - BufferTypeToString(buffer_type_))); -} - -Expected> LiteRtTensorBufferT::GetIonBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeIon) { - auto buffer = std::get(buffer_); - return std::make_pair(buffer.addr, buffer.fd); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeIon), - BufferTypeToString(buffer_type_))); -} - -Expected> LiteRtTensorBufferT::GetDmaBufBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeDmaBuf) { - auto buffer = std::get(buffer_); - return std::make_pair(buffer.addr, buffer.fd); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeDmaBuf), - BufferTypeToString(buffer_type_))); -} - -Expected> LiteRtTensorBufferT::GetFastRpcBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeFastRpc) { - auto buffer = std::get(buffer_); - return std::make_pair(buffer.addr, buffer.fd); - } - - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeFastRpc), - BufferTypeToString(buffer_type_))); -} - -#if LITERT_HAS_OPENCL_SUPPORT -Expected -LiteRtTensorBufferT::GetOpenClBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeOpenCl) { - return &std::get(buffer_); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeOpenCl), - BufferTypeToString(buffer_type_))); -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -Expected LiteRtTensorBufferT::GetGlTexture() { - if (buffer_type_ != kLiteRtTensorBufferTypeGlTexture) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unexpected tensor buffer type"); - } - return &std::get(buffer_); -} - -Expected LiteRtTensorBufferT::GetGlBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeGlBuffer) { - return &std::get(buffer_); - } -#if LITERT_HAS_AHWB_SUPPORT - if (buffer_type_ == kLiteRtTensorBufferTypeAhwb) { - if (auto it = memory_backed_buffers_.find(kLiteRtTensorBufferTypeGlBuffer); - it != memory_backed_buffers_.end()) { - BufferVariant& memory_backed_buffer = it->second; - return &std::get(memory_backed_buffer); - } - // Create a new GL buffer from the AHWB buffer if not found. - litert::internal::AhwbBuffer ahwb_buffer = { - .ahwb = std::get(buffer_).ahwb}; - - LITERT_ASSIGN_OR_RETURN( - litert::internal::GlBuffer gl_buffer_from_ahwb, - litert::internal::GlBuffer::AllocFromAhwbBuffer(ahwb_buffer)); - - auto [it, inserted] = memory_backed_buffers_.insert( - {kLiteRtTensorBufferTypeGlBuffer, std::move(gl_buffer_from_ahwb)}); - LITERT_RETURN_IF_ERROR( - inserted == true, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to insert GL buffer into memory backed buffers")); - return &std::get(it->second); - } -#endif // LITERT_HAS_AHWB_SUPPORT - - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeGlBuffer), - BufferTypeToString(buffer_type_))); -} - -Expected LiteRtTensorBufferT::Lock() { - if (event_ != nullptr) { - // Only AHWB supports waiting on an input sync fence when locking the - // buffer. For all other buffer types we wait here. - if (buffer_type() != kLiteRtTensorBufferTypeAhwb) { - LITERT_RETURN_IF_ERROR(event_->Wait(/*timeout_in_ms=*/-1)); - } - } - - switch (buffer_type()) { - case kLiteRtTensorBufferTypeHostMemory: - return *GetHostBuffer(); - case kLiteRtTensorBufferTypeAhwb: - return litert::internal::AhwbBuffer::Lock( - *GetAhwbBuffer(), event_ != nullptr ? event_.get() : nullptr); - case kLiteRtTensorBufferTypeIon: - return GetIonBuffer()->first; - case kLiteRtTensorBufferTypeDmaBuf: - return GetDmaBufBuffer()->first; - case kLiteRtTensorBufferTypeFastRpc: - return GetFastRpcBuffer()->first; - case kLiteRtTensorBufferTypeOpenCl: { -#if LITERT_HAS_OPENCL_SUPPORT - auto opencl_buffer = *GetOpenClBuffer(); - auto host_memory_ptr = opencl_buffer->Lock(); - if (host_memory_ptr.HasValue()) { - return Expected(host_memory_ptr.Value()); - } else { - return Unexpected(host_memory_ptr.Error()); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenCL buffers are not supported"); -#endif // LITERT_HAS_OPENCL_SUPPORT - } - case kLiteRtTensorBufferTypeGlBuffer: { -#if LITERT_HAS_OPENGL_SUPPORT - auto gl_buffer = *GetGlBuffer(); - auto host_memory_ptr = gl_buffer->Lock(); - if (host_memory_ptr.HasValue()) { - return Expected(host_memory_ptr.Value()); - } else { - return Unexpected(host_memory_ptr.Error()); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenGL buffers are not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT - } - default: - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unexpected tensor buffer type"); - } -} - -Expected LiteRtTensorBufferT::Unlock() { - switch (buffer_type()) { - case kLiteRtTensorBufferTypeAhwb: { - auto ahwb = std::get(buffer_).ahwb; - return litert::internal::AhwbBuffer::Unlock(ahwb); - } - case kLiteRtTensorBufferTypeOpenCl: { -#if LITERT_HAS_OPENCL_SUPPORT - auto opencl_buffer = *GetOpenClBuffer(); - return opencl_buffer->Unlock(); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenCL buffers are not supported"); -#endif // LITERT_HAS_OPENCL_SUPPORT - } - case kLiteRtTensorBufferTypeGlBuffer: { -#if LITERT_HAS_OPENGL_SUPPORT - auto gl_buffer = *GetGlBuffer(); - return gl_buffer->Unlock(); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenGL buffers are not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT - } - default: - return {}; - } -} diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h deleted file mode 100644 index f0c9d8085a7e42..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_texture.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" -#endif // LITERT_HAS_OPENCL_SUPPORT - -class LiteRtTensorBufferT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtTensorBufferT(); - - // Make this class non-copiable because it includes raw pointers and resource - // handles. - LiteRtTensorBufferT(const LiteRtTensorBufferT&) = delete; - LiteRtTensorBufferT(LiteRtTensorBufferT&&) = delete; - LiteRtTensorBufferT& operator=(const LiteRtTensorBufferT&) = delete; - LiteRtTensorBufferT& operator=(LiteRtTensorBufferT&&) = delete; - - static litert::Expected CreateFromHostMemory( - const LiteRtRankedTensorType& tensor_type, - absl::Span host_memory, - LiteRtHostMemoryDeallocator deallocator = nullptr); - - static litert::Expected CreateFromAhwb( - const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator = nullptr); - - static litert::Expected CreateFromIonBuffer( - const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator = nullptr); - - static litert::Expected CreateFromDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, - LiteRtDmaBufDeallocator deallocator = nullptr); - - static litert::Expected CreateFromFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, - int fastrpc_buffer_fd, size_t fastrpc_buffer_size, - size_t fastrpc_buffer_offset, - LiteRtFastRpcDeallocator deallocator = nullptr); - -#if LITERT_HAS_OPENCL_SUPPORT - static litert::Expected CreateFromOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, cl_mem buffer, - size_t opencl_buffer_size, LiteRtOpenClDeallocator deallocator = nullptr); -#endif // LITERT_HAS_OPENCL_SUPPORT - - static litert::Expected CreateFromGlBuffer( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator = nullptr); - static litert::Expected CreateFromGlTexture( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, - LiteRtGLint layer, LiteRtGlTextureDeallocator deallocator = nullptr); - - static litert::Expected CreateManaged( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - LiteRtRankedTensorType tensor_type() const { return tensor_type_; } - LiteRtTensorBufferType buffer_type() const { return buffer_type_; } - size_t buffer_size() const { return buffer_size_; } - size_t buffer_offset() const { return buffer_offset_; } - - bool HasEvent() const { return event_ != nullptr; } - - litert::Expected GetEvent() const { - if (!HasEvent()) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer has no event"); - } - return event_.get(); - } - - void SetEvent(LiteRtEventT* e) { - // Take ownership of the event. - event_ = std::unique_ptr(e); - } - void ClearEvent() { event_ = nullptr; } - - litert::Expected GetHostBuffer(); - litert::Expected GetAhwbBuffer(); - litert::Expected> GetIonBuffer(); - litert::Expected> GetDmaBufBuffer(); - litert::Expected> GetFastRpcBuffer(); -#if LITERT_HAS_OPENCL_SUPPORT - litert::Expected GetOpenClBuffer(); -#endif // LITERT_HAS_OPENCL_SUPPORT - litert::Expected GetGlBuffer(); - litert::Expected GetGlTexture(); - - litert::Expected Lock(); - litert::Expected Unlock(); - - // Used to duplicate the current tensor buffer. Internally it increases - // reference count to the underlying buffer. - void Duplicate() const { Ref(); } - - // Increments reference count by one. - void Ref() const { ref_.fetch_add(1, std::memory_order_relaxed); } - - // Decrements reference count by one. If the count remains - // positive, returns false. When the count reaches zero, returns - // true. - bool Unref() const { - if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return true; - } - return false; - } - - // Gets the current reference count. - int RefCount() const { return ref_.load(std::memory_order_relaxed); } - - private: - struct HostBuffer { - void* addr; - LiteRtHostMemoryDeallocator deallocator; - }; - - struct AhwbBuffer { - AHardwareBuffer* ahwb; - LiteRtAhwbDeallocator deallocator; - }; - - struct IonBuffer { - void* addr; - int fd; - LiteRtIonDeallocator deallocator; - }; - - struct DmaBufBuffer { - void* addr; - int fd; - LiteRtDmaBufDeallocator deallocator; - }; - - struct FastRpcBuffer { - void* addr; - int fd; - LiteRtFastRpcDeallocator deallocator; - }; - - using BufferVariant = - std::variant; - - LiteRtTensorBufferT(const LiteRtRankedTensorType& tensor_type, - LiteRtTensorBufferType buffer_type, size_t buffer_size, - size_t buffer_offset = 0); - - static litert::Expected CreateManagedOnHostMemory( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedAhwbBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedIonBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedGlBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - litert::Expected IsValid(); - - LiteRtRankedTensorType tensor_type_; - std::vector> dimensions_; - std::vector> strides_; - LiteRtTensorBufferType buffer_type_; - size_t buffer_size_; - size_t buffer_offset_; - BufferVariant buffer_; - std::unique_ptr event_; - mutable std::atomic_int_fast32_t ref_; - // A map of memory backed buffers. Memory backed buffers are backed by the - // memory of buffer_. For example, a GL buffer can be backed by the memory of - // an AHWB buffer. - absl::flat_hash_map - memory_backed_buffers_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc deleted file mode 100644 index aac9c2b37f0af5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h" - -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h" - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -#include -#include -#include -#include - -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" - -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -namespace litert { -namespace internal { - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT - -// TODO(b/383176413): Add gl-cl interop extension. -Expected CopyGlToCl(GlBuffer& src, OpenClBuffer& dest) { - if (src.target() != GL_SHADER_STORAGE_BUFFER) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported GL target for conversion to OpenCL"); - } - size_t cl_size = dest.size_bytes(); - if (src.bytes_size() != cl_size) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GL buffer size does not match OpenCL size"); - } - LITERT_ASSIGN_OR_RETURN(void* host_src, src.Lock()); - LITERT_ASSIGN_OR_RETURN(void* host_dest, dest.Lock()); - std::memcpy(host_dest, host_src, src.bytes_size()); - LITERT_RETURN_IF_ERROR(dest.Unlock()); - LITERT_RETURN_IF_ERROR(src.Unlock()); - return {}; -} - -Expected TensorBufferConvertGlToCl( - LiteRtTensorBufferT& tensor_buffer_gl) { - // Create a new CL tensor buffer. - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer_cl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeOpenCl, - tensor_buffer_gl.tensor_type(), - tensor_buffer_gl.buffer_size())); - LITERT_ASSIGN_OR_RETURN(OpenClBuffer * cl_buffer, - tensor_buffer_cl->GetOpenClBuffer()); - LITERT_ASSIGN_OR_RETURN(GlBuffer * gl_buffer, tensor_buffer_gl.GetGlBuffer()); - CopyGlToCl(*gl_buffer, *cl_buffer); - return tensor_buffer_cl; -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_CL_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -Expected CopyGlToAhwb(GlBuffer& src, AhwbBuffer& dest) { - if (src.target() != GL_SHADER_STORAGE_BUFFER) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported GL target for conversion to AHWB"); - } - LITERT_ASSIGN_OR_RETURN(size_t ahwb_size, AhwbBuffer::GetSize(dest.ahwb)); - if (src.bytes_size() != ahwb_size) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GL buffer size does not match AHWB size"); - } - LITERT_ASSIGN_OR_RETURN(void* host_src, src.Lock()); - LITERT_ASSIGN_OR_RETURN(void* host_dest, AhwbBuffer::Lock(dest.ahwb)); - std::memcpy(host_dest, host_src, src.bytes_size()); - LITERT_RETURN_IF_ERROR(AhwbBuffer::Unlock(dest.ahwb)); - LITERT_RETURN_IF_ERROR(src.Unlock()); - return {}; -} - -Expected TensorBufferConvertGlToAhwb( - LiteRtTensorBufferT& tensor_buffer_gl) { - // Create a new AHWB tensor buffer. - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer_ahwb, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeAhwb, - tensor_buffer_gl.tensor_type(), - tensor_buffer_gl.buffer_size())); - LITERT_ASSIGN_OR_RETURN(AHardwareBuffer * ahwb, - tensor_buffer_ahwb->GetAhwbBuffer()); - AhwbBuffer ahwb_buffer{.ahwb = ahwb}; - LITERT_ASSIGN_OR_RETURN(GlBuffer * gl_buffer, tensor_buffer_gl.GetGlBuffer()); - CopyGlToAhwb(*gl_buffer, ahwb_buffer); - return tensor_buffer_ahwb; -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT -Expected CopyHostToGl(void* host_src, GlBuffer& dest) { - LITERT_ASSIGN_OR_RETURN(void* host_dest, dest.Lock()); - std::memcpy(host_dest, host_src, dest.bytes_size()); - return {}; -} - -Expected TensorBufferConvertHostToGl( - LiteRtTensorBufferT& tensor_buffer_host) { - // Create a new GL tensor buffer. - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeGlBuffer, - tensor_buffer_host.tensor_type(), - tensor_buffer_host.buffer_size())); - LITERT_ASSIGN_OR_RETURN(void* host_memory, - tensor_buffer_host.GetHostBuffer()); - LITERT_ASSIGN_OR_RETURN(GlBuffer * gl_buffer, - tensor_buffer_gl->GetGlBuffer()); - CopyHostToGl(host_memory, *gl_buffer); - return tensor_buffer_gl; -} -#endif - -Expected TensorBufferConvertHostTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeGlBuffer: -#if LITERT_HAS_OPENGL_SUPPORT - return TensorBufferConvertHostToGl(tensor_buffer); -#else - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); -#endif - default: - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); - } -} - -Expected TensorBufferConvertGlTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeAhwb: -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - return TensorBufferConvertGlToAhwb(tensor_buffer); -#else - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); -#endif - case kLiteRtTensorBufferTypeOpenCl: -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT - return TensorBufferConvertGlToCl(tensor_buffer); -#else - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); -#endif - default: - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); - } -} - -Expected TensorBufferConvertTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer) { - switch (tensor_buffer.buffer_type()) { - case kLiteRtTensorBufferTypeHostMemory: - return TensorBufferConvertHostTo(buffer_type, tensor_buffer); - case kLiteRtTensorBufferTypeGlBuffer: - return TensorBufferConvertGlTo(buffer_type, tensor_buffer); - default: - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); - } -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h deleted file mode 100644 index a3ebe1826303d8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_CONVERSION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_CONVERSION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -namespace litert::internal { - -// Converts the given tensor buffer to the specified buffer type. A new tensor -// buffer is created and returned. This function locks/unlocks the tensor buffer -// and will involve a copy. -// TODO(b/383176413): Investigate zero/fast-copy conversions. -Expected TensorBufferConvertTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_CONVERSION_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc deleted file mode 100644 index d358d1c1c2f785..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h" - -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace { - -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / - sizeof(kTensorData[0])}; - -constexpr const LiteRtRankedTensorType kTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTensorDimensions)}; - -TEST(TensorBufferConversionTest, HostToGl) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr litert_env, - LiteRtEnvironmentT::CreateWithOptions({})); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_host, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeHostMemory, - kTensorType, sizeof(kTensorData))); - // Write data to the host memory. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_memory, - tensor_buffer_host->GetHostBuffer()); - std::memcpy(host_memory, kTensorData, sizeof(kTensorData)); - -#if LITERT_HAS_OPENGL_SUPPORT - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - litert::internal::TensorBufferConvertTo(kLiteRtTensorBufferTypeGlBuffer, - *tensor_buffer_host)); - - // Ensure that data was copied correctly from host to GL. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_gl, tensor_buffer_gl->Lock()); - ASSERT_EQ(std::memcmp(host_gl, kTensorData, sizeof(kTensorData)), 0); -#else - // Since GL support is not enabled, the conversion should fail. - EXPECT_FALSE(litert::internal::TensorBufferConvertTo( - kLiteRtTensorBufferTypeGlBuffer, *tensor_buffer_host)); -#endif -} - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -TEST(TensorBufferConversionTest, GlToAhwb) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeGlBuffer, - kTensorType, sizeof(kTensorData))); - // Write data to the GL buffer. - LITERT_ASSERT_OK_AND_ASSIGN(litert::internal::GlBuffer * gl_buffer, - tensor_buffer_gl->GetGlBuffer()); - LITERT_ASSERT_OK_AND_ASSIGN(float* data, gl_buffer->Lock()); - std::memcpy(data, kTensorData, sizeof(kTensorData)); - gl_buffer->Unlock(); - - // Convert. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_ahwb, - litert::internal::TensorBufferConvertTo(kLiteRtTensorBufferTypeAhwb, - *tensor_buffer_gl)); - // Ensure that data was copied correctly from Gl to Ahwb. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_ahwb, tensor_buffer_ahwb->Lock()); - ASSERT_EQ(std::memcmp(host_ahwb, kTensorData, sizeof(kTensorData)), 0); -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT -TEST(TensorBufferConversionTest, GlToCl) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr litert_env, - LiteRtEnvironmentT::CreateWithOptions({})); - if (!litert::internal::OpenClBuffer::IsSupported()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeGlBuffer, - kTensorType, sizeof(kTensorData))); - // Write data to the GL buffer. - LITERT_ASSERT_OK_AND_ASSIGN(litert::internal::GlBuffer * gl_buffer, - tensor_buffer_gl->GetGlBuffer()); - LITERT_ASSERT_OK_AND_ASSIGN(float* data, gl_buffer->Lock()); - std::memcpy(data, kTensorData, sizeof(kTensorData)); - gl_buffer->Unlock(); - - // Convert. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_cl, - litert::internal::TensorBufferConvertTo(kLiteRtTensorBufferTypeOpenCl, - *tensor_buffer_gl)); - - // Ensure that data was copied correctly from Gl to CL. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_cl, tensor_buffer_cl->Lock()); - ASSERT_EQ(std::memcmp(host_cl, kTensorData, sizeof(kTensorData)), 0); -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT - -} // namespace diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h deleted file mode 100644 index 04f461966889b7..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_REQUIREMENTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_REQUIREMENTS_H_ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -class LiteRtTensorBufferRequirementsT { - public: - LiteRtTensorBufferRequirementsT( - int num_supported_tensor_buffer_types, - const LiteRtTensorBufferType* supported_tensor_buffer_types, - size_t buffer_size, std::vector&& strides) - : supported_buffer_types_( - supported_tensor_buffer_types, - supported_tensor_buffer_types + num_supported_tensor_buffer_types), - buffer_size_(buffer_size), - strides_(std::move(strides)) {} - const std::vector& SupportedBufferTypes() const { - return supported_buffer_types_; - } - size_t BufferSize() const { return buffer_size_; } - const std::vector& Strides() const { return strides_; } - - private: - std::vector supported_buffer_types_; - size_t buffer_size_; - // Stride per each dimension. - std::vector strides_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc b/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc deleted file mode 100644 index 37e419c68b00ca..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tfl_utils.h" - -#include -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -namespace litert { -namespace internal { - -Expected ConvertElementType(TfLiteType tfl_type) { - switch (tfl_type) { - case kTfLiteNoType: - return ElementType::None; - case kTfLiteBool: - return ElementType::Bool; - case kTfLiteInt4: - return ElementType::Int4; - case kTfLiteInt8: - return ElementType::Int8; - case kTfLiteInt16: - return ElementType::Int16; - case kTfLiteInt32: - return ElementType::Int32; - case kTfLiteInt64: - return ElementType::Int64; - case kTfLiteUInt8: - return ElementType::UInt8; - case kTfLiteUInt16: - return ElementType::UInt16; - case kTfLiteUInt32: - return ElementType::UInt32; - case kTfLiteUInt64: - return ElementType::UInt64; - case kTfLiteFloat16: - return ElementType::Float16; - case kTfLiteBFloat16: - return ElementType::BFloat16; - case kTfLiteFloat32: - return ElementType::Float32; - case kTfLiteFloat64: - return ElementType::Float64; - case kTfLiteComplex64: - return ElementType::Complex64; - case kTfLiteComplex128: - return ElementType::Complex128; - case kTfLiteResource: - return ElementType::TfResource; - case kTfLiteString: - return ElementType::TfString; - case kTfLiteVariant: - return ElementType::TfVariant; - default: - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unsupported TfLiteType"); - } -} - -Expected ConvertTensorType( - const TfLiteOpaqueTensor* tfl_opaque_tensor) { - auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor); - auto element_type = ConvertElementType(tfl_type); - if (!element_type) { - return Unexpected(element_type.Error()); - } - - size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor); - Dimensions dimensions(rank); - for (size_t i = 0; i < rank; ++i) { - dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i); - } - - return RankedTensorType(*element_type, Layout(std::move(dimensions))); -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/tfl_utils.h b/tensorflow/lite/experimental/litert/runtime/tfl_utils.h deleted file mode 100644 index 8874c7535f3d6a..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tfl_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -struct TfLiteOpaqueTensor; - -namespace litert::internal { - -Expected ConvertElementType(TfLiteType tfl_type); - -Expected ConvertTensorType( - const TfLiteOpaqueTensor* tfl_opaque_tensor); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/test/BUILD b/tensorflow/lite/experimental/litert/test/BUILD deleted file mode 100644 index 14b60339d2726b..00000000000000 --- a/tensorflow/lite/experimental/litert/test/BUILD +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:tfl_model_gen.bzl", "tfl_model_gen") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - # copybara:uncomment "//third_party/mediapipe/calculators/tensor:__subpackages__", - # copybara:uncomment "//third_party/odml/litert:__subpackages__", - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -tfl_model_gen( - name = "mlir_test_data", - srcs = glob(["testdata/*.mlir"]), -) - -filegroup( - name = "tflite_test_data", - srcs = glob(["testdata/*.tflite"]), -) - -cc_library( - name = "common", - testonly = 1, - srcs = [ - "common.cc", - ], - hdrs = [ - "common.h", - ], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform", - ], -) - -cc_library( - name = "simple_model", - testonly = 1, - hdrs = [ - "testdata/simple_model_test_vectors.h", - ], - data = [ - "testdata/simple_model.tflite", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - ], -) - -cc_library( - name = "simple_model_npu", - testonly = 1, - srcs = [], - hdrs = [ - "testdata/simple_model_test_vectors.h", - ], - data = [ - "testdata/simple_model_google_tensor.bin", - "testdata/simple_model_mtk.bin", - "testdata/simple_model_npu.tflite", - "testdata/simple_model_qualcomm.bin", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - ], -) - -cc_library( - name = "simple_cascade_model_npu", - testonly = 1, - srcs = [], - hdrs = [ - "testdata/simple_model_test_vectors.h", - ], - data = [ - "testdata/simple_cascade_model_npu.tflite", - "testdata/simple_model_google_tensor.bin", - "testdata/simple_model_mtk.bin", - "testdata/simple_model_qualcomm.bin", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - ], -) - -cc_library( - name = "test_models", - hdrs = ["test_models.h"], - deps = [ - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "matchers", - testonly = True, - hdrs = ["matchers.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -cc_test( - name = "matchers_test", - srcs = ["matchers_test.cc"], - deps = [ - ":matchers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_googletest//:gtest_main", - ], -) - -# Use this library if you want to enforce an OSS environment for your test. -cc_library( - name = "matchers_oss", - testonly = True, - hdrs = ["matchers.h"], - defines = ["LITERT_DEFINE_GTEST_STATUS_PRINTER"], - tags = ["avoid_dep"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -exports_files(srcs = [ - "testdata/mobilenet_v2_1.0_224.tflite", - "testdata/simple_model_google_tensor.bin", - "testdata/simple_model_mtk.bin", - "testdata/simple_model_qualcomm.bin", -]) diff --git a/tensorflow/lite/experimental/litert/test/common.cc b/tensorflow/lite/experimental/litert/test/common.cc deleted file mode 100644 index 8744bcc14cfbed..00000000000000 --- a/tensorflow/lite/experimental/litert/test/common.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/test/common.h" - -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tsl/platform/platform.h" - -namespace litert::testing { - -Expected UniqueTestDirectory::Create() { - constexpr size_t kMaxTries = 1000; - ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - - // We don't want multiple threads to create the same directory. - absl::MutexLock l(&mutex); - - auto tmp_dir = std::filesystem::temp_directory_path(); - std::random_device dev; - std::mt19937 prng(dev()); - std::uniform_int_distribution rand(0); - std::stringstream ss; - - for (auto i = 0; i < kMaxTries; ++i) { - ss.clear(); - ss << std::hex << rand(prng); - auto path = tmp_dir / ss.str(); - if (std::filesystem::create_directory(path)) { - LITERT_LOG(LITERT_INFO, "Created unique temporary directory %s", - path.c_str()); - return UniqueTestDirectory(path); - } - } - - return Error(kLiteRtStatusErrorRuntimeFailure, - "Could not create a unique temporary directory"); -} - -UniqueTestDirectory::~UniqueTestDirectory() { - std::filesystem::remove_all(tmpdir_); -} - -std::string GetTestFilePath(absl::string_view filename) { - static constexpr absl::string_view kTestDataDir = - "tensorflow/lite/experimental/litert/" - "test/testdata/"; - - if constexpr (!tsl::kIsOpenSource) { - return internal::Join({"third_party", kTestDataDir, filename}); - } else { - return internal::Join({kTestDataDir, filename}); - } -} - -std::string GetTfliteFilePath(absl::string_view filename) { - static constexpr absl::string_view kTestDataDir = "tensorflow/lite/"; - - if constexpr (!tsl::kIsOpenSource) { - return internal::Join({"third_party", kTestDataDir, filename}); - } else { - return internal::Join({kTestDataDir, filename}); - } -} - -std::string GetLiteRtPath(absl::string_view rel_path) { - static constexpr absl::string_view kLiteRtRoot = - "tensorflow/lite/experimental/litert/"; - - if constexpr (!tsl::kIsOpenSource) { - return internal::Join({"third_party", kLiteRtRoot, rel_path}); - } else { - return internal::Join({kLiteRtRoot, rel_path}); - } -} - -Model LoadTestFileModel(absl::string_view filename) { - return *Model::CreateFromFile(GetTestFilePath(filename)); -} - -Expected TflRuntime::CreateFromFlatBuffer( - internal::FlatbufferWrapper::Ptr flatbuffer) { - ::tflite::Interpreter::Ptr interp; - tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder(flatbuffer->FlatbufferModel(), resolver)(&interp); - if (interp == nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure); - } - return TflRuntime::Ptr( - new TflRuntime(std::move(flatbuffer), std::move(interp))); -} - -} // namespace litert::testing diff --git a/tensorflow/lite/experimental/litert/test/common.h b/tensorflow/lite/experimental/litert/test/common.h deleted file mode 100644 index 6b7b20e989b7a3..00000000000000 --- a/tensorflow/lite/experimental/litert/test/common.h +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/interpreter.h" - -namespace litert::testing { - -// A x-platform compatible replacement for testing::UniqueTestDirectory. -class UniqueTestDirectory { - public: - static Expected Create(); - ~UniqueTestDirectory(); - - UniqueTestDirectory(const UniqueTestDirectory&) = delete; - UniqueTestDirectory(UniqueTestDirectory&&) = default; - UniqueTestDirectory& operator=(const UniqueTestDirectory&) = delete; - UniqueTestDirectory& operator=(UniqueTestDirectory&&) = default; - - absl::string_view Str() const { return tmpdir_; } - - private: - explicit UniqueTestDirectory(std::string&& tmpdir) - : tmpdir_(std::move(tmpdir)) {} - std::string tmpdir_; -}; - -// Gets the path to the given filename in the testdata directory. -std::string GetTestFilePath(absl::string_view filename); - -// Gets a path to the given filename in the tflite directory. -std::string GetTfliteFilePath(absl::string_view filename); - -// Gets a full path given a path relative to the litert directory. -std::string GetLiteRtPath(absl::string_view rel_path); - -Model LoadTestFileModel(absl::string_view filename); - -class TflRuntime { - public: - using Ptr = std::unique_ptr; - - static Expected CreateFromFlatBuffer( - internal::FlatbufferWrapper::Ptr flatbuffer); - - ::tflite::Interpreter& Interpreter() { return *interpreter_; } - - const internal::FlatbufferWrapper& Flatbuffer() const { return *flatbuffer_; } - - private: - TflRuntime(internal::FlatbufferWrapper::Ptr flatbuffer, - ::tflite::Interpreter::Ptr interpreter) - : flatbuffer_(std::move(flatbuffer)), - interpreter_(std::move(interpreter)) {} - - internal::FlatbufferWrapper::Ptr flatbuffer_; - ::tflite::Interpreter::Ptr interpreter_; -}; - -inline Expected MakeRuntimeFromTestFile( - absl::string_view filename) { - auto flatbuffer = - internal::FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(filename)); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return TflRuntime::CreateFromFlatBuffer(std::move(*flatbuffer)); -} - -inline Expected MakeRuntimeFromTestFileWithNpuModel( - absl::string_view filename, absl::string_view npu_filename) { - auto buf = internal::GetModelBufWithByteCode(GetTestFilePath(filename), - GetTestFilePath(npu_filename)); - if (!buf) { - return buf.Error(); - } - auto flatbuffer = - internal::FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return TflRuntime::CreateFromFlatBuffer(std::move(*flatbuffer)); -} - -} // namespace litert::testing - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/test/matchers.h b/tensorflow/lite/experimental/litert/test/matchers.h deleted file mode 100644 index 7db2c43435c2ca..00000000000000 --- a/tensorflow/lite/experimental/litert/test/matchers.h +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_MATCHERS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_MATCHERS_H_ - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -// Is equivalent to `ASSERT_THAT(expr, testing::litert::IsOk())` -#define LITERT_ASSERT_OK(EXPR) ASSERT_THAT((EXPR), ::testing::litert::IsOk()) - -// Is equivalent to `EXPECT_THAT(expr, testing::litert::IsOk())` -#define LITERT_EXPECT_OK(EXPR) EXPECT_THAT((EXPR), ::testing::litert::IsOk()) - -// Checks that the result of `EXPR` (a `litert::Expected` object) is not an -// error and assigns the value it holds to `DECL` as if: -// ``` -// DECL = std::move(EXPR.Value()); -// ``` -// -// ```cpp -// Expected BuildSomething(); -// -// Will fail the test if `BuildSomething()`'s returned value holds an error. -// Otherwise defines and assigns the returned `Something` value to `smth` -// ASSERT_OK_AND_ASSIGN(Something smth, BuildSomething()); -// ``` -#define LITERT_ASSERT_OK_AND_ASSIGN(DECL, EXPR) \ - LITERT_ASSERT_OK_AND_ASSIGN_HELPER2(__LINE__, DECL, EXPR) - -#define LITERT_ASSERT_OK_AND_ASSIGN_HELPER1(LINE, DECL, EXPR) \ - auto&& litert_expected_value_or_error_##LINE = (EXPR); \ - LITERT_ASSERT_OK(litert_expected_value_or_error_##LINE); \ - DECL = std::move(litert_expected_value_or_error_##LINE.Value()); - -#define LITERT_ASSERT_OK_AND_ASSIGN_HELPER2(LINE, DECL, EXPR) \ - LITERT_ASSERT_OK_AND_ASSIGN_HELPER1(LINE, DECL, EXPR) - -namespace testing::litert { - -// Matches `litert::Expected` values that hold a success value and -// `LiteRtStatusOk`. -// -// See `IsOk()` function below for usage examples. -class IsOkMatcher { - public: - // Implicitly builds and wraps the matcher implementation in a GTest - // Matcher object. - template - // NOLINTNEXTLINE(*-explicit-constructor): This needs to be implicit. - operator testing::Matcher() const { - return testing::Matcher(new Impl()); - } - - template - class Impl : public testing::MatcherInterface { - template - bool MatchAndExplainImpl(const ::litert::Expected& value, - testing::MatchResultListener* listener) const { - return value.HasValue(); - } - - bool MatchAndExplainImpl(const ::litert::Unexpected& unexpected, - testing::MatchResultListener* listener) const { - return false; - } - - bool MatchAndExplainImpl(const ::litert::Error& e, - testing::MatchResultListener* listener) const { - return false; - } - - bool MatchAndExplainImpl(const LiteRtStatus& status, - testing::MatchResultListener* listener) const { - if (status != kLiteRtStatusOk) { - *listener << "status is " << LiteRtGetStatusString(status); - return false; - } - return true; - } - - public: - using is_gtest_matcher = void; - - bool MatchAndExplain( - V value, testing::MatchResultListener* listener) const override { - return MatchAndExplainImpl(value, listener); - } - - void DescribeTo(std::ostream* os) const override { - if (os) { - *os << "is ok."; - } - } - - void DescribeNegationTo(std::ostream* os) const override { - if (os) { - *os << "is not ok."; - } - } - }; -}; - -// Matches `litert::Expected` values that hold a success value and -// `LiteRtStatusOk`. -// -// Note: you might want to use the convenience macros: -// - `LITERT_EXPECT_OK(expr)` -// - `LITERT_ASSERT_OK(expr)` -// - `ASSERT_OK_AND_ASSIGN(type var, expr)` -// -// ```cpp -// LiteRtStatus DoSomething(); -// -// // Will fail the test if DoSomething() doesn't return kLiteRtStatusOk. -// EXPECT_THAT(DoSomething(), IsOk()); -// ``` -// -// This also works for `Expected` objects. -// -// Note: You probably want `ASSERT_OK_AND_ASSIGN` when working with `Expected`. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned value holds an error. -// // Note that the returned value is unused. -// EXPECT_THAT(BuildSomething(), IsOk()); -// ``` -inline IsOkMatcher IsOk() { return IsOkMatcher(); } - -// Matches `litert::Expected` values that hold an error and -// `LiteRtStatusError*` values. -// -// See `IsError(...)` functions below for usage examples. -class IsErrorMatcher { - public: - IsErrorMatcher(std::optional status, - std::optional msg) - : impl_(status, msg) {} - - // Implicitly builds and wraps the matcher implementation in a GTest - // Matcher object. - template - // NOLINTNEXTLINE(*-explicit-constructor): This needs to be implicit. - operator testing::Matcher() const { - return testing::Matcher(new Impl(impl_)); - } - - private: - class ImplBase { - public: - ImplBase() = default; - - explicit ImplBase(std::optional status, - std::optional msg) - : status_(status), msg_(std::move(msg)) {}; - - protected: - bool MatchAndExplainImpl(const LiteRtStatus status, - const absl::string_view msg, - testing::MatchResultListener* listener) const { - if (status == kLiteRtStatusOk || - (status_.has_value() && status != status_.value())) { - if (listener) { - *listener << "status doesn't match"; - } - return false; - } - if (msg_.has_value() && msg != msg_.value()) { - if (listener) { - *listener << "message doesn't match"; - } - return false; - } - return true; - } - - template - bool MatchAndExplainImpl(const ::litert::Expected& value, - testing::MatchResultListener* listener) const { - if (value.HasValue()) { - *listener << "expected holds a value (but should hold an error)"; - return false; - } - return MatchAndExplainImpl(value.Error(), listener); - } - - bool MatchAndExplainImpl(const ::litert::Unexpected& e, - testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(e.Error().Status(), e.Error().Message(), - listener); - } - - bool MatchAndExplainImpl(const ::litert::Error& e, - testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(e.Status(), e.Message(), listener); - } - - bool MatchAndExplainImpl(const LiteRtStatus& status, - testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(status, {}, listener); - } - - void DescribeImpl(std::ostream* os, const bool negation) const { - if (os) { - *os << "is" << (negation ? " not" : "") << " an error"; - const char* sep = " with "; - if (status_.has_value()) { - *os << sep << "status " << LiteRtGetStatusString(status_.value()); - sep = " and "; - } - if (msg_.has_value()) { - *os << sep << "message matching: '" << msg_.value() << "'"; - } - *os << "."; - } - } - - private: - std::optional status_; - std::optional msg_; - }; - - template - class Impl : public testing::MatcherInterface, ImplBase { - public: - using is_gtest_matcher = void; - - Impl() = default; - explicit Impl(const ImplBase& base) : ImplBase(base) {} - - bool MatchAndExplain( - V value, testing::MatchResultListener* listener) const override { - return MatchAndExplainImpl(value, listener); - } - - void DescribeTo(std::ostream* os) const override { - DescribeImpl(os, /*negation=*/false); - } - - void DescribeNegationTo(std::ostream* os) const override { - DescribeImpl(os, /*negation=*/true); - } - }; - - ImplBase impl_; -}; - -// Matches `litert::Expected`, `litert::Unexpected`, `litert::Error` and -// `LiteRtStatus` values that hold an error. -// -// Note: This will always match `true` for `litert::Unexpected` and -// `litert::Error`. This can be useful to test template code that might always -// return an error for certain specialisations. -// -// ```cpp -// LiteRtStatus DoSomething(); -// -// // Will fail the test if `DoSomething()` returns `kLiteRtStatusOk`. -// EXPECT_THAT(DoSomething(), IsError()); -// ``` -// -// This also works for `Expected` objects. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned object holds a value. -// EXPECT_THAT(BuildSomething(), IsError()); -// ``` -inline IsErrorMatcher IsError() { - return IsErrorMatcher(/*status=*/std::nullopt, /*msg=*/std::nullopt); -} - -// Matches `litert::Expected`, `litert::Unexpected`, `litert::Error` and -// `LiteRtStatus` values that hold a specific error status. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned object holds a value or -// // if the error status is not `kLiteRtStatusErrorSystemError`. -// EXPECT_THAT(BuildSomething(), IsError(kLiteRtStatusErrorSystemError)); -// ``` -inline IsErrorMatcher IsError(LiteRtStatus status) { - return IsErrorMatcher(status, /*msg=*/std::nullopt); -} - -// Matches `litert::Expected` and `LiteRtStatus` values that have a specific -// error status and error message. -// -// Warning: This will always return `false` for `LiteRtStatus` objects as those -// do not convey a message. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned object holds a value. -// EXPECT_THAT(BuildSomething(), IsError(kLiteRtStatusErrorSystemError, -// "System is not initialised")); -// ``` -inline IsErrorMatcher IsError(LiteRtStatus status, std::string msg) { - return IsErrorMatcher(status, std::move(msg)); -} - -} // namespace testing::litert - -// GTest doesn't use `AbslStringify` if `GTEST_USE_ABSL` is not defined. This -// provides a fallback implementation. -// -// This is defined here instead of with `litert::Expected` because those -// functions should only be used for testing. -#if defined(LITERT_DEFINE_GTEST_STATUS_PRINTER) && !defined(GTEST_USE_ABSL) -#include "absl/strings/str_format.h" - -// GTest documentation explicitly states that functions the those below must -// live in the same namespace as the classes they are used with so that GTest -// can find them through ADL. -namespace litert { - -inline void PrintTo(const Error& error, std::ostream* os) { - *os << absl::StrFormat("%v", error); -} - -inline void PrintTo(const Unexpected& unexpected, std::ostream* os) { - *os << absl::StrFormat("%v", unexpected); -} - -template -void PrintTo(const Expected& expected, std::ostream* os) { - *os << absl::StrFormat("%v", expected); -} - -} // namespace litert - -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_MATCHERS_H_ diff --git a/tensorflow/lite/experimental/litert/test/matchers_test.cc b/tensorflow/lite/experimental/litert/test/matchers_test.cc deleted file mode 100644 index 1acfdf282a1810..00000000000000 --- a/tensorflow/lite/experimental/litert/test/matchers_test.cc +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#include - -#include -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -using litert::Error; -using litert::Expected; -using litert::Unexpected; -using testing::Not; -using testing::StrEq; -using testing::litert::IsError; -using testing::litert::IsOk; - -namespace { - -struct CopyOnly { - CopyOnly() = default; - CopyOnly(const CopyOnly&) = default; - CopyOnly& operator=(const CopyOnly&) = default; -}; - -struct MoveOnly { - MoveOnly() = default; - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; -}; - -TEST(IsOkMatcherTest, Works) { - const Expected error = Error(kLiteRtStatusErrorNotFound, "not found"); - EXPECT_THAT(kLiteRtStatusOk, IsOk()); - EXPECT_THAT(Expected(3), IsOk()); - - EXPECT_THAT(error, Not(IsOk())); - EXPECT_THAT(Unexpected(kLiteRtStatusErrorFileIO), Not(IsOk())); - EXPECT_THAT(Error(kLiteRtStatusErrorInvalidArgument), Not(IsOk())); - - EXPECT_THAT(kLiteRtStatusErrorUnsupported, Not(IsOk())); - - EXPECT_THAT(testing::DescribeMatcher>(IsOk()), StrEq("is ok.")); - EXPECT_THAT(testing::DescribeMatcher>(Not(IsOk())), - StrEq("is not ok.")); - - testing::StringMatchResultListener listener; - EXPECT_FALSE(testing::ExplainMatchResult( - IsOk(), kLiteRtStatusErrorUnsupported, &listener)); - EXPECT_THAT(listener.str(), StrEq("status is kLiteRtStatusErrorUnsupported")); - - listener.Clear(); - EXPECT_FALSE(testing::ExplainMatchResult(IsOk(), error, &listener)); - EXPECT_THAT(listener.str(), StrEq("")); - - listener.Clear(); - EXPECT_FALSE(testing::ExplainMatchResult(IsOk(), error.Error(), &listener)); - EXPECT_THAT(listener.str(), StrEq("")); -} - -// No, I'm not creating a templated test fixture just for that. This only -// contains non-fatal failures that are propagated to the test. -// -// The type of the error wrapper that fails is the test failure stack trace when -// debug options are specified. -template -void TestErrorWrapper() { - const ErrorWrapper error = Error(kLiteRtStatusErrorNotFound, "not found"); - EXPECT_THAT(error, IsError()); - EXPECT_THAT(error, IsError(kLiteRtStatusErrorNotFound)); - EXPECT_THAT(error, IsError(kLiteRtStatusErrorNotFound, "not found")); - // This checks against the wrong status. - EXPECT_THAT(error, Not(IsError(kLiteRtStatusErrorInvalidArgument))); - // This checks against the wrong message. - EXPECT_THAT(error, Not(IsError(kLiteRtStatusErrorNotFound, "oob"))); - - testing::StringMatchResultListener listener; - EXPECT_FALSE(testing::ExplainMatchResult( - IsError(kLiteRtStatusErrorInvalidArgument), error, &listener)); - EXPECT_THAT(listener.str(), StrEq("status doesn't match")); - - listener.Clear(); - EXPECT_FALSE(testing::ExplainMatchResult( - IsError(kLiteRtStatusErrorNotFound, "oob"), error, &listener)); - EXPECT_THAT(listener.str(), StrEq("message doesn't match")); -} - -TEST(IsErrorMatcherTest, Works) { - TestErrorWrapper>(); - TestErrorWrapper(); - TestErrorWrapper(); - - EXPECT_THAT(kLiteRtStatusErrorUnsupported, IsError()); - EXPECT_THAT(kLiteRtStatusOk, Not(IsError())); - EXPECT_THAT(Expected(3), Not(IsError())); - - EXPECT_THAT(testing::DescribeMatcher>(IsError()), - StrEq("is an error.")); - EXPECT_THAT(testing::DescribeMatcher>(Not(IsError())), - StrEq("is not an error.")); - EXPECT_THAT( - testing::DescribeMatcher>( - IsError(kLiteRtStatusErrorUnsupported)), - testing::StrEq("is an error with status kLiteRtStatusErrorUnsupported.")); - EXPECT_THAT(testing::DescribeMatcher>( - IsError(kLiteRtStatusErrorUnsupported, "unsupported")), - testing::StrEq("is an error with status " - "kLiteRtStatusErrorUnsupported and message " - "matching: 'unsupported'.")); - - testing::StringMatchResultListener listener; - EXPECT_FALSE( - testing::ExplainMatchResult(IsError(), kLiteRtStatusOk, &listener)); - EXPECT_THAT(listener.str(), StrEq("status doesn't match")); - - listener.Clear(); - EXPECT_FALSE( - testing::ExplainMatchResult(IsError(), Expected(3), &listener)); - EXPECT_THAT(listener.str(), - StrEq("expected holds a value (but should hold an error)")); -} - -TEST(LitertAssertOk, Works) { - LITERT_ASSERT_OK(Expected(3)); - LITERT_ASSERT_OK(kLiteRtStatusOk); - EXPECT_FATAL_FAILURE( - LITERT_ASSERT_OK(Error(kLiteRtStatusErrorInvalidArgument)), "is ok"); -} -TEST(LitertExpectOk, Works) { - LITERT_EXPECT_OK(Expected(3)); - LITERT_EXPECT_OK(kLiteRtStatusOk); - EXPECT_NONFATAL_FAILURE( - LITERT_EXPECT_OK(Error(kLiteRtStatusErrorInvalidArgument)), "is ok"); -} - -TEST(AssertOkAndAssign, DefineAVariableWorks) { - LITERT_ASSERT_OK_AND_ASSIGN(auto expected, Expected(3)); - static_assert(std::is_same_v, - "Type should be deduced to int."); - EXPECT_EQ(expected, 3); - - LITERT_ASSERT_OK_AND_ASSIGN([[maybe_unused]] auto copy_only, - Expected(CopyOnly())); - LITERT_ASSERT_OK_AND_ASSIGN([[maybe_unused]] auto move_only, - Expected(MoveOnly())); -} - -TEST(AssertOkAndAssign, AssignAVariableWorks) { - int expected = 0; - LITERT_ASSERT_OK_AND_ASSIGN(expected, Expected(3)); - EXPECT_EQ(expected, 3); - - [[maybe_unused]] CopyOnly copy_only; - [[maybe_unused]] MoveOnly move_only; - LITERT_ASSERT_OK_AND_ASSIGN(copy_only, Expected(CopyOnly())); - LITERT_ASSERT_OK_AND_ASSIGN(move_only, Expected(MoveOnly())); -} - -void TestAssertOkAndAssignFailure() { - LITERT_ASSERT_OK_AND_ASSIGN( - [[maybe_unused]] int expected, - Expected(Unexpected(kLiteRtStatusErrorInvalidArgument))); -} - -TEST(AssertOkAndAssign, FailuresStopsExecution) { - EXPECT_FATAL_FAILURE(TestAssertOkAndAssignFailure(), "is ok"); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/test/test_models.h b/tensorflow/lite/experimental/litert/test/test_models.h deleted file mode 100644 index ddad473d40bb42..00000000000000 --- a/tensorflow/lite/experimental/litert/test/test_models.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -// /////////////////////////////////////////////////////////////////////////// -// FP32 models. -// /////////////////////////////////////////////////////////////////////////// - -// Attention sub-module of a toy model. -static constexpr absl::string_view kAttentionModel = "attention.tflite"; - -// Attention vector einsum sub-module of a toy LLM. -static constexpr absl::string_view kAttnVecEinsumModel = - "attn_vec_einsum.tflite"; - -// Feed forward sub-module of a toy LLM. -static constexpr absl::string_view kFeedForwardModel = "ff.tflite"; - -// Key einsume sub-module of a toy LLM. -static constexpr absl::string_view kKeyEinsumModel = "k_einsum.tflite"; - -// Value einsum sub-module of a toy LLM. -static constexpr absl::string_view kValueEinsumModel = "v_einsum.tflite"; - -// Query einsum sub-module of a toy LLM. -static constexpr absl::string_view kQueryEinsumModel = "q_einsum.tflite"; - -// RMS Normalization sub-module of a toy LLM. -static constexpr absl::string_view kRMSNormModel = "norm.tflite"; - -// ROPE sub-module of a toy LLM. -static constexpr absl::string_view kROPEModel = "rope.tflite"; - -// ROPE sub-module of a toy LLM, uses embedding_lookup op for sin/cos. -static constexpr absl::string_view kLookUpROPEModel = "lookup_rope.tflite"; - -// Scale dot product attentionsub-module of a toy LLM. -static constexpr absl::string_view kSDPAModel = "sdpa.tflite"; - -// Transformer block sub-module of a toy LLM. -static constexpr absl::string_view kTransformerBlockModel = - "transformer.tflite"; - -// /////////////////////////////////////////////////////////////////////////// -// Quantized models. -// /////////////////////////////////////////////////////////////////////////// - -// Quantized model with a single mul op. -// Mul: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> -static constexpr absl::string_view kQSimpleMul16x16Model = "mul_quant.tflite"; - -// Quantized model with a mul op and a add op. -// Mul: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> -// Add: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> -static constexpr absl::string_view kQMulAdd16x16Model = - "simple_quantized_ops.tflite"; - -// Single add op i16 activations and i8 weights and dynamic shape. -// Add: , -> -static constexpr absl::string_view kQSingleDynAdd16x8Model = - "single_add_default_a16w8_recipe_quantized.tflite"; - -// Single add op i8 activations and i8 weights and dynamic shape. -// Add: , -> -static constexpr absl::string_view kQSingleDynAdd8x8Model = - "single_add_default_a8w8_recipe_quantized.tflite"; - -// Single mul op i16 activations and i8 weights and dynamic shape. -// Mul: , -> -static constexpr absl::string_view kQSingleDynMul16x8Model = - "single_mul_default_a16w8_recipe_quantized.tflite"; - -// Single mul op i8 activations and i8 weights and dynamic shape. -// Mul: , -> -static constexpr absl::string_view kQSingleDynMul8x8Model = - "single_mul_default_a8w8_recipe_quantized.tflite"; - -// Single rsqrt op i16 activations and i8 weights and dynamic shape. -// RSQRT: -> -static constexpr absl::string_view kQSingleDynRsqrt16x8Model = - "single_rsqrt_default_a16w8_recipe_quantized.tflite"; - -// Single rsqrt op i8 activations and i8 weights and dynamic shape. -// RSQRT: -> -static constexpr absl::string_view kQSingleDynRsqrt8x8Model = - "single_rsqrt_default_a8w8_recipe_quantized.tflite"; - -// Quantized einsum model with i16 activations and i8 weights. -static constexpr absl::string_view kQQueryEinsum16x8Model = - "static_w8_a16_quantized_q_einsum.tflite"; - -static constexpr absl::string_view kQKeyEinsum16x8Model = - "static_w8_a16_quantized_k_einsum.tflite"; - -static constexpr absl::string_view kQVauleEinsum16x8Model = - "static_w8_a16_quantized_v_einsum.tflite"; - -static constexpr absl::string_view kQAttnVecEinsum16x8Model = - "static_w8_a16_quantized_attn_vec_einsum.tflite"; - -static constexpr absl::string_view kQSDPAModel = - "static_a8w8_quantized_sdpa.tflite"; - -// All the quantized test models. -static constexpr auto kAllQModels = absl::MakeConstSpan((absl::string_view[]){ - kQSimpleMul16x16Model, kQMulAdd16x16Model, kQSingleDynAdd16x8Model, - kQSingleDynAdd8x8Model, kQSingleDynMul16x8Model, kQSingleDynMul8x8Model, - kQSingleDynRsqrt16x8Model, kQSingleDynRsqrt8x8Model}); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ diff --git a/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir b/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir deleted file mode 100644 index 502a32a7845190..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir b/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir deleted file mode 100644 index 32945b4c8be23c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir deleted file mode 100644 index e6f996a706f619..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x32x2xf32>, %arg1: tensor<8x100x1x2xf32>) -> tensor<8x100x32x2xf32> { - %0 = "tfl.cos"(%arg1) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> - %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<8x100x32x2xf32>, tensor<8x100x1x2xf32>) -> tensor<8x100x32x2xf32> - return %1 : tensor<8x100x32x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir deleted file mode 100644 index 8a11bf4f58ba4f..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir +++ /dev/null @@ -1,12 +0,0 @@ -module { - func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "tfl.pseudo_const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> - %1 = tfl.mul %arg0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> - return %1 : tensor<4xf32> - } - func.func @other(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "tfl.pseudo_const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> - %1 = tfl.mul %arg0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> - return %1 : tensor<4xf32> - } -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir b/tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir deleted file mode 100644 index 7024ce189b7745..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor) -> tensor { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor - return %0 : tensor -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir b/tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir deleted file mode 100644 index a3db1d9a887a65..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x128xf32>, %arg1: tensor<128x128xf32>, %arg2: none) -> tensor<8x100x128xf32> { - %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<8x100x128xf32>, tensor<128x128xf32>, none) -> tensor<8x100x128xf32> - return %0 : tensor<8x100x128xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir b/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir deleted file mode 100644 index dd02656c2f370f..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %1 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir deleted file mode 100644 index 60a65cdfe4f38c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir +++ /dev/null @@ -1,21 +0,0 @@ -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = stablehlo.composite "odml.npu_call" %arg0, %arg1 {decomposition = @decomp1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = stablehlo.composite "odml.regular_composite" %arg0, %0 {decomposition = @decomp2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %2 = stablehlo.composite "odml.npu_call" %arg0, %1 {decomposition = @decomp3} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %2 : tensor<2x2xf32> -} - -func.func private @decomp1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp3(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir deleted file mode 100644 index 433d166fe3c1f5..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %2 = tfl.sub %1, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %3 = tfl.sub %2, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %3 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir deleted file mode 100644 index 7c1f0fe4e0f5b0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir +++ /dev/null @@ -1,21 +0,0 @@ -module { - -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[-1.0, -1.0, -1.0, -1.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} - -func.func @func1(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} - -func.func @func2(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} - -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir deleted file mode 100644 index 607100dbc389b6..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir +++ /dev/null @@ -1,13 +0,0 @@ -module { - -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func @func1(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4x4xf32> - return %0 : tensor<4x4xf32> -} - -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir deleted file mode 100644 index 617c27db761e44..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - %1 = tfl.add %0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> - %2 = tfl.add %1, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %2 : tensor<4xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir deleted file mode 100644 index 32ca6e26f2bfc9..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir +++ /dev/null @@ -1,14 +0,0 @@ -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = stablehlo.composite "odml.npu_call" %arg0, %arg1 {decomposition = @decomp1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = stablehlo.composite "odml.regular_composite" %arg0, %arg1 {decomposition = @decomp2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir deleted file mode 100644 index afabf1903ee846..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir b/tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir deleted file mode 100644 index 476c9829a5bd92..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir +++ /dev/null @@ -1,16 +0,0 @@ -module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00zs\F5|\1F\CE)\0D\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.10.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<8x128x1024xf32> {tf_saved_model.index_path = ["args_0"]}) -> (tensor<8x128x1024xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_args_0:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<8x128x1024xf32> - %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = "tfl.sum"(%0, %1) <{keep_dims = false}> : (tensor<8x128x1024xf32>, tensor<1xi32>) -> tensor<8x128xf32> - %3 = "tfl.pseudo_const"() <{value = dense<1.024000e+03> : tensor}> : () -> tensor - %4 = tfl.div(%2, %3) <{fused_activation_function = "NONE"}> : (tensor<8x128xf32>, tensor) -> tensor<8x128xf32> - %5 = "tfl.pseudo_const"() <{value = dense<9.99999997E-7> : tensor}> : () -> tensor - %6 = tfl.add(%4, %5) <{fused_activation_function = "NONE"}> : (tensor<8x128xf32>, tensor) -> tensor<8x128xf32> - %7 = "tfl.pseudo_const"() <{value = dense<[8, 128, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> - %8 = "tfl.reshape"(%6, %7) : (tensor<8x128xf32>, tensor<3xi32>) -> tensor<8x128x1xf32> - %9 = "tfl.rsqrt"(%8) : (tensor<8x128x1xf32>) -> tensor<8x128x1xf32> - %10 = tfl.mul(%arg0, %9) <{fused_activation_function = "NONE"}> : (tensor<8x128x1024xf32>, tensor<8x128x1xf32>) -> tensor<8x128x1024xf32> - return %10 : tensor<8x128x1024xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir deleted file mode 100644 index 6995e4d739ab2a..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir +++ /dev/null @@ -1,23 +0,0 @@ -module { - -func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<2304xf32>) -> tensor<1x128x2304xf32> { - %0 = stablehlo.composite "odml.rms_norm" %arg0, %arg1 {composite_attributes = {epsilon = 9.99999997E-7 : f32}, decomposition = @odml.rms_norm.impl} : (tensor<1x128x2304xf32>, tensor<2304xf32>) -> tensor<1x128x2304xf32> - return %0 : tensor<1x128x2304xf32> -} - -func.func @odml.rms_norm.impl(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<2304xf32>) -> tensor<1x128x2304xf32> { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<1x128x2304xf32> - %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = "tfl.sum"(%0, %1) <{keep_dims = false}> : (tensor<1x128x2304xf32>, tensor<1xi32>) -> tensor<1x128xf32> - %3 = "tfl.pseudo_const"() <{value = dense<4.34027781E-4> : tensor}> : () -> tensor - %4 = tfl.mul(%2, %3) <{fused_activation_function = "NONE"}> : (tensor<1x128xf32>, tensor) -> tensor<1x128xf32> - %5 = "tfl.pseudo_const"() <{value = dense<9.99999997E-7> : tensor}> : () -> tensor - %6 = tfl.add(%4, %5) <{fused_activation_function = "NONE"}> : (tensor<1x128xf32>, tensor) -> tensor<1x128xf32> - %7 = "tfl.pseudo_const"() <{value = dense<[1, 128, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> - %8 = "tfl.reshape"(%6, %7) : (tensor<1x128xf32>, tensor<3xi32>) -> tensor<1x128x1xf32> - %9 = "tfl.rsqrt"(%8) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32> - %10 = tfl.mul(%arg0, %9) <{fused_activation_function = "NONE"}> : (tensor<1x128x2304xf32>, tensor<1x128x1xf32>) -> tensor<1x128x2304xf32> - %11 = tfl.mul(%10, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x128x2304xf32>, tensor<2304xf32>) -> tensor<1x128x2304xf32> - return %11 : tensor<1x128x2304xf32> - } -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir b/tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir deleted file mode 100644 index 0b655f704eb5d7..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1xf32>) -> tensor { - %cst = arith.constant dense<[]> : tensor<0xi32> - %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x1xf32>, tensor<0xi32>) -> tensor - return %0 : tensor -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir deleted file mode 100644 index 42a5059e8861dd..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { - func.func @main(%x: tensor<2xf32>, %y: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %cpu_out = tfl.add %x, %y {fused_activation_function = "NONE"} : tensor<2xf32> - %npu_out = "tfl.custom"(%x, %y) {custom_code = "DISPATCH_OP", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %cpu_out, %npu_out : tensor<2xf32>, tensor<2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir deleted file mode 100644 index 0902f5966f8266..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x1xf32>, %arg1: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x1xf32> - return %0 : tensor<1x128x1xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir deleted file mode 100644 index 979610cdaa0e1e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1728x2304x3xf32>) -> tensor<1x432x576x3xf32> { - %0 = "tfl.average_pool_2d"(%arg0) <{filter_height = 4 : i32, filter_width = 4 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 4 : i32}> : (tensor<1x1728x2304x3xf32>) -> tensor<1x432x576x3xf32> - return %0 : tensor<1x432x576x3xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir deleted file mode 100644 index e756a0dab87cbc..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x4x256x128xf32>, %arg1: tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> { - %0 = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x4x256x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> - return %0 : tensor<1x4x256x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir deleted file mode 100644 index 5e262cb678714c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { - func.func @main(%x1: tensor<2xf32>, %x2: tensor<2xf32>, %x3: tensor<2xf32>) -> tensor<2xf32> { - %t1 = "tfl.custom"(%x1, %x2) {custom_code = "DISPATCH_OP_1", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %out = "tfl.custom"(%t1, %x3) {custom_code = "DISPATCH_OP_2", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %out : tensor<2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir deleted file mode 100644 index 6066c665713bc3..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xi32>) -> tensor<8x100x1xf32> { - %0 = "tfl.cast"(%arg0) : (tensor<8x100x1xi32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir deleted file mode 100644 index 79c64f423039ba..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir +++ /dev/null @@ -1,11 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32, tf_saved_model.semantics} { - func.func @main(%arg0: tensor<2x2xf32> { tf_saved_model.index_path = ["arg0"] }, %arg1: tensor<2x2xf32> { tf_saved_model.index_path = ["arg1"]}) -> (tensor<2x2xf32> {tf_saved_model.index_path = ["output"] }) attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "output"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = stablehlo.composite "odml.npu_call" %arg0, %arg1 {decomposition = @decomp} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> - } - func.func private @decomp(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> - } -} - diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir deleted file mode 100644 index e1e9bd36ae01b0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<128x4x1x256xf32>, %arg1: tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> { - %0 = "tfl.concatenation"(%arg0, %arg1) <{axis = 2 : i32, fused_activation_function = "NONE"}> : (tensor<128x4x1x256xf32>, tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> - return %0 : tensor<128x4x2x256xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir deleted file mode 100644 index 4eb0e0a04d32c2..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x216x288x24xf32>, %arg1: tensor<24x3x3x24xf32>, %arg2: tensor<24xf32>) -> tensor<1x216x288x24xf32> { - %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x216x288x24xf32>, tensor<24x3x3x24xf32>, tensor<24xf32>) -> tensor<1x216x288x24xf32> - return %0 : tensor<1x216x288x24xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir deleted file mode 100644 index 70ea46c1988b16..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> { - %0 = "tfl.cos"(%arg0) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> - return %0 : tensor<8x100x1x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir deleted file mode 100644 index 2682b724b88c37..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x216x288x12xf32>) -> tensor<1x432x576x3xf32> { - %0 = "tfl.depth_to_space"(%arg0) <{block_size = 2 : i32}> : (tensor<1x216x288x12xf32>) -> tensor<1x432x576x3xf32> - return %0 : tensor<1x432x576x3xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir deleted file mode 100644 index 706295d3e27076..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x40x40x192xf32>, %arg1: tensor<1x3x3x192xf32>, %arg2: tensor<192xf32>) -> tensor<1x32x32x192xf32> { - %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) <{depth_multiplier = 1 : i32, dilation_h_factor = 4 : i32, dilation_w_factor = 4 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x40x40x192xf32>, tensor<1x3x3x192xf32>, tensor<192xf32>) -> tensor<1x32x32x192xf32> - return %0 : tensor<1x32x32x192xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir deleted file mode 100644 index 3748d45bcd5249..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x128xf32>, %arg1: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { - %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x8x128xf32> - return %0 : tensor<1x128x8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir deleted file mode 100644 index a10606eccd41f9..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x64x4x64xf32>, %arg1: tensor<1x1x4x64xf32>) -> tensor<1x64x4x64xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[0, 1, 0, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> - %0 = "tfl.dynamic_update_slice"(%arg0, %arg1, %cst) : (tensor<1x64x4x64xf32>, tensor<1x1x4x64xf32>, tensor<4xi32>) -> tensor<1x64x4x64xf32> - return %0 : tensor<1x64x4x64xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir deleted file mode 100644 index 75b8000bb97a35..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<5xi32>) -> tensor<5x1x2xf32> { - %table = "tfl.pseudo_const"() <{value = dense<"0x00010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001"> : tensor<20x1x2xf32>}> : () -> tensor<20x1x2xf32> - %0 = "tfl.embedding_lookup"(%arg0, %table) : (tensor<5xi32>, tensor<20x1x2xf32>) -> tensor<5x1x2xf32> - return %0 : tensor<5x1x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir deleted file mode 100644 index 6bd3f1fa79d77c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> - return %0 : tensor<5xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir deleted file mode 100644 index 5cad120662635e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<128x2048xf32>, %arg1: tensor<2304x2048xf32>, %arg2: none) -> tensor<128x2304xf32> { - %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<128x2048xf32>, tensor<2304x2048xf32>, none) -> tensor<128x2304xf32> - return %0 : tensor<128x2304xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir deleted file mode 100644 index 6b0375c77c24a9..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x3x6xf32>, %arg1: tensor<4x5xi32>) -> tensor<4x5x3x6xf32> { - %0 = "tfl.gather"(%arg0, %arg1) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32> - return %0 : tensor<4x5x3x6xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir deleted file mode 100644 index 39ebcf24e972d0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { - %0 = "tfl.gelu"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir deleted file mode 100644 index b368def16d6e88..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1x64xi32>, %arg1: tensor<1x1x64xi32>) -> tensor<1x1x64xi1> { - %0 = "tfl.greater"(%arg0, %arg1) : (tensor<1x1x64xi32>, tensor<1x1x64xi32>) -> tensor<1x1x64xi1> - return %0 : tensor<1x1x64xi1> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir deleted file mode 100644 index 5c95ca2bb4e573..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x216x288x48xf32>) -> tensor<1x216x288x48xf32> { - %0 = "tfl.hard_swish"(%arg0) : (tensor<1x216x288x48xf32>) -> tensor<1x216x288x48xf32> - return %0 : tensor<1x216x288x48xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir deleted file mode 100644 index 13dacd3984493a..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x32x32x192xf32>) -> tensor<1x32x32x192xf32> { - %0 = "tfl.leaky_relu"(%arg0) <{alpha = 2.000000e-01 : f32}> : (tensor<1x32x32x192xf32>) -> tensor<1x32x32x192xf32> - return %0 : tensor<1x32x32x192xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir deleted file mode 100644 index 06370a186ddc5b..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1x64xi32>, %arg1: tensor<1x1x64xi32>) -> tensor<1x1x64xi1> { - %0 = "tfl.less"(%arg0, %arg1) : (tensor<1x1x64xi32>, tensor<1x1x64xi32>) -> tensor<1x1x64xi1> - return %0 : tensor<1x1x64xi1> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir deleted file mode 100644 index e58307caceb3be..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x64x64xi1>, %arg1: tensor<1x64x64xi1>) -> tensor<1x64x64xi1> { - %0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<1x64x64xi1>, tensor<1x64x64xi1>) -> tensor<1x64x64xi1> - return %0 : tensor<1x64x64xi1> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir deleted file mode 100644 index 56b4fcb8a9f3e6..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> - %0 = "tfl.mean"(%arg0, %cst) <{keep_dims = false}> : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<2xf32> - return %0 : tensor<2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir deleted file mode 100644 index d88a5d5923c77e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2xf32> - return %0 : tensor<2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin deleted file mode 100644 index 208cb983671510eedbb7b31f3491b9a0d6d14b9a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeHMU2I%O75?V#ubut5PMop{P*!TvxDXnX29zKR-;$K%AqtUNl^`K)oY-z%9NTgn znpQ}?g%>}81urXJh*E*9_QgtGtn$FB50#0cYFjF_Nc@;64{o295-JCoEZ>>?v-f&; zZBuNk5_?v&GiT13GjqroCQvtJl=~ov@OuKK-&Us3$!iJwm{ngZ40z5(6&I^ z0&NSlEzq{Wr_uu5=s^@ojxB{8@-WfOAF#xH+&Bs+KZY;*J6h}3VB4Lm z=DJYo>RjkTr3)Xpi(|2WIhwR%n2F&RFEqKf%ya~2`nbClt49s zs>Lj-S=c7^P}E@7#GL5SO+6}FiEE;Vq8{^IC~G~I zx;%dbq}U*zb-tszWl{ai-C9lZ#fKmxEiv=K{+cdY{O^^R4%}>ZqbD60pu+H{gn`Z<$f5K zfQr&<5<4_F_>Z~JeLRMM4y5nK!Y<|YKzu*SPrT}~OMw0=c6t~DmpJ8t{c(Q`Tr(8E zT;tuY?Ge2Eet)&bxM}?_`hUX{Mj;>fOZSXePJeZZQ?mAOh1OxGEoCbjTagh`M@v+h zkje(Iq{jCD99J|1Uj;w(@nkgEnP1(k;0rR@d!iY~2I}mVJY0wx?AMsl&;0IMwr~kn zN1s!&k-V34iriorJY;gSzq<3WW3x%xfWBa zcqw0&b$VmG{A`Hc7%$}$w%)!mUVc7A-}-p@aL~T&I}Dp2FL|2X)_9p;OTW@T`+2F- zEXsw=5Wp-sARYxilJojqIh-%lC}gJgihFv{yb<~j&Ig)pLD9wo{z6ca^JSj;DE>YH zEJh$-3#|?E%bz~R4KMGw@Wwi@40&--bj(N%wfy**=)mn|8$TfasmgIBHuGM&iugM$-2PB7?X5f84SZ|QZ z1{)ri)zRbd+R^0u?!kg{*y{U%@p&Id)$;sJ`u=t3kFG^;72ka$tUAxS{5O4@w4!?k zX+>wfhL-e}`@?~tKDH~y&o+O5=%H+jsn!1QScu-ZKYTGnZ`>cg6rwlo4+lf^t=}KM z9JDX{4#Vc}50A5h?L1E(Vt@X5+74?6Y>NK`gcjdrtO-Enk6~P|gHLe(6j(p{{>ckk ztY7Q?snnr;IWVAAMca#PZoTX;&;i}+VO4`Bs1tD?lQY7-1pC`Qe|G%!_r5(G{2xI)rmrB} z8q#;;on5~35Ep+{p3#Y~<*!!KZ>a^AvZ3!?zuW(&Aol^z3y`zdYZpg`(C<&@<9V|1ys6*0M8<9=kuf!Kr7%7`JUTuwSQxnc(%41Ti*AI36sLWgyQd-d%}?I@A_R^Vq94GkV27#;n(j7w}S{Rz(>*rpn2vx{|?hbDh8GWJ|yaA{t#IP5q(p^fOPMIZ=4>^taDERk%2Mv~ z*jn&xck4l;AO3}P4n3-ECHAT?%`lx~GMV@)%RI?>cI?6mL*`p}_cO_w3#m6Uc6q9B za&r92M4{LFkx%-Ocdw#ej|sJpxOgqmTB~`KHZqsa8V3x|4I2UC*fZqD;}&;RD+t%-*eBo@A~@l(R?|P!USLj^fd6pvaIY1=*vL% z8MUm^v`tb7g!F$JG}}p#Yf1h|H_=C6Jhq_5L9?Iq<#-AYZj~wEq9u~{%shxaN6EK$ z-6KBIOy{x&1b|J|0`?JCb}y%3BlBk;i?;ol=9mml8?F~hOKX8r4>AKjP+UV^+5*HY zB$HL53?Z|Tp1i7DnP((!fs`@xpxJPY{A^ezlkjHN*xiZhrEWujv__L7}=M0?L z*a>80%s;+tc<)LL#v}-5fn-s-*#O880dlI&)WKRHx4?{(mJgepow`D?ojIf<=f-pT z%b0Q=fP>7{X;)~SLUv9cCoSX4S;K=+(4QYT`=PUPZz6it2Il90?*cfs<&R+kG@J0g zC&0&S-l{Q#lQz$5PlDHMo=^8R2fqyb9C+Q<^BNk0Xg1Feas<58gMF-ktyp`@*)L=) zPhvAj)p}MvRAiol(AqCH!F&u1@7G+xMr0-}C5SJj?Jh&e`@ag{I6h!e9?$h~>cQp! z`^A8qABX37%Yht=b=#ipZE0?9YP13uP&5#)kfYdTd9*XI%(buK_vVLV8~J zUV4tJ1KFR~l>2m^e~2?+olkH}t;grp887AYSiOmUQhMXu^O-Oif@^$wRXQH$YWx~6 z7uRN1&0UL;Nu!>Ll<`h`U~c*=^zYOtO{I+U6RVE=!F#xC_C=qT@|!QNjq$iTf^ngg zm@YM$rVAKHmC2N0UZxW#7nCc^yOJqmjP`7rU2uH_O&txTM%~vyU8I53g8L?H>Xf}0 zFDl<++SSCgtA=U!ozzuYHWzP;mim=H)5$;j724G`8`t$xqf+$@iW8ogsH2kk@!iRC zqw?d8B#yZuQNBFCW{+NOjCJcI7OubeH8V%z~czjCXZ$3S#9 z#_Py1-l(rI)E%^1?4*DgSKZ`Sx~Ur4-?O$&FEZ2{p!%sntU2K(Lmi-JM1nd+{AX?t zcIGkEjr6jpf_}V6g?9+v@vp0}EYkF8Wn4XNsHzs4HqR)h{if-fI$|g#+TirjQVJRB zCpEg*U)6^7rZ+^8b${3=?o?^u6Ty%eKmO0fyYH$0GSpMys8|`|{l5{6iSeU>9 zjAN^^Bpe zt(p+ZkC&gZqJQ5|epGWo99{3JHq^mzMBK$`GSq!RT_nOGP<)#Dg5LC5Ig{lu zh=sV$w}L*wC(BShAx*p_EBK>UFj&d|Tz=5$1ECGvcNx`2XkO{Jou#*U5xG&gJF2_4)r( z9v)%d3GhEgr{N!ol_x#vGb{B;Tnl-4g!vrkj2t%?50ukfN5weZ5@mhH3n4Zz*P_k;DEW!)@F`th6(P94)S_Wkzs zAo@kbm23YHG~aJC>T4MB-yHa=MG?J%8GqX|-z6Grn`-r1eJ@1v69V5XoT3jXuSxlN zhIRa=lcTdg#K;>dN0hdtsI0cxWimdvJIlG1PnT@Bv^;ds};NeQRSqprN_3t)aE8sihV^{{_IKV)_68 diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir deleted file mode 100644 index f4959fb63e6231..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { - func.func @main(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> { - %out = "tfl.custom"(%x, %y) {custom_code = "DISPATCH_OP", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %out : tensor<2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin deleted file mode 100644 index a66f76296d7698031d0d8df30aaefe093aa5d04d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 13800 zcmeHM%a0sK8Lv&8*m=r?9GnECCGi0{W>r1YGdru!M7y#R=Yj1A1Ch9Ndb(?7THDjz z?#Jwo5yFDpxa7bI{(&4o0>pts@QDLL;u81+xNs#LA|WBb{C-u{^Vml)$4Jw5Z+CxH z_0?D3$Pv>%J@`zRzO)7xBck!8MRy3(TOl@AKWo zec3+oz-R$}LcD_eU;jZb^rIZV$2%VS;yB9l1>ElhMHN?Zma6DVkR(3m{1kmZz(Bli zu-}8um+0xtd_6NBqm9hu^F`dHDyDHd@*_12>ZI~#m@z`%i&zQe7SJ{A@$8Bws0TUM z!%-b@t<(dbyr?hocmp^6h*lzROrv%2=LeMS5VD=o|1cOA{f2XY*7ZCZr>(74XG?Ci zT)ExqInv#6q`$Mh?Kq1-p+1(cUswX==0&;BTDkGLK{_Miyk5&`mM&DA2 zny9ooh)d;{^+fyDrqTVzCfbE~9=4)&3?Irw|EDz6BmJHHI1hc0EzO*qgBQOk(`mlk z`0*zCEas!sCgPUoBOl%mGT&*wd=5HZw%aE7abBZ${Q;esm2yQ~XS0U?wZ@a?q>nxc zlp?AUr_9PWHhSRO*!&r*V``Z=MLPg7BSZQ8 z|2->i-O;%D%bOvgGzCU~9W`48Z^5PUo8YF(HZr3Ue zRnd=zUNnU74;G|6EBq+;q|9=^%!YN_1{FF^RZ#+)`-QT& z>4n*35Cg1z@~y(U@`p)w(09Bt9D|yxV(1q^I#PZTPvWY-)Ahn2^@kYamw5t;^i&!Q z5`4}_V7PR{tgQS&P^t)BoK6i^Rh%e%NH?gftQr?8U|XaH^~g`uR3-hdIo|C$NUQjc z@<&CGkNd+QDV3)VD^;W`y3p$J!vM4TVWNT*tfo~s@e8U8b64?1Wi)b=z}=I^abY0Qrjdu z6*J$QXjU+ZQ$HxealhN6((?e;N#YP{8wa_vQ^lDdO#*)~%Y(9n7O60Qtm4tQ>i4`s zJn}JLHMW{>b-G>>6tK-8s6tSowZ7Bpd1(e?1Q>dthnALm$H4;1Bg8bQ;wdJ9Is&JA z;F{$P;4%A^USXl4I#m5GY@pyDa34P};;e|PS>M@ddGtG#z@QDkQVGY9-sw!et_S+x?Ls)%K7b${I zyNzta)QSQc=W2J&SjQdEm!Opn7CP=}E%G($ zU8_3*$NnpF%jsNR{qrg$R6&unyvCpBj99p4GnprnmJ`vHa`sOzl_5r+>ayg9Voly1O;SQ(B8zi zAqf*)Nl`T+()+F_-UhvG@v7&z)ayLF z3}2mEYzR25QJN)j&_smPPE8NYlK=px=d&L^LL7iiA=5ceE9)r9224Z<_c^P)P{ec~ zgQ06)0|XFI89;^@^|K<<3EU^f6C#k%HG27JFi~FDfz`9T2FMxbXp)KbVs(gAtJMx& z#S;iX>-mE?DE$dGg>lIEvteLZ20+%SG4CloG(w%yrVF1jjA-f` z-&9-HVt`pVk3w3u`eYL4D*b6()bm7^jWFpV5w5leSl(E|rETc<8zd zNR-|r8NIV8j`!0A37%>Pg;Ts#f9VBjn-bQglwr>HM zO~_me>#{=TBE^`y=bYT%w1xvj?TP$My9J7rYGn~*RF8TxIAn>8*pFFNn;Oy#86vku z{=&=3mBx&+o3gXk^|%?-cz-eot=!uGaIQn`YP;!avSJd<^J|3OwxnC`g-9)-@e)7Y(38)(w8+6SVWVoJVdLYNe!eU@gmmL)IsDP1qA? zo3Ubn-71_y&=Oc002>>tc{ig?K>mKCFS%`oVpxgLf!dC_nvnor$UOpb_9*6dir(94 zBMCalnT@i51YC5fW8DRKus$-TGSN?K#VE$7!$zhCR9V)}>|!OH$rR8AH0HtVJN>o` zb?>$MJJfs2q2+X`&7SQYgdFQR&)I9<+71erIRjnV4%$nXPUJ2&jZqv?^+RMhmH?|m zbTusBoM(d$am#EybXr+Nx70gDAQE>|%gDzf0^OJt7rRO6GMO}rB75daiG3ida+M;B zqX6KSu_v&3A(vyl@jAm0dtKyj2StZ*NNEM&QgHTKWcnC#0Xb%SvXn@T|IveN1rZK1 z>}9HP)_Jou$dFG5TU3yAjh^FyKaVXjm&QIOf$94?JL!V z+|xdaeI_EQP$(C5wFhIPTJPH!_7-b8fy0%>Fau{A>vT4?izk7gD5-jGFMrwD!G@M&j?>-lahhyV zz!T`sAk0dV0M zS-J-}+CzneEC_u-ogEzr7b@;#CPue>iqRj_&Wvf zk3P7Py@cmk-2DCG7EeUQr8iTc(=DDV?!G?UD&0%nEwOub5?mRpa6ixBJhxPJgz|0> zS7I-UFyc~MT)L?dZkb0LJZ{a>@MeyQ#Kzr6ApzvQ%ljd`+uwoie)OX&6@N$KZ$`X_ zLs7co^&wux`)Bk|9GmZ-**O2pCvJSocKdrSo5j7$dw(Kr`TBY58-Hxp`|Epe@3HNl z_;2Rh)q2ps7J$Hh4~c)@=Wj>6@O>9E?l#|+*0ilWt^9Kuc|C=jzvJ-z_TL`n&pzi5 ze)#bZ9>$_cJ9_1|&tIHB{G{o_e$7oeO55k&Uq61o>C^CGJNp*sgXbo{pSN>SKHmIo zoxcb2$-5b|i;=bE{o7M`qr0ef zo*C$ncb8o+SUPR{7;gG&(~n#Z(U)_w{RQ;plf3C0q&I8-u9bo9FOe}W_Pb+cAbk^b z(s+Llbm(dmui`6~&iAMO{>eQ)*VkGDYYnV5@OjYyaqFkA{O -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" - -constexpr const char* kModelFileName = "simple_model.tflite"; -constexpr const char* kQualcommModelFileName = "simple_model_qualcomm.bin"; -constexpr const char* kGoogleTensorModelFileName = - "simple_model_google_tensor.bin"; -constexpr const char* kMediaTekModelFileName = "simple_model_mtk.bin"; - -constexpr const int32_t kTestInput0Dimensions[] = {2}; -constexpr const int32_t kNumTestInput0Dimensions = - sizeof(kTestInput0Dimensions) / sizeof(kTestInput0Dimensions[0]); -constexpr const int32_t kTestInput1Dimensions[] = {2}; -constexpr const int32_t kNumTestInput1Dimensions = - sizeof(kTestInput1Dimensions) / sizeof(kTestInput1Dimensions[0]); -constexpr const int32_t kTestOutputDimensions[] = {2}; -constexpr const int32_t kNumTestOutputDimensions = - sizeof(kTestOutputDimensions) / sizeof(kTestOutputDimensions[0]); - -constexpr const float kTestInput0Tensor[] = {1, 2}; -constexpr const float kTestInput1Tensor[] = {10, 20}; -constexpr const float kTestOutputTensor[] = {11, 22}; - -constexpr const float kTestInput0Tensor_2[] = {10, 20}; -constexpr const float kTestInput1Tensor_2[] = {100, 200}; -constexpr const float kTestOutputTensor_2[] = {110, 220}; - -constexpr const size_t kTestInput0Size = - sizeof(kTestInput0Tensor) / sizeof(kTestInput0Tensor[0]); -constexpr const size_t kTestInput1Size = - sizeof(kTestInput1Tensor) / sizeof(kTestInput1Tensor[0]); -constexpr const size_t kTestOutputSize = - sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); - -constexpr const LiteRtRankedTensorType kInput0TensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTestInput0Dimensions)}; - -constexpr const LiteRtRankedTensorType kInput1TensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTestInput1Dimensions)}; - -constexpr const LiteRtRankedTensorType kOutputTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTestOutputDimensions)}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir deleted file mode 100644 index 7fb5ac2d2187f0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<1x128x2304xf32>) -> tensor<1x128x2304xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x2304xf32> - return %0 : tensor<1x128x2304xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir deleted file mode 100644 index 07757fddec1b90..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %2 = tfl.mul %1, %1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %3 = tfl.add %2, %2 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %3 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir deleted file mode 100644 index e94d4815d9545e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<4xi32> { - // %cst = "tfl.pseudo_const"() <{value = dense<0> : tensor}> : () -> tensor - %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3) <{axis = 0 : i32, values_count = 4 : i32}> : (tensor, tensor, tensor, tensor) -> tensor<4xi32> - return %0 : tensor<4xi32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir deleted file mode 100644 index 17bbc4ef2fdefe..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { - %0 = "tfl.relu6"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir deleted file mode 100644 index 72306d2b9e6cd3..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { - %0 = "tfl.relu"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir deleted file mode 100644 index 515db6e424e6a7..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>) -> tensor<128x4x1x256xf32> { - %0 = "tfl.reshape"(%arg0, %arg1) : (tensor<1x128x4x256xf32>, tensor<4xi32>) -> tensor<128x4x1x256xf32> - return %0 : tensor<128x4x1x256xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir deleted file mode 100644 index 1cd9be9729f487..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x54x72x96xf32>) -> tensor<1x108x144x96xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[108, 144]> : tensor<2xi32>}> : () -> tensor<2xi32> - %0 = "tfl.resize_bilinear"(%arg0, %cst) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x54x72x96xf32>, tensor<2xi32>) -> tensor<1x108x144x96xf32> - return %0 : tensor<1x108x144x96xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir deleted file mode 100644 index a73eb7e60e0b4c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x54x72x96xf32>) -> tensor<1x108x144x96xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[108, 144]> : tensor<2xi32>}> : () -> tensor<2xi32> - %0 = "tfl.resize_nearest_neighbor"(%arg0, %cst) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x54x72x96xf32>, tensor<2xi32>) -> tensor<1x108x144x96xf32> - return %0 : tensor<1x108x144x96xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir deleted file mode 100644 index 5083f3f3a30383..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { - %0 = "tfl.rsqrt"(%arg0) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32> - return %0 : tensor<1x128x1xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir deleted file mode 100644 index 2405e5d3626893..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x128xi1>, %arg1: tensor<1x128x8x128xf32>, %arg2: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { - %0 = "tfl.select"(%arg0, %arg1, %arg2) : (tensor<1x128x8x128xi1>, tensor<1x128x8x128xf32>, tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> - return %0 : tensor<1x128x8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir deleted file mode 100644 index a8d80ecc80f970..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x1x1x100xi1>, %arg1: tensor<8x100x32x100xf32>, %arg2: tensor<8x100x32x100xf32>) -> tensor<8x100x32x100xf32> { - %0 = "tfl.select_v2"(%arg0, %arg1, %arg2) : (tensor<8x1x1x100xi1>, tensor<8x100x32x100xf32>, tensor<8x100x32x100xf32>) -> tensor<8x100x32x100xf32> - return %0 : tensor<8x100x32x100xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir deleted file mode 100644 index 431d3b93065441..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> { - %0 = "tfl.sin"(%arg0) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> - return %0 : tensor<8x100x1x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir deleted file mode 100644 index 4adfa00a204cfc..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir +++ /dev/null @@ -1,8 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x256xf32>) -> tensor<1x128x4x128xf32> { - %cst_0 = "tfl.pseudo_const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32> - %cst_1 = "tfl.pseudo_const"() <{value = dense<[1, 128, 4, 128]> : tensor<4xi32>}> : () -> tensor<4xi32> - %0 = "tfl.slice"(%arg0, %cst_0, %cst_1) : (tensor<1x128x8x256xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x4x128xf32> - return %0 : tensor<1x128x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir deleted file mode 100644 index bb3a83a3787f6f..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { - %0 = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> - return %0 : tensor<8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir deleted file mode 100644 index 3e336816486285..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x432x576x6xf32>) -> tensor<1x216x288x24xf32> { - %0 = "tfl.space_to_depth"(%arg0) <{block_size = 2 : i32}> : (tensor<1x432x576x6xf32>) -> tensor<1x216x288x24xf32> - return %0 : tensor<1x216x288x24xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir deleted file mode 100644 index 38c99095a01319..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x4x3x3xf32>) -> tensor<1x4x3x1xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<3> : tensor}> : () -> tensor - %0:3 = "tfl.split"(%cst, %arg0) <{num_splits = 3 : i32}> : (tensor, tensor<1x4x3x3xf32>) -> (tensor<1x4x3x1xf32>, tensor<1x4x3x1xf32>, tensor<1x4x3x1xf32>) - return %0#0 : tensor<1x4x3x1xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir deleted file mode 100644 index 9d098eb0b9f61d..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<131072x4xi32>, %arg2: tensor<131072xf32>) -> tensor<1x128x4x256xf32> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ - ^bb0(%arg3: tensor, %arg4: tensor): - stablehlo.return %arg4 : tensor - }) : (tensor<1x128x4x256xf32>, tensor<131072x4xi32>, tensor<131072xf32>) -> tensor<1x128x4x256xf32> - return %0 : tensor<1x128x4x256xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir deleted file mode 100644 index 373eff80ff3cd8..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>, %arg3: tensor<4xi32>) -> tensor<1x128x4x128xf32> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x128x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x4x128xf32> - return %0 : tensor<1x128x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir deleted file mode 100644 index e1483fed87d802..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x128xf32>, %arg1: tensor<1x128x4x128xf32>) -> tensor<1x128x4x128xf32> { - %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x4x128xf32> - return %0 : tensor<1x128x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir deleted file mode 100644 index bb4613d5b4b6c5..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x2304xf32>) -> tensor<1x128x1xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> - %0 = "tfl.sum"(%arg0, %cst) <{keep_dims = true}> : (tensor<1x128x2304xf32>, tensor<1xi32>) -> tensor<1x128x1xf32> - return %0 : tensor<1x128x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir deleted file mode 100644 index ce1d0302c8a838..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { - %0 = "tfl.tanh"(%arg0) : (tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> - return %0 : tensor<1x128x8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir deleted file mode 100644 index f24d72216897fd..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<128x4x2x128xf32>) -> tensor<128x2x4x128xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> - %0 = "tfl.transpose"(%arg0, %cst) : (tensor<128x4x2x128xf32>, tensor<4xi32>) -> tensor<128x2x4x128xf32> - return %0 : tensor<128x2x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir b/tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir deleted file mode 100644 index 463dd456dc5c5d..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.add %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %1 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir b/tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir deleted file mode 100644 index 738c8309110318..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %2 = tfl.add %1, %1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %3 = tfl.mul %2, %2 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %3 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir b/tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir deleted file mode 100644 index 4e2403a7fadbd8..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<*xf32> - return %0 : tensor<*xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/tools/BUILD b/tensorflow/lite/experimental/litert/tools/BUILD deleted file mode 100644 index 88e02b5246d4d0..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/BUILD +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_cc_bin_with_qnn") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "apply_plugin", - srcs = ["apply_plugin.cc"], - hdrs = ["apply_plugin.h"], - deps = [ - ":dump", - ":outstream", - ":tool_display", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_flags", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "apply_plugin_test", - srcs = ["apply_plugin_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", - ], - tags = [ - "noasan", - "nomsan", - "nosan", - "notsan", - ], - deps = [ - ":apply_plugin", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -# TODO - @lukeboyer: Figure out some selective inclusiion of the data deps, some are very large. -litert_cc_bin_with_qnn( - name = "apply_plugin_main", - srcs = ["apply_plugin_main.cc"], - data = [ - # copybara:uncomment_begin(google-only) - # "//platforms/darwinn/compiler:compiler_api_wrapper", - # copybara:uncomment_end - "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/google_tensor/compiler:google_tensor_compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:compiler_plugin_so", - ], - export_litert_only = 1, - include_system = 1, - linkstatic = 1, - # copybara:uncomment malloc = "//base:system_malloc", - tags = [ - "noasan", - "nobuilder", - "nomsan", - "nosan", - ], - ungrte = True, - deps = [ - ":apply_plugin", - ":outstream", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_flags", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", - ], -) - -# Fork of "apply_plugin_main" without the "ungrte" so this tool can be used as part of larger -# integration test pipelines with example_plugin. -cc_binary( - name = "apply_plugin_main_for_test", - testonly = 1, - srcs = ["apply_plugin_main.cc"], - data = [ - "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", - ], - linkstatic = 1, - tags = [ - "noasan", - "nomsan", - "nosan", - ], - deps = [ - ":apply_plugin", - ":outstream", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_flags", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@llvm-project//llvm:Support", - ], -) - -cc_library( - name = "tool_display", - srcs = ["tool_display.cc"], - hdrs = ["tool_display.h"], - deps = [ - ":outstream", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "tool_display_test", - srcs = ["tool_display_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":tool_display", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "dump", - srcs = ["dump.cc"], - hdrs = ["dump.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core/model", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "dump_test", - srcs = ["dump_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":dump", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "outstream", - hdrs = ["outstream.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_logging", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "benchmark_litert_model", - srcs = ["benchmark_litert_model.cc"], - hdrs = ["benchmark_litert_model.h"], - deps = [ - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/tools:utils", - "//tensorflow/lite/tools/benchmark:benchmark_model_lib", - "//tensorflow/lite/tools/benchmark:benchmark_params", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "benchmark_litert_model_test", - srcs = ["benchmark_litert_model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/mobilenet_v2_1.0_224.tflite", - ], - env = { - "ASAN_OPTIONS": "detect_odr_violation=0", - }, - tags = [ - "manual", - "notap", - "requires-gpu-nvidia", - ], - deps = - [ - ":benchmark_litert_model", - "@com_google_googletest//:gtest_main", - # copybara:uncomment_begin(google-only) - # "//third_party/odml/infra/ml_drift_delegate/litert:ml_drift_cl_accelerator", # buildcleaner: keep - # copybara:uncomment_end - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/tools/benchmark:benchmark_model_lib", - "//tensorflow/lite/tools/benchmark:benchmark_params", - ], -) - -# We create a library for benchmark_main.cc to faciliate the creation of a -# customized benchmark model binary that only needs linking with extra -# dependency, e.g., enabling creating of benchmark binaries with a custom -# delegate provider. -cc_library( - name = "benchmark_model_main", - srcs = [ - "benchmark_litert_model_main.cc", - ], - deps = [ - ":benchmark_litert_model", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/tools:logging", - ], -) diff --git a/tensorflow/lite/experimental/litert/tools/README.md b/tensorflow/lite/experimental/litert/tools/README.md deleted file mode 100644 index 400f9c1f9a5b19..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/README.md +++ /dev/null @@ -1,24 +0,0 @@ -## run_model - -This is a simple tool to run a model with the CompiledModel API. - -``` -run_model --graph= -``` - -### Use NPU via Dispatch API - -If you're using the Dispatch API, you need to pass the Dispatch library -(libLiteRtDispatch_xxx.so) location via `--dispatch_library_dir` - -``` -run_model --graph= --dispatch_library_dir= -``` - -### Use GPU - -If you run a model with GPU accelerator, use `--use_gpu` flag. - -``` -run_model --graph= --use_gpu -``` diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin.cc deleted file mode 100644 index 36db8e81000fe4..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc +++ /dev/null @@ -1,515 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" -#include "tensorflow/lite/experimental/litert/tools/tool_display.h" - -namespace litert::tools { - -using ::litert::BufferRef; -using ::litert::internal::CompilerFlags; -using ::litert::internal::CompilerPlugin; -using ::litert::internal::Dump; -using ::litert::internal::PartitionResult; -using ::litert::internal::SerializeModel; -using ::litert::internal::VerifyFlatbuffer; -using ::litert::tools::ApplyPluginRun; - -#define LITERT_ENSURE_CONFIG(expr) \ - if (!(expr)) { \ - return kLiteRtStatusErrorInvalidToolConfig; \ - } - -namespace { - -class Context { - public: - using Ptr = std::unique_ptr; - - explicit Context(ApplyPluginRun::Ptr run) - : run_(std::move(run)), - display_(ToolDisplay(std::move(run_->dump_out), - Context::CmdStr(run_->cmd))) {} - - ApplyPluginRun::Cmd Cmd() const { return run_->cmd; } - - absl::Span LibSearchPaths() const { - return absl::MakeConstSpan(run_->lib_search_paths.data(), - run_->lib_search_paths.size()); - } - - absl::string_view SocModelTarget() const { - ABSL_CHECK_EQ(run_->soc_models.size(), 1); - return run_->soc_models.front(); - } - - absl::string_view SocManufacturer() const { - return run_->soc_manufacturer.value(); - } - - std::ostream& Out(size_t out_ind = 0) { - ABSL_CHECK_GE(run_->outs.size(), 1); - return run_->outs.at(out_ind); - } - - const CompilerFlags& Flags() const { return run_->compiler_flags; } - - OutStream SwapOut(OutStream out) { - ABSL_CHECK_EQ(run_->outs.size(), 1); - auto res = run_->outs.front(); - run_->outs.at(0) = out; - return res; - } - - uint32_t NumOuts() const { return run_->outs.size(); } - - const ApplyPluginRun& Run() const { return *run_; } - ApplyPluginRun& Run() { return *run_; } - - ToolDisplay& Dump() { return display_; } - - static absl::string_view CmdStr(ApplyPluginRun::Cmd cmd); - - private: - ApplyPluginRun::Ptr run_; - ToolDisplay display_; -}; - -void DumpSubgraphs(ToolDisplay& display, absl::string_view label, - absl::Span subgraphs) { - for (auto* subgraph : subgraphs) { - display.Labeled(); - display.Indented() << absl::StreamFormat("(%s graph)", label); - Dump(*subgraph, display.Display()); - } -} - -void DumpCompilationRequest(ToolDisplay& display, absl::string_view soc_model, - size_t num_subgraphs, const CompilerFlags& flags) { - display.Labeled() << absl::StreamFormat( - "Requesting compilation for target `%s` on %lu " - "partitions with flags: ", - soc_model, num_subgraphs) - << flags << "\n"; -} - -void DumpCompilationResult(ToolDisplay& display, size_t byte_code_size, - size_t num_entry_points) { - display.Labeled() << absl::StreamFormat( - "Compiled %lu partitions into %lu bytes\n", num_entry_points, - byte_code_size); -} - -void DumpModelStats(ToolDisplay& display, BufferRef buf) { - display.Labeled() << absl::StreamFormat( - "Serialized a model of size %lu bytes\n", buf.Size()); -} - -void DumpPartitionResult(ToolDisplay& display, const PartitionResult& result) { - display.Labeled() << absl::StreamFormat( - "Partitioning yielded %lu new subgraphs\n", result.second.Size()); - - DumpSubgraphs(display, "new subgraphs", result.second.Elements()); -} - -absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { - switch (cmd) { - case ApplyPluginRun::Cmd::INFO: - return "INFO"; - case ApplyPluginRun::Cmd::NOOP: - return "NOOP"; - case ApplyPluginRun::Cmd::PARTITION: - return "PARTITION"; - case ApplyPluginRun::Cmd::COMPILE: - return "COMPILE"; - case ApplyPluginRun::Cmd::APPLY: - return "APPLY"; - } -} - -Expected> LoadAllPlugins(Context& ctx) { - ctx.Dump().Start("Load Plugins"); - ctx.Dump().Labeled() << "Loading plugins from: "; - const auto paths = ctx.LibSearchPaths(); - for (auto it = paths.begin(); it < paths.end(); ++it) { - ctx.Dump().Display() << *it; - if (it < paths.end() - 1) { - ctx.Dump().Display() << ", "; - } - } - ctx.Dump().Display() << "\n"; - - auto plugins = CompilerPlugin::LoadPlugins(ctx.LibSearchPaths()); - if (!plugins.HasValue()) { - ctx.Dump().Fail(); - return plugins; - } - ctx.Dump().Labeled() << "Found plugins\n"; - ctx.Dump().Labeled() << absl::StreamFormat("Loaded %lu plugins\n", - plugins.Value().size()); - - ctx.Dump().Done(); - return plugins; -} - -Expected LoadPlugin(Context& ctx) { - auto plugins = LoadAllPlugins(ctx); - if (!plugins) { - return plugins.Error(); - } - - ctx.Dump().Start("Select Plugin"); - - for (auto& plugin : *plugins) { - if (plugin.SocManufacturer() == ctx.Run().soc_manufacturer) { - ctx.Dump().Labeled() << absl::StreamFormat("Selected plugin for: %s\n", - plugin.SocManufacturer()); - ctx.Dump().Done(); - return std::move(plugin); - } - } - - ctx.Dump().Fail(); - return Unexpected(kLiteRtStatusErrorNotFound); -} - -Expected LoadModel(Context& ctx) { - ctx.Dump().Start("Load Model"); - ctx.Dump().Labeled() << absl::StreamFormat("Loading model from: %s\n", - ctx.Run().model.value()); - auto model_result = Model::CreateFromFile(ctx.Run().model->data()); - if (!model_result.HasValue()) { - ctx.Dump().Labeled() << "Failed to load model from file."; - ctx.Dump().Fail(); - return model_result; - } - - ctx.Dump().Labeled(); - Dump(*model_result.Value().Get(), ctx.Dump().Display()); - ctx.Dump().Done(); - - return model_result; -} - -// -// INFO Command -// - -LiteRtStatus ValidateInfoRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.outs.size() == 1); - return kLiteRtStatusOk; -} - -LiteRtStatus Info(Context& ctx) { - auto plugins = LoadAllPlugins(ctx); - if (!plugins) { - return plugins.Error().Status(); - } - - for (auto& plugin : *plugins) { - ctx.Out() << absl::StreamFormat("< LiteRtCompilerPlugin > \"%s\" | ", - plugin.SocManufacturer()); - const auto& models = plugin.SocModels(); - for (auto it = models.begin(); it < models.end(); ++it) { - ctx.Out() << absl::StreamFormat("\"%s\"", *it); - if (it < models.end() - 1) { - ctx.Out() << ", "; - } - } - ctx.Out() << "\n"; - } - return kLiteRtStatusOk; -} - -// -// NOOP Command -// - -LiteRtStatus ValidateNoopRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(run.model.has_value()); - LITERT_ENSURE_CONFIG(run.outs.size() == 1); - return kLiteRtStatusOk; -} - -LiteRtStatus Noop(Context& ctx) { - auto model = LoadModel(ctx); - if (!model) { - return model.Error().Status(); - } - - auto serialized = SerializeModel(std::move(*model->Get())); - if (!serialized) { - return serialized.Error().Status(); - } - LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), - kLiteRtStatusErrorInvalidFlatbuffer, - "Failed to invalidate flatbuffer"); - serialized->WriteStr(ctx.Out()); - return kLiteRtStatusOk; -} - -// -// PARTITION Command -// - -LiteRtStatus ValidatePartitionRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.model.has_value() && !run.model.value().empty()); - LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); - LITERT_ENSURE_CONFIG(!run.outs.empty()); - return kLiteRtStatusOk; -} - -LiteRtStatus Partition(Context& ctx) { - auto plugin = LoadPlugin(ctx); - if (!plugin) { - return plugin.Error().Status(); - } - - auto model_wrap = LoadModel(ctx); - if (!model_wrap) { - return model_wrap.Error().Status(); - } - auto& model = *model_wrap->Get(); - - ctx.Dump().Start("Partitioning model"); - auto partition_result = PartitionModel(*plugin, model, ctx.Run().subgraphs); - if (!partition_result) { - return partition_result.Error().Status(); - } - ctx.Dump().Done(); - DumpPartitionResult(ctx.Dump(), *partition_result); - - auto& new_subgraphs = partition_result->second; - model.TransferSubgraphsFrom(std::move(new_subgraphs)); - - ctx.Dump().Start("Serializing model"); - auto serialized = SerializeModel(std::move(model)); - DumpModelStats(ctx.Dump(), *serialized); - ctx.Dump().Done(); - - ctx.Dump().Start("Verifying flatbuffer"); - LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), - kLiteRtStatusErrorInvalidFlatbuffer, - "Failed to invalidate flatbuffer"); - ctx.Dump().Done(); - - ctx.Dump().Start("Writing to out"); - serialized->WriteStr(ctx.Out()); - ctx.Dump().Done(); - - return kLiteRtStatusOk; -} - -// -// COMPILE Command -// - -LiteRtStatus ValidateCompileRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.model.has_value()); - LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); - // TODO: implement multi target compilation. - LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, - "Multi target compilation not implemented."); - return kLiteRtStatusOk; -} - -LiteRtStatus Compile(Context& ctx) { - auto model_wrap = LoadModel(ctx); - if (!model_wrap) { - return model_wrap.Error().Status(); - } - auto& model = *model_wrap->Get(); - - auto plugin = LoadPlugin(ctx); - if (!plugin) { - return plugin.Error().Status(); - } - - ctx.Dump().Start("Compiling"); - DumpCompilationRequest(ctx.Dump(), ctx.SocModelTarget(), model.NumSubgraphs(), - ctx.Flags()); - plugin->SetFlags(ctx.Flags()); - auto compilation_result = plugin->Compile(&model, ctx.SocModelTarget()); - if (!compilation_result) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - - auto num_byte_code = compilation_result->NumByteCodeModules(); - if (*num_byte_code < 1) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - if (!num_byte_code) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - for (int i = 0; i < ctx.NumOuts(); ++i) { - auto byte_code = compilation_result->ByteCode(i); - if (!byte_code) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - auto num_calls = compilation_result->NumCalls(); - if (!num_calls) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - - DumpCompilationResult(ctx.Dump(), byte_code->Size(), *num_calls); - byte_code->WriteStr(ctx.Out(i)); - } - ctx.Dump().Done(); - - return kLiteRtStatusOk; -} - -// -// APPLY Command -// - -LiteRtStatus ValidateApplyRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.model.has_value()); - LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); - LITERT_ENSURE_CONFIG(run.outs.size() == run.soc_models.size()); - // TODO: implement multi target compilation. - LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, - "Multi target compilation not implemented."); - return kLiteRtStatusOk; -} - -LiteRtStatus Apply(Context& ctx) { - auto model_wrap = LoadModel(ctx); - if (!model_wrap) { - return model_wrap.Error().Status(); - } - auto& model = *model_wrap->Get(); - - auto plugin = LoadPlugin(ctx); - if (!plugin) { - return plugin.Error().Status(); - } - - ctx.Dump().Start("Applying plugin"); - plugin->SetFlags(ctx.Flags()); - if (auto status = litert::internal::ApplyPlugin( - *plugin, model, ctx.SocModelTarget(), ctx.Run().subgraphs); - !status) { - LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().c_str()); - return status.Error().Status(); - } - ctx.Dump().Done(); - - ctx.Dump().Start("Serializing model"); - auto serialized = SerializeModel(std::move(model)); - DumpModelStats(ctx.Dump(), *serialized); - ctx.Dump().Done(); - - ctx.Dump().Start("Verifying flatbuffer"); - LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), - kLiteRtStatusErrorInvalidFlatbuffer, - "Failed to invalidate flatbuffer"); - ctx.Dump().Done(); - - ctx.Dump().Start("Writing to out"); - serialized->WriteStr(ctx.Out()); - ctx.Dump().Done(); - - return kLiteRtStatusOk; -} - -} // namespace - -LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run) { - Context context(std::move(run)); - DumpPreamble(context.Dump()); - - switch (context.Cmd()) { - case ApplyPluginRun::Cmd::INFO: - if (auto stat = ValidateInfoRun(context.Run()); stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for INFO command\n"; - return stat; - } - return Info(context); - - case ApplyPluginRun::Cmd::PARTITION: - if (auto stat = ValidatePartitionRun(context.Run()); - stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for PARTITION command\n"; - return stat; - } - return Partition(context); - - case ApplyPluginRun::Cmd::COMPILE: - if (auto stat = ValidateCompileRun(context.Run()); - stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for COMPILE command\n"; - return stat; - } - return Compile(context); - - case ApplyPluginRun::Cmd::APPLY: - if (auto stat = ValidateApplyRun(context.Run()); - stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for APPLY command\n"; - return stat; - } - return Apply(context); - - case ApplyPluginRun::Cmd::NOOP: - - if (auto stat = ValidateNoopRun(context.Run()); stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for NOP command\n"; - return stat; - } - return Noop(context); - - default: - return kLiteRtStatusErrorInvalidArgument; - } - - return kLiteRtStatusOk; -} - -} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.h b/tensorflow/lite/experimental/litert/tools/apply_plugin.h deleted file mode 100644 index 8d105836eb8422..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin.h +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -namespace litert::tools { - -using ::litert::internal::CompilerFlags; - -struct ApplyPluginRun { - // NOTE: All StrFlagT are expected to have static storage duration. - using Ptr = std::unique_ptr; - - // A specific command implemented by the tool to run. - enum class Cmd { - // Displays info about all plugins found in given search paths. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Ignored. - // "soc_manufacturer": Optional, filters plugins to display. - // "soc_models": Ignored. - // "outs": Required, must be size one. - // "dump_out": Optional. - INFO, - - // Does nothing and simply de-serializes and re-serializes the given model. - // This is intended for testing and internal debugging only. - // - // FLAG SEMANTICS: - // "lib_search_paths": Ignored. - // "model": Required. - // "soc_manufacturer": Ignored. - // "soc_models": Ignored. - // "outs": Required, must be size one. - // "dump_out": Optional. - NOOP, - - // Runs the entire end to end flow. This is the standard compiler plugin - // usage. A seperate compilation step will occur for each sco_model tag that - // is supported by the loaded plugin, and a new output model will be - // generated for each. Partitioning is invariant accross different soc_model - // targets from the same manufacturer, so only one compilation step will - // occur even if multiple targest are requested. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Required. - // "soc_manufacturer": Required. - // "soc_models": Required, at least one. - // "outs": Required, must be size equal to "soc_models". - // "dump_out": Optional. - // - // TODO: Support multi target compilation. - APPLY, - - // Only run the partiion step and skip compilation. Writes a ".tflite" model - // to "out" where selected partitions are manifested as new standard - // flatbuffer subgraphs added to the input model. - // The partitions original locations are replaced with a single custom op - // the contains an identifier to the corresponding partition (new subgraph). - // This is intended for testing and development. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Required. - // "soc_manufacturer": Required. - // "soc_models": Ignored. - // "outs": Required, must be size one. - // "dump_out": Optional. - PARTITION, - - // Skip partitioning and run the entire input model through compilation - // directly. Fails if any ops in the input model are unsupported by the - // plugin. Writes the raw compiled result to the "out" stream without any - // wrapping flatbuffer. Runs multi-target compilation as in "APPLY", - // Intended for testing and development. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Required. - // "soc_manufacturer": Required. - // "soc_models": Required, at least one. - // "out": Required, must be size equal to "soc_models". - // "dump_out": Optional. - // - // TODO: Support multi target compilation. - COMPILE, - }; - - // A command to run, see above. - Cmd cmd; - - // Collection of paths on local files system dictating where the tool should - // look for suitable LiteRtCompilerPlugin shared libraries. The tool will - // select the first ".so" file found with prefix "libLiteRtPlugin" that has - // the "soc_manufacturer" tag passed. Providing more than one plugin shared - // library for the same manufacturer results in an error. - std::vector lib_search_paths = {}; - - // Path to ".tflite" model the tool should operated on. - std::optional model = {}; - - // A tag representing a manufacturer the tool should target for compilation. - // This is used to select the appropriate plugin if multiple plugins are found - // in "lib_search_paths". - std::optional soc_manufacturer = {}; - - // Collection of soc models tags the tool should target for compilation. - std::vector soc_models = {}; - - // Where the tool should write its result file(s) to. If the command runs - // compilation, an "out" stream should be passed for each "soc_model" target - // requested for compilation. Output for the "ith" target will be written to - // the "ith" outs stream. - std::vector outs = {std::cout}; - - // Where to direct logging for this run. Passing nullopt here indicates - // "silent" behavior and should only be used when this tool is part of a - // larger pipeline like an end2end test. - UserStream dump_out; - - // Compiler flags to pass to the plugin. Only relevant for "APPLY" and - // "COMPILE" commands. - CompilerFlags compiler_flags; - - // If provided, only the subgraphs with the given indices are applied with the - // plugin. - absl::flat_hash_set subgraphs = {}; -}; - -LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run); - -} // namespace litert::tools - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc deleted file mode 100644 index 261ecd494e855f..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expruns or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/Support/CommandLine.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -using ::litert::tools::ApplyPlugin; -using ::litert::tools::ApplyPluginRun; -using ::litert::tools::UserStream; - -// NOLINTNEXTLINE -static llvm::cl::opt cmd( - llvm::cl::Positional, - llvm::cl::desc("Routine to run (apply, partition, compile, info, noop)."), - llvm::cl::init("partition")); - -// NOLINTNEXTLINE -static llvm::cl::opt model( - "model", llvm::cl::desc("Path to flatbuffer file."), llvm::cl::init("")); - -// TODO: b/366821557 - Support path to pre-compiled plugin in flags. -// NOLINTNEXTLINE -static llvm::cl::opt soc_manufacturer( - "soc_man", - llvm::cl::desc("String identifier of SoC manufacturer (e.g., GoogleTensor, " - "Qualcomm)."), - llvm::cl::init("ExampleSocManufacturer")); - -// TODO: Support multi target compilation. -// NOLINTNEXTLINE -static llvm::cl::opt soc_model("soc_model", - llvm::cl::desc("Target SoC model."), - llvm::cl::init("ExampleSocModel")); - -// NOLINTNEXTLINE -static llvm::cl::list libs( - "libs", - llvm::cl::desc("List of directories in which to search for suitable " - "compiler plugin shared libraries."), - llvm::cl::list_init(llvm::ArrayRef{ - "third_party/tensorflow/lite/experimental/litert/vendors/examples", - "third_party/tensorflow/lite/experimental/litert/vendors/qualcomm/" - "compiler", - "third_party/tensorflow/lite/experimental/litert/vendors/mediatek/" - "compiler", - "third_party/tensorflow/lite/experimental/litert/vendors/" - "google_tensor/compiler"})); - -// NOLINTNEXTLINE -static llvm::cl::list outs( - "o", - llvm::cl::desc("Path to files for output, \"-\" indicates standard out, " - "\"--\" for standard err, \"none\" for null stream."), - llvm::cl::list_init(llvm::ArrayRef{"-"})); - -// NOLINTNEXTLINE -static llvm::cl::opt err( - "err", - llvm::cl::desc("Path to file for err output, \"-\" indicates standard out, " - "\"--\" for standard err, \"none\" for null stream."), - llvm::cl::init("--")); - -// NOLINTNEXTLINE -static llvm::cl::opt compiler_flags( - "compiler-flags", - llvm::cl::desc("List of comma separated (no space) compiler flags. Flags " - "may be key-value pairs " - "in the format of \"key=value\", or just \"key\". E.g. " - "\"--compiler-flags=key1=value1,key2\"")); - -// NOLINTNEXTLINE -static llvm::cl::list subgraphs( - "subgraphs", - llvm::cl::desc("If provides, only the subgraphs with the given indices " - "are applied with the plugin."), - llvm::cl::list_init(llvm::ArrayRef{})); - -ApplyPluginRun::Ptr ParseFlags() { - auto res = std::make_unique(); - - if (!model.empty()) { - res->model = model; - } - - res->compiler_flags = *litert::internal::ParseCompilerFlags(compiler_flags); - - res->soc_manufacturer = soc_manufacturer; - res->soc_models.push_back(soc_model); - - res->lib_search_paths.assign(libs.begin(), libs.end()); - - if (cmd == "apply") { - res->cmd = ApplyPluginRun::Cmd::APPLY; - } else if (cmd == "partition") { - res->cmd = ApplyPluginRun::Cmd::PARTITION; - } else if (cmd == "compile") { - res->cmd = ApplyPluginRun::Cmd::COMPILE; - } else if (cmd == "info") { - res->cmd = ApplyPluginRun::Cmd::INFO; - } else if (cmd == "noop") { - res->cmd = ApplyPluginRun::Cmd::NOOP; - } else { - return nullptr; - } - - for (auto subgraph_idx : subgraphs) { - res->subgraphs.insert(subgraph_idx); - } - - return res; -} - -int main(int argc, char* argv[]) { - llvm::cl::ParseCommandLineOptions(argc, argv); - - auto run = ParseFlags(); - if (run == nullptr) { - return 1; - } - - run->outs.clear(); - std::vector> oss; - for (const auto& out : outs) { - oss.push_back(std::make_unique( - UserStream::MakeFromFlag(out))); - run->outs.push_back(oss.back()->Get()); - } - - run->dump_out = UserStream::MakeFromFlag(err); - - run->dump_out.Get() << absl::StreamFormat( - "CMD: %s\nMODEL: %s\nSOC_MANUFACTURER: %s\nSOC_MODEL: %s\n", cmd, model, - soc_manufacturer, soc_model); - - return ApplyPlugin(std::move(run)); -} diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc deleted file mode 100644 index b86bc5ec19f874..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::tools { -namespace { - -using ::litert::internal::kLiteRtBuildStampKey; -using ::litert::internal::ParseBuildStamp; -using ::testing::HasSubstr; -using ::testing::litert::IsError; - -static constexpr absl::string_view kPluginSearchPath = - "third_party/tensorflow/lite/experimental/litert/vendors/examples"; - -static constexpr absl::string_view kSocManufacturer = "ExampleSocManufacturer"; - -static constexpr absl::string_view kSocModel = "ExampleSocModel"; - -absl::string_view TestModelPath(absl::string_view filename) { - static char kModelPath[512] = {}; - const auto model_path = ::litert::testing::GetTestFilePath(filename); - ABSL_CHECK(model_path.size() < 512); - model_path.copy(kModelPath, model_path.size(), 0); - return kModelPath; -} - -ApplyPluginRun::Ptr MakeBaseRun( - ApplyPluginRun::Cmd cmd, absl::string_view model_path = "one_mul.tflite") { - auto run = std::make_unique(); - run->cmd = cmd; - run->lib_search_paths.push_back(kPluginSearchPath); - run->model.emplace(TestModelPath(model_path)); - run->soc_manufacturer.emplace(kSocManufacturer); - run->soc_models.push_back(kSocModel); - run->outs.clear(); - return run; -} - -TEST(TestApplyPluginTool, TestInfoBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); - run->lib_search_paths.clear(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestInfo) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_THAT(out.str(), - ::testing::HasSubstr( - "< LiteRtCompilerPlugin > \"ExampleSocManufacturer\" | " - "\"ExampleSocModel\"")); -} - -TEST(TestApplyPluginTool, TestNoopBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestNoop) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - - auto model = Model::CreateFromBuffer( - BufferRef(out.view().data(), out.view().size())); - EXPECT_EQ(model->Get()->NumSubgraphs(), 1); -} - -TEST(TestApplyPluginTool, TestPartitionBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestPartition) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_FALSE(out.str().empty()); -} - -TEST(TestApplyPluginTool, TestCompileBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestCompile) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_FALSE(out.str().empty()); - EXPECT_THAT(out.str(), HasSubstr("Partition_0_with_1_muls")); -} - -TEST(TestApplyPluginTool, TestApplyBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestApply) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - - const auto out_str = out.str(); - BufferRef serialized(out_str.data(), out_str.size()); - - auto model = Model::CreateFromBuffer(serialized); - EXPECT_EQ(model->Get()->NumSubgraphs(), 1); - - { - auto stamp_buffer = model->Get()->FindMetadata(kLiteRtBuildStampKey); - auto stamp = ParseBuildStamp(*stamp_buffer); - auto [man, soc_model] = *stamp; - EXPECT_EQ(man, kSocManufacturer); - EXPECT_EQ(soc_model, kSocModel); - } - - auto* op = model->Get()->MainSubgraph()->Ops().front(); - ASSERT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - - const auto options = internal::GetDispatchOpOptions(op->CustomOptions()); - const auto& [size, offset, name] = options; - EXPECT_EQ(name, "Partition_0"); - ASSERT_LE(offset + size, serialized.Size()); - - EXPECT_THAT(serialized.StrView().substr(offset, size), - HasSubstr("Partition_0_with_1_muls")); -} - -TEST(TestApplyPluginTool, TestCompileToMultiByteCode) { - auto run = - MakeBaseRun(ApplyPluginRun::Cmd::COMPILE, "multi_subgraph_mul.tflite"); - std::stringstream out_0; - std::stringstream out_1; - run->outs.push_back(out_0); - run->outs.push_back(out_1); - - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_FALSE(out_0.str().empty()); - EXPECT_FALSE(out_1.str().empty()); - EXPECT_THAT(out_0.str(), HasSubstr("Partition_0_with_1_muls")); - EXPECT_THAT(out_1.str(), HasSubstr("Partition_1_with_1_muls")); -} - -} // namespace -} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc deleted file mode 100644 index b82fdc18fc0228..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" - -namespace litert::benchmark { -namespace { -using ::litert::CompilationOptions; -using ::litert::CompiledModel; -using ::litert::TensorBuffer; - -CompilationOptions CreateCompiledModelOptions(const BenchmarkParams& params) { - auto use_gpu = params.Get("use_gpu"); - CompilationOptions compilation_options = - *litert::CompilationOptions::Create(); - if (use_gpu) { - compilation_options.SetHardwareAccelerators( - LiteRtHwAccelerators::kLiteRtHwAcceleratorGpu); - } - return compilation_options; -} -} // namespace - -TfLiteStatus BenchmarkLiteRtModel::Init() { - std::string fd_or_graph_path = params_.Get("graph"); - LITERT_LOG(LITERT_INFO, "Loading model from: %s", fd_or_graph_path.c_str()); - model_ = *litert::Model::CreateFromFile(fd_or_graph_path); - if (!model_) { - LITERT_LOG(LITERT_ERROR, "Failed to load model: %s", - fd_or_graph_path.c_str()); - return kTfLiteError; - } - - auto env = Environment::Create({}); - if (!env) { - LITERT_LOG(LITERT_ERROR, "Failed to create litert environment."); - return kTfLiteError; - } - - auto compilation_options = CreateCompiledModelOptions(params_); - auto compiled_model_result = - litert::CompiledModel::Create(*env, model_, compilation_options); - if (!compiled_model_result) { - LITERT_LOG(LITERT_ERROR, "Failed to create compiled model."); - return kTfLiteError; - } - - compiled_model_ = std::make_unique( - std::move(*compiled_model_result)); - auto signature = params_.Get("signature_to_run_for"); - auto input_buffers_result = compiled_model_->CreateInputBuffers(signature); - if (!input_buffers_result) { - LITERT_LOG(LITERT_ERROR, "Failed to create input buffers."); - return kTfLiteError; - } - input_buffers_ = std::make_unique>( - std::move(*input_buffers_result)); - - auto output_buffers_result = compiled_model_->CreateOutputBuffers(signature); - if (!output_buffers_result) { - LITERT_LOG(LITERT_ERROR, "Failed to create output buffers."); - return kTfLiteError; - } - output_buffers_ = std::make_unique>( - std::move(*output_buffers_result)); - - return kTfLiteOk; -} -} // namespace litert::benchmark diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h deleted file mode 100644 index 8534efddee78ab..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_BENCHMARK_LITERT_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_BENCHMARK_LITERT_MODEL_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" -#include "tensorflow/lite/tools/benchmark/benchmark_params.h" -#include "tensorflow/lite/tools/utils.h" - -namespace litert { -namespace benchmark { - -using ::litert::CompiledModel; -using ::litert::Environment; -using ::litert::Model; -using ::litert::TensorBuffer; -using ::tflite::benchmark::BenchmarkModel; -using ::tflite::benchmark::BenchmarkParam; -using ::tflite::benchmark::BenchmarkParams; -using ::tflite::utils::InputTensorData; - -class BenchmarkLiteRtModel : public BenchmarkModel { - public: - BenchmarkLiteRtModel() = default; - explicit BenchmarkLiteRtModel(BenchmarkParams params) - : BenchmarkModel(std::move(params)) {} - ~BenchmarkLiteRtModel() override = default; - static BenchmarkParams DefaultParams() { - BenchmarkParams default_params = BenchmarkModel::DefaultParams(); - default_params.AddParam("graph", BenchmarkParam::Create("")); - default_params.AddParam("signature_to_run_for", - BenchmarkParam::Create("")); - default_params.AddParam("use_xnnpack", BenchmarkParam::Create(true)); - default_params.AddParam("use_gpu", BenchmarkParam::Create(false)); - - return default_params; - } - - TfLiteStatus Init() override; - - int64_t MayGetModelFileSize() override { - std::string fd_or_graph_path = params_.Get("graph"); - // Path can be one of the following: - // 1) File descriptor path: path must be in the format of - // "fd:%model_fd%:%model_offset%:%model_size%". - // 2) File path: path to the model file. - // Please see tensorflow/lite/tools/model_loader.h for more information. - std::vector parts = - absl::StrSplit(fd_or_graph_path, ':'); - if (!parts.empty() && parts[0] == "fd") { - int64_t model_size = -1; - if (parts.size() != 4 || !absl::SimpleAtoi(parts[3], &model_size)) { - LITERT_LOG(LITERT_ERROR, "Failed to parse model file size: %s", - fd_or_graph_path.c_str()); - } - return model_size; - } - std::ifstream in_file(fd_or_graph_path, std::ios::binary | std::ios::ate); - return in_file.tellg(); - } - - TfLiteStatus RunImpl() override { - if (!compiled_model_) { - LITERT_LOG(LITERT_ERROR, "Compiled model not initialized"); - return kTfLiteError; - } - auto signature = params_.Get("signature_to_run_for"); - if (compiled_model_->Run(signature, *input_buffers_, *output_buffers_)) { - return kTfLiteOk; - } else { - LITERT_LOG(LITERT_ERROR, "Run failed"); - return kTfLiteError; - } - } - - uint64_t ComputeInputBytes() override { - uint64_t total_bytes = 0; - for (const auto& buffer : *input_buffers_) { - total_bytes += *buffer.Size(); - } - return total_bytes; - } - - InputTensorData CreateRandomTensorData(const litert::TensorBuffer& t, - std::string name) { - float low_range = 0; - float high_range = 0; - tflite::utils::GetDataRangesForType( - static_cast(t.TensorType()->ElementType()), &low_range, - &high_range); - return tflite::utils::CreateRandomTensorData( - name, static_cast(t.TensorType()->ElementType()), *t.Size(), - low_range, high_range); - } - - TfLiteStatus PrepareInputData() override { - int index = 0; - for (auto& buffer : *input_buffers_) { - auto t_data = - CreateRandomTensorData(buffer, "input_" + std::to_string(index)); - buffer.Write(absl::MakeSpan( - reinterpret_cast(t_data.data.get()), t_data.bytes)); - ++index; - } - return kTfLiteOk; - } - - TfLiteStatus ResetInputsAndOutputs() override { return kTfLiteOk; } - - private: - Model model_; - std::unique_ptr compiled_model_; - std::unique_ptr> input_buffers_; - std::unique_ptr> output_buffers_; -}; - -} // namespace benchmark -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_BENCHMARK_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc deleted file mode 100644 index 8cf1891085f5b9..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h" -#include "tensorflow/lite/tools/logging.h" - -namespace litert::benchmark { - -int Main(int argc, char** argv) { - TFLITE_LOG(INFO) << "STARTING!"; - BenchmarkLiteRtModel benchmark; - if (benchmark.Run(argc, argv) != kTfLiteOk) { - TFLITE_LOG(ERROR) << "Benchmarking failed."; - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -} -} // namespace litert::benchmark - -int main(int argc, char** argv) { return litert::benchmark::Main(argc, argv); } diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc deleted file mode 100644 index 08634b7aff4d41..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h" - -#include -#include - -#include -#include - -#include -#include "tensorflow/lite/core/c/c_api_types.h" -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" -#include "tensorflow/lite/tools/benchmark/benchmark_params.h" - -namespace litert { -namespace benchmark { -namespace { -using ::litert::benchmark::BenchmarkLiteRtModel; -using ::tflite::benchmark::BenchmarkListener; -using ::tflite::benchmark::BenchmarkParams; -using ::tflite::benchmark::BenchmarkResults; - -static constexpr char kModelPath[] = - "third_party/tensorflow/lite/experimental/litert/test/testdata/" - "mobilenet_v2_1.0_224.tflite"; -static constexpr char kSignatureToRunFor[] = ""; - -class TestBenchmarkListener : public BenchmarkListener { - public: - void OnBenchmarkEnd(const BenchmarkResults& results) override { - results_ = results; - } - - BenchmarkResults results_; -}; - -TEST(BenchmarkLiteRtModelTest, GetModelSizeFromPathSucceeded) { - BenchmarkParams params = BenchmarkLiteRtModel::DefaultParams(); - params.Set("graph", kModelPath); - params.Set("signature_to_run_for", kSignatureToRunFor); - params.Set("num_runs", 1); - params.Set("warmup_runs", 0); - params.Set("use_xnnpack", true); - params.Set("use_gpu", false); - BenchmarkLiteRtModel benchmark = BenchmarkLiteRtModel(std::move(params)); - TestBenchmarkListener listener; - benchmark.AddListener(&listener); - - benchmark.Run(); - - EXPECT_GE(listener.results_.model_size_mb(), 0); -} - -TEST(BenchmarkLiteRtModelTest, GPUAcceleration) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - BenchmarkParams params = BenchmarkLiteRtModel::DefaultParams(); - params.Set("graph", kModelPath); - params.Set("signature_to_run_for", kSignatureToRunFor); - params.Set("use_xnnpack", false); - params.Set("use_gpu", true); - - BenchmarkLiteRtModel benchmark = BenchmarkLiteRtModel(std::move(params)); - - EXPECT_EQ(benchmark.Run(), kTfLiteOk); -} - -} // namespace -} // namespace benchmark -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/tools/dump.cc b/tensorflow/lite/experimental/litert/tools/dump.cc deleted file mode 100644 index e6c84d631773d5..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/dump.cc +++ /dev/null @@ -1,442 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -#include - -#ifndef __ANDROID__ -#if __has_include() -#include -#endif -#endif - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -namespace { - -static constexpr int kMaxDisplayCount = 16; - -void DumpNode(const LiteRtTensorT& tensor, std::ostream& out) { - switch (tensor.Type().first) { - case kLiteRtRankedTensorType: - Dump(tensor.Type().second.ranked_tensor_type, out); - break; - case kLiteRtUnrankedTensorType: - Dump(tensor.Type().second.unranked_tensor_type.element_type, out); - break; - default: - out << "UKNOWN_TENSOR_TYPE" << tensor.Type().first; - } - Dump(tensor.Qparams(), out); -} - -void DumpNode(const LiteRtOpT& op, std::ostream& out) { - Dump(op.OpCode(), out); -} - -void DumpSignature(const std::vector& ins, - const std::vector& outs, std::ostream& out) { - out << "("; - for (auto it = ins.begin(); it < ins.end(); ++it) { - DumpNode(**it, out); - if (it != ins.end() - 1) { - out << ", "; - } - } - out << ")"; - - out << " -> "; - const bool paren_outs = outs.size() != 1; - if (paren_outs) { - out << "("; - } - for (auto it = outs.begin(); it < outs.end(); ++it) { - DumpNode(**it, out); - if (it != outs.end() - 1) { - out << ", "; - } - } - if (paren_outs) { - out << ")"; - } -} - -} // namespace - -void Dump(LiteRtOpCode code, std::ostream& out) { - switch (code) { - case kLiteRtOpCodeTflAdd: - out << "TFL_ADD"; - break; - case kLiteRtOpCodeTflMul: - out << "TFL_MUL"; - break; - case kLiteRtOpCodeTflCustom: - out << "TFL_CUSTOM_OP"; - break; - case kLiteRtOpCodeTflSlice: - out << "TFL_SLICE"; - break; - case kLiteRtOpCodeTflDiv: - out << "TFL_DIV"; - break; - case kLiteRtOpCodeTflRsqrt: - out << "TFL_RSQRT"; - break; - case kLiteRtOpCodeTflTanh: - out << "TFL_TANH"; - break; - case kLiteRtOpCodeTflSub: - out << "TFL_SUB"; - break; - case kLiteRtOpCodeTflReshape: - out << "TFL_RESHAPE"; - break; - case kLiteRtOpCodeTflBatchMatmul: - out << "TFL_BATCH_MATMUL"; - break; - case kLiteRtOpCodeTflSum: - out << "TFL_SUM"; - break; - case kLiteRtOpCodeTflConcatenation: - out << "TFL_CONCATENATION"; - break; - case kLiteRtOpCodeTflSoftmax: - out << "TFL_SOFTMAX"; - break; - case kLiteRtOpCodeTflCast: - out << "TFL_CAST"; - break; - case kLiteRtOpCodeTflTranspose: - out << "TFL_TRANSPOSE"; - break; - case kLiteRtOpCodeTflSin: - out << "TFL_SIN"; - break; - case kLiteRtOpCodeTflCos: - out << "TFL_COS"; - break; - case kLiteRtOpCodeTflSelect: - out << "TFL_SELECT"; - break; - case kLiteRtOpCodeTflSelectV2: - out << "TFL_SELECT_V2"; - break; - case kLiteRtOpCodeTflFullyConnected: - out << "TFL_FULLY_CONNECTED"; - break; - case kLiteRtOpCodeTflEmbeddingLookup: - out << "TFL_EMBEDDING_LOOKUP"; - break; - case kLiteRtOpCodeTflLogicalAnd: - out << "TFL_LOGICAL_AND"; - break; - case kLiteRtOpCodeTflLess: - out << "TFL_LESS"; - break; - case kLiteRtOpCodeTflGreater: - out << "TFL_GREATER"; - break; - case kLiteRtOpCodeTflGelu: - out << "TFL_GELU"; - break; - case kLiteRtOpCodeTflDynamicUpdateSlice: - out << "TFL_DYNAMIC_UPDATE_SLICE"; - break; - case kLiteRtOpCodeTflPack: - out << "TFL_PACK"; - break; - case kLiteRtOpCodeTflQuantize: - out << "TFL_QUANTIZE"; - break; - case kLiteRtOpCodeTflLeakyRelu: - out << "TFL_LEAKY_RELU"; - break; - case kLiteRtOpCodeTflHardSwish: - out << "TFL_HARD_SWISH"; - break; - case kLiteRtOpCodeTflAveragePool2d: - out << "AVERAGE_POOL_2D"; - break; - case kLiteRtOpCodeTflDepthwiseConv2d: - out << "DEPTHWISE_CONV_2D"; - break; - case kLiteRtOpCodeTflSpaceToDepth: - out << "SPACE_TO_DEPTH"; - break; - case kLiteRtOpCodeTflDepthToSpace: - out << "DEPTH_TO_SPACE"; - break; - case kLiteRtOpCodeTflConv2d: - out << "CONV_2D"; - break; - case kLiteRtOpCodeTflResizeBilinear: - out << "RESIZE_BILINEAR"; - break; - case kLiteRtOpCodeTflMinimum: - out << "MINIMUM"; - break; - case kLiteRtOpCodeTflMaximum: - out << "MAXIMUM"; - break; - case kLiteRtOpCodeTflResizeNearestNeighbor: - out << "RESIZE_NEAREST_NEIGHBOR"; - break; - case kLiteRtOpCodeTflRelu: - out << "TFL_RELU"; - break; - case kLiteRtOpCodeTflRelu6: - out << "TFL_RELU6"; - break; - default: - out << "UKNOWN_OP_CODE: " << code; - break; - } -}; - -// Dump details about the given LiteRtElementType to the given stream. -void Dump(LiteRtElementType type, std::ostream& out) { - switch (type) { - case kLiteRtElementTypeFloat32: - out << "f32"; - break; - case kLiteRtElementTypeInt32: - out << "i32"; - break; - case kLiteRtElementTypeFloat64: - out << "f64"; - break; - case kLiteRtElementTypeInt64: - out << "i64"; - break; - case kLiteRtElementTypeFloat16: - out << "f16"; - break; - case kLiteRtElementTypeInt16: - out << "i16"; - break; - case kLiteRtElementTypeInt8: - out << "i8"; - break; - case kLiteRtElementTypeUInt8: - out << "ui8"; - break; - case kLiteRtElementTypeBool: - out << "i1"; - break; - default: - out << "UKNNOWN_ELEMENT_TYPE: " << type; - } -} - -void Dump(const LiteRtRankedTensorType& type, std::ostream& out) { - out << "<"; - for (int i = 0; i < type.layout.rank; ++i) { - out << type.layout.dimensions[i] << "x"; - } - Dump(type.element_type, out); - out << ">"; -} - -void Dump(const LiteRtTensorT& tensor, std::ostream& out) { - out << "LiteRtTensor : "; - DumpNode(tensor, out); - out << " [ "; - if (tensor.DefiningOp() == nullptr) { - out << "*"; - } else { - DumpNode(*tensor.DefiningOp(), out); - } - out << " ] "; - - out << "("; - for (auto it = tensor.Users().begin(); it < tensor.Users().end(); ++it) { - DumpNode(**it, out); - if (it != tensor.Users().end() - 1) { - out << ", "; - } - } - out << ")"; - out << "\n"; -} - -void Dump(const LiteRtOpT& op, std::ostream& out) { - out << "LiteRtOp : [ "; - DumpNode(op, out); - out << " ] "; - DumpSignature(op.Inputs(), op.Outputs(), out); - out << "\n"; -} - -void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out) { - constexpr absl::string_view kSubgraphTpl = - "LiteRtSubgraph : [ #ops=%d #tensors=%d ] "; - out << absl::StreamFormat(kSubgraphTpl, subgraph.Ops().size(), - subgraph.Tensors().size()); - DumpSignature(subgraph.Inputs(), subgraph.Outputs(), out); - out << "\n"; -} - -void Dump(const CompilerPlugin& plugin, std::ostream& out) { - constexpr absl::string_view kPluginDumpTpl = - "SocManufacturer: %s\nSocModels: { "; - out << absl::StreamFormat(kPluginDumpTpl, plugin.SocManufacturer()); - - for (auto it = plugin.SocModels().begin(); it < plugin.SocModels().end(); - ++it) { - out << *it; - if (it != plugin.SocModels().end() - 1) { - out << ","; - } - out << " "; - } - - out << "}\n"; -} - -void Dump(const LiteRtModelT& model, std::ostream& out) { - out << absl::StreamFormat("LiteRtModel : [ #subgraphs=%d ]\n", - model.Subgraphs().size()); -} - -void DumpOptions(const LiteRtOpT& op, std::ostream& out) { - auto& opts = litert::internal::GetTflOptions(op); - if (opts.value == nullptr) { - out << "null options\n"; - return; - } - switch (op.OpCode()) { - case kLiteRtOpCodeTflAdd: - out << "fused_activation_function: " - << opts.AsAddOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflMul: - out << "fused_activation_function: " - << opts.AsMulOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflBatchMatmul: - out << "adj_x: " << opts.AsBatchMatMulOptions()->adj_x << "\n"; - out << "adj_y: " << opts.AsBatchMatMulOptions()->adj_y << "\n"; - out << "asymmetric_quantize_input: " - << opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs << "\n"; - break; - case kLiteRtOpCodeTflConcatenation: - out << "axis: " << opts.AsConcatenationOptions()->axis << "\n"; - out << "fused_activation_function: " - << opts.AsConcatenationOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflDiv: - out << "fused_activation_function: " - << opts.AsDivOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflFullyConnected: - out << "weights_format: " - << opts.AsFullyConnectedOptions()->weights_format << "\n"; - out << "keep_num_dims: " << opts.AsFullyConnectedOptions()->keep_num_dims - << "\n"; - out << "quantized_bias_type: " - << opts.AsFullyConnectedOptions()->quantized_bias_type << "\n"; - out << "asymmetric_quantize_input: " - << opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs << "\n"; - out << "fused_activation_function: " - << opts.AsFullyConnectedOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflSoftmax: - out << "beta: " << opts.AsSoftmaxOptions()->beta << "\n"; - break; - case kLiteRtOpCodeTflStridedSlice: - out << "begin_mask: " << opts.AsStridedSliceOptions()->begin_mask << "\n"; - out << "end_mask: " << opts.AsStridedSliceOptions()->end_mask << "\n"; - out << "ellipsis_mask: " << opts.AsStridedSliceOptions()->ellipsis_mask - << "\n"; - out << "new_axis_mask: " << opts.AsStridedSliceOptions()->new_axis_mask - << "\n"; - out << "shrink_axis_mask: " - << opts.AsStridedSliceOptions()->shrink_axis_mask << "\n"; - out << "offset: " << opts.AsStridedSliceOptions()->offset << "\n"; - break; - case kLiteRtOpCodeTflSub: - out << "fused_activation_function: " - << opts.AsSubOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflReshape: - out << "new_shape: "; - if (opts.AsReshapeOptions() != nullptr) { - const int32_t* new_shape = opts.AsReshapeOptions()->new_shape.data(); - int32_t new_shape_size = opts.AsReshapeOptions()->new_shape.size(); - for (int i = 0; i < new_shape_size; ++i) { - out << new_shape[i] << " "; - } - } - break; - case kLiteRtOpCodeTflSum: - out << "keepdims: " << opts.AsReducerOptions()->keep_dims << "\n"; - break; - case kLiteRtOpCodeTflPack: - out << "axis: " << opts.AsPackOptions()->axis << "\n"; - break; - default: - out << "No options for op code: " << op.OpCode(); - break; - } -} - -void Dump(Quantization quantization, std::ostream& out) { - int max_display_count; - switch (quantization.first) { - case kLiteRtQuantizationNone: - return; - case kLiteRtQuantizationPerTensor: - out << absl::StreamFormat(" ", - quantization.second.per_tensor.zero_point, - quantization.second.per_tensor.scale); - return; - case kLiteRtQuantizationPerChannel: - max_display_count = - kMaxDisplayCount < quantization.second.per_channel.num_channels - ? kMaxDisplayCount - : quantization.second.per_channel.num_channels; - out << absl::StreamFormat(" ", quantization.second.per_channel.quantized_dimension); - return; - default: - out << " "; - return; - } -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/tools/dump.h b/tensorflow/lite/experimental/litert/tools/dump.h deleted file mode 100644 index 89254ae48e29a6..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/dump.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// -// LiteRt IR -// - -// Dump details about the given LiteRtOpT to the given stream. -void Dump(const LiteRtOpT& op, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtSubgraphT to the given stream. -void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtTensorT to the given stream. -void Dump(const LiteRtTensorT& tensor, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtOpCode to the given stream. -void Dump(LiteRtOpCode code, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtElementType to the given stream. -void Dump(LiteRtElementType type, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtRankedTensorType to the given stream. -void Dump(const LiteRtRankedTensorType& type, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtModel to the given stream. -void Dump(const LiteRtModelT& model, std::ostream& out = std::cerr); - -// Dump details about the given quantization params. -void Dump(Quantization quantization, std::ostream& out = std::cerr); - -// Dump details about options -void DumpOptions(const LiteRtOpT& op, std::ostream& out = std::cerr); - -// -// Library Utilities -// - -// Dumps details about the loaded LiteRtCompilerPlugin library. -void Dump(const CompilerPlugin& plugin, std::ostream& out = std::cerr); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ diff --git a/tensorflow/lite/experimental/litert/tools/dump_test.cc b/tensorflow/lite/experimental/litert/tools/dump_test.cc deleted file mode 100644 index ff89547c2350aa..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/dump_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -#include -#include -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace { - -using ::litert::internal::Dump; -using ::litert::internal::DumpOptions; -using ::litert::testing::LoadTestFileModel; - -TEST(DumpTest, TestDump) { - auto model = LoadTestFileModel("one_mul.tflite"); - - { - std::ostringstream model_dump; - Dump(*model.Get(), model_dump); - EXPECT_EQ(model_dump.view(), "LiteRtModel : [ #subgraphs=1 ]\n"); - } - - { - const LiteRtTensorT& in_tensor = model.Get()->Subgraph(0).Input(0); - std::ostringstream in_tensor_dump; - Dump(in_tensor, in_tensor_dump); - EXPECT_EQ(in_tensor_dump.view(), - "LiteRtTensor : <2x2xf32> [ * ] (TFL_MUL)\n"); - } - - { - const LiteRtTensorT& out_tensor = model.Get()->Subgraph(0).Output(0); - std::ostringstream out_tensor_dump; - Dump(out_tensor, out_tensor_dump); - EXPECT_EQ(out_tensor_dump.view(), - "LiteRtTensor : <2x2xf32> [ TFL_MUL ] ()\n"); - } - - { - const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); - std::ostringstream op_dump; - Dump(op, op_dump); - EXPECT_EQ(op_dump.view(), - "LiteRtOp : [ TFL_MUL ] (<2x2xf32>, <2x2xf32>) -> <2x2xf32>\n"); - } - - { - const LiteRtSubgraphT& subgraph = model.Get()->Subgraph(0); - std::ostringstream subgraph_dump; - Dump(subgraph, subgraph_dump); - EXPECT_EQ( - subgraph_dump.view(), - "LiteRtSubgraph : [ #ops=1 #tensors=3 ] (<2x2xf32>, <2x2xf32>) -> " - "<2x2xf32>\n"); - } -} - -TEST(DumpTest, TestDumpOptions) { - auto model = LoadTestFileModel("simple_strided_slice_op.tflite"); - const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); - std::ostringstream op_dump; - DumpOptions(op, op_dump); - EXPECT_EQ(op_dump.view(), - "begin_mask: 0\n" - "end_mask: 0\n" - "ellipsis_mask: 0\n" - "new_axis_mask: 0\n" - "shrink_axis_mask: 0\n" - "offset: 0\n"); -} - -TEST(DumpTest, TestDumpPerTensorQuantization) { - QuantizationDetail per_tensor_detail; - per_tensor_detail.per_tensor.scale = 1.0; - per_tensor_detail.per_tensor.zero_point = 2; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationPerTensor, per_tensor_detail), q_dump); - EXPECT_EQ(q_dump.view(), " "); -} - -TEST(DumpTest, TestDumpPerChannelQuantization) { - static constexpr size_t kRank = 2; - static constexpr size_t kQuantizedDimension = 1; - static constexpr float kScales[kRank] = {1.0, 2.0}; - static constexpr int64_t kZps[kRank] = {2, 3}; - QuantizationDetail per_channel_detail; - per_channel_detail.per_channel.scales = const_cast(kScales); - per_channel_detail.per_channel.zero_points = const_cast(kZps); - per_channel_detail.per_channel.quantized_dimension = kQuantizedDimension; - per_channel_detail.per_channel.num_channels = kRank; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationPerChannel, per_channel_detail), - q_dump); - EXPECT_FALSE(q_dump.view().empty()); -} - -TEST(DumpTest, TestDumpNoQuantization) { - QuantizationDetail none_detail; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationNone, none_detail), q_dump); - EXPECT_TRUE(q_dump.view().empty()); -} - -TEST(DumpTest, TestDumpUnknownQuantization) { - QuantizationDetail detail; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationBlockWise, detail), q_dump); - EXPECT_EQ(q_dump.view(), " "); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/tools/outstream.h b/tensorflow/lite/experimental/litert/tools/outstream.h deleted file mode 100644 index a920f21839592b..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/outstream.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -namespace litert::tools { - -using OutStream = std::reference_wrapper; -using OutStreamPtr = std::unique_ptr; - -// Out stream configured by a user by flag. -class UserStream { - public: - // Parse the flag and get a configured stream. - static UserStream MakeFromFlag(absl::string_view flag) { - if (flag == kCerr) { - LITERT_LOG(LITERT_INFO, "Setup cerr stream\n", ""); - return UserStream(std::cerr); - } else if (flag == kCout) { - LITERT_LOG(LITERT_INFO, "Setup cout stream\n", ""); - return UserStream(std::cout); - } else if (flag == kNone) { - LITERT_LOG(LITERT_INFO, "Setup null stream\n", ""); - return UserStream(); - } else { - // File stream. - LITERT_LOG(LITERT_INFO, "Setup file stream\n", ""); - auto ofstream = std::make_unique(); - ofstream->open(flag.data()); - return UserStream(std::move(ofstream)); - } - } - - // Get the actual stream to write to. - OutStream Get() { return used_; } - - // Silent stream. - UserStream() - : stored_(std::make_unique(nullptr)), used_(*stored_) {} - // From reference to external stream (cerr, cout) - explicit UserStream(OutStream ostream) : stored_(nullptr), used_(ostream) {} - // From stream to internalize. - explicit UserStream(OutStreamPtr ostream) - : stored_(std::move(ostream)), used_(*stored_) {} - - UserStream(UserStream&&) = default; - UserStream& operator=(UserStream&&) = default; - - private: - // These are used in the various CLI's flags that configure output streams. - static constexpr absl::string_view kCerr = "--"; - static constexpr absl::string_view kCout = "-"; - static constexpr absl::string_view kNone = "none"; - - OutStreamPtr stored_; - OutStream used_; -}; - -} // namespace litert::tools - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ diff --git a/tensorflow/lite/experimental/litert/tools/run_model.cc b/tensorflow/lite/experimental/litert/tools/run_model.cc deleted file mode 100644 index c360beae0f2a4b..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/run_model.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/profiling/time.h" - -ABSL_FLAG(std::string, graph, "", "Model filename to use for testing."); -ABSL_FLAG(std::string, dispatch_library_dir, "", - "Path to the dispatch library."); -ABSL_FLAG(bool, use_gpu, false, "Use GPU Accelerator."); - -namespace litert { -namespace { - -Expected RunModel() { - if (absl::GetFlag(FLAGS_graph).empty()) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Model filename is empty. Use --graph to provide it."); - } - - ABSL_LOG(INFO) << "Model: " << absl::GetFlag(FLAGS_graph); - LITERT_ASSIGN_OR_RETURN(auto model, - Model::CreateFromFile(absl::GetFlag(FLAGS_graph))); - - const std::string dispatch_library_dir = - absl::GetFlag(FLAGS_dispatch_library_dir); - - std::vector environment_options = {}; - if (!dispatch_library_dir.empty()) { - environment_options.push_back(litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - absl::string_view(dispatch_library_dir)}); - }; - - LITERT_ASSIGN_OR_RETURN( - auto env, - litert::Environment::Create(absl::MakeConstSpan(environment_options))); - - ABSL_LOG(INFO) << "Create CompiledModel"; - auto accelerator = absl::GetFlag(FLAGS_use_gpu) ? kLiteRtHwAcceleratorGpu - : kLiteRtHwAcceleratorNone; - if (accelerator == kLiteRtHwAcceleratorGpu) { - ABSL_LOG(INFO) << "Using GPU Accelerator"; - } - LITERT_ASSIGN_OR_RETURN(auto compiled_model, - CompiledModel::Create(env, model, accelerator)); - - LITERT_ASSIGN_OR_RETURN(auto signatures, model.GetSignatures()); - size_t signature_index = 0; - - ABSL_LOG(INFO) << "Prepare input buffers"; - - LITERT_ASSIGN_OR_RETURN(auto input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - - ABSL_LOG(INFO) << "Prepare output buffers"; - - LITERT_ASSIGN_OR_RETURN(auto output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - - ABSL_LOG(INFO) << "Run model"; - uint64_t start = tflite::profiling::time::NowMicros(); - auto status = - compiled_model.Run(signature_index, input_buffers, output_buffers); - uint64_t end = tflite::profiling::time::NowMicros(); - LITERT_LOG(LITERT_INFO, "Run took %lu microseconds", end - start); - - ABSL_LOG(INFO) << "Model run completed"; - - return status; -} - -} // namespace -} // namespace litert - -int main(int argc, char** argv) { - absl::ParseCommandLine(argc, argv); - - auto res = litert::RunModel(); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -} diff --git a/tensorflow/lite/experimental/litert/tools/tool_display.cc b/tensorflow/lite/experimental/litert/tools/tool_display.cc deleted file mode 100644 index 2067d7826adb66..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/tool_display.cc +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/tool_display.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -namespace litert::tools { - -std::string ToolDisplay::MakeLabel(absl::string_view tool_label) { - return absl::StrFormat( - "[LITERT_TOOLS%s] ", - tool_label.empty() ? tool_label : absl::StrFormat(":%s", tool_label)); -} - -std::ostream& ToolDisplay::Display() { return ostream_.Get(); } - -std::ostream& ToolDisplay::Labeled() { - Display() << label_; - return Display(); -} - -std::ostream& ToolDisplay::Indented() { - Display() << "\t"; - return Display(); -} - -void ToolDisplay::Start(const absl::string_view scope_name) { - static constexpr absl::string_view kStartFmt = "Starting %s...\n"; - Labeled() << absl::StreamFormat(kStartFmt, scope_name); -} - -void ToolDisplay::Done(const absl::string_view scope_name) { - static constexpr absl::string_view kDoneFmt = "%s Done!\n"; - Labeled() << ""; - Indented() << absl::StreamFormat(kDoneFmt, scope_name); -} - -void ToolDisplay::Fail() { - Labeled() << ""; - Indented() << "Failed\n"; -} - -ToolDisplay::LoggedScope ToolDisplay::StartS(absl::string_view scope_name) { - return LoggedScope(*this, scope_name); -} - -void ToolDisplay::LoggedScope::Start() { parent_.Start(scope_name_); } - -void ToolDisplay::LoggedScope::Done() { parent_.Done(scope_name_); } - -ToolDisplay::LoggedScope::~LoggedScope() { Done(); } - -ToolDisplay::LoggedScope::LoggedScope(ToolDisplay& parent, - absl::string_view scope_name) - : parent_(parent), scope_name_(scope_name) { - Start(); -} - -static constexpr absl::string_view kArt = R"( - __ _ __ ____ __ - / / (_/ /____ / __ \/ /_ - / / / / __/ _ \/ /_/ / __/ - / /___/ / /_/ __/ _, _/ /_ -/_____/_/\__/\___/_/ |_|\__/ -)"; - -void DumpPreamble(ToolDisplay& display) { display.Display() << kArt << "\n"; } - -} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/tool_display.h b/tensorflow/lite/experimental/litert/tools/tool_display.h deleted file mode 100644 index 583d07ee3480f6..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/tool_display.h +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -namespace litert::tools { - -// Utility class for interactive logging for usage in command line tools only. -// Allows user to explicitly set target stream. -class ToolDisplay { - public: - using Ptr = std::unique_ptr; - // Construct configured ToolDisplay. Label is used for prefixing dumps - // in "LabeledStream". - explicit ToolDisplay(UserStream&& ostream, absl::string_view tool_label = "") - : label_(MakeLabel(tool_label)), - ostream_(std::forward(ostream)) {} - explicit ToolDisplay(OutStream ostream, absl::string_view tool_label = "") - : label_(MakeLabel(tool_label)), ostream_(UserStream(ostream)) {} - - ToolDisplay(const ToolDisplay&) = delete; - ToolDisplay& operator=(const ToolDisplay&) = delete; - ToolDisplay(ToolDisplay&&) = delete; - ToolDisplay& operator=(ToolDisplay&&) = delete; - - // Get out stream. - std::ostream& Display(); - - // Get Display with label prefix. - std::ostream& Labeled(); - - // Get Display with indent. - std::ostream& Indented(); - - // Log string indicating a sub rountine is beginning. - void Start(absl::string_view scope_name); - - // Log string indicating a sub rountine is done and succeeded. - void Done(absl::string_view scope_name = ""); - - // Log string indicating a sub rountine is done and failed. - void Fail(); - - // Logs "start/finish" messages automatically. - class LoggedScope { - friend class ToolDisplay; - - public: - LoggedScope(const LoggedScope&) = delete; - LoggedScope& operator=(const LoggedScope&) = delete; - LoggedScope(LoggedScope&&) = delete; - LoggedScope& operator=(LoggedScope&&) = delete; - - ~LoggedScope(); - - private: - explicit LoggedScope(ToolDisplay& parent, absl::string_view scope_name); - - void Start(); - void Done(); - - ToolDisplay& parent_; - // These should all be from literals. - absl::string_view scope_name_; - }; - - // Get object that prints a start message and an exit message - // automatically when it goes out of scope. - [[maybe_unused]] LoggedScope StartS(absl::string_view scope_name); - - private: - static std::string MakeLabel(absl::string_view tool_label); - std::string label_; - UserStream ostream_; -}; - -// Print art and info at cli startup. -void DumpPreamble(ToolDisplay& display); - -} // namespace litert::tools - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ diff --git a/tensorflow/lite/experimental/litert/tools/tool_display_test.cc b/tensorflow/lite/experimental/litert/tools/tool_display_test.cc deleted file mode 100644 index 94027f663c301c..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/tool_display_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/tool_display.h" - -#include - -#include -#include -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" - -namespace { - -using ::litert::tools::ToolDisplay; -using ::testing::EndsWith; -using ::testing::StartsWith; - -static constexpr absl::string_view kToolName = "test-tool"; -static constexpr absl::string_view kLabel = "[LITERT_TOOLS:test-tool]"; -static constexpr absl::string_view kStartLabel = "Test Routine"; -static constexpr absl::string_view kDisplayInfo = "info"; - -TEST(TestToolDisplay, Display) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Display() << kDisplayInfo; - EXPECT_EQ(out.view(), kDisplayInfo); -} - -TEST(TestToolDisplay, Indented) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Indented() << kDisplayInfo; - EXPECT_EQ(out.view(), absl::StrFormat("\t%s", kDisplayInfo)); -} - -TEST(TestToolDisplay, Labeled) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Labeled() << kDisplayInfo; - EXPECT_EQ(out.view(), absl::StrFormat("%s %s", kLabel, kDisplayInfo)); -} - -TEST(TestToolDisplay, LabeledNoToolName) { - std::stringstream out; - ToolDisplay display(out); - display.Labeled() << kDisplayInfo; - EXPECT_EQ(out.view(), - absl::StrFormat("%s %s", "[LITERT_TOOLS]", kDisplayInfo)); -} - -TEST(TestToolDisplay, Start) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Start(kStartLabel); - EXPECT_EQ(out.view(), - absl::StrFormat("%s Starting %s...\n", kLabel, kStartLabel)); -} - -TEST(TestToolDisplay, Done) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Done(kStartLabel); - EXPECT_EQ(out.view(), - absl::StrFormat("%s \t%s Done!\n", kLabel, kStartLabel)); -} - -TEST(TestToolDisplay, Fail) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Fail(); - EXPECT_EQ(out.view(), absl::StrFormat("%s \tFailed\n", kLabel)); -} - -TEST(TestLoggedScope, EnterExit) { - std::stringstream out; - ToolDisplay display(out, kToolName); - { - auto s = display.StartS(kStartLabel); - } - EXPECT_THAT(out.view(), StartsWith(absl::StrFormat("%s Starting %s...\n", - kLabel, kStartLabel))); - EXPECT_THAT(out.view(), EndsWith(absl::StrFormat("%s \t%s Done!\n", kLabel, - kStartLabel))); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/c/BUILD b/tensorflow/lite/experimental/litert/vendors/c/BUILD deleted file mode 100644 index 0692c1f0cd4a11..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "litert_compiler_plugin", - hdrs = ["litert_compiler_plugin.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - ], -) - -cc_library( - name = "litert_compiler_plugin_api", - hdrs = ["litert_compiler_plugin_api.h"], - deps = [ - ":litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/strings:string_view", - ], -) - -# This library is used to build the C API header files for the vendor dispatch API. -# All the vendor dispatch .so tragets should depend on this library. -cc_library( - name = "litert_dispatch_c_api", - hdrs = [ - "litert_dispatch.h", - "litert_dispatch_api.h", - ], - deps = [ - # only depend on the headers, not the implementation. - "//tensorflow/lite/experimental/litert/c:litert_dispatch_headers", - ], -) - -# This test verifies that the C API header files can build via C compiler. -cc_test( - name = "litert_vendor_c_api_common_test", - srcs = ["litert_vendor_c_api_common_test.c"], - copts = ["--std=c11"], - linkopts = ["-ldl"], - deps = [ - ":litert_compiler_plugin", - ":litert_compiler_plugin_api", - ":litert_dispatch_c_api", - ], -) - -exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h deleted file mode 100644 index 926c4f98d469c0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtCompilerPlugin); - -// Artifact produced from compiling a selected partition of ops. -LITERT_DEFINE_HANDLE(LiteRtCompiledResult); - -// -// Plugin -// - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version); - -// Name associated with the manufacturer this plugin relates to (e.g, -// GoogleTensor, Qualcomm). -const char* LiteRtGetCompilerPluginSocManufacturer(); - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin); - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin); - -// Return the HW supported by this plugin (e.g., GPU, NPU) -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware); - -// Number of SoC models supported by this plugin. -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models); - -// Gets the name of the SoC model at the given index. The memory -// associated with the returned name is owned by the plugin. -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name); - -// Select desired ops for compilation. This will only be called once -// per subgraph, plugins should select all supportable ops. -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops); - -// Prepare result to pass to the runtime for given model containing partitioned -// subgraphs. Optionally, handles a SoC model (parameter `soc_model` can be NULL -// to specify a default SoC model). -LiteRtStatus LiteRtCompilerPluginCompile(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtModel partitions, - LiteRtCompiledResult* compiled_result); - -// Set any flags for the compiler do use during compilation. Flag data may be -// released or reused after this function returns. Flags are string key -> -// optional string value pairs. A non-existent value is represented by an empty -// string. Calling this function will unset any previously set flags. -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values); - -// -// Compiled Partition -// - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult result); - -// Get the buffer for the compiled byte code for the given index. -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size); - -// The number of individual byte code modules. -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code); - -// Get per-op info related to a particular compiled partition as well as the -// index of the respective byte code buffer. -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx); - -// Get the number of calls that will be made to the HAL for this graph. -// This should equal the number of partitions given for compilation which -// is equal to the number of custom ops in the final model. -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h deleted file mode 100644 index 8555933e6a9890..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -// Wrapper for dynamically loaded LiteRtCompilerPlugin library. See -// "litert_compiler_plugin.h". - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// -// Api Interface -// - -typedef LiteRtStatus (*LiteRtGetCompilerPluginVersionT)(LiteRtApiVersion*); - -typedef const char* (*LiteRtGetCompilerPluginSocManufacturerT)(); - -typedef LiteRtStatus (*LiteRtCreateCompilerPluginT)(LiteRtCompilerPlugin*); - -typedef void (*LiteRtDestroyCompilerPluginT)(LiteRtCompilerPlugin); - -typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedHardwareT)( - LiteRtCompilerPlugin, LiteRtHwAccelerators*); - -typedef LiteRtStatus (*LiteRtGetNumCompilerPluginSupportedSocModelsT)( - LiteRtCompilerPlugin, LiteRtParamIndex*); - -typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedSocModelT)( - LiteRtCompilerPlugin, LiteRtParamIndex soc_model_idx, - const char** soc_moel_idx); - -typedef LiteRtStatus (*LiteRtCompilerPluginPartitionT)( - LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraph subgraph, - LiteRtOpList selected_ops); - -typedef LiteRtStatus (*LiteRtCompilerPluginCompileT)( - LiteRtCompilerPlugin, const char* soc_model, LiteRtModel partitions, - LiteRtCompiledResult* compiled_result); - -typedef void (*LiteRtDestroyCompiledResultT)(LiteRtCompiledResult); - -typedef LiteRtStatus (*LiteRtGetCompiledResultByteCodeT)( - LiteRtCompiledResult, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size); - -typedef LiteRtStatus (*LiteRtCompiledResultNumByteCodeModulesT)( - LiteRtCompiledResult, LiteRtParamIndex* num_byte_code); - -typedef LiteRtStatus (*LiteRtGetCompiledResultCallInfoT)( - LiteRtCompiledResult, LiteRtParamIndex call_idx, const void** call_info, - size_t* call_info_size, LiteRtParamIndex* byte_code_idx); - -typedef LiteRtStatus (*LiteRtGetNumCompiledResultCallsT)( - LiteRtCompiledResult, LiteRtParamIndex* num_calls); - -typedef LiteRtStatus (*LiteRtCompilerPluginSetFlagsT)( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex num_flags, - const char** keys, const char** values); - -// -// Function Pointer Container -// - -// Wraps all resolved functions from api interface. -struct LiteRtCompilerPluginApi { - LiteRtGetCompilerPluginVersionT get_compiler_plugin_version; - LiteRtGetCompilerPluginSocManufacturerT get_compiler_plugin_soc_manufacturer; - LiteRtCreateCompilerPluginT create_compiler_plugin; - LiteRtDestroyCompilerPluginT destroy_compiler_plugin; - - LiteRtGetCompilerPluginSupportedHardwareT - get_compiler_plugin_supported_hardware; - LiteRtGetNumCompilerPluginSupportedSocModelsT - get_num_compiler_plugin_supported_models; - LiteRtGetCompilerPluginSupportedSocModelT - get_compiler_plugin_supported_soc_model; - - LiteRtCompilerPluginPartitionT compiler_plugin_partition; - LiteRtCompilerPluginCompileT compiler_plugin_compile; - - LiteRtDestroyCompiledResultT destroy_compiled_result; - LiteRtGetCompiledResultByteCodeT get_compiled_result_byte_code; - LiteRtCompiledResultNumByteCodeModulesT get_compiled_result_num_byte_code; - LiteRtGetCompiledResultCallInfoT get_compiled_result_call_info; - LiteRtGetNumCompiledResultCallsT get_compiled_result_num_calls; - - LiteRtCompilerPluginSetFlagsT set_flags; -}; - -#ifdef __cplusplus -} - -#include "absl/strings/string_view.h" - -static constexpr absl::string_view kLiteRtGetCompilerPluginVersion = - "LiteRtGetCompilerPluginVersion"; - -static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedHardware = - "LiteRtGetCompilerPluginSupportedHardware"; - -static constexpr absl::string_view kLiteRtGetCompilerPluginSocManufacturer = - "LiteRtGetCompilerPluginSocManufacturer"; -static constexpr absl::string_view - kLiteRtGetNumCompilerPluginSupportedSocModels = - "LiteRtGetNumCompilerPluginSupportedSocModels"; -static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedSocModel = - "LiteRtGetCompilerPluginSupportedSocModel"; - -static constexpr absl::string_view kLiteRtCreateCompilerPlugin = - "LiteRtCreateCompilerPlugin"; -static constexpr absl::string_view kLiteRtDestroyCompilerPlugin = - "LiteRtDestroyCompilerPlugin"; - -static constexpr absl::string_view kLiteRtCompilerPluginPartition = - "LiteRtCompilerPluginPartition"; -static constexpr absl::string_view kLiteRtCompilerPluginCompile = - "LiteRtCompilerPluginCompile"; - -static constexpr absl::string_view kLiteRtDestroyCompiledResult = - "LiteRtDestroyCompiledResult"; -static constexpr absl::string_view kLiteRtGetCompiledResultByteCode = - "LiteRtGetCompiledResultByteCode"; -static constexpr absl::string_view kLiteRtCompiledResultNumByteCodeModules = - "LiteRtCompiledResultNumByteCodeModules"; -static constexpr absl::string_view kLiteRtGetCompiledResultCallInfo = - "LiteRtGetCompiledResultCallInfo"; -static constexpr absl::string_view kLiteRtGetNumCompiledResultCalls = - "LiteRtGetNumCompiledResultCalls"; - -static constexpr absl::string_view kLiteRtCompilerPluginSetFlags = - "LiteRtCompilerPluginSetFlags"; - -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h deleted file mode 100644 index 7487daf9c9ae22..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LITERT_DEFINE_HANDLE(LiteRtDispatchDeviceContext); -LITERT_DEFINE_HANDLE(LiteRtDispatchInvocationContext); -LITERT_DEFINE_HANDLE(LiteRtDispatchMetrics); - -typedef uint64_t LiteRtTensorBufferHandle; - -typedef enum LiteRtDispatchCapabilities { - kLiteRtDispatchCapabilitiesNone = 0, - kLiteRtDispatchCapabilitiesBasic = 1, // The vendor supports the Basic API - kLiteRtDispatchCapabilitiesAsync = 2, // The vendor supports the Async API - kLiteRtDispatchCapabilitiesGraph = 4, // The vendor supports the Graph API -} LiteRtDispatchCapabilities; - -// Types of executable that can run on the HW accelerators. -typedef enum LiteRtDispatchExecutableType { - kLiteRtDispatchExecutableTypeUnknown = 0, - kLiteRtDispatchExecutableTypeDspLibrary = 1, // DSP library - kLiteRtDispatchExecutableTypeMlModel = 2, // Vendor-specific ML model -} LiteRtDispatchExecutableType; - -typedef struct LiteRtDispatchOption { - const char* name; - LiteRtAny value; -} LiteRtDispatchOption; - -typedef struct LiteRtMetric { - const char* name; - LiteRtAny value; -} LiteRtMetric; - -typedef struct LiteRtMemBuffer { - int fd; // File descriptor for an mmapped buffer, -1 if unused. - const void* base_addr; // Base address of the buffer. - size_t offset; // Offset of the buffer from the base address. - size_t size; // Buffer size. -} LiteRtMemBuffer; - -// This option can be used to specify a directory from where to load shared -// libraries. -static const char* kDispatchOptionSharedLibraryDir = "shared_library_dir"; - -// Initialize the Dispatch API runtime. -// -// This function should be called before calling any other Dispatch API -// functions. -LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, - int num_options); - -// Return the version of the Dispatch API runtime. -LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtApiVersion* api_version); - -// Return the vendor id of the Dispatch API runtime. -// -// This function returns a pointer to a statically allocated string that is the -// ID of vendor providing the Dispatch API runtime. -LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id); - -// Return the build ID of the Dispatch API runtime. -// -// This function returns a pointer to a statically allocated string that is the -// ID of the Dispatch API runtime build. -LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id); - -// Return the capabilities supported by the Dispatch API runtime as a set of the -// values specified in LiteRtDispatchCapabilities. -LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities); - -// Create a `LiteRtDispatchDeviceContext` object. -// -// The returned object is used to talk with the underlying HW. The caller owns -// the memory associated with the context and should call -// LiteRtDispatchDeviceContextDestroy() to release it. Return NULL in case of -// error. -LiteRtStatus LiteRtDispatchDeviceContextCreate( - LiteRtDispatchDeviceContext* device_context); - -// Release a `LiteRtDispatchDeviceContext` object. -// -// The given context should be release only after releasing all associated -// objects. -LiteRtStatus LiteRtDispatchDeviceContextDestroy( - LiteRtDispatchDeviceContext device_context); - -// Given a tensor type for an invocation context input, obtain the attributes -// the HW requires for the associated tensor buffer. The returned -// `tensor_buffer_requirements` object is owned by the caller. -LiteRtStatus LiteRtDispatchGetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -// Given a tensor type for an invocation context output, obtain the attributes -// the HW requires for the associated tensor buffer. The returned -// `tensor_buffer_requirements` object is owned by the caller. -LiteRtStatus LiteRtDispatchGetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -// Registers a buffer with the given device context. -// Note: The memory backing the buffer should be valid until -// `LiteRtDispatchUnregisterTensorBuffer` is called. -LiteRtStatus LiteRtDispatchRegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle); - -// Unregisters the registered buffer associated with the given -// `LiteRtTensorBufferHandle`. -// Note: The registered `LiteRtTensorBufferHandle` is supposed to be -// unregistered with this function before the associated `ThrContext` is deleted -// by calling `LiteRtDispatchDeviceContextDestroy`. -LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle tensor_buffer_handle); - -// Create an invocation context to run a given function from a given -// executable. Parameter `function_name` is required if the provided executable -// includes multiple functions. -LiteRtStatus LiteRtDispatchInvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context); - -LiteRtStatus LiteRtDispatchInvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context); - -LiteRtStatus LiteRtDispatchAttachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchAttachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchDetachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchDetachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchInvoke( - LiteRtDispatchInvocationContext invocation_context); - -// Start collection of HW-specific metrics at a specific level of detail (>= 0). -LiteRtStatus LiteRtDispatchStartMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, int detail_level); - -// Stop collection of HW-specific metrics and report the collected -// metrics. Note: The caller is responsible for deallocating the returned -// metrics by calling `LiteRtDispatchDestroyMetrics`. -LiteRtStatus LiteRtDispatchStopMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics); - -LiteRtStatus LiteRtDispatchGetNumMetrics(LiteRtDispatchMetrics metrics, - int* num_metrics); - -// Fetch a specific metric. The runtime owns the returned object. -LiteRtStatus LiteRtDispatchGetMetric(LiteRtDispatchMetrics metrics, - int metric_index, LiteRtMetric* metric); - -LiteRtStatus LiteRtDispatchDestroyMetrics(LiteRtDispatchMetrics metrics); - -// ///////////////////////////////////////////////////////////////////////////// -// Async Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchAttachInputEvent( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event); - -LiteRtStatus LiteRtDispatchInvokeAsync( - LiteRtDispatchInvocationContext invocation_context, int num_output_events, - LiteRtEvent* output_events); - -// ///////////////////////////////////////////////////////////////////////////// -// Graph Execution API -// ///////////////////////////////////////////////////////////////////////////// - -typedef uint64_t LiteRtDispatchNodeId; -typedef uint64_t LiteRtDispatchEdgeId; -typedef uint64_t LiteRtDispatchExecutableHandle; - -LITERT_DEFINE_HANDLE(LiteRtDispatchGraph); - -// Types of graph nodes. -typedef enum LiteRtDispatchNodeType { - kLiteRtDispatchNodeTypeUnknown = 0, - kLiteRtDispatchNodeTypeDsp = - 1, // Can execute both ML models and Dsp libraries - kLiteRtDispatchNodeTypeNpu = 2, // Can execute only ML models -} LiteRtDispatchNodeType; - -LiteRtStatus LiteRtDispatchGraphCreate( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph** graph); - -LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph* graph); - -// Add a compute node to a given graph. Parameter node_id should be unique to -// the graph. -LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); - -// Add an edge a given graph. Parameter edge_id should be unique to the graph. -LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph* graph, - LiteRtDispatchEdgeId edge_id); - -// Connect a given node's input. -LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - int input_index, - LiteRtDispatchEdgeId edge_id); - -// Connect a given node's output. -LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - int output_index, - LiteRtDispatchEdgeId edge_id); - -// Connect a given graph's input. -LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph* graph, - int input_index, - LiteRtDispatchEdgeId edge_id); - -// Connect a given graph's output. -LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph* graph, - int output_index, - LiteRtDispatchEdgeId edge_id); - -LiteRtStatus LiteRtDispatchLoadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle); - -LiteRtStatus LiteRtDispatchUnloadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle); - -// Assign an executable function to a graph node. Parameter `function_name` is -// mandatory if the given executable includes multiple functions. -LiteRtStatus LiteRtDispatchAssignNodeFunction( - LiteRtDispatchGraph* graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, const char* function_name); - -// Add an annotation to an entire graph. -LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph* graph, - const char* key, const char* value); - -// Add an annotation to a specified node. -LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - const char* key, const char* value); - -// Add an annotation to a specified edge. -LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph* graph, - LiteRtDispatchEdgeId edge_id, - const char* key, const char* value); - -LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph, - LiteRtDispatchInvocationContext* invocation_context); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h deleted file mode 100644 index 527a19c2630e09..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// ///////////////////////////////////////////////////////////////////////////// - -typedef LiteRtStatus (*LiteRtDispatchInitializeT)( - const LiteRtDispatchOption* options, int num_options); - -typedef LiteRtStatus (*LiteRtDispatchGetVendorIdT)(const char** vendor_id); - -typedef LiteRtStatus (*LiteRtDispatchGetBuildIdT)(const char** build_id); - -typedef LiteRtStatus (*LiteRtDispatchGetCapabilitiesT)(int* capabilities); - -typedef LiteRtStatus (*LiteRtDispatchDeviceContextCreateT)( - LiteRtDispatchDeviceContext* device_context); - -typedef LiteRtStatus (*LiteRtDispatchDeviceContextDestroyT)( - LiteRtDispatchDeviceContext device_context); - -typedef LiteRtStatus (*LiteRtDispatchGetInputRequirementsT)( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -typedef LiteRtStatus (*LiteRtDispatchGetOutputRequirementsT)( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -typedef LiteRtStatus (*LiteRtDispatchRegisterTensorBufferT)( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchUnregisterTensorBufferT)( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle handle); - -typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateT)( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchInvocationContextDestroyT)( - LiteRtDispatchInvocationContext invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchAttachInputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchAttachOutputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchDetachInputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchDetachOutputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchInvokeT)( - LiteRtDispatchInvocationContext invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchStartMetricsCollectionT)( - LiteRtDispatchInvocationContext invocation_context, int detail_level); - -typedef LiteRtStatus (*LiteRtDispatchStopMetricsCollectionT)( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics); - -typedef LiteRtStatus (*LiteRtDispatchGetNumMetricsT)( - LiteRtDispatchMetrics metrics, int* num_metrics); - -typedef LiteRtStatus (*LiteRtDispatchGetMetricT)(LiteRtDispatchMetrics metrics, - int metric_index, - LiteRtMetric* metric); - -typedef LiteRtStatus (*LiteRtDispatchDestroyMetricsT)( - LiteRtDispatchMetrics metrics); - -typedef struct LiteRtDispatchInterface { - LiteRtDispatchInitializeT initialize; - LiteRtDispatchGetVendorIdT get_vendor_id; - LiteRtDispatchGetBuildIdT get_build_id; - LiteRtDispatchGetCapabilitiesT get_capabilities; - LiteRtDispatchDeviceContextCreateT device_context_create; - LiteRtDispatchDeviceContextDestroyT device_context_destroy; - LiteRtDispatchGetInputRequirementsT get_input_requirements; - LiteRtDispatchGetOutputRequirementsT get_output_requirements; - LiteRtDispatchRegisterTensorBufferT register_tensor_buffer; - LiteRtDispatchUnregisterTensorBufferT unregister_tensor_buffer; - LiteRtDispatchInvocationContextCreateT invocation_context_create; - LiteRtDispatchInvocationContextDestroyT invocation_context_destroy; - LiteRtDispatchAttachInputT attach_input; - LiteRtDispatchAttachOutputT attach_output; - LiteRtDispatchDetachInputT detach_input; - LiteRtDispatchDetachOutputT detach_output; - LiteRtDispatchInvokeT invoke; - LiteRtDispatchStartMetricsCollectionT start_metrics_collection; - LiteRtDispatchStopMetricsCollectionT stop_metrics_collection; - LiteRtDispatchGetNumMetricsT get_num_metrics; - LiteRtDispatchGetMetricT get_metric; - LiteRtDispatchDestroyMetricsT destroy_metrics; -} LiteRtDispatchInterface; - -// ///////////////////////////////////////////////////////////////////////////// - -typedef LiteRtStatus (*LiteRtDispatchAttachInputEventT)( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event); - -typedef LiteRtStatus (*LiteRtDispatchInvokeAsyncT)( - LiteRtDispatchInvocationContext invocation_context, int num_output_events, - LiteRtEvent* output_events); - -typedef struct LiteRtDispatchAsyncInterface { - LiteRtDispatchAttachInputEventT attach_input_event; - LiteRtDispatchInvokeAsyncT invoke_async; -} LiteRtDispatchAsyncInterface; - -// ///////////////////////////////////////////////////////////////////////////// - -typedef LiteRtStatus (*LiteRtDispatchGraphCreateT)( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph); - -typedef LiteRtStatus (*LiteRtDispatchGraphDestroyT)(LiteRtDispatchGraph graph); - -typedef LiteRtStatus (*LiteRtDispatchAddNodeT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); - -typedef LiteRtStatus (*LiteRtDispatchAddEdgeT)(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectNodeInputT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectNodeOutputT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectGraphInputT)( - LiteRtDispatchGraph graph, int input_index, LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectGraphOutputT)( - LiteRtDispatchGraph graph, int output_index, LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchLoadExecutableT)( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle); - -typedef LiteRtStatus (*LiteRtDispatchUnloadExecutableT)( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle); - -typedef LiteRtStatus (*LiteRtDispatchAssignNodeFunctionT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, const char* function_name); - -typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateFromGraphT)( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchAnnotateGraphT)(LiteRtDispatchGraph graph, - const char* key, - const char* value); - -typedef LiteRtStatus (*LiteRtDispatchAnnotateNodeT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, const char* key, - const char* value); - -typedef LiteRtStatus (*LiteRtDispatchAnnotateEdgeT)( - LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id, const char* key, - const char* value); - -typedef struct LiteRtDispatchGraphInterface { - LiteRtDispatchGraphCreateT graph_create; - LiteRtDispatchGraphDestroyT graph_destroy; - LiteRtDispatchAddNodeT add_node; - LiteRtDispatchAddEdgeT add_edge; - LiteRtDispatchConnectNodeInputT connect_node_input; - LiteRtDispatchConnectNodeOutputT connect_node_output; - LiteRtDispatchConnectGraphInputT connect_graph_input; - LiteRtDispatchConnectGraphOutputT connect_graph_output; - LiteRtDispatchLoadExecutableT load_executable; - LiteRtDispatchUnloadExecutableT unload_executable; - LiteRtDispatchAssignNodeFunctionT assign_node_function; - LiteRtDispatchAnnotateGraphT annotate_graph; - LiteRtDispatchAnnotateNodeT annotate_node; - LiteRtDispatchAnnotateEdgeT annotate_edge; - LiteRtDispatchInvocationContextCreateFromGraphT - invocation_context_create_from_graph; -} LiteRtDispatchGraphInterface; - -// ///////////////////////////////////////////////////////////////////////////// - -// FIXME See Vulkan and OpenCL extensions. -typedef struct LiteRtDispatchApi { - LiteRtApiVersion version; - LiteRtDispatchInterface* interface; - LiteRtDispatchAsyncInterface* async_interface; - LiteRtDispatchGraphInterface* graph_interface; -} LiteRtDispatchApi; - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c b/tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c deleted file mode 100644 index 60cedbb927035a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// This file exists to verify that the below header files can build, link, -// and run as C code. -#ifdef __cplusplus -#error "This file should be compiled as C code, not as C++." -#endif - -// Include all the header files in the litert/c directory. -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" // NOLINT -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" // NOLINT -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" // NOLINT -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" // NOLINT - -int main(void) { - return 0; -} diff --git a/tensorflow/lite/experimental/litert/vendors/cc/BUILD b/tensorflow/lite/experimental/litert/vendors/cc/BUILD deleted file mode 100644 index 25e6c26462cfab..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/BUILD +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "litert_compiler_plugin", - hdrs = ["litert_compiler_plugin.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "conversion", - hdrs = ["conversion.h"], - deps = [ - ":backend_ir", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "backend_ir", - hdrs = ["backend_ir.h"], - deps = ["//tensorflow/lite/experimental/litert/c:litert_common"], -) - -cc_library( - name = "partition_with_capabilities", - hdrs = ["partition_with_capabilities.h"], - deps = [ - ":conversion", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_library( - name = "convert_graph", - hdrs = ["convert_graph.h"], - deps = [ - ":conversion", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_library( - name = "ir_types", - hdrs = ["ir_types.h"], - deps = [ - ":backend_ir", - ":conversion", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "partition_with_capabilities_test", - srcs = ["partition_with_capabilities_test.cc"], - deps = [ - ":partition_with_capabilities", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/vendors/examples:example_conversion_impl", - "//tensorflow/lite/experimental/litert/vendors/examples:example_ir", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "convert_graph_test", - srcs = ["convert_graph_test.cc"], - deps = [ - ":backend_ir", - ":convert_graph", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/examples:example_conversion_impl", - "//tensorflow/lite/experimental/litert/vendors/examples:example_ir", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h b/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h deleted file mode 100644 index 34cf95bd3643e6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert { - -// Interfaces and types for managing backend IR to be targeted by LiteRt for -// compilation. - -// Memory Management -//===--------------------------------------------------------------------------- - -// Callable for allocating a new instance of a backend IR type. This facilitates -// external memory management for the backend IR implementented by the backend. -// It is encouraged for implementations provide pointer stability (consider -// std::list for storage). -template -using BackendIrAllocator = std::function; - -// Allocator for backend tensors. -template -using TensorAllocator = BackendIrAllocator; - -// Allocator for backend ops. -template -using OpAllocator = BackendIrAllocator; - -// Graph Construction -//===--------------------------------------------------------------------------- - -// Wrapper for an in memory graph for a particular backend. Implementations -// should contain an instance of a backend graph that can be iteratively -// constructed via calls to this interface. -template -class BackendGraphBuilder { - public: - // Hook called to initialize state for a new backend graph with a name. This - // will be called once per-instance before any other method. - virtual void InitGraph(std::string graph_name) = 0; - - // Hook called to register a backend tensor once it - // has been converted. This will be called once per tensor. - virtual LiteRtStatus RegisterTensor(BackendTensor& tensor) = 0; - - // Hook called to register a backend op once it has been converted. This will - // be called once per op (in a toplogogical order). All input/output tensors - // will have been registered before called. - virtual LiteRtStatus RegisterOp(BackendOp& op) = 0; - - // Hook called to register a graph when graph - // conversion is completed. Backend graph context should be stored as internal - // state. This will be called once per instance after all ops/tensors have - // been finalized. - virtual LiteRtStatus FinalizeGraph() = 0; - - virtual ~BackendGraphBuilder() = default; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/conversion.h b/tensorflow/lite/experimental/litert/vendors/cc/conversion.h deleted file mode 100644 index 139ba594bb1e8a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/conversion.h +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility types for mapping LiteRt IR to arbitrary backend specific -// types. Implementations of these types define mapping for ops and tensors -// that may be used in a stndalone fashion. They also may be composed -// to create lowerings of entire graphs with topology. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" - -namespace litert { - -// Interfaces and types for implementing "conversions" that map LiteRt IR to -// backend IR. -// NOTE: Conversions depend on external memory management for the backend IR -// types. User defined conversions are usually expected to leverage callbacks -// to allocate backend IR types rather than constructing them directly. - -// Conversion Result Type -//===--------------------------------------------------------------------------- - -// Result of a one->many general mapping from LiteRt op to any number of -// backend specific ops. Does not own the memory of the backend ops or tensors. -template -struct GeneralConversionResult { - // Ops emitted from translation pattern. - std::vector ops; - - // Any backend tensors used within the results ops. Not relevant when - // size of backend ops == 1. This does not include input/output tensors of the - // op being converted. - std::vector intermediate_tensors; -}; - -// The result of a one->one specialized mapping from LiteRt op to backend op. -template -using SimpleConversionResult = BackendOp*; - -// A tag-type for a conversion result that is a non-error non-match. -struct NoMatch {}; - -// Type union for conversion results. -// TODO(lukeboyer): Update conversion result types to handle the case where -// backend ops add extra inputs. -template -using ConversionResult = - std::variant, - GeneralConversionResult, NoMatch>; - -// Short hand for holds_alternative. -template -bool ConversionIsA(const ConversionResult& result) { - return std::holds_alternative(result); -} - -// Short hand for holds_alternative. -template -bool ConversionMatched( - const ConversionResult& result) { - return !std::holds_alternative(result); -} - -// Short hand for holds_alternative. -template -bool IsSimpleResult(const ConversionResult& result) { - return ConversionIsA>(result); -} - -// Short hand for holds_alternative. -template -bool IsGeneralResult(const ConversionResult& result) { - return ConversionIsA>( - result); -} - -// Short hand for std::get. Also checks if match and wraps in expected. -template -Expected GetConversionResult( - const ConversionResult& result) { - if (ConversionMatched(result)) { - return Expected(std::get(result)); - } - return Error(kLiteRtStatusLegalizeNoMatch); -} - -// Get simple result if there was a match. -template -Expected> GetSimpleConversionResult( - const ConversionResult& result) { - if (!IsSimpleResult(result)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return GetConversionResult>(result); -} - -// Get general result if there was a match. -template -Expected> -GetGeneralConversionResult( - const ConversionResult& result) { - if (!IsGeneralResult(result)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return GetConversionResult>( - result); -} - -// Common IR Conversion -//===--------------------------------------------------------------------------- - -// User defined callback for converting a LiteRt tensor to a backend tensor. -// These are leveraged in various higher-level conversion routines. -// TensorConverters should not stack allocate memory for the backend tensor. In -// most situations, these will be bound to an external allocator. -template -using TensorConverter = - std::function(const Tensor& litert_tensor)>; - -// User defined callback for creating a TensorConverter. This facilitates -// TensoConverters that are bound to an external allocator. -template -using TensorConverterFactory = std::function( - TensorAllocator alloc)>; - -// Mapping from LiteRt tensor to backend tensor, used during iterative graph -// conversions to store current scope. -template -using TensorMap = absl::flat_hash_map; - -// User-defined hook that calls backend to determine if an op is supported. -template -using Capability = std::function; - -// Legalization -//===--------------------------------------------------------------------------- - -// A legalization is a particlar type of user-defined conversion that is -// scheduled for execution on a particular type of LiteRtOp. They may be -// one-to-one or one-to-many conversions. -template -class Legalization { - private: - using Self = Legalization; - - public: - using Result = ConversionResult; - using TensorConverter = TensorConverter; - using TensorConverterFactory = TensorConverterFactory; - using Ptr = std::unique_ptr; - using TensorAllocator = TensorAllocator; - using OpAllocator = OpAllocator; - using Tensors = std::vector; - - // The type of op to schedule on. - virtual LiteRtOpCode OpToMatch() const = 0; - - // Invoke this legalization on the given LiteRt op. All new backend IR will be - // allocated via given allocators. NOTE: In most cases, input and output - // converters will be the same. They are separated here for compatibility with - // graph-level conversions routines. - Expected Legalize(const Op& litert_op, - TensorConverterFactory input_converter, - TensorConverterFactory output_converter, - TensorAllocator tensor_allocator, - OpAllocator op_allocator) const { - const auto litert_inputs = litert_op.Inputs(); - Tensors inputs(litert_inputs.size()); - auto convert_input = input_converter(tensor_allocator); - - for (size_t i = 0; i < litert_inputs.size(); ++i) { - const auto& litert_input = litert_inputs[i]; - auto result = convert_input(litert_input); - if (!result) { - return result.Error(); - } - inputs[i] = *result; - } - - const auto litert_outputs = litert_op.Outputs(); - Tensors outputs(litert_outputs.size()); - auto convert_output = output_converter(tensor_allocator); - - for (size_t i = 0; i < litert_outputs.size(); ++i) { - const auto& litert_output = litert_outputs[i]; - auto result = convert_output(litert_output); - if (!result) { - return result.Error(); - } - outputs[i] = *result; - } - - return LegalizeImpl(litert_op, inputs, outputs, tensor_allocator, - op_allocator); - } - - virtual ~Legalization() = default; - - private: - // The user defined implementation of a legalization. Users must use the - // given allocators to allocate any new backend IR types (e.g. intermediate - // ops/tensors in the case of a one-to-many legalization). BackendTensors - // corresponding to LiteRt inputs and outputs have been pre-converted. - virtual Expected LegalizeImpl(const Op& litert_op, - const Tensors& inputs, - const Tensors& outputs, - TensorAllocator tensor_allocator, - OpAllocator op_allocator) const = 0; -}; - -// Collection of legalizations for a specific backend. -template -using Legalizations = - std::vector::Ptr>; - -// Map for instance lookup by op code. -template -using LegalizationMap = - absl::flat_hash_map*>; - -// Construct a LegalizationMap from a collection of legalizations. -// TODO: Consider wrapping the legalization map in a class to avoid -// re-constructing it & better syntax. -template -LegalizationMap MakeLegalizationMap( - const Legalizations& legalizations) { - LegalizationMap map; - for (const auto& l : legalizations) { - map.insert({l->OpToMatch(), l.get()}); - } - return map; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h deleted file mode 100644 index cd7221c7bba028..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" - -namespace litert { - -// Performs iterative graph conversion with user provided hooks. This function -// traverses the IR in toplogical order, converting ops and tensors with given -// tensor converter and legalizations. Registers converted ops and tensors with -// the backend graph builder after they have been converted. The following are -// true: -// * Each tensor and op will be converted & registered at most once. -// * An ops input and output tensors will be registered before the op is -// converted (and before its registered). -// * The graph builder will be initialized before any registration. -// * The graph builder will be finalized after all registration. -template -LiteRtStatus ConvertGraph( - const Subgraph& subgraph, std::string graph_name, - typename Ir::TensorConverterFactory tensor_converter_factory, - typename Ir::TensorAllocator tensor_alloc, - typename Ir::OpAllocator op_alloc, - const typename Ir::Legalizations& legalizations, - typename Ir::GraphBuilder& builder) { - // Store mapping between evaluated litert tensors and corresponding backend - // tensors. - typename Ir::TensorMap tensor_map; - - // Initialize backend graph builder. - builder.InitGraph(std::move(graph_name)); - - // Convert tensor, add to scope and register in backend graph builder. - auto handle_tensor = [&tensor_map, &builder]( - const auto& litert_tensor, - auto tensor_converter) -> Ir::TensorResult { - auto converted = tensor_converter(litert_tensor); - if (!converted) { - LITERT_LOG(LITERT_ERROR, "Failed to convert tensor %lu", - litert_tensor.Get()); - return converted.Error(); - } - - if (auto status = builder.RegisterTensor(**converted); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor %lu, with status %d", - litert_tensor.Get(), status); - return Error(status); - } - - tensor_map.insert({litert_tensor.Get(), *converted}); - return *converted; - }; - - // Wrap provided tensor conversion logic for converting subgraph or op input - // tensors. We want functionality that provides user-defined conversions with - // tensors to be aware of the tensor map and graph builder registration. - auto input_tensor_convert_factory = [tensor_converter_factory, &tensor_map, - handle_tensor](auto tensor_alloc) { - return [tensor_alloc, tensor_converter_factory, &tensor_map, - handle_tensor](const Tensor& litert_tensor) -> Ir::TensorResult { - auto tensor_converter = tensor_converter_factory(tensor_alloc); - - // Check if tensor has been converted already. - auto it = tensor_map.find(litert_tensor.Get()); - const auto in_scope = it != tensor_map.end(); - if (in_scope) { - LITERT_LOG(LITERT_VERBOSE, "Tensor %lu is in scope", - litert_tensor.Get()); - return it->second; - } - - // If its a subgraph input or constant, we can convert it and add to - // scope. - const auto is_cst = litert_tensor.IsConstant(); - const auto is_sg_input = litert_tensor.IsSubgraphInput(); - if (is_sg_input || is_cst) { - return handle_tensor(litert_tensor, tensor_converter); - } - - // Tensor must be added to scope before conversion, or not have a parent - // (e.g. subgraph input or constant) so error at this point. - LITERT_LOG(LITERT_ERROR, "Tensor %lu not handled", litert_tensor.Get()); - return Error(kLiteRtStatusErrorInvalidArgument); - }; - }; - - // Wrap provided tensor conversion logic for op output tensors. Adds to map - // and backend graph after conversion. - auto output_tensor_convert_factory = [tensor_converter_factory, - handle_tensor](auto tensor_alloc) { - return [tensor_alloc, tensor_converter_factory, - handle_tensor](const Tensor& litert_tensor) { - auto tensor_converter = tensor_converter_factory(tensor_alloc); - return handle_tensor(litert_tensor, tensor_converter); - }; - }; - - // Convert all ops in subgraph in toplogical order. - auto legalization_map = Ir::MakeLegalizationMap(legalizations); - for (const auto& op : subgraph.Ops()) { - auto it = legalization_map.find(op.Code()); - if (it == legalization_map.end()) { - LITERT_LOG(LITERT_ERROR, "No legalization found for op %d", op.Code()); - return kLiteRtStatusErrorUnsupported; - } - - auto result = it->second->Legalize(op, input_tensor_convert_factory, - output_tensor_convert_factory, - tensor_alloc, op_alloc); - if (!result) { - LITERT_LOG(LITERT_ERROR, "Failed to legalize op %d, with status %d", - op.Code(), result.Error().Status()); - return result.Error().Status(); - } - - auto simple_result = GetSimpleConversionResult(*result); - if (simple_result) { - if (auto stat = builder.RegisterOp(**simple_result); - stat != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register op %d, with status %d", - op.Code(), stat); - return stat; - } - } - - auto general_result = GetGeneralConversionResult(*result); - if (general_result) { - for (auto* tensor : general_result->intermediate_tensors) { - if (auto stat = builder.RegisterTensor(*tensor); - stat != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, - "Failed to register tensor %d, with status %d", tensor->id, - stat); - return stat; - } - } - - for (auto* op : general_result->ops) { - if (auto stat = builder.RegisterOp(*op); stat != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register op %d, with status %d", - op->op_code, stat); - return stat; - } - } - } - } - - builder.FinalizeGraph(); - - return kLiteRtStatusOk; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc deleted file mode 100644 index 9ad0e0e644e66f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert { -namespace { - -using ::litert::example::ExampleOpAllocator; -using ::litert::example::ExampleOpType; -using ::litert::example::ExampleTensorAllocator; -using ::litert::example::ExampleTypes; -using ::litert::example::MakeAllLegalizations; -using ::litert::example::MakeTensorConverter; -using ::testing::AllOf; -using ::testing::ElementsAreArray; -using ::testing::Expectation; -using ::testing::ExpectationSet; -using ::testing::Field; -using ::testing::Return; - -static constexpr std::array kDims = {2, 2}; -static constexpr auto kElementType = kLiteRtElementTypeFloat32; -static constexpr absl::string_view kGraphName = "graph_name"; - -TensorType GetTestTensorType() { - return MakeRankedTensorType(kElementType, absl::MakeConstSpan(kDims)); -} - -class MockGraphBuilder - : public BackendGraphBuilder { - public: - MOCK_METHOD(void, InitGraph, (std::string name), (override)); - MOCK_METHOD(LiteRtStatus, RegisterTensor, (ExampleTypes::Tensor & tensor), - (override)); - MOCK_METHOD(LiteRtStatus, RegisterOp, (ExampleTypes::Op & op), (override)); - MOCK_METHOD(LiteRtStatus, FinalizeGraph, (), (override)); -}; - -TEST(ConvertGraphTest, ConvertSingleSimpleConversion) { - LiteRtSubgraphT subgraph; - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflMul); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(GetTestTensorType()); - input1.SetName("input1"); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(GetTestTensorType()); - input2.SetName("input2"); - - auto& output = subgraph.EmplaceTensor(); - output.SetType(GetTestTensorType()); - output.SetName("output"); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - subgraph.Outputs().push_back(&output); - - Subgraph litert_subgraph(&subgraph); - - ExampleOpAllocator op_alloc; - ExampleTensorAllocator tensor_alloc; - - MockGraphBuilder builder; - - Expectation init_graph = - EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); - - ExpectationSet reg_inputs; - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - ExpectationSet reg_outputs; - reg_outputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_op_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::MUL), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({input1.Name(), input2.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output.Name()}))); - - Expectation reg_op = EXPECT_CALL(builder, RegisterOp(match_reg_op_args)) - .Times(1) - .After(reg_inputs, reg_outputs) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) - .Times(1) - .After(reg_op) - .WillOnce(Return(kLiteRtStatusOk)); - - auto stat = ConvertGraph( - litert_subgraph, std::string(kGraphName), MakeTensorConverter, - tensor_alloc, op_alloc, MakeAllLegalizations(), builder); - - LITERT_ASSERT_OK(stat); -} - -TEST(ConvertGraphTest, ConvertSingleGeneralConversion) { - LiteRtSubgraphT subgraph; - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflAdd); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(op, std::move(tfl_opts)); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(GetTestTensorType()); - input1.SetName("input1"); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(GetTestTensorType()); - input2.SetName("input2"); - - auto& output = subgraph.EmplaceTensor(); - output.SetType(GetTestTensorType()); - output.SetName("output"); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - subgraph.Outputs().push_back(&output); - - Subgraph litert_subgraph(&subgraph); - - ExampleOpAllocator op_alloc; - ExampleTensorAllocator tensor_alloc; - - MockGraphBuilder builder; - - Expectation init_graph = - EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); - - ExpectationSet reg_inputs; - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - ExpectationSet reg_intermediates; - reg_intermediates += - EXPECT_CALL(builder, - RegisterTensor(Field(&ExampleTypes::Tensor::name, - example::kIntermediateTensorName))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - ExpectationSet reg_outputs; - reg_outputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_add_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::ADD), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({input1.Name(), input2.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({example::kIntermediateTensorName}))); - - Expectation reg_add = EXPECT_CALL(builder, RegisterOp(match_reg_add_args)) - .Times(1) - .After(reg_inputs, reg_intermediates) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_relu_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::RELU), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({example::kIntermediateTensorName})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output.Name()}))); - - Expectation reg_relu = EXPECT_CALL(builder, RegisterOp(match_reg_relu_args)) - .Times(1) - .After(reg_add, reg_intermediates, reg_outputs) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) - .Times(1) - .After(reg_relu) - .WillOnce(Return(kLiteRtStatusOk)); - - auto stat = ConvertGraph( - litert_subgraph, std::string(kGraphName), MakeTensorConverter, - tensor_alloc, op_alloc, MakeAllLegalizations(), builder); - - LITERT_ASSERT_OK(stat); -} - -TEST(ConvertGraphTest, ConvertMultipleOps) { - LiteRtSubgraphT subgraph; - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflMul); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(GetTestTensorType()); - input1.SetName("input1"); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(GetTestTensorType()); - input2.SetName("input2"); - - auto& output1 = subgraph.EmplaceTensor(); - output1.SetType(GetTestTensorType()); - output1.SetName("output1"); - - auto& cst = subgraph.EmplaceTensor(); - OwningBufferRef weights(8); - SetWeightsFromUnownedBuffer(cst.Weights(), weights); - cst.SetName("cst"); - cst.SetType(GetTestTensorType()); - - auto& op2 = subgraph.EmplaceOp(); - op2.SetOpCode(kLiteRtOpCodeTflAdd); - - auto& output2 = subgraph.EmplaceTensor(); - output2.SetType(GetTestTensorType()); - output2.SetName("output2"); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output1, op); - - internal::AttachInput(&output1, op2); - internal::AttachInput(&cst, op2); - internal::AttachOutput(&output2, op2); - - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - subgraph.Outputs().push_back(&output2); - - Subgraph litert_subgraph(&subgraph); - - ExampleOpAllocator op_alloc; - ExampleTensorAllocator tensor_alloc; - - MockGraphBuilder builder; - - Expectation init_graph = - EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); - - ExpectationSet reg_inputs; - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation reg_output1 = - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation reg_cst = - EXPECT_CALL(builder, RegisterTensor( - Field(&ExampleTypes::Tensor::name, cst.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation reg_output2 = - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_op1_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::MUL), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({input1.Name(), input2.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output1.Name()}))); - - Expectation reg_op1 = EXPECT_CALL(builder, RegisterOp(match_reg_op1_args)) - .Times(1) - .After(reg_inputs, reg_output1) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_op2_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::ADD), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({output1.Name(), cst.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output2.Name()}))); - - Expectation reg_op2 = EXPECT_CALL(builder, RegisterOp(match_reg_op2_args)) - .Times(1) - .After(reg_op1, reg_cst, reg_output2, reg_output1) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) - .Times(1) - .After(reg_op2) - .WillOnce(Return(kLiteRtStatusOk)); - - auto stat = ConvertGraph( - litert_subgraph, std::string(kGraphName), MakeTensorConverter, - tensor_alloc, op_alloc, MakeAllLegalizations(), builder); - - LITERT_ASSERT_OK(stat); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h b/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h deleted file mode 100644 index a1da917de18a74..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" - -namespace litert { - -// Holds particular backends IR template aliases for convenience. -template -struct IrTypes { - using Op = BackendOp; - using Tensor = BackendTensor; - using OpAllocator = OpAllocator; - using TensorAllocator = TensorAllocator; - using GraphBuilder = BackendGraphBuilder; - using GeneralConversionResult = GeneralConversionResult; - using SimpleConversionResult = SimpleConversionResult; - using ConversionResult = Expected>; - using Legalization = Legalization; - using Legalizations = Legalizations; - using LegalizationMap = LegalizationMap; - using TensorConverter = TensorConverter; - using TensorResult = Expected; - using TensorConverterFactory = TensorConverterFactory; - using TensorMap = TensorMap; - using Capability = Capability; - // NOLINTNEXTLINE - inline static auto MakeLegalizationMap = - litert::MakeLegalizationMap; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h b/tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h deleted file mode 100644 index 654457f0f75e24..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ - -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -namespace litert { - -// Deleter for incomplete compiler plugin type. -struct LiteRtCompilerPluginDeleter { - void operator()(LiteRtCompilerPlugin plugin) { - if (plugin != nullptr) { - LiteRtDestroyCompilerPlugin(plugin); - } - } -}; - -// Smart pointer wrapper for incomplete plugin type. -using PluginPtr = - std::unique_ptr; - -// Initialize a plugin via c-api and wrap result in smart pointer. -inline PluginPtr CreatePlugin() { - LiteRtCompilerPlugin plugin; - LITERT_CHECK_STATUS_OK(LiteRtCreateCompilerPlugin(&plugin)); - return PluginPtr(plugin); -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h deleted file mode 100644 index a462d1744c3886..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" - -namespace litert { - -// Higher-level functions for partitioning by leveraging user-defined -// conversions. This method selects ops for partitioning via a callback that -// checks if an op is supported by the backend. - -// Selects ops for partitioning from given subgraph based on given Capability -// check. Returns all ops in the given supbgraph that are supported by the -// backend. Suitable for use in implementing LiteRtCompilerPluginPartition. Any -// allocations of new backend ir types will be done through given external -// allocators. -// NOTE: A missing legalization or any legalization failure will result in -// an op not being supported, rather than a failure of this function. -template -Expected> PartitionWithCapabilities( - const typename Ir::Legalizations& legalizations, - typename Ir::Capability capability, - typename Ir::TensorConverterFactory convert_tensor_fact, - typename Ir::TensorAllocator tensor_allocator, - typename Ir::OpAllocator op_allocator, const Subgraph& litert_subgraph) { - std::vector results; - - // Build map for legalization lookup by op code. - auto map = Ir::MakeLegalizationMap(legalizations); - - // Convert all ops from the given subgraph and check backend support. - for (const auto& litert_op : litert_subgraph.Ops()) { - const auto code = litert_op.Code(); - LITERT_LOG(LITERT_INFO, "Checking support for LiteRtOp: %d", code); - - auto it = map.find(code); - if (it == map.end()) { - LITERT_LOG(LITERT_WARNING, "No legalization found for LiteRtOp: %d", - code); - continue; - } - - // Call user-defined conversion. - auto result = it->second->Legalize(litert_op, convert_tensor_fact, - convert_tensor_fact, tensor_allocator, - op_allocator); - if (!result) { - LITERT_LOG(LITERT_WARNING, "Failed to legalize LiteRtOp: %d", code); - continue; - } - - if (auto simple_result = GetSimpleConversionResult(*result)) { - if (capability(*simple_result)) { - LITERT_LOG(LITERT_INFO, "Selected LiteRtOp: %d", litert_op.Code()); - results.push_back(litert_op.Get()); - } - continue; - } - - // Check all ops emitted from a one-to-many conversion are supported. - if (auto gen_result = GetGeneralConversionResult(*result)) { - const auto b_ops_start = gen_result->ops.cbegin(); - const auto b_ops_end = gen_result->ops.cend(); - if (std::all_of(b_ops_start, b_ops_end, capability)) { - LITERT_LOG(LITERT_INFO, "Selected LiteRtOp: %d", litert_op.Code()); - results.push_back(litert_op.Get()); - } - continue; - } - } - - return results; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc deleted file mode 100644 index cece5adb48dca9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility types for mapping LiteRt IR to arbitrary backend specific -// types. Implementations of these types define mapping for ops and tensors -// that may be used in a stndalone fashion. They also may be composed -// to create lowerings of entire graphs with topology. - -#include "tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h" - -#include -#include -#include - -#include -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert { -namespace { - -using ::litert::example::ExampleLegalizeAdd; -using ::litert::example::ExampleLegalizeMul; -using ::litert::example::ExampleOpAllocator; -using ::litert::example::ExampleOpType; -using ::litert::example::ExampleTensorAllocator; -using ::litert::example::ExampleTypes; -using ::litert::example::MakeTensorConverter; - -bool ExampleCapability(const ExampleTypes::Op* op) { - return op->op_code == ExampleOpType::ADD || - op->op_code == ExampleOpType::RELU; -} - -TEST(PartitionWithCapabilitiesTest, EmptyGraph) { - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - ASSERT_TRUE(ops); - EXPECT_TRUE(ops->empty()); -} - -TEST(PartitionWithCapabilitiesTest, SingleSelectedOp) { - static constexpr std::array kDims = {2, 2}; - - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - - const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(type); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(type); - - auto& output = subgraph.EmplaceTensor(); - output.SetType(type); - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - - ASSERT_TRUE(ops); - EXPECT_EQ(ops->size(), 1); -} - -TEST(PartitionWithCapabilitiesTest, MultiSelectedOp) { - static constexpr std::array kDims = {2, 2}; - - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - - const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); - - auto& add1_input = subgraph.EmplaceTensor(); - add1_input.SetType(type); - auto& add1_output = subgraph.EmplaceTensor(); - add1_output.SetType(type); - auto& add1 = subgraph.EmplaceOp(); - add1.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&add1_input, add1); - internal::AttachInput(&add1_input, add1); - internal::AttachOutput(&add1_output, add1); - - auto& mul_output = subgraph.EmplaceTensor(); - mul_output.SetType(type); - auto& mul = subgraph.EmplaceOp(); - mul.SetOpCode(kLiteRtOpCodeTflMul); - - internal::AttachInput(&add1_output, mul); - internal::AttachOutput(&mul_output, mul); - - auto& add2_output = subgraph.EmplaceTensor(); - add2_output.SetType(type); - auto& add2 = subgraph.EmplaceOp(); - add2.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&mul_output, add2); - internal::AttachInput(&mul_output, add2); - internal::AttachOutput(&add2_output, add2); - - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - - ASSERT_TRUE(ops); - - ASSERT_EQ(ops->size(), 2); - EXPECT_EQ(ops->front(), &add1); - EXPECT_EQ(ops->back(), &add2); -} - -TEST(PartitionWithCapabilitiesTest, WithGeneralResult) { - static constexpr std::array kDims = {2, 2}; - - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - - const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); - - auto& add1_input = subgraph.EmplaceTensor(); - add1_input.SetType(type); - auto& add1_output = subgraph.EmplaceTensor(); - add1_output.SetType(type); - auto& add1 = subgraph.EmplaceOp(); - add1.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&add1_input, add1); - internal::AttachInput(&add1_input, add1); - internal::AttachOutput(&add1_output, add1); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(add1, std::move(tfl_opts)); - - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - - ASSERT_TRUE(ops); - - ASSERT_EQ(ops->size(), 1); - EXPECT_EQ(ops->front(), &add1); -} - -} // namespace - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/examples/BUILD b/tensorflow/lite/experimental/litert/vendors/examples/BUILD deleted file mode 100644 index 16427953b936fe..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/BUILD +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_dynamic_lib( - name = "example_plugin", - srcs = [ - "example_plugin.cc", - "example_plugin_common.cc", - "example_plugin_common.h", - ], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - linkstatic = 1, - shared_lib_name = "example_plugin_so", - so_name = "libLiteRtCompilerPlugin_Example.so", - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/cc:litert_op_options", - ], -) - -cc_test( - name = "example_plugin_test", - srcs = [ - "example_plugin_test.cc", - ], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":example_plugin", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "example_conversion_impl", - srcs = ["example_conversion_impl.cc"], - hdrs = ["example_conversion_impl.h"], - visibility = ["//tensorflow/lite/experimental/litert/vendors/cc:__pkg__"], - deps = [ - ":example_ir", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/cc:backend_ir", - "//tensorflow/lite/experimental/litert/vendors/cc:conversion", - "//tensorflow/lite/experimental/litert/vendors/cc:ir_types", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "example_conversion_impl_test", - srcs = ["example_conversion_impl_test.cc"], - deps = [ - ":example_conversion_impl", - ":example_ir", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/cc:conversion", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "example_ir", - srcs = ["example_ir.cc"], - hdrs = ["example_ir.h"], - visibility = ["//tensorflow/lite/experimental/litert/vendors/cc:__pkg__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/vendors/cc:backend_ir", - "//tensorflow/lite/experimental/litert/vendors/cc:ir_types", - ], -) - -cc_library( - name = "example_plugin_with_conversions", - srcs = [ - "example_plugin_common.cc", - "example_plugin_common.h", - "example_plugin_with_conversions.cc", - ], - deps = [ - ":example_conversion_impl", - ":example_ir", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/cc:convert_graph", - "//tensorflow/lite/experimental/litert/vendors/cc:partition_with_capabilities", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_test( - name = "example_plugin_with_conversions_test", - srcs = ["example_plugin_with_conversions_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":example_plugin_with_conversions", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc deleted file mode 100644 index fa6e163aee4b77..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert::example { - -TensorConverter MakeTensorConverter( - TensorAllocator alloc) { - return [alloc](const Tensor& litert_tensor) -> Expected { - auto& tensor = *alloc(); - tensor.name = litert_tensor.Name(); - - auto litert_type = litert_tensor.RankedTensorType(); - if (!litert_type) { - return Error(litert_type.Error().Status()); - } - - const auto litert_dims = litert_type->Layout().Dimensions(); - - tensor.dims.assign(litert_dims.cbegin(), litert_dims.cend()); - - switch (litert_tensor.RankedTensorType()->ElementType()) { - case ElementType::Float32: - tensor.type = ExampleTensorType::FLOAT; - break; - case ElementType::Int32: - tensor.type = ExampleTensorType::INT; - break; - default: - return Error(kLiteRtStatusErrorInvalidArgument); - } - - return &tensor; - }; -} - -ExampleTypes::Legalizations MakeAllLegalizations() { - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeMul::Make()); - legalizations.push_back(ExampleLegalizeAdd::Make()); - return legalizations; -} - -} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h deleted file mode 100644 index 64f3199bc363df..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ - -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/ir_types.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert::example { - -// Conversion type implementations for the fictional "example" backend. - -ExampleTypes::TensorConverter MakeTensorConverter( - ExampleTypes::TensorAllocator alloc); - -static constexpr absl::string_view kIntermediateTensorName = - "intermediate_bin_output"; - -// Example legalization for simple binary ops. -template -class ExampleBinOpLegalization : public Legalization { - private: - using Self = ExampleBinOpLegalization; - - public: - using Ptr = std::unique_ptr; - - static Ptr Make() { return std::make_unique(); } - - // Return the litert op code to match on. - LiteRtOpCode OpToMatch() const override { return LiteRtOpType; } - - // Determines if the given litert op has a fused relu attribute. - bool HasFusedRelu(const Op& litert_op) const { - if constexpr (LiteRtOpType != kLiteRtOpCodeTflAdd) { - return false; - } - uint32_t faf; - if (LiteRtGetAddFusedActivationOption(litert_op.Get(), &faf) != - kLiteRtStatusOk) { - return false; - } - return faf == 1; - } - - // Transforms LiteRtAdd op into example op definition using the tensor - // converter to map tensors within. - ExampleTypes::ConversionResult LegalizeImpl( - const Op& litert_op, const Tensors& inputs, const Tensors& outputs, - ExampleTypes::TensorAllocator tensor_allocator, - ExampleTypes::OpAllocator op_allocator) const override { - ABSL_DCHECK_EQ(litert_op.Code(), LiteRtOpType); - - auto& bin_op = *op_allocator(); - bin_op.op_code = BackendOpType; - - if (inputs.size() != 2 || outputs.size() != 1) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - - for (const auto* input : inputs) { - bin_op.inputs.push_back(input->id); - bin_op.input_names.push_back(input->name); - } - - auto& output_tensor = *outputs.front(); - if (!HasFusedRelu(litert_op)) { - bin_op.outputs.push_back(output_tensor.id); - bin_op.output_names.push_back(output_tensor.name); - return Expected(&bin_op); - } - - auto* bin_output = tensor_allocator(); - bin_output->dims = output_tensor.dims; - bin_output->type = output_tensor.type; - bin_output->name = std::string(kIntermediateTensorName); - bin_op.outputs.push_back(bin_output->id); - bin_op.output_names.push_back(bin_output->name); - - auto& relu = *op_allocator(); - relu.op_code = ExampleOpType::RELU; - relu.inputs.push_back(bin_output->id); - relu.input_names.push_back(bin_output->name); - relu.outputs.push_back(output_tensor.id); - relu.output_names.push_back(output_tensor.name); - - ExampleTypes::GeneralConversionResult result; - result.ops.push_back(&bin_op); - result.ops.push_back(&relu); - result.intermediate_tensors.push_back(bin_output); - - return ExampleTypes::ConversionResult(result); - } -}; - -using ExampleLegalizeAdd = - ExampleBinOpLegalization; -using ExampleLegalizeMul = - ExampleBinOpLegalization; - -ExampleTypes::Legalizations MakeAllLegalizations(); - -} // namespace litert::example - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc deleted file mode 100644 index 8baf028313eda3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" - -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::example { -namespace { - -using ::testing::ElementsAreArray; -using ::testing::HasSubstr; - -TEST(ExampleConversionImplTest, ConvertTensor) { - static constexpr std::array kDims = {2, 2}; - static constexpr absl::string_view kName = "foo"; - - LiteRtTensorT litert_tensor; - litert_tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - litert_tensor.SetName(std::string(kName)); - - ExampleTensorAllocator tensor_alloc; - auto tensor_convert = MakeTensorConverter(tensor_alloc); - - auto& example_tensor = **tensor_convert(Tensor(&litert_tensor)); - EXPECT_EQ(example_tensor.type, ExampleTensorType::FLOAT); - EXPECT_THAT(example_tensor.dims, ElementsAreArray(kDims)); - EXPECT_EQ(example_tensor.name, kName); -} - -TEST(ExampleConversionImplTest, ExampleGraphBuilder) { - ExampleTensor input; - input.type = ExampleTensorType::FLOAT; - input.dims = {2, 2}; - input.id = 1; - - ExampleTensor output; - output.type = ExampleTensorType::INT; - output.dims = {3, 3}; - output.id = 2; - - ExampleOp op; - op.op_code = ExampleOpType::ADD; - op.inputs = {1}; - op.outputs = {2}; - - ExampleGraphBuilder builder; - static constexpr absl::string_view kName = "FOO_GRAPH"; - - builder.InitGraph(std::string(kName)); - LITERT_ASSERT_OK(builder.RegisterTensor(input)); - LITERT_ASSERT_OK(builder.RegisterOp(op)); - LITERT_ASSERT_OK(builder.RegisterTensor(output)); - LITERT_ASSERT_OK(builder.FinalizeGraph()); - - const auto serialized = builder.Serialize(); - EXPECT_THAT(serialized, HasSubstr("1FLOAT[2, 2]")); - EXPECT_THAT(serialized, HasSubstr("2INT[3, 3]")); - EXPECT_THAT(serialized, HasSubstr("ADD(1)->(2)")); - EXPECT_THAT(serialized, HasSubstr("FINALIZED")); - EXPECT_THAT(serialized, HasSubstr(kName)); -} - -TEST(ExampleConversionImplTest, LegalizeAddSimpleResult) { - static constexpr std::array kDims = {2, 2}; - - LiteRtTensorT input1; - input1.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input1.SetName("input1"); - - LiteRtTensorT input2; - input2.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input2.SetName("input2"); - - LiteRtTensorT output; - output.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - output.SetName("output"); - - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeTflAdd); - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_NONE; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(op, std::move(tfl_opts)); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - ExampleLegalizeAdd legalize_add; - EXPECT_EQ(legalize_add.OpToMatch(), kLiteRtOpCodeTflAdd); - - auto legalized = - legalize_add.Legalize(Op(&op), MakeTensorConverter, MakeTensorConverter, - tensor_alloc, op_alloc); - - ASSERT_TRUE(legalized); - - auto simple_result = GetSimpleConversionResult(*legalized); - ASSERT_TRUE(simple_result); - auto& example_op = **simple_result; - - EXPECT_EQ(example_op.op_code, ExampleOpType::ADD); - EXPECT_THAT(example_op.inputs, ElementsAreArray({0, 1})); - EXPECT_THAT(example_op.input_names, - ElementsAreArray({input1.Name(), input2.Name()})); - EXPECT_THAT(example_op.outputs, ElementsAreArray({2})); - EXPECT_THAT(example_op.output_names, ElementsAreArray({output.Name()})); -} - -TEST(ExampleConversionImplTest, LegalizeAddGeneralResult) { - static constexpr std::array kDims = {2, 2}; - LiteRtTensorT input1; - input1.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input1.SetName("input1"); - - LiteRtTensorT input2; - input2.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input2.SetName("input2"); - - LiteRtTensorT output; - output.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - output.SetName("output"); - - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeTflAdd); - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(op, std::move(tfl_opts)); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto legalize_add = ExampleLegalizeAdd::Make(); - EXPECT_EQ(legalize_add->OpToMatch(), kLiteRtOpCodeTflAdd); - - auto legalized = - legalize_add->Legalize(Op(&op), MakeTensorConverter, MakeTensorConverter, - tensor_alloc, op_alloc); - ASSERT_TRUE(legalized); - - auto gen_result = GetGeneralConversionResult(*legalized); - ASSERT_TRUE(gen_result); - - ASSERT_EQ(gen_result->ops.size(), 2); - EXPECT_EQ(gen_result->ops[0]->op_code, ExampleOpType::ADD); - EXPECT_THAT(gen_result->ops[0]->inputs, ElementsAreArray({0, 1})); - EXPECT_THAT(gen_result->ops[0]->input_names, - ElementsAreArray({input1.Name(), input2.Name()})); - EXPECT_THAT(gen_result->ops[0]->outputs, ElementsAreArray({3})); - EXPECT_THAT(gen_result->ops[0]->output_names, - ElementsAreArray({kIntermediateTensorName})); - EXPECT_EQ(gen_result->ops[1]->op_code, ExampleOpType::RELU); - EXPECT_THAT(gen_result->ops[1]->inputs, ElementsAreArray({3})); - EXPECT_THAT(gen_result->ops[1]->input_names, - ElementsAreArray({kIntermediateTensorName})); - EXPECT_THAT(gen_result->ops[1]->outputs, ElementsAreArray({2})); - EXPECT_THAT(gen_result->ops[1]->output_names, - ElementsAreArray({output.Name()})); - EXPECT_EQ(gen_result->intermediate_tensors.size(), 1); - EXPECT_EQ(gen_result->intermediate_tensors.front()->id, 3); - EXPECT_EQ(gen_result->intermediate_tensors.front()->name, - kIntermediateTensorName); -} - -} // namespace - -} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc deleted file mode 100644 index da06b617d9f15b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert::example { - -namespace { - -template -void PrintWithCommas(It start, It end, std::ostream& out) { - for (auto it = start; it < end; ++it) { - out << std::to_string(*it); - if (it != end - 1) { - out << ", "; - } - } -} - -} // namespace - -LiteRtStatus ExampleGraphBuilder::RegisterOp(ExampleOp& op) { - switch (op.op_code) { - case ExampleOpType::ADD: - example_graph_ << "ADD"; - break; - case ExampleOpType::MUL: - example_graph_ << "MUL"; - break; - case ExampleOpType::RELU: - example_graph_ << "RELU"; - break; - } - example_graph_ << "("; - PrintWithCommas(op.inputs.cbegin(), op.inputs.cend(), example_graph_); - example_graph_ << ")->("; - PrintWithCommas(op.outputs.cbegin(), op.outputs.cend(), example_graph_); - example_graph_ << ")"; - return kLiteRtStatusOk; -} - -LiteRtStatus ExampleGraphBuilder::RegisterTensor(ExampleTensor& tensor) { - example_graph_ << std::to_string(tensor.id); - switch (tensor.type) { - case ExampleTensorType::FLOAT: - example_graph_ << "FLOAT"; - break; - case ExampleTensorType::INT: - example_graph_ << "INT"; - break; - } - example_graph_ << "["; - PrintWithCommas(tensor.dims.cbegin(), tensor.dims.cend(), example_graph_); - example_graph_ << "]"; - return kLiteRtStatusOk; -} - -LiteRtStatus ExampleGraphBuilder::FinalizeGraph() { - example_graph_ << "FINALIZED"; - return kLiteRtStatusOk; -} - -void ExampleGraphBuilder::InitGraph(std::string graph_name) { - example_graph_ << "name=" << graph_name << "\n"; -} - -std::string ExampleGraphBuilder::Serialize() const { - return example_graph_.str(); -} - -} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h deleted file mode 100644 index e423a53f382b8d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/ir_types.h" - -namespace litert::example { - -// Example IR wrapper types for an imaginary backend. - -// Example backend knows only float and int 32. -enum class ExampleTensorType { - FLOAT, - INT, -}; - -// Example backend tensor wrapper that stores the type and shape and unique ID. -struct ExampleTensor { - using Id = int32_t; - ExampleTensorType type; - std::vector dims; - std::string name; - Id id = -1; -}; - -// Example backend knows only a few simple ops. -enum class ExampleOpType { - ADD, - MUL, - RELU, -}; - -// Example backend op that stores op type as well as input and output tensor -// IDs and names. -struct ExampleOp { - ExampleOpType op_code; - std::vector inputs; - std::vector input_names; - std::vector outputs; - std::vector output_names; -}; - -// Simple allocator(s) for example example IR types that provides pointer -// stability. -template -class ExampleIrAllocatorBase { - public: - ExampleIrAllocatorBase(const ExampleIrAllocatorBase&) = delete; - ExampleIrAllocatorBase& operator=(const ExampleIrAllocatorBase&) = delete; - ExampleIrAllocatorBase() = default; - - protected: - std::list ir_; -}; - -// Allocator for example tensors that provides pointer stability and unique IDs. -class ExampleTensorAllocator : public ExampleIrAllocatorBase { - private: - using Alloc = BackendIrAllocator; - - public: - ExampleTensor* operator()() { - auto& tensor = this->ir_.emplace_back(); - tensor.id = this->next_id_++; - return &tensor; - } - - // Return lambda instead of implicit copy construction when converting to - // function type. - // NOLINTNEXTLINE - operator Alloc() { - return [this]() { return this->operator()(); }; - } - - ExampleTensorAllocator(const ExampleTensorAllocator&) = delete; - ExampleTensorAllocator& operator=(const ExampleTensorAllocator&) = delete; - ExampleTensorAllocator() = default; - - private: - uint32_t next_id_ = 0; -}; - -// Allocator for example ops that provides pointer stability. -class ExampleOpAllocator : public ExampleIrAllocatorBase { - private: - using Alloc = BackendIrAllocator; - - public: - ExampleOp* operator()() { return &this->ir_.emplace_back(); } - - // Return lambda instead of implicit copy construction when converting to - // function type. - // NOLINTNEXTLINE - operator Alloc() { - return [this]() { return this->operator()(); }; - } - - ExampleOpAllocator(const ExampleOpAllocator&) = delete; - ExampleOpAllocator& operator=(const ExampleOpAllocator&) = delete; - ExampleOpAllocator() = default; -}; - -// Builder for graph conversion to example IR. The internal example IR graph is -// simply a string representation of the graph. -class ExampleGraphBuilder - : public BackendGraphBuilder { - public: - // Prefixes ir string. - void InitGraph(std::string graph_name) override; - - // Registers tensor into the currrent graph by simply appending its string - // representation. - LiteRtStatus RegisterTensor(ExampleTensor& tensor) override; - - // Registers op into the currrent graph by simply appending its string - // representation. - LiteRtStatus RegisterOp(ExampleOp& op) override; - - // Simply appends tag to IR string. - LiteRtStatus FinalizeGraph() override; - - // Gets the serialized IR representation. - std::string Serialize() const; - - private: - std::stringstream example_graph_; -}; - -using ExampleTypes = IrTypes; - -} // namespace litert::example - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc deleted file mode 100644 index dff15a4490ec9d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" - -// A simple compiler plugin example that implements everything directly. -// This plugin matches on mul ops, and emits "byte code" that is simply -// a string representative of the ops consumed. - -// Plugins can hold state. -struct LiteRtCompilerPluginT {}; - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - *compiler_plugin = new LiteRtCompilerPluginT; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ::litert::Subgraph main_subgraph(subgraph); - for (const auto& op : main_subgraph.Ops()) { - if (op.Code() == kLiteRtOpCodeTflMul) { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } else if (op.Code() == kLiteRtOpCodeTflSub) { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 1)); - } else if (op.Code() == kLiteRtOpCodeShloComposite) { - const auto opts = - litert::GetOptionsAs(op.Get()); - if (!opts) { - return opts.Error().Status(); - } - if (opts->name == "odml.rms_norm") { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } - } - } - return kLiteRtStatusOk; -} - -namespace { - -LiteRtStatus CompileSinglePartition(LiteRtParamIndex partition_index, - LiteRtSubgraph subgraph, - LiteRtCompiledResultT& result, - int byte_code_idx) { - const litert::Subgraph sg(subgraph); - int num_muls_in_partition = 0; - for (const auto& op : sg.Ops()) { - if (op.Code() != kLiteRtOpCodeTflMul && op.Code() != kLiteRtOpCodeTflSub) { - return kLiteRtStatusErrorUnsupported; - } - if (op.Code() == kLiteRtOpCodeTflMul) { - ++num_muls_in_partition; - } - } - - { - char* byte_code_append; - (void)asprintf(&byte_code_append, - "Partition_%lu_with_%d_muls:", partition_index, - num_muls_in_partition); - result.byte_code[byte_code_idx].append(byte_code_append); - free(byte_code_append); - } - - { - char* per_op_data; - (void)asprintf(&per_op_data, "Partition_%lu", partition_index); - result.per_op_data.push_back(per_op_data); - free(per_op_data); - } - - return kLiteRtStatusOk; -} - -} // namespace - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - auto result = std::make_unique(); - result->byte_code.resize(num_partitions); - for (auto i = 0; i < num_partitions; ++i) { - LITERT_RETURN_IF_ERROR( - CompileSinglePartition(i, model.Subgraph(i)->Get(), *result, i)); - } - - *compiled_result = result.release(); - - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc deleted file mode 100644 index 19c84dc55e7869..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -// -// Configurations -// - -namespace litert::example { -namespace { - -constexpr char kPluginManufacturer[] = "ExampleSocManufacturer"; -constexpr char kPluginSocModel[] = "ExampleSocModel"; - -} // namespace -} // namespace litert::example - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - // IMPLEMENT ME - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (!api_version) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorCpu; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return litert::example::kPluginManufacturer; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = 1; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx != 0) { - return kLiteRtStatusErrorUnsupported; - } - *soc_model_name = litert::example::kPluginSocModel; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result) { - return kLiteRtStatusErrorInvalidArgument; - } - *byte_code = compiled_result->byte_code[byte_code_idx].data(); - *byte_code_size = compiled_result->byte_code[byte_code_idx].size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (call_idx >= compiled_result->per_op_data.size()) { - return kLiteRtStatusErrorIndexOOB; - } - *call_info = compiled_result->per_op_data.at(call_idx).data(); - *call_info_size = compiled_result->per_op_data.at(call_idx).size(); - *byte_code_idx = 0; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->per_op_data.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - *num_byte_code = compiled_result->byte_code.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h deleted file mode 100644 index cc7c0f60df4e85..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ - -#include -#include - -// Simple compiled result def holds byte code and per op data. -struct LiteRtCompiledResultT { - std::vector byte_code; - std::vector per_op_data; -}; - -namespace litert::example {} // namespace litert::example - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc deleted file mode 100644 index 3b1b098ff62bfa..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -TEST(TestDummyPlugin, GetConfigInfo) { - ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), - "ExampleSocManufacturer"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 1); - - const char* soc_model_name; - LITERT_ASSERT_OK(LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, - &soc_model_name)); - ASSERT_STREQ(soc_model_name, "ExampleSocModel"); -} - -TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("simple_multi_op.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Subgraph(0)->Get(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 2); - ASSERT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflMul); - ASSERT_EQ(selected_ops[1].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(TestCallDummyPlugin, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK(LiteRtCompilerPluginCompile( - plugin.get(), /*soc_model=*/nullptr, model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_EQ(byte_code_string, "Partition_0_with_2_muls:"); - - LiteRtParamIndex byte_code_idx; - const void* op_data; - size_t op_data_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ(op_data_string, "Partition_0"); - - LiteRtDestroyCompiledResult(compiled); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc deleted file mode 100644 index 22f11167c2cc5a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" - -using ::litert::PartitionWithCapabilities; -using ::litert::example::ExampleGraphBuilder; -using ::litert::example::ExampleOpAllocator; -using ::litert::example::ExampleOpType; -using ::litert::example::ExampleTensorAllocator; -using ::litert::example::ExampleTypes; -using ::litert::example::MakeAllLegalizations; -using ::litert::example::MakeTensorConverter; - -// Example plugin implementations that leverage the pluggable conversion -// infrastructure. Implementations of common interfaces are provided in -// example_conversion_impl.h. These are passed to higher-level litert functions -// to perform the actual conversion. -// The primary benifit of this approach is the re-use of conversion logic -// between the partition and compile phases. - -// Plugins can hold state. -struct LiteRtCompilerPluginT { - ExampleTypes::Legalizations legalizations; -}; - -namespace { - -bool MulCapability(const ExampleTypes::Op* op) { - return op->op_code == ExampleOpType::MUL; -} - -} // namespace - -// Initialize example plugin and register legalizations. -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - auto* plugin = new LiteRtCompilerPluginT; - plugin->legalizations = MakeAllLegalizations(); - *compiler_plugin = plugin; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -// Leverage the convert_type PartitionViaCapabilties algorithm for partitioning -// implementation. -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - compiler_plugin->legalizations, MulCapability, MakeTensorConverter, - tensor_alloc, op_alloc, ::litert::Subgraph(subgraph)); - if (!ops) { - return ops.Error().Status(); - } - - for (auto* op : *ops) { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op, 0)); - } - - return kLiteRtStatusOk; -} - -namespace { - -LiteRtStatus CompileSinglePartition( - const ExampleTypes::Legalizations& legalizations, std::string name, - LiteRtSubgraph subgraph, LiteRtCompiledResultT& result) { - ::litert::Subgraph litert_subgraph(subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - ExampleGraphBuilder builder; - - LITERT_RETURN_IF_ERROR(::litert::ConvertGraph( - litert_subgraph, name, MakeTensorConverter, tensor_alloc, op_alloc, - legalizations, builder)); - - // This example plugin only supports a single byte code module. - result.byte_code[0].append(builder.Serialize()); - result.per_op_data.push_back(std::move(name)); - - return kLiteRtStatusOk; -} - -} // namespace - -// Plugin compiler implementation that leverages the pluggable convert_types -// infrastructure. -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - auto result = std::make_unique(); - result->byte_code.resize(num_partitions); - for (auto i = 0; i < num_partitions; ++i) { - auto name = absl::StrFormat("partition_%lu", i); - LITERT_RETURN_IF_ERROR( - CompileSinglePartition(compiler_plugin->legalizations, std::move(name), - model.Subgraph(i)->Get(), *result)); - } - - *compiled_result = result.release(); - - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc deleted file mode 100644 index 10c7928cab629f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -using ::testing::HasSubstr; - -TEST(ExamplePluginWithConvertTypesTest, GetConfigInfo) { - ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), - "ExampleSocManufacturer"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 1); - - const char* soc_model_name; - LITERT_ASSERT_OK(LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, - &soc_model_name)); - ASSERT_STREQ(soc_model_name, "ExampleSocModel"); -} - -TEST(ExamplePluginWithConvertTypesTest, PartitionSimpleMultiAdd) { - auto plugin = CreatePlugin(); - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Get()->MainSubgraph(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 2); - ASSERT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflMul); - ASSERT_EQ(selected_ops[1].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(ExamplePluginWithConvertTypesTest, CompileMulSubgraph) { - static constexpr absl::string_view kName = "partition_0"; - - auto plugin = CreatePlugin(); - auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK(LiteRtCompilerPluginCompile( - plugin.get(), /*soc_model=*/nullptr, model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - absl::string_view byte_code_str(reinterpret_cast(byte_code), - byte_code_size); - - EXPECT_THAT(byte_code_str, HasSubstr(kName)); - EXPECT_THAT(byte_code_str, HasSubstr("0FLOAT[2, 2]")); - EXPECT_THAT(byte_code_str, HasSubstr("1FLOAT[2, 2]")); - EXPECT_THAT(byte_code_str, HasSubstr("2FLOAT[2, 2]")); - EXPECT_THAT(byte_code_str, HasSubstr("MUL")); - EXPECT_THAT(byte_code_str, HasSubstr("FINALIZED")); - - LiteRtParamIndex num_call_infos; - LITERT_ASSERT_OK(LiteRtGetNumCompiledResultCalls(compiled, &num_call_infos)); - - ASSERT_EQ(num_call_infos, 1); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, 0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_str(reinterpret_cast(op_data), - op_data_size); - EXPECT_EQ(op_data_str, kName); - - LiteRtDestroyCompiledResult(compiled); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/BUILD deleted file mode 100644 index e66b1c1b84ec95..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib", "make_rpaths") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_lib( - name = "adapter", - srcs = ["adapter.cc"], - hdrs = ["adapter.h"], - data = [ - # copybara:uncomment_begin(google-only) - # "//platforms/darwinn/compiler:compiler_api_wrapper", - # copybara:uncomment_end - ], - linkopts = [ - # TODO(abhirs): Make this work for OS. - #copybara:comment_begin(google-only) - make_rpaths(["platforms/darwinn/compiler"]), - # copybara:uncomment_end - ], - tags = [ - # Don't build/test in OS until libcompiler_api_wrapper.so is available. - "nobuilder", - "no_oss", - ], - ungrte = False, - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "adapter_test", - srcs = ["adapter_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = 1, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - #Don't build/test in OS until libcompiler_api_wrapper.so is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that uses dlopen. - "nosan", - "manual", - ], - # This test can only be run on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - deps = [ - ":adapter", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc deleted file mode 100644 index b0e1c25c591e57..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h" - -#include - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace google_tensor { - -Adapter::Adapter() : api_(new Api) {} - -Adapter::~Adapter() { - if (dlib_handle_) { - dlclose(dlib_handle_); // Use dlclose directly - } -} - -litert::Expected Adapter::Create( - std::optional shared_library_dir) { - Ptr adapter(new Adapter); - auto status = adapter->LoadSymbols(shared_library_dir); - if (!status.HasValue()) { - LITERT_LOG(LITERT_ERROR, "Failed to create Adapter: %s", - status.Error().Message().c_str()); - return status.Error(); - } - return adapter; -} - -litert::Expected Adapter::LoadSymbols( - std::optional shared_library_dir) { - constexpr auto kLibTensorTPUCompiler = "libcompiler_api_wrapper.so"; - - const std::vector so_paths = { - shared_library_dir.has_value() - ? absl::StrCat(*shared_library_dir, "/", kLibTensorTPUCompiler) - : kLibTensorTPUCompiler}; - - // Use dlopen directly - for (const auto& path : so_paths) { - dlib_handle_ = dlopen(path.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (dlib_handle_) { - void* init_func = dlsym(dlib_handle_, "Initialize"); - if (init_func) { - (*reinterpret_cast(init_func))(); - } - break; // Found the library - } - } - - if (!dlib_handle_) { - const std::string error_message = - "Failed to load Tensor TPU compiler library: " + std::string(dlerror()); - LITERT_LOG(LITERT_ERROR, "Failed to load Tensor TPU compiler library: %s", - error_message.c_str()); // Include dlerror() for more info - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, error_message); - } - - api_->compile = - reinterpret_cast(dlsym(dlib_handle_, "CompileFlatbuffer")); - if (!api_->compile) { - const std::string error_message = - "Failed to load Tensor TPU compiler API: " + std::string(dlerror()); - LITERT_LOG(LITERT_ERROR, "Failed to load Tensor TPU compiler API: %s", - error_message.c_str()); // Include dlerror() - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, error_message); - } - - LITERT_LOG(LITERT_INFO, "Tensor TPU compiler API symbols loaded"); - return {}; -} - -} // namespace google_tensor -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h deleted file mode 100644 index 37a88a840c793f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_ADAPTER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_ADAPTER_H_ -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::google_tensor { - -// Flags is a vector of key-value pairs. where key is the flag name and value is -// the flag value. eg. {{"enable_reference", "true"}} -using Flags = std::vector>; -typedef absl::Status (*Compile)(absl::string_view serialized_tfl_buffer, - absl::string_view soc_model, const Flags& flags, - std::string* compiled_code); - -// This class adapts the google tensor compiler API for dynamic loading. -class Adapter { - public: - // A smart pointer for managing TensorAdapter objects. - using Ptr = std::unique_ptr; - struct Api; - - Adapter(); - ~Adapter(); - - // Creates a new TensorAdapter and loads the compiler API symbols. - static litert::Expected Create( - std::optional shared_library_dir); - - // Returns a reference to the loaded API. - const Api& api() const { return *api_; } - - private: - // Loads the symbols from the compiler library. - litert::Expected LoadSymbols( - std::optional shared_library_dir); - - void* dlib_handle_ = nullptr; - std::unique_ptr api_; -}; - -struct Adapter::Api { - // The function pointer to the compiler wrapper API. - Compile compile = nullptr; -}; - -} // namespace litert::google_tensor - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_ADAPTER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc deleted file mode 100644 index 55872dfeb1160b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h" - -#include - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace litert { -namespace google_tensor { - -TEST(AdapterTest, CreateSuccess) { - auto adapter_result = Adapter::Create(/*shared_library_dir=*/ - std::nullopt); - if (!adapter_result.HasValue()) { - LITERT_LOG(LITERT_ERROR, "Failed to create Adapter: %s", - adapter_result.Error().Message().c_str()); - } - ASSERT_TRUE(adapter_result.HasValue()); -} - -TEST(AdapterTest, CreateFailure) { - auto kLibDarwinnCompilerNoLib = "libcompiler_api_wrapper_no_lib.so"; - auto adapter_result = Adapter::Create(kLibDarwinnCompilerNoLib); - ASSERT_FALSE(adapter_result.HasValue()); -} - -TEST(AdapterTest, CompileSuccess) { - auto adapter_result = Adapter::Create(/*shared_library_dir=*/ - std::nullopt); - if (!adapter_result.HasValue()) { - LITERT_LOG(LITERT_ERROR, "Failed to create Adapter: %s", - adapter_result.Error().Message().c_str()); - } - - auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); - LiteRtModel litert_model = model.Get(); - - LITERT_LOG(LITERT_INFO, "%s", "Serializing model"); - litert::OwningBufferRef buf; - - // Using weak pointer to link the data to the buffer. - auto [data, size, offset] = buf.GetWeak(); - - const auto opts = litert::SerializationOptions::Defaults(); - auto status = - LiteRtSerializeModel(litert_model, &data, &size, &offset, false, opts); - if (status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to serialize model"); - } - - absl::string_view buffer_str(reinterpret_cast(buf.Data()), - buf.Size()); - - ASSERT_FALSE(buffer_str.empty()); - LITERT_LOG(LITERT_INFO, "buffer_str size: %d", buffer_str.size()); - LITERT_LOG(LITERT_INFO, "Compling model..."); - absl::string_view soc_model = "P25"; - litert::google_tensor::Flags flags; - flags.clear(); - std::string compiled_code; - auto compile_status = adapter_result.Value()->api().compile( - buffer_str, soc_model, flags, &compiled_code); - ASSERT_OK(compile_status); - ASSERT_FALSE(compiled_code.empty()); -} - -} // namespace google_tensor -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD deleted file mode 100644 index eaf46ce867d1b6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_dynamic_lib( - name = "compiler_plugin", - srcs = ["compiler_plugin.cc"], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - linkstatic = 1, - shared_lib_name = "google_tensor_compiler_plugin_so", - so_name = "libLiteRtCompilerPlugin_google_tensor.so", - tags = [ - # Don't build/test in OS until google tensor is available. - "nobuilder", - "no_oss", - "notap", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/google_tensor:adapter", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "compiler_plugin_test", - srcs = [ - "compiler_plugin_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until google tensor is available. - "nobuilder", - "no_oss", - # Sanatizer runtime doesn't work with anything that loads a shared library. - "nosan", - "manual", - ], - # This test can only be run on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - deps = [ - ":compiler_plugin", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc deleted file mode 100644 index 0b5854f03a18b9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc +++ /dev/null @@ -1,360 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h" - -// -// Configurations -// - -namespace google_tensor { - -constexpr char kPluginManufacturer[] = "GoogleTensor"; - -constexpr const char* kPluginSocModels[] = { - "P25", -}; // get the name for plugin soc model - -constexpr LiteRtOpCode kUnSupportedOps[] = { - kLiteRtOpCodeTflAssignVariable, - kLiteRtOpCodeTflBidirectionalSequenceLstm, - kLiteRtOpCodeTflBroadcastArgs, - kLiteRtOpCodeTflBucketize, - kLiteRtOpCodeTflCallOnce, - kLiteRtOpCodeTflComplexAbs, - kLiteRtOpCodeTflConv3d, - kLiteRtOpCodeTflConv3dTranspose, - kLiteRtOpCodeTflDensify, - kLiteRtOpCodeTflFakeQuant, - kLiteRtOpCodeTflHashtable, - kLiteRtOpCodeTflHashtableFind, - kLiteRtOpCodeTflHashtableImport, - kLiteRtOpCodeTflHashtableSize, - kLiteRtOpCodeTflImag, - kLiteRtOpCodeTflLocalResponseNormalization, - kLiteRtOpCodeTflMatrixDiag, - kLiteRtOpCodeTflMatrixSetDiag, - kLiteRtOpCodeTflMultinomial, - kLiteRtOpCodeTflNonMaxSuppressionV4, - kLiteRtOpCodeTflNonMaxSuppressionV5, - kLiteRtOpCodeTflRandomStandardNormal, - kLiteRtOpCodeTflRandomUniform, - kLiteRtOpCodeTflRank, - kLiteRtOpCodeTflReadVariable, - kLiteRtOpCodeTflReal, - kLiteRtOpCodeTflReduceProd, - kLiteRtOpCodeTflReverseSequence, - kLiteRtOpCodeTflRfft2d, - kLiteRtOpCodeTflSegmentSum, - kLiteRtOpCodeTflShape, - kLiteRtOpCodeTflSparseToDense, - kLiteRtOpCodeTflSvdf, - kLiteRtOpCodeTflUnidirectionalSequenceRnn, - kLiteRtOpCodeTflUnique, - kLiteRtOpCodeTflUnsortedSegmentMax, - kLiteRtOpCodeTflUnsortedSegmentMin, - kLiteRtOpCodeTflUnsortedSegmentProd, - kLiteRtOpCodeTflUnsortedSegmentSum, - kLiteRtOpCodeTflVarHandle, - kLiteRtOpCodeTflWhere, -}; -// clang format on - -constexpr auto kNumPluginSocModels = - sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); - -} // namespace google_tensor - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (api_version == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", "api_version is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return google_tensor::kPluginManufacturer; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiler_plugin or supported_hardware is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorNpu; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (compiler_plugin == nullptr || num_supported_soc_models == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiler_plugin or num_supported_soc_models is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = google_tensor::kNumPluginSocModels; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (compiler_plugin == nullptr || - soc_model_idx >= google_tensor::kNumPluginSocModels || - soc_model_name == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiler_plugin or soc_model_idx or soc_model_name is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *soc_model_name = google_tensor::kPluginSocModels[soc_model_idx]; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -// TODO (abhirs): Revisit this struct after updating the compiler api wrapper to -// return multiple bytecodes. -struct LiteRtCompiledResultT { - std::string byte_code; - std::vector per_op_data; -}; - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result || !byte_code || !byte_code_size || - (byte_code_idx != 0)) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiled_result or byte_code or byte_code_size" - "or byte_code_idx is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *byte_code = compiled_result->byte_code.data(); - *byte_code_size = compiled_result->byte_code.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - if (!compiled_result || !num_byte_code) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiled_result or num_byte_code is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_byte_code = 1; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (!compiled_result || !call_info || !call_info_size) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiled_result or call_info or call_info_size is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } else if (call_idx >= compiled_result->per_op_data.size()) { - LITERT_LOG(LITERT_ERROR, "%s", "call_idx is out of bounds"); - return kLiteRtStatusErrorIndexOOB; - } - - *call_info = compiled_result->per_op_data.at(call_idx).data(); - *call_info_size = compiled_result->per_op_data.at(call_idx).size(); - *byte_code_idx = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result || !num_calls) { - LITERT_LOG(LITERT_ERROR, "%s", "compiled_result or num_calls is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->per_op_data.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LiteRtCompilerPluginT { - using Flag = std::pair; - std::vector flags; -}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - auto& flags = compiler_plugin->flags; - if (flags.size() != 0) { - LITERT_LOG(LITERT_INFO, "Overwriting existing flags"); - flags.clear(); - } - flags.resize(num_flags); - for (int i = 0; i < num_flags; ++i) { - auto& flag = flags[i]; - flag.first = std::string(keys[i]); - flag.second = std::string(values[i]); - LITERT_LOG(LITERT_INFO, "Setting Flag: %s = %s", flag.first.c_str(), - flag.second.c_str()); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - *compiler_plugin = new LiteRtCompilerPluginT; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - if (compiler_plugin == nullptr) { - return; - } - delete compiler_plugin; -} - -namespace google_tensor { -// TODO(abhirs): update the function to use the darwinn inbuilt way of -// finding supportedops -bool IsOpSupported(const litert::Op& op) { - for (auto unsupported_op : kUnSupportedOps) { - if (unsupported_op == op.Code()) { - return false; - } - } - return true; -} - -} // namespace google_tensor - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ::litert::Subgraph graph(subgraph); - for (const auto& op : graph.Ops()) { - if (!google_tensor::IsOpSupported(op)) { - continue; - } - - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - if (compiler_plugin == nullptr || soc_model == nullptr || - partitions == nullptr || compiled_result == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - LITERT_LOG(LITERT_INFO, - "Starting GoogleTensor Compilation for %d subgraphs, soc_model=%s", - num_partitions, soc_model); - - // Serialize model. - LITERT_LOG(LITERT_INFO, "%s", "Serializing model"); - litert::OwningBufferRef buf; - auto [data, size, offset] = buf.GetWeak(); - const auto opts = litert::SerializationOptions::Defaults(); - LITERT_RETURN_IF_ERROR( - LiteRtSerializeModel(partitions, &data, &size, &offset, false, opts)); - // TODO(abhirs): add support for serializing subgraphs - - absl::string_view buffer_str(reinterpret_cast(buf.Data()), - buf.Size()); - - // Loading Google Tensor Compiler Adapter - LITERT_LOG(LITERT_INFO, "%s", "Loading Google Tensor Compiler Adapter"); - auto adapter_result = litert::google_tensor::Adapter::Create( - /*shared_library_dir=*/std::nullopt); - if (!adapter_result.HasValue()) { - const auto& error_message = adapter_result.Error().Message(); - LITERT_LOG(LITERT_ERROR, "Failed to create adapter: %s", - error_message.c_str()); - return kLiteRtStatusErrorRuntimeFailure; - } - - // Compile model. - LITERT_LOG(LITERT_INFO, "%s", "Compiling model..."); - // TODO(b/398984678): add support for multiple bytecodes - absl::string_view soc_model_view(soc_model); - std::string compiled; - auto compile_status = adapter_result.Value()->api().compile( - buffer_str, soc_model_view, compiler_plugin->flags, &compiled); - - if (!compile_status.ok()) { - LITERT_LOG( - LITERT_ERROR, "%s", - absl::StrCat("Failed to compile model: ", compile_status.message()) - .c_str()); - return kLiteRtStatusErrorRuntimeFailure; - } - - // Result - auto result = std::make_unique(); - - result->byte_code = std::string(compiled.data(), compiled.size()); - // Generate per_op_data. - for (auto i = 0; i < num_partitions; ++i) { - result->per_op_data.emplace_back( - absl::StrFormat("Partition_%d", static_cast(i))); - } - *compiled_result = result.release(); - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc deleted file mode 100644 index 7f6ca4aaf95a74..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -TEST(TestGoogleTensorPlugin, GetConfigInfo) { - ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "GoogleTensor"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 1); - - const char* soc_model_name; - LITERT_ASSERT_OK(LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, - &soc_model_name)); - ASSERT_STREQ(soc_model_name, "P25"); -} - -TEST(TestCallGoogleTensorPlugin, PartitionSimpleMultiAdd) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("simple_multi_op.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Subgraph(0)->Get(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 4); - ASSERT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflAdd); - ASSERT_EQ(selected_ops[1].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(TestCallGoogleTensorPlugin, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "P25", model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode(compiled, 0, &byte_code, - &byte_code_size)); - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, 0, &op_data, &op_data_size, &byte_code_idx)); - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("Partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD deleted file mode 100644 index 6267f882339fed..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_dynamic_lib( - name = "dispatch_api", - srcs = [ - "dispatch_api.cc", - "litert_dispatch_device_context.cc", - "litert_dispatch_graph.cc", - "litert_dispatch_invocation_context.cc", - "southbound.cc", - ], - hdrs = [ - "dispatch_api.h", - "litert_dispatch_device_context.h", - "litert_dispatch_graph.h", - "litert_dispatch_invocation_context.h", - "litert_dispatch_metrics.h", - "southbound.h", - # copybara:uncomment "//third_party/odml/infra/southbound:sb_api.h", - ], - copts = [ - "-Os", - "-fno-exceptions", - "-fno-unwind-tables", - "-fno-asynchronous-unwind-tables", - "-ffunction-sections", - "-fdata-sections", - ], - export_litert_only = True, - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + [ - "-Wl,-soname=libLiteRtDispatch_GoogleTensor.so", - "-Wl,-lc++abi", - ], - shared_lib_name = "dispatch_api_so", - so_name = "libLiteRtDispatch_GoogleTensor.so", - tags = [ - # Don't build/test in OSS until Southbound is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings:string_view", - ], -) - -# This is cc_library target for `libLiteRtDispatch_GoogleTensor.so`. -cc_library( - name = "dispatch_api_shared_lib", - srcs = [":dispatch_api_so"], - visibility = ["//visibility:public"], -) - -# Copies the shared library so that it is available for use in test data as libLiteRtDispatch_GoogleTensor.so. -copy_file( - name = "copy_dispatch_api_so", - src = "//tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch:dispatch_api_so", - target = "libLiteRtDispatch_GoogleTensor.so", -) - -cc_test( - name = "dispatch_api_google_tensor_test", - srcs = [ - "dispatch_api_google_tensor_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - # Don't build/test in OSS until Southbound is available. - "nobuilder", - "no_oss", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "dispatch_api_async_google_tensor_test", - srcs = [ - "dispatch_api_async_google_tensor_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - # Don't build/test in OSS until Southbound is available. - "nobuilder", - "no_oss", - ], - deps = [ - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/darwinn/driver_shared/fence:fence_test_util", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc deleted file mode 100644 index f1aede35c98885..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc +++ /dev/null @@ -1,644 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" - -#include -#include -#include -#include -#include - -#if LITERT_HAS_AHWB_SUPPORT -#include -#endif - -#include "absl/strings/string_view.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -namespace { - -litert::google_tensor::Southbound* TheSouthbound; -char BuildId[256]; - -} // namespace - -namespace litert { -namespace google_tensor { - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, - int num_options) { - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - return option.value.str_value; - } - } - return nullptr; -} - -LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { - auto* shared_library_dir = GetSharedLibraryDir(options, num_options); - std::optional shared_library_dir_opt = - shared_library_dir ? std::make_optional(std::string(shared_library_dir)) - : std::nullopt; - - if (auto southbound = - litert::google_tensor::Southbound::Create(shared_library_dir_opt); - !southbound) { - LITERT_LOG(LITERT_INFO, "Initialization failure: %s", - southbound.Error().Message().c_str()); - return southbound.Error().Status(); - } else { - TheSouthbound = southbound->release(); - } - - auto thr_initialize = TheSouthbound->api().thr_initialize; - if (!thr_initialize) { - LITERT_LOG(LITERT_INFO, "thr_initialize not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - if (auto status = thr_initialize(); status != kThrStatusSuccess) { - LITERT_LOG(LITERT_INFO, "thr_initialize failed: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - auto thr_get_vendor_api_version = - TheSouthbound->api().thr_get_vendor_api_version; - const char* sb_api_version = - thr_get_vendor_api_version ? thr_get_vendor_api_version() : "N.A."; - auto thr_get_vendor_id = TheSouthbound->api().thr_get_vendor_id; - const char* sb_vendor_id = thr_get_vendor_id ? thr_get_vendor_id() : "N.A."; - snprintf( - BuildId, sizeof(BuildId), - "GoogleTensor Dispatch API version %d.%d.%d, Darwinn API version %s, " - "vendor id: %s", - LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH, sb_api_version, sb_vendor_id); - BuildId[sizeof(BuildId) - 1] = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus GetVendorId(const char** vendor_id) { - *vendor_id = "Google"; - return kLiteRtStatusOk; -} - -LiteRtStatus GetBuildId(const char** build_id) { - *build_id = BuildId; - return kLiteRtStatusOk; -} - -LiteRtStatus GetCapabilities(int* capabilities) { - *capabilities = kLiteRtDispatchCapabilitiesBasic | - kLiteRtDispatchCapabilitiesAsync | - kLiteRtDispatchCapabilitiesGraph; - return kLiteRtStatusOk; -} - -LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { - if (auto result = LiteRtDispatchDeviceContextT::Create(*TheSouthbound); - result) { - *device_context = result->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { - delete device_context; - return kLiteRtStatusOk; -} - -LiteRtStatus GetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto result = - invocation_context->GetInputRequirements(input_index, *tensor_type); - result) { - *tensor_buffer_requirements = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get input requirements: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus GetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto result = - invocation_context->GetOutputRequirements(output_index, *tensor_type); - result) { - *tensor_buffer_requirements = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get output requirements: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus RegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, LiteRtTensorBuffer buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (auto status = device_context->RegisterTensorBuffer(buffer); status) { - *tensor_buffer_handle = *status; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } -} - -LiteRtStatus UnregisterTensorBuffer(LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle handle) { - if (auto status = device_context->UnregisterTensorBuffer(handle); status) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to unregister buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } -} - -LiteRtStatus InvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - function_name = ""; - if (auto result = LiteRtDispatchInvocationContextT::CreateFromBytecode( - *TheSouthbound, device_context, exec_type, exec_bytecode_buffer, - function_name, num_inputs, num_outputs); - result) { - *invocation_context = result->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus InvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - delete invocation_context; - return kLiteRtStatusOk; -} - -LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->AttachInput(graph_input_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to attach input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->AttachOutput(graph_output_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to attach output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - - -LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->DetachInput(graph_input_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to detatch input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->DetachOutput(graph_output_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to detatch output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { - if (auto result = invocation_context->Invoke(); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to invoke: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -// ///////////////////////////////////////////////////////////////////////////// -// Async Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus AttachInputEvent( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event) { - if (auto result = - invocation_context->AttachInputEvent(graph_input_index, input_event); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to attach input event: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus InvokeAsync(LiteRtDispatchInvocationContext invocation_context, - int num_output_events, LiteRtEvent* output_events) { - if (auto result = - invocation_context->InvokeAsync(num_output_events, output_events); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to invoke asynchronously: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -// ///////////////////////////////////////////////////////////////////////////// -// Metrics API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus StartMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, int detail_level) { - if (auto result = invocation_context->StartMetricsCollection(detail_level); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to start metrics collection: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus StopMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics) { - if (auto result = invocation_context->StopMetricsCollection(metrics); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to stop metrics collection: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus GetNumMetrics(LiteRtDispatchMetrics metrics, int* num_metrics) { - if (metrics == nullptr) { - LITERT_LOG(LITERT_ERROR, - "GetNumMetrics failed: metrics should not be null"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_metrics = metrics->GetNumMetrics(); - return kLiteRtStatusOk; -} - -LiteRtStatus GetMetric(LiteRtDispatchMetrics metrics, int metric_index, - LiteRtMetric* metric) { - if (metrics == nullptr) { - LITERT_LOG(LITERT_ERROR, "GetMetric failed: metrics should not be null"); - return kLiteRtStatusErrorInvalidArgument; - } - *metric = metrics->GetMetric(metric_index); - return kLiteRtStatusOk; -} - -LiteRtStatus DestroyMetrics(LiteRtDispatchMetrics metrics) { - if (metrics) { - delete metrics; - } - return kLiteRtStatusOk; -} - -// ///////////////////////////////////////////////////////////////////////////// -// Graph Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchGraph* graph) { - if (auto result = device_context->CreateGraph(); result) { - *graph = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create graph: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph) { - if (auto result = graph->device_context()->DestroyGraph(graph); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to delete graph: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type) { - if (auto result = graph->AddNode(node_id, node_type); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to add node: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->AddEdge(edge_id); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to add edge: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectNodeInput(node_id, input_index, edge_id); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect node input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectNodeOutput(node_id, output_index, edge_id); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect node output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectGraphInput(input_index, edge_id); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect graph input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectGraphOutput(output_index, edge_id); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect graph output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, - const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle) { - if (auto result = device_context->LoadExecutable(type, bytecode_buffer); - result) { - *exec_handle = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to load executable: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle) { - if (auto result = device_context->UnloadExecutable(exec_handle); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to unload executable: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, - const char* function_name) { - // TODO - b/397771624: Southbound currently doesn't support function names, so - // overriding function names to empty strings as a temporary fix. We need to - // investigate with the CoreML team to find a more robust solution. - function_name = ""; - if (auto result = - graph->AssignNodeFunction(node_id, exec_handle, function_name); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to assign node function: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, - const char* value) { - if (auto result = graph->AnnotateGraph(key, value); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to annotate graph: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, const char* key, - const char* value) { - if (auto result = graph->AnnotateNode(node_id, key, value); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to annotate node: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id, const char* key, - const char* value) { - if (auto result = graph->AnnotateEdge(edge_id, key, value); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to annotate edge: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus InvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context) { - if (auto result = LiteRtDispatchInvocationContextT::CreateFromGraph( - *TheSouthbound, device_context, graph); - result) { - *invocation_context = result->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -} // namespace google_tensor -} // namespace litert - -// ///////////////////////////////////////////////////////////////////////////// - -namespace { - -LiteRtDispatchInterface TheInterface = { - .initialize = litert::google_tensor::Initialize, - .get_vendor_id = litert::google_tensor::GetVendorId, - .get_build_id = litert::google_tensor::GetBuildId, - .get_capabilities = litert::google_tensor::GetCapabilities, - .device_context_create = litert::google_tensor::DeviceContextCreate, - .device_context_destroy = litert::google_tensor::DeviceContextDestroy, - .get_input_requirements = litert::google_tensor::GetInputRequirements, - .get_output_requirements = litert::google_tensor::GetOutputRequirements, - .register_tensor_buffer = litert::google_tensor::RegisterTensorBuffer, - .unregister_tensor_buffer = litert::google_tensor::UnregisterTensorBuffer, - .invocation_context_create = litert::google_tensor::InvocationContextCreate, - .invocation_context_destroy = - litert::google_tensor::InvocationContextDestroy, - .attach_input = litert::google_tensor::AttachInput, - .attach_output = litert::google_tensor::AttachOutput, - .detach_input = litert::google_tensor::DetachInput, - .detach_output = litert::google_tensor::DetachOutput, - .invoke = litert::google_tensor::Invoke, - .start_metrics_collection = litert::google_tensor::StartMetricsCollection, - .stop_metrics_collection = litert::google_tensor::StopMetricsCollection, - .get_num_metrics = litert::google_tensor::GetNumMetrics, - .get_metric = litert::google_tensor::GetMetric, - .destroy_metrics = litert::google_tensor::DestroyMetrics, -}; - -LiteRtDispatchAsyncInterface TheAsyncInterface = { - .attach_input_event = litert::google_tensor::AttachInputEvent, - .invoke_async = litert::google_tensor::InvokeAsync, -}; - -LiteRtDispatchGraphInterface TheGraphInterface = { - .graph_create = litert::google_tensor::GraphCreate, - .graph_destroy = litert::google_tensor::GraphDestroy, - .add_node = litert::google_tensor::AddNode, - .add_edge = litert::google_tensor::AddEdge, - .connect_node_input = litert::google_tensor::ConnectNodeInput, - .connect_node_output = litert::google_tensor::ConnectNodeOutput, - .connect_graph_input = litert::google_tensor::ConnectGraphInput, - .connect_graph_output = litert::google_tensor::ConnectGraphOutput, - .load_executable = litert::google_tensor::LoadExecutable, - .unload_executable = litert::google_tensor::UnloadExecutable, - .assign_node_function = litert::google_tensor::AssignNodeFunction, - .annotate_graph = litert::google_tensor::AnnotateGraph, - .annotate_node = litert::google_tensor::AnnotateNode, - .annotate_edge = litert::google_tensor::AnnotateEdge, - .invocation_context_create_from_graph = - litert::google_tensor::InvocationContextCreateFromGraph, -}; - -LiteRtDispatchApi TheApi = { - .version = {.major = LITERT_API_VERSION_MAJOR, - .minor = LITERT_API_VERSION_MINOR, - .patch = LITERT_API_VERSION_PATCH}, - .interface = &TheInterface, - .async_interface = &TheAsyncInterface, - .graph_interface = &TheGraphInterface, -}; - -} // namespace - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { - *api = TheApi; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h deleted file mode 100644 index 00392c06efe163..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace litert::google_tensor { - -LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchGraph* graph); -LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph); -LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); -LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, - const void* bytecode, size_t bytecode_size, - LiteRtDispatchExecutableHandle* exec_handle); -LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle); -LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, - const char* function_name); -LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, - const char* value); -LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, const char* key, - const char* value); -LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id, const char* key, - const char* value); -LiteRtStatus InvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context); - -} // namespace litert::google_tensor - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc deleted file mode 100644 index 762792a135e0ca..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#if defined(__ANDROID__) -#include "platforms/darwinn/tachyon/core/fence/fence.h" -#endif -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "third_party/darwinn/driver_shared/fence/fence_test_util.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; -using Fence = std::shared_ptr; - -TEST(DispatchApiAsync, GoogleTensor) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a GoogleTensor eTPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kGoogleTensorModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Attach sync fences to input buffers. - // /////////////////////////////////////////////////////////////////////////// - - Fence input_fence_0 = platforms::darwinn::fence_util::CreateFence(); - Fence input_fence_1 = platforms::darwinn::fence_util::CreateFence(); - - LiteRtEvent input_event_0; - ASSERT_EQ(LiteRtCreateEventFromSyncFenceFd(input_fence_0->GetFd(), - /*owns_fd=*/false, &input_event_0), - kLiteRtStatusOk); - - LiteRtEvent input_event_1; - ASSERT_EQ(LiteRtCreateEventFromSyncFenceFd(input_fence_1->GetFd(), - /*owns_fd=*/false, &input_event_1), - kLiteRtStatusOk); - - ASSERT_EQ(LiteRtDispatchAttachInputEvent( - invocation_context, /*graph_input_index=*/0, input_event_0), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtDispatchAttachInputEvent( - invocation_context, /*graph_input_index=*/1, input_event_1), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - LiteRtEvent output_event = nullptr; - EXPECT_EQ(LiteRtDispatchInvokeAsync(invocation_context, 1, &output_event), - kLiteRtStatusOk); - ASSERT_NE(output_event, nullptr); - - // Attach output event to output tensor buffer. - ASSERT_EQ(LiteRtSetTensorBufferEvent(output_tensor_buffer, output_event), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Signal input fences. - // /////////////////////////////////////////////////////////////////////////// - - ASSERT_OK(input_fence_0->Signal(/*success=*/true)); - ASSERT_OK(input_fence_1->Signal(/*success=*/true)); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtDestroyEvent(input_event_0); - LiteRtDestroyEvent(input_event_1); - LiteRtDestroyEvent(output_event); - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc deleted file mode 100644 index 2d2cca562552ff..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; - -TEST(DispatchApi, GoogleTensor) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a GoogleTensor eTPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kGoogleTensorModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc deleted file mode 100644 index 342c469a7cdb68..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" - -#include -#include - -#if __ANDROID__ -#include -#endif // __ANDROID__ - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" - -using litert::Error; -using litert::Expected; -using litert::Unexpected; - -LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() { - if (!thr_graphs_.empty()) { - auto thr_graph_delete = southbound_.api().thr_graph_delete; - if (!thr_graph_delete) { - LITERT_LOG(LITERT_ERROR, "thr_graph_delete not found"); - } else { - for (auto* thr_graph : thr_graphs_) { - thr_graph_delete(thr_graph); - } - } - } - - if (thr_context_) { - auto thr_context_delete = southbound_.api().thr_context_delete; - if (!thr_context_delete) { - LITERT_LOG(LITERT_ERROR, "thr_context_delete not found"); - } else { - thr_context_delete(thr_context_); - } - } -} - -Expected -LiteRtDispatchDeviceContextT::Create( - const litert::google_tensor::Southbound& southbound) { - Ptr device_context(new LiteRtDispatchDeviceContextT(southbound)); - - auto thr_context_create = southbound.api().thr_context_create; - if (!thr_context_create) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "thr_context_create not found"); - } - - device_context->thr_context_ = thr_context_create(); - return device_context; -} - -Expected -LiteRtDispatchDeviceContextT::RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer) { - LiteRtTensorBufferType tensor_buffer_type; - if (auto status = - LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get buffer type"); - } - - if (tensor_buffer_type != kLiteRtTensorBufferTypeAhwb) { - return Error(kLiteRtStatusErrorUnsupported, "Unsupported buffer type"); - } - - size_t tensor_buffer_size; - if (auto status = - LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get buffer size"); - } - - size_t tensor_buffer_offset; - if (auto status = - LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); - status != kLiteRtStatusOk) { - if (status == kLiteRtStatusErrorNotFound) { - tensor_buffer_offset = 0; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get buffer offset"); - } - } - - LiteRtRankedTensorType tensor_type; - if (auto status = - LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get tensor buffer type"); - } - - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported"); - } - - AHardwareBuffer* ahwb; -#if LITERT_HAS_AHWB_SUPPORT - if (auto status = LiteRtGetTensorBufferAhwb(tensor_buffer, &ahwb); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get AHWB"); - } -#else - return Error(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif - - ThrBufferHandle thr_buffer_handle; - - if (tensor_buffer_offset == 0) { - auto thr_register_buffer = southbound_.api().thr_register_buffer; - if (!thr_register_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer not found"); - } - - if (auto status = thr_register_buffer( - thr_context_, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, - tensor_buffer_size, &thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_register_buffer failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer failed"); - } - - } else { - auto thr_register_buffer_with_offset = - southbound_.api().thr_register_buffer_with_offset; - if (!thr_register_buffer_with_offset) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer_with_offset not found"); - } - - if (auto status = thr_register_buffer_with_offset( - thr_context_, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, - tensor_buffer_offset, tensor_buffer_size, &thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_register_buffer_with_offset failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer_with_offset failed"); - } - } - - return thr_buffer_handle; -} - -litert::Expected LiteRtDispatchDeviceContextT::UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto thr_unregister_buffer = southbound_.api().thr_unregister_buffer; - if (!thr_unregister_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unregister_buffer not found"); - } - - ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; - if (auto status = thr_unregister_buffer(thr_context_, thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_unregister_buffer failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unregister_buffer failed"); - } - - return {}; -} - -litert::Expected -LiteRtDispatchDeviceContextT::CreateGraph() { - auto thr_graph_create = southbound_.api().thr_graph_create; - if (!thr_graph_create) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_create not found"); - } - - ThrGraph* thr_graph = thr_graph_create(thr_context_); - if (!thr_graph) { - return Error(kLiteRtStatusErrorRuntimeFailure, "thr_graph_create failed"); - } - - return new LiteRtDispatchGraphT(southbound_, thr_graph, this); -} - -litert::Expected LiteRtDispatchDeviceContextT::DestroyGraph( - LiteRtDispatchGraph graph) { - auto thr_graph_delete = southbound_.api().thr_graph_delete; - if (!thr_graph_delete) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_delete not found"); - } - - thr_graphs_.erase(graph->thr_graph()); - - ThrGraph* thr_graph = graph->thr_graph(); - if (auto status = thr_graph_delete(thr_graph); status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_destroy failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, "thr_graph_destroy failed"); - } - - delete graph; - return {}; -} - -litert::Expected -LiteRtDispatchDeviceContextT::LoadExecutable( - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer) { - auto thr_load_sq_container = southbound_.api().thr_load_sq_container; - if (!thr_load_sq_container) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_load_sq_container not found"); - } - - ThrSqContainerType thr_type; - switch (type) { - case kLiteRtDispatchExecutableTypeDspLibrary: - thr_type = kThrSqContainerTypeFunctionLibrary; - break; - case kLiteRtDispatchExecutableTypeMlModel: - thr_type = kThrSqContainerTypeMlModel; - break; - default: - LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", type); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Unexpected executable type"); - } - - ThrSqContainerHandle sq_handle; - ThrStatus status; - if (bytecode_buffer->fd >= 0 && - // Unfortunately thrLoadSqContainerFd doesn't support passing an - // offset. So if the offset is non-zero, we fallback to passing a CPU - // memory address right below. - (bytecode_buffer->offset == 0)) { - bool lazy_loading = false; - status = southbound_.api().thr_load_sq_container_fd( - thr_context_, thr_type, bytecode_buffer->fd, bytecode_buffer->size, - lazy_loading, &sq_handle); - } else { - auto bytecode_ptr = - static_cast(bytecode_buffer->base_addr) + - bytecode_buffer->offset; - status = southbound_.api().thr_load_sq_container( - thr_context_, thr_type, bytecode_ptr, bytecode_buffer->size, - &sq_handle); - } - if (status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_load_sq_container failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_load_sq_container failed"); - } - - return sq_handle; -} - -litert::Expected LiteRtDispatchDeviceContextT::UnloadExecutable( - LiteRtDispatchExecutableHandle exec_handle) { - auto thr_unload_sq_container = southbound_.api().thr_unload_sq_container; - if (!thr_unload_sq_container) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unload_sq_container not found"); - } - - ThrSqContainerHandle sq_handle = exec_handle; - if (auto status = thr_unload_sq_container(thr_context_, sq_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_unload_sq_container failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unload_sq_container failed"); - } - - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h deleted file mode 100644 index 4a7074d49ede66..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ - -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -class LiteRtDispatchDeviceContextT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtDispatchDeviceContextT(); - - static litert::Expected Create( - const litert::google_tensor::Southbound& southbound); - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer); - - litert::Expected UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected CreateGraph(); - litert::Expected DestroyGraph(LiteRtDispatchGraph graph); - - litert::Expected LoadExecutable( - LiteRtDispatchExecutableType type, - const LiteRtMemBuffer* bytecode_buffer); - - litert::Expected UnloadExecutable( - LiteRtDispatchExecutableHandle exec_handle); - - ThrContext* thr_context() { return thr_context_; } - - void add_graph(ThrGraph* graph) { thr_graphs_.insert(graph); } - - private: - explicit LiteRtDispatchDeviceContextT( - const litert::google_tensor::Southbound& southbound) - : southbound_(southbound) {} - - const litert::google_tensor::Southbound& southbound_; - ThrContext* thr_context_ = nullptr; - absl::flat_hash_set thr_graphs_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc deleted file mode 100644 index d3530b56d57f46..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" - -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using litert::Error; -using litert::Expected; - -namespace { - -// We store THR names in a global set as a workaround to b/369144429. -std::set* ThrNames = new std::set(); - -absl::string_view ThrNodeIdStr(LiteRtDispatchNodeId node_id) { - auto str = "node_" + std::to_string(node_id); - auto iter = ThrNames->find(str); - if (iter == ThrNames->end()) { - iter = ThrNames->insert(iter, str); - } - return *iter; -} - -} // namespace - -absl::string_view ThrEdgeIdStr(LiteRtDispatchEdgeId edge_id) { - auto str = "edge_" + std::to_string(edge_id); - auto iter = ThrNames->find(str); - if (iter == ThrNames->end()) { - iter = ThrNames->insert(iter, str); - } - return *iter; -} - -litert::Expected LiteRtDispatchGraphT::AddNode( - LiteRtDispatchNodeId node_id, LiteRtDispatchNodeType node_type) { - auto thr_graph_add_sq_node = southbound_.api().thr_graph_add_sq_node; - if (!thr_graph_add_sq_node) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_add_sq_node not found"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - ThrNodeType thr_node_type; - switch (node_type) { - case kLiteRtDispatchNodeTypeDsp: - thr_node_type = kThrNodeTypeDsp; - break; - case kLiteRtDispatchNodeTypeNpu: - thr_node_type = kThrNodeTypeNpu; - break; - default: - LITERT_LOG(LITERT_ERROR, "Unexpected node type: %d", node_type); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected node type"); - } - - if (auto status = - thr_graph_add_sq_node(thr_graph_, thr_node_id.data(), thr_node_type); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_add_sq_node failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_add_sq_node failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AddEdge( - LiteRtDispatchEdgeId edge_id) { - auto thr_graph_add_edge = southbound_.api().thr_graph_add_edge; - if (!thr_graph_add_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_add_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - ThrEdgeType thr_edge_type = kThrEdgeNoType; - if (auto status = - thr_graph_add_edge(thr_graph_, thr_edge_id.data(), thr_edge_type); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_add_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, "thr_graph_add_edge failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectNodeInput( - LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id) { - auto thr_graph_connect_node_input = - southbound_.api().thr_graph_connect_node_input; - if (!thr_graph_connect_node_input) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_connect_node_input not found"); - } - - int next_input_index = NextNodeInputIndex(node_id); - if (input_index != next_input_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", - input_index, next_input_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected input index"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_connect_node_input(thr_graph_, thr_node_id.data(), - thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_input_edge failed"); - } - - AddInputEdge(input_index, edge_id); - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectNodeOutput( - LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id) { - auto thr_graph_connect_node_output = - southbound_.api().thr_graph_connect_node_output; - if (!thr_graph_connect_node_output) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_connect_node_output not found"); - } - - int next_output_index = NextNodeOutputIndex(node_id); - if (output_index != next_output_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", - output_index, next_output_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected output index"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_connect_node_output( - thr_graph_, thr_node_id.data(), thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_output_edge failed"); - } - - AddOutputEdge(output_index, edge_id); - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectGraphInput( - int input_index, LiteRtDispatchEdgeId edge_id) { - int next_input_index = NextGraphInputIndex(); - if (input_index != next_input_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", - input_index, next_input_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected input index"); - } - - auto thr_graph_set_input_edge = southbound_.api().thr_graph_set_input_edge; - if (!thr_graph_set_input_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_input_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_set_input_edge(thr_graph_, thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_input_edge failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectGraphOutput( - int output_index, LiteRtDispatchEdgeId edge_id) { - int next_output_index = NextGraphOutputIndex(); - if (output_index != next_output_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", - output_index, next_output_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected output index"); - } - - auto thr_graph_set_output_edge = southbound_.api().thr_graph_set_output_edge; - if (!thr_graph_set_output_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_output_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_set_output_edge(thr_graph_, thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_output_edge failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AssignNodeFunction( - LiteRtDispatchNodeId node_id, LiteRtDispatchExecutableHandle exec_handle, - const char* function_name) { - auto thr_graph_assign_sq = southbound_.api().thr_graph_assign_sq; - if (!thr_graph_assign_sq) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_assign_sq not found"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - ThrSqContainerHandle sq_handle = exec_handle; - // An empty function name represent no function name being provided and - // therefore we must pass a nullptr to the call below, otherwise the SB API - // will expect a model with a signature. See b/378913220. - const char* function_name_ptr = - absl::string_view(function_name).empty() ? nullptr : function_name; - if (auto status = thr_graph_assign_sq(thr_graph_, thr_node_id.data(), - sq_handle, function_name_ptr); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_assign_sq failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_assign_sq failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AnnotateGraph(const char* key, - const char* value) { - auto thr_graph_annotate_graph = southbound_.api().thr_graph_annotate_graph; - if (!thr_graph_annotate_graph) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_graph not found"); - } - - if (auto status = thr_graph_annotate_graph(thr_graph_, key, value); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_graph failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_graph failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AnnotateNode( - LiteRtDispatchNodeId node_id, const char* key, const char* value) { - auto thr_graph_annotate_node = southbound_.api().thr_graph_annotate_node; - if (!thr_graph_annotate_node) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_node not found"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - if (auto status = - thr_graph_annotate_node(thr_graph_, thr_node_id.data(), key, value); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_node failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_node failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AnnotateEdge( - LiteRtDispatchEdgeId edge_id, const char* key, const char* value) { - auto thr_graph_annotate_edge = southbound_.api().thr_graph_annotate_edge; - if (!thr_graph_annotate_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = - thr_graph_annotate_edge(thr_graph_, thr_edge_id.data(), key, value); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_edge failed"); - } - - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h deleted file mode 100644 index 6586e58f9bd637..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ - -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -class LiteRtDispatchGraphT { - public: - LiteRtDispatchGraphT(const litert::google_tensor::Southbound& southbound, - ThrGraph* thr_graph, - LiteRtDispatchDeviceContext device_context) - : southbound_(southbound), - thr_graph_(thr_graph), - device_context_(device_context) {} - - ThrGraph* thr_graph() { return thr_graph_; } - - LiteRtDispatchDeviceContext device_context() { return device_context_; } - - litert::Expected AddNode(LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); - litert::Expected AddEdge(LiteRtDispatchEdgeId edge_id); - - litert::Expected InputEdge(int input_index) const { - return IoEdge(input_index, input_edges_); - } - - litert::Expected OutputEdge(int output_index) const { - return IoEdge(output_index, output_edges_); - } - - size_t NumOutputs() const { return output_edges_.size(); } - - litert::Expected ConnectNodeInput(LiteRtDispatchNodeId node_id, - int input_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected ConnectNodeOutput(LiteRtDispatchNodeId node_id, - int output_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected ConnectGraphInput(int input_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected ConnectGraphOutput(int output_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected AssignNodeFunction( - LiteRtDispatchNodeId node_id, LiteRtDispatchExecutableHandle exec_handle, - const char* function_name); - - litert::Expected AnnotateGraph(const char* key, const char* value); - - litert::Expected AnnotateNode(LiteRtDispatchNodeId node_id, - const char* key, const char* value); - - litert::Expected AnnotateEdge(LiteRtDispatchEdgeId edge_id, - const char* key, const char* value); - - private: - using NextNodeIoIndexMap = std::map; - using IoIndexToEdgeIdMap = std::map; - - int NextNodeOutputIndex(LiteRtDispatchNodeId node_id) { - return NextNodeIoIndex(node_id, next_node_output_index_); - } - - int NextGraphInputIndex() { return next_graph_input_index_++; } - - int NextGraphOutputIndex() { return next_graph_output_index_++; } - - int NextNodeIoIndex(LiteRtDispatchNodeId node_id, NextNodeIoIndexMap& map) { - return map[node_id]++; - } - - litert::Expected IoEdge( - int io_index, const IoIndexToEdgeIdMap& map) const { - auto iter = map.find(io_index); - if (iter == map.end()) { - return litert::Unexpected(kLiteRtStatusErrorNotFound, - "Unexpected graph input/output index"); - } - return iter->second; - } - - int NextNodeInputIndex(LiteRtDispatchNodeId node_id) { - return NextNodeIoIndex(node_id, next_node_input_index_); - } - - void AddInputEdge(int input_index, LiteRtDispatchEdgeId edge_id) { - input_edges_[input_index] = edge_id; - } - - void AddOutputEdge(int output_index, LiteRtDispatchEdgeId edge_id) { - output_edges_[output_index] = edge_id; - } - - const litert::google_tensor::Southbound& southbound_; - ThrGraph* thr_graph_; - LiteRtDispatchDeviceContext device_context_; - NextNodeIoIndexMap next_node_input_index_; - NextNodeIoIndexMap next_node_output_index_; - int next_graph_input_index_ = 0; - int next_graph_output_index_ = 0; - IoIndexToEdgeIdMap input_edges_; - IoIndexToEdgeIdMap output_edges_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc deleted file mode 100644 index ac0a845c56ea4b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" - -#include - -#include "absl/strings/string_view.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -using litert::Error; -using litert::Expected; -using litert::Unexpected; - -extern absl::string_view ThrEdgeIdStr(LiteRtDispatchEdgeId edge_id); - -namespace { - -constexpr const size_t kEdgeTpuPadding = 64; - -template -inline constexpr auto Pad(X x, Align align) { - return ((x + align - 1) / align) * align; -} - -} // namespace - -litert::Expected -LiteRtDispatchInvocationContextT::CreateFromBytecode( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs) { - auto graph = device_context->CreateGraph(); - if (!graph) { - return graph.Error(); - } - - LiteRtDispatchNodeId node_id = 0; - LiteRtDispatchNodeType node_type; - switch (exec_type) { - case kLiteRtDispatchExecutableTypeDspLibrary: - node_type = kLiteRtDispatchNodeTypeDsp; - break; - case kLiteRtDispatchExecutableTypeMlModel: - node_type = kLiteRtDispatchNodeTypeNpu; - break; - default: - LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", exec_type); - return Error(kLiteRtStatusErrorInvalidArgument, - "Unexpected executable type"); - } - - if (auto status = (*graph)->AddNode(node_id, node_type); !status) { - return status.Error(); - } - - auto exec_handle = - device_context->LoadExecutable(exec_type, exec_bytecode_buffer); - if (!exec_handle) { - return exec_handle.Error(); - } - - if (auto status = - (*graph)->AssignNodeFunction(node_id, *exec_handle, function_name); - !status) { - return status.Error(); - } - - LiteRtDispatchEdgeId next_edge_id = 0; - - for (auto input_index = 0; input_index < num_inputs; ++input_index) { - LiteRtDispatchEdgeId edge_id = next_edge_id++; - if (auto status = (*graph)->AddEdge(edge_id); !status) { - return status.Error(); - } - if (auto status = (*graph)->ConnectGraphInput(input_index, edge_id); - !status) { - return status.Error(); - } - if (auto status = (*graph)->ConnectNodeInput(node_id, input_index, edge_id); - !status) { - return status.Error(); - } - } - - for (auto output_index = 0; output_index < num_outputs; ++output_index) { - LiteRtDispatchEdgeId edge_id = next_edge_id++; - if (auto status = (*graph)->AddEdge(edge_id); !status) { - return status.Error(); - } - if (auto status = - (*graph)->ConnectNodeOutput(node_id, output_index, edge_id); - !status) { - return status.Error(); - } - if (auto status = (*graph)->ConnectGraphOutput(output_index, edge_id); - !status) { - return status.Error(); - } - } - - auto invocation_context = CreateFromGraph(southbound, device_context, *graph); - if (!invocation_context) { - return invocation_context.Error(); - } - - (*invocation_context)->AttachExecutable(*exec_handle); - - return invocation_context; -} - -litert::Expected -LiteRtDispatchInvocationContextT::CreateFromGraph( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph) { - auto thr_invocation_context_get = southbound.api().thr_invocation_context_get; - if (!thr_invocation_context_get) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get not found"); - } - - ThrGraph* thr_graph = graph->thr_graph(); - auto thr_icontext = - thr_invocation_context_get(thr_graph, device_context->thr_context()); - if (!thr_icontext) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get failed"); - } - - device_context->add_graph(thr_graph); - return Ptr(new LiteRtDispatchInvocationContextT(southbound, thr_icontext, - device_context, graph)); -} - -LiteRtDispatchInvocationContextT::~LiteRtDispatchInvocationContextT() { - auto thr_invocation_context_delete = - southbound_.api().thr_invocation_context_delete; - if (!thr_invocation_context_delete) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete not found"); - } else { - ThrGraph* thr_graph = graph_->thr_graph(); - if (auto status = - thr_invocation_context_delete(thr_graph, thr_invocation_context_); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete failed: %d", - status); - } - } - - if (exec_handle_) { - device_context_->UnloadExecutable(*exec_handle_); - } -} - -namespace { - -Expected GetTensorBufferRequirements( - const LiteRtRankedTensorType& tensor_type) { - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported on GoogleTensor"); - } - - LiteRtTensorBufferType supported_tensor_buffer_types[] = { - kLiteRtTensorBufferTypeAhwb, - }; - int num_supported_tensor_buffer_types = - sizeof(supported_tensor_buffer_types) / - sizeof(supported_tensor_buffer_types[0]); - - auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); - if (!buffer_size) { - return Unexpected(buffer_size.Error()); - } - - size_t padded_buffer_size = Pad(*buffer_size, kEdgeTpuPadding); - - LiteRtTensorBufferRequirements requirements; - if (auto status = LiteRtCreateTensorBufferRequirements( - num_supported_tensor_buffer_types, supported_tensor_buffer_types, - padded_buffer_size, /*num_strides=*/0, /*strides=*/nullptr, - &requirements); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create tensor buffer requirements"); - } - - return requirements; -} -} // namespace - -Expected -LiteRtDispatchInvocationContextT::GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -Expected -LiteRtDispatchInvocationContextT::GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -namespace { - -litert::Expected AttachBufferHelper( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchEdgeId edge_id, - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto thr_invocation_context_attach_buffer = - southbound.api().thr_invocation_context_attach_buffer; - if (!thr_invocation_context_attach_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_buffer not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - ThrContext* thr_context = invocation_context->device_context()->thr_context(); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; - if (auto status = thr_invocation_context_attach_buffer( - thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_attach_buffer failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_buffer failed"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->InputEdge(graph_input_index); result) { - auto edge_id = *result; - return AttachBufferHelper(southbound_, this, edge_id, tensor_buffer_handle); - } else { - return result.Error(); - } -} - -litert::Expected LiteRtDispatchInvocationContextT::AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->OutputEdge(graph_output_index); result) { - auto edge_id = *result; - return AttachBufferHelper(southbound_, this, edge_id, tensor_buffer_handle); - } else { - return result.Error(); - } -} - -namespace { - -litert::Expected DetachTensorBufferHelper( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchEdgeId edge_id, - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto thr_invocation_context_detach_buffer = - southbound.api().thr_invocation_context_detach_buffer; - if (!thr_invocation_context_detach_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_detach_buffer not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - ThrContext* thr_context = invocation_context->device_context()->thr_context(); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; - if (auto status = thr_invocation_context_detach_buffer( - thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_detach_buffer failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_detach_buffer failed"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->InputEdge(graph_input_index); result) { - auto edge_id = *result; - return DetachTensorBufferHelper(southbound_, this, edge_id, - tensor_buffer_handle); - } else { - return result.Error(); - } -} - -litert::Expected LiteRtDispatchInvocationContextT::DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->OutputEdge(graph_output_index); result) { - auto edge_id = *result; - return DetachTensorBufferHelper(southbound_, this, edge_id, - tensor_buffer_handle); - } else { - return result.Error(); - } -} - -namespace { - -litert::Expected PrepareForInvoke( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, - bool create_output_sync_fence) { - auto thr_invocation_context_prepare_for_invoke = - southbound.api().thr_invocation_context_prepare_for_invoke; - if (!thr_invocation_context_prepare_for_invoke) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_prepare_for_invoke not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - if (auto status = thr_invocation_context_prepare_for_invoke( - thr_icontext, create_output_sync_fence); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_prepare_for_invoke failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_prepare_for_invoke failed"); - } - - return {}; -} - -litert::Expected InvokeOnce( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context) { - auto thr_invocation_context_invoke_once = - southbound.api().thr_invocation_context_invoke_once; - if (!thr_invocation_context_invoke_once) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_invoke_once not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - if (auto status = thr_invocation_context_invoke_once(thr_icontext); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_invoke_once failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_invoke_once failed"); - } - - return {}; -} - -litert::Expected Wait( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context) { - auto thr_invocation_context_wait = - southbound.api().thr_invocation_context_wait; - if (!thr_invocation_context_wait) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_wait not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - if (auto status = thr_invocation_context_wait(thr_icontext); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_wait failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_wait failed"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::Invoke() { - if (auto result = PrepareForInvoke(southbound_, this, - /*create_output_sync_fence=*/false); - !result) { - return result.Error(); - } - if (auto result = InvokeOnce(southbound_, this); !result) { - return result.Error(); - } - return Wait(southbound_, this); -} - -litert::Expected LiteRtDispatchInvocationContextT::AttachInputEvent( - int graph_input_index, LiteRtEvent input_event) { - int input_fence_fd; - if (auto status = LiteRtGetEventSyncFenceFd(input_event, &input_fence_fd); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get sync fence fd from event"); - } - - auto edge = graph_->InputEdge(graph_input_index); - if (!edge) { - LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", - graph_input_index); - return edge.Error(); - } - auto edge_id = *edge; - - auto thr_invocation_context_attach_input_buffer_sync_fence = - southbound_.api().thr_invocation_context_attach_input_buffer_sync_fence; - if (!thr_invocation_context_attach_input_buffer_sync_fence) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_input_buffer_sync_fence not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_invocation_context_attach_input_buffer_sync_fence( - thr_invocation_context_, thr_edge_id.data(), input_fence_fd); - status != kThrStatusSuccess) { - LITERT_LOG( - LITERT_ERROR, - "thr_invocation_context_attach_input_buffer_sync_fence failed: %d", - status); - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_input_buffer_sync_fence failed"); - } - - input_sync_fences_[thr_edge_id.data()] = input_fence_fd; - return {}; -} - -namespace { - -litert::Expected GetOutputEvent( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtEvent* output_event) { - auto edge = invocation_context->graph()->OutputEdge(graph_output_index); - if (!edge) { - LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", - graph_output_index); - return edge.Error(); - } - auto edge_id = *edge; - - auto thr_invocation_context_get_output_buffer_sync_fence = - southbound.api().thr_invocation_context_get_output_buffer_sync_fence; - if (!thr_invocation_context_get_output_buffer_sync_fence) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get_output_buffer_sync_fence not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - int output_fence_fd; - if (auto status = thr_invocation_context_get_output_buffer_sync_fence( - thr_icontext, thr_edge_id.data(), &output_fence_fd); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_get_output_buffer_sync_fence failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get_output_buffer_sync_fence failed"); - } - - if (auto status = LiteRtCreateEventFromSyncFenceFd( - output_fence_fd, /*owns_fd=*/false, output_event); - status != kLiteRtStatusOk) { - return Error(status, "Failed to create event from sync fence fd"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::InvokeAsync( - int num_output_events, LiteRtEvent* output_events) { - if (num_output_events != graph_->NumOutputs()) { - LITERT_LOG(LITERT_ERROR, "Unexpected number of output events: %d", - num_output_events); - return Error(kLiteRtStatusErrorInvalidArgument, - "Unexpected number of output events"); - } - - if (auto status = PrepareForInvoke(southbound_, this, - /*create_output_sync_fence=*/true); - !status) { - return status.Error(); - } - - if (auto status = InvokeOnce(southbound_, this); !status) { - return status.Error(); - } - - // Deatach input fences. - auto thr_invocation_context_detach_input_buffer_sync_fence = - southbound_.api().thr_invocation_context_detach_input_buffer_sync_fence; - if (!thr_invocation_context_detach_input_buffer_sync_fence) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_detach_input_buffer_sync_fence not found"); - } - for (const auto& p : input_sync_fences_) { - const auto& thr_edge_id = p.first; - auto input_fence_fd = p.second; - if (auto status = thr_invocation_context_detach_input_buffer_sync_fence( - thr_invocation_context_, thr_edge_id.data(), input_fence_fd); - status != kThrStatusSuccess) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_deatch_input_buffer_sync_fence failed"); - } - } - input_sync_fences_.clear(); - - // Extract output events. - for (auto graph_output_index = 0; graph_output_index < num_output_events; - ++graph_output_index) { - if (auto status = GetOutputEvent(southbound_, this, graph_output_index, - &output_events[graph_output_index]); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to get event for output %d", - graph_output_index); - return status.Error(); - } - } - - return {}; -} - -litert::Expected LiteRtDispatchInvocationContextT::StartMetricsCollection( - int detail_level) { - auto thr_invocation_context_start_metrics_collection = - southbound_.api().thr_invocation_context_start_metrics_collection; - if (!thr_invocation_context_start_metrics_collection) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_start_metrics_collection not found"); - } - if (auto status = thr_invocation_context_start_metrics_collection( - thr_invocation_context_, detail_level); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_start_metrics_collection failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_start_metrics_collection failed"); - } - return {}; -} - -litert::Expected LiteRtDispatchInvocationContextT::StopMetricsCollection( - LiteRtDispatchMetrics* metrics) { - auto thr_invocation_context_stop_metrics_collection = - southbound_.api().thr_invocation_context_stop_metrics_collection; - if (!thr_invocation_context_stop_metrics_collection) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_stop_metrics_collection not found"); - } - ThrInvocationMetrics thr_metrics{.version = 0}; - if (auto status = thr_invocation_context_stop_metrics_collection( - thr_invocation_context_, &thr_metrics); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_stop_metrics_collection failed: %d", - status); - *metrics = new LiteRtDispatchMetricsT(/*num_metrics=*/0, - /*metric_names=*/nullptr, - /*metric_values=*/nullptr); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_stop_metrics_collection failed"); - } - *metrics = new LiteRtDispatchMetricsT(thr_metrics.num_metrics, - thr_metrics.metric_keys, - thr_metrics.metric_values); - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h deleted file mode 100644 index 8cbae593d0874c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ - -#include -#include -#include -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -class LiteRtDispatchInvocationContextT { - public: - using Ptr = std::unique_ptr; - - static litert::Expected CreateFromBytecode( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs); - - static litert::Expected CreateFromGraph( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph); - - ~LiteRtDispatchInvocationContextT(); - - litert::Expected GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type); - litert::Expected GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected Invoke(); - litert::Expected InvokeAsync(int num_output_events, - LiteRtEvent* output_events); - litert::Expected StartMetricsCollection(int detail_level); - litert::Expected StopMetricsCollection(LiteRtDispatchMetrics* metrics); - - litert::Expected AttachInputEvent(int graph_input_index, - LiteRtEvent input_event); - - ThrInvocationContext* thr_invocation_context() { - return thr_invocation_context_; - } - - LiteRtDispatchDeviceContext device_context() { return device_context_; } - - LiteRtDispatchGraph graph() { return graph_; } - - private: - LiteRtDispatchInvocationContextT( - const litert::google_tensor::Southbound& southbound, - ThrInvocationContext* thr_invocation_context, - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph) - : southbound_(southbound), - thr_invocation_context_(thr_invocation_context), - device_context_(device_context), - graph_(graph) {} - - void AttachExecutable(LiteRtDispatchExecutableHandle exec_handle) { - exec_handle_ = exec_handle; - } - - const litert::google_tensor::Southbound& southbound_; - ThrInvocationContext* thr_invocation_context_; - LiteRtDispatchDeviceContext device_context_; - LiteRtDispatchGraph graph_; - std::optional exec_handle_; - std::map input_sync_fences_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h deleted file mode 100644 index a33a69d4adc237..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_METRICS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_METRICS_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -class LiteRtDispatchMetricsT { - public: - // Construct a LiteRtDispatchMetricsT object using C-style arrays and strings. - // `metric_names` is an array of C-style strings representing metric names. - // `metric_values` is an array of int64_t values representing metric values. - // Both `metric_names` and `metric_values` have `num_metrics` elements. - // - // NOTE: The values in the arrays are copied into the LiteRtDispatchMetricsT. - LiteRtDispatchMetricsT(int num_metrics, const char** metric_names, - const int64_t* metric_values) - : metric_names_(metric_names, metric_names + num_metrics), - metric_values_(metric_values, metric_values + num_metrics) {} - int GetNumMetrics() const { return metric_names_.size(); } - LiteRtMetric GetMetric(int metric_index) const { - if (metric_index < 0 || metric_index >= GetNumMetrics()) { - return LiteRtMetric{.name = "invalid_metric", - .value = LiteRtAny{.type = kLiteRtAnyTypeNone}}; - } - return LiteRtMetric{ - .name = metric_names_[metric_index].c_str(), - .value = - LiteRtAny{ - .type = kLiteRtAnyTypeInt, - .int_value = metric_values_[metric_index], - }, - }; - } - - private: - const std::vector metric_names_; - const std::vector metric_values_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_METRICS_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc deleted file mode 100644 index e103c289d5820b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -#include - -#include -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#define Load(H, S) \ - H = reinterpret_cast(::dlsym(dlib_handle_, #S)); \ - if (!H) { \ - LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ - ::dlerror()); \ - } - -namespace litert { -namespace google_tensor { - -namespace { - -// The SouthBound APIs are implemented in the EdgeTPU libraries. -// It used to be implemented in the libedgetpu_util.so and has been moved to -// libedgetpu_litert.so in newer Android builds. -constexpr const char* kLiteRtLibPath = "/vendor/lib64/libedgetpu_litert.so"; -constexpr const char* kEdgeTpuUtilLibPath = "/vendor/lib64/libedgetpu_util.so"; - -} // namespace - -Southbound::Southbound() : api_(new ThrFunctions) {} - -Southbound::~Southbound() { - if (dlib_handle_) { - ::dlclose(dlib_handle_); - } -} - -Expected Southbound::Create( - std::optional shared_library_dir) { - Ptr southbound(new Southbound); - if (auto status = southbound->LoadSymbols(shared_library_dir); !status) { - return Unexpected(status.Error()); - } - - return southbound; -} - -Expected Southbound::LoadSymbols( - std::optional shared_library_dir) { - // Always load the Southbound API library from the vendor partition. - (void)shared_library_dir; - - // Try loading the new EdgeTPU LiteRT library first. If it fails, it might be - // because the Android build is too old. In that case, load the old EdgeTPU - // utility library. - dlib_handle_ = ::dlopen(kLiteRtLibPath, RTLD_NOW | RTLD_LOCAL); - if (!dlib_handle_) { - dlib_handle_ = ::dlopen(kEdgeTpuUtilLibPath, RTLD_NOW | RTLD_LOCAL); - if (!dlib_handle_) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to load Southbound shared library"); - } - } - - // Binds all supported symbols from the shared library to the function - // pointers. - Load(api_->thr_initialize, thrInitialize); - - Load(api_->thr_get_vendor_api_version, thrGetVendorApiVersion); - Load(api_->thr_get_vendor_id, thrGetVendorId); - - Load(api_->thr_context_create, thrContextCreate); - Load(api_->thr_context_delete, thrContextDelete); - - Load(api_->thr_graph_create, thrGraphCreate); - Load(api_->thr_graph_delete, thrGraphDelete); - - Load(api_->thr_graph_add_edge, thrGraphAddEdge); - Load(api_->thr_graph_add_sq_node, thrGraphAddSqNode); - - Load(api_->thr_graph_connect_node_input, thrGraphConnectNodeInput); - Load(api_->thr_graph_connect_node_output, thrGraphConnectNodeOutput); - - Load(api_->thr_graph_set_input_edge, thrGraphSetInputEdge); - Load(api_->thr_graph_set_output_edge, thrGraphSetOutputEdge); - - Load(api_->thr_graph_annotate_graph, thrGraphAnnotateGraph); - Load(api_->thr_graph_annotate_edge, thrGraphAnnotateEdge); - Load(api_->thr_graph_annotate_node, thrGraphAnnotateNode); - - Load(api_->thr_load_sq_container, thrLoadSqContainer); - Load(api_->thr_load_sq_container_fd, thrLoadSqContainerFd); - Load(api_->thr_load_sq_container_file, thrLoadSqContainerFile); - Load(api_->thr_unload_sq_container, thrUnloadSqContainer); - - Load(api_->thr_graph_assign_sq, thrGraphAssignSq); - Load(api_->thr_sq_query_scratch_pad, thrSqQueryScratchPad); - Load(api_->thr_sq_attach_scratch_pad_buffer, thrSqAttachScratchPadBuffer); - - Load(api_->thr_register_buffer, thrRegisterBuffer); - Load(api_->thr_register_buffer_with_offset, thrRegisterBufferWithOffset); - Load(api_->thr_unregister_buffer, thrUnregisterBuffer); - - Load(api_->thr_invocation_context_get, thrInvocationContextGet); - Load(api_->thr_invocation_context_delete, thrInvocationContextDelete); - - Load(api_->thr_invocation_context_attach_buffer, - thrInvocationContextAttachBuffer); - Load(api_->thr_invocation_context_detach_buffer, - thrInvocationContextDetachBuffer); - - Load(api_->thr_invocation_context_prepare_for_invoke, - thrInvocationContextPrepareForInvoke); - Load(api_->thr_invocation_context_invoke_once, - thrInvocationContextInvokeOnce); - Load(api_->thr_invocation_context_wait, thrInvocationContextWait); - - Load(api_->thr_invocation_context_attach_input_buffer_sync_fence, - thrInvocationContextAttachInputBufferSyncFence); - Load(api_->thr_invocation_context_get_output_buffer_sync_fence, - thrInvocationContextGetOutputBufferSyncFence); - Load(api_->thr_invocation_context_detach_input_buffer_sync_fence, - thrInvocationContextDetachInputBufferSyncFence); - - Load(api_->thr_invocation_context_query_node_scratch_pad, - thrInvocationContextQueryNodeScratchPad); - Load(api_->thr_invocation_context_attach_scratch_pad_buffer, - thrInvocationContextAttachScratchPadBuffer); - - Load(api_->thr_invocation_context_start_metrics_collection, - thrInvocationContextStartMetricsCollection); - Load(api_->thr_invocation_context_stop_metrics_collection, - thrInvocationContextStopMetricsCollection); - - Load(api_->thr_vendor_set_system_attribute_str, - thrVendorSetSystemAttributeStr); - Load(api_->thr_vendor_set_system_attribute_int64, - thrVendorSetSystemAttributeInt64); - - LITERT_LOG(LITERT_INFO, "SouthBound symbols loaded"); - return {}; -} - -} // namespace google_tensor -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h deleted file mode 100644 index d3ab7367c7e665..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ - -#include -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::google_tensor { - -class Southbound { - public: - using Ptr = std::unique_ptr; - struct ThrFunctions; - - Southbound(Southbound&) = delete; - Southbound(Southbound&&) = delete; - Southbound& operator=(const Southbound&) = delete; - Southbound& operator=(Southbound&&) = delete; - - ~Southbound(); - - static Expected Create(std::optional shared_library_dir); - - const ThrFunctions& api() const { return *api_; } - - private: - Southbound(); - Expected LoadSymbols(std::optional shared_library_dir); - - void* dlib_handle_ = nullptr; - std::unique_ptr api_; -}; - -// A convenient struct for holding function pointers to SouthBound symbols. -// These function pointers will be loaded to the shared library on device during -// runtime. -struct Southbound::ThrFunctions { - decltype(&thrInitialize) thr_initialize = nullptr; - - decltype(&thrGetVendorApiVersion) thr_get_vendor_api_version = nullptr; - decltype(&thrGetVendorId) thr_get_vendor_id = nullptr; - - decltype(&thrContextCreate) thr_context_create = nullptr; - decltype(&thrContextDelete) thr_context_delete = nullptr; - - decltype(&thrGraphCreate) thr_graph_create = nullptr; - decltype(&thrGraphDelete) thr_graph_delete = nullptr; - - decltype(&thrGraphAddEdge) thr_graph_add_edge = nullptr; - decltype(&thrGraphAddSqNode) thr_graph_add_sq_node = nullptr; - - decltype(&thrGraphConnectNodeInput) thr_graph_connect_node_input = nullptr; - decltype(&thrGraphConnectNodeOutput) thr_graph_connect_node_output = nullptr; - - decltype(&thrGraphSetInputEdge) thr_graph_set_input_edge = nullptr; - decltype(&thrGraphSetOutputEdge) thr_graph_set_output_edge = nullptr; - - decltype(&thrGraphAnnotateGraph) thr_graph_annotate_graph = nullptr; - decltype(&thrGraphAnnotateEdge) thr_graph_annotate_edge = nullptr; - decltype(&thrGraphAnnotateNode) thr_graph_annotate_node = nullptr; - - decltype(&thrLoadSqContainer) thr_load_sq_container = nullptr; - decltype(&thrLoadSqContainerFd) thr_load_sq_container_fd = nullptr; - decltype(&thrLoadSqContainerFile) thr_load_sq_container_file = nullptr; - decltype(&thrUnloadSqContainer) thr_unload_sq_container = nullptr; - - decltype(&thrGraphAssignSq) thr_graph_assign_sq = nullptr; - decltype(&thrSqQueryScratchPad) thr_sq_query_scratch_pad = nullptr; - decltype(&thrSqAttachScratchPadBuffer) thr_sq_attach_scratch_pad_buffer = - nullptr; - - decltype(&thrRegisterBuffer) thr_register_buffer = nullptr; - decltype(&thrRegisterBufferWithOffset) thr_register_buffer_with_offset = - nullptr; - decltype(&thrUnregisterBuffer) thr_unregister_buffer = nullptr; - - decltype(&thrInvocationContextGet) thr_invocation_context_get = nullptr; - decltype(&thrInvocationContextDelete) thr_invocation_context_delete = nullptr; - - decltype(&thrInvocationContextAttachBuffer) - thr_invocation_context_attach_buffer = nullptr; - decltype(&thrInvocationContextDetachBuffer) - thr_invocation_context_detach_buffer = nullptr; - - decltype(&thrInvocationContextPrepareForInvoke) - thr_invocation_context_prepare_for_invoke = nullptr; - decltype(&thrInvocationContextInvokeOnce) thr_invocation_context_invoke_once = - nullptr; - decltype(&thrInvocationContextWait) thr_invocation_context_wait = nullptr; - - decltype(&thrInvocationContextAttachInputBufferSyncFence) - thr_invocation_context_attach_input_buffer_sync_fence = nullptr; - decltype(&thrInvocationContextGetOutputBufferSyncFence) - thr_invocation_context_get_output_buffer_sync_fence = nullptr; - decltype(&thrInvocationContextDetachInputBufferSyncFence) - thr_invocation_context_detach_input_buffer_sync_fence = nullptr; - - decltype(&thrInvocationContextQueryNodeScratchPad) - thr_invocation_context_query_node_scratch_pad = nullptr; - decltype(&thrInvocationContextAttachScratchPadBuffer) - thr_invocation_context_attach_scratch_pad_buffer = nullptr; - - decltype(&thrInvocationContextStartMetricsCollection) - thr_invocation_context_start_metrics_collection = nullptr; - decltype(&thrInvocationContextStopMetricsCollection) - thr_invocation_context_stop_metrics_collection = nullptr; - - decltype(&thrVendorSetSystemAttributeStr) - thr_vendor_set_system_attribute_str = nullptr; - decltype(&thrVendorSetSystemAttributeInt64) - thr_vendor_set_system_attribute_int64 = nullptr; -}; - -} // namespace litert::google_tensor - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD deleted file mode 100644 index 73d1e8e484ebdf..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/vendors/mediatek:mediatek_build_defs.bzl", "litert_cc_lib_with_mtk") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_cc_lib_with_mtk( - name = "neuron_adapter_api", - srcs = [ - "neuron_adapter_api.cc", - ], - hdrs = [ - "neuron_adapter_api.h", - ], - tags = [ - # Don't build/test in OS until neuron is available. - "nobuilder", - "notap", - ], - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD deleted file mode 100644 index 11a4f96268b748..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:__subpackages__"], -) - -litert_dynamic_lib( - name = "compiler_plugin", - srcs = ["compiler_plugin.cc"], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - shared_lib_name = "compiler_plugin_so", - so_name = "libLiteRtCompilerPlugin_MediaTek.so", - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - ungrte = True, - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":compile_model", - ":create_model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:common_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:neuron_litert_schema", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "create_model", - srcs = ["create_model.cc"], - hdrs = ["create_model.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:add_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:batch_matmul_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:common_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:concat_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:fully_connected_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:gelu_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:mean_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:mul_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:operand_map", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:reshape_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:rsqrt_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:softmax_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:sub_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:transpose_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - ], -) - -cc_library( - name = "compile_model", - srcs = ["compile_model.cc"], - hdrs = ["compile_model.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - ], -) - -litert_test( - name = "compiler_plugin_test", - srcs = [ - "compiler_plugin_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - "nobuilder", - "no_oss", - "nosan", - ], - # Currently this test can only be run on Android because we don't have x86 shared libraries for - # MTK. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - ungrte = True, - use_sys_malloc = True, - deps = [ - ":compiler_plugin", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers_oss", - "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc deleted file mode 100644 index 15a5485b20dba4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h" - -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected CompileModel( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - std::optional soc_model) { -#if defined(__ANDROID__) - if (soc_model) { - return Error(kLiteRtStatusErrorInvalidArgument, - "JIT compilation for a specific SoC is not supported"); - } -#endif - - // Per MediaTek recommendation, Compilation_create, - // Compilation_createWithOptions, and Compilation_setOptimizationString - // should be used as follow: - // - AOT Compilation: Compilation_createWithOptions only - // - JIT Compilation: Compilation_create and Compilation_setOptimizationString - // The code below takes care of those conditions. - - // NOLINTBEGIN - const auto compile_options = -#if __ANDROID__ - std::string(neuron_adapter_api.JitCompileOptions()); -#else - std::string(neuron_adapter_api.AotCompileOptions()); -#endif - // NOLINTEND - - auto compilation = -#if __ANDROID__ - neuron_adapter_api.CreateCompilation(model); -#else - neuron_adapter_api.CreateCompilation(model, compile_options); -#endif - if (!compilation) { - return compilation.Error(); - } - - if (neuron_adapter_api.api().compilation_set_priority( - compilation->get(), NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation priority"); - } - - if (neuron_adapter_api.api().compilation_set_preference( - compilation->get(), NEURON_PREFER_SUSTAINED_SPEED) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation preference"); - } - -#if __ANDROID__ - if (!compile_options.empty()) { - if (auto status = - neuron_adapter_api.api().compilation_set_optimization_string( - compilation->get(), compile_options.c_str()); - status != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_INFO, - "NeuronCompilation_setOptimizationString failed with error %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set optimization string"); - } - } -#endif - - if (auto status = - neuron_adapter_api.api().compilation_finish(compilation->get()); - status != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_INFO, "NeuronCompilation_finish failed with error %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to finish compilation"); - } - - return compilation; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h deleted file mode 100644 index 3e30c0d8451b7b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected CompileModel( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - std::optional soc_model); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc deleted file mode 100644 index 1f92fb4168f124..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema_generated.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -// -// Configurations -// - -using litert::Error; -using litert::Expected; -using litert::mediatek::NeuronAdapterApi; -using litert::mediatek::NeuronCompilationPtr; -using litert::mediatek::NeuronModelPtr; - -namespace { - -constexpr char kPluginManufacturer[] = "MediaTek"; - -// clang-format off -constexpr std::pair kPluginSocModels[] = { - {"mt6853", "mt6853"}, - {"mt6877", "mt6877"}, - {"mt6878", "mt6878"}, - {"mt6879", "mt6879"}, - {"mt6886", "mt6886"}, - {"mt6893", "mt6893"}, - {"mt6895", "mt6895"}, - {"mt6897", "mt6897"}, - {"mt6983", "mt6983"}, - {"mt6985", "mt6985"}, - {"mt6989", "mt6989"}, - {"mt6991", "mt6991"}, -}; - -constexpr LiteRtOpCode kSupportedOps[] = { - kLiteRtOpCodeTflAdd, - kLiteRtOpCodeTflMul, - kLiteRtOpCodeTflBatchMatmul, - kLiteRtOpCodeTflFullyConnected, - kLiteRtOpCodeTflReshape, - kLiteRtOpCodeTflTranspose, - kLiteRtOpCodeTflRsqrt, - kLiteRtOpCodeTflConcatenation, - kLiteRtOpCodeTflQuantize, - kLiteRtOpCodeTflSlice, - kLiteRtOpCodeTflSub, - kLiteRtOpCodeTflTanh, - kLiteRtOpCodeTflSoftmax, - kLiteRtOpCodeTflMean, - kLiteRtOpCodeTflGelu, -}; -// clang-format on - -constexpr auto kNumPluginSocModels = - sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); - -std::optional FindSocModel(absl::string_view soc_model_name) { - std::optional soc_model; - for (auto i = 0; i < kNumPluginSocModels; ++i) { - if (soc_model_name == kPluginSocModels[i].first) { - soc_model = kPluginSocModels[i].second; - break; - } - } - return soc_model; -} - -} // namespace - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (api_version == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return kPluginManufacturer; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorNpu; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = kNumPluginSocModels; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { - return kLiteRtStatusErrorInvalidArgument; - } - *soc_model_name = kPluginSocModels[soc_model_idx].first; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -// TODO: Revisit this struct after we extend the compiler plugin API to return -// results with more than one single bytecode. -struct LiteRtCompiledResultT { - std::vector graph_names; - neuron::BytecodeBuilder bytebuilder; -}; - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - if (!compiled_result || !num_byte_code) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_byte_code = compiled_result->graph_names.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result || !byte_code || !byte_code_size || - (byte_code_idx >= compiled_result->graph_names.size())) { - return kLiteRtStatusErrorInvalidArgument; - } - *byte_code = compiled_result->bytebuilder.GetBytecode().first; - *byte_code_size = compiled_result->bytebuilder.GetBytecode().second; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (!compiled_result || !call_info || !call_info_size) { - return kLiteRtStatusErrorInvalidArgument; - } else if (call_idx >= compiled_result->graph_names.size()) { - return kLiteRtStatusErrorIndexOOB; - } - - auto& graph_name = compiled_result->graph_names[call_idx]; - *call_info = graph_name.data(); - *call_info_size = graph_name.size(); - // MTK should have one byte code per call. - *byte_code_idx = call_idx; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result || !num_calls) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->graph_names.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LiteRtCompilerPluginT {}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - // IMPLEMENT ME - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - auto* plugin = new LiteRtCompilerPluginT; - *compiler_plugin = plugin; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -namespace { - -// TODO update this function to match the new legalizations. -bool IsOpSupported(const litert::Op& op) { - // NOTE: Currently we are demoing by just mapping simple f32 mul ops. Use a - // very loose guard for now -- only checking if op code is supported. - for (auto supported_op : kSupportedOps) { - if (op.Code() == supported_op && - litert::mediatek::VerifyCommonOp(op, op.Code())) { - return true; - } - } - return false; -} - -} // namespace - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - litert::Subgraph graph(subgraph); - for (const auto& op : graph.Ops()) { - if (!IsOpSupported(op)) { - continue; - } - - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } - - return kLiteRtStatusOk; -} - -namespace { - -Expected> CompilePartition( - NeuronAdapterApi& neuron_adapter_api, const litert::Subgraph& partition, - const std::string& graph_name, std::optional soc_model) { - auto model = CreateModel(neuron_adapter_api, partition, graph_name); - if (!model) { - return model.Error(); - } - - auto compilation = CompileModel(neuron_adapter_api, model->get(), soc_model); - if (!compilation) { - return compilation.Error(); - } - - size_t bytecode_size; - if (neuron_adapter_api.api().compilation_get_compiled_network_size( - compilation->get(), &bytecode_size) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get compiled network size"); - } - - std::vector bytecode(bytecode_size); - if (neuron_adapter_api.api().compilation_store_compiled_network( - compilation->get(), bytecode.data(), bytecode.size()) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get compiled network"); - } - - return bytecode; -} - -} // namespace - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - static constexpr char dla_directory_template[] = "/tmp/tempdir_dla.XXXXXX"; - char* dla_directory_name = mkdtemp(const_cast(dla_directory_template)); - if (dla_directory_name == nullptr) { - LITERT_LOG(LITERT_ERROR, "Failed to make DLA temporary directory") - return kLiteRtStatusErrorFileIO; - } - setenv("MTKNN_ADAPTER_DLA_PLATFORM", soc_model, 1); - setenv("MTKNN_ADAPTER_DLA_DIR", dla_directory_name, 1); - - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - - LITERT_LOG(LITERT_INFO, - "Starting MediaTek Compilation for %d subgraphs, soc_model=%s", - num_partitions, soc_model); - - auto opt_soc_model = soc_model ? FindSocModel(soc_model) : std::nullopt; - if (opt_soc_model) { - LITERT_LOG(LITERT_ERROR, "Compiling for MediaTek architecture: %s", - *opt_soc_model); - } else if (soc_model) { - LITERT_LOG(LITERT_ERROR, "Unexpected SoC model: %s", soc_model); - rmdir(dla_directory_name); - return kLiteRtStatusErrorInvalidArgument; - } - - // Initialize SDK and load mediatek shared libraries. - - auto api = NeuronAdapterApi::Create(/*shared_library_dir=*/std::nullopt); - if (!api) { - rmdir(dla_directory_name); - return api.Error().Status(); - } - - auto result = std::make_unique(); - - for (auto i = 0; i < num_partitions; ++i) { - auto graph_name = absl::StrFormat("Partition_%d", i); - auto bytecode = - CompilePartition(**api, *model.Subgraph(i), graph_name, opt_soc_model); - rmdir(dla_directory_name); - if (!bytecode) { - LITERT_LOG(LITERT_INFO, "%s", bytecode.Error().Message().c_str()); - return bytecode.Error().Status(); - } - auto bufferIdx = result->bytebuilder.AddBuffer( - graph_name, (int8_t*)bytecode->data(), bytecode->size()); - result->bytebuilder.AddCompiledNetwork( - graph_name, NeuronSchema::CompiledType_AdapterCache, bufferIdx); - result->graph_names.emplace_back(graph_name); - } - - if (!result->bytebuilder.Finish()) { - return kLiteRtStatusErrorCompilation; - } - *compiled_result = result.release(); - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc deleted file mode 100644 index b8bb947587b229..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -using ::testing::Values; - -// clang-format off -const auto kSupportedOps = Values( - "add_cst.tflite", - "add_simple.tflite", - "simple_add_op.tflite"); -// clang-format on - -TEST(TestMediatekPlugin, GetConfigInfo) { - EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "MediaTek"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - ASSERT_EQ(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models), - kLiteRtStatusOk); - ASSERT_EQ(num_supported_soc_models, 12); - - const char* config_id; - ASSERT_EQ( - LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id), - kLiteRtStatusOk); - EXPECT_STREQ(config_id, "mt6853"); -} - -TEST(TestMediatekPlugin, PartitionAdd) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("add_simple.tflite"); - - LiteRtOpListT selected_op_list; - ASSERT_EQ(LiteRtCompilerPluginPartition(plugin.get(), /*soc_model=*/nullptr, - model.Subgraph(0)->Get(), - &selected_op_list), - kLiteRtStatusOk); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 1); - EXPECT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflAdd); -} - -// ///////////////////////////////////////////////////////////////////////////// - -class MtkPluginOpCompatibilityTest - : public ::testing::TestWithParam {}; - -TEST_P(MtkPluginOpCompatibilityTest, SupportedOpsTest) { -#ifndef __ANDROID__ - GTEST_SKIP() << "Loading shared lib not currently supported on linux."; -#endif // __ANDROID__ - - LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str()); - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel(GetParam()); - - LiteRtCompiledResult compiled; - ASSERT_EQ(LiteRtCompilerPluginCompile(plugin.get(), /*soc_model=*/nullptr, - model.Get(), &compiled), - kLiteRtStatusOk); - - LiteRtParamIndex num_byte_code; - ASSERT_EQ(LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code), - kLiteRtStatusOk); - ASSERT_EQ(num_byte_code, 1); - - const void* byte_code; - size_t byte_code_size; - - ASSERT_EQ(LiteRtGetCompiledResultByteCode(compiled, /*byte_code_idx=*/0, - &byte_code, &byte_code_size), - kLiteRtStatusOk); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - ASSERT_EQ(LiteRtGetCompiledResultCallInfo(compiled, /*call_idx=*/0, &op_data, - &op_data_size, &byte_code_idx), - kLiteRtStatusOk); - - EXPECT_EQ(byte_code_idx, 0); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - EXPECT_EQ(op_data_string, "Partition_0"); - - LiteRtDestroyCompiledResult(compiled); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, MtkPluginOpCompatibilityTest, - kSupportedOps); - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc deleted file mode 100644 index c7b3ca2d3ebdb0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h" - -#include -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -namespace litert::mediatek { - -Expected CreateModel(const NeuronAdapterApi& neuron_adapter_api, - const litert::Subgraph& partition, - const std::string& model_name) { - auto model = neuron_adapter_api.CreateModel(); - if (!model) { - return model.Error(); - } - - if (neuron_adapter_api.api().model_set_name( - model->get(), model_name.c_str()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to set model name"); - } - - OperandMap operand_map(neuron_adapter_api, model->get()); - - std::vector input_indices; - for (const auto& input : partition.Inputs()) { - auto operand_index = operand_map.GetOperandIndex(input); - if (!operand_index) { - return operand_index.Error(); - } - input_indices.push_back(*operand_index); - } - - std::vector output_indices; - for (const auto& output : partition.Outputs()) { - auto operand_index = operand_map.GetOperandIndex(output); - if (!operand_index) { - return operand_index.Error(); - } - output_indices.push_back(*operand_index); - } - - if (neuron_adapter_api.api().model_identify_inputs_and_outputs( - model->get(), input_indices.size(), input_indices.data(), - output_indices.size(), output_indices.data()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to identify model I/Os"); - } - - for (const auto& op : partition.Ops()) { - Expected status; - switch (op.Code()) { - case kLiteRtOpCodeTflAdd: - status = - LegalizeAddOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflMul: - status = - LegalizeMulOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflBatchMatmul: - status = LegalizeBatchMatMulOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflFullyConnected: - status = LegalizeFullyConnectedOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflReshape: - status = LegalizeReshapeOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflTranspose: - status = LegalizeTransposeOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflRsqrt: - status = - LegalizeRsqrtOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflConcatenation: - status = - LegalizeConcatOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflQuantize: - status = LegalizeCommonOp(neuron_adapter_api, model->get(), operand_map, - op, NEURON_QUANTIZE); - break; - case kLiteRtOpCodeTflSlice: - status = LegalizeCommonOp(neuron_adapter_api, model->get(), operand_map, - op, NEURON_SLICE); - break; - case kLiteRtOpCodeTflTanh: - status = LegalizeCommonOp(neuron_adapter_api, model->get(), operand_map, - op, NEURON_TANH); - break; - case kLiteRtOpCodeTflSub: - status = - LegalizeSubOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflSoftmax: - status = LegalizeSoftmaxOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflMean: - status = - LegalizeMeanOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflGelu: - status = - LegalizeGeluOp(neuron_adapter_api, model->get(), operand_map, op); - break; - default: - return Error(kLiteRtStatusErrorRuntimeFailure, "Unsupported op"); - } - - if (!status) { - return status.Error(); - } - } - - if (neuron_adapter_api.api().model_finish(model->get()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to finish model"); - } - - return model; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h deleted file mode 100644 index 6e958d691a80e1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -// Create a new NeuronModel Graph from given LiteRt Graph. -Expected CreateModel(const NeuronAdapterApi& neuron_adapter_api, - const Subgraph& partition, - const std::string& model_name); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD deleted file mode 100644 index abd020b3ce9725..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:__subpackages__"], -) - -cc_library( - name = "operand_map", - srcs = ["operand_map.cc"], - hdrs = ["operand_map.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "neuron_utils", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "neuron_utils", - srcs = ["neuron_utils.cc"], - hdrs = ["neuron_utils.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "add_op_legalization", - srcs = ["add_op_legalization.cc"], - hdrs = ["add_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "mul_op_legalization", - srcs = ["mul_op_legalization.cc"], - hdrs = ["mul_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "batch_matmul_op_legalization", - srcs = ["batch_matmul_op_legalization.cc"], - hdrs = ["batch_matmul_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "fully_connected_op_legalization", - srcs = ["fully_connected_op_legalization.cc"], - hdrs = ["fully_connected_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "reshape_op_legalization", - srcs = ["reshape_op_legalization.cc"], - hdrs = ["reshape_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "transpose_op_legalization", - srcs = ["transpose_op_legalization.cc"], - hdrs = ["transpose_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "rsqrt_op_legalization", - srcs = ["rsqrt_op_legalization.cc"], - hdrs = ["rsqrt_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "concat_op_legalization", - srcs = ["concat_op_legalization.cc"], - hdrs = ["concat_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "quantize_op_legalization", - srcs = ["quantize_op_legalization.cc"], - hdrs = ["quantize_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "common_op_legalization", - srcs = ["common_op_legalization.cc"], - hdrs = ["common_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "sub_op_legalization", - srcs = ["sub_op_legalization.cc"], - hdrs = ["sub_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "softmax_op_legalization", - srcs = ["softmax_op_legalization.cc"], - hdrs = ["softmax_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "mean_op_legalization", - srcs = ["mean_op_legalization.cc"], - hdrs = ["mean_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "gelu_op_legalization", - srcs = ["gelu_op_legalization.cc"], - hdrs = ["gelu_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc deleted file mode 100644 index 47194694f7de7b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeAddOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Add"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_ADD operation takes a 3rd scalar operand, which is used to pass a - // TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = - LiteRtGetAddFusedActivationOption(op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_ADD, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_ADD op"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h deleted file mode 100644 index d774d6bcb972e9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeAddOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc deleted file mode 100644 index 23cbe10ff09386..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeBatchMatMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, - OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize BatchMatMul"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_BATCH_MATMUL operation takes 2 scalar operand, which is used to - // pass a adjX, adjY value. - bool tfl_matmul_param_adj_x = 0, tfl_matmul_param_adj_y = 0; - if (auto status = - LiteRtGetBatchMatmulAdjXOption(op.Get(), &tfl_matmul_param_adj_x); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get batch matmul adjX"); - } - - if (auto status = - LiteRtGetBatchMatmulAdjYOption(op.Get(), &tfl_matmul_param_adj_y); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get batch matmul adjY"); - } - - auto adj_x_operand_index = operand_map.AddScalarBool(tfl_matmul_param_adj_x); - if (!adj_x_operand_index) { - return adj_x_operand_index.Error(); - } - input_indices.push_back(*adj_x_operand_index); - - auto adj_j_operand_index = operand_map.AddScalarBool(tfl_matmul_param_adj_y); - if (!adj_j_operand_index) { - return adj_j_operand_index.Error(); - } - input_indices.push_back(*adj_j_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_BATCH_MATMUL, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_BATCH_MATMUL op"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h deleted file mode 100644 index 227c6563713a58..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeBatchMatMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, - OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc deleted file mode 100644 index 0c3a62f8997b3a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -bool VerifyCommonOp(const litert::Op& op, LiteRtOpCode op_code) { - // Do some common check - return true; -} - -Expected LegalizeCommonOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op, - NeuronOperationType mtk_operation_type) { - LITERT_LOG(LITERT_INFO, "Legalize Op: %d", mtk_operation_type); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/mtk_operation_type, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to add operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h deleted file mode 100644 index 5995f77e888bf1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_COMMON_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_COMMON_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -bool VerifyCommonOp(const litert::Op& op, LiteRtOpCode op_code); - -Expected LegalizeCommonOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op, - NeuronOperationType mtk_operation_type); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_COMMON_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc deleted file mode 100644 index 3320272f9c65f1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeConcatOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Concate"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_CONCAT operation takes an additional scalar operand, which is used - // to pass as a axis. - int32_t axis; - if (auto status = LiteRtGetConcatenationAxisOption(op.Get(), &axis); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get new shape option"); - } - - auto axis_operand_index = operand_map.AddScalarInt32(axis); - if (!axis_operand_index) { - return axis_operand_index.Error(); - } - - input_indices.push_back(*axis_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, - /*type=*/NEURON_CONCATENATION, input_indices, - output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_CONCAT operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h deleted file mode 100644 index e7f1294ec39df4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_CONCAT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_CONCAT_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeConcatOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_CONCAT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc deleted file mode 100644 index 877c7511649a4a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -#define GET_RANK(op) ((op).RankedTensorType()->Layout().Rank()) -#define GET_DIMENSION(op) ((op).RankedTensorType()->Layout().Dimensions()) - -namespace litert::mediatek { - -Expected LegalizeFullyConnectedOp( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - OperandMap& operand_map, const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Fully Connected"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // for beta - if (input_indices.size() < 3) { - auto weights_shape = GET_DIMENSION(op.Inputs()[1]); - std::vector bias_shape = { - static_cast(weights_shape[0])}; - std::vector bias_data(bias_shape[0], 0); - auto bias_data_operand = - operand_map.AddTensorByType(NEURON_TENSOR_QUANT8_SYMM, bias_shape, - bias_data.data(), bias_data.size() * 1); - input_indices.push_back(*bias_data_operand); - } - - // A NEURON_FULLY_CONNECTED operation takes a 4rd scalar operand, which is - // used to pass a TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = LiteRtGetFullyConnectedFusedActivationOption( - op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - auto output_operand = OperandType::Create(op.Outputs()[0]); - std::vector output_indices; - - if (GET_RANK(op.Outputs()[0]) > 2) { - // if output_operand shape , reshape to - auto last_dim = output_operand->GetDimension().back(); - auto elements = output_operand->GetElementCount(); - std::vector new_dimension = {elements / last_dim, last_dim}; - if (auto res = output_operand->Reshape(new_dimension); !res) { - return res.Error(); - } - auto intermediate_operand = operand_map.AddOperand(*output_operand); - output_indices.push_back(*intermediate_operand); - } else { - auto output_operand = operand_map.GetOperandIndex(op.Outputs()[0]); - output_indices.push_back(*output_operand); - if (!output_operand) { - return output_operand.Error(); - } - } - - if (ModelAddOperation(neuron_adapter_api, model, - /*type=*/NEURON_FULLY_CONNECTED, input_indices, - output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set NEURON_FULLY_CONNECTED operation"); - } - - if (GET_RANK(op.Outputs()[0]) > 2) { - // intermediate as reshape input - input_indices = {output_indices.back()}; - auto output_operand = operand_map.GetOperandIndex(op.Outputs()[0]); - if (!output_operand) { - return output_operand.Error(); - } - - auto dimension = op.Outputs()[0].RankedTensorType()->Layout().Dimensions(); - std::vector new_shape(dimension.begin(), dimension.end()); - std::vector tensor_shape = {(uint32_t)new_shape.size()}; - auto new_shape_operand_index = operand_map.AddTensorInt32( - tensor_shape, new_shape.data(), new_shape.size() * sizeof(int32_t)); - if (!new_shape_operand_index) { - return new_shape_operand_index.Error(); - } - input_indices.push_back(*new_shape_operand_index); - output_indices = {*output_operand}; - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_RESHAPE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add Reshape after FC"); - } - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h deleted file mode 100644 index 68d6a319295b90..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeFullyConnectedOp( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - OperandMap& operand_map, const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc deleted file mode 100644 index 32af1156fae40b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -constexpr uint32_t kGeluApproximateTanh = 1; - -Expected LegalizeGeluOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Gelu"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - auto approximate_operand = operand_map.AddScalarUInt32(kGeluApproximateTanh); - if (!approximate_operand) { - return approximate_operand.Error(); - } - - input_indices.push_back(*approximate_operand); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_GELU_V2, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add GELU operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h deleted file mode 100644 index 9249263c77e902..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeGeluOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc deleted file mode 100644 index efbad31106b51b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMeanOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Mean"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_Mean operation takes an additional scalar operand, which is - // used to pass a keepdims. - bool keepdims; - if (auto status = LiteRtGetMeanKeepDimsOption(op.Get(), &keepdims); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get beta"); - } - LITERT_LOG(LITERT_INFO, "keepdims: %d", keepdims); - auto keepdims_operand = operand_map.AddScalarInt32(keepdims); - if (!keepdims_operand) { - return keepdims_operand.Error(); - } - input_indices.push_back(*keepdims_operand); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_MEAN, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_MEAN operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h deleted file mode 100644 index fc36f646d75836..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MEAN_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MEAN_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMeanOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MEAN_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc deleted file mode 100644 index b78f1640f8bc92..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Mul"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_MUL operation takes a 3rd scalar operand, which is used to pass a - // TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = - LiteRtGetMulFusedActivationOption(op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_MUL, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_MUL operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h deleted file mode 100644 index 8ff1c325fe3f27..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc deleted file mode 100644 index 059d51cc10dc9c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h" - -namespace litert::mediatek { - -Expected GetNeuronTensorType(const Tensor& t) { - auto ranked_tensor_type = t.RankedTensorType(); - if (!ranked_tensor_type) { - return ranked_tensor_type.Error(); - } - - int32_t mtk_type; - switch (ranked_tensor_type->ElementType()) { - case ElementType::Float32: - mtk_type = NEURON_TENSOR_FLOAT32; - break; - case ElementType::Float16: - mtk_type = NEURON_TENSOR_FLOAT16; - break; - case ElementType::Int32: - mtk_type = NEURON_TENSOR_INT32; - break; - case ElementType::Int16: - if (t.QTypeId() == kLiteRtQuantizationPerTensor) { - mtk_type = NEURON_TENSOR_QUANT16_SYMM; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Int16 is not supported."); - } - break; - case ElementType::Int8: - if (t.QTypeId() == kLiteRtQuantizationPerTensor) { - mtk_type = NEURON_TENSOR_QUANT8_SYMM; - } else if (t.QTypeId() == kLiteRtQuantizationPerChannel) { - mtk_type = NEURON_TENSOR_QUANT8_SYMM_PER_CHANNEL; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Int8 is not supported."); - } - break; - default: - return Error(kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Unsupported element type: %d", - ranked_tensor_type->ElementType())); - } - return mtk_type; -} - -Expected GetNeuronDataSize(NeuronTensorType type) { - switch (type) { - case NEURON_FLOAT32: - case NEURON_TENSOR_FLOAT32: - case NEURON_INT32: - case NEURON_TENSOR_INT32: - return 4; - case NEURON_FLOAT16: - case NEURON_TENSOR_FLOAT16: - case NEURON_EXT_TENSOR_QUANT16_ASYMM_SIGNED: - return 2; - case NEURON_BOOL: - case NEURON_TENSOR_BOOL8: - case NEURON_TENSOR_QUANT8_ASYMM: - case NEURON_TENSOR_QUANT8_ASYMM_SIGNED: - return 1; - default: - return Error(kLiteRtStatusErrorRuntimeFailure, - "Get Data Size fail for Neuron Type"); - } - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected neuron type"); -} - -Expected IsQuantizedType(NeuronTensorType type) { - switch (type) { - case NEURON_TENSOR_QUANT16_SYMM: - case NEURON_TENSOR_QUANT16_ASYMM: - case NEURON_TENSOR_QUANT8_ASYMM: - case NEURON_TENSOR_QUANT8_ASYMM_SIGNED: - return true; - } - return false; -} - -NeuronReturnCode ModelAddOperation(const NeuronAdapterApi& api, - NeuronModel* model, NeuronOperationType type, - std::vector input, - std::vector output) { - return api.api().model_add_operation(model, type, input.size(), input.data(), - output.size(), output.data()); -}; - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h deleted file mode 100644 index 27633fef2d746f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_NEURON_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_NEURON_UTILS_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { -using NeuronTensorType = int32_t; -using NeuronReturnCode = int32_t; - -Expected GetNeuronTensorType(const Tensor& t); - -Expected GetNeuronDataSize(NeuronTensorType type); - -Expected IsQuantizedType(NeuronTensorType type); - -NeuronReturnCode ModelAddOperation(const NeuronAdapterApi& api, - NeuronModel* model, NeuronOperationType type, - std::vector input, - std::vector output); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_NEURON_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc deleted file mode 100644 index 40347a2c00bbe5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected OperandMap::Register(const NeuronOperandType& operand_type) { - if (neuron_adapter_api_.api().model_add_operand(model_, &operand_type) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to register model operand"); - } - return AllocateOperandIndex(); -} - -Expected OperandMap::Register(const Tensor& t) { - auto operand_type = OperandType::Create(t); - if (!operand_type) { - return operand_type.Error(); - } - - auto operand_index = - Register(static_cast(*operand_type)); - if (!operand_index) { - return operand_index.Error(); - } - LITERT_LOG(LITERT_INFO, "\nOperandIndex: %d", operand_index.Value()); - operand_type->Info(); - - if (t.HasWeights()) { - auto weights = t.Weights().Bytes(); - if (t.QTypeId() == kLiteRtQuantizationPerChannel) { - auto quant_param = operand_type->GetPerChannelQuantParams().Value(); - if (neuron_adapter_api_.api().model_set_symm_per_channel_quant_params( - model_, *operand_index, &quant_param) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set param of per channel quant params"); - } - } - if (neuron_adapter_api_.api().model_set_operand_value( - model_, *operand_index, weights.data(), weights.size()) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set value of tensor weights"); - } - } - - map_[t.Get()] = *operand_index; - return *operand_index; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h deleted file mode 100644 index fb79626d6cf988..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -class OperandType : public NeuronOperandType { - public: - static Expected Create(const Tensor& t) { - auto ranked_tensor_type = t.RankedTensorType(); - if (!ranked_tensor_type) { - return ranked_tensor_type.Error(); - } - - auto tensor_dimensions = ranked_tensor_type->Layout().Dimensions(); - std::vector mtk_dimensions; - mtk_dimensions.reserve(tensor_dimensions.size()); - std::copy(tensor_dimensions.begin(), tensor_dimensions.end(), - std::back_inserter(mtk_dimensions)); - - // tensor type dimensions couldn't be zero. - if (mtk_dimensions.size() == 0) { - mtk_dimensions = { - 1, - }; - } - - // BlockWise Quantize is not supported now. - if (t.HasQuantization() && t.QTypeId() == kLiteRtQuantizationBlockWise) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Doesn't support BlockWise quantize now"); - } - - auto mtk_type = GetNeuronTensorType(t); - if (!mtk_type) { - return mtk_type.Error(); - } - - if (t.QTypeId() == kLiteRtQuantizationPerTensor) { - auto quant_info = t.PerTensorQuantization(); - LITERT_LOG(LITERT_INFO, "zeroPoint: %d, scale: %f", quant_info.zero_point, - quant_info.scale); - return OperandType(*mtk_type, std::move(mtk_dimensions), quant_info.scale, - quant_info.zero_point, std::nullopt); - } else if (t.QTypeId() == kLiteRtQuantizationPerChannel) { - auto quant_info = t.PerChannelQuantization(); - NeuronSymmPerChannelQuantParams params; - params.scaleCount = quant_info.num_channels; - params.scales = quant_info.scales; - params.channelDim = quant_info.quantized_dimension; - LITERT_LOG(LITERT_INFO, "quantized_dimension: %d", - quant_info.quantized_dimension); - LITERT_LOG(LITERT_INFO, "params.channelDim: %d", params.channelDim); - return OperandType(*mtk_type, std::move(mtk_dimensions), 0, 0, params); - } else { - return OperandType(*mtk_type, std::move(mtk_dimensions), /*scale*/ 0, - /*zero_point*/ 0, std::nullopt); - } - } - - void Info() { - std::string vector = "["; - for (int i = 0; i < dimensionCount; i++) { - vector += std::to_string(dimensions_[i]); - vector += ","; - } - vector += "]"; - LITERT_LOG(LITERT_INFO, - "\n[Type] %d" - "\n[zeroPoint]%d" - "\n[scale]%f" - "\n[dimensionCount]%u" - "\n[dimensions]%s\n", - type, zeroPoint, scale, dimensionCount, vector.c_str()); - } - - OperandType(const OperandType&) = delete; - - OperandType(OperandType&& other) - : dimensions_(std::move(other.dimensions_)), - neuron_per_channel_params_(other.neuron_per_channel_params_) { - // Copy all the scalar fields from other. - *static_cast(this) = - *static_cast(&other); - // Reset the pointer fields by using own data. - dimensions = dimensions_.data(); - }; - - Expected Reshape(std::vector& shape) { - auto elements = GetElementCount(); - if (elements != std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies())) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "the elements is not the same"); - } - this->dimensions_ = shape; - this->dimensionCount = this->dimensions_.size(); - this->dimensions = this->dimensions_.data(); - return {}; - } - - Expected GetPerChannelQuantParams() { - if (!neuron_per_channel_params_.has_value()) { - return Error(kLiteRtStatusErrorRuntimeFailure, "No quant param is set"); - } - return neuron_per_channel_params_.value(); - } - - int32_t GetNeuronType() const { return this->type; } - - std::vector GetDimension() { return this->dimensions_; } - - uint32_t GetElementCount() { - return std::accumulate(dimensions_.begin(), dimensions_.end(), 1, - std::multiplies()); - } - - uint32_t GetRank() { return this->dimensions_.size(); } - - OperandType& operator=(const OperandType&) = delete; - OperandType& operator=(OperandType&& other) = delete; - - private: - explicit OperandType(int32_t mtk_type, std::vector&& mtk_dimensions, - float scale, int32_t zero_point, - std::optional pararms) - : dimensions_(std::move(mtk_dimensions)), - neuron_per_channel_params_(pararms) { - this->scale = scale; - this->zeroPoint = zero_point; - this->type = mtk_type; - this->dimensionCount = dimensions_.size(); - this->dimensions = dimensions_.data(); - } - - std::vector dimensions_; - - std::optional neuron_per_channel_params_ = - std::nullopt; -}; - -// This class takes care of registering Tensors and scalars with a given -// NeuronModel and returing their "operand index", which is how the MTK SDK -// handles them. -class OperandMap { - public: - OperandMap(const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model) - : neuron_adapter_api_(neuron_adapter_api), model_(model) {} - - // Add a scalar operand to the model. - Expected AddScalarBool(bool value) { - return AddScalar(NEURON_BOOL, value); - } - Expected AddScalarInt32(int32_t value) { - return AddScalar(NEURON_INT32, value); - } - Expected AddScalarUInt32(uint32_t value) { - return AddScalar(NEURON_UINT32, value); - } - Expected AddScalarFloat32(float value) { - return AddScalar(NEURON_FLOAT32, value); - } - - // Add a tensor operand to the model - Expected AddTensorInt32(std::vector& shape, - const void* data, const size_t data_size) { - return AddTensor(NEURON_TENSOR_INT32, shape, data, data_size); - } - - // Add a tensor operand to the model - Expected AddTensorByType(int mtk_type, std::vector& shape, - const void* data, const size_t data_size) { - return AddTensor(mtk_type, shape, data, data_size); - } - - Expected AddOperand(const NeuronOperandType& operand) { - return Register(operand); - } - - // Find the operand index for a given tensor and, if not done already, add the - // tensor as an operand in the model. - Expected GetOperandIndex(const Tensor& t) { - auto i = map_.find(t.Get()); - if (i != map_.end()) { - return i->second; - } else { - return Register(t); - } - } - - private: - Expected Register(const Tensor& t); - Expected Register(const NeuronOperandType& operand_type); - uint32_t AllocateOperandIndex() { return next_operand_index_++; } - - template - Expected AddScalar(int32_t mtk_type, T value) { - const NeuronOperandType scalar_type = { - .type = mtk_type, - .dimensionCount = 0, - .dimensions = nullptr, - }; - auto operand_index = Register(scalar_type); - if (!operand_index) { - return operand_index.Error(); - } - if (neuron_adapter_api_.api().model_set_operand_value( - model_, *operand_index, &value, sizeof(value)) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set value of scalar operand"); - } - return operand_index; - } - - Expected AddTensor(int32_t mtk_type, - const std::vector& shape, - const void* data, const size_t data_size) { - const NeuronOperandType scalar_type = { - .type = mtk_type, - .dimensionCount = (uint32_t)shape.size(), - .dimensions = shape.data(), - }; - auto operand_index = Register(scalar_type); - if (!operand_index) { - return operand_index.Error(); - } - if (neuron_adapter_api_.api().model_set_operand_value( - model_, *operand_index, data, data_size) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set value of tensor operand"); - } - return operand_index; - } - - const NeuronAdapterApi& neuron_adapter_api_; - NeuronModel* model_; - int next_operand_index_ = 0; - absl::flat_hash_map map_; -}; - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc deleted file mode 100644 index 662b93d66d1508..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeQuantizeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Quantize"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_QUANTIZE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_QUANTIZE operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h deleted file mode 100644 index d2db3761f374de..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeQuantizeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc deleted file mode 100644 index f9a9af0e8f1fb1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeReshapeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Reshape"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_RESHAPE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_RESHAPE operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h deleted file mode 100644 index d8b3b3246ecbb8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeReshapeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc deleted file mode 100644 index 8b35a9d0163174..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeRsqrtOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Rsqrt"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_RSQRT, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_RSQRT operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h deleted file mode 100644 index b8ae369796fc78..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeRsqrtOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc deleted file mode 100644 index 1f0ea602cec504..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSoftmaxOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Softmax"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_Softmax operation takes an additional scalar operand, which is - // used to pass a Beta value. - float beta; - if (auto status = LiteRtGetSoftmaxBetaOption(op.Get(), &beta); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get beta"); - } - auto beta_operand = operand_map.AddScalarFloat32(beta); - if (!beta_operand) { - return beta_operand.Error(); - } - input_indices.push_back(*beta_operand); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_SOFTMAX, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_SOFTMAX operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h deleted file mode 100644 index 22c9ea4f1aed63..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSoftmaxOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc deleted file mode 100644 index 0b26d24bc39f73..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSubOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Sub"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_SUB operation takes a 3rd scalar operand, which is used to pass a - // TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = - LiteRtGetSubFusedActivationOption(op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_SUB, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add value of NEURON_SUB fused activation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h deleted file mode 100644 index bc1e783d55f7b2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSubOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc deleted file mode 100644 index 754a77677a4167..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeTransposeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Transpose"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_TRANSPOSE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add reshape operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h deleted file mode 100644 index 94b445b6218025..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeTransposeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD deleted file mode 100644 index 7315f1598a9476..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_dynamic_lib( - name = "dispatch_api", - srcs = [ - "dispatch_api.cc", - "litert_dispatch_device_context.cc", - "litert_dispatch_invocation_context.cc", - ], - hdrs = [ - "litert_dispatch_device_context.h", - "litert_dispatch_invocation_context.h", - ], - copts = [ - "-Os", - "-fno-exceptions", - "-fno-unwind-tables", - "-fno-asynchronous-unwind-tables", - "-ffunction-sections", - "-fdata-sections", - ], - export_litert_only = True, - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + [ - "-Wl,-soname=libLiteRtDispatch_GoogleTensor.so", - "-Wl,-lc++abi", - ], - shared_lib_name = "dispatch_api_so", - so_name = "libLiteRtDispatch_Mediatek.so", - tags = [ - # Remove when sdk is available to bazel. - "nobuilder", - "notap", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - ], -) - -# This is cc_library target for `libLiteRtDispatch_Mediatek.so`. -cc_library( - name = "dispatch_api_shared_lib", - srcs = [":dispatch_api_so"], -) - -# Copies the shared library so that it is available for use in test data as libLiteRtDispatch_Mediatek.so. -copy_file( - name = "copy_dispatch_api_so", - src = "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", - target = "libLiteRtDispatch_Mediatek.so", -) - -cc_test( - name = "dispatch_api_mediatek_test", - srcs = [ - "dispatch_api_mediatek_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "no-remote-exec", - "no_oss", - "nobuilder", - "nosan", - "notap", - ], - deps = [ - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md deleted file mode 100644 index 35a6130c76d318..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md +++ /dev/null @@ -1,4 +0,0 @@ -Test case can dispatch_api_mediatek_test can be run on a device with a MetiaTek -mt6989 SoC with the following comands - -$ ../../../google/run_test_on_android.sh dispatch_api_mediatek_test diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc deleted file mode 100644 index b8c5da6ee392f6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#if LITERT_HAS_AHWB_SUPPORT -#include -#endif - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -namespace { - -litert::mediatek::NeuronAdapterApi* TheNeuronAdapter; -char BuildId[256]; - -} // namespace - -namespace litert { -namespace mediatek { - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, - int num_options) { - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - return option.value.str_value; - } - } - return nullptr; -} - -LiteRtStatus LiteRtInitialize(const LiteRtDispatchOption* options, - int num_options) { - auto* shared_library_dir = GetSharedLibraryDir(options, num_options); - std::optional shared_library_dir_opt = - shared_library_dir ? std::make_optional(std::string(shared_library_dir)) - : std::nullopt; - - if (auto neuron_adapter_api = - litert::mediatek::NeuronAdapterApi::Create(shared_library_dir_opt); - neuron_adapter_api) { - TheNeuronAdapter = neuron_adapter_api->release(); - } else { - LITERT_LOG(LITERT_INFO, "Initialization failure: %s", - neuron_adapter_api.Error().Message().c_str()); - return neuron_adapter_api.Error().Status(); - } - - auto get_version = TheNeuronAdapter->api().get_version; - if (!get_version) { - LITERT_LOG(LITERT_ERROR, "get_version not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - - NeuronRuntimeVersion version; - if (get_version(&version) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to get version"); - return kLiteRtStatusErrorRuntimeFailure; - } - LITERT_LOG(LITERT_INFO, "Neuron SDK version: %d.%d.%d", version.major, - version.minor, version.patch); - - snprintf(BuildId, sizeof(BuildId), - "MediaTek Dispatch API version %d.%d.%d, NeuronAdaptor API version " - "%d.%d.%d", - LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH, version.major, version.minor, - version.patch); - BuildId[sizeof(BuildId) - 1] = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetVendorId(const char** vendor_id) { - *vendor_id = "MediaTek"; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBuildId(const char** build_id) { - *build_id = BuildId; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCapabilities(int* capabilities) { - *capabilities = kLiteRtDispatchCapabilitiesBasic; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDeviceContextCreate( - LiteRtDispatchDeviceContext* device_context) { - if (auto context = LiteRtDispatchDeviceContextT::Create(*TheNeuronAdapter); - context) { - *device_context = context->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } -} - -LiteRtStatus LiteRtDeviceContextDestroy( - LiteRtDispatchDeviceContext device_context) { - delete device_context; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetInputRequirements(input_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus LiteRtGetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetOutputRequirements(output_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus LiteRtRegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (auto result = device_context->RegisterTensorBuffer(tensor_buffer); - result) { - *tensor_buffer_handle = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus LiteRtUnregisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = - device_context->UnregisterTensorBuffer(tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to unregister tensor buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtInvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - auto context = LiteRtDispatchInvocationContextT::Create( - *TheNeuronAdapter, device_context, exec_type, exec_bytecode_buffer, - function_name, num_inputs, num_outputs); - if (!context) { - LITERT_LOG(LITERT_ERROR, "Failed to create context from context binary: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } - *invocation_context = context->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtInvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - delete invocation_context; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAttachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachInput(graph_input_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach input: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAttachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachOutput(graph_output_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach output: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDetachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->DetachInput(graph_input_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to detach input: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDetachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->DetachOutput(graph_output_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to detach output: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtInvoke(LiteRtDispatchInvocationContext invocation_context) { - if (auto status = invocation_context->Invoke(); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -} // namespace mediatek -} // namespace litert - -// ///////////////////////////////////////////////////////////////////////////// - -namespace { - -LiteRtDispatchInterface TheInterface = { - .initialize = litert::mediatek::LiteRtInitialize, - .get_vendor_id = litert::mediatek::LiteRtGetVendorId, - .get_build_id = litert::mediatek::LiteRtGetBuildId, - .get_capabilities = litert::mediatek::LiteRtGetCapabilities, - .device_context_create = litert::mediatek::LiteRtDeviceContextCreate, - .device_context_destroy = litert::mediatek::LiteRtDeviceContextDestroy, - .get_input_requirements = litert::mediatek::LiteRtGetInputRequirements, - .get_output_requirements = litert::mediatek::LiteRtGetOutputRequirements, - .register_tensor_buffer = litert::mediatek::LiteRtRegisterTensorBuffer, - .unregister_tensor_buffer = litert::mediatek::LiteRtUnregisterTensorBuffer, - .invocation_context_create = - litert::mediatek::LiteRtInvocationContextCreate, - .invocation_context_destroy = - litert::mediatek::LiteRtInvocationContextDestroy, - .attach_input = litert::mediatek::LiteRtAttachInput, - .attach_output = litert::mediatek::LiteRtAttachOutput, - .detach_input = litert::mediatek::LiteRtDetachInput, - .detach_output = litert::mediatek::LiteRtDetachOutput, - .invoke = litert::mediatek::LiteRtInvoke, -}; - -LiteRtDispatchApi TheApi = { - .version = {.major = LITERT_API_VERSION_MAJOR, - .minor = LITERT_API_VERSION_MINOR, - .patch = LITERT_API_VERSION_PATCH}, - .interface = &TheInterface, - .async_interface = nullptr, - .graph_interface = nullptr, -}; - -} // namespace - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { - *api = TheApi; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc deleted file mode 100644 index 9926f55e6884b6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc +++ /dev/null @@ -1,638 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; - -TEST(MediaTek, DispatchApiWithAhwb) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a MediaTek NPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kMediaTekModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with more data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor_2, - sizeof(kTestInput0Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor_2, - sizeof(kTestInput1Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model once more. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking second execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor_2[i]; - } - EXPECT_THAT(output, - Pointwise(testing::FloatNear(1e-3), kTestOutputTensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} - -TEST(MediaTek, DispatchApiWithDmaBuf) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a MediaTek NPU"; -#endif - - EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kMediaTekModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 2); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/1, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 2); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/1, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 2); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/1, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with more data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor_2, - sizeof(kTestInput0Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor_2, - sizeof(kTestInput1Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model once more. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking second execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor_2[i]; - } - EXPECT_THAT(output, - Pointwise(testing::FloatNear(1e-3), kTestOutputTensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc deleted file mode 100644 index b728f5c5c15d5f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" - -#include - -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -using litert::Error; - -LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() = default; - -litert::Expected -LiteRtDispatchDeviceContextT::Create( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) { - return std::unique_ptr( - new LiteRtDispatchDeviceContextT(neuron_adapter_api)); -} - -litert::Expected -LiteRtDispatchDeviceContextT::RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer) { - LiteRtTensorBufferType tensor_buffer_type; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type)); - - if (tensor_buffer_type != kLiteRtTensorBufferTypeAhwb && - tensor_buffer_type != kLiteRtTensorBufferTypeDmaBuf) { - return Error(kLiteRtStatusErrorUnsupported, "Unsupported buffer type"); - } - - size_t tensor_buffer_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size)); - - size_t tensor_buffer_offset; - if (auto status = - LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); - status != kLiteRtStatusOk) { - if (status == kLiteRtStatusErrorNotFound) { - tensor_buffer_offset = 0; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get buffer offset"); - } - } - - LiteRtRankedTensorType tensor_type; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type)); - - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported"); - } - - switch (tensor_buffer_type) { - case kLiteRtTensorBufferTypeAhwb: -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb; - if (auto status = LiteRtGetTensorBufferAhwb(tensor_buffer, &ahwb); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get AHWB"); - } -#else - return Error(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT - NeuronMemory* neuron_memory; -#if LITERT_HAS_AHWB_SUPPORT - if (neuron_adapter_api_.api().memory_create_from_ahwb( - ahwb, &neuron_memory) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronMemory from AHWB"); - } - return neuron_memory_registry_.Register(neuron_memory, tensor_buffer_size, - tensor_buffer_offset); -#else - (void)neuron_adapter_api_; - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT - break; - - case kLiteRtTensorBufferTypeDmaBuf: - - int fd; -#if LITERT_HAS_DMABUF_SUPPORT - void* addr; - if (auto status = - LiteRtGetTensorBufferDmaBufBuffer(tensor_buffer, &addr, &fd); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get DMA-BUF"); - } -#else - return Error(kLiteRtStatusErrorRuntimeFailure, - "DMA-BUF is not supported on this platform"); -#endif // LITERT_HAS_DMABUF_SUPPORT - if (neuron_adapter_api_.api().memory_create_from_fd( - tensor_buffer_size, /*protect*/ PROT_READ | PROT_WRITE, fd, - tensor_buffer_offset, &neuron_memory) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronMemory from DMA-BUF"); - } - return neuron_memory_registry_.Register(neuron_memory, tensor_buffer_size, - tensor_buffer_offset); - break; - - default: - LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", - tensor_buffer_type); - return litert::Unexpected(kLiteRtStatusErrorUnsupported); - } -} - -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::~NeuronMemoryRegistry() { - for (auto i = 0; i < records_.size(); ++i) { - auto& record = records_[i]; - if (record.neuron_memory != nullptr) { - neuron_adapter_api_.api().memory_free(record.neuron_memory); - } - } -} - -LiteRtTensorBufferHandle -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Register( - NeuronMemory* neuron_memory, size_t size, size_t offset) { - int dest_index = -1; - for (auto i = 0; i < records_.size(); ++i) { - if (!records_[i].neuron_memory) { - dest_index = i; - break; - } - } - if (dest_index < 0) { - dest_index = records_.size(); - records_.push_back({}); - } - auto& dest = records_[dest_index]; - dest = {neuron_memory, size, offset}; - return dest_index; -} - -litert::Expected -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Unregister( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto record = Find(tensor_buffer_handle); - if (!record) { - return record.Error(); - } else { - auto& mem = (*record)->neuron_memory; - neuron_adapter_api_.api().memory_free(mem); - mem = nullptr; - return {}; - } -} - -litert::Expected -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Find( - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (tensor_buffer_handle < 0 || tensor_buffer_handle >= records_.size()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid tensor buffer handle"); - } - return &records_[tensor_buffer_handle]; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h deleted file mode 100644 index 483701fe919acc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ - -#include - -#include "neuron/api/NeuronAdapter.h" -#include "absl/container/flat_hash_set.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -class LiteRtDispatchDeviceContextT { - public: - using Ptr = std::unique_ptr; - struct NeuronMemoryInfo { - NeuronMemory* neuron_memory; - size_t size; - size_t offset; - }; - - ~LiteRtDispatchDeviceContextT(); - - static litert::Expected Create( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api); - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer); - - litert::Expected UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - return neuron_memory_registry_.Unregister(tensor_buffer_handle); - } - - litert::Expected GetNeuronMemoryInfo( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto record = neuron_memory_registry_.Find(tensor_buffer_handle); - if (!record) { - return record.Error(); - } else { - return NeuronMemoryInfo(**record); - } - } - - private: - class NeuronMemoryRegistry { - public: - explicit NeuronMemoryRegistry( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) - : neuron_adapter_api_(neuron_adapter_api) {} - ~NeuronMemoryRegistry(); - LiteRtTensorBufferHandle Register(NeuronMemory* neuron_memory, size_t size, - size_t offset); - litert::Expected Unregister( - LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected Find( - LiteRtTensorBufferHandle tensor_buffer_handle); - - private: - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api_; - std::vector records_; - }; - - explicit LiteRtDispatchDeviceContextT( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) - : neuron_adapter_api_(neuron_adapter_api), - neuron_memory_registry_(neuron_adapter_api) {} - - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api_; - NeuronMemoryRegistry neuron_memory_registry_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc deleted file mode 100644 index 2f235e182c9614..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc +++ /dev/null @@ -1,435 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -using litert::Error; -using litert::Expected; -using litert::mediatek::NeuronCompilationPtr; -using litert::mediatek::NeuronExecutionPtr; -using litert::mediatek::NeuronModelPtr; - -namespace { - -Expected> LoadFromCachedNetwork( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - const void* bytecode_addr, size_t bytecode_size) { - NeuronModel* model; - NeuronCompilation* compilation; - if (neuron_adapter_api.api().model_restore_from_compiled_network( - &model, &compilation, bytecode_addr, bytecode_size) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to restore model from compiled network"); - } - return std::make_pair( - NeuronModelPtr{model, neuron_adapter_api.api().model_free}, - NeuronCompilationPtr{compilation, - neuron_adapter_api.api().compilation_free}); -} - -uint16_t GetRestoreDlaExtensionOperandType( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) { - NeuronRuntimeVersion version; - neuron_adapter_api.api().get_version(&version); - // The values below were suggested by MTK. - if (version.major >= 8) { - return 0x0200; - } else { - return 0x0100; - } -} - -Expected> LoadFromDlaBytecode( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - const void* bytecode_addr, size_t bytecode_size, int num_inputs, - int num_outputs) { - Expected model = neuron_adapter_api.CreateModel(); - if (!model) { - return model.Error(); - } - - // fake input, the real outputs are loaded by compiled network. - constexpr const NeuronOperandType fake_io_operand_type{ - .type = NEURON_TENSOR_FLOAT32, - .dimensionCount = 0, - .scale = 0.0f, - .zeroPoint = 0, - }; - - std::vector input_op_number; - input_op_number.reserve(num_inputs); - for (auto i = 0; i < num_inputs; i++) { - if (neuron_adapter_api.api().model_add_operand( - model->get(), &fake_io_operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add input operand"); - } - input_op_number.emplace_back(i); - } - - const uint16_t kNetworkOperandRestoreData = - GetRestoreDlaExtensionOperandType(neuron_adapter_api); - constexpr const uint16_t kRestoreDlaExtensionOperationType = 0; - constexpr const char* kExtensionRestoreCompiledNetwork = - "com.mediatek.compiled_network"; - - int32_t operand_type; - if (neuron_adapter_api.api().model_get_extension_operand_type( - model->get(), kExtensionRestoreCompiledNetwork, - kNetworkOperandRestoreData, &operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to getextension operand"); - } - - const NeuronOperandType extension_operand_type{ - .type = operand_type, - .dimensionCount = 0, - .scale = 0.0f, - .zeroPoint = 0, - }; - if (neuron_adapter_api.api().model_add_operand( - model->get(), &extension_operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add extension operand"); - } - input_op_number.emplace_back(input_op_number.size()); - if (neuron_adapter_api.api().model_set_operand_value( - model->get(), input_op_number.back(), bytecode_addr, bytecode_size) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set extension operand value"); - } - - std::vector output_op_number; - for (auto i = 0; i < num_outputs; i++) { - if (neuron_adapter_api.api().model_add_operand( - model->get(), &fake_io_operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add output operand"); - } - output_op_number.emplace_back(input_op_number.size() + i); - } - - int32_t operation_type; - if (neuron_adapter_api.api().model_get_extension_operation_type( - model->get(), kExtensionRestoreCompiledNetwork, - kRestoreDlaExtensionOperationType, - &operation_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get extension operation"); - } - - // Add extension operation - if (neuron_adapter_api.api().model_add_operation( - model->get(), static_cast(operation_type), - input_op_number.size(), input_op_number.data(), - output_op_number.size(), - output_op_number.data()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add extension operation"); - } - - if (neuron_adapter_api.api().model_identify_inputs_and_outputs( - model->get(), input_op_number.size() - 1, input_op_number.data(), - output_op_number.size(), - output_op_number.data()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to identify I/Os"); - } - - if (neuron_adapter_api.api().model_finish(model->get()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to finish model"); - } - - auto compilation = neuron_adapter_api.CreateCompilation(model->get()); - if (!compilation) { - return compilation.Error(); - } - - if (neuron_adapter_api.api().compilation_set_priority( - compilation->get(), NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation priority"); - } - - if (neuron_adapter_api.api().compilation_set_preference( - compilation->get(), NEURON_PREFER_SUSTAINED_SPEED) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation preference"); - } - - // We use AOT compile options since the DLA file was compiled ahead of time. - const auto compile_options = - std::string(neuron_adapter_api.AotCompileOptions()); - if (!compile_options.empty()) { - if (neuron_adapter_api.api().compilation_set_optimization_string( - compilation->get(), compile_options.c_str()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set optimization string"); - } - } - - if (neuron_adapter_api.api().compilation_finish(compilation->get()) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to finish compilation"); - } - - return std::make_pair(std::move(*model), std::move(*compilation)); -} - -Expected> -LoadModelAndCompilation( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - const void* bytecode_addr, size_t bytecode_size, int num_inputs, - int num_outputs) { - if (auto result = LoadFromDlaBytecode(neuron_adapter_api, bytecode_addr, - bytecode_size, num_inputs, num_outputs); - !result) { - return LoadFromCachedNetwork(neuron_adapter_api, bytecode_addr, - bytecode_size); - } else { - return result; - } -} - -} // namespace - -Expected -LiteRtDispatchInvocationContextT::Create( - litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs) { - neuron::SchemaResolver resolver; - - const void* exec_bytecode_ptr = - static_cast(exec_bytecode_buffer->base_addr) + - exec_bytecode_buffer->offset; - auto exec_bytecode_size = exec_bytecode_buffer->size; - auto res = resolver.Initialize((const uint8_t*)exec_bytecode_ptr, - exec_bytecode_size); - if (res.HasValue() && res.Value()) { - std::string func = function_name != nullptr ? function_name : ""; - auto graph = resolver.GetCompiledGraph(func); - if (!graph.has_value()) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Couldn't find the subgraph"); - } - auto compile_graph = graph.value().GetCompiledNetwork(); - if (!compile_graph) { - return compile_graph.Error(); - } - std::tie(exec_bytecode_ptr, exec_bytecode_size) = compile_graph.Value(); - } - - auto model_and_compilation = - LoadModelAndCompilation(neuron_adapter_api, exec_bytecode_ptr, - exec_bytecode_size, num_inputs, num_outputs); - if (!model_and_compilation) { - return model_and_compilation.Error(); - } - - auto& model = model_and_compilation->first; - auto& compilation = model_and_compilation->second; - - auto execution = neuron_adapter_api.CreateExecution(compilation.get()); - if (!execution) { - return execution.Error(); - } - - if (neuron_adapter_api.api().execution_set_boost_hint( - execution->get(), 100) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution boost hint"); - } - - return Ptr(new LiteRtDispatchInvocationContextT( - neuron_adapter_api, device_context, model.release(), - compilation.release(), execution->release(), num_inputs, num_outputs)); -} - -LiteRtDispatchInvocationContextT::~LiteRtDispatchInvocationContextT() { - if (execution_) { - neuron_adapter_api_.api().execution_free(execution_); - } - if (compilation_) { - neuron_adapter_api_.api().compilation_free(compilation_); - } - if (model_) { - neuron_adapter_api_.api().model_free(model_); - } -} - -LiteRtDispatchInvocationContextT::IoRequirementsBuilder::IoRequirementsBuilder( - size_t buffer_size, const std::vector& padded_dimensions) - : buffer_size_(buffer_size) { - auto rank = padded_dimensions.size(); - strides_.resize(rank); - strides_[0] = 1; - for (auto i = 1; i < rank; ++i) { - strides_[i] = padded_dimensions[i - 1]; - } -} - -Expected -LiteRtDispatchInvocationContextT::IoRequirementsBuilder::Create() { - static constexpr std::array kSupportedTensorBufferTypes = { -#if defined(__ANDROID__) - kLiteRtTensorBufferTypeAhwb, -#endif // __ANDROID__ - kLiteRtTensorBufferTypeDmaBuf, - }; - - LiteRtTensorBufferRequirements requirements; - if (auto status = LiteRtCreateTensorBufferRequirements( - kSupportedTensorBufferTypes.size(), - kSupportedTensorBufferTypes.data(), buffer_size_, strides_.size(), - strides_.data(), &requirements); - status != kLiteRtStatusOk) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create tensor buffer requirements"); - } - - return requirements; -} - -Expected -LiteRtDispatchInvocationContextT::GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type) { - if (!input_requirements_builders_[input_index]) { - size_t buffer_size; - if (neuron_adapter_api_.api().compilation_get_input_padded_size( - compilation_, input_index, &buffer_size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get input padded size"); - } - - std::vector padded_dimensions(tensor_type.layout.rank); - if (neuron_adapter_api_.api().compilation_get_input_padded_dimensions( - compilation_, input_index, padded_dimensions.data()) != - NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get input padded dimensions"); - } - - input_requirements_builders_[input_index] = - std::make_unique(buffer_size, padded_dimensions); - } - - return input_requirements_builders_[input_index]->Create(); -} - -Expected -LiteRtDispatchInvocationContextT::GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type) { - if (!output_requirements_builders_[output_index]) { - size_t buffer_size; - if (neuron_adapter_api_.api().compilation_get_output_padded_size( - compilation_, output_index, &buffer_size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get output padded size"); - } - - std::vector padded_dimensions(tensor_type.layout.rank); - if (neuron_adapter_api_.api().compilation_get_output_padded_dimensions( - compilation_, output_index, padded_dimensions.data()) != - NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get output padded dimensions"); - } - - output_requirements_builders_[output_index] = - std::make_unique(buffer_size, padded_dimensions); - } - - return output_requirements_builders_[output_index]->Create(); -} - -Expected LiteRtDispatchInvocationContextT::AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - auto neuron_memory_info = - device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); - if (!neuron_memory_info) { - return litert::Error(neuron_memory_info.Error()); - } - - if (neuron_adapter_api_.api().execution_set_input_from_memory( - execution_, graph_input_index, nullptr, - neuron_memory_info->neuron_memory, neuron_memory_info->offset, - neuron_memory_info->size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution input from memory"); - } - return {}; -} - -Expected LiteRtDispatchInvocationContextT::AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - auto neuron_memory_info = - device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); - if (!neuron_memory_info) { - return litert::Error(neuron_memory_info.Error()); - } - - if (neuron_adapter_api_.api().execution_set_output_from_memory( - execution_, graph_output_index, nullptr, - neuron_memory_info->neuron_memory, neuron_memory_info->offset, - neuron_memory_info->size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution output from memory"); - } - return {}; -} - -Expected LiteRtDispatchInvocationContextT::DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do. - return {}; -} - -Expected LiteRtDispatchInvocationContextT::DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do. - return {}; -} - -Expected LiteRtDispatchInvocationContextT::Invoke() { - if (neuron_adapter_api_.api().execution_compute(execution_) != - NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to execute network"); - } - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h deleted file mode 100644 index f58ee976b693e2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ - -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -class LiteRtDispatchInvocationContextT { - public: - using Ptr = std::unique_ptr; - - static litert::Expected Create( - litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs); - - ~LiteRtDispatchInvocationContextT(); - - litert::Expected GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected Invoke(); - - private: - class IoRequirementsBuilder { - public: - IoRequirementsBuilder(size_t buffer_size, - const std::vector& padded_dimensions); - litert::Expected Create(); - - private: - size_t buffer_size_; - std::vector strides_; - }; - - LiteRtDispatchInvocationContextT( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - LiteRtDispatchDeviceContext device_context, NeuronModel* model, - NeuronCompilation* compilation, NeuronExecution* execution, - int num_inputs, int num_outputs) - : neuron_adapter_api_(neuron_adapter_api), - device_context_(device_context), - model_(model), - compilation_(compilation), - execution_(execution), - input_requirements_builders_(num_inputs), - output_requirements_builders_(num_outputs) {} - - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api_; - LiteRtDispatchDeviceContext device_context_; - NeuronModel* model_; - NeuronCompilation* compilation_; - NeuronExecution* execution_; - std::vector> - input_requirements_builders_; - std::vector> - output_requirements_builders_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl b/tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl deleted file mode 100644 index 5427e9e0d29521..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for Mediatek backend.""" - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "append_rule_kwargs", "litert_bin", "litert_lib", "make_rpaths") - -_MTK_STD_LIBS_HOST = [ - # copybara:uncomment_begin(google-only) - # "//third_party/neuro_pilot:latest/host/lib/libc++.so.1", - # "//third_party/neuro_pilot:latest/host/lib/libstdc++.so.6", - # copybara:uncomment_end -] # @unused - -_MTK_NEURON_ADAPTER_SO = [ - # copybara:uncomment_begin(google-only) - # "//third_party/neuro_pilot:latest/host/lib/libneuron_adapter.so", - # copybara:uncomment_end -] - -# TODO: Make rpaths dynamic with "$(location {})". -_MTK_HOST_RPATHS = [ - # copybara:uncomment_begin(google-only) - # "third_party/neuro_pilot/latest/host/lib", - # copybara:uncomment_end -] - -def _litert_with_mtk_base( - litert_rule, - use_custom_libcc, - **litert_rule_kwargs): - if use_custom_libcc: - # TODO: Figure out strategy for custom libcc. - fail("Custom libcc not yet supported") - - data_x86_64 = [] - data_x86_64.extend(_MTK_NEURON_ADAPTER_SO) - append_rule_kwargs( - litert_rule_kwargs, - data = select({ - "//tensorflow:linux_x86_64": data_x86_64, - "//conditions:default": [], - }), - linkopts = select({ - "//tensorflow:linux_x86_64": [make_rpaths(_MTK_HOST_RPATHS)], - "//conditions:default": [], - }), - ) - - litert_rule(**litert_rule_kwargs) - -def litert_cc_lib_with_mtk( - use_custom_libcc = False, - **litert_lib_kwargs): - """Creates a litert_lib target with Mediatek backend dependencies. - - Args: - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_lib_kwargs: Keyword arguments passed to litert_lib. - """ - _litert_with_mtk_base( - litert_lib, - use_custom_libcc, - **litert_lib_kwargs - ) - -def litert_cc_bin_with_mtk( - use_custom_libcc = False, - **litert_bin_kwargs): - """Creates a litert_bin target with Mediatek backend dependencies. - - Args: - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_bin_kwargs: Keyword arguments passed to litert_bin. - """ - _litert_with_mtk_base( - litert_bin, - use_custom_libcc, - **litert_bin_kwargs - ) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc deleted file mode 100644 index ab3cbfaddb1287..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -#include - -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#define LOAD_SYMB(S, H) \ - if (auto maybe_H = dlib_.LookupSymbol(#S); maybe_H.HasValue()) { \ - H = reinterpret_cast(std::move(maybe_H).Value()); \ - } else { \ - LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ - dlib_.DlError()); \ - } - -namespace litert { -namespace mediatek { - -NeuronAdapterApi::NeuronAdapterApi() : api_(new Api) {} - -litert::Expected NeuronAdapterApi::Create( - std::optional shared_library_dir) { - std::unique_ptr neuron_adapter_api(new NeuronAdapterApi); - if (auto status = neuron_adapter_api->LoadSymbols(shared_library_dir); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to load NeuronAdapter shared library: %s", - status.Error().Message().c_str()); - return status.Error(); - } - - return neuron_adapter_api; -} - -litert::Expected NeuronAdapterApi::LoadSymbols( - std::optional shared_library_dir) { - constexpr auto kLibNeuronAdapterLib = "libneuron_adapter.so"; - - const std::vector so_paths = { - // The following preinstalled library is for system partition - // applications. - "libneuronusdk_adapter.mtk.so", "libneuron_adapter_mgvi.so", - kLibNeuronAdapterLib, - // Finally, the app may want to provide their own version of the library. - shared_library_dir.has_value() - ? absl::StrCat(*shared_library_dir, "/", kLibNeuronAdapterLib) - : kLibNeuronAdapterLib}; - for (auto& so_path : so_paths) { - auto maybe_dlib = SharedLibrary::Load(so_path, RtldFlags::Default()); - if (maybe_dlib.HasValue()) { - dlib_ = std::move(maybe_dlib).Value(); - } - } - - if (!dlib_.Loaded()) { - return litert::Error(kLiteRtStatusErrorDynamicLoading, - "Failed to load NeuronAdapter shared library"); - } - - LITERT_LOG(LITERT_INFO, "Loaded NeuronAdapter shared library."); - - // Binds all supported symbols from the shared library to the function - // pointers. - LOAD_SYMB(NeuronCompilation_create, api_->compilation_create); - LOAD_SYMB(NeuronCompilation_createWithOptions, - api_->compilation_create_with_options); - LOAD_SYMB(NeuronCompilation_finish, api_->compilation_finish); - LOAD_SYMB(NeuronCompilation_free, api_->compilation_free); - LOAD_SYMB(NeuronCompilation_getInputPaddedDimensions, - api_->compilation_get_input_padded_dimensions); - LOAD_SYMB(NeuronCompilation_getInputPaddedSize, - api_->compilation_get_input_padded_size); - LOAD_SYMB(NeuronCompilation_getOutputPaddedDimensions, - api_->compilation_get_output_padded_dimensions); - LOAD_SYMB(NeuronCompilation_getOutputPaddedSize, - api_->compilation_get_output_padded_size); - LOAD_SYMB(NeuronCompilation_setOptimizationString, - api_->compilation_set_optimization_string); - LOAD_SYMB(NeuronCompilation_setPreference, api_->compilation_set_preference); - LOAD_SYMB(NeuronCompilation_setPriority, api_->compilation_set_priority); - LOAD_SYMB(NeuronExecution_compute, api_->execution_compute); - LOAD_SYMB(NeuronExecution_create, api_->execution_create); - LOAD_SYMB(NeuronExecution_free, api_->execution_free); - LOAD_SYMB(NeuronCompilation_getCompiledNetworkSize, - api_->compilation_get_compiled_network_size); - LOAD_SYMB(NeuronCompilation_storeCompiledNetwork, - api_->compilation_store_compiled_network); - LOAD_SYMB(NeuronExecution_setBoostHint, api_->execution_set_boost_hint); - LOAD_SYMB(NeuronExecution_setInputFromMemory, - api_->execution_set_input_from_memory); - LOAD_SYMB(NeuronExecution_setOutputFromMemory, - api_->execution_set_output_from_memory); - LOAD_SYMB(NeuronMemory_createFromAHardwareBuffer, - api_->memory_create_from_ahwb); - LOAD_SYMB(NeuronMemory_createFromFd, api_->memory_create_from_fd); - LOAD_SYMB(NeuronMemory_free, api_->memory_free); - LOAD_SYMB(NeuronModel_addOperand, api_->model_add_operand); - LOAD_SYMB(NeuronModel_addOperation, api_->model_add_operation); - LOAD_SYMB(NeuronModel_create, api_->model_create); - LOAD_SYMB(NeuronModel_finish, api_->model_finish); - LOAD_SYMB(NeuronModel_free, api_->model_free); - LOAD_SYMB(NeuronModel_getExtensionOperandType, - api_->model_get_extension_operand_type); - LOAD_SYMB(NeuronModel_getExtensionOperationType, - api_->model_get_extension_operation_type); - LOAD_SYMB(NeuronModel_identifyInputsAndOutputs, - api_->model_identify_inputs_and_outputs); - LOAD_SYMB(NeuronModel_restoreFromCompiledNetwork, - api_->model_restore_from_compiled_network); - LOAD_SYMB(NeuronModel_setName, api_->model_set_name); - LOAD_SYMB(NeuronModel_setOperandValue, api_->model_set_operand_value); - LOAD_SYMB(NeuronModel_setOperandSymmPerChannelQuantParams, - api_->model_set_symm_per_channel_quant_params); - LOAD_SYMB(Neuron_getVersion, api_->get_version); - - LITERT_LOG(LITERT_INFO, "NeuronAdapter symbols loaded"); - return {}; -} - -Expected NeuronAdapterApi::CreateModel() const { - NeuronModel* model; - if (api().model_create(&model) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuroModel"); - } - return NeuronModelPtr{model, api().model_free}; -} - -Expected NeuronAdapterApi::CreateCompilation( - NeuronModel* model) const { - NeuronCompilation* compilation; - if (api().compilation_create(model, &compilation) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronCompilation"); - } - return NeuronCompilationPtr{compilation, api().compilation_free}; -} - -Expected NeuronAdapterApi::CreateCompilation( - NeuronModel* model, const std::string& compile_options) const { - NeuronCompilation* compilation; - if (auto status = api().compilation_create_with_options( - model, &compilation, compile_options.c_str()); - status != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, - "NeuronCompilation_createWithOptions failed with error %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronCompilation"); - } - return NeuronCompilationPtr{compilation, api().compilation_free}; -} - -Expected NeuronAdapterApi::CreateExecution( - NeuronCompilation* compilation) const { - NeuronExecution* execution; - if (api().execution_create(compilation, &execution) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create execution"); - } - return NeuronExecutionPtr{execution, api().execution_free}; -} - -} // namespace mediatek -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h deleted file mode 100644 index 7d61d2c027f2d1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_API_H_ - -#include -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#else -struct AHardwareBuffer {}; -#endif - -namespace litert::mediatek { - -using NeuronModelPtr = std::unique_ptr; -using NeuronCompilationPtr = - std::unique_ptr; -using NeuronExecutionPtr = - std::unique_ptr; - -class NeuronAdapterApi { - public: - using Ptr = std::unique_ptr; - struct Api; - - NeuronAdapterApi(NeuronAdapterApi&) = delete; - NeuronAdapterApi(NeuronAdapterApi&&) = delete; - NeuronAdapterApi& operator=(const NeuronAdapterApi&) = delete; - NeuronAdapterApi& operator=(NeuronAdapterApi&&) = delete; - - static Expected Create(std::optional shared_library_dir); - - const Api& api() const { return *api_; } - - absl::string_view AotCompileOptions() const { - // Option `import_forever` has been recommended by MediaTek to reduce memory - // footprint when using the same I/O buffers across multiple invocations. - return "--apusys-config \"{ \\\"import_forever\\\": true }\""; - } - - absl::string_view JitCompileOptions() const { return ""; } - - Expected CreateModel() const; - - Expected CreateCompilation(NeuronModel* model) const; - - Expected CreateCompilation( - NeuronModel* model, const std::string& compile_options) const; - - Expected CreateExecution( - NeuronCompilation* compilation) const; - - private: - NeuronAdapterApi(); - litert::Expected LoadSymbols( - std::optional shared_library_dir); - - // Handle to the shared library that implements the Neuron API. - // - // This will keep the shared library open until the NeuronAdapterApi object is - // destroyed. - SharedLibrary dlib_; - std::unique_ptr api_; -}; - -// This is not part of the provided NeuronAdapter header for some reason. -int NeuronCompilation_createWithOptions(NeuronModel* model, - NeuronCompilation** compilation, - const char* options); - -// A convenient struct for holding function pointers to NeuronAdapter API -// symbols. These function pointers will be loaded to the shared library on -// device during runtime. -struct NeuronAdapterApi::Api { - decltype(&NeuronCompilation_create) compilation_create = nullptr; - decltype(&NeuronCompilation_createWithOptions) - compilation_create_with_options = nullptr; - decltype(&NeuronCompilation_finish) compilation_finish = nullptr; - decltype(&NeuronCompilation_free) compilation_free = nullptr; - decltype(&NeuronCompilation_getCompiledNetworkSize) - compilation_get_compiled_network_size = nullptr; - decltype(&NeuronCompilation_getInputPaddedDimensions) - compilation_get_input_padded_dimensions = nullptr; - decltype(&NeuronCompilation_getInputPaddedSize) - compilation_get_input_padded_size = nullptr; - decltype(&NeuronCompilation_getOutputPaddedDimensions) - compilation_get_output_padded_dimensions = nullptr; - decltype(&NeuronCompilation_getOutputPaddedSize) - compilation_get_output_padded_size = nullptr; - decltype(&NeuronCompilation_setOptimizationString) - compilation_set_optimization_string = nullptr; - decltype(&NeuronCompilation_setPreference) compilation_set_preference = - nullptr; - decltype(&NeuronCompilation_setPriority) compilation_set_priority = nullptr; - decltype(&NeuronCompilation_storeCompiledNetwork) - compilation_store_compiled_network = nullptr; - decltype(&NeuronExecution_compute) execution_compute = nullptr; - decltype(&NeuronExecution_create) execution_create = nullptr; - decltype(&NeuronExecution_free) execution_free = nullptr; - decltype(&NeuronExecution_setBoostHint) execution_set_boost_hint = nullptr; - decltype(&NeuronExecution_setInputFromMemory) - execution_set_input_from_memory = nullptr; - decltype(&NeuronExecution_setOutputFromMemory) - execution_set_output_from_memory = nullptr; - decltype(&NeuronMemory_createFromAHardwareBuffer) memory_create_from_ahwb = - nullptr; - decltype(&NeuronMemory_createFromFd) memory_create_from_fd = nullptr; - decltype(&NeuronMemory_free) memory_free = nullptr; - decltype(&NeuronModel_addOperand) model_add_operand = nullptr; - decltype(&NeuronModel_addOperation) model_add_operation = nullptr; - decltype(&NeuronModel_create) model_create = nullptr; - decltype(&NeuronModel_finish) model_finish = nullptr; - decltype(&NeuronModel_free) model_free = nullptr; - decltype(&NeuronModel_getExtensionOperandType) - model_get_extension_operand_type = nullptr; - decltype(&NeuronModel_getExtensionOperationType) - model_get_extension_operation_type = nullptr; - decltype(&NeuronModel_identifyInputsAndOutputs) - model_identify_inputs_and_outputs = nullptr; - decltype(&NeuronModel_restoreFromCompiledNetwork) - model_restore_from_compiled_network = nullptr; - decltype(&NeuronModel_setName) model_set_name = nullptr; - decltype(&NeuronModel_setOperandValue) model_set_operand_value = nullptr; - decltype(&NeuronModel_setOperandSymmPerChannelQuantParams) - model_set_symm_per_channel_quant_params = nullptr; - decltype(&Neuron_getVersion) get_version = nullptr; -}; - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD deleted file mode 100644 index f17fd689d66796..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2025 MediaTek Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -flatbuffer_cc_library( - name = "neuron_litert_schema", - srcs = ["neuron_schema.fbs"], - compatible_with = get_compatible_with_portable(), -) - -cc_library( - name = "mediatek_litert_schema", - hdrs = [ - "schema_resolver.h", - ], - visibility = ["//visibility:public"], - deps = [ - "neuron_litert_schema", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:str_format", - "@flatbuffers//:runtime_cc", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs b/tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs deleted file mode 100644 index 6d515fd7972e7c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace NeuronSchema; - -enum CompiledType : byte { - DLA = 0, - DLB, - AdapterCache -} - -table Index { - value: int = -1; -} - -table Identifier { - value: string; -} - -// BufferIndicate to specify how to point to a buffer -union BufferIndicate { - Index, - Identifier, -} - -table Subgraph { - entry_point: string; // Entry point of the subgraph - type: CompiledType; // Type of the compiled subgraph - compiled_index: BufferIndicate; // index to the buffer at Graphs.data - weight_share_index: [BufferIndicate]; // index to the buffer at Graphs.data[index]. if empty, no weight share. -} - -table Graphs { - version: short; // Version of the graph schema - subgraphs: [Subgraph]; // List of subgraphs - data: [Buffer]; - external: ExternalBuffer; -} - -table Buffer { - identifier: string; - data: [byte]; // Binary data -} - -// List of external buffer that doesn't store in this schema -table ExternalBuffer { - identifiers: [string]; -} - -root_type Graphs; \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h b/tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h deleted file mode 100644 index 9fd871da7e0f0e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_SCHEMA_SCHEMA_RESOLVER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_SCHEMA_SCHEMA_RESOLVER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "flatbuffers/buffer.h" // from @flatbuffers -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "flatbuffers/verifier.h" // from @flatbuffers -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema_generated.h" - -namespace neuron { - -inline bool IsNeuronSchema(const uint8_t* buffer, size_t size) { - if (buffer == nullptr) { - return false; - } - flatbuffers::Verifier verifier(buffer, size); - return NeuronSchema::VerifyGraphsBuffer(verifier); -} - -class CompiledGraph { - public: - CompiledGraph(const NeuronSchema::Graphs& g, const NeuronSchema::Subgraph& s) - : graph_(g), subgraph_(s) {}; - - litert::Expected> GetCompiledNetwork() { - // Neuron Adapter doesn't support DLB for now. - assert(GetCompiledType() != NeuronSchema::CompiledType_DLB); - // TODO: Support the external buffer. - assert(subgraph_.compiled_index_type() == - NeuronSchema::BufferIndicate_Index); - auto index = subgraph_.compiled_index_as_Index(); - return GetBuffer(index->value()); - } - - NeuronSchema::CompiledType GetCompiledType() { return subgraph_.type(); } - - litert::Expected> GetBuffer(int32_t i) { - auto array_size = graph_.data()->size(); - if (i >= array_size) { - return litert::Error( - kLiteRtStatusErrorIndexOOB, - absl::StrFormat("Buffer array index %d is OOB, the array size : %d", - i, array_size)); - } - auto buffer = graph_.data()->Get(i); - return std::pair(buffer->data()->data(), - buffer->data()->size()); - } - - private: - const NeuronSchema::Graphs& graph_; - const NeuronSchema::Subgraph& subgraph_; -}; - -class SchemaResolver { - public: - SchemaResolver() = default; - - litert::Expected Initialize(const uint8_t* buffer, size_t size) { - if (!IsNeuronSchema(buffer, size)) { - return litert::Error(kLiteRtStatusErrorInvalidFlatbuffer, - "buffer is not a valid NeuronSchema"); - } - graph_ = NeuronSchema::GetGraphs(buffer); - - auto subgraphs = graph_->subgraphs(); - for (const auto& subgraph : *subgraphs) { - auto graph_name = subgraph->entry_point()->str(); - if (entry_points_.count(graph_name)) { - // shouldn't have the same name between graphs. - return false; - } else { - LITERT_LOG(LITERT_INFO, "Found graph: %s", graph_name.c_str()); - entry_points_[graph_name] = subgraph; - } - } - LITERT_LOG(LITERT_INFO, "There are %u subgraphs in the bytecode", - entry_points_.size()); - return true; - } - - std::optional GetCompiledGraph(std::string& name) { - if (entry_points_.count(name) == 0) { - return std::nullopt; - } - return CompiledGraph(*graph_, *entry_points_[name]); - }; - - private: - const NeuronSchema::Graphs* graph_ = nullptr; - - std::unordered_map entry_points_; -}; - -class BytecodeBuilder { - public: - BytecodeBuilder() = default; - - int32_t AddCompiledNetwork(std::string& entry_point, - NeuronSchema::CompiledType type, - int32_t buffer_index) { - auto index = NeuronSchema::CreateIndex(fb_, buffer_index); - auto subgraph = NeuronSchema::CreateSubgraph( - fb_, fb_.CreateString(entry_point), type, - NeuronSchema::BufferIndicate_Index, index.Union()); - - subgraphs_.push_back(subgraph); - return subgraphs_count_++; - }; - - int32_t AddBuffer(std::string& identifier, const std::vector& data) { - auto buffer = - NeuronSchema::CreateBufferDirect(fb_, identifier.c_str(), &data); - graph_data_.push_back(buffer); - return buffer_count_++; - } - - int32_t AddBuffer(std::string& identifier, const int8_t* data, - size_t length) { - auto data_offset = fb_.CreateVector(data, length); - auto identifier_offset = fb_.CreateString(identifier); - auto buffer = - NeuronSchema::CreateBuffer(fb_, identifier_offset, data_offset); - graph_data_.push_back(buffer); - return buffer_count_++; - } - - bool Finish() { - auto graphs = - NeuronSchema::CreateGraphsDirect(fb_, 1, &subgraphs_, &graph_data_); - fb_.Finish(graphs); - raw_buffer_ = {fb_.GetBufferPointer(), fb_.GetSize()}; - return true; - } - - std::pair GetBytecode() { - if (!raw_buffer_.has_value()) { - return {nullptr, 0}; - } - return raw_buffer_.value(); - } - - private: - ::flatbuffers::FlatBufferBuilder fb_; - - std::optional> raw_buffer_; - - std::vector<::flatbuffers::Offset> subgraphs_; - - std::vector<::flatbuffers::Offset> graph_data_; - - int32_t subgraphs_count_ = 0; - int32_t buffer_count_ = 0; -}; - -}; // namespace neuron - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_SCHEMA_SCHEMA_RESOLVER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv b/tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv deleted file mode 100644 index 0d792650611649..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv +++ /dev/null @@ -1,17 +0,0 @@ -# manufacturer,model,android_api_level -Mediatek,MT6897,UNKNOWN -Mediatek,MT6895Z_A/TCZA,UNKNOWN -Mediatek,MT6985,UNKNOWN -Mediatek,MT6989,UNKNOWN -Mediatek,MT6983,UNKNOWN -Mediatek,MT6895Z/TCZA,UNKNOWN -Mediatek,MT6895Z_B/TCZA,UNKNOWN -Mediatek,MT6991,UNKNOWN -Mediatek,MT6983Z/CZA,UNKNOWN -Mediatek,MT6983W/CZA,UNKNOWN -Mediatek,MT6895,UNKNOWN -Mediatek,MT6983Z/TCZA,UNKNOWN -Mediatek,MT6991(ENG),UNKNOWN -Mediatek,MT6895Z/CZA,UNKNOWN -Mediatek,MT6989(ENG),UNKNOWN -Mediatek,MT6985(ENG),UNKNOWN diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD deleted file mode 100644 index 6a76dc9594a1c3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib", "litert_test") -load("//tensorflow/lite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_cc_lib_with_qnn") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "common", - hdrs = ["common.h"], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -litert_lib( - name = "qnn_log", - srcs = ["qnn_log.cc"], - hdrs = ["qnn_log.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - ], -) - -cc_library( - name = "qnn_manager_hdr", - hdrs = ["qnn_manager.h"], - deps = [ - ":common", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - ], -) - -litert_cc_lib_with_qnn( - name = "qnn_manager", - srcs = [ - "qnn_manager.cc", - ], - hdrs = ["qnn_manager.h"], - include_system = True, - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - ungrte = True, - deps = [ - ":common", - ":qnn_log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - ], -) - -litert_test( - name = "qnn_manager_test", - srcs = ["qnn_manager_test.cc"], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. - "nosan", - ], - # This test can be run only on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - deps = [ - ":qnn_manager", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers_oss", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/tools:dump", - ], -) - -cc_library( - name = "context_binary_info", - srcs = ["context_binary_info.cc"], - hdrs = ["context_binary_info.h"], - deps = [ - ":qnn_manager", - ":qnn_tensor", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_library( - name = "qnn_tensor", - srcs = ["qnn_tensor.cc"], - hdrs = ["qnn_tensor.h"], - deps = [ - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h deleted file mode 100644 index 34b8971460c466..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ - -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -#define LITERT_RETURN_STATUS_IF_QNN_NOT_OK(expr) \ - if (QNN_SUCCESS != (expr)) { \ - return kLiteRtStatusErrorNotFound; \ - } - -// Pointers to functions of a dynamically loaded QNN library. -typedef QNN_INTERFACE_VER_TYPE QnnApi; - -// Pointers to functions of a dynamically loaded QNN system library. -typedef QNN_SYSTEM_INTERFACE_VER_TYPE QnnSystemApi; - -// QNN backend library should be on DT_RUNPATH (-rpath). -static const char kLibQnnHtpSo[] = "libQnnHtp.so"; - -// QNN backend library should be on DT_RUNPATH (-rpath). -static const char kLibQnnSystemSo[] = "libQnnSystem.so"; - -// Map LiteRT element type to Qnn counterpart. -inline LiteRtStatus LegalizeElementType(litert::ElementType litert_type, - Qnn_DataType_t* qnn_type) { - switch (litert_type) { - case litert::ElementType::Bool: - *qnn_type = QNN_DATATYPE_BOOL_8; - break; - case litert::ElementType::Int4: - *qnn_type = QNN_DATATYPE_SFIXED_POINT_4; - break; - case litert::ElementType::Int8: - *qnn_type = QNN_DATATYPE_INT_8; - break; - case litert::ElementType::Int16: - *qnn_type = QNN_DATATYPE_INT_16; - break; - case litert::ElementType::Int32: - *qnn_type = QNN_DATATYPE_INT_32; - break; - case litert::ElementType::Int64: - *qnn_type = QNN_DATATYPE_INT_64; - break; - case litert::ElementType::UInt8: - *qnn_type = QNN_DATATYPE_UINT_8; - break; - case litert::ElementType::UInt16: - *qnn_type = QNN_DATATYPE_UINT_16; - break; - case litert::ElementType::UInt32: - *qnn_type = QNN_DATATYPE_UINT_32; - break; - case litert::ElementType::UInt64: - *qnn_type = QNN_DATATYPE_UINT_64; - break; - case litert::ElementType::Float16: - *qnn_type = QNN_DATATYPE_FLOAT_16; - break; - case litert::ElementType::Float32: - *qnn_type = QNN_DATATYPE_FLOAT_32; - break; - case litert::ElementType::Float64: - *qnn_type = QNN_DATATYPE_FLOAT_64; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD deleted file mode 100644 index 6c54037f0803a1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_lib", "litert_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_dynamic_lib( - name = "qnn_compiler_plugin", - srcs = ["qnn_compiler_plugin.cc"], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - shared_lib_name = "qnn_compiler_plugin_so", - so_name = "libLiteRtCompilerPlugin_Qualcomm.so", - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - ungrte = True, - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":qnn_compose_graph", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -litert_test( - name = "qnn_compiler_plugin_test", - srcs = [ - "qnn_compiler_plugin_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. - "nosan", - ], - # This test can be run only on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - use_sys_malloc = True, - deps = [ - ":qnn_compiler_plugin", # buildcleaner: keep - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers_oss", - "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:quantize_op_legalization", - ], -) - -litert_lib( - name = "qnn_compose_graph", - srcs = ["qnn_compose_graph.cc"], - hdrs = ["qnn_compose_graph.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":graph_mapper", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/tools:dump", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:cast_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:concatenation_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:conv2d_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:depthwise_conv2d_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:dynamic_update_slice_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:elementwise_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:embedding_lookup_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:fully_connected_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:fully_connected_op_builder_htp", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:gather_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:gelu_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:hard_swish_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:leaky_relu_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:matmul_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:mean_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:pack_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:pool2d_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:quantize_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:reduce_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:relu6_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:relu_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:reshape_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:resize_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:rms_norm_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:select_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:slice_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:softmax_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:spatial_transform_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:split_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:tanh_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:transpose_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -litert_lib( - name = "graph_mapper", - srcs = [ - "graph_mapper.cc", - ], - hdrs = ["graph_mapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD deleted file mode 100644 index fa0e3f55e19b93..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:__subpackages__"], -) - -cc_library( - name = "qnn_tensor", - srcs = ["qnn_tensor.cc"], - hdrs = ["qnn_tensor.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - ], -) - -cc_test( - name = "qnn_tensor_test", - srcs = ["qnn_tensor_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - ], - deps = [ - ":qnn_tensor", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:test_models", - ], -) - -cc_library( - name = "qnn_op", - srcs = ["qnn_op.cc"], - hdrs = ["qnn_op.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_test( - name = "qnn_op_test", - srcs = ["qnn_op_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - ], - deps = [ - ":qnn_op", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - ], -) - -cc_test( - name = "op_compatibility_test", - srcs = ["op_compatibility_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - ], - deps = [ - ":qnn_op", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc deleted file mode 100644 index 477711417441f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" - -namespace { - -static constexpr absl::string_view kOpTpl = "simple_%s_op.tflite"; -struct OpInfo { - std::string op_name; - std::string expected_type_name; -}; - -// TODOL: b/365299994 - Add "stablehlo_scatter" once muti subgraphs is -// supported. -// clang-format off -const auto kSupportedOps = testing::Values( - OpInfo{"add", "ElementWiseAdd"}, - OpInfo{"mul", "ElementWiseMultiply"}, - OpInfo{"batch_matmul", "MatMul"}, - OpInfo{"concatenation", "Concat"}, - OpInfo{"div", "ElementWiseDivide"}, - OpInfo{"fully_connected", "FullyConnected"}, - OpInfo{"reshape", "Reshape"}, - OpInfo{"rsqrt", "ElementWiseRsqrt"}, - OpInfo{"select_v2", "ElementWiseSelect"}, - OpInfo{"select", "ElementWiseSelect"}, - OpInfo{"strided_slice", "StridedSlice"}, - OpInfo{"slice", "StridedSlice"}, - OpInfo{"softmax", "Softmax"}, - OpInfo{"sub", "ElementWiseSubtract"}, - OpInfo{"tanh", "Tanh"}, - OpInfo{"transpose", "Transpose"}); -// clang-format on - -class OpCompatibilityTest : public ::testing::TestWithParam {}; - -TEST_P(OpCompatibilityTest, SupportedOpsTest) { - auto test_params = GetParam(); - std::string model_path = absl::StrFormat(kOpTpl, test_params.op_name); - auto model = litert::testing::LoadTestFileModel(model_path); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - - Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); - LITERT_ASSERT_OK(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op)); - - EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, test_params.op_name)); - EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); - EXPECT_STREQ(qnn_op.v1.typeName, test_params.expected_type_name.c_str()); - - EXPECT_EQ(qnn_op.v1.numOfInputs, 0); - EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); - EXPECT_EQ(qnn_op.v1.numOfParams, 0); - - litert::qnn::ResetOp(qnn_op); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, OpCompatibilityTest, kSupportedOps); - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc deleted file mode 100644 index 0a6949afaf7807..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -// A macro dance to create a unique literal string given a prefix. -#define STRINGIFY(x) #x -#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER) - -namespace litert::qnn { - -namespace { - -// Maps "op-code" related information (name, packageName, typeName) from src -// to dest. -LiteRtStatus LegalizeOpType(const Op& src, Qnn_OpConfig_t& dest) { - switch (src.Code()) { - case kLiteRtOpCodeTflMul: - dest.v1.name = QNN_OP_NAME(mul_); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseMultiply"; - break; - case kLiteRtOpCodeTflAdd: - dest.v1.name = QNN_OP_NAME("add"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseAdd"; - break; - case kLiteRtOpCodeTflBatchMatmul: - dest.v1.name = QNN_OP_NAME("batch_matmul"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "MatMul"; - break; - case kLiteRtOpCodeTflConcatenation: - dest.v1.name = QNN_OP_NAME("concatenation"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Concat"; - break; - case kLiteRtOpCodeTflDiv: - dest.v1.name = QNN_OP_NAME("div"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseDivide"; - break; - case kLiteRtOpCodeTflFullyConnected: - dest.v1.name = QNN_OP_NAME("fully_connected"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "FullyConnected"; - break; - case kLiteRtOpCodeTflReshape: - dest.v1.name = QNN_OP_NAME("reshape"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Reshape"; - break; - case kLiteRtOpCodeTflRsqrt: - dest.v1.name = QNN_OP_NAME("rsqrt"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseRsqrt"; - break; - case kLiteRtOpCodeTflSelectV2: - dest.v1.name = QNN_OP_NAME("select_v2"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseSelect"; - break; - case kLiteRtOpCodeTflSelect: - dest.v1.name = QNN_OP_NAME("select"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseSelect"; - break; - case kLiteRtOpCodeTflStridedSlice: - dest.v1.name = QNN_OP_NAME("strided_slice"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "StridedSlice"; - break; - case kLiteRtOpCodeTflSlice: - dest.v1.name = QNN_OP_NAME("slice"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "StridedSlice"; - break; - case kLiteRtOpCodeTflSoftmax: - dest.v1.name = QNN_OP_NAME("softmax"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Softmax"; - break; - case kLiteRtOpCodeTflSub: - dest.v1.name = QNN_OP_NAME("sub"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseSubtract"; - break; - case kLiteRtOpCodeTflTanh: - dest.v1.name = QNN_OP_NAME("tanh"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Tanh"; - break; - case kLiteRtOpCodeTflTranspose: - dest.v1.name = QNN_OP_NAME("transpose"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Transpose"; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - return kLiteRtStatusOk; -} - -} // namespace - -Qnn_OpConfig_t BuildDefaultOp() { - Qnn_OpConfig_t op = QNN_OPCONFIG_INIT; - ResetOp(op); - return op; -} -Qnn_Param_t BuildDefaultParam() { - Qnn_Param_t param = QNN_PARAM_INIT; - ResetParam(param); - return param; -} - -void ResetOp(Qnn_OpConfig_t& op) { - op = QNN_OPCONFIG_INIT; - op.version = QNN_OPCONFIG_VERSION_1; - op.v1 = QNN_OPCONFIG_V1_INIT; -} - -void ResetParam(Qnn_Param_t& param) { param = QNN_PARAM_INIT; } -LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest) { - ResetOp(dest); - Op op(src); - return LegalizeOpType(op, dest); -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h deleted file mode 100644 index 20e0f27f798b98..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -namespace litert::qnn { - -// -// Initialize QNN Op. -// - -// NOTE: Any referential data within a QNN Op -// is allocated with "new" and must be explicitly cleaned up with ResetOp. - -// Construct a "blank" QNN Op. -Qnn_OpConfig_t BuildDefaultOp(); - -// Construct a "blank" QNN Param. -Qnn_Param_t BuildDefaultParam(); - -// Reset the given tensor, deallocating anything on the heap that it points to. -void ResetOp(Qnn_OpConfig_t& op); - -// Reset the given param, deallocating anything on the heap that it points to. -void ResetParam(Qnn_Param_t& param); - -// -// Legalize LiteRt Op to Analogous QNN Construct. -// - -// Map src op onto dest. Resets dest before doing anything. This only handles -// attribute-like info. It does not set edges (in/out tensors). -LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc deleted file mode 100644 index dd78cfca40b88c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" - -#include -#include -#include "absl/strings/match.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::litert::IsError; - -TEST(TestInitQnnOp, BuildDefaultOp) { - Qnn_OpConfig_t op = litert::qnn::BuildDefaultOp(); - ASSERT_EQ(op.version, QNN_OPCONFIG_VERSION_1); -} - -TEST(TestLegalizeOp, SimpleSupportedOp) { - auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - - Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); - LITERT_ASSERT_OK(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op)); - - EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, "mul")); - EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); - EXPECT_STREQ(qnn_op.v1.typeName, "ElementWiseMultiply"); - - EXPECT_EQ(qnn_op.v1.numOfInputs, 0); - EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); - EXPECT_EQ(qnn_op.v1.numOfParams, 0); - - litert::qnn::ResetOp(qnn_op); -} - -TEST(TestLegalizeOp, UnsupportedOp) { - auto model = litert::testing::LoadTestFileModel("simple_floor_mod_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - - Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); - EXPECT_THAT(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op), - IsError(kLiteRtStatusErrorUnsupported)); - - litert::qnn::ResetOp(qnn_op); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc deleted file mode 100644 index 4a308f6da78012..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" - -#include - -#include "absl/log/absl_check.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" - -namespace litert::qnn { - -namespace { - -LiteRtStatus LegalizeShapeInfo(const litert::Layout& src, Qnn_Tensor_t& dest) { - LITERT_ENSURE_SUPPORTED(!src.HasStrides(), "Strides not yet supported"); - - dest.v2.rank = src.Rank(); - // Ad-hoc fix: rank 0 tensor needs to be single element 1D tensor in QNN. - if (dest.v2.rank == 0) { - LITERT_LOG(LITERT_INFO, "Setting rank 0 tensor to single element tensor"); - dest.v2.rank = 1; - dest.v2.dimensions = new uint32_t[1]; - dest.v2.dimensions[0] = 1; - return kLiteRtStatusOk; - } - - dest.v2.dimensions = new uint32_t[dest.v2.rank]; - for (int i = 0; i < dest.v2.rank; ++i) { - const auto src_dim = src.Dimensions()[i]; - LITERT_ENSURE(src_dim >= 1, kLiteRtStatusErrorInvalidArgument, - "Cannot pass dim < 1 to QNN Tensor."); - - dest.v2.dimensions[i] = src.Dimensions()[i]; - } - return kLiteRtStatusOk; -} - -void FreeTensorDims(Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_2 && - tensor.v2.dimensions != nullptr) { - delete[] tensor.v2.dimensions; - tensor.v2.dimensions = nullptr; - tensor.v2.rank = 0; - } -} - -void FreePerChannelQuantization(Qnn_Tensor_t& tensor) { - if (tensor.v2.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - delete[] tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = nullptr; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = 0; - } -} - -} // namespace - -void SetInputTensorAttrs(Qnn_Tensor_t& tensor) { - ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); - tensor.v2.type = QNN_TENSOR_TYPE_APP_WRITE; - tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - tensor.v2.clientBuf = QNN_CLIENT_BUFFER_INIT; -} - -void SetOutputTensorAttrs(Qnn_Tensor_t& tensor) { - ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); - tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; -} - -void SetResultTensorAttrs(Qnn_Tensor_t& tensor) { - ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); - tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - tensor.v2.type = QNN_TENSOR_TYPE_NATIVE; -} - -void ResetTensor(Qnn_Tensor_t& tensor) { - FreeTensorDims(tensor); - FreePerChannelQuantization(tensor); - tensor = QNN_TENSOR_INIT; - tensor.version = QNN_TENSOR_VERSION_2; - tensor.v2 = QNN_TENSOR_V2_INIT; - tensor.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_DENSE; - tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; -} - -Qnn_Tensor_t BuildDefaultTensor(uint32_t id) { - Qnn_Tensor_t tensor = QNN_TENSOR_INIT; - ResetTensor(tensor); - tensor.v2.id = id; - return tensor; -} - -Qnn_Tensor_t BuildDefaultTensor() { return BuildDefaultTensor(0); } - -Qnn_Tensor_t BuildInputTensor() { - auto tensor = BuildDefaultTensor(); - SetInputTensorAttrs(tensor); - return tensor; -} - -Qnn_ClientBuffer_t BuildDefaultClientBuffer() { - Qnn_ClientBuffer_t client_buf = QNN_CLIENT_BUFFER_INIT; - client_buf.data = nullptr; - client_buf.dataSize = 0; - return client_buf; -} - -Qnn_Tensor_t BuildOutputTensor() { - Qnn_Tensor_t tensor = BuildDefaultTensor(); - SetOutputTensorAttrs(tensor); - return tensor; -} - -uint32_t MoveToId(Qnn_Tensor_t& tensor) { - const auto id = tensor.v2.id; - ResetTensor(tensor); - tensor.v2.id = id; - return id; -} - -void SetPerChannelQuantization( - Qnn_Tensor_t& tensor, - const LiteRtQuantizationPerChannel& lite_rt_quantization_per_channel) { - tensor.v2.quantizeParams.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; - - tensor.v2.quantizeParams.axisScaleOffsetEncoding = QNN_AXIS_SCALE_OFFSET_INIT; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis = - lite_rt_quantization_per_channel.quantized_dimension; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = - lite_rt_quantization_per_channel.num_channels; - - // Allocates memory for scaleOffset array. - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = - new Qnn_ScaleOffset_t[lite_rt_quantization_per_channel.num_channels]; - - for (int i = 0; i < lite_rt_quantization_per_channel.num_channels; ++i) { - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].scale = - lite_rt_quantization_per_channel.scales[i]; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].offset = - lite_rt_quantization_per_channel.zero_points[i]; - } -} - -void SetPerTensorQuantization( - Qnn_Tensor_t& tensor, - const LiteRtQuantizationPerTensor& lite_rt_quantization_per_tensor) { - tensor.v2.quantizeParams.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; - tensor.v2.quantizeParams.scaleOffsetEncoding.scale = - lite_rt_quantization_per_tensor.scale; - tensor.v2.quantizeParams.scaleOffsetEncoding.offset = - lite_rt_quantization_per_tensor.zero_point; -} - -LiteRtStatus LegalizeQuntizationParameter(const litert::Tensor& src, - Qnn_Tensor_t& dest) { - LiteRtQuantizationTypeId lite_rt_quantization_type_id = src.QTypeId(); - switch (lite_rt_quantization_type_id) { - case kLiteRtQuantizationPerTensor: - SetPerTensorQuantization(dest, src.PerTensorQuantization()); - return kLiteRtStatusOk; - case kLiteRtQuantizationPerChannel: - SetPerChannelQuantization(dest, src.PerChannelQuantization()); - return kLiteRtStatusOk; - default: - LITERT_LOG(LITERT_ERROR, "Unsupported quantization type."); - return kLiteRtStatusErrorInvalidArgument; - } -} - -LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest) { - if (src.TypeId() != kLiteRtRankedTensorType) { - return kLiteRtStatusErrorInvalidArgument; - } - - ResetTensor(dest); - - if (src.HasQuantization()) { - LITERT_RETURN_IF_ERROR(LegalizeQuntizationParameter(src, dest)); - } - - auto src_ranked_tensor_type = src.RankedTensorType(); - if (!src_ranked_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - src_ranked_tensor_type.Error().Message().c_str()); - return src_ranked_tensor_type.Error().Status(); - } - - Qnn_DataType_t* qnn_data_type = &dest.v2.dataType; - LITERT_RETURN_IF_ERROR(LegalizeElementType( - src_ranked_tensor_type->ElementType(), qnn_data_type)); - - LITERT_RETURN_IF_ERROR( - LegalizeShapeInfo(src_ranked_tensor_type->Layout(), dest)); - - const bool is_subgraph_in = src.IsSubgraphInput(); - const bool is_subgraph_out = src.IsSubgraphOutput(); - const bool is_constant = src.IsConstant(); - - LITERT_ENSURE(!(is_subgraph_in && is_subgraph_out), - kLiteRtStatusErrorInvalidArgument, - "Malformed tensor, cannot be both subgraph in and out."); - if (is_constant) { - LITERT_LOG(LITERT_INFO, "Adding constant tensor %s to qnn graph", - dest.v2.name); - LITERT_ENSURE(src.HasWeights(), kLiteRtStatusErrorInvalidLegalization, - "Empty weights for constant tensor."); - Qnn_ClientBuffer_t client_buf = BuildDefaultClientBuffer(); - client_buf.data = (void*)src.Weights().Bytes().data(); - client_buf.dataSize = src.Weights().Bytes().size(); - dest.v2.clientBuf = client_buf; - dest.v2.memType = QNN_TENSORMEMTYPE_RAW; - dest.v2.type = QNN_TENSOR_TYPE_STATIC; - dest.v2.isDynamicDimensions = nullptr; - } - - if (is_subgraph_in) { - LITERT_LOG(LITERT_INFO, "Adding subgraph input tensor to qnn graph"); - SetInputTensorAttrs(dest); - } - if (is_subgraph_out) { - LITERT_LOG(LITERT_INFO, "Adding subgraph output tensor to qnn graph"); - SetOutputTensorAttrs(dest); - } - if (!is_constant && !is_subgraph_in && !is_subgraph_out) { - LITERT_LOG(LITERT_INFO, "Adding result tensor to qnn graph"); - SetResultTensorAttrs(dest); - } - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h deleted file mode 100644 index 607cc4c3decba9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ - -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -namespace litert::qnn { - -// -// Initialize QNN Tensors. -// - -// NOTE: Within LiteRt land, all Qnn Tensors are treated as "v2". Any -// referential data (like dimensions : uint32_t*) within a QNN Tensor -// is allocated with "new" and must be explicitly cleaned up with ResetTensor. - -// Construct a "blank" QNN Tensor. -Qnn_Tensor_t BuildDefaultTensor(); - -// Construct a "blank" QNN Tensor with given id. -Qnn_Tensor_t BuildDefaultTensor(uint32_t id); - -// Constructa a "blank" QNN Tensor meant to be used as a graph input. -Qnn_Tensor_t BuildInputTensor(); - -// Constructa a "blank" QNN Tensor meant to be used as a graph output. -Qnn_Tensor_t BuildOutputTensor(); - -Qnn_ClientBuffer_t BuildDefaultClientBuffer(); - -// Adds attributes to given tensor making it amenable for use as graph input. -void SetInputTensorAttrs(Qnn_Tensor_t& tensor); - -// Adds attributes to given tensor making it amenable for use as graph output. -void SetOutputTensorAttrs(Qnn_Tensor_t& tensor); - -// Adds attributes to given tensor making it amenable for uses a intermediate -// output. -void SetResultTensorAttrs(Qnn_Tensor_t& tensor); - -// Reset the given tensor, deallocating anything on the heap that it points to. -void ResetTensor(Qnn_Tensor_t& tensor); - -// Resets all fields other than id in the given tensor and returns the id for -// convenience. Only the id is needed to traffic QNN Tensors after they have -// been registered with the context. -uint32_t MoveToId(Qnn_Tensor_t& tensor); - -// -// Legalize LiteRt Tensors to Analogous QNN Construct. -// - -// Map src tensor onto dest. Resets dest before doing anything. -LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc deleted file mode 100644 index ba38fd211457a8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" - -#include -#include -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" - -namespace { - -constexpr float kSimpleMulQuantModelOutputScale = 0.00028621565f; -constexpr float kSimpleMulQuantModelOutputOffset = 0; - -TEST(TestInitQnnTensor, BuildDefaultTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); - EXPECT_EQ(tensor.v2.rank, 0); - EXPECT_EQ(tensor.v2.dimensions, nullptr); - EXPECT_EQ(tensor.v2.id, 0); -} - -TEST(TestInitQnnTensor, BuildDefaultTensorWithId) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); - EXPECT_EQ(tensor.v2.rank, 0); - EXPECT_EQ(tensor.v2.dimensions, nullptr); - EXPECT_EQ(tensor.v2.id, 2); -} - -TEST(TestInitQnnTensor, BuildDefaultInputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildInputTensor(); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); -} - -TEST(TestInitQnnTensor, SetInputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); - litert::qnn::SetInputTensorAttrs(tensor); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); -} - -TEST(TestInitQnnTensor, BuildDefaultOutputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildOutputTensor(); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); -} - -TEST(TestInitQnnTensor, SetOutputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); - litert::qnn::SetOutputTensorAttrs(tensor); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); -} - -TEST(TestInitQnnTensor, MoveToId) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); - - litert::qnn::SetOutputTensorAttrs(tensor); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); - - EXPECT_EQ(litert::qnn::MoveToId(tensor), 2); - EXPECT_EQ(tensor.v2.id, 2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_UNDEFINED); -} - -TEST(TestLegalizeTensor, SimpleSupportedTensorSubgraphInput) { - auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto outputs = subgraph->Outputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& output_tensor = outputs.front(); - LITERT_ASSERT_OK(litert::qnn::LegalizeTensor(output_tensor, qnn_tensor)); - - ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); - EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); - - ASSERT_EQ(qnn_tensor.v2.rank, 2); - ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); - EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), - ::testing::ElementsAreArray({2, 2})); - - litert::qnn::ResetTensor(qnn_tensor); -} - -TEST(TestLegalizeTensor, SimpleSupportedTensor) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - auto op_outs = ops.at(1).Outputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& op_out = op_outs.front(); - LITERT_ASSERT_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); - - ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); - EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_NATIVE); - - ASSERT_EQ(qnn_tensor.v2.rank, 2); - ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); - EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), - ::testing::ElementsAreArray({2, 2})); - - litert::qnn::ResetTensor(qnn_tensor); -} - -TEST(TestLegalizeTensor, SimpleQuantizedTensor) { - auto model = litert::testing::LoadTestFileModel(kQSimpleMul16x16Model); - - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - auto op_outs = ops.at(0).Outputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& op_out = op_outs.front(); - LITERT_ASSERT_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); - - ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_16); - EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); - - ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.scale, - kSimpleMulQuantModelOutputScale); - - ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.offset, - kSimpleMulQuantModelOutputOffset); - litert::qnn::ResetTensor(qnn_tensor); -} - -TEST(TestLegalizeTensor, PerChannelQuantizedTensor) { - auto model = litert::testing::LoadTestFileModel(kQKeyEinsum16x8Model); - - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - auto op_ins = ops.at(1).Inputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& per_channel_quant_tensor = op_ins[1]; - LITERT_ASSERT_OK( - litert::qnn::LegalizeTensor(per_channel_quant_tensor, qnn_tensor)); - - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_8); - - LiteRtQuantizationPerChannel per_channel_quant_params = - per_channel_quant_tensor.PerChannelQuantization(); - - ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis, - per_channel_quant_params.quantized_dimension); - EXPECT_EQ( - qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets, - per_channel_quant_params.num_channels); - for (int i = 0; i < per_channel_quant_params.num_channels; ++i) { - ASSERT_FLOAT_EQ( - qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] - .scale, - per_channel_quant_params.scales[i]); - ASSERT_EQ( - qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] - .offset, - per_channel_quant_params.zero_points[i]); - } - litert::qnn::ResetTensor(qnn_tensor); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc deleted file mode 100644 index e0be0c0c8650ae..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" - -#include -#include - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpGraph.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnGraph.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -inline absl::Span GetDefaultGraphConfigs() { - static std::array graph_custom_configs; - // QNN suggest always enable relax precision. - graph_custom_configs[0] = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; - graph_custom_configs[0].option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; - graph_custom_configs[0].precision = QNN_PRECISION_FLOAT16; - // Default use O3 for now. - graph_custom_configs[1] = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; - graph_custom_configs[1].option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - graph_custom_configs[1].optimizationOption.type = - QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - // Change to 2 if you want to use O2 (default). - graph_custom_configs[1].optimizationOption.floatValue = 3; - - static std::array graph_configs; - graph_configs[0] = QNN_GRAPH_CONFIG_INIT; - graph_configs[0].option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_configs[0].customConfig = &graph_custom_configs[0]; - - graph_configs[1] = QNN_GRAPH_CONFIG_INIT; - graph_configs[1].option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_configs[1].customConfig = &graph_custom_configs[1]; - - static std::array result = { - &graph_configs[0], &graph_configs[1], nullptr}; - - return absl::MakeSpan(result.data(), result.size()); -} - -inline absl::Span GetLegacyGraphConfigs() { - static QnnHtpGraph_CustomConfig_t graph_custom_config; - // Default use O3 for now. - graph_custom_config = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; - graph_custom_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - graph_custom_config.optimizationOption.type = - QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - // Change to 2 if you want to use O2 (default). - graph_custom_config.optimizationOption.floatValue = 3; - - static QnnGraph_Config_t graph_config; - graph_config = QNN_GRAPH_CONFIG_INIT; - graph_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_config.customConfig = &graph_custom_config; - - static std::array result = {&graph_config, - nullptr}; - - return absl::MakeSpan(result.data(), result.size()); -} - -absl::Span GraphMapper::PickGraphConfigHeuristic() { - if (qnn_.IsLegacySocModel()) { - return GetLegacyGraphConfigs(); - } else { - return GetDefaultGraphConfigs(); - } -} - -LiteRtStatus GraphMapper::AssignTensorName(Qnn_Tensor_t& qnn_tensor) { - char* name = nullptr; - const int written = asprintf(&name, "Tensor_%d", cur_tensor_num_++); - LITERT_ENSURE(written != -1 && name != nullptr, kLiteRtStatusErrorNotFound, - "Failed to make tensor name"); - qnn_tensor.v2.name = name; - return kLiteRtStatusOk; -} - -absl::flat_hash_map& GraphMapper::CurrentScope() { - return current_scope_; -} - -LiteRtStatus GraphMapper::LookupInScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor) { - // If we go in topological order, this should never happen. TODO: add - // "internal error" status code. - const auto qnn_id = CurrentScope().find(litert_tensor); - // when qnn_id is not found, the tensor is a constant tensor thats not been - // added qnn graph. - if (qnn_id == CurrentScope().end()) { - LITERT_LOG(LITERT_INFO, "Adding constant tensor %s to qnn graph", - qnn_tensor.v2.name); - LITERT_RETURN_IF_ERROR(LegalizeAndRegister(litert_tensor, qnn_tensor)); - LITERT_RETURN_IF_ERROR(PushToScope(litert_tensor, qnn_tensor)); - // } - return kLiteRtStatusOk; - } - LITERT_LOG(LITERT_INFO, "Found tensor %d in current_scope.", qnn_id->second); - ResetTensor(qnn_tensor); - qnn_tensor.v2.id = qnn_id->second; - - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::PushToScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor) { - CurrentScope()[litert_tensor] = MoveToId(qnn_tensor); - return kLiteRtStatusOk; -} - -QnnManager& GraphMapper::Qnn() { return qnn_; } - -Qnn_GraphHandle_t& GraphMapper::QnnGraph() { return qnn_graph_; } - -LiteRtStatus GraphMapper::LegalizeAndRegister(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor) { - litert::Tensor tensor(litert_tensor); - LITERT_RETURN_IF_ERROR(LegalizeTensor(tensor, qnn_tensor)); - LITERT_RETURN_IF_ERROR(AssignTensorName(qnn_tensor)); - - // Set tensor as graph output if it is used by other Ops. - if (graph_outpus_.contains(litert_tensor)) { - LITERT_LOG(LITERT_INFO, "Setting tensor %d as Graph output", - qnn_tensor.v2.id); - qnn_tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; - } - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - qnn_.Api()->tensorCreateGraphTensor(QnnGraph(), &qnn_tensor)); - - LITERT_LOG(LITERT_INFO, "Legalized and registered tensor %d", - qnn_tensor.v2.id); - - for (int i = 0; i < qnn_tensor.v2.rank; ++i) { - LITERT_LOG(LITERT_INFO, "qnn_tensor dim[%d] = %d", i, - qnn_tensor.v2.dimensions[i]); - } - - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::IsLiteRtSubgraphSupported() { - // For now, we assume all LiteRt subgraphs are supported. - // TODO: b/381133565: Implement or remove this function. - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::InitQnnGraph(absl::string_view qnn_graph_name) { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - qnn_.Api()->graphCreate(context_handle_, qnn_graph_name.data(), - PickGraphConfigHeuristic().data(), &QnnGraph())); - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::Finalize() { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - qnn_.Api()->graphFinalize(QnnGraph(), nullptr, nullptr)); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h deleted file mode 100644 index 3e70e9f222e442..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnGraph.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -// Algorithm class for managing "scope" when mapping litert Subgraphs -// to QNN Graphs. -class GraphMapper { - public: - GraphMapper(LiteRtSubgraph subgraph, QnnManager& qnn, - Qnn_ContextHandle_t context_handle) - : subgraph_(Subgraph(subgraph)), - qnn_(qnn), - context_handle_(context_handle) {} - - // Legalize given LiteRtTensors attributes into QNN Tensor registered with - // QNN context. Result QNN Tensor is empty except for the canonical id - // assigned by QNN Api. - LiteRtStatus LegalizeAndRegister(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor); - - // Find ID associated with evaluated litert Tensor and add it to given - // QNN Tensor. - LiteRtStatus LookupInScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor); - - // Adds new mapping to scope. All fields other than ID in given QNN Tensor are - // cleared and its ID is added to "current_scope". Expects QNN Tensor has - // already been registered with context. - LiteRtStatus PushToScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor); - - // NOTE: QNN Tensors must be created with a unique name. This will ensure - // uniqueness but will want to have more meaningful names in the future. - LiteRtStatus AssignTensorName(Qnn_Tensor_t& qnn_tensor); - - // QNN Sdk Accessors - QnnManager& Qnn(); - Qnn_GraphHandle_t& QnnGraph(); - - // CC Convenience Accessors - const Subgraph& Graph() const { return subgraph_; } - - // Accessor for current scope. - // Since each QNN Tensor needs to have a unique name globally within each QNN - // context, we maintain "Current scope", which is a map of evaluated - // LiteRtTensors to their resolved QNN Tensor ID. - absl::flat_hash_map& CurrentScope(); - - // Can implementation handle given LiteRtSubgraph topology (see comment at - // bottom of file). - LiteRtStatus IsLiteRtSubgraphSupported(); - - // Initialize QNN Graph with given name. Call this after parsing - // LiteRtSubgraph. - LiteRtStatus InitQnnGraph(absl::string_view qnn_graph_name); - - // Finalize QNN Graph. Call this after all ops have been mapped. - LiteRtStatus Finalize(); - - inline void RegisterOutput(LiteRtTensor litert_tensor) { - graph_outpus_.insert(litert_tensor); - } - - // Pick graph config based on subgraph. - absl::Span PickGraphConfigHeuristic(); - - inline bool IsTensorOutput(LiteRtTensor litert_tensor) { - return graph_outpus_.contains(litert_tensor); - } - - private: - const Subgraph subgraph_; - - // Set of all outputs of the graph. - absl::flat_hash_set graph_outpus_; - - // Maps evaluated tensors to their resolved QNN Tensor ID. - absl::flat_hash_map current_scope_; - - // - // QNN Sdk State - // - QnnManager& qnn_; - Qnn_ContextHandle_t context_handle_; - Qnn_GraphHandle_t qnn_graph_ = nullptr; - - // - // Tensor Naming - // - - uint32_t cur_tensor_num_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD deleted file mode 100644 index 46f27e985fd262..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD +++ /dev/null @@ -1,922 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_lib( - name = "legalization", - hdrs = ["legalization.h"], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - ], -) - -litert_lib( - name = "add_op_legalization", - srcs = ["add_op_legalization.cc"], - hdrs = ["add_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "batch_matmul_op_legalization", - srcs = ["batch_matmul_op_legalization.cc"], - hdrs = ["batch_matmul_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "cast_op_legalization", - srcs = ["cast_op_legalization.cc"], - hdrs = ["cast_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "concatenation_op_legalization", - srcs = ["concatenation_op_legalization.cc"], - hdrs = ["concatenation_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "cos_op_legalization", - srcs = ["cos_op_legalization.cc"], - hdrs = ["cos_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "div_op_legalization", - srcs = ["div_op_legalization.cc"], - hdrs = ["div_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "dynamic_update_slice_op_legalization", - srcs = ["dynamic_update_slice_op_legalization.cc"], - hdrs = ["dynamic_update_slice_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "embedding_lookup_op_legalization", - srcs = ["embedding_lookup_op_legalization.cc"], - hdrs = ["embedding_lookup_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "transpose_op_legalization", - srcs = ["transpose_op_legalization.cc"], - hdrs = ["transpose_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "fully_connected_op_legalization", - srcs = ["fully_connected_op_legalization.cc"], - hdrs = ["fully_connected_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "gelu_op_legalization", - srcs = ["gelu_op_legalization.cc"], - hdrs = ["gelu_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "greater_op_legalization", - srcs = ["greater_op_legalization.cc"], - hdrs = ["greater_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "less_op_legalization", - srcs = ["less_op_legalization.cc"], - hdrs = ["less_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "logical_and_op_legalization", - srcs = ["logical_and_op_legalization.cc"], - hdrs = ["logical_and_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "mul_op_legalization", - srcs = ["mul_op_legalization.cc"], - hdrs = ["mul_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "pack_op_legalization", - srcs = ["pack_op_legalization.cc"], - hdrs = ["pack_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "quantize_op_legalization", - srcs = ["quantize_op_legalization.cc"], - hdrs = ["quantize_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "reshape_op_legalization", - srcs = ["reshape_op_legalization.cc"], - hdrs = ["reshape_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "rsqrt_op_legalization", - srcs = ["rsqrt_op_legalization.cc"], - hdrs = ["rsqrt_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "sin_op_legalization", - srcs = ["sin_op_legalization.cc"], - hdrs = ["sin_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "select_op_legalization", - srcs = ["select_op_legalization.cc"], - hdrs = ["select_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "slice_op_legalization", - srcs = ["slice_op_legalization.cc"], - hdrs = ["slice_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "sum_op_legalization", - srcs = ["sum_op_legalization.cc"], - hdrs = ["sum_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "sub_op_legalization", - srcs = ["sub_op_legalization.cc"], - hdrs = ["sub_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "softmax_op_legalization", - srcs = ["softmax_op_legalization.cc"], - hdrs = ["softmax_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "tanh_op_legalization", - srcs = ["tanh_op_legalization.cc"], - hdrs = ["tanh_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "util", - srcs = ["util.cc"], - hdrs = ["util.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/tools:dump", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc deleted file mode 100644 index a2a8da69bdc816..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnAddOpTypeName = "ElementWiseAdd"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kAddOpFmt = "add_%d"; - -LiteRtStatus AddOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflAdd) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kAddOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnAddOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized add op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h deleted file mode 100644 index c8301cb124666b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class AddOpLegalization : public Legalization { - public: - AddOpLegalization() = default; - ~AddOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc deleted file mode 100644 index 0685a751243054..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnBatchMatmulOpTypeName = "MatMul"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kBatchMatmulOpFmt = "batch_matmul_%d"; - -LiteRtStatus BatchMatmulOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kBatchMatmulOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnBatchMatmulOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized batch_matmul op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h deleted file mode 100644 index 60aee1f164f079..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class BatchMatmulOpLegalization : public Legalization { - public: - BatchMatmulOpLegalization() = default; - ~BatchMatmulOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc deleted file mode 100644 index 8a3bdef7138a6d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnCastOpTypeName = "Cast"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kCastOpFmt = "cast_%d"; - -LiteRtStatus CastOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflCast) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kCastOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnCastOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized cast op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h deleted file mode 100644 index fecbe54be7643d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class CastOpLegalization : public Legalization { - public: - CastOpLegalization() = default; - ~CastOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc deleted file mode 100644 index 11fd3f526fb8ed..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnConcatenationOpTypeName = "Concat"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kConcatenationOpFmt = "concatenation_%d"; - -static constexpr int kReduceConcatenationOpOutputSize = 1; -static constexpr int kReduceConcatenationOpParamSize = 1; - -LiteRtStatus ConcatenationOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflConcatenation) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kConcatenationOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnConcatenationOpTypeName.data(), dest)); - - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - - // QNN concatenation op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, - kReduceConcatenationOpOutputSize, QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Extract axis option from concatenation op. - int32_t axis; - LITERT_RETURN_IF_ERROR(LiteRtGetConcatenationAxisOption(src.Get(), &axis)); - - // Construct the scalar "axis" param. - Qnn_Param_t axis_param = BuildDefaultParam(); - axis_param.paramType = QNN_PARAMTYPE_SCALAR; - axis_param.name = "axis"; - Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; - axis_scalar.dataType = QNN_DATATYPE_UINT_32; - axis_scalar.int32Value = axis; - axis_param.scalarParam = axis_scalar; - - Qnn_Param_t concatenation_params[] = {axis_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = op_ins.size(); - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kReduceConcatenationOpOutputSize; - dest.v1.numOfParams = kReduceConcatenationOpParamSize; - dest.v1.params = concatenation_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized concatenation op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h deleted file mode 100644 index b3c26971b57c43..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class ConcatenationOpLegalization : public Legalization { - public: - ConcatenationOpLegalization() = default; - ~ConcatenationOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc deleted file mode 100644 index 7bd555b31cef8b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnCosOpTypeName = "ElementWiseCos"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kCosOpFmt = "cos_%d"; - -LiteRtStatus CosOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflCos) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kCosOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnCosOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized cos op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h deleted file mode 100644 index 6a35da2fb12d4c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class CosOpLegalization : public Legalization { - public: - CosOpLegalization() = default; - ~CosOpLegalization() = default; - using UniquePtr = std::unique_ptr; - static UniquePtr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc deleted file mode 100644 index 947bad6f719b0f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnDivOpTypeName = "ElementWiseDivide"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kDivOpFmt = "div_%d"; - -LiteRtStatus DivOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflDiv) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kDivOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnDivOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized div op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h deleted file mode 100644 index a22b91248a4661..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class DivOpLegalization : public Legalization { - public: - DivOpLegalization() = default; - ~DivOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc deleted file mode 100644 index 1511802a788a1b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; - -// Dynamic update slice op info. -static constexpr int kDynamicUpdateSliceOpOperandIndex = 0; -static constexpr int kDynamicUpdateSliceOpUpdateIndex = 1; -static constexpr int kDynamicUpdateSliceOpIndicesIndex = 2; - -// ScatterND op config. -static constexpr absl::string_view kQnnScatterNdOpTypeName = "ScatterNd"; -static constexpr absl::string_view kScatterNdOpFmt = "dus_scatter_nd_%d"; -static constexpr int kScatterNDOpInputSize = 3; -static constexpr int kScatterNDOpOutputSize = 1; -static constexpr int kScatterNDOutputRank = 4; -static constexpr int kScatterNDParamSize = 0; - -// Strided slice op config. -static constexpr absl::string_view kStridedSliceOpTypeName = "StridedSlice"; -static constexpr absl::string_view kStridedSliceOpFmt = "dus_strided_slice_%d"; -static constexpr int kStridedSliceOpInputSize = 1; -static constexpr int kStridedSliceOpOutputSize = 1; -static constexpr int kStridedSliceOpOutputRank = 1; -static constexpr int kStridedSliceParamSize = 1; -static constexpr absl::string_view kRangesParamName = "ranges"; -static constexpr int kRangesParamRank = 2; -static constexpr int kRangesParamArgSize = 3; - -// Reshape op config. -static constexpr absl::string_view kReshapeOpTypeName = "Reshape"; -static constexpr absl::string_view kReshapeOpFmt = "dus_reshape_%d"; -static constexpr int kReshapeOpInputSize = 1; -static constexpr int kReshapeOpOutputSize = 1; -static constexpr int kReshapeOpOutputRank = 2; -static constexpr int kReshapeParamSize = 0; - -// Transpose op config. -static constexpr absl::string_view kTransposeOpTypeName = "Transpose"; -static constexpr absl::string_view kTransposeOperandOpFmt = - "dus_transpose_operand_%d"; -static constexpr absl::string_view kTransposeUpdateOpFmt = - "dus_transpose_update_%d"; -static constexpr absl::string_view kTransposeResultOpFmt = - "dus_transpose_result_%d"; -static constexpr int kTransposeOpInputSize = 1; -static constexpr int kTransposeOpOutputSize = 1; -static constexpr int kTransposeOpOutputRank = 4; -static constexpr int kTransposeParamSize = 1; -static constexpr absl::string_view kPermParamName = "perm"; -static constexpr int kPermParamRank = 1; -static constexpr int kPermParamArgSize = 4; - -LiteRtStatus DynamicUpdateSliceOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflDynamicUpdateSlice) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - - // Legalize input tensors, lookup operand tensor in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kScatterNDOpInputSize, - QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - // Legalize op data type. - Qnn_DataType_t OperandDataType, UpdateDataType; - LITERT_RETURN_IF_ERROR(LegalizeElementType( - op_ins[kDynamicUpdateSliceOpOperandIndex].ElementType(), - &OperandDataType)); - LITERT_RETURN_IF_ERROR(LegalizeElementType( - op_ins[kDynamicUpdateSliceOpUpdateIndex].ElementType(), &UpdateDataType)); - - //=========================================================================== - // Step 1.1 Build strided slice op. Extract slice index from input[2] - // input: [0, x, 0, 0] (LiteRT.DUS input[2]) - // output: [x] - Qnn_OpConfig_t strided_slice_op = BuildDefaultOp(); - std::string op_name = absl::StrFormat(kStridedSliceOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR( - SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), - kStridedSliceOpTypeName.data(), strided_slice_op)); - - // Prepare strided slice op params. - std::vector ranges = {1, 2, 1}; - std::vector ranges_dims = {1, kRangesParamArgSize}; - Qnn_Param_t range_param = BuildDefaultParam(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildQnnTesnorParam( - ranges.data(), ranges_dims.data(), QNN_DATATYPE_INT_32, kRangesParamRank, - kRangesParamName.data(), graph_mapper, range_param)); - - // Prepare strided slice op outputs. - Qnn_Tensor_t strided_slice_op_out = BuildDefaultTensor(); - std::vector slice_op_out_dims = {1}; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - QNN_DATATYPE_INT_32, kStridedSliceOpOutputRank, slice_op_out_dims.data(), - graph_mapper, strided_slice_op_out)); - - // Configure strided slice op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kStridedSliceOpInputSize, &qnn_op_ins[kDynamicUpdateSliceOpIndicesIndex], - kStridedSliceOpOutputSize, &strided_slice_op_out, strided_slice_op, - kStridedSliceParamSize, &range_param, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Added strided slice op for dus"); - - //=========================================================================== - // Step 1.2 Build reshape op. Construct input tensor shape for QNN.ScatterND - // op. - // input: [x] (QNN.StridedSlice output) - // output: [[x]] - Qnn_OpConfig_t reshape_op = BuildDefaultOp(); - std::string reshpae_op_name = absl::StrFormat(kReshapeOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kReshapeOpTypeName.data(), reshape_op)); - - // Prepare reshape op output tensor. - Qnn_Tensor_t reshape_op_out = BuildDefaultTensor(); - std::vector reshape_op_out_dims = {1, 1}; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - QNN_DATATYPE_INT_32, kReshapeOpOutputRank, reshape_op_out_dims.data(), - graph_mapper, reshape_op_out)); - - // Configure reshape op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kReshapeOpInputSize, &strided_slice_op_out, kReshapeOpOutputSize, - &reshape_op_out, reshape_op, kReshapeParamSize, nullptr, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Added reshape op for dus"); - - //=========================================================================== - // Step 2 Build transpose op. Swap the first two dimensions of the input - // tensor[0] and input tensor[1]. - // op. - // input: [a, b, c, d] (LiteRT.DUS input[0]/input[1] ) - // output: [b, a, c, d] - Qnn_OpConfig_t transpose_operand_op = BuildDefaultOp(); - Qnn_OpConfig_t transpose_update_op = BuildDefaultOp(); - std::string transpose_operand_op_name = - absl::StrFormat(kTransposeOperandOpFmt, op_counter_); - std::string transpose_update_op_name = - absl::StrFormat(kTransposeUpdateOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR(SetOpInfo( - transpose_operand_op_name.c_str(), kDefaultQnnOpPackageName.data(), - kTransposeOpTypeName.data(), transpose_operand_op)); - LITERT_RETURN_IF_ERROR(SetOpInfo( - transpose_update_op_name.c_str(), kDefaultQnnOpPackageName.data(), - kTransposeOpTypeName.data(), transpose_update_op)); - - // Prepare transpose op params. - std::vector perm = {1, 0, 2, 3}; - std::vector perm_dims = {kPermParamArgSize}; - Qnn_Param_t perm_param = BuildDefaultParam(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildQnnTesnorParam( - perm.data(), perm_dims.data(), QNN_DATATYPE_UINT_32, kPermParamRank, - kPermParamName.data(), graph_mapper, perm_param)); - - // Prepare transpose op outputs. - Qnn_Tensor_t transpose_operand_op_output = BuildDefaultTensor(); - Qnn_Tensor_t transpose_update_op_output = BuildDefaultTensor(); - - // Cast const int to uint32_t. - auto cast_f = [](int const_int) { return static_cast(const_int); }; - - std::vector transpose_operand_op_output_dims( - kTransposeOpOutputRank); - std::vector transpose_update_op_output_dims(kTransposeOpOutputRank); - auto operand_dims = src.Inputs()[kDynamicUpdateSliceOpOperandIndex] - .RankedTensorType() - ->Layout() - .Dimensions(); - transpose_operand_op_output_dims[0] = cast_f(operand_dims[1]); - transpose_operand_op_output_dims[1] = cast_f(operand_dims[0]); - transpose_operand_op_output_dims[2] = cast_f(operand_dims[2]); - transpose_operand_op_output_dims[3] = cast_f(operand_dims[3]); - - auto update_dims = src.Inputs()[kDynamicUpdateSliceOpUpdateIndex] - .RankedTensorType() - ->Layout() - .Dimensions(); - transpose_update_op_output_dims[0] = cast_f(update_dims[1]); - transpose_update_op_output_dims[1] = cast_f(update_dims[0]); - transpose_update_op_output_dims[2] = cast_f(update_dims[2]); - transpose_update_op_output_dims[3] = cast_f(update_dims[3]); - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - OperandDataType, kTransposeOpOutputRank, - transpose_operand_op_output_dims.data(), graph_mapper, - transpose_operand_op_output)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - UpdateDataType, kTransposeOpOutputRank, - transpose_update_op_output_dims.data(), graph_mapper, - transpose_update_op_output)); - - // Configure transpose ops. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kTransposeOpInputSize, &qnn_op_ins[kDynamicUpdateSliceOpOperandIndex], - kTransposeOpOutputSize, &transpose_operand_op_output, - transpose_operand_op, kTransposeParamSize, &perm_param, graph_mapper)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kTransposeOpInputSize, &qnn_op_ins[kDynamicUpdateSliceOpUpdateIndex], - kTransposeOpOutputSize, &transpose_update_op_output, transpose_update_op, - kTransposeParamSize, &perm_param, graph_mapper)); - - //=========================================================================== - // Step 3 Build ScatterND op. - Qnn_OpConfig_t scatter_nd_op = BuildDefaultOp(); - std::string scatter_nd_op_name = - absl::StrFormat(kScatterNdOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR( - SetOpInfo(scatter_nd_op_name.c_str(), kDefaultQnnOpPackageName.data(), - kQnnScatterNdOpTypeName.data(), scatter_nd_op)); - - // Prepare scatter nd op output tensor. - Qnn_Tensor_t scatter_nd_op_output = BuildDefaultTensor(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - BuildAndRegisterQnnNativeTensor(OperandDataType, kScatterNDOutputRank, - transpose_operand_op_output_dims.data(), - graph_mapper, scatter_nd_op_output)); - - // Configure ScatterND op. - LITERT_STACK_ARRAY(Qnn_Tensor_t, scatter_nd_op_ins, kScatterNDOpInputSize, - QNN_TENSOR_INIT); - scatter_nd_op_ins[0] = transpose_operand_op_output; - scatter_nd_op_ins[1] = reshape_op_out; - scatter_nd_op_ins[2] = transpose_update_op_output; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kScatterNDOpInputSize, scatter_nd_op_ins, kScatterNDOpOutputSize, - &scatter_nd_op_output, scatter_nd_op, kScatterNDParamSize, nullptr, - graph_mapper)); - - //=========================================================================== - // Step 4 Build final transpose op. Swap back the first two dimensions of the - // scatter nd op output. - // op. - // input: [b, a, c, d] (QNN.ScatterND output) - // output: [a, b, c, d] - std::string transpose_result_op_name = absl::StrFormat( - kTransposeResultOpFmt, /*increase counter*/ op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(transpose_result_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kTransposeOpTypeName.data(), dest)); - - // Legalize op outputs and update scope. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Configure transpose op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kTransposeOpInputSize, &scatter_nd_op_output, kTransposeOpOutputSize, - &qnn_op_outs[0], dest, kTransposeParamSize, &perm_param, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Legalized dynamic update slice op"); - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h deleted file mode 100644 index 2a497f4f5cfcdb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DYNAMIC_UPDATE_SLICE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DYNAMIC_UPDATE_SLICE_OP_LEGALIZATION_H_ - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class DynamicUpdateSliceOpLegalization : public Legalization { - public: - DynamicUpdateSliceOpLegalization() = default; - ~DynamicUpdateSliceOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DYNAMIC_UPDATE_SLICE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc deleted file mode 100644 index ecab067e3846a5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnEmbeddingLookupOpTypeName = "Gather"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kEmbeddingLookupOpFmt = - "embedding_lookup_%d"; - -static constexpr int kReduceEmbeddingLookupOpOutputSize = 1; -static constexpr int kReduceEmbeddingLookupOpParamSize = 1; - -static constexpr int kEmbeddingLookupOpTableInputIndex = 1; -static constexpr int kEmbeddingLookupOpLookipInputIndex = 0; -static constexpr int kQnnGatherOpTableInputIndex = 0; -static constexpr int kQnnGatherOpLookupInputIndex = 1; - -LiteRtStatus EmbeddingLookupOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflEmbeddingLookup) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kEmbeddingLookupOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnEmbeddingLookupOpTypeName.data(), dest)); - - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - - LITERT_RETURN_IF_ERROR(graph_mapper.LookupInScope( - op_ins[kEmbeddingLookupOpLookipInputIndex].Get(), - qnn_op_ins[kQnnGatherOpLookupInputIndex])); - LITERT_RETURN_IF_ERROR(graph_mapper.LookupInScope( - op_ins[kEmbeddingLookupOpTableInputIndex].Get(), - qnn_op_ins[kQnnGatherOpTableInputIndex])); - - // QNN embedding_lookup op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, - kReduceEmbeddingLookupOpOutputSize, QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Construct the scalar "axis" param. - Qnn_Param_t axis_param = BuildDefaultParam(); - axis_param.paramType = QNN_PARAMTYPE_SCALAR; - axis_param.name = "axis"; - Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; - axis_scalar.dataType = QNN_DATATYPE_INT_32; - // Embedding lookup op expects axis to always be 0. - axis_scalar.int32Value = 0; - axis_param.scalarParam = axis_scalar; - - Qnn_Param_t embedding_lookup_params[] = {axis_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = op_ins.size(); - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kReduceEmbeddingLookupOpOutputSize; - dest.v1.numOfParams = kReduceEmbeddingLookupOpParamSize; - dest.v1.params = embedding_lookup_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized embedding_lookup op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h deleted file mode 100644 index e8bae779d2ae64..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class EmbeddingLookupOpLegalization : public Legalization { - public: - EmbeddingLookupOpLegalization() = default; - ~EmbeddingLookupOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc deleted file mode 100644 index fca0d31c26a987..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnFullyConnectedOpTypeName = - "FullyConnected"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kFullyConnectedOpFmt = "fully_connected_%d"; - -LiteRtStatus FullyConnectedOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kFullyConnectedOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnFullyConnectedOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Legalized fully_connected op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h deleted file mode 100644 index 0ff2983e59e708..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class FullyConnectedOpLegalization : public Legalization { - public: - FullyConnectedOpLegalization() = default; - ~FullyConnectedOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc deleted file mode 100644 index 3b769d9bd7521e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnGeluOpTypeName = "Gelu"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kGeluOpFmt = "gelu_%d"; - -LiteRtStatus GeluOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflGelu) { - return kLiteRtStatusLegalizeNoMatch; - } - const std::string op_name = absl::StrFormat(kGeluOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnGeluOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized gelu op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h deleted file mode 100644 index fdb31f5300d07c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class GeluOpLegalization : public Legalization { - public: - GeluOpLegalization() = default; - ~GeluOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc deleted file mode 100644 index d07ca4f086c708..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Ungreater required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnGreaterOpTypeName = "ElementWiseGreater"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kGreaterOpFmt = "greater_%d"; - -LiteRtStatus GreaterOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflGreater) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kGreaterOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnGreaterOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized greater op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h deleted file mode 100644 index bb353420291c00..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class GreaterOpLegalization : public Legalization { - public: - GreaterOpLegalization() = default; - ~GreaterOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h deleted file mode 100644 index 5f7c8ef96062ef..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" - -#define STRINGIFY(x) #x -#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER__) - -namespace litert::qnn { - -class Legalization { - public: - Legalization() = default; - virtual ~Legalization() = default; - - virtual LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) = 0; - - // Sets the op name, package name, and type. - // Note: All argument strings can't be de-allocated until the op has been - // registered with the qnn api. i.e graphAddNode(). - inline LiteRtStatus SetOpInfo(const char* name, const char* op_package_name, - const char* op_type, Qnn_OpConfig_t& op) { - op.v1.name = name; - op.v1.packageName = op_package_name; - op.v1.typeName = op_type; - return kLiteRtStatusOk; - } -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc deleted file mode 100644 index 23d45e4ba4a6fb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnLessOpTypeName = "ElementWiseLess"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kLessOpFmt = "less_%d"; - -LiteRtStatus LessOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflLess) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kLessOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnLessOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized less op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h deleted file mode 100644 index b16c5335f01a8e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class LessOpLegalization : public Legalization { - public: - LessOpLegalization() = default; - ~LessOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc deleted file mode 100644 index 1a1bc4dbdc7aa5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnLogicalAndOpTypeName = "ElementWiseAnd"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kLogicalAndOpFmt = "logical_and_%d"; - -LiteRtStatus LogicalAndOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflLogicalAnd) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kLogicalAndOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnLogicalAndOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized logical_and op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h deleted file mode 100644 index ec5c5c2a03bf5e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class LogicalAndOpLegalization : public Legalization { - public: - LogicalAndOpLegalization() = default; - ~LogicalAndOpLegalization() = default; - using UniquePtr = std::unique_ptr; - static UniquePtr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc deleted file mode 100644 index 4185740e2cb2b0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnMulOpTypeName = "ElementWiseMultiply"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kMulOpFmt = "mul_%d"; - -LiteRtStatus MulOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflMul) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kMulOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnMulOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized mul op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h deleted file mode 100644 index 098d0954430d50..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class MulOpLegalization : public Legalization { - public: - MulOpLegalization() = default; - ~MulOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc deleted file mode 100644 index 6e1f3d350813fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -// Pack op config. -static constexpr absl::string_view kQnnPackOpTypeName = "Pack"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kPackOpFmt = "pack_%d"; -static constexpr absl::string_view kPackOpAxisParamName = "axis"; -static constexpr int kPackOpAxisParamSize = 1; -static constexpr int kPackScalarsOpOutputRank = 2; - -// Reshape op config. -static constexpr absl::string_view kReshapeOpTypeName = "Reshape"; -static constexpr absl::string_view kReshapeOpFmt = "pack_reshape_%d"; -static constexpr int kReshapeOpInputSize = 1; -static constexpr int kReshapeOpOutputSize = 1; -static constexpr int kReshapeParamSize = 0; - -LiteRtStatus PackOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflPack) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string pack_op_name = absl::StrFormat(kPackOpFmt, op_counter_); - DumpLegalization(*src.Get()); - - // Legalize input tensors, lookup operand tensor in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - - // Legalize output tensors. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Get axis option and build QNN scalar param. - int32_t axis; - LITERT_RETURN_IF_ERROR(LiteRtGetPackAxisOption(src.Get(), &axis)); - uint32_t axis_value = static_cast(axis); - - Qnn_Param_t axis_param = BuildDefaultParam(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildQnnScalarParam( - axis_value, QNN_DATATYPE_UINT_32, kPackOpAxisParamName.data(), - graph_mapper, axis_param)); - - // Qnn does not support Packing scalars, scalar value are legalized as 1D - // tensor with single element. In such case, we need to add a reshape op to - // convert result packed 2D tensor to 1D tensor. - auto input_layout = op_ins[0].RankedTensorType()->Layout(); - if (input_layout.Rank() == 0) { - // prepare Pack op output tensor. - Qnn_Tensor_t pack_op_out = BuildDefaultTensor(); - uint32_t pack_op_out_rank = kPackScalarsOpOutputRank; - Qnn_DataType_t PackOpDataType = QNN_DATATYPE_UNDEFINED; - - LITERT_RETURN_IF_ERROR( - LegalizeElementType(op_ins[0].ElementType(), &PackOpDataType)); - std::vector pack_op_out_dims = { - static_cast(op_ins.size())}; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - PackOpDataType, pack_op_out_rank, pack_op_out_dims.data(), graph_mapper, - pack_op_out)); - - // Build Pack op. - Qnn_OpConfig_t pack_op = BuildDefaultOp(); - LITERT_RETURN_IF_ERROR(SetOpInfo(pack_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnPackOpTypeName.data(), pack_op)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - op_ins.size(), qnn_op_ins, op_outs.size(), &pack_op_out, pack_op, - kPackOpAxisParamSize, &axis_param, graph_mapper)); - - // Build Reshape op. - std::string reshape_op_name = absl::StrFormat(kReshapeOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR(SetOpInfo(reshape_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kReshapeOpTypeName.data(), dest)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kReshapeOpInputSize, &pack_op_out, kReshapeOpOutputSize, qnn_op_outs, - dest, kReshapeParamSize, nullptr, graph_mapper)); - } else { - LITERT_RETURN_IF_ERROR(SetOpInfo(pack_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnPackOpTypeName.data(), dest)); - BuildAndRegisterQnnOp(op_ins.size(), qnn_op_ins, op_outs.size(), - qnn_op_outs, dest, kPackOpAxisParamSize, &axis_param, - graph_mapper); - } - op_counter_++; - - LITERT_LOG(LITERT_INFO, "Legalized pack op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h deleted file mode 100644 index 42bd24f95b7813..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_PACK_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_PACK_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class PackOpLegalization : public Legalization { - public: - PackOpLegalization() = default; - ~PackOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_PACK_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc deleted file mode 100644 index bf16efb347447d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnConvertOpTypeName = "Convert"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kConvertOpFmt = "q_convert_%d"; - -static constexpr absl::string_view kQnnQuantizeOpTypeName = "Quantize"; -static constexpr absl::string_view kQuantizeOpFmt = "quantize_%d"; - -static constexpr absl::string_view kQnnCastOpTypeName = "Cast"; -static constexpr absl::string_view kCastOpFmt = "q_cast_%d"; - -// SFIXED_8 and UFIXED_8 offset diff -static constexpr int kSUFixed8OffsetDiff = 128; -// SFIXED_16 and UFIXED_16 offset diff -static constexpr int kSUFixed16OffsetDiff = 32768; - -LiteRtStatus QuantizeOpLegalization::LegalizeQuantizeOpAsConvertOp( - const litert::Op& src, Qnn_OpConfig_t& dest) { - std::string op_name = absl::StrFormat(kConvertOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnConvertOpTypeName.data(), dest)); - return kLiteRtStatusOk; -} - -LiteRtStatus QuantizeOpLegalization::LegalizeQuantizeOpAsCastOp( - const litert::Op& src, Qnn_OpConfig_t& dest) { - std::string op_name = absl::StrFormat(kCastOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnCastOpTypeName.data(), dest)); - return kLiteRtStatusOk; -} - -LiteRtStatus QuantizeOpLegalization::LegalizeQuantizeOpAsQuantizeOp( - const litert::Op& src, Qnn_OpConfig_t& dest) { - std::string op_name = absl::StrFormat(kQuantizeOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnQuantizeOpTypeName.data(), dest)); - return kLiteRtStatusOk; -} - -inline bool IsTensorUInt8(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::UInt8; -} -inline bool IsTensorInt8(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::Int8; -} -inline bool IsTensorUInt16(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::UInt16; -} -inline bool IsTensorInt16(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::Int16; -} - -inline bool IsTensorPerTensorQuantized(Tensor& tensor) { - return (IsTensorInt8(tensor) || IsTensorUInt8(tensor) || - IsTensorInt16(tensor) || IsTensorUInt16(tensor)) && - tensor.QTypeId() == kLiteRtQuantizationPerTensor; -} - -inline bool WithinCastRange(Tensor& input_tensor, Tensor& output_tensor, - const int offst_diff) { - return (std::fabs(input_tensor.PerTensorQuantization().scale - - output_tensor.PerTensorQuantization().scale)) < - std::numeric_limits::epsilon() && - std::abs(input_tensor.PerTensorQuantization().zero_point - - output_tensor.PerTensorQuantization().zero_point) == - offst_diff; -} - -LiteRtStatus QuantizeOpLegalization::ConfigureQnnOp(const litert::Op& src, - Qnn_OpConfig_t& dest) { - const bool is_input_tensor_per_tensor_quantized = - IsTensorPerTensorQuantized(src.Inputs().front()); - const bool is_output_tensor_per_tensor_quantized = - IsTensorPerTensorQuantized(src.Outputs().front()); - - if (is_input_tensor_per_tensor_quantized && - is_output_tensor_per_tensor_quantized) { - // Check if the input and output tensors are int8/uint8 or int16/uint16. - const bool is_input_tensor_int8 = IsTensorInt8(src.Inputs().front()); - const bool is_input_tensor_uint8 = IsTensorUInt8(src.Inputs().front()); - const bool is_input_tensor_int16 = IsTensorInt16(src.Inputs().front()); - const bool is_input_tensor_uint16 = IsTensorUInt16(src.Inputs().front()); - const bool is_output_tensor_int8 = IsTensorInt8(src.Outputs().front()); - const bool is_output_tensor_uint8 = IsTensorUInt8(src.Outputs().front()); - const bool is_output_tensor_int16 = IsTensorInt16(src.Outputs().front()); - const bool is_output_tensor_uint16 = IsTensorUInt16(src.Outputs().front()); - - if ((is_input_tensor_int8 && is_output_tensor_uint8) || - (is_input_tensor_uint8 && is_output_tensor_int8)) { - // Case if the input and output tensors are int8/uint8. - const bool is_quantization_range_within_cast_range = WithinCastRange( - src.Inputs().front(), src.Outputs().front(), kSUFixed8OffsetDiff); - if (is_quantization_range_within_cast_range) { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - LegalizeQuantizeOpAsCastOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Configured quantize op to Cast Op"); - return kLiteRtStatusOk; - } - } else if ((is_input_tensor_int16 && is_output_tensor_uint16) || - (is_input_tensor_uint16 && is_output_tensor_int16)) { - // Case if the input and output tensors are int16/uint16. - const bool is_quantization_range_within_cast_range = WithinCastRange( - src.Inputs().front(), src.Outputs().front(), kSUFixed16OffsetDiff); - if (is_quantization_range_within_cast_range) { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - LegalizeQuantizeOpAsCastOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Configured quantize op to Cast Op"); - return kLiteRtStatusOk; - } - } - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - LegalizeQuantizeOpAsConvertOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Configured quantize op to Convert Op"); - return kLiteRtStatusOk; - } - - // Not per tensor quantized, legalize to Quantize Op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(LegalizeQuantizeOpAsQuantizeOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Legalized quantize op to Quantize Op"); - return kLiteRtStatusOk; -} - -LiteRtStatus QuantizeOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflQuantize) { - return kLiteRtStatusLegalizeNoMatch; - } - ConfigureQnnOp(src, dest); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized quantize Op"); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h deleted file mode 100644 index 7621f701b87a3d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class QuantizeOpLegalization : public Legalization { - public: - QuantizeOpLegalization() = default; - ~QuantizeOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - LiteRtStatus ConfigureQnnOp(const litert::Op& src, Qnn_OpConfig_t& dest); - - private: - // Requantization: legalize quantize to QNN Convert Op. - // Quantization range is not within QNN cast Op range. - LiteRtStatus LegalizeQuantizeOpAsConvertOp(const litert::Op& src, - Qnn_OpConfig_t& dest); - - // Ignore Requantization: legalize quantize to QNN Cast Op. - // Quantization range is within QNN cast Op range. Directly use QNN Cast Op. - LiteRtStatus LegalizeQuantizeOpAsCastOp(const litert::Op& src, - Qnn_OpConfig_t& dest); - - // Quantization: legalize quantize to QNN Quantize Op. - LiteRtStatus LegalizeQuantizeOpAsQuantizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest); - - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc deleted file mode 100644 index 1127b0f3188cf3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnReshapeOpTypeName = "Reshape"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kReshapeOpFmt = "reshape_%d"; - -static constexpr int kReshapeOpInputSize = 1; -static constexpr int kReshapeOpOutputSize = 1; - -LiteRtStatus ReshapeOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflReshape) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kReshapeOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnReshapeOpTypeName.data(), dest)); - DumpLegalization(*src.Get()); - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReshapeOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // Legalize op outputs and update scope. - - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReshapeOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - dest.v1.numOfInputs = kReshapeOpInputSize; - dest.v1.inputTensors = qnn_op_ins; - - dest.v1.numOfOutputs = kReshapeOpOutputSize; - dest.v1.outputTensors = qnn_op_outs; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized reshape op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h deleted file mode 100644 index e8553639fc0906..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class ReshapeOpLegalization : public Legalization { - public: - ReshapeOpLegalization() = default; - ~ReshapeOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc deleted file mode 100644 index 363434821d6d08..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnRsqrtOpTypeName = "ElementWiseRsqrt"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kRsqrtOpFmt = "rsqrt_%d"; - -LiteRtStatus RsqrtOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflRsqrt) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kRsqrtOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnRsqrtOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized rsqrt op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h deleted file mode 100644 index 5971e9f98cd5b1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class RsqrtOpLegalization : public Legalization { - public: - RsqrtOpLegalization() = default; - ~RsqrtOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc deleted file mode 100644 index 9c6da052221bcc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSelectOpTypeName = "ElementWiseSelect"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSelectOpFmt = "select_%d"; - -LiteRtStatus SelectOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSelect && - src.Code() != kLiteRtOpCodeTflSelectV2) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kSelectOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSelectOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - - return kLiteRtStatusOk; - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized select op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h deleted file mode 100644 index 526498a4bb4b51..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SelectOpLegalization : public Legalization { - public: - SelectOpLegalization() = default; - ~SelectOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc deleted file mode 100644 index 17932971f8cee4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSinOpTypeName = "ElementWiseSin"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSinOpFmt = "sin_%d"; - -LiteRtStatus SinOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSin) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kSinOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSinOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized sin op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h deleted file mode 100644 index e87296eeb10fb2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SinOpLegalization : public Legalization { - public: - SinOpLegalization() = default; - ~SinOpLegalization() = default; - using UniquePtr = std::unique_ptr; - static UniquePtr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc deleted file mode 100644 index 6749bd654eb51c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h" - -#include -#include - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSliceOpTypeName = "StridedSlice"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSliceOpFmt = "slice_%d"; - -static constexpr int kSliceOpInputSize = 1; -static constexpr int kSliceOpOutputSize = 1; -static constexpr int kSliceOpParamSize = 1; -// QNN StridedSlice op packs "start", "end", and "stride" into a single tensor -// param "ranges". -static constexpr int kRangesParamArgSize = 3; -static constexpr int kRangesParamRank = 2; - -LiteRtStatus SliceOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSlice) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kSliceOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSliceOpTypeName.data(), dest)); - - // QNN strided slice op expects 1 input tensor. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSliceOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // QNN strided slice op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSliceOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - const auto& src_input_tensor = op_ins.front(); - auto src_input_tensor_type = src_input_tensor.RankedTensorType(); - if (!src_input_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - src_input_tensor_type.Error().Message().c_str()); - return src_input_tensor_type.Error().Status(); - } - - auto src_input_tensor_rank = src_input_tensor_type->Layout().Rank(); - - // Prepare qnn strided slice parameters. - - auto src_begin_indices = op_ins.at(1).WeightsData(); - if (!src_begin_indices) { - return src_begin_indices.Error().Status(); - } - - auto src_size_indices = op_ins.at(2).WeightsData(); - if (!src_size_indices) { - return src_size_indices.Error().Status(); - } - - // Check if src_begin_indices and src_size_indices are weights tensors. - if (src_begin_indices->empty() || src_size_indices->empty()) { - return kLiteRtStatusErrorInvalidLegalization; - } - - LITERT_STACK_ARRAY(int32_t, range_tensor_data, - src_input_tensor_rank* kRangesParamArgSize, - /*init value*/ 0); - for (int i = 0; i < src_input_tensor_rank; ++i) { - // Copy begin, end, and stride values from src_begin_indices and - // src_size_indices to range_tensor_data. Stride is always 1. - range_tensor_data[i * kRangesParamArgSize] = src_begin_indices->at(i); - range_tensor_data[i * kRangesParamArgSize + 1] = - src_begin_indices->at(i) + src_size_indices->at(i); - range_tensor_data[i * kRangesParamArgSize + 2] = 1; - } - - Qnn_ClientBuffer_t range_tensor_client_buf = BuildDefaultClientBuffer(); - range_tensor_client_buf.data = range_tensor_data; - range_tensor_client_buf.dataSize = - src_input_tensor_rank * kRangesParamArgSize * sizeof(int32_t); - - // Construct the const tensor "ranges". - Qnn_Tensor_t range_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(range_tensor); - range_tensor.v2.dataType = QNN_DATATYPE_INT_32; - range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - range_tensor.v2.rank = kRangesParamRank; - range_tensor.v2.dimensions = new uint32_t[kRangesParamRank]; - range_tensor.v2.dimensions[0] = src_input_tensor_rank; - range_tensor.v2.dimensions[1] = kRangesParamArgSize; - range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - range_tensor.v2.clientBuf = range_tensor_client_buf; - range_tensor.v2.isDynamicDimensions = nullptr; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &range_tensor)); - - Qnn_Param_t range_param = BuildDefaultParam(); - range_param.paramType = QNN_PARAMTYPE_TENSOR; - range_param.name = "ranges"; - range_param.tensorParam = range_tensor; - - Qnn_Param_t strided_slice_params[] = {range_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kSliceOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kSliceOpOutputSize; - dest.v1.numOfParams = kSliceOpParamSize; - dest.v1.params = strided_slice_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized slice op", ""); - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h deleted file mode 100644 index 1430d1e1fa43ca..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SliceOpLegalization : public Legalization { - public: - SliceOpLegalization() = default; - ~SliceOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc deleted file mode 100644 index c974e6f462fda2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSoftmaxOpTypeName = "Softmax"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSoftmaxOpFmt = "softmax_%d"; - -static constexpr int kSoftmaxOpInputSize = 1; -static constexpr int kSoftmaxOpOutputSize = 1; -static constexpr int kSoftmaxOpParamSize = 1; - -LiteRtStatus SoftmaxOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSoftmax) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kSoftmaxOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSoftmaxOpTypeName.data(), dest)); - - // QNN reduce softmax op expects 1 input tensor. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSoftmaxOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // QNN softmax op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSoftmaxOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Prepare QNN reduce softmax parameters. - - // Extract beta option from softmax op. - float beta; - LITERT_RETURN_IF_ERROR(LiteRtGetSoftmaxBetaOption(src.Get(), &beta)); - Qnn_Param_t beta_param = BuildDefaultParam(); - beta_param.paramType = QNN_PARAMTYPE_SCALAR; - beta_param.name = "beta"; - Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; - keep_dims_scalar.dataType = QNN_DATATYPE_FLOAT_32; - keep_dims_scalar.floatValue = beta; - beta_param.scalarParam = keep_dims_scalar; - - Qnn_Param_t reduce_softmax_params[] = {beta_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kSoftmaxOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kSoftmaxOpOutputSize; - dest.v1.numOfParams = kSoftmaxOpParamSize; - dest.v1.params = reduce_softmax_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized softmax op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h deleted file mode 100644 index b4ecb005003c91..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SoftmaxOpLegalization : public Legalization { - public: - SoftmaxOpLegalization() = default; - ~SoftmaxOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc deleted file mode 100644 index 09ff1cbbc4dcfe..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSubOpTypeName = "ElementWiseSubtract"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSubOpFmt = "sub_%d"; - -LiteRtStatus SubOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSub) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kSubOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSubOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized sub op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h deleted file mode 100644 index 3f05f8e04a7d3e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SubOpLegalization : public Legalization { - public: - SubOpLegalization() = default; - ~SubOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc deleted file mode 100644 index 40fe0c10f878c0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSumOpTypeName = "ReduceSum"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSumOpFmt = "sum_%d"; - -static constexpr int kReduceSumOpInputSize = 1; -static constexpr int kReduceSumOpOutputSize = 1; -static constexpr int kReduceSumOpParamSize = 1; -static constexpr int kReduceSumOpParamRank = 1; - -LiteRtStatus SumOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSum) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kSumOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSumOpTypeName.data(), dest)); - - // QNN reduce sum op expects 1 input tensor. - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReduceSumOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(src.Inputs().front().Get(), qnn_op_ins[0])); - - // QNN sum op expects 1 output tensor. - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReduceSumOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR(graph_mapper.LegalizeAndRegister( - src.Outputs().front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(src.Outputs().front().Get(), qnn_op_outs[0])); - - // Prepare QNN reduce sum parameters. - const auto inputs = src.Inputs(); - const auto& src_axes = inputs.at(1); - - // Check if src_axes are weights tensors. - if (!src_axes.HasWeights()) { - LITERT_LOG(LITERT_ERROR, "Sum op axes are not weights tensors"); - return kLiteRtStatusErrorInvalidLegalization; - } - - auto src_axes_tensor_type = src_axes.RankedTensorType(); - if (!src_axes_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - src_axes_tensor_type.Error().Message().c_str()); - return src_axes_tensor_type.Error().Status(); - } - - int32_t dest_axes_size = src_axes_tensor_type->Layout().Dimensions()[0]; - auto src_axes_data = src_axes.Weights().Bytes(); - Qnn_ClientBuffer_t axes_tensor_client_buf = BuildDefaultClientBuffer(); - axes_tensor_client_buf.data = (void*)src_axes_data.data(); - axes_tensor_client_buf.dataSize = src_axes_data.size(); - - // Extract keepdims option from sum op. - bool keep_dims; - LITERT_RETURN_IF_ERROR(LiteRtGetSumKeepDimsOption(src.Get(), &keep_dims)); - - // Construct the scalar "keep_dims" param. - if (keep_dims) { - Qnn_Param_t range_param = BuildDefaultParam(); - range_param.paramType = QNN_PARAMTYPE_SCALAR; - range_param.name = "keep_dims"; - Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; - keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8; - keep_dims_scalar.bool8Value = true; - range_param.scalarParam = keep_dims_scalar; - } - - // Construct the const tensor "axes". - Qnn_Tensor_t range_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(range_tensor); - range_tensor.v2.dataType = QNN_DATATYPE_INT_32; - range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - range_tensor.v2.rank = kReduceSumOpParamRank; - range_tensor.v2.dimensions = new uint32_t[kReduceSumOpParamRank]; - range_tensor.v2.dimensions[0] = dest_axes_size; - range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - range_tensor.v2.clientBuf = axes_tensor_client_buf; - range_tensor.v2.isDynamicDimensions = nullptr; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &range_tensor)); - - Qnn_Param_t range_param = BuildDefaultParam(); - range_param.paramType = QNN_PARAMTYPE_TENSOR; - range_param.name = "axes"; - range_param.tensorParam = range_tensor; - - Qnn_Param_t reduce_sum_params[] = {range_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kReduceSumOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kReduceSumOpOutputSize; - dest.v1.numOfParams = kReduceSumOpParamSize; - dest.v1.params = reduce_sum_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized sum op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h deleted file mode 100644 index a50e946ad069b8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SumOpLegalization : public Legalization { - public: - SumOpLegalization() = default; - ~SumOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc deleted file mode 100644 index 121c564c1e95fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnTanhOpTypeName = "Tanh"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kTanhOpFmt = "tanh_%d"; - -LiteRtStatus TanhOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflTanh) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kTanhOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnTanhOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized tanh op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h deleted file mode 100644 index 486e321ae8e2d3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class TanhOpLegalization : public Legalization { - public: - TanhOpLegalization() = default; - ~TanhOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc deleted file mode 100644 index 487ecce2e66d79..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnTransposeOpTypeName = "Transpose"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kTransposeOpFmt = "transpose_%d"; - -static constexpr int kTransposeOpInputSize = 1; -static constexpr int kTransposeOpOutputSize = 1; -static constexpr int kTransposeOpParamSize = 1; -static constexpr int kTransposeOpParamRank = 1; - -LiteRtStatus TransposeOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflTranspose) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kTransposeOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnTransposeOpTypeName.data(), dest)); - - // QNN transpose op expects 1 input tensor. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kTransposeOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // QNN transpose op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kTransposeOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Prepare QNN transpose parameters. - auto perm = Tensor(op_ins.at(1).Get()); - - // Check if src_axes are weights tensors. - if (!perm.HasWeights()) { - return kLiteRtStatusErrorInvalidLegalization; - } - auto perm_data = perm.Weights().Bytes(); - int32_t dest_axes_size = perm_data.size(); - Qnn_ClientBuffer_t perm_tensor_client_buf = BuildDefaultClientBuffer(); - perm_tensor_client_buf.data = (void*)perm_data.data(); - perm_tensor_client_buf.dataSize = dest_axes_size; - - // Construct the const tensor "perm". - Qnn_Tensor_t perm_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(perm_tensor); - perm_tensor.v2.dataType = QNN_DATATYPE_INT_32; - perm_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - perm_tensor.v2.rank = kTransposeOpParamRank; - perm_tensor.v2.dimensions = new uint32_t[kTransposeOpParamRank]; - perm_tensor.v2.dimensions[0] = dest_axes_size; - perm_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - perm_tensor.v2.clientBuf = perm_tensor_client_buf; - perm_tensor.v2.isDynamicDimensions = nullptr; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &perm_tensor)); - - Qnn_Param_t perm_param = BuildDefaultParam(); - perm_param.paramType = QNN_PARAMTYPE_TENSOR; - perm_param.name = "perm"; - perm_param.tensorParam = perm_tensor; - - Qnn_Param_t transpose_params[] = {perm_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kTransposeOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kTransposeOpOutputSize; - dest.v1.numOfParams = kTransposeOpParamSize; - dest.v1.params = transpose_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized transpose op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h deleted file mode 100644 index 39d7fc645c8e80..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class TransposeOpLegalization : public Legalization { - public: - TransposeOpLegalization() = default; - ~TransposeOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc deleted file mode 100644 index 5cd0646a907fc0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -using ::litert::internal::Dump; -using ::litert::internal::DumpOptions; - -// Dump source Op details. -void DumpLegalization(const LiteRtOpT& op) { - std::ostringstream dump; - // TODO Make dump tools part of stable api. - Dump(op, dump); - DumpOptions(op, dump); - std::string s = dump.str(); - LITERT_LOG(LITERT_INFO, "%s", s.data()); -} - -LiteRtStatus LegalizeSimpleOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - DumpLegalization(*src.Get()); - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - - // Legalize op outputs and update scope. - - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), - QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_out = qnn_op_outs; - for (const auto& op_out : op_outs) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_out.Get(), *cur_qnn_op_out)); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_out.Get(), *cur_qnn_op_out)); - ++cur_qnn_op_out; - } - dest.v1.numOfInputs = op_ins.size(); - dest.v1.inputTensors = qnn_op_ins; - - dest.v1.numOfOutputs = op_outs.size(); - dest.v1.outputTensors = qnn_op_outs; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - return kLiteRtStatusOk; -} - -LiteRtStatus BuildAndRegisterQnnNativeTensor(Qnn_DataType_t param_data_type, - uint32_t rank, uint32_t* dims, - GraphMapper& graph_mapper, - Qnn_Tensor_t& tensor) { - graph_mapper.AssignTensorName(tensor); - tensor.v2.dataType = param_data_type; - tensor.v2.type = QNN_TENSOR_TYPE_NATIVE; - tensor.v2.rank = rank; - tensor.v2.dimensions = dims; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &tensor)); - return kLiteRtStatusOk; -} - -LiteRtStatus BuildAndRegisterQnnOp(uint32_t input_size, Qnn_Tensor_t* op_ins, - uint32_t output_size, Qnn_Tensor_t* op_outs, - Qnn_OpConfig_t& op, uint32_t param_size, - Qnn_Param_t* params, - GraphMapper& graph_mapper) { - op.v1.numOfInputs = input_size; - op.v1.inputTensors = op_ins; - op.v1.numOfOutputs = output_size; - op.v1.outputTensors = op_outs; - op.v1.numOfParams = param_size; - op.v1.params = params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), op)); - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h deleted file mode 100644 index fb80708537b7e9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ - -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" - -namespace litert::qnn { - -// Use this function to legalize a LiteRtOp to a Qnn Op when: -// 1. Source input/output tensor and destination input/ouptut tensor are 1 : 1 -// mapped -// 2. Assigning params to destination OP does not depending on input tensor of -// source OP. -LiteRtStatus LegalizeSimpleOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - -// Dump source Op details. -void DumpLegalization(const LiteRtOpT& op); - -// Build and register a QNN native tensor in the QNN graph. -LiteRtStatus BuildAndRegisterQnnNativeTensor(Qnn_DataType_t param_data_type, - uint32_t rank, uint32_t* dims, - GraphMapper& graph_mapper, - Qnn_Tensor_t& tensor); - -// Build and register a QNN op in the QNN graph. -LiteRtStatus BuildAndRegisterQnnOp(uint32_t input_size, Qnn_Tensor_t* op_ins, - uint32_t output_size, Qnn_Tensor_t* op_outs, - Qnn_OpConfig_t& op, uint32_t param_size, - Qnn_Param_t* params, - GraphMapper& graph_mapper); - -// Build and register a QNN tensor param in the QNN graph. -template -LiteRtStatus BuildQnnTesnorParam(T* param_data, uint32_t* param_dims, - Qnn_DataType_t param_data_type, - uint32_t param_rank, const char* param_name, - GraphMapper& graph_mapper, - Qnn_Param_t& param) { - // Build ClientBuffer for the param tensor. - Qnn_ClientBuffer_t tensor_client_buf = BuildDefaultClientBuffer(); - tensor_client_buf.data = param_data; - tensor_client_buf.dataSize = sizeof(param_data); - - // Build QNN param tensor. - Qnn_Tensor_t param_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(param_tensor); - param_tensor.v2.dataType = param_data_type; - param_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - param_tensor.v2.rank = param_rank; - param_tensor.v2.dimensions = param_dims; - param_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - param_tensor.v2.clientBuf = tensor_client_buf; - - // Register param tensor in QNN graph. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - ¶m_tensor)); - param.paramType = QNN_PARAMTYPE_TENSOR; - param.name = param_name; - param.tensorParam = param_tensor; - return kLiteRtStatusOk; -} - -template -LiteRtStatus BuildQnnScalarParam(T& param_data, Qnn_DataType_t param_data_type, - const char* param_name, - GraphMapper& graph_mapper, - Qnn_Param_t& param) { - // Build QNN scalar. - Qnn_Scalar_t scalar = QNN_SCALAR_INIT; - scalar.dataType = param_data_type; - - // Build QNN scalar param. - switch (param_data_type) { - case QNN_DATATYPE_BOOL_8: - scalar.bool8Value = param_data; - break; - case QNN_DATATYPE_UINT_32: - scalar.uint32Value = param_data; - break; - case QNN_DATATYPE_INT_32: - scalar.int32Value = param_data; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - param.paramType = QNN_PARAMTYPE_SCALAR; - param.name = param_name; - param.scalarParam = scalar; - return kLiteRtStatusOk; -} - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc deleted file mode 100644 index a26a91a030d23b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc +++ /dev/null @@ -1,423 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -using ::litert::qnn::QnnManager; -using LiteRtBufferId = uint32_t; -using LiteRtContextHandleIdx = uint32_t; -using WeightSharingMap = - absl::flat_hash_map; - -// -// Configurations -// - -namespace { - -constexpr char kPluginManufacturer[] = "Qualcomm"; -constexpr LiteRtParamIndex kDefaultPartitionIndex = 0; - -// clang-format off -constexpr std::pair kPluginSocModels[] = { - {"V68", QNN_HTP_DEVICE_ARCH_V68}, - {"V69", QNN_HTP_DEVICE_ARCH_V69}, - {"V73", QNN_HTP_DEVICE_ARCH_V73}, - {"V75", QNN_HTP_DEVICE_ARCH_V75}, - {"V79", QNN_HTP_DEVICE_ARCH_V79}, -}; - -constexpr const char* kSocModelsSupportsWeightSharing[] = { - "V73", - "V75", - "V79", -}; -// clang-format on - -static constexpr absl::string_view kEntryPointNameFmt = "qnn_partition_%d"; - -constexpr auto kNumPluginSocModels = - sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); - -std::optional FindSocModel( - absl::string_view soc_model_name) { - std::optional soc_model; - for (auto i = 0; i < kNumPluginSocModels; ++i) { - if (soc_model_name == kPluginSocModels[i].first) { - soc_model = kPluginSocModels[i].second; - break; - } - } - return soc_model; -} - -bool IsWeightSharingSupported(absl::string_view soc_model_name) { - return std::find(std::begin(kSocModelsSupportsWeightSharing), - std::end(kSocModelsSupportsWeightSharing), - soc_model_name) != std::end(kSocModelsSupportsWeightSharing); -} - -} // namespace - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (api_version == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return kPluginManufacturer; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorNpu; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = kNumPluginSocModels; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { - return kLiteRtStatusErrorInvalidArgument; - } - *soc_model_name = kPluginSocModels[soc_model_idx].first; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -struct LiteRtCompiledResultT { - std::vector> context_bin; - std::vector graph_names; -}; - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result || !byte_code || !byte_code_size) { - return kLiteRtStatusErrorInvalidArgument; - } - - *byte_code = compiled_result->context_bin[byte_code_idx].data(); - *byte_code_size = compiled_result->context_bin[byte_code_idx].size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (!compiled_result || !call_info || !call_info_size) { - return kLiteRtStatusErrorInvalidArgument; - } else if (call_idx >= compiled_result->graph_names.size()) { - return kLiteRtStatusErrorIndexOOB; - } - - *call_info = compiled_result->graph_names.at(call_idx).data(); - *call_info_size = compiled_result->graph_names.at(call_idx).size(); - *byte_code_idx = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result || !num_calls) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->graph_names.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - *num_byte_code = compiled_result->context_bin.size(); - return kLiteRtStatusOk; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LiteRtCompilerPluginT { - // A "key-only" flag will have an empty string as the value. - using Flag = std::pair; - std::vector flags; -}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - auto& flags = compiler_plugin->flags; - flags.resize(num_flags); - for (int i = 0; i < num_flags; ++i) { - auto& flag = flags[i]; - flag.first = std::string(keys[i]); - flag.second = std::string(values[i]); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - auto* plugin = new LiteRtCompilerPluginT; - *compiler_plugin = plugin; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ::litert::Subgraph graph(subgraph); - - auto backend_configs = QnnManager::DefaultBackendConfigs(); - // TODO: pass soc_model as parameter - auto qnn_manager = QnnManager::Create(backend_configs, std::nullopt, - {QNN_HTP_DEVICE_ARCH_V75}); - if (!qnn_manager) { - LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().data()); - return qnn_manager.Error().Status(); - } - LITERT_LOG(LITERT_INFO, "%s", "QNN manager created"); - - for (const auto& op : graph.Ops()) { - // default constructed, won't add tensor to QNN - ::qnn::TensorPool tensor_pool; - std::vector<::qnn::TensorWrapperRef> input_tensors; - for (const auto& input : op.Inputs()) { - ::qnn::TensorWrapper* res{nullptr}; - LITERT_RETURN_IF_ERROR( - litert::qnn::ConvertTensor(input, tensor_pool, res)); - input_tensors.emplace_back(*res); - } - - std::vector<::qnn::TensorWrapperRef> output_tensors; - for (const auto& output : op.Outputs()) { - ::qnn::TensorWrapper* res{nullptr}; - LITERT_RETURN_IF_ERROR( - litert::qnn::ConvertTensor(output, tensor_pool, res)); - output_tensors.emplace_back(*res); - } - - std::vector<::qnn::OpWrapper> op_wrappers; - LITERT_RETURN_IF_ERROR(litert::qnn::ConvertOp( - op, tensor_pool, input_tensors, output_tensors, op_wrappers)); - tensor_pool.ForEach([](::qnn::TensorWrapper& tensor_wrapper) { - // TODO(chunhsue): Use compile interface to get useQInt16AsQUint16. - constexpr bool useQInt16AsQUint16 = true; - if constexpr (useQInt16AsQUint16) { - tensor_wrapper.ConvertQint16ToQuint16(); - } - }); - // Empty op_wrappers means the op is not supported by QNN. - if (op_wrappers.empty()) { - continue; - } - if (std::all_of( - op_wrappers.begin(), op_wrappers.end(), - [&qnn_manager](::qnn::OpWrapper& op_wrapper) -> bool { - return kLiteRtStatusOk == - (*qnn_manager)->ValidateOp(op_wrapper.GetOpConfig()); - })) { - LITERT_RETURN_IF_ERROR( - // Use default partition index if vendor doesn't support multiple - // partitions. - LiteRtPushOp(selected_ops, op.Get(), kDefaultPartitionIndex)); - } - } - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - - LITERT_LOG(LITERT_INFO, - "Starting QNN Compilation for %d subgraphs, soc_model=%s", - num_partitions, soc_model); - - auto opt_soc_model = soc_model ? FindSocModel(soc_model) : std::nullopt; - if (opt_soc_model) { - LITERT_LOG(LITERT_ERROR, "Compiling QNN architecture: %d", *opt_soc_model); - } else if (soc_model) { - LITERT_LOG(LITERT_ERROR, "Unexpected SoC model: %s", soc_model); - return kLiteRtStatusErrorInvalidArgument; - } - - auto result = std::make_unique(); - // Prepare one context binary per partition, since each partition is a - // separate subgraph that maps to a single Dispatch Op in the compiled the - // model. - result->context_bin.resize(num_partitions); - - // Initialize SDK and load qnn shared libraries. - LITERT_LOG(LITERT_INFO, "%s", "Creating QNN manager"); - auto backend_configs = QnnManager::DefaultBackendConfigs(); - auto qnn_manager = QnnManager::Create( - backend_configs, /*shared_library_dir=*/std::nullopt, opt_soc_model); - if (!qnn_manager) { - LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().c_str()); - return qnn_manager.Error().Status(); - } - LITERT_LOG(LITERT_INFO, "%s", "QNN manager created"); - - // Map of LiteRt buffer id to context handle index. - // This map memerizes the last context handle index of a weight was registered - // in. - WeightSharingMap weight_sharing_map; - LiteRtContextHandleIdx next_context_handle_idx = 0; - - std::vector context_handles; - - // Compile each partition (subgraph) individually. - for (int partition_idx = 0; partition_idx < num_partitions; ++partition_idx) { - LiteRtContextHandleIdx context_handle_idx = next_context_handle_idx; - uint64_t largest_weight_size = 0; - // Check all weights in this subgraph, see if any of them were previously - // seen and added to existing qnn context, use the largest weight size to - // determine which context to use. - for (const auto& op : model.Subgraph(partition_idx)->Ops()) { - for (const auto& input : op.Inputs()) { - if (input.IsConstant()) { - auto buffer_id = input.Weights().Get()->GetBufferId(); - auto it = weight_sharing_map.find(buffer_id); - if (it != weight_sharing_map.end()) { - if (input.Weights().Get()->Buffer().Size() >= largest_weight_size) { - context_handle_idx = it->second; - largest_weight_size = input.Weights().Get()->Buffer().Size(); - } - } - } - } - } - // If we didn't find a existing context handle for this subgraph, create a - // new one. - if (context_handle_idx == next_context_handle_idx) { - // Initialize context. - LITERT_LOG(LITERT_INFO, "%s", "Creating context handle"); - // We enable weight sharing by default, this could lead to issue when - // support legacy SoC. - // TODO: use option to control weight sharing. - auto context_configs = QnnManager::WeightSharingContextConfigs(); - if (!IsWeightSharingSupported(soc_model)) { - context_configs = QnnManager::DefaultContextConfigs(); - } - auto context_handle = - (*qnn_manager)->CreateContextHandle(context_configs); - if (!context_handle) { - LITERT_LOG(LITERT_ERROR, "%s", - context_handle.Error().Message().c_str()); - return context_handle.Error().Status(); - } - context_handles.push_back(std::move(context_handle.Value())); - LITERT_LOG(LITERT_INFO, "%s", "Context handle created"); - ++next_context_handle_idx; - } - // Set context handle index for all weight buffers in this subgraph. - for (const auto& op : model.Subgraph(partition_idx)->Ops()) { - for (const auto& input : op.Inputs()) { - if (input.IsConstant()) { - auto buffer_id = input.Weights().Get()->GetBufferId(); - weight_sharing_map[buffer_id] = context_handle_idx; - } - } - } - - // Compose graphs. - LITERT_LOG(LITERT_INFO, "%s", "Composing graph"); - std::string& entry_point_name = result->graph_names.emplace_back(); - entry_point_name = absl::StrFormat(kEntryPointNameFmt, partition_idx); - LiteRtSubgraph partition = model.Subgraph(partition_idx)->Get(); - LITERT_RETURN_IF_ERROR(litert::qnn::ComposeGraph( - **qnn_manager, context_handles[context_handle_idx].get(), partition, - entry_point_name)); - LITERT_LOG(LITERT_INFO, "%s", "Graph composed"); - } - - // Generate context binary. - result->context_bin.resize(next_context_handle_idx); - for (int i = 0; i < next_context_handle_idx; ++i) { - LITERT_LOG(LITERT_INFO, "%s", "Generating context binary"); - LITERT_RETURN_IF_ERROR((*qnn_manager) - ->GenerateContextBinary(context_handles[i].get(), - result->context_bin[i])); - LITERT_LOG(LITERT_INFO, "Context binary %d generated", i); - } - *compiled_result = result.release(); - - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc deleted file mode 100644 index 2b6016a0578d15..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc +++ /dev/null @@ -1,399 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h" - -namespace litert { -namespace { - -using ::testing::Values; - -// clang-format off -// TODO: Add support and uncomment these models. -const auto kSupportedOps = - Values( - "rms_norm_composite.tflite", - "simple_add_op.tflite", - "simple_div_op.tflite", - "simple_mul_op.tflite", - "simple_rsqrt_op.tflite", - "simple_slice_op.tflite", - "simple_sub_op.tflite", - "simple_sum_op.tflite", - "simple_tanh_op.tflite", - "simple_reshape_op.tflite", - "simple_batch_matmul_op.tflite", - "rms_norm.tflite", - "simple_concatenation_op.tflite", - "simple_softmax_op.tflite", - "simple_cast_op.tflite", - "simple_transpose_op.tflite", - "simple_sin_op.tflite", - "simple_cos_op.tflite", - "simple_select_op.tflite", - "simple_select_v2_op.tflite", - "simple_fully_connected_op.tflite", - "fully_connected_3d.tflite", - "simple_embedding_lookup_op.tflite", - "simple_logical_and_op.tflite", - "simple_less_op.tflite", - "simple_greater_op.tflite", - "simple_gelu_op.tflite", - "simple_dynamic_update_slice_op.tflite", - "simple_pack_op.tflite", - "simple_gather_op.tflite", - "simple_mean_op.tflite", - "simple_split_op.tflite", - "simple_average_poll_2d.tflite", - "simple_conv_2d_op.tflite", - "simple_depth_to_space_op.tflite", - "simple_depthwise_conv_2d_op.tflite", - "simple_hard_swish_op.tflite", - "simple_leaky_relu_op.tflite", - "simple_resize_bilinear_op.tflite", - "simple_space_to_depth_op.tflite", - "simple_resize_nearest_neighbor_op.tflite", - "simple_relu_op.tflite", - kFeedForwardModel, - kKeyEinsumModel, - kQueryEinsumModel, - kValueEinsumModel, - kAttnVecEinsumModel, - kROPEModel, - kLookUpROPEModel, - kRMSNormModel, - kSDPAModel, - kAttentionModel, - kTransformerBlockModel, - kQSimpleMul16x16Model, - kQMulAdd16x16Model, - kQQueryEinsum16x8Model, - kQKeyEinsum16x8Model, - kQVauleEinsum16x8Model, - kQAttnVecEinsum16x8Model - ); - -const auto kSupportedSocModels = Values( - "V68", - "V69", - "V73", - "V75", - "V79" -); -// clang-format on - -TEST(TestQnnPlugin, GetConfigInfo) { - EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "Qualcomm"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 5); - - const char* config_id; - LITERT_ASSERT_OK( - LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id)); - EXPECT_STREQ(config_id, "V68"); -} - -TEST(TestQnnPlugin, PartitionMulOps) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Subgraph(0)->Get(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 1); - EXPECT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(TestQnnPlugin, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("qnn_partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -TEST(TestQnnPlugin, ShareContextBinary) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("cst_multi_subgraph.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - uint64_t num_byte_code; - LITERT_ASSERT_OK( - LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code)); - ASSERT_EQ(num_byte_code, 1); - - LiteRtDestroyCompiledResult(compiled); -} - -TEST(TestQnnPlugin, NotShareContextBinary) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("multi_subgraph.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - uint64_t num_byte_code; - LITERT_ASSERT_OK( - LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code)); - ASSERT_EQ(num_byte_code, 3); - - LiteRtDestroyCompiledResult(compiled); -} - -TEST(TestLegalization, QuantizeOpLegalizedToCastOp) { - static constexpr absl::string_view kQnnOpName = "Cast"; - static constexpr int kSUFixed8OffsetDiff = 128; - const auto input_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/0); - const auto output_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/kSUFixed8OffsetDiff); - LiteRtOpT quantize_op; - LiteRtTensorT input_tensor; - LiteRtTensorT output_tensor; - // Set quantization params, tensor type for input and output tensors. - input_tensor.SetQarams(input_quantization_params); - TensorType input_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeInt8, {1, 1}); - input_tensor.SetType(input_tensor_type); - output_tensor.SetQarams(output_quantization_params); - TensorType output_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeUInt8, {1, 1}); - output_tensor.SetType(output_tensor_type); - quantize_op.Inputs().push_back(&input_tensor); - quantize_op.Outputs().push_back(&output_tensor); - quantize_op.SetOpCode(kLiteRtOpCodeTflQuantize); - - qnn::QuantizeOpLegalization legalization; - Qnn_OpConfig_t legalized_qnn_op = qnn::BuildDefaultOp(); - litert::Op litert_quantize_op(&quantize_op); - LITERT_ASSERT_OK( - legalization.ConfigureQnnOp(litert_quantize_op, legalized_qnn_op)); - absl::string_view qnn_op_name(legalized_qnn_op.v1.typeName); - EXPECT_EQ(qnn_op_name, kQnnOpName); -} - -TEST(TestLegalization, QuantizeOpLegalizedToConvertOp) { - static constexpr absl::string_view kQnnOpName = "Convert"; - static constexpr int kSUFixed8OffsetDiff = 0; - const auto input_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/0); - const auto output_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/kSUFixed8OffsetDiff); - LiteRtOpT quantize_op; - LiteRtTensorT input_tensor; - LiteRtTensorT output_tensor; - // Set quantization params, tensor type for input and output tensors. - input_tensor.SetQarams(input_quantization_params); - TensorType input_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeInt8, {1, 1}); - input_tensor.SetType(input_tensor_type); - output_tensor.SetQarams(output_quantization_params); - TensorType output_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeUInt8, {1, 1}); - output_tensor.SetType(output_tensor_type); - quantize_op.Inputs().push_back(&input_tensor); - quantize_op.Outputs().push_back(&output_tensor); - quantize_op.SetOpCode(kLiteRtOpCodeTflQuantize); - - qnn::QuantizeOpLegalization legalization; - Qnn_OpConfig_t legalized_qnn_op = qnn::BuildDefaultOp(); - litert::Op litert_quantize_op(&quantize_op); - LITERT_ASSERT_OK( - legalization.ConfigureQnnOp(litert_quantize_op, legalized_qnn_op)); - absl::string_view qnn_op_name(legalized_qnn_op.v1.typeName); - EXPECT_EQ(qnn_op_name, kQnnOpName); -} - -TEST(TestLegalization, QuantizeOpLegalizedToQuantizeOp) { - static constexpr absl::string_view kQnnOpName = "Quantize"; - const auto output_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/0); - LiteRtOpT quantize_op; - LiteRtTensorT input_tensor; - LiteRtTensorT output_tensor; - // Set quantization params, tensor type for input and output tensors. - TensorType input_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeFloat32, {1, 1}); - input_tensor.SetType(input_tensor_type); - output_tensor.SetQarams(output_quantization_params); - TensorType output_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeInt16, {1, 1}); - output_tensor.SetType(output_tensor_type); - quantize_op.Inputs().push_back(&input_tensor); - quantize_op.Outputs().push_back(&output_tensor); - quantize_op.SetOpCode(kLiteRtOpCodeTflQuantize); - - qnn::QuantizeOpLegalization legalization; - Qnn_OpConfig_t legalized_qnn_op = qnn::BuildDefaultOp(); - litert::Op litert_quantize_op(&quantize_op); - LITERT_ASSERT_OK( - legalization.ConfigureQnnOp(litert_quantize_op, legalized_qnn_op)); - absl::string_view qnn_op_name(legalized_qnn_op.v1.typeName); - EXPECT_EQ(qnn_op_name, kQnnOpName); -} - -class QnnPlyginSupportedSocCompilationTest - : public ::testing::TestWithParam {}; - -TEST_P(QnnPlyginSupportedSocCompilationTest, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("one_mul.tflite"); - auto soc_model = GetParam(); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(plugin.get(), soc_model.c_str(), - model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("qnn_partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPlyginSupportedSocCompilationTest, - kSupportedSocModels); - -class QnnPluginOpValidationTest : public ::testing::TestWithParam { -}; - -TEST_P(QnnPluginOpValidationTest, SupportedOpsTest) { - LITERT_LOG(LITERT_INFO, "Validating TFLite model: %s", GetParam().c_str()); - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel(GetParam()); - - const auto subgraph = model.MainSubgraph(); - LiteRtSubgraph litert_subgraph = subgraph->Get(); - - LiteRtOpListT selected_ops; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, litert_subgraph, &selected_ops)); - - EXPECT_EQ(selected_ops.Values().size(), litert_subgraph->Ops().size()); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPluginOpValidationTest, - kSupportedOps); - -class QnnPluginOpCompatibilityTest - : public ::testing::TestWithParam {}; - -TEST_P(QnnPluginOpCompatibilityTest, SupportedOpsTest) { - LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str()); - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel(GetParam()); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("qnn_partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPluginOpCompatibilityTest, - kSupportedOps); - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc deleted file mode 100644 index 1c1a7fca31408d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc +++ /dev/null @@ -1,758 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -using ::litert::internal::Dump; - -LiteRtStatus ConvertPaddingType(const uint32_t litert_padding, - ::qnn::PaddingType& qnn_padding) { - switch (litert_padding) { - case 0: { - qnn_padding = ::qnn::PaddingType::Same; - break; - } - case 1: { - qnn_padding = ::qnn::PaddingType::Valid; - break; - } - default: { - return kLiteRtStatusErrorUnsupported; - } - } - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertDataType(const litert::ElementType litert_type, - const bool is_quantized, - Qnn_DataType_t& qnn_type) { - qnn_type = QNN_DATATYPE_UNDEFINED; - switch (litert_type) { - case litert::ElementType::Bool: - qnn_type = QNN_DATATYPE_BOOL_8; - break; - case litert::ElementType::Int4: - qnn_type = QNN_DATATYPE_SFIXED_POINT_4; - break; - case litert::ElementType::Int8: - qnn_type = - is_quantized ? QNN_DATATYPE_SFIXED_POINT_8 : QNN_DATATYPE_INT_8; - break; - case litert::ElementType::Int16: - qnn_type = - is_quantized ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_INT_16; - break; - case litert::ElementType::Int32: - qnn_type = - is_quantized ? QNN_DATATYPE_SFIXED_POINT_32 : QNN_DATATYPE_INT_32; - break; - case litert::ElementType::Int64: - qnn_type = QNN_DATATYPE_INT_64; - break; - case litert::ElementType::UInt8: - qnn_type = - is_quantized ? QNN_DATATYPE_UFIXED_POINT_8 : QNN_DATATYPE_UINT_8; - break; - case litert::ElementType::UInt16: - qnn_type = - is_quantized ? QNN_DATATYPE_UFIXED_POINT_16 : QNN_DATATYPE_UINT_16; - break; - case litert::ElementType::UInt32: - qnn_type = - is_quantized ? QNN_DATATYPE_UFIXED_POINT_32 : QNN_DATATYPE_UINT_32; - break; - case litert::ElementType::UInt64: - qnn_type = QNN_DATATYPE_UINT_64; - break; - case litert::ElementType::Float16: - qnn_type = QNN_DATATYPE_FLOAT_16; - break; - case litert::ElementType::Float32: - qnn_type = QNN_DATATYPE_FLOAT_32; - break; - case litert::ElementType::Float64: - qnn_type = QNN_DATATYPE_FLOAT_64; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertTensor(const litert::Tensor& litert_tensor, - ::qnn::TensorPool& tensor_pool, - ::qnn::TensorWrapper*& tensor_wrapper, - bool is_tensor_read_and_write) { - tensor_wrapper = nullptr; - - if (litert_tensor.TypeId() != kLiteRtRankedTensorType) { - return kLiteRtStatusErrorInvalidArgument; - } - - const auto ranked_tensor_type = litert_tensor.RankedTensorType(); - if (!ranked_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", ranked_tensor_type.Error().Message().data()); - return ranked_tensor_type.Error().Status(); - } - - Qnn_DataType_t qnn_data_type; - LITERT_RETURN_IF_ERROR(ConvertDataType(ranked_tensor_type->ElementType(), - litert_tensor.HasQuantization(), - qnn_data_type)); - - std::vector dimentions; - const auto litert_layout = ranked_tensor_type->Layout(); - if (litert_layout.Rank() == 0) { - dimentions.resize(1, 1); - } else { - dimentions.resize(litert_layout.Rank()); - for (size_t i = 0; i < dimentions.size(); ++i) { - dimentions[i] = litert_layout.Dimensions()[i]; - } - } - - ::qnn::QuantizeParamsWrapperVariant quantize_params; - switch (litert_tensor.QTypeId()) { - case kLiteRtQuantizationPerTensor: { - const auto per_tensor_quant = litert_tensor.PerTensorQuantization(); - quantize_params.emplace<::qnn::ScaleOffsetQuantizeParamsWrapper>( - per_tensor_quant.scale, per_tensor_quant.zero_point); - break; - } - case kLiteRtQuantizationPerChannel: { - const auto per_channel_quant = litert_tensor.PerChannelQuantization(); - // convert zero points from std::int64_t to std::int32_t - std::vector zero_points(per_channel_quant.num_channels); - for (size_t i = 0; i < zero_points.size(); ++i) { - zero_points[i] = per_channel_quant.zero_points[i]; - } - quantize_params.emplace<::qnn::AxisScaleOffsetQuantizeParamsWrapper>( - per_channel_quant.quantized_dimension, - absl::Span{per_channel_quant.scales, - per_channel_quant.num_channels}, - absl::Span{zero_points.data(), - zero_points.size()}); - break; - } - case kLiteRtQuantizationBlockWise: { - LITERT_LOG(LITERT_ERROR, "Unsupported quantization type."); - return kLiteRtStatusErrorInvalidArgument; - } - case kLiteRtQuantizationNone: - default: - break; - } - - if (litert_tensor.IsSubgraphInput()) { - auto& res = tensor_pool.CreateInputTensor(qnn_data_type, quantize_params, - dimentions); - tensor_wrapper = &res; - } else if (litert_tensor.IsSubgraphOutput() || is_tensor_read_and_write) { - auto& res = tensor_pool.CreateOutpuTensor(qnn_data_type, quantize_params, - dimentions); - tensor_wrapper = &res; - } else if (litert_tensor.IsConstant()) { - LITERT_ENSURE(litert_tensor.HasWeights(), - kLiteRtStatusErrorInvalidLegalization, - "Empty weights for constant tensor."); - auto& res = tensor_pool.CreateStaticTensor( - qnn_data_type, quantize_params, dimentions, - litert_tensor.Weights().Bytes().size(), - reinterpret_cast(litert_tensor.Weights().Bytes().data())); - tensor_wrapper = &res; - } else { - auto& res = tensor_pool.CreateNativeTensor(qnn_data_type, quantize_params, - dimentions); - tensor_wrapper = &res; - } - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertOp( - const litert::Op& litert_op, ::qnn::TensorPool& tensor_pool, - const std::vector<::qnn::TensorWrapperRef>& input_tensors, - const std::vector<::qnn::TensorWrapperRef>& output_tensors, - std::vector<::qnn::OpWrapper>& op_wrappers) { - switch (litert_op.Code()) { - case LiteRtOpCode::kLiteRtOpCodeTflCast: { - op_wrappers = - ::qnn::BuildCastOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflConcatenation: { - int32_t axis{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetConcatenationAxisOption(litert_op.Get(), &axis)); - op_wrappers = ::qnn::BuildConcatenationOp(tensor_pool, input_tensors, - output_tensors, axis); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflAdd: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetAddFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseAddOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflLogicalAnd: { - op_wrappers = ::qnn::BuildElementwiseAndOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflCos: { - op_wrappers = ::qnn::BuildElementwiseCosOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDiv: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetDivFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseDivOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflGreater: { - op_wrappers = ::qnn::BuildElementwiseGreaterOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflLess: { - op_wrappers = ::qnn::BuildElementwiseLessOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMul: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetMulFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseMulOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflRsqrt: { - op_wrappers = ::qnn::BuildElementwiseRsqrtOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSin: { - op_wrappers = ::qnn::BuildElementwiseSinOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSquaredDifference: { - op_wrappers = ::qnn::BuildElementwiseSquaredDifferenceOp( - tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSquare: { - op_wrappers = ::qnn::BuildElementwiseSquareOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSub: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetSubFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseSubOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMinimum: { - op_wrappers = ::qnn::BuildElementwiseMinimumOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMaximum: { - op_wrappers = ::qnn::BuildElementwiseMaximumOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflEmbeddingLookup: { - op_wrappers = ::qnn::BuildEmbeddingLookupOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflFullyConnected: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetFullyConnectedFusedActivationOption( - litert_op.Get(), &fused_activation)); - bool keep_num_dims{}; - LITERT_RETURN_IF_ERROR(LiteRtGetFullyConnectedKeepNumDimsOption( - litert_op.Get(), &keep_num_dims)); - // TODO(jiunkaiy): Use compile interface to get useHtpPreferencs. - constexpr LiteRtQnnOptions qnn_options = LITERT_QNN_OPTIONS_INIT; - if (qnn_options.useHtpPreferencs) { - op_wrappers = ::qnn::BuildFullyConnectedOpHtp( - tensor_pool, input_tensors, output_tensors, keep_num_dims); - } - if (op_wrappers.empty()) { - op_wrappers = ::qnn::BuildFullyConnectedOp( - tensor_pool, input_tensors, output_tensors, keep_num_dims); - } - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflGather: { - int32_t axis{}; - LITERT_RETURN_IF_ERROR(LiteRtGetGatherAxisOption(litert_op.Get(), &axis)); - int32_t batch_dims{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetGatherBatchDimsOption(litert_op.Get(), &batch_dims)); - op_wrappers = ::qnn::BuildGatherOp(tensor_pool, input_tensors, - output_tensors, axis, batch_dims); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflGelu: { - op_wrappers = - ::qnn::BuildGeluOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflRelu: { - op_wrappers = - ::qnn::BuildReluOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflRelu6: { - op_wrappers = - ::qnn::BuildRelu6Op(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflBatchMatmul: { - bool adj_x{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetBatchMatmulAdjXOption(litert_op.Get(), &adj_x)); - bool adj_y{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetBatchMatmulAdjYOption(litert_op.Get(), &adj_y)); - op_wrappers = ::qnn::BuildMatmulOp(tensor_pool, input_tensors, - output_tensors, adj_x, adj_y); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMean: { - bool keep_dims{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetMeanKeepDimsOption(litert_op.Get(), &keep_dims)); - op_wrappers = ::qnn::BuildMeanOp(tensor_pool, input_tensors, - output_tensors, keep_dims); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflQuantize: { - op_wrappers = - ::qnn::BuildQuantizeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDequantize: { - op_wrappers = - ::qnn::BuildDequantizeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSum: { - bool keep_dims{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetSumKeepDimsOption(litert_op.Get(), &keep_dims)); - op_wrappers = ::qnn::BuildReduceSumOp(tensor_pool, input_tensors, - output_tensors, keep_dims); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflReshape: { - op_wrappers = - ::qnn::BuildReshapeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSelect: - case LiteRtOpCode::kLiteRtOpCodeTflSelectV2: { - op_wrappers = - ::qnn::BuildSelectOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSlice: { - op_wrappers = - ::qnn::BuildSliceOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSoftmax: { - float beta{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetSoftmaxBetaOption(litert_op.Get(), &beta)); - op_wrappers = ::qnn::BuildSoftmaxOp(tensor_pool, input_tensors, - output_tensors, beta); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSplit: { - int32_t num_splits{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetSplitNumSplitsOption(litert_op.Get(), &num_splits)); - op_wrappers = ::qnn::BuildSplitOp(tensor_pool, input_tensors, - output_tensors, num_splits); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflTanh: { - op_wrappers = - ::qnn::BuildTanhOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflTranspose: { - op_wrappers = - ::qnn::BuildTransposeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflPack: { - int32_t axis{}; - LiteRtGetPackAxisOption(litert_op.Get(), &axis); - op_wrappers = - ::qnn::BuildPackOp(tensor_pool, input_tensors, output_tensors, axis); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDynamicUpdateSlice: { - op_wrappers = ::qnn::BuildDynamicUpdateSliceOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeShloComposite: { - // TODO(yunandrew): Support custom epsilon for RMS Norm. - float epsilon = 9.99999997E-7; - op_wrappers = ::qnn::BuildRmsNormOp(tensor_pool, input_tensors, - output_tensors, epsilon); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflConv2d: { - uint32_t padding; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dPaddingOption(litert_op.Get(), &padding)); - int32_t stride_w; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dStrideWOption(litert_op.Get(), &stride_w)); - int32_t stride_h; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dStrideHOption(litert_op.Get(), &stride_h)); - int32_t dilation_w_factor; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dDilationWOption(litert_op.Get(), &dilation_w_factor)); - int32_t dilation_h_factor; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dDilationWOption(litert_op.Get(), &dilation_h_factor)); - - ::qnn::PaddingType qnn_padding; - LITERT_RETURN_IF_ERROR(ConvertPaddingType(padding, qnn_padding)); - op_wrappers = ::qnn::BuildConv2dOp( - tensor_pool, input_tensors, output_tensors, stride_h, stride_w, - dilation_h_factor, dilation_w_factor, qnn_padding); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDepthwiseConv2d: { - uint32_t padding; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthwiseConv2dPaddingOption(litert_op.Get(), &padding)); - int32_t stride_w; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthwiseConv2dStrideWOption(litert_op.Get(), &stride_w)); - int32_t stride_h; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthwiseConv2dStrideHOption(litert_op.Get(), &stride_h)); - int32_t dilation_w_factor; - LITERT_RETURN_IF_ERROR(LiteRtGetDepthwiseConv2dDilationWOption( - litert_op.Get(), &dilation_w_factor)); - int32_t dilation_h_factor; - LITERT_RETURN_IF_ERROR(LiteRtGetDepthwiseConv2dDilationHOptions( - litert_op.Get(), &dilation_h_factor)); - - ::qnn::PaddingType qnn_padding; - LITERT_RETURN_IF_ERROR(ConvertPaddingType(padding, qnn_padding)); - op_wrappers = ::qnn::BuildDepthwiseConv2dOp( - tensor_pool, input_tensors, output_tensors, stride_h, stride_w, - dilation_h_factor, dilation_w_factor, qnn_padding); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflAveragePool2d: { - uint32_t padding; - LITERT_RETURN_IF_ERROR( - LiteRtGetAveragePool2dPaddingOption(litert_op.Get(), &padding)); - int32_t stride_w; - LITERT_RETURN_IF_ERROR( - LiteRtGetAveragePool2dStrideWOption(litert_op.Get(), &stride_w)); - int32_t stride_h; - LITERT_RETURN_IF_ERROR( - LiteRtGetAveragePool2dStrideHOption(litert_op.Get(), &stride_h)); - int32_t filter_width; - LITERT_RETURN_IF_ERROR(LiteRtGetAveragePool2dFilterWidthOption( - litert_op.Get(), &filter_width)); - int32_t filter_height; - LITERT_RETURN_IF_ERROR(LiteRtGetAveragePool2dFilterHeightOption( - litert_op.Get(), &filter_height)); - - ::qnn::PaddingType qnn_padding; - LITERT_RETURN_IF_ERROR(ConvertPaddingType(padding, qnn_padding)); - op_wrappers = ::qnn::BuildAveragePoolOp( - tensor_pool, input_tensors, output_tensors, stride_h, stride_w, - filter_height, filter_width, qnn_padding); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDepthToSpace: { - int32_t block_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthToSpaceBlockSizeOption(litert_op.Get(), &block_size)); - op_wrappers = ::qnn::BuildDepthToSpaceOp(tensor_pool, input_tensors, - output_tensors, block_size); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSpaceToDepth: { - int32_t block_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetSpaceToDepthBlockSizeOption(litert_op.Get(), &block_size)); - op_wrappers = ::qnn::BuildSpaceToDepthOp(tensor_pool, input_tensors, - output_tensors, block_size); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflHardSwish: { - op_wrappers = - ::qnn::BuildHardSwishOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflLeakyRelu: { - float alpha; - LITERT_RETURN_IF_ERROR( - LiteRtGetLeakyReluAlphaOption(litert_op.Get(), &alpha)); - op_wrappers = ::qnn::BuildLeakyReluOp(tensor_pool, input_tensors, - output_tensors, alpha); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflResizeBilinear: { - bool align_corners; - LITERT_RETURN_IF_ERROR(LiteRtGetResizeBilinearAlignCornersOption( - litert_op.Get(), &align_corners)); - bool half_pixel_centers; - LITERT_RETURN_IF_ERROR(LiteRtGetResizeBilinearHalfPixelCenterOption( - litert_op.Get(), &half_pixel_centers)); - op_wrappers = ::qnn::BuildResizeBilinearOp(tensor_pool, input_tensors, - output_tensors, align_corners, - half_pixel_centers); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflResizeNearestNeighbor: { - bool align_corners; - LITERT_RETURN_IF_ERROR(LiteRtGetResizeNearestNeighborAlignCornersOption( - litert_op.Get(), &align_corners)); - bool half_pixel_centers; - LITERT_RETURN_IF_ERROR( - LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - litert_op.Get(), &half_pixel_centers)); - op_wrappers = ::qnn::BuildResizeNearestOp(tensor_pool, input_tensors, - output_tensors, align_corners, - half_pixel_centers); - break; - } - default: { - LITERT_LOG(LITERT_ERROR, - "LiteRT Op Code: %d is not supported in Qualcomm Compiler.", - litert_op.Code()); - } - } - return kLiteRtStatusOk; -} - -LiteRtStatus MapGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, - LiteRtSubgraph subgraph, - absl::string_view qnn_graph_name) { - GraphMapper graph_mapper(subgraph, qnn, context_handle); - LITERT_RETURN_IF_ERROR(graph_mapper.IsLiteRtSubgraphSupported()); - LITERT_RETURN_IF_ERROR(graph_mapper.InitQnnGraph(qnn_graph_name)); - - // - // Legalize subgraph inputs and update tensors in scope - // - - ::qnn::TensorPool tensor_pool; - absl::flat_hash_map - litert_tensor_to_wrapper; - - for (const auto& subgraph_input : graph_mapper.Graph().Inputs()) { - ::qnn::TensorWrapper* tensor_wrapper{nullptr}; - LITERT_RETURN_IF_ERROR( - ConvertTensor(subgraph_input, tensor_pool, tensor_wrapper)); - litert_tensor_to_wrapper.emplace(subgraph_input.Get(), tensor_wrapper); - } - - for (const auto& subgraph_output : graph_mapper.Graph().Outputs()) { - graph_mapper.RegisterOutput(subgraph_output.Get()); - } - // - // Topologically traverse graph, legalizing and updating tensors in scope - // - - // TODO: make ConvertOp accept a vector and append OpWrapper in it. - std::vector<::qnn::OpWrapper> graph_op_wrappers; - std::ostringstream dump; - for (const auto& op : graph_mapper.Graph().Ops()) { - // Dump op info. - dump.clear(); - Dump(*op.Get(), dump); - std::string s = dump.str(); - LITERT_LOG(LITERT_VERBOSE, "%s", s.data()); - - std::vector<::qnn::TensorWrapperRef> input_tensors; - for (const auto& input : op.Inputs()) { - if (const auto it = litert_tensor_to_wrapper.find(input.Get()); - it == litert_tensor_to_wrapper.end()) { - ::qnn::TensorWrapper* tensor_wrapper{nullptr}; - LITERT_RETURN_IF_ERROR( - ConvertTensor(input, tensor_pool, tensor_wrapper)); - // add into map to capture re-used static tensor - litert_tensor_to_wrapper.emplace(input.Get(), tensor_wrapper); - input_tensors.emplace_back(*tensor_wrapper); - } else { - input_tensors.emplace_back(*(it->second)); - } - } - - std::vector<::qnn::TensorWrapperRef> output_tensors; - for (const auto& output : op.Outputs()) { - bool is_tensor_read_and_write = graph_mapper.IsTensorOutput(output.Get()); - ::qnn::TensorWrapper* tensor_wrapper{nullptr}; - LITERT_RETURN_IF_ERROR(ConvertTensor(output, tensor_pool, tensor_wrapper, - is_tensor_read_and_write)); - litert_tensor_to_wrapper.emplace(output.Get(), tensor_wrapper); - output_tensors.emplace_back(*tensor_wrapper); - } - - std::vector<::qnn::OpWrapper> op_wrappers; - LITERT_RETURN_IF_ERROR( - ConvertOp(op, tensor_pool, input_tensors, output_tensors, op_wrappers)); - std::move(op_wrappers.begin(), op_wrappers.end(), - std::back_inserter(graph_op_wrappers)); - } - // Insert all tensors into Qnn graph and update the id of Qnn_Tensor_t inside. - tensor_pool.ForEach( - [&qnn, &graph_mapper](::qnn::TensorWrapper& tensor_wrapper) { - // TODO(chunhsue): Use compile interface to get useQInt16AsQUint16. - constexpr bool useQInt16AsQUint16 = true; - if constexpr (useQInt16AsQUint16) { - tensor_wrapper.ConvertQint16ToQuint16(); - } - qnn.Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &tensor_wrapper.GetQnnTensor()); - }); - // Then op can be added into Qnn graph after the tensor ids are updated. - for (auto& op_wrapper : graph_op_wrappers) { - qnn.Api()->graphAddNode(graph_mapper.QnnGraph(), op_wrapper.GetOpConfig()); - } - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(graph_mapper.Finalize()); - - return kLiteRtStatusOk; -} - -//===----------------------------------------------------------------------===// -// -// [WIP] LiteRT SUBGRAPH -> QNN GRAPH -// -// Core driver for IR translation. Traverses LiteRt Subgraph, iteratively -// "legalizing" (mapping) LiteRt entities to their QNN counterpart. -// -// APPROACH: -// -// To support the general case we will need a driver loop that either -// traverses input recursively through edges or just iterates topologically. -// -// The algorithm is pretty straightforward: -// * Store mapping between already evaluated LiteRtTensors and their -// newly constructed Qnn Tensor counterpart. -// * Look up QNN Tensors when setting QNN Op inputs. -// * Add new QNN Tensor when setting QNN Op outputs. -// -// NOTES ON QNN API: -// -// After QNN Tensors are registered in the context, they need only -// be stored as their ID. QNN Tensor and "id" : uint32_t are used -// interchangeably. -// -//===----------------------------------------------------------------------===// - -LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, - LiteRtSubgraph subgraph, - absl::string_view qnn_graph_name) { - LITERT_RETURN_IF_ERROR( - MapGraph(qnn, context_handle, subgraph, qnn_graph_name)); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h deleted file mode 100644 index 3c43e1901acb02..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -LiteRtStatus ConvertDataType(const litert::ElementType litert_type, - const bool is_quantized, Qnn_DataType_t& qnn_type); - -LiteRtStatus ConvertTensor(const litert::Tensor& litert_tensor, - ::qnn::TensorPool& tensor_pool, - ::qnn::TensorWrapper*& tensor_wrapper, - bool is_tensor_read_and_write = false); - -LiteRtStatus ConvertOp( - const litert::Op& litert_op, ::qnn::TensorPool& tensor_pool, - const std::vector<::qnn::TensorWrapperRef>& input_tensors, - const std::vector<::qnn::TensorWrapperRef>& output_tensors, - std::vector<::qnn::OpWrapper>& op_wrappers); - -// Composes a new QNN Graph from given LiteRt Graph. Qnn Graph is written to -// context behind "qnn". Uses given graph_name to name entry point. -LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, - LiteRtSubgraph subgraph, - absl::string_view qnn_graph_name); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc deleted file mode 100644 index 366f3a228b06d1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" - -namespace litert { -namespace qnn { - -namespace { - -Expected InsertQnnTensors(int num_qnn_tensors, Qnn_Tensor_t* qnn_tensors, - std::vector* tensors) { - tensors->clear(); - tensors->reserve(num_qnn_tensors); - for (auto i = 0; i < num_qnn_tensors; ++i) { - auto tensor = QnnTensor::Create(qnn_tensors[i]); - if (!tensor) { - return Unexpected(tensor.Error()); - } - tensors->push_back(std::move(*tensor)); - } - return {}; -} - -Expected InsertQnnGraphInfos( - int num_qnn_graph_infos, QnnSystemContext_GraphInfo_t* qnn_graph_infos, - std::vector* graphs) { - graphs->clear(); - graphs->reserve(num_qnn_graph_infos); - for (auto i = 0; i < num_qnn_graph_infos; ++i) { - auto graph = GraphInfo::Create(qnn_graph_infos[i]); - if (!graph) { - return Unexpected(graph.Error()); - } - graphs->push_back(std::move(*graph)); - } - - return {}; -} - -} // namespace - -Expected GraphInfo::Create( - const QnnSystemContext_GraphInfo_t& graph_info) { - GraphInfo info; - auto status = info.Init(graph_info); - if (status) { - return info; - } else { - return Unexpected(status.Error()); - } -} - -Expected GraphInfo::Init(const QnnSystemContext_GraphInfo_t& graph_info) { - if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { - const auto& graph_info_ = graph_info.graphInfoV1; - name_ = graph_info_.graphName; - LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); - - if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, - graph_info_.graphInputs, &inputs_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, - graph_info_.graphOutputs, &outputs_); - !status) { - return Unexpected(status.Error()); - } - - } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) { - const auto& graph_info_ = graph_info.graphInfoV2; - name_ = graph_info_.graphName; - LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); - - if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, - graph_info_.graphInputs, &inputs_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, - graph_info_.graphOutputs, &outputs_); - !status) { - return Unexpected(status.Error()); - } - } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { - const auto& graph_info_ = graph_info.graphInfoV3; - name_ = graph_info_.graphName; - LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); - - if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, - graph_info_.graphInputs, &inputs_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, - graph_info_.graphOutputs, &outputs_); - !status) { - return Unexpected(status.Error()); - } - - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported graph info version."); - } - return {}; -} - -Expected ContextBinaryInfo::Init( - const QnnSystemContext_BinaryInfo_t& binary_info) { - if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { - const auto& context_binary_info = binary_info.contextBinaryInfoV1; - if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, - context_binary_info.contextTensors, - &context_tensors_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, - context_binary_info.graphs, &graphs_); - !status) { - return Unexpected(status.Error()); - } - - } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { - const auto& context_binary_info = binary_info.contextBinaryInfoV2; - if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, - context_binary_info.contextTensors, - &context_tensors_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, - context_binary_info.graphs, &graphs_); - !status) { - return Unexpected(status.Error()); - } - } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { - const auto& context_binary_info = binary_info.contextBinaryInfoV3; - if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, - context_binary_info.contextTensors, - &context_tensors_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, - context_binary_info.graphs, &graphs_); - !status) { - return Unexpected(status.Error()); - } - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported context binary version."); - } - return {}; -} - -Expected ContextBinaryInfo::Create( - QnnManager& qnn, const void* exec_bytecode_ptr, size_t exec_bytecode_size) { - auto system_context_handle = qnn.CreateSystemContextHandle(); - if (!system_context_handle) { - return Unexpected(system_context_handle.Error()); - } - - const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; - Qnn_ContextBinarySize_t binary_info_size = 0; - if (auto status = qnn.SystemApi()->systemContextGetBinaryInfo( - system_context_handle->get(), const_cast(exec_bytecode_ptr), - exec_bytecode_size, &binary_info, &binary_info_size); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to get context binary info: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get context binary info"); - } - - if (!binary_info) { - LITERT_LOG(LITERT_ERROR, "Null binary info", ""); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Null binary info"); - } - - ContextBinaryInfo info; - auto status = info.Init(*binary_info); - - if (status) { - return info; - } else { - return Unexpected(status.Error()); - } -} - -} // namespace qnn -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h deleted file mode 100644 index e1e11dfa19f375..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" - -namespace litert::qnn { - -class GraphInfo { - public: - static Expected Create( - const QnnSystemContext_GraphInfo_t& graph_info); - const std::string& Name() const { return name_; } - const std::vector& Inputs() const { return inputs_; } - const std::vector& Outputs() const { return outputs_; } - - private: - GraphInfo() = default; - Expected Init(const QnnSystemContext_GraphInfo_t& graph_info); - std::string name_; - std::vector inputs_; - std::vector outputs_; -}; - -class ContextBinaryInfo { - public: - static Expected Create(QnnManager& qnn, - const void* exec_bytecode_ptr, - size_t exec_bytecode_size); - const std::vector& ContextTensors() const { - return context_tensors_; - } - const std::vector& Graphs() const { return graphs_; } - - private: - ContextBinaryInfo() = default; - Expected Init(const QnnSystemContext_BinaryInfo_t& binary_info); - std::vector context_tensors_; - std::vector graphs_; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD deleted file mode 100644 index 902bb5b5b49bee..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "tensor_pool", - srcs = ["tensor_pool.cc"], - hdrs = ["tensor_pool.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "common", - hdrs = ["common.h"], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD deleted file mode 100644 index 1c0f0367e2dab3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD +++ /dev/null @@ -1,574 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "op_builder", - srcs = ["op_builder.cc"], - hdrs = ["op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "fully_connected_op_builder_htp", - srcs = ["fully_connected_op_builder_htp.cc"], - hdrs = ["fully_connected_op_builder_htp.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "elementwise_op_builder", - srcs = ["elementwise_op_builder.cc"], - hdrs = ["elementwise_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:param_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "cast_op_builder", - srcs = ["cast_op_builder.cc"], - hdrs = ["cast_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "concatenation_op_builder", - srcs = ["concatenation_op_builder.cc"], - hdrs = ["concatenation_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "embedding_lookup_op_builder", - srcs = ["embedding_lookup_op_builder.cc"], - hdrs = ["embedding_lookup_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "fully_connected_op_builder", - srcs = ["fully_connected_op_builder.cc"], - hdrs = ["fully_connected_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "gather_op_builder", - srcs = ["gather_op_builder.cc"], - hdrs = ["gather_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "gelu_op_builder", - srcs = ["gelu_op_builder.cc"], - hdrs = ["gelu_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "relu_op_builder", - srcs = ["relu_op_builder.cc"], - hdrs = ["relu_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "relu6_op_builder", - srcs = ["relu6_op_builder.cc"], - hdrs = ["relu6_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "matmul_op_builder", - srcs = ["matmul_op_builder.cc"], - hdrs = ["matmul_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "mean_op_builder", - srcs = ["mean_op_builder.cc"], - hdrs = ["mean_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "quantize_op_builder", - srcs = ["quantize_op_builder.cc"], - hdrs = ["quantize_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "reduce_op_builder", - srcs = ["reduce_op_builder.cc"], - hdrs = ["reduce_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "reshape_op_builder", - srcs = ["reshape_op_builder.cc"], - hdrs = ["reshape_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "select_op_builder", - srcs = ["select_op_builder.cc"], - hdrs = ["select_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "slice_op_builder", - srcs = ["slice_op_builder.cc"], - hdrs = ["slice_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "softmax_op_builder", - srcs = ["softmax_op_builder.cc"], - hdrs = ["softmax_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "split_op_builder", - srcs = ["split_op_builder.cc"], - hdrs = ["split_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "tanh_op_builder", - srcs = ["tanh_op_builder.cc"], - hdrs = ["tanh_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "transpose_op_builder", - srcs = ["transpose_op_builder.cc"], - hdrs = ["transpose_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "pack_op_builder", - srcs = ["pack_op_builder.cc"], - hdrs = ["pack_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "dynamic_update_slice_op_builder", - srcs = ["dynamic_update_slice_op_builder.cc"], - hdrs = ["dynamic_update_slice_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "rms_norm_op_builder", - srcs = ["rms_norm_op_builder.cc"], - hdrs = ["rms_norm_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "conv2d_op_builder", - srcs = ["conv2d_op_builder.cc"], - hdrs = ["conv2d_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "pool2d_op_builder", - srcs = ["pool2d_op_builder.cc"], - hdrs = ["pool2d_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "spatial_transform_op_builder", - srcs = ["spatial_transform_op_builder.cc"], - hdrs = ["spatial_transform_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "resize_op_builder", - srcs = ["resize_op_builder.cc"], - hdrs = ["resize_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "leaky_relu_op_builder", - srcs = ["leaky_relu_op_builder.cc"], - hdrs = ["leaky_relu_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "hard_swish_op_builder", - srcs = ["hard_swish_op_builder.cc"], - hdrs = ["hard_swish_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "depthwise_conv2d_op_builder", - srcs = ["depthwise_conv2d_op_builder.cc"], - hdrs = ["depthwise_conv2d_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc deleted file mode 100644 index 361b6007572528..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildCastOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& op = CreateOpWrapper(res, QNN_OP_CAST); - op.AddInputTensor(inputs[0]); - op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h deleted file mode 100644 index 4de521da983870..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CAST_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CAST_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildCastOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CAST_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc deleted file mode 100644 index c75d985dbbd2bd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildConcatenationOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis) { - std::vector res; - - auto& concat_op = CreateOpWrapper(res, QNN_OP_CONCAT); - for (const auto& input : inputs) { - concat_op.AddInputTensor(input); - } - concat_op.AddOutputTensor(outputs[0]); - - std::uint32_t adjusted_axis = - (axis >= 0) ? axis : axis + inputs[0].get().GetRank(); - concat_op.AddScalarParam(QNN_OP_CONCAT_PARAM_AXIS, - adjusted_axis); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h deleted file mode 100644 index ed0784e27a913a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONCATENATION_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONCATENATION_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildConcatenationOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONCATENATION_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc deleted file mode 100644 index a41132440e15bc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h" - -#include -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kFilterIndex = 1; -constexpr size_t kBiasIndex = 2; -constexpr size_t kOutputIndex = 0; -constexpr size_t kBatchIndex = 0; -constexpr size_t kHeightIndex = 1; -constexpr size_t kWidthIndex = 2; -constexpr size_t kChannelIndex = 3; - -} // namespace - -std::vector BuildConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding_type) { - std::vector res; - - // transpose filter - TensorWrapper& filter_tensor = inputs[kFilterIndex]; - const std::vector& filters_dims = filter_tensor.GetDims(); - auto& filter_quant_params = filter_tensor.GetQuantParams(); - std::vector permute_dims{filters_dims[1], filters_dims[2], - filters_dims[3], filters_dims[0]}; - if (std::holds_alternative( - filter_quant_params)) { - auto& axis_quant_params = - std::get(filter_quant_params); - const std::array new_axis{3, 0, 1, 2}; - axis_quant_params.SetAxis(new_axis[axis_quant_params.GetAxis()]); - } - - size_t filter_bytes = filter_tensor.GetTensorBytes(); - TensorWrapper* transposed_filter_tensor = nullptr; - if (filter_tensor.IsTensorStatic() && - filter_tensor.GetDataType() == - Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8) { - auto filter_data = filter_tensor.GetStaticTensorData(); - std::vector transpose_weight_int8; - TransposeFromOHWIToHWIO(filter_data.value(), filters_dims, - transpose_weight_int8); - transposed_filter_tensor = &(tensor_pool.CreateStaticTensor( - filter_tensor.GetDataType(), filter_quant_params, permute_dims, - filter_bytes, transpose_weight_int8.data())); - } else if (filter_tensor.IsTensorStatic() && - filter_tensor.GetDataType() == - Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8) { - auto filter_data = filter_tensor.GetStaticTensorData(); - std::vector transpose_weight_uint8; - TransposeFromOHWIToHWIO(filter_data.value(), filters_dims, - transpose_weight_uint8); - transposed_filter_tensor = &(tensor_pool.CreateStaticTensor( - filter_tensor.GetDataType(), filter_quant_params, permute_dims, - filter_bytes, transpose_weight_uint8.data())); - } else { - transposed_filter_tensor = - &(tensor_pool.CloneNativeTensorFrom(filter_tensor, permute_dims)); - - const std::vector permute_shape{4}; - const std::array permute_data{kHeightIndex, kWidthIndex, - kChannelIndex, kBatchIndex}; - auto& permute_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, permute_shape, - sizeof(decltype(permute_data)::value_type) * permute_data.size(), - permute_data.data()); - - OpWrapper& transpose_op = CreateOpWrapper(res, QNN_OP_TRANSPOSE); - transpose_op.AddInputTensor(filter_tensor); - transpose_op.AddOutputTensor(*transposed_filter_tensor); - transpose_op.AddTensorParam(QNN_OP_TRANSPOSE_PARAM_PERM, permute_tensor); - } - - // conv - OpWrapper& conv_op = CreateOpWrapper(res, QNN_OP_CONV_2D); - TensorWrapper& input_tensor = inputs[kInputIndex]; - conv_op.AddInputTensor(input_tensor); - conv_op.AddInputTensor(*transposed_filter_tensor); - if (inputs.size() - 1 >= kBiasIndex) { - TensorWrapper& bias_tensor = inputs[kBiasIndex]; - // QNN only support per-tensor quant for bias, - // and the scale and offset are both zero. - bias_tensor.ConvertAxisScaleOffsetToScaleOffset(); - conv_op.AddInputTensor(bias_tensor); - } - - TensorWrapper& output_tensor = outputs[kOutputIndex]; - conv_op.AddOutputTensor(output_tensor); - // TODO: fused activation - - // stride param - const std::array stride_data{stride_h, stride_w}; - const std::vector stride_shape{2}; - auto& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(decltype(stride_data)::value_type) * stride_data.size(), - stride_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_STRIDE, stride_tensor); - - // dilation param - const std::array dilation_data{dilation_h, dilation_w}; - const std::vector dilation_shape{2}; - auto& dilation_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, dilation_shape, - sizeof(decltype(dilation_data)::value_type) * dilation_data.size(), - dilation_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_DILATION, dilation_tensor); - - // padding param - const auto [padding_before_height, padding_after_height] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kHeightIndex), - filter_tensor.GetDim(kHeightIndex), stride_h, - dilation_h, padding_type); - const auto [padding_before_width, padding_after_width] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kWidthIndex), - filter_tensor.GetDim(kWidthIndex), stride_w, - dilation_w, padding_type); - const std::array padding_data = { - padding_before_height, padding_after_height, padding_before_width, - padding_after_width}; - const std::vector padding_shape{2, 2}; - auto& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(decltype(padding_data)::value_type) * padding_data.size(), - padding_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, padding_tensor); - - // group param - if ((input_tensor.GetDim(kChannelIndex) % - filter_tensor.GetDim(kChannelIndex)) != 0) { - QNN_LOG_WARNING( - "The channels of the input tensor cannot be evenly divided by the " - "channels of the filter tensor."); - } - if (const std::uint32_t groups = input_tensor.GetDim(kChannelIndex) / - filter_tensor.GetDim(kChannelIndex); - groups > 1) { - conv_op.AddScalarParam(QNN_OP_CONV_2D_PARAM_GROUP, groups); - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h deleted file mode 100644 index 7cdd99bec46c40..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONV2D_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONV2D_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONV2D_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc deleted file mode 100644 index 3d4840eb6c390a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kFilterIndex = 1; -constexpr size_t kBiasIndex = 2; -constexpr size_t kOutputIndex = 0; -constexpr size_t kBatchIndex = 0; -constexpr size_t kHeightIndex = 1; -constexpr size_t kWidthIndex = 2; -constexpr size_t kChannelIndex = 3; - -} // namespace - -std::vector BuildDepthwiseConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding_type) { - std::vector res; - - // reshape filter - TensorWrapper& filter_tensor = inputs[kFilterIndex]; - - // 1HWC to HW1C, only need reshape instead of transpose. - const std::vector reshape_dims{ - filter_tensor.GetDim(kHeightIndex), filter_tensor.GetDim(kWidthIndex), - filter_tensor.GetDim(kBatchIndex), filter_tensor.GetDim(kChannelIndex)}; - TensorWrapper* reshaped_filter_tensor = nullptr; - if (filter_tensor.IsTensorStatic()) { - reshaped_filter_tensor = - &(tensor_pool.CloneStaticTensorFrom(filter_tensor, reshape_dims)); - } else { - reshaped_filter_tensor = - &(tensor_pool.CloneNativeTensorFrom(filter_tensor, reshape_dims)); - - OpWrapper& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op.AddInputTensor(filter_tensor); - reshape_op.AddOutputTensor(*reshaped_filter_tensor); - } - - // conv - OpWrapper& conv_op = CreateOpWrapper(res, QNN_OP_DEPTH_WISE_CONV_2D); - TensorWrapper& input_tensor = inputs[kInputIndex]; - conv_op.AddInputTensor(input_tensor); - conv_op.AddInputTensor(*reshaped_filter_tensor); - if (inputs.size() - 1 >= kBiasIndex) { - TensorWrapper& bias_tensor = inputs[kBiasIndex]; - // QNN only support per-tensor quant for bias, - // and the scale and offset are both zero. - bias_tensor.ConvertAxisScaleOffsetToScaleOffset(); - conv_op.AddInputTensor(bias_tensor); - } - - TensorWrapper& output_tensor = outputs[kOutputIndex]; - conv_op.AddOutputTensor(output_tensor); - // TODO: fused activation - - // stride param - const std::array stride_data{stride_h, stride_w}; - const std::vector stride_shape{2}; - auto& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(decltype(stride_data)::value_type) * stride_data.size(), - stride_data.data()); - conv_op.AddTensorParam(QNN_OP_DEPTH_WISE_CONV_2D_PARAM_STRIDE, stride_tensor); - - // dilation param - const std::array dilation_data{dilation_h, dilation_w}; - const std::vector dilation_shape{2}; - auto& dilation_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, dilation_shape, - sizeof(decltype(dilation_data)::value_type) * dilation_data.size(), - dilation_data.data()); - conv_op.AddTensorParam(QNN_OP_DEPTH_WISE_CONV_2D_PARAM_DILATION, - dilation_tensor); - - // padding param - const auto [padding_before_height, padding_after_height] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kHeightIndex), - filter_tensor.GetDim(kHeightIndex), stride_h, - dilation_h, padding_type); - const auto [padding_before_width, padding_after_width] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kWidthIndex), - filter_tensor.GetDim(kWidthIndex), stride_w, - dilation_w, padding_type); - const std::array padding_data = { - padding_before_height, padding_after_height, padding_before_width, - padding_after_width}; - const std::vector padding_shape{2, 2}; - auto& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(decltype(padding_data)::value_type) * padding_data.size(), - padding_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, padding_tensor); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h deleted file mode 100644 index 32419352844a5b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DEPTHWISE_CONV2D_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DEPTHWISE_CONV2D_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildDepthwiseConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DEPTHWISE_CONV2D_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc deleted file mode 100644 index b356188becb799..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { -constexpr int kInputIdx = 0; -constexpr int kUpdateIdx = 1; -constexpr int kIndicesIdx = 2; -constexpr int kOutputIdx = 0; -} // namespace - -std::vector BuildDynamicUpdateSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - // Dynamic Update Slice: - // in[0] operand: [1, 64, 4, 64] - // in[1] updates: [1, 1, 4, 64] - // in[2] start_indices: [4] -> data: [0, x, 0, 0] - - // reduceSum and reshape in[2] -> index tensor - - // Create static tensor table - // shape: [64] - // data: [0,...,63] - - // QNN ElementWiseNotEqual: - // in[0]: table - // in[1]: index tensor - // out[0]: condition tensor - - // reshape condition tensor due to QNN broadcast rules - // in[0]: [64] - // out[0]: [64, 1, 1] - - // QNN ElementWiseSelect: - // in[0] condition: [64, 1, 1] - // in[1] input: [1, 64, 4, 64] - // in[2] updates: [1, 1, 4, 64] - - // CAUTION!!! only support Gemma2 use case now - - auto& input_tensor = inputs[kInputIdx].get(); - auto& update_tensor = inputs[kUpdateIdx].get(); - auto& indices_tensor = inputs[kIndicesIdx].get(); - auto& output_tensor = outputs[kOutputIdx].get(); - - if (input_tensor.GetRank() != update_tensor.GetRank()) { - LITERT_LOG(LITERT_ERROR, "%s", - "QNN LiteRT Delegate only supports Dynamic Update Slice when " - "operand and updates have the same rank."); - return {}; - } - - if (indices_tensor.GetDataType() != QNN_DATATYPE_INT_32) { - LITERT_LOG(LITERT_ERROR, "%s", - "Dynamic Update Slice only supports QNN_DATATYPE_INT_32 " - "start_indices."); - return {}; - } - - // reduce sum - auto& reduce_sum_op = CreateOpWrapper(res, QNN_OP_REDUCE_SUM); - reduce_sum_op.AddInputTensor(indices_tensor); - - std::vector axis_data = {0}; - TensorWrapper& axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, {1}, - sizeof(std::uint32_t), axis_data.data()); - reduce_sum_op.AddTensorParam(QNN_OP_REDUCE_SUM_PARAM_AXES, axis_tensor); - - // create intermediate tensor - TensorWrapper& one_dim_index = - tensor_pool.CloneNativeTensorFrom(indices_tensor, {1}); - reduce_sum_op.AddOutputTensor(one_dim_index); - - // ElementwiseNotEqual - // get table dims from in[0]->Dims[1] - if (input_tensor.GetRank() < 2) { - LITERT_LOG(LITERT_ERROR, "%s", - "Dynamic Update Slice only supports operand tensor rank >= 2"); - return {}; - } - uint32_t table_size = input_tensor.GetDim(1); - std::vector static_table_dims = {table_size}; - std::vector table_data(table_size); - std::iota(table_data.begin(), table_data.end(), 0); - - // create static table tensor - TensorWrapper& static_table = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_INT_32, QuantizeParamsWrapperVariant{}, static_table_dims, - table_size * sizeof(std::int32_t), table_data.data()); - - OpWrapper& not_equal_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_NOT_EQUAL); - not_equal_op.AddInputTensor(static_table); - not_equal_op.AddInputTensor(one_dim_index); - - TensorWrapper& not_equal_out = tensor_pool.CreateNativeTensor( - QNN_DATATYPE_BOOL_8, QuantizeParamsWrapperVariant{}, static_table_dims); - not_equal_op.AddOutputTensor(not_equal_out); - - // reshape not equal output to [N, 1, 1] - OpWrapper& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - - reshape_op.AddInputTensor(not_equal_out); - TensorWrapper& reshape_out = - tensor_pool.CloneNativeTensorFrom(not_equal_out, {table_size, 1, 1}); - reshape_op.AddOutputTensor(reshape_out); - - // Select - OpWrapper& select_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SELECT); - - select_op.AddInputTensor(reshape_out); - select_op.AddInputTensor(input_tensor); - select_op.AddInputTensor(update_tensor); - select_op.AddOutputTensor(output_tensor); - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h deleted file mode 100644 index c5a74c1a7c5ed6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DYNAMIC_UPDATE_SLICE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DYNAMIC_UPDATE_SLICE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildDynamicUpdateSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DYNAMIC_UPDATE_SLICE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc deleted file mode 100644 index 38ed759d53fd21..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildElementwiseAddOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_ADD); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseSubOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SUBTRACT); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseMulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_MULTIPLY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseDivOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_DIVIDE); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseSinOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SIN); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseCosOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_COS); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseRsqrtOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_RSQRT); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseSquareOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - OpWrapper& elementwise_op = - CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_MULTIPLY); - elementwise_op.AddInputTensor(inputs[0]); - elementwise_op.AddInputTensor(inputs[0]); - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseSquaredDifferenceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = - CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SQUARED_DIFFERENCE); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseLessOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_LESS); - - return res; -} - -std::vector BuildElementwiseGreaterOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_GREATER); - - return res; -} - -std::vector BuildElementwiseAndOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_AND); - - return res; -} - -std::vector BuildElementwiseMinimumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_MINIMUM); - - return res; -} - -std::vector BuildElementwiseMaximumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_MAXIMUM); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h deleted file mode 100644 index 7953ce93c26b17..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_ELEMENTWISE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_ELEMENTWISE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildElementwiseAddOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSubOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseMulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseDivOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSinOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseCosOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseRsqrtOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSquareOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSquaredDifferenceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseLessOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseGreaterOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseAndOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseMinimumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseMaximumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_ELEMENTWISE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc deleted file mode 100644 index f33c5167ef3404..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { -constexpr int kTableIdx = 1; -constexpr int kIndicesIdx = 0; -constexpr int kOutputIdx = 0; -} // namespace - -std::vector BuildEmbeddingLookupOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - TensorWrapper& table_tensor = inputs[kTableIdx]; - TensorWrapper& indices_tensor = inputs[kIndicesIdx]; - TensorWrapper& output_tensor = outputs[kOutputIdx]; - - auto& gather_op = CreateOpWrapper(res, QNN_OP_GATHER); - // Case: QInt8 table with QInt16 output - if (table_tensor.IsQuant8() && output_tensor.IsQuant16()) { - QNN_LOG_WARNING( - "The data type of embedding lookup table is int8, but output data type " - "is int16. Int8 table will be cast to int16."); - std::vector int16_data; - size_t data_len = table_tensor.GetTensorNumElements(); - auto int8_data = table_tensor.GetStaticTensorData(); - if (!int8_data.has_value()) { - QNN_LOG_ERROR("Embedding lookup get int8 table failed."); - return res; - } - int16_data.reserve(data_len); - for (int i = 0; i < data_len; ++i) { - int16_data.emplace_back(static_cast((*int8_data)[i])); - } - - TensorWrapper& int16_table_tensor = tensor_pool.CreateStaticTensor( - output_tensor.GetDataType(), table_tensor.GetQuantParams(), - table_tensor.GetDims(), - sizeof(decltype(int16_data)::value_type) * int16_data.size(), - reinterpret_cast(int16_data.data())); - - gather_op.AddInputTensor(int16_table_tensor); - } else { - gather_op.AddInputTensor(table_tensor); - } - - gather_op.AddInputTensor(indices_tensor); - gather_op.AddOutputTensor(output_tensor); - gather_op.AddScalarParam(QNN_OP_GATHER_PARAM_AXIS, 0); - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h deleted file mode 100644 index 175f65dac0a5e0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_EMBEDDING_LOOKUP_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_EMBEDDING_LOOKUP_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildEmbeddingLookupOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_EMBEDDING_LOOKUP_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc deleted file mode 100644 index 2b471d5f7ce5d9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr int kBiasIdx = 2; -} - -std::vector BuildFullyConnectedOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims) { - std::vector res; - OpWrapper& fully_connected_op = CreateOpWrapper(res, QNN_OP_FULLY_CONNECTED); - - TensorWrapper& input_tensor = inputs[0]; - fully_connected_op.AddInputTensor(input_tensor); - TensorWrapper& weight_tensor = inputs[1]; - fully_connected_op.AddInputTensor(weight_tensor); - if (inputs.size() - 1 >= kBiasIdx) { - TensorWrapper& bias_tensor = inputs[kBiasIdx]; - fully_connected_op.AddInputTensor(bias_tensor); - } - - TensorWrapper& output_tensor = outputs[0]; - if (keep_num_dims) { - auto& input_dims = input_tensor.GetDims(); - std::uint32_t input_size = std::accumulate( - input_dims.begin(), input_dims.end(), 1, std::multiplies<>()); - const std::uint32_t num_units = weight_tensor.GetDim(0); - const std::uint32_t num_input_elem = weight_tensor.GetDim(1); - - // input_size must be divisible by num_input_elem. This should be validated - // by QNN. - const std::uint32_t batch_size = input_size / num_input_elem; - // QNN output should always be rank 2 - qnn::TensorWrapper& fully_connected_out = tensor_pool.CloneNativeTensorFrom( - output_tensor, {batch_size, num_units}); - - fully_connected_op.AddOutputTensor(fully_connected_out); - // TODO: fused activation - - qnn::OpWrapper& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op.AddInputTensor(fully_connected_out); - reshape_op.AddOutputTensor(output_tensor); - } else { - fully_connected_op.AddOutputTensor(outputs[0]); - // TODO: fused activation - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h deleted file mode 100644 index 3031be6f3002b8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildFullyConnectedOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc deleted file mode 100644 index a0e56116518f4b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h" - -#include -#include -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildFullyConnectedOpHtp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims) { - std::vector res; - QNN_LOG_INFO("[FullyConnected Optimization] FC -> CONV2D"); - // TFLite FC Input: [1, k, n] and Weight: [m, n] - // QNN Conv2D Input: - // [batch, height, width, channel_in] - // -> [1, 1, k, n] - // QNN Conv2D Weight: - // [filter_height, filter_width, channel_in / group, channel_out] - // -> [1, 1, n, m] - bool is_supported = inputs[0].get().GetRank() == 3 && inputs.size() == 2 && - inputs[1].get().IsTensorStatic(); - if (!is_supported) { - QNN_LOG_INFO("[FullyConnected Optimization] FAILURE: Unsupported Input"); - return res; - } - - // TFLite FC -> QNN CONV2D: - // Reshape -> Conv2D -> Reshpae - TensorWrapper& input_tensor = inputs[0]; - TensorWrapper& weight_tensor = inputs[1]; - TensorWrapper& output_tensor = outputs[0]; - // Reshape - qnn::OpWrapper& reshape_op_1 = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op_1.AddInputTensor(input_tensor); - std::vector conv_input_dims = input_tensor.GetDims(); - conv_input_dims.insert(conv_input_dims.begin() + 1, 1); - qnn::TensorWrapper& conv_input_tensor = - tensor_pool.CloneNativeTensorFrom(input_tensor, conv_input_dims); - reshape_op_1.AddOutputTensor(conv_input_tensor); - // Conv2D Input, Weight, and Output - OpWrapper& conv_op = CreateOpWrapper(res, QNN_OP_CONV_2D); - conv_op.AddInputTensor(conv_input_tensor); - auto& quant_params = weight_tensor.GetQuantParams(); - if (std::holds_alternative( - quant_params)) { - auto& axis_quant_param = - std::get(quant_params); - axis_quant_param.SetAxis(3); - } - std::vector weight_dims{1, 1, weight_tensor.GetDim(1), - weight_tensor.GetDim(0)}; - size_t weight_bytes = weight_tensor.GetTensorBytes(); - const std::vector transpose_dim{weight_tensor.GetDim(0), 1, 1, - weight_tensor.GetDim(1)}; - TensorWrapper* weight; - if (weight_tensor.GetDataType() == QNN_DATATYPE_SFIXED_POINT_8) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else if (weight_tensor.GetDataType() == QNN_DATATYPE_SFIXED_POINT_16) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else if (weight_tensor.GetDataType() == QNN_DATATYPE_UFIXED_POINT_16) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else if (weight_tensor.GetDataType() == QNN_DATATYPE_FLOAT_32) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else { - QNN_LOG_INFO( - "[FullyConnected Optimization] FAILURE: Upsupported Weight Datatype"); - return {}; - } - conv_op.AddInputTensor(*weight); - qnn::TensorWrapper& conv_out = tensor_pool.CloneNativeTensorFrom( - output_tensor, {conv_input_dims[0], conv_input_dims[1], - conv_input_dims[2], weight_dims[3]}); - conv_op.AddOutputTensor(conv_out); - // Conv2D Stride - const std::array stride_data{1, 1}; - const std::vector stride_shape{2}; - auto& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(std::uint32_t) * stride_data.size(), stride_data.data()); - conv_op.AddTensorParam(QNN_OP_DEPTH_WISE_CONV_2D_PARAM_STRIDE, stride_tensor); - // Conv2D Padding - const std::array padding_data = {0, 0, 0, 0}; - const std::vector padding_shape{2, 2}; - auto& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(std::uint32_t) * padding_data.size(), padding_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, padding_tensor); - - // Reshape - qnn::OpWrapper& reshape_op_2 = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op_2.AddInputTensor(conv_out); - reshape_op_2.AddOutputTensor(output_tensor); - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h deleted file mode 100644 index ccf8371fe9755e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_HTP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_HTP_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildFullyConnectedOpHtp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_HTP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc deleted file mode 100644 index e20d7b31fad8bb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGatherOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis, - const std::int32_t batch_dims) { - std::vector res; - - if (batch_dims != 0) { - QNN_LOG_ERROR("The batch dimension of Gather OP is not equal to 0."); - return res; - } - - auto& gather_op = CreateOpWrapper(res, QNN_OP_GATHER); - for (const auto& input : inputs) { - gather_op.AddInputTensor(input); - } - for (const auto& output : outputs) { - gather_op.AddOutputTensor(output); - } - const std::int32_t adjusted_axis = - axis >= 0 ? axis : axis + inputs[0].get().GetRank(); - gather_op.AddScalarParam(QNN_OP_GATHER_PARAM_AXIS, - adjusted_axis); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h deleted file mode 100644 index 00b078c4f36e7e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GATHER_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GATHER_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGatherOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis, - const std::int32_t batch_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GATHER_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc deleted file mode 100644 index 8f382292b53bc5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGeluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_GELU, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h deleted file mode 100644 index 77a72154ee89a9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GELU_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GELU_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGeluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GELU_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc deleted file mode 100644 index be77996eb660e6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; -} // namespace - -std::vector BuildHardSwishOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - OpWrapper& hard_swish_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_NEURON); - hard_swish_op.AddInputTensor(inputs[kInputIndex]); - hard_swish_op.AddOutputTensor(outputs[kOutputIndex]); - hard_swish_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_NEURON_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_NEURON_OPERATION_HARD_SWISH); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h deleted file mode 100644 index 9a0a6c3254d327..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_HARD_SWISH_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_HARD_SWISH_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildHardSwishOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_HARD_SWISH_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc deleted file mode 100644 index b6ece2ed343655..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h" - -#include -#include -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; - -template -TensorWrapper& CreateAlphaTensor( - TensorPool& tensor_pool, const Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_param, const T alpha) { - const std::vector alpha_dims{1}; - const std::array alpha_data{alpha}; - return tensor_pool.CreateStaticTensor(data_type, quant_param, alpha_dims, - sizeof(T) * alpha_data.size(), - alpha_data.data()); -} - -} // namespace -std::vector BuildLeakyReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float alpha) { - std::vector res; - - OpWrapper& leaky_relu_op = CreateOpWrapper(res, QNN_OP_PRELU); - TensorWrapper& input_tensor = inputs[kInputIndex]; - leaky_relu_op.AddInputTensor(input_tensor); - leaky_relu_op.AddOutputTensor(outputs[kOutputIndex]); - - if (std::holds_alternative( - input_tensor.GetQuantParams())) { - TensorWrapper& alpha_tensor = - CreateAlphaTensor(tensor_pool, input_tensor.GetDataType(), - input_tensor.GetQuantParams(), alpha); - leaky_relu_op.AddInputTensor(alpha_tensor); - } else if (std::holds_alternative( - input_tensor.GetQuantParams())) { - QuantizeParamsWrapperVariant quant_param; - quant_param.emplace(std::max(alpha, 0.0f), - 0); - - switch (input_tensor.GetDataType()) { - case QNN_DATATYPE_UFIXED_POINT_8: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - case QNN_DATATYPE_SFIXED_POINT_8: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - case QNN_DATATYPE_UFIXED_POINT_16: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - case QNN_DATATYPE_SFIXED_POINT_16: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - default: { - QNN_LOG_ERROR( - "Unsupported QNN data type when creating alpha tensor for " - "per-tensor quantization."); - break; - } - } - } else { - QNN_LOG_ERROR("Unsupported quantization type for LeakyRelu op."); - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h deleted file mode 100644 index 99f400a7285355..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_LEAKY_RELU_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_LEAKY_RELU_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildLeakyReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float alpha); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_LEAKY_RELU_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc deleted file mode 100644 index 9833c36fe71fec..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMatmulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool adj_x, - const bool adj_y) { - std::vector res; - - auto& matmul_op = CreateOpWrapper(res, QNN_OP_MAT_MUL); - for (const auto& input : inputs) { - matmul_op.AddInputTensor(input); - } - matmul_op.AddOutputTensor(outputs[0]); - matmul_op.AddScalarParam(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, adj_x); - matmul_op.AddScalarParam(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, adj_y); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h deleted file mode 100644 index 40958ebb9c4db2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MATMUL_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MATMUL_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMatmulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool adj_x, - const bool adj_y); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MATMUL_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc deleted file mode 100644 index 3495ee24efa6f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMeanOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const bool keep_dim) { - std::vector res; - - TensorWrapper& axis_tensor = inputs[1]; - if (!axis_tensor.IsTensorStatic() || axis_tensor.GetRank() != 1) { - QNN_LOG_ERROR( - "The axis tensor is not static, or the rank of axis tensor is not " - "equal to 1."); - return res; - } - - TensorWrapper& input_tensor = inputs[0]; - - auto axis_data = axis_tensor.GetStaticTensorData(); - if (!axis_data.has_value()) { - QNN_LOG_ERROR("Get axis_data failed."); - return res; - } - std::vector adjusted_axis_data; - for (size_t i = 0; i < axis_tensor.GetDim(0); ++i) { - std::uint32_t adjusted_axis = - (*axis_data)[i] >= 0 ? (*axis_data)[i] - : (*axis_data)[i] + input_tensor.GetRank(); - if (std::find(adjusted_axis_data.begin(), adjusted_axis_data.end(), - adjusted_axis) == adjusted_axis_data.end()) { - adjusted_axis_data.emplace_back(adjusted_axis); - } - } - TensorWrapper& adjusted_axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, axis_tensor.GetQuantParams(), - {static_cast(adjusted_axis_data.size())}, - sizeof(std::uint32_t) * adjusted_axis_data.size(), - adjusted_axis_data.data()); - - auto& reduce_op = CreateOpWrapper(res, QNN_OP_REDUCE_MEAN); - reduce_op.AddInputTensor(input_tensor); - reduce_op.AddOutputTensor(outputs[0]); - reduce_op.AddTensorParam(QNN_OP_REDUCE_MEAN_PARAM_AXES, adjusted_axis_tensor); - reduce_op.AddScalarParam(QNN_OP_REDUCE_MEAN_PARAM_KEEP_DIMS, keep_dim); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h deleted file mode 100644 index 50127647c90c10..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MEAN_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MEAN_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMeanOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const bool keep_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MEAN_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc deleted file mode 100644 index 8687927d3875b2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::pair ComputePaddingBeforeAfter( - const std::uint32_t input_size, const std::uint32_t filter_size, - const std::uint32_t stride, const std::uint32_t dilation_rate, - const PaddingType padding_type) { - // padding_before, padding_after - std::pair result{0, 0}; - if (stride == 0) { - QNN_LOG_ERROR("Stride is 0"); - return result; - } - - std::uint32_t output_size{}; - std::uint32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; - - switch (padding_type) { - case PaddingType::Same: - output_size = (input_size + stride - 1) / stride; - break; - case PaddingType::Valid: - output_size = (input_size + stride - effective_filter_size) / stride; - break; - default: // PaddingType::Unknown - QNN_LOG_ERROR("Unknown padding type"); - return result; - } - - std::uint32_t total_padding = - (output_size - 1) * stride + effective_filter_size - input_size; - result.first = total_padding / 2; - result.second = result.first + total_padding % 2; - return result; -} - -OpWrapper& CreateOpWrapper(std::vector& ops, const char* op_type) { - const auto op_count = ops.size(); - const auto name = "op_type_" + std::string(op_type) + "_op_count_" + - std::to_string(op_count); - return ops.emplace_back(std::move(name), op_type); -} - -OpWrapper& CreateSimpleActivationOp(std::vector& ops, - const char* op_type, - const TensorWrapper& input_tensor, - const TensorWrapper& output_tensor) { - auto& ret = CreateOpWrapper(ops, op_type); - ret.AddInputTensor(input_tensor); - ret.AddOutputTensor(output_tensor); - return ret; -} - -/* -LiteRtStatus OpMapper::AddFusedActivationNode( - const tflite::ActivationFunctionType activation, - const TensorWrapper& input_tensor, const TensorWrapper& output_tensor) { - switch (activation) { - case tflite::ActivationFunctionType_RELU: { - OpWrapper& activation_op = - CreateSimpleActivationOp(QNN_OP_RELU, input_tensor, output_tensor); - break; - } - case tflite::ActivationFunctionType_RELU_N1_TO_1: { - OpWrapper& activation_op = CreateSimpleActivationOp( - QNN_OP_RELU_MIN_MAX, input_tensor, output_tensor); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, - -1.f); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, - 1.f); - break; - } - case tflite::ActivationFunctionType_RELU6: { - OpWrapper& activation_op = CreateSimpleActivationOp( - QNN_OP_RELU_MIN_MAX, input_tensor, output_tensor); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, - 0.f); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, - 6.f); - break; - } - case tflite::ActivationFunctionType_TANH: { - OpWrapper& activation_op = - CreateSimpleActivationOp(QNN_OP_TANH, input_tensor, output_tensor); - break; - } - default: - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} -*/ - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h deleted file mode 100644 index 2888c3e84262c2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_OP_BUILDER_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -enum class PaddingType { - Unknown = 0, - Same, - Valid, -}; - -std::pair ComputePaddingBeforeAfter( - const std::uint32_t input_size, const std::uint32_t filter_size, - const std::uint32_t stride, const std::uint32_t dilation_rate, - const PaddingType padding_type); - -OpWrapper& CreateOpWrapper(std::vector& ops, const char* op_type); - -OpWrapper& CreateSimpleActivationOp(std::vector& ops, - const char* op_type, - const TensorWrapper& input_tensor, - const TensorWrapper& output_tensor); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc deleted file mode 100644 index 97dc4c5c9561b7..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildPackOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const int32_t axis) { - std::vector res; - - // pack op with only one input would violate op definition of qnn - // we'll replace it with reshape op - if (inputs.size() == 1) { - auto& op = CreateOpWrapper(res, QNN_OP_RESHAPE); - op.AddInputTensor(inputs[0]); - op.AddOutputTensor(outputs[0]); - return res; - } - - if (outputs[0].get().GetRank() != inputs[0].get().GetRank() + 1) { - auto& concat_op = CreateOpWrapper(res, QNN_OP_CONCAT); - for (const auto& input : inputs) { - concat_op.AddInputTensor(input); - } - concat_op.AddOutputTensor(outputs[0]); - } else { - auto& pack_op = CreateOpWrapper(res, QNN_OP_PACK); - for (const auto& input : inputs) { - pack_op.AddInputTensor(input); - } - std::uint32_t adjusted_axis = - axis < 0 ? axis + inputs[0].get().GetRank() : axis; - pack_op.AddScalarParam(QNN_OP_PACK_PARAM_AXIS, - adjusted_axis); - pack_op.AddOutputTensor(outputs[0]); - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h deleted file mode 100644 index b0e39cc74ccd2f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_PACK_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_PACK_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildPackOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const int32_t axis); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_PACK_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc deleted file mode 100644 index b4b42d743a0cf1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { - -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; -constexpr size_t kHeightIndex = 1; -constexpr size_t kWidthIndex = 2; - -std::vector BuildPool2dOp( - TensorPool& tensor_pool, const char* op_type, const char* filter_param_name, - const char* stride_param_name, const char* padding_param_name, - const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type) { - std::vector res; - - OpWrapper& pool_op = CreateOpWrapper(res, op_type); - - TensorWrapper& input_tensor = inputs[kInputIndex]; - pool_op.AddInputTensor(input_tensor); - - // filter param - const std::vector filter_shape{2}; - const std::array filter_data{filter_height, filter_width}; - TensorWrapper& filter_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, filter_shape, - sizeof(decltype(filter_data)::value_type) * filter_data.size(), - filter_data.data()); - pool_op.AddTensorParam(filter_param_name, filter_tensor); - - // stride param - const std::vector stride_shape{2}; - const std::array stride_data{stride_height, stride_width}; - TensorWrapper& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(decltype(stride_data)::value_type) * stride_data.size(), - stride_data.data()); - pool_op.AddTensorParam(stride_param_name, stride_tensor); - - // padding - const auto [padding_before_height, padding_after_height] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kHeightIndex), - filter_height, stride_height, 1, padding_type); - const auto [padding_before_width, padding_after_width] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kWidthIndex), filter_width, - stride_width, 1, padding_type); - const std::vector padding_shape{2, 2}; - const std::array padding_data{ - padding_before_height, padding_after_height, padding_before_width, - padding_after_width}; - TensorWrapper& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(decltype(padding_data)::value_type) * padding_data.size(), - padding_data.data()); - pool_op.AddTensorParam(padding_param_name, padding_tensor); - - TensorWrapper& output_tensor = outputs[kOutputIndex]; - pool_op.AddOutputTensor(output_tensor); - // TODO: fused activation - - return res; -} - -} // namespace - -std::vector BuildMaxPoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type) { - return BuildPool2dOp( - tensor_pool, QNN_OP_POOL_MAX_2D, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, - QNN_OP_POOL_MAX_2D_PARAM_STRIDE, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, - inputs, outputs, stride_height, stride_width, filter_height, filter_width, - padding_type); -} - -std::vector BuildAveragePoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type) { - return BuildPool2dOp( - tensor_pool, QNN_OP_POOL_AVG_2D, QNN_OP_POOL_AVG_2D_PARAM_FILTER_SIZE, - QNN_OP_POOL_AVG_2D_PARAM_STRIDE, QNN_OP_POOL_AVG_2D_PARAM_PAD_AMOUNT, - inputs, outputs, stride_height, stride_width, filter_height, filter_width, - padding_type); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h deleted file mode 100644 index cb8da0e7a19589..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_POOL2D_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_POOL2D_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMaxPoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type); - -std::vector BuildAveragePoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_POOL2D_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc deleted file mode 100644 index 70c4b610336118..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildQuantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - const char* qnn_op = nullptr; - if (inputs[0].get().IsPerTensorQuantWithOffsetDiff(outputs[0].get())) { - qnn_op = QNN_OP_CAST; - } else if ((inputs[0].get().IsQuant8() || inputs[0].get().IsQuant16()) && - (outputs[0].get().IsQuant8() || outputs[0].get().IsQuant16())) { - qnn_op = QNN_OP_CONVERT; - } else { - qnn_op = QNN_OP_QUANTIZE; - } - - auto& quantize_op = CreateOpWrapper(res, qnn_op); - quantize_op.AddInputTensor(inputs[0]); - quantize_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildDequantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - const char* qnn_op = nullptr; - if (inputs[0].get().IsF16() && outputs[0].get().IsF32()) { - qnn_op = QNN_OP_CAST; - } else { - qnn_op = QNN_OP_DEQUANTIZE; - } - - auto& quantize_op = CreateOpWrapper(res, qnn_op); - quantize_op.AddInputTensor(inputs[0]); - quantize_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h deleted file mode 100644 index 2b2cfd923202bc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_QUANTIZE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_QUANTIZE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildQuantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildDequantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_QUANTIZE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc deleted file mode 100644 index b978f10450213b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReduceSumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_dims) { - std::vector res; - - TensorWrapper& axis_tensor = inputs[1]; - if (!axis_tensor.IsTensorStatic() || axis_tensor.GetRank() != 1) { - QNN_LOG_ERROR( - "The axis tensor is not static, or the rank of axis tensor is not " - "equal to 1."); - return res; - } - - TensorWrapper& input_tensor = inputs[0]; - - auto axis_data = axis_tensor.GetStaticTensorData(); - if (!axis_data.has_value()) { - QNN_LOG_ERROR("Get axis_data failed."); - return res; - } - std::vector adjusted_axis_data; - for (size_t i = 0; i < axis_tensor.GetDim(0); ++i) { - std::uint32_t adjusted_axis = - (*axis_data)[i] >= 0 ? (*axis_data)[i] - : (*axis_data)[i] + input_tensor.GetRank(); - if (std::find(adjusted_axis_data.begin(), adjusted_axis_data.end(), - adjusted_axis) == adjusted_axis_data.end()) { - adjusted_axis_data.emplace_back(adjusted_axis); - } - } - TensorWrapper& adjusted_axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, axis_tensor.GetQuantParams(), - {static_cast(adjusted_axis_data.size())}, - sizeof(std::uint32_t) * adjusted_axis_data.size(), - adjusted_axis_data.data()); - - OpWrapper& reduce_op = CreateOpWrapper(res, QNN_OP_REDUCE_SUM); - reduce_op.AddInputTensor(input_tensor); - reduce_op.AddOutputTensor(outputs[0]); - reduce_op.AddTensorParam(QNN_OP_REDUCE_SUM_PARAM_AXES, adjusted_axis_tensor); - reduce_op.AddScalarParam(QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS, keep_dims); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h deleted file mode 100644 index cb43106587d91e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_REDUCE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_REDUCE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReduceSumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_REDUCE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc deleted file mode 100644 index ed9330211bf41f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildRelu6Op( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_RELU6, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h deleted file mode 100644 index 6261da7fd1b80d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU6_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU6_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildRelu6Op( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU6_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc deleted file mode 100644 index bfbfb37c8dd247..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_RELU, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h deleted file mode 100644 index 3d2d5da8f2fa7a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc deleted file mode 100644 index a51711dfb5ac59..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReshapeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op.AddInputTensor(inputs[0]); - reshape_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h deleted file mode 100644 index 6b14ad38bbd01b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESHAPE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESHAPE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReshapeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESHAPE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc deleted file mode 100644 index c0a1f173b423b0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; - -std::vector BuildResizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const char* op_type, - const char* align_corners_param, const char* half_pixel_centers_param, - const bool align_corners, const bool half_pixel_centers) { - std::vector res; - - auto& resize_op = CreateOpWrapper(res, op_type); - resize_op.AddInputTensor(inputs[kInputIndex]); - resize_op.AddOutputTensor(outputs[kOutputIndex]); - resize_op.AddScalarParam(align_corners_param, align_corners); - resize_op.AddScalarParam(half_pixel_centers_param, half_pixel_centers); - - return res; -} -} // namespace - -std::vector BuildResizeBilinearOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers) { - return BuildResizeOp(tensor_pool, inputs, outputs, QNN_OP_RESIZE_BILINEAR, - QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, - QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, - align_corners, half_pixel_centers); -} - -std::vector BuildResizeNearestOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers) { - return BuildResizeOp(tensor_pool, inputs, outputs, - QNN_OP_RESIZE_NEAREST_NEIGHBOR, - QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_ALIGN_CORNERS, - QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_HALF_PIXEL_CENTERS, - align_corners, half_pixel_centers); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h deleted file mode 100644 index c24e889ee9f0a2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESIZE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESIZE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildResizeBilinearOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers); - -std::vector BuildResizeNearestOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESIZE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc deleted file mode 100644 index fc88f639b76684..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -static constexpr int kInputIndex = 0; -static constexpr int kGammaIndex = 1; - -std::vector BuildRmsNormOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float epsilon) { - std::vector res; - - auto& rms_norm_op = CreateOpWrapper(res, QNN_OP_RMS_NORM); - for (const auto& input : inputs) { - rms_norm_op.AddInputTensor(input); - } - - // Constructs axis param tensor. - std::vector axis_data; - axis_data.emplace_back(inputs[kInputIndex].get().GetRank() - 1); - TensorWrapper& axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, inputs[kInputIndex].get().GetQuantParams(), {1}, - sizeof(std::uint32_t) * axis_data.size(), axis_data.data()); - - if (inputs[kGammaIndex].get().GetDataType() == QNN_DATATYPE_FLOAT_32) { - // Construct float beta static all 0 tensor. - std::vector beta_data( - inputs[kGammaIndex].get().GetTensorNumElements(), 0); - TensorWrapper& beta_tensor = tensor_pool.CreateStaticTensor( - inputs[kGammaIndex].get().GetDataType(), - inputs[kGammaIndex].get().GetQuantParams(), - inputs[kGammaIndex].get().GetDims(), sizeof(float) * beta_data.size(), - beta_data.data()); - rms_norm_op.AddInputTensor(beta_tensor); - } else { - // Construct uint8_t beta static all 0 tensor. - std::vector beta_data( - inputs[kGammaIndex].get().GetTensorNumElements(), 0); - - // Offset needs to be 0, scale does not matter since data is 0 - ScaleOffsetQuantizeParamsWrapper q_param(0.00001, 0); - - TensorWrapper& beta_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UFIXED_POINT_8, q_param, - inputs[kGammaIndex].get().GetDims(), sizeof(uint8_t) * beta_data.size(), - beta_data.data()); - rms_norm_op.AddInputTensor(beta_tensor); - } - - rms_norm_op.AddScalarParam(QNN_OP_RMS_NORM_PARAM_EPSILON, epsilon); - rms_norm_op.AddTensorParam(QNN_OP_RMS_NORM_PARAM_AXES, axis_tensor); - rms_norm_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h deleted file mode 100644 index f97e35fd58717e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RMS_NORM_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RMS_NORM_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildRmsNormOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float epsilon); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RMS_NORM_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc deleted file mode 100644 index 3312ae3d2e8d96..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSelectOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& select_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SELECT); - for (const auto& input : inputs) { - select_op.AddInputTensor(input); - } - select_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h deleted file mode 100644 index e5a4431f99ddb9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SELECT_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SELECT_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSelectOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SELECT_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc deleted file mode 100644 index af94e9ab833e22..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr int kDefaultStrideValue = 1; -constexpr int kSizeNegative = -1; -constexpr int kRangeNumElements = 3; -} // namespace - -std::vector BuildSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - TensorWrapper& input_tensor = inputs[0]; - TensorWrapper& begin_tensor = inputs[1]; - TensorWrapper& size_tensor = inputs[2]; - if (!begin_tensor.IsTensorStatic() || !size_tensor.IsTensorStatic()) { - QNN_LOG_ERROR( - "The begin tensor and size tensor of Slice OP is not static."); - return res; - } - - const auto input_rank = input_tensor.GetRank(); - auto begin_data = begin_tensor.GetStaticTensorData(); - if (!begin_data.has_value()) { - QNN_LOG_ERROR("Get begin_data failed."); - return res; - } - auto size_data = size_tensor.GetStaticTensorData(); - if (!size_data.has_value()) { - QNN_LOG_ERROR("Get size_data failed."); - return res; - } - std::vector range_data; - range_data.reserve(input_rank * kRangeNumElements); - for (size_t i = 0; i < input_rank; ++i) { - range_data.emplace_back((*begin_data)[i]); - if ((*size_data)[i] == kSizeNegative) { - range_data.emplace_back(input_tensor.GetDim(i)); - } else { - range_data.emplace_back((*begin_data)[i] + (*size_data)[i]); - } - range_data.emplace_back(kDefaultStrideValue); - } - TensorWrapper& range_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_INT_32, begin_tensor.GetQuantParams(), - {input_rank, kRangeNumElements}, sizeof(std::int32_t) * range_data.size(), - range_data.data()); - - auto& slice_op = CreateOpWrapper(res, QNN_OP_STRIDED_SLICE); - slice_op.AddTensorParam(QNN_OP_STRIDED_SLICE_PARAM_RANGES, range_tensor); - slice_op.AddInputTensor(input_tensor); - slice_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h deleted file mode 100644 index 7eb9c013dcccfd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SLICE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SLICE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SLICE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc deleted file mode 100644 index 5d3e226b011846..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSoftmaxOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float beta) { - std::vector res; - - auto& softmax_op = CreateOpWrapper(res, QNN_OP_SOFTMAX); - softmax_op.AddInputTensor(inputs[0]); - softmax_op.AddOutputTensor(outputs[0]); - softmax_op.AddScalarParam(QNN_OP_SOFTMAX_PARAM_BETA, beta); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h deleted file mode 100644 index bac0ea1c0d76d1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SOFTMAX_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SOFTMAX_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSoftmaxOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float beta); -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SOFTMAX_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc deleted file mode 100644 index 9f77d75c18ae01..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; - -std::vector BuildSpatialTransformOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const char* op_type, - const char* block_param, const std::uint32_t block_size) { - std::vector res; - - auto& spatial_transform_op = CreateOpWrapper(res, op_type); - spatial_transform_op.AddInputTensor(inputs[kInputIndex]); - spatial_transform_op.AddOutputTensor(outputs[kOutputIndex]); - const std::array block_data = {block_size, block_size}; - const std::vector block_dims{2}; - auto& block_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, block_dims, - sizeof(decltype(block_dims)::value_type) * block_dims.size(), - block_data.data()); - spatial_transform_op.AddTensorParam(block_param, block_tensor); - - return res; -} -} // namespace - -std::vector BuildDepthToSpaceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size) { - return BuildSpatialTransformOp( - tensor_pool, inputs, outputs, QNN_OP_DEPTH_TO_SPACE, - QNN_OP_DEPTH_TO_SPACE_PARAM_BLOCK_SIZE, block_size); -} - -std::vector BuildSpaceToDepthOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size) { - return BuildSpatialTransformOp( - tensor_pool, inputs, outputs, QNN_OP_SPACE_TO_DEPTH, - QNN_OP_SPACE_TO_DEPTH_PARAM_BLOCK_SIZE, block_size); -} -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h deleted file mode 100644 index c2e7c5e19c68fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPATIAL_TRANSFORM_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPATIAL_TRANSFORM_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildDepthToSpaceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size); - -std::vector BuildSpaceToDepthOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPATIAL_TRANSFORM_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc deleted file mode 100644 index 4bdb6322bd0e0b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr int kSplitIndexRank = 1; -constexpr int kinputAxisIndex = 0; -} // namespace - -std::vector BuildSplitOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t num_splits) { - std::vector res; - - const TensorWrapper& axis_tensor = inputs[kinputAxisIndex]; - if (!axis_tensor.IsTensorStatic()) { - return res; - } - - const TensorWrapper& input_tensor = inputs[kSplitIndexRank]; - auto axis_data = axis_tensor.GetStaticTensorData(); - if (!axis_data.has_value()) { - QNN_LOG_ERROR("Get axis_data failed."); - return res; - } - std::uint32_t axis = (*axis_data)[0] >= 0 - ? (*axis_data)[0] - : (*axis_data)[0] + input_tensor.GetRank(); - - const std::uint32_t slice_size = input_tensor.GetDim(axis) / num_splits; - // The split_indice will do N cuts, split the dimension into N+1 clips - // so 0 will not be included in the split_indice - // for example, when we split 12 into 4 clip, the split index will be {3,6,9} - std::vector split_indice; - split_indice.reserve(num_splits); - for (int i = 1; i < num_splits; i++) { - split_indice.emplace_back(static_cast(i * slice_size)); - } - TensorWrapper& split_indice_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, axis_tensor.GetQuantParams(), {num_splits - 1}, - sizeof(std::uint32_t) * split_indice.size(), split_indice.data()); - - auto& split_op = CreateOpWrapper(res, QNN_OP_SPLIT); - split_op.AddInputTensor(input_tensor); - for (const auto& output : outputs) { - split_op.AddOutputTensor(output); - } - split_op.AddScalarParam(QNN_OP_SPLIT_PARAM_AXIS, axis); - split_op.AddTensorParam(QNN_OP_SPLIT_PARAM_SPLIT_INDEX, split_indice_tensor); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h deleted file mode 100644 index 76fafd15cba35c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPLIT_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPLIT_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSplitOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t num_splits); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPLIT_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc deleted file mode 100644 index 221ebf796c52e0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTanhOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_TANH, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h deleted file mode 100644 index 1ede3ba202baf3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TANH_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TANH_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTanhOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TANH_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc deleted file mode 100644 index 5f1415ffadf8fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTransposeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - TensorWrapper& perm_tensor = inputs[1]; - if (!perm_tensor.IsTensorStatic()) { - QNN_LOG_ERROR("The param 'perm' of Transpose OP is not static."); - return res; - } - - auto& transpose_op = CreateOpWrapper(res, QNN_OP_TRANSPOSE); - transpose_op.AddInputTensor(inputs[0]); - transpose_op.AddOutputTensor(outputs[0]); - transpose_op.AddTensorParam( - QNN_OP_TRANSPOSE_PARAM_PERM, - tensor_pool.CloneStaticTensorFrom(perm_tensor, QNN_DATATYPE_UINT_32)); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h deleted file mode 100644 index 7f32710f29b309..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TRANSPOSE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TRANSPOSE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTransposeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TRANSPOSE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h deleted file mode 100644 index 7fd072eaff9b1d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_COMMON_H_ - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum LiteRtQnnLogLevel { // NOLINT(modernize-use-using) - /// Disable delegate and QNN backend logging messages. - kLogOff = 0, - kLogLevelError = 1, - kLogLevelWarn = 2, - kLogLevelInfo = 3, - kLogLevelVerbose = 4, - kLogLevelDebug = 5, -} LiteRtQnnLogLevel; - -typedef struct { // NOLINT(modernize-use-using) - /// Apply HTP-friendly op builder. - bool useHtpPreferencs; - /// This option will treat quantized int16 tensor as quantized uint16 tensor - /// for better backend compatibility. - bool useQInt16AsQUint16; -} LiteRtQnnOptions; - -// clang-format off -#define LITERT_QNN_OPTIONS_INIT \ - { \ - false, /*useHtpPreferencs*/ \ - true, /*useQInt16AsQUint16*/ \ - } -// clang-format on -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc deleted file mode 100644 index 27cce37e3f2c4d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -TensorPool::TensorPool() = default; - -TensorWrapper& TensorPool::CreateInputTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_APP_WRITE, data_type, quant_params, dimentions); - return back; -} - -TensorWrapper& TensorPool::CreateOutpuTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_APP_READ, data_type, quant_params, dimentions); - return back; -} - -TensorWrapper& TensorPool::CreateNativeTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_NATIVE, data_type, quant_params, dimentions); - return back; -} - -TensorWrapper& TensorPool::CreateStaticTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions, std::uint32_t bytes, - const void* data) { - const auto id = tensor_wrappers_.size(); - auto& back = - tensor_wrappers_.emplace_back(id, QNN_TENSOR_TYPE_STATIC, data_type, - quant_params, dimentions, bytes, data); - return back; -} - -TensorWrapper& TensorPool::CloneNativeTensorFrom(const TensorWrapper& src) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_NATIVE, src.GetDataType(), src.quantize_params_, - src.dimentions_); - return back; -} - -TensorWrapper& TensorPool::CloneNativeTensorFrom( - const TensorWrapper& src, const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back(id, QNN_TENSOR_TYPE_NATIVE, - src.GetDataType(), - src.quantize_params_, dimentions); - return back; -} - -TensorWrapper& TensorPool::CloneStaticTensorFrom(const TensorWrapper& src, - Qnn_DataType_t data_type) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_STATIC, data_type, src.quantize_params_, - src.dimentions_, src.owned_data_.size(), src.owned_data_.data()); - return back; -} - -TensorWrapper& TensorPool::CloneStaticTensorFrom( - const TensorWrapper& src, const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_STATIC, src.qnn_tensor_.v2.dataType, - src.quantize_params_, dimentions, src.qnn_tensor_.v2.clientBuf.dataSize, - src.qnn_tensor_.v2.clientBuf.data); - - return back; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h deleted file mode 100644 index a21199ad2e40c2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_TENSOR_POOL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_TENSOR_POOL_H_ - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -class TensorPool { - public: - TensorPool(); - - TensorWrapper& CreateInputTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions); - - TensorWrapper& CreateOutpuTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions); - - TensorWrapper& CreateNativeTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions); - - TensorWrapper& CreateStaticTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions, std::uint32_t bytes, - const void* data); - - TensorWrapper& CloneNativeTensorFrom(const TensorWrapper& src); - - TensorWrapper& CloneNativeTensorFrom( - const TensorWrapper& src, const std::vector& dimentions); - - TensorWrapper& CloneStaticTensorFrom(const TensorWrapper& src, - Qnn_DataType_t data_type); - - TensorWrapper& CloneStaticTensorFrom( - const TensorWrapper& src, const std::vector& dimentions); - - template - void ForEach(UnaryFunc f) { - for (auto& tensor_wrapper : tensor_wrappers_) { - f(tensor_wrapper); - } - } - - private: - std::list tensor_wrappers_{}; -}; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_TENSOR_POOL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD deleted file mode 100644 index 3ce72dec755646..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "log", - srcs = select({ - "//tensorflow:android": ["log_android.cc"], - "//conditions:default": ["log_default.cc"], - }), - hdrs = ["log.h"], - linkopts = select({ - "//tensorflow:android": ["-llog"], - "//conditions:default": [], - }), - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:common", - ], -) - -cc_library( - name = "miscs", - srcs = ["miscs.cc"], - hdrs = ["miscs.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "utils_test", - srcs = [ - "utils_test.cc", - ], - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. - "nosan", - ], - deps = [ - ":log", - ":miscs", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:common", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h deleted file mode 100644 index f89b4131dea4b6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_LOG_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_LOG_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h" - -namespace qnn { - -class QNNLogger { - public: - // Logging hook that takes variadic args. - static void Log(LiteRtQnnLogLevel severity, const char* format, ...); - - // Set file descriptor - static void SetLogFilePointer(FILE* fp); - - // Set log level - static void SetLogLevel(LiteRtQnnLogLevel log_level); - - private: - // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) - static FILE* log_file_pointer_; - static LiteRtQnnLogLevel log_level_; - // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -}; -} // namespace qnn - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_VERBOSE(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelVerbose, ("VERBOSE: [Qnn] " format), \ - ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_INFO(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelInfo, ("INFO: [Qnn] " format), ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_WARNING(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelWarn, ("WARNING: [Qnn] " format), \ - ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_ERROR(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelError, ("ERROR: [Qnn] " format), \ - ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_DEBUG(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelDebug, ("DEBUG: [Qnn] " format), \ - ##__VA_ARGS__); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_LOG_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc deleted file mode 100644 index ec13856cda945b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "log.h" - -namespace qnn { -namespace { - -int GetPlatformSeverity(LiteRtQnnLogLevel severity) { - switch (severity) { - case kLogLevelError: - return ANDROID_LOG_ERROR; - case kLogLevelWarn: - return ANDROID_LOG_WARN; - case kLogLevelInfo: - return ANDROID_LOG_INFO; - case kLogLevelVerbose: - return ANDROID_LOG_VERBOSE; - default: - return ANDROID_LOG_DEBUG; - } -} - -} // namespace - -// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -FILE* QNNLogger::log_file_pointer_ = stderr; -LiteRtQnnLogLevel QNNLogger::log_level_ = kLogLevelInfo; -// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -void QNNLogger::SetLogFilePointer(FILE* fp) { log_file_pointer_ = fp; } -void QNNLogger::SetLogLevel(LiteRtQnnLogLevel log_level) { - log_level_ = log_level; -} -// NOLINTNEXTLINE(cert-dcl50-cpp) -void QNNLogger::Log(LiteRtQnnLogLevel severity, const char* format, ...) { - if (severity > log_level_) { - return; - } - - // Pass to LogFormatted - va_list args; - va_start(args, format); - - // First log to Android's explicit log(cat) API. - va_list args_copy; - va_copy(args_copy, args); - __android_log_vprint(GetPlatformSeverity(severity), "qnn", format, args_copy); - va_end(args_copy); - - // Print to file pointer. - vfprintf(log_file_pointer_, format, args); - fputc('\n', log_file_pointer_); - - va_end(args); -} -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc deleted file mode 100644 index 6d9067d26d61a3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include - -#include "log.h" - -namespace qnn { - -// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -FILE* QNNLogger::log_file_pointer_ = stderr; -LiteRtQnnLogLevel QNNLogger::log_level_ = kLogLevelInfo; -// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -void QNNLogger::SetLogFilePointer(FILE* fp) { log_file_pointer_ = fp; } -void QNNLogger::SetLogLevel(LiteRtQnnLogLevel log_level) { - log_level_ = log_level; -} -// NOLINTNEXTLINE(cert-dcl50-cpp) -void QNNLogger::Log(LiteRtQnnLogLevel severity, const char* format, ...) { - if (severity > log_level_) { - return; - } - - // Pass to LogFormatted - va_list args; - va_start(args, format); - - // Print to file pointer. - vfprintf(log_file_pointer_, format, args); - fputc('\n', log_file_pointer_); - - va_end(args); -} -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc deleted file mode 100644 index e07ef251adcc10..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" - -#include -#include - -#include "absl/types/span.h" - -namespace qnn { -void ConvertDataFromInt16toUInt16(absl::Span src, - std::vector& dst) { - dst.clear(); - dst.reserve(src.size()); - for (const auto& data : src) { - dst.emplace_back(data + kUint16ZeroPoint); - } -} - -void ConvertDataFromUInt16toInt16(absl::Span src, - std::vector& dst) { - dst.clear(); - dst.reserve(src.size()); - for (const auto& data : src) { - dst.emplace_back(data - kUint16ZeroPoint); - } -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h deleted file mode 100644 index 7b12cc09eecf3d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_MISCS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_MISCS_H_ - -#include -#include -#include -#include -#include - -#include "absl/types/span.h" - -namespace qnn { - -constexpr uint32_t kUint16ZeroPoint = -std::numeric_limits::min(); - -template -inline constexpr bool always_false = false; - -template -T Quantize(const float val, const float scale, const int32_t zero_point) { - static_assert(std::is_integral::value, - "Integral required in Quantize function."); - return std::round(val / scale) + zero_point; -} - -template -float Dequantize(const T val, const float scale, const int32_t zero_point) { - static_assert(std::is_integral::value, - "Integral required in Dequantize function."); - return scale * (val - zero_point); -} - -void ConvertDataFromInt16toUInt16(absl::Span src, - std::vector& dst); - -void ConvertDataFromUInt16toInt16(absl::Span src, - std::vector& dst); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_MISCS_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc deleted file mode 100644 index c8953157ada8fb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include -#include -#include -#include -#include -#include -#include - -#include -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" - -namespace qnn { -namespace { - -bool IsPrefix(std::string_view prefix, std::string_view full) { - return prefix == full.substr(0, prefix.size()); -} - -bool CheckLoggoing(const std::string log_path, LiteRtQnnLogLevel log_level) { - std::ifstream fin(log_path); - std::string msg; - while (std::getline(fin, msg)) { - // Log severity: DEBUG > VERBOSE > INFO > WARN > ERROR - switch (log_level) { - case kLogOff: - if (IsPrefix("ERROR:", msg)) return false; - [[fallthrough]]; - case kLogLevelError: - if (IsPrefix("WARNING:", msg)) return false; - [[fallthrough]]; - case kLogLevelWarn: - if (IsPrefix("INFO:", msg)) return false; - [[fallthrough]]; - case kLogLevelInfo: - if (IsPrefix("VERBOSE:", msg)) return false; - [[fallthrough]]; - case kLogLevelVerbose: - if (IsPrefix("DEBUG:", msg)) return false; - [[fallthrough]]; - default: - break; - } - } - return true; -} - -} // namespace - -class LiteRtLog : public ::testing::TestWithParam {}; -INSTANTIATE_TEST_SUITE_P(, LiteRtLog, - ::testing::Values(kLogOff, kLogLevelError, - kLogLevelWarn, kLogLevelInfo, - kLogLevelVerbose, kLogLevelDebug)); - -TEST_P(LiteRtLog, SanityTest) { - // Create temp file for log - std::filesystem::path temp_path = - std::filesystem::temp_directory_path() / "temp.log"; - std::ofstream fout(temp_path); - ASSERT_TRUE(fout.is_open()); - - // Set log file pointer - FILE* file_ptr = fopen(temp_path.c_str(), "w"); - ASSERT_NE(file_ptr, nullptr); - qnn::QNNLogger::SetLogFilePointer(file_ptr); - - // Set log_level and print message to file - LiteRtQnnLogLevel log_level = GetParam(); - qnn::QNNLogger::SetLogLevel(log_level); - QNN_LOG_VERBOSE("This is a verbose message."); - QNN_LOG_INFO("This is an info message."); - QNN_LOG_WARNING("This is a warning message."); - QNN_LOG_ERROR("This is an error message."); - QNN_LOG_DEBUG("This is a debug message."); - qnn::QNNLogger::SetLogFilePointer(stderr); - fclose(file_ptr); - - // Check logging messages are as expected - ASSERT_EQ(CheckLoggoing(temp_path.string(), log_level), true); - - // Delete the temporary log file - std::filesystem::remove(temp_path); -} - -TEST(MiscTest, TestAlwaysFalse) { - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); -} - -TEST(MiscTests, Quantize) { - float val = 1; - float scale = 0.1; - int32_t zero_point = 1; - auto q_val = Quantize(val, scale, zero_point); - EXPECT_EQ(q_val, 11); -} - -TEST(MiscTests, Dequantize) { - std::int8_t q_val = 11; - float scale = 0.1; - int32_t zero_point = 1; - auto val = Dequantize(q_val, scale, zero_point); - EXPECT_FLOAT_EQ(val, 1); -} - -TEST(MiscTests, ConvertDataFromInt16toUInt16) { - constexpr int16_t int16_data[4] = {0, 1, 2, 3}; - size_t data_len = sizeof(int16_data) / sizeof(int16_data[0]); - absl::Span int16_span(int16_data, data_len); - std::vector uint16_data; - - ConvertDataFromInt16toUInt16(int16_span, uint16_data); - EXPECT_EQ(uint16_data[0], 32768); - EXPECT_EQ(uint16_data[1], 32769); - EXPECT_EQ(uint16_data[2], 32770); - EXPECT_EQ(uint16_data[3], 32771); -} - -TEST(MiscTests, ConvertDataFromUInt16toInt16) { - constexpr uint16_t uint16_data[4] = {32768, 32769, 32770, 32771}; - size_t data_len = sizeof(uint16_data) / sizeof(uint16_data[0]); - absl::Span uint16_span(uint16_data, data_len); - std::vector int16_data; - - ConvertDataFromUInt16toInt16(uint16_span, int16_data); - EXPECT_EQ(int16_data[0], 0); - EXPECT_EQ(int16_data[1], 1); - EXPECT_EQ(int16_data[2], 2); - EXPECT_EQ(int16_data[3], 3); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD deleted file mode 100644 index e904d2a9c4efb3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "quantize_params_wrapper", - srcs = ["quantize_params_wrapper.cc"], - hdrs = ["quantize_params_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - ], -) - -cc_library( - name = "tensor_wrapper", - srcs = ["tensor_wrapper.cc"], - hdrs = ["tensor_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":quantize_params_wrapper", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:miscs", - ], -) - -cc_library( - name = "param_wrapper", - srcs = ["param_wrapper.cc"], - hdrs = ["param_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:miscs", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "op_wrapper", - srcs = ["op_wrapper.cc"], - hdrs = ["op_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:param_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc deleted file mode 100644 index 43ac6a1a0704f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -OpWrapper::OpWrapper(std::string name, const char* op_type) - : type_name_{op_type}, name_{std::move(name)} {} - -OpWrapper::OpWrapper(OpWrapper&& other) - : type_name_{other.type_name_}, - name_{std::move(other.name_)}, - input_tensors_{std::move(other.input_tensors_)}, - output_tensors_{std::move(other.output_tensors_)}, - scalar_params_{std::move(other.scalar_params_)}, - tensor_params_{std::move(other.tensor_params_)}, - qnn_input_tensors_{std::move(other.qnn_input_tensors_)}, - qnn_output_tensors_{std::move(other.qnn_output_tensors_)}, - qnn_params_{std::move(other.qnn_params_)} {} - -OpWrapper::~OpWrapper() = default; - -void OpWrapper::AddInputTensor(const TensorWrapper& tensor) { - input_tensors_.emplace_back(tensor); -} - -void OpWrapper::AddOutputTensor(const TensorWrapper& tensor) { - output_tensors_.emplace_back(tensor); -} - -void OpWrapper::AddTensorParam(const char* name, const TensorWrapper& tensor) { - tensor_params_.emplace_back(name, tensor); -} - -Qnn_OpConfig_t OpWrapper::GetOpConfig() { - Qnn_OpConfig_t qnn_op = QNN_OPCONFIG_INIT; - qnn_op.v1.packageName = QNN_OP_PACKAGE_NAME_QTI_AISW; - qnn_op.v1.typeName = type_name_; - qnn_op.v1.name = name_.data(); - // input tensors - qnn_input_tensors_.reserve(input_tensors_.size()); - qnn_input_tensors_.clear(); - for (const auto& input_tensor : input_tensors_) { - auto& back = qnn_input_tensors_.emplace_back(); - input_tensor.get().CloneTo(back); - } - qnn_op.v1.numOfInputs = qnn_input_tensors_.size(); - qnn_op.v1.inputTensors = qnn_input_tensors_.data(); - // output tensors - qnn_output_tensors_.reserve(output_tensors_.size()); - qnn_output_tensors_.clear(); - for (const auto& output_tensor : output_tensors_) { - auto& back = qnn_output_tensors_.emplace_back(); - output_tensor.get().CloneTo(back); - } - qnn_op.v1.numOfOutputs = qnn_output_tensors_.size(); - qnn_op.v1.outputTensors = qnn_output_tensors_.data(); - // params - qnn_params_.reserve(scalar_params_.size() + tensor_params_.size()); - qnn_params_.clear(); - for (const auto& scalar_param : scalar_params_) { - auto& back = qnn_params_.emplace_back(); - scalar_param.CloneTo(back); - } - for (const auto& tensor_param : tensor_params_) { - auto& back = qnn_params_.emplace_back(); - tensor_param.CloneTo(back); - } - qnn_op.v1.numOfParams = qnn_params_.size(); - qnn_op.v1.params = qnn_params_.data(); - return qnn_op; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h deleted file mode 100644 index 62858fb2ec2421..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_OP_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_OP_WRAPPER_H_ - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -class OpWrapper final { - public: - explicit OpWrapper(std::string name, const char* op_type); - - OpWrapper(const OpWrapper& other) = delete; - - OpWrapper(OpWrapper&& other); - - ~OpWrapper(); - - void AddInputTensor(const TensorWrapper& tensor); - - void AddOutputTensor(const TensorWrapper& tensor); - - template - void AddScalarParam(const char* name, const T data, - const bool is_quant = false) { - scalar_params_.emplace_back(name, data, is_quant); - } - - void AddTensorParam(const char* name, const TensorWrapper& tensor); - - Qnn_OpConfig_t GetOpConfig(); - - private: - const char* type_name_{nullptr}; - std::string name_{}; // human readable name - std::vector> input_tensors_{}; - std::vector> output_tensors_{}; - std::vector scalar_params_{}; - std::vector tensor_params_{}; - std::vector qnn_input_tensors_{}; - std::vector qnn_output_tensors_{}; - std::vector qnn_params_{}; -}; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_OP_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc deleted file mode 100644 index 9be8b2b4d635c2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h" - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -void ScalarParamWrapper::CloneTo(Qnn_Param_t& dst) const { - dst.name = name_; - dst.paramType = QNN_PARAMTYPE_SCALAR; - dst.scalarParam = qnn_scalar_; -} - -TensorParamWrapper::TensorParamWrapper(const char* name, - const TensorWrapper& tensor) - : name_{name}, tensor_{tensor} {} - -void TensorParamWrapper::CloneTo(Qnn_Param_t& dst) const { - dst.name = name_; - dst.paramType = QNN_PARAMTYPE_TENSOR; - tensor_.CloneTo(dst.tensorParam); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h deleted file mode 100644 index 9dbc63102cf6f2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_PARAM_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_PARAM_WRAPPER_H_ - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -class ScalarParamWrapper { - public: - template - explicit ScalarParamWrapper(const char* name, const T data, - const bool is_quant) - : name_{name} { - if constexpr (std::is_same_v) { - qnn_scalar_.dataType = QNN_DATATYPE_BOOL_8; - qnn_scalar_.bool8Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_UFIXED_POINT_8 : QNN_DATATYPE_UINT_8; - qnn_scalar_.uint8Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_SFIXED_POINT_8 : QNN_DATATYPE_INT_8; - qnn_scalar_.int8Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_UFIXED_POINT_16 : QNN_DATATYPE_UINT_16; - qnn_scalar_.uint16Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_INT_16; - qnn_scalar_.int16Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_UFIXED_POINT_32 : QNN_DATATYPE_UINT_32; - qnn_scalar_.uint32Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_SFIXED_POINT_32 : QNN_DATATYPE_INT_32; - qnn_scalar_.int32Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = QNN_DATATYPE_FLOAT_32; - qnn_scalar_.floatValue = data; - } else { - static_assert(::qnn::always_false, - "Unsupported data type for scalar param."); - } - } - - void CloneTo(Qnn_Param_t& dst) const; - - private: - const char* name_ = nullptr; - Qnn_Scalar_t qnn_scalar_ = QNN_SCALAR_INIT; -}; - -class TensorParamWrapper { - public: - explicit TensorParamWrapper(const char* name, const TensorWrapper& tensor); - - void CloneTo(Qnn_Param_t& dst) const; - - private: - const char* name_ = nullptr; - const TensorWrapper& tensor_; -}; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_PARAM_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc deleted file mode 100644 index ce327633207ef5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" - -namespace qnn { - -UndefinedQuantizeParamsWrapper::UndefinedQuantizeParamsWrapper() = default; - -UndefinedQuantizeParamsWrapper::UndefinedQuantizeParamsWrapper( - const UndefinedQuantizeParamsWrapper&) = default; - -UndefinedQuantizeParamsWrapper::UndefinedQuantizeParamsWrapper( - UndefinedQuantizeParamsWrapper&&) = default; - -void UndefinedQuantizeParamsWrapper::CloneTo(Qnn_QuantizeParams_t& dst) { - dst = qnn_quantize_param_; -} - -ScaleOffsetQuantizeParamsWrapper::ScaleOffsetQuantizeParamsWrapper( - const float scale, const std::int32_t zero_point) { - qnn_quantize_param_.encodingDefinition = QNN_DEFINITION_DEFINED; - qnn_quantize_param_.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; - qnn_quantize_param_.scaleOffsetEncoding.scale = scale; - qnn_quantize_param_.scaleOffsetEncoding.offset = -1 * zero_point; -} - -ScaleOffsetQuantizeParamsWrapper::ScaleOffsetQuantizeParamsWrapper( - const ScaleOffsetQuantizeParamsWrapper&) = default; - -ScaleOffsetQuantizeParamsWrapper::ScaleOffsetQuantizeParamsWrapper( - ScaleOffsetQuantizeParamsWrapper&&) = default; - -void ScaleOffsetQuantizeParamsWrapper::CloneTo(Qnn_QuantizeParams_t& dst) { - dst = qnn_quantize_param_; -} - -AxisScaleOffsetQuantizeParamsWrapper::AxisScaleOffsetQuantizeParamsWrapper( - const std::int32_t axis, const absl::Span scales, - const absl::Span zero_points) - : scale_offsets_(scales.size()) { - assert(scales.size() == zero_points.size()); - for (size_t i = 0; i < scale_offsets_.size(); ++i) { - scale_offsets_[i].scale = scales[i]; - scale_offsets_[i].offset = -1 * zero_points[i]; - } - - qnn_quantize_param_.encodingDefinition = QNN_DEFINITION_DEFINED; - qnn_quantize_param_.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; - qnn_quantize_param_.axisScaleOffsetEncoding.axis = axis; - qnn_quantize_param_.axisScaleOffsetEncoding.numScaleOffsets = - scale_offsets_.size(); - qnn_quantize_param_.axisScaleOffsetEncoding.scaleOffset = - scale_offsets_.data(); -} - -AxisScaleOffsetQuantizeParamsWrapper::AxisScaleOffsetQuantizeParamsWrapper( - const AxisScaleOffsetQuantizeParamsWrapper& rhs) - : qnn_quantize_param_{rhs.qnn_quantize_param_}, - scale_offsets_{rhs.scale_offsets_} { - qnn_quantize_param_.axisScaleOffsetEncoding.scaleOffset = - scale_offsets_.data(); -} - -AxisScaleOffsetQuantizeParamsWrapper::AxisScaleOffsetQuantizeParamsWrapper( - AxisScaleOffsetQuantizeParamsWrapper&& rhs) - : qnn_quantize_param_{rhs.qnn_quantize_param_}, - scale_offsets_{std::move(rhs.scale_offsets_)} { - qnn_quantize_param_.axisScaleOffsetEncoding.scaleOffset = - scale_offsets_.data(); -} - -void AxisScaleOffsetQuantizeParamsWrapper::CloneTo(Qnn_QuantizeParams_t& dst) { - dst = qnn_quantize_param_; -} - -std::int32_t AxisScaleOffsetQuantizeParamsWrapper::GetAxis() const { - return qnn_quantize_param_.axisScaleOffsetEncoding.axis; -} - -void AxisScaleOffsetQuantizeParamsWrapper::SetAxis(const std::int32_t axis) { - qnn_quantize_param_.axisScaleOffsetEncoding.axis = axis; -} - -void AxisScaleOffsetQuantizeParamsWrapper::GetScales( - std::vector& scales) const { - scales.clear(); - scales.reserve(scale_offsets_.size()); - for (size_t i = 0; i < scale_offsets_.size(); ++i) { - scales.emplace_back(scale_offsets_[i].scale); - } -} - -void AxisScaleOffsetQuantizeParamsWrapper::GetZeroPoints( - std::vector& zero_points) const { - zero_points.clear(); - zero_points.reserve(scale_offsets_.size()); - for (size_t i = 0; i < scale_offsets_.size(); ++i) { - zero_points.emplace_back(-1 * scale_offsets_[i].offset); - } -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h deleted file mode 100644 index ee209ef4c7d2f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_QUANTIZE_PARAMS_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_QUANTIZE_PARAMS_WRAPPER_H_ - -#include -#include -#include -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" - -namespace qnn { - -class UndefinedQuantizeParamsWrapper final { - public: - UndefinedQuantizeParamsWrapper(); - - UndefinedQuantizeParamsWrapper(const UndefinedQuantizeParamsWrapper&); - - UndefinedQuantizeParamsWrapper(UndefinedQuantizeParamsWrapper&&); - - void CloneTo(Qnn_QuantizeParams_t& dst); - - private: - Qnn_QuantizeParams_t qnn_quantize_param_ = QNN_QUANTIZE_PARAMS_INIT; -}; - -class ScaleOffsetQuantizeParamsWrapper final { - public: - explicit ScaleOffsetQuantizeParamsWrapper(const float scale, - const std::int32_t zero_point); - - ScaleOffsetQuantizeParamsWrapper(const ScaleOffsetQuantizeParamsWrapper&); - - ScaleOffsetQuantizeParamsWrapper(ScaleOffsetQuantizeParamsWrapper&&); - - void CloneTo(Qnn_QuantizeParams_t& dst); - - float GetScale() const { - return qnn_quantize_param_.scaleOffsetEncoding.scale; - } - - std::int32_t GetZeroPoint() const { - return -1 * qnn_quantize_param_.scaleOffsetEncoding.offset; - } - - private: - Qnn_QuantizeParams_t qnn_quantize_param_ = QNN_QUANTIZE_PARAMS_INIT; -}; - -class AxisScaleOffsetQuantizeParamsWrapper final { - public: - explicit AxisScaleOffsetQuantizeParamsWrapper( - const std::int32_t axis, const absl::Span scales, - const absl::Span zero_points); - - AxisScaleOffsetQuantizeParamsWrapper( - const AxisScaleOffsetQuantizeParamsWrapper& rhs); - - AxisScaleOffsetQuantizeParamsWrapper( - AxisScaleOffsetQuantizeParamsWrapper&& rhs); - - void CloneTo(Qnn_QuantizeParams_t& dst); - - std::int32_t GetAxis() const; - - void SetAxis(const std::int32_t axis); - - void GetScales(std::vector& scales) const; - - void GetZeroPoints(std::vector& zero_points) const; - - private: - Qnn_QuantizeParams_t qnn_quantize_param_ = QNN_QUANTIZE_PARAMS_INIT; - std::vector scale_offsets_; -}; - -using QuantizeParamsWrapperVariant = - std::variant; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_QUANTIZE_PARAMS_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc deleted file mode 100644 index 1e78b7922d866e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -namespace qnn { - -std::size_t GetDataTypeSize(const Qnn_DataType_t data_type) { - std::size_t bytes = 0; - switch (data_type) { - case QNN_DATATYPE_INT_8: - case QNN_DATATYPE_UINT_8: - case QNN_DATATYPE_SFIXED_POINT_8: - case QNN_DATATYPE_UFIXED_POINT_8: - case QNN_DATATYPE_BOOL_8: - bytes = 1; - break; - case QNN_DATATYPE_INT_16: - case QNN_DATATYPE_UINT_16: - case QNN_DATATYPE_FLOAT_16: - case QNN_DATATYPE_SFIXED_POINT_16: - case QNN_DATATYPE_UFIXED_POINT_16: - bytes = 2; - break; - case QNN_DATATYPE_INT_32: - case QNN_DATATYPE_UINT_32: - case QNN_DATATYPE_FLOAT_32: - case QNN_DATATYPE_SFIXED_POINT_32: - case QNN_DATATYPE_UFIXED_POINT_32: - bytes = 4; - break; - case QNN_DATATYPE_INT_64: - case QNN_DATATYPE_UINT_64: - case QNN_DATATYPE_FLOAT_64: - bytes = 8; - break; - case QNN_DATATYPE_UNDEFINED: - case QNN_DATATYPE_SFIXED_POINT_4: - case QNN_DATATYPE_UFIXED_POINT_4: - default: - bytes = 0; - break; - } - return bytes; -} - -TensorWrapper::TensorWrapper() = default; - -TensorWrapper::TensorWrapper( - std::uint32_t id, Qnn_TensorType_t tensor_type, Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions) - : name_{std::to_string(id)}, - dimentions_{dimentions}, - quantize_params_{quantize_params} { - qnn_tensor_.v2.name = name_.c_str(); - qnn_tensor_.v2.type = tensor_type; - qnn_tensor_.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER; - qnn_tensor_.v2.dataType = data_type; - std::visit( - [this](auto&& quantize_params) -> void { - quantize_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); - qnn_tensor_.v2.rank = dimentions_.size(); - qnn_tensor_.v2.dimensions = dimentions_.data(); - qnn_tensor_.v2.memType = QNN_TENSORMEMTYPE_RAW; -} - -TensorWrapper::TensorWrapper( - std::uint32_t id, Qnn_TensorType_t tensor_type, Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions, std::uint32_t bytes, - const void* data) - : TensorWrapper(id, tensor_type, data_type, quantize_params, dimentions) { - SetDataBy(bytes, data); -} - -TensorWrapper::TensorWrapper(const TensorWrapper& other) - : qnn_tensor_{other.qnn_tensor_}, - name_{other.name_}, - dimentions_{other.dimentions_}, - quantize_params_{other.quantize_params_}, - owned_data_{other.owned_data_} { - qnn_tensor_.v2.name = name_.c_str(); - qnn_tensor_.v2.dimensions = dimentions_.data(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - std::visit( - [this](auto&& quant_params) -> void { - quant_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); -} - -TensorWrapper::TensorWrapper(TensorWrapper&& other) - : qnn_tensor_{other.qnn_tensor_}, - name_{std::move(other.name_)}, - dimentions_{std::move(other.dimentions_)}, - quantize_params_{std::move(other.quantize_params_)}, - owned_data_{std::move(other.owned_data_)} { - qnn_tensor_.v2.name = name_.c_str(); - qnn_tensor_.v2.dimensions = dimentions_.data(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - std::visit( - [this](auto&& quant_params) -> void { - quant_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); -} - -TensorWrapper::~TensorWrapper() = default; - -std::uint32_t TensorWrapper::GetDim(size_t index) const { - return dimentions_[index]; -} - -Qnn_DataType_t TensorWrapper::GetDataType() const { - return qnn_tensor_.v2.dataType; -} - -void TensorWrapper::CloneTo(Qnn_Tensor_t& dst) const { dst = qnn_tensor_; } - -std::uint32_t TensorWrapper::GetRank() const { return qnn_tensor_.v2.rank; } - -Qnn_TensorType_t TensorWrapper::GetTensorType() const { - return qnn_tensor_.v2.type; -} - -std::uint32_t TensorWrapper::GetTensorNumElements() const { - return GetDims().empty() ? 0 - : std::accumulate(GetDims().begin(), GetDims().end(), - 1, std::multiplies<>()); -} - -size_t TensorWrapper::GetTensorBytes() const { - return GetDataTypeSize(GetDataType()) * GetTensorNumElements(); -} - -bool TensorWrapper::IsPerTensorQuantWithOffsetDiff( - const TensorWrapper& rhs) const { - const auto& lhs_quant = qnn_tensor_.v2.quantizeParams; - const auto& rhs_quant = rhs.qnn_tensor_.v2.quantizeParams; - - if (lhs_quant.encodingDefinition != QNN_DEFINITION_DEFINED || - rhs_quant.encodingDefinition != QNN_DEFINITION_DEFINED) { - return false; - } - - if (lhs_quant.quantizationEncoding != - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET || - rhs_quant.quantizationEncoding != - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - return false; - } - - const auto lhs_scale = lhs_quant.scaleOffsetEncoding.scale; - const auto lhs_offset = lhs_quant.scaleOffsetEncoding.offset; - const auto rhs_scale = rhs_quant.scaleOffsetEncoding.scale; - const auto rhs_offset = rhs_quant.scaleOffsetEncoding.offset; - if ((GetDataType() == QNN_DATATYPE_SFIXED_POINT_8 && - rhs.GetDataType() == QNN_DATATYPE_UFIXED_POINT_8) || - (GetDataType() == QNN_DATATYPE_UFIXED_POINT_8 && - rhs.GetDataType() == QNN_DATATYPE_SFIXED_POINT_8)) { - constexpr int kSUFixed8OffsetDiff = 128; - if (std::fabs(lhs_scale - rhs_scale) < - std::numeric_limits::epsilon() && - std::abs(lhs_offset - rhs_offset) == kSUFixed8OffsetDiff) { - return true; - } - } else if ((GetDataType() == QNN_DATATYPE_SFIXED_POINT_16 && - rhs.GetDataType() == QNN_DATATYPE_UFIXED_POINT_16) || - (GetDataType() == QNN_DATATYPE_UFIXED_POINT_16 && - rhs.GetDataType() == QNN_DATATYPE_SFIXED_POINT_16)) { - constexpr int kSUFixed16OffsetDiff = 32768; - if (std::fabs(lhs_scale - rhs_scale) < - std::numeric_limits::epsilon() && - std::abs(lhs_offset - rhs_offset) == kSUFixed16OffsetDiff) { - return true; - } - } - return false; -} - -void TensorWrapper::SetDataBy(std::uint32_t bytes, const void* data) { - if (bytes != GetTensorBytes()) { - QNN_LOG_WARNING( - "Bytes: %d != GetTensorBytes(): %d, use GetTensorBytes() instead.", - bytes, GetTensorBytes()); - bytes = GetTensorBytes(); - } - owned_data_.resize(bytes); - std::memcpy(owned_data_.data(), reinterpret_cast(data), bytes); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); -} - -void TensorWrapper::ConvertQint16ToQuint16() { - if (GetDataType() != QNN_DATATYPE_SFIXED_POINT_16) { - return; - } - - // adjust static data - if (IsTensorStatic()) { - auto int16_data = GetStaticTensorData(); - if (!int16_data.has_value()) { - QNN_LOG_ERROR( - "Cannot convert static QInt16 data to QUint16 data failed since " - "GetStaticTensorData failed."); - return; - } - QNN_LOG_DEBUG("Converting static tensor data from QInt16 to QUint16..."); - std::vector uint16_data; - ConvertDataFromInt16toUInt16((*int16_data), uint16_data); - std::memcpy(owned_data_.data(), - reinterpret_cast(uint16_data.data()), - GetTensorBytes()); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - } - - // adjust quant param; - if (IsPerTensorQuant()) { - const auto& q_param = - std::get(GetQuantParams()); - quantize_params_.emplace( - q_param.GetScale(), q_param.GetZeroPoint() + kUint16ZeroPoint); - - } else if (IsPerChannelQuant()) { - const auto& q_param = - std::get(GetQuantParams()); - std::int32_t axis = q_param.GetAxis(); - std::vector scales; - q_param.GetScales(scales); - std::vector zero_points; - q_param.GetZeroPoints(zero_points); - std::for_each(zero_points.begin(), zero_points.end(), - [](std::int32_t& val) { val += kUint16ZeroPoint; }); - quantize_params_.emplace( - axis, absl::MakeSpan(scales), absl::MakeSpan(zero_points)); - } - - std::visit( - [this](auto&& quantize_params) -> void { - quantize_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); - - // change data type here since GetStaticTensorData checks data type - qnn_tensor_.v2.dataType = QNN_DATATYPE_UFIXED_POINT_16; - QNN_LOG_DEBUG( - "QNN does not fully support QInt16 now, converting to QUint16 for better " - "compatibility."); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h deleted file mode 100644 index 5a079868d2b98e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_TENSOR_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_TENSOR_WRAPPER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -namespace qnn { - -// Get the Qnn_DataType_t associated with given C++ type. -template -inline constexpr Qnn_DataType_t GetQnnDataType(const bool is_quant) { - if constexpr (std::is_same_v) { - return QNN_DATATYPE_BOOL_8; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_UFIXED_POINT_8 : QNN_DATATYPE_UINT_8; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_SFIXED_POINT_8 : QNN_DATATYPE_INT_8; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_UFIXED_POINT_16 : QNN_DATATYPE_UINT_16; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_INT_16; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_UFIXED_POINT_32 : QNN_DATATYPE_UINT_32; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_SFIXED_POINT_32 : QNN_DATATYPE_INT_32; - } else if constexpr (std::is_same_v) { - return QNN_DATATYPE_FLOAT_32; - } else { - static_assert(always_false, "Uknown C++ type"); - } - return QNN_DATATYPE_UNDEFINED; -} - -std::size_t GetDataTypeSize(const Qnn_DataType_t data_type); - -template -void TransposeFromOHWIToHWIO(absl::Span weight_data, - const std::vector& weight_dims, - std::vector& weight_data_transpose) { - weight_data_transpose.resize(weight_data.size()); - uint32_t output = weight_dims[0]; - uint32_t height = weight_dims[1]; - uint32_t width = weight_dims[2]; - uint32_t input = weight_dims[3]; - // OHWI->HWIO - uint32_t map_o = 0; - uint32_t map_w = 0; - uint32_t map_h = 0; - for (uint32_t index_o = 0; index_o < output; index_o++) { - map_o = index_o * height * width * input; - for (uint32_t index_h = 0; index_h < height; index_h++) { - map_h = index_h * width * input; - for (uint32_t index_w = 0; index_w < width; index_w++) { - map_w = index_w * input; - for (uint32_t index_i = 0; index_i < input; index_i++) { - T inval = weight_data[map_o + map_h + map_w + index_i]; - uint32_t index_transpose = index_h * width * input * output + - index_w * input * output + - index_i * output + index_o; - weight_data_transpose[index_transpose] = inval; - } - } - } - } -} - -class TensorWrapper final { - friend class TensorPool; - - public: - explicit TensorWrapper(); - - explicit TensorWrapper(std::uint32_t id, Qnn_TensorType_t tensor_type, - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions); - - explicit TensorWrapper(std::uint32_t id, Qnn_TensorType_t tensor_type, - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions, - std::uint32_t bytes, const void* data); - - TensorWrapper(const TensorWrapper& other); - - TensorWrapper(TensorWrapper&& other); - - ~TensorWrapper(); - - void CloneTo(Qnn_Tensor_t& dst) const; - - Qnn_Tensor_t& GetQnnTensor() { return qnn_tensor_; } - - std::uint32_t GetRank() const; - - std::uint32_t GetDim(size_t index) const; - - const std::vector& GetDims() const { return dimentions_; }; - - std::uint32_t GetTensorNumElements() const; - - const QuantizeParamsWrapperVariant& GetQuantParams() const { - return quantize_params_; - }; - - QuantizeParamsWrapperVariant& GetQuantParams() { return quantize_params_; }; - - bool IsQuant() const { - return !std::holds_alternative( - quantize_params_); - }; - - bool IsPerTensorQuant() const { - return std::holds_alternative( - quantize_params_); - } - - bool IsPerChannelQuant() const { - return std::holds_alternative( - quantize_params_); - } - - bool IsPerTensorQuantWithOffsetDiff(const TensorWrapper& rhs) const; - - bool IsQuant8() const { - return GetDataType() == QNN_DATATYPE_SFIXED_POINT_8 || - GetDataType() == QNN_DATATYPE_UFIXED_POINT_8; - } - - bool IsQuant16() const { - return GetDataType() == QNN_DATATYPE_SFIXED_POINT_16 || - GetDataType() == QNN_DATATYPE_UFIXED_POINT_16; - } - - bool IsF32() const { return GetDataType() == QNN_DATATYPE_FLOAT_32; } - bool IsF16() const { return GetDataType() == QNN_DATATYPE_FLOAT_16; } - - Qnn_DataType_t GetDataType() const; - - bool IsSubgraphInput() const { - return GetTensorType() == QNN_TENSOR_TYPE_APP_WRITE; - } - - bool IsSubgraphOutput() const { - return GetTensorType() == QNN_TENSOR_TYPE_APP_READ; - } - - bool IsTensorStatic() const { - return GetTensorType() == QNN_TENSOR_TYPE_STATIC; - } - - template - bool SetTensorData(absl::Span data) { - if (!IsSubgraphInput() && !IsTensorStatic()) { - QNN_LOG_ERROR( - "Cannot set tensor data of tensor type other than " - "QNN_TENSOR_TYPE_APP_WRITE or QNN_TENSOR_TYPE_STATIC."); - return false; - } - - size_t num_elements = GetTensorNumElements(); - if (!num_elements) { - QNN_LOG_ERROR("Cannot set tensor data, number of elements = 0"); - return false; - } - - size_t data_bytes = sizeof(T) * data.size(); - size_t tensor_bytes = GetTensorBytes(); - if (tensor_bytes > data_bytes) { - QNN_LOG_ERROR( - "Tensor bytes: %d > given data bytes: %d, SetTensorData failed.", - tensor_bytes, data_bytes); - return false; - } - if (tensor_bytes < data_bytes) { - QNN_LOG_WARNING( - "Tensor bytes : %d < given data bytes: %d, using only %d.", - tensor_bytes, data_bytes, tensor_bytes); - } - - if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_FLOAT_32) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting float data on QNN data type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_INT_8 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_SFIXED_POINT_8) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::int8_t data on QNN data type " - "%d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_UINT_8 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_UFIXED_POINT_8) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::uint8_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_INT_16 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_SFIXED_POINT_16) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::int16_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_UINT_16 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_UFIXED_POINT_16) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::uint16_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_INT_32 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_SFIXED_POINT_32) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::int32_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_UINT_32 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_UFIXED_POINT_32) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::uint32_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else { - QNN_LOG_ERROR("Cannot set tensor data, unknown data type."); - return false; - } - - owned_data_.resize(tensor_bytes); - std::memcpy(owned_data_.data(), reinterpret_cast(data.data()), - tensor_bytes); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - return true; - } - - // Allocate memory on owned_data_ for output tensors - void AllocateOutputTensorBuffer() { - owned_data_.resize(GetTensorBytes()); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - } - - template - std::optional> GetStaticTensorData() const; - - void ConvertAxisScaleOffsetToScaleOffset() { - if (!std::holds_alternative( - quantize_params_)) { - return; - } - - quantize_params_.emplace(0.0, 0); - } - - size_t GetTensorBytes() const; - - void ConvertQint16ToQuint16(); - - private: - Qnn_TensorType_t GetTensorType() const; - - void SetDataBy(std::uint32_t bytes, const void* data); - - bool HasStaticData() const { - return qnn_tensor_.v2.clientBuf.dataSize != 0 && - qnn_tensor_.v2.clientBuf.data != nullptr; - } - - Qnn_Tensor_t qnn_tensor_{.version = QNN_TENSOR_VERSION_2, - .v2 = QNN_TENSOR_V2_INIT}; - std::string name_{}; - std::vector dimentions_{}; - QuantizeParamsWrapperVariant quantize_params_{}; - std::vector owned_data_{}; -}; - -using TensorWrapperRef = std::reference_wrapper; - -template -std::optional> TensorWrapper::GetStaticTensorData() const { - if (!IsTensorStatic()) { - QNN_LOG_ERROR( - "Cannot GetStaticTensorData() on a non-static tensor, tensor type %d.", - GetTensorType()); - return std::nullopt; - } - - if (GetDataType() != GetQnnDataType(IsQuant())) { - QNN_LOG_ERROR("GetStaticTensorData() with incorrect template type."); - return std::nullopt; - } - - if (!HasStaticData()) { - QNN_LOG_ERROR("Empty static tensor data."); - return std::nullopt; - } - - if (qnn_tensor_.v2.clientBuf.dataSize != GetTensorBytes()) { - QNN_LOG_ERROR("Tensor bytes != stored data bytes."); - return std::nullopt; - } - - uint32_t num_elements = qnn_tensor_.v2.clientBuf.dataSize / sizeof(T); - if (!num_elements) { - QNN_LOG_ERROR("No element in this tensor."); - return std::nullopt; - } - - return absl::MakeConstSpan( - reinterpret_cast(qnn_tensor_.v2.clientBuf.data), num_elements); -} -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_TENSOR_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD deleted file mode 100644 index 6617e6ff0198e0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_test( - name = "op_wrapper_test", - srcs = [ - "op_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_test( - name = "tensor_wrapper_test", - srcs = [ - "tensor_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:miscs", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_test( - name = "param_wrapper_test", - srcs = [ - "param_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:param_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_test( - name = "quantize_params_wrapper_test", - srcs = [ - "quantize_params_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc deleted file mode 100644 index 60d121142ab5f7..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" - -#include -#include -#include -#include -#include -#include - -#include -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { - -void EXPECT_TENSOR_EQ(Qnn_Tensor_t actual, Qnn_Tensor_t expected) { - EXPECT_EQ(actual.v2.id, expected.v2.id); - EXPECT_EQ(actual.v2.type, expected.v2.type); - EXPECT_EQ(actual.v2.dataFormat, expected.v2.dataFormat); - EXPECT_EQ(actual.v2.dataType, expected.v2.dataType); - EXPECT_EQ(actual.v2.quantizeParams.encodingDefinition, - expected.v2.quantizeParams.encodingDefinition); - EXPECT_EQ(actual.v2.rank, expected.v2.rank); - for (size_t i = 0; i < actual.v2.rank; i++) { - EXPECT_EQ(actual.v2.dimensions[i], expected.v2.dimensions[i]); - } - EXPECT_EQ(actual.v2.memType, expected.v2.memType); - EXPECT_EQ(actual.v2.clientBuf.dataSize, expected.v2.clientBuf.dataSize); - const auto* actual_data = - reinterpret_cast(actual.v2.clientBuf.data); - const auto* expected_data = - reinterpret_cast(expected.v2.clientBuf.data); - for (size_t i = 0; i < actual.v2.clientBuf.dataSize; i++) { - EXPECT_EQ(actual_data[i], expected_data[i]); - } -} - -TEST(OpWrapperTest, SanityTest) { - OpWrapper op_wrapper{"name", "OP_TYPE"}; - const Qnn_OpConfig_t& op_config = op_wrapper.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - - const Qnn_OpConfigV1_t& op_config_v1 = op_config.v1; - EXPECT_STREQ(op_config_v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config_v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config_v1.name, "name"); - EXPECT_EQ(op_config_v1.numOfInputs, 0); - EXPECT_EQ(op_config_v1.numOfOutputs, 0); - EXPECT_EQ(op_config_v1.numOfParams, 0); - EXPECT_EQ(op_config_v1.params, nullptr); - EXPECT_EQ(op_config_v1.inputTensors, nullptr); - EXPECT_EQ(op_config_v1.outputTensors, nullptr); -} - -TEST(OpWrapperTest, MoveCtorSanityTest) { - OpWrapper op_wrapper{"name", "OP_TYPE"}; - OpWrapper moved{std::move(op_wrapper)}; - const Qnn_OpConfig_t& op_config = moved.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - - const Qnn_OpConfigV1_t& op_config_v1 = op_config.v1; - EXPECT_STREQ(op_config_v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config_v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config_v1.name, "name"); - EXPECT_EQ(op_config_v1.numOfInputs, 0); - EXPECT_EQ(op_config_v1.numOfOutputs, 0); - EXPECT_EQ(op_config_v1.numOfParams, 0); - EXPECT_EQ(op_config_v1.params, nullptr); - EXPECT_EQ(op_config_v1.inputTensors, nullptr); - EXPECT_EQ(op_config_v1.outputTensors, nullptr); -} - -TEST(OpWrapperTest, OpConfigTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - const auto data_size = - std::accumulate(dummy_dims.begin(), dummy_dims.end(), - sizeof(decltype(data)::value_type), std::multiplies<>()); - - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data_size), - data_ptr}; - - Qnn_Tensor_t golden_qnn_tensor; - tensor_wrapper.CloneTo(golden_qnn_tensor); - - std::uint8_t value = 255; - OpWrapper op_wrapper{"name", "OP_TYPE"}; - op_wrapper.AddInputTensor(tensor_wrapper); - op_wrapper.AddOutputTensor(tensor_wrapper); - op_wrapper.AddScalarParam("uint8_param", value, false); - op_wrapper.AddTensorParam("tensor_param", tensor_wrapper); - - Qnn_OpConfig_t op_config = op_wrapper.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - EXPECT_STREQ(op_config.v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config.v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config.v1.name, "name"); - - Qnn_OpConfigV1_t op_config_v1 = op_config.v1; - - EXPECT_EQ(op_config_v1.numOfInputs, 1); - EXPECT_EQ(op_config_v1.numOfOutputs, 1); - EXPECT_EQ(op_config_v1.numOfParams, 2); - EXPECT_TENSOR_EQ(op_config_v1.inputTensors[0], golden_qnn_tensor); - EXPECT_TENSOR_EQ(op_config_v1.outputTensors[0], golden_qnn_tensor); - EXPECT_EQ(op_config_v1.params[0].paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(op_config_v1.params[0].name, "uint8_param"); - EXPECT_EQ(op_config_v1.params[0].scalarParam.dataType, QNN_DATATYPE_UINT_8); - EXPECT_EQ(op_config_v1.params[0].scalarParam.uint8Value, value); - EXPECT_EQ(op_config_v1.params[1].paramType, QNN_PARAMTYPE_TENSOR); - EXPECT_EQ(op_config_v1.params[1].name, "tensor_param"); - EXPECT_TENSOR_EQ(op_config_v1.params[1].tensorParam, golden_qnn_tensor); -} - -TEST(OpWrapperTest, MoveConstructorTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data.size()), - data_ptr}; - Qnn_Tensor_t golden_qnn_tensor; - tensor_wrapper.CloneTo(golden_qnn_tensor); - std::uint8_t value = 255; - OpWrapper op_wrapper{"name", "OP_TYPE"}; - op_wrapper.AddInputTensor(tensor_wrapper); - op_wrapper.AddOutputTensor(tensor_wrapper); - op_wrapper.AddScalarParam("uint8_param", value, false); - op_wrapper.AddTensorParam("tensor_param", tensor_wrapper); - OpWrapper op_wrapper_move(std::move(op_wrapper)); - Qnn_OpConfig_t op_config = op_wrapper_move.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - EXPECT_STREQ(op_config.v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config.v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config.v1.name, "name"); - Qnn_OpConfigV1_t op_config_v1 = op_config.v1; - EXPECT_EQ(op_config_v1.numOfInputs, 1); - EXPECT_EQ(op_config_v1.numOfOutputs, 1); - EXPECT_EQ(op_config_v1.numOfParams, 2); - EXPECT_TENSOR_EQ(op_config_v1.inputTensors[0], golden_qnn_tensor); - EXPECT_TENSOR_EQ(op_config_v1.outputTensors[0], golden_qnn_tensor); - EXPECT_EQ(op_config_v1.params[0].paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(op_config_v1.params[0].name, "uint8_param"); - EXPECT_EQ(op_config_v1.params[0].scalarParam.dataType, QNN_DATATYPE_UINT_8); - EXPECT_EQ(op_config_v1.params[0].scalarParam.uint8Value, value); - EXPECT_EQ(op_config_v1.params[1].paramType, QNN_PARAMTYPE_TENSOR); - EXPECT_EQ(op_config_v1.params[1].name, "tensor_param"); - EXPECT_TENSOR_EQ(op_config_v1.params[1].tensorParam, golden_qnn_tensor); -} - -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc deleted file mode 100644 index 1472e494306911..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h" - -#include -#include -#include -#include -#include - -#include -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { - -TEST(ScalarParamWrapperTest, BoolParamTest) { - ScalarParamWrapper bool_param{"bool_param", true, false}; - Qnn_Param_t bool_qnn_param = QNN_PARAM_INIT; - bool_param.CloneTo(bool_qnn_param); - EXPECT_EQ(bool_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(bool_qnn_param.name, "bool_param"); - EXPECT_EQ(bool_qnn_param.scalarParam.dataType, QNN_DATATYPE_BOOL_8); - EXPECT_EQ(bool_qnn_param.scalarParam.bool8Value, 1); -} - -TEST(ScalarParamWrapperTest, Uint8ParamTest) { - constexpr std::uint8_t value = 255; - ScalarParamWrapper uint8_param{"uint8_param", value, false}; - Qnn_Param_t uint8_qnn_param = QNN_PARAM_INIT; - uint8_param.CloneTo(uint8_qnn_param); - EXPECT_EQ(uint8_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint8_qnn_param.name, "uint8_param"); - EXPECT_EQ(uint8_qnn_param.scalarParam.dataType, QNN_DATATYPE_UINT_8); - EXPECT_EQ(uint8_qnn_param.scalarParam.uint8Value, value); -} - -TEST(ScalarParamWrapperTest, Int8ParamTest) { - constexpr std::int8_t value = -128; - ScalarParamWrapper int8_param{"int8_param", value, false}; - Qnn_Param_t int8_qnn_param = QNN_PARAM_INIT; - int8_param.CloneTo(int8_qnn_param); - EXPECT_EQ(int8_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int8_qnn_param.name, "int8_param"); - EXPECT_EQ(int8_qnn_param.scalarParam.dataType, QNN_DATATYPE_INT_8); - EXPECT_EQ(int8_qnn_param.scalarParam.int8Value, value); -} - -TEST(ScalarParamWrapperTest, Uint16ParamTest) { - constexpr std::uint16_t value = 65535; - ScalarParamWrapper uint16_param{"uint16_param", value, false}; - Qnn_Param_t uint16_qnn_param = QNN_PARAM_INIT; - uint16_param.CloneTo(uint16_qnn_param); - EXPECT_EQ(uint16_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint16_qnn_param.name, "uint16_param"); - EXPECT_EQ(uint16_qnn_param.scalarParam.dataType, QNN_DATATYPE_UINT_16); - EXPECT_EQ(uint16_qnn_param.scalarParam.uint16Value, value); -} - -TEST(ScalarParamWrapperTest, Int16ParamTest) { - constexpr std::int16_t value = -32768; - ScalarParamWrapper int16_param{"int16_param", value, false}; - Qnn_Param_t int16_qnn_param = QNN_PARAM_INIT; - int16_param.CloneTo(int16_qnn_param); - EXPECT_EQ(int16_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int16_qnn_param.name, "int16_param"); - EXPECT_EQ(int16_qnn_param.scalarParam.dataType, QNN_DATATYPE_INT_16); - EXPECT_EQ(int16_qnn_param.scalarParam.int16Value, value); -} - -TEST(ScalarParamWrapperTest, Uint32ParamTest) { - constexpr std::uint32_t value = 4294967295; - ScalarParamWrapper uint32_param{"uint32_param", value, false}; - Qnn_Param_t uint32_qnn_param = QNN_PARAM_INIT; - uint32_param.CloneTo(uint32_qnn_param); - EXPECT_EQ(uint32_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint32_qnn_param.name, "uint32_param"); - EXPECT_EQ(uint32_qnn_param.scalarParam.dataType, QNN_DATATYPE_UINT_32); - EXPECT_EQ(uint32_qnn_param.scalarParam.uint32Value, value); -} - -TEST(ScalarParamWrapperTest, Int32ParamTest) { - constexpr std::int32_t value = -2147483648; - ScalarParamWrapper int32_param{"int32_param", value, false}; - Qnn_Param_t int32_qnn_param = QNN_PARAM_INIT; - int32_param.CloneTo(int32_qnn_param); - EXPECT_EQ(int32_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int32_qnn_param.name, "int32_param"); - EXPECT_EQ(int32_qnn_param.scalarParam.dataType, QNN_DATATYPE_INT_32); - EXPECT_EQ(int32_qnn_param.scalarParam.int32Value, value); -} - -TEST(ScalarParamWrapperTest, FloatParamTest) { - constexpr float value = 3.14f; - ScalarParamWrapper float_param{"float_param", value, false}; - Qnn_Param_t float_qnn_param = QNN_PARAM_INIT; - float_param.CloneTo(float_qnn_param); - EXPECT_EQ(float_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(float_qnn_param.name, "float_param"); - EXPECT_EQ(float_qnn_param.scalarParam.dataType, QNN_DATATYPE_FLOAT_32); - EXPECT_FLOAT_EQ(float_qnn_param.scalarParam.floatValue, value); -} - -TEST(ScalarParamWrapperTest, QuantizedBoolParamTest) { - ScalarParamWrapper bool_quant_param{"bool_quant_param", true, true}; - Qnn_Param_t bool_quant_qnn_param = QNN_PARAM_INIT; - bool_quant_param.CloneTo(bool_quant_qnn_param); - EXPECT_EQ(bool_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(bool_quant_qnn_param.name, "bool_quant_param"); - EXPECT_EQ(bool_quant_qnn_param.scalarParam.dataType, QNN_DATATYPE_BOOL_8); - EXPECT_EQ(bool_quant_qnn_param.scalarParam.bool8Value, 1); -} - -TEST(ScalarParamWrapperTest, QuantizedUint8ParamTest) { - constexpr std::uint8_t value = 255; - ScalarParamWrapper uint8_quant_param{"uint8_quant_param", value, true}; - Qnn_Param_t uint8_quant_qnn_param = QNN_PARAM_INIT; - uint8_quant_param.CloneTo(uint8_quant_qnn_param); - EXPECT_EQ(uint8_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint8_quant_qnn_param.name, "uint8_quant_param"); - EXPECT_EQ(uint8_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(uint8_quant_qnn_param.scalarParam.uint8Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedInt8ParamTest) { - constexpr std::int8_t value = -128; - ScalarParamWrapper int8_quant_param{"int8_quant_param", value, true}; - Qnn_Param_t int8_quant_qnn_param = QNN_PARAM_INIT; - int8_quant_param.CloneTo(int8_quant_qnn_param); - EXPECT_EQ(int8_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int8_quant_qnn_param.name, "int8_quant_param"); - EXPECT_EQ(int8_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_SFIXED_POINT_8); - EXPECT_EQ(int8_quant_qnn_param.scalarParam.int8Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedUint16ParamTest) { - constexpr std::uint16_t value = 65535; - ScalarParamWrapper uint16_quant_param{"uint16_quant_param", value, true}; - Qnn_Param_t uint16_quant_qnn_param = QNN_PARAM_INIT; - uint16_quant_param.CloneTo(uint16_quant_qnn_param); - EXPECT_EQ(uint16_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint16_quant_qnn_param.name, "uint16_quant_param"); - EXPECT_EQ(uint16_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_UFIXED_POINT_16); - EXPECT_EQ(uint16_quant_qnn_param.scalarParam.uint16Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedInt16ParamTest) { - constexpr std::int16_t value = -32768; - ScalarParamWrapper int16_quant_param{"int16_quant_param", value, true}; - Qnn_Param_t int16_quant_qnn_param = QNN_PARAM_INIT; - int16_quant_param.CloneTo(int16_quant_qnn_param); - EXPECT_EQ(int16_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int16_quant_qnn_param.name, "int16_quant_param"); - EXPECT_EQ(int16_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_SFIXED_POINT_16); - EXPECT_EQ(int16_quant_qnn_param.scalarParam.int16Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedUint32ParamTest) { - constexpr std::uint32_t value = 4294967295; - ScalarParamWrapper uint32_quant_param{"uint32_quant_param", value, true}; - Qnn_Param_t uint32_quant_qnn_param = QNN_PARAM_INIT; - uint32_quant_param.CloneTo(uint32_quant_qnn_param); - EXPECT_EQ(uint32_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint32_quant_qnn_param.name, "uint32_quant_param"); - EXPECT_EQ(uint32_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_UFIXED_POINT_32); - EXPECT_EQ(uint32_quant_qnn_param.scalarParam.uint32Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedInt32ParamTest) { - constexpr std::int32_t value = -2147483648; - ScalarParamWrapper int32_quant_param{"int32_quant_param", value, true}; - Qnn_Param_t int32_quant_qnn_param = QNN_PARAM_INIT; - int32_quant_param.CloneTo(int32_quant_qnn_param); - EXPECT_EQ(int32_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int32_quant_qnn_param.name, "int32_quant_param"); - EXPECT_EQ(int32_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_SFIXED_POINT_32); - EXPECT_EQ(int32_quant_qnn_param.scalarParam.int32Value, value); -} - -TEST(ParamWrapperTest, TensorParamTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - - const auto data_size = - std::accumulate(dummy_dims.begin(), dummy_dims.end(), - sizeof(decltype(data)::value_type), std::multiplies<>()); - - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data_size), - data_ptr}; - - TensorParamWrapper tensor_param{"tensor_param", tensor_wrapper}; - - Qnn_Param_t qnn_tensor_param = QNN_PARAM_INIT; - tensor_param.CloneTo(qnn_tensor_param); - EXPECT_EQ(qnn_tensor_param.paramType, QNN_PARAMTYPE_TENSOR); - EXPECT_EQ(qnn_tensor_param.name, "tensor_param"); - - Qnn_Tensor_t& ref = qnn_tensor_param.tensorParam; - EXPECT_EQ(ref.v2.id, 0); - EXPECT_EQ(ref.v2.type, QNN_TENSOR_TYPE_STATIC); - EXPECT_EQ(ref.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); - EXPECT_EQ(ref.v2.dataType, QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(ref.v2.quantizeParams.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(ref.v2.rank, dummy_dims.size()); - for (size_t i = 0; i < ref.v2.rank; i++) { - EXPECT_EQ(ref.v2.dimensions[i], dummy_dims[i]); - } - EXPECT_EQ(ref.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(ref.v2.clientBuf.dataSize, data_size); - const auto* ref_data = - reinterpret_cast(ref.v2.clientBuf.data); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(ref_data[i], data[i]); - } -} -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc deleted file mode 100644 index 8ed03dc50689ea..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -#include -#include -#include -#include - -#include -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" - -namespace qnn { -namespace { - -TEST(UndefinedQuantizeParamsWrapperTest, DefaultConstructorTest) { - UndefinedQuantizeParamsWrapper wrapper; - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_UNDEFINED); -} - -TEST(UndefinedQuantizeParamsWrapperTest, CopyConstructorTest) { - UndefinedQuantizeParamsWrapper wrapper1; - UndefinedQuantizeParamsWrapper wrapper2(wrapper1); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_UNDEFINED); -} - -TEST(UndefinedQuantizeParamsWrapperTest, MoveConstructorTest) { - UndefinedQuantizeParamsWrapper wrapper1; - UndefinedQuantizeParamsWrapper wrapper2(std::move(wrapper1)); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_UNDEFINED); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, ConstructorTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper(scale, zero_point); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - EXPECT_FLOAT_EQ(dst.scaleOffsetEncoding.scale, scale); - EXPECT_EQ(dst.scaleOffsetEncoding.offset, -zero_point); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, CopyConstructorTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper1(scale, zero_point); - ScaleOffsetQuantizeParamsWrapper wrapper2(wrapper1); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - EXPECT_FLOAT_EQ(dst.scaleOffsetEncoding.scale, scale); - EXPECT_EQ(dst.scaleOffsetEncoding.offset, -zero_point); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, MoveConstructorTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper1(scale, zero_point); - ScaleOffsetQuantizeParamsWrapper wrapper2(std::move(wrapper1)); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - EXPECT_FLOAT_EQ(dst.scaleOffsetEncoding.scale, scale); - EXPECT_EQ(dst.scaleOffsetEncoding.offset, -zero_point); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, GetterTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper(scale, zero_point); - EXPECT_FLOAT_EQ(wrapper.GetScale(), scale); - EXPECT_EQ(wrapper.GetZeroPoint(), zero_point); -} - -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, ConstructorTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper(axis, scales, zero_points); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(dst.axisScaleOffsetEncoding.axis, axis); - EXPECT_EQ(dst.axisScaleOffsetEncoding.numScaleOffsets, scales.size()); - for (size_t i = 0; i < scales.size(); ++i) { - EXPECT_FLOAT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].scale, - scales[i]); - EXPECT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].offset, - -zero_points[i]); - } -} - -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, CopyConstructorTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper1(axis, scales, zero_points); - AxisScaleOffsetQuantizeParamsWrapper wrapper2(wrapper1); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(dst.axisScaleOffsetEncoding.axis, axis); - EXPECT_EQ(dst.axisScaleOffsetEncoding.numScaleOffsets, scales.size()); - for (size_t i = 0; i < scales.size(); ++i) { - EXPECT_FLOAT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].scale, - scales[i]); - EXPECT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].offset, - -zero_points[i]); - } -} - -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, MoveConstructorTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper1(axis, scales, zero_points); - AxisScaleOffsetQuantizeParamsWrapper wrapper2(std::move(wrapper1)); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(dst.axisScaleOffsetEncoding.axis, axis); - EXPECT_EQ(dst.axisScaleOffsetEncoding.numScaleOffsets, scales.size()); - for (size_t i = 0; i < scales.size(); ++i) { - EXPECT_FLOAT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].scale, - scales[i]); - EXPECT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].offset, - -zero_points[i]); - } -} -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, GetterTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper(axis, scales, zero_points); - std::vector scales_out; - wrapper.GetScales(scales_out); - EXPECT_EQ(scales, scales_out); - std::vector zero_points_out; - wrapper.GetZeroPoints(zero_points_out); - EXPECT_EQ(zero_points, zero_points_out); -} -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc deleted file mode 100644 index 68e1828181ef9b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -namespace qnn { -namespace { - -TEST(TensorWrapperTest, SanityTest) { - TensorWrapper tensor_wrapper{}; - - EXPECT_EQ(tensor_wrapper.GetRank(), 0); - EXPECT_TRUE(tensor_wrapper.GetDims().empty()); - EXPECT_TRUE(std::holds_alternative( - tensor_wrapper.GetQuantParams())); - EXPECT_FALSE(tensor_wrapper.IsPerTensorQuantWithOffsetDiff(tensor_wrapper)); - EXPECT_FALSE(tensor_wrapper.IsQuant8()); - EXPECT_FALSE(tensor_wrapper.IsQuant16()); - EXPECT_EQ(tensor_wrapper.GetDataType(), QNN_DATATYPE_UNDEFINED); - EXPECT_FALSE(tensor_wrapper.IsSubgraphInput()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphOutput()); - EXPECT_FALSE(tensor_wrapper.IsTensorStatic()); - EXPECT_EQ(tensor_wrapper.GetStaticTensorData(), std::nullopt); - std::vector data = {1, 2, 3}; - // expect no use, since tensor type not correct - tensor_wrapper.SetTensorData( - absl::MakeSpan(data.data(), data.size())); - EXPECT_EQ(tensor_wrapper.GetStaticTensorData(), std::nullopt); -} - -TEST(TensorWrapperTest, CopyTensorTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, q_param, - dummy_dims}; - TensorWrapper copied{tensor_wrapper}; - - EXPECT_EQ(copied.GetRank(), 3); - EXPECT_EQ(copied.GetDims(), dummy_dims); - EXPECT_TRUE(std::holds_alternative( - copied.GetQuantParams())); - EXPECT_FALSE(copied.IsPerTensorQuantWithOffsetDiff(copied)); - EXPECT_TRUE(copied.IsQuant8()); - EXPECT_FALSE(copied.IsQuant16()); - EXPECT_EQ(copied.GetDataType(), QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_FALSE(copied.IsSubgraphInput()); - EXPECT_FALSE(copied.IsSubgraphOutput()); - EXPECT_TRUE(copied.IsTensorStatic()); - EXPECT_EQ(copied.GetStaticTensorData(), std::nullopt); - std::vector data = {1, 2, 3}; - copied.SetTensorData(absl::MakeSpan(data.data(), data.size())); - const auto tensor_data = copied.GetStaticTensorData(); - EXPECT_TRUE(tensor_data.has_value()); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ((*tensor_data)[i], data[i]); - } -} - -TEST(TensorWrapperTest, MoveTensorTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, - q_param, - dummy_dims, - static_cast(data.size()), - data_ptr}; - TensorWrapper moved{tensor_wrapper}; - - EXPECT_EQ(moved.GetRank(), 3); - EXPECT_EQ(moved.GetDims(), dummy_dims); - EXPECT_TRUE(std::holds_alternative( - moved.GetQuantParams())); - EXPECT_FALSE(moved.IsPerTensorQuantWithOffsetDiff(moved)); - EXPECT_TRUE(moved.IsQuant8()); - EXPECT_FALSE(moved.IsQuant16()); - EXPECT_EQ(moved.GetDataType(), QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_FALSE(moved.IsSubgraphInput()); - EXPECT_FALSE(moved.IsSubgraphOutput()); - EXPECT_TRUE(moved.IsTensorStatic()); - const auto tensor_data = moved.GetStaticTensorData(); - EXPECT_TRUE(tensor_data.has_value()); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(tensor_data.value()[i], data[i]); - } -} - -TEST(TensorWrapperTest, QnnTensorTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - const auto data_size = - std::accumulate(dummy_dims.begin(), dummy_dims.end(), - sizeof(decltype(data)::value_type), std::multiplies<>()); - - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data_size), - data_ptr}; - - Qnn_Tensor_t cloned; - tensor_wrapper.CloneTo(cloned); - EXPECT_EQ(cloned.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(cloned.v2.id, 0); - EXPECT_EQ(cloned.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(cloned.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); - EXPECT_EQ(cloned.v2.dataType, QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(cloned.v2.quantizeParams.encodingDefinition, - QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(cloned.v2.rank, dummy_dims.size()); - for (size_t i = 0; i < cloned.v2.rank; i++) { - EXPECT_EQ(cloned.v2.dimensions[i], dummy_dims[i]); - } - EXPECT_EQ(cloned.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(cloned.v2.clientBuf.dataSize, data_size); - const auto* cloned_data = - reinterpret_cast(cloned.v2.clientBuf.data); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(cloned_data[i], data[i]); - } - - Qnn_Tensor_t& ref = tensor_wrapper.GetQnnTensor(); - EXPECT_EQ(ref.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(ref.v2.id, 0); - EXPECT_EQ(ref.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(ref.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); - EXPECT_EQ(ref.v2.dataType, QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(ref.v2.quantizeParams.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(ref.v2.rank, dummy_dims.size()); - for (size_t i = 0; i < ref.v2.rank; i++) { - EXPECT_EQ(ref.v2.dimensions[i], dummy_dims[i]); - } - EXPECT_EQ(ref.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(ref.v2.clientBuf.dataSize, data_size); - const auto* ref_data = - reinterpret_cast(ref.v2.clientBuf.data); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(ref_data[i], data[i]); - } -} - -TEST(TensorWrapperTest, IsPerTensorQuantWithOffsetDiff8BitTest) { - constexpr int kSUFixed8OffsetDiff = 128; - ScaleOffsetQuantizeParamsWrapper wrapper1(1, 0); - ScaleOffsetQuantizeParamsWrapper wrapper2(1, kSUFixed8OffsetDiff); - TensorWrapper tensor_wrapper0{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(wrapper1), - {}}; - TensorWrapper tensor_wrapper1{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_SFIXED_POINT_8, - QuantizeParamsWrapperVariant(wrapper2), - {}}; - EXPECT_TRUE(tensor_wrapper0.IsPerTensorQuantWithOffsetDiff(tensor_wrapper1)); -} - -TEST(TensorWrapperTest, IsPerTensorQuantWithOffsetDiff16BitTest) { - constexpr int kSUFixed16OffsetDiff = 32768; - ScaleOffsetQuantizeParamsWrapper wrapper1(1, 0); - ScaleOffsetQuantizeParamsWrapper wrapper2(1, kSUFixed16OffsetDiff); - TensorWrapper tensor_wrapper0{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_16, - QuantizeParamsWrapperVariant(wrapper1), - {}}; - TensorWrapper tensor_wrapper1{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_SFIXED_POINT_16, - QuantizeParamsWrapperVariant(wrapper2), - {}}; - EXPECT_TRUE(tensor_wrapper0.IsPerTensorQuantWithOffsetDiff(tensor_wrapper1)); -} - -TEST(TensorWrapperTest, StaticTensorTest) { - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UNDEFINED, - QuantizeParamsWrapperVariant(), - {}}; - - EXPECT_TRUE(tensor_wrapper.IsTensorStatic()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphInput()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphOutput()); -} - -TEST(TensorWrapperTest, SubgraphInputTensorTest) { - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UNDEFINED, - QuantizeParamsWrapperVariant(), - {}}; - - EXPECT_FALSE(tensor_wrapper.IsTensorStatic()); - EXPECT_TRUE(tensor_wrapper.IsSubgraphInput()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphOutput()); -} - -TEST(TensorWrapperTest, SubgraphOutputTensorTest) { - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_READ, - QNN_DATATYPE_UNDEFINED, - QuantizeParamsWrapperVariant(), - {}}; - - EXPECT_FALSE(tensor_wrapper.IsTensorStatic()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphInput()); - EXPECT_TRUE(tensor_wrapper.IsSubgraphOutput()); -} - -TEST(TensorWrapperTest, GetStaticTensorDataNonStaticTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, q_param, - dummy_dims}; - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); -} - -TEST(TensorWrapperTest, GetStaticTensorDataTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, q_param, - dummy_dims}; - - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); - std::vector data = {1, 2, 3}; - tensor_wrapper.SetTensorData( - absl::MakeSpan(data.data(), data.size())); - const auto tensor_data = - *(tensor_wrapper.GetStaticTensorData()); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(tensor_data[i], data[i]); - } -} - -TEST(TensorWrapperTest, ConvertQint16ToQuint16Test) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(0.0001, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_SFIXED_POINT_16, q_param, - dummy_dims}; - - std::vector data = {1, 2, 3}; - const auto& int16_q_param_ref = tensor_wrapper.GetQuantParams(); - EXPECT_TRUE(std::holds_alternative( - int16_q_param_ref)); - const float int16_scale = - std::get(int16_q_param_ref).GetScale(); - const std::int32_t int16_zero_point = - std::get(int16_q_param_ref) - .GetZeroPoint(); - std::vector int16_data; - for (int i = 0; i < data.size(); ++i) { - int16_data.emplace_back( - Quantize(data[i], int16_scale, int16_zero_point)); - } - tensor_wrapper.SetTensorData( - absl::MakeSpan(int16_data.data(), int16_data.size())); - - tensor_wrapper.ConvertQint16ToQuint16(); - - const auto& uint16_q_param_ref = tensor_wrapper.GetQuantParams(); - EXPECT_TRUE(std::holds_alternative( - uint16_q_param_ref)); - const float uint16_scale = - std::get(uint16_q_param_ref).GetScale(); - const std::int32_t uint16_zero_point = - std::get(uint16_q_param_ref) - .GetZeroPoint(); - const auto uint16_data = - *(tensor_wrapper.GetStaticTensorData()); - std::vector deq_data; - for (size_t i = 0; i < data.size(); i++) { - deq_data.emplace_back( - Dequantize(uint16_data[i], uint16_scale, uint16_zero_point)); - } - ASSERT_EQ(data.size(), deq_data.size()); - for (size_t i = 0; i < data.size(); ++i) { - EXPECT_NEAR(data[i], deq_data[i], 1e-3); - } -} -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD deleted file mode 100644 index 2809dcc115188f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_dynamic_lib( - name = "dispatch_api", - srcs = [ - "dispatch_api.cc", - "litert_dispatch_device_context.cc", - "litert_dispatch_invocation_context.cc", - ], - hdrs = [ - "litert_dispatch_device_context.h", - "litert_dispatch_invocation_context.h", - "registry.h", - ], - copts = [ - "-Os", - "-fno-exceptions", - "-fno-unwind-tables", - "-fno-asynchronous-unwind-tables", - "-ffunction-sections", - "-fdata-sections", - ], - export_litert_only = True, - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + [ - "-Wl,-soname=libLiteRtDispatch_Qualcomm.so", - "-Wl,-lc++abi", - ], - shared_lib_name = "dispatch_api_so", - so_name = "libLiteRtDispatch_Qualcomm.so", - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:context_binary_info", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - ], -) - -# This is cc_library target for `libLiteRtDispatch_Qualcomm.so`. -cc_library( - name = "dispatch_api_shared_lib", - srcs = [":dispatch_api_so"], -) - -# Copies the shared library so that it is available for use in test data as libLiteRtDispatch_Qualcomm.so. -copy_file( - name = "copy_dispatch_api_so", - src = "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - target = "libLiteRtDispatch_Qualcomm.so", -) - -cc_test( - name = "dispatch_api_qualcomm_test", - srcs = [ - "dispatch_api_qualcomm_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - linkstatic = 1, - tags = [ - "no-remote-exec", - "no_oss", - "notap", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc deleted file mode 100644 index f377e1a26581e1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace { - -using ::litert::qnn::QnnManager; - -static std::unique_ptr TheQnnManager; - -QnnManager& Qnn() { return *TheQnnManager; } - -char BuildId[256]; - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, - int num_options) { - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - return option.value.str_value; - } - } - return nullptr; -} - -LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { - auto* shared_library_dir = GetSharedLibraryDir(options, num_options); - std::optional shared_library_dir_opt = - shared_library_dir ? std::make_optional(std::string(shared_library_dir)) - : std::nullopt; - - auto configs = QnnManager::DefaultBackendConfigs(); - if (auto qnn_manager = QnnManager::Create(configs, shared_library_dir_opt); - !qnn_manager) { - LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().c_str()); - return qnn_manager.Error().Status(); - } else { - std::swap(TheQnnManager, *qnn_manager); - } - - Qnn_ApiVersion_t qnn_api_version; - if (auto status = Qnn().Api()->backendGetApiVersion(&qnn_api_version); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to get QNN API version: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - const char* build_id; - if (auto status = Qnn().Api()->backendGetBuildId(&build_id); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to get QNN build ID: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - snprintf(BuildId, sizeof(BuildId), - "Qualcomm Dispatch API version %d.%d.%d, QNN API version %d.%d.%d, " - "build id: %s", - LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH, qnn_api_version.coreApiVersion.major, - qnn_api_version.coreApiVersion.minor, - qnn_api_version.coreApiVersion.patch, build_id); - BuildId[sizeof(BuildId) - 1] = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus GetVendorId(const char** vendor_id) { - *vendor_id = "Qualcomm"; - return kLiteRtStatusOk; -} - -LiteRtStatus GetBuildId(const char** build_id) { - *build_id = BuildId; - return kLiteRtStatusOk; -} - -LiteRtStatus GetCapabilities(int* capabilities) { - *capabilities = kLiteRtDispatchCapabilitiesBasic; - return kLiteRtStatusOk; -} - -LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { - if (auto context = LiteRtDispatchDeviceContextT::Create(Qnn()); context) { - *device_context = context->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } -} - -LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { - delete device_context; - return kLiteRtStatusOk; -} - -LiteRtStatus GetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetInputRequirements(input_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus GetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetOutputRequirements(output_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus RegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, LiteRtTensorBuffer buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (auto status = device_context->RegisterTensorBuffer(buffer); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } else { - *tensor_buffer_handle = *status; - return kLiteRtStatusOk; - } -} - -LiteRtStatus UnregisterTensorBuffer(LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle handle) { - if (auto status = device_context->UnregisterTensorBuffer(handle); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to unregister buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } else { - return kLiteRtStatusOk; - } -} - -LiteRtStatus InvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - auto context = LiteRtDispatchInvocationContextT::Create( - Qnn(), *device_context, exec_bytecode_buffer, function_name); - if (!context) { - LITERT_LOG(LITERT_ERROR, "Failed to create context from context binary: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } - *invocation_context = context->release(); - device_context->SetInvocationContext(*invocation_context); - return kLiteRtStatusOk; -} - -LiteRtStatus InvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - delete invocation_context; - return kLiteRtStatusOk; -} - -LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachInput(graph_input_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach input buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachOutput(graph_output_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach output buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do here. - return kLiteRtStatusOk; -} - -LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do here. - return kLiteRtStatusOk; -} - -LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { - if (auto status = invocation_context->Execute(); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to execute invocation context: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtDispatchInterface TheInterface = { - /*.initialize=*/Initialize, - /*.get_vendor_id=*/GetVendorId, - /*.get_build_id=*/GetBuildId, - /*.get_capabilities=*/GetCapabilities, - /*.device_context_create=*/DeviceContextCreate, - /*.device_context_destroy=*/DeviceContextDestroy, - /*.get_input_requirements=*/GetInputRequirements, - /*.get_output_requirements=*/GetOutputRequirements, - /*.register_tensor_buffer=*/RegisterTensorBuffer, - /*.unregister_tensor_buffer=*/UnregisterTensorBuffer, - /*.invocation_context_create=*/InvocationContextCreate, - /*.invocation_context_destroy=*/InvocationContextDestroy, - /*.attach_input=*/AttachInput, - /*.attach_output=*/AttachOutput, - /*.detach_input=*/DetachInput, - /*.detach_output=*/DetachOutput, - /*.invoke=*/Invoke, -}; - -LiteRtDispatchApi TheApi = { - /*.version=*/{/*.major=*/LITERT_API_VERSION_MAJOR, - /*.minor=*/LITERT_API_VERSION_MINOR, - /*.patch=*/LITERT_API_VERSION_PATCH}, - /*.interface=*/&TheInterface, - /*.async_interface=*/nullptr, - /*.graph_interface=*/nullptr, -}; - -} // namespace - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { - *api = TheApi; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc deleted file mode 100644 index c1ae8d1c53d4e4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc +++ /dev/null @@ -1,544 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; - -TEST(Qualcomm, DispatchApiWithFastRpc) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a Qualcomm NPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kQualcommModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/"simple", - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} - -TEST(Qualcomm, DispatchApiWithDmaBuf) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a Qualcomm NPU"; -#endif - - EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kQualcommModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/"simple", - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/1, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/1, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/1, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc deleted file mode 100644 index adf0ed86f80ca3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" - -#include -#include - -#include "absl/log/absl_check.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpMem.h" -#include "third_party/qairt/latest/include/QNN/QnnBackend.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnMem.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -using litert::Expected; -using litert::Unexpected; -using litert::qnn::QnnManager; - -Expected -LiteRtDispatchDeviceContextT::Create(QnnManager& qnn) { - return Ptr(new LiteRtDispatchDeviceContextT(qnn)); -} - -Expected LiteRtDispatchDeviceContextT::GetTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); - if (!registry_entry) { - return Unexpected(registry_entry.Error()); - } - - return (*registry_entry)->tensor_buffer; -} - -Expected LiteRtDispatchDeviceContextT::GetMemHandle( - LiteRtTensorBufferHandle tensor_buffer_handle, const Qnn_Tensor_t& tensor) { - auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); - if (!registry_entry) { - return Unexpected(registry_entry.Error()); - } - - if (!(*registry_entry)->qnn_mem_handle) { - auto qnn_mem_handle = - RegisterTensorBuffer((*registry_entry)->tensor_buffer, tensor); - if (!qnn_mem_handle) { - return Unexpected(qnn_mem_handle.Error()); - } - (*registry_entry)->qnn_mem_handle = *qnn_mem_handle; - } - - return (*registry_entry)->qnn_mem_handle; -} - -Expected LiteRtDispatchDeviceContextT::RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor) { - LiteRtTensorBufferType tensor_buffer_type; - if (auto status = - LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer type"); - } - - size_t tensor_buffer_size; - if (auto status = - LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer size"); - } - - size_t tensor_buffer_offset; - if (auto status = - LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer offset"); - } - - LiteRtRankedTensorType tensor_type; - if (auto status = - LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer's type"); - } - - auto element_type = - static_cast(tensor_type.element_type); - Qnn_DataType_t tensor_data_type; - if (auto status = LegalizeElementType(element_type, &tensor_data_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to legalize datatype"); - } - - uint32_t tensor_rank = tensor_type.layout.rank; - uint32_t* tensor_dimensions = reinterpret_cast( - const_cast(tensor_type.layout.dimensions)); - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported by QNN"); - } - - void* buffer_host_addr; - int buffer_fd; - (void)buffer_host_addr; - - switch (tensor_buffer_type) { - case kLiteRtTensorBufferTypeFastRpc: -#if LITERT_HAS_FASTRPC_SUPPORT - if (auto status = LiteRtGetTensorBufferFastRpcBuffer( - tensor_buffer, &buffer_host_addr, &buffer_fd); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get FastRPC buffer"); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "FastRPC support is missing on this platform"); -#endif // LRT_HAS_FASTRPC_SUPPORT - break; - - case kLiteRtTensorBufferTypeDmaBuf: -#if LITERT_HAS_DMABUF_SUPPORT - if (auto status = LiteRtGetTensorBufferDmaBufBuffer( - tensor_buffer, &buffer_host_addr, &buffer_fd); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get DMA-BUF buffer"); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "DmaBuf support is missing on this platform"); -#endif // LRT_HAS_DMABUF_SUPPORT - break; - - default: - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported tensor buffer type"); - } - - QnnMemHtp_Descriptor_t mem_htp_descriptor = {}; - mem_htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; - mem_htp_descriptor.size = tensor_buffer_size; - mem_htp_descriptor.sharedBufferConfig = - QnnHtpMem_SharedBufferConfig_t{buffer_fd, tensor_buffer_offset}; - - Qnn_MemDescriptor_t mem_descriptor = {}; - mem_descriptor.memShape = {tensor_rank, tensor_dimensions, nullptr}; - mem_descriptor.dataType = tensor_data_type; - mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM; - mem_descriptor.customInfo = &mem_htp_descriptor; - - if (invocation_context_ == nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Missing invocation context"); - } - - Qnn_ContextHandle_t context_handle = invocation_context_->ContextHandle(); - - Qnn_MemHandle_t mem_handle = nullptr; - if (auto status = qnn_manager_.Api()->memRegister( - context_handle, &mem_descriptor, 1UL, &mem_handle); - status != QNN_SUCCESS) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register tensor buffer"); - } - - if (!mem_handle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register buffer: null mem_handle"); - } - - return mem_handle; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h deleted file mode 100644 index bd375c5137fcba..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ - -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -class LiteRtDispatchDeviceContextT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtDispatchDeviceContextT() = default; - - static litert::Expected Create(litert::qnn::QnnManager& qnn_manager); - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer) { - return tensor_buffer_registry_.Register( - TensorBufferRegistryEntry(tensor_buffer)); - } - - litert::Expected UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - return tensor_buffer_registry_.Unregister(tensor_buffer_handle); - } - - litert::Expected GetTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected GetMemHandle( - LiteRtTensorBufferHandle tensor_buffer_handle, - const Qnn_Tensor_t& tensor); - - void SetInvocationContext( - LiteRtDispatchInvocationContextT* invocation_context) { - invocation_context_ = invocation_context; - } - - private: - struct TensorBufferRegistryEntry { - LiteRtTensorBuffer tensor_buffer; - Qnn_MemHandle_t qnn_mem_handle = nullptr; - explicit TensorBufferRegistryEntry(LiteRtTensorBuffer tensor_buffer_) - : tensor_buffer(tensor_buffer_) {} - }; - - using TensorBufferRegistry = litert::qnn::Registry; - - LiteRtDispatchDeviceContextT(litert::qnn::QnnManager& qnn_manager) - : qnn_manager_(qnn_manager) {} - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor); - - litert::qnn::QnnManager& qnn_manager_; - TensorBufferRegistry tensor_buffer_registry_; - LiteRtDispatchInvocationContextT* invocation_context_ = nullptr; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc deleted file mode 100644 index 6d05088cf06681..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -using litert::Expected; -using litert::Unexpected; -using litert::qnn::QnnManager; - -LiteRtDispatchInvocationContextT::LiteRtDispatchInvocationContextT( - litert::qnn::QnnManager& qnn_manager, - const litert::qnn::ContextBinaryInfo& context_binary_info, - LiteRtDispatchDeviceContextT& device_context, - QnnManager::ContextHandle&& context_handle, - Qnn_ProfileHandle_t profile_handle, int graph_index, - Qnn_GraphHandle_t graph_handle) - : qnn_manager_(qnn_manager), - device_context_(device_context), - context_handle_(std::move(context_handle)), - profile_handle_(profile_handle), - graph_index_(graph_index), - graph_handle_(graph_handle), - inputs_(context_binary_info.Graphs()[graph_index].Inputs()), - outputs_(context_binary_info.Graphs()[graph_index].Outputs()) {} - -Expected -LiteRtDispatchInvocationContextT::Create( - QnnManager& qnn, LiteRtDispatchDeviceContextT& device_context, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name) { - auto exec_bytecode_ptr = - static_cast(exec_bytecode_buffer->base_addr) + - exec_bytecode_buffer->offset; - auto context_binary_info = litert::qnn::ContextBinaryInfo::Create( - qnn, exec_bytecode_ptr, exec_bytecode_buffer->size); - if (!context_binary_info) { - return Unexpected(context_binary_info.Error()); - } - - const auto& graphs = context_binary_info->Graphs(); - if (graphs.empty()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "No graph found"); - } - - int graph_index = -1; - // If the function_name is not specified and there is only one graph, then - // take that graph. - if (absl::string_view(function_name).empty() && graphs.size() == 1) { - graph_index = 0; - const auto& graph = graphs[graph_index]; - function_name = graph.Name().c_str(); - } else { - for (auto i = 0; i < graphs.size(); ++i) { - const auto& graph = graphs[i]; - if (graph.Name() == absl::string_view(function_name)) { - graph_index = i; - break; - } - } - } - if (graph_index < 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Function name not found"); - } - - auto configs = QnnManager::DefaultContextConfigs(); - Qnn_ProfileHandle_t profile_handle = nullptr; - auto context_handle = qnn.CreateContextHandle( - configs, - absl::MakeSpan(static_cast(exec_bytecode_ptr), - exec_bytecode_buffer->size), - profile_handle); - if (!context_handle) { - return Unexpected(context_handle.Error()); - } - - Qnn_GraphHandle_t graph_handle; - if (auto status = qnn.Api()->graphRetrieve(context_handle->get(), - function_name, &graph_handle); - status != QNN_SUCCESS) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to retrieve graph"); - } - - return Ptr(new LiteRtDispatchInvocationContextT( - qnn, std::move(*context_binary_info), device_context, - std::move(*context_handle), profile_handle, graph_index, graph_handle)); -} - -namespace { - -Expected GetTensorBufferRequirements( - const LiteRtRankedTensorType& tensor_type) { - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported by QNN"); - } - - static constexpr std::array - kSupportedTensorBufferTypes = { - kLiteRtTensorBufferTypeFastRpc, - kLiteRtTensorBufferTypeDmaBuf, - }; - - auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); - if (!buffer_size) { - return Unexpected(buffer_size.Error()); - } - - LiteRtTensorBufferRequirements requirements; - if (auto status = LiteRtCreateTensorBufferRequirements( - kSupportedTensorBufferTypes.size(), - kSupportedTensorBufferTypes.data(), *buffer_size, /*num_strides=*/0, - /*strides=*/nullptr, &requirements); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Not implemented"); - } - - return requirements; -} - -} // namespace - -Expected -LiteRtDispatchInvocationContextT::GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -Expected -LiteRtDispatchInvocationContextT::GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -Expected LiteRtDispatchInvocationContextT::AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (graph_input_index < 0 || graph_input_index >= inputs_.size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Invalid graph_input_index"); - } - - auto& tensor = inputs_[graph_input_index]; - return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); -} - -Expected LiteRtDispatchInvocationContextT::AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (graph_output_index < 0 || graph_output_index >= outputs_.size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Invalid graph_output_index"); - } - - auto& tensor = outputs_[graph_output_index]; - return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); -} - -Expected LiteRtDispatchInvocationContextT::AttachBuffer( - Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle) { - auto tensor_buffer = device_context_.GetTensorBuffer(tensor_buffer_handle); - if (!tensor_buffer) { - return Unexpected(tensor_buffer.Error()); - } - - auto mem_handle = device_context_.GetMemHandle(tensor_buffer_handle, tensor); - if (!mem_handle) { - return Unexpected(mem_handle.Error()); - } - - if (tensor.version == QNN_TENSOR_VERSION_1) { - tensor.v1.memType = QNN_TENSORMEMTYPE_MEMHANDLE; - tensor.v1.memHandle = *mem_handle; - - } else if (tensor.version == QNN_TENSOR_VERSION_2) { - if (tensor.v2.isDynamicDimensions != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Dynamic dimensions not yet supported"); - } - tensor.v2.memType = QNN_TENSORMEMTYPE_MEMHANDLE; - tensor.v2.memHandle = *mem_handle; - - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN tensor version"); - } - - return {}; -} - -Expected LiteRtDispatchInvocationContextT::Execute() { - const size_t num_ins = inputs_.size(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, inputs, num_ins, QNN_TENSOR_INIT); - for (size_t i = 0; i < num_ins; ++i) { - *(inputs + i) = inputs_.at(i).Tensor(); - } - - const size_t num_outs = outputs_.size(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, outputs, num_outs, QNN_TENSOR_INIT); - for (size_t i = 0; i < num_outs; ++i) { - *(outputs + i) = outputs_.at(i).Tensor(); - } - - if (auto status = qnn_manager_.Api()->graphExecute( - graph_handle_, inputs, num_ins, outputs, num_outs, - /*profileHandle=*/nullptr, /*signalHandle=*/nullptr); - status != QNN_SUCCESS) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to execute graph"); - } - - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h deleted file mode 100644 index 17759238816302..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ - -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -class LiteRtDispatchDeviceContextT; - -class LiteRtDispatchInvocationContextT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtDispatchInvocationContextT() = default; - - static litert::Expected Create( - litert::qnn::QnnManager& qnn_manager, - LiteRtDispatchDeviceContextT& device_context, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name); - - litert::Expected GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type); - litert::Expected GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected Execute(); - - Qnn_ContextHandle_t ContextHandle() { return context_handle_.get(); } - - private: - LiteRtDispatchInvocationContextT( - litert::qnn::QnnManager& qnn_manager, - const litert::qnn::ContextBinaryInfo& context_binary_info, - LiteRtDispatchDeviceContextT& device_context, - litert::qnn::QnnManager::ContextHandle&& context_handle, - Qnn_ProfileHandle_t profile_handle, int graph_index, - Qnn_GraphHandle_t graph_handle); - - litert::Expected AttachBuffer( - Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::qnn::QnnManager& qnn_manager_; - LiteRtDispatchDeviceContextT& device_context_; - litert::qnn::QnnManager::ContextHandle context_handle_; - Qnn_ProfileHandle_t profile_handle_; - int graph_index_; - Qnn_GraphHandle_t graph_handle_; - std::vector inputs_; - std::vector outputs_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h deleted file mode 100644 index 8a80e342568e32..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ - -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::qnn { - -template -class Registry { - public: - Expected Register(const V& value) { - // TODO: improve this linear search by keeping an index to the first unused - // element. - for (auto i = 0; i < entries_.size(); ++i) { - auto& entry = entries_[i]; - if (!entry.used) { - entry.value = value; - entry.used = true; - return static_cast(i); - } - } - // Grow the set of entries. - H handle = static_cast(entries_.size()); - entries_.emplace_back(value); - return handle; - } - - Expected Unregister(H handle) { - if (handle < 0 || handle >= entries_.size()) { - return Unexpected(kLiteRtStatusErrorNotFound, "Unexpected handle"); - } - entries_[handle].used = false; - return {}; - } - - Expected Get(H handle) { - if (handle < 0 || handle >= entries_.size()) { - return Unexpected(kLiteRtStatusErrorNotFound, "Unexpected handle"); - } - return &entries_[handle].value; - } - - private: - struct Entry { - V value; - bool used; - explicit Entry(const V& v) : value(v), used(true) {} - }; - - std::vector entries_; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc deleted file mode 100644 index a0967992192570..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnLog.h" - -namespace litert::qnn { -namespace { - -void DefaultStdOutLogger(const char* fmt, QnnLog_Level_t level, - uint64_t timestamp, va_list argp) { - const char* levelStr = ""; - switch (level) { - case QNN_LOG_LEVEL_ERROR: - levelStr = " ERROR "; - break; - case QNN_LOG_LEVEL_WARN: - levelStr = "WARNING"; - break; - case QNN_LOG_LEVEL_INFO: - levelStr = " INFO "; - break; - case QNN_LOG_LEVEL_DEBUG: - levelStr = " DEBUG "; - break; - case QNN_LOG_LEVEL_VERBOSE: - levelStr = "VERBOSE"; - break; - case QNN_LOG_LEVEL_MAX: - levelStr = "UNKNOWN"; - break; - } - char buffer1[256]; - char buffer2[256]; - double ms = timestamp; - snprintf(buffer1, sizeof(buffer1), "%8.1fms [%-7s] ", ms, levelStr); - buffer1[sizeof(buffer1) - 1] = 0; - vsnprintf(buffer2, sizeof(buffer2), fmt, argp); - buffer2[sizeof(buffer1) - 2] = 0; - std::cout << buffer1 << buffer2; -} - -} // namespace - -QnnLog_Callback_t GetDefaultStdOutLogger() { return DefaultStdOutLogger; } - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h deleted file mode 100644 index 934a164b49f933..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ - -#include "third_party/qairt/latest/include/QNN/QnnLog.h" - -namespace litert::qnn { - -// Gets a default logger implementation to stdout. -// This is used when initializing qnn logging. -QnnLog_Callback_t GetDefaultStdOutLogger(); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc deleted file mode 100644 index 0094d76cb6a340..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -#include - -#include -#include // NOLINT -#include -#include -#include - -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpContext.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" -#include "third_party/qairt/latest/include/QNN/QnnBackend.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnContext.h" -#include "third_party/qairt/latest/include/QNN/QnnDevice.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnLog.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemCommon.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h" - -namespace litert::qnn { - -namespace { - -constexpr char kLibQnnGetProvidersSymbol[] = "QnnInterface_getProviders"; - -constexpr char kLibQnnSystemGetProvidersSymbol[] = - "QnnSystemInterface_getProviders"; - -typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)( - const QnnInterface_t*** provider_list, uint32_t* num_providers); - -typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)( - const QnnSystemInterface_t***, uint32_t*); - -Expected> LoadProvidersFromLib( - SharedLibrary& lib) { - QnnInterfaceGetProvidersFn_t get_providers = nullptr; - LITERT_ASSIGN_OR_RETURN(get_providers, - lib.LookupSymbol( - kLibQnnGetProvidersSymbol)); - const QnnInterface_t** interface_providers = nullptr; - uint32_t num_providers = 0; - if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to get providers"); - } - return absl::MakeSpan(interface_providers, num_providers); -} - -Expected> LoadSystemProvidersFromLib( - SharedLibrary& lib) { - LITERT_ASSIGN_OR_RETURN(QnnSystemInterfaceGetProvidersFn_t get_providers, - lib.LookupSymbol( - kLibQnnSystemGetProvidersSymbol)); - const QnnSystemInterface_t** interface_providers = nullptr; - uint32_t num_providers = 0; - if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get system providers"); - } - return absl::MakeSpan(interface_providers, num_providers); -} - -} // namespace - -QnnManager::~QnnManager() { - (void)FreeDevice(); - (void)FreeBackend(); - (void)FreeLogging(); -} - -LiteRtStatus QnnManager::LoadLib(absl::string_view path) { - LITERT_LOG(LITERT_INFO, "Loading qnn shared library from \"%s\"", - path.data()); - LITERT_ASSIGN_OR_RETURN(lib_, - SharedLibrary::Load(path, RtldFlags::Default())); - LITERT_LOG(LITERT_INFO, "Loaded qnn shared library", ""); - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::LoadSystemLib(absl::string_view path) { - LITERT_ASSIGN_OR_RETURN(lib_system_, - SharedLibrary::Load(path, RtldFlags::Default())); - return kLiteRtStatusOk; -} - -const QnnApi* QnnManager::Api() const { - if (interface_ == nullptr) { - return nullptr; - } - return &interface_->QNN_INTERFACE_VER_NAME; -} - -LiteRtStatus QnnManager::ResolveApi() { - if (!lib_.Loaded()) { - LITERT_LOG(LITERT_ERROR, "%s", - "Cannot resolve functions: libQnn*.so has not been loaded.\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - LITERT_ASSIGN_OR_RETURN(auto providers, LoadProvidersFromLib(lib_)); - for (const auto& prov : providers) { - const bool major = - prov->apiVersion.coreApiVersion.major == QNN_API_VERSION_MAJOR; - - const bool minor = - prov->apiVersion.coreApiVersion.minor == QNN_API_VERSION_MINOR; - - const bool patch = - prov->apiVersion.coreApiVersion.patch == QNN_API_VERSION_PATCH; - - if (major && minor && patch) { - interface_ = prov; - break; - } - } - - if (interface_ == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", "No valid interface was provided\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::ResolveSystemApi() { - if (!lib_.Loaded()) { - LITERT_LOG(LITERT_ERROR, "%s", - "Cannot resolve functions: libQnn*.so has not been loaded.\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - LITERT_ASSIGN_OR_RETURN(auto system_providers, - LoadSystemProvidersFromLib(lib_system_)); - for (const auto& system_prov : system_providers) { - const bool major = - system_prov->systemApiVersion.major == QNN_SYSTEM_API_VERSION_MAJOR; - - const bool minor = - system_prov->systemApiVersion.minor == QNN_SYSTEM_API_VERSION_MINOR; - - const bool patch = - system_prov->systemApiVersion.patch == QNN_SYSTEM_API_VERSION_PATCH; - - if (major && minor && patch) { - system_interface_ = system_prov; - break; - } - } - - if (system_interface_ == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", "No valid system interface was provided\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - return kLiteRtStatusOk; -} - -const QnnSystemApi* QnnManager::SystemApi() const { - if (system_interface_ == nullptr) { - return nullptr; - } - return &system_interface_->QNN_SYSTEM_INTERFACE_VER_NAME; -} - -LiteRtStatus QnnManager::FreeLogging() { - if (log_handle_ != nullptr) { - if (QNN_SUCCESS != Api()->logFree(log_handle_)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to free logging\n"); - return kLiteRtStatusErrorNotFound; - } - } - log_handle_ = nullptr; - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::FreeBackend() { - if (backend_handle_ != nullptr) { - if (QNN_SUCCESS != Api()->backendFree(backend_handle_)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to free backend\n"); - return kLiteRtStatusErrorNotFound; - } - } - backend_handle_ = nullptr; - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::FreeDevice() { - if (device_handle_ != nullptr) { - if (QNN_SUCCESS != Api()->deviceFree(device_handle_)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to free device\n"); - return kLiteRtStatusErrorNotFound; - } - } - device_handle_ = nullptr; - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::GenerateContextBinary( - Qnn_ContextHandle_t context_handle, std::vector& buffer) { - Qnn_ContextBinarySize_t bin_size = 0; - if (QNN_SUCCESS != Api()->contextGetBinarySize(context_handle, &bin_size)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to get context bin size\n"); - return kLiteRtStatusErrorNotFound; - } - buffer.clear(); - buffer.resize(bin_size); - - Qnn_ContextBinarySize_t written_bin_size = 0; - if (QNN_SUCCESS != Api()->contextGetBinary(context_handle, buffer.data(), - buffer.size(), - &written_bin_size)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to generated context binary \n"); - return kLiteRtStatusErrorNotFound; - } - - LITERT_LOG(LITERT_INFO, "Serialized a context bin of size (bytes): %lu\n", - written_bin_size); - - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::ValidateOp(const Qnn_OpConfig_t& op_config) { - if (Qnn_ErrorHandle_t error = - Api()->backendValidateOpConfig(BackendHandle(), op_config); - QNN_SUCCESS != error) { - LITERT_LOG(LITERT_ERROR, "Failed to validate op %s\n, error: %lld", - op_config.v1.name, static_cast(error)); - return kLiteRtStatusErrorInvalidLegalization; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::Init(absl::Span configs, - std::optional shared_library_dir, - std::optional soc_model) { - // If shared_library_dir is provided, add it to the path as it may contain - // libs to be loaded. - // TOOD: This should probably be done upstream in litert_dispatch. - if (shared_library_dir) { - LITERT_LOG(LITERT_INFO, "Adding shared library dir to path: %s", - shared_library_dir->c_str()); - - static constexpr char kAdsp[] = "ADSP_LIBRARY_PATH"; - if (getenv(kAdsp) == nullptr) { - setenv(kAdsp, shared_library_dir->data(), /*overwrite=*/1); - } - - // TODO: Put dynamic loading module in cc or vendor/cc. - litert::internal::PutLibOnLdPath(shared_library_dir->data(), kLibQnnHtpSo); - } - - LITERT_RETURN_IF_ERROR(LoadLib(kLibQnnHtpSo)); - LITERT_RETURN_IF_ERROR(ResolveApi()); - - LITERT_RETURN_IF_ERROR(LoadSystemLib(kLibQnnSystemSo)); - LITERT_RETURN_IF_ERROR(ResolveSystemApi()); - - if (auto status = Api()->logCreate(GetDefaultStdOutLogger(), - QNN_LOG_LEVEL_INFO, &LogHandle()); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN logger: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - if (auto status = - Api()->backendCreate(LogHandle(), configs.data(), &BackendHandle()); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN backend: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - if (soc_model.has_value()) { - soc_model_ = *soc_model; - LITERT_LOG(LITERT_INFO, - "Initializing QNN backend for device architecture %d", - *soc_model); - QnnHtpDevice_CustomConfig_t arch_custom_config = {}; - arch_custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; - arch_custom_config.arch.arch = *soc_model; - arch_custom_config.arch.deviceId = 0; - - QnnDevice_Config_t arch_device_config = {}; - arch_device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; - arch_device_config.customConfig = &arch_custom_config; - - const QnnDevice_Config_t* device_configs[2] = { - &arch_device_config, - nullptr, - }; - - if (auto status = - Api()->deviceCreate(nullptr, device_configs, &DeviceHandle()); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN device: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - } - - return kLiteRtStatusOk; -} - -Expected -QnnManager::CreateSystemContextHandle() { - QnnSystemContext_Handle_t system_context_handle; - if (auto status = SystemApi()->systemContextCreate(&system_context_handle); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN system context: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create QNN system context"); - } - auto deleter = SystemApi()->systemContextFree; - return SystemContextHandle{system_context_handle, deleter}; -} - -Expected QnnManager::CreateContextHandle( - absl::Span configs) { - Qnn_ContextHandle_t context_handle; - if (auto status = Api()->contextCreate(BackendHandle(), DeviceHandle(), - configs.data(), &context_handle); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create QNN context"); - } - auto deleter = Api()->contextFree; - return ContextHandle{context_handle, /*profile=*/nullptr, deleter}; -} - -Expected QnnManager::CreateContextHandle( - absl::Span configs, - absl::Span bytecode, Qnn_ProfileHandle_t profile_handle) { - Qnn_ContextHandle_t context_handle; - if (auto status = Api()->contextCreateFromBinary( - BackendHandle(), DeviceHandle(), configs.data(), bytecode.data(), - bytecode.size(), &context_handle, profile_handle); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create QNN context"); - } - auto deleter = Api()->contextFree; - return ContextHandle{context_handle, profile_handle, deleter}; -} - -Expected QnnManager::Create( - absl::Span configs, - std::optional shared_library_dir, - std::optional soc_model) { - Ptr qnn_manager(new QnnManager); - if (auto status = qnn_manager->Init(configs, shared_library_dir, soc_model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to set up QNN manager"); - } - return qnn_manager; -} - -absl::Span QnnManager::DefaultBackendConfigs() { - static const QnnBackend_Config_t* configs[] = {nullptr}; - return absl::MakeSpan(configs); -} - -absl::Span QnnManager::DefaultContextConfigs() { - static const QnnContext_Config_t* configs[] = {nullptr}; - return absl::MakeSpan(configs); -} - -absl::Span -QnnManager::WeightSharingContextConfigs() { - static QnnHtpContext_CustomConfig_t customConfig = - QNN_HTP_CONTEXT_CUSTOM_CONFIG_INIT; - customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - customConfig.weightSharingEnabled = true; - static QnnContext_Config_t contextConfig = QNN_CONTEXT_CONFIG_INIT; - contextConfig.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - contextConfig.customConfig = &customConfig; - static const QnnContext_Config_t* configs[2] = {&contextConfig, nullptr}; - return absl::MakeSpan(configs); -} - -}; // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h deleted file mode 100644 index 30d00ab7169706..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" -#include "third_party/qairt/latest/include/QNN/QnnBackend.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnContext.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" - -//===----------------------------------------------------------------------===// -// -// QnnManger -// -// Syntactic sugar for various Qnn Sdk routines. -// -// Provides various utilities for linking shared libraries at runtime -// against Qnn symbols as well as convience getters and storage of handles -// (pointers). Provides simple wrappers for freeing handles and returning -// LiteRtStatus rather than Qnn ones. Additionally exposes hooks for dumping -// api and shared libarary details. -// -// Does not own any memory and will always have trivial cstor/dstor. The -// user is responsible for freeing any Qnn handles explicitly. Note, -// Qnn handles will be automatically freed when the library is unloaded -// if they have been already. -// -//===----------------------------------------------------------------------===// - -namespace litert::qnn { - -class QnnManager; - -namespace internal { - -void Dump(const QnnManager& qnn, std::ostream& out); - -} // namespace internal - -class QnnManager { - friend void internal::Dump(const QnnManager& qnn, std::ostream& out); - - public: - using Ptr = std::unique_ptr; - using SystemContextHandle = - std::unique_ptr::type, - QnnSystemContext_FreeFn_t>; - class ContextHandle; - - ~QnnManager(); - - static Expected Create( - absl::Span configs, - std::optional shared_library_dir = std::nullopt, - std::optional soc_model = std::nullopt); - - static absl::Span DefaultBackendConfigs(); - static absl::Span DefaultContextConfigs(); - static absl::Span WeightSharingContextConfigs(); - - // Get resolved function pointers for qnn sdk calls. Nullptr if functions - // have not been resolved yet. - const QnnApi* Api() const; - - // Get resolved function pointers for qnn sdk calls. Nullptr if functions - // have not been resolved yet. - const QnnSystemApi* SystemApi() const; - - // - // QNN SDK Objects. - // - - // Create system context handle. - Expected CreateSystemContextHandle(); - - // Create a context handle for compilation. - Expected CreateContextHandle( - absl::Span configs); - - // Create a context handle for inference, from a given bytecode. - Expected CreateContextHandle( - absl::Span configs, - absl::Span bytecode, Qnn_ProfileHandle_t profile_handle); - - // - // Context Binary - // - - // Generates QNN context binary from current context. Writes to given - // buffer. - LiteRtStatus GenerateContextBinary(Qnn_ContextHandle_t context_handle, - std::vector& buffer); - - LiteRtStatus ValidateOp(const Qnn_OpConfig_t& op_config); - - bool IsLegacySocModel() { return soc_model_ == QNN_HTP_DEVICE_ARCH_V68; } - - private: - QnnManager() = default; - - LiteRtStatus Init(absl::Span configs, - std::optional shared_library_dir, - std::optional soc_model); - - // - // Manage libQnn*.so Loading - // - - // Loads the libQnn*.so at given path. - LiteRtStatus LoadLib(absl::string_view path); - - // Loads the libQnnSystem.so at given path. - LiteRtStatus LoadSystemLib(absl::string_view path); - - // - // Resolve and Access QNN SDK Functions - // - - // Resolve all available QNN SDK functions from (already) loaded so. If - // multiple providers are found, selects the first one with a suitable - // version. Fails if none can be found. - LiteRtStatus ResolveApi(); - - // Resolve all available QNN SDK functions from (already) loaded so. If - // multiple providers are found, selects the first one with a suitable - // version. Fails if none can be found. - LiteRtStatus ResolveSystemApi(); - - // Get qnn log handle. Nullptr if logCreate has not been successfully called. - Qnn_LogHandle_t& LogHandle() { return log_handle_; } - - // Get qnn backend handle. Nullptr if backendCreate has not been successfully - // called. - Qnn_BackendHandle_t& BackendHandle() { return backend_handle_; } - - // Get qnn device handle. Nullptr if deviceCreate has not been successfully - // called. - Qnn_DeviceHandle_t& DeviceHandle() { return device_handle_; } - - // Signal QNN SDK to free any memory related to the device. Does nothing - // if deviceCreate has not been called. - LiteRtStatus FreeDevice(); - - // Signal QNN SDK to free any memory related to logging. Does nothing - // if logCreate has not been called. - LiteRtStatus FreeLogging(); - - // Signal QNN SDK to free any memory related to backend. Does nothing - // if backendCreate has not been called. - LiteRtStatus FreeBackend(); - - // Handle to the shared library that implements the API. The library is - // released when the manager is destroyed. - SharedLibrary lib_; - - // Handle to the system shared library that implements the API. The library is - // released when the manager is destroyed. - SharedLibrary lib_system_; - - const QnnInterface_t* interface_ = nullptr; - const QnnSystemInterface_t* system_interface_ = nullptr; - - Qnn_LogHandle_t log_handle_ = nullptr; - Qnn_BackendHandle_t backend_handle_ = nullptr; - Qnn_DeviceHandle_t device_handle_ = nullptr; - QnnHtpDevice_Arch_t soc_model_ = QNN_HTP_DEVICE_ARCH_UNKNOWN; -}; - -// Unfortunately we can't use std::unique_ptr with a deleter because -// QnnContext_FreeFn_t takes a profile handle as a second argument. -class QnnManager::ContextHandle { - public: - ContextHandle(Qnn_ContextHandle_t context_handle, Qnn_ProfileHandle_t profile, - QnnContext_FreeFn_t free_fn) - : context_handle_(context_handle), profile_(profile), free_fn_(free_fn) {} - - ~ContextHandle() { - if (context_handle_ && free_fn_) { - free_fn_(context_handle_, profile_); - } - } - - ContextHandle(ContextHandle&& other) { *this = std::move(other); } - - ContextHandle(const ContextHandle& other) = delete; - - ContextHandle& operator=(ContextHandle&& other) { - std::swap(context_handle_, other.context_handle_); - std::swap(profile_, other.profile_); - std::swap(free_fn_, other.free_fn_); - return *this; - } - - ContextHandle& operator=(const ContextHandle& other) = delete; - - Qnn_ContextHandle_t get() const noexcept { return context_handle_; } - explicit operator bool() const noexcept { return context_handle_ != nullptr; } - - private: - Qnn_ContextHandle_t context_handle_ = nullptr; - Qnn_ProfileHandle_t profile_ = nullptr; - QnnContext_FreeFn_t free_fn_ = nullptr; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc deleted file mode 100644 index 742af4f508dd64..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h" - -namespace { - -using ::litert::qnn::QnnManager; -using ::litert::qnn::internal::Dump; -using ::testing::HasSubstr; - -// NOTE: This tests that all of the dynamic loading works properly and -// the QNN SDK instance can be properly initialized and destroyed. - -TEST(QnnManagerTest, SetupQnnManager) { - auto configs = QnnManager::DefaultBackendConfigs(); - auto qnn = QnnManager::Create(configs); - ASSERT_TRUE(qnn); -} - -TEST(QnnManagerTest, Dump) { - auto configs = QnnManager::DefaultBackendConfigs(); - auto qnn = QnnManager::Create(configs); - ASSERT_TRUE(qnn); - - std::ostringstream dump; - Dump(**qnn, dump); - - EXPECT_THAT(dump.str(), HasSubstr("< QnnInterface_t >")); - EXPECT_THAT(dump.str(), HasSubstr("< QnnSystemInterface_t >")); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc deleted file mode 100644 index 557a5d2f9ed56c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" - -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace qnn { - -QnnTensor::QnnTensor(const QnnTensor& other) : QnnTensor(other.Tensor()) { - auto status = DeepCopy(); - // This should never fail because the input QnnTensor was already deep-copied. - if (!status) { - LITERT_LOG(LITERT_ERROR, "Failed to build QnnTensor: %s", - status.Error().Message().c_str()); - ABSL_CHECK(status); - } -} - -QnnTensor::QnnTensor(QnnTensor&& other) { - tensor_ = other.tensor_; - // Swap managed memory. - std::swap(name_, other.name_); - std::swap(dimensions_, other.dimensions_); - std::swap(is_dynamic_dimensions_, other.is_dynamic_dimensions_); -} - -Expected QnnTensor::Create(const Qnn_Tensor_t& tensor) { - QnnTensor qnn_tensor(tensor); - if (auto status = qnn_tensor.DeepCopy(); !status) { - return Unexpected(status.Error()); - } - return qnn_tensor; -} - -Expected QnnTensor::DeepCopy() { - if (tensor_.version == QNN_TENSOR_VERSION_1) { - dimensions_.reserve(tensor_.v1.rank); - std::copy(tensor_.v1.dimensions, tensor_.v1.dimensions + tensor_.v1.rank, - std::back_inserter(dimensions_)); - tensor_.v1.dimensions = dimensions_.data(); - - // FIXME: Implement deep copy for quantizeParams. - if (tensor_.v1.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || - tensor_.v1.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_VECTOR) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN quantization"); - } - - } else if (tensor_.version == QNN_TENSOR_VERSION_2) { - dimensions_.reserve(tensor_.v2.rank); - std::copy(tensor_.v2.dimensions, tensor_.v2.dimensions + tensor_.v2.rank, - std::back_inserter(dimensions_)); - tensor_.v2.dimensions = dimensions_.data(); - - if (tensor_.v2.isDynamicDimensions) { - is_dynamic_dimensions_.reserve(tensor_.v2.rank); - std::copy(tensor_.v2.isDynamicDimensions, - tensor_.v2.isDynamicDimensions + tensor_.v2.rank, - std::back_inserter(is_dynamic_dimensions_)); - tensor_.v2.isDynamicDimensions = is_dynamic_dimensions_.data(); - } - - // FIXME: Implement deep copy for quantizeParams. - if (tensor_.v2.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || - tensor_.v2.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_VECTOR) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN quantization"); - } - - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN tensor version"); - } - - return {}; -} - -} // namespace qnn -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h deleted file mode 100644 index c0429ce01864e5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::qnn { - -class QnnTensor { - public: - static Expected Create(const Qnn_Tensor_t& tensor); - - QnnTensor(const QnnTensor& other); - QnnTensor(QnnTensor&& other); - - QnnTensor& operator=(const QnnTensor&) = delete; - QnnTensor& operator=(QnnTensor&&) = delete; - - Qnn_Tensor_t& Tensor() { return tensor_; } - const Qnn_Tensor_t& Tensor() const { return tensor_; } - - size_t Rank() const { return dimensions_.size(); } - const uint32_t* Dimensions() const { return dimensions_.data(); } - - private: - explicit QnnTensor(const Qnn_Tensor_t& tensor) : tensor_(tensor) {} - Expected DeepCopy(); - - Qnn_Tensor_t tensor_; - std::string name_; - std::vector dimensions_; - std::vector is_dynamic_dimensions_; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl b/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl deleted file mode 100644 index d4c9c70db3674e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for QualComm backend.""" - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "append_rule_kwargs", "litert_bin", "litert_lib", "make_rpaths") - -_QNN_LIBCC_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libc++.so.1", - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libc++abi.so.1", - # copybara:uncomment_end -] # @unused - -# TODO: Make rpaths dynamic with "$(location {})". -_QNN_LIB_RPATHS_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "third_party/qairt/latest/lib/x86_64-linux-clang", - # copybara:uncomment_end -] - -_QNN_LIB_HTP_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnHtp.so", - # copybara:uncomment_end -] - -_QNN_LIB_SYSTEM_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnSystem.so", - # copybara:uncomment_end -] - -def _litert_with_qnn_base( - litert_rule, - backend, - include_system, - use_custom_libcc, - **litert_rule_kwargs): - if backend != "htp": - fail("Only htp currently supported") - - if use_custom_libcc: - # TODO: Figure out strategy for custom libcc. - fail("Custom libcc not yet supported") - - data_x86_64 = [] - data_x86_64.extend(_QNN_LIB_HTP_X86_64) - if include_system: - data_x86_64.extend(_QNN_LIB_SYSTEM_X86_64) - data = select({ - "//tensorflow:linux_x86_64": data_x86_64, - "//conditions:default": [], - }) - - append_rule_kwargs( - litert_rule_kwargs, - data = data, - linkopts = select({ - "//tensorflow:linux_x86_64": [make_rpaths(_QNN_LIB_RPATHS_X86_64)], - "//conditions:default": [], - }), - ) - - litert_rule(**litert_rule_kwargs) - -def litert_cc_lib_with_qnn( - backend = "htp", - include_system = False, - use_custom_libcc = False, - **litert_lib_kwargs): - """Creates a litert_lib target with QualComm backend dependencies. - - Args: - backend: The backend to use. Currently only "htp" is supported. - include_system: Whether to include libQnnSystem.so. - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_lib_kwargs: Keyword arguments passed to litert_lib. - """ - _litert_with_qnn_base( - litert_lib, - backend, - include_system, - use_custom_libcc, - **litert_lib_kwargs - ) - -def litert_cc_bin_with_qnn( - backend = "htp", - include_system = False, - use_custom_libcc = False, - **litert_bin_kwargs): - """Creates a litert_bin target with QualComm backend dependencies. - - Args: - backend: The backend to use. Currently only "htp" is supported. - include_system: Whether to include libQnnSystem.so. - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_bin_kwargs: Keyword arguments passed to litert_bin. - """ - _litert_with_qnn_base( - litert_bin, - backend, - include_system, - use_custom_libcc, - **litert_bin_kwargs - ) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv b/tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv deleted file mode 100644 index 52b7f881f47a3e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv +++ /dev/null @@ -1,41 +0,0 @@ -# manufacturer,model,runtime_library_version,soc_model -Qualcomm,SM8750,v79,69 -Qualcomm,SM8650,v75,57 -Qualcomm,SM8635,v73,68 -Qualcomm,SM8550,v73,43 -Qualcomm,SM7675,v73,70 -Qualcomm,SM7550,v73, -Qualcomm,SM7435,v73, -Qualcomm,SM6450,v73,50 -Qualcomm,QCM8550LA,v73,66 -Qualcomm,QCM8550LE,v73,66 -Qualcomm,SM8475,v69,42 -Qualcomm,SM8450,v69,36 -Qualcomm,SM7475,v69,54 -Qualcomm,SM7450,v69,41 -Qualcomm,SM7425,v69, -Qualcomm,SXR2230P,v69,53 -Qualcomm,SXR2250P,v69, -Qualcomm,SM8350,v68,30 -Qualcomm,SM8350P,v68,30 -Qualcomm,SM7350,v68,32 -Qualcomm,SM7325,v68,35 -Qualcomm,SM7315,v68,38 -Qualcomm,QCM6490,v68,35 -Qualcomm,SM8250,v66,21 -Qualcomm,SM8150,v66, -Qualcomm,SM7250,v66,25 -Qualcomm,SM7225,v66,29 -Qualcomm,SM7125,v66, -Qualcomm,SM6350,v66,29 -Qualcomm,SM6225,v66,40 -Qualcomm,SM6150,v66, -Qualcomm,SM6125,v66, -Qualcomm,SM4350,v66,31 -Qualcomm,QRB5165U,v66,21 -Qualcomm,QRB5165LE,v66,21 -Qualcomm,QCS7230LA,v66,51 -Qualcomm,QCS7230LE,v66,51 -Qualcomm,SM6375,v66,31 -Qualcomm,SM7150,v65, -Qualcomm,SDM845,v65, diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD deleted file mode 100644 index 45df0fef3b5a21..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "dump", - srcs = ["dump.cc"], - hdrs = ["dump.h"], - tags = ["nobuilder"], - deps = [ - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager_hdr", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc deleted file mode 100644 index 0e94b6b0385890..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn::internal { -namespace { - -static constexpr absl::string_view kNullDumpTpl = "%s : nullptr\n"; - -void Dump(const QnnInterface_t* interface, std::ostream& out) { - static constexpr absl::string_view kQnnInterfaceHeader = "< QnnInterface_t >"; - // NOLINTBEGIN - static constexpr absl::string_view kQnnInterfaceDumpTpl = - "\ - %s\n\ - name: %s\n\ - backend_id: %u\n\ - core_api_version: %u.%u.%u\n\ - backend_api_version: %u.%u.%u\n"; - // NOLINTEND - - if (interface == nullptr) { - out << absl::StreamFormat(kNullDumpTpl, kQnnInterfaceHeader); - return; - } - - const auto core_version = interface->apiVersion.coreApiVersion; - const auto backend_version = interface->apiVersion.backendApiVersion; - - out << absl::StreamFormat(kQnnInterfaceDumpTpl, kQnnInterfaceHeader, - interface->providerName, interface->backendId, - core_version.major, core_version.minor, - core_version.patch, backend_version.major, - backend_version.minor, backend_version.patch); -} - -void Dump(const QnnSystemInterface_t* interface, std::ostream& out) { - static constexpr absl::string_view kQnnSystemInterfaceHeader = - "< QnnSystemInterface_t >"; - // NOLINTBEGIN - static constexpr absl::string_view kQnnSystemInterfaceDumpTpl = - "\ - %s\n\ - name: %s\n\ - backend_id: %u\n\ - system_api_version: %u.%u.%u\n"; - // NOLINTEND - - if (interface == nullptr) { - out << absl::StreamFormat(kNullDumpTpl, kQnnSystemInterfaceHeader); - return; - } - - const auto system_version = interface->systemApiVersion; - - out << absl::StreamFormat(kQnnSystemInterfaceDumpTpl, - kQnnSystemInterfaceHeader, interface->providerName, - interface->backendId, system_version.major, - system_version.minor, system_version.patch); -} - -} // namespace - -void Dump(const QnnManager& qnn, std::ostream& out) { - Dump(qnn.interface_, out); - Dump(qnn.system_interface_, out); -} -} // namespace litert::qnn::internal diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h deleted file mode 100644 index b64650249af0af..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn::internal { - -void Dump(const QnnManager& qnn, std::ostream& out = std::cerr); - -} - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index da69497a97f2b8..8e37b86fdf2f9c 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -109,7 +109,6 @@ tf_staging/tensorflow/lite/experimental/acceleration/configuration/configuration tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h: tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg.h: tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl: -tf_staging/tensorflow/lite/experimental/litert/build_common/special_rule.bzl: tf_staging/tensorflow/lite/interpreter.h: tf_staging/tensorflow/lite/interpreter_builder.h: tf_staging/tensorflow/lite/ios/BUILD: From 3b1f42ab1de955491eb3819c575fb1aeaf2546c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 18:28:12 -0700 Subject: [PATCH 0480/1324] Refactor `cuda`, `cudnn`, `nccl` repository rules to create CUDA repositories in parallel. Bazel processes labels in the `load` statements in the beginning of `cuda_configure.bzl` file in parallel, hence CUDA repositories are created in parallel as well. Previously the CUDA repositories were created sequentially in the order of `repository_ctx.read()` operations inside the `cuda_configure` rule implementation. Bazel creates repositories only when there are direct usages of them, and the list of required CUDA repositories was not known until the repository rule `cuda_configure` was executed. PiperOrigin-RevId: 745809642 --- .../gpus/cuda/hermetic/cuda_cccl.BUILD.tpl | 4 -- .../gpus/cuda/hermetic/cuda_configure.bzl | 50 +++++++++++-------- .../gpus/cuda/hermetic/cuda_cublas.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cudart.BUILD.tpl | 4 -- .../gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cufft.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cupti.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_curand.BUILD.tpl | 3 -- .../cuda/hermetic/cuda_cusolver.BUILD.tpl | 3 -- .../cuda/hermetic/cuda_cusparse.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl | 1 - .../cuda/hermetic/cuda_nvjitlink.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_nvml.BUILD.tpl | 4 -- .../gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl | 4 -- .../cuda_redist_init_repositories.bzl | 20 +++++--- third_party/nccl/hermetic/nccl_configure.bzl | 6 +-- .../hermetic/nccl_redist_init_repository.bzl | 5 +- .../gpus/cuda/hermetic/cuda_cccl.BUILD.tpl | 4 -- .../gpus/cuda/hermetic/cuda_configure.bzl | 50 +++++++++++-------- .../gpus/cuda/hermetic/cuda_cublas.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cudart.BUILD.tpl | 4 -- .../gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cufft.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_cupti.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_curand.BUILD.tpl | 3 -- .../cuda/hermetic/cuda_cusolver.BUILD.tpl | 3 -- .../cuda/hermetic/cuda_cusparse.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl | 1 - .../cuda/hermetic/cuda_nvjitlink.BUILD.tpl | 3 -- .../gpus/cuda/hermetic/cuda_nvml.BUILD.tpl | 4 -- .../gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl | 4 -- .../cuda_redist_init_repositories.bzl | 20 +++++--- .../nccl/hermetic/nccl_configure.bzl | 6 +-- .../hermetic/nccl_redist_init_repository.bzl | 5 +- 36 files changed, 96 insertions(+), 154 deletions(-) diff --git a/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl index ce509857e5666a..85c0cbbb196fef 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -1,9 +1,5 @@ licenses(["restricted"]) # NVIDIA proprietary license -exports_files([ - "version.txt", -]) - cc_library( name = "headers", hdrs = glob([ diff --git a/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/gpus/cuda/hermetic/cuda_configure.bzl index 826b91b03d9f99..a9a7ad53cfe20c 100644 --- a/third_party/gpus/cuda/hermetic/cuda_configure.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -23,6 +23,14 @@ environment variable is used by GCC compiler. """ +load("@cuda_cublas//:version.bzl", _cublas_version = "VERSION") +load("@cuda_cudart//:version.bzl", _cudart_version = "VERSION") +load("@cuda_cudnn//:version.bzl", _cudnn_version = "VERSION") +load("@cuda_cufft//:version.bzl", _cufft_version = "VERSION") +load("@cuda_cupti//:version.bzl", _cupti_version = "VERSION") +load("@cuda_curand//:version.bzl", _curand_version = "VERSION") +load("@cuda_cusolver//:version.bzl", _cusolver_version = "VERSION") +load("@cuda_cusparse//:version.bzl", _cusparse_version = "VERSION") load( "//third_party/gpus:compiler_common_tools.bzl", "get_cxx_inc_directories", @@ -271,14 +279,14 @@ def _get_cuda_config(repository_ctx): return struct( cuda_version = get_cuda_version(repository_ctx), - cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), - cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), - cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), - cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), - curand_version = repository_ctx.read(repository_ctx.attr.curand_version), - cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), - cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), - cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + cupti_version = _cupti_version, + cudart_version = _cudart_version, + cublas_version = _cublas_version, + cusolver_version = _cusolver_version, + curand_version = _curand_version, + cufft_version = _cufft_version, + cusparse_version = _cusparse_version, + cudnn_version = _cudnn_version, compute_capabilities = _compute_capabilities(repository_ctx), cpu_value = get_cpu_value(repository_ctx), ) @@ -645,20 +653,20 @@ cuda_configure = repository_rule( environ = _ENVIRONS, attrs = { "environ": attr.string_dict(), - "cccl_version": attr.label(default = Label("@cuda_cccl//:version.txt")), - "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), - "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), - "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), - "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), - "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), - "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), - "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), - "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "cccl_version": attr.label(default = Label("@cuda_cccl//:version.bzl")), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.bzl")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.bzl")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.bzl")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.bzl")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.bzl")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.bzl")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.bzl")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.bzl")), "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), - "nvcc_version": attr.label(default = Label("@cuda_nvcc//:version.txt")), - "nvjitlink_version": attr.label(default = Label("@cuda_nvjitlink//:version.txt")), - "nvml_version": attr.label(default = Label("@cuda_nvml//:version.txt")), - "nvtx_version": attr.label(default = Label("@cuda_nvtx//:version.txt")), + "nvcc_version": attr.label(default = Label("@cuda_nvcc//:version.bzl")), + "nvjitlink_version": attr.label(default = Label("@cuda_nvjitlink//:version.bzl")), + "nvml_version": attr.label(default = Label("@cuda_nvml//:version.bzl")), + "nvtx_version": attr.label(default = Label("@cuda_nvtx//:version.bzl")), "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), diff --git a/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl index d8f125fa3d3253..a414cf781d4c5e 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cublas_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl index fabb310001cd39..3e513b69b68fb0 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -4,10 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) - filegroup( name = "static", srcs = ["lib/libcudart_static.a"], diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl index c3701a6241243d..b762577405ed69 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cudnn_ops_infer", diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl index 4e8bcbd84e0327..564066a420cb39 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cudnn_ops", diff --git a/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl index 2e55a742d54967..029171c28d4eba 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cufft_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl index 16d6991b584154..7f24e8d048b3fb 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -5,9 +5,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cupti_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl index 746503fcf22229..c33f35db0e97f8 100644 --- a/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "curand_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl index 30bacf07eebda2..167739ce67bb0d 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cusolver_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl index b7765ab22508dc..0c6ae547d09138 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cusparse_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl index 16ff3c8bea80dc..7757a92a90b795 100644 --- a/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -2,7 +2,6 @@ licenses(["restricted"]) # NVIDIA proprietary license exports_files([ "bin/nvcc", - "version.txt", ]) filegroup( diff --git a/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl index 5be8d6ef2408ba..d86afc3e943f0d 100644 --- a/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "nvjitlink_shared_library", diff --git a/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl index 65bcb04db5c2e1..23ee30f09f8ff3 100644 --- a/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -1,9 +1,5 @@ licenses(["restricted"]) # NVIDIA proprietary license -exports_files([ - "version.txt", -]) - cc_library( name = "headers", %{comment}hdrs = ["include/nvml.h"], diff --git a/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl index 72418eeb158dc1..3457f41a502dee 100644 --- a/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -1,9 +1,5 @@ licenses(["restricted"]) # NVIDIA proprietary license -exports_files([ - "version.txt", -]) - cc_library( name = "headers", %{comment}hdrs = glob([ diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index b40989c7b2797e..f30e68ce95e690 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -232,6 +232,12 @@ def _create_cuda_version_file(repository_ctx, lib_name_to_version_dict): "MAJOR_CUDA_VERSION = \"{}\"".format(major_cudart_version), ) +def create_version_file(repository_ctx, major_lib_version): + repository_ctx.file( + "version.bzl", + "VERSION = \"{}\"".format(major_lib_version), + ) + def use_local_path(repository_ctx, local_path, dirs): # buildifier: disable=function-docstring-args """Creates repository using local redistribution paths.""" @@ -255,7 +261,7 @@ def use_local_path(repository_ctx, local_path, dirs): lib_name_to_version_dict, ) _create_cuda_version_file(repository_ctx, lib_name_to_version_dict) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _use_local_cuda_path(repository_ctx, local_cuda_path): # buildifier: disable=function-docstring-args @@ -329,7 +335,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): # If no CUDA version is found, comment out all cc_import targets. create_dummy_build_file(repository_ctx) _create_cuda_version_file(repository_ctx, {}) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return if len(repository_ctx.attr.url_dict) == 0: @@ -338,7 +344,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): )) # buildifier: disable=print create_dummy_build_file(repository_ctx) _create_cuda_version_file(repository_ctx, {}) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return # Download archive only when GPU config is used. @@ -371,7 +377,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): ) _create_cuda_header_symlinks(repository_ctx) _create_cuda_version_file(repository_ctx, lib_name_to_version_dict) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _cuda_repo_impl(repository_ctx): local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") @@ -409,7 +415,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): if not cudnn_version: # If no CUDNN version is found, comment out cc_import targets. create_dummy_build_file(repository_ctx) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return if len(repository_ctx.attr.url_dict) == 0: @@ -417,7 +423,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): repository_ctx.name, )) # buildifier: disable=print create_dummy_build_file(repository_ctx) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return # Download archive only when GPU config is used. @@ -455,7 +461,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): major_version, ) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _cudnn_repo_impl(repository_ctx): local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") diff --git a/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/nccl/hermetic/nccl_configure.bzl index acbfd146e2392f..fe80dc2a3d54b9 100644 --- a/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/nccl/hermetic/nccl_configure.bzl @@ -9,6 +9,7 @@ """ +load("@cuda_nccl//:version.bzl", _nccl_version = "VERSION") load( "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "HERMETIC_CUDA_VERSION", @@ -132,7 +133,7 @@ alias( def _create_local_nccl_repository(repository_ctx): cuda_version = get_cuda_version(repository_ctx).split(".")[:2] - nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + nccl_version = _nccl_version if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) @@ -160,7 +161,7 @@ def _nccl_autoconf_impl(repository_ctx): # Add a dummy build file to make bazel query happy. repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) if use_cuda_redistributions(repository_ctx): - nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + nccl_version = _nccl_version repository_ctx.file( "nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version, @@ -185,7 +186,6 @@ nccl_configure = repository_rule( implementation = _nccl_autoconf_impl, attrs = { "environ": attr.string_dict(), - "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), }, diff --git a/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/nccl/hermetic/nccl_redist_init_repository.bzl index 3bb2fe0efcf5bd..524ba30a50eb5b 100644 --- a/third_party/nccl/hermetic/nccl_redist_init_repository.bzl +++ b/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -20,6 +20,7 @@ load( "OS_ARCH_DICT", "create_build_file", "create_dummy_build_file", + "create_version_file", "get_archive_name", "get_env_var", "get_lib_name_to_version_dict", @@ -42,7 +43,7 @@ def _use_downloaded_nccl_wheel(repository_ctx): if not cuda_version: # If no CUDA version is found, comment out cc_import targets. create_dummy_build_file(repository_ctx) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return # Download archive only when GPU config is used. @@ -105,7 +106,7 @@ def _use_downloaded_nccl_wheel(repository_ctx): major_version, ) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _use_local_nccl_path(repository_ctx, local_nccl_path): # buildifier: disable=function-docstring-args diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl index ce509857e5666a..85c0cbbb196fef 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -1,9 +1,5 @@ licenses(["restricted"]) # NVIDIA proprietary license -exports_files([ - "version.txt", -]) - cc_library( name = "headers", hdrs = glob([ diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl index 826b91b03d9f99..a9a7ad53cfe20c 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -23,6 +23,14 @@ environment variable is used by GCC compiler. """ +load("@cuda_cublas//:version.bzl", _cublas_version = "VERSION") +load("@cuda_cudart//:version.bzl", _cudart_version = "VERSION") +load("@cuda_cudnn//:version.bzl", _cudnn_version = "VERSION") +load("@cuda_cufft//:version.bzl", _cufft_version = "VERSION") +load("@cuda_cupti//:version.bzl", _cupti_version = "VERSION") +load("@cuda_curand//:version.bzl", _curand_version = "VERSION") +load("@cuda_cusolver//:version.bzl", _cusolver_version = "VERSION") +load("@cuda_cusparse//:version.bzl", _cusparse_version = "VERSION") load( "//third_party/gpus:compiler_common_tools.bzl", "get_cxx_inc_directories", @@ -271,14 +279,14 @@ def _get_cuda_config(repository_ctx): return struct( cuda_version = get_cuda_version(repository_ctx), - cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), - cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), - cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), - cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), - curand_version = repository_ctx.read(repository_ctx.attr.curand_version), - cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), - cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), - cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + cupti_version = _cupti_version, + cudart_version = _cudart_version, + cublas_version = _cublas_version, + cusolver_version = _cusolver_version, + curand_version = _curand_version, + cufft_version = _cufft_version, + cusparse_version = _cusparse_version, + cudnn_version = _cudnn_version, compute_capabilities = _compute_capabilities(repository_ctx), cpu_value = get_cpu_value(repository_ctx), ) @@ -645,20 +653,20 @@ cuda_configure = repository_rule( environ = _ENVIRONS, attrs = { "environ": attr.string_dict(), - "cccl_version": attr.label(default = Label("@cuda_cccl//:version.txt")), - "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), - "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), - "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), - "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), - "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), - "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), - "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), - "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "cccl_version": attr.label(default = Label("@cuda_cccl//:version.bzl")), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.bzl")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.bzl")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.bzl")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.bzl")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.bzl")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.bzl")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.bzl")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.bzl")), "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), - "nvcc_version": attr.label(default = Label("@cuda_nvcc//:version.txt")), - "nvjitlink_version": attr.label(default = Label("@cuda_nvjitlink//:version.txt")), - "nvml_version": attr.label(default = Label("@cuda_nvml//:version.txt")), - "nvtx_version": attr.label(default = Label("@cuda_nvtx//:version.txt")), + "nvcc_version": attr.label(default = Label("@cuda_nvcc//:version.bzl")), + "nvjitlink_version": attr.label(default = Label("@cuda_nvjitlink//:version.bzl")), + "nvml_version": attr.label(default = Label("@cuda_nvml//:version.bzl")), + "nvtx_version": attr.label(default = Label("@cuda_nvtx//:version.bzl")), "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl index d8f125fa3d3253..a414cf781d4c5e 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cublas_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl index fabb310001cd39..3e513b69b68fb0 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -4,10 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) - filegroup( name = "static", srcs = ["lib/libcudart_static.a"], diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl index c3701a6241243d..b762577405ed69 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cudnn_ops_infer", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl index 4e8bcbd84e0327..564066a420cb39 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cudnn_ops", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl index 2e55a742d54967..029171c28d4eba 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cufft_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl index 16d6991b584154..7f24e8d048b3fb 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -5,9 +5,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cupti_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl index 746503fcf22229..c33f35db0e97f8 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "curand_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl index 30bacf07eebda2..167739ce67bb0d 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cusolver_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl index b7765ab22508dc..0c6ae547d09138 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "cusparse_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl index 16ff3c8bea80dc..7757a92a90b795 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -2,7 +2,6 @@ licenses(["restricted"]) # NVIDIA proprietary license exports_files([ "bin/nvcc", - "version.txt", ]) filegroup( diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl index 5be8d6ef2408ba..d86afc3e943f0d 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -4,9 +4,6 @@ load( "cuda_rpath_flags", ) -exports_files([ - "version.txt", -]) %{multiline_comment} cc_import( name = "nvjitlink_shared_library", diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl index 65bcb04db5c2e1..23ee30f09f8ff3 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -1,9 +1,5 @@ licenses(["restricted"]) # NVIDIA proprietary license -exports_files([ - "version.txt", -]) - cc_library( name = "headers", %{comment}hdrs = ["include/nvml.h"], diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl index 72418eeb158dc1..3457f41a502dee 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -1,9 +1,5 @@ licenses(["restricted"]) # NVIDIA proprietary license -exports_files([ - "version.txt", -]) - cc_library( name = "headers", %{comment}hdrs = glob([ diff --git a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index b40989c7b2797e..f30e68ce95e690 100644 --- a/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/xla/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -232,6 +232,12 @@ def _create_cuda_version_file(repository_ctx, lib_name_to_version_dict): "MAJOR_CUDA_VERSION = \"{}\"".format(major_cudart_version), ) +def create_version_file(repository_ctx, major_lib_version): + repository_ctx.file( + "version.bzl", + "VERSION = \"{}\"".format(major_lib_version), + ) + def use_local_path(repository_ctx, local_path, dirs): # buildifier: disable=function-docstring-args """Creates repository using local redistribution paths.""" @@ -255,7 +261,7 @@ def use_local_path(repository_ctx, local_path, dirs): lib_name_to_version_dict, ) _create_cuda_version_file(repository_ctx, lib_name_to_version_dict) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _use_local_cuda_path(repository_ctx, local_cuda_path): # buildifier: disable=function-docstring-args @@ -329,7 +335,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): # If no CUDA version is found, comment out all cc_import targets. create_dummy_build_file(repository_ctx) _create_cuda_version_file(repository_ctx, {}) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return if len(repository_ctx.attr.url_dict) == 0: @@ -338,7 +344,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): )) # buildifier: disable=print create_dummy_build_file(repository_ctx) _create_cuda_version_file(repository_ctx, {}) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return # Download archive only when GPU config is used. @@ -371,7 +377,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): ) _create_cuda_header_symlinks(repository_ctx) _create_cuda_version_file(repository_ctx, lib_name_to_version_dict) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _cuda_repo_impl(repository_ctx): local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") @@ -409,7 +415,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): if not cudnn_version: # If no CUDNN version is found, comment out cc_import targets. create_dummy_build_file(repository_ctx) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return if len(repository_ctx.attr.url_dict) == 0: @@ -417,7 +423,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): repository_ctx.name, )) # buildifier: disable=print create_dummy_build_file(repository_ctx) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return # Download archive only when GPU config is used. @@ -455,7 +461,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): major_version, ) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _cudnn_repo_impl(repository_ctx): local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") diff --git a/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl index acbfd146e2392f..fe80dc2a3d54b9 100644 --- a/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/xla/third_party/nccl/hermetic/nccl_configure.bzl @@ -9,6 +9,7 @@ """ +load("@cuda_nccl//:version.bzl", _nccl_version = "VERSION") load( "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "HERMETIC_CUDA_VERSION", @@ -132,7 +133,7 @@ alias( def _create_local_nccl_repository(repository_ctx): cuda_version = get_cuda_version(repository_ctx).split(".")[:2] - nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + nccl_version = _nccl_version if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) @@ -160,7 +161,7 @@ def _nccl_autoconf_impl(repository_ctx): # Add a dummy build file to make bazel query happy. repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) if use_cuda_redistributions(repository_ctx): - nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + nccl_version = _nccl_version repository_ctx.file( "nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version, @@ -185,7 +186,6 @@ nccl_configure = repository_rule( implementation = _nccl_autoconf_impl, attrs = { "environ": attr.string_dict(), - "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), }, diff --git a/third_party/xla/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/xla/third_party/nccl/hermetic/nccl_redist_init_repository.bzl index 3bb2fe0efcf5bd..524ba30a50eb5b 100644 --- a/third_party/xla/third_party/nccl/hermetic/nccl_redist_init_repository.bzl +++ b/third_party/xla/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -20,6 +20,7 @@ load( "OS_ARCH_DICT", "create_build_file", "create_dummy_build_file", + "create_version_file", "get_archive_name", "get_env_var", "get_lib_name_to_version_dict", @@ -42,7 +43,7 @@ def _use_downloaded_nccl_wheel(repository_ctx): if not cuda_version: # If no CUDA version is found, comment out cc_import targets. create_dummy_build_file(repository_ctx) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) return # Download archive only when GPU config is used. @@ -105,7 +106,7 @@ def _use_downloaded_nccl_wheel(repository_ctx): major_version, ) - repository_ctx.file("version.txt", major_version) + create_version_file(repository_ctx, major_version) def _use_local_nccl_path(repository_ctx, local_nccl_path): # buildifier: disable=function-docstring-args From d59b860d9292281b7d7dfb581d2933f5c78ba212 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Wed, 9 Apr 2025 18:43:00 -0700 Subject: [PATCH 0481/1324] - When function returns `absl::StatusOr` we can simply `return t`. No need to call `return std::move(t)`. C++ will use `absl::StatusOr(&&)` ctor anyway. - When we repack value from one container to another it is recommended to move the source container first and then call `.value()` instead of moving just one field from the container, example: - best (recommended): `return std::move(container).value();` - ok (still works): `return std::move(container.value());` PiperOrigin-RevId: 745813031 --- third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc | 2 +- third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc | 2 +- third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc | 2 +- third_party/xla/xla/hlo/analysis/indexing_map.cc | 2 +- third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc | 2 +- .../xla/xla/hlo/analysis/tuple_points_to_analysis.cc | 2 +- third_party/xla/xla/hlo/builder/lib/tuple.cc | 2 +- third_party/xla/xla/hlo/builder/value_inference.cc | 2 +- third_party/xla/xla/hlo/builder/xla_builder.cc | 6 +++--- third_party/xla/xla/hlo/builder/xla_computation.cc | 2 +- third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc | 4 ++-- .../xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h | 2 +- third_party/xla/xla/hlo/ir/hlo_computation.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_instruction.h | 2 +- third_party/xla/xla/hlo/ir/hlo_instructions.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_module.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_schedule.cc | 4 ++-- third_party/xla/xla/hlo/ir/hlo_sharding.cc | 2 +- third_party/xla/xla/hlo/parser/hlo_parser.cc | 2 +- third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc | 2 +- .../hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc | 2 +- third_party/xla/xla/hlo/tools/hlo_translate.cc | 4 ++-- .../xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc | 2 +- third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc | 2 +- third_party/xla/xla/hlo/translate/stablehlo.cc | 4 ++-- third_party/xla/xla/hlo/utils/hlo_live_range.cc | 2 +- third_party/xla/xla/hlo/utils/hlo_sharding_util.cc | 6 +++--- third_party/xla/xla/literal.cc | 6 +++--- third_party/xla/xla/literal.h | 2 +- third_party/xla/xla/shape.h | 2 +- 31 files changed, 41 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc index 8921ca3962a3e3..aae41c5fe06774 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc @@ -421,7 +421,7 @@ absl::StatusOr> HloAliasAnalysis::Run( }); XLA_VLOG_LINES(2, alias_analysis->ToString()); - return std::move(alias_analysis); + return alias_analysis; } } // namespace xla diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc index 833de6e30169b2..27c351c5099cca 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc @@ -1766,7 +1766,7 @@ absl::StatusOr> HloDataflowAnalysis::Run( XLA_VLOG_LINES(1, dataflow_analysis->ToString()); - return std::move(dataflow_analysis); + return dataflow_analysis; } absl::Status HloDataflowAnalysis::Verify() const { diff --git a/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc index f9e63a385d8f1c..3c142220d052c1 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc @@ -350,7 +350,7 @@ absl::StatusOr> HloLivenessAnalysis::Run( liveness_analysis->RunAnalysis(); - return std::move(liveness_analysis); + return liveness_analysis; } } // namespace xla diff --git a/third_party/xla/xla/hlo/analysis/indexing_map.cc b/third_party/xla/xla/hlo/analysis/indexing_map.cc index 74819f6118022d..c89388782fe8ba 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_map.cc +++ b/third_party/xla/xla/hlo/analysis/indexing_map.cc @@ -1625,7 +1625,7 @@ SmallBitVector IndexingMap::RemoveUnusedSymbols() { if (!CompressVars(/*unused_dims=*/{}, unused_vars.unused_symbols)) { return {}; } - return std::move(unused_vars.unused_symbols); + return std::move(unused_vars).unused_symbols; } void IndexingMap::ResetToKnownEmpty() { diff --git a/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc index b6d92b51046174..d22bebb574f615 100644 --- a/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc @@ -58,7 +58,7 @@ LogicalBufferAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( new LogicalBufferAnalysis(module)); TF_RETURN_IF_ERROR(analysis->Analyze()); - return std::move(analysis); + return analysis; } absl::Status LogicalBufferAnalysis::Analyze() { diff --git a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc index d012c13e67ce7d..dcea12c88f5eb0 100644 --- a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc @@ -150,7 +150,7 @@ TuplePointsToAnalysis::Run(const HloModule* module) { std::unique_ptr analysis(new TuplePointsToAnalysis( module, std::move(logical_buffer_analysis).value())); TF_RETURN_IF_ERROR(analysis->Analyze()); - return std::move(analysis); + return analysis; } absl::Status TuplePointsToAnalysis::Analyze() { diff --git a/third_party/xla/xla/hlo/builder/lib/tuple.cc b/third_party/xla/xla/hlo/builder/lib/tuple.cc index 6a0145addefbde..1edfa273aff452 100644 --- a/third_party/xla/xla/hlo/builder/lib/tuple.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple.cc @@ -39,7 +39,7 @@ absl::StatusOr> DisassembleTuple(XlaOp tuple) { *element = GetTupleElement(parent, index.back()); } }); - return std::move(result); + return result; } XlaOp AssembleTuple(XlaBuilder* builder, ShapeTree elements) { diff --git a/third_party/xla/xla/hlo/builder/value_inference.cc b/third_party/xla/xla/hlo/builder/value_inference.cc index 940625cc006294..22e588a609f07e 100644 --- a/third_party/xla/xla/hlo/builder/value_inference.cc +++ b/third_party/xla/xla/hlo/builder/value_inference.cc @@ -547,7 +547,7 @@ PostorderDFSVisitor::AnalyzeConstantValueFallback(int64_t handle, call_context, "callee's root instruction"); return node.AddVisit([](Literal operand) -> absl::StatusOr { // Forward result of callee's root to caller. - return std::move(operand); + return operand; }); } diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 180d00d702afe2..8b5d4c17fffc7f 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -811,7 +811,7 @@ absl::StatusOr XlaBuilder::Build( this->embedded_.clear(); this->parameter_numbers_.clear(); - return std::move(computation); + return computation; } /* static */ absl::Status XlaBuilder::PopulateInputOutputAliasAndBufferDonor( @@ -2286,7 +2286,7 @@ absl::StatusOr XlaBuilder::DynamicConvInstruction( if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; } - return std::move(instr); + return instr; } XlaOp XlaBuilder::DynamicConvInputGrad( @@ -4743,7 +4743,7 @@ absl::StatusOr XlaBuilder::BuildConstantSubGraph( if (VLOG_IS_ON(4)) { VLOG(4) << "Constant computation:\n" << module->DebugString(); } - return std::move(computation); + return computation; } std::unique_ptr XlaBuilder::CreateSubBuilder( diff --git a/third_party/xla/xla/hlo/builder/xla_computation.cc b/third_party/xla/xla/hlo/builder/xla_computation.cc index 1d01870f1d85c9..807af8f4c49a61 100644 --- a/third_party/xla/xla/hlo/builder/xla_computation.cc +++ b/third_party/xla/xla/hlo/builder/xla_computation.cc @@ -37,7 +37,7 @@ absl::StatusOr> XlaComputation::Snapshot() const { } auto session = std::make_unique(); *session->mutable_hlo()->mutable_hlo_module() = proto_; - return std::move(session); + return session; } } // namespace xla diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index eaf8792d3405ad..e89706164bf22b 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -133,7 +133,7 @@ absl::StatusOr Compare(const Shape& shape, Comparison comparison, return compare_op(lhs, rhs); })); } - return std::move(result); + return result; }; switch (comparison.GetDirection()) { case ComparisonDirection::kEq: @@ -3839,7 +3839,7 @@ absl::StatusOr StochasticConvertOp(const Literal& operand_literal, return stochastic_convert_op(operand_literal.Get(multi_index), random_literal.Get(multi_index)); })); - return std::move(result); + return result; } // Converts from primitive types to native types. diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 0ba235098e8886..2bfb9cdc9f0e39 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -98,7 +98,7 @@ auto ToArithmeticSafeType(T t) { return static_cast>(t); } if constexpr (!std::is_integral_v) { - return std::move(t); + return t; } } diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index fc011d551b4f53..ea48be15dea2bb 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -1217,7 +1217,7 @@ HloComputation::CreateFromProto( if (!proto.execution_thread().empty()) { computation->SetExecutionThread(proto.execution_thread()); } - return std::move(computation); + return computation; } void HloComputation::AppendInstructionsIntoCalledComputation( diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index ee444ab0951121..74026577bb9c95 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -1381,7 +1381,7 @@ absl::StatusOr> HloInstruction::CreateFromProto( instruction->set_original_value(original_value); } - return std::move(instruction); + return instruction; } /* static */ std::unique_ptr HloInstruction::CreateParameter( diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 689a79825633fe..1930980c9b8d96 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1966,7 +1966,7 @@ class HloInstruction { absl::StatusOr backend_config() const { ConfigProto proto; TF_RETURN_IF_ERROR(backend_config_.GetProto(&proto)); - return std::move(proto); + return proto; } absl::Status set_backend_config(const tsl::protobuf::Message& proto) { diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index f6d1a73b47d3a6..5944ebf631a5d6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -3454,7 +3454,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_padding_type(padding_type_); *cloned->mutable_precision_config() = precision_config(); cloned->set_custom_call_schedule(custom_call_schedule_); - return std::move(cloned); + return cloned; } HloPadInstruction::HloPadInstruction(const Shape& shape, diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index e15d4dff1cfb42..51e3bfa6ca744b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -737,7 +737,7 @@ absl::StatusOr> HloModule::CreateFromProto( module->stack_frame_index_ = std::move(proto.stack_frame_index()); } } - return std::move(module); + return module; } /* static */ diff --git a/third_party/xla/xla/hlo/ir/hlo_schedule.cc b/third_party/xla/xla/hlo/ir/hlo_schedule.cc index 9a8dd8307bd022..5e39d86170bf1c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_schedule.cc +++ b/third_party/xla/xla/hlo/ir/hlo_schedule.cc @@ -80,7 +80,7 @@ namespace xla { } } TF_RETURN_IF_ERROR(schedule.Verify()); - return std::move(schedule); + return schedule; } absl::StatusOr HloSchedule::ToProto() const { @@ -96,7 +96,7 @@ absl::StatusOr HloSchedule::ToProto() const { proto_sequence.add_instruction_ids(id); } } - return std::move(proto); + return proto; } void HloSchedule::set_sequence(const HloComputation* computation, diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index 883f36eb93d61a..54551b70278e87 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -641,7 +641,7 @@ absl::StatusOr> HloSharding::AsShapeTree( for (auto& index_to_sharding : result.leaves()) { index_to_sharding.second = *it++; } - return std::move(result); + return result; } else { return ShapeTree(shape, *this); } diff --git a/third_party/xla/xla/hlo/parser/hlo_parser.cc b/third_party/xla/xla/hlo/parser/hlo_parser.cc index 3e4003344c3af6..1019c02b772456 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser.cc @@ -7282,7 +7282,7 @@ absl::StatusOr> ParseAndReturnUnverifiedModule( auto module = std::make_unique(/*name=*/"_", config); HloParserImpl parser(str, options); TF_RETURN_IF_ERROR(parser.Run(module.get())); - return std::move(module); + return module; } absl::StatusOr ParseSharding(absl::string_view str) { diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc index ee08fbd5abed79..be8df4849d736f 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc @@ -58,7 +58,7 @@ class HloPassPipelineTest : public HloHardwareIndependentTestBase { ParseAndReturnVerifiedModule(hlo_string)); group.push_back(std::move(module)); } - return std::move(group); + return group; } }; diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc index 7e1e25695267e0..0c284937a042e9 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc @@ -1215,6 +1215,6 @@ absl::StatusOr> HloValueTracing::Run( } absl::c_sort(hlo_value_tracing->values_vector_, HloValue::IdLessThan); - return std::move(hlo_value_tracing); + return hlo_value_tracing; } } // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_translate.cc b/third_party/xla/xla/hlo/tools/hlo_translate.cc index 162671f87bee3f..cceb7189f87e20 100644 --- a/third_party/xla/xla/hlo/tools/hlo_translate.cc +++ b/third_party/xla/xla/hlo/tools/hlo_translate.cc @@ -158,7 +158,7 @@ mlir::OwningOpRef GetModuleFromHloInput( // Try HLO Text auto module_from_text = GetModuleFromHLOText(content, context, emit_mhlo); - if (module_from_text.ok()) return std::move(module_from_text.value()); + if (module_from_text.ok()) return std::move(module_from_text).value(); if (module_from_text.status().message() != kLoadHloError) { emitError() << "Failed to convert HLO to MLIR: " << module_from_text.status().message(); @@ -168,7 +168,7 @@ mlir::OwningOpRef GetModuleFromHloInput( // Try HLO Proto auto module_from_proto = GetModuleFromHLOProto(std::string(content), context, emit_mhlo); - if (module_from_proto.ok()) return std::move(module_from_proto.value()); + if (module_from_proto.ok()) return std::move(module_from_proto).value(); if (module_from_proto.status().message() != kLoadHloError) { emitError() << "Failed to convert HLO to MLIR: " << module_from_proto.status().message(); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc index 997e9705f94202..1e711fbb994c4a 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc @@ -632,7 +632,7 @@ absl::StatusOr ScheduleModule( TF_RETURN_IF_ERROR(schedule.Verify()); - return std::move(schedule); + return schedule; } absl::StatusOr ScheduleModule( diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc index 1e2d46b2da7885..9771e245c785a9 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc @@ -136,7 +136,7 @@ mlir::OwningOpRef HloTextToStablehloTranslateFunction( return nullptr; } - return std::move(stablehlo_module.value()); + return std::move(stablehlo_module).value(); } } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index cde29fc4125e35..cc4d0a761c9e46 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -193,7 +193,7 @@ absl::StatusOr> ConvertHloToStablehlo( /*emit_stablehlo=*/false) .Import(*hlo_module)); TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); - return std::move(mlir_module); + return mlir_module; } absl::StatusOr> ConvertHloToStablehlo( @@ -206,7 +206,7 @@ absl::StatusOr> ConvertHloToStablehlo( /*emit_stablehlo=*/false) .Import(*hlo_module_proto)); TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); - return std::move(mlir_module); + return mlir_module; } absl::StatusOr> ConvertStablehloToHlo( diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.cc b/third_party/xla/xla/hlo/utils/hlo_live_range.cc index 61ce4c68dfe486..12ee1fbc45ebab 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.cc @@ -49,7 +49,7 @@ absl::StatusOr> HloLiveRange::Run( hlo_live_range->FlattenSchedule(*computation); hlo_live_range->CalculateBufferStartEndMap(); hlo_live_range->NormalizeAliasedBuffers(); - return std::move(hlo_live_range); + return hlo_live_range; } void HloLiveRange::NormalizeAliasedBuffers() { diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index d2e5b5c656ddc8..ca1317757dd35f 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -3144,7 +3144,7 @@ std::optional ReturnImprovedShardingImpl( bool allow_aggressive_resharding) { // Always allow improve the sharding if it's straightly better. if (to_improved != nullptr && IsShardingStrictlyBetter(from, *to_improved)) { - return std::move(from); + return from; } // We don't want to propagate tile maximal shardings. if (!IsSpatiallyPartitioned(from)) { @@ -3152,7 +3152,7 @@ std::optional ReturnImprovedShardingImpl( } // Any sharding is better than no sharding. if (to_improved == nullptr) { - return std::move(from); + return from; } // We don't want to propagate manual shardings. if (from.IsManual()) { @@ -3172,7 +3172,7 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } } - return std::move(from); + return from; } return std::nullopt; } diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 20803b2dec197a..f37c390738859d 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -402,7 +402,7 @@ absl::Status LiteralBase::SerializeToString(std::string* output) const { absl::StatusOr LiteralBase::SerializeAsString() const { std::string result; TF_RETURN_IF_ERROR(SerializeToString(&result)); - return std::move(result); + return result; } template @@ -541,7 +541,7 @@ void MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return absl::OkStatus(); })); - return std::move(literal); + return literal; } Literal Literal::SubLiteral(ShapeIndexView shape_index) { @@ -1210,7 +1210,7 @@ absl::StatusOr LiteralBase::Reshape( ShapeUtil::HumanString(shape()), ShapeUtil::HumanString(output.shape())); } - return std::move(output); + return output; } Literal LiteralBase::Transpose(absl::Span permutation) const { diff --git a/third_party/xla/xla/literal.h b/third_party/xla/xla/literal.h index 0b7110b3331c06..10fb37b118f8a6 100644 --- a/third_party/xla/xla/literal.h +++ b/third_party/xla/xla/literal.h @@ -437,7 +437,7 @@ class LiteralBase { } }); - return std::move(state); + return state; } // Templated wrapper struct to control layout sensitivity during Absl::Hash. diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 1637eee92100f1..ee8a84cf293b20 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -488,7 +488,7 @@ class Shape { if (kIsLayoutSensitive) { h = H::combine(std::move(h), state->layout); } - return std::move(h); + return h; } return H::combine(std::move(h), s.element_type_); } From 5f309c8575ccd35f65c88956f2a2a026604fa006 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 9 Apr 2025 18:48:16 -0700 Subject: [PATCH 0482/1324] [IFRT] Strengthen the device list and memory kind checks of `ClientMakeArraysFromHostBufferShards()` `ClientMakeArraysFromHostBufferShards()`, a fallback implementation for `Client::MakeArraysFromHostBufferShards()`, checks if supplied array specs meet the current API contract (the sharding of all array specs have equal device lists and memory kinds). PiperOrigin-RevId: 745814187 --- .../xla/python/ifrt/array_impl_test_lib.cc | 106 ++++++++++++++++++ .../xla/xla/python/ifrt/client_impl_util.cc | 20 ++++ 2 files changed, 126 insertions(+) diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index ffaf2aee644ad4..1902e5ee078c0a 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -485,6 +485,112 @@ TEST(ArrayImplTest, MakeArraysFromHostBufferShardsAndCopyToHostBuffer) { } } +TEST(ArrayImplTest, MakeArraysFromHostBufferShardsWithDifferentDevices) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + if (client->addressable_devices().size() < 2) { + GTEST_SKIP() << "This test is relevant only for clients with devices that " + "have at least 2 devices"; + } + + DType dtype(DType::kF32); + Shape shape({2, 3}); + Shape shard_shape = shape; + auto data = std::make_unique>(6); + std::iota(data->begin(), data->end(), 0); + + std::shared_ptr sharding0 = SingleDeviceSharding::Create( + client->addressable_devices()[0], MemoryKind()); + std::shared_ptr sharding1 = SingleDeviceSharding::Create( + client->addressable_devices()[1], MemoryKind()); + + std::vector specs; + // Create two arrays with different shardings. + specs.push_back({ + /*buffers=*/{ + {{0}, + {data->data(), dtype, shard_shape, /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}}}, + /*array_spec=*/{dtype, shape, sharding0, /*layout=*/nullptr}, + }); + specs.push_back({ + /*buffers=*/{ + {{0}, + {data->data(), dtype, shard_shape, /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}}}, + /*array_spec=*/{dtype, shape, sharding1, /*layout=*/nullptr}, + }); + + absl::Status status; + auto result = client->MakeArraysFromHostBufferShards( + absl::MakeSpan(specs), + Client::HostBufferSemantics::kImmutableOnlyDuringCall, + client->CreateUserContext()); + if (result.ok()) { + // Implementations may poison outputs instead of immediately returning an + // error. + status = result->at(0)->GetReadyFuture().Await(); + } else { + status = result.status(); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ArrayImplTest, MakeArraysFromHostBufferShardsWithDifferentMemoryKinds) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + if (client->addressable_devices().front()->Memories().size() < 2) { + GTEST_SKIP() << "This test is relevant only for clients with a device that " + "have at least 2 memories"; + } + + DType dtype(DType::kF32); + Shape shape({2, 3}); + Shape shard_shape = shape; + auto data = std::make_unique>(6); + std::iota(data->begin(), data->end(), 0); + + std::vector memory_kinds; + for (const Memory* memory : + client->addressable_devices().front()->Memories()) { + memory_kinds.push_back(memory->Kind()); + } + + std::shared_ptr sharding0 = SingleDeviceSharding::Create( + client->addressable_devices().front(), memory_kinds[0]); + std::shared_ptr sharding1 = SingleDeviceSharding::Create( + client->addressable_devices().front(), memory_kinds[1]); + + std::vector specs; + // Create two arrays with different shardings. + specs.push_back({ + /*buffers=*/{ + {{0}, + {data->data(), dtype, shard_shape, /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}}}, + /*array_spec=*/{dtype, shape, sharding0, /*layout=*/nullptr}, + }); + specs.push_back({ + /*buffers=*/{ + {{0}, + {data->data(), dtype, shard_shape, /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}}}, + /*array_spec=*/{dtype, shape, sharding1, /*layout=*/nullptr}, + }); + + absl::Status status; + auto result = client->MakeArraysFromHostBufferShards( + absl::MakeSpan(specs), + Client::HostBufferSemantics::kImmutableOnlyDuringCall, + client->CreateUserContext()); + if (result.ok()) { + // Implementations may poison outputs instead of immediately returning an + // error. + status = result->at(0)->GetReadyFuture().Await(); + } else { + status = result.status(); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); +} + TEST(ArrayImplTest, MakeErrorArrays) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); xla::ifrt::DeviceListRef device_list = diff --git a/third_party/xla/xla/python/ifrt/client_impl_util.cc b/third_party/xla/xla/python/ifrt/client_impl_util.cc index 7b9aba04cc52c9..00a59d92ead1b2 100644 --- a/third_party/xla/xla/python/ifrt/client_impl_util.cc +++ b/third_party/xla/xla/python/ifrt/client_impl_util.cc @@ -98,6 +98,26 @@ ClientMakeArraysFromHostBufferShards( absl::Span specs, Client::HostBufferSemantics semantics, tsl::RCReference user_context) { + for (int i = 1; i < specs.size(); ++i) { + const Client::MakeArraysFromHostBufferShardsSpec& spec = specs[i]; + if (specs[0].array_spec.sharding->devices() != + spec.array_spec.sharding->devices()) { + return absl::InvalidArgumentError(absl::StrCat( + "All arrays in MakeArraysFromHostBufferShards must have the " + "same device list, but got ", + specs[0].array_spec.sharding->devices(), " vs. ", + spec.array_spec.sharding->devices())); + } + if (specs[0].array_spec.sharding->memory_kind() != + spec.array_spec.sharding->memory_kind()) { + return absl::InvalidArgumentError(absl::StrCat( + "All arrays in MakeArraysFromHostBufferShards must have the " + "same memory kind, but got ", + specs[0].array_spec.sharding->memory_kind(), " vs. ", + spec.array_spec.sharding->memory_kind())); + } + } + std::vector> arrays; arrays.reserve(specs.size()); for (Client::MakeArraysFromHostBufferShardsSpec& spec : specs) { From c5d30c01cc424f937e860ba4e7f5a3292de1d5a3 Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Wed, 9 Apr 2025 19:32:08 -0700 Subject: [PATCH 0483/1324] Avoid creating large constants in ConvertTFLBroadcastToMulOp optimization pass. This pattern is inherited from before and has proved to be increasing the model size due the introduction of large splat const. In its current form this pattern replaces a tfl.broadcast_to op (with rank<4) to a tfl.mul with all-ones tensor. This change will keep the broadcast_to ops as is because its clear that introducing MUL is not an optimization. PiperOrigin-RevId: 745823621 --- .../batched_gather_round_trip.mlir | 11 ++- .../batched_scatter_round_trip.mlir | 11 ++- .../legalize-tf-no-runtime-verification.mlir | 5 +- .../compiler/mlir/lite/tests/optimize.mlir | 42 +++------- .../mlir/lite/transforms/optimize_pass.cc | 76 +------------------ 5 files changed, 28 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir index 12de9da5939573..adb22ddd009a80 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir @@ -4,11 +4,14 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { - // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 - // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> + // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> + // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> + // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %2) <{ + // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ // CHECK-ROUNDTRIP-SAME: dimension_numbers = #stablehlo.gather< // CHECK-ROUNDTRIP-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir index 44d1bb7dd8b72f..7e42ff310c080f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir @@ -4,11 +4,14 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { - // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 - // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> + // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> + // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> + // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %2, %arg2) <{ + // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ // CHECK-ROUNDTRIP-SAME: scatter_dimension_numbers = #stablehlo.scatter // CHECK-ROUNDTRIP-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>}> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 2c17e734c58dad..e0793cbf803c4f 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -5,7 +5,6 @@ func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> ten func.return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = tfl.mul(%arg0, [[CST]]) <{fused_activation_function = "NONE"}> : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> -// CHECK: return [[MUL]] : tensor<3x3xbf16> +// CHECK: %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> +// CHECK: return %0 : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 83b82d50fc064f..47d32734dbb27c 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -4399,11 +4399,11 @@ func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x %1 = "tfl.broadcast_to"(%0, %cst_0) : (tensor<1x1x1x8x1x1xf32>, tensor<6xi32>) -> tensor<1x1x1x8x16x1xf32> %2 = "tfl.reshape"(%1, %cst_1) : (tensor<1x1x1x8x16x1xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> return %2 : tensor<1x1x1x128xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<8x16xf32> + // CHECK: %cst = arith.constant dense<[8, 16]> : tensor<2xi64> // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> - // CHECK: %1 = tfl.mul(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<8x1xf32>, tensor<2xi64>) -> tensor<8x16xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> // CHECK: return %2 : tensor<1x1x1x128xf32> } @@ -4425,83 +4425,63 @@ func.func @FuseExcessBroadcastingOnReshapesDynamicShapes(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_low_dim func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK: return %0 : tensor<3x3xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_low_dim_with_unknown_shape func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i16_low_dim func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) -> tensor<3x3xi16> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> return %0 : tensor<3x3xi16> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> - // CHECK: return %0 : tensor<3x3xi16> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_low_dim_with_unknown_output func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<*xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> - // CHECK: %cst = arith.constant dense<1> : tensor - // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> - // CHECK: return %1 : tensor<*xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_ui32 func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> return %0 : tensor<10xui32> - // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor<10xui32>) -> tensor<10xui32> - // CHECK: return %0 : tensor<10xui32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_f32 func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32 func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK: return %0 : tensor<3x3xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_with_dynamic_shape_and_output func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x?xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x?xi32> return %0 : tensor<3x?xi32> - // CHECK: %cst = arith.constant dense<1> : tensor - // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> - // CHECK: return %1 : tensor<3x?xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_ui32_with_dynamic_output diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index 1de8398dbd8bd9..448ef85bdac543 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -977,80 +977,6 @@ struct SqueezeReshapesAroundBroadcastOp } }; -// This pattern matches TFL::BroadcastToOp WITH TENSOR RANK <= 4 and replaces -// it with a MulOp that multiplies the tensor by a splat constant with 1s. -struct ConvertTFLBroadcastToMulOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, - PatternRewriter &rewriter) const override { - auto input_type = - mlir::cast(tfl_broadcast_to_op.getInput().getType()); - auto output_type = - mlir::cast(tfl_broadcast_to_op.getOutput().getType()); - auto shape_type = - mlir::cast(tfl_broadcast_to_op.getShape().getType()); - Type element_type = input_type.getElementType(); - - auto loc = tfl_broadcast_to_op->getLoc(); - - // Check that the output type is not dynamic and is less-than-equal to 4D or - // the shape type is static, 1D and has less-than-equal to 4 elements. - bool is_output_shape_dynamic = - (!output_type.hasRank() || (output_type.getRank() > 4) || - (output_type.getNumDynamicDims() > 0)); - bool is_broadcast_shape_dynamic = - (!shape_type.hasStaticShape() || (shape_type.getRank() != 1) || - (shape_type.getDimSize(0) > 4)); - if (is_output_shape_dynamic && is_broadcast_shape_dynamic) - return rewriter.notifyMatchFailure( - loc, "output_rank or broadcast_to shape not supported"); - - // Allow lowering when the input's elements type is F32, BFloat16, I32 or - // I16. - if (!(mlir::isa(element_type) || - element_type.isInteger(32) || element_type.isInteger(16))) - return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); - - // TFL_FillOp is created only if is_output_shape_dynamic is true, otherwise - // a Arith.ConstOp is created. - if (is_output_shape_dynamic && - output_type.getElementType().isUnsignedInteger()) { - return rewriter.notifyMatchFailure( - loc, - "Unsigned broadcast_to output with dynamic shape is not supported"); - } - - Value mul_rhs_value; - if (!output_type.hasRank() || (output_type.getNumDynamicDims() > 0)) { - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, loc, input_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - mul_rhs_value = rewriter.create( - loc, output_type, tfl_broadcast_to_op.getShape(), - status_or_const_op.value()); - } else { - auto status_or_const_op = - CreateConstOpWithVectorValue(&rewriter, loc, output_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - mul_rhs_value = status_or_const_op.value(); - } - - auto mul_op = rewriter.create( - loc, output_type, tfl_broadcast_to_op.getInput(), mul_rhs_value, - rewriter.getStringAttr("NONE")); - rewriter.replaceOp(tfl_broadcast_to_op, mul_op.getResult()); - return success(); - } -}; - struct FuseAddAndStridedSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3081,7 +3007,7 @@ void OptimizePass::runOnOperation() { OptimizeTopK, FuseAddAndStridedSlice, FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul, MoveReshapeAfterFullyConnected, - EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp, + EnableFullyConnectedKeepNumDimsBeforeReshape, ReorderTransposeReshapeTranspose, FullyConnectedSwapOperandsWhenLHSIsConst>(ctx); if (!GetOptions().disable_fuse_mul_and_fc) { From fe99cc2b5772e4737276f9b2c22fa704ad5142fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Apr 2025 20:08:19 -0700 Subject: [PATCH 0484/1324] Integrate LLVM at llvm/llvm-project@836476660e5c Updates LLVM usage to match [836476660e5c](https://github.com/llvm/llvm-project/commit/836476660e5c) PiperOrigin-RevId: 745830634 --- third_party/llvm/generated.patch | 585 ------------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 804 +++++++++++++----- third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 804 +++++++++++++----- .../xla/third_party/shardy/workspace.bzl | 4 +- .../emitters/transforms/convert_float_amd.cc | 2 +- .../codegen/emitters/transforms/fuse_loops.cc | 2 +- .../emitters/transforms/flatten_tensors.cc | 2 +- .../transforms/vectorize_loads_stores.cc | 2 +- 10 files changed, 1178 insertions(+), 1035 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index a3ecef4dbbedb7..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,586 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ---- a/clang/lib/Sema/TreeTransform.h -+++ b/clang/lib/Sema/TreeTransform.h -@@ -7765,17 +7765,23 @@ - NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); - NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); - -- typedef TemplateArgumentLocContainerIterator< -- DependentTemplateSpecializationTypeLoc> ArgIterator; -- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), -- ArgIterator(TL, TL.getNumArgs()), -- NewTemplateArgs)) -+ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); -+ -+ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), -+ ArgsRange.end(), NewTemplateArgs)) - return QualType(); -+ bool TemplateArgumentsChanged = !llvm::equal( -+ ArgsRange, NewTemplateArgs.arguments(), -+ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { -+ return A.getArgument().structurallyEquals(B.getArgument()); -+ }); - - const DependentTemplateStorage &DTN = T->getDependentTemplateName(); - - QualType Result = TL.getType(); -- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { -+ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || -+ TemplateArgumentsChanged) { - TemplateName Name = getDerived().RebuildTemplateName( - SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), - /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h ---- a/clang/test/CodeGen/include/cuda.h -+++ b/clang/test/CodeGen/include/cuda.h -@@ -1,194 +0,0 @@ --/* Minimal declarations for CUDA support. Testing purposes only. -- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h -- */ --#pragma once -- --// Make this file work with nvcc, for testing compatibility. -- --#ifndef __NVCC__ --#define __constant__ __attribute__((constant)) --#define __device__ __attribute__((device)) --#define __global__ __attribute__((global)) --#define __host__ __attribute__((host)) --#define __shared__ __attribute__((shared)) --#define __managed__ __attribute__((managed)) --#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) -- --struct dim3 { -- unsigned x, y, z; -- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} --}; -- --// Host- and device-side placement new overloads. --void *operator new(__SIZE_TYPE__, void *p) { return p; } --void *operator new[](__SIZE_TYPE__, void *p) { return p; } --__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } --__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } -- --#define CUDA_VERSION 10100 -- --struct char1 { -- char x; -- __host__ __device__ char1(char x = 0) : x(x) {} --}; --struct char2 { -- char x, y; -- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} --}; --struct char4 { -- char x, y, z, w; -- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct uchar1 { -- unsigned char x; -- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} --}; --struct uchar2 { -- unsigned char x, y; -- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} --}; --struct uchar4 { -- unsigned char x, y, z, w; -- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct short1 { -- short x; -- __host__ __device__ short1(short x = 0) : x(x) {} --}; --struct short2 { -- short x, y; -- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} --}; --struct short4 { -- short x, y, z, w; -- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct ushort1 { -- unsigned short x; -- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} --}; --struct ushort2 { -- unsigned short x, y; -- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} --}; --struct ushort4 { -- unsigned short x, y, z, w; -- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct int1 { -- int x; -- __host__ __device__ int1(int x = 0) : x(x) {} --}; --struct int2 { -- int x, y; -- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} --}; --struct int4 { -- int x, y, z, w; -- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct uint1 { -- unsigned x; -- __host__ __device__ uint1(unsigned x = 0) : x(x) {} --}; --struct uint2 { -- unsigned x, y; -- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} --}; --struct uint3 { -- unsigned x, y, z; -- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} --}; --struct uint4 { -- unsigned x, y, z, w; -- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct longlong1 { -- long long x; -- __host__ __device__ longlong1(long long x = 0) : x(x) {} --}; --struct longlong2 { -- long long x, y; -- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} --}; --struct longlong4 { -- long long x, y, z, w; -- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct ulonglong1 { -- unsigned long long x; -- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} --}; --struct ulonglong2 { -- unsigned long long x, y; -- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} --}; --struct ulonglong4 { -- unsigned long long x, y, z, w; -- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct float1 { -- float x; -- __host__ __device__ float1(float x = 0) : x(x) {} --}; --struct float2 { -- float x, y; -- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} --}; --struct float4 { -- float x, y, z, w; -- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --struct double1 { -- double x; -- __host__ __device__ double1(double x = 0) : x(x) {} --}; --struct double2 { -- double x, y; -- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} --}; --struct double4 { -- double x, y, z, w; -- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} --}; -- --typedef unsigned long long cudaTextureObject_t; --typedef unsigned long long cudaSurfaceObject_t; -- --enum cudaTextureReadMode { -- cudaReadModeNormalizedFloat, -- cudaReadModeElementType --}; -- --enum cudaSurfaceBoundaryMode { -- cudaBoundaryModeZero, -- cudaBoundaryModeClamp, -- cudaBoundaryModeTrap --}; -- --enum { -- cudaTextureType1D, -- cudaTextureType2D, -- cudaTextureType3D, -- cudaTextureTypeCubemap, -- cudaTextureType1DLayered, -- cudaTextureType2DLayered, -- cudaTextureTypeCubemapLayered --}; -- --struct textureReference {}; --template --struct __attribute__((device_builtin_texture_type)) texture -- : public textureReference {}; -- --#endif // !__NVCC__ -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h ---- a/clang/test/CodeGen/Inputs/cuda.h -+++ b/clang/test/CodeGen/Inputs/cuda.h -@@ -0,0 +1,194 @@ -+/* Minimal declarations for CUDA support. Testing purposes only. -+ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h -+ */ -+#pragma once -+ -+// Make this file work with nvcc, for testing compatibility. -+ -+#ifndef __NVCC__ -+#define __constant__ __attribute__((constant)) -+#define __device__ __attribute__((device)) -+#define __global__ __attribute__((global)) -+#define __host__ __attribute__((host)) -+#define __shared__ __attribute__((shared)) -+#define __managed__ __attribute__((managed)) -+#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) -+ -+struct dim3 { -+ unsigned x, y, z; -+ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} -+}; -+ -+// Host- and device-side placement new overloads. -+void *operator new(__SIZE_TYPE__, void *p) { return p; } -+void *operator new[](__SIZE_TYPE__, void *p) { return p; } -+__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } -+__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } -+ -+#define CUDA_VERSION 10100 -+ -+struct char1 { -+ char x; -+ __host__ __device__ char1(char x = 0) : x(x) {} -+}; -+struct char2 { -+ char x, y; -+ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} -+}; -+struct char4 { -+ char x, y, z, w; -+ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct uchar1 { -+ unsigned char x; -+ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} -+}; -+struct uchar2 { -+ unsigned char x, y; -+ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} -+}; -+struct uchar4 { -+ unsigned char x, y, z, w; -+ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct short1 { -+ short x; -+ __host__ __device__ short1(short x = 0) : x(x) {} -+}; -+struct short2 { -+ short x, y; -+ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} -+}; -+struct short4 { -+ short x, y, z, w; -+ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct ushort1 { -+ unsigned short x; -+ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} -+}; -+struct ushort2 { -+ unsigned short x, y; -+ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} -+}; -+struct ushort4 { -+ unsigned short x, y, z, w; -+ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct int1 { -+ int x; -+ __host__ __device__ int1(int x = 0) : x(x) {} -+}; -+struct int2 { -+ int x, y; -+ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} -+}; -+struct int4 { -+ int x, y, z, w; -+ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct uint1 { -+ unsigned x; -+ __host__ __device__ uint1(unsigned x = 0) : x(x) {} -+}; -+struct uint2 { -+ unsigned x, y; -+ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} -+}; -+struct uint3 { -+ unsigned x, y, z; -+ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} -+}; -+struct uint4 { -+ unsigned x, y, z, w; -+ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct longlong1 { -+ long long x; -+ __host__ __device__ longlong1(long long x = 0) : x(x) {} -+}; -+struct longlong2 { -+ long long x, y; -+ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} -+}; -+struct longlong4 { -+ long long x, y, z, w; -+ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct ulonglong1 { -+ unsigned long long x; -+ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} -+}; -+struct ulonglong2 { -+ unsigned long long x, y; -+ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} -+}; -+struct ulonglong4 { -+ unsigned long long x, y, z, w; -+ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct float1 { -+ float x; -+ __host__ __device__ float1(float x = 0) : x(x) {} -+}; -+struct float2 { -+ float x, y; -+ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} -+}; -+struct float4 { -+ float x, y, z, w; -+ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+struct double1 { -+ double x; -+ __host__ __device__ double1(double x = 0) : x(x) {} -+}; -+struct double2 { -+ double x, y; -+ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} -+}; -+struct double4 { -+ double x, y, z, w; -+ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} -+}; -+ -+typedef unsigned long long cudaTextureObject_t; -+typedef unsigned long long cudaSurfaceObject_t; -+ -+enum cudaTextureReadMode { -+ cudaReadModeNormalizedFloat, -+ cudaReadModeElementType -+}; -+ -+enum cudaSurfaceBoundaryMode { -+ cudaBoundaryModeZero, -+ cudaBoundaryModeClamp, -+ cudaBoundaryModeTrap -+}; -+ -+enum { -+ cudaTextureType1D, -+ cudaTextureType2D, -+ cudaTextureType3D, -+ cudaTextureTypeCubemap, -+ cudaTextureType1DLayered, -+ cudaTextureType2DLayered, -+ cudaTextureTypeCubemapLayered -+}; -+ -+struct textureReference {}; -+template -+struct __attribute__((device_builtin_texture_type)) texture -+ : public textureReference {}; -+ -+#endif // !__NVCC__ -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu ---- a/clang/test/CodeGen/nvptx-surface.cu -+++ b/clang/test/CodeGen/nvptx-surface.cu -@@ -1,6 +1,6 @@ - // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s - // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s --#include "include/cuda.h" -+#include "Inputs/cuda.h" - - #include "__clang_cuda_texture_intrinsics.h" - -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp ---- a/clang/test/SemaTemplate/dependent-names.cpp -+++ b/clang/test/SemaTemplate/dependent-names.cpp -@@ -458,3 +458,12 @@ - }; - int f(b ba) { return ba.add<0>(); } - } -+ -+namespace TransformDependentTemplates { -+ template struct Test1 { -+ template -+ using Arg = typename T::template Arg; -+ void f(Arg); -+ void f(Arg); -+ }; -+} // namespace TransformDependentTemplates -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ---- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp -@@ -15391,12 +15391,20 @@ - - if (E->State == TreeEntry::SplitVectorize) { - Res = FindLastInst(); -+ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { -+ for (auto *E : Entries) { -+ auto *I = dyn_cast_or_null(E->VectorizedValue); -+ if (!I) -+ I = &getLastInstructionInBundle(E); -+ if (Res->comesBefore(I)) -+ Res = I; -+ } -+ } - return *Res; - } - - // Set insertpoint for gathered loads to the very first load. -- if (E->State != TreeEntry::SplitVectorize && -- GatheredLoadsEntriesFirst.has_value() && -+ if (GatheredLoadsEntriesFirst.has_value() && - E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && - E->getOpcode() == Instruction::Load) { - Res = FindFirstInst(); -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ---- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -+++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -@@ -0,0 +1,99 @@ -+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s -+ -+define void @test(ptr %0, <8 x i8> %1) { -+; CHECK-LABEL: define void @test( -+; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { -+; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 -+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 -+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 -+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 -+; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 -+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> -+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 -+; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> -+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> -+; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) -+; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 -+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> -+; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) -+; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) -+; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] -+; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 -+; CHECK-NEXT: ret void -+; -+ %3 = load i8, ptr %0, align 2 -+ %4 = getelementptr i8, ptr %0, i64 13442 -+ %5 = load i8, ptr %4, align 2 -+ %6 = or i8 %5, %3 -+ %7 = getelementptr i8, ptr %0, i64 13550 -+ store i8 %6, ptr %7, align 2 -+ %8 = extractelement <8 x i8> %1, i64 0 -+ %9 = or i8 %5, %8 -+ %10 = getelementptr i8, ptr %0, i64 13542 -+ store i8 %9, ptr %10, align 2 -+ %11 = getelementptr i8, ptr %0, i64 13438 -+ %12 = load i8, ptr %11, align 2 -+ %13 = or i8 %12, %3 -+ %14 = getelementptr i8, ptr %0, i64 13546 -+ store i8 %13, ptr %14, align 2 -+ %15 = extractelement <8 x i8> %1, i64 2 -+ %16 = or i8 %12, %15 -+ %17 = getelementptr i8, ptr %0, i64 13538 -+ store i8 %16, ptr %17, align 2 -+ %18 = getelementptr i8, ptr %0, i64 13440 -+ %19 = load i8, ptr %18, align 4 -+ %20 = or i8 %19, %3 -+ %21 = getelementptr i8, ptr %0, i64 13548 -+ store i8 %20, ptr %21, align 4 -+ %22 = extractelement <8 x i8> %1, i64 4 -+ %23 = or i8 %19, %22 -+ %24 = getelementptr i8, ptr %0, i64 13540 -+ store i8 %23, ptr %24, align 4 -+ %25 = getelementptr i8, ptr %0, i64 13436 -+ %26 = load i8, ptr %25, align 4 -+ %27 = getelementptr i8, ptr %0, i64 13444 -+ %28 = load i8, ptr %27, align 4 -+ %29 = or i8 %28, %26 -+ %30 = getelementptr i8, ptr %0, i64 13544 -+ store i8 %29, ptr %30, align 4 -+ %31 = or i8 %26, %8 -+ %32 = getelementptr i8, ptr %0, i64 13536 -+ store i8 %31, ptr %32, align 4 -+ %33 = getelementptr i8, ptr %0, i64 13443 -+ %34 = load i8, ptr %33, align 1 -+ %35 = or i8 %34, %3 -+ %36 = getelementptr i8, ptr %0, i64 13551 -+ store i8 %35, ptr %36, align 1 -+ %37 = extractelement <8 x i8> %1, i64 7 -+ %38 = or i8 %34, %37 -+ %39 = getelementptr i8, ptr %0, i64 13543 -+ store i8 %38, ptr %39, align 1 -+ %40 = getelementptr i8, ptr %0, i64 13439 -+ %41 = load i8, ptr %40, align 1 -+ %42 = or i8 %41, %3 -+ %43 = getelementptr i8, ptr %0, i64 13547 -+ store i8 %42, ptr %43, align 1 -+ %44 = extractelement <8 x i8> %1, i64 3 -+ %45 = or i8 %41, %44 -+ %46 = getelementptr i8, ptr %0, i64 13539 -+ store i8 %45, ptr %46, align 1 -+ %47 = getelementptr i8, ptr %0, i64 13441 -+ %48 = load i8, ptr %47, align 1 -+ %49 = or i8 %48, %3 -+ %50 = getelementptr i8, ptr %0, i64 13549 -+ store i8 %49, ptr %50, align 1 -+ %51 = extractelement <8 x i8> %1, i64 5 -+ %52 = or i8 %48, %51 -+ %53 = getelementptr i8, ptr %0, i64 13541 -+ store i8 %52, ptr %53, align 1 -+ %54 = getelementptr i8, ptr %0, i64 13437 -+ %55 = load i8, ptr %54, align 1 -+ %56 = or i8 %55, %3 -+ %57 = getelementptr i8, ptr %0, i64 13545 -+ store i8 %56, ptr %57, align 1 -+ %58 = or i8 %55, %8 -+ %59 = getelementptr i8, ptr %0, i64 13537 -+ store i8 %58, ptr %59, align 1 -+ ret void -+} diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 73450ce1ae572a..7993194770a240 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" - LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" + LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" + LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index bd3beff6435ac1..b5a97068282d17 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,242 +1,606 @@ -diff --git a/shardy/dialect/sdy/ir/BUILD b/shardy/dialect/sdy/ir/BUILD -index 780cd17..fe8986b 100644 ---- a/shardy/dialect/sdy/ir/BUILD -+++ b/shardy/dialect/sdy/ir/BUILD -@@ -164,6 +164,7 @@ cc_library( - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", -+ "@llvm-project//mlir:TransformUtils", - "@stablehlo//:stablehlo_assembly_format", - "@stablehlo//:stablehlo_ops", - "@stablehlo//:stablehlo_type_inference", -diff --git a/shardy/dialect/sdy/ir/canonicalization.cc b/shardy/dialect/sdy/ir/canonicalization.cc -index e1b391f..7ab3e28 100644 ---- a/shardy/dialect/sdy/ir/canonicalization.cc -+++ b/shardy/dialect/sdy/ir/canonicalization.cc -@@ -25,6 +25,7 @@ limitations under the License. - #include "mlir/IR/Region.h" - #include "mlir/IR/Value.h" - #include "mlir/Support/LLVM.h" -+#include "mlir/Transforms/Inliner.h" - #include "mlir/Transforms/InliningUtils.h" - #include "shardy/dialect/sdy/ir/dialect.h" - #include "shardy/dialect/sdy/ir/utils.h" -@@ -103,9 +104,11 @@ class RedundantManualComputationPattern - } - - mlir::InlinerInterface inliner(manualComputationOp.getContext()); -+ mlir::InlinerConfig config; - if (inlineRegion( -- inliner, &manualComputationOp.getRegion(), -- manualComputationOp->getBlock(), manualComputationOp->getIterator(), -+ inliner, config.getCloneCallback(), -+ &manualComputationOp.getRegion(), manualComputationOp->getBlock(), -+ manualComputationOp->getIterator(), - manualComputationOp.getOperands(), manualComputationOp.getResults()) - .failed()) { - manualComputationOp.emitOpError( diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 97282ec..a3ecef4 100644 +index a3ecef4..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,16 +1,4 @@ +@@ -1,586 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp ----- a/clang/lib/AST/ASTContext.cpp --+++ b/clang/lib/AST/ASTContext.cpp --@@ -7011,7 +7011,7 @@ -- getCanonicalTemplateArgument(subst->getArgumentPack()); -- return getSubstTemplateTemplateParmPack( -- canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), --- subst->getFinal(), subst->getIndex()); --+ subst->getIndex(), subst->getFinal()); -- } -- case TemplateName::DeducedTemplate: { -- assert(IgnoreDeduced == false); - diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h - --- a/clang/lib/Sema/TreeTransform.h - +++ b/clang/lib/Sema/TreeTransform.h -@@ -44,28 +32,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/ - TemplateName Name = getDerived().RebuildTemplateName( - SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), - /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, --diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ----- a/clang/lib/Serialization/ASTReaderStmt.cpp --+++ b/clang/lib/Serialization/ASTReaderStmt.cpp --@@ -2229,6 +2229,7 @@ -- E->PackIndex = Record.readInt(); -- else -- E->PackIndex = 0; --+ E->Final = CurrentUnpackingBits->getNextBit(); -- E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); -- E->Replacement = Record.readSubExpr(); -- } --diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ----- a/clang/lib/Serialization/ASTWriterStmt.cpp --+++ b/clang/lib/Serialization/ASTWriterStmt.cpp --@@ -2229,6 +2229,7 @@ -- CurrentPackingBits.addBit((bool)E->getPackIndex()); -- if (auto PackIndex = E->getPackIndex()) -- Record.push_back(*PackIndex + 1); --+ CurrentPackingBits.addBit(E->getFinal()); +-diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h +---- a/clang/lib/Sema/TreeTransform.h +-+++ b/clang/lib/Sema/TreeTransform.h +-@@ -7765,17 +7765,23 @@ +- NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); +- NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); +- +-- typedef TemplateArgumentLocContainerIterator< +-- DependentTemplateSpecializationTypeLoc> ArgIterator; +-- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), +-- ArgIterator(TL, TL.getNumArgs()), +-- NewTemplateArgs)) +-+ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); +-+ +-+ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), +-+ ArgsRange.end(), NewTemplateArgs)) +- return QualType(); +-+ bool TemplateArgumentsChanged = !llvm::equal( +-+ ArgsRange, NewTemplateArgs.arguments(), +-+ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { +-+ return A.getArgument().structurallyEquals(B.getArgument()); +-+ }); - -- Record.AddSourceLocation(E->getNameLoc()); -- Record.AddStmt(E->getReplacement()); - diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h - --- a/clang/test/CodeGen/include/cuda.h - +++ b/clang/test/CodeGen/include/cuda.h -@@ -515,119 +481,6 @@ diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp - E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && - E->getOpcode() == Instruction::Load) { - Res = FindFirstInst(); --diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ----- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --@@ -2590,6 +2590,14 @@ -- if (R.mayWriteToMemory() && !InterleaveR) -- return; +- const DependentTemplateStorage &DTN = T->getDependentTemplateName(); - --+ // Do not narrow interleave groups if there are VectorPointer recipes and --+ // the plan was unrolled. The recipe implicitly uses VF from --+ // VPTransformState. --+ // TODO: Remove restriction once the VF for the VectorPointer offset is --+ // modeled explicitly as operand. --+ if (isa(&R) && Plan.getUF() > 1) --+ return; --+ -- // All other ops are allowed, but we reject uses that cannot be converted -- // when checking all allowed consumers (store interleave groups) below. -- if (!InterleaveR) --diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ----- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll --+++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll --@@ -66,3 +66,91 @@ -- exit: -- ret void +- QualType Result = TL.getType(); +-- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { +-+ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || +-+ TemplateArgumentsChanged) { +- TemplateName Name = getDerived().RebuildTemplateName( +- SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), +- /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, +-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h +---- a/clang/test/CodeGen/include/cuda.h +-+++ b/clang/test/CodeGen/include/cuda.h +-@@ -1,194 +0,0 @@ +--/* Minimal declarations for CUDA support. Testing purposes only. +-- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +-- */ +--#pragma once +-- +--// Make this file work with nvcc, for testing compatibility. +-- +--#ifndef __NVCC__ +--#define __constant__ __attribute__((constant)) +--#define __device__ __attribute__((device)) +--#define __global__ __attribute__((global)) +--#define __host__ __attribute__((host)) +--#define __shared__ __attribute__((shared)) +--#define __managed__ __attribute__((managed)) +--#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +-- +--struct dim3 { +-- unsigned x, y, z; +-- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +--}; +-- +--// Host- and device-side placement new overloads. +--void *operator new(__SIZE_TYPE__, void *p) { return p; } +--void *operator new[](__SIZE_TYPE__, void *p) { return p; } +--__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +--__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-- +--#define CUDA_VERSION 10100 +-- +--struct char1 { +-- char x; +-- __host__ __device__ char1(char x = 0) : x(x) {} +--}; +--struct char2 { +-- char x, y; +-- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +--}; +--struct char4 { +-- char x, y, z, w; +-- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct uchar1 { +-- unsigned char x; +-- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +--}; +--struct uchar2 { +-- unsigned char x, y; +-- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +--}; +--struct uchar4 { +-- unsigned char x, y, z, w; +-- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct short1 { +-- short x; +-- __host__ __device__ short1(short x = 0) : x(x) {} +--}; +--struct short2 { +-- short x, y; +-- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +--}; +--struct short4 { +-- short x, y, z, w; +-- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct ushort1 { +-- unsigned short x; +-- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +--}; +--struct ushort2 { +-- unsigned short x, y; +-- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +--}; +--struct ushort4 { +-- unsigned short x, y, z, w; +-- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct int1 { +-- int x; +-- __host__ __device__ int1(int x = 0) : x(x) {} +--}; +--struct int2 { +-- int x, y; +-- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +--}; +--struct int4 { +-- int x, y, z, w; +-- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct uint1 { +-- unsigned x; +-- __host__ __device__ uint1(unsigned x = 0) : x(x) {} +--}; +--struct uint2 { +-- unsigned x, y; +-- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +--}; +--struct uint3 { +-- unsigned x, y, z; +-- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +--}; +--struct uint4 { +-- unsigned x, y, z, w; +-- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct longlong1 { +-- long long x; +-- __host__ __device__ longlong1(long long x = 0) : x(x) {} +--}; +--struct longlong2 { +-- long long x, y; +-- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +--}; +--struct longlong4 { +-- long long x, y, z, w; +-- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct ulonglong1 { +-- unsigned long long x; +-- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +--}; +--struct ulonglong2 { +-- unsigned long long x, y; +-- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +--}; +--struct ulonglong4 { +-- unsigned long long x, y, z, w; +-- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct float1 { +-- float x; +-- __host__ __device__ float1(float x = 0) : x(x) {} +--}; +--struct float2 { +-- float x, y; +-- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +--}; +--struct float4 { +-- float x, y, z, w; +-- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct double1 { +-- double x; +-- __host__ __device__ double1(double x = 0) : x(x) {} +--}; +--struct double2 { +-- double x, y; +-- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +--}; +--struct double4 { +-- double x, y, z, w; +-- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--typedef unsigned long long cudaTextureObject_t; +--typedef unsigned long long cudaSurfaceObject_t; +-- +--enum cudaTextureReadMode { +-- cudaReadModeNormalizedFloat, +-- cudaReadModeElementType +--}; +-- +--enum cudaSurfaceBoundaryMode { +-- cudaBoundaryModeZero, +-- cudaBoundaryModeClamp, +-- cudaBoundaryModeTrap +--}; +-- +--enum { +-- cudaTextureType1D, +-- cudaTextureType2D, +-- cudaTextureType3D, +-- cudaTextureTypeCubemap, +-- cudaTextureType1DLayered, +-- cudaTextureType2DLayered, +-- cudaTextureTypeCubemapLayered +--}; +-- +--struct textureReference {}; +--template +--struct __attribute__((device_builtin_texture_type)) texture +-- : public textureReference {}; +-- +--#endif // !__NVCC__ +-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h +---- a/clang/test/CodeGen/Inputs/cuda.h +-+++ b/clang/test/CodeGen/Inputs/cuda.h +-@@ -0,0 +1,194 @@ +-+/* Minimal declarations for CUDA support. Testing purposes only. +-+ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +-+ */ +-+#pragma once +-+ +-+// Make this file work with nvcc, for testing compatibility. +-+ +-+#ifndef __NVCC__ +-+#define __constant__ __attribute__((constant)) +-+#define __device__ __attribute__((device)) +-+#define __global__ __attribute__((global)) +-+#define __host__ __attribute__((host)) +-+#define __shared__ __attribute__((shared)) +-+#define __managed__ __attribute__((managed)) +-+#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +-+ +-+struct dim3 { +-+ unsigned x, y, z; +-+ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +-+}; +-+ +-+// Host- and device-side placement new overloads. +-+void *operator new(__SIZE_TYPE__, void *p) { return p; } +-+void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-+__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +-+__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-+ +-+#define CUDA_VERSION 10100 +-+ +-+struct char1 { +-+ char x; +-+ __host__ __device__ char1(char x = 0) : x(x) {} +-+}; +-+struct char2 { +-+ char x, y; +-+ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +-+}; +-+struct char4 { +-+ char x, y, z, w; +-+ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct uchar1 { +-+ unsigned char x; +-+ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +-+}; +-+struct uchar2 { +-+ unsigned char x, y; +-+ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +-+}; +-+struct uchar4 { +-+ unsigned char x, y, z, w; +-+ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct short1 { +-+ short x; +-+ __host__ __device__ short1(short x = 0) : x(x) {} +-+}; +-+struct short2 { +-+ short x, y; +-+ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +-+}; +-+struct short4 { +-+ short x, y, z, w; +-+ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct ushort1 { +-+ unsigned short x; +-+ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +-+}; +-+struct ushort2 { +-+ unsigned short x, y; +-+ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +-+}; +-+struct ushort4 { +-+ unsigned short x, y, z, w; +-+ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct int1 { +-+ int x; +-+ __host__ __device__ int1(int x = 0) : x(x) {} +-+}; +-+struct int2 { +-+ int x, y; +-+ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +-+}; +-+struct int4 { +-+ int x, y, z, w; +-+ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct uint1 { +-+ unsigned x; +-+ __host__ __device__ uint1(unsigned x = 0) : x(x) {} +-+}; +-+struct uint2 { +-+ unsigned x, y; +-+ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +-+}; +-+struct uint3 { +-+ unsigned x, y, z; +-+ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +-+}; +-+struct uint4 { +-+ unsigned x, y, z, w; +-+ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct longlong1 { +-+ long long x; +-+ __host__ __device__ longlong1(long long x = 0) : x(x) {} +-+}; +-+struct longlong2 { +-+ long long x, y; +-+ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +-+}; +-+struct longlong4 { +-+ long long x, y, z, w; +-+ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct ulonglong1 { +-+ unsigned long long x; +-+ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +-+}; +-+struct ulonglong2 { +-+ unsigned long long x, y; +-+ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +-+}; +-+struct ulonglong4 { +-+ unsigned long long x, y, z, w; +-+ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct float1 { +-+ float x; +-+ __host__ __device__ float1(float x = 0) : x(x) {} +-+}; +-+struct float2 { +-+ float x, y; +-+ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +-+}; +-+struct float4 { +-+ float x, y, z, w; +-+ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct double1 { +-+ double x; +-+ __host__ __device__ double1(double x = 0) : x(x) {} +-+}; +-+struct double2 { +-+ double x, y; +-+ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +-+}; +-+struct double4 { +-+ double x, y, z, w; +-+ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+typedef unsigned long long cudaTextureObject_t; +-+typedef unsigned long long cudaSurfaceObject_t; +-+ +-+enum cudaTextureReadMode { +-+ cudaReadModeNormalizedFloat, +-+ cudaReadModeElementType +-+}; +-+ +-+enum cudaSurfaceBoundaryMode { +-+ cudaBoundaryModeZero, +-+ cudaBoundaryModeClamp, +-+ cudaBoundaryModeTrap +-+}; +-+ +-+enum { +-+ cudaTextureType1D, +-+ cudaTextureType2D, +-+ cudaTextureType3D, +-+ cudaTextureTypeCubemap, +-+ cudaTextureType1DLayered, +-+ cudaTextureType2DLayered, +-+ cudaTextureTypeCubemapLayered +-+}; +-+ +-+struct textureReference {}; +-+template +-+struct __attribute__((device_builtin_texture_type)) texture +-+ : public textureReference {}; +-+ +-+#endif // !__NVCC__ +-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu +---- a/clang/test/CodeGen/nvptx-surface.cu +-+++ b/clang/test/CodeGen/nvptx-surface.cu +-@@ -1,6 +1,6 @@ +- // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s +- // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s +--#include "include/cuda.h" +-+#include "Inputs/cuda.h" +- +- #include "__clang_cuda_texture_intrinsics.h" +- +-diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp +---- a/clang/test/SemaTemplate/dependent-names.cpp +-+++ b/clang/test/SemaTemplate/dependent-names.cpp +-@@ -458,3 +458,12 @@ +- }; +- int f(b ba) { return ba.add<0>(); } - } -+ --+define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { --+; CHECK-LABEL: define void @test_2xi64_with_wide_load( --+; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { --+; CHECK-NEXT: [[ENTRY:.*]]: --+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] --+; CHECK: [[VECTOR_PH]]: --+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] --+; CHECK: [[VECTOR_BODY]]: --+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] --+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 --+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] --+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 --+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 --+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 --+; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 --+; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 --+; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 --+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] --+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] --+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 --+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 --+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] --+; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] --+; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] --+; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] --+; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> --+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> --+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 --+; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> --+; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> --+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 --+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 --+; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 --+; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] --+; CHECK: [[MIDDLE_BLOCK]]: --+; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] --+; CHECK: [[SCALAR_PH]]: --+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] --+; CHECK-NEXT: br label %[[LOOP:.*]] --+; CHECK: [[LOOP]]: --+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] --+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] --+; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 --+; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 --+; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] --+; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 --+; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] --+; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 --+; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 --+; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] --+; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 --+; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] --+; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 --+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 --+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 --+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] --+; CHECK: [[EXIT]]: +-+namespace TransformDependentTemplates { +-+ template struct Test1 { +-+ template +-+ using Arg = typename T::template Arg; +-+ void f(Arg); +-+ void f(Arg); +-+ }; +-+} // namespace TransformDependentTemplates +-diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +---- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +-+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +-@@ -15391,12 +15391,20 @@ +- +- if (E->State == TreeEntry::SplitVectorize) { +- Res = FindLastInst(); +-+ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { +-+ for (auto *E : Entries) { +-+ auto *I = dyn_cast_or_null(E->VectorizedValue); +-+ if (!I) +-+ I = &getLastInstructionInBundle(E); +-+ if (Res->comesBefore(I)) +-+ Res = I; +-+ } +-+ } +- return *Res; +- } +- +- // Set insertpoint for gathered loads to the very first load. +-- if (E->State != TreeEntry::SplitVectorize && +-- GatheredLoadsEntriesFirst.has_value() && +-+ if (GatheredLoadsEntriesFirst.has_value() && +- E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && +- E->getOpcode() == Instruction::Load) { +- Res = FindFirstInst(); +-diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +---- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +-+++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +-@@ -0,0 +1,99 @@ +-+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +-+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s +-+ +-+define void @test(ptr %0, <8 x i8> %1) { +-+; CHECK-LABEL: define void @test( +-+; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { +-+; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 +-+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 +-+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 +-+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 +-+; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 +-+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> +-+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 +-+; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> +-+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> +-+; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) +-+; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 +-+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> +-+; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) +-+; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) +-+; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] +-+; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 -+; CHECK-NEXT: ret void -+; --+entry: --+ br label %loop --+ --+loop: --+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] --+ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv --+ %l.factor = load i64, ptr %arrayidx, align 8 --+ %1 = shl nsw i64 %iv, 1 --+ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 --+ %l.0 = load i64, ptr %data.0, align 8 --+ %mul.0 = mul i64 %l.factor, %l.0 --+ store i64 %mul.0, ptr %data.0, align 8 --+ %3 = or disjoint i64 %1, 1 --+ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 --+ %l.1 = load i64, ptr %data.1, align 8 --+ %mul.1 = mul i64 %l.factor, %l.1 --+ store i64 %mul.1, ptr %data.1, align 8 --+ %iv.next = add nuw nsw i64 %iv, 1 --+ %ec = icmp eq i64 %iv.next, 100 --+ br i1 %ec, label %exit, label %loop --+ --+exit: +-+ %3 = load i8, ptr %0, align 2 +-+ %4 = getelementptr i8, ptr %0, i64 13442 +-+ %5 = load i8, ptr %4, align 2 +-+ %6 = or i8 %5, %3 +-+ %7 = getelementptr i8, ptr %0, i64 13550 +-+ store i8 %6, ptr %7, align 2 +-+ %8 = extractelement <8 x i8> %1, i64 0 +-+ %9 = or i8 %5, %8 +-+ %10 = getelementptr i8, ptr %0, i64 13542 +-+ store i8 %9, ptr %10, align 2 +-+ %11 = getelementptr i8, ptr %0, i64 13438 +-+ %12 = load i8, ptr %11, align 2 +-+ %13 = or i8 %12, %3 +-+ %14 = getelementptr i8, ptr %0, i64 13546 +-+ store i8 %13, ptr %14, align 2 +-+ %15 = extractelement <8 x i8> %1, i64 2 +-+ %16 = or i8 %12, %15 +-+ %17 = getelementptr i8, ptr %0, i64 13538 +-+ store i8 %16, ptr %17, align 2 +-+ %18 = getelementptr i8, ptr %0, i64 13440 +-+ %19 = load i8, ptr %18, align 4 +-+ %20 = or i8 %19, %3 +-+ %21 = getelementptr i8, ptr %0, i64 13548 +-+ store i8 %20, ptr %21, align 4 +-+ %22 = extractelement <8 x i8> %1, i64 4 +-+ %23 = or i8 %19, %22 +-+ %24 = getelementptr i8, ptr %0, i64 13540 +-+ store i8 %23, ptr %24, align 4 +-+ %25 = getelementptr i8, ptr %0, i64 13436 +-+ %26 = load i8, ptr %25, align 4 +-+ %27 = getelementptr i8, ptr %0, i64 13444 +-+ %28 = load i8, ptr %27, align 4 +-+ %29 = or i8 %28, %26 +-+ %30 = getelementptr i8, ptr %0, i64 13544 +-+ store i8 %29, ptr %30, align 4 +-+ %31 = or i8 %26, %8 +-+ %32 = getelementptr i8, ptr %0, i64 13536 +-+ store i8 %31, ptr %32, align 4 +-+ %33 = getelementptr i8, ptr %0, i64 13443 +-+ %34 = load i8, ptr %33, align 1 +-+ %35 = or i8 %34, %3 +-+ %36 = getelementptr i8, ptr %0, i64 13551 +-+ store i8 %35, ptr %36, align 1 +-+ %37 = extractelement <8 x i8> %1, i64 7 +-+ %38 = or i8 %34, %37 +-+ %39 = getelementptr i8, ptr %0, i64 13543 +-+ store i8 %38, ptr %39, align 1 +-+ %40 = getelementptr i8, ptr %0, i64 13439 +-+ %41 = load i8, ptr %40, align 1 +-+ %42 = or i8 %41, %3 +-+ %43 = getelementptr i8, ptr %0, i64 13547 +-+ store i8 %42, ptr %43, align 1 +-+ %44 = extractelement <8 x i8> %1, i64 3 +-+ %45 = or i8 %41, %44 +-+ %46 = getelementptr i8, ptr %0, i64 13539 +-+ store i8 %45, ptr %46, align 1 +-+ %47 = getelementptr i8, ptr %0, i64 13441 +-+ %48 = load i8, ptr %47, align 1 +-+ %49 = or i8 %48, %3 +-+ %50 = getelementptr i8, ptr %0, i64 13549 +-+ store i8 %49, ptr %50, align 1 +-+ %51 = extractelement <8 x i8> %1, i64 5 +-+ %52 = or i8 %48, %51 +-+ %53 = getelementptr i8, ptr %0, i64 13541 +-+ store i8 %52, ptr %53, align 1 +-+ %54 = getelementptr i8, ptr %0, i64 13437 +-+ %55 = load i8, ptr %54, align 1 +-+ %56 = or i8 %55, %3 +-+ %57 = getelementptr i8, ptr %0, i64 13545 +-+ store i8 %56, ptr %57, align 1 +-+ %58 = or i8 %55, %8 +-+ %59 = getelementptr i8, ptr %0, i64 13537 +-+ store i8 %58, ptr %59, align 1 -+ ret void -+} - diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll - --- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll - +++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -@@ -731,18 +584,3 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-nod - + store i8 %58, ptr %59, align 1 - + ret void - +} --diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ----- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp --+++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp --@@ -151,9 +151,10 @@ -- MachineModuleInfoWrapperPass *MMIWP = -- new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); -- --- legacy::PassManager PassMgrF; -- SmallString<1024> Buf; -- llvm::raw_svector_ostream OS(Buf); --+ legacy::PassManager PassMgrF; --+ -- AsmPrinter *Printer = -- addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); -- PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index c3bcd53..73450ce 100644 +index 73450ce..7993194 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" -- LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" -+ LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" -+ LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" +- LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" +- LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" ++ LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" ++ LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 104a8de57175ad..503b82b5d33179 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "1ba08b6822b3bce9ce4acb3b839b05b3266ca0bc" - SHARDY_SHA256 = "6930a383ed9b516041f08ae948e86a3926a2cf11c7457fa50950f298275c6a84" + SHARDY_COMMIT = "2bd86e4ef697536b0683149a93022e21d8d5e6d3" + SHARDY_SHA256 = "a3b3672c72cadd8cafd837d7da219cebf97c5312e545ed1ebc639e71a47b60e5" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index bd3beff6435ac1..b5a97068282d17 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,242 +1,606 @@ -diff --git a/shardy/dialect/sdy/ir/BUILD b/shardy/dialect/sdy/ir/BUILD -index 780cd17..fe8986b 100644 ---- a/shardy/dialect/sdy/ir/BUILD -+++ b/shardy/dialect/sdy/ir/BUILD -@@ -164,6 +164,7 @@ cc_library( - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", -+ "@llvm-project//mlir:TransformUtils", - "@stablehlo//:stablehlo_assembly_format", - "@stablehlo//:stablehlo_ops", - "@stablehlo//:stablehlo_type_inference", -diff --git a/shardy/dialect/sdy/ir/canonicalization.cc b/shardy/dialect/sdy/ir/canonicalization.cc -index e1b391f..7ab3e28 100644 ---- a/shardy/dialect/sdy/ir/canonicalization.cc -+++ b/shardy/dialect/sdy/ir/canonicalization.cc -@@ -25,6 +25,7 @@ limitations under the License. - #include "mlir/IR/Region.h" - #include "mlir/IR/Value.h" - #include "mlir/Support/LLVM.h" -+#include "mlir/Transforms/Inliner.h" - #include "mlir/Transforms/InliningUtils.h" - #include "shardy/dialect/sdy/ir/dialect.h" - #include "shardy/dialect/sdy/ir/utils.h" -@@ -103,9 +104,11 @@ class RedundantManualComputationPattern - } - - mlir::InlinerInterface inliner(manualComputationOp.getContext()); -+ mlir::InlinerConfig config; - if (inlineRegion( -- inliner, &manualComputationOp.getRegion(), -- manualComputationOp->getBlock(), manualComputationOp->getIterator(), -+ inliner, config.getCloneCallback(), -+ &manualComputationOp.getRegion(), manualComputationOp->getBlock(), -+ manualComputationOp->getIterator(), - manualComputationOp.getOperands(), manualComputationOp.getResults()) - .failed()) { - manualComputationOp.emitOpError( diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 97282ec..a3ecef4 100644 +index a3ecef4..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,16 +1,4 @@ +@@ -1,586 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp ----- a/clang/lib/AST/ASTContext.cpp --+++ b/clang/lib/AST/ASTContext.cpp --@@ -7011,7 +7011,7 @@ -- getCanonicalTemplateArgument(subst->getArgumentPack()); -- return getSubstTemplateTemplateParmPack( -- canonArgPack, subst->getAssociatedDecl()->getCanonicalDecl(), --- subst->getFinal(), subst->getIndex()); --+ subst->getIndex(), subst->getFinal()); -- } -- case TemplateName::DeducedTemplate: { -- assert(IgnoreDeduced == false); - diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h - --- a/clang/lib/Sema/TreeTransform.h - +++ b/clang/lib/Sema/TreeTransform.h -@@ -44,28 +32,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/ - TemplateName Name = getDerived().RebuildTemplateName( - SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), - /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, --diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ----- a/clang/lib/Serialization/ASTReaderStmt.cpp --+++ b/clang/lib/Serialization/ASTReaderStmt.cpp --@@ -2229,6 +2229,7 @@ -- E->PackIndex = Record.readInt(); -- else -- E->PackIndex = 0; --+ E->Final = CurrentUnpackingBits->getNextBit(); -- E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); -- E->Replacement = Record.readSubExpr(); -- } --diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ----- a/clang/lib/Serialization/ASTWriterStmt.cpp --+++ b/clang/lib/Serialization/ASTWriterStmt.cpp --@@ -2229,6 +2229,7 @@ -- CurrentPackingBits.addBit((bool)E->getPackIndex()); -- if (auto PackIndex = E->getPackIndex()) -- Record.push_back(*PackIndex + 1); --+ CurrentPackingBits.addBit(E->getFinal()); +-diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h +---- a/clang/lib/Sema/TreeTransform.h +-+++ b/clang/lib/Sema/TreeTransform.h +-@@ -7765,17 +7765,23 @@ +- NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); +- NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); +- +-- typedef TemplateArgumentLocContainerIterator< +-- DependentTemplateSpecializationTypeLoc> ArgIterator; +-- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), +-- ArgIterator(TL, TL.getNumArgs()), +-- NewTemplateArgs)) +-+ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); +-+ +-+ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), +-+ ArgsRange.end(), NewTemplateArgs)) +- return QualType(); +-+ bool TemplateArgumentsChanged = !llvm::equal( +-+ ArgsRange, NewTemplateArgs.arguments(), +-+ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { +-+ return A.getArgument().structurallyEquals(B.getArgument()); +-+ }); - -- Record.AddSourceLocation(E->getNameLoc()); -- Record.AddStmt(E->getReplacement()); - diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h - --- a/clang/test/CodeGen/include/cuda.h - +++ b/clang/test/CodeGen/include/cuda.h -@@ -515,119 +481,6 @@ diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp - E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && - E->getOpcode() == Instruction::Load) { - Res = FindFirstInst(); --diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp ----- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --@@ -2590,6 +2590,14 @@ -- if (R.mayWriteToMemory() && !InterleaveR) -- return; +- const DependentTemplateStorage &DTN = T->getDependentTemplateName(); - --+ // Do not narrow interleave groups if there are VectorPointer recipes and --+ // the plan was unrolled. The recipe implicitly uses VF from --+ // VPTransformState. --+ // TODO: Remove restriction once the VF for the VectorPointer offset is --+ // modeled explicitly as operand. --+ if (isa(&R) && Plan.getUF() > 1) --+ return; --+ -- // All other ops are allowed, but we reject uses that cannot be converted -- // when checking all allowed consumers (store interleave groups) below. -- if (!InterleaveR) --diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll ----- a/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll --+++ b/llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll --@@ -66,3 +66,91 @@ -- exit: -- ret void +- QualType Result = TL.getType(); +-- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { +-+ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || +-+ TemplateArgumentsChanged) { +- TemplateName Name = getDerived().RebuildTemplateName( +- SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), +- /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, +-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h +---- a/clang/test/CodeGen/include/cuda.h +-+++ b/clang/test/CodeGen/include/cuda.h +-@@ -1,194 +0,0 @@ +--/* Minimal declarations for CUDA support. Testing purposes only. +-- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +-- */ +--#pragma once +-- +--// Make this file work with nvcc, for testing compatibility. +-- +--#ifndef __NVCC__ +--#define __constant__ __attribute__((constant)) +--#define __device__ __attribute__((device)) +--#define __global__ __attribute__((global)) +--#define __host__ __attribute__((host)) +--#define __shared__ __attribute__((shared)) +--#define __managed__ __attribute__((managed)) +--#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +-- +--struct dim3 { +-- unsigned x, y, z; +-- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +--}; +-- +--// Host- and device-side placement new overloads. +--void *operator new(__SIZE_TYPE__, void *p) { return p; } +--void *operator new[](__SIZE_TYPE__, void *p) { return p; } +--__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +--__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-- +--#define CUDA_VERSION 10100 +-- +--struct char1 { +-- char x; +-- __host__ __device__ char1(char x = 0) : x(x) {} +--}; +--struct char2 { +-- char x, y; +-- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +--}; +--struct char4 { +-- char x, y, z, w; +-- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct uchar1 { +-- unsigned char x; +-- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +--}; +--struct uchar2 { +-- unsigned char x, y; +-- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +--}; +--struct uchar4 { +-- unsigned char x, y, z, w; +-- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct short1 { +-- short x; +-- __host__ __device__ short1(short x = 0) : x(x) {} +--}; +--struct short2 { +-- short x, y; +-- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +--}; +--struct short4 { +-- short x, y, z, w; +-- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct ushort1 { +-- unsigned short x; +-- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +--}; +--struct ushort2 { +-- unsigned short x, y; +-- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +--}; +--struct ushort4 { +-- unsigned short x, y, z, w; +-- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct int1 { +-- int x; +-- __host__ __device__ int1(int x = 0) : x(x) {} +--}; +--struct int2 { +-- int x, y; +-- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +--}; +--struct int4 { +-- int x, y, z, w; +-- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct uint1 { +-- unsigned x; +-- __host__ __device__ uint1(unsigned x = 0) : x(x) {} +--}; +--struct uint2 { +-- unsigned x, y; +-- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +--}; +--struct uint3 { +-- unsigned x, y, z; +-- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +--}; +--struct uint4 { +-- unsigned x, y, z, w; +-- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct longlong1 { +-- long long x; +-- __host__ __device__ longlong1(long long x = 0) : x(x) {} +--}; +--struct longlong2 { +-- long long x, y; +-- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +--}; +--struct longlong4 { +-- long long x, y, z, w; +-- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct ulonglong1 { +-- unsigned long long x; +-- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +--}; +--struct ulonglong2 { +-- unsigned long long x, y; +-- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +--}; +--struct ulonglong4 { +-- unsigned long long x, y, z, w; +-- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct float1 { +-- float x; +-- __host__ __device__ float1(float x = 0) : x(x) {} +--}; +--struct float2 { +-- float x, y; +-- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +--}; +--struct float4 { +-- float x, y, z, w; +-- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--struct double1 { +-- double x; +-- __host__ __device__ double1(double x = 0) : x(x) {} +--}; +--struct double2 { +-- double x, y; +-- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +--}; +--struct double4 { +-- double x, y, z, w; +-- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +--}; +-- +--typedef unsigned long long cudaTextureObject_t; +--typedef unsigned long long cudaSurfaceObject_t; +-- +--enum cudaTextureReadMode { +-- cudaReadModeNormalizedFloat, +-- cudaReadModeElementType +--}; +-- +--enum cudaSurfaceBoundaryMode { +-- cudaBoundaryModeZero, +-- cudaBoundaryModeClamp, +-- cudaBoundaryModeTrap +--}; +-- +--enum { +-- cudaTextureType1D, +-- cudaTextureType2D, +-- cudaTextureType3D, +-- cudaTextureTypeCubemap, +-- cudaTextureType1DLayered, +-- cudaTextureType2DLayered, +-- cudaTextureTypeCubemapLayered +--}; +-- +--struct textureReference {}; +--template +--struct __attribute__((device_builtin_texture_type)) texture +-- : public textureReference {}; +-- +--#endif // !__NVCC__ +-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h +---- a/clang/test/CodeGen/Inputs/cuda.h +-+++ b/clang/test/CodeGen/Inputs/cuda.h +-@@ -0,0 +1,194 @@ +-+/* Minimal declarations for CUDA support. Testing purposes only. +-+ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h +-+ */ +-+#pragma once +-+ +-+// Make this file work with nvcc, for testing compatibility. +-+ +-+#ifndef __NVCC__ +-+#define __constant__ __attribute__((constant)) +-+#define __device__ __attribute__((device)) +-+#define __global__ __attribute__((global)) +-+#define __host__ __attribute__((host)) +-+#define __shared__ __attribute__((shared)) +-+#define __managed__ __attribute__((managed)) +-+#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) +-+ +-+struct dim3 { +-+ unsigned x, y, z; +-+ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} +-+}; +-+ +-+// Host- and device-side placement new overloads. +-+void *operator new(__SIZE_TYPE__, void *p) { return p; } +-+void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-+__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +-+__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } +-+ +-+#define CUDA_VERSION 10100 +-+ +-+struct char1 { +-+ char x; +-+ __host__ __device__ char1(char x = 0) : x(x) {} +-+}; +-+struct char2 { +-+ char x, y; +-+ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} +-+}; +-+struct char4 { +-+ char x, y, z, w; +-+ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct uchar1 { +-+ unsigned char x; +-+ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} +-+}; +-+struct uchar2 { +-+ unsigned char x, y; +-+ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} +-+}; +-+struct uchar4 { +-+ unsigned char x, y, z, w; +-+ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct short1 { +-+ short x; +-+ __host__ __device__ short1(short x = 0) : x(x) {} +-+}; +-+struct short2 { +-+ short x, y; +-+ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} +-+}; +-+struct short4 { +-+ short x, y, z, w; +-+ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct ushort1 { +-+ unsigned short x; +-+ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} +-+}; +-+struct ushort2 { +-+ unsigned short x, y; +-+ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} +-+}; +-+struct ushort4 { +-+ unsigned short x, y, z, w; +-+ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct int1 { +-+ int x; +-+ __host__ __device__ int1(int x = 0) : x(x) {} +-+}; +-+struct int2 { +-+ int x, y; +-+ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} +-+}; +-+struct int4 { +-+ int x, y, z, w; +-+ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct uint1 { +-+ unsigned x; +-+ __host__ __device__ uint1(unsigned x = 0) : x(x) {} +-+}; +-+struct uint2 { +-+ unsigned x, y; +-+ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} +-+}; +-+struct uint3 { +-+ unsigned x, y, z; +-+ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} +-+}; +-+struct uint4 { +-+ unsigned x, y, z, w; +-+ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct longlong1 { +-+ long long x; +-+ __host__ __device__ longlong1(long long x = 0) : x(x) {} +-+}; +-+struct longlong2 { +-+ long long x, y; +-+ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} +-+}; +-+struct longlong4 { +-+ long long x, y, z, w; +-+ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct ulonglong1 { +-+ unsigned long long x; +-+ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} +-+}; +-+struct ulonglong2 { +-+ unsigned long long x, y; +-+ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} +-+}; +-+struct ulonglong4 { +-+ unsigned long long x, y, z, w; +-+ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct float1 { +-+ float x; +-+ __host__ __device__ float1(float x = 0) : x(x) {} +-+}; +-+struct float2 { +-+ float x, y; +-+ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} +-+}; +-+struct float4 { +-+ float x, y, z, w; +-+ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+struct double1 { +-+ double x; +-+ __host__ __device__ double1(double x = 0) : x(x) {} +-+}; +-+struct double2 { +-+ double x, y; +-+ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} +-+}; +-+struct double4 { +-+ double x, y, z, w; +-+ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} +-+}; +-+ +-+typedef unsigned long long cudaTextureObject_t; +-+typedef unsigned long long cudaSurfaceObject_t; +-+ +-+enum cudaTextureReadMode { +-+ cudaReadModeNormalizedFloat, +-+ cudaReadModeElementType +-+}; +-+ +-+enum cudaSurfaceBoundaryMode { +-+ cudaBoundaryModeZero, +-+ cudaBoundaryModeClamp, +-+ cudaBoundaryModeTrap +-+}; +-+ +-+enum { +-+ cudaTextureType1D, +-+ cudaTextureType2D, +-+ cudaTextureType3D, +-+ cudaTextureTypeCubemap, +-+ cudaTextureType1DLayered, +-+ cudaTextureType2DLayered, +-+ cudaTextureTypeCubemapLayered +-+}; +-+ +-+struct textureReference {}; +-+template +-+struct __attribute__((device_builtin_texture_type)) texture +-+ : public textureReference {}; +-+ +-+#endif // !__NVCC__ +-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu +---- a/clang/test/CodeGen/nvptx-surface.cu +-+++ b/clang/test/CodeGen/nvptx-surface.cu +-@@ -1,6 +1,6 @@ +- // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s +- // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s +--#include "include/cuda.h" +-+#include "Inputs/cuda.h" +- +- #include "__clang_cuda_texture_intrinsics.h" +- +-diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp +---- a/clang/test/SemaTemplate/dependent-names.cpp +-+++ b/clang/test/SemaTemplate/dependent-names.cpp +-@@ -458,3 +458,12 @@ +- }; +- int f(b ba) { return ba.add<0>(); } - } -+ --+define void @test_2xi64_with_wide_load(ptr noalias %data, ptr noalias %factor) { --+; CHECK-LABEL: define void @test_2xi64_with_wide_load( --+; CHECK-SAME: ptr noalias [[DATA:%.*]], ptr noalias [[FACTOR:%.*]]) { --+; CHECK-NEXT: [[ENTRY:.*]]: --+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] --+; CHECK: [[VECTOR_PH]]: --+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] --+; CHECK: [[VECTOR_BODY]]: --+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] --+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 2 --+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[INDEX]] --+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 0 --+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[TMP1]], i32 2 --+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = load <2 x i64>, ptr [[TMP2]], align 8 --+; CHECK-NEXT: [[BROADCAST_SPLAT3:%.*]] = load <2 x i64>, ptr [[TMP3]], align 8 --+; CHECK-NEXT: [[TMP6:%.*]] = shl nsw i64 [[INDEX]], 1 --+; CHECK-NEXT: [[TMP7:%.*]] = shl nsw i64 [[TMP0]], 1 --+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP6]] --+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP7]] --+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP8]], align 8 --+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[STRIDED_VEC2:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[WIDE_VEC3:%.*]] = load <4 x i64>, ptr [[TMP9]], align 8 --+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x i64> [[WIDE_VEC3]], <4 x i64> poison, <2 x i32> --+; CHECK-NEXT: [[TMP10:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[WIDE_LOAD]] --+; CHECK-NEXT: [[TMP11:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[WIDE_LOAD1]] --+; CHECK-NEXT: [[TMP15:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT]], [[STRIDED_VEC2]] --+; CHECK-NEXT: [[TMP16:%.*]] = mul <2 x i64> [[BROADCAST_SPLAT3]], [[STRIDED_VEC5]] --+; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x i64> [[TMP10]], <2 x i64> [[TMP15]], <4 x i32> --+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP17]], <4 x i64> poison, <4 x i32> --+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP8]], align 8 --+; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x i64> [[TMP11]], <2 x i64> [[TMP16]], <4 x i32> --+; CHECK-NEXT: [[INTERLEAVED_VEC6:%.*]] = shufflevector <4 x i64> [[TMP18]], <4 x i64> poison, <4 x i32> --+; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC6]], ptr [[TMP9]], align 8 --+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 --+; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100 --+; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] --+; CHECK: [[MIDDLE_BLOCK]]: --+; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]] --+; CHECK: [[SCALAR_PH]]: --+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 100, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] --+; CHECK-NEXT: br label %[[LOOP:.*]] --+; CHECK: [[LOOP]]: --+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] --+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[FACTOR]], i64 [[IV]] --+; CHECK-NEXT: [[L_FACTOR:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 --+; CHECK-NEXT: [[TMP13:%.*]] = shl nsw i64 [[IV]], 1 --+; CHECK-NEXT: [[DATA_0:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP13]] --+; CHECK-NEXT: [[L_0:%.*]] = load i64, ptr [[DATA_0]], align 8 --+; CHECK-NEXT: [[MUL_0:%.*]] = mul i64 [[L_FACTOR]], [[L_0]] --+; CHECK-NEXT: store i64 [[MUL_0]], ptr [[DATA_0]], align 8 --+; CHECK-NEXT: [[TMP14:%.*]] = or disjoint i64 [[TMP13]], 1 --+; CHECK-NEXT: [[DATA_1:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP14]] --+; CHECK-NEXT: [[L_1:%.*]] = load i64, ptr [[DATA_1]], align 8 --+; CHECK-NEXT: [[MUL_1:%.*]] = mul i64 [[L_FACTOR]], [[L_1]] --+; CHECK-NEXT: store i64 [[MUL_1]], ptr [[DATA_1]], align 8 --+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 --+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 100 --+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] --+; CHECK: [[EXIT]]: +-+namespace TransformDependentTemplates { +-+ template struct Test1 { +-+ template +-+ using Arg = typename T::template Arg; +-+ void f(Arg); +-+ void f(Arg); +-+ }; +-+} // namespace TransformDependentTemplates +-diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +---- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +-+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +-@@ -15391,12 +15391,20 @@ +- +- if (E->State == TreeEntry::SplitVectorize) { +- Res = FindLastInst(); +-+ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { +-+ for (auto *E : Entries) { +-+ auto *I = dyn_cast_or_null(E->VectorizedValue); +-+ if (!I) +-+ I = &getLastInstructionInBundle(E); +-+ if (Res->comesBefore(I)) +-+ Res = I; +-+ } +-+ } +- return *Res; +- } +- +- // Set insertpoint for gathered loads to the very first load. +-- if (E->State != TreeEntry::SplitVectorize && +-- GatheredLoadsEntriesFirst.has_value() && +-+ if (GatheredLoadsEntriesFirst.has_value() && +- E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && +- E->getOpcode() == Instruction::Load) { +- Res = FindFirstInst(); +-diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +---- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +-+++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll +-@@ -0,0 +1,99 @@ +-+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +-+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s +-+ +-+define void @test(ptr %0, <8 x i8> %1) { +-+; CHECK-LABEL: define void @test( +-+; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { +-+; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 +-+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 +-+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 +-+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 +-+; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 +-+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> +-+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 +-+; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> +-+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> +-+; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) +-+; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 +-+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> +-+; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) +-+; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) +-+; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] +-+; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 -+; CHECK-NEXT: ret void -+; --+entry: --+ br label %loop --+ --+loop: --+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] --+ %arrayidx = getelementptr inbounds i64, ptr %factor, i64 %iv --+ %l.factor = load i64, ptr %arrayidx, align 8 --+ %1 = shl nsw i64 %iv, 1 --+ %data.0 = getelementptr inbounds i64, ptr %data, i64 %1 --+ %l.0 = load i64, ptr %data.0, align 8 --+ %mul.0 = mul i64 %l.factor, %l.0 --+ store i64 %mul.0, ptr %data.0, align 8 --+ %3 = or disjoint i64 %1, 1 --+ %data.1 = getelementptr inbounds i64, ptr %data, i64 %3 --+ %l.1 = load i64, ptr %data.1, align 8 --+ %mul.1 = mul i64 %l.factor, %l.1 --+ store i64 %mul.1, ptr %data.1, align 8 --+ %iv.next = add nuw nsw i64 %iv, 1 --+ %ec = icmp eq i64 %iv.next, 100 --+ br i1 %ec, label %exit, label %loop --+ --+exit: +-+ %3 = load i8, ptr %0, align 2 +-+ %4 = getelementptr i8, ptr %0, i64 13442 +-+ %5 = load i8, ptr %4, align 2 +-+ %6 = or i8 %5, %3 +-+ %7 = getelementptr i8, ptr %0, i64 13550 +-+ store i8 %6, ptr %7, align 2 +-+ %8 = extractelement <8 x i8> %1, i64 0 +-+ %9 = or i8 %5, %8 +-+ %10 = getelementptr i8, ptr %0, i64 13542 +-+ store i8 %9, ptr %10, align 2 +-+ %11 = getelementptr i8, ptr %0, i64 13438 +-+ %12 = load i8, ptr %11, align 2 +-+ %13 = or i8 %12, %3 +-+ %14 = getelementptr i8, ptr %0, i64 13546 +-+ store i8 %13, ptr %14, align 2 +-+ %15 = extractelement <8 x i8> %1, i64 2 +-+ %16 = or i8 %12, %15 +-+ %17 = getelementptr i8, ptr %0, i64 13538 +-+ store i8 %16, ptr %17, align 2 +-+ %18 = getelementptr i8, ptr %0, i64 13440 +-+ %19 = load i8, ptr %18, align 4 +-+ %20 = or i8 %19, %3 +-+ %21 = getelementptr i8, ptr %0, i64 13548 +-+ store i8 %20, ptr %21, align 4 +-+ %22 = extractelement <8 x i8> %1, i64 4 +-+ %23 = or i8 %19, %22 +-+ %24 = getelementptr i8, ptr %0, i64 13540 +-+ store i8 %23, ptr %24, align 4 +-+ %25 = getelementptr i8, ptr %0, i64 13436 +-+ %26 = load i8, ptr %25, align 4 +-+ %27 = getelementptr i8, ptr %0, i64 13444 +-+ %28 = load i8, ptr %27, align 4 +-+ %29 = or i8 %28, %26 +-+ %30 = getelementptr i8, ptr %0, i64 13544 +-+ store i8 %29, ptr %30, align 4 +-+ %31 = or i8 %26, %8 +-+ %32 = getelementptr i8, ptr %0, i64 13536 +-+ store i8 %31, ptr %32, align 4 +-+ %33 = getelementptr i8, ptr %0, i64 13443 +-+ %34 = load i8, ptr %33, align 1 +-+ %35 = or i8 %34, %3 +-+ %36 = getelementptr i8, ptr %0, i64 13551 +-+ store i8 %35, ptr %36, align 1 +-+ %37 = extractelement <8 x i8> %1, i64 7 +-+ %38 = or i8 %34, %37 +-+ %39 = getelementptr i8, ptr %0, i64 13543 +-+ store i8 %38, ptr %39, align 1 +-+ %40 = getelementptr i8, ptr %0, i64 13439 +-+ %41 = load i8, ptr %40, align 1 +-+ %42 = or i8 %41, %3 +-+ %43 = getelementptr i8, ptr %0, i64 13547 +-+ store i8 %42, ptr %43, align 1 +-+ %44 = extractelement <8 x i8> %1, i64 3 +-+ %45 = or i8 %41, %44 +-+ %46 = getelementptr i8, ptr %0, i64 13539 +-+ store i8 %45, ptr %46, align 1 +-+ %47 = getelementptr i8, ptr %0, i64 13441 +-+ %48 = load i8, ptr %47, align 1 +-+ %49 = or i8 %48, %3 +-+ %50 = getelementptr i8, ptr %0, i64 13549 +-+ store i8 %49, ptr %50, align 1 +-+ %51 = extractelement <8 x i8> %1, i64 5 +-+ %52 = or i8 %48, %51 +-+ %53 = getelementptr i8, ptr %0, i64 13541 +-+ store i8 %52, ptr %53, align 1 +-+ %54 = getelementptr i8, ptr %0, i64 13437 +-+ %55 = load i8, ptr %54, align 1 +-+ %56 = or i8 %55, %3 +-+ %57 = getelementptr i8, ptr %0, i64 13545 +-+ store i8 %56, ptr %57, align 1 +-+ %58 = or i8 %55, %8 +-+ %59 = getelementptr i8, ptr %0, i64 13537 +-+ store i8 %58, ptr %59, align 1 -+ ret void -+} - diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll - --- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll - +++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll -@@ -731,18 +584,3 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-nod - + store i8 %58, ptr %59, align 1 - + ret void - +} --diff -ruN --strip-trailing-cr a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp ----- a/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp --+++ b/llvm/unittests/CodeGen/X86MCInstLowerTest.cpp --@@ -151,9 +151,10 @@ -- MachineModuleInfoWrapperPass *MMIWP = -- new MachineModuleInfoWrapperPass(TM.get(), &*MCFoo); -- --- legacy::PassManager PassMgrF; -- SmallString<1024> Buf; -- llvm::raw_svector_ostream OS(Buf); --+ legacy::PassManager PassMgrF; --+ -- AsmPrinter *Printer = -- addPassesToEmitFile(PassMgrF, OS, CodeGenFileType::AssemblyFile, MMIWP); -- PassMgrF.run(*M); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index c3bcd53..73450ce 100644 +index 73450ce..7993194 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "cd54cb062bba9c90a8f3723bf66caa7effbcf259" -- LLVM_SHA256 = "4054d0f174e80e9d0ca62af465a60252faabe4c7163612c0fdcb86898f7f266a" -+ LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" -+ LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" +- LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" +- LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" ++ LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" ++ LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 104a8de57175ad..503b82b5d33179 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "1ba08b6822b3bce9ce4acb3b839b05b3266ca0bc" - SHARDY_SHA256 = "6930a383ed9b516041f08ae948e86a3926a2cf11c7457fa50950f298275c6a84" + SHARDY_COMMIT = "2bd86e4ef697536b0683149a93022e21d8d5e6d3" + SHARDY_SHA256 = "a3b3672c72cadd8cafd837d7da219cebf97c5312e545ed1ebc639e71a47b60e5" tf_http_archive( name = "shardy", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc index 5fd0d4c5a9429b..4075a4977572a8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc @@ -125,7 +125,7 @@ struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern { size_t pos; auto insert = mlir::dyn_cast(op->use_begin()->getOwner()); - if (!insert || insert.getSource() != op->getResult(0) || + if (!insert || insert.getValueToStore() != op->getResult(0) || !matchPos(insert, &pos) || !insert.getDest().hasOneUse()) { return std::nullopt; } diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/fuse_loops.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/fuse_loops.cc index 43fcbb9c96faac..561408e034ad87 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/fuse_loops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/fuse_loops.cc @@ -148,7 +148,7 @@ void FuseExtractInsertLoopPair(MLIRContext* mlir_context, LoopOp insert_loop, auto vector_cst = insert_loop.getInits().back(); insert_loop->replaceAllUsesWith(ValueRange(vector_cst)); extract_loop->replaceAllUsesWith(new_loop.getResults()); - extract.replaceAllUsesWith(insert.getSource()); + extract.replaceAllUsesWith(insert.getValueToStore()); auto insert_loop_yield = mlir::dyn_cast(insert_loop.getRegion().front().back()); rewriter.eraseOp(insert_loop_yield); diff --git a/third_party/xla/xla/codegen/emitters/transforms/flatten_tensors.cc b/third_party/xla/xla/codegen/emitters/transforms/flatten_tensors.cc index 98f6dcedf68b39..8ec80544f72a7b 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/flatten_tensors.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/flatten_tensors.cc @@ -411,7 +411,7 @@ struct RewriteVectorInsert : OpRewritePattern { GetFlattenedType(vector_type), vector) .getResult(0); auto new_insert = - b.create(op.getSource(), vector_1D, linear_index); + b.create(op.getValueToStore(), vector_1D, linear_index); auto cast_to_orig_type = b.create( vector_type, new_insert.getResult()); rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); diff --git a/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc index 9fbfcca9e47c44..0e4524087757a5 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc @@ -466,7 +466,7 @@ struct FoldVectorInsertExtractPairs // Check that the value that we insert is produced by a vector.extract. auto extract = mlir::dyn_cast_or_null( - insert.getSource().getDefiningOp()); + insert.getValueToStore().getDefiningOp()); if (!extract || !extract.hasDynamicPosition() || !extract->hasOneUse()) { return rewriter.notifyMatchFailure(insert, "no single-use vector.extract found"); From 09dfb10a541041bd3ef95ce9569838ddab177292 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Wed, 9 Apr 2025 22:45:40 -0700 Subject: [PATCH 0485/1324] [XLA] Preseve backend config when removing computation parameters There is a rollover when a parameter in the middle is removed, so we need to also carry over any extra attributes attached to the parameters. PiperOrigin-RevId: 745874780 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 1 + third_party/xla/xla/service/BUILD | 2 ++ .../xla/xla/service/hlo_computation_test.cc | 32 +++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index ea48be15dea2bb..d345d97fa3910b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -455,6 +455,7 @@ absl::Status HloComputation::RemoveParameter(int64_t param_no) { HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), StrCat("param_", param_no))); + param_instruction->SetupDerivedInstruction(new_instr); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(ForceRemoveInstruction(param_instruction)); diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 86ff582f47a9be..afd16b7671a801 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3646,9 +3646,11 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index 9a0c3cdb786a7f..808ee67141bde8 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -1041,5 +1044,34 @@ TEST_F(HloComputationTest, ToStringWhileCreatingReplacements) { EXPECT_EQ(counter, entry->instruction_count()); } +TEST_F(HloComputationTest, RemoveParameterWithBackendConfig) { + const absl::string_view hlo = R"( +ENTRY main { + arg.0 = s32[] parameter(0) + arg.1 = s32[] parameter(1) + ROOT call.0 = (s32[]) call(arg.0, arg.1), to_apply={ + arg.0 = s32[] parameter(0) + arg.1 = s32[] parameter(1), backend_config={"config" : []} + ROOT tuple.0 = tuple(arg.1) + } +} + )"; + // Since we remove the called computation parameter, we also need to remove + // the operand from the callee, but that's not possible, due to all related + // APIs being private/protected, hence the module will be left in an illegal + // state and not verifiable. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(hlo)); + + HloInstruction* call0 = module->entry_computation()->root_instruction(); + ASSERT_EQ(call0->opcode(), HloOpcode::kCall); + HloComputation* computation = call0->to_apply(); + ASSERT_TRUE(!computation->parameter_instruction(0)->has_backend_config()); + // Parameter 0 is dead and safe to remove. + TF_ASSERT_OK(computation->RemoveParameter(0)); + // Parameter 1 shifted to parameter 0 and should preserve its backend config. + EXPECT_TRUE(computation->parameter_instruction(0)->has_backend_config()); +} + } // namespace } // namespace xla From 9697fd2d49c0be9b30f1acf39785d1fc2196b428 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Wed, 9 Apr 2025 23:07:37 -0700 Subject: [PATCH 0486/1324] Reverts 88cbfdecee62ee0cc00a03e009bd7018db8d0a74 PiperOrigin-RevId: 745880173 --- third_party/xla/xla/client/local_client.cc | 27 +- third_party/xla/xla/client/local_client.h | 9 + third_party/xla/xla/pjrt/BUILD | 8 + .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 35 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 10 - .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 5 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc | 11 +- .../pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc | 9 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 265 ++++++++++++------ .../xla/pjrt/pjrt_stream_executor_client.h | 34 ++- .../xla/pjrt/stream_executor_executable.cc | 123 +++++++- .../xla/xla/pjrt/stream_executor_executable.h | 99 +++++-- .../xla/xla/service/buffer_assignment.cc | 5 +- 13 files changed, 460 insertions(+), 180 deletions(-) diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index 9296eeb2671521..ce57cec6e9d4c9 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -473,21 +473,34 @@ LocalClient::CompileAheadOfTime( absl::StatusOr> LocalClient::Load( const std::string& serialized_aot_result, const ExecutableBuildOptions& options) { - TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, - UpdateBuildOptions(options, default_device_ordinal())); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend().stream_executor(updated_options.device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, Compiler::GetForPlatform(platform())); TF_ASSIGN_OR_RETURN( std::unique_ptr aot_result, compiler->LoadAotCompilationResult(serialized_aot_result)); + return LoadInternal(std::move(aot_result), compiler.get(), options); +} + +absl::StatusOr> LocalClient::Load( + std::unique_ptr aot_result, + const ExecutableBuildOptions& options) { + TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, + Compiler::GetForPlatform(platform())); + return LoadInternal(std::move(aot_result), compiler.get(), options); +} + +absl::StatusOr> LocalClient::LoadInternal( + std::unique_ptr aot_result, Compiler* compiler, + const ExecutableBuildOptions& options) { + TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, + UpdateBuildOptions(options, default_device_ordinal())); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + backend().stream_executor(updated_options.device_ordinal())); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - std::move(*aot_result).LoadExecutable(compiler.get(), executor)); + std::move(*aot_result).LoadExecutable(compiler, executor)); return std::make_unique(std::move(executable), local_service_->mutable_backend(), updated_options); diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index c9ee317bc42e5a..c687766fcc37b8 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -174,6 +174,11 @@ class LocalClient : public Client { const std::string& serialized_aot_result, const ExecutableBuildOptions& options); + // Variant of `Load()` that accepts an AotCompilationResult. + absl::StatusOr> Load( + std::unique_ptr aot_result, + const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the @@ -244,6 +249,10 @@ class LocalClient : public Client { private: LocalService* local_service_; + + absl::StatusOr> LoadInternal( + std::unique_ptr aot_result, Compiler* compiler, + const ExecutableBuildOptions& options); }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index dfd717ac84a7ca..afcfe8534c065f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -500,12 +500,18 @@ cc_library( srcs = ["stream_executor_executable.cc"], hdrs = ["stream_executor_executable.h"], deps = [ + ":host_memory_spaces", ":pjrt_common", ":pjrt_executable", ":stream_executor_executable_proto_cc", + "//xla:shape_util", + "//xla:util", + "//xla/client:local_client", "//xla/hlo/ir:hlo", "//xla/service:compiler", + "//xla/service:executable", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -535,6 +541,7 @@ cc_library( ":pjrt_future", ":pjrt_stream_executor_device_description", ":semaphore", + ":stream_executor_executable", ":tracked_device_buffer", ":transpose", ":utils", @@ -577,6 +584,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 12393ed9910c09..d9a0e2df827359 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -897,41 +897,6 @@ StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, load_options); } -absl::StatusOr> -StreamExecutorGpuClient::Load(std::unique_ptr executable) { - auto se_executable = absl::WrapUnique( - tensorflow::down_cast(executable.release())); - - CompileOptions compile_options = se_executable->compile_options(); - CompileOptions input_options = compile_options; - TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); - TF_ASSIGN_OR_RETURN( - ExecutableExtras extras, - UpdateCompileOptionsAndGetExecutableExtras(&compile_options)); - - // Load Executable from AOT compilation result. - std::vector> local_executables; - local_executables.reserve(se_executable->aot_executables().size()); - for (std::unique_ptr& aot_executable : - se_executable->aot_executables()) { - TF_ASSIGN_OR_RETURN(std::string serialized, - aot_executable->SerializeAsString()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr local_executable, - client()->Load(serialized, compile_options.executable_build_options)); - local_executables.push_back(std::move(local_executable)); - } - bool parameter_is_tupled_arguments = - compile_options.parameter_is_tupled_arguments; - auto ret = std::make_unique( - std::move(local_executables), parameter_is_tupled_arguments, - std::move(extras.device_assignment), std::move(input_options), - std::move(extras.addressable_device_logical_ids), - std::move(extras.addressable_devices), this); - TF_RETURN_IF_ERROR(ret->SetUpDonation(parameter_is_tupled_arguments)); - return std::unique_ptr(std::move(ret)); -} - namespace { #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index e920edaf33c580..b2499d6659f1e1 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -140,16 +140,6 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { return &topology_; } - absl::StatusOr> Load( - std::unique_ptr executable, - const LoadOptions& load_options) override { - return absl::WrapUnique( - tensorflow::down_cast(executable.release())); - } - - absl::StatusOr> Load( - std::unique_ptr executable); - absl::StatusOr> LoadSerialized( absl::string_view serialized, std::optional options, const LoadOptions& load_options); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index b86a100d53c59f..03bac8cd5358ca 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1154,7 +1154,7 @@ TEST(StreamExecutorGpuClientTest, GetDeviceFabricInfo) { &executor->GetDeviceDescription())) == 9) { auto fabric_info = GetDeviceFabricInfo(executor->device_ordinal()); if (fabric_info.ok()) { - EXPECT_FALSE(true); + ADD_FAILURE(); } } } @@ -1925,8 +1925,7 @@ TEST(StreamExecutorGpuClientTest, MlirParameterLayoutFromOptionsIsSetInHlo) { xla::CompileOptions options; options.argument_layouts = { {ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 2, 2}, {0, 2, 1})}}; - TF_ASSERT_OK_AND_ASSIGN(auto executable, - client->CompileAndLoad(*module, options)); + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, options)); TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); auto first_param_layout = diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 50e910f58e02ef..93faf18a1706ed 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -126,7 +126,7 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, if (!options.target_config) { if (client != nullptr) { TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); - return client->CompileAndLoad(computation, options); + return client->Compile(computation, options); } const auto& gpu_topology = tensorflow::down_cast( @@ -172,21 +172,16 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, const int num_partitions = hlo_module->config().num_partitions(); const std::string name = hlo_module->name(); const std::string fingerprint = hlo_module->GetFingerprint128(); - const int num_outputs = hlo_module->result_shape().IsTuple() - ? hlo_module->result_shape().tuple_shapes_size() - : 1; auto unique_module_group = std::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> aot_results, gpu_compiler->CompileAheadOfTime(std::move(unique_module_group), aot_options)); - std::vector> output_memory_kinds(1); - output_memory_kinds[0].resize(num_outputs, - StreamExecutorGpuHbmMemorySpace::kKind); return std::make_unique( std::move(input_options), std::move(aot_results), num_replicas, - num_partitions, name, fingerprint, std::move(output_memory_kinds)); + num_partitions, name, fingerprint, + /*default_memory_kind=*/StreamExecutorGpuHbmMemorySpace::kKind); } absl::StatusOr> diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index f18c509108498c..956c5458150297 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -99,8 +99,9 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) { TF_ASSERT_OK_AND_ASSIGN(auto executable, compiler.Compile(opts, mlir_module.get(), *topology, /*client=*/nullptr)); - TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable, - se_client->Load(std::move(executable))); + TF_ASSERT_OK_AND_ASSIGN( + auto loaded_executable, + se_client->Load(std::move(executable), LoadOptions())); TF_ASSERT_OK_AND_ASSIGN( std::vector>> result, @@ -129,7 +130,7 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { compiler.Compile(opts, computation, *topology, /*client=*/nullptr)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr loaded_executable, - se_client->Load(std::move(executable))); + se_client->Load(std::move(executable), LoadOptions())); TF_ASSERT_OK_AND_ASSIGN( std::vector>> result, loaded_executable->Execute(/*argument_handles=*/{{}}, {})); @@ -192,7 +193,7 @@ TEST(StreamExecutorGpuCompilerTest, SuccessSerializeDeserialize) { compiler.Compile(opts, computation, *topology, /*client=*/nullptr)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr loaded_executable, - se_client->Load(std::move(executable))); + se_client->Load(std::move(executable), LoadOptions())); // Serialize the executable and deserialize it without failure. TF_ASSERT_OK_AND_ASSIGN(std::string serialized_executable, diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 6a8a97eb674121..596f5a9edb8d1a 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -87,6 +87,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -118,6 +119,7 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/profiling/device_time_measurement.h" #include "xla/pjrt/semaphore.h" +#include "xla/pjrt/stream_executor_executable.h" #include "xla/pjrt/tracked_device_buffer.h" #include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" @@ -3183,8 +3185,23 @@ PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const { return out; } -void PjRtStreamExecutorClient::UpdateCompileOptions(CompileOptions* options) { +absl::Status PjRtStreamExecutorClient::UpdateCompileOptions( + CompileOptions* options) { + return UpdateCompileOptionsInternal(options, /*returned_extras=*/nullptr); +} + +absl::StatusOr +PjRtStreamExecutorClient::UpdateCompileOptionsAndGetExecutableExtras( + CompileOptions* options) { + ExecutableExtras extras; + TF_RETURN_IF_ERROR(UpdateCompileOptionsInternal(options, &extras)); + return extras; +} + +absl::Status PjRtStreamExecutorClient::UpdateCompileOptionsInternal( + CompileOptions* options, ExecutableExtras* returned_extras) { ExecutableBuildOptions& build_options = options->executable_build_options; + const int original_device_ordinal = build_options.device_ordinal(); if (!build_options.compile_thread_pool()) { build_options.set_compile_thread_pool(thread_pool()); } @@ -3224,15 +3241,17 @@ void PjRtStreamExecutorClient::UpdateCompileOptions(CompileOptions* options) { if (build_options.device_ordinal() < 0) { build_options.set_device_ordinal(0); } -} - -absl::StatusOr -PjRtStreamExecutorClient::UpdateCompileOptionsAndGetExecutableExtras( - CompileOptions* options) { - const int original_device_ordinal = - options->executable_build_options.device_ordinal(); - UpdateCompileOptions(options); + // We need the information of device assignment for + // 1) XLA GPU shard autotuning, as the process count and the current process + // index are needed; + // 2) getting executable extras, as the addressable devices are needed. + const bool use_xla_gpu_shard_autotuning = + build_options.has_debug_options() && + build_options.debug_options().xla_gpu_shard_autotuning(); + if (!use_xla_gpu_shard_autotuning && returned_extras == nullptr) { + return absl::OkStatus(); + } ExecutableExtras extras; std::shared_ptr& device_assignment = @@ -3284,7 +3303,6 @@ PjRtStreamExecutorClient::UpdateCompileOptionsAndGetExecutableExtras( device_assignment->ToString()); } - ExecutableBuildOptions& build_options = options->executable_build_options; if (original_device_ordinal < 0) { build_options.set_device_ordinal( addressable_devices.front()->local_hardware_id().value()); @@ -3293,17 +3311,20 @@ PjRtStreamExecutorClient::UpdateCompileOptionsAndGetExecutableExtras( build_options.set_process_index(*this_process_index); build_options.set_process_count(all_process_indices.size()); } - return extras; + if (returned_extras != nullptr) { + *returned_extras = std::move(extras); + } + return absl::OkStatus(); } -absl::StatusOr> +absl::StatusOr> PjRtStreamExecutorClient::CompileInternal( const XlaComputation& computation, const std::vector& argument_layout_pointers, LayoutCanonicalizationCallback layout_canonicalization_callback, CompileOptions options) { - tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile"); - VLOG(1) << "PjRtStreamExecutorClient::Compile"; + tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::CompileInternal"); + VLOG(1) << "PjRtStreamExecutorClient::CompileInternal"; if (key_value_store().has_value() && !options.executable_build_options.key_value_store()) { options.executable_build_options.set_key_value_store(*key_value_store()); @@ -3311,14 +3332,7 @@ PjRtStreamExecutorClient::CompileInternal( auto input_options = options; TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); - - TF_ASSIGN_OR_RETURN(ExecutableExtras extras, - UpdateCompileOptionsAndGetExecutableExtras(&options)); - std::shared_ptr& device_assignment = - extras.device_assignment; - std::vector& - addressable_device_logical_ids = extras.addressable_device_logical_ids; - std::vector& addressable_devices = extras.addressable_devices; + TF_RETURN_IF_ERROR(UpdateCompileOptions(&options)); // It is important to set the canonicalization callback after creating // a copy of the options so that the executable's options remain without @@ -3333,26 +3347,39 @@ PjRtStreamExecutorClient::CompileInternal( client()->Compile(computation, argument_layout_pointers, options.executable_build_options)); - auto executable = std::make_unique( - std::move(local_executables), options.parameter_is_tupled_arguments, - std::move(device_assignment), std::move(input_options), - std::move(addressable_device_logical_ids), std::move(addressable_devices), - this); + return BuildPjRtExecutable(std::move(local_executables), input_options); +} - TF_RETURN_IF_ERROR( - executable->SetUpDonation(options.parameter_is_tupled_arguments)); - const auto& ex_options = options.executable_build_options; - if (ex_options.has_debug_options() && - ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots()) { - executable->SetInputHloSnapshotBits( - computation.proto(), options.executable_build_options.debug_options()); - } - return std::unique_ptr(std::move(executable)); +absl::StatusOr> +PjRtStreamExecutorClient::Compile(const XlaComputation& computation, + CompileOptions options) { + std::vector argument_layout_pointers; + const ExecutableBuildOptions& build_options = + options.executable_build_options; + const bool allow_auto_layout = + build_options.has_debug_options() && + build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); + TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( + computation, + [local_client = client(), + allow_auto_layout](Shape shape) -> absl::StatusOr { + if (allow_auto_layout && !shape.has_layout()) { + return shape; + } + return local_client->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(shape); + }, + options.argument_layouts, &options.executable_build_options, + &argument_layout_pointers)); + return CompileInternal(computation, argument_layout_pointers, + /* layout_canonicalization_callback = */ nullptr, + options); } -absl::StatusOr> -PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, - CompileOptions options) { +absl::StatusOr> +PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, + CompileOptions options) { XlaComputation xla_computation; const ExecutableBuildOptions& exec_build_options = options.executable_build_options; @@ -3364,7 +3391,7 @@ PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, // If the compile options specify argument layout, then let's // fall back to using the options to determine layouts. if (options.argument_layouts) { - return CompileAndLoad(xla_computation, options); + return Compile(xla_computation, options); } TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, @@ -3414,28 +3441,17 @@ PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, absl::StatusOr> PjRtStreamExecutorClient::CompileAndLoad(const XlaComputation& computation, CompileOptions options) { - std::vector argument_layout_pointers; - const ExecutableBuildOptions& build_options = - options.executable_build_options; - const bool allow_auto_layout = - build_options.has_debug_options() && - build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); - TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( - computation, - [local_client = client(), - allow_auto_layout](Shape shape) -> absl::StatusOr { - if (allow_auto_layout && !shape.has_layout()) { - return shape; - } - return local_client->backend() - .transfer_manager() - ->ChooseCompactLayoutForShape(shape); - }, - options.argument_layouts, &options.executable_build_options, - &argument_layout_pointers)); - return CompileInternal(computation, argument_layout_pointers, - /* layout_canonicalization_callback = */ nullptr, - options); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + Compile(computation, options)); + return Load(std::move(executable), LoadOptions()); +} + +absl::StatusOr> +PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module, + CompileOptions options) { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + Compile(module, options)); + return Load(std::move(executable), LoadOptions()); } absl::StatusOr PjRtStreamExecutorClient::SerializeExecutable( @@ -3472,21 +3488,55 @@ absl::StatusOr PjRtStreamExecutorClient::SerializeExecutable( return proto.SerializeAsString(); } -absl::StatusOr> -PjRtStreamExecutorClient::LoadSerializedExecutable( - absl::string_view serialized, std::optional options, - const LoadOptions& load_options) { +absl::StatusOr> +PjRtStreamExecutorClient::BuildPjRtExecutable( + std::vector> local_executables, + CompileOptions compile_options) { + if (local_executables.empty()) { + return Internal("No local executable"); + } + if (local_executables.size() != 1) { + return Unimplemented("Multiple executables are not supported"); + } + Executable* built_executable = local_executables[0]->executable(); + if (!built_executable->has_module()) { + return absl::InternalError("Executable does not have HLO modules."); + } + const auto& hlo_module = built_executable->module(); + + const int num_replicas = hlo_module.config().replica_count(); + const int num_partitions = hlo_module.config().num_partitions(); + const std::string name = hlo_module.name(); + const std::string fingerprint = hlo_module.GetFingerprint128(); + + return std::make_unique( + std::move(compile_options), std::move(local_executables), client_, + num_replicas, num_partitions, name, fingerprint, + memory_spaces()[0]->kind()); +} + +absl::StatusOr> +PjRtStreamExecutorClient::DeserializeExecutable( + absl::string_view serialized, + std::optional compile_options) { + TF_ASSIGN_OR_RETURN( + auto local_executables_and_options, + DeserializeToLocalExecutable(serialized, compile_options)); + + return BuildPjRtExecutable(std::move(local_executables_and_options.first), + local_executables_and_options.second); +} + +absl::StatusOr< + std::pair>, CompileOptions>> +PjRtStreamExecutorClient::DeserializeToLocalExecutable( + absl::string_view serialized, std::optional options) { ExecutableAndOptionsProto proto; if (serialized.size() > std::numeric_limits::max()) { - return Internal( - "PjRtStreamExecutorClient::DeserializeExecutable proto too large " - "(>2GB)"); + return Internal("Proto is too large (>2GB)"); } if (!proto.ParseFromArray(serialized.data(), serialized.size())) { - return Internal( - "PjRtStreamExecutorClient::DeserializeExecutable proto " - "deserialization " - "failed"); + return Internal("Proto deserialization failed"); } CompileOptions compile_options; @@ -3496,11 +3546,39 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( TF_ASSIGN_OR_RETURN(compile_options, CompileOptions::FromProto(proto.compile_options())); } - auto input_options = compile_options; tsl::profiler::TraceMe traceme( - "PjRtStreamExecutorClient::DeserializeExecutable"); - VLOG(1) << "PjRtStreamExecutorClient::DeserializeExecutable"; + "PjRtStreamExecutorClient::DeserializeToLocalExecutable"); + VLOG(1) << "PjRtStreamExecutorClient::DeserializeToLocalExecutable"; + + std::string str = std::move(*proto.mutable_serialized_executable()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr loaded, + client()->Load(str, compile_options.executable_build_options)); + + std::vector> local_executables; + local_executables.push_back(std::move(loaded)); + + return std::make_pair(std::move(local_executables), compile_options); +} + +absl::StatusOr> +PjRtStreamExecutorClient::LoadSerializedExecutable( + absl::string_view serialized, std::optional options, + const LoadOptions& load_options) { + TF_ASSIGN_OR_RETURN(auto local_executables_and_options, + DeserializeToLocalExecutable(serialized, options)); + return LoadInternal(std::move(local_executables_and_options.first), + local_executables_and_options.second); +} + +absl::StatusOr> +PjRtStreamExecutorClient::LoadInternal( + std::vector> local_executables, + CompileOptions compile_options) { + auto input_options = compile_options; + + TF_RETURN_IF_ERROR(compile_options.ApplyAllOptionOverrides()); TF_ASSIGN_OR_RETURN( ExecutableExtras extras, @@ -3511,13 +3589,14 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( addressable_device_logical_ids = extras.addressable_device_logical_ids; std::vector& addressable_devices = extras.addressable_devices; - std::string str = std::move(*proto.mutable_serialized_executable()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded, - client()->Load(str, compile_options.executable_build_options)); - - std::vector> local_executables; - local_executables.push_back(std::move(loaded)); + const auto& ex_options = compile_options.executable_build_options; + const bool xla_gpu_dump_hlo_unoptimized_snapshots = + ex_options.has_debug_options() && + ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots(); + HloModuleProto hlo_module_proto; + if (xla_gpu_dump_hlo_unoptimized_snapshots) { + hlo_module_proto = local_executables[0]->executable()->module().ToProto(); + } auto executable = std::make_unique( std::move(local_executables), @@ -3528,9 +3607,29 @@ PjRtStreamExecutorClient::LoadSerializedExecutable( TF_RETURN_IF_ERROR( executable->SetUpDonation(compile_options.parameter_is_tupled_arguments)); + if (xla_gpu_dump_hlo_unoptimized_snapshots) { + executable->SetInputHloSnapshotBits( + std::move(hlo_module_proto), + compile_options.executable_build_options.debug_options()); + } return std::unique_ptr(std::move(executable)); } +absl::StatusOr> +PjRtStreamExecutorClient::Load(std::unique_ptr executable, + const LoadOptions& load_options) { + auto se_executable = absl::WrapUnique( + tensorflow::down_cast(executable.release())); + CompileOptions compile_options = se_executable->compile_options(); + + tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Load"); + VLOG(1) << "PjRtStreamExecutorClient::Load"; + + TF_ASSIGN_OR_RETURN(auto local_executables, se_executable->ConsumeExecutable( + client(), compile_options)); + return LoadInternal(std::move(local_executables), compile_options); +} + bool PjRtStreamExecutorClient::IsDmaMapped(const void* data_start, int64_t transfer_size) { absl::MutexLock lock(&dma_maps_mutex_); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 5f1cd62924ed4d..bf3a98c0998eb6 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -277,14 +277,22 @@ class PjRtStreamExecutorClient : public PjRtClient { absl::StatusOr GetDefaultLayout( PrimitiveType element_type, absl::Span dims) override; + absl::StatusOr> Compile( + const XlaComputation& computation, CompileOptions options) override; absl::StatusOr> CompileAndLoad( const XlaComputation& computation, CompileOptions options) override; + absl::StatusOr> Compile( + mlir::ModuleOp mlir_module, CompileOptions options) override; absl::StatusOr> CompileAndLoad( mlir::ModuleOp mlir_module, CompileOptions options) override; virtual absl::StatusOr SerializeExecutable( const PjRtLoadedExecutable& executable) const; + absl::StatusOr> DeserializeExecutable( + absl::string_view serialized, + std::optional options) override; + // For PjRtStreamExecutorClient, `options` is mandatory. // This function returns an InvalidArgument error if `std::nullopt` is passed. // TODO(b/237720161): make it actually optional @@ -293,6 +301,10 @@ class PjRtStreamExecutorClient : public PjRtClient { std::optional options, const LoadOptions& load_options) override; + absl::StatusOr> Load( + std::unique_ptr executable, + const LoadOptions& load_options) override; + absl::StatusOr> GetHloCostAnalysis() const override; @@ -413,18 +425,36 @@ class PjRtStreamExecutorClient : public PjRtClient { }; // Updates `options` for compilation. - void UpdateCompileOptions(CompileOptions* options); + absl::Status UpdateCompileOptions(CompileOptions* options); // Same as above, but also returns the executable extras. absl::StatusOr UpdateCompileOptionsAndGetExecutableExtras( CompileOptions* options); - absl::StatusOr> CompileInternal( + // Updates `options` for compilation, and gets the executable extras if + // `returned_extras` is not null. + absl::Status UpdateCompileOptionsInternal(CompileOptions* options, + ExecutableExtras* returned_extras); + + absl::StatusOr> CompileInternal( const XlaComputation& computation, const std::vector& argument_layout_pointers, LayoutCanonicalizationCallback layout_canonicalization_callback, CompileOptions options); + absl::StatusOr> BuildPjRtExecutable( + std::vector> local_executables, + CompileOptions compile_options); + + absl::StatusOr< + std::pair>, CompileOptions>> + DeserializeToLocalExecutable(absl::string_view serialized, + std::optional options); + + absl::StatusOr> LoadInternal( + std::vector> local_executables, + CompileOptions compile_options); + absl::StatusOr> BufferFromHostBufferInternal( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.cc b/third_party/xla/xla/pjrt/stream_executor_executable.cc index ab82fdaf0c2ec1..91b80c72510669 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.cc +++ b/third_party/xla/xla/pjrt/stream_executor_executable.cc @@ -18,27 +18,52 @@ limitations under the License. #include #include #include +#include +#include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/client/local_client.h" +#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/compiler.h" +#include "xla/service/executable.h" +#include "xla/shape.h" +#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { absl::StatusOr StreamExecutorExecutable::SerializeExecutable() const { - if (aot_executables_.empty()) { - return absl::InternalError("No local executable"); - } - if (aot_executables_.size() != 1) { - return absl::UnimplementedError( - "PjRtStreamExecutorClient::SerializeExecutable unimplemented for MPMD " - "executables"); + std::string serialized; + if (std::holds_alternative< + std::vector>>( + executables_)) { + const auto& aot_executables = + std::get>>( + executables_); + if (aot_executables.empty()) { + return absl::InternalError("No local executable"); + } + if (aot_executables.size() != 1) { + return absl::UnimplementedError( + "PjRtStreamExecutorClient::SerializeExecutable unimplemented for " + "MPMD executables"); + } + TF_ASSIGN_OR_RETURN(serialized, aot_executables[0]->SerializeAsString()); + } else { + const auto& local_executables = + std::get>>(executables_); + Executable* built_executable = local_executables[0]->executable(); + CHECK(local_client_ != nullptr); + TF_ASSIGN_OR_RETURN( + std::unique_ptr aot_result, + local_client_->backend().compiler()->Export(built_executable)); + + TF_ASSIGN_OR_RETURN(serialized, aot_result->SerializeAsString()); } - TF_ASSIGN_OR_RETURN(std::string serialized, - aot_executables_[0]->SerializeAsString()); if (serialized.empty()) { return absl::InternalError( "PjRtStreamExecutorClient::SerializeExecutable proto serialization " @@ -50,4 +75,84 @@ absl::StatusOr StreamExecutorExecutable::SerializeExecutable() compile_options_.ToProto()); return proto.SerializeAsString(); } + +namespace { + +absl::StatusOr MemoryKindFromSimpleShape( + const Shape& shape, absl::string_view default_memory_kind) { + if (!shape.has_layout()) { + return default_memory_kind; + } + switch (shape.layout().memory_space()) { + case Layout::kHostMemorySpace: + return PinnedHostMemorySpace::kKind; + case Layout::kGenericFastMemorySpace: + case Layout::kDefaultMemorySpace: + return default_memory_kind; + default: + return InvalidArgument("Unexpected memory space %d in output layout", + shape.layout().memory_space()); + } +} + +absl::StatusOr> MemoryKindsFromShape( + const Shape& shape, absl::string_view default_memory_kind) { + if (!shape.IsTuple()) { + TF_ASSIGN_OR_RETURN(absl::string_view memory_kind, + MemoryKindFromSimpleShape(shape, default_memory_kind)); + return {{memory_kind}}; + } + std::vector result; + result.reserve(shape.tuple_shapes_size()); + for (const auto& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + absl::string_view element_memory_kind, + MemoryKindFromSimpleShape(element_shape, default_memory_kind)); + result.push_back(element_memory_kind); + } + return result; +} + +} // namespace + +absl::StatusOr>> +StreamExecutorExecutable::GetOutputMemoryKinds() const { + TF_ASSIGN_OR_RETURN(auto shapes, GetOutputShapes()); + std::vector> out; + out.reserve(shapes.size()); + for (const auto& shape : shapes) { + TF_ASSIGN_OR_RETURN(std::vector memory_kind, + MemoryKindsFromShape(shape, default_memory_kind_)); + out.push_back(memory_kind); + } + return out; +} + +absl::StatusOr>> +StreamExecutorExecutable::ConsumeExecutable( + LocalClient* client, const CompileOptions& compile_options) { + if (std::holds_alternative>>( + executables_)) { + return std::get>>( + std::move(executables_)); + } else if (std::holds_alternative< + std::vector>>( + executables_)) { + auto aot_executables = + std::get>>( + std::move(executables_)); + std::vector> local_executables; + local_executables.reserve(aot_executables.size()); + for (int i = 0; i < aot_executables.size(); ++i) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr local_executable, + client->Load(std::move(aot_executables[i]), + compile_options.executable_build_options)); + local_executables.push_back(std::move(local_executable)); + } + return local_executables; + } + return absl::UnimplementedError("Unsupported executable type."); +} + } // namespace xla diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.h b/third_party/xla/xla/pjrt/stream_executor_executable.h index 826e4f2912f176..bc984613f7d99e 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.h +++ b/third_party/xla/xla/pjrt/stream_executor_executable.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/client/local_client.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" @@ -39,15 +40,37 @@ class StreamExecutorExecutable : public PjRtExecutable { const CompileOptions& compile_options, std::vector> executables, int num_replicas, int num_partitions, absl::string_view name, - absl::string_view fingerprint, - std::optional>> - output_memory_kinds) + absl::string_view fingerprint, absl::string_view default_memory_kind) : compile_options_(compile_options), - aot_executables_(std::move(executables)), + executables_(std::move(executables)), num_replicas_(num_replicas), num_partitions_(num_partitions), name_(name), - fingerprint_(fingerprint) {} + fingerprint_(fingerprint), + default_memory_kind_(default_memory_kind) {} + + StreamExecutorExecutable( + const CompileOptions& compile_options, + std::vector> local_executables, + LocalClient* local_client, int num_replicas, int num_partitions, + absl::string_view name, absl::string_view fingerprint, + absl::string_view default_memory_kind) + : compile_options_(compile_options), + executables_(std::move(local_executables)), + local_client_(local_client), + num_replicas_(num_replicas), + num_partitions_(num_partitions), + name_(name), + fingerprint_(fingerprint), + default_memory_kind_(default_memory_kind) { + std::vector> hlo_modules; + for (const auto& local_executable : + std::get>>( + executables_)) { + hlo_modules.push_back(local_executable->executable()->shared_module()); + } + hlo_modules_ = std::move(hlo_modules); + } absl::StatusOr SerializeExecutable() const override; @@ -59,27 +82,64 @@ class StreamExecutorExecutable : public PjRtExecutable { } absl::StatusOr>> GetHloModules() const override { - return absl::UnimplementedError("GetHloModules is not supported."); + if (!hlo_modules_.has_value()) { + return absl::UnimplementedError("GetHloModules is not supported."); + } + return *hlo_modules_; } - absl::StatusOr>> - GetOutputMemoryKinds() const override { - if (output_memory_kinds_.has_value()) { - return *output_memory_kinds_; + absl::StatusOr GetCompiledMemoryStats() const override { + if (std::holds_alternative< + std::vector>>( + executables_)) { + return absl::UnimplementedError( + "Retrieving CompiledMemoryStats is not supported."); } - return absl::UnimplementedError("GetOutputMemoryKinds is not supported."); + const auto& local_executables = + std::get>>(executables_); + if (local_executables.size() != 1) { + return absl::UnimplementedError( + "Retrieving CompiledMemoryStats is not supported for multiple " + "executables."); + } + CompiledMemoryStats memory_stats = CompiledMemoryStats(); + memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); + const HloProto* proto = local_executables[0]->executable()->hlo_proto(); + if (proto != nullptr) { + memory_stats.serialized_hlo_proto = proto->SerializeAsString(); + } + memory_stats.PopulateBufferStatsFromAllocations( + local_executables[0]->executable()->GetAllocations()); + return memory_stats; } + + absl::StatusOr>> + GetOutputMemoryKinds() const override; + absl::StatusOr> GetCostAnalysis() const override { return absl::UnimplementedError("GetCostAnalysis is not supported."); } - int64_t SizeOfGeneratedCodeInBytes() const override { return 0; } + int64_t SizeOfGeneratedCodeInBytes() const override { + if (std::holds_alternative< + std::vector>>( + executables_)) { + return 0; + } + int64_t size = 0; + for (auto& executable : + std::get>>( + executables_)) { + size += executable->executable()->SizeOfGeneratedCodeInBytes(); + } + return size; + } const CompileOptions& compile_options() const { return compile_options_; } - std::vector>& aot_executables() { - return aot_executables_; - } + + absl::StatusOr>> + ConsumeExecutable(LocalClient* client, const CompileOptions& compile_options); absl::StatusOr FingerprintExecutable() const override { return fingerprint_; @@ -87,13 +147,16 @@ class StreamExecutorExecutable : public PjRtExecutable { private: CompileOptions compile_options_; - std::vector> aot_executables_; + std::variant>, + std::vector>> + executables_; + LocalClient* local_client_ = nullptr; + std::optional>> hlo_modules_; int num_replicas_; int num_partitions_; std::string name_; std::string fingerprint_; - std::optional>> - output_memory_kinds_; + absl::string_view default_memory_kind_; }; } // namespace xla diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index c809be5b0971f6..1e6ea95df721fa 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -1225,8 +1225,11 @@ absl::StatusOr> BufferAssignment::FromProto( absl::c_copy(alloc_proto.parameter_shape_index(), std::back_inserter(shape_idx_vals)); ShapeIndex shape_index(shape_idx_vals); + const bool parameter_has_alias = + module->input_output_alias_config().ParameterHasAlias( + alloc_proto.parameter_number(), shape_index); allocation->set_entry_computation_parameter( - alloc_proto.parameter_number(), shape_index, false); + alloc_proto.parameter_number(), shape_index, parameter_has_alias); } // Process each logical buffer assigned to the current allocation and create From 33a7c5ab5c3affd9254138c9e186e12bfac6a093 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 10 Apr 2025 01:09:59 -0700 Subject: [PATCH 0487/1324] [HLO Componentization] Migrate away from deprecated build targets and header files. PiperOrigin-RevId: 745913712 --- tensorflow/compiler/mlir/lite/stablehlo/BUILD | 7 +++---- .../compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc | 6 +++--- tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD | 2 +- third_party/xla/xla/hlo/pass/BUILD | 2 +- third_party/xla/xla/hlo/transforms/expanders/BUILD | 2 -- third_party/xla/xla/pjrt/distributed/BUILD | 1 - third_party/xla/xla/pjrt/distributed/topology_util_test.cc | 1 - third_party/xla/xla/service/gpu/transforms/BUILD | 2 +- .../xla/service/gpu/transforms/block_scaling_rewriter.h | 2 +- 9 files changed, 10 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 2f25ec4532b233..b1a472ba5fcc2e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -166,13 +166,12 @@ cc_library( "@local_tsl//tsl/platform:bfloat16", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/client/lib:conv_grad_size_util", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:sharding_builder", + "@local_xla//xla/hlo/builder/lib:conv_grad_size_util", "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:convert_op_folder", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/tsl/platform:status", "@stablehlo//:chlo_ops", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc index d1e7dd75dcfa9d..c4c227c22bb149 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc @@ -59,9 +59,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" #include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 697ef42d3c6ac2..a70ca99ff30cbf 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -555,7 +555,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", ], ) diff --git a/third_party/xla/xla/hlo/pass/BUILD b/third_party/xla/xla/hlo/pass/BUILD index fdc72786360172..aa4c6f25ff7e3c 100644 --- a/third_party/xla/xla/hlo/pass/BUILD +++ b/third_party/xla/xla/hlo/pass/BUILD @@ -53,13 +53,13 @@ xla_cc_test( ":hlo_pass", "//xla:literal_util", "//xla:shape_util", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test_helpers", "//xla/service:hlo_proto_cc", "//xla/service:pattern_matcher", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/hlo/transforms/expanders/BUILD b/third_party/xla/xla/hlo/transforms/expanders/BUILD index 3207f550271c76..e7c91bc7c51900 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/BUILD +++ b/third_party/xla/xla/hlo/transforms/expanders/BUILD @@ -156,7 +156,6 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_creation_utils", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -289,7 +288,6 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/service:dynamic_padder", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", # fixdeps: keep diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index a288d5fd8c9fa3..0efe8a1c641eb5 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -49,7 +49,6 @@ xla_cc_test( ":in_memory_key_value_store", ":protocol_proto_cc", ":topology_util", - "//xla:test_helpers", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index a911c779758522..5ccdd234e5137c 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/types/span.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/status_matchers.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 2fab25b5c14fbb..6eb28726d1db21 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -487,8 +487,8 @@ cc_library( "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:constants", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms/expanders:op_expander_pass", "//xla/service:hlo_creation_utils", - "//xla/service:op_expander_pass", "//xla/service:shape_inference", "//xla/service/gpu:cublas_cudnn", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter.h b/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter.h index ecdb1ad60e8a47..37d261fe6e1888 100644 --- a/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla::gpu { From 5f31a9a4deee460e2c768cecccbfeebb13b4b987 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Thu, 10 Apr 2025 01:52:21 -0700 Subject: [PATCH 0488/1324] [XLA:CPU] Split out codegen of fusion call targets PiperOrigin-RevId: 745925825 --- .../codegen/emitters/cpu_fusion_emitter.cc | 131 ++++++++++-------- .../cpu/codegen/emitters/cpu_fusion_emitter.h | 19 +-- .../emitters/computation_partitioner.cc | 2 +- 3 files changed, 82 insertions(+), 70 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc index 47d79125fd7bc3..9db6088ea4ffc8 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc @@ -191,76 +191,24 @@ CpuFusionEmitterBase::CreateMLIRModule( mlir::OpBuilder builder(&context); auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); + SetDataLayoutAttribute(module.get(), fusion); TF_ASSIGN_OR_RETURN( mlir::func::FuncOp entry_func, EmitFusionKernelApi(module.get(), fusion, entry_function_name, buffer_assignment)); - TF_RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion)); - return module; -} - -// NOLINTNEXTLINE(readability-function-cognitive-complexity) -absl::Status CpuFusionEmitterBase::EmitMlir( - mlir::ModuleOp module, FuncOp entry_function, - const HloFusionInstruction& fusion) const { std::vector epilogues = - GetEpilogues(fusion, module->getContext()); + GetEpilogues(fusion, &context); emitters::PartitionedComputations computations( - fusion.fused_instructions_computation(), module->getContext(), - /*epilogues=*/epilogues); - auto subgraph_to_mlir_fn = computations.DeclareFunctions(module); - - // Erase subgraphs for all heroes that aren't used anywhere else. This is - // necessary because the instructions may not have elemental implementations - // (scatter). - for (const auto& epilogue : epilogues) { - for (auto* custom : epilogue.heroes) { - if (custom->user_count() == 0) { - subgraph_to_mlir_fn.extract(&computations.FindSubgraph(custom)) - .mapped() - .erase(); - } - } - } - - // The epilogue functions replace the root tuple. - auto* root = fusion.fused_instructions_computation()->root_instruction(); - if (root->opcode() == HloOpcode::kTuple && !epilogues.empty()) { - subgraph_to_mlir_fn.extract(&computations.FindSubgraph(root)) - .mapped() - .erase(); - } - - auto call_targets = - computations.CreateCallTargetProvider(subgraph_to_mlir_fn); - for (const auto& comp : computations.partitioned_computations()) { - for (const auto& subgraph : comp.subgraphs()) { - if (subgraph_to_mlir_fn.contains(&subgraph)) { - TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( - comp, subgraph, subgraph_to_mlir_fn[&subgraph], call_targets)); - } - } - } - for (const auto& epilogue : computations.epilogues()) { - if (epilogue.roots.empty()) continue; - TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( - computations.FindPartitionedComputation( - fusion.fused_instructions_computation()), - epilogue, subgraph_to_mlir_fn[&epilogue], call_targets)); - } - - int index_bitwidth = - Needs64BitIndices(fusion.fused_instructions_computation()) ? 64 : 32; - mlir::OpBuilder b(module->getContext()); - auto index_layout = mlir::DataLayoutEntryAttr::get( - b.getIndexType(), b.getI32IntegerAttr(index_bitwidth)); - module->setAttr( - mlir::DLTIDialect::kDataLayoutAttrName, - mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout})); + fusion.fused_instructions_computation(), &context, epilogues); + TF_ASSIGN_OR_RETURN( + emitters::CallTargetProvider call_targets, + EmitCallTargets(module.get(), fusion, computations, epilogues)); - return EmitEntryFunction(computations, call_targets, entry_function, fusion); + TF_RETURN_IF_ERROR( + EmitEntryFunction(computations, call_targets, entry_func, fusion)); + return module; } using mlir::AffineExpr; @@ -396,6 +344,67 @@ absl::StatusOr EmitFusionKernelApi( return entry_func; } +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +absl::StatusOr EmitCallTargets( + mlir::ModuleOp module, const HloFusionInstruction& fusion, + const emitters::PartitionedComputations& computations, + const std::vector& epilogues) { + auto subgraph_to_mlir_fn = computations.DeclareFunctions(module); + + // Erase subgraphs for all heroes that aren't used anywhere else. This is + // necessary because the instructions may not have elemental implementations + // (scatter). + for (const auto& epilogue : epilogues) { + for (auto* custom : epilogue.heroes) { + if (custom->user_count() == 0) { + subgraph_to_mlir_fn.extract(&computations.FindSubgraph(custom)) + .mapped() + .erase(); + } + } + } + + // The epilogue functions replace the root tuple. + auto* root = fusion.fused_instructions_computation()->root_instruction(); + if (root->opcode() == HloOpcode::kTuple && !epilogues.empty()) { + subgraph_to_mlir_fn.extract(&computations.FindSubgraph(root)) + .mapped() + .erase(); + } + + auto call_targets = + computations.CreateCallTargetProvider(subgraph_to_mlir_fn); + for (const auto& comp : computations.partitioned_computations()) { + for (const auto& subgraph : comp.subgraphs()) { + if (subgraph_to_mlir_fn.contains(&subgraph)) { + TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( + comp, subgraph, subgraph_to_mlir_fn[&subgraph], call_targets)); + } + } + } + for (const auto& epilogue : computations.epilogues()) { + if (epilogue.roots.empty()) continue; + TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( + computations.FindPartitionedComputation( + fusion.fused_instructions_computation()), + epilogue, subgraph_to_mlir_fn[&epilogue], call_targets)); + } + + return call_targets; +} + +void SetDataLayoutAttribute(mlir::ModuleOp module, + const HloFusionInstruction& fusion) { + int index_bitwidth = + Needs64BitIndices(fusion.fused_instructions_computation()) ? 64 : 32; + mlir::OpBuilder b(module->getContext()); + auto index_layout = mlir::DataLayoutEntryAttr::get( + b.getIndexType(), b.getI32IntegerAttr(index_bitwidth)); + module->setAttr( + mlir::DLTIDialect::kDataLayoutAttrName, + mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout})); +} + int64_t CeilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h index 820886a57aa108..455c5f553ded17 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h @@ -58,6 +58,17 @@ absl::StatusOr EmitFusionKernelApi( const std::string& entry_function_name, const BufferAssignment& buffer_assignment); +// Emit the call targets for the given fusion. +absl::StatusOr EmitCallTargets( + mlir::ModuleOp module, const HloFusionInstruction& fusion, + const emitters::PartitionedComputations& computations, + const std::vector& epilogues); + +// Set the data layout attribute of the module based on the called instructions +// of the fusion. +void SetDataLayoutAttribute(mlir::ModuleOp module, + const HloFusionInstruction& fusion); + class CpuFusionEmitterBase { public: CpuFusionEmitterBase(mlir::MLIRContext* mlir_context, @@ -117,14 +128,6 @@ class CpuFusionEmitterBase { llvm::LLVMContext* llvm_context_; const BufferAssignment& buffer_assignment_; const HloFusionInstruction* fusion_; - - private: - // Emits MLIR for the given fusion. The entry function has one tensor argument - // per fusion parameter and output and one tensor result per fusion output. - // The fuson outputs may only be used with `tensor.insert` ops.a - absl::Status EmitMlir(mlir::ModuleOp module, - mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const; }; int64_t CeilDiv(int64_t a, int64_t b); diff --git a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc index aefcaf0ca5be95..f90d3d911f5393 100644 --- a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc @@ -438,7 +438,7 @@ const PartitionedComputation::Subgraph& PartitionedComputations::FindSubgraph( CallTargetProvider PartitionedComputations::CreateCallTargetProvider( const absl::flat_hash_map& subgraph_to_func) const { - return [&, this](const HloInstruction* instr) { + return [subgraph_to_func, this](const HloInstruction* instr) { const auto& subgraph = FindSubgraph(instr); CHECK(subgraph_to_func.contains(&subgraph)) << "No function found for subgraph with instruction " From c995cc0f438d608e978238150bdca6f77994d478 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Thu, 10 Apr 2025 01:53:05 -0700 Subject: [PATCH 0489/1324] Compute the number of warps we should have based on hardware properties Replaces the current placeholder value. We still have a slightly more distilled placeholder of what we are trying to achieve in terms of occupancy, but figuring that out is a problem for another day. PiperOrigin-RevId: 745926087 --- .../gpu/autotuning/dot_search_space.cc | 12 ++++++++++-- .../service/gpu/autotuning/dot_search_space.h | 14 ++++++++++++++ .../gpu/autotuning/dot_search_space_test.cc | 19 ++++++++++++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 245ee5a3ba0c74..f9982d35bebf62 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -92,10 +92,11 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( rhs_parallel_size_(ShapeUtil::ElementsIn(dot->operand(1)->shape()) / (contracting_size_ * batch_size_)), compute_bitwidth_(primitive_util::BitWidth(dot->shape().element_type())), + // Figure out some basic limitations on tiling based on the above. + desired_total_warps_(GetDesiredTotalWarps()), + max_out_tile_(GetMaxOutputTile()), // TODO: b/404470821 - Compute these from the problem properties instead // of hardcoding. - desired_total_warps_(2160), - max_out_tile_(GetMaxOutputTile()), min_out_tile_{16, 16}, min_warps_per_cta_(4), min_contracting_tile_size_(16), @@ -160,6 +161,13 @@ std::string TritonDotFusionSearchSpace::Serialize() { min_contracting_tile_size_, desired_total_warps_, min_warps_per_cta_); } +int TritonDotFusionSearchSpace::GetDesiredTotalWarps() const { + constexpr int kSchedulersPerCore = 4; + constexpr int kDesiredWarpsPerCore = + kMaxWarpsPerScheduler * kSchedulersPerCore; + return kDesiredWarpsPerCore * device_description_.core_count(); +} + TritonDotFusionSearchSpace::OutputTile TritonDotFusionSearchSpace::GetMaxOutputTile() const { constexpr int kRegisterSizeInBits = 32; diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index 44025a65dbc96f..fdeccf57936db4 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -71,6 +71,12 @@ class TritonDotFusionSearchSpace { std::string ToString() const { return config.ToString(); } }; + // Approximation on the maximum number of warps we would want to oversubscribe + // the SMs with to overlap different GPU pipes (memory, tensor core, ALU, + // special function unit, etc.) + // TODO: b/408114338 - Figure out a better model for this. + static constexpr int kMaxWarpsPerScheduler = 5; + // Callback type for `ExtendConfigs`. The method should append zero or more // extensions of `config` to the `updated_configs` vector. using ExtendConfigCallback = void (TritonDotFusionSearchSpace::*)( @@ -83,6 +89,14 @@ class TritonDotFusionSearchSpace { void ExtendConfigs(std::vector& configs, ExtendConfigCallback extend_config); + // Computes the maximum number of total warps we should have to sufficiently + // saturate the GPU. + // + // We're counting warps instead of blocks here, since we already need this + // value as a consideration to decide how large the blocks should be (which + // then impacts how many of them we should have). + int GetDesiredTotalWarps() const; + // Computes the maximum sensible size of the output tile (block_m, block_n) // based on the dot shape and element type, and the available registers on // the core. diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index 6d150dff1f91a3..ccc4fb178ee635 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -118,7 +118,7 @@ TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { EXPECT_EQ(search_space.Serialize(), "problem_size_BxMxNxKxE: 1x1024x1024x1024x16 " "tile_range_SxMxNxK: [1-64]x[16-256]x[16-512]x[16-?] " - "desired_total_warps: 2160 warps_per_cta: [4-?]"); + "desired_total_warps: 2640 warps_per_cta: [4-?]"); } TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { @@ -222,5 +222,22 @@ TEST_F(DotSearchSpaceTest, HonorsMinimumOutputTileSizeForTinyProblem) { AllOf(Not(IsEmpty()), Each(BlockMIs(Ge(16))), Each(BlockNIs(Ge(16))))); } +TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/512, + /*contracting_dim=*/1024)); + auto search_space = MakeSearchSpace(module.get()); + + // 1024x512 elements / 32x32 elements/block = 32x16 blocks = 512 blocks. + // 512 blocks * 4 warps/block = 2048 warps. + // 132 cores * 4 schedulers/core * 5 desired warps/scheduler = 2640 desired + // warps. + // ceil(2640 desired warps / 2048 warps) = ceil(1.3) = 2 desired split + EXPECT_THAT(search_space.GenerateConfigs(), + Contains(AllOf(BlockMIs(Eq(32)), BlockNIs(Eq(32)), + NumWarpsIs(Eq(4)), SplitKIs(Eq(2))))); +} + } // namespace } // namespace xla::gpu From 852a5e8a4ba6f98f6846b17250e55140a9389f59 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 02:02:45 -0700 Subject: [PATCH 0490/1324] Update GraphDef version to 2193. PiperOrigin-RevId: 745929183 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 8cb8f11d7f0d53..59eef5704606e7 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2192 // Updated: 2025/4/9 +#define TF_GRAPH_DEF_VERSION 2193 // Updated: 2025/4/10 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From e27d915fc3e5c51aca37f8992277349c2ccf9e86 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 02:03:32 -0700 Subject: [PATCH 0491/1324] compat: Update forward compatibility horizon to 2025-04-10 PiperOrigin-RevId: 745929481 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index beebbbb3444e76..75b98c6cb2b296 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 9) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 10) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 33c50f177c1d151da54b2ab3734dbe65eda3baa2 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Thu, 10 Apr 2025 02:56:06 -0700 Subject: [PATCH 0492/1324] [XLA:CPU] Make CpuFusionEmitterBase a pure interface class PiperOrigin-RevId: 745944794 --- .../xla/backends/cpu/codegen/emitters/BUILD | 2 +- .../codegen/emitters/cpu_fusion_emitter.cc | 118 ++++++------------ .../cpu/codegen/emitters/cpu_fusion_emitter.h | 59 +-------- .../emitters/cpu_fusion_emitter_test.cc | 19 +-- .../codegen/emitters/cpu_scatter_emitter.cc | 37 +++++- .../codegen/emitters/cpu_scatter_emitter.h | 19 ++- .../xla/xla/service/cpu/ir_emitter2.cc | 28 +++-- 7 files changed, 125 insertions(+), 157 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD index 2848c407ecf636..a02d882ce5fd53 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD @@ -34,7 +34,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu:alignment", - "//xla/backends/cpu/codegen:fusion_compiler", "//xla/backends/cpu/codegen:kernel_api_ir_builder", "//xla/backends/cpu/codegen/emitters/ir:xla_cpu", "//xla/codegen/emitters:computation_partitioner", @@ -101,6 +100,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc index 9db6088ea4ffc8..44e3748625bc46 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" @@ -33,6 +33,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Linker/Linker.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" @@ -60,7 +61,6 @@ limitations under the License. #include "xla/backends/cpu/alignment.h" #include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_ops.h" #include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_types.h" -#include "xla/backends/cpu/codegen/fusion_compiler.h" #include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" #include "xla/codegen/emitters/computation_partitioner.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" @@ -76,7 +76,6 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/buffer_assignment.h" #include "xla/service/dump.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/errors.h" @@ -131,86 +130,6 @@ bool Needs64BitIndices(const HloComputation* computation) { } } // namespace -absl::StatusOr CpuFusionEmitterBase::Emit() const { - // Single-threaded for now. - TF_ASSIGN_OR_RETURN(auto module, - CreateLLVMModule(*mlir_context_, *llvm_context_, *fusion_, - buffer_assignment_)); - - const HloModule* hlo_module = fusion_->GetModule(); - if (hlo_module == nullptr) { - return Internal("HloModule is null"); - } - // Create a Kernel API Builder and a throwaway kernel prototype in order to - // extract useful info from them, e.g. noalias, invariant_arguments and - // entry function attributes. - // TODO(ecg): find a way to obtain the same info without wasting work by - // creating a throwaway module. All of this additional info should probably be - // explicit in the generated MLIR, not added afterwards like we're doing here. - // TODO(ecg): some attributes on the final loads are missing wrt those - // generated via KernelApiIrBuilder, e.g. noalias. Add them. - KernelApiIrBuilder kernel_api_ir_builder( - *llvm_context_, - KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config())); - std::unique_ptr throwaway_llvm_module = - KernelApiIrBuilder::CreateModule( - absl::StrCat(fusion_->name(), "_throwaway_module"), *llvm_context_); - TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype, - kernel_api_ir_builder.EmitKernelPrototype( - *throwaway_llvm_module, fusion_, &buffer_assignment_, - "_throwaway_kernel_prototype")); - llvm::Function* kernel_function = module->getFunction(fusion_->name()); - kernel_api_ir_builder.SetKernelFunctionAttributes(kernel_function); - - CpuFusionEmissionResult result; - result.llvm_module = std::move(module); - result.invariant_arguments = std::move(kernel_prototype.invariant_arguments); - return result; -} - -absl::StatusOr> -CpuFusionEmitterBase::CreateLLVMModule( - mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context, - const HloFusionInstruction& fusion, - const BufferAssignment& buffer_assignment) const { - TF_ASSIGN_OR_RETURN(auto module, - CreateMLIRModule(mlir_context, fusion, - std::string(fusion.name()) + "_entry", - buffer_assignment)); - - FusionCompiler compiler(FusionCompiler::Options{}); - return compiler.Compile(llvm_context, module.get()); -} - -absl::StatusOr> -CpuFusionEmitterBase::CreateMLIRModule( - mlir::MLIRContext& context, const HloFusionInstruction& fusion, - const std::string& entry_function_name, - const BufferAssignment& buffer_assignment, - mlir::interpreter::MlirCompilationTrace* trace) const { - mlir::OpBuilder builder(&context); - auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); - mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); - SetDataLayoutAttribute(module.get(), fusion); - - TF_ASSIGN_OR_RETURN( - mlir::func::FuncOp entry_func, - EmitFusionKernelApi(module.get(), fusion, entry_function_name, - buffer_assignment)); - - std::vector epilogues = - GetEpilogues(fusion, &context); - emitters::PartitionedComputations computations( - fusion.fused_instructions_computation(), &context, epilogues); - TF_ASSIGN_OR_RETURN( - emitters::CallTargetProvider call_targets, - EmitCallTargets(module.get(), fusion, computations, epilogues)); - - TF_RETURN_IF_ERROR( - EmitEntryFunction(computations, call_targets, entry_func, fusion)); - return module; -} - using mlir::AffineExpr; IndexingMap GetDefaultIndexingMap(absl::Span thread_tile_sizes, @@ -405,6 +324,39 @@ void SetDataLayoutAttribute(mlir::ModuleOp module, mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout})); } +absl::StatusOr> SetKernelFunctionAttributes( + llvm::Module& module, const BufferAssignment& buffer_assignment, + const HloFusionInstruction* fusion) { + const HloModule* hlo_module = fusion->GetModule(); + if (hlo_module == nullptr) { + return Internal("HloModule is null"); + } + + // Create a Kernel API Builder and a throwaway kernel prototype in order to + // extract useful info from them, e.g. noalias, invariant_arguments and + // entry function attributes. + // TODO(ecg): find a way to obtain the same info without wasting work by + // creating a throwaway module. All of this additional info should probably be + // explicit in the generated MLIR, not added afterwards like we're doing here. + // TODO(ecg): some attributes on the final loads are missing wrt those + // generated via KernelApiIrBuilder, e.g. noalias. Add them. + llvm::LLVMContext& context = module.getContext(); + KernelApiIrBuilder kernel_api_ir_builder( + context, + KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config())); + std::unique_ptr throwaway_llvm_module = + KernelApiIrBuilder::CreateModule( + absl::StrCat(fusion->name(), "_throwaway_module"), context); + TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype, + kernel_api_ir_builder.EmitKernelPrototype( + *throwaway_llvm_module, fusion, &buffer_assignment, + "_throwaway_kernel_prototype")); + llvm::Function* kernel_function = module.getFunction(fusion->name()); + kernel_api_ir_builder.SetKernelFunctionAttributes(kernel_function); + + return kernel_prototype.invariant_arguments; +} + int64_t CeilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h index 455c5f553ded17..0bf3dbd1a04f8f 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h @@ -16,13 +16,11 @@ limitations under the License. #define XLA_BACKENDS_CPU_CODEGEN_EMITTERS_CPU_FUSION_EMITTER_H_ #include -#include #include #include #include #include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/IR/LLVMContext.h" @@ -30,10 +28,8 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/Value.h" #include "mlir/Pass/PassManager.h" #include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" @@ -44,11 +40,6 @@ limitations under the License. namespace xla { namespace cpu { -struct CpuFusionEmissionResult { - std::unique_ptr llvm_module; - absl::flat_hash_set invariant_arguments; -}; - IndexingMap GetDefaultIndexingMap(absl::Span thread_tile_sizes, absl::Span shape, mlir::MLIRContext* mlir_context); @@ -69,17 +60,12 @@ absl::StatusOr EmitCallTargets( void SetDataLayoutAttribute(mlir::ModuleOp module, const HloFusionInstruction& fusion); +absl::StatusOr> SetKernelFunctionAttributes( + llvm::Module& module, const BufferAssignment& buffer_assignment, + const HloFusionInstruction* fusion); + class CpuFusionEmitterBase { public: - CpuFusionEmitterBase(mlir::MLIRContext* mlir_context, - llvm::LLVMContext* llvm_context, - const BufferAssignment& buffer_assignment, - const HloFusionInstruction* fusion) - : mlir_context_(mlir_context), - llvm_context_(llvm_context), - buffer_assignment_(buffer_assignment), - fusion_(fusion) {} - virtual ~CpuFusionEmitterBase() = default; virtual int64_t num_threads() const = 0; @@ -92,42 +78,7 @@ class CpuFusionEmitterBase { virtual std::string BackendExtraOptions() { return {}; } - absl::StatusOr Emit() const; - - // Visible for testing. - absl::StatusOr> CreateLLVMModule( - mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context, - const HloFusionInstruction& fusion, - const BufferAssignment& buffer_assignment) const; - - // Visible for testing. - absl::StatusOr> CreateMLIRModule( - mlir::MLIRContext& context, const HloFusionInstruction& fusion, - const std::string& entry_function_name, - const BufferAssignment& buffer_assignment, - mlir::interpreter::MlirCompilationTrace* trace = nullptr) const; - - protected: - virtual absl::Status EmitEntryFunction( - const emitters::PartitionedComputations& computations, - const emitters::CallTargetProvider& call_targets, - mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const = 0; - - virtual std::vector GetEpilogues( - const HloFusionInstruction& fusion, - mlir::MLIRContext* mlir_context) const { - // We don't actually support epilogues for scatter, but this is how we tell - // the base class that we don't want it to generate code for the scatter. - return {}; - } - - mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const; - - mlir::MLIRContext* mlir_context_; - llvm::LLVMContext* llvm_context_; - const BufferAssignment& buffer_assignment_; - const HloFusionInstruction* fusion_; + virtual absl::StatusOr> Emit() const = 0; }; int64_t CeilDiv(int64_t a, int64_t b); diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc index 4bd542c900b4a1..6e7247d4e1e327 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h" +#include #include #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" @@ -135,11 +137,7 @@ TEST_F(CpuFusionEmitterTest, ScatterMlir) { hlo_module->entry_computation()->root_instruction()); CpuScatterFusion emitter(mlir_context_.get(), &llvm_context_, *buffer_assignment, fusion); - TF_ASSERT_OK_AND_ASSIGN( - auto mlir_module, - emitter.CreateMLIRModule(*mlir_context_, *fusion, - std::string(fusion->name()) + "_entry", - *buffer_assignment)); + TF_ASSERT_OK_AND_ASSIGN(auto mlir_module, emitter.Emit()); auto mlir_dump = MlirModuleToString(*mlir_module); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, RunFileCheck(mlir_dump, kExpected)); @@ -165,8 +163,15 @@ TEST_F(CpuFusionEmitterTest, ScatterLlvm) { hlo_module->entry_computation()->root_instruction()); CpuScatterFusion emitter(mlir_context_.get(), &llvm_context_, *buffer_assignment, fusion); - TF_ASSERT_OK_AND_ASSIGN(auto result, emitter.Emit()); - auto llvm_dump = LlvmModuleToString(*result.llvm_module); + TF_ASSERT_OK_AND_ASSIGN(auto mlir_module, emitter.Emit()); + FusionCompiler compiler(FusionCompiler::Options{}); + llvm::LLVMContext llvm_context; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr llvm_module, + compiler.Compile(llvm_context, mlir_module.get())); + TF_ASSERT_OK_AND_ASSIGN( + absl::flat_hash_set invariant_arguments, + SetKernelFunctionAttributes(*llvm_module, *buffer_assignment, fusion)); + auto llvm_dump = LlvmModuleToString(*llvm_module); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, RunFileCheck(llvm_dump, kExpected)); EXPECT_TRUE(filecheck_matched); diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc index fa2a2b70d86132..1a502342d0c705 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc @@ -40,8 +40,10 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h" #include "xla/codegen/emitters/computation_partitioner.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" @@ -55,9 +57,12 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/scatter_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -173,8 +178,10 @@ CpuScatterFusion::CpuScatterFusion(mlir::MLIRContext* mlir_context, llvm::LLVMContext* llvm_context, const BufferAssignment& buffer_assignment, const HloFusionInstruction* fusion) - : CpuFusionEmitterBase{mlir_context, llvm_context, buffer_assignment, - fusion} { + : mlir_context_(mlir_context), + llvm_context_(llvm_context), + buffer_assignment_(buffer_assignment), + fusion_(fusion) { const auto* scatter = Cast( fusion->fused_instructions_computation()->root_instruction()); auto update_shape = scatter->scatter_updates().front()->shape(); @@ -236,6 +243,32 @@ IndexingMap GetScatterIndexingMap( {}, constraints); } +absl::StatusOr> CpuScatterFusion::Emit() + const { + mlir::OpBuilder builder(mlir_context_); + auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_->name())); + mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); + SetDataLayoutAttribute(module.get(), *fusion_); + + TF_ASSIGN_OR_RETURN( + mlir::func::FuncOp entry_func, + EmitFusionKernelApi(module.get(), *fusion_, + std::string(fusion_->name()) + "_entry", + buffer_assignment_)); + + std::vector epilogues = + GetEpilogues(*fusion_, mlir_context_); + emitters::PartitionedComputations computations( + fusion_->fused_instructions_computation(), mlir_context_, epilogues); + TF_ASSIGN_OR_RETURN( + emitters::CallTargetProvider call_targets, + EmitCallTargets(module.get(), *fusion_, computations, epilogues)); + + TF_RETURN_IF_ERROR( + EmitEntryFunction(computations, call_targets, entry_func, *fusion_)); + return module; +} + absl::Status CpuScatterFusion::EmitEntryFunction( const emitters::PartitionedComputations& computations, const emitters::CallTargetProvider& call_targets, diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.h b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.h index 46189aa5912095..0d5a9191cbc9af 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.h +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.h @@ -17,12 +17,18 @@ limitations under the License. #include #include +#include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/IR/LLVMContext.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" #include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h" #include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" @@ -51,18 +57,27 @@ class CpuScatterFusion : public CpuFusionEmitterBase { std::string BackendExtraOptions() override; + absl::StatusOr> Emit() const final; + protected: absl::Status EmitEntryFunction( const emitters::PartitionedComputations& computations, const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const override; + const HloFusionInstruction& fusion) const; std::vector GetEpilogues( const HloFusionInstruction& fusion, - mlir::MLIRContext* mlir_context) const override; + mlir::MLIRContext* mlir_context) const; private: + mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const; + + mlir::MLIRContext* mlir_context_; + llvm::LLVMContext* llvm_context_; + const BufferAssignment& buffer_assignment_; + const HloFusionInstruction* fusion_; + int64_t vector_size_; int64_t num_threads_; }; diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index a918da2805ab50..a49333ab0acc9e 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -213,18 +213,30 @@ absl::StatusOr IrEmitter2::EmitFusionWithFusionEmitters( fusion_emitter_kind, fusion->ToString()); } - TF_ASSIGN_OR_RETURN(auto fusion_result, emitter->Emit()); + TF_ASSIGN_OR_RETURN(auto mlir_module, emitter->Emit()); + + FusionCompiler compiler(FusionCompiler::Options{}); + TF_ASSIGN_OR_RETURN( + std::unique_ptr llvm_module, + compiler.Compile(module_->getContext(), mlir_module.get())); + + // TODO(willfroom): This should be done as part of the fusion emitter and + // lowered by the above compiler. + TF_ASSIGN_OR_RETURN( + absl::flat_hash_set invariant_arguments, + SetKernelFunctionAttributes(*llvm_module, + nested_ir_emitter_->assignment(), fusion)); + // Match data layouts to avoid warning messages. - fusion_result.llvm_module->setDataLayout(module_->getDataLayout()); - if (llvm::Linker::linkModules(*module_, - std::move(fusion_result.llvm_module))) { + llvm_module->setDataLayout(module_->getDataLayout()); + if (llvm::Linker::linkModules(*module_, std::move(llvm_module))) { return Internal("Cannot link additional LLVM module for fusion %s", fusion->name()); } - return kernels_.emplace_back(KernelInfo( - std::string(fusion->name()), se::BlockDim(), - se::ThreadDim(emitter->num_threads()), fusion_result.invariant_arguments, - emitter->BackendExtraOptions())); + return kernels_.emplace_back( + KernelInfo(std::string(fusion->name()), se::BlockDim(), + se::ThreadDim(emitter->num_threads()), invariant_arguments, + emitter->BackendExtraOptions())); } absl::StatusOr IrEmitter2::EmitFusionHostKernel( From d4b08f6a12b9de8061c4d5ed5acca6ad2b0fd81a Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 10 Apr 2025 04:15:05 -0700 Subject: [PATCH 0493/1324] #sdy don't escape frontend attr strings as this is now done by XLA dumping. PiperOrigin-RevId: 745967890 --- .../spmd/shardy/shardy_xla_pass_test.cc | 8 +-- .../test/sdy_round_trip_export_pipeline.mlir | 40 +++++++------- .../test/sdy_round_trip_import_pipeline.mlir | 54 +++++++++---------- .../test/sdy_round_trip_shard_map_export.mlir | 50 ++++++++--------- .../xla/xla/service/spmd/shardy/utils.cc | 24 ++++----- .../xla/xla/service/spmd/shardy/utils.h | 4 +- 6 files changed, 87 insertions(+), 93 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 41ffe8c23ff6cf..ce603a83f7013e 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -706,7 +706,7 @@ ENTRY %main.0 (Arg_0.0: s64[2]) -> s64[2] { TEST_F(ShardyXLATest, WhileShardingOnlyOnFreeVariables) { const char* const hloString = R"( - HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}}, frontend_attributes={xla.sdy.meshes="{mesh = #sdy.mesh<[\"x\"=4]>}"} + HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}}, frontend_attributes={xla.sdy.meshes={mesh = #sdy.mesh<["x"=4]>}} %region_0.6 (arg_tuple.7: (f32[32,96], s32[], f32[32,96])) -> (f32[32,96], s32[], f32[32,96]) { %arg_tuple.7 = (f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) parameter(0) @@ -756,7 +756,7 @@ TEST_F(ShardyXLATest, WhileShardingOnlyOnFreeVariables) { TEST_F(ShardyXLATest, EmptyResultLayout) { const char* const hloString = R"( - HloModule pjit_f_, entry_computation_layout={(s64[2,2,2]{2,1,0})->()}, num_partitions=2, frontend_attributes={xla.sdy.meshes="{maximal_mesh_0 = #sdy.mesh<[], device_ids=[0]>, mesh = #sdy.mesh<[\"x\"=2]>}"} + HloModule pjit_f_, entry_computation_layout={(s64[2,2,2]{2,1,0})->()}, num_partitions=2, frontend_attributes={xla.sdy.meshes={maximal_mesh_0 = #sdy.mesh<[], device_ids=[0]>, mesh = #sdy.mesh<["x"=2]>}} ENTRY %main.5 (Arg_0.1: s64[2,2,2]) -> () { %Arg_0.0 = s64[2,2,2]{2,1,0} parameter(0), sharding={devices=[2,1,1]<=[2]}, metadata={op_name="x"} @@ -772,7 +772,7 @@ TEST_F(ShardyXLATest, EmptyResultLayout) { TEST_F(ShardyXLATest, EmptyOperandLayout) { const char* const hloString = R"( - HloModule pjit_f_, entry_computation_layout={()->s64[2,2]{1,0}}, num_partitions=2, frontend_attributes={xla.sdy.meshes="{maximal_mesh_0 = #sdy.mesh<[], device_ids=[0]>, mesh = #sdy.mesh<[\"x\"=2]>}"} + HloModule pjit_f_, entry_computation_layout={()->s64[2,2]{1,0}}, num_partitions=2, frontend_attributes={xla.sdy.meshes={maximal_mesh_0 = #sdy.mesh<[], device_ids=[0]>, mesh = #sdy.mesh<["x"=2]>}} ENTRY %main.5 () -> s64[2,2] { ROOT %constant = s64[2,2]{1,0} constant({{1,1},{1,1}}) @@ -787,7 +787,7 @@ TEST_F(ShardyXLATest, EmptyOperandLayout) { TEST_F(ShardyXLATest, RaggedDotMode1) { const char* const hloString = R"( - HloModule ragged_dot, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={xla.sdy.meshes="{mesh = #sdy.mesh<[\"a\"=2, \"b\"=2, \"c\"=2]>}"} + HloModule ragged_dot, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={xla.sdy.meshes={mesh = #sdy.mesh<["a"=2, "b"=2, "c"=2]>}} ENTRY entry { p0 = f32[16,32,64] parameter(0), frontend_attributes={xla.sdy.sharding="#sdy.sharding<@mesh, [{\"a\", ?}, {\"b\", ?}, {\"c\", ?}]>"} p1 = f32[4,16,64,8] parameter(1) diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 1e679b6d9644e1..54110bb38086e7 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -8,20 +8,20 @@ sdy.mesh @mesh_2 = <["x"=8, "y"=4]> // CHECK: module attributes {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.meshes = "{ -// CHECK-SAME: mesh_0 = #sdy.mesh<[\\\22axis_0\\\22=2, \\\22axis_1\\\22=4, \\\22axis_2\\\22=4]>, -// CHECK-SAME: mesh_1 = #sdy.mesh<[\\\22axis_0\\\22=16]>, -// CHECK-SAME: mesh_2 = #sdy.mesh<[\\\22x\\\22=8, \\\22y\\\22=4]>}"}} { +// CHECK-SAME: mesh_0 = #sdy.mesh<[\22axis_0\22=2, \22axis_1\22=4, \22axis_2\22=4]>, +// CHECK-SAME: mesh_1 = #sdy.mesh<[\22axis_0\22=16]>, +// CHECK-SAME: mesh_2 = #sdy.mesh<[\22x\22=8, \22y\22=4]>}"}} { // CHECK-LABEL: func @multiple_shardings( -// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{\\\22axis_2\\\22}, {\\\22axis_0\\\22, \\\22axis_1\\\22}]>"}, mhlo.sharding = -// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\\\22axis_0\\\22, \\\22axis_2\\\22}]>"}, mhlo.sharding = -// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\\\22axis_1\\\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{\22axis_2\22}, {\22axis_0\22, \22axis_1\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_0\22, \22axis_2\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_1\22}]>"}, mhlo.sharding = // CHECK-SAME: -> tensor<8x16xf32> { func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: stablehlo.add -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22axis_1\\\22, \\\22axis_0\\\22}, {}]>]>"}, mhlo.sharding = +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\22axis_1\22, \22axis_0\22}, {}]>]>"}, mhlo.sharding = %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> @@ -31,7 +31,7 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { %0 = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: stablehlo.reduce -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\\\22y\\\22}]>, <@mesh_2, [{\\\22y\\\22}, {}]>]>"}, mhlo.sharding = +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\22y\22}]>, <@mesh_2, [{\22y\22}, {}]>]>"}, mhlo.sharding = %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) @@ -44,13 +44,13 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) } // CHECK-LABEL: func @split_axes( -// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\\\22y\\\22}, {\\\22x\\\22:(2)2}]>"}, mhlo.sharding = -// CHECK-SAME: %arg1: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\\\22x\\\22:(1)2}, {\\\22x\\\22:(2)4}]>"}, mhlo.sharding = +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22y\22}, {\22x\22:(2)2}]>"}, mhlo.sharding = +// CHECK-SAME: %arg1: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22x\22:(1)2}, {\22x\22:(2)4}]>"}, mhlo.sharding = // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: stablehlo.dot -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22:(1)2, \\\22x\\\22:(4)2}, {}]>]>"}, mhlo.sharding = +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22:(1)2, \22x\22:(4)2}, {}]>]>"}, mhlo.sharding = %1 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -60,7 +60,7 @@ func.func @func_result_sharding_returning_func_arg( // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {mhlo.sharding = %arg0: tensor<8x16xf32> ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}) { - // CHECK: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: return %[[CUSTOM_CALL]] : tensor<8x16xf32> return %arg0 : tensor<8x16xf32> } @@ -76,10 +76,10 @@ func.func @func_result_sharding_returning_op_value(%arg0: tensor<8x16xf32>) tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{}, {}]>}) { // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = stablehlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, \\\22y\\\22}, {}]>, <@mesh_2, [{\\\22y\\\22, \\\22x\\\22}, {}]>]>"}, mhlo.sharding = - // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {\\\22y\\\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = stablehlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, \22y\22}, {}]>, <@mesh_2, [{\22y\22, \22x\22}, {}]>]>"}, mhlo.sharding = + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22}, {\22y\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[ADD_RESULT_SHARDING_1:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {}]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: return %[[ADD_RESULT_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_1]], %[[ADD_RESULT_SHARDING_1]] %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> @@ -90,7 +90,7 @@ func.func @func_result_sharding_returning_op_value(%arg0: tensor<8x16xf32>) // CHECK-LABEL: func @sharding_constraint // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {?}]>]>"}, mhlo.sharding = + // CHECK: stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {?}]>]>"}, mhlo.sharding = %0 = sdy.sharding_constraint %arg0 <@mesh_2, [{"x", ?}, {?}]> : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -121,14 +121,14 @@ func.func @constant() -> tensor { // CHECK-LABEL: func @inlined_mesh( // CHECK-SAME: %arg0: tensor<32xi32> -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\\\22a\\\22}]>"}, +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\22a\22}]>"}, // CHECK-SAME: mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}) // CHECK-SAME: -> (tensor<32xi32> {mhlo.sharding = "{maximal device=5}"}) { func.func @inlined_mesh( %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>} ) -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, []>}) { // CHECK-NEXT: %[[SHARDING:.*]] = stablehlo.custom_call @Sharding(%arg0) - // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}, mhlo.sharding = "{devices=[4]<=[4]}"} + // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\22c\22}]>]>"}, mhlo.sharding = "{devices=[4]<=[4]}"} // CHECK-NEXT: %[[RESULT_SHARDING:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[SHARDING]]) // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, []>]>"} // CHECK-NEXT: return %[[RESULT_SHARDING]] @@ -146,7 +146,7 @@ func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> t // CHECK-LABEL: func @sharding_and_op_sharding_rule func.func @sharding_and_op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { // CHECK: stablehlo.custom_call @foo(%arg0, %arg1) {mhlo.frontend_attributes = - // CHECK-SAME: {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {}]>]>" + // CHECK-SAME: {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22}, {}]>]>" // CHECK-SAME: xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"} %0 = stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>, diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 89c57ff153ce3a..379023c9db13ec 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: module @multiple_func_result_shardings module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = - "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>, mesh2 = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=1]>}"}} { + "{mesh = #sdy.mesh<[\"a\"=8, \"b\"=8, \"c\"=8]>, mesh2 = #sdy.mesh<[\"a\"=1, \"b\"=4, \"c\"=1]>}"}} { // CHECK: sdy.mesh @mesh = <["a"=8, "b"=8, "c"=8]> // CHECK: sdy.mesh @mesh2 = <["a"=1, "b"=4, "c"=1]> @@ -21,15 +21,15 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2 // CHECK-NEXT: } func.func @func_results_with_sharding( - %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p2]>"}}, - %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}}, - %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}} + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"b\"}p2]>"}}, + %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"a\"}p1]>"}}, + %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"c\"}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) { - %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %4 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"b\"}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"c\"}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> } @@ -45,9 +45,9 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x func.func @func_result_shardings_used_by_other_ops( %arg0: tensor<32xi32>, %arg1: tensor<32xi32> ) -> (tensor<32xi32>, tensor<32xi32>) { - %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"b\"}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %3 = stablehlo.add %1, %2 : tensor<32xi32> return %1, %3 : tensor<32xi32>, tensor<32xi32> } @@ -120,16 +120,16 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}) // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p4]>}) { func.func @discard_shardings_on_unknown_ops( - %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p0]>"}} + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"a\"}p0]>"}} ) -> tensor<32xi32> { // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<32xi32> // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a"}p2]> : tensor<32xi32> // CHECK-NEXT: %[[UNKNOWN:.*]] = stablehlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32> // CHECK-NEXT: return %[[UNKNOWN]] - %0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32> - %1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p1]>]>"}} : tensor<32xi32> + %1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3 : tensor<32xi32> } @@ -137,11 +137,11 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>}) // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, []>}) { func.func @inlined_mesh( - %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\\\22a\\\22}]>"}} + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\"a\"}]>"}} ) -> tensor<32xi32> { // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> // CHECK-NEXT: return %[[SHARDING]] - %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\"c\"}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, []>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %1 : tensor<32xi32> } @@ -154,9 +154,9 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{}]>}, // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b"}]>}) { func.func @shardings_with_size_one_axes( - %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22b\\\22}p1], replicated={\\\22c\\\22}>"}}, - %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22a\\\22}p2], replicated={\\\22b\\\22}>"}}, - %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22c\\\22, \\\22b\\\22, ?}p0]>"}} + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\"b\"}p1], replicated={\"c\"}>"}}, + %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\"a\"}p2], replicated={\"b\"}>"}}, + %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\"c\", \"b\", ?}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>) { // CHECK-NEXT: %[[SC1:.*]] = sdy.sharding_constraint %arg0 <@mesh2, [{"b", ?}]> // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SC1]], %[[SC1]] @@ -164,11 +164,11 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-NEXT: %[[SC2:.*]] = sdy.sharding_constraint %arg1 <@mesh2, [{}]> // CHECK-NEXT: return %[[ADD]], %[[SC2]] // CHECK-NEXT: } - %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\"a\", \"b\", ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %1 = stablehlo.add %0, %0 : tensor<32xi32> - %2 = stablehlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %4 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\"c\", \"a\"}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\"a\"}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\"b\"}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3, %4 : tensor<32xi32>, tensor<32xi32> } @@ -185,7 +185,7 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] %0:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) - %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh2, [{}, {\\\22b\\\22}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh2, [{}, {\\\22b\\\22, \\\22a\\\22}]>]>"}} : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> + %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh2, [{\"a\"}, {\"b\"}]>, <@mesh2, [{}, {\"b\"}], replicated={\"a\"}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh2, [{}, {\"b\", \"a\"}]>]>"}} : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir index db52b4bc47009e..f3bfcf3b659034 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir @@ -8,9 +8,9 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.shard // CHECK-NEXT: %[[GLOBAL_TO_LOCAL:.*]]:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) {has_side_effect = true} : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body(%[[GLOBAL_TO_LOCAL]]#0, %[[GLOBAL_TO_LOCAL]]#1) // CHECK-SAME: {mhlo.frontend_attributes = { - // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\22a\22}, {\22b\22}]>, <@mesh_0, [{\22b\22}, {}], replicated={\22a\22}>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\22a\22}, {}], replicated={\22b\22}>]>"}} // CHECK-SAME: : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> // CHECK-NEXT: %[[LOCAL_TO_GLOBAL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) {has_side_effect = true} : (tensor<2x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[LOCAL_TO_GLOBAL]] : tensor<8x32xf32> @@ -35,17 +35,17 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy // CHECK-NEXT: %[[GLOBAL_TO_LOCAL_0:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true} : (tensor<8x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[SHMAP_0:.*]] = call @local_xla.sdy.manual_computation_body_0(%[[GLOBAL_TO_LOCAL_0]]) // CHECK-SAME: {mhlo.frontend_attributes = { - // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\22a\22}, {}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\22a\22}, {}]>]>"}} // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[LOCAL_TO_GLOBAL_0:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP_0]]) {has_side_effect = true} : (tensor<2x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: %[[GLOBAL_TO_LOCAL_1:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%[[LOCAL_TO_GLOBAL_0]]) {has_side_effect = true} : (tensor<8x8xf32>) -> tensor<8x4xf32> // CHECK-NEXT: %[[SHMAP_1:.*]] = call @local_xla.sdy.manual_computation_body_1(%[[GLOBAL_TO_LOCAL_1]]) // CHECK-SAME: {mhlo.frontend_attributes = { - // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\22b\22}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\22b\22}]>]>"}} // CHECK-SAME: : (tensor<8x4xf32>) -> tensor<8x4xf32 // CHECK-NEXT: %[[LOCAL_TO_GLOBAL_1:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP_1]]) {has_side_effect = true} : (tensor<8x4xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[LOCAL_TO_GLOBAL_1]] : tensor<8x8xf32> @@ -64,9 +64,9 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-NEXT: %[[GLOBAL_TO_LOCAL:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_3(%[[GLOBAL_TO_LOCAL]]) // CHECK-SAME: {mhlo.frontend_attributes = { - // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\22a\22}, {}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\22a\22}, {}]>]>"}} // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[LOCAL_TO_GLOBAL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) {has_side_effect = true} : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[LOCAL_TO_GLOBAL]] : tensor<4x8xf32> @@ -85,9 +85,9 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh // CHECK-NEXT: %[[GLOBAL_TO_LOCAL:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_5(%[[GLOBAL_TO_LOCAL]]) // CHECK-SAME: {mhlo.frontend_attributes = { - // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\22a\22}, {}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\22a\22}, {}]>]>"}} // CHECK-SAME: (tensor<2x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[LOCAL_TO_GLOBAL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) {has_side_effect = true} : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[LOCAL_TO_GLOBAL]] : tensor<4x8xf32> @@ -107,8 +107,8 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_6() // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\22b\22}]>]>"}} // CHECK-SAME: () -> tensor<2xi64> // CHECK-NEXT: %[[LOCAL_TO_GLOBAL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) {has_side_effect = true} : (tensor<2xi64>) -> tensor<4xi64> // CHECK-NEXT: return %[[LOCAL_TO_GLOBAL]] : tensor<4xi64> @@ -124,8 +124,8 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-NEXT: %[[GLOBAL_TO_LOCAL:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true} : (tensor<4xi64>) -> tensor<2xi64> // CHECK-NEXT: call @local_xla.sdy.manual_computation_body_7(%[[GLOBAL_TO_LOCAL]]) // CHECK-SAME: {mhlo.frontend_attributes = { - // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", - // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\22b\22}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} // CHECK-SAME: : (tensor<2xi64>) -> () // CHECK-NEXT: return @@ -154,9 +154,9 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-NEXT: %[[GLOBAL_TO_LOCAL:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true} : (tensor<2x8xf32>) -> tensor<2x4xf32> // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_2(%[[GLOBAL_TO_LOCAL]]) // CHECK-SAME: {mhlo.frontend_attributes = { -// CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", -// CHECK-SAME: xla.sdy.manual_axes = "#sdy", -// CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} +// CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\22b\22}]>]>", +// CHECK-SAME: xla.sdy.manual_axes = "#sdy", +// CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\22b\22}]>]>"}} // CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) {has_side_effect = true} : (tensor<2x4xf32>) -> tensor<2x8xf32> @@ -167,9 +167,9 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-NEXT: %[[GLOBAL_TO_LOCAL:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true} : (tensor<2x8xf32>) -> tensor<2x4xf32 // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_4(%[[GLOBAL_TO_LOCAL]]) // CHECK-SAME: {mhlo.frontend_attributes = { -// CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", -// CHECK-SAME: xla.sdy.manual_axes = "#sdy", -// CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} +// CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\22b\22}]>]>", +// CHECK-SAME: xla.sdy.manual_axes = "#sdy", +// CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\22b\22}]>]>"}} // CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: %[[LOCAL_TO_GLOBAL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) {has_side_effect = true} : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[LOCAL_TO_GLOBAL]], %[[LOCAL_TO_GLOBAL]] : tensor<2x8xf32> diff --git a/third_party/xla/xla/service/spmd/shardy/utils.cc b/third_party/xla/xla/service/spmd/shardy/utils.cc index 03127482898f65..e0ff8017583062 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.cc +++ b/third_party/xla/xla/service/spmd/shardy/utils.cc @@ -67,18 +67,12 @@ DictionaryAttr getFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index) { namespace { -mlir::StringAttr getStringAttribute(Attribute attr, mlir::OpBuilder& builder, - bool escapeAttr) { +mlir::StringAttr getStringAttribute(Attribute attr, mlir::OpBuilder& builder) { std::string value; if (auto stringAttr = mlir::dyn_cast(attr)) { - if (!escapeAttr) { - return stringAttr; - } - value = stringAttr.getValue().str(); - } else { - value = mlir::sdy::attributeToString(attr); + return stringAttr; } - return builder.getStringAttr(escapeAttr ? absl::CEscape(value) : value); + return builder.getStringAttr(mlir::sdy::attributeToString(attr)); } SmallVector getExistingFrontendAttributes( @@ -96,9 +90,9 @@ SmallVector getExistingFrontendAttributes( } void setFrontendAttribute(SmallVector& existingAttributes, - StringRef name, Attribute value, bool escapeAttr) { + StringRef name, Attribute value) { mlir::OpBuilder builder(value.getContext()); - StringAttr stringValue = getStringAttribute(value, builder, escapeAttr); + StringAttr stringValue = getStringAttribute(value, builder); for (auto* it = existingAttributes.begin(); it != existingAttributes.end(); ++it) { if (it->getName() == name) { @@ -140,19 +134,19 @@ void setFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index, } // namespace void setFrontendAttribute(Operation* op, StringRef name, Attribute value, - bool escapeAttr) { + bool) { SmallVector existingAttributes = getExistingFrontendAttributes(getFrontendAttrs(op), ""); - setFrontendAttribute(existingAttributes, name, value, escapeAttr); + setFrontendAttribute(existingAttributes, name, value); setFrontendAttrs(op, existingAttributes); } void setFrontendAttribute(FuncOp funcOp, StringRef name, Attribute value, - int64_t argNum, bool escapeAttr) { + int64_t argNum) { SmallVector existingAttributes = getExistingFrontendAttributes(getFuncArgFrontendAttrs(funcOp, argNum), ""); - setFrontendAttribute(existingAttributes, name, value, escapeAttr); + setFrontendAttribute(existingAttributes, name, value); setFuncArgFrontendAttrs(funcOp, argNum, existingAttributes); } diff --git a/third_party/xla/xla/service/spmd/shardy/utils.h b/third_party/xla/xla/service/spmd/shardy/utils.h index 7e7f2af813cb57..54134ce9986ed9 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.h +++ b/third_party/xla/xla/service/spmd/shardy/utils.h @@ -49,6 +49,7 @@ mlir::DictionaryAttr getFuncArgFrontendAttrs(mlir::func::FuncOp funcOp, // Adds `name` into the frontend attributes of `op` with value `value`. If // `name` already exists, it will be overwritten. Note that `value` will be // turned into a `StringAttr`. +// TODO(tomnatan): cleanup `escapeAttr` void setFrontendAttribute(mlir::Operation* op, mlir::StringRef name, mlir::Attribute value, bool escapeAttr = true); @@ -56,8 +57,7 @@ void setFrontendAttribute(mlir::Operation* op, mlir::StringRef name, // with value `value`. If `name` already exists, it will be overwritten. Note // that `value` will be turned into a `StringAttr`. void setFrontendAttribute(mlir::func::FuncOp funcOp, mlir::StringRef name, - mlir::Attribute value, int64_t argNum, - bool escapeAttr = true); + mlir::Attribute value, int64_t argNum); // Remove `attributeName` from the frontend attributes of `op`. void removeFrontendAttribute(mlir::Operation* op, From c4a23576372b3a8623c9c74d71654f17f6dac06c Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Thu, 10 Apr 2025 05:09:12 -0700 Subject: [PATCH 0494/1324] Add CTA size (num_warps) selection logic to dynamic search space PiperOrigin-RevId: 745982223 --- .../gpu/autotuning/dot_search_space.cc | 47 +++++++++-- .../service/gpu/autotuning/dot_search_space.h | 12 +++ .../gpu/autotuning/dot_search_space_test.cc | 82 +++++++++++++------ 3 files changed, 108 insertions(+), 33 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index f9982d35bebf62..98b70dc615d96e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -135,15 +135,16 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddOutputTilings); EliminateLowOccupancyConfigs(configs); + ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddCtaSizeParameter); + std::vector result; result.reserve(configs.size()); - for (auto& config_with_notes : configs) { + for (ConfigWithNotes& config_with_notes : configs) { // TODO: b/404470821 - Implement this properly rather than hardcoding the // config parameters. - auto& config = config_with_notes.config; + TritonGemmConfig& config = config_with_notes.config; config.block_k = 64; config.num_stages = 3; - config.num_warps = 4; config.num_ctas = 1; result.push_back(config); } @@ -215,6 +216,19 @@ int64_t TritonDotFusionSearchSpace::GetNumResultTiles( CeilOfRatio(rhs_parallel_size_, output_tile.rhs_dim); } +int TritonDotFusionSearchSpace::GetMaxWarpsPerCta(OutputTile tile) const { + // A single mma instruction is of output shape at least 16x8 (the same + // also holds for wgmma: the warp-group level instruction is at least + // 64x8, and split 4-ways across the 4 warps in the group). + constexpr OutputTile kMmaSubTile = {16, 8}; + const int max_warps = device_description_.threads_per_block_limit() / + device_description_.threads_per_warp(); + const int lhs_warps = CeilOfRatio(tile.lhs_dim, kMmaSubTile.lhs_dim); + const int rhs_warps = CeilOfRatio(tile.rhs_dim, kMmaSubTile.rhs_dim); + return std::max(min_warps_per_cta_, + std::min(max_warps, lhs_warps * rhs_warps)); +} + int TritonDotFusionSearchSpace::GetMaxContractingSplit( OutputTile output_tile) const { const int64_t desired_num_ctas = desired_total_warps_ / min_warps_per_cta_; @@ -275,9 +289,9 @@ void TritonDotFusionSearchSpace::AddOutputTilings( CHECK_GT(config.config.split_k, 0) << "Need config with contracting split already set."; const int split = config.config.split_k; - auto new_config = config; - auto& m = new_config.config.block_m; - auto& n = new_config.config.block_n; + ConfigWithNotes new_config = config; + int& m = new_config.config.block_m; + int& n = new_config.config.block_n; for (m = min_out_tile_.lhs_dim; m <= max_out_tile_.lhs_dim; m *= 2) { for (n = min_out_tile_.rhs_dim; n <= max_out_tile_.rhs_dim; n *= 2) { OutputTile tile = {m, n}; @@ -298,10 +312,29 @@ void TritonDotFusionSearchSpace::AddOutputTilings( } } +void TritonDotFusionSearchSpace::AddCtaSizeParameter( + const ConfigWithNotes& config, + std::vector& updated_configs) { + ConfigWithNotes new_config = config; + int tile_rows = config.config.block_m; + int tile_cols = config.config.block_n; + int& warps = new_config.config.num_warps; + CHECK_GT(tile_rows * tile_cols, 0) + << "Need configs with output tilings determined."; + int max_warps = GetMaxWarpsPerCta({tile_rows, tile_cols}); + VLOG(5) << "Computing max_warps: For output_tile = " << tile_rows << "x" + << tile_cols + << " and (wg)mma instruction shape, max_warps = " << max_warps; + for (warps = min_warps_per_cta_; warps <= max_warps; warps *= 2) { + VLOG(10) << "Adding CTA size parameter: config = " << new_config.ToString(); + updated_configs.push_back(new_config); + } +} + void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( std::vector& configs) { CHECK(!configs.empty()); - auto last_config = configs.back(); // Config with the largest split. + ConfigWithNotes last_config = configs.back(); // Largest split. auto has_too_few_tiles = [](const ConfigWithNotes& config) { if (config.not_enough_tiles) { VLOG(10) << "Skipping due to fewer tiles than cores, config = " diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index fdeccf57936db4..d6c7be718acf08 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -106,6 +106,11 @@ class TritonDotFusionSearchSpace { // splitting the contracting dimension for a given output tile. int64_t GetNumResultTiles(OutputTile output_tile) const; + // Computes how many warps per Cooperative Thread Array (aka. CTA, aka. CUDA + // block) is reasonable for the given output tile and restrictions on + // instruction shape. + int GetMaxWarpsPerCta(OutputTile output_tile) const; + // Computes the maximum sensible split in the contracting dimension // (split_k) to sufficiently occupy all available cores when using the given // output tile. @@ -122,6 +127,13 @@ class TritonDotFusionSearchSpace { void AddOutputTilings(const ConfigWithNotes& config, std::vector& updated_configs); + // Finds all promising values for the Cooperative Thread Array (aka. CTA, aka. + // CUDA block) size (num_warps), based on `config` with already determined + // output tiling and appends them to `updated_configs`. Each config in the + // input list might yield zero or more configs in the output. + void AddCtaSizeParameter(const ConfigWithNotes& config, + std::vector& updated_configs); + // Removes configs that are marked with `not_enough_tiles` from the list. If // this results in an empty list, adds a config that should be the most // optimal one even though it does not occupy all cores. diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index ccc4fb178ee635..b85796de942d42 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -84,6 +84,8 @@ class DotSearchSpaceTest : public HloHardwareIndependentTestBase { // Using H100 numbers as the most relevant example here. device_description_.set_registers_per_block_limit(64 * 1024); device_description_.set_core_count(132); + device_description_.set_threads_per_block_limit(1024); + device_description_.set_threads_per_warp(32); } absl::StatusOr> GetDefaultDotModule( @@ -112,8 +114,9 @@ ENTRY e { }; TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { - TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - auto search_space = MakeSearchSpace(module.get()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_EQ(search_space.Serialize(), "problem_size_BxMxNxKxE: 1x1024x1024x1024x16 " @@ -122,16 +125,18 @@ TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { } TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { - TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - auto search_space = MakeSearchSpace(module.get()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), AllOf(Not(IsEmpty()), Each(IsValidConfig()))); } TEST_F(DotSearchSpaceTest, HonorsForcedContractingSplit) { - TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - auto search_space = MakeSearchSpace(module.get()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT( search_space.GenerateConfigs(/*force_contracting_split=*/2), @@ -139,29 +144,30 @@ TEST_F(DotSearchSpaceTest, HonorsForcedContractingSplit) { } TEST_F(DotSearchSpaceTest, ConsidersContractingSplitForSmallOutputSize) { - TF_ASSERT_OK_AND_ASSIGN(auto module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/16, /*rhs_parallel_dim=*/16, /*contracting_dim=*/1024)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), Contains(SplitKIs(Ge(2)))); } TEST_F(DotSearchSpaceTest, LimitsContractingSplitForSmallerContractingSize) { - TF_ASSERT_OK_AND_ASSIGN(auto module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/16, /*rhs_parallel_dim=*/16, /*contracting_dim=*/32)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), AllOf(Not(IsEmpty()), Each(SplitKIs(Le(2))))); } TEST_F(DotSearchSpaceTest, FindsGoodDataReuseOutputTiles) { - TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - auto search_space = MakeSearchSpace(module.get()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), Contains(AllOf(BlockMIs(Ge(32)), BlockNIs(Ge(32)))).Times(Ge(2))); @@ -169,10 +175,10 @@ TEST_F(DotSearchSpaceTest, FindsGoodDataReuseOutputTiles) { TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForLowOccupancyProblem) { TF_ASSERT_OK_AND_ASSIGN( - auto module, + std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/4096, /*rhs_parallel_dim=*/16, /*contracting_dim=*/4096)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), Contains(AllOf(BlockMIs(Ge(32)), SplitKIs(Ge(2))))); @@ -181,10 +187,10 @@ TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForLowOccupancyProblem) { TEST_F(DotSearchSpaceTest, FindsUniqueOccupancyMaximizingTilingForSmallProblem) { TF_ASSERT_OK_AND_ASSIGN( - auto module, + std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/32, /*rhs_parallel_dim=*/32, /*contracting_dim=*/32)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), AllOf(SizeIs(1), Each(AllOf(BlockMIs(Eq(16)), BlockNIs(Eq(16)), @@ -192,8 +198,9 @@ TEST_F(DotSearchSpaceTest, } TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) { - TF_ASSERT_OK_AND_ASSIGN(auto module, GetDefaultDotModule()); - auto search_space = MakeSearchSpace(module.get()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT( search_space.GenerateConfigs(/*force_contracting_split=*/128), @@ -201,21 +208,21 @@ TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) { } TEST_F(DotSearchSpaceTest, PadsTilesForSmallParallelDimension) { - TF_ASSERT_OK_AND_ASSIGN(auto module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/15, /*contracting_dim=*/1024)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), Contains(BlockNIs(Eq(16)))); } TEST_F(DotSearchSpaceTest, HonorsMinimumOutputTileSizeForTinyProblem) { - TF_ASSERT_OK_AND_ASSIGN(auto module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/12, /*rhs_parallel_dim=*/8, /*contracting_dim=*/16)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT( search_space.GenerateConfigs(), @@ -224,13 +231,13 @@ TEST_F(DotSearchSpaceTest, HonorsMinimumOutputTileSizeForTinyProblem) { TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) { TF_ASSERT_OK_AND_ASSIGN( - auto module, + std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/512, /*contracting_dim=*/1024)); - auto search_space = MakeSearchSpace(module.get()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); - // 1024x512 elements / 32x32 elements/block = 32x16 blocks = 512 blocks. - // 512 blocks * 4 warps/block = 2048 warps. + // 1024x512 elements / 32x32 elements/CTA = 32x16 blocks = 512 CTAs. + // 512 CTAs * 4 warps/CTA = 2048 warps. // 132 cores * 4 schedulers/core * 5 desired warps/scheduler = 2640 desired // warps. // ceil(2640 desired warps / 2048 warps) = ceil(1.3) = 2 desired split @@ -239,5 +246,28 @@ TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) { NumWarpsIs(Eq(4)), SplitKIs(Eq(2))))); } +TEST_F(DotSearchSpaceTest, DoesNotBreakCTASizeLimits) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024 * 16, + /*rhs_parallel_dim=*/1024 * 16)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Not(IsEmpty()), Each(NumWarpsIs(Le(32))))); +} + +TEST_F(DotSearchSpaceTest, ConsidersAppropriateCTASizeForTileSize) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/4096, + /*rhs_parallel_dim=*/4096)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Contains(AllOf(BlockMIs(Eq(64)), BlockNIs(Eq(32)), + NumWarpsIs(Eq(4)))), + Contains(AllOf(BlockMIs(Eq(128)), BlockNIs(Eq(32)), + NumWarpsIs(Eq(8)))))); +} + } // namespace } // namespace xla::gpu From c060c02d7f5ca8d6163d4b5d7ca6048b37662d74 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Thu, 10 Apr 2025 05:33:13 -0700 Subject: [PATCH 0495/1324] [XLA:GPU/TMA] Fixes to TritonXLA Tile lowering to TMA. - tt.addptr should be emitted to get the correct base pointer either way. - Triton also expects the TMA arguments to be !tt.ptr. While this doesn't seem to affect the lowering, it's better to follow their convention in case this changes in the future. PiperOrigin-RevId: 745988526 --- .../triton_xla_extract_insert_to_triton.mlir | 10 ++- ...riton_xla_extract_insert_to_triton_pass.cc | 80 +++++++++++-------- 2 files changed, 52 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir index a4d1ce9e9f59b5..673f15df5cfc25 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir @@ -36,10 +36,12 @@ func.func @lower_tile_extract_insert(%arg0: tensor<512x128xbf16>, // CHECK: tt.return // CHECK-TMA-LABEL:tt.func @lower_tile_extract_insert -// CHECK-TMA-SAME: %[[ARG_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor}, -// CHECK-TMA-SAME: %[[ARG_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor} -// CHECK-TMA: %[[DESC_0:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_0]] -// CHECK-TMA: %[[DESC_1:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_1]] +// CHECK-TMA-SAME: %[[ARG_0:.*]]: !tt.ptr {tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor}, +// CHECK-TMA-SAME: %[[ARG_1:.*]]: !tt.ptr {tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor} +// CHECK-TMA: %[[ADDPTR_0:.*]] = tt.addptr %[[ARG_0]] +// CHECK-TMA: %[[DESC_0:.*]] = tt.reinterpret_tensor_descriptor %[[ADDPTR_0]] +// CHECK-TMA: %[[ADDPTR_1:.*]] = tt.addptr %[[ARG_1]] +// CHECK-TMA: %[[DESC_1:.*]] = tt.reinterpret_tensor_descriptor %[[ADDPTR_1]] // CHECK-TMA: %[[LOAD:.*]] = tt.descriptor_load %[[DESC_0]] // CHECK-TMA: tt.descriptor_store %[[DESC_1]][{{.*}}], %[[LOAD]] // CHECK-TMA: tt.return diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index 0e9dcc6d87cfb4..6c8dc01a829725 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -75,6 +75,14 @@ PointerType GetTensorPtrType(::xla::EmitterLocOpBuilder& builder, Type type) { mlir::NVVM::kGlobalMemorySpace); } +PointerType GetTensorPtrTypeForTma(::xla::EmitterLocOpBuilder& builder) { + // Triton frontend is passing zero in the address space. This doesn't map to + // anything meaningful in NVVM dialect. Setting it to be consistent with + // Triton. + return PointerType::get(xgt::StorageType(builder, builder.getI8Type()), + /*addrspace=*/0); +} + TensorDescType GetTensorDescPtrType(::xla::EmitterLocOpBuilder& builder, RankedTensorType type) { return TensorDescType::get(builder.getContext(), type); @@ -275,13 +283,21 @@ struct RewriteFuncOp : mlir::OpRewritePattern { for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) { mlir::BlockArgument func_arg = op.getArgument(index); - // !tt.ptr<> -> tensor - auto cast_to_orig_type = builder.create( - operand_type, func_arg); + mlir::UnrealizedConversionCastOp cast_to_orig_type; + if (op.getArgAttr(index, "tt.nv_tma_desc")) { + operand_type = GetTensorPtrTypeForTma(builder); + // !tt.ptr -> tensor + cast_to_orig_type = builder.create( + operand_type, func_arg); + } else { + // !tt.ptr<> -> tensor + cast_to_orig_type = builder.create( + operand_type, func_arg); + operand_type = GetTensorPtrType( + builder, mlir::cast(operand_type).getElementType()); + } func_arg.replaceAllUsesExcept(cast_to_orig_type.getResult(0), cast_to_orig_type); - operand_type = GetTensorPtrType( - builder, mlir::cast(operand_type).getElementType()); } // Replace the function arguments with the new types. @@ -310,6 +326,10 @@ struct RewriteFuncOp : mlir::OpRewritePattern { op.getLoc(), op.getName(), new_function_type, attrs, arg_attrs); for (int i = 0; i < new_func.getNumArguments(); ++i) { + // TMA arguments don't require tt.divisibility. + if (op.getArgAttr(i, "tt.nv_tma_desc")) { + continue; + } new_func.setArgAttr(i, "tt.divisibility", builder.getIntegerAttr(builder.getI32Type(), 16)); } @@ -342,9 +362,27 @@ struct RewriteTile : mlir::OpRewritePattern { TileOp op, mlir::PatternRewriter& rewriter) const override { ::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter); auto tiled_tensor_type = op.getTiledTensor().getType(); + bool can_use_tma = CanUseTMA(builder, tma_enabled, *device_description, + tiled_tensor_type, op.getTensor()); + + // can_use_tma ? "tensor -> !tt.ptr" otherwise "tensor -> !tt.ptr<>" + Type ptr_type = + can_use_tma ? GetTensorPtrTypeForTma(builder) + : GetTensorPtrType( + builder, op.getTensor().getType().getElementType()); + auto cast_to_tensor_ptr_type = + builder.create(ptr_type, + op.getTensor()); - if (CanUseTMA(builder, tma_enabled, *device_description, tiled_tensor_type, - op.getTensor())) { + auto linear_offset = ComputeLinearOffset(builder, tiled_tensor_type, + op.getOffsets(), op.getLayout()); + auto ptr = builder + .create( + cast_to_tensor_ptr_type.getResult(0).getType(), + cast_to_tensor_ptr_type.getResult(0), linear_offset) + .getResult(); + + if (can_use_tma) { // Add TMA attributes to the corresponding argument in the function. auto block_arg = mlir::dyn_cast(op.getTensor()); auto func_op = @@ -361,21 +399,12 @@ struct RewriteTile : mlir::OpRewritePattern { tiled_tensor_type.getTileShape(), tiled_tensor_type.getElementType().getIntOrFloatBitWidth() / 8)); - // tensor -> !tt.ptr<> - auto cast_to_tensor_ptr_type = - builder - .create( - GetTensorPtrType(builder, - op.getTensor().getType().getElementType()), - op.getTensor()) - .getResult(0); - auto reinterpret_tensor_desc = builder .create( mlir::triton::TensorDescType::get( builder.getContext(), tiled_tensor_type.getTileType()), - cast_to_tensor_ptr_type) + ptr) .getResult(); // !tt.tensordesc -> tiled_tensor @@ -388,23 +417,6 @@ struct RewriteTile : mlir::OpRewritePattern { return mlir::success(); } - // tensor -> !tt.ptr<> - auto cast_to_tensor_ptr_type = - builder - .create( - GetTensorPtrType(builder, - op.getTensor().getType().getElementType()), - op.getTensor()) - .getResult(0); - - auto linear_offset = ComputeLinearOffset(builder, tiled_tensor_type, - op.getOffsets(), op.getLayout()); - - auto ptr = builder - .create(cast_to_tensor_ptr_type.getType(), - cast_to_tensor_ptr_type, linear_offset) - .getResult(); - // Only emit make_tensor_ptr if the input is not a scalar. auto tile_shape = tiled_tensor_type.getTileShape(); if (!tile_shape.empty()) { From 97c65bec293151cbe5bb9ab1d828af5a879f9a35 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Thu, 10 Apr 2025 06:06:51 -0700 Subject: [PATCH 0496/1324] [XLA:GPU] Update default autotuner configs PiperOrigin-RevId: 745998088 --- .../autotuning/gemm_fusion_autotuner_cuda.cc | 45 ++++++++----------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc index da0cd4ffc3a691..860f44f560be82 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc @@ -84,33 +84,24 @@ std::vector GemmFusionAutotunerImpl::GetDefaultTritonConfigs() using Config = TritonGemmConfig; std::vector configs = { - Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), - Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), - Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), - Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), - Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), - Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8), - Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4), - Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4), - Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8), - Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4), - Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4), - Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4), - Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4), - Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4), - Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4), - Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)}; - auto cu_compute_capability = - std::get(GetComputeCapability()); - if (cu_compute_capability.IsAtLeastHopper()) { - absl::c_copy( - std::vector{ - Config(16, 32, 32, 8, 1, 2), - Config(16, 64, 128, 8, 1, 4), - Config(16, 64, 128, 16, 3, 4), - }, - std::back_inserter(configs)); - } + Config(16, 16, 64, 1, 4, 2), Config(16, 16, 128, 1, 4, 4), + Config(16, 16, 128, 128, 4, 2), Config(16, 16, 128, 16, 1, 2), + Config(16, 256, 16, 1, 1, 2), Config(32, 32, 128, 16, 1, 4), + Config(32, 256, 32, 1, 3, 4), Config(32, 256, 32, 16, 3, 8), + Config(64, 16, 32, 1, 4, 2), Config(64, 16, 32, 16, 4, 2), + Config(64, 16, 64, 1, 1, 4), Config(64, 16, 64, 4, 3, 2), + Config(64, 16, 64, 16, 4, 4), Config(64, 16, 128, 1, 4, 2), + Config(64, 16, 128, 16, 4, 4), Config(64, 32, 32, 1, 4, 4), + Config(64, 32, 64, 16, 3, 4), Config(64, 32, 128, 1, 3, 2), + Config(64, 32, 128, 128, 2, 4), Config(64, 64, 32, 1, 4, 4), + Config(64, 64, 64, 1, 4, 4), Config(64, 64, 64, 4, 4, 4), + Config(64, 64, 128, 16, 3, 4), Config(64, 64, 256, 16, 4, 8), + Config(64, 128, 16, 1, 4, 2), Config(64, 128, 64, 1, 3, 4), + Config(64, 128, 128, 8, 1, 4), Config(64, 256, 32, 1, 4, 4), + Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2), + Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2), + Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8), + Config(128, 256, 64, 1, 4, 8)}; return configs; } From 1d57984ce53f1b4ff548391db9ee2f28934324c2 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 10 Apr 2025 07:14:21 -0700 Subject: [PATCH 0497/1324] PR #24774: [GPU] Fix se_gpu_pjrt_client_test in OSS. Imported from GitHub PR https://github.com/openxla/xla/pull/24774 - Make the GPU backend work - Use the OSS macros Copybara import of the project: -- ca956eb12ec22356e8cb9fe172ede2ef83bdff33 by Ilia Sergachev : [GPU] Fix se_gpu_pjrt_client_test in OSS. - Make the GPU backend work - Use the OSS macros Merging this change closes #24774 PiperOrigin-RevId: 746017593 --- third_party/xla/xla/pjrt/gpu/BUILD | 14 ++++---- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 34 +++++++++++++------ 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index b4b0367d1e8b40..317ce13bec7a7c 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -1,6 +1,5 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//xla:xla.default.bzl", "xla_cc_test") load("//xla/pjrt/gpu:package_groups.bzl", "xla_gpu_internal_packages") load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//xla/tests:build_defs.bzl", "xla_test") @@ -162,16 +161,16 @@ cc_library( ]), ) -xla_cc_test( +xla_test( name = "se_gpu_pjrt_client_test", srcs = ["se_gpu_pjrt_client_test.cc"], - tags = [ - "gpu", + backend_tags = {"gpu": [ + "multi_gpu_h100", "no_oss", "noasan", "nomsan", - "requires-gpu-nvidia:2", - ], + ]}, + backends = ["gpu"], deps = [ ":gpu_topology", ":gpu_topology_proto_cc", @@ -203,13 +202,13 @@ xla_cc_test( "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/pjrt/profiling:device_time_measurement", "//xla/pjrt/profiling/test_util:mock_device_time_measurement", - "//xla/service:gpu_plugin", "//xla/service:platform_util", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", "//xla/tsl/platform:statusor", "//xla/tsl/platform:subprocess", "@com_google_absl//absl/container:flat_hash_map", @@ -224,6 +223,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 03bac8cd5358ca..e723f9e6f170e1 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" @@ -76,6 +77,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/subprocess.h" #include "xla/types.h" @@ -86,6 +88,7 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/mem.h" +#include "tsl/platform/platform.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" @@ -241,7 +244,7 @@ ENTRY main.5 { ASSERT_EQ(result.size(), 1); ASSERT_EQ(result[0].size(), 1); - EXPECT_OK(result[0][0]->GetReadyFuture().Await()); + TF_EXPECT_OK(result[0][0]->GetReadyFuture().Await()); } #endif @@ -2188,12 +2191,12 @@ TEST(TpuLocalClientTest, RawBuffer) { *client->addressable_devices()[0]->default_memory_space(), /*device_layout=*/nullptr) .value(); - ASSERT_OK(buffer->GetReadyFuture().Await()); - ASSERT_OK_AND_ASSIGN(auto raw_buffer, - PjRtRawBuffer::CreateRawAliasOfBuffer(buffer.get())); + TF_ASSERT_OK(buffer->GetReadyFuture().Await()); + TF_ASSERT_OK_AND_ASSIGN(auto raw_buffer, + PjRtRawBuffer::CreateRawAliasOfBuffer(buffer.get())); ASSERT_EQ(raw_buffer->memory_space(), buffer->memory_space()); - ASSERT_OK_AND_ASSIGN(size_t on_device_size, - raw_buffer->GetOnDeviceSizeInBytes()); + TF_ASSERT_OK_AND_ASSIGN(size_t on_device_size, + raw_buffer->GetOnDeviceSizeInBytes()); ASSERT_EQ(on_device_size, 1024); std::vector data2(256); @@ -2201,8 +2204,8 @@ TEST(TpuLocalClientTest, RawBuffer) { auto* dst1 = tsl::port::AlignedMalloc(1024, 1024); auto* dst2 = tsl::port::AlignedMalloc(1024, 1024); memcpy(dst1, data2.data(), sizeof(int32_t) * data2.size()); - EXPECT_OK(raw_buffer->CopyRawHostToDevice(dst1, 0, 1024).Await()); - EXPECT_OK(raw_buffer->CopyRawDeviceToHost(dst2, 0, 1024).Await()); + TF_EXPECT_OK(raw_buffer->CopyRawHostToDevice(dst1, 0, 1024).Await()); + TF_EXPECT_OK(raw_buffer->CopyRawDeviceToHost(dst2, 0, 1024).Await()); EXPECT_EQ(absl::MakeSpan(reinterpret_cast(dst2), 256), data2); tsl::port::AlignedFree(dst1); @@ -2234,8 +2237,6 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) { std::string cache_dir; CHECK(tsl::Env::Default()->LocalTempFilename(&cache_dir)); - tsl::setenv("TF_CPP_VMODULE", "gemm_fusion_autotuner=1", /*overwrite=*/true); - // Compile twice to test both empty and non-empty disk cache. for (int iteration = 0; iteration < 2; ++iteration) { tsl::SubProcess child[kNumNodes]; @@ -2251,6 +2252,11 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) { argv.push_back(absl::StrFormat("--num_nodes_using_cache=%d", param.num_nodes_using_cache)); argv.push_back(absl::StrFormat("--cache_dir=%s", cache_dir)); + // Test relies on VLOG(1) messages. Enable VLOG(1) in Non-OSS. + if (!tsl::kIsOpenSource) { + argv.push_back("--vmodule=gemm_fusion_autotuner=1"); + argv.push_back("--logtostderr"); + } child[node_id].SetProgram("/proc/self/exe", argv); child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE); child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); @@ -2279,6 +2285,8 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) { "Rank %d / %d: autotuning %d / 1 fusions", node_id, param.num_active_nodes, num_fusions_to_autotune))); } else { + stderr_str = absl::StrReplaceAll( + stderr_str, {{"sharded_autotuning_test", "sharded_test"}}); EXPECT_THAT(stderr_str, Not(HasSubstr("autotuning"))); } } @@ -2290,6 +2298,11 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id, const int num_nodes_using_cache, absl::string_view cache_dir, bool use_xla_computation) { + if (tsl::kIsOpenSource) { + // Test relies on VLOG(1) messages. Enable VLOG(1) in OSS. + tsl::setenv("TF_CPP_VMODULE", "gemm_fusion_autotuner=1", + /*overwrite=*/true); + } std::unique_ptr service; if (node_id == 0) { TF_ASSIGN_OR_RETURN( @@ -2384,6 +2397,7 @@ INSTANTIATE_TEST_SUITE_P( {false, 2, 1}, {false, 2, 2}}), ShardedAutotuningTestInfo::Name); + } // namespace } // namespace xla From 4a9af2e64f3b145191a6e755c2c7f399d0533a63 Mon Sep 17 00:00:00 2001 From: Nikita Putikhin Date: Thu, 10 Apr 2025 07:16:40 -0700 Subject: [PATCH 0498/1324] Reverts c300b7dece76084bc25ff1a5c4e2a47cab6f57b0 PiperOrigin-RevId: 746018394 --- third_party/xla/xla/service/rendezvous.h | 106 ++++++++++-------- .../xla/xla/service/rendezvous_test.cc | 20 ++-- 2 files changed, 67 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index dbc2fae612fa14..86488963e16744 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -214,15 +214,15 @@ struct RendezvousState : public RendezvousStateSynchronization { // Rendezvous state ownership: // // (1) When rendezvous participant initiates a rendezvous with a particular key -// we create a new state for it, keep it in a map as weak pointer for -// tracking and return a shared pointer to the caller. +// we create a new state for it, keep it in a map for tracking and return a +// shared pointer to the caller. // // (2) When rendezvous participant joins in-progress rendezvous it gets back // a shared pointer that is copied from a tracking map. // -// (3) When all rendezvous participants complete the rendezvous, shared pointers -// are destructed and the tracking map will have an expired weak pointer, -// that will be lazily garbage collected by the next rendezvous. +// (3) When the last rendezvous participant computes the result it completes the +// rendezvous and removes a shared pointer to a state. Remaining shared +// pointers destructed when all participants are notified. // // This process guarantees that all completed rendezvous are removed from a map // and a map has records only for rendezvous in progress. @@ -233,25 +233,56 @@ class RendezvousMap { std::shared_ptr Join(const K& key, size_t num_threads) { absl::MutexLock lock(&mutex_); + std::shared_ptr& state = state_[key]; - // Erase expired rendezvous from the map. - absl::erase_if(state_, [](const auto& e) { return e.second.expired(); }); + // Join an in-progress rendezvous. + if (state) return state; - std::weak_ptr& in_progress = state_[key]; + // Join a newly created rendezvous. + return state = std::make_shared(num_threads); + } + + template + void Complete(const K& key, Result&& result) { + std::shared_ptr state = [&] { + absl::MutexLock lock(&mutex_); + + // Extract state from the map so we can immediately start a new round of + // rendezvous with the same key. A state for previous rendezvous will be + // destructed with the last copy of a shared pointer. + std::shared_ptr state = state_.extract(key).mapped(); + + // Check that we have have exactly the number of participants we expected: + // +1 reference for all participants and a +1 reference we extracted. + CHECK_EQ(state.use_count(), 1 + state->values.size()); // NOLINT - // Try to join an in-progress rendezvous for a given key. - if (std::shared_ptr joined = in_progress.lock()) { - return joined; + return state; + }(); + + // We notify awaiting participants without holding a rendezvous map lock, as + // the rendezvous callback might be an expensive operation and might block + // the progress of concurrent rendezvous for other keys. + + // Publish rendezvous result to all participants. + if constexpr (IsStatusOrResult::value) { + if (ABSL_PREDICT_TRUE(result.ok())) { + state->result = std::make_shared(*std::forward(result)); + } else { + state->result = result.status(); + } + } else { + state->result = std::make_shared(std::forward(result)); } - // Start a new rendezvous for a given key. - std::shared_ptr start = std::make_shared(num_threads); - return (in_progress = start, start); + // Notify awaiting participants that result is ready. + absl::MutexLock lock(&state->mutex); + state->ready = true; + state->cv.SignalAll(); } private: absl::Mutex mutex_; - absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); }; void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, @@ -264,22 +295,6 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, // Rendezvous implemenetation. //===----------------------------------------------------------------------===// -template -absl::StatusOr> InvokeRendezvous( - Fn fn, absl::Span values) { - auto result = fn(values); - - if constexpr (internal::IsStatusOrResult::value) { - if (ABSL_PREDICT_TRUE(result.ok())) { - return std::make_shared(*std::move(result)); - } else { - return result.status(); - } - } else { - return std::make_shared(std::move(result)); - } -} - template absl::StatusOr> Rendezvous( absl::string_view name, const K& key, const V& value, size_t num_threads, @@ -292,7 +307,16 @@ absl::StatusOr> Rendezvous( // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread). if (num_threads == 1) { const V* ptr = &value; - return InvokeRendezvous(std::move(fn), absl::MakeSpan(&ptr, 1)); + auto result = fn(absl::MakeSpan(&ptr, 1)); + + if constexpr (internal::IsStatusOrResult::value) { + if (ABSL_PREDICT_TRUE(result.ok())) { + return std::make_shared(*std::move(result)); + } + return result.status(); + } else { + return std::make_shared(std::move(result)); + } } using State = internal::RendezvousState; @@ -335,23 +359,9 @@ absl::StatusOr> Rendezvous( // be notified via `state->ready` flag when result is ready, and we rely on // the store to a flag to create a memory barrier that makes access to // `state->result` safe without any extra synchronization. - tsl::profiler::TraceMe trace("InvokeRendezvous"); + tsl::profiler::TraceMe trace("ExecuteRendezvousCallback"); absl::Span values(state->values.data(), num_threads); - - // Check that we have have exactly the number of participants we expect. - CHECK_EQ(state.use_count(), num_threads); // NOLINT - - // Publish rendezvous result to all participants. - state->result = InvokeRendezvous(std::move(fn), values); - - // Switch `ready` flag to signal all participants that result is ready. - { - absl::MutexLock lock(&state->mutex); - state->ready = true; - } - - // Notify awaiting participants that result is ready. - state->cv.SignalAll(); + rendezvous.Complete(key, fn(values)); } return state->result; diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index 70725f135ea24e..f9fbb1c8287e20 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -37,8 +37,8 @@ limitations under the License. namespace xla { namespace { -absl::Duration Timeout() { return absl::Seconds(5); } -absl::Duration Terminate() { return absl::Seconds(5); } +absl::Duration Timeout() { return absl::Seconds(10); } +absl::Duration Terminate() { return absl::Seconds(10); } tsl::thread::ThreadPool CreateThreadPool(int32_t size) { return tsl::thread::ThreadPool(tsl::Env::Default(), "rendezvous_test", size); @@ -268,9 +268,8 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - CHECK_OK(Rendezvous( - "rendezvous_test", /*key=*/0, num_threads, [] { return 42; }, - Timeout(), Terminate())); + CHECK_OK(Rendezvous("rendezvous_test", 0, num_threads, + [] { return 42; })); counter.DecrementCount(); }); } @@ -286,9 +285,9 @@ static void BM_RendezvousWithValues(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&, i] { - CHECK_OK(Rendezvous( - "rendezvous_test", /*key=*/0, /*value=*/i, num_threads, - [](auto) { return 42; }, Timeout(), Terminate())); + int32_t value = i; + CHECK_OK(Rendezvous("rendezvous_test", 0, value, num_threads, + [](auto) { return 42; })); counter.DecrementCount(); }); } @@ -307,9 +306,8 @@ static void BM_GroupedRendezvous(benchmark::State& state) { for (int64_t group = 0; group < num_groups; ++group) { for (int64_t i = 0; i < group_size; ++i) { thread_pool.Schedule([&, group] { - CHECK_OK(Rendezvous( - "rendezvous_test", /*key=*/group, /*num_threads=*/group_size, - [] { return 42; }, Timeout(), Terminate())); + CHECK_OK(Rendezvous("rendezvous_test", group, group_size, + [] { return 42; })); counter.DecrementCount(); }); } From 6787818efa27ab4a2743ff0cf550ae097b5e0b2e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 10 Apr 2025 07:25:29 -0700 Subject: [PATCH 0499/1324] [XLA:GPU] Prevent `SymbolicTileAnalysis` from collapsing point dimensions prior to symbolic tile derivation. If we collapse point dimensions, then we miss out on cases where dimensions of size `1` need to be padded to a larger power of 2 in Triton emission. This is typically the case for `dot` operations with a trivial non-contracting dimension (i.e. vectors). PiperOrigin-RevId: 746021129 --- .../gpu/model/symbolic_tile_analysis.cc | 53 ++++++++++++- .../gpu/model/symbolic_tile_analysis_test.cc | 78 ++++++++++++++++++- 2 files changed, 128 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 03d63e57ff9b1b..b7772b7552ee4c 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -403,6 +403,51 @@ void SortTiledHloInstructionsInPostOrder( }); } +// Returns `true` if `SymbolicTileAnalysis` should simplify point dimensions +// away when deriving indexing maps. +// +// Simplifying point dimensions away is helpful as it allows symbolic tile +// derivation to succeed in more cases. However, it can lead to generating +// ill-typed programs when we need to propagate a larger (padded) tile through +// the program. In that case, simplifying the point dimension away prevents +// propagation, and leads to the downstream generation of an incorrect program. +// +// This is typically the case when trying to feed a vector-matrix or +// matrix-vector dot product into NVIDIA GPU tensor cores---which expect their +// inputs to have specific dimensions. In that case, we usually want to pretend +// to tile the vector with a tile size appropriate for the tensor core, even +// though one of its dimensions is 1. +// +// Adding this here is a slight abstraction leak, since it slightly specializes +// symbolic tile analysis to NVIDIA GPUs. This is not totally unreasonable +// though: given sufficient analytical capabilities for symbolic tile +// derivation, preventing the simplification of point dimensions should not +// cause us to fail to tile more programs, and would better track the +// propagation of tiles throughout the program. As a result, a mode that does +// not perform this simplification is actually "more correct"---but currently +// leads to more fusions being untileable. +bool ShouldDerivationSimplifyPointDimensions(const HloFusionAdaptor& fusion) { + for (const HloInstructionAdaptor& instruction_adaptor : + fusion.MakeInstructionPostOrder()) { + if (!fusion.ContainsInstruction(&instruction_adaptor.instruction())) { + continue; + } + + if (instruction_adaptor.opcode() == HloOpcode::kDot) { + return false; + } + + if (instruction_adaptor.opcode() == HloOpcode::kFusion) { + auto nested_fusion_adaptor = HloFusionAdaptor::ForComputation( + instruction_adaptor.instruction().fused_instructions_computation()); + if (!ShouldDerivationSimplifyPointDimensions(*nested_fusion_adaptor)) { + return false; + } + } + } + return true; +} + } // anonymous namespace // Extracts HloInstructions from a span of HloInstructionAdaptors. @@ -474,6 +519,12 @@ absl::StatusOr GetRealRootIndex( OrderedUniquePtrValueHashSet tiled_hlo_instructions_set; + IndexingMap::SimplifyPointDimensions simplification_mode = + IndexingMap::SimplifyPointDimensions::kPreserve; + if (ShouldDerivationSimplifyPointDimensions(fusion)) { + simplification_mode = IndexingMap::SimplifyPointDimensions::kReplace; + } + // TODO(b/372454662): Once we get rid of the restriction of only one real // root, this needs to be adapted. auto [root_tiled_hlo, _] = tiled_hlo_instructions_set.Insert( @@ -510,7 +561,7 @@ absl::StatusOr GetRealRootIndex( << tiled_hlo_instruction->hlo()->ToString() << " and operand " << operand.instruction().ToString(); } - operand_indexing_map.Simplify(); + operand_indexing_map.Simplify(simplification_mode); operand_indexing_map.RescaleSymbols(); operand_indexing_map.RemoveUnusedSymbols(); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 0e99c7b84dc678..e82f14470a0489 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/hlo/analysis/indexing_test_utils.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_traversal.h" #include "xla/service/gpu/model/constraint_expression.h" @@ -1295,7 +1296,7 @@ ENTRY main { ->fused_instructions_computation(), &mlir_context_, /*emitter_specific_constraints_builder=*/nullptr); - EXPECT_TRUE(std::holds_alternative(analysis_or_error)); + ASSERT_TRUE(std::holds_alternative(analysis_or_error)); EXPECT_THAT(std::get(analysis_or_error).Explain(), ::testing::HasSubstr("Bailing out on reshape")); } @@ -1560,7 +1561,7 @@ ENTRY main { "kind":"__triton_nested_gemm_fusion"}} })")); std::optional analysis = TryAnalyzeModule(module.get()); - EXPECT_TRUE(analysis.has_value()); + ASSERT_TRUE(analysis.has_value()); TF_ASSERT_OK_AND_ASSIGN( TiledHloComputation tiled_hlo_computation, @@ -1617,6 +1618,79 @@ ENTRY main { ::testing::HasSubstr("not divisible by tile size"))); } +TEST_F(SymbolicTileAnalysisTest, TrivialDimensionParametersArePreserved) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +lhs { + ROOT p0 = f32[137,115] parameter(0) +} + +rhs { + ROOT p0 = f32[1,115] parameter(0) +} + +dot { + p0 = f32[137,115] parameter(0) + p1 = f32[1,115] parameter(1) + + lhs = f32[137,115] fusion(p0), + kind=kCustom, calls=lhs, backend_config={ + "fusion_backend_config":{ + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16","32"]}]}}} + rhs = f32[1,115] fusion(p1), + kind=kCustom, calls=rhs, backend_config={ + "fusion_backend_config":{ + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["16","32"]}]}}} + + ROOT dot = f32[137,1] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY main { + p0 = f32[137,115] parameter(0) + p1 = f32[1,115] parameter(1) + ROOT fusion = f32[137,1] fusion(p0, p1), + kind=kCustom, calls=dot +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{16, 16}, + /*constraints_are_known_satisfied=*/true, + /*compute_all_tile_offset_indexing_maps=*/true)); + + const TiledHloInstruction* dot = tiled_hlo_computation.GetRoots().front(); + ASSERT_EQ(dot->hlo()->opcode(), HloOpcode::kDot); + + const TiledHloFusionInstruction* lhs_fusion = + static_cast(dot->operand(0)); + const TiledHloFusionInstruction* rhs_fusion = + static_cast(dot->operand(1)); + + EXPECT_THAT(*lhs_fusion->called_computation()->GetRoots().front(), + MatchTiledHloInstruction( + /*tile_sizes=*/{16, 32}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/ + "(pid_0, pid_1, pid_2) -> (pid_0 * 16, pid_2 * 32), domain: " + "pid_0 in [0, 8], pid_1 in [0, 0], pid_2 in [0, 3]")); + + // RHS has a trivial dimension. We make sure here that the requested padding + // is propagated as requested, and not simplified away (which would result in + // an invalid tile size of size "1"). + // The trivial argument is still expected to be eliminated in the + // `tile_offsets_indexing` map, since this allows for more effective CSE. + EXPECT_THAT(*rhs_fusion->called_computation()->GetRoots().front(), + MatchTiledHloInstruction( + /*tile_sizes=*/{16, 32}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/ + "(pid_0, pid_1, pid_2) -> (0, pid_2 * 32), domain: " + "pid_0 in [0, 8], pid_1 in [0, 0], pid_2 in [0, 3]")); +} + } // namespace } // namespace gpu } // namespace xla From 0602affd6bd47f53a358b084795270b6cd9388c7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 10 Apr 2025 08:07:05 -0700 Subject: [PATCH 0500/1324] [XLA:GPU] Enable generic Triton emitter GEMM tests involving propagating trivial dimensions. Those are fixed by avoiding the simplification of point dimensions in symbolic tile analysis. PiperOrigin-RevId: 746034281 --- .../fusion_emitter_device_legacy_port_test.cc | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 5bb2f36e15c913..3e014ea92e6d72 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -202,9 +202,7 @@ ENTRY e { EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); } -// TODO(bchetioui): there is already a change out to fix this, enable once it -// lands. -TEST_F(TritonTest, DISABLED_TestGemmWithTrivialNonContractingDimension) { +TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -2500,10 +2498,8 @@ ENTRY e { // into Triton fusions. using CompareTest = TritonGemmTest; -// TODO(bchetioui): same as -// TritonTest.TestGemmWithTrivialNonContractingDimension. -TEST_F(CompareTest, DISABLED_F32WithTrivialNonContractingDimension) { - constexpr absl::string_view hlo_text_ref = R"( +TEST_F(CompareTest, F32WithTrivialNonContractingDimension) { + constexpr absl::string_view kHloTextRef = R"( HloModule r ENTRY e { @@ -2516,7 +2512,7 @@ ENTRY e { } )"; - constexpr absl::string_view hlo_text_triton = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -2534,12 +2530,19 @@ ENTRY e { triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, "split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; +})"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndNestedFusionMetadata test_module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ref_module, + ParseAndReturnVerifiedModule(kHloTextRef)); + + EXPECT_TRUE(RunAndCompareTwoModules( + std::move(ref_module), std::move(test_module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); } // TODO(b/353484968, b/393299275): the e2e test path was never really testing From 3c7b3ea03e0fc35e8f64a2396d59ada7a9b2513e Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 10 Apr 2025 08:20:37 -0700 Subject: [PATCH 0501/1324] Use a static constant float for the clamp instead of an arbitrary pointer whose lifetime is unknown. PiperOrigin-RevId: 746038908 --- tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 7d32e5c9a450c7..0d904a9c44cf35 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -71,6 +71,13 @@ namespace tflite { namespace xnnpack { namespace { +// VisitDotAttentionNode uses a clamp to add a constant value to the XNNPack +// subgraph. The constant data must outlive the XNNPack delegate and there is no +// simple way of doing this. Therefore a clamp was used to clamp some arbitrary +// data to this constant value. The static input data to the clamp can be +// anything. +const float kConstantClampData = 0.f; + constexpr char kOdmlSDPA[] = "odml.scaled_dot_product_attention"; template @@ -5886,7 +5893,7 @@ class Subgraph { TF_LITE_ENSURE_EQ( logging_context, xnn_status_success, xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, - /*dims=*/nullptr, &query_proj.dims->data[3], + /*dims=*/nullptr, &kConstantClampData, XNN_INVALID_VALUE_ID, 0, &scale_orig_id)); TF_LITE_ENSURE_EQ( logging_context, xnn_status_success, From 8fc421adc66d42512b4db46d1398bdd075579318 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 10 Apr 2025 08:24:33 -0700 Subject: [PATCH 0502/1324] [XLA:GPU] Verify the functionality of (and subsequently delete) 12 more tests in `fusion_emitter_device_legacy_port_test.cc`. All the removed tests were confirmed to work manually before being deleted. We list them here, along with a reason for deleting them: * `TritonGemmTest.MultiplePathsToSameOperandWorks`: this was testing a basic feature of the generic Triton emitter (multiple tiles for the same node), which is already well covered. * `TritonGemmTest.SameInput`: same as above. * `TritonGemmTest.MaximumReturnsLHS`: this is basically an elementwise test for `maximum`. This is already well covered by tests involving normalizations/reductions. * `TritonGemmTest.MaximumReturnsRHS`: this is basically an elementwise test for `maximum`. This is already well covered by tests involving normalizations/reductions. * `TritonGemmTest.MaximumHandlesNaNsOnTheLeft`: this is basically an elementwise test for `maximum`. This is already well covered by tests involving normalizations/reductions. * `TritonGemmTest.MaximumHandlesNaNsOnTheRight`: this is basically an elementwise test for `maximum`. This is already well covered by tests involving normalizations/reductions. * `TritonGemmTest.MinimumReturnsLHS`: this is basically an elementwise test for `minimum`. I consolidated this as a standalone test `TritonEmitterTest.MinimumIsEmittedCorrectly` in `fusion_emitter_device_test.cc`. * `TritonGemmTest.MinimumReturnsRHS`: this is basically an elementwise test for `minimum`. I consolidated this as a standalone test `TritonEmitterTest.MinimumIsEmittedCorrectly` in `fusion_emitter_device_test.cc`. * `TritonGemmTest.MinimumHandlesNaNsOnTheLeft`: this is basically an elementwise test for `minimum`. I consolidated this as a standalone test `TritonEmitterTest.MinimumIsEmittedCorrectly` in `fusion_emitter_device_test.cc`. * `TritonGemmTest.MinimumHandlesNaNsOnTheRight`: this is basically an elementwise test for `minimum`. I consolidated this as a standalone test `TritonEmitterTest.MinimumIsEmittedCorrectly` in `fusion_emitter_device_test.cc`. * `TritonGemmTest.SingleElementTileIsHandled`: it's actually unclear what special path this is exercising, but the HLO doesn't seem particularly interesting from the point of view of the generic Triton emitter. * `TritonGemmTest.NoPadding`: this is another `dot` that is not particularly interesting, and there are plenty of other tests in the file that already do not involve padding. PiperOrigin-RevId: 746040124 --- .../fusion_emitter_device_legacy_port_test.cc | 306 +----------------- .../triton/fusion_emitter_device_test.cc | 27 ++ 2 files changed, 29 insertions(+), 304 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 3e014ea92e6d72..08be420041d777 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -637,33 +637,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_NoPadding) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f16[15,19] parameter(0) - p1 = s8[19,17] parameter(1) - cp1 = f16[19,17] convert(p1) - ROOT _ = f16[15,17] dot(p0, cp1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: ROOT -; CHECK-SAME: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": -; CHECK-NOT: pad -; CHECK-NOT: slice -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - +// TODO(b/393299275): requires enabling mixed-type dots for s8xs8->s32. TEST_F(TritonGemmTest, DISABLED_S8xS8) { constexpr absl::string_view kHloText = R"( HloModule t @@ -674,6 +648,7 @@ ENTRY f { ROOT z = s32[1024,1024]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; + MatchOptimizedHlo(kHloText, "CHECK: __triton_nested_gemm_fusion"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } @@ -1128,40 +1103,6 @@ ENTRY e { )"); } -TEST_F(TritonGemmTest, DISABLED_SingleElementTileIsHandled) { - if (std::holds_alternative( - GpuComputeCapability())) { - GTEST_SKIP() << "Not using autotuner on ROCM yet."; - } - MatchOptimizedHlo(R"( -t { - p0 = f32[2,7,3]{2,1,0} parameter(0) - p1 = s32[2,1]{1,0} parameter(1) - c = s32[] constant(1) - br0 = s32[2,1]{1,0} broadcast(c), dimensions={} - cmp = pred[2,1]{1,0} compare(p1, br0), direction=LT - bc0 = pred[2]{0} bitcast(cmp) - br1 = pred[2,1,3,3]{3,2,0,1} broadcast(bc0), dimensions={0} - cvt = f32[2,1,3,3]{3,2,0,1} convert(br1) - bc1 = f32[2,3,3]{2,1,0} bitcast(cvt) - ROOT d = f32[2,7,3]{2,1,0} dot(p0, bc1), - lhs_batch_dims={0}, lhs_contracting_dims={2}, - rhs_batch_dims={0}, rhs_contracting_dims={1} -} - -ENTRY e { - p0 = f32[2,7,3]{2,1,0} parameter(0) - p1 = s32[2,1]{1,0} parameter(1) - ROOT r = f32[2,7,3]{2,1,0} fusion(p0, p1), kind=kCustom, - calls=t, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} -})", - // This partially optimized HLO will go through the - // autotuner which will run the fusion through the emitter - // multiple times and assign block sizes on success. - R"( -; CHECK: block_m -)"); -} TEST_F(TritonGemmTest, BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported) { @@ -1262,28 +1203,6 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_SameInput) { - constexpr absl::string_view kHloText = R"( -HloModule m - -ENTRY e { - p0 = pred[5,5]{1,0} parameter(0) - c = f32[5,5]{1,0} convert(p0) - ROOT r = f32[5,5]{1,0} dot(c, c), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -})"; - - // The fusion has separate parameters for each scope. - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK: %[[p0:.*]] = pred[5,5]{1,0} parameter(0) -; CHECK: fusion(%[[p0]], %[[p0]]), kind=kCustom -; CHECK-PTX-SAME: "block_m": -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); -} - TEST_F(TritonGemmTest, DISABLED_DynamicSliceIsSupportedInLhsEndToEnd) { // The select is used to restrict the start index to values that make sense. // If it was constant, then the dynamic-slice would be optimized to slice. It @@ -1353,35 +1272,6 @@ ENTRY e { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, DISABLED_MultiplePathsToSameOperandWorks) { - constexpr absl::string_view kHloText = R"( -triton_computation { - p0 = bf16[8192,512]{1,0} parameter(0) - p1 = bf16[512,512]{1,0} parameter(1) - dot = bf16[8192,512]{1,0} dot(bf16[8192,512]{1,0} p0, bf16[512,512]{1,0} p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - p2 = bf16[8192,512]{1,0} parameter(2) - multiply.1 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} dot, bf16[8192,512]{1,0} p2) - ROOT multiply.2 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} multiply.1, bf16[8192,512]{1,0} p2) -} - -ENTRY e { - p0 = bf16[8192,512]{1,0} parameter(0) - p1 = bf16[512,512]{1,0} parameter(1) - p2 = bf16[8192,512]{1,0} parameter(2) - ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"256","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}} -})"; - - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_computation", R"( - CHECK: tt.dot - CHECK-SAME: tensor<64x32xbf16> * tensor<32x256xbf16> -> tensor<64x256xf32> - CHECK: arith.mulf - CHECK: arith.mulf - )")); -} - class TritonGemmDynamicSliceClampingTest : public TritonTest, public ::testing::WithParamInterface {}; @@ -1937,198 +1827,6 @@ e { /*arel=*/1e-2})); } -TEST_F(TritonGemmTest, DISABLED_MinimumHandlesNaNsOnTheLeft) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - neg1 = f32[] constant(-1) - neg1s = f32[5,5] broadcast(neg1), dimensions={} - nans = f32[5,5] sqrt(neg1s) - min = f32[5,5] minimum(nans, neg1s) - ROOT _ = f32[5,5] dot(p0, min), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MinimumHandlesNaNsOnTheRight) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - neg1 = f32[] constant(-1) - neg1s = f32[5,5] broadcast(neg1), dimensions={} - nans = f32[5,5] sqrt(neg1s) - min = f32[5,5] minimum(neg1s, nans) - ROOT _ = f32[5,5] dot(p0, min), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MaximumHandlesNaNsOnTheLeft) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - neg1 = f32[] constant(-1) - neg1s = f32[5,5] broadcast(neg1), dimensions={} - nans = f32[5,5] sqrt(neg1s) - max = f32[5,5] maximum(nans, neg1s) - ROOT _ = f32[5,5] dot(p0, max), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MaximumHandlesNaNsOnTheRight) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - neg1 = f32[] constant(-1) - neg1s = f32[5,5] broadcast(neg1), dimensions={} - nans = f32[5,5] sqrt(neg1s) - max = f32[5,5] maximum(neg1s, nans) - ROOT _ = f32[5,5] dot(p0, max), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MinimumReturnsLHS) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - zero = f32[] constant(0) - zeros = f32[5,5] broadcast(zero), dimensions={} - one = f32[] constant(1) - ones = f32[5,5] broadcast(one), dimensions={} - min = f32[5,5] minimum(zeros, ones) - ROOT _ = f32[5,5] dot(p0, min), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, - /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MinimumReturnsRHS) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - zero = f32[] constant(0) - zeros = f32[5,5] broadcast(zero), dimensions={} - one = f32[] constant(1) - ones = f32[5,5] broadcast(one), dimensions={} - min = f32[5,5] minimum(ones, zeros) - ROOT _ = f32[5,5] dot(p0, min), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, - /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MaximumReturnsLHS) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - zero = f32[] constant(0) - zeros = f32[5,5] broadcast(zero), dimensions={} - one = f32[] constant(1) - ones = f32[5,5] broadcast(one), dimensions={} - max = f32[5,5] maximum(ones, zeros) - ROOT _ = f32[5,5] dot(p0, max), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, - /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_MaximumReturnsRHS) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,5] parameter(0) - zero = f32[] constant(0) - zeros = f32[5,5] broadcast(zero), dimensions={} - one = f32[] constant(1) - ones = f32[5,5] broadcast(one), dimensions={} - max = f32[5,5] maximum(zeros, ones) - ROOT _ = f32[5,5] dot(p0, max), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, - /*arel=*/1e-3})); -} - TEST_F(TritonGemmTest, SineOutputIsNotFused) { constexpr absl::string_view kHloText = R"( HloModule m diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index ffe7033ef1ede8..ce59346d85ee1e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -75,6 +75,33 @@ class TritonEmitterTest : public GpuCodegenTest { } }; +// TODO(bchetioui): turn this into a general binary elementwise test. +TEST_F(TritonEmitterTest, MinimumIsEmittedCorrectly) { + constexpr absl::string_view kHloText = R"( +computation { + p0 = f32[8,4] parameter(0) + p1 = f32[8,4] parameter(1) + ROOT minimum = f32[8,4] minimum(p0, p1) +} + +ENTRY entry_computation { + p0 = f32[8,4] parameter(0) + p1 = f32[8,4] parameter(1) + ROOT fusion = f32[8,4] fusion(p0, p1), kind=kCustom, + calls=computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["1", "4"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); +} + TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { constexpr absl::string_view kHloText = R"( HloModule m From 028196e352a8d7aaa2df6ce6e1197e8fd8246543 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 10 Apr 2025 08:49:06 -0700 Subject: [PATCH 0503/1324] Automated Code Change PiperOrigin-RevId: 746047694 --- third_party/stablehlo/temporary.patch | 653 ++++++++++++++++++ .../xla/third_party/stablehlo/temporary.patch | 653 ++++++++++++++++++ 2 files changed, 1306 insertions(+) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index a6f4b05d72cedb..8683dc781d6a82 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,3 +1,612 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -58,16 +58,10 @@ + gentbl_cc_library( + name = "base_attr_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attr-interface-decls"], +- "stablehlo/dialect/BaseAttrInterfaces.h.inc", +- ), +- ( +- ["-gen-attr-interface-defs"], +- "stablehlo/dialect/BaseAttrInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/BaseAttrInterfaces.h.inc": ["-gen-attr-interface-decls"], ++ "stablehlo/dialect/BaseAttrInterfaces.cpp.inc": ["-gen-attr-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/Base.td", + deps = [":stablehlo_ops_td_files"], +@@ -107,16 +101,10 @@ + gentbl_cc_library( + name = "chlo_attrs_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attrdef-decls"], +- "stablehlo/dialect/ChloAttrs.h.inc", +- ), +- ( +- ["-gen-attrdef-defs"], +- "stablehlo/dialect/ChloAttrs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/ChloAttrs.h.inc": ["-gen-attrdef-decls"], ++ "stablehlo/dialect/ChloAttrs.cpp.inc": ["-gen-attrdef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/ChloOps.td", + deps = [ +@@ -173,16 +161,10 @@ + gentbl_cc_library( + name = "chlo_enums_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-enum-decls"], +- "stablehlo/dialect/ChloEnums.h.inc", +- ), +- ( +- ["-gen-enum-defs"], +- "stablehlo/dialect/ChloEnums.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/ChloEnums.h.inc": ["-gen-enum-decls"], ++ "stablehlo/dialect/ChloEnums.cpp.inc": ["-gen-enum-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/ChloOps.td", + deps = [ +@@ -193,23 +175,14 @@ + gentbl_cc_library( + name = "chlo_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/dialect/ChloOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/dialect/ChloOps.cpp.inc", +- ), +- ( +- [ +- "-gen-dialect-doc", +- "--dialect=chlo", +- ], +- "stablehlo/dialect/chlo.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/ChloOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/dialect/ChloOps.cpp.inc": ["-gen-op-defs"], ++ "stablehlo/dialect/chlo.md": [ ++ "-gen-dialect-doc", ++ "--dialect=chlo", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/ChloOps.td", + deps = [ +@@ -304,12 +277,7 @@ + + gentbl_cc_library( + name = "chlo_rewriters_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/ChloDecompositionPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/ChloDecompositionPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/ChloDecompositionPatterns.td", + deps = [ +@@ -320,12 +288,7 @@ + + gentbl_cc_library( + name = "stablehlo_aggressive_simplification_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td", + deps = [ +@@ -335,12 +298,7 @@ + + gentbl_cc_library( + name = "stablehlo_legalize_deprecated_ops_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/StablehloLegalizeDeprecatedOpsPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/StablehloLegalizeDeprecatedOpsPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloLegalizeDeprecatedOpsPatterns.td", + deps = [ +@@ -350,12 +308,7 @@ + + gentbl_cc_library( + name = "vhlo_rewriters_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/VhloToVersionPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/VhloToVersionPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/VhloToVersionPatterns.td", + deps = [ +@@ -365,12 +318,7 @@ + + gentbl_cc_library( + name = "stablehlo_create_compatibility_expander_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.td", + deps = [ +@@ -380,12 +328,7 @@ + + gentbl_cc_library( + name = "stablehlo_create_complex_math_expander_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td", + deps = [ +@@ -420,16 +363,10 @@ + gentbl_cc_library( + name = "interpreter_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/reference/InterpreterOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/reference/InterpreterOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/reference/InterpreterOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/reference/InterpreterOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/reference/InterpreterOps.td", + deps = [ +@@ -451,21 +388,15 @@ + gentbl_cc_library( + name = "interpreter_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=InterpreterTransforms", +- ], +- "stablehlo/reference/InterpreterPasses.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/reference/interpreter_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/reference/InterpreterPasses.h.inc": [ ++ "-gen-pass-decls", ++ "-name=InterpreterTransforms", ++ ], ++ "stablehlo/reference/interpreter_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/reference/InterpreterPasses.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -494,21 +425,15 @@ + gentbl_cc_library( + name = "linalg_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=StablehloLinalgTransforms", +- ], +- "stablehlo/conversions/linalg/transforms/Passes.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/conversions/linalg/transforms/Passes.h.inc": [ ++ "-gen-pass-decls", ++ "-name=StablehloLinalgTransforms", ++ ], ++ "stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/conversions/linalg/transforms/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -901,16 +826,10 @@ + gentbl_cc_library( + name = "stablehlo_attrs_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attrdef-decls"], +- "stablehlo/dialect/StablehloAttrs.h.inc", +- ), +- ( +- ["-gen-attrdef-defs"], +- "stablehlo/dialect/StablehloAttrs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloAttrs.h.inc": ["-gen-attrdef-decls"], ++ "stablehlo/dialect/StablehloAttrs.cpp.inc": ["-gen-attrdef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1067,16 +986,10 @@ + gentbl_cc_library( + name = "stablehlo_enums_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-enum-decls"], +- "stablehlo/dialect/StablehloEnums.h.inc", +- ), +- ( +- ["-gen-enum-defs"], +- "stablehlo/dialect/StablehloEnums.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloEnums.h.inc": ["-gen-enum-decls"], ++ "stablehlo/dialect/StablehloEnums.cpp.inc": ["-gen-enum-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1087,22 +1000,16 @@ + gentbl_cc_library( + name = "stablehlo_types_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-typedef-decls", +- "--typedefs-dialect=stablehlo", +- ], +- "stablehlo/dialect/StablehloTypeDefs.h.inc", +- ), +- ( +- [ +- "-gen-typedef-defs", +- "--typedefs-dialect=stablehlo", +- ], +- "stablehlo/dialect/StablehloTypeDefs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloTypeDefs.h.inc": [ ++ "-gen-typedef-decls", ++ "--typedefs-dialect=stablehlo", ++ ], ++ "stablehlo/dialect/StablehloTypeDefs.cpp.inc": [ ++ "-gen-typedef-defs", ++ "--typedefs-dialect=stablehlo", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1113,16 +1020,10 @@ + gentbl_cc_library( + name = "stablehlo_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/dialect/StablehloOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/dialect/StablehloOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/dialect/StablehloOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1188,20 +1089,14 @@ + gentbl_cc_library( + name = "stablehlo_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- ], +- "stablehlo/transforms/Passes.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/transforms/stablehlo_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/transforms/Passes.h.inc": [ ++ "-gen-pass-decls", ++ ], ++ "stablehlo/transforms/stablehlo_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -1304,21 +1199,15 @@ + gentbl_cc_library( + name = "stablehlo_passes_optimization_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=Optimization", +- ], +- "stablehlo/transforms/optimization/Passes.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/transforms/optimization/stablehlo_optimization_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/transforms/optimization/Passes.h.inc": [ ++ "-gen-pass-decls", ++ "-name=Optimization", ++ ], ++ "stablehlo/transforms/optimization/stablehlo_optimization_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/optimization/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -1555,19 +1444,13 @@ + gentbl_cc_library( + name = "tosa_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=StablehloTOSATransforms", +- ], +- "stablehlo/conversions/tosa/transforms/Passes.h.inc", +- ), +- ( +- ["-gen-pass-doc"], +- "stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/conversions/tosa/transforms/Passes.h.inc": [ ++ "-gen-pass-decls", ++ "-name=StablehloTOSATransforms", ++ ], ++ "stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md": ["-gen-pass-doc"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/conversions/tosa/transforms/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -1607,12 +1490,7 @@ + + gentbl_cc_library( + name = "tosa_pdll_inc_gen", +- tbl_outs = [ +- ( +- ["-x=cpp"], +- "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll.h.inc": ["-x=cpp"]}, + tblgen = "@llvm-project//mlir:mlir-pdll", + td_file = "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll", + deps = [ +@@ -1710,16 +1588,10 @@ + gentbl_cc_library( + name = "vhlo_attr_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attr-interface-decls"], +- "stablehlo/dialect/VhloAttrInterfaces.h.inc", +- ), +- ( +- ["-gen-attr-interface-defs"], +- "stablehlo/dialect/VhloAttrInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloAttrInterfaces.h.inc": ["-gen-attr-interface-decls"], ++ "stablehlo/dialect/VhloAttrInterfaces.cpp.inc": ["-gen-attr-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloEnums.td", + deps = [ +@@ -1730,16 +1602,10 @@ + gentbl_cc_library( + name = "vhlo_attrs_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attrdef-decls"], +- "stablehlo/dialect/VhloAttrs.h.inc", +- ), +- ( +- ["-gen-attrdef-defs"], +- "stablehlo/dialect/VhloAttrs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloAttrs.h.inc": ["-gen-attrdef-decls"], ++ "stablehlo/dialect/VhloAttrs.cpp.inc": ["-gen-attrdef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ +@@ -1750,16 +1616,10 @@ + gentbl_cc_library( + name = "vhlo_enums_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-enum-decls"], +- "stablehlo/dialect/VhloEnums.h.inc", +- ), +- ( +- ["-gen-enum-defs"], +- "stablehlo/dialect/VhloEnums.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloEnums.h.inc": ["-gen-enum-decls"], ++ "stablehlo/dialect/VhloEnums.cpp.inc": ["-gen-enum-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloEnums.td", + deps = [ +@@ -1770,16 +1630,10 @@ + gentbl_cc_library( + name = "vhlo_op_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-interface-decls"], +- "stablehlo/dialect/VhloOpInterfaces.h.inc", +- ), +- ( +- ["-gen-op-interface-defs"], +- "stablehlo/dialect/VhloOpInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloOpInterfaces.h.inc": ["-gen-op-interface-decls"], ++ "stablehlo/dialect/VhloOpInterfaces.cpp.inc": ["-gen-op-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ +@@ -1822,16 +1676,10 @@ + gentbl_cc_library( + name = "vhlo_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/dialect/VhloOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/dialect/VhloOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/dialect/VhloOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ +@@ -1884,16 +1732,10 @@ + gentbl_cc_library( + name = "vhlo_type_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-type-interface-decls"], +- "stablehlo/dialect/VhloTypeInterfaces.h.inc", +- ), +- ( +- ["-gen-type-interface-defs"], +- "stablehlo/dialect/VhloTypeInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloTypeInterfaces.h.inc": ["-gen-type-interface-decls"], ++ "stablehlo/dialect/VhloTypeInterfaces.cpp.inc": ["-gen-type-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloTypes.td", + deps = [ +@@ -1904,16 +1746,10 @@ + gentbl_cc_library( + name = "vhlo_types_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-typedef-decls"], +- "stablehlo/dialect/VhloTypeDefs.h.inc", +- ), +- ( +- ["-gen-typedef-defs"], +- "stablehlo/dialect/VhloTypeDefs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloTypeDefs.h.inc": ["-gen-typedef-decls"], ++ "stablehlo/dialect/VhloTypeDefs.cpp.inc": ["-gen-typedef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir --- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir @@ -46,6 +655,50 @@ diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehl return %cst : tensor<14x15x0x9xcomplex> } } +diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel +--- stablehlo/stablehlo/tests/BUILD.bazel ++++ stablehlo/stablehlo/tests/BUILD.bazel +@@ -51,16 +51,10 @@ + gentbl_cc_library( + name = "check_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "CheckOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "CheckOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "CheckOps.h.inc": ["-gen-op-decls"], ++ "CheckOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "CheckOps.td", + deps = [ +@@ -108,15 +102,10 @@ + gentbl_cc_library( + name = "test_utils_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=HloTest", +- ], +- "TestUtils.h.inc", +- ), +- ], ++ tbl_outs = {"TestUtils.h.inc": [ ++ "-gen-pass-decls", ++ "-name=HloTest", ++ ]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TestUtils.td", + deps = [ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp --- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index a6f4b05d72cedb..8683dc781d6a82 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,3 +1,612 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -58,16 +58,10 @@ + gentbl_cc_library( + name = "base_attr_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attr-interface-decls"], +- "stablehlo/dialect/BaseAttrInterfaces.h.inc", +- ), +- ( +- ["-gen-attr-interface-defs"], +- "stablehlo/dialect/BaseAttrInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/BaseAttrInterfaces.h.inc": ["-gen-attr-interface-decls"], ++ "stablehlo/dialect/BaseAttrInterfaces.cpp.inc": ["-gen-attr-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/Base.td", + deps = [":stablehlo_ops_td_files"], +@@ -107,16 +101,10 @@ + gentbl_cc_library( + name = "chlo_attrs_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attrdef-decls"], +- "stablehlo/dialect/ChloAttrs.h.inc", +- ), +- ( +- ["-gen-attrdef-defs"], +- "stablehlo/dialect/ChloAttrs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/ChloAttrs.h.inc": ["-gen-attrdef-decls"], ++ "stablehlo/dialect/ChloAttrs.cpp.inc": ["-gen-attrdef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/ChloOps.td", + deps = [ +@@ -173,16 +161,10 @@ + gentbl_cc_library( + name = "chlo_enums_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-enum-decls"], +- "stablehlo/dialect/ChloEnums.h.inc", +- ), +- ( +- ["-gen-enum-defs"], +- "stablehlo/dialect/ChloEnums.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/ChloEnums.h.inc": ["-gen-enum-decls"], ++ "stablehlo/dialect/ChloEnums.cpp.inc": ["-gen-enum-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/ChloOps.td", + deps = [ +@@ -193,23 +175,14 @@ + gentbl_cc_library( + name = "chlo_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/dialect/ChloOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/dialect/ChloOps.cpp.inc", +- ), +- ( +- [ +- "-gen-dialect-doc", +- "--dialect=chlo", +- ], +- "stablehlo/dialect/chlo.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/ChloOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/dialect/ChloOps.cpp.inc": ["-gen-op-defs"], ++ "stablehlo/dialect/chlo.md": [ ++ "-gen-dialect-doc", ++ "--dialect=chlo", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/ChloOps.td", + deps = [ +@@ -304,12 +277,7 @@ + + gentbl_cc_library( + name = "chlo_rewriters_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/ChloDecompositionPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/ChloDecompositionPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/ChloDecompositionPatterns.td", + deps = [ +@@ -320,12 +288,7 @@ + + gentbl_cc_library( + name = "stablehlo_aggressive_simplification_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td", + deps = [ +@@ -335,12 +298,7 @@ + + gentbl_cc_library( + name = "stablehlo_legalize_deprecated_ops_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/StablehloLegalizeDeprecatedOpsPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/StablehloLegalizeDeprecatedOpsPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloLegalizeDeprecatedOpsPatterns.td", + deps = [ +@@ -350,12 +308,7 @@ + + gentbl_cc_library( + name = "vhlo_rewriters_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/VhloToVersionPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/VhloToVersionPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/VhloToVersionPatterns.td", + deps = [ +@@ -365,12 +318,7 @@ + + gentbl_cc_library( + name = "stablehlo_create_compatibility_expander_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.td", + deps = [ +@@ -380,12 +328,7 @@ + + gentbl_cc_library( + name = "stablehlo_create_complex_math_expander_inc_gen", +- tbl_outs = [ +- ( +- ["--gen-rewriters"], +- "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc": ["--gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td", + deps = [ +@@ -420,16 +363,10 @@ + gentbl_cc_library( + name = "interpreter_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/reference/InterpreterOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/reference/InterpreterOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/reference/InterpreterOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/reference/InterpreterOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/reference/InterpreterOps.td", + deps = [ +@@ -451,21 +388,15 @@ + gentbl_cc_library( + name = "interpreter_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=InterpreterTransforms", +- ], +- "stablehlo/reference/InterpreterPasses.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/reference/interpreter_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/reference/InterpreterPasses.h.inc": [ ++ "-gen-pass-decls", ++ "-name=InterpreterTransforms", ++ ], ++ "stablehlo/reference/interpreter_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/reference/InterpreterPasses.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -494,21 +425,15 @@ + gentbl_cc_library( + name = "linalg_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=StablehloLinalgTransforms", +- ], +- "stablehlo/conversions/linalg/transforms/Passes.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/conversions/linalg/transforms/Passes.h.inc": [ ++ "-gen-pass-decls", ++ "-name=StablehloLinalgTransforms", ++ ], ++ "stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/conversions/linalg/transforms/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -901,16 +826,10 @@ + gentbl_cc_library( + name = "stablehlo_attrs_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attrdef-decls"], +- "stablehlo/dialect/StablehloAttrs.h.inc", +- ), +- ( +- ["-gen-attrdef-defs"], +- "stablehlo/dialect/StablehloAttrs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloAttrs.h.inc": ["-gen-attrdef-decls"], ++ "stablehlo/dialect/StablehloAttrs.cpp.inc": ["-gen-attrdef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1067,16 +986,10 @@ + gentbl_cc_library( + name = "stablehlo_enums_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-enum-decls"], +- "stablehlo/dialect/StablehloEnums.h.inc", +- ), +- ( +- ["-gen-enum-defs"], +- "stablehlo/dialect/StablehloEnums.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloEnums.h.inc": ["-gen-enum-decls"], ++ "stablehlo/dialect/StablehloEnums.cpp.inc": ["-gen-enum-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1087,22 +1000,16 @@ + gentbl_cc_library( + name = "stablehlo_types_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-typedef-decls", +- "--typedefs-dialect=stablehlo", +- ], +- "stablehlo/dialect/StablehloTypeDefs.h.inc", +- ), +- ( +- [ +- "-gen-typedef-defs", +- "--typedefs-dialect=stablehlo", +- ], +- "stablehlo/dialect/StablehloTypeDefs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloTypeDefs.h.inc": [ ++ "-gen-typedef-decls", ++ "--typedefs-dialect=stablehlo", ++ ], ++ "stablehlo/dialect/StablehloTypeDefs.cpp.inc": [ ++ "-gen-typedef-defs", ++ "--typedefs-dialect=stablehlo", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1113,16 +1020,10 @@ + gentbl_cc_library( + name = "stablehlo_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/dialect/StablehloOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/dialect/StablehloOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/StablehloOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/dialect/StablehloOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/StablehloOps.td", + deps = [ +@@ -1188,20 +1089,14 @@ + gentbl_cc_library( + name = "stablehlo_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- ], +- "stablehlo/transforms/Passes.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/transforms/stablehlo_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/transforms/Passes.h.inc": [ ++ "-gen-pass-decls", ++ ], ++ "stablehlo/transforms/stablehlo_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -1304,21 +1199,15 @@ + gentbl_cc_library( + name = "stablehlo_passes_optimization_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=Optimization", +- ], +- "stablehlo/transforms/optimization/Passes.h.inc", +- ), +- ( +- [ +- "-gen-pass-doc", +- ], +- "stablehlo/transforms/optimization/stablehlo_optimization_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/transforms/optimization/Passes.h.inc": [ ++ "-gen-pass-decls", ++ "-name=Optimization", ++ ], ++ "stablehlo/transforms/optimization/stablehlo_optimization_passes.md": [ ++ "-gen-pass-doc", ++ ], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/optimization/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -1555,19 +1444,13 @@ + gentbl_cc_library( + name = "tosa_pass_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=StablehloTOSATransforms", +- ], +- "stablehlo/conversions/tosa/transforms/Passes.h.inc", +- ), +- ( +- ["-gen-pass-doc"], +- "stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/conversions/tosa/transforms/Passes.h.inc": [ ++ "-gen-pass-decls", ++ "-name=StablehloTOSATransforms", ++ ], ++ "stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md": ["-gen-pass-doc"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/conversions/tosa/transforms/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +@@ -1607,12 +1490,7 @@ + + gentbl_cc_library( + name = "tosa_pdll_inc_gen", +- tbl_outs = [ +- ( +- ["-x=cpp"], +- "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll.h.inc", +- ), +- ], ++ tbl_outs = {"stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll.h.inc": ["-x=cpp"]}, + tblgen = "@llvm-project//mlir:mlir-pdll", + td_file = "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll", + deps = [ +@@ -1710,16 +1588,10 @@ + gentbl_cc_library( + name = "vhlo_attr_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attr-interface-decls"], +- "stablehlo/dialect/VhloAttrInterfaces.h.inc", +- ), +- ( +- ["-gen-attr-interface-defs"], +- "stablehlo/dialect/VhloAttrInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloAttrInterfaces.h.inc": ["-gen-attr-interface-decls"], ++ "stablehlo/dialect/VhloAttrInterfaces.cpp.inc": ["-gen-attr-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloEnums.td", + deps = [ +@@ -1730,16 +1602,10 @@ + gentbl_cc_library( + name = "vhlo_attrs_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-attrdef-decls"], +- "stablehlo/dialect/VhloAttrs.h.inc", +- ), +- ( +- ["-gen-attrdef-defs"], +- "stablehlo/dialect/VhloAttrs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloAttrs.h.inc": ["-gen-attrdef-decls"], ++ "stablehlo/dialect/VhloAttrs.cpp.inc": ["-gen-attrdef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ +@@ -1750,16 +1616,10 @@ + gentbl_cc_library( + name = "vhlo_enums_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-enum-decls"], +- "stablehlo/dialect/VhloEnums.h.inc", +- ), +- ( +- ["-gen-enum-defs"], +- "stablehlo/dialect/VhloEnums.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloEnums.h.inc": ["-gen-enum-decls"], ++ "stablehlo/dialect/VhloEnums.cpp.inc": ["-gen-enum-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloEnums.td", + deps = [ +@@ -1770,16 +1630,10 @@ + gentbl_cc_library( + name = "vhlo_op_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-interface-decls"], +- "stablehlo/dialect/VhloOpInterfaces.h.inc", +- ), +- ( +- ["-gen-op-interface-defs"], +- "stablehlo/dialect/VhloOpInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloOpInterfaces.h.inc": ["-gen-op-interface-decls"], ++ "stablehlo/dialect/VhloOpInterfaces.cpp.inc": ["-gen-op-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ +@@ -1822,16 +1676,10 @@ + gentbl_cc_library( + name = "vhlo_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "stablehlo/dialect/VhloOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "stablehlo/dialect/VhloOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloOps.h.inc": ["-gen-op-decls"], ++ "stablehlo/dialect/VhloOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ +@@ -1884,16 +1732,10 @@ + gentbl_cc_library( + name = "vhlo_type_interfaces_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-type-interface-decls"], +- "stablehlo/dialect/VhloTypeInterfaces.h.inc", +- ), +- ( +- ["-gen-type-interface-defs"], +- "stablehlo/dialect/VhloTypeInterfaces.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloTypeInterfaces.h.inc": ["-gen-type-interface-decls"], ++ "stablehlo/dialect/VhloTypeInterfaces.cpp.inc": ["-gen-type-interface-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloTypes.td", + deps = [ +@@ -1904,16 +1746,10 @@ + gentbl_cc_library( + name = "vhlo_types_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-typedef-decls"], +- "stablehlo/dialect/VhloTypeDefs.h.inc", +- ), +- ( +- ["-gen-typedef-defs"], +- "stablehlo/dialect/VhloTypeDefs.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "stablehlo/dialect/VhloTypeDefs.h.inc": ["-gen-typedef-decls"], ++ "stablehlo/dialect/VhloTypeDefs.cpp.inc": ["-gen-typedef-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir --- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir +++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir @@ -46,6 +655,50 @@ diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehl return %cst : tensor<14x15x0x9xcomplex> } } +diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel +--- stablehlo/stablehlo/tests/BUILD.bazel ++++ stablehlo/stablehlo/tests/BUILD.bazel +@@ -51,16 +51,10 @@ + gentbl_cc_library( + name = "check_ops_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- ["-gen-op-decls"], +- "CheckOps.h.inc", +- ), +- ( +- ["-gen-op-defs"], +- "CheckOps.cpp.inc", +- ), +- ], ++ tbl_outs = { ++ "CheckOps.h.inc": ["-gen-op-decls"], ++ "CheckOps.cpp.inc": ["-gen-op-defs"], ++ }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "CheckOps.td", + deps = [ +@@ -108,15 +102,10 @@ + gentbl_cc_library( + name = "test_utils_inc_gen", + strip_include_prefix = ".", +- tbl_outs = [ +- ( +- [ +- "-gen-pass-decls", +- "-name=HloTest", +- ], +- "TestUtils.h.inc", +- ), +- ], ++ tbl_outs = {"TestUtils.h.inc": [ ++ "-gen-pass-decls", ++ "-name=HloTest", ++ ]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TestUtils.td", + deps = [ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp --- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp From 90d8302d74b229b4f4a6259ffde0cb723e0f9edb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 09:01:20 -0700 Subject: [PATCH 0504/1324] Fix uses of `reinterpret_cast` in internal XLA and suppress warnings on some uses. PiperOrigin-RevId: 746051752 --- third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index 24c82637e42041..d917111ad31247 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -141,13 +141,13 @@ TEST(SafeReinterpretCast, CanCastRestrictPointerToRestrictPointer) { void Dummy() {} -TEST(SafeReinterepretCast, CanCastFuncPointerToFromVoidPointer) { +TEST(SafeReinterpretCast, CanCastFuncPointerToFromVoidPointer) { void* const void_p = safe_reinterpret_cast(&Dummy); void (*func_p)() = safe_reinterpret_cast(void_p); EXPECT_EQ(func_p, &Dummy); } -TEST(SafeReinterepretCast, CanCastDataPointerToFromVoidPointer) { +TEST(SafeReinterpretCast, CanCastDataPointerToFromVoidPointer) { int x = 42; void* const void_p = safe_reinterpret_cast(&x); int* const int_p = safe_reinterpret_cast(void_p); From dcb4d8d797bc2e0f4e339281c30eaf7353b66a76 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 10 Apr 2025 09:07:33 -0700 Subject: [PATCH 0505/1324] More HLO->StableHLO Direct Conversions Dot, Tuple, and Compare PiperOrigin-RevId: 746053852 --- .../xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 1 + .../hlo_to_mhlo/attribute_importer.cc | 150 ++++++++++++++++++ .../hlo_to_mhlo/attribute_importer.h | 28 ++++ .../hlo_to_mhlo/hlo_function_importer.cc | 106 +++++++------ .../hlo/translate/hlo_to_mhlo/hlo_utils.cc | 34 +++- .../tests/import_emit_stablehlo.hlo | 68 ++++---- 6 files changed, 300 insertions(+), 87 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index d62735b16d7a5a..8b5a77f674036f 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -193,6 +193,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:SparseTensorEnums", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc index 0be4b75f5e11e8..9f449a64c5c794 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc @@ -33,6 +33,8 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -43,6 +45,154 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { +namespace stablehlo { + +mlir::stablehlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) { + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + std::vector collapsed_slice_dims( + dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); + std::vector operand_batching_dims( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + std::vector start_indices_batching_dims( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + return mlir::stablehlo::GatherDimensionNumbersAttr::get( + builder->getContext(), offset_dims, collapsed_slice_dims, + operand_batching_dims, start_indices_batching_dims, start_index_map, + dnums.index_vector_dim()); +} + +mlir::stablehlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) { + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + std::vector inserted_window_dims( + dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); + std::vector input_batching_dims(dnums.input_batching_dims().begin(), + dnums.input_batching_dims().end()); + std::vector scatter_indices_batching_dims( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + return mlir::stablehlo::ScatterDimensionNumbersAttr::get( + builder->getContext(), update_window_dims, inserted_window_dims, + input_batching_dims, scatter_indices_batching_dims, + scatter_dims_to_operand_dims, dnums.index_vector_dim()); +} + +mlir::stablehlo::DotAlgorithmAttr ConvertDotAlgorithm( + const PrecisionConfig::Algorithm algorithm, mlir::Builder* builder) { + mlir::Type lhs, rhs, accum; + int64_t lhsComponentCount = 1, rhsComponentCount = 1, + numPrimitiveOperations = 1; + bool allowImpreciseAccumulation = false; + switch (algorithm) { + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: { + lhs = rhs = builder->getType(); + accum = builder->getF32Type(); + break; + } + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: { + lhs = rhs = builder->getType(); + accum = builder->getF32Type(); + allowImpreciseAccumulation = true; + break; + } + case PrecisionConfig::ALG_DOT_F16_F16_F16: { + lhs = rhs = accum = builder->getF16Type(); + break; + } + case PrecisionConfig::ALG_DOT_F16_F16_F32: { + lhs = rhs = builder->getF16Type(); + accum = builder->getF32Type(); + break; + } + case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: { + lhs = rhs = accum = builder->getBF16Type(); + break; + } + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: { + lhs = rhs = builder->getBF16Type(); + accum = builder->getF32Type(); + break; + } + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: { + lhs = rhs = builder->getBF16Type(); + accum = builder->getF32Type(); + numPrimitiveOperations = 3; + break; + } + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: { + lhs = rhs = builder->getBF16Type(); + accum = builder->getF32Type(); + numPrimitiveOperations = 6; + break; + } + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: { + lhs = rhs = builder->getTF32Type(); + accum = builder->getF32Type(); + break; + } + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: { + lhs = rhs = builder->getTF32Type(); + accum = builder->getF32Type(); + numPrimitiveOperations = 3; + break; + } + case PrecisionConfig::ALG_DOT_F32_F32_F32: { + lhs = rhs = accum = builder->getF32Type(); + break; + } + case PrecisionConfig::ALG_DOT_F64_F64_F64: { + lhs = rhs = accum = builder->getF64Type(); + break; + } + default: + // Unset, sentinels + return mlir::stablehlo::DotAlgorithmAttr{}; + } + return mlir::stablehlo::DotAlgorithmAttr::get( + builder->getContext(), lhs, rhs, accum, lhsComponentCount, + rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation); +} + +mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( + const DotDimensionNumbers& dnums, mlir::Builder* builder) { + auto arrayref = [](absl::Span array) { + return llvm::ArrayRef{array.data(), array.size()}; + }; + return mlir::stablehlo::DotDimensionNumbersAttr::get( + builder->getContext(), arrayref(dnums.lhs_batch_dimensions()), + arrayref(dnums.rhs_batch_dimensions()), + arrayref(dnums.lhs_contracting_dimensions()), + arrayref(dnums.rhs_contracting_dimensions())); +} + +mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, + mlir::Builder* builder) { + if (!config) return {}; + + // TODO(b/129709049) The HLO text format elides this in the all DEFAULT + // case and the parser sticks it in. Maybe we should too. + llvm::SmallVector operand_precision_attrs; + + for (auto prec : config->operand_precision()) { + operand_precision_attrs.push_back(mlir::stablehlo::PrecisionAttr::get( + builder->getContext(), mlir::stablehlo::symbolizePrecision( + PrecisionConfig_Precision_Name(prec)) + .value())); + } + return builder->getArrayAttr(operand_precision_attrs); +} + +} // namespace stablehlo mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, mlir::Builder* builder) { diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h index b64ed37483a264..6cace3ee16a72c 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/span.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" @@ -33,23 +34,50 @@ limitations under the License. namespace xla { +namespace stablehlo { +// Converts the gather dimensions to attributes. +mlir::stablehlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the scatter dimensions to attributes. +mlir::stablehlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the dot algorithm to attributes. +mlir::stablehlo::DotAlgorithmAttr ConvertDotAlgorithm( + PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); + +// Converts the dot dimensions to attributes. +mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( + const DotDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts an XLA PrecisionConfig to the corresponding MLIR attribute. +mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, + mlir::Builder* builder); + +} // namespace stablehlo + // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, mlir::Builder* builder); // Converts the gather dimensions to attributes. +// [Deprecated] Used in TF2XLA only. mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); // Converts the scatter dimensions to attributes. +// [Deprecated] Used in TF2XLA only. mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); // Converts the dot algorithm to attributes. +// Used by sparse dot. mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); // Converts the dot dimensions to attributes. +// Used by sparse dot. mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( const DotDimensionNumbers& dnums, mlir::Builder* builder); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 177f3fc0386dfb..a362913ba730e8 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -900,47 +900,62 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( case HloOpcode::kDot: { auto dot = Cast(instruction); + if (dot->sparse_operands()) { + attributes.push_back(builder_->getNamedAttr( + "precision_config", + ConvertPrecisionConfig(&instruction->precision_config(), + builder_))); + if (instruction->precision_config().algorithm() != + PrecisionConfig::ALG_UNSET) { + attributes.push_back(builder_->getNamedAttr( + "algorithm", + ConvertDotAlgorithm(instruction->precision_config().algorithm(), + builder_))); + } + attributes.push_back(builder_->getNamedAttr( + "dot_dimension_numbers", + ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(), + builder_))); + for (const SparsityDescriptor& descriptor : dot->sparsity()) { + TF_ASSIGN_OR_RETURN(auto sparsity, + ConvertSparsityDescriptor(descriptor, builder_)); + attributes.push_back(builder_->getNamedAttr( + descriptor.index() == 0 ? "lhs_sparsity" : "rhs_sparsity", + sparsity)); + } + // XLA Feature -- MHLO Only + return func_builder + ->create(loc, result_type, operands, + attributes) + .getOperation(); + } + + // Dot or DotGeneral attributes.push_back(builder_->getNamedAttr( - "precision_config", - ConvertPrecisionConfig(&instruction->precision_config(), builder_))); + "precision_config", stablehlo::ConvertPrecisionConfig( + &instruction->precision_config(), builder_))); if (instruction->precision_config().algorithm() != PrecisionConfig::ALG_UNSET) { attributes.push_back(builder_->getNamedAttr( "algorithm", - ConvertDotAlgorithm(instruction->precision_config().algorithm(), - builder_))); + stablehlo::ConvertDotAlgorithm( + instruction->precision_config().algorithm(), builder_))); } // Consider consolidating DotOps together. if (DotIsDefault(instruction) && !dot->sparse_operands()) { - // TODO(b/408024772) ToStablehlo: Convert[PrecisionConfig|DotAlgorithm] return func_builder - ->create(loc, result_type, operands, attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } attributes.push_back(builder_->getNamedAttr( "dot_dimension_numbers", - ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(), - builder_))); - if (!dot->sparse_operands()) { - // TODO(b/408024772) ToStablehlo: ConvertDotDimensionNumbers - return func_builder - ->create(loc, result_type, operands, - attributes) - .getOperation(); - } - - for (const SparsityDescriptor& descriptor : dot->sparsity()) { - TF_ASSIGN_OR_RETURN(auto sparsity, - ConvertSparsityDescriptor(descriptor, builder_)); - attributes.push_back(builder_->getNamedAttr( - descriptor.index() == 0 ? "lhs_sparsity" : "rhs_sparsity", - sparsity)); - } - // XLA Feature -- MHLO Only + stablehlo::ConvertDotDimensionNumbers( + instruction->dot_dimension_numbers(), builder_))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kRaggedAllToAll: { @@ -1231,10 +1246,9 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( compare->operand(0)->shape().element_type()); if (compare->type() != default_type) attributes.push_back(ConvertComparisonType(compare->type())); - // TODO(b/408024772) ToStableHLO: ConvertComparison[Direction|Type] return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kCholesky: { @@ -1250,21 +1264,21 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto gather_instruction = Cast(instruction); attributes.push_back(builder_->getNamedAttr( "dimension_numbers", - ConvertGatherDimensionNumbers( + stablehlo::ConvertGatherDimensionNumbers( gather_instruction->gather_dimension_numbers(), builder_))); std::vector slice_sizes( gather_instruction->gather_slice_sizes().begin(), gather_instruction->gather_slice_sizes().end()); attributes.push_back( - builder_->getNamedAttr("slice_sizes", Convert(slice_sizes))); + builder_->getNamedAttr("slice_sizes", ConvertArray(slice_sizes))); attributes.push_back(builder_->getNamedAttr( "indices_are_sorted", builder_->getBoolAttr(gather_instruction->indices_are_sorted()))); - // TODO(b/408024772) ToStableHLO: ConvertGatherDimensionNumbers return func_builder - ->create(loc, result_type, operands, attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kDynamicSlice: { @@ -1357,8 +1371,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto scatter = Cast(instruction); attributes.push_back(builder_->getNamedAttr( "scatter_dimension_numbers", - ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(), - builder_))); + stablehlo::ConvertScatterDimensionNumbers( + scatter->scatter_dimension_numbers(), builder_))); attributes.push_back(builder_->getNamedAttr( "indices_are_sorted", builder_->getBoolAttr(scatter->indices_are_sorted()))); @@ -1368,8 +1382,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( llvm::SmallVector flattened_types; FlattenTupleType(result_type, flattened_types); - // TODO(b/408024772) ToStableHLO: ConvertScatterDimensionNumbers - auto scatter_op = func_builder->create( + auto scatter_op = func_builder->create( loc, flattened_types, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(), &scatter_op.getUpdateComputation())); @@ -1692,12 +1705,10 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return_types = llvm::to_vector<6>(tuple_ty.getTypes()); } - // TODO(b/408024772) ToStablehlo: ReduceOp builder is different in - // StableHLO - auto reduce = func_builder->create( - loc, return_types, llvm::ArrayRef(operands).take_front(num_inputs), - llvm::ArrayRef(operands).drop_front(num_inputs), - ConvertDimensions(instruction->dimensions())); + auto reduce = func_builder->create( + loc, return_types, mlir::ValueRange(operands).take_front(num_inputs), + mlir::ValueRange(operands).drop_front(num_inputs), + ConvertArray(ToArrayRef(instruction->dimensions()))); TF_RETURN_IF_ERROR( ImportAsRegion(*instruction->to_apply(), &reduce.getBody())); @@ -1785,7 +1796,6 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (rng_op->shape().IsArray()) { return op.getOperation(); } - // TODO(b/408024772) ToStablehlo: CreateTupleFromOpResults return CreateTupleFromOpResults(func_builder, loc, op.getOperation(), result_type); } @@ -2385,8 +2395,8 @@ mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection( ComparisonDirection direction) { return builder_->getNamedAttr( "comparison_direction", - mlir::mhlo::ComparisonDirectionAttr::get( - builder_->getContext(), mlir::mhlo::symbolizeComparisonDirection( + mlir::stablehlo::ComparisonDirectionAttr::get( + builder_->getContext(), mlir::stablehlo::symbolizeComparisonDirection( ComparisonDirectionToString(direction)) .value())); } @@ -2395,9 +2405,9 @@ mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType( Comparison::Type type) { return builder_->getNamedAttr( "compare_type", - mlir::mhlo::ComparisonTypeAttr::get( + mlir::stablehlo::ComparisonTypeAttr::get( builder_->getContext(), - mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type)) + mlir::stablehlo::symbolizeComparisonType(ComparisonTypeToString(type)) .value())); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index eb5744bdc6109f..ccf55a9739a41e 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/ValueRange.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir/utils/type_util.h" @@ -174,6 +175,23 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( vector); } +namespace { +bool HasMhloTokenType(mlir::TypeRange types) { + bool use_mhlo = false; + for (auto type : types) { + if (!use_mhlo) { + type.walk([&](mlir::Type type) { + use_mhlo |= llvm::isa(type); + if (use_mhlo) return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }); + } + } + return use_mhlo; +} + +} // namespace + mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, mlir::ValueRange& flatten_values, mlir::Type type) { @@ -190,7 +208,11 @@ mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, flatten_sub_values.push_back( CreateTupleValue(func_builder, loc, flatten_values, child_type)); - return func_builder->create(loc, flatten_sub_values) + if (HasMhloTokenType(mlir::TypeRange(flatten_sub_values))) { + return func_builder->create(loc, flatten_sub_values) + .getResult(); + } + return func_builder->create(loc, flatten_sub_values) .getResult(); } @@ -203,10 +225,12 @@ mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, mlir::ValueRange flattened_results_ref(op->getResults()); auto result = CreateTupleValue(func_builder, loc, flattened_results_ref, type); - auto defining_tuple_op = result.getDefiningOp(); - assert(defining_tuple_op && "builder didn't return the right type"); - auto tupleOp = defining_tuple_op.getOperation(); - return tupleOp; + mlir::Operation* tuple_op = result.getDefiningOp(); + if (!tuple_op) { + tuple_op = result.getDefiningOp(); + } + assert(tuple_op && "builder didn't return the right type"); + return tuple_op; } mlir::Operation* WrapVariadicResultsInTuple(mlir::OpBuilder* builder, diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo index d76a2b2fe51f21..0fe3161a055ff9 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo @@ -5,7 +5,7 @@ // minimized and named to reflect the test intent. // Regnerate them using the following command: -// $ TFILE=/path/to/import_emit_stablehlo.hlo +// $ TFILE=/path/to/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo // $ DELIM="Hlo Module" # Remove the space in the middle when running cmd. This comment needs the space since the source file is regex matched. // $ xla-translate $TFILE -hlo-text-to-mlir-hlo --emit-stablehlo --split-input-file --hlo-import-all-computations | \ // third_party/llvm/llvm-project/mlir/utils/generate-test-checks.py --source $TFILE --source_delim_regex="$DELIM" --starts_from_scope=0 -i @@ -855,13 +855,13 @@ ENTRY %main.8 (Arg_0.1: f32[6], Arg_1.2: f32[6], Arg_2.3: s32[3], Arg_3.4: s32[3 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func private @top_k_gt_comparator.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_4:.*]] = mhlo.compare GT, %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = stablehlo.compare GT, %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor // CHECK: return %[[VAL_4]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_5:.*]]: tensor<16x256xbf16>, %[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor<16x256xi32>, %[[VAL_8:.*]]: tensor) -> tuple, tensor<16x4xi32>> { // CHECK: %[[VAL_9:.*]]:2 = "stablehlo.sort"(%[[VAL_5]], %[[VAL_7]]) <{dimension = 1 : i64, is_stable = false}> ({ // CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor): -// CHECK: %[[VAL_14:.*]] = mhlo.compare GT, %[[VAL_10]], %[[VAL_11]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = stablehlo.compare GT, %[[VAL_10]], %[[VAL_11]] : (tensor, tensor) -> tensor // CHECK: stablehlo.return %[[VAL_14]] : tensor // CHECK: }) : (tensor<16x256xbf16>, tensor<16x256xi32>) -> (tensor<16x256xbf16>, tensor<16x256xi32>) // CHECK: %[[VAL_15:.*]] = stablehlo.slice %[[VAL_16:.*]]#0 [0:16, 0:4] : (tensor<16x256xbf16>) -> tensor<16x4xbf16> @@ -900,11 +900,11 @@ ENTRY %main.20 (Arg_0.1: bf16[16,256], Arg_1.2: s32[], Arg_2.3: s32[16,256], Arg // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func private @top_k_gt_comparator.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_4:.*]] = mhlo.compare GT, %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = stablehlo.compare GT, %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor // CHECK: return %[[VAL_4]] : tensor // CHECK: } // CHECK: func.func private @top_k_gt_comparator.14(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_9:.*]] = mhlo.compare GT, %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = stablehlo.compare GT, %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: return %[[VAL_9]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_10:.*]]: tensor<16x256xbf16>, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor<16x256xi32>, %[[VAL_13:.*]]: tensor) -> tuple, tensor<16x4xi32>> { @@ -913,7 +913,7 @@ ENTRY %main.20 (Arg_0.1: bf16[16,256], Arg_1.2: s32[], Arg_2.3: s32[16,256], Arg // CHECK: %[[VAL_16:.*]] = stablehlo.get_tuple_element %[[VAL_14]][1] : (tuple, tensor<16x128xi32>>) -> tensor<16x128xi32> // CHECK: %[[VAL_17:.*]]:2 = "stablehlo.sort"(%[[VAL_15]], %[[VAL_16]]) <{dimension = 1 : i64, is_stable = false}> ({ // CHECK: ^bb0(%[[VAL_18:.*]]: tensor, %[[VAL_19:.*]]: tensor, %[[VAL_20:.*]]: tensor, %[[VAL_21:.*]]: tensor): -// CHECK: %[[VAL_22:.*]] = mhlo.compare GT, %[[VAL_18]], %[[VAL_19]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_22:.*]] = stablehlo.compare GT, %[[VAL_18]], %[[VAL_19]] : (tensor, tensor) -> tensor // CHECK: stablehlo.return %[[VAL_22]] : tensor // CHECK: }) : (tensor<16x128xbf16>, tensor<16x128xi32>) -> (tensor<16x128xbf16>, tensor<16x128xi32>) // CHECK: %[[VAL_23:.*]] = stablehlo.slice %[[VAL_24:.*]]#0 [0:16, 0:4] : (tensor<16x128xbf16>) -> tensor<16x4xbf16> @@ -1046,7 +1046,7 @@ ENTRY %main.6 (Arg_0.1: f32[2,3]) -> (f32[2,3], f16[4,5]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi8>, %[[VAL_1:.*]]: tensor<3xi8>) -> tensor { -// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3xi8>, tensor<3xi8>) -> tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.dot %[[VAL_0]], %[[VAL_1]], precision = [DEFAULT, DEFAULT] : (tensor<3xi8>, tensor<3xi8>) -> tensor // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: } @@ -1062,7 +1062,7 @@ ENTRY %main.4 (Arg_0.1: s8[3], Arg_1.2: s8[3]) -> s64[] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi4>, %[[VAL_1:.*]]: tensor<3xi4>) -> tensor { -// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3xi4>, tensor<3xi4>) -> tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.dot %[[VAL_0]], %[[VAL_1]], precision = [DEFAULT, DEFAULT] : (tensor<3xi4>, tensor<3xi4>) -> tensor // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: } @@ -1078,7 +1078,7 @@ ENTRY %main.4 (Arg_0.1: s4[3], Arg_1.2: s4[3]) -> s8[] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xui4>, %[[VAL_1:.*]]: tensor<3xui4>) -> tensor { -// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3xui4>, tensor<3xui4>) -> tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.dot %[[VAL_0]], %[[VAL_1]], precision = [DEFAULT, DEFAULT] : (tensor<3xui4>, tensor<3xui4>) -> tensor // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: } @@ -1094,7 +1094,7 @@ ENTRY %main.4 (Arg_0.1: u4[3], Arg_1.2: u4[3]) -> u8[] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2x2xi8>, %[[VAL_1:.*]]: tensor<2x2x3xi8>) -> tensor<2x2x3xi32> { -// CHECK: %[[VAL_2:.*]] = "mhlo.dot_general"(%[[VAL_0]], %[[VAL_1]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> +// CHECK: %[[VAL_2:.*]] = stablehlo.dot_general %[[VAL_0]], %[[VAL_1]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> // CHECK: return %[[VAL_2]] : tensor<2x2x3xi32> // CHECK: } // CHECK: } @@ -1127,7 +1127,7 @@ ENTRY %main.5 (Arg_0.1: bf16[10,16], Arg_1.2: bf16[32,20], Arg_2.3: u16[10,2]) - // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: tensor<4x5xi32>) -> tensor<3x5xi32> { -// CHECK: %[[VAL_2:.*]] = "mhlo.dot"(%[[VAL_0]], %[[VAL_1]]) <{precision_config = [#mhlo, #mhlo]}> {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> +// CHECK: %[[VAL_2:.*]] = stablehlo.dot %[[VAL_0]], %[[VAL_1]], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> // CHECK: %[[VAL_3:.*]] = stablehlo.transpose %[[VAL_2]], dims = [0, 1] : (tensor<3x5xi32>) -> tensor<3x5xi32> // CHECK: return %[[VAL_3]] : tensor<3x5xi32> // CHECK: } @@ -1160,7 +1160,7 @@ ENTRY %main.3 (Arg_0.1: f32[3,9]) -> c64[3,5] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<200x100x300xf32>, %[[VAL_1:.*]]: tensor<10x2xi32>) -> tensor<10x300xf32> { -// CHECK: %[[VAL_2:.*]] = "mhlo.gather"(%[[VAL_0]], %[[VAL_1]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>}> : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> +// CHECK: %[[VAL_2:.*]] = "stablehlo.gather"(%[[VAL_0]], %[[VAL_1]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = array}> : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> // CHECK: return %[[VAL_2]] : tensor<10x300xf32> // CHECK: } // CHECK: } @@ -1176,7 +1176,7 @@ ENTRY %main.4 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[10,2]) -> f32[10,300] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<200x100x300xf32>, %[[VAL_1:.*]]: tensor<100x200x1xi32>) -> tensor<100x200x300xf32> { -// CHECK: %[[VAL_2:.*]] = "mhlo.gather"(%[[VAL_0]], %[[VAL_1]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>}> : (tensor<200x100x300xf32>, tensor<100x200x1xi32>) -> tensor<100x200x300xf32> +// CHECK: %[[VAL_2:.*]] = "stablehlo.gather"(%[[VAL_0]], %[[VAL_1]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = array}> : (tensor<200x100x300xf32>, tensor<100x200x1xi32>) -> tensor<100x200x300xf32> // CHECK: return %[[VAL_2]] : tensor<100x200x300xf32> // CHECK: } // CHECK: } @@ -1508,11 +1508,11 @@ ENTRY %main.6 (Arg_0.1: token[]) -> token[] { // CHECK: return %[[VAL_6]] : tuple, tensor> // CHECK: } // CHECK: func.func @main(%[[VAL_7:.*]]: tensor<1x10xf32>, %[[VAL_8:.*]]: tensor<1x10xi32>, %[[VAL_9:.*]]: tensor, %[[VAL_10:.*]]: tensor) -> tuple, tensor<1xi32>> { -// CHECK: %[[VAL_11:.*]]:2 = mhlo.reduce(%[[VAL_7]] init: %[[VAL_9]]), (%[[VAL_8]] init: %[[VAL_10]]) across dimensions = [1] : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) +// CHECK: %[[VAL_11:.*]]:2 = stablehlo.reduce(%[[VAL_7]] init: %[[VAL_9]]), (%[[VAL_8]] init: %[[VAL_10]]) across dimensions = [1] : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) // CHECK: reducer(%[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor) { // CHECK: %[[VAL_16:.*]] = stablehlo.maximum %[[VAL_12]], %[[VAL_13]] : tensor // CHECK: %[[VAL_17:.*]] = stablehlo.maximum %[[VAL_14]], %[[VAL_15]] : tensor -// CHECK: mhlo.return %[[VAL_16]], %[[VAL_17]] : tensor, tensor +// CHECK: stablehlo.return %[[VAL_16]], %[[VAL_17]] : tensor, tensor // CHECK: } // CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_11]]#0, %[[VAL_11]]#1 {xla_shape = "(f32[1]{0}, s32[1]{0})"} : tuple, tensor<1xi32>> // CHECK: return %[[VAL_18]] : tuple, tensor<1xi32>> @@ -1650,10 +1650,10 @@ ENTRY %main.5 () -> f32[2,3,5] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<200x100x300xf32>, %[[VAL_4:.*]]: tensor<10x2xi32>, %[[VAL_5:.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> { -// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) <{indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true}> ({ +// CHECK: %[[VAL_6:.*]] = "stablehlo.scatter"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor): // CHECK: %[[VAL_9:.*]] = stablehlo.add %[[VAL_7]], %[[VAL_8]] : tensor -// CHECK: mhlo.return %[[VAL_9]] : tensor +// CHECK: stablehlo.return %[[VAL_9]] : tensor // CHECK: }) : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> // CHECK: return %[[VAL_6]] : tensor<200x100x300xf32> // CHECK: } @@ -1681,10 +1681,10 @@ ENTRY %main.9 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[10,2], Arg_2.3: f32[10,30 // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<200x100x300xf32>, %[[VAL_4:.*]]: tensor<100x200x1xi32>, %[[VAL_5:.*]]: tensor<100x200x300xf32>) -> tensor<200x100x300xf32> { -// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) <{indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true}> ({ +// CHECK: %[[VAL_6:.*]] = "stablehlo.scatter"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: tensor, %[[VAL_8:.*]]: tensor): // CHECK: %[[VAL_9:.*]] = stablehlo.add %[[VAL_7]], %[[VAL_8]] : tensor -// CHECK: mhlo.return %[[VAL_9]] : tensor +// CHECK: stablehlo.return %[[VAL_9]] : tensor // CHECK: }) : (tensor<200x100x300xf32>, tensor<100x200x1xi32>, tensor<100x200x300xf32>) -> tensor<200x100x300xf32> // CHECK: return %[[VAL_6]] : tensor<200x100x300xf32> // CHECK: } @@ -1714,11 +1714,11 @@ ENTRY %main.9 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[100,200,1], Arg_2.3: f32[ // CHECK: return %[[VAL_6]] : tuple, tensor> // CHECK: } // CHECK: func.func @main(%[[VAL_7:.*]]: tensor<200x100x300xf32>, %[[VAL_8:.*]]: tensor<10x2xi64>, %[[VAL_9:.*]]: tensor<10x300xf32>) -> tuple, tensor<200x100x300xf32>> { -// CHECK: %[[VAL_10:.*]]:2 = "mhlo.scatter"(%[[VAL_7]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_9]]) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false}> ({ +// CHECK: %[[VAL_10:.*]]:2 = "stablehlo.scatter"(%[[VAL_7]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_9]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor): // CHECK: %[[VAL_15:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_12]] : tensor // CHECK: %[[VAL_16:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_14]] : tensor -// CHECK: mhlo.return %[[VAL_15]], %[[VAL_16]] : tensor, tensor +// CHECK: stablehlo.return %[[VAL_15]], %[[VAL_16]] : tensor, tensor // CHECK: }) : (tensor<200x100x300xf32>, tensor<200x100x300xf32>, tensor<10x2xi64>, tensor<10x300xf32>, tensor<10x300xf32>) -> (tensor<200x100x300xf32>, tensor<200x100x300xf32>) // CHECK: %[[VAL_17:.*]] = mhlo.tuple %[[VAL_18:.*]]#0, %[[VAL_18]]#1 {xla_shape = "(f32[200,100,300]{2,1,0}, f32[200,100,300]{2,1,0})"} : tuple, tensor<200x100x300xf32>> // CHECK: return %[[VAL_17]] : tuple, tensor<200x100x300xf32>> @@ -1769,7 +1769,7 @@ ENTRY %main.6 (Arg_0.1: pred[], Arg_1.2: s32[2,3], Arg_2.3: s32[2,3]) -> s32[2,3 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_2:.*]] = mhlo.compare GE, %[[VAL_0]], %[[VAL_1]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.compare GE, %[[VAL_0]], %[[VAL_1]], TOTALORDER : (tensor, tensor) -> tensor // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func private @region_1.8(%[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor) -> tensor { @@ -1780,7 +1780,7 @@ ENTRY %main.6 (Arg_0.1: pred[], Arg_1.2: s32[2,3], Arg_2.3: s32[2,3]) -> s32[2,3 // CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_9:.*]] = "stablehlo.select_and_scatter"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor): -// CHECK: %[[VAL_12:.*]] = mhlo.compare GE, %[[VAL_10]], %[[VAL_11]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = stablehlo.compare GE, %[[VAL_10]], %[[VAL_11]], TOTALORDER : (tensor, tensor) -> tensor // CHECK: stablehlo.return %[[VAL_12]] : tensor // CHECK: }, { // CHECK: ^bb0(%[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor): @@ -2568,8 +2568,8 @@ ENTRY %main.3 (Arg_0.1: f32[?,784]) -> f32[?,784] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func private @region_0.3(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { // CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = mhlo.compare NE, %[[VAL_0]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = mhlo.compare NE, %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.compare NE, %[[VAL_0]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = stablehlo.compare NE, %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = stablehlo.or %[[VAL_3]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[VAL_7:.*]] = stablehlo.select %[[VAL_5]], %[[VAL_6]], %[[VAL_2]] : tensor, tensor @@ -2577,17 +2577,17 @@ ENTRY %main.3 (Arg_0.1: f32[?,784]) -> f32[?,784] { // CHECK: } // CHECK: func.func @main(%[[VAL_8:.*]]: tensor<2x2xf32>) -> tuple> { // CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_10:.*]] = mhlo.reduce(%[[VAL_8]] init: %[[VAL_9]]) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.reduce(%[[VAL_8]] init: %[[VAL_9]]) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor // CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor) { // CHECK: %[[VAL_13:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_14:.*]] = mhlo.compare NE, %[[VAL_11]], %[[VAL_13]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_15:.*]] = mhlo.compare NE, %[[VAL_12]], %[[VAL_13]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = stablehlo.compare NE, %[[VAL_11]], %[[VAL_13]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = stablehlo.compare NE, %[[VAL_12]], %[[VAL_13]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_16:.*]] = stablehlo.or %[[VAL_14]], %[[VAL_15]] : tensor // CHECK: %[[VAL_17:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[VAL_18:.*]] = stablehlo.select %[[VAL_16]], %[[VAL_17]], %[[VAL_13]] : tensor, tensor -// CHECK: mhlo.return %[[VAL_18]] : tensor +// CHECK: stablehlo.return %[[VAL_18]] : tensor // CHECK: } -// CHECK: %[[VAL_19:.*]] = mhlo.compare NE, %[[VAL_10]], %[[VAL_9]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_19:.*]] = stablehlo.compare NE, %[[VAL_10]], %[[VAL_9]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = mhlo.tuple %[[VAL_19]] {xla_shape = "(pred[])"} : tuple> // CHECK: return %[[VAL_20]] : tuple> // CHECK: } @@ -2705,9 +2705,9 @@ ENTRY %main.9 (Arg_0.1: f32[2,17,31,7], Arg_1.2: f32[]) -> f32[2,16,30,7] { // CHECK: return %[[VAL_1]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_2:.*]]: tensor<3xi32>, %[[VAL_3:.*]]: tensor<1x1xi32>, %[[VAL_4:.*]]: tensor<1xi32>) -> tensor<3xi32> { -// CHECK: %[[VAL_5:.*]] = "mhlo.scatter"(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false}> ({ +// CHECK: %[[VAL_5:.*]] = "stablehlo.scatter"(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor): -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK: return %[[VAL_5]] : tensor<3xi32> // CHECK: } @@ -2731,7 +2731,7 @@ ENTRY %main.8 (Arg_0.1: s32[3], Arg_1.2: s32[1,1], Arg_2.3: s32[1]) -> s32[3] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tensor { // CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = mhlo.compare GE, %[[VAL_0]], %[[VAL_2]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.compare GE, %[[VAL_0]], %[[VAL_2]], TOTALORDER : (tensor, tensor) -> tensor // CHECK: return %[[VAL_3]] : tensor // CHECK: } // CHECK: func.func private @region_1.9(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor) -> tensor { @@ -2741,7 +2741,7 @@ ENTRY %main.8 (Arg_0.1: s32[3], Arg_1.2: s32[1,1], Arg_2.3: s32[1]) -> s32[3] { // CHECK: %[[VAL_9:.*]] = "stablehlo.select_and_scatter"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor): // CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_13:.*]] = mhlo.compare GE, %[[VAL_10]], %[[VAL_12]], TOTALORDER : (tensor, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = stablehlo.compare GE, %[[VAL_10]], %[[VAL_12]], TOTALORDER : (tensor, tensor) -> tensor // CHECK: stablehlo.return %[[VAL_13]] : tensor // CHECK: }, { // CHECK: ^bb0(%[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor): From 87d8e9e034f259f49752208d63dc09caa097e36d Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Thu, 10 Apr 2025 09:54:39 -0700 Subject: [PATCH 0506/1324] [XLA:Upkeep] Resolve the following technical debt issue: Todo(resolved) PiperOrigin-RevId: 746070354 --- third_party/xla/xla/stream_executor/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index c346625a666ea3..e99aeff6c0484b 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -921,7 +921,6 @@ alias( "gpu", "rocm-only", ] + if_google([ - # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. "manual", ]), ) From 24023f87dbfb7d6270bf3f779590408d4d09b89e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 10:06:45 -0700 Subject: [PATCH 0507/1324] Moves some existing XLA Auto Sharding utilities for IOPDDL into third_party/ PiperOrigin-RevId: 746075132 --- .../xla/hlo/experimental/auto_sharding/BUILD | 14 + .../auto_sharding/auto_sharding_iopddl.cc | 282 ++++++++++++++++++ .../auto_sharding/auto_sharding_iopddl.h | 44 +++ 3 files changed, 340 insertions(+) create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.h diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 1615747d70a1bd..68f003c22b3232 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -451,6 +451,7 @@ cc_library( "iopddl.h", "solver.h", ], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", @@ -473,3 +474,16 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "auto_sharding_iopddl", + srcs = ["auto_sharding_iopddl.cc"], + hdrs = ["auto_sharding_iopddl.h"], + compatible_with = get_compatible_with_libtpu_portable(), + deps = [ + ":iopddl_lib", + "//xla/hlo/experimental/auto_sharding:auto_sharding_proto_cc", + "//xla/hlo/experimental/auto_sharding:auto_sharding_strategy", + "@com_google_absl//absl/log:check", + ], +) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc new file mode 100644 index 00000000000000..a2cd086c11f0bc --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc @@ -0,0 +1,282 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "xla/hlo/experimental/auto_sharding/iopddl.h" + +namespace xla { +namespace spmd { + +iopddl::Cost ConvertCost(const double cost) { + CHECK_GE(cost, 0); // Contest problems shouldn't include any negative costs. + if (cost >= kInfinityInt) { + return kInfinityInt; + } + return static_cast(cost); +} + +iopddl::Problem ConvertToProblem(const AutoShardingSolverRequest& request) { + CHECK(request.live().empty()); // Contest files don't support live matrices. + CHECK(request.node_groups().empty()); // Contest files don't support groups. + iopddl::Problem problem = {.name = request.request_name()}; + for (int64_t node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + CHECK_LT(node_idx, request.node_intervals_size()); + const auto& interval = request.node_intervals(node_idx); + iopddl::Interval node_interval = {0, 0}; + if (interval.first() <= interval.second()) { + node_interval = {interval.first(), interval.second() + 1}; + } + problem.nodes.push_back({node_interval}); + CHECK_LT(node_idx, request.s_len_size()); + CHECK_LT(node_idx, request.computation_costs_size()); + CHECK_LT(node_idx, request.communication_costs_size()); + CHECK_LT(node_idx, request.memory_costs_size()); + for (int64_t j = 0; j < request.s_len(node_idx); ++j) { + CHECK_LT(j, request.computation_costs(node_idx).costs_size()); + CHECK_LT(j, request.communication_costs(node_idx).costs_size()); + CHECK_LT(j, request.memory_costs(node_idx).costs_size()); + const double node_cost = request.computation_costs(node_idx).costs(j) + + request.communication_costs(node_idx).costs(j); + const iopddl::Cost cost = ConvertCost(node_cost); + const iopddl::Usage usage = + ConvertCost(request.memory_costs(node_idx).costs(j)); + problem.nodes.back().strategies.push_back({cost, usage}); + } + } + // The first kind of edges come from request.edges + for (int64_t edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { + const auto& edge = request.edges(edge_idx); + CHECK_LT(edge.first(), request.s_len_size()); + CHECK_LT(edge.second(), request.s_len_size()); + CHECK_LT(edge_idx, request.resharding_costs_size()); + problem.edges.push_back({{edge.first(), edge.second()}}); + for (int64_t i = 0; i < request.s_len(edge.first()); ++i) { + for (int64_t j = 0; j < request.s_len(edge.second()); ++j) { + const int64_t k = i * request.s_len(edge.second()) + j; + CHECK_LT(k, request.resharding_costs(edge_idx).costs_size()); + const iopddl::Cost cost = + ConvertCost(request.resharding_costs(edge_idx).costs(k)); + problem.edges.back().strategies.push_back({cost}); + } + } + } + // The second kind of edges come from request.aliases + for (int64_t alias_idx = 0; alias_idx < request.aliases_size(); ++alias_idx) { + const auto& alias = request.aliases(alias_idx); + problem.edges.push_back({{alias.first(), alias.second()}}); + CHECK_LT(alias.first(), request.s_len_size()); + CHECK_LT(alias.second(), request.s_len_size()); + CHECK_LT(alias_idx, request.value_costs_size()); + for (int64_t i = 0; i < request.s_len(alias.first()); ++i) { + for (int64_t j = 0; j < request.s_len(alias.second()); ++j) { + const int64_t k = i * request.s_len(alias.second()) + j; + CHECK_LT(k, request.value_costs(alias_idx).costs_size()); + const iopddl::Cost cost = + ConvertCost(request.value_costs(alias_idx).costs(k) * kInfinityInt); + problem.edges.back().strategies.push_back({cost}); + } + } + } + // The third kind of edges come from request.s_follow + for (int64_t node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + CHECK_LT(node_idx, request.s_follow_size()); + if (request.s_follow(node_idx) < 0) { + continue; + } + problem.edges.push_back({{request.s_follow(node_idx), node_idx}}); + CHECK_LT(node_idx, request.s_len_size()); + for (int64_t i = 0; i < request.s_len(node_idx); ++i) { + for (int64_t j = 0; j < request.s_len(node_idx); ++j) { + const iopddl::Cost cost = (i == j) ? 0 : kInfinityInt; + problem.edges.back().strategies.push_back({cost}); + } + } + } + if (request.memory_budget() > 0) { + problem.usage_limit = request.memory_budget(); + } + return problem; +} + +static bool IsEdgeFollower(const iopddl::Problem& problem, + const iopddl::Edge& edge) { + int strategies0 = problem.nodes[edge.nodes[0]].strategies.size(); + int strategies1 = problem.nodes[edge.nodes[1]].strategies.size(); + if (strategies0 != strategies1) { + return false; + } + for (iopddl::StrategyIdx idx0 = 0; idx0 < strategies0; ++idx0) { + for (iopddl::StrategyIdx idx1 = 0; idx1 < strategies1; ++idx1) { + const auto strategy = edge.strategies[idx0 * strategies1 + idx1]; + if (idx0 == idx1 && strategy.cost != 0) { + return false; + } + if (idx0 != idx1 && strategy.cost != kInfinityInt) { + return false; + } + } + } + return true; +} + +static bool IsEdgeAlias(const iopddl::Edge& edge) { + for (const iopddl::Strategy& strategy : edge.strategies) { + if (strategy.cost == kInfinityInt) { + return true; + } + } + return false; +} + +AutoShardingSolverRequest ConvertToSolverRequest( + const iopddl::Problem& problem) { + AutoShardingSolverRequest request; + request.set_request_name(problem.name); + request.set_num_nodes(problem.nodes.size()); + request.set_memory_budget(*problem.usage_limit); + for (iopddl::NodeIdx node_idx = 0; node_idx < problem.nodes.size(); + ++node_idx) { + const iopddl::Node& node = problem.nodes[node_idx]; + request.add_s_len(node.strategies.size()); + request.add_s_follow(-1); + request.add_communication_costs(); + request.add_computation_costs(); + request.add_memory_costs(); + for (const iopddl::Strategy& strategy : node.strategies) { + double strategy_cost = (strategy.cost == kInfinityInt) + ? kInfinityCost + : static_cast(strategy.cost); + request.mutable_computation_costs()->rbegin()->add_costs( + static_cast(strategy_cost)); + request.mutable_communication_costs()->rbegin()->add_costs(0.0); + request.mutable_memory_costs()->rbegin()->add_costs( + static_cast(strategy.usage)); + } + request.add_node_intervals(); + bool empty_interval = (node.interval.first == node.interval.second); + request.mutable_node_intervals()->rbegin()->set_first( + empty_interval ? 100 : node.interval.first); + request.mutable_node_intervals()->rbegin()->set_second( + empty_interval ? -1 : node.interval.second - 1); + } + for (iopddl::EdgeIdx edge_idx = 0; edge_idx < problem.edges.size(); + ++edge_idx) { + const iopddl::Edge& edge = problem.edges[edge_idx]; + if (IsEdgeFollower(problem, edge)) { + request.mutable_s_follow()->Set(edge.nodes[1], edge.nodes[0]); + continue; + } + if (IsEdgeAlias(edge)) { + auto* alias = request.add_aliases(); + alias->set_first(edge.nodes[0]); + alias->set_second(edge.nodes[1]); + request.add_value_costs(); + for (const iopddl::Strategy& strategy : edge.strategies) { + request.mutable_value_costs()->rbegin()->add_costs( + strategy.cost == kInfinityInt ? 1.0 : 0.0); + } + continue; + } + auto* edge_proto = request.add_edges(); + edge_proto->set_first(edge.nodes[0]); + edge_proto->set_second(edge.nodes[1]); + request.add_resharding_costs(); + for (const iopddl::Strategy& strategy : edge.strategies) { + request.mutable_resharding_costs()->rbegin()->add_costs( + static_cast(strategy.cost)); + } + } + return request; +} + +void RandomizeCosts(iopddl::Problem& problem) { + unsigned int seed = 2025; + auto get_multiplier = [&]() { // Returns a value between 1/16 and 16.0 + return std::pow(2.0, (rand_r(&seed) % 9) - 4); + }; + auto randomize = [&](iopddl::Cost& cost, const double multiplier) { + if (cost != kInfinityInt) { + cost = static_cast(static_cast(cost) * multiplier); + } + }; + for (iopddl::Node& node : problem.nodes) { + const double multiplier = get_multiplier(); + for (iopddl::Strategy& strategy : node.strategies) { + randomize(strategy.cost, multiplier); + } + } + for (iopddl::Edge& edge : problem.edges) { + const double multiplier = get_multiplier(); + for (iopddl::Strategy& strategy : edge.strategies) { + randomize(strategy.cost, multiplier); + } + } +} + +// TODO(moffitt): Re-implement this using an XLA-friendly library (eg, jsoncpp). +std::string ConvertToJsonString(const iopddl::Problem& problem) { +/* + nlohmann::json json; + json["problem"]["name"] = problem.name; + json["problem"]["nodes"]["intervals"] = nlohmann::json::array(); + json["problem"]["nodes"]["costs"] = nlohmann::json::array(); + json["problem"]["nodes"]["usages"] = nlohmann::json::array(); + for (const iopddl::Node& node : problem.nodes) { + auto intervals = nlohmann::json::array(); + auto costs = nlohmann::json::array(); + auto usages = nlohmann::json::array(); + intervals.push_back(node.interval.first); + intervals.push_back(node.interval.second); + for (const iopddl::Strategy& strategy : node.strategies) { + costs.push_back(strategy.cost); + usages.push_back(strategy.usage); + } + json["problem"]["nodes"]["intervals"].push_back(intervals); + json["problem"]["nodes"]["costs"].push_back(costs); + json["problem"]["nodes"]["usages"].push_back(usages); + } + json["problem"]["edges"]["nodes"] = nlohmann::json::array(); + json["problem"]["edges"]["costs"] = nlohmann::json::array(); + for (const iopddl::Edge& edge : problem.edges) { + auto nodes = nlohmann::json::array(); + auto costs = nlohmann::json::array(); + for (const iopddl::NodeIdx node_idx : edge.nodes) { + nodes.push_back(node_idx); + } + for (const iopddl::Strategy& strategy : edge.strategies) { + costs.push_back(strategy.cost); + } + json["problem"]["edges"]["nodes"].push_back(nodes); + json["problem"]["edges"]["costs"].push_back(costs); + } + if (problem.usage_limit.has_value()) { + json["problem"]["usage_limit"] = *problem.usage_limit; + } + return json.dump(); +*/ + return ""; +} + +} // namespace spmd +} // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.h new file mode 100644 index 00000000000000..77e346be87df90 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_IOPDDL_H_ +#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_IOPDDL_H_ + +#include +#include + +#include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" +#include "xla/hlo/experimental/auto_sharding/iopddl.h" + +namespace xla { +namespace spmd { + +constexpr int64_t kInfinityInt = 1e18; + +iopddl::Cost ConvertCost(double cost); + +iopddl::Problem ConvertToProblem(const AutoShardingSolverRequest& request); + +AutoShardingSolverRequest ConvertToSolverRequest( + const iopddl::Problem& problem); + +void RandomizeCosts(iopddl::Problem& problem); + +std::string ConvertToJsonString(const iopddl::Problem& problem); + +} // namespace spmd +} // namespace xla + +#endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_IOPDDL_H_ From 09571891b434d613c9385a3f8ccfa886470e40e7 Mon Sep 17 00:00:00 2001 From: Tongfei Guo Date: Thu, 10 Apr 2025 10:27:28 -0700 Subject: [PATCH 0508/1324] [XLA] Support kTranpose op in formatting step. PiperOrigin-RevId: 746083396 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 19 +++++++++++++++++++ .../xla/xla/hlo/utils/hlo_sharding_util.h | 1 + third_party/xla/xla/service/BUILD | 1 + .../batched_gather_scatter_normalizer.h | 2 +- .../xla/xla/service/hlo_creation_utils.cc | 13 +++++++++++++ .../xla/xla/service/hlo_creation_utils.h | 8 ++++++++ 6 files changed, 43 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index ca1317757dd35f..c2faedaa8239f2 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -101,6 +101,12 @@ HloInstruction* FormatShape(HloInstruction* data, step.output_shape, data, padding, padding_config)); break; } + case HloOpcode::kTranspose: { + CHECK(step.xpose_permutation.has_value()); + data = computation->AddInstruction(HloInstruction::CreateTranspose( + step.output_shape, data, *step.xpose_permutation)); + break; + } default: LOG(FATAL) << "Unsupported formatting step"; } @@ -135,6 +141,19 @@ HloInstruction* ReverseFormatShape( previous_shape.dimensions(), strides)); break; } + case HloOpcode::kTranspose: { + CHECK(step.xpose_permutation.has_value()); + std::vector reverse_permutation; + reverse_permutation.reserve(step.xpose_permutation->size()); + for (int64_t i = 0; i < step.xpose_permutation->size(); ++i) { + reverse_permutation.push_back( + absl::c_find(*step.xpose_permutation, i) - + step.xpose_permutation->begin()); + } + data = computation->AddInstruction(HloInstruction::CreateTranspose( + previous_shape, data, reverse_permutation)); + break; + } default: LOG(FATAL) << "Unsupported formatting step"; } diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index 49155a0705da88..427377b6025c27 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -52,6 +52,7 @@ struct FormattingStep { std::optional reverse_input_shape; HloOpcode formatting_opcode; HloInstruction* padding_value; + std::optional> xpose_permutation; }; struct GatherScatterDims { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index afd16b7671a801..1992a34ce73a8c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -2095,6 +2095,7 @@ cc_library( "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:comparators", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.h b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h index 50c1d43def0293..e0576c4478bce2 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer.h +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h @@ -27,7 +27,7 @@ namespace xla { class BatchedGatherScatterNormalizer : public OpExpanderPass { public: absl::string_view name() const override { - return "gather_scatter_normalizer"; + return "batched_gather_scatter_normalizer"; } protected: diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index dc3965a023112c..b815b1ffc79627 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -907,4 +908,16 @@ HloInstruction* ExpandDegenerateReshape(HloInstruction* inst) { return nullptr; } +absl::StatusOr MakeWithinBounds(HloInstruction* inst, + HloInstruction* lower_bound, + HloInstruction* upper_bound) { + TF_ASSIGN_OR_RETURN( + HloInstruction * le, + MakeCompareHlo(Comparison::Direction::kLe, lower_bound, inst)); + TF_ASSIGN_OR_RETURN( + HloInstruction * gt, + MakeCompareHlo(Comparison::Direction::kGt, upper_bound, inst)); + return MakeBinaryHlo(HloOpcode::kAnd, le, gt); +} + } // namespace xla diff --git a/third_party/xla/xla/service/hlo_creation_utils.h b/third_party/xla/xla/service/hlo_creation_utils.h index 6e9aa6acef2d8f..686ff34765f10f 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.h +++ b/third_party/xla/xla/service/hlo_creation_utils.h @@ -21,11 +21,13 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal_util.h" +#include "xla/primitive_util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -430,6 +432,12 @@ std::unique_ptr MakeScalarConstantWithShape(const Shape& shape, shape.element_type()); } +// Create instructions that check if the given instruction is within the given +// bounds (lower_bound <= inst < upper_bound). +absl::StatusOr MakeWithinBounds(HloInstruction* inst, + HloInstruction* lower_bound, + HloInstruction* upper_bound); + } // namespace xla #endif // XLA_SERVICE_HLO_CREATION_UTILS_H_ From ba5031ea65e5f1ecc356e1c587a8253105d0d212 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 10 Apr 2025 10:41:04 -0700 Subject: [PATCH 0509/1324] [xla:gpu] CommandBuffer: remove CUDA <= 12.4 workarounds PiperOrigin-RevId: 746089128 --- .../gpu/runtime/command_buffer_cmd.cc | 49 +------------------ .../backends/gpu/runtime/command_buffer_cmd.h | 15 ------ 2 files changed, 2 insertions(+), 62 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 634b42d0506590..ce19ee9ac5bd73 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -70,8 +70,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/cuda/cuda_compute_capability.h" -#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -502,27 +500,6 @@ CommandBufferCmd::BufferUseVector ComputationIdCmd::buffers() { return {{dest_, MemoryAccess::kWrite}}; } -absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, - StateManager& state) { - auto cuda_cc = std::get_if( - ¶ms.executor->GetDeviceDescription().gpu_compute_capability()); - if (cuda_cc != nullptr) { - { - absl::MutexLock lock(&mutex_); - if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - CreateKernel("memset32", 3, kMemset32Kernel, - /*cubin_data=*/{}, params.executor, - /*shared_mem_bytes=*/0)); - - absl::MutexLock lock(&mutex_); - memset_kernels_.emplace(params.executor, std::move(kernel)); - } - return absl::OkStatus(); -} - absl::StatusOr ComputationIdCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -545,30 +522,8 @@ absl::StatusOr ComputationIdCmd::Record( << "; value=" << value; VLOG(5) << " Id: " << dest_ << " (" << dst.opaque() << ")"; - auto cuda_cc = std::get_if( - &execute_params.stream->parent() - ->GetDeviceDescription() - .gpu_compute_capability()); - - if (cuda_cc != nullptr) { - se::Kernel* memset_kernel = [&] { - absl::MutexLock lock(&mutex_); - return memset_kernels_[execute_params.stream->parent()].get(); - }(); - - if (memset_kernel == nullptr) { - return absl::InternalError( - "Memset kernel not loaded on a command buffer executor"); - } - - auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); - return RecordedCommands::Create(command_buffer->Launch( - se::ThreadDim(1), se::BlockDim(1), *memset_kernel, *args, {})); - - } else { - return RecordedCommands::Create( - command_buffer->Memset(&dst, value, /*num_elements=*/1, {})); - } + return RecordedCommands::Create( + command_buffer->Memset(&dst, value, /*num_elements=*/1, {})); } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index b6d53ab7ffc339..db17ea1a639e9d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -453,9 +453,6 @@ class ComputationIdCmd : public CommandBufferCmd { ComputationIdCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dest, Kind kind); - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -466,18 +463,6 @@ class ComputationIdCmd : public CommandBufferCmd { private: BufferAllocation::Slice dest_; Kind kind_; - - // Command sequence can be recorded concurrently for multiple command buffers - // on different stream executors and we need to synchronize mutable state. - absl::Mutex mutex_; - - // TODO(ezhulenev): This is a workaround for CUDA graphs + conditional nodes - // bug that will be fixed in CUDA 12.4.1 release: currently it's impossible to - // update a memset node inside a conditional graph. Instead of using memset - // node we replace it with a kernel launch node of CUDA kernels doing 1D - // memset. This should be removed when bug is fixed in CUDA. - absl::flat_hash_map> - memset_kernels_ ABSL_GUARDED_BY(mutex_); }; //===----------------------------------------------------------------------===// From 113cc820c3882a265a69721350f5cdf472350d4d Mon Sep 17 00:00:00 2001 From: Renjie Wu Date: Thu, 10 Apr 2025 11:09:15 -0700 Subject: [PATCH 0510/1324] Bump up cast/quantize/dequantize reference kernels. PiperOrigin-RevId: 746100141 --- tensorflow/lite/kernels/register_ref.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 479f3d5c996a8f..a6cdc7b6cc80f4 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -375,10 +375,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_CAST, Register_CAST(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 7); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 6); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU_REF()); AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM_REF(), /* min_version = */ 1, @@ -500,7 +500,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG()); AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE_REF(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); AddBuiltin(BuiltinOperator_IF, Register_IF()); AddBuiltin(BuiltinOperator_WHILE, Register_WHILE()); From ded0458a49c25b722d71d27629a09079fb651197 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 11:11:05 -0700 Subject: [PATCH 0511/1324] Enable subgraph reshaping by default in XNNPACK delegate This should improve performance (and could subtly change results) for some models that previously did not delegate to XNNPACK because this flag was disabled. PiperOrigin-RevId: 746100892 --- tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 0d904a9c44cf35..26cfad1ce7bf52 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -703,12 +703,7 @@ class Delegate { } bool enable_subgraph_reshaping() const { -#ifdef XNNPACK_DELEGATE_ENABLE_SUBGRAPH_RESHAPING return true; -#else - return (options_.flags & - TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING) != 0; -#endif } bool enable_slinky() const { From 9686f9a2dd64d5468658b6049d50afda2358c2a0 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Thu, 10 Apr 2025 11:11:41 -0700 Subject: [PATCH 0512/1324] Replace VLOG(0) with VLOG(1) VLOG(0) is enabled by default and is too noisy. PiperOrigin-RevId: 746101128 --- third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 0a0577bea80455..4f8fe6260401c6 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -2699,11 +2699,11 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( // too far, not for correctness. Placing it before the executable launch // allows the inputs for the next executable to be fetched even if the // launch is delayed. - VLOG(0) << "Going to get compute reservation for " << name() << ": " + VLOG(1) << "Going to get compute reservation for " << name() << ": " << options.launch_id << "; replica: " << replica; auto compute_reservation = std::make_unique( device->max_inflight_computations_semaphore().ScopedAcquire(1)); - VLOG(0) << "Got compute reservation for " << name() << ": " + VLOG(1) << "Got compute reservation for " << name() << ": " << options.launch_id << "; replica: " << replica; auto ffi_context = options.context != nullptr ? &options.context->ffi_context() : nullptr; @@ -2730,7 +2730,7 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( send_device_memory(std::move(send_device_memory)), recv_device_memory(std::move(recv_device_memory)), client = client_](std::vector execution_inputs) mutable { - VLOG(0) << "execute_fn for " << executable_name << ": " << launch_id + VLOG(1) << "execute_fn for " << executable_name << ": " << launch_id << "; replica: " << replica; tsl::profiler::TraceMe traceme("execute_fn"); auto set_error = [&](absl::Status status) { @@ -2796,7 +2796,7 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( absl::StatusOr result_buffer_or_status = gpu_executable->RunAsync(std::move(execution_inputs), run_options); - VLOG(0) << "Replica " << replica << " partition " << partition + VLOG(1) << "Replica " << replica << " partition " << partition << " completed; ok=" << result_buffer_or_status.ok(); if (!result_buffer_or_status.ok()) { From 82285129fc6ee670f7ad2d7797763e8fba4d7659 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 10 Apr 2025 11:23:46 -0700 Subject: [PATCH 0513/1324] [xla] Use std::weak_ptr to keep state for in-progress rendezvous PiperOrigin-RevId: 746105925 --- third_party/xla/xla/service/rendezvous.cc | 5 +- third_party/xla/xla/service/rendezvous.h | 124 +++++++++--------- .../xla/xla/service/rendezvous_test.cc | 47 +++++-- 3 files changed, 102 insertions(+), 74 deletions(-) diff --git a/third_party/xla/xla/service/rendezvous.cc b/third_party/xla/xla/service/rendezvous.cc index a7b189906549d8..f69bdc5b0aacc0 100644 --- a/third_party/xla/xla/service/rendezvous.cc +++ b/third_party/xla/xla/service/rendezvous.cc @@ -50,13 +50,12 @@ static bool WaitForReadyWithTimeout(RendezvousStateSynchronization& state, }); bool timed_out = state.cv.WaitWithTimeout(&state.mutex, timeout); - bool ready = state.ready; // We are done and ready. - if (ready) return true; + if (state.ready) return true; // We are done with waiting because the timeout is exceeded. - if (timed_out && !ready) { + if (timed_out && !state.ready) { return false; } diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index 86488963e16744..3cb424d980bd9b 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -214,15 +214,19 @@ struct RendezvousState : public RendezvousStateSynchronization { // Rendezvous state ownership: // // (1) When rendezvous participant initiates a rendezvous with a particular key -// we create a new state for it, keep it in a map for tracking and return a -// shared pointer to the caller. +// we create a new state for it, keep it in a map as weak pointer for +// tracking and return a shared pointer to the caller. // // (2) When rendezvous participant joins in-progress rendezvous it gets back // a shared pointer that is copied from a tracking map. // -// (3) When the last rendezvous participant computes the result it completes the -// rendezvous and removes a shared pointer to a state. Remaining shared -// pointers destructed when all participants are notified. +// (3) When rendezvous completes, the thread that completes it removes a state +// from a map, so that the next rendezvous with the same key can start +// immediately and create a new state. +// +// (4) If rendezvous failed to complete, the weak pointer will expire when all +// participants left the rendezvous, and will be lazily garbage collected +// in the next call to `Join`. // // This process guarantees that all completed rendezvous are removed from a map // and a map has records only for rendezvous in progress. @@ -233,68 +237,59 @@ class RendezvousMap { std::shared_ptr Join(const K& key, size_t num_threads) { absl::MutexLock lock(&mutex_); - std::shared_ptr& state = state_[key]; - - // Join an in-progress rendezvous. - if (state) return state; - - // Join a newly created rendezvous. - return state = std::make_shared(num_threads); - } - - template - void Complete(const K& key, Result&& result) { - std::shared_ptr state = [&] { - absl::MutexLock lock(&mutex_); - // Extract state from the map so we can immediately start a new round of - // rendezvous with the same key. A state for previous rendezvous will be - // destructed with the last copy of a shared pointer. - std::shared_ptr state = state_.extract(key).mapped(); + // Erase expired rendezvous from the map. + absl::erase_if(state_, [](const auto& e) { return e.second.expired(); }); - // Check that we have have exactly the number of participants we expected: - // +1 reference for all participants and a +1 reference we extracted. - CHECK_EQ(state.use_count(), 1 + state->values.size()); // NOLINT + std::weak_ptr& in_progress = state_[key]; - return state; - }(); - - // We notify awaiting participants without holding a rendezvous map lock, as - // the rendezvous callback might be an expensive operation and might block - // the progress of concurrent rendezvous for other keys. - - // Publish rendezvous result to all participants. - if constexpr (IsStatusOrResult::value) { - if (ABSL_PREDICT_TRUE(result.ok())) { - state->result = std::make_shared(*std::forward(result)); - } else { - state->result = result.status(); - } - } else { - state->result = std::make_shared(std::forward(result)); + // Try to join an in-progress rendezvous for a given key. + if (std::shared_ptr joined = in_progress.lock()) { + return joined; } - // Notify awaiting participants that result is ready. - absl::MutexLock lock(&state->mutex); - state->ready = true; - state->cv.SignalAll(); + // Start a new rendezvous for a given key. + std::shared_ptr start = std::make_shared(num_threads); + return (in_progress = start, start); + } + + void Complete(const K& key) { + absl::MutexLock lock(&mutex_); + state_.erase(key); } private: absl::Mutex mutex_; - absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); }; void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, absl::string_view name, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout); + } // namespace internal //===----------------------------------------------------------------------===// // Rendezvous implemenetation. //===----------------------------------------------------------------------===// +template +absl::StatusOr> InvokeRendezvous( + Fn fn, absl::Span values) { + auto result = fn(values); + + if constexpr (internal::IsStatusOrResult::value) { + if (ABSL_PREDICT_TRUE(result.ok())) { + return std::make_shared(*std::move(result)); + } else { + return result.status(); + } + } else { + return std::make_shared(std::move(result)); + } +} + template absl::StatusOr> Rendezvous( absl::string_view name, const K& key, const V& value, size_t num_threads, @@ -307,16 +302,7 @@ absl::StatusOr> Rendezvous( // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread). if (num_threads == 1) { const V* ptr = &value; - auto result = fn(absl::MakeSpan(&ptr, 1)); - - if constexpr (internal::IsStatusOrResult::value) { - if (ABSL_PREDICT_TRUE(result.ok())) { - return std::make_shared(*std::move(result)); - } - return result.status(); - } else { - return std::make_shared(std::move(result)); - } + return InvokeRendezvous(std::move(fn), absl::MakeSpan(&ptr, 1)); } using State = internal::RendezvousState; @@ -357,11 +343,29 @@ absl::StatusOr> Rendezvous( // Last thread to arrive executes the function and completes rendezvous by // making result available to all participants. All other participants will // be notified via `state->ready` flag when result is ready, and we rely on - // the store to a flag to create a memory barrier that makes access to - // `state->result` safe without any extra synchronization. - tsl::profiler::TraceMe trace("ExecuteRendezvousCallback"); + // the mutex to create a memory barrier that makes access to `state->result` + // safe without any extra synchronization. + tsl::profiler::TraceMe trace("InvokeRendezvous"); absl::Span values(state->values.data(), num_threads); - rendezvous.Complete(key, fn(values)); + + // Check that we have have exactly the number of participants we expect. + CHECK_EQ(state.use_count(), num_threads); // NOLINT + + // Publish rendezvous result to all participants. + state->result = InvokeRendezvous(std::move(fn), values); + + // Switch `ready` flag to signal all participants that result is ready. + { + absl::MutexLock lock(&state->mutex); + state->ready = true; + } + + // Notify awaiting participants that result is ready. + state->cv.SignalAll(); + + // Mark rendezvous as completed, so that we can immediately start a new + // rendezvous with the same key. + rendezvous.Complete(key); } return state->result; diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index f9fbb1c8287e20..5bb8d2880a2ac1 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -37,8 +37,8 @@ limitations under the License. namespace xla { namespace { -absl::Duration Timeout() { return absl::Seconds(10); } -absl::Duration Terminate() { return absl::Seconds(10); } +absl::Duration Timeout() { return absl::Seconds(5); } +absl::Duration Terminate() { return absl::Seconds(5); } tsl::thread::ThreadPool CreateThreadPool(int32_t size) { return tsl::thread::ThreadPool(tsl::Env::Default(), "rendezvous_test", size); @@ -108,8 +108,9 @@ TEST(RendezvousTest, RepeatRendezvous) { absl::BlockingCounter counter(2); auto task = [&] { - TF_ASSERT_OK( - Rendezvous("rendezvous_test", i, 2, [] { return 42; })); + TF_ASSERT_OK(Rendezvous( + "rendezvous_test", /*key=*/0, /*num_threads=*/2, [] { return 42; }, + Timeout(), Terminate())); counter.DecrementCount(); }; @@ -119,6 +120,28 @@ TEST(RendezvousTest, RepeatRendezvous) { } } +TEST(RendezvousTest, BackToBackRendezvous) { + auto thread_pool = CreateThreadPool(2); + + absl::BlockingCounter counter(2); + + // In contrast to the previous test, both task do back to back rendezvous + // without synchronization with a main thread. We check that in this case + // rendezvous do not step on each other and execute correctly. + auto task = [&] { + for (int32_t i = 0; i < 10; ++i) { + TF_ASSERT_OK(Rendezvous( + "rendezvous_test", /*key=*/0, /*num_threads=*/2, [] { return 42; }, + Timeout(), Terminate())); + } + counter.DecrementCount(); + }; + + thread_pool.Schedule(task); + thread_pool.Schedule(task); + counter.Wait(); +} + TEST(RendezvousTest, ReturningStatusOr) { absl::BlockingCounter counter(2); std::vector> results(2); @@ -268,8 +291,9 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - CHECK_OK(Rendezvous("rendezvous_test", 0, num_threads, - [] { return 42; })); + CHECK_OK(Rendezvous( + "rendezvous_test", /*key=*/0, num_threads, [] { return 42; }, + Timeout(), Terminate())); counter.DecrementCount(); }); } @@ -285,9 +309,9 @@ static void BM_RendezvousWithValues(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&, i] { - int32_t value = i; - CHECK_OK(Rendezvous("rendezvous_test", 0, value, num_threads, - [](auto) { return 42; })); + CHECK_OK(Rendezvous( + "rendezvous_test", /*key=*/0, /*value=*/i, num_threads, + [](auto) { return 42; }, Timeout(), Terminate())); counter.DecrementCount(); }); } @@ -306,8 +330,9 @@ static void BM_GroupedRendezvous(benchmark::State& state) { for (int64_t group = 0; group < num_groups; ++group) { for (int64_t i = 0; i < group_size; ++i) { thread_pool.Schedule([&, group] { - CHECK_OK(Rendezvous("rendezvous_test", group, group_size, - [] { return 42; })); + CHECK_OK(Rendezvous( + "rendezvous_test", /*key=*/group, /*num_threads=*/group_size, + [] { return 42; }, Timeout(), Terminate())); counter.DecrementCount(); }); } From e70561e67cda9229c7ceb4dcb735243272cc74f9 Mon Sep 17 00:00:00 2001 From: Won Jong Jeon Date: Thu, 10 Apr 2025 11:59:41 -0700 Subject: [PATCH 0514/1324] [mlir][tosa] Fix log lit tests (#91088) This fixes couple of log lit tests where tosa const attribute "value" has changed to "values" Change-Id: Ib845ac4f34256b33b16f36c0c87edfc213069080 Signed-off-by: Tai Ly Co-authored-by: Tai Ly --- tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index c802d3c6e9033e..fd44d2738beb22 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -943,7 +943,7 @@ func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_log_qi8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> // CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] func.func @test_log_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> @@ -954,7 +954,7 @@ func.func @test_log_qi8(%arg0: tensor<13x21x3x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> // CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] func.func @test_log_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> From aac649d516ccfd8512a701ab04267c318e6019bd Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 10 Apr 2025 11:24:55 -0700 Subject: [PATCH 0515/1324] Implement d2d transfer `TfrtGpuBuffer::CopyToMemorySpace` PiperOrigin-RevId: 746106384 --- third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index 71c2615514244b..039ee0cde44649 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -1011,10 +1011,10 @@ TEST(TfrtGpuClientTest, CreateViewOfDeviceBuffer) { TEST(TfrtGpuClientTest, CopyRawToHostFullBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); auto literal = xla::LiteralUtil::CreateR1({41.0f, 42.0f}); + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr buffer, client->BufferFromHostLiteral(literal, client->memory_spaces()[0])); - TF_ASSERT_OK_AND_ASSIGN(int64_t size, buffer->GetOnDeviceSizeInBytes()); void* dst = tsl::port::AlignedMalloc(size, tsl::Allocator::kAllocatorAlignment); From 01285d52efc179f23cd3e1bcd594c5415edb33cc Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 10 Apr 2025 11:45:15 -0700 Subject: [PATCH 0516/1324] [xla:gpu] CommandBuffer: use explicit command update API to update command buffers Do not rely on UpdateState to track indices of updated commands and instead explicitly call update APIs from CommandBufferCmd implementations. PiperOrigin-RevId: 746114354 --- .../gpu/runtime/command_buffer_cmd.cc | 372 ++++++++++------- .../backends/gpu/runtime/command_buffer_cmd.h | 17 +- .../xla/xla/stream_executor/command_buffer.h | 11 + .../stream_executor/gpu/gpu_command_buffer.cc | 378 +++++++----------- .../stream_executor/gpu/gpu_command_buffer.h | 12 +- 5 files changed, 414 insertions(+), 376 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index ce19ee9ac5bd73..98c04247c17604 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -74,7 +74,6 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -157,6 +156,35 @@ CommandBufferCmd::RecordedCommands::Create( return RecordedCommands{{*command}}; } +//===----------------------------------------------------------------------===// +// CommandBufferCmd::RecordAction helpers. +//===----------------------------------------------------------------------===// + +using CreateCommand = + absl::FunctionRef( + absl::Span dependencies)>; + +using UpdateCommand = + absl::FunctionRef; + +// Handles a record action by calling one of the user-provided functions. +static absl::StatusOr Handle( + CommandBufferCmd::RecordAction action, CreateCommand create_command, + UpdateCommand update_command) { + if (auto* create = std::get_if(&action)) { + return CommandBufferCmd::RecordedCommands::Create( + create_command(create->dependencies)); + } + + if (auto* update = std::get_if(&action)) { + auto* command = update->recorded_commands.commands[0]; + TF_RETURN_IF_ERROR(update_command(command)); + return std::move(update->recorded_commands); + } + + return Internal("Invalid record action"); +} + //===----------------------------------------------------------------------===// // CommandBufferCmd //===----------------------------------------------------------------------===// @@ -408,10 +436,11 @@ TracedCommandBufferCmd::TracedCommandBufferCmd( CommandBufferCmdType cmd_type, ExecutionStreamId execution_stream_id) : CommandBufferCmd(cmd_type, execution_stream_id) {} -absl::StatusOr -TracedCommandBufferCmd::AddTracedCommandBuffer( +absl::StatusOr +TracedCommandBufferCmd::RecordTracedCommand( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer, absl::FunctionRef trace) { auto traced_cmd = record_params.state.GetOrCreate( this, command_buffer, [&] { @@ -427,68 +456,21 @@ TracedCommandBufferCmd::AddTracedCommandBuffer( execute_params.command_buffer_trace_stream, trace)); VLOG(5) << "Add nested command buffer"; - return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->AddNestedCommandBuffer(*nested_cmd, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + }); } //===----------------------------------------------------------------------===// // ComputationId //===----------------------------------------------------------------------===// -// TODO(ezhulenev): PTX kernel should be replaced with CUDA C++ kernel but -// today we accidentally try to build them without CUDA support. We need to -// clean our build and testing infrastructure first. - -// PTX kernel compiled from: -// -// __global__ void memset32(int64_t n, uint32_t value, uint32_t* dst) -// { -// int i = blockIdx.x*blockDim.x + threadIdx.x; -// if (i < n) dst[i] = value; -// } -// -// Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr absl::string_view kMemset32Kernel = R"( -.version 4.0 -.target sm_50 -.address_size 64 - -.visible .entry memset32( - .param .u64 memset32_param_0, - .param .u32 memset32_param_1, - .param .u64 memset32_param_2 -) -{ - .reg .pred %p<2>; - .reg .b32 %r<6>; - .reg .b64 %rd<7>; - .loc 1 3 0 - - ld.param.u64 %rd3, [memset32_param_0]; - ld.param.u32 %r1, [memset32_param_1]; - ld.param.u64 %rd2, [memset32_param_2]; - .loc 1 5 3 - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mov.u32 %r4, %tid.x; - mad.lo.s32 %r5, %r2, %r3, %r4; - .loc 1 6 3 - cvt.s64.s32 %rd1, %r5; - setp.ge.s64 %p1, %rd1, %rd3; - @%p1 bra $L__BB0_2; - - .loc 1 5 3 - cvta.to.global.u64 %rd4, %rd2; - .loc 1 6 3 - shl.b64 %rd5, %rd1, 2; - add.s64 %rd6, %rd4, %rd5; - st.global.u32 [%rd6], %r1; - -$L__BB0_2: - .loc 1 7 1 - ret; - -})"; - ComputationIdCmd::ComputationIdCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dest, Kind kind) : CommandBufferCmd(CommandBufferCmdType::kComputationIdCmd, @@ -522,8 +504,16 @@ absl::StatusOr ComputationIdCmd::Record( << "; value=" << value; VLOG(5) << " Id: " << dest_ << " (" << dst.opaque() << ")"; - return RecordedCommands::Create( - command_buffer->Memset(&dst, value, /*num_elements=*/1, {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->Memset(&dst, value, /*num_elements=*/1, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->Memset(command, &dst, value, + /*num_elements=*/1); + }); } //===----------------------------------------------------------------------===// @@ -587,9 +577,18 @@ absl::StatusOr LaunchCmd::Record( TF_ASSIGN_OR_RETURN(auto kernel_args, se::PackKernelArgs(buffers, shmem_bytes_)); - return RecordedCommands::Create( - command_buffer->Launch(dims_.thread_counts_per_block(), - dims_.block_counts(), *kernel, *kernel_args, {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->Launch(dims_.thread_counts_per_block(), + dims_.block_counts(), *kernel, + *kernel_args, dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->Launch(command, dims_.thread_counts_per_block(), + dims_.block_counts(), *kernel, + *kernel_args); + }); } CommandBufferCmd::BufferUseVector LaunchCmd::buffers() { @@ -659,9 +658,18 @@ CustomKernelLaunchCmd::Record(const Thunk::ExecuteParams& execute_params, se::KernelArgsDeviceMemoryArray kernel_args( buffers, custom_kernel_.shared_memory_bytes()); - return RecordedCommands::Create(command_buffer->Launch( - custom_kernel_.thread_dims(), custom_kernel_.block_dims(), *kernel, - kernel_args, {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->Launch(custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *kernel, + kernel_args, dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->Launch(command, custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *kernel, + kernel_args); + }); } CommandBufferCmd::BufferUseVector CustomKernelLaunchCmd::buffers() { @@ -704,8 +712,16 @@ MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, return RecordedCommands{}; } - return RecordedCommands::Create( - command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_, {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->MemcpyDeviceToDevice(command, &dst, src, + num_bytes_); + }); } CommandBufferCmd::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() { @@ -736,9 +752,17 @@ absl::StatusOr MemzeroCmd::Record( return RecordedCommands{}; } - return RecordedCommands::Create( - command_buffer->Memset(&dst, uint8_t{0}, - /*num_elements=*/dst_.size(), {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->Memset(&dst, uint8_t{0}, + /*num_elements=*/dst_.size(), + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->Memset(command, &dst, uint8_t{0}, + /*num_elements=*/dst_.size()); + }); } CommandBufferCmd::BufferUseVector MemzeroCmd::buffers() { @@ -770,9 +794,18 @@ absl::StatusOr Memset32Cmd::Record( return RecordedCommands{}; } - return RecordedCommands::Create(command_buffer->Memset( - &dst, bit_pattern_, - /*num_elements=*/dst_.size() / sizeof(uint32_t), {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->Memset( + &dst, bit_pattern_, + /*num_elements=*/dst_.size() / sizeof(uint32_t), dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->Memset( + command, &dst, bit_pattern_, + /*num_elements=*/dst_.size() / sizeof(uint32_t)); + }); } CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() { @@ -809,19 +842,31 @@ absl::StatusOr CaseCmd::Record( VLOG(5) << "CaseCmd:"; VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")"; - if (index_is_bool_) { - return RecordedCommands::Create( - command_buffer->Case(se::DeviceMemory(index), - CreateBuilders(absl::MakeSpan(branches_commands_), - &execute_params, &record_params), - {})); - } else { - return RecordedCommands::Create( - command_buffer->Case(se::DeviceMemory(index), - CreateBuilders(absl::MakeSpan(branches_commands_), - &execute_params, &record_params), - {})); - } + auto branches = CreateBuilders(absl::MakeSpan(branches_commands_), + &execute_params, &record_params); + + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + if (index_is_bool_) { + return command_buffer->Case(se::DeviceMemory(index), + std::move(branches), dependencies); + + } else { + return command_buffer->Case(se::DeviceMemory(index), + std::move(branches), dependencies); + } + }, + [&](const se::CommandBuffer::Command* command) { + if (index_is_bool_) { + return command_buffer->Case(command, se::DeviceMemory(index), + std::move(branches)); + + } else { + return command_buffer->Case(command, se::DeviceMemory(index), + std::move(branches)); + } + }); } bool CaseCmd::force_update() { @@ -868,11 +913,21 @@ absl::StatusOr WhileCmd::Record( << " body_commands=" << body_commands_.size(); VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; - return RecordedCommands::Create(command_buffer->While( - se::DeviceMemory(pred), - CreateExecutionScopeBuilder(&cond_commands_, &execute_params, - &record_params), - CreateBuilder(&body_commands_, &execute_params, &record_params), {})); + auto cond = CreateExecutionScopeBuilder(&cond_commands_, &execute_params, + &record_params); + auto body = CreateBuilder(&body_commands_, &execute_params, &record_params); + + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->While(se::DeviceMemory(pred), + std::move(cond), std::move(body), + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->While(command, se::DeviceMemory(pred), + std::move(cond), std::move(body)); + }); } bool WhileCmd::force_update() { @@ -934,11 +989,12 @@ absl::StatusOr GemmCmd::Record( VLOG(5) << " Out: " << output_buffer_ << " (" << out.opaque() << ")"; VLOG(5) << " Workspace: " << workspace_ << " (" << workspace.opaque() << ")"; - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, - stream); - })); + return RecordTracedCommand(execute_params, record_params, + std::move(record_action), command_buffer, + [&](se::Stream* stream) { + return RunGemm(config_, lhs, rhs, out, workspace, + deterministic_, stream); + }); } CommandBufferCmd::BufferUseVector GemmCmd::buffers() { @@ -1079,8 +1135,9 @@ absl::StatusOr CublasLtCmd::Record( VLOG(5) << " d_amax_buffer: " << d_amax_buffer_.ToString(); VLOG(5) << " workspace_buffer: " << workspace_buffer_.ToString(); - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RecordTracedCommand( + execute_params, record_params, std::move(record_action), command_buffer, + [&](se::Stream* stream) { return plan->ExecuteOnStream( stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), @@ -1088,7 +1145,7 @@ absl::StatusOr CublasLtCmd::Record( allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, algorithm, allocs.GetDeviceAddress(workspace_buffer_)); - })); + }); } CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() { @@ -1165,12 +1222,13 @@ absl::StatusOr CuDnnCmd::Record( *graph_->get(), *execute_params.stream, absl::Span(operands), {})); } - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RecordTracedCommand( + execute_params, record_params, std::move(record_action), command_buffer, + [&](se::Stream* stream) { return graph_->get()->Execute( *stream, absl::Span(operands), execute_params.collective_params->local_device_ordinal); - })); + }); } CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { @@ -1193,9 +1251,10 @@ absl::StatusOr CustomCallCmd::Record( se::CommandBuffer* command_buffer) { if (handler_ == nullptr) { return RecordLegacyCustomCall(execute_params, record_params, - command_buffer); + std::move(record_action), command_buffer); } - return RecordXlaFfiCall(execute_params, record_params, command_buffer); + return RecordXlaFfiCall(execute_params, record_params, + std::move(record_action), command_buffer); } namespace { @@ -1229,7 +1288,8 @@ absl::Status GetBuffers( absl::StatusOr CustomCallCmd::RecordLegacyCustomCall( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { std::vector buffers; buffers.reserve(operands_.size() + results_.size()); @@ -1255,13 +1315,21 @@ CustomCallCmd::RecordLegacyCustomCall( return absl::OkStatus(); })); - return RecordedCommands::Create( - command_buffer->AddNestedCommandBuffer(*nested_cmd, {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->AddNestedCommandBuffer(*nested_cmd, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + }); } absl::StatusOr CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, + RecordAction record_action, se::CommandBuffer* command_buffer) { // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes @@ -1329,8 +1397,15 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, return ffi::Call(handler_, call_frame, options); })); - return RecordedCommands::Create( - command_buffer->AddNestedCommandBuffer(*nested_cmd, {})); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->AddNestedCommandBuffer(*nested_cmd, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + }); } CommandBufferCmd::BufferUseVector CustomCallCmd::buffers() { @@ -1369,17 +1444,26 @@ absl::Status CollectiveCmd::Prepare( return resource_requests.AddClique(clique_key); } -absl::StatusOr -CollectiveCmd::AddTracedCommandBuffer( +absl::StatusOr +CollectiveCmd::RecordTracedCommand( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer, absl::FunctionRef trace) { TF_ASSIGN_OR_RETURN(std::unique_ptr nested_cmd, se::TraceCommandBufferFactory::Create( execute_params.stream->parent(), execute_params.command_buffer_trace_stream, trace)); - return command_buffer->AddNestedCommandBuffer(*nested_cmd, {}); + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->AddNestedCommandBuffer(*nested_cmd, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + }); } //===----------------------------------------------------------------------===// @@ -1428,11 +1512,12 @@ absl::StatusOr AllReduceCmd::Record( *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RecordTracedCommand( + execute_params, record_params, std::move(record_action), command_buffer, + [&](se::Stream* stream) { return RunAllReduce(collectives, reduction_kind_, device_buffers, *stream, comm_handle.comm); - })); + }); } CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() { @@ -1491,11 +1576,12 @@ absl::StatusOr ReduceScatterCmd::Record( *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunReduceScatter(collectives, reduction_kind_, device_buffers, - *stream, comm_handle.comm); - })); + return RecordTracedCommand(execute_params, record_params, record_action, + command_buffer, [&](se::Stream* stream) { + return RunReduceScatter( + collectives, reduction_kind_, device_buffers, + *stream, comm_handle.comm); + }); } CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() { @@ -1551,11 +1637,12 @@ absl::StatusOr AllToAllCmd::Record( *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RecordTracedCommand( + execute_params, record_params, std::move(record_action), command_buffer, + [&](se::Stream* stream) { return RunAllToAll(collectives, has_split_dimension_, device_buffers, *stream, comm_handle.comm); - })); + }); } CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() { @@ -1611,11 +1698,12 @@ absl::StatusOr AllGatherCmd::Record( *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunAllGather(collectives, device_buffers, *stream, - comm_handle.comm); - })); + return RecordTracedCommand(execute_params, record_params, + std::move(record_action), command_buffer, + [&](se::Stream* stream) { + return RunAllGather(collectives, device_buffers, + *stream, comm_handle.comm); + }); } CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() { @@ -1673,11 +1761,12 @@ CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params, *execute_params.collective_cliques, config().replica_groups, config().group_mode, GetAsyncStreamKind())); - return RecordedCommands::Create(AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RecordTracedCommand( + execute_params, record_params, std::move(record_action), command_buffer, + [&](se::Stream* stream) { return RunCollectiveBroadcast(device_buffers, *stream, comm_handle.comm, collectives); - })); + }); } CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() { @@ -1927,8 +2016,17 @@ DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, .value(); TF_RETURN_IF_ERROR(embedded_commands_.Record(new_params, record_params, nested_command_buffer.get())); - return RecordedCommands::Create( - command_buffer->AddNestedCommandBuffer(*nested_command_buffer, {})); + + return Handle( + std::move(record_action), + [&](absl::Span dependencies) { + return command_buffer->AddNestedCommandBuffer(*nested_command_buffer, + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + return command_buffer->AddNestedCommandBuffer(command, + *nested_command_buffer); + }); } CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index db17ea1a639e9d..f93992caf6099a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -436,9 +436,10 @@ class TracedCommandBufferCmd : public CommandBufferCmd { // Creates a command buffer by calling a user-provided `trace` function and // adds it as a nested command to `command_buffer`. Traced command buffers // cached and reused in an instance of `TracedCommandBuffer` kept in `state`. - absl::StatusOr AddTracedCommandBuffer( + absl::StatusOr RecordTracedCommand( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer, absl::FunctionRef trace); }; @@ -832,10 +833,13 @@ class CustomCallCmd : public CommandBufferCmd { private: absl::StatusOr RecordLegacyCustomCall( const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, se::CommandBuffer* command_buffer); + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer); + absl::StatusOr RecordXlaFfiCall( const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, se::CommandBuffer* command_buffer); + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer); std::string target_name_; @@ -874,9 +878,10 @@ class CollectiveCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } - absl::StatusOr AddTracedCommandBuffer( + absl::StatusOr RecordTracedCommand( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer, absl::FunctionRef trace); virtual AsyncStreamKind GetAsyncStreamKind() = 0; diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 15dbdcb52c1447..2ac80cf5030c17 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -87,6 +87,17 @@ class CommandBuffer { // enum class State { kCreate, kUpdate, kFinalized }; + friend absl::string_view StateToString(State state) { + switch (state) { + case CommandBuffer::State::kCreate: + return "create"; + case CommandBuffer::State::kUpdate: + return "update"; + case CommandBuffer::State::kFinalized: + return "finalized"; + } + } + // Command buffers have two modes of execution: // // (1) kPrimary: command buffer can be submitted for execution via diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index e1289b3abdfc93..3947d87cfbb2c7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -28,7 +28,6 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -48,33 +47,10 @@ limitations under the License. namespace stream_executor::gpu { -//===----------------------------------------------------------------------===// -// Implementation details device kernels required by GpuCommandBuffer. -//===----------------------------------------------------------------------===// - using Mode = CommandBuffer::Mode; using State = CommandBuffer::State; using GraphNodeHandle = GpuCommandBuffer::GraphNodeHandle; using GraphConditionalHandle = GpuCommandBuffer::GraphConditionalHandle; -using GraphConditionalHandles = absl::Span; - -namespace { -absl::string_view to_string(State state) { - switch (state) { - case State::kCreate: - return "create"; - case State::kUpdate: - return "update"; - case State::kFinalized: - return "finalized"; - } -} - -absl::Status UnsupportedStateError(State state) { - return absl::InternalError( - absl::StrCat("Unsupported command buffer state: ", to_string(state))); -} -} // namespace //===----------------------------------------------------------------------===// // GpuCommandBuffer resource usage tracking @@ -145,39 +121,40 @@ absl::Status GpuCommandBuffer::CheckNotFinalized() { return absl::OkStatus(); } +absl::Status GpuCommandBuffer::CheckInState(State state) { + if (state_ != state) { + return absl::InternalError(absl::StrFormat( + "Expected command buffer to be in state %v but it was in state %v", + state, state_)); + } + return absl::OkStatus(); +} + absl::StatusOr GpuCommandBuffer::LaunchWithPackedArgs( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args, absl::Span dependencies) { + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); + CHECK_EQ(kernel.Arity() + (packed_args.number_of_shared_bytes() > 0), packed_args.number_of_arguments()); // Adds a new kernel node to the graph under construction. - if (state_ == State::kCreate) { - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN( - GraphNodeHandle handle, - CreateKernelNode(barrier, threads, blocks, kernel, packed_args)); - return AppendCommand(handle); - } - - // Updates kernel node in the executable graph. - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR( - LaunchWithPackedArgs(&command, threads, blocks, kernel, packed_args)); - return &command; - } + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + TF_ASSIGN_OR_RETURN( + GraphNodeHandle handle, + CreateKernelNode(barrier, threads, blocks, kernel, packed_args)); - return UnsupportedStateError(state_); + return AppendCommand(handle); } absl::Status GpuCommandBuffer::LaunchWithPackedArgs( const Command* command, const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateKernelNode(gpu_command->handle, threads, blocks, kernel, packed_args); @@ -186,7 +163,7 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs( absl::StatusOr GpuCommandBuffer::Launch( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args, absl::Span dependencies) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); // If arguments are already packed we can just launch the kernel. if (auto* packed = DynCast(&args)) { @@ -214,7 +191,7 @@ absl::Status GpuCommandBuffer::Launch(const Command* command, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); // If arguments are already packed we can just launch the kernel. if (auto* packed = DynCast(&args)) { @@ -241,30 +218,19 @@ absl::StatusOr GpuCommandBuffer::AddNestedCommandBuffer( const CommandBuffer& nested, absl::Span dependencies) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - // Adds a child graph node to the graph under construction. - if (state_ == State::kCreate) { - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, - CreateChildNode(barrier, nested)); - return AppendCommand(handle); - } + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, CreateChildNode(barrier, nested)); - // Updates child graph node in the executable graph. - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR(AddNestedCommandBuffer(&command, nested)); - return &command; - } - - return UnsupportedStateError(state_); + return AppendCommand(handle); } absl::Status GpuCommandBuffer::AddNestedCommandBuffer( const Command* command, const CommandBuffer& nested) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateChildNode(gpu_command->handle, nested); } @@ -273,30 +239,22 @@ absl::StatusOr GpuCommandBuffer::MemcpyDeviceToDevice( DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size, absl::Span dependencies) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - if (state_ == State::kCreate) { - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, - CreateMemcpyD2DNode(barrier, *dst, src, size)); - return AppendCommand(handle); - } - - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR(MemcpyDeviceToDevice(&command, dst, src, size)); - return &command; - } + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, + CreateMemcpyD2DNode(barrier, *dst, src, size)); - return UnsupportedStateError(state_); + return AppendCommand(handle); } absl::Status GpuCommandBuffer::MemcpyDeviceToDevice(const Command* command, DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateMemcpyD2DNode(gpu_command->handle, *dst, src, size); } @@ -304,31 +262,23 @@ absl::Status GpuCommandBuffer::MemcpyDeviceToDevice(const Command* command, absl::StatusOr GpuCommandBuffer::Memset( DeviceMemoryBase* dst, BitPattern bit_pattern, size_t num_elements, absl::Span dependencies) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - if (state_ == State::kCreate) { - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN( - GraphNodeHandle handle, - CreateMemsetNode(barrier, *dst, bit_pattern, num_elements)); - return AppendCommand(handle); - } - - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR(Memset(&command, dst, bit_pattern, num_elements)); - return &command; - } + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + TF_ASSIGN_OR_RETURN( + GraphNodeHandle handle, + CreateMemsetNode(barrier, *dst, bit_pattern, num_elements)); - return UnsupportedStateError(state_); + return AppendCommand(handle); } absl::Status GpuCommandBuffer::Memset(const Command* command, DeviceMemoryBase* dst, const BitPattern& bit_pattern, size_t num_elements) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateMemsetNode(gpu_command->handle, *dst, bit_pattern, num_elements); } @@ -337,41 +287,32 @@ absl::StatusOr GpuCommandBuffer::DnnGraph( dnn::DnnGraph& dnn_graph, Stream& stream, absl::Span operands, absl::Span dependencies) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - if (state_ == State::kCreate) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr nested, - stream.parent()->CreateCommandBuffer(CommandBuffer::Mode::kNested)); - GpuCommandBuffer& nested_gpu = - tensorflow::down_cast(*nested); - TF_RETURN_IF_ERROR( - nested_gpu.PopulateDnnGraphNode(dnn_graph, stream, operands)); - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, - CreateChildNode(barrier, *nested)); - return AppendCommand(handle); - } - - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR(DnnGraph(&command, dnn_graph, stream, operands)); - return &command; - } + TF_ASSIGN_OR_RETURN(std::unique_ptr nested, + stream.parent()->CreateCommandBuffer(Mode::kNested)); + GpuCommandBuffer& nested_gpu = + tensorflow::down_cast(*nested); + TF_RETURN_IF_ERROR( + nested_gpu.PopulateDnnGraphNode(dnn_graph, stream, operands)); + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, + CreateChildNode(barrier, *nested)); - return UnsupportedStateError(state_); + return AppendCommand(handle); } absl::Status GpuCommandBuffer::DnnGraph(const Command* command, dnn::DnnGraph& dnn_graph, Stream& stream, absl::Span operands) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); + auto* gpu_command = tsl::down_cast(command); - TF_ASSIGN_OR_RETURN( - std::unique_ptr nested, - stream.parent()->CreateCommandBuffer(CommandBuffer::Mode::kNested)); + TF_ASSIGN_OR_RETURN(std::unique_ptr nested, + stream.parent()->CreateCommandBuffer(Mode::kNested)); GpuCommandBuffer& nested_gpu = tensorflow::down_cast(*nested); TF_RETURN_IF_ERROR(nested_gpu.UpdateDnnGraphNode(dnn_graph, stream, operands, @@ -397,83 +338,77 @@ absl::StatusOr GpuCommandBuffer::Case( DeviceMemory index, bool index_is_bool, std::vector branches, absl::Span dependencies) { + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); + constexpr size_t kBranchBatchSize = 8; - if (state_ == State::kCreate) { - GpuCaseCommand command = {}; - - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - - int32_t batch_offset = 0; - while (batch_offset < branches.size()) { - // Conditionals will by default run branches[branchs.size()-1] if index is - // `< 0` or `>= branches.size()`. See - // https://openxla.org/xla/operation_semantics#conditional. - // To break down a large case with back to back ConditionalCommands, only - // the last batch should accept this default case. - int32_t remaining_branches = branches.size() - batch_offset; - int32_t batch_size; - bool enable_conditional_default; - if (remaining_branches <= kBranchBatchSize) { - batch_size = remaining_branches; - enable_conditional_default = true; - } else { - batch_size = kBranchBatchSize; - enable_conditional_default = false; - } - - TF_ASSIGN_OR_RETURN(auto conditionals, - CreateConditionalHandles(batch_size)); - - TF_ASSIGN_OR_RETURN(auto set_condition_node, - CreateSetCaseConditionNode( - conditionals, index, index_is_bool, batch_offset, - enable_conditional_default, barrier)); - - std::vector conditional_nodes; - for (int z = 0; z < batch_size; ++z) { - int branch_offset = z + batch_offset; - TF_ASSIGN_OR_RETURN( - conditional_nodes.emplace_back(), - CreateConditionalNode({set_condition_node}, conditionals[z], - ConditionType::kIf)); - - GpuCommandBuffer* case_command_buffer = - conditional_nodes.back().command_buffer.get(); - TF_RETURN_IF_ERROR(branches[branch_offset](case_command_buffer)); - TF_RETURN_IF_ERROR(case_command_buffer->Finalize()); - } - - // Move the state into the recorded command. - command.conditionals.insert(command.conditionals.end(), - conditionals.begin(), conditionals.end()); - command.set_condition_nodes.push_back(set_condition_node); - command.conditional_nodes.insert( - command.conditional_nodes.end(), - std::make_move_iterator(conditional_nodes.begin()), - std::make_move_iterator(conditional_nodes.end())); - - batch_offset += batch_size; + GpuCaseCommand command = {}; + + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); + + int32_t batch_offset = 0; + while (batch_offset < branches.size()) { + // Conditionals will by default run branches[branchs.size()-1] if index is + // `< 0` or `>= branches.size()`. See + // https://openxla.org/xla/operation_semantics#conditional. + // To break down a large case with back to back ConditionalCommands, only + // the last batch should accept this default case. + int32_t remaining_branches = branches.size() - batch_offset; + int32_t batch_size; + bool enable_conditional_default; + if (remaining_branches <= kBranchBatchSize) { + batch_size = remaining_branches; + enable_conditional_default = true; + } else { + batch_size = kBranchBatchSize; + enable_conditional_default = false; } - return AppendCommand(std::move(command)); - } + TF_ASSIGN_OR_RETURN(auto conditionals, + CreateConditionalHandles(batch_size)); + + TF_ASSIGN_OR_RETURN(auto set_condition_node, + CreateSetCaseConditionNode( + conditionals, index, index_is_bool, batch_offset, + enable_conditional_default, barrier)); + + std::vector conditional_nodes; + for (int z = 0; z < batch_size; ++z) { + int branch_offset = z + batch_offset; + TF_ASSIGN_OR_RETURN( + conditional_nodes.emplace_back(), + CreateConditionalNode({set_condition_node}, conditionals[z], + ConditionType::kIf)); + + GpuCommandBuffer* case_command_buffer = + conditional_nodes.back().command_buffer.get(); + TF_RETURN_IF_ERROR(branches[branch_offset](case_command_buffer)); + TF_RETURN_IF_ERROR(case_command_buffer->Finalize()); + } + + // Move the state into the recorded command. + command.conditionals.insert(command.conditionals.end(), + conditionals.begin(), conditionals.end()); + command.set_condition_nodes.push_back(set_condition_node); + command.conditional_nodes.insert( + command.conditional_nodes.end(), + std::make_move_iterator(conditional_nodes.begin()), + std::make_move_iterator(conditional_nodes.end())); - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR(Case(&command, index, index_is_bool, branches)); - return &command; + batch_offset += batch_size; } - return UnsupportedStateError(state_); + return AppendCommand(std::move(command)); } absl::Status GpuCommandBuffer::Case(const Command* command, DeviceMemory index, bool index_is_bool, std::vector branches) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); + constexpr size_t kBranchBatchSize = 8; auto* gpu_command = tsl::down_cast(command); @@ -553,52 +488,46 @@ absl::Status GpuCommandBuffer::Case(const Command* command, absl::StatusOr GpuCommandBuffer::While( DeviceMemory pred, Builder cond_builder, Builder body_builder, absl::Span dependencies) { - if (state_ == State::kCreate) { - GpuWhileCommand command = {}; - - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - - // TODO(ezhulenev): cond_builder should be able to take dependencies. - (void)barrier; - - TF_RETURN_IF_ERROR(cond_builder(this)); - - TF_ASSIGN_OR_RETURN(command.conditional, CreateConditionalHandle()); - TF_ASSIGN_OR_RETURN(command.set_init_condition_node, - CreateSetWhileConditionNode(command.conditional, pred, - GetAutoDependencies())); - TF_ASSIGN_OR_RETURN( - command.conditional_node, - CreateConditionalNode({command.set_init_condition_node}, - command.conditional, ConditionType::kWhile)); - - GpuCommandBuffer* body = command.conditional_node.command_buffer.get(); - TF_RETURN_IF_ERROR(body_builder(body)); - TF_RETURN_IF_ERROR(cond_builder(body)); - TF_ASSIGN_OR_RETURN( - command.set_body_condition_node, - body->CreateSetWhileConditionNode(command.conditional, pred, - body->GetAutoDependencies())); - TF_RETURN_IF_ERROR(command.conditional_node.command_buffer->Finalize()); - - return AppendCommand(std::move(command)); - } + TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - if (state_ == State::kUpdate) { - Command& command = *commands_[update_state_.command_idx++]; - TF_RETURN_IF_ERROR(While(&command, pred, cond_builder, body_builder)); - return &command; - } + GpuWhileCommand command = {}; + + Dependencies barrier = dependencies.empty() + ? GetAutoDependencies() + : ToGraphNodeDependencies(dependencies); - return UnsupportedStateError(state_); + // TODO(ezhulenev): cond_builder should be able to take dependencies. + (void)barrier; + + TF_RETURN_IF_ERROR(cond_builder(this)); + + TF_ASSIGN_OR_RETURN(command.conditional, CreateConditionalHandle()); + TF_ASSIGN_OR_RETURN(command.set_init_condition_node, + CreateSetWhileConditionNode(command.conditional, pred, + GetAutoDependencies())); + TF_ASSIGN_OR_RETURN( + command.conditional_node, + CreateConditionalNode({command.set_init_condition_node}, + command.conditional, ConditionType::kWhile)); + + GpuCommandBuffer* body = command.conditional_node.command_buffer.get(); + TF_RETURN_IF_ERROR(body_builder(body)); + TF_RETURN_IF_ERROR(cond_builder(body)); + TF_ASSIGN_OR_RETURN( + command.set_body_condition_node, + body->CreateSetWhileConditionNode(command.conditional, pred, + body->GetAutoDependencies())); + TF_RETURN_IF_ERROR(command.conditional_node.command_buffer->Finalize()); + + return AppendCommand(std::move(command)); } absl::Status GpuCommandBuffer::While(const Command* command, DeviceMemory pred, Builder cond_builder, Builder body_builder) { + TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); + auto* gpu_command = tsl::down_cast(command); TF_RETURN_IF_ERROR(cond_builder(this)); @@ -695,7 +624,6 @@ absl::Status GpuCommandBuffer::Update() { << " command buffer " << this; state_ = State::kUpdate; - update_state_ = UpdateState(); return absl::OkStatus(); } @@ -705,7 +633,7 @@ GpuCommandBuffer::commands() const { } absl::Status GpuCommandBuffer::Submit(Stream* stream) { - if (mode_ != CommandBuffer::Mode::kPrimary) { + if (mode_ != Mode::kPrimary) { return absl::InvalidArgumentError( "Can't submit non-primary command buffer for execution"); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 14f9da3e57b556..7381e973f7acb2 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -275,6 +275,10 @@ class GpuCommandBuffer : public CommandBuffer { // possible to add new commands to it, otherwise returns internal error. absl::Status CheckNotFinalized(); + // Return OK status if command buffer is in the given state, otherwise returns + // an error. + absl::Status CheckInState(State state); + // Returns OK status if the command buffer can be updated. virtual absl::Status CheckCanBeUpdated() = 0; @@ -415,14 +419,6 @@ class GpuCommandBuffer : public CommandBuffer { // Gpu commands recorded into the command buffer. std::vector> commands_; - - // Tracks indices into data structures during command buffer updates. - struct UpdateState { - int64_t command_idx = 0; - }; - - // Tracks execution scope update state. - UpdateState update_state_; }; } // namespace stream_executor::gpu From 458f4617a1cbb5ec727dc1bb8ab466efabcf2fb8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 12:06:28 -0700 Subject: [PATCH 0517/1324] Removes all memory term reduction that occurs outside the solver. PiperOrigin-RevId: 746122805 --- .../auto_sharding/auto_sharding.cc | 156 ++++++------------ .../auto_sharding/auto_sharding_impl.cc | 11 +- .../auto_sharding/auto_sharding_wrapper.h | 8 - 3 files changed, 54 insertions(+), 121 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 82de1b0162d678..6ba427fbc8fc4d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1755,10 +1755,6 @@ CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, const std::vector& s_hint, const bool compute_iis, const int64_t solver_timeout_in_seconds, const AutoShardingOption& option, std::optional max_cost, absl::string_view request_name, @@ -1948,27 +1944,55 @@ CreateAutoShardingSolverRequestAndCallSolver( } } - for (const auto& interval : node_intervals) { - AutoShardingSolverRequest_Pair pair; - pair.set_first(interval.first); - pair.set_second(interval.second); - *request.add_node_intervals() = std::move(pair); - } - for (const auto& interval : edge_intervals) { - AutoShardingSolverRequest_Pair pair; - pair.set_first(interval.first); - pair.set_second(interval.second); - *request.add_edge_intervals() = std::move(pair); + // Serialize intervals + std::vector> node_to_edges( + strategy_groups.size()); + EdgeIdx edge_idx = 0; + for (const auto& [edge, _] : cost_graph.edge_costs_) { + node_to_edges[edge.second].insert(edge_idx); + ++edge_idx; + } + const absl::flat_hash_map& + buffer_live_ranges = hlo_live_range.buffer_live_ranges(); + absl::flat_hash_map node_to_time_bound; + absl::flat_hash_map edge_to_time_bound; + for (const auto& [value, time_bound] : buffer_live_ranges) { + const HloInstruction* instruction = value->instruction(); + const ShapeIndex& index = value->index(); + if (instruction->shape().IsTuple() && index.empty()) continue; + const spmd::StrategyGroup* strategy_group = + strategy_map.at(instruction).get(); + const spmd::NodeIdx node_idx = + strategy_group->GetSubStrategyGroup(index)->node_idx; + if (node_idx < 0) continue; + node_to_time_bound[node_idx] = time_bound; + for (const EdgeIdx edge_idx : node_to_edges[node_idx]) { + edge_to_time_bound[edge_idx] = time_bound; + } } - for (const auto& reduced_group : node_groups) { - AutoShardingSolverRequest_Group group; - group.mutable_prims()->Add(reduced_group.begin(), reduced_group.end()); - *request.add_node_groups() = std::move(group); + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + AutoShardingSolverRequest_Pair interval; + if (auto time_bound = node_to_time_bound.find(node_idx); + time_bound != node_to_time_bound.end()) { + interval.set_first(time_bound->second.start); + interval.set_second(time_bound->second.end); + } else { + interval.set_first(std::numeric_limits::max()); + interval.set_second(0); + } + *request.add_node_intervals() = std::move(interval); } - for (const auto& reduced_group : edge_groups) { - AutoShardingSolverRequest_Group group; - group.mutable_prims()->Add(reduced_group.begin(), reduced_group.end()); - *request.add_edge_groups() = std::move(group); + for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { + AutoShardingSolverRequest_Pair interval; + if (auto time_bound = edge_to_time_bound.find(edge_idx); + time_bound != edge_to_time_bound.end()) { + interval.set_first(time_bound->second.start); + interval.set_second(time_bound->second.end); + } else { + interval.set_first(std::numeric_limits::max()); + interval.set_second(0); + } + *request.add_edge_intervals() = std::move(interval); } PopulateTemporalValues(cost_graph, request); @@ -3722,90 +3746,12 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( spmd::CostGraph cost_graph(strategy_groups, associative_dot_pairs); cost_graph.Simplify(option_.simplify_graph); - // ----- Build & reduce node and edge intervals ----- - std::vector> node_to_edges( - strategy_groups.size()); - spmd::EdgeIdx edge_idx = 0; - for (const auto& [edge, _] : cost_graph.edge_costs_) { - node_to_edges[edge.second].insert(edge_idx); - ++edge_idx; - } - const absl::flat_hash_map& - buffer_live_ranges = hlo_live_range->buffer_live_ranges(); - absl::flat_hash_map - node_to_time_bound; - absl::flat_hash_map - edge_to_time_bound; - for (const auto& [value, time_bound] : buffer_live_ranges) { - const HloInstruction* instruction = value->instruction(); - const ShapeIndex& index = value->index(); - if (instruction->shape().IsTuple() && index.empty()) continue; - const spmd::StrategyGroup* strategy_group = - strategy_map.at(instruction).get(); - const spmd::NodeIdx node_idx = - strategy_group->GetSubStrategyGroup(index)->node_idx; - if (node_idx < 0) continue; - node_to_time_bound[node_idx] = time_bound; - for (const spmd::EdgeIdx edge_idx : node_to_edges[node_idx]) { - edge_to_time_bound[edge_idx] = time_bound; - } - } - std::vector> node_intervals, - edge_intervals; - for (spmd::NodeIdx node_idx = 0; node_idx < strategy_groups.size(); - ++node_idx) { - std::pair interval; - if (auto time_bound = node_to_time_bound.find(node_idx); - time_bound != node_to_time_bound.end()) { - interval.first = time_bound->second.start; - interval.second = time_bound->second.end; - } else { - interval.first = std::numeric_limits::max(); - interval.second = 0; - } - node_intervals.push_back(std::move(interval)); - } - for (spmd::EdgeIdx edge_idx = 0; edge_idx < cost_graph.edge_costs_.size(); - ++edge_idx) { - std::pair interval; - if (auto time_bound = edge_to_time_bound.find(edge_idx); - time_bound != edge_to_time_bound.end()) { - interval.first = time_bound->second.start; - interval.second = time_bound->second.end; - } else { - interval.first = std::numeric_limits::max(); - interval.second = 0; - } - edge_intervals.push_back(std::move(interval)); - } - const absl::Time term_reduction_start_time = absl::Now(); - std::vector> - reduced_node_intervals, reduced_edge_intervals; - std::vector> reduced_node_groups, - reduced_edge_groups; - auto num_node_terms = - ReduceMemoryTerms(strategy_groups.size(), node_intervals, - reduced_node_intervals, reduced_node_groups); - auto num_edge_terms = - ReduceMemoryTerms(cost_graph.edge_costs_.size(), edge_intervals, - reduced_edge_intervals, reduced_edge_groups); - const absl::Time term_reduction_end_time = absl::Now(); - const auto term_reduction_duration = - term_reduction_end_time - term_reduction_start_time; - LOG(INFO) << "Memory Term Reducer took " - << absl::ToInt64Milliseconds(term_reduction_duration) - << " ms and reduced the number of terms from " - << num_node_terms.first + num_edge_terms.first << " to " - << num_node_terms.second + num_edge_terms.second; - // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - TF_ASSIGN_OR_RETURN( - spmd::AutoShardingSolverOutput output, - Solve(*module, *hlo_live_range, strategy_map, strategy_groups, - cost_graph, alias_set, reduced_node_intervals, - reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, - option_, request_name, sharding_propagation_solution)); + TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, + Solve(*module, *hlo_live_range, strategy_map, + strategy_groups, cost_graph, alias_set, option_, + request_name, sharding_propagation_solution)); if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; } else if (hard_memory_constraint) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index b9226f561244ea..bc208c48f6506e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -42,19 +42,14 @@ absl::StatusOr Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, const AutoShardingOption& option, absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution) { return CreateAutoShardingSolverRequestAndCallSolver( hlo_module, hlo_live_range, strategy_map, strategy_groups, cost_graph, - alias_set, node_intervals, edge_intervals, node_groups, edge_groups, - /*s_hint*/ {}, - /*compute_iis*/ true, option.solver_timeout_in_seconds, option, - /*max_cost*/ std::nullopt, request_prefix, sharding_propagation_solution, + alias_set, /*s_hint*/ {}, /*compute_iis*/ true, + option.solver_timeout_in_seconds, option, /*max_cost*/ std::nullopt, + request_prefix, sharding_propagation_solution, /*deterministic mode*/ true); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 333df715447f0b..5a07e71e726f89 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -45,10 +45,6 @@ absl::StatusOr Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, const AutoShardingOption& option, absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution = {}); @@ -60,10 +56,6 @@ CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, const std::vector& s_hint, bool compute_iis, int64_t solver_timeout_in_seconds, const AutoShardingOption& option, std::optional max_cost, absl::string_view request_name, From 99f886429c0dad5168662b9e7892b98ee377c2b9 Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Thu, 10 Apr 2025 12:44:07 -0700 Subject: [PATCH 0518/1324] Add dedicated thread pool for coordination service. Prevents heartbeat errors due to thread starvation in the shared thread pool. PiperOrigin-RevId: 746135865 --- .../core/distributed_runtime/rpc/grpc_server_lib.cc | 10 ++++++---- .../core/distributed_runtime/rpc/grpc_server_lib.h | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 56d21e876f6db3..213047b6165e84 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -279,9 +279,11 @@ absl::Status GrpcServer::Init(const GrpcServerOptions& opts) { opts.worker_service_options) .release(); eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); - thread::ThreadPool* compute_pool = ComputePool(sess_opts); - coordination_service_ = - new GrpcCoordinationServiceImpl(compute_pool, &builder); + coordination_compute_pool_ = std::make_unique( + env_, "CoordinationServiceRpcHandler", + /*num_threads=*/4); + coordination_service_ = new GrpcCoordinationServiceImpl( + coordination_compute_pool_.get(), &builder); profiler_service_ = tsl::profiler::CreateProfilerService(); builder.RegisterService(profiler_service_.get()); @@ -331,7 +333,7 @@ absl::Status GrpcServer::Init(const GrpcServerOptions& opts) { return WorkerCacheFactory(options, worker_cache); }, grpc_coordination_service->GetRpcHandler()); - worker_env_.compute_pool = compute_pool; + worker_env_.compute_pool = ComputePool(sess_opts); // Finish setting up master environment. master_env_.ops = OpRegistry::Global(); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index ca162c193d3b15..8aadbc3077732e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/platform/env.h" +#include "tsl/platform/threadpool.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" namespace tensorflow { @@ -225,6 +226,7 @@ class GrpcServer : public ServerInterface { std::shared_ptr worker_session_; // Experimental coordination service implementation, and RPC polling thread. + std::unique_ptr coordination_compute_pool_ = nullptr; tsl::AsyncServiceInterface* coordination_service_ = nullptr; std::unique_ptr coordination_thread_ TF_GUARDED_BY(mu_); From 41ca755d6e3e50a8d2e3b1b3b35b442c1b025995 Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 10 Apr 2025 13:00:11 -0700 Subject: [PATCH 0519/1324] Add `IWYU pragma: private` telling which is the public header to include. PiperOrigin-RevId: 746141508 --- tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h index 1327162f23262b..c580bf03cd3f59 100644 --- a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + /// WARNING: Users of TensorFlow Lite should not include this file directly, -/// but should instead include -/// "third_party/tensorflow/lite/c/builtin_op_data.h". -/// Only the TensorFlow Lite implementation itself should include this -/// file directly. +/// only the TensorFlow Lite implementation itself should. + +// IWYU pragma: private, include "third_party/tensorflow/lite/c/builtin_op_data.h" + #ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ From 3949bd4ee63dd4a1c2f608a95077cdb0461f91f1 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 10 Apr 2025 13:49:42 -0700 Subject: [PATCH 0520/1324] Post submission readability review of cl/742875944. PiperOrigin-RevId: 746160404 --- third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index 039ee0cde44649..0af1108e014f98 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -1363,7 +1363,8 @@ TEST(TfrtGpuClientTest, AsyncCopyToDevice) { } TEST(TfrtGpuClientTest, OnDoneSafelyDestructTransferManagerAsync) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetTfrtGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); PjRtDevice* const device = client->addressable_devices()[0]; From 3fc00d39fa2c8eb749925dbebb9a17e0852601bf Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Thu, 10 Apr 2025 13:51:48 -0700 Subject: [PATCH 0521/1324] Add config for unsafe dynamic broadcast fusion PiperOrigin-RevId: 746161270 --- tensorflow/compiler/mlir/lite/BUILD | 16 +++++++- .../mlir/lite/common/tfl_pass_config.h | 8 ++++ .../compiler/mlir/lite/converter_flags.proto | 8 +++- .../python/saved_model_to_tfl_flatbuffer.cc | 2 + .../compiler/mlir/lite/tf_tfl_passes.cc | 19 +++++++-- .../converter_pass_options_setter.cc | 7 ++++ .../converter_pass_options_setter.h | 2 + .../optimize_broadcast_like_pass.cc | 34 ++++++++++----- .../transforms/optimize_broadcast_like_pass.h | 10 +++-- .../optimize_broadcast_like_pass_options.h | 41 +++++++++++++++++++ .../lite/transforms/pass_options_setter.h | 2 + .../compiler/mlir/lite/transforms/passes.h | 3 +- tensorflow/lite/python/convert.py | 9 ++++ tensorflow/lite/python/lite.py | 4 ++ 14 files changed, 145 insertions(+), 20 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 5ee2ccbb024788..fbd280f063032e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -377,6 +377,7 @@ cc_library( deps = [ ":common", ":converter_flags_proto_cc", + ":optimize_broadcast_like_pass_options", ":optimize_pass_options", ":pass_options", ":pass_options_setter", @@ -512,6 +513,7 @@ cc_library( ":converter_inc", ":cost_estimators", ":optimize_broadcast_like_pass", + ":optimize_broadcast_like_pass_options", ":optimize_pass_options", ":pass", ":pass_options", @@ -1013,6 +1015,7 @@ cc_library( ":fake_quant_utils", ":lstm_utils", ":nms_utils", + ":optimize_broadcast_like_pass_options", ":perception_ops_utils", ":shape_and_size_utils", ":stateful_ops_utils", @@ -1188,8 +1191,8 @@ cc_library( ], deps = [ ":optimize_broadcast_like_inc_gen", + ":optimize_broadcast_like_pass_options", ":pass", - ":pass_options", ":tensorflow_lite_ops", ":utils", "@llvm-project//llvm:Support", @@ -1197,6 +1200,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], @@ -1270,6 +1274,7 @@ cc_library( deps = [ "convert_type", ":op_quant_spec_getters_inc", + ":optimize_broadcast_like_pass_options", ":shape_and_size_utils", ":stateful_ops_utils", ":tensorflow_lite", @@ -1665,6 +1670,15 @@ cc_library( ], ) +cc_library( + name = "optimize_broadcast_like_pass_options", + hdrs = ["transforms/optimize_broadcast_like_pass_options.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + cc_library( name = "flatbuffer_translate_lib", hdrs = [ diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index db9715e99c1acd..aa552ec43d138a 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -107,6 +107,12 @@ struct PassConfig { // When set to true, convert +Inf/-Inf to MIN/MAX float value and output of // convert only contains finite values. bool canonicalizing_inf_as_min_max_float = true; + + // When set to true, allows fusion of dynamic shaped broadcast ops. It helps + // fusing implicit broadcasting ops when output shape has dynamic dimensions, + // but it may cause incorrect results when broadcasting ops are introduced by + // explicit broadcasting in the source model. + bool unsafe_fuse_dynamic_shaped_broadcast = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, @@ -133,6 +139,8 @@ inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, << pass_config.enable_stablehlo_conversion << "\nlegalize_custom_tensor_list_ops: " << pass_config.legalize_custom_tensor_list_ops + << "\nunsafe_fuse_dynamic_shaped_broadcast: " + << pass_config.unsafe_fuse_dynamic_shaped_broadcast << "\nreduce_type_precision: " << pass_config.reduce_type_precision << "\nconvert_qdq_format: " << GetQDQQuantModeString(pass_config.qdq_conversion_mode) diff --git a/tensorflow/compiler/mlir/lite/converter_flags.proto b/tensorflow/compiler/mlir/lite/converter_flags.proto index 5b6b9e2ca752a6..1c1a1ad00aea74 100644 --- a/tensorflow/compiler/mlir/lite/converter_flags.proto +++ b/tensorflow/compiler/mlir/lite/converter_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 68. +// Next ID to use: 69. message ConverterFlags { // Input file format optional FileFormat input_format = 1; @@ -385,4 +385,10 @@ message ConverterFlags { // possible rather than quantizing any op that is possible to quantize. // WARNING: Experimental interface, subject to change. optional bool strict_qdq_mode = 67 [default = false]; + + // When set to true, allows fusion of dynamic shaped broadcast ops. It helps + // fusing implicit broadcasting ops when output shape has dynamic dimensions, + // but it may cause incorrect results when broadcasting ops are introduced by + // explicit broadcasting in the source model. + optional bool unsafe_fuse_dynamic_shaped_broadcast = 68 [default = false]; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 4dcf1497476f77..6cf9cd3cc9711d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -217,6 +217,8 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.model_origin_framework = converter_flags.model_origin_framework(); pass_config.canonicalizing_inf_as_min_max_float = converter_flags.canonicalizing_inf_as_min_max_float(); + pass_config.unsafe_fuse_dynamic_shaped_broadcast = + converter_flags.unsafe_fuse_dynamic_shaped_broadcast(); if (converter_flags.strict_qdq_mode()) { pass_config.quant_specs.qdq_conversion_mode = diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 25789ab44d17be..6cfd3cb992028a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -70,8 +70,14 @@ void AddOptimizationPasses(const tflite::ConverterFlags& converter_flags, pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); - pass_manager->addNestedPass( - mlir::TFL::Create()); + // Add BroadcastLike optimization pass. + { + mlir::TFL::OptimizeBroadcastLikePassOptions options; + options.unsafe_fuse_dynamic_shaped_broadcast = + pass_config.unsafe_fuse_dynamic_shaped_broadcast; + pass_manager->addNestedPass( + mlir::TFL::Create(options)); + } // Add TFLite optimize pass. mlir::TFL::OptimizePassOptions optimize_pass_options; @@ -355,8 +361,13 @@ void AddPostQuantizationStableHloToTfPasses( // broadcasting support. This needs to be run immediately after HLO->TFL // legalization, otherwise the newly generated TFL broadcast ops can fold // and materialize the weights. - pass_manager.addNestedPass( - mlir::TFL::Create()); + { + mlir::TFL::OptimizeBroadcastLikePassOptions options; + options.unsafe_fuse_dynamic_shaped_broadcast = + pass_config.unsafe_fuse_dynamic_shaped_broadcast; + pass_manager.addNestedPass( + mlir::TFL::Create(options)); + } } // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF diff --git a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc index f0fb9361980f67..5a3f23fe6df382 100644 --- a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc +++ b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h" @@ -33,6 +34,12 @@ void ConverterPassOptionsSetter::SetOptions( options.enable_tflite_variables = pass_config_.enable_tflite_variables; } +void ConverterPassOptionsSetter::SetOptions( + OptimizeBroadcastLikePassOptions& options) const { + // options.unsafe_fuse_dynamic_shaped_broadcast = + // converter_flags_.unsafe_fuse_dynamic_shaped_broadcast(); +} + void ConverterPassOptionsSetter::SetOptions(EmptyPassOptions& options) const {} } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h index 01f71afe84ca3f..59151448b92f0a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h +++ b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h @@ -26,6 +26,7 @@ namespace TFL { class OptimizePassOptions; class VariableFreezingPipelineOptions; class EmptyPassOptions; +class OptimizeBroadcastLikePassOptions; // PassOptionsSetter to set TFLite Converter Pass/Pipeline Options based on // ConverterFlags and TFL::PassConfig values. @@ -40,6 +41,7 @@ class ConverterPassOptionsSetter : public PassOptionsSetter { void SetOptions(OptimizePassOptions& options) const override; void SetOptions(VariableFreezingPipelineOptions& options) const override; void SetOptions(EmptyPassOptions& options) const override; + void SetOptions(OptimizeBroadcastLikePassOptions& options) const override; private: tflite::ConverterFlags converter_flags_; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc index e85cfef6dd0d87..aed2946db17ba3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { @@ -55,8 +56,10 @@ using BroadcastedShapeFunction = class ConvertResultsBroadcastableShapeOp : public RewritePattern { public: - explicit ConvertResultsBroadcastableShapeOp(MLIRContext* context) - : RewritePattern(MatchAnyOpTypeTag(), /*PatternBenefit*/ 1, context) {} + explicit ConvertResultsBroadcastableShapeOp( + MLIRContext* context, const OptimizeBroadcastLikePassOptions& options) + : RewritePattern(MatchAnyOpTypeTag(), /*PatternBenefit*/ 1, context), + options_(options) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; @@ -65,6 +68,9 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern { LogicalResult RewriteOp( Operation* op, PatternRewriter& rewriter, BroadcastedShapeFunction& get_broadcasted_shape) const; + + private: + const OptimizeBroadcastLikePassOptions& options_; }; // Some tfl ops only support implicit broadcasting up to a certain rank. @@ -191,7 +197,8 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = llvm::cast(op->getResultTypes().front()); - if (!result_type || !result_type.hasStaticShape()) + if (!result_type || (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !result_type.hasStaticShape())) return rewriter.notifyMatchFailure( op, "Unsupported result shape for broadcasting on op: " + op->getName().getStringRef()); @@ -224,7 +231,10 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the operand of the broadcast has fully defined shape. auto broadcast_arg_type = llvm::cast(broadcast_like_op_input.getType()); - if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + if (!broadcast_arg_type || + (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !broadcast_arg_type.hasStaticShape())) + continue; auto other_arg = op->getOpOperand(1 - i).get(); // If non-splat operand is not fusable affine ops, then no need to apply @@ -238,7 +248,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the other argument has fully defined shape. auto other_arg_type = llvm::cast(other_arg.getType()); - if (!other_arg_type || !other_arg_type.hasStaticShape()) continue; + if (!other_arg_type || (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !other_arg_type.hasStaticShape())) + continue; // Get the unbroadcasted shapes in the operand order. std::array, 2> operand_shapes; @@ -268,8 +280,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( class ConvertResultsBroadcastableBatchMatMulShapeOp : public ConvertResultsBroadcastableShapeOp { public: - explicit ConvertResultsBroadcastableBatchMatMulShapeOp(MLIRContext* context) - : ConvertResultsBroadcastableShapeOp(context) {} + explicit ConvertResultsBroadcastableBatchMatMulShapeOp( + MLIRContext* context, const OptimizeBroadcastLikePassOptions& options) + : ConvertResultsBroadcastableShapeOp(context, options) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; @@ -384,9 +397,10 @@ void OptimizeBroadcastLikePass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); - patterns.add(func.getContext()); - patterns.add( - func.getContext()); + patterns.add(func.getContext(), + GetOptions()); + patterns.add(func.getContext(), + GetOptions()); patterns.add(func.getContext()); TFL::populateWithGenerated(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h index f13048a1982641..0b5f8f1f6bc2b1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h @@ -16,24 +16,28 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/pass.h" -#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" namespace mlir { namespace TFL { // Pass to optimize explicit broadcasting-like patterns. class OptimizeBroadcastLikePass - : public TFL::Pass { + : public TFL::Pass { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeBroadcastLikePass) OptimizeBroadcastLikePass() = default; OptimizeBroadcastLikePass(const OptimizeBroadcastLikePass&) {}; + explicit OptimizeBroadcastLikePass(const mlir::detail::PassOptions& options) + : Pass(options) {} void runOnOperation() override; static llvm::StringRef GetName() { return "OptimizeBroadcastLikePass"; } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h new file mode 100644 index 00000000000000..7d11f5d74cc4c5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h @@ -0,0 +1,41 @@ + +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Pass Options +//////////////////////////////////////////////////////////////////////////////// + +struct OptimizeBroadcastLikePassOptions : public mlir::detail::PassOptions { + mlir::detail::PassOptions::Option unsafe_fuse_dynamic_shaped_broadcast{ + *this, "unsafe-fuse-dynamic-shaped-broadcast", + llvm::cl::desc( + "Enable fusion of dynamic shaped broadcast ops. It helps fusing " + "implicit broadcasting ops when output shape has dynamic dimensions, " + "but it may cause incorrect results when broadcasting ops are " + "introduced by explicit broadcasting in the source model."), + llvm::cl::init(false)}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h b/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h index 534b1402dd4cd3..29906014fce292 100644 --- a/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h +++ b/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h @@ -22,6 +22,7 @@ namespace TFL { class OptimizePassOptions; class VariableFreezingPipelineOptions; class EmptyPassOptions; +class OptimizeBroadcastLikePassOptions; // Interface for setting options for TFLite Converter Pass/Pipeline Options. class PassOptionsSetter { @@ -30,6 +31,7 @@ class PassOptionsSetter { virtual void SetOptions(OptimizePassOptions& options) const = 0; virtual void SetOptions(VariableFreezingPipelineOptions& options) const = 0; virtual void SetOptions(EmptyPassOptions& options) const = 0; + virtual void SetOptions(OptimizeBroadcastLikePassOptions& options) const = 0; }; } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 4d8ecccaa5f3f7..8b808568c26fcd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h" @@ -340,7 +341,7 @@ inline void registerTensorFlowLitePasses() { Register(); Register(); Register(); - Register(); + Register(); Register(); Register(); diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 2519835376fe4b..7bd4c5e9411271 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -453,6 +453,7 @@ def build_conversion_flags( model_origin_framework=lite_constants.UNSET, canonicalizing_inf_as_min_max_float=True, serialize_debug_metadata=False, + unsafe_fuse_dynamic_shaped_broadcast=False, **_, ): """Builds protocol buffer describing a conversion of a model. @@ -593,6 +594,11 @@ def build_conversion_flags( MIN/MAX float value and output of converter only contains finite values. serialize_debug_metadata: When set to true, serialize debug metadata in the flatbuffer. + unsafe_fuse_dynamic_shaped_broadcast: When set to true, allows fusion of + dynamic shaped broadcast ops. It helps fusing implicit broadcasting ops + when output shape has dynamic dimensions, but it may cause incorrect + results when broadcasting ops are introduced by explicit broadcasting in + the source model. Returns: conversion_flags: protocol buffer describing the conversion process. @@ -727,6 +733,9 @@ def build_conversion_flags( ) conversion_flags.serialize_debug_metadata = serialize_debug_metadata + conversion_flags.unsafe_fuse_dynamic_shaped_broadcast = ( + unsafe_fuse_dynamic_shaped_broadcast + ) return conversion_flags diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 3b62c3a62b78fd..55afc80c329795 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -685,6 +685,7 @@ def __init__(self): self.model_origin_framework = constants.UNSET self.canonicalizing_inf_as_min_max_float = True self._experimental_strict_qdq = False + self._experimental_unsafe_fuse_dynamic_shaped_broadcast = False # Debug parameters self.ir_dump_dir = None @@ -854,6 +855,9 @@ def _get_base_converter_args(self): self.canonicalizing_inf_as_min_max_float ), "serialize_debug_metadata": self.serialize_debug_metadata, + "unsafe_fuse_dynamic_shaped_broadcast": ( + self._experimental_unsafe_fuse_dynamic_shaped_broadcast + ), } if self.saved_model_dir: From a18e429fffb3923785a75a841114467ae57ed3d1 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Thu, 10 Apr 2025 14:05:45 -0700 Subject: [PATCH 0522/1324] Add a Test to avoid mixing a dialect with bounded types from an another dialect during direct StableHLO to HLO lowering. PiperOrigin-RevId: 746166864 --- .../mhlo/stablehlo-legalize-to-hlo-partial.mlir | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir index 5008e99d64deb9..bc9b597b5a801a 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo-partial.mlir @@ -8,3 +8,18 @@ func.func @op_constant(%arg0: tensor) -> tensor { %cst = stablehlo.constant dense<0.000000e+00> : tensor return %cst : tensor } + +// ----- + +// CHECK-LABEL: bounded_dynamic_gather +func.func @bounded_dynamic_gather(%arg0: tensor<16x50xf32>, %arg1: tensor<1x?xi64, #stablehlo.bounds>) -> tensor> { + // CHECK: mhlo.reshape + // CHECK-NOT: #stablehlo.bounds + // CHECK-SAME: #mhlo.type_extensions>) -> tensor> + // CHECK: mhlo.gather + // CHECK-NOT: #stablehlo.bounds + // CHECK-SAME: #mhlo.type_extensions, indices_are_sorted = false, slice_sizes = array}> : (tensor<16x50xf32>, tensor>) -> tensor> + return %1 : tensor> +} From 0377f30a042a15791b9c329520e7610923663890 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Thu, 10 Apr 2025 14:35:11 -0700 Subject: [PATCH 0523/1324] Fix HloRunnerAgnosticTestBase argument const-ness. For some reason the `Literal*` values consumed by our `Execute` functions were not marked const, so all of the infrastructure that derives from this was also not marked const. We don't actually need to be able to mutate these literals so it is best to make them const. PiperOrigin-RevId: 746177763 --- third_party/xla/xla/tests/BUILD | 1 + .../tests/client_library_test_runner_mixin.h | 36 ++++++++++--------- .../hlo_runner_agnostic_reference_mixin.h | 9 +++-- .../tests/hlo_runner_agnostic_test_base.cc | 27 +++++++------- .../xla/tests/hlo_runner_agnostic_test_base.h | 24 +++++++------ 5 files changed, 54 insertions(+), 43 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 840e28a3d61583..99624926729622 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -270,6 +270,7 @@ cc_library( "//xla/tsl/platform:test", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index c44536ae19cacb..7184e88958dfa1 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -94,7 +94,7 @@ class ClientLibraryTestRunnerMixin : public T { absl::StatusOr ExecuteAndTransfer( const XlaComputation& computation, - const absl::Span arguments, + const absl::Span arguments, const Shape* const shape_with_output_layout = nullptr) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -108,7 +108,8 @@ class ClientLibraryTestRunnerMixin : public T { } absl::StatusOr ExecuteAndTransfer( - XlaBuilder* const builder, const absl::Span arguments, + XlaBuilder* const builder, + const absl::Span arguments, const Shape* shape_with_output_layout = nullptr) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(XlaComputation computation, builder->Build()); @@ -118,8 +119,9 @@ class ClientLibraryTestRunnerMixin : public T { // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. - std::string ExecuteToString(XlaBuilder* const builder, - const absl::Span arguments) { + std::string ExecuteToString( + XlaBuilder* const builder, + const absl::Span arguments) { const absl::StatusOr result = ExecuteAndTransfer(builder, arguments); if (!result.ok()) { @@ -132,7 +134,7 @@ class ClientLibraryTestRunnerMixin : public T { // Compare with reference. // Side effect: EXPECT_OK void ComputeAndCompare(XlaBuilder* const builder, - const absl::Span arguments, + const absl::Span arguments, const std::optional error = std::nullopt) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder->Build()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -142,10 +144,10 @@ class ClientLibraryTestRunnerMixin : public T { // Compare with literal. // Side effect: EXPECT_OK - void ComputeAndCompareLiteral(XlaBuilder* const builder, - const Literal& expected, - const absl::Span arguments, - const Shape* shape_with_layout) { + void ComputeAndCompareLiteral( + XlaBuilder* const builder, const Literal& expected, + const absl::Span arguments, + const Shape* shape_with_layout) { return ComputeAndCompareLiteral(builder, expected, arguments, std::nullopt, shape_with_layout); } @@ -154,7 +156,7 @@ class ClientLibraryTestRunnerMixin : public T { // Side effect: EXPECT_OK void ComputeAndCompareLiteral( XlaBuilder* const builder, const Literal& expected, - const absl::Span arguments, + const absl::Span arguments, const std::optional error = std::nullopt, const Shape* shape_with_layout = nullptr) { if (error == std::nullopt) { @@ -195,14 +197,14 @@ class ClientLibraryTestRunnerMixin : public T { // Compare with literal. // Side effect: EXPECT_OK void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { return ComputeAndCompareLiteral(builder, expected, arguments, error); } template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { CheckErrorSpec(error); Literal expected_literal = LiteralUtil::CreateR0(expected); @@ -212,7 +214,7 @@ class ClientLibraryTestRunnerMixin : public T { template void ComputeAndCompareR1(XlaBuilder* builder, absl::Span expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { CheckErrorSpec(error); Literal expected_literal = LiteralUtil::CreateR1(expected); @@ -221,7 +223,7 @@ class ClientLibraryTestRunnerMixin : public T { void ComputeAndCompareR1(XlaBuilder* builder, const tsl::core::Bitmap& expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { Literal expected_literal = LiteralUtil::CreateR1(expected); ComputeAndCompareLiteral(builder, expected_literal, arguments, error); @@ -230,7 +232,7 @@ class ClientLibraryTestRunnerMixin : public T { template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { CheckErrorSpec(error); Literal expected_literal = @@ -241,7 +243,7 @@ class ClientLibraryTestRunnerMixin : public T { template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { CheckErrorSpec(error); Literal expected_literal = @@ -252,7 +254,7 @@ class ClientLibraryTestRunnerMixin : public T { template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - absl::Span arguments, + absl::Span arguments, std::optional error = std::nullopt) { CheckErrorSpec(error); Literal expected_literal = diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h b/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h index b5b632e03247b9..d92ed2abfae0b5 100644 --- a/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/error_spec.h" @@ -86,7 +87,8 @@ class HloRunnerAgnosticReferenceMixin : public T { // reference backend. Note that the program shape of the module must not be // modified. ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, absl::Span arguments, + std::unique_ptr module, + absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr) { @@ -104,7 +106,7 @@ class HloRunnerAgnosticReferenceMixin : public T { // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const absl::Span arguments, + const absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr) { @@ -223,7 +225,8 @@ class HloRunnerAgnosticReferenceMixin : public T { // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( - std::unique_ptr module, absl::Span arguments, + std::unique_ptr module, + absl::Span arguments, const std::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr) { diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc index c0739ced9eeed6..b988b5b13075e6 100644 --- a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc @@ -101,13 +101,14 @@ void HloRunnerAgnosticTestBase::UpdateEntryComputationLayout( } absl::StatusOr HloRunnerAgnosticTestBase::Execute( - std::unique_ptr module, absl::Span arguments, - bool run_hlo_passes) { + std::unique_ptr module, + absl::Span arguments, bool run_hlo_passes) { return test_runner_->Execute(std::move(module), arguments, run_hlo_passes); } Literal HloRunnerAgnosticTestBase::ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments) { + std::unique_ptr module, + absl::Span arguments) { absl::StatusOr result = Execute(std::move(module), arguments, /*run_hlo_passes=*/false); CHECK_OK(result.status()); @@ -115,7 +116,8 @@ Literal HloRunnerAgnosticTestBase::ExecuteNoHloPasses( } Literal HloRunnerAgnosticTestBase::ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments) { + std::unique_ptr module, + absl::Span arguments) { absl::StatusOr result = test_runner_->Execute(std::move(module), arguments, true, nullptr); CHECK_OK(result.status()); @@ -125,8 +127,9 @@ Literal HloRunnerAgnosticTestBase::ExecuteAndTransfer( absl::StatusOr> HloRunnerAgnosticTestBase::ExecuteReplicated( std::unique_ptr module, - const absl::Span arguments, const int64_t num_replicas, - const bool use_threads, const bool run_hlo_passes) { + const absl::Span arguments, + const int64_t num_replicas, const bool use_threads, + const bool run_hlo_passes) { HloRunnerInterface::ReplicatedExecuteOptions options; options.num_replicas = num_replicas; options.arguments = {arguments.begin(), arguments.end()}; @@ -138,9 +141,9 @@ HloRunnerAgnosticTestBase::ExecuteReplicated( absl::StatusOr> HloRunnerAgnosticTestBase::ExecuteReplicated( std::unique_ptr module, - const absl::Span arguments, const int64_t num_replicas, - DeviceAssignment* const device_assignment, const bool run_hlo_passes, - const bool use_threads) { + const absl::Span arguments, + const int64_t num_replicas, DeviceAssignment* const device_assignment, + const bool run_hlo_passes, const bool use_threads) { HloRunnerInterface::ReplicatedExecuteOptions options; options.num_replicas = num_replicas; options.arguments = {arguments.begin(), arguments.end()}; @@ -324,7 +327,7 @@ HloRunnerAgnosticTestBase::RunAndCompareTwoModulesReplicated( ::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareTwoModules( std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, + const absl::Span arguments, const std::optional& error, bool run_hlo_passes) { const absl::StatusOr<::testing::AssertionResult> result = RunAndCompareTwoModulesInternal(std::move(module_0), std::move(module_1), @@ -414,7 +417,7 @@ ::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareTwoModules( ::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareTwoModules( absl::string_view hlo_string_module_0, absl::string_view hlo_string_module_1, - const absl::Span arguments, + const absl::Span arguments, const std::optional& error, const bool run_hlo_passes) { auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); if (!module_0_or_status.ok()) { @@ -626,7 +629,7 @@ HloRunnerAgnosticTestBase::RunAndCompareTwoModulesInternalReplicated( absl::StatusOr<::testing::AssertionResult> HloRunnerAgnosticTestBase::RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, + const absl::Span arguments, const std::optional& error, bool run_hlo_passes) { TF_RETURN_IF_ERROR(verifier().Run(module_0.get()).status()); TF_RETURN_IF_ERROR(verifier().Run(module_1.get()).status()); diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h index 216e7a94c3f360..7a88b9ccaa21a8 100644 --- a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h @@ -118,16 +118,16 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Executes the given module and return the result as a Literal. absl::StatusOr Execute(std::unique_ptr module, - absl::Span arguments, + absl::Span arguments, bool run_hlo_passes = true); // Same as above, except the module will be executed without running any HLO // passes on it. Literal ExecuteNoHloPasses(std::unique_ptr module, - absl::Span arguments); + absl::Span arguments); Literal ExecuteAndTransfer(std::unique_ptr module, - absl::Span arguments); + absl::Span arguments); // Compile the given module to an executable. absl::StatusOr> CreateExecutable( @@ -141,14 +141,16 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // with a thread-per-replica, vs using an implicitly async call such as // Executable::ExecuteOnStreams. absl::StatusOr> ExecuteReplicated( - std::unique_ptr module, absl::Span arguments, - int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); + std::unique_ptr module, + absl::Span arguments, int64_t num_replicas, + bool use_threads, bool run_hlo_passes = false); // Same as above, but uses specified device assignment. absl::StatusOr> ExecuteReplicated( - std::unique_ptr module, absl::Span arguments, - int64_t num_replicas, DeviceAssignment* device_assignment, - bool run_hlo_passes, bool use_threads); + std::unique_ptr module, + absl::Span arguments, int64_t num_replicas, + DeviceAssignment* device_assignment, bool run_hlo_passes, + bool use_threads); // Same as above, but allows passing different programs for replicas. absl::StatusOr> ExecuteReplicated( @@ -208,7 +210,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as below, except requires passing fake arguments. ::testing::AssertionResult RunAndCompareTwoModules( std::unique_ptr module_0, std::unique_ptr module_1, - absl::Span arguments, + absl::Span arguments, const std::optional& error, bool run_hlo_passes = true); // Same as below, except requires passing the modules. @@ -238,7 +240,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { ::testing::AssertionResult RunAndCompareTwoModules( absl::string_view hlo_string_module_0, absl::string_view hlo_string_module_1, - absl::Span arguments, + absl::Span arguments, const std::optional& error, bool run_hlo_passes = true); // Executes an hlo module with fake inputs on multiple replicas. @@ -283,7 +285,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // error happens before the results are computed, returns the error status. absl::StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, - absl::Span arguments, + absl::Span arguments, const std::optional& error, bool run_hlo_passes); std::unique_ptr test_runner_; From baf7092fa4b5ca53c4ddf6b082306ce3fc53e631 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 10 Apr 2025 15:08:11 -0700 Subject: [PATCH 0524/1324] Allow ctrl-c to cancel block_until_ready(). Partially addresses: https://github.com/jax-ml/jax/issues/18246. If compile can also be a future, this code can be used to safely block on that as well. PiperOrigin-RevId: 746189742 --- third_party/xla/xla/pjrt/pjrt_future.h | 45 ++++++++++++++++++-------- third_party/xla/xla/python/version.h | 2 +- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_future.h b/third_party/xla/xla/pjrt/pjrt_future.h index ac76be51859a1d..2e01a08b9a2770 100644 --- a/third_party/xla/xla/pjrt/pjrt_future.h +++ b/third_party/xla/xla/pjrt/pjrt_future.h @@ -253,22 +253,38 @@ class PjRtFutureBase : public PjRtFutureMoveControl { #endif }; - PjRtFutureHelpers::ProfilingKeys OnBlockStart() const { - return on_block_start_ ? on_block_start_() - : PjRtFutureHelpers::ProfilingKeys(); - } + class ProfilingCleanup { + public: + ProfilingCleanup(const PjRtFutureBase* parent, + PjRtFutureHelpers::ProfilingKeys keys) + : parent_(parent), keys_(std::move(keys)) {} + ~ProfilingCleanup() { + if (parent_ && parent_->on_block_end_) + parent_->on_block_end_(std::move(keys_)); + } + ProfilingCleanup(const ProfilingCleanup& other) = delete; + ProfilingCleanup(ProfilingCleanup&& other) = delete; + + private: + const PjRtFutureBase* parent_; + PjRtFutureHelpers::ProfilingKeys keys_; + }; - void OnBlockEnd(PjRtFutureHelpers::ProfilingKeys keys) const { - if (on_block_end_) on_block_end_(std::move(keys)); + ProfilingCleanup OnBlockStartScope() const { + return ProfilingCleanup(this, on_block_start_ + ? on_block_start_() + : PjRtFutureHelpers::ProfilingKeys()); } - // Blocks the calling thread until the future is ready. - void BlockUntilReady() const { + // Calls block_until_ready_fn to wait until the underlying AsyncValue is + // concrete. block_until_ready_fn should be equivalent to + // tsl::BlockUntilReady. + template + void BlockUntilReady(Fn&& block_until_ready_fn) const { CHECK(IsValid()); if (!promise_.IsAvailable()) { - PjRtFutureHelpers::ProfilingKeys keys = OnBlockStart(); - tsl::BlockUntilReady(promise_); - OnBlockEnd(std::move(keys)); + ProfilingCleanup scope = OnBlockStartScope(); + block_until_ready_fn(promise_.GetAsyncValue()); } DCHECK(promise_.IsConcrete()); } @@ -276,14 +292,16 @@ class PjRtFutureBase : public PjRtFutureMoveControl { // Blocks the calling thread until the future is ready, then returns the // final value. const T& Await() const& { - BlockUntilReady(); + BlockUntilReady( + static_cast(tsl::BlockUntilReady)); return *promise_; } // Blocks the calling thread until the future is ready, then returns the // final value. std::conditional_t Await() && { - BlockUntilReady(); + BlockUntilReady( + static_cast(tsl::BlockUntilReady)); if constexpr (unique) { return std::move(*promise_); @@ -483,6 +501,7 @@ class PjRtFuture : public internal::PjRtFutureBase { std::move(on_block_end)) {} using Base::Await; + using Base::BlockUntilReady; using Base::OnReady; }; diff --git a/third_party/xla/xla/python/version.h b/third_party/xla/xla/python/version.h index ea6e7b78c5d510..11c9800037cd6a 100644 --- a/third_party/xla/xla/python/version.h +++ b/third_party/xla/xla/python/version.h @@ -18,6 +18,6 @@ limitations under the License. // An increasing version number to protect jax code against breaking changes. // In JAX, reference this via jax._src.lib.ifrt_version. -#define JAX_IFRT_VERSION_NUMBER 4 +#define JAX_IFRT_VERSION_NUMBER 5 #endif // XLA_PYTHON_VERSION_H_ From c9ff6b6245e87fc50d3313bfbd7220ddf30bb693 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 15:55:04 -0700 Subject: [PATCH 0525/1324] Add `--repo_env=USE_PYWRAP_RULES=True` to linux arm64 builds. PiperOrigin-RevId: 746206292 --- ci/official/envs/linux_arm64 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/official/envs/linux_arm64 b/ci/official/envs/linux_arm64 index 643be7a872e0df..2b6e38b0e42f04 100644 --- a/ci/official/envs/linux_arm64 +++ b/ci/official/envs/linux_arm64 @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config release_arm64_linux" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config release_arm64_linux" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 # Note: this is not set to "--cpu", because that changes the package name # to tensorflow_cpu. These ARM builds are supposed to have the name "tensorflow" From 90e4386f96b8bc8575b3bb8316cc6a8ee891f005 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 10 Apr 2025 16:16:18 -0700 Subject: [PATCH 0526/1324] Simplify the if-else structure for `SpmdPartitioningVisitor::Preprocess`. PiperOrigin-RevId: 746213663 --- .../xla/xla/service/spmd/spmd_partitioner.cc | 262 +++++++++--------- 1 file changed, 130 insertions(+), 132 deletions(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index 40593280510c31..63d6d9a48e93a2 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -2568,6 +2568,18 @@ absl::Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { absl::Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { visiting_hlo_ = hlo; b_.set_visiting_hlo(hlo); + + if (hlo->opcode() == HloOpcode::kAllReduce || + hlo->opcode() == HloOpcode::kCall || + hlo->opcode() == HloOpcode::kConditional || + hlo->opcode() == HloOpcode::kInfeed || + hlo->opcode() == HloOpcode::kOutfeed || + hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kRng || hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kWhile) { + return absl::OkStatus(); + } + // Temporarily replace manual sharding to one-device sharding so that the // partitioner will not change the HLOs. auto manual_to_onedevice = [&](HloOpcode opcode, const Shape& shape, @@ -2592,143 +2604,129 @@ absl::Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { return sharding; }; - if (hlo->opcode() != HloOpcode::kConditional && - hlo->opcode() != HloOpcode::kTuple && - hlo->opcode() != HloOpcode::kParameter && - hlo->opcode() != HloOpcode::kWhile && hlo->opcode() != HloOpcode::kRng && - hlo->opcode() != HloOpcode::kInfeed && - hlo->opcode() != HloOpcode::kOutfeed && - hlo->opcode() != HloOpcode::kAllReduce && - hlo->opcode() != HloOpcode::kCall) { - const bool has_manual_sharding = - hlo->sharding().IsManual() || - (hlo->sharding().IsTuple() && - absl::c_any_of( - hlo->sharding().tuple_elements(), - [](const HloSharding& sharding) { return sharding.IsManual(); })); - if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) { - visiting_hlo_sharding_ = hlo->sharding(); - auto get_sharding_shape = [](const HloInstruction* hlo) { - if (hlo->opcode() != HloOpcode::kOutfeed) { - return hlo->shape(); - } - std::vector operand_shapes(hlo->operand_count()); - for (int i = 0; i < hlo->operand_count(); ++i) { - operand_shapes[i] = hlo->operand(i)->shape(); - } - return ShapeUtil::MakeTupleShape(operand_shapes); - }; - hlo->set_sharding(manual_to_onedevice( - hlo->opcode(), get_sharding_shape(hlo), *visiting_hlo_sharding_)); - - visiting_hlo_operand_shardings_.reserve(hlo->operand_count()); - for (HloInstruction* operand : hlo->unique_operands()) { - visiting_hlo_operand_shardings_.push_back(operand->sharding()); - operand->set_sharding(manual_to_onedevice( - hlo->opcode(), get_sharding_shape(operand), operand->sharding())); - GetPartitionedHlo(operand).hlo()->copy_sharding(operand); + const bool has_manual_sharding = + hlo->sharding().IsManual() || + (hlo->sharding().IsTuple() && + absl::c_any_of( + hlo->sharding().tuple_elements(), + [](const HloSharding& sharding) { return sharding.IsManual(); })); + const bool has_manual_subgroup = + hlo->sharding().IsManualSubgroup() || + (hlo->sharding().IsTuple() && + absl::c_any_of(hlo->sharding().tuple_elements(), + [](const HloSharding& sharding) { + return sharding.IsManualSubgroup(); + })); + if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) { + visiting_hlo_sharding_ = hlo->sharding(); + auto get_sharding_shape = [](const HloInstruction* hlo) { + if (hlo->opcode() != HloOpcode::kOutfeed) { + return hlo->shape(); } - } else { - const bool has_manual_subgroup = - hlo->sharding().IsManualSubgroup() || - (hlo->sharding().IsTuple() && - absl::c_any_of(hlo->sharding().tuple_elements(), - [](const HloSharding& sharding) { - return sharding.IsManualSubgroup(); - })); - if (has_manual_subgroup && !hlo->IsCustomCall("SPMDFullToShardShape") && - !hlo->IsCustomCall("SPMDShardToFullShape") && - hlo->opcode() != HloOpcode::kGetTupleElement) { - auto get_grouped_sharding = - [&](const HloSharding& sharding, const Shape& shape, - const GroupedSharding* ref = - nullptr) -> absl::StatusOr { - if (!sharding.IsTuple()) { - GroupedSharding grouped = - hlo_sharding_util::GetManualSubgroupSharding(sharding); - if (ref != nullptr) { - auto aligned = - AlignGroupsWithIfCompatible(std::move(grouped), *ref); - TF_RET_CHECK(aligned.has_value()) - << "Incompatible manual sharding at " << hlo->ToString(); - return *aligned; - } - return grouped; - } - std::vector elements; - elements.reserve(sharding.tuple_elements().size()); - CHECK(!sharding.tuple_elements().empty()); - GroupedSharding grouped0 = - hlo_sharding_util::GetManualSubgroupSharding( - sharding.tuple_elements()[0]); - if (ref != nullptr) { - auto aligned = - AlignGroupsWithIfCompatible(std::move(grouped0), *ref); - TF_RET_CHECK(aligned.has_value()) - << "Incompatible manual sharding at " << hlo->ToString(); - grouped0 = std::move(*aligned); - } - elements.push_back(std::move(grouped0.sharding)); - for (int64_t i = 1; i < sharding.tuple_elements().size(); ++i) { - auto grouped_i = AlignGroupsWithIfCompatible( - hlo_sharding_util::GetManualSubgroupSharding( - sharding.tuple_elements()[i]), - grouped0); - TF_RET_CHECK(grouped_i.has_value()) - << "Incompatible manual sharding between tuple elements: " - << hlo->ToString(); - elements.push_back(std::move(grouped_i->sharding)); - } - grouped0.sharding = HloSharding::Tuple(shape, elements); - return grouped0; - }; - TF_ASSIGN_OR_RETURN( - auto group_sharding, - get_grouped_sharding(hlo->sharding(), hlo->shape())); - // Update sharding. - visiting_hlo_sharding_ = hlo->sharding(); - hlo->set_sharding(group_sharding.sharding); - // Update device_groups and num_partitions. - // Set device_groups_, visiting_partition_id_ and - // visiting_collective_ops_creator_ before MakePartitioningState() which - // uses them. - device_groups_ = group_sharding.device_groups; - visiting_num_partitions_ = num_partitions_; - num_partitions_ = num_partitions_ / group_sharding.device_groups.size(); - visiting_partition_id_ = partition_id_; - visiting_collective_ops_creator_ = std::move(collective_ops_creator_); - auto grouped_state = MakePartitioningState(); - collective_ops_creator_ = - std::move(grouped_state.collective_ops_creator); - partition_id_ = grouped_state.partition_id; - - // Update sharding for the operands. - visiting_hlo_operand_shardings_.reserve(hlo->operand_count()); - visiting_state_.reserve(hlo->operand_count()); - for (HloInstruction* operand : hlo->unique_operands()) { - visiting_hlo_operand_shardings_.push_back(operand->sharding()); - auto old_state = GetPartitionedHlo(operand).state(); - visiting_state_.push_back(old_state); - if (operand->shape().IsArray() && operand->IsConstant() && - operand->shape().dimensions_size() == 0 && - !operand->sharding().IsManualSubgroup()) { - // We allowed scalar constants to be CSE'ed between manual/auto - // subgraphs. It's possible that it doesn't have a manual subgroup. - continue; - } - TF_ASSIGN_OR_RETURN( - auto op_group_sharding, - get_grouped_sharding(operand->sharding(), operand->shape(), - &group_sharding)); - operand->set_sharding(op_group_sharding.sharding); - GetPartitionedHlo(operand).hlo()->copy_sharding(operand); - auto group_state = CreatePerGroupPartitioningState( - old_state, op_group_sharding.device_groups, &b_); - GetPartitionedHlo(operand).set_state(group_state); + std::vector operand_shapes(hlo->operand_count()); + for (int i = 0; i < hlo->operand_count(); ++i) { + operand_shapes[i] = hlo->operand(i)->shape(); + } + return ShapeUtil::MakeTupleShape(operand_shapes); + }; + hlo->set_sharding(manual_to_onedevice( + hlo->opcode(), get_sharding_shape(hlo), *visiting_hlo_sharding_)); + + visiting_hlo_operand_shardings_.reserve(hlo->operand_count()); + for (HloInstruction* operand : hlo->unique_operands()) { + visiting_hlo_operand_shardings_.push_back(operand->sharding()); + operand->set_sharding(manual_to_onedevice( + hlo->opcode(), get_sharding_shape(operand), operand->sharding())); + GetPartitionedHlo(operand).hlo()->copy_sharding(operand); + } + } else if (has_manual_subgroup && + !hlo->IsCustomCall("SPMDFullToShardShape") && + !hlo->IsCustomCall("SPMDShardToFullShape") && + hlo->opcode() != HloOpcode::kGetTupleElement) { + auto get_grouped_sharding = + [&](const HloSharding& sharding, const Shape& shape, + const GroupedSharding* ref = + nullptr) -> absl::StatusOr { + if (!sharding.IsTuple()) { + GroupedSharding grouped = + hlo_sharding_util::GetManualSubgroupSharding(sharding); + if (ref != nullptr) { + auto aligned = AlignGroupsWithIfCompatible(std::move(grouped), *ref); + TF_RET_CHECK(aligned.has_value()) + << "Incompatible manual sharding at " << hlo->ToString(); + return *aligned; } + return grouped; } + std::vector elements; + elements.reserve(sharding.tuple_elements().size()); + CHECK(!sharding.tuple_elements().empty()); + GroupedSharding grouped0 = hlo_sharding_util::GetManualSubgroupSharding( + sharding.tuple_elements()[0]); + if (ref != nullptr) { + auto aligned = AlignGroupsWithIfCompatible(std::move(grouped0), *ref); + TF_RET_CHECK(aligned.has_value()) + << "Incompatible manual sharding at " << hlo->ToString(); + grouped0 = std::move(*aligned); + } + elements.push_back(std::move(grouped0.sharding)); + for (int64_t i = 1; i < sharding.tuple_elements().size(); ++i) { + auto grouped_i = AlignGroupsWithIfCompatible( + hlo_sharding_util::GetManualSubgroupSharding( + sharding.tuple_elements()[i]), + grouped0); + TF_RET_CHECK(grouped_i.has_value()) + << "Incompatible manual sharding between tuple elements: " + << hlo->ToString(); + elements.push_back(std::move(grouped_i->sharding)); + } + grouped0.sharding = HloSharding::Tuple(shape, elements); + return grouped0; + }; + TF_ASSIGN_OR_RETURN(auto group_sharding, + get_grouped_sharding(hlo->sharding(), hlo->shape())); + // Update sharding. + visiting_hlo_sharding_ = hlo->sharding(); + hlo->set_sharding(group_sharding.sharding); + // Update device_groups and num_partitions. + // Set device_groups_, visiting_partition_id_ and + // visiting_collective_ops_creator_ before MakePartitioningState() which + // uses them. + device_groups_ = group_sharding.device_groups; + visiting_num_partitions_ = num_partitions_; + num_partitions_ = num_partitions_ / group_sharding.device_groups.size(); + visiting_partition_id_ = partition_id_; + visiting_collective_ops_creator_ = std::move(collective_ops_creator_); + auto grouped_state = MakePartitioningState(); + collective_ops_creator_ = std::move(grouped_state.collective_ops_creator); + partition_id_ = grouped_state.partition_id; + + // Update sharding for the operands. + visiting_hlo_operand_shardings_.reserve(hlo->operand_count()); + visiting_state_.reserve(hlo->operand_count()); + for (HloInstruction* operand : hlo->unique_operands()) { + visiting_hlo_operand_shardings_.push_back(operand->sharding()); + auto old_state = GetPartitionedHlo(operand).state(); + visiting_state_.push_back(old_state); + if (operand->shape().IsArray() && operand->IsConstant() && + operand->shape().dimensions_size() == 0 && + !operand->sharding().IsManualSubgroup()) { + // We allowed scalar constants to be CSE'ed between manual/auto + // subgraphs. It's possible that it doesn't have a manual subgroup. + continue; + } + TF_ASSIGN_OR_RETURN( + auto op_group_sharding, + get_grouped_sharding(operand->sharding(), operand->shape(), + &group_sharding)); + operand->set_sharding(op_group_sharding.sharding); + GetPartitionedHlo(operand).hlo()->copy_sharding(operand); + auto group_state = CreatePerGroupPartitioningState( + old_state, op_group_sharding.device_groups, &b_); + GetPartitionedHlo(operand).set_state(group_state); } } + return absl::OkStatus(); } From 9454ed4d12d53a92c7bee455317c600fbd609ea7 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Thu, 10 Apr 2025 16:23:41 -0700 Subject: [PATCH 0527/1324] Integrate StableHLO at openxla/stablehlo@8d9a84b5 PiperOrigin-RevId: 746216102 --- third_party/stablehlo/temporary.patch | 75 ------------------- third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 75 ------------------- .../xla/third_party/stablehlo/workspace.bzl | 4 +- 4 files changed, 4 insertions(+), 154 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8683dc781d6a82..ff118b89c8b1ba 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -607,54 +607,6 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "stablehlo/dialect/VhloOps.td", deps = [ -diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -@@ -12,7 +12,7 @@ - return %2 : tensor<14x15x0x33xf64> - } - func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> - return %cst : tensor<14x15x0x17xcomplex> - } - func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { -diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -@@ -12,7 +12,7 @@ - return %2 : tensor<14x15x0x33xf32> - } - func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> - return %cst : tensor<14x15x0x17xcomplex> - } - func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { -diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -@@ -16,7 +16,7 @@ - return %cst : tensor<14x15x0x17xf32> - } - func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> - return %cst : tensor<14x15x0x9xcomplex> - } - } -diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -@@ -16,7 +16,7 @@ - return %cst : tensor<14x15x0x17xf64> - } - func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> - return %cst : tensor<14x15x0x9xcomplex> - } - } diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel --- stablehlo/stablehlo/tests/BUILD.bazel +++ stablehlo/stablehlo/tests/BUILD.bazel @@ -699,31 +651,4 @@ diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/B tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "TestUtils.td", deps = [ -diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp ---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -@@ -1539,8 +1539,8 @@ - - void populateStablehloHloImportCanonicalizationPatterns( - MLIRContext *context, RewritePatternSet *patterns) { -- patterns->add( -- context); -+ patterns->add(context); - } - - std::unique_ptr createStablehloAggressiveSimplificationPass( -diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -@@ -366,7 +366,8 @@ - (StableHLO_ReshapeOpWithShape $reshape, $operand)>; - - // Pattern: reshape(X, [X.shape]) -> X --def : Pat<(StableHLO_ReshapeOp:$reshape $operand), -+def ReshapeIsNoop -+ : Pat<(StableHLO_ReshapeOp:$reshape $operand), - (replaceWithValue $operand), - [(TypesEqual $reshape, $operand)]>; - diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index da51958d55b282..0f079f6856ed99 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "4bf77d23bd9150782a70d85fda9c12a2dec5328c" - STABLEHLO_SHA256 = "0efae2563d87c642cf9ad5c576911c5f08f9b5ee023b626ddb2a51a87d93297f" + STABLEHLO_COMMIT = "8d9a84b5efbd1fe57cfcb84c6fa38f751bdbabe8" + STABLEHLO_SHA256 = "6e4a05f016d428778b9a95e15da1c2126c4376c32105734343a86cc1b7adfbf4" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8683dc781d6a82..ff118b89c8b1ba 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -607,54 +607,6 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "stablehlo/dialect/VhloOps.td", deps = [ -diff --ruN a/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_complex128_14_15_0_17.mlir -@@ -12,7 +12,7 @@ - return %2 : tensor<14x15x0x33xf64> - } - func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> - return %cst : tensor<14x15x0x17xcomplex> - } - func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) { -diff --ruN a/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_complex64_14_15_0_17.mlir -@@ -12,7 +12,7 @@ - return %2 : tensor<14x15x0x33xf32> - } - func.func private @inputs() -> (tensor<14x15x0x17xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex> - return %cst : tensor<14x15x0x17xcomplex> - } - func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) { -diff --ruN a/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_float32_14_15_0_17.mlir -@@ -16,7 +16,7 @@ - return %cst : tensor<14x15x0x17xf32> - } - func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> - return %cst : tensor<14x15x0x9xcomplex> - } - } -diff --ruN a/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir b/stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir ---- stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -+++ stablehlo/stablehlo/testdata/fft_float64_14_15_0_17.mlir -@@ -16,7 +16,7 @@ - return %cst : tensor<14x15x0x17xf64> - } - func.func private @expected() -> (tensor<14x15x0x9xcomplex> {mhlo.layout_mode = "default"}) { -- %cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex> -+ %cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex> - return %cst : tensor<14x15x0x9xcomplex> - } - } diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel --- stablehlo/stablehlo/tests/BUILD.bazel +++ stablehlo/stablehlo/tests/BUILD.bazel @@ -699,31 +651,4 @@ diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/B tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "TestUtils.td", deps = [ -diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp ---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -@@ -1539,8 +1539,8 @@ - - void populateStablehloHloImportCanonicalizationPatterns( - MLIRContext *context, RewritePatternSet *patterns) { -- patterns->add( -- context); -+ patterns->add(context); - } - - std::unique_ptr createStablehloAggressiveSimplificationPass( -diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -@@ -366,7 +366,8 @@ - (StableHLO_ReshapeOpWithShape $reshape, $operand)>; - - // Pattern: reshape(X, [X.shape]) -> X --def : Pat<(StableHLO_ReshapeOp:$reshape $operand), -+def ReshapeIsNoop -+ : Pat<(StableHLO_ReshapeOp:$reshape $operand), - (replaceWithValue $operand), - [(TypesEqual $reshape, $operand)]>; - diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index da51958d55b282..0f079f6856ed99 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "4bf77d23bd9150782a70d85fda9c12a2dec5328c" - STABLEHLO_SHA256 = "0efae2563d87c642cf9ad5c576911c5f08f9b5ee023b626ddb2a51a87d93297f" + STABLEHLO_COMMIT = "8d9a84b5efbd1fe57cfcb84c6fa38f751bdbabe8" + STABLEHLO_SHA256 = "6e4a05f016d428778b9a95e15da1c2126c4376c32105734343a86cc1b7adfbf4" # LINT.ThenChange(Google-internal path) tf_http_archive( From bbc1615620aa9e3921d7e4f6bf8f7709be1a1207 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 16:46:16 -0700 Subject: [PATCH 0528/1324] Reverts ded0458a49c25b722d71d27629a09079fb651197 PiperOrigin-RevId: 746223242 --- tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 26cfad1ce7bf52..0d904a9c44cf35 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -703,7 +703,12 @@ class Delegate { } bool enable_subgraph_reshaping() const { +#ifdef XNNPACK_DELEGATE_ENABLE_SUBGRAPH_RESHAPING return true; +#else + return (options_.flags & + TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING) != 0; +#endif } bool enable_slinky() const { From 42a3a3c7b7da608369017573b204b56db1dc4a09 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 10 Apr 2025 16:59:10 -0700 Subject: [PATCH 0529/1324] Update XNNPACK dependency version in XLA PiperOrigin-RevId: 746227299 --- tensorflow/lite/tools/cmake/modules/xnnpack.cmake | 2 +- third_party/xla/tsl_workspace2.bzl | 12 ++++++------ third_party/xla/workspace2.bzl | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 6a60e1e1f9fde9..7dbda8f999e4d3 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG ece21c589be842fbeaee297b0d668194d6f3a35b + GIT_TAG 8a2f5f441833b80806b58b5d704ec8335634182c GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/third_party/xla/tsl_workspace2.bzl b/third_party/xla/tsl_workspace2.bzl index deaf8b0ebaef0c..a3243925c95f09 100644 --- a/third_party/xla/tsl_workspace2.bzl +++ b/third_party/xla/tsl_workspace2.bzl @@ -115,9 +115,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "9e290e7b094134bdda0cad4ef4b89625fbde3c4b8e8f5dc84044c0f2e55b875a", - strip_prefix = "XNNPACK-5b4978cae19292232a27bdf0f495819bf5297167", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/5b4978cae19292232a27bdf0f495819bf5297167.zip"), + sha256 = "1832b8998252529d73e585b545c3f1a12a69ddd136ba9072ea9f717e17ce452b", + strip_prefix = "XNNPACK-8a2f5f441833b80806b58b5d704ec8335634182c", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/8a2f5f441833b80806b58b5d704ec8335634182c.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -130,9 +130,9 @@ def _tf_repositories(): tf_http_archive( name = "pthreadpool", - sha256 = "215724985c4845cdcadcb5f26a2a8777943927bb5a172a00e7716fe16a6f3c1b", - strip_prefix = "pthreadpool-b1aee199d54003fb557076a201bcac3398af580b", - urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/b1aee199d54003fb557076a201bcac3398af580b.zip"), + sha256 = "618fcb27b1bc895af5431acccc96bc005e872854115ad99cdbaf803a53b00a4c", + strip_prefix = "pthreadpool-da30a55fecdd9f2f90f236874a305db146567aef", + urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/da30a55fecdd9f2f90f236874a305db146567aef.zip"), ) tf_http_archive( diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index a76093e926f022..0a883cd0b865e7 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -45,9 +45,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "72e4368ff3e7bdefd8b43fc6e5708b8e9fada7a8302ba2362028832df6262c13", - strip_prefix = "XNNPACK-e67c0fbc360903f921ff286a235c18d9e12c6df6", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/e67c0fbc360903f921ff286a235c18d9e12c6df6.zip"), + sha256 = "1832b8998252529d73e585b545c3f1a12a69ddd136ba9072ea9f717e17ce452b", + strip_prefix = "XNNPACK-8a2f5f441833b80806b58b5d704ec8335634182c", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/8a2f5f441833b80806b58b5d704ec8335634182c.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -74,9 +74,9 @@ def _tf_repositories(): tf_http_archive( name = "pthreadpool", - sha256 = "2d56c31ebf6509d171d12ace2b543f6182ff0083ba674541515fc573738a3238", - strip_prefix = "pthreadpool-706a8ea9e4b8c2129718af195ddce7fc2573e719", - urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/706a8ea9e4b8c2129718af195ddce7fc2573e719.zip"), + sha256 = "618fcb27b1bc895af5431acccc96bc005e872854115ad99cdbaf803a53b00a4c", + strip_prefix = "pthreadpool-da30a55fecdd9f2f90f236874a305db146567aef", + urls = tf_mirror_urls("https://github.com/google/pthreadpool/archive/da30a55fecdd9f2f90f236874a305db146567aef.zip"), ) tf_http_archive( From 0e35da02f547acbaf91714015a42aefb1b641973 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 17:17:31 -0700 Subject: [PATCH 0530/1324] Fix more undefined behaviors in PJRT. It's undefined behavior to `reinterpret_cast` an `A*` to a `B*` where `A` and `B` are unrelated and neither is a `char`/`unsigned char`/`std::byte` type. In particular, we cannot `reinterpret_cast` a pointer to a PJRT extension struct `Foo` to a `PJRT_Extension_Base*`, as `Foo` is not related to `PJRT_Extension_Base` as far as the compiler is concerned. Usually, the fix is to make `Foo` derive from `PJRT_Extension_Base.` However, the code needs to work as C, and C doesn't support inheritance. Therefore we make `Foo` contain a `PJRT_Extension_Base` variable as its first field instead. PiperOrigin-RevId: 746232486 --- .../xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc | 10 ++--- .../pjrt_c_api_custom_partitioner_extension.h | 4 +- .../xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h | 4 +- .../xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 8 ++-- .../xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h | 4 +- .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 41 ++++++++++--------- .../xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 6 +-- .../xla/pjrt/c/pjrt_c_api_stream_extension.h | 4 +- .../xla/pjrt/c/pjrt_c_api_triton_extension.h | 4 +- .../xla/pjrt/c/pjrt_c_api_triton_internal.h | 8 ++-- third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 9 ++-- .../plugin/example_plugin/myplugin_c_pjrt.cc | 3 +- 12 files changed, 47 insertions(+), 58 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index 75f030997f8704..4726bf2a3b16c3 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -69,18 +69,16 @@ const PJRT_Api* GetCpuPjrtApi() { pjrt::CreateLayoutsExtension(nullptr); static PJRT_MemoryDescriptions_Extension memory_descriptions_extension = - pjrt::CreateMemoryDescriptionsExtension( - reinterpret_cast(&layouts_extension)); + pjrt::CreateMemoryDescriptionsExtension(&layouts_extension.base); - static PJRT_FFI_Extension ffi_extension = pjrt::CreateFfiExtension( - reinterpret_cast(&memory_descriptions_extension)); + static PJRT_FFI_Extension ffi_extension = + pjrt::CreateFfiExtension(&memory_descriptions_extension.base); static const PJRT_Api pjrt_api = pjrt::CreatePjrtApi( pjrt::cpu_plugin::PJRT_Client_Create, pjrt::cpu_plugin::PJRT_ExecuteContext_Create, pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create, - pjrt::PJRT_Plugin_Initialize_NoOp, - reinterpret_cast(&ffi_extension), + pjrt::PJRT_Plugin_Initialize_NoOp, &ffi_extension.base, pjrt::PJRT_Plugin_Attributes_Xla); return &pjrt_api; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h index beee3d9a5e7229..a79764f5dc8913 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h @@ -130,9 +130,7 @@ typedef PJRT_Error* PJRT_Register_Batch_Partitionable( PJRT_Register_Batch_Partitionable_Args* args); typedef struct PJRT_Custom_Partitioner_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_Register_Custom_Partitioner* register_custom_partitioner; PJRT_Register_Batch_Partitionable* register_batch_partitionable; } PJRT_Custom_Partitioner_Extension; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h index d9f6d163d842cf..995a2c7e50dc8f 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h @@ -70,9 +70,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); typedef struct PJRT_FFI_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_FFI_TypeID_Register* type_id_register; PJRT_FFI_UserData_Add* user_data_add; } PJRT_FFI; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc index 8898088a23a8aa..5fa88eab330ad3 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc @@ -70,9 +70,11 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { return { - /*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_FFI, - /*next=*/next, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_FFI, + /*next=*/next, + }, /*type_id_register=*/PJRT_FFI_TypeID_Register, /*user_data_add=*/PJRT_FFI_UserData_Add, }; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h index 28b17e5434f2ea..c456e1b1a85c55 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h @@ -43,9 +43,7 @@ typedef PJRT_Error* PJRT_Gpu_Register_Custom_Call( PJRT_Gpu_Register_Custom_Call_Args* args); typedef struct PJRT_Gpu_Custom_Call { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_Gpu_Register_Custom_Call* custom_call; } PJRT_Gpu_Custom_Call; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Custom_Call, custom_call); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 6980d736a57104..abb9d1a41b6dc3 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -350,9 +350,11 @@ PJRT_Error* PJRT_Register_Batch_Partitionable( } PJRT_Custom_Partitioner_Extension custom_partitioner{ - /*struct_size=*/PJRT_Custom_Partitioner_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner, - /*next=*/reinterpret_cast(&profiler_extension), + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Custom_Partitioner_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner, + /*next=*/&profiler_extension.base, + }, /*register_custom_partitioner=*/PJRT_Register_Custom_Partitioner, /*register_batch_partitionable=*/PJRT_Register_Batch_Partitionable, }; @@ -383,9 +385,11 @@ PJRT_Error* PJRT_Wait_Until_Buffer_Ready_On_Stream( } PJRT_Stream_Extension stream{ - /*struct_size=*/PJRT_Stream_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Stream, - /*next=*/reinterpret_cast(&custom_partitioner), + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Stream_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Stream, + /*next=*/&custom_partitioner.base, + }, /*get_stream=*/PJRT_Get_Stream_For_External_Ready_Events, /*wait_stream=*/PJRT_Wait_Until_Buffer_Ready_On_Stream, }; @@ -421,32 +425,31 @@ PJRT_Error* PJRT_Gpu_Register_Custom_Call( const PJRT_Api* GetGpuPjrtApi() { static PJRT_Gpu_Custom_Call custom_call{ - /*struct_size=*/PJRT_Gpu_Custom_Call_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call, - /*next=*/reinterpret_cast(&stream), + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Gpu_Custom_Call_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call, + /*next=*/&stream.base, + }, /*custom_call=*/PJRT_Gpu_Register_Custom_Call, }; static PJRT_Layouts_Extension layouts_extension = - pjrt::CreateLayoutsExtension( - reinterpret_cast(&custom_call)); + pjrt::CreateLayoutsExtension(&custom_call.base); - static PJRT_FFI_Extension ffi_extension = pjrt::CreateFfiExtension( - reinterpret_cast(&layouts_extension)); + static PJRT_FFI_Extension ffi_extension = + pjrt::CreateFfiExtension(&layouts_extension.base); static PJRT_MemoryDescriptions_Extension memory_descriptions_extension = - pjrt::CreateMemoryDescriptionsExtension( - reinterpret_cast(&ffi_extension)); + pjrt::CreateMemoryDescriptionsExtension(&ffi_extension.base); - static PJRT_Triton_Extension triton_extension = pjrt::CreateTritonExtension( - reinterpret_cast(&memory_descriptions_extension)); + static PJRT_Triton_Extension triton_extension = + pjrt::CreateTritonExtension(&memory_descriptions_extension.base); static const PJRT_Api pjrt_api = pjrt::CreatePjrtApi( pjrt::gpu_plugin::PJRT_Client_Create, pjrt::gpu_plugin::PJRT_ExecuteContext_Create, pjrt::gpu_plugin::PJRT_GpuDeviceTopology_Create, - pjrt::PJRT_Plugin_Initialize_NoOp, - reinterpret_cast(&triton_extension), + pjrt::PJRT_Plugin_Initialize_NoOp, &triton_extension.base, pjrt::PJRT_Plugin_Attributes_Xla); return &pjrt_api; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 182477dcae19e6..08f7f14dc1f3a2 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -949,8 +949,7 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) { args.handler_initialize = nullptr; args.handler_execute = reinterpret_cast(&TestCustomCallV2); auto api = GetPjrtApi(); - const PJRT_Extension_Base* next = - reinterpret_cast(api->extension_start); + const PJRT_Extension_Base* next = api->extension_start; while (next != nullptr && next->type != PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { @@ -982,8 +981,7 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallTyped) { args.handler_initialize = nullptr; args.handler_execute = reinterpret_cast(kNoop); auto api = GetPjrtApi(); - const PJRT_Extension_Base* next = - reinterpret_cast(api->extension_start); + const PJRT_Extension_Base* next = api->extension_start; while (next != nullptr && next->type != PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_stream_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_stream_extension.h index 3c691d43c41311..d277becd64ce60 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_stream_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_stream_extension.h @@ -51,9 +51,7 @@ typedef PJRT_Error* PJRT_Wait_Until_Buffer_Ready_On_Stream( PJRT_Wait_Until_Buffer_Ready_On_Stream_Args* args); typedef struct PJRT_Stream_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_Get_Stream_For_External_Ready_Events* get_stream; PJRT_Wait_Until_Buffer_Ready_On_Stream* wait_stream; } PJRT_Stream_Extension; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_extension.h index 5fa0e866e07eb3..3eef6aca6b2c75 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_extension.h @@ -49,9 +49,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Triton_Compile_Args, out_cluster_dim_z); typedef PJRT_Error* PJRT_Triton_Compile(PJRT_Triton_Compile_Args* args); typedef struct PJRT_Triton_Extension { - size_t struct_size; - PJRT_Extension_Type type; - PJRT_Extension_Base* next; + PJRT_Extension_Base base; PJRT_Triton_Compile* compile; } PJRT_Triton; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Triton_Extension, compile); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_internal.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_internal.h index d10e9e8a0150d3..a8d955f42711ec 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_internal.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_triton_internal.h @@ -25,9 +25,11 @@ PJRT_Error* PJRT_Triton_Compile(PJRT_Triton_Compile_Args* args); inline PJRT_Triton_Extension CreateTritonExtension(PJRT_Extension_Base* next) { return { - /*struct_size=*/PJRT_Triton_Extension_STRUCT_SIZE, - /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Triton, - /*next=*/next, + PJRT_Extension_Base{ + /*struct_size=*/PJRT_Triton_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Triton, + /*next=*/next, + }, /*compile=*/PJRT_Triton_Compile, }; } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 6072097267cbb5..0ee251eb9fb112 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -360,8 +360,7 @@ InitializeArgsAndCompile(PjRtCApiClient* api_client, const PJRT_Api* c_api, args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE; PJRT_Profiler_Extension profiler_extension = pjrt::CreatePjrtProfilerExtension("PJRT_Client_Compile linkage"); - args.extension_start = - reinterpret_cast(&profiler_extension); + args.extension_start = &profiler_extension.base; args.client = client; TF_ASSIGN_OR_RETURN(const CompileOptionsProto options_proto, options.ToProto()); @@ -1875,8 +1874,7 @@ PjRtCApiLoadedExecutable::Execute( PJRT_Profiler_Extension profiler_extension = pjrt::CreatePjrtProfilerExtension( "PJRT_LoadedExecutable_Execute linkage"); - args.extension_start = - reinterpret_cast(&profiler_extension); + args.extension_start = &profiler_extension.base; RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), pjrt_c_api()); @@ -1947,8 +1945,7 @@ PjRtCApiLoadedExecutable::ExecuteWithSingleDevice( PJRT_Profiler_Extension profiler_extension = pjrt::CreatePjrtProfilerExtension( "PJRT_LoadedExecutable_Execute linkage"); - args.extension_start = - reinterpret_cast(&profiler_extension); + args.extension_start = &profiler_extension.base; RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), pjrt_c_api()); diff --git a/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt.cc b/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt.cc index 7945ed29a1c07d..56a8506992082b 100644 --- a/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt.cc +++ b/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt.cc @@ -60,8 +60,7 @@ const PJRT_Api* GetPjrtApi() { myplugin_pjrt::PJRT_MypluginClient_Create, myplugin_pjrt::PJRT_MypluginExecuteContext_Create, myplugin_pjrt::PJRT_MypluginDeviceTopology_Create, - pjrt::PJRT_Plugin_Initialize_NoOp, - reinterpret_cast(&layouts_extension), + pjrt::PJRT_Plugin_Initialize_NoOp, &layouts_extension.base, pjrt::PJRT_Plugin_Attributes_Xla); printf("MyPlugin called GetPjrtApi\n"); From 2f4e979adb3e56cbb3363f04052cbe8500824f7e Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Thu, 10 Apr 2025 17:51:57 -0700 Subject: [PATCH 0531/1324] [XLA] Migrate to use `dimensions().size()` instead of `dimensions_size()`. dimensions_size() is DEPRECATED , use dimensions().size() instead PiperOrigin-RevId: 746241385 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 20 +++++++++---------- third_party/xla/xla/hlo/ir/hlo_instruction.h | 2 +- .../xla/xla/hlo/ir/hlo_instructions.cc | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 74026577bb9c95..aa8391e0d0d5d9 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -537,9 +537,9 @@ absl::StatusOr> HloInstruction::CreateFromProto( proto.dimensions().end())); break; case HloOpcode::kConcatenate: - TF_RET_CHECK(proto.dimensions_size() == 1) + TF_RET_CHECK(proto.dimensions().size() == 1) << "Concatenate instruction should have 1 dimension but sees " - << proto.dimensions_size(); + << proto.dimensions().size(); instruction = CreateConcatenate(shape, all_operands(), proto.dimensions(0)); break; @@ -729,7 +729,7 @@ absl::StatusOr> HloInstruction::CreateFromProto( channel_id = proto.channel_id(); } - TF_RET_CHECK(proto.dimensions_size() == 1) + TF_RET_CHECK(proto.dimensions().size() == 1) << "AllGather cannot have more than 1 all-gather dimensions"; int64_t all_gather_dimension = proto.dimensions(0); if (opcode == HloOpcode::kAllGather) { @@ -767,7 +767,7 @@ absl::StatusOr> HloInstruction::CreateFromProto( proto.constrain_layout(), channel_id, proto.use_global_device_ids()); } else if (opcode == HloOpcode::kReduceScatter) { - TF_RET_CHECK(proto.dimensions_size() == 1) + TF_RET_CHECK(proto.dimensions().size() == 1) << "ReduceScatter cannot have more than 1 scatter dimensions"; int64_t scatter_dimension = proto.dimensions(0); instruction = CreateReduceScatter( @@ -788,8 +788,8 @@ absl::StatusOr> HloInstruction::CreateFromProto( channel_id = proto.channel_id(); } std::optional split_dimension; - if (proto.dimensions_size() > 0) { - TF_RET_CHECK(proto.dimensions_size() == 1) + if (!proto.dimensions().empty()) { + TF_RET_CHECK(proto.dimensions().size() == 1) << "AllToAll cannot have more than 1 dimension (split dimension)"; TF_RET_CHECK(all_operands().size() == 1) << "AllToAll must have a single operand when the split dimension " @@ -1135,9 +1135,9 @@ absl::StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kIota: - TF_RET_CHECK(proto.dimensions_size() == 1) + TF_RET_CHECK(proto.dimensions().size() == 1) << "Iota instruction should have 1 dimension but sees " - << proto.dimensions_size(); + << proto.dimensions().size(); instruction = CreateIota(shape, proto.dimensions(0)); break; case HloOpcode::kDot: { @@ -1205,12 +1205,12 @@ absl::StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kGetDimensionSize: - TF_RET_CHECK(proto.dimensions_size() == 1); + TF_RET_CHECK(proto.dimensions().size() == 1); instruction = CreateGetDimensionSize(shape, operands(0), proto.dimensions(0)); break; case HloOpcode::kSetDimensionSize: - TF_RET_CHECK(proto.dimensions_size() == 1); + TF_RET_CHECK(proto.dimensions().size() == 1); instruction = CreateSetDimensionSize(shape, operands(0), operands(1), proto.dimensions(0)); break; diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 1930980c9b8d96..12e1804a69e1d6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -998,7 +998,7 @@ class HloInstruction { // Creates a dynamic reshape instruction. Similar to reshape but dynamic // dimensions sizes are provided as additional variadic arguments. // - // Precondition: dim_sizes.size() == shape.dimensions_size() + // Precondition: dim_sizes.size() == shape.dimensions().size() static std::unique_ptr CreateDynamicReshape( const Shape& shape, HloInstruction* data_operand, absl::Span dim_sizes); diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 5944ebf631a5d6..bc5ace122c82d2 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -2982,7 +2982,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { void HloConvolutionInstruction::PrintExtraAttributesImpl( AttributePrinter& printer, const HloPrintOptions& options) const { - if (window_.dimensions_size() != 0) { + if (!window_.dimensions().empty()) { printer.Next([this](Printer* printer) { AppendCat(printer, "window={", window_util::ToString(window()), "}"); }); @@ -3067,7 +3067,7 @@ HloInstructionProto HloReduceWindowInstruction::ToProto() const { void HloReduceWindowInstruction::PrintExtraAttributesImpl( AttributePrinter& printer, const HloPrintOptions& options) const { - if (window_.dimensions_size() != 0) { + if (!window_.dimensions().empty()) { printer.Next([this](Printer* printer) { AppendCat(printer, "window={", window_util::ToString(window()), "}"); }); @@ -3118,7 +3118,7 @@ HloInstructionProto HloSelectAndScatterInstruction::ToProto() const { void HloSelectAndScatterInstruction::PrintExtraAttributesImpl( AttributePrinter& printer, const HloPrintOptions& options) const { - if (window_.dimensions_size() != 0) { + if (!window_.dimensions().empty()) { printer.Next([this](Printer* printer) { AppendCat(printer, "window={", window_util::ToString(window()), "}"); }); From 468974a2f9c2b5dc54e847e3b479b8fc5be1f402 Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Thu, 10 Apr 2025 19:35:13 -0700 Subject: [PATCH 0532/1324] Fix divide by 0 issue in roofline model data conversion PiperOrigin-RevId: 746266565 --- .../core/profiler/convert/op_stats_to_roofline_model.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc index 02066b0720c8ae..44cca8bdba14a2 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc @@ -62,9 +62,10 @@ RooflineModelRecord ConvertOpMetricsToRooflineModelRecord( // For RecordType::AVERAGE_STEP, divide by num_steps to show per-step // numbers when appropriate. int num_steps = op_stats.step_db().step_sequence_size(); - record.set_total_time_in_us(record.total_time_in_us() / num_steps); - record.set_total_self_time_in_us(record.total_self_time_in_us() / - num_steps); + record.set_total_time_in_us( + tsl::profiler::SafeDivide(record.total_time_in_us(), num_steps)); + record.set_total_self_time_in_us( + tsl::profiler::SafeDivide(record.total_self_time_in_us(), num_steps)); } record.set_total_time_per_core_in_us(tsl::profiler::SafeDivide( record.total_time_in_us(), From 46e7e79c3e47984d1835471ea970e5d9dbea04d2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 19:50:21 -0700 Subject: [PATCH 0533/1324] Support `safe_reinterpret_cast(nullptr)`. This usage is blessed by the C++ style guide. `nullptr` has a special type (`nullptr_t`) that's not a pointer type. Hence we need a new specialization to handle it. PiperOrigin-RevId: 746269639 --- third_party/xla/xla/tsl/util/safe_reinterpret_cast.h | 6 ++++++ .../xla/xla/tsl/util/safe_reinterpret_cast_test.cc | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h index 079f4c060bf490..47dc02e68eba05 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast.h @@ -106,6 +106,12 @@ struct IsSafeCast : std::true_type {}; template struct IsSafeCast : std::true_type {}; +// It's safe to cast the nullptr literal to std::uintptr_t or std::intptr_t. +template <> +struct IsSafeCast : std::true_type {}; +template <> +struct IsSafeCast : std::true_type {}; + } // namespace internal // Like reinterpret_cast, but compiles only if it's safe. diff --git a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc index d917111ad31247..4b344bca53d9a7 100644 --- a/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc +++ b/third_party/xla/xla/tsl/util/safe_reinterpret_cast_test.cc @@ -101,6 +101,16 @@ TEST(SafeReinterpretCast, CanCastPointerToFromStdIntptrT) { EXPECT_EQ(safe_reinterpret_cast(intptr_t_p), &x); } +TEST(SafeReinterpretCast, CanCastNullptrToStdUintptrT) { + const std::uintptr_t n = safe_reinterpret_cast(nullptr); + EXPECT_EQ(safe_reinterpret_cast(n), nullptr); +} + +TEST(SafeReinterpretCast, CanCastNullptrToStdIntptrT) { + const std::intptr_t n = safe_reinterpret_cast(nullptr); + EXPECT_EQ(safe_reinterpret_cast(n), nullptr); +} + TEST(SafeReinterpretCast, CanCastPointerToFromSameType) { const int x = 42; const int* const int_p = safe_reinterpret_cast(&x); From 12febd898a409f6f602902ad8b0d4ecc078cf708 Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 10 Apr 2025 19:53:30 -0700 Subject: [PATCH 0534/1324] Don't use `_EQ` macro with `nullptr`; use `==` operator instead. PiperOrigin-RevId: 746270375 --- tensorflow/lite/experimental/genai/BUILD | 3 --- tensorflow/lite/experimental/genai/kvcache.cc | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/experimental/genai/BUILD b/tensorflow/lite/experimental/genai/BUILD index 144734c0ba3660..35de2af73a28fd 100644 --- a/tensorflow/lite/experimental/genai/BUILD +++ b/tensorflow/lite/experimental/genai/BUILD @@ -26,12 +26,9 @@ cc_library( "//tensorflow/lite/experimental/resource", "//tensorflow/lite/experimental/resource:cache_buffer", "//tensorflow/lite/kernels:kernel_util", - "//tensorflow/lite/kernels:reference_ops", "//tensorflow/lite/kernels/internal:common", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "@flatbuffers", ], diff --git a/tensorflow/lite/experimental/genai/kvcache.cc b/tensorflow/lite/experimental/genai/kvcache.cc index 59fa3abd7ed510..f4f8bacf43eb0e 100644 --- a/tensorflow/lite/experimental/genai/kvcache.cc +++ b/tensorflow/lite/experimental/genai/kvcache.cc @@ -267,10 +267,10 @@ TfLiteStatus KVCacheEval(TfLiteContext* context, TfLiteNode* node) { v_ptr = v_ptr + sizeof(float) * op_data->layer_index * elements_in_one_block; // 0. Ensure output ptr is pointing to the cache data - TF_LITE_ENSURE_EQ(context, k_ptr, op_data->key_cache_ptr); - TF_LITE_ENSURE_EQ(context, v_ptr, op_data->value_cache_ptr); - TF_LITE_ENSURE_EQ(context, k_ptr, kfull->data.data); - TF_LITE_ENSURE_EQ(context, v_ptr, vfull->data.data); + TF_LITE_ENSURE(context, k_ptr == op_data->key_cache_ptr); + TF_LITE_ENSURE(context, v_ptr == op_data->value_cache_ptr); + TF_LITE_ENSURE(context, k_ptr == kfull->data.data); + TF_LITE_ENSURE(context, v_ptr == vfull->data.data); // 1. Determine which slots the inputs take up, and which slots are in the // existing span of the cache. From 25e9a64433da926566f0b83ab0e2f8d0db803a64 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 10 Apr 2025 20:06:46 -0700 Subject: [PATCH 0535/1324] Remove deprecated targets from build and header files PiperOrigin-RevId: 746273025 --- third_party/xla/xla/service/BUILD | 660 ------------------ .../xla/xla/service/add_original_value.h | 22 - .../xla/xla/service/algebraic_simplifier.h | 22 - .../service/all_gather_broadcast_reorder.h | 22 - .../xla/xla/service/all_gather_combiner.h | 22 - .../xla/xla/service/all_reduce_combiner.h | 22 - .../xla/xla/service/all_reduce_contiguous.h | 22 - .../xla/xla/service/all_reduce_folder.h | 22 - third_party/xla/xla/service/ar_crs_combiner.h | 22 - .../xla/service/async_collective_creator.h | 22 - .../xla/service/batch_dot_simplification.h | 22 - .../xla/service/bfloat16_conversion_folding.h | 22 - .../xla/xla/service/bfloat16_propagation.h | 22 - .../xla/xla/service/bitcast_dtypes_expander.h | 22 - .../xla/xla/service/broadcast_canonicalizer.h | 22 - .../xla/xla/service/cholesky_expander.h | 22 - .../xla/xla/service/collective_quantizer.h | 22 - .../collective_transformation_reorderer.h | 22 - .../service/collectives_schedule_linearizer.h | 22 - .../xla/xla/service/comparison_expander.h | 22 - .../xla/service/conditional_canonicalizer.h | 21 - .../convert_async_collectives_to_sync.h | 22 - ...memory_placement_to_internal_annotations.h | 22 - third_party/xla/xla/service/convert_mover.h | 22 - .../xla/xla/service/convert_operand_folding.h | 22 - .../xla/xla/service/convolution_4d_expander.h | 22 - .../xla/service/convolution_group_converter.h | 22 - .../xla/service/convolution_pred_expander.h | 22 - third_party/xla/xla/service/defuser.h | 22 - third_party/xla/xla/service/despecializer.h | 22 - third_party/xla/xla/service/dot_decomposer.h | 22 - .../xla/xla/service/dot_dimension_merger.h | 22 - third_party/xla/xla/service/dot_merger.h | 22 - .../service/dynamic_dimension_simplifier.h | 21 - .../xla/xla/service/dynamic_index_splitter.h | 22 - third_party/xla/xla/service/eigh_expander.h | 22 - .../xla/xla/service/flatten_call_graph.h | 24 - .../xla/xla/service/float_normalization.h | 22 - .../xla/xla/service/fusion_constant_sinking.h | 22 - .../xla/xla/service/gather_simplifier.h | 22 - .../xla/xla/service/hlo_alias_analysis.h | 22 - .../service/hlo_computation_deduplicator.h | 22 - .../xla/xla/service/hlo_constant_folding.h | 22 - .../xla/xla/service/hlo_dataflow_analysis.h | 26 - third_party/xla/xla/service/hlo_dce.h | 22 - .../xla/service/hlo_element_type_converter.h | 22 - third_party/xla/xla/service/hlo_lexer.h | 22 - .../xla/xla/service/hlo_liveness_analysis.h | 22 - .../xla/xla/service/hlo_memory_scheduler.h | 22 - third_party/xla/xla/service/hlo_ordering.h | 22 - third_party/xla/xla/service/hlo_parser.h | 22 - third_party/xla/xla/service/hlo_pass_fix.h | 22 - .../xla/xla/service/hlo_pass_interface.h | 22 - .../xla/xla/service/hlo_pass_pipeline.h | 22 - .../xla/xla/service/hlo_rematerialization.h | 21 - .../hlo_rematerialization_test_utils.h | 24 - .../xla/service/hlo_replication_analysis.h | 22 - .../service/hlo_value_semantics_analysis.h | 22 - .../service/host_memory_transfer_asyncifier.h | 21 - .../xla/xla/service/host_offload_legalize.h | 21 - third_party/xla/xla/service/host_offloader.h | 21 - .../xla/xla/service/host_offloading_prepare.h | 22 - .../xla/xla/service/indexed_array_analysis.h | 22 - .../xla/service/infeed_token_propagation.h | 22 - .../xla/xla/service/instruction_hoister.h | 22 - .../xla/xla/service/logical_buffer_analysis.h | 22 - .../xla/xla/service/logistic_expander.h | 22 - .../xla/service/memory_space_propagation.h | 22 - .../xla/xla/service/op_expander_pass.h | 22 - .../xla/xla/service/operand_upcaster.h | 22 - .../service/optimization_barrier_expander.h | 22 - .../optimize_input_output_buffer_alias.h | 22 - .../xla/xla/service/pattern_matcher_gmock.h | 22 - third_party/xla/xla/service/qr_expander.h | 22 - .../xla/xla/service/real_imag_expander.h | 22 - .../xla/xla/service/reduce_decomposer.h | 22 - .../xla/xla/service/reduce_window_rewriter.h | 22 - .../xla/xla/service/reshape_decomposer.h | 22 - third_party/xla/xla/service/reshape_mover.h | 22 - third_party/xla/xla/service/result_caster.h | 22 - .../xla/service/rng_bit_generator_expander.h | 22 - third_party/xla/xla/service/rng_expander.h | 22 - .../xla/xla/service/root_instruction_sinker.h | 22 - .../xla/xla/service/simplify_fp_conversions.h | 22 - third_party/xla/xla/service/slice_sinker.h | 22 - third_party/xla/xla/service/sort_simplifier.h | 22 - .../xla/xla/service/stable_sort_expander.h | 22 - .../service/stochastic_convert_decomposer.h | 22 - .../xla/xla/service/sub_byte_normalization.h | 22 - .../xla/xla/service/tree_reduction_rewriter.h | 22 - .../xla/service/tuple_points_to_analysis.h | 22 - .../xla/xla/service/tuple_simplifier.h | 22 - .../xla/xla/service/while_loop_analysis.h | 22 - .../service/while_loop_trip_count_annotator.h | 22 - .../xla/service/zero_sized_hlo_elimination.h | 22 - 95 files changed, 2730 deletions(-) delete mode 100644 third_party/xla/xla/service/add_original_value.h delete mode 100644 third_party/xla/xla/service/algebraic_simplifier.h delete mode 100644 third_party/xla/xla/service/all_gather_broadcast_reorder.h delete mode 100644 third_party/xla/xla/service/all_gather_combiner.h delete mode 100644 third_party/xla/xla/service/all_reduce_combiner.h delete mode 100644 third_party/xla/xla/service/all_reduce_contiguous.h delete mode 100644 third_party/xla/xla/service/all_reduce_folder.h delete mode 100644 third_party/xla/xla/service/ar_crs_combiner.h delete mode 100644 third_party/xla/xla/service/async_collective_creator.h delete mode 100644 third_party/xla/xla/service/batch_dot_simplification.h delete mode 100644 third_party/xla/xla/service/bfloat16_conversion_folding.h delete mode 100644 third_party/xla/xla/service/bfloat16_propagation.h delete mode 100644 third_party/xla/xla/service/bitcast_dtypes_expander.h delete mode 100644 third_party/xla/xla/service/broadcast_canonicalizer.h delete mode 100644 third_party/xla/xla/service/cholesky_expander.h delete mode 100644 third_party/xla/xla/service/collective_quantizer.h delete mode 100644 third_party/xla/xla/service/collective_transformation_reorderer.h delete mode 100644 third_party/xla/xla/service/collectives_schedule_linearizer.h delete mode 100644 third_party/xla/xla/service/comparison_expander.h delete mode 100644 third_party/xla/xla/service/conditional_canonicalizer.h delete mode 100644 third_party/xla/xla/service/convert_async_collectives_to_sync.h delete mode 100644 third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h delete mode 100644 third_party/xla/xla/service/convert_mover.h delete mode 100644 third_party/xla/xla/service/convert_operand_folding.h delete mode 100644 third_party/xla/xla/service/convolution_4d_expander.h delete mode 100644 third_party/xla/xla/service/convolution_group_converter.h delete mode 100644 third_party/xla/xla/service/convolution_pred_expander.h delete mode 100644 third_party/xla/xla/service/defuser.h delete mode 100644 third_party/xla/xla/service/despecializer.h delete mode 100644 third_party/xla/xla/service/dot_decomposer.h delete mode 100644 third_party/xla/xla/service/dot_dimension_merger.h delete mode 100644 third_party/xla/xla/service/dot_merger.h delete mode 100644 third_party/xla/xla/service/dynamic_dimension_simplifier.h delete mode 100644 third_party/xla/xla/service/dynamic_index_splitter.h delete mode 100644 third_party/xla/xla/service/eigh_expander.h delete mode 100644 third_party/xla/xla/service/flatten_call_graph.h delete mode 100644 third_party/xla/xla/service/float_normalization.h delete mode 100644 third_party/xla/xla/service/fusion_constant_sinking.h delete mode 100644 third_party/xla/xla/service/gather_simplifier.h delete mode 100644 third_party/xla/xla/service/hlo_alias_analysis.h delete mode 100644 third_party/xla/xla/service/hlo_computation_deduplicator.h delete mode 100644 third_party/xla/xla/service/hlo_constant_folding.h delete mode 100644 third_party/xla/xla/service/hlo_dataflow_analysis.h delete mode 100644 third_party/xla/xla/service/hlo_dce.h delete mode 100644 third_party/xla/xla/service/hlo_element_type_converter.h delete mode 100644 third_party/xla/xla/service/hlo_lexer.h delete mode 100644 third_party/xla/xla/service/hlo_liveness_analysis.h delete mode 100644 third_party/xla/xla/service/hlo_memory_scheduler.h delete mode 100644 third_party/xla/xla/service/hlo_ordering.h delete mode 100644 third_party/xla/xla/service/hlo_parser.h delete mode 100644 third_party/xla/xla/service/hlo_pass_fix.h delete mode 100644 third_party/xla/xla/service/hlo_pass_interface.h delete mode 100644 third_party/xla/xla/service/hlo_pass_pipeline.h delete mode 100644 third_party/xla/xla/service/hlo_rematerialization.h delete mode 100644 third_party/xla/xla/service/hlo_rematerialization_test_utils.h delete mode 100644 third_party/xla/xla/service/hlo_replication_analysis.h delete mode 100644 third_party/xla/xla/service/hlo_value_semantics_analysis.h delete mode 100644 third_party/xla/xla/service/host_memory_transfer_asyncifier.h delete mode 100644 third_party/xla/xla/service/host_offload_legalize.h delete mode 100644 third_party/xla/xla/service/host_offloader.h delete mode 100644 third_party/xla/xla/service/host_offloading_prepare.h delete mode 100644 third_party/xla/xla/service/indexed_array_analysis.h delete mode 100644 third_party/xla/xla/service/infeed_token_propagation.h delete mode 100644 third_party/xla/xla/service/instruction_hoister.h delete mode 100644 third_party/xla/xla/service/logical_buffer_analysis.h delete mode 100644 third_party/xla/xla/service/logistic_expander.h delete mode 100644 third_party/xla/xla/service/memory_space_propagation.h delete mode 100644 third_party/xla/xla/service/op_expander_pass.h delete mode 100644 third_party/xla/xla/service/operand_upcaster.h delete mode 100644 third_party/xla/xla/service/optimization_barrier_expander.h delete mode 100644 third_party/xla/xla/service/optimize_input_output_buffer_alias.h delete mode 100644 third_party/xla/xla/service/pattern_matcher_gmock.h delete mode 100644 third_party/xla/xla/service/qr_expander.h delete mode 100644 third_party/xla/xla/service/real_imag_expander.h delete mode 100644 third_party/xla/xla/service/reduce_decomposer.h delete mode 100644 third_party/xla/xla/service/reduce_window_rewriter.h delete mode 100644 third_party/xla/xla/service/reshape_decomposer.h delete mode 100644 third_party/xla/xla/service/reshape_mover.h delete mode 100644 third_party/xla/xla/service/result_caster.h delete mode 100644 third_party/xla/xla/service/rng_bit_generator_expander.h delete mode 100644 third_party/xla/xla/service/rng_expander.h delete mode 100644 third_party/xla/xla/service/root_instruction_sinker.h delete mode 100644 third_party/xla/xla/service/simplify_fp_conversions.h delete mode 100644 third_party/xla/xla/service/slice_sinker.h delete mode 100644 third_party/xla/xla/service/sort_simplifier.h delete mode 100644 third_party/xla/xla/service/stable_sort_expander.h delete mode 100644 third_party/xla/xla/service/stochastic_convert_decomposer.h delete mode 100644 third_party/xla/xla/service/sub_byte_normalization.h delete mode 100644 third_party/xla/xla/service/tree_reduction_rewriter.h delete mode 100644 third_party/xla/xla/service/tuple_points_to_analysis.h delete mode 100644 third_party/xla/xla/service/tuple_simplifier.h delete mode 100644 third_party/xla/xla/service/while_loop_analysis.h delete mode 100644 third_party/xla/xla/service/while_loop_trip_count_annotator.h delete mode 100644 third_party/xla/xla/service/zero_sized_hlo_elimination.h diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 1992a34ce73a8c..0b01456d54b023 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -109,13 +109,6 @@ cc_library( ], ) -cc_library( - name = "async_collective_creator", - hdrs = ["async_collective_creator.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:async_collective_creator instead.", - deps = ["//xla/hlo/transforms/collectives:async_collective_creator"], -) - cc_library( name = "all_reduce_key", srcs = ["all_reduce_key.cc"], @@ -207,13 +200,6 @@ xla_cc_test( ], ) -cc_library( - name = "all_reduce_folder", - hdrs = ["all_reduce_folder.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:all_reduce_folder instead.", - deps = ["//xla/hlo/transforms/simplifiers:all_reduce_folder"], -) - cc_library( name = "float_support", srcs = ["float_support.cc"], @@ -224,34 +210,6 @@ cc_library( ], ) -cc_library( - name = "broadcast_canonicalizer", - hdrs = ["broadcast_canonicalizer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:broadcast_canonicalizer instead.", - deps = ["//xla/hlo/transforms/simplifiers:broadcast_canonicalizer"], -) - -cc_library( - name = "bfloat16_conversion_folding", - hdrs = ["bfloat16_conversion_folding.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:bfloat16_conversion_folding instead.", - deps = ["//xla/hlo/transforms/simplifiers:bfloat16_conversion_folding"], -) - -cc_library( - name = "float_normalization", - hdrs = ["float_normalization.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:float_normalization instead.", - deps = ["//xla/hlo/transforms/simplifiers:float_normalization"], -) - -cc_library( - name = "bfloat16_propagation", - hdrs = ["bfloat16_propagation.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:bfloat16_propagation instead.", - deps = ["//xla/hlo/transforms:bfloat16_propagation"], -) - cc_library( name = "source_target_pairs", hdrs = ["source_target_pairs.h"], @@ -410,13 +368,6 @@ xla_cc_test( ], ) -cc_library( - name = "convert_async_collectives_to_sync", - hdrs = ["convert_async_collectives_to_sync.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:convert_async_collectives_to_sync instead.", - deps = ["//xla/hlo/transforms/collectives:convert_async_collectives_to_sync"], -) - cc_library( name = "value_range", srcs = ["value_range.cc"], @@ -524,13 +475,6 @@ xla_cc_test( ], ) -cc_library( - name = "collective_quantizer", - hdrs = ["collective_quantizer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:collective_quantizer instead.", - deps = ["//xla/hlo/transforms/collectives:collective_quantizer"], -) - cc_library( name = "dump", srcs = ["dump.cc"], @@ -878,14 +822,6 @@ xla_cc_test( ], ) -cc_library( - name = "pattern_matcher_gmock", - testonly = 1, - hdrs = ["pattern_matcher_gmock.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/testlib:pattern_matcher_gmock instead.", - deps = ["//xla/hlo/testlib:pattern_matcher_gmock"], -) - xla_cc_test( name = "pattern_matcher_gmock_test", srcs = ["pattern_matcher_gmock_test.cc"], @@ -1017,13 +953,6 @@ xla_cc_test( ], ) -cc_library( - name = "flatten_call_graph", - hdrs = ["flatten_call_graph.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:flatten_call_graph instead.", - deps = ["//xla/hlo/transforms/simplifiers:flatten_call_graph"], -) - cc_library( name = "call_inliner", srcs = ["call_inliner.cc"], @@ -1076,13 +1005,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_computation_deduplicator", - hdrs = ["hlo_computation_deduplicator.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_computation_deduplicator instead.", - deps = ["//xla/hlo/transforms/simplifiers:hlo_computation_deduplicator"], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -1853,13 +1775,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_ordering", - hdrs = ["hlo_ordering.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_ordering instead.", - deps = ["//xla/hlo/analysis:hlo_ordering"], -) - xla_cc_test( name = "hlo_module_group_test", srcs = ["hlo_module_group_test.cc"], @@ -1989,14 +1904,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_memory_scheduler", - hdrs = ["hlo_memory_scheduler.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_memory_scheduler instead.", - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = ["//xla/hlo/transforms/simplifiers:hlo_memory_scheduler"], -) - cc_library( name = "fusion_queue", hdrs = ["fusion_queue.h"], @@ -2186,13 +2093,6 @@ cc_library( ], ) -cc_library( - name = "op_expander_pass", - hdrs = ["op_expander_pass.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:op_expander_pass instead.", - deps = ["//xla/hlo/transforms/expanders:op_expander_pass"], -) - cc_library( name = "gather_expander", srcs = ["gather_expander.cc"], @@ -2212,20 +2112,6 @@ cc_library( ], ) -cc_library( - name = "optimization_barrier_expander", - hdrs = ["optimization_barrier_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:optimization_barrier_expander instead.", - deps = ["//xla/hlo/transforms/expanders:optimization_barrier_expander"], -) - -cc_library( - name = "comparison_expander", - hdrs = ["comparison_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:comparison_expander instead.", - deps = ["//xla/hlo/transforms/expanders:comparison_expander"], -) - cc_library( name = "scatter_utils", srcs = ["scatter_utils.cc"], @@ -2370,48 +2256,6 @@ xla_cc_test( ], ) -cc_library( - name = "cholesky_expander", - hdrs = ["cholesky_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:cholesky_expander instead.", - deps = ["//xla/hlo/transforms/expanders:cholesky_expander"], -) - -cc_library( - name = "qr_expander", - hdrs = ["qr_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:qr_expander instead.", - deps = ["//xla/hlo/transforms/expanders:qr_expander"], -) - -cc_library( - name = "real_imag_expander", - hdrs = ["real_imag_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:real_imag_expander instead.", - deps = ["//xla/hlo/transforms/expanders:real_imag_expander"], -) - -cc_library( - name = "eigh_expander", - hdrs = ["eigh_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:eigh_expander instead.", - deps = ["//xla/hlo/transforms/expanders:eigh_expander"], -) - -cc_library( - name = "convolution_4d_expander", - hdrs = ["convolution_4d_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:convolution_4d_expander instead.", - deps = ["//xla/hlo/transforms/expanders:convolution_4d_expander"], -) - -cc_library( - name = "convolution_pred_expander", - hdrs = ["convolution_pred_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:convolution_pred_expander instead.", - deps = ["//xla/hlo/transforms/expanders:convolution_pred_expander"], -) - xla_test( name = "batchnorm_expander_test", size = "small", @@ -2437,21 +2281,6 @@ xla_test( ], ) -cc_library( - name = "algebraic_simplifier", - hdrs = ["algebraic_simplifier.h"], - copts = tsl_copts(), - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:algebraic_simplifier instead.", - deps = ["//xla/hlo/transforms/simplifiers:algebraic_simplifier"], -) - -cc_library( - name = "tree_reduction_rewriter", - hdrs = ["tree_reduction_rewriter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:tree_reduction_rewriter instead.", - deps = ["//xla/hlo/transforms/simplifiers:tree_reduction_rewriter"], -) - xla_test( name = "algebraic_simplifier_overflow_test", srcs = ["algebraic_simplifier_overflow_test.cc"], @@ -2463,27 +2292,6 @@ xla_test( ], ) -cc_library( - name = "simplify_fp_conversions", - hdrs = ["simplify_fp_conversions.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:simplify_fp_conversions instead.", - deps = ["//xla/hlo/transforms/simplifiers:simplify_fp_conversions"], -) - -cc_library( - name = "logistic_expander", - hdrs = ["logistic_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:logistic_expander instead.", - deps = ["//xla/hlo/transforms/expanders:logistic_expander"], -) - -cc_library( - name = "collectives_schedule_linearizer", - hdrs = ["collectives_schedule_linearizer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:collectives_schedule_linearizer instead.", - deps = ["//xla/hlo/transforms/collectives:collectives_schedule_linearizer"], -) - cc_library( name = "collective_combiner_utils", hdrs = ["collective_combiner_utils.h"], @@ -2518,41 +2326,6 @@ cc_library( ], ) -cc_library( - name = "all_gather_broadcast_reorder", - hdrs = ["all_gather_broadcast_reorder.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:all_gather_broadcast_reorder instead.", - deps = ["//xla/hlo/transforms/collectives:all_gather_broadcast_reorder"], -) - -cc_library( - name = "bitcast_dtypes_expander", - hdrs = ["bitcast_dtypes_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:bitcast_dtypes_expander instead.", - deps = ["//xla/hlo/transforms/expanders:bitcast_dtypes_expander"], -) - -cc_library( - name = "all_gather_combiner", - hdrs = ["all_gather_combiner.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:all_gather_combiner instead.", - deps = ["//xla/hlo/transforms/collectives:all_gather_combiner"], -) - -cc_library( - name = "all_reduce_combiner", - hdrs = ["all_reduce_combiner.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:all_reduce_combiner instead.", - deps = ["//xla/hlo/transforms/collectives:all_reduce_combiner"], -) - -cc_library( - name = "all_reduce_contiguous", - hdrs = ["all_reduce_contiguous.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:all_reduce_contiguous instead.", - deps = ["//xla/hlo/transforms/collectives:all_reduce_contiguous"], -) - cc_library( name = "reduce_scatter_combiner", srcs = ["reduce_scatter_combiner.cc"], @@ -2735,13 +2508,6 @@ xla_cc_test( ], ) -cc_library( - name = "batch_dot_simplification", - hdrs = ["batch_dot_simplification.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:batch_dot_simplification instead.", - deps = ["//xla/hlo/transforms/simplifiers:batch_dot_simplification"], -) - xla_cc_test( name = "gather_expander_test", srcs = ["gather_expander_test.cc"], @@ -2850,13 +2616,6 @@ xla_cc_test( ], ) -cc_library( - name = "convolution_group_converter", - hdrs = ["convolution_group_converter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:convolution_group_converter instead.", - deps = ["//xla/hlo/transforms/simplifiers:convolution_group_converter"], -) - cc_library( name = "space_to_batch_converter", srcs = ["space_to_batch_converter.cc"], @@ -3057,13 +2816,6 @@ xla_cc_test( ], ) -cc_library( - name = "while_loop_analysis", - hdrs = ["while_loop_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:while_loop_analysis instead.", - deps = ["//xla/hlo/analysis:while_loop_analysis"], -) - cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -3121,48 +2873,6 @@ xla_cc_test( ], ) -cc_library( - name = "while_loop_trip_count_annotator", - hdrs = ["while_loop_trip_count_annotator.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:while_loop_trip_count_annotator instead.", - deps = ["//xla/hlo/transforms:while_loop_trip_count_annotator"], -) - -cc_library( - name = "defuser", - hdrs = ["defuser.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:defuser instead.", - deps = ["//xla/hlo/transforms:defuser"], -) - -cc_library( - name = "dot_decomposer", - hdrs = ["dot_decomposer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:dot_decomposer instead.", - deps = ["//xla/hlo/transforms/expanders:dot_decomposer"], -) - -cc_library( - name = "dot_dimension_merger", - hdrs = ["dot_dimension_merger.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:dot_dimension_merger instead.", - deps = ["//xla/hlo/transforms/simplifiers:dot_dimension_merger"], -) - -cc_library( - name = "dot_merger", - hdrs = ["dot_merger.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:dot_merger instead.", - deps = ["//xla/hlo/transforms/simplifiers:dot_merger"], -) - -cc_library( - name = "convert_mover", - hdrs = ["convert_mover.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:convert_mover instead.", - deps = ["//xla/hlo/transforms/simplifiers:convert_mover"], -) - cc_library( name = "all_to_all_decomposer", srcs = ["all_to_all_decomposer.cc"], @@ -3215,34 +2925,6 @@ xla_cc_test( ], ) -cc_library( - name = "tuple_simplifier", - hdrs = ["tuple_simplifier.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:tuple_simplifier instead.", - deps = ["//xla/hlo/transforms/simplifiers:tuple_simplifier"], -) - -cc_library( - name = "reshape_mover", - hdrs = ["reshape_mover.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:reshape_mover instead.", - deps = ["//xla/hlo/transforms/simplifiers:reshape_mover"], -) - -cc_library( - name = "reshape_decomposer", - hdrs = ["reshape_decomposer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:reshape_decomposer instead.", - deps = ["//xla/hlo/transforms/expanders:reshape_decomposer"], -) - -cc_library( - name = "reduce_decomposer", - hdrs = ["reduce_decomposer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:reduce_decomposer instead.", - deps = ["//xla/hlo/transforms/expanders:reduce_decomposer"], -) - cc_library( name = "dynamic_window_utils", srcs = ["dynamic_window_utils.cc"], @@ -3298,13 +2980,6 @@ cc_library( ], ) -cc_library( - name = "dynamic_dimension_simplifier", - hdrs = ["dynamic_dimension_simplifier.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier instead.", - deps = ["//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier"], -) - cc_library( name = "dynamic_padder", srcs = ["dynamic_padder.cc"], @@ -3802,27 +3477,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_value_semantics_analysis", - hdrs = ["hlo_value_semantics_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_value_semantics_analysis instead.", - deps = ["//xla/hlo/analysis:hlo_value_semantics_analysis"], -) - -cc_library( - name = "hlo_replication_analysis", - hdrs = ["hlo_replication_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_replication_analysis instead.", - deps = ["//xla/hlo/analysis:hlo_replication_analysis"], -) - -cc_library( - name = "hlo_liveness_analysis", - hdrs = ["hlo_liveness_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_liveness_analysis instead.", - deps = ["//xla/hlo/analysis:hlo_liveness_analysis"], -) - cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -3842,27 +3496,6 @@ cc_library( ], ) -cc_library( - name = "hlo_alias_analysis", - hdrs = ["hlo_alias_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_alias_analysis instead.", - deps = ["//xla/hlo/analysis:hlo_alias_analysis"], -) - -cc_library( - name = "logical_buffer_analysis", - hdrs = ["logical_buffer_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:logical_buffer_analysis instead.", - deps = ["//xla/hlo/analysis:logical_buffer_analysis"], -) - -cc_library( - name = "tuple_points_to_analysis", - hdrs = ["tuple_points_to_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:tuple_points_to_analysis instead.", - deps = ["//xla/hlo/analysis:tuple_points_to_analysis"], -) - cc_library( name = "compilation_cache", srcs = ["compilation_cache.cc"], @@ -4042,20 +3675,6 @@ xla_cc_test( ], ) -cc_library( - name = "memory_space_propagation", - hdrs = ["memory_space_propagation.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:memory_space_propagation instead.", - deps = ["//xla/hlo/transforms:memory_space_propagation"], -) - -cc_library( - name = "hlo_dce", - hdrs = ["hlo_dce.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_dce instead.", - deps = ["//xla/hlo/transforms/simplifiers:hlo_dce"], -) - cc_library( name = "hlo_module_dce", srcs = ["hlo_module_dce.cc"], @@ -4171,21 +3790,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_rematerialization", - hdrs = ["hlo_rematerialization.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_rematerialization instead.", - deps = ["//xla/hlo/transforms/simplifiers:hlo_rematerialization"], -) - -cc_library( - name = "hlo_rematerialization_test_utils", - testonly = 1, - hdrs = ["hlo_rematerialization_test_utils.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_rematerialization_test_utils instead.", - deps = ["//xla/hlo/transforms/simplifiers:hlo_rematerialization_test_utils"], -) - xla_cc_test( name = "hlo_module_dce_test", srcs = ["hlo_module_dce_test.cc"], @@ -4235,40 +3839,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_pass", - hdrs = [ - "hlo_pass_fix.h", - "hlo_pass_interface.h", - ], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass instead.", - deps = [ - "//xla:status_macros", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "hlo_pass_pipeline", - hdrs = ["hlo_pass_pipeline.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass_pipeline instead.", - deps = [ - ":compilation_stats", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "hlo_cse", srcs = ["hlo_cse.cc"], @@ -4396,20 +3966,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_element_type_converter", - hdrs = ["hlo_element_type_converter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_element_type_converter instead.", - deps = ["//xla/hlo/transforms/simplifiers:hlo_element_type_converter"], -) - -cc_library( - name = "conditional_canonicalizer", - hdrs = ["conditional_canonicalizer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:conditional_canonicalizer instead.", - deps = ["//xla/hlo/transforms/simplifiers:conditional_canonicalizer"], -) - cc_library( name = "maybe_owning_device_memory", srcs = [ @@ -4677,13 +4233,6 @@ xla_cc_test( ], ) -cc_library( - name = "zero_sized_hlo_elimination", - hdrs = ["zero_sized_hlo_elimination.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination instead.", - deps = ["//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination"], -) - cc_library( name = "stream_pool", srcs = ["stream_pool.cc"], @@ -4859,20 +4408,6 @@ cc_library( ], ) -cc_library( - name = "sort_simplifier", - hdrs = ["sort_simplifier.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:sort_simplifier instead.", - deps = ["//xla/hlo/transforms/simplifiers:sort_simplifier"], -) - -cc_library( - name = "stable_sort_expander", - hdrs = ["stable_sort_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:stable_sort_expander instead.", - deps = ["//xla/hlo/transforms/expanders:stable_sort_expander"], -) - cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], @@ -4911,13 +4446,6 @@ xla_cc_test( ], ) -cc_library( - name = "root_instruction_sinker", - hdrs = ["root_instruction_sinker.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:root_instruction_sinker instead.", - deps = ["//xla/hlo/transforms/simplifiers:root_instruction_sinker"], -) - cc_library( name = "memory_annotations_hdr", hdrs = ["memory_annotations.h"], @@ -4926,27 +4454,6 @@ cc_library( ], ) -cc_library( - name = "convert_memory_placement_to_internal_annotations", - hdrs = ["convert_memory_placement_to_internal_annotations.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_memory_placement_to_internal_annotations instead.", - deps = ["//xla/hlo/transforms:convert_memory_placement_to_internal_annotations"], -) - -cc_library( - name = "host_memory_transfer_asyncifier", - hdrs = ["host_memory_transfer_asyncifier.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:host_memory_transfer_asyncifier instead.", - deps = ["//xla/hlo/transforms/simplifiers:host_memory_transfer_asyncifier"], -) - -cc_library( - name = "host_offload_legalize", - hdrs = ["host_offload_legalize.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_offload_legalize instead.", - deps = ["//xla/hlo/transforms:host_offload_legalize"], -) - cc_library( name = "host_offload_utils", srcs = ["host_offload_utils.cc"], @@ -5000,20 +4507,6 @@ xla_cc_test( ], ) -cc_library( - name = "host_offloader", - hdrs = ["host_offloader.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_offloader instead.", - deps = ["//xla/hlo/transforms:host_offloader"], -) - -cc_library( - name = "host_offloading_prepare", - hdrs = ["host_offloading_prepare.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_offloading_prepare instead.", - deps = ["//xla/hlo/transforms:host_offloading_prepare"], -) - cc_library( name = "while_util", srcs = ["while_util.cc"], @@ -5233,13 +4726,6 @@ xla_cc_test( ], ) -cc_library( - name = "fusion_constant_sinking", - hdrs = ["fusion_constant_sinking.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:fusion_constant_sinking instead.", - deps = ["//xla/hlo/transforms/simplifiers:fusion_constant_sinking"], -) - cc_library( name = "while_loop_constant_sinking", srcs = ["while_loop_constant_sinking.cc"], @@ -5322,13 +4808,6 @@ xla_cc_test( ], ) -cc_library( - name = "despecializer", - hdrs = ["despecializer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:despecializer instead.", - deps = ["//xla/hlo/transforms:despecializer"], -) - cc_library( name = "source_map_util", srcs = [], @@ -5340,33 +4819,6 @@ cc_library( ], ) -cc_library( - name = "indexed_array_analysis", - hdrs = ["indexed_array_analysis.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:indexed_array_analysis instead.", - deps = ["//xla/hlo/analysis:indexed_array_analysis"], -) - -cc_library( - name = "hlo_parser", - hdrs = ["hlo_parser.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/parser:hlo_parser instead.", - deps = [ - "//xla/hlo/parser:hlo_parser", - ], -) - -cc_library( - name = "hlo_lexer", - hdrs = [ - "hlo_lexer.h", - ], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/parser:hlo_lexer instead.", - deps = [ - "//xla/hlo/parser:hlo_lexer", - ], -) - cc_library( name = "map_inliner", srcs = ["map_inliner.cc"], @@ -5387,20 +4839,6 @@ cc_library( ], ) -cc_library( - name = "optimize_input_output_buffer_alias", - hdrs = ["optimize_input_output_buffer_alias.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias instead.", - deps = ["//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias"], -) - -cc_library( - name = "ar_crs_combiner", - hdrs = ["ar_crs_combiner.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:ar_crs_combiner instead.", - deps = ["//xla/hlo/transforms/simplifiers:ar_crs_combiner"], -) - cc_library( name = "compilation_stats", srcs = ["compilation_stats.cc"], @@ -5414,13 +4852,6 @@ cc_library( ], ) -cc_library( - name = "dynamic_index_splitter", - hdrs = ["dynamic_index_splitter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:dynamic_index_splitter instead.", - deps = ["//xla/hlo/transforms/expanders:dynamic_index_splitter"], -) - xla_cc_test( name = "map_inliner_test", srcs = ["map_inliner_test.cc"], @@ -5482,13 +4913,6 @@ xla_cc_test( ], ) -cc_library( - name = "slice_sinker", - hdrs = ["slice_sinker.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:slice_sinker instead.", - deps = ["//xla/hlo/transforms/simplifiers:slice_sinker"], -) - cc_library( name = "custom_call_target_registry", srcs = ["custom_call_target_registry.cc"], @@ -5586,20 +5010,6 @@ cc_library( deps = [":custom_call_status"], ) -cc_library( - name = "rng_expander", - hdrs = ["rng_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:rng_expander instead.", - deps = ["//xla/hlo/transforms/expanders:rng_expander"], -) - -cc_library( - name = "rng_bit_generator_expander", - hdrs = ["rng_bit_generator_expander.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:rng_bit_generator_expander instead.", - deps = ["//xla/hlo/transforms/expanders:rng_bit_generator_expander"], -) - cc_library( name = "slow_operation_alarm", srcs = ["slow_operation_alarm.cc"], @@ -5664,13 +5074,6 @@ cc_library( ], ) -cc_library( - name = "collective_transformation_reorderer", - hdrs = ["collective_transformation_reorderer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:collective_transformation_reorderer instead.", - deps = ["//xla/hlo/transforms/collectives:collective_transformation_reorderer"], -) - xla_cc_test( name = "collective_ops_utils_test", srcs = ["collective_ops_utils_test.cc"], @@ -5746,13 +5149,6 @@ cc_library( deps = ["//xla/hlo/transforms:operand_upcaster"], ) -cc_library( - name = "result_caster", - hdrs = ["result_caster.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:result_caster instead.", - deps = ["//xla/hlo/transforms/simplifiers:result_caster"], -) - cc_library( name = "global_device_id", srcs = ["global_device_id.cc"], @@ -5765,13 +5161,6 @@ cc_library( ], ) -cc_library( - name = "convert_operand_folding", - hdrs = ["convert_operand_folding.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:convert_operand_folding instead.", - deps = ["//xla/hlo/transforms/simplifiers:convert_operand_folding"], -) - cc_library( name = "xla_debug_info_manager", srcs = [ @@ -6001,13 +5390,6 @@ cc_library( ], ) -cc_library( - name = "instruction_hoister", - hdrs = ["instruction_hoister.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:instruction_hoister instead.", - deps = ["//xla/hlo/transforms/simplifiers:instruction_hoister"], -) - cc_library( name = "scatter_simplifier", srcs = ["scatter_simplifier.cc"], @@ -6123,13 +5505,6 @@ cc_library( ], ) -cc_library( - name = "gather_simplifier", - hdrs = ["gather_simplifier.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:gather_simplifier instead.", - deps = ["//xla/hlo/transforms/simplifiers:gather_simplifier"], -) - cc_library( name = "batched_gather_scatter_normalizer", srcs = ["batched_gather_scatter_normalizer.cc"], @@ -6150,20 +5525,6 @@ cc_library( ], ) -cc_library( - name = "reduce_window_rewriter", - hdrs = ["reduce_window_rewriter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:reduce_window_rewriter instead.", - deps = ["//xla/hlo/transforms/simplifiers:reduce_window_rewriter"], -) - -cc_library( - name = "stochastic_convert_decomposer", - hdrs = ["stochastic_convert_decomposer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/expanders:stochastic_convert_decomposer instead.", - deps = ["//xla/hlo/transforms/expanders:stochastic_convert_decomposer"], -) - cc_library( name = "metrics_hook_interface", hdrs = ["metrics_hook_interface.h"], @@ -6175,13 +5536,6 @@ cc_library( ], ) -cc_library( - name = "sub_byte_normalization", - hdrs = ["sub_byte_normalization.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:sub_byte_normalization instead.", - deps = ["//xla/hlo/transforms/simplifiers:sub_byte_normalization"], -) - xla_cc_test( name = "batched_gather_scatter_normalizer_test", srcs = ["batched_gather_scatter_normalizer_test.cc"], @@ -6493,13 +5847,6 @@ cc_library( ], ) -cc_library( - name = "add_original_value", - hdrs = ["add_original_value.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms:add_original_value instead.", - deps = ["//xla/hlo/transforms:add_original_value"], -) - xla_cc_test( name = "propagate_original_value_test", srcs = ["propagate_original_value_test.cc"], @@ -6512,13 +5859,6 @@ xla_cc_test( ], ) -cc_library( - name = "infeed_token_propagation", - hdrs = ["infeed_token_propagation.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/transforms/collectives:infeed_token_propagation instead.", - deps = ["//xla/hlo/transforms/collectives:infeed_token_propagation"], -) - cc_library( name = "while_loop_pipeline_unroller", srcs = ["while_loop_pipeline_unroller.cc"], diff --git a/third_party/xla/xla/service/add_original_value.h b/third_party/xla/xla/service/add_original_value.h deleted file mode 100644 index 2a68cca88b0e94..00000000000000 --- a/third_party/xla/xla/service/add_original_value.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ -#define XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/add_original_value.h" - -#endif // XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h deleted file mode 100644 index 82fc943903041e..00000000000000 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ -#define XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" - -#endif // XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder.h b/third_party/xla/xla/service/all_gather_broadcast_reorder.h deleted file mode 100644 index ce722207a37a62..00000000000000 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ -#define XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h" - -#endif // XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ diff --git a/third_party/xla/xla/service/all_gather_combiner.h b/third_party/xla/xla/service/all_gather_combiner.h deleted file mode 100644 index 9c7029207c6c3d..00000000000000 --- a/third_party/xla/xla/service/all_gather_combiner.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ALL_GATHER_COMBINER_H_ -#define XLA_SERVICE_ALL_GATHER_COMBINER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/all_gather_combiner.h" - -#endif // XLA_SERVICE_ALL_GATHER_COMBINER_H_ diff --git a/third_party/xla/xla/service/all_reduce_combiner.h b/third_party/xla/xla/service/all_reduce_combiner.h deleted file mode 100644 index f0f3a200f22f1f..00000000000000 --- a/third_party/xla/xla/service/all_reduce_combiner.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ALL_REDUCE_COMBINER_H_ -#define XLA_SERVICE_ALL_REDUCE_COMBINER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/all_reduce_combiner.h" - -#endif // XLA_SERVICE_ALL_REDUCE_COMBINER_H_ diff --git a/third_party/xla/xla/service/all_reduce_contiguous.h b/third_party/xla/xla/service/all_reduce_contiguous.h deleted file mode 100644 index 7dc1a6501259d4..00000000000000 --- a/third_party/xla/xla/service/all_reduce_contiguous.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ -#define XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/all_reduce_contiguous.h" - -#endif // XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ diff --git a/third_party/xla/xla/service/all_reduce_folder.h b/third_party/xla/xla/service/all_reduce_folder.h deleted file mode 100644 index 6054de621c1d03..00000000000000 --- a/third_party/xla/xla/service/all_reduce_folder.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ALL_REDUCE_FOLDER_H_ -#define XLA_SERVICE_ALL_REDUCE_FOLDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/all_reduce_folder.h" - -#endif // XLA_SERVICE_ALL_REDUCE_FOLDER_H_ diff --git a/third_party/xla/xla/service/ar_crs_combiner.h b/third_party/xla/xla/service/ar_crs_combiner.h deleted file mode 100644 index 57b36ee2b1599d..00000000000000 --- a/third_party/xla/xla/service/ar_crs_combiner.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_AR_CRS_COMBINER_H_ -#define XLA_SERVICE_AR_CRS_COMBINER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/ar_crs_combiner.h" - -#endif // XLA_SERVICE_AR_CRS_COMBINER_H_ diff --git a/third_party/xla/xla/service/async_collective_creator.h b/third_party/xla/xla/service/async_collective_creator.h deleted file mode 100644 index f3141f50ece42a..00000000000000 --- a/third_party/xla/xla/service/async_collective_creator.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ASYNC_COLLECTIVE_CREATOR_H_ -#define XLA_SERVICE_ASYNC_COLLECTIVE_CREATOR_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/async_collective_creator.h" - -#endif // XLA_SERVICE_ASYNC_COLLECTIVE_CREATOR_H_ diff --git a/third_party/xla/xla/service/batch_dot_simplification.h b/third_party/xla/xla/service/batch_dot_simplification.h deleted file mode 100644 index 381b67955adf09..00000000000000 --- a/third_party/xla/xla/service/batch_dot_simplification.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ -#define XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" - -#endif // XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.h b/third_party/xla/xla/service/bfloat16_conversion_folding.h deleted file mode 100644 index deb5675fc85cbe..00000000000000 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ -#define XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h" - -#endif // XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/third_party/xla/xla/service/bfloat16_propagation.h b/third_party/xla/xla/service/bfloat16_propagation.h deleted file mode 100644 index e3a0e0fab40b4c..00000000000000 --- a/third_party/xla/xla/service/bfloat16_propagation.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_BFLOAT16_PROPAGATION_H_ -#define XLA_SERVICE_BFLOAT16_PROPAGATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/bfloat16_propagation.h" - -#endif // XLA_SERVICE_BFLOAT16_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.h b/third_party/xla/xla/service/bitcast_dtypes_expander.h deleted file mode 100644 index 7824af39cf5829..00000000000000 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ -#define XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" - -#endif // XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ diff --git a/third_party/xla/xla/service/broadcast_canonicalizer.h b/third_party/xla/xla/service/broadcast_canonicalizer.h deleted file mode 100644 index efedf3ed3481ab..00000000000000 --- a/third_party/xla/xla/service/broadcast_canonicalizer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_BROADCAST_CANONICALIZER_H_ -#define XLA_SERVICE_BROADCAST_CANONICALIZER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" - -#endif // XLA_SERVICE_BROADCAST_CANONICALIZER_H_ diff --git a/third_party/xla/xla/service/cholesky_expander.h b/third_party/xla/xla/service/cholesky_expander.h deleted file mode 100644 index 7e9e7332e917f0..00000000000000 --- a/third_party/xla/xla/service/cholesky_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CHOLESKY_EXPANDER_H_ -#define XLA_SERVICE_CHOLESKY_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/cholesky_expander.h" - -#endif // XLA_SERVICE_CHOLESKY_EXPANDER_H_ diff --git a/third_party/xla/xla/service/collective_quantizer.h b/third_party/xla/xla/service/collective_quantizer.h deleted file mode 100644 index b63a3138b91e0b..00000000000000 --- a/third_party/xla/xla/service/collective_quantizer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ -#define XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/collective_quantizer.h" - -#endif // XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ diff --git a/third_party/xla/xla/service/collective_transformation_reorderer.h b/third_party/xla/xla/service/collective_transformation_reorderer.h deleted file mode 100644 index 2bbae612c5e4c9..00000000000000 --- a/third_party/xla/xla/service/collective_transformation_reorderer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ -#define XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/collective_transformation_reorderer.h" - -#endif // XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ diff --git a/third_party/xla/xla/service/collectives_schedule_linearizer.h b/third_party/xla/xla/service/collectives_schedule_linearizer.h deleted file mode 100644 index 27f0de0032e2fa..00000000000000 --- a/third_party/xla/xla/service/collectives_schedule_linearizer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_ -#define XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/collectives_schedule_linearizer.h" - -#endif // XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_ diff --git a/third_party/xla/xla/service/comparison_expander.h b/third_party/xla/xla/service/comparison_expander.h deleted file mode 100644 index 333375478e59b7..00000000000000 --- a/third_party/xla/xla/service/comparison_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_COMPARISON_EXPANDER_H_ -#define XLA_SERVICE_COMPARISON_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/comparison_expander.h" - -#endif // XLA_SERVICE_COMPARISON_EXPANDER_H_ diff --git a/third_party/xla/xla/service/conditional_canonicalizer.h b/third_party/xla/xla/service/conditional_canonicalizer.h deleted file mode 100644 index 6a857fc4cf208c..00000000000000 --- a/third_party/xla/xla/service/conditional_canonicalizer.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ -#define XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" - -#endif // XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync.h b/third_party/xla/xla/service/convert_async_collectives_to_sync.h deleted file mode 100644 index 3e3884b98a0fbb..00000000000000 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ -#define XLA_SERVICE_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h" - -#endif // XLA_SERVICE_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h deleted file mode 100644 index 17f629fd058847..00000000000000 --- a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - -#ifndef XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ -#define XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h" - -#endif // XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/convert_mover.h b/third_party/xla/xla/service/convert_mover.h deleted file mode 100644 index a335a4583caecd..00000000000000 --- a/third_party/xla/xla/service/convert_mover.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CONVERT_MOVER_H_ -#define XLA_SERVICE_CONVERT_MOVER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/convert_mover.h" - -#endif // XLA_SERVICE_CONVERT_MOVER_H_ diff --git a/third_party/xla/xla/service/convert_operand_folding.h b/third_party/xla/xla/service/convert_operand_folding.h deleted file mode 100644 index 863cd7da8d4914..00000000000000 --- a/third_party/xla/xla/service/convert_operand_folding.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_ -#define XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/convert_operand_folder.h" - -#endif // XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_ diff --git a/third_party/xla/xla/service/convolution_4d_expander.h b/third_party/xla/xla/service/convolution_4d_expander.h deleted file mode 100644 index 2a290290ebddef..00000000000000 --- a/third_party/xla/xla/service/convolution_4d_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_ -#define XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/convolution_4d_expander.h" - -#endif // XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_ diff --git a/third_party/xla/xla/service/convolution_group_converter.h b/third_party/xla/xla/service/convolution_group_converter.h deleted file mode 100644 index 21d68d2751a0fc..00000000000000 --- a/third_party/xla/xla/service/convolution_group_converter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ -#define XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/convolution_group_converter.h" - -#endif // XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ diff --git a/third_party/xla/xla/service/convolution_pred_expander.h b/third_party/xla/xla/service/convolution_pred_expander.h deleted file mode 100644 index 84c57681afb00a..00000000000000 --- a/third_party/xla/xla/service/convolution_pred_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CONVOLUTION_PRED_EXPANDER_H_ -#define XLA_SERVICE_CONVOLUTION_PRED_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/convolution_pred_expander.h" - -#endif // XLA_SERVICE_CONVOLUTION_PRED_EXPANDER_H_ diff --git a/third_party/xla/xla/service/defuser.h b/third_party/xla/xla/service/defuser.h deleted file mode 100644 index 46ad02630dfce0..00000000000000 --- a/third_party/xla/xla/service/defuser.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_DEFUSER_H_ -#define XLA_SERVICE_DEFUSER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/defuser.h" - -#endif // XLA_SERVICE_DEFUSER_H_ diff --git a/third_party/xla/xla/service/despecializer.h b/third_party/xla/xla/service/despecializer.h deleted file mode 100644 index c230c27805b012..00000000000000 --- a/third_party/xla/xla/service/despecializer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_DESPECIALIZER_H_ -#define XLA_SERVICE_DESPECIALIZER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/despecializer.h" - -#endif // XLA_SERVICE_DESPECIALIZER_H_ diff --git a/third_party/xla/xla/service/dot_decomposer.h b/third_party/xla/xla/service/dot_decomposer.h deleted file mode 100644 index 1e6f4015f169a3..00000000000000 --- a/third_party/xla/xla/service/dot_decomposer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_DOT_DECOMPOSER_H_ -#define XLA_SERVICE_DOT_DECOMPOSER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/dot_decomposer.h" - -#endif // XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/dot_dimension_merger.h b/third_party/xla/xla/service/dot_dimension_merger.h deleted file mode 100644 index dcc23bc149217a..00000000000000 --- a/third_party/xla/xla/service/dot_dimension_merger.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_DOT_DIMENSION_MERGER_H_ -#define XLA_SERVICE_DOT_DIMENSION_MERGER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/dot_dimension_merger.h" - -#endif // XLA_SERVICE_DOT_DIMENSION_MERGER_H_ diff --git a/third_party/xla/xla/service/dot_merger.h b/third_party/xla/xla/service/dot_merger.h deleted file mode 100644 index 5f8c1160686c27..00000000000000 --- a/third_party/xla/xla/service/dot_merger.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_DOT_MERGER_H_ -#define XLA_SERVICE_DOT_MERGER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/dot_merger.h" - -#endif // XLA_SERVICE_DOT_MERGER_H_ diff --git a/third_party/xla/xla/service/dynamic_dimension_simplifier.h b/third_party/xla/xla/service/dynamic_dimension_simplifier.h deleted file mode 100644 index 0824118cb48bb5..00000000000000 --- a/third_party/xla/xla/service/dynamic_dimension_simplifier.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ -#define XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" - -#endif // XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/dynamic_index_splitter.h b/third_party/xla/xla/service/dynamic_index_splitter.h deleted file mode 100644 index 670d297da852ce..00000000000000 --- a/third_party/xla/xla/service/dynamic_index_splitter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ -#define XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" - -#endif // XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/third_party/xla/xla/service/eigh_expander.h b/third_party/xla/xla/service/eigh_expander.h deleted file mode 100644 index 5ef10cffe0bbcc..00000000000000 --- a/third_party/xla/xla/service/eigh_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_EIGH_EXPANDER_H_ -#define XLA_SERVICE_EIGH_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/eigh_expander.h" - -#endif // XLA_SERVICE_EIGH_EXPANDER_H_ diff --git a/third_party/xla/xla/service/flatten_call_graph.h b/third_party/xla/xla/service/flatten_call_graph.h deleted file mode 100644 index ff5af7039ee3b7..00000000000000 --- a/third_party/xla/xla/service/flatten_call_graph.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Flatten the call graph for an HLO module into a tree. - -#ifndef XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ -#define XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" - -#endif // XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ diff --git a/third_party/xla/xla/service/float_normalization.h b/third_party/xla/xla/service/float_normalization.h deleted file mode 100644 index db54be02642d8d..00000000000000 --- a/third_party/xla/xla/service/float_normalization.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_FLOAT_NORMALIZATION_H_ -#define XLA_SERVICE_FLOAT_NORMALIZATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/float_normalization.h" - -#endif // XLA_SERVICE_FLOAT_NORMALIZATION_H_ diff --git a/third_party/xla/xla/service/fusion_constant_sinking.h b/third_party/xla/xla/service/fusion_constant_sinking.h deleted file mode 100644 index 15f15e41af05c9..00000000000000 --- a/third_party/xla/xla/service/fusion_constant_sinking.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ -#define XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/fusion_constant_sinking.h" - -#endif // XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ diff --git a/third_party/xla/xla/service/gather_simplifier.h b/third_party/xla/xla/service/gather_simplifier.h deleted file mode 100644 index 0cbcadc09b0d26..00000000000000 --- a/third_party/xla/xla/service/gather_simplifier.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GATHER_SIMPLIFIER_H_ -#define XLA_SERVICE_GATHER_SIMPLIFIER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/gather_simplifier.h" - -#endif // XLA_SERVICE_GATHER_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/hlo_alias_analysis.h b/third_party/xla/xla/service/hlo_alias_analysis.h deleted file mode 100644 index e2789adda9f4bf..00000000000000 --- a/third_party/xla/xla/service/hlo_alias_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ -#define XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_alias_analysis.h" - -#endif // XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_computation_deduplicator.h b/third_party/xla/xla/service/hlo_computation_deduplicator.h deleted file mode 100644 index bf82bc4ff4204c..00000000000000 --- a/third_party/xla/xla/service/hlo_computation_deduplicator.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_ -#define XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h" - -#endif // XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_ diff --git a/third_party/xla/xla/service/hlo_constant_folding.h b/third_party/xla/xla/service/hlo_constant_folding.h deleted file mode 100644 index 5f82f95d863ebb..00000000000000 --- a/third_party/xla/xla/service/hlo_constant_folding.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ -#define XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" - -#endif // XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis.h b/third_party/xla/xla/service/hlo_dataflow_analysis.h deleted file mode 100644 index 571638e53cf80f..00000000000000 --- a/third_party/xla/xla/service/hlo_dataflow_analysis.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Analysis for determining the possible set of values for all positions -// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped -// tracking values across computation boundaries. - -#ifndef XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ -#define XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_dataflow_analysis.h" - -#endif // XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_dce.h b/third_party/xla/xla/service/hlo_dce.h deleted file mode 100644 index d0ce0665d0d0df..00000000000000 --- a/third_party/xla/xla/service/hlo_dce.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_DCE_H_ -#define XLA_SERVICE_HLO_DCE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_dce.h" - -#endif // XLA_SERVICE_HLO_DCE_H_ diff --git a/third_party/xla/xla/service/hlo_element_type_converter.h b/third_party/xla/xla/service/hlo_element_type_converter.h deleted file mode 100644 index 3fed0142430401..00000000000000 --- a/third_party/xla/xla/service/hlo_element_type_converter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ -#define XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_element_type_converter.h" - -#endif // XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/third_party/xla/xla/service/hlo_lexer.h b/third_party/xla/xla/service/hlo_lexer.h deleted file mode 100644 index aad399ed291f3a..00000000000000 --- a/third_party/xla/xla/service/hlo_lexer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_LEXER_H_ -#define XLA_SERVICE_HLO_LEXER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/parser/hlo_lexer.h" - -#endif // XLA_SERVICE_HLO_LEXER_H_ diff --git a/third_party/xla/xla/service/hlo_liveness_analysis.h b/third_party/xla/xla/service/hlo_liveness_analysis.h deleted file mode 100644 index fd590408d53934..00000000000000 --- a/third_party/xla/xla/service/hlo_liveness_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ -#define XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_liveness_analysis.h" - -#endif // XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.h b/third_party/xla/xla/service/hlo_memory_scheduler.h deleted file mode 100644 index 09d8b432f998db..00000000000000 --- a/third_party/xla/xla/service/hlo_memory_scheduler.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ -#define XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" - -#endif // XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/third_party/xla/xla/service/hlo_ordering.h b/third_party/xla/xla/service/hlo_ordering.h deleted file mode 100644 index d035368156aed2..00000000000000 --- a/third_party/xla/xla/service/hlo_ordering.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_ORDERING_H_ -#define XLA_SERVICE_HLO_ORDERING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_ordering.h" - -#endif // XLA_SERVICE_HLO_ORDERING_H_ diff --git a/third_party/xla/xla/service/hlo_parser.h b/third_party/xla/xla/service/hlo_parser.h deleted file mode 100644 index 6a9e8d8be6039d..00000000000000 --- a/third_party/xla/xla/service/hlo_parser.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_PARSER_H_ -#define XLA_SERVICE_HLO_PARSER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/parser/hlo_parser.h" - -#endif // XLA_SERVICE_HLO_PARSER_H_ diff --git a/third_party/xla/xla/service/hlo_pass_fix.h b/third_party/xla/xla/service/hlo_pass_fix.h deleted file mode 100644 index c7dab4303b6e1a..00000000000000 --- a/third_party/xla/xla/service/hlo_pass_fix.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_PASS_FIX_H_ -#define XLA_SERVICE_HLO_PASS_FIX_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/pass/hlo_pass_fix.h" - -#endif // XLA_SERVICE_HLO_PASS_FIX_H_ diff --git a/third_party/xla/xla/service/hlo_pass_interface.h b/third_party/xla/xla/service/hlo_pass_interface.h deleted file mode 100644 index 1b6a373b3a1785..00000000000000 --- a/third_party/xla/xla/service/hlo_pass_interface.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_PASS_INTERFACE_H_ -#define XLA_SERVICE_HLO_PASS_INTERFACE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/pass/hlo_pass_interface.h" - -#endif // XLA_SERVICE_HLO_PASS_INTERFACE_H_ diff --git a/third_party/xla/xla/service/hlo_pass_pipeline.h b/third_party/xla/xla/service/hlo_pass_pipeline.h deleted file mode 100644 index 83d693ccfef3f8..00000000000000 --- a/third_party/xla/xla/service/hlo_pass_pipeline.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_PASS_PIPELINE_H_ -#define XLA_SERVICE_HLO_PASS_PIPELINE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/pass/hlo_pass_pipeline.h" - -#endif // XLA_SERVICE_HLO_PASS_PIPELINE_H_ diff --git a/third_party/xla/xla/service/hlo_rematerialization.h b/third_party/xla/xla/service/hlo_rematerialization.h deleted file mode 100644 index 0dcdcee3636247..00000000000000 --- a/third_party/xla/xla/service/hlo_rematerialization.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#ifndef XLA_SERVICE_HLO_REMATERIALIZATION_H_ -#define XLA_SERVICE_HLO_REMATERIALIZATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_rematerialization.h" - -#endif // XLA_SERVICE_HLO_REMATERIALIZATION_H_ diff --git a/third_party/xla/xla/service/hlo_rematerialization_test_utils.h b/third_party/xla/xla/service/hlo_rematerialization_test_utils.h deleted file mode 100644 index 8837169bec82fd..00000000000000 --- a/third_party/xla/xla/service/hlo_rematerialization_test_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Class to create computations for testing rematerialization methods. - -#ifndef XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ -#define XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h" - -#endif // XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ diff --git a/third_party/xla/xla/service/hlo_replication_analysis.h b/third_party/xla/xla/service/hlo_replication_analysis.h deleted file mode 100644 index 85289cb01adb5e..00000000000000 --- a/third_party/xla/xla/service/hlo_replication_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ -#define XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_replication_analysis.h" - -#endif // XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h deleted file mode 100644 index 4a946206879037..00000000000000 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_HLO_VALUE_SEMANTICS_ANALYSIS_H_ -#define XLA_SERVICE_HLO_VALUE_SEMANTICS_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_value_semantics_analysis.h" - -#endif // XLA_SERVICE_HLO_VALUE_SEMANTICS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/host_memory_transfer_asyncifier.h b/third_party/xla/xla/service/host_memory_transfer_asyncifier.h deleted file mode 100644 index d2677f2ab2948e..00000000000000 --- a/third_party/xla/xla/service/host_memory_transfer_asyncifier.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ -#define XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h" - -#endif // XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ diff --git a/third_party/xla/xla/service/host_offload_legalize.h b/third_party/xla/xla/service/host_offload_legalize.h deleted file mode 100644 index 181c82e269a183..00000000000000 --- a/third_party/xla/xla/service/host_offload_legalize.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#ifndef XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ -#define XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/host_offload_legalize.h" - -#endif // XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h deleted file mode 100644 index 0f68eb631fc033..00000000000000 --- a/third_party/xla/xla/service/host_offloader.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#ifndef XLA_SERVICE_HOST_OFFLOADER_H_ -#define XLA_SERVICE_HOST_OFFLOADER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/host_offloader.h" - -#endif // XLA_SERVICE_HOST_OFFLOADER_H_ diff --git a/third_party/xla/xla/service/host_offloading_prepare.h b/third_party/xla/xla/service/host_offloading_prepare.h deleted file mode 100644 index 016bfadb46bad7..00000000000000 --- a/third_party/xla/xla/service/host_offloading_prepare.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - -#ifndef XLA_SERVICE_HOST_OFFLOADING_PREPARE_H_ -#define XLA_SERVICE_HOST_OFFLOADING_PREPARE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/host_offloading_prepare.h" - -#endif // XLA_SERVICE_HOST_OFFLOADING_PREPARE_H_ diff --git a/third_party/xla/xla/service/indexed_array_analysis.h b/third_party/xla/xla/service/indexed_array_analysis.h deleted file mode 100644 index 6dbfd2a1eccf74..00000000000000 --- a/third_party/xla/xla/service/indexed_array_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ -#define XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/indexed_array_analysis.h" - -#endif // XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/infeed_token_propagation.h b/third_party/xla/xla/service/infeed_token_propagation.h deleted file mode 100644 index 31a0aa19ed8c07..00000000000000 --- a/third_party/xla/xla/service/infeed_token_propagation.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ -#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/collectives/infeed_token_propagation.h" - -#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/instruction_hoister.h b/third_party/xla/xla/service/instruction_hoister.h deleted file mode 100644 index bd002321eecf92..00000000000000 --- a/third_party/xla/xla/service/instruction_hoister.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_INSTRUCTION_HOISTER_H_ -#define XLA_SERVICE_INSTRUCTION_HOISTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/instruction_hoister.h" - -#endif // XLA_SERVICE_INSTRUCTION_HOISTER_H_ diff --git a/third_party/xla/xla/service/logical_buffer_analysis.h b/third_party/xla/xla/service/logical_buffer_analysis.h deleted file mode 100644 index 6571558fb208e4..00000000000000 --- a/third_party/xla/xla/service/logical_buffer_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ -#define XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/logical_buffer_analysis.h" - -#endif // XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/logistic_expander.h b/third_party/xla/xla/service/logistic_expander.h deleted file mode 100644 index c0c5ec0c37f0da..00000000000000 --- a/third_party/xla/xla/service/logistic_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_LOGISTIC_EXPANDER_H_ -#define XLA_SERVICE_LOGISTIC_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/logistic_expander.h" - -#endif // XLA_SERVICE_LOGISTIC_EXPANDER_H_ diff --git a/third_party/xla/xla/service/memory_space_propagation.h b/third_party/xla/xla/service/memory_space_propagation.h deleted file mode 100644 index 11676aa45c3ba9..00000000000000 --- a/third_party/xla/xla/service/memory_space_propagation.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ -#define XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/memory_space_propagation.h" - -#endif // XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/op_expander_pass.h b/third_party/xla/xla/service/op_expander_pass.h deleted file mode 100644 index df65b012e1da6c..00000000000000 --- a/third_party/xla/xla/service/op_expander_pass.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_OP_EXPANDER_PASS_H_ -#define XLA_SERVICE_OP_EXPANDER_PASS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/op_expander_pass.h" - -#endif // XLA_SERVICE_OP_EXPANDER_PASS_H_ diff --git a/third_party/xla/xla/service/operand_upcaster.h b/third_party/xla/xla/service/operand_upcaster.h deleted file mode 100644 index 8b237a47e0cd65..00000000000000 --- a/third_party/xla/xla/service/operand_upcaster.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_OPERAND_UPCASTER_H_ -#define XLA_SERVICE_OPERAND_UPCASTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/operand_upcaster.h" - -#endif // XLA_SERVICE_OPERAND_UPCASTER_H_ diff --git a/third_party/xla/xla/service/optimization_barrier_expander.h b/third_party/xla/xla/service/optimization_barrier_expander.h deleted file mode 100644 index b257010fe9a616..00000000000000 --- a/third_party/xla/xla/service/optimization_barrier_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_OPTIMIZATION_BARRIER_EXPANDER_H_ -#define XLA_SERVICE_OPTIMIZATION_BARRIER_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" - -#endif // XLA_SERVICE_OPTIMIZATION_BARRIER_EXPANDER_H_ diff --git a/third_party/xla/xla/service/optimize_input_output_buffer_alias.h b/third_party/xla/xla/service/optimize_input_output_buffer_alias.h deleted file mode 100644 index 04ad98bc488386..00000000000000 --- a/third_party/xla/xla/service/optimize_input_output_buffer_alias.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ -#define XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" - -#endif // XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ diff --git a/third_party/xla/xla/service/pattern_matcher_gmock.h b/third_party/xla/xla/service/pattern_matcher_gmock.h deleted file mode 100644 index f8bea2cff482a7..00000000000000 --- a/third_party/xla/xla/service/pattern_matcher_gmock.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ -#define XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/testlib/pattern_matcher_gmock.h" - -#endif // XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ diff --git a/third_party/xla/xla/service/qr_expander.h b/third_party/xla/xla/service/qr_expander.h deleted file mode 100644 index 067ea64c9166a9..00000000000000 --- a/third_party/xla/xla/service/qr_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_QR_EXPANDER_H_ -#define XLA_SERVICE_QR_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/qr_expander.h" - -#endif // XLA_SERVICE_QR_EXPANDER_H_ diff --git a/third_party/xla/xla/service/real_imag_expander.h b/third_party/xla/xla/service/real_imag_expander.h deleted file mode 100644 index fc87a60e747da6..00000000000000 --- a/third_party/xla/xla/service/real_imag_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_REAL_IMAG_EXPANDER_H_ -#define XLA_SERVICE_REAL_IMAG_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/real_imag_expander.h" - -#endif // XLA_SERVICE_REAL_IMAG_EXPANDER_H_ diff --git a/third_party/xla/xla/service/reduce_decomposer.h b/third_party/xla/xla/service/reduce_decomposer.h deleted file mode 100644 index 12fac9b0dec6b1..00000000000000 --- a/third_party/xla/xla/service/reduce_decomposer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_REDUCE_DECOMPOSER_H_ -#define XLA_SERVICE_REDUCE_DECOMPOSER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/reduce_decomposer.h" - -#endif // XLA_SERVICE_REDUCE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/reduce_window_rewriter.h b/third_party/xla/xla/service/reduce_window_rewriter.h deleted file mode 100644 index 01f1ad58267695..00000000000000 --- a/third_party/xla/xla/service/reduce_window_rewriter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ -#define XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/reduce_window_rewriter.h" - -#endif // XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ diff --git a/third_party/xla/xla/service/reshape_decomposer.h b/third_party/xla/xla/service/reshape_decomposer.h deleted file mode 100644 index f5d5b140b1921f..00000000000000 --- a/third_party/xla/xla/service/reshape_decomposer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_RESHAPE_DECOMPOSER_H_ -#define XLA_SERVICE_RESHAPE_DECOMPOSER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/reshape_decomposer.h" - -#endif // XLA_SERVICE_RESHAPE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/reshape_mover.h b/third_party/xla/xla/service/reshape_mover.h deleted file mode 100644 index 63f2003ed3e8c3..00000000000000 --- a/third_party/xla/xla/service/reshape_mover.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_RESHAPE_MOVER_H_ -#define XLA_SERVICE_RESHAPE_MOVER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/reshape_mover.h" - -#endif // XLA_SERVICE_RESHAPE_MOVER_H_ diff --git a/third_party/xla/xla/service/result_caster.h b/third_party/xla/xla/service/result_caster.h deleted file mode 100644 index d8fc21221f5038..00000000000000 --- a/third_party/xla/xla/service/result_caster.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_RESULT_CASTER_H_ -#define XLA_SERVICE_RESULT_CASTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/result_caster.h" - -#endif // XLA_SERVICE_RESULT_CASTER_H_ diff --git a/third_party/xla/xla/service/rng_bit_generator_expander.h b/third_party/xla/xla/service/rng_bit_generator_expander.h deleted file mode 100644 index 40a8b353804746..00000000000000 --- a/third_party/xla/xla/service/rng_bit_generator_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ -#define XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" - -#endif // XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ diff --git a/third_party/xla/xla/service/rng_expander.h b/third_party/xla/xla/service/rng_expander.h deleted file mode 100644 index 5f1951d7c2c6f4..00000000000000 --- a/third_party/xla/xla/service/rng_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_RNG_EXPANDER_H_ -#define XLA_SERVICE_RNG_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/rng_expander.h" - -#endif // XLA_SERVICE_RNG_EXPANDER_H_ diff --git a/third_party/xla/xla/service/root_instruction_sinker.h b/third_party/xla/xla/service/root_instruction_sinker.h deleted file mode 100644 index 38cc3c7756908e..00000000000000 --- a/third_party/xla/xla/service/root_instruction_sinker.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ -#define XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/root_instruction_sinker.h" - -#endif // XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ diff --git a/third_party/xla/xla/service/simplify_fp_conversions.h b/third_party/xla/xla/service/simplify_fp_conversions.h deleted file mode 100644 index b12727941fb086..00000000000000 --- a/third_party/xla/xla/service/simplify_fp_conversions.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SIMPLIFY_FP_CONVERSIONS_H_ -#define XLA_SERVICE_SIMPLIFY_FP_CONVERSIONS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/simplify_fp_conversions.h" - -#endif // XLA_SERVICE_SIMPLIFY_FP_CONVERSIONS_H_ diff --git a/third_party/xla/xla/service/slice_sinker.h b/third_party/xla/xla/service/slice_sinker.h deleted file mode 100644 index d1d1aa599b1a0f..00000000000000 --- a/third_party/xla/xla/service/slice_sinker.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SLICE_SINKER_H_ -#define XLA_SERVICE_SLICE_SINKER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/slice_sinker.h" - -#endif // XLA_SERVICE_SLICE_SINKER_H_ diff --git a/third_party/xla/xla/service/sort_simplifier.h b/third_party/xla/xla/service/sort_simplifier.h deleted file mode 100644 index d05996705787c0..00000000000000 --- a/third_party/xla/xla/service/sort_simplifier.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SORT_SIMPLIFIER_H_ -#define XLA_SERVICE_SORT_SIMPLIFIER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" - -#endif // XLA_SERVICE_SORT_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/stable_sort_expander.h b/third_party/xla/xla/service/stable_sort_expander.h deleted file mode 100644 index 78d58b24ba822e..00000000000000 --- a/third_party/xla/xla/service/stable_sort_expander.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_STABLE_SORT_EXPANDER_H_ -#define XLA_SERVICE_STABLE_SORT_EXPANDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/stable_sort_expander.h" - -#endif // XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/third_party/xla/xla/service/stochastic_convert_decomposer.h b/third_party/xla/xla/service/stochastic_convert_decomposer.h deleted file mode 100644 index 79aefac76e302a..00000000000000 --- a/third_party/xla/xla/service/stochastic_convert_decomposer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_STOCHASTIC_CONVERT_DECOMPOSER_H_ -#define XLA_SERVICE_STOCHASTIC_CONVERT_DECOMPOSER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" - -#endif // XLA_SERVICE_STOCHASTIC_CONVERT_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/sub_byte_normalization.h b/third_party/xla/xla/service/sub_byte_normalization.h deleted file mode 100644 index 3f9f700509c4b2..00000000000000 --- a/third_party/xla/xla/service/sub_byte_normalization.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SUB_BYTE_NORMALIZATION_H_ -#define XLA_SERVICE_SUB_BYTE_NORMALIZATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" - -#endif // XLA_SERVICE_SUB_BYTE_NORMALIZATION_H_ diff --git a/third_party/xla/xla/service/tree_reduction_rewriter.h b/third_party/xla/xla/service/tree_reduction_rewriter.h deleted file mode 100644 index e505b69e92d0d9..00000000000000 --- a/third_party/xla/xla/service/tree_reduction_rewriter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ -#define XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h" - -#endif // XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/tuple_points_to_analysis.h b/third_party/xla/xla/service/tuple_points_to_analysis.h deleted file mode 100644 index 1b231e4b76ad29..00000000000000 --- a/third_party/xla/xla/service/tuple_points_to_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ -#define XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/tuple_points_to_analysis.h" - -#endif // XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/tuple_simplifier.h b/third_party/xla/xla/service/tuple_simplifier.h deleted file mode 100644 index 19d81248537be4..00000000000000 --- a/third_party/xla/xla/service/tuple_simplifier.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_TUPLE_SIMPLIFIER_H_ -#define XLA_SERVICE_TUPLE_SIMPLIFIER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" - -#endif // XLA_SERVICE_TUPLE_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/while_loop_analysis.h b/third_party/xla/xla/service/while_loop_analysis.h deleted file mode 100644 index c6d95ac80db238..00000000000000 --- a/third_party/xla/xla/service/while_loop_analysis.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ -#define XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/while_loop_analysis.h" - -#endif // XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/while_loop_trip_count_annotator.h b/third_party/xla/xla/service/while_loop_trip_count_annotator.h deleted file mode 100644 index ee7377423b8b02..00000000000000 --- a/third_party/xla/xla/service/while_loop_trip_count_annotator.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ -#define XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" - -#endif // XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/zero_sized_hlo_elimination.h b/third_party/xla/xla/service/zero_sized_hlo_elimination.h deleted file mode 100644 index 3da82bd21355bb..00000000000000 --- a/third_party/xla/xla/service/zero_sized_hlo_elimination.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ -#define XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" - -#endif // XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ From 9a463dc7f61c561b0a5167f589a791f127447cbd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 21:32:07 -0700 Subject: [PATCH 0536/1324] Add a 2D test in memories_test. PiperOrigin-RevId: 746295338 --- third_party/xla/xla/layout_util.cc | 8 ++++++++ third_party/xla/xla/layout_util.h | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/third_party/xla/xla/layout_util.cc b/third_party/xla/xla/layout_util.cc index 787d17bb60d4a2..98024f46f6a571 100644 --- a/third_party/xla/xla/layout_util.cc +++ b/third_party/xla/xla/layout_util.cc @@ -119,12 +119,20 @@ absl::string_view BoolToString(bool b) { return b ? "true" : "false"; } return MakeLayout(layout); } +/* static */ bool LayoutUtil::HasDescendingLayout(const Layout& layout) { + return absl::c_is_sorted(layout.minor_to_major(), std::greater()); +} + /* static */ Layout LayoutUtil::MakeAscendingLayout(int64_t num_dims) { std::vector layout(num_dims); std::iota(layout.begin(), layout.end(), static_cast(0)); return MakeLayout(layout); } +/* static */ bool LayoutUtil::HasAscendingLayout(const Layout& layout) { + return absl::c_is_sorted(layout.minor_to_major(), std::less()); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( absl::Span major_to_minor) { Layout layout; diff --git a/third_party/xla/xla/layout_util.h b/third_party/xla/xla/layout_util.h index 71507049bd71f3..06741b6fee52c9 100644 --- a/third_party/xla/xla/layout_util.h +++ b/third_party/xla/xla/layout_util.h @@ -63,10 +63,16 @@ class LayoutUtil { // dimensions. static Layout MakeDescendingLayout(int64_t num_dims); + // Returns true if the layout is descending. + static bool HasDescendingLayout(const Layout& layout); + // Returns a layout with ascending ((i.e. {0, 1, ... n-1}) minor-to-major // dimensions. static Layout MakeAscendingLayout(int64_t num_dims); + // Returns true if the layout is ascending. + static bool HasAscendingLayout(const Layout& layout); + // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); From 4799645c5ff51333a30fb43f8e9b503775c63e48 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 10 Apr 2025 22:20:22 -0700 Subject: [PATCH 0537/1324] [xla] Fix a race in rendezvous Rendezvous must be completed before we signal other threads that potentially can start the next round and clash with unfinished one. PiperOrigin-RevId: 746310238 --- third_party/xla/xla/service/rendezvous.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index 3cb424d980bd9b..9caca88abd23d2 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -340,6 +340,10 @@ absl::StatusOr> Rendezvous( internal::AwaitAndLogIfStuck(*state, id, name, warn_stuck_timeout, terminate_timeout); } else { + // Mark rendezvous as completed, so that we can immediately start a new + // rendezvous with the same key. + rendezvous.Complete(key); + // Last thread to arrive executes the function and completes rendezvous by // making result available to all participants. All other participants will // be notified via `state->ready` flag when result is ready, and we rely on @@ -362,10 +366,6 @@ absl::StatusOr> Rendezvous( // Notify awaiting participants that result is ready. state->cv.SignalAll(); - - // Mark rendezvous as completed, so that we can immediately start a new - // rendezvous with the same key. - rendezvous.Complete(key); } return state->result; From 4ac198df8df35e0423d55c75b594e19b44f03f95 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 10 Apr 2025 23:20:36 -0700 Subject: [PATCH 0538/1324] [XLA:GPU] Add custom triton emitter check related to multi-output fusion. Since triton requires power of 2 tiling parameters, we need to ensure that there is no mismatch between the propagated tile parameters and the tile parameter that will actually be used in the emitter. Otherwise we may run into issues with buffer sharing. PiperOrigin-RevId: 746325192 --- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../gpu/model/triton_emitter_constraints.cc | 28 +++++++++- .../gpu/model/triton_emitter_constraints.h | 6 +++ .../model/triton_emitter_constraints_test.cc | 52 +++++++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 6199e65a54afdf..3c1f29dd528412 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -767,6 +767,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc index 32183d2f753ba6..6c8c715936972e 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" @@ -171,9 +172,20 @@ TritonEmitterConstraints::GetBuilder( instructions, const HloFusionAdaptor& fusion_adaptor) { llvm::DenseSet unique_tile_size_maps; + llvm::SmallVector size_maps; + auto roots = fusion_adaptor.GetRoots(); for (const auto& tiled_hlo_instruction : instructions) { unique_tile_size_maps.insert( tiled_hlo_instruction->symbolic_tile().size_map()); + // TODO(b/365727080): We should also enforce this for single-output + // fusions. + if (roots.size() > 1 && + absl::c_any_of(roots, [&tiled_hlo_instruction]( + const HloInstructionAdaptor& instr) { + return &instr.instruction() == tiled_hlo_instruction->hlo(); + })) { + size_maps.push_back(tiled_hlo_instruction->symbolic_tile().size_map()); + } } std::vector custom_constraints = @@ -184,7 +196,8 @@ TritonEmitterConstraints::GetBuilder( return std::unique_ptr( absl::WrapUnique(new TritonEmitterConstraints( - std::move(tile_size_maps), std::move(custom_constraints), + std::move(tile_size_maps), std::move(size_maps), + std::move(custom_constraints), /*root_shape=*/instructions.back()->hlo()->shape(), device_description))); }; @@ -233,6 +246,19 @@ absl::StatusOr TritonEmitterConstraints::ParametersSatisfyConstraints( return false; } } + for (const auto& size_map : size_maps_) { + llvm::SmallVector transformed_tile_parameters = + EvaluateAffineMap(size_map, + /*dim_values=*/tile_parameters); + // For multi-output fusions, we require that the propagated tile sizes for + // potential root tiles are powers of 2. + // TODO(b/365727080): Technically we should also enforce this for fusions + // with just one root. + if (GetPaddedTileSizes(transformed_tile_parameters) != + transformed_tile_parameters) { + return false; + } + } return true; } diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h index c8f1356c398d68..931ac97bf38010 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h @@ -56,9 +56,11 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { explicit TritonEmitterConstraints( llvm::SmallVector tile_size_maps, + llvm::SmallVector size_maps, std::vector custom_constraints, const Shape& root_shape, const se::DeviceDescription& device_info) : tile_size_maps_(std::move(tile_size_maps)), + size_maps_(std::move(size_maps)), custom_constraints_(std::move(custom_constraints)), root_shape_(root_shape), device_info_(device_info) {} @@ -91,6 +93,10 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { // collection of unique maps to improve compilation time. llvm::SmallVector tile_size_maps_; + // Tile size maps that need to be checked whether they evaluate to powers of + // 2. We need this constraint for multi-output fusions. + llvm::SmallVector size_maps_; + // Custom emitter-specific constraints to check in // `ParametersSatisfyConstraints`. std::vector custom_constraints_; diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc index fd004d894c1ec9..d99fcee7158962 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -388,6 +388,58 @@ ENTRY main { IsOkAndHolds(false)); } +TEST_F(TritonEmitterConstraintsTest, MultiOutputFusionHasPowerOfTwoTileSizes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fused_computation { + param_0 = f32[36] parameter(0) + abs = f32[36] abs(param_0) + reshape = f32[3,12] reshape(abs) + zero = f32[] constant(0) + reduce = f32[3] reduce(reshape, zero), to_apply=add, dimensions={1} + ROOT tuple = (f32[3], f32[36]) tuple(reduce, abs) +} + +ENTRY entry_computation { + param_0 = f32[36] parameter(0) + ROOT fusion = (f32[3], f32[36]) fusion(param_0), kind=kCustom, + calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton"}} +})")); + std::optional analysis_without_triton_constraints = + TryAnalyzeModule(module.get(), + /*with_triton_emitter_specific_constraints=*/false); + ASSERT_TRUE(analysis_without_triton_constraints.has_value()); + + // (1,) is a theoretically valid tiling for this multi-output fusion, so + // SymbolicTileAnalysis should allow it. + EXPECT_THAT( + analysis_without_triton_constraints->ParametersSatisfyConstraints({1}), + IsOkAndHolds(true)); + + std::optional analysis_with_triton_constraints = + TryAnalyzeModule(module.get(), + /*with_triton_emitter_specific_constraints=*/true); + + ASSERT_TRUE(analysis_with_triton_constraints.has_value()); + + // (1,) is a theoretically valid tiling for this multi-output fusion, but the + // propagated tile size of (1,12) for the extra output does not pass the + // condition that all tile sizes are powers of 2. This can result in different + // paddings for the different roots being used, which can cause problems if + // buffers are shared. + EXPECT_THAT( + analysis_with_triton_constraints->ParametersSatisfyConstraints({1}), + IsOkAndHolds(false)); +} + } // namespace } // namespace gpu } // namespace xla From 1e9dc8ce3511071ff1b889ec3b45066fe17f17c2 Mon Sep 17 00:00:00 2001 From: Venkat6871 Date: Fri, 11 Apr 2025 12:04:38 +0530 Subject: [PATCH 0539/1324] Fix typos in documentation strings --- tensorflow/python/keras/callbacks.py | 4 ++-- tensorflow/python/keras/engine/compile_utils.py | 2 +- tensorflow/python/keras/engine/training_utils_v1.py | 2 +- tensorflow/python/keras/layers/merge.py | 2 +- tensorflow/python/keras/losses.py | 4 ++-- tensorflow/python/keras/metrics.py | 4 ++-- tensorflow/python/keras/optimizer_v2/optimizer_v2.py | 6 +++--- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 527a001f096c28..01bcb5e0cbd718 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -616,7 +616,7 @@ class Callback: 1. You should pack all your callbacks into a single `callbacks.CallbackList` so they can all be called together. - 2. You will need to manually call all the `on_*` methods at the apropriate + 2. You will need to manually call all the `on_*` methods at the appropriate locations in your loop. Like this: ``` @@ -1627,7 +1627,7 @@ class BackupAndRestore(Callback): ... pass >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, ... batch_size=1, callbacks=[callback], verbose=0) - >>> # Only 6 more epochs are run, since first trainning got interrupted at + >>> # Only 6 more epochs are run, since first training got interrupted at >>> # zero-indexed epoch 4, second training will continue from 4 to 9. >>> len(history.history['loss']) 6 diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py index 05ef59d5317652..8bee5daf01e4fe 100644 --- a/tensorflow/python/keras/engine/compile_utils.py +++ b/tensorflow/python/keras/engine/compile_utils.py @@ -571,7 +571,7 @@ def _create_pseudo_names(tensors, prefix): """Creates pseudo {input | output} names for subclassed Models. Warning: this function should only be used to define default - names for `Metics` and `SavedModel`. No other use cases should + names for `Metrics` and `SavedModel`. No other use cases should rely on a `Model`'s input or output names. Example with dict: diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py index 5bfa27a38234f5..58a4dd0a8beba1 100644 --- a/tensorflow/python/keras/engine/training_utils_v1.py +++ b/tensorflow/python/keras/engine/training_utils_v1.py @@ -1690,7 +1690,7 @@ def infer_steps_for_dataset(model, (dataset.options().experimental_distribute.auto_shard_policy != options_lib.AutoShardPolicy.OFF)): # If the dataset would be auto-sharded, we should not infer a local - # steps_per_epoch due to the possible inbalanced sharding between workers. + # steps_per_epoch due to the possible imbalanced sharding between workers. return None size = backend.get_value(cardinality.cardinality(dataset)) diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py index 68461de0841f2f..b78e1b9f9d3b92 100644 --- a/tensorflow/python/keras/layers/merge.py +++ b/tensorflow/python/keras/layers/merge.py @@ -33,7 +33,7 @@ class _Merge(Layer): """ def __init__(self, **kwargs): - """Intializes a Merge layer. + """Initializes a Merge layer. Args: **kwargs: standard layer keyword arguments. diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index f03e7de0ed932f..a7c86f88222f3b 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -1216,7 +1216,7 @@ def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred, y_pred_extra_dim=False): """ def rt_is_equiv_dense(rt): - """Returns true if this RaggedTensor has the same row_lenghts across + """Returns true if this RaggedTensor has the same row_lengths across all ragged dimensions and thus can be converted to a dense tensor without loss of information. @@ -1676,7 +1676,7 @@ def _ragged_tensor_categorical_crossentropy(y_true, When used by CategoricalCrossentropy() with the default reduction (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the number of elements independent of the batch. E.g. if the RaggedTensor - has 2 batches with [2, 1] values respectivly the resulting loss is + has 2 batches with [2, 1] values respectively the resulting loss is the sum of the individual loss values divided by 3. """ fn = functools.partial( diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 96e63a57d4d206..726d579b924ac4 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -2045,12 +2045,12 @@ class AUC(Metric): Usage with `compile()` API: ```python - # Reports the AUC of a model outputing a probability. + # Reports the AUC of a model outputting a probability. model.compile(optimizer='sgd', loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.AUC()]) - # Reports the AUC of a model outputing a logit. + # Reports the AUC of a model outputting a logit. model.compile(optimizer='sgd', loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), metrics=[tf.keras.metrics.AUC(from_logits=True)]) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index d239d49951346a..3c140f10427546 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -480,7 +480,7 @@ def _aggregate_gradients(self, grads_and_vars): grads_and_vars: List of (gradient, variable) pairs. Returns: - A list of (aggregrated_gradient, variable) pairs. By default, this calls + A list of (aggregated_gradient, variable) pairs. By default, this calls `self.gradient_aggregator`. """ return self.gradient_aggregator(grads_and_vars) @@ -619,7 +619,7 @@ def apply_gradients(self, name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. experimental_aggregate_gradients: Whether to sum gradients from different - replicas in the presense of `tf.distribute.Strategy`. If False, it's + replicas in the presence of `tf.distribute.Strategy`. If False, it's user responsibility to aggregate the gradients. Default to True. Returns: @@ -1452,7 +1452,7 @@ class RestoredOptimizer(OptimizerV2): Holds slot variables and hyperparameters when an optimizer is restored from a SavedModel. These variables may be referenced in functions along with ops created by the original optimizer, but currently we do not support using the - optimizer object iself (e.g. through `apply_gradients`). + optimizer object itself (e.g. through `apply_gradients`). """ # TODO(allenl): Make the restored optimizer functional by tracing its apply # methods. From 5c213593cf22774875a835887be9ace2795ff204 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Apr 2025 23:34:41 -0700 Subject: [PATCH 0540/1324] Automated Code Change PiperOrigin-RevId: 746328271 --- third_party/xla/xla/hlo/tools/BUILD | 4 ++++ third_party/xla/xla/hlo/tools/convert_computation.cc | 2 ++ third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc | 3 +++ 3 files changed, 9 insertions(+) diff --git a/third_party/xla/xla/hlo/tools/BUILD b/third_party/xla/xla/hlo/tools/BUILD index 2992a203790bda..bdde1602c06c87 100644 --- a/third_party/xla/xla/hlo/tools/BUILD +++ b/third_party/xla/xla/hlo/tools/BUILD @@ -31,6 +31,8 @@ xla_cc_binary( "//xla/tsl/lib/io:random_inputstream", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", @@ -58,6 +60,8 @@ xla_cc_binary( srcs = ["convert_computation.cc"], deps = [ "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", diff --git a/third_party/xla/xla/hlo/tools/convert_computation.cc b/third_party/xla/xla/hlo/tools/convert_computation.cc index 7ebc5d3f5aa4cd..b013ac383ab8b6 100644 --- a/third_party/xla/xla/hlo/tools/convert_computation.cc +++ b/third_party/xla/xla/hlo/tools/convert_computation.cc @@ -16,6 +16,8 @@ limitations under the License. // Usage: convert_computation serialized_computation_proto // // bin2txt spits out the result to stdout. txt2bin modifies the file in place. +#include "absl/log/check.h" +#include "absl/log/log.h" #include "tsl/platform/status.h" #ifndef _WIN32 #include diff --git a/third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc index 659e4cde814b5d..efbcef00e0f376 100644 --- a/third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc +++ b/third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc @@ -15,10 +15,13 @@ limitations under the License. #include +#include #include #include #include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/tsl/lib/io/buffered_inputstream.h" #include "xla/tsl/lib/io/random_inputstream.h" From d16a8453cedbf3ffb075af7f50ab31de66443ffd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 00:13:21 -0700 Subject: [PATCH 0541/1324] Automated Code Change PiperOrigin-RevId: 746339122 --- third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index d0a37584173235..59f6c02cb3a66e 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include From dc4fa262bf70ecb8760b46d8cd72083136296572 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 11 Apr 2025 00:14:50 -0700 Subject: [PATCH 0542/1324] Remove deprecated targets from build and header files PiperOrigin-RevId: 746339638 --- third_party/xla/xla/client/lib/BUILD | 208 ------------------ third_party/xla/xla/client/lib/approx_topk.h | 22 -- .../xla/xla/client/lib/approx_topk_shape.h | 22 -- third_party/xla/xla/client/lib/arithmetic.h | 22 -- third_party/xla/xla/client/lib/broadcast.h | 22 -- third_party/xla/xla/client/lib/comparators.h | 22 -- third_party/xla/xla/client/lib/constants.h | 22 -- .../xla/xla/client/lib/conv_grad_size_util.h | 22 -- .../xla/xla/client/lib/dynamic_shaped_ops.h | 22 -- third_party/xla/xla/client/lib/loops.h | 22 -- .../xla/xla/client/lib/lu_decomposition.h | 22 -- third_party/xla/xla/client/lib/math.h | 22 -- third_party/xla/xla/client/lib/matrix.h | 22 -- third_party/xla/xla/client/lib/pooling.h | 22 -- third_party/xla/xla/client/lib/prng.h | 22 -- third_party/xla/xla/client/lib/qr.h | 22 -- third_party/xla/xla/client/lib/quantize.h | 22 -- .../xla/xla/client/lib/self_adjoint_eig.h | 22 -- third_party/xla/xla/client/lib/slicing.h | 22 -- third_party/xla/xla/client/lib/sorting.h | 22 -- third_party/xla/xla/client/lib/svd.h | 22 -- third_party/xla/xla/client/lib/tridiagonal.h | 22 -- third_party/xla/xla/client/lib/tuple.h | 22 -- 23 files changed, 692 deletions(-) delete mode 100644 third_party/xla/xla/client/lib/approx_topk.h delete mode 100644 third_party/xla/xla/client/lib/approx_topk_shape.h delete mode 100644 third_party/xla/xla/client/lib/arithmetic.h delete mode 100644 third_party/xla/xla/client/lib/broadcast.h delete mode 100644 third_party/xla/xla/client/lib/comparators.h delete mode 100644 third_party/xla/xla/client/lib/constants.h delete mode 100644 third_party/xla/xla/client/lib/conv_grad_size_util.h delete mode 100644 third_party/xla/xla/client/lib/dynamic_shaped_ops.h delete mode 100644 third_party/xla/xla/client/lib/loops.h delete mode 100644 third_party/xla/xla/client/lib/lu_decomposition.h delete mode 100644 third_party/xla/xla/client/lib/math.h delete mode 100644 third_party/xla/xla/client/lib/matrix.h delete mode 100644 third_party/xla/xla/client/lib/pooling.h delete mode 100644 third_party/xla/xla/client/lib/prng.h delete mode 100644 third_party/xla/xla/client/lib/qr.h delete mode 100644 third_party/xla/xla/client/lib/quantize.h delete mode 100644 third_party/xla/xla/client/lib/self_adjoint_eig.h delete mode 100644 third_party/xla/xla/client/lib/slicing.h delete mode 100644 third_party/xla/xla/client/lib/sorting.h delete mode 100644 third_party/xla/xla/client/lib/svd.h delete mode 100644 third_party/xla/xla/client/lib/tridiagonal.h delete mode 100644 third_party/xla/xla/client/lib/tuple.h diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 80aa72001f0314..4b3669466ea5f2 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -23,170 +23,6 @@ filegroup( # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() -cc_library( - name = "arithmetic", - hdrs = ["arithmetic.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:arithmetic instead.", - deps = [ - "//xla/hlo/builder/lib:arithmetic", - ], -) - -cc_library( - name = "comparators", - hdrs = [ - "comparators.h", - ], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:comparators instead.", - deps = [ - "//xla/hlo/builder/lib:comparators", - ], -) - -cc_library( - name = "constants", - hdrs = ["constants.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:constants instead.", - deps = [ - "//xla/hlo/builder/lib:constants", - ], -) - -cc_library( - name = "broadcast", - hdrs = ["broadcast.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:broadcast instead.", - deps = [ - "//xla/hlo/builder/lib:broadcast", - ], -) - -cc_library( - name = "conv_grad_size_util", - hdrs = ["conv_grad_size_util.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:conv_grad_size_util instead.", - deps = [ - "//xla/hlo/builder/lib:conv_grad_size_util", - ], -) - -cc_library( - name = "dynamic_shaped_ops", - hdrs = ["dynamic_shaped_ops.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:dynamic_shaped_ops instead.", - deps = [ - "//xla/hlo/builder/lib:dynamic_shaped_ops", - ], -) - -cc_library( - name = "loops", - hdrs = ["loops.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:loops instead.", - deps = [ - "//xla/hlo/builder/lib:loops", - ], -) - -cc_library( - name = "math", - hdrs = [ - "math.h", - ], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:math instead.", - deps = [ - "//xla/hlo/builder/lib:math", - ], -) - -cc_library( - name = "matrix", - hdrs = ["matrix.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:matrix instead.", - deps = [ - "//xla/hlo/builder/lib:matrix", - ], -) - -cc_library( - name = "pooling", - hdrs = ["pooling.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:pooling instead.", - deps = [ - "//xla/hlo/builder/lib:pooling", - ], -) - -cc_library( - name = "prng", - hdrs = ["prng.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:prng instead.", - deps = [ - "//xla/hlo/builder/lib:prng", - ], -) - -cc_library( - name = "qr", - hdrs = ["qr.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:qr instead.", - deps = [ - "//xla/hlo/builder/lib:qr", - ], -) - -cc_library( - name = "lu_decomposition", - hdrs = ["lu_decomposition.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:lu_decomposition instead.", - deps = [ - "//xla/hlo/builder/lib:lu_decomposition", - ], -) - -cc_library( - name = "approx_topk", - hdrs = ["approx_topk.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:approx_topk instead.", - deps = [ - "//xla/hlo/builder/lib:approx_topk", - ], -) - -cc_library( - name = "approx_topk_shape", - hdrs = ["approx_topk_shape.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:approx_topk_shape instead.", - deps = ["//xla/hlo/builder/lib:approx_topk_shape"], -) - -cc_library( - name = "slicing", - hdrs = ["slicing.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:slicing instead.", - deps = [ - "//xla/hlo/builder/lib:slicing", - ], -) - -cc_library( - name = "sorting", - hdrs = ["sorting.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:sorting instead.", - deps = [ - "//xla/hlo/builder/lib:sorting", - ], -) - -cc_library( - name = "quantize", - hdrs = ["quantize.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:quantize instead.", - deps = [ - "//xla/hlo/builder/lib:quantize", - ], -) - cc_library( name = "testing", srcs = ["testing.cc"], @@ -208,47 +44,3 @@ cc_library( "@local_tsl//tsl/platform:errors", ], ) - -cc_library( - name = "self_adjoint_eig", - hdrs = ["self_adjoint_eig.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:self_adjoint_eig instead.", - deps = [ - "//xla/hlo/builder/lib:self_adjoint_eig", - ], -) - -cc_library( - name = "svd", - hdrs = ["svd.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:svd instead.", - deps = [ - "//xla/hlo/builder/lib:svd", - ], -) - -cc_library( - name = "tridiagonal", - hdrs = ["tridiagonal.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:tridiagonal instead.", - deps = [ - "//xla/hlo/builder/lib:tridiagonal", - ], -) - -cc_library( - name = "logdet", - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:logdet instead.", - deps = [ - "//xla/hlo/builder/lib:logdet", - ], -) - -cc_library( - name = "tuple", - hdrs = ["tuple.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:tuple instead.", - deps = [ - "//xla/hlo/builder/lib:tuple", - ], -) diff --git a/third_party/xla/xla/client/lib/approx_topk.h b/third_party/xla/xla/client/lib/approx_topk.h deleted file mode 100644 index 175a12cad0e94a..00000000000000 --- a/third_party/xla/xla/client/lib/approx_topk.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_APPROX_TOPK_H_ -#define XLA_CLIENT_LIB_APPROX_TOPK_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/approx_topk.h" - -#endif // XLA_CLIENT_LIB_APPROX_TOPK_H_ diff --git a/third_party/xla/xla/client/lib/approx_topk_shape.h b/third_party/xla/xla/client/lib/approx_topk_shape.h deleted file mode 100644 index eef1e296f36fd3..00000000000000 --- a/third_party/xla/xla/client/lib/approx_topk_shape.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ -#define XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/approx_topk_shape.h" - -#endif // XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic.h b/third_party/xla/xla/client/lib/arithmetic.h deleted file mode 100644 index 0b8e000a2f276b..00000000000000 --- a/third_party/xla/xla/client/lib/arithmetic.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_ARITHMETIC_H_ -#define XLA_CLIENT_LIB_ARITHMETIC_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/arithmetic.h" - -#endif // XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/third_party/xla/xla/client/lib/broadcast.h b/third_party/xla/xla/client/lib/broadcast.h deleted file mode 100644 index deb85ae9ab8585..00000000000000 --- a/third_party/xla/xla/client/lib/broadcast.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_BROADCAST_H_ -#define XLA_CLIENT_LIB_BROADCAST_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/broadcast.h" - -#endif // XLA_CLIENT_LIB_BROADCAST_H_ diff --git a/third_party/xla/xla/client/lib/comparators.h b/third_party/xla/xla/client/lib/comparators.h deleted file mode 100644 index ad9b37d716d717..00000000000000 --- a/third_party/xla/xla/client/lib/comparators.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_COMPARATORS_H_ -#define XLA_CLIENT_LIB_COMPARATORS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/comparators.h" - -#endif // XLA_CLIENT_LIB_COMPARATORS_H_ diff --git a/third_party/xla/xla/client/lib/constants.h b/third_party/xla/xla/client/lib/constants.h deleted file mode 100644 index 2135f481977396..00000000000000 --- a/third_party/xla/xla/client/lib/constants.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_CONSTANTS_H_ -#define XLA_CLIENT_LIB_CONSTANTS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/constants.h" - -#endif // XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/third_party/xla/xla/client/lib/conv_grad_size_util.h b/third_party/xla/xla/client/lib/conv_grad_size_util.h deleted file mode 100644 index e991982968da9e..00000000000000 --- a/third_party/xla/xla/client/lib/conv_grad_size_util.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ -#define XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/conv_grad_size_util.h" - -#endif // XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/third_party/xla/xla/client/lib/dynamic_shaped_ops.h b/third_party/xla/xla/client/lib/dynamic_shaped_ops.h deleted file mode 100644 index cf62a37d6f920e..00000000000000 --- a/third_party/xla/xla/client/lib/dynamic_shaped_ops.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ -#define XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" - -#endif // XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ diff --git a/third_party/xla/xla/client/lib/loops.h b/third_party/xla/xla/client/lib/loops.h deleted file mode 100644 index d714efeaa415f1..00000000000000 --- a/third_party/xla/xla/client/lib/loops.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_LOOPS_H_ -#define XLA_CLIENT_LIB_LOOPS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/loops.h" - -#endif // XLA_CLIENT_LIB_LOOPS_H_ diff --git a/third_party/xla/xla/client/lib/lu_decomposition.h b/third_party/xla/xla/client/lib/lu_decomposition.h deleted file mode 100644 index 752e84c9d2b12f..00000000000000 --- a/third_party/xla/xla/client/lib/lu_decomposition.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ -#define XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/lu_decomposition.h" - -#endif // XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ diff --git a/third_party/xla/xla/client/lib/math.h b/third_party/xla/xla/client/lib/math.h deleted file mode 100644 index 9956776ee87d1a..00000000000000 --- a/third_party/xla/xla/client/lib/math.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_MATH_H_ -#define XLA_CLIENT_LIB_MATH_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/math.h" - -#endif // XLA_CLIENT_LIB_MATH_H_ diff --git a/third_party/xla/xla/client/lib/matrix.h b/third_party/xla/xla/client/lib/matrix.h deleted file mode 100644 index aaf938786fc020..00000000000000 --- a/third_party/xla/xla/client/lib/matrix.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_MATRIX_H_ -#define XLA_CLIENT_LIB_MATRIX_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/matrix.h" - -#endif // XLA_CLIENT_LIB_MATRIX_H_ diff --git a/third_party/xla/xla/client/lib/pooling.h b/third_party/xla/xla/client/lib/pooling.h deleted file mode 100644 index 22f3d2f0b07b9c..00000000000000 --- a/third_party/xla/xla/client/lib/pooling.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_POOLING_H_ -#define XLA_CLIENT_LIB_POOLING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/pooling.h" - -#endif // XLA_CLIENT_LIB_POOLING_H_ diff --git a/third_party/xla/xla/client/lib/prng.h b/third_party/xla/xla/client/lib/prng.h deleted file mode 100644 index 0c9e460ba10cbb..00000000000000 --- a/third_party/xla/xla/client/lib/prng.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_PRNG_H_ -#define XLA_CLIENT_LIB_PRNG_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/prng.h" - -#endif // XLA_CLIENT_LIB_PRNG_H_ diff --git a/third_party/xla/xla/client/lib/qr.h b/third_party/xla/xla/client/lib/qr.h deleted file mode 100644 index 743b36503b6175..00000000000000 --- a/third_party/xla/xla/client/lib/qr.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_QR_H_ -#define XLA_CLIENT_LIB_QR_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/qr.h" - -#endif // XLA_CLIENT_LIB_QR_H_ diff --git a/third_party/xla/xla/client/lib/quantize.h b/third_party/xla/xla/client/lib/quantize.h deleted file mode 100644 index 459716b36b54db..00000000000000 --- a/third_party/xla/xla/client/lib/quantize.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_QUANTIZE_H_ -#define XLA_CLIENT_LIB_QUANTIZE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/quantize.h" - -#endif // XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig.h b/third_party/xla/xla/client/lib/self_adjoint_eig.h deleted file mode 100644 index ae81dbc0baf5a0..00000000000000 --- a/third_party/xla/xla/client/lib/self_adjoint_eig.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ -#define XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/self_adjoint_eig.h" - -#endif // XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/third_party/xla/xla/client/lib/slicing.h b/third_party/xla/xla/client/lib/slicing.h deleted file mode 100644 index c2ea243ae2c937..00000000000000 --- a/third_party/xla/xla/client/lib/slicing.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_SLICING_H_ -#define XLA_CLIENT_LIB_SLICING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/slicing.h" - -#endif // XLA_CLIENT_LIB_SLICING_H_ diff --git a/third_party/xla/xla/client/lib/sorting.h b/third_party/xla/xla/client/lib/sorting.h deleted file mode 100644 index 5cb81a43c11f36..00000000000000 --- a/third_party/xla/xla/client/lib/sorting.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_SORTING_H_ -#define XLA_CLIENT_LIB_SORTING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/sorting.h" - -#endif // XLA_CLIENT_LIB_SORTING_H_ diff --git a/third_party/xla/xla/client/lib/svd.h b/third_party/xla/xla/client/lib/svd.h deleted file mode 100644 index 54893697c5fced..00000000000000 --- a/third_party/xla/xla/client/lib/svd.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_SVD_H_ -#define XLA_CLIENT_LIB_SVD_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/svd.h" - -#endif // XLA_CLIENT_LIB_SVD_H_ diff --git a/third_party/xla/xla/client/lib/tridiagonal.h b/third_party/xla/xla/client/lib/tridiagonal.h deleted file mode 100644 index 5cc51c5e98262e..00000000000000 --- a/third_party/xla/xla/client/lib/tridiagonal.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_TRIDIAGONAL_H_ -#define XLA_CLIENT_LIB_TRIDIAGONAL_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/tridiagonal.h" - -#endif // XLA_CLIENT_LIB_TRIDIAGONAL_H_ diff --git a/third_party/xla/xla/client/lib/tuple.h b/third_party/xla/xla/client/lib/tuple.h deleted file mode 100644 index c1dc9de027a50f..00000000000000 --- a/third_party/xla/xla/client/lib/tuple.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_LIB_TUPLE_H_ -#define XLA_CLIENT_LIB_TUPLE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/lib/tuple.h" - -#endif // XLA_CLIENT_LIB_TUPLE_H_ From 48173835528a45fd00c5fe5df9149dba8e38655f Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 11 Apr 2025 00:17:33 -0700 Subject: [PATCH 0543/1324] Remove deprecated targets from build and header files PiperOrigin-RevId: 746340561 --- third_party/xla/xla/tests/BUILD | 20 ----------------- third_party/xla/xla/tests/filecheck.h | 22 ------------------- .../xla/xla/tests/verified_hlo_module.h | 21 ------------------ 3 files changed, 63 deletions(-) delete mode 100644 third_party/xla/xla/tests/filecheck.h delete mode 100644 third_party/xla/xla/tests/verified_hlo_module.h diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 99624926729622..a2d68605f00c95 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -106,16 +106,6 @@ cc_library( ], ) -cc_library( - name = "verified_hlo_module", - testonly = True, - hdrs = ["verified_hlo_module.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/testlib:verified_hlo_module instead.", - deps = [ - "//xla/hlo/testlib:verified_hlo_module", - ], -) - cc_library( name = "pjrt_client_registry", srcs = ["pjrt_client_registry.cc"], @@ -449,16 +439,6 @@ cc_library( ], ) -cc_library( - name = "filecheck", - testonly = True, - hdrs = ["filecheck.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/testlib:filecheck instead.", - deps = [ - "//xla/hlo/testlib:filecheck", - ], -) - cc_library( name = "local_client_test_base", testonly = True, diff --git a/third_party/xla/xla/tests/filecheck.h b/third_party/xla/xla/tests/filecheck.h deleted file mode 100644 index e96152510c455f..00000000000000 --- a/third_party/xla/xla/tests/filecheck.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TESTS_FILECHECK_H_ -#define XLA_TESTS_FILECHECK_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/testlib/filecheck.h" - -#endif // XLA_TESTS_FILECHECK_H_ diff --git a/third_party/xla/xla/tests/verified_hlo_module.h b/third_party/xla/xla/tests/verified_hlo_module.h deleted file mode 100644 index 3b27b3bd0cefa5..00000000000000 --- a/third_party/xla/xla/tests/verified_hlo_module.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_TESTS_VERIFIED_HLO_MODULE_H_ -#define XLA_TESTS_VERIFIED_HLO_MODULE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/testlib/verified_hlo_module.h" - -#endif // XLA_TESTS_VERIFIED_HLO_MODULE_H_ From dbf186b2d8b48b03fb7125136e56f4d0c7974eb2 Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 11 Apr 2025 01:03:18 -0700 Subject: [PATCH 0544/1324] The second parameter to `TF_LITE_KERNEL_LOG` is expected to be a `printf` format string, known at compile time. Fix the used format specifiers to match the appropriate value, potentially updating the type of the value. Run IWYU to fix includes. PiperOrigin-RevId: 746352306 --- tensorflow/lite/delegates/xnnpack/BUILD | 12 +++--- .../delegates/xnnpack/xnnpack_delegate.cc | 43 ++++++++++--------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index a0905e314a020b..9c7a9d7021fdaf 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -240,11 +240,11 @@ cc_library( ":file_util", ":flexbuffers_util", ":quantization_util", - ":tflite_with_xnnpack_dynamic_fully_connected", - ":tflite_with_xnnpack_logging", - ":tflite_with_xnnpack_qs8", - ":tflite_with_xnnpack_qu8", - ":tflite_with_xnnpack_transient_indirection_buffer", + ":tflite_with_xnnpack_dynamic_fully_connected", # buildcleaner: keep + ":tflite_with_xnnpack_logging", # buildcleaner: keep + ":tflite_with_xnnpack_qs8", # buildcleaner: keep + ":tflite_with_xnnpack_qu8", # buildcleaner: keep + ":tflite_with_xnnpack_transient_indirection_buffer", # buildcleaner: keep ":weight_cache", "//tensorflow/compiler/mlir/lite/kernels/internal:compatibility_macros", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", @@ -257,7 +257,6 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal/utils:sparsity_format_converter", "//tensorflow/lite/schema:schema_fbs", @@ -301,7 +300,6 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal/utils:sparsity_format_converter", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 0d904a9c44cf35..48c8c03eedad52 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -319,17 +319,18 @@ xnn_datatype GetXNNPackDatatype(TfLiteContext* context, if (!CheckFp16Scale(context, tensor, t, quantization_params)) { return xnn_datatype_invalid; } - int num_scales = + int64_t num_scales = NumElements(&context->tensors[quantization_params->scale]); - int num_filter_elements = NumElements(&tensor); + int64_t num_filter_elements = NumElements(&tensor); if (num_filter_elements / num_scales != quantization_params->blocksize) { - TF_LITE_KERNEL_LOG(context, - "Unsupported combination of filter elements %d " - "number of scales %d and blocksize %d " - "%s tensor %d in XNNPACK delegate", - num_filter_elements, num_scales, - quantization_params->blocksize, t); + TF_LITE_KERNEL_LOG( + context, + "Unsupported combination of filter elements %" PRId64 + " number of scales %" PRId64 " and blocksize %" PRId32 + " for %s tensor %d in XNNPACK delegate", + num_filter_elements, num_scales, quantization_params->blocksize, + tensor.name, t); return xnn_datatype_invalid; } break; @@ -351,9 +352,10 @@ xnn_datatype GetXNNPackDatatype(TfLiteContext* context, case kTfLiteBlockwiseQuantization: return xnn_datatype_qbint4; default: - TF_LITE_KERNEL_LOG( - context, - "Unsupported quantization type %zu for INT4 tensor #%d", t); + TF_LITE_KERNEL_LOG(context, + "Unsupported quantization type %d for INT4 " + "tensor %d in XNNPACK delegate", + tensor.quantization.type, t); return xnn_datatype_invalid; } case kTfLiteInt8: @@ -2331,8 +2333,8 @@ class Subgraph { if (quantization_params->blocksize % 32 != 0) { TF_LITE_MAYBE_KERNEL_LOG( context, - "Blocksize %zu must be multiple of 32 in " - "tensor #%d in node #%d", + "Blocksize %" PRId32 + " must be multiple of 32 in tensor #%d in node #%d", quantization_params->blocksize, tensor_index, node_index); return kTfLiteError; } @@ -4372,11 +4374,11 @@ class Subgraph { logging_context, axis_tensor, node->inputs->data[1], BuiltinOperator_EXPAND_DIMS, node_index)); - const size_t num_new_axes = NumElements(&axis_tensor); + const int64_t num_new_axes = NumElements(&axis_tensor); if (num_new_axes != 1) { TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unexpected number of axes (%d) in node #%d: " - "TFLite only supports 1 new axes", + "unexpected number of axes (%" PRId64 + ") in node #%d: TFLite only supports 1 new axes", num_new_axes, node_index); return kTfLiteError; } @@ -5772,13 +5774,14 @@ class Subgraph { // inside our kernels; check here and punt those to the default // delegate implementation for it to decide how to handle them. const int64_t extent = input_tensor.dims->data[i]; - const size_t offset = begins[i] < 0 ? begins[i] + extent : begins[i]; - const size_t size = + const int64_t offset = begins[i] < 0 ? begins[i] + extent : begins[i]; + const int64_t size = ends[i] <= 0 ? ends[i] + extent - offset : ends[i] - offset; if (offset + size > extent) { TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "offset %zu + size %zu exceeds extent %zu in " - "STRIDED_SLICE node #%d for dimension %zu", + "offset %" PRId64 " + size %" PRId64 + " exceeds extent %" PRId64 + " in STRIDED_SLICE node #%d for dimension %zu", offset, size, extent, node_index, i); return kTfLiteError; } From 76e3fc33bb02c4dfb3c32e5b114aaae961022f15 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Fri, 11 Apr 2025 01:29:32 -0700 Subject: [PATCH 0545/1324] Integrate Triton up to [c629b06a](https://github.com/openai/triton/commits/c629b06ac06a7f8596fa2ba9b7728d3fd8314f11) PiperOrigin-RevId: 746359168 --- third_party/triton/llvm_integration/series.bzl | 2 -- third_party/triton/temporary/ptxas_12_4.patch | 15 --------------- third_party/triton/temporary/series.bzl | 1 - third_party/triton/workspace.bzl | 4 ++-- .../triton/llvm_integration/series.bzl | 2 -- .../third_party/triton/temporary/ptxas_12_4.patch | 15 --------------- .../xla/third_party/triton/temporary/series.bzl | 1 - third_party/xla/third_party/triton/workspace.bzl | 4 ++-- .../codegen/triton/compilation_pipeline_cuda.cc | 2 +- .../codegen/triton/compilation_pipeline_rocm.cc | 1 + .../xla/service/gpu/autotuning/autotuner_util.h | 2 +- 11 files changed, 7 insertions(+), 42 deletions(-) delete mode 100644 third_party/triton/temporary/ptxas_12_4.patch delete mode 100644 third_party/xla/third_party/triton/temporary/ptxas_12_4.patch diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index a0964d051f1668..d820528c8a38f6 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,8 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl741558316.patch", - "//third_party/triton:llvm_integration/cl742325920.patch", "//third_party/triton:llvm_integration/cl744822685.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/ptxas_12_4.patch b/third_party/triton/temporary/ptxas_12_4.patch deleted file mode 100644 index fcc71519083af1..00000000000000 --- a/third_party/triton/temporary/ptxas_12_4.patch +++ /dev/null @@ -1,15 +0,0 @@ -This can be removed as soon as we updated ptxas to 12.8 (b/385480934). - ---- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-03-26 00:22:57.000000000 -0700 -+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-04-02 02:43:36.000000000 -0700 -@@ -180,6 +180,10 @@ - // warpgroup. - Value warp = - b.and_(rewriter.create(loc), b.i32_val(0xFFFFFFFC)); -+ // Workaround for a bug in ptxas 12.3 that cause a failure in -+ // test_core.py::test_dot. The shuffle will force the compiler to treat the -+ // value as uniform and prevent wrong optimizations. -+ warp = mlir::LLVM::NVIDIA::shuffleIdx(loc, rewriter, warp, 0); - Value warpM = b.urem(warp, b.i32_val(wpt[0])); - Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); - diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index f38405c475ed6c..4fa55269e3323c 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,6 +14,5 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/ptxas_12_4.patch", # Add new patches just above this line ] diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 656679a89df22a..895c7b4c96b2d3 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "triton_integrate_branch-1.2" - TRITON_SHA256 = "ba715575f8e8ead49df545a40c9557a4e40174400892fcf28fefdd15ff3f2c6a" + TRITON_COMMIT = "triton_integrate_branch-1.3" + TRITON_SHA256 = "930d2c40ded0300c4070b8f43a845d493c7a078a4864d023be3af2e16ec4f884" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index a0964d051f1668..d820528c8a38f6 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -8,8 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl741558316.patch", - "//third_party/triton:llvm_integration/cl742325920.patch", "//third_party/triton:llvm_integration/cl744822685.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/temporary/ptxas_12_4.patch b/third_party/xla/third_party/triton/temporary/ptxas_12_4.patch deleted file mode 100644 index fcc71519083af1..00000000000000 --- a/third_party/xla/third_party/triton/temporary/ptxas_12_4.patch +++ /dev/null @@ -1,15 +0,0 @@ -This can be removed as soon as we updated ptxas to 12.8 (b/385480934). - ---- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-03-26 00:22:57.000000000 -0700 -+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2025-04-02 02:43:36.000000000 -0700 -@@ -180,6 +180,10 @@ - // warpgroup. - Value warp = - b.and_(rewriter.create(loc), b.i32_val(0xFFFFFFFC)); -+ // Workaround for a bug in ptxas 12.3 that cause a failure in -+ // test_core.py::test_dot. The shuffle will force the compiler to treat the -+ // value as uniform and prevent wrong optimizations. -+ warp = mlir::LLVM::NVIDIA::shuffleIdx(loc, rewriter, warp, 0); - Value warpM = b.urem(warp, b.i32_val(wpt[0])); - Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); - diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index f38405c475ed6c..4fa55269e3323c 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -14,6 +14,5 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/ptxas_12_4.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 656679a89df22a..895c7b4c96b2d3 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "triton_integrate_branch-1.2" - TRITON_SHA256 = "ba715575f8e8ead49df545a40c9557a4e40174400892fcf28fefdd15ff3f2c6a" + TRITON_COMMIT = "triton_integrate_branch-1.3" + TRITON_SHA256 = "930d2c40ded0300c4070b8f43a845d493c7a078a4864d023be3af2e16ec4f884" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc index 835b25a5fef1de..398fe60a78c314 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc @@ -95,10 +95,10 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mt::gpu::createTritonGPUOptimizeAccumulatorInit()); pm->addPass( mt::gpu::createTritonGPUAutomaticWarpSpecialization({num_stages})); + pm->addPass(mt::gpu::createTritonGPUHoistTMEMAlloc()); pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages})); pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); pm->addPass(mlir::createTritonNvidiaGPUPromoteLHSToTMemPass()); - pm->addPass(mlir::createTritonNvidiaGPUKeepAccInTMemPass()); pm->addPass(mlir::createCanonicalizerPass()); } else if (cc.IsAtLeastAmpere()) { // Even though we don't run on pre-Ampere architectures anymore, we keep diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index 21919310c67dfd..359032fb8e7f33 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -129,6 +129,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createTritonAMDGPUConvertToBufferOpsPass(arch_name)); } + pm->addPass(mlir::createTritonAMDGPUFoldTrueCmpIPass()); pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createCSEPass()); pm->addPass(mlir::createSymbolDCEPass()); diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index d49545966372cf..a9b3fd70dce22f 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -61,7 +61,7 @@ class AutotuneCacheKey { // Tie a version to the cache key in order to invalidate the cache when // necessary. This should be incremented on triton upgrades or any other // changes that may affect the autotuning results. - static constexpr int kCurrentVersion = 2; + static constexpr int kCurrentVersion = 3; AutotuneCacheKey(const se::DeviceDescription& device_description, const HloInstruction& instruction) From cd9d1bd678747396714a979ce7616d311a38d43d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 02:02:44 -0700 Subject: [PATCH 0546/1324] Update GraphDef version to 2194. PiperOrigin-RevId: 746368496 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 59eef5704606e7..0e6bfd86c33f93 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2193 // Updated: 2025/4/10 +#define TF_GRAPH_DEF_VERSION 2194 // Updated: 2025/4/11 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From f99f00e2eed15f75ae4978a8155f2766b1a3ca54 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 02:03:37 -0700 Subject: [PATCH 0547/1324] compat: Update forward compatibility horizon to 2025-04-11 PiperOrigin-RevId: 746368802 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 75b98c6cb2b296..14e0ab50b31db3 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 10) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 11) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From ad14730b131b74d9504f2bfdfa27ee933bab9423 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Fri, 11 Apr 2025 02:25:07 -0700 Subject: [PATCH 0548/1324] PR #23400: Fix complex_unary_op_test_cpu on macOS Apple Imported from GitHub PR https://github.com/openxla/xla/pull/23400 As in the title. Fixes https://github.com/openxla/xla/issues/19824 Copybara import of the project: -- d215f46c6f50020fa5dcd3a3f2b08f67d0516460 by Pearu Peterson : Fix complex_unary_op_test_cpu on macOS Apple Merging this change closes #23400 PiperOrigin-RevId: 746374794 --- .../xla/xla/tests/complex_unary_op_samples.h | 66 +++++++++---------- .../generate_complex_unary_op_samples.py | 24 ++++++- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/tests/complex_unary_op_samples.h b/third_party/xla/xla/tests/complex_unary_op_samples.h index 1851e564473adb..94d38ee05243fa 100644 --- a/third_party/xla/xla/tests/complex_unary_op_samples.h +++ b/third_party/xla/xla/tests/complex_unary_op_samples.h @@ -48,7 +48,11 @@ struct Log1p { const T pi3_4 = 2.3561945f; const T zero = 0.0f; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -354,7 +358,11 @@ struct Log1p { const T pi3_4 = 2.356194490192345; const T zero = 0.0; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -671,7 +679,11 @@ struct Tan { const T nan = std::nanf(""); const T zero = 0.0f; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -804,15 +816,9 @@ struct Tan { /* 123 */ { { -2.e+00f, -min }, { 2.1850398e+00f, -6.7877737e-38f }, 2.5e-01f }, /* 124 */ { { -3.6093321e-13f, -min }, { -3.6093321e-13f, -min }, 2.1990233e+12f }, /* 125 */ { { -6.5136393e-26f, -min }, { -6.5136393e-26f, -min }, 9.6714066e+24f }, -#ifndef __aarch64__ -// TODO(b/342448599); Fix and re-enable on Arm. /* 126 */ { { -min, -min }, { -min, -min }, 4.2535296e+37f }, -#endif /* 127 */ { { zero, -min }, { zero, -min }, 4.2535296e+37f }, -#ifndef __aarch64__ -// TODO(b/342448599); Fix and re-enable on Arm. /* 128 */ { { min, -min }, { min, -min }, 4.2535296e+37f }, -#endif /* 129 */ { { 6.5136393e-26f, -min }, { 6.5136393e-26f, -min }, 9.6714066e+24f }, /* 130 */ { { 3.6093321e-13f, -min }, { 3.6093321e-13f, -min }, 2.1990233e+12f }, /* 131 */ { { 2.e+00f, -min }, { -2.1850398e+00f, -6.7877737e-38f }, 2.5e-01f }, @@ -827,15 +833,9 @@ struct Tan { /* 140 */ { { -2.e+00f, zero }, { 2.1850398e+00f, zero }, 2.5e-01f }, /* 141 */ { { -3.6093321e-13f, zero }, { -3.6093321e-13f, zero }, 2.1990233e+12f }, /* 142 */ { { -6.5136393e-26f, zero }, { -6.5136393e-26f, zero }, 9.6714066e+24f }, -#ifndef __aarch64__ -// TODO(b/342448599); Fix and re-enable on Arm. /* 143 */ { { -min, zero }, { -min, zero }, 4.2535296e+37f }, -#endif /* 144 */ { { zero, zero }, { zero, zero }, 1.e+00f }, -#ifndef __aarch64__ -// TODO(b/342448599); Fix and re-enable on Arm. /* 145 */ { { min, zero }, { min, zero }, 4.2535296e+37f }, -#endif /* 146 */ { { 6.5136393e-26f, zero }, { 6.5136393e-26f, zero }, 9.6714066e+24f }, /* 147 */ { { 3.6093321e-13f, zero }, { 3.6093321e-13f, zero }, 2.1990233e+12f }, /* 148 */ { { 2.e+00f, zero }, { -2.1850398e+00f, zero }, 2.5e-01f }, @@ -850,15 +850,9 @@ struct Tan { /* 157 */ { { -2.e+00f, min }, { 2.1850398e+00f, 6.7877737e-38f }, 2.5e-01f }, /* 158 */ { { -3.6093321e-13f, min }, { -3.6093321e-13f, min }, 2.1990233e+12f }, /* 159 */ { { -6.5136393e-26f, min }, { -6.5136393e-26f, min }, 9.6714066e+24f }, -#ifndef __aarch64__ -// TODO(b/342448599); Fix and re-enable on Arm. /* 160 */ { { -min, min }, { -min, min }, 4.2535296e+37f }, -#endif /* 161 */ { { zero, min }, { zero, min }, 4.2535296e+37f }, -#ifndef __aarch64__ -// TODO(b/342448599); Fix and re-enable on Arm. /* 162 */ { { min, min }, { min, min }, 4.2535296e+37f }, -#endif /* 163 */ { { 6.5136393e-26f, min }, { 6.5136393e-26f, min }, 9.6714066e+24f }, /* 164 */ { { 3.6093321e-13f, min }, { 3.6093321e-13f, min }, 2.1990233e+12f }, /* 165 */ { { 2.e+00f, min }, { -2.1850398e+00f, 6.7877737e-38f }, 2.5e-01f }, @@ -985,7 +979,6 @@ struct Tan { /* 286 */ { { 6.1409603e+25f, inf }, { zero, 1.e+00f }, 5.e-01f }, /* 287 */ { { max, inf }, { zero, 1.e+00f }, 5.e-01f }, /* 288 */ { { inf, inf }, { zero, 1.e+00f }, 5.e-01f } - // clang-format on }; return table; @@ -993,7 +986,11 @@ struct Tan { const T nan = std::nan(""); const T zero = 0.0; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -1126,13 +1123,9 @@ struct Tan { /* 123 */ { { -1.9999999999998694e+00, -min }, { 2.185039863262273e+00, -1.2848464717505794e-307 }, 2.5e-01 }, /* 124 */ { { -4.4647944971961829e-103, -min }, { -4.4647944971961829e-103, -min }, 2.2397447421778042e+102 }, /* 125 */ { { -9.9671949510973086e-206, -min }, { -9.9671949510973086e-206, -min }, 1.0032913020226237e+205 }, -#ifndef __aarch64__ // Seems that denormalized values are being flushed to zero on arm (see b/342448599) /* 126 */ { { -min, -min }, { -min, -min }, 2.2471164185778949e+307 }, -#endif /* 127 */ { { zero, -min }, { zero, -min }, 2.2471164185778949e+307 }, -#ifndef __aarch64__ // Seems that denormalized values are being flushed to zero on arm (see b/342448599) /* 128 */ { { min, -min }, { min, -min }, 2.2471164185778949e+307 }, -#endif /* 129 */ { { 9.9671949510973086e-206, -min }, { 9.9671949510973086e-206, -min }, 1.0032913020226237e+205 }, /* 130 */ { { 4.4647944971961829e-103, -min }, { 4.4647944971961829e-103, -min }, 2.2397447421778042e+102 }, /* 131 */ { { 1.9999999999998694e+00, -min }, { -2.185039863262273e+00, -1.2848464717505794e-307 }, 2.5e-01 }, @@ -1147,13 +1140,9 @@ struct Tan { /* 140 */ { { -1.9999999999998694e+00, zero }, { 2.185039863262273e+00, zero }, 2.5e-01 }, /* 141 */ { { -4.4647944971961829e-103, zero }, { -4.4647944971961829e-103, zero }, 2.2397447421778042e+102 }, /* 142 */ { { -9.9671949510973086e-206, zero }, { -9.9671949510973086e-206, zero }, 1.0032913020226237e+205 }, -#ifndef __aarch64__ // Seems that denormalized values are being flushed to zero on arm (see b/342448599) /* 143 */ { { -min, zero }, { -min, zero }, 2.2471164185778949e+307 }, -#endif /* 144 */ { { zero, zero }, { zero, zero }, 1.e+00 }, -#ifndef __aarch64__ // Seems that denormalized values are being flushed to zero on arm (see b/342448599) /* 145 */ { { min, zero }, { min, zero }, 2.2471164185778949e+307 }, -#endif /* 146 */ { { 9.9671949510973086e-206, zero }, { 9.9671949510973086e-206, zero }, 1.0032913020226237e+205 }, /* 147 */ { { 4.4647944971961829e-103, zero }, { 4.4647944971961829e-103, zero }, 2.2397447421778042e+102 }, /* 148 */ { { 1.9999999999998694e+00, zero }, { -2.185039863262273e+00, zero }, 2.5e-01 }, @@ -1168,13 +1157,9 @@ struct Tan { /* 157 */ { { -1.9999999999998694e+00, min }, { 2.185039863262273e+00, 1.2848464717505794e-307 }, 2.5e-01 }, /* 158 */ { { -4.4647944971961829e-103, min }, { -4.4647944971961829e-103, min }, 2.2397447421778042e+102 }, /* 159 */ { { -9.9671949510973086e-206, min }, { -9.9671949510973086e-206, min }, 1.0032913020226237e+205 }, -#ifndef __aarch64__ // Seems that denormalized values are being flushed to zero on arm (see b/342448599) /* 160 */ { { -min, min }, { -min, min }, 2.2471164185778949e+307 }, -#endif /* 161 */ { { zero, min }, { zero, min }, 2.2471164185778949e+307 }, -#ifndef __aarch64__ // Seems that denormalized values are being flushed to zero on arm (see b/342448599) /* 162 */ { { min, min }, { min, min }, 2.2471164185778949e+307 }, -#endif /* 163 */ { { 9.9671949510973086e-206, min }, { 9.9671949510973086e-206, min }, 1.0032913020226237e+205 }, /* 164 */ { { 4.4647944971961829e-103, min }, { 4.4647944971961829e-103, min }, 2.2397447421778042e+102 }, /* 165 */ { { 1.9999999999998694e+00, min }, { -2.185039863262273e+00, 1.2848464717505794e-307 }, 2.5e-01 }, @@ -1300,8 +1285,7 @@ struct Tan { /* 285 */ { { 8.9589789687104559e+102, inf }, { zero, 1.e+00 }, 5.e-01 }, /* 286 */ { { 4.0131652080900752e+205, inf }, { zero, 1.e+00 }, 5.e-01 }, /* 287 */ { { max, inf }, { zero, 1.e+00 }, 5.e-01 }, - /* 288 */ { { inf, inf }, { zero, 1.e+00 }, 5.e-01 }, - // clang-format on + /* 288 */ { { inf, inf }, { zero, 1.e+00 }, 5.e-01 } // clang-format on }; return table; } else { @@ -1323,7 +1307,11 @@ struct Asin { const T pi_2 = 1.5707964f; const T zero = 0.0f; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -1627,7 +1615,11 @@ struct Asin { const T pi_2 = 1.5707963267948966; const T zero = 0.0; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -1945,7 +1937,11 @@ struct Asinh { const T pi_2 = 1.5707964f; const T zero = 0.0f; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off @@ -2249,7 +2245,11 @@ struct Asinh { const T pi_2 = 1.5707963267948966; const T zero = 0.0; const T inf = std::numeric_limits::infinity(); +#ifdef __aarch64__ + const T min = std::nextafter(std::numeric_limits::min(), T(1)); +#else const T min = std::numeric_limits::min(); +#endif const T max = std::numeric_limits::max(); const TableType table{ // clang-format off diff --git a/third_party/xla/xla/tests/generate_complex_unary_op_samples.py b/third_party/xla/xla/tests/generate_complex_unary_op_samples.py index 08be6ea42f1ed7..6dc472e6bba84d 100644 --- a/third_party/xla/xla/tests/generate_complex_unary_op_samples.py +++ b/third_party/xla/xla/tests/generate_complex_unary_op_samples.py @@ -21,6 +21,7 @@ """ import os +import platform import re import sys import jax._src.test_util as jtu @@ -36,6 +37,16 @@ def disable(op, real, imag): def main(): + machine = platform.machine() + is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm') + if is_arm_cpu and platform.system() == 'Darwin': + # jtu.complex_plane_sample on Darwin ARM generates samples that + # are specific to the given platform (tiny is mapped to + # nextafter(tiny, inf) to avoid unexpected result when DAZ is + # enabled). Here we handle the Mac specific DAZ difference at C++ + # level (see the __aarch64__-dependent min value mapping below). + sys.stdout.write("Don't run this script under Darwin ARM\n") + return target = (sys.argv[1] if len(sys.argv) > 1 else 'xla').lower() assert target in {'xla', 'tensorflow'}, target header_file_define = dict( @@ -190,8 +201,17 @@ def tostr(v): max='std::numeric_limits::max()', ).items(): if name in used_constants: - constants.append(f'const T {name} = {value};') - constants = '\n '.join(constants) + if name == 'min': + constants.append('#ifdef __aarch64__') + constants.append(f'const T {name} = std::nextafter({value}, T(1));') + constants.append('#else') + constants.append(f'const T {name} = {value};') + constants.append('#endif') + else: + constants.append(f'const T {name} = {value};') + nl = '\n ' + constants = nl.join(constants) + constants = constants.replace(nl + '#', '\n#') ifblocks.append(f"""\ if constexpr (std::is_same_v) {{ From 85b53a47bb5979e89d52411fdb2ae6c834657c72 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Fri, 11 Apr 2025 02:30:26 -0700 Subject: [PATCH 0549/1324] [XLA:GPU] Do not Re-Replace HLO ops to avoid inconsistencies in the HLO Module. This is one step towards fixing `--xla_gpu_triton_gemm_disable_reduced_precision_reduction` and make it pass all tests. PiperOrigin-RevId: 746376120 --- .../xla/service/gpu/split_k_gemm_rewriter.cc | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 646fbc2f2c0f98..2438410d776fdf 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -425,21 +425,16 @@ absl::Status MakeDotSplitKBatch(HloInstruction* dot_fusion, // The output of the reduce has to have the layout of the original dot. *reduce->mutable_shape()->mutable_layout() = output_layout; + HloInstruction* split_k_root = disable_reduced_precision_reduction + ? MakeConvertToHlo(reduce, output_type) + : reduce; + if (dot_fusion->IsRoot()) { - dot_fusion->parent()->set_root_instruction(reduce, + dot_fusion->parent()->set_root_instruction(split_k_root, /*accept_different_shape=*/true); } else { - TF_RETURN_IF_ERROR(dot_fusion->ReplaceAllUsesWithDifferentShape(reduce)); - } - - if (disable_reduced_precision_reduction) { - HloInstruction* convert = MakeConvertToHlo(reduce, output_type); - if (reduce->IsRoot()) { - reduce->parent()->set_root_instruction(convert, - /*accept_different_shape=*/true); - } else { - TF_RETURN_IF_ERROR(reduce->ReplaceAllUsesWithDifferentShape(convert)); - } + TF_RETURN_IF_ERROR( + dot_fusion->ReplaceAllUsesWithDifferentShape(split_k_root)); } return absl::OkStatus(); From 0b585b91831bc12bb7a1a05d32818f7e4b2adb37 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 11 Apr 2025 02:42:13 -0700 Subject: [PATCH 0550/1324] [XLA:GPU][Emitters] Extract the utility to create 0-splat constants to utils.h. PiperOrigin-RevId: 746379575 --- .../xla/backends/gpu/codegen/emitters/BUILD | 1 + .../backends/gpu/codegen/emitters/scatter.cc | 20 +------- .../tests/scatter/sorted_indices_large.hlo | 27 +++++++++++ third_party/xla/xla/codegen/emitters/BUILD | 13 +++++ third_party/xla/xla/codegen/emitters/utils.cc | 48 +++++++++++++++++++ third_party/xla/xla/codegen/emitters/utils.h | 29 +++++++++++ 6 files changed, 120 insertions(+), 18 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/tests/scatter/sorted_indices_large.hlo create mode 100644 third_party/xla/xla/codegen/emitters/utils.cc create mode 100644 third_party/xla/xla/codegen/emitters/utils.h diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD index b0e0d66fcdaa5c..244e5cafed1565 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD @@ -355,6 +355,7 @@ cc_library( "//xla/codegen/emitters:computation_partitioner", "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/emitters:type_util", + "//xla/codegen/emitters:utils", "//xla/codegen/emitters/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc index 7457e35e8ea999..93ee0e457f24a1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/scatter.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/codegen/emitters/ir/xla_ops.h" #include "xla/codegen/emitters/type_util.h" +#include "xla/codegen/emitters/utils.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -636,23 +637,6 @@ void ScatterWithDistributedIndices::ComputeIndexing( } } -DenseElementsAttr GetShapedZeroConstantAttr(VectorType vector_type) { - auto elem_type = vector_type.getElementType(); - if (auto float_type = mlir::dyn_cast(elem_type)) { - std::vector values( - vector_type.getNumElements(), - APFloat::getZero(float_type.getFloatSemantics())); - return DenseElementsAttr::get(vector_type, values); - } - if (auto int_type = mlir::dyn_cast(elem_type)) { - std::vector values( - vector_type.getNumElements(), - APInt::getZero(int_type.getIntOrFloatBitWidth())); - return DenseElementsAttr::get(vector_type, values); - } - llvm_unreachable("Unsupported vector element type"); -} - Value ScatterWithDistributedIndices::InitializeAccumulator( ImplicitLocOpBuilder& b) const { auto elem_type = emitters::PrimitiveTypeToMlirType(description_.elem_type, b); @@ -662,7 +646,7 @@ Value ScatterWithDistributedIndices::InitializeAccumulator( auto accumulator_type = VectorType::get({update_iterations_per_thread, vector_size_}, elem_type); return b.create( - accumulator_type, GetShapedZeroConstantAttr(accumulator_type)); + accumulator_type, emitters::GetZeroDenseElementsAttr(accumulator_type)); } absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/tests/scatter/sorted_indices_large.hlo b/third_party/xla/xla/backends/gpu/codegen/emitters/tests/scatter/sorted_indices_large.hlo new file mode 100644 index 00000000000000..189d82e1432760 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/tests/scatter/sorted_indices_large.hlo @@ -0,0 +1,27 @@ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s + +mul { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %prod = s32[] multiply(%p0, %p1) +} +scatter { + %operand = s32[100] parameter(0) + %indices = s32[2008,1] parameter(1) + %update = s32[2008,64] parameter(2) + + ROOT %scatter = s32[100] scatter( + s32[100] %operand, + s32[2008,1] %indices, + s32[2008,64] %update + ), + update_window_dims={1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + indices_are_sorted=true, + unique_indices=false, + to_apply=mul +} +// CHECK: arith.constant dense<0> : vector<2xi32> diff --git a/third_party/xla/xla/codegen/emitters/BUILD b/third_party/xla/xla/codegen/emitters/BUILD index 42bfde560a968f..027a21c82a5944 100644 --- a/third_party/xla/xla/codegen/emitters/BUILD +++ b/third_party/xla/xla/codegen/emitters/BUILD @@ -57,6 +57,19 @@ xla_cc_test( ], ) +cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "elemental_hlo_to_mlir", srcs = ["elemental_hlo_to_mlir.cc"], diff --git a/third_party/xla/xla/codegen/emitters/utils.cc b/third_party/xla/xla/codegen/emitters/utils.cc new file mode 100644 index 00000000000000..ed966925931dd8 --- /dev/null +++ b/third_party/xla/xla/codegen/emitters/utils.cc @@ -0,0 +1,48 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/codegen/emitters/utils.h" + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" + +namespace xla::emitters { + +using mlir::DenseElementsAttr; +using mlir::ShapedType; + +DenseElementsAttr GetZeroDenseElementsAttr(ShapedType shaped_type) { + auto elem_type = shaped_type.getElementType(); + if (auto float_type = mlir::dyn_cast(elem_type)) { + mlir::SmallVector values( + shaped_type.getNumElements(), + mlir::APFloat::getZero(float_type.getFloatSemantics())); + return DenseElementsAttr::get(shaped_type, values); + } + if (auto int_type = mlir::dyn_cast(elem_type)) { + mlir::SmallVector values( + shaped_type.getNumElements(), + mlir::APInt::getZero(int_type.getIntOrFloatBitWidth())); + return DenseElementsAttr::get(shaped_type, values); + } + llvm_unreachable("Unsupported element type"); +} + +} // namespace xla::emitters diff --git a/third_party/xla/xla/codegen/emitters/utils.h b/third_party/xla/xla/codegen/emitters/utils.h new file mode 100644 index 00000000000000..321fadf53f13b3 --- /dev/null +++ b/third_party/xla/xla/codegen/emitters/utils.h @@ -0,0 +1,29 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_CODEGEN_EMITTERS_UTILS_H_ +#define XLA_CODEGEN_EMITTERS_UTILS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "xla/xla_data.pb.h" + +namespace xla::emitters { + +mlir::DenseElementsAttr GetZeroDenseElementsAttr(mlir::ShapedType shaped_type); + +} // namespace xla::emitters + +#endif // XLA_CODEGEN_EMITTERS_UTILS_H_ From b83ce1e407b73c4f52a508b681dbe4848c22d67a Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Fri, 11 Apr 2025 02:57:31 -0700 Subject: [PATCH 0551/1324] [XLA:GPU] NOOP change. Refactor split parts in triton matmul emitters to use named constants for the better readability. PiperOrigin-RevId: 746383471 --- .../gpu/codegen/triton/dot_algorithms.cc | 77 +++++++++++-------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc index a45b6854372297..0404af051b94db 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -160,35 +160,35 @@ Value IEEEDot(EmitterLocOpBuilder b, Value lhs, Value rhs, Value acc) { absl::StatusOr EmitBF16x9Matmul(EmitterLocOpBuilder& b, const DotOperands& dot_operands, const PrecisionSpec& precision_spec) { + constexpr int kNumParts = 3; + constexpr int kHigh = 0; + constexpr int kMid = 1; + constexpr int kLow = 2; + Type f32 = b.getF32Type(); TF_RETURN_IF_ERROR(ExpectType(dot_operands.lhs, f32)); TF_RETURN_IF_ERROR(ExpectType(dot_operands.rhs, f32)); TF_RETURN_IF_ERROR(ExpectType(dot_operands.accumulator, f32)); - std::vector lhs_parts = SplitF32(b, dot_operands.lhs, 3); - std::vector rhs_parts = SplitF32(b, dot_operands.rhs, 3); + std::vector lhs_parts = SplitF32(b, dot_operands.lhs, kNumParts); + std::vector rhs_parts = SplitF32(b, dot_operands.rhs, kNumParts); - Value local_acc = triton::ZerosLike(b, dot_operands.accumulator); - Value result; + Value result = triton::ZerosLike(b, dot_operands.accumulator); - // low @ low + low @ mid + mid @ low - result = IEEEDot(b, lhs_parts[2], rhs_parts[2], local_acc); - result = IEEEDot(b, lhs_parts[1], rhs_parts[2], result); - result = IEEEDot(b, lhs_parts[2], rhs_parts[1], result); + result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kLow], result); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kLow], result); + result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kMid], result); - // mid @ mid - result = IEEEDot(b, lhs_parts[1], rhs_parts[1], result); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kMid], result); - // high @ low + low @ high - result = IEEEDot(b, lhs_parts[2], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[2], result); + result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kHigh], result); + result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kLow], result); - // high @ mid + mid @ high - result = IEEEDot(b, lhs_parts[1], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[1], result); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kHigh], result); + result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kMid], result); result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[0], result); + result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kHigh], result); result = b.create(dot_operands.accumulator, result); return result; } @@ -198,26 +198,31 @@ absl::StatusOr EmitBF16x9Matmul(EmitterLocOpBuilder& b, absl::StatusOr EmitBF16x6Matmul(EmitterLocOpBuilder& b, const DotOperands& dot_operands, const PrecisionSpec& precision_spec) { + constexpr int kNumParts = 3; + constexpr int kHigh = 0; + constexpr int kMid = 1; + constexpr int kLow = 2; + Type f32 = b.getF32Type(); TF_RETURN_IF_ERROR(ExpectType(dot_operands.lhs, f32)); TF_RETURN_IF_ERROR(ExpectType(dot_operands.rhs, f32)); TF_RETURN_IF_ERROR(ExpectType(dot_operands.accumulator, f32)); - std::vector lhs_parts = SplitF32(b, dot_operands.lhs, 3); - std::vector rhs_parts = SplitF32(b, dot_operands.rhs, 3); + std::vector lhs_parts = SplitF32(b, dot_operands.lhs, kNumParts); + std::vector rhs_parts = SplitF32(b, dot_operands.rhs, kNumParts); + + Value result = triton::ZerosLike(b, dot_operands.accumulator); - Value local_acc = triton::ZerosLike(b, dot_operands.accumulator); - Value result = IEEEDot(b, lhs_parts[1], rhs_parts[1], local_acc); - // high @ low + low @ high - result = IEEEDot(b, lhs_parts[2], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[2], result); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kMid], result); - // high @ mid + mid @ high - result = IEEEDot(b, lhs_parts[1], rhs_parts[0], result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[1], result); + result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kHigh], result); + result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kLow], result); + + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kHigh], result); + result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kMid], result); result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_parts[0], rhs_parts[0], result); + result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kHigh], result); result = b.create(dot_operands.accumulator, result); return result; } @@ -227,19 +232,23 @@ absl::StatusOr EmitBF16x6Matmul(EmitterLocOpBuilder& b, absl::StatusOr EmitBF16x3Matmul(EmitterLocOpBuilder& b, const DotOperands& dot_operands, const PrecisionSpec& precision_spec) { + constexpr int kNumParts = 2; + constexpr int kHigh = 0; + constexpr int kLow = 1; + Type f32 = b.getF32Type(); TF_RETURN_IF_ERROR(ExpectType(dot_operands.lhs, f32)); TF_RETURN_IF_ERROR(ExpectType(dot_operands.rhs, f32)); TF_RETURN_IF_ERROR(ExpectType(dot_operands.accumulator, f32)); - std::vector lhs_bf16 = SplitF32(b, dot_operands.lhs, 2); - std::vector rhs_bf16 = SplitF32(b, dot_operands.rhs, 2); + std::vector lhs_bf16 = SplitF32(b, dot_operands.lhs, kNumParts); + std::vector rhs_bf16 = SplitF32(b, dot_operands.rhs, kNumParts); - Value local_acc = triton::ZerosLike(b, dot_operands.accumulator); - Value result = IEEEDot(b, lhs_bf16[1], rhs_bf16[0], local_acc); - result = IEEEDot(b, lhs_bf16[0], rhs_bf16[1], result); + Value result = triton::ZerosLike(b, dot_operands.accumulator); + result = IEEEDot(b, lhs_bf16[kLow], rhs_bf16[kHigh], result); + result = IEEEDot(b, lhs_bf16[kHigh], rhs_bf16[kLow], result); result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_bf16[0], rhs_bf16[0], result); + result = IEEEDot(b, lhs_bf16[kHigh], rhs_bf16[kHigh], result); result = b.create(dot_operands.accumulator, result); return result; } From 3c3bc94f3291741f565b5c48b3740ba24188511a Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 11 Apr 2025 03:10:10 -0700 Subject: [PATCH 0552/1324] PR #24958: [ROCM] move rocm ci command into xla repo Imported from GitHub PR https://github.com/openxla/xla/pull/24958 Move rocm CI command into the xla directory. This is a preparation step before implementing asan build for rocm. Copybara import of the project: -- bd5a73bd96a0e4f9d5d65926cdf3f4d7206331e4 by alekstheod : Move ci command to xla repo -- a5cbb02fa751d8ad6175bcb79c6da44b28f8fdc0 by alekstheod : Address review comments -- dbd0b6b1656bed67a43dc552a608a7db9184b45a by alekstheod : Move ci build script to build_tools dir Merging this change closes #24958 PiperOrigin-RevId: 746386978 --- .../xla/build_tools/rocm/run_xla_ci_build.sh | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100755 third_party/xla/build_tools/rocm/run_xla_ci_build.sh diff --git a/third_party/xla/build_tools/rocm/run_xla_ci_build.sh b/third_party/xla/build_tools/rocm/run_xla_ci_build.sh new file mode 100755 index 00000000000000..14a5b28dc60910 --- /dev/null +++ b/third_party/xla/build_tools/rocm/run_xla_ci_build.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +CONFIG=$1 +DISK_CACHE_PATH=$2 + +bazel --bazelrc=/usertools/rocm.bazelrc test \ + --config=${CONFIG} \ + --config=xla_cpp \ + --disk_cache=${DISK_CACHE_PATH} \ + --test_tag_filters=gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only \ + --build_tag_filters=gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only \ + --profile=/tf/pkg/profile.json.gz \ + --keep_going \ + --test_env=TF_TESTS_PER_GPU=1 \ + --test_env=TF_GPU_COUNT=2 \ + --action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \ + --action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \ + --test_output=errors \ + --local_test_jobs=2 \ + --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute From 0fc2bfbb9bb865f4b20c3f241befa47ba131296f Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 11 Apr 2025 03:47:30 -0700 Subject: [PATCH 0553/1324] PR #24971: [ROCM] Ci fix mem leak in rocm dnn Imported from GitHub PR https://github.com/openxla/xla/pull/24971 Fix memory leak detected by asan. Early return on a wrong condition in destructor. Copybara import of the project: -- 7cb0126327384d280be5ed3e4b4bd0ed95fc25f2 by alekstheod : Fix memory leak in rocm_dnn Merging this change closes #24971 PiperOrigin-RevId: 746396540 --- third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc index a8312e5104ebc2..a12083b4821e93 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc @@ -872,7 +872,7 @@ struct ScopedDescriptor { } ~ScopedDescriptor() { - if (handle_ != nullptr) return; + if (handle_ == nullptr) return; auto status = miDestroyObject( handle_); // wrap::miopenDestroyTensorDescriptor(handle_); From d3e4b324d423ac0db8d7fa4554d6de561b1be839 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 11 Apr 2025 04:00:45 -0700 Subject: [PATCH 0554/1324] [XLA:GPU] Embed single host interpolation tables by default. Also: 0. Read these tables by default if input profiles are not specified. 1. Make the proper abstract class out of InterpolatorBase. 2. Simplifies !IsHopper logic present in the previous version. PiperOrigin-RevId: 746399857 --- third_party/xla/xla/service/gpu/model/BUILD | 12 +- .../gpu/model/collective_interpolator.cc | 96 +- .../gpu/model/collective_interpolator.h | 17 +- .../gpu/model/collective_interpolator_data.h | 3022 +++++++++++++++++ .../gpu/model/collective_interpolator_test.cc | 37 +- ...collective_ptable_stats_collection_test.cc | 4 +- .../xla/xla/service/gpu/model/interpolator.h | 22 +- .../service/gpu/model/interpolator_test.cc | 5 + .../collectives/collective_ops_utils.cc | 9 +- 9 files changed, 3179 insertions(+), 45 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/model/collective_interpolator_data.h diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 3c1f29dd528412..f4897d2bbdfb6f 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -1066,13 +1066,18 @@ xla_cc_test( cc_library( name = "collective_interpolator", srcs = ["collective_interpolator.cc"], - hdrs = ["collective_interpolator.h"], + hdrs = [ + "collective_interpolator.h", + "collective_interpolator_data.h", + ], deps = [ ":gpu_hlo_cost_analysis", ":hlo_op_profile_proto_cc", + ":hlo_op_profiles", ":interpolator", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service/gpu/transforms/collectives:collective_ops_utils", @@ -1098,12 +1103,16 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", ], @@ -1151,6 +1160,7 @@ xla_cc_test( "//xla/service:hlo_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator.cc b/third_party/xla/xla/service/gpu/model/collective_interpolator.cc index 0e9f23b5690fae..1ac08e7cbb6574 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator.cc +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include #include +#include +#include #include #include "absl/container/flat_hash_map.h" @@ -34,8 +36,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/model/collective_interpolator_data.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/hlo_op_profile.pb.h" +#include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/gpu/model/interpolator.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/hlo.pb.h" @@ -53,10 +58,12 @@ namespace { struct InterpolationSpecification { HloOpcode opcode; GPUCommunicationType comm; - int num_devices; - int transfer_size; + int64_t num_devices; + int64_t transfer_size; }; +// Returns number of participating devices in an input `device_list`. Supports +// only `iota_replica_group_list`. absl::StatusOr GetNumParticipatingDevices( const CollectiveDeviceList& device_list) { auto iota = device_list.iota_replica_group_list(); @@ -67,11 +74,6 @@ absl::StatusOr GetNumParticipatingDevices( return iota->num_devices_per_group(); } -absl::StatusOr GetNumParticipatingDevices( - const HloCollectiveInstruction& instr) { - return GetNumParticipatingDevices(instr.device_list()); -} - absl::StatusOr Spec( const HloInstructionProfile& profile, const se::DeviceDescription& device_info) { @@ -85,12 +87,13 @@ absl::StatusOr Spec( GpuHloCostAnalysis analysis(GpuHloCostAnalysis::Options(), device_info); TF_RETURN_IF_ERROR(collective->Accept(&analysis)); - int bytes_transferred = analysis.BytesTransferred(*collective); + int64_t bytes_transferred = analysis.BytesTransferred(*collective); TF_ASSIGN_OR_RETURN( auto comm, CommunicationType(*collective, device_info.gpu_compute_capability())); - TF_ASSIGN_OR_RETURN(int num_devices, GetNumParticipatingDevices(*collective)); + TF_ASSIGN_OR_RETURN(int num_devices, + GetNumParticipatingDevices(collective->device_list())); return InterpolationSpecification{ /*opcode=*/collective->opcode(), @@ -247,12 +250,62 @@ HloOpcode AsyncToSyncOpcode(const HloCollectiveInstruction& instr) { return opcode; } +absl::StatusOr ReadDefaultProfiles( + const se::DeviceDescription& device_info) { + DeviceHloInstructionProfiles profile; + + if (!tsl::protobuf::TextFormat::ParseFromString(kDefaultCollectivePTable, + &profile)) { + return absl::FailedPreconditionError("Cannot parse a default profile."); + } + std::string key = HloOpProfiles::GetProfileName(device_info); + + if (!profile.entries().contains(key)) { + return absl::NotFoundError(absl::StrCat("Cannot find key: ", key)); + } + return profile.entries().at(key); +} + } // namespace +/*static*/ absl::StatusOr> +CollectiveInterpolator::Create(const se::DeviceDescription& device_info) { + auto interpolators = std::make_unique>>>(); + + TF_ASSIGN_OR_RETURN(HloInstructionProfileList profiles, + ReadDefaultProfiles(device_info)); + for (auto& profile : profiles.entries()) { + TF_ASSIGN_OR_RETURN(InterpolationSpecification spec, + Spec(profile, device_info)); + CollectiveInterpolator::InterpolatorKey key{ + /*opcode=*/spec.opcode, + /*communication_type=*/spec.comm, + }; + auto it = interpolators->find(key); + if (it == interpolators->end()) { + auto interpolator = + std::make_unique>( + /*next_context=*/std::array{-1, -1}, + /*next_power_context=*/std::array{1, 1}, + /*max_context=*/std::array{1 << 30, 8}, + /*min_context=*/std::array{1 << 10, 8}); + + (*interpolators)[key] = std::move(interpolator); + } + std::array point = {spec.transfer_size, spec.num_devices}; + interpolators->at(key)->Add(point, + profile.network_throughput_bytes_per_sec()); + } + return std::unique_ptr( + new CollectiveInterpolator(std::move(interpolators), device_info)); +} + /*static*/ absl::StatusOr> CollectiveInterpolator::Create(HloInstructionProfileList profiles, const se::DeviceDescription& device_info) { - CollectiveInterpolator::InterpolatorMap interpolators; + auto interpolators = std::make_unique>>>(); for (auto& profile : profiles.entries()) { TF_ASSIGN_OR_RETURN(InterpolationSpecification spec, @@ -261,19 +314,22 @@ CollectiveInterpolator::Create(HloInstructionProfileList profiles, /*opcode=*/spec.opcode, /*communication_type=*/spec.comm, }; - auto it = interpolators.find(key); - if (it == interpolators.end()) { - interpolators[key] = EuclideanNNInterpolator(); + auto it = interpolators->find(key); + if (it == interpolators->end()) { + auto interpolator = + std::make_unique>(); + (*interpolators)[key] = std::move(interpolator); } std::array point = {spec.transfer_size, spec.num_devices}; - interpolators[key].Add(point, profile.network_throughput_bytes_per_sec()); + interpolators->at(key)->Add(point, + profile.network_throughput_bytes_per_sec()); } return std::unique_ptr( - new CollectiveInterpolator(profiles, interpolators, device_info)); + new CollectiveInterpolator(std::move(interpolators), device_info)); } std::optional CollectiveInterpolator::EstimatedRuntime( - HloCollectiveInstruction& instr) { + const HloCollectiveInstruction& instr) const { GpuHloCostAnalysis analysis(GpuHloCostAnalysis::Options(), device_info_); CHECK_OK(instr.Accept(&analysis)); int64_t bytes_transferred = analysis.BytesTransferred(instr); @@ -281,21 +337,21 @@ std::optional CollectiveInterpolator::EstimatedRuntime( if (!comm.ok()) { return std::nullopt; } - auto num_devices = GetNumParticipatingDevices(instr); + auto num_devices = GetReplicaGroupCountAndSize(&instr); if (!num_devices.ok()) { return std::nullopt; } - std::array point({bytes_transferred, *num_devices}); + std::array point({bytes_transferred, (*num_devices)->second}); CollectiveInterpolator::InterpolatorKey key{ /*opcode=*/AsyncToSyncOpcode(instr), /*communication_type=*/*comm, }; - if (!interpolators_.contains(key)) { + if (!interpolators_->contains(key)) { VLOG(1) << "Cannot find key for instr: " << instr.ToString(); return std::nullopt; } return absl::Seconds(1.0 * bytes_transferred / - interpolators_.at(key).Eval(point)); + interpolators_->at(key)->Eval(point)); } /*static*/ std::unique_ptr CollectiveInterpolator::ConstructModule( diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator.h b/third_party/xla/xla/service/gpu/model/collective_interpolator.h index 948d7c6a32d7b5..4ef0789c1ae7f5 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator.h +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator.h @@ -51,13 +51,16 @@ class CollectiveInterpolator { } }; - using InterpolatorMap = - absl::flat_hash_map>; + using InterpolatorMap = std::unique_ptr>>>; static absl::StatusOr> Create( HloInstructionProfileList profiles, const se::DeviceDescription& device_info); + static absl::StatusOr> Create( + const se::DeviceDescription& device_info); + // Constructs the semantically correct module from the profile. // Usually the root instruction of the entry computation is of interest and is // directly related to the `profile`d information. @@ -66,19 +69,15 @@ class CollectiveInterpolator { // Returns the estimated runtime for a supported `collective`. std::optional EstimatedRuntime( - HloCollectiveInstruction& instr); + const HloCollectiveInstruction& instr) const; private: // Uses `EuclideanNNInterpolator` to figure get the closest neighbour from // profiles. - explicit CollectiveInterpolator(HloInstructionProfileList profiles, - InterpolatorMap interpolators, + explicit CollectiveInterpolator(InterpolatorMap interpolators, const se::DeviceDescription& device_info) - : profiles_(profiles), - interpolators_(interpolators), - device_info_(device_info) {} + : interpolators_(std::move(interpolators)), device_info_(device_info) {} - HloInstructionProfileList profiles_; InterpolatorMap interpolators_; const se::DeviceDescription& device_info_; diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator_data.h b/third_party/xla/xla/service/gpu/model/collective_interpolator_data.h new file mode 100644 index 00000000000000..3a395b29706f64 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator_data.h @@ -0,0 +1,3022 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_COLLECTIVE_INTERPOLATOR_DATA_H_ +#define XLA_SERVICE_GPU_MODEL_COLLECTIVE_INTERPOLATOR_DATA_H_ + +// Textproto below is generated via +// +// bazel run --config=cuda -- //xla/tools:collective_perf_table_gen_main +constexpr char kDefaultCollectivePTable[] = R"pb( + entries { + key: "sm_90" + value { + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 256 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 28571428 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 256 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 158415841 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 256 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 33023735 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 256 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 167539267 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 32 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 49535603 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 256 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 164948453 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 512 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 47024246 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 512 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 329896907 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 512 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 68965517 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 512 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 328205128 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 64 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 67297581 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 512 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 326530612 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 1024 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 113174182 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 1024 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 646464646 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 1024 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 112775330 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 1024 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 684491978 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 128 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 240601503 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 1024 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 656410256 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 2048 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 345479082 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 2048 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1333333333 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 2048 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 266666666 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 2048 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1333333333 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 256 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 451499118 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 2048 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1280000000 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 4096 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 413570274 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 4096 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 2680628272 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 4096 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 708160442 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 4096 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 2534653465 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 512 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 666666666 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 4096 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 2652849740 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 8192 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 739350180 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 8192 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 5145728643 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 8192 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1723905723 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 8192 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 5224489795 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 1024 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1479768786 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 8192 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 5305699481 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 16384 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1472322070 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 16384 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 10502564102 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 16384 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 3424749163 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 16384 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 10395939086 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 2048 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 2314124293 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 16384 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 10343434343 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 32768 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 3444911690 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 32768 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 20480000000 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 32768 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 6714754098 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 32768 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 21005128205 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 4096 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 2862334032 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 32768 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 20897959183 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 65536 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 5957818181 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 65536 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 42226804123 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 65536 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 13255663430 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 65536 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 42010256410 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 8192 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 8224899598 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 65536 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 40554455445 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 131072 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 18984936268 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 131072 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 81920000000 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 131072 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 20177339901 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 131072 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 81920000000 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 16384 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 16532795156 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 131072 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 80313725490 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 262144 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 22443835616 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 262144 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 152409302325 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 262144 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 42666666666 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 262144 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 151703703703 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 32768 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 35121114683 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 262144 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 148945454545 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 524288 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 43030860144 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 524288 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 296542986425 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 524288 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 62594078319 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 524288 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 295207207207 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 65536 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 82643127364 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 524288 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 291271111111 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 1048576 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 63906387128 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 1048576 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 553046413502 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 1048576 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 87323117921 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 1048576 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 537180327868 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 131072 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 116924174843 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 1048576 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 534987755102 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 2097152 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 74451576256 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 2097152 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 865161716171 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 2097152 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 134432820512 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 2097152 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 826952681388 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 262144 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 153390286717 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 2097152 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 834853503184 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 4194304 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 110353188802 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 4194304 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 981812734082 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 4194304 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 194613214550 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 4194304 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 934559714795 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 524288 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 204560280920 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 4194304 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 948079566003 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 8388608 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 144272977435 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 8388608 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 992969696969 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 8388608 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 232397163120 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 8388608 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1005346116970 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 1048576 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 233900513049 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 8388608 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1009216554379 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 16777216 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 165927051190 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 16777216 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1149754385964 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 16777216 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 313428784934 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 16777216 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1141617855198 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 2097152 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 311519904931 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 16777216 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1137901247965 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 33554432 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 195347398817 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 33554432 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1228200292825 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 33554432 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 338032237266 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 33554432 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1230001173020 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 4194304 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 327500897946 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 33554432 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1227481416447 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 67108864 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 208692606229 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 67108864 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1278361475160 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 67108864 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 368002105724 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 67108864 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1276222120797 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 8388608 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 347901791639 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 67108864 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1275833916349 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 134217728 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 218058669855 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 134217728 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1287781393920 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 134217728 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 384032229267 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 134217728 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1287880248714 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 16777216 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 359209009549 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 134217728 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1286991101564 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 268435456 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 222030980976 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 268435456 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1305416744475 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 268435456 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 391036278245 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 268435456 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1303236571251 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 33554432 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 364916444628 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 268435456 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1304452513314 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 536870912 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 224423344971 + } + entries { + instruction { + opcode: "all-reduce" + shape { + element_type: F32 + dimensions: 536870912 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1315576326674 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 536870912 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 400124397805 + } + entries { + instruction { + opcode: "all-gather" + shape { + element_type: F32 + dimensions: 536870912 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1316944621060 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 67108864 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 1 + num_devices_per_group: 8 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 373842693108 + } + entries { + instruction { + opcode: "reduce-scatter" + shape { + element_type: F32 + dimensions: 536870912 + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + is_dynamic_dimension: false + } + dimensions: 0 + channel_id: 1 + use_global_device_ids: true + collective_device_list { + iota_replica_group_list { + num_replica_groups: 8 + num_devices_per_group: 1 + iota_reshape_dims: 8 + iota_transpose_perm: 0 + } + } + } + network_throughput_bytes_per_sec: 1315318476705 + } + } + } +)pb"; + +#endif // XLA_SERVICE_GPU_MODEL_COLLECTIVE_INTERPOLATOR_DATA_H_ diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc b/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc index 6a533e437bf644..5445a42859405d 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc @@ -24,19 +24,23 @@ limitations under the License. #include #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -76,7 +80,8 @@ class CollectiveInterpolationTest : public TestWithParam { space_spec.network_througput_bytes); *profiles.add_entries() = entry; } - device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo( + stream_executor::CudaComputeCapability::Hopper()); interpolator_ = *CollectiveInterpolator::Create(profiles, device_info_); } @@ -887,5 +892,35 @@ INSTANTIATE_TEST_SUITE_P( return info.param.test_name; }); +TEST(CollectiveInterpolatorTest, LoadsDefaultProfile) { + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo( + stream_executor::CudaComputeCapability::Hopper()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr interpolator, + CollectiveInterpolator::Create(device_info)); + absl::string_view kHlo = R"( + HloModule m, num_partitions=8 + + wrapped_add { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT _ = f32[] add(a,b) + } + + ENTRY main { + p = f32[256] parameter(0) + ROOT _ = f32[256] all-reduce(p), + to_apply=wrapped_add, + replica_groups=[1,8]<=[8], + use_global_device_ids=true, + channel_id=1 + } +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + + EXPECT_TRUE(interpolator->EstimatedRuntime(*instr).has_value()); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc index 07a21525b4df1b..2fc2b7f0ab0f1b 100644 --- a/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc +++ b/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -77,7 +78,8 @@ DeviceHloInstructionProfiles TestProfiles( class CollectivePerfTableStatsCollectionTest : public HloTestBase { public: explicit CollectivePerfTableStatsCollectionTest() - : device_info_(TestGpuDeviceInfo::RTXA6000DeviceInfo()), + : device_info_(TestGpuDeviceInfo::RTXA6000DeviceInfo( + se::CudaComputeCapability(9, 0))), profiles_path_(tsl::io::JoinPath(tsl::testing::TmpDir(), kFile)) {} void SetUp() override { diff --git a/third_party/xla/xla/service/gpu/model/interpolator.h b/third_party/xla/xla/service/gpu/model/interpolator.h index 76d3260fe819d1..75abb0eecb5eb7 100644 --- a/third_party/xla/xla/service/gpu/model/interpolator.h +++ b/third_party/xla/xla/service/gpu/model/interpolator.h @@ -39,15 +39,10 @@ class InterpolatorBase { virtual ~InterpolatorBase() = default; // Adds point to the interpolation space. - void Add(std::array& point, R val) { - plane_.emplace_back(point, val); - }; + virtual void Add(std::array& point, R val) = 0; // Returns interpolated value. virtual R Eval(std::array& point) const = 0; - - protected: - std::vector, R>> plane_; }; // `Interpolates` any point in euclidean space by just returning the nearest @@ -59,13 +54,17 @@ class InterpolatorBase { template class EuclideanNNInterpolator : public InterpolatorBase { public: + void Add(std::array& point, R val) override { + plane_.emplace_back(point, val); + }; + R Eval(std::array& point) const override { - CHECK_GT(this->plane_.size(), 0); + CHECK_GT(plane_.size(), 0); R result; uint64_t min_dist = std::numeric_limits::max(); - for (const auto& [plane_point, val] : this->plane_) { + for (const auto& [plane_point, val] : plane_) { int64_t dist = Norm2(plane_point, point); if (dist < min_dist) { result = val; @@ -86,6 +85,8 @@ class EuclideanNNInterpolator : public InterpolatorBase { } return dist; } + + std::vector, R>> plane_; }; template @@ -100,13 +101,12 @@ class EuclideanComplementInterpolator : public EuclideanNNInterpolator { max_ctx_(max_context), min_ctx_(min_context) {} - void Add(std::array& point, R val) { - EuclideanNNInterpolator::Add(point, val); + void Add(std::array& point, R val) override { retrieval_[point] = val; } R Eval(std::array& point) const override { - CHECK_GT(this->plane_.size(), 0); + CHECK_GT(retrieval_.size(), 0); std::array interpolation_point; for (int i = 0; i < point.size(); ++i) { std::optional next_potential_dim; diff --git a/third_party/xla/xla/service/gpu/model/interpolator_test.cc b/third_party/xla/xla/service/gpu/model/interpolator_test.cc index 0dbcb2bce16c02..53fb9511d18558 100644 --- a/third_party/xla/xla/service/gpu/model/interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/interpolator_test.cc @@ -52,6 +52,11 @@ class InterpolatorFake : public InterpolatorBase { // Fake eval function which just returns the size of the consumed set. int Eval(std::array& x) const override { return plane_.size(); } + + void Add(std::array& x, int val) override { plane_.push_back(x); } + + private: + std::vector> plane_; }; TEST(Interpolator, PersistsEuclideanPoints) { diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc index 140b20462df012..8022b90cfcfa49 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" +#include + #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/collective_device_list.h" @@ -41,10 +43,13 @@ absl::StatusOr CommunicationType( const se::GpuComputeCapability& gpu_version) { auto iota = instr.device_list().iota_replica_group_list(); + if (!std::holds_alternative(gpu_version)) { + return absl::FailedPreconditionError("Only CUDA is supported."); + } + auto cuda_compute_capability = std::get(gpu_version); - if (!(cuda_compute_capability.IsAtLeastAmpere() && - !cuda_compute_capability.IsAtLeastBlackwell())) { + if (!cuda_compute_capability.IsHopper()) { return absl::FailedPreconditionError( "Only Hopper is supported to get communication type"); } From 55e970a23aebf58e0c9b82a9c65386b044866035 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 11 Apr 2025 04:11:02 -0700 Subject: [PATCH 0555/1324] [XLA:GPU][Emitters] Move TransposeSpec to ir_emission_utils. Also migrate the ir_emission_utils_test to the HloRunnerAgnosticTestBase. PiperOrigin-RevId: 746403340 --- .../xla/backends/gpu/codegen/emitters/BUILD | 14 --- .../gpu/codegen/emitters/transpose.cc | 45 ---------- .../backends/gpu/codegen/emitters/transpose.h | 29 ------- .../gpu/codegen/emitters/transpose_test.cc | 87 ------------------- third_party/xla/xla/service/gpu/BUILD | 10 ++- .../xla/xla/service/gpu/ir_emission_utils.cc | 69 ++++++++++++++- .../xla/xla/service/gpu/ir_emission_utils.h | 55 ++++++++++++ .../xla/service/gpu/ir_emission_utils_test.cc | 61 ++++++++++++- 8 files changed, 189 insertions(+), 181 deletions(-) delete mode 100644 third_party/xla/xla/backends/gpu/codegen/emitters/transpose_test.cc diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD index 244e5cafed1565..dcb6158181976e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD @@ -415,17 +415,3 @@ cc_library( "@llvm-project//mlir:TensorDialect", ], ) - -xla_cc_test( - name = "transpose_test", - srcs = ["transpose_test.cc"], - deps = [ - ":transpose", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc index 1fcea8b9571855..7a2df870437d85 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.cc @@ -81,51 +81,6 @@ constexpr int kMaxVectorizedBytes = 4; } // namespace -TransposeSpec GetTransposeSpec(const HloTransposeInstruction* transpose) { - auto inv_permutation = InversePermutation(transpose->dimensions()); - auto& output_shape = transpose->shape(); - SmallVector canonical_output_shape = - llvm::to_vector<3>(output_shape.dimensions()); - SmallVector canonical_permutation = - llvm::to_vector<3>(transpose->dimensions()); - - // If the last dimension is transposed, add a size-1 B dimension. - if (canonical_permutation.back() != canonical_output_shape.size() - 1) { - canonical_permutation.push_back(output_shape.dimensions().size()); - canonical_output_shape.push_back(1); - } - int64_t dim_t1 = -1; - int64_t dim_t2 = -1; - for (int64_t i = canonical_permutation.size() - 1; i >= 0; --i) { - if (canonical_permutation[i] != i) { - dim_t2 = canonical_permutation[i]; - dim_t1 = i; - break; - } - } - // If T1 and T2 are adjacent, insert a size-1 A dimension between them. - if (dim_t1 - dim_t2 == 1) { - canonical_output_shape.insert(canonical_output_shape.begin() + dim_t1, 1); - for (auto& p : canonical_permutation) { - if (p > dim_t2) p++; - } - canonical_permutation.insert(canonical_permutation.begin() + dim_t1, - dim_t1); - } - auto canonical_inv_permutation = InversePermutation(canonical_permutation); - auto canonical_input_shape = - Permute(canonical_output_shape, canonical_inv_permutation); - return TransposeSpec{ - transpose, - llvm::to_vector<3>(transpose->dimensions()), - llvm::to_vector<3>(inv_permutation), - canonical_output_shape, - canonical_permutation, - llvm::to_vector<3>(canonical_inv_permutation), - llvm::to_vector<3>(canonical_input_shape), - }; -} - TransposeFusion::TransposeFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), transpose_(analysis.tiled_transpose()), diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.h b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.h index 0a18d24f4914bf..f935980f52cd0f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.h +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose.h @@ -44,35 +44,6 @@ limitations under the License. namespace xla { namespace gpu { -// TODO(pifon): Unify this with TransposeDescription. -struct TransposeSpec { - const Shape& input_shape() const { return transpose->operand(0)->shape(); } - const Shape& output_shape() const { return transpose->shape(); } - PrimitiveType elem_type() const { return input_shape().element_type(); } - - const HloTransposeInstruction* transpose; - - llvm::SmallVector permutation; - llvm::SmallVector inv_permutation; - - // Canonical transpose permutates the input shape - // into - // . - // Note that the `D` dimensions are batch dimensions. They can also be - // permuted, but they are tiled by 1. - // - // Examples: - // 1. <8x32> -> <32x8> will be canonicalized to <8x1x32x1> -> <32x1x8x1>. - // 2. <8x2x32> -> <32x2x8> will be canonicalized to <8x2x32x1> -> <32x2x8x1>. - // 3. <8x2x32x7x6> -> <6x32x2x7x8> becomes <8x2x32x7x6x1> -> <6x32x2x7x8x1>. - - llvm::SmallVector canonical_output_shape; - llvm::SmallVector canonical_permutation; - llvm::SmallVector canonical_inv_permutation; - llvm::SmallVector canonical_input_shape; -}; -TransposeSpec GetTransposeSpec(const HloTransposeInstruction* transpose); - // Lowers kTranspose fusion to LLVM via MLIR using GPU's shared memory. // Each thread block of `kWarpSize` x `kNumRows` threads diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose_test.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transpose_test.cc deleted file mode 100644 index df33fa2b5f3b0c..00000000000000 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transpose_test.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/backends/gpu/codegen/emitters/transpose.h" - -#include -#include -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla::gpu { -namespace { - -using ::testing::ElementsAre; - -class TransposeTest : public HloTestBase { - public: - TransposeSpec GetTransposeSpecFromRoot(absl::string_view hlo_text) { - auto module = ParseAndReturnVerifiedModule(hlo_text).value(); - auto* root = module->entry_computation()->root_instruction(); - return GetTransposeSpec(Cast(root)); - } -}; - -TEST_F(TransposeTest, Transpose_10) { - auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { - p0 = f32[8, 32] parameter(0) - ROOT transpose_p0 = f32[32, 8] transpose(p0), dimensions={1, 0} - })"); - EXPECT_THAT(spec.permutation, ElementsAre(1, 0)); - EXPECT_THAT(spec.inv_permutation, ElementsAre(1, 0)); - EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 1, 32, 1)); - EXPECT_THAT(spec.canonical_output_shape, ElementsAre(32, 1, 8, 1)); - EXPECT_THAT(spec.canonical_permutation, ElementsAre(2, 1, 0, 3)); - EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(2, 1, 0, 3)); -} - -TEST_F(TransposeTest, Transpose_210) { - auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { - p0 = f32[8, 2, 32] parameter(0) - ROOT transpose_p0 = f32[32, 2, 8] transpose(p0), dimensions={2, 1, 0} - })"); - EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 2, 32, 1)); - EXPECT_THAT(spec.canonical_output_shape, ElementsAre(32, 2, 8, 1)); - EXPECT_THAT(spec.canonical_permutation, ElementsAre(2, 1, 0, 3)); - EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(2, 1, 0, 3)); -} - -TEST_F(TransposeTest, Transpose_102) { - auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { - p0 = f32[8, 2, 32] parameter(0) - ROOT transpose_p0 = f32[2, 8, 32] transpose(p0), dimensions={1, 0, 2} - })"); - EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 1, 2, 32)); - EXPECT_THAT(spec.canonical_output_shape, ElementsAre(2, 1, 8, 32)); - EXPECT_THAT(spec.canonical_permutation, ElementsAre(2, 1, 0, 3)); - EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(2, 1, 0, 3)); -} - -TEST_F(TransposeTest, Transpose_42130) { - auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { - p0 = f32[8, 2, 32, 7, 6] parameter(0) - ROOT transpose_p0 = f32[6, 32, 2, 7, 8] transpose(p0), - dimensions={4, 2, 1, 3, 0} - })"); - EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 2, 32, 7, 6, 1)); - EXPECT_THAT(spec.canonical_output_shape, ElementsAre(6, 32, 2, 7, 8, 1)); - EXPECT_THAT(spec.canonical_permutation, ElementsAre(4, 2, 1, 3, 0, 5)); - EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(4, 2, 1, 3, 0, 5)); -} - -} // namespace -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 5dec0050d0be87..8d3345d55dbbb6 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -633,6 +633,7 @@ cc_library( ":backend_configs_cc", ":target_util", "//xla:literal", + "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", "//xla:util", @@ -645,6 +646,7 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -675,10 +677,14 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service:buffer_assignment", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/service:hlo_runner", + "//xla/service:platform_util", + "//xla/tests:hlo_runner_agnostic_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 918e0e90306ee4..3e222ebdef0801 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -30,12 +30,13 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/escaping.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/FPEnv.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsNVPTX.h" @@ -52,17 +53,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_traversal.h" #include "xla/literal.h" +#include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/target_util.h" -#include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" @@ -432,6 +434,69 @@ std::optional GetDescriptionForTiledTransposeEmitter( return std::nullopt; } +TransposeSpec GetTransposeSpec(const HloTransposeInstruction* transpose) { + auto inv_permutation = InversePermutation(transpose->dimensions()); + auto& output_shape = transpose->shape(); + llvm::SmallVector canonical_output_shape = + llvm::to_vector<3>(output_shape.dimensions()); + llvm::SmallVector canonical_permutation = + llvm::to_vector<3>(transpose->dimensions()); + + // If the last dimension is transposed, add a size-1 B dimension. + if (canonical_permutation.back() != canonical_output_shape.size() - 1) { + canonical_permutation.push_back(output_shape.dimensions_size()); + canonical_output_shape.push_back(1); + } + int64_t dim_t1 = -1; + int64_t dim_t2 = -1; + for (int64_t i = canonical_permutation.size() - 1; i >= 0; --i) { + if (canonical_permutation[i] != i) { + dim_t2 = canonical_permutation[i]; + dim_t1 = i; + break; + } + } + // Insert size-1 A dimension if necessary. + auto rank = canonical_output_shape.size(); + if (canonical_permutation[rank - 3] != rank - 3) { + canonical_output_shape.insert(canonical_output_shape.begin() + dim_t1, 1); + for (auto& p : canonical_permutation) { + if (p > rank - 3) p++; + } + canonical_permutation.insert(canonical_permutation.begin() + dim_t1, + dim_t1); + } + auto canonical_inv_permutation = InversePermutation(canonical_permutation); + auto canonical_input_shape = + Permute(canonical_output_shape, canonical_inv_permutation); + return TransposeSpec{ + transpose, + llvm::to_vector<3>(transpose->dimensions()), + llvm::to_vector<3>(inv_permutation), + canonical_output_shape, + canonical_permutation, + llvm::to_vector<3>(canonical_inv_permutation), + llvm::to_vector<3>(canonical_input_shape), + }; +} + +std::string TransposeSpec::ToString() const { + return absl::Substitute(R"( +transpose: $0 +canonical_input_shape: $1 +canonical_output_shape: $2 +canonical_permutation: $3 +canonical_inv_permutation: $4 +[T2, A, T1, B] = [$5, $6, $7, $8] +)", + transpose->ToString(), + absl::StrJoin(canonical_input_shape, ","), + absl::StrJoin(canonical_output_shape, ","), + absl::StrJoin(canonical_permutation, ","), + absl::StrJoin(canonical_inv_permutation, ","), + dim_T2(), dim_A(), dim_T1(), dim_B()); +} + bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count) { // Number of operands should be in range [1, allowed_operand_count]. if (instr->operand_count() == 0 || diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 40d5024adfe99f..c4857fcf418561 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/hlo/ir/backend_config.h" @@ -219,6 +220,60 @@ struct TransposeDescription { std::optional GetDescriptionForTiledTransposeEmitter( const HloInstruction& hero); +// Canonical transpose permutes the input shape +// into +// . +// Note that the `D` dimensions are batch dimensions. They can also be +// permuted, but they are tiled by 1. +// +// Examples: +// 1. <8x32> -> <32x8> will be canonicalized to <8x1x32x1> -> <32x1x8x1>. +// 2. <8x2x32> -> <32x2x8> will be canonicalized to <8x2x32x1> -> <32x2x8x1>. +// 3. <8x2x32x7x6> -> <6x32x2x7x8> becomes <8x2x32x7x6x1> -> <6x32x2x7x8x1>. + +// TODO(b/370690811): Unify this with TransposeDescription. +struct TransposeSpec { + PrimitiveType elem_type() const { return input_shape().element_type(); } + + const Shape& input_shape() const { return transpose->operand(0)->shape(); } + const Shape& output_shape() const { return transpose->shape(); } + + int64_t rank() const { return input_shape().dimensions_size(); } + int64_t canonical_rank() const { return canonical_input_shape.size(); } + + int64_t dim_A() const { return canonical_input_shape[dim_A_id()]; } + int64_t dim_A_id() const { return canonical_rank() - 3; } + + int64_t dim_B() const { return canonical_input_shape.back(); } + int64_t dim_B_id() const { return canonical_rank() - 1; } + + int64_t dim_T1() const { return canonical_input_shape[dim_T1_input_id()]; } + int64_t dim_T1_input_id() const { return canonical_rank() - 2; } + int64_t dim_T1_output_id() const { + return canonical_inv_permutation[canonical_rank() - 2]; + } + + int64_t dim_T2() const { return canonical_input_shape[dim_T2_input_id()]; } + int64_t dim_T2_input_id() const { + return canonical_permutation[canonical_rank() - 2]; + } + int64_t dim_T2_output_id() const { return canonical_rank() - 2; } + + std::string ToString() const; + + const HloTransposeInstruction* transpose; + + llvm::SmallVector permutation; + llvm::SmallVector inv_permutation; + + llvm::SmallVector canonical_output_shape; + llvm::SmallVector canonical_permutation; + llvm::SmallVector canonical_inv_permutation; + llvm::SmallVector canonical_input_shape; +}; + +TransposeSpec GetTransposeSpec(const HloTransposeInstruction* transpose); + // Checks if the instruction is elementwise. bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 1081a865252612..b2399db02db684 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -20,18 +20,26 @@ limitations under the License. #include #include +#include +#include #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/backend_config.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_traversal.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/hlo_runner.h" +#include "xla/service/platform_util.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_runner_agnostic_test_base.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -40,9 +48,22 @@ limitations under the License. namespace xla { namespace gpu { +using ::testing::ElementsAre; using ::tsl::testing::IsOkAndHolds; -using IrEmissionUtilsTest = HloTestBase; +class IrEmissionUtilsTest : public HloRunnerAgnosticTestBase { + public: + IrEmissionUtilsTest() + : HloRunnerAgnosticTestBase( + std::make_unique(*PlatformUtil::GetDefaultPlatform())) {} + + TransposeSpec GetTransposeSpecFromRoot(absl::string_view hlo_text) { + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + auto* root = module->entry_computation()->root_instruction(); + return GetTransposeSpec(Cast(root)); + } +}; + using InlinedVector = absl::InlinedVector; TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) { @@ -1271,5 +1292,41 @@ TEST_F(IrEmissionUtilsTest, ResolveWhileLoopDependencySideEffect) { ASSERT_FALSE(result.has_value()); } +TEST_F(IrEmissionUtilsTest, Transpose_10) { + auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { + p0 = f32[8, 32] parameter(0) + ROOT transpose_p0 = f32[32, 8] transpose(p0), dimensions={1, 0} + })"); + EXPECT_THAT(spec.permutation, ElementsAre(1, 0)); + EXPECT_THAT(spec.inv_permutation, ElementsAre(1, 0)); + EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 1, 32, 1)); + EXPECT_THAT(spec.canonical_output_shape, ElementsAre(32, 1, 8, 1)); + EXPECT_THAT(spec.canonical_permutation, ElementsAre(2, 1, 0, 3)); + EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(2, 1, 0, 3)); +} + +TEST_F(IrEmissionUtilsTest, Transpose_210) { + auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { + p0 = f32[8, 2, 32] parameter(0) + ROOT transpose_p0 = f32[32, 2, 8] transpose(p0), dimensions={2, 1, 0} + })"); + EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 2, 32, 1)); + EXPECT_THAT(spec.canonical_output_shape, ElementsAre(32, 2, 8, 1)); + EXPECT_THAT(spec.canonical_permutation, ElementsAre(2, 1, 0, 3)); + EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(2, 1, 0, 3)); +} + +TEST_F(IrEmissionUtilsTest, Transpose_102) { + auto spec = GetTransposeSpecFromRoot(R"(ENTRY entry { + p0 = f32[8, 2, 32, 7, 6] parameter(0) + ROOT transpose_p0 = f32[6, 32, 2, 7, 8] transpose(p0), + dimensions={4, 2, 1, 3, 0} + })"); + EXPECT_THAT(spec.canonical_input_shape, ElementsAre(8, 2, 32, 7, 6, 1)); + EXPECT_THAT(spec.canonical_output_shape, ElementsAre(6, 32, 2, 7, 8, 1)); + EXPECT_THAT(spec.canonical_permutation, ElementsAre(4, 2, 1, 3, 0, 5)); + EXPECT_THAT(spec.canonical_inv_permutation, ElementsAre(4, 2, 1, 3, 0, 5)); +} + } // namespace gpu } // namespace xla From 2aff0691afee9c88c494e70bea34e6846d54bea8 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Fri, 11 Apr 2025 04:44:09 -0700 Subject: [PATCH 0556/1324] [XLA:GPU] Make Split-K rewrites propagate the accumulator dtype from the dot to the fusion root. So far, Split-K rewrites change the type of the fusion root only, which does not work in the general case. This is another step towards fixing --xla_gpu_triton_gemm_disable_reduced_precision_reduction and make it pass all tests. PiperOrigin-RevId: 746411487 --- .../xla/service/gpu/split_k_gemm_rewriter.cc | 53 +++++++++++++------ .../service/gpu/split_k_gemm_rewriter_test.cc | 14 ++++- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 2438410d776fdf..a64da1db635ef2 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -127,6 +127,20 @@ absl::StatusOr MakeSparseMetaOperand( return MakeBitcastHlo(meta, new_shape); } +PrimitiveType GetAccumulatorType(bool disable_reduced_precision_reduction, + HloComputation* computation, + HloInstruction* instr) { + if (!disable_reduced_precision_reduction) { + return instr->shape().element_type(); + } + + PrimitiveType output_type = + computation->root_instruction()->shape().element_type(); + PrimitiveType accumulator_type = output_type == PrimitiveType::F64 + ? PrimitiveType::F64 + : PrimitiveType::F32; + return accumulator_type; +} } // namespace absl::StatusOr MakeSplitKOperand( @@ -320,8 +334,11 @@ absl::Status MakeDotComputationSplitKBatch( TF_ASSIGN_OR_RETURN(sparse_meta[i], MakeSparseMetaOperand(*dot, config)); } + // Keep the precision of the accumulator type for the dot output. + PrimitiveType dot_dtype = GetAccumulatorType( + disable_reduced_precision_reduction, computation, dot); expanded = MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), - dot->shape().element_type(), sparsity, sparse_meta) + dot_dtype, sparsity, sparse_meta) .value(); // Make the added batch dimension the major-most, keep the order of the // original dimensions. @@ -334,8 +351,13 @@ absl::Status MakeDotComputationSplitKBatch( expanded->mutable_shape()->mutable_layout()->add_minor_to_major(0); dot->SetupDerivedInstruction(expanded); } else { - expanded = computation->AddInstruction(current->CloneWithNewShape( - ShapeUtil::PrependMajorDimension(config.split_k, current->shape()))); + // Propagate the precision of the accumulator to the GEMM fusion root. + PrimitiveType accumulator_dtype = GetAccumulatorType( + disable_reduced_precision_reduction, computation, current); + expanded = computation->AddInstruction( + current->CloneWithNewShape(ShapeUtil::PrependMajorDimension( + config.split_k, ShapeUtil::ChangeElementType( + current->shape(), accumulator_dtype)))); if (expanded->opcode() == HloOpcode::kTranspose) { const auto* old_transpose = Cast(current); auto* new_transpose = Cast(expanded); @@ -358,28 +380,25 @@ absl::Status MakeDotComputationSplitKBatch( for (int i = 0; i < expanded->operands().size(); ++i) { HloInstruction* operand = expanded->mutable_operand(i); if (!to_process_set.contains(operand)) { + // Broadcast the operand to the Split-K dimension and convert to the + // accumulator dtype. + auto accumulator_dtype = GetAccumulatorType( + disable_reduced_precision_reduction, computation, operand); + HloInstruction* convert = MakeConvertToHlo(operand, accumulator_dtype); std::vector broadcast_dimensions( operand->shape().dimensions_size()); absl::c_iota(broadcast_dimensions, 1); TF_RETURN_IF_ERROR(expanded->ReplaceOperandWithDifferentShape( - i, MakeBroadcastHlo(operand, broadcast_dimensions, - ShapeUtil::PrependMajorDimension( - config.split_k, operand->shape())))); + i, + MakeBroadcastHlo(convert, broadcast_dimensions, + ShapeUtil::PrependMajorDimension( + config.split_k, + ShapeUtil::ChangeElementType( + operand->shape(), accumulator_dtype))))); } } } - if (disable_reduced_precision_reduction) { - PrimitiveType output_type = - computation->root_instruction()->shape().element_type(); - PrimitiveType accumulator_type = output_type == PrimitiveType::F64 - ? PrimitiveType::F64 - : PrimitiveType::F32; - - computation->root_instruction()->mutable_shape()->set_element_type( - accumulator_type); - } - if (did_pad) { // Check if the analysis can work on the transformed HLO. // We can fail gracefully here, but not in IrEmitterTriton. diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index 20809a8863fa4b..5ae3103a50ce22 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -755,8 +755,18 @@ ENTRY e { TritonGemmConfig config(16, 16, 16, 4, 1, 4); TF_EXPECT_OK(MakeDotSplitKBatch( module->entry_computation()->root_instruction(), config)); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Convert(m::Reduce(m::Fusion(), m::Constant())))); + HloInstruction* dot_fusion; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Convert(m::Reduce(m::Fusion(&dot_fusion), m::Constant())))); + EXPECT_THAT( + dot_fusion->fused_instructions_computation()->root_instruction(), + GmockMatch( + m::MultiplyAnyOrder( + m::Broadcast().WithElementType(F32), + m::Convert(m::Dot().WithElementType(F32)).WithElementType(F32)) + .WithElementType(F32))) + << module->ToString(); } TEST_F(SplitKTest, MakeSplitKWithTransposeAfterDot) { From 7f2a0374e565d43a61d090691ab9389d71b25379 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 11 Apr 2025 05:08:52 -0700 Subject: [PATCH 0557/1324] [XLA:GPU] Unskip Triton regression test that was needlessly commented out. The test was tagged as needing to be skipped without actually failing, which is a weird state to be in. PiperOrigin-RevId: 746417770 --- .../fusion_emitter_device_legacy_port_test.cc | 19 +++++++++---------- .../fusion_emitter_device_legacy_test.cc | 4 ---- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 08be420041d777..76a91d00104480 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -172,22 +172,18 @@ class TritonGemmTestWithSplitK : public TritonGemmTest { } }; +// TODO(b/393299275): requires enabling mixed-type dots for f8xf8->bf16. TEST_F(TritonGemmTest, DISABLED_FP8DotSmallTileDoesNotCrash) { - GTEST_SKIP() << "TODO(b/337839570): Re-enable once the bug is fixed. " - "Currently the test is not representative of the issue. " - "While the test passes, the end-to-end model fails."; - if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } constexpr absl::string_view kHloText = R"( -HloModule m - triton_dot { - %parameter_0 = f8e4m3fn[32,32]{1,0} parameter(0) - %parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) - ROOT %dot.1643 = bf16[32,32]{1,0} dot(f8e4m3fn[32,32]{1,0} %parameter_0, f8e4m3fn[32,32]{0,1} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + p0 = f8e4m3fn[32,32]{1,0} parameter(0) + p1 = f8e4m3fn[32,32]{1,0} parameter(1) + ROOT dot = bf16[32,32]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { @@ -199,7 +195,10 @@ ENTRY e { "split_k":1,"num_stages":2,"num_warps":2, "num_ctas":1}}} })"; - EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); + EXPECT_TRUE(Run(std::move(module_and_metadata.module), + /*run_hlo_passes=*/false)); } TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc index 6889723c51f8f6..d329bf0e30aeba 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc @@ -133,10 +133,6 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { }; TEST_F(TritonGemmTest, FP8DotSmallTileDoesNotCrash) { - GTEST_SKIP() << "TODO(b/337839570): Re-enable once the bug is fixed. " - "Currently the test is not representative of the issue. " - "While the test passes, the end-to-end model fails."; - if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } From 8263838584dab995525877c49d120a0154db1f81 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 11 Apr 2025 13:27:57 +0100 Subject: [PATCH 0558/1324] [mlir][tosa] Slice maxpool2d input to tosa expected size (#90985) This commit fixes the maxpool2d legalization such that the generated tosa maxpool2d operator is generated with the expected shape by slicing the input accordingly. This is similar to the approaches of conv2d, conv3d, depthwise_conv2d and avg_pool2d. Change-Id: I56cd7730934c6e206b5916cb70fb1d7d455fab93 Signed-off-by: Luke Hutton --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 12 ++++++++++++ .../compiler/mlir/tosa/transforms/legalize_tfl.cc | 12 ++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index fd44d2738beb22..f8f753919e87c6 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -1208,6 +1208,18 @@ func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- +// CHECK-LABEL: test_max_pool2d_slicing +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 31, 31, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x31x31x8xf32> +// CHECK: %[[VAL_4:.*]] = tosa.max_pool2d %[[VAL_3]] {kernel = array, pad = array, stride = array} : (tensor<1x31x31x8xf32>) -> tensor<1x15x15x8xf32> +func.func @test_max_pool2d_slicing(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { + %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: test_reshape // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 61a9f54fd2f2e7..f3e6e371f47ca9 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -1346,6 +1346,8 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( DenseI64ArrayAttr kernel_size; DenseI64ArrayAttr stride; DenseI64ArrayAttr pad; + // Pooling has no non-unit dilation + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); { int64_t kernel_h = tfl_maxpool_op.getFilterHeight(); int64_t kernel_w = tfl_maxpool_op.getFilterWidth(); @@ -1364,9 +1366,6 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( if (!GetPaddingFromString(tfl_maxpool_op.getPadding().str(), &tf_pad).ok()) return failure(); - // Pooling has no non-unit dilation - DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); - RankedTensorType filter_type = RankedTensorType::get(i64array, rewriter.getIntegerType(64)); @@ -1379,8 +1378,13 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( return failure(); } + // TFLite only supports NHWC format + const Value max_pool_input = getInputSlicedToItsUsedSize( + rewriter, op, tensorflow::FORMAT_NHWC, input_type, + tfl_maxpool_op.getInput(), kernel_size, pad, stride, dilation); + CreateReplaceOpAndInfer(rewriter, op, output_type, - tfl_maxpool_op.getInput(), + max_pool_input, kernel_size, stride, pad); return success(); } From 6c5b9025f6efc79e6d34071e1f82f48cb841e2f0 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Fri, 11 Apr 2025 05:09:21 -0700 Subject: [PATCH 0559/1324] Add contracting tile size selection logic to dynamic search space PiperOrigin-RevId: 746417878 --- .../gpu/autotuning/dot_search_space.cc | 66 +++++++++++++++--- .../service/gpu/autotuning/dot_search_space.h | 16 +++++ .../gpu/autotuning/dot_search_space_test.cc | 69 +++++++++++++++++-- 3 files changed, 136 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 98b70dc615d96e..3520b814424baf 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -134,8 +134,8 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddOutputTilings); EliminateLowOccupancyConfigs(configs); - ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddCtaSizeParameter); + ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddContractingTiling); std::vector result; result.reserve(configs.size()); @@ -143,7 +143,6 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( // TODO: b/404470821 - Implement this properly rather than hardcoding the // config parameters. TritonGemmConfig& config = config_with_notes.config; - config.block_k = 64; config.num_stages = 3; config.num_ctas = 1; result.push_back(config); @@ -258,6 +257,31 @@ int TritonDotFusionSearchSpace::GetMaxContractingSplit( return split; } +int TritonDotFusionSearchSpace::GetContractingSizeLimitToFitSharedMemory( + OutputTile output_tile) const { + const int64_t shared_memory_budget = + device_description_.shared_memory_per_block_optin(); + // Need to satisfy: + // (lhs_dim + rhs_dim) * contracting_dim * bitwidth <= budget_in_bits + return 8 * shared_memory_budget / compute_bitwidth_ / + (output_tile.lhs_dim + output_tile.rhs_dim); +} + +int TritonDotFusionSearchSpace::GetMaxContractingTileSize( + OutputTile output_tile, int contracting_split) const { + const int64_t available_size = contracting_size_ / contracting_split; + const int64_t size_limit = + GetContractingSizeLimitToFitSharedMemory(output_tile); + const int64_t max_size = + std::min(NextPowerOfTwo(available_size), PreviousPowerOfTwo(size_limit)); + VLOG(5) << "Computing max_contracting_tile_size for tiling BxMxN = " + << contracting_split << "x" << output_tile.lhs_dim << "x" + << output_tile.rhs_dim << ": limit based on problem is " + << available_size << ", limit based on available shared memory is " + << size_limit << ", max_contracting_tile_size = " << max_size; + return max_size; +} + std::vector TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { CHECK_GE(max_contracting_split_, 1); @@ -290,10 +314,8 @@ void TritonDotFusionSearchSpace::AddOutputTilings( << "Need config with contracting split already set."; const int split = config.config.split_k; ConfigWithNotes new_config = config; - int& m = new_config.config.block_m; - int& n = new_config.config.block_n; - for (m = min_out_tile_.lhs_dim; m <= max_out_tile_.lhs_dim; m *= 2) { - for (n = min_out_tile_.rhs_dim; n <= max_out_tile_.rhs_dim; n *= 2) { + for (int m = min_out_tile_.lhs_dim; m <= max_out_tile_.lhs_dim; m *= 2) { + for (int n = min_out_tile_.rhs_dim; n <= max_out_tile_.rhs_dim; n *= 2) { OutputTile tile = {m, n}; // We could make the tile size limits depend on split_k, but then we // need to implement the "inverse" of `GetMaxContractingSplit`. @@ -306,6 +328,8 @@ void TritonDotFusionSearchSpace::AddOutputTilings( } new_config.not_enough_tiles = GetNumResultTiles(tile) * split < device_description_.core_count(); + new_config.config.block_m = m; + new_config.config.block_n = n; VLOG(10) << "Adding output tiling: config = " << new_config.ToString(); updated_configs.push_back(new_config); } @@ -316,21 +340,41 @@ void TritonDotFusionSearchSpace::AddCtaSizeParameter( const ConfigWithNotes& config, std::vector& updated_configs) { ConfigWithNotes new_config = config; - int tile_rows = config.config.block_m; - int tile_cols = config.config.block_n; - int& warps = new_config.config.num_warps; + const int tile_rows = config.config.block_m; + const int tile_cols = config.config.block_n; CHECK_GT(tile_rows * tile_cols, 0) << "Need configs with output tilings determined."; - int max_warps = GetMaxWarpsPerCta({tile_rows, tile_cols}); + const int max_warps = GetMaxWarpsPerCta({tile_rows, tile_cols}); VLOG(5) << "Computing max_warps: For output_tile = " << tile_rows << "x" << tile_cols << " and (wg)mma instruction shape, max_warps = " << max_warps; - for (warps = min_warps_per_cta_; warps <= max_warps; warps *= 2) { + for (int warps = min_warps_per_cta_; warps <= max_warps; warps *= 2) { + new_config.config.num_warps = warps; VLOG(10) << "Adding CTA size parameter: config = " << new_config.ToString(); updated_configs.push_back(new_config); } } +void TritonDotFusionSearchSpace::AddContractingTiling( + const ConfigWithNotes& config, + std::vector& updated_configs) { + const int tile_rows = config.config.block_m; + const int tile_cols = config.config.block_n; + const int split = config.config.split_k; + CHECK_GT(tile_rows * tile_cols, 0) + << "Need configs with output tilings determined."; + CHECK_GT(split, 0) << "Need config with contracting split determined."; + int max_tile_size = + std::max(GetMaxContractingTileSize({tile_rows, tile_cols}, split), + min_contracting_tile_size_); + ConfigWithNotes new_config = config; + for (int k = min_contracting_tile_size_; k <= max_tile_size; k *= 2) { + new_config.config.block_k = k; + VLOG(10) << "Adding contracting tiling: config = " << new_config.ToString(); + updated_configs.push_back(new_config); + } +} + void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( std::vector& configs) { CHECK(!configs.empty()); diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index d6c7be718acf08..7cc49026abd2ab 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -116,6 +116,15 @@ class TritonDotFusionSearchSpace { // output tile. int GetMaxContractingSplit(OutputTile output_tile) const; + // Computes the size limit for contracting dimension, based on the shared + // memory budget. + int GetContractingSizeLimitToFitSharedMemory(OutputTile output_tile) const; + + // Computes the maximum reasonable tile size for the contracting dimension for + // the given output tile and contracting split. + int GetMaxContractingTileSize(OutputTile output_tile, + int contracting_split) const; + // Finds all promising values for splitting the contracting dimension to // achieve sufficient occupancy (split_k). std::vector GenerateContractingSplitFactors(); @@ -134,6 +143,13 @@ class TritonDotFusionSearchSpace { void AddCtaSizeParameter(const ConfigWithNotes& config, std::vector& updated_configs); + // Finds all promising values for the contracting dimension tile size + // (block_k), based on `config` with already determined contracting split and + // output tiling, and appends them to `updated_configs`. Each config in the + // input list might yield zero or more configs in the output. + void AddContractingTiling(const ConfigWithNotes& config, + std::vector& updated_configs); + // Removes configs that are marked with `not_enough_tiles` from the list. If // this results in an empty list, adds a config that should be the most // optimal one even though it does not occupy all cores. diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index b85796de942d42..7870a56112f3e6 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -82,10 +82,13 @@ class DotSearchSpaceTest : public HloHardwareIndependentTestBase { DotSearchSpaceTest() : device_description_(se::GpuDeviceInfoProto::default_instance()) { // Using H100 numbers as the most relevant example here. + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability + // https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/#nvidia_h100_gpu_architecture_in-depth device_description_.set_registers_per_block_limit(64 * 1024); device_description_.set_core_count(132); device_description_.set_threads_per_block_limit(1024); device_description_.set_threads_per_warp(32); + device_description_.set_shared_memory_per_block_optin(227 * 1024); } absl::StatusOr> GetDefaultDotModule( @@ -114,8 +117,10 @@ ENTRY e { }; TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetDefaultDotModule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/1024, + /*contracting_dim=*/1024)); TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_EQ(search_space.Serialize(), @@ -166,7 +171,8 @@ TEST_F(DotSearchSpaceTest, LimitsContractingSplitForSmallerContractingSize) { TEST_F(DotSearchSpaceTest, FindsGoodDataReuseOutputTiles) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetDefaultDotModule()); + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, + /*rhs_parallel_dim=*/1024)); TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), @@ -199,7 +205,8 @@ TEST_F(DotSearchSpaceTest, TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetDefaultDotModule()); + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, + /*rhs_parallel_dim=*/1024)); TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT( @@ -269,5 +276,59 @@ TEST_F(DotSearchSpaceTest, ConsidersAppropriateCTASizeForTileSize) { NumWarpsIs(Eq(8)))))); } +TEST_F(DotSearchSpaceTest, FindsFullCacheLineContractingTileSize) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/1024, + /*contracting_dim=*/1024)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), Contains(BlockKIs(Ge(64)))); +} + +TEST_F(DotSearchSpaceTest, HonorsSharedMemoryLimit) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/4096, /*rhs_parallel_dim=*/4096, + /*contracting_dim=*/4096)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + // We pick a certain output tiling and contracting split of 1 (to not reduce + // the effective contracting size), and only verify that configs with these + // properties honor the memory limit. This simplifies the test logic and makes + // the calculation easier to verify by hand, while not reducing the coverage + // of the test. + auto considered_configs = + AllOf(BlockMIs(Eq(128)), BlockNIs(Eq(128)), SplitKIs(Eq(1))); + + // 2B * (128 + 128) * block_k < 227 KB => + // block_k <= 227 KB / (2B * (128 + 128)) = 454 + EXPECT_THAT( + search_space.GenerateConfigs(), + AllOf(Contains(considered_configs), + Not(Contains(AllOf(considered_configs, BlockKIs(Ge(512))))))); +} + +TEST_F(DotSearchSpaceTest, HonorsContractingSizeLimit) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/1024, + /*contracting_dim=*/256)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(/*force_contracting_split=*/4), + AllOf(Not(IsEmpty()), Each(BlockKIs(Le(64))))); +} + +TEST_F(DotSearchSpaceTest, EnsuresContractingTileSizeFitsInstructonShape) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/1024, + /*contracting_dim=*/4)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Not(IsEmpty()), Each(BlockKIs(Ge(8))))); +} + } // namespace } // namespace xla::gpu From 467c2d43e550df18855e3ad1690f23851511ae8c Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 11 Apr 2025 05:09:49 -0700 Subject: [PATCH 0560/1324] [XLA:GPU] Let EvaluateTileStrides() handle parameters > dimension bounds. Currently the assumption is that they are never called with tile parameters outside dimension bounds. The underlying implementation detail is that IfNeqOne affine expression that we use for expanding reshapes assumes that the tile parameter is not bigger than the dimension bound. To make the assumption hold, we clamp the parameters accordingly. PiperOrigin-RevId: 746417970 --- .../xla/service/gpu/model/symbolic_tile.cc | 15 +++++++- .../gpu/model/symbolic_tile_analysis_test.cc | 36 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 874882685d8f8d..7cf7fd99b23d0c 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile.h" +#include #include #include #include @@ -296,8 +297,20 @@ llvm::SmallVector EvaluateTileSizes( llvm::SmallVector EvaluateTileStrides( const SymbolicTile& symbolic_tile, absl::Span parameters) { + llvm::SmallVector clamped_parameters; + clamped_parameters.reserve(parameters.size()); + // We need to clamp the parameters to the dimension bounds, otherwise the + // stride expressions would potentially return wrong results. The underlying + // implementation detail is that the IfNeqOne affine expression that we use + // for expanding reshapes assumes that the tile parameter is not bigger than + // the dimension bound. To make the assumption hold, we clamp the parameters + // accordingly. + for (auto [parameter, dim_bounds] : + llvm::zip(parameters, symbolic_tile.tile_map().GetDimensionBounds())) { + clamped_parameters.push_back(std::min(parameter, dim_bounds.upper)); + } return EvaluateAffineMap(symbolic_tile.stride_map(), - /*dim_values=*/parameters); + /*dim_values=*/clamped_parameters); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index e82f14470a0489..7c6ffefc9c81c8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -232,6 +232,42 @@ ENTRY main { EXPECT_EQ(p0_from_subtract0, p0_from_subtract1); } +TEST_F(SymbolicTileAnalysisTest, + ExpandingReshapeIsSupportedWithTileParamsOutsideBounds) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + param_0 = f32[20] parameter(0) + abs = f32[20] abs(param_0) + ROOT reshape = f32[4,5] reshape(abs) +} + +ENTRY entry_computation { + param_0 = f32[20] parameter(0) + ROOT fusion = f32[4, 5] fusion(param_0), kind=kCustom, calls=fusion +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{1, 8}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); + + const TiledHloInstruction* root = tiled_hlo_computation.GetRoots()[0]; + auto parameter = root->operand(0)->operand(0); + EXPECT_THAT(*parameter, MatchTiledHloInstruction( + /*tile_sizes=*/{8}, + /*tile_strides=*/{1}, + /*tile_offsets_indexing=*/R"( + (pid_0, pid_1) -> (pid_0 * 5), + domain: + pid_0 in [0, 3], + pid_1 in [0, 0] + )")); +} + TEST_F(SymbolicTileAnalysisTest, ProducerConsumerFusionIsSupported) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( From 9ffc0fe0628c7ee159485dc5fdc4ace724f29c9c Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Fri, 11 Apr 2025 05:28:11 -0700 Subject: [PATCH 0561/1324] [XLA:GPU] Avoid cycles in Split-K Rewrites. PiperOrigin-RevId: 746422055 --- .../xla/service/gpu/split_k_gemm_rewriter.cc | 6 ++-- .../service/gpu/split_k_gemm_rewriter_test.cc | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index a64da1db635ef2..99dd94144ab4bf 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -436,6 +436,7 @@ absl::Status MakeDotSplitKBatch(HloInstruction* dot_fusion, HloInstruction* zero = dot_fusion->parent()->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(root->shape().element_type()))); + auto initial_dot_fusion_users = dot_fusion->users(); // The batch dimension to reduce is the first one by construction. TF_ASSIGN_OR_RETURN(HloInstruction * reduce, MakeReduceHlo(dot_fusion, zero, /*dimensions=*/{0}, @@ -452,8 +453,9 @@ absl::Status MakeDotSplitKBatch(HloInstruction* dot_fusion, dot_fusion->parent()->set_root_instruction(split_k_root, /*accept_different_shape=*/true); } else { - TF_RETURN_IF_ERROR( - dot_fusion->ReplaceAllUsesWithDifferentShape(split_k_root)); + // Replace all users expect for split_k_root created above to avoid cycles. + TF_RETURN_IF_ERROR(dot_fusion->ReplaceAllUsesWithDifferentShape( + initial_dot_fusion_users, split_k_root)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index 5ae3103a50ce22..d7c3b72deabc04 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -769,6 +769,42 @@ ENTRY e { << module->ToString(); } +TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKTuple) { + const std::string hlo_text = R"( +HloModule t + +triton_gemm_dot { + parameter_0 = s8[3,128,5,32]{3,2,1,0} parameter(0) + bitcast.1 = s8[3,5,32,128]{2,1,3,0} bitcast(parameter_0) + copy.1 = s8[3,5,32,128]{3,2,1,0} copy(bitcast.1) + reshape.5 = s8[480,128]{1,0} reshape(copy.1) + convert.8 = bf16[480,128]{1,0} convert(reshape.5) + parameter_1 = bf16[16,128]{1,0} parameter(1) + ROOT dot.0 = bf16[480,16]{1,0} dot(convert.8, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = s8[3,128,5,32]{3,2,1,0} parameter(0) + p1 = bf16[16,128]{1,0} parameter(1) + fusion = bf16[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm", + metadata={op_name="foo"} + ROOT tuple = (bf16[480,16]{1,0}, bf16[16,128]{1,0}) tuple(fusion, p1) +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TritonGemmConfig config(16, 16, 16, 4, 1, 4); + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction()->mutable_operand(0), + config)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple().WithOperand( + 0, m::Convert(m::Reduce().WithOperand(0, m::Fusion()))))) + << module->ToString(); +} + TEST_F(SplitKTest, MakeSplitKWithTransposeAfterDot) { const std::string hlo_text = R"( triton_gemm_dot { From 93f72f0848cb5cca4d83351ee002c0cfa260579f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 05:55:35 -0700 Subject: [PATCH 0562/1324] Remove stale build targets After #25023 some stale targets were still leftover in the BUILD file that don't have corresponding source files. PiperOrigin-RevId: 746428411 --- third_party/xla/xla/service/BUILD | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 0b01456d54b023..64009010fb7389 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3447,12 +3447,6 @@ cc_library( ], ) -cc_library( - name = "hlo_dataflow_analysis", - hdrs = ["hlo_dataflow_analysis.h"], - deps = ["//xla/hlo/analysis:hlo_dataflow_analysis"], -) - cc_library( name = "hlo_phi_graph", srcs = ["hlo_phi_graph.cc"], @@ -3880,12 +3874,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_constant_folding", - hdrs = ["hlo_constant_folding.h"], - deps = ["//xla/hlo/transforms/simplifiers:hlo_constant_folding"], -) - cc_library( name = "hlo_domain_map", srcs = ["hlo_domain_map.cc"], @@ -5143,12 +5131,6 @@ xla_cc_test( ], ) -cc_library( - name = "operand_upcaster", - hdrs = ["operand_upcaster.h"], - deps = ["//xla/hlo/transforms:operand_upcaster"], -) - cc_library( name = "global_device_id", srcs = ["global_device_id.cc"], From 464bc87b0376ba4baedb09d7d8d447225630ac21 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 11 Apr 2025 14:18:12 +0000 Subject: [PATCH 0563/1324] Updated rules_python patch to get 3.13.2 python --- third_party/py/python_init_rules.bzl | 5 +- ...rules_python.patch => rules_python1.patch} | 0 third_party/py/rules_python2.patch | 74 +++++++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) rename third_party/py/{rules_python.patch => rules_python1.patch} (100%) create mode 100644 third_party/py/rules_python2.patch diff --git a/third_party/py/python_init_rules.bzl b/third_party/py/python_init_rules.bzl index 796ae3d92d999f..e66623b50d6c15 100644 --- a/third_party/py/python_init_rules.bzl +++ b/third_party/py/python_init_rules.bzl @@ -9,5 +9,8 @@ def python_init_rules(): strip_prefix = "rules_python-0.39.0", url = "https://github.com/bazelbuild/rules_python/releases/download/0.39.0/rules_python-0.39.0.tar.gz", patch_args = ["-p1"], - patches = [Label("//third_party/py:rules_python.patch")], + patches = [ + Label("//third_party/py:rules_python1.patch"), + Label("//third_party/py:rules_python2.patch"), + ], ) diff --git a/third_party/py/rules_python.patch b/third_party/py/rules_python1.patch similarity index 100% rename from third_party/py/rules_python.patch rename to third_party/py/rules_python1.patch diff --git a/third_party/py/rules_python2.patch b/third_party/py/rules_python2.patch new file mode 100644 index 00000000000000..304d0606ca08f7 --- /dev/null +++ b/third_party/py/rules_python2.patch @@ -0,0 +1,74 @@ +diff --git a/python/versions.bzl b/python/versions.bzl +index 91e59f9b..223c90d4 100644 +--- a/python/versions.bzl ++++ b/python/versions.bzl +@@ -575,25 +575,42 @@ TOOL_VERSIONS = { + }, + "strip_prefix": "python", + }, +- "3.13.0": { +- "url": "20241016/cpython-{python_version}+20241016-{platform}-{build}.{ext}", +- "sha256": { +- "aarch64-apple-darwin": "31397953849d275aa2506580f3fa1cb5a85b6a3d392e495f8030e8b6412f5556", +- "aarch64-unknown-linux-gnu": "e8378c0162b2e0e4cc1f62b29443a3305d116d09583304dbb0149fecaff6347b", +- "ppc64le-unknown-linux-gnu": "fc4b7f27c4e84c78f3c8e6c7f8e4023e4638d11f1b36b6b5ce457b1926cebb53", +- "s390x-unknown-linux-gnu": "66b19e6a07717f6cfcd3a8ca953f0a2eaa232291142f3d26a8d17c979ec0f467", +- "x86_64-apple-darwin": "cff1b7e7cd26f2d47acac1ad6590e27d29829776f77e8afa067e9419f2f6ce77", +- "x86_64-pc-windows-msvc": "b25926e8ce4164cf103bacc4f4d154894ea53e07dd3fdd5ebb16fb1a82a7b1a0", +- "x86_64-unknown-linux-gnu": "2c8cb15c6a2caadaa98af51df6fe78a8155b8471cb3dd7b9836038e0d3657fb4", +- "aarch64-apple-darwin-freethreaded": "efc2e71c0e05bc5bedb7a846e05f28dd26491b1744ded35ed82f8b49ccfa684b", +- "aarch64-unknown-linux-gnu-freethreaded": "59b50df9826475d24bb7eff781fa3949112b5e9c92adb29e96a09cdf1216d5bd", +- "ppc64le-unknown-linux-gnu-freethreaded": "1217efa5f4ce67fcc9f7eb64165b1bd0912b2a21bc25c1a7e2cb174a21a5df7e", +- "s390x-unknown-linux-gnu-freethreaded": "6c3e1e4f19d2b018b65a7e3ef4cd4225c5b9adfbc490218628466e636d5c4b8c", +- "x86_64-apple-darwin-freethreaded": "2e07dfea62fe2215738551a179c87dbed1cc79d1b3654f4d7559889a6d5ce4eb", +- "x86_64-pc-windows-msvc-freethreaded": "bfd89f9acf866463bc4baf01733da5e767d13f5d0112175a4f57ba91f1541310", +- "x86_64-unknown-linux-gnu-freethreaded": "a73adeda301ad843cce05f31a2d3e76222b656984535a7b87696a24a098b216c", ++ "3.13.2": { ++ "url": "20250317/cpython-{python_version}+20250317-{platform}-{build}.{ext}", ++ "sha256": { ++ "aarch64-apple-darwin": "faa44274a331eb39786362818b21b3a4e74514e8805000b20b0e55c590cecb94", ++ "aarch64-unknown-linux-gnu": "9c67260446fee6ea706dad577a0b32936c63f449c25d66e4383d5846b2ab2e36", ++ "ppc64le-unknown-linux-gnu": "345b53d2f86c9dbd7f1320657cb227ff9a42ef63ff21f129abbbc8c82a375147", ++ "s390x-unknown-linux-gnu": "ec3b16ea8a97e3138acec72bc5ff35949950c62c8994a8ec8e213fd93f0e806b", ++ "x86_64-apple-darwin": "ee4526e84b5ce5b11141c50060b385320f2773616249a741f90c96d460ce8e8f", ++ "x86_64-pc-windows-msvc": "84d7b52f3558c8e35c670a4fa14080c75e3ec584adfae49fec8b51008b75b21e", ++ "x86_64-unknown-linux-gnu": "db011f0cd29cab2291584958f4e2eb001b0e6051848d89b38a2dc23c5c54e512", ++ "x86_64-unknown-linux-musl": "00bb2d629f7eacbb5c6b44dc04af26d1f1da64cee3425b0d8eb5135a93830296", ++ "aarch64-apple-darwin-freethreaded": "c98c9c977e6fa05c3813bd49f3553904d89d60fed27e2e36468da7afa1d6d5e2", ++ "aarch64-unknown-linux-gnu-freethreaded": "b8635e59e3143fd17f19a3dfe8ccc246ee6587c87da359bd1bcab35eefbb5f19", ++ "ppc64le-unknown-linux-gnu-freethreaded": "6ae8fa44cb2edf4ab49cff1820b53c40c10349c0f39e11b8cd76ce7f3e7e1def", ++ "s390x-unknown-linux-gnu-freethreaded": "c074144cc80c2af32c420b79a9df26e8db405212619990c1fbdd308bd75afe3f", ++ "x86_64-apple-darwin-freethreaded": "0d73e4348d8d4b5159058609d2303705190405b485dd09ad05d870d7e0f36e0f", ++ "x86_64-pc-windows-msvc-freethreaded": "c51b4845fda5421e044067c111192f645234081d704313f74ee77fa013a186ea", ++ "x86_64-unknown-linux-gnu-freethreaded": "1aea5062614c036904b55c1cc2fb4b500b7f6f7a4cacc263f4888889d355eef8", ++ }, ++ "strip_prefix": { ++ "aarch64-apple-darwin": "python", ++ "aarch64-unknown-linux-gnu": "python", ++ "ppc64le-unknown-linux-gnu": "python", ++ "s390x-unknown-linux-gnu": "python", ++ "x86_64-apple-darwin": "python", ++ "x86_64-pc-windows-msvc": "python", ++ "x86_64-unknown-linux-gnu": "python", ++ "x86_64-unknown-linux-musl": "python", ++ "aarch64-apple-darwin-freethreaded": "python/install", ++ "aarch64-unknown-linux-gnu-freethreaded": "python/install", ++ "ppc64le-unknown-linux-gnu-freethreaded": "python/install", ++ "s390x-unknown-linux-gnu-freethreaded": "python/install", ++ "x86_64-apple-darwin-freethreaded": "python/install", ++ "x86_64-pc-windows-msvc-freethreaded": "python/install", ++ "x86_64-unknown-linux-gnu-freethreaded": "python/install", + }, +- "strip_prefix": "python", + }, + } + +@@ -604,7 +621,7 @@ MINOR_MAPPING = { + "3.10": "3.10.15", + "3.11": "3.11.10", + "3.12": "3.12.8", +- "3.13": "3.13.0", ++ "3.13": "3.13.2", + } + + def _generate_platforms(): From f5a40c1e60d4460e3674dad33781caf2ad045e4b Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 11 Apr 2025 07:22:56 -0700 Subject: [PATCH 0564/1324] PR #24502: [ROCM] Introduce asan config for rocm ci hermetic build Imported from GitHub PR https://github.com/openxla/xla/pull/24502 Copybara import of the project: -- db6ea475e514387856129d26708e8ea22c06382d by Alexandros Theodoridis : Introduce asan config for rocm ci hermetic build -- e209516d805e06c2317cace1b2fa8b024eab4ece by Alex : PR #24958: [ROCM] move rocm ci command into xla repo Imported from GitHub PR https://github.com/openxla/xla/pull/24958 Move rocm CI command into the xla directory. This is a preparation step before implementing asan build for rocm. Copybara import of the project: -- bd5a73bd96a0e4f9d5d65926cdf3f4d7206331e4 by alekstheod : Move ci command to xla repo -- a5cbb02fa751d8ad6175bcb79c6da44b28f8fdc0 by alekstheod : Address review comments -- dbd0b6b1656bed67a43dc552a608a7db9184b45a by alekstheod : Move ci build script to build_tools dir Merging this change closes #24958 PiperOrigin-RevId: 746386978 -- 5eea14508321c8a8aad9711a70bf298c301b6b94 by alekstheod : Add asan ignore lists -- 2e0e2ad1cd7de668a159be7bc52b99abb7d80a81 by alekstheod : Revert conversions api Merging this change closes #24502 PiperOrigin-RevId: 746452079 --- third_party/xla/build_tools/rocm/asan_ignore_list.txt | 2 ++ third_party/xla/build_tools/rocm/lsan_ignore_list.txt | 3 +++ third_party/xla/build_tools/rocm/run_xla_ci_build.sh | 4 +++- third_party/xla/tensorflow.bazelrc | 10 ++++++++++ 4 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/build_tools/rocm/asan_ignore_list.txt create mode 100644 third_party/xla/build_tools/rocm/lsan_ignore_list.txt diff --git a/third_party/xla/build_tools/rocm/asan_ignore_list.txt b/third_party/xla/build_tools/rocm/asan_ignore_list.txt new file mode 100644 index 00000000000000..9f9ba4ad9b8c69 --- /dev/null +++ b/third_party/xla/build_tools/rocm/asan_ignore_list.txt @@ -0,0 +1,2 @@ +interceptor_via_lib:libhsa-runtime64.so +interceptor_via_lib:libamdhip64.so diff --git a/third_party/xla/build_tools/rocm/lsan_ignore_list.txt b/third_party/xla/build_tools/rocm/lsan_ignore_list.txt new file mode 100644 index 00000000000000..41fd8f23abc3d5 --- /dev/null +++ b/third_party/xla/build_tools/rocm/lsan_ignore_list.txt @@ -0,0 +1,3 @@ +leak:libhsa-runtime64.so +leak:libstdc++.so +leak:libamdhip64.so diff --git a/third_party/xla/build_tools/rocm/run_xla_ci_build.sh b/third_party/xla/build_tools/rocm/run_xla_ci_build.sh index 14a5b28dc60910..dbc6b0321f31be 100755 --- a/third_party/xla/build_tools/rocm/run_xla_ci_build.sh +++ b/third_party/xla/build_tools/rocm/run_xla_ci_build.sh @@ -35,4 +35,6 @@ bazel --bazelrc=/usertools/rocm.bazelrc test \ --action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \ --test_output=errors \ --local_test_jobs=2 \ - --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute + --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute \ + --test_env="ASAN_OPTIONS=suppressions=$(realpath $(dirname $0))/asan_ignore_list.txt" \ + --test_env="LSAN_OPTIONS=suppressions=$(realpath $(dirname $0))/lsan_ignore_list.txt" diff --git a/third_party/xla/tensorflow.bazelrc b/third_party/xla/tensorflow.bazelrc index 584a966f8cde27..98fdeb4f3fcca3 100644 --- a/third_party/xla/tensorflow.bazelrc +++ b/third_party/xla/tensorflow.bazelrc @@ -216,6 +216,14 @@ build:dbg --per_file_copt=+.*,-xla.*@-g0 # AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 build:dbg --copt -DDEBUG_BUILD +build:asan --strip=never +build:asan --copt -fsanitize=address +build:asan --copt -DADDRESS_SANITIZER +build:asan --copt -O1 +build:asan --copt -g +build:asan --copt -fno-omit-frame-pointer +build:asan --linkopt -fsanitize=address + build:rocm_base --copt=-Wno-gnu-offsetof-extensions build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm_base --define=using_rocm_hipcc=true @@ -237,6 +245,8 @@ build:rocm_clang_official --linkopt="-fuse-ld=lld" build:rocm_clang_official --host_linkopt="-fuse-ld=lld" build:rocm_ci --config=rocm_clang_official + +build:rocm_ci_hermetic --config=asan build:rocm_ci_hermetic --config=rocm_clang_official build:rocm_ci_hermetic --repo_env="OS=ubuntu_22.04" build:rocm_ci_hermetic --repo_env="ROCM_VERSION=6.2.0" From 71e335411dfc3703f97db0c7465904958246e2fc Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 11 Apr 2025 08:16:42 -0700 Subject: [PATCH 0565/1324] [XLA:GPU] Productionize communication type detection. In XLA:GPU we always assume that iota is the optimal rank assignment. We compute generally whether participants are within a single node or not and whether full rail alignment is reached or not. PiperOrigin-RevId: 746467442 --- .../service/gpu/transforms/collectives/BUILD | 12 ++ .../collectives/collective_ops_utils.cc | 94 +++++++++- .../collectives/collective_ops_utils_test.cc | 162 +++++++++++++++++- 3 files changed, 255 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD index aeb588f65671da..987a2beaa5b4af 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD @@ -53,13 +53,18 @@ cc_library( srcs = ["collective_ops_utils.cc"], hdrs = ["collective_ops_utils.h"], deps = [ + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", ], ) @@ -68,10 +73,17 @@ xla_cc_test( srcs = ["collective_ops_utils_test.cc"], deps = [ ":collective_ops_utils", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:hlo_module_config", + "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc index 8022b90cfcfa49..ba1edbcb7e53fa 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc @@ -15,10 +15,18 @@ limitations under the License. #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" +#include +#include +#include +#include #include +#include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -26,10 +34,78 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { +namespace { + +struct CommunicationMetadata { + absl::flat_hash_map node_to_participant_count; +}; + +bool SameParticipantCounts(const absl::flat_hash_map& lhs, + const absl::flat_hash_map& rhs) { + std::vector lhs_counts, rhs_counts; + lhs_counts.reserve(lhs.size()); + for (const auto& [_, v] : lhs) { + lhs_counts.push_back(v); + } + + rhs_counts.reserve(rhs.size()); + for (const auto& [_, v] : rhs) { + rhs_counts.push_back(v); + } + std::sort(lhs_counts.begin(), lhs_counts.end()); + std::sort(rhs_counts.begin(), rhs_counts.end()); + return lhs_counts == rhs_counts; +} + +absl::StatusOr CommunicationContext( + const CollectiveDeviceList& device_list, int num_devices_per_host) { + absl::flat_hash_map node_to_participant_count; + for (const ReplicaGroup& replica_group : device_list.replica_groups()) { + absl::flat_hash_map buffer; + for (int64_t rank : replica_group.replica_ids()) { + int64_t node_id = rank / num_devices_per_host; + buffer[node_id]++; + } + if (!node_to_participant_count.empty() && + !SameParticipantCounts(buffer, node_to_participant_count)) { + return absl::FailedPreconditionError(absl::StrCat( + "Non homogenous replica group: ", device_list.ToString())); + } + if (node_to_participant_count.empty()) { + node_to_participant_count = buffer; + } + } + + return CommunicationMetadata{node_to_participant_count}; +} + +bool IsSingleHost(const CommunicationMetadata& pattern) { + return pattern.node_to_participant_count.size() == 1; +} + +bool IsRailAligned(const CommunicationMetadata& pattern, + int num_devices_per_host) { + return absl::c_all_of(pattern.node_to_participant_count, + [num_devices_per_host](const auto& elem) { + const auto& [node_id, participant_count] = elem; + return participant_count == num_devices_per_host; + }); +} + +bool IsNonRailAligned(const CommunicationMetadata& pattern, + int num_devices_per_host) { + return !IsSingleHost(pattern) && + !IsRailAligned(pattern, num_devices_per_host); +} + +} // namespace + bool IsGPUSyncCollective(const HloInstruction& instr) { auto backend_config = instr.backend_config(); if (!backend_config.ok()) { @@ -56,22 +132,22 @@ absl::StatusOr CommunicationType( // We assume no topology was provided to the compiler and no // `CUDA_VISIBLE_DEVICES` env var has been set. + // For now we only support H100 and assume 8GPUs per host. int num_devices_per_host = 8; - if (!iota.has_value()) { - return absl::FailedPreconditionError( - "Only iota device assignment is supported."); + TF_ASSIGN_OR_RETURN( + CommunicationMetadata comm, + CommunicationContext(instr.device_list(), num_devices_per_host)); + if (IsSingleHost(comm)) { + return GPUCommunicationType::SINGLE_HOST; } - if (iota->num_replica_groups() == 1) { + if (IsRailAligned(comm, num_devices_per_host)) { return GPUCommunicationType::RAIL_ALIGNED; } - if (iota->num_replica_groups() == num_devices_per_host && - iota->transpose_perm().size() == 2 && iota->transpose_perm()[0] == 1) { + if (IsNonRailAligned(comm, num_devices_per_host)) { return GPUCommunicationType::NON_RAIL_ALIGNED; } - if (iota->num_devices_per_group() == num_devices_per_host) { - return GPUCommunicationType::SINGLE_HOST; - } + return GPUCommunicationType::UNDEFINED; } diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc index ebcd83b7d07325..69ec8182e27769 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc @@ -15,15 +15,26 @@ limitations under the License. #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" +#include #include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/statusor.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" -namespace xla { -namespace gpu { +namespace xla::gpu { namespace { +using ::testing::Test; +using ::tsl::testing::IsOkAndHolds; + bool IsMultiHostTopology(se::CudaComputeCapability compute_capability, int num_partitions, int replica_count) { HloModuleConfig config; @@ -63,6 +74,149 @@ TEST(IsMultiHostTopologyTest, MultiHosts) { /*num_partitions=*/1, /*replica_count=*/16)); } +class CommunicationTypeTest : public Test { + protected: + se::DeviceDescription& device_info() { return device_info_; } + + private: + se::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo( + stream_executor::CudaComputeCapability(9, 0)); +}; + +TEST_F(CommunicationTypeTest, DetectsSingleHost8Devices) { + absl::string_view kHlo = R"( + HloModule m, num_partitions=8 + + ENTRY e { + p = f32[128] parameter(0) + ROOT _ = f32[1024] all-gather(p), + dimensions={0}, + use_global_device_ids=true, + channel_id=1, + replica_groups=[1,8]<=[8] + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + EXPECT_THAT(CommunicationType(*instr, device_info().gpu_compute_capability()), + IsOkAndHolds(GPUCommunicationType::SINGLE_HOST)); +} + +TEST_F(CommunicationTypeTest, DetectsSingleHost4Devices) { + absl::string_view kHlo = R"( + HloModule m, num_partitions=8 + + ENTRY e { + p = f32[128] parameter(0) + ROOT _ = f32[512] all-gather(p), + dimensions={0}, + use_global_device_ids=true, + channel_id=1, + replica_groups=[1,4]<=[4] + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + EXPECT_THAT(CommunicationType(*instr, device_info().gpu_compute_capability()), + IsOkAndHolds(GPUCommunicationType::SINGLE_HOST)); +} + +TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) { + absl::string_view kHlo = R"( + HloModule m, num_partitions=16 + + ENTRY e { + p = f32[128] parameter(0) + ROOT _ = f32[512] all-gather(p), + dimensions={0}, + use_global_device_ids=true, + channel_id=1, + replica_groups=[2,8]<=[16] + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + EXPECT_THAT(CommunicationType(*instr, device_info().gpu_compute_capability()), + IsOkAndHolds(GPUCommunicationType::SINGLE_HOST)); +} + +TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) { + absl::string_view kHlo = R"( + HloModule m, num_partitions=16 + + ENTRY e { + p = f32[128] parameter(0) + ROOT _ = f32[2048] all-gather(p), + dimensions={0}, + use_global_device_ids=true, + channel_id=1, + replica_groups=[1,16]<=[16] + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + EXPECT_THAT(CommunicationType(*instr, device_info().gpu_compute_capability()), + IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED)); +} + +TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) { + absl::string_view kHlo = R"( + HloModule m, num_partitions=32 + + ENTRY e { + p = f32[128] parameter(0) + ROOT _ = f32[512] all-gather(p), + dimensions={0}, + use_global_device_ids=true, + channel_id=1, + replica_groups={ + {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, + {16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31} + } + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + EXPECT_THAT(CommunicationType(*instr, device_info().gpu_compute_capability()), + IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED)); +} + +TEST_F(CommunicationTypeTest, DetectNonRailAligned) { + absl::string_view kHlo = R"( + HloModule m, num_partitions=16 + + ENTRY e { + p = f32[128] parameter(0) + ROOT _ = f32[512] all-gather(p), + dimensions={0}, + use_global_device_ids=true, + channel_id=1, + replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + + HloCollectiveInstruction* instr = Cast( + module->entry_computation()->root_instruction()); + EXPECT_THAT(CommunicationType(*instr, device_info().gpu_compute_capability()), + IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED)); +} + } // namespace -} // namespace gpu -} // namespace xla +} // namespace xla::gpu From cc430b52385ecf882c3f9507e9fd3de331d4189b Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 11 Apr 2025 08:23:04 -0700 Subject: [PATCH 0566/1324] Remove deprecated targets from build and header files PiperOrigin-RevId: 746469337 --- third_party/xla/xla/BUILD | 18 ----- third_party/xla/xla/client/BUILD | 48 ------------ third_party/xla/xla/client/padding.h | 22 ------ third_party/xla/xla/client/sharding_builder.h | 22 ------ third_party/xla/xla/client/value_inference.h | 21 ----- third_party/xla/xla/client/xla_builder.h | 22 ------ third_party/xla/xla/client/xla_computation.h | 22 ------ third_party/xla/xla/hlo/ir/BUILD | 14 ---- .../xla/xla/hlo/ir/hlo_dfs_reachability.h | 22 ------ third_party/xla/xla/hlo/ir/hlo_reachability.h | 22 ------ third_party/xla/xla/test.h | 40 ---------- third_party/xla/xla/test_helpers.h | 22 ------ third_party/xla/xla/translate/BUILD | 22 ------ .../xla/xla/translate/hlo_to_mhlo/BUILD | 78 ------------------- .../hlo_to_mhlo/attribute_importer.h | 22 ------ .../hlo_to_mhlo/hlo_function_importer.h | 22 ------ .../hlo_to_mhlo/hlo_module_importer.h | 22 ------ .../translate/hlo_to_mhlo/hlo_to_mlir_hlo.h | 22 ------ .../xla/xla/translate/hlo_to_mhlo/hlo_utils.h | 24 ------ .../xla/xla/translate/hlo_to_mhlo/translate.h | 22 ------ .../xla/xla/translate/stablehlo_to_hlo/BUILD | 20 ----- .../translate/stablehlo_to_hlo/translate.h | 22 ------ 22 files changed, 571 deletions(-) delete mode 100644 third_party/xla/xla/client/padding.h delete mode 100644 third_party/xla/xla/client/sharding_builder.h delete mode 100644 third_party/xla/xla/client/value_inference.h delete mode 100644 third_party/xla/xla/client/xla_builder.h delete mode 100644 third_party/xla/xla/client/xla_computation.h delete mode 100644 third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h delete mode 100644 third_party/xla/xla/hlo/ir/hlo_reachability.h delete mode 100644 third_party/xla/xla/test.h delete mode 100644 third_party/xla/xla/test_helpers.h delete mode 100644 third_party/xla/xla/translate/BUILD delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/BUILD delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h delete mode 100644 third_party/xla/xla/translate/hlo_to_mhlo/translate.h delete mode 100644 third_party/xla/xla/translate/stablehlo_to_hlo/BUILD delete mode 100644 third_party/xla/xla/translate/stablehlo_to_hlo/translate.h diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 61194f0a629e72..aab8b49a216641 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -207,15 +207,6 @@ cc_library( ], ) -cc_library( - name = "test", - testonly = 1, - hdrs = ["test.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/testlib:test instead.", - visibility = internal_visibility([":friends"]), - deps = ["//xla/hlo/testlib:test"], -) - cc_library( name = "types", hdrs = ["types.h"], @@ -942,15 +933,6 @@ cc_library( ], ) -cc_library( - name = "test_helpers", - testonly = 1, - hdrs = ["test_helpers.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/testlib:test_helpers instead.", - visibility = internal_visibility([":friends"]), - deps = ["//xla/hlo/testlib:test_helpers"], -) - cc_library( name = "text_literal_reader", srcs = ["text_literal_reader.cc"], diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 577aecd868e1be..9af439a45a83c2 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -27,15 +27,6 @@ filegroup( ]), ) -cc_library( - name = "padding", - hdrs = ["padding.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder:padding instead.", - deps = [ - "//xla/hlo/builder:padding", - ], -) - cc_library( name = "client", srcs = ["client.cc"], @@ -201,42 +192,3 @@ cc_library( "@com_google_absl//absl/synchronization", ], ) - -cc_library( - name = "sharding_builder", - hdrs = ["sharding_builder.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder:sharding_builder instead.", - deps = [ - "//xla/hlo/builder:sharding_builder", - ], -) - -cc_library( - name = "xla_computation", - hdrs = ["xla_computation.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder:xla_computation instead.", - visibility = ["//visibility:public"], - deps = [ - "//xla/hlo/builder:xla_computation", - ], -) - -cc_library( - name = "value_inference", - hdrs = ["value_inference.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder:value_inference instead.", - visibility = ["//visibility:public"], - deps = [ - "//xla/hlo/builder:value_inference", - ], -) - -cc_library( - name = "xla_builder", - hdrs = ["xla_builder.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/builder:xla_builder instead.", - visibility = ["//visibility:public"], - deps = [ - "//xla/hlo/builder:xla_builder", - ], -) diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h deleted file mode 100644 index a9e928d865da0e..00000000000000 --- a/third_party/xla/xla/client/padding.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_PADDING_H_ -#define XLA_CLIENT_PADDING_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/padding.h" - -#endif // XLA_CLIENT_PADDING_H_ diff --git a/third_party/xla/xla/client/sharding_builder.h b/third_party/xla/xla/client/sharding_builder.h deleted file mode 100644 index 995978b165f885..00000000000000 --- a/third_party/xla/xla/client/sharding_builder.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_SHARDING_BUILDER_H_ -#define XLA_CLIENT_SHARDING_BUILDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/sharding_builder.h" - -#endif // XLA_CLIENT_SHARDING_BUILDER_H_ diff --git a/third_party/xla/xla/client/value_inference.h b/third_party/xla/xla/client/value_inference.h deleted file mode 100644 index f717cc703b2502..00000000000000 --- a/third_party/xla/xla/client/value_inference.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_CLIENT_VALUE_INFERENCE_H_ -#define XLA_CLIENT_VALUE_INFERENCE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/value_inference.h" - -#endif // XLA_CLIENT_VALUE_INFERENCE_H_ diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h deleted file mode 100644 index 1599160a713014..00000000000000 --- a/third_party/xla/xla/client/xla_builder.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_XLA_BUILDER_H_ -#define XLA_CLIENT_XLA_BUILDER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/xla_builder.h" - -#endif // XLA_CLIENT_XLA_BUILDER_H_ diff --git a/third_party/xla/xla/client/xla_computation.h b/third_party/xla/xla/client/xla_computation.h deleted file mode 100644 index 685fcfecb0b093..00000000000000 --- a/third_party/xla/xla/client/xla_computation.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_CLIENT_XLA_COMPUTATION_H_ -#define XLA_CLIENT_XLA_COMPUTATION_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/builder/xla_computation.h" - -#endif // XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 5f2367acee4761..47f798ca0c04d4 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -269,20 +269,6 @@ xla_cc_test( ], ) -cc_library( - name = "hlo_reachability", - hdrs = ["hlo_reachability.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_reachability instead.", - deps = ["//xla/hlo/analysis:hlo_reachability"], -) - -cc_library( - name = "hlo_dfs_reachability", - hdrs = ["hlo_dfs_reachability.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_dfs_reachability instead.", - deps = ["//xla/hlo/analysis:hlo_dfs_reachability"], -) - cc_library( name = "ptrvec", hdrs = ["ptrvec.h"], diff --git a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h deleted file mode 100644 index 446be761b96228..00000000000000 --- a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ -#define XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_dfs_reachability.h" - -#endif // XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_reachability.h b/third_party/xla/xla/hlo/ir/hlo_reachability.h deleted file mode 100644 index 30153bf07aadc8..00000000000000 --- a/third_party/xla/xla/hlo/ir/hlo_reachability.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_HLO_IR_HLO_REACHABILITY_H_ -#define XLA_HLO_IR_HLO_REACHABILITY_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/analysis/hlo_reachability.h" - -#endif // XLA_HLO_IR_HLO_REACHABILITY_H_ diff --git a/third_party/xla/xla/test.h b/third_party/xla/xla/test.h deleted file mode 100644 index 8ce11ab8a7a374..00000000000000 --- a/third_party/xla/xla/test.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TEST_H_ -#define XLA_TEST_H_ - -// This header includes gmock.h and enables the use of gmock matchers in tests -// in third_party/tensorflow/compiler/xla. -// -// Test including this header can use the macros EXPECT_THAT(...) and -// ASSERT_THAT(...) in combination with gmock matchers. -// Example: -// std::vector vec = Foo(); -// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); -// -// For more details on gmock matchers see: -// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers -// -// The advantages of using gmock matchers instead of self defined matchers are -// better error messages, more maintainable tests and more test coverage. -// -// Note that while the use of gmock matchers is allowed in the xla project, the -// use of mocks is disallowed in the whole tensorflow project! - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/testlib/test.h" - -#endif // XLA_TEST_H_ diff --git a/third_party/xla/xla/test_helpers.h b/third_party/xla/xla/test_helpers.h deleted file mode 100644 index 77336bd5aa53cc..00000000000000 --- a/third_party/xla/xla/test_helpers.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TEST_HELPERS_H_ -#define XLA_TEST_HELPERS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/testlib/test_helpers.h" - -#endif // XLA_TEST_HELPERS_H_ diff --git a/third_party/xla/xla/translate/BUILD b/third_party/xla/xla/translate/BUILD deleted file mode 100644 index df293cbeac9ab9..00000000000000 --- a/third_party/xla/xla/translate/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//xla/tsl:tsl.bzl", "internal_visibility") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = internal_visibility([ - "//learning/brain/mlir:tensorflow_friends", - "//learning/brain/mlir:xla_friends", - ]), - licenses = ["notice"], -) - -alias( - name = "xla-translate", - actual = "//xla/hlo/translate:xla-translate", - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate:xla-translate instead.", -) - -alias( - name = "xla-translate-opt", - actual = "//xla/hlo/translate:xla-translate-opt", - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate:xla-translate-opt instead.", -) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD deleted file mode 100644 index d439babc4564e9..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD +++ /dev/null @@ -1,78 +0,0 @@ -load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = internal_visibility([ - "//learning/brain/mlir:tensorflow_friends", - "//learning/brain/mlir:xla_friends", - ]), - licenses = ["notice"], -) - -cc_library( - name = "attribute_importer", - hdrs = ["attribute_importer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:attribute_importer instead.", - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:attribute_importer", - ], -) - -cc_library( - name = "hlo_function_importer", - hdrs = ["hlo_function_importer.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_function_importer instead.", - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", - ], -) - -cc_library( - name = "hlo_module_importer", - hdrs = [ - "hlo_module_importer.h", - ], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_module_importer instead.", - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:hlo_module_importer", - ], -) - -cc_library( - name = "hlo_to_mlir_hlo", - hdrs = ["hlo_to_mlir_hlo.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo instead.", - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - ], -) - -cc_library( - name = "hlo_utils", - hdrs = ["hlo_utils.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_utils instead.", - includes = ["include"], - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", - ], -) - -cc_library( - name = "translate", - hdrs = ["translate.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:translate instead.", - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:translate", - ], -) - -cc_library( - name = "translate_registration", - testonly = True, - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:translate_registration instead.", - deps = [ - "//xla/hlo/translate/hlo_to_mhlo:translate_registration", - ], - alwayslink = 1, -) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h deleted file mode 100644 index 2b5f81982fd6d8..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" - -#endif // XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h deleted file mode 100644 index 0ebd37fa6af125..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" - -#endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h deleted file mode 100644 index 8577e86dc93839..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" - -#endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h deleted file mode 100644 index 4943ef790d35f1..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" - -#endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h deleted file mode 100644 index 50e31028617463..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines helpers useful when creating or manipulating lhlo/hlo. - -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" - -#endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate.h b/third_party/xla/xla/translate/hlo_to_mhlo/translate.h deleted file mode 100644 index 4ed0dc5c1ba216..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/hlo_to_mhlo/translate.h" - -#endif // XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD b/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD deleted file mode 100644 index e588cb866371d7..00000000000000 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = internal_visibility([ - "//learning/brain/mlir:tensorflow_friends", - "//learning/brain/mlir:xla_friends", - ]), - licenses = ["notice"], -) - -cc_library( - name = "translate", - hdrs = ["translate.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/stablehlo_to_hlo:translate instead.", - deps = [ - "//xla/hlo/translate/stablehlo_to_hlo:translate", - ], -) diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h b/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h deleted file mode 100644 index badaeeaa9acb30..00000000000000 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ -#define XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" - -#endif // XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ From 5f5b7574457eda4dc407efda09b5f7f4b877f3a3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 08:42:41 -0700 Subject: [PATCH 0567/1324] Converts an AutoShardingSolverRequest to IOPDDL and back before solving. PiperOrigin-RevId: 746475138 --- .../xla/hlo/experimental/auto_sharding/BUILD | 1 + .../auto_sharding/auto_sharding.cc | 10 ++- .../auto_sharding/auto_sharding_iopddl.cc | 2 +- .../auto_sharding/auto_sharding_solver.cc | 29 +++--- .../auto_sharding/auto_sharding_solver.h | 7 +- .../auto_sharding_solver_test.cc | 89 +++++++++++++------ .../auto_sharding/auto_sharding_test.cc | 12 ++- 7 files changed, 100 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 68f003c22b3232..53e03158ce2e42 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ ":auto_sharding_cost_graph", ":auto_sharding_device_mesh", + ":auto_sharding_iopddl", ":auto_sharding_option", ":auto_sharding_solver", ":auto_sharding_strategy", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 6ba427fbc8fc4d..5ba07480e91184 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_memory.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" @@ -1997,7 +1998,14 @@ CreateAutoShardingSolverRequestAndCallSolver( PopulateTemporalValues(cost_graph, request); - return FormulateAndSolveMIPFromSolverRequest(request); + const auto converted_problem = ConvertToProblem(request); + const auto converted_request = ConvertToSolverRequest(converted_problem); + const std::optional overbudget_coeff = + option.memory_overbudget_coeff >= 0.0 + ? std::make_optional(option.memory_overbudget_coeff) + : std::nullopt; + return FormulateAndSolveMIPFromSolverRequest(converted_request, + overbudget_coeff); } void CheckHloSharding( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc index a2cd086c11f0bc..2e47c828a999fa 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_iopddl.cc @@ -154,7 +154,7 @@ AutoShardingSolverRequest ConvertToSolverRequest( AutoShardingSolverRequest request; request.set_request_name(problem.name); request.set_num_nodes(problem.nodes.size()); - request.set_memory_budget(*problem.usage_limit); + request.set_memory_budget(problem.usage_limit.value_or(-1)); for (iopddl::NodeIdx node_idx = 0; node_idx < problem.nodes.size(); ++node_idx) { const iopddl::Node& node = problem.nodes[node_idx]; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index b0c1a40b49af8c..bf017949e6a104 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -268,7 +268,7 @@ absl::StatusOr SolveAndExtractSolution( const std::vector>& s, const std::vector>& e, const MPVariable* overbudget_var, const MPVariable* makespan_var, - MPSolver& solver) { + const std::optional overbudget_coeff, MPSolver& solver) { auto status = solver.Solve(); LOG(INFO) << "Solver absl::Status: " << status; @@ -364,8 +364,7 @@ absl::StatusOr SolveAndExtractSolution( unsalted_objective += request.resharding_costs(edge_idx).costs(j); } if (overbudget_var) { - unsalted_objective += request.overbudget_coeff().coeff() * - overbudget_var->solution_value() * + unsalted_objective += *overbudget_coeff * overbudget_var->solution_value() * request.memory_budget(); } if (makespan_var) { @@ -578,7 +577,8 @@ void AddMemoryTerms( // is guaranteed to never produce a negative overall cost for the graph, // however. absl::StatusOr FormulateAndSolveMIPFromSolverRequest( - const AutoShardingSolverRequest& unscaled_request) { + const AutoShardingSolverRequest& unscaled_request, + std::optional overbudget_coeff) { const absl::Time start_time = absl::Now(); const AutoShardingSolverRequest request = ScaleRequest(unscaled_request); const size_t num_edges = request.edges_size(); @@ -662,7 +662,7 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( edge_map.insert({followed_edge, edge_idx}); } - if (request.memory_budget() > 0 && request.has_overbudget_coeff()) { + if (request.memory_budget() > 0 && overbudget_coeff.has_value()) { overbudget_var = solver->MakeNumVar(0.0, MPSolver::infinity(), "overbudget"); } @@ -808,8 +808,7 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( } if (overbudget_var && !request.minimize_departures()) { solver->MutableObjective()->SetCoefficient( - overbudget_var, - request.overbudget_coeff().coeff() * request.memory_budget()); + overbudget_var, *overbudget_coeff * request.memory_budget()); } LOG(INFO) << "Minimum memory budget estimate: " << MinimumMemoryBudgetRequired(request); @@ -974,11 +973,11 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( if (request.has_max_departures()) { VLOG(0) << "Max departures: " << request.max_departures().coeff(); } - auto result = SolveAndExtractSolution(request, s, e, overbudget_var, - makespan_var, *solver); + auto result = SolveAndExtractSolution( + request, s, e, overbudget_var, makespan_var, overbudget_coeff, *solver); if (result.ok()) { const AutoShardingEvaluation evaluation = - Evaluate(unscaled_request, *result); + Evaluate(unscaled_request, *result, overbudget_coeff); LOG(INFO) << "*** Total costs for the (unscaled) solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost @@ -1207,7 +1206,8 @@ bool AutoShardingEvaluation::operator==( } AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverOutput& result) { + const AutoShardingSolverOutput& result, + std::optional overbudget_coeff) { const auto& c = request.computation_costs(); const auto& d = request.communication_costs(); const auto& r = request.resharding_costs(); @@ -1376,11 +1376,10 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, evaluation.violation_codes.insert(kMemoryViolationCode); } } - if (request.has_overbudget_coeff()) { - evaluation.total.overbudget_cost = - request.overbudget_coeff().coeff() * total_overbudget; + if (overbudget_coeff.has_value()) { + evaluation.total.overbudget_cost = *overbudget_coeff * total_overbudget; evaluation.lower_bound.overbudget_cost = - request.overbudget_coeff().coeff() * lower_bound_overbudget; + *overbudget_coeff * lower_bound_overbudget; } } // Compute metrics and lower bounds. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index f832699ffb5b31..a8b9a631807ac8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_H_ +#include #include #include @@ -46,7 +47,8 @@ AutoShardingSolverRequest ScaleRequest( double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request); absl::StatusOr FormulateAndSolveMIPFromSolverRequest( - const AutoShardingSolverRequest& request); + const AutoShardingSolverRequest& request, + std::optional overbudget_coeff); // TODO(fahrbach): Create AutoShardingHeuristicOptions proto with a oneof field. // Runs a heuristic specified by one of the following values of `algorithm`: @@ -101,7 +103,8 @@ struct AutoShardingEvaluation { // Evaluates the given solver result w.r.t. the input request, computing various // solution quality metrics and validating the consistency of hard constraints. AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverOutput& result); + const AutoShardingSolverOutput& result, + std::optional overbudget_coeff); // Computes the objective value of the sharding strategy. If the objective value // is infinite or the sharding is infeasible (e.g., violates the peak-memory diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index f5da95b53a6907..ed3abdc7204335 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -13,6 +13,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include +#include #include #include #include @@ -84,6 +85,13 @@ void AddGroups( } } +std::optional GetOverbudgetCoeff( + const AutoShardingSolverRequest& request) { + return request.has_overbudget_coeff() + ? std::make_optional(request.overbudget_coeff().coeff()) + : std::nullopt; +} + // clang-format off AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { @@ -256,7 +264,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -270,7 +279,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { request.mutable_overbudget_coeff()->set_coeff(10.0); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; @@ -283,7 +293,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { request.mutable_max_departures()->set_coeff(3.0); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -296,7 +307,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { request.set_minimize_departures(true); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; @@ -311,7 +323,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; @@ -324,7 +337,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -349,7 +363,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { AddCosts(request.mutable_duration_costs(), t); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; @@ -376,7 +391,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { AddCosts(request.mutable_duration_costs(), t); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; @@ -390,7 +406,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -403,7 +420,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 const absl::StatusOr result = - FormulateAndSolveMIPFromSolverRequest(request); + FormulateAndSolveMIPFromSolverRequest(request, + GetOverbudgetCoeff(request)); EXPECT_TRUE(absl::IsInternal(result.status())); } @@ -413,7 +431,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { request.mutable_max_cost()->set_coeff(1e19); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -436,7 +455,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { request.set_enable_memory_edge_costs(true); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -464,7 +484,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { request.set_enable_memory_edge_costs(true); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -497,7 +518,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, request.set_enable_memory_edge_costs(true); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -517,7 +539,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, request.set_enable_memory_edge_costs(false); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -558,7 +581,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, request.set_memory_budget(4321); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -571,7 +595,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { AutoShardingSolverRequestWithEquivalences(); TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, - FormulateAndSolveMIPFromSolverRequest(request)); + FormulateAndSolveMIPFromSolverRequest( + request, GetOverbudgetCoeff(request))); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; @@ -585,7 +610,8 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 159.0; // 13+21+32+42+51 @@ -608,7 +634,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -637,7 +664,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -669,7 +697,8 @@ TEST(AutoShardingEvaluatorTest, const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -692,7 +721,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kFollowerViolationCode}; @@ -714,7 +744,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kAliasViolationCode}; @@ -736,7 +767,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMemoryViolationCode}; @@ -761,7 +793,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -784,7 +817,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -807,7 +841,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingEvaluation evaluation = Evaluate(request, output); + const AutoShardingEvaluation evaluation = + Evaluate(request, output, GetOverbudgetCoeff(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMaxDeparturesViolationCode}; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 11f783b9e27dc5..625f3f95b5910d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1527,15 +1527,19 @@ ENTRY twomatmul { op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); const HloInstruction* param3 = FindInstruction(module.get(), "parameter.3"); ASSERT_NE(param3, nullptr); - EXPECT_THAT(param3, - op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); + EXPECT_THAT( + param3, + AnyOf(op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[1,2,2]<=[4] last_tile_dim_replicate}"))); const HloInstruction* dot4 = FindInstruction(module.get(), "dot.4"); ASSERT_NE(dot4, nullptr); EXPECT_THAT(dot4, op::Sharding("{devices=[2,2]0,2,1,3}")); const HloInstruction* dot5 = FindInstruction(module.get(), "dot.5"); ASSERT_NE(dot5, nullptr); - EXPECT_THAT(dot5, - op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}")); + EXPECT_THAT( + dot5, + AnyOf(op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,2]<=[2,2]T(1,0)}"))); } TEST_F(AutoShardingTest, TwoMatmulWithDotReplicationEnabled) { From a59e97bb3c68ac8b764941e5b7c6a529246f5aab Mon Sep 17 00:00:00 2001 From: James Ward Date: Fri, 11 Apr 2025 16:54:53 +0100 Subject: [PATCH 0568/1324] Negate the pad values when legalizing transpose_conv2d (#61892) Matches the new definition in the specification Change-Id: I4f8dfa3d380039a88b96fd74f09e8f8ebabee3f5 Co-authored-by: Eric Kunze --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 16 ++++++++++++++++ .../mlir/tosa/transforms/legalize_utils.cc | 8 ++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index f8f753919e87c6..1e44952b105e5f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -85,6 +85,22 @@ func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tens // ----- +// CHECK-LABEL: test_transpose_conv2d_outpad +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +func.func @test_transpose_conv2d_outpad(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x33x33x16xf32> { + %cst = arith.constant dense<[1, 33, 33, 16]> : tensor<4xi32> + %cst_1 = "tfl.no_value"() {value = unit} : () -> none + %0 = "tfl.transpose_conv"(%cst, %cst_0, %arg0, %cst_1) + {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, + fused_activation_function = "NONE"} + : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x33x33x16xf32> + func.return %0 : tensor<1x33x33x16xf32> +} + +// ----- + // CHECK-LABEL: test_conv2d_qi8 // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<16x2x2x8xi8>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<16xi32>}> diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 29cc8208d3fa2b..101676be1a0110 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -1067,14 +1067,14 @@ bool getTransposeConv2dPaddingValues( return false; } - int total_padding = ((ifm_size - 1) * dim_stride + filter_size - ofm_size); - total_padding = total_padding > 0 ? total_padding : 0; + int total_padding = + ((ifm_size - 1) * dim_stride + filter_size - ofm_size); pad_before = total_padding / 2; pad_after = total_padding - pad_before; - computed_paddings.push_back(pad_before); - computed_paddings.push_back(pad_after); + computed_paddings.push_back(-pad_before); + computed_paddings.push_back(-pad_after); } explicit_padding = rewriter.getDenseI64ArrayAttr(computed_paddings); From ba906f2822e20ea7993cdfa968da872c9b989994 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 11 Apr 2025 08:57:33 -0700 Subject: [PATCH 0569/1324] [XLA:GPU] Pass perf profiles by reference instead of value. PiperOrigin-RevId: 746479385 --- .../xla/xla/service/gpu/model/collective_interpolator.cc | 2 +- third_party/xla/xla/service/gpu/model/collective_interpolator.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator.cc b/third_party/xla/xla/service/gpu/model/collective_interpolator.cc index 1ac08e7cbb6574..cdf1f5a1922e5a 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator.cc +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator.cc @@ -302,7 +302,7 @@ CollectiveInterpolator::Create(const se::DeviceDescription& device_info) { } /*static*/ absl::StatusOr> -CollectiveInterpolator::Create(HloInstructionProfileList profiles, +CollectiveInterpolator::Create(const HloInstructionProfileList& profiles, const se::DeviceDescription& device_info) { auto interpolators = std::make_unique>>>(); diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator.h b/third_party/xla/xla/service/gpu/model/collective_interpolator.h index 4ef0789c1ae7f5..2a1c800ff83f57 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator.h +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator.h @@ -55,7 +55,7 @@ class CollectiveInterpolator { InterpolatorKey, std::unique_ptr>>>; static absl::StatusOr> Create( - HloInstructionProfileList profiles, + const HloInstructionProfileList& profiles, const se::DeviceDescription& device_info); static absl::StatusOr> Create( From a578ae99956a4c005657673e13a94fc90478db05 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 11 Apr 2025 09:00:25 -0700 Subject: [PATCH 0570/1324] [XLA:GPU] Move mask variables within SoL class. So that we must explicitly reference it with by `SolGPUCostModel::`. PiperOrigin-RevId: 746480207 --- third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h index 6a84c801d65884..7b4c73519d7daa 100644 --- a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h @@ -24,11 +24,14 @@ limitations under the License. namespace xla { namespace gpu { -inline constexpr absl::string_view kSplitMaskWorldLevel = "0x0"; +// Speed-of-Light (SoL) analytical cost model for NCCL collectives. class SolGPUCostModel { - // Speed-of-Light (SoL) analytical cost model for NCCL collectives. public: + static constexpr absl::string_view kSplitMaskWorldLevel = "0x0"; + + static constexpr absl::string_view kSplitMaskNonRailAligned = "0x7"; + // Tunable system configuration, see // xla_gpu_analytical_latency_estimator_options struct Config { From 624a719cc325d1fed2ab81061216bdd8fb257bde Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Fri, 11 Apr 2025 09:02:06 -0700 Subject: [PATCH 0571/1324] Compute minimum contracting size based on input type instead of hardcoding it PiperOrigin-RevId: 746480765 --- .../gpu/autotuning/dot_search_space.cc | 31 +++++++++++++++---- .../service/gpu/autotuning/dot_search_space.h | 7 ++++- .../gpu/autotuning/dot_search_space_test.cc | 6 ++-- .../gpu/autotuning/gemm_fusion_autotuner.cc | 2 +- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 3520b814424baf..0b8ab610b115a9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -91,6 +91,8 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( (contracting_size_ * batch_size_)), rhs_parallel_size_(ShapeUtil::ElementsIn(dot->operand(1)->shape()) / (contracting_size_ * batch_size_)), + operand_bitwidth_( // The bitwitdth of both operands is the same. + primitive_util::BitWidth(dot->operand(0)->shape().element_type())), compute_bitwidth_(primitive_util::BitWidth(dot->shape().element_type())), // Figure out some basic limitations on tiling based on the above. desired_total_warps_(GetDesiredTotalWarps()), @@ -99,7 +101,7 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( // of hardcoding. min_out_tile_{16, 16}, min_warps_per_cta_(4), - min_contracting_tile_size_(16), + min_contracting_tile_size_(GetMinContractingTileSize()), max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) { // Make sure that the range of output tile sizes is not empty // (min_output_tile_ is a hard limit, while max_output_tile_ is a soft one). @@ -150,15 +152,16 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( return result; } -std::string TritonDotFusionSearchSpace::Serialize() { +std::string TritonDotFusionSearchSpace::ToString() const { return absl::StrFormat( - "problem_size_BxMxNxKxE: %dx%dx%dx%dx%d " + "problem_size_BxMxNxKxE: %dx%dx%dx%dx(%d->%d) " "tile_range_SxMxNxK: [1-%d]x[%d-%d]x[%d-%d]x[%d-?] " "desired_total_warps: %d warps_per_cta: [%d-?]", batch_size_, lhs_parallel_size_, rhs_parallel_size_, contracting_size_, - compute_bitwidth_, max_contracting_split_, min_out_tile_.lhs_dim, - max_out_tile_.lhs_dim, min_out_tile_.rhs_dim, max_out_tile_.rhs_dim, - min_contracting_tile_size_, desired_total_warps_, min_warps_per_cta_); + operand_bitwidth_, compute_bitwidth_, max_contracting_split_, + min_out_tile_.lhs_dim, max_out_tile_.lhs_dim, min_out_tile_.rhs_dim, + max_out_tile_.rhs_dim, min_contracting_tile_size_, desired_total_warps_, + min_warps_per_cta_); } int TritonDotFusionSearchSpace::GetDesiredTotalWarps() const { @@ -228,6 +231,22 @@ int TritonDotFusionSearchSpace::GetMaxWarpsPerCta(OutputTile tile) const { std::min(max_warps, lhs_warps * rhs_warps)); } +int TritonDotFusionSearchSpace::GetMinContractingTileSize() const { + // The number of bits that both MMA and WGMMA instructions expect to have in + // the contracting dimension. See + // https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shape + constexpr int kMmaContractingBitwidth = 128; + /// TODO: b/395572776 - Triton currently requires at least 16 elements, but we + // shouldbe able to relax this and remove this limit here. + constexpr int kTritonLowerLimit = 16; + const int min_contracting_tile_size = + std::max(kMmaContractingBitwidth / operand_bitwidth_, kTritonLowerLimit); + VLOG(5) << "Computing min_contracting_tile_size: Based on bitwidth of " + << operand_bitwidth_ + << ", min_contracting_tile_size = " << min_contracting_tile_size; + return min_contracting_tile_size; +} + int TritonDotFusionSearchSpace::GetMaxContractingSplit( OutputTile output_tile) const { const int64_t desired_num_ctas = desired_total_warps_ / min_warps_per_cta_; diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index 7cc49026abd2ab..569198e3bdae78 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -47,7 +47,7 @@ class TritonDotFusionSearchSpace { std::optional force_contracting_split = std::nullopt); // Serializes the search space to a human-readable string. - std::string Serialize(); + std::string ToString() const; private: // Groups together the tiling of the dot's output dimensions: the parallel @@ -111,6 +111,10 @@ class TritonDotFusionSearchSpace { // instruction shape. int GetMaxWarpsPerCta(OutputTile output_tile) const; + // Computes the minimum reasonable tile size for the contracting dimension + // given the element types of the operands. + int GetMinContractingTileSize() const; + // Computes the maximum sensible split in the contracting dimension // (split_k) to sufficiently occupy all available cores when using the given // output tile. @@ -160,6 +164,7 @@ class TritonDotFusionSearchSpace { int64_t batch_size_; int64_t lhs_parallel_size_; int64_t rhs_parallel_size_; + int operand_bitwidth_; int compute_bitwidth_; int desired_total_warps_; OutputTile max_out_tile_; diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index 7870a56112f3e6..75d109534d85a2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -123,8 +123,8 @@ TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { /*contracting_dim=*/1024)); TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); - EXPECT_EQ(search_space.Serialize(), - "problem_size_BxMxNxKxE: 1x1024x1024x1024x16 " + EXPECT_EQ(search_space.ToString(), + "problem_size_BxMxNxKxE: 1x1024x1024x1024x(16->16) " "tile_range_SxMxNxK: [1-64]x[16-256]x[16-512]x[16-?] " "desired_total_warps: 2640 warps_per_cta: [4-?]"); } @@ -166,7 +166,7 @@ TEST_F(DotSearchSpaceTest, LimitsContractingSplitForSmallerContractingSize) { TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); EXPECT_THAT(search_space.GenerateConfigs(), - AllOf(Not(IsEmpty()), Each(SplitKIs(Le(2))))); + AllOf(Not(IsEmpty()), Each(SplitKIs(Le(4))))); } TEST_F(DotSearchSpaceTest, FindsGoodDataReuseOutputTiles) { diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index c05f71517ac5e0..ea485aa0318d80 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -897,7 +897,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { TritonDotFusionSearchSpace search_space(config_.GetDeviceDescription(), &dot); VLOG(1) << "Generating configs from search space: " - << search_space.Serialize(); + << search_space.ToString(); // We don't need to consider small_dot here. The new search space will // already generate a unique config for small problems. return search_space.GenerateConfigs( From 99e7ad988852f2a95f47b80f93032a5ab6a6e595 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Fri, 11 Apr 2025 09:40:36 -0700 Subject: [PATCH 0572/1324] PR #24939: Dumping non-default debug options in a new file under xla_dump_to Imported from GitHub PR https://github.com/openxla/xla/pull/24939 With this change, the non-default debug options are now dumped in the directory given by `--xla_dump_to` XLA_FLAGS. When dumping the non-default values, the default values are from `DefaultDebugOptionsIgnoringFlags()`. The dump is in a protobuf parsable format. The file is emitted using the same format as other dumps with a `debug_options` suffix: for example, `module_0001.test_module.debug_options` `original debug options = DefaultDebugOptionsIgnoringFlags() + debug_options overwritten from file` Copybara import of the project: -- 5c50eea207f6ac3846bcc8b37ed03e1cb71a6a99 by Shraiysh Vaishay : Dumping non-default debug options in a new file under xla_dump_to With this change, the non-default debug options are now dumped in the directory given by `--xla_dump_to` XLA_FLAGS. When dumping the non-default values, the default values are from `DefaultDebugOptionsIgnoringFlags()`. The dump is in a protobuf parsable format. The file is emitted using the same format as other dumps with a `debug_options` suffix: for example, `module_0001.test_module.debug_options` `original debug options = DefaultDebugOptionsIgnoringFlags() + debug_options overwritten from file` Merging this change closes #24939 PiperOrigin-RevId: 746492585 --- third_party/xla/xla/service/BUILD | 2 + third_party/xla/xla/service/dump.cc | 162 +++++++++++++++++++++++ third_party/xla/xla/service/dump.h | 10 ++ third_party/xla/xla/service/dump_test.cc | 128 ++++++++++++++++++ 4 files changed, 302 insertions(+) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 64009010fb7389..187376ecdfbd0e 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -529,7 +529,9 @@ xla_cc_test( deps = [ ":dump", ":hlo_module_config", + "//xla:debug_options_flags", "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/runtime/large_hlo_snapshot_serialization:serialization", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index fda459a50b0de7..3b238dfa4f4b77 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/dump.h" +#include #include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include #include #include + #include "absl/algorithm/container.h" #include "absl/base/const_init.h" #include "absl/container/flat_hash_set.h" @@ -753,12 +755,172 @@ std::vector DumpHloModuleIfEnabled(const HloModule& module, return {}; } +std::string GetRepeatedValueAsString( + const tsl::protobuf::Reflection* reflection, + const DebugOptions& debug_options, + const tsl::protobuf::FieldDescriptor* field, int index) { + switch (field->type()) { + case tsl::protobuf::FieldDescriptor::TYPE_INT32: + return std::to_string( + reflection->GetRepeatedInt32(debug_options, field, index)); + case tsl::protobuf::FieldDescriptor::TYPE_INT64: + return std::to_string( + reflection->GetRepeatedInt64(debug_options, field, index)); + case tsl::protobuf::FieldDescriptor::TYPE_UINT32: + return std::to_string( + reflection->GetRepeatedUInt32(debug_options, field, index)); + case tsl::protobuf::FieldDescriptor::TYPE_UINT64: + return std::to_string( + reflection->GetRepeatedUInt64(debug_options, field, index)); + case tsl::protobuf::FieldDescriptor::TYPE_DOUBLE: + return std::to_string( + reflection->GetRepeatedDouble(debug_options, field, index)); + case tsl::protobuf::FieldDescriptor::TYPE_FLOAT: + return std::to_string( + reflection->GetRepeatedFloat(debug_options, field, index)); + case tsl::protobuf::FieldDescriptor::TYPE_BOOL: + return reflection->GetRepeatedBool(debug_options, field, index) ? "true" + : "false"; + case tsl::protobuf::FieldDescriptor::TYPE_ENUM: + return std::string( + reflection->GetRepeatedEnum(debug_options, field, index)->name()); + case tsl::protobuf::FieldDescriptor::TYPE_STRING: + return reflection->GetRepeatedString(debug_options, field, index); + case tsl::protobuf::FieldDescriptor::TYPE_MESSAGE: { + tsl::protobuf::TextFormat::Printer tsl_printer; + tsl_printer.SetInitialIndentLevel(1); + std::string result; + tsl_printer.PrintToString( + reflection->GetRepeatedMessage(debug_options, field, index), &result); + return "{\n" + result + "}"; + } + default: + return "Unsupported field type"; + } +} + +std::string GetValueAsString(const tsl::protobuf::Reflection* reflection, + const DebugOptions& debug_options, + const tsl::protobuf::FieldDescriptor* field) { + // Based on the field type, get the value and convert it to a string + switch (field->type()) { + case tsl::protobuf::FieldDescriptor::TYPE_INT32: + return std::to_string(reflection->GetInt32(debug_options, field)); + case tsl::protobuf::FieldDescriptor::TYPE_INT64: + return std::to_string(reflection->GetInt64(debug_options, field)); + case tsl::protobuf::FieldDescriptor::TYPE_UINT32: + return std::to_string(reflection->GetUInt32(debug_options, field)); + case tsl::protobuf::FieldDescriptor::TYPE_UINT64: + return std::to_string(reflection->GetUInt64(debug_options, field)); + case tsl::protobuf::FieldDescriptor::TYPE_DOUBLE: + return std::to_string(reflection->GetDouble(debug_options, field)); + case tsl::protobuf::FieldDescriptor::TYPE_FLOAT: + return std::to_string(reflection->GetFloat(debug_options, field)); + case tsl::protobuf::FieldDescriptor::TYPE_BOOL: + return reflection->GetBool(debug_options, field) ? "true" : "false"; + case tsl::protobuf::FieldDescriptor::TYPE_ENUM: + return std::string(reflection->GetEnum(debug_options, field)->name()); + case tsl::protobuf::FieldDescriptor::TYPE_STRING: + return "\"" + reflection->GetString(debug_options, field) + "\""; + case tsl::protobuf::FieldDescriptor::TYPE_MESSAGE: { + tsl::protobuf::TextFormat::Printer tsl_printer; + tsl_printer.SetSingleLineMode(false); + std::string result; + tsl_printer.PrintToString(reflection->GetMessage(debug_options, field), + &result); + return "{\n" + result + "}"; + } + default: + return "Unsupported field type"; + } +} + +std::string GetNonDefaultDebugOptions(const DebugOptions& debug_options) { + // Create a default DebugOptions to compare against + DebugOptions default_options = DefaultDebugOptionsIgnoringFlags(); + std::string non_default_options; + + // Use protobuf reflection to compare fields + const tsl::protobuf::Descriptor* descriptor = debug_options.GetDescriptor(); + const tsl::protobuf::Reflection* reflection = debug_options.GetReflection(); + + // Iterate through all fields + for (int i = 0; i < descriptor->field_count(); i++) { + const tsl::protobuf::FieldDescriptor* field = descriptor->field(i); + + if (field->is_repeated()) { + // Handle repeated fields by comparing the values + int repeated_count = reflection->FieldSize(debug_options, field); + int default_count = reflection->FieldSize(default_options, field); + + // Only process if the repeated field has values + if (repeated_count > 0) { + std::vector debug_values(repeated_count); + std::vector default_values(default_count); + + // Collect all values from debug_options + for (int j = 0; j < repeated_count; j++) { + debug_values[j] = + GetRepeatedValueAsString(reflection, debug_options, field, j); + } + + // Collect all values from default_options + for (int j = 0; j < default_count; j++) { + default_values[j] = + GetRepeatedValueAsString(reflection, default_options, field, j); + } + + // Sort both vectors for comparison + std::sort(debug_values.begin(), debug_values.end()); + std::sort(default_values.begin(), default_values.end()); + + // Compare the sorted vectors + if (debug_values != default_values) { + // Values differ, append all debug values to output + for (const auto& value : debug_values) { + absl::StrAppend(&non_default_options, field->name(), ": ", value, + "\n"); + } + } + } + continue; + } + + // Check if this field differs from default + if (reflection->HasField(debug_options, field) && + !reflection->HasField(default_options, field)) { + // Field exists in debug_options but not defaults + absl::StrAppend(&non_default_options, field->name(), ": ", + GetValueAsString(reflection, debug_options, field), "\n"); + } else if (reflection->HasField(debug_options, field)) { + // Field exists in both, compare values + if (GetValueAsString(reflection, debug_options, field) != + GetValueAsString(reflection, default_options, field)) { + absl::StrAppend(&non_default_options, field->name(), ": ", + GetValueAsString(reflection, debug_options, field), + "\n"); + } + } + } + + return non_default_options; +} + +void DumpNonDefaultDebugOptions(const HloModule& module, + absl::string_view suffix) { + const DebugOptions& debug_options = module.config().debug_options(); + auto filename = FilenameFor(module, "", suffix); + auto nonDefaultDebugOptions = GetNonDefaultDebugOptions(debug_options); + DumpToFileInDir(debug_options, filename, nonDefaultDebugOptions); +} + std::vector DumpHloModuleIfEnabled( const HloModule& module, const BufferAssignment& buffer_assn, string_view name) { CanonicalDebugOptions opts(module.config().debug_options()); if (opts.should_dump_module(module.name())) { DumpHloModuleImpl(module, &buffer_assn, TimestampFor(module), name, opts); + DumpNonDefaultDebugOptions(module, kNonDefaultDebugOptionsDumpSuffix); } return {}; } diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h index 75cb9a717fa1f1..36bde59814e8f0 100644 --- a/third_party/xla/xla/service/dump.h +++ b/third_party/xla/xla/service/dump.h @@ -39,6 +39,7 @@ namespace xla { // performed on an HloModule. constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations"; constexpr char kAfterOptimizationsDumpName[] = "after_optimizations"; +constexpr char kNonDefaultDebugOptionsDumpSuffix[] = "debug_options"; class BufferAssignment; class HloSnapshot; @@ -199,6 +200,15 @@ absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message, void DumpHloConfigIfEnabled(const HloModule& module); +// Dumps the non-default debug options to a file in the xla_dump_to directory +// specified by the module's DebugOptions. +void DumpNonDefaultDebugOptions(const HloModule& module, + absl::string_view suffix); + +// Returns the non-default debug options as a string. The default debug options +// are received from DefaultDebugOptionsIgnoringFlags(). +std::string GetNonDefaultDebugOptions(const DebugOptions& debug_options); + } // namespace xla #endif // XLA_SERVICE_DUMP_H_ diff --git a/third_party/xla/xla/service/dump_test.cc b/third_party/xla/xla/service/dump_test.cc index 2a0afaeefb1c61..71599516d18a92 100644 --- a/third_party/xla/xla/service/dump_test.cc +++ b/third_party/xla/xla/service/dump_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/dump.h" +#include + #include #include #include @@ -22,6 +24,9 @@ limitations under the License. #include #include #include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/runtime/large_hlo_snapshot_serialization/serialization.h" #include "xla/service/hlo_module_config.h" @@ -288,5 +293,128 @@ TEST(DumpTest, DumpHloUnoptimizedSnapshotProtoBinary) { EXPECT_EQ(hlo_snapshot_loaded.hlo_module().name(), module.name()); } +TEST(DumpTest, GetNonDefaultDebugOptions) { + DebugOptions options; + DebugOptions default_options = DefaultDebugOptionsIgnoringFlags(); + std::string dump_folder = tsl::testing::TmpDir(); + + // String field + options.set_xla_dump_to(dump_folder); + // Int32 field + options.set_xla_gpu_dot_merger_threshold_mb( + default_options.xla_gpu_dot_merger_threshold_mb() + 100); + // Int64 field + options.set_xla_gpu_experimental_collective_cse_distance_threshold( + default_options.xla_gpu_experimental_collective_cse_distance_threshold() + + 100); + // Bool field + options.set_xla_gpu_enable_nccl_user_buffers( + !default_options.xla_gpu_enable_nccl_user_buffers()); + options.set_xla_enable_dumping(true); + // Enum field + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS); + options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + // Message field + int gpus_per_node; + EXPECT_TRUE(absl::SimpleAtoi( + default_options.xla_gpu_analytical_latency_estimator_options().at( + "gpus_per_node"), + &gpus_per_node)); + int chunk_size_bytes; + EXPECT_TRUE(absl::SimpleAtoi( + default_options.xla_gpu_analytical_latency_estimator_options().at( + "chunk_size_bytes"), + &chunk_size_bytes)); + options.mutable_xla_gpu_analytical_latency_estimator_options()->insert( + {"gpus_per_node", std::to_string(gpus_per_node + 1)}); + options.mutable_xla_gpu_analytical_latency_estimator_options()->insert( + {"chunk_size_bytes", std::to_string(chunk_size_bytes)}); + + auto non_default_options = GetNonDefaultDebugOptions(options); + EXPECT_THAT(non_default_options, + testing::HasSubstr("xla_dump_to: \"" + dump_folder + "\"")); + EXPECT_THAT( + non_default_options, + testing::HasSubstr( + "xla_gpu_dot_merger_threshold_mb: " + + std::to_string(default_options.xla_gpu_dot_merger_threshold_mb() + + 100))); + EXPECT_THAT( + non_default_options, + testing::HasSubstr( + "xla_gpu_experimental_collective_cse_distance_threshold: " + + std::to_string( + default_options + .xla_gpu_experimental_collective_cse_distance_threshold() + + 100))); + EXPECT_THAT(non_default_options, + testing::HasSubstr("xla_gpu_enable_nccl_user_buffers: true")); + EXPECT_THAT(non_default_options, + testing::HasSubstr("xla_gpu_enable_command_buffer: CUBLAS")); + EXPECT_THAT(non_default_options, + testing::HasSubstr("xla_gpu_enable_command_buffer: FUSION")); + EXPECT_THAT( + non_default_options, + testing::HasSubstr("xla_gpu_analytical_latency_estimator_options: {\n" + " key: \"gpus_per_node\"\n" + " value: \"" + + std::to_string(gpus_per_node + 1) + + "\"\n" + "}")); + EXPECT_THAT( + non_default_options, + testing::HasSubstr("xla_gpu_analytical_latency_estimator_options: {\n" + " key: \"chunk_size_bytes\"\n" + " value: \"" + + std::to_string(chunk_size_bytes) + + "\"\n" + "}")); + tsl::protobuf::TextFormat::Parser parser; + DebugOptions parsed_options = DefaultDebugOptionsIgnoringFlags(); + parser.ParseFromString(non_default_options, &parsed_options); + EXPECT_EQ(parsed_options.xla_dump_to(), dump_folder); + EXPECT_EQ(parsed_options.xla_gpu_dot_merger_threshold_mb(), + default_options.xla_gpu_dot_merger_threshold_mb() + 100); + EXPECT_EQ( + parsed_options.xla_gpu_experimental_collective_cse_distance_threshold(), + default_options.xla_gpu_experimental_collective_cse_distance_threshold() + + 100); + EXPECT_EQ(parsed_options.xla_gpu_enable_nccl_user_buffers(), + !default_options.xla_gpu_enable_nccl_user_buffers()); + EXPECT_EQ(parsed_options.xla_gpu_enable_command_buffer_size(), 2); + EXPECT_EQ(parsed_options.xla_gpu_enable_command_buffer(0), + DebugOptions::CUBLAS); + EXPECT_EQ(parsed_options.xla_gpu_enable_command_buffer(1), + DebugOptions::FUSION); + EXPECT_EQ(parsed_options.xla_gpu_analytical_latency_estimator_options().at( + "gpus_per_node"), + std::to_string(gpus_per_node + 1)); + EXPECT_EQ(parsed_options.xla_gpu_analytical_latency_estimator_options().at( + "chunk_size_bytes"), + std::to_string(chunk_size_bytes)); + + HloModuleConfig config; + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnUnverifiedModule(R"( + HloModule test + ENTRY test { + p0 = s32[11] parameter(0) + c = s32[11] constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + ROOT x = s32[11] multiply(p0, c) + } + )", + config)); + DumpNonDefaultDebugOptions(*m, kNonDefaultDebugOptionsDumpSuffix); + std::string real_contents; + TF_ASSERT_OK(tsl::ReadFileToString( + tsl::Env::Default(), + tsl::io::JoinPath(dump_folder, + FilenameFor(*m, "", kNonDefaultDebugOptionsDumpSuffix)), + &real_contents)); + EXPECT_THAT(real_contents, testing::Eq(non_default_options)); +} + } // namespace } // namespace xla From 570721044ab0df002fd6c0db564f11db2b8d826f Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Fri, 11 Apr 2025 09:43:07 -0700 Subject: [PATCH 0573/1324] Move experimental model_utils code location PiperOrigin-RevId: 746493362 --- tensorflow/compiler/mlir/lite/integrations/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/integrations/BUILD b/tensorflow/compiler/mlir/lite/integrations/BUILD index 1a54d980c52074..899c936e9929a9 100644 --- a/tensorflow/compiler/mlir/lite/integrations/BUILD +++ b/tensorflow/compiler/mlir/lite/integrations/BUILD @@ -20,8 +20,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/lite/integrations:__subpackages__", - "//tensorflow/lite/experimental/litert/python/google/tools/model_utils:__subpackages__", - "//third_party/odml/litert/litert/python/google/tools/model_utils:__subpackages__", + "//third_party/odml/litert/litert/python/tools/model_utils:__subpackages__", ], licenses = ["notice"], ) From 7282769ab6b408ea43aa0af948f04ec7aeae859f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 10:05:39 -0700 Subject: [PATCH 0574/1324] Remove some unnecessary `const_cast`s from OpenXLA. PiperOrigin-RevId: 746501512 --- .../cpu/collectives/gloo_communicator.cc | 24 +++++++++---------- .../cpu/collectives/mpi_communicator.cc | 4 ++-- .../xla/xla/backends/cpu/runtime/kernel.cc | 2 +- .../xla/xla/backends/gpu/codegen/custom.cc | 4 ++-- .../xla/xla/backends/interpreter/executor.h | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc index e5e19aa3a1cfed..ff24895d48d62f 100644 --- a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc @@ -65,12 +65,12 @@ static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, se::DeviceMemoryBase output_buffer, size_t num_elements, gloo::AllreduceOptions& options) { - options.setInput( - reinterpret_cast(const_cast(input_buffer.opaque())), - num_elements); - options.setOutput( - reinterpret_cast(const_cast(output_buffer.opaque())), - num_elements); + options.setInput(reinterpret_cast( // REINTERPRET_CAST_OK=existing code. + input_buffer.opaque()), + num_elements); + options.setOutput(reinterpret_cast( // REINTERPRET_CAST_OK=existing code. + output_buffer.opaque()), + num_elements); using ReductionFn = void (*)(void*, const void*, const void*, size_t); @@ -262,10 +262,10 @@ absl::Status GlooCommunicator::AllToAll( context_->size); for (size_t i = 0; i < world_size; ++i) { if (i != my_rank) { - ins[i] = context_->createUnboundBuffer( - const_cast(send_buffers[i].opaque()), chunk_bytes); - outs[i] = context_->createUnboundBuffer( - const_cast(recv_buffers[i].opaque()), chunk_bytes); + ins[i] = context_->createUnboundBuffer(send_buffers[i].opaque(), + chunk_bytes); + outs[i] = context_->createUnboundBuffer(recv_buffers[i].opaque(), + chunk_bytes); } } @@ -276,8 +276,8 @@ absl::Status GlooCommunicator::AllToAll( outs[recv_rank]->recv(recv_rank, slot); } - std::memcpy(const_cast(recv_buffers[my_rank].opaque()), - send_buffers[my_rank].opaque(), chunk_bytes); + std::memcpy(recv_buffers[my_rank].opaque(), send_buffers[my_rank].opaque(), + chunk_bytes); auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); for (int i = 0; i < world_size; i++) { diff --git a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc index 0062593da75407..b863c2eca62bc3 100644 --- a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc @@ -197,8 +197,8 @@ absl::Status MpiCommunicator::AllToAll( std::vector output_buffers; for (int i = 0; i < size; i++) { - input_buffers.push_back(const_cast(send_buffers[i].opaque())); - output_buffers.push_back(const_cast(recv_buffers[i].opaque())); + input_buffers.push_back(send_buffers[i].opaque()); + output_buffers.push_back(recv_buffers[i].opaque()); } std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel.cc b/third_party/xla/xla/backends/cpu/runtime/kernel.cc index 4464571b2d4efb..ac1a5d7181ec52 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel.cc @@ -54,7 +54,7 @@ static absl::InlinedVector ConvertBuffersToKernelArgs( absl::Span buffers) { absl::InlinedVector args(buffers.size()); for (size_t i = 0; i < buffers.size(); ++i) { - args[i].data = const_cast(buffers[i].opaque()); + args[i].data = buffers[i].opaque(); args[i].size = buffers[i].size(); } return args; diff --git a/third_party/xla/xla/backends/gpu/codegen/custom.cc b/third_party/xla/xla/backends/gpu/codegen/custom.cc index 3ad6c564a0bea5..6f22021dd62c6f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/custom.cc +++ b/third_party/xla/xla/backends/gpu/codegen/custom.cc @@ -111,10 +111,10 @@ absl::StatusOr GetOperandSlice( } // Walk through ShapeIndex to find the real starting point. - auto* start = const_cast(&start_instr); + const auto* start = &start_instr; for (auto idx : shape_idx) { CHECK(start->shape().IsTuple()); - start = const_cast(start->operand(idx)); + start = start->operand(idx); } if (const auto* param = DynCast(start)) { diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 996d1e03c2e77f..8c5e8051564309 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -62,7 +62,7 @@ class InterpreterStream : public host::HostStream { absl::Status Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, uint64_t size) override { - void *src_mem = const_cast(gpu_src.opaque()); + void *src_mem = gpu_src.opaque(); EnqueueTask( [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); return BlockUntilDone(); From 363a086ab94017baf3e3d7ccba458111f6237398 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 10:14:56 -0700 Subject: [PATCH 0575/1324] Reverts 22885935f463c5c5ebb4d347f3b09b9da001747e PiperOrigin-RevId: 746505093 --- tensorflow/python/grappler/BUILD | 82 ------------------- tensorflow/python/grappler/cluster_wrapper.cc | 4 - tensorflow/python/grappler/cost_analyzer.cc | 13 --- tensorflow/python/grappler/cost_analyzer.h | 2 - tensorflow/python/grappler/item_wrapper.cc | 1 - tensorflow/python/grappler/model_analyzer.cc | 7 +- tensorflow/python/grappler/model_analyzer.h | 2 - .../python/grappler/model_analyzer_wrapper.cc | 1 - .../python/grappler/tf_optimizer_wrapper.cc | 2 - 9 files changed, 1 insertion(+), 113 deletions(-) diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 6f5dd068d11ec5..687fd36ce2053e 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -26,8 +26,6 @@ cc_library( "//tensorflow/core/grappler/costs:cost_estimator", "//tensorflow/core/grappler/costs:measuring_cost_estimator", "//tensorflow/core/grappler/costs:utils", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", ] + tf_protos_grappler(), alwayslink = 1, ) @@ -60,28 +58,12 @@ tf_python_pybind_extension( starlark_only = True, deps = [ ":cost_analyzer_headers", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", - "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/platform:threadpool_options", "//tensorflow/python/lib/core:pybind11_status", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:thread_annotations", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", ], ) @@ -96,7 +78,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/costs:graph_properties", - "@com_google_absl//absl/status", ], ) @@ -113,16 +94,10 @@ tf_python_pybind_extension( ], starlark_only = True, deps = [ - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/python/lib/core:pybind11_status", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:platform_port", "@pybind11", ] + if_pywrap(["//tensorflow/python/grappler:model_analyzer_lib"]), ) @@ -153,29 +128,11 @@ tf_python_pybind_extension( "_pywrap_tf_item.pyi", ], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", - "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/grappler:utils", "//tensorflow/python/lib/core:pybind11_status", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:thread_annotations", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]), # b/148556093, ) @@ -256,32 +213,13 @@ tf_python_pybind_extension( "_pywrap_tf_cluster.pyi", ], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", - "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/grappler:utils", - "//tensorflow/core/platform:status", "//tensorflow/python/lib/core:pybind11_status", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:thread_annotations", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", ] + if_pywrap( if_true = [ @@ -354,32 +292,12 @@ tf_python_pybind_extension( # }), # static_deps = tf_python_pybind_static_deps(), deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_headers_lib", - "//tensorflow/core/common_runtime:device", - "//tensorflow/core/common_runtime:device_factory", - "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/common_runtime/gpu:gpu_id", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:status", "//tensorflow/python/lib/core:pybind11_status", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:thread_annotations", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ] + if_pywrap( diff --git a/tensorflow/python/grappler/cluster_wrapper.cc b/tensorflow/python/grappler/cluster_wrapper.cc index 0df0f3fcc25dc3..dbf97535082413 100644 --- a/tensorflow/python/grappler/cluster_wrapper.cc +++ b/tensorflow/python/grappler/cluster_wrapper.cc @@ -15,9 +15,7 @@ limitations under the License. #include #include -#include #include -#include #include #include #include @@ -26,8 +24,6 @@ limitations under the License. #include #include -#include "absl/log/check.h" -#include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 #include "tensorflow/core/framework/kernel_def.pb.h" diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 44239ebe140536..90f9b426d3756c 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -15,23 +15,10 @@ limitations under the License. #include "tensorflow/python/grappler/cost_analyzer.h" -#include -#include -#include #include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "tensorflow/core/framework/cost_graph.pb.h" -#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h index b14a89b1f318d9..44e1e45265b9c5 100644 --- a/tensorflow/python/grappler/cost_analyzer.h +++ b/tensorflow/python/grappler/cost_analyzer.h @@ -17,8 +17,6 @@ limitations under the License. #define TENSORFLOW_PYTHON_GRAPPLER_COST_ANALYZER_H_ #include - -#include "absl/status/status.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/python/grappler/item_wrapper.cc b/tensorflow/python/grappler/item_wrapper.cc index 27207ec2c053f4..13d2ee6def5c75 100644 --- a/tensorflow/python/grappler/item_wrapper.cc +++ b/tensorflow/python/grappler/item_wrapper.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include #include "pybind11/pybind11.h" // from @pybind11 diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index 5fe0e94ff8947b..202eb758a91221 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -15,15 +15,10 @@ limitations under the License. #include "tensorflow/python/grappler/model_analyzer.h" -#include -#include - -#include "absl/status/status.h" -#include "tensorflow/core/framework/node_def.pb.h" +#include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" -#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h index c76d850a5b119a..d66ad8915c99b5 100644 --- a/tensorflow/python/grappler/model_analyzer.h +++ b/tensorflow/python/grappler/model_analyzer.h @@ -17,8 +17,6 @@ limitations under the License. #define TENSORFLOW_PYTHON_GRAPPLER_MODEL_ANALYZER_H_ #include - -#include "absl/status/status.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/python/grappler/model_analyzer_wrapper.cc b/tensorflow/python/grappler/model_analyzer_wrapper.cc index c5db3fffa3c123..86ac40701e303a 100644 --- a/tensorflow/python/grappler/model_analyzer_wrapper.cc +++ b/tensorflow/python/grappler/model_analyzer_wrapper.cc @@ -15,7 +15,6 @@ limitations under the License. #include #include -#include #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/core/grappler/grappler_item_builder.h" diff --git a/tensorflow/python/grappler/tf_optimizer_wrapper.cc b/tensorflow/python/grappler/tf_optimizer_wrapper.cc index 4e88995858e6a7..08b3a0895071a8 100644 --- a/tensorflow/python/grappler/tf_optimizer_wrapper.cc +++ b/tensorflow/python/grappler/tf_optimizer_wrapper.cc @@ -17,8 +17,6 @@ limitations under the License. #include #include #include -#include -#include #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf From c211d1f917292a016d492e32242ecf99cb4315e3 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Fri, 11 Apr 2025 10:16:51 -0700 Subject: [PATCH 0576/1324] SliceOp, DynamicSliceOp : Direct StableHLO -> HLO translation. PiperOrigin-RevId: 746505822 --- .../mhlo_to_hlo/gen_hlo_op_writer.td | 4 +-- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 17 +++++++++- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.h | 6 ++++ .../hlo/translate/mhlo_to_hlo/translate.cc | 4 ++- .../xla/hlo/translate/mhlo_to_hlo/translate.h | 18 ++++++++++- .../translate/stablehlo_to_hlo/translate.cc | 7 ++-- third_party/xla/xla/hlo/translate/tests/BUILD | 1 + .../xla/hlo/translate/tests/stablehlo.mlir | 32 +++++++++++++++++-- .../stablehlo_legalize_to_hlo_pass.cc | 7 +++- 9 files changed, 86 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index f0e61f04070e5f..55596d1f5d836a 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -61,7 +61,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_DynamicIotaOp, // StableHLO_DynamicPadOp, // StableHLO_DynamicReshapeOp, - // StableHLO_DynamicSliceOp, + StableHLO_DynamicSliceOp, // StableHLO_DynamicUpdateSliceOp, // StableHLO_EinsumOp, // StableHLO_Expm1Op, @@ -119,7 +119,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_ShiftRightLogicalOp, // StableHLO_SignOp, // StableHLO_SineOp, - // StableHLO_SliceOp, + StableHLO_SliceOp, // StableHLO_SortOp, // StableHLO_SqrtOp, // StableHLO_SubtractOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 7ec6c4d8989b79..046aaa22c86b55 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -400,18 +400,30 @@ static mlir::FailureOr ExtractXlaShape(mlir::Operation* op) { return ConvertDenseIntAttr(attribute); \ } +#define I64_ARRAY_ATTR_TO_VECTOR(attribute) \ + static std::vector Convert_##attribute( \ + std::optional> attribute) { \ + if (!attribute.has_value()) return {}; \ + return {attribute->begin(), attribute->end()}; \ + } + I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes); I64_ELEMENTS_ATTR_TO_VECTOR(permutation); I64_ELEMENTS_ATTR_TO_VECTOR(start_indices); +I64_ARRAY_ATTR_TO_VECTOR(start_indices); I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices); +I64_ARRAY_ATTR_TO_VECTOR(limit_indices); I64_ELEMENTS_ATTR_TO_VECTOR(strides); +I64_ARRAY_ATTR_TO_VECTOR(strides); I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes); +I64_ARRAY_ATTR_TO_VECTOR(slice_sizes); I64_ELEMENTS_ATTR_TO_VECTOR(fft_length); I64_ELEMENTS_ATTR_TO_VECTOR(dimensions); I64_ELEMENTS_ATTR_TO_VECTOR(window_strides); I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation); I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation); +#undef I64_ARRAY_ATTR_TO_VECTOR #undef I64_ELEMENTS_ATTR_TO_VECTOR #define BOOL_ELEMENTS_ATTR_TO_VECTOR(attribute) \ @@ -4154,7 +4166,10 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, // temporarily support StableHLO to MHLO lowering here as well to ensure // a smooth migration. mlir::PassManager pm(module->getContext()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + mhlo::StablehloLegalizeToHloPassOptions shlo_pass_opts; + shlo_pass_opts.convert_xla_supported_stablehlo_ = + !options.direct_stablehlo_to_hlo; + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass(shlo_pass_opts)); if (failed(pm.run(module))) { return tsl::errors::Internal("Unable to convert StableHLO to MHLO"); } diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h index 6a90e32bbb24da..f61223f5d85a78 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h @@ -61,6 +61,12 @@ struct MlirToHloConversionOptions { // Multiple return values are always converted to a tuple and returned as a // single value. bool return_tuple = true; + + // If true, StableHLO ops that are supported by XLA will be converted directly + // to HLO. Otherwise, they will be converted to MHLO and then lowered to HLO. + // This is a temporary flag to support the ongoing direct stableHLO to HLO + // translation. + bool direct_stablehlo_to_hlo = false; }; // Prefer `ConvertMlirHloToHloModule` over this method when possible, as it diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc index 8c0016ff24d264..63efd0c82ec056 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc @@ -149,7 +149,8 @@ absl::Status ConvertMlirHloToHloViaBuilder( mlir::LogicalResult MlirHloToHloTextTranslateFunction( mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts) { + bool print_sugar, bool via_builder, bool with_layouts, + bool direct_stablehlo_to_hlo) { if (!module) return mlir::failure(); HloProto hloProto; @@ -157,6 +158,7 @@ mlir::LogicalResult MlirHloToHloTextTranslateFunction( options.propagate_layouts = with_layouts; options.use_tuple_args = emit_use_tuple_arg; options.return_tuple = emit_return_tuple; + options.direct_stablehlo_to_hlo = direct_stablehlo_to_hlo; absl::StatusOr> statusOrHloModule; if (via_builder) { auto status = ConvertMlirHloToHloViaBuilder(module, &hloProto, options); diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h index 064db33984b864..a47c880f8fb9aa 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h @@ -27,15 +27,31 @@ limitations under the License. namespace xla { +// Translates the given MLIR module containing MHLO to a HLO module. +// The resulting HLO is written to the output stream. +// `emit_return_tuple` controls whether the return value should be a tuple. +// `emit_use_tuple_arg` controls whether the arguments should be a tuple. mlir::LogicalResult MlirHloToHloTranslateFunction(mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, bool emit_use_tuple_arg); +// Translates the given MLIR module containing MHLO to a HLO text program. +// The resulting HLO text is written to the output stream. +// `emit_return_tuple` controls whether the return value should be a tuple. +// `emit_use_tuple_arg` controls whether the arguments should be a tuple. +// `print_layouts` controls whether to print layouts. +// `print_large_constants` controls whether to print large constants. +// `print_sugar` controls whether to print sugar. +// `via_builder` controls whether to use the HLO builder. +// `with_layouts` controls whether to print layouts. +// `direct_stablehlo_to_hlo` controls whether to translate StableHLO directly to +// HLO. mlir::LogicalResult MlirHloToHloTextTranslateFunction( mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); + bool print_sugar, bool via_builder, bool with_layouts, + bool direct_stablehlo_to_hlo = false); // Translate the MHLO program in in-memory file 'buffer' to a HLO program // written in a file represented with handle 'output_stream'; diff --git a/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc index ee50467832e872..68b9757fee7a2c 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc @@ -58,7 +58,9 @@ mlir::LogicalResult StablehloToHloTextTranslateFunction( if (!module) return mlir::failure(); mlir::PassManager pm(module->getContext()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + mlir::mhlo::StablehloLegalizeToHloPassOptions shlo_pass_opts; + shlo_pass_opts.convert_xla_supported_stablehlo_ = false; + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass(shlo_pass_opts)); if (failed(pm.run(module))) { module->dump(); return mlir::failure(); @@ -66,7 +68,8 @@ mlir::LogicalResult StablehloToHloTextTranslateFunction( return xla::MlirHloToHloTextTranslateFunction( module, output, emit_return_tuple, emit_use_tuple_arg, print_layouts, - print_large_constants, print_sugar, via_builder, with_layouts); + print_large_constants, print_sugar, via_builder, with_layouts, + /*direct_stablehlo_to_hlo=*/true); } mlir::LogicalResult StablehloToHloTextMain( diff --git a/third_party/xla/xla/hlo/translate/tests/BUILD b/third_party/xla/xla/hlo/translate/tests/BUILD index d003281e4baeb8..17f631deda574c 100644 --- a/third_party/xla/xla/hlo/translate/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/tests/BUILD @@ -34,6 +34,7 @@ lit_test_suite( ], tools = [ "//xla/hlo/tools:hlo-translate", + "//xla/hlo/translate:xla-translate", "//xla/mlir_hlo:mlir-hlo-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir index 7e782eb10f3c97..18947908b99a84 100644 --- a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir @@ -1,5 +1,5 @@ -// RUN: hlo-translate -mlir-to-hlo %s | FileCheck %s -// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo=convert-xla-supported-stablehlo=false %s | FileCheck %s --check-prefix CHECK-DIRECT +// RUN: xla-translate --stablehlo-to-hlo-text -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo=convert-xla-supported-stablehlo=false -split-input-file %s | FileCheck %s --check-prefix CHECK-DIRECT // Tests for all stablehlo ops to validate stablehlo -> hlo conversion. @@ -13,3 +13,31 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = stablehlo.add %arg0, %arg1 : tensor<4xf32> func.return %0 : tensor<4xf32> } // CHECK-DIRECT: stablehlo.add + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(s32[3,4]{1,0})->s32[1,2]{1,0}} + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = s32[3,4] parameter(0) +// CHECK-NEXT: ROOT %[[slice_2:[^ ]+]] = s32[1,2] slice(%[[Arg_0_1]]), slice={[1:2:1], [0:4:2]}, +func.func @main(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { + %0 = stablehlo.slice %arg0 [1:2, 0:4:2] : (tensor<3x4xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} +// CHECK-DIRECT: stablehlo.slice + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(s32[3,4]{1,0}, s64[], s64[])->s32[1,4]{1,0}} + +// CHECK: ENTRY %[[$main_5:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = s32[3,4] parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = s64[] parameter(1) +// CHECK-NEXT: %[[Arg_2_3:[^ ]+]] = s64[] parameter(2) +// CHECK-NEXT: ROOT %[[dynamic_slice_4:[^ ]+]] = s32[1,4] dynamic-slice(%[[Arg_0_1]], %[[Arg_1_2]], %[[Arg_2_3]]), dynamic_slice_sizes={1,4}, +func.func @main(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + %0 = stablehlo.dynamic_slice %arg0, %arg1, %arg2, sizes = [1, 4] : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} +// CHECK-DIRECT: stablehlo.dynamic_slice diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 9257d88984b097..9e658a25043dc3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -40,7 +40,12 @@ namespace mhlo { namespace { void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { - target.addLegalOp(); + target.addLegalOp< + // go/keep-sorted start + stablehlo::AddOp, stablehlo::ConstantOp, stablehlo::DynamicSliceOp, + stablehlo::SliceOp + // go/keep-sorted end + >(); } struct StablehloLegalizeToHloPass From 50d476b24ff7befbd272b2872084db5f8832c84d Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 11 Apr 2025 10:29:44 -0700 Subject: [PATCH 0577/1324] The second parameter to `TF_LITE_KERNEL_LOG` is expected to be a `printf` format string, known at compile time. Run IWYU to fix includes. PiperOrigin-RevId: 746510696 --- .../lite/experimental/acceleration/mini_benchmark/BUILD | 2 +- .../acceleration/mini_benchmark/decode_jpeg.cc | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index 4360e6a615f64e..6d14959b1a9b74 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -294,8 +294,8 @@ cc_library( hdrs = ["decode_jpeg_register.h"], copts = tflite_copts(), deps = [ + ":decode_jpeg_status", ":libjpeg_decoder", - "//tensorflow/lite:string", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc index b1e2d619904ca5..ea6e7ff5ad5574 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc @@ -14,17 +14,17 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include -#include #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_status.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/string_util.h" namespace tflite { @@ -124,7 +124,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { std::unique_ptr decoder = LibjpegDecoder::Create(decoder_status); if (decoder_status.code != kTfLiteOk) { - TF_LITE_KERNEL_LOG(context, decoder_status.error_message.c_str()); + TF_LITE_KERNEL_LOG(context, "%s", decoder_status.error_message.c_str()); return kTfLiteError; } @@ -166,7 +166,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_array_offset += kOutputImageSize; if (decode_status.code != kTfLiteOk) { - TF_LITE_KERNEL_LOG(context, decode_status.error_message.c_str()); + TF_LITE_KERNEL_LOG(context, "%s", decode_status.error_message.c_str()); return kTfLiteError; } } From a2f96fa014a00fe547cab25bf74fd15074b41300 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Fri, 11 Apr 2025 10:31:09 -0700 Subject: [PATCH 0578/1324] Add ExpandTestType utility from ClientLibraryTestBase to utility header. PiperOrigin-RevId: 746511168 --- third_party/xla/xla/tests/BUILD | 1 + .../xla/xla/tests/client_library_test_base.h | 14 -------------- .../xla/tests/client_library_test_runner_utils.h | 14 ++++++++++++++ 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index a2d68605f00c95..bf2069617ad78a 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -407,6 +407,7 @@ cc_library( "//xla/hlo/builder:xla_computation", "//xla/tsl/platform:status", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 03d526b85a6105..961f00abf3c0cb 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -58,20 +58,6 @@ static_assert(false, namespace xla { -template -std::vector ExpandTestType( - absl::Span test_type_params, - absl::Span specs) { - std::vector expanded; - for (const PrimitiveType test_type : test_type_params) { - for (const auto& spec : specs) { - expanded.push_back(spec); - expanded.back().test_type = test_type; - } - } - return expanded; -} - // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: diff --git a/third_party/xla/xla/tests/client_library_test_runner_utils.h b/third_party/xla/xla/tests/client_library_test_runner_utils.h index 0c0d04e0d3bab2..ac4f4b00600fe5 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_utils.h +++ b/third_party/xla/xla/tests/client_library_test_runner_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "xla/array2d.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/tests/test_utils.h" @@ -69,6 +70,19 @@ std::unique_ptr> CreatePatternedMatrix(int rows, int cols, std::unique_ptr> CreatePatternedMatrixWithZeroPadding( int rows, int cols, int rows_padded, int cols_padded); +template +std::vector ExpandTestType( + absl::Span test_type_params, + absl::Span specs) { + std::vector expanded; + for (const PrimitiveType test_type : test_type_params) { + for (const auto& spec : specs) { + expanded.push_back(spec); + expanded.back().test_type = test_type; + } + } + return expanded; +} } // namespace xla #endif // XLA_TESTS_CLIENT_LIBRARY_TEST_RUNNER_UTILS_H_ From f76cd031cf78ad905b72041552fc69fdf5d6f2c0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 10:34:08 -0700 Subject: [PATCH 0579/1324] [Fix] [XLA] Disable pipelining for pad ops PiperOrigin-RevId: 746512385 --- third_party/xla/xla/service/collective_pipeliner.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index f76a80e6e1d141..1e8c1216ab193a 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -343,11 +343,12 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, ShapeUtil::ElementsIn(instr->operand(0)->shape()) < 1024)) { return true; } + // TODO(b/409716406): Reconsider cases where Pad can be supported. return HloPredicateIsOp(i) || + HloOpcode::kCollectivePermute, HloOpcode::kConvert, + HloOpcode::kReshape, HloOpcode::kAllReduce, + HloOpcode::kTranspose, HloOpcode::kBroadcast, + HloOpcode::kAllGather>(i) || (multi_uses_pipelining && i->IsElementwise()) || i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep) || i->IsCustomCall(CollectivePipeliner::kSunkByPreviousStep); From 39097561bc64698c7b9c01f8eb6473cff6dcae0d Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Fri, 11 Apr 2025 10:34:11 -0700 Subject: [PATCH 0580/1324] Cleanup `tsl::errors`. * Use helper functions to create error status. * Remove unused overloads. * Remove unused template parameter for `...WithPayloads` functions. * Replace a few `tsl::strings` functions usage with Abseil equivalents. * Inline error status test functions. I plan to mark these as `ABSL_DEPRECATE_AND_INLINE` in a followup CL. PiperOrigin-RevId: 746512408 --- tensorflow/python/_pywrap_tensorflow.def | 13 +- third_party/xla/xla/tsl/platform/BUILD | 8 +- third_party/xla/xla/tsl/platform/errors.cc | 74 +----- third_party/xla/xla/tsl/platform/errors.h | 294 +++++++++------------ 4 files changed, 139 insertions(+), 250 deletions(-) diff --git a/tensorflow/python/_pywrap_tensorflow.def b/tensorflow/python/_pywrap_tensorflow.def index 70a2818ba4835e..2a6329d9d73e44 100644 --- a/tensorflow/python/_pywrap_tensorflow.def +++ b/tensorflow/python/_pywrap_tensorflow.def @@ -45,7 +45,6 @@ EXPORTS ?ForEachPayload@Status@lts_20230802@absl@@QEBAXV?$FunctionRef@$$A6AXV?$basic_string_view@DU?$char_traits@D@std@@@std@@AEBVCord@lts_20230802@absl@@@Z@23@@Z ??BCord@lts_20230802@absl@@QEBA?AV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@XZ ?ReadRecord@RecordReader@io@tsl@@QEAA?AVStatus@lts_20230802@absl@@PEA_KPEAVtstring@3@@Z - ?IsOutOfRange@errors@tsl@@YA_NAEBVStatus@lts_20230802@absl@@@Z ?FastUInt64ToBufferLeft@strings@tsl@@YA_K_KPEAD@Z ?StrCat@strings@tsl@@YA?AV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@AEBVAlphaNum@12@0@Z ??1RecordWriter@io@tsl@@QEAA@XZ @@ -80,7 +79,6 @@ EXPORTS ?ReadFileToString@tsl@@YA?AVStatus@lts_20230802@absl@@PEAVEnv@1@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@PEAV67@@Z ?GetChildren@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@PEAV?$vector@V?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@V?$allocator@V?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@@2@@7@@Z ?CreateDir@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@@Z - ?IsAlreadyExists@errors@tsl@@YA_NAEBVStatus@lts_20230802@absl@@@Z ?RecursivelyCreateDir@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@@Z ?CopyFile@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@0@Z ?StrCat@strings@tsl@@YA?AV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@AEBVAlphaNum@12@@Z @@ -840,7 +838,6 @@ EXPORTS ?NameRangesForNode@tensorflow@@YA?AVStatus@lts_20230802@absl@@AEBVAttrSlice@1@AEBVOpDef@1@PEAV?$FlatMap@V?$basic_string_view@DU?$char_traits@D@std@@@std@@U?$pair@HH@2@U?$hash@V?$basic_string_view@DU?$char_traits@D@std@@@std@@X@tsl@@U?$equal_to@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@2@@gtl@tsl@@2@Z ?ParseTensorName@tensorflow@@YA?AUTensorId@1@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z ?Hash64@tsl@@YA_KPEBD_K1@Z - ?IsFailedPrecondition@errors@tsl@@YA_NAEBVStatus@lts_20230802@absl@@@Z ?HasAtomicMove@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@PEA_N@Z ?Stat@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@PEAUFileStatistics@2@@Z ?NewAppendableFile@Env@tsl@@QEAA?AVStatus@lts_20230802@absl@@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@PEAV?$unique_ptr@VWritableFile@tsl@@U?$default_delete@VWritableFile@tsl@@@std@@@7@@Z @@ -1245,3 +1242,13 @@ EXPORTS ?Signal@CondVar@lts_20230802@absl@@QEAAXXZ ?Unlock@Mutex@lts_20230802@absl@@QEAAXXZ ?Wait@CondVar@lts_20230802@absl@@QEAAXPEAVMutex@23@@Z + ?UnimplementedError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?NotFoundError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?InternalError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?InvalidArgumentError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?IsAlreadyExists@lts_20230802@absl@@YA_NAEBVStatus@12@@Z + ?AlreadyExistsError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?IsOutOfRange@lts_20230802@absl@@YA_NAEBVStatus@12@@Z + ?FailedPreconditionError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?PermissionDeniedError@lts_20230802@absl@@YA?AVStatus@12@V?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z + ?IsFailedPrecondition@lts_20230802@absl@@YA_NAEBVStatus@12@@Z diff --git a/third_party/xla/xla/tsl/platform/BUILD b/third_party/xla/xla/tsl/platform/BUILD index cedaa35ff49072..db27bed86811cd 100644 --- a/third_party/xla/xla/tsl/platform/BUILD +++ b/third_party/xla/xla/tsl/platform/BUILD @@ -389,14 +389,12 @@ cc_library( srcs = ["errors.cc"], hdrs = ["errors.h"], deps = [ - "//xla/tsl/platform:logging", - "//xla/tsl/platform:macros", - "//xla/tsl/platform:status", - "@com_google_absl//absl/base:core_headers", + ":logging", + ":macros", + ":status", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", - "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", ], ) diff --git a/third_party/xla/xla/tsl/platform/errors.cc b/third_party/xla/xla/tsl/platform/errors.cc index 88aadeb1ac9f95..67f24f1562f3ce 100644 --- a/third_party/xla/xla/tsl/platform/errors.cc +++ b/third_party/xla/xla/tsl/platform/errors.cc @@ -18,8 +18,9 @@ limitations under the License. #include #include -#include "xla/tsl/platform/status.h" -#include "tsl/platform/strcat.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" namespace tsl { namespace errors { @@ -175,74 +176,9 @@ absl::StatusCode ErrnoToCode(int err_number) { } // namespace -absl::Status IOError(const string& context, int err_number) { +absl::Status IOError(absl::string_view context, int err_number) { auto code = ErrnoToCode(err_number); - return absl::Status(code, - strings::StrCat(context, "; ", strerror(err_number))); -} - -bool IsAborted(const absl::Status& status) { - return status.code() == tsl::error::Code::ABORTED; -} - -bool IsAlreadyExists(const absl::Status& status) { - return status.code() == tsl::error::Code::ALREADY_EXISTS; -} - -bool IsCancelled(const absl::Status& status) { - return status.code() == tsl::error::Code::CANCELLED; -} - -bool IsDataLoss(const absl::Status& status) { - return status.code() == tsl::error::Code::DATA_LOSS; -} - -bool IsDeadlineExceeded(const absl::Status& status) { - return status.code() == tsl::error::Code::DEADLINE_EXCEEDED; -} - -bool IsFailedPrecondition(const absl::Status& status) { - return status.code() == tsl::error::Code::FAILED_PRECONDITION; -} - -bool IsInternal(const absl::Status& status) { - return status.code() == tsl::error::Code::INTERNAL; -} - -bool IsInvalidArgument(const absl::Status& status) { - return status.code() == tsl::error::Code::INVALID_ARGUMENT; -} - -bool IsNotFound(const absl::Status& status) { - return status.code() == tsl::error::Code::NOT_FOUND; -} - -bool IsOutOfRange(const absl::Status& status) { - return status.code() == tsl::error::Code::OUT_OF_RANGE; -} - -bool IsPermissionDenied(const absl::Status& status) { - return status.code() == tsl::error::Code::PERMISSION_DENIED; -} - -bool IsResourceExhausted(const absl::Status& status) { - return status.code() == tsl::error::Code::RESOURCE_EXHAUSTED; -} - -bool IsUnauthenticated(const absl::Status& status) { - return status.code() == tsl::error::Code::UNAUTHENTICATED; -} - -bool IsUnavailable(const absl::Status& status) { - return status.code() == tsl::error::Code::UNAVAILABLE; -} - -bool IsUnimplemented(const absl::Status& status) { - return status.code() == tsl::error::Code::UNIMPLEMENTED; -} - -bool IsUnknown(const absl::Status& status) { - return status.code() == tsl::error::Code::UNKNOWN; + return absl::Status(code, absl::StrCat(context, "; ", strerror(err_number))); } } // namespace errors diff --git a/third_party/xla/xla/tsl/platform/errors.h b/third_party/xla/xla/tsl/platform/errors.h index a154d1d970f7df..6ef49cfb1889b0 100644 --- a/third_party/xla/xla/tsl/platform/errors.h +++ b/third_party/xla/xla/tsl/platform/errors.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "xla/tsl/platform/logging.h" @@ -87,7 +88,7 @@ inline const strings::AlphaNum& PrepareForStrCat(const strings::AlphaNum& a) { } // namespace internal // Maps UNIX errors into a Status. -absl::Status IOError(const string& context, int err_number); +absl::Status IOError(absl::string_view context, int err_number); // Returns all payloads from a Status as a key-value map. inline std::unordered_map GetPayloads( @@ -199,9 +200,8 @@ void AppendToMessage(absl::Status* status, Args... args) { // CANCELLED template absl::Status Cancelled(Args... args) { - return absl::Status(absl::StatusCode::kCancelled, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::CancelledError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } template absl::Status CancelledWithPayloads( @@ -213,19 +213,16 @@ absl::Status CancelledWithPayloads( // InvalidArgument template absl::Status InvalidArgument(Args... args) { - return absl::Status(absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::InvalidArgumentError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } - +// Specialized overloads to capture source location for up to four arguments. #if defined(PLATFORM_GOOGLE) -// Specialized overloads to capture source location for up to three arguments. template absl::Status InvalidArgument( Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2), ::tsl::errors::internal::PrepareForStrCat(arg3), @@ -236,8 +233,7 @@ template absl::Status InvalidArgument( Arg1 arg1, Arg2 arg2, Arg3 arg3, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2), ::tsl::errors::internal::PrepareForStrCat(arg3)), @@ -247,8 +243,7 @@ template absl::Status InvalidArgument( Arg1 arg1, Arg2 arg2, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2)), loc); @@ -256,13 +251,11 @@ absl::Status InvalidArgument( template absl::Status InvalidArgument( Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), loc); } -template -absl::Status InvalidArgumentWithPayloads( +inline absl::Status InvalidArgumentWithPayloads( absl::string_view message, const std::unordered_map& payloads, absl::SourceLocation loc = absl::SourceLocation::current()) { @@ -270,29 +263,7 @@ absl::Status InvalidArgumentWithPayloads( loc); } #else -template -absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3))); -} -template -absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2))); -} -template -absl::Status InvalidArgument(Arg1 arg1) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); -} -template -absl::Status InvalidArgumentWithPayloads( +inline absl::Status InvalidArgumentWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads); @@ -302,18 +273,16 @@ absl::Status InvalidArgumentWithPayloads( // NotFound template absl::Status NotFound(Args... args) { - return absl::Status(absl::StatusCode::kNotFound, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::NotFoundError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -#if defined(PLATFORM_GOOGLE) // Specialized overloads to capture source location for up to three arguments. +#if defined(PLATFORM_GOOGLE) template absl::Status NotFound( Arg1 arg1, Arg2 arg2, Arg3 arg3, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, + return absl::NotFoundError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2), ::tsl::errors::internal::PrepareForStrCat(arg3)), @@ -323,8 +292,7 @@ template absl::Status NotFound( Arg1 arg1, Arg2 arg2, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, + return absl::NotFoundError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2)), loc); @@ -332,42 +300,18 @@ absl::Status NotFound( template absl::Status NotFound( Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, + return absl::NotFoundError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), loc); } -template -absl::Status NotFoundWithPayloads( +inline absl::Status NotFoundWithPayloads( absl::string_view message, const std::unordered_map& payloads, absl::SourceLocation loc = absl::SourceLocation::current()) { return errors::Create(absl::StatusCode::kNotFound, message, payloads, loc); } #else -template -absl::Status NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3))); -} -template -absl::Status NotFound(Arg1 arg1, Arg2 arg2) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2))); -} -template -absl::Status NotFound(Arg1 arg1) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); -} -template -absl::Status NotFoundWithPayloads( +inline absl::Status NotFoundWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kNotFound, message, payloads); @@ -377,12 +321,10 @@ absl::Status NotFoundWithPayloads( // AlreadyExists template absl::Status AlreadyExists(Args... args) { - return absl::Status(absl::StatusCode::kAlreadyExists, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::AlreadyExistsError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status AlreadyExistsWithPayloads( +inline absl::Status AlreadyExistsWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kAlreadyExists, message, payloads); @@ -391,12 +333,10 @@ absl::Status AlreadyExistsWithPayloads( // ResourceExhausted template absl::Status ResourceExhausted(Args... args) { - return absl::Status(absl::StatusCode::kResourceExhausted, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::ResourceExhaustedError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status ResourceExhaustedWithPayloads( +inline absl::Status ResourceExhaustedWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kResourceExhausted, message, @@ -406,12 +346,10 @@ absl::Status ResourceExhaustedWithPayloads( // Unavailable template absl::Status Unavailable(Args... args) { - return absl::Status(absl::StatusCode::kUnavailable, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::UnavailableError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status UnavailableWithPayloads( +inline absl::Status UnavailableWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnavailable, message, payloads); @@ -420,12 +358,10 @@ absl::Status UnavailableWithPayloads( // FailedPrecondition template absl::Status FailedPrecondition(Args... args) { - return absl::Status(absl::StatusCode::kFailedPrecondition, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::FailedPreconditionError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status FailedPreconditionWithPayloads( +inline absl::Status FailedPreconditionWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kFailedPrecondition, message, @@ -435,12 +371,10 @@ absl::Status FailedPreconditionWithPayloads( // OutOfRange template absl::Status OutOfRange(Args... args) { - return absl::Status(absl::StatusCode::kOutOfRange, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::OutOfRangeError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status OutOfRangeWithPayloads( +inline absl::Status OutOfRangeWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kOutOfRange, message, payloads); @@ -449,12 +383,10 @@ absl::Status OutOfRangeWithPayloads( // Unimplemented template absl::Status Unimplemented(Args... args) { - return absl::Status(absl::StatusCode::kUnimplemented, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::UnimplementedError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status UnimplementedWithPayloads( +inline absl::Status UnimplementedWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnimplemented, message, payloads); @@ -463,12 +395,10 @@ absl::Status UnimplementedWithPayloads( // Internal template absl::Status Internal(Args... args) { - return absl::Status(absl::StatusCode::kInternal, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::InternalError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status InternalWithPayloads( +inline absl::Status InternalWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kInternal, message, payloads); @@ -477,12 +407,10 @@ absl::Status InternalWithPayloads( // Aborted template absl::Status Aborted(Args... args) { - return absl::Status(absl::StatusCode::kAborted, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::AbortedError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status AbortedWithPayloads( +inline absl::Status AbortedWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kAborted, message, payloads); @@ -491,12 +419,10 @@ absl::Status AbortedWithPayloads( // DeadlineExceeded template absl::Status DeadlineExceeded(Args... args) { - return absl::Status(absl::StatusCode::kDeadlineExceeded, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::DeadlineExceededError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status DeadlineExceededWithPayloads( +inline absl::Status DeadlineExceededWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kDeadlineExceeded, message, payloads); @@ -505,12 +431,10 @@ absl::Status DeadlineExceededWithPayloads( // DataLoss template absl::Status DataLoss(Args... args) { - return absl::Status(absl::StatusCode::kDataLoss, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::DataLossError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status DataLossWithPayloads( +inline absl::Status DataLossWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kDataLoss, message, payloads); @@ -519,12 +443,10 @@ absl::Status DataLossWithPayloads( // Unknown template absl::Status Unknown(Args... args) { - return absl::Status(absl::StatusCode::kUnknown, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::UnknownError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status UnknownPayloads( +inline absl::Status UnknownPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnknown, message, payloads); @@ -532,12 +454,10 @@ absl::Status UnknownPayloads( // PermissionDenied template absl::Status PermissionDenied(Args... args) { - return absl::Status(absl::StatusCode::kPermissionDenied, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::PermissionDeniedError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status PermissionDeniedWithPayloads( +inline absl::Status PermissionDeniedWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kPermissionDenied, message, payloads); @@ -546,33 +466,63 @@ absl::Status PermissionDeniedWithPayloads( // Unauthenticated template absl::Status Unauthenticated(Args... args) { - return absl::Status(absl::StatusCode::kUnauthenticated, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); + return absl::UnauthenticatedError(::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); } -template -absl::Status UnauthenticatedWithPayloads( +inline absl::Status UnauthenticatedWithPayloads( absl::string_view message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnauthenticated, message, payloads); } -bool IsAborted(const absl::Status& status); -bool IsAlreadyExists(const absl::Status& status); -bool IsCancelled(const absl::Status& status); -bool IsDataLoss(const absl::Status& status); -bool IsDeadlineExceeded(const absl::Status& status); -bool IsFailedPrecondition(const absl::Status& status); -bool IsInternal(const absl::Status& status); -bool IsInvalidArgument(const absl::Status& status); -bool IsNotFound(const absl::Status& status); -bool IsOutOfRange(const absl::Status& status); -bool IsPermissionDenied(const absl::Status& status); -bool IsResourceExhausted(const absl::Status& status); -bool IsUnauthenticated(const absl::Status& status); -bool IsUnavailable(const absl::Status& status); -bool IsUnimplemented(const absl::Status& status); -bool IsUnknown(const absl::Status& status); +inline bool IsAborted(const absl::Status& status) { + return absl::IsAborted(status); +} +inline bool IsAlreadyExists(const absl::Status& status) { + return absl::IsAlreadyExists(status); +} +inline bool IsCancelled(const absl::Status& status) { + return absl::IsCancelled(status); +} +inline bool IsDataLoss(const absl::Status& status) { + return absl::IsDataLoss(status); +} +inline bool IsDeadlineExceeded(const absl::Status& status) { + return absl::IsDeadlineExceeded(status); +} +inline bool IsFailedPrecondition(const absl::Status& status) { + return absl::IsFailedPrecondition(status); +} +inline bool IsInternal(const absl::Status& status) { + return absl::IsInternal(status); +} +inline bool IsInvalidArgument(const absl::Status& status) { + return absl::IsInvalidArgument(status); +} +inline bool IsNotFound(const absl::Status& status) { + return absl::IsNotFound(status); +} +inline bool IsOutOfRange(const absl::Status& status) { + return absl::IsOutOfRange(status); +} +inline bool IsPermissionDenied(const absl::Status& status) { + return absl::IsPermissionDenied(status); +} +inline bool IsResourceExhausted(const absl::Status& status) { + return absl::IsResourceExhausted(status); +} +inline bool IsUnauthenticated(const absl::Status& status) { + return absl::IsUnauthenticated(status); +} +inline bool IsUnavailable(const absl::Status& status) { + return absl::IsUnavailable(status); +} +inline bool IsUnimplemented(const absl::Status& status) { + return absl::IsUnimplemented(status); +} +inline bool IsUnknown(const absl::Status& status) { + return absl::IsUnknown(status); +} // Produces a formatted string pattern from the name which can uniquely identify // this node upstream to produce an informative error message. The pattern @@ -581,19 +531,19 @@ bool IsUnknown(const absl::Status& status); // tensorflow/python/client/session.py // LINT.IfChange inline std::string FormatNodeNameForError(absl::string_view name) { - return strings::StrCat("{{node ", name, "}}"); + return absl::StrCat("{{node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/client/session.py) template std::string FormatNodeNamesForError(const T& names) { - return absl::StrJoin( - names, ", ", [](std::string* output, absl::string_view s) { - ::tsl::strings::StrAppend(output, FormatNodeNameForError(s)); - }); + return absl::StrJoin(names, ", ", + [](std::string* output, absl::string_view s) { + absl::StrAppend(output, FormatNodeNameForError(s)); + }); } // LINT.IfChange inline std::string FormatColocationNodeForError(absl::string_view name) { - return strings::StrCat("{{colocation_node ", name, "}}"); + return absl::StrCat("{{colocation_node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) template originally returned UnavailableError, and was replaced by " - "InternalError to avoid invoking TF network error handling logic.")); + assert(absl::IsUnavailable(s)); + return absl::InternalError(absl::StrCat( + s.message(), "\nExecuting non-communication op <", op_name, + "> originally returned UnavailableError, and was replaced by " + "InternalError to avoid invoking TF network error handling logic.")); } template From f02801666c9766bd0eb50fda4a359c43e9b26b4d Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 11 Apr 2025 11:02:25 -0700 Subject: [PATCH 0581/1324] [xla:gpu] CommandBuffer: use AbslStringify to convert enums to strings PiperOrigin-RevId: 746522359 --- .../xla/xla/stream_executor/command_buffer.h | 21 ++++++++++++------- .../cuda/cuda_command_buffer.h | 3 ++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 2ac80cf5030c17..dbc7bbab3d3da1 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -87,14 +87,18 @@ class CommandBuffer { // enum class State { kCreate, kUpdate, kFinalized }; - friend absl::string_view StateToString(State state) { + template + friend void AbslStringify(Sink& sink, State state) { switch (state) { case CommandBuffer::State::kCreate: - return "create"; + sink.Append("create"); + break; case CommandBuffer::State::kUpdate: - return "update"; + sink.Append("update"); + break; case CommandBuffer::State::kFinalized: - return "finalized"; + sink.Append("finalized"); + break; } } @@ -107,12 +111,15 @@ class CommandBuffer { // enum class Mode { kPrimary, kNested }; - friend absl::string_view ModeToString(Mode mode) { + template + friend void AbslStringify(Sink& sink, Mode mode) { switch (mode) { case CommandBuffer::Mode::kPrimary: - return "primary"; + sink.Append("primary"); + break; case CommandBuffer::Mode::kNested: - return "nested"; + sink.Append("nested"); + break; } } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h index 06b95380a91978..61b69ad8c5121d 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" @@ -65,7 +66,7 @@ class CudaCommandBuffer final : public GpuCommandBuffer { graph_(graph), is_owned_graph_(is_owned_graph) { VLOG(5) << "Created command buffer for graph " << graph_ - << "; mode=" << ModeToString(mode) + << "; mode=" << absl::StrCat(mode) << "; is_owned_graph=" << is_owned_graph_; } From 7c5891306ecaf878b8635e84e797b83a6559f46a Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 11 Apr 2025 11:35:26 -0700 Subject: [PATCH 0582/1324] Remove deprecated targets from build and header files PiperOrigin-RevId: 746533821 --- .../xla/xla/translate/mhlo_to_hlo/BUILD | 87 ------------------- .../mhlo_to_hlo/attribute_exporter.h | 22 ----- .../xla/translate/mhlo_to_hlo/layout_util.h | 24 ----- .../translate/mhlo_to_hlo/location_exporter.h | 22 ----- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.h | 22 ----- .../xla/xla/translate/mhlo_to_hlo/translate.h | 22 ----- .../xla/translate/mhlo_to_hlo/type_to_shape.h | 22 ----- 7 files changed, 221 deletions(-) delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/BUILD delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/translate.h delete mode 100644 third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD deleted file mode 100644 index d3dfd9e969d0c2..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ /dev/null @@ -1,87 +0,0 @@ -load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = internal_visibility([ - "//learning/brain/mlir:tensorflow_friends", - "//learning/brain/mlir:xla_friends", - ]), - licenses = ["notice"], -) - -cc_library( - name = "attribute_exporter", - hdrs = ["attribute_exporter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:attribute_exporter instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", - ], -) - -cc_library( - name = "layout_util", - hdrs = ["layout_util.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:layout_util instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:layout_util", - ], -) - -cc_library( - name = "location_exporter", - hdrs = ["location_exporter.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:location_exporter instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:location_exporter", - ], -) - -alias( - name = "module_attributes_exporter", - actual = "//xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter", - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter instead.", -) - -alias( - name = "stack_frame_index_builder", - actual = "//xla/hlo/translate/mhlo_to_hlo:stack_frame_index_builder", - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:stack_frame_index_builder instead.", -) - -cc_library( - name = "mlir_hlo_to_hlo", - hdrs = ["mlir_hlo_to_hlo.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - ], -) - -cc_library( - name = "translate", - hdrs = ["translate.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:translate instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:translate", - ], -) - -cc_library( - name = "translate_registration", - testonly = True, - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:translate_registration instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:translate_registration", - ], - alwayslink = 1, -) - -cc_library( - name = "type_to_shape", - hdrs = ["type_to_shape.h"], - deprecation = "This library is deprecated and will be removed in February 2025. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:type_to_shape instead.", - deps = [ - "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", - ], -) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h deleted file mode 100644 index 2caf77bf3a3d2a..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" - -#endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h deleted file mode 100644 index 6005d23d69e910..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Utilities for working with XLA layout and shapes. - -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" - -#endif // XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h deleted file mode 100644 index b5c43ce49c481a..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" - -#endif // XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h deleted file mode 100644 index 1544b99e069571..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" - -#endif // XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.h b/third_party/xla/xla/translate/mhlo_to_hlo/translate.h deleted file mode 100644 index 373eaca3fca4f3..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/mhlo_to_hlo/translate.h" - -#endif // XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h deleted file mode 100644 index 2e99276efe7c5b..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ - -// The current header will be deprecated in favour of the following. -#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" - -#endif // XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ From bdbb0b53f805e0edf15f308a84686c4d2d90c624 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Fri, 11 Apr 2025 11:47:17 -0700 Subject: [PATCH 0583/1324] Add pipelining stages selection logic to dynamic search space PiperOrigin-RevId: 746537958 --- .../gpu/autotuning/dot_search_space.cc | 49 ++++++++++++++++-- .../service/gpu/autotuning/dot_search_space.h | 12 +++++ .../gpu/autotuning/dot_search_space_test.cc | 50 +++++++++++++++---- 3 files changed, 99 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 0b8ab610b115a9..9050b4a2c2e8ac 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -138,14 +138,13 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( EliminateLowOccupancyConfigs(configs); ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddCtaSizeParameter); ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddContractingTiling); + ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddPipeliningParameter); std::vector result; result.reserve(configs.size()); for (ConfigWithNotes& config_with_notes : configs) { - // TODO: b/404470821 - Implement this properly rather than hardcoding the - // config parameters. TritonGemmConfig& config = config_with_notes.config; - config.num_stages = 3; + // TODO: b/408386169 - Implement CTA cluster support. config.num_ctas = 1; result.push_back(config); } @@ -301,6 +300,27 @@ int TritonDotFusionSearchSpace::GetMaxContractingTileSize( return max_size; } +int TritonDotFusionSearchSpace::GetMaxNumStages(OutputTile output_tile, + int contracting_tile_size, + int contracting_split) const { + const int64_t available_stages = CeilOfRatio( + contracting_size_, contracting_split * contracting_tile_size); + const int64_t stage_limit = + CeilOfRatio(GetContractingSizeLimitToFitSharedMemory(output_tile), + contracting_tile_size); + // Number of stages is basically a replacement for oversubscription, so + // the maximum number we want is also limited by kMaxWarpsPerScheduler. + const int stages = std::min({available_stages, stage_limit, + static_cast(kMaxWarpsPerScheduler)}); + VLOG(5) << "Computing max_num_stages for tiling BxMxNxK = " + << contracting_split << "x" << output_tile.lhs_dim << "x" + << output_tile.rhs_dim << "x" << contracting_tile_size + << ": limit based on problem is " << available_stages + << ", limit based on available shared memory is " << stage_limit + << ", max_num_stages = " << stages; + return stages; +} + std::vector TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { CHECK_GE(max_contracting_split_, 1); @@ -394,6 +414,29 @@ void TritonDotFusionSearchSpace::AddContractingTiling( } } +void TritonDotFusionSearchSpace::AddPipeliningParameter( + const ConfigWithNotes& config, + std::vector& updated_configs) { + const int tile_rows = config.config.block_m; + const int tile_cols = config.config.block_n; + const int tile_contracting = config.config.block_k; + const int split = config.config.split_k; + CHECK_GT(tile_rows * tile_cols, 0) + << "Need config with output tilings determined."; + CHECK_GT(tile_contracting, 0) + << "Need config with contracting tiling determined."; + CHECK_GT(split, 0) << "Need config with contracting split determined."; + int max_stages = + GetMaxNumStages({tile_rows, tile_cols}, tile_contracting, split); + ConfigWithNotes new_config = config; + for (int num_stages = 1; num_stages <= max_stages; ++num_stages) { + new_config.config.num_stages = num_stages; + VLOG(10) << "Adding pipelining parameter: config = " + << new_config.ToString(); + updated_configs.push_back(new_config); + } +} + void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( std::vector& configs) { CHECK(!configs.empty()); diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index 569198e3bdae78..5c98406658920e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -129,6 +129,11 @@ class TritonDotFusionSearchSpace { int GetMaxContractingTileSize(OutputTile output_tile, int contracting_split) const; + // Computes the maximum reasonable number of stages for the given output and + // input tilings and contracting split. + int GetMaxNumStages(OutputTile output_tile, int contracting_tile_size, + int contracting_split) const; + // Finds all promising values for splitting the contracting dimension to // achieve sufficient occupancy (split_k). std::vector GenerateContractingSplitFactors(); @@ -154,6 +159,13 @@ class TritonDotFusionSearchSpace { void AddContractingTiling(const ConfigWithNotes& config, std::vector& updated_configs); + // Finds all promising values for the pipelining parameter, based on + // `config` with already determined contracting split, output tiling, and + // contracting tile size, and appends them to `updated_configs`. Each config + // in the input list might yield zero or more configs in the output. + void AddPipeliningParameter(const ConfigWithNotes& config, + std::vector& updated_configs); + // Removes configs that are marked with `not_enough_tiles` from the list. If // this results in an empty list, adds a config that should be the most // optimal one even though it does not occupy all cores. diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index 75d109534d85a2..ca1c2464992c1a 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -40,6 +40,15 @@ using ::testing::IsEmpty; using ::testing::Le; using ::testing::SizeIs; +// Returns a matcher that verifies that each container element that matches +// `filter` also matches `matcher`, and that there is at least one such element. +template +auto WhenFilteredBy(FilterMatcher filter, Matcher matcher) { + // We check the negation: there is no element that matches `filter` and does + // not match `matcher`. + return AllOf(Contains(filter), Not(Contains(AllOf(filter, Not(matcher))))); +} + template auto BlockMIs(MatcherType matcher) { return Field("block_m", &TritonGemmConfig::block_m, matcher); @@ -292,20 +301,16 @@ TEST_F(DotSearchSpaceTest, HonorsSharedMemoryLimit) { GetDefaultDotModule(/*lhs_parallel_dim=*/4096, /*rhs_parallel_dim=*/4096, /*contracting_dim=*/4096)); TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); - // We pick a certain output tiling and contracting split of 1 (to not reduce - // the effective contracting size), and only verify that configs with these + + // We pick the 128x128 output tiling and only verify that configs with these // properties honor the memory limit. This simplifies the test logic and makes // the calculation easier to verify by hand, while not reducing the coverage // of the test. - auto considered_configs = - AllOf(BlockMIs(Eq(128)), BlockNIs(Eq(128)), SplitKIs(Eq(1))); - // 2B * (128 + 128) * block_k < 227 KB => // block_k <= 227 KB / (2B * (128 + 128)) = 454 - EXPECT_THAT( - search_space.GenerateConfigs(), - AllOf(Contains(considered_configs), - Not(Contains(AllOf(considered_configs, BlockKIs(Ge(512))))))); + EXPECT_THAT(search_space.GenerateConfigs(/*force_contracting_split=*/1), + WhenFilteredBy(AllOf(BlockMIs(Eq(128)), BlockNIs(Eq(128))), + BlockKIs(Le(256)))); } TEST_F(DotSearchSpaceTest, HonorsContractingSizeLimit) { @@ -330,5 +335,32 @@ TEST_F(DotSearchSpaceTest, EnsuresContractingTileSizeFitsInstructonShape) { AllOf(Not(IsEmpty()), Each(BlockKIs(Ge(8))))); } +TEST_F(DotSearchSpaceTest, FindReasonablePipeliningStageCount) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Contains(NumStagesIs(Ge(2))).Times(Ge(2)), + Contains(NumStagesIs(Eq(1))), Each(NumStagesIs(Le(5))))); +} + +TEST_F(DotSearchSpaceTest, LimitsStagesToAvailableTileSize) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, /*rhs_parallel_dim=*/1024, + /*contracting_dim=*/128)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + // We pick the 64x32x32 tiling and only verify that configs with these + // properties choose the right number of stages. This simplifies the test + // logic and makes the calculation easier to verify by hand, while not + // reducing the coverage of the test. + EXPECT_THAT(search_space.GenerateConfigs(/*force_contracting_split=*/2), + WhenFilteredBy( + AllOf(BlockMIs(Eq(64)), BlockNIs(Eq(32)), BlockKIs(Eq(32))), + NumStagesIs(Le(2)))); +} + } // namespace } // namespace xla::gpu From 4e88374a5fc570c4b2a9ef45e7690fe624594534 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Fri, 11 Apr 2025 12:18:26 -0700 Subject: [PATCH 0584/1324] Fix staging buffer conditional use. PiperOrigin-RevId: 746548924 --- third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 596f5a9edb8d1a..de678204df0eb5 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -793,8 +793,8 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( // Allocating multigigabyte pinned buffers can be very slow. In that case, // using a staging buffer is probably worse than not using one. // TODO(phawkins): add chunking for transfers. - if (!IsDmaMapped(data, packed_size) && - (must_use_staging_buffer || (should_stage_host_to_device_transfers() && + if (must_use_staging_buffer || (!IsDmaMapped(data, packed_size) && + (should_stage_host_to_device_transfers() && packed_size < (int64_t{1} << 30)))) { void* ptr = host_memory_allocator()->AllocateRaw( tsl::Allocator::kAllocatorAlignment, transpose ? size : packed_size); From 5f5db99be9f1cc9f7b9d8367af01f4df7b9e109b Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Fri, 11 Apr 2025 12:21:56 -0700 Subject: [PATCH 0585/1324] Add parameter considerations that ensure good occupancy to dynamic search space PiperOrigin-RevId: 746549984 --- .../xla/xla/service/gpu/autotuning/BUILD | 1 + .../gpu/autotuning/dot_search_space.cc | 59 +++++++++++++++++-- .../service/gpu/autotuning/dot_search_space.h | 22 +++++++ .../gpu/autotuning/dot_search_space_test.cc | 35 ++++++++++- 4 files changed, 108 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 3bb76e3111901e..1e7d3629511875 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -290,6 +290,7 @@ xla_test( "//xla/service/gpu:matmul_utils", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 9050b4a2c2e8ac..43eeecdf7abcf9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -97,10 +97,9 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( // Figure out some basic limitations on tiling based on the above. desired_total_warps_(GetDesiredTotalWarps()), max_out_tile_(GetMaxOutputTile()), - // TODO: b/404470821 - Compute these from the problem properties instead - // of hardcoding. - min_out_tile_{16, 16}, - min_warps_per_cta_(4), + should_optimize_for_occupancy_(ShouldOptimizeForOccupancy()), + min_out_tile_(GetMinOutputTile()), + min_warps_per_cta_(GetMinWarpsPerCta()), min_contracting_tile_size_(GetMinContractingTileSize()), max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) { // Make sure that the range of output tile sizes is not empty @@ -155,12 +154,13 @@ std::string TritonDotFusionSearchSpace::ToString() const { return absl::StrFormat( "problem_size_BxMxNxKxE: %dx%dx%dx%dx(%d->%d) " "tile_range_SxMxNxK: [1-%d]x[%d-%d]x[%d-%d]x[%d-?] " - "desired_total_warps: %d warps_per_cta: [%d-?]", + "desired_total_warps: %d occupancy_optimization: %d " + "warps_per_cta: [%d-?]", batch_size_, lhs_parallel_size_, rhs_parallel_size_, contracting_size_, operand_bitwidth_, compute_bitwidth_, max_contracting_split_, min_out_tile_.lhs_dim, max_out_tile_.lhs_dim, min_out_tile_.rhs_dim, max_out_tile_.rhs_dim, min_contracting_tile_size_, desired_total_warps_, - min_warps_per_cta_); + should_optimize_for_occupancy_, min_warps_per_cta_); } int TritonDotFusionSearchSpace::GetDesiredTotalWarps() const { @@ -210,6 +210,53 @@ TritonDotFusionSearchSpace::GetMaxOutputTile() const { return max_tile; } +bool TritonDotFusionSearchSpace::ShouldOptimizeForOccupancy() const { + const int64_t desired_num_ctas = + desired_total_warps_ / kMinWarpsPerCtaForWgmma; + const int64_t min_result_tiles = GetNumResultTiles(max_out_tile_); + if (desired_num_ctas > min_result_tiles) { + VLOG(5) << "Occupancy optimization: Might have as few as " + << min_result_tiles << " tiles, but want at least " + << desired_num_ctas + << " CTAs. Will consider trading off compute performance for " + "occupancy."; + return true; + } + return false; +} + +TritonDotFusionSearchSpace::OutputTile +TritonDotFusionSearchSpace::GetMinOutputTile() const { + // Triton currently doesn't support tiles smaller than 16x16. + // TODO: b/395572776 - Lift this restriction, and calculate a smaller tile + // based on the requested algorithm (e.g., if we want to use wgmma vs mma + // vs fma, the minimal reasonable tile size is different). + constexpr OutputTile kMinSupportedTile = {16, 16}; + constexpr OutputTile kMinWgmmaTile = {64, 16}; + if (device_description_.cuda_compute_capability().IsAtLeastHopper() && + !should_optimize_for_occupancy_) { + VLOG(5) << "Computing output_tile: Want to use wgmma, so output_tile >= " + << kMinWgmmaTile.lhs_dim << "x" << kMinWgmmaTile.rhs_dim; + return kMinWgmmaTile; + } + VLOG(5) + << "Computing output_tile: Might want to target mma, so output_tile >= " + << kMinSupportedTile.lhs_dim << "x" << kMinSupportedTile.rhs_dim; + return kMinSupportedTile; +} + +int TritonDotFusionSearchSpace::GetMinWarpsPerCta() const { + if (device_description_.cuda_compute_capability().IsAtLeastHopper() && + !should_optimize_for_occupancy_) { + VLOG(5) << "Computing num_warps: Want to use wgmma, so num_warps >= " + << kMinWarpsPerCtaForWgmma; + return kMinWarpsPerCtaForWgmma; + } + VLOG(5) << "Computing num_warps: Considering occupancy, so num_warps >= " + << kMinWarpsPerCtaForOccupancy; + return kMinWarpsPerCtaForOccupancy; +} + int64_t TritonDotFusionSearchSpace::GetNumResultTiles( OutputTile output_tile) const { return batch_size_ * diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index 5c98406658920e..d070f5496ef9cb 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -71,6 +71,14 @@ class TritonDotFusionSearchSpace { std::string ToString() const { return config.ToString(); } }; + // Newer NVIDIA GPUs can achieve good enough occupancy with as + // few as 2 warps per Cooperative Thread Array (CTA). See + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + static constexpr int kMinWarpsPerCtaForOccupancy = 2; + // To use Hopper's wgmma instructions, we need at least a single "warp + // group" (4 warps) within a CTA to cooperate on a single instruction. + /// https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions + static constexpr int kMinWarpsPerCtaForWgmma = 4; // Approximation on the maximum number of warps we would want to oversubscribe // the SMs with to overlap different GPU pipes (memory, tensor core, ALU, // special function unit, etc.) @@ -106,6 +114,17 @@ class TritonDotFusionSearchSpace { // splitting the contracting dimension for a given output tile. int64_t GetNumResultTiles(OutputTile output_tile) const; + // Decides if the problem is small enough so it makes sense to trade off + // compute for occupancy efficiency. + bool ShouldOptimizeForOccupancy() const; + + // Computes the minimum sensible size of the output tile (block_m, block_n). + OutputTile GetMinOutputTile() const; + + // Computes the minimum number of warps we want to try using per Cooperative + // Thread Array (CTA). + int GetMinWarpsPerCta() const; + // Computes how many warps per Cooperative Thread Array (aka. CTA, aka. CUDA // block) is reasonable for the given output tile and restrictions on // instruction shape. @@ -171,6 +190,8 @@ class TritonDotFusionSearchSpace { // optimal one even though it does not occupy all cores. void EliminateLowOccupancyConfigs(std::vector& configs); + // The order of these fields is important: the values of those defined earlier + // are used to compute the values of later ones. se::DeviceDescription device_description_; int64_t contracting_size_; int64_t batch_size_; @@ -180,6 +201,7 @@ class TritonDotFusionSearchSpace { int compute_bitwidth_; int desired_total_warps_; OutputTile max_out_tile_; + bool should_optimize_for_occupancy_; OutputTile min_out_tile_; int min_warps_per_cta_; int min_contracting_tile_size_; diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index ca1c2464992c1a..e1425791c8df11 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/tsl/platform/statusor.h" @@ -98,6 +99,8 @@ class DotSearchSpaceTest : public HloHardwareIndependentTestBase { device_description_.set_threads_per_block_limit(1024); device_description_.set_threads_per_warp(32); device_description_.set_shared_memory_per_block_optin(227 * 1024); + device_description_.set_gpu_compute_capability( + se::CudaComputeCapability::Hopper()); } absl::StatusOr> GetDefaultDotModule( @@ -135,7 +138,8 @@ TEST_F(DotSearchSpaceTest, SerializesSearchSpace) { EXPECT_EQ(search_space.ToString(), "problem_size_BxMxNxKxE: 1x1024x1024x1024x(16->16) " "tile_range_SxMxNxK: [1-64]x[16-256]x[16-512]x[16-?] " - "desired_total_warps: 2640 warps_per_cta: [4-?]"); + "desired_total_warps: 2640 occupancy_optimization: 1 " + "warps_per_cta: [2-?]"); } TEST_F(DotSearchSpaceTest, ReturnsValidConfigList) { @@ -262,7 +266,7 @@ TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) { NumWarpsIs(Eq(4)), SplitKIs(Eq(2))))); } -TEST_F(DotSearchSpaceTest, DoesNotBreakCTASizeLimits) { +TEST_F(DotSearchSpaceTest, DoesNotBreakCtaSizeLimits) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/1024 * 16, /*rhs_parallel_dim=*/1024 * 16)); @@ -272,7 +276,7 @@ TEST_F(DotSearchSpaceTest, DoesNotBreakCTASizeLimits) { AllOf(Not(IsEmpty()), Each(NumWarpsIs(Le(32))))); } -TEST_F(DotSearchSpaceTest, ConsidersAppropriateCTASizeForTileSize) { +TEST_F(DotSearchSpaceTest, ConsidersAppropriateCtaSizeForTileSize) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetDefaultDotModule(/*lhs_parallel_dim=*/4096, /*rhs_parallel_dim=*/4096)); @@ -362,5 +366,30 @@ TEST_F(DotSearchSpaceTest, LimitsStagesToAvailableTileSize) { NumStagesIs(Le(2)))); } +TEST_F(DotSearchSpaceTest, ConsidersFewWarpsPerCtaAndMmaForSmallProblem) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/128, /*rhs_parallel_dim=*/128, + /*contracting_dim=*/128)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT( + search_space.GenerateConfigs(), + Contains(AllOf(NumWarpsIs(Eq(2)), BlockMIs(Eq(16)), BlockNIs(Eq(16))))); +} + +TEST_F(DotSearchSpaceTest, EnsuresWgmmaShapeForLargeProblem) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/16 * 1024, + /*rhs_parallel_dim=*/16 * 1024, + /*contracting_dim=*/4096)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT( + search_space.GenerateConfigs(), + AllOf(Not(IsEmpty()), Each(AllOf(NumWarpsIs(Ge(4)), BlockMIs(Ge(64)), + BlockNIs(Ge(16)))))); +} + } // namespace } // namespace xla::gpu From c552b70092af6dc18e2ae542974c862b79854caa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 12:24:50 -0700 Subject: [PATCH 0586/1324] Adds a progress bar to stdout for long-running matcher processes to provide visual feedback. PiperOrigin-RevId: 746550923 --- .../hlo_diff/matchers/hlo_gumgraph_matcher.cc | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc index dd5f10256b1d9d..d3cc4a6ca0254c 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc @@ -15,7 +15,10 @@ #include "xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.h" #include -#include +#include +#include +#include +#include #include #include @@ -45,6 +48,18 @@ constexpr double kMetadataOpNameMatchScore = 0.1; constexpr double kMetadataSourceFileMatchScore = 0.1; constexpr double kMetadataSourceLineMatchScore = 0.1; +constexpr int kProgressBarWidth = 60; +constexpr char kProgressBarBlock = '|'; +constexpr char kProgressBarEmpty = ' '; + +void PrintProgress(int percentage) { + int lpad = static_cast(percentage / 100.0 * kProgressBarWidth); + int rpad = kProgressBarWidth - lpad; + std::cout << "\r" << std::setw(3) << percentage << "% [" + << std::string(lpad, kProgressBarBlock) + << std::string(rpad, kProgressBarEmpty) << "]" << std::flush; +} + struct NodePairSimilarity { const HloInstructionNode* left; const HloInstructionNode* right; @@ -419,7 +434,15 @@ void GreedyLimitedCandidatesBottomUpMatcher::Match( int current_mapping_count = mappings.left_to_right_instruction_map.size(); std::vector left_postorder = GetAllNodesInDfsOrder( left_.GetRoot(), DfsTraversalOrder::kPostOrder, left_.GetNodeCount()); - for (const HloInstructionNode* left_node : left_postorder) { + int progress = 0; + int total_steps = left_postorder.size(); + for (size_t i = 0; i < total_steps; ++i) { + const auto* left_node = left_postorder[i]; + int current_progress = static_cast((i * 100.0) / total_steps); + if (current_progress > progress) { + PrintProgress(current_progress); + progress = current_progress; + } // Skip matched nodes or ones without children. if (mappings.InstructionMapContainsLeft(left_node) || left_node->children.empty()) { From a1d3eea72299be14a8a33ef85865b72c6ae3a141 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 12:28:56 -0700 Subject: [PATCH 0587/1324] Integrate LLVM at llvm/llvm-project@98feb05825a1 Updates LLVM usage to match [98feb05825a1](https://github.com/llvm/llvm-project/commit/98feb05825a1) PiperOrigin-RevId: 746552182 --- third_party/llvm/generated.patch | 48 ++ third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 647 ++---------------- third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 647 ++---------------- .../xla/third_party/shardy/workspace.bzl | 4 +- 6 files changed, 164 insertions(+), 1190 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..bbffc2ff4b7cc3 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,49 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h +--- a/libc/src/__support/FPUtil/aarch64/sqrt.h ++++ b/libc/src/__support/FPUtil/aarch64/sqrt.h +@@ -18,6 +18,8 @@ + #error "Invalid include" + #endif + ++#include "src/__support/FPUtil/generic/sqrt.h" ++ + namespace LIBC_NAMESPACE_DECL { + namespace fputil { + +diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h +--- a/libc/src/__support/FPUtil/arm/sqrt.h ++++ b/libc/src/__support/FPUtil/arm/sqrt.h +@@ -18,6 +18,8 @@ + #error "Invalid include" + #endif + ++#include "src/__support/FPUtil/generic/sqrt.h" ++ + namespace LIBC_NAMESPACE_DECL { + namespace fputil { + +diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h +--- a/libc/src/__support/FPUtil/riscv/sqrt.h ++++ b/libc/src/__support/FPUtil/riscv/sqrt.h +@@ -18,6 +18,8 @@ + #error "Invalid include" + #endif + ++#include "src/__support/FPUtil/generic/sqrt.h" ++ + namespace LIBC_NAMESPACE_DECL { + namespace fputil { + +diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h +--- a/libc/src/__support/FPUtil/x86_64/sqrt.h ++++ b/libc/src/__support/FPUtil/x86_64/sqrt.h +@@ -18,6 +18,8 @@ + #error "sqrtss / sqrtsd need SSE2" + #endif + ++#include "src/__support/FPUtil/generic/sqrt.h" ++ + namespace LIBC_NAMESPACE_DECL { + namespace fputil { + diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 7993194770a240..0b67d8b3fd140f 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" - LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" + LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" + LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index b5a97068282d17..34a45370f62ef2 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,606 +1,69 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index a3ecef4..509398d 100644 +index 509398d..bbffc2f 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,586 +1 @@ +@@ -1 +1,49 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ----- a/clang/lib/Sema/TreeTransform.h --+++ b/clang/lib/Sema/TreeTransform.h --@@ -7765,17 +7765,23 @@ -- NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); -- NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); -- --- typedef TemplateArgumentLocContainerIterator< --- DependentTemplateSpecializationTypeLoc> ArgIterator; --- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), --- ArgIterator(TL, TL.getNumArgs()), --- NewTemplateArgs)) --+ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); --+ --+ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), --+ ArgsRange.end(), NewTemplateArgs)) -- return QualType(); --+ bool TemplateArgumentsChanged = !llvm::equal( --+ ArgsRange, NewTemplateArgs.arguments(), --+ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { --+ return A.getArgument().structurallyEquals(B.getArgument()); --+ }); -- -- const DependentTemplateStorage &DTN = T->getDependentTemplateName(); -- -- QualType Result = TL.getType(); --- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { --+ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || --+ TemplateArgumentsChanged) { -- TemplateName Name = getDerived().RebuildTemplateName( -- SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), -- /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, --diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h ----- a/clang/test/CodeGen/include/cuda.h --+++ b/clang/test/CodeGen/include/cuda.h --@@ -1,194 +0,0 @@ ---/* Minimal declarations for CUDA support. Testing purposes only. --- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h --- */ ---#pragma once --- ---// Make this file work with nvcc, for testing compatibility. --- ---#ifndef __NVCC__ ---#define __constant__ __attribute__((constant)) ---#define __device__ __attribute__((device)) ---#define __global__ __attribute__((global)) ---#define __host__ __attribute__((host)) ---#define __shared__ __attribute__((shared)) ---#define __managed__ __attribute__((managed)) ---#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) --- ---struct dim3 { --- unsigned x, y, z; --- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} ---}; --- ---// Host- and device-side placement new overloads. ---void *operator new(__SIZE_TYPE__, void *p) { return p; } ---void *operator new[](__SIZE_TYPE__, void *p) { return p; } ---__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } ---__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } --- ---#define CUDA_VERSION 10100 --- ---struct char1 { --- char x; --- __host__ __device__ char1(char x = 0) : x(x) {} ---}; ---struct char2 { --- char x, y; --- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} ---}; ---struct char4 { --- char x, y, z, w; --- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct uchar1 { --- unsigned char x; --- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} ---}; ---struct uchar2 { --- unsigned char x, y; --- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} ---}; ---struct uchar4 { --- unsigned char x, y, z, w; --- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct short1 { --- short x; --- __host__ __device__ short1(short x = 0) : x(x) {} ---}; ---struct short2 { --- short x, y; --- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} ---}; ---struct short4 { --- short x, y, z, w; --- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct ushort1 { --- unsigned short x; --- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} ---}; ---struct ushort2 { --- unsigned short x, y; --- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} ---}; ---struct ushort4 { --- unsigned short x, y, z, w; --- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct int1 { --- int x; --- __host__ __device__ int1(int x = 0) : x(x) {} ---}; ---struct int2 { --- int x, y; --- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} ---}; ---struct int4 { --- int x, y, z, w; --- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct uint1 { --- unsigned x; --- __host__ __device__ uint1(unsigned x = 0) : x(x) {} ---}; ---struct uint2 { --- unsigned x, y; --- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} ---}; ---struct uint3 { --- unsigned x, y, z; --- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} ---}; ---struct uint4 { --- unsigned x, y, z, w; --- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct longlong1 { --- long long x; --- __host__ __device__ longlong1(long long x = 0) : x(x) {} ---}; ---struct longlong2 { --- long long x, y; --- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} ---}; ---struct longlong4 { --- long long x, y, z, w; --- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct ulonglong1 { --- unsigned long long x; --- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} ---}; ---struct ulonglong2 { --- unsigned long long x, y; --- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} ---}; ---struct ulonglong4 { --- unsigned long long x, y, z, w; --- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct float1 { --- float x; --- __host__ __device__ float1(float x = 0) : x(x) {} ---}; ---struct float2 { --- float x, y; --- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} ---}; ---struct float4 { --- float x, y, z, w; --- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct double1 { --- double x; --- __host__ __device__ double1(double x = 0) : x(x) {} ---}; ---struct double2 { --- double x, y; --- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} ---}; ---struct double4 { --- double x, y, z, w; --- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---typedef unsigned long long cudaTextureObject_t; ---typedef unsigned long long cudaSurfaceObject_t; --- ---enum cudaTextureReadMode { --- cudaReadModeNormalizedFloat, --- cudaReadModeElementType ---}; --- ---enum cudaSurfaceBoundaryMode { --- cudaBoundaryModeZero, --- cudaBoundaryModeClamp, --- cudaBoundaryModeTrap ---}; --- ---enum { --- cudaTextureType1D, --- cudaTextureType2D, --- cudaTextureType3D, --- cudaTextureTypeCubemap, --- cudaTextureType1DLayered, --- cudaTextureType2DLayered, --- cudaTextureTypeCubemapLayered ---}; --- ---struct textureReference {}; ---template ---struct __attribute__((device_builtin_texture_type)) texture --- : public textureReference {}; --- ---#endif // !__NVCC__ --diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h ----- a/clang/test/CodeGen/Inputs/cuda.h --+++ b/clang/test/CodeGen/Inputs/cuda.h --@@ -0,0 +1,194 @@ --+/* Minimal declarations for CUDA support. Testing purposes only. --+ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h --+ */ --+#pragma once --+ --+// Make this file work with nvcc, for testing compatibility. --+ --+#ifndef __NVCC__ --+#define __constant__ __attribute__((constant)) --+#define __device__ __attribute__((device)) --+#define __global__ __attribute__((global)) --+#define __host__ __attribute__((host)) --+#define __shared__ __attribute__((shared)) --+#define __managed__ __attribute__((managed)) --+#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) --+ --+struct dim3 { --+ unsigned x, y, z; --+ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} --+}; --+ --+// Host- and device-side placement new overloads. --+void *operator new(__SIZE_TYPE__, void *p) { return p; } --+void *operator new[](__SIZE_TYPE__, void *p) { return p; } --+__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } --+__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } --+ --+#define CUDA_VERSION 10100 --+ --+struct char1 { --+ char x; --+ __host__ __device__ char1(char x = 0) : x(x) {} --+}; --+struct char2 { --+ char x, y; --+ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} --+}; --+struct char4 { --+ char x, y, z, w; --+ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct uchar1 { --+ unsigned char x; --+ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} --+}; --+struct uchar2 { --+ unsigned char x, y; --+ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} --+}; --+struct uchar4 { --+ unsigned char x, y, z, w; --+ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct short1 { --+ short x; --+ __host__ __device__ short1(short x = 0) : x(x) {} --+}; --+struct short2 { --+ short x, y; --+ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} --+}; --+struct short4 { --+ short x, y, z, w; --+ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct ushort1 { --+ unsigned short x; --+ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} --+}; --+struct ushort2 { --+ unsigned short x, y; --+ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} --+}; --+struct ushort4 { --+ unsigned short x, y, z, w; --+ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct int1 { --+ int x; --+ __host__ __device__ int1(int x = 0) : x(x) {} --+}; --+struct int2 { --+ int x, y; --+ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} --+}; --+struct int4 { --+ int x, y, z, w; --+ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct uint1 { --+ unsigned x; --+ __host__ __device__ uint1(unsigned x = 0) : x(x) {} --+}; --+struct uint2 { --+ unsigned x, y; --+ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} --+}; --+struct uint3 { --+ unsigned x, y, z; --+ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} --+}; --+struct uint4 { --+ unsigned x, y, z, w; --+ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct longlong1 { --+ long long x; --+ __host__ __device__ longlong1(long long x = 0) : x(x) {} --+}; --+struct longlong2 { --+ long long x, y; --+ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} --+}; --+struct longlong4 { --+ long long x, y, z, w; --+ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct ulonglong1 { --+ unsigned long long x; --+ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} --+}; --+struct ulonglong2 { --+ unsigned long long x, y; --+ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} --+}; --+struct ulonglong4 { --+ unsigned long long x, y, z, w; --+ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct float1 { --+ float x; --+ __host__ __device__ float1(float x = 0) : x(x) {} --+}; --+struct float2 { --+ float x, y; --+ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} --+}; --+struct float4 { --+ float x, y, z, w; --+ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct double1 { --+ double x; --+ __host__ __device__ double1(double x = 0) : x(x) {} --+}; --+struct double2 { --+ double x, y; --+ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} --+}; --+struct double4 { --+ double x, y, z, w; --+ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+typedef unsigned long long cudaTextureObject_t; --+typedef unsigned long long cudaSurfaceObject_t; --+ --+enum cudaTextureReadMode { --+ cudaReadModeNormalizedFloat, --+ cudaReadModeElementType --+}; --+ --+enum cudaSurfaceBoundaryMode { --+ cudaBoundaryModeZero, --+ cudaBoundaryModeClamp, --+ cudaBoundaryModeTrap --+}; --+ --+enum { --+ cudaTextureType1D, --+ cudaTextureType2D, --+ cudaTextureType3D, --+ cudaTextureTypeCubemap, --+ cudaTextureType1DLayered, --+ cudaTextureType2DLayered, --+ cudaTextureTypeCubemapLayered --+}; --+ --+struct textureReference {}; --+template --+struct __attribute__((device_builtin_texture_type)) texture --+ : public textureReference {}; --+ --+#endif // !__NVCC__ --diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu ----- a/clang/test/CodeGen/nvptx-surface.cu --+++ b/clang/test/CodeGen/nvptx-surface.cu --@@ -1,6 +1,6 @@ -- // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s -- // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s ---#include "include/cuda.h" --+#include "Inputs/cuda.h" -- -- #include "__clang_cuda_texture_intrinsics.h" -- --diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp ----- a/clang/test/SemaTemplate/dependent-names.cpp --+++ b/clang/test/SemaTemplate/dependent-names.cpp --@@ -458,3 +458,12 @@ -- }; -- int f(b ba) { return ba.add<0>(); } -- } --+ --+namespace TransformDependentTemplates { --+ template struct Test1 { --+ template --+ using Arg = typename T::template Arg; --+ void f(Arg); --+ void f(Arg); --+ }; --+} // namespace TransformDependentTemplates --diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ----- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --@@ -15391,12 +15391,20 @@ -- -- if (E->State == TreeEntry::SplitVectorize) { -- Res = FindLastInst(); --+ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { --+ for (auto *E : Entries) { --+ auto *I = dyn_cast_or_null(E->VectorizedValue); --+ if (!I) --+ I = &getLastInstructionInBundle(E); --+ if (Res->comesBefore(I)) --+ Res = I; --+ } --+ } -- return *Res; -- } -- -- // Set insertpoint for gathered loads to the very first load. --- if (E->State != TreeEntry::SplitVectorize && --- GatheredLoadsEntriesFirst.has_value() && --+ if (GatheredLoadsEntriesFirst.has_value() && -- E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && -- E->getOpcode() == Instruction::Load) { -- Res = FindFirstInst(); --diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ----- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll --+++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll --@@ -0,0 +1,99 @@ --+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 --+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s --+ --+define void @test(ptr %0, <8 x i8> %1) { --+; CHECK-LABEL: define void @test( --+; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { --+; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 --+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 --+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 --+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 --+; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 --+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> --+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 --+; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> --+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> --+; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) --+; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 --+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> --+; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) --+; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) --+; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] --+; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 --+; CHECK-NEXT: ret void --+; --+ %3 = load i8, ptr %0, align 2 --+ %4 = getelementptr i8, ptr %0, i64 13442 --+ %5 = load i8, ptr %4, align 2 --+ %6 = or i8 %5, %3 --+ %7 = getelementptr i8, ptr %0, i64 13550 --+ store i8 %6, ptr %7, align 2 --+ %8 = extractelement <8 x i8> %1, i64 0 --+ %9 = or i8 %5, %8 --+ %10 = getelementptr i8, ptr %0, i64 13542 --+ store i8 %9, ptr %10, align 2 --+ %11 = getelementptr i8, ptr %0, i64 13438 --+ %12 = load i8, ptr %11, align 2 --+ %13 = or i8 %12, %3 --+ %14 = getelementptr i8, ptr %0, i64 13546 --+ store i8 %13, ptr %14, align 2 --+ %15 = extractelement <8 x i8> %1, i64 2 --+ %16 = or i8 %12, %15 --+ %17 = getelementptr i8, ptr %0, i64 13538 --+ store i8 %16, ptr %17, align 2 --+ %18 = getelementptr i8, ptr %0, i64 13440 --+ %19 = load i8, ptr %18, align 4 --+ %20 = or i8 %19, %3 --+ %21 = getelementptr i8, ptr %0, i64 13548 --+ store i8 %20, ptr %21, align 4 --+ %22 = extractelement <8 x i8> %1, i64 4 --+ %23 = or i8 %19, %22 --+ %24 = getelementptr i8, ptr %0, i64 13540 --+ store i8 %23, ptr %24, align 4 --+ %25 = getelementptr i8, ptr %0, i64 13436 --+ %26 = load i8, ptr %25, align 4 --+ %27 = getelementptr i8, ptr %0, i64 13444 --+ %28 = load i8, ptr %27, align 4 --+ %29 = or i8 %28, %26 --+ %30 = getelementptr i8, ptr %0, i64 13544 --+ store i8 %29, ptr %30, align 4 --+ %31 = or i8 %26, %8 --+ %32 = getelementptr i8, ptr %0, i64 13536 --+ store i8 %31, ptr %32, align 4 --+ %33 = getelementptr i8, ptr %0, i64 13443 --+ %34 = load i8, ptr %33, align 1 --+ %35 = or i8 %34, %3 --+ %36 = getelementptr i8, ptr %0, i64 13551 --+ store i8 %35, ptr %36, align 1 --+ %37 = extractelement <8 x i8> %1, i64 7 --+ %38 = or i8 %34, %37 --+ %39 = getelementptr i8, ptr %0, i64 13543 --+ store i8 %38, ptr %39, align 1 --+ %40 = getelementptr i8, ptr %0, i64 13439 --+ %41 = load i8, ptr %40, align 1 --+ %42 = or i8 %41, %3 --+ %43 = getelementptr i8, ptr %0, i64 13547 --+ store i8 %42, ptr %43, align 1 --+ %44 = extractelement <8 x i8> %1, i64 3 --+ %45 = or i8 %41, %44 --+ %46 = getelementptr i8, ptr %0, i64 13539 --+ store i8 %45, ptr %46, align 1 --+ %47 = getelementptr i8, ptr %0, i64 13441 --+ %48 = load i8, ptr %47, align 1 --+ %49 = or i8 %48, %3 --+ %50 = getelementptr i8, ptr %0, i64 13549 --+ store i8 %49, ptr %50, align 1 --+ %51 = extractelement <8 x i8> %1, i64 5 --+ %52 = or i8 %48, %51 --+ %53 = getelementptr i8, ptr %0, i64 13541 --+ store i8 %52, ptr %53, align 1 --+ %54 = getelementptr i8, ptr %0, i64 13437 --+ %55 = load i8, ptr %54, align 1 --+ %56 = or i8 %55, %3 --+ %57 = getelementptr i8, ptr %0, i64 13545 --+ store i8 %56, ptr %57, align 1 --+ %58 = or i8 %55, %8 --+ %59 = getelementptr i8, ptr %0, i64 13537 --+ store i8 %58, ptr %59, align 1 --+ ret void --+} ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h ++--- a/libc/src/__support/FPUtil/aarch64/sqrt.h +++++ b/libc/src/__support/FPUtil/aarch64/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "Invalid include" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h ++--- a/libc/src/__support/FPUtil/arm/sqrt.h +++++ b/libc/src/__support/FPUtil/arm/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "Invalid include" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h ++--- a/libc/src/__support/FPUtil/riscv/sqrt.h +++++ b/libc/src/__support/FPUtil/riscv/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "Invalid include" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h ++--- a/libc/src/__support/FPUtil/x86_64/sqrt.h +++++ b/libc/src/__support/FPUtil/x86_64/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "sqrtss / sqrtsd need SSE2" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 73450ce..7993194 100644 +index 7993194..0b67d8b 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" -- LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" -+ LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" -+ LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" +- LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" +- LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" ++ LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" ++ LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 503b82b5d33179..1e0b188c3f5b28 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "2bd86e4ef697536b0683149a93022e21d8d5e6d3" - SHARDY_SHA256 = "a3b3672c72cadd8cafd837d7da219cebf97c5312e545ed1ebc639e71a47b60e5" + SHARDY_COMMIT = "0d88b5d25971bd66272195ceeb2288cde72997d0" + SHARDY_SHA256 = "e2cb1a9d409c49c724739e77156e7ca69b51b68e07e6017f149769f6fdafed42" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index b5a97068282d17..34a45370f62ef2 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,606 +1,69 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index a3ecef4..509398d 100644 +index 509398d..bbffc2f 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,586 +1 @@ +@@ -1 +1,49 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ----- a/clang/lib/Sema/TreeTransform.h --+++ b/clang/lib/Sema/TreeTransform.h --@@ -7765,17 +7765,23 @@ -- NewTemplateArgs.setLAngleLoc(TL.getLAngleLoc()); -- NewTemplateArgs.setRAngleLoc(TL.getRAngleLoc()); -- --- typedef TemplateArgumentLocContainerIterator< --- DependentTemplateSpecializationTypeLoc> ArgIterator; --- if (getDerived().TransformTemplateArguments(ArgIterator(TL, 0), --- ArgIterator(TL, TL.getNumArgs()), --- NewTemplateArgs)) --+ auto ArgsRange = llvm::make_range>({TL, 0}, {TL, TL.getNumArgs()}); --+ --+ if (getDerived().TransformTemplateArguments(ArgsRange.begin(), --+ ArgsRange.end(), NewTemplateArgs)) -- return QualType(); --+ bool TemplateArgumentsChanged = !llvm::equal( --+ ArgsRange, NewTemplateArgs.arguments(), --+ [](const TemplateArgumentLoc &A, const TemplateArgumentLoc &B) { --+ return A.getArgument().structurallyEquals(B.getArgument()); --+ }); -- -- const DependentTemplateStorage &DTN = T->getDependentTemplateName(); -- -- QualType Result = TL.getType(); --- if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier()) { --+ if (getDerived().AlwaysRebuild() || SS.getScopeRep() != DTN.getQualifier() || --+ TemplateArgumentsChanged) { -- TemplateName Name = getDerived().RebuildTemplateName( -- SS, TL.getTemplateKeywordLoc(), DTN.getName(), TL.getTemplateNameLoc(), -- /*ObjectType=*/QualType(), /*FirstQualifierInScope=*/nullptr, --diff -ruN --strip-trailing-cr a/clang/test/CodeGen/include/cuda.h b/clang/test/CodeGen/include/cuda.h ----- a/clang/test/CodeGen/include/cuda.h --+++ b/clang/test/CodeGen/include/cuda.h --@@ -1,194 +0,0 @@ ---/* Minimal declarations for CUDA support. Testing purposes only. --- * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h --- */ ---#pragma once --- ---// Make this file work with nvcc, for testing compatibility. --- ---#ifndef __NVCC__ ---#define __constant__ __attribute__((constant)) ---#define __device__ __attribute__((device)) ---#define __global__ __attribute__((global)) ---#define __host__ __attribute__((host)) ---#define __shared__ __attribute__((shared)) ---#define __managed__ __attribute__((managed)) ---#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) --- ---struct dim3 { --- unsigned x, y, z; --- __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} ---}; --- ---// Host- and device-side placement new overloads. ---void *operator new(__SIZE_TYPE__, void *p) { return p; } ---void *operator new[](__SIZE_TYPE__, void *p) { return p; } ---__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } ---__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } --- ---#define CUDA_VERSION 10100 --- ---struct char1 { --- char x; --- __host__ __device__ char1(char x = 0) : x(x) {} ---}; ---struct char2 { --- char x, y; --- __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} ---}; ---struct char4 { --- char x, y, z, w; --- __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct uchar1 { --- unsigned char x; --- __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} ---}; ---struct uchar2 { --- unsigned char x, y; --- __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} ---}; ---struct uchar4 { --- unsigned char x, y, z, w; --- __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct short1 { --- short x; --- __host__ __device__ short1(short x = 0) : x(x) {} ---}; ---struct short2 { --- short x, y; --- __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} ---}; ---struct short4 { --- short x, y, z, w; --- __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct ushort1 { --- unsigned short x; --- __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} ---}; ---struct ushort2 { --- unsigned short x, y; --- __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} ---}; ---struct ushort4 { --- unsigned short x, y, z, w; --- __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct int1 { --- int x; --- __host__ __device__ int1(int x = 0) : x(x) {} ---}; ---struct int2 { --- int x, y; --- __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} ---}; ---struct int4 { --- int x, y, z, w; --- __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct uint1 { --- unsigned x; --- __host__ __device__ uint1(unsigned x = 0) : x(x) {} ---}; ---struct uint2 { --- unsigned x, y; --- __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} ---}; ---struct uint3 { --- unsigned x, y, z; --- __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} ---}; ---struct uint4 { --- unsigned x, y, z, w; --- __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct longlong1 { --- long long x; --- __host__ __device__ longlong1(long long x = 0) : x(x) {} ---}; ---struct longlong2 { --- long long x, y; --- __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} ---}; ---struct longlong4 { --- long long x, y, z, w; --- __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct ulonglong1 { --- unsigned long long x; --- __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} ---}; ---struct ulonglong2 { --- unsigned long long x, y; --- __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} ---}; ---struct ulonglong4 { --- unsigned long long x, y, z, w; --- __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct float1 { --- float x; --- __host__ __device__ float1(float x = 0) : x(x) {} ---}; ---struct float2 { --- float x, y; --- __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} ---}; ---struct float4 { --- float x, y, z, w; --- __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---struct double1 { --- double x; --- __host__ __device__ double1(double x = 0) : x(x) {} ---}; ---struct double2 { --- double x, y; --- __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} ---}; ---struct double4 { --- double x, y, z, w; --- __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} ---}; --- ---typedef unsigned long long cudaTextureObject_t; ---typedef unsigned long long cudaSurfaceObject_t; --- ---enum cudaTextureReadMode { --- cudaReadModeNormalizedFloat, --- cudaReadModeElementType ---}; --- ---enum cudaSurfaceBoundaryMode { --- cudaBoundaryModeZero, --- cudaBoundaryModeClamp, --- cudaBoundaryModeTrap ---}; --- ---enum { --- cudaTextureType1D, --- cudaTextureType2D, --- cudaTextureType3D, --- cudaTextureTypeCubemap, --- cudaTextureType1DLayered, --- cudaTextureType2DLayered, --- cudaTextureTypeCubemapLayered ---}; --- ---struct textureReference {}; ---template ---struct __attribute__((device_builtin_texture_type)) texture --- : public textureReference {}; --- ---#endif // !__NVCC__ --diff -ruN --strip-trailing-cr a/clang/test/CodeGen/Inputs/cuda.h b/clang/test/CodeGen/Inputs/cuda.h ----- a/clang/test/CodeGen/Inputs/cuda.h --+++ b/clang/test/CodeGen/Inputs/cuda.h --@@ -0,0 +1,194 @@ --+/* Minimal declarations for CUDA support. Testing purposes only. --+ * This should stay in sync with clang/test/Headers/Inputs/include/cuda.h --+ */ --+#pragma once --+ --+// Make this file work with nvcc, for testing compatibility. --+ --+#ifndef __NVCC__ --+#define __constant__ __attribute__((constant)) --+#define __device__ __attribute__((device)) --+#define __global__ __attribute__((global)) --+#define __host__ __attribute__((host)) --+#define __shared__ __attribute__((shared)) --+#define __managed__ __attribute__((managed)) --+#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__))) --+ --+struct dim3 { --+ unsigned x, y, z; --+ __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {} --+}; --+ --+// Host- and device-side placement new overloads. --+void *operator new(__SIZE_TYPE__, void *p) { return p; } --+void *operator new[](__SIZE_TYPE__, void *p) { return p; } --+__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } --+__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } --+ --+#define CUDA_VERSION 10100 --+ --+struct char1 { --+ char x; --+ __host__ __device__ char1(char x = 0) : x(x) {} --+}; --+struct char2 { --+ char x, y; --+ __host__ __device__ char2(char x = 0, char y = 0) : x(x), y(y) {} --+}; --+struct char4 { --+ char x, y, z, w; --+ __host__ __device__ char4(char x = 0, char y = 0, char z = 0, char w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct uchar1 { --+ unsigned char x; --+ __host__ __device__ uchar1(unsigned char x = 0) : x(x) {} --+}; --+struct uchar2 { --+ unsigned char x, y; --+ __host__ __device__ uchar2(unsigned char x = 0, unsigned char y = 0) : x(x), y(y) {} --+}; --+struct uchar4 { --+ unsigned char x, y, z, w; --+ __host__ __device__ uchar4(unsigned char x = 0, unsigned char y = 0, unsigned char z = 0, unsigned char w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct short1 { --+ short x; --+ __host__ __device__ short1(short x = 0) : x(x) {} --+}; --+struct short2 { --+ short x, y; --+ __host__ __device__ short2(short x = 0, short y = 0) : x(x), y(y) {} --+}; --+struct short4 { --+ short x, y, z, w; --+ __host__ __device__ short4(short x = 0, short y = 0, short z = 0, short w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct ushort1 { --+ unsigned short x; --+ __host__ __device__ ushort1(unsigned short x = 0) : x(x) {} --+}; --+struct ushort2 { --+ unsigned short x, y; --+ __host__ __device__ ushort2(unsigned short x = 0, unsigned short y = 0) : x(x), y(y) {} --+}; --+struct ushort4 { --+ unsigned short x, y, z, w; --+ __host__ __device__ ushort4(unsigned short x = 0, unsigned short y = 0, unsigned short z = 0, unsigned short w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct int1 { --+ int x; --+ __host__ __device__ int1(int x = 0) : x(x) {} --+}; --+struct int2 { --+ int x, y; --+ __host__ __device__ int2(int x = 0, int y = 0) : x(x), y(y) {} --+}; --+struct int4 { --+ int x, y, z, w; --+ __host__ __device__ int4(int x = 0, int y = 0, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct uint1 { --+ unsigned x; --+ __host__ __device__ uint1(unsigned x = 0) : x(x) {} --+}; --+struct uint2 { --+ unsigned x, y; --+ __host__ __device__ uint2(unsigned x = 0, unsigned y = 0) : x(x), y(y) {} --+}; --+struct uint3 { --+ unsigned x, y, z; --+ __host__ __device__ uint3(unsigned x = 0, unsigned y = 0, unsigned z = 0) : x(x), y(y), z(z) {} --+}; --+struct uint4 { --+ unsigned x, y, z, w; --+ __host__ __device__ uint4(unsigned x = 0, unsigned y = 0, unsigned z = 0, unsigned w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct longlong1 { --+ long long x; --+ __host__ __device__ longlong1(long long x = 0) : x(x) {} --+}; --+struct longlong2 { --+ long long x, y; --+ __host__ __device__ longlong2(long long x = 0, long long y = 0) : x(x), y(y) {} --+}; --+struct longlong4 { --+ long long x, y, z, w; --+ __host__ __device__ longlong4(long long x = 0, long long y = 0, long long z = 0, long long w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct ulonglong1 { --+ unsigned long long x; --+ __host__ __device__ ulonglong1(unsigned long long x = 0) : x(x) {} --+}; --+struct ulonglong2 { --+ unsigned long long x, y; --+ __host__ __device__ ulonglong2(unsigned long long x = 0, unsigned long long y = 0) : x(x), y(y) {} --+}; --+struct ulonglong4 { --+ unsigned long long x, y, z, w; --+ __host__ __device__ ulonglong4(unsigned long long x = 0, unsigned long long y = 0, unsigned long long z = 0, unsigned long long w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct float1 { --+ float x; --+ __host__ __device__ float1(float x = 0) : x(x) {} --+}; --+struct float2 { --+ float x, y; --+ __host__ __device__ float2(float x = 0, float y = 0) : x(x), y(y) {} --+}; --+struct float4 { --+ float x, y, z, w; --+ __host__ __device__ float4(float x = 0, float y = 0, float z = 0, float w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+struct double1 { --+ double x; --+ __host__ __device__ double1(double x = 0) : x(x) {} --+}; --+struct double2 { --+ double x, y; --+ __host__ __device__ double2(double x = 0, double y = 0) : x(x), y(y) {} --+}; --+struct double4 { --+ double x, y, z, w; --+ __host__ __device__ double4(double x = 0, double y = 0, double z = 0, double w = 0) : x(x), y(y), z(z), w(w) {} --+}; --+ --+typedef unsigned long long cudaTextureObject_t; --+typedef unsigned long long cudaSurfaceObject_t; --+ --+enum cudaTextureReadMode { --+ cudaReadModeNormalizedFloat, --+ cudaReadModeElementType --+}; --+ --+enum cudaSurfaceBoundaryMode { --+ cudaBoundaryModeZero, --+ cudaBoundaryModeClamp, --+ cudaBoundaryModeTrap --+}; --+ --+enum { --+ cudaTextureType1D, --+ cudaTextureType2D, --+ cudaTextureType3D, --+ cudaTextureTypeCubemap, --+ cudaTextureType1DLayered, --+ cudaTextureType2DLayered, --+ cudaTextureTypeCubemapLayered --+}; --+ --+struct textureReference {}; --+template --+struct __attribute__((device_builtin_texture_type)) texture --+ : public textureReference {}; --+ --+#endif // !__NVCC__ --diff -ruN --strip-trailing-cr a/clang/test/CodeGen/nvptx-surface.cu b/clang/test/CodeGen/nvptx-surface.cu ----- a/clang/test/CodeGen/nvptx-surface.cu --+++ b/clang/test/CodeGen/nvptx-surface.cu --@@ -1,6 +1,6 @@ -- // RUN: %clang_cc1 -triple nvptx-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s -- // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -fcuda-is-device -O3 -o - %s -emit-llvm | FileCheck %s ---#include "include/cuda.h" --+#include "Inputs/cuda.h" -- -- #include "__clang_cuda_texture_intrinsics.h" -- --diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-names.cpp b/clang/test/SemaTemplate/dependent-names.cpp ----- a/clang/test/SemaTemplate/dependent-names.cpp --+++ b/clang/test/SemaTemplate/dependent-names.cpp --@@ -458,3 +458,12 @@ -- }; -- int f(b ba) { return ba.add<0>(); } -- } --+ --+namespace TransformDependentTemplates { --+ template struct Test1 { --+ template --+ using Arg = typename T::template Arg; --+ void f(Arg); --+ void f(Arg); --+ }; --+} // namespace TransformDependentTemplates --diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ----- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --@@ -15391,12 +15391,20 @@ -- -- if (E->State == TreeEntry::SplitVectorize) { -- Res = FindLastInst(); --+ if (ArrayRef Entries = getTreeEntries(Res); !Entries.empty()) { --+ for (auto *E : Entries) { --+ auto *I = dyn_cast_or_null(E->VectorizedValue); --+ if (!I) --+ I = &getLastInstructionInBundle(E); --+ if (Res->comesBefore(I)) --+ Res = I; --+ } --+ } -- return *Res; -- } -- -- // Set insertpoint for gathered loads to the very first load. --- if (E->State != TreeEntry::SplitVectorize && --- GatheredLoadsEntriesFirst.has_value() && --+ if (GatheredLoadsEntriesFirst.has_value() && -- E->Idx >= *GatheredLoadsEntriesFirst && !E->isGather() && -- E->getOpcode() == Instruction::Load) { -- Res = FindFirstInst(); --diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll ----- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll --+++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-last-inst-vectorized.ll --@@ -0,0 +1,99 @@ --+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 --+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s --+ --+define void @test(ptr %0, <8 x i8> %1) { --+; CHECK-LABEL: define void @test( --+; CHECK-SAME: ptr [[TMP0:%.*]], <8 x i8> [[TMP1:%.*]]) { --+; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP0]], align 2 --+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13436 --+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13536 --+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP0]], i64 13437 --+; CHECK-NEXT: [[TMP7:%.*]] = load <8 x i8>, ptr [[TMP4]], align 4 --+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[TMP1]], <8 x i8> poison, <8 x i32> --+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x i8> [[TMP7]], i8 [[TMP3]], i32 1 --+; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[TMP9]], <8 x i8> poison, <8 x i32> --+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[TMP8]], <8 x i8> poison, <16 x i32> --+; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP11]], <8 x i8> [[TMP10]], i64 8) --+; CHECK-NEXT: [[TMP13:%.*]] = load <8 x i8>, ptr [[TMP6]], align 1 --+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> poison, <8 x i32> --+; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[TMP7]], i64 0) --+; CHECK-NEXT: [[TMP16:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> [[TMP15]], <8 x i8> [[TMP14]], i64 8) --+; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i8> [[TMP16]], [[TMP12]] --+; CHECK-NEXT: store <16 x i8> [[TMP17]], ptr [[TMP5]], align 4 --+; CHECK-NEXT: ret void --+; --+ %3 = load i8, ptr %0, align 2 --+ %4 = getelementptr i8, ptr %0, i64 13442 --+ %5 = load i8, ptr %4, align 2 --+ %6 = or i8 %5, %3 --+ %7 = getelementptr i8, ptr %0, i64 13550 --+ store i8 %6, ptr %7, align 2 --+ %8 = extractelement <8 x i8> %1, i64 0 --+ %9 = or i8 %5, %8 --+ %10 = getelementptr i8, ptr %0, i64 13542 --+ store i8 %9, ptr %10, align 2 --+ %11 = getelementptr i8, ptr %0, i64 13438 --+ %12 = load i8, ptr %11, align 2 --+ %13 = or i8 %12, %3 --+ %14 = getelementptr i8, ptr %0, i64 13546 --+ store i8 %13, ptr %14, align 2 --+ %15 = extractelement <8 x i8> %1, i64 2 --+ %16 = or i8 %12, %15 --+ %17 = getelementptr i8, ptr %0, i64 13538 --+ store i8 %16, ptr %17, align 2 --+ %18 = getelementptr i8, ptr %0, i64 13440 --+ %19 = load i8, ptr %18, align 4 --+ %20 = or i8 %19, %3 --+ %21 = getelementptr i8, ptr %0, i64 13548 --+ store i8 %20, ptr %21, align 4 --+ %22 = extractelement <8 x i8> %1, i64 4 --+ %23 = or i8 %19, %22 --+ %24 = getelementptr i8, ptr %0, i64 13540 --+ store i8 %23, ptr %24, align 4 --+ %25 = getelementptr i8, ptr %0, i64 13436 --+ %26 = load i8, ptr %25, align 4 --+ %27 = getelementptr i8, ptr %0, i64 13444 --+ %28 = load i8, ptr %27, align 4 --+ %29 = or i8 %28, %26 --+ %30 = getelementptr i8, ptr %0, i64 13544 --+ store i8 %29, ptr %30, align 4 --+ %31 = or i8 %26, %8 --+ %32 = getelementptr i8, ptr %0, i64 13536 --+ store i8 %31, ptr %32, align 4 --+ %33 = getelementptr i8, ptr %0, i64 13443 --+ %34 = load i8, ptr %33, align 1 --+ %35 = or i8 %34, %3 --+ %36 = getelementptr i8, ptr %0, i64 13551 --+ store i8 %35, ptr %36, align 1 --+ %37 = extractelement <8 x i8> %1, i64 7 --+ %38 = or i8 %34, %37 --+ %39 = getelementptr i8, ptr %0, i64 13543 --+ store i8 %38, ptr %39, align 1 --+ %40 = getelementptr i8, ptr %0, i64 13439 --+ %41 = load i8, ptr %40, align 1 --+ %42 = or i8 %41, %3 --+ %43 = getelementptr i8, ptr %0, i64 13547 --+ store i8 %42, ptr %43, align 1 --+ %44 = extractelement <8 x i8> %1, i64 3 --+ %45 = or i8 %41, %44 --+ %46 = getelementptr i8, ptr %0, i64 13539 --+ store i8 %45, ptr %46, align 1 --+ %47 = getelementptr i8, ptr %0, i64 13441 --+ %48 = load i8, ptr %47, align 1 --+ %49 = or i8 %48, %3 --+ %50 = getelementptr i8, ptr %0, i64 13549 --+ store i8 %49, ptr %50, align 1 --+ %51 = extractelement <8 x i8> %1, i64 5 --+ %52 = or i8 %48, %51 --+ %53 = getelementptr i8, ptr %0, i64 13541 --+ store i8 %52, ptr %53, align 1 --+ %54 = getelementptr i8, ptr %0, i64 13437 --+ %55 = load i8, ptr %54, align 1 --+ %56 = or i8 %55, %3 --+ %57 = getelementptr i8, ptr %0, i64 13545 --+ store i8 %56, ptr %57, align 1 --+ %58 = or i8 %55, %8 --+ %59 = getelementptr i8, ptr %0, i64 13537 --+ store i8 %58, ptr %59, align 1 --+ ret void --+} ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h ++--- a/libc/src/__support/FPUtil/aarch64/sqrt.h +++++ b/libc/src/__support/FPUtil/aarch64/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "Invalid include" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h ++--- a/libc/src/__support/FPUtil/arm/sqrt.h +++++ b/libc/src/__support/FPUtil/arm/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "Invalid include" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h ++--- a/libc/src/__support/FPUtil/riscv/sqrt.h +++++ b/libc/src/__support/FPUtil/riscv/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "Invalid include" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ ++diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h ++--- a/libc/src/__support/FPUtil/x86_64/sqrt.h +++++ b/libc/src/__support/FPUtil/x86_64/sqrt.h ++@@ -18,6 +18,8 @@ ++ #error "sqrtss / sqrtsd need SSE2" ++ #endif ++ +++#include "src/__support/FPUtil/generic/sqrt.h" +++ ++ namespace LIBC_NAMESPACE_DECL { ++ namespace fputil { ++ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 73450ce..7993194 100644 +index 7993194..0b67d8b 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "f280d60c9839120618da353ab71004be33c4fa53" -- LLVM_SHA256 = "4bd04ea868766d48d3aabd666de4c38458ef0c6e074740fc1a82d4ec81efb16d" -+ LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" -+ LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" +- LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" +- LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" ++ LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" ++ LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 503b82b5d33179..1e0b188c3f5b28 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "2bd86e4ef697536b0683149a93022e21d8d5e6d3" - SHARDY_SHA256 = "a3b3672c72cadd8cafd837d7da219cebf97c5312e545ed1ebc639e71a47b60e5" + SHARDY_COMMIT = "0d88b5d25971bd66272195ceeb2288cde72997d0" + SHARDY_SHA256 = "e2cb1a9d409c49c724739e77156e7ca69b51b68e07e6017f149769f6fdafed42" tf_http_archive( name = "shardy", From 472a50134b7302bc448315ca02062d774ebd6234 Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 11 Apr 2025 12:40:14 -0700 Subject: [PATCH 0588/1324] Remove unnecessary `:lib` dependency in `session_options`. PiperOrigin-RevId: 746555792 --- tensorflow/core/BUILD | 1 - tensorflow/core/public/session_options.h | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c3d6b934ea232f..2f0ff5e91867f1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -468,7 +468,6 @@ cc_library( hdrs = ["//tensorflow/core/public:session_options.h"], visibility = ["//visibility:public"], deps = [ - ":lib", ":protos_all_cc", ], ) diff --git a/tensorflow/core/public/session_options.h b/tensorflow/core/public/session_options.h index 92134528dbf975..3335046aa58d16 100644 --- a/tensorflow/core/public/session_options.h +++ b/tensorflow/core/public/session_options.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tsl { From f803a404f95bbea4e8a1a3acb999e35e982c54b2 Mon Sep 17 00:00:00 2001 From: Won Jong Jeon Date: Fri, 11 Apr 2025 13:06:05 -0700 Subject: [PATCH 0589/1324] [mlir][tosa] Fix transpose_conv2d lit test (#91223) Fixes test_transpose_conv2d_outpad in tfl-to-tosa-pipeline.mlir Change-Id: I35315ce9c3a4ab653e61b474b348a6b093468b23 Signed-off-by: Tai Ly Co-authored-by: Tai Ly --- .../compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 1e44952b105e5f..54237fee7abf2f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -86,13 +86,12 @@ func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tens // ----- // CHECK-LABEL: test_transpose_conv2d_outpad -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} -func.func @test_transpose_conv2d_outpad(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x33x33x16xf32> { +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR0]], %[[VAR0]] {acc_type = f32, out_pad = array, stride = array} +func.func @test_transpose_conv2d_outpad(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>) -> tensor<1x33x33x16xf32> { %cst = arith.constant dense<[1, 33, 33, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none - %0 = "tfl.transpose_conv"(%cst, %cst_0, %arg0, %cst_1) + %0 = "tfl.transpose_conv"(%cst, %arg1, %arg0, %cst_1) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x33x33x16xf32> From 7e060cb1d09359244bf9fa08dfd317aa328cad88 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 13:16:27 -0700 Subject: [PATCH 0590/1324] Avoid copying around large nested vectors using shared_ptr inside HloReplication. PiperOrigin-RevId: 746567634 --- .../hlo/analysis/hlo_replication_analysis.cc | 33 ++++++++++--------- .../hlo/analysis/hlo_replication_analysis.h | 8 +++-- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index cb8da1d1a2cf75..e75030e8665b50 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -170,7 +170,7 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( // replica groups must contain every device, the size of the set is the // number of partitions or replicas. bool fully_replicated = true; - for (auto device_sets : device_sets_per_replica) { + for (const auto& device_sets : device_sets_per_replica) { fully_replicated &= device_sets.size() == 1 && (*device_sets.begin()).size() == @@ -674,10 +674,12 @@ HloReplicationAnalysis::HloReplication::HloReplication( HloReplicationAnalysis::HloReplication::State state, absl::Span> device_set_root_per_replica) : state_(state), - device_set_root_per_replica_(device_set_root_per_replica.begin(), - device_set_root_per_replica.end()) { + device_set_root_per_replica_( + std::make_shared>>( + device_set_root_per_replica.begin(), + device_set_root_per_replica.end())) { CHECK(state == State::kPartiallyReplicated || - device_set_root_per_replica_.empty()); + device_set_root_per_replica_->empty()); } HloReplicationAnalysis::HloReplication @@ -736,13 +738,13 @@ HloReplicationAnalysis::HloReplication::Merge( bool unique_on_all_devices = true; std::vector>> device_sets_per_replica; - CHECK_EQ(device_set_root_per_replica_.size(), - other.device_set_root_per_replica_.size()); - for (int i = 0; i < device_set_root_per_replica_.size(); ++i) { + CHECK_EQ(device_set_root_per_replica_->size(), + other.device_set_root_per_replica_->size()); + for (int i = 0; i < device_set_root_per_replica_->size(); ++i) { const std::vector& my_device_set_root = - device_set_root_per_replica_[i]; + device_set_root_per_replica_->at(i); const std::vector& other_device_set_root = - other.device_set_root_per_replica_[i]; + other.device_set_root_per_replica_->at(i); absl::flat_hash_map> value_to_device_set; size_t num_devices = my_device_set_root.size(); @@ -782,9 +784,9 @@ bool HloReplicationAnalysis::HloReplication::Equal( if (state_ != other.state_) { return false; } - for (int i = 0; i < device_set_root_per_replica_.size(); ++i) { - if (device_set_root_per_replica_[i] != - other.device_set_root_per_replica_[i]) { + for (int i = 0; i < device_set_root_per_replica_->size(); ++i) { + if (device_set_root_per_replica_->at(i) != + other.device_set_root_per_replica_->at(i)) { return false; } } @@ -808,7 +810,8 @@ bool HloReplicationAnalysis::HloReplication::IsUniqueOnAllDevices() const { bool HloReplicationAnalysis::HloReplication::IsReplicatedWithinSubgroup( absl::Span device_ids) const { if (device_ids.empty()) return true; - for (std::vector device_set_roots : device_set_root_per_replica_) { + for (const std::vector& device_set_roots : + *device_set_root_per_replica_) { if (!absl::c_all_of(device_ids, [&device_ids, &device_set_roots](int device_id) { return device_set_roots[device_id] == @@ -829,12 +832,12 @@ std::string HloReplicationAnalysis::HloReplication::ToString() const { case State::kPartiallyReplicated: std::ostringstream oss; oss << "PartiallyReplicated{"; - for (int k = 0; k < device_set_root_per_replica_.size(); ++k) { + for (int k = 0; k < device_set_root_per_replica_->size(); ++k) { if (k > 0) { oss << ", "; } oss << absl::StrCat( - "{", absl::StrJoin(device_set_root_per_replica_[k], ","), "}"); + "{", absl::StrJoin(device_set_root_per_replica_->at(k), ","), "}"); } oss << "}"; return oss.str(); diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h index 5a9057ebf89cab..21f64e0d0e6ab7 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ #define XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ +#include +#include #include #include #include @@ -95,7 +97,8 @@ class HloReplicationAnalysis { template friend H AbslHashValue(H h, const HloReplication& r) { - return H::combine(std::move(h), r.state_, r.device_set_root_per_replica_); + return H::combine(std::move(h), r.state_, + *r.device_set_root_per_replica_); } private: @@ -117,7 +120,8 @@ class HloReplicationAnalysis { // If cross_partition_spmd is false, groups_for_replicas_[k]'s size equals // the number of replicas, and within partition k, groups_for_replicas_[k] // maps each replica to the smallest replica ID in the set. - std::vector> device_set_root_per_replica_; + std::shared_ptr>> + device_set_root_per_replica_; }; static HloReplication DetermineHloInstructionIsReplicated( From 05d5cd2ac56b63f939b67c888a63b43e7b2c0656 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 13:25:27 -0700 Subject: [PATCH 0591/1324] Removes request scaling code (unused in production as of cl/601162681). PiperOrigin-RevId: 746570759 --- .../auto_sharding/auto_sharding_solver.cc | 76 ++----------- .../auto_sharding/auto_sharding_solver.h | 4 - .../auto_sharding_solver_test.cc | 102 ------------------ 3 files changed, 11 insertions(+), 171 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index bf017949e6a104..c592d4e86c78dd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -83,55 +83,6 @@ bool AutoShardingSolverOutput::operator==( is_optimal == other.is_optimal && peak_times == other.peak_times; } -namespace { - -double MaxCoeff( - const tsl::protobuf::RepeatedPtrField& - cost_mat) { - double max_coeff = 0.0; - for (auto& costs : cost_mat) { - for (auto& cost : costs.costs()) { - if (cost < kInfinityCost) { - max_coeff = std::max(max_coeff, cost); - } - } - } - return max_coeff; -} - -void ScaleCoeffs( - double scaling_factor, - tsl::protobuf::RepeatedPtrField* - cost_mat) { - for (auto& costs : *cost_mat) { - for (auto& cost : *costs.mutable_costs()) { - if (cost < kInfinityCost) { - cost = floor(cost * scaling_factor); - } - } - } -} - -} // namespace - -AutoShardingSolverRequest ScaleRequest( - const AutoShardingSolverRequest& request) { - if (!request.has_coeff_limit()) return request; - VLOG(0) << "Scaling request by coefficient limit: " - << request.coeff_limit().coeff(); - double max_coeff = 0.0; - max_coeff = std::max(max_coeff, MaxCoeff(request.communication_costs())); - max_coeff = std::max(max_coeff, MaxCoeff(request.computation_costs())); - max_coeff = std::max(max_coeff, MaxCoeff(request.resharding_costs())); - if (max_coeff <= request.coeff_limit().coeff()) return request; - const double scaling_factor = request.coeff_limit().coeff() / max_coeff; - AutoShardingSolverRequest scaled_request = request; - ScaleCoeffs(scaling_factor, scaled_request.mutable_communication_costs()); - ScaleCoeffs(scaling_factor, scaled_request.mutable_computation_costs()); - ScaleCoeffs(scaling_factor, scaled_request.mutable_resharding_costs()); - return scaled_request; -} - double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { std::vector min_memory_required; if (request.node_intervals().empty()) { // Handles live matrices. @@ -577,10 +528,9 @@ void AddMemoryTerms( // is guaranteed to never produce a negative overall cost for the graph, // however. absl::StatusOr FormulateAndSolveMIPFromSolverRequest( - const AutoShardingSolverRequest& unscaled_request, + const AutoShardingSolverRequest& request, std::optional overbudget_coeff) { const absl::Time start_time = absl::Now(); - const AutoShardingSolverRequest request = ScaleRequest(unscaled_request); const size_t num_edges = request.edges_size(); const int num_workers = 32; // SAT or SCIP @@ -922,17 +872,17 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( LOG(ERROR) << write_status.message(); } } - // Exports the *unscaled* solver request proto for debugging. + // Exports the solver request proto for debugging. bool dump_solver_request = false; if (dump_solver_request) { uint64_t solver_request_fprint = - tsl::Fingerprint64(unscaled_request.SerializeAsString()); + tsl::Fingerprint64(request.SerializeAsString()); std::string request_dump_path = - absl::StrCat("/tmp/solver_request_", unscaled_request.request_name(), - "_", solver_request_fprint, ".textproto"); + absl::StrCat("/tmp/solver_request_", request.request_name(), "_", + solver_request_fprint, ".textproto"); auto write_status = file::SetTextProto( // Modify this file path if needed. - request_dump_path, unscaled_request, file::Defaults()); + request_dump_path, request, file::Defaults()); LOG(INFO) << "Dumped solver request to " << request_dump_path; if (!write_status.ok()) { LOG(ERROR) << write_status.message(); @@ -941,7 +891,7 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( // Invokes the solver request callback for any additional debugging. bool solver_request_callback = false; if (solver_request_callback) { - SolverRequestCallback(unscaled_request); + SolverRequestCallback(request); } #endif if (request.enable_output()) { @@ -977,8 +927,8 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( request, s, e, overbudget_var, makespan_var, overbudget_coeff, *solver); if (result.ok()) { const AutoShardingEvaluation evaluation = - Evaluate(unscaled_request, *result, overbudget_coeff); - LOG(INFO) << "*** Total costs for the (unscaled) solver request ***"; + Evaluate(request, *result, overbudget_coeff); + LOG(INFO) << "*** Total costs for the solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost << " (lower bound: " << evaluation.lower_bound.communication_cost @@ -1156,11 +1106,7 @@ AutoShardingSolverOutput SolveGreedy(const AutoShardingSolverRequest& request, } // namespace absl::StatusOr RunHeuristicSolver( - const AutoShardingSolverRequest& unscaled_request, - const std::string& algorithm) { - // Scale the coefficients in the request in the same way as the MIP solver. - AutoShardingSolverRequest request = ScaleRequest(unscaled_request); - + const AutoShardingSolverRequest& request, const std::string& algorithm) { absl::Time start_time = absl::Now(); AutoShardingSolverOutput output; if (algorithm == "trivial") { @@ -1181,7 +1127,7 @@ absl::StatusOr RunHeuristicSolver( LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms"; LOG(INFO) << "Objective value: " << output.cost; LOG(INFO) << "Total Cost: " - << ComputeShardingStrategyCost(unscaled_request, output.s_val); + << ComputeShardingStrategyCost(request, output.s_val); return output; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index a8b9a631807ac8..551f9c7aff111c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -39,10 +39,6 @@ struct AutoShardingSolverOutput { bool operator==(const AutoShardingSolverOutput& other) const; }; -// Scales down values to reduce the range of costs & coefficients in the solver. -AutoShardingSolverRequest ScaleRequest( - const AutoShardingSolverRequest& request); - // Determines the minimum memory budget required to avoid memory violations. double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index ed3abdc7204335..ee28350cfcc6ed 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -858,108 +858,6 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { EXPECT_EQ(evaluation, expected_evaluation); } -TEST(ScaleRequest, ScalesProperly) { - AutoShardingSolverRequest unscaled_request; - const CostMatrix c = {{10000000, 11000000, 12000000, 13000000}, - {20000000, 21000000, 22000000}, - {30000000, 31000000, 32000000, 33000000}, - {40000000, 41000000, 42000000, 43000000}, - {50000000, 51000000, 52000000, 53000000}}; - const CostMatrix d = {{100000000, 110000000, 120000000, 130000000}, - {200000000, 210000000, 220000000}, - {300000000, 310000000, 320000000, 330000000}, - {400000000, 410000000, 420000000, 430000000}, - {500000000, 510000000, 520000000}}; - const CostMatrix r = {{1000000000, 1100000000, 1200000000, 1300000000, - 2000000000, 2100000000, 2200000000, 2300000000, - 3000000000, 3100000000, 3200000000, 3300000000, - 4000000000, 4100000000, 4200000000, 4300000000}, - {5000000000, 5100000000, 5200000000, 5300000000, - 6000000000, 6100000000, 6200000000, 6300000000, - 7000000000, 7100000000, 7200000000, 10000000000000}}; - AddCosts(unscaled_request.mutable_computation_costs(), c); - AddCosts(unscaled_request.mutable_communication_costs(), d); - AddCosts(unscaled_request.mutable_resharding_costs(), r); - unscaled_request.mutable_coeff_limit()->set_coeff(1e7); - - AutoShardingSolverRequest request = ScaleRequest(unscaled_request); - - AutoShardingSolverRequest expected_request; - const CostMatrix expected_c = {{10, 11, 12, 13}, - {20, 21, 22}, - {30, 31, 32, 33}, - {40, 41, 42, 43}, - {50, 51, 52, 53}}; - const CostMatrix expected_d = {{100, 110, 120, 130}, - {200, 210, 220}, - {300, 310, 320, 330}, - {400, 410, 420, 430}, - {500, 510, 520}}; - const CostMatrix expected_r = {{1000, 1100, 1200, 1300, - 2000, 2100, 2200, 2300, - 3000, 3100, 3200, 3300, - 4000, 4100, 4200, 4300}, - {5000, 5100, 5200, 5300, - 6000, 6100, 6200, 6300, - 7000, 7100, 7200, 10000000}}; - AddCosts(expected_request.mutable_computation_costs(), expected_c); - AddCosts(expected_request.mutable_communication_costs(), expected_d); - AddCosts(expected_request.mutable_resharding_costs(), expected_r); - expected_request.mutable_coeff_limit()->set_coeff(1e7); - EXPECT_THAT(request, ::tsl::proto_testing::EqualsProto(expected_request)); -} - -TEST(ScaleRequest, SkipsScaling) { - AutoShardingSolverRequest unscaled_request; - const CostMatrix c = {{10, 11, 12, 13}, - {20, 21, 22}, - {30, 31, 32, 33}, - {40, 41, 42, 43}, - {50, 51, 52, 53}}; - const CostMatrix d = {{100, 110, 120, 130}, - {200, 210, 220}, - {300, 310, 320, 330}, - {400, 410, 420, 430}, - {500, 510, 520}}; - const CostMatrix r = {{1000, 1100, 1200, 1300, - 2000, 2100, 2200, 2300, - 3000, 3100, 3200, 3300, - 4000, 4100, 4200, 4300}, - {5000, 5100, 5200, 5300, - 6000, 6100, 6200, 6300, - 7000, 7100, 7200, 10000000}}; - AddCosts(unscaled_request.mutable_computation_costs(), c); - AddCosts(unscaled_request.mutable_communication_costs(), d); - AddCosts(unscaled_request.mutable_resharding_costs(), r); - unscaled_request.mutable_coeff_limit()->set_coeff(1e7); - - AutoShardingSolverRequest request = ScaleRequest(unscaled_request); - - AutoShardingSolverRequest expected_request; - const CostMatrix expected_c = {{10, 11, 12, 13}, - {20, 21, 22}, - {30, 31, 32, 33}, - {40, 41, 42, 43}, - {50, 51, 52, 53}}; - const CostMatrix expected_d = {{100, 110, 120, 130}, - {200, 210, 220}, - {300, 310, 320, 330}, - {400, 410, 420, 430}, - {500, 510, 520}}; - const CostMatrix expected_r = {{1000, 1100, 1200, 1300, - 2000, 2100, 2200, 2300, - 3000, 3100, 3200, 3300, - 4000, 4100, 4200, 4300}, - {5000, 5100, 5200, 5300, - 6000, 6100, 6200, 6300, - 7000, 7100, 7200, 10000000}}; - AddCosts(expected_request.mutable_computation_costs(), expected_c); - AddCosts(expected_request.mutable_communication_costs(), expected_d); - AddCosts(expected_request.mutable_resharding_costs(), expected_r); - expected_request.mutable_coeff_limit()->set_coeff(1e7); - EXPECT_THAT(request, ::tsl::proto_testing::EqualsProto(expected_request)); -} - TEST(MinimumMemoryBudgetRequired, HandlesLiveMatrix) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); EXPECT_EQ(MinimumMemoryBudgetRequired(request), 1000000.0); From c64f2796868cb7d5a53e54aca4d5f1295d1a815b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 13:33:54 -0700 Subject: [PATCH 0592/1324] Avoid calling `dimensions()` or `dimensions_size()` on non-array shapes. It's a bug to use an array shape as a tuple or use a tuple shape as an array, for example. To catch such bugs, we will make these methods fail if called on non-array shapes. In order to avoid such failures, we need to make sure the shape is an array before calling `dimensions()` or `dimensions_size()`. PiperOrigin-RevId: 746573574 --- .../tf2xla/kernels/xla_call_module_loader.cc | 7 ++-- .../emitters/computation_partitioner.cc | 6 ++-- .../xla/xla/hlo/builder/xla_builder.cc | 6 ++-- third_party/xla/xla/literal_comparison.cc | 18 +++++++--- third_party/xla/xla/pjrt/pjrt_client.h | 3 +- .../xla/xla/service/collective_pipeliner.cc | 13 ++++++-- .../service/dynamic_dimension_inference.cc | 33 ++++++++++--------- .../gpu/model/tiled_hlo_instruction.cc | 3 +- .../gpu/model/triton_emitter_constraints.cc | 8 +++-- .../xla/xla/service/hlo_graph_dumper.cc | 8 ++--- .../xla/xla/service/sharding_propagation.cc | 3 +- .../xla/service/spmd/spmd_partitioner_util.cc | 8 +++-- third_party/xla/xla/service/value_range.cc | 2 +- third_party/xla/xla/shape_util.cc | 16 +++++---- 14 files changed, 86 insertions(+), 48 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index eb946ab9085b93..820d5ded5abe68 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -267,8 +267,11 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( // Get static MLIR Type from xla Shape. const xla::Shape &xla_shape = input_shapes[next_actual_input++]; - std::vector xla_dimensions(xla_shape.dimensions().begin(), - xla_shape.dimensions().end()); + std::vector xla_dimensions; + if (xla_shape.IsArray()) { + xla_dimensions = std::vector(xla_shape.dimensions().begin(), + xla_shape.dimensions().end()); + } TF_ASSIGN_OR_RETURN( mlir::Type element_type, ConvertPrimitiveTypeToMlirType(xla_shape.element_type(), builder)); diff --git a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc index f90d3d911f5393..0de9ae6bc66cbd 100644 --- a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc @@ -95,8 +95,10 @@ EpilogueSpecification EpilogueSpecification::FromIdentityIndexing( const HloInstruction* hero, const HloInstruction* root, mlir::MLIRContext* mlir_context) { EpilogueSpecification result; - absl::c_copy(root->shape().dimensions(), - std::back_inserter(result.index_ranges)); + if (root->shape().IsArray()) { + absl::c_copy(root->shape().dimensions(), + std::back_inserter(result.index_ranges)); + } result.roots.push_back(root); result.root_indexing.push_back( CreateIdentityMap(root->shape(), mlir_context)); diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 8b5d4c17fffc7f..be5357ff5f2ae2 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -764,8 +764,10 @@ absl::StatusOr XlaBuilder::Build( remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); } } - for (int64_t i = 0; i < shape->dimensions_size(); ++i) { - shape->set_dynamic_dimension(i, false); + if (shape->IsArray()) { + for (int64_t i = 0; i < shape->dimensions_size(); ++i) { + shape->set_dynamic_dimension(i, false); + } } }; for (size_t index = 0; index < instructions_.size(); ++index) { diff --git a/third_party/xla/xla/literal_comparison.cc b/third_party/xla/xla/literal_comparison.cc index 8c3e53e48bec37..a695215bcec211 100644 --- a/third_party/xla/xla/literal_comparison.cc +++ b/third_party/xla/xla/literal_comparison.cc @@ -706,14 +706,18 @@ absl::Status EqualHelper(const LiteralSlice& expected, next_index.pop_back(); } } else { - std::vector multi_index(expected.shape().dimensions_size(), 0); + std::vector multi_index( + expected.shape().IsArray() ? expected.shape().dimensions().size() : 0, + 0); auto index = absl::MakeSpan(multi_index); - Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED, - expected.shape().dimensions()); + const Shape unequal_shape = ShapeUtil::MakeShape( + PrimitiveType::PRED, expected.shape().IsArray() + ? expected.shape().dimensions() + : absl::Span()); Literal miscompared(unequal_shape); - Literal* miscompared_ptr = - (miscompare_callback == nullptr ? nullptr : &miscompared); + Literal* const miscompared_ptr = + (miscompare_callback == nullptr) ? nullptr : &miscompared; primitive_util::PrimitiveTypeSwitch( [&](auto primitive_type_constant) -> void { @@ -873,6 +877,10 @@ absl::Status EqualDynamicShapesAndDimensions(const LiteralSlice& expected, [&expected, &actual](const Shape& expected_shape, const ShapeIndex& index) -> absl::Status { auto actual_shape = ShapeUtil::GetSubshape(actual.shape(), index); + if (!expected_shape.IsArray()) { + return absl::OkStatus(); + } + for (int i = 0; i < expected_shape.dimensions().size(); ++i) { if (!expected_shape.is_dynamic_dimension(i) && !actual_shape.is_dynamic_dimension(i)) { diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index 8f8d9ae6495f19..0810a371ad9f16 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -974,7 +974,8 @@ class PjRtBuffer { // Returned dimensions have lifetime of this buffer. virtual absl::Span dimensions() const { - return on_device_shape().dimensions(); + return on_device_shape().IsArray() ? on_device_shape().dimensions() + : absl::Span(); } // The on-device memory layout of this buffer. Returned via shared_ptr to make diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 1e8c1216ab193a..60e57b4672641f 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -1243,6 +1243,15 @@ void WhileLoopAnalysis::MergeIntoExistingCollectives( "MergeIntoExistingCollectives "; } +// Returns the number of dimensions of the array shape, or 0 if the shape is not +// an array. +static int GetNumArrayDimensionsOrZero(const Shape& shape) { + if (shape.IsArray()) { + return shape.dimensions().size(); + } + return 0; +} + void WhileLoopAnalysis::CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, @@ -1304,8 +1313,8 @@ void WhileLoopAnalysis::CollectCollectivesToMove( for (auto* instr : instructions_post_order) { if (direction == CollectivePipeliner::PipeliningDirection::kForward && (instr->operand_count() != 1 || - instr->shape().dimensions_size() != - instr->operand(0)->shape().dimensions_size())) { + GetNumArrayDimensionsOrZero(instr->shape()) != + GetNumArrayDimensionsOrZero(instr->operand(0)->shape()))) { continue; } if (!should_process(instr)) { diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.cc b/third_party/xla/xla/service/dynamic_dimension_inference.cc index f92a01894420b1..2bfab519efce4e 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference.cc @@ -2814,21 +2814,24 @@ absl::Status DynamicDimensionInference::ForwardDynamicSize( bool DynamicDimensionInference::HasDynamicDimension( HloInstruction* inst, ShapeIndexView index) const { bool has_dynamic_dim = false; - ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape, - const ShapeIndex& subindex) { - if (subshape.IsTuple()) { - return; - } - if (ShapeIndexView(subindex).subspan(0, index.size()) != index) { - return; - } - for (int64_t i = 0; i < subshape.dimensions_size(); ++i) { - HloInstruction* operand_dynamic_size = GetDynamicSize(inst, subindex, i); - if (operand_dynamic_size != nullptr) { - has_dynamic_dim = true; - } - } - }); + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& subindex) { + if (subshape.IsTuple()) { + return; + } + if (ShapeIndexView(subindex).subspan(0, index.size()) != index) { + return; + } + if (subshape.IsArray()) { + for (int64_t i = 0; i < subshape.dimensions_size(); ++i) { + HloInstruction* operand_dynamic_size = + GetDynamicSize(inst, subindex, i); + if (operand_dynamic_size != nullptr) { + has_dynamic_dim = true; + } + } + } + }); return has_dynamic_dim; } diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc index ed21dd2fd4f369..e1f3de9994854b 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc @@ -47,7 +47,8 @@ absl::Status VerifyTiledHloInstructionConstructorPreconditions( const HloInstruction* hlo, llvm::SmallVector tile_sizes, llvm::SmallVector tile_strides, std::optional tile_offsets_indexing) { - int rank = hlo->shape().dimensions_size(); + const int rank = + hlo->shape().IsArray() ? hlo->shape().dimensions().size() : 0; if (tile_sizes.size() != rank) { return absl::InvalidArgumentError( diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc index 6c8c715936972e..259fa8daa05879 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc @@ -219,9 +219,11 @@ absl::StatusOr TritonEmitterConstraints::ParametersSatisfyConstraints( } int64_t num_tiles = 1; - for (auto [dim_size, tile_size] : - llvm::zip(root_shape_.dimensions(), tile_parameters)) { - num_tiles *= (dim_size + tile_size - 1) / tile_size; + if (root_shape_.IsArray()) { + for (auto [dim_size, tile_size] : + llvm::zip(root_shape_.dimensions(), tile_parameters)) { + num_tiles *= (dim_size + tile_size - 1) / tile_size; + } } // Number of blocks will exceed the hardware limit. This limitation comes from diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index 1b767e816eab88..48268e43d45bec 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -1486,10 +1486,10 @@ std::string HloDotDumper::GetInstructionNodeExtraInfo( // layout on tuples or tensors with just one dimension (which only have one // possible layout) to avoid visual noise. bool shape_is_multidim = false; - ShapeUtil::ForEachSubshape(instr->shape(), - [&](const Shape& s, const ShapeIndex&) { - shape_is_multidim |= s.dimensions_size() > 1; - }); + ShapeUtil::ForEachSubshape( + instr->shape(), [&](const Shape& s, const ShapeIndex&) { + shape_is_multidim |= s.IsArray() && s.dimensions().size() > 1; + }); std::string instr_shape; if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) { instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape()); diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index bd20f42360076e..daf8e300a5e1ba 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -1776,7 +1776,8 @@ std::optional ShardingPropagation::GetShardingFromUser( [&](const Shape& sub_shape, const ShapeIndex& index) { if (ShapeUtil::IsLeafIndex(instruction.shape(), index)) { shardings.push_back(hlo_sharding_util::ReplicateAllDataDims( - user_sharding, sub_shape.dimensions_size())); + user_sharding, + sub_shape.IsArray() ? sub_shape.dimensions().size() : 0)); } }); return HloSharding::Tuple(instruction.shape(), shardings); diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc index 6eebac549a51f8..241cb57f089fa9 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc @@ -100,9 +100,11 @@ bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { if (sharding.IsTileMaximal()) { return sharding.IsReplicated(); } - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { - if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { - return false; + if (shape.IsArray()) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { + if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { + return false; + } } } return true; diff --git a/third_party/xla/xla/service/value_range.cc b/third_party/xla/xla/service/value_range.cc index 57bffad1e00416..1a4ecd23d7fd42 100644 --- a/third_party/xla/xla/service/value_range.cc +++ b/third_party/xla/xla/service/value_range.cc @@ -96,7 +96,7 @@ Range RecursivelyIdentifyRange( // Non scalar or non-integer HLO. Abort. if ((!instr->shape().AreAllLeavesIntegers() && instr->shape().element_type() != PRED) || - instr->shape().dimensions_size() != 0) { + (instr->shape().IsArray() && !instr->shape().dimensions().empty())) { return Range{}; } VLOG(5) << "Computing Range for " << instr->ToString(); diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 944b7c44488d8a..dd7af95c71f352 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -716,8 +716,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append("[]"); return; } + // Now we are in array shape with at least one dimension. printer->Append("["); - auto print_one = [&](int i) { + // Prints the i-th dimension of the array shape. + auto print_dimension = [&](int i) { if (shape.is_dynamic_dimension(i)) { if (shape.dimensions(i) != Shape::kUnboundedSize) { printer->Append(StrCat("<=", shape.dimensions(i))); @@ -728,10 +730,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append(shape.dimensions(i)); } }; - print_one(0); + print_dimension(0); for (int i = 1, n = shape.dimensions().size(); i < n; ++i) { printer->Append(","); - print_one(i); + print_dimension(i); } printer->Append("]"); } @@ -849,9 +851,11 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { /* static */ DimensionVector ShapeUtil::CreateDimensionVectorFromShape( const Shape& shape) { DimensionVector dimensions; - dimensions.reserve(shape.dimensions().size()); - for (int i = 0; i < shape.dimensions().size(); ++i) { - dimensions.push_back(shape.dimensions(i)); + if (shape.IsArray()) { + dimensions.reserve(shape.dimensions().size()); + for (int i = 0; i < shape.dimensions().size(); ++i) { + dimensions.push_back(shape.dimensions(i)); + } } return dimensions; } From 18090554b3bfb9a4c310ce5d276f70b45eacfa9b Mon Sep 17 00:00:00 2001 From: Won Jong Jeon Date: Fri, 11 Apr 2025 14:03:02 -0700 Subject: [PATCH 0593/1324] [mlir][tosa] Add concat qi8 tests case (#91226) This adds a test case to test concatenation of quantized operands and result, to guard aginst previous tosa concat type inference bug Change-Id: Ic71d30f9237a758cffeac066d7acd8fead75a574 Signed-off-by: Tai Ly Co-authored-by: Tai Ly --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 54237fee7abf2f..2c4a64af956ecf 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -3882,3 +3882,14 @@ func.func @test_transpose_conv2d_bias_f32(%arg0: tensor<1x64x64x256xf32>) -> ten %2 = "tfl.transpose_conv"(%cst, %0, %arg0, %1) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<128x2x2x256xf32>, tensor<1x64x64x256xf32>, tensor<128xf32>) -> tensor<1x128x128x128xf32> return %2 : tensor<1x128x128x128xf32> } + +// ----- + +// CHECK-LABEL: test_concat_qconst +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<42> : tensor<28x19xi8>}> : () -> tensor<28x19x!quant.uniform> +// CHECK-DAG: %[[VAR1:.*]] = tosa.concat %[[VAR0]], %arg0 {axis = 0 : i32} : (tensor<28x19x!quant.uniform>, tensor<1x19x!quant.uniform>) -> tensor<29x19x!quant.uniform> +func.func @test_concat_qconst(%arg0: tensor<1x19x!quant.uniform> ) -> tensor<29x19x!quant.uniform> { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<28x19x!quant.uniform>, value = dense<42> : tensor<28x19xi8>} : () -> tensor<28x19x!quant.uniform> + %1 = "tfl.concatenation"(%0, %arg0) {axis = 0 : i32, fused_activation_function = "NONE"}: (tensor<28x19x!quant.uniform>, tensor<1x19x!quant.uniform>) -> tensor<29x19x!quant.uniform> + return %1 : tensor<29x19x!quant.uniform> +} From 8f8f31413e8ebed57ab43ce595d0c5e393d17b16 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 13:59:57 -0700 Subject: [PATCH 0594/1324] Removes support for live matrices (unused in production as of cl/628152039). PiperOrigin-RevId: 746582103 --- .../auto_sharding/auto_sharding_solver.cc | 225 +++++++----------- .../auto_sharding_solver_test.cc | 52 ++-- 2 files changed, 108 insertions(+), 169 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index c592d4e86c78dd..7641cd4bc5fe0d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -85,45 +85,36 @@ bool AutoShardingSolverOutput::operator==( double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { std::vector min_memory_required; - if (request.node_intervals().empty()) { // Handles live matrices. - min_memory_required.resize(request.live_size(), 0.0); - for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { - for (NodeIdx node_idx : request.live(time_idx).nodes()) { - const auto& m = request.memory_costs(node_idx).costs(); - const double fixed_memory_cost = *std::min_element(m.begin(), m.end()); - min_memory_required[time_idx] += fixed_memory_cost; - } + std::vector min_memory_required_group; + for (const auto& group : request.node_groups()) { + double fixed_memory_cost = 0.0; + for (const NodeIdx node_idx : group.prims()) { + const auto& m = request.memory_costs(node_idx).costs(); + fixed_memory_cost += *std::min_element(m.begin(), m.end()); } - } else { // Handles the interval-based memory representation. - std::vector min_memory_required_group; - for (const auto& group : request.node_groups()) { - double fixed_memory_cost = 0.0; - for (const NodeIdx node_idx : group.prims()) { - const auto& m = request.memory_costs(node_idx).costs(); - fixed_memory_cost += *std::min_element(m.begin(), m.end()); - } - min_memory_required_group.push_back(fixed_memory_cost); + min_memory_required_group.push_back(fixed_memory_cost); + } + for (NodeIdx node_idx = 0; node_idx < request.node_intervals_size(); + ++node_idx) { + const auto& interval = request.node_intervals(node_idx); + if (interval.first() > interval.second()) { + continue; } - for (NodeIdx node_idx = 0; node_idx < request.node_intervals_size(); - ++node_idx) { - const auto& interval = request.node_intervals(node_idx); - if (interval.first() > interval.second()) continue; - // Expand cost vectors if needed to cover the range of this interval. - while (min_memory_required.size() <= interval.second()) { - min_memory_required.push_back(0.0); - } - double fixed_memory_cost = 0.0; - if (node_idx < request.num_nodes()) { - const auto& m = request.memory_costs(node_idx).costs(); - fixed_memory_cost = *std::min_element(m.begin(), m.end()); - } else { - int64_t group_idx = node_idx - request.num_nodes(); - fixed_memory_cost = min_memory_required_group[group_idx]; - } - for (LivenessIdx time_idx = interval.first(); - time_idx <= interval.second(); ++time_idx) { - min_memory_required[time_idx] += fixed_memory_cost; - } + // Expand cost vectors if needed to cover the range of this interval. + while (min_memory_required.size() <= interval.second()) { + min_memory_required.push_back(0.0); + } + double fixed_memory_cost = 0.0; + if (node_idx < request.num_nodes()) { + const auto& m = request.memory_costs(node_idx).costs(); + fixed_memory_cost = *std::min_element(m.begin(), m.end()); + } else { + int64_t group_idx = node_idx - request.num_nodes(); + fixed_memory_cost = min_memory_required_group[group_idx]; + } + for (LivenessIdx time_idx = interval.first(); time_idx <= interval.second(); + ++time_idx) { + min_memory_required[time_idx] += fixed_memory_cost; } } double min_memory_budget_required_estimate = 0.0; @@ -341,10 +332,7 @@ absl::StatusOr SolveAndExtractSolution( // create constrained variables for the subsequent groups. std::optional> ReduceMemoryTerms( const AutoShardingSolverRequest& request, MPSolver& solver, - int64_t num_lives, int64_t num_primitives, - const std::function< - tsl::protobuf::RepeatedField(int64_t)>& // NOLINT - live, + int64_t num_primitives, const tsl::protobuf::RepeatedPtrField< // NOLINT AutoShardingSolverRequest_Pair>& intervals, const tsl::protobuf::RepeatedPtrField< // NOLINT @@ -360,8 +348,9 @@ std::optional> ReduceMemoryTerms( std::optional> num_terms = std::nullopt; std::vector> reduced_groups; if (groups.empty()) { - // If we've been given primitive intervals instead of a liveness matrix, we - // need to update the # of lives in order to use the memory term reducer. + // We need to compute the number of lives in order to use the memory term + // reducer. + int64_t num_lives = 0; for (const auto& interval : intervals) { if (interval.first() > interval.second()) continue; // Interval undefined num_lives = std::max(num_lives, interval.second() + 1); @@ -371,10 +360,7 @@ std::optional> ReduceMemoryTerms( return {intervals.at(prim_idx).first(), intervals.at(prim_idx).second()}; }; MemoryTermReducer reducer; - num_terms = - intervals.empty() - ? reducer.Reduce(num_lives, num_primitives, live) - : reducer.Reduce(num_lives, num_primitives, std::move(Intervals)); + num_terms = reducer.Reduce(num_lives, num_primitives, std::move(Intervals)); reduced_intervals = reducer.GetReducedIntervals(); reduced_groups = reducer.GetReducedGroups(); } else { // If we've already done term reduction, just copy over the results. @@ -720,30 +706,20 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( } // c. if (request.memory_budget() > 0) { - auto LiveNodes = - [request](int64_t live_idx) -> tsl::protobuf::RepeatedField { - return request.live(live_idx).nodes(); - }; - auto LiveEdges = - [request](int64_t live_idx) -> tsl::protobuf::RepeatedField { - return request.live_edges(live_idx).edges(); - }; std::vector> reduced_intervals_nodes, reduced_intervals_edges; absl::flat_hash_set reduced_times; std::vector group_node_vars, group_edge_vars; std::optional> num_node_terms, num_edge_terms; num_node_terms = ReduceMemoryTerms( - request, *solver, request.live_size(), request.num_nodes(), - std::move(LiveNodes), request.node_intervals(), request.node_groups(), - request.memory_costs(), "node", s, reduced_intervals_nodes, - group_node_vars, reduced_times); + request, *solver, request.num_nodes(), request.node_intervals(), + request.node_groups(), request.memory_costs(), "node", s, + reduced_intervals_nodes, group_node_vars, reduced_times); if (request.enable_memory_edge_costs()) { num_edge_terms = ReduceMemoryTerms( - request, *solver, request.live_edges_size(), request.edges_size(), - std::move(LiveEdges), request.edge_intervals(), request.edge_groups(), - request.memory_edge_costs(), "edge", e, reduced_intervals_edges, - group_edge_vars, reduced_times); + request, *solver, request.edges_size(), request.edge_intervals(), + request.edge_groups(), request.memory_edge_costs(), "edge", e, + reduced_intervals_edges, group_edge_vars, reduced_times); } absl::flat_hash_map constraints; AddMemoryTerms(request, *solver, request.num_nodes(), @@ -997,16 +973,6 @@ std::optional ShardingStrategyHasViolation( return AutoShardingViolationCode::kInfiniteCostViolationCode; } } - // Check that the peak-memory constraint is satisfied at each time step t. - for (LivenessIdx t = 0; t < request.live_size(); ++t) { - double live_memory = 0.0; - for (NodeIdx v : request.live(t).nodes()) { - live_memory += request.memory_costs(v).costs(node_strategies[v]); - if (live_memory > request.memory_budget()) { - return AutoShardingViolationCode::kMemoryViolationCode; - } - } - } return std::nullopt; } @@ -1203,76 +1169,55 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, } if (request.memory_budget() > 0) { std::vector total_memory_costs, lower_bound_memory_costs; - if (request.node_intervals().empty()) { // Handles live matrices. - total_memory_costs.resize(request.live_size(), 0.0); - lower_bound_memory_costs.resize(request.live_size(), 0.0); - for (LivenessIdx time_idx = 0; time_idx < request.live_size(); - ++time_idx) { - for (NodeIdx node_idx : request.live(time_idx).nodes()) { - const auto& m = request.memory_costs(node_idx).costs(); - total_memory_costs[time_idx] += m[s_val[node_idx]]; - lower_bound_memory_costs[time_idx] += - *std::min_element(m.begin(), m.end()); - } - if (!request.live_edges().empty() && - request.enable_memory_edge_costs()) { - for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { - const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_memory_costs[time_idx] += m[e_val(edge_idx)]; - lower_bound_memory_costs[time_idx] += - *std::min_element(m.begin(), m.end()); - } - } + std::vector total_node_group_costs, total_edge_group_costs, + lower_bound_node_group_costs, lower_bound_edge_group_costs; + for (const auto& group : request.node_groups()) { + double total_group_cost = 0.0; + double lower_bound_group_cost = 0.0; + for (const NodeIdx node_idx : group.prims()) { + const auto& m = request.memory_costs(node_idx).costs(); + total_group_cost += m[s_val[node_idx]]; + lower_bound_group_cost += *std::min_element(m.begin(), m.end()); } - } else { // Handles the interval-based memory representation. - std::vector total_node_group_costs, total_edge_group_costs, - lower_bound_node_group_costs, lower_bound_edge_group_costs; - for (const auto& group : request.node_groups()) { - double total_group_cost = 0.0; - double lower_bound_group_cost = 0.0; - for (const NodeIdx node_idx : group.prims()) { - const auto& m = request.memory_costs(node_idx).costs(); - total_group_cost += m[s_val[node_idx]]; - lower_bound_group_cost += *std::min_element(m.begin(), m.end()); - } - total_node_group_costs.push_back(total_group_cost); - lower_bound_node_group_costs.push_back(lower_bound_group_cost); + total_node_group_costs.push_back(total_group_cost); + lower_bound_node_group_costs.push_back(lower_bound_group_cost); + } + for (const auto& group : request.edge_groups()) { + double total_group_cost = 0.0; + double lower_bound_group_cost = 0.0; + for (const EdgeIdx edge_idx : group.prims()) { + const auto& m = request.memory_edge_costs(edge_idx).costs(); + total_group_cost += m[e_val(edge_idx)]; + lower_bound_group_cost += *std::min_element(m.begin(), m.end()); } - for (const auto& group : request.edge_groups()) { - double total_group_cost = 0.0; - double lower_bound_group_cost = 0.0; - for (const EdgeIdx edge_idx : group.prims()) { - const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_group_cost += m[e_val(edge_idx)]; - lower_bound_group_cost += *std::min_element(m.begin(), m.end()); - } - total_edge_group_costs.push_back(total_group_cost); - lower_bound_edge_group_costs.push_back(lower_bound_group_cost); + total_edge_group_costs.push_back(total_group_cost); + lower_bound_edge_group_costs.push_back(lower_bound_group_cost); + } + for (NodeIdx node_idx = 0; node_idx < request.node_intervals_size(); + ++node_idx) { + const auto& interval = request.node_intervals(node_idx); + if (interval.first() > interval.second()) { + continue; } - for (NodeIdx node_idx = 0; node_idx < request.node_intervals_size(); - ++node_idx) { - const auto& interval = request.node_intervals(node_idx); - if (interval.first() > interval.second()) continue; - // Expand cost vectors if needed to cover the range of this interval. - while (total_memory_costs.size() <= interval.second()) { - total_memory_costs.push_back(0.0); - lower_bound_memory_costs.push_back(0.0); - } - double total_memory_cost = 0.0, lower_bound_memory_cost = 0.0; - if (node_idx < request.num_nodes()) { - const auto& m = request.memory_costs(node_idx).costs(); - total_memory_cost = m[s_val[node_idx]]; - lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); - } else { - int64_t group_idx = node_idx - request.num_nodes(); - total_memory_cost = total_node_group_costs[group_idx]; - lower_bound_memory_cost = lower_bound_node_group_costs[group_idx]; - } - for (LivenessIdx time_idx = interval.first(); - time_idx <= interval.second(); ++time_idx) { - total_memory_costs[time_idx] += total_memory_cost; - lower_bound_memory_costs[time_idx] += lower_bound_memory_cost; - } + // Expand cost vectors if needed to cover the range of this interval. + while (total_memory_costs.size() <= interval.second()) { + total_memory_costs.push_back(0.0); + lower_bound_memory_costs.push_back(0.0); + } + double total_memory_cost = 0.0, lower_bound_memory_cost = 0.0; + if (node_idx < request.num_nodes()) { + const auto& m = request.memory_costs(node_idx).costs(); + total_memory_cost = m[s_val[node_idx]]; + lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); + } else { + int64_t group_idx = node_idx - request.num_nodes(); + total_memory_cost = total_node_group_costs[group_idx]; + lower_bound_memory_cost = lower_bound_node_group_costs[group_idx]; + } + for (LivenessIdx time_idx = interval.first(); + time_idx <= interval.second(); ++time_idx) { + total_memory_costs[time_idx] += total_memory_cost; + lower_bound_memory_costs[time_idx] += lower_bound_memory_cost; } if (request.enable_memory_edge_costs()) { for (EdgeIdx edge_idx = 0; edge_idx < request.edge_intervals_size(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index ee28350cfcc6ed..475527d96f0a62 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -104,11 +104,10 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { edge2.set_first(1); edge2.set_second(2); const auto edges = {edge1, edge2}; - const NodeMatrix live = {{1, 0}, - {1, 0}, - {1, 2, 0}, - {1, 2, 3, 0}, - {1, 3, 0}}; + const std::vector> node_intervals = + {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; + const std::vector> edge_intervals = + {{1, 2}, {2, 3}}; const CostMatrix c = {{10, 11, 12, 13}, {20, 21, 22}, {30, 31, 32, 33}, @@ -163,7 +162,8 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { request.mutable_s_len()->Add(s_len.begin(), s_len.end()); request.mutable_s_follow()->Add(s_follow.begin(), s_follow.end()); request.mutable_edges()->Add(edges.begin(), edges.end()); - AddNodes(request.mutable_live(), live); + AddIntervals(request.mutable_node_intervals(), node_intervals); + AddIntervals(request.mutable_edge_intervals(), edge_intervals); AddCosts(request.mutable_computation_costs(), c); AddCosts(request.mutable_communication_costs(), d); AddCosts(request.mutable_memory_costs(), m); @@ -188,11 +188,10 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { edge2.set_first(1); edge2.set_second(2); const auto edges = {edge1, edge2}; - const NodeMatrix live = {{1, 0}, - {1, 0}, - {1, 2, 0}, - {1, 2, 3, 0}, - {1, 3, 0}}; + const std::vector> node_intervals = + {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; + const std::vector> edge_intervals = + {{1, 2}, {2, 3}}; const CostMatrix c = {{10, 10, 10, 10}, {20, 20, 20}, {30, 30, 31, 30, 30, 30, 30}, @@ -246,7 +245,8 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { request.mutable_s_len()->Add(s_len.begin(), s_len.end()); request.mutable_s_follow()->Add(s_follow.begin(), s_follow.end()); request.mutable_edges()->Add(edges.begin(), edges.end()); - AddNodes(request.mutable_live(), live); + AddIntervals(request.mutable_node_intervals(), node_intervals); + AddIntervals(request.mutable_edge_intervals(), edge_intervals); AddCosts(request.mutable_computation_costs(), c); AddCosts(request.mutable_communication_costs(), d); AddCosts(request.mutable_memory_costs(), m); @@ -466,10 +466,6 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const std::vector> node_intervals = - {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; - const std::vector> edge_intervals = - {{1, 2}, {2, 3}}; const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, 2000, 2100, 2200, 2300, 3000, 3100, 3200, 3300, @@ -477,9 +473,6 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { {5000000, 5100, 5200, 5300, 6000, 6100, 6200, 6300, 7000, 7100, 7200, 7300}}; - request.clear_live(); - AddIntervals(request.mutable_node_intervals(), node_intervals); - AddIntervals(request.mutable_edge_intervals(), edge_intervals); AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); @@ -509,7 +502,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, {5000000, 5100, 5200, 5300, 6000, 6100, 6200, 6300, 7000, 7100, 7200, 7300}}; - request.clear_live(); + request.clear_node_intervals(); + request.clear_edge_intervals(); AddIntervals(request.mutable_node_intervals(), node_intervals); AddIntervals(request.mutable_edge_intervals(), edge_intervals); AddGroups(request.mutable_node_groups(), node_groups); @@ -533,7 +527,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; const std::vector> node_groups = {{0, 1}}; - request.clear_live(); + request.clear_node_intervals(); + request.clear_edge_intervals(); AddIntervals(request.mutable_node_intervals(), node_intervals); AddGroups(request.mutable_node_groups(), node_groups); request.set_enable_memory_edge_costs(false); @@ -569,7 +564,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; - request.clear_live(); + request.clear_node_intervals(); + request.clear_edge_intervals(); request.clear_memory_costs(); AddIntervals(request.mutable_node_intervals(), node_intervals); AddIntervals(request.mutable_edge_intervals(), edge_intervals); @@ -654,12 +650,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const std::vector> node_intervals = - {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - request.clear_live(); - AddIntervals(request.mutable_node_intervals(), node_intervals); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; @@ -690,7 +682,8 @@ TEST(AutoShardingEvaluatorTest, const std::vector> node_groups = {{0, 1}}; request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - request.clear_live(); + request.clear_node_intervals(); + request.clear_edge_intervals(); AddIntervals(request.mutable_node_intervals(), node_intervals); AddGroups(request.mutable_node_groups(), node_groups); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; @@ -868,13 +861,14 @@ TEST(MinimumMemoryBudgetRequired, HandlesReducedIntervalsAndGroups) { const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; const std::vector> node_groups = {{0, 1}}; - request.clear_live(); + request.clear_node_intervals(); + request.clear_edge_intervals(); AddIntervals(request.mutable_node_intervals(), node_intervals); AddGroups(request.mutable_node_groups(), node_groups); EXPECT_EQ(MinimumMemoryBudgetRequired(request), 1000000.0); } -TEST(StableMap, IterationOrderDeterminism){ +TEST(StableMap, IterationOrderDeterminism) { StableMap map; std::vector insertion_order = {6, 3, 1, 2, 4, 5, 10, 0, 7, 9, 8}; for (int key : insertion_order) { From 3b18a7be6335c16c31f7c1c104e8f9c1c2697c32 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 11 Apr 2025 14:17:04 -0700 Subject: [PATCH 0595/1324] Simplify `convertManualComputationOp` in `stablehlo_round_trip/shard_map_export.cc`. Replace `getInShardingWithoutManualAxes` with `eraseManualAxes` since we already have InSharding locally. Remove `getInShardingWithoutManualAxes` and `getOutShardingWithoutManualAxes`. PiperOrigin-RevId: 746588014 --- .../stablehlo_round_trip/shard_map_export.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc index 2545b47dcb712b..760bc100e1c15e 100644 --- a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc @@ -243,10 +243,9 @@ void convertManualComputationOp( // Add copy and custom_call @SPMDFullToShardShape for each operand. The // copy corresponds to custom_call @Sharding before sharding propagation. SmallVector fullToShardResults; - for (auto [operand_index, args] : llvm::enumerate( - llvm::zip_equal(op.getOperands(), op.getBody().getArgumentTypes(), - op.getInShardings().getShardings()))) { - auto [globalOperand, localArgumentType, inSharding] = args; + for (auto [globalOperand, localArgumentType, inSharding] : + llvm::zip_equal(op.getOperands(), op.getBody().getArgumentTypes(), + op.getInShardings().getShardings())) { auto copy = rewriter.create(loc, globalOperand); copy->setAttr(kXlaShardingAttr, getStringAttr(convertToHloSharding(inSharding, getMeshAttr, @@ -255,7 +254,7 @@ void convertManualComputationOp( kXlaShardingAttr, fullyManual ? fullyManualSharding : getStringAttr(convertToHloSharding( - op.getInShardingWithoutManualAxes(operand_index), + eraseManualAxes(inSharding, regionManualAxes), getMeshAttr, regionManualAxes))); auto fullToShard = rewriter.create( loc, localArgumentType, copy.getResult(), fullToShardAttributes); @@ -270,11 +269,11 @@ void convertManualComputationOp( op.getOutShardings().getShardings())) { auto copy = rewriter.create(loc, terminatorOperand.get()); copy->setAttr(kXlaShardingAttr, - fullyManual ? fullyManualSharding - : getStringAttr(convertToHloSharding( - op.getOutShardingWithoutManualAxes( - terminatorOperand.getOperandNumber()), - getMeshAttr, regionManualAxes))); + fullyManual + ? fullyManualSharding + : getStringAttr(convertToHloSharding( + eraseManualAxes(outSharding, regionManualAxes), + getMeshAttr, regionManualAxes))); shardToFullAttributes.back() = rewriter.getNamedAttr( kXlaShardingAttr, getStringAttr(convertToHloSharding( outSharding, getMeshAttr, parentManualAxes))); From 3fac29ca138ce9a865491e7b3fc902451deedc4d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 14:42:23 -0700 Subject: [PATCH 0596/1324] Removes support for precomputed memory term reductions (unused in production as of cl/746122805). PiperOrigin-RevId: 746596307 --- .../auto_sharding/auto_sharding_solver.cc | 103 +++++------------- .../auto_sharding_solver_test.cc | 10 +- 2 files changed, 30 insertions(+), 83 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 7641cd4bc5fe0d..2c474773dee426 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -86,14 +86,6 @@ bool AutoShardingSolverOutput::operator==( double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { std::vector min_memory_required; std::vector min_memory_required_group; - for (const auto& group : request.node_groups()) { - double fixed_memory_cost = 0.0; - for (const NodeIdx node_idx : group.prims()) { - const auto& m = request.memory_costs(node_idx).costs(); - fixed_memory_cost += *std::min_element(m.begin(), m.end()); - } - min_memory_required_group.push_back(fixed_memory_cost); - } for (NodeIdx node_idx = 0; node_idx < request.node_intervals_size(); ++node_idx) { const auto& interval = request.node_intervals(node_idx); @@ -335,8 +327,6 @@ std::optional> ReduceMemoryTerms( int64_t num_primitives, const tsl::protobuf::RepeatedPtrField< // NOLINT AutoShardingSolverRequest_Pair>& intervals, - const tsl::protobuf::RepeatedPtrField< // NOLINT - AutoShardingSolverRequest_Group>& groups, const tsl::protobuf::RepeatedPtrField< // NOLINT AutoShardingSolverRequest_Costs>& memory_costs, absl::string_view prim_type, @@ -347,30 +337,23 @@ std::optional> ReduceMemoryTerms( const absl::Time term_reduction_start_time = absl::Now(); std::optional> num_terms = std::nullopt; std::vector> reduced_groups; - if (groups.empty()) { - // We need to compute the number of lives in order to use the memory term - // reducer. - int64_t num_lives = 0; - for (const auto& interval : intervals) { - if (interval.first() > interval.second()) continue; // Interval undefined - num_lives = std::max(num_lives, interval.second() + 1); - } - auto Intervals = - [intervals](int64_t prim_idx) -> std::pair { - return {intervals.at(prim_idx).first(), intervals.at(prim_idx).second()}; - }; - MemoryTermReducer reducer; - num_terms = reducer.Reduce(num_lives, num_primitives, std::move(Intervals)); - reduced_intervals = reducer.GetReducedIntervals(); - reduced_groups = reducer.GetReducedGroups(); - } else { // If we've already done term reduction, just copy over the results. - for (const auto& interval : intervals) { - reduced_intervals.push_back({interval.first(), interval.second()}); - } - for (const auto& group : groups) { - reduced_groups.push_back({group.prims().begin(), group.prims().end()}); + // We need to compute the number of lives in order to use the memory term + // reducer. + int64_t num_lives = 0; + for (const auto& interval : intervals) { + if (interval.first() > interval.second()) { + continue; // Interval undefined } + num_lives = std::max(num_lives, interval.second() + 1); } + auto Intervals = + [intervals](int64_t prim_idx) -> std::pair { + return {intervals.at(prim_idx).first(), intervals.at(prim_idx).second()}; + }; + MemoryTermReducer reducer; + num_terms = reducer.Reduce(num_lives, num_primitives, std::move(Intervals)); + reduced_intervals = reducer.GetReducedIntervals(); + reduced_groups = reducer.GetReducedGroups(); solver.MakeNumVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(), absl::StrCat("group_", prim_type), &group_vars); for (int64_t group_idx = 0; group_idx < group_vars.size(); ++group_idx) { @@ -713,13 +696,13 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( std::optional> num_node_terms, num_edge_terms; num_node_terms = ReduceMemoryTerms( request, *solver, request.num_nodes(), request.node_intervals(), - request.node_groups(), request.memory_costs(), "node", s, - reduced_intervals_nodes, group_node_vars, reduced_times); + request.memory_costs(), "node", s, reduced_intervals_nodes, + group_node_vars, reduced_times); if (request.enable_memory_edge_costs()) { num_edge_terms = ReduceMemoryTerms( request, *solver, request.edges_size(), request.edge_intervals(), - request.edge_groups(), request.memory_edge_costs(), "edge", e, - reduced_intervals_edges, group_edge_vars, reduced_times); + request.memory_edge_costs(), "edge", e, reduced_intervals_edges, + group_edge_vars, reduced_times); } absl::flat_hash_map constraints; AddMemoryTerms(request, *solver, request.num_nodes(), @@ -1169,30 +1152,6 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, } if (request.memory_budget() > 0) { std::vector total_memory_costs, lower_bound_memory_costs; - std::vector total_node_group_costs, total_edge_group_costs, - lower_bound_node_group_costs, lower_bound_edge_group_costs; - for (const auto& group : request.node_groups()) { - double total_group_cost = 0.0; - double lower_bound_group_cost = 0.0; - for (const NodeIdx node_idx : group.prims()) { - const auto& m = request.memory_costs(node_idx).costs(); - total_group_cost += m[s_val[node_idx]]; - lower_bound_group_cost += *std::min_element(m.begin(), m.end()); - } - total_node_group_costs.push_back(total_group_cost); - lower_bound_node_group_costs.push_back(lower_bound_group_cost); - } - for (const auto& group : request.edge_groups()) { - double total_group_cost = 0.0; - double lower_bound_group_cost = 0.0; - for (const EdgeIdx edge_idx : group.prims()) { - const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_group_cost += m[e_val(edge_idx)]; - lower_bound_group_cost += *std::min_element(m.begin(), m.end()); - } - total_edge_group_costs.push_back(total_group_cost); - lower_bound_edge_group_costs.push_back(lower_bound_group_cost); - } for (NodeIdx node_idx = 0; node_idx < request.node_intervals_size(); ++node_idx) { const auto& interval = request.node_intervals(node_idx); @@ -1205,15 +1164,9 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, lower_bound_memory_costs.push_back(0.0); } double total_memory_cost = 0.0, lower_bound_memory_cost = 0.0; - if (node_idx < request.num_nodes()) { - const auto& m = request.memory_costs(node_idx).costs(); - total_memory_cost = m[s_val[node_idx]]; - lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); - } else { - int64_t group_idx = node_idx - request.num_nodes(); - total_memory_cost = total_node_group_costs[group_idx]; - lower_bound_memory_cost = lower_bound_node_group_costs[group_idx]; - } + const auto& m = request.memory_costs(node_idx).costs(); + total_memory_cost = m[s_val[node_idx]]; + lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); for (LivenessIdx time_idx = interval.first(); time_idx <= interval.second(); ++time_idx) { total_memory_costs[time_idx] += total_memory_cost; @@ -1230,15 +1183,9 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, lower_bound_memory_costs.push_back(0.0); } double total_memory_cost = 0.0, lower_bound_memory_cost = 0.0; - if (edge_idx < request.edges_size()) { - const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_memory_cost = m[e_val(edge_idx)]; - lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); - } else { - int64_t group_idx = edge_idx - request.edges_size(); - total_memory_cost = total_edge_group_costs[group_idx]; - lower_bound_memory_cost = lower_bound_edge_group_costs[group_idx]; - } + const auto& m = request.memory_edge_costs(edge_idx).costs(); + total_memory_cost = m[e_val(edge_idx)]; + lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); for (LivenessIdx time_idx = interval.first(); time_idx <= interval.second(); ++time_idx) { total_memory_costs[time_idx] += total_memory_cost; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 475527d96f0a62..0b2b53790a3ddb 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -486,7 +486,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { EXPECT_EQ(result, expected_output); } -TEST(FormulateAndSolveMIPFromSolverRequestTest, +TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, HandlesReducedIntervalsAndGroups) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = @@ -521,7 +521,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, EXPECT_EQ(result, expected_output); } -TEST(FormulateAndSolveMIPFromSolverRequestTest, +TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = @@ -543,7 +543,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, EXPECT_EQ(result, expected_output); } -TEST(FormulateAndSolveMIPFromSolverRequestTest, +TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, HandlesGroupsWithTinyMemoryCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = @@ -674,7 +674,7 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { EXPECT_EQ(evaluation, expected_evaluation); } -TEST(AutoShardingEvaluatorTest, +TEST(DISABLED_AutoShardingEvaluatorTest, EvaluatesOverbudgetWithReducedIntervalsAndGroups) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = @@ -856,7 +856,7 @@ TEST(MinimumMemoryBudgetRequired, HandlesLiveMatrix) { EXPECT_EQ(MinimumMemoryBudgetRequired(request), 1000000.0); } -TEST(MinimumMemoryBudgetRequired, HandlesReducedIntervalsAndGroups) { +TEST(DISABLED_MinimumMemoryBudgetRequired, HandlesReducedIntervalsAndGroups) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; From 3be27c9ffabd54594a643cf33ae0fc352aee5b37 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 11 Apr 2025 14:52:12 -0700 Subject: [PATCH 0597/1324] [xla:gpu] CommandBuffer: rename command buffer APIs to Create/Update Add Create/Update prefix to all command buffer APIs to clearly identify command creation vs command update. PiperOrigin-RevId: 746599392 --- .../gpu/runtime/command_buffer_cmd.cc | 112 ++++++++------- .../xla/xla/stream_executor/command_buffer.h | 125 +++++++++-------- .../cuda/cuda_command_buffer.cc | 29 +--- .../cuda/cuda_command_buffer.h | 3 - .../cuda/cuda_command_buffer_test.cc | 14 +- .../stream_executor/gpu/gpu_command_buffer.cc | 130 +++++++++--------- .../stream_executor/gpu/gpu_command_buffer.h | 91 ++++++------ .../gpu/gpu_command_buffer_test.cc | 76 +++++----- .../rocm/rocm_command_buffer.cc | 25 ---- .../rocm/rocm_command_buffer.h | 3 - 10 files changed, 280 insertions(+), 328 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 98c04247c17604..baab3b024cbd89 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -455,15 +455,14 @@ TracedCommandBufferCmd::RecordTracedCommand( execute_params.buffer_allocations, execute_params.stream->parent(), execute_params.command_buffer_trace_stream, trace)); - VLOG(5) << "Add nested command buffer"; + VLOG(5) << "Record traced command into command buffer: " << command_buffer; return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->AddNestedCommandBuffer(*nested_cmd, - dependencies); + return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + return command_buffer->UpdateNestedCommand(command, *nested_cmd); }); } @@ -507,12 +506,12 @@ absl::StatusOr ComputationIdCmd::Record( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->Memset(&dst, value, /*num_elements=*/1, - dependencies); + return command_buffer->CreateMemset(&dst, value, /*num_elements=*/1, + dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->Memset(command, &dst, value, - /*num_elements=*/1); + return command_buffer->UpdateMemset(command, &dst, value, + /*num_elements=*/1); }); } @@ -580,14 +579,14 @@ absl::StatusOr LaunchCmd::Record( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->Launch(dims_.thread_counts_per_block(), - dims_.block_counts(), *kernel, - *kernel_args, dependencies); + return command_buffer->CreateLaunch(dims_.thread_counts_per_block(), + dims_.block_counts(), *kernel, + *kernel_args, dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->Launch(command, dims_.thread_counts_per_block(), - dims_.block_counts(), *kernel, - *kernel_args); + return command_buffer->UpdateLaunch( + command, dims_.thread_counts_per_block(), dims_.block_counts(), + *kernel, *kernel_args); }); } @@ -661,14 +660,14 @@ CustomKernelLaunchCmd::Record(const Thunk::ExecuteParams& execute_params, return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->Launch(custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *kernel, - kernel_args, dependencies); + return command_buffer->CreateLaunch(custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), + *kernel, kernel_args, dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->Launch(command, custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *kernel, - kernel_args); + return command_buffer->UpdateLaunch( + command, custom_kernel_.thread_dims(), custom_kernel_.block_dims(), + *kernel, kernel_args); }); } @@ -715,12 +714,11 @@ MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_, - dependencies); + return command_buffer->CreateMemcpyD2D(&dst, src, num_bytes_, + dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->MemcpyDeviceToDevice(command, &dst, src, - num_bytes_); + return command_buffer->UpdateMemcpyD2D(command, &dst, src, num_bytes_); }); } @@ -755,13 +753,13 @@ absl::StatusOr MemzeroCmd::Record( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->Memset(&dst, uint8_t{0}, - /*num_elements=*/dst_.size(), - dependencies); + return command_buffer->CreateMemset(&dst, uint8_t{0}, + /*num_elements=*/dst_.size(), + dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->Memset(command, &dst, uint8_t{0}, - /*num_elements=*/dst_.size()); + return command_buffer->UpdateMemset(command, &dst, uint8_t{0}, + /*num_elements=*/dst_.size()); }); } @@ -797,12 +795,12 @@ absl::StatusOr Memset32Cmd::Record( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->Memset( + return command_buffer->CreateMemset( &dst, bit_pattern_, /*num_elements=*/dst_.size() / sizeof(uint32_t), dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->Memset( + return command_buffer->UpdateMemset( command, &dst, bit_pattern_, /*num_elements=*/dst_.size() / sizeof(uint32_t)); }); @@ -849,22 +847,22 @@ absl::StatusOr CaseCmd::Record( std::move(record_action), [&](absl::Span dependencies) { if (index_is_bool_) { - return command_buffer->Case(se::DeviceMemory(index), - std::move(branches), dependencies); + return command_buffer->CreateCase(se::DeviceMemory(index), + std::move(branches), dependencies); } else { - return command_buffer->Case(se::DeviceMemory(index), - std::move(branches), dependencies); + return command_buffer->CreateCase(se::DeviceMemory(index), + std::move(branches), dependencies); } }, [&](const se::CommandBuffer::Command* command) { if (index_is_bool_) { - return command_buffer->Case(command, se::DeviceMemory(index), - std::move(branches)); + return command_buffer->UpdateCase( + command, se::DeviceMemory(index), std::move(branches)); } else { - return command_buffer->Case(command, se::DeviceMemory(index), - std::move(branches)); + return command_buffer->UpdateCase( + command, se::DeviceMemory(index), std::move(branches)); } }); } @@ -920,13 +918,14 @@ absl::StatusOr WhileCmd::Record( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->While(se::DeviceMemory(pred), - std::move(cond), std::move(body), - dependencies); + return command_buffer->CreateWhile(se::DeviceMemory(pred), + std::move(cond), std::move(body), + dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->While(command, se::DeviceMemory(pred), - std::move(cond), std::move(body)); + return command_buffer->UpdateWhile(command, + se::DeviceMemory(pred), + std::move(cond), std::move(body)); }); } @@ -1218,7 +1217,7 @@ absl::StatusOr CuDnnCmd::Record( const bool supports_explicit, graph_->get()->SupportsExplicitCommandBufferConstruction()); if (supports_explicit) { - return RecordedCommands::Create(command_buffer->DnnGraph( + return RecordedCommands::Create(command_buffer->CreateDnnGraphCommand( *graph_->get(), *execute_params.stream, absl::Span(operands), {})); } @@ -1318,11 +1317,10 @@ CustomCallCmd::RecordLegacyCustomCall( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->AddNestedCommandBuffer(*nested_cmd, - dependencies); + return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + return command_buffer->UpdateNestedCommand(command, *nested_cmd); }); } @@ -1400,11 +1398,10 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->AddNestedCommandBuffer(*nested_cmd, - dependencies); + return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + return command_buffer->UpdateNestedCommand(command, *nested_cmd); }); } @@ -1458,11 +1455,10 @@ CollectiveCmd::RecordTracedCommand( return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->AddNestedCommandBuffer(*nested_cmd, - dependencies); + return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->AddNestedCommandBuffer(command, *nested_cmd); + return command_buffer->UpdateNestedCommand(command, *nested_cmd); }); } @@ -2020,12 +2016,12 @@ DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->AddNestedCommandBuffer(*nested_command_buffer, - dependencies); + return command_buffer->CreateNestedCommand(*nested_command_buffer, + dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->AddNestedCommandBuffer(command, - *nested_command_buffer); + return command_buffer->UpdateNestedCommand(command, + *nested_command_buffer); }); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index dbc7bbab3d3da1..fa98781c27733b 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -127,21 +127,23 @@ class CommandBuffer { // Command buffer API //===--------------------------------------------------------------------===// - // Adds a kernel launch command. - virtual absl::StatusOr Launch( + // Creates a kernel launch command. + virtual absl::StatusOr CreateLaunch( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args, absl::Span dependencies) = 0; // Updates a kernel launch command. - virtual absl::Status Launch(const Command* command, const ThreadDim& threads, - const BlockDim& blocks, const Kernel& kernel, - const KernelArgs& args) = 0; + virtual absl::Status UpdateLaunch(const Command* command, + const ThreadDim& threads, + const BlockDim& blocks, + const Kernel& kernel, + const KernelArgs& args) = 0; // Type-safe wrapper for launching typed kernels. Notice that the order of // arguments is different do disambiguate from the regular launch API. template - absl::StatusOr Launch( + absl::StatusOr CreateLaunch( const TypedKernel& kernel, const ThreadDim& threads, const BlockDim& blocks, absl::Span dependencies, Args... args); @@ -149,79 +151,88 @@ class CommandBuffer { // Type-safe wrapper for updating typed kernels. Notice that the order of // arguments is different do disambiguate from the regular launch API. template - absl::Status Launch(const Command* command, - const TypedKernel& kernel, - const ThreadDim& threads, const BlockDim& blocks, - Args... args); + absl::Status UpdateLaunch(const Command* command, + const TypedKernel& kernel, + const ThreadDim& threads, const BlockDim& blocks, + Args... args); - // Adds a nested command buffer. - virtual absl::StatusOr AddNestedCommandBuffer( + // Creates a command that launches a nested command buffer. + virtual absl::StatusOr CreateNestedCommand( const CommandBuffer& nested, absl::Span dependencies) = 0; - // Updates a nested command buffer. - virtual absl::Status AddNestedCommandBuffer(const Command* command, - const CommandBuffer& nested) = 0; + // Updates a command that launches a nested command buffer. + virtual absl::Status UpdateNestedCommand(const Command* command, + const CommandBuffer& nested) = 0; - // Adds a device-to-device memory copy. - virtual absl::StatusOr MemcpyDeviceToDevice( + // Creates a device-to-device memory copy. + virtual absl::StatusOr CreateMemcpyD2D( DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size, absl::Span dependencies) = 0; // Updates a device-to-device memory copy. - virtual absl::Status MemcpyDeviceToDevice(const Command* command, - DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) = 0; + virtual absl::Status UpdateMemcpyD2D(const Command* command, + DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) = 0; - // Adds a memset command. - virtual absl::StatusOr Memset( + // Creates a memset command. + virtual absl::StatusOr CreateMemset( DeviceMemoryBase* dst, BitPattern bit_pattern, size_t num_elements, absl::Span dependencies) = 0; // Updates a memset command. - virtual absl::Status Memset(const Command* command, DeviceMemoryBase* dst, - const BitPattern& bit_pattern, - size_t num_elements) = 0; + virtual absl::Status UpdateMemset(const Command* command, + DeviceMemoryBase* dst, + const BitPattern& bit_pattern, + size_t num_elements) = 0; - // Adds a DNN graph launch command. - virtual absl::StatusOr DnnGraph( + //--------------------------------------------------------------------------// + // Command buffer DNN graph API + //--------------------------------------------------------------------------// + + // Creates a DNN graph launch command. + virtual absl::StatusOr CreateDnnGraphCommand( dnn::DnnGraph&, Stream&, absl::Span operands, absl::Span dependencies) = 0; // Updates a DNN graph command. - virtual absl::Status DnnGraph(const Command*, dnn::DnnGraph&, Stream&, - absl::Span operands) = 0; + virtual absl::Status UpdateDnnGraphCommand( + const Command*, dnn::DnnGraph&, Stream&, + absl::Span operands) = 0; //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// - // Adds a conditional operation that will execute a command buffer constructed - // by the `branches` builder at `index`. If `index` is out of range, then it - // will run a conditional command buffer constructed by the last builder. + // Creates a conditional operation that will execute a command buffer + // constructed by the `branches` builder at `index`. If `index` is out of + // range, then it will run a conditional command buffer constructed by the + // last builder. // // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case - virtual absl::StatusOr Case( + virtual absl::StatusOr CreateCase( DeviceMemory index, std::vector branches, absl::Span dependencies) = 0; - virtual absl::StatusOr Case( + virtual absl::StatusOr CreateCase( DeviceMemory index, std::vector branches, absl::Span dependencies) = 0; - // Updates a Case operation. - virtual absl::Status Case(const Command* command, DeviceMemory index, - std::vector branches) = 0; + // Updates a Case command. + virtual absl::Status UpdateCase(const Command* command, + DeviceMemory index, + std::vector branches) = 0; - virtual absl::Status Case(const Command* command, DeviceMemory index, - std::vector branches) = 0; + virtual absl::Status UpdateCase(const Command* command, + DeviceMemory index, + std::vector branches) = 0; - // Adds a conditional operation that will execute a command buffer constructed - // by the `cond_builder` that must update `pred` value, and then depending on - // the value might execute command buffer constructed by `body_builder` and - // `cond_builder`. Will continue while `pred` value (which is continuously - // updated by `cond_builder`) is `true`. + // Creates a conditional operation that will execute a command buffer + // constructed by the `cond_builder` that must update `pred` value, and then + // depending on the value might execute command buffer constructed by + // `body_builder` and `cond_builder`. Will continue while `pred` value (which + // is continuously updated by `cond_builder`) is `true`. // // In pseudocode: // @@ -230,13 +241,15 @@ class CommandBuffer { // body_builder() // cond_builder() // - virtual absl::StatusOr While( + virtual absl::StatusOr CreateWhile( DeviceMemory pred, Builder cond_builder, Builder body_builder, absl::Span dependencies) = 0; - // Updates a While operation. - virtual absl::Status While(const Command* command, DeviceMemory pred, - Builder cond_builder, Builder body_builder) = 0; + // Updates a While command. + virtual absl::Status UpdateWhile(const Command* command, + DeviceMemory pred, + Builder cond_builder, + Builder body_builder) = 0; // Submits the command buffer for execution. virtual absl::Status Submit(Stream* stream) { @@ -283,21 +296,21 @@ class CommandBuffer { //===----------------------------------------------------------------------===// template -absl::StatusOr CommandBuffer::Launch( +absl::StatusOr CommandBuffer::CreateLaunch( const TypedKernel& kernel, const ThreadDim& threads, const BlockDim& blocks, absl::Span dependencies, Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); - return Launch(threads, blocks, *kernel, *kernel_args, dependencies); + return CreateLaunch(threads, blocks, *kernel, *kernel_args, dependencies); } template -absl::Status CommandBuffer::Launch(const Command* command, - const TypedKernel& kernel, - const ThreadDim& threads, - const BlockDim& blocks, Args... args) { +absl::Status CommandBuffer::UpdateLaunch(const Command* command, + const TypedKernel& kernel, + const ThreadDim& threads, + const BlockDim& blocks, Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); - return Launch(command, threads, blocks, *kernel, *kernel_args); + return UpdateLaunch(command, threads, blocks, *kernel, *kernel_args); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc index 8e94a896e25188..e156f4a4c88260 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -639,8 +639,7 @@ absl::Status CudaCommandBuffer::PrepareFinalization() { } TF_ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel()); - TF_RETURN_IF_ERROR( - CommandBuffer::Launch(*noop, ThreadDim(), BlockDim(), {}).status()); + TF_RETURN_IF_ERROR(CreateLaunch(*noop, ThreadDim(), BlockDim(), {}).status()); return absl::OkStatus(); } @@ -746,31 +745,5 @@ absl::Status CudaCommandBuffer::CheckCanBeUpdated() { return absl::OkStatus(); } -absl::StatusOr> -CudaCommandBuffer::GetNodeDependencies(GraphNodeHandle node) { - VLOG(2) << "Get CUDA graph node " << node << " dependencies"; - - std::vector dependencies; - - size_t num_dependencies = 0; - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuGraphNodeGetDependencies(ToCudaGraphHandle(node), - nullptr, &num_dependencies), - "Failed to get CUDA graph node depedencies size")); - - dependencies.resize(num_dependencies, nullptr); - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuGraphNodeGetDependencies(ToCudaGraphHandle(node), dependencies.data(), - &num_dependencies), - "Failed to get CUDA graph node depedencies")); - - std::vector result; - result.reserve(dependencies.size()); - absl::c_transform( - dependencies, std::back_inserter(result), - static_cast(&FromCudaGraphHandle)); - - return result; -} } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h index 61b69ad8c5121d..3cb1fa14203c41 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h @@ -175,9 +175,6 @@ class CudaCommandBuffer final : public GpuCommandBuffer { absl::Status CheckCanBeUpdated() override; - absl::StatusOr> GetNodeDependencies( - GraphNodeHandle node) override; - // A signature of a device kernels updating conditional handle(s). using SetCaseConditionKernel = TypedKernel cmd_buffer, executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK( - cmd_buffer - ->DnnGraph(graph, *stream, absl::Span(operands), {}) - .status()); + TF_ASSERT_OK_AND_ASSIGN( + auto* dnn_command, + cmd_buffer->CreateDnnGraphCommand( + graph, *stream, absl::Span(operands), {})); TF_ASSERT_OK(cmd_buffer->Finalize()); std::vector host_buffer(output0.ElementCount()); @@ -144,10 +144,8 @@ TEST(CudaCommandBufferTest, CuDnnExplicitConstructionAndUpdateWork) { // Update the command buffer to write into the new output buffer. TF_ASSERT_OK(cmd_buffer->Update()); - TF_ASSERT_OK( - cmd_buffer - ->DnnGraph(graph, *stream, absl::Span(operands), {}) - .status()); + TF_ASSERT_OK(cmd_buffer->UpdateDnnGraphCommand( + dnn_command, graph, *stream, absl::Span(operands))); TF_ASSERT_OK(cmd_buffer->Finalize()); // Run the computation. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 3947d87cfbb2c7..8b81658ee401df 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -131,7 +131,7 @@ absl::Status GpuCommandBuffer::CheckInState(State state) { } absl::StatusOr -GpuCommandBuffer::LaunchWithPackedArgs( +GpuCommandBuffer::CreateLaunchWithPackedArgs( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args, absl::Span dependencies) { @@ -148,10 +148,10 @@ GpuCommandBuffer::LaunchWithPackedArgs( GraphNodeHandle handle, CreateKernelNode(barrier, threads, blocks, kernel, packed_args)); - return AppendCommand(handle); + return AppendCommand(GpuCommand{handle}); } -absl::Status GpuCommandBuffer::LaunchWithPackedArgs( +absl::Status GpuCommandBuffer::UpdateLaunchWithPackedArgs( const Command* command, const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); @@ -160,14 +160,15 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs( packed_args); } -absl::StatusOr GpuCommandBuffer::Launch( +absl::StatusOr GpuCommandBuffer::CreateLaunch( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); // If arguments are already packed we can just launch the kernel. if (auto* packed = DynCast(&args)) { - return LaunchWithPackedArgs(threads, blocks, kernel, *packed, dependencies); + return CreateLaunchWithPackedArgs(threads, blocks, kernel, *packed, + dependencies); } // For device memory array we rely on a custom kernel arguments packing. @@ -180,22 +181,24 @@ absl::StatusOr GpuCommandBuffer::Launch( } TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return LaunchWithPackedArgs(threads, blocks, kernel, *packed, dependencies); + return CreateLaunchWithPackedArgs(threads, blocks, kernel, *packed, + dependencies); } return absl::InternalError("Unsupported kernel arguments type"); } -absl::Status GpuCommandBuffer::Launch(const Command* command, - const ThreadDim& threads, - const BlockDim& blocks, - const Kernel& kernel, - const KernelArgs& args) { +absl::Status GpuCommandBuffer::UpdateLaunch(const Command* command, + const ThreadDim& threads, + const BlockDim& blocks, + const Kernel& kernel, + const KernelArgs& args) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); // If arguments are already packed we can just launch the kernel. if (auto* packed = DynCast(&args)) { - return LaunchWithPackedArgs(command, threads, blocks, kernel, *packed); + return UpdateLaunchWithPackedArgs(command, threads, blocks, kernel, + *packed); } // For device memory array we rely on a custom kernel arguments packing. @@ -208,14 +211,15 @@ absl::Status GpuCommandBuffer::Launch(const Command* command, } TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return LaunchWithPackedArgs(command, threads, blocks, kernel, *packed); + return UpdateLaunchWithPackedArgs(command, threads, blocks, kernel, + *packed); } return absl::InternalError("Unsupported kernel arguments type"); } absl::StatusOr -GpuCommandBuffer::AddNestedCommandBuffer( +GpuCommandBuffer::CreateNestedCommand( const CommandBuffer& nested, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); @@ -225,18 +229,17 @@ GpuCommandBuffer::AddNestedCommandBuffer( : ToGraphNodeDependencies(dependencies); TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, CreateChildNode(barrier, nested)); - return AppendCommand(handle); + return AppendCommand(GpuCommand{handle}); } -absl::Status GpuCommandBuffer::AddNestedCommandBuffer( +absl::Status GpuCommandBuffer::UpdateNestedCommand( const Command* command, const CommandBuffer& nested) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateChildNode(gpu_command->handle, nested); } -absl::StatusOr -GpuCommandBuffer::MemcpyDeviceToDevice( +absl::StatusOr GpuCommandBuffer::CreateMemcpyD2D( DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); @@ -247,19 +250,19 @@ GpuCommandBuffer::MemcpyDeviceToDevice( TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, CreateMemcpyD2DNode(barrier, *dst, src, size)); - return AppendCommand(handle); + return AppendCommand(GpuCommand{handle}); } -absl::Status GpuCommandBuffer::MemcpyDeviceToDevice(const Command* command, - DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) { +absl::Status GpuCommandBuffer::UpdateMemcpyD2D(const Command* command, + DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateMemcpyD2DNode(gpu_command->handle, *dst, src, size); } -absl::StatusOr GpuCommandBuffer::Memset( +absl::StatusOr GpuCommandBuffer::CreateMemset( DeviceMemoryBase* dst, BitPattern bit_pattern, size_t num_elements, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); @@ -271,19 +274,24 @@ absl::StatusOr GpuCommandBuffer::Memset( GraphNodeHandle handle, CreateMemsetNode(barrier, *dst, bit_pattern, num_elements)); - return AppendCommand(handle); + return AppendCommand(GpuCommand{handle}); } -absl::Status GpuCommandBuffer::Memset(const Command* command, - DeviceMemoryBase* dst, - const BitPattern& bit_pattern, - size_t num_elements) { +absl::Status GpuCommandBuffer::UpdateMemset(const Command* command, + DeviceMemoryBase* dst, + const BitPattern& bit_pattern, + size_t num_elements) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); return UpdateMemsetNode(gpu_command->handle, *dst, bit_pattern, num_elements); } -absl::StatusOr GpuCommandBuffer::DnnGraph( +//----------------------------------------------------------------------------// +// Command buffer DNN graph API +//----------------------------------------------------------------------------// + +absl::StatusOr +GpuCommandBuffer::CreateDnnGraphCommand( dnn::DnnGraph& dnn_graph, Stream& stream, absl::Span operands, absl::Span dependencies) { @@ -301,13 +309,12 @@ absl::StatusOr GpuCommandBuffer::DnnGraph( TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, CreateChildNode(barrier, *nested)); - return AppendCommand(handle); + return AppendCommand(GpuCommand{handle}); } -absl::Status GpuCommandBuffer::DnnGraph(const Command* command, - dnn::DnnGraph& dnn_graph, - Stream& stream, - absl::Span operands) { +absl::Status GpuCommandBuffer::UpdateDnnGraphCommand( + const Command* command, dnn::DnnGraph& dnn_graph, Stream& stream, + absl::Span operands) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); @@ -320,9 +327,9 @@ absl::Status GpuCommandBuffer::DnnGraph(const Command* command, return UpdateChildNode(gpu_command->handle, *nested); } -//--------------------------------------------------------------------------// +//----------------------------------------------------------------------------// // Command buffer condtitional commands API -//--------------------------------------------------------------------------// +//----------------------------------------------------------------------------// absl::StatusOr> GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { @@ -334,7 +341,7 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { return handles; } -absl::StatusOr GpuCommandBuffer::Case( +absl::StatusOr GpuCommandBuffer::CreateCase( DeviceMemory index, bool index_is_bool, std::vector branches, absl::Span dependencies) { @@ -403,10 +410,10 @@ absl::StatusOr GpuCommandBuffer::Case( return AppendCommand(std::move(command)); } -absl::Status GpuCommandBuffer::Case(const Command* command, - DeviceMemory index, - bool index_is_bool, - std::vector branches) { +absl::Status GpuCommandBuffer::UpdateCase(const Command* command, + DeviceMemory index, + bool index_is_bool, + std::vector branches) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); constexpr size_t kBranchBatchSize = 8; @@ -451,41 +458,41 @@ absl::Status GpuCommandBuffer::Case(const Command* command, return absl::OkStatus(); } -absl::StatusOr GpuCommandBuffer::Case( +absl::StatusOr GpuCommandBuffer::CreateCase( DeviceMemory index, std::vector branches, absl::Span dependencies) { - return Case( + return CreateCase( DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), /*index_is_bool=*/false, branches, dependencies); } -absl::StatusOr GpuCommandBuffer::Case( +absl::StatusOr GpuCommandBuffer::CreateCase( DeviceMemory index, std::vector branches, absl::Span dependencies) { - return Case( + return CreateCase( DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), /*index_is_bool=*/true, branches, dependencies); } -absl::Status GpuCommandBuffer::Case(const Command* command, - DeviceMemory index, - std::vector branches) { - return Case( +absl::Status GpuCommandBuffer::UpdateCase(const Command* command, + DeviceMemory index, + std::vector branches) { + return UpdateCase( command, DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), /*index_is_bool=*/false, branches); } -absl::Status GpuCommandBuffer::Case(const Command* command, - DeviceMemory index, - std::vector branches) { - return Case( +absl::Status GpuCommandBuffer::UpdateCase(const Command* command, + DeviceMemory index, + std::vector branches) { + return UpdateCase( command, DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), /*index_is_bool=*/true, branches); } -absl::StatusOr GpuCommandBuffer::While( +absl::StatusOr GpuCommandBuffer::CreateWhile( DeviceMemory pred, Builder cond_builder, Builder body_builder, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); @@ -522,10 +529,10 @@ absl::StatusOr GpuCommandBuffer::While( return AppendCommand(std::move(command)); } -absl::Status GpuCommandBuffer::While(const Command* command, - DeviceMemory pred, - Builder cond_builder, - Builder body_builder) { +absl::Status GpuCommandBuffer::UpdateWhile(const Command* command, + DeviceMemory pred, + Builder cond_builder, + Builder body_builder) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); @@ -619,9 +626,8 @@ absl::Status GpuCommandBuffer::Update() { "Command buffer has to be finalized first before it can be updated"); } - VLOG(5) << "Begin update of " - << (mode_ == Mode::kPrimary ? "primary" : "nested") - << " command buffer " << this; + VLOG(5) << "Begin update of " << absl::StrCat(mode_) << " command buffer " + << this; state_ = State::kUpdate; return absl::OkStatus(); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 7381e973f7acb2..13c379cb057943 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include @@ -30,6 +29,7 @@ limitations under the License. #include "xla/stream_executor/bit_pattern.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/scoped_update_mode.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" @@ -123,68 +123,70 @@ class GpuCommandBuffer : public CommandBuffer { GpuCommandBuffer(Mode mode, StreamExecutor* parent); - using CommandBuffer::Launch; + // Bring CreateLaunch and UpdateLaunch template functions into scope. + using CommandBuffer::CreateLaunch; + using CommandBuffer::UpdateLaunch; - absl::StatusOr Launch( + absl::StatusOr CreateLaunch( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args, absl::Span dependencies) override; - absl::Status Launch(const Command* command, const ThreadDim& threads, - const BlockDim& blocks, const Kernel& kernel, - const KernelArgs& args) override; + absl::Status UpdateLaunch(const Command* command, const ThreadDim& threads, + const BlockDim& blocks, const Kernel& kernel, + const KernelArgs& args) override; - absl::StatusOr AddNestedCommandBuffer( + absl::StatusOr CreateNestedCommand( const CommandBuffer& nested, absl::Span dependencies) override; - absl::Status AddNestedCommandBuffer(const Command* command, - const CommandBuffer& nested) override; + absl::Status UpdateNestedCommand(const Command* command, + const CommandBuffer& nested) override; - absl::StatusOr MemcpyDeviceToDevice( + absl::StatusOr CreateMemcpyD2D( DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size, absl::Span dependencies) override; - absl::Status MemcpyDeviceToDevice(const Command* command, - DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) override; + absl::Status UpdateMemcpyD2D(const Command* command, DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) override; - absl::StatusOr Memset( + absl::StatusOr CreateMemset( DeviceMemoryBase* dst, BitPattern bit_pattern, size_t num_elements, absl::Span dependencies) override; - absl::Status Memset(const Command* command, DeviceMemoryBase* dst, - const BitPattern& bit_pattern, - size_t num_elements) override; + absl::Status UpdateMemset(const Command* command, DeviceMemoryBase* dst, + const BitPattern& bit_pattern, + size_t num_elements) override; - absl::StatusOr Case( - DeviceMemory index, std::vector branches, - absl::Span dependencies) override; - - absl::StatusOr DnnGraph( + absl::StatusOr CreateDnnGraphCommand( dnn::DnnGraph&, Stream&, absl::Span operands, absl::Span dependencies) override; - absl::Status DnnGraph(const Command*, dnn::DnnGraph&, Stream&, - absl::Span operands) override; + absl::Status UpdateDnnGraphCommand( + const Command*, dnn::DnnGraph&, Stream&, + absl::Span operands) override; + + absl::StatusOr CreateCase( + DeviceMemory index, std::vector branches, + absl::Span dependencies) override; - absl::StatusOr Case( + absl::StatusOr CreateCase( DeviceMemory index, std::vector branches, absl::Span dependencies) override; - absl::Status Case(const Command* command, DeviceMemory index, - std::vector branches) override; + absl::Status UpdateCase(const Command* command, DeviceMemory index, + std::vector branches) override; - absl::Status Case(const Command* command, DeviceMemory index, - std::vector branches) override; + absl::Status UpdateCase(const Command* command, DeviceMemory index, + std::vector branches) override; - absl::StatusOr While( + absl::StatusOr CreateWhile( DeviceMemory pred, Builder cond_builder, Builder body_builder, absl::Span dependencies) override; - absl::Status While(const Command* command, DeviceMemory pred, - Builder cond_builder, Builder body_builder) override; + absl::Status UpdateWhile(const Command* command, DeviceMemory pred, + Builder cond_builder, Builder body_builder) override; absl::Status Finalize() override; absl::Status Update() override; @@ -195,12 +197,6 @@ class GpuCommandBuffer : public CommandBuffer { absl::Span> commands() const; - // Returns the list of dependencies for a given node. `node` must be a node - // added to the current command buffer. The returned node pointer's lifetimes - // are bound to the current command buffer. - virtual absl::StatusOr> GetNodeDependencies( - GraphNodeHandle node) = 0; - protected: // We track the total number of allocated and alive executable graphs in the // process to track the command buffers resource usage. Executable graph @@ -260,13 +256,13 @@ class GpuCommandBuffer : public CommandBuffer { //===--------------------------------------------------------------------===// // Launches CUDA kernels with packed arguments. - absl::StatusOr LaunchWithPackedArgs( + absl::StatusOr CreateLaunchWithPackedArgs( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args, absl::Span dependencies); // Updates a kernel launch command with packed arguments. - absl::Status LaunchWithPackedArgs( + absl::Status UpdateLaunchWithPackedArgs( const Command* command, const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args); @@ -283,20 +279,13 @@ class GpuCommandBuffer : public CommandBuffer { virtual absl::Status CheckCanBeUpdated() = 0; private: - absl::StatusOr Case( + absl::StatusOr CreateCase( DeviceMemory index, bool index_is_bool, std::vector branches, absl::Span dependencies); - absl::Status Case(const Command* command, DeviceMemory index, - bool index_is_bool, std::vector branches); - - // Constructs a new command for the given graph node handle and appends it to - // the command buffer. - const Command* AppendCommand(GraphNodeHandle handle) { - commands_.push_back(std::make_unique(handle)); - return commands_.back().get(); - } + absl::Status UpdateCase(const Command* command, DeviceMemory index, + bool index_is_bool, std::vector branches); // Appends a new command to the command buffer. template diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 3c1638a1340464..121b6378fa9970 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -107,7 +107,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { executor->CreateCommandBuffer(primary)); TF_ASSERT_OK_AND_ASSIGN( auto* launch, - cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), {}, a, b, c)); + cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c)); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -126,7 +126,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { // Update command buffer to write into `d` buffer. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK( - cmd_buffer->Launch(launch, add, ThreadDim(), BlockDim(4), a, b, d)); + cmd_buffer->UpdateLaunch(launch, add, ThreadDim(), BlockDim(4), a, b, d)); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -224,9 +224,10 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { executor->CreateCommandBuffer(primary)); TF_ASSERT_OK_AND_ASSIGN(auto nested_cmd, executor->CreateCommandBuffer(nested)); - TF_ASSERT_OK(nested_cmd->Launch(add, ThreadDim(), BlockDim(4), {}, a, b, c)); + TF_ASSERT_OK( + nested_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c)); TF_ASSERT_OK_AND_ASSIGN(auto* nested_command, - primary_cmd->AddNestedCommandBuffer(*nested_cmd, {})); + primary_cmd->CreateNestedCommand(*nested_cmd, {})); TF_ASSERT_OK(primary_cmd->Finalize()); TF_ASSERT_OK(primary_cmd->Submit(stream.get())); @@ -245,10 +246,10 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { // Update command buffer to write into `d` buffer by creating a new nested // command buffer. nested_cmd = executor->CreateCommandBuffer(nested).value(); - TF_ASSERT_OK(nested_cmd->Launch(add, ThreadDim(), BlockDim(4), {}, a, b, d)); - TF_ASSERT_OK(primary_cmd->Update()); TF_ASSERT_OK( - primary_cmd->AddNestedCommandBuffer(nested_command, *nested_cmd)); + nested_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, d)); + TF_ASSERT_OK(primary_cmd->Update()); + TF_ASSERT_OK(primary_cmd->UpdateNestedCommand(nested_command, *nested_cmd)); TF_ASSERT_OK(primary_cmd->Finalize()); TF_ASSERT_OK(primary_cmd->Submit(stream.get())); @@ -277,8 +278,8 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { // Create a command buffer with a single a to b memcpy command. TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK_AND_ASSIGN( - auto* memcpy, cmd_buffer->MemcpyDeviceToDevice(&b, a, byte_length, {})); + TF_ASSERT_OK_AND_ASSIGN(auto* memcpy, + cmd_buffer->CreateMemcpyD2D(&b, a, byte_length, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -292,7 +293,7 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { // Update command buffer to swap the memcpy direction. TF_ASSERT_OK(cmd_buffer->Update()); - TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(memcpy, &a, b, byte_length)); + TF_ASSERT_OK(cmd_buffer->UpdateMemcpyD2D(memcpy, &a, b, byte_length)); TF_ASSERT_OK(cmd_buffer->Finalize()); // Clear destination to test that command buffer actually copied memory. @@ -320,8 +321,9 @@ TEST(GpuCommandBufferTest, Memset) { // Create a command buffer with a single memset command. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK_AND_ASSIGN(const CommandBuffer::Command* memset, - cmd_buffer->Memset(&a, uint32_t{42}, length, {})); + TF_ASSERT_OK_AND_ASSIGN( + const CommandBuffer::Command* memset, + cmd_buffer->CreateMemset(&a, uint32_t{42}, length, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -335,7 +337,7 @@ TEST(GpuCommandBufferTest, Memset) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); - TF_ASSERT_OK(cmd_buffer->Memset(memset, &a, uint32_t{43}, length)); + TF_ASSERT_OK(cmd_buffer->UpdateMemset(memset, &a, uint32_t{43}, length)); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -380,7 +382,7 @@ TEST(GpuCommandBufferTest, ConditionalCaseEmptyGraph) { // if (index == 0) c = a + b CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { - return branch0_cmd->Launch(add, ThreadDim(), BlockDim(4), {}, a, b, c) + return branch0_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c) .status(); }; @@ -392,7 +394,7 @@ TEST(GpuCommandBufferTest, ConditionalCaseEmptyGraph) { // Create a command buffer with a single conditional operation. TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1}, {})); + TF_ASSERT_OK(cmd_buffer->CreateCase(index, {branch0, branch1}, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -481,8 +483,8 @@ TEST_P(GpuCommandBufferCaseTest, ConditionalMultiCase) { branches[i] = [&, i](CommandBuffer* branch_cmd) { // result = i * i; return branch_cmd - ->Launch(mul, ThreadDim(), BlockDim(kLength), {}, values[i], - values[i], results[i]) + ->CreateLaunch(mul, ThreadDim(), BlockDim(kLength), {}, values[i], + values[i], results[i]) .status(); }; } @@ -490,7 +492,7 @@ TEST_P(GpuCommandBufferCaseTest, ConditionalMultiCase) { // Create a command buffer with a single conditional operation. TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK(cmd_buffer->Case(index, branches, {})); + TF_ASSERT_OK(cmd_buffer->CreateCase(index, branches, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); // We test the out of bounds cases as well ( i < 0, i >= kNumCases). @@ -567,19 +569,19 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // if (index == 0) c = a + b CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { - return branch0_cmd->Launch(add, ThreadDim(), BlockDim(4), {}, a, b, c) + return branch0_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c) .status(); }; // if (index == 1) c = a * b CommandBuffer::Builder branch1 = [&](CommandBuffer* branch1_cmd) { - return branch1_cmd->Launch(mul, ThreadDim(), BlockDim(4), {}, a, b, c) + return branch1_cmd->CreateLaunch(mul, ThreadDim(), BlockDim(4), {}, a, b, c) .status(); }; // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1}, {})); + TF_ASSERT_OK(cmd_buffer->CreateCase(index, {branch0, branch1}, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -665,20 +667,21 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { // Loop cond: loop_counter++ < num_iters; CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { return cond_cmd - ->Launch(inc_and_cmp, ThreadDim(), BlockDim(), {}, loop_counter, pred, - num_iters) + ->CreateLaunch(inc_and_cmp, ThreadDim(), BlockDim(), {}, loop_counter, + pred, num_iters) .status(); }; // Loop body: b = a + b CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { - return body_cmd->Launch(add, ThreadDim(), BlockDim(length), {}, a, b, b) + return body_cmd + ->CreateLaunch(add, ThreadDim(), BlockDim(length), {}, a, b, b) .status(); }; // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder, {})); + TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, cond_builder, body_builder, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -738,32 +741,34 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { CommandBuffer::Builder then_builder = // Then body: b = a + b [&](CommandBuffer* then_cmd) { - return then_cmd->Launch(add, ThreadDim(), BlockDim(length), {}, a, b, b) + return then_cmd + ->CreateLaunch(add, ThreadDim(), BlockDim(length), {}, a, b, b) .status(); }; auto nested_cmd = executor->CreateCommandBuffer(nested).value(); // TODO(b/339653343): Adding this Case condition causes AddNestedCommandBuffer // to fail. - TF_ASSERT_OK(nested_cmd->Case(pred_then, {then_builder, then_builder}, {})); + TF_ASSERT_OK( + nested_cmd->CreateCase(pred_then, {then_builder, then_builder}, {})); // Loop cond: loop_counter++ < num_iters; CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { return cond_cmd - ->Launch(inc_and_cmp, ThreadDim(), BlockDim(length), {}, loop_counter, - pred, num_iters) + ->CreateLaunch(inc_and_cmp, ThreadDim(), BlockDim(length), {}, + loop_counter, pred, num_iters) .status(); }; CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) -> absl::Status { - CHECK_OK(body_cmd->AddNestedCommandBuffer(*nested_cmd, {})); + CHECK_OK(body_cmd->CreateNestedCommand(*nested_cmd, {})); return absl::OkStatus(); }; // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder, {})); + TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, cond_builder, body_builder, {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -798,7 +803,8 @@ static void BM_CreateCommandBuffer(benchmark::State& state) { for (auto s : state) { auto cmd_buffer = executor->CreateCommandBuffer(nested).value(); for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), {}, b, b, b)); + CHECK_OK( + cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, b, b, b)); } CHECK_OK(cmd_buffer->Finalize()); } @@ -845,14 +851,16 @@ static void BM_UpdateCommandBuffer(benchmark::State& state) { auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), {}, b, b, b)); + CHECK_OK( + cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, b, b, b)); } CHECK_OK(cmd_buffer->Finalize()); for (auto s : state) { CHECK_OK(cmd_buffer->Update()); for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), {}, b, b, b)); + CHECK_OK( + cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, b, b, b)); } CHECK_OK(cmd_buffer->Finalize()); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc index a989c2f5797ca3..27c4e5d0572046 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc @@ -491,29 +491,4 @@ absl::Status RocmCommandBuffer::CheckCanBeUpdated() { return absl::OkStatus(); } -absl::StatusOr> -RocmCommandBuffer::GetNodeDependencies(const GraphNodeHandle node) { - VLOG(2) << "Get HIP graph node " << node << " dependencies"; - - std::vector dependencies; - - size_t num_dependencies = 0; - TF_RETURN_IF_ERROR( - ToStatus(hipGraphNodeGetDependencies(ToHipGraphHandle(node), nullptr, - &num_dependencies), - "Failed to get HIP graph node depedencies size")); - - dependencies.resize(num_dependencies, nullptr); - TF_RETURN_IF_ERROR(ToStatus( - hipGraphNodeGetDependencies(ToHipGraphHandle(node), dependencies.data(), - &num_dependencies), - "Failed to get HIP graph node depedencies")); - - std::vector result; - result.reserve(dependencies.size()); - absl::c_transform( - dependencies, std::back_inserter(result), - static_cast(&FromHipGraphHandle)); - return result; -} } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h index 2d58554b1f8413..1ac7412a8b8955 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h @@ -161,9 +161,6 @@ class RocmCommandBuffer : public GpuCommandBuffer { absl::Status CheckCanBeUpdated() override; - absl::StatusOr> GetNodeDependencies( - GraphNodeHandle node) override; - static_assert(std::is_pointer_v, "hipGraph_t must be a pointer"); static_assert(std::is_pointer_v, "hipGraphExec_t must be a pointer"); From 1ae55ba00c49cdfe7300f592530cb01c0f2b4ca5 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 11 Apr 2025 15:07:46 -0700 Subject: [PATCH 0598/1324] Direct HLO->StableHLO for communication ops with tokens. Allow mhlo::AddDependencyOp to take either a mhlo::TokenType or a StableHLO token type. PiperOrigin-RevId: 746604336 --- .../compiler/mlir/tf2xla/transforms/BUILD | 6 +- .../mlir/tf2xla/transforms/tf2xla_rewriter.cc | 25 +- .../mlir/tf2xla/transforms/tf2xla_rewriter.h | 10 +- .../tf2xla/transforms/tf2xla_rewriter_test.cc | 17 +- .../xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 3 +- .../translate/hlo_to_mhlo/async_importer.cc | 70 +++-- .../translate/hlo_to_mhlo/async_importer.h | 3 +- .../hlo_to_mhlo/attribute_importer.cc | 81 ++++-- .../hlo_to_mhlo/attribute_importer.h | 15 + .../hlo_to_mhlo/hlo_function_importer.cc | 180 ++++++------ .../xla/hlo/translate/hlo_to_mhlo/hlo_utils.h | 3 +- .../tests/import_emit_stablehlo.hlo | 270 +++++++++--------- third_party/xla/xla/mlir_hlo/BUILD | 2 +- .../xla/xla/mlir_hlo/mhlo/IR/hlo_base.td | 4 + .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 1 + .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td | 4 +- .../hlo_legalize_to_stablehlo_pass.cc | 45 ++- .../stablehlo_legalize_to_hlo_pass.cc | 30 ++ .../hlo-legalize-to-stablehlo-partial.mlir | 16 ++ .../mhlo/stablehlo-legalize-to-hlo.mlir | 16 ++ 20 files changed, 488 insertions(+), 313 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 87af60a1cfabe3..1da9a071a3c0f3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -324,7 +324,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/framework:allocator", "//tensorflow/core/protobuf:for_core_protos_cc", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -352,6 +351,7 @@ cc_library( "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:stablehlo_ops", ], ) @@ -366,9 +366,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:framework", "//tensorflow/core:ops", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -381,11 +379,11 @@ tf_cc_test( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder:xla_computation", - "@local_xla//xla/mlir_hlo", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 161ae934df7d05..f35af77a1e6082 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -50,6 +50,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -154,7 +155,7 @@ Tf2XlaRewriter::~Tf2XlaRewriter() { if (context_) context_->Unref(); } -absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( +absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( XlaComputation& computation) { xla::DebugOptions debug_options; TF_ASSIGN_OR_RETURN(auto hlo_module_config, @@ -193,8 +194,8 @@ absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( xla::HloFunctionImporter::ImportInstructions( *hlo_module->entry_computation(), arguments, symbol_table, &builder)); - mhlo::TupleOp root_tuple = - mlir::dyn_cast_or_null(root_value.getDefiningOp()); + stablehlo::TupleOp root_tuple = + mlir::dyn_cast_or_null(root_value.getDefiningOp()); if (!root_tuple) { return tsl::errors::InvalidArgument( "Imported XLA Root Value is not a tuple op"); @@ -410,23 +411,23 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { if (failed(VerifyOpResults(op_context))) return failure(); - absl::StatusOr tuple_result_or_status = + absl::StatusOr tuple_result_or_status = CompileWithHloImporter(op_context); if (!tuple_result_or_status.ok()) { return op_->emitRemark() << tuple_result_or_status.status().ToString(); } - mhlo::TupleOp tuple_result = tuple_result_or_status.value(); + stablehlo::TupleOp tuple_result = tuple_result_or_status.value(); - llvm::SmallVector output_values; - if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { - return failure(); - } + llvm::SmallVector output_values; + if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { + return failure(); + } rewriter_.replaceOp(op_, output_values); return success(); } -absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( +absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( tensorflow::OpKernelContext& op_context) { // XLA can only return a single value. Wrap all output op return values // in a Tuple op that gets unpacked later. @@ -470,7 +471,7 @@ mlir::LogicalResult Tf2XlaRewriter::VerifyOpResults( // multiple values. We get around this by returning a tuple as an XLA op. We // then unpack it here to return the multiple values instead. mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( - mhlo::TupleOp tuple_result, llvm::SmallVector& outputs) { + stablehlo::TupleOp tuple_result, llvm::SmallVector& outputs) { if (tuple_result->getNumOperands() != op_->getNumResults()) { return op_->emitRemark() << "Translated TF2XLA tuple has different " "number of results than original op"; @@ -485,7 +486,7 @@ mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( } mlir::LogicalResult Tf2XlaRewriter::GetKernelOutputs( - tensorflow::OpKernelContext& op_context, mhlo::TupleOp tuple_results, + tensorflow::OpKernelContext& op_context, stablehlo::TupleOp tuple_results, llvm::SmallVector& outputs) { outputs.reserve(op_->getNumResults()); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index c5c417e27ba022..c89316638a2ea5 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -28,12 +28,12 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/op_kernel.h" @@ -58,12 +58,12 @@ class Tf2XlaRewriter { // Compiles the given Operation with XlaBuilder and imports the generated HLO // via the HLO -> MHLO importer. - absl::StatusOr CompileWithHloImporter( + absl::StatusOr CompileWithHloImporter( tensorflow::OpKernelContext& op_context); // Import the given XlaComputation into the parent module. Returns the given // generated function. - absl::StatusOr ImportXlaComputation( + absl::StatusOr ImportXlaComputation( xla::XlaComputation& computation); // Prepares OpKernelContext params common to all the ops. @@ -83,12 +83,12 @@ class Tf2XlaRewriter { mlir::LogicalResult VerifyOpResults(tensorflow::OpKernelContext& op_context); mlir::LogicalResult GetKernelOutputs(tensorflow::OpKernelContext& op_context, - mhlo::TupleOp tuple_results, + stablehlo::TupleOp tuple_results, llvm::SmallVector& outputs); // Given a translated function with a single return value, unpack the tuple // results. - mlir::LogicalResult UnpackTupleResults(mhlo::TupleOp tuple_result, + mlir::LogicalResult UnpackTupleResults(stablehlo::TupleOp tuple_result, llvm::SmallVector& outputs); // Tries to legalize the specified TensorFlow op, if supported. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index eaad485ccab96a..e20be6bb9a173c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -33,20 +34,19 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tensorflow/core/framework/op_kernel.h" namespace mlir { namespace mhlo { @@ -102,7 +102,7 @@ class Tf2XlaRewriterTestPeer { tf2xla_rewriter_(op, empty_rewriter_, /*device_type=*/"XLA_CPU_JIT") {} - absl::StatusOr ImportXlaComputationIntoModule( + absl::StatusOr ImportXlaComputationIntoModule( XlaComputation& computation) { return tf2xla_rewriter_.ImportXlaComputation(computation); } @@ -184,7 +184,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { return main_func.getBody().front().front(); } - absl::StatusOr ImportXlaComputationIntoModule( + absl::StatusOr ImportXlaComputationIntoModule( XlaComputation& computation) { SourceMgrDiagnosticHandler sourceMgrHandler(source_manager_, &context_); @@ -204,7 +204,8 @@ TEST_F(Tf2XlaRewriterTest, LegalizesOpWithTf2xlaHloImporter) { TF_EXPECT_OK(LegalizeModule()); int num_tuple_ops = 0; - module_->walk([&num_tuple_ops](TupleOp tuple_op) { num_tuple_ops += 1; }); + module_->walk( + [&num_tuple_ops](stablehlo::TupleOp tuple_op) { num_tuple_ops += 1; }); EXPECT_EQ(num_tuple_ops, 0); } @@ -214,7 +215,7 @@ TEST_F(Tf2XlaRewriterTest, ImportsXlaComputationIntoModule) { XlaComputation computation = GetTestXlaComputation(); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + TF_ASSERT_OK_AND_ASSIGN(stablehlo::TupleOp root_tuple, ImportXlaComputationIntoModule(computation)); ModuleOp parent_module = @@ -261,7 +262,7 @@ TEST_F(Tf2XlaRewriterTest, ImportsSingleComputation) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + TF_ASSERT_OK_AND_ASSIGN(stablehlo::TupleOp root_tuple, ImportXlaComputationIntoModule(computation)); EXPECT_TRUE(root_tuple); @@ -356,7 +357,7 @@ TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - absl::StatusOr status_or_tuple_op = + absl::StatusOr status_or_tuple_op = ImportXlaComputationIntoModule(computation); EXPECT_FALSE(status_or_tuple_op.ok()); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index 8b5a77f674036f..9b1d5d88306774 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -21,6 +21,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/mlir_hlo", "//xla/service:hlo_proto_cc", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -37,7 +38,6 @@ cc_library( deps = [ ":attribute_importer", ":hlo_utils", - "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -49,6 +49,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc index 4cc344f80a16cd..bbd49177867e74 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -107,10 +108,10 @@ absl::StatusOr ImportOldStyleAsyncStart( // Attach the frontend_attributes and sharding attributes to the async op // instead of the sync op. First, semantically sharding attributes cannot be // attached to the sync op since the sync op may not produce the same number - // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO - // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the - // `mhlo.async_start` ops, so attaching them to the sync op will make them - // disappear during MHLO to HLO lowering. + // of results as the sharding's tuple element count, e.g., `stablehlo.send` + // vs. HLO `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from + // the `mhlo.async_start` ops, so attaching them to the sync op will make them + // disappear during StableHLO/MHLO to HLO lowering. for (auto it = attributes.begin(); it != attributes.end();) { if (it->getName() == kShardingAttr || it->getName() == kFrontendAttributesAttr) { @@ -195,7 +196,8 @@ absl::StatusOr ImportSend( channel_handle.set_type(send_op->is_host_transfer() ? ChannelHandle::DEVICE_TO_HOST : ChannelHandle::DEVICE_TO_DEVICE); - attributes.push_back(ConvertChannelHandle(channel_handle, builder)); + attributes.push_back( + stablehlo::ConvertChannelHandle(channel_handle, builder)); } bool isPipelined = @@ -219,22 +221,22 @@ absl::StatusOr ImportSend( auto async_bundled_tuple = mlir::TupleType::get( builder->getContext(), {async_arg_type, result_types[2], result_types[1]}); - return ImportOldStyleAsyncStart( + return ImportOldStyleAsyncStart( symbol_table, attributes, operands, loc, async_bundled_tuple, builder, "send_", [](auto) { return absl::OkStatus(); }); } // Otherwise return send op for non-pipelined send. - // Skip empty data in MLIR send(tuple<>, token) --> mhlo.send(token) + // Skip empty data in MLIR send(tuple<>, token) --> stablehlo.send(token) auto token = operands[1]; llvm::ArrayRef args = operands; if (args.size() == 2 && IsEmptyTuple(args[0].getType())) { args = args.drop_front(1); } - auto send = - builder - ->create(loc, token.getType(), args, attributes) - .getOperation(); + auto send = builder + ->create(loc, token.getType(), args, + attributes) + .getOperation(); if (instruction->has_sharding()) { const HloSharding& sharding = instruction->sharding(); if (sharding.IsTuple() && sharding.tuple_elements().size() == 3) { @@ -268,7 +270,8 @@ absl::StatusOr ImportRecv( channel_handle.set_type(recv_op->is_host_transfer() ? ChannelHandle::HOST_TO_DEVICE : ChannelHandle::DEVICE_TO_DEVICE); - attributes.push_back(ConvertChannelHandle(channel_handle, builder)); + attributes.push_back( + stablehlo::ConvertChannelHandle(channel_handle, builder)); } // Currently only consolidates async recv with result, 0-result recv uses old @@ -294,14 +297,14 @@ absl::StatusOr ImportRecv( auto async_result_type_tuple = builder->getTupleType(async_result_types); auto async_bundled_tuple = builder->getTupleType( {result_types[2], async_result_type_tuple, result_types[1]}); - return ImportOldStyleAsyncStart( + return ImportOldStyleAsyncStart( symbol_table, attributes, operands, loc, async_bundled_tuple, builder, "recv_", [](auto) { return absl::OkStatus(); }); } // Return recv op for non-pipelined send, skip empty tuple result type if (!IsEmptyTuple(result_types[0])) { - auto recv = builder->create( + auto recv = builder->create( loc, llvm::SmallVector{result_types[0], result_types[2]}, operands, attributes); if (instruction->has_sharding()) { @@ -324,13 +327,13 @@ absl::StatusOr ImportRecv( // Recv with no result, only token. // To keep parity, if op only returns token, wrap in tuple, token> - auto recv = builder->create( + auto recv = builder->create( loc, llvm::SmallVector{result_types[2]}, operands, attributes); - auto empty_tuple = - builder->create(loc, llvm::ArrayRef{}); + auto empty_tuple = builder->create( + loc, llvm::ArrayRef{}); - return builder->create( + return builder->create( loc, llvm::ArrayRef{empty_tuple.getResult(), recv.getResult(0)}); } @@ -350,22 +353,24 @@ absl::StatusOr ImportAllGatherStart( attributes.push_back( ConvertReplicaGroups(all_gather_start->replica_groups(), builder)); if (all_gather_start->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(all_gather_start->channel_id().value(), builder)); + attributes.push_back(stablehlo::ConvertChannelHandle( + all_gather_start->channel_id().value(), builder)); if (all_gather_start->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder)); if (all_gather_start->operands().size() > 1) - return InvalidArgument("Async tuple all-gather is not supported in MHLO"); + return InvalidArgument( + "Async tuple all-gather is not supported in StableHLO"); if (!llvm::isa(result_type)) { // Async AllGather's output type is bundle // There are some instances where the output type is not a tuple, this seems - // to be the more modern case, so we will wrap these in a tuple for MHLO. + // to be the more modern case, so we will wrap these in a tuple for + // StableHLO. result_type = mlir::TupleType::get(builder->getContext(), {operands[0].getType(), result_type}); } - return ImportOldStyleAsyncStart( + return ImportOldStyleAsyncStart( symbol_table, attributes, operands, loc, result_type, builder, "all_gather_", [](auto) { return absl::OkStatus(); }); } @@ -375,28 +380,30 @@ absl::StatusOr ImportAllReduceStart( const llvm::SmallVectorImpl& operands, llvm::SmallVectorImpl& attributes, mlir::Type result_type, mlir::OpBuilder* builder, - std::function mutate_op, + std::function mutate_op, mlir::SymbolTable& symbol_table) { auto all_reduce_start = Cast(instruction); attributes.push_back( ConvertReplicaGroups(all_reduce_start->replica_groups(), builder)); if (all_reduce_start->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(all_reduce_start->channel_id().value(), builder)); + attributes.push_back(stablehlo::ConvertChannelHandle( + all_reduce_start->channel_id().value(), builder)); if (all_reduce_start->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder)); if (all_reduce_start->operands().size() > 1) - return InvalidArgument("Async tuple all-reduce is not supported in MHLO"); + return InvalidArgument( + "Async tuple all-reduce is not supported in StableHLO"); if (!llvm::isa(result_type)) { // Async AllReduce's output type is bundle // There are some instances where the output type is not a tuple, this seems - // to be the more modern case, so we will wrap these in a tuple for MHLO. + // to be the more modern case, so we will wrap these in a tuple for + // StableHLO. result_type = mlir::TupleType::get(builder->getContext(), {operands[0].getType(), result_type}); } - return ImportOldStyleAsyncStart( + return ImportOldStyleAsyncStart( symbol_table, attributes, operands, loc, result_type, builder, "all_reduce_", mutate_op); } @@ -414,11 +421,12 @@ absl::StatusOr ImportCollectivePermuteStart( if (!llvm::isa(result_type)) { // Async CollectivePermute's output type is bundle // There are some instances where the output type is not a tuple, this seems - // to be the more modern case, so we will wrap these in a tuple for MHLO. + // to be the more modern case, so we will wrap these in a tuple for + // StableHLO. result_type = mlir::TupleType::get(builder->getContext(), {operands[0].getType(), result_type}); } - return ImportOldStyleAsyncStart( + return ImportOldStyleAsyncStart( symbol_table, attributes, operands, loc, result_type, builder, "collective_permute_", [&](auto) { return absl::OkStatus(); }); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h index 116d17f86c7bc0..30fc9fbd125e42 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -62,7 +63,7 @@ absl::StatusOr ImportAllReduceStart( const llvm::SmallVectorImpl& operands, llvm::SmallVectorImpl& attributes, mlir::Type result_type, mlir::OpBuilder* builder, - std::function mutate_op, + std::function mutate_op, mlir::SymbolTable& symbol_table); absl::StatusOr ImportCollectivePermuteStart( diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc index 9f449a64c5c794..cf411420bbac99 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -87,6 +88,43 @@ mlir::stablehlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( scatter_dims_to_operand_dims, dnums.index_vector_dim()); } +mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, + mlir::Builder* builder) { + return builder->getNamedAttr( + "channel_handle", + mlir::stablehlo::ChannelHandleAttr::get( + builder->getContext(), channel.handle(), channel.type())); +} + +mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, + mlir::Builder* builder) { + ChannelHandle channel_handle; + if (channel_id) channel_handle.set_handle(*channel_id); + return stablehlo::ConvertChannelHandle(channel_handle, builder); +} + +absl::StatusOr +ConvertCustomCallApiVersion(xla::CustomCallApiVersion api_version) { + switch (api_version) { + case xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED: + return mlir::stablehlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED; + case xla::CustomCallApiVersion::API_VERSION_ORIGINAL: + return mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL; + case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + return mlir::stablehlo::CustomCallApiVersion:: + API_VERSION_STATUS_RETURNING; + case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + return mlir::stablehlo::CustomCallApiVersion:: + API_VERSION_STATUS_RETURNING_UNIFIED; + case xla::CustomCallApiVersion::API_VERSION_TYPED_FFI: + return mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI; + default: + return InvalidArgument("Unknown CustomCallApiVersion enum value #%d (%s)", + api_version, + xla::CustomCallApiVersion_Name(api_version)); + } +} + mlir::stablehlo::DotAlgorithmAttr ConvertDotAlgorithm( const PrecisionConfig::Algorithm algorithm, mlir::Builder* builder) { mlir::Type lhs, rhs, accum; @@ -175,6 +213,23 @@ mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( arrayref(dnums.rhs_contracting_dimensions())); } +mlir::ArrayAttr ConvertOutputOperandAliasing( + const std::vector>>& aliasInfo, + mlir::Builder* builder) { + auto arrayref = [](absl::Span array) { + return llvm::ArrayRef{array.data(), array.size()}; + }; + std::vector attrs; + for (auto& [output_tuple_idx, operand_idx] : aliasInfo) { + auto attr = mlir::stablehlo::OutputOperandAliasAttr::get( + builder->getContext(), arrayref(output_tuple_idx), operand_idx.first, + arrayref(operand_idx.second)); + attrs.push_back(attr); + } + return builder->getArrayAttr(attrs); +} + mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, mlir::Builder* builder) { if (!config) return {}; @@ -439,23 +494,14 @@ absl::StatusOr ConvertTranspose( absl::StatusOr ConvertCustomCallApiVersion( xla::CustomCallApiVersion api_version) { - switch (api_version) { - case xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED: - return mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED; - case xla::CustomCallApiVersion::API_VERSION_ORIGINAL: - return mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL; - case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - return mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING; - case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: - return mlir::mhlo::CustomCallApiVersion:: - API_VERSION_STATUS_RETURNING_UNIFIED; - case xla::CustomCallApiVersion::API_VERSION_TYPED_FFI: - return mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI; - default: - return InvalidArgument("Unknown CustomCallApiVersion enum value #%d (%s)", - api_version, - xla::CustomCallApiVersion_Name(api_version)); - } + TF_ASSIGN_OR_RETURN(auto stablehlo_api_version, + stablehlo::ConvertCustomCallApiVersion(api_version)); + auto mhlo_api_version = mlir::mhlo::symbolizeCustomCallApiVersion( + mlir::stablehlo::stringifyCustomCallApiVersion(stablehlo_api_version)); + if (!mhlo_api_version.has_value()) + return InvalidArgument("Unknown CustomCallApiVersion enum value #%d", + api_version); + return mhlo_api_version.value(); } mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, @@ -465,6 +511,7 @@ mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, mlir::mhlo::ChannelHandleAttr::get(builder->getContext(), channel.handle(), channel.type())); } + mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, mlir::Builder* builder) { ChannelHandle channel_handle; diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h index 6cace3ee16a72c..eb857cf0053528 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h @@ -35,6 +35,15 @@ limitations under the License. namespace xla { namespace stablehlo { +// Converts the channel handle to attributes. +mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, + mlir::Builder* builder); +mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, + mlir::Builder* builder); + +absl::StatusOr +ConvertCustomCallApiVersion(xla::CustomCallApiVersion api_version); + // Converts the gather dimensions to attributes. mlir::stablehlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); @@ -51,6 +60,12 @@ mlir::stablehlo::DotAlgorithmAttr ConvertDotAlgorithm( mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( const DotDimensionNumbers& dnums, mlir::Builder* builder); +// Converts the output operand aliasing to attributes. +mlir::ArrayAttr ConvertOutputOperandAliasing( + const std::vector>>& aliaInfo, + mlir::Builder* builder); + // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, mlir::Builder* builder); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index a362913ba730e8..8b35075984b148 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -1109,12 +1109,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back(ConvertReplicaGroups( collective_broadcast->replica_groups(), builder_)); if (collective_broadcast->channel_id().has_value()) - attributes.push_back(ConvertChannelHandle( + attributes.push_back(stablehlo::ConvertChannelHandle( collective_broadcast->channel_id().value(), builder_)); - // TODO(b/408024772) ToStablehlo: ConvertChannelHandle return func_builder - ->create(loc, result_type, - operands, attributes) + ->create(loc, result_type, + operands, attributes) .getOperation(); } @@ -1123,12 +1122,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back(ConvertSourceTargetPairs( collective_permute->source_target_pairs(), builder_)); if (collective_permute->channel_id().has_value()) - attributes.push_back(ConvertChannelHandle( + attributes.push_back(stablehlo::ConvertChannelHandle( collective_permute->channel_id().value(), builder_)); - // TODO(b/408024772) ToStablehlo: ConvertChannelHandle return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, + operands, attributes) .getOperation(); } case HloOpcode::kCollectivePermuteStart: { @@ -1181,11 +1179,6 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( builder_->getNamedAttr("result_layouts", result_layouts)); } - attributes.push_back( - ConvertCustomCallSchedule(custom_call->custom_call_schedule())); - TF_ASSIGN_OR_RETURN( - auto mlir_api_version, - ConvertCustomCallApiVersion(custom_call->api_version())); attributes.push_back(builder_->getNamedAttr( "call_target_name", builder_->getStringAttr(custom_call->custom_call_target()))); @@ -1224,19 +1217,42 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( builder_->getNamedAttr("mhlo.literal", attr.value())); } + // StableHLO CustomCall doesn't have a schedule attribute + if (custom_call->custom_call_schedule() != + CustomCallSchedule::SCHEDULE_NONE) { + attributes.push_back( + ConvertCustomCallSchedule(custom_call->custom_call_schedule())); + TF_ASSIGN_OR_RETURN( + auto mlir_api_version, + ConvertCustomCallApiVersion(custom_call->api_version())); + attributes.push_back(builder_->getNamedAttr( + "api_version", mlir::mhlo::CustomCallApiVersionAttr::get( + builder_->getContext(), mlir_api_version))); + attributes.push_back(builder_->getNamedAttr( + "output_operand_aliases", + ConvertOutputOperandAliasing(instruction->output_operand_aliasing(), + builder_))); + // XLA Feature - MHLO Only + return func_builder + ->create(loc, result_type, operands, + attributes) + .getOperation(); + } + + // Valid StableHLO CustomCall + TF_ASSIGN_OR_RETURN( + auto mlir_api_version, + stablehlo::ConvertCustomCallApiVersion(custom_call->api_version())); attributes.push_back(builder_->getNamedAttr( - "api_version", mlir::mhlo::CustomCallApiVersionAttr::get( + "api_version", mlir::stablehlo::CustomCallApiVersionAttr::get( builder_->getContext(), mlir_api_version))); attributes.push_back(builder_->getNamedAttr( "output_operand_aliases", - ConvertOutputOperandAliasing(instruction->output_operand_aliasing(), - builder_))); - - // TODO(b/408024772) ToStablehlo: Special handling needed for CC schedules - // that aren't NONE and convert CC attrs to StableHLO + stablehlo::ConvertOutputOperandAliasing( + instruction->output_operand_aliasing(), builder_))); return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kCompare: { @@ -1319,8 +1335,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( llvm::SmallVector flattened_ret_types; FlattenTupleType(result_type, flattened_ret_types); - // TODO(b/408024772) ToStableHLO: result_type needs to be StableHLO Token - auto op = func_builder->create( + auto op = func_builder->create( loc, flattened_ret_types, operands, attributes); return CreateTupleFromOpResults(func_builder, loc, op.getOperation(), @@ -1339,8 +1354,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FlattenTupleValue(func_builder, loc, operands[0], flattened_operands); flattened_operands.push_back(operands[1]); - // TODO(b/408024772) ToStableHLO: result_type needs to be StableHLO Token - auto op = func_builder->create( + auto op = func_builder->create( loc, result_type, flattened_operands, attributes); return op.getOperation(); @@ -1578,12 +1592,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back( ConvertReplicaGroups(all_gather->replica_groups(), builder_)); if (all_gather->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(all_gather->channel_id().value(), builder_)); + attributes.push_back(stablehlo::ConvertChannelHandle( + all_gather->channel_id().value(), builder_)); if (all_gather->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); - // TODO(b/408024772) ToStablehlo: ConvertChannelHandleToStablehlo - auto all_gather_op = func_builder->create( + auto all_gather_op = func_builder->create( loc, result_types, operands, attributes); if (result_tuple_ty) { return WrapInTuple(func_builder, all_gather_op); @@ -1610,12 +1623,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back( ConvertReplicaGroups(all_reduce->replica_groups(), builder_)); if (all_reduce->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(all_reduce->channel_id().value(), builder_)); + attributes.push_back(stablehlo::ConvertChannelHandle( + all_reduce->channel_id().value(), builder_)); if (all_reduce->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); - // TODO(b/408024772) ToStablehlo: ConvertChannelHandleToStablehlo - auto all_reduce_op = func_builder->create( + auto all_reduce_op = func_builder->create( loc, result_types, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(), &all_reduce_op.getComputation())); @@ -1625,9 +1637,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return all_reduce_op.getOperation(); } case HloOpcode::kAllReduceStart: { - // TODO(b/408024772) ToStablehlo: Special handling needed for - // AllReduceStart. - auto appendRegion = [&](mlir::mhlo::AllReduceOp all_reduce_sync) { + auto appendRegion = [&](mlir::stablehlo::AllReduceOp all_reduce_sync) { TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(), &all_reduce_sync.getComputation())); return absl::OkStatus(); @@ -1638,62 +1648,64 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( symbol_table_); } case HloOpcode::kAllReduceDone: { - // TODO(b/408024772) ToStablehlo: Special handling needed for - // AllReduceStart. return ImportAsyncOpDone(instruction, loc, operands, attributes, result_type, func_builder); } case HloOpcode::kAllToAll: { auto all_to_all = Cast(instruction); - auto result_tuple_ty = result_type.dyn_cast(); + + auto replica_groups_attr = llvm::cast( + ConvertReplicaGroups(all_to_all->replica_groups(), builder_) + .getValue()); // Check invariants of array all-to-all. This is a sanity check and is // verified by the HLO verifier. + auto result_tuple_ty = llvm::dyn_cast(result_type); if (result_tuple_ty) { + // Tuple all-to-all is MHLO only for now. if (all_to_all->split_dimension().has_value()) { return InvalidArgument( "Tuple all-to-all should not have a split dimension"); } - } else { - if (!all_to_all->split_dimension().has_value() || - operands.size() != 1 || all_to_all->replica_groups().empty()) { - return InvalidArgument( - "Array all-to-all should have a split dimension, one operand and " - "non-empty replica groups"); + llvm::SmallVector return_types = + llvm::to_vector<4>(result_tuple_ty.getTypes()); + mlir::mhlo::ChannelHandleAttr channel_handle_attr{}; + if (all_to_all->channel_id().has_value()) { + channel_handle_attr = llvm::cast( + ConvertChannelHandle(all_to_all->channel_id().value(), builder_) + .getValue()); } - } - - auto replica_groups_attr = - ConvertReplicaGroups(all_to_all->replica_groups(), builder_) - .getValue() - .cast(); - llvm::SmallVector return_types = {result_type}; - if (result_tuple_ty) { - return_types = llvm::to_vector<4>(result_tuple_ty.getTypes()); + // TODO(b/408024772) Fix StableHLO AllToAll to support tuple types the + // way that XLA expects it. Currently it is mis-designed. + // XLA Feature -- MHLO Only + auto result = func_builder->create( + loc, return_types, operands, nullptr, nullptr, nullptr, + replica_groups_attr, channel_handle_attr); + return WrapInTuple(func_builder, result); + } + // Array AllToAll + if (!all_to_all->split_dimension().has_value() || operands.size() != 1 || + all_to_all->replica_groups().empty()) { + return InvalidArgument( + "Array all-to-all should have a split dimension, one operand and " + "non-empty replica groups"); } - // TODO(b/408024772) ToStablehlo: mhlo AllToAll and StableHLO AllToAll are - // different currently for multiple arguments, we need to fix this. - // Additionally ConvertChannelHandle changes are needed. - auto result = func_builder->create( - loc, return_types, operands, nullptr, nullptr, nullptr, - replica_groups_attr); - + mlir::stablehlo::ChannelHandleAttr channel_handle_attr{}; if (all_to_all->channel_id().has_value()) { - auto handle = - ConvertChannelHandle(all_to_all->channel_id().value(), builder_); - result.setChannelHandleAttr( - handle.getValue().cast()); + channel_handle_attr = llvm::cast( + stablehlo::ConvertChannelHandle(all_to_all->channel_id().value(), + builder_) + .getValue()); } - if (result_tuple_ty) { - return WrapInTuple(func_builder, result); - } + auto result = func_builder->create( + loc, result_type, operands[0], all_to_all->split_dimension().value(), + all_to_all->split_dimension().value(), + all_to_all->replica_groups()[0].replica_ids_size(), + replica_groups_attr, channel_handle_attr); - result.setSplitDimension(all_to_all->split_dimension().value()); - result.setConcatDimension(all_to_all->split_dimension().value()); - result.setSplitCount(all_to_all->replica_groups()[0].replica_ids_size()); return result.getOperation(); } case HloOpcode::kReduce: { @@ -1879,13 +1891,12 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back( ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_)); if (reduce_scatter->channel_id().has_value()) - attributes.push_back(ConvertChannelHandle( + attributes.push_back(stablehlo::ConvertChannelHandle( reduce_scatter->channel_id().value(), builder_)); if (reduce_scatter->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); - // TODO(b/408024772) ToStablehlo: ConvertChannelHandle auto reduce_scatter_op = - func_builder->create( + func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(), &reduce_scatter_op.getComputation())); @@ -2055,22 +2066,17 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } } case HloOpcode::kAfterAll: { - // HLO AfterAll ops without any token input are used to just create a - // token. MHLO has a special op CreateToken for this case. + // HLO AfterAll ops without any token input are used to create a token. + // TODO(b/408024772): Remove CreateTokenOp, it is redundant. if (instruction->operands().empty()) { - // TODO(b/408024772) ToStablehlo: result_type needs to be a - // stablehlo::TokenType. - // Also, remove CreateTokenOp usage to AfterAllOp. return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } else { - // TODO(b/408024772) ToStablehlo: result_type needs to be a - // stablehlo::TokenType. return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } } @@ -2197,6 +2203,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp); NO_ATTRIBUTE_CASE(kSign, SignOp); NO_ATTRIBUTE_CASE(kSubtract, SubtractOp); + NO_ATTRIBUTE_CASE(kTuple, TupleOp); NO_ATTRIBUTE_CASE(kXor, XorOp); #undef NO_ATTRIBUTE_CASE @@ -2211,9 +2218,6 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE_MHLO(kCopy, CopyOp); NO_ATTRIBUTE_CASE_MHLO(kErf, ErfOp); NO_ATTRIBUTE_CASE_MHLO(kStochasticConvert, StochasticConvertOp); - // TODO(b/408024772) ToStablehlo: Once all tokens are stablehlo.token move - // to NO_ATTRIBUTE_CASE. - NO_ATTRIBUTE_CASE_MHLO(kTuple, TupleOp); #undef NO_ATTRIBUTE_CASE_MHLO diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h index 8c85eec6d711c3..19ef3e82fe760c 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -170,7 +171,7 @@ static absl::StatusOr ConvertShapeToType(const Shape& shape, return builder.getTupleType(contents); } if (shape.IsToken()) { - return mlir::mhlo::TokenType::get(builder.getContext()); + return mlir::stablehlo::TokenType::get(builder.getContext()); } return ConvertTensorShapeToType(shape, builder); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo index 0fe3161a055ff9..1d13766c9e1b5d 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo @@ -41,9 +41,9 @@ ENTRY %main.3 (Arg_0.1: pred[2]) -> pred[2] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_2:.*]] = mhlo.after_all %[[VAL_0]], %[[VAL_1]] {xla_shape = "token[]"} : !mhlo.token -// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token, %[[VAL_1:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_2:.*]] = stablehlo.after_all %[[VAL_0]], %[[VAL_1]] {xla_shape = "token[]"} : !stablehlo.token +// CHECK: return %[[VAL_2]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[], token[])->token[]} @@ -57,9 +57,9 @@ ENTRY %main.4 (Arg_0.1: token[], Arg_1.2: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main() -> !mhlo.token { -// CHECK: %[[VAL_0:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token -// CHECK: return %[[VAL_0]] : !mhlo.token +// CHECK: func.func @main() -> !stablehlo.token { +// CHECK: %[[VAL_0:.*]] = stablehlo.create_token {xla_shape = "token[]"} : !stablehlo.token +// CHECK: return %[[VAL_0]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={()->token[]} @@ -76,10 +76,10 @@ ENTRY %main.2 () -> token[] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<5xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<5xf32> // CHECK: return %[[VAL_4]] : tensor<5xf32> // CHECK: } @@ -101,7 +101,7 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[5] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> { -// CHECK: %[[VAL_1:.*]] = "mhlo.all_gather"(%[[VAL_0]]) <{all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> : (tensor<128x32xf32>) -> tensor<128x128xf32> +// CHECK: %[[VAL_1:.*]] = "stablehlo.all_gather"(%[[VAL_0]]) <{all_gather_dim = 1 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> : (tensor<128x32xf32>) -> tensor<128x128xf32> // CHECK: return %[[VAL_1]] : tensor<128x128xf32> // CHECK: } // CHECK: } @@ -116,7 +116,7 @@ ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,128] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> { -// CHECK: %[[VAL_1:.*]] = "mhlo.all_gather"(%[[VAL_0]]) <{all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<128x32xf32>) -> tensor<128x128xf32> +// CHECK: %[[VAL_1:.*]] = "stablehlo.all_gather"(%[[VAL_0]]) <{all_gather_dim = 1 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<128x32xf32>) -> tensor<128x128xf32> // CHECK: return %[[VAL_1]] : tensor<128x128xf32> // CHECK: } // CHECK: } @@ -131,8 +131,8 @@ ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,128] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<8x2xf32>, %[[VAL_1:.*]]: tensor<8x4xf32>) -> tuple, tensor<8x16xf32>> { -// CHECK: %[[VAL_2:.*]]:2 = "mhlo.all_gather"(%[[VAL_0]], %[[VAL_1]]) <{all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) -// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[8,8]{1,0}, f32[8,16]{1,0})"} : tuple, tensor<8x16xf32>> +// CHECK: %[[VAL_2:.*]]:2 = "stablehlo.all_gather"(%[[VAL_0]], %[[VAL_1]]) <{all_gather_dim = 1 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[8,8]{1,0}, f32[8,16]{1,0})"} : tuple, tensor<8x16xf32>> // CHECK: return %[[VAL_3]] : tuple, tensor<8x16xf32>> // CHECK: } // CHECK: } @@ -158,10 +158,10 @@ ENTRY %main.10 (Arg_0.1: f32[8,2], Arg_1.2: f32[8,4]) -> (f32[8,8], f32[8,16]) { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> // CHECK: return %[[VAL_4]] : tensor<10xf32> // CHECK: } @@ -187,10 +187,10 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, -1], [1, 3, 5, 6]]> : tensor<2x4xi64>}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, -1], [1, 3, 5, 6]]> : tensor<2x4xi64>}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> // CHECK: return %[[VAL_4]] : tensor<10xf32> // CHECK: } @@ -216,10 +216,10 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.all_reduce"(%[[VAL_3]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> // CHECK: return %[[VAL_4]] : tensor<10xf32> // CHECK: } @@ -245,12 +245,12 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<8xf32>, %[[VAL_4:.*]]: tensor) -> tuple, tensor> { -// CHECK: %[[VAL_5:.*]]:2 = "mhlo.all_reduce"(%[[VAL_3]], %[[VAL_4]]) <{replica_groups = dense<> : tensor<0x0xi64>}> ({ +// CHECK: %[[VAL_5:.*]]:2 = "stablehlo.all_reduce"(%[[VAL_3]], %[[VAL_4]]) <{replica_groups = dense<> : tensor<0x0xi64>}> ({ // CHECK: ^bb0(%[[VAL_6:.*]]: tensor, %[[VAL_7:.*]]: tensor): // CHECK: %[[VAL_8:.*]] = stablehlo.add %[[VAL_6]], %[[VAL_7]] : tensor -// CHECK: mhlo.return %[[VAL_8]] : tensor +// CHECK: stablehlo.return %[[VAL_8]] : tensor // CHECK: }) : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) -// CHECK: %[[VAL_9:.*]] = mhlo.tuple %[[VAL_10:.*]]#0, %[[VAL_10]]#1 {xla_shape = "(f32[8]{0}, f32[])"} : tuple, tensor> +// CHECK: %[[VAL_9:.*]] = stablehlo.tuple %[[VAL_10:.*]]#0, %[[VAL_10]]#1 {xla_shape = "(f32[8]{0}, f32[])"} : tuple, tensor> // CHECK: return %[[VAL_9]] : tuple, tensor> // CHECK: } // CHECK: } @@ -282,10 +282,10 @@ ENTRY %main.14 (Arg_0.1: f32[8], Arg_1.2: f32[]) -> (f32[8], f32[]) { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<5xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<5xf32> // CHECK: return %[[VAL_4]] : tensor<5xf32> // CHECK: } @@ -311,10 +311,10 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[5] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<5xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64, use_global_device_ids}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.reduce_scatter"(%[[VAL_3]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 2], [1, 3]]> : tensor<2x2xi64>, scatter_dimension = 0 : i64, use_global_device_ids}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.maximum %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<5xf32> // CHECK: return %[[VAL_4]] : tensor<5xf32> // CHECK: } @@ -337,7 +337,7 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[5] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2x2x2xf32>, %[[VAL_1:.*]]: tensor<2xf32>, %[[VAL_2:.*]]: tensor<2xf32>, %[[VAL_3:.*]]: tensor<2xf32>, %[[VAL_4:.*]]: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { // CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]], %[[VAL_7:.*]] = "stablehlo.batch_norm_grad"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) <{epsilon = 1.000000e-03 : f32, feature_index = 0 : i64}> : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -// CHECK: %[[VAL_8:.*]] = mhlo.tuple %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {xla_shape = "(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})"} : tuple, tensor<2xf32>, tensor<2xf32>> +// CHECK: %[[VAL_8:.*]] = stablehlo.tuple %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {xla_shape = "(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})"} : tuple, tensor<2xf32>, tensor<2xf32>> // CHECK: return %[[VAL_8]] : tuple, tensor<2xf32>, tensor<2xf32>> // CHECK: } // CHECK: } @@ -361,7 +361,7 @@ ENTRY %main.11 (Arg_0.1: f32[2,2,2,2], Arg_1.2: f32[2], Arg_2.3: f32[2], Arg_3.4 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2x2x2xf32>, %[[VAL_1:.*]]: tensor<2xf32>, %[[VAL_2:.*]]: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { // CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]], %[[VAL_5:.*]] = "stablehlo.batch_norm_training"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})"} : tuple, tensor<2xf32>, tensor<2xf32>> +// CHECK: %[[VAL_6:.*]] = stablehlo.tuple %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})"} : tuple, tensor<2xf32>, tensor<2xf32>> // CHECK: return %[[VAL_6]] : tuple, tensor<2xf32>, tensor<2xf32>> // CHECK: } // CHECK: } @@ -386,7 +386,7 @@ ENTRY %main.9 (Arg_0.1: f32[2,2,2,2], Arg_1.2: f32[2], Arg_2.3: f32[2]) -> (f32[ // CHECK: %[[VAL_5:.*]] = stablehlo.shift_left %[[VAL_2]], %[[VAL_3]] : tensor<4xi32> // CHECK: %[[VAL_6:.*]] = stablehlo.shift_right_arithmetic %[[VAL_2]], %[[VAL_3]] : tensor<4xi32> // CHECK: %[[VAL_7:.*]] = stablehlo.shift_right_logical %[[VAL_2]], %[[VAL_3]] : tensor<4xi32> -// CHECK: %[[VAL_8:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {xla_shape = "(f32[4]{0}, s32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>> +// CHECK: %[[VAL_8:.*]] = stablehlo.tuple %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {xla_shape = "(f32[4]{0}, s32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>> // CHECK: return %[[VAL_8]] : tuple, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>> // CHECK: } // CHECK: } @@ -452,9 +452,9 @@ ENTRY %main.3 (Arg_0.1: f32[1]) -> f32[1,10] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main() -> !mhlo.token { -// CHECK: %[[VAL_0:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token -// CHECK: return %[[VAL_0]] : !mhlo.token +// CHECK: func.func @main() -> !stablehlo.token { +// CHECK: %[[VAL_0:.*]] = stablehlo.create_token {xla_shape = "token[]"} : !stablehlo.token +// CHECK: return %[[VAL_0]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={()->token[]} @@ -467,7 +467,7 @@ ENTRY %main.2 () -> token[] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func private @empty_callee.2() -> tuple<> { -// CHECK: %[[VAL_0:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_0:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_0]] : tuple<> // CHECK: } // CHECK: func.func @main(%[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { @@ -529,14 +529,14 @@ ENTRY %main.12 (Arg_0.1: s32[4]) -> s32[4] { // CHECK: func.func private @callee.2(%[[VAL_0:.*]]: tensor<4xi32>, %[[VAL_1:.*]]: tensor<4xi32>) -> tuple, tensor<4xi32>> { // CHECK: %[[VAL_2:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor<4xi32> // CHECK: %[[VAL_3:.*]] = stablehlo.multiply %[[VAL_0]], %[[VAL_1]] : tensor<4xi32> -// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>> +// CHECK: %[[VAL_4:.*]] = stablehlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>> // CHECK: return %[[VAL_4]] : tuple, tensor<4xi32>> // CHECK: } // CHECK: func.func @main(%[[VAL_5:.*]]: tensor<4xi32>) -> tuple, tensor<4xi32>> { // CHECK: %[[VAL_6:.*]] = call @callee.2(%[[VAL_5]], %[[VAL_5]]) {xla_shape = "(s32[4]{0}, s32[4]{0})"} : (tensor<4xi32>, tensor<4xi32>) -> tuple, tensor<4xi32>> // CHECK: %[[VAL_7:.*]] = stablehlo.get_tuple_element %[[VAL_6]][0] : (tuple, tensor<4xi32>>) -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = stablehlo.get_tuple_element %[[VAL_6]][1] : (tuple, tensor<4xi32>>) -> tensor<4xi32> -// CHECK: %[[VAL_9:.*]] = mhlo.tuple %[[VAL_7]], %[[VAL_8]] {xla_shape = "(s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>> +// CHECK: %[[VAL_9:.*]] = stablehlo.tuple %[[VAL_7]], %[[VAL_8]] {xla_shape = "(s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xi32>> // CHECK: return %[[VAL_9]] : tuple, tensor<4xi32>> // CHECK: } // CHECK: } @@ -607,7 +607,7 @@ ENTRY %main.3 (Arg_0.1: f32[]) -> f32[] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> { -// CHECK: %[[VAL_1:.*]] = "mhlo.collective_broadcast"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> +// CHECK: %[[VAL_1:.*]] = "stablehlo.collective_broadcast"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> // CHECK: return %[[VAL_1]] : tensor<128x32xf32> // CHECK: } // CHECK: } @@ -622,7 +622,7 @@ ENTRY %main.3 (Arg_0.1: f32[128,32]) -> f32[128,32] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> { -// CHECK: %[[VAL_1:.*]] = "mhlo.collective_permute"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> +// CHECK: %[[VAL_1:.*]] = "stablehlo.collective_permute"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> // CHECK: return %[[VAL_1]] : tensor<128x32xf32> // CHECK: } // CHECK: } @@ -654,7 +654,7 @@ ENTRY %main.5 (Arg_0.1: f32[5,2], Arg_1.2: f32[5,5], Arg_2.3: f32[5,7]) -> f32[5 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main() -> tuple<> { -// CHECK: %[[VAL_0:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_0:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_0]] : tuple<> // CHECK: } // CHECK: } @@ -795,10 +795,10 @@ ENTRY %main.3 (Arg_0.1: s32[2]) -> s32[2] { // CHECK: return %[[VAL_2]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_3:.*]]: tensor<10xf32>) -> tensor<10xf32> { -// CHECK: %[[VAL_4:.*]] = "mhlo.all_reduce"(%[[VAL_3]]) <{replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> ({ +// CHECK: %[[VAL_4:.*]] = "stablehlo.all_reduce"(%[[VAL_3]]) <{replica_groups = dense<{{\[\[}}0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> ({ // CHECK: ^bb0(%[[VAL_5:.*]]: tensor, %[[VAL_6:.*]]: tensor): // CHECK: %[[VAL_7:.*]] = stablehlo.add %[[VAL_5]], %[[VAL_6]] : tensor -// CHECK: mhlo.return %[[VAL_7]] : tensor +// CHECK: stablehlo.return %[[VAL_7]] : tensor // CHECK: }) : (tensor<10xf32>) -> tensor<10xf32> // CHECK: return %[[VAL_4]] : tensor<10xf32> // CHECK: } @@ -820,7 +820,7 @@ ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> { -// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @SetBound(%[[VAL_0]]) {backend_config = "", mhlo.literal = dense<1> : tensor} : (tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.custom_call @SetBound(%[[VAL_0]]) {backend_config = "", mhlo.literal = dense<1> : tensor} : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK: return %[[VAL_1]] : tensor<2x3xf32> // CHECK: } // CHECK: } @@ -866,7 +866,7 @@ ENTRY %main.8 (Arg_0.1: f32[6], Arg_1.2: f32[6], Arg_2.3: s32[3], Arg_3.4: s32[3 // CHECK: }) : (tensor<16x256xbf16>, tensor<16x256xi32>) -> (tensor<16x256xbf16>, tensor<16x256xi32>) // CHECK: %[[VAL_15:.*]] = stablehlo.slice %[[VAL_16:.*]]#0 [0:16, 0:4] : (tensor<16x256xbf16>) -> tensor<16x4xbf16> // CHECK: %[[VAL_17:.*]] = stablehlo.slice %[[VAL_16]]#1 [0:16, 0:4] : (tensor<16x256xi32>) -> tensor<16x4xi32> -// CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_15]], %[[VAL_17]] {xla_shape = "(bf16[16,4]{1,0}, s32[16,4]{1,0})"} : tuple, tensor<16x4xi32>> +// CHECK: %[[VAL_18:.*]] = stablehlo.tuple %[[VAL_15]], %[[VAL_17]] {xla_shape = "(bf16[16,4]{1,0}, s32[16,4]{1,0})"} : tuple, tensor<16x4xi32>> // CHECK: return %[[VAL_18]] : tuple, tensor<16x4xi32>> // CHECK: } // CHECK: } @@ -908,7 +908,7 @@ ENTRY %main.20 (Arg_0.1: bf16[16,256], Arg_1.2: s32[], Arg_2.3: s32[16,256], Arg // CHECK: return %[[VAL_9]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_10:.*]]: tensor<16x256xbf16>, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor<16x256xi32>, %[[VAL_13:.*]]: tensor) -> tuple, tensor<16x4xi32>> { -// CHECK: %[[VAL_14:.*]] = mhlo.custom_call @PartialReduce(%[[VAL_10]], %[[VAL_12]], %[[VAL_13]], %[[VAL_11]]) {backend_config = "{\22log2_reduction\22: 1, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 4, \22recall_target\22: 0.949218}", called_computations = [@top_k_gt_comparator.5], xla_shape = "(bf16[16,128]{1,0}, s32[16,128]{1,0})"} : (tensor<16x256xbf16>, tensor<16x256xi32>, tensor, tensor) -> tuple, tensor<16x128xi32>> +// CHECK: %[[VAL_14:.*]] = stablehlo.custom_call @PartialReduce(%[[VAL_10]], %[[VAL_12]], %[[VAL_13]], %[[VAL_11]]) {backend_config = "{\22log2_reduction\22: 1, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 4, \22recall_target\22: 0.949218}", called_computations = [@top_k_gt_comparator.5], xla_shape = "(bf16[16,128]{1,0}, s32[16,128]{1,0})"} : (tensor<16x256xbf16>, tensor<16x256xi32>, tensor, tensor) -> tuple, tensor<16x128xi32>> // CHECK: %[[VAL_15:.*]] = stablehlo.get_tuple_element %[[VAL_14]][0] : (tuple, tensor<16x128xi32>>) -> tensor<16x128xbf16> // CHECK: %[[VAL_16:.*]] = stablehlo.get_tuple_element %[[VAL_14]][1] : (tuple, tensor<16x128xi32>>) -> tensor<16x128xi32> // CHECK: %[[VAL_17:.*]]:2 = "stablehlo.sort"(%[[VAL_15]], %[[VAL_16]]) <{dimension = 1 : i64, is_stable = false}> ({ @@ -918,7 +918,7 @@ ENTRY %main.20 (Arg_0.1: bf16[16,256], Arg_1.2: s32[], Arg_2.3: s32[16,256], Arg // CHECK: }) : (tensor<16x128xbf16>, tensor<16x128xi32>) -> (tensor<16x128xbf16>, tensor<16x128xi32>) // CHECK: %[[VAL_23:.*]] = stablehlo.slice %[[VAL_24:.*]]#0 [0:16, 0:4] : (tensor<16x128xbf16>) -> tensor<16x4xbf16> // CHECK: %[[VAL_25:.*]] = stablehlo.slice %[[VAL_24]]#1 [0:16, 0:4] : (tensor<16x128xi32>) -> tensor<16x4xi32> -// CHECK: %[[VAL_26:.*]] = mhlo.tuple %[[VAL_23]], %[[VAL_25]] {xla_shape = "(bf16[16,4]{1,0}, s32[16,4]{1,0})"} : tuple, tensor<16x4xi32>> +// CHECK: %[[VAL_26:.*]] = stablehlo.tuple %[[VAL_23]], %[[VAL_25]] {xla_shape = "(bf16[16,4]{1,0}, s32[16,4]{1,0})"} : tuple, tensor<16x4xi32>> // CHECK: return %[[VAL_26]] : tuple, tensor<16x4xi32>> // CHECK: } // CHECK: } @@ -995,7 +995,7 @@ ENTRY %main.4 (Arg_0.1: f32[2,3], Arg_1.2: f32[5,5]) -> f32[1,2,3] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tuple> { -// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0})"} : (tensor<2x3xf32>) -> tuple> +// CHECK: %[[VAL_1:.*]] = stablehlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0})"} : (tensor<2x3xf32>) -> tuple> // CHECK: return %[[VAL_1]] : tuple> // CHECK: } // CHECK: } @@ -1010,7 +1010,7 @@ ENTRY %main.3 (Arg_0.1: f32[2,3]) -> (f32[2,3]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> { -// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : (tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> +// CHECK: %[[VAL_1:.*]] = stablehlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : (tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> // CHECK: return %[[VAL_1]] : tuple, tensor<4x5xf16>> // CHECK: } // CHECK: } @@ -1025,10 +1025,10 @@ ENTRY %main.3 (Arg_0.1: f32[2,3]) -> (f32[2,3], f16[4,5]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> { -// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : (tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> +// CHECK: %[[VAL_1:.*]] = stablehlo.custom_call @foo(%[[VAL_0]]) {backend_config = "", xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : (tensor<2x3xf32>) -> tuple, tensor<4x5xf16>> // CHECK: %[[VAL_2:.*]] = stablehlo.get_tuple_element %[[VAL_1]][0] : (tuple, tensor<4x5xf16>>) -> tensor<2x3xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.get_tuple_element %[[VAL_1]][1] : (tuple, tensor<4x5xf16>>) -> tensor<4x5xf16> -// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : tuple, tensor<4x5xf16>> +// CHECK: %[[VAL_4:.*]] = stablehlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(f32[2,3]{1,0}, f16[4,5]{1,0})"} : tuple, tensor<4x5xf16>> // CHECK: return %[[VAL_4]] : tuple, tensor<4x5xf16>> // CHECK: } // CHECK: } @@ -1238,11 +1238,11 @@ ENTRY %main.3 (Arg_0.1: (f32[], s32[])) -> f32[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tuple, tensor>, !mhlo.token> { -// CHECK: %[[VAL_1:.*]]:3 = "mhlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = {{\[\[}}1, 0], []]}> : (!mhlo.token) -> (tensor<3x3xi32>, tensor, !mhlo.token) -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,3]{1,0}, pred[])"} : tuple, tensor> -// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_1]]#2 {xla_shape = "((s32[3,3]{1,0}, pred[]), token[])"} : tuple, tensor>, !mhlo.token> -// CHECK: return %[[VAL_3]] : tuple, tensor>, !mhlo.token> +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> tuple, tensor>, !stablehlo.token> { +// CHECK: %[[VAL_1:.*]]:3 = "stablehlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = {{\[\[}}1, 0], []]}> : (!stablehlo.token) -> (tensor<3x3xi32>, tensor, !stablehlo.token) +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,3]{1,0}, pred[])"} : tuple, tensor> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple %[[VAL_2]], %[[VAL_1]]#2 {xla_shape = "((s32[3,3]{1,0}, pred[]), token[])"} : tuple, tensor>, !stablehlo.token> +// CHECK: return %[[VAL_3]] : tuple, tensor>, !stablehlo.token> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->((s32[3,3]{1,0}, pred[]), token[])} @@ -1261,8 +1261,8 @@ ENTRY %main.9 (Arg_0.1: token[]) -> ((s32[3,3], pred[]), token[]) { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false, mhlo.xla_entry_computation_result_layout = [dense<[0, 1]> : tensor<2xindex>], mhlo.xla_entry_computation_result_tiles = {{\[\[}}]]} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tensor<3x3xi32> { -// CHECK: %[[VAL_1:.*]]:2 = "mhlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = {{\[\[}}1, 0]]}> : (!mhlo.token) -> (tensor<3x3xi32>, !mhlo.token) +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> tensor<3x3xi32> { +// CHECK: %[[VAL_1:.*]]:2 = "stablehlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = {{\[\[}}1, 0]]}> : (!stablehlo.token) -> (tensor<3x3xi32>, !stablehlo.token) // CHECK: return %[[VAL_1]]#0 : tensor<3x3xi32> // CHECK: } // CHECK: } @@ -1279,9 +1279,9 @@ ENTRY %main.6 (Arg_0.1: token[]) -> s32[3,3] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_1:.*]] = "mhlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = []}> : (!mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_1:.*]] = "stablehlo.infeed"(%[[VAL_0]]) <{infeed_config = "foobar", layout = []}> : (!stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_1]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->token[]} @@ -1355,9 +1355,9 @@ ENTRY %main.3 (Arg_0.1: f32[4], Arg_1.2: s32[4]) -> f32[4] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_2:.*]] = "mhlo.outfeed"(%[[VAL_0]], %[[VAL_1]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (tensor<3xi32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_2:.*]] = "stablehlo.outfeed"(%[[VAL_0]], %[[VAL_1]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (tensor<3xi32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_2]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(s32[3]{0}, token[])->token[]} @@ -1372,11 +1372,11 @@ ENTRY %main.5 (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x2xi32>, %[[VAL_1:.*]]: !mhlo.token) -> (!mhlo.token {mhlo.sharding = "{{\{\{}}devices=[2,1]0,1}, {maximal device=0}}"}) { -// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @Sharding(%[[VAL_0]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<3x2xi32>) -> tensor<3x2xi32> -// CHECK: %[[VAL_3:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[VAL_2]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<3x2xi32>) -> tensor<6x2xi32> -// CHECK: %[[VAL_4:.*]] = "mhlo.outfeed"(%[[VAL_3]], %[[VAL_1]]) <{outfeed_config = "foobar"}> {mhlo.sharding = "{{\{\{}}devices=[2,1]0,1}, {maximal device=0}}", xla_shape = "token[]"} : (tensor<6x2xi32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_4]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x2xi32>, %[[VAL_1:.*]]: !stablehlo.token) -> (!stablehlo.token {mhlo.sharding = "{{\{\{}}devices=[2,1]0,1}, {maximal device=0}}"}) { +// CHECK: %[[VAL_2:.*]] = stablehlo.custom_call @Sharding(%[[VAL_0]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<3x2xi32>) -> tensor<3x2xi32> +// CHECK: %[[VAL_3:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[VAL_2]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<3x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_4:.*]] = "stablehlo.outfeed"(%[[VAL_3]], %[[VAL_1]]) <{outfeed_config = "foobar"}> {mhlo.sharding = "{{\{\{}}devices=[2,1]0,1}, {maximal device=0}}", xla_shape = "token[]"} : (tensor<6x2xi32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_4]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(s32[3,2]{1,0}, token[])->token[]} @@ -1393,9 +1393,9 @@ ENTRY %main.7 (Arg_0.1: s32[3,2], Arg_1.2: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_3:.*]] = "mhlo.outfeed"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (tensor<3xi32>, tensor<3xi32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_3]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_3:.*]] = "stablehlo.outfeed"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (tensor<3xi32>, tensor<3xi32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_3]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(s32[3]{0}, s32[3]{0}, token[])->token[]} @@ -1411,9 +1411,9 @@ ENTRY %main.6 (Arg_0.1: s32[3], Arg_1.2: s32[3], Arg_2.3: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_1:.*]] = "mhlo.outfeed"(%[[VAL_0]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (!mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_1:.*]] = "stablehlo.outfeed"(%[[VAL_0]]) <{outfeed_config = "foobar"}> {xla_shape = "token[]"} : (!stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_1]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->token[]} @@ -1443,10 +1443,10 @@ ENTRY %main.4 (Arg_0.1: f32[4,6], Arg_1.2: f32[]) -> f32[13,19] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tuple, !mhlo.token> { -// CHECK: %[[VAL_1:.*]]:2 = "mhlo.recv"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,4]{1,0}, token[])"} : tuple, !mhlo.token> -// CHECK: return %[[VAL_2]] : tuple, !mhlo.token> +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> tuple, !stablehlo.token> { +// CHECK: %[[VAL_1:.*]]:2 = "stablehlo.recv"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = true}> : (!stablehlo.token) -> (tensor<3x4xi32>, !stablehlo.token) +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,4]{1,0}, token[])"} : tuple, !stablehlo.token> +// CHECK: return %[[VAL_2]] : tuple, !stablehlo.token> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} @@ -1463,10 +1463,10 @@ ENTRY %main.7 (Arg_0.1: token[]) -> (s32[3,4], token[]) { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> tuple, !mhlo.token> { -// CHECK: %[[VAL_1:.*]]:2 = "mhlo.recv"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,4]{1,0}, token[])"} : tuple, !mhlo.token> -// CHECK: return %[[VAL_2]] : tuple, !mhlo.token> +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> tuple, !stablehlo.token> { +// CHECK: %[[VAL_1:.*]]:2 = "stablehlo.recv"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = false}> : (!stablehlo.token) -> (tensor<3x4xi32>, !stablehlo.token) +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]]#0, %[[VAL_1]]#1 {xla_shape = "(s32[3,4]{1,0}, token[])"} : tuple, !stablehlo.token> +// CHECK: return %[[VAL_2]] : tuple, !stablehlo.token> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} @@ -1483,9 +1483,9 @@ ENTRY %main.7 (Arg_0.1: token[]) -> (s32[3,4], token[]) { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_1:.*]] = "mhlo.recv"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> : (!mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_1:.*]] = "stablehlo.recv"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = false}> : (!stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_1]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->token[]} @@ -1504,7 +1504,7 @@ ENTRY %main.6 (Arg_0.1: token[]) -> token[] { // CHECK: func.func private @region_0.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tuple, tensor> { // CHECK: %[[VAL_4:.*]] = stablehlo.maximum %[[VAL_0]], %[[VAL_2]] : tensor // CHECK: %[[VAL_5:.*]] = stablehlo.maximum %[[VAL_1]], %[[VAL_3]] : tensor -// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> +// CHECK: %[[VAL_6:.*]] = stablehlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> // CHECK: return %[[VAL_6]] : tuple, tensor> // CHECK: } // CHECK: func.func @main(%[[VAL_7:.*]]: tensor<1x10xf32>, %[[VAL_8:.*]]: tensor<1x10xi32>, %[[VAL_9:.*]]: tensor, %[[VAL_10:.*]]: tensor) -> tuple, tensor<1xi32>> { @@ -1514,7 +1514,7 @@ ENTRY %main.6 (Arg_0.1: token[]) -> token[] { // CHECK: %[[VAL_17:.*]] = stablehlo.maximum %[[VAL_14]], %[[VAL_15]] : tensor // CHECK: stablehlo.return %[[VAL_16]], %[[VAL_17]] : tensor, tensor // CHECK: } -// CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_11]]#0, %[[VAL_11]]#1 {xla_shape = "(f32[1]{0}, s32[1]{0})"} : tuple, tensor<1xi32>> +// CHECK: %[[VAL_18:.*]] = stablehlo.tuple %[[VAL_11]]#0, %[[VAL_11]]#1 {xla_shape = "(f32[1]{0}, s32[1]{0})"} : tuple, tensor<1xi32>> // CHECK: return %[[VAL_18]] : tuple, tensor<1xi32>> // CHECK: } // CHECK: } @@ -1710,7 +1710,7 @@ ENTRY %main.9 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[100,200,1], Arg_2.3: f32[ // CHECK: func.func private @region_0.4(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tuple, tensor> { // CHECK: %[[VAL_4:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_1]] : tensor // CHECK: %[[VAL_5:.*]] = stablehlo.add %[[VAL_2]], %[[VAL_3]] : tensor -// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], f32[])"} : tuple, tensor> +// CHECK: %[[VAL_6:.*]] = stablehlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], f32[])"} : tuple, tensor> // CHECK: return %[[VAL_6]] : tuple, tensor> // CHECK: } // CHECK: func.func @main(%[[VAL_7:.*]]: tensor<200x100x300xf32>, %[[VAL_8:.*]]: tensor<10x2xi64>, %[[VAL_9:.*]]: tensor<10x300xf32>) -> tuple, tensor<200x100x300xf32>> { @@ -1720,7 +1720,7 @@ ENTRY %main.9 (Arg_0.1: f32[200,100,300], Arg_1.2: s32[100,200,1], Arg_2.3: f32[ // CHECK: %[[VAL_16:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_14]] : tensor // CHECK: stablehlo.return %[[VAL_15]], %[[VAL_16]] : tensor, tensor // CHECK: }) : (tensor<200x100x300xf32>, tensor<200x100x300xf32>, tensor<10x2xi64>, tensor<10x300xf32>, tensor<10x300xf32>) -> (tensor<200x100x300xf32>, tensor<200x100x300xf32>) -// CHECK: %[[VAL_17:.*]] = mhlo.tuple %[[VAL_18:.*]]#0, %[[VAL_18]]#1 {xla_shape = "(f32[200,100,300]{2,1,0}, f32[200,100,300]{2,1,0})"} : tuple, tensor<200x100x300xf32>> +// CHECK: %[[VAL_17:.*]] = stablehlo.tuple %[[VAL_18:.*]]#0, %[[VAL_18]]#1 {xla_shape = "(f32[200,100,300]{2,1,0}, f32[200,100,300]{2,1,0})"} : tuple, tensor<200x100x300xf32>> // CHECK: return %[[VAL_17]] : tuple, tensor<200x100x300xf32>> // CHECK: } // CHECK: } @@ -1814,9 +1814,9 @@ ENTRY %main.13 (Arg_0.1: f32[10,24,24,64], Arg_1.2: f32[10,12,12,64]) -> f32[10, // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_2:.*]] = "stablehlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xi32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_2]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} @@ -1831,9 +1831,9 @@ ENTRY %main.5 (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> {xla_shape = "token[]"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xi32>, %[[VAL_1:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_2:.*]] = "stablehlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = false}> {xla_shape = "token[]"} : (tensor<3x4xi32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_2]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} @@ -1848,9 +1848,9 @@ ENTRY %main.5 (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_1:.*]] = "mhlo.send"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> {xla_shape = "token[]"} : (!mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_1]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_1:.*]] = "stablehlo.send"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = false}> {xla_shape = "token[]"} : (!stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_1]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(token[])->token[]} @@ -1945,7 +1945,7 @@ ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor) -> tuple, tensor> { -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_0]], %[[VAL_1]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_0]], %[[VAL_1]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> // CHECK: return %[[VAL_2]] : tuple, tensor> // CHECK: } // CHECK: } @@ -1965,7 +1965,7 @@ ENTRY %main.4 (Arg_0.1: f32[], Arg_1.2: s32[]) -> (f32[], s32[]) { // CHECK: %[[VAL_3:.*]] = mhlo.log_plus_one %[[VAL_0]] : tensor<4xf32> // CHECK: %[[VAL_4:.*]] = stablehlo.not %[[VAL_1]] : tensor<4xi32> // CHECK: %[[VAL_5:.*]] = stablehlo.popcnt %[[VAL_1]] : tensor<4xi32> -// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[4]{0}, f32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> +// CHECK: %[[VAL_6:.*]] = stablehlo.tuple %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[4]{0}, f32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> // CHECK: return %[[VAL_6]] : tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> // CHECK: } // CHECK: } @@ -2001,7 +2001,7 @@ ENTRY %main.4 (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>, %[[VAL_1:.*]]: tensor<16x16xi32>) -> tuple<> { -// CHECK: %[[VAL_2:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_2]] : tuple<> // CHECK: } // CHECK: } @@ -2017,7 +2017,7 @@ ENTRY %main.4 (Arg_0.1: f32[16,16], Arg_1.2: s32[16,16]) -> () { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>) -> tuple<> { -// CHECK: %[[VAL_1:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_1:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_1]] : tuple<> // CHECK: } // CHECK: } @@ -2032,7 +2032,7 @@ ENTRY %main.3 (Arg_0.1: f32[16,16]) -> () { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>) -> (tensor<16x16xf32> {mhlo.sharding = "{devices=[1,2]0,1}"}) { -// CHECK: %[[VAL_1:.*]] = mhlo.custom_call @Sharding(%[[VAL_0]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<16x16xf32>) -> tensor<16x16xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.custom_call @Sharding(%[[VAL_0]]) {backend_config = "", mhlo.sharding = "{devices=[1,2]0,1}"} : (tensor<16x16xf32>) -> tensor<16x16xf32> // CHECK: return %[[VAL_1]] : tensor<16x16xf32> // CHECK: } // CHECK: } @@ -2050,7 +2050,7 @@ ENTRY %main.3 (Arg_0.1: f32[16,16]) -> f32[16,16] { // CHECK: return %[[VAL_0]] : tensor<2x3xf32> // CHECK: } // CHECK: func.func @main(%[[VAL_2:.*]]: tensor<2x3xf32>, %[[VAL_3:.*]]: tensor<5x5xf32>) -> tensor<2x3xf32> { -// CHECK: %[[VAL_4:.*]] = mhlo.custom_call @foo(%[[VAL_2]], %[[VAL_3]]) {backend_config = "", called_computations = [@foo.3]} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<2x3xf32> +// CHECK: %[[VAL_4:.*]] = stablehlo.custom_call @foo(%[[VAL_2]], %[[VAL_3]]) {backend_config = "", called_computations = [@foo.3]} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<2x3xf32> // CHECK: return %[[VAL_4]] : tensor<2x3xf32> // CHECK: } // CHECK: } @@ -2073,7 +2073,7 @@ ENTRY %main.7 (Arg_0.1: f32[2,3], Arg_1.2: f32[5,5]) -> f32[2,3] { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xcomplex>, %[[VAL_1:.*]]: tensor<2xcomplex>) -> tuple, tensor<2xf64>> { // CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_0]] : (tensor<2xcomplex>) -> tensor<2xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.abs %[[VAL_1]] : (tensor<2xcomplex>) -> tensor<2xf64> -// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(f32[2]{0}, f64[2]{0})"} : tuple, tensor<2xf64>> +// CHECK: %[[VAL_4:.*]] = stablehlo.tuple %[[VAL_2]], %[[VAL_3]] {xla_shape = "(f32[2]{0}, f64[2]{0})"} : tuple, tensor<2xf64>> // CHECK: return %[[VAL_4]] : tuple, tensor<2xf64>> // CHECK: } // CHECK: } @@ -2136,11 +2136,11 @@ ENTRY %main.4 (Arg_0.1: f32[4], Arg_1.2: s32[]) -> f32[<=4] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !mhlo.token) -> tuple, !mhlo.token> { -// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_rendezvous = "channel_dtoh_0"}, xla_shape = "token[]"} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token -// CHECK: %[[VAL_3:.*]]:2 = "mhlo.recv"(%[[VAL_2]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_rendezvous = "channel_htod_0"}} : (!mhlo.token) -> (tensor<3x4xf32>, !mhlo.token) -// CHECK: %[[VAL_4:.*]] = mhlo.tuple %[[VAL_3]]#0, %[[VAL_3]]#1 {xla_shape = "(f32[3,4]{1,0}, token[])"} : tuple, !mhlo.token> -// CHECK: return %[[VAL_4]] : tuple, !mhlo.token> +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !stablehlo.token) -> tuple, !stablehlo.token> { +// CHECK: %[[VAL_2:.*]] = "stablehlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_rendezvous = "channel_dtoh_0"}, xla_shape = "token[]"} : (tensor<3x4xf32>, !stablehlo.token) -> !stablehlo.token +// CHECK: %[[VAL_3:.*]]:2 = "stablehlo.recv"(%[[VAL_2]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_rendezvous = "channel_htod_0"}} : (!stablehlo.token) -> (tensor<3x4xf32>, !stablehlo.token) +// CHECK: %[[VAL_4:.*]] = stablehlo.tuple %[[VAL_3]]#0, %[[VAL_3]]#1 {xla_shape = "(f32[3,4]{1,0}, token[])"} : tuple, !stablehlo.token> +// CHECK: return %[[VAL_4]] : tuple, !stablehlo.token> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(f32[3,4]{1,0}, token[])->(f32[3,4]{1,0}, token[])} @@ -2160,9 +2160,9 @@ ENTRY %main.10 (Arg_0.1: f32[3,4], Arg_1.2: token[]) -> (f32[3,4], token[]) { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_2:.*]] = "stablehlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xf32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_2]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(f32[3,4]{1,0}, token[])->token[]} @@ -2177,9 +2177,9 @@ ENTRY %main.5 (Arg_0.1: f32[3,4], Arg_1.2: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !mhlo.token) -> !mhlo.token { -// CHECK: %[[VAL_2:.*]] = "mhlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token -// CHECK: return %[[VAL_2]] : !mhlo.token +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>, %[[VAL_1:.*]]: !stablehlo.token) -> !stablehlo.token { +// CHECK: %[[VAL_2:.*]] = "stablehlo.send"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #stablehlo.channel_handle, is_host_transfer = true}> {xla_shape = "token[]"} : (tensor<3x4xf32>, !stablehlo.token) -> !stablehlo.token +// CHECK: return %[[VAL_2]] : !stablehlo.token // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(f32[3,4]{1,0}, token[])->token[]} @@ -2196,7 +2196,7 @@ ENTRY %main.5 (Arg_0.1: f32[3,4], Arg_1.2: token[]) -> token[] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { // CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = stablehlo.rng_bit_generator %[[VAL_0]], algorithm = PHILOX : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) -// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_1]], %[[VAL_2]] {xla_shape = "(u64[3]{0}, u32[2,2]{1,0})"} : tuple, tensor<2x2xui32>> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple %[[VAL_1]], %[[VAL_2]] {xla_shape = "(u64[3]{0}, u32[2,2]{1,0})"} : tuple, tensor<2x2xui32>> // CHECK: return %[[VAL_3]] : tuple, tensor<2x2xui32>> // CHECK: } // CHECK: } @@ -2260,7 +2260,7 @@ ENTRY %main.3 (Arg_0.1: f32[3,4]) -> f32[3,4,1] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor<3x4xf32>) -> tuple, tensor<3x4xf32>> { // CHECK: %[[VAL_2:.*]]:2 = stablehlo.optimization_barrier %[[VAL_0]], %[[VAL_1]] : tensor<4x4xf32>, tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[4,4]{1,0}, f32[3,4]{1,0})"} : tuple, tensor<3x4xf32>> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[4,4]{1,0}, f32[3,4]{1,0})"} : tuple, tensor<3x4xf32>> // CHECK: return %[[VAL_3]] : tuple, tensor<3x4xf32>> // CHECK: } // CHECK: } @@ -2327,7 +2327,7 @@ ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: f32[3,4]) -> f32[3,4] { // CHECK: func.func private @region_0.5(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tuple, tensor> { // CHECK: %[[VAL_4:.*]] = stablehlo.add %[[VAL_0]], %[[VAL_2]] : tensor // CHECK: %[[VAL_5:.*]] = stablehlo.add %[[VAL_1]], %[[VAL_3]] : tensor -// CHECK: %[[VAL_6:.*]] = mhlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> +// CHECK: %[[VAL_6:.*]] = stablehlo.tuple %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[], s32[])"} : tuple, tensor> // CHECK: return %[[VAL_6]] : tuple, tensor> // CHECK: } // CHECK: func.func @main(%[[VAL_7:.*]]: tensor<4x2xf32>, %[[VAL_8:.*]]: tensor<4x2xi32>, %[[VAL_9:.*]]: tensor, %[[VAL_10:.*]]: tensor) -> tuple, tensor<2x2xi32>> { @@ -2337,7 +2337,7 @@ ENTRY %main.4 (Arg_0.1: f32[4,4], Arg_1.2: f32[3,4]) -> f32[3,4] { // CHECK: %[[VAL_17:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_15]] : tensor // CHECK: stablehlo.return %[[VAL_16]], %[[VAL_17]] : tensor, tensor // CHECK: }) : (tensor<4x2xf32>, tensor<4x2xi32>, tensor, tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) -// CHECK: %[[VAL_18:.*]] = mhlo.tuple %[[VAL_19:.*]]#0, %[[VAL_19]]#1 {xla_shape = "(f32[2,2]{1,0}, s32[2,2]{1,0})"} : tuple, tensor<2x2xi32>> +// CHECK: %[[VAL_18:.*]] = stablehlo.tuple %[[VAL_19:.*]]#0, %[[VAL_19]]#1 {xla_shape = "(f32[2,2]{1,0}, s32[2,2]{1,0})"} : tuple, tensor<2x2xi32>> // CHECK: return %[[VAL_18]] : tuple, tensor<2x2xi32>> // CHECK: } // CHECK: } @@ -2399,7 +2399,7 @@ ENTRY %main.3 (Arg_0.1: f32[2]) -> f32[2] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>) -> tuple, tensor<4x2xi32>> { // CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = mhlo.topk(%[[VAL_0]], k = 2) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) -// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_1]], %[[VAL_2]] {xla_shape = "(f32[4,2]{1,0}, s32[4,2]{1,0})"} : tuple, tensor<4x2xi32>> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple %[[VAL_1]], %[[VAL_2]] {xla_shape = "(f32[4,2]{1,0}, s32[4,2]{1,0})"} : tuple, tensor<4x2xi32>> // CHECK: return %[[VAL_3]] : tuple, tensor<4x2xi32>> // CHECK: } // CHECK: } @@ -2417,8 +2417,8 @@ ENTRY %main.6 (Arg_0.1: f32[4,4]) -> (f32[4,2], s32[4,2]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tuple, tensor<2x3xf32>>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tuple<> { -// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "", output_operand_aliases = [#mhlo.output_operand_alias], xla_shape = "(f32[2,3]{1,0})"} : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple> -// CHECK: %[[VAL_3:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_2:.*]] = stablehlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias], xla_shape = "(f32[2,3]{1,0})"} : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_3]] : tuple<> // CHECK: } // CHECK: } @@ -2435,8 +2435,8 @@ ENTRY %main.5 (Arg_0.1: (f32[1,1], f32[2,3]), Arg_1.2: f32[5,5]) -> () { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tuple, tensor<2x3xf32>>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tuple<> { -// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "", output_operand_aliases = [#mhlo.output_operand_alias]} : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tensor<2x3xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_2:.*]] = stablehlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias]} : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tensor<2x3xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_3]] : tuple<> // CHECK: } // CHECK: } @@ -2453,8 +2453,8 @@ ENTRY %main.5 (Arg_0.1: (f32[1,1], f32[2,3]), Arg_1.2: f32[5,5]) -> () { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> { -// CHECK: %[[VAL_1:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token -// CHECK: %[[VAL_2:.*]] = mhlo.add_dependency %[[VAL_0]], %[[VAL_1]] : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.create_token {xla_shape = "token[]"} : !stablehlo.token +// CHECK: %[[VAL_2:.*]] = mhlo.add_dependency %[[VAL_0]], %[[VAL_1]] : (tensor<3x4xf32>, !stablehlo.token) -> tensor<3x4xf32> // CHECK: return %[[VAL_2]] : tensor<3x4xf32> // CHECK: } // CHECK: } @@ -2470,8 +2470,8 @@ ENTRY %main.4 (Arg_0.1: f32[3,4]) -> f32[3,4] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> attributes {execution_thread = "test_thread"} { -// CHECK: %[[VAL_1:.*]] = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token -// CHECK: %[[VAL_2:.*]] = mhlo.add_dependency %[[VAL_0]], %[[VAL_1]] : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.create_token {xla_shape = "token[]"} : !stablehlo.token +// CHECK: %[[VAL_2:.*]] = mhlo.add_dependency %[[VAL_0]], %[[VAL_1]] : (tensor<3x4xf32>, !stablehlo.token) -> tensor<3x4xf32> // CHECK: return %[[VAL_2]] : tensor<3x4xf32> // CHECK: } // CHECK: } @@ -2487,7 +2487,7 @@ ENTRY %main.4 (Arg_0.1: f32[3,4]) -> f32[3,4] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x2xi32>) -> tensor<2x2xi32> { -// CHECK: %[[VAL_1:.*]] = "mhlo.all_to_all"(%[[VAL_0]]) <{channel_handle = #mhlo.channel_handle, concat_dimension = 1 : i64, replica_groups = dense<{{\[\[}}1, 2], [0, 3]]> : tensor<2x2xi64>, split_count = 2 : i64, split_dimension = 1 : i64}> : (tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_1:.*]] = "stablehlo.all_to_all"(%[[VAL_0]]) <{channel_handle = #stablehlo.channel_handle, concat_dimension = 1 : i64, replica_groups = dense<{{\[\[}}1, 2], [0, 3]]> : tensor<2x2xi64>, split_count = 2 : i64, split_dimension = 1 : i64}> : (tensor<2x2xi32>) -> tensor<2x2xi32> // CHECK: return %[[VAL_1]] : tensor<2x2xi32> // CHECK: } // CHECK: } @@ -2503,7 +2503,7 @@ ENTRY %main.3 (Arg_0.1: s32[2,2]) -> s32[2,2] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<128x4xf32>, %[[VAL_1:.*]]: tensor<128x4xf32>) -> tuple, tensor<128x4xf32>> { // CHECK: %[[VAL_2:.*]]:2 = "mhlo.all_to_all"(%[[VAL_0]], %[[VAL_1]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1]]> : tensor<1x2xi64>}> : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) -// CHECK: %[[VAL_3:.*]] = mhlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[128,4]{1,0}, f32[128,4]{1,0})"} : tuple, tensor<128x4xf32>> +// CHECK: %[[VAL_3:.*]] = stablehlo.tuple %[[VAL_2]]#0, %[[VAL_2]]#1 {xla_shape = "(f32[128,4]{1,0}, f32[128,4]{1,0})"} : tuple, tensor<128x4xf32>> // CHECK: return %[[VAL_3]] : tuple, tensor<128x4xf32>> // CHECK: } // CHECK: } @@ -2522,7 +2522,7 @@ ENTRY %main.7 (Arg_0.1: f32[128,4], Arg_1.2: f32[128,4]) -> (f32[128,4], f32[128 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2x3xf32>, %[[VAL_1:.*]]: tensor<5x5xf32>) -> tensor<1x2x3xf32> { -// CHECK: %[[VAL_2:.*]] = mhlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {api_version = 4 : i32, backend_config = {user_attr0 = 123 : i32, user_attr1 = dense<42> : tensor}, has_side_effect = true} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_2:.*]] = stablehlo.custom_call @foo(%[[VAL_0]], %[[VAL_1]]) {api_version = 4 : i32, backend_config = {user_attr0 = 123 : i32, user_attr1 = dense<42> : tensor}, has_side_effect = true} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> // CHECK: return %[[VAL_2]] : tensor<1x2x3xf32> // CHECK: } // CHECK: } @@ -2588,7 +2588,7 @@ ENTRY %main.3 (Arg_0.1: f32[?,784]) -> f32[?,784] { // CHECK: stablehlo.return %[[VAL_18]] : tensor // CHECK: } // CHECK: %[[VAL_19:.*]] = stablehlo.compare NE, %[[VAL_10]], %[[VAL_9]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_20:.*]] = mhlo.tuple %[[VAL_19]] {xla_shape = "(pred[])"} : tuple> +// CHECK: %[[VAL_20:.*]] = stablehlo.tuple %[[VAL_19]] {xla_shape = "(pred[])"} : tuple> // CHECK: return %[[VAL_20]] : tuple> // CHECK: } // CHECK: } @@ -2620,9 +2620,9 @@ ENTRY %main.15 (Arg_0.1: f32[2,2]) -> (pred[]) { // CHECK: return %[[VAL_0]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_2:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_3:.*]] = "mhlo.all_reduce"(%[[VAL_2]]) <{replica_groups = dense<{{\[\[}}0], [1]]> : tensor<2x1xi64>}> ({ +// CHECK: %[[VAL_3:.*]] = "stablehlo.all_reduce"(%[[VAL_2]]) <{replica_groups = dense<{{\[\[}}0], [1]]> : tensor<2x1xi64>}> ({ // CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): -// CHECK: mhlo.return %[[VAL_4]] : tensor +// CHECK: stablehlo.return %[[VAL_4]] : tensor // CHECK: }) : (tensor) -> tensor // CHECK: return %[[VAL_3]] : tensor // CHECK: } @@ -2646,9 +2646,9 @@ ENTRY %main.6 (Arg_0.1: f32[]) -> f32[] { // CHECK: return %[[VAL_0]] : tensor // CHECK: } // CHECK: func.func @main(%[[VAL_2:.*]]: tensor<4x16xf32>) -> tensor<4x4xf32> { -// CHECK: %[[VAL_3:.*]] = "mhlo.reduce_scatter"(%[[VAL_2]]) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ +// CHECK: %[[VAL_3:.*]] = "stablehlo.reduce_scatter"(%[[VAL_2]]) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<{{\[\[}}0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 1 : i64, use_global_device_ids}> ({ // CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): -// CHECK: mhlo.return %[[VAL_4]] : tensor +// CHECK: stablehlo.return %[[VAL_4]] : tensor // CHECK: }) : (tensor<4x16xf32>) -> tensor<4x4xf32> // CHECK: return %[[VAL_3]] : tensor<4x4xf32> // CHECK: } @@ -2775,7 +2775,7 @@ ENTRY %main.13 (Arg_0.1: f32[10,24,24,64], Arg_1.2: f32[10,23,23,64], Arg_2.3: f // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<16x16xf32>, %[[VAL_1:.*]]: tensor<16x16xi32>) -> tuple<> { -// CHECK: %[[VAL_2:.*]] = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple {xla_shape = "()"} : tuple<> // CHECK: return %[[VAL_2]] : tuple<> // CHECK: } // CHECK: } diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index e2efc83b32793b..48e9f1220ff7ad 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -295,9 +295,9 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:base", - "@stablehlo//:broadcast_utils", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_assembly_format", + "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_type_inference", ], ) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td index 1f90057883629f..e4d7cdf400bdbb 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td @@ -52,6 +52,9 @@ defvar MHLO_BroadcastDimAttr = I64ElementsAttr; // Token type. defvar MHLO_Token = HLO_Token; +def StableHLO_TokenType : Type($_self)">, "stablehlo token">; +defvar MHLO_AnyToken = AnyTypeOf<[MHLO_Token, StableHLO_TokenType]>; + // Any integer tensor types defvar MHLO_IntTensor = HLO_IntTensor; @@ -80,6 +83,7 @@ defvar MHLO_ComplexTensor = HLO_ComplexTensor; defvar MHLO_Tuple = HLO_Tuple; defvar MHLO_TensorOrToken = HLO_TensorOrPerAxisQuantizedTensorOrToken; +defvar MHLO_TensorOrAnyToken = AnyTypeOf<[MHLO_TensorOrToken, StableHLO_TokenType]>; defvar MHLO_TensorOrTokenOrTuple = AnyTypeOf<[MHLO_Tensor, MHLO_Token, MHLO_Tuple]>; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 48c3c255fd3173..998a7ac6a1969a 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -85,6 +85,7 @@ limitations under the License. #include "mlir/Transforms/InliningUtils.h" #include "stablehlo/dialect/AssemblyFormat.h" #include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/TypeInference.h" #include "utils/convert_op_folder.h" #include "utils/hlo_utils.h" diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 9d5ba5557e6c8f..a32574fcf11293 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -1271,8 +1271,8 @@ def MHLO_AddDependencyOp : MHLO_Op<"add_dependency", [Pure, ``` }]; - let arguments = (ins MHLO_TensorOrToken:$operand, MHLO_Token:$token); - let results = (outs MHLO_TensorOrToken:$output); + let arguments = (ins MHLO_TensorOrAnyToken:$operand, MHLO_AnyToken:$token); + let results = (outs MHLO_TensorOrAnyToken:$output); let hasCustomHLOConverter = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc index 9c14dfd106e532..9f5e164969decd 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mhlo/transforms/rewriters.h" @@ -25,6 +26,7 @@ limitations under the License. #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" @@ -38,6 +40,29 @@ namespace mhlo { namespace { +// AddDependencyOp is the only op that doesn't exist in StableHLO but uses +// token types. This led to two options (1) support either token type in +// AddDependencyOp or (2) Design a token conversion (or unrealized cast) between +// MHLO and StableHLO. Option (1) seems safer, and we can hopefully obsolete +// mhlo::TokenType all together and just use StableHLO tokens everywhere. +// +// Note: Only the second argument needs to be converted. All token creation and +// propagation is already handled by existing conversions. +struct AddDependencyOpToStablehloTokenConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::AddDependencyOp op, mhlo::AddDependencyOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Only convert if input token type is MHLO token + if (!llvm::isa(adaptor.getToken().getType())) + return rewriter.notifyMatchFailure(op, "nothing to convert"); + rewriter.replaceOpWithNewOp(op, adaptor.getOperand(), + adaptor.getToken()); + return success(); + } +}; + struct HloLegalizeToStablehloPass : public impl::HloLegalizeToStablehloPassBase { HloLegalizeToStablehloPass() @@ -51,18 +76,24 @@ struct HloLegalizeToStablehloPass target.addIllegalDialect(); target.addLegalDialect(); + stablehlo::HloToStablehloTypeConverter converter; + RewritePatternSet patterns(&getContext()); + if (allow_xla_features_) { // These ops do not exist in StableHLO. target.addLegalOp< - mhlo::AddDependencyOp, mhlo::AsyncDoneOp, mhlo::AsyncStartOp, - mhlo::AsyncUpdateOp, mhlo::BitcastOp, mhlo::CopyOp, mhlo::DomainOp, - mhlo::ErfOp, mhlo::FusionOp, mhlo::MinimumBroadcastShapesOp, - mhlo::RaggedDotOp, mhlo::SparseDotOp, mhlo::StochasticConvertOp, - mhlo::TopKOp, mhlo::TraceOp, mhlo::XlaRngGetAndUpdateStateOp>(); + mhlo::AsyncDoneOp, mhlo::AsyncStartOp, mhlo::AsyncUpdateOp, + mhlo::BitcastOp, mhlo::CopyOp, mhlo::DomainOp, mhlo::ErfOp, + mhlo::FusionOp, mhlo::MinimumBroadcastShapesOp, mhlo::RaggedDotOp, + mhlo::SparseDotOp, mhlo::StochasticConvertOp, mhlo::TopKOp, + mhlo::TraceOp, mhlo::XlaRngGetAndUpdateStateOp>(); + target.addDynamicallyLegalOp( + [](mhlo::AddDependencyOp op) { + return llvm::isa(op.getToken().getType()); + }); + patterns.add(&getContext()); } - stablehlo::HloToStablehloTypeConverter converter; - RewritePatternSet patterns(&getContext()); stablehlo::populateHloToStablehloPatterns( &patterns, &converter, &getContext(), allow_experimental_features_); stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 9e658a25043dc3..5079b94ad8482d 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mhlo/transforms/rewriters.h" @@ -26,6 +27,7 @@ limitations under the License. #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" @@ -39,6 +41,29 @@ namespace mhlo { namespace { +// AddDependencyOp is the only op that doesn't exist in StableHLO but uses +// token types. This led to two options (1) support either token type in +// AddDependencyOp or (2) Design a token conversion (or unrealized cast) between +// MHLO and StableHLO. Option (1) seems safer, and we can hopefully obsolete +// mhlo::TokenType all together and just use StableHLO tokens everywhere. +// +// Note: Only the second argument needs to be converted. All token creation and +// propagation is already handled by existing conversions. +struct AddDependencyOpToMhoTokenConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::AddDependencyOp op, mhlo::AddDependencyOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Only convert if input token type is MHLO token + if (!llvm::isa(adaptor.getToken().getType())) + return rewriter.notifyMatchFailure(op, "nothing to convert"); + rewriter.replaceOpWithNewOp(op, adaptor.getOperand(), + adaptor.getToken()); + return success(); + } +}; + void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< // go/keep-sorted start @@ -55,6 +80,10 @@ struct StablehloLegalizeToHloPass ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); + target.addDynamicallyLegalOp( + [](mhlo::AddDependencyOp op) { + return llvm::isa(op.getToken().getType()); + }); // Allow injecting legal ops to permit gradual migration. if (!convert_xla_supported_stablehlo_) { @@ -63,6 +92,7 @@ struct StablehloLegalizeToHloPass stablehlo::StablehloToHloTypeConverter converter; RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); stablehlo::populateStablehloToHloPatterns(&patterns, &converter, &getContext()); stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-partial.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-partial.mlir index 80ce93cd1aa0db..fdb11ec5d39ece 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-partial.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-partial.mlir @@ -32,3 +32,19 @@ func.func @copy() -> tensor<2x1xi32> { // CHECK: return %[[COPY]] func.return %1 : tensor<2x1xi32> } + +// ----- + +// Tokens flow between StableHLO and MHLO ops, so need to have special converson +// logic. AddDependencyOp is the only op that doesn't exist in StableHLO but +// uses token types, so it can have either StableHLO or MHLO token types as +// input. + +// CHECK-LABEL: func @add_dependency +func.func @add_dependency(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: %[[TOK:.*]] = stablehlo.create_token {{.*}} : !stablehlo.token + // CHECK-NEXT: %[[COPY:.*]] = mhlo.add_dependency %arg0, %[[TOK]] : (tensor<3x4xf32>, !stablehlo.token) -> tensor<3x4xf32> + %0 = mhlo.create_token {xla_shape = "token[]"} : !mhlo.token + %1 = mhlo.add_dependency %arg0, %0 : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 159150203db3af..a72ac0486c4f8b 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -442,6 +442,22 @@ func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { // ----- +// Tokens flow between StableHLO and MHLO ops, so need to have special converson +// logic. AddDependencyOp is the only op that doesn't exist in StableHLO but +// uses token types, so it can have either StableHLO or MHLO token types as +// input. + +// CHECK-LABEL: "add_dependency" +func.func @add_dependency(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: %[[TOK:.*]] = "mhlo.create_token"() {{.*}} : () -> !mhlo.token + // CHECK-NEXT: %[[COPY:.*]] = "mhlo.add_dependency"(%arg0, %[[TOK]]) : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> + %0 = stablehlo.create_token {xla_shape = "token[]"} : !stablehlo.token + %1 = mhlo.add_dependency %arg0, %0 : (tensor<3x4xf32>, !stablehlo.token) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + +// ----- + // CHECK-LABEL: "op_after_all" func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { // CHECK: "mhlo.after_all"([[ARG0:%arg[0-9]+]]) : (!mhlo.token) -> !mhlo.token From 1ba1883f1d6f3acd60a3f1665ebce8278abb0075 Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 11 Apr 2025 16:17:00 -0700 Subject: [PATCH 0599/1324] The second parameter to `TF_LITE_KERNEL_LOG` is expected to be a `printf` format string, known at compile time. Use type-appropriate (potentially platform-dependent) specifiers in said formatting string. Use `TF_LITE_ENSURE` instead of `CHECK`. Run IWYU to fix includes. PiperOrigin-RevId: 746624426 --- tensorflow/lite/delegates/flex/BUILD | 5 ++++ tensorflow/lite/delegates/flex/delegate.cc | 27 ++++++++++-------- tensorflow/lite/delegates/flex/kernel.cc | 33 +++++++++++++++++----- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 4533a2417d0d88..a7572d9f74bdfb 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -166,6 +166,7 @@ cc_library( ":delegate_data", ":tflite_subgraph_execute", ":util", + "//tensorflow/core:session_options", "//tensorflow/core/tfrt/fallback:op_kernel_runner", "//tensorflow/lite:kernel_api", "//tensorflow/lite:macros", @@ -173,10 +174,14 @@ cc_library( "//tensorflow/lite:string", "//tensorflow/lite:string_util", "//tensorflow/lite:util", + "//tensorflow/lite/core:subgraph", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/utils:simple_delegate", "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@flatbuffers", ] + if_mobile([ diff --git a/tensorflow/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc index eee6bd04e6de37..f7fca34d49d739 100644 --- a/tensorflow/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -14,19 +14,24 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/delegate.h" +#include +#include +#include #include #include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/lite/context_util.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/public/session_options.h" #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/core/macros.h" +#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/lite/delegates/flex/kernel.h" -#include "tensorflow/lite/delegates/flex/util.h" +#include "tensorflow/lite/delegates/utils/simple_delegate.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/util.h" @@ -158,11 +163,9 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle( if (output->bytes != t_data.size()) { TF_LITE_KERNEL_LOG(context, - absl::StrCat("The given ", output->bytes, - " bytes are not enough to store " - "TensorFlow's aligned buffer of size ", - t_data.size(), " bytes.") - .c_str()); + "The given %zu bytes are not enough to store " + "TensorFlow's aligned buffer of size %zu bytes.", + output->bytes, t_data.size()); return kTfLiteError; } diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 7a8bf163161914..9e6532d6b7b908 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -14,7 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/kernel.h" +#include + #include +#include +#include #include #include #include @@ -22,23 +26,38 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_type.h" +#include "tensorflow/lite/util.h" // Note: this is part of TF Lite's Flex delegation code which is to be // completed soon. @@ -343,7 +362,7 @@ class OpNode { tf_tensor->TotalBytes() != tensor->bytes) { TF_LITE_KERNEL_LOG(context, "FlexDelegate: Tensor %s(%d) buffer size mismatch " - "%zu(%lld) != %ld(%ld)", + "%zu(%" PRId64 ") != %zu(%" PRId64 ")", tensor->name, tensor_index, tf_tensor->TotalBytes(), tf_tensor->NumElements(), tensor->bytes, NumElements(tensor)); @@ -466,14 +485,14 @@ TfLiteStatus DelegateKernel::Init(TfLiteContext* context, op_data_->shared_info.tensor_release_map = flex_delegate_data->GetTensorReleaseMap(context); - CHECK(params->output_tensors); + TF_LITE_ENSURE(context, params->output_tensors != nullptr); std::set output_set; for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { op_data_->subgraph_outputs.push_back(tensor_index); output_set.insert(tensor_index); } - CHECK(params->input_tensors); + TF_LITE_ENSURE(context, params->input_tensors != nullptr); for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { op_data_->subgraph_inputs.push_back(tensor_index); } @@ -482,7 +501,7 @@ TfLiteStatus DelegateKernel::Init(TfLiteContext* context, op_data_->nodes.reserve(params->nodes_to_replace->size); - CHECK(params->nodes_to_replace); + TF_LITE_ENSURE(context, params->nodes_to_replace != nullptr); absl::Status status; // Now we explicitly disable reusing TFLite tensor buffers for certain TF ops, @@ -813,7 +832,7 @@ TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) { tf_tensor.TotalBytes() != tensor->bytes) { TF_LITE_KERNEL_LOG(context, "FlexDelegate: Tensor %s(%d) buffer size mismatch " - "%zu(%lld) != %ld(%ld)", + "%zu(%" PRId64 ") != %zu(%" PRId64 ")", tensor->name, tensor_index, tf_tensor.TotalBytes(), tf_tensor.NumElements(), tensor->bytes, NumElements(tensor)); From 9d76958e600d8b334bc00e312b76d6987c17c00e Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 11 Apr 2025 16:55:52 -0700 Subject: [PATCH 0600/1324] Replace `populateRegionManualAxes` with `getManualAxesHierarchy` and add documentation for the manual axes hierarchy. Given an ManualComputationOp `op`, * `op.getManualAxes()` is the local manual axes * `parent` is the manual axes of its parent ManualComputationOp, recursively * `region` is the concatenation of `op.getManualAxes()` and `parent`. PiperOrigin-RevId: 746634820 --- .../stablehlo_round_trip/shard_map_export.cc | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc index 760bc100e1c15e..5e1997d6f92c21 100644 --- a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc @@ -87,22 +87,29 @@ using sdy::TensorShardingPerValueAttr; using ManualComputationToParentManualAxes = llvm::SmallDenseMap>; -// Populates `regionManualAxes` with the manual axes of `op`. If `op` is nested -// in another manual computation, also returns the manual axes of the parent -// computation. -mlir::ArrayRef populateRegionManualAxes( - SmallVector& regionManualAxes, ManualComputationOp op, +// Given an ManualComputationOp `op`, `op.getManualAxes()` is the local manual +// axes. `parent` is the manual axes of its parent ManualComputationOp, +// recursively. `region` is the concatenation of `op.getManualAxes()` and +// `parent`. +struct ManualAxesHierarchy { + mlir::ArrayRef parent; + SmallVector region; +}; + +ManualAxesHierarchy getManualAxesHierarchy( + ManualComputationOp op, const ManualComputationToParentManualAxes& parentManualCompAxes) { - regionManualAxes = SmallVector(op.getManualAxes().begin(), + ManualAxesHierarchy hierarchy; + hierarchy.region = SmallVector(op.getManualAxes().begin(), op.getManualAxes().end()); - mlir::ArrayRef parentManualAxesRef; + if (auto parentManualAxes = parentManualCompAxes.find(op); parentManualAxes != parentManualCompAxes.end()) { - parentManualAxesRef = parentManualAxes->getSecond(); - regionManualAxes.append(parentManualAxes->getSecond().begin(), + hierarchy.parent = parentManualAxes->getSecond(); + hierarchy.region.append(parentManualAxes->getSecond().begin(), parentManualAxes->getSecond().end()); } - return parentManualAxesRef; + return hierarchy; } // Returns the first sharding of `op`. If there are no in/out shardings, returns @@ -136,15 +143,15 @@ void convertShardingsToStablehloShardings( CHECK(mesh); // The axes that are manual inside `op`'s region. - SmallVector regionManualAxes; - (void)populateRegionManualAxes(regionManualAxes, op, parentManualCompAxes); + ManualAxesHierarchy manualAxes = + getManualAxesHierarchy(op, parentManualCompAxes); MLIRContext* context = op.getContext(); std::function getStringAttr = [&](const HloSharding& hloSharding) { return StringAttr::get(context, hloSharding.ToString()); }; - if (mesh.getAxes().size() == regionManualAxes.size()) { + if (mesh.getAxes().size() == manualAxes.region.size()) { // All operations in the body have fully manual sharding. StringAttr fullyManualSharding = getStringAttr(HloSharding::Manual()); op.getBody().front().walk( @@ -178,7 +185,7 @@ void convertShardingsToStablehloShardings( opInBody->setAttr(kXlaShardingAttr, convertToHloShardingAttr( opInBody, shardingPerValue.getShardings(), - getMeshAttr, getStringAttr, regionManualAxes)); + getMeshAttr, getStringAttr, manualAxes.region)); opInBody->removeAttr(kShardingAttr); return mlir::WalkResult::advance(); }); @@ -217,9 +224,8 @@ void convertManualComputationOp( CHECK(mesh); // The axes that are manual inside `op`'s region. - SmallVector regionManualAxes; - mlir::ArrayRef parentManualAxes = - populateRegionManualAxes(regionManualAxes, op, parentManualCompAxes); + ManualAxesHierarchy manualAxes = + getManualAxesHierarchy(op, parentManualCompAxes); std::function getStringAttr = [&](const HloSharding& hloSharding) { return rewriter.getStringAttr(hloSharding.ToString()); @@ -237,7 +243,7 @@ void convertManualComputationOp( SmallVector shardToFullAttributes = createAttributes(kSPMDShardToFullShapeCallTargetName); - bool fullyManual = mesh.getAxes().size() == regionManualAxes.size(); + bool fullyManual = mesh.getAxes().size() == manualAxes.region.size(); mlir::Location loc = op.getLoc(); auto getMeshAttr = [&](TensorShardingAttr) { return mesh; }; // Add copy and custom_call @SPMDFullToShardShape for each operand. The @@ -249,13 +255,13 @@ void convertManualComputationOp( auto copy = rewriter.create(loc, globalOperand); copy->setAttr(kXlaShardingAttr, getStringAttr(convertToHloSharding(inSharding, getMeshAttr, - parentManualAxes))); + manualAxes.parent))); fullToShardAttributes.back() = rewriter.getNamedAttr( kXlaShardingAttr, fullyManual ? fullyManualSharding : getStringAttr(convertToHloSharding( - eraseManualAxes(inSharding, regionManualAxes), - getMeshAttr, regionManualAxes))); + eraseManualAxes(inSharding, manualAxes.region), + getMeshAttr, manualAxes.region))); auto fullToShard = rewriter.create( loc, localArgumentType, copy.getResult(), fullToShardAttributes); fullToShardResults.push_back(fullToShard.getResult(0)); @@ -272,11 +278,11 @@ void convertManualComputationOp( fullyManual ? fullyManualSharding : getStringAttr(convertToHloSharding( - eraseManualAxes(outSharding, regionManualAxes), - getMeshAttr, regionManualAxes))); + eraseManualAxes(outSharding, manualAxes.region), + getMeshAttr, manualAxes.region))); shardToFullAttributes.back() = rewriter.getNamedAttr( kXlaShardingAttr, getStringAttr(convertToHloSharding( - outSharding, getMeshAttr, parentManualAxes))); + outSharding, getMeshAttr, manualAxes.parent))); auto shardToFull = rewriter.create( loc, opResult.getType(), copy.getResult(), shardToFullAttributes); opResult.replaceAllUsesWith(shardToFull.getResult(0)); From ebd9ac67f6bd7f6c12f1c08965948717da2f7939 Mon Sep 17 00:00:00 2001 From: Wilsin Gosti Date: Fri, 11 Apr 2025 17:17:46 -0700 Subject: [PATCH 0601/1324] Add more descriptive error messages to help debug the issue when example parsing fails. PiperOrigin-RevId: 746640107 --- .../core/util/example_proto_fast_parsing.cc | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index b4fac84e7aa017..18d58405287bbf 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -1582,9 +1582,10 @@ absl::Status FastParseSingleExample(const Config& config, return errors::InvalidArgument("Key: ", feature_name, ". ", suffix); }; - auto parse_error = [feature_name] { - return errors::InvalidArgument("Key: ", feature_name, - ". Can't parse serialized Example."); + auto parse_error = [feature_name](absl::string_view description) { + return errors::InvalidArgument( + "Key: ", feature_name, + ". Can't parse serialized Example: ", description); }; DataType example_dtype; @@ -1619,27 +1620,30 @@ absl::Status FastParseSingleExample(const Config& config, case DT_INT64: { auto out_p = out->flat().data(); LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseInt64List(&slice)) return parse_error(); + if (!feature.ParseInt64List(&slice)) + return parse_error("Parsing int64_list failed."); if (slice.EndDistance() != 0) { - return parse_error(); + return parse_error("Some int64_list slice was not parsed."); } break; } case DT_FLOAT: { auto out_p = out->flat().data(); LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseFloatList(&slice)) return parse_error(); + if (!feature.ParseFloatList(&slice)) + return parse_error("Parsing float_list failed."); if (slice.EndDistance() != 0) { - return parse_error(); + return parse_error("Some float_list slice was not parsed."); } break; } case DT_STRING: { auto out_p = out->flat().data(); LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseBytesList(&slice)) return parse_error(); + if (!feature.ParseBytesList(&slice)) + return parse_error("Parsing bytes_list failed."); if (slice.EndDistance() != 0) { - return parse_error(); + return parse_error("Some bytes_list slice was not parsed."); } break; } @@ -1697,22 +1701,25 @@ absl::Status FastParseSingleExample(const Config& config, case DT_INT64: { // TODO(mrry): Use the fact that the `int64_list` is packed to read // out the length and pre-allocate the output tensor. - if (!feature.ParseInt64List(&int64_list)) return parse_error(); + if (!feature.ParseInt64List(&int64_list)) + return parse_error("Parsing int64_list failed."); num_elements = int64_list.size(); break; } case DT_FLOAT: { - if (!feature.ParseFloatList(&float_list)) return parse_error(); + if (!feature.ParseFloatList(&float_list)) + return parse_error("Parsing float_list failed."); num_elements = float_list.size(); break; } case DT_STRING: { int actual_num_elements = 0; if (!feature.GetNumElementsInBytesList(&actual_num_elements)) { - return parse_error(); + return parse_error("Could not get num elements in bytes_list."); } bytes_list.reserve(actual_num_elements); - if (!feature.ParseBytesList(&bytes_list)) return parse_error(); + if (!feature.ParseBytesList(&bytes_list)) + return parse_error("Parsing bytes_list failed."); num_elements = bytes_list.size(); break; } @@ -1778,7 +1785,9 @@ absl::Status FastParseSingleExample(const Config& config, } case DT_FLOAT: { if (!out->CopyFrom(float_list.tensor(), out_shape)) { - return parse_error(); + return parse_error(absl::StrCat("Size of float_list is ", + float_list.tensor().dims(), + ", expected ", out_shape.dims())); } break; } From 6a8c88106544f9288c82b446569d6abda178083a Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Fri, 11 Apr 2025 17:32:43 -0700 Subject: [PATCH 0602/1324] BroadcastOp, BroadCastInDimOp, DynamicBroadcastInDimOp : Direct StableHLO -> HLO Translation PiperOrigin-RevId: 746643207 --- .../xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 1 + .../mhlo_to_hlo/gen_hlo_op_writer.td | 10 ++--- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 41 ++++++++++++++----- third_party/xla/xla/hlo/translate/tests/BUILD | 1 + .../xla/hlo/translate/tests/stablehlo.mlir | 26 ++++++++++++ .../translate/tests/stablehlo_invalid.mlir | 10 +++++ .../stablehlo_legalize_to_hlo_pass.cc | 5 ++- 7 files changed, 77 insertions(+), 17 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/tests/stablehlo_invalid.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index 5c5427a255ed28..5d1094eaf05714 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -34,6 +34,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index 55596d1f5d836a..1fb6b15607312c 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -31,8 +31,8 @@ defvar HloConversionAllowedOps = [ // StableHLO_BatchNormInferenceOp, // StableHLO_BatchNormTrainingOp, // StableHLO_BitcastConvertOp, - // StableHLO_BroadcastInDimOp, - // StableHLO_BroadcastOp, + StableHLO_BroadcastInDimOp, + StableHLO_BroadcastOp, // StableHLO_CaseOp, // StableHLO_CbrtOp, // StableHLO_CeilOp, @@ -55,7 +55,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_DivOp, // StableHLO_DotGeneralOp, // StableHLO_DotOp, - // StableHLO_DynamicBroadcastInDimOp, + StableHLO_DynamicBroadcastInDimOp, // StableHLO_DynamicConvOp, // StableHLO_DynamicGatherOp, // StableHLO_DynamicIotaOp, @@ -146,7 +146,7 @@ defvar CustomHloConverterOps = [ // StableHLO_BatchNormGradOp, // StableHLO_BatchNormTrainingOp, // StableHLO_BitcastConvertOp, - // StableHLO_BroadcastInDimOp, + StableHLO_BroadcastInDimOp, // StableHLO_CaseOp, // StableHLO_CollectiveBroadcastOp, // StableHLO_CompareOp, @@ -158,7 +158,7 @@ defvar CustomHloConverterOps = [ // StableHLO_CustomCallOp, // StableHLO_DotGeneralOp, // StableHLO_DotOp, - // StableHLO_DynamicBroadcastInDimOp, + StableHLO_DynamicBroadcastInDimOp, // StableHLO_DynamicConvOp, // StableHLO_DynamicGatherOp, // StableHLO_DynamicIotaOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 046aaa22c86b55..623cf96a2bd904 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -256,15 +256,6 @@ static std::vector ConvertDenseIntAttr( return ConvertDenseIntAttr(*attr); } -// Converts the broadcast_dimensions attribute into a vector of dimension -// numbers (empty if the attribute is absent). -static std::vector Convert_broadcast_dimensions( - std::optional broadcast_dimensions) { - if (!broadcast_dimensions.has_value()) return {}; - - return ConvertDenseIntAttr(*broadcast_dimensions); -} - static std::vector Convert_cross_program_prefetches( mlir::ArrayAttr prefetches) { std::vector cross_program_prefetches; @@ -408,6 +399,9 @@ static mlir::FailureOr ExtractXlaShape(mlir::Operation* op) { } I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes); +I64_ARRAY_ATTR_TO_VECTOR(broadcast_sizes); +I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_dimensions); +I64_ARRAY_ATTR_TO_VECTOR(broadcast_dimensions); I64_ELEMENTS_ATTR_TO_VECTOR(permutation); I64_ELEMENTS_ATTR_TO_VECTOR(start_indices); I64_ARRAY_ATTR_TO_VECTOR(start_indices); @@ -1107,6 +1101,31 @@ LogicalResult ExportXlaOp(ConstantOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { + auto type = mlir::dyn_cast(op.getType()); + if (!type) { + return failure(); + } + auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) { + return failure(); + } + + // Use TypeToShape to handle bounded dynamism. + // HLO expects broadcast sizes to use the bound's value, not kDynamic. + xla::Shape shape = xla::TypeToShape(type); + value_map[op] = + BroadcastInDim(operand, shape.dimensions(), + Convert_broadcast_dimensions(op.getBroadcastDimensions())); + return success(); +} + +LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { + // This op has no expression in the legacy export format. + return failure(); +} + } // namespace } // namespace stablehlo @@ -1131,7 +1150,9 @@ LogicalResult ExportXlaOp(CompositeOp, OpLoweringContext) { } LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { - // This op has no expression in the legacy export format. + // HLO has no support for DynamicBroadcastInDimOp. + // These all must be refined away before lowering. + // See https://openxla.org/stablehlo/dynamism return failure(); } diff --git a/third_party/xla/xla/hlo/translate/tests/BUILD b/third_party/xla/xla/hlo/translate/tests/BUILD index 17f631deda574c..228f90ee2cc1b5 100644 --- a/third_party/xla/xla/hlo/translate/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/tests/BUILD @@ -19,6 +19,7 @@ lit_test_suite( "simple.hlo", "simple.mlir", "stablehlo.mlir", + "stablehlo_invalid.mlir", "vhlo_input.mlir", # go/keep-sorted end ], diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir index 18947908b99a84..1f9a500466d22c 100644 --- a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir @@ -41,3 +41,29 @@ func.func @main(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) return %0 : tensor<1x4xi32> } // CHECK-DIRECT: stablehlo.dynamic_slice + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(s32[4]{0})->s32[1,2,3,4]{3,2,1,0}} + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = s32[4] parameter(0) +// CHECK-NEXT: ROOT %[[broadcast_2:[^ ]+]] = s32[1,2,3,4] broadcast(%[[Arg_0_1]]), dimensions={3} +func.func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { + %0 = stablehlo.broadcast %arg0, sizes = [1, 2, 3] : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + return %0 : tensor<1x2x3x4xi32> +} +// CHECK-DIRECT: stablehlo.broadcast + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[1]{0})->f32[1,10]{1,0}} + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[1] parameter(0) +// CHECK-NEXT: ROOT %[[broadcast_2:[^ ]+]] = f32[1,10] broadcast(%[[Arg_0_1]]), dimensions={0} +func.func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x10xf32> + return %0 : tensor<1x10xf32> +} +// CHECK-DIRECT: stablehlo.broadcast_in_dim diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo_invalid.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo_invalid.mlir new file mode 100644 index 00000000000000..9e188bfcb238e9 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo_invalid.mlir @@ -0,0 +1,10 @@ +// RUN: not hlo-translate -mlir-to-hlo -split-input-file %s 2>&1 | FileCheck %s + +// StableHLO ops that has no HLO support. These all must be refined away before +// lowering. See https://openxla.org/stablehlo/dynamism + +func.func @main(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { + // CHECK: Shape Error: Invalid element type + %0 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [0] {known_expanding_dimensions = array, known_nonexpanding_dimensions = array} : (tensor, tensor<1xindex>) -> tensor + return %0 : tensor +} diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 5079b94ad8482d..a0f82893435d59 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -67,8 +67,9 @@ struct AddDependencyOpToMhoTokenConverter void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< // go/keep-sorted start - stablehlo::AddOp, stablehlo::ConstantOp, stablehlo::DynamicSliceOp, - stablehlo::SliceOp + stablehlo::AddOp, stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, + stablehlo::ConstantOp, stablehlo::DynamicBroadcastInDimOp, + stablehlo::DynamicSliceOp, stablehlo::SliceOp // go/keep-sorted end >(); } From ebe6a597273cce474759a7cc37a26fde48b022ed Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 11 Apr 2025 18:43:58 -0700 Subject: [PATCH 0603/1324] [xla:gpu] CommandBuffer: add explicit RecordCreate and RecordUpdate APIs to CommandBufferCmdSequence + Remove RecordedCommands struct as today CommandBufferCmd always creates or updates a single command in the underlying command buffer. Remove unnecessary complexity and simply keep track of a single pointer. PiperOrigin-RevId: 746658289 --- .../gpu/runtime/command_buffer_cmd.cc | 250 +++++++++++------- .../backends/gpu/runtime/command_buffer_cmd.h | 145 ++++++---- .../gpu/runtime/command_buffer_cmd_test.cc | 16 +- 3 files changed, 253 insertions(+), 158 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index baab3b024cbd89..53d6bded7ba291 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -149,13 +149,6 @@ static se::CommandBuffer::Builder CreateExecutionScopeBuilder( }; } -absl::StatusOr -CommandBufferCmd::RecordedCommands::Create( - absl::StatusOr command) { - if (!command.ok()) return command.status(); - return RecordedCommands{{*command}}; -} - //===----------------------------------------------------------------------===// // CommandBufferCmd::RecordAction helpers. //===----------------------------------------------------------------------===// @@ -168,18 +161,16 @@ using UpdateCommand = absl::FunctionRef; // Handles a record action by calling one of the user-provided functions. -static absl::StatusOr Handle( +static absl::StatusOr Handle( CommandBufferCmd::RecordAction action, CreateCommand create_command, UpdateCommand update_command) { if (auto* create = std::get_if(&action)) { - return CommandBufferCmd::RecordedCommands::Create( - create_command(create->dependencies)); + return create_command(create->dependencies); } if (auto* update = std::get_if(&action)) { - auto* command = update->recorded_commands.commands[0]; - TF_RETURN_IF_ERROR(update_command(command)); - return std::move(update->recorded_commands); + TF_RETURN_IF_ERROR(update_command(update->command)); + return update->command; } return Internal("Invalid record action"); @@ -261,23 +252,12 @@ absl::Status CommandBufferCmdSequence::Initialize( return absl::OkStatus(); } -static absl::string_view RecordModeString( - CommandBufferCmdSequence::RecordMode mode) { - switch (mode) { - case CommandBufferCmdSequence::RecordMode::kExclusive: - return "exclusive"; - case CommandBufferCmdSequence::RecordMode::kConditional: - return "conditional"; - } -} - absl::Status CommandBufferCmdSequence::Record( const Thunk::ExecuteParams& execute_params, const CommandBufferCmd::RecordParams& record_params, se::CommandBuffer* command_buffer, RecordMode mode) { - VLOG(3) << "Record " << commands_.size() << " commands into command buffer" - << "; mode=" << RecordModeString(mode); - uint64_t start_micros = tsl::Env::Default()->NowMicros(); + VLOG(3) << "Record " << commands_.size() + << " commands into command buffer; mode=" << absl::StrCat(mode); if (mode == RecordMode::kExclusive) { if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { @@ -285,62 +265,136 @@ absl::Status CommandBufferCmdSequence::Record( } } + if (command_buffer->state() == se::CommandBuffer::State::kUpdate) { + TF_RETURN_IF_ERROR( + RecordUpdate(execute_params, record_params, command_buffer)); + } else { + TF_RETURN_IF_ERROR( + RecordCreate(execute_params, record_params, command_buffer, {}) + .status()); + } + + if (mode == RecordMode::kExclusive) { + TF_RETURN_IF_ERROR(command_buffer->Finalize()); + } + + return absl::OkStatus(); +} + +absl::StatusOr> +CommandBufferCmdSequence::RecordCreate( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer, + absl::Span dependencies) { + // Command buffer must be in create state. + TF_RETURN_IF_ERROR(CheckCommandBufferState( + command_buffer, se::CommandBuffer::State::kCreate)); + + VLOG(3) << "Create " << commands_.size() << " commands"; + uint64_t start_micros = tsl::Env::Default()->NowMicros(); + + // Short-circuit if there are no commands to record. + if (commands_.empty()) { + return std::vector{}; + } + // Keep a state associated with commands in the sequence in the state manager. CommandBufferCmd::StateManager& state = record_params.state; - // If the command buffer is in update state, it means that we already recorded - // all commands into the underlying command buffer and we need to update them. - bool is_update = command_buffer->state() == se::CommandBuffer::State::kUpdate; - for (std::unique_ptr& command : commands_) { + std::optional annotation = + GetKernelAnnotation(command->profile_annotation()); + + // Skip recording collective commands if mock collectives are enabled. if (execute_params.mock_collectives && dynamic_cast(command.get())) { continue; } + // Create new commands by recording them into the command buffer. + DCHECK(!state.GetOrNull(command.get(), command_buffer)) + << "Record state must be null for " << command->ToString(); + auto* record_state = + state.GetOrCreate(command.get(), command_buffer); + + // TODO(b/406370928): Fetch command dependencies computed from the command + // sequence, today we rely on implicit synchronization of all commands. + auto record_action = CommandBufferCmd::RecordCreate{}; + + TF_ASSIGN_OR_RETURN( + record_state->command, + command->Record(execute_params, record_params, std::move(record_action), + command_buffer)); + } + + uint64_t end_micros = tsl::Env::Default()->NowMicros(); + VLOG(3) << "Created " << commands_.size() << " commands in " + << (end_micros - start_micros) << " μs"; + + // TODO(b/406370928): Depending on synchronization mode we must collect + // commands created for all sink nodes in the execution graph. + auto* last_recorded = + state.GetOrNull(commands_.back().get(), command_buffer); + DCHECK(last_recorded) << "Last recorded command state must be not null"; + return std::vector{last_recorded->command}; +} + +absl::Status CommandBufferCmdSequence::RecordUpdate( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + // Command buffer must be already prepared for recording updates. + TF_RETURN_IF_ERROR(CheckCommandBufferState( + command_buffer, se::CommandBuffer::State::kUpdate)); + + VLOG(3) << "Update " << commands_.size() << " commands"; + uint64_t start_micros = tsl::Env::Default()->NowMicros(); + + // Short-circuit if there are no commands to update. + if (commands_.empty()) { + return absl::OkStatus(); + } + + // Keep a state associated with commands in the sequence in the state manager. + CommandBufferCmd::StateManager& state = record_params.state; + + for (std::unique_ptr& command : commands_) { std::optional annotation = GetKernelAnnotation(command->profile_annotation()); - if (is_update) { - // Update existing commands in the command buffer. - auto* record_state = - state.GetOrNull(command.get(), command_buffer); - DCHECK(record_state) << "Record state must be not null for " - << command->ToString(); - - auto record_action = CommandBufferCmd::RecordUpdate{ - std::move(record_state->recorded_commands)}; - TF_ASSIGN_OR_RETURN( - record_state->recorded_commands, - command->Record(execute_params, record_params, - std::move(record_action), command_buffer)); - - } else { - // Create new commands by recording them into the command buffer. - DCHECK(!state.GetOrNull(command.get(), command_buffer)) - << "Record state must be null for " << command->ToString(); - auto* record_state = - state.GetOrCreate(command.get(), command_buffer); - - // TODO(b/406370928): Fetch command dependencies computed from the command - // sequence, today we rely on implicit synchronization of all commands. - auto record_action = CommandBufferCmd::RecordCreate{}; - TF_ASSIGN_OR_RETURN( - record_state->recorded_commands, - command->Record(execute_params, record_params, - std::move(record_action), command_buffer)); + // Skip updating collective commands if mock collectives are enabled. + if (execute_params.mock_collectives && + dynamic_cast(command.get())) { + continue; } - } - if (mode == RecordMode::kExclusive) { - TF_RETURN_IF_ERROR(command_buffer->Finalize()); + // Update existing commands in the command buffer. + auto* record_state = + state.GetOrNull(command.get(), command_buffer); + DCHECK(record_state) << "Record state must be not null for " + << command->ToString(); + + auto record_action = CommandBufferCmd::RecordUpdate{record_state->command}; + + TF_ASSIGN_OR_RETURN( + record_state->command, + command->Record(execute_params, record_params, std::move(record_action), + command_buffer)); } uint64_t end_micros = tsl::Env::Default()->NowMicros(); - VLOG(3) << "Recorded " << commands_.size() - << " commands into command buffer in " << (end_micros - start_micros) - << " μs; mode=" << RecordModeString(mode); + VLOG(3) << "Updated " << commands_.size() << " commands in " + << (end_micros - start_micros) << " μs"; + + return absl::OkStatus(); +} +absl::Status CommandBufferCmdSequence::CheckCommandBufferState( + se::CommandBuffer* command_buffer, + se::CommandBuffer::State expected_state) { + if (command_buffer->state() != expected_state) { + return Internal("Command buffer must be in %v state, got %v", + expected_state, command_buffer->state()); + } return absl::OkStatus(); } @@ -436,7 +490,7 @@ TracedCommandBufferCmd::TracedCommandBufferCmd( CommandBufferCmdType cmd_type, ExecutionStreamId execution_stream_id) : CommandBufferCmd(cmd_type, execution_stream_id) {} -absl::StatusOr +absl::StatusOr TracedCommandBufferCmd::RecordTracedCommand( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -481,7 +535,7 @@ CommandBufferCmd::BufferUseVector ComputationIdCmd::buffers() { return {{dest_, MemoryAccess::kWrite}}; } -absl::StatusOr ComputationIdCmd::Record( +absl::StatusOr ComputationIdCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -548,7 +602,7 @@ absl::Status LaunchCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::StatusOr LaunchCmd::Record( +absl::StatusOr LaunchCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -628,11 +682,10 @@ absl::Status CustomKernelLaunchCmd::Initialize( return absl::OkStatus(); } -absl::StatusOr -CustomKernelLaunchCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - RecordAction record_action, - se::CommandBuffer* command_buffer) { +absl::StatusOr CustomKernelLaunchCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { VLOG(5) << "CustomKernelLaunchCmd: custom_kernel=" << custom_kernel_.name(); se::Kernel* kernel = [&] { @@ -692,7 +745,7 @@ MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd( src_(src), num_bytes_(num_bytes) {} -absl::StatusOr +absl::StatusOr MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -708,7 +761,7 @@ MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, if (num_bytes_ == 0) { VLOG(5) << "Skip recording MemcpyDeviceToDeviceCmd command of 0 bytes"; - return RecordedCommands{}; + return nullptr; } return Handle( @@ -735,7 +788,7 @@ MemzeroCmd::MemzeroCmd(ExecutionStreamId execution_stream_id, : CommandBufferCmd(CommandBufferCmdType::kMemzeroCmd, execution_stream_id), dst_(dst) {} -absl::StatusOr MemzeroCmd::Record( +absl::StatusOr MemzeroCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -747,7 +800,7 @@ absl::StatusOr MemzeroCmd::Record( if (dst_.size() == 0) { VLOG(5) << "Skip recording MemzeroCmd command of 0 bytes"; - return RecordedCommands{}; + return nullptr; } return Handle( @@ -777,7 +830,7 @@ Memset32Cmd::Memset32Cmd(ExecutionStreamId execution_stream_id, dst_(dst), bit_pattern_(bit_pattern) {} -absl::StatusOr Memset32Cmd::Record( +absl::StatusOr Memset32Cmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -789,7 +842,7 @@ absl::StatusOr Memset32Cmd::Record( if (dst_.size() == 0) { VLOG(5) << "Skip recording Memset32Cmd command of 0 bytes"; - return RecordedCommands{}; + return nullptr; } return Handle( @@ -830,7 +883,7 @@ absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::StatusOr CaseCmd::Record( +absl::StatusOr CaseCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -900,7 +953,7 @@ absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params, return body_commands_.Initialize(params, state); } -absl::StatusOr WhileCmd::Record( +absl::StatusOr WhileCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -969,7 +1022,7 @@ absl::Status GemmCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::StatusOr GemmCmd::Record( +absl::StatusOr GemmCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1086,7 +1139,7 @@ absl::Status CublasLtCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::StatusOr CublasLtCmd::Record( +absl::StatusOr CublasLtCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1200,7 +1253,7 @@ absl::Status CuDnnCmd::Initialize(const Thunk::InitializeParams& params, return absl::OkStatus(); } -absl::StatusOr CuDnnCmd::Record( +absl::StatusOr CuDnnCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1217,9 +1270,9 @@ absl::StatusOr CuDnnCmd::Record( const bool supports_explicit, graph_->get()->SupportsExplicitCommandBufferConstruction()); if (supports_explicit) { - return RecordedCommands::Create(command_buffer->CreateDnnGraphCommand( + return command_buffer->CreateDnnGraphCommand( *graph_->get(), *execute_params.stream, - absl::Span(operands), {})); + absl::Span(operands), {}); } return RecordTracedCommand( execute_params, record_params, std::move(record_action), command_buffer, @@ -1244,7 +1297,7 @@ CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { // CustomCallCmd //===----------------------------------------------------------------------===// -absl::StatusOr CustomCallCmd::Record( +absl::StatusOr CustomCallCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1284,7 +1337,7 @@ absl::Status GetBuffers( } } // namespace -absl::StatusOr +absl::StatusOr CustomCallCmd::RecordLegacyCustomCall( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -1324,7 +1377,7 @@ CustomCallCmd::RecordLegacyCustomCall( }); } -absl::StatusOr +absl::StatusOr CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -1441,7 +1494,7 @@ absl::Status CollectiveCmd::Prepare( return resource_requests.AddClique(clique_key); } -absl::StatusOr +absl::StatusOr CollectiveCmd::RecordTracedCommand( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -1476,7 +1529,7 @@ AllReduceCmd::AllReduceCmd(ExecutionStreamId execution_stream_id, reduction_kind_(reduction_kind), buffers_(buffers.begin(), buffers.end()) {} -absl::StatusOr AllReduceCmd::Record( +absl::StatusOr AllReduceCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1539,7 +1592,7 @@ ReduceScatterCmd::ReduceScatterCmd( reduction_kind_(reduction_kind), buffers_(buffers.begin(), buffers.end()) {} -absl::StatusOr ReduceScatterCmd::Record( +absl::StatusOr ReduceScatterCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1602,7 +1655,7 @@ AllToAllCmd::AllToAllCmd(ExecutionStreamId execution_stream_id, has_split_dimension_(has_split_dimension), buffers_(buffers.begin(), buffers.end()) {} -absl::StatusOr AllToAllCmd::Record( +absl::StatusOr AllToAllCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1662,7 +1715,7 @@ AllGatherCmd::AllGatherCmd(ExecutionStreamId execution_stream_id, async_from_stream_id, std::move(config)), buffers_(buffers.begin(), buffers.end()) {} -absl::StatusOr AllGatherCmd::Record( +absl::StatusOr AllGatherCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { @@ -1724,7 +1777,7 @@ CollectiveBroadcastCmd::CollectiveBroadcastCmd( std::move(config)), buffers_(buffers.begin(), buffers.end()) {} -absl::StatusOr +absl::StatusOr CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -1871,11 +1924,10 @@ absl::Status DynamicSliceFusionCmd::Prepare( return absl::OkStatus(); } -absl::StatusOr -DynamicSliceFusionCmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - RecordAction record_action, - se::CommandBuffer* command_buffer) { +absl::StatusOr DynamicSliceFusionCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { se::Stream& stream = *execute_params.stream; const BufferAllocations& orig_allocations = diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index f93992caf6099a..c21bea10e80133 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -200,15 +200,6 @@ class CommandBufferCmd { StateManager& state; }; - // A list of commands recorded into the command buffer (or updated). - struct RecordedCommands { - // Creates a recorded commands from a single se::CommandBuffer command. - static absl::StatusOr Create( - absl::StatusOr command); - - absl::InlinedVector commands; - }; - // Create new commands in the command buffer using the given dependencies. struct RecordCreate { absl::Span dependencies; @@ -216,7 +207,7 @@ class CommandBufferCmd { // Update previously recorded commands in the command buffer. struct RecordUpdate { - RecordedCommands recorded_commands; + const se::CommandBuffer::Command* command; }; // When recording a command into the command buffer we can either update @@ -249,7 +240,7 @@ class CommandBufferCmd { // Records commands into the command buffer. Returned commands will be passed // back on the next call to `Record` into the same command buffer, so that it // can do efficient command buffer updates. - virtual absl::StatusOr Record( + virtual absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) = 0; @@ -298,6 +289,35 @@ class CommandBufferCmdSequence { CommandBufferCmdSequence(CommandBufferCmdSequence&&) = default; CommandBufferCmdSequence& operator=(CommandBufferCmdSequence&&) = default; + using RecordParams = CommandBufferCmd::RecordParams; + + enum class RecordMode { + // In exclusive mode no one else is recording commands into the command + // buffer argument, and cmd sequence is responsible for updating command + // buffer state: finalizing after all commands recorded, and + // switching to update state before recording updates. + kExclusive, + + // In conditional mode multiple cmd sequences can be recorded into the + // command buffer argument, and with command buffer state managed externally + // cmd sequence should not finalize or update it. This mode is used when + // command buffer cmd sequence is recorded into conditional command buffers + // owned by the parent command buffer. + kConditional + }; + + template + friend void AbslStringify(Sink& sink, RecordMode mode) { + switch (mode) { + case RecordMode::kExclusive: + sink.Append("exclusive"); + break; + case RecordMode::kConditional: + sink.Append("conditional"); + break; + } + } + // Synchronization mode defines how much concurrency is allowed between // commands in the sequence. enum class SynchronizationMode { @@ -310,6 +330,18 @@ class CommandBufferCmdSequence { kAutomatic }; + template + friend void AbslStringify(Sink& sink, SynchronizationMode mode) { + switch (mode) { + case SynchronizationMode::kSerialize: + sink.Append("serialize"); + break; + case SynchronizationMode::kAutomatic: + sink.Append("automatic"); + break; + } + } + // A command buffer cmd sequence builder for lazy command sequence // construction. class Builder { @@ -327,21 +359,6 @@ class CommandBufferCmdSequence { std::vector> commands_; }; - enum class RecordMode { - // In exclusive mode no one else is recording commands into the command - // buffer argument, and cmd sequence is responsible for updating command - // buffer state: finalizing after all commands recorded, and - // switching to update state before recording updates. - kExclusive, - - // In conditional mode multiple cmd sequences can be recorded into the - // command buffer argument, and with command buffer state managed externally - // cmd sequence should not finalize or update it. This mode is used when - // command buffer cmd sequence is recorded into conditional command buffers - // owned by the parent command buffer. - kConditional - }; - // Prepares all commands added to a sequence. absl::Status Prepare(const Thunk::PrepareParams& params, Thunk::ResourceRequestsInterface& resource_requests); @@ -350,12 +367,35 @@ class CommandBufferCmdSequence { absl::Status Initialize(const Thunk::InitializeParams& params, CommandBufferCmd::StateManager& state); - // Records all commands added to a sequence into the given command buffer. + // Records commands into the command buffer. This method automatically + // switches between `RecordCreate` or `RecordUpdate` depending on the command + // buffer state. This method assumes that no other command buffer sequences is + // recorded into the same command buffer, and doesn't set up initial + // dependencies for recorded commands. + // + // TODO(b/406370928): This API must be removed, and instead users should + // explicitly call `RecordCreate` or `RecordUpdate` depending on what they + // want to do. absl::Status Record(const Thunk::ExecuteParams& execute_params, - const CommandBufferCmd::RecordParams& record_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer, RecordMode mode = RecordMode::kExclusive); + // Records command creation into the command buffer. Command buffer must be + // in create state. The next command sequence recorded into the same command + // buffer must use returned commands as dependencies, to guarantee that it is + // correctly ordered after this command sequence. + absl::StatusOr> RecordCreate( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer, + absl::Span dependencies); + + // Records command updates into the command buffer. Command buffer must be + // in update state. + absl::Status RecordUpdate(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer); + // Returns buffers referenced by commands in this sequence. const absl::flat_hash_set& buffers() const; @@ -374,13 +414,16 @@ class CommandBufferCmdSequence { // A state associated with commands in the sequence. We rely on this state to // efficiently update command recorded into the command buffer. struct RecordState : public CommandBufferCmd::State { - CommandBufferCmd::RecordedCommands recorded_commands; + const se::CommandBuffer::Command* command; }; CommandBufferCmdSequence( SynchronizationMode synchronization_mode, std::vector> commands); + absl::Status CheckCommandBufferState(se::CommandBuffer* command_buffer, + se::CommandBuffer::State expected_state); + SynchronizationMode synchronization_mode_; std::vector> commands_; @@ -436,7 +479,7 @@ class TracedCommandBufferCmd : public CommandBufferCmd { // Creates a command buffer by calling a user-provided `trace` function and // adds it as a nested command to `command_buffer`. Traced command buffers // cached and reused in an instance of `TracedCommandBuffer` kept in `state`. - absl::StatusOr RecordTracedCommand( + absl::StatusOr RecordTracedCommand( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer, @@ -454,7 +497,7 @@ class ComputationIdCmd : public CommandBufferCmd { ComputationIdCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dest, Kind kind); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -480,7 +523,7 @@ class LaunchCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -515,7 +558,7 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -544,7 +587,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { BufferAllocation::Slice dst, BufferAllocation::Slice src, int64_t num_bytes); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -566,7 +609,7 @@ class MemzeroCmd : public CommandBufferCmd { MemzeroCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -586,7 +629,7 @@ class Memset32Cmd : public CommandBufferCmd { Memset32Cmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst, uint32_t bit_pattern); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -611,7 +654,7 @@ class CaseCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -639,7 +682,7 @@ class WhileCmd : public CommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -669,7 +712,7 @@ class GemmCmd : public TracedCommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -712,7 +755,7 @@ class CublasLtCmd : public TracedCommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -768,7 +811,7 @@ class CuDnnCmd : public TracedCommandBufferCmd { absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -822,7 +865,7 @@ class CustomCallCmd : public CommandBufferCmd { operands_(std::move(operands)), results_(std::move(results)) {} - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -831,12 +874,12 @@ class CustomCallCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } private: - absl::StatusOr RecordLegacyCustomCall( + absl::StatusOr RecordLegacyCustomCall( const Thunk::ExecuteParams& execute_param, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer); - absl::StatusOr RecordXlaFfiCall( + absl::StatusOr RecordXlaFfiCall( const Thunk::ExecuteParams& execute_param, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer); @@ -878,7 +921,7 @@ class CollectiveCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } - absl::StatusOr RecordTracedCommand( + absl::StatusOr RecordTracedCommand( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer, @@ -917,7 +960,7 @@ class AllReduceCmd : public CollectiveCmd { ReductionKind reduction_kind, absl::Span buffers); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -944,7 +987,7 @@ class ReduceScatterCmd : public CollectiveCmd { CollectiveConfig config, ReductionKind reduction_kind, absl::Span buffers); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -971,7 +1014,7 @@ class AllToAllCmd : public CollectiveCmd { bool has_split_dimension, absl::Span buffers); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -997,7 +1040,7 @@ class AllGatherCmd : public CollectiveCmd { ExecutionStreamId async_from_stream_id, CollectiveConfig config, absl::Span buffers); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -1023,7 +1066,7 @@ class CollectiveBroadcastCmd : public CollectiveCmd { CollectiveConfig config, absl::Span buffers); - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; @@ -1058,7 +1101,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { const Thunk::PrepareParams& params, Thunk::ResourceRequestsInterface& resource_requests) final; - absl::StatusOr Record( + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) override; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 43bf61a1fb614f..740a8008d74a04 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -76,10 +76,10 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { execution_stream_id), buffer_usage(buffer_usage) {} - absl::StatusOr Record(const Thunk::ExecuteParams&, - const RecordParams&, RecordAction, - se::CommandBuffer*) override { - return RecordedCommands{}; + absl::StatusOr Record( + const Thunk::ExecuteParams&, const RecordParams&, RecordAction, + se::CommandBuffer*) override { + return nullptr; } BufferUseVector buffers() override { return buffer_usage; } @@ -93,10 +93,10 @@ class FakeCmd : public CommandBufferCmd { : CommandBufferCmd(CommandBufferCmdType::kTracedCommandBufferCmd, execution_stream_id) {} - absl::StatusOr Record(const Thunk::ExecuteParams&, - const RecordParams&, RecordAction, - se::CommandBuffer*) override { - return RecordedCommands{}; + absl::StatusOr Record( + const Thunk::ExecuteParams&, const RecordParams&, RecordAction, + se::CommandBuffer*) override { + return nullptr; } BufferUseVector buffers() override { return BufferUseVector{}; } }; From 58c215fd9f0287fc77fd13049de4c1de03d1c6de Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 11 Apr 2025 18:50:23 -0700 Subject: [PATCH 0604/1324] [XLA] treat dot, convolution and ragged_dot the same with respect to hlo semantics analysis PiperOrigin-RevId: 746659572 --- .../analysis/hlo_value_semantics_analysis.cc | 38 +++++---------- .../analysis/hlo_value_semantics_analysis.h | 14 +++++- .../hlo_value_semantics_analysis_test.cc | 46 +++++++++++++++++++ 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc index 564e29394ba6d5..ef6630e9e94229 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc @@ -50,8 +50,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/side_effect_util.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -387,11 +385,6 @@ absl::Status EinsumDepthAnalysis::HandleDot(HloInstruction* dot) { return HandleDepthIncrementInstruction(dot); } -absl::Status EinsumDepthAnalysis::HandleConvolution( - HloInstruction* convolution) { - return HandleDepthIncrementInstruction(convolution); -} - absl::Status EinsumDepthAnalysis::HandleCall(HloInstruction* call) { const ShapeTree& depth_tree = GetDepthTreeOrDie(call); return HandleCalledComputation(*call->called_computations()[0], depth_tree, @@ -774,12 +767,6 @@ absl::Status EinsumHeightAnalysis::HandleDot(HloInstruction* dot) { return HandleHeightIncrementInstruction(dot); } -absl::Status EinsumHeightAnalysis::HandleConvolution( - HloInstruction* convolution) { - RETURN_IF_HEIGHT_EXISTS(convolution); - return HandleHeightIncrementInstruction(convolution); -} - absl::Status EinsumHeightAnalysis::HandleCall(HloInstruction* call) { RETURN_IF_HEIGHT_EXISTS(call); TF_RETURN_IF_ERROR(HandleCalledComputation(*(call->called_computations()[0]), @@ -1179,6 +1166,13 @@ const HloValueSemantics* HloValueSemanticsPropagation::AddSemantics( return analysis_->NewHloValueSemantics(semantics.label(), semantics.origin()); } +namespace { +bool IsDotOrConvolution(const HloInstruction* instruction) { + return HloPredicateIsOp(instruction); +} +} // namespace + std::vector HloValueSemanticsPropagation::FindEinsumsWhereOriginDependsOnOther( const HloValueSemantics& semantics, const HloPosition& origin_dependence, @@ -1204,10 +1198,8 @@ HloValueSemanticsPropagation::FindEinsumsWhereOriginDependsOnOther( if (origin.instruction->opcode() == HloOpcode::kDynamicSlice) { operands = operands.subspan(0, 1); } - bool is_einsum = origin.instruction->opcode() == HloOpcode::kDot || - origin.instruction->opcode() == HloOpcode::kConvolution; bool found_einsum = false; - if (is_einsum) { + if (IsDotOrConvolution(origin.instruction)) { for (int64_t operand_index = 0; operand_index < operands.size(); ++operand_index) { const HloInstruction* origin_operand = operands[operand_index]; @@ -1251,9 +1243,7 @@ HloValueSemanticsPropagation::ComputeSemanticsFromStaticAndOther( return CopySemanticsWithNewOrigin(other_semantics, instruction); } - bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot || - instruction->opcode() == HloOpcode::kConvolution; - if (is_dot_or_convolution && + if (IsDotOrConvolution(instruction) && other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { return MaybeCreateGradientSemantics( instruction, HloValueSemanticLabel::kActivationGradient); @@ -1301,10 +1291,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( CHECK(weight_semantics.label() == HloValueSemanticLabel::kWeight); CHECK(other_semantics.label() != HloValueSemanticLabel::kStatic && other_semantics.label() != HloValueSemanticLabel::kRandom); - bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot || - instruction->opcode() == HloOpcode::kConvolution; if (other_semantics.label() == HloValueSemanticLabel::kWeight) { - if (!is_dot_or_convolution) { + if (!IsDotOrConvolution(instruction)) { if (weight_semantics.origin() == other_semantics.origin()) { return CopySemantics(other_semantics); } @@ -1313,7 +1301,7 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( return HloValueSemantics(HloValueSemanticLabel::kActivation, {instruction, {}}); } - if (!is_dot_or_convolution) { + if (!IsDotOrConvolution(instruction)) { return CopySemantics(other_semantics); } if (other_semantics.label() == HloValueSemanticLabel::kActivation) { @@ -1362,9 +1350,7 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( CHECK(other_semantics.label() != HloValueSemanticLabel::kStatic && other_semantics.label() != HloValueSemanticLabel::kRandom && other_semantics.label() != HloValueSemanticLabel::kWeight); - bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot || - instruction->opcode() == HloOpcode::kConvolution; - if (!is_dot_or_convolution) { + if (!IsDotOrConvolution(instruction)) { if (activation_semantics.origin() == other_semantics.origin()) { return CopySemantics(other_semantics); } diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h index 0223e9fce5b8c2..a688ccb72aa4d9 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h @@ -99,7 +99,12 @@ class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { absl::Status HandleGetTupleElement( HloInstruction* get_tuple_element) override; absl::Status HandleDot(HloInstruction* dot) override; - absl::Status HandleConvolution(HloInstruction* convolution) override; + absl::Status HandleConvolution(HloInstruction* convolution) override { + return HandleDot(convolution); + } + absl::Status HandleRaggedDot(HloInstruction* ragged_dot) override { + return HandleDot(ragged_dot); + } absl::Status HandleCall(HloInstruction* call) override; absl::Status HandleFusion(HloInstruction* fusion) override; absl::Status HandleWhile(HloInstruction* xla_while) override; @@ -155,7 +160,12 @@ class EinsumHeightAnalysis : public DfsHloVisitorWithDefault { absl::Status HandleGetTupleElement( HloInstruction* get_tuple_element) override; absl::Status HandleDot(HloInstruction* dot) override; - absl::Status HandleConvolution(HloInstruction* convolution) override; + absl::Status HandleConvolution(HloInstruction* convolution) override { + return HandleDot(convolution); + } + absl::Status HandleRaggedDot(HloInstruction* ragged_dot) override { + return HandleDot(ragged_dot); + } absl::Status HandleCall(HloInstruction* call) override; absl::Status HandleFusion(HloInstruction* fusion) override; absl::Status HandleWhile(HloInstruction* xla_while) override; diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc index 255ec9d9b0a631..fafc7958855242 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc @@ -250,6 +250,52 @@ ENTRY entry { EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "dot.2")); } +TEST_F(HloValueSemanticsAnalysisTest, OneRaggedDot) { + const std::string module_str = R"( +HloModule OneMatmul + +region_0.39 { + Arg_0.40 = f32[] parameter(0) + Arg_1.41 = f32[] parameter(1) + ROOT add.42 = f32[] add(Arg_0.40, Arg_1.41) +} + +ENTRY entry { + Arg_1.2 = f32[8,32,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1} + Arg_7.8 = f32[4,32]{1,0} parameter(1), sharding={devices=[2,1]0,1} + G = s32[8]{0} parameter(2), sharding={replicated} + copy = f32[4,32]{1,0} copy(Arg_7.8), sharding={devices=[2,1]0,1} + dot.0 = f32[4,128]{1,0} ragged-dot(copy, Arg_1.2, G), lhs_contracting_dims={1}, rhs_contracting_dims={1}, lhs_ragged_dims={0}, rhs_group_dims={0}, sharding={devices=[2,1]0,1} + constant.5 = f32[] constant(0), sharding={replicated} + broadcast.2 = f32[4,128]{1,0} broadcast(constant.5), dimensions={}, sharding={devices=[2,1]0,1} + maximum.33 = f32[4,128]{1,0} maximum(dot.0, broadcast.2), sharding={devices=[2,1]0,1} + compare.34 = pred[4,128]{1,0} compare(dot.0, maximum.33), direction=EQ, sharding={devices=[2,1]0,1} + constant.4 = f32[] constant(1), sharding={replicated} + broadcast.1 = f32[4,128]{1,0} broadcast(constant.4), dimensions={}, sharding={devices=[2,1]0,1} + select.35 = f32[4,128]{1,0} select(compare.34, broadcast.1, broadcast.2), sharding={devices=[2,1]0,1} + dot.2 = f32[32,128]{0,1} dot(copy, select.35), lhs_contracting_dims={0}, rhs_contracting_dims={0}, sharding={devices=[2,1]0,1} + constant.11 = f32[] constant(-0.01), sharding={replicated} + broadcast.12 = f32[32,128]{1,0} broadcast(constant.11), dimensions={}, sharding={devices=[2,1]0,1} + multiply.52 = f32[32,128]{0,1} multiply(dot.2, broadcast.12), sharding={devices=[2,1]0,1} + reduce.43 = f32[] reduce(maximum.33, constant.5), dimensions={0,1}, to_apply=region_0.39, sharding={replicated} + ROOT tuple.109 = (f32[32,128]{1,0}, f32[]) tuple(multiply.52, reduce.43), sharding={{devices=[2,1]0,1}, {replicated}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1, + /*num_partitions=*/2)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_value_semantics_analysis, + HloValueSemanticsAnalysis::Run(*module)); + EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "copy")); + EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "Arg_1.2")); + EXPECT_TRUE( + IsActivation(*hlo_value_semantics_analysis, module.get(), "dot.0")); + EXPECT_TRUE( + IsStatic(*hlo_value_semantics_analysis, module.get(), "select.35")); + EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "dot.2")); +} TEST_F(HloValueSemanticsAnalysisTest, HandleConditional) { const std::string module_str = R"( HloModule Module From 9db2bb97d32ee443e997a8621a8e5af427abbb10 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Fri, 11 Apr 2025 18:58:07 -0700 Subject: [PATCH 0605/1324] ConvolutionOp: Direct StableHLO to HLO Translate PiperOrigin-RevId: 746661137 --- .../mhlo_to_hlo/attribute_exporter.cc | 30 +++++++++++ .../mhlo_to_hlo/attribute_exporter.h | 5 ++ .../mhlo_to_hlo/gen_hlo_op_writer.td | 4 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 29 +++++++++++ .../xla/hlo/translate/tests/stablehlo.mlir | 51 +++++++++++++++++++ .../stablehlo_legalize_to_hlo_pass.cc | 5 +- 6 files changed, 120 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc index 652401bf7c3703..dcecb5599cb8f4 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -66,6 +67,35 @@ ConvolutionDimensionNumbers ConvertConvDimensionNumbers( return output; } +ConvolutionDimensionNumbers ConvertConvDimensionNumbers( + mlir::stablehlo::ConvDimensionNumbersAttr input) { + ConvolutionDimensionNumbers output; + + output.set_input_batch_dimension(input.getInputBatchDimension()); + output.set_input_feature_dimension(input.getInputFeatureDimension()); + for (auto v : input.getInputSpatialDimensions()) { + output.add_input_spatial_dimensions(v); + } + + output.set_kernel_input_feature_dimension( + input.getKernelInputFeatureDimension()); + output.set_kernel_output_feature_dimension( + input.getKernelOutputFeatureDimension()); + + for (auto v : input.getKernelSpatialDimensions()) { + output.add_kernel_spatial_dimensions(v); + } + + output.set_output_batch_dimension(input.getOutputBatchDimension()); + output.set_output_feature_dimension(input.getOutputFeatureDimension()); + + for (auto v : input.getOutputSpatialDimensions()) { + output.add_output_spatial_dimensions(v); + } + + return output; +} + absl::StatusOr ConvertDotAlgorithm( mlir::mhlo::DotAlgorithmAttr attr) { auto algorithm = mlir::hlo::detail::getKnownDotAlgorithm( diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h index bc8344ce11b01d..f5b910784b246b 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -38,6 +39,10 @@ namespace xla { ConvolutionDimensionNumbers ConvertConvDimensionNumbers( mlir::mhlo::ConvDimensionNumbersAttr input); +// Converts the conv dimensions attribute to XLA HLO. +ConvolutionDimensionNumbers ConvertConvDimensionNumbers( + mlir::stablehlo::ConvDimensionNumbersAttr input); + // Converts the dot algorithm attribute to XLA HLO. absl::StatusOr ConvertDotAlgorithm( mlir::mhlo::DotAlgorithmAttr attr); diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index 1fb6b15607312c..f5e19c39e8149c 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -47,7 +47,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_ConcatenateOp, StableHLO_ConstantOp, // StableHLO_ConvertOp, - // StableHLO_ConvolutionOp, + StableHLO_ConvolutionOp, // StableHLO_CosineOp, // StableHLO_CreateTokenOp, // StableHLO_CrossReplicaSumOp, @@ -153,7 +153,7 @@ defvar CustomHloConverterOps = [ // StableHLO_CompositeOp, StableHLO_ConstantOp, // StableHLO_ConvertOp, - // StableHLO_ConvolutionOp, + StableHLO_ConvolutionOp, // StableHLO_CosineOp, // StableHLO_CustomCallOp, // StableHLO_DotGeneralOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 623cf96a2bd904..3450820faf0c38 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -414,8 +414,11 @@ I64_ARRAY_ATTR_TO_VECTOR(slice_sizes); I64_ELEMENTS_ATTR_TO_VECTOR(fft_length); I64_ELEMENTS_ATTR_TO_VECTOR(dimensions); I64_ELEMENTS_ATTR_TO_VECTOR(window_strides); +I64_ARRAY_ATTR_TO_VECTOR(window_strides); I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation); +I64_ARRAY_ATTR_TO_VECTOR(lhs_dilation); I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation); +I64_ARRAY_ATTR_TO_VECTOR(rhs_dilation); #undef I64_ARRAY_ATTR_TO_VECTOR #undef I64_ELEMENTS_ATTR_TO_VECTOR @@ -1126,6 +1129,32 @@ LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(mlir::stablehlo::ConvolutionOp op, + OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp lhs, rhs; + if (failed(GetXlaOp(op.getLhs(), value_map, &lhs, op))) { + return mlir::failure(); + } + if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) { + return mlir::failure(); + } + xla::PrimitiveType preferred_element_type = + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType())); + xla::XlaOp xla_result = xla::ConvGeneralDilated( + lhs, rhs, Convert_window_strides(op.getWindowStrides()), + Convert_padding(op.getPadding()), + Convert_lhs_dilation(op.getLhsDilation()), + Convert_rhs_dilation(op.getRhsDilation()), + xla::ConvertConvDimensionNumbers(op.getDimensionNumbers()), + Convertuint64_t(op.getFeatureGroupCount()), + Convertuint64_t(op.getBatchGroupCount()), + Unwrap(Convert_precision_config(op.getPrecisionConfig())), + preferred_element_type, op.getWindowReversal()); + value_map[op] = xla_result; + return mlir::success(); +} + } // namespace } // namespace stablehlo diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir index 1f9a500466d22c..ec933fcfb5c12c 100644 --- a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir @@ -67,3 +67,54 @@ func.func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { return %0 : tensor<1x10xf32> } // CHECK-DIRECT: stablehlo.broadcast_in_dim + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[100,26,26,32]{3,2,1,0}, f32[3,3,1,32]{3,2,1,0})->f32[100,28,28,1]{3,2,1,0}} + +// CHECK: ENTRY %[[$main_4:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[100,26,26,32] parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = f32[3,3,1,32] parameter(1) +// CHECK-NEXT: ROOT %[[convolution_3:[^ ]+]] = f32[100,28,28,1] convolution(%[[Arg_0_1]], %[[Arg_1_2]]), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01oi->b01f, metadata= + +module { + func.func @main(%arg0: tensor<100x26x26x32xf32>, %arg1: tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> + return %0 : tensor<100x28x28x1xf32> + } +} +// CHECK-DIRECT: stablehlo.convolution + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(s8[100,26,26,32]{3,2,1,0}, s8[3,3,1,32]{3,2,1,0})->s32[100,28,28,1]{3,2,1,0}} + +// CHECK: ENTRY %[[$main_4:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = s8[100,26,26,32] parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = s8[3,3,1,32] parameter(1) +// CHECK-NEXT: ROOT %[[convolution_3:[^ ]+]] = s32[100,28,28,1] convolution(%[[Arg_0_1]], %[[Arg_1_2]]), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01oi->b01f, metadata= + +module { + func.func @main(%arg0: tensor<100x26x26x32xi8>, %arg1: tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> + return %0 : tensor<100x28x28x1xi32> + } +} +// CHECK-DIRECT: stablehlo.convolution + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(s8[100,26,26,32]{3,2,1,0}, s8[3,3,1,32]{3,2,1,0})->s32[100,28,28,1]{3,2,1,0}} + +// CHECK: ENTRY %[[$main_4:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = s8[100,26,26,32] parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = s8[3,3,1,32] parameter(1) +// CHECK-NEXT: ROOT %[[convolution_3:[^ ]+]] = s32[100,28,28,1] convolution(%[[Arg_0_1]], %[[Arg_1_2]]), window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, metadata= + +module { + func.func @main(%arg0: tensor<100x26x26x32xi8>, %arg1: tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [true, true]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> + return %0 : tensor<100x28x28x1xi32> + } +} +// CHECK-DIRECT: stablehlo.convolution diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index a0f82893435d59..14f1b5290f7f3f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -68,8 +68,9 @@ void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< // go/keep-sorted start stablehlo::AddOp, stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, - stablehlo::ConstantOp, stablehlo::DynamicBroadcastInDimOp, - stablehlo::DynamicSliceOp, stablehlo::SliceOp + stablehlo::ConstantOp, stablehlo::ConvolutionOp, + stablehlo::DynamicBroadcastInDimOp, stablehlo::DynamicSliceOp, + stablehlo::SliceOp // go/keep-sorted end >(); } From 3ccbc7022bba7d947ba86b6f73f53cc38b9f6454 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Fri, 11 Apr 2025 20:03:39 -0700 Subject: [PATCH 0606/1324] Unary Elementwise Ops : Direct StableHLO -> HLO Translation PiperOrigin-RevId: 746675298 --- .../mhlo_to_hlo/gen_hlo_op_writer.td | 60 ++++---- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 129 +++++++++++++++--- third_party/xla/xla/hlo/translate/tests/BUILD | 1 + .../tests/stablehlo_unary_elementwise.mlir | 117 ++++++++++++++++ .../stablehlo_legalize_to_hlo_pass.cc | 14 +- 5 files changed, 272 insertions(+), 49 deletions(-) create mode 100644 third_party/xla/xla/hlo/translate/tests/stablehlo_unary_elementwise.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index f5e19c39e8149c..98cc0524fdf4d5 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -19,7 +19,7 @@ include "stablehlo/dialect/StablehloOps.td" // List of StableHLO ops that are allowed to be directly converted to HLO // without intermediate MHLO step. defvar HloConversionAllowedOps = [ - // StableHLO_AbsOp, + StableHLO_AbsOp, StableHLO_AddOp, // StableHLO_AfterAllOp, // StableHLO_AllGatherOp, @@ -34,11 +34,11 @@ defvar HloConversionAllowedOps = [ StableHLO_BroadcastInDimOp, StableHLO_BroadcastOp, // StableHLO_CaseOp, - // StableHLO_CbrtOp, - // StableHLO_CeilOp, + StableHLO_CbrtOp, + StableHLO_CeilOp, // StableHLO_CholeskyOp, // StableHLO_ClampOp, - // StableHLO_ClzOp, + StableHLO_ClzOp, // StableHLO_CollectiveBroadcastOp, // StableHLO_CollectivePermuteOp, // StableHLO_CompareOp, @@ -46,9 +46,9 @@ defvar HloConversionAllowedOps = [ // StableHLO_CompositeOp, // StableHLO_ConcatenateOp, StableHLO_ConstantOp, - // StableHLO_ConvertOp, + StableHLO_ConvertOp, StableHLO_ConvolutionOp, - // StableHLO_CosineOp, + StableHLO_CosineOp, // StableHLO_CreateTokenOp, // StableHLO_CrossReplicaSumOp, // StableHLO_CustomCallOp, @@ -64,36 +64,36 @@ defvar HloConversionAllowedOps = [ StableHLO_DynamicSliceOp, // StableHLO_DynamicUpdateSliceOp, // StableHLO_EinsumOp, - // StableHLO_Expm1Op, - // StableHLO_ExpOp, + StableHLO_Expm1Op, + StableHLO_ExpOp, // StableHLO_FftOp, - // StableHLO_FloorOp, + StableHLO_FloorOp, // StableHLO_GatherOp, // StableHLO_GetDimensionSizeOp, // StableHLO_GetTupleElementOp, // StableHLO_IfOp, - // StableHLO_ImagOp, + StableHLO_ImagOp, // StableHLO_InfeedOp, // StableHLO_IotaOp, - // StableHLO_IsFiniteOp, - // StableHLO_Log1pOp, - // StableHLO_LogisticOp, - // StableHLO_LogOp, + StableHLO_IsFiniteOp, + StableHLO_Log1pOp, + StableHLO_LogisticOp, + StableHLO_LogOp, // StableHLO_MapOp, // StableHLO_MaxOp, // StableHLO_MinOp, // StableHLO_MulOp, - // StableHLO_NegOp, - // StableHLO_NotOp, + StableHLO_NegOp, + StableHLO_NotOp, // StableHLO_OptimizationBarrierOp, // StableHLO_OrOp, // StableHLO_OutfeedOp, // StableHLO_PadOp, // StableHLO_PartitionIdOp, - // StableHLO_PopulationCountOp, + StableHLO_PopulationCountOp, // StableHLO_PowOp, // StableHLO_RealDynamicSliceOp, - // StableHLO_RealOp, + StableHLO_RealOp, // StableHLO_RecvOp, // StableHLO_ReduceOp, // StableHLO_ReducePrecisionOp, @@ -106,9 +106,9 @@ defvar HloConversionAllowedOps = [ // StableHLO_ReverseOp, // StableHLO_RngBitGeneratorOp, // StableHLO_RngOp, - // StableHLO_RoundNearestEvenOp, - // StableHLO_RoundOp, - // StableHLO_RsqrtOp, + StableHLO_RoundNearestEvenOp, + StableHLO_RoundOp, + StableHLO_RsqrtOp, // StableHLO_ScatterOp, // StableHLO_SelectAndScatterOp, // StableHLO_SelectOp, @@ -117,14 +117,14 @@ defvar HloConversionAllowedOps = [ // StableHLO_ShiftLeftOp, // StableHLO_ShiftRightArithmeticOp, // StableHLO_ShiftRightLogicalOp, - // StableHLO_SignOp, - // StableHLO_SineOp, + StableHLO_SignOp, + StableHLO_SineOp, StableHLO_SliceOp, // StableHLO_SortOp, - // StableHLO_SqrtOp, + StableHLO_SqrtOp, // StableHLO_SubtractOp, - // StableHLO_TanhOp, - // StableHLO_TanOp, + StableHLO_TanhOp, + StableHLO_TanOp, // StableHLO_TorchIndexSelectOp, // StableHLO_TransposeOp, // StableHLO_TriangularSolveOp, @@ -152,9 +152,9 @@ defvar CustomHloConverterOps = [ // StableHLO_CompareOp, // StableHLO_CompositeOp, StableHLO_ConstantOp, - // StableHLO_ConvertOp, + StableHLO_ConvertOp, StableHLO_ConvolutionOp, - // StableHLO_CosineOp, + StableHLO_CosineOp, // StableHLO_CustomCallOp, // StableHLO_DotGeneralOp, // StableHLO_DotOp, @@ -185,10 +185,10 @@ defvar CustomHloConverterOps = [ // StableHLO_SelectAndScatterOp, // StableHLO_SendOp, // StableHLO_SetDimensionSizeOp, - // StableHLO_SineOp, + StableHLO_SineOp, // StableHLO_SortOp, // StableHLO_SubtractOp, - // StableHLO_TanOp, + StableHLO_TanOp, // StableHLO_UniformDequantizeOp, // StableHLO_UniformQuantizeOp, // StableHLO_WhileOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 3450820faf0c38..58da59d75e2f3a 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -38,6 +38,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -107,6 +108,8 @@ limitations under the License. #include "xla/tsl/platform/types.h" #include "xla/xla_data.pb.h" +#define DEBUG_TYPE "xla-translate" + using ::int64_t; using ::tsl::int16; using ::tsl::int32; @@ -651,6 +654,8 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( } // Converts ResultAccuracyAttr to XLA ResultAccuracy proto. +// This function name is non-standard to match the codegen +// function name, similar to other attribute converters. static xla::ResultAccuracy Convert_result_accuracy( std::optional optional_result_accuracy_attr) { @@ -665,21 +670,58 @@ static xla::ResultAccuracy Convert_result_accuracy( optional_result_accuracy_attr.value().getRtol().convertToDouble()); result_accuracy.mutable_tolerance()->set_ulps( optional_result_accuracy_attr.value().getUlps()); - } else { - xla::ResultAccuracy::Mode mode; - auto result_accuracy_mode = - ::mlir::mhlo::stringifyResultAccuracyMode( - optional_result_accuracy_attr.value().getMode().getValue()) - .str(); - if (xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) { - result_accuracy.set_mode(mode); - } else { - auto* context = optional_result_accuracy_attr.value().getContext(); - mlir::emitError(mlir::UnknownLoc::get(context)) - << "unexpected result accuracy mode " << result_accuracy_mode; - return xla::ResultAccuracy(); - } + return result_accuracy; } + + xla::ResultAccuracy::Mode mode; + auto result_accuracy_mode = + ::mlir::mhlo::stringifyResultAccuracyMode( + optional_result_accuracy_attr.value().getMode().getValue()) + .str(); + if (!xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) { + auto* context = optional_result_accuracy_attr.value().getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "unexpected result accuracy mode " << result_accuracy_mode; + return xla::ResultAccuracy(); + } + + result_accuracy.set_mode(mode); + return result_accuracy; +} + +// Converts ResultAccuracyAttr to XLA ResultAccuracy proto. +// This function name is non-standard to match the codegen +// function name, similar to other attribute converters. +static xla::ResultAccuracy Convert_result_accuracy( + std::optional + optional_result_accuracy_attr) { + if (!optional_result_accuracy_attr.has_value()) return xla::ResultAccuracy(); + + auto result_accuracy = xla::ResultAccuracy(); + if (optional_result_accuracy_attr.value().getMode().getValue() == + mlir::stablehlo::ResultAccuracyMode::TOLERANCE) { + result_accuracy.mutable_tolerance()->set_atol( + optional_result_accuracy_attr.value().getAtol().convertToDouble()); + result_accuracy.mutable_tolerance()->set_rtol( + optional_result_accuracy_attr.value().getRtol().convertToDouble()); + result_accuracy.mutable_tolerance()->set_ulps( + optional_result_accuracy_attr.value().getUlps()); + return result_accuracy; + } + + xla::ResultAccuracy::Mode mode; + auto result_accuracy_mode = + ::mlir::stablehlo::stringifyResultAccuracyMode( + optional_result_accuracy_attr.value().getMode().getValue()) + .str(); + if (!xla::ResultAccuracy::Mode_Parse(result_accuracy_mode, &mode)) { + auto* context = optional_result_accuracy_attr.value().getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "unexpected result accuracy mode " << result_accuracy_mode; + return xla::ResultAccuracy(); + } + + result_accuracy.set_mode(mode); return result_accuracy; } @@ -1155,6 +1197,57 @@ LogicalResult ExportXlaOp(mlir::stablehlo::ConvolutionOp op, return mlir::success(); } +LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) + return failure(); + + value_map[op] = xla::ConvertElementType( + operand, + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + return success(); +} + +LogicalResult ExportXlaOp(CosineOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + xla::XlaOp arg; + if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op))) + return mlir::failure(); + xla::ResultAccuracy result_accuracy = + Convert_result_accuracy(op.getResultAccuracy()); + auto xla_result = xla::Cos(Unwrap(arg), result_accuracy); + value_map[result] = xla_result; + return mlir::success(); +} + +LogicalResult ExportXlaOp(SineOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + xla::XlaOp arg; + xla::ResultAccuracy result_accuracy = + Convert_result_accuracy(op.getResultAccuracy()); + if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op))) + return mlir::failure(); + auto xla_result = xla::Sin(Unwrap(arg), result_accuracy); + value_map[result] = xla_result; + return mlir::success(); +} + +LogicalResult ExportXlaOp(TanOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + xla::XlaOp arg; + xla::ResultAccuracy result_accuracy = + Convert_result_accuracy(op.getResultAccuracy()); + if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op))) + return mlir::failure(); + auto xla_result = xla::Tan(Unwrap(arg), result_accuracy); + value_map[result] = xla_result; + return mlir::success(); +} + } // namespace } // namespace stablehlo @@ -4132,11 +4225,14 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } xla::XlaOp return_value; - for (auto& inst : *block) + for (auto& inst : *block) { + if (isa(inst.getDialect())) + LLVM_DEBUG(llvm::dbgs() + << "Lowering: " << inst.getName().getStringRef() << "\n"); if (failed(Lower(&inst, is_entry_function, ret_shardings, implicit_results, builder, &lowering, &return_value))) return failure(); - + } // Build the XlaComputation and check for failures. auto computation_or = return_value.valid() ? builder->Build(return_value) : builder->Build(); @@ -4145,6 +4241,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( return failure(); } *result = std::move(computation_or.value()); + LLVM_DEBUG(llvm::dbgs() << "Created: " << result->name() << "\n"); return success(); } diff --git a/third_party/xla/xla/hlo/translate/tests/BUILD b/third_party/xla/xla/hlo/translate/tests/BUILD index 228f90ee2cc1b5..383f5c9ff45d89 100644 --- a/third_party/xla/xla/hlo/translate/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/tests/BUILD @@ -20,6 +20,7 @@ lit_test_suite( "simple.mlir", "stablehlo.mlir", "stablehlo_invalid.mlir", + "stablehlo_unary_elementwise.mlir", "vhlo_input.mlir", # go/keep-sorted end ], diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo_unary_elementwise.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo_unary_elementwise.mlir new file mode 100644 index 00000000000000..2d0b90d6e809e7 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo_unary_elementwise.mlir @@ -0,0 +1,117 @@ +// RUN: xla-translate --stablehlo-to-hlo-text -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo=convert-xla-supported-stablehlo=false -split-input-file %s | FileCheck %s --check-prefix CHECK-DIRECT + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[3,4]{1,0}, c64[2]{0}, s32[5]{0}, pred[5]{0}, f32[?,784]{1,0})->(f32[3,4]{1,0}, s32[5]{0}, f32[2]{0}, f32[2]{0}, pred[5]{0}, /*index=5*/f32[?,784]{1,0}, f16[3,4]{1,0}, pred[3,4]{1,0})} + +// CHECK: ENTRY %[[$main_35:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[3,4] parameter(0) +// CHECK-NEXT: %[[abs_6:[^ ]+]] = f32[3,4] abs(%[[Arg_0_1]]), metadata= +// CHECK-NEXT: %[[cbrt_7:[^ ]+]] = f32[3,4] cbrt(%[[abs_6]]), metadata= +// CHECK-NEXT: %[[ceil_8:[^ ]+]] = f32[3,4] ceil(%[[cbrt_7]]), metadata= +// CHECK-NEXT: %[[cosine_9:[^ ]+]] = f32[3,4] cosine(%[[ceil_8]]), metadata= +// CHECK-NEXT: %[[exponential_10:[^ ]+]] = f32[3,4] exponential(%[[cosine_9]]), metadata= +// CHECK-NEXT: %[[exponential_minus_one_11:[^ ]+]] = f32[3,4] exponential-minus-one(%[[exponential_10]]), metadata= +// CHECK-NEXT: %[[floor_12:[^ ]+]] = f32[3,4] floor(%[[exponential_minus_one_11]]), metadata= +// CHECK-NEXT: %[[log_13:[^ ]+]] = f32[3,4] log(%[[floor_12]]), metadata= +// CHECK-NEXT: %[[log_plus_one_14:[^ ]+]] = f32[3,4] log-plus-one(%[[log_13]]), metadata= +// CHECK-NEXT: %[[logistic_15:[^ ]+]] = f32[3,4] logistic(%[[log_plus_one_14]]), metadata= +// CHECK-NEXT: %[[negate_16:[^ ]+]] = f32[3,4] negate(%[[logistic_15]]), metadata= +// CHECK-NEXT: %[[round_nearest_afz_17:[^ ]+]] = f32[3,4] round-nearest-afz(%[[negate_16]]), metadata= +// CHECK-NEXT: %[[round_nearest_even_18:[^ ]+]] = f32[3,4] round-nearest-even(%[[round_nearest_afz_17]]), metadata= +// CHECK-NEXT: %[[rsqrt_19:[^ ]+]] = f32[3,4] rsqrt(%[[round_nearest_even_18]]), metadata= +// CHECK-NEXT: %[[sign_20:[^ ]+]] = f32[3,4] sign(%[[rsqrt_19]]), metadata= +// CHECK-NEXT: %[[sine_21:[^ ]+]] = f32[3,4] sine(%[[sign_20]]), metadata= +// CHECK-NEXT: %[[sqrt_22:[^ ]+]] = f32[3,4] sqrt(%[[sine_21]]), metadata= +// CHECK-NEXT: %[[tan_23:[^ ]+]] = f32[3,4] tan(%[[sqrt_22]]), metadata= +// CHECK-NEXT: %[[tanh_24:[^ ]+]] = f32[3,4] tanh(%[[tan_23]]), metadata= +// CHECK-NEXT: %[[Arg_2_3:[^ ]+]] = s32[5] parameter(2) +// CHECK-NEXT: %[[abs_25:[^ ]+]] = s32[5] abs(%[[Arg_2_3]]), metadata= +// CHECK-NEXT: %[[count_leading_zeros_26:[^ ]+]] = s32[5] count-leading-zeros(%[[abs_25]]), metadata= +// CHECK-NEXT: %[[not_27:[^ ]+]] = s32[5] not(%[[count_leading_zeros_26]]), metadata= +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = c64[2] parameter(1) +// CHECK-NEXT: %[[imag_28:[^ ]+]] = f32[2] imag(%[[Arg_1_2]]), metadata= +// CHECK-NEXT: %[[real_29:[^ ]+]] = f32[2] real(%[[Arg_1_2]]), metadata= +// CHECK-NEXT: %[[Arg_3_4:[^ ]+]] = pred[5] parameter(3) +// CHECK-NEXT: %[[not_30:[^ ]+]] = pred[5] not(%[[Arg_3_4]]), metadata= +// CHECK-NEXT: %[[Arg_4_5:[^ ]+]] = f32[?,784] parameter(4) +// CHECK-NEXT: %[[abs_31:[^ ]+]] = f32[?,784] abs(%[[Arg_4_5]]), metadata= +// CHECK-NEXT: %[[convert_32:[^ ]+]] = f16[3,4] convert(%[[Arg_0_1]]), metadata= +// CHECK-NEXT: %[[is_finite_33:[^ ]+]] = pred[3,4] is-finite(%[[Arg_0_1]]), metadata= +// CHECK-NEXT: ROOT %[[tuple_34:[^ ]+]] = (f32[3,4], s32[5], f32[2], f32[2], pred[5], /*index=5*/f32[?,784], f16[3,4], pred[3,4]) tuple(%[[tanh_24]], %[[not_27]], %[[imag_28]], %[[real_29]], %[[not_30]], /*index=5*/%[[abs_31]], %[[convert_32]], %[[is_finite_33]]) + +func.func @main( + %arg_f32: tensor<3x4xf32>, + %arg_complex: tensor<2xcomplex>, + %arg_int: tensor<5xi32>, + %arg_bool: tensor<5xi1>, + %arg_dynamic: tensor +) -> ( + tensor<3x4xf32>, + tensor<5xi32>, + tensor<2xf32>, + tensor<2xf32>, + tensor<5xi1>, + tensor, + tensor<3x4xf16>, + tensor<3x4xi1> +) { + %f0 = stablehlo.abs %arg_f32 : tensor<3x4xf32> + %f1 = stablehlo.cbrt %f0 : tensor<3x4xf32> + %f2 = stablehlo.ceil %f1 : tensor<3x4xf32> + %f4 = stablehlo.cosine %f2 : tensor<3x4xf32> + %f6 = stablehlo.exponential %f4 : tensor<3x4xf32> + %f7 = stablehlo.exponential_minus_one %f6 : tensor<3x4xf32> + %f8 = stablehlo.floor %f7 : tensor<3x4xf32> + %f11 = stablehlo.log %f8 : tensor<3x4xf32> + %f12 = stablehlo.log_plus_one %f11 : tensor<3x4xf32> + %f13 = stablehlo.logistic %f12 : tensor<3x4xf32> + %f14 = stablehlo.negate %f13 : tensor<3x4xf32> + %f19 = stablehlo.round_nearest_afz %f14 : tensor<3x4xf32> + %f20 = stablehlo.round_nearest_even %f19 : tensor<3x4xf32> + %f21 = stablehlo.rsqrt %f20 : tensor<3x4xf32> + %f22 = stablehlo.sign %f21 : tensor<3x4xf32> + %f23 = stablehlo.sine %f22 : tensor<3x4xf32> + %f24 = stablehlo.sqrt %f23 : tensor<3x4xf32> + %f25 = stablehlo.tan %f24 : tensor<3x4xf32> + %f26 = stablehlo.tanh %f25 : tensor<3x4xf32> + %i0 = stablehlo.abs %arg_int : tensor<5xi32> + %i5 = stablehlo.count_leading_zeros %i0 : tensor<5xi32> + %i16 = stablehlo.not %i5 : tensor<5xi32> + %cx9 = stablehlo.imag %arg_complex : (tensor<2xcomplex>) -> tensor<2xf32> + %cx18 = stablehlo.real %arg_complex : (tensor<2xcomplex>) -> tensor<2xf32> + %b15 = stablehlo.not %arg_bool : tensor<5xi1> + %d3 = stablehlo.abs %arg_dynamic : tensor + %t3 = stablehlo.convert %arg_f32 : (tensor<3x4xf32>) -> tensor<3x4xf16> + %t10 = stablehlo.is_finite %arg_f32 : (tensor<3x4xf32>) -> tensor<3x4xi1> + + // Return all the final results to prevent DCE + func.return %f26, %i16, %cx9, %cx18, %b15, %d3, %t3, %t10 : tensor<3x4xf32>, tensor<5xi32>, tensor<2xf32>, tensor<2xf32>, tensor<5xi1>, tensor, tensor<3x4xf16>, tensor<3x4xi1> +} +// CHECK-DIRECT: stablehlo.abs +// CHECK-DIRECT: stablehlo.cbrt +// CHECK-DIRECT: stablehlo.ceil +// CHECK-DIRECT: stablehlo.cosine +// CHECK-DIRECT: stablehlo.exponential +// CHECK-DIRECT: stablehlo.exponential_minus_one +// CHECK-DIRECT: stablehlo.floor +// CHECK-DIRECT: stablehlo.log +// CHECK-DIRECT: stablehlo.log_plus_one +// CHECK-DIRECT: stablehlo.logistic +// CHECK-DIRECT: stablehlo.negate +// CHECK-DIRECT: stablehlo.round_nearest_afz +// CHECK-DIRECT: stablehlo.round_nearest_even +// CHECK-DIRECT: stablehlo.rsqrt +// CHECK-DIRECT: stablehlo.sign +// CHECK-DIRECT: stablehlo.sine +// CHECK-DIRECT: stablehlo.sqrt +// CHECK-DIRECT: stablehlo.tan +// CHECK-DIRECT: stablehlo.tanh +// CHECK-DIRECT: stablehlo.abs +// CHECK-DIRECT: stablehlo.count_leading_zeros +// CHECK-DIRECT: stablehlo.not +// CHECK-DIRECT: stablehlo.imag +// CHECK-DIRECT: stablehlo.real +// CHECK-DIRECT: stablehlo.not +// CHECK-DIRECT: stablehlo.abs +// CHECK-DIRECT: stablehlo.convert +// CHECK-DIRECT: stablehlo.is_finite diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 14f1b5290f7f3f..58c1ab8d9f70cc 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -67,10 +67,18 @@ struct AddDependencyOpToMhoTokenConverter void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< // go/keep-sorted start + stablehlo::AbsOp, stablehlo::CbrtOp, stablehlo::SqrtOp, stablehlo::TanOp, stablehlo::AddOp, stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, - stablehlo::ConstantOp, stablehlo::ConvolutionOp, - stablehlo::DynamicBroadcastInDimOp, stablehlo::DynamicSliceOp, - stablehlo::SliceOp + stablehlo::CeilOp, stablehlo::ClzOp, stablehlo::ConvertOp, + stablehlo::ConstantOp, stablehlo::ConvolutionOp, stablehlo::CosineOp, + stablehlo::DynamicSliceOp, stablehlo::FloorOp, stablehlo::ImagOp, + stablehlo::ExpOp, stablehlo::Expm1Op, stablehlo::DynamicBroadcastInDimOp, + stablehlo::IsFiniteOp, stablehlo::Log1pOp, stablehlo::LogOp, + stablehlo::LogisticOp, stablehlo::NegOp, stablehlo::NotOp, + stablehlo::PopulationCountOp, stablehlo::RealOp, + stablehlo::RoundNearestEvenOp, stablehlo::RoundOp, stablehlo::RsqrtOp, + stablehlo::SignOp, stablehlo::SineOp, stablehlo::SliceOp, + stablehlo::TanhOp // go/keep-sorted end >(); } From 333fe4c6868d43be17011c7c19353b986a3df2ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 20:31:06 -0700 Subject: [PATCH 0607/1324] Reverts c5d30c01cc424f937e860ba4e7f5a3292de1d5a3 PiperOrigin-RevId: 746681568 --- .../batched_gather_round_trip.mlir | 11 +-- .../batched_scatter_round_trip.mlir | 11 +-- .../legalize-tf-no-runtime-verification.mlir | 5 +- .../compiler/mlir/lite/tests/optimize.mlir | 42 +++++++--- .../mlir/lite/transforms/optimize_pass.cc | 76 ++++++++++++++++++- 5 files changed, 117 insertions(+), 28 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir index adb22ddd009a80..12de9da5939573 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir @@ -4,14 +4,11 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { - // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> - // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> - // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> - // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 + // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ + // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %2) <{ // CHECK-ROUNDTRIP-SAME: dimension_numbers = #stablehlo.gather< // CHECK-ROUNDTRIP-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir index 7e42ff310c080f..44d1bb7dd8b72f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir @@ -4,14 +4,11 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { - // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> - // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> - // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> - // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 + // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ + // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %2, %arg2) <{ // CHECK-ROUNDTRIP-SAME: scatter_dimension_numbers = #stablehlo.scatter // CHECK-ROUNDTRIP-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>}> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index e0793cbf803c4f..2c17e734c58dad 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -5,6 +5,7 @@ func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> ten func.return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> -// CHECK: return %0 : tensor<3x3xbf16> +// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xbf16> +// CHECK: [[MUL:%.*]] = tfl.mul(%arg0, [[CST]]) <{fused_activation_function = "NONE"}> : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> +// CHECK: return [[MUL]] : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 47d32734dbb27c..83b82d50fc064f 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -4399,11 +4399,11 @@ func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x %1 = "tfl.broadcast_to"(%0, %cst_0) : (tensor<1x1x1x8x1x1xf32>, tensor<6xi32>) -> tensor<1x1x1x8x16x1xf32> %2 = "tfl.reshape"(%1, %cst_1) : (tensor<1x1x1x8x16x1xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> return %2 : tensor<1x1x1x128xf32> - // CHECK: %cst = arith.constant dense<[8, 16]> : tensor<2xi64> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<8x16xf32> // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> - // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<8x1xf32>, tensor<2xi64>) -> tensor<8x16xf32> + // CHECK: %1 = tfl.mul(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> // CHECK: return %2 : tensor<1x1x1x128xf32> } @@ -4425,63 +4425,83 @@ func.func @FuseExcessBroadcastingOnReshapesDynamicShapes(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> } // CHECK-LABEL: @broadcast_to_i32_low_dim func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: return %0 : tensor<3x3xi32> } // CHECK-LABEL: @broadcast_to_low_dim_with_unknown_shape func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> } // CHECK-LABEL: @broadcast_to_i16_low_dim func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) -> tensor<3x3xi16> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> return %0 : tensor<3x3xi16> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> + // CHECK: return %0 : tensor<3x3xi16> } // CHECK-LABEL: @broadcast_to_i32_low_dim_with_unknown_output func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<*xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1> : tensor + // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> + // CHECK: return %1 : tensor<*xi32> } // CHECK-LABEL: @broadcast_to_ui32 func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> return %0 : tensor<10xui32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor<10xui32>) -> tensor<10xui32> + // CHECK: return %0 : tensor<10xui32> } // CHECK-LABEL: @broadcast_to_f32 func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> } // CHECK-LABEL: @broadcast_to_i32 func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: return %0 : tensor<3x3xi32> } // CHECK-LABEL: @broadcast_to_i32_with_dynamic_shape_and_output func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x?xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x?xi32> return %0 : tensor<3x?xi32> - // CHECK: tfl.broadcast_to + // CHECK: %cst = arith.constant dense<1> : tensor + // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> + // CHECK: return %1 : tensor<3x?xi32> } // CHECK-LABEL: @broadcast_to_ui32_with_dynamic_output diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index 448ef85bdac543..1de8398dbd8bd9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -977,6 +977,80 @@ struct SqueezeReshapesAroundBroadcastOp } }; +// This pattern matches TFL::BroadcastToOp WITH TENSOR RANK <= 4 and replaces +// it with a MulOp that multiplies the tensor by a splat constant with 1s. +struct ConvertTFLBroadcastToMulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, + PatternRewriter &rewriter) const override { + auto input_type = + mlir::cast(tfl_broadcast_to_op.getInput().getType()); + auto output_type = + mlir::cast(tfl_broadcast_to_op.getOutput().getType()); + auto shape_type = + mlir::cast(tfl_broadcast_to_op.getShape().getType()); + Type element_type = input_type.getElementType(); + + auto loc = tfl_broadcast_to_op->getLoc(); + + // Check that the output type is not dynamic and is less-than-equal to 4D or + // the shape type is static, 1D and has less-than-equal to 4 elements. + bool is_output_shape_dynamic = + (!output_type.hasRank() || (output_type.getRank() > 4) || + (output_type.getNumDynamicDims() > 0)); + bool is_broadcast_shape_dynamic = + (!shape_type.hasStaticShape() || (shape_type.getRank() != 1) || + (shape_type.getDimSize(0) > 4)); + if (is_output_shape_dynamic && is_broadcast_shape_dynamic) + return rewriter.notifyMatchFailure( + loc, "output_rank or broadcast_to shape not supported"); + + // Allow lowering when the input's elements type is F32, BFloat16, I32 or + // I16. + if (!(mlir::isa(element_type) || + element_type.isInteger(32) || element_type.isInteger(16))) + return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); + + // TFL_FillOp is created only if is_output_shape_dynamic is true, otherwise + // a Arith.ConstOp is created. + if (is_output_shape_dynamic && + output_type.getElementType().isUnsignedInteger()) { + return rewriter.notifyMatchFailure( + loc, + "Unsigned broadcast_to output with dynamic shape is not supported"); + } + + Value mul_rhs_value; + if (!output_type.hasRank() || (output_type.getNumDynamicDims() > 0)) { + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, loc, input_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + mul_rhs_value = rewriter.create( + loc, output_type, tfl_broadcast_to_op.getShape(), + status_or_const_op.value()); + } else { + auto status_or_const_op = + CreateConstOpWithVectorValue(&rewriter, loc, output_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + mul_rhs_value = status_or_const_op.value(); + } + + auto mul_op = rewriter.create( + loc, output_type, tfl_broadcast_to_op.getInput(), mul_rhs_value, + rewriter.getStringAttr("NONE")); + rewriter.replaceOp(tfl_broadcast_to_op, mul_op.getResult()); + return success(); + } +}; + struct FuseAddAndStridedSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3007,7 +3081,7 @@ void OptimizePass::runOnOperation() { OptimizeTopK, FuseAddAndStridedSlice, FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul, MoveReshapeAfterFullyConnected, - EnableFullyConnectedKeepNumDimsBeforeReshape, + EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp, ReorderTransposeReshapeTranspose, FullyConnectedSwapOperandsWhenLHSIsConst>(ctx); if (!GetOptions().disable_fuse_mul_and_fc) { From f94fd1d5bfb53b9f3f1928ae35560da62c71b6e3 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 11 Apr 2025 20:37:16 -0700 Subject: [PATCH 0608/1324] Remove duplicate passes in XlaCallModule lowering StableHLO->HLO conversion already takes care of these in `ConvertStablehloToHlo` called in `XlaCallModuleLoader::ToXlaComputation` PiperOrigin-RevId: 746683567 --- .../tf2xla/kernels/xla_call_module_loader.cc | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 820d5ded5abe68..c88c4042ca2c7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -490,18 +490,14 @@ absl::Status XlaCallModuleLoader::ValidateStaticShapes() { absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - // TODO (b/393390051): Migrate required passes to StableHLO. + // TODO (b/410057228): Replace MHLO canonicalization with StableHLO. + // This code requires MHLO CaseOp canonicalization to remove unreachable + // branches, else `tf.call_tf_function` inlining can fail. mlir::PassManager pm(module_->getContext()); - applyTensorflowAndCLOptions(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - // In order to export to XLA, we must sink constants to control flow - // regions, since XLA uses functional control flow. - pm.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (failed(pm.run(*module_))) { return absl::InternalError( absl::StrCat("MHLO->HLO lowering passes failed: ", @@ -509,7 +505,7 @@ absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { } if (VLOG_IS_ON(5)) { - DumpMlirOpToFile("xla_call_module.after_mhlo_lowering", *module_); + DumpMlirOpToFile("xla_call_module.after_canonicalization", *module_); } return absl::OkStatus(); From 1efc0950dd4edc3be857d943ceb8321e52d0acb1 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Fri, 11 Apr 2025 21:55:09 -0700 Subject: [PATCH 0609/1324] Add CreateR4Parameter to ClientLibraryTestRunnerMixin. PiperOrigin-RevId: 746703385 --- .../xla/xla/tests/client_library_test_runner_mixin.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index 7184e88958dfa1..0dc15530243b07 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -315,6 +315,16 @@ class ClientLibraryTestRunnerMixin : public T { return literal; } + template + Literal CreateR4Parameter(const Array4D& array_4d, + int64_t parameter_number, const std::string& name, + XlaBuilder* builder, XlaOp* data_handle) { + Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d); + literal = MaybeConvertLiteralToTestType(literal); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); + return literal; + } + Literal MaybeConvertLiteralToTestType(const Literal& literal) const { switch (test_type_) { case BF16: From b7577b9f70b1ccb8945b6a98280f744313207998 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 22:05:03 -0700 Subject: [PATCH 0610/1324] Automated Code Change PiperOrigin-RevId: 746706001 --- third_party/xla/xla/hlo/transforms/memory_space_propagation.cc | 1 - third_party/xla/xla/hlo/transforms/operand_upcaster.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc b/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc index ac6c6300d1aac4..c88232dd0023a3 100644 --- a/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" -#include #include #include diff --git a/third_party/xla/xla/hlo/transforms/operand_upcaster.cc b/third_party/xla/xla/hlo/transforms/operand_upcaster.cc index 4fe0df75d32590..2e774787e2e239 100644 --- a/third_party/xla/xla/hlo/transforms/operand_upcaster.cc +++ b/third_party/xla/xla/hlo/transforms/operand_upcaster.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" From 081fac73ad75ff175eebd624d34b9f1d80bd7c24 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Fri, 11 Apr 2025 22:25:23 -0700 Subject: [PATCH 0611/1324] [XLA] Extend DCE to remove dead parameters for called computations This is already being done for fusions today. Handle other trivial control flow, such as kCall and kAsyncStart. Also cleanup the dead root detection logic in DCE. This has gotten to wildly to understand over the years. PiperOrigin-RevId: 746709912 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 108 +++++++++++++--- third_party/xla/xla/hlo/ir/hlo_computation.h | 18 ++- .../xla/xla/hlo/transforms/simplifiers/BUILD | 11 +- .../xla/hlo/transforms/simplifiers/hlo_dce.cc | 116 +++++++++++++---- .../xla/hlo/transforms/simplifiers/hlo_dce.h | 17 ++- .../transforms/simplifiers/hlo_dce_test.cc | 118 ++++++++++++++++++ 6 files changed, 334 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index d345d97fa3910b..78c34570e3da6b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -519,8 +519,11 @@ absl::Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) { return absl::OkStatus(); } -bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction, - bool ignore_control_dependency) { +bool HloComputation::IsSafelyRemovable( + const HloInstruction* instruction, bool ignore_control_dependency, + std::optional< + absl::FunctionRef(const HloComputation*)>> + computation_callers) const { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for // example, to avert interference due to buffer aliasing). @@ -528,11 +531,42 @@ bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction, return false; } - if (instruction->opcode() == HloOpcode::kParameter && - !IsFusionComputation()) { - return false; + if (instruction->opcode() == HloOpcode::kParameter) { + // If there is no parent, it is safe to remove the child. + if (instruction->parent() == nullptr) { + return true; + } + // Entry computation parameters can never be removed. + if (instruction->parent()->IsEntryComputation()) { + return false; + } + // We generally want to be using the call graph to determine who the caller + // is, as this back pointer is very fragile, however its not reasonable to + // expect every caller to be passing in the call graph. + if (IsFusionComputation()) { + return true; + } + // If we can't fixup the caller, then we can't remove the parameter. + if (!computation_callers.has_value()) { + return false; + } + std::vector callers = + (*computation_callers)(instruction->parent()); + if (callers.empty()) { + return false; + } + for (HloInstruction* caller : + (*computation_callers)(instruction->parent())) { + if (caller->opcode() != HloOpcode::kFusion && + caller->opcode() != HloOpcode::kCall && + caller->opcode() != HloOpcode::kAsyncStart) { + // We don't handle callers with non-trivial control flow today. + return false; + } + } } + // All instruction generally are safe to remove. return true; } @@ -552,12 +586,19 @@ bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) { absl::Status HloComputation::RemoveInstructionAndUnusedOperands( HloInstruction* instruction, std::optional> cleanup, - bool ignore_control_dependencies) { + bool ignore_control_dependencies, + std::optional< + absl::FunctionRef(const HloComputation*)>> + computation_callers) { TF_RET_CHECK(root_instruction() != instruction); TF_RET_CHECK(instruction->IsDead()); - TF_RET_CHECK(IsSafelyRemovable(instruction, ignore_control_dependencies)) + TF_RET_CHECK(IsSafelyRemovable(instruction, ignore_control_dependencies, + computation_callers)) << "Cannot remove instruction: " << instruction->ToString(); + // Remember the parent, in case we lose all references to it, in order to + // clean up the callers. + HloComputation* parent = instruction->parent(); absl::flat_hash_set removed; std::queue worklist; worklist.push(instruction); @@ -567,7 +608,8 @@ absl::Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.contains(item) || !item->IsDead() || - !IsSafelyRemovable(item, ignore_control_dependencies) || + !IsSafelyRemovable(item, ignore_control_dependencies, + computation_callers) || (item->HasSideEffect() && item != instruction)) { continue; } @@ -585,10 +627,9 @@ absl::Status HloComputation::RemoveInstructionAndUnusedOperands( (*cleanup)(item); } if (item->opcode() == HloOpcode::kParameter) { - // Note that right now, only parameters inside fusion computations are - // considered to be safely removable. We cannot remove a parameter - // directly, because it may cause a renumbering of other parameters which - // may invalidate some of the pointers in the worklist. + // We cannot remove a parameter directly, because it may cause a + // renumbering of other parameters which may invalidate some of the + // pointers in the worklist. parameters_to_be_removed.push_back(item); } else { TF_RETURN_IF_ERROR(RemoveInstruction(item)); @@ -601,18 +642,47 @@ absl::Status HloComputation::RemoveInstructionAndUnusedOperands( [](HloInstruction* a, HloInstruction* b) { return a->parameter_number() > b->parameter_number(); }); + std::vector callers; + if (!parameters_to_be_removed.empty()) { + if (parent != nullptr && computation_callers.has_value()) { + callers = (*computation_callers)(parent); + } + // We generally want to be using the call graph to determine who the caller + // is, as this back pointer is very fragile, however its not reasonable to + // expect every caller to be passing in the call graph. + if (callers.empty() && FusionInstruction() != nullptr) { + callers = {FusionInstruction()}; + } + } + // Only attempt to remove parameters if we can fixup the caller. + if (callers.empty()) { + return absl::OkStatus(); + } for (HloInstruction* param : parameters_to_be_removed) { int64_t parameter_number = param->parameter_number(); TF_RETURN_IF_ERROR(RemoveParameter(parameter_number)); - if (FusionInstruction() != nullptr) { - auto operand = FusionInstruction()->mutable_operand(parameter_number); - FusionInstruction()->RemoveOperandAt(parameter_number); - FusionInstruction()->DetachFrom(operand); - if (operand->IsDead() && operand->parent()->IsSafelyRemovable( - operand, ignore_control_dependencies)) { + for (HloInstruction* caller : callers) { + // The caller could have been eagerly removed. + if (caller->IsDead()) { + continue; + } + auto operand = caller->mutable_operand(parameter_number); + caller->RemoveOperandAt(parameter_number); + caller->DetachFrom(operand); + // Cleanup operand shape embedded into the async-start shape. + if (caller->opcode() == HloOpcode::kAsyncStart) { + std::vector* operand_shapes = caller->mutable_shape() + ->mutable_tuple_shapes(0) + ->mutable_tuple_shapes(); + operand_shapes->erase(operand_shapes->begin() + parameter_number); + } + if (operand->IsDead() && + operand->parent()->IsSafelyRemovable( + operand, ignore_control_dependencies, computation_callers)) { TF_RETURN_IF_ERROR( operand->parent()->RemoveInstructionAndUnusedOperands( - operand, cleanup, ignore_control_dependencies)); + operand, cleanup, ignore_control_dependencies, + computation_callers)); } } } diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index cbe8fb6ccac4dd..0665663973d017 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -314,7 +314,10 @@ class HloComputation { HloInstruction* instruction, std::optional> cleanup = std::nullopt, - bool ignore_control_dependencies = false); + bool ignore_control_dependencies = false, + std::optional(const HloComputation*)>> + computation_callers = std::nullopt); // Set the root of the computation to the given instruction. The instruction // must have already been added to the computation. In addition it must have @@ -757,8 +760,10 @@ class HloComputation { // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of - // the HLO computation with the exception of fusion computation. A parameter - // instruction is removable for a fusion computation. + // the HLO computation with the exception of those with trivial control flow + // (fusion, call, async call). This is determined by checking the call graph + // via computation_callers. This is expected to be equivalent to + // CallGraph::GetComputationCallers(). // // Note that IsSafelyRemovable() is a necessary condition to remove an // instruction rather than a sufficient condition. For example, instructions @@ -766,8 +771,11 @@ class HloComputation { // but the transformation must guarantee the invariants relevant to the // instructions still hold (e.g., Send and Recv must be removed together to // make each channel complete). - bool IsSafelyRemovable(const HloInstruction* instruction, - bool ignore_control_dependency = false); + bool IsSafelyRemovable( + const HloInstruction* instruction, bool ignore_control_dependency = false, + std::optional(const HloComputation*)>> + computation_callers = std::nullopt) const; // Returns a map from an instruction to the group of instructions associated // with the same channel. These instructions will be considered as a single diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD index cab76b59d355cb..79fa8234711daf 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -805,13 +805,15 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) @@ -827,10 +829,13 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:pattern_matcher", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc index fe18654bf84623..0f99bcfa3f6ba0 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc @@ -17,14 +17,18 @@ limitations under the License. #include #include +#include +#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -34,12 +38,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -47,7 +51,7 @@ namespace { // Checks if the instruction is a removable while given // remove_cross_partition_collective_ops -bool IsRemovableWhile(HloInstruction* instruction, +bool IsRemovableWhile(const HloInstruction* instruction, bool remove_cross_partition_collective_ops) { if (instruction->opcode() != HloOpcode::kWhile) { return false; @@ -163,39 +167,92 @@ absl::StatusOr RemoveMultiOutputFusionsUnusedOutputs( } // namespace /*static*/ absl::StatusOr HloDCE::RunOnComputation( - HloComputation* computation, bool remove_cross_partition_collective_ops) { + HloComputation* computation, bool remove_cross_partition_collective_ops, + CallGraph* call_graph) { // We do this first, because it may create dead roots which we can clean up // next. TF_ASSIGN_OR_RETURN(bool changed, RemoveMultiOutputFusionsUnusedOutputs(computation)); + auto computation_callers = + [call_graph]( + const HloComputation* computation) -> std::vector { + if (call_graph == nullptr) { + return {}; + } + return call_graph->GetComputationCallers(computation); + }; + // Remove any dead roots and their dead transitive operands. Collect // them into a separate list first to avoid problems with iterating through // the computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { - auto maybe_collective_op = DynCast(instruction); - if (instruction->IsDead() && computation->IsSafelyRemovable(instruction) && - (!instruction->IsCustomCall("Sharding") || - (!instruction->operand(0)->IsRoot() && - instruction->operand(0)->opcode() != HloOpcode::kParameter && - instruction->operand(0)->user_count() == 1)) && - (!instruction->HasSideEffect() || - (remove_cross_partition_collective_ops && maybe_collective_op && - !maybe_collective_op->constrain_layout()) || - IsRemovableWhile(instruction, - remove_cross_partition_collective_ops))) { - dead_roots.push_back(instruction); + if (!instruction->IsDead()) { + continue; + } + if (!computation->IsSafelyRemovable( + instruction, + /*ignore_control_dependency=*/false, + /*computation_callers=*/computation_callers)) { + continue; + } + // We cannot remove a parameter directly, because it may cause a + // renumbering of other parameters which may invalidate some of the + // pointers in the worklist. + if (instruction->opcode() == HloOpcode::kParameter) { + continue; + } + if (instruction->IsCustomCall("Sharding") && + (instruction->operand(0)->IsRoot() || + instruction->operand(0)->opcode() == HloOpcode::kParameter || + instruction->operand(0)->user_count() != 1)) { + continue; } + if (instruction->HasSideEffect()) { + auto maybe_collective_op = DynCast(instruction); + bool allow_collective = remove_cross_partition_collective_ops && + maybe_collective_op && + !maybe_collective_op->constrain_layout(); + bool allow_while = + IsRemovableWhile(instruction, remove_cross_partition_collective_ops); + if (!allow_collective && !allow_while) { + continue; + } + } + dead_roots.push_back(instruction); } for (HloInstruction* dead_root : dead_roots) { VLOG(1) << "Removing dead root " << dead_root->ToString() << " and its unused operands"; - TF_RETURN_IF_ERROR( - computation->RemoveInstructionAndUnusedOperands(dead_root)); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( + dead_root, /*cleanup=*/std::nullopt, + /*ignore_control_dependencies=*/false, + /*computation_callers=*/computation_callers)); changed = true; } + + auto parameters = computation->parameter_instructions(); + // Sort into decreasing order by parameter number, otherwise the renumbering + // of parameters when one parameter is deleted will cause issues. + absl::c_reverse(parameters); + for (HloInstruction* parameter : parameters) { + if (parameter->IsDead() && + computation->IsSafelyRemovable( + parameter, + /*ignore_control_dependency=*/false, + /*computation_callers=*/computation_callers)) { + VLOG(1) << "Removing dead parameter " << parameter->ToString() + << " and its unused operands"; + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( + parameter, /*cleanup=*/std::nullopt, + /*ignore_control_dependencies=*/false, + /*computation_callers=*/computation_callers)); + changed = true; + } + } + return changed; } @@ -207,6 +264,11 @@ absl::StatusOr HloDCE::Run( VLOG(2) << "Before dce; threads: " << absl::StrJoin(execution_threads, ","); XLA_VLOG_LINES(2, module->ToString()); + std::unique_ptr call_graph; + if (use_call_analysis_) { + call_graph = CallGraph::Build(module); + } + // Run DCE on each computation. Visit callers before callees so that we // cleanup dead get-tuple-element users of MultiOutput fusions before cleaning // up the fusion computation. If the same callee is referred to by multiple @@ -230,8 +292,8 @@ absl::StatusOr HloDCE::Run( execution_threads.contains(computation->execution_thread())) { TF_ASSIGN_OR_RETURN( bool computation_changed, - RunOnComputation(computation, - remove_cross_partition_collective_ops_)); + RunOnComputation(computation, remove_cross_partition_collective_ops_, + call_graph.get())); changed |= computation_changed; } @@ -244,6 +306,18 @@ absl::StatusOr HloDCE::Run( } } } + // Some computations might have been left dangling due to being detached + // indirectly. We need to rebuild the call graph to find these. + if (use_call_analysis_) { + call_graph = CallGraph::Build(module); + for (HloComputation* computation : + module->computations(execution_threads)) { + if (!computation->IsEntryComputation() && + call_graph->GetComputationCallers(computation).empty()) { + to_remove.insert(computation); + } + } + } for (auto computation : to_remove) { // Only remove computations from the specified execution threads. if (execution_threads.empty() || diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h index a7579ad3538a27..d6f1e4e10803d7 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/call_graph.h" namespace xla { @@ -33,20 +34,23 @@ namespace xla { // dead if it is not the entry computation of the module and it is not reachable // from the entry computation. // -// This pass does not remove dead parameter instructions, as parameter -// instructions cannot be deleted. +// This pass does not remove dead parameter instructions, unless call analysis +// is enabled. Using this will slow down compilation. This is only beneficial +// to do so if the graph is not inlined. class HloDCE : public HloModulePass { public: - HloDCE() : remove_cross_partition_collective_ops_(false) {} - explicit HloDCE(bool remove_cross_partition_collective_ops) + explicit HloDCE(bool remove_cross_partition_collective_ops = false, + bool use_call_analysis = false) : remove_cross_partition_collective_ops_( - remove_cross_partition_collective_ops) {} + remove_cross_partition_collective_ops), + use_call_analysis_(use_call_analysis) {} ~HloDCE() override {} absl::string_view name() const override { return "dce"; } // Run DCE on a computation. static absl::StatusOr RunOnComputation( - HloComputation* computation, bool remove_cross_partition_collective_ops); + HloComputation* computation, bool remove_cross_partition_collective_ops, + CallGraph* call_graph = nullptr); // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). @@ -57,6 +61,7 @@ class HloDCE : public HloModulePass { private: bool remove_cross_partition_collective_ops_; + bool use_call_analysis_; }; } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc index 36c86c76fa4228..b3827ba430fc29 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -30,12 +31,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -799,6 +802,7 @@ TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementAdjustTuple) { m::Tuple(m::Negate(), m::Add()).WithShapeEqualTo(&expected_shape))); EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); } + TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementWithControlAdjustTupleAndDep) { constexpr char kHloString[] = R"( @@ -832,5 +836,119 @@ TEST_F(HloDceTest, EXPECT_EQ(add2->control_predecessors().size(), 1); EXPECT_EQ(add2->control_predecessors()[0], fusion); } + +TEST_F(HloDceTest, UnusedCalledParameter) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + arg.1 = s32[] parameter(1) + ROOT call.0 = (s32[]) call(arg.0, arg.1), to_apply={ + arg.0 = s32[] parameter(0) + arg.1 = s32[] parameter(1) + ROOT tuple.0 = tuple(arg.0) + } +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + HloDCE dce(/*remove_cross_partition_collective_ops=*/false, + /*use_call_analysis=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, dce.Run(module.get())); + EXPECT_TRUE(changed); + + HloComputation* main = module->entry_computation(); + EXPECT_EQ(main->parameter_instructions().size(), 2); + + HloInstruction* call0 = main->root_instruction(); + ASSERT_EQ(call0->opcode(), HloOpcode::kCall); + // arg.1 should have been removed. + EXPECT_EQ(call0->operand_count(), 1); + EXPECT_EQ(call0->operand(0), main->parameter_instruction(0)); + + HloComputation* called_computation = call0->to_apply(); + EXPECT_EQ(called_computation->parameter_instructions().size(), 1); +} + +TEST_F(HloDceTest, UnusedAsyncParameter) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + arg.1 = s32[] parameter(1) + call-start.0 = ((s32[], s32[]), (s32[]), s32[]) call-start(arg.0, arg.1), to_apply={ + arg.0 = s32[] parameter(0) + arg.1 = s32[] parameter(1) + ROOT tuple.0 = tuple(arg.0) + }, async_execution_thread="thread" + ROOT call-done.0 = (s32[]) call-done(call-start.0) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + HloDCE dce(/*remove_cross_partition_collective_ops=*/false, + /*use_call_analysis=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, dce.Run(module.get())); + EXPECT_TRUE(changed); + + HloComputation* main = module->entry_computation(); + EXPECT_EQ(main->parameter_instructions().size(), 2); + + HloInstruction* call_done0 = main->root_instruction(); + ASSERT_EQ(call_done0->opcode(), HloOpcode::kAsyncDone); + HloInstruction* call_start0 = call_done0->async_chain_start(); + // arg.1 should have been removed. + EXPECT_EQ(call_start0->operand_count(), 1); + EXPECT_EQ(call_start0->operand(0), main->parameter_instruction(0)); + + HloComputation* async_wrapped_computation = + call_start0->async_wrapped_computation(); + EXPECT_EQ(async_wrapped_computation->parameter_instructions().size(), 1); + + HloComputation* async_computation = + call_start0->async_wrapped_instruction()->to_apply(); + EXPECT_EQ(async_computation->parameter_instructions().size(), 1); +} + +TEST_F(HloDceTest, IndirectComputationRemoval) { + constexpr absl::string_view kHlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + call.0 = (s32[]) call(arg.0), to_apply={ + arg.0 = s32[] parameter(0) + ROOT tuple.0 = tuple(arg.0) + } + gte.0 = get-tuple-element(call.0), index=0 + ROOT call.1 = (s32[]) call(gte.0), to_apply={ + arg.0 = s32[] parameter(0) + zero.0 = s32[] constant(0) + ROOT tuple.0 = tuple(zero.0) + } +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + HloDCE dce(/*remove_cross_partition_collective_ops=*/false, + /*use_call_analysis=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, dce.Run(module.get())); + EXPECT_TRUE(changed); + + // call.0 should be removed. + EXPECT_EQ(module->computation_count(), 2); + + HloComputation* main = module->entry_computation(); + EXPECT_EQ(main->parameter_instructions().size(), 1); + + HloInstruction* call1 = module->entry_computation()->root_instruction(); + ASSERT_EQ(call1->opcode(), HloOpcode::kCall); + EXPECT_EQ(call1->operand_count(), 0); +} } // namespace } // namespace xla From 23b6121ab44f5599d9501c71ae0bba20505092e1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 22:29:36 -0700 Subject: [PATCH 0612/1324] Automated Code Change PiperOrigin-RevId: 746710504 --- .../xla/service/gpu/conv_layout_normalization.cc | 4 ++-- .../xla/service/gpu/cublas_padding_requirements.cc | 2 +- .../xla/xla/service/gpu/cudnn_support_utils.cc | 6 +++--- .../xla/service/gpu/fusion_dispatch_pipeline.cc | 2 +- .../service/gpu/gpu_latency_hiding_scheduler.cc | 2 +- .../xla/xla/service/gpu/gpu_transfer_manager.cc | 2 +- .../xla/xla/service/gpu/ir_emission_utils.cc | 4 ++-- .../xla/xla/service/gpu/ir_emitter_unnested.cc | 14 +++++++------- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/conv_layout_normalization.cc b/third_party/xla/xla/service/gpu/conv_layout_normalization.cc index ef2d17e776735f..336a1fea11d4d5 100644 --- a/third_party/xla/xla/service/gpu/conv_layout_normalization.cc +++ b/third_party/xla/xla/service/gpu/conv_layout_normalization.cc @@ -45,7 +45,7 @@ absl::StatusOr> UpdateLayoutForCudnnConvolution( hlo->convolution_dimension_numbers(); auto transpose_dim = [&](int64_t dim, const Shape& unnormalized_shape) { - return unnormalized_shape.dimensions_size() - + return unnormalized_shape.dimensions().size() - FindIndex(unnormalized_shape.layout().minor_to_major(), dim) - 1; }; @@ -110,7 +110,7 @@ absl::StatusOr> UpdateLayoutForCudnnConvolution( Shape normalized_shape; if (hlo->shape().IsTuple()) { - TF_RET_CHECK(hlo->shape().tuple_shapes().back().dimensions_size() == 1) + TF_RET_CHECK(hlo->shape().tuple_shapes().back().dimensions().size() == 1) << "The last element in the tuple returned by a convolution Custom " "Call is expected to be an " "allocator of rank one"; diff --git a/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc b/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc index 3e606f93cff619..917f88be91581b 100644 --- a/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc +++ b/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc @@ -59,7 +59,7 @@ bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size, // Non-batch dimensions requiring potential padding are placed at higher // indices than batch dimensions. This is because dots are canonicalized prior // to padding. - for (int i = batch_dimensions_size; i < shape.dimensions_size(); i++) { + for (int i = batch_dimensions_size; i < shape.dimensions().size(); i++) { if (DimensionRequiresPadding(shape.dimensions(i), shape.element_type(), cc)) { return true; diff --git a/third_party/xla/xla/service/gpu/cudnn_support_utils.cc b/third_party/xla/xla/service/gpu/cudnn_support_utils.cc index 621ca194e6ba43..489ed743dae43c 100644 --- a/third_party/xla/xla/service/gpu/cudnn_support_utils.cc +++ b/third_party/xla/xla/service/gpu/cudnn_support_utils.cc @@ -129,7 +129,7 @@ CudnnInferTransposeForFilterReordering( const Shape& shape, const ConvolutionDimensionNumbers& dimension_numbers) { // A normal filter should have four dimensions: [O, I, H, W] // An already vectorized filter will have five: [O, I/k, H, W, k]; k=4|32 - if (shape.dimensions_size() != 4 && shape.dimensions_size() != 5) { + if (shape.dimensions().size() != 4 && shape.dimensions().size() != 5) { return Internal("Filter shape has unexpected rank."); } @@ -140,7 +140,7 @@ CudnnInferTransposeForFilterReordering( const int64_t dW = dimension_numbers.kernel_spatial_dimensions().at(1); // In case of re-vectorization (rank=5), the missing dimension can be // calculated as Σi(i=0..4)-(dO+dI+dH+dW) - bool revectorize = shape.dimensions_size() == 5; + bool revectorize = shape.dimensions().size() == 5; const int64_t dZ = revectorize ? 10 - dO - dI - dH - dW : -1; const int64_t vsize = revectorize ? shape.dimensions(dZ) : 1; @@ -196,7 +196,7 @@ CudnnInferTransposeForFilterReordering( absl::StatusOr CudnnInferTransposeForBiasReordering(const Shape& shape) { // Expected bias has one dimension: [O] - if (shape.dimensions_size() != 1) { + if (shape.dimensions().size() != 1) { return Internal("Bias shape has unexpected rank."); } if (shape.dimensions(0) % 32 != 0) { diff --git a/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc index 28302ad6ffe6aa..6bc69cec0b3907 100644 --- a/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc @@ -57,7 +57,7 @@ bool IsSlowLoopTransposeFusion(const HloFusionInstruction* fusion) { // is neither the minormost nor the second minormost dimension in the output, // and the output minormost dimension is swapped with the new minormost // dimension. - int64_t rank = root->shape().dimensions_size(); + int64_t rank = root->shape().dimensions().size(); // The transpose dimension grouper has run, so it should be enough to check // that the minormost dimension's index within the result is smaller than diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 359b827ec51f11..fddb4441fc6d0b 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -174,7 +174,7 @@ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction( return size; } // Each dynamic dimension size is represented as a S32. - int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); + int64_t metadata_size = sizeof(int32_t) * shape.dimensions().size(); return size + metadata_size; }; } diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc index a1d90d328b18b8..6a4e62139a4d7f 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc @@ -218,7 +218,7 @@ absl::Status GpuTransferManager::ReadDynamicShapes( for (int i = 0; i < copies.size(); i++) { Shape* dst_shape = copies[i].second; int32_t* dst = h2d_memcpy_dsts[i]; - for (int64_t j = 0; j < dst_shape->dimensions_size(); j++) { + for (int64_t j = 0; j < dst_shape->dimensions().size(); j++) { dst_shape->mutable_dimensions()[j] = dst[j]; } } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 3e222ebdef0801..95554d7ab775b5 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -77,12 +77,12 @@ namespace { // Return whether the given shape is rank 2 excluding the batch dimensions. bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) { - return shape.dimensions_size() == batch_dimensions_size + 2; + return shape.dimensions().size() == batch_dimensions_size + 2; } // Return whether the given shape is rank 1 excluding the batch dimensions. bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { - return shape.dimensions_size() == batch_dimensions_size + 1; + return shape.dimensions().size() == batch_dimensions_size + 1; } } // namespace diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index d9e09ee69ea464..7e156db5f36730 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -836,7 +836,7 @@ absl::Status IrEmitterUnnested::EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr) { bool has_bias = instr->operand_count() > 1; Shape shape = has_bias ? instr->shape().tuple_shapes(0) : instr->shape(); - if (shape.dimensions_size() != 5 || shape.dimensions(4) != 32) { + if (shape.dimensions().size() != 5 || shape.dimensions(4) != 32) { return Internal("Unexpected shape for convolution reorder: %s", instr->ToString()); } @@ -1021,7 +1021,7 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort( : std::nullopt, operands, results, scratch, options.descending(), Product(operand_shape.dimensions()) / - operand_shape.dimensions(operand_shape.dimensions_size() - 1)); + operand_shape.dimensions(operand_shape.dimensions().size() - 1)); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -1030,7 +1030,7 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { TF_ASSIGN_OR_RETURN(CholeskyOptions options, instr->backend_config()); const Shape& shape = instr->operand(0)->shape(); - int ndim = shape.dimensions_size(); + int ndim = shape.dimensions().size(); CHECK_GE(ndim, 2); int64_t n = shape.dimensions(ndim - 1); @@ -1305,8 +1305,8 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall( /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape))); } - int64_t m = b_shape.dimensions(b_shape.dimensions_size() - 2); - int64_t n = b_shape.dimensions(b_shape.dimensions_size() - 1); + int64_t m = b_shape.dimensions(b_shape.dimensions().size() - 2); + int64_t n = b_shape.dimensions(b_shape.dimensions().size() - 1); int64_t batch_size = std::accumulate( b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1}, [](int64_t a, int64_t b) { return a * b; }); @@ -1347,11 +1347,11 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall( auto top_elements_shape = shape.tuple_shapes()[0]; auto indices_shape = shape.tuple_shapes()[1]; - TF_RET_CHECK(data_shape.dimensions_size() <= 2) << "Invalid input shape."; + TF_RET_CHECK(data_shape.dimensions().size() <= 2) << "Invalid input shape."; TF_RET_CHECK(indices_shape.element_type() == PrimitiveType::S32) << "Indices should be S32."; - bool has_batch = data_shape.dimensions_size() == 2; + bool has_batch = data_shape.dimensions().size() == 2; auto [batch_size, n, k] = has_batch ? std::tuple{data_shape.dimensions(0), From 9c8435bdadd75298ca5795804b0c9d4c50bc1d3f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 22:55:01 -0700 Subject: [PATCH 0613/1324] Automated Code Change PiperOrigin-RevId: 746715431 --- .../xla/xla/service/gpu/triton_tiling_propagation.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 961b35b51cdd3a..1f2a3f704a9bc2 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -157,7 +157,7 @@ using FragmentOrders = DimensionOrder::FragmentOrders; /*static*/ DimensionOrder DimensionOrder::FromDotOperandOrOutput( const HloInstruction& hlo, const int split_k_dimension_index) { DimensionOrder dim_order; - dim_order.tensor_fragments_order_.reserve(hlo.shape().dimensions_size()); + dim_order.tensor_fragments_order_.reserve(hlo.shape().dimensions().size()); for (const int i : hlo.shape().layout().minor_to_major()) { int target_dim_number = i; if (i == split_k_dimension_index) { @@ -611,8 +611,8 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( // Group subdimensions by iterating over them in the same order as over // full dimensions and matching by total size. std::vector> src_physical; - src_physical.reserve(src.shape().dimensions_size()); - if (src_fragments_order.size() < src.shape().dimensions_size()) { + src_physical.reserve(src.shape().dimensions().size()); + if (src_fragments_order.size() < src.shape().dimensions().size()) { // It's not supported currently to further propagate dimensions after // reaching a trivial sized tensor. We could probably support it, but now we // just prevent crashing here. @@ -676,7 +676,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( if (reduce->dimensions().size() != 1) { return FusionDecision::Forbid("Unsupported reduction."); } else if (reduce->dimensions().front() != - reduce->operand(0)->shape().dimensions_size() - 1) { + reduce->operand(0)->shape().dimensions().size() - 1) { return FusionDecision::Forbid("Only row reductions are supported."); } } else if (hlo.opcode() == HloOpcode::kConcatenate) { From e55b7c1c51242ed1de0b847f7f710786c9c391a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Apr 2025 23:14:45 -0700 Subject: [PATCH 0614/1324] Automated Code Change PiperOrigin-RevId: 746719640 --- tensorflow/python/BUILD | 9 +++++++++ tensorflow/python/mlir_wrapper.cc | 1 - tensorflow/python/py_exception_registry_wrapper.cc | 2 -- tensorflow/python/pywrap_dtensor_device.cc | 8 ++++++++ tensorflow/python/tfe_wrapper_monitoring_reader.cc | 2 -- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index fa305cd653fc41..c64325e230024e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1311,6 +1311,7 @@ tf_python_pybind_extension( "//tensorflow/core/config:flags_headers", "//tensorflow/core/framework:pywrap_required_hdrs", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/platform", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status", @@ -1372,10 +1373,18 @@ tf_python_pybind_extension( features = ["-layering_check"], deps = [ ":pywrap_densor_device_headers", + "//tensorflow/core:protos_all_cc", + "//tensorflow/dtensor/cc:dtensor_device_cc", + "//tensorflow/dtensor/cc:tensor_layout", "//tensorflow/dtensor/proto:layout_proto_cc", + "//tensorflow/python/eager:pywrap_tfe_lib", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status_headers", + "//tensorflow/python/lib/core:safe_pyobject_ptr", + "//tensorflow/python/util:cpp_python_util", "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 662f70ba3e112b..e00ff095b62e2c 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 diff --git a/tensorflow/python/py_exception_registry_wrapper.cc b/tensorflow/python/py_exception_registry_wrapper.cc index 0f1a17ccb1816e..54189c779c6210 100644 --- a/tensorflow/python/py_exception_registry_wrapper.cc +++ b/tensorflow/python/py_exception_registry_wrapper.cc @@ -15,8 +15,6 @@ limitations under the License. #include -#include - #include "pybind11/attr.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 diff --git a/tensorflow/python/pywrap_dtensor_device.cc b/tensorflow/python/pywrap_dtensor_device.cc index a055f784d382a3..78121b8f1dadc5 100644 --- a/tensorflow/python/pywrap_dtensor_device.cc +++ b/tensorflow/python/pywrap_dtensor_device.cc @@ -13,16 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include +#include #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/dtensor/cc/dtensor_device.h" #include "tensorflow/dtensor/cc/tensor_layout.h" +#include "tensorflow/dtensor/proto/layout.pb.h" #include "tensorflow/python/eager/pywrap_tensor.h" #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/python/tfe_wrapper_monitoring_reader.cc b/tensorflow/python/tfe_wrapper_monitoring_reader.cc index 5496065a572255..c906a9f851a260 100644 --- a/tensorflow/python/tfe_wrapper_monitoring_reader.cc +++ b/tensorflow/python/tfe_wrapper_monitoring_reader.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "Python.h" #include "pybind11/complex.h" // from @pybind11 #include "pybind11/functional.h" // from @pybind11 From 6bc632fa1c967b64f70d1afb87c6722e98839feb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 00:51:56 -0700 Subject: [PATCH 0615/1324] Automated Code Change PiperOrigin-RevId: 746739880 --- .../xla/xla/hlo/builder/lib/approx_topk.cc | 8 ++--- .../xla/xla/hlo/builder/lib/arithmetic.cc | 2 +- .../xla/hlo/builder/lib/dynamic_shaped_ops.cc | 16 ++++----- .../xla/hlo/builder/lib/lu_decomposition.cc | 2 +- third_party/xla/xla/hlo/builder/lib/matrix.cc | 20 +++++------ third_party/xla/xla/hlo/builder/lib/prng.cc | 28 +++++++-------- third_party/xla/xla/hlo/builder/lib/qr.cc | 12 +++---- .../xla/xla/hlo/builder/lib/quantize.h | 8 ++--- .../xla/hlo/builder/lib/self_adjoint_eig.cc | 2 +- .../hlo/builder/lib/self_adjoint_eig_test.cc | 4 +-- .../xla/xla/hlo/builder/lib/slicing.cc | 36 +++++++++---------- .../xla/xla/hlo/builder/lib/sorting.cc | 22 ++++++------ third_party/xla/xla/hlo/builder/lib/svd.cc | 14 ++++---- .../xla/xla/hlo/builder/lib/svd_test.cc | 2 +- .../xla/xla/hlo/builder/lib/tridiagonal.cc | 16 ++++----- 15 files changed, 96 insertions(+), 96 deletions(-) diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc index dcdcddbff6be80..9d133a351f9beb 100644 --- a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc @@ -49,7 +49,7 @@ absl::StatusOr> GetOperandTypes( for (int i = 0; i < num_operands; ++i) { const auto& op_shape = operands_shapes[i]; const auto& init_shape = init_values_shapes[i]; - if (op_shape.dimensions_size() == 0) { + if (op_shape.dimensions().size() == 0) { return InvalidArgument("ApproxTopK operands must have rank 1+."); } if (!ShapeUtil::CompatibleIgnoringElementType(operands_shapes[0], @@ -115,7 +115,7 @@ XlaOp AggregateToTopKBuilder(XlaBuilder* builder, int64_t reduction_dim, const XlaComputation& comparator) { auto operands_shapes = builder->GetOperandShapes(operands).value(); - int64_t rank = operands_shapes[0].dimensions_size(); + int64_t rank = operands_shapes[0].dimensions().size(); int64_t num_operands = operands.size(); if (top_k == 1) { @@ -176,7 +176,7 @@ XlaOp ApproxTopK(XlaBuilder* builder, absl::Span operands, return builder->ReportError(status_or_optypes.status()); } auto op_types = status_or_optypes.value(); - int64_t rank = operands_shapes[0].dimensions_size(); + int64_t rank = operands_shapes[0].dimensions().size(); if (reduction_dim < 0 || reduction_dim >= rank) { return builder->ReportError( InvalidArgument("reduction_dim should range in [0,%d)", rank)); @@ -263,7 +263,7 @@ XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span operands, bool aggregate_to_topk, int64_t reduction_input_size_override) { auto operands_shapes = builder->GetOperandShapes(operands).value(); - int64_t rank = operands_shapes[0].dimensions_size(); + int64_t rank = operands_shapes[0].dimensions().size(); uint64_t n = operands_shapes[0].dimensions(reduction_dim); // Align the output size with ApproxTopK. auto status_or_approx_output_size = ApproxTopKReductionOutputSize( diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc index fc2dfe56298b6b..84908846dae705 100644 --- a/third_party/xla/xla/hlo/builder/lib/arithmetic.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc @@ -110,7 +110,7 @@ XlaOp Any(XlaOp predicates) { XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); - std::vector all_dimensions(predicates_shape.dimensions_size()); + std::vector all_dimensions(predicates_shape.dimensions().size()); std::iota(all_dimensions.begin(), all_dimensions.end(), 0); return Reduce(predicates, f, logical_or, all_dimensions); }); diff --git a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc index 1428ad8e0952b6..44a2edc1689768 100644 --- a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc @@ -66,8 +66,8 @@ Shape FindMaxShape(absl::Span shapes) { Shape result = *shapes[0]; for (const Shape* shape : shapes) { - CHECK(result.dimensions_size() == shape->dimensions_size()); - for (int64_t dim = 0; dim < result.dimensions_size(); ++dim) { + CHECK(result.dimensions().size() == shape->dimensions().size()); + for (int64_t dim = 0; dim < result.dimensions().size(); ++dim) { if (shape->dimensions(dim) > result.dimensions(dim)) { result.set_dimensions(dim, shape->dimensions(dim)); } @@ -106,15 +106,15 @@ absl::StatusOr ReconsileBranchDifference(const Shape& left_branch_shape, "right_branch_shape should not be a tuple, received %s", right_branch_shape.DebugString()); } - if (left_branch_shape.dimensions_size() != - right_branch_shape.dimensions_size()) { + if (left_branch_shape.dimensions().size() != + right_branch_shape.dimensions().size()) { return InvalidArgument( "left_branch_shape.dimensions_size() != " "right_branch_shape.dimensions_size() (%d vs %d)", - left_branch_shape.dimensions_size(), - right_branch_shape.dimensions_size()); + left_branch_shape.dimensions().size(), + right_branch_shape.dimensions().size()); } - for (int64_t dim = 0; dim < left_branch_shape.dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < left_branch_shape.dimensions().size(); ++dim) { XlaOp original_dim = GetDimensionSize(result, dim); if (left_branch_shape.dimensions(dim) < right_branch_shape.dimensions(dim)) { @@ -273,7 +273,7 @@ absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, TF_RETURN_IF_ERROR(builder->GetCurrentStatus()); TF_ASSIGN_OR_RETURN(auto shape_ptr, builder->GetShapePtr(operand)); - for (int64_t i = 0; i < shape_ptr->dimensions_size(); ++i) { + for (int64_t i = 0; i < shape_ptr->dimensions().size(); ++i) { // If a dimension is dynamic, call set-dimension-size on the output. auto dim_size = xla::Slice(size_vector, {i}, {i + 1}, {1}); dim_size = xla::Reshape(dim_size, {}); diff --git a/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc index 4f5f5bb31bfe7c..eb9d541019d731 100644 --- a/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc +++ b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc @@ -33,7 +33,7 @@ LuDecompositionResult LuDecomposition(XlaOp a) { XlaBuilder* builder = a.builder(); XlaOp result = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int ndims = a_shape.dimensions_size(); + const int ndims = a_shape.dimensions().size(); TF_RET_CHECK(ndims >= 2); const int64_t m = ShapeUtil::GetDimension(a_shape, -2); const int64_t n = ShapeUtil::GetDimension(a_shape, -1); diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.cc b/third_party/xla/xla/hlo/builder/lib/matrix.cc index ddfbc000ad4da2..e9fe29ea83ee6b 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix.cc @@ -64,7 +64,7 @@ XlaOp GetDiagonalMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - auto n_dims = static_cast(shape.dimensions_size()); + auto n_dims = static_cast(shape.dimensions().size()); TF_RET_CHECK(n_dims >= 2); auto m = shape.dimensions(n_dims - 2); auto n = shape.dimensions(n_dims - 1); @@ -82,7 +82,7 @@ XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - auto n_dims = static_cast(shape.dimensions_size()); + auto n_dims = static_cast(shape.dimensions().size()); TF_RET_CHECK(n_dims >= 2); const int64_t m = shape.dimensions(n_dims - 2); const int64_t n = shape.dimensions(n_dims - 1); @@ -116,7 +116,7 @@ XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - auto n_dims = static_cast(shape.dimensions_size()); + auto n_dims = static_cast(shape.dimensions().size()); TF_RET_CHECK(n_dims >= 2); const int64_t m = shape.dimensions(n_dims - 2); const int64_t n = shape.dimensions(n_dims - 1); @@ -180,7 +180,7 @@ XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) { return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix)); TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag)); - auto n_dims = static_cast(shape.dimensions_size()); + auto n_dims = static_cast(shape.dimensions().size()); TF_RET_CHECK(n_dims >= 2); const int64_t m = shape.dimensions(n_dims - 2); const int64_t n = shape.dimensions(n_dims - 1); @@ -195,7 +195,7 @@ XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) { if (pad_high != 0) { PaddingConfig padding_config; - for (int64_t i = 0; i < diag_shape.dimensions_size() - 1; ++i) { + for (int64_t i = 0; i < diag_shape.dimensions().size() - 1; ++i) { auto* dims = padding_config.add_dimensions(); dims->set_edge_padding_low(0); dims->set_interior_padding(0); @@ -218,7 +218,7 @@ XlaOp TriangleMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); TF_RET_CHECK(n_dims >= 2); const int64_t m = shape.dimensions(n_dims - 2); const int64_t n = shape.dimensions(n_dims - 1); @@ -245,7 +245,7 @@ XlaOp Symmetrize(XlaOp x, bool lower) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - if (shape.dimensions_size() < 2) { + if (shape.dimensions().size() < 2) { return InvalidArgument( "Argument to symmetrize must have >= 2 dimensions, got %s", shape.ToString()); @@ -734,8 +734,8 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); TF_ASSIGN_OR_RETURN( auto einsum_config_numeric, - ParseEinsumString(einsum_config, x_shape.dimensions_size(), - y_shape.dimensions_size())); + ParseEinsumString(einsum_config, x_shape.dimensions().size(), + y_shape.dimensions().size())); return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1], einsum_config_numeric[2], precision, preferred_element_type, grad_x, grad_y); @@ -752,7 +752,7 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); diff --git a/third_party/xla/xla/hlo/builder/lib/prng.cc b/third_party/xla/xla/hlo/builder/lib/prng.cc index 2a43e0e4d3b638..661981f19c17ea 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng.cc @@ -156,7 +156,7 @@ std::pair GetThreeFryInputsAndUpdatedState( // initial_state is an R1, so reshape it to a scalar. auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions()); int64_t trailing_dims_product = 1; - for (int64_t i = shape.dimensions_size() - 1; i >= 0; --i) { + for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) { if (shape.dimensions(i) < 2) { continue; } @@ -181,7 +181,7 @@ struct SplitShapePair { // Split the shape on a dimension > 1 into two halves. SplitShapePair SplitShapeIntoHalves(const Shape& shape) { SplitShapePair pair; - if (shape.dimensions_size() == 0) { + if (shape.dimensions().size() == 0) { pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1}); pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2}); pair.split_dim = 0; @@ -189,7 +189,7 @@ SplitShapePair SplitShapeIntoHalves(const Shape& shape) { return pair; } pair.split_dim = -1; - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { if (shape.dimensions(i) % 2 == 0) { pair.split_dim = i; break; @@ -197,7 +197,7 @@ SplitShapePair SplitShapeIntoHalves(const Shape& shape) { } if (pair.split_dim == -1) { // No even dims. Find a dimension with maximum size. - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { if (pair.split_dim == -1 || shape.dimensions(i) > shape.dimensions(pair.split_dim)) { pair.split_dim = i; @@ -209,7 +209,7 @@ SplitShapePair SplitShapeIntoHalves(const Shape& shape) { } std::vector half_shape_dims; std::vector concat_shape_dims; - const auto rank = shape.dimensions_size(); + const auto rank = shape.dimensions().size(); half_shape_dims.reserve(rank + 1); concat_shape_dims.reserve(rank + 1); for (int64_t i = 0; i < rank; ++i) { @@ -236,7 +236,7 @@ SplitShapePair SplitShapeIntoHalves(const Shape& shape) { XlaOp CombineShapePair(absl::Span pair, const SplitShapePair& shape_pair, const Shape& original_shape) { - if (original_shape.dimensions_size() == 0) { + if (original_shape.dimensions().size() == 0) { return Reshape(pair[0], {}); } XlaBuilder* builder = pair[0].builder(); @@ -248,10 +248,10 @@ XlaOp CombineShapePair(absl::Span pair, reshape_dims[shape_pair.split_dim] = RoundUpTo(pre_split_size, 2); result = Reshape(result, reshape_dims); if (reshape_dims[shape_pair.split_dim] != pre_split_size) { - result = - Slice(result, std::vector(original_shape.dimensions_size(), 0), - original_shape.dimensions(), - std::vector(original_shape.dimensions_size(), 1)); + result = Slice(result, + std::vector(original_shape.dimensions().size(), 0), + original_shape.dimensions(), + std::vector(original_shape.dimensions().size(), 1)); } return result; } @@ -735,15 +735,15 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, // Separate the bits into two groups to perform the Box-Muller transform. XlaOp bits_0 = Slice(bits_state.value, - std::vector(shape_pair.half_shape.dimensions_size(), 0), + std::vector(shape_pair.half_shape.dimensions().size(), 0), shape_pair.half_shape.dimensions(), - std::vector(shape_pair.half_shape.dimensions_size(), 1)); - std::vector bits_1_starts(shape_pair.half_shape.dimensions_size(), + std::vector(shape_pair.half_shape.dimensions().size(), 1)); + std::vector bits_1_starts(shape_pair.half_shape.dimensions().size(), 0); bits_1_starts[shape_pair.new_concat_dim] = 1; XlaOp bits_1 = Slice( bits_state.value, bits_1_starts, shape_pair.concat_shape.dimensions(), - std::vector(shape_pair.half_shape.dimensions_size(), 1)); + std::vector(shape_pair.half_shape.dimensions().size(), 1)); std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1); // Put the numbers in the two groups back to form the requested shape. diff --git a/third_party/xla/xla/hlo/builder/lib/qr.cc b/third_party/xla/xla/hlo/builder/lib/qr.cc index 45dcd087e2313a..2118d54f345d4f 100644 --- a/third_party/xla/xla/hlo/builder/lib/qr.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr.cc @@ -36,7 +36,7 @@ QrDecomposition Qr(XlaOp a) { auto result = [&]() -> absl::StatusOr { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = a_shape.dimensions_size(); + const int num_dims = a_shape.dimensions().size(); if (num_dims < 2) { return InvalidArgument( "Arguments to QR must have rank >= 2: got shape %s", @@ -70,12 +70,12 @@ XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus) { return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus)); - if (a_shape.dimensions_size() < 2) { + if (a_shape.dimensions().size() < 2) { return InvalidArgument( "Matrix `a` must have >= 2 dimensions: got shape %s", a_shape.ToString()); } - if (taus_shape.dimensions_size() + 1 != a_shape.dimensions_size()) { + if (taus_shape.dimensions().size() + 1 != a_shape.dimensions().size()) { return InvalidArgument( "Matrix `taus` must have one fewer dimension than `a`: got shapes " "%s and %s", @@ -91,10 +91,10 @@ XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus) { } absl::Span a_batch_dims = absl::MakeConstSpan( a_shape.dimensions().begin(), - a_shape.dimensions().begin() + a_shape.dimensions_size() - 2); + a_shape.dimensions().begin() + a_shape.dimensions().size() - 2); absl::Span taus_batch_dims = absl::MakeConstSpan( taus_shape.dimensions().begin(), - taus_shape.dimensions().begin() + taus_shape.dimensions_size() - 1); + taus_shape.dimensions().begin() + taus_shape.dimensions().size() - 1); const int64_t k = ShapeUtil::GetDimension(taus_shape, -1); if (a_shape.element_type() != taus_shape.element_type() || a_batch_dims != taus_batch_dims || k > n) { @@ -125,7 +125,7 @@ void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r) { t = SliceInMinorDims(qr.q_and_r, {0, 0}, {m, m}); } else { t = PadInDim(qr.q_and_r, Zero(a.builder(), a_shape.element_type()), - a_shape.dimensions_size() - 1, /*pad_lo=*/0, + a_shape.dimensions().size() - 1, /*pad_lo=*/0, /*pad_hi=*/m - n); } q = ProductOfElementaryHouseholderReflectors(t, qr.taus); diff --git a/third_party/xla/xla/hlo/builder/lib/quantize.h b/third_party/xla/xla/hlo/builder/lib/quantize.h index d0126f0c021b2f..e5d427d5b350dc 100644 --- a/third_party/xla/xla/hlo/builder/lib/quantize.h +++ b/third_party/xla/xla/hlo/builder/lib/quantize.h @@ -120,11 +120,11 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, bit_mask |= 0x000000ff; } - std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::vector shift_transpose_dimensions(shape.dimensions().size()); std::iota(shift_transpose_dimensions.begin(), shift_transpose_dimensions.end(), 0); shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, - shape.dimensions_size()); + shape.dimensions().size()); // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. XlaOp shifted_input = ShiftRightLogical( @@ -154,7 +154,7 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); } - std::vector transpose_dimensions(shape.dimensions_size()); + std::vector transpose_dimensions(shape.dimensions().size()); std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); @@ -171,7 +171,7 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, } // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. - std::vector result_dimensions(shape.dimensions_size()); + std::vector result_dimensions(shape.dimensions().size()); std::iota(result_dimensions.begin(), result_dimensions.end(), 0); std::reverse(result_dimensions.begin(), result_dimensions.end()); diff --git a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc index 09eb0b9457e838..ba651261751dd0 100644 --- a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc @@ -37,7 +37,7 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64_t max_iter, XlaBuilder* builder = a.builder(); XlaOp result = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int64_t num_dims = a_shape.dimensions_size(); + const int64_t num_dims = a_shape.dimensions().size(); if (num_dims < 2) { return InvalidArgument( "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", diff --git a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc index 192f00a9ac4a37..e46c284083bb94 100644 --- a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc @@ -126,10 +126,10 @@ XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { Shape shape = builder->GetShape(result.v).value(); absl::Span out_dims = shape.dimensions(); - std::vector broadcast_dims(shape.dimensions_size() - 1); + std::vector broadcast_dims(shape.dimensions().size() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); - broadcast_dims[shape.dimensions_size() - 2] = shape.dimensions_size() - 1; + broadcast_dims[shape.dimensions().size() - 2] = shape.dimensions().size() - 1; auto vw = Mul(result.v, BroadcastInDim(ConvertElementType(result.w, shape.element_type()), diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.cc b/third_party/xla/xla/hlo/builder/lib/slicing.cc index cc0d82e691bb88..ae0f6b987497a1 100644 --- a/third_party/xla/xla/hlo/builder/lib/slicing.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing.cc @@ -44,7 +44,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); TF_RET_CHECK(n_minor_dims <= n_dims); auto major_dims = shape.dimensions().subspan( /*pos=*/0, @@ -69,7 +69,7 @@ XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); const int64_t start_size = start.size(); TF_RET_CHECK(start_size == n_dims); @@ -89,7 +89,7 @@ XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); const int64_t n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -113,7 +113,7 @@ absl::StatusOr> PrependZerosInMajorDims( XlaOp x, absl::Span starts) { XlaBuilder* builder = x.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); auto zero = ConstantR0(builder, 0); std::vector padded_starts(n_dims, zero); for (int i = 0; i < starts.size(); ++i) { @@ -129,7 +129,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64_t n_dims = shape.dimensions_size(); + const int64_t n_dims = shape.dimensions().size(); int64_t n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); @@ -161,15 +161,15 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { index = ConvertElementType(index, U32); index_shape.set_element_type(U32); } - if (index_shape.dimensions_size() == 1) { + if (index_shape.dimensions().size() == 1) { return TorchIndexSelect(input, index, 0); } if (!sparse) { std::vector index_broadcast_dims; std::vector input_broadcast_dims; std::vector sizes; - sizes.reserve(index_shape.dimensions_size()); - for (int64_t i = 0; i < index_shape.dimensions_size(); ++i) { + sizes.reserve(index_shape.dimensions().size()); + for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { if (i < dim) { input_broadcast_dims.push_back(i); index_broadcast_dims.push_back(i); @@ -200,8 +200,8 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { ShapeUtil::AppendMajorDimension(1, &index_shape); std::vector to_concat; - to_concat.reserve(input_shape.dimensions_size()); - for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { + to_concat.reserve(input_shape.dimensions().size()); + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (i == dim) { to_concat.push_back(Reshape(index, index_shape.dimensions())); } else { @@ -209,11 +209,11 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { } } XlaOp gather_indices = - ConcatInDim(builder, to_concat, input_shape.dimensions_size()); - std::vector slice_sizes(input_shape.dimensions_size(), 1); + ConcatInDim(builder, to_concat, input_shape.dimensions().size()); + std::vector slice_sizes(input_shape.dimensions().size(), 1); GatherDimensionNumbers gather_dnums; - gather_dnums.set_index_vector_dim(input_shape.dimensions_size()); - for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { + gather_dnums.set_index_vector_dim(input_shape.dimensions().size()); + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { gather_dnums.add_collapsed_slice_dims(i); gather_dnums.add_start_index_map(i); } @@ -229,9 +229,9 @@ XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); std::vector index_broadcast_dims; std::vector sizes; - const auto rank = index_shape.dimensions_size(); + const auto rank = index_shape.dimensions().size(); sizes.reserve(rank + 1); - for (int64_t i = 0; i < index_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { if (i < dim) { index_broadcast_dims.push_back(i); } else { @@ -278,7 +278,7 @@ XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, } std::vector slice_sizes = SpanToVector(input_shape.dimensions()); GatherDimensionNumbers gather_dnums; - gather_dnums.set_index_vector_dim(index_shape.dimensions_size()); + gather_dnums.set_index_vector_dim(index_shape.dimensions().size()); if (batch_dims > 0) { ShapeUtil::AppendMajorDimension(1, &index_shape); std::vector to_concat; @@ -290,7 +290,7 @@ XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, to_concat.push_back(Reshape(index, index_shape.dimensions())); index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim()); } - for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (i < batch_dims || i == dim) { slice_sizes[i] = std::min(slice_sizes[i], 1); gather_dnums.add_collapsed_slice_dims(i); diff --git a/third_party/xla/xla/hlo/builder/lib/sorting.cc b/third_party/xla/xla/hlo/builder/lib/sorting.cc index b4b59beb3bbc3a..f8958f34f26c24 100644 --- a/third_party/xla/xla/hlo/builder/lib/sorting.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting.cc @@ -37,7 +37,7 @@ XlaOp TopK(XlaOp input, int64_t k, PrimitiveType index_type) { XlaBuilder* const builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); - int last_dim = input_shape.dimensions_size() - 1; + int last_dim = input_shape.dimensions().size() - 1; int64_t last_dim_size = input_shape.dimensions(last_dim); // TODO(b/148796364): tune these constants for better performance. const int64_t kPerPartitionSize = 8192; // 2^13 @@ -56,7 +56,7 @@ XlaOp TopK(XlaOp input, int64_t k, PrimitiveType index_type) { Shape iota_shape = ShapeUtil::MakeShape(index_type, input_shape.dimensions()); XlaOp iota = Iota(builder, iota_shape, last_dim); - for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (input_shape.is_dynamic_dimension(i)) { // Propagate dynamic dimension from inputs to iota. iota = SetDimensionSize(iota, GetDimensionSize(input, i), i); @@ -79,13 +79,13 @@ XlaOp TopK(XlaOp input, int64_t k, PrimitiveType index_type) { (input_shape.element_type() == BF16 && last_dim_size < kLow16BitsLimit && (last_dim_size < kMaxLastDimSizeForSmallBatches || - (input_shape.dimensions_size() == 2 && + (input_shape.dimensions().size() == 2 && input_shape.dimensions(0) >= kSmallBatchSizeThreshold))); - std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector start_indices(input_shape.dimensions().size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; - std::vector strides(input_shape.dimensions_size(), 1); + std::vector strides(input_shape.dimensions().size(), 1); XlaOp values; XlaOp indices; @@ -165,7 +165,7 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, XlaBuilder* const builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); - int last_dim = input_shape.dimensions_size() - 1; + int last_dim = input_shape.dimensions().size() - 1; // Calculate per partition size. auto input_dims = input_shape.dimensions(); int64_t last_dim_size = input_shape.dimensions(last_dim); @@ -179,7 +179,7 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, Shape iota_shape = ShapeUtil::MakeShape(index_type, input_shape.dimensions()); XlaOp iota = Iota(builder, iota_shape, last_dim); - for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (input_shape.is_dynamic_dimension(i)) { // Propagate dynamic dimension from inputs to iota. iota = SetDimensionSize(iota, GetDimensionSize(input, i), i); @@ -213,9 +213,9 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, sliced_indices.builder()), last_dim, true); - std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector start_indices(input_shape.dimensions().size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); - std::vector strides(input_shape.dimensions_size(), 1); + std::vector strides(input_shape.dimensions().size(), 1); // Slice topk. start_indices[last_dim] = 0; limit_indices[last_dim] = k; @@ -228,9 +228,9 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, // Get the values and indices for the first topk so that they can // be passed to the while loop. - std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector start_indices(input_shape.dimensions().size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); - std::vector strides(input_shape.dimensions_size(), 1); + std::vector strides(input_shape.dimensions().size(), 1); start_indices[last_dim] = 0; limit_indices[last_dim] = per_partition_size; // Slice value and indices for the first partition. diff --git a/third_party/xla/xla/hlo/builder/lib/svd.cc b/third_party/xla/xla/hlo/builder/lib/svd.cc index e00b7a1100ef25..561c107ba8085a 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd.cc @@ -115,7 +115,7 @@ absl::StatusOr HouseRow( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int64_t num_dims = a_shape.dimensions_size(); + const int64_t num_dims = a_shape.dimensions().size(); const int64_t n = ShapeUtil::GetDimension(a_shape, -1); XlaOp zero = ScalarLike(i, 0); XlaOp x = DynamicSliceInMinorDims(a, {i, zero}, {1, n}); @@ -181,7 +181,7 @@ absl::StatusOr HouseCol( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int64_t num_dims = a_shape.dimensions_size(); + const int64_t num_dims = a_shape.dimensions().size(); const int64_t m = ShapeUtil::GetDimension(a_shape, -2); XlaOp zero = ScalarLike(i, 0); XlaOp x = DynamicSliceInMinorDims(a, {zero, j}, {m, 1}); @@ -259,7 +259,7 @@ absl::StatusOr HouseHolderBidiagonalization( XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int64_t num_dims = a_shape.dimensions_size(); + const int64_t num_dims = a_shape.dimensions().size(); const int64_t num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { @@ -464,7 +464,7 @@ absl::StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp d = svd_result.d; XlaBuilder* builder = d.builder(); TF_ASSIGN_OR_RETURN(Shape d_shape, builder->GetShape(d)); - const int64_t num_dims = d_shape.dimensions_size(); + const int64_t num_dims = d_shape.dimensions().size(); const int64_t num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { @@ -574,7 +574,7 @@ absl::StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, absl::StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { XlaBuilder* builder = w.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); - auto num_dims = static_cast(shape.dimensions_size()); + auto num_dims = static_cast(shape.dimensions().size()); int64_t n = shape.dimensions(num_dims - 1); shape.set_dimensions(num_dims - 2, n); auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n}); @@ -743,7 +743,7 @@ absl::StatusOr SortBySingularValuesAndPostProcessing( SVDResult result) { XlaBuilder* builder = result.d.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d)); - const int64_t num_dims = shape.dimensions_size(); + const int64_t num_dims = shape.dimensions().size(); auto dimensions = shape.dimensions(); const int64_t m = ShapeUtil::GetDimension(shape, -2); const int64_t n = ShapeUtil::GetDimension(shape, -1); @@ -844,7 +844,7 @@ SVDResult SVD(XlaOp a, int64_t max_iter, float epsilon, return return_error(shape_with_status.status()); } Shape a_shape = shape_with_status.value(); - const int64_t num_dims = a_shape.dimensions_size(); + const int64_t num_dims = a_shape.dimensions().size(); const int64_t num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { diff --git a/third_party/xla/xla/hlo/builder/lib/svd_test.cc b/third_party/xla/xla/hlo/builder/lib/svd_test.cc index 170a7f0fa8c42e..0bb0a63a011fe4 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd_test.cc @@ -84,7 +84,7 @@ class SVDTest : public ClientLibraryTestBase { v = SliceInMinorDims(v, {0, 0}, {n, m}); } - int num_dims = u_shape.dimensions_size(); + int num_dims = u_shape.dimensions().size(); std::vector broadcast_dims(num_dims - 1); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims[num_dims - 2] = num_dims - 1; diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc index 0cd4e8190c66f8..ab9ffac8b51be0 100644 --- a/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc @@ -69,10 +69,10 @@ absl::StatusOr CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, builder->GetShape(upper_diagonal)); TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs)); - const auto lower_diagonal_rank = lower_diagonal_shape.dimensions_size(); - const auto main_diagonal_rank = main_diagonal_shape.dimensions_size(); - const auto upper_diagonal_rank = upper_diagonal_shape.dimensions_size(); - const auto rhs_rank = rhs_shape.dimensions_size(); + const auto lower_diagonal_rank = lower_diagonal_shape.dimensions().size(); + const auto main_diagonal_rank = main_diagonal_shape.dimensions().size(); + const auto upper_diagonal_rank = upper_diagonal_shape.dimensions().size(); + const auto rhs_rank = rhs_shape.dimensions().size(); if (!((lower_diagonal_rank == main_diagonal_rank) && (lower_diagonal_rank == upper_diagonal_rank) && (lower_diagonal_rank == rhs_rank))) { @@ -127,8 +127,8 @@ struct TridiagonalMatMulShapeParams { absl::Status ValidateTridiagonalMatMulDiagonal( const Shape& diagonal_shape, const absl::string_view diagonal_name, const Shape& rhs_shape) { - const int64_t diagonal_rank = diagonal_shape.dimensions_size(); - const int64_t rhs_rank = rhs_shape.dimensions_size(); + const int64_t diagonal_rank = diagonal_shape.dimensions().size(); + const int64_t rhs_rank = rhs_shape.dimensions().size(); if (diagonal_rank != rhs_rank) { return InvalidArgument("%s must have same rank as rhs, but got %d and %d.", diagonal_name, diagonal_rank, rhs_rank); @@ -178,7 +178,7 @@ CheckMatMulSystemAndReturnShapeParams(XlaOp upper_diagonal, XlaOp main_diagonal, builder->GetShape(lower_diagonal)); TF_ASSIGN_OR_RETURN(const Shape rhs_shape, builder->GetShape(rhs)); - const int64_t rank = rhs_shape.dimensions_size(); + const int64_t rank = rhs_shape.dimensions().size(); if (rank < 2) { return InvalidArgument("Input must have rank >= 2, but got %d.", rank); } @@ -405,7 +405,7 @@ absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, XlaOp rhs) { XlaBuilder* builder = diagonals.builder(); TF_ASSIGN_OR_RETURN(Shape diagonals_shape, builder->GetShape(diagonals)); - const int64_t rank = diagonals_shape.dimensions_size(); + const int64_t rank = diagonals_shape.dimensions().size(); auto upper_diagonal = SliceInDim(diagonals, /*start_index=*/0, /*limit_index=*/1, From 6498b299a1c8331777752fa8ba82a969a7fc810a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 02:02:29 -0700 Subject: [PATCH 0616/1324] Update GraphDef version to 2195. PiperOrigin-RevId: 746755099 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 0e6bfd86c33f93..c0c48201e1cb3e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2194 // Updated: 2025/4/11 +#define TF_GRAPH_DEF_VERSION 2195 // Updated: 2025/4/12 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From c163841a4e6792c93d32e573d7e9c51adb35facc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 02:02:34 -0700 Subject: [PATCH 0617/1324] compat: Update forward compatibility horizon to 2025-04-12 PiperOrigin-RevId: 746755136 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 14e0ab50b31db3..3377ee2bae9e17 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 11) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 12) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From b93a95c14a381a400d3b135cfa77b9b260919573 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 02:25:54 -0700 Subject: [PATCH 0618/1324] Automated Code Change PiperOrigin-RevId: 746760312 --- third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc b/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc index 1379a912e72edd..fb44caf056f6e7 100644 --- a/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc +++ b/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" From 2aa637d22915a75c2941149b8447a2ee19c77b8d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 07:04:30 -0700 Subject: [PATCH 0619/1324] Automated Code Change PiperOrigin-RevId: 746811156 --- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/python/custom_partition_callback.cc | 1 + third_party/xla/xla/python/custom_partition_callback.h | 2 ++ 3 files changed, 4 insertions(+) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index ca03d08d37427d..cc85f7b928294a 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -153,6 +153,7 @@ cc_library( deps = [ "//xla:debug_options_flags", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", diff --git a/third_party/xla/xla/python/custom_partition_callback.cc b/third_party/xla/xla/python/custom_partition_callback.cc index fed210f4157538..419798f02bdb3a 100644 --- a/third_party/xla/xla/python/custom_partition_callback.cc +++ b/third_party/xla/xla/python/custom_partition_callback.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/python/custom_partition_callback.h b/third_party/xla/xla/python/custom_partition_callback.h index 6ba1789a038daa..53cec15ed7a768 100644 --- a/third_party/xla/xla/python/custom_partition_callback.h +++ b/third_party/xla/xla/python/custom_partition_callback.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/c/pjrt_c_api.h" From adc0e5688dfc04d2ab49f916734954c7df869223 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 10:52:19 -0700 Subject: [PATCH 0620/1324] Automated Code Change PiperOrigin-RevId: 746851492 --- third_party/xla/xla/python/ifrt/array.h | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/python/ifrt/array.h b/third_party/xla/xla/python/ifrt/array.h index f7d46970c74cfa..63509a1cd81c1e 100644 --- a/third_party/xla/xla/python/ifrt/array.h +++ b/third_party/xla/xla/python/ifrt/array.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/types/span.h" From 21ffe360aa844d882d898a17cc3846e94cc0b9e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 11:02:37 -0700 Subject: [PATCH 0621/1324] Automated Code Change PiperOrigin-RevId: 746853069 --- third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc index 301a82569aec80..1e85b4116dadf8 100644 --- a/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc @@ -137,9 +137,11 @@ absl::StatusOr> CreateArray( MemoryKind(), /*shape=*/shape, /*shard_shape=*/std::move(shard_shape)); + absl::Span> arrays = absl::MakeSpan(shards); return client->AssembleArrayFromSingleDeviceArrays( - std::move(shape), std::move(assembled_sharding), absl::MakeSpan(shards), - ArrayCopySemantics::kDonateInput); + arrays.at(0)->dtype(), std::move(shape), std::move(assembled_sharding), + arrays, ArrayCopySemantics::kDonateInput, + SingleDeviceShardSemantics::kAddressableShards); } // Checks the shards and contents of an array, same as what CreateArray would From 34afb729aab89cc6eb8c6f89cf303f931fd9c800 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 11:05:19 -0700 Subject: [PATCH 0622/1324] Automated Code Change PiperOrigin-RevId: 746853674 --- third_party/xla/xla/text_literal_reader.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/text_literal_reader.cc b/third_party/xla/xla/text_literal_reader.cc index 587730874ebd2b..d796f39af00789 100644 --- a/third_party/xla/xla/text_literal_reader.cc +++ b/third_party/xla/xla/text_literal_reader.cc @@ -115,11 +115,11 @@ absl::StatusOr TextLiteralReader::ReadAllLines() { } coordinate_values.push_back(coordinate_value); } - if (coordinate_values.size() != shape.dimensions_size()) { + if (coordinate_values.size() != shape.dimensions().size()) { return InvalidArgument( "line did not have expected number of coordinates; want %d got %u: " "\"%s\"", - shape.dimensions_size(), coordinate_values.size(), line); + shape.dimensions().size(), coordinate_values.size(), line); } result.Set(coordinate_values, value); } From 6f5ef47b4316ef0cb10d711d8de8903b3d6a28cd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 12 Apr 2025 19:01:23 -0700 Subject: [PATCH 0623/1324] Don't expand shape if the constant is a scalar, which can't be the case for slice shape that's to be concated. PiperOrigin-RevId: 746937234 --- .../xla/xla/service/collective_pipeliner.cc | 60 +++++++++++-------- .../xla/service/collective_pipeliner_test.cc | 5 +- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 60e57b4672641f..1f7c172ab4387f 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -211,7 +211,7 @@ CollectDynamicSliceIndicesIfConstant(HloInstruction* instr) { for (int64_t i = dyn_slice->first_index_operand_number(); i < instr->operand_count(); ++i) { HloInstruction* operand = dyn_slice->mutable_operand(i); - CHECK_EQ(operand->shape().dimensions_size(), 0); + CHECK(operand->shape().dimensions().empty()); std::vector> stack( 1, std::make_pair(operand, 0)); absl::flat_hash_set visited; @@ -343,12 +343,11 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, ShapeUtil::ElementsIn(instr->operand(0)->shape()) < 1024)) { return true; } - // TODO(b/409716406): Reconsider cases where Pad can be supported. return HloPredicateIsOp(i) || + HloOpcode::kPad, HloOpcode::kCollectivePermute, + HloOpcode::kConvert, HloOpcode::kReshape, + HloOpcode::kAllReduce, HloOpcode::kTranspose, + HloOpcode::kBroadcast, HloOpcode::kAllGather>(i) || (multi_uses_pipelining && i->IsElementwise()) || i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep) || i->IsCustomCall(CollectivePipeliner::kSunkByPreviousStep); @@ -1529,7 +1528,7 @@ Shape ComputeFullOutputShape(const WhileMoveInfo& move_info, // Create zero of base type ptype and broadcast it to shape. HloInstruction* CreateZero(HloComputation* comp, const Shape& shape, PrimitiveType ptype) { - if (shape.dimensions_size() == 0) { + if (shape.dimensions().empty()) { return comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); } @@ -2009,8 +2008,8 @@ absl::Status TransformLoopForward( if (slice_target_shape != data_to_slice->shape()) { // Slice matrix. absl::InlinedVector dynamic_slice_sizes; - dynamic_slice_sizes.reserve(slice_target_shape.dimensions_size()); - for (int i = 0; i < slice_target_shape.dimensions_size(); ++i) { + dynamic_slice_sizes.reserve(slice_target_shape.dimensions().size()); + for (int i = 0; i < slice_target_shape.dimensions().size(); ++i) { dynamic_slice_sizes.push_back(slice_target_shape.dimensions(i)); } sliced_data = @@ -2268,7 +2267,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, Shape index_shape = move_info.dynamic_update_slices.front()->index_shapes()[0]; std::vector indices( - expanded_shape.dimensions_size(), + expanded_shape.dimensions().size(), CreateZero(body_computation, index_shape, index_shape.element_type())); indices[0] = move_info.dynamic_update_slices.front()->index_operands()[0]; @@ -2313,7 +2312,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, HloDynamicUpdateSliceInstruction* dyn_update = to_move.dynamic_update_slices[0]; std::vector indices( - expanded_shape.dimensions_size(), + expanded_shape.dimensions().size(), CreateZero(body_computation, dyn_update->index_shapes()[0], dyn_update->index_shapes()[0].element_type())); indices[0] = dyn_update->index_operands()[0]; @@ -2434,7 +2433,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, if (is_loop_invariant) { Shape full_shape = ComputeFullOutputShape(to_move, pipelined->shape()); absl::InlinedVector operand_dims; - operand_dims.resize(pipelined->shape().dimensions_size()); + operand_dims.resize(pipelined->shape().dimensions().size()); absl::c_iota(operand_dims, 1); HloInstruction* broadcasted = loop_computation->AddInstruction(HloInstruction::CreateBroadcast( @@ -2460,21 +2459,30 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, std::vector operands; for (auto* operand : instr->mutable_operands()) { if (operand->opcode() == HloOpcode::kConstant) { - HloInstruction* cloned_constant = loop_computation->AddInstruction( - operand->CloneWithNewOperands(operand->shape(), {})); - if (!to_add_batch_set.contains(instr)) { - operands.push_back(cloned_constant); + if (!operand->shape().dimensions().empty()) { + // Broadcast constant into full shape. + HloInstruction* cloned_constant = loop_computation->AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {})); + if (!to_add_batch_set.contains(instr)) { + operands.push_back(cloned_constant); + continue; + } + Shape full_shape = + ComputeFullOutputShape(to_move, cloned_constant->shape()); + absl::InlinedVector operand_dims; + operand_dims.resize(cloned_constant->shape().dimensions().size()); + absl::c_iota(operand_dims, 1); + HloInstruction* broadcasted = loop_computation->AddInstruction( + HloInstruction::CreateBroadcast(full_shape, cloned_constant, + operand_dims)); + operands.push_back(broadcasted); continue; } - Shape full_shape = - ComputeFullOutputShape(to_move, cloned_constant->shape()); - absl::InlinedVector operand_dims; - operand_dims.resize(cloned_constant->shape().dimensions_size()); - absl::c_iota(operand_dims, 1); - HloInstruction* broadcasted = - loop_computation->AddInstruction(HloInstruction::CreateBroadcast( - full_shape, cloned_constant, operand_dims)); - operands.push_back(broadcasted); + // The constant may be for something like a padding value. And a + // scalar shape can't be for slice shape that's to be concatenated. + // No need to broadcast. + operands.push_back(loop_computation->AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {}))); continue; } auto it = pipelined_map.find(operand); @@ -2546,7 +2554,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } // Constant scalars don't get expanded ahead of time and are kept // scalar. - if (operands[0]->shape().dimensions_size() == 0) { + if (operands[0]->shape().dimensions().empty()) { dimensions.clear(); } HloInstruction* expanded_broadcast = diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 556d0234db0500..e1c72ab46888c0 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -3253,7 +3253,10 @@ while_body { dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 - b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + slice = bf16[1,8,120] slice(ar.1), slice={[0:1], [0:8], [0:120]} + constant.2563 = bf16[] constant(5.0) + pad = bf16[1,8,128] pad(slice, constant.2563), padding=0_0x0_0x0_8 + b.1 = bf16[1,8,128,32] broadcast(pad), dimensions={0,1,2} constant = bf16[] constant(0) reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) From f95849c82d358be92c4ffd695175f9b5d4823b96 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sat, 12 Apr 2025 23:14:05 -0700 Subject: [PATCH 0624/1324] Run build_cleaner on xla/ directory. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 746989444 --- third_party/xla/xla/BUILD | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index aab8b49a216641..766da958265d22 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -169,8 +169,6 @@ xla_cc_test( ":ef57", "//xla/hlo/testlib:test", "//xla/tsl/platform:test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:log_streamer", "@com_google_absl//absl/random", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", @@ -277,7 +275,6 @@ cc_library( deps = [ ":types", ":util", - "@com_google_absl//absl/base:core_headers", ], ) @@ -527,8 +524,8 @@ xla_cc_test( srcs = ["shape_partition_test.cc"], deps = [ ":shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", + ":util", + ":xla_data_proto_cc", "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -563,7 +560,6 @@ xla_cc_test( "//xla/tsl/platform:logging", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:test_main", - "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", ], @@ -690,7 +686,6 @@ xla_cc_test( name = "literal_pool_test", srcs = ["literal_pool_test.cc"], deps = [ - ":literal", ":literal_pool", ":literal_util", "//xla/tsl/platform:test", @@ -916,7 +911,6 @@ cc_library( deps = [ ":literal", ":shape_util", - ":status_macros", ":types", ":util", ":xla_data_proto_cc", @@ -929,7 +923,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:protobuf", ], ) @@ -941,7 +934,6 @@ cc_library( deps = [ ":literal", ":shape_util", - ":status_macros", ":types", ":util", ":xla_data_proto_cc", @@ -954,7 +946,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:protobuf", ], ) @@ -965,7 +956,6 @@ xla_cc_test( ":literal", ":shape_util", ":text_literal_reader", - ":types", ":xla_data_proto_cc", "//xla/hlo/testlib:test", "//xla/tsl/platform:env", @@ -982,7 +972,6 @@ cc_library( deps = [ ":literal", ":shape_util", - ":status_macros", ":types", ":xla_data_proto_cc", "//xla/tsl/platform:env", @@ -997,10 +986,8 @@ xla_cc_test( name = "text_literal_writer_test", srcs = ["text_literal_writer_test.cc"], deps = [ - ":literal", ":literal_util", ":text_literal_writer", - ":types", "//xla/hlo/testlib:test", "//xla/hlo/testlib:test_helpers", "//xla/tsl/lib/core:status_test_util", @@ -1153,7 +1140,6 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], @@ -1403,7 +1389,6 @@ xla_cc_test( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log:log_streamer", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", ], From 42c135d841ba3b347e44ea9f64b338040414c70a Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:27:01 -0700 Subject: [PATCH 0625/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/cpu/collectives. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directoI've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747004981 --- .../xla/xla/backends/cpu/collectives/BUILD | 75 ++----------------- 1 file changed, 6 insertions(+), 69 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/collectives/BUILD b/third_party/xla/xla/backends/cpu/collectives/BUILD index cfce13a7963680..ac5ec02cab5e3d 100644 --- a/third_party/xla/xla/backends/cpu/collectives/BUILD +++ b/third_party/xla/xla/backends/cpu/collectives/BUILD @@ -24,7 +24,6 @@ cc_library( "//xla/service:global_device_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:casts", ], @@ -55,7 +54,6 @@ cc_library( ":cpu_clique_key", ":cpu_collectives", "//xla:util", - "//xla/core/collectives:clique", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/tsl/platform:errors", @@ -91,7 +89,6 @@ cc_library( srcs = ["cpu_collectives.cc"], hdrs = ["cpu_collectives.h"], deps = [ - "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives", @@ -102,7 +99,6 @@ cc_library( "//xla/service:collective_ops_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", @@ -117,28 +113,15 @@ cc_library( deps = [ ":cpu_collectives", ":in_process_communicator", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", - "//xla/core/collectives:rank_id", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "//xla/stream_executor:device_memory", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", ], ) @@ -149,20 +132,17 @@ cc_library( deps = [ ":cpu_collectives", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", "//xla/service:rendezvous", "//xla/stream_executor:device_memory", "//xla/tsl/lib/math:math_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -170,7 +150,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", ], ) @@ -207,34 +186,19 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", + ":cpu_collectives", + ":gloo_communicator", "//xla:xla_data_proto_cc", - "//xla/backends/cpu/collectives:cpu_collectives", - "//xla/backends/cpu/collectives:gloo_communicator", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", - "//xla/core/collectives:rank_id", - "//xla/service:collective_ops_utils", "//xla/service:global_device_id", - "//xla/stream_executor:device_memory", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@gloo", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", ], ) @@ -242,12 +206,12 @@ xla_cc_test( name = "gloo_collectives_test", srcs = ["gloo_collectives_test.cc"], deps = [ + ":cpu_clique_key", + ":cpu_collectives", ":gloo_collectives", ":gloo_kv_store", "//xla:executable_run_options", "//xla:xla_data_proto_cc", - "//xla/backends/cpu/collectives:cpu_clique_key", - "//xla/backends/cpu/collectives:cpu_collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", @@ -297,22 +261,16 @@ cc_library( "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@gloo", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", ], ) @@ -330,29 +288,17 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", + ":cpu_collectives", + ":mpi_communicator", "//xla:xla_data_proto_cc", - "//xla/backends/cpu/collectives:cpu_collectives", - "//xla/backends/cpu/collectives:mpi_communicator", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", - "//xla/service:collective_ops_utils", "//xla/service:global_device_id", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@mpitrampoline", ], ) @@ -373,27 +319,18 @@ cc_library( deps = [ "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@mpitrampoline", ], From dfba91e153eb20ce4a211f761c44e3c095f3be7a Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:33:14 -0700 Subject: [PATCH 0626/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/gpu/codegen. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747006169 --- third_party/xla/xla/backends/gpu/codegen/BUILD | 8 -------- 1 file changed, 8 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/BUILD b/third_party/xla/xla/backends/gpu/codegen/BUILD index ea8b374199ed74..bf96c08415da39 100644 --- a/third_party/xla/xla/backends/gpu/codegen/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/BUILD @@ -50,15 +50,8 @@ xla_cc_test( ":copy", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/testlib:verified_hlo_module", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -96,7 +89,6 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/backends/gpu/runtime:thunk", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:pattern_matcher_gmock", From 641e533c0bdf89fcdc764f8bfd7730f1b52aed0f Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:35:24 -0700 Subject: [PATCH 0627/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/gpu/codegen/emitters. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747006712 --- third_party/xla/xla/backends/gpu/codegen/emitters/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD index dcb6158181976e..b972f556a56a0e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD @@ -312,10 +312,8 @@ cc_library( "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:reduction_utils", "//xla/stream_executor:device_description", - "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", From c1287ff38255290f2c013e40c8411045e7191ca9 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:35:24 -0700 Subject: [PATCH 0628/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747006717 --- third_party/xla/xla/backends/interpreter/BUILD | 6 ------ third_party/xla/xla/backends/profiler/gpu/BUILD | 3 +-- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index 036ba547cd59fc..a2474bb27a3988 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -163,17 +163,12 @@ cc_library( srcs = ["executor.cc"], hdrs = ["executor.h"], deps = [ - "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/stream_executor:blas", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:generic_memory_allocation", "//xla/stream_executor:generic_memory_allocator", - "//xla/stream_executor:kernel", - "//xla/stream_executor:kernel_spec", - "//xla/stream_executor:launch_dim", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:memory_allocator", "//xla/stream_executor:platform", @@ -186,6 +181,5 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 2be5360977daef..6ba5c0f0c24a4a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -392,8 +392,6 @@ xla_test( ], deps = [ ":cupti_buffer_events", - ":cupti_collector", - ":cupti_tracer", ":cupti_utils", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", @@ -435,6 +433,7 @@ cuda_library( local_defines = if_oss(["NVTX_VERSION_3_1=1"]), tags = ["cuda-only"], visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], ) xla_test( From d96c20edef976a3e812d0dbcc5bfae0e8b24f716 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:36:18 -0700 Subject: [PATCH 0629/1324] Run build_cleaner on BUILD file(s) located in /xla/ffi/api. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747006901 --- third_party/xla/xla/ffi/api/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index a542b768aa4882..b573810b6b7195 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -90,7 +90,6 @@ xla_cc_test( "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", - "//xla/tsl/platform:errors", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", From dec1b1290b68714d2e1a953c96f0cda8edfad87c Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:36:40 -0700 Subject: [PATCH 0630/1324] Run build_cleaner on BUILD file(s) located in /xla/hlo/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747006992 --- third_party/xla/xla/hlo/analysis/BUILD | 7 +------ third_party/xla/xla/hlo/experimental/auto_sharding/BUILD | 3 --- third_party/xla/xla/hlo/ir/BUILD | 3 --- third_party/xla/xla/hlo/pass/BUILD | 2 -- third_party/xla/xla/hlo/testlib/BUILD | 1 - third_party/xla/xla/hlo/tools/tests/BUILD | 1 - 6 files changed, 1 insertion(+), 16 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/BUILD b/third_party/xla/xla/hlo/analysis/BUILD index c20bfc78def4ec..7fc5b511dca645 100644 --- a/third_party/xla/xla/hlo/analysis/BUILD +++ b/third_party/xla/xla/hlo/analysis/BUILD @@ -98,11 +98,11 @@ cc_library( hdrs = ["hlo_ordering.h"], deps = [ ":hlo_dataflow_analysis", + ":hlo_reachability", "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", - "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/service:call_graph", "//xla/service:hlo_proto_cc", @@ -110,7 +110,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -207,7 +206,6 @@ cc_library( "//xla/service:hlo_phi_graph", "//xla/service:hlo_value", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -355,7 +353,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", ], @@ -415,7 +412,6 @@ xla_cc_test( deps = [ ":hlo_alias_analysis", ":hlo_ordering", - "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -424,7 +420,6 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/hlo/testlib:test_helpers", "//xla/hlo/transforms/simplifiers:flatten_call_graph", - "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_buffer", "//xla/service:hlo_value", "//xla/tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 53e03158ce2e42..e2f92b427ffd81 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -177,7 +177,6 @@ cc_library( hdrs = ["auto_sharding_cost_graph.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ - ":auto_sharding_device_mesh", ":auto_sharding_strategy", ":matrix", "//xla:shape_util", @@ -196,9 +195,7 @@ cc_library( hdrs = ["auto_sharding_option.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ - ":auto_sharding_device_mesh", ":auto_sharding_util", - "//xla:array", "//xla/service:hlo_module_config", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 47f798ca0c04d4..72bca9a992350a 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -223,7 +223,6 @@ xla_cc_test( "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@com_google_googletest//:gtest", ], ) @@ -263,7 +262,6 @@ xla_cc_test( ":hlo_instruction_utils", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_query", - "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", "@com_google_googletest//:gtest", ], @@ -317,7 +315,6 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:test_main", - "@com_google_absl//absl/log:globals", "@com_google_googletest//:gtest", ], ) diff --git a/third_party/xla/xla/hlo/pass/BUILD b/third_party/xla/xla/hlo/pass/BUILD index aa4c6f25ff7e3c..f4dd3228f24fcb 100644 --- a/third_party/xla/xla/hlo/pass/BUILD +++ b/third_party/xla/xla/hlo/pass/BUILD @@ -66,7 +66,6 @@ xla_cc_test( "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -103,7 +102,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", "@local_tsl//tsl/profiler/lib:scoped_annotation", "@local_tsl//tsl/profiler/lib:traceme", ], diff --git a/third_party/xla/xla/hlo/testlib/BUILD b/third_party/xla/xla/hlo/testlib/BUILD index 549a30afa218f0..0300aed0a0bfce 100644 --- a/third_party/xla/xla/hlo/testlib/BUILD +++ b/third_party/xla/xla/hlo/testlib/BUILD @@ -119,7 +119,6 @@ cc_library( testonly = 1, hdrs = ["pattern_matcher_gmock.h"], deps = [ - "test", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", diff --git a/third_party/xla/xla/hlo/tools/tests/BUILD b/third_party/xla/xla/hlo/tools/tests/BUILD index 7d101c7dbbda9d..0fb22f00c5f56c 100644 --- a/third_party/xla/xla/hlo/tools/tests/BUILD +++ b/third_party/xla/xla/hlo/tools/tests/BUILD @@ -102,7 +102,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", ], ) From b4a41969206079d2a6f0ac87d2b7ba497187bc51 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:37:00 -0700 Subject: [PATCH 0631/1324] Run build_cleaner on BUILD file(s) located in /xla/codegen/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747007086 --- third_party/xla/xla/codegen/emitters/BUILD | 1 - third_party/xla/xla/codegen/testlib/BUILD | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/BUILD b/third_party/xla/xla/codegen/emitters/BUILD index 027a21c82a5944..df91fb8fdb3543 100644 --- a/third_party/xla/xla/codegen/emitters/BUILD +++ b/third_party/xla/xla/codegen/emitters/BUILD @@ -51,7 +51,6 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", ], diff --git a/third_party/xla/xla/codegen/testlib/BUILD b/third_party/xla/xla/codegen/testlib/BUILD index 5e60e49471f992..345ea53acc5037 100644 --- a/third_party/xla/xla/codegen/testlib/BUILD +++ b/third_party/xla/xla/codegen/testlib/BUILD @@ -89,7 +89,7 @@ py_strict_test( "no_oss", ], deps = [ - ":_extension", + ":_extension", # buildcleaner: keep ":testlib", "//third_party/py/numpy", "//xla/python:xla_extension", From dad08645f9f23631553b229bc63d354828d1ee94 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:37:55 -0700 Subject: [PATCH 0632/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/gpu/codegen/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747007289 --- .../xla/xla/backends/gpu/codegen/tools/BUILD | 6 +-- .../xla/xla/backends/gpu/codegen/triton/BUILD | 39 ------------------- 2 files changed, 1 insertion(+), 44 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/tools/BUILD b/third_party/xla/xla/backends/gpu/codegen/tools/BUILD index 12bc08c27f8b2e..f4778af992fec1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/tools/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/tools/BUILD @@ -121,10 +121,7 @@ py_strict_binary( py_strict_library( name = "ncu_rep_lib", srcs = ["ncu_rep_lib.py"], - deps = [ - "@absl_py//absl:app", - "@absl_py//absl/flags", - ], + deps = ["@absl_py//absl:app"], ) py_strict_test( @@ -132,7 +129,6 @@ py_strict_test( srcs = ["ncu_rep_test.py"], deps = [ ":ncu_rep_lib", - "@absl_py//absl/flags", "@absl_py//absl/testing:absltest", ], ) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index b1a34cfbeb91ff..cc323e19e4f271 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -355,57 +355,25 @@ cc_library( hdrs = ["dot_algorithms.h"], deps = [ ":emitter_helpers", - "//xla:comparison_util", - "//xla:literal", "//xla:shape_util", - "//xla:status_macros", - "//xla:util", "//xla:xla_data_proto_cc", - "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_traversal", - "//xla/mlir_hlo", - "//xla/mlir_hlo:map_mhlo_to_scalar_op", - "//xla/mlir_hlo:transformation_helpers", "//xla/service:algorithm_util", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:matmul_indexing_utils", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:triton_fusion_analysis", - "//xla/service/gpu:triton_tiling_propagation", - "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", - "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor:device_description", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor/gpu:tma_metadata", "//xla/tsl/platform:errors", - "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@triton//:TritonDialects", ], @@ -457,7 +425,6 @@ xla_cc_test( ":fusion_emitter_stub_for_testing", "//xla:literal", "//xla:literal_util", - "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", @@ -636,7 +603,6 @@ xla_test( "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/platform:statusor", - "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:path", @@ -758,7 +724,6 @@ cc_library( "//xla/backends/profiler/gpu:cupti_collector", "//xla/backends/profiler/gpu:cupti_tracer", "//xla/tsl/profiler/utils:time_utils", - "@com_google_absl//absl/algorithm:container", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -905,7 +870,6 @@ cc_library( "//xla/service/gpu:matmul_indexing_utils", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -964,7 +928,6 @@ xla_test( tags = ["no_mac"], deps = [ ":fusion_emitter", - ":kernel_name_tracer", ":support", ":test_utils", "//xla:error_spec", @@ -993,8 +956,6 @@ cc_library( deps = [ "//xla:shape_util", "//xla/codegen:emitter_loc_op_builder", - "//xla/service:hlo_module_config", - "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:tma_metadata", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", From cc063ee1f457522c2e109a1cbcb902534950e2be Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:39:43 -0700 Subject: [PATCH 0633/1324] Run build_cleaner on BUILD file(s) located in /xla/ffi. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747007688 --- third_party/xla/xla/ffi/BUILD | 6 ------ 1 file changed, 6 deletions(-) diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 5bfc261890ae83..baa702b6650d45 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -47,10 +47,8 @@ xla_cc_test( "//xla/stream_executor:device_memory", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:test_main", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ], @@ -70,7 +68,6 @@ cc_library( "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", ], ) @@ -110,7 +107,6 @@ xla_cc_test( deps = [ ":execution_state", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", @@ -174,7 +170,6 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -193,7 +188,6 @@ cc_library( hdrs = ["attribute_map.h"], deps = [ ":call_frame", - "//xla:util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", From 426d8c10df095b74417fa0bfa41a4ac84d7b6831 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:40:25 -0700 Subject: [PATCH 0634/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/gpu/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747007849 --- third_party/xla/xla/backends/gpu/collectives/BUILD | 5 ----- third_party/xla/xla/backends/gpu/runtime/BUILD | 14 -------------- 2 files changed, 19 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index ec0f5076115cbb..3595fe88020edf 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -53,8 +53,6 @@ cc_library( srcs = ["gpu_clique_key.cc"], hdrs = ["gpu_clique_key.h"], deps = [ - "//xla/core/collectives", - "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/service:global_device_id", "//xla/tsl/lib/gtl:int_type", @@ -76,11 +74,9 @@ xla_cc_test( ":gpu_clique_key", "//xla/core/collectives:clique_id", "//xla/service:global_device_id", - "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/status", "@com_google_googletest//:gtest", ], ) @@ -143,7 +139,6 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/tsl/platform:logging", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 101dc1bdd138f3..409717e1569e60 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -98,7 +98,6 @@ cc_library( "//xla/tsl/lib/gtl:int_type", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", - "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -126,7 +125,6 @@ cc_library( ":all_gather_thunk", ":all_reduce_thunk", ":all_to_all_thunk", - ":collective_thunk", ":command_buffer_cmd", ":conditional_thunk", ":copy_thunk", @@ -140,7 +138,6 @@ cc_library( ":replica_id_thunk", ":sequential_thunk", ":thunk", - ":wait_for_streams_thunk", ":while_thunk", "//xla:util", "//xla/runtime:buffer_use", @@ -415,7 +412,6 @@ cc_library( ":host_memory_pool", ":sequential_thunk", ":thunk", - "//xla:shape_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", @@ -755,7 +751,6 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -820,7 +815,6 @@ cc_library( "//xla/stream_executor:event", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream", - "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", @@ -871,7 +865,6 @@ cc_library( ":p2p_thunk_common", ":thunk", "//xla:executable_run_options", - "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", @@ -886,7 +879,6 @@ cc_library( "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", - "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", @@ -970,7 +962,6 @@ cc_library( "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/core/collectives:communicator", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/service:collective_ops_utils", @@ -1055,9 +1046,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/stream_executor:event", "//xla/stream_executor:stream", - "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -1195,7 +1184,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", ], ) @@ -1269,7 +1257,6 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/profiler/lib:traceme_encode", ], ) @@ -1283,7 +1270,6 @@ xla_test( "//xla:executable_run_options", "//xla/service:executable", "//xla/service:platform_util", - "//xla/service/gpu:buffer_allocations", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", From 9d8bca547e17e7d83080e47c7eea91e662308109 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:41:15 -0700 Subject: [PATCH 0635/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt/c. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747008013 --- third_party/xla/xla/pjrt/c/BUILD | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index d1f3a43584314f..e95f741002efcd 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -385,7 +385,7 @@ cc_library( "//xla/pjrt/gpu:gpu_helpers", "//xla/pjrt/gpu:gpu_topology", "//xla/pjrt/gpu:se_gpu_pjrt_client", - "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler + "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # buildcleaner: keep to register GPU AOT compiler "//xla/python:custom_call_batch_partitioner", "//xla/python:custom_partition_callback", "//xla/python:debug_callback_partitioner", # To register "DebugCallbackCustomCallPartitioner" custom partitioning handler. @@ -527,17 +527,13 @@ xla_cc_test( ":pjrt_c_api_helpers", ":pjrt_c_api_wrapper_impl", "//xla:shape_util", - "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:status", - "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", "@stablehlo//:version", @@ -550,7 +546,6 @@ xla_cc_test( deps = [ ":pjrt_c_api_cpu", ":pjrt_c_api_test_common", - ":pjrt_c_api_wrapper_impl", "@com_google_googletest//:gtest_main", ], ) From 05de0c5f2e9d0d38e7966195f58435ce6e060042 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:41:20 -0700 Subject: [PATCH 0636/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/cpu/benchmarks. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747008024 --- third_party/xla/xla/backends/cpu/benchmarks/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/benchmarks/BUILD b/third_party/xla/xla/backends/cpu/benchmarks/BUILD index 0251f718339e38..97fb8de2cc810e 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/backends/cpu/benchmarks/BUILD @@ -35,8 +35,6 @@ xla_cc_test( "//xla/service:compiler", "//xla/service/cpu:cpu_aot_compilation_result", "//xla/service/cpu:test_header_helper", - "//xla/tsl/platform:status_matchers", - "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", ], ) @@ -470,7 +468,6 @@ xla_cc_test( "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", ], ) From b8e14456d0d7cf6fd34ab24ee743b4fd49cdefb5 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:42:51 -0700 Subject: [PATCH 0637/1324] Run build_cleaner on BUILD file(s) located in /xla/mlir/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747008234 --- third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD | 3 --- third_party/xla/xla/mlir/utils/BUILD | 1 - 2 files changed, 4 deletions(-) diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD index b0fde92aad9c20..3d9bbabd86535d 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD @@ -15,15 +15,12 @@ cc_library( ":compiler_trace_proto_cc", ":compiler_trace_proto_cc_impl", "//xla/service/llvm_ir:llvm_util", - "//xla/tsl/platform:env", "//xla/tsl/platform:logging", "@com_google_absl//absl/log", - "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/mlir/utils/BUILD b/third_party/xla/xla/mlir/utils/BUILD index 98e5bb41888842..827b84165fa474 100644 --- a/third_party/xla/xla/mlir/utils/BUILD +++ b/third_party/xla/xla/mlir/utils/BUILD @@ -23,7 +23,6 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", From c81085d5d9287d8625043fda6e2feada8ee27952 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:43:39 -0700 Subject: [PATCH 0638/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/gpu/codegen/emitters/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747008340 --- third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD | 3 --- .../xla/xla/backends/gpu/codegen/emitters/transforms/BUILD | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD index bab734000f6404..20b5dec1c59a74 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD @@ -113,9 +113,6 @@ cc_library( ":xla_gpu_types_inc_gen", "//xla/codegen/emitters/ir:xla", "//xla/hlo/analysis:indexing_analysis", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD index 6ba7c4508eb79f..f00866292804ec 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/BUILD @@ -44,11 +44,14 @@ cc_library( copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ ":passes_inc_gen", + "//xla:shape_util", "//xla:util", "//xla/backends/gpu/codegen/emitters/ir:xla_gpu", "//xla/codegen/emitters/ir:xla", "//xla/codegen/emitters/transforms:atomic_rmw_utils", "//xla/hlo/analysis:indexing_analysis", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:device_description", @@ -69,6 +72,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFUtils", From ca00b4846e12515259f1aad772450f9088b089d2 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:45:08 -0700 Subject: [PATCH 0639/1324] Run build_cleaner on BUILD file(s) located in /xla/hlo/translate/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747008614 --- third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 6 ------ third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD | 3 --- 2 files changed, 9 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index 9b1d5d88306774..678814e990e69a 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -61,7 +61,6 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -203,13 +202,9 @@ xla_cc_test( srcs = ["hlo_utils_test.cc"], deps = [ ":hlo_utils", - "//xla:literal", - "//xla:literal_util", "//xla:shape_util", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/mlir_hlo", - "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test_main", "@com_google_googletest//:gtest", @@ -249,7 +244,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index 5d1094eaf05714..326c88676bb11a 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -65,7 +65,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/tsl/platform:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -82,7 +81,6 @@ cc_library( deps = [ ":stack_frame_index_builder", "//xla:xla_data_proto_cc", - "@com_google_absl//absl/log", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -229,7 +227,6 @@ xla_cc_test( "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", - "@com_google_googletest//:gtest", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", From 06c724696661704f189eb76d8beb4fd6fc1cf8ea Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:46:27 -0700 Subject: [PATCH 0640/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/cpu/codegen/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747008876 --- third_party/xla/xla/backends/cpu/codegen/elemental/BUILD | 2 -- third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD | 1 - 2 files changed, 3 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD b/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD index 5f45d0e5e24809..46927228ab4924 100644 --- a/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD @@ -92,7 +92,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:JITLink", "@llvm-project//llvm:ir_headers", ], ) @@ -104,7 +103,6 @@ xla_cc_test( ":elemental_kernel_emitter", "//xla:xla_data_proto_cc", "//xla/codegen:kernel_definition", - "//xla/codegen:kernel_spec", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD index 7d223e2e09168d..d5f892ba18d2ee 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/ir/BUILD @@ -90,6 +90,5 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", ], ) From e2e78ba607e215684719d446f196d5485eb9dd12 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:51:50 -0700 Subject: [PATCH 0641/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/cpu/runtime. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747009829 --- .../xla/xla/backends/cpu/runtime/BUILD | 58 +------------------ 1 file changed, 2 insertions(+), 56 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index eae9577bbbe3e1..009d8a8c6b5a73 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -86,19 +86,16 @@ cc_library( hdrs = ["kernel.h"], deps = [ ":kernel_c_api", + ":work_queue", "//xla:util", - "//xla/backends/cpu/runtime:work_queue", "//xla/stream_executor:device_memory", "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", ], @@ -168,14 +165,12 @@ cc_library( srcs = ["parallel_loop_runner.cc"], hdrs = ["parallel_loop_runner.h"], deps = [ - "//xla/backends/cpu/runtime:work_queue", + ":work_queue", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/math:math_util", "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/time", "@eigen_archive//:eigen3", ], ) @@ -192,8 +187,6 @@ xla_cc_test( "//xla/tsl/platform:test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", ], @@ -224,7 +217,6 @@ cc_library( "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/runtime:resource_use", - "//xla/service:buffer_assignment", "//xla/service:global_device_id", "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_runtime", @@ -233,7 +225,6 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/profiler/lib:traceme", @@ -281,7 +272,6 @@ cc_library( local_defines = if_windows(["_ENABLE_EXTENDED_ALIGNED_STORAGE"]), deps = [ ":thunk", - "//xla:util", "//xla/runtime:buffer_use", "//xla/runtime:execution_graph", "//xla/runtime:resource_use", @@ -304,7 +294,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:connected_traceme", "@local_tsl//tsl/profiler/lib:context_types_hdrs", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/profiler/lib:traceme_encode", ], ) @@ -346,12 +335,10 @@ cc_library( deps = [ ":thunk", ":thunk_executor", - "//xla/service:buffer_assignment", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -399,10 +386,8 @@ cc_library( ":collective_thunk", ":thunk", "//xla:shape_util", - "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", "//xla/core/collectives:communicator", - "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", @@ -413,7 +398,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -430,7 +414,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/tsl/platform:errors", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -468,24 +451,17 @@ cc_library( ":thunk", "//xla:executable_run_options", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -513,14 +489,9 @@ xla_cc_test( ":convolution_thunk", ":convolution_thunk_test_util", ":thunk", - ":thunk_testlib", - "//xla:literal", - "//xla:literal_util", - "//xla/service:buffer_assignment", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -537,7 +508,6 @@ cc_library( ":thunk", "//xla:shape_util", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", "//xla/core/collectives:communicator", "//xla/service:buffer_assignment", @@ -552,7 +522,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -578,7 +547,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -594,7 +562,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", "//xla/core/collectives:communicator", - "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", @@ -605,7 +572,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -636,7 +602,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -703,7 +668,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -762,7 +726,6 @@ cc_library( "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -862,7 +825,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -901,7 +863,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -945,7 +906,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1000,7 +960,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1048,7 +1007,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1081,7 +1039,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1132,7 +1089,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1179,7 +1135,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1265,14 +1220,10 @@ cc_library( deps = [ "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/math:math_util", - "//xla/tsl/platform:env", - "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/time", "@eigen_archive//:eigen3", ], ) @@ -1284,16 +1235,11 @@ xla_cc_test( ":work_queue", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:env", - "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", ], ) From 48794c9541f1909ccaebc84f9cf4480d0ccb5901 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 00:57:13 -0700 Subject: [PATCH 0642/1324] Run build_cleaner on BUILD file(s) located in /xla/mlir_hlo. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747010654 --- third_party/xla/xla/mlir_hlo/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 48e9f1220ff7ad..9f6c705db3ceb8 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -569,7 +569,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", ], ) From 3be9eb08c74cebebefd93e790466d0758126e8e7 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 01:17:25 -0700 Subject: [PATCH 0643/1324] Run build_cleaner on BUILD file(s) located in /xla/hlo/utils/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747014736 --- third_party/xla/xla/hlo/utils/BUILD | 2 -- third_party/xla/xla/hlo/utils/concurrency/BUILD | 1 - 2 files changed, 3 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index f4920f643da533..a37ac5a918b1e3 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -39,8 +39,6 @@ cc_library( deps = [ "//xla:shape_util", "//xla/hlo/analysis:hlo_alias_analysis", - "//xla/hlo/analysis:hlo_dataflow_analysis", - "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/service:hlo_buffer", "//xla/service:hlo_value", diff --git a/third_party/xla/xla/hlo/utils/concurrency/BUILD b/third_party/xla/xla/hlo/utils/concurrency/BUILD index ca54aee41d4197..6101e251ef23fd 100644 --- a/third_party/xla/xla/hlo/utils/concurrency/BUILD +++ b/third_party/xla/xla/hlo/utils/concurrency/BUILD @@ -92,7 +92,6 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service:hlo_module_config", - "//xla/tests:hlo_test_base", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", From d09af4a5f45e87e7ed8000fdb973d2509fba7871 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 01:20:02 -0700 Subject: [PATCH 0644/1324] Run build_cleaner on BUILD file(s) located in /xla/hlo/transforms/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747015329 --- third_party/xla/xla/hlo/transforms/BUILD | 4 ---- third_party/xla/xla/hlo/transforms/collectives/BUILD | 3 --- third_party/xla/xla/hlo/transforms/expanders/BUILD | 8 -------- third_party/xla/xla/hlo/transforms/simplifiers/BUILD | 4 ---- 4 files changed, 19 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index 84943b76aecb3b..5f57b9cd9b4576 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -116,7 +116,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:call_graph", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -274,7 +273,6 @@ xla_cc_test( "//xla/service:memory_annotations_hdr", "//xla/service:pattern_matcher", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", @@ -354,7 +352,6 @@ cc_library( srcs = ["host_offloading_prepare.cc"], hdrs = ["host_offloading_prepare.h"], deps = [ - "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:call_graph", @@ -416,7 +413,6 @@ cc_library( "//xla:literal_pool", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/hlo/transforms/collectives/BUILD b/third_party/xla/xla/hlo/transforms/collectives/BUILD index 30d9bfd24bf261..d1b86c51f67cb3 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/BUILD +++ b/third_party/xla/xla/hlo/transforms/collectives/BUILD @@ -182,7 +182,6 @@ cc_library( "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -372,8 +371,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/hlo/utils:hlo_sharding_util", "//xla/service:collective_combiner_utils", "//xla/service:collective_permute_key", "//xla/service:hlo_domain_map", diff --git a/third_party/xla/xla/hlo/transforms/expanders/BUILD b/third_party/xla/xla/hlo/transforms/expanders/BUILD index e7c91bc7c51900..0f6e3fe2e4532f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/BUILD +++ b/third_party/xla/xla/hlo/transforms/expanders/BUILD @@ -301,10 +301,7 @@ cc_library( hdrs = ["bitcast_dtypes_expander.h"], deps = [ ":op_expander_pass", - "//xla:literal_util", "//xla:shape_util", - "//xla:status_macros", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", @@ -314,7 +311,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_module_config", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -332,8 +328,6 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", @@ -486,9 +480,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD index 79fa8234711daf..de6d92b820ffce 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -434,7 +434,6 @@ cc_library( "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", @@ -619,10 +618,8 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -851,7 +848,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@com_google_googletest//:gtest", ], ) From 53998017e9dfc2793fae392b99c15067f567b92b Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 01:21:26 -0700 Subject: [PATCH 0645/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/cpu/runtime/xnnpack. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747015613 --- .../xla/backends/cpu/runtime/xnnpack/BUILD | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD index 44de812771bdde..82cfc16b99ade3 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD @@ -59,15 +59,10 @@ xla_cc_test( ":xnn_threadpool", "//xla/backends/cpu/runtime:parallel_loop_runner", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:env", "//xla/tsl/platform:test", - "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", "@XNNPACK", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/synchronization", - "@com_google_googletest//:gtest", - "@eigen_archive//:eigen3", "@pthreadpool", ], ) @@ -104,7 +99,6 @@ xla_cc_test( deps = [ ":xnn_convolution_thunk", "//xla:error_spec", - "//xla:executable_run_options", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -123,10 +117,7 @@ xla_cc_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@eigen_archive//:eigen3", @@ -188,10 +179,6 @@ cc_library( ":xnn_interop", ":xnn_threadpool", "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/backends/cpu/runtime:dot_lib", "//xla/backends/cpu/runtime:object_pool", "//xla/backends/cpu/runtime:parallel_loop_runner", "//xla/backends/cpu/runtime:thunk", @@ -210,14 +197,10 @@ cc_library( "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/lib:traceme", "@pthreadpool", ], ) @@ -228,7 +211,6 @@ xla_cc_test( deps = [ ":xnn_fusion_thunk", ":xnn_interop", - "//xla:executable_run_options", "//xla:literal_util", "//xla:shape_util", "//xla/backends/cpu/runtime:buffer_allocations", @@ -239,7 +221,6 @@ xla_cc_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@XNNPACK", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", From 290f6e3023f180facbcf4a3f3f7d612e543baa51 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 01:21:29 -0700 Subject: [PATCH 0646/1324] Run build_cleaner on BUILD file(s) located in /xla/core/collectives. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747015639 --- third_party/xla/xla/core/collectives/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index d199bcaad52964..85f30c2c76b0bf 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -19,7 +19,6 @@ cc_library( srcs = ["clique.cc"], hdrs = ["clique.h"], deps = [ - ":clique_id", ":communicator", ":rank_id", "//xla:util", From e036ea6028f2bf7edae59685a3efdcafaadf2baf Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Sun, 13 Apr 2025 01:52:55 -0700 Subject: [PATCH 0647/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: any conflicts that needs manual handling conflicts that needs to choose between two "valid" targets missing BUILD in a directory missing target for a file (e.g. a python script) missing targets for some bzl_library platform-specific code (e.g. rocm) ones that use filegroup instead of individual cc_library and more. Before: metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s After: metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under stream_executor/ and service/ the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747022627 --- third_party/xla/xla/pjrt/BUILD | 4 ---- 1 file changed, 4 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index afcfe8534c065f..cb842dc4687f12 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -621,7 +621,6 @@ xla_cc_test( "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", - "//xla/pjrt/profiling:device_time_measurement", "//xla/service:cpu_plugin", "//xla/service:platform_util", "//xla/stream_executor:platform", @@ -632,7 +631,6 @@ xla_cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", ], ) @@ -924,8 +922,6 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", "@stablehlo//:version", ], ) From dae536ac2d98f8cb181d0a253a984f2f55289c94 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 02:02:42 -0700 Subject: [PATCH 0648/1324] Update GraphDef version to 2196. PiperOrigin-RevId: 747024992 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index c0c48201e1cb3e..e4433bb4bebf24 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2195 // Updated: 2025/4/12 +#define TF_GRAPH_DEF_VERSION 2196 // Updated: 2025/4/13 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From fccea021c053ab83444f61459f76665b46458951 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 02:03:10 -0700 Subject: [PATCH 0649/1324] compat: Update forward compatibility horizon to 2025-04-13 PiperOrigin-RevId: 747025124 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 3377ee2bae9e17..118d7e775c7081 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 12) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 13) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 2a688ed61fdd9b2d42faf8eddccde6df57063b0b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 07:24:01 -0700 Subject: [PATCH 0650/1324] Automated Code Change PiperOrigin-RevId: 747084521 --- third_party/xla/xla/python/aggregate_profile.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/python/aggregate_profile.cc b/third_party/xla/xla/python/aggregate_profile.cc index 29e09d04821c95..55f6797e937f11 100644 --- a/third_party/xla/xla/python/aggregate_profile.cc +++ b/third_party/xla/xla/python/aggregate_profile.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "xla/python/xplane_to_profile_instructions.h" +#include "tsl/profiler/protobuf/profiled_instructions.pb.h" namespace xla { From 45ebe85575d10b51dec07353d73db1f4da0c3f3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 09:37:47 -0700 Subject: [PATCH 0651/1324] Automated Code Change PiperOrigin-RevId: 747107669 --- .../xla/xla/backends/gpu/codegen/emitters/emitter_base.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc index 4d393924bfb774..e4eb5aef1e16b3 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" From fb00f6a851d89138997372b3baae42d4a25cd50b Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Sun, 13 Apr 2025 09:55:43 -0700 Subject: [PATCH 0652/1324] Add gelu lowering pattern for dynamic shapes PiperOrigin-RevId: 747110641 --- .../stablehlo/transforms/composite_lowering_patterns.td | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 656b9ec692568d..aa2d50438c0c54 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -147,6 +147,13 @@ def LegalizeCompositeGELUDynamicShaped2 : Pat< (TFL_GeluOp $inputs, (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; +def LegalizeCompositeGELUDynamicShaped3 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + def LegalizeCompositeOdmlEmbeddingLookup : Pat< (MHLO_CompositeOp:$composite (variadic $indices, $table), From 2a3f646e6177178fde3e79f15b582580252e558c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 12:52:53 -0700 Subject: [PATCH 0653/1324] Automated Code Change PiperOrigin-RevId: 747141894 --- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/index_util_test.cc | 1 + third_party/xla/xla/layout.cc | 1 + third_party/xla/xla/literal_util.cc | 4 ++++ third_party/xla/xla/metric_table_report.cc | 1 + 5 files changed, 8 insertions(+) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 766da958265d22..1c1b04a744a478 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -716,6 +716,7 @@ cc_library( "//xla/tsl/platform:status", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/index_util_test.cc b/third_party/xla/xla/index_util_test.cc index a312293d32b586..a828e7c7674353 100644 --- a/third_party/xla/xla/index_util_test.cc +++ b/third_party/xla/xla/index_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/index_util.h" +#include #include #include diff --git a/third_party/xla/xla/layout.cc b/third_party/xla/xla/layout.cc index 9e6a3b0e687221..e7218ca7006afb 100644 --- a/third_party/xla/xla/layout.cc +++ b/third_party/xla/xla/layout.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/layout.h" +#include #include #include #include diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc index c975499125a103..3713c2beffbf02 100644 --- a/third_party/xla/xla/literal_util.cc +++ b/third_party/xla/xla/literal_util.cc @@ -15,16 +15,20 @@ limitations under the License. #include "xla/literal_util.h" +#include #include +#include #include #include #include +#include #include #include #include #include #include "absl/random/uniform_int_distribution.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" diff --git a/third_party/xla/xla/metric_table_report.cc b/third_party/xla/xla/metric_table_report.cc index 6852e7a2a71f5e..605e9cd43d6079 100644 --- a/third_party/xla/xla/metric_table_report.cc +++ b/third_party/xla/xla/metric_table_report.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include From c2a3c368e79be0292faeac380086c42169763908 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Sun, 13 Apr 2025 16:12:38 -0700 Subject: [PATCH 0654/1324] Change breaking internal tests. Reverts 6f5ef47b4316ef0cb10d711d8de8903b3d6a28cd PiperOrigin-RevId: 747176251 --- .../xla/xla/service/collective_pipeliner.cc | 60 ++++++++----------- .../xla/service/collective_pipeliner_test.cc | 5 +- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 1f7c172ab4387f..60e57b4672641f 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -211,7 +211,7 @@ CollectDynamicSliceIndicesIfConstant(HloInstruction* instr) { for (int64_t i = dyn_slice->first_index_operand_number(); i < instr->operand_count(); ++i) { HloInstruction* operand = dyn_slice->mutable_operand(i); - CHECK(operand->shape().dimensions().empty()); + CHECK_EQ(operand->shape().dimensions_size(), 0); std::vector> stack( 1, std::make_pair(operand, 0)); absl::flat_hash_set visited; @@ -343,11 +343,12 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, ShapeUtil::ElementsIn(instr->operand(0)->shape()) < 1024)) { return true; } + // TODO(b/409716406): Reconsider cases where Pad can be supported. return HloPredicateIsOp(i) || + HloOpcode::kCollectivePermute, HloOpcode::kConvert, + HloOpcode::kReshape, HloOpcode::kAllReduce, + HloOpcode::kTranspose, HloOpcode::kBroadcast, + HloOpcode::kAllGather>(i) || (multi_uses_pipelining && i->IsElementwise()) || i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep) || i->IsCustomCall(CollectivePipeliner::kSunkByPreviousStep); @@ -1528,7 +1529,7 @@ Shape ComputeFullOutputShape(const WhileMoveInfo& move_info, // Create zero of base type ptype and broadcast it to shape. HloInstruction* CreateZero(HloComputation* comp, const Shape& shape, PrimitiveType ptype) { - if (shape.dimensions().empty()) { + if (shape.dimensions_size() == 0) { return comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); } @@ -2008,8 +2009,8 @@ absl::Status TransformLoopForward( if (slice_target_shape != data_to_slice->shape()) { // Slice matrix. absl::InlinedVector dynamic_slice_sizes; - dynamic_slice_sizes.reserve(slice_target_shape.dimensions().size()); - for (int i = 0; i < slice_target_shape.dimensions().size(); ++i) { + dynamic_slice_sizes.reserve(slice_target_shape.dimensions_size()); + for (int i = 0; i < slice_target_shape.dimensions_size(); ++i) { dynamic_slice_sizes.push_back(slice_target_shape.dimensions(i)); } sliced_data = @@ -2267,7 +2268,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, Shape index_shape = move_info.dynamic_update_slices.front()->index_shapes()[0]; std::vector indices( - expanded_shape.dimensions().size(), + expanded_shape.dimensions_size(), CreateZero(body_computation, index_shape, index_shape.element_type())); indices[0] = move_info.dynamic_update_slices.front()->index_operands()[0]; @@ -2312,7 +2313,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, HloDynamicUpdateSliceInstruction* dyn_update = to_move.dynamic_update_slices[0]; std::vector indices( - expanded_shape.dimensions().size(), + expanded_shape.dimensions_size(), CreateZero(body_computation, dyn_update->index_shapes()[0], dyn_update->index_shapes()[0].element_type())); indices[0] = dyn_update->index_operands()[0]; @@ -2433,7 +2434,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, if (is_loop_invariant) { Shape full_shape = ComputeFullOutputShape(to_move, pipelined->shape()); absl::InlinedVector operand_dims; - operand_dims.resize(pipelined->shape().dimensions().size()); + operand_dims.resize(pipelined->shape().dimensions_size()); absl::c_iota(operand_dims, 1); HloInstruction* broadcasted = loop_computation->AddInstruction(HloInstruction::CreateBroadcast( @@ -2459,30 +2460,21 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, std::vector operands; for (auto* operand : instr->mutable_operands()) { if (operand->opcode() == HloOpcode::kConstant) { - if (!operand->shape().dimensions().empty()) { - // Broadcast constant into full shape. - HloInstruction* cloned_constant = loop_computation->AddInstruction( - operand->CloneWithNewOperands(operand->shape(), {})); - if (!to_add_batch_set.contains(instr)) { - operands.push_back(cloned_constant); - continue; - } - Shape full_shape = - ComputeFullOutputShape(to_move, cloned_constant->shape()); - absl::InlinedVector operand_dims; - operand_dims.resize(cloned_constant->shape().dimensions().size()); - absl::c_iota(operand_dims, 1); - HloInstruction* broadcasted = loop_computation->AddInstruction( - HloInstruction::CreateBroadcast(full_shape, cloned_constant, - operand_dims)); - operands.push_back(broadcasted); + HloInstruction* cloned_constant = loop_computation->AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {})); + if (!to_add_batch_set.contains(instr)) { + operands.push_back(cloned_constant); continue; } - // The constant may be for something like a padding value. And a - // scalar shape can't be for slice shape that's to be concatenated. - // No need to broadcast. - operands.push_back(loop_computation->AddInstruction( - operand->CloneWithNewOperands(operand->shape(), {}))); + Shape full_shape = + ComputeFullOutputShape(to_move, cloned_constant->shape()); + absl::InlinedVector operand_dims; + operand_dims.resize(cloned_constant->shape().dimensions_size()); + absl::c_iota(operand_dims, 1); + HloInstruction* broadcasted = + loop_computation->AddInstruction(HloInstruction::CreateBroadcast( + full_shape, cloned_constant, operand_dims)); + operands.push_back(broadcasted); continue; } auto it = pipelined_map.find(operand); @@ -2554,7 +2546,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } // Constant scalars don't get expanded ahead of time and are kept // scalar. - if (operands[0]->shape().dimensions().empty()) { + if (operands[0]->shape().dimensions_size() == 0) { dimensions.clear(); } HloInstruction* expanded_broadcast = diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index e1c72ab46888c0..556d0234db0500 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -3253,10 +3253,7 @@ while_body { dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 - slice = bf16[1,8,120] slice(ar.1), slice={[0:1], [0:8], [0:120]} - constant.2563 = bf16[] constant(5.0) - pad = bf16[1,8,128] pad(slice, constant.2563), padding=0_0x0_0x0_8 - b.1 = bf16[1,8,128,32] broadcast(pad), dimensions={0,1,2} + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} constant = bf16[] constant(0) reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) From 19c30a9fab62c2d3a94a9c1591f0098aee6012d1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 21:21:05 -0700 Subject: [PATCH 0655/1324] Automated Code Change PiperOrigin-RevId: 747239490 --- tensorflow/python/lib/io/BUILD | 7 +++++++ tensorflow/python/lib/io/file_io_wrapper.cc | 3 +++ tensorflow/python/lib/io/record_io_wrapper.cc | 3 ++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/lib/io/BUILD b/tensorflow/python/lib/io/BUILD index 3c1c2ff182cd09..34fda92c56093e 100644 --- a/tensorflow/python/lib/io/BUILD +++ b/tensorflow/python/lib/io/BUILD @@ -27,9 +27,15 @@ tf_python_pybind_extension( ], deps = [ "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:portable_jpeg_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:file_statistics", "//tensorflow/python/lib/core:pybind11_absl", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/status", "@pybind11", ], ) @@ -51,6 +57,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_absl", "//tensorflow/python/lib/core:pybind11_status", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@pybind11", ], ) diff --git a/tensorflow/python/lib/io/file_io_wrapper.cc b/tensorflow/python/lib/io/file_io_wrapper.cc index 54ec610b534a6d..5b310a5772ecf2 100644 --- a/tensorflow/python/lib/io/file_io_wrapper.cc +++ b/tensorflow/python/lib/io/file_io_wrapper.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include +#include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 #include "tensorflow/core/lib/core/error_codes.pb.h" diff --git a/tensorflow/python/lib/io/record_io_wrapper.cc b/tensorflow/python/lib/io/record_io_wrapper.cc index 3b27d5b8a1885a..1e4256a9418686 100644 --- a/tensorflow/python/lib/io/record_io_wrapper.cc +++ b/tensorflow/python/lib/io/record_io_wrapper.cc @@ -15,8 +15,9 @@ limitations under the License. #include #include +#include -#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" From c3f5cd1973d208173aaa512e4f69cf517ec96d53 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 23:02:35 -0700 Subject: [PATCH 0656/1324] Automated Code Change PiperOrigin-RevId: 747262393 --- third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD | 1 + .../xla/hlo/tools/hlo_diff/utils/connected_components_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD index c820faaaad25b6..32231488c5bb4d 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/BUILD @@ -32,6 +32,7 @@ xla_cc_test( deps = [ ":connected_components", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc index 2b6c0237506556..10b8f2b7b98196 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/utils/connected_components_test.cc @@ -22,6 +22,7 @@ #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace hlo_diff { From 8ebaf621730a2c62198fe486cdd44622554208f8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Apr 2025 23:54:40 -0700 Subject: [PATCH 0657/1324] Automated Code Change PiperOrigin-RevId: 747274751 --- .../xla/xla/service/gpu/transforms/BUILD | 64 ++++++++++++++++++- .../transforms/command_buffer_scheduling.h | 1 + .../command_buffer_scheduling_test.cc | 4 ++ .../service/gpu/transforms/conv_rewriter.cc | 4 +- .../gpu/transforms/conv_rewriter_test.cc | 4 ++ .../xla/service/gpu/transforms/copy_fusion.cc | 2 + .../gpu/transforms/copy_fusion_test.cc | 1 + .../transforms/cublas_gemm_rewriter_test.cc | 2 + .../gpu/transforms/cublas_pad_for_gemms.cc | 3 + .../gpu/transforms/cublas_pad_for_gemms.h | 1 + .../transforms/cublas_pad_for_gemms_test.cc | 2 + .../transforms/cudnn_custom_call_compiler.cc | 3 +- .../transforms/cudnn_fused_conv_rewriter.cc | 3 + .../cudnn_fused_conv_rewriter_test.cc | 4 ++ .../transforms/cudnn_pad_for_convolutions.cc | 4 ++ .../cudnn_pad_for_convolutions_test.cc | 1 + .../gpu/transforms/cudnn_simplify_padding.cc | 3 + .../transforms/cudnn_simplify_padding_test.cc | 3 + .../cudnn_vectorize_convolutions.cc | 3 + .../cudnn_vectorize_convolutions_test.cc | 1 + .../custom_kernel_fusion_rewriter.cc | 5 ++ .../custom_kernel_fusion_rewriter_test.cc | 2 + .../gpu/transforms/dot_algorithm_rewriter.cc | 1 + .../gpu/transforms/dot_dimension_sorter.cc | 2 + .../transforms/dot_dimension_sorter_test.cc | 1 + .../service/gpu/transforms/dot_normalizer.cc | 1 + .../gpu/transforms/dot_normalizer_test.cc | 3 + .../gpu/transforms/dot_operand_converter.cc | 2 + .../dynamic_slice_fusion_rewriter.cc | 6 ++ .../dynamic_slice_fusion_rewriter_test.cc | 4 ++ ...xplicit_collectives_group_async_wrapper.cc | 2 + .../fusion_block_level_rewriter_test.cc | 1 - .../fusion_dynamic_memcpy_rewriter.cc | 1 - .../fusion_dynamic_memcpy_rewriter_test.cc | 3 +- .../gpu/transforms/fusion_wrapper_test.cc | 2 +- .../gemm_broadcast_folding_rewriter_test.cc | 2 + .../gpu/transforms/gemm_rewriter_fp8_test.cc | 1 + .../gpu/transforms/splitk_rewriter_test.cc | 1 + 38 files changed, 146 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 6eb28726d1db21..6b21f9d1c50a58 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -30,6 +30,7 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", @@ -698,6 +699,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla:xla_proto_cc", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -731,6 +733,7 @@ xla_test( ], deps = [ ":command_buffer_scheduling", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:filecheck", @@ -743,6 +746,7 @@ xla_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -810,6 +814,8 @@ cc_library( "//xla/stream_executor:dnn", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -829,6 +835,7 @@ xla_cc_test( "//xla:array4d", "//xla:literal_util", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/testlib:test", @@ -841,7 +848,10 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -862,6 +872,7 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", @@ -878,6 +889,7 @@ xla_cc_test( "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", @@ -894,6 +906,7 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/backends/gpu/codegen/triton:support", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -901,6 +914,7 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:logging", @@ -917,6 +931,8 @@ xla_cc_test( ], deps = [ ":cublas_pad_for_gemms", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", @@ -966,6 +982,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -977,6 +994,7 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1012,6 +1030,8 @@ xla_test( ":cudnn_fused_conv_rewriter", "//xla:comparison_util", "//xla:error_spec", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", @@ -1033,7 +1053,9 @@ xla_test( "//xla/stream_executor:semantic_version", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1144,6 +1166,7 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:cublas_cudnn", @@ -1152,6 +1175,8 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1167,6 +1192,7 @@ xla_cc_test( srcs = ["cudnn_pad_for_convolutions_test.cc"], deps = [ ":cudnn_pad_for_convolutions", + "//xla:xla_data_proto_cc", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", @@ -1193,6 +1219,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1211,6 +1239,7 @@ xla_cc_test( ":cudnn_vectorize_convolutions", "//xla:literal", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/testlib:pattern_matcher_gmock", @@ -1225,6 +1254,8 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # build_cleaner: keep "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", @@ -1241,6 +1272,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", @@ -1256,6 +1288,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1272,6 +1306,7 @@ xla_cc_test( deps = [ ":cudnn_vectorize_convolutions", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:call_inliner", @@ -1321,7 +1356,11 @@ cc_library( "//xla/stream_executor/cuda:cudnn_plugin", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - ]) + ["@com_google_absl//absl/container:flat_hash_map"], + ]) + [ + "//xla:xla_data_proto_cc", + "//xla/tsl/protobuf:dnn_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + ], ) cc_library( @@ -1331,8 +1370,11 @@ cc_library( tags = ["gpu"], deps = [ "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_proto_cc", + "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu/kernels:custom_fusion_library", "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/stream_executor:device_description", @@ -1340,6 +1382,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1356,6 +1399,7 @@ xla_cc_test( deps = [ ":custom_kernel_fusion_rewriter", "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/stream_executor:device_description", @@ -1378,6 +1422,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1394,6 +1439,7 @@ xla_test( deps = [ ":dot_dimension_sorter", "//xla:error_spec", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -1409,6 +1455,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", "@com_google_absl//absl/status:statusor", @@ -1422,6 +1469,7 @@ xla_cc_test( srcs = ["dot_normalizer_test.cc"], deps = [ ":dot_normalizer", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", @@ -1440,8 +1488,10 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", @@ -1573,6 +1623,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/ffi:ffi_api", "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", @@ -1581,6 +1632,7 @@ cc_library( "//xla/hlo/utils:hlo_traversal", "//xla/service:call_graph", "//xla/service:custom_call_target_registry", + "//xla/service:hlo_proto_cc", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:gpu_constants", @@ -1590,6 +1642,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1610,6 +1664,8 @@ xla_cc_test( deps = [ ":dynamic_slice_fusion_rewriter", "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/ffi", "//xla/ffi:ffi_api", "//xla/hlo/builder:xla_builder", @@ -1617,6 +1673,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:custom_call_target_registry", "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", @@ -1722,6 +1779,7 @@ xla_cc_test( srcs = ["fusion_wrapper_test.cc"], deps = [ ":fusion_wrapper", + "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", "@com_google_googletest//:gtest_main", ], @@ -1756,6 +1814,7 @@ xla_test( ":gemm_broadcast_folding_rewriter", ":gemm_rewriter", "//xla:error_spec", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:semantic_version", @@ -1963,6 +2022,7 @@ xla_test( ":gemm_rewriter", ":gemm_rewriter_test_lib", "//xla:error_spec", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", @@ -1989,6 +2049,7 @@ xla_test( ":gemm_rewriter", ":gemm_rewriter_test_lib", "//xla:error_spec", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -3615,6 +3676,7 @@ xla_cc_test( "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:filecheck", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "//xla/tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h index 71d5b421c1ee56..b06acc83127e86 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index e74ff34ec0d625..ecc753d2dbfc10 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -15,11 +15,14 @@ limitations under the License. #include "xla/service/gpu/transforms/command_buffer_scheduling.h" #include +#include #include #include +#include #include #include +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" @@ -34,6 +37,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/xla.pb.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc index 714c7561868192..13aab8fd9203c8 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc @@ -28,8 +28,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/str_replace.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc index d186f88775cbcb..6330e9e66242b8 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc @@ -19,7 +19,10 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "xla/array4d.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -35,6 +38,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/util/proto/proto_matchers.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc index 23706a4dbcf149..1ff643bddcb203 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc @@ -21,6 +21,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc index e557990fa07c11..b55271ac81efe8 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc index aac833f4cedb03..5784d91d86d234 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_replace.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc index c94f142af8d5c5..d79d13051e0f41 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc @@ -20,6 +20,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/backends/gpu/codegen/triton/support_legacy.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -31,6 +33,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h index 15c6c74d9b50e4..9df27d017e2c4a 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc index aac0a4f579cc93..ee93e56f06da1c 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" namespace m = ::xla::match; diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 4bd09cc48e3364..66017bd43dd0d9 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -46,7 +45,9 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_dnn.h" #include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" #include "xla/stream_executor/dnn.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc index 71f30a90f6046d..ee7b9c91398d0e 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -55,7 +56,9 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/ml_dtypes.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index b0e0ac3a78c823..f2da50e6bcaf55 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_replace.h" @@ -54,6 +55,9 @@ limitations under the License. #include "xla/stream_executor/semantic_version.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/dnn.pb.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc index 49c680d5073982..7a0200cc254fdf 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc @@ -25,7 +25,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -40,6 +43,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc index 55c5197a13c915..67acd3a1fb102d 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/pattern_matcher.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc index 577f55d21e69e3..b58d6837acae93 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc @@ -24,6 +24,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc index 38cbe080d1d53b..f369ab9a0e8c47 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include #include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/hlo/pass/hlo_pass_fix.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index 0fdf28c82621a2..fdef50d6b13415 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -24,6 +24,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -48,6 +50,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc index 5118d99d60248f..d3faf0b1363524 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc index bfbdef3dadda56..223b0338c3c38b 100644 --- a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc @@ -24,6 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -32,9 +34,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc index 4d44d99066e9ff..1fb3e0d43e1ba4 100644 --- a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc @@ -19,9 +19,11 @@ limitations under the License. #include #include +#include #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc index fedba5849fd7dd..fac30e4daac8c4 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc index 7bfdb137e47c12..736e84ccbc3a57 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc @@ -22,7 +22,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc index 8b6efd2e757da8..aa97bee74bd379 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/xla.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc index 924b70fd48ec7f..853562a43be3a3 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc index 8e8c83017821ca..dcaaa10db0f663 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc @@ -15,10 +15,13 @@ limitations under the License. #include "xla/service/gpu/transforms/dot_normalizer.h" +#include +#include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/pattern_matcher.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc index 4f2606d9ca6574..4840c0365fe106 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/service/gpu/transforms/dot_operand_converter.h" +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 61163d3bd8cc48..e073e5062060d7 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,9 +28,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/ffi_api.h" @@ -47,9 +51,11 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index d7870e03c47723..41276d156c4da2 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" @@ -27,12 +28,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc index 42aa4b8af87fb5..e085dd20fcb4b2 100644 --- a/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.h" +#include + #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc index fd8c8b6e008260..eb6816d5acf863 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" -#include #include #include diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.cc index bc7dcbbfe9c25b..c27c22faf66bd5 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.h" #include -#include #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter_test.cc index d90df2905437e3..d7205cb68071f1 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter_test.cc @@ -15,9 +15,10 @@ limitations under the License. #include "xla/service/gpu/transforms/fusion_dynamic_memcpy_rewriter.h" +#include + #include #include -#include "absl/log/check.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc index e1326bb27cfc47..6705cac8c8fcd3 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/transforms/fusion_wrapper.h" -#include #include #include +#include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc index 6973521e783734..312a57e4678d24 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc @@ -17,11 +17,13 @@ limitations under the License. #include +#include #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/xla.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc index c766ba81bd3179..27d587fea3e19e 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc @@ -45,6 +45,7 @@ limitations under the License. #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc index bc32790c468dd9..3aa1bbd0f00335 100644 --- a/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/splitk_rewriter_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tsl/platform/statusor.h" From 19836287138a2d62a86a51997d528078a8333e00 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Sun, 13 Apr 2025 23:58:34 -0700 Subject: [PATCH 0658/1324] PR #25065: [GPU] Fix se_gpu_pjrt_client_test/ShardedAutotuningTest in OSS. Imported from GitHub PR https://github.com/openxla/xla/pull/25065 Copybara import of the project: -- defe869c98da53b2818dcfe7bcd8678c3846a2e7 by Ilia Sergachev : [GPU] Fix se_gpu_pjrt_client_test/ShardedAutotuningTest in OSS. Merging this change closes #25065 PiperOrigin-RevId: 747275748 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index e723f9e6f170e1..c2c9683e8bbc15 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -2237,6 +2237,12 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) { std::string cache_dir; CHECK(tsl::Env::Default()->LocalTempFilename(&cache_dir)); + if (tsl::kIsOpenSource) { + // Test relies on VLOG(1) messages. Enable VLOG(1) in OSS. + tsl::setenv("TF_CPP_VMODULE", "gemm_fusion_autotuner=1", + /*overwrite=*/true); + } + // Compile twice to test both empty and non-empty disk cache. for (int iteration = 0; iteration < 2; ++iteration) { tsl::SubProcess child[kNumNodes]; @@ -2298,11 +2304,6 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id, const int num_nodes_using_cache, absl::string_view cache_dir, bool use_xla_computation) { - if (tsl::kIsOpenSource) { - // Test relies on VLOG(1) messages. Enable VLOG(1) in OSS. - tsl::setenv("TF_CPP_VMODULE", "gemm_fusion_autotuner=1", - /*overwrite=*/true); - } std::unique_ptr service; if (node_id == 0) { TF_ASSIGN_OR_RETURN( From b238280eb1648473ca890580d4d06bdb4903f978 Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Mon, 14 Apr 2025 00:10:06 -0700 Subject: [PATCH 0659/1324] [XLA:GPU] Check the shape of the root instruction in triton support test Instead of tested instruction.The codegen wrapper will generate a fusion from the whole ENTRY computation, and start emitting from its root PiperOrigin-RevId: 747279592 --- .../gpu/codegen/triton/support_test.cc | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 63f742904f604d..949986d2b4f899 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -232,15 +232,16 @@ class TritonSupportTest : public TritonSupportTestBase { // If that is not the case, codegen could fail for that reason---which // wouldn't give any valuable signal here. The check is only done for array // and tuple shapes (only one layer of nesting is supported for tuples). - if (ti.Instruction().shape().IsArray()) { + const auto& root_instruction = ti.TritonComputation().root_instruction(); + if (root_instruction->shape().IsArray()) { ASSERT_EQ(output_tile_sizes.size(), 1); ASSERT_EQ(output_tile_sizes[0].size(), - ti.Instruction().shape().dimensions().size()); - } else if (ti.Instruction().shape().IsTuple()) { + root_instruction->shape().dimensions().size()); + } else if (root_instruction->shape().IsTuple()) { ASSERT_EQ(output_tile_sizes.size(), - ti.Instruction().shape().tuple_shapes_size()); + root_instruction->shape().tuple_shapes_size()); for (int64_t i = 0; i < output_tile_sizes.size(); ++i) { - const auto& shape = ti.Instruction().shape().tuple_shapes(i); + const auto& shape = root_instruction->shape().tuple_shapes(i); if (shape.IsTuple()) { continue; // No validation for nested tuples, as there is no way to // specify output tile sizes for them. @@ -1159,8 +1160,7 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, HloOpcode::kCollectivePermuteDone)); - RunSupportTestMultipleOutputTiles(std::move(ti_start), - /*output_tile_sizes=*/{{2, 2}, {2, 2}}, cc); + RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{2, 2}, cc); RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{2, 2}, cc); } @@ -1215,10 +1215,8 @@ ENTRY triton_computation { TestedInstruction ti_done, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, HloOpcode::kAsyncDone)); - RunSupportTestMultipleOutputTiles(std::move(ti_start), - /*output_tile_sizes=*/{{1}, {1}}, cc); - RunSupportTestMultipleOutputTiles(std::move(ti_update), - /*output_tile_sizes=*/{{1}, {1}}, cc); + RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{1}, cc); + RunSupportTest(std::move(ti_update), /*output_tile_sizes=*/{1}, cc); RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{1}, cc); } From fa1e7543a53d1e8ef15a87bcbe61c3376687e4a1 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 14 Apr 2025 00:33:37 -0700 Subject: [PATCH 0660/1324] [XLA:GPU][EMITTERS] Remove pipelining pass. PiperOrigin-RevId: 747286426 --- .../emitters/transforms/optimize_loops.cc | 178 ------------------ .../transforms/tests/optimize_loops.mlir | 101 ---------- 2 files changed, 279 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/optimize_loops.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/optimize_loops.cc index b764c8626e1702..97448a64042cd2 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/optimize_loops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/optimize_loops.cc @@ -54,175 +54,6 @@ namespace gpu { namespace { -mlir::Value GetSource(mlir::vector::TransferReadOp op) { - return op.getSource(); -} - -bool DoIndicesDependOnInductionVar(mlir::ValueRange indices, - mlir::scf::ForOp loop) { - // We assume LICM ran, so we can just check if any index is defined in the - // loop. - return absl::c_any_of(indices, [&](mlir::Value v) { - return v.getParentRegion() == &loop.getBodyRegion(); - }); -} - -bool CanReplaceInductionVar(mlir::ValueRange indices) { - return absl::c_all_of(indices, [&](mlir::Value v) { - if (auto bbarg = mlir::dyn_cast(v)) { - auto for_op = mlir::dyn_cast_or_null( - v.getParentRegion()->getParentOp()); - // This is a bbarg that is defined outside of the loop, so it doesn't - // affect pipelining. - if (!for_op) { - return true; - } - // We can only replace the induction variable, not other loop-carried - // values. - return v == for_op.getInductionVar(); - } - auto* op = v.getDefiningOp(); - return op && - mlir::isa( - op) && - CanReplaceInductionVar(op->getOperands()); - }); -} - -llvm::SmallVector ReplaceInductionVar( - mlir::Value induction_var, mlir::Value replacement, - llvm::SmallVector indices, - mlir::ImplicitLocOpBuilder& builder) { - for (mlir::Value& index : indices) { - if (mlir::isa(index)) { - if (index == induction_var) { - index = replacement; - } - continue; - } - - auto* op = index.getDefiningOp(); - CHECK(op) << "Did CanReplaceInductionVar() fail?"; - if (mlir::isa(op)) { - continue; - } - - CHECK( - (mlir::isa(op))) - << "Did CanReplaceInductionVar() fail?"; - auto replaced_args = ReplaceInductionVar(induction_var, replacement, - op->getOperands(), builder); - index = builder - .create(builder.getLoc(), op->getName().getIdentifier(), - replaced_args, op->getResultTypes(), op->getAttrs()) - ->getResult(0); - } - return indices; -} - -mlir::Value GetSource(mlir::tensor::ExtractOp op) { return op.getTensor(); } - -// TODO(jreiffers): Use a shared memory queue for pipelining instead of -// registers. -template -struct PipelineLoad : mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter& rewriter) const override { - auto loop = mlir::dyn_cast_or_null(op->getParentOp()); - if (!loop) { - return rewriter.notifyMatchFailure(op, "no loop found"); - } - - if (auto step = loop.getConstantStep(); - !step || step->getSExtValue() != 1) { - return rewriter.notifyMatchFailure(op, "loop step is not 1"); - } - - llvm::APInt lb, ub; - if (!mlir::matchPattern(loop.getLowerBound(), mlir::m_ConstantInt(&lb)) || - !mlir::matchPattern(loop.getUpperBound(), mlir::m_ConstantInt(&ub))) { - return rewriter.notifyMatchFailure(op, "bounds are not constants"); - } - if (lb.getSExtValue() != 0) { - return rewriter.notifyMatchFailure(op, "lower bound is not 0"); - } - - auto source = GetSource(op); - if (!source.getParentRegion()->isProperAncestor(&loop.getBodyRegion())) { - return rewriter.notifyMatchFailure( - op, "source is not defined outside the loop"); - } - - if (!DoIndicesDependOnInductionVar(op.getIndices(), loop)) { - // We don't run LICM between iterations, so this could happen. - // Just hoist the load out of the loop. - rewriter.moveOpBefore(op, loop); - return mlir::success(); - } - - if (!CanReplaceInductionVar(op.getIndices())) { - return rewriter.notifyMatchFailure(op, "unable to replace indices"); - } - - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - mlir::Value zero = b.create(0); - - b.setInsertionPoint(loop); - auto first_args = - ReplaceInductionVar(loop.getInductionVar(), zero, op.getOperands(), b); - auto loaded_first = - b.create(op->getResultTypes(), first_args, op->getAttrs()); - auto ub_minus_one = - b.create(ub.getSExtValue() - 1); - - b.setInsertionPointToStart(loop.getBody()); - - auto needs_load = b.create( - mlir::arith::CmpIPredicate::ult, loop.getInductionVar(), ub_minus_one); - auto next_value = - b.create(op->getResultTypes(), needs_load, true, true); - auto new_for = - mlir::cast(*loop.replaceWithAdditionalYields( - rewriter, loaded_first->getResult(0), - /*replaceInitOperandUsesInLoop=*/false, - [&](mlir::OpBuilder&, mlir::Location, - llvm::ArrayRef) { - return llvm::SmallVector{next_value->getResult(0)}; - })); - rewriter.replaceAllUsesWith(op, new_for.getRegionIterArgs().back()); - - b.setInsertionPointToStart(next_value.thenBlock()); - auto yield = b.create(op->getResult(0)); - - // We use this convoluted way to add 1 so folding works properly. - auto plus_one_map = mlir::AffineMap::get( - 1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1); - b.setInsertionPoint(next_value); - IndexingMap indexing_map(plus_one_map, - {IndexingMap::Variable{0, ub.getSExtValue() - 1}}, - /*range_vars=*/{}, /*rt_vars=*/{}); - auto induction_plus_one = - b.create(new_for.getInductionVar(), indexing_map) - ->getResult(0); - - // Create the new apply_indexing ops outside the if, to improve CSE. - rewriter.modifyOpInPlace(op, [&]() { - op->setOperands(ReplaceInductionVar( - new_for.getInductionVar(), induction_plus_one, op->getOperands(), b)); - }); - rewriter.moveOpBefore(op, yield); - - b.setInsertionPointToStart(next_value.elseBlock()); - b.create(new_for.getRegionIterArgs().back()); - return mlir::success(); - } -}; - int GetUnrollingFactor(mlir::scf::ForOp op) { // We only unroll loops with a step of 1 and a lower bound of 0. That's the // only type we generate. @@ -317,15 +148,6 @@ class OptimizeLoopsPass signalPassFailure(); return; } - - // Then pipeline the remaining loops. - mlir::RewritePatternSet patterns(&getContext()); - patterns.add, - PipelineLoad>(&getContext()); - if (mlir::failed( - mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { - signalPassFailure(); - } } }; diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/optimize_loops.mlir index 8fe920c3abfcd4..345c7680139dab 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/optimize_loops.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/tests/optimize_loops.mlir @@ -105,104 +105,3 @@ module { // CHECK-LABEL: @do_not_unroll // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: scf.for {{.*}} step %[[C1]] - -// ----- - -module { - func.func @pipeline_extract(%arg: tensor<31xf32>) -> f32 { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c31 = arith.constant 31 : index - %cst = arith.constant 0.0 : f32 - %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%iter = %cst) -> (f32) { - %val = tensor.extract %arg[%i] : tensor<31xf32> - %log = math.log %val : f32 - %add = arith.addf %log, %iter : f32 - scf.yield %add : f32 - } - return %ret : f32 - } -} - -// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 + 1), -// CHECK-LABEL: @pipeline_extract -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK: %[[VAL0:.*]] = tensor.extract %[[ARG0:.*]][%[[C0]]] -// CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]]) -// CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C30]] -// CHECK-DAG: %[[NEXT_I:.*]] = xla.apply_indexing #[[$MAP]](%[[I]] -// CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]] -// CHECK-NEXT: tensor.extract %[[ARG0]][%[[NEXT_I]]] -// CHECK-NEXT: yield -// CHECK-NEXT: else -// CHECK-NEXT: yield %[[VAL]] -// CHECK: math.log %[[VAL]] -// CHECK: %[[ADD:.*]] = arith.addf -// CHECK: yield %[[ADD]], %[[NEXT_VAL]] - -// ----- - -module { - func.func @pipeline_transfer(%arg: tensor<34xf32>) -> vector<2xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c17 = arith.constant 17 : index - %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> - %cst0 = arith.constant 0.0 : f32 - %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla.apply_indexing #xla.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15]">(%i) - %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> - %log = math.log %val : vector<2xf32> - %add = arith.addf %log, %iter : vector<2xf32> - scf.yield %add : vector<2xf32> - } - return %ret : vector<2xf32> - } -} - -// CHECK-DAG: #[[$MAP0:.*]] = #xla.indexing_map<"(d0) -> (d0 * 2), -// CHECK-DAG: #[[$MAP1:.*]] = #xla.indexing_map<"(d0) -> (d0 + 1), -// CHECK-LABEL: @pipeline_transfer -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK: %[[BASE0:.*]] = xla.apply_indexing #[[$MAP0]](%[[C0]] -// CHECK: %[[VAL0:.*]] = vector.transfer_read %[[ARG0:.*]][%[[BASE0]]] -// CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]]) -// CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C16]] -// CHECK-DAG: %[[NEXT_I:.*]] = xla.apply_indexing #[[$MAP1]](%[[I]] -// CHECK-DAG: %[[NEXT_BASE:.*]] = xla.apply_indexing #[[$MAP0]](%[[NEXT_I]] -// CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]] -// CHECK-NEXT: vector.transfer_read %[[ARG0]][%[[NEXT_BASE]]] -// CHECK-NEXT: yield -// CHECK-NEXT: else -// CHECK-NEXT: yield %[[VAL]] -// CHECK: math.log %[[VAL]] -// CHECK: %[[ADD:.*]] = arith.addf -// CHECK: yield %[[ADD]], %[[NEXT_VAL]] - -// ----- - -module { - func.func @sequential_extract(%arg0: tensor<6xindex>, %arg1: tensor<22xindex>) -> (index) { - %c1 = arith.constant 1 : index - %c733 = arith.constant 733 : index - %c0 = arith.constant 0 : index - %2 = scf.for %i = %c0 to %c733 step %c1 iter_args(%x = %c1) -> (index) { - %extracted = tensor.extract %arg0[%i] : tensor<6xindex> - %extracted_1 = tensor.extract %arg1[%extracted] : tensor<22xindex> - scf.yield %extracted_1 : index - } - return %2 : index - } -} - -// Once `extracted` is pipelined, it becomes an iter arg, so `extracted_1` is -// extract %arg1[%arg]. While it is possible to pipeline this in principle, we -// do not currently do this. - -// CHECK-LABEL: @sequential_extract -// CHECK-SAME: (%[[ARG0:.*]]: tensor<6xindex>, %[[ARG1:.*]]: tensor<22xindex>) -// CHECK: tensor.extract %[[ARG0]] -// CHECK-NOT: tensor.extract -// CHECK: scf.for From 49a5502a4c460ee6e5d371541ebcc6f56df0e0fb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 00:46:08 -0700 Subject: [PATCH 0661/1324] Automated Code Change PiperOrigin-RevId: 747289914 --- third_party/xla/xla/codegen/emitter_loc_op_builder_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/codegen/emitter_loc_op_builder_test.cc b/third_party/xla/xla/codegen/emitter_loc_op_builder_test.cc index ee71c779b66457..e1a2da61f4a6f0 100644 --- a/third_party/xla/xla/codegen/emitter_loc_op_builder_test.cc +++ b/third_party/xla/xla/codegen/emitter_loc_op_builder_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/strings/string_view.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinAttributes.h" From 54ba3dec0a81f5b2a7a6656c0255af2082fab2bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 00:55:20 -0700 Subject: [PATCH 0662/1324] Automated Code Change PiperOrigin-RevId: 747292336 --- third_party/xla/xla/tools/matmul_perf_table_gen.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/tools/matmul_perf_table_gen.cc b/third_party/xla/xla/tools/matmul_perf_table_gen.cc index 588dd10b989f02..2ecd3b67468f31 100644 --- a/third_party/xla/xla/tools/matmul_perf_table_gen.cc +++ b/third_party/xla/xla/tools/matmul_perf_table_gen.cc @@ -309,12 +309,12 @@ int64_t GetFlops(const HloDotInstruction& dot) { const HloInstruction& rhs = *dot.operand(1); // Get non-contracting dims - for (int dim : GetNonContractingDims(lhs.shape().dimensions_size(), + for (int dim : GetNonContractingDims(lhs.shape().dimensions().size(), dot_dims.lhs_contracting_dimensions(), dot_dims.lhs_batch_dimensions())) { fmas *= dim_size(lhs, dim); } - for (int dim : GetNonContractingDims(rhs.shape().dimensions_size(), + for (int dim : GetNonContractingDims(rhs.shape().dimensions().size(), dot_dims.rhs_contracting_dimensions(), dot_dims.rhs_batch_dimensions())) { fmas *= dim_size(rhs, dim); From f75f7fed6bd3bf4ea9a6ddbe205e82b0ece147c4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 00:59:31 -0700 Subject: [PATCH 0663/1324] Automated Code Change PiperOrigin-RevId: 747293330 --- third_party/xla/xla/service/BUILD | 5 +++++ third_party/xla/xla/service/conditional_to_select.h | 3 +++ third_party/xla/xla/service/conditional_to_select_test.cc | 2 +- third_party/xla/xla/service/constant_value.cc | 6 ++++++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 187376ecdfbd0e..2c7b067514974c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -354,7 +354,9 @@ cc_library( "//xla:literal", "//xla:util", "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", ], ) @@ -4883,6 +4885,9 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/service/conditional_to_select.h b/third_party/xla/xla/service/conditional_to_select.h index 4e3676468fd8a6..8bba94a9329ff0 100644 --- a/third_party/xla/xla/service/conditional_to_select.h +++ b/third_party/xla/xla/service/conditional_to_select.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ #define XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/conditional_to_select_test.cc b/third_party/xla/xla/service/conditional_to_select_test.cc index f79de7206d3d9a..82a24f2cf2a7ef 100644 --- a/third_party/xla/xla/service/conditional_to_select_test.cc +++ b/third_party/xla/xla/service/conditional_to_select_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/service/conditional_to_select.h" #include -#include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/constant_value.cc b/third_party/xla/xla/service/constant_value.cc index 14de2501c5c0ef..0555a757b753fb 100644 --- a/third_party/xla/xla/service/constant_value.cc +++ b/third_party/xla/xla/service/constant_value.cc @@ -15,8 +15,14 @@ limitations under the License. #include "xla/service/constant_value.h" +#include #include +#include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" + namespace xla { absl::StatusOr ConstantValue::FromLiteral( From d1b1aeb07de6ffe046d0a73094255b352215e6a8 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 14 Apr 2025 01:24:57 -0700 Subject: [PATCH 0664/1324] [XLA:GPU] Enforce tile size constraints also for single root fusions. It turns out that we use the tile sizes "as is" when computing launch configs, while padding dimension sizes to power of 2 in the triton fusion emitter. This means we would silently compute incorrect launch configs. While directly using the provided tile sizes is currently not a problem in end2end codegen, it is risky to assume that the passed tile sizes are either capturing the full dimension or are already powers of 2. In fact, we already had test cases that didn't follow this constraint and also shows that the wrong number of blocks was computed. Therefore we add tile size constraints to the emitter specific constraints, so that if tests use a wrong tile size value it leads to an error. PiperOrigin-RevId: 747300816 --- .../triton/fusion_emitter_device_test.cc | 2 +- .../gpu/codegen/triton/fusion_test.cc | 6 +- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../gpu/model/triton_emitter_constraints.cc | 39 +++++++------ .../gpu/model/triton_emitter_constraints.h | 18 ++++-- .../model/triton_emitter_constraints_test.cc | 57 +++++++++++++++++++ 6 files changed, 98 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index ce59346d85ee1e..8a9f59e2a37dac 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -1056,7 +1056,7 @@ ENTRY main { "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["2","5","16"]}], + "output_tiles":[{"sizes":["2","8","16"]}], "num_warps":"4", "num_ctas":"1", "num_stages":"1"}}} diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_test.cc index f0ed8e26d4f5ab..a08ec05546022b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_test.cc @@ -56,7 +56,7 @@ ENTRY entry_computation { calls=triton_computation, backend_config={"fusion_backend_config":{ "kind":"__triton", - "block_level_fusion_config":{"output_tiles":[{"sizes":["3","127"]}], + "block_level_fusion_config":{"output_tiles":[{"sizes":["4","127"]}], "num_warps":"4"}}} })")); @@ -74,12 +74,12 @@ ENTRY entry_computation { triton_fusion->launch_config(); ASSERT_NE(launch_config, std::nullopt); EXPECT_EQ(launch_config->launch_dimensions.num_blocks(), - /*ceil(125 / 3)=*/42); + /*ceil(125 / 4)=*/32); EXPECT_EQ(launch_config->launch_dimensions.num_threads_per_block(), /*32 * num_warps=*/128); EXPECT_EQ(launch_config->block_level_parameters.output_tile_sizes.size(), 1); EXPECT_THAT(launch_config->block_level_parameters.output_tile_sizes[0], - ElementsAre(3, 127)); + ElementsAre(4, 127)); } TEST_F(TritonFusionTest, diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index f4897d2bbdfb6f..72c5f4a26ef42c 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -763,6 +763,7 @@ cc_library( ":symbolic_tile_analysis", ":symbolic_tiled_hlo_instruction", "//xla:shape_util", + "//xla:util", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc index 259fa8daa05879..50c5726b535f92 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/stream_executor/device_description.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -172,19 +173,18 @@ TritonEmitterConstraints::GetBuilder( instructions, const HloFusionAdaptor& fusion_adaptor) { llvm::DenseSet unique_tile_size_maps; - llvm::SmallVector size_maps; + llvm::SmallVector root_infos; auto roots = fusion_adaptor.GetRoots(); for (const auto& tiled_hlo_instruction : instructions) { unique_tile_size_maps.insert( tiled_hlo_instruction->symbolic_tile().size_map()); - // TODO(b/365727080): We should also enforce this for single-output - // fusions. - if (roots.size() > 1 && - absl::c_any_of(roots, [&tiled_hlo_instruction]( + if (absl::c_any_of(roots, [&tiled_hlo_instruction]( const HloInstructionAdaptor& instr) { return &instr.instruction() == tiled_hlo_instruction->hlo(); })) { - size_maps.push_back(tiled_hlo_instruction->symbolic_tile().size_map()); + root_infos.push_back(RootTileInfo{ + tiled_hlo_instruction->symbolic_tile().size_map(), + SpanToVector(tiled_hlo_instruction->hlo()->shape().dimensions())}); } } @@ -196,7 +196,7 @@ TritonEmitterConstraints::GetBuilder( return std::unique_ptr( absl::WrapUnique(new TritonEmitterConstraints( - std::move(tile_size_maps), std::move(size_maps), + std::move(tile_size_maps), std::move(root_infos), std::move(custom_constraints), /*root_shape=*/instructions.back()->hlo()->shape(), device_description))); @@ -248,17 +248,24 @@ absl::StatusOr TritonEmitterConstraints::ParametersSatisfyConstraints( return false; } } - for (const auto& size_map : size_maps_) { + for (const auto& root : roots_) { llvm::SmallVector transformed_tile_parameters = - EvaluateAffineMap(size_map, + EvaluateAffineMap(root.size_map, /*dim_values=*/tile_parameters); - // For multi-output fusions, we require that the propagated tile sizes for - // potential root tiles are powers of 2. - // TODO(b/365727080): Technically we should also enforce this for fusions - // with just one root. - if (GetPaddedTileSizes(transformed_tile_parameters) != - transformed_tile_parameters) { - return false; + // We require that the propagated tile sizes for potential root tiles are + // either powers of 2 or are equal to the dimension size. + // TODO(b/365727080): Technically the tile size should always be a power of + // 2, but currently if we capture a dimension fully, we use the dimension + // size as tile size. + for (auto [tile_size, dim_size] : + llvm::zip(transformed_tile_parameters, root.dim_sizes)) { + CHECK_GT(tile_size, 0); + // If the tile size is neither a power of 2, nor equal to dim size, it is + // invalid. Otherwise we would for example compute the launch config + // incorrectly. + if ((tile_size & (tile_size - 1)) && tile_size != dim_size) { + return false; + } } } diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h index 931ac97bf38010..26b6ca35f0a5e6 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h @@ -54,13 +54,21 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { ConstraintExpression constraints; }; + // Holds the info needed to validate whether the tiling parameters satisfy the + // constraint that they are either powers of 2, or equal to the dimension + // size. + struct RootTileInfo { + mlir::AffineMap size_map; + std::vector dim_sizes; + }; + explicit TritonEmitterConstraints( llvm::SmallVector tile_size_maps, - llvm::SmallVector size_maps, + llvm::SmallVector roots, std::vector custom_constraints, const Shape& root_shape, const se::DeviceDescription& device_info) : tile_size_maps_(std::move(tile_size_maps)), - size_maps_(std::move(size_maps)), + roots_(std::move(roots)), custom_constraints_(std::move(custom_constraints)), root_shape_(root_shape), device_info_(device_info) {} @@ -93,9 +101,9 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { // collection of unique maps to improve compilation time. llvm::SmallVector tile_size_maps_; - // Tile size maps that need to be checked whether they evaluate to powers of - // 2. We need this constraint for multi-output fusions. - llvm::SmallVector size_maps_; + // Holds the info for all fusion roots necessary to check whether the tile + // sizes evaluate to powers of 2 or have the same size as the dimension. + llvm::SmallVector roots_; // Custom emitter-specific constraints to check in // `ParametersSatisfyConstraints`. diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc index d99fcee7158962..1c428bfb921c52 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -388,6 +388,63 @@ ENTRY main { IsOkAndHolds(false)); } +TEST_F(TritonEmitterConstraintsTest, FusionHasValidTileSizes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +fused_computation { + param_0 = f32[36] parameter(0) + abs = f32[36] abs(param_0) + ROOT reshape = f32[6,6] reshape(abs) +} + +ENTRY entry_computation { + param_0 = f32[36] parameter(0) + ROOT fusion = f32[6,6] fusion(param_0), kind=kCustom, + calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton"}} +})")); + std::optional analysis_without_triton_constraints = + TryAnalyzeModule(module.get(), + /*with_triton_emitter_specific_constraints=*/false); + ASSERT_TRUE(analysis_without_triton_constraints.has_value()); + + // (1,3) is a theoretically valid tiling for this fusion, so + // SymbolicTileAnalysis should allow it. + EXPECT_THAT( + analysis_without_triton_constraints->ParametersSatisfyConstraints({1, 3}), + IsOkAndHolds(true)); + + std::optional analysis_with_triton_constraints = + TryAnalyzeModule(module.get(), + /*with_triton_emitter_specific_constraints=*/true); + + ASSERT_TRUE(analysis_with_triton_constraints.has_value()); + + // (1,3) is a theoretically valid tiling for this fusion, but it does not pass + // the triton specific condition that all tile sizes are either powers of 2, + // or equal to the dimension size. + EXPECT_THAT( + analysis_with_triton_constraints->ParametersSatisfyConstraints({1, 3}), + IsOkAndHolds(false)); + + // However if we capture the last dimension fully, it should be valid. + EXPECT_THAT( + analysis_with_triton_constraints->ParametersSatisfyConstraints({1, 6}), + IsOkAndHolds(true)); + + // Also powers of 2 are valid. + EXPECT_THAT( + analysis_with_triton_constraints->ParametersSatisfyConstraints({2, 1}), + IsOkAndHolds(true)); + EXPECT_THAT( + analysis_with_triton_constraints->ParametersSatisfyConstraints({1, 8}), + IsOkAndHolds(true)); + EXPECT_THAT( + analysis_with_triton_constraints->ParametersSatisfyConstraints({1, 4}), + IsOkAndHolds(true)); +} + TEST_F(TritonEmitterConstraintsTest, MultiOutputFusionHasPowerOfTwoTileSizes) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( From 149e988b78b54a5055371541b38adb0e4a72e9c4 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Mon, 14 Apr 2025 01:39:27 -0700 Subject: [PATCH 0665/1324] Fix propagating the contracting split autotuning flag If we want to autotune the split, we should not be forcing a specfic split and leaving it as a nullptr. We had a bug and were doing the opposite. PiperOrigin-RevId: 747305144 --- .../gpu/autotuning/dot_search_space.cc | 1 + .../gpu/autotuning/gemm_fusion_autotuner.cc | 4 +++- .../autotuning/gemm_fusion_autotuner_test.cc | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index 43eeecdf7abcf9..f5295147feeeca 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -122,6 +122,7 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( // discarding all configs, and use the smallest possible tile size further // down, which is likely not what the user had in mind. config.keep_large_split = GetMaxContractingSplit(max_out_tile_) < split; + VLOG(5) << "Forcing split_k, config = " << config.ToString(); if (config.keep_large_split) { LOG(WARNING) << "split_k is larger than what we would have found automatically. " diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index ea485aa0318d80..19fde26bcb1012 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -901,7 +901,9 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { // We don't need to consider small_dot here. The new search space will // already generate a unique config for small problems. return search_space.GenerateConfigs( - autotune_contracting_split ? std::make_optional(1) : std::nullopt); + /*force_contracting_split=*/autotune_contracting_split + ? std::nullopt + : std::make_optional(1)); } // Retrieve the minimum bit-width participating in the dot. This is needed diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index e57cb1ce5fcc48..973400d13ca8d1 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -1789,6 +1789,25 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3})); } +TEST_F(DynamicSearchSpaceAutotunerTest, UsesSplitKForSmallOuterDimensions) { + const std::string hlo = R"( +HloModule module +ENTRY e { + x = s8[32,16384] parameter(0) + c = f16[32,16384] convert(x) + y = f16[16384,32] parameter(1) + ROOT out = f16[32,32] dot(c, y), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + CheckTritonAutotuning(hlo, R"( +// CHECK: ENTRY +// CHECK: __triton_gemm +// CHECK-NOT: "split_k":"1" +// CHECK: ROOT +)"); +} + } // namespace } // namespace gpu } // namespace xla From 5836a09ec1f0b1ee740ba74d031c8f19bf83d086 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 02:02:41 -0700 Subject: [PATCH 0666/1324] Update GraphDef version to 2197. PiperOrigin-RevId: 747312570 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e4433bb4bebf24..7e517bef019aa0 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2196 // Updated: 2025/4/13 +#define TF_GRAPH_DEF_VERSION 2197 // Updated: 2025/4/14 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From c794afede24c01ff6f9a2c8a60d007eb9a9ed4e7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 02:02:53 -0700 Subject: [PATCH 0667/1324] compat: Update forward compatibility horizon to 2025-04-14 PiperOrigin-RevId: 747312662 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 118d7e775c7081..61c797a99bedbc 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 13) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 14) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From ef2ec38f46ed9b42994dff05cf54a611a8e5ac3a Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 14 Apr 2025 02:09:09 -0700 Subject: [PATCH 0668/1324] Enable generic dot emitter for different result and input types. PiperOrigin-RevId: 747315064 --- .../fusion_emitter_device_legacy_port_test.cc | 7 +- .../triton/fusion_emitter_device_test.cc | 75 ++++++++++--------- .../backends/gpu/codegen/triton/support.cc | 57 ++++++++++---- .../xla/backends/gpu/codegen/triton/support.h | 5 +- .../gpu/codegen/triton/support_test.cc | 56 +++++++++----- .../xla/stream_executor/device_description.h | 2 +- 6 files changed, 127 insertions(+), 75 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index 76a91d00104480..d652df9e2332f8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -3327,10 +3327,7 @@ CHECK: inputPrecision = tf32 ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -// TODO(b/393299275): this test requires us to allow actual mixed type GEMMs -// in the lowering. We need to expand support tests and the lowering to model -// mixed types as needed. (f8e4m3fn x f8e4m3fn -> f32) -TEST_F(TritonTest, DISABLED_Fp8LoweringIsSupportedPostHopper) { +TEST_F(TritonTest, Fp8LoweringIsSupportedPostHopper) { if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } @@ -3365,7 +3362,7 @@ ENTRY main { CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4M3FN> * tensor<64x32xf8E4M3FN> -> tensor<128x32xf32> )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module_and_metadata.module), ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index 8a9f59e2a37dac..452f6583e0b17e 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" @@ -1967,12 +1969,10 @@ ENTRY entry_computation { TEST_F(TritonEmitterTest, ConvertF16ToF8E5M2Exhaustive) { // TODO(b/396595945): enable post-Ampere once Triton respects RTNE semantics // on H100. - if (auto cc = - std::get_if(&GpuComputeCapability())) { - if (cc->IsAtLeastHopper()) { - GTEST_SKIP() << "Skipping tests above Ampere, Triton's conversion isn't " - "always correct"; - } + if (auto cc = std::get_if(&GpuComputeCapability()); + cc && cc->IsAtLeastHopper()) { + GTEST_SKIP() << "Skipping tests above Ampere, Triton's conversion isn't " + "always correct"; } constexpr absl::string_view kHloTextTemplate = R"( @@ -2011,11 +2011,9 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, FP8ToFP8EndToEnd) { - if (auto cc = - std::get_if(&GpuComputeCapability())) { - if (!cc->IsAtLeastHopper()) { - GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; - } + if (auto cc = std::get_if(&GpuComputeCapability()); + cc && !cc->IsAtLeastHopper()) { + GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } const std::string hlo_text = R"( @@ -2669,7 +2667,7 @@ ErrorSpec ErrorSpecForDotAlgorithm(PrecisionConfig::Algorithm algorithm) { case PrecisionConfig::ALG_UNSET: // Give a loose tolerance to ALG_UNSET, as the expected behaviour is // not deducible from the algorithm name alone. - return ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}; + return ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}; case PrecisionConfig::ALG_DOT_F16_F16_F16: // Computed to make the tests pass (and it seems reasonable on the face of // it), and not derived from first principles. @@ -2868,52 +2866,61 @@ INSTANTIATE_TEST_SUITE_P(TF32DotAlgorithmEmitterTestSuite, class DotUnsetAlgorithmEmitterTest : public TritonEmitterTest, - public ::testing::WithParamInterface {}; + public ::testing::WithParamInterface< + std::tuple> { + public: + static std::string ParamToString( + const ::testing::TestParamInfo& + data) { + auto [result_type, input_type] = data.param; + return absl::StrCat(primitive_util::LowercasePrimitiveTypeName(result_type), + "_", + primitive_util::LowercasePrimitiveTypeName(input_type)); + }; +}; TEST_P(DotUnsetAlgorithmEmitterTest, UnsetAlgorithmIsEmittedCorrectly) { - // This currently assumes that the dot output type is the same as the input - // type. This is not enforced by the verifier/HLO spec, but is currently true - // for our emitters, and is enforced by `support_test.cc`. This test may - // require upgrading if we ever consider emitting code for truly mixed type - // `dot`s. - PrimitiveType ty = GetParam(); - if (!internal::IsResultTypeSupportedByAlgUnsetDot(ty, - GpuComputeCapability())) { - GTEST_SKIP() << primitive_util::LowercasePrimitiveTypeName(ty) - << " is not supported on this platform."; + auto [input_type, result_type] = GetParam(); + if (!internal::AreTypesSupportedByAlgUnsetDot(input_type, result_type, + GpuComputeCapability())) { + GTEST_SKIP() << "Not supported on this platform."; } ErrorSpec error_spec = ErrorSpecForDotAlgorithm(PrecisionConfig::ALG_UNSET); // For 8-bit floating point types, we need to allow large errors. - if (primitive_util::IsFloatingPointType(ty) && - primitive_util::BitWidth(ty) == 8) { + if (primitive_util::IsFloatingPointType(result_type) && + primitive_util::BitWidth(result_type) == 8) { error_spec = ErrorSpec{/*aabs=*/1e0, /*arel=*/1e-1}; } const std::string kHloText = - GetDotAlgorithmHlo(ty, ty, PrecisionConfig::ALG_UNSET); + GetDotAlgorithmHlo(input_type, result_type, PrecisionConfig::ALG_UNSET); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, error_spec)); } -std::vector AllXlaDataTypesSupportedByAlgUnsetDotLowering() { +auto AllXlaDataTypesSupportedByAlgUnsetDotLowering() { // We don't have a pointer to stream executor available here so we can't // detect the particular device we're running on with a canonical API call. // Instead, we just return a superset of the supported types (i.e. those that // are supported on the latest device), and filter out the unsupported types // in the test body. - std::vector supported_types; - absl::c_copy_if(AllXlaDataTypes(), std::back_inserter(supported_types), - [](PrimitiveType type) { - return internal::IsResultTypeSupportedByAlgUnsetDot( - type, se::CudaComputeCapability::Blackwell()); - }); + std::vector supported_types; + ::testing::internal::ParamGenerator + all_types = ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllXlaDataTypes())); + absl::c_copy_if( + all_types, std::back_inserter(supported_types), [](const auto& types) { + auto [result_type, input_type] = types; + return static_cast(internal::AreTypesSupportedByAlgUnsetDot( + input_type, result_type, se::CudaComputeCapability::Blackwell())); + }); return supported_types; } INSTANTIATE_TEST_SUITE_P( DotUnsetAlgorithmEmitterTestSuite, DotUnsetAlgorithmEmitterTest, ::testing::ValuesIn(AllXlaDataTypesSupportedByAlgUnsetDotLowering()), - TypeTestParamToString); + DotUnsetAlgorithmEmitterTest::ParamToString); } // namespace } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index fff08e4d0da7b7..85f2ea80c64cc5 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -389,9 +389,9 @@ CodegenDecision IsTritonSupportedDot( } // TODO(b/393299275): add support tests for mixed types. - if (result_type != lhs_type || result_type != rhs_type) { + if (lhs_type != rhs_type) { return CodegenDecision::Forbid( - "Dot operation only supports same types for the result, lhs and rhs."); + "Dot operation only supports same types for lhs and rhs."); } absl::Status status = CheckSupportedCheckDotDimensions(dot); @@ -408,10 +408,12 @@ CodegenDecision IsTritonSupportedDot( PrecisionConfig::Algorithm_Name(algorithm))); } - if (algorithm == PrecisionConfig::ALG_UNSET && - !internal::IsResultTypeSupportedByAlgUnsetDot(result_type, gpu_version)) { - return CodegenDecision::Forbid( - "Unsupported result type for dot algorithm ALG_UNSET."); + if (algorithm == PrecisionConfig::ALG_UNSET) { + if (CodegenDecision decision = internal::AreTypesSupportedByAlgUnsetDot( + lhs_type, result_type, gpu_version); + !decision) { + return decision; + } } if (CodegenDecision conversion_decision = @@ -617,17 +619,46 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { } } -bool IsResultTypeSupportedByAlgUnsetDot( - PrimitiveType result_type, const se::GpuComputeCapability& gpu_version) { - std::vector supported_types = {BF16, F16, F32, F64, F8E5M2}; +CodegenDecision AreTypesSupportedByAlgUnsetDot( + PrimitiveType input_type, PrimitiveType result_type, + const se::GpuComputeCapability& gpu_version) { + if (input_type == F64 && result_type != F64) { + return CodegenDecision::Forbid( + "Dot operation only supports F64 result type for F64 input type."); + } - if (auto* cuda_cc = std::get_if(&gpu_version)) { - if (cuda_cc->IsAtLeastHopper()) { - supported_types.push_back(F8E4M3FN); + if (input_type == F8E4M3FN || result_type == F8E4M3FN) { + if (auto* cuda_cc = std::get_if(&gpu_version); + cuda_cc && !cuda_cc->IsAtLeastHopper()) { + return CodegenDecision::Forbid( + "Dot operation for F8E4M3FN is not supported before Hopper."); } } - return absl::c_linear_search(supported_types, result_type); + auto supported_float_types = {BF16, F16, F32, F64, F8E5M2, F8E4M3FN}; + if (absl::c_linear_search(supported_float_types, input_type)) { + return CodegenDecision::Allow(); + } + + if (input_type == S8 && result_type == S32) { + return CodegenDecision::Allow(); + } + + auto partially_supported_signed_types = {S8, S16, S32, S64}; + if (absl::c_linear_search(partially_supported_signed_types, input_type)) { + if (absl::c_linear_search(partially_supported_signed_types, result_type)) { + return CodegenDecision::Forbid( + "Dot operation does not support these signed integer types."); + } + if (primitive_util::IsFloatingPointType(result_type)) { + return CodegenDecision::Forbid( + "Dot operation does not support floating point input and signed " + "integer result types."); + } + return CodegenDecision::Allow(); + } + + return CodegenDecision::Forbid("Unsupported types."); } } // namespace internal diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.h b/third_party/xla/xla/backends/gpu/codegen/triton/support.h index 47f2f02c6f1ead..37301424eb06c8 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.h @@ -73,8 +73,9 @@ namespace internal { bool IsTritonUnsupportedOpcode(HloOpcode opcode); // This is exposed for testing purposes only. Do not use. -bool IsResultTypeSupportedByAlgUnsetDot( - PrimitiveType result_type, const se::GpuComputeCapability& gpu_version); +CodegenDecision AreTypesSupportedByAlgUnsetDot( + PrimitiveType input_type, PrimitiveType result_type, + const se::GpuComputeCapability& gpu_version); } // namespace internal diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index 949986d2b4f899..f117ba9f0b01e6 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -1737,18 +1737,35 @@ INSTANTIATE_TEST_SUITE_P(ReverseSuite, ReverseTest, using DotTest = TritonSupportTest; -class DotTypesTest : public DotTest, - public ::testing::WithParamInterface< - std::tuple> { +class DotTypesTest + : public DotTest, + public ::testing::WithParamInterface< + std::tuple> { + public: + static std::string ParamToString( + const ::testing::TestParamInfo& data) { + auto [result_type, input_type, cc] = data.param; + return absl::StrCat(primitive_util::LowercasePrimitiveTypeName(result_type), + "_", + primitive_util::LowercasePrimitiveTypeName(input_type), + "_", ComputeCapabilityToString(cc)); + }; }; TEST_P(DotTypesTest, Dot) { - // Testing A[] = dot(A[], A[]). - // TODO(b/393299275): Add tests for cases where LHS, RHS, and result have - // different types. Using infra of parameterized test will not work as the - // number of combinations is too large. - auto [type, cc] = GetParam(); - const std::string hlo_text = R"( + // Testing B[] = dot(A[], A[]). + auto [result_type, input_type, cc] = GetParam(); + + ExpectedFailMode fail_mode = ExpectedFailMode::kFail; + if (input_type == F8E4M3FN || result_type == F8E4M3FN) { + if (auto* cuda_cc = std::get_if(&cc); + cuda_cc && !cuda_cc->IsAtLeastHopper()) { + // Hits llvm::report_fatal_error during Triton compilation. + fail_mode = ExpectedFailMode::kFailOrCrash; + } + } + + std::string hlo_text = R"( flhs { ROOT result = $0[128,256] parameter(0) } @@ -1774,29 +1791,28 @@ ENTRY triton_computation { } } } - ROOT result = $0[128,512]{1,0} dot(lhs, rhs), + ROOT result = $1[128,512]{1,0} dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; + hlo_text = absl::Substitute( + hlo_text, primitive_util::LowercasePrimitiveTypeName(input_type), + primitive_util::LowercasePrimitiveTypeName(result_type)); - ExpectedFailMode fail_mode = ExpectedFailMode::kFail; - if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, type)) { - fail_mode = ExpectedFailMode::kFailOrCrash; - } - - TF_ASSERT_OK_AND_ASSIGN( - TestedInstruction ti, - ParseTemplateAndGetInstruction(hlo_text, type, HloOpcode::kDot, - /* use_nested_gemm_fusions=*/true)); + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, + ParseTemplateAndGetInstruction( + hlo_text, PRIMITIVE_TYPE_INVALID, HloOpcode::kDot, + /*use_nested_gemm_fusions=*/true)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc, fail_mode); } INSTANTIATE_TEST_SUITE_P( DotTestSuite, DotTypesTest, ::testing::Combine( + ::testing::ValuesIn(AllOpSupportedTypes(HloOpcode::kDot)), ::testing::ValuesIn(AllOpSupportedTypes(HloOpcode::kDot)), ::testing::ValuesIn(AllDevicesToTest())), - TritonSupportTestTypeAndDeviceToString); + DotTypesTest::ParamToString); TEST_F(DotTest, NonFusionRhs) { const std::string kHloTestTemplate = R"( diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 1c5365be2b9f52..dadcdf1e99d727 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -284,7 +284,7 @@ class DeviceDescription { // Returns the CUDA compute capability if we're running on the CUDA platform. // If a CUDA compute capability is not available, the major version will be - // zero. + // negative. CudaComputeCapability cuda_compute_capability() const; // Returns the ROCm compute capability if we're running on the ROCm platform. From 95380d56d4fae73a76f24c0c5514a1dcfaca28e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 02:14:30 -0700 Subject: [PATCH 0669/1324] [XLA:GPU]: Make memory compute parallelism inline. The constant never changes value so we can change it to a constexpr. PiperOrigin-RevId: 747317066 --- .../model/gpu_indexing_performance_model.cc | 8 +++--- .../model/gpu_indexing_performance_model.h | 2 -- .../gpu/model/gpu_performance_model.cc | 8 +++--- .../gpu/model/gpu_performance_model_base.cc | 6 ++--- .../gpu/model/gpu_performance_model_base.h | 26 +++++++++---------- 5 files changed, 22 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index ce41097da44677..c124961eca2744 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -335,8 +335,8 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( launch_dimensions.num_threads_per_block()); absl::Duration write_time = WriteTime(*device_info_, bytes_written); absl::Duration memory_access_time = read_time + write_time; - absl::Duration exec_time = CombineComputeAndMemoryAccessTime( - compute_time, memory_access_time, GpuPerformanceModelOptions::Default()); + absl::Duration exec_time = + CombineComputeAndMemoryAccessTime(compute_time, memory_access_time); EstimateRunTimeData runtime_data = {flops, bytes_read, bytes_written, read_time, write_time, compute_time, @@ -516,8 +516,8 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( launch_dimensions.num_threads_per_block()); absl::Duration memory_access_time = read_time + write_time; - absl::Duration exec_time = CombineComputeAndMemoryAccessTime( - compute_time, memory_access_time, GpuPerformanceModelOptions::Default()); + absl::Duration exec_time = + CombineComputeAndMemoryAccessTime(compute_time, memory_access_time); return EstimateRunTimeData{/*flops=*/flops, /*bytes_read=*/bytes_read, diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 577ab9bf7f6cda..ab41fdf82bee28 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ #define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ -#include #include #include #include @@ -32,7 +31,6 @@ limitations under the License. #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" -#include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/instruction_fusion.h" diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index ab09a82537e9b3..9a04d9d07eb51b 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -93,8 +93,8 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( } absl::Duration write_time = WriteTime(device_info, bytes_written); - absl::Duration exec_time = CombineComputeAndMemoryAccessTime( - compute_time, read_time + write_time, config); + absl::Duration exec_time = + CombineComputeAndMemoryAccessTime(compute_time, read_time + write_time); EstimateRunTimeData runtime_data = {flops, bytes_read, bytes_written, read_time, write_time, compute_time, @@ -207,8 +207,8 @@ GpuPerformanceModel::EstimateRunTimeForInstructionCached( write_time += producer_runtime.write_time; } - auto exec_time = CombineComputeAndMemoryAccessTime( - compute_time, read_time + write_time, config); + auto exec_time = + CombineComputeAndMemoryAccessTime(compute_time, read_time + write_time); VLOG(3) << "Runtime data for producer-consumer fusion:\n" << " producer: " << producer->name() << "\n" diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc index c15cc173db5014..87a9554112b464 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -351,11 +351,9 @@ absl::Duration GpuPerformanceModelBase::ComputeTime( /*static*/ absl::Duration GpuPerformanceModelBase::CombineComputeAndMemoryAccessTime( - absl::Duration compute_time, absl::Duration memory_access_time, - const GpuPerformanceModelOptions& config) { + absl::Duration compute_time, absl::Duration memory_access_time) { return compute_time + memory_access_time - - std::min(compute_time, memory_access_time) * - config.memory_compute_parallelism; + std::min(compute_time, memory_access_time) * kMemoryComputeParallelism; } /*static*/ diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h index 0ac09b5dcf2bd7..64963c286daf7b 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h @@ -127,18 +127,6 @@ class GpuPerformanceModelCache { }; struct GpuPerformanceModelOptions { - // Factor for how much parallelism between compute and memory accesses should - // be assumed. If 1.0, assume perfect parallelism (the run time is the maximum - // of both times). If 0.0, assume no parallelism (the run time is the sum of - // both times). - // - // This constant was chosen empirically in early 2024, based on runtime - // performance on a set of benchmarks internal to Google. Intuitively, we - // expect it to be close to 1, but not quite 1 (i.e., sometimes, compute - // or memory accesses will be stalled waiting for the other, but usually - // they won't). - double memory_compute_parallelism = 0.95; - // If present, use this to retrieve fusion analyses. HloFusionAnalysisCache* fusion_analysis_cache = nullptr; @@ -167,6 +155,17 @@ class GpuPerformanceModelBase { absl::Microseconds(5); static constexpr float kL2CacheSpeedup = 2.5; static constexpr float kL1CacheSpeedup = 8; + // Factor for how much parallelism between compute and memory accesses should + // be assumed. If 1.0, assume perfect parallelism (the run time is the maximum + // of both times). If 0.0, assume no parallelism (the run time is the sum of + // both times). + // + // This constant was chosen empirically in early 2024, based on runtime + // performance on a set of benchmarks internal to Google. Intuitively, we + // expect it to be close to 1, but not quite 1 (i.e., sometimes, compute + // or memory accesses will be stalled waiting for the other, but usually + // they won't). + static constexpr double kMemoryComputeParallelism = 0.95; // Uses HloFusionAnalysis for computing the actual number of threads and // blocks that the IR emitter will use. @@ -225,8 +224,7 @@ class GpuPerformanceModelBase { int64_t num_blocks, int64_t num_threads_per_block); static absl::Duration CombineComputeAndMemoryAccessTime( - absl::Duration compute_time, absl::Duration memory_access_time, - const GpuPerformanceModelOptions& config); + absl::Duration compute_time, absl::Duration memory_access_time); // Logs estimates for the operand read if VLOG is enabled. static void VLogOperandRead(const HloInstruction* operand, From 0c1964c4285f6e90a4e644426bcf34a8640017a7 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Mon, 14 Apr 2025 02:59:05 -0700 Subject: [PATCH 0670/1324] [XLA:GPU] Relax type constraints in the autotuning test `ApplySplitKWithoutAlteringTiling`. The idea of the test is to ensure that tilings are not altered by auto-tuning. Choosing a different accumulator type should be allowed and does not affect tilings. PiperOrigin-RevId: 747330366 --- .../xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 973400d13ca8d1..e2daf1825f7a56 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -614,7 +614,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: f16[3,55,20] +; CHECK: f{{(16|32)}}[3,55,20] ; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1} ; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}} )"); From 8529d146dcc23ab2d5b82a8aa68daf88fc07a505 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Mon, 14 Apr 2025 03:01:29 -0700 Subject: [PATCH 0671/1324] PR #18688: [XLA:CPU][oneDNN] Pad MemrefInfo pod Imported from GitHub PR https://github.com/openxla/xla/pull/18688 This PR pads the MemrefInfoPOD emitted on stack so that it aligns with the system cacheline. Copybara import of the project: -- 18a8fc8d59365c1382f1bc45a310f91575d27d32 by Akhil Goel : Pad Memref pod on stack -- 8ee40fc4da18178b14adca48bca2e631bfab38fd by Akhil Goel : Address review comment -- 31a04cea88c0177910380decb3a755e0e090f0dd by Akhil Goel : Address review comments -- 01123fd333306be98d8974819854c2dbc7776800 by Akhil Goel : Add benchmark -- c8150221d6a2d877f30bd2d86bb183f131bd2959 by Akhil Goel : Refine comment -- 8efe08757e2b49bfd2fb13f7283df771fd5c573c by Akhil Goel : Address review comments -- 5eae0678c64927c4acac4d5310223e26e509dca5 by Akhil Goel : Fix formatting style Merging this change closes #18688 PiperOrigin-RevId: 747331035 --- .../xla/xla/backends/cpu/benchmarks/BUILD | 23 ++++ .../cpu/benchmarks/hlo_benchmark_runner.cc | 12 ++ .../cpu/benchmarks/hlo_benchmark_runner.h | 1 + .../onednn_matmul_benchmark_test.cc | 98 ++++++++++++++++ .../xla/xla/service/cpu/onednn_memory_util.cc | 19 +-- third_party/xla/xla/service/cpu/tests/BUILD | 19 +++ .../cpu/tests/onednn_memory_util_test.cc | 110 ++++++++++++++++++ 7 files changed, 275 insertions(+), 7 deletions(-) create mode 100644 third_party/xla/xla/backends/cpu/benchmarks/onednn_matmul_benchmark_test.cc create mode 100644 third_party/xla/xla/service/cpu/tests/onednn_memory_util_test.cc diff --git a/third_party/xla/xla/backends/cpu/benchmarks/BUILD b/third_party/xla/xla/backends/cpu/benchmarks/BUILD index 97fb8de2cc810e..7ecdffa3c312b6 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/backends/cpu/benchmarks/BUILD @@ -1,4 +1,5 @@ load("//xla:xla.default.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "tsl_copts") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -428,6 +429,28 @@ xla_cc_test( ], ) +xla_cc_test( + name = "onednn_matmul_benchmark_test", + srcs = ["onednn_matmul_benchmark_test.cc"], + copts = tsl_copts(), + fail_if_no_test_linked = False, # NOLINT=This contains benchmarks only, no tests. + deps = [ + ":hlo_benchmark_runner", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/service/cpu:onednn_util", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "tanh_benchmark_test", srcs = ["tanh_benchmark_test.cc"], diff --git a/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc b/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc index 0cd50af444f5e9..ff9dc06a116445 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc +++ b/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc @@ -147,6 +147,12 @@ absl::Status RunHloBenchmark(benchmark::State& state, compile_options.executable_build_options.mutable_debug_options() ->add_xla_disable_hlo_passes("cpu-parallel-task-assigner"); } + // TODO(intel-tf): Remove this if-block once oneDNN custom calls are enabled + // with thunk runtime + if (!benchmark_options.use_thunk_runtime) { + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_cpu_use_thunk_runtime(false); + } std::unique_ptr executable; if (benchmark_options.aot_options) { auto* cpu_client = tsl::down_cast(client.get()); @@ -297,6 +303,12 @@ absl::Status CompileHloBenchmark(benchmark::State& state, compile_options.executable_build_options.mutable_debug_options() ->add_xla_disable_hlo_passes("cpu-parallel-task-assigner"); } + // TODO(intel-tf): Remove this if-block once oneDNN custom calls are enabled + // with thunk runtime + if (!benchmark_options.use_thunk_runtime) { + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_cpu_use_thunk_runtime(false); + } for (auto _ : state) { TF_ASSIGN_OR_RETURN(std::unique_ptr executable, diff --git a/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.h b/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.h index b11aef8fd9dcb5..3c5daf29707eca 100644 --- a/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.h +++ b/third_party/xla/xla/backends/cpu/benchmarks/hlo_benchmark_runner.h @@ -37,6 +37,7 @@ using StrToStrMapping = struct HloBenchmarkOptions { int32_t num_executions = 1; bool disable_parallel_task_assigner = false; + bool use_thunk_runtime = true; // If not null, AOT compilation will be used. std::unique_ptr aot_options; }; diff --git a/third_party/xla/xla/backends/cpu/benchmarks/onednn_matmul_benchmark_test.cc b/third_party/xla/xla/backends/cpu/benchmarks/onednn_matmul_benchmark_test.cc new file mode 100644 index 00000000000000..2338d3a74021d1 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/benchmarks/onednn_matmul_benchmark_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/cpu/onednn_util.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { + +static void BM_oneDNN_MM(benchmark::State& state) { + PrimitiveType dtype = static_cast(state.range(0)); + int64_t d0 = state.range(1); + + absl::string_view hlo = R"( + HloModule oneDNN_$dtype_$d0 + + ENTRY e { + lhs = $dtype[128,$d0] parameter(0) + rhs = $dtype[$d0,256] parameter(1) + ROOT custom-call = $dtype[128,256] custom-call(lhs, rhs), + custom_call_target="__onednn$matmul" + } + )"; + + std::minstd_rand0 engine; + + auto lhs_shape = ShapeUtil::MakeShape(dtype, {128, d0}); + auto rhs_shape = ShapeUtil::MakeShape(dtype, {d0, 256}); + Literal p0, p1; + + if (dtype == F32) { + p0 = *LiteralUtil::CreateRandomLiteral(lhs_shape, &engine, 1.0f, 0.1f); + p1 = *LiteralUtil::CreateRandomLiteral(rhs_shape, &engine, 1.0f, 0.1f); + } else if (dtype == BF16 && IsSupportedType(BF16)) { + p0 = + *LiteralUtil::CreateRandomLiteral(lhs_shape, &engine, 1.0f, 0.1f); + p1 = + *LiteralUtil::CreateRandomLiteral(rhs_shape, &engine, 1.0f, 0.1f); + } else if (dtype == F16 && IsSupportedType(F16)) { + p0 = *LiteralUtil::CreateRandomLiteral(lhs_shape, &engine, 1.0f, 0.1f); + p1 = *LiteralUtil::CreateRandomLiteral(rhs_shape, &engine, 1.0f, 0.1f); + } else { + VLOG(0) << primitive_util::LowercasePrimitiveTypeName(dtype) + << " not supported on this platform"; + return; + } + + std::vector args = {&p0, &p1}; + HloBenchmarkOptions benchmark_options; + benchmark_options.use_thunk_runtime = false; + CHECK_OK(RunHloBenchmark( + state, hlo, args, + {{"$dtype", primitive_util::LowercasePrimitiveTypeName(dtype)}, + {"$d0", absl::StrCat(d0)}}, + benchmark_options)); +} + +#define BENCHMARK_ONEDNN_MM(dtype) \ + BENCHMARK(BM_oneDNN_MM) \ + ->MeasureProcessCPUTime() \ + ->Args({dtype, 512}) \ + ->Args({dtype, 1024}) \ + ->Args({dtype, 2048}) + +BENCHMARK_ONEDNN_MM(F32); +BENCHMARK_ONEDNN_MM(BF16); +BENCHMARK_ONEDNN_MM(F16); + +} // namespace xla::cpu + +#endif // INTEL_MKL diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.cc b/third_party/xla/xla/service/cpu/onednn_memory_util.cc index 90295d65f2bad9..c81a8e1863747c 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.cc @@ -45,9 +45,11 @@ namespace cpu { struct MemrefInfoPOD { int64_t dtype; int64_t rank; + void* data; + int64_t unused; // This unused value pads the struct to align with a 64-byte + // cacheline int64_t dims[kOneDnnMaxNDims]; int64_t strides[kOneDnnMaxNDims]; - void* data; }; MemrefInfoHandler CreateMemrefFromShape(const Shape& shape, void* const buf) { @@ -109,7 +111,7 @@ StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder, llvm::ArrayType::get(builder.getInt64Ty(), kOneDnnMaxNDims); llvm::StructType* memref_info_type = llvm::StructType::get( builder.getContext(), - {i64_type, i64_type, i64_array_type, i64_array_type, ptr_type}); + {i64_type, i64_type, ptr_type, i64_type, i64_array_type, i64_array_type}); // Prepare array dims and strides. llvm::Value* dims_val = llvm::UndefValue::get(i64_array_type); @@ -121,16 +123,19 @@ StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder, strides_val = builder.CreateInsertValue(strides_val, stride_val, i); } - // Prepare values for struct MemrefInfo. + // Prepare values for struct MemrefInfo with padding to align to system + // cacheline llvm::Value* dtype_val = builder.getInt64(shape.element_type()); llvm::Value* rank_val = builder.getInt64(rank); + llvm::Value* pad_val = builder.getInt64(0xff); llvm::Value* data_ptr = ir_array.GetBasePointer(); llvm::Value* memref_info_val = llvm::UndefValue::get(memref_info_type); memref_info_val = builder.CreateInsertValue(memref_info_val, dtype_val, 0); memref_info_val = builder.CreateInsertValue(memref_info_val, rank_val, 1); - memref_info_val = builder.CreateInsertValue(memref_info_val, dims_val, 2); - memref_info_val = builder.CreateInsertValue(memref_info_val, strides_val, 3); - memref_info_val = builder.CreateInsertValue(memref_info_val, data_ptr, 4); + memref_info_val = builder.CreateInsertValue(memref_info_val, data_ptr, 2); + memref_info_val = builder.CreateInsertValue(memref_info_val, pad_val, 3); + memref_info_val = builder.CreateInsertValue(memref_info_val, dims_val, 4); + memref_info_val = builder.CreateInsertValue(memref_info_val, strides_val, 5); // Allocate MemrefInfo on the stack llvm::Value* memref_info_ptr = llvm_ir::EmitAllocaAtFunctionEntry( @@ -159,8 +164,8 @@ dnnl::memory::data_type MemrefInfo::GetOneDnnDataType() const { } dnnl::memory::desc MemrefInfo::GetOneDnnMemDesc() const { - auto dims = GetOneDnnDims(); auto dtype = GetOneDnnDataType(); + auto dims = GetOneDnnDims(); auto strides = GetOneDnnStrides(); return dnnl::memory::desc{dims, dtype, strides}; } diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index 9567429330f98d..a3b1c7aac2a8f1 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -410,6 +410,25 @@ xla_cc_test( ], ) +xla_cc_test( + name = "onednn_memory_util_test", + srcs = ["onednn_memory_util_test.cc"], + copts = tsl_copts(), + fail_if_no_test_linked = False, # NOLINT=There are only tests for Intel MKL. + deps = [ + "//xla:shape_util", + "//xla/hlo/testlib:filecheck", + "//xla/service:cpu_plugin", + "//xla/service/cpu:onednn_memory_util", + "//xla/service/llvm_ir:llvm_util", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "onednn_layer_norm_test", srcs = ["onednn_layer_norm_test.cc"], diff --git a/third_party/xla/xla/service/cpu/tests/onednn_memory_util_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_memory_util_test.cc new file mode 100644 index 00000000000000..b622f47cc720e0 --- /dev/null +++ b/third_party/xla/xla/service/cpu/tests/onednn_memory_util_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) + +#include "xla/service/cpu/onednn_memory_util.h" + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +class MemoryUtilTest + : public ::testing::Test, + public ::testing::WithParamInterface> { + protected: + constexpr static const char* test_pattern_ = R"( + CHECK: %[[mref0:[0-9]+]] = insertvalue + CHECK: %[[mref1:[0-9]+]] = insertvalue + CHECK-SAME: [[arr:\[12 x i64\]]] } %[[mref0]], i64 255, 3 + CHECK: %{{[0-9]+}} = insertvalue + CHECK-SAME: %[[mref1]], [[arr]] )"; + + auto GetMemRefTestPattern(Shape shape) { + std::ostringstream stream; + stream << "["; + absl::c_for_each(shape.dimensions(), + [&stream](auto x) { stream << "i64 " << x << ", "; }); + return absl::StrCat(test_pattern_, stream.str()); + } +}; + +TEST_P(MemoryUtilTest, VerifyMemRefTest) { + std::string filecheck_input; + llvm::LLVMContext context = llvm::LLVMContext(); + llvm::IRBuilder builder(context); + llvm::raw_string_ostream ostream(filecheck_input); + llvm::Module module("MemoryUtilTest", context); + + llvm::FunctionType* function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), {builder.getPtrTy()}, false); + llvm::Function* function = llvm::Function::Create( + function_type, llvm::Function::LinkageTypes::ExternalLinkage, + "memory_util_test", module); + llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "BB", function); + builder.SetInsertPoint(bb); + + Shape shape = ShapeUtil::MakeShape(F32, GetParam()); + llvm::Argument* ptr = function->getArg(0); + llvm::Type* type = llvm_ir::PrimitiveTypeToIrType(F32, builder.getContext()); + + if (shape.IsArray()) { + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + type = llvm::ArrayType::get(type, shape.dimensions(dim)); + } + } + + llvm_ir::IrArray ir_array(ptr, type, shape); + auto alloca = GetAllocaAndEmitMemrefInfo(builder, ir_array); + alloca.EmitLifetimeEnd(); + ostream << module; + + absl::StatusOr match = + RunFileCheck(filecheck_input, GetMemRefTestPattern(shape)); + TF_ASSERT_OK(match.status()); + EXPECT_TRUE(match.value()); +} + +INSTANTIATE_TEST_SUITE_P( + MemoryUtilTestSuite, MemoryUtilTest, + ::testing::Values(std::vector({30}), + std::vector({30, 40}), + std::vector({30, 40, 50})), + [](const ::testing::TestParamInfo& info) { + return absl::StrCat("Rank_", info.param.size()); + }); + +} // namespace +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL From a3394eba93b8f25737582ef1e74ec55079b0a60a Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Mon, 14 Apr 2025 03:18:33 -0700 Subject: [PATCH 0672/1324] [XLA:GPU] Make split_k_gemm_rewriter_tests pass with high precision Split-K accumulation (`--xla_gpu_triton_gemm_disable_reduced_precision_reduction`). Duplicate tests to test both flag values. The flag itself and the related tests will be removed, hence, duplicating the tests is preferred over relaxing the test EXPECTations. Use GMockMatchers in the new tests for conciseness and better readability. PiperOrigin-RevId: 747336083 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../service/gpu/split_k_gemm_rewriter_test.cc | 219 ++++++++++++++++-- 2 files changed, 197 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8d3345d55dbbb6..ba0470ad748b6f 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -898,6 +898,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index d7c3b72deabc04..d0a63b501feaf6 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -38,11 +38,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -70,7 +68,32 @@ TEST(HasDivisibleSuffixAllowingSplitTest, AllTests) { using SplitKTest = HloTestBase; -TEST_F(SplitKTest, MakeSplitK) { +// TODO(b/409940111): Remove these tests once the flag is deprecated. +class SplitKTestWithLowPreciseReduction + : public HloTestBase, + public ::testing::WithParamInterface { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( + false); + return debug_options; + } +}; + +class SplitKTestWithMorePreciseReduction + : public HloTestBase, + public ::testing::WithParamInterface { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( + true); + return debug_options; + } +}; + +TEST_F(SplitKTestWithLowPreciseReduction, MakeSplitK) { const std::string hlo_text = R"( HloModule t @@ -102,7 +125,7 @@ ENTRY e { EXPECT_EQ(root->metadata().op_name(), "foo"); } -TEST_F(SplitKTest, MakeSplitKWithOutputFusion) { +TEST_F(SplitKTestWithLowPreciseReduction, MakeSplitKWithOutputFusion) { const std::string hlo_text = R"( HloModule t @@ -217,7 +240,8 @@ ENTRY e { "Sliced contracting dimension is not supported yet."))); } -TEST_F(SplitKTest, MakeSplitKWithNonStandardOutputLayout) { +TEST_F(SplitKTestWithLowPreciseReduction, + MakeSplitKWithNonStandardOutputLayout) { const std::string kHloText = R"( HloModule t @@ -251,6 +275,40 @@ ENTRY e { Layout({0, 1})); } +TEST_F(SplitKTestWithMorePreciseReduction, + MakeSplitKWithNonStandardOutputLayout) { + const std::string kHloText = R"( +HloModule t + +triton_gemm_dot { +parameter_0 = s8[3,128,5,32]{3,2,1,0} parameter(0) +bitcast.1 = s8[3,5,32,128]{2,1,3,0} bitcast(parameter_0) +copy.1 = s8[3,5,32,128]{3,2,1,0} copy(bitcast.1) +reshape.5 = s8[480,128]{1,0} reshape(copy.1) +convert.8 = bf16[480,128]{1,0} convert(reshape.5) +parameter_1 = bf16[16,128]{1,0} parameter(1) +ROOT dot.0 = bf16[480,16]{0,1} dot(convert.8, parameter_1), +lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { +p0 = s8[3,128,5,32]{3,2,1,0} parameter(0) +p1 = bf16[16,128]{1,0} parameter(1) +ROOT fusion = bf16[480,16]{0,1} fusion(p0, p1), +kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + TritonGemmConfig config(16, 16, 16, 4, 1, 4); + + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Convert( + m::Reduce().WithShape(m::Shape().WithLayout({0, 1}))))); +} + TEST_F(SplitKTest, MakeSplitKWithExistingBatchDim) { const std::string hlo_text = R"( HloModule m @@ -501,7 +559,7 @@ ENTRY e { "Too small divisible part of the contracting dimension.")); } -TEST_F(SplitKTest, FragmentedKSupported) { +TEST_F(SplitKTestWithLowPreciseReduction, FragmentedKSupported) { const std::string hlo_text = R"( HloModule t @@ -556,6 +614,61 @@ ENTRY e { /*broadcast_multiplier=*/1))); } +TEST_F(SplitKTestWithMorePreciseReduction, FragmentedKSupported) { + const std::string hlo_text = R"( +HloModule t + +triton_gemm_dot { + p0 = f16[7,2,16,4,20] parameter(0) + t0 = f16[2,16,4,20,7] transpose(p0), dimensions={1,2,3,4,0} + b0 = f16[2560,7] bitcast(t0) + a1 = f16[2560,5] parameter(1) + ROOT r = f16[7,5] dot(b0, a1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[7,2,16,4,20] parameter(0) + p1 = f16[2560,5] parameter(1) + ROOT fusion = f16[7,5] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + + TritonGemmConfig config(32, 32, 16, 1, 1, 4); + // 5 divides the contracting dimension, but not its major subdimensions. + config.split_k = 5; + EXPECT_THAT( + MakeDotSplitKBatch(module->entry_computation()->root_instruction(), + config), + tsl::testing::StatusIs(tsl::error::CANCELLED, + "Contracting dimension is too fragmented.")); + + // 8 fits the constraints. + config.split_k = 8; + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + HloInstruction* dot_fusion; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Convert( + m::Reduce(m::Fusion(&dot_fusion), m::ConstantScalar())))) + << module->ToString(); + const HloComputation* dot_computation = dot_fusion->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + TF_ASSERT_OK_AND_ASSIGN( + const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation, config.split_k)); + EXPECT_EQ(dot_computation->root_instruction()->shape(), + ShapeUtil::MakeShapeWithDescendingLayout(F32, {8, 7, 5})); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/2560, /*slice_start=*/0, + /*slice_limit=*/2560, + /*subfragments=*/ElementsAre(20, 4, 4, 4, 2), + /*broadcast_multiplier=*/1))); +} + TEST_F(SplitKTest, FragmentedKUnsupported) { const std::string hlo_text = R"( HloModule t @@ -586,7 +699,8 @@ ENTRY e { "Contracting dimension is too fragmented.")); } -TEST_F(SplitKTest, MakeSplitKWithNonDefaultOutputLayout) { +TEST_F(SplitKTestWithLowPreciseReduction, + MakeSplitKWithNonDefaultOutputLayout) { const std::string kHloText = R"( triton_gemm_dot.4842_computation { parameter_0 = bf16[96,96]{1,0} parameter(0) @@ -618,6 +732,40 @@ ENTRY e { TritonFusionAnalysis::Execute(*dot_computation)); } +TEST_F(SplitKTestWithMorePreciseReduction, + MakeSplitKWithNonDefaultOutputLayout) { + const std::string kHloText = R"( +triton_gemm_dot.4842_computation { + parameter_0 = bf16[96,96]{1,0} parameter(0) + parameter_1 = bf16[96,7]{1,0} parameter(1) + dot.0 = bf16[96,7]{0,1} dot(parameter_0, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT bitcast.2 = bf16[7,3,32]{2,1,0} bitcast(dot.0) +} + +ENTRY e { + parameter_0.91 = bf16[96,96]{1,0} parameter(0) + parameter_1.86 = bf16[96,7]{1,0} parameter(1) + ROOT triton_gemm_dot.4842 = bf16[7,3,32]{2,1,0} + fusion(parameter_0.91, parameter_1.86), kind=kCustom, + calls=triton_gemm_dot.4842_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + TritonGemmConfig config(16, 16, 16, 2, 1, 4); + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + HloInstruction* dot_fusion; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Convert( + m::Reduce(m::Fusion(&dot_fusion), m::ConstantScalar())))) + << module->ToString(); + const HloComputation* dot_computation = + dot_fusion->fused_instructions_computation(); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); +} + TEST_F(SplitKTest, SparseDotWithLhsSparseOperandIsRewritten) { const std::string hlo_text = R"( HloModule test @@ -684,18 +832,6 @@ ENTRY e { EXPECT_FALSE(result.ok()); } -class SplitKTestWithMorePreciseReduction - : public HloTestBase, - public ::testing::WithParamInterface { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( - true); - return debug_options; - } -}; - TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitK) { constexpr absl::string_view kHloText = R"( HloModule t @@ -715,7 +851,8 @@ ENTRY e { p0 = s8[3,128,5,32]{3,2,1,0} parameter(0) p1 = bf16[16,128]{1,0} parameter(1) ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1), - kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm", + metadata={op_name="foo"} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloText)); @@ -724,8 +861,12 @@ ENTRY e { TF_EXPECT_OK(MakeDotSplitKBatch( module->entry_computation()->root_instruction(), config)); + HloInstruction* reduce; EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Convert(m::Reduce(m::Fusion(), m::Constant())))); + GmockMatch(m::Convert(m::Op(&reduce) + .WithOpcode(HloOpcode::kReduce) + .WithOperand(0, m::Fusion())))); + EXPECT_EQ(reduce->metadata().op_name(), "foo"); } TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKWithOutputFusion) { @@ -805,7 +946,7 @@ ENTRY e { << module->ToString(); } -TEST_F(SplitKTest, MakeSplitKWithTransposeAfterDot) { +TEST_F(SplitKTestWithLowPreciseReduction, MakeSplitKWithTransposeAfterDot) { const std::string hlo_text = R"( triton_gemm_dot { p0 = f16[8,288,288]{2,1,0} parameter(0) @@ -836,6 +977,38 @@ ENTRY e { EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); } +TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKWithTransposeAfterDot) { + const std::string hlo_text = R"( +triton_gemm_dot { + p0 = f16[8,288,288]{2,1,0} parameter(0) + p1 = f16[8,288,32]{2,0,1} parameter(1) + d = f16[8,288,32]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT t = f16[288,8,32]{2,1,0} transpose(d), dimensions={1,0,2} +} + +ENTRY e { + p0 = f16[8,288,288]{2,1,0} parameter(0) + p1 = f16[8,288,32]{2,0,1} parameter(1) + ROOT fusion = f16[288,8,32]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TritonGemmConfig config(16, 128, 32, 8, 1, 4); + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + HloInstruction* dot_fusion; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Convert( + m::Reduce(m::Fusion(&dot_fusion), m::ConstantScalar())))) + << module->ToString(); + const auto* transpose = Cast( + dot_fusion->fused_instructions_computation()->root_instruction()); + EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); +} + TEST_F(SplitKTest, MakeSplitKWithTrivialDimension) { const std::string hlo_text = R"( triton_gemm_dot { From 6e41eff11fa4e1496af49a7df66fb42c67d2a2e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 03:20:05 -0700 Subject: [PATCH 0673/1324] Copy insertion will elide unnecessary copies based on checking the live range of the current schedule. Passes after copy-insertion might duplicate/reschedule nodes in a way that causes live range overlap and result in runtime corruption, this change adds dependencies to explicitly prevent that. To avoid the control predicates impacting the schedule unnecessarily, we add operands last to dfs stack during ForEachInstructionPostOrderImpl() and make GPU tests more permissive of op ordering. PiperOrigin-RevId: 747336466 --- third_party/xla/xla/service/copy_insertion.cc | 141 +++++++++++++----- third_party/xla/xla/service/copy_insertion.h | 3 +- 2 files changed, 107 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index d2b7f1a73f00d6..cb107c5e566d97 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/ptrvec.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" @@ -63,9 +64,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -615,33 +613,39 @@ class LiveRangeRegions { absl::InlinedVector computation_vector_; }; +#define RUNTIME_ORDER_LIST(V) \ + /* Indicates that there is no overlap whatsoever between the two regions. */ \ + V(kNoOverlap, 0) \ + /* Indicates that the first region includes the same set of instructions \ + as the second region. */ \ + V(kSameInstr, 1) \ + /* Indicates that the first region is entirely before the second region \ + starts. */ \ + V(kBeforeStart, 2) \ + /* Indicates that the first region is before the second region ends. */ \ + V(kBeforeStartOrSameInstr, kBeforeStart | kSameInstr) \ + /* Indicates that the first region is entirely after the second region \ + ends. */ \ + V(kAfterEnd, 4) \ + /* Indicates that the first region is after the second region \ + starts, with some instructions before the second region ends. */ \ + V(kAfterEndOrSameInstr, kAfterEnd | kSameInstr) \ + /* Indicates that the first region overlaps with the second one, but share \ + no common instructions. */ \ + V(kBeforeStartOrAfterEnd, kBeforeStart | kAfterEnd) \ + /* Indicates that the first region overlaps with the second one, and have \ + some common instructions. */ \ + V(kBeforeOrAfterOrOverlap, kBeforeStart | kAfterEnd | kSameInstr) + namespace { // Represent relations between the locations of two regions of instructions, // each region can include 0-n instructions. class Relation { public: enum RuntimeOrder { - // Indicate that there is no overlap whatsoever between the two regions. - kNoOverlap = 0, - // Indicate that the first region includes the same set of instructions as - // the second region. - kSameInstr = 1, - // Indicate that the first region is entirely before the second region - // starts. - kBeforeStart = 2, - // Indicate that the first region is before the second region ends. - kBeforeStartOrSameInstr = kBeforeStart | kSameInstr, - // Indicate that the first region is entirely after the second region ends. - kAfterEnd = 4, - // Indicate that the first region is after the second region - // starts, with some instructions before the second region ends. - kAfterEndOrSameInstr = kAfterEnd | kSameInstr, - // Indicate that the first region overlaps with the second one, but share no - // common instructions. - kBeforeStartOrAfterEnd = kBeforeStart | kAfterEnd, - // Indicate that the first region overlaps with the second one, and have - // some common instructions. - kBeforeOrAfterOrOverlap = kBeforeStart | kAfterEnd | kSameInstr, +#define DECLARE_ENUM(enum_name, enum_value) enum_name = enum_value, + RUNTIME_ORDER_LIST(DECLARE_ENUM) +#undef DECLARE_ENUM }; Relation() : intercept_def_use_(false) {} explicit Relation(RuntimeOrder order, bool intercept_def_use = false) @@ -699,8 +703,18 @@ class Relation { return orders_.size() == 1 && orders_[0] == kAfterEnd; } std::string ToString() const { - return absl::StrCat("Interception = ", intercept_def_use_, ";", - absl::StrJoin(orders_, ",")); + auto format_order = [](std::string* out, RuntimeOrder order) { + switch (order) { +#define DECLARE_CASE(enum_name, enum_value) \ + case enum_name: \ + absl::StrAppend(out, #enum_name); \ + break; + RUNTIME_ORDER_LIST(DECLARE_CASE) +#undef DECLARE_CASE + } + }; + return absl::StrCat("Interception = ", intercept_def_use_, " Orders = ", + absl::StrJoin(orders_, ", ", format_order), ","); } static bool DefinitionImpliesInterception(RuntimeOrder definition) { @@ -1523,8 +1537,8 @@ class CopyRemover { // live range interference is introduced by the copy's elimination. If // elision is possible, then the internal state (value lists) are updated, // and true is returned. Returns false otherwise. - bool TryElideCopy(const HloInstruction* copy, - int64_t* region_analysis_limit) { + bool TryElideCopy(const HloInstruction* copy, int64_t* region_analysis_limit, + bool insert_post_scheduling_control_dependencies) { VLOG(2) << "Trying to remove " << copy->name(); CHECK_NE(region_analysis_limit, nullptr); if (copy->shape().has_layout() && copy->operand(0)->shape().has_layout()) { @@ -1603,6 +1617,29 @@ class CopyRemover { VLOG(2) << "Region-based interference is false."; return false; }; + auto AddControlDependenciesBetween = [&](ValueNode* src, ValueNode* dst) { + if (src == nullptr || dst == nullptr) { + return; + } + for (auto use : src->uses) { + if (use->instruction->parent() != dst->value->instruction()->parent() || + use->instruction == dst->value->instruction()) { + // Don't add control dependencies if the use is in a different + // computation or if the use is the same as the destination. + continue; + } + + VLOG(2) << "Adding control dependency:"; + VLOG(2) << " From: " << use->instruction->ToString(); + VLOG(2) << " Use: " << use->ToString(); + + VLOG(2) << " To: " << dst->value->instruction()->ToShortString(); + VLOG(2) << " Value: " << dst->value->ToString(); + + CHECK_OK(use->instruction->AddControlDependencyTo( + dst->value->instruction())); + } + }; // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -1680,9 +1717,24 @@ class CopyRemover { // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. CheckLiveRangeBefore(copy_node.dest->prev, Next(*copy_node.src)); VLOG(2) << "LiveRangeBefore result: " << live_range_before; - if (!live_range_before && - CheckLiveRangeInterference(copy_node.src, copy_node.dest, - kMergeFirstDestInSource)) { + // If the live range is before, we can add control dependencies to ensure + // the ordering. Otherwise, we check for interference (which will + // also add control dependencies if needed) + if (live_range_before) { + if (insert_post_scheduling_control_dependencies) { + // Ensure that the last uses of the copy source (e.g. s_x) are + // ordered before the next definition of the copy destination buffer + // (d_1). + AddControlDependenciesBetween(copy_node.src, Next(*copy_node.dest)); + + // Also ensure that the last uses of the copy destination (e.g. d_m) + // are ordered before the next definition of the copy source buffer + // (s_{x+1}). + AddControlDependenciesBetween(copy_node.dest->prev, + Next(*copy_node.src)); + } + } else if (CheckLiveRangeInterference(copy_node.src, copy_node.dest, + kMergeFirstDestInSource)) { return false; } VLOG(2) << "Splice dest after source."; @@ -1712,9 +1764,23 @@ class CopyRemover { // Live range of 'last_src' must be before next_dest d_{y+1}. CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)); VLOG(2) << "LiveRangeBefore result: " << live_range_before; - if (!live_range_before && - CheckLiveRangeInterference(copy_node.src, copy_node.dest, - kMergeLastSourceInDest)) { + // If the live range is before, we can add control dependencies to ensure + // the ordering. Otherwise, we check for interference (which will + // also add control dependencies if needed) + if (live_range_before) { + if (insert_post_scheduling_control_dependencies) { + // Ensure that the last uses of the copy source (e.g. s_n) are + // ordered before the next definition of the copy destination buffer + // (d_{y+1}). + AddControlDependenciesBetween(Prev(*copy_node.dest), + copy_node.src->next); + // Also ensure that the last uses of the copy source (e.g. s_n) are + // ordered before next definition of the copy destination (e.g. + // d_{y+1}). + AddControlDependenciesBetween(copy_node.src, Next(*copy_node.dest)); + } + } else if (CheckLiveRangeInterference(copy_node.src, copy_node.dest, + kMergeLastSourceInDest)) { VLOG(2) << "Region-based analysis concludes interference."; return false; } @@ -2453,7 +2519,8 @@ static int64_t GetNumExistingCopies( absl::Status CopyInsertion::RemoveUnnecessaryCopies( HloModule* module, bool check_live_range_ordering, - const absl::flat_hash_set& execution_threads) { + const absl::flat_hash_set& execution_threads, + bool insert_post_scheduling_control_dependencies) { XLA_VLOG_LINES( 4, module->ToString(HloPrintOptions().set_syntax_sugar_async_ops(false))); @@ -2504,7 +2571,9 @@ absl::Status CopyInsertion::RemoveUnnecessaryCopies( ? 0 : std::min(allowance.analysis_allowance(), use_region_based_live_range_analysis_); - if (copy_remover.TryElideCopy(instruction, ®ion_analysis_cost_now)) { + if (copy_remover.TryElideCopy( + instruction, ®ion_analysis_cost_now, + insert_post_scheduling_control_dependencies)) { changed = true; TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/service/copy_insertion.h b/third_party/xla/xla/service/copy_insertion.h index e580405341c87a..9e1d8a6e676255 100644 --- a/third_party/xla/xla/service/copy_insertion.h +++ b/third_party/xla/xla/service/copy_insertion.h @@ -80,7 +80,8 @@ class CopyInsertion : public HloModulePass { // in all the existing aliased buffers. absl::Status RemoveUnnecessaryCopies( HloModule* module, bool check_live_range_ordering = false, - const absl::flat_hash_set& execution_threads = {}); + const absl::flat_hash_set& execution_threads = {}, + bool insert_post_scheduling_control_dependencies = false); // Add copies to address special constraints on the roots of computations not // related to live range interference: From bdace03697603e732105cb886aeca9caff5679fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:42:17 -0700 Subject: [PATCH 0674/1324] Automated Code Change PiperOrigin-RevId: 747360927 --- .../xla/python/pjrt_ifrt/basic_string_array_test.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc index e489973967b1bd..cc5a0a0893ac74 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -594,7 +594,8 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, array->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_EQ(disassembled_arrays.size(), 1); auto basic_string_array = @@ -616,7 +617,8 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, ShardedArrayDisassembleSuccess) { TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, array->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_EQ(disassembled_arrays.size(), 2); @@ -642,9 +644,10 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, FailsIfTheArrayHasBeenDeleted) { array->Delete(); - EXPECT_THAT( - array->DisassembleIntoSingleDeviceArrays(ArrayCopySemantics::kAlwaysCopy), - StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(array->DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards), + StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST(CopyTest, SuccessSingleDeviceShardedArray) { From a70879dd728cf22dbc0efaec0c6f51f5e04eb1bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:44:32 -0700 Subject: [PATCH 0675/1324] Automated Code Change PiperOrigin-RevId: 747361460 --- third_party/xla/xla/reference_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc index b4af564e8b91cc..3a5338af6299ef 100644 --- a/third_party/xla/xla/reference_util.cc +++ b/third_party/xla/xla/reference_util.cc @@ -448,7 +448,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloEvaluator evaluator; Literal result_literal = evaluator.Evaluate(*computation, {}).value(); - CHECK_EQ(result_literal.shape().dimensions_size(), 4); + CHECK_EQ(result_literal.shape().dimensions().size(), 4); auto result = std::make_unique>(result_literal.shape().dimensions(0), result_literal.shape().dimensions(1), From d47b747f6478694b929dea399a989181a3d66f0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:44:38 -0700 Subject: [PATCH 0676/1324] Automated Code Change PiperOrigin-RevId: 747361480 --- third_party/xla/xla/python/ifrt/test_util.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/ifrt/test_util.h b/third_party/xla/xla/python/ifrt/test_util.h index e83ac52685918c..e3e9b37536d97b 100644 --- a/third_party/xla/xla/python/ifrt/test_util.h +++ b/third_party/xla/xla/python/ifrt/test_util.h @@ -67,7 +67,8 @@ void AssertPerShardData( testing::ElementsAreArray(GetDeviceIds(expected_device_list))); TF_ASSERT_OK_AND_ASSIGN(auto actual_per_shard_arrays, actual->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_EQ(actual_per_shard_arrays.size(), expected_per_shard_data.size()); for (int i = 0; i < actual_per_shard_arrays.size(); ++i) { SCOPED_TRACE(absl::StrCat("Shard ", i)); From 7c8659cd279873b865f27a4dcd3e100b7bd48d0d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:46:21 -0700 Subject: [PATCH 0677/1324] Automated Code Change PiperOrigin-RevId: 747361862 --- third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 4f8fe6260401c6..185c921b8a0667 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -1548,7 +1548,8 @@ absl::StatusOr> TfrtGpuClient::BufferFromHostBuffer( device_shape, transfer_manager->ChooseCompactLayoutForShape(device_shape)); - absl::InlinedVector shape_strides(device_shape.dimensions_size()); + absl::InlinedVector shape_strides( + device_shape.dimensions().size()); TF_RETURN_IF_ERROR( ShapeUtil::ByteStrides(device_shape, absl::MakeSpan(shape_strides))); bool host_and_device_strides_equal = From fcbb17dc4b272bf5db8e9b13110b6a6d1cd80151 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:47:47 -0700 Subject: [PATCH 0678/1324] Automated Code Change PiperOrigin-RevId: 747362258 --- tensorflow/compiler/tf2xla/lib/scatter.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 5d5ec2e22f2c3e..91e357ec69eaa9 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -51,7 +51,7 @@ absl::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > buffer_shape.dimensions_size()) { + if (num_index_dims > buffer_shape.dimensions().size()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -140,11 +140,11 @@ absl::StatusOr XlaScatter( xla::ScatterDimensionNumbers dim_numbers; dim_numbers.set_index_vector_dim(indices_are_vectors - ? indices_shape.dimensions_size() - 1 - : indices_shape.dimensions_size()); + ? indices_shape.dimensions().size() - 1 + : indices_shape.dimensions().size()); - int64_t updates_rank = updates_shape.dimensions_size(); - int64_t buffer_rank = buffer_shape.dimensions_size(); + int64_t updates_rank = updates_shape.dimensions().size(); + int64_t buffer_rank = buffer_shape.dimensions().size(); int64_t num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -159,7 +159,7 @@ absl::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = updates_shape.dimensions_size(); + updates_rank = updates_shape.dimensions().size(); } if (updates_rank > 0) { From 7ad84c306c357d03cdcbeebab18549ab5d678bee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:48:56 -0700 Subject: [PATCH 0679/1324] Automated Code Change PiperOrigin-RevId: 747362647 --- third_party/xla/xla/python/custom_call_batch_partitioner.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/python/custom_call_batch_partitioner.cc b/third_party/xla/xla/python/custom_call_batch_partitioner.cc index acfa2387d7abde..d3ac1a7fe82434 100644 --- a/third_party/xla/xla/python/custom_call_batch_partitioner.cc +++ b/third_party/xla/xla/python/custom_call_batch_partitioner.cc @@ -159,7 +159,8 @@ std::pair ComputeResultShapeAndSharding( const Shape& shape, const HloSharding& batch_sharding, int64_t num_batch_dims) { if (!shape.IsTuple()) { - const int64_t num_replicate_dims = shape.dimensions_size() - num_batch_dims; + const int64_t num_replicate_dims = + shape.dimensions().size() - num_batch_dims; auto result_sharding = InsertNonBatchSharding(batch_sharding, num_replicate_dims); auto result_shape = spmd::MakePartitionedShape(shape, result_sharding); @@ -243,7 +244,7 @@ absl::Status CustomCallBatchPartitioner::Partition( partitioned_shapes_with_layout_constraints.reserve(num_operands); for (size_t i = 0; i < num_operands; ++i) { const int64_t num_replicate_dims = - hlo->operand(i)->shape().dimensions_size() - num_batch_dims; + hlo->operand(i)->shape().dimensions().size() - num_batch_dims; HloSharding operand_sharding = InsertNonBatchSharding(batch_sharding, num_replicate_dims); spmd::PartitionedHlo partitioned_operand = From 2dac39cf1e534144ac8ea65c2a7ed6cf0f4c1efa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:50:05 -0700 Subject: [PATCH 0680/1324] Automated Code Change PiperOrigin-RevId: 747362951 --- .../xla/xla/service/gpu/parallel_loop_emitter.cc | 2 +- third_party/xla/xla/service/gpu/reduction_utils.cc | 2 +- .../xla/xla/service/gpu/split_k_gemm_rewriter.cc | 8 ++++---- .../xla/xla/service/gpu/stream_executor_util.cc | 10 +++++----- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc index 565e907fb719e8..6fd2ce9b33099b 100644 --- a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc @@ -180,7 +180,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Value* linear_index_base = linear_base_and_thread_idx.linear_base; - std::vector multidim(shape_.dimensions_size(), nullptr); + std::vector multidim(shape_.dimensions().size(), nullptr); for (int i = 0; i < launch_config_.unroll_factor; ++i) { // The add operation is needed even if the offset is 0, since when the // kernel is unrolled, the following GEP instruction shares the same pointer diff --git a/third_party/xla/xla/service/gpu/reduction_utils.cc b/third_party/xla/xla/service/gpu/reduction_utils.cc index 5858bcfc2fc3c7..30f6fbb8d2aa5e 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.cc +++ b/third_party/xla/xla/service/gpu/reduction_utils.cc @@ -185,7 +185,7 @@ ReductionDimensions GetReductionKindAndContiguousComponents( Shape input_shape = reduce.operand(0)->shape(); absl::Span dims_to_reduce = reduce.dimensions(); DimensionVector dims_to_keep; - for (int64_t dim = 0; dim < input_shape.dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < input_shape.dimensions().size(); ++dim) { if (!absl::c_linear_search(dims_to_reduce, dim)) { dims_to_keep.push_back(dim); } diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 99dd94144ab4bf..53abd5d6fd0d78 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -203,7 +203,7 @@ absl::StatusOr MakeSplitKOperand( LiteralUtil::Zero(operand->shape().element_type()))); PaddingConfig padding_config = - MakeNoPaddingConfig(operand->shape().dimensions_size()); + MakeNoPaddingConfig(operand->shape().dimensions().size()); padding_config.mutable_dimensions(contracting_dim_idx) ->set_edge_padding_high(config.split_k - k % config.split_k); @@ -218,7 +218,7 @@ absl::StatusOr MakeSplitKOperand( const Shape& shape = operand->shape(); Shape new_shape(shape.element_type(), {}, {}); - for (int i = 0; i < shape.dimensions_size(); ++i) { + for (int i = 0; i < shape.dimensions().size(); ++i) { const int64_t dimension_size = shape.dimensions(i); if (i == contracting_dim_idx) { new_shape.add_dimensions(config.split_k); @@ -363,7 +363,7 @@ absl::Status MakeDotComputationSplitKBatch( auto* new_transpose = Cast(expanded); new_transpose->mutable_dimensions()->clear(); new_transpose->mutable_dimensions()->reserve( - new_transpose->shape().dimensions_size()); + new_transpose->shape().dimensions().size()); // The split-K batch dimension is always major. new_transpose->mutable_dimensions()->push_back(0); for (const int64_t dim : old_transpose->dimensions()) { @@ -386,7 +386,7 @@ absl::Status MakeDotComputationSplitKBatch( disable_reduced_precision_reduction, computation, operand); HloInstruction* convert = MakeConvertToHlo(operand, accumulator_dtype); std::vector broadcast_dimensions( - operand->shape().dimensions_size()); + operand->shape().dimensions().size()); absl::c_iota(broadcast_dimensions, 1); TF_RETURN_IF_ERROR(expanded->ReplaceOperandWithDifferentShape( i, diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 123d515811f2c9..4aaef2204a5913 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -334,15 +334,15 @@ FindVectorizedFeatureDims(const ConvolutionDimensionNumbers& dnums, const Shape& input, const Shape& filter, const Shape& output) { return { - FindVectorizedDim(input.dimensions_size(), dnums.input_batch_dimension(), - dnums.input_feature_dimension(), - dnums.input_spatial_dimensions()), - FindVectorizedDim(filter.dimensions_size(), + FindVectorizedDim( + input.dimensions().size(), dnums.input_batch_dimension(), + dnums.input_feature_dimension(), dnums.input_spatial_dimensions()), + FindVectorizedDim(filter.dimensions().size(), dnums.kernel_input_feature_dimension(), dnums.kernel_output_feature_dimension(), dnums.kernel_spatial_dimensions()), FindVectorizedDim( - output.dimensions_size(), dnums.output_batch_dimension(), + output.dimensions().size(), dnums.output_batch_dimension(), dnums.output_feature_dimension(), dnums.output_spatial_dimensions()), }; } From 681614e2c47b769b5244f215c5167ee5debc32c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:51:13 -0700 Subject: [PATCH 0681/1324] Automated Code Change PiperOrigin-RevId: 747363247 --- .../llvm_ir/dynamic_update_slice_util.cc | 2 +- .../xla/xla/service/llvm_ir/ir_array.cc | 32 +++++++++---------- .../xla/xla/service/llvm_ir/llvm_loop.cc | 4 +-- .../xla/xla/service/llvm_ir/loop_emitter.cc | 10 +++--- .../xla/xla/service/llvm_ir/sort_util.cc | 2 +- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc b/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc index 9780b77b5e3ad0..9b39776eba8bea 100644 --- a/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -130,7 +130,7 @@ static absl::Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. - const int64_t rank = output_shape.dimensions_size(); + const int64_t rank = output_shape.dimensions().size(); std::vector start_multi_index(rank); for (int64_t i = 0; i < rank; ++i) { TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i)); diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 5ee3841035ad6b..f0d5b2979dea41 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -92,8 +92,8 @@ void IrArray::Index::Delinearize(std::vector* multidim, llvm::Value* linear, const Shape& shape, absl::Span dynamic_dims, llvm::IRBuilderBase* b) const { - CHECK_EQ(shape.dimensions_size(), dynamic_dims.size()); - CHECK_EQ(multidim_.size(), shape.dimensions_size()); + CHECK_EQ(shape.dimensions().size(), dynamic_dims.size()); + CHECK_EQ(multidim_.size(), shape.dimensions().size()); llvm::Value* divisor = GetConstantWithIndexType(1); const Layout& layout = shape.layout(); for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) { @@ -119,7 +119,7 @@ void IrArray::Index::Delinearize(std::vector* multidim, IrArray::Index::Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilderBase* b) - : multidim_(shape.dimensions_size()), + : multidim_(shape.dimensions().size()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -134,13 +134,13 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, IrArray::Index::Index(llvm::Value* linear, absl::Span multidim, const Shape& shape, llvm::IRBuilderBase* b) - : multidim_(shape.dimensions_size()), + : multidim_(shape.dimensions().size()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { CHECK_NE(linear, nullptr); index_type_ = linear->getType(); - CHECK_EQ(multidim.size(), shape.dimensions_size()); + CHECK_EQ(multidim.size(), shape.dimensions().size()); for (auto dim : multidim) { if (dim) { CHECK_EQ(dim->getType(), index_type_); @@ -160,7 +160,7 @@ IrArray::Index::Index(llvm::Value* linear, IrArray::Index::Index(llvm::Value* linear, const Shape& shape, absl::Span dynamic_dims, llvm::IRBuilderBase* b) - : multidim_(shape.dimensions_size()), + : multidim_(shape.dimensions().size()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -186,7 +186,7 @@ IrArray::Index::Index(absl::Span multidim, dims_(shape.dimensions().begin(), shape.dimensions().end()), index_type_(index_type) { CHECK_NE(index_type_, nullptr); - CHECK_EQ(shape.dimensions_size(), multidim.size()); + CHECK_EQ(shape.dimensions().size(), multidim.size()); for (const auto* dim : multidim) { CHECK_NE(dim, nullptr); } @@ -212,7 +212,7 @@ IrArray::IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape) if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, shape_.dimensions_size()) << shape_.ShortDebugString(); + DCHECK_EQ(depth, shape_.dimensions().size()) << shape_.ShortDebugString(); } } @@ -228,9 +228,9 @@ bool IrArray::Index::LinearValidOnShape(const Shape& a) const { IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilderBase* builder) const { - CHECK_EQ(multidim_.size(), output_shape.dimensions_size()); + CHECK_EQ(multidim_.size(), output_shape.dimensions().size()); std::vector source_multidim_index( - input_shape.dimensions_size(), llvm::UndefValue::get(index_type_)); + input_shape.dimensions().size(), llvm::UndefValue::get(index_type_)); if (std::optional trivial_reshape = ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, @@ -394,7 +394,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, llvm::IRBuilderBase* builder) const { - int64_t rank = operand_shape.dimensions_size(); + int64_t rank = operand_shape.dimensions().size(); std::vector source_index(rank); for (int64_t i = 0; i < rank; ++i) { source_index[i] = multidim_[dimension_mapping[i]]; @@ -408,7 +408,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( // The other dimensions can be masked out with a div and a mod operation. std::vector logical_to_physical = LayoutUtil::MakeLogicalToPhysical(shape.layout()); - int64_t output_rank = shape.dimensions_size(); + int64_t output_rank = shape.dimensions().size(); // The minimum physical dimension that is broadcasted. int64_t min_broadcasted_dimension = output_rank; // The maximum physical dimension that is broadcasted. @@ -511,7 +511,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), shape_.dimensions_size()); + CHECK_EQ(index.size(), shape_.dimensions().size()); CHECK(index.ShapeIsCompatible(shape_)) << "Shape " << index.AsShapeWithType(shape_.element_type()).ToString(true) << " is not compatible with " << shape_.ToString(true); @@ -527,8 +527,8 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, if (!index.LinearValidOnShape(shape_)) { // Create a valid linear index. std::vector dimensions; - dimensions.reserve(shape_.dimensions_size()); - for (int64_t i = 0; i < shape_.dimensions_size(); ++i) { + dimensions.reserve(shape_.dimensions().size()); + for (int64_t i = 0; i < shape_.dimensions().size(); ++i) { dimensions.push_back(shape_.dimensions(i)); } llvm::Value* linearized = index.Linearize(dimensions, b); @@ -555,7 +555,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, CHECK_GT(index.size(), 0); std::vector gep_indices( 1, llvm::ConstantInt::get(index[0]->getType(), 0)); - for (int64_t i = 0; i < shape_.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape_.dimensions().size(); ++i) { int64_t dimension = LayoutUtil::Major(shape_.layout(), i); gep_indices.push_back(actual_index[dimension]); } diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc index 26f94c04377f3a..43ee77ef0e85be 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc @@ -244,7 +244,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { - std::vector dimensions(shape.dimensions_size()); + std::vector dimensions(shape.dimensions().size()); std::iota(dimensions.begin(), dimensions.end(), 0); return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix), shape, index_type_); @@ -253,7 +253,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, std::vector ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, absl::Span dimensions, absl::string_view suffix) { - std::vector multi_index(shape.dimensions_size()); + std::vector multi_index(shape.dimensions().size()); for (int64_t dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index d8b7a4026b4db0..906a58d1cebe71 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -50,7 +50,7 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, std::vector dynamic_dims, llvm::IRBuilderBase* b) : LoopEmitter::LoopEmitter(body_emitter, shape, b) { - CHECK_EQ(dynamic_dims.size(), shape_.dimensions_size()); + CHECK_EQ(dynamic_dims.size(), shape_.dimensions().size()); dynamic_dims_ = std::move(dynamic_dims); } @@ -125,12 +125,12 @@ IrArray::Index LoopEmitter::EmitStaticIndex(ForLoopNest* loop_nest, // Loops are added from outermost to innermost order with the ForLoopNest // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). - std::vector array_multi_index(shape_.dimensions_size()); + std::vector array_multi_index(shape_.dimensions().size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64_t dimension = LayoutUtil::Major(shape_.layout(), i); // Only unroll the most minor dimension, this seems to give us good runtime // performance with a large improvement in compile time. - auto unroll_mode = (i == shape_.dimensions_size() - 1) + auto unroll_mode = (i == shape_.dimensions().size() - 1) ? llvm_ir::UnrollMode::kDefaultUnroll : llvm_ir::UnrollMode::kNoUnroll; std::unique_ptr loop = loop_nest->AddLoop( @@ -149,12 +149,12 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, // Loops are added from outermost to innermost order with the ForLoopNest // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). - std::vector array_multi_index(shape_.dimensions_size()); + std::vector array_multi_index(shape_.dimensions().size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64_t dimension = LayoutUtil::Major(shape_.layout(), i); // Only unroll the most minor dimension, this seems to give us good runtime // performance with a large improvement in compile time. - auto unroll_mode = (i == shape_.dimensions_size() - 1) + auto unroll_mode = (i == shape_.dimensions().size() - 1) ? llvm_ir::UnrollMode::kDefaultUnroll : llvm_ir::UnrollMode::kNoUnroll; std::unique_ptr loop = loop_nest->AddLoop( diff --git a/third_party/xla/xla/service/llvm_ir/sort_util.cc b/third_party/xla/xla/service/llvm_ir/sort_util.cc index a574d0cd4af2f0..007868333e4a8b 100644 --- a/third_party/xla/xla/service/llvm_ir/sort_util.cc +++ b/third_party/xla/xla/service/llvm_ir/sort_util.cc @@ -340,7 +340,7 @@ absl::Status EmitSortInPlace( // comparisons). const Shape& keys_shape = values_arrays[0].GetShape(); - int64_t rank = keys_shape.dimensions_size(); + int64_t rank = keys_shape.dimensions().size(); int64_t dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); std::vector iteration_order_to_logical_order(rank); From 99d001b4509bcf8bfa2dfc69d38ea5bbda40c16a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:52:12 -0700 Subject: [PATCH 0682/1324] Automated Code Change PiperOrigin-RevId: 747363526 --- .../gpu/transforms/algebraic_simplifier.cc | 4 ++-- .../gpu/transforms/all_reduce_splitter.cc | 2 +- .../transforms/conv_padding_legalization.cc | 16 ++++++------- .../service/gpu/transforms/conv_rewriter.cc | 4 ++-- .../gpu/transforms/cublas_pad_for_gemms.cc | 14 +++++------ .../gpu/transforms/cudnn_norm_rewriter.cc | 15 ++++++------ .../transforms/cudnn_pad_for_convolutions.cc | 8 +++---- .../gpu/transforms/cudnn_simplify_padding.cc | 6 ++--- .../transforms/cudnn_simplify_padding_test.cc | 2 +- .../cudnn_vectorize_convolutions.cc | 4 ++-- .../service/gpu/transforms/dot_normalizer.cc | 4 ++-- .../gpu/transforms/dot_sparsity_rewriter.cc | 6 ++--- .../gemm_broadcast_folding_rewriter.cc | 4 ++-- .../transforms/gemm_fusion_swap_operands.cc | 8 +++---- .../service/gpu/transforms/gemm_rewriter.cc | 24 +++++++++---------- .../service/gpu/transforms/gemv_rewriter.cc | 4 ++-- .../gpu/transforms/gpusolver_rewriter.cc | 2 +- .../gpu/transforms/horizontal_input_fusion.cc | 2 +- .../gpu/transforms/horizontal_loop_fusion.cc | 4 ++-- .../gpu/transforms/layout_assignment.cc | 17 ++++++------- .../gpu/transforms/multi_output_fusion.cc | 2 +- .../gpu/transforms/nest_gemm_fusion.cc | 8 ++++--- .../ragged_all_to_all_decomposer.cc | 16 ++++++------- .../gpu/transforms/reduce_scatter_creator.cc | 4 ++-- .../reduction_degenerate_dim_remover.cc | 2 +- .../transforms/reduction_dimension_grouper.cc | 4 ++-- .../transforms/reduction_layout_normalizer.cc | 8 +++---- .../transforms/scatter_slice_simplifier.cc | 10 ++++---- .../gpu/transforms/softmax_rewriter_triton.cc | 4 ++-- .../service/gpu/transforms/sort_rewriter.cc | 2 +- .../gpu/transforms/topk_specializer.cc | 4 ++-- .../service/gpu/transforms/topk_splitter.cc | 2 +- .../transforms/transpose_dimension_grouper.cc | 9 +++---- .../gpu/transforms/windowed_einsum_handler.cc | 20 ++++++++-------- 34 files changed, 125 insertions(+), 120 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc index d591b06a98d8db..94a9115ef08756 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc @@ -139,10 +139,10 @@ bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce( DotDimensionNumbers dnums = dot->dot_dimension_numbers(); bool lhs_is_vector = (dnums.lhs_batch_dimensions_size() + dnums.lhs_contracting_dimensions_size() == - lhs->shape().dimensions_size()); + lhs->shape().dimensions().size()); bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() + dnums.rhs_contracting_dimensions_size() == - rhs->shape().dimensions_size()); + rhs->shape().dimensions().size()); // Strength-reduce vector-vector dots since they are not supported by // GemmFusion. if (lhs_is_vector && rhs_is_vector) { diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc index 5dbd2fa7382c0c..e3b156fdbd8344 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc @@ -126,7 +126,7 @@ std::optional GetSplitDim(const HloAllReduceInstruction& ar, const HloDynamicSliceInstruction& ds) { int split_dim = -1; int num_dims = 0; - for (int64_t dim = 0; dim < ar.shape().dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < ar.shape().dimensions().size(); ++dim) { if (ar.shape().dimensions(dim) != ds.shape().dimensions(dim)) { num_dims++; split_dim = dim; diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc index 9281f065a636cb..fc75d8f7c465ec 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc @@ -77,7 +77,7 @@ HloInstruction* MaybePaddedAndSlicedInput( // within cudnn is basically free, whereas a kPad's cost increases as the // amount of padding increases. PaddingConfig padding_config = - MakeNoPaddingConfig(input->shape().dimensions_size()); + MakeNoPaddingConfig(input->shape().dimensions().size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { int64_t dim = conv_dnums.input_spatial_dimensions(i); if (conv_window->dimensions(i).padding_low() > 0) { @@ -109,10 +109,10 @@ HloInstruction* MaybePaddedAndSlicedInput( // // For each dimension, initialize the start index to 0 and the limit index // to the size of that dimension. - std::vector start_indices(input->shape().dimensions_size(), 0); + std::vector start_indices(input->shape().dimensions().size(), 0); std::vector limit_indices(input->shape().dimensions().begin(), input->shape().dimensions().end()); - std::vector strides(input->shape().dimensions_size(), 1); + std::vector strides(input->shape().dimensions().size(), 1); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { int64_t dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or @@ -147,8 +147,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, // Compute the shape and padding config of the pad to be inserted. PaddingConfig padding_config; padding_config.mutable_dimensions()->Reserve( - kernel->shape().dimensions_size()); - for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { + kernel->shape().dimensions().size()); + for (size_t i = 0; i < kernel->shape().dimensions().size(); ++i) { padding_config.add_dimensions(); } for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { @@ -239,7 +239,7 @@ bool ConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = - MakeNoPaddingConfig(input->shape().dimensions_size()); + MakeNoPaddingConfig(input->shape().dimensions().size()); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { @@ -369,11 +369,11 @@ bool ConvPaddingLegalization::CanonicalizeBackwardInputConvolution( // // Initialize start_indices and limit_indices as no slicing. std::vector start_indices( - new_backward_conv->shape().dimensions_size(), 0LL); + new_backward_conv->shape().dimensions().size(), 0LL); std::vector limit_indices( new_backward_conv->shape().dimensions().begin(), new_backward_conv->shape().dimensions().end()); - std::vector strides(new_backward_conv->shape().dimensions_size(), + std::vector strides(new_backward_conv->shape().dimensions().size(), 1LL); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64_t padding_low = backward_conv->window().dimensions(i).padding_low(); diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc index 13aab8fd9203c8..de98e998172f3d 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc @@ -643,7 +643,7 @@ ConvolutionMatch MatchBackwardInput(HloInstruction* conv) { // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ..., // in_depth/G, G, out_depth / G] - std::vector transpose_dims(rhs->shape().dimensions_size()); + std::vector transpose_dims(rhs->shape().dimensions().size()); std::iota(transpose_dims.begin(), transpose_dims.end(), 0); transpose_dims.erase(transpose_dims.begin() + input_feature_dimension); transpose_dims.insert(transpose_dims.begin() + output_feature_dimension, @@ -742,7 +742,7 @@ HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution( // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H, // W, G, C] - std::vector transpose_dims(lhs->shape().dimensions_size()); + std::vector transpose_dims(lhs->shape().dimensions().size()); std::iota(transpose_dims.begin(), transpose_dims.end(), 0); transpose_dims.erase(transpose_dims.begin() + input_batch_dimension); transpose_dims.insert(transpose_dims.begin() + input_feature_dimension, diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc index d79d13051e0f41..f144a83423a6ce 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc @@ -64,8 +64,8 @@ static absl::StatusOr PadForGemm(HloDotInstruction* dot, // Since the dot instruction is canonicalized, the last two dimensions for // each operand represent non-batch dimensions, and the others are the same // for both operands and correspond to batch dimensions. - pad_dim(s, s.dimensions_size() - 2); - pad_dim(s, s.dimensions_size() - 1); + pad_dim(s, s.dimensions().size() - 2); + pad_dim(s, s.dimensions().size() - 1); return s; }; @@ -83,7 +83,7 @@ static absl::StatusOr PadForGemm(HloDotInstruction* dot, auto create_padding_config = [](Shape& shape, Shape& new_shape) { PaddingConfig padding_config; - for (int i = 0; i < shape.dimensions_size(); ++i) { + for (int i = 0; i < shape.dimensions().size(); ++i) { auto dimension = padding_config.add_dimensions(); dimension->set_edge_padding_high(new_shape.dimensions()[i] - shape.dimensions()[i]); @@ -113,8 +113,8 @@ static absl::StatusOr PadForGemm(HloDotInstruction* dot, HloInstruction* new_dot = parent->AddInstruction( dot->CloneWithNewOperands(new_result_shape, {lpad, rpad})); - std::vector start_indices(result_shape.dimensions_size(), 0); - std::vector strides(result_shape.dimensions_size(), 1); + std::vector start_indices(result_shape.dimensions().size(), 0); + std::vector strides(result_shape.dimensions().size(), 1); HloInstruction* slice = parent->AddInstruction( HloInstruction::CreateSlice(result_shape, new_dot, start_indices, result_shape.dimensions(), strides)); @@ -139,9 +139,9 @@ bool CheckCanonical(HloDotInstruction* dot) { const auto& dimension_numbers = dot->dot_dimension_numbers(); if (dimension_numbers.lhs_batch_dimensions_size() + 2 != - dot->operand(0)->shape().dimensions_size() || + dot->operand(0)->shape().dimensions().size() || dimension_numbers.rhs_batch_dimensions_size() + 2 != - dot->operand(1)->shape().dimensions_size()) { + dot->operand(1)->shape().dimensions().size()) { VLOG(2) << dot->ToString() << " is not canonical: Expected all dimensions but 2 to be " diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc index ab0f287935d9e3..d23a813d082033 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc @@ -184,7 +184,7 @@ std::vector AdjustedDimensions(const Shape& shape, absl::Span dimensions) { absl::flat_hash_map dimension_map; for (int64_t dimension = 0, non_degen_dimension = 0; - dimension < shape.dimensions_size(); ++dimension) { + dimension < shape.dimensions().size(); ++dimension) { if (shape.dimensions(dimension) > 1) { dimension_map.insert({dimension, non_degen_dimension}); non_degen_dimension++; @@ -328,13 +328,13 @@ std::vector MapDimensions(const Shape& original_shape, absl::flat_hash_map> dimensions_map; std::vector original_dimensions, reshaped_dimensions; for (int64_t original_dimension = 0, reshaped_dimension = 0; - original_dimension < original_shape.dimensions_size(); + original_dimension < original_shape.dimensions().size(); ++original_dimension) { original_dimensions.push_back(original_dimension); while ((reshaped_dimensions.empty() || dimension_product(reshaped_shape, reshaped_dimensions) < dimension_product(original_shape, original_dimensions)) && - reshaped_dimension < reshaped_shape.dimensions_size()) { + reshaped_dimension < reshaped_shape.dimensions().size()) { reshaped_dimensions.emplace_back(reshaped_dimension++); } @@ -827,7 +827,7 @@ auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center, .GetAsDouble({}) .value(); int64_t nelems = 1; - for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + for (int i = 0; i < instr->shape().dimensions().size(); ++i) { if (!absl::c_linear_search(instr->dimensions(), i)) { nelems *= instr->shape().dimensions()[i]; } @@ -994,7 +994,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { std::vector norm_dims_adjusted = AdjustedDimensions(reduce); if (norm_dims_adjusted.size() != ShapeUtil::DropDegenerateDimensions(scale->shape()) - .dimensions_size()) { + .dimensions() + .size()) { VLOG(1) << "Layer norm input dimensions not supported."; return absl::OkStatus(); } @@ -1015,7 +1016,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // If necessary, transpose the input so that the dimensions not being // normalized are the leading dimensions. std::vector non_norm_dims; - for (int64_t x_dim = 0; x_dim < x.Instr()->shape().dimensions_size(); + for (int64_t x_dim = 0; x_dim < x.Instr()->shape().dimensions().size(); ++x_dim) { if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) == norm_dims.end()) { @@ -1350,7 +1351,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // broadcasted dimensions. float actual_r_nelems = scalar->literal().GetAsDouble({}).value(); int64_t nelems = 1; - for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) { + for (int i = 0; i < broadcast->shape().dimensions().size(); ++i) { if (!absl::c_linear_search(broadcast->dimensions(), i)) { nelems *= broadcast->shape().dimensions()[i]; } diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc index 7a0200cc254fdf..a430c3a5ed8406 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc @@ -63,10 +63,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - PaddingConfig pad_config = MakeNoPaddingConfig(shape.dimensions_size()); + PaddingConfig pad_config = MakeNoPaddingConfig(shape.dimensions().size()); bool added_padding = false; - for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < shape.dimensions().size(); ++dim) { if (shape.dimensions(dim) == new_shape.dimensions(dim)) { continue; } @@ -126,10 +126,10 @@ static absl::Status PadConv(HloCustomCallInstruction* conv, // Slice the new conv result if necessary, keeping in mind that new_conv // has tuple shape (new_result_shape, u8[0]). if (!ShapeUtil::Equal(result_shape, new_result_shape)) { - std::vector start_indices(result_shape.dimensions_size(), 0); + std::vector start_indices(result_shape.dimensions().size(), 0); std::vector end_indices(result_shape.dimensions().begin(), result_shape.dimensions().end()); - std::vector strides(result_shape.dimensions_size(), 1); + std::vector strides(result_shape.dimensions().size(), 1); auto* new_conv_result = add( HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc index b58d6837acae93..fc24b74a0f76fb 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc @@ -73,7 +73,7 @@ std::optional FindFalseIndex(absl::Span vals) { std::optional FindOutputVectCDim(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); - int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size(); + int64_t num_dims = conv->shape().tuple_shapes(0).dimensions().size(); absl::InlinedVector seen_dims(num_dims); seen_dims[dnums.output_batch_dimension()] = true; seen_dims[dnums.output_feature_dimension()] = true; @@ -87,7 +87,7 @@ std::optional FindOutputVectCDim(HloInstruction* conv) { std::optional FindKernelVectCDim(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); - int64_t num_dims = conv->operand(1)->shape().dimensions_size(); + int64_t num_dims = conv->operand(1)->shape().dimensions().size(); absl::InlinedVector seen_dims(num_dims); seen_dims[dnums.kernel_input_feature_dimension()] = true; seen_dims[dnums.kernel_output_feature_dimension()] = true; @@ -127,7 +127,7 @@ std::optional NumTrailingZeroOutputFeatures(HloInstruction* conv) { // has modified the filter, making making it infeasible to get the original, // un-reordered value. if (!matched || feature_dim != 0 || - transpose->shape().dimensions_size() != 8) { + transpose->shape().dimensions().size() != 8) { VLOG(2) << "The filter output feature dimension cannot be determined, as " "the reordering sequence is modified"; return std::nullopt; diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc index f369ab9a0e8c47..a95b8920ea7ad5 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc @@ -627,7 +627,7 @@ TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) { // into a slice. ASSERT_THAT(root, GmockMatch(m::Slice( &slice, m::GetTupleElement(m::CustomCall(), 0)))); - for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) { + for (int64_t i = 0; i < slice->shape().dimensions().size(); ++i) { SCOPED_TRACE(i); EXPECT_EQ(slice->slice_starts(i), 0); EXPECT_EQ(slice->slice_strides(i), 1); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index fdef50d6b13415..33ef6df65dbc9c 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -131,7 +131,7 @@ static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) { // Transposes dimension `src` to right before `dst`. static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) { XlaBuilder& b = *instr.builder(); - int64_t rank = b.GetShape(instr)->dimensions_size(); + int64_t rank = b.GetShape(instr)->dimensions().size(); DimensionVector idxs(rank); absl::c_iota(idxs, 0); @@ -496,7 +496,7 @@ static absl::StatusOr TryVectorizeConv( return false; } - if (input_shape.dimensions_size() > + if (input_shape.dimensions().size() > 2 + dnums->input_spatial_dimensions_size()) { // Conv already has an extra dimension, which we assume is the vectorized // features dim. diff --git a/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc index 853562a43be3a3..8f0374174582c5 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc @@ -52,8 +52,8 @@ absl::StatusOr DotNormalizer::ExpandInstruction( dot->AddInstruction(HloInstruction::CreateBitcast(new_rhs_shape, rhs)); TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(1, normalized_rhs)); DotDimensionNumbers* dnums = dot->mutable_dot_dimension_numbers(); - dnums->add_lhs_contracting_dimensions(new_lhs_shape.dimensions_size() - 1); - dnums->add_rhs_contracting_dimensions(new_rhs_shape.dimensions_size() - 1); + dnums->add_lhs_contracting_dimensions(new_lhs_shape.dimensions().size() - 1); + dnums->add_rhs_contracting_dimensions(new_rhs_shape.dimensions().size() - 1); return nullptr; } diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc index 6e4654ec516c59..108e9acd77461d 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc @@ -71,12 +71,12 @@ class SparseDotRewriterImpl : public DfsHloRewriteVisitor { // Result dimensions: , , int batch_dims = dnums.lhs_batch_dimensions().size(); - int new_lhs_noncontracting = rhs->shape().dimensions_size() - batch_dims - + int new_lhs_noncontracting = rhs->shape().dimensions().size() - batch_dims - dnums.lhs_contracting_dimensions().size(); - int new_rhs_noncontracting = lhs->shape().dimensions_size() - batch_dims - + int new_rhs_noncontracting = lhs->shape().dimensions().size() - batch_dims - dnums.rhs_contracting_dimensions().size(); - int rank = dot->shape().dimensions_size(); + int rank = dot->shape().dimensions().size(); DimensionVector dimensions(rank); for (int i = 0; i < batch_dims; ++i) { dimensions[i] = i; diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc index 3395ed8a7db1c8..dceb042dfd7443 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc @@ -53,8 +53,8 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers(); int bcast_operand_index = instr->operand_index(bcast); - int num_bcast_dims = (bcast->shape().dimensions_size() - - bcast->operand(0)->shape().dimensions_size()); + int num_bcast_dims = (bcast->shape().dimensions().size() - + bcast->operand(0)->shape().dimensions().size()); int num_batch_dims = dim_nums->lhs_batch_dimensions_size(); const tsl::protobuf::RepeatedField &batch_dimensions = diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc index c625ed829760df..45c8fe8c954173 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc @@ -58,14 +58,14 @@ HloDotInstruction* MakeDotWithSwappedOperands(HloInstruction* dot) { const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers(); const size_t num_batch_dims = dot_dims.lhs_batch_dimensions_size(); const size_t num_lhs_noncontracting_dims = - dot->operand(0)->shape().dimensions_size() - num_batch_dims - + dot->operand(0)->shape().dimensions().size() - num_batch_dims - dot_dims.lhs_contracting_dimensions_size(); const size_t num_rhs_noncontracting_dims = - dot->operand(1)->shape().dimensions_size() - num_batch_dims - + dot->operand(1)->shape().dimensions().size() - num_batch_dims - dot_dims.rhs_contracting_dimensions_size(); std::vector out_shape_permutation; - out_shape_permutation.reserve(dot->shape().dimensions_size()); + out_shape_permutation.reserve(dot->shape().dimensions().size()); auto fill_permutation = [&](int64_t count, int64_t start) { while (count--) out_shape_permutation.push_back(start++); }; @@ -139,7 +139,7 @@ absl::StatusOr GetNonContractingDimsNumElements( operand_index == 0 ? dot_dims.lhs_contracting_dimensions() : dot_dims.rhs_contracting_dimensions(); const DimensionVector noncontracting_dim_indices = GetNonContractingDims( - shape.dimensions_size(), batch_dim_indices, contracting_dim_indices); + shape.dimensions().size(), batch_dim_indices, contracting_dim_indices); return absl::c_accumulate( noncontracting_dim_indices, int64_t{1}, [&](int64_t acc, int64_t dim) { return acc * shape.dimensions(dim); }); diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 59f6c02cb3a66e..342e1e667654a9 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -128,7 +128,7 @@ bool IsF8Type(const HloInstruction *instr) { Shape PadShapeToMultipleOf16(const Shape old_shape, const absl::Span batch_dims) { Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.dimensions_size(); ++i) { + for (int i = 0; i < old_shape.dimensions().size(); ++i) { if (!absl::c_linear_search(batch_dims, i)) { int64_t padded_dimension = RoundUpTo(old_shape.dimensions(i), 16); @@ -147,7 +147,7 @@ HloInstruction *PadOperandToTargetShape(const Shape &target, } PaddingConfig padding_config; - for (int i = 0; i < x->shape().dimensions_size(); ++i) { + for (int i = 0; i < x->shape().dimensions().size(); ++i) { auto dimension = padding_config.add_dimensions(); dimension->set_edge_padding_low(0); dimension->set_edge_padding_high(target.dimensions(i) - @@ -357,14 +357,14 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, auto input_shape = instr->shape(); // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. - std::vector permutation(input_shape.dimensions_size(), -1); + std::vector permutation(input_shape.dimensions().size(), -1); // Discard the batch dimensions. for (int64_t batch_dim : batch_dims) { permutation[batch_dim] = batch_dim; } // Identify the non-contracting dimension. int non_contracting_dim; - for (int i = 0; i < input_shape.dimensions_size(); ++i) { + for (int i = 0; i < input_shape.dimensions().size(); ++i) { if (permutation[i] == -1 && contracting_dim != i) { non_contracting_dim = i; } @@ -629,9 +629,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { int64_t lhs_batch_dims_size = instr->dot_dimension_numbers().lhs_batch_dimensions_size(); bool is_lhs_vector = - lhs->shape().dimensions_size() == lhs_batch_dims_size + 1; + lhs->shape().dimensions().size() == lhs_batch_dims_size + 1; bool is_rhs_vector = - rhs->shape().dimensions_size() == lhs_batch_dims_size + 1; + rhs->shape().dimensions().size() == lhs_batch_dims_size + 1; int64_t lhs_stride = is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size) : lhs->shape().dimensions(lhs_batch_dims_size) * @@ -1271,7 +1271,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const HloInstruction *input = param.commutative_ops.empty() ? param.fp8_input : param.commutative_ops.back().first; - if (input->shape().dimensions_size() != num_batch_dims + 2) { + if (input->shape().dimensions().size() != num_batch_dims + 2) { VLOG(1) << "Failed to rewrite " << instr->ToShortString() << "into FP8 Custom Call. Inputs must have exactly one " "contracting and one non-contracting dimension."; @@ -1386,8 +1386,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Slice the result of the GEMM if the operands were padded. HloInstruction *slice = nullptr; if (new_output_shape.dimensions() != instr->shape().dimensions()) { - std::vector start_indices(instr->shape().dimensions_size(), 0); - std::vector strides(instr->shape().dimensions_size(), 1); + std::vector start_indices(instr->shape().dimensions().size(), 0); + std::vector strides(instr->shape().dimensions().size(), 1); slice = instr->AddInstruction(HloInstruction::CreateSlice( instr->shape(), new_custom_call, start_indices, instr->shape().dimensions(), strides)); @@ -1625,7 +1625,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // To ensure correctness, only slices that chop off the ends of dimensions // are supported. if (slice) { - int slice_op_dim = slice->operand(0)->shape().dimensions_size(); + int slice_op_dim = slice->operand(0)->shape().dimensions().size(); if (slice->slice_starts() != std::vector(slice_op_dim, 0) || slice->slice_strides() != std::vector(slice_op_dim, 1)) { return absl::OkStatus(); @@ -1763,13 +1763,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); // # output column dims == # non-contracting rhs operand dims. const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); - size_t num_col_dims = gemm->operand(1)->shape().dimensions_size() - + size_t num_col_dims = gemm->operand(1)->shape().dimensions().size() - dot_dims.rhs_batch_dimensions_size() - dot_dims.rhs_contracting_dimensions_size(); if ((gemm->user_count() != 1) || (config.epilogue() != GemmBackendConfig::DEFAULT) || - (bias->shape().dimensions_size() != num_col_dims)) { + (bias->shape().dimensions().size() != num_col_dims)) { return false; } // We require the bias vector to have been broadcast in the most major diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc index f63091e21a2100..8684f34ec87302 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc @@ -67,11 +67,11 @@ class GemvRewriterVisitor : public DfsHloRewriteVisitor { // This pass relies on dot decomposer which ensures that all non-batch // dimensions are merged into one. bool lhs_has_non_contracting_dim = - lhs->shape().dimensions_size() == + lhs->shape().dimensions().size() == dim_numbers.lhs_batch_dimensions_size() + dim_numbers.lhs_contracting_dimensions_size() + 1; bool rhs_has_non_contracting_dim = - rhs->shape().dimensions_size() == + rhs->shape().dimensions().size() == dim_numbers.rhs_batch_dimensions_size() + dim_numbers.rhs_contracting_dimensions_size() + 1; diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc index 693ebc5158cd7b..aa9ac168002b6e 100644 --- a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc @@ -64,7 +64,7 @@ absl::StatusOr CreateCholesky( HloComputation* computation = operand->parent(); Shape a_shape = operand->shape(); - int ndim = a_shape.dimensions_size(); + int ndim = a_shape.dimensions().size(); CHECK_GE(ndim, 2); int64_t n = a_shape.dimensions(ndim - 1); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index de0ded3417619f..791bd2241fccd2 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -103,7 +103,7 @@ std::vector FindAndSortFusionCandidates( // shapes will be placed adjacent each other. // Sort `fusion_instrs` according to instruction counts, because // we'd like to fuse together computations of similar sizes. - return std::tuple{shape.dimensions_size(), shape.dimensions(), + return std::tuple{shape.dimensions().size(), shape.dimensions(), GetInstrCountOfFusible(*op), op->unique_id()}; }; return tuple_for_op(shape_a, a) < tuple_for_op(shape_b, b); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index a6cb22add0cd10..6f3d1d91144224 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -555,7 +555,7 @@ absl::Status HorizontalLoopFusionImpl::CreateFusedComputation( const HloInstruction* old_output = GetOutputsOfFusible(*fused_fusion_instrs[j])[i]; HloInstruction* new_output = clone_map[old_output]; - if (new_output->shape().dimensions_size() == 1) { + if (new_output->shape().dimensions().size() == 1) { instr_outputs[j] = new_output; } else { if (!LayoutUtil::IsMonotonicWithDim0Major( @@ -671,7 +671,7 @@ absl::Status HorizontalLoopFusionImpl::Fuse( HloInstruction * gep, MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++)); // This pass runs late, so useless bitcast won't be cleaned up. - if (output->shape().dimensions_size() == 1) { + if (output->shape().dimensions().size() == 1) { bitcasts_or_gte.push_back(gep); } else { bitcasts_or_gte.push_back(computation_->AddInstruction( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index fd1d6182e2ef2c..c09efa402862ff 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -107,7 +107,7 @@ HeuristicLayoutAssignment(const HloInstruction* instr, int num_spatial_dimensions = dnums.input_spatial_dimensions_size(); if (primitive_util::IsIntegralType(input_ty)) { if (input_ty == S8 && num_spatial_dimensions == 2 && - input_shape.dimensions_size() == 5) { + input_shape.dimensions().size() == 5) { VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString(); return kAllNCHW_VECT_C; } @@ -174,7 +174,7 @@ HeuristicLayoutAssignment(const HloInstruction* instr, cuda_compute_capability && cuda_compute_capability->IsAtLeast(se::CudaComputeCapability::kVolta); if (!isFloat16 || !is_volta || - instr->shape().tuple_shapes(0).dimensions_size() != 4) { + instr->shape().tuple_shapes(0).dimensions().size() != 4) { return kAllNCHW; } } else if (std::holds_alternative(gpu_version)) { @@ -184,7 +184,8 @@ HeuristicLayoutAssignment(const HloInstruction* instr, auto rocm_compute_capability = std::get(gpu_version); if (!isFloat16 || (!rocm_compute_capability.has_nhwc_layout_support()) || - instr->shape().tuple_shapes(0).dimensions_size() != 4 || !is_enabled) { + instr->shape().tuple_shapes(0).dimensions().size() != 4 || + !is_enabled) { return kAllNCHW; } } @@ -304,11 +305,11 @@ bool DotCanSupportShapeWithLayout(const HloInstruction* dot, // If we are able to construct a `MatrixLayout` then the dot can support // this layout. return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(), - dot->operand(0)->shape().dimensions_size() - + dot->operand(0)->shape().dimensions().size() - dot_dims.lhs_contracting_dimensions().size() - dot_dims.lhs_batch_dimensions().size(), dot_dims.rhs_batch_dimensions().size(), - dot->operand(1)->shape().dimensions_size() - + dot->operand(1)->shape().dimensions().size() - dot_dims.rhs_contracting_dimensions().size() - dot_dims.rhs_batch_dimensions().size()) .ok(); @@ -461,11 +462,11 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); } else if ((HloPredicateIsOp(instruction) || IsCubDeviceRadixSort(*instruction)) && - instruction->operand(0)->shape().dimensions_size() > 1) { + instruction->operand(0)->shape().dimensions().size() > 1) { // Make sure that all the operands and the output(s) have the same layout. Shape keys_shape = instruction->operand(0)->shape(); Layout keys_layout = - LayoutUtil::GetDefaultLayoutForRank(keys_shape.dimensions_size()); + LayoutUtil::GetDefaultLayoutForRank(keys_shape.dimensions().size()); for (int64_t i = 0; i < instruction->operand_count(); ++i) { Shape shape = instruction->operand(i)->shape(); *shape.mutable_layout() = keys_layout; @@ -485,7 +486,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( } else if (IsCustomCallToTopK(*instruction)) { // The output of the TopK custom call needs to have default layout. Layout default_layout = LayoutUtil::GetDefaultLayoutForRank( - instruction->operand(0)->shape().dimensions_size()); + instruction->operand(0)->shape().dimensions().size()); TF_ASSIGN_OR_RETURN( auto values_buffer, points_to_analysis_->GetBufferDefinedAt(instruction, {0})); diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index b0eadfa43ff9df..d1a52e87b56d71 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -103,7 +103,7 @@ FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1, auto& limits1 = slice1->slice_limits(); auto& limits2 = slice2->slice_limits(); - for (int64_t dim = 0; dim < parent->shape().dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < parent->shape().dimensions().size(); ++dim) { bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim]; if (!overlap) { return FusionDecision::Forbid("slices are non-overlapping"); diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index ba80a524e6998f..e7d579efe1acfd 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -176,7 +176,7 @@ absl::Status AnnotateDotOperandNestedFusionImpl( // We have a single contracting dimension, and a single non-contracting // dimension. All the other output tile sizes are set to 1. std::vector output_tile_sizes( - dot.operand(0)->shape().dimensions_size(), 1); + dot.operand(0)->shape().dimensions().size(), 1); output_tile_sizes[contracting_dimensions[0]] = contracting_dim_size; output_tile_sizes[non_contracting_dimensions[0]] = non_contracting_dim_size; @@ -258,7 +258,8 @@ absl::StatusOr> FindOutputTileSizesForEpilogue( VLOG(3) << "FindOutputTileSizesForEpilogue: dot shape: " << dot->shape().ToString(); - auto expected_dot_tile_sizes = get_tile_sizes(dot->shape().dimensions_size()); + auto expected_dot_tile_sizes = + get_tile_sizes(dot->shape().dimensions().size()); if (VLOG_IS_ON(2)) { std::ostringstream oss; for (const auto& size : expected_dot_tile_sizes) { @@ -271,7 +272,8 @@ absl::StatusOr> FindOutputTileSizesForEpilogue( // Try all permutations of the dot tile sizes to see if any of them satisfy // the constraints of the analysis and map to the given config of the dot. - int64_t out_rank = computation->root_instruction()->shape().dimensions_size(); + int64_t out_rank = + computation->root_instruction()->shape().dimensions().size(); VLOG(3) << "FindOutputTileSizesForEpilogue: computation root shape: " << computation->root_instruction()->shape().ToString(); auto output_tile_sizes = get_tile_sizes(out_rank); diff --git a/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer.cc b/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer.cc index e32655499fa597..8f0d8eaa50c3bd 100644 --- a/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer.cc +++ b/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer.cc @@ -132,12 +132,12 @@ HloInstruction* GetRowSlice(HloInstruction* hlo, int64_t row_index) { Shape row_shape = hlo->shape(); row_shape.set_dimensions(0, 1); - std::vector slice_start_indices(row_shape.dimensions_size(), 0); + std::vector slice_start_indices(row_shape.dimensions().size(), 0); slice_start_indices[0] = row_index; std::vector slice_limit_indices{row_shape.dimensions().begin(), row_shape.dimensions().end()}; slice_limit_indices[0] = row_index + 1; - std::vector slice_strides(row_shape.dimensions_size(), 1); + std::vector slice_strides(row_shape.dimensions().size(), 1); HloInstruction* row_slice = computation->AddInstruction(HloInstruction::CreateSlice( @@ -201,7 +201,7 @@ HloInstruction* PadOutermostDimension(HloComputation* computation, Shape padded_shape = hlo->shape(); PaddingConfig padding_config = - MakeNoPaddingConfig(padded_shape.dimensions_size()); + MakeNoPaddingConfig(padded_shape.dimensions().size()); padding_config.mutable_dimensions(0)->set_edge_padding_high(padding_size); padded_shape.set_dimensions(0, padded_shape.dimensions(0) + padding_size); @@ -233,7 +233,7 @@ std::vector RaggedToDense(HloComputation* computation, for (int64_t j = 0; j < num_updates_per_replica; ++j) { auto offset_multi_index = GetOffsetMultiIndex( computation, offsets, i * num_updates_per_replica + j, - ragged_input->shape().dimensions_size()); + ragged_input->shape().dimensions().size()); HloInstruction* padded_input = PadOutermostDimension(computation, ragged_input, max_update_size); @@ -266,7 +266,7 @@ HloInstruction* DenseToRagged(HloComputation* computation, int64_t num_updates_per_replica, int64_t max_update_size) { int64_t num_rows = offsets->shape().dimensions(0); - int64_t rank = ragged_output->shape().dimensions_size(); + int64_t rank = ragged_output->shape().dimensions().size(); Shape original_shape = ragged_output->shape(); @@ -280,9 +280,9 @@ HloInstruction* DenseToRagged(HloComputation* computation, for (int64_t i = 0; i < num_rows / num_updates_per_replica; ++i) { for (int64_t j = 0; j < num_updates_per_replica; ++j) { int idx = i * num_updates_per_replica + j; - auto offset_multi_index = - GetOffsetMultiIndex(computation, offsets, idx, - padded_ragged_output->shape().dimensions_size()); + auto offset_multi_index = GetOffsetMultiIndex( + computation, offsets, idx, + padded_ragged_output->shape().dimensions().size()); // `dense_inputs` is a tuple of updates for each replica. The number of // elements in the tuple is equal to the number of replicas. diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc index 012451df67f5db..9d361ba3fb6384 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc @@ -82,9 +82,9 @@ absl::StatusOr ReduceScatterCreator::Run( scatter_dim_size * ar_spec->group_size); rs_input = computation->AddInstruction(HloInstruction::CreateSlice( scatter_shape, rs_input, - std::vector(scatter_shape.dimensions_size(), 0), + std::vector(scatter_shape.dimensions().size(), 0), scatter_shape.dimensions(), - std::vector(scatter_shape.dimensions_size(), 1))); + std::vector(scatter_shape.dimensions().size(), 1))); } scatter_shape.set_dimensions(split_dim, scatter_dim_size); diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc index 4dc435af41f244..b6e727c3ccc09d 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc @@ -65,7 +65,7 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { auto reduced_dimensions = instr->dimensions(); int64_t shift = 0; - for (int dim = 0; dim < input_shape.dimensions_size(); dim++) { + for (int dim = 0; dim < input_shape.dimensions().size(); dim++) { if (input_shape.dimensions(dim) == 1) { shift++; } else { diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc index 615edc7dd6f3a6..ea4322d7fc9575 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc @@ -66,12 +66,12 @@ class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor { // Since we have enforced the standard layout, iteration over logical // dimensions is equivalent to iteration over the major-to-minor order. - for (int logical_dim = 0; logical_dim < shape.dimensions_size(); + for (int logical_dim = 0; logical_dim < shape.dimensions().size(); logical_dim++) { VLOG(5) << "Processing dimension " << logical_dim << " of size " << shape.dimensions(logical_dim); if (is_reduced(logical_dim) && - logical_dim < shape.dimensions_size() - 1 && + logical_dim < shape.dimensions().size() - 1 && is_reduced(logical_dim + 1)) { VLOG(5) << "This and consecutive dimension are reduced, merging"; changed = true; diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc index dda7b81ba6b469..38b8878eea128e 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc @@ -80,7 +80,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { // The layout order of the reduction output can be different to the // ordering of kept dimensions in the input operand, thus we need to // calculate the new layout. - DimensionVector new_reduce_shape_layout(reduce_shape.dimensions_size()); + DimensionVector new_reduce_shape_layout(reduce_shape.dimensions().size()); std::vector reduce_shape_logical_to_physical = LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout()); @@ -92,11 +92,11 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { }); }; - for (int i = 0; i < operand_shape.dimensions_size(); i++) { + for (int i = 0; i < operand_shape.dimensions().size(); i++) { // Process the dimensions in the major-to-minor order in order to // enforce the default layout. int64_t major_to_minor_dim_idx = - operand_shape.dimensions_size() - i - 1; + operand_shape.dimensions().size() - i - 1; int64_t logical_dim = operand_layout.minor_to_major(major_to_minor_dim_idx); int64_t dim_size = operand_shape.dimensions(logical_dim); @@ -113,7 +113,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { reduce_shape_logical_to_physical[logical_reduce_dim]; VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", " << "physical_reduce_dim = " << physical_reduce_dim; - new_reduce_shape_layout[reduce_shape.dimensions_size() - + new_reduce_shape_layout[reduce_shape.dimensions().size() - physical_reduce_dim - 1] = new_reduce_shape_data.size() - 1; } diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc index e6ffa39092ff9a..a74110fd507815 100644 --- a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc @@ -84,7 +84,7 @@ class ScatterSliceMatcher { // the original scatter dimensions. Return `false` if the update is not // possible. bool UpdateDimensions(const HloSliceInstruction* slice) { - int64_t rank = slice->shape().dimensions_size(); + int64_t rank = slice->shape().dimensions().size(); for (int64_t i = 0; i < rank; ++i) { if (slice->slice_starts(i) != 0 || slice->slice_strides(i) != 1) { return false; // The slice is not a truncation. @@ -145,10 +145,10 @@ class ScatterSliceMatcher { // Create a replacement operand for the scatter instruction. HloInstruction* CreateSliceFrom(HloInstruction* operand, const Shape& shape) { - std::vector start_indices(shape.dimensions_size(), 0); - std::vector limit_indices(shape.dimensions_size()); - std::vector strides(shape.dimensions_size(), 1); - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + std::vector start_indices(shape.dimensions().size(), 0); + std::vector limit_indices(shape.dimensions().size()); + std::vector strides(shape.dimensions().size(), 1); + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { limit_indices[i] = shape.dimensions(i); } return operand->AddInstruction(HloInstruction::CreateSlice( diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index a75b166bf2bd49..3a3948a221fcba 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -462,7 +462,7 @@ FusionDecision ShouldFuseReduction(const HloInstruction& reduce, if (reduce.dimensions().size() != 1 || reduce.dimensions(0) != - reduce.operand(0)->shape().dimensions_size() - 1) { + reduce.operand(0)->shape().dimensions().size() - 1) { return FusionDecision::Forbid( "The reductions in the diamond must reduce 1 dimension and that " "dimension must be the last dimension of the operand."); @@ -546,7 +546,7 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( producer = reduce->mutable_operand(0); if (absl::c_linear_search(broadcast->dimensions(), - broadcast->shape().dimensions_size() - 1)) { + broadcast->shape().dimensions().size() - 1)) { return FusionDecision::Forbid( "Broadcast is not along the reduction dimension."); } diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc index 85f68fd3da36d8..9f8967c5c56a36 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc @@ -427,7 +427,7 @@ bool IsCubCompatibleSort(const HloSortInstruction* sort_op) { } const Shape& operand_shape = sort_op->operand(0)->shape(); - if (sort_op->sort_dimension() != operand_shape.dimensions_size() - 1) { + if (sort_op->sort_dimension() != operand_shape.dimensions().size() - 1) { VLOG(2) << "Sort dimension should be the minor one"; return false; } diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc index 1cc6206ee8908a..6eb258e89db6b3 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc @@ -55,11 +55,11 @@ absl::StatusOr SmallBufferOptimization( primitive_util::LowercasePrimitiveTypeName(data_shape.element_type())); } // We only support topk of the shape [x] or [batch, x]. - if (data_shape.dimensions_size() > 2) { + if (data_shape.dimensions().size() > 2) { return InvalidArgument("Invalid input dimensions: %s", data_shape.ToString()); } - bool has_batch = data_shape.dimensions_size() == 2; + bool has_batch = data_shape.dimensions().size() == 2; constexpr size_t max_k = 16; constexpr size_t min_n = 1024; size_t n = data_shape.dimensions(has_batch ? 1 : 0); diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc index 385e06077b9c2c..82867c7a238e7f 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc @@ -60,7 +60,7 @@ class TopkSplitterVisitor : public DfsHloRewriteVisitor { } HloComputation* comp = inst->parent(); Shape data_shape = topk->operand(0)->shape(); - bool has_batch = data_shape.dimensions_size() == 2; + bool has_batch = data_shape.dimensions().size() == 2; // TODO(doak): Support multiple batches. if (has_batch && data_shape.dimensions(0) != 1) { return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc index c2c9e8d25701d7..2889be609b0da5 100644 --- a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc @@ -86,8 +86,8 @@ absl::InlinedVector GetNormalizedTransposeShapeHelper( return normalized_dims; } // Derive the permutation from the segments. - std::vector segment_to_normalized_dim(output_shape.dimensions_size(), - -1); + std::vector segment_to_normalized_dim( + output_shape.dimensions().size(), -1); for (size_t segment : segments) { segment_to_normalized_dim[output_to_input[segment]] = 0; } @@ -137,10 +137,11 @@ absl::InlinedVector GetNormalizedLogicalTransposeShape( absl::InlinedVector &permutation) { permutation.clear(); // Drop degenerate dimensions. - absl::InlinedVector delta(output_shape.dimensions_size() + 1, 0); + absl::InlinedVector delta(output_shape.dimensions().size() + 1, + 0); auto input_dimensions = ComposePermutations(output_shape.dimensions(), InversePermutation(dimensions)); - for (int i = 0; i < output_shape.dimensions_size(); ++i) { + for (int i = 0; i < output_shape.dimensions().size(); ++i) { delta[i + 1] = delta[i]; if (input_dimensions[i] == static_cast(1)) { ++delta[i + 1]; diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index 3d22a6c7e2590e..a58f90d52f945b 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -1009,15 +1009,15 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { // Each split is sliced out of the input buffer, we need to determine the // slice sizes and increments. - std::vector lhs_slice_sizes(a2a->shape().dimensions_size(), 0); - std::vector lhs_slice_increments(a2a->shape().dimensions_size(), - 1); + std::vector lhs_slice_sizes(a2a->shape().dimensions().size(), 0); + std::vector lhs_slice_increments( + a2a->shape().dimensions().size(), 1); std::vector lhs_slice_max_range( a2a->shape().dimensions().begin(), a2a->shape().dimensions().end()); - std::vector rhs_slice_sizes(rhs->shape().dimensions_size(), 0); - std::vector rhs_slice_increments(rhs->shape().dimensions_size(), - 1); + std::vector rhs_slice_sizes(rhs->shape().dimensions().size(), 0); + std::vector rhs_slice_increments( + rhs->shape().dimensions().size(), 1); std::vector rhs_slice_max_range( rhs->shape().dimensions().begin(), rhs->shape().dimensions().end()); @@ -1248,17 +1248,17 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { // Each split is sliced out of the input buffer, we need to determine the // slice sizes and increments. std::vector lhs_slice_sizes( - matched_result.lhs->shape().dimensions_size(), 0); + matched_result.lhs->shape().dimensions().size(), 0); std::vector lhs_slice_increments( - matched_result.lhs->shape().dimensions_size(), 1); + matched_result.lhs->shape().dimensions().size(), 1); std::vector lhs_slice_max_range( matched_result.lhs->shape().dimensions().begin(), matched_result.lhs->shape().dimensions().end()); std::vector rhs_slice_sizes( - matched_result.rhs->shape().dimensions_size(), 0); + matched_result.rhs->shape().dimensions().size(), 0); std::vector rhs_slice_increments( - matched_result.rhs->shape().dimensions_size(), 1); + matched_result.rhs->shape().dimensions().size(), 1); std::vector rhs_slice_max_range( matched_result.rhs->shape().dimensions().begin(), matched_result.rhs->shape().dimensions().end()); From b25a028745985424353b925340b10d9b7bd912ac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:54:16 -0700 Subject: [PATCH 0683/1324] Automated Code Change PiperOrigin-RevId: 747364157 --- .../xla/xla/service/cpu/cpu_executable.cc | 2 +- .../xla/service/cpu/cpu_instruction_fusion.cc | 11 ++--- .../cpu/cpu_instruction_fusion_test.cc | 3 +- .../xla/service/cpu/cpu_layout_assignment.cc | 4 +- .../service/cpu/cpu_layout_assignment_test.cc | 9 ++-- third_party/xla/xla/service/cpu/cpu_xfeed.cc | 2 +- .../xla/xla/service/cpu/dot_op_emitter.cc | 42 +++++++++---------- .../xla/xla/service/cpu/ir_emission_utils.cc | 9 ++-- third_party/xla/xla/service/cpu/ir_emitter.cc | 18 ++++---- .../xla/service/cpu/parallel_loop_emitter.cc | 2 +- .../xla/xla/service/cpu/thunk_emitter.cc | 2 +- .../xla/xla/service/cpu/xfeed_manager.cc | 2 +- 12 files changed, 56 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 8ee9339e20b8c9..a0c5e6590bd493 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -513,7 +513,7 @@ absl::StatusOr CpuExecutable::ExecuteAsyncOnStream( return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } // Each dynamic dimension size is represented as a S32. - int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); + int64_t metadata_size = sizeof(int32_t) * shape.dimensions().size(); return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size; } diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index f224ab2993c420..bd26c1c87621cd 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -54,7 +54,8 @@ bool CanBeLoopFused(const HloInstruction& hlo) { bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { const Shape& hlo_shape = hlo->shape(); return !ShapeUtil::ElementIsComplex(hlo_shape) && - hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1 && + hlo->opcode() == HloOpcode::kDot && + hlo_shape.dimensions().size() <= 1 && hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } @@ -161,7 +162,7 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // better job with pure data movement loops. auto is_minor_dim_concatenate = [](const HloInstruction* hlo) { // For vectors it's always beneficial to fuse concatenations. - if (hlo->shape().dimensions_size() <= 1) return false; + if (hlo->shape().dimensions().size() <= 1) return false; // For small concatenated dimensions we don't loose any performance by // fusing the concatenation as we don't have opportunities for vectorization @@ -231,19 +232,19 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // fusion can easily be overshadowed by the overhead of a naive GEMM // algorithm in the IR. const Shape& output_shape = consumer->shape(); - if (output_shape.dimensions_size() <= 1) { + if (output_shape.dimensions().size() <= 1) { // We fuse in cases where we have a matrix*vector or vector*matrix dot and // fusion can get rid of the larger tensor. We assume that a naive // traversal of a small enough (to fit in L1) column or row tensor is // "good enough" from the perspective of cache management; and calling out // to an optimized GEMM kernel is not a huge win. - if (consumer->operand(0)->shape().dimensions_size() == 1 && + if (consumer->operand(0)->shape().dimensions().size() == 1 && operand_index == 1 && ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; return FusionDecision::Allow(); - } else if (consumer->operand(1)->shape().dimensions_size() == 1 && + } else if (consumer->operand(1)->shape().dimensions().size() == 1 && operand_index == 0 && ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) < kFusionThresholdBytes) { diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc index d7d273de3c59b1..e348ffe5531537 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -47,7 +47,8 @@ using InstructionFusionTest = HloTestBase; std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(lhs->shape().dimensions_size() - 1); + dot_dnums.add_lhs_contracting_dimensions(lhs->shape().dimensions().size() - + 1); dot_dnums.add_rhs_contracting_dimensions(0); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index c6b67698a9ce68..a2ae89ee50f546 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -93,7 +93,7 @@ static Shape RowMajorShape(Shape shape) { if (!subshape->IsArray()) { return; } - std::vector dimension_order(subshape->dimensions_size()); + std::vector dimension_order(subshape->dimensions().size()); std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); *subshape->mutable_layout() = LayoutUtil::MakeLayout(dimension_order); }); @@ -102,7 +102,7 @@ static Shape RowMajorShape(Shape shape) { static Shape ColMajorShape(const Shape& old_shape) { Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); + std::vector dimension_order(new_shape.dimensions().size()); std::iota(dimension_order.begin(), dimension_order.end(), 0); *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); return new_shape; diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc index 108c1b3be91c41..5ce5edad657ec8 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc @@ -349,7 +349,8 @@ static void AssertCorrectLayoutForDotOutputFusion( ? LayoutUtil::MakeLayout({0, 1}) : LayoutUtil::MakeLayout({1, 0}); if (layout_assignment_result.dot_rhs_fusion_param->shape() - .dimensions_size() == 1) { + .dimensions() + .size() == 1) { expected_dot_rhs_layout = LayoutUtil::MakeLayout({0}); } EXPECT_TRUE(LayoutUtil::Equal( @@ -359,13 +360,15 @@ static void AssertCorrectLayoutForDotOutputFusion( EXPECT_TRUE(LayoutUtil::Equal( LayoutUtil::MakeDescendingLayout( layout_assignment_result.dot_lhs_fusion_param->shape() - .dimensions_size()), + .dimensions() + .size()), layout_assignment_result.dot_lhs_fusion_param->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( LayoutUtil::MakeDescendingLayout( layout_assignment_result.addend_fusion_param->shape() - .dimensions_size()), + .dimensions() + .size()), layout_assignment_result.addend_fusion_param->shape().layout())); EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); } diff --git a/third_party/xla/xla/service/cpu/cpu_xfeed.cc b/third_party/xla/xla/service/cpu/cpu_xfeed.cc index bc89ecde2dd198..c608fe85134359 100644 --- a/third_party/xla/xla/service/cpu/cpu_xfeed.cc +++ b/third_party/xla/xla/service/cpu/cpu_xfeed.cc @@ -309,7 +309,7 @@ absl::Status ReadDynamicShapesOnCpu( reinterpret_cast(buffer_8 + offset); // Update shape size from metadata. - for (int64_t i = 0; i < device_sub_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < device_sub_shape.dimensions().size(); ++i) { device_sub_shape.mutable_dimensions()[i] = metadata_buffer[i]; } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index 5f1af495522768..310605eb7bfc5f 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -79,7 +79,7 @@ bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) { } // Return whether the given shape is rank 2. -bool IsRank2(const Shape& shape) { return shape.dimensions_size() == 2; } +bool IsRank2(const Shape& shape) { return shape.dimensions().size() == 2; } bool IsSimpleLayout(const Layout& layout) { return layout.tiles().empty() && LayoutUtil::IsDense(layout); @@ -178,8 +178,8 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - if ((dot_info.result_shape.dimensions_size() <= 1 || - (dot_info.result_shape.dimensions_size() == 2 && + if ((dot_info.result_shape.dimensions().size() <= 1 || + (dot_info.result_shape.dimensions().size() == 2 && (dot_info.result_shape.dimensions(0) == 1 || dot_info.result_shape.dimensions(1) == 1))) && (primitive_util::IsFloatingPointType(element_type) || @@ -188,12 +188,12 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( } // MatMul smaller than 3x3 should use naive nested loop. - if ((dot_info.lhs_shape.dimensions_size() <= 1 || - (dot_info.lhs_shape.dimensions_size() == 2 && + if ((dot_info.lhs_shape.dimensions().size() <= 1 || + (dot_info.lhs_shape.dimensions().size() == 2 && (dot_info.lhs_shape.dimensions(0) <= 3 || dot_info.lhs_shape.dimensions(1) <= 3))) && - (dot_info.rhs_shape.dimensions_size() <= 1 || - (dot_info.rhs_shape.dimensions_size() == 2 && + (dot_info.rhs_shape.dimensions().size() <= 1 || + (dot_info.rhs_shape.dimensions().size() == 2 && (dot_info.rhs_shape.dimensions(0) <= 3 || dot_info.rhs_shape.dimensions(1) <= 3))) && (primitive_util::IsFloatingPointType(element_type) || @@ -962,14 +962,14 @@ absl::Status DotOpEmitter::EmitCallToBatchRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_LE(dot_info_.result_shape.dimensions_size(), 2); + CHECK_LE(dot_info_.result_shape.dimensions().size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; auto is_column_major = [](const Shape& shape) { - return shape.dimensions_size() > 1 && + return shape.dimensions().size() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; }; @@ -978,29 +978,29 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0); return { - /*m=*/lhs_shape.dimensions_size() <= 1 + /*m=*/lhs_shape.dimensions().size() <= 1 ? 1LL : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)), /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), - /*n=*/rhs_shape.dimensions_size() <= 1 + /*n=*/rhs_shape.dimensions().size() <= 1 ? 1LL : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), - /*lhs_canonical=*/lhs_shape.dimensions_size() <= 1 || + /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || dim_nums.lhs_contracting_dimensions(0) == 1, /*rhs_column_major=*/is_column_major(rhs_shape), /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0}; } DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const { - CHECK_LE(dot_info_.result_shape.dimensions_size(), 2); + CHECK_LE(dot_info_.result_shape.dimensions().size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; auto is_column_major = [](const Shape& shape) { - return shape.dimensions_size() > 1 && + return shape.dimensions().size() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; }; @@ -1009,15 +1009,15 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const { CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0); return { - /*m=*/lhs_shape.dimensions_size() <= 1 + /*m=*/lhs_shape.dimensions().size() <= 1 ? 1LL : lhs_shape.dimensions(2LL - dim_nums.lhs_contracting_dimensions(0)), /*k=*/lhs_shape.dimensions(1LL + dim_nums.lhs_contracting_dimensions(0)), - /*n=*/rhs_shape.dimensions_size() <= 1 + /*n=*/rhs_shape.dimensions().size() <= 1 ? 1LL : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), - /*lhs_canonical=*/lhs_shape.dimensions_size() <= 1 || + /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || dim_nums.lhs_contracting_dimensions(0) == 1, /*rhs_column_major=*/is_column_major(rhs_shape), /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0}; @@ -1027,8 +1027,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const { // column major. std::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo) { - if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) { - if (hlo.operand(0)->shape().dimensions_size() != 1 || + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions().size() <= 1) { + if (hlo.operand(0)->shape().dimensions().size() != 1 || hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) { return {}; } @@ -1118,7 +1118,7 @@ llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilderBase* b, const Shape& shape = array.GetShape(); CHECK(shape.has_layout() && LayoutUtil::IsMonotonicWithDim0Major(shape.layout())); - CHECK_GE(shape.dimensions_size(), n); + CHECK_GE(shape.dimensions().size(), n); Shape new_shape = CollapseFirstNDims(shape, n); llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, b->getContext()); return llvm_ir::IrArray(array.GetBasePointer(), new_ir_type, @@ -1145,7 +1145,7 @@ llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, llvm::Value* batch_index, llvm::IRBuilderBase* b) { Shape inner_shape = DropFirstDim(outer_array.GetShape()); - std::vector multidim_index(inner_shape.dimensions_size() + 1, + std::vector multidim_index(inner_shape.dimensions().size() + 1, b->getInt64(0)); multidim_index[0] = batch_index; llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(), diff --git a/third_party/xla/xla/service/cpu/ir_emission_utils.cc b/third_party/xla/xla/service/cpu/ir_emission_utils.cc index e8002970917811..a765f921d8d258 100644 --- a/third_party/xla/xla/service/cpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/cpu/ir_emission_utils.cc @@ -104,14 +104,15 @@ bool PotentiallyImplementedAsEigenConvolution( } return dnums.input_batch_dimension() == 0 && - dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && + dnums.input_feature_dimension() == + input_shape.dimensions().size() - 1 && dnums.output_batch_dimension() == 0 && dnums.output_feature_dimension() == - output_shape.dimensions_size() - 1 && + output_shape.dimensions().size() - 1 && dnums.kernel_input_feature_dimension() == - kernel_shape.dimensions_size() - 2 && + kernel_shape.dimensions().size() - 2 && dnums.kernel_output_feature_dimension() == - kernel_shape.dimensions_size() - 1; + kernel_shape.dimensions().size() - 1; } } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 0177792ab8fe90..e28edbd3177c73 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -696,7 +696,7 @@ absl::Status IrEmitter::HandleSort(HloInstruction* hlo) { higher_dimensions *= normalized_keys_shape.dimensions(i); } int64_t lower_dimensions = 1; - for (int64_t i = normalized_keys_shape.dimensions_size() - 1; + for (int64_t i = normalized_keys_shape.dimensions().size() - 1; i > physical_dimension_to_sort; --i) { lower_dimensions *= normalized_keys_shape.dimensions(i); } @@ -827,7 +827,7 @@ absl::Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // convolutions, except that we pretend that the 1D convolution is really // a 2D convolution with the missing dimension set to 1. We also adjust // the padding, dilation parameters as needed. - bool one_dim_convolution = lhs_shape.dimensions_size() == 3; + bool one_dim_convolution = lhs_shape.dimensions().size() == 3; llvm::Value* lhs_address = GetEmittedValueFor(lhs); llvm::Value* rhs_address = GetEmittedValueFor(rhs); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); @@ -1009,7 +1009,7 @@ absl::Status IrEmitter::HandleFft(HloInstruction* fft) { // Flatten operand batches. absl::InlinedVector operand_shape_flat(fft_rank + 1); int64_t input_batch = 1; - int64_t input_batch_length = fft->shape().dimensions_size() - fft_rank; + int64_t input_batch_length = fft->shape().dimensions().size() - fft_rank; for (int i = 0; i < input_batch_length; i++) { input_batch *= operand->shape().dimensions(i); } @@ -1443,7 +1443,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { const Shape& result_shape = reduce.shape(); int64_t delta = 0; - for (int64_t i = 0; i < operand_shape.dimensions_size(); i++) { + for (int64_t i = 0; i < operand_shape.dimensions().size(); i++) { if (reduced_dims.contains(i)) { delta++; } else { @@ -1455,7 +1455,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { // dimensions in the source and target shapes are equivalent. int64_t result_dim_idx = 0; for (int64_t operand_dim_idx = 0; - operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { + operand_dim_idx < operand_shape.dimensions().size(); operand_dim_idx++) { int64_t operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); if (!reduced_dims.contains(operand_dim)) { @@ -1466,7 +1466,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { } } - CHECK_EQ(result_dim_idx, result_shape.dimensions_size()); + CHECK_EQ(result_dim_idx, result_shape.dimensions().size()); return true; } @@ -1811,7 +1811,7 @@ absl::StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::ForLoopNest loop_nest(IrName(reduce), b()); std::vector array_multi_index( - reduce->shape().dimensions_size()); + reduce->shape().dimensions().size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64_t dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -1969,7 +1969,7 @@ absl::Status IrEmitter::HandleSlice(HloInstruction* slice) { } const Layout& layout = operand->shape().layout(); - const int64_t num_dims = operand->shape().dimensions_size(); + const int64_t num_dims = operand->shape().dimensions().size(); // The slice lowering finds maximal contiguous blocks of memory that can be // copied from the source to the target. This is done by looking at the @@ -2427,7 +2427,7 @@ absl::Status IrEmitter::HandleTopK(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); const HloInstruction* input = hlo->operand(0); const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back(); - const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2; + const bool has_batch = hlo->shape().tuple_shapes(0).dimensions().size() == 2; TF_RET_CHECK(input->shape().element_type() == F32) << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( hlo->shape().tuple_shapes(0).layout())) diff --git a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc index ba5bec52bd18f5..2bfffd88df937e 100644 --- a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc @@ -50,7 +50,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, b_); - const int64_t num_dims = shape_.dimensions_size(); + const int64_t num_dims = shape_.dimensions().size(); std::vector array_multi_index(num_dims); // Add loops from outer-most to inner-most dimensions. diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 7d88c6d3bb3a04..dbbf31c627c2c2 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -941,7 +941,7 @@ absl::StatusOr ThunkEmitter::EmitTopKThunk( // Deduce parameters from the result shape and operand shape const int64_t input_size = input->shape().dimensions().back(); - const bool has_batch = result_shape.tuple_shapes(0).dimensions_size() == 2; + const bool has_batch = result_shape.tuple_shapes(0).dimensions().size() == 2; const int64_t batch_size = has_batch ? result_shape.tuple_shapes(0).dimensions(0) : 1; const int64_t k = result_shape.tuple_shapes(0).dimensions().back(); diff --git a/third_party/xla/xla/service/cpu/xfeed_manager.cc b/third_party/xla/xla/service/cpu/xfeed_manager.cc index d7d40ff09e1b9b..cc1a953ef9bd21 100644 --- a/third_party/xla/xla/service/cpu/xfeed_manager.cc +++ b/third_party/xla/xla/service/cpu/xfeed_manager.cc @@ -69,7 +69,7 @@ int64_t GetByteSizeRequirement(const Shape& shape, int64_t pointer_size) { if (shape.IsTuple() || shape.is_static()) { return ShapeUtil::ByteSizeOf(shape, pointer_size); } - int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); + int64_t metadata_size = sizeof(int32_t) * shape.dimensions().size(); return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size; } From 3b5d6e443a07b03c2f28007a48d45dd6e7b5f224 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:56:04 -0700 Subject: [PATCH 0684/1324] Automated Code Change PiperOrigin-RevId: 747364622 --- .../xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 0b7e51cd2a952c..50beeddeab1c6c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -104,8 +104,8 @@ struct GemmWithDynamicSlice { // Returns OK if dot instruction is a simple 2D row-major gemm. absl::Status MatchRowMajorGemm(HloDotInstruction* dot) { - if (dot->operand(0)->shape().dimensions_size() != 2 || - dot->operand(1)->shape().dimensions_size() != 2) { + if (dot->operand(0)->shape().dimensions().size() != 2 || + dot->operand(1)->shape().dimensions().size() != 2) { return absl::InternalError("operands must have rank 2"); } From 751fb91f608a2f8bdd81859ae82670ee8bc95d30 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:56:20 -0700 Subject: [PATCH 0685/1324] Automated Code Change PiperOrigin-RevId: 747364677 --- .../gpu/codegen/triton/fusion_emitter_legacy_matmul.cc | 6 +++--- .../xla/xla/backends/gpu/codegen/triton/support_legacy.cc | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index 98870b577d36e8..b0ba0148783952 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -735,9 +735,9 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) { // split-K, batch, non-contracting LHS, non-contracting RHS, // where split-K and batch are optional. matmul_dims.out_rhs_noncontracting_dim_idx = - dot.shape().dimensions_size() - 1; + dot.shape().dimensions().size() - 1; matmul_dims.out_lhs_noncontracting_dim_idx = - dot.shape().dimensions_size() - 2; + dot.shape().dimensions().size() - 2; auto* root = dot.parent()->root_instruction(); auto iter_spec = @@ -852,7 +852,7 @@ absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, TF_RET_CHECK(dims.lhs_contracting_dimensions_size() == 1); TF_RET_CHECK(dims.rhs_contracting_dimensions_size() == 1); - TF_RET_CHECK(dot.operand(0)->shape().dimensions_size() == + TF_RET_CHECK(dot.operand(0)->shape().dimensions().size() == 2 + (config.split_k > 1 ? 1 : 0) + num_batch_dims); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc index f6d4d4ee15c0f7..6db61f06feaec7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc @@ -330,10 +330,10 @@ bool NoNonContractingDimension(const HloDotInstruction& dot) { const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); if (dim_numbers.lhs_batch_dimensions().size() + dim_numbers.lhs_contracting_dimensions().size() == - dot.operand(0)->shape().dimensions_size() || + dot.operand(0)->shape().dimensions().size() || dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == - dot.operand(1)->shape().dimensions_size()) { + dot.operand(1)->shape().dimensions().size()) { return true; } return false; @@ -362,7 +362,7 @@ CodegenDecision IsTritonSupportedDynamicSlice( int64_t majormost_dim_id = in_layout.minor_to_major(in_layout.minor_to_major_size() - 1); - for (int i = 0; i < input->shape().dimensions_size(); ++i) { + for (int i = 0; i < input->shape().dimensions().size(); ++i) { if (i == majormost_dim_id) { continue; } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { From 368667fae73ceb153d7642a39db1375221429630 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:56:34 -0700 Subject: [PATCH 0686/1324] Automated Code Change PiperOrigin-RevId: 747364734 --- third_party/xla/xla/backends/gpu/codegen/copy.cc | 4 ++-- third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/copy.cc b/third_party/xla/xla/backends/gpu/codegen/copy.cc index 047031d690b6ad..7a779ce7f2876d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/copy.cc +++ b/third_party/xla/xla/backends/gpu/codegen/copy.cc @@ -207,7 +207,7 @@ bool DynamicMemcpyFusion::IsCandidateFusion( } } - int rank = root->operand(0)->shape().dimensions_size(); + int rank = root->operand(0)->shape().dimensions().size(); for (int i = 0; i < rank; ++i) { auto* operand = root->operand(i + first_offset_index); if (!IsZeroOffset(root, i) && operand->opcode() != HloOpcode::kConstant && @@ -238,7 +238,7 @@ DynamicMemcpyFusion::GetMemcpyDescriptorForFusion( } int first_offset_index = GetFirstOffsetOperandIndex(slice); - int rank = slice_input_shape.dimensions_size(); + int rank = slice_input_shape.dimensions().size(); auto stack = GetCallStack(fusion); VLOG(5) << "Preconditions passed, trying to build a memcpy descriptor."; diff --git a/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc index 9c02768ddf22f6..3c2a7fd4363810 100644 --- a/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc @@ -98,7 +98,7 @@ absl::Status AnnotateKernelLaunchDimensions( IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( const LaunchDimensions& launch_dims, int unroll_factor, const Shape& shape, mlir::MLIRContext* ctx) { - std::vector output_dims(shape.dimensions_size()); + std::vector output_dims(shape.dimensions().size()); std::array thread_counts{ launch_dims.thread_counts_per_block().x, From 3726249397f440ee5aa756312bf3b2cd7fb79730 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 04:57:40 -0700 Subject: [PATCH 0687/1324] Automated Code Change PiperOrigin-RevId: 747365022 --- .../xla/xla/service/gpu/matmul_indexing_utils.cc | 4 ++-- third_party/xla/xla/service/gpu/matmul_utils.cc | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc b/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc index d1afa22e1b0832..4255c6e0a4af1d 100644 --- a/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc @@ -33,11 +33,11 @@ namespace gpu { absl::StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims) { - auto nc = ::xla::GetNonContractingDims(shape.dimensions_size(), + auto nc = ::xla::GetNonContractingDims(shape.dimensions().size(), contracting_dims, batch_dims); TF_RET_CHECK(batch_dims.size() + contracting_dims.size() + nc.size() == - shape.dimensions_size()); + shape.dimensions().size()); return std::vector(nc.begin(), nc.end()); } diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index d9fdbf565d85cf..1737508915baab 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -65,7 +65,7 @@ absl::StatusOr GetBatchRowColumnShape( TF_RET_CHECK(shape.has_layout()); std::vector minor_to_major; - for (size_t i = 0; i < shape.dimensions_size();) { + for (size_t i = 0; i < shape.dimensions().size();) { // The GeMM output always has its layout set such that the batch, row, and // col dim groups are each laid out physically sequentially. GeMM operands // must, therefore, be laid out similarly. @@ -112,7 +112,7 @@ absl::StatusOr GetBatchRowColumnShape( // Returns the matrix layout for a logical shape (batch, rows, columns). /*static*/ absl::StatusOr MatrixLayout::For(const Shape& shape) { - TF_RET_CHECK(shape.dimensions_size() == 3); + TF_RET_CHECK(shape.dimensions().size() == 3); TF_RET_CHECK(shape.has_layout()); int64_t batch_size = shape.dimensions(0); @@ -175,10 +175,10 @@ absl::StatusOr GetBatchRowColumnShape( size_t rhs_num_batch_dims, size_t rhs_num_col_dims) { size_t num_batch_dims = std::max(lhs_num_batch_dims, rhs_num_batch_dims); - TF_RET_CHECK(shape.dimensions_size() == + TF_RET_CHECK(shape.dimensions().size() == num_batch_dims + lhs_num_row_dims + rhs_num_col_dims); - std::vector dims(shape.dimensions_size()); + std::vector dims(shape.dimensions().size()); absl::c_iota(dims, 0); auto batch_dims = absl::Span(dims).first(num_batch_dims); @@ -300,10 +300,10 @@ absl::StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, int64_t num_batch_dims = std::max(lhs_batch_dims.size(), rhs_batch_dims.size()); - TF_RET_CHECK(output_shape.dimensions_size() == + TF_RET_CHECK(output_shape.dimensions().size() == num_batch_dims + lhs_row_dims.size() + rhs_col_dims.size()); - std::vector output_dims(output_shape.dimensions_size()); + std::vector output_dims(output_shape.dimensions().size()); absl::c_iota(output_dims, 0); auto output_batch_dims = From 4971bcaac7fd08c29410864d0cc5db00abafa273 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:00:39 -0700 Subject: [PATCH 0688/1324] Automated Code Change PiperOrigin-RevId: 747365742 --- third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 157fc5def33c5e..dfa9e753d30f69 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -430,7 +430,7 @@ AbstractTfrtCpuBuffer::CopyToDeviceAcrossClients(PjRtDevice* dst_device) { // Avoid use-after-free on `literal` due to unsequenced move and use. Literal* literal_pointer = literal.get(); absl::InlinedVector byte_strides( - literal->shape().dimensions_size()); + literal->shape().dimensions().size()); TF_RETURN_IF_ERROR( ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides))); TF_ASSIGN_OR_RETURN(PjRtMemorySpace * dst_memory_space, From 72dfb2253a15d25e50baa793adec0c0905a8b385 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:02:16 -0700 Subject: [PATCH 0689/1324] Automated Code Change PiperOrigin-RevId: 747366180 --- third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 2 +- third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 0ee251eb9fb112..cf9e35a0144c25 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -2319,7 +2319,7 @@ absl::StatusOr> PjRtCApiBuffer::CopyToMemorySpace( // Copy across PjRtClients by copying through host TF_ASSIGN_OR_RETURN(std::shared_ptr literal, ToLiteralSync()); absl::InlinedVector byte_strides( - literal->shape().dimensions_size()); + literal->shape().dimensions().size()); TF_RETURN_IF_ERROR( ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides))); // Avoid use-after-free on `literal` due to unsequenced move and use. diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index de678204df0eb5..b69c42ec009997 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -743,7 +743,8 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( device_shape, transfer_manager->ChooseCompactLayoutForShape(device_shape)); } - absl::InlinedVector shape_strides(device_shape.dimensions_size()); + absl::InlinedVector shape_strides( + device_shape.dimensions().size()); TF_RETURN_IF_ERROR( ShapeUtil::ByteStrides(device_shape, absl::MakeSpan(shape_strides))); bool host_and_device_strides_equal = @@ -1687,7 +1688,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace( // Avoid use-after-free on `literal` due to unsequenced move and use. Literal* literal_pointer = literal.get(); absl::InlinedVector byte_strides( - literal->shape().dimensions_size()); + literal->shape().dimensions().size()); TF_RETURN_IF_ERROR( ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides))); return dst_device->client()->BufferFromHostBuffer( @@ -1865,7 +1866,7 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking, // shape `pred[0]`. if (execution_shape.IsToken() && buffer_on_device_shape.element_type() == PrimitiveType::PRED && - buffer_on_device_shape.dimensions_size() == 1 && + buffer_on_device_shape.dimensions().size() == 1 && buffer_on_device_shape.dimensions(0) == 0) { return absl::OkStatus(); } From 83e0a97462b940d94598bb96b62f0dbde134b1cd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:04:07 -0700 Subject: [PATCH 0690/1324] Automated Code Change PiperOrigin-RevId: 747366834 --- .../xla/backends/cpu/runtime/convolution_lib.cc | 10 +++++----- .../xla/xla/backends/cpu/runtime/dot_lib.cc | 14 +++++++------- .../xla/xla/backends/cpu/runtime/fft_thunk.cc | 2 +- .../xla/xla/backends/cpu/runtime/sort_thunk.cc | 6 +++--- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/convolution_lib.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_lib.cc index 59931b302d7815..92a86cbf760191 100644 --- a/third_party/xla/xla/backends/cpu/runtime/convolution_lib.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_lib.cc @@ -45,7 +45,7 @@ static size_t GetConvolutionRank(const Shape& input_shape) { // Convolution rank is the number of spatial dimensions. Besides spatial // dimensions, input shape contains two other dimensions (batch size and the // number of channels). - return input_shape.dimensions_size() - 2; + return input_shape.dimensions().size() - 2; } static absl::Status ValidateConvolutionShapes( @@ -59,13 +59,13 @@ static absl::Status ValidateConvolutionShapes( } // Rank of input, kernel and output buffers. - if (input_shape.dimensions_size() != kernel_shape.dimensions_size() || - input_shape.dimensions_size() != output_shape.dimensions_size()) { + if (input_shape.dimensions().size() != kernel_shape.dimensions().size() || + input_shape.dimensions().size() != output_shape.dimensions().size()) { return InvalidArgument( "ConvolutionThunk: Buffer ranks mismatch. Input rank (%d) vs kernel " "rank (%d) vs output rank (%d)", - input_shape.dimensions_size(), kernel_shape.dimensions_size(), - output_shape.dimensions_size()); + input_shape.dimensions().size(), kernel_shape.dimensions().size(), + output_shape.dimensions().size()); } // Batch size. diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc b/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc index 9d8c96bf80ba1b..2b477022831e05 100644 --- a/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc @@ -107,9 +107,9 @@ absl::StatusOr GetDotShape(const DotDimensionNumbers& dot_dimensions, // Check that matmul shapes are rank 2 or less and can be represented as // Eigen 2D contraction. - if (lhs_matmul_shape.dimensions_size() > 2 || - rhs_matmul_shape.dimensions_size() > 2 || - out_matmul_shape.dimensions_size() > 2) { + if (lhs_matmul_shape.dimensions().size() > 2 || + rhs_matmul_shape.dimensions().size() > 2 || + out_matmul_shape.dimensions().size() > 2) { return InvalidArgument( "MatMul shape must be rank 2 or less: lhs=%s, rhs=%s, out=%s", lhs_matmul_shape.ToString(true), rhs_matmul_shape.ToString(true), @@ -150,20 +150,20 @@ absl::StatusOr GetDotCanonicalDims( TF_RET_CHECK(rhs_contracting_dims[0] < 2); auto is_column_major = [](const Shape& shape) { - return shape.dimensions_size() > 1 && + return shape.dimensions().size() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; }; return DotCanonicalDims{ - /*m=*/dot_shape.lhs_matmul_shape.dimensions_size() <= 1 + /*m=*/dot_shape.lhs_matmul_shape.dimensions().size() <= 1 ? int64_t{1} : dot_shape.lhs_matmul_shape.dimensions(1 - lhs_contracting_dims[0]), /*k=*/dot_shape.lhs_matmul_shape.dimensions(lhs_contracting_dims[0]), - /*n=*/dot_shape.rhs_matmul_shape.dimensions_size() <= 1 + /*n=*/dot_shape.rhs_matmul_shape.dimensions().size() <= 1 ? int64_t{1} : dot_shape.rhs_matmul_shape.dimensions(1 - rhs_contracting_dims[0]), /*lhs_column_major=*/is_column_major(dot_shape.lhs_matmul_shape), - /*lhs_canonical=*/dot_shape.lhs_matmul_shape.dimensions_size() <= 1 || + /*lhs_canonical=*/dot_shape.lhs_matmul_shape.dimensions().size() <= 1 || lhs_contracting_dims[0] == 1, /*rhs_column_major=*/is_column_major(dot_shape.rhs_matmul_shape), /*rhs_canonical=*/rhs_contracting_dims[0] == 0, diff --git a/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc index a19e5cb49a7e17..cc099c66be537e 100644 --- a/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc @@ -80,7 +80,7 @@ tsl::AsyncValueRef FftThunk::Execute( // Flatten operand batches. absl::InlinedVector operand_shape_flat(fft_rank + 1); int64_t input_batch = 1; - int64_t input_batch_length = output_shape_.dimensions_size() - fft_rank; + int64_t input_batch_length = output_shape_.dimensions().size() - fft_rank; for (int i = 0; i < input_batch_length; i++) { input_batch *= input_shape_.dimensions(i); } diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index 28293441a09a03..dd66ce0bbcf665 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -79,8 +79,8 @@ static absl::Status VerifySortInputs(absl::Span inputs, // Check that sort dimension is valid. int64_t sort_dimension = - dimension >= 0 ? dimension : shape.dimensions_size() + dimension; - if (shape.dimensions_size() <= sort_dimension) { + dimension >= 0 ? dimension : shape.dimensions().size() + dimension; + if (shape.dimensions().size() <= sort_dimension) { return Internal( "Shape of dimensions [%s] can't be sorted along dimension %d", absl::StrJoin(shape.dimensions(), ","), dimension); @@ -604,7 +604,7 @@ struct SortDims { // (or `std::stable_sort`) on each (strided) slice of the buffer. static SortDims GetSortDims(const Shape& shape, int64_t dimension) { int64_t sort_dimension = - dimension >= 0 ? dimension : shape.dimensions_size() + dimension; + dimension >= 0 ? dimension : shape.dimensions().size() + dimension; // We need to normalize shape + layout into a descending layout, so that we // can compute access strides according to the physical layout. From ba75f944a66b2a07a97f4f567014c03a0cd33bd6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:09:37 -0700 Subject: [PATCH 0691/1324] Automated Code Change PiperOrigin-RevId: 747368462 --- .../xla/xla/backends/gpu/runtime/command_buffer_cmd.cc | 10 +++++----- .../xla/backends/gpu/runtime/dynamic_slice_thunk.cc | 10 +++++----- third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 53d6bded7ba291..67fc4e89526ab5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -1867,7 +1867,7 @@ DynamicSliceFusionCmd::DynamicSliceFusionCmd( offsets_allocs_base_.push_back(offsets_allocs_size_); if (slice.sliced_shape.has_value()) { offsets_allocs_size_ += - slice.sliced_shape->dimensions_size() * sizeof(int64_t); + slice.sliced_shape->dimensions().size() * sizeof(int64_t); } } } @@ -1915,9 +1915,9 @@ absl::Status DynamicSliceFusionCmd::Prepare( TF_RET_CHECK(slice.sliced_shape->IsArray()); TF_RET_CHECK(slice.offsets->size() == - slice.orig_shape->dimensions_size()); - TF_RET_CHECK(slice.sliced_shape->dimensions_size() == - slice.orig_shape->dimensions_size()); + slice.orig_shape->dimensions().size()); + TF_RET_CHECK(slice.sliced_shape->dimensions().size() == + slice.orig_shape->dimensions().size()); } } TF_RETURN_IF_ERROR(embedded_commands_.Prepare(params, resource_requests)); @@ -1968,7 +1968,7 @@ absl::StatusOr DynamicSliceFusionCmd::Record( const Shape& dst_shape = *slice.sliced_shape; absl::InlinedVector slice_starts; - slice_starts.reserve(dst_shape.dimensions_size()); + slice_starts.reserve(dst_shape.dimensions().size()); // Number of issues d2h transfers to copy offset values from device to // host. diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index a60a12bbf704dc..7e8347abf38f32 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -107,7 +107,7 @@ DynamicSliceThunk::DynamicSliceThunk( offsets_allocs_base_.push_back(offsets_allocs_size_); if (slice.sliced_shape.has_value()) { offsets_allocs_size_ += - slice.sliced_shape->dimensions_size() * sizeof(int64_t); + slice.sliced_shape->dimensions().size() * sizeof(int64_t); } } } @@ -125,9 +125,9 @@ absl::Status DynamicSliceThunk::Prepare( TF_RET_CHECK(slice.sliced_shape->IsArray()); TF_RET_CHECK(slice.offsets->size() == - slice.orig_shape->dimensions_size()); - TF_RET_CHECK(slice.sliced_shape->dimensions_size() == - slice.orig_shape->dimensions_size()); + slice.orig_shape->dimensions().size()); + TF_RET_CHECK(slice.sliced_shape->dimensions().size() == + slice.orig_shape->dimensions().size()); } } @@ -201,7 +201,7 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { const Shape& dst_shape = *slice.sliced_shape; absl::InlinedVector slice_starts; - slice_starts.reserve(dst_shape.dimensions_size()); + slice_starts.reserve(dst_shape.dimensions().size()); // Number of issues d2h transfers to copy offset values from device to // host. diff --git a/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc index 5c571499c388e6..826ec109352dfb 100644 --- a/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc @@ -148,7 +148,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, const int64_t fft_rank = fft_len.size(); CHECK_LE(fft_rank, 3); int batch_size = 1; - for (int i = 0; i < input_shape.dimensions_size() - fft_rank; ++i) { + for (int i = 0; i < input_shape.dimensions().size() - fft_rank; ++i) { batch_size *= input_shape.dimensions(i); } uint64_t fft_length[3]; @@ -160,7 +160,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, uint64_t output_distance = 1; for (int i = 0; i < fft_rank; ++i) { - auto dim_offset = input_shape.dimensions_size() - fft_rank + i; + auto dim_offset = input_shape.dimensions().size() - fft_rank + i; fft_length[i] = static_cast(fft_len[i]); input_embed[i] = input_shape.dimensions(dim_offset); input_distance *= input_shape.dimensions(dim_offset); From 96a8c80ca0fbc0edc3fc62d296875e43db8c6d14 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:10:02 -0700 Subject: [PATCH 0692/1324] Automated Code Change PiperOrigin-RevId: 747368594 --- tensorflow/compiler/tf2xla/shape_util.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 04fbb0cf31f834..0d7549d81c20f6 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -43,7 +43,7 @@ absl::Status PopulateInfeedLayoutVector(const xla::Shape& shape, layouts->push_back(dim); } } else { - layouts->insert(layouts->end(), shape.dimensions_size(), -1); + layouts->insert(layouts->end(), shape.dimensions().size(), -1); } return absl::OkStatus(); } @@ -97,7 +97,7 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < shape.dimensions_size(); ++i) { + for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } return absl::OkStatus(); @@ -237,7 +237,7 @@ absl::Status GetShapeWithLayout( "Nested tuples not supported: ", xla::ShapeUtil::HumanString(input_shape)); } - int64_t rank = shape.dimensions_size(); + int64_t rank = shape.dimensions().size(); if (position + rank > minor_to_major.size()) { return errors::InvalidArgument( "Not enough layout attribute elements: position=", position, @@ -259,7 +259,7 @@ absl::Status GetShapeWithLayout( } *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); } else { - int64_t rank = input_shape.dimensions_size(); + int64_t rank = input_shape.dimensions().size(); const int64_t minor_to_major_size = minor_to_major.size(); if (rank != minor_to_major_size) { return errors::InvalidArgument( From 1e0f31f81a2139cad4fc4e6f472d5d181b663f02 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:15:41 -0700 Subject: [PATCH 0693/1324] Automated Code Change PiperOrigin-RevId: 747370319 --- .../codegen/emitters/computation_partitioner.cc | 2 +- .../codegen/emitters/elemental_hlo_to_mlir.cc | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc index 0de9ae6bc66cbd..a66ae7f7bccd27 100644 --- a/third_party/xla/xla/codegen/emitters/computation_partitioner.cc +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc @@ -157,7 +157,7 @@ bool IsEvaluatedMoreThanOnce(const HloInstruction* instr) { return absl::c_any_of(instr->users(), [&](const HloInstruction* user) { if (user->opcode() == HloOpcode::kGather && absl::c_linear_search(user->OperandIndices(instr), 1) && - instr->shape().dimensions_size() >= 2 && + instr->shape().dimensions().size() >= 2 && instr->shape().dimensions(1) > 1) { return true; } diff --git a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc index 6d848e2912f34c..79bca0e4e850fb 100644 --- a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc @@ -319,7 +319,7 @@ absl::StatusOr> EmitDynamicSlice( SmallVector input_indices(indices); const auto& input_shape = instr->operand(0)->shape(); - for (int i = 0; i < input_shape.dimensions_size(); ++i) { + for (int i = 0; i < input_shape.dimensions().size(); ++i) { TF_ASSIGN_OR_RETURN( auto offset, GetSingleOperandValue(operand_provider, instr, i + 1, {})); offset = @@ -343,7 +343,7 @@ absl::StatusOr> EmitDynamicUpdateSlice( Value is_in_bounds = b.create(b.getIntegerAttr(b.getI1Type(), 1)); mlir::SmallVector update_indices; const auto& updates_shape = instr->operand(1)->shape(); - for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + for (int i = 0; i < instr->shape().dimensions().size(); ++i) { int64_t update_size = updates_shape.dimensions(i); TF_ASSIGN_OR_RETURN( auto start_index, @@ -392,7 +392,7 @@ absl::StatusOr> EmitGather( // Gather allows the index vector to contain fewer elements than the rank // of the input. In that case, the remaining indices are 0. SmallVector operand_indices( - instr->operand(0)->shape().dimensions_size(), zero); + instr->operand(0)->shape().dimensions().size(), zero); // Produce start indices. // HLO allows the index vector dimension to be implicit, and the algebraic @@ -400,7 +400,7 @@ absl::StatusOr> EmitGather( // indices here and do the implicit reshape in place. const auto& indices_shape = instr->operand(1)->shape(); int num_indices = - indices_shape.dimensions_size() == 1 ? 1 : indices_shape.dimensions(1); + indices_shape.dimensions().size() == 1 ? 1 : indices_shape.dimensions(1); for (int i = 0; i < num_indices; ++i) { auto i_val = i == 0 ? zero : b.create(i); int64_t slice_size = instr->gather_slice_sizes()[i]; @@ -408,7 +408,7 @@ absl::StatusOr> EmitGather( // Read and clamp index. TF_ASSIGN_OR_RETURN(auto input_index, operand_provider(instr, 1, - indices_shape.dimensions_size() == 1 + indices_shape.dimensions().size() == 1 ? ValueRange{row} : ValueRange{row, i_val})); TF_RET_CHECK(input_index.size() == 1) @@ -711,7 +711,7 @@ absl::StatusOr> EmitTuple( while (first_shape->IsTuple()) { first_shape = &first_shape->tuple_shapes(0); } - CHECK_EQ(first_shape->dimensions_size(), indices.size()) + CHECK_EQ(first_shape->dimensions().size(), indices.size()) << "Indices for tuple must be for the first tuple element"; SmallVector operands; for (int i = 0; i < instr->operand_count(); ++i) { @@ -778,10 +778,10 @@ absl::StatusOr> GetOperands( if (is_elementwise && instr->shape().IsArray()) { // Check if the instruction is really elementwise. There may be some // broadcasting. - int64_t rank = instr->shape().dimensions_size(); + int64_t rank = instr->shape().dimensions().size(); is_elementwise &= absl::c_all_of(instr->operands(), [&](const HloInstruction* operand) { - return operand->shape().dimensions_size() == rank; + return operand->shape().dimensions().size() == rank; }); } From 0fe4c33dce087569fefcfd2fa07567e47ad88ee4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:16:27 -0700 Subject: [PATCH 0694/1324] Automated Code Change PiperOrigin-RevId: 747370611 --- third_party/xla/xla/tests/client_library_test_base.cc | 4 ++-- third_party/xla/xla/tests/concatenate_test.cc | 4 ++-- third_party/xla/xla/tests/test_utils.cc | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 1358c8600ffc04..d84efe662408f9 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -200,7 +200,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( verify_output(actual, ""); // Try with all output layouts. - std::vector minor_to_major(expected.shape().dimensions_size()); + std::vector minor_to_major(expected.shape().dimensions().size()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto layout = ShapeUtil::MakeShapeWithDenseLayout( @@ -243,7 +243,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return absl::OkStatus(); } - std::vector minor_to_major(literal.shape().dimensions_size()); + std::vector minor_to_major(literal.shape().dimensions().size()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = diff --git a/third_party/xla/xla/tests/concatenate_test.cc b/third_party/xla/xla/tests/concatenate_test.cc index 2144a5d49071db..e69018cb77175f 100644 --- a/third_party/xla/xla/tests/concatenate_test.cc +++ b/third_party/xla/xla/tests/concatenate_test.cc @@ -99,8 +99,8 @@ TEST_F(ConcatenateTest, ThreeR2Axis1) { } static auto MakeIotaForShape(const Shape& shape) { - std::vector strides(shape.dimensions_size(), 1); - for (int i = shape.dimensions_size() - 1; i > 0; --i) { + std::vector strides(shape.dimensions().size(), 1); + for (int i = shape.dimensions().size() - 1; i > 0; --i) { strides[i - 1] = strides[i] * shape.dimensions(i); } return [strides = std::move(strides)](absl::Span indices) { diff --git a/third_party/xla/xla/tests/test_utils.cc b/third_party/xla/xla/tests/test_utils.cc index 12954db63fbffb..31dfe33cc43778 100644 --- a/third_party/xla/xla/tests/test_utils.cc +++ b/third_party/xla/xla/tests/test_utils.cc @@ -369,14 +369,14 @@ absl::Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_LE(lhs->shape().dimensions_size(), 2); - CHECK_LE(rhs->shape().dimensions_size(), 2); + CHECK_LE(lhs->shape().dimensions().size(), 2); + CHECK_LE(rhs->shape().dimensions().size(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.add_lhs_contracting_dimensions( - lhs->shape().dimensions_size() > 1 ? 1 : 0); + lhs->shape().dimensions().size() > 1 ? 1 : 0); dot_dimension_numbers.add_rhs_contracting_dimensions(0); return std::make_unique( shape, lhs, rhs, dot_dimension_numbers, precision_config); From aca38707d8a685b00ea4b9f4e6d49f9f90d3ee61 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:21:11 -0700 Subject: [PATCH 0695/1324] Automated Code Change PiperOrigin-RevId: 747371953 --- .../auto_sharding/auto_sharding.cc | 34 ++++++++++--------- .../auto_sharding_dot_handler.cc | 14 ++++---- .../auto_sharding/auto_sharding_strategy.cc | 4 +-- .../auto_sharding/auto_sharding_util.cc | 27 ++++++++------- .../auto_sharding/auto_sharding_util.h | 4 +-- .../auto_sharding/cluster_environment.cc | 10 +++--- .../auto_sharding/cluster_environment.h | 4 +-- 7 files changed, 50 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 5ba07480e91184..ed8a7fc0aa547c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -224,7 +224,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); const auto& operand_strategies = operand_strategy_group.GetStrategies(); const std::vector zeros(operand_strategies.size(), 0.0); - if (operand_shape.IsToken() || operand_shape.dimensions_size() == 0) { + if (operand_shape.IsToken() || operand_shape.dimensions().size() == 0) { communication_resharding_costs.push_back(zeros); memory_resharding_costs.push_back(zeros); if (!input_shardings.shardings[k].has_value()) { @@ -481,9 +481,9 @@ absl::StatusOr> FollowReduceStrategy( // op_dim_to_output_dim = [0, 1, -1] std::vector op_dim_to_output_dim = GetDimensionMapping(/*reduced_dimensions=*/ins->dimensions(), - /*op_count*/ operand->shape().dimensions_size()); - CHECK_EQ(ins->dimensions().size() + output_shape.dimensions_size(), - operand->shape().dimensions_size()) + /*op_count*/ operand->shape().dimensions().size()); + CHECK_EQ(ins->dimensions().size() + output_shape.dimensions().size(), + operand->shape().dimensions().size()) << "Invalid kReduce: output size + reduced dimensions size != op count"; for (const auto& src_strategy : src_strategy_group->GetStrategies()) { @@ -492,12 +492,12 @@ absl::StatusOr> FollowReduceStrategy( operand->shape(), input_sharding, /* consider_reverse_device_meshes */ true, /* crash_at_error */ crash_at_error); - if (tensor_dim_to_mesh.size() != operand->shape().dimensions_size()) { + if (tensor_dim_to_mesh.size() != operand->shape().dimensions().size()) { return absl::InvalidArgumentError( "Cannot generate tensor dim to mesh dim mapping"); } std::vector all_reduce_dims; - for (int64_t op_dim = 0; op_dim < operand->shape().dimensions_size(); + for (int64_t op_dim = 0; op_dim < operand->shape().dimensions().size(); ++op_dim) { int64_t mesh_dim = tensor_dim_to_mesh[op_dim]; // Replicates on this mesh dim. @@ -880,7 +880,7 @@ void EnumerateAll1DPartition( bool allow_shardings_small_dims_across_many_devices, const std::string& suffix, const CallGraph& call_graph, StrategyGroup& strategy_group) { - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { bool small_dims_sharding_check = !allow_shardings_small_dims_across_many_devices && @@ -939,7 +939,7 @@ void EnumerateAll1DPartition( // the cost model for sort (which, as noted above in the comments for // the function) is also an approximation. communication_cost = ComputeSortCommunicationCost( - ins->operand(0)->shape().dimensions_size() - 1, i, j, shape, + ins->operand(0)->shape().dimensions().size() - 1, i, j, shape, cluster_env); } strategy_group.AddStrategy( @@ -975,7 +975,7 @@ void EnumerateAllPartition( return; } // Fully tile the buffer to the mesh - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { auto tensor_it = std::find(tensor_dims.begin(), tensor_dims.end(), i); if (tensor_it != tensor_dims.end()) { continue; @@ -1044,7 +1044,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, CHECK(sort_ins); sort_or_topk_dim = sort_ins->sort_dimension(); } else if (IsTopKCustomCall(ins)) { - sort_or_topk_dim = ins->operand(0)->shape().dimensions_size() - 1; + sort_or_topk_dim = ins->operand(0)->shape().dimensions().size() - 1; } if (sort_or_topk_dim != -1) { @@ -1072,7 +1072,7 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, const Shape& operand_shape = operand->shape(); const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); - for (int64_t i = 0; i < ins->shape().dimensions_size(); ++i) { + for (int64_t i = 0; i < ins->shape().dimensions().size(); ++i) { for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { if (device_mesh.dim(j) == 1 || (only_allow_divisible && @@ -1534,7 +1534,7 @@ void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( continue; } const auto& tile_assignment = strategy.output_sharding.tile_assignment(); - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { if (tile_assignment.dim(i) > 1 && tile_assignment.dim(i) > shape.dimensions(i)) { invalid_strategy_indices.push_back(sid); @@ -2221,8 +2221,10 @@ absl::Status InsertReshardReshapes( rhs->shape(), rhs_sharding, /*consider_reverse_device_meshes=*/true, crash_at_error); - if (lhs_tensor_dim_to_mesh_dim.size() != lhs->shape().dimensions_size() || - rhs_tensor_dim_to_mesh_dim.size() != rhs->shape().dimensions_size()) { + if (lhs_tensor_dim_to_mesh_dim.size() != + lhs->shape().dimensions().size() || + rhs_tensor_dim_to_mesh_dim.size() != + rhs->shape().dimensions().size()) { return absl::InvalidArgumentError( "Cannot generate tensor dim to mesh dim mapping"); } @@ -3235,7 +3237,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, } } else if (ins->opcode() == HloOpcode::kReduce) { // TODO(zhuohan): support more cases. - CHECK_EQ(ins->shape().dimensions_size(), 1); + CHECK_EQ(ins->shape().dimensions().size(), 1); int mesh_dim; if (absl::StrContains(input_shardings.name, "allreduce @ [0]")) { @@ -3291,7 +3293,7 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst, } if (inst->opcode() == HloOpcode::kReduce && - inst->shape().dimensions_size() == 1) { + inst->shape().dimensions().size() == 1) { return true; } if (inst->opcode() == HloOpcode::kDot) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 5cadf66f822d69..3c748b2ff9119f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -717,12 +717,12 @@ void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand( used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); } if (used_mesh_dims.size() == device_mesh_.num_dimensions() || - used_mesh_dims.size() == operand->shape().dimensions_size()) { + used_mesh_dims.size() == operand->shape().dimensions().size()) { return; } - for (int64_t tensor_dim = 0; tensor_dim < operand->shape().dimensions_size(); - ++tensor_dim) { + for (int64_t tensor_dim = 0; + tensor_dim < operand->shape().dimensions().size(); ++tensor_dim) { if (auto it = operand_dim_map.find(tensor_dim); it != operand_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) { continue; @@ -763,11 +763,11 @@ void DotHandler::AppendReduceScatterWindowedEinsumStrategy( } if (used_mesh_dims.size() == device_mesh_.num_dimensions() || - used_mesh_dims.size() == ins_->shape().dimensions_size()) { + used_mesh_dims.size() == ins_->shape().dimensions().size()) { return; } - for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().dimensions_size(); + for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().dimensions().size(); ++tensor_dim) { if (auto it = output_dim_map.find(tensor_dim); it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) { @@ -805,7 +805,7 @@ absl::Status DotHandler::RegisterStrategies() { [&](const DimMap& output_dim_map) { GenerateDotShardingStrategiesFromOutputSharding(output_dim_map); }, - ins_->shape().dimensions_size(), all_mesh_dims, + ins_->shape().dimensions().size(), all_mesh_dims, option_.allow_mixed_mesh_shape); SortStrategies(); return absl::OkStatus(); @@ -965,7 +965,7 @@ void ConvHandler::SplitDepthwise(bool forward) { }; std::vector all_mesh_dims(device_mesh_.num_dimensions()); std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0); - Enumerate(split_func, ins_->shape().dimensions_size(), all_mesh_dims, + Enumerate(split_func, ins_->shape().dimensions().size(), all_mesh_dims, option_.allow_mixed_mesh_shape); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 334986f81abbbd..909ab6f933a7c7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -134,13 +134,13 @@ ComputeSliceShardingAndCommunicationCostFromOperand( CHECK(old_shape.IsArray()); std::vector tensor_to_mesh_dim = GetTensorDimToMeshDim( - new_shape.dimensions_size(), input_spec, device_mesh, + new_shape.dimensions().size(), input_spec, device_mesh, /* consider_reverse_device_meshes */ true); std::vector mesh_dims_for_communication; std::vector tensor_dims; std::vector mesh_dims; - for (size_t i = 0; i < new_shape.dimensions_size(); ++i) { + for (size_t i = 0; i < new_shape.dimensions().size(); ++i) { if (tensor_to_mesh_dim[i] == -1) { continue; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index b660e11323a579..8c72dab856892b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -121,7 +121,7 @@ std::optional PropagateDimwiseSharding( CHECK(old_shape.IsArray()); const auto& tile_assignment = input_spec.tile_assignment(); - for (int64_t i = 0; i < old_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < old_shape.dimensions().size(); ++i) { if (tile_assignment.dim(i) > 1 && new_shape.dimensions(i) != old_shape.dimensions(i)) { return std::nullopt; @@ -144,7 +144,7 @@ std::optional PropagateReduceWindowSharding( CHECK(!input_spec.IsTuple()); const auto& tile_assignment = input_spec.tile_assignment(); - for (int64_t i = 0; i < old_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < old_shape.dimensions().size(); ++i) { if (tile_assignment.dim(i) > 1 && window.dimensions(i).size() != 1) { return std::nullopt; } @@ -265,7 +265,7 @@ void BatchDimMapForward(const std::vector& instructions, if (batch_map.contains(GetBatchDimMapKey(operand))) { int value = batch_map[GetBatchDimMapKey(operand)]; int old_dim = -1; - for (int i = 0; i < ins->shape().dimensions_size(); ++i) { + for (int i = 0; i < ins->shape().dimensions().size(); ++i) { if (absl::c_linear_search(dimensions, i)) { old_dim++; } @@ -521,7 +521,7 @@ void BatchDimMapBackward(const std::vector& instructions, !batch_map.contains(GetBatchDimMapKey(operand))) { int value = batch_map[GetBatchDimMapKey(ins)]; int old_dim = -1; - for (int i = 0; i < ins->shape().dimensions_size(); ++i) { + for (int i = 0; i < ins->shape().dimensions().size(); ++i) { if (absl::c_linear_search(dimensions, i)) { old_dim++; } @@ -914,7 +914,7 @@ bool IsAlwaysReplicated(const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kConstant) { return true; } - if (inst->shape().dimensions_size() == 0) { + if (inst->shape().dimensions().size() == 0) { return true; } if (inst->opcode() == HloOpcode::kBroadcast) { @@ -1211,7 +1211,7 @@ absl::StatusOr ComputeIntermediateShape(const HloSharding& src_sharding, // Find an intermediate shape std::vector inter_shape_dims; - for (size_t i = 0; i < shape.dimensions_size(); ++i) { + for (size_t i = 0; i < shape.dimensions().size(); ++i) { if (sharding_1d->tile_assignment().dim(i) == 1) { inter_shape_dims.push_back(shape.dimensions(i)); } else { @@ -1501,7 +1501,7 @@ HloSharding TileV1(const Shape& tensor_shape, CHECK_EQ(tensor_dims.size(), mesh_dims.size()); CHECK(tensor_shape.IsArray()); std::vector tile_assignment_dimensions( - tensor_shape.dimensions_size(), 1); + tensor_shape.dimensions().size(), 1); // Split on certain mesh dimensions int64_t split_prod = 1; @@ -1542,7 +1542,7 @@ HloSharding TileV1(const Shape& tensor_shape, } if (proceed_to_next_tensor_dim && - current_tensor_dim == tensor_shape.dimensions_size() - 1) { + current_tensor_dim == tensor_shape.dimensions().size() - 1) { AppendFlattenElements(&tile_assignment_devices, device_mesh.DeviceArray(), mesh_indices); return; @@ -1598,7 +1598,7 @@ HloSharding TileV2(const Shape& tensor_shape, CHECK_EQ(tensor_dims.size(), mesh_dims.size()); CHECK(tensor_shape.IsArray()); std::vector tile_assignment_dimensions( - tensor_shape.dimensions_size(), 1); + tensor_shape.dimensions().size(), 1); std::vector transpose_perm; absl::Span reshape_dims = device_mesh.dimensions(); @@ -2261,8 +2261,9 @@ absl::StatusOr AdjustShardingsWithPartialMeshShape( output_flattened_shardings.push_back(sharding); continue; } - TF_ASSIGN_OR_RETURN(std::optional new_sharding, - adjust_sharding(shape.dimensions_size(), sharding)); + TF_ASSIGN_OR_RETURN( + std::optional new_sharding, + adjust_sharding(shape.dimensions().size(), sharding)); output_flattened_shardings.push_back( new_sharding.has_value() ? *new_sharding : sharding); changed |= new_sharding.has_value(); @@ -2277,7 +2278,7 @@ absl::StatusOr AdjustShardingsWithPartialMeshShape( } TF_ASSIGN_OR_RETURN( std::optional new_sharding, - adjust_sharding(inst->shape().dimensions_size(), inst->sharding())); + adjust_sharding(inst->shape().dimensions().size(), inst->sharding())); if (new_sharding.has_value()) { inst->set_sharding(*new_sharding); changed = true; @@ -2538,7 +2539,7 @@ bool IsShardingMisaligned(const HloSharding& sharding, const Shape& shape) { return false; } - for (size_t i = 0; i < shape.dimensions_size(); ++i) { + for (size_t i = 0; i < shape.dimensions().size(); ++i) { int64_t shape_dim = shape.dimensions()[i]; int64_t sharding_dim = sharding.tile_assignment().dim(i); if (shape_dim % sharding_dim != 0) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 166c8df77ceaab..902036fa475bf3 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -184,7 +184,7 @@ GetSpaceDims(const Shape& lhs_shape, const Shape& rhs_shape, const DotDimensionNumbers& dnums) { tsl::protobuf::RepeatedField lhs_space_dims, rhs_space_dims; - for (int64_t i = 0; i < lhs_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < lhs_shape.dimensions().size(); ++i) { if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { continue; @@ -192,7 +192,7 @@ GetSpaceDims(const Shape& lhs_shape, const Shape& rhs_shape, lhs_space_dims.Add(i); } - for (int64_t i = 0; i < rhs_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < rhs_shape.dimensions().size(); ++i) { if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { continue; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc index 2b777245713d78..cc41ea7503e3a7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -143,11 +143,11 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape( const HloSharding& dst_sharding) const { absl::StatusOr>> src_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding( - shape.dimensions_size(), src_sharding, device_mesh_, + shape.dimensions().size(), src_sharding, device_mesh_, /*consider_reverse_device_meshes=*/true); absl::StatusOr>> dst_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding( - shape.dimensions_size(), dst_sharding, device_mesh_, + shape.dimensions().size(), dst_sharding, device_mesh_, /*consider_reverse_device_meshes=*/true); if (!src_tensor_dim_to_mesh_axis.ok() || !dst_tensor_dim_to_mesh_axis.ok()) { return OverestimateReplicationCost(shape, src_sharding, device_mesh_); @@ -156,7 +156,7 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape( int64_t num_devices = device_mesh_.num_elements(); std::vector collective_mesh_axes; // Only consider sharded dimensions, do not consider replicate_on_last_dim. - for (size_t i = 0; i < shape.dimensions_size(); ++i) { + for (size_t i = 0; i < shape.dimensions().size(); ++i) { if ((*src_tensor_dim_to_mesh_axis)[i] == (*dst_tensor_dim_to_mesh_axis)[i]) { continue; @@ -313,9 +313,9 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape, // of an operand with a different shape, we need to use their // TiledDataRank(). size_t src_rank = - src_spec.IsTiled() ? src_spec.TiledDataRank() : shape.dimensions_size(); + src_spec.IsTiled() ? src_spec.TiledDataRank() : shape.dimensions().size(); size_t dst_rank = - dst_spec.IsTiled() ? dst_spec.TiledDataRank() : shape.dimensions_size(); + dst_spec.IsTiled() ? dst_spec.TiledDataRank() : shape.dimensions().size(); auto get_tensor_dim_to_mesh_dim = [&](int64_t rank, const HloSharding& sharding) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h index 8d4bf328f61c4f..a0612e607117e7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -116,11 +116,11 @@ class ClusterEnvironment { std::vector tensor_dim_to_mesh_dim; if (crash_at_error) { tensor_dim_to_mesh_dim = - GetTensorDimToMeshDim(shape.dimensions_size(), spec, device_mesh_, + GetTensorDimToMeshDim(shape.dimensions().size(), spec, device_mesh_, consider_reverse_device_meshes); } else { auto tensor_dim_to_mesh_dim_status = GetTensorDimToMeshDimNoCrash( - shape.dimensions_size(), spec, device_mesh_, + shape.dimensions().size(), spec, device_mesh_, consider_reverse_device_meshes); if (tensor_dim_to_mesh_dim_status.ok()) { tensor_dim_to_mesh_dim = tensor_dim_to_mesh_dim_status.value(); From 9f2a6a9fcd2e8860fc5deeb98370c76bbff49238 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:24:35 -0700 Subject: [PATCH 0696/1324] Automated Code Change PiperOrigin-RevId: 747372864 --- .../xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc index d3ac9a653d8b66..ca15294ce32727 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc @@ -54,8 +54,8 @@ absl::Status RewriteLayoutWithShardedShape( sharding->TileOffsetForDevice(*xla_shape, device); std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->dimensions_size()); - for (int64_t i = 0; i < xla_shape->dimensions_size(); ++i) { + std::vector dimensions(xla_shape->dimensions().size()); + for (int64_t i = 0; i < xla_shape->dimensions().size(); ++i) { dimensions[i] = limit[i] - offset[i]; } xla::Shape per_device_xla_shape = @@ -115,7 +115,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( &to_shape)); } if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64_t i = 0; i < original_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < original_shape.dimensions().size(); ++i) { to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); } } From b403fbe82e4bd6d5cc6cdf57dfbeee3b1e63046e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 05:24:47 -0700 Subject: [PATCH 0697/1324] Automated Code Change PiperOrigin-RevId: 747372925 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 116 +++++++++--------- 1 file changed, 60 insertions(+), 56 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index c2faedaa8239f2..c2c8062b255531 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -85,7 +85,7 @@ HloInstruction* FormatShape(HloInstruction* data, } case HloOpcode::kPad: { PaddingConfig padding_config; - for (int64_t i = 0; i < step.output_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < step.output_shape.dimensions().size(); ++i) { auto padding_config_dim = padding_config.add_dimensions(); padding_config_dim->set_edge_padding_low(0); padding_config_dim->set_interior_padding(0); @@ -134,8 +134,9 @@ HloInstruction* ReverseFormatShape( break; } case HloOpcode::kPad: { - std::vector start_indices(previous_shape.dimensions_size(), 0); - std::vector strides(previous_shape.dimensions_size(), 1); + std::vector start_indices(previous_shape.dimensions().size(), + 0); + std::vector strides(previous_shape.dimensions().size(), 1); data = computation->AddInstruction( HloInstruction::CreateSlice(previous_shape, data, start_indices, previous_shape.dimensions(), strides)); @@ -216,7 +217,7 @@ bool IsSubTilingOrEqualSharding(const Shape& potential_sharded_shape, // Different tiled ranks can't be compared (something is wrong, are the // shardings for different shapes?) if (tiled_data_rank != sharding.TiledDataRank() || - tiled_data_rank != potential_sharded_shape.dimensions_size()) { + tiled_data_rank != potential_sharded_shape.dimensions().size()) { return false; } @@ -917,7 +918,7 @@ std::optional ReshapeSharding(const Shape& source_shape, DimensionVector sharding_tile_dims_stack( source_sharding.tile_assignment().dimensions().begin(), source_sharding.tile_assignment().dimensions().begin() + - source_shape.dimensions_size()); + source_shape.dimensions().size()); std::reverse(sharding_tile_dims_stack.begin(), sharding_tile_dims_stack.end()); int64_t source_dims_index = -1; @@ -1013,7 +1014,7 @@ std::optional ReshapeSharding(const Shape& source_shape, return std::nullopt; } while (target_tile_assignment_dimensions.size() < - target_shape.dimensions_size()) { + target_shape.dimensions().size()) { target_tile_assignment_dimensions.push_back(1); } @@ -1045,7 +1046,7 @@ std::optional ReshapeSharding(const Shape& source_shape, } else if (absl::c_linear_search(subgroup_types, OpSharding::REPLICATED)) { target_tile_assignment_dimensions[sharding.SubgroupReplicationDim() - sharding.TiledDataRank() + - target_shape.dimensions_size()] = + target_shape.dimensions().size()] = partially_replicated.quot; } else { target_tile_assignment_dimensions.push_back(partially_replicated.quot); @@ -1071,9 +1072,9 @@ HloSharding PropagateShardingThroughReshape(const Shape& source_shape, HloSharding inner_reshaped = PropagateShardingThroughReshape( source_shape, target_shape, group.sharding); group.sharding = std::move(inner_reshaped); - group.data_rank = target_shape.dimensions_size(); + group.data_rank = target_shape.dimensions().size(); group.group_dims[0] += - target_shape.dimensions_size() - source_shape.dimensions_size(); + target_shape.dimensions().size() - source_shape.dimensions().size(); return UngroupSharding(group); } // Find intervals of consecutive dimensions that could use ReshapeSharding(). @@ -1081,13 +1082,13 @@ HloSharding PropagateShardingThroughReshape(const Shape& source_shape, // and if it fails, we find a sub-interval of it or a disjoint interval. HloSharding result = HloSharding::Replicate(); int64_t start_dim = 0; - while (start_dim < source_shape.dimensions_size()) { + while (start_dim < source_shape.dimensions().size()) { bool found_compatible = false; // For each start_dim, try to use all dims after it. If that fails, reduce // the range. - for (int64_t end_dim = source_shape.dimensions_size(); end_dim > start_dim; - --end_dim) { - DimensionVector grouped_tiling_dims(source_shape.dimensions_size(), 1); + for (int64_t end_dim = source_shape.dimensions().size(); + end_dim > start_dim; --end_dim) { + DimensionVector grouped_tiling_dims(source_shape.dimensions().size(), 1); for (int64_t i = start_dim; i < end_dim; ++i) { grouped_tiling_dims[i] = sharding.tile_assignment().dim(i); } @@ -1117,7 +1118,7 @@ HloSharding PropagateShardingThroughReshape(const Shape& source_shape, int64_t num_replicated_dims = sharding.tile_assignment().num_elements() / Product(reshape_dims); const int64_t diff = - reshape_dims.size() - target_shape.dimensions_size(); + reshape_dims.size() - target_shape.dimensions().size(); CHECK(diff == 0 || diff == 1); if (diff == 0) { reshape_dims.push_back(num_replicated_dims); @@ -1289,11 +1290,12 @@ HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding, const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); const GatherScatterDims indices_output_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( - hlo->operand(1)->shape().dimensions_size(), dnums.index_vector_dim(), - hlo->shape().dimensions_size(), dnums.offset_dims()); + hlo->operand(1)->shape().dimensions().size(), + dnums.index_vector_dim(), hlo->shape().dimensions().size(), + dnums.offset_dims()); return PropagateShardingAlongDimsAndReplicateOthers( index_sharding, indices_output_dims.indices_dims, - indices_output_dims.output_dims, hlo->shape().dimensions_size()); + indices_output_dims.output_dims, hlo->shape().dimensions().size()); } HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding, @@ -1306,12 +1308,13 @@ HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding, const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); const GatherScatterDims indices_output_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( - hlo->operand(1)->shape().dimensions_size(), dnums.index_vector_dim(), - hlo->shape().dimensions_size(), dnums.offset_dims()); + hlo->operand(1)->shape().dimensions().size(), + dnums.index_vector_dim(), hlo->shape().dimensions().size(), + dnums.offset_dims()); return PropagateShardingAlongDimsAndReplicateOthers( output_sharding, indices_output_dims.output_dims, indices_output_dims.indices_dims, - hlo->operand(1)->shape().dimensions_size()); + hlo->operand(1)->shape().dimensions().size()); } HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { @@ -1320,9 +1323,9 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { } const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); - DimensionVector tile_assignment_dims(hlo.shape().dimensions_size()); + DimensionVector tile_assignment_dims(hlo.shape().dimensions().size()); int64_t num_elements = 1; - for (int64_t i = 0; i < hlo.shape().dimensions_size(); ++i) { + for (int64_t i = 0; i < hlo.shape().dimensions().size(); ++i) { if (!absl::c_binary_search(dnums.offset_dims(), i)) { tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); num_elements *= hlo.sharding().tile_assignment().dim(i); @@ -1351,9 +1354,9 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { // - first dimension is non offset dimension, // - second dimension is offset dimension, // Then the result sharding will be [2,1]{0,2}. - DimensionVector slice_starts(hlo.shape().dimensions_size(), 0LL), - slice_limits(hlo.shape().dimensions_size()); - for (int64_t i = 0; i < hlo.shape().dimensions_size(); ++i) { + DimensionVector slice_starts(hlo.shape().dimensions().size(), 0LL), + slice_limits(hlo.shape().dimensions().size()); + for (int64_t i = 0; i < hlo.shape().dimensions().size(); ++i) { if (!absl::c_binary_search(dnums.offset_dims(), i)) { slice_limits[i] = hlo.sharding().tile_assignment().dim(i); } else { @@ -1375,14 +1378,14 @@ HloSharding ScatterIndexShardingFromUpdate( const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers(); const GatherScatterDims indices_update_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( - scatter->scatter_indices()->shape().dimensions_size(), + scatter->scatter_indices()->shape().dimensions().size(), dnums.index_vector_dim(), - scatter->scatter_updates()[0]->shape().dimensions_size(), + scatter->scatter_updates()[0]->shape().dimensions().size(), dnums.update_window_dims()); return PropagateShardingAlongDimsAndReplicateOthers( update_sharding, indices_update_dims.output_dims, indices_update_dims.indices_dims, - scatter->scatter_indices()->shape().dimensions_size()); + scatter->scatter_indices()->shape().dimensions().size()); } HloSharding ScatterUpdateShardingFromIndex( @@ -1394,14 +1397,14 @@ HloSharding ScatterUpdateShardingFromIndex( const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers(); const GatherScatterDims indices_update_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( - scatter->scatter_indices()->shape().dimensions_size(), + scatter->scatter_indices()->shape().dimensions().size(), dnums.index_vector_dim(), - scatter->scatter_updates()[0]->shape().dimensions_size(), + scatter->scatter_updates()[0]->shape().dimensions().size(), dnums.update_window_dims()); return PropagateShardingAlongDimsAndReplicateOthers( index_sharding, indices_update_dims.indices_dims, indices_update_dims.output_dims, - scatter->scatter_updates()[0]->shape().dimensions_size()); + scatter->scatter_updates()[0]->shape().dimensions().size()); } HloSharding ScatterEffectiveIndexSharding( @@ -1414,7 +1417,7 @@ HloSharding ScatterEffectiveIndexSharding( const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers(); int64_t num_elements = 1; int64_t index_dim = 0; - for (int64_t i = 0; i < scatter.shape().dimensions_size(); ++i) { + for (int64_t i = 0; i < scatter.shape().dimensions().size(); ++i) { if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { num_elements *= index_sharding.tile_assignment().dim(index_dim); index_dim++; @@ -1435,7 +1438,7 @@ HloSharding ScatterEffectiveIndexSharding( } const int64_t index_rank = - scatter.scatter_indices()->shape().dimensions_size(); + scatter.scatter_indices()->shape().dimensions().size(); DimensionVector slice_starts(index_rank, 0LL), slice_limits(index_rank); for (int64_t i = 0; i < index_rank; ++i) { if (i < index_dim) { @@ -1458,10 +1461,10 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers(); const int64_t data_rank = - scatter.scatter_updates()[0]->shape().dimensions_size(); + scatter.scatter_updates()[0]->shape().dimensions().size(); DimensionVector tile_assignment_dims(data_rank, 1LL); int64_t num_elements = 1; - for (int64_t i = 0; i < scatter.shape().dimensions_size(); ++i) { + for (int64_t i = 0; i < scatter.shape().dimensions().size(); ++i) { if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { CHECK_LT(i, data_rank); tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); @@ -1507,7 +1510,7 @@ GatherScatterDims GetGatherScatterOperandPassthroughDims( CHECK(absl::c_is_sorted(offset_or_window_dims)); int64_t collapsed_or_batching = 0; - for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < operand_shape.dimensions().size(); ++i) { if (IsCollapsedOrBatchingDim(collapsed_or_inserted_dims, operand_batching_dims, i)) { collapsed_or_batching++; @@ -1570,7 +1573,7 @@ std::optional PassthroughGatherOutputOrScatterUpdateToOperand( offset_or_window_dims, slice_size); HloSharding result = PropagateShardingAlongDimsAndReplicateOthers( output_or_update_sharding, operand_passthrough_dims.output_dims, - operand_passthrough_dims.operand_dims, operand_shape.dimensions_size()); + operand_passthrough_dims.operand_dims, operand_shape.dimensions().size()); if (result.IsTileMaximal()) { return std::nullopt; } @@ -1607,7 +1610,7 @@ std::optional GatherOperandShardingFromOutputParallelDimensions( return PropagateShardingAlongDimsAndReplicateOthers( output_sharding, parallel_dims.output_dims, parallel_dims.operand_dims, - gather.operand(0)->shape().dimensions_size()); + gather.operand(0)->shape().dimensions().size()); } } // namespace @@ -1625,7 +1628,7 @@ GatherOutputShardingFromOperandOperandPassthroughDimensions( const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); return PassthroughOperandToGatherOutputOrScatterUpdate( - operand_shape, operand_sharding, hlo.shape().dimensions_size(), + operand_shape, operand_sharding, hlo.shape().dimensions().size(), dnums.collapsed_slice_dims(), dnums.operand_batching_dims(), dnums.offset_dims(), slice_sizes); } @@ -1666,9 +1669,9 @@ std::optional GatherOperandShardingFromOutput( std::vector GetScatterSliceSize(const Shape& operand_shape, const Shape& update_shape, const ScatterDimensionNumbers& dnums) { - std::vector slice_size(operand_shape.dimensions_size(), 1); + std::vector slice_size(operand_shape.dimensions().size(), 1); int64_t num_update_window_dims = 0; - for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < operand_shape.dimensions().size(); ++i) { if (IsCollapsedOrBatchingDim(dnums.inserted_window_dims(), dnums.input_batching_dims(), i)) { continue; @@ -1745,7 +1748,7 @@ ScatterUpdateShardingFromOutputOperandPassthroughDimensions( const auto& dnums = scatter->scatter_dimension_numbers(); return PassthroughOperandToGatherOutputOrScatterUpdate( output_shape, output_sharding, - scatter->scatter_updates()[0]->shape().dimensions_size(), + scatter->scatter_updates()[0]->shape().dimensions().size(), dnums.inserted_window_dims(), dnums.input_batching_dims(), dnums.update_window_dims(), slice_sizes); } @@ -1780,7 +1783,7 @@ std::optional ScatterUpdateShardingFromOutputParallelDimensions( return PropagateShardingAlongDimsAndReplicateOthers( output_sharding, parallel_dims.operand_dims, parallel_dims.output_dims, - scatter.scatter_updates()[0]->shape().dimensions_size()); + scatter.scatter_updates()[0]->shape().dimensions().size()); } absl::StatusOr, HloOpcode>> @@ -2065,7 +2068,8 @@ std::optional GetDimensionForIota(const HloInstruction* maybe_iota, return std::nullopt; } if (maybe_iota->IsConstant()) { - std::vector is_iota_dim(maybe_iota->shape().dimensions_size(), true); + std::vector is_iota_dim(maybe_iota->shape().dimensions().size(), + true); maybe_iota->literal().EachCell( [&](absl::Span indices, int32_t val) { for (int64_t i = 0; i < indices.size(); ++i) { @@ -2190,7 +2194,7 @@ std::optional GetGatherScatterBatchParallelDims( int concatenated_dims = 0; for (const HloInstruction* op : indices->operands()) { const int64_t num_indices_from_element = - op->shape().dimensions_size() > index_vector_dim + op->shape().dimensions().size() > index_vector_dim ? op->shape().dimensions(index_vector_dim) : 1; if (std::optional maybe_iota_dim = @@ -2208,7 +2212,7 @@ std::optional GetGatherScatterBatchParallelDims( if (*maybe_iota_dim != index_vector_dim) { // This is a case of a single iota with index_dim being out of bounds. const int64_t num_indices_from_element = - indices->shape().dimensions_size() > index_vector_dim + indices->shape().dimensions().size() > index_vector_dim ? indices->shape().dimensions(index_vector_dim) : 1; index_parallel_in_dim.assign(num_indices_from_element, *maybe_iota_dim); @@ -2333,9 +2337,9 @@ GatherScatterDims GetGatherScatterIndexPassThroughDims( std::back_inserter(excluded_indices_dims)); } return GetGatherConnectedDimsAcrossIndicesAndOutput( - gather->operand(1)->shape().dimensions_size(), dnums.index_vector_dim(), - hlo.shape().dimensions_size(), dnums.offset_dims(), - excluded_indices_dims); + gather->operand(1)->shape().dimensions().size(), + dnums.index_vector_dim(), hlo.shape().dimensions().size(), + dnums.offset_dims(), excluded_indices_dims); } if (const auto* scatter = DynCast(&hlo)) { @@ -2349,9 +2353,9 @@ GatherScatterDims GetGatherScatterIndexPassThroughDims( std::back_inserter(excluded_indices_dims)); } return GetGatherConnectedDimsAcrossIndicesAndOutput( - scatter->scatter_indices()->shape().dimensions_size(), + scatter->scatter_indices()->shape().dimensions().size(), dnums.index_vector_dim(), - scatter->scatter_updates()[0]->shape().dimensions_size(), + scatter->scatter_updates()[0]->shape().dimensions().size(), dnums.update_window_dims(), excluded_indices_dims); } @@ -2364,7 +2368,7 @@ HloSharding InferGatherScatterParallelShardingFromOperandSharding( absl::Span output_parallel_dims) { return PropagateShardingAlongDimsAndReplicateOthers( operand_sharding, output_aligned_operand_parallel_dims, - output_parallel_dims, shape.dimensions_size()); + output_parallel_dims, shape.dimensions().size()); } std::string GroupedSharding::ToString() const { @@ -2933,14 +2937,14 @@ std::shared_ptr CreateTupleSharding( std::optional GetFirstTargetDimToMoveShardingTiles( const Shape& shape, const HloSharding& sharding, int64_t source_dim, std::function can_be_target_dim) { - if (shape.dimensions_size() < 2 || shape.dimensions(source_dim) == 1) { + if (shape.dimensions().size() < 2 || shape.dimensions(source_dim) == 1) { return std::nullopt; } if (!sharding.IsTiled() || sharding.tile_assignment().dim(source_dim) == 1) { return std::nullopt; } - for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < shape.dimensions().size(); ++dim) { if (dim == source_dim) { continue; } @@ -3004,7 +3008,7 @@ Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape) { Shape result_shape = shape; // sharding.TiledDataRank() == i < shape.dimensions_size() is not always true? for (int64_t i = 0; - i < sharding.TiledDataRank() && i < shape.dimensions_size(); ++i) { + i < sharding.TiledDataRank() && i < shape.dimensions().size(); ++i) { result_shape.set_dimensions( i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); } @@ -3039,7 +3043,7 @@ Shape TileLeafShape(const HloSharding& sharding, const Shape& shape) { } Shape result_shape = shape; for (int64_t i = 0; - i < sharding.TiledDataRank() && i < shape.dimensions_size(); ++i) { + i < sharding.TiledDataRank() && i < shape.dimensions().size(); ++i) { CHECK_EQ(shape.dimensions(i) % sharding.tile_assignment().dim(i), 0); result_shape.set_dimensions( i, shape.dimensions(i) / sharding.tile_assignment().dim(i)); From f1320d202edf1ef19366f5e325087346770d152b Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Mon, 14 Apr 2025 06:01:42 -0700 Subject: [PATCH 0698/1324] Move `RaggedAllToAllKernel` behind `GpuKernelRegistry`. * Moves `RaggedAllToAll` logic into `backends/gpu/runtime` since it's a runtime component. * Defines trait for the `RaggedAllToAll` kernel in `stream_executor/gpu/` * Moves the implementations of this kernel into stream_executor/{cuda|rocm} and registers them with the registry for each supported type. * Makes `RaggedAllToAll` retrieve the kernel by using the kernel registry. * Add the kernel implementations as dependencies to the `all_runtime` targets for CUDA and ROCm. PiperOrigin-RevId: 747382332 --- .../xla/xla/backends/gpu/runtime/BUILD | 52 +++++++++- .../gpu/runtime/ragged_all_to_all.cc} | 96 ++++++++++++------- .../gpu/runtime/ragged_all_to_all.h} | 6 +- .../gpu/runtime/ragged_all_to_all_test.cc} | 2 +- .../gpu/runtime/ragged_all_to_all_thunk.cc | 2 +- third_party/xla/xla/service/gpu/kernels/BUILD | 61 ------------ .../xla/xla/stream_executor/cuda/BUILD | 22 +++++ .../cuda/ragged_all_to_all_kernel_cuda.cc | 41 ++++++++ third_party/xla/xla/stream_executor/gpu/BUILD | 14 ++- .../gpu/ragged_all_to_all_kernel.h | 43 +++++++++ .../gpu/ragged_all_to_all_kernel_lib.cu.h} | 27 ++---- .../xla/xla/stream_executor/rocm/BUILD | 22 +++++ .../rocm/ragged_all_to_all_kernel_rocm.cc | 41 ++++++++ 13 files changed, 306 insertions(+), 123 deletions(-) rename third_party/xla/xla/{service/gpu/kernels/ragged_all_to_all_kernel.cc => backends/gpu/runtime/ragged_all_to_all.cc} (53%) rename third_party/xla/xla/{service/gpu/kernels/ragged_all_to_all_kernel.h => backends/gpu/runtime/ragged_all_to_all.h} (92%) rename third_party/xla/xla/{service/gpu/kernels/ragged_all_to_all_kernel_test.cc => backends/gpu/runtime/ragged_all_to_all_test.cc} (98%) create mode 100644 third_party/xla/xla/stream_executor/cuda/ragged_all_to_all_kernel_cuda.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel.h rename third_party/xla/xla/{service/gpu/kernels/ragged_all_to_all_kernel.cu.cc => stream_executor/gpu/ragged_all_to_all_kernel_lib.cu.h} (82%) create mode 100644 third_party/xla/xla/stream_executor/rocm/ragged_all_to_all_kernel_rocm.cc diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 409717e1569e60..4057f2d192f699 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -798,6 +798,7 @@ cc_library( tags = ["gpu"], deps = [ ":collective_thunk", + ":ragged_all_to_all", ":thunk", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -808,7 +809,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:rendezvous", - "//xla/service/gpu/kernels:ragged_all_to_all_kernel", "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_handle", @@ -1426,3 +1426,53 @@ xla_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "ragged_all_to_all", + srcs = ["ragged_all_to_all.cc"], + hdrs = ["ragged_all_to_all.h"], + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/gpu:ragged_all_to_all_kernel", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "ragged_all_to_all_test", + srcs = ["ragged_all_to_all_test.cc"], + backends = ["gpu"], + disabled_backends = [], + deps = [ + ":ragged_all_to_all", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor/gpu:gpu_init", + "//xla/stream_executor/gpu:ragged_all_to_all_kernel", + "//xla/stream_executor/host:host_platform", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + ], +) diff --git a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc similarity index 53% rename from third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.cc rename to third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc index 838767f89bf901..39c29d0a690df5 100644 --- a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/kernels/ragged_all_to_all_kernel.h" - #include #include #include @@ -22,15 +20,15 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/kernels/ragged_all_to_all_kernel_common.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/gpu/ragged_all_to_all_kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/typed_kernel_factory.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -39,27 +37,37 @@ namespace xla::gpu { namespace { -void* GetKernel(PrimitiveType element_type) { - switch (primitive_util::BitWidth(element_type)) { - case 8: - return GetRaggedAllToAllKernel(); - case 16: - return GetRaggedAllToAllKernel(); - case 32: - return GetRaggedAllToAllKernel(); - case 64: - return GetRaggedAllToAllKernel(); - default: - return nullptr; - } +template +absl::Status LaunchTypedKernel( + se::Stream* stream, se::StreamExecutor* executor, + const se::ThreadDim& thread_dims, const se::BlockDim& block_dims, + se::DeviceMemoryBase input_buffer, + const std::array& + output_ptrs, + se::DeviceMemoryBase input_offsets_buffer, + se::DeviceMemoryBase send_sizes_buffer, + se::DeviceMemoryBase output_offsets_buffer, int64_t num_updates_per_output, + int64_t num_row_elements) { + TF_ASSIGN_OR_RETURN( + auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry() + .LoadKernel>(executor)); + + return kernel.Launch(thread_dims, block_dims, stream, input_buffer, + output_ptrs, input_offsets_buffer, send_sizes_buffer, + output_offsets_buffer, num_updates_per_output, + num_row_elements); } } // namespace bool IsRaggedAllToAllKernelSupported(int64_t num_outputs, PrimitiveType element_type) { - return num_outputs <= kMaxNumRaggedAllToAllOutputPtrs && - GetKernel(element_type) != nullptr; + int bit_width = primitive_util::BitWidth(element_type); + + return num_outputs <= stream_executor::gpu::kMaxNumRaggedAllToAllOutputPtrs && + (bit_width == 8 || bit_width == 16 || bit_width == 32 || + bit_width == 64); } absl::Status RunRaggedAllToAllKernel( @@ -71,7 +79,8 @@ absl::Status RunRaggedAllToAllKernel( se::DeviceMemoryBase output_offsets_buffer, int64_t num_outputs, int64_t num_updates_per_output, int64_t num_input_rows, int64_t num_row_elements) { - if (output_buffers.size() > kMaxNumRaggedAllToAllOutputPtrs) { + if (output_buffers.size() > + stream_executor::gpu::kMaxNumRaggedAllToAllOutputPtrs) { return absl::InvalidArgumentError( "Number of output pointers exceeds the maximum supported number of " "output pointers."); @@ -93,25 +102,38 @@ absl::Status RunRaggedAllToAllKernel( std::min(CeilOfRatio(num_input_rows * num_row_elements, kThreads), kMaxBlocksPerUpdate); - TF_ASSIGN_OR_RETURN( - auto kernel, - (se::TypedKernelFactory< - se::DeviceMemoryBase, - std::array, - se::DeviceMemoryBase, se::DeviceMemoryBase, se::DeviceMemoryBase, - int64_t, int64_t>::Create(executor, "ragged_all_to_all", - GetKernel(element_type)))); - - std::array output_ptrs; + se::ThreadDim thread_dims(kThreads, 1, 1); + se::BlockDim block_dims(num_blocks_x, num_blocks_y, 1); + + std::array + output_ptrs; for (int64_t i = 0; i < output_buffers.size(); ++i) { output_ptrs[i] = output_buffers[i].opaque(); } - return kernel.Launch(se::ThreadDim(kThreads, 1, 1), - se::BlockDim(num_blocks_x, num_blocks_y, 1), stream, - input_buffer, output_ptrs, input_offsets_buffer, - send_sizes_buffer, output_offsets_buffer, - num_updates_per_output, num_row_elements); -} + auto launch_kernel = [&](auto type) -> absl::Status { + using T = decltype(type); + return LaunchTypedKernel(stream, executor, thread_dims, block_dims, + input_buffer, output_ptrs, input_offsets_buffer, + send_sizes_buffer, output_offsets_buffer, + num_updates_per_output, num_row_elements); + }; + switch (xla::primitive_util::BitWidth(element_type)) { + case 8: + return launch_kernel(uint8_t{}); + case 16: + return launch_kernel(uint16_t{}); + case 32: + return launch_kernel(uint32_t{}); + case 64: + return launch_kernel(uint64_t{}); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported element type: ", + primitive_util::LowercasePrimitiveTypeName(element_type), + " (bit width ", xla::primitive_util::BitWidth(element_type), + ") for RaggedAllToAll kernel.")); + } +} } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.h b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.h similarity index 92% rename from third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.h rename to third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.h index ed89d80f185ca4..408816b158968d 100644 --- a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.h +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_KERNELS_RAGGED_ALL_TO_ALL_KERNEL_H_ -#define XLA_SERVICE_GPU_KERNELS_RAGGED_ALL_TO_ALL_KERNEL_H_ +#ifndef XLA_BACKENDS_GPU_RUNTIME_RAGGED_ALL_TO_ALL_H_ +#define XLA_BACKENDS_GPU_RUNTIME_RAGGED_ALL_TO_ALL_H_ #include @@ -58,4 +58,4 @@ absl::Status RunRaggedAllToAllKernel( int64_t num_row_elements); } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_KERNELS_RAGGED_ALL_TO_ALL_KERNEL_H_ +#endif // XLA_BACKENDS_GPU_RUNTIME_RAGGED_ALL_TO_ALL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc rename to third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_test.cc index c7f9266504e346..fed99cc0e7b202 100644 --- a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/kernels/ragged_all_to_all_kernel.h" +#include "xla/backends/gpu/runtime/ragged_all_to_all.h" #include #include diff --git a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc index 09a70533f4b16f..feaae49dd850ca 100644 --- a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc @@ -37,13 +37,13 @@ limitations under the License. #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/collective_thunk.h" +#include "xla/backends/gpu/runtime/ragged_all_to_all.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/kernels/ragged_all_to_all_kernel.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/rendezvous.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 4b58886527fbc3..def94e22faef1c 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -322,67 +322,6 @@ xla_test( ], ) -cc_library( - name = "ragged_all_to_all_kernel", - srcs = ["ragged_all_to_all_kernel.cc"], - hdrs = ["ragged_all_to_all_kernel.h"], - tags = ["gpu"], - visibility = [":friends"], - deps = [ - ":ragged_all_to_all_kernel_gpu", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:stream", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:typed_kernel_factory", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - ], -) - -gpu_kernel_library( - name = "ragged_all_to_all_kernel_gpu", - srcs = ["ragged_all_to_all_kernel.cu.cc"], - hdrs = ["ragged_all_to_all_kernel_common.h"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -xla_test( - name = "ragged_all_to_all_kernel_test", - srcs = ["ragged_all_to_all_kernel_test.cc"], - backends = ["gpu"], - deps = [ - ":ragged_all_to_all_kernel", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/stream_executor:stream", - "//xla/stream_executor/gpu:gpu_init", - "//xla/stream_executor/host:host_platform", - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:test", - "//xla/tsl/platform:test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - ], -) - #===--------------------------------------------------------------------------------------------===# # CUTLASS Gemm <-> xla::gpu::kernel::CustomKernel adaptor #===--------------------------------------------------------------------------------------------===# diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index dc80d2397b3a07..8e0d321ebef7c3 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1123,6 +1123,7 @@ cc_library( ":cudnn_plugin", ":cufft_plugin", ":make_batch_pointers_kernel_cuda", + ":ragged_all_to_all_kernel_cuda", "//xla/tsl/cuda:cusolver", "//xla/tsl/cuda:cusparse", "//xla/tsl/cuda:tensorrt_rpath", @@ -2042,3 +2043,24 @@ cuda_library( ], alwayslink = 1, ) + +cuda_library( + name = "ragged_all_to_all_kernel_cuda", + srcs = [ + "ragged_all_to_all_kernel_cuda.cc", + "//xla/stream_executor/gpu:ragged_all_to_all_kernel_lib.cu.h", + ], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_platform_id", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/gpu:ragged_all_to_all_kernel", + "@local_config_cuda//cuda:cuda_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/cuda/ragged_all_to_all_kernel_cuda.cc b/third_party/xla/xla/stream_executor/cuda/ragged_all_to_all_kernel_cuda.cc new file mode 100644 index 00000000000000..cbb493e43f2b11 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/ragged_all_to_all_kernel_cuda.cc @@ -0,0 +1,41 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/gpu/ragged_all_to_all_kernel.h" +#include "xla/stream_executor/gpu/ragged_all_to_all_kernel_lib.cu.h" + +#define REGISTER_RAGGED_ALL_TO_ALL_KERNEL(TYPE, BITS) \ + GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( \ + RaggedAllToAllKernelCudaUInt##BITS, \ + stream_executor::gpu::RaggedAllToAllKernel, \ + stream_executor::cuda::kCudaPlatformId, ([] { \ + stream_executor::MultiKernelLoaderSpec spec(7); \ + spec.AddInProcessSymbol( \ + absl::bit_cast( \ + &stream_executor::gpu::RaggedAllToAllKernelImpl), \ + "ragged_all_to_all_kernel_uint" #BITS); \ + return spec; \ + })); + +// Register the kernel for different integer types using the macro +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint8_t, 8); +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint16_t, 16); +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint32_t, 32); +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint64_t, 64); diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 71a6ece0220379..59a0f4c680e5f8 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -864,7 +864,10 @@ cc_library( ], ) -exports_files(["buffer_comparator_kernel_lib.cu.h"]) +exports_files([ + "buffer_comparator_kernel_lib.cu.h", + "ragged_all_to_all_kernel_lib.cu.h", +]) cc_library( name = "make_batch_pointers_kernel", @@ -874,3 +877,12 @@ cc_library( "//xla/stream_executor:kernel", ], ) + +cc_library( + name = "ragged_all_to_all_kernel", + hdrs = ["ragged_all_to_all_kernel.h"], + deps = [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + ], +) diff --git a/third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel.h b/third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel.h new file mode 100644 index 00000000000000..6e4530b1a421c9 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel.h @@ -0,0 +1,43 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_RAGGED_ALL_TO_ALL_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_GPU_RAGGED_ALL_TO_ALL_KERNEL_H_ + +#include + +#include +#include + +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" + +namespace stream_executor::gpu { +inline constexpr int64_t kMaxNumRaggedAllToAllOutputPtrs = 8; + +// Defines a trait for the RaggedAllToAll kernel that can be used to register +// and look up the kernel in the GPU kernel registry. +template +struct RaggedAllToAllKernel { + using KernelType = stream_executor::TypedKernel< + stream_executor::DeviceMemoryBase, + std::array, + stream_executor::DeviceMemoryBase, stream_executor::DeviceMemoryBase, + stream_executor::DeviceMemoryBase, int64_t, int64_t>; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_RAGGED_ALL_TO_ALL_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.cu.cc b/third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel_lib.cu.h similarity index 82% rename from third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.cu.cc rename to third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel_lib.cu.h index e3a0b201ede5b2..3944b4fe433d24 100644 --- a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/ragged_all_to_all_kernel_lib.cu.h @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef XLA_STREAM_EXECUTOR_GPU_RAGGED_ALL_TO_ALL_KERNEL_LIB_CU_H_ +#define XLA_STREAM_EXECUTOR_GPU_RAGGED_ALL_TO_ALL_KERNEL_LIB_CU_H_ + #include #include -#include "xla/service/gpu/kernels/ragged_all_to_all_kernel_common.h" +#include "xla/stream_executor/gpu/ragged_all_to_all_kernel.h" -namespace xla::gpu { -namespace { +namespace stream_executor::gpu { // RaggedAllToAll instruction performs a collective AllToAll operation on ragged // tensors. For the semantics of each operand see the documentation of @@ -49,7 +51,7 @@ namespace { // - Block grid: (N*num_updates_per_rank, num_blocks_per_update, 1) // - Thread grid: (num_threads_per_update, 1, 1) template -__global__ void __launch_bounds__(128) RaggedAllToAllKernel( +__global__ void __launch_bounds__(128) RaggedAllToAllKernelImpl( const T* __restrict__ input_ptr, std::array output_ptrs, const int64_t* __restrict__ input_offsets_ptr, @@ -59,7 +61,7 @@ __global__ void __launch_bounds__(128) RaggedAllToAllKernel( int64_t update_idx = blockIdx.x; int64_t output_idx = update_idx / num_updates_per_replica; - T* output_ptr = reinterpret_cast(output_ptrs[output_idx]); + T* output_ptr = absl::bit_cast(output_ptrs[output_idx]); int64_t input_offset = input_offsets_ptr[update_idx]; int64_t send_size = send_sizes_ptr[update_idx]; @@ -75,17 +77,6 @@ __global__ void __launch_bounds__(128) RaggedAllToAllKernel( output_ptr[output_offset_start + i] = input_ptr[input_offset_start + i]; } } +} // namespace stream_executor::gpu -} // namespace - -template -void* GetRaggedAllToAllKernel() { - return reinterpret_cast(&RaggedAllToAllKernel); -} - -template void* GetRaggedAllToAllKernel(); -template void* GetRaggedAllToAllKernel(); -template void* GetRaggedAllToAllKernel(); -template void* GetRaggedAllToAllKernel(); - -} // namespace xla::gpu +#endif // XLA_STREAM_EXECUTOR_GPU_RAGGED_ALL_TO_ALL_KERNEL_LIB_CU_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index d42871f9cb2b95..84f2f9551f698a 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -828,6 +828,7 @@ cc_library( ":hipfft_plugin", ":make_batch_pointers_kernel_rocm", ":miopen_plugin", + ":ragged_all_to_all_kernel_rocm", ":rocblas_plugin", ":rocm_helpers", ":rocm_platform", @@ -1121,3 +1122,24 @@ rocm_library( ], alwayslink = 1, ) + +rocm_library( + name = "ragged_all_to_all_kernel_rocm", + srcs = [ + "ragged_all_to_all_kernel_rocm.cc", + "//xla/stream_executor/gpu:ragged_all_to_all_kernel_lib.cu.h", + ], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + ":rocm_platform_id", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/stream_executor/gpu:ragged_all_to_all_kernel", + "@local_config_rocm//rocm:rocm_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/rocm/ragged_all_to_all_kernel_rocm.cc b/third_party/xla/xla/stream_executor/rocm/ragged_all_to_all_kernel_rocm.cc new file mode 100644 index 00000000000000..a3c29dd9bf1c49 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/ragged_all_to_all_kernel_rocm.cc @@ -0,0 +1,41 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/gpu/ragged_all_to_all_kernel.h" +#include "xla/stream_executor/gpu/ragged_all_to_all_kernel_lib.cu.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" + +#define REGISTER_RAGGED_ALL_TO_ALL_KERNEL(TYPE, BITS) \ + GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( \ + RaggedAllToAllKernelRocmUInt##BITS, \ + stream_executor::gpu::RaggedAllToAllKernel, \ + stream_executor::rocm::kROCmPlatformId, ([] { \ + stream_executor::MultiKernelLoaderSpec spec(7); \ + spec.AddInProcessSymbol( \ + absl::bit_cast( \ + &stream_executor::gpu::RaggedAllToAllKernelImpl), \ + "ragged_all_to_all_kernel_uint" #BITS); \ + return spec; \ + })); + +// Register the kernel for different integer types using the macro +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint8_t, 8); +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint16_t, 16); +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint32_t, 32); +REGISTER_RAGGED_ALL_TO_ALL_KERNEL(uint64_t, 64); From 2586933beebdfb83af4ffaa5846bd05607439c05 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 14 Apr 2025 06:22:50 -0700 Subject: [PATCH 0699/1324] [XLA:GPU] Fix flag parsing for a newly added unsupported flag When adding the flag definition, I used copy/paste and forgot to adjust 2 lines. PiperOrigin-RevId: 747388785 --- third_party/xla/xla/debug_options_flags.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 7bd5f886de4ee5..6f5e06a6e65277 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -1779,9 +1779,8 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_gpu_unsupported_enable_triton_multi_output_fusion", bool_setter_for( &DebugOptions:: - set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms), - debug_options - ->xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(), + set_xla_gpu_unsupported_enable_triton_multi_output_fusion), + debug_options->xla_gpu_unsupported_enable_triton_multi_output_fusion(), "Enable Triton multi-output fusions.")); flag_list->push_back(tsl::Flag( "xla_gpu_verify_triton_fusion_numerics", From db2f9667f156c0e2844b6837deae6c01c1588f36 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 14 Apr 2025 07:05:59 -0700 Subject: [PATCH 0700/1324] [xla:gpu] NFC: Remove `builder` argument from some functions that don't need it. PiperOrigin-RevId: 747401649 --- .../gpu/codegen/triton/emitter_helpers.cc | 6 +-- .../gpu/codegen/triton/emitter_helpers.h | 2 +- .../gpu/codegen/triton/fusion_emitter.cc | 24 +++++------ ...riton_xla_extract_insert_to_triton_pass.cc | 43 ++++++++----------- 4 files changed, 34 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index f7124577c5a061..fce076ae44fbc2 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -126,9 +126,9 @@ absl::StatusOr GetPrimitiveType(Type t) { return absl::UnimplementedError("Unsupported type in getPrimitiveType.\n"); } -Type StorageType(EmitterLocOpBuilder& b, Type t) { - if (t.isInteger(1)) { - return b.getI8Type(); +Type StorageType(Type t) { + if (auto i = mlir::dyn_cast(t); i && i.getWidth() == 1) { + return i.get(i.getContext(), 8, i.getSignedness()); } return t; } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h index 093e6cfdaea680..9217ed4b9e53ba 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h @@ -106,7 +106,7 @@ absl::StatusOr TritonType(EmitterLocOpBuilder& b, PrimitiveType t); // Triton type -> XLA type conversions. absl::StatusOr GetPrimitiveType(mlir::Type t); -mlir::Type StorageType(EmitterLocOpBuilder& b, mlir::Type t); +mlir::Type StorageType(mlir::Type t); // Get the value of the scalar constant's literal in a C++ type. template diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 1983d45ecdf671..f93698bd992c99 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -207,7 +207,7 @@ ScalarOrTensor EmitParameterExtract(EmitterLocOpBuilder& b, return ScalarOrTensor(b.create( mlir::RankedTensorType::get( tiled_tensor_type.getTileShape(), - StorageType(b, tiled_tensor_type.getElementType())), + StorageType(tiled_tensor_type.getElementType())), tile_op.getResult(), offsets)); } @@ -1120,7 +1120,7 @@ absl::StatusOr EmitTiledHloInstruction( if (expected_element_type != loaded_element_type) { // Ensure that we didn't mess up somewhere else by checking that we // indeed loaded the expected storage type for the expected element type. - if (loaded_element_type != StorageType(b, expected_element_type)) { + if (loaded_element_type != StorageType(expected_element_type)) { return absl::InternalError(absl::StrCat( "Parameters were loaded with an unexpected element type " "while lowering ", @@ -1363,7 +1363,7 @@ absl::StatusOr CreateTileOp( const Shape& shape = tiled_hlo.hlo()->shape(); TF_ASSIGN_OR_RETURN(Type expected_element_type, TritonType(b, shape.element_type())); - Type storage_type = StorageType(b, expected_element_type); + Type storage_type = StorageType(expected_element_type); auto result_type = mtx::TiledTensorType::get( b.getContext(), padded_tile_sizes, llvm::ArrayRef(shape.dimensions().data(), @@ -1441,7 +1441,7 @@ absl::StatusOr> EmitGeneric( // as i8. It's important to check converted types before storing if the type // of the result does not match the type of the output pointer. Type result_element_type = getElementTypeOrSelf(result.getType()); - Type result_storage_type = StorageType(b, result_element_type); + Type result_storage_type = StorageType(result_element_type); if (result_element_type != result_storage_type) { result = @@ -1539,29 +1539,29 @@ absl::Status CreateInternalError(absl::string_view message, // Legacy emitter works with tt.func. New emitter works with func.func. // TODO(393299275): Remove legacy optionality once migration is complete. -void AppendFuncArgType(EmitterLocOpBuilder& b, absl::Span dims, +void AppendFuncArgType(absl::Span dims, absl::string_view fusion_kind, Type ir_type, SmallVector& fn_arg_types) { if (fusion_kind == kTritonGemmFusionKind) { fn_arg_types.push_back(ttir::PointerType::get( - StorageType(b, ir_type), mlir::NVVM::kGlobalMemorySpace)); + StorageType(ir_type), mlir::NVVM::kGlobalMemorySpace)); } else { fn_arg_types.push_back(mlir::RankedTensorType::get( llvm::ArrayRef(dims.data(), dims.size()), - StorageType(b, ir_type))); + StorageType(ir_type))); } } // Only needed for the new emitter since we are using func.func instead of // tt.func. // TODO(393299275): Remove legacy optionality once migration is complete. -void AppendFuncResultType(EmitterLocOpBuilder& b, absl::string_view fusion_kind, +void AppendFuncResultType(absl::string_view fusion_kind, absl::Span dims, Type ir_type, SmallVector& fn_result_types) { if (fusion_kind != kTritonGemmFusionKind) { fn_result_types.push_back(mlir::RankedTensorType::get( llvm::ArrayRef(dims.data(), dims.size()), - StorageType(b, ir_type))); + StorageType(ir_type))); } } @@ -1662,7 +1662,7 @@ absl::StatusOr> CreateTritonModule( TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type)); } - AppendFuncArgType(b, p->shape().dimensions(), fusion_kind, ir_type, + AppendFuncArgType(p->shape().dimensions(), fusion_kind, ir_type, fn_arg_types); } @@ -1671,9 +1671,9 @@ absl::StatusOr> CreateTritonModule( for (const ShapeUtil::IndexedShape& s : ShapeUtil::GetLeafShapes(fusion->shape())) { TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type())); - AppendFuncArgType(b, s.shape.dimensions(), fusion_kind, triton_ty, + AppendFuncArgType(s.shape.dimensions(), fusion_kind, triton_ty, fn_arg_types); - AppendFuncResultType(b, fusion_kind, s.shape.dimensions(), triton_ty, + AppendFuncResultType(fusion_kind, s.shape.dimensions(), triton_ty, fn_result_types); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index 6c8dc01a829725..013589b223e612 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -70,8 +70,8 @@ namespace { #define GEN_PASS_DEF_TRITONXLAEXTRACTINSERTTOTRITONPASS #include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" -PointerType GetTensorPtrType(::xla::EmitterLocOpBuilder& builder, Type type) { - return PointerType::get(xgt::StorageType(builder, type), +PointerType GetTensorPtrType(Type type) { + return PointerType::get(xgt::StorageType(type), mlir::NVVM::kGlobalMemorySpace); } @@ -79,7 +79,7 @@ PointerType GetTensorPtrTypeForTma(::xla::EmitterLocOpBuilder& builder) { // Triton frontend is passing zero in the address space. This doesn't map to // anything meaningful in NVVM dialect. Setting it to be consistent with // Triton. - return PointerType::get(xgt::StorageType(builder, builder.getI8Type()), + return PointerType::get(xgt::StorageType(builder.getI8Type()), /*addrspace=*/0); } @@ -88,10 +88,9 @@ TensorDescType GetTensorDescPtrType(::xla::EmitterLocOpBuilder& builder, return TensorDescType::get(builder.getContext(), type); } -RankedTensorType GetRankedTensorType(::xla::EmitterLocOpBuilder& builder, - TiledTensorType type) { - return RankedTensorType::get( - type.getTileShape(), xgt::StorageType(builder, type.getElementType())); +RankedTensorType GetRankedTensorType(TiledTensorType type) { + return RankedTensorType::get(type.getTileShape(), + xgt::StorageType(type.getElementType())); } bool AreRankedTensors(ArrayRef types) { @@ -294,7 +293,7 @@ struct RewriteFuncOp : mlir::OpRewritePattern { cast_to_orig_type = builder.create( operand_type, func_arg); operand_type = GetTensorPtrType( - builder, mlir::cast(operand_type).getElementType()); + mlir::cast(operand_type).getElementType()); } func_arg.replaceAllUsesExcept(cast_to_orig_type.getResult(0), cast_to_orig_type); @@ -367,9 +366,9 @@ struct RewriteTile : mlir::OpRewritePattern { // can_use_tma ? "tensor -> !tt.ptr" otherwise "tensor -> !tt.ptr<>" Type ptr_type = - can_use_tma ? GetTensorPtrTypeForTma(builder) - : GetTensorPtrType( - builder, op.getTensor().getType().getElementType()); + can_use_tma + ? GetTensorPtrTypeForTma(builder) + : GetTensorPtrType(op.getTensor().getType().getElementType()); auto cast_to_tensor_ptr_type = builder.create(ptr_type, op.getTensor()); @@ -410,8 +409,7 @@ struct RewriteTile : mlir::OpRewritePattern { // !tt.tensordesc -> tiled_tensor auto cast_desc_ptr_to_tiled_tensor_ptr_type = builder.create( - xgt::StorageType(builder, tiled_tensor_type), - reinterpret_tensor_desc); + xgt::StorageType(tiled_tensor_type), reinterpret_tensor_desc); rewriter.replaceOp(op, cast_desc_ptr_to_tiled_tensor_ptr_type); return mlir::success(); @@ -448,7 +446,7 @@ struct RewriteTile : mlir::OpRewritePattern { // !tt.ptr -> tiled_tensor auto cast_to_tiled_tensor_type = builder.create( - xgt::StorageType(builder, tiled_tensor_type), ptr); + xgt::StorageType(tiled_tensor_type), ptr); rewriter.replaceOp(op, cast_to_tiled_tensor_type); return mlir::success(); @@ -473,8 +471,7 @@ struct RewriteExtract : mlir::OpRewritePattern { builder .create( GetTensorDescPtrType( - builder, - GetRankedTensorType(builder, op.getSrc().getType())), + builder, GetRankedTensorType(op.getSrc().getType())), op.getSrc()) .getResult(0); @@ -492,8 +489,7 @@ struct RewriteExtract : mlir::OpRewritePattern { auto cast_to_tensor_ptr_type = builder .create( - GetTensorPtrType(builder, GetRankedTensorType( - builder, op.getSrc().getType())), + GetTensorPtrType(GetRankedTensorType(op.getSrc().getType())), op.getSrc()) .getResult(0); @@ -532,8 +528,7 @@ struct RewriteInsert : mlir::OpRewritePattern { builder .create( GetTensorDescPtrType( - builder, - GetRankedTensorType(builder, op.getDst().getType())), + builder, GetRankedTensorType(op.getDst().getType())), op.getDst()) .getResult(0); @@ -545,9 +540,7 @@ struct RewriteInsert : mlir::OpRewritePattern { auto cast_dst_to_tensor_ptr_type = builder .create( - GetTensorPtrType( - builder, - GetRankedTensorType(builder, op.getDst().getType())), + GetTensorPtrType(GetRankedTensorType(op.getDst().getType())), op.getDst()) .getResult(0); @@ -582,7 +575,7 @@ struct RewriteScalarInsert : mlir::OpRewritePattern { return rewriter.notifyMatchFailure(op, "Expected dest to be scalar."); } ::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter); - auto ptr_type = GetTensorPtrType(builder, op.getScalar().getType()); + auto ptr_type = GetTensorPtrType(op.getScalar().getType()); auto cast_dst_to_tensor_ptr_type = builder.create(ptr_type, op.getDest()) .getResult(0); @@ -604,7 +597,7 @@ struct RewriteScalarExtract : mlir::OpRewritePattern { return rewriter.notifyMatchFailure(op, "Expected src to be scalar."); } ::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter); - auto ptr_type = GetTensorPtrType(builder, op.getType()); + auto ptr_type = GetTensorPtrType(op.getType()); auto cast_src_to_tensor_ptr_type = builder .create(ptr_type, op.getTensor()) From 591a780aed798cd7d7445719b624b5b868fff658 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 14 Apr 2025 07:18:10 -0700 Subject: [PATCH 0701/1324] [XLA:CPU] Create boilerplate required for fusion kernel emitters PiperOrigin-RevId: 747405490 --- .../xla/xla/backends/cpu/testlib/BUILD | 29 +++++++ .../xla/xla/backends/cpu/testlib/__init__.py | 8 ++ .../xla/backends/cpu/testlib/kernel_runner.cc | 29 +++++++ .../xla/backends/cpu/testlib/kernel_runner.h | 7 ++ .../cpu/testlib/kernel_runner_extension.cc | 32 ++++++++ .../cpu/testlib/kernel_runner_test.py | 61 +++++++++++++++ .../cpu/testlib/mlir_kernel_emitter.cc | 76 +++++++++++++++++++ .../cpu/testlib/mlir_kernel_emitter.h | 65 ++++++++++++++++ third_party/xla/xla/codegen/BUILD | 17 +++++ .../xla/xla/codegen/mlir_kernel_source.cc | 59 ++++++++++++++ .../xla/xla/codegen/mlir_kernel_source.h | 64 ++++++++++++++++ third_party/xla/xla/codegen/testlib/BUILD | 2 + .../xla/xla/codegen/testlib/__init__.py | 2 + .../testlib/kernel_runner_extension.cc | 18 +++++ third_party/xla/xla/hlo/ir/hlo_instruction.cc | 2 +- 15 files changed, 470 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc create mode 100644 third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.h create mode 100644 third_party/xla/xla/codegen/mlir_kernel_source.cc create mode 100644 third_party/xla/xla/codegen/mlir_kernel_source.h diff --git a/third_party/xla/xla/backends/cpu/testlib/BUILD b/third_party/xla/xla/backends/cpu/testlib/BUILD index ab30cccd80b831..8fd6bd61efd08f 100644 --- a/third_party/xla/xla/backends/cpu/testlib/BUILD +++ b/third_party/xla/xla/backends/cpu/testlib/BUILD @@ -26,6 +26,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/backends/cpu/codegen:cpu_features", "//xla/backends/cpu/codegen:execution_engine", + "//xla/backends/cpu/codegen:fusion_compiler", "//xla/backends/cpu/codegen:ir_compiler", "//xla/backends/cpu/codegen:jit_compiler", "//xla/backends/cpu/runtime:function_library", @@ -34,6 +35,7 @@ cc_library( "//xla/codegen:kernel_definition", "//xla/codegen:kernel_spec", "//xla/codegen:llvm_ir_kernel_source", + "//xla/codegen:mlir_kernel_source", "//xla/codegen/testlib:kernel_runner", "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_options", @@ -99,6 +101,28 @@ cc_library( ], ) +cc_library( + name = "mlir_kernel_emitter", + srcs = ["mlir_kernel_emitter.cc"], + hdrs = ["mlir_kernel_emitter.h"], + deps = [ + "//xla/backends/cpu/codegen:fusion_compiler", + "//xla/codegen:kernel_definition", + "//xla/codegen:kernel_emitter", + "//xla/codegen:kernel_spec", + "//xla/codegen:mlir_kernel_source", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/stream_executor:launch_dim", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + tsl_pybind_extension( name = "_extension", testonly = 1, @@ -107,13 +131,16 @@ tsl_pybind_extension( deps = [ ":kernel_runner", ":llvm_ir_kernel_emitter", + ":mlir_kernel_emitter", # placeholder for index annotation deps # buildcleaner: keep "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", "@nanobind", "@local_config_python//:python_headers", # buildcleaner: keep "//xla/backends/cpu/codegen:computation_kernel_emitter", + "//xla/backends/cpu/codegen:fusion_compiler", "//xla/backends/cpu/codegen:jit_compiler", "//xla/backends/cpu/codegen:target_machine_features", "//xla/backends/cpu/codegen/dot:dot_kernel_emitter", @@ -121,6 +148,8 @@ tsl_pybind_extension( "//xla/backends/cpu/codegen/elemental:elemental_kernel_emitter", "//xla/codegen:kernel_definition", "//xla/codegen:kernel_emitter", + "//xla/codegen:llvm_ir_kernel_source", + "//xla/codegen:mlir_kernel_source", "//xla/codegen/testlib:kernel_runner", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", diff --git a/third_party/xla/xla/backends/cpu/testlib/__init__.py b/third_party/xla/xla/backends/cpu/testlib/__init__.py index bd43b544dee052..10aca5b99a8d52 100644 --- a/third_party/xla/xla/backends/cpu/testlib/__init__.py +++ b/third_party/xla/xla/backends/cpu/testlib/__init__.py @@ -16,6 +16,7 @@ from xla.backends.cpu.testlib import _extension +# Classes. # go/keep-sorted start ComputationKernelEmitter = _extension.ComputationKernelEmitter ConcatenateKernelEmitter = _extension.ConcatenateKernelEmitter @@ -25,5 +26,12 @@ JitCompiler = _extension.JitCompiler KernelRunner = _extension.KernelRunner LlvmIrKernelEmitter = _extension.LlvmIrKernelEmitter +MLIRContext = _extension.MLIRContext +MlirKernelEmitter = _extension.MlirKernelEmitter TargetMachineFeatures = _extension.TargetMachineFeatures # go/keep-sorted end + +# Free functions. +# go/keep-sorted start +lower_to_llvm = _extension.lower_to_llvm +# go/keep-sorted end diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc index 78efebd64e6b81..4c93abbd0eef9e 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc @@ -24,9 +24,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Module.h" #include "llvm/Target/TargetOptions.h" #include "xla/backends/cpu/codegen/cpu_features.h" #include "xla/backends/cpu/codegen/execution_engine.h" +#include "xla/backends/cpu/codegen/fusion_compiler.h" #include "xla/backends/cpu/codegen/ir_compiler.h" #include "xla/backends/cpu/codegen/jit_compiler.h" #include "xla/backends/cpu/runtime/function_library.h" @@ -35,6 +37,7 @@ limitations under the License. #include "xla/codegen/kernel_definition.h" #include "xla/codegen/kernel_spec.h" #include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/codegen/mlir_kernel_source.h" #include "xla/service/cpu/cpu_options.h" #include "xla/service/cpu/runtime_symbol_generator.h" #include "xla/service/hlo_module_config.h" @@ -56,6 +59,10 @@ absl::StatusOr KernelRunner::Create( dynamic_cast(kernel_source.get())) { return Create(kernel_spec, std::move(*llvm_kernel_source), std::move(compiler)); + } else if (auto* mlir_kernel_source = + dynamic_cast(kernel_source.get())) { + return Create(kernel_spec, std::move(*mlir_kernel_source), + std::move(compiler)); } return absl::InvalidArgumentError("Unrecognised kernel spec type"); @@ -79,6 +86,16 @@ absl::StatusOr KernelRunner::Create( return KernelRunner(std::move(library), Kernel(1, kernel_fn), thread_dim); } +absl::StatusOr KernelRunner::Create( + const KernelSpec& kernel_spec, MlirKernelSource mlir_kernel_source, + JitCompiler compiler) { + TF_ASSIGN_OR_RETURN(LlvmIrKernelSource llvm_ir_kernel_source, + LowerToLlvm(mlir_kernel_source)); + + return Create(kernel_spec, std::move(llvm_ir_kernel_source), + std::move(compiler)); +} + KernelRunner::KernelRunner(std::unique_ptr library, Kernel kernel, Kernel::ThreadDim thread_dim) : library_(std::move(library)), @@ -133,4 +150,16 @@ absl::StatusOr KernelRunner::CreateJitCompiler( std::move(ir_compiler)); } +absl::StatusOr LowerToLlvm( + MlirKernelSource& mlir_kernel_source) { + auto llvm_context = std::make_unique(); + + FusionCompiler fusion_compiler(FusionCompiler::Options{}); + TF_ASSIGN_OR_RETURN( + std::unique_ptr llvm_module, + fusion_compiler.Compile(*llvm_context, mlir_kernel_source.module())); + + return LlvmIrKernelSource(std::move(llvm_context), std::move(llvm_module)); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h index 702cb13bf53bd4..126a8d9101ea92 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/codegen/kernel_definition.h" #include "xla/codegen/kernel_spec.h" #include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/codegen/mlir_kernel_source.h" #include "xla/codegen/testlib/kernel_runner.h" #include "xla/service/hlo_module_config.h" @@ -53,6 +54,9 @@ class KernelRunner final : public xla::KernelRunner { static absl::StatusOr Create( const KernelSpec& kernel_spec, LlvmIrKernelSource llvm_ir_kernel_source, JitCompiler compiler); + static absl::StatusOr Create( + const KernelSpec& kernel_spec, MlirKernelSource mlir_kernel_source, + JitCompiler compiler); KernelRunner(std::unique_ptr library, Kernel kernel, Kernel::ThreadDim thread_dim); @@ -62,6 +66,9 @@ class KernelRunner final : public xla::KernelRunner { Kernel::ThreadDim thread_dim_; }; +absl::StatusOr LowerToLlvm( + MlirKernelSource& mlir_kernel_source); + } // namespace xla::cpu #endif // XLA_BACKENDS_CPU_TESTLIB_KERNEL_RUNNER_H_ diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc index 9cf51d1fb3ad8d..95c00ed2a2c15b 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/tuple.h" // IWYU pragma: keep @@ -32,12 +33,16 @@ limitations under the License. #include "xla/backends/cpu/codegen/dot/dot_kernel_emitter.h" #include "xla/backends/cpu/codegen/elemental/concatenate_kernel_emitter.h" #include "xla/backends/cpu/codegen/elemental/elemental_kernel_emitter.h" +#include "xla/backends/cpu/codegen/fusion_compiler.h" #include "xla/backends/cpu/codegen/jit_compiler.h" #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/backends/cpu/testlib/kernel_runner.h" #include "xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h" +#include "xla/backends/cpu/testlib/mlir_kernel_emitter.h" #include "xla/codegen/kernel_definition.h" #include "xla/codegen/kernel_emitter.h" +#include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/codegen/mlir_kernel_source.h" #include "xla/codegen/testlib/kernel_runner.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -88,6 +93,30 @@ NB_MODULE(_extension, kernel_runner_module) { {}); }); + nb::class_(kernel_runner_module, + "MlirKernelEmitter") + .def("__init__", + [](MlirKernelEmitter* self, absl::string_view ir, + absl::string_view kernel_name, NbThreadDim thread_dim) { + new (self) MlirKernelEmitter( + ir, kernel_name, + se::ThreadDim{std::get<0>(thread_dim), std::get<1>(thread_dim), + std::get<2>(thread_dim)}, + {}); + }); + + kernel_runner_module.def("lower_to_llvm", [](MlirKernelSource& source) { + absl::StatusOr llvm_ir_kernel_source = + LowerToLlvm(source); + + if (!llvm_ir_kernel_source.ok()) { + throw std::runtime_error( + std::string(llvm_ir_kernel_source.status().message())); + } + + return std::move(llvm_ir_kernel_source).value(); + }); + nb::class_(kernel_runner_module, "HloCompiler") .def(nb::init<>()) .def("create_buffer_assignment", @@ -114,6 +143,9 @@ NB_MODULE(_extension, kernel_runner_module) { return std::move(schedule).value(); }); + nb::class_(kernel_runner_module, "MLIRContext") + .def(nb::new_([] { return FusionCompiler::CreateContext(); })); + nb::class_(kernel_runner_module, "TargetMachineFeatures") .def("__str__", &TargetMachineFeatures::get_target_feature_string); diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py index 802e3ffa75c498..9689e288478ca8 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py @@ -68,5 +68,66 @@ def test_llvm_ir_kernel_runner(self): np.testing.assert_array_equal(np.asarray(c), np.asarray(a) + np.asarray(b)) +class MlirKernelRunnerTest(absltest.TestCase): + + def test_mlir_kernel_runner(self): + ir = """ + #indexing_map = #xla.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1023], s1 in [0, 31]"> + module attributes {dlti.dl_spec = #dlti.dl_spec} { + func.func private + @sum_kernel_entry(%input_buffer: tensor<1024x32xf32>, + %output_buffer: tensor<1xf32>) -> tensor<1xf32> + attributes {xla.backend_kind = #xla.backend_kind, xla.entry} { + // Initial sum set to 0. + %sum_0 = arith.constant 0.0 : f32 + // iter_args binds initial values to the loop's region arguments. + %sum = xla.loop ()[%i, %j] -> (%r0, %r1) + in #indexing_map iter_args(%sum_iter = %sum_0) -> (f32) { + %t = tensor.extract %input_buffer[%i, %j] : tensor<1024x32xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + // Yield current iteration sum to next iteration %sum_iter or to %sum + // if final iteration. + xla.yield %sum_next : f32 + } + + // Ideally it would be be possible to do this in the kernel region, + // but currently our lowering results in xla_cpu.store being removed + // before the tensors are lowered which then results in the insert + // being removed as tensors have value scemantics and are treated as + // pure. + %zero_index = arith.constant 0 : index + %inserted = tensor.insert %sum into %output_buffer[%zero_index] : tensor<1xf32> + + return %inserted : tensor<1xf32> + } + func.func @sum_kernel(%call_frame: !xla_cpu.call_frame) -> !xla_cpu.error { + %thread_idx = xla_cpu.thread_id %call_frame : index + %input_buffer = xla_cpu.load %call_frame, 0 : tensor<1024x32xf32> + %output_buffer = xla_cpu.load %call_frame, 1 : tensor<1xf32> + %sum = xla.pure_call @sum_kernel_entry(%input_buffer, %output_buffer) + {noinline} : (tensor<1024x32xf32>, tensor<1xf32>) -> tensor<1xf32> + xla_cpu.store %sum into %call_frame, 1 : tensor<1xf32> + %success = xla_cpu.success : !xla_cpu.error + return %success : !xla_cpu.error + } + } + """ + mlir_emitter = cpu_testlib.MlirKernelEmitter(ir, "sum_kernel", (1, 1, 1)) + + kernel_definition = mlir_emitter.emit_kernel_definition() + + runner = cpu_testlib.KernelRunner.create( + kernel_definition, + cpu_testlib.JitCompiler(base_testlib.HloModuleConfig()), + ) + input_tensor = create_literal(np.ones([1024, 32], dtype=np.float32)) + output_sum = create_literal(np.zeros([1], dtype=np.float32)) + runner.call([input_tensor, output_sum]) + + np.testing.assert_array_equal( + np.asarray(output_sum).item(), np.asarray(input_tensor).sum() + ) + + if __name__ == "__main__": absltest.main() diff --git a/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc new file mode 100644 index 00000000000000..b13008d834bc98 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/testlib/mlir_kernel_emitter.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/backends/cpu/codegen/fusion_compiler.h" +#include "xla/codegen/kernel_definition.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/codegen/mlir_kernel_source.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::cpu { +MlirKernelEmitter::MlirKernelEmitter(absl::string_view mlir, + absl::string_view kernel_name, + se::ThreadDim thread_dim, + absl::Span args) + : mlir_(mlir), + kernel_name_(kernel_name), + thread_dim_(thread_dim), + args_(args.begin(), args.end()) { + for (const MlirKernelEmitter::KernelArg& arg : args_) { + buffer_allocations_.emplace_back(buffer_allocations_.size(), arg.size_bytes, + /*color=*/0); + } +} + +absl::StatusOr MlirKernelEmitter::EmitKernelDefinition() { + std::unique_ptr context = FusionCompiler::CreateContext(); + + TF_ASSIGN_OR_RETURN( + MlirKernelSource source, + MlirKernelSource::ParseFromString(mlir_, std::move(context))); + + // Convert kernel arguments to fake allocations and buffer uses. + KernelSpec::Buffers argument_buffers; + KernelSpec::Buffers result_buffers; + + for (const auto& [arg, allocation] : llvm::zip(args_, buffer_allocations_)) { + BufferAllocation::Slice slice(&allocation, 0, arg.size_bytes); + if (arg.memory_access == BufferUse::MemoryAccess::kRead) { + argument_buffers.push_back(slice); + } else { + result_buffers.push_back(slice); + } + } + + KernelSpec kernel_spec(kernel_name_, thread_dim_, std::move(argument_buffers), + std::move(result_buffers), /*invariant_arguments=*/{}); + return KernelDefinition( + std::move(kernel_spec), + std::make_unique(std::move(source))); +} +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.h b/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.h new file mode 100644 index 00000000000000..4587d195452efe --- /dev/null +++ b/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.h @@ -0,0 +1,65 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_TESTLIB_MLIR_KERNEL_EMITTER_H_ +#define XLA_BACKENDS_CPU_TESTLIB_MLIR_KERNEL_EMITTER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/codegen/kernel_definition.h" +#include "xla/codegen/kernel_emitter.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/stream_executor/launch_dim.h" + +namespace xla::cpu { + +// An XLA kernel emitter that emits a kernel by parsing the given MLIR module +// into the dedicated MLIR context and module instance. This kernel emitter is +// intended to be used for testing purposes only: (1) load pre-compiled LLVM IR +// into the XLA kernel spec; (2) Execute it with user provided input buffers. +class MlirKernelEmitter : public KernelEmitter { + public: + // When loading kernel IR into the KernelSpec we create a separate buffer + // allocation for every kernel argument. We don't use buffer assignment in + // kernel testlib, but we still need to return a valid BufferUses vector. + struct KernelArg { + size_t size_bytes; + BufferUse::MemoryAccess memory_access; + }; + + MlirKernelEmitter(absl::string_view mlir, absl::string_view kernel_name, + se::ThreadDim thread_dim, absl::Span args); + + absl::StatusOr EmitKernelDefinition() final; + + private: + std::string mlir_; + std::string kernel_name_; + se::ThreadDim thread_dim_; + std::vector args_; + // Normally this would be populated by the buffer assignment pass, but for + // testing purposes we hold it in the emitter. + std::vector buffer_allocations_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_TESTLIB_MLIR_KERNEL_EMITTER_H_ diff --git a/third_party/xla/xla/codegen/BUILD b/third_party/xla/xla/codegen/BUILD index 343af622364b3b..570e9637bb36f0 100644 --- a/third_party/xla/xla/codegen/BUILD +++ b/third_party/xla/xla/codegen/BUILD @@ -117,3 +117,20 @@ cc_library( "//xla/tsl/platform:logging", ], ) + +cc_library( + name = "mlir_kernel_source", + srcs = ["mlir_kernel_source.cc"], + hdrs = ["mlir_kernel_source.h"], + deps = [ + ":kernel_source", + "//xla:util", + "//xla/mlir/utils:error_util", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + ], +) diff --git a/third_party/xla/xla/codegen/mlir_kernel_source.cc b/third_party/xla/xla/codegen/mlir_kernel_source.cc new file mode 100644 index 00000000000000..f2f0a2e651d786 --- /dev/null +++ b/third_party/xla/xla/codegen/mlir_kernel_source.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/codegen/mlir_kernel_source.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "xla/util.h" + +namespace xla { + +absl::StatusOr MlirKernelSource::ParseFromString( + absl::string_view ir, std::unique_ptr context) { + llvm::SourceMgr source_mgr; + + std::string error_string; + llvm::raw_string_ostream error_stream(error_string); + mlir::SourceMgrDiagnosticHandler source_mgr_handler(source_mgr, context.get(), + error_stream); + + source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(ir), + llvm::SMLoc()); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(source_mgr, context.get()); + + if (!mlir_module) { + return Internal("Failed to parse MLIR IR: %s", error_string); + } + + return MlirKernelSource(std::move(context), std::move(mlir_module)); +} + +} // namespace xla diff --git a/third_party/xla/xla/codegen/mlir_kernel_source.h b/third_party/xla/xla/codegen/mlir_kernel_source.h new file mode 100644 index 00000000000000..9ff4c14cbb19de --- /dev/null +++ b/third_party/xla/xla/codegen/mlir_kernel_source.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_ +#define XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Support/DebugStringHelper.h" +#include "xla/codegen/kernel_source.h" + +namespace xla { + +// Kernel JIT source that is backed by MLIR and owned by a mlir::ModuleOp. + +// The MLIR source is typically created by a fusion emitter from either the CPU +// or GPU backend, e.g., ScatterFusion. The specific dialect(s) that backs the +// source is not specified but is implicit in the passed context. It is expected +// that the source will be lowered to LLVM by the corresponding backend +// compiler. +class MlirKernelSource final : public KernelSource { + public: + MlirKernelSource(std::unique_ptr mlir_context, + mlir::OwningOpRef mlir_module) + : mlir_context_(std::move(mlir_context)), + module_(std::move(mlir_module)) {} + + MlirKernelSource(MlirKernelSource&& other) noexcept = default; + MlirKernelSource& operator=(MlirKernelSource&& other) noexcept = default; + + static absl::StatusOr ParseFromString( + absl::string_view ir, std::unique_ptr context); + + mlir::ModuleOp module() { return *module_; } + + std::string ToString() const final { return mlir::debugString(*module_); } + + private: + std::unique_ptr mlir_context_; + mlir::OwningOpRef module_; +}; + +} // namespace xla + +#endif // XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_ diff --git a/third_party/xla/xla/codegen/testlib/BUILD b/third_party/xla/xla/codegen/testlib/BUILD index 345ea53acc5037..f4e95e39f3dc42 100644 --- a/third_party/xla/xla/codegen/testlib/BUILD +++ b/third_party/xla/xla/codegen/testlib/BUILD @@ -57,6 +57,8 @@ tsl_pybind_extension( "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_source", "//xla/codegen:kernel_spec", + "//xla/codegen:llvm_ir_kernel_source", + "//xla/codegen:mlir_kernel_source", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/python:nb_absl_inlined_vector", diff --git a/third_party/xla/xla/codegen/testlib/__init__.py b/third_party/xla/xla/codegen/testlib/__init__.py index 31599d5f9912eb..3c1222b04f198c 100644 --- a/third_party/xla/xla/codegen/testlib/__init__.py +++ b/third_party/xla/xla/codegen/testlib/__init__.py @@ -30,6 +30,8 @@ KernelEmmitter = _extension.KernelEmitter KernelRunner = _extension.KernelRunner KernelSpec = _extension.KernelSpec +LlvmIrKernelSource = _extension.LlvmIrKernelSource +MlirKernelSource = _extension.MlirKernelSource # go/keep-sorted end # Functions diff --git a/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc b/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc index ee6907dddf8a91..d8d3220f94b6cb 100644 --- a/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc +++ b/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc @@ -38,6 +38,8 @@ limitations under the License. #include "xla/codegen/kernel_emitter.h" #include "xla/codegen/kernel_source.h" #include "xla/codegen/kernel_spec.h" +#include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/codegen/mlir_kernel_source.h" #include "xla/codegen/testlib/kernel_runner.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" @@ -171,6 +173,22 @@ NB_MODULE(_extension, kernel_runner_module) { nb::class_(kernel_runner_module, "KernelSource") .def("__str__", &KernelSource::ToString); + nb::class_ llvm_kernel_source( + kernel_runner_module, "LlvmIrKernelSource"); + + nb::class_(kernel_runner_module, + "MlirKernelSource") + .def_static( + "parse_from_string", + [](absl::string_view ir, std::unique_ptr context) { + absl::StatusOr source = + MlirKernelSource::ParseFromString(ir, std::move(context)); + if (!source.ok()) { + throw std::runtime_error(std::string(source.status().message())); + } + return std::move(source).value(); + }); + nb::class_ kernel_spec(kernel_runner_module, "KernelSpec"); nb::class_(kernel_runner_module, "KernelDefinition") diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index aa8391e0d0d5d9..ed93d31470e51d 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -4746,7 +4746,7 @@ static absl::Status PostOrderDFS( int current_id = dfs_stack.back().first; HloInstruction* current_node = dfs_stack.back().second; - CHECK_GE(current_id, 0) << current_id << ": " << current_node + CHECK_GE(current_id, 0) << current_id << ": " << current_node->name() << ": instruction may not have parent computation"; typename Visitor::VisitState visit_state = visitor->GetVisitState(current_id); From 9f864536fbf47afcde2960f8ebd9892b6c9bb3be Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 14 Apr 2025 07:45:10 -0700 Subject: [PATCH 0702/1324] [XLA:CPU] Set xla_backend_extra_options on the fusion module PiperOrigin-RevId: 747414202 --- .../xla/xla/backends/cpu/codegen/BUILD | 4 ++++ .../xla/backends/cpu/codegen/emitters/BUILD | 1 + .../emitters/cpu_fusion_emitter_test.cc | 1 + .../codegen/emitters/cpu_scatter_emitter.cc | 9 ++++++++ .../backends/cpu/codegen/fusion_compiler.cc | 21 +++++++++++++++++++ .../xla/xla/codegen/emitters/ir/xla_attrs.td | 7 +++++++ .../xla/xla/service/cpu/ir_emitter2.cc | 3 +++ 7 files changed, 46 insertions(+) diff --git a/third_party/xla/xla/backends/cpu/codegen/BUILD b/third_party/xla/xla/backends/cpu/codegen/BUILD index 33e6a9e77db8ba..a6f6cae3d5a51c 100644 --- a/third_party/xla/xla/backends/cpu/codegen/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/BUILD @@ -100,6 +100,7 @@ cc_library( "//xla/backends/cpu/codegen/emitters/ir:xla_cpu", "//xla/backends/cpu/codegen/emitters/transforms:passes", "//xla/codegen/emitters/ir:xla", + "//xla/codegen/emitters/ir:xla_attrs_inc_gen", "//xla/codegen/emitters/transforms:passes", "//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc", "//xla/mlir_hlo", @@ -109,7 +110,9 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:AffineTransforms", @@ -130,6 +133,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD index a02d882ce5fd53..e137ba5fe0d959 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD @@ -40,6 +40,7 @@ cc_library( "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/emitters:type_util", "//xla/codegen/emitters/ir:xla", + "//xla/codegen/emitters/ir:xla_attrs_inc_gen", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc", diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc index 6e7247d4e1e327..7be062b2396dd0 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter_test.cc @@ -107,6 +107,7 @@ static constexpr absl::string_view kScatterHlo = R"( TEST_F(CpuFusionEmitterTest, ScatterMlir) { constexpr absl::string_view kExpected = R"( + CHECK: module attributes {{{.*}}xla.extra_backend_options = #xla{{.*}}} CHECK: @wrapped_scatter_entry( CHECK-SAME: xla.entry CHECK: %[[XLA_LOOP:.+]] = xla.loop diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc index 1a502342d0c705..58c3096ccb0d3a 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -47,6 +48,7 @@ limitations under the License. #include "xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.h" #include "xla/codegen/emitters/computation_partitioner.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/ir/xla_attrs.h.inc" #include "xla/codegen/emitters/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" @@ -250,6 +252,13 @@ absl::StatusOr> CpuScatterFusion::Emit() mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); SetDataLayoutAttribute(module.get(), *fusion_); + mlir::StringAttr disable_loop_unrolling_attr = + builder.getStringAttr("xla_cpu_disable_loop_unrolling"); + module->getOperation()->setAttr( + xla::ExtraBackendOptionsAttr::name, + builder.getAttr( + llvm::ArrayRef{disable_loop_unrolling_attr})); + TF_ASSIGN_OR_RETURN( mlir::func::FuncOp entry_func, EmitFusionKernelApi(module.get(), *fusion_, diff --git a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc index 9454821436c130..fc57d7c1a888ce 100644 --- a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc @@ -16,10 +16,14 @@ limitations under the License. #include "xla/backends/cpu/codegen/fusion_compiler.h" #include +#include #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" @@ -39,15 +43,19 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" #include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_dialect.h" #include "xla/backends/cpu/codegen/emitters/transforms/passes.h" +#include "xla/codegen/emitters/ir/xla_attrs.h.inc" #include "xla/codegen/emitters/ir/xla_ops.h" #include "xla/codegen/emitters/transforms/passes.h" #include "xla/mlir/tools/mlir_replay/public/compiler_trace.pb.h" @@ -175,6 +183,19 @@ absl::StatusOr> FusionCompiler::Compile( std::unique_ptr llvm_module = mlir::translateModuleToLLVMIR(mlir_module, llvm_context); + if (mlir::Attribute options = + mlir_module->getAttr(xla::ExtraBackendOptionsAttr::name)) { + const auto formater = [](std::string* out, const mlir::StringAttr& attr) { + absl::StrAppend(out, attr.str()); + }; + std::string options_csv = absl::StrJoin( + mlir::cast(options), ",", formater); + llvm::MDString* options_mdstring = + llvm::MDString::get(llvm_context, options_csv); + llvm_module->addModuleFlag(llvm::Module::Error, "xla_backend_extra_options", + options_mdstring); + } + TF_RET_CHECK(llvm_module != nullptr) << "Failed to translate module to LLVM IR."; diff --git a/third_party/xla/xla/codegen/emitters/ir/xla_attrs.td b/third_party/xla/xla/codegen/emitters/ir/xla_attrs.td index 197ac2e72c96ea..b410d33135fbbf 100644 --- a/third_party/xla/xla/codegen/emitters/ir/xla_attrs.td +++ b/third_party/xla/xla/codegen/emitters/ir/xla_attrs.td @@ -77,4 +77,11 @@ def XLA_BackendKindAttr : let assemblyFormat = "`<` $value `>`"; } +def XLA_ExtraBackendOptionsAttr + : ArrayOfAttr { +} + #endif // XLA_CODEGEN_EMITTERS_IR_XLA_ATTRS diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index a49333ab0acc9e..35ffcb786cd957 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -229,6 +229,9 @@ absl::StatusOr IrEmitter2::EmitFusionWithFusionEmitters( // Match data layouts to avoid warning messages. llvm_module->setDataLayout(module_->getDataLayout()); + // We need to clear the module flags so that they don't pollute the linked + // module (this will not be required when we migrate to the kernel API). + llvm_module->getModuleFlagsMetadata()->eraseFromParent(); if (llvm::Linker::linkModules(*module_, std::move(llvm_module))) { return Internal("Cannot link additional LLVM module for fusion %s", fusion->name()); From 889afc514434f13112267c0656c6b980a95ad4cc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 08:05:52 -0700 Subject: [PATCH 0703/1324] Removes support for memory edge costs (unused in production as of cl/651447720). PiperOrigin-RevId: 747421601 --- .../auto_sharding/auto_sharding_solver.cc | 40 ++----------------- .../auto_sharding_solver_test.cc | 6 ++- 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 2c474773dee426..cd2b5f813fd08d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -689,32 +689,19 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( } // c. if (request.memory_budget() > 0) { - std::vector> reduced_intervals_nodes, - reduced_intervals_edges; + std::vector> reduced_intervals_nodes; absl::flat_hash_set reduced_times; - std::vector group_node_vars, group_edge_vars; - std::optional> num_node_terms, num_edge_terms; + std::vector group_node_vars; + std::optional> num_node_terms; num_node_terms = ReduceMemoryTerms( request, *solver, request.num_nodes(), request.node_intervals(), request.memory_costs(), "node", s, reduced_intervals_nodes, group_node_vars, reduced_times); - if (request.enable_memory_edge_costs()) { - num_edge_terms = ReduceMemoryTerms( - request, *solver, request.edges_size(), request.edge_intervals(), - request.memory_edge_costs(), "edge", e, reduced_intervals_edges, - group_edge_vars, reduced_times); - } absl::flat_hash_map constraints; AddMemoryTerms(request, *solver, request.num_nodes(), reduced_intervals_nodes, request.memory_costs(), overbudget_var, reduced_times, s, group_node_vars, constraints); - if (request.enable_memory_edge_costs()) { - AddMemoryTerms(request, *solver, request.edges_size(), - reduced_intervals_edges, request.memory_edge_costs(), - overbudget_var, reduced_times, e, group_edge_vars, - constraints); - } if (overbudget_var && !request.minimize_departures()) { solver->MutableObjective()->SetCoefficient( overbudget_var, *overbudget_coeff * request.memory_budget()); @@ -1172,27 +1159,6 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, total_memory_costs[time_idx] += total_memory_cost; lower_bound_memory_costs[time_idx] += lower_bound_memory_cost; } - if (request.enable_memory_edge_costs()) { - for (EdgeIdx edge_idx = 0; edge_idx < request.edge_intervals_size(); - ++edge_idx) { - const auto& interval = request.edge_intervals(edge_idx); - if (interval.first() > interval.second()) continue; - // Expand cost vectors if needed to cover the range of this interval. - while (total_memory_costs.size() <= interval.second()) { - total_memory_costs.push_back(0.0); - lower_bound_memory_costs.push_back(0.0); - } - double total_memory_cost = 0.0, lower_bound_memory_cost = 0.0; - const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_memory_cost = m[e_val(edge_idx)]; - lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); - for (LivenessIdx time_idx = interval.first(); - time_idx <= interval.second(); ++time_idx) { - total_memory_costs[time_idx] += total_memory_cost; - lower_bound_memory_costs[time_idx] += lower_bound_memory_cost; - } - } - } } double total_overbudget = 0.0; double lower_bound_overbudget = 0.0; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 0b2b53790a3ddb..e7668e29224f0b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -440,7 +440,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { EXPECT_EQ(result, expected_output); } -TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { +TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, + HandlesMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}}; const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, @@ -464,7 +465,8 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { EXPECT_EQ(result, expected_output); } -TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { +TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, + HandlesIntervals) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, 2000, 2100, 2200, 2300, From 729e03290e5bcf5f8f0853752d95c7e0ab6af2ed Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 14 Apr 2025 08:06:43 -0700 Subject: [PATCH 0704/1324] [XLA:GPU] Fix all-to-all thunk for buffers with element count larger than int32 max. PiperOrigin-RevId: 747421888 --- .../backends/gpu/runtime/all_to_all_thunk.cc | 2 +- .../xla/xla/tests/collective_ops_e2e_test.cc | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc index 3e94cbfd182ad3..3f5bb1c464902c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc @@ -200,7 +200,7 @@ absl::Status RunAllToAll(GpuCollectives* collectives, bool has_split_dimension, TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); PrimitiveType element_type = buffers[0].element_type; - int32_t element_count = buffers[0].element_count; + int64_t element_count = buffers[0].element_count; // All buffers must have the same element type and count. bool all_buffers_match = absl::c_all_of(buffers, [&](const auto& buffer) { diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 870c383c7961dd..216ae9fc2685f5 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -742,6 +742,52 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { LiteralTestUtil::ExpectR1Equal({40, 60, 44, 64}, results[1]); } +XLA_TEST_P(AsyncCollectiveOps, + AsyncAllToAllNumberOfElementsLargerThanInt32Max) { + const absl::string_view kModuleStr = R"( + HloModule test + + ENTRY test_computation { + id = u32[] replica-id() + id_u8 = u8[] convert(id) + a0 = u8[2,32768,32768] broadcast(id_u8), dimensions={} + ROOT a2a = u8[2,32768,32768] all-to-all(u8[2,32768,32768] a0), + replica_groups={{0,1}}, dimensions={0} + } + )"; + const int64_t kNumReplicas = 2; + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } + const bool enable_async_all_to_all = GetParam(); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(kModuleStr, kNumReplicas)); + TF_ASSERT_OK_AND_ASSIGN(const HloModule* const hlo_module, + test_runner().HloModuleFromWrapped(executable.get())); + + HloInstruction* a2a_start = + FindInstruction(hlo_module, HloOpcode::kAsyncStart); + HloInstruction* a2a_done = FindInstruction(hlo_module, HloOpcode::kAsyncDone); + ASSERT_THAT(a2a_start, NotNull()); + ASSERT_THAT(a2a_done, NotNull()); + HloAsyncInstruction* a2a_start_async = Cast(a2a_start); + EXPECT_EQ(a2a_start_async->async_wrapped_opcode(), HloOpcode::kAllToAll); + EXPECT_EQ(IsAsync(a2a_start_async), enable_async_all_to_all); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(executable.get(), kNumReplicas)); + ASSERT_EQ(results.size(), kNumReplicas); + + // Sanity check only a few elements in each result, because checking all 2GB + // would take too long. + EXPECT_EQ(results[0].Get({0, 0, 0}), 0); + EXPECT_EQ(results[0].Get({1, 0, 0}), 1); + + EXPECT_EQ(results[1].Get({0, 0, 0}), 0); + EXPECT_EQ(results[1].Get({1, 0, 0}), 1); +} + XLA_TEST_P(AsyncMemcpyCollectiveOps, AsyncAllToAllMultipleReplicaGroups) { const absl::string_view kModuleStr = R"( HloModule test From 09b8aaaf5f1107c7fc0c9fbb3c913d928be3a4cb Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 14 Apr 2025 08:49:40 -0700 Subject: [PATCH 0705/1324] [XLA:GPU] Implement a more detailed heuristic for `SortRewriter` for H100 and A100. PiperOrigin-RevId: 747436596 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/gpu_compiler.cc | 2 +- .../service/gpu/gpu_device_info_for_tests.cc | 26 +++ .../service/gpu/gpu_device_info_for_tests.h | 4 + .../service/gpu/tests/gpu_cub_sort_test.cc | 10 +- .../xla/xla/service/gpu/transforms/BUILD | 12 +- .../service/gpu/transforms/sort_rewriter.cc | 157 +++++++++++++----- .../service/gpu/transforms/sort_rewriter.h | 27 +-- .../gpu/transforms/sort_rewriter_test.cc | 60 ++++++- 9 files changed, 232 insertions(+), 67 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ba0470ad748b6f..1804a153fb6266 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -280,6 +280,7 @@ cc_library( deps = [ "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", + "//xla/stream_executor/cuda:cuda_compute_capability", ], ) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 7c9340d6ef38ca..8ce282559600c4 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -708,7 +708,7 @@ absl::Status RunOptimizationPasses( // which isn't feasible if we don't have a device. if (hlo_module->config().debug_options().xla_gpu_enable_cub_radix_sort()) { if (stream_exec != nullptr) { - pipeline.AddPass(); + pipeline.AddPass(gpu_target_config.device_description); } else { LOG(WARNING) << "Using fallback sort algorithm rather than SortRewriter, " "which will be slower at runtime. To avoid this, " diff --git a/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc b/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc index 150b3fe3ebe430..ed12324b4f672b 100644 --- a/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc +++ b/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc @@ -47,6 +47,32 @@ stream_executor::DeviceDescription TestGpuDeviceInfo::RTXA6000DeviceInfo( return b; } +stream_executor::DeviceDescription TestGpuDeviceInfo::RTXH100SXMDeviceInfo( + stream_executor::GpuComputeCapability cc) { + stream_executor::DeviceDescription b; + b.set_gpu_compute_capability(cc); + b.set_threads_per_block_limit(1024); + b.set_threads_per_warp(32); + b.set_shared_memory_per_block(48 * 1024); + b.set_shared_memory_per_block_optin(227 * 1024); + b.set_shared_memory_per_core(228 * 1024); + b.set_threads_per_core_limit(2048); + b.set_core_count(132); + b.set_fpus_per_core(128); + b.set_block_dim_limit_x(2'147'483'647); + b.set_block_dim_limit_y(65535); + b.set_block_dim_limit_z(65535); + b.set_memory_bandwidth(3'352'320'000'000); + b.set_l2_cache_size(50 * 1024 * 1024); + b.set_clock_rate_ghz(1.98); + b.set_device_memory_size(84'978'434'048); + b.set_registers_per_core_limit(65536); + b.set_registers_per_block_limit(65536); + b.set_runtime_version(stream_executor::SemanticVersion{12, 4, 0}); + b.set_driver_version(stream_executor::SemanticVersion{12, 4, 0}); + return b; +} + stream_executor::DeviceDescription TestGpuDeviceInfo::AMDMI210DeviceInfo() { stream_executor::DeviceDescription b; b.set_gpu_compute_capability( diff --git a/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.h b/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.h index 9085763fee43c8..54501ace5a6efb 100644 --- a/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.h +++ b/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_DEVICE_INFO_FOR_TESTS_H_ #define XLA_SERVICE_GPU_GPU_DEVICE_INFO_FOR_TESTS_H_ +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -26,6 +27,9 @@ class TestGpuDeviceInfo { static stream_executor::DeviceDescription RTXA6000DeviceInfo( stream_executor::GpuComputeCapability cc = stream_executor::CudaComputeCapability(8, 9)); + static stream_executor::DeviceDescription RTXH100SXMDeviceInfo( + stream_executor::GpuComputeCapability cc = + stream_executor::CudaComputeCapability(9, 0)); static stream_executor::DeviceDescription AMDMI210DeviceInfo(); // Returns deafult RTXA6000 or AMDMI210 device info static stream_executor::DeviceDescription CudaOrRocmDeviceInfo(); diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index 907fe384c3208d..d37d322ddb32f2 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -52,13 +52,12 @@ class CubSortKeysTest : public HloTestBase, public: void SetUp() override { HloTestBase::SetUp(); - SortRewriter::SetSortSizeThresholdForTestingOnly( - 0); // Always use CUB sort. + SortRewriter::SetSortModeForTestingOnly(SortRewriter::Mode::kAlways); } }; TEST_F(CubSortKeysTest, AlwaysUsesCubSort) { - EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0); + EXPECT_EQ(SortRewriter::SortMode(), SortRewriter::Mode::kAlways); } TEST_P(CubSortKeysTest, CompareToReference) { @@ -208,13 +207,12 @@ class CubSortPairsTest public: void SetUp() override { HloTestBase::SetUp(); - SortRewriter::SetSortSizeThresholdForTestingOnly( - 0); // Always use CUB sort. + SortRewriter::SetSortModeForTestingOnly(SortRewriter::Mode::kAlways); } }; TEST_F(CubSortPairsTest, AlwaysUsesCubSort) { - EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0); + EXPECT_EQ(SortRewriter::SortMode(), SortRewriter::Mode::kAlways); } TEST_P(CubSortPairsTest, CompareToReference) { diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 6b21f9d1c50a58..307b89e6a59d39 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -3028,12 +3028,15 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:pattern_matcher", "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) @@ -3054,14 +3057,15 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:pattern_matcher", "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:gpu_device_info_for_tests", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc index 9f8967c5c56a36..c431ad1560e299 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc @@ -21,9 +21,12 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/runtime/cub_sort_thunk.h" @@ -40,11 +43,12 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -302,6 +306,117 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type, return sort_keys; } +bool IsCubSortFasterOnH100(int bitwidth, int batch_size, int num_elements, + int sm_count) { + // The numbers below are based on extensive benchmarks: see + // b/407689882#comment35 and b/410480351 for more details. + switch (bitwidth) { + case 8: + return batch_size == 1 || + (num_elements > 1300 && (batch_size > 8 || num_elements < 26000)); + case 16: + return (batch_size == 1 && num_elements > (1 << 9)) || + (batch_size > 12 && num_elements > (1 << 16)) || + (batch_size > 14 && num_elements > (1 << 15)) || + (batch_size > 16 && num_elements > (1 << 14)) || + (batch_size > 18 && num_elements > (1 << 13)) || + (batch_size > 33 && num_elements > (1 << 12)) || + (batch_size > 66 && num_elements > (1 << 11)); + case 32: + return (batch_size == 1 && num_elements > 22000) || + (batch_size > 26 && num_elements > (1 << 17)) || + (batch_size > 31 && num_elements > (1 << 16)) || + (batch_size > 38 && num_elements > (1 << 15)) || + (batch_size > 44 && num_elements > (1 << 14)) || + (batch_size > 52 && num_elements > (1 << 13)) || + (batch_size > 88 && batch_size <= sm_count && + num_elements > (1 << 12)); + case 64: + return (batch_size == 1 && num_elements > (1 << 17)) || + (batch_size > 55 && num_elements > (1 << 17)) || + (batch_size > 70 && num_elements > (1 << 16)) || + (batch_size > 92 && num_elements > (1 << 15)) || + (((batch_size > 160 && batch_size <= 2 * sm_count) || + (batch_size > 354)) && + num_elements > (1 << 14)); + default: + return false; + } +} + +// Returns whether a compatible sort should be rewritten based on the current +// sort mode and possibly a heuristic. +bool ShouldRewriteCompatibleSort(se::DeviceDescription device_description, + const HloSortInstruction* sort_op) { + if (SortRewriter::SortMode() == SortRewriter::Mode::kAlways) { + return true; + } + + const Shape& operand_shape = sort_op->operand(0)->shape(); + int num_elements = operand_shape.dimensions().back(); + if (num_elements == 0) { + return false; + } + + if (SortRewriter::SortMode() == SortRewriter::Mode::kAuto) { + if (auto cuda_cc = std::get_if( + &device_description.gpu_compute_capability())) { + int bitwidth = primitive_util::BitWidth(operand_shape.element_type()); + int batch_size = Product(operand_shape.dimensions()) / num_elements; + + if (cuda_cc->IsHopper()) { + return IsCubSortFasterOnH100(bitwidth, batch_size, num_elements, + device_description.core_count()); + } + if (cuda_cc->IsAmpere()) { + // TODO(b/410480351): Verify that the H100 heuristic also works well for + // Ampere or implement a custom heuristic. + return IsCubSortFasterOnH100(bitwidth, batch_size, num_elements, + device_description.core_count()); + } + } + } + + // TODO(b/410480351): The default heuristic below is pretty bad in the general + // case. Run benchmarks on different devices and add a heuristic per device. + return Product(operand_shape.dimensions()) > 16384; +} + +bool IsCubCompatibleSort(const se::DeviceDescription& device_description, + const HloSortInstruction* sort_op) { + VLOG(1) << "Sort instruction: " << sort_op->name(); + if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) { + VLOG(2) << "Unsupported operand count: " << sort_op->operand_count(); + return false; + } + + const Shape& operand_shape = sort_op->operand(0)->shape(); + if (sort_op->sort_dimension() != operand_shape.dimensions().size() - 1) { + VLOG(2) << "Sort dimension should be the minor one"; + return false; + } + if (!ShouldRewriteCompatibleSort(device_description, sort_op)) { + VLOG(2) << "Tensor shape and type will not see an improvement."; + return false; + } + + auto sort_analysis = AnalyzeSortOp(*sort_op); + if (!sort_analysis.has_value()) { + VLOG(2) << "Only simple compare computations are supported"; + return false; + } + if (!CreateRunner(*sort_analysis).ok()) { + VLOG(2) << "Unsupported operand types (no compiled CUB kernels): " + << PrimitiveType_Name(sort_analysis->key_type) << " " + << (sort_analysis->value_type.has_value() + ? PrimitiveType_Name(sort_analysis->value_type.value()) + : ""); + return false; + } + VLOG(2) << "Sort operation is compatible"; + return true; +} + } // namespace // Rewrites a single sort instruction with a custom call. @@ -392,7 +507,7 @@ absl::StatusOr SortRewriter::RunOnComputation( std::vector sort_ops; for (auto* inst : computation->instructions()) { HloSortInstruction* sort = DynCast(inst); - if (sort != nullptr && IsCubCompatibleSort(sort)) { + if (sort != nullptr && IsCubCompatibleSort(device_description_, sort)) { sort_ops.push_back(sort); } } @@ -419,39 +534,5 @@ absl::StatusOr SortRewriter::Run( return changed; } -bool IsCubCompatibleSort(const HloSortInstruction* sort_op) { - VLOG(1) << "Sort instruction: " << sort_op->name(); - if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) { - VLOG(2) << "Unsupported operand count: " << sort_op->operand_count(); - return false; - } - - const Shape& operand_shape = sort_op->operand(0)->shape(); - if (sort_op->sort_dimension() != operand_shape.dimensions().size() - 1) { - VLOG(2) << "Sort dimension should be the minor one"; - return false; - } - if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) { - VLOG(2) << "Tensor shape size is too small to see an improvement"; - return false; - } - - auto sort_analysis = AnalyzeSortOp(*sort_op); - if (!sort_analysis.has_value()) { - VLOG(2) << "Only simple compare computations are supported"; - return false; - } - if (!CreateRunner(*sort_analysis).ok()) { - VLOG(2) << "Unsupported operand types (no compiled CUB kernels): " - << PrimitiveType_Name(sort_analysis->key_type) << " " - << (sort_analysis->value_type.has_value() - ? PrimitiveType_Name(sort_analysis->value_type.value()) - : ""); - return false; - } - VLOG(2) << "Sort operation is compatible"; - return true; -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h index 96835aab6306e0..cbe17e8603fe49 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -33,16 +34,24 @@ namespace gpu { class SortRewriter : public HloModulePass { public: + explicit SortRewriter(const se::DeviceDescription& device_description) + : device_description_(device_description) {} absl::string_view name() const override { return "sort-rewriter"; } + enum class Mode { + kAuto, // Decide whether to rewrite compatible sorts based on a heuristic. + kAlways // Always rewrite compatible sorts. Used for testing. + }; + // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite // tensors with sizes below this limit. - static int SortSizeThreshold() { return sort_size_threshold_; } - static void SetSortSizeThresholdForTestingOnly(int threshold) { - // We need to be able to reduce the threshold for testing, so that the tests - // can run and compare against the reference interpreter, which is quite - // slow. - sort_size_threshold_ = threshold; + static Mode SortMode() { return sort_mode_; } + static void SetSortModeForTestingOnly(Mode sort_mode) { + // We need to be able to force rewrites for testing for arbitrary shapes. + // This enables the tests to run and compare against the reference + // interpreter, which is quite slow and needs smaller shapes that would + // normally not be rewritten. + sort_mode_ = sort_mode; } using HloPassInterface::Run; @@ -54,12 +63,10 @@ class SortRewriter : public HloModulePass { absl::StatusOr RunOnInstruction(HloSortInstruction* sort_op); absl::StatusOr RunOnComputation(HloComputation* computation); - static inline int sort_size_threshold_ = 16385; + static inline Mode sort_mode_ = Mode::kAuto; + const se::DeviceDescription device_description_; }; -// Verify that the sort tensor shape is supported by CUB. -bool IsCubCompatibleSort(const HloSortInstruction* sort_op); - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc index a01271e4155bca..5509441423e676 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/service/gpu/transforms/sort_rewriter.h" +#include #include #include #include +#include #include #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" @@ -26,13 +28,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/primitive_util.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/pattern_matcher.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace gpu { @@ -46,13 +49,14 @@ class SortRewriterTest public: void SetUp() override { HloTestBase::SetUp(); - SortRewriter::SetSortSizeThresholdForTestingOnly( - 0); // Always use CUB sort. + SortRewriter::SetSortModeForTestingOnly(SortRewriter::Mode::kAlways); } bool RunModuleAndPass(HloModule* module) { auto cloned = module->Clone(); - bool changed = SortRewriter().Run(module).value(); + bool changed = SortRewriter(TestGpuDeviceInfo::CudaOrRocmDeviceInfo()) + .Run(module) + .value(); if (changed) { // Here we run an end to end test to make sure that SortRewriter does // not introduce an incorrect rewrite. To do this, we need to clone the @@ -315,7 +319,7 @@ ENTRY %main { // Small shapes do not see improvement from CUB sort. TEST_F(SortRewriterTest, NoRewriteSmallSize) { - SortRewriter::SetSortSizeThresholdForTestingOnly(16385); + SortRewriter::SetSortModeForTestingOnly(SortRewriter::Mode::kAuto); constexpr char kHlo[] = R"( HloModule TestModule @@ -334,6 +338,44 @@ ENTRY %main { EXPECT_FALSE(RunModuleAndPass(module.get())); } +TEST_F(SortRewriterTest, H100Heuristic) { + SortRewriter::SetSortModeForTestingOnly(SortRewriter::Mode::kAuto); + constexpr char kHloTmpl[] = R"( +HloModule TestModule + +%compare { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT +} + +ENTRY %main { + %input = f32[$0,100000] parameter(0) + ROOT %sort = f32[$0,100000] sort(%input), dimensions={1}, to_apply=%compare +})"; + + auto pass = SortRewriter(TestGpuDeviceInfo::RTXH100SXMDeviceInfo()); + + // Batch 1 + std::string hlo = absl::Substitute(kHloTmpl, "1"); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + // Batch 3 + hlo = absl::Substitute(kHloTmpl, "3"); + TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); + + // Batch 70 + hlo = absl::Substitute(kHloTmpl, "70"); + TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); +} + // Basic sort: with batch dimension. TEST_F(SortRewriterTest, SortWithBatchDim) { constexpr char kHlo[] = R"( @@ -404,7 +446,9 @@ ENTRY %main { constexpr char kExpectedPattern[] = R"( // CHECK: %[[CC:.*]] = (u16[1000]{0}, u8[1]{0}) custom-call({{.*}}), custom_call_target="__cub$DeviceRadixSort", metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}, backend_config={"descending":true} )"; - RunAndFilecheckHloRewrite(kHlo, SortRewriter(), kExpectedPattern); + RunAndFilecheckHloRewrite( + kHlo, SortRewriter(TestGpuDeviceInfo::CudaOrRocmDeviceInfo()), + kExpectedPattern); } TEST_P(SortRewriterTest, SortNumpyOrder) { @@ -455,7 +499,7 @@ INSTANTIATE_TEST_SUITE_P( }); TEST_F(SortRewriterTest, AlwaysUsesCubSort) { - EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0); + EXPECT_EQ(SortRewriter::SortMode(), SortRewriter::Mode::kAlways); } } // namespace From dfd9216e4774bb0200d57a89b7be8578a5532909 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 14 Apr 2025 08:59:00 -0700 Subject: [PATCH 0706/1324] [xla:gpu] Fix DotUnsetAlgorithmEmitterTest.UnsetAlgorithmIsEmittedCorrectly test. The previous change called `AreTypesSupportedByAlgUnsetDot()` with swapped `result_type` and `input_type` by accident. Just using the correct order doesn't work because the implementation is not symmetric. Instead, use `IsTritonSupportedComputation()` to skip unsupported configs. This generates more test cases, but skipping them should not be prohibitively expensive. PiperOrigin-RevId: 747439799 --- .../triton/fusion_emitter_device_test.cc | 34 ++----- .../backends/gpu/codegen/triton/support.cc | 88 +++++++++---------- .../xla/backends/gpu/codegen/triton/support.h | 6 -- 3 files changed, 52 insertions(+), 76 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index 452f6583e0b17e..3b08f5939367db 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -2880,9 +2880,12 @@ class DotUnsetAlgorithmEmitterTest }; TEST_P(DotUnsetAlgorithmEmitterTest, UnsetAlgorithmIsEmittedCorrectly) { - auto [input_type, result_type] = GetParam(); - if (!internal::AreTypesSupportedByAlgUnsetDot(input_type, result_type, - GpuComputeCapability())) { + auto [result_type, input_type] = GetParam(); + const std::string kHloText = + GetDotAlgorithmHlo(input_type, result_type, PrecisionConfig::ALG_UNSET); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + if (!IsTritonSupportedComputation(*module->entry_computation(), + GpuComputeCapability())) { GTEST_SKIP() << "Not supported on this platform."; } @@ -2892,34 +2895,13 @@ TEST_P(DotUnsetAlgorithmEmitterTest, UnsetAlgorithmIsEmittedCorrectly) { primitive_util::BitWidth(result_type) == 8) { error_spec = ErrorSpec{/*aabs=*/1e0, /*arel=*/1e-1}; } - - const std::string kHloText = - GetDotAlgorithmHlo(input_type, result_type, PrecisionConfig::ALG_UNSET); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, error_spec)); } -auto AllXlaDataTypesSupportedByAlgUnsetDotLowering() { - // We don't have a pointer to stream executor available here so we can't - // detect the particular device we're running on with a canonical API call. - // Instead, we just return a superset of the supported types (i.e. those that - // are supported on the latest device), and filter out the unsupported types - // in the test body. - std::vector supported_types; - ::testing::internal::ParamGenerator - all_types = ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), - ::testing::ValuesIn(AllXlaDataTypes())); - absl::c_copy_if( - all_types, std::back_inserter(supported_types), [](const auto& types) { - auto [result_type, input_type] = types; - return static_cast(internal::AreTypesSupportedByAlgUnsetDot( - input_type, result_type, se::CudaComputeCapability::Blackwell())); - }); - return supported_types; -} - INSTANTIATE_TEST_SUITE_P( DotUnsetAlgorithmEmitterTestSuite, DotUnsetAlgorithmEmitterTest, - ::testing::ValuesIn(AllXlaDataTypesSupportedByAlgUnsetDotLowering()), + ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllXlaDataTypes())), DotUnsetAlgorithmEmitterTest::ParamToString); } // namespace diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 85f2ea80c64cc5..22512f3fa9188b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -309,6 +309,48 @@ bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm) { return false; } +CodegenDecision AreTypesSupportedByAlgUnsetDot( + PrimitiveType input_type, PrimitiveType result_type, + const se::GpuComputeCapability& gpu_version) { + if (input_type == F64 && result_type != F64) { + return CodegenDecision::Forbid( + "Dot operation only supports F64 result type for F64 input type."); + } + + if (input_type == F8E4M3FN || result_type == F8E4M3FN) { + if (auto* cuda_cc = std::get_if(&gpu_version); + cuda_cc && !cuda_cc->IsAtLeastHopper()) { + return CodegenDecision::Forbid( + "Dot operation for F8E4M3FN is not supported before Hopper."); + } + } + + auto supported_float_types = {BF16, F16, F32, F64, F8E5M2, F8E4M3FN}; + if (absl::c_linear_search(supported_float_types, input_type)) { + return CodegenDecision::Allow(); + } + + if (input_type == S8 && result_type == S32) { + return CodegenDecision::Allow(); + } + + auto partially_supported_signed_types = {S8, S16, S32, S64}; + if (absl::c_linear_search(partially_supported_signed_types, input_type)) { + if (absl::c_linear_search(partially_supported_signed_types, result_type)) { + return CodegenDecision::Forbid( + "Dot operation does not support these signed integer types."); + } + if (primitive_util::IsFloatingPointType(result_type)) { + return CodegenDecision::Forbid( + "Dot operation does not support floating point input and signed " + "integer result types."); + } + return CodegenDecision::Allow(); + } + + return CodegenDecision::Forbid("Unsupported types."); +} + // Checks whether the conversions generated during the lowering of the relevant // dot algorithm for the relevant input and output types are supported by // Triton. @@ -409,8 +451,8 @@ CodegenDecision IsTritonSupportedDot( } if (algorithm == PrecisionConfig::ALG_UNSET) { - if (CodegenDecision decision = internal::AreTypesSupportedByAlgUnsetDot( - lhs_type, result_type, gpu_version); + if (CodegenDecision decision = + AreTypesSupportedByAlgUnsetDot(lhs_type, result_type, gpu_version); !decision) { return decision; } @@ -619,48 +661,6 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { } } -CodegenDecision AreTypesSupportedByAlgUnsetDot( - PrimitiveType input_type, PrimitiveType result_type, - const se::GpuComputeCapability& gpu_version) { - if (input_type == F64 && result_type != F64) { - return CodegenDecision::Forbid( - "Dot operation only supports F64 result type for F64 input type."); - } - - if (input_type == F8E4M3FN || result_type == F8E4M3FN) { - if (auto* cuda_cc = std::get_if(&gpu_version); - cuda_cc && !cuda_cc->IsAtLeastHopper()) { - return CodegenDecision::Forbid( - "Dot operation for F8E4M3FN is not supported before Hopper."); - } - } - - auto supported_float_types = {BF16, F16, F32, F64, F8E5M2, F8E4M3FN}; - if (absl::c_linear_search(supported_float_types, input_type)) { - return CodegenDecision::Allow(); - } - - if (input_type == S8 && result_type == S32) { - return CodegenDecision::Allow(); - } - - auto partially_supported_signed_types = {S8, S16, S32, S64}; - if (absl::c_linear_search(partially_supported_signed_types, input_type)) { - if (absl::c_linear_search(partially_supported_signed_types, result_type)) { - return CodegenDecision::Forbid( - "Dot operation does not support these signed integer types."); - } - if (primitive_util::IsFloatingPointType(result_type)) { - return CodegenDecision::Forbid( - "Dot operation does not support floating point input and signed " - "integer result types."); - } - return CodegenDecision::Allow(); - } - - return CodegenDecision::Forbid("Unsupported types."); -} - } // namespace internal absl::Status EnsureTritonSupportsComputeCapability( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.h b/third_party/xla/xla/backends/gpu/codegen/triton/support.h index 37301424eb06c8..de2c15c6c47011 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.h @@ -71,12 +71,6 @@ namespace internal { // HLOs. This is exposed for testing purposes only and will be removed in the // near future. Do not use. This functions only returns a partial result. bool IsTritonUnsupportedOpcode(HloOpcode opcode); - -// This is exposed for testing purposes only. Do not use. -CodegenDecision AreTypesSupportedByAlgUnsetDot( - PrimitiveType input_type, PrimitiveType result_type, - const se::GpuComputeCapability& gpu_version); - } // namespace internal } // namespace gpu From 50cd97d31a27ad6cb97fb6613e0fb43f5bb746c2 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 14 Apr 2025 09:00:33 -0700 Subject: [PATCH 0707/1324] [XLA:GPU] Add support for cub-sort kernels with `f32` keys. PiperOrigin-RevId: 747440269 --- .../xla/xla/backends/gpu/runtime/cub_sort_thunk.cc | 8 ++++++-- third_party/xla/xla/service/gpu/build_defs.bzl | 3 +++ third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc | 9 +++++++++ third_party/xla/xla/service/gpu/cub_sort_kernel.h | 4 ++++ .../xla/xla/service/gpu/tests/gpu_cub_sort_test.cc | 2 +- 5 files changed, 23 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc index 8195bf0a1ab226..f4b35813f4f4c2 100644 --- a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc @@ -241,8 +241,8 @@ absl::StatusOr> CreateCubSortRunner( } // Returns an interface for calling CubSortPairs on the given key and value -// types. key_type can be only unsigned integer types. value_type can be any -// type of 16/32/64 bit width. +// types. key_type can be any unsigned integer types or F32. value_type can be +// any type of 16/32/64 bit width. absl::StatusOr> CreateCubSortRunner( PrimitiveType key_type, PrimitiveType value_type) { int value_width = primitive_util::BitWidth(value_type); @@ -260,6 +260,10 @@ absl::StatusOr> CreateCubSortRunner( if (key_type == U64 && value_width == 32) sort_fn = CubSortPairs_u64_b32; if (key_type == U64 && value_width == 64) sort_fn = CubSortPairs_u64_b64; + if (key_type == F32 && value_width == 16) sort_fn = CubSortPairs_f32_b16; + if (key_type == F32 && value_width == 32) sort_fn = CubSortPairs_f32_b32; + if (key_type == F32 && value_width == 64) sort_fn = CubSortPairs_f32_b64; + if (sort_fn == nullptr) { return InvalidArgument( "Unsupported key/value type combination for CubSortPairs: %s/%s", diff --git a/third_party/xla/xla/service/gpu/build_defs.bzl b/third_party/xla/xla/service/gpu/build_defs.bzl index 9ae3e2ab0b08f9..e5f3036865a225 100644 --- a/third_party/xla/xla/service/gpu/build_defs.bzl +++ b/third_party/xla/xla/service/gpu/build_defs.bzl @@ -42,6 +42,9 @@ def get_cub_sort_kernel_types(name = ""): "u8_b16", "u8_b32", "u8_b64", + "f32_b16", + "f32_b32", + "f32_b64", ] def build_cub_sort_kernels(name, types, local_defines = [], **kwargs): diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc index 507bd4d7da953f..e5d9db3f52ee75 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc @@ -244,6 +244,15 @@ XLA_CUB_DEFINE_SORT_PAIRS(u32_b32, uint32_t, uint32_t) #ifdef CUB_TYPE_U32_B64 XLA_CUB_DEFINE_SORT_PAIRS(u32_b64, uint32_t, uint64_t) #endif +#ifdef CUB_TYPE_F32_B16 +XLA_CUB_DEFINE_SORT_PAIRS(f32_b16, float, uint16_t) +#endif +#ifdef CUB_TYPE_F32_B32 +XLA_CUB_DEFINE_SORT_PAIRS(f32_b32, float, uint32_t) +#endif +#ifdef CUB_TYPE_F32_B64 +XLA_CUB_DEFINE_SORT_PAIRS(f32_b64, float, uint64_t) +#endif // Pairs with 64-bit key. #ifdef CUB_TYPE_U64_B16 diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.h b/third_party/xla/xla/service/gpu/cub_sort_kernel.h index 29b163e7b1bf0b..627dd7ef079b84 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.h +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.h @@ -67,6 +67,10 @@ XLA_CUB_DECLARE_SORT_PAIRS(u64_b16) XLA_CUB_DECLARE_SORT_PAIRS(u64_b32) XLA_CUB_DECLARE_SORT_PAIRS(u64_b64) +XLA_CUB_DECLARE_SORT_PAIRS(f32_b16) +XLA_CUB_DECLARE_SORT_PAIRS(f32_b32) +XLA_CUB_DECLARE_SORT_PAIRS(f32_b64) + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index d37d322ddb32f2..eb22488d981ffa 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -306,7 +306,7 @@ ENTRY m { INSTANTIATE_TEST_SUITE_P( CubSort, CubSortPairsTest, - ::testing::Combine(::testing::Values(U8, U16, U32, U64), + ::testing::Combine(::testing::Values(U8, U16, U32, U64, F32), ::testing::Values(F16, F32, F64), ::testing::Bool(), ::testing::Values(1, 10)), [](const ::testing::TestParamInfo& info) { From 4dd87b5325363fd434bf9d697cbf99d655eac247 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 09:40:25 -0700 Subject: [PATCH 0708/1324] Add `--repo_env=USE_PYWRAP_RULES=True` to cross-compile linux arm64 builds. PiperOrigin-RevId: 747454934 --- ci/official/envs/linux_arm64_cross_compile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/official/envs/linux_arm64_cross_compile b/ci/official/envs/linux_arm64_cross_compile index e4e9004b4f1c3a..7333be2ff9fff8 100644 --- a/ci/official/envs/linux_arm64_cross_compile +++ b/ci/official/envs/linux_arm64_cross_compile @@ -13,5 +13,5 @@ # limitations under the License. # ============================================================================== source ci/official/envs/linux_arm64 -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_linux_arm64" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_linux_arm64 --repo_env=USE_PYWRAP_RULES=True" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 From 7092acdb463e07e8ae306170f8ec57573aad0f31 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 14 Apr 2025 09:40:34 -0700 Subject: [PATCH 0709/1324] Removed unused and deprecated BUILD rule. PiperOrigin-RevId: 747454992 --- .../xla/xla/tsl/distributed_runtime/coordination/BUILD | 4 ---- 1 file changed, 4 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 22ac47cdc8834d..df5838794b63b9 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -81,10 +81,6 @@ cc_library( ], ) -# TODO(mwhittaker): Remove this once tensorflow only relies on the -# coordination_service BUILD rule. -cc_library(name = "coordination_service_impl") - tf_proto_library( name = "test_device_proto", testonly = 1, From 42af8b9d73314ecc4dfb42910bd96fccd9a54eea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 11:14:25 -0700 Subject: [PATCH 0710/1324] Fix MacOS ci failure Fix ci failure: ``` TypeError: non-default argument 'repo' follows default argument ``` PiperOrigin-RevId: 747494308 --- third_party/xla/build_tools/ci/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 2b5faa2f8ef1cf..ce9b9464df91d4 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -137,8 +137,8 @@ class Build: _builds: ClassVar[Dict[BuildType, "Build"]] = {} type_: BuildType - subcommand: str = "test" repo: str + subcommand: str = "test" target_patterns: Tuple[str, ...] configs: Tuple[str, ...] = () build_tag_filters: Tuple[str, ...] = () From 1913a4f4b3642c48168939e1008165f18b20d6a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 11:18:51 -0700 Subject: [PATCH 0711/1324] Add test for triangular solve re-writer PiperOrigin-RevId: 747496012 --- .../xla/xla/service/gpu/transforms/BUILD | 18 +++- .../transforms/triangular_solve_rewriter.cc | 1 + .../triangular_solve_rewriter_test.cc | 84 +++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter_test.cc diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 307b89e6a59d39..2a643979a48205 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -3311,7 +3311,6 @@ xla_cc_test( ], ) -# TODO(b/358278858): Currently lacking test coverage. cc_library( name = "triangular_solve_rewriter", srcs = ["triangular_solve_rewriter.cc"], @@ -3331,6 +3330,23 @@ cc_library( ], ) +xla_cc_test( + name = "triangular_solve_rewriter_test", + srcs = [ + "triangular_solve_rewriter_test.cc", + ], + deps = [ + ":triangular_solve_rewriter", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/service:pattern_matcher", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "triton_fusion_numerics_verifier", srcs = ["triton_fusion_numerics_verifier.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc index b6cbcce83a4af1..d4562437b0d03e 100644 --- a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc @@ -79,6 +79,7 @@ absl::StatusOr TriangularSolveRewriter::Run( TF_ASSIGN_OR_RETURN(HloInstruction * gte, MakeGetTupleElementHlo(custom_call, 0)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + changed = true; } } return changed; diff --git a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter_test.cc new file mode 100644 index 00000000000000..7936039c922e74 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/service/pattern_matcher.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" + +namespace m = ::xla::match; + +namespace xla { +namespace gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +using TriangularSolveRewriterTest = HloHardwareIndependentTestBase; + +TEST_F(TriangularSolveRewriterTest, TriangularSolveWithTranspose) { + const char* const hlo_string = R"( +HloModule TriangularSolve + +ENTRY main { + a = f32[4,4]{1,0} parameter(0) + b = f32[3,4]{1,0} parameter(1) + ROOT triangular-solve = f32[3,4]{1,0} triangular-solve(a, b), lower=true, + transpose_a=TRANSPOSE +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TriangularSolveRewriter rewriter; + EXPECT_THAT(rewriter.Run(module.get()), IsOkAndHolds(true)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({"__cublas$triangularSolve"})))); +} + +TEST_F(TriangularSolveRewriterTest, RightLowerNoTranspose) { + const char* const hlo_string = R"( +HloModule TriangularSolve + +ENTRY %RightLowerNoTranspose (a: f32[4,4], b: f32[3,4]) -> f32[3,4] { + a = f32[4,4]{1,0} parameter(0) + b = f32[3,4]{1,0} parameter(1) + ROOT %solve = f32[3,4]{1,0} triangular-solve(a, b), lower=true, transpose_a=NO_TRANSPOSE +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TriangularSolveRewriter rewriter; + EXPECT_THAT(rewriter.Run(module.get()), IsOkAndHolds(true)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({"__cublas$triangularSolve"})))); +} + +} // namespace +} // namespace gpu +} // namespace xla From 76fd4df2ea3663c550bd419bdc7df717024e0fae Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 11:25:11 -0700 Subject: [PATCH 0712/1324] Fix clang-tidy warning in xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc ``` third_party/tensorflow/compiler/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc third_party/tensorflow/compiler/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc ``` PiperOrigin-RevId: 747498431 --- third_party/xla/xla/hlo/transforms/collectives/BUILD | 3 ++- .../transforms/collectives/collectives_schedule_linearizer.cc | 2 +- .../collectives/collectives_schedule_linearizer_test.cc | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/collectives/BUILD b/third_party/xla/xla/hlo/transforms/collectives/BUILD index d1b86c51f67cb3..3fb783e8f05ef6 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/BUILD +++ b/third_party/xla/xla/hlo/transforms/collectives/BUILD @@ -182,12 +182,12 @@ cc_library( "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:errors", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", ], ) @@ -201,6 +201,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:test_helpers", "//xla/service:pattern_matcher", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc index f5acfa5a8672c3..2508100d6c846a 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc index d93acafcffa692..ff07396caa003e 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc @@ -24,9 +24,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/service/pattern_matcher.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla { From 76b5670536d41034b485611bbd54456127f3f320 Mon Sep 17 00:00:00 2001 From: Ranko Sredojevic Date: Mon, 14 Apr 2025 11:33:09 -0700 Subject: [PATCH 0713/1324] Add PassFuelIsSet to check if the pass fuel was explicitly limited, or defaults to INF. PiperOrigin-RevId: 747501566 --- third_party/xla/xla/debug_options_flags.cc | 7 +++++ third_party/xla/xla/debug_options_flags.h | 3 ++ .../xla/xla/debug_options_parsers_test.cc | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 6f5e06a6e65277..16f94aa0022b59 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -2395,6 +2395,13 @@ void ResetThreadLocalFuel() { } } +bool PassFuelIsSet(absl::string_view pass) { + absl::call_once(flags_init, &AllocateFlags, nullptr); + auto* fuel_pool = thread_fuel ? thread_fuel.get() : global_fuel; + auto it = fuel_pool->find(pass); + return it != fuel_pool->end(); +} + bool ConsumeFuel(absl::string_view pass, bool* just_ran_out) { absl::call_once(flags_init, &AllocateFlags, nullptr); if (just_ran_out != nullptr) { diff --git a/third_party/xla/xla/debug_options_flags.h b/third_party/xla/xla/debug_options_flags.h index 7438750351c6c4..32ba39eb5a06a3 100644 --- a/third_party/xla/xla/debug_options_flags.h +++ b/third_party/xla/xla/debug_options_flags.h @@ -46,6 +46,9 @@ DebugOptions GetDebugOptionsFromFlags(); // Gets a DebugOptions proto that reflects the defaults as if no flags were set. DebugOptions DefaultDebugOptionsIgnoringFlags(); +// Checks whether the pass fuel was explicitly set. +bool PassFuelIsSet(absl::string_view pass); + // Consumes a unit of "compiler fuel" for the given pass, and returns false if // we're out of fuel for that pass. // diff --git a/third_party/xla/xla/debug_options_parsers_test.cc b/third_party/xla/xla/debug_options_parsers_test.cc index 2636ef7ec312cd..a1f829247b5767 100644 --- a/third_party/xla/xla/debug_options_parsers_test.cc +++ b/third_party/xla/xla/debug_options_parsers_test.cc @@ -97,6 +97,34 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(GetUppercaseStringSetterTestCases()), UppercaseStringSetterTest::Name); +TEST(FuelTest, FuelPassCountsAreSeparate) { + tsl::setenv("XLA_FLAGS", "--xla_fuel=ABC=1,PQR=2", /*overwrite=*/true); + // Parse flags from the environment variable. + int* pargc; + std::vector* pargv; + ResetFlagsFromEnvForTesting("XLA_FLAGS", &pargc, &pargv); + + EXPECT_TRUE(ConsumeFuel("ABC")); + EXPECT_FALSE(ConsumeFuel("ABC")); + + EXPECT_TRUE(ConsumeFuel("PQR")); + EXPECT_TRUE(ConsumeFuel("PQR")); + EXPECT_FALSE(ConsumeFuel("PQR")); +} + +TEST(FuelTest, + PassFuelIsSetReturnsTrueOnExplicitlyFueledPassesAndFalseOtherwise) { + tsl::setenv("XLA_FLAGS", "--xla_fuel=ABC=1,PQR=2", /*overwrite=*/true); + // Parse flags from the environment variable. + int* pargc; + std::vector* pargv; + ResetFlagsFromEnvForTesting("XLA_FLAGS", &pargc, &pargv); + + EXPECT_TRUE(PassFuelIsSet("ABC")); + EXPECT_FALSE(PassFuelIsSet("MNO")); + EXPECT_TRUE(PassFuelIsSet("PQR")); + EXPECT_FALSE(PassFuelIsSet("XYZ")); +} } // namespace } // namespace xla From 5a11d4f1b682ce94adfb75ddab473d955eae65dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 11:34:03 -0700 Subject: [PATCH 0714/1324] Fix an undefined behavior in the `pjrt` directory of OpenXLA. It is undefined behavior to cast a `const unique_ptr*` to a `const unique_ptr*` and dereference it, even if the code doesn't attempt to mutate the `unique_ptr`, as it violates the strict aliasing rule. By making `Create()` templated and accept any subclass of `PjRtDevice`, we get rid of the need for the cast and remove the undefined behavior. Also fix a style violation by replacing a `reinterpret_cast` with `down_cast`. PiperOrigin-RevId: 747501920 --- .../mlir/tfrt/transforms/ifrt/tf2hlo_test.cc | 63 +++++++++---------- third_party/xla/xla/pjrt/c/BUILD | 1 + .../pjrt/c/pjrt_c_api_raw_buffer_external.cc | 6 +- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 30 ++++----- .../xla_cpu/cpu_topology_description.cc | 15 ----- .../plugin/xla_cpu/cpu_topology_description.h | 9 +-- 6 files changed, 50 insertions(+), 74 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 24252c40ae7da9..1bd737b98c3787 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -118,11 +118,10 @@ TEST_F(Tf2HloTest, Empty) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, {})); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -168,11 +167,10 @@ TEST_F(Tf2HloTest, Tuple) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -219,11 +217,10 @@ TEST_F(Tf2HloTest, Spmd) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -307,11 +304,10 @@ TEST_F(Tf2HloTest, UsingDefaultDeviceAssignment) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -420,11 +416,10 @@ TEST_F(Tf2HloTest, XlaCallHostCallback) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -530,11 +525,10 @@ TEST_F(Tf2HloTest, SameArgProduceSameKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -592,11 +586,10 @@ TEST_F(Tf2HloTest, DifferentCompileMetadataProduceDifferentKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index e95f741002efcd..9b5d8c019f1a53 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -166,6 +166,7 @@ cc_library( "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:casts", ], ) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_external.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_external.cc index 4d98cc067a9f38..31143a85369f70 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_external.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_raw_buffer_external.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/pjrt/raw_buffer.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" +#include "tsl/platform/casts.h" #define PJRT_RETURN_FUTURE_IF_ERROR(expr, c_api) \ do { \ @@ -172,8 +173,9 @@ PjRtCApiBuffer_CreateRawAliasOfBuffer_Factory(PjRtBuffer* buffer) { pjrt::PjRtCApiBuffer_CreateRawAliasOfBuffer( c_api, extension, c_api_buffer->c_buffer())); return tsl::MakeRef( - raw_buffer, reinterpret_cast(c_api_buffer->client()), - c_api, extension); + raw_buffer, + tensorflow::down_cast(c_api_buffer->client()), c_api, + extension); } return std::nullopt; } diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index bde9dab1e928a5..757176a6d0d210 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -282,18 +282,6 @@ class CpuAsyncHostToDeviceTransferManager TfrtCpuDevice* device_; }; -// Converts a const span of unique_ptr to a const span of -// unique_ptr. This is a safe operation because the resulting span -// only permits access to elements via pointer dereference, and unique_ptr -// values remain immutable. -absl::Span> GetPjRtDeviceSpan( - absl::Span> devices) { - static_assert(std::is_base_of_v); - return absl::Span>( - reinterpret_cast*>(devices.data()), - devices.size()); -} - } // namespace static int CpuDeviceCount() { @@ -338,6 +326,19 @@ static tsl::ThreadOptions GetThreadOptions() { return thread_options; } +// Returns the CPU devices from the given TfrtCpuDevices. +// Precondition: `devices` doesn't contain nullptr. +static std::vector GetCpuDevices( + absl::Span> devices) { + std::vector cpu_devices; + cpu_devices.reserve(devices.size()); + for (const auto& device : devices) { + cpu_devices.push_back(CpuTopology::CpuDevice{ + device->process_index(), device->local_hardware_id().value()}); + } + return cpu_devices; +} + TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, std::shared_ptr collectives, size_t num_threads, @@ -361,9 +362,8 @@ TfrtCpuClient::TfrtCpuClient( tsl::MakeAvailableAsyncValueRef()), transpose_cache_(1024), collectives_(std::move(collectives)), - topology_(CpuTopologyDescription::Create( - platform_id(), platform_name(), platform_version(), - GetPjRtDeviceSpan(owned_devices_), cpu::DetectMachineAttributes())), + topology_(platform_id(), platform_name(), platform_version(), + GetCpuDevices(owned_devices_), cpu::DetectMachineAttributes()), asynchronous_(asynchronous), customize_hlo_module_config_(std::move(customize_hlo_module_config)) { for (const std::unique_ptr& device : owned_devices_) { diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc index 60a9054588d6c8..38cb2e5a901e31 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc @@ -37,21 +37,6 @@ limitations under the License. namespace xla { -/*static*/ CpuTopologyDescription CpuTopologyDescription::Create( - PjRtPlatformId platform_id, absl::string_view platform_name, - absl::string_view platform_version, - absl::Span> devices, - absl::Span machine_attributes) { - std::vector cpu_devices; - cpu_devices.reserve(devices.size()); - for (const auto& device : devices) { - cpu_devices.push_back(CpuTopology::CpuDevice{ - device->process_index(), device->local_hardware_id().value()}); - } - return CpuTopologyDescription(platform_id, platform_name, platform_version, - cpu_devices, machine_attributes); -} - absl::StatusOr CpuTopologyDescription::GetDefaultLayout( PrimitiveType element_type, absl::Span dims) const { Shape shape = ShapeUtil::MakeShape(element_type, dims); diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h index 545644c0c7eaec..76521e2326e4ac 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -36,18 +37,12 @@ namespace xla { class CpuTopologyDescription : public PjRtTopologyDescription { public: - static CpuTopologyDescription Create( - PjRtPlatformId platform_id, absl::string_view platform_name, - absl::string_view platform_version, - absl::Span> devices, - absl::Span machine_attributes); - // `cpu_device_ids` is the list of logical device ids for the CPU devices and // will be used to initialize the CPU topology. CpuTopologyDescription(const PjRtPlatformId platform_id, const absl::string_view platform_name, const absl::string_view platform_version, - const std::vector cpu_devices, + std::vector cpu_devices, absl::Span machine_attributes) : platform_id_(platform_id), platform_name_(platform_name), From d11567ff589918888a71ca2824c89cdbbababa9d Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 14 Apr 2025 11:46:29 -0700 Subject: [PATCH 0715/1324] Make `AddTransferMetadata` warning less noisy There's little value in logging this many times. PiperOrigin-RevId: 747506385 --- third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 0d0a18ee2b8606..bf94cf22e728dc 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -271,8 +271,8 @@ class AbstractAsyncHostToHostMemoryTransferManager void SetBufferError(int buffer_index, absl::Status error) override; void AddTransferMetadata(const TransferMetadata& meta) override { - LOG(WARNING) << "AddTransferMetadata not implemented for " - "AbstractAsyncHostToHostMemoryTransferManager"; + LOG_FIRST_N(WARNING, 1) << "AddTransferMetadata not implemented for " + "AbstractAsyncHostToHostMemoryTransferManager"; } protected: From 02edd5ba6f62ced535e89ef402a4ece426457bdb Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Mon, 14 Apr 2025 11:50:07 -0700 Subject: [PATCH 0716/1324] Lower jnp.stack to tfl.pack. Without the rewrite, the op is lowered to a number of concatenate and reshape ops. Currently TD isn't able to represent variadic of variadic, later we need to move to C++ rewriter to support more cases. PiperOrigin-RevId: 747507670 --- .../transforms/composite_lowering_patterns.td | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index aa2d50438c0c54..4f833459cc3a07 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -212,4 +212,14 @@ def LegalizeCompositeUnpack : Pat< ConstantStrAttr, $attrs, $_, $_), (TFL_UnpackOp $inputs, (GetCompositeAttributeAs<"num", "IntegerAttr"> $attrs), - (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; \ No newline at end of file + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; + +def LegalizeCompositePack4Elements : Pat< + (MHLO_CompositeOp:$composite + // TD not able to represent variadic of variadic now. + // Move to C++ matcher to support more cases. + (variadic $i0, $i1, $i2, $i3), + ConstantStrAttr, $attrs, $_, $_), + (TFL_PackOp (variadic $i0, $i1, $i2, $i3), + (GetCompositeAttributeAs<"values_count", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; From 86dbdc77c8b350c54a43895b47f70df52c98497a Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 14 Apr 2025 11:56:22 -0700 Subject: [PATCH 0717/1324] Implement `AbslStringify` for `PjRtLayout` PiperOrigin-RevId: 747510252 --- third_party/xla/xla/pjrt/BUILD | 1 + third_party/xla/xla/pjrt/pjrt_layout.h | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index cb842dc4687f12..887a16a348a435 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -455,6 +455,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/pjrt/pjrt_layout.h b/third_party/xla/xla/pjrt/pjrt_layout.h index e4318102bf7c1c..e3b6d44cf72a15 100644 --- a/third_party/xla/xla/pjrt/pjrt_layout.h +++ b/third_party/xla/xla/pjrt/pjrt_layout.h @@ -17,11 +17,13 @@ limitations under the License. #define XLA_PJRT_PJRT_LAYOUT_H_ #include +#include #include #include #include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" @@ -64,10 +66,19 @@ class PjRtLayout { return H::combine(std::move(state), layout.xla_layout_); } + template + friend void AbslStringify(Sink& sink, const PjRtLayout& layout) { + absl::Format(&sink, "%s", layout.ToString()); + } + private: Layout xla_layout_; }; +inline std::ostream& operator<<(std::ostream& out, const PjRtLayout& layout) { + return out << layout.ToString(); +} + } // namespace xla #endif // XLA_PJRT_PJRT_LAYOUT_H_ From ca76136aeac05c508924df1907852b94814f7715 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 14 Apr 2025 12:04:56 -0700 Subject: [PATCH 0718/1324] [xla:gpu] CommandBuffer: move command dependencies tracking to CommandBufferCmdSequence This change doesn't require any tests as it must generate CUDA graphs with exactly the same structure, and it is already covered by existing tests. PiperOrigin-RevId: 747513778 --- .../xla/xla/backends/gpu/runtime/BUILD | 1 + .../gpu/runtime/command_buffer_cmd.cc | 115 +++++++++++++----- .../backends/gpu/runtime/command_buffer_cmd.h | 18 ++- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 4057f2d192f699..2bcf42526d9210 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -110,6 +110,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 67fc4e89526ab5..c2e9eea2ecf047 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -155,7 +156,7 @@ static se::CommandBuffer::Builder CreateExecutionScopeBuilder( using CreateCommand = absl::FunctionRef( - absl::Span dependencies)>; + absl::Span dependencies)>; using UpdateCommand = absl::FunctionRef; @@ -301,42 +302,55 @@ CommandBufferCmdSequence::RecordCreate( // Keep a state associated with commands in the sequence in the state manager. CommandBufferCmd::StateManager& state = record_params.state; - for (std::unique_ptr& command : commands_) { + // Collect sink commands while recording the command sequence. + std::vector sink_commands; + + for (CommandId id = 0; id < commands_.size(); ++id) { + CommandBufferCmd* command = commands_[id].get(); + std::optional annotation = GetKernelAnnotation(command->profile_annotation()); // Skip recording collective commands if mock collectives are enabled. if (execute_params.mock_collectives && - dynamic_cast(command.get())) { + dynamic_cast(command)) { continue; } // Create new commands by recording them into the command buffer. - DCHECK(!state.GetOrNull(command.get(), command_buffer)) + DCHECK(!state.GetOrNull(command, command_buffer)) << "Record state must be null for " << command->ToString(); auto* record_state = - state.GetOrCreate(command.get(), command_buffer); + state.GetOrCreate(command, command_buffer); - // TODO(b/406370928): Fetch command dependencies computed from the command - // sequence, today we rely on implicit synchronization of all commands. - auto record_action = CommandBufferCmd::RecordCreate{}; + std::vector command_dependencies = + Dependencies(record_params, command_buffer, id); + + // Source command must depend on external dependencies passed by the caller, + // internal commands dependencies are defined by the command sequence + // structure (buffer and resource dependencies). + auto record_action = + IsSource(id) ? CommandBufferCmd::RecordCreate{dependencies} + : CommandBufferCmd::RecordCreate{command_dependencies}; TF_ASSIGN_OR_RETURN( record_state->command, command->Record(execute_params, record_params, std::move(record_action), command_buffer)); + + // Collect sink commands as external dependencies for the next command + // sequence recorded into the same command buffer. + if (IsSink(id)) { + sink_commands.push_back(record_state->command); + } } uint64_t end_micros = tsl::Env::Default()->NowMicros(); - VLOG(3) << "Created " << commands_.size() << " commands in " - << (end_micros - start_micros) << " μs"; + VLOG(3) << absl::StrFormat( + "Created %d commands in %d μs (num sink commands: %d)", commands_.size(), + end_micros - start_micros, sink_commands.size()); - // TODO(b/406370928): Depending on synchronization mode we must collect - // commands created for all sink nodes in the execution graph. - auto* last_recorded = - state.GetOrNull(commands_.back().get(), command_buffer); - DCHECK(last_recorded) << "Last recorded command state must be not null"; - return std::vector{last_recorded->command}; + return sink_commands; } absl::Status CommandBufferCmdSequence::RecordUpdate( @@ -357,24 +371,24 @@ absl::Status CommandBufferCmdSequence::RecordUpdate( // Keep a state associated with commands in the sequence in the state manager. CommandBufferCmd::StateManager& state = record_params.state; - for (std::unique_ptr& command : commands_) { + for (CommandId id = 0; id < commands_.size(); ++id) { + CommandBufferCmd* command = commands_[id].get(); + std::optional annotation = GetKernelAnnotation(command->profile_annotation()); // Skip updating collective commands if mock collectives are enabled. if (execute_params.mock_collectives && - dynamic_cast(command.get())) { + dynamic_cast(command)) { continue; } // Update existing commands in the command buffer. - auto* record_state = - state.GetOrNull(command.get(), command_buffer); + auto* record_state = state.GetOrNull(command, command_buffer); DCHECK(record_state) << "Record state must be not null for " << command->ToString(); auto record_action = CommandBufferCmd::RecordUpdate{record_state->command}; - TF_ASSIGN_OR_RETURN( record_state->command, command->Record(execute_params, record_params, std::move(record_action), @@ -398,6 +412,41 @@ absl::Status CommandBufferCmdSequence::CheckCommandBufferState( return absl::OkStatus(); } +// TODO(b/406370928): Currently we assume sequential execution order of all +// recorded commands, so we use a very simple rule for identifying source and +// sink nodes and computing dependencies. Long term we should get that from the +// ExecutionGraph helper, when using automatic synchronization mode. + +bool CommandBufferCmdSequence::IsSource(CommandId id) const { return id == 0; } + +bool CommandBufferCmdSequence::IsSink(CommandId id) const { + return id + 1 == commands_.size(); +} + +std::vector +CommandBufferCmdSequence::Dependencies(const RecordParams& record_params, + se::CommandBuffer* command_buffer, + CommandId id) const { + // Source commands have no dependencies. + if (IsSource(id)) { + return {}; + } + + // Find recorded command state for the previous command in the sequence. + auto* record_state = record_params.state.GetOrNull( + commands_[id - 1].get(), command_buffer); + DCHECK(record_state) << "Record state must be not null for " + << commands_[id - 1]->ToString(); + + // Some commands might end up not recording anything into the command buffer, + // e.g. memcpy commands where source and destination are the same. + if (record_state->command == nullptr) { + return {}; + } + + return {record_state->command}; +} + const absl::flat_hash_set& CommandBufferCmdSequence::buffers() const { return buffers_; @@ -512,7 +561,7 @@ TracedCommandBufferCmd::RecordTracedCommand( VLOG(5) << "Record traced command into command buffer: " << command_buffer; return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { @@ -559,7 +608,7 @@ absl::StatusOr ComputationIdCmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateMemset(&dst, value, /*num_elements=*/1, dependencies); }, @@ -632,7 +681,7 @@ absl::StatusOr LaunchCmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateLaunch(dims_.thread_counts_per_block(), dims_.block_counts(), *kernel, *kernel_args, dependencies); @@ -712,7 +761,7 @@ absl::StatusOr CustomKernelLaunchCmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateLaunch(custom_kernel_.thread_dims(), custom_kernel_.block_dims(), *kernel, kernel_args, dependencies); @@ -766,7 +815,7 @@ MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params, return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateMemcpyD2D(&dst, src, num_bytes_, dependencies); }, @@ -805,7 +854,7 @@ absl::StatusOr MemzeroCmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateMemset(&dst, uint8_t{0}, /*num_elements=*/dst_.size(), dependencies); @@ -847,7 +896,7 @@ absl::StatusOr Memset32Cmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateMemset( &dst, bit_pattern_, /*num_elements=*/dst_.size() / sizeof(uint32_t), dependencies); @@ -898,7 +947,7 @@ absl::StatusOr CaseCmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { if (index_is_bool_) { return command_buffer->CreateCase(se::DeviceMemory(index), std::move(branches), dependencies); @@ -970,7 +1019,7 @@ absl::StatusOr WhileCmd::Record( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateWhile(se::DeviceMemory(pred), std::move(cond), std::move(body), dependencies); @@ -1369,7 +1418,7 @@ CustomCallCmd::RecordLegacyCustomCall( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { @@ -1450,7 +1499,7 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { @@ -1507,7 +1556,7 @@ CollectiveCmd::RecordTracedCommand( return Handle( std::move(record_action), - [&](absl::Span dependencies) { + [&](absl::Span dependencies) { return command_buffer->CreateNestedCommand(*nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index c21bea10e80133..4f0a333a3e1a75 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -202,7 +202,7 @@ class CommandBufferCmd { // Create new commands in the command buffer using the given dependencies. struct RecordCreate { - absl::Span dependencies; + absl::Span dependencies; }; // Update previously recorded commands in the command buffer. @@ -369,7 +369,7 @@ class CommandBufferCmdSequence { // Records commands into the command buffer. This method automatically // switches between `RecordCreate` or `RecordUpdate` depending on the command - // buffer state. This method assumes that no other command buffer sequences is + // buffer state. This method assumes that no other command buffer sequence is // recorded into the same command buffer, and doesn't set up initial // dependencies for recorded commands. // @@ -411,6 +411,9 @@ class CommandBufferCmdSequence { } private: + // We use index into the `commands_` vector as a command id. + using CommandId = int64_t; + // A state associated with commands in the sequence. We rely on this state to // efficiently update command recorded into the command buffer. struct RecordState : public CommandBufferCmd::State { @@ -424,6 +427,17 @@ class CommandBufferCmdSequence { absl::Status CheckCommandBufferState(se::CommandBuffer* command_buffer, se::CommandBuffer::State expected_state); + // Returns true if command has no dependencies. + bool IsSource(CommandId id) const; + + // Returns true if command is not a dependency of any other commands. + bool IsSink(CommandId id) const; + + // Returns dependencies of the command with the given id. + std::vector Dependencies( + const RecordParams& record_params, se::CommandBuffer* command_buffer, + CommandId id) const; + SynchronizationMode synchronization_mode_; std::vector> commands_; From a1c4864e166fe125d54262daa2252e8cfb6b585d Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Mon, 14 Apr 2025 12:21:08 -0700 Subject: [PATCH 0719/1324] Fix the order of population of `ManualAxesHierarchy`. The hierarchy's region has the following order. * Before: local + outer parent + inner parent * After: outer parent + inner parent + local Given a nested manual computations with 3 levels, ``` manual_computation_level_0, manaul_axes_level_0 manual_computation_level_1, manaul_axes_level_1 manual_computation_level_2, manaul_axes_level_2 ``` The region manual axes for the level_2 manual computation has the following order * Before: manaul_axes_level_2 + manaul_axes_level_0 + manaul_axes_level_1 * After: manaul_axes_level_0 + manaul_axes_level_1 + manaul_axes_level_2 The new order are sorted, while the old one is unsorted. PiperOrigin-RevId: 747519973 --- .../stablehlo_round_trip/shard_map_export.cc | 12 ++-- ...stablehlo_round_trip_shard_map_export.mlir | 55 +++++++++++-------- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc index 5e1997d6f92c21..4b9d1de4b450a6 100644 --- a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/shard_map_export.cc @@ -89,8 +89,8 @@ using ManualComputationToParentManualAxes = // Given an ManualComputationOp `op`, `op.getManualAxes()` is the local manual // axes. `parent` is the manual axes of its parent ManualComputationOp, -// recursively. `region` is the concatenation of `op.getManualAxes()` and -// `parent`. +// recursively. `region` is the concatenation of `parent` and +// `op.getManualAxes()`. struct ManualAxesHierarchy { mlir::ArrayRef parent; SmallVector region; @@ -100,15 +100,15 @@ ManualAxesHierarchy getManualAxesHierarchy( ManualComputationOp op, const ManualComputationToParentManualAxes& parentManualCompAxes) { ManualAxesHierarchy hierarchy; - hierarchy.region = SmallVector(op.getManualAxes().begin(), - op.getManualAxes().end()); if (auto parentManualAxes = parentManualCompAxes.find(op); parentManualAxes != parentManualCompAxes.end()) { hierarchy.parent = parentManualAxes->getSecond(); - hierarchy.region.append(parentManualAxes->getSecond().begin(), - parentManualAxes->getSecond().end()); } + + hierarchy.region = + SmallVector(hierarchy.parent.begin(), hierarchy.parent.end()); + hierarchy.region.append(op.getManualAxes().begin(), op.getManualAxes().end()); return hierarchy; } diff --git a/third_party/xla/xla/service/spmd/shardy/test/stablehlo_round_trip_shard_map_export.mlir b/third_party/xla/xla/service/spmd/shardy/test/stablehlo_round_trip_shard_map_export.mlir index 03fb7347393061..20fb846fe06898 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/stablehlo_round_trip_shard_map_export.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/stablehlo_round_trip_shard_map_export.mlir @@ -85,25 +85,32 @@ func.func @call_op_with_no_operands_or_results() { } // CHECK-LABEL: func @nested_shmaps -func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK-NEXT: %[[COPY_OPERAND_OUTER:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[COPY_OPERAND_INNER:.*]] = mhlo.copy %[[FULL_TO_SHARD_OUTER]] {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> - // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[MULT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[COPY_RESULT_OUTER:.*]] = mhlo.copy %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: return %[[SHARD_TO_FULL_OUTER]] : tensor<4x8xf32> - %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { - %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { - %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> - sdy.return %2 : tensor<2x4xf32> - } : (tensor<2x8xf32>) -> tensor<2x8xf32> - sdy.return %1 : tensor<2x8xf32> - } : (tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> +func.func @nested_shmaps(%arg0: tensor<4x8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}, {"c"}]>}) -> (tensor<4x8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}, {?}]>}) { + // CHECK-NEXT: %[[COPY_OPERAND_LEVEL_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8x16xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_LEVEL_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_LEVEL_0]]) {mhlo.sharding = "{devices=[1,1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8x16xf32>) -> tensor<2x8x16xf32> + // CHECK-NEXT: %[[COPY_OPERAND_LEVEL_1:.*]] = mhlo.copy %[[FULL_TO_SHARD_LEVEL_0]] {mhlo.sharding = "{devices=[1,2,1,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8x16xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_LEVEL_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_LEVEL_1]]) {mhlo.sharding = "{devices=[1,1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<2x8x16xf32>) -> tensor<2x4x16xf32> + // CHECK-NEXT: %[[COPY_OPERAND_LEVEL_2:.*]] = mhlo.copy %[[FULL_TO_SHARD_LEVEL_1]] {mhlo.sharding = "{devices=[1,1,2,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4x16xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_LEVEL_2:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_LEVEL_2]]) {mhlo.sharding = "{devices=[1,1,1,8,2]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<2x4x16xf32>) -> tensor<2x4x8xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_LEVEL_2]], %[[FULL_TO_SHARD_LEVEL_2]] {mhlo.sharding = "{devices=[1,1,1,8,2]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x4x8xf32> + // CHECK-NEXT: %[[COPY_RESULT_LEVEL_2:.*]] = mhlo.copy %[[MULT]] {mhlo.sharding = "{devices=[1,1,1,8,2]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_LEVEL_2:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_LEVEL_2]]) {mhlo.sharding = "{devices=[1,1,2,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4x8xf32>) -> tensor<2x4x16xf32> + // CHECK-NEXT: %[[COPY_RESULT_LEVEL_1:.*]] = mhlo.copy %[[SHARD_TO_FULL_LEVEL_2]] {mhlo.sharding = "{devices=[1,1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x4x16xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_LEVEL_1:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_LEVEL_1]]) {mhlo.sharding = "{devices=[1,2,1,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4x16xf32>) -> tensor<2x8x16xf32> + // CHECK-NEXT: %[[COPY_RESULT_LEVEL_0:.*]] = mhlo.copy %[[SHARD_TO_FULL_LEVEL_1]] {mhlo.sharding = "{devices=[1,1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8x16xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_LEVEL_0:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_LEVEL_0]]) {mhlo.sharding = "{devices=[2,1,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8x16xf32>) -> tensor<4x8x16xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL_LEVEL_0]] : tensor<4x8x16xf32> + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8x16xf32>) { + %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}, {}]>] out_shardings=[<@mesh_1, [{}, {"b"}, {}]>] manual_axes={"b"} (%arg2: tensor<2x4x16xf32>) { + %2 = sdy.manual_computation(%arg2) in_shardings=[<@mesh_1, [{}, {}, {"c"}]>] out_shardings=[<@mesh_1, [{}, {}, {"c"}]>] manual_axes={"c"} (%arg3: tensor<2x4x8xf32>) { + %3 = stablehlo.multiply %arg3, %arg3 : tensor<2x4x8xf32> + sdy.return %3 : tensor<2x4x8xf32> + } : (tensor<2x4x16xf32>) -> tensor<2x4x16xf32> + sdy.return %2 : tensor<2x4x16xf32> + } : (tensor<2x8x16xf32>) -> tensor<2x8x16xf32> + sdy.return %1 : tensor<2x8x16xf32> + } : (tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> } // CHECK-LABEL: func @nested_shmaps_extra_op @@ -111,11 +118,11 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh // CHECK-NEXT: %[[COPY_OPERAND_OUTER:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8xf32> // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_OPERAND_INNER:.*]] = mhlo.copy %[[FULL_TO_SHARD_OUTER]] {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> - // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[MULT]], %[[MULT]] {mhlo.sharding = "{devices=[2,1,4,2]<=[2,2,2,2]T(2,1,0,3) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SUB:.*]] = stablehlo.subtract %[[ADD]], %[[ADD]] {mhlo.sharding = "{devices=[4,1,4]<=[2,2,4]T(2,1,0) last_tile_dims={manual}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[SUB]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[MULT]], %[[MULT]] {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[SUB:.*]] = stablehlo.subtract %[[ADD]], %[[ADD]] {mhlo.sharding = "{devices=[4,1,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[SUB]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_TO_FULL_INNER]], %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> // CHECK-NEXT: %[[COPY_RESULT_OUTER:.*]] = mhlo.copy %[[ADD]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> From c8acf1b46151b776cc60a4f1e949b01e91fdfe27 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 12:21:50 -0700 Subject: [PATCH 0720/1324] Fix an issue in the `StatsCalculator::OrderNodesByMetric` function to correct run_order based ordering behavior. PiperOrigin-RevId: 747520228 --- .../xla/xla/tsl/util/stats_calculator.cc | 38 ++++++++++++++----- .../xla/xla/tsl/util/stats_calculator.h | 3 ++ .../xla/xla/tsl/util/stats_calculator_test.cc | 26 +++++++++++++ 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/tsl/util/stats_calculator.cc b/third_party/xla/xla/tsl/util/stats_calculator.cc index cdfa46c94417c3..eddeea382f6f58 100644 --- a/third_party/xla/xla/tsl/util/stats_calculator.cc +++ b/third_party/xla/xla/tsl/util/stats_calculator.cc @@ -15,11 +15,16 @@ limitations under the License. #include "xla/tsl/util/stats_calculator.h" +#include +#include #include +#include #include #include #include #include +#include +#include namespace tsl { @@ -103,37 +108,52 @@ std::string StatsCalculator::ColumnString(const Detail& detail, void StatsCalculator::OrderNodesByMetric( SortingMetric metric, std::vector* details) const { + // We convert each metric value to a string and use a priority queue to keep + // them sorted in descending order. For cases where we want to sort in + // ascending order, we transform the metric value to a string that + // represents the inverse of the original value. std::priority_queue> sorted_list; - const int num_nodes = details_.size(); - for (const auto& det : details_) { - const Detail* detail = &(det.second); + // We keep the run order metric in ascending order, so we need to know the + // maximum run order in order to sort the nodes correctly. + int max_run_order = 0; + for (const auto& [_, detail] : details_) { + max_run_order = std::max(max_run_order, detail.run_order); + } + const int num_nodes = max_run_order; + + for (const auto& [_, detail] : details_) { std::stringstream stream; stream << std::setw(20) << std::right << std::setprecision(10) << std::fixed; switch (metric) { case BY_NAME: - stream << detail->name; + // Sorted in ascending order of length of name. + stream << detail.name; break; case BY_RUN_ORDER: - stream << num_nodes - detail->run_order; + // Sorted in ascending order. + stream << num_nodes - detail.run_order; break; case BY_TIME: - stream << detail->elapsed_time.avg(); + // Sorted in descending order. + stream << detail.elapsed_time.avg(); break; case BY_MEMORY: - stream << detail->mem_used.avg(); + // Sorted in descending order. + stream << detail.mem_used.avg(); break; case BY_TYPE: - stream << detail->type; + // Sorted in ascending order of length of type. + stream << detail.type; break; default: stream << ""; break; } - sorted_list.emplace(stream.str(), detail); + sorted_list.emplace(stream.str(), &detail); } while (!sorted_list.empty()) { diff --git a/third_party/xla/xla/tsl/util/stats_calculator.h b/third_party/xla/xla/tsl/util/stats_calculator.h index 253895ca605fae..e5b0d6854f4bdb 100644 --- a/third_party/xla/xla/tsl/util/stats_calculator.h +++ b/third_party/xla/xla/tsl/util/stats_calculator.h @@ -217,6 +217,9 @@ class StatsCalculator { int64_t run_order, int64_t rel_end_us, int64_t mem_used); private: + // Orders the nodes in the details_ map by the given sorting metric. The + // details vector is populated with pointers to the Detail objects in the + // details_ map. void OrderNodesByMetric(SortingMetric sorting_metric, std::vector* details) const; diff --git a/third_party/xla/xla/tsl/util/stats_calculator_test.cc b/third_party/xla/xla/tsl/util/stats_calculator_test.cc index bbd75845f583d6..bf9401e4a5cb78 100644 --- a/third_party/xla/xla/tsl/util/stats_calculator_test.cc +++ b/third_party/xla/xla/tsl/util/stats_calculator_test.cc @@ -18,7 +18,10 @@ limitations under the License. #include #include #include +#include +#include +#include #include "xla/tsl/platform/test.h" namespace tsl { @@ -140,5 +143,28 @@ TEST(StatsCalculatorTest, StatWithPercentiles) { EXPECT_EQ(150, stat.percentile(100)); } +TEST(StatsCalculatorTest, + VerifyOrderStatsByRunOrderForMaxRunOrderLargerThanDetailsSize) { + auto options = StatSummarizerOptions(); + StatsCalculator calc(options); + EXPECT_TRUE(calc.GetDetails().empty()); + + calc.AddNodeStats("node1", "type_1", 1, 10, 20); + ASSERT_EQ(calc.GetDetails().size(), 1); + + calc.AddNodeStats("node1", "type_1", 2, 11, 21); + ASSERT_EQ(calc.GetDetails().size(), 1); + calc.AddNodeStats("node2", "type_2", 3, 10, 100); + ASSERT_EQ(calc.GetDetails().size(), 2); + calc.UpdateRunTotalUs(100); + std::string stats = calc.GetStatsByMetric( + "test", StatsCalculator::SortingMetric::BY_RUN_ORDER, 0); + ASSERT_GT(stats.size(), 0); + ASSERT_THAT(stats, ::testing::HasSubstr("node1")); + ASSERT_THAT(stats, ::testing::HasSubstr("node2")); + // Ensure that node1 has a lower run order than node2 in the stats. + ASSERT_LT(stats.find("node1"), stats.find("node2")); +} + } // namespace } // namespace tsl From 688234a972009b0f2bf3868c294dcde98f0f8122 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 Date: Tue, 15 Apr 2025 01:39:40 +0530 Subject: [PATCH 0721/1324] Fix C compatibility issue in TfLiteQuantizationType enum --- tensorflow/lite/core/c/common.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 57caa3b759a35d..3f1fe32b8b4f47 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -322,7 +322,11 @@ typedef struct TfLiteBFloat16 { const char* TfLiteTypeGetName(TfLiteType type); /// SupportedQuantizationTypes. +#ifdef __cplusplus typedef enum TfLiteQuantizationType : int { +#else +typedef enum TfLiteQuantizationType { +#endif /// No quantization. kTfLiteNoQuantization = 0, /// Affine quantization (with support for per-channel quantization). From de481154c75b70c9fedda6ab259dde48f3ad34b8 Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Mon, 14 Apr 2025 12:42:30 -0700 Subject: [PATCH 0722/1324] PR #25171: [ROCm] adjust workspace size for gfx950 Imported from GitHub PR https://github.com/openxla/xla/pull/25171 This PR adjusts the workspace size for gfx950 to be 64MB, which is needed in hipblaslt. Copybara import of the project: -- 78ce17eb70f5252e6882c5c2a6f564056f53dad4 by scxfjiang : enable ocp fp8 for gfx950 Merging this change closes #25171 PiperOrigin-RevId: 747527967 --- third_party/xla/xla/service/gpu/matmul_utils.h | 1 + .../xla/service/gpu/transforms/gemm_rewriter.cc | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 7f3be061fc59d5..b9bf8b1408eb88 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -82,6 +82,7 @@ struct GemmConfig : public se::gpu::GemmConfig { // Size of the workspace based on NVIDIA recommendation: // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace static constexpr int64_t kHopperWorkspace = 32 * 1024 * 1024; // 32 MiB + static constexpr int64_t kGFX950Workspace = 64 * 1024 * 1024; // 64 MiB static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024; // 4 MiB static absl::StatusOr For( diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 342e1e667654a9..76ad29ef84a637 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -2441,15 +2441,18 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - auto *cuda_cc = std::get_if(&gpu_version_); - // Pass a user-managed workspace to legacy cuBLAS operations, as // otherwise cuBLAS will use its own internal pool which will be competing // with XLA allocator for device memory. - int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace - : cuda_cc->IsAtLeastHopper() - ? GemmConfig::kHopperWorkspace - : GemmConfig::kDefaultWorkspace; + int64_t workspace = GemmConfig::kDefaultWorkspace; + auto *cuda_cc = std::get_if(&gpu_version_); + if (cuda_cc != nullptr && cuda_cc->IsAtLeastHopper()) { + workspace = GemmConfig::kHopperWorkspace; + } + auto *rocm_cc = std::get_if(&gpu_version_); + if (rocm_cc != nullptr && rocm_cc->gfx_version() == "gfx950") { + workspace = GemmConfig::kGFX950Workspace; + } // We do not know the workspace size required by cuBLAS, but we can guess // that in a worst case cuBLAS will transpose all operands into tiled From 2ee6516c779efe45c38dc09dbbd59e5d02e42804 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 14 Apr 2025 12:43:15 -0700 Subject: [PATCH 0723/1324] PR #25131: [ROCM] Clean up rocm ci asan build Imported from GitHub PR https://github.com/openxla/xla/pull/25131 Modify rocm ci build to include asan flags only in case of hermetic build. Copybara import of the project: -- 4d1d706893d20e7efbcd4fa6c6aad4e12ad8443c by alekstheod : ignore abseil asan issue and use asan args -- 5c4086ef60dc197b3b01779066bb58d23a4baf7a by alekstheod : Fix asan issue -- d192a22989a09af3fce86994eab8d73b02f088dc by alekstheod : Revert switch to std::map Merging this change closes #25131 PiperOrigin-RevId: 747528270 --- third_party/xla/build_tools/rocm/run_xla_ci_build.sh | 10 ++++++++-- third_party/xla/tensorflow.bazelrc | 1 - 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/third_party/xla/build_tools/rocm/run_xla_ci_build.sh b/third_party/xla/build_tools/rocm/run_xla_ci_build.sh index dbc6b0321f31be..e7de2182f3d1bc 100755 --- a/third_party/xla/build_tools/rocm/run_xla_ci_build.sh +++ b/third_party/xla/build_tools/rocm/run_xla_ci_build.sh @@ -21,6 +21,13 @@ set -x CONFIG=$1 DISK_CACHE_PATH=$2 +ASAN_ARGS=() +if [[ $CONFIG == "rocm_ci_hermetic" ]]; then + ASAN_ARGS+=("--test_env=ASAN_OPTIONS=suppressions=$(realpath $(dirname $0))/asan_ignore_list.txt") + ASAN_ARGS+=("--test_env=LSAN_OPTIONS=suppressions=$(realpath $(dirname $0))/lsan_ignore_list.txt") + ASAN_ARGS+=("--config=asan") +fi + bazel --bazelrc=/usertools/rocm.bazelrc test \ --config=${CONFIG} \ --config=xla_cpp \ @@ -36,5 +43,4 @@ bazel --bazelrc=/usertools/rocm.bazelrc test \ --test_output=errors \ --local_test_jobs=2 \ --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute \ - --test_env="ASAN_OPTIONS=suppressions=$(realpath $(dirname $0))/asan_ignore_list.txt" \ - --test_env="LSAN_OPTIONS=suppressions=$(realpath $(dirname $0))/lsan_ignore_list.txt" + "${ASAN_ARGS[@]}" diff --git a/third_party/xla/tensorflow.bazelrc b/third_party/xla/tensorflow.bazelrc index 98fdeb4f3fcca3..f745e058cb5f23 100644 --- a/third_party/xla/tensorflow.bazelrc +++ b/third_party/xla/tensorflow.bazelrc @@ -246,7 +246,6 @@ build:rocm_clang_official --host_linkopt="-fuse-ld=lld" build:rocm_ci --config=rocm_clang_official -build:rocm_ci_hermetic --config=asan build:rocm_ci_hermetic --config=rocm_clang_official build:rocm_ci_hermetic --repo_env="OS=ubuntu_22.04" build:rocm_ci_hermetic --repo_env="ROCM_VERSION=6.2.0" From 6f276f3bc059dcf69f32dcea80741cfaf7751029 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 14 Apr 2025 12:44:08 -0700 Subject: [PATCH 0724/1324] PR #25059: Fix clamping of constants in dynamic memcpy thunk. Imported from GitHub PR https://github.com/openxla/xla/pull/25059 We correctly clamp dynamic offsets, but constant offsets are simply taken as is, leading to incorrect results and possible out-of-bounds accesses. Copybara import of the project: -- 5940d5f96a8f3e106b0cc7051cc0a06f183a07a0 by Johannes Reifferscheid : Fix clamping of constants in dynamic memcpy thunk. We correctly clamp dynamic offsets, but constant offsets are simply taken as is, leading to incorrect results and possible out-of-bounds accesses. Merging this change closes #25059 PiperOrigin-RevId: 747528625 --- .../xla/xla/backends/gpu/codegen/copy.cc | 22 +++++--- .../xla/xla/backends/gpu/codegen/copy_test.cc | 52 +++++++++++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/copy.cc b/third_party/xla/xla/backends/gpu/codegen/copy.cc index 7a779ce7f2876d..b7598bc8f035eb 100644 --- a/third_party/xla/xla/backends/gpu/codegen/copy.cc +++ b/third_party/xla/xla/backends/gpu/codegen/copy.cc @@ -133,17 +133,21 @@ absl::StatusOr DynamicMemcpyFusion::Emit( namespace { +// Returns the slice size in the given dimension for a dynamic-(update-)slice +// instruction. +int64_t GetSliceSize(const HloInstruction* slice, int dim) { + if (slice->opcode() == HloOpcode::kDynamicSlice) { + return slice->dynamic_slice_sizes()[dim]; + } + CHECK_EQ(slice->opcode(), HloOpcode::kDynamicUpdateSlice); + return slice->operand(1)->shape().dimensions(dim); +} + // Whether the offset in the given dimension of the slice operation is // guaranteed to be clamped to 0. This is the case if the slice size is the // same as the size of the dimension in the unsliced shape. bool IsZeroOffset(const HloInstruction* slice, int dim) { - if (slice->opcode() == HloOpcode::kDynamicSlice) { - return slice->dynamic_slice_sizes()[dim] == - slice->operand(0)->shape().dimensions(dim); - } - CHECK_EQ(slice->opcode(), HloOpcode::kDynamicUpdateSlice); - return slice->operand(1)->shape().dimensions(dim) == - slice->operand(0)->shape().dimensions(dim); + return GetSliceSize(slice, dim) == slice->operand(0)->shape().dimensions(dim); } std::vector GetCallStack( @@ -264,6 +268,10 @@ DynamicMemcpyFusion::GetMemcpyDescriptorForFusion( return std::nullopt; } + // Clamp the offset to [0; dimension size - slice size]. + int64_t max = + slice->operand(0)->shape().dimensions(i) - GetSliceSize(slice, i); + *value = std::max(0, std::min(*value, max)); VLOG(5) << "Offset for dimension " << i << " is constant: " << *value << "."; static_offset += *value * (*strides)[i]; diff --git a/third_party/xla/xla/backends/gpu/codegen/copy_test.cc b/third_party/xla/xla/backends/gpu/codegen/copy_test.cc index 3f606e84e6aea4..523be7646081f9 100644 --- a/third_party/xla/xla/backends/gpu/codegen/copy_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/copy_test.cc @@ -25,6 +25,8 @@ namespace { using CopyFusionTest = HloHardwareIndependentTestBase; +using ::testing::IsEmpty; + const HloFusionInstruction& GetFusion(HloModule* module) { const HloInstruction* fusion = module->GetComputationWithName("dynamic_slice")->FusionInstruction(); @@ -92,6 +94,56 @@ TEST_F(CopyFusionTest, ValidCandidateClamped) { EXPECT_TRUE(DynamicMemcpyFusion::IsCandidateFusion(GetFusion(module.get()))); } +TEST_F(CopyFusionTest, ClampedConstantPositive) { + auto module = ParseAndReturnVerifiedModule(R"( + dynamic_slice { + p0 = f32[200] parameter(0) + c195 = s32[] constant(195) + ROOT slice = f32[100] dynamic-slice(p0, c195), dynamic_slice_sizes={100} + } + + ENTRY main { + p0 = f32[200] parameter(0) + ROOT fusion = f32[100] fusion(p0), kind=kLoop, calls=dynamic_slice + } + )") + .value(); + + auto descriptor = DynamicMemcpyFusion::GetMemcpyDescriptorForFusion( + GetFusion(module.get())); + + ASSERT_TRUE(descriptor.has_value()); + EXPECT_THAT(descriptor->src_dynamic_offsets, IsEmpty()); + EXPECT_THAT(descriptor->dst_dynamic_offsets, IsEmpty()); + EXPECT_EQ(descriptor->src_byte_static_offset, sizeof(float) * 100); + EXPECT_EQ(descriptor->dst_byte_static_offset, 0); +} + +TEST_F(CopyFusionTest, ClampedConstantNegative) { + auto module = ParseAndReturnVerifiedModule(R"( + dynamic_slice { + p0 = f32[200] parameter(0) + cn1 = s32[] constant(-1) + ROOT slice = f32[100] dynamic-slice(p0, cn1), dynamic_slice_sizes={100} + } + + ENTRY main { + p0 = f32[200] parameter(0) + ROOT fusion = f32[100] fusion(p0), kind=kLoop, calls=dynamic_slice + } + )") + .value(); + + auto descriptor = DynamicMemcpyFusion::GetMemcpyDescriptorForFusion( + GetFusion(module.get())); + + ASSERT_TRUE(descriptor.has_value()); + EXPECT_THAT(descriptor->src_dynamic_offsets, IsEmpty()); + EXPECT_THAT(descriptor->dst_dynamic_offsets, IsEmpty()); + EXPECT_EQ(descriptor->src_byte_static_offset, 0); + EXPECT_EQ(descriptor->dst_byte_static_offset, 0); +} + constexpr char kSliceMemcpyModule[] = R"( dynamic_slice { p0 = s32[4,8,8] parameter(0) From 805650700b78e038ad396cce3c4ce7fdb01f37c2 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Mon, 14 Apr 2025 12:44:12 -0700 Subject: [PATCH 0725/1324] Fix incorrect predicate for input and filter rank on FC LHS const swap. PiperOrigin-RevId: 747528648 --- tensorflow/compiler/mlir/lite/tests/optimize.mlir | 9 +++++++++ .../compiler/mlir/lite/transforms/optimize_pass.cc | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 83b82d50fc064f..72edff4f21f5f7 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -4756,3 +4756,12 @@ func.func @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction(%arg0 // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) } +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstLHSRank3 +func.func @FullyConnectedSwapOperandsWhenLHSIsConstLHSRank3(%arg0: tensor<512x512xf32>, %arg1: none) -> tensor<1x1x512xf32> { + %cst = arith.constant dense<1.0> : tensor<1x1x512xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x512xf32>, tensor<512x512xf32>, none) -> tensor<1x1x512xf32> + func.return %0 : tensor<1x1x512xf32> + + // CHECK: %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index 1de8398dbd8bd9..6cb79a635427ca 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -2978,7 +2978,7 @@ struct FullyConnectedSwapOperandsWhenLHSIsConst if (!input_type || !filter_type || !output_type) return failure(); - if (input_type.getRank() != 2 && filter_type.getRank() != 2) + if (input_type.getRank() != 2 || filter_type.getRank() != 2) return failure(); // Dimensions: B=Batch, I=InputDepth, O=OutputDepth From 90ed039dd960d23a3d15438ddda639568cffc4e3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 14 Apr 2025 12:58:05 -0700 Subject: [PATCH 0726/1324] [xla:gpu] CommandBuffer: introduce UpdateCommands callback and explicitly use RecordUpdate to update conditional command buffers PiperOrigin-RevId: 747534044 --- .../gpu/runtime/command_buffer_cmd.cc | 69 +++++++++++-------- .../backends/gpu/runtime/command_buffer_cmd.h | 12 ++-- .../xla/xla/stream_executor/command_buffer.h | 23 ++++--- .../stream_executor/gpu/gpu_command_buffer.cc | 39 +++++------ .../stream_executor/gpu/gpu_command_buffer.h | 10 +-- 5 files changed, 87 insertions(+), 66 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index c2e9eea2ecf047..86bcc4ae7a7038 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -138,18 +138,30 @@ static std::vector CreateBuilders( return builders; } -// Creates command buffer execution scope builder from a cmd sequence. -static se::CommandBuffer::Builder CreateExecutionScopeBuilder( - CommandBufferCmdSequence* commands, +// Create a callback to update a command buffer with command sequence. +static se::CommandBuffer::UpdateCommands UpdateCommands( + const CommandBufferCmdSequence* commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { return [=](se::CommandBuffer* command_buffer) { - CommandBufferCmd::RecordParams params = *record_params; - return commands->Record(*execute_params, params, command_buffer, - CommandBufferCmdSequence::RecordMode::kConditional); + return commands->RecordUpdate(*execute_params, *record_params, + command_buffer); }; } +// Create callbacks to update a command buffer with command sequence. +static std::vector UpdateCommands( + absl::Span commands, + const Thunk::ExecuteParams* execute_params, + const CommandBufferCmd::RecordParams* record_params) { + std::vector update_commands; + for (const CommandBufferCmdSequence& cmd : commands) { + update_commands.push_back( + UpdateCommands(&cmd, execute_params, record_params)); + } + return update_commands; +} + //===----------------------------------------------------------------------===// // CommandBufferCmd::RecordAction helpers. //===----------------------------------------------------------------------===// @@ -355,7 +367,8 @@ CommandBufferCmdSequence::RecordCreate( absl::Status CommandBufferCmdSequence::RecordUpdate( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { + const RecordParams& record_params, + se::CommandBuffer* command_buffer) const { // Command buffer must be already prepared for recording updates. TF_RETURN_IF_ERROR(CheckCommandBufferState( command_buffer, se::CommandBuffer::State::kUpdate)); @@ -404,7 +417,7 @@ absl::Status CommandBufferCmdSequence::RecordUpdate( absl::Status CommandBufferCmdSequence::CheckCommandBufferState( se::CommandBuffer* command_buffer, - se::CommandBuffer::State expected_state) { + se::CommandBuffer::State expected_state) const { if (command_buffer->state() != expected_state) { return Internal("Command buffer must be in %v state, got %v", expected_state, command_buffer->state()); @@ -918,15 +931,15 @@ CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() { CaseCmd::CaseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice index, bool index_is_bool, - std::vector branches_commands) + std::vector branches) : CommandBufferCmd(CommandBufferCmdType::kCaseCmd, execution_stream_id), index_(index), index_is_bool_(index_is_bool), - branches_commands_(std::move(branches_commands)) {} + branches_(std::move(branches)) {} absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params, StateManager& state) { - for (auto& branch : branches_commands_) { + for (auto& branch : branches_) { TF_RETURN_IF_ERROR(branch.Initialize(params, state)); } return absl::OkStatus(); @@ -942,12 +955,11 @@ absl::StatusOr CaseCmd::Record( VLOG(5) << "CaseCmd:"; VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")"; - auto branches = CreateBuilders(absl::MakeSpan(branches_commands_), - &execute_params, &record_params); - return Handle( std::move(record_action), [&](absl::Span dependencies) { + auto branches = CreateBuilders(absl::MakeSpan(branches_), + &execute_params, &record_params); if (index_is_bool_) { return command_buffer->CreateCase(se::DeviceMemory(index), std::move(branches), dependencies); @@ -960,24 +972,26 @@ absl::StatusOr CaseCmd::Record( [&](const se::CommandBuffer::Command* command) { if (index_is_bool_) { return command_buffer->UpdateCase( - command, se::DeviceMemory(index), std::move(branches)); + command, se::DeviceMemory(index), + UpdateCommands(branches_, &execute_params, &record_params)); } else { return command_buffer->UpdateCase( - command, se::DeviceMemory(index), std::move(branches)); + command, se::DeviceMemory(index), + UpdateCommands(branches_, &execute_params, &record_params)); } }); } bool CaseCmd::force_update() { - return absl::c_any_of(branches_commands_, + return absl::c_any_of(branches_, [](const auto& seq) { return seq.force_update(); }); } CommandBufferCmd::BufferUseVector CaseCmd::buffers() { absl::flat_hash_set buffers; buffers.emplace(index_, MemoryAccess::kRead); - for (auto& branch : branches_commands_) { + for (auto& branch : branches_) { buffers.insert(branch.buffers().begin(), branch.buffers().end()); } return {buffers.begin(), buffers.end()}; @@ -1013,21 +1027,20 @@ absl::StatusOr WhileCmd::Record( << " body_commands=" << body_commands_.size(); VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; - auto cond = CreateExecutionScopeBuilder(&cond_commands_, &execute_params, - &record_params); - auto body = CreateBuilder(&body_commands_, &execute_params, &record_params); - return Handle( std::move(record_action), [&](absl::Span dependencies) { - return command_buffer->CreateWhile(se::DeviceMemory(pred), - std::move(cond), std::move(body), - dependencies); + return command_buffer->CreateWhile( + se::DeviceMemory(pred), + CreateBuilder(&cond_commands_, &execute_params, &record_params), + CreateBuilder(&body_commands_, &execute_params, &record_params), + dependencies); }, [&](const se::CommandBuffer::Command* command) { - return command_buffer->UpdateWhile(command, - se::DeviceMemory(pred), - std::move(cond), std::move(body)); + return command_buffer->UpdateWhile( + command, se::DeviceMemory(pred), + UpdateCommands(&cond_commands_, &execute_params, &record_params), + UpdateCommands(&body_commands_, &execute_params, &record_params)); }); } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 4f0a333a3e1a75..c415a789008384 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -394,7 +394,7 @@ class CommandBufferCmdSequence { // in update state. absl::Status RecordUpdate(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, - se::CommandBuffer* command_buffer); + se::CommandBuffer* command_buffer) const; // Returns buffers referenced by commands in this sequence. const absl::flat_hash_set& buffers() const; @@ -424,8 +424,9 @@ class CommandBufferCmdSequence { SynchronizationMode synchronization_mode, std::vector> commands); - absl::Status CheckCommandBufferState(se::CommandBuffer* command_buffer, - se::CommandBuffer::State expected_state); + absl::Status CheckCommandBufferState( + se::CommandBuffer* command_buffer, + se::CommandBuffer::State expected_state) const; // Returns true if command has no dependencies. bool IsSource(CommandId id) const; @@ -662,8 +663,7 @@ class Memset32Cmd : public CommandBufferCmd { class CaseCmd : public CommandBufferCmd { public: CaseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice index, - bool index_is_bool, - std::vector branches_commands); + bool index_is_bool, std::vector branches); absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; @@ -680,7 +680,7 @@ class CaseCmd : public CommandBufferCmd { private: BufferAllocation::Slice index_; bool index_is_bool_; - std::vector branches_commands_; + std::vector branches_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index fa98781c27733b..e09633a92725ac 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -68,6 +68,13 @@ class CommandBuffer { // Builder constructs nested command buffers owned by a parent command buffer. using Builder = std::function; + // A callback to update a nested `command_buffer` owned by a conditional + // command. At command buffer update time we can't change the dependency + // structure of the previously created commands, and can only update the + // parameters of the commands (i.e. device memory pointers). + using UpdateCommands = + absl::AnyInvocable; + CommandBuffer() = default; virtual ~CommandBuffer() = default; @@ -220,13 +227,13 @@ class CommandBuffer { absl::Span dependencies) = 0; // Updates a Case command. - virtual absl::Status UpdateCase(const Command* command, - DeviceMemory index, - std::vector branches) = 0; + virtual absl::Status UpdateCase( + const Command* command, DeviceMemory index, + std::vector update_branches) = 0; - virtual absl::Status UpdateCase(const Command* command, - DeviceMemory index, - std::vector branches) = 0; + virtual absl::Status UpdateCase( + const Command* command, DeviceMemory index, + std::vector update_branches) = 0; // Creates a conditional operation that will execute a command buffer // constructed by the `cond_builder` that must update `pred` value, and then @@ -248,8 +255,8 @@ class CommandBuffer { // Updates a While command. virtual absl::Status UpdateWhile(const Command* command, DeviceMemory pred, - Builder cond_builder, - Builder body_builder) = 0; + UpdateCommands update_cond, + UpdateCommands update_body) = 0; // Submits the command buffer for execution. virtual absl::Status Submit(Stream* stream) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 8b81658ee401df..3b28c8c943a221 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -410,10 +410,9 @@ absl::StatusOr GpuCommandBuffer::CreateCase( return AppendCommand(std::move(command)); } -absl::Status GpuCommandBuffer::UpdateCase(const Command* command, - DeviceMemory index, - bool index_is_bool, - std::vector branches) { +absl::Status GpuCommandBuffer::UpdateCase( + const Command* command, DeviceMemory index, bool index_is_bool, + std::vector update_branches) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); constexpr size_t kBranchBatchSize = 8; @@ -423,8 +422,8 @@ absl::Status GpuCommandBuffer::UpdateCase(const Command* command, // Update branch conditionals. size_t batch_index = 0; int32_t batch_offset = 0; - while (batch_offset < branches.size()) { - int32_t remaining_branches = branches.size() - batch_offset; + while (batch_offset < update_branches.size()) { + int32_t remaining_branches = update_branches.size() - batch_offset; int32_t batch_size; bool enable_conditional_default; if (remaining_branches <= kBranchBatchSize) { @@ -451,7 +450,7 @@ absl::Status GpuCommandBuffer::UpdateCase(const Command* command, gpu_command->conditional_nodes[i].command_buffer.get(); auto scoped_update_mode = ActivateUpdateMode(case_command_buffer); TF_RETURN_IF_ERROR(case_command_buffer->Update()); - TF_RETURN_IF_ERROR(branches[i](case_command_buffer)); + TF_RETURN_IF_ERROR(update_branches[i](case_command_buffer)); TF_RETURN_IF_ERROR(case_command_buffer->Finalize()); } @@ -474,22 +473,22 @@ absl::StatusOr GpuCommandBuffer::CreateCase( /*index_is_bool=*/true, branches, dependencies); } -absl::Status GpuCommandBuffer::UpdateCase(const Command* command, - DeviceMemory index, - std::vector branches) { +absl::Status GpuCommandBuffer::UpdateCase( + const Command* command, DeviceMemory index, + std::vector update_branches) { return UpdateCase( command, DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), - /*index_is_bool=*/false, branches); + /*index_is_bool=*/false, std::move(update_branches)); } -absl::Status GpuCommandBuffer::UpdateCase(const Command* command, - DeviceMemory index, - std::vector branches) { +absl::Status GpuCommandBuffer::UpdateCase( + const Command* command, DeviceMemory index, + std::vector update_branches) { return UpdateCase( command, DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), - /*index_is_bool=*/true, branches); + /*index_is_bool=*/true, std::move(update_branches)); } absl::StatusOr GpuCommandBuffer::CreateWhile( @@ -531,13 +530,13 @@ absl::StatusOr GpuCommandBuffer::CreateWhile( absl::Status GpuCommandBuffer::UpdateWhile(const Command* command, DeviceMemory pred, - Builder cond_builder, - Builder body_builder) { + UpdateCommands update_cond, + UpdateCommands update_body) { TF_RETURN_IF_ERROR(CheckInState(State::kUpdate)); auto* gpu_command = tsl::down_cast(command); - TF_RETURN_IF_ERROR(cond_builder(this)); + TF_RETURN_IF_ERROR(update_cond(this)); TF_RETURN_IF_ERROR(UpdateSetWhileConditionNode( gpu_command->set_init_condition_node, gpu_command->conditional, pred)); @@ -547,8 +546,8 @@ absl::Status GpuCommandBuffer::UpdateWhile(const Command* command, // Update command buffer using user-provided builder callback. TF_RETURN_IF_ERROR(body->Update()); - TF_RETURN_IF_ERROR(body_builder(body)); - TF_RETURN_IF_ERROR(cond_builder(body)); + TF_RETURN_IF_ERROR(update_body(body)); + TF_RETURN_IF_ERROR(update_cond(body)); TF_RETURN_IF_ERROR(body->UpdateSetWhileConditionNode( gpu_command->set_body_condition_node, gpu_command->conditional, pred)); TF_RETURN_IF_ERROR(body->Finalize()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 13c379cb057943..8989327eef298e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -176,17 +176,18 @@ class GpuCommandBuffer : public CommandBuffer { absl::Span dependencies) override; absl::Status UpdateCase(const Command* command, DeviceMemory index, - std::vector branches) override; + std::vector update_branches) override; absl::Status UpdateCase(const Command* command, DeviceMemory index, - std::vector branches) override; + std::vector update_branches) override; absl::StatusOr CreateWhile( DeviceMemory pred, Builder cond_builder, Builder body_builder, absl::Span dependencies) override; absl::Status UpdateWhile(const Command* command, DeviceMemory pred, - Builder cond_builder, Builder body_builder) override; + UpdateCommands update_cond, + UpdateCommands update_body) override; absl::Status Finalize() override; absl::Status Update() override; @@ -285,7 +286,8 @@ class GpuCommandBuffer : public CommandBuffer { absl::Span dependencies); absl::Status UpdateCase(const Command* command, DeviceMemory index, - bool index_is_bool, std::vector branches); + bool index_is_bool, + std::vector update_branches); // Appends a new command to the command buffer. template From 9b5993cc5c7691c9714e9a5ce48de96869aede09 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 14 Apr 2025 13:00:31 -0700 Subject: [PATCH 0727/1324] PR #25135: Treat arrays and tuples equally in Literal construction. Imported from GitHub PR https://github.com/openxla/xla/pull/25135 Currently, arrays without a layout are treated as having an implicit default layout (see the first added test, which passes at HEAD). For tuples, this does not apply, so it is possible to have Literals whose shape does not have a layout. After this change, Literals should always have a set layout. This should hopefully make landing https://github.com/openxla/xla/pull/24744/files easier. I'm sending this separately because the other PR affects a lot of tests that I have no access to. I'm guessing this change solves some of the issues we're seeing there. Copybara import of the project: -- 21c8e354484357cb47ef7c90db978639dc371107 by Johannes Reifferscheid : Treat arrays and tuples equally in Literal construction. Currently, arrays without a layout are treated as having an implicit default layout (see the first added test, which passes at HEAD). For tuples, this does not apply, so it is possible to have Literals whose shape does not have a layout. After this change, Literals should always have a set layout. This should hopefully make landing https://github.com/openxla/xla/pull/24744/files easier. I'm sending this separately because the other PR affects a lot of tests that I have no access to. I'm guessing this change solves some of the issues we're seeing there. Merging this change closes #25135 PiperOrigin-RevId: 747534947 --- third_party/xla/xla/literal.cc | 10 +++++++--- third_party/xla/xla/literal_test.cc | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index f37c390738859d..610e6a1453f0a0 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -257,9 +257,13 @@ void Literal::SetShape(const Shape& shape) { return; } auto owning_shape_ptr = std::make_unique(shape); - if (owning_shape_ptr->IsArray() && !owning_shape_ptr->has_layout()) { - *owning_shape_ptr->mutable_layout() = - LayoutUtil::GetDefaultLayoutForShape(*owning_shape_ptr); + if (!LayoutUtil::HasLayout(*owning_shape_ptr)) { + ShapeUtil::ForEachMutableLeafShape( + owning_shape_ptr.get(), [](Shape* subshape, const ShapeIndex& index) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); } if (owning_shape_ptr->IsArray() && LayoutUtil::HasCustomElementSizeInBits(*owning_shape_ptr)) { diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 19b0eeb0b1949b..b2c8af1a94d100 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -551,6 +551,27 @@ TEST_F(LiteralUtilTest, DifferentLayoutInEquality) { EXPECT_FALSE(colmajor.Equal(rowmajor, true)); } +TEST_F(LiteralUtilTest, CreateWithoutLayout) { + Shape default_layout_shape = ShapeUtil::MakeShape(F32, {2, 1}); + Shape no_layout_shape = default_layout_shape; + no_layout_shape.clear_layout(); + auto literal = + LiteralBase::CreateFromShapeWithUndeterminedLeafArrays(no_layout_shape); + // The default Layout should have been added back. + EXPECT_EQ(literal.shape(), default_layout_shape); +} + +TEST_F(LiteralUtilTest, CreateWithoutLayout_Tuple) { + Shape default_layout_shape = ShapeUtil::MakeShape(F32, {2, 1}); + Shape no_layout_shape = default_layout_shape; + no_layout_shape.clear_layout(); + Shape literal_shape = ShapeUtil::MakeTupleShape({no_layout_shape}); + auto literal = + LiteralBase::CreateFromShapeWithUndeterminedLeafArrays(literal_shape); + // The default Layout should have been added back. + EXPECT_EQ(literal.shape().tuple_shapes(0), default_layout_shape); +} + TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. auto scalar = LiteralUtil::CreateR0(1.0); From 1db2214f0d7d89d998009826c0add7fcf927b7f6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 13:51:12 -0700 Subject: [PATCH 0728/1324] Add support for DT_FLOAT8_E4M3B11FNUZ to tensor summary. PiperOrigin-RevId: 747554607 --- tensorflow/core/framework/tensor.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 73bbc554ea1b05..e04c3e2f12eb72 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -1268,6 +1268,10 @@ inline float PrintOneElement(float8_e4m3fn f, bool print_v2) { return static_cast(f); } +inline float PrintOneElement(float8_e4m3b11fnuz f, bool print_v2) { + return static_cast(f); +} + inline int16_t PrintOneElement(int4 a, bool print_v2) { return static_cast(a); } @@ -1454,6 +1458,9 @@ string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const { case DT_FLOAT8_E4M3FN: return SummarizeArray(limit, num_elts, shape_, data, print_v2); + case DT_FLOAT8_E4M3B11FNUZ: + return SummarizeArray(limit, num_elts, shape_, data, + print_v2); case DT_FLOAT: return SummarizeArray(limit, num_elts, shape_, data, print_v2); break; From c62fcc7a7f15967f920b7c0561f972b2532defe2 Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Mon, 14 Apr 2025 13:52:27 -0700 Subject: [PATCH 0729/1324] Propagate volatility attribute when creating QConst ops during quantization * This ensures that output of a qconst will have the same volatility as the source quantize op. * This will help with subsequently identify and remove unnecessary QConst->Dequantize ops after quantization. PiperOrigin-RevId: 747555208 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../mlir/lite/tests/post-quantize.mlir | 14 +- .../mlir/lite/tests/quantize-strict.mlir | 13 +- .../mlir/lite/transforms/post_quantize.cc | 137 ++++++++++++++++++ .../compiler/mlir/lite/transforms/quantize.cc | 8 +- 5 files changed, 170 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index fbd280f063032e..a666df4cab651a 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1293,6 +1293,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 005aec23403c7f..8971ca0d6d3788 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -188,9 +188,21 @@ func.func @FoldPerAxisReshape() -> tensor<1x2x2x!quant.uniform>, value = dense<[[-127, 127], [-85, -80]]> : tensor<2x2xi8>}> : () -> tensor<2x2x!quant.uniform> %1 = "tfl.reshape"(%0, %cst) : (tensor<2x2x!quant.uniform>, tensor<3xi32>) -> tensor<1x2x2x!quant.uniform> return %1 : tensor<1x2x2x!quant.uniform> - + // CHECK{LITERAL}: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x2x!quant.uniform>, value = dense<[[[-127, 127], [-85, -80]]]> : tensor<1x2x2xi8>}> : () -> tensor<1x2x2x!quant.uniform> // CHECK-NOT: tfl.reshape // CHECK: return %0 : tensor<1x2x2x!quant.uniform> } + +// CHECK-LABEL: RemoveVolatileQConstOps +func.func @RemoveVolatileQConstOps() -> tensor<640xf32> { + %1 = "tfl.pseudo_qconst"() <{qtype = tensor<640x!quant.uniform>, value = dense<0> : tensor<640xi32>}> {volatile} : () -> tensor<640x!quant.uniform> + %2 = "tfl.dequantize"(%1) : (tensor<640x!quant.uniform>) -> tensor<640xf32> + func.return %2 : tensor<640xf32> + // CHECK: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<640x!quant.uniform>, value = dense<0> : tensor<640xi32>}> {volatile} : () -> tensor<640x!quant.uniform> + // CHECK: return %0 : tensor<640x!quant.uniform> + + // QDQ-CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<640xf32> + // QDQ-CHECK: return %cst : tensor<640xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir index 4240ea65988461..43984be64310e3 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir @@ -73,7 +73,7 @@ func.func @QuantizeConvWithBiasAndReluWeightOnly(%arg0: tensor<1x4x4x3xf32>) -> func.func @QuantizeConvWithBiasAndReluSRQ(%arg0: tensor<1x4x4x3xf32>) -> (tensor<1x4x4x1xf32>) { %cst = arith.constant dense<1.14751196> : tensor<1xf32> - %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> : (tensor<1xf32>) -> tensor<1x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> %cst_0 = arith.constant dense<[[[[1.76285899, -0.257785767, 0.20429258], [1.16310906, 0.23124367, 0.529797196]], [[0.348971426, -0.319283515, -0.772461354], [0.316666812, 1.88180697, -1.78054631]]]]> : tensor<1x2x2x3xf32> %2 = "tfl.quantize"(%arg0) <{qtype = tensor<1x4x4x3x!quant.uniform>}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3x!quant.uniform> @@ -105,3 +105,14 @@ func.func @DQQToRequantize(%arg0: tensor<1x128x128x320x!quant.uniform> } +// ----- + +func.func @VolatileQuantizeConst() -> (tensor<1xf32>) { + %cst = arith.constant dense<1.14751196> : tensor<1xf32> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> + return %1 : tensor<1xf32> +// CHECK: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<1x!quant.uniform>, value = dense<20578> : tensor<1xi32>}> {volatile} : () -> tensor<1x!quant.uniform> +// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> +// CHECK: return %1 : tensor<1xf32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 914d426f278d66..65e5368b7faf96 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -15,13 +15,21 @@ limitations under the License. // This transformation pass applies some clean up steps after quantization. +#include +#include #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -31,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" //===----------------------------------------------------------------------===// // The post-quantize Passes. @@ -155,6 +164,92 @@ enum RemoveVolatileOpsType { kPreserveInputsAndOutputs, }; +// Returns a constant tensor with the given scalar/vector value and shape. +template +std::optional GetConstTensor(PatternRewriter& rewriter, + Location loc, llvm::ArrayRef vec, + llvm::ArrayRef shape) { + int64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + return std::nullopt; + } + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + shape, rewriter.getIntegerType(sizeof(T) * 8)); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(loc, const_type, const_attr); + return const_op.getResult(); +} + +// Converts a dequantize op to a (scale * (input - zeropoint)). The expectation +// is that the qconst value will be constant folded to retain the original +// constant value. This is essentially a constant fold of the dequantize op, +// privided that the value, zp and scale are all constants. +std::optional ConvertDequantizeOp( + PatternRewriter& rewriter, mlir::Operation* op, + mlir::ShapedType output_type, mlir::Value input_value, + llvm::ArrayRef scale, llvm::ArrayRef zeropoint, + int64_t dim) { + RankedTensorType input_type = + dyn_cast(input_value.getType()); + if (!input_type) return std::nullopt; + + std::optional zp_val; + if (zeropoint.size() == 1) { + auto const_type = + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getF32Type()); + auto const_attr = + DenseElementsAttr::get(const_type, static_cast(zeropoint[0])); + + auto const_op = rewriter.create(op->getLoc(), const_type, + const_attr); + zp_val = const_op.getResult(); + } else { + SmallVector shape; + shape.resize(input_type.getRank(), 1); + shape[dim] = zeropoint.size(); + zp_val = GetConstTensor(rewriter, op->getLoc(), zeropoint, shape); + } + + std::optional scale_val; + if (scale.size() == 1) { + auto const_type = + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getF32Type()); + auto const_attr = + DenseElementsAttr::get(const_type, static_cast(scale[0])); + + auto const_op = rewriter.create(op->getLoc(), const_type, + const_attr); + scale_val = const_op.getResult(); + } else { + SmallVector shape; + shape.resize(input_type.getRank(), 1); + shape[dim] = scale.size(); + scale_val = GetConstTensor(rewriter, op->getLoc(), scale, shape); + } + + if (!zp_val || !scale_val) return std::nullopt; + + auto op1_cast_in = + rewriter.create(op->getLoc(), output_type, input_value); + + auto op2_sub_op1 = rewriter.create( + op->getLoc(), output_type, op1_cast_in.getResult(), zp_val.value(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + + return rewriter + .create( + op->getLoc(), output_type, op2_sub_op1.getResult(), scale_val.value(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")) + .getResult(); +} + // Remove the back-to-back quantize and dequantize ops with volatile attribute. template struct RemoveVolatileOps : public OpRewritePattern { @@ -188,6 +283,48 @@ struct RemoveVolatileOps : public OpRewritePattern { op.replaceAllUsesWith(q.getInput()); return success(); + } else if (auto qconst_op = llvm::dyn_cast_or_null(input_op)) { + if (!qconst_op->getAttr(mlir::quant::kVolatileOpAttrName)) + return failure(); + + auto qtype = + quant::QuantizedType::getQuantizedElementType(qconst_op.getType()); + if (!qtype) return failure(); + SmallVector scale; + SmallVector zeropoint; + int64_t dim = 0; + + if (auto uniform_qtype = + mlir::dyn_cast(qtype)) { + scale.push_back(uniform_qtype.getScale()); + zeropoint.push_back(uniform_qtype.getZeroPoint()); + } else if (auto per_axis_qtype = + mlir::dyn_cast( + qtype)) { + scale.assign(per_axis_qtype.getScales().begin(), + per_axis_qtype.getScales().end()); + zeropoint.assign(per_axis_qtype.getZeroPoints().begin(), + per_axis_qtype.getZeroPoints().end()); + dim = per_axis_qtype.getQuantizedDimension(); + } else { + return failure(); + } + + auto output_type = mlir::cast(op.getOutput().getType()); + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + output_type.getShape(), qtype.getStorageType()); + auto const_op = rewriter.create( + op->getLoc(), const_type, qconst_op.getValue()); + + auto new_value = + ConvertDequantizeOp(rewriter, op, output_type, const_op.getResult(), + scale, zeropoint, dim); + if (!new_value) return failure(); + + op.replaceAllUsesWith(new_value.value()); + op->erase(); + return success(); } return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index ae1674b5862986..be0c1803543140 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -582,7 +582,13 @@ class QuantizeConstPattern : public OpRewritePattern { quantized_attr = quant::Quantize(attr, qtype.getValue()); } if (quantized_attr) { - rewriter.replaceOpWithNewOp(op, qtype, quantized_attr); + auto qconst_op = + rewriter.create(op.getLoc(), qtype, quantized_attr); + if (auto volatile_attr = op->getAttr(quant::kVolatileOpAttrName)) { + qconst_op->setAttr(quant::kVolatileOpAttrName, volatile_attr); + } + op.replaceAllUsesWith(qconst_op.getOutput()); + rewriter.eraseOp(op); return success(); } } From c2662d08013ee35a87982f62ecfb1c51292749aa Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 14 Apr 2025 14:13:31 -0700 Subject: [PATCH 0730/1324] Removed `CoordinationServiceInterface`. Previously, there was a pure virtual `CoordinationServiceInterface` interface with a single subclass called `CoordinationServiceStandaloneImpl`. I think the idea was that there would be multiple implementations of the `CoordinationServiceInterface`, but there are not, and the extra layer of indirection is a nuisance. This CL renames `CoordinationServiceStandaloneImpl` to just `CoordinationService` and replaces `CoordinationServiceInterface` with it. PiperOrigin-RevId: 747563412 --- .../xla/xla/pjrt/distributed/service.cc | 6 +- .../xla/xla/pjrt/distributed/service.h | 2 +- .../distributed_runtime/coordination/BUILD | 3 +- .../coordination/client_server_test.cc | 11 +- .../coordination/coordination_service.cc | 677 ++++-------------- .../coordination/coordination_service.h | 499 +++++++++++-- .../coordination/coordination_service_agent.h | 4 +- ...ordination_service_recoverable_job_test.cc | 6 +- .../coordination_service_rpc_handler.cc | 3 +- .../coordination_service_rpc_handler.h | 4 +- .../coordination/coordination_service_test.cc | 78 +- .../preemption_sync_manager_test.cc | 8 +- .../grpc_coordination_service_impl.h | 2 +- 13 files changed, 626 insertions(+), 677 deletions(-) diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index 8d66d9b37d47d9..9103c154d95ab8 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -34,7 +34,7 @@ limitations under the License. namespace { -std::unique_ptr EnableCoordinationService( +std::unique_ptr EnableCoordinationService( const xla::CoordinationServiceImpl::Options& options) { const std::string job_name = "jax_worker"; tensorflow::CoordinationServiceConfig config; @@ -51,8 +51,8 @@ std::unique_ptr EnableCoordinationService( config.mutable_coordinated_job_list()->Add(); job->set_name(job_name); job->set_num_tasks(options.num_nodes); - auto service = tsl::CoordinationServiceInterface::EnableCoordinationService( - options.env, config, /*cache=*/nullptr); + auto service = + tsl::CoordinationService::Create(options.env, config, /*cache=*/nullptr); return service; } } // namespace diff --git a/third_party/xla/xla/pjrt/distributed/service.h b/third_party/xla/xla/pjrt/distributed/service.h index d1e3279a7a4a3d..8d1142feffa9f5 100644 --- a/third_party/xla/xla/pjrt/distributed/service.h +++ b/third_party/xla/xla/pjrt/distributed/service.h @@ -76,7 +76,7 @@ class CoordinationServiceImpl { private: tsl::Env* env_ = nullptr; // Not owned. - std::unique_ptr coord_service_; + std::unique_ptr coord_service_; std::unique_ptr coord_compute_pool_; std::unique_ptr coord_rpc_service_; std::unique_ptr coord_rpc_thread_; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index df5838794b63b9..58f8bcbc9928e6 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -57,16 +57,15 @@ cc_library( ":coordination_service_error_util", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/platform:env", - "//xla/tsl/platform:macros", "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc index 0e33414822cd84..af6897e028282d 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc @@ -117,8 +117,8 @@ class ClientServerTest : public ::testing::Test { config.mutable_coordinated_job_list()->Add(); job->set_name("agent"); job->set_num_tasks(num_nodes); - auto service = tsl::CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, /*cache=*/nullptr); + auto service = tsl::CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); return config; } @@ -162,9 +162,8 @@ class ClientServerTest : public ::testing::Test { grpc::InsecureServerCredentials()); // Set up the actual coordination service (where all the real logic // lives). - coord_service_ = - tsl::CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, /*cache=*/nullptr); + coord_service_ = tsl::CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); // Set up threads and RPC service. coord_compute_pool_ = std::make_unique( Env::Default(), "CoordinationServiceRpcHandler", @@ -211,7 +210,7 @@ class ClientServerTest : public ::testing::Test { private: std::string service_address_; std::unique_ptr server_; - std::unique_ptr coord_service_; + std::unique_ptr coord_service_; std::unique_ptr coord_compute_pool_; std::unique_ptr coord_rpc_service_; std::unique_ptr coord_rpc_thread_; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index 445040905d84d0..04224e68d1d7bb 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -24,16 +24,12 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/algorithm/container.h" -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/bind_front.h" -#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -53,7 +49,6 @@ limitations under the License. #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/platform/random.h" namespace tsl { namespace { @@ -77,7 +72,7 @@ constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck"; constexpr int kPendingStragglerLogLimit = 3; constexpr int kUniqueBarrierCounter = 0; -std::string GetTaskName(std::string_view job_name, int task_id) { +std::string GetTaskName(absl::string_view job_name, int task_id) { return absl::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id); } @@ -85,7 +80,7 @@ std::string GetTaskName(const CoordinatedTask& task) { return GetTaskName(task.job_name(), task.task_id()); } -CoordinatedTask GetTaskFromName(std::string_view task_name) { +CoordinatedTask GetTaskFromName(absl::string_view task_name) { DeviceNameUtils::ParsedName parsed; DeviceNameUtils::ParseFullName(task_name, &parsed); CoordinatedTask task; @@ -94,426 +89,14 @@ CoordinatedTask GetTaskFromName(std::string_view task_name) { return task; } -// Convenience structs to allow using CoordinatedTask as container keys. -struct CoordinatedTaskHash { - uint64_t operator()(const CoordinatedTask& task) const { - return absl::HashOf(task.job_name(), task.task_id()); - } -}; -struct CoordinatedTaskEqual { - bool operator()(const CoordinatedTask& lhs, - const CoordinatedTask& rhs) const { - return lhs.job_name() == rhs.job_name() && lhs.task_id() == rhs.task_id(); - } -}; - -using CoordinatedTaskSet = - absl::flat_hash_set; - absl::Status MakeShutdownBarrierError(const absl::Status& error) { return MakeCoordinationError(absl::InternalError(absl::StrCat( "Shutdown barrier has failed.\nBarrier result: '", error.ToString()))); } -// Standalone implementation of the coordination service. -class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { - public: - CoordinationServiceStandaloneImpl( - Env* env, const CoordinationServiceConfig& config, - std::unique_ptr client_cache); - ~CoordinationServiceStandaloneImpl() override { - absl::MutexLock lock(&state_mu_); - Stop(); - } - - void SetDeviceAggregationFunction( - std::function - post_aggregate_device_fn) override; - - void LogConnectStatusLocked() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - - absl::Status RegisterTask(const CoordinatedTask& task, - uint64_t incarnation) override; - void RegisterTaskAsync(const CoordinatedTask& task, uint64_t incarnation, - StatusCallback done) override; - void WaitForAllTasks(const CoordinatedTask& task, const DeviceInfo& devices, - StatusCallback done) override; - void ShutdownTaskAsync(const CoordinatedTask& task, - StatusCallback done) override; - absl::Status ResetTask(const CoordinatedTask& task) override; - absl::Status RecordHeartbeat(const CoordinatedTask& task, - uint64_t incarnation) override; - absl::Status ReportTaskError(const CoordinatedTask& task, - const absl::Status& error) override; - std::vector GetTaskState( - const std::vector& task) override; - std::vector GetJobState( - absl::string_view job) override; - absl::Status InsertKeyValue(std::string_view key, - std::string_view value) override; - absl::Status InsertKeyValue(std::string_view key, std::string_view value, - bool allow_overwrite) override; - void GetKeyValueAsync(std::string_view key, - StatusOrValueCallback done) override; - absl::StatusOr TryGetKeyValue(std::string_view key) override; - std::vector GetKeyValueDir( - std::string_view directory_key) override; - absl::Status DeleteKeyValue(std::string_view key) override; - void BarrierAsync(std::string barrier_id, int64_t counter, - absl::Duration timeout, const CoordinatedTask& task, - const std::vector& participating_tasks, - BarrierCallback done) override; - absl::Status CancelBarrier(std::string barrier_id, int64_t counter, - const CoordinatedTask& task) override; - void GetAliveTasksAsync(const tensorflow::CoordinatedTask& requesting_task, - const std::vector& tasks, - GetAliveTasksCallback done) override; - void PollForErrorAsync(const CoordinatedTask& task, - StatusCallback done) override; - - private: - const DeviceInfo& ListClusterDevices() override - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - uint64_t GetServiceIncarnation() override; - void BarrierAsyncLocked( - std::string barrier_id, int64_t counter, absl::Duration timeout, - const CoordinatedTask& task, - const std::vector& participating_tasks, - BarrierCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - BarrierCallback ConnectAfterBarrierPasses(absl::string_view task_name, - uint64_t incarnation, - StatusCallback done); - // Connects a task to the service, and leaves any previously ongoing barriers - // for recoverable tasks. - void ConnectTask(const CoordinatedTask& task, uint64_t incarnation) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Checks if any task has stopped sending heartbeats. - void CheckHeartbeatTimeout(); - // Checks if any barrier has timed out. - void CheckBarrierTimeout(); - // Checks both heartbeat and barrier timeouts. Use a single function so they - // can be run in the same thread as threads are a constrained resource. - void CheckStaleness(); - // Starts a thread to check staleness. - void StartCheckStaleness(); - void Stop() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - bool ServiceHasStopped() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Report error from a task to all other connected tasks if the task is not - // recoverable. - // Note: SetTaskError() must be called before propagating its error. - void PropagateError(const absl::Status& error, - const std::vector& source_tasks, - bool is_reported_by_task = false) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - void PropagateError(const absl::Status& error, - const std::vector& source_task_names, - bool is_reported_by_task = false) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Checks if all tasks are from recoverable jobs. - bool AllTasksAreRecoverable(const std::vector& tasks) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - void SetTaskError(std::string_view task_name, const absl::Status& error) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Used for cluster-wide errors (e.g. register or shutdown barrier fails). - void SetAllTasksError(const absl::Status& error) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - absl::Status DisconnectTask(const CoordinatedTask& task) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - void DisconnectAllNonRecoverableTasks() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - std::vector GetTasksForShutdownBarrier(); - - struct BarrierState { - std::string id = ""; - // Counter is incremented for each new barrier using the same id. - // No two barriers with the same id (and different counters) can be ongoing - // at the same time. - int64_t counter = 0; - bool passed = false; - absl::Status result = absl::UnknownError( - "Invalid barrier result."); // Only valid if `passed` is true. - uint64_t deadline_in_micros = 0; - int num_pending_tasks = 0; - // Specifies which tasks have called the barrier so far. - absl::flat_hash_map - tasks_at_barrier; - absl::flat_hash_map - done_callbacks; - // Specifies the task that initiated the barrier (the first task to call the - // barrier). - CoordinatedTask initiating_task; - }; - bool BarrierIsUninitialized(const BarrierState& barrier) { - return barrier.id.empty() && barrier.counter == 0 && !barrier.passed && - barrier.deadline_in_micros == 0 && barrier.num_pending_tasks == 0; - } - std::string BarrierName(std::string_view barrier_id, int64_t counter) { - return absl::StrCat(barrier_id, "::", counter); - } - std::string BarrierName(const BarrierState& barrier) { - return BarrierName(barrier.id, barrier.counter); - } - // Initializes a new barrier. Returns false if the barrier should fail - // immediately. - bool InitializeBarrier( - BarrierState* barrier, std::string_view barrier_id, int64_t counter, - absl::Duration timeout, const CoordinatedTask& task, - const std::vector& participating_tasks, - BarrierCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Initialize `BarrierState`'s tasks_at_barrier map. - bool InitializeTasksAtBarrier( - BarrierState* barrier, - const std::vector& participating_tasks, - BarrierCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Adds a callback to be called when the barrier is done. - // If there is an existing callback for that task, it will be overwritten, - // cancelling the previous callback. - void AddBarrierCallback(BarrierState* barrier, const CoordinatedTask& task, - BarrierCallback done) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Ends the barrier with a result (ok or error). - void PassBarrier(BarrierState* barrier, const absl::Status& result) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // A task reaches the barrier. - void ReachBarrier(BarrierState* barrier, const CoordinatedTask& task) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - void FailBarrierWithCounterMismatch(BarrierState* barrier, int64_t counter) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Propagates same result back to task. - void RepeatBarrierResult(BarrierState* barrier, const CoordinatedTask& task) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Leaves any ongoing barriers. - // If the task is non-recoverable, the barrier exits with an error. - // If the task is recoverable, the barrier will 'unregister' a task and allow - // it to join back again later before the timeout. - void LeaveOngoingBarriers(const CoordinatedTask& task, - std::string_view reason) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Post-barrier hook to connect all tasks. - void ConnectAllTasks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Post-barrier hook to aggregate device info. - void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Post-shutdown barrier hook to disconnect tasks that acked and propagate - // errors to those that have not. - void CompleteShutdownAfterBarrier(const absl::Status& result, - BarrierState* barrier) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Checks if the participating tasks are specified correctly across barrier - // calls and that the caller task is one of the participating tasks. - bool ValidateTaskArgs(BarrierState* barrier, - const CoordinatedTask& caller_task, - const std::vector& tasks_args) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - bool isRecoverableJob(std::string_view task_name) const; - // Sends responses to error polling requests when an error is encountered. - void SendErrorPollingResponse(const absl::Status& error) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Responds to error polling or fails all tasks when an error is - // encountered. Should only be called when there is no service to client - // connection. - void SendErrorPollingResponseOrFailAllTasks(const absl::Status& error) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Returns whether the clients are polling for error from the service. If the - // clients are not polling for error from the service, the service should stop - // when there is an error. Otherwise, the service should not stop. - bool IsClientPollingForError() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - - class ErrorPollingState { - public: - // Returns whether the error polling requests have been responded. - bool Responded() const { return responded_; } - // Sets the error and executes the status callbacks. - void SetError(const absl::Status& error); - // Gets the error that is propagated to the agents. - const absl::Status& GetError() const { return error_; } - // Returns true if the task has sent request to poll for error from the - // service. - bool IsTaskPolling(absl::string_view task_name) const { - return polling_task_names_.contains(task_name); - } - // Adds a task to the error polling state. - void AddTask(const CoordinatedTask& task, StatusCallback&& done); - - // Removes a task from the error polling state. - // If an existing polling request is present, we will invoke the callback - // with the `reason` argument. - // Note: for disconnected tasks, this does not actually propagate the error - // back, but prevents memory leaks by removing stale callbacks. - void RemoveTask(const CoordinatedTask& task, absl::string_view reason); - - private: - bool responded_ = false; - absl::Status error_ = absl::OkStatus(); - absl::flat_hash_map - done_callbacks_; - absl::flat_hash_set polling_task_names_; - }; - - class TaskState { - public: - // Task state maintained on the coordination service side. - // State transition: - // Register Heartbeat - // DISCONNECTED -------> CONNECTED --------> ERROR (timeout) - // | ReportError - // +--------------> ERROR - // - // When task state becomes ERROR, propagate this status to other CONNECTED - // tasks in the cluster. - - explicit TaskState(absl::string_view task) { task_name_ = task; } - - CoordinatedTaskState GetState() const { return state_; } - absl::Status GetStatus() const { return status_; } - bool IsRecoverable() const { return recoverable_; } - void SetRecoverable(bool recoverable) { recoverable_ = recoverable; } - uint64_t GetTaskIncarnation() const { return task_incarnation_; } - void SetTaskIncarnation(uint64_t task_incarnation) { - task_incarnation_ = task_incarnation; - } - void Connect() { - SetConnected(task_incarnation_); - LOG(INFO) << task_name_ - << " has connected to coordination service. Incarnation: " - << task_incarnation_; - } - void SetConnected(uint64_t task_incarnation); - void Disconnect(uint64_t grace_period_duration_us); - absl::Status RecordHeartbeat(uint64_t task_incarnation); - int64_t TimeSinceLastHeartbeatMs(); - // Sets the error and returns true if the task state is not ERROR. - // Otherwise, don't overwrite the error and return false. - bool SetError(const absl::Status& status); - DeviceInfo GetDeviceInfo() { return devices_; } - void CollectDeviceInfo(const DeviceInfo& devices) { devices_ = devices; } - // Checks if task has called WaitForAllTasks() previously, which gathers the - // local device info. - bool DeviceInfoIsCollected() { return devices_.device_size() != 0; } - - // This is used to propagate state changes (disconnect, error) to ongoing - // barriers. - absl::flat_hash_set GetOngoingBarriers(); - // The task has a new ongoing barrier. This does not mean that it has - // reached the barrier. - void JoinBarrier(std::string_view barrier_id); - // The task has exited a barrier (because a barrier has passed). - void ExitBarrier(std::string_view barrier_id); - // Returns true if the task has been disconnected beyond the grace period - // and no further agent requests are expected. Note that the grace period - // accounts for the lag time between the service recording the state change - // and the agent stopping heartbeats/error polling. - bool IsDisconnectedBeyondGracePeriod(); - - private: - std::string task_name_; - // Incarnation ID for CPU:0 on remote task. - uint64_t task_incarnation_ = 0; - - CoordinatedTaskState state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED; - absl::Status status_; - absl::Mutex last_heartbeat_mu_; - uint64_t last_heartbeat_us_ ABSL_GUARDED_BY(last_heartbeat_mu_); - // This denotes the deadline after which we stop accepting heartbeats or - // error polling requests from a disconnected task. This grace period - // accounts for the lag time between the service recording the state change - // and the agent stopping heartbeats/error polling. - uint64_t disconnect_grace_period_us_ = 0; - DeviceInfo devices_; - // For now, we assume there won't be many simultaneous barriers so we simply - // use a set. - absl::flat_hash_set ongoing_barriers_for_task_; - // TODO(b/342448688): Re-use config's recoverable jobs instead. - bool recoverable_ = false; - }; - - // AlivenessState tracks the state of pending GetAliveTasks calls. - struct AlivenessState { - // All tasks that can participate in the GetAliveTasks barrier. - CoordinatedTaskSet tasks; - // All tasks currently blocked on the barrier. - CoordinatedTaskSet in_barrier; - // Done callbacks for the tasks blocked on the barrier. - std::vector dones; - }; - - // Returns the set of alive tasks drawn from the provided set of tasks. - CoordinatedTaskSet AliveTasks(const CoordinatedTaskSet& tasks) const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - - // Refreshes the AlivenessStates of all pending GetAliveTasks call, - // potentially finishing some of the pending calls. The AlivenessStates should - // be refreshed, for example, after a task has failed. - void RefreshAliveness() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - - static CoordinatedTaskStateInfo CreateTaskStateInfo( - const CoordinatedTask& task, const TaskState& state); - - std::unique_ptr client_cache_; - Env& env_; - const uint64_t service_incarnation_ = random::New64(); - const uint64_t heartbeat_timeout_ms_; - bool cluster_register_with_barrier_ = false; - const absl::Duration cluster_register_timeout_; - const absl::Duration shutdown_barrier_timeout_; - // If a task restarts with a new incarnation, we may allow it to reconnect - // silently if configured. This is useful when we know that a task can - // immediately resume work upon re-connecting to the service. - bool allow_new_incarnation_to_reconnect_ = false; - std::function - post_aggregate_device_fn_; - - const std::string device_propagation_barrier_id_ = - absl::StrCat("WaitForAllTasks::", std::to_string(service_incarnation_)); - const std::string shutdown_barrier_id_ = - absl::StrCat("Shutdown::", std::to_string(service_incarnation_)); - std::vector shutdown_barrier_tasks_ - ABSL_GUARDED_BY(state_mu_); - - absl::Mutex state_mu_; - absl::flat_hash_map> cluster_state_ - ABSL_GUARDED_BY(state_mu_); - DeviceInfo cluster_devices_ ABSL_GUARDED_BY(state_mu_); - - absl::Mutex kv_mu_; - // Ordered map to store config key-values - std::map kv_store_ ABSL_GUARDED_BY(kv_mu_); - absl::flat_hash_map> get_cb_ - ABSL_GUARDED_BY(kv_mu_); - - absl::flat_hash_map barriers_ - ABSL_GUARDED_BY(state_mu_); - // For now, we assume there won't be many simultaneous barriers so we simply - // use a set. - absl::flat_hash_set ongoing_barriers_ ABSL_GUARDED_BY(state_mu_); - - // The state of all pending GetAliveTasks calls. - std::vector aliveness_states_ ABSL_GUARDED_BY(state_mu_); - - absl::flat_hash_set recoverable_jobs_; - - // Whether the agents are polling for error from the service. It will be set - // to true when the service sees the first error polling request. Once set to - // true, the value will never change back to false. - bool client_polling_for_error_ ABSL_GUARDED_BY(state_mu_) = false; - ErrorPollingState error_polling_state_ ABSL_GUARDED_BY(state_mu_); - - absl::CondVar check_staleness_thread_cv_; - bool shutting_down_ ABSL_GUARDED_BY(state_mu_) = false; - // Note: sequence matters here, we must destroy the staleness thread before - // the other state related to barriers and heartbeats to prevent illegal - // memory access. - std::unique_ptr check_staleness_thread_; - - CoordinationServiceStandaloneImpl(const CoordinationServiceStandaloneImpl&) = - delete; - void operator=(const CoordinationServiceStandaloneImpl&) = delete; -}; +} // namespace -void CoordinationServiceStandaloneImpl::ErrorPollingState::SetError( +void CoordinationService::ErrorPollingState::SetError( const absl::Status& error) { if (responded_) return; responded_ = true; @@ -524,7 +107,7 @@ void CoordinationServiceStandaloneImpl::ErrorPollingState::SetError( done_callbacks_.clear(); } -void CoordinationServiceStandaloneImpl::ErrorPollingState::RemoveTask( +void CoordinationService::ErrorPollingState::RemoveTask( const CoordinatedTask& task, absl::string_view reason) { if (done_callbacks_.contains(task)) { done_callbacks_[task](MakeCoordinationError(absl::CancelledError( @@ -533,7 +116,7 @@ void CoordinationServiceStandaloneImpl::ErrorPollingState::RemoveTask( done_callbacks_.erase(task); } -void CoordinationServiceStandaloneImpl::ErrorPollingState::AddTask( +void CoordinationService::ErrorPollingState::AddTask( const CoordinatedTask& task, StatusCallback&& done) { // Do not allow to insert a task if the service has already responded. if (Responded()) return; @@ -542,8 +125,7 @@ void CoordinationServiceStandaloneImpl::ErrorPollingState::AddTask( done_callbacks_[task] = done; } -void CoordinationServiceStandaloneImpl::TaskState::SetConnected( - uint64_t task_incarnation) { +void CoordinationService::TaskState::SetConnected(uint64_t task_incarnation) { state_ = CoordinatedTaskState::TASKSTATE_CONNECTED; status_ = absl::OkStatus(); task_incarnation_ = task_incarnation; @@ -551,7 +133,7 @@ void CoordinationServiceStandaloneImpl::TaskState::SetConnected( last_heartbeat_us_ = Env::Default()->NowMicros(); } -void CoordinationServiceStandaloneImpl::TaskState::Disconnect( +void CoordinationService::TaskState::Disconnect( uint64_t grace_period_duration_us) { disconnect_grace_period_us_ = Env::Default()->NowMicros() + grace_period_duration_us; @@ -559,15 +141,14 @@ void CoordinationServiceStandaloneImpl::TaskState::Disconnect( status_ = absl::OkStatus(); } -bool CoordinationServiceStandaloneImpl::TaskState::SetError( - const absl::Status& status) { +bool CoordinationService::TaskState::SetError(const absl::Status& status) { if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return false; state_ = CoordinatedTaskState::TASKSTATE_ERROR; status_ = status; return true; } -absl::Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat( +absl::Status CoordinationService::TaskState::RecordHeartbeat( uint64_t task_incarnation) { if (!status_.ok()) return status_; // Record heartbeat. @@ -589,40 +170,36 @@ absl::Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat( } } -int64_t -CoordinationServiceStandaloneImpl::TaskState::TimeSinceLastHeartbeatMs() { +int64_t CoordinationService::TaskState::TimeSinceLastHeartbeatMs() { absl::MutexLock l(&last_heartbeat_mu_); return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000; } absl::flat_hash_set -CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() { +CoordinationService::TaskState::GetOngoingBarriers() { return ongoing_barriers_for_task_; } -void CoordinationServiceStandaloneImpl::TaskState::JoinBarrier( - std::string_view barrier_id) { +void CoordinationService::TaskState::JoinBarrier(absl::string_view barrier_id) { ongoing_barriers_for_task_.emplace(barrier_id); } -void CoordinationServiceStandaloneImpl::TaskState::ExitBarrier( - std::string_view barrier_id) { +void CoordinationService::TaskState::ExitBarrier(absl::string_view barrier_id) { ongoing_barriers_for_task_.erase(barrier_id); } -bool CoordinationServiceStandaloneImpl::TaskState:: - IsDisconnectedBeyondGracePeriod() { +bool CoordinationService::TaskState::IsDisconnectedBeyondGracePeriod() { return GetState() == CoordinatedTaskState::TASKSTATE_DISCONNECTED && Env::Default()->NowMicros() > disconnect_grace_period_us_; } -void CoordinationServiceStandaloneImpl::SetDeviceAggregationFunction( +void CoordinationService::SetDeviceAggregationFunction( std::function post_aggregate_device_fn) { post_aggregate_device_fn_ = std::move(post_aggregate_device_fn); } -CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( +CoordinationService::CoordinationService( Env* env, const CoordinationServiceConfig& config, std::unique_ptr client_cache) : client_cache_(std::move(client_cache)), @@ -651,9 +228,9 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( StartCheckStaleness(); } -void CoordinationServiceStandaloneImpl::CheckHeartbeatTimeout() { +void CoordinationService::CheckHeartbeatTimeout() { absl::Status status = absl::OkStatus(); - std::vector stale_task_names; + std::vector stale_task_names; absl::MutexLock l(&state_mu_); for (const auto& [task_name, task_state] : cluster_state_) { // Skip tasks that are not registered or in error state. @@ -691,12 +268,12 @@ void CoordinationServiceStandaloneImpl::CheckHeartbeatTimeout() { } } -void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { +void CoordinationService::CheckBarrierTimeout() { absl::flat_hash_map expired_barriers; uint64_t current_time_micros = Env::Default()->NowMicros(); absl::MutexLock l(&state_mu_); // Gather barriers which have timed out. - for (std::string_view barrier_id : ongoing_barriers_) { + for (absl::string_view barrier_id : ongoing_barriers_) { auto* barrier = &barriers_[barrier_id]; if (current_time_micros > barrier->deadline_in_micros) { expired_barriers[barrier_id] = barrier; @@ -734,7 +311,7 @@ void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { } } -void CoordinationServiceStandaloneImpl::CheckStaleness() { +void CoordinationService::CheckStaleness() { // Used to store stale tasks and barriers. while (true) { { @@ -749,14 +326,12 @@ void CoordinationServiceStandaloneImpl::CheckStaleness() { } } -void CoordinationServiceStandaloneImpl::StartCheckStaleness() { - check_staleness_thread_.reset(env_.StartThread( - {}, kHealthCheckThread, - absl::bind_front(&CoordinationServiceStandaloneImpl::CheckStaleness, - this))); +void CoordinationService::StartCheckStaleness() { + check_staleness_thread_.reset( + env_.StartThread({}, kHealthCheckThread, [this]() { CheckStaleness(); })); } -void CoordinationServiceStandaloneImpl::Stop() { +void CoordinationService::Stop() { // Prevent recursion. if (shutting_down_) { return; @@ -801,12 +376,10 @@ void CoordinationServiceStandaloneImpl::Stop() { } } -bool CoordinationServiceStandaloneImpl::ServiceHasStopped() const { - return shutting_down_; -} +bool CoordinationService::ServiceHasStopped() const { return shutting_down_; } // Helper to log progress to having waited for all tasks. -void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const { +void CoordinationService::LogConnectStatusLocked() const { const int num_tasks = cluster_state_.size(); int pending_tasks = 0; std::vector task_names; @@ -825,8 +398,8 @@ void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const { } } -absl::Status CoordinationServiceStandaloneImpl::RegisterTask( - const CoordinatedTask& task, uint64_t incarnation) { +absl::Status CoordinationService::RegisterTask(const CoordinatedTask& task, + uint64_t incarnation) { absl::Notification done; absl::Status status; RegisterTaskAsync(task, incarnation, [&](absl::Status s) { @@ -837,9 +410,10 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask( return status; } -CoordinationServiceInterface::BarrierCallback -CoordinationServiceStandaloneImpl::ConnectAfterBarrierPasses( - absl::string_view task_name, uint64_t incarnation, StatusCallback done) { +CoordinationService::BarrierCallback +CoordinationService::ConnectAfterBarrierPasses(absl::string_view task_name, + uint64_t incarnation, + StatusCallback done) { return [this, task = std::string(task_name), incarnation, done = std::move(done)](absl::Status s, int64_t unused_counter) mutable { @@ -860,8 +434,8 @@ CoordinationServiceStandaloneImpl::ConnectAfterBarrierPasses( }; } -void CoordinationServiceStandaloneImpl::ConnectTask(const CoordinatedTask& task, - uint64_t incarnation) { +void CoordinationService::ConnectTask(const CoordinatedTask& task, + uint64_t incarnation) { const std::string task_name = GetTaskName(task); const std::unique_ptr& task_state = cluster_state_[task_name]; @@ -872,8 +446,9 @@ void CoordinationServiceStandaloneImpl::ConnectTask(const CoordinatedTask& task, } } -void CoordinationServiceStandaloneImpl::RegisterTaskAsync( - const CoordinatedTask& task, uint64_t incarnation, StatusCallback done) { +void CoordinationService::RegisterTaskAsync(const CoordinatedTask& task, + uint64_t incarnation, + StatusCallback done) { const std::string task_name = GetTaskName(task); std::string error_message; @@ -972,9 +547,9 @@ void CoordinationServiceStandaloneImpl::RegisterTaskAsync( done(error); } -void CoordinationServiceStandaloneImpl::WaitForAllTasks( - const CoordinatedTask& task, const DeviceInfo& devices, - StatusCallback done) { +void CoordinationService::WaitForAllTasks(const CoordinatedTask& task, + const DeviceInfo& devices, + StatusCallback done) { { absl::MutexLock l(&state_mu_); if (ServiceHasStopped()) { @@ -997,8 +572,8 @@ void CoordinationServiceStandaloneImpl::WaitForAllTasks( int64_t unused_counter) { done(s); }); } -void CoordinationServiceStandaloneImpl::ShutdownTaskAsync( - const CoordinatedTask& task, StatusCallback done) { +void CoordinationService::ShutdownTaskAsync(const CoordinatedTask& task, + StatusCallback done) { VLOG(3) << "Task " << GetTaskName(task) << " invoked ShutdownTaskAsync()"; if (shutdown_barrier_timeout_ > absl::ZeroDuration() && !task.recoverable()) { // Impose shutdown barrier so that all (non-recoverable) tasks can @@ -1039,14 +614,12 @@ void CoordinationServiceStandaloneImpl::ShutdownTaskAsync( } } -absl::Status CoordinationServiceStandaloneImpl::ResetTask( - const CoordinatedTask& task) { +absl::Status CoordinationService::ResetTask(const CoordinatedTask& task) { absl::MutexLock l(&state_mu_); return DisconnectTask(task); } -absl::Status CoordinationServiceStandaloneImpl::DisconnectTask( - const CoordinatedTask& task) { +absl::Status CoordinationService::DisconnectTask(const CoordinatedTask& task) { const std::string task_name = GetTaskName(task); // Check if task is valid and not already disconnected. if (ServiceHasStopped()) { @@ -1075,16 +648,16 @@ absl::Status CoordinationServiceStandaloneImpl::DisconnectTask( return absl::OkStatus(); } -const DeviceInfo& CoordinationServiceStandaloneImpl::ListClusterDevices() { +const DeviceInfo& CoordinationService::ListClusterDevices() { return cluster_devices_; } -uint64_t CoordinationServiceStandaloneImpl::GetServiceIncarnation() { +uint64_t CoordinationService::GetServiceIncarnation() { return service_incarnation_; } -absl::Status CoordinationServiceStandaloneImpl::ReportTaskError( - const CoordinatedTask& task, const absl::Status& error) { +absl::Status CoordinationService::ReportTaskError(const CoordinatedTask& task, + const absl::Status& error) { const std::string task_name = GetTaskName(task); absl::MutexLock l(&state_mu_); if (ServiceHasStopped()) { @@ -1103,7 +676,7 @@ absl::Status CoordinationServiceStandaloneImpl::ReportTaskError( return absl::OkStatus(); } -CoordinatedTaskStateInfo CoordinationServiceStandaloneImpl::CreateTaskStateInfo( +CoordinatedTaskStateInfo CoordinationService::CreateTaskStateInfo( const CoordinatedTask& task, const TaskState& state) { CoordinatedTaskStateInfo info; info.set_state(state.GetState()); @@ -1119,8 +692,7 @@ CoordinatedTaskStateInfo CoordinationServiceStandaloneImpl::CreateTaskStateInfo( return info; } -std::vector -CoordinationServiceStandaloneImpl::GetTaskState( +std::vector CoordinationService::GetTaskState( const std::vector& tasks) { std::vector states_info; states_info.reserve(tasks.size()); @@ -1133,8 +705,8 @@ CoordinationServiceStandaloneImpl::GetTaskState( return states_info; } -std::vector -CoordinationServiceStandaloneImpl::GetJobState(absl::string_view job_name) { +std::vector CoordinationService::GetJobState( + absl::string_view job_name) { absl::MutexLock l(&state_mu_); std::vector states_info; for (const auto& [name, task_state] : cluster_state_) { @@ -1147,8 +719,8 @@ CoordinationServiceStandaloneImpl::GetJobState(absl::string_view job_name) { return states_info; } -absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( - const CoordinatedTask& task, uint64_t incarnation) { +absl::Status CoordinationService::RecordHeartbeat(const CoordinatedTask& task, + uint64_t incarnation) { const std::string task_name = GetTaskName(task); absl::Status s = absl::OkStatus(); absl::MutexLock l(&state_mu_); @@ -1195,7 +767,7 @@ absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( return s; } -bool CoordinationServiceStandaloneImpl::AllTasksAreRecoverable( +bool CoordinationService::AllTasksAreRecoverable( const std::vector& tasks) { for (const auto& task : tasks) { if (!cluster_state_[GetTaskName(task)]->IsRecoverable() && @@ -1206,9 +778,9 @@ bool CoordinationServiceStandaloneImpl::AllTasksAreRecoverable( return true; } -void CoordinationServiceStandaloneImpl::PropagateError( +void CoordinationService::PropagateError( const absl::Status& error, - const std::vector& source_task_names, + const std::vector& source_task_names, bool is_reported_by_task) { std::vector source_tasks; source_tasks.reserve(source_task_names.size()); @@ -1218,7 +790,7 @@ void CoordinationServiceStandaloneImpl::PropagateError( return PropagateError(error, source_tasks, is_reported_by_task); } -void CoordinationServiceStandaloneImpl::PropagateError( +void CoordinationService::PropagateError( const absl::Status& error, const std::vector& source_tasks, bool is_reported_by_task) { VLOG(3) << "PropagateError(): " << error; @@ -1280,7 +852,7 @@ void CoordinationServiceStandaloneImpl::PropagateError( // The normalized key will not have leading or trailing slashes, and all parts // in the key path are separated by exactly one slack ('/'). // E.g., ///a//b/c// --> a/b/c -std::string NormalizeKey(std::string_view orig_key) { +std::string NormalizeKey(absl::string_view orig_key) { std::string norm_key = std::string(orig_key); const char* src = norm_key.c_str(); std::string::iterator dst = norm_key.begin(); @@ -1304,13 +876,14 @@ std::string NormalizeKey(std::string_view orig_key) { return norm_key; } -absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue( - std::string_view key, std::string_view value) { +absl::Status CoordinationService::InsertKeyValue(absl::string_view key, + absl::string_view value) { return InsertKeyValue(key, value, /*allow_overwrite=*/false); } -absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue( - std::string_view key, std::string_view value, bool allow_overwrite) { +absl::Status CoordinationService::InsertKeyValue(absl::string_view key, + absl::string_view value, + bool allow_overwrite) { VLOG(3) << "InsertKeyValue(): " << key << ": " << value << " allow_overwrite: " << allow_overwrite; const std::string norm_key = NormalizeKey(key); @@ -1330,8 +903,8 @@ absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue( return absl::OkStatus(); } -void CoordinationServiceStandaloneImpl::GetKeyValueAsync( - std::string_view key, StatusOrValueCallback done) { +void CoordinationService::GetKeyValueAsync(absl::string_view key, + StatusOrValueCallback done) { VLOG(3) << "GetKeyValue(): " << key; const std::string norm_key = NormalizeKey(key); absl::MutexLock l(&kv_mu_); @@ -1348,8 +921,8 @@ void CoordinationServiceStandaloneImpl::GetKeyValueAsync( cb_iter->second.emplace_back(std::move(done)); } -absl::StatusOr CoordinationServiceStandaloneImpl::TryGetKeyValue( - std::string_view key) { +absl::StatusOr CoordinationService::TryGetKeyValue( + absl::string_view key) { VLOG(3) << "TryGetKeyValue(): " << key; const std::string norm_key = NormalizeKey(key); absl::MutexLock l(&kv_mu_); @@ -1360,8 +933,8 @@ absl::StatusOr CoordinationServiceStandaloneImpl::TryGetKeyValue( return iter->second; } -std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( - std::string_view directory_key) { +std::vector CoordinationService::GetKeyValueDir( + absl::string_view directory_key) { VLOG(3) << "TryGetKeyValueDir(): " << directory_key; std::vector kvs_in_directory; const std::string norm_key = NormalizeKey(directory_key); @@ -1370,9 +943,9 @@ std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( absl::MutexLock l(&kv_mu_); // Find first key in ordered map that has the directory prefix. auto begin = kv_store_.lower_bound(dir); - std::map::iterator it; + auto it = begin; // Iterate through key range that match directory prefix. - for (it = begin; it != kv_store_.end(); ++it) { + for (; it != kv_store_.end(); ++it) { // Stop once the next key does not have the directory prefix. Since keys are // ordered, none of the other keys would have a matching prefix. if (std::mismatch(dir.begin(), dir.end(), it->first.begin(), @@ -1389,16 +962,15 @@ std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( return kvs_in_directory; } -absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( - std::string_view key) { +absl::Status CoordinationService::DeleteKeyValue(absl::string_view key) { VLOG(3) << "DeleteKeyValue(): " << key; const std::string norm_key = NormalizeKey(key); absl::MutexLock l(&kv_mu_); // Delete directory: find key range that match directory prefix const std::string dir = absl::StrCat(norm_key, "/"); auto begin = kv_store_.lower_bound(dir); - std::map::iterator end; - for (end = begin; end != kv_store_.end(); end++) { + auto end = begin; + for (; end != kv_store_.end(); end++) { if (std::mismatch(dir.begin(), dir.end(), end->first.begin(), end->first.end()) .first != dir.end()) @@ -1412,15 +984,14 @@ absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( return absl::OkStatus(); } -void CoordinationServiceStandaloneImpl::SetAllTasksError( - const absl::Status& error) { +void CoordinationService::SetAllTasksError(const absl::Status& error) { for (const auto& task_state : cluster_state_) { SetTaskError(task_state.first, error); } } -void CoordinationServiceStandaloneImpl::SetTaskError( - std::string_view task_name, const absl::Status& error) { +void CoordinationService::SetTaskError(absl::string_view task_name, + const absl::Status& error) { const CoordinatedTask task = GetTaskFromName(task_name); const std::unique_ptr& task_state = cluster_state_[task_name]; if (task_state->SetError(error)) { @@ -1430,8 +1001,8 @@ void CoordinationServiceStandaloneImpl::SetTaskError( } } -void CoordinationServiceStandaloneImpl::PollForErrorAsync( - const CoordinatedTask& task, StatusCallback done) { +void CoordinationService::PollForErrorAsync(const CoordinatedTask& task, + StatusCallback done) { const std::string task_name = GetTaskName(task); VLOG(3) << "Task " << task_name << " invoked PollForErrorAsync()."; @@ -1491,8 +1062,8 @@ void CoordinationServiceStandaloneImpl::PollForErrorAsync( // Initializes a new barrier. Returns false if the barrier should fail // immediately. -bool CoordinationServiceStandaloneImpl::InitializeBarrier( - BarrierState* barrier, std::string_view barrier_id, int64_t counter, +bool CoordinationService::InitializeBarrier( + BarrierState* barrier, absl::string_view barrier_id, int64_t counter, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, BarrierCallback done) { @@ -1548,7 +1119,7 @@ bool CoordinationServiceStandaloneImpl::InitializeBarrier( return true; } -bool CoordinationServiceStandaloneImpl::InitializeTasksAtBarrier( +bool CoordinationService::InitializeTasksAtBarrier( BarrierState* barrier, const std::vector& participating_tasks, BarrierCallback done) { @@ -1563,7 +1134,7 @@ bool CoordinationServiceStandaloneImpl::InitializeTasksAtBarrier( if (participating_tasks.empty()) { // Assume barrier is for entire cluster if no tasks are specified. for (const auto& task_state : cluster_state_) { - std::string_view task_name = task_state.first; + absl::string_view task_name = task_state.first; barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false; } return true; @@ -1590,8 +1161,9 @@ bool CoordinationServiceStandaloneImpl::InitializeTasksAtBarrier( return true; } -void CoordinationServiceStandaloneImpl::AddBarrierCallback( - BarrierState* barrier, const CoordinatedTask& task, BarrierCallback done) { +void CoordinationService::AddBarrierCallback(BarrierState* barrier, + const CoordinatedTask& task, + BarrierCallback done) { auto it = barrier->done_callbacks.find(task); if (it != barrier->done_callbacks.end()) { it->second(absl::CancelledError( @@ -1603,7 +1175,7 @@ void CoordinationServiceStandaloneImpl::AddBarrierCallback( barrier->done_callbacks[task] = std::move(done); } -void CoordinationServiceStandaloneImpl::BarrierAsync( +void CoordinationService::BarrierAsync( // Note: `barrier_id` uses a `std::string` instead of `string_view` as the // RPC may end (i.e. done callback is invoked) before this handler // completes, which would invalidate the `string_view`. @@ -1616,8 +1188,8 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( participating_tasks, std::move(done)); }; -void CoordinationServiceStandaloneImpl::BarrierAsyncLocked( - std::string barrier_id, int64_t counter, absl::Duration timeout, +void CoordinationService::BarrierAsyncLocked( + absl::string_view barrier_id, int64_t counter, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, BarrierCallback done) { @@ -1701,8 +1273,8 @@ void CoordinationServiceStandaloneImpl::BarrierAsyncLocked( ReachBarrier(barrier, task); } -void CoordinationServiceStandaloneImpl::FailBarrierWithCounterMismatch( - BarrierState* barrier, int64_t counter) { +void CoordinationService::FailBarrierWithCounterMismatch(BarrierState* barrier, + int64_t counter) { std::string reason; if (counter == 0 || barrier->counter == 0) { reason = @@ -1726,7 +1298,7 @@ void CoordinationServiceStandaloneImpl::FailBarrierWithCounterMismatch( PassBarrier(barrier, error); } -absl::Status CoordinationServiceStandaloneImpl::CancelBarrier( +absl::Status CoordinationService::CancelBarrier( // Note: `barrier_id` uses a `std::string` instead of `string_view` as the // RPC may end (i.e. done callback is invoked) before this handler // completes, which would invalidate the `string_view`. @@ -1781,8 +1353,8 @@ absl::Status CoordinationServiceStandaloneImpl::CancelBarrier( } // Mark barrier as passed. -void CoordinationServiceStandaloneImpl::PassBarrier( - BarrierState* barrier, const absl::Status& result) { +void CoordinationService::PassBarrier(BarrierState* barrier, + const absl::Status& result) { barrier->passed = true; barrier->result = result; VLOG(3) << "Barrier(" << BarrierName(*barrier) @@ -1828,7 +1400,8 @@ void CoordinationServiceStandaloneImpl::PassBarrier( } // Returns true if x is a (non-strict) subset of y. -bool TaskSetSubset(const CoordinatedTaskSet& x, const CoordinatedTaskSet& y) { +bool TaskSetSubset(const CoordinationService::CoordinatedTaskSet& x, + const CoordinationService::CoordinatedTaskSet& y) { return std::all_of(x.begin(), x.end(), [&y](const CoordinatedTask& task) { return y.contains(task); }); @@ -1840,11 +1413,12 @@ bool TaskSetSubset(const CoordinatedTaskSet& x, const CoordinatedTaskSet& y) { // the equal operator on the underlying elements in the sets, but the equal // operator is not defined on protos. Thus, we have to implement our own // equality function. -bool TaskSetEqual(const CoordinatedTaskSet& x, const CoordinatedTaskSet& y) { +bool TaskSetEqual(const CoordinationService::CoordinatedTaskSet& x, + const CoordinationService::CoordinatedTaskSet& y) { return x.size() == y.size() && TaskSetSubset(x, y); } -CoordinatedTaskSet CoordinationServiceStandaloneImpl::AliveTasks( +CoordinationService::CoordinatedTaskSet CoordinationService::AliveTasks( const CoordinatedTaskSet& tasks) const { CoordinatedTaskSet alive_tasks; for (const CoordinatedTask& task : tasks) { @@ -1858,7 +1432,7 @@ CoordinatedTaskSet CoordinationServiceStandaloneImpl::AliveTasks( return alive_tasks; } -void CoordinationServiceStandaloneImpl::RefreshAliveness() { +void CoordinationService::RefreshAliveness() { // Try to finish every pending GetAliveTasks call. auto it = aliveness_states_.begin(); while (it != aliveness_states_.end()) { @@ -1880,7 +1454,7 @@ void CoordinationServiceStandaloneImpl::RefreshAliveness() { } } -void CoordinationServiceStandaloneImpl::GetAliveTasksAsync( +void CoordinationService::GetAliveTasksAsync( const tensorflow::CoordinatedTask& requesting_task, const std::vector& tasks, GetAliveTasksCallback done) { @@ -1924,8 +1498,7 @@ void CoordinationServiceStandaloneImpl::GetAliveTasksAsync( } } -void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( - const absl::Status& error) { +void CoordinationService::SendErrorPollingResponse(const absl::Status& error) { CHECK(IsClientPollingForError()) << "`SendErrorPollingResponse` should only be called after agents poll " "errors from the service."; @@ -1953,7 +1526,7 @@ void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( } } -bool CoordinationServiceStandaloneImpl::ValidateTaskArgs( +bool CoordinationService::ValidateTaskArgs( BarrierState* barrier, const CoordinatedTask& task, const std::vector& tasks_args) { // Assume all tasks are participating if no task is specified. @@ -2002,8 +1575,8 @@ bool CoordinationServiceStandaloneImpl::ValidateTaskArgs( return true; } -void CoordinationServiceStandaloneImpl::RepeatBarrierResult( - BarrierState* barrier, const CoordinatedTask& task) { +void CoordinationService::RepeatBarrierResult(BarrierState* barrier, + const CoordinatedTask& task) { BarrierCallback done = barrier->done_callbacks[task]; barrier->done_callbacks.erase(task); // Special hook for shutdown barrier to disconnect task. @@ -2019,8 +1592,8 @@ void CoordinationServiceStandaloneImpl::RepeatBarrierResult( done(barrier->result, barrier->counter); } -void CoordinationServiceStandaloneImpl::LeaveOngoingBarriers( - const CoordinatedTask& task, std::string_view reason) { +void CoordinationService::LeaveOngoingBarriers(const CoordinatedTask& task, + absl::string_view reason) { const std::string task_name = GetTaskName(task); const std::unique_ptr& task_state = cluster_state_[task_name]; // Unregister recoverable task from ongoing barriers. @@ -2059,8 +1632,8 @@ void CoordinationServiceStandaloneImpl::LeaveOngoingBarriers( } } -void CoordinationServiceStandaloneImpl::ReachBarrier( - BarrierState* barrier, const CoordinatedTask& task) { +void CoordinationService::ReachBarrier(BarrierState* barrier, + const CoordinatedTask& task) { // Remove pending task. // We need to check if task made a repeated call after reaching the // barrier. @@ -2076,7 +1649,7 @@ void CoordinationServiceStandaloneImpl::ReachBarrier( } }; -void CoordinationServiceStandaloneImpl::AggregateClusterDevices() { +void CoordinationService::AggregateClusterDevices() { assert(cluster_devices_.device_size() == 0); std::vector ordered_tasks; // Sort by task name to set deterministic order for cluster devices. @@ -2103,7 +1676,7 @@ void CoordinationServiceStandaloneImpl::AggregateClusterDevices() { } } -void CoordinationServiceStandaloneImpl::DisconnectAllNonRecoverableTasks() { +void CoordinationService::DisconnectAllNonRecoverableTasks() { for (const auto& [task_name, state] : cluster_state_) { if (state->IsRecoverable()) { // Recoverable tasks will disconnect independently without the @@ -2117,8 +1690,7 @@ void CoordinationServiceStandaloneImpl::DisconnectAllNonRecoverableTasks() { } } -std::vector -CoordinationServiceStandaloneImpl::GetTasksForShutdownBarrier() { +std::vector CoordinationService::GetTasksForShutdownBarrier() { absl::MutexLock l(&state_mu_); if (shutdown_barrier_tasks_.empty()) { for (const auto& [task_name, task_state] : cluster_state_) { @@ -2130,7 +1702,7 @@ CoordinationServiceStandaloneImpl::GetTasksForShutdownBarrier() { return shutdown_barrier_tasks_; } -void CoordinationServiceStandaloneImpl::CompleteShutdownAfterBarrier( +void CoordinationService::CompleteShutdownAfterBarrier( const absl::Status& result, BarrierState* barrier) { if (result.ok()) { LOG(INFO) << "Shutdown barrier in coordination service has passed."; @@ -2164,22 +1736,13 @@ void CoordinationServiceStandaloneImpl::CompleteShutdownAfterBarrier( SetAllTasksError(shutdown_error); } } -} // namespace - -std::unique_ptr -CoordinationServiceInterface::EnableCoordinationService( - Env* env, const tensorflow::CoordinationServiceConfig& config, - std::unique_ptr cache) { - return std::make_unique(env, config, - std::move(cache)); -} -bool CoordinationServiceStandaloneImpl::isRecoverableJob( - const std::string_view task_name) const { +bool CoordinationService::isRecoverableJob( + const absl::string_view task_name) const { return recoverable_jobs_.find(task_name) != recoverable_jobs_.end(); } -void CoordinationServiceStandaloneImpl::SendErrorPollingResponseOrFailAllTasks( +void CoordinationService::SendErrorPollingResponseOrFailAllTasks( const absl::Status& error) { CHECK(!error.ok()) << "SendErrorPollingResponseOrFailAllTasks called with OK " "status. Should always return an error."; @@ -2201,7 +1764,7 @@ void CoordinationServiceStandaloneImpl::SendErrorPollingResponseOrFailAllTasks( } } -bool CoordinationServiceStandaloneImpl::IsClientPollingForError() const { +bool CoordinationService::IsClientPollingForError() const { return client_polling_for_error_; } diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index d7e42d48981301..abd6debaac5039 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -20,18 +20,27 @@ limitations under the License. #include #include #include -#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/platform/env.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" +#include "tsl/platform/random.h" namespace tsl { class Env; @@ -52,28 +61,60 @@ class Env; // coordination. One instance of the service should be deployed in a cluster, // handling various requests and stores configuration key-value data for the // tasks. Each task interacts with the service through CoordinationServiceAgent. -class CoordinationServiceInterface { +class CoordinationService { public: using StatusOrValueCallback = - std::function&)>; + std::function&)>; using BarrierCallback = std::function; using GetAliveTasksCallback = std::function&)>; - virtual ~CoordinationServiceInterface() = default; - - static std::unique_ptr - EnableCoordinationService(Env* env, - const tensorflow::CoordinationServiceConfig& config, - std::unique_ptr cache); + // Convenience structs to allow using CoordinatedTask as container keys. + struct CoordinatedTaskHash { + uint64_t operator()(const tensorflow::CoordinatedTask& task) const { + return absl::HashOf(task.job_name(), task.task_id()); + } + }; + struct CoordinatedTaskEqual { + bool operator()(const tensorflow::CoordinatedTask& lhs, + const tensorflow::CoordinatedTask& rhs) const { + return lhs.job_name() == rhs.job_name() && lhs.task_id() == rhs.task_id(); + } + }; + + using CoordinatedTaskSet = + absl::flat_hash_set; + + static std::unique_ptr Create( + Env* env, const tensorflow::CoordinationServiceConfig& config, + std::unique_ptr cache) { + return std::make_unique(env, config, std::move(cache)); + } + + // TODO: b/410607726 - Remove once deprecated EnableCoordinationService is + // unused. + static std::unique_ptr EnableCoordinationService( + Env* env, const tensorflow::CoordinationServiceConfig& config, + std::unique_ptr cache) { + return Create(env, config, std::move(cache)); + } + + CoordinationService(Env* env, + const tensorflow::CoordinationServiceConfig& config, + std::unique_ptr client_cache); + + ~CoordinationService() { + absl::MutexLock lock(&state_mu_); + Stop(); + } // This function is invoked after each task's local devices are appended in a // deterministic order during WaitForAllTasks(). This is useful to convert the // result into another message, or set global device ids. - virtual void SetDeviceAggregationFunction( - std::function< - tensorflow::DeviceInfo(const tensorflow::DeviceInfo& devices)> - post_aggregate_device_fn) = 0; + void SetDeviceAggregationFunction(std::function + post_aggregate_device_fn); // Register a task to the service. // Possible service errors: @@ -82,19 +123,19 @@ class CoordinationServiceInterface { // - Aborted: (1) task is in error state, or (2) task is in connected state // with a different incarnation, indicating that it restarted. // - DeadlineExceeded: waited too long for straggler tasks to register. - virtual absl::Status RegisterTask(const tensorflow::CoordinatedTask& task, - uint64_t incarnation) = 0; - virtual void RegisterTaskAsync(const tensorflow::CoordinatedTask& task, - uint64_t incarnation, StatusCallback done) = 0; + absl::Status RegisterTask(const tensorflow::CoordinatedTask& task, + uint64_t incarnation); + void RegisterTaskAsync(const tensorflow::CoordinatedTask& task, + uint64_t incarnation, StatusCallback done); // Wait for all tasks to be up and running, and register local device // info. The callback is invoked when all tasks are up and registered, or some // error occurs. // Each task's local devices will be appended in a deterministic order, and // post-processed by the callback in SetDeviceAggregationFunction() (if set). - virtual void WaitForAllTasks(const tensorflow::CoordinatedTask& task, - const tensorflow::DeviceInfo& devices, - StatusCallback done) = 0; + void WaitForAllTasks(const tensorflow::CoordinatedTask& task, + const tensorflow::DeviceInfo& devices, + StatusCallback done); // Disconnects task from the service. If `shutdown_barrier_timeout_in_ms` is // specified in the config, blocks until all tasks reach the barrier before @@ -103,62 +144,59 @@ class CoordinationServiceInterface { // - Internal: Service has shut down. // - InvalidArgument: Unexpected task request. // - FailedPrecondition: task has already disconnected. - virtual void ShutdownTaskAsync(const tensorflow::CoordinatedTask& task, - StatusCallback done) = 0; + void ShutdownTaskAsync(const tensorflow::CoordinatedTask& task, + StatusCallback done); // Disconnects task from the service and cleans up its internal error state. // Possible service errors: // - Internal: Service has shut down. // - InvalidArgument: Unexpected task request. // - FailedPrecondition: task has already disconnected. - virtual absl::Status ResetTask(const tensorflow::CoordinatedTask& task) = 0; + absl::Status ResetTask(const tensorflow::CoordinatedTask& task); // Update the heartbeat timestamp of a task. This should only be invoked on // the leader of the cluster. // - Internal: Service has shut down. - virtual absl::Status RecordHeartbeat(const tensorflow::CoordinatedTask& task, - uint64_t incarnation) = 0; + absl::Status RecordHeartbeat(const tensorflow::CoordinatedTask& task, + uint64_t incarnation); // Set a task in error state permanently. - virtual absl::Status ReportTaskError(const tensorflow::CoordinatedTask& task, - const absl::Status& error) = 0; + absl::Status ReportTaskError(const tensorflow::CoordinatedTask& task, + const absl::Status& error); // Get the state and the error status of the tasks. - virtual std::vector GetTaskState( - const std::vector& task) = 0; + std::vector GetTaskState( + const std::vector& task); // Gets the state and the error status of the job. - virtual std::vector GetJobState( - absl::string_view job) = 0; + std::vector GetJobState( + absl::string_view job); // Insert a configuration key-value in the coordination service. // For now, a key-value can only be inserted once and cannot be updated. // The key-values are not persisted and will be lost if the leader fails. - virtual absl::Status InsertKeyValue(std::string_view key, - std::string_view value) = 0; - virtual absl::Status InsertKeyValue(std::string_view key, - std::string_view value, - bool allow_overwrite) = 0; + absl::Status InsertKeyValue(absl::string_view key, absl::string_view value); + absl::Status InsertKeyValue(absl::string_view key, absl::string_view value, + bool allow_overwrite); // Get a configuration key-value from the coordination service. The `done` // callback is invoked when the key-value becomes available. - virtual void GetKeyValueAsync(std::string_view key, - StatusOrValueCallback done) = 0; + void GetKeyValueAsync(absl::string_view key, StatusOrValueCallback done); // Get a configuration key-value from the coordination service. If the key // does not exist, return NotFound error. - virtual absl::StatusOr TryGetKeyValue(std::string_view key) = 0; + absl::StatusOr TryGetKeyValue(absl::string_view key); // Gets all values under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. This is not a blocking call. Agent does not need to be // connected to utilize the distributed key-value store. - virtual std::vector GetKeyValueDir( - std::string_view directory_key) = 0; + std::vector GetKeyValueDir( + absl::string_view directory_key); // Delete configuration key-value. If key is a directory, recursively clean // up all key-values under the directory. - virtual absl::Status DeleteKeyValue(std::string_view key) = 0; + absl::Status DeleteKeyValue(absl::string_view key); // Blocks until all (or a subset of) tasks are at the barrier or the barrier // fails. @@ -194,11 +232,11 @@ class CoordinationServiceInterface { // TODO(b/342448688): Allow re-use of ids by specifying different counters. // The counter field is mostly ignored at the moment with no user-facing // effect. - virtual void BarrierAsync( + void BarrierAsync( std::string barrier_id, int64_t counter, absl::Duration timeout, const tensorflow::CoordinatedTask& task, const std::vector& participating_tasks, - BarrierCallback done) = 0; + BarrierCallback done); // Aborts the barrier if it is ongoing. // Current and future WaitAtBarrier() calls with the same id will return a @@ -208,9 +246,8 @@ class CoordinationServiceInterface { // TODO(b/342448688): Allow re-use of ids by specifying different counters. // The counter field is mostly ignored at the moment with no user-facing // effect. - virtual absl::Status CancelBarrier( - std::string barrier_id, int64_t counter, - const tensorflow::CoordinatedTask& task) = 0; + absl::Status CancelBarrier(std::string barrier_id, int64_t counter, + const tensorflow::CoordinatedTask& task); // Returns the set of currently alive tasks. More specifically, given a set of // tasks T, GetAliveTasks(T) returns the subset T of alive tasks. Note that @@ -241,10 +278,9 @@ class CoordinationServiceInterface { // has failed and that every task calls GetAliveTasks([A, B, C, D]). The // invocation will return tasks [A, B, C]. The GetAliveTasks call acts as a // barrier across tasks A, B, and C. Task D, which failed, is ignored. - virtual void GetAliveTasksAsync( - const tensorflow::CoordinatedTask& requesting_task, - const std::vector& tasks, - GetAliveTasksCallback done) = 0; + void GetAliveTasksAsync(const tensorflow::CoordinatedTask& requesting_task, + const std::vector& tasks, + GetAliveTasksCallback done); // Gets error from the coordination service. Block until the service // returns an error or the task/service is shutdown. This should never be used @@ -255,8 +291,8 @@ class CoordinationServiceInterface { // coordination service, so once an error occurs after the first call, the // service will use the error polling mode to propagate the error to all // connected tasks instead of simply shutting down. - virtual void PollForErrorAsync(const tensorflow::CoordinatedTask& task, - StatusCallback done) = 0; + void PollForErrorAsync(const tensorflow::CoordinatedTask& task, + StatusCallback done); private: friend class CoordinationServiceRpcHandler; @@ -265,10 +301,363 @@ class CoordinationServiceInterface { friend class CoordinationServiceTest_ListClusterDevices_DevicesAreNotAddedTwice_Test; - virtual const tensorflow::DeviceInfo& ListClusterDevices() = 0; - virtual uint64_t GetServiceIncarnation() = 0; + void LogConnectStatusLocked() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + + const tensorflow::DeviceInfo& ListClusterDevices() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + uint64_t GetServiceIncarnation(); + void BarrierAsyncLocked( + absl::string_view barrier_id, int64_t counter, absl::Duration timeout, + const tensorflow::CoordinatedTask& task, + const std::vector& participating_tasks, + BarrierCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + BarrierCallback ConnectAfterBarrierPasses(absl::string_view task_name, + uint64_t incarnation, + StatusCallback done); + // Connects a task to the service, and leaves any previously ongoing barriers + // for recoverable tasks. + void ConnectTask(const tensorflow::CoordinatedTask& task, + uint64_t incarnation) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Checks if any task has stopped sending heartbeats. + void CheckHeartbeatTimeout(); + // Checks if any barrier has timed out. + void CheckBarrierTimeout(); + // Checks both heartbeat and barrier timeouts. Use a single function so they + // can be run in the same thread as threads are a constrained resource. + void CheckStaleness(); + // Starts a thread to check staleness. + void StartCheckStaleness(); + void Stop() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + bool ServiceHasStopped() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Report error from a task to all other connected tasks if the task is not + // recoverable. + // Note: SetTaskError() must be called before propagating its error. + void PropagateError( + const absl::Status& error, + const std::vector& source_tasks, + bool is_reported_by_task = false) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + void PropagateError(const absl::Status& error, + const std::vector& source_task_names, + bool is_reported_by_task = false) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Checks if all tasks are from recoverable jobs. + bool AllTasksAreRecoverable( + const std::vector& tasks) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + void SetTaskError(absl::string_view task_name, const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Used for cluster-wide errors (e.g. register or shutdown barrier fails). + void SetAllTasksError(const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + absl::Status DisconnectTask(const tensorflow::CoordinatedTask& task) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + void DisconnectAllNonRecoverableTasks() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + std::vector GetTasksForShutdownBarrier(); + + struct BarrierState { + std::string id = ""; + // Counter is incremented for each new barrier using the same id. + // No two barriers with the same id (and different counters) can be ongoing + // at the same time. + int64_t counter = 0; + bool passed = false; + absl::Status result = absl::UnknownError( + "Invalid barrier result."); // Only valid if `passed` is true. + uint64_t deadline_in_micros = 0; + int num_pending_tasks = 0; + // Specifies which tasks have called the barrier so far. + absl::flat_hash_map + tasks_at_barrier; + absl::flat_hash_map + done_callbacks; + // Specifies the task that initiated the barrier (the first task to call the + // barrier). + tensorflow::CoordinatedTask initiating_task; + }; + bool BarrierIsUninitialized(const BarrierState& barrier) { + return barrier.id.empty() && barrier.counter == 0 && !barrier.passed && + barrier.deadline_in_micros == 0 && barrier.num_pending_tasks == 0; + } + std::string BarrierName(absl::string_view barrier_id, int64_t counter) { + return absl::StrCat(barrier_id, "::", counter); + } + std::string BarrierName(const BarrierState& barrier) { + return BarrierName(barrier.id, barrier.counter); + } + // Initializes a new barrier. Returns false if the barrier should fail + // immediately. + bool InitializeBarrier( + BarrierState* barrier, absl::string_view barrier_id, int64_t counter, + absl::Duration timeout, const tensorflow::CoordinatedTask& task, + const std::vector& participating_tasks, + BarrierCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Initialize `BarrierState`'s tasks_at_barrier map. + bool InitializeTasksAtBarrier( + BarrierState* barrier, + const std::vector& participating_tasks, + BarrierCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Adds a callback to be called when the barrier is done. + // If there is an existing callback for that task, it will be overwritten, + // cancelling the previous callback. + void AddBarrierCallback(BarrierState* barrier, + const tensorflow::CoordinatedTask& task, + BarrierCallback done) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Ends the barrier with a result (ok or error). + void PassBarrier(BarrierState* barrier, const absl::Status& result) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // A task reaches the barrier. + void ReachBarrier(BarrierState* barrier, + const tensorflow::CoordinatedTask& task) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + void FailBarrierWithCounterMismatch(BarrierState* barrier, int64_t counter) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Propagates same result back to task. + void RepeatBarrierResult(BarrierState* barrier, + const tensorflow::CoordinatedTask& task) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Leaves any ongoing barriers. + // If the task is non-recoverable, the barrier exits with an error. + // If the task is recoverable, the barrier will 'unregister' a task and allow + // it to join back again later before the timeout. + void LeaveOngoingBarriers(const tensorflow::CoordinatedTask& task, + absl::string_view reason) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-barrier hook to connect all tasks. + void ConnectAllTasks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-barrier hook to aggregate device info. + void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-shutdown barrier hook to disconnect tasks that acked and propagate + // errors to those that have not. + void CompleteShutdownAfterBarrier(const absl::Status& result, + BarrierState* barrier) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Checks if the participating tasks are specified correctly across barrier + // calls and that the caller task is one of the participating tasks. + bool ValidateTaskArgs( + BarrierState* barrier, const tensorflow::CoordinatedTask& caller_task, + const std::vector& tasks_args) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + bool isRecoverableJob(absl::string_view task_name) const; + // Sends responses to error polling requests when an error is encountered. + void SendErrorPollingResponse(const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Responds to error polling or fails all tasks when an error is + // encountered. Should only be called when there is no service to client + // connection. + void SendErrorPollingResponseOrFailAllTasks(const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Returns whether the clients are polling for error from the service. If the + // clients are not polling for error from the service, the service should stop + // when there is an error. Otherwise, the service should not stop. + bool IsClientPollingForError() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + + class ErrorPollingState { + public: + // Returns whether the error polling requests have been responded. + bool Responded() const { return responded_; } + // Sets the error and executes the status callbacks. + void SetError(const absl::Status& error); + // Gets the error that is propagated to the agents. + const absl::Status& GetError() const { return error_; } + // Returns true if the task has sent request to poll for error from the + // service. + bool IsTaskPolling(absl::string_view task_name) const { + return polling_task_names_.contains(task_name); + } + // Adds a task to the error polling state. + void AddTask(const tensorflow::CoordinatedTask& task, + StatusCallback&& done); + + // Removes a task from the error polling state. + // If an existing polling request is present, we will invoke the callback + // with the `reason` argument. + // Note: for disconnected tasks, this does not actually propagate the error + // back, but prevents memory leaks by removing stale callbacks. + void RemoveTask(const tensorflow::CoordinatedTask& task, + absl::string_view reason); + + private: + bool responded_ = false; + absl::Status error_ = absl::OkStatus(); + absl::flat_hash_map + done_callbacks_; + absl::flat_hash_set polling_task_names_; + }; + + class TaskState { + public: + // Task state maintained on the coordination service side. + // State transition: + // Register Heartbeat + // DISCONNECTED -------> CONNECTED --------> ERROR (timeout) + // | ReportError + // +--------------> ERROR + // + // When task state becomes ERROR, propagate this status to other CONNECTED + // tasks in the cluster. + + explicit TaskState(absl::string_view task) { task_name_ = task; } + + tensorflow::CoordinatedTaskState GetState() const { return state_; } + absl::Status GetStatus() const { return status_; } + bool IsRecoverable() const { return recoverable_; } + void SetRecoverable(bool recoverable) { recoverable_ = recoverable; } + uint64_t GetTaskIncarnation() const { return task_incarnation_; } + void SetTaskIncarnation(uint64_t task_incarnation) { + task_incarnation_ = task_incarnation; + } + void Connect() { + SetConnected(task_incarnation_); + LOG(INFO) << task_name_ + << " has connected to coordination service. Incarnation: " + << task_incarnation_; + } + void SetConnected(uint64_t task_incarnation); + void Disconnect(uint64_t grace_period_duration_us); + absl::Status RecordHeartbeat(uint64_t task_incarnation); + int64_t TimeSinceLastHeartbeatMs(); + // Sets the error and returns true if the task state is not ERROR. + // Otherwise, don't overwrite the error and return false. + bool SetError(const absl::Status& status); + tensorflow::DeviceInfo GetDeviceInfo() { return devices_; } + void CollectDeviceInfo(const tensorflow::DeviceInfo& devices) { + devices_ = devices; + } + // Checks if task has called WaitForAllTasks() previously, which gathers the + // local device info. + bool DeviceInfoIsCollected() { return !devices_.device().empty(); } + + // This is used to propagate state changes (disconnect, error) to ongoing + // barriers. + absl::flat_hash_set GetOngoingBarriers(); + // The task has a new ongoing barrier. This does not mean that it has + // reached the barrier. + void JoinBarrier(absl::string_view barrier_id); + // The task has exited a barrier (because a barrier has passed). + void ExitBarrier(absl::string_view barrier_id); + // Returns true if the task has been disconnected beyond the grace period + // and no further agent requests are expected. Note that the grace period + // accounts for the lag time between the service recording the state change + // and the agent stopping heartbeats/error polling. + bool IsDisconnectedBeyondGracePeriod(); + + private: + std::string task_name_; + // Incarnation ID for CPU:0 on remote task. + uint64_t task_incarnation_ = 0; + + tensorflow::CoordinatedTaskState state_ = + tensorflow::CoordinatedTaskState::TASKSTATE_DISCONNECTED; + absl::Status status_; + absl::Mutex last_heartbeat_mu_; + uint64_t last_heartbeat_us_ ABSL_GUARDED_BY(last_heartbeat_mu_); + // This denotes the deadline after which we stop accepting heartbeats or + // error polling requests from a disconnected task. This grace period + // accounts for the lag time between the service recording the state change + // and the agent stopping heartbeats/error polling. + uint64_t disconnect_grace_period_us_ = 0; + tensorflow::DeviceInfo devices_; + // For now, we assume there won't be many simultaneous barriers so we simply + // use a set. + absl::flat_hash_set ongoing_barriers_for_task_; + // TODO(b/342448688): Re-use config's recoverable jobs instead. + bool recoverable_ = false; + }; + + // AlivenessState tracks the state of pending GetAliveTasks calls. + struct AlivenessState { + // All tasks that can participate in the GetAliveTasks barrier. + CoordinatedTaskSet tasks; + // All tasks currently blocked on the barrier. + CoordinatedTaskSet in_barrier; + // Done callbacks for the tasks blocked on the barrier. + std::vector dones; + }; + + // Returns the set of alive tasks drawn from the provided set of tasks. + CoordinatedTaskSet AliveTasks(const CoordinatedTaskSet& tasks) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + + // Refreshes the AlivenessStates of all pending GetAliveTasks call, + // potentially finishing some of the pending calls. The AlivenessStates should + // be refreshed, for example, after a task has failed. + void RefreshAliveness() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + + static tensorflow::CoordinatedTaskStateInfo CreateTaskStateInfo( + const tensorflow::CoordinatedTask& task, const TaskState& state); + + std::unique_ptr client_cache_; + Env& env_; + const uint64_t service_incarnation_ = random::New64(); + const uint64_t heartbeat_timeout_ms_; + bool cluster_register_with_barrier_ = false; + const absl::Duration cluster_register_timeout_; + const absl::Duration shutdown_barrier_timeout_; + // If a task restarts with a new incarnation, we may allow it to reconnect + // silently if configured. This is useful when we know that a task can + // immediately resume work upon re-connecting to the service. + bool allow_new_incarnation_to_reconnect_ = false; + + std::function + post_aggregate_device_fn_; + + const std::string device_propagation_barrier_id_ = + absl::StrCat("WaitForAllTasks::", service_incarnation_); + const std::string shutdown_barrier_id_ = + absl::StrCat("Shutdown::", service_incarnation_); + std::vector shutdown_barrier_tasks_ + ABSL_GUARDED_BY(state_mu_); + + absl::Mutex state_mu_; + absl::flat_hash_map> cluster_state_ + ABSL_GUARDED_BY(state_mu_); + tensorflow::DeviceInfo cluster_devices_ ABSL_GUARDED_BY(state_mu_); + + absl::Mutex kv_mu_; + // Ordered map to store config key-values + absl::btree_map kv_store_ ABSL_GUARDED_BY(kv_mu_); + absl::flat_hash_map> get_cb_ + ABSL_GUARDED_BY(kv_mu_); + + absl::flat_hash_map barriers_ + ABSL_GUARDED_BY(state_mu_); + // For now, we assume there won't be many simultaneous barriers so we simply + // use a set. + absl::flat_hash_set ongoing_barriers_ ABSL_GUARDED_BY(state_mu_); + + // The state of all pending GetAliveTasks calls. + std::vector aliveness_states_ ABSL_GUARDED_BY(state_mu_); + + absl::flat_hash_set recoverable_jobs_; + + // Whether the agents are polling for error from the service. It will be set + // to true when the service sees the first error polling request. Once set to + // true, the value will never change back to false. + bool client_polling_for_error_ ABSL_GUARDED_BY(state_mu_) = false; + ErrorPollingState error_polling_state_ ABSL_GUARDED_BY(state_mu_); + + absl::CondVar check_staleness_thread_cv_; + bool shutting_down_ ABSL_GUARDED_BY(state_mu_) = false; + // Note: sequence matters here, we must destroy the staleness thread before + // the other state related to barriers and heartbeats to prevent illegal + // memory access. + std::unique_ptr check_staleness_thread_; + + CoordinationService(const CoordinationService&) = delete; + void operator=(const CoordinationService&) = delete; }; +// TODO: b/410607726 - Remove once deprecated CoordinationServiceInterface is +// removed. +using CoordinationServiceInterface = CoordinationService; + } // namespace tsl #endif // XLA_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h index 4b33e65f68ca28..345ed5b1eb2c3a 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -42,11 +42,11 @@ class Env; // CoordinationServiceAgent defines the interface for tasks to communicate with // the coordination service instance (which implements -// CoordinationServiceInterface). One instance of the agent should be deployed +// CoordinationService). One instance of the agent should be deployed // on each task for it to send various requests and stores / retrieves config // key-value data to the service. // -// See CoordinationServiceInterface for more details on coordination service. +// See CoordinationService for more details on coordination service. // // All coordination service errors will have an additional // CoordinationServiceError payload to distinguish themselves from RPC failures. diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index 982bcd5d58a214..dd8d809277ce5f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -109,7 +109,7 @@ class TestCoordinationServiceTaskState { [service = coord_rpc_service_.get()]() { service->HandleRPCsLoop(); })); } - void SetCoordinationService(CoordinationServiceInterface* service) { + void SetCoordinationService(CoordinationService* service) { auto* grpc_coord_service = static_cast(coord_rpc_service_.get()); grpc_coord_service->SetCoordinationServiceInstance(service); @@ -181,7 +181,7 @@ class CoordinationServiceRecoverableJobTest : public ::testing::Test { client_cache->AddTask( /*target=*/"/job:worker/replica:0/task:1", state_worker_1_.GetCoordinationClient()); - coord_service_ = CoordinationServiceInterface::EnableCoordinationService( + coord_service_ = CoordinationService::Create( Env::Default(), coordination_config_, std::move(client_cache)); // Set the service pointer for all the tasks since it is needed for handling // error propagations. In reality, every task has its own service pointer. @@ -224,7 +224,7 @@ class CoordinationServiceRecoverableJobTest : public ::testing::Test { protected: CoordinationServiceConfig coordination_config_; - std::unique_ptr coord_service_; + std::unique_ptr coord_service_; TestCoordinationServiceTaskState state_ps_0_; TestCoordinationServiceTaskState state_ps_1_; TestCoordinationServiceTaskState state_worker_0_; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index 8f73235a3a2883..0f50e5d33044ac 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -48,7 +48,7 @@ void CoordinationServiceRpcHandler::SetAgentInstance( } void CoordinationServiceRpcHandler::SetServiceInstance( - CoordinationServiceInterface* service) { + CoordinationService* service) { absl::MutexLock l(&mu_); service_ = service; } @@ -103,6 +103,7 @@ void CoordinationServiceRpcHandler::WaitForAllTasksAsync( request->source_task(), request->device_info(), [response, service = service_, done = std::move(done)](absl::Status s) { if (s.ok()) { + service->state_mu_.AssertHeld(); *response->mutable_device_info() = service->ListClusterDevices(); } done(s); diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h index 2b365df9b67ca0..5e2986ac8aa951 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h @@ -31,7 +31,7 @@ class CoordinationServiceRpcHandler { void SetAgentInstance(CoordinationServiceAgent* agent); - void SetServiceInstance(CoordinationServiceInterface* service); + void SetServiceInstance(CoordinationService* service); void RegisterTaskAsync(const tensorflow::RegisterTaskRequest* request, tensorflow::RegisterTaskResponse* response, @@ -107,7 +107,7 @@ class CoordinationServiceRpcHandler { private: absl::Mutex mu_; CoordinationServiceAgent* agent_ TF_GUARDED_BY(mu_) = nullptr; - CoordinationServiceInterface* service_ TF_GUARDED_BY(mu_) = nullptr; + CoordinationService* service_ TF_GUARDED_BY(mu_) = nullptr; }; } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 265caeb9c9823a..9d31d421604e9d 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -193,8 +193,8 @@ class CoordinationBarrierTest : public ::testing::Test { } CoordinationServiceConfig config = GetCoordinationServiceConfig(num_tasks); - coord_service_ = CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + coord_service_ = CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); // Register the tasks. for (int i = 0; i < num_tasks; ++i) { absl::Status s = @@ -206,9 +206,7 @@ class CoordinationBarrierTest : public ::testing::Test { } } - CoordinationServiceInterface* GetCoordinationService() { - return coord_service_.get(); - } + CoordinationService* GetCoordinationService() { return coord_service_.get(); } CoordinatedTask GetTask(int i) const { return tasks_[i]; } const std::vector& GetTasks() const { return tasks_; } @@ -236,7 +234,7 @@ class CoordinationBarrierTest : public ::testing::Test { } private: - std::unique_ptr coord_service_; + std::unique_ptr coord_service_; std::vector tasks_; std::vector> clients_; }; @@ -285,8 +283,8 @@ class CoordinateTwoTasksTest : public ::testing::Test { config.set_allow_new_incarnation_to_reconnect(true); } // Init service. - coord_service_ = CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + coord_service_ = CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); } CoordinatedTask task_0_; @@ -297,7 +295,7 @@ class CoordinateTwoTasksTest : public ::testing::Test { const uint64_t incarnation_1_ = random::New64(); const uint64_t incarnation_1_new_ = random::New64(); TestCoordinationClient client_1_; - std::unique_ptr coord_service_; + std::unique_ptr coord_service_; }; // Construct fake device protos. @@ -375,9 +373,9 @@ TEST(CoordinationServiceTest, TestCoordinatedJobs) { client_cache->AddTask("/job:worker/replica:0/task:1", &wi1); TestCoordinationClient ei; client_cache->AddTask("/job:evaluator/replica:0/task:0", &ei); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); // Each coordinated task registers and waits for other tasks. absl::Notification register_chief; @@ -419,10 +417,9 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyConnected_Succeeds) { CoordinatedTask task_0; task_0.set_job_name("worker"); task_0.set_task_id(0); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, - /*cache=*/nullptr); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); // Task connects to coordination service. ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0)); @@ -440,10 +437,9 @@ TEST(CoordinationServiceTest, CoordinatedTask task_0; task_0.set_job_name("worker"); task_0.set_task_id(0); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, - /*cache=*/nullptr); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); // Task connects to coordination service. ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0)); @@ -462,10 +458,9 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) { CoordinatedTask task_0; task_0.set_job_name("worker"); task_0.set_task_id(0); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, - /*cache=*/nullptr); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); // Task connects to coordination service. ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0)); // Arbitrarily set task to be in error. @@ -782,9 +777,9 @@ TEST(CoordinationServiceTest, TryGetKeyValue) { const CoordinationServiceConfig config = GetCoordinationServiceConfig(/*num_tasks=*/1); auto client_cache = std::make_unique(); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); // Try to get nonexistent key. absl::StatusOr result = @@ -896,9 +891,9 @@ TEST(CoordinationServiceTest, ListClusterDevices_TfDevice) { task_2.set_task_id(2); absl::Status status = absl::OkStatus(); auto client_cache = std::make_unique(); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); absl::Notification n; // Map fake devices to each task. DeviceInfo local_devices_0; @@ -922,6 +917,7 @@ TEST(CoordinationServiceTest, ListClusterDevices_TfDevice) { coord_service->WaitForAllTasks(task_2, local_devices_2, [&](absl::Status s) { ASSERT_OK(s); // Gather the cluster device info. + coord_service->state_mu_.AssertHeld(); cluster_devices = coord_service->ListClusterDevices(); n.Notify(); }); @@ -952,9 +948,9 @@ TEST(CoordinationServiceTest, ListClusterDevices_XlaDevice) { task_2.set_task_id(2); absl::Status status = absl::OkStatus(); auto client_cache = std::make_unique(); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); coord_service->SetDeviceAggregationFunction( [](const DeviceInfo& raw_global_devices) { TestDeviceList global_device_list; @@ -997,6 +993,7 @@ TEST(CoordinationServiceTest, ListClusterDevices_XlaDevice) { coord_service->WaitForAllTasks(task_2, local_devices_2, [&](absl::Status s) { ASSERT_OK(s); // Gather the cluster device info. + coord_service->state_mu_.AssertHeld(); cluster_devices = coord_service->ListClusterDevices(); n.Notify(); }); @@ -1031,9 +1028,9 @@ TEST(CoordinationServiceTest, ListClusterDevices_DevicesAreNotAddedTwice) { absl::Status status = absl::OkStatus(); absl::Status initial_wait_for_all_tasks_status; auto client_cache = std::make_unique(); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, std::move(client_cache)); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + std::move(client_cache)); absl::Notification n; // Map fake devices to each task. DeviceInfo local_devices_0; @@ -1060,6 +1057,7 @@ TEST(CoordinationServiceTest, ListClusterDevices_DevicesAreNotAddedTwice) { &cluster_devices, &n](absl::Status s) { ASSERT_OK(s); // Gather the cluster device info. + coord_service->state_mu_.AssertHeld(); cluster_devices = coord_service->ListClusterDevices(); n.Notify(); @@ -2139,9 +2137,9 @@ TEST(CoordinationServiceTest, RecoverableAndNonRecoverableTasks) { worker_job->set_name("worker"); worker_job->set_num_tasks(2); - std::unique_ptr coord_service = - CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, /*cache=*/nullptr); + std::unique_ptr coord_service = + CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); // Each coordinated task registers and polls for errors. ASSERT_OK(coord_service->RegisterTask(chief, /*incarnation=*/0)); @@ -2581,7 +2579,7 @@ TEST_F(GetAliveTasksTest, CallingGetAliveTasksWithoutBeingAMember) { finished.DecrementCount(); }; - CoordinationServiceInterface* s = GetCoordinationService(); + CoordinationService* s = GetCoordinationService(); s->GetAliveTasksAsync(GetTask(0), {GetTask(1), GetTask(2)}, done); s->GetAliveTasksAsync(GetTask(1), {GetTask(0), GetTask(2)}, done); s->GetAliveTasksAsync(GetTask(2), {GetTask(0), GetTask(1)}, done); diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc index 8598a4a56e7ef5..2ecefb7dce5f16 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc @@ -143,14 +143,14 @@ class PreemptionSyncManagerTest : public ::testing::Test { /*thread_options=*/{}, /*name=*/"CoordinationServiceHandleRPCsLoop", [service = coord_rpc_service_.get()]() { service->HandleRPCsLoop(); })); } - std::unique_ptr EnableCoordinationService() { + std::unique_ptr EnableCoordinationService() { CoordinationServiceConfig config; config.set_service_type("standalone"); CoordinatedJob* job = config.mutable_coordinated_job_list()->Add(); job->set_name(kJobName); job->set_num_tasks(2); - return CoordinationServiceInterface::EnableCoordinationService( - Env::Default(), config, /*cache=*/nullptr); + return CoordinationService::Create(Env::Default(), config, + /*cache=*/nullptr); } void InitializeAndConnectCoordinationAgents() { std::unique_ptr coord_client = @@ -175,7 +175,7 @@ class PreemptionSyncManagerTest : public ::testing::Test { } // Coordination service. - std::unique_ptr coord_service_; + std::unique_ptr coord_service_; std::unique_ptr<::grpc::Server> grpc_server_; std::unique_ptr coord_compute_pool_; std::unique_ptr coord_rpc_service_; diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h index 26b71faf1be2f9..8f02cac64e1862 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h @@ -52,7 +52,7 @@ class GrpcCoordinationServiceImpl : public AsyncServiceInterface { void SetCoordinationServiceAgentInstance(CoordinationServiceAgent* agent) { rpc_handler_.SetAgentInstance(agent); } - void SetCoordinationServiceInstance(CoordinationServiceInterface* service) { + void SetCoordinationServiceInstance(CoordinationService* service) { rpc_handler_.SetServiceInstance(service); } CoordinationServiceRpcHandler* GetRpcHandler() { return &rpc_handler_; } From 3146b8375e4430947406b13284af2f82b84ba1cd Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 14 Apr 2025 14:27:58 -0700 Subject: [PATCH 0731/1324] [xla:gpu] CommandBuffer: introduce CreateCommands callback and explicitly use RecordCreate to create conditional command buffers Remove automatic command dependency tracking and rely on CommandBufferCmdSequence passing correct dependencies to command buffer. Also delete dead code from GpuCommandBuffer. PiperOrigin-RevId: 747568918 --- .../gpu/runtime/command_buffer_cmd.cc | 65 ++++---- .../backends/gpu/runtime/command_buffer_cmd.h | 36 +---- .../xla/xla/stream_executor/command_buffer.h | 18 ++- .../cuda/cuda_command_buffer.cc | 36 +---- .../cuda/cuda_command_buffer.h | 6 - .../stream_executor/gpu/gpu_command_buffer.cc | 151 ++++++++---------- .../stream_executor/gpu/gpu_command_buffer.h | 48 +----- .../gpu/gpu_command_buffer_test.cc | 118 ++++++++------ .../rocm/rocm_command_buffer.cc | 27 ---- .../rocm/rocm_command_buffer.h | 6 - 10 files changed, 187 insertions(+), 324 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 86bcc4ae7a7038..8b22edcf66494a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -115,27 +115,29 @@ static absl::string_view ReductionKindString(ReductionKind kind) { } } -// Creates command buffer builder from a cmd sequence. -static se::CommandBuffer::Builder CreateBuilder( - CommandBufferCmdSequence* commands, +// Create a callback to create a command buffer from a command sequence. +static se::CommandBuffer::CreateCommands CreateCommands( + const CommandBufferCmdSequence* commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { - return [=](se::CommandBuffer* command_buffer) { - return commands->Record(*execute_params, *record_params, command_buffer, - CommandBufferCmdSequence::RecordMode::kConditional); + return [=](se::CommandBuffer* command_buffer, + absl::Span dependencies) { + return commands->RecordCreate(*execute_params, *record_params, + command_buffer, dependencies); }; } -// Creates command buffer builders from a span of cmd sequences. -static std::vector CreateBuilders( - absl::Span commands, +// Create callbacks to create a command buffer from command sequences. +static std::vector CreateCommands( + absl::Span commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { - std::vector builders; - for (CommandBufferCmdSequence& cmd : commands) { - builders.push_back(CreateBuilder(&cmd, execute_params, record_params)); + std::vector create_commands; + for (const CommandBufferCmdSequence& cmd : commands) { + create_commands.push_back( + CreateCommands(&cmd, execute_params, record_params)); } - return builders; + return create_commands; } // Create a callback to update a command buffer with command sequence. @@ -268,14 +270,11 @@ absl::Status CommandBufferCmdSequence::Initialize( absl::Status CommandBufferCmdSequence::Record( const Thunk::ExecuteParams& execute_params, const CommandBufferCmd::RecordParams& record_params, - se::CommandBuffer* command_buffer, RecordMode mode) { - VLOG(3) << "Record " << commands_.size() - << " commands into command buffer; mode=" << absl::StrCat(mode); + se::CommandBuffer* command_buffer) { + VLOG(3) << "Record " << commands_.size() << " commands into command buffer"; - if (mode == RecordMode::kExclusive) { - if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { - TF_RETURN_IF_ERROR(command_buffer->Update()); - } + if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { + TF_RETURN_IF_ERROR(command_buffer->Update()); } if (command_buffer->state() == se::CommandBuffer::State::kUpdate) { @@ -287,18 +286,14 @@ absl::Status CommandBufferCmdSequence::Record( .status()); } - if (mode == RecordMode::kExclusive) { - TF_RETURN_IF_ERROR(command_buffer->Finalize()); - } - - return absl::OkStatus(); + return command_buffer->Finalize(); } absl::StatusOr> CommandBufferCmdSequence::RecordCreate( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, - absl::Span dependencies) { + absl::Span dependencies) const { // Command buffer must be in create state. TF_RETURN_IF_ERROR(CheckCommandBufferState( command_buffer, se::CommandBuffer::State::kCreate)); @@ -958,15 +953,17 @@ absl::StatusOr CaseCmd::Record( return Handle( std::move(record_action), [&](absl::Span dependencies) { - auto branches = CreateBuilders(absl::MakeSpan(branches_), - &execute_params, &record_params); if (index_is_bool_) { - return command_buffer->CreateCase(se::DeviceMemory(index), - std::move(branches), dependencies); + return command_buffer->CreateCase( + se::DeviceMemory(index), + CreateCommands(branches_, &execute_params, &record_params), + dependencies); } else { - return command_buffer->CreateCase(se::DeviceMemory(index), - std::move(branches), dependencies); + return command_buffer->CreateCase( + se::DeviceMemory(index), + CreateCommands(branches_, &execute_params, &record_params), + dependencies); } }, [&](const se::CommandBuffer::Command* command) { @@ -1032,8 +1029,8 @@ absl::StatusOr WhileCmd::Record( [&](absl::Span dependencies) { return command_buffer->CreateWhile( se::DeviceMemory(pred), - CreateBuilder(&cond_commands_, &execute_params, &record_params), - CreateBuilder(&body_commands_, &execute_params, &record_params), + CreateCommands(&cond_commands_, &execute_params, &record_params), + CreateCommands(&body_commands_, &execute_params, &record_params), dependencies); }, [&](const se::CommandBuffer::Command* command) { diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index c415a789008384..eab8981d693e3c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -291,33 +291,6 @@ class CommandBufferCmdSequence { using RecordParams = CommandBufferCmd::RecordParams; - enum class RecordMode { - // In exclusive mode no one else is recording commands into the command - // buffer argument, and cmd sequence is responsible for updating command - // buffer state: finalizing after all commands recorded, and - // switching to update state before recording updates. - kExclusive, - - // In conditional mode multiple cmd sequences can be recorded into the - // command buffer argument, and with command buffer state managed externally - // cmd sequence should not finalize or update it. This mode is used when - // command buffer cmd sequence is recorded into conditional command buffers - // owned by the parent command buffer. - kConditional - }; - - template - friend void AbslStringify(Sink& sink, RecordMode mode) { - switch (mode) { - case RecordMode::kExclusive: - sink.Append("exclusive"); - break; - case RecordMode::kConditional: - sink.Append("conditional"); - break; - } - } - // Synchronization mode defines how much concurrency is allowed between // commands in the sequence. enum class SynchronizationMode { @@ -372,14 +345,9 @@ class CommandBufferCmdSequence { // buffer state. This method assumes that no other command buffer sequence is // recorded into the same command buffer, and doesn't set up initial // dependencies for recorded commands. - // - // TODO(b/406370928): This API must be removed, and instead users should - // explicitly call `RecordCreate` or `RecordUpdate` depending on what they - // want to do. absl::Status Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, - se::CommandBuffer* command_buffer, - RecordMode mode = RecordMode::kExclusive); + se::CommandBuffer* command_buffer); // Records command creation into the command buffer. Command buffer must be // in create state. The next command sequence recorded into the same command @@ -388,7 +356,7 @@ class CommandBufferCmdSequence { absl::StatusOr> RecordCreate( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, - absl::Span dependencies); + absl::Span dependencies) const; // Records command updates into the command buffer. Command buffer must be // in update state. diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index e09633a92725ac..a6b2861313f137 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/functional/any_invocable.h" @@ -65,8 +64,14 @@ class CommandBuffer { Command& operator=(Command&&) = default; }; - // Builder constructs nested command buffers owned by a parent command buffer. - using Builder = std::function; + // A callback to construct a nested `command_buffer` by creating commands in + // it. Created commands must execute after `dependencies`, and the callback + // must return a vector of commands that will be used as external dependencies + // for the next callback recording into the same command buffer. + using CreateCommands = + absl::AnyInvocable>( + CommandBuffer* command_buffer, + absl::Span dependencies)>; // A callback to update a nested `command_buffer` owned by a conditional // command. At command buffer update time we can't change the dependency @@ -219,11 +224,11 @@ class CommandBuffer { // // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case virtual absl::StatusOr CreateCase( - DeviceMemory index, std::vector branches, + DeviceMemory index, std::vector create_branches, absl::Span dependencies) = 0; virtual absl::StatusOr CreateCase( - DeviceMemory index, std::vector branches, + DeviceMemory index, std::vector create_branches, absl::Span dependencies) = 0; // Updates a Case command. @@ -249,7 +254,8 @@ class CommandBuffer { // cond_builder() // virtual absl::StatusOr CreateWhile( - DeviceMemory pred, Builder cond_builder, Builder body_builder, + DeviceMemory pred, CreateCommands create_cond, + CreateCommands create_body, absl::Span dependencies) = 0; // Updates a While command. diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc index e156f4a4c88260..4b3bdf9c7a547d 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -529,30 +529,6 @@ absl::Status CudaCommandBuffer::UpdateKernelNode( "Failed to set CUDA graph kernel node params"); } -absl::StatusOr CudaCommandBuffer::CreateBarrierNode( - absl::Span dependencies) { - if (parent_->GetDeviceDescription().driver_version() < - SemanticVersion(12, 4, 0)) { - // Instead of empty nodes we create no-op kernel nodes as barriers because - // CUDA 12.3 does not support empty nodes inside conditional command - // buffers. - TF_ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel()); - return CreateKernelNode(dependencies, ThreadDim{1, 1, 1}, BlockDim{1, 1, 1}, - **noop, KernelArgsPackedArray<0>()); - } - - VLOG(2) << "Add empty node to a graph " << graph_ - << "; deps: " << dependencies.size(); - - CUgraphNode barrier_handle = nullptr; - std::vector deps = ToCudaGraphHandles(dependencies); - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuGraphAddEmptyNode(&barrier_handle, graph_, deps.data(), deps.size()), - "Failed to add empty node to a CUDA graph")); - - return FromCudaGraphHandle(barrier_handle); -} - absl::Status CudaCommandBuffer::Trace( Stream* stream, absl::AnyInvocable function) { #if CUDA_VERSION < 12030 @@ -606,23 +582,13 @@ absl::Status CudaCommandBuffer::Trace( #endif } -absl::Status CudaCommandBuffer::SetNodeExecutionEnabled( - GraphNodeHandle node_handle, bool enabled) { - // Node is enabled if value != 0, otherwise the node is disabled. - unsigned value = enabled ? 1 : 0; - VLOG(2) << "Set CUDA executable graph " << exec_ << " node " << node_handle - << " enabled flag to " << value; - return cuda::ToStatus( - cuGraphNodeSetEnabled(exec_, ToCudaGraphHandle(node_handle), value), - "Failed to set CUDA graph node enabled flag"); -} - absl::Status CudaCommandBuffer::LaunchGraph(Stream* stream) { VLOG(3) << "Launch command buffer executable graph " << exec_ << " on a stream: " << stream; return cuda::ToStatus(cuGraphLaunch(exec_, AsGpuStreamValue(stream)), "Failed to launch CUDA graph"); } + absl::StatusOr CudaCommandBuffer::GetNodeCount() const { size_t num_nodes; TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h index 3cb1fa14203c41..05a2c475a6ed81 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h @@ -148,15 +148,9 @@ class CudaCommandBuffer final : public GpuCommandBuffer { const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& args) override; - absl::StatusOr CreateBarrierNode( - absl::Span dependencies) override; - absl::Status Trace(Stream* stream, absl::AnyInvocable function) override; - absl::Status SetNodeExecutionEnabled(GraphNodeHandle node_handle, - bool enabled) override; - absl::Status LaunchGraph(Stream* stream) override; absl::StatusOr GetNodeCount() const override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 3b28c8c943a221..041b886b84b2d4 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -80,40 +80,6 @@ int64_t GpuCommandBuffer::AliveExecs() { GpuCommandBuffer::GpuCommandBuffer(Mode mode, StreamExecutor* parent) : mode_(mode), parent_(parent) {} -GpuCommandBuffer::Dependencies GpuCommandBuffer::GetAutoDependencies() const { - if (commands_.empty()) return Dependencies{}; - - const Command* command = commands_.back().get(); - - if (auto* gpu_command = dynamic_cast(command)) { - return Dependencies{gpu_command->handle}; - } - - if (auto* gpu_command = dynamic_cast(command)) { - return Dependencies{gpu_command->then_conditional_node.handle}; - } - - if (auto* gpu_command = dynamic_cast(command)) { - return Dependencies{gpu_command->then_conditional_node.handle, - gpu_command->else_conditional_node.handle}; - } - - if (auto* gpu_command = dynamic_cast(command)) { - Dependencies dependencies; - for (const auto& conditional_node : gpu_command->conditional_nodes) { - dependencies.push_back(conditional_node.handle); - } - return dependencies; - } - - if (auto* gpu_command = dynamic_cast(command)) { - return Dependencies{gpu_command->conditional_node.handle}; - } - - CHECK(false) << "Unsupported command type"; // Crash OK - return Dependencies{}; -} - absl::Status GpuCommandBuffer::CheckNotFinalized() { if (state_ == State::kFinalized) return absl::InternalError( @@ -130,6 +96,31 @@ absl::Status GpuCommandBuffer::CheckInState(State state) { return absl::OkStatus(); } +std::vector +GpuCommandBuffer::ToGraphNodeDependencies( + absl::Span dependencies) { + std::vector handles; + + for (const Command* dep : dependencies) { + if (auto* gpu_command = dynamic_cast(dep)) { + handles.push_back(gpu_command->handle); + + } else if (auto* gpu_command = dynamic_cast(dep)) { + for (const auto& conditional_node : gpu_command->conditional_nodes) { + handles.push_back(conditional_node.handle); + } + + } else if (auto* gpu_command = dynamic_cast(dep)) { + handles.push_back(gpu_command->conditional_node.handle); + + } else { + LOG(FATAL) << "Unsupported command type"; // Crash OK + } + } + + return handles; +} + absl::StatusOr GpuCommandBuffer::CreateLaunchWithPackedArgs( const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, @@ -141,12 +132,9 @@ GpuCommandBuffer::CreateLaunchWithPackedArgs( packed_args.number_of_arguments()); // Adds a new kernel node to the graph under construction. - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN( - GraphNodeHandle handle, - CreateKernelNode(barrier, threads, blocks, kernel, packed_args)); + TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, + CreateKernelNode(ToGraphNodeDependencies(dependencies), + threads, blocks, kernel, packed_args)); return AppendCommand(GpuCommand{handle}); } @@ -224,10 +212,9 @@ GpuCommandBuffer::CreateNestedCommand( absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, CreateChildNode(barrier, nested)); + TF_ASSIGN_OR_RETURN( + GraphNodeHandle handle, + CreateChildNode(ToGraphNodeDependencies(dependencies), nested)); return AppendCommand(GpuCommand{handle}); } @@ -244,11 +231,9 @@ absl::StatusOr GpuCommandBuffer::CreateMemcpyD2D( absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, - CreateMemcpyD2DNode(barrier, *dst, src, size)); + CreateMemcpyD2DNode(ToGraphNodeDependencies(dependencies), + *dst, src, size)); return AppendCommand(GpuCommand{handle}); } @@ -267,12 +252,9 @@ absl::StatusOr GpuCommandBuffer::CreateMemset( absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN( - GraphNodeHandle handle, - CreateMemsetNode(barrier, *dst, bit_pattern, num_elements)); + TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, + CreateMemsetNode(ToGraphNodeDependencies(dependencies), + *dst, bit_pattern, num_elements)); return AppendCommand(GpuCommand{handle}); } @@ -303,11 +285,10 @@ GpuCommandBuffer::CreateDnnGraphCommand( tensorflow::down_cast(*nested); TF_RETURN_IF_ERROR( nested_gpu.PopulateDnnGraphNode(dnn_graph, stream, operands)); - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - TF_ASSIGN_OR_RETURN(GraphNodeHandle handle, - CreateChildNode(barrier, *nested)); + + TF_ASSIGN_OR_RETURN( + GraphNodeHandle handle, + CreateChildNode(ToGraphNodeDependencies(dependencies), *nested)); return AppendCommand(GpuCommand{handle}); } @@ -343,7 +324,7 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { absl::StatusOr GpuCommandBuffer::CreateCase( DeviceMemory index, bool index_is_bool, - std::vector branches, + std::vector create_branches, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); @@ -351,18 +332,17 @@ absl::StatusOr GpuCommandBuffer::CreateCase( GpuCaseCommand command = {}; - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); + std::vector node_dependencies = + ToGraphNodeDependencies(dependencies); int32_t batch_offset = 0; - while (batch_offset < branches.size()) { + while (batch_offset < create_branches.size()) { // Conditionals will by default run branches[branchs.size()-1] if index is // `< 0` or `>= branches.size()`. See // https://openxla.org/xla/operation_semantics#conditional. // To break down a large case with back to back ConditionalCommands, only // the last batch should accept this default case. - int32_t remaining_branches = branches.size() - batch_offset; + int32_t remaining_branches = create_branches.size() - batch_offset; int32_t batch_size; bool enable_conditional_default; if (remaining_branches <= kBranchBatchSize) { @@ -379,7 +359,7 @@ absl::StatusOr GpuCommandBuffer::CreateCase( TF_ASSIGN_OR_RETURN(auto set_condition_node, CreateSetCaseConditionNode( conditionals, index, index_is_bool, batch_offset, - enable_conditional_default, barrier)); + enable_conditional_default, node_dependencies)); std::vector conditional_nodes; for (int z = 0; z < batch_size; ++z) { @@ -391,7 +371,9 @@ absl::StatusOr GpuCommandBuffer::CreateCase( GpuCommandBuffer* case_command_buffer = conditional_nodes.back().command_buffer.get(); - TF_RETURN_IF_ERROR(branches[branch_offset](case_command_buffer)); + TF_RETURN_IF_ERROR(create_branches[branch_offset](case_command_buffer, + /*dependencies=*/{}) + .status()); TF_RETURN_IF_ERROR(case_command_buffer->Finalize()); } @@ -458,19 +440,19 @@ absl::Status GpuCommandBuffer::UpdateCase( } absl::StatusOr GpuCommandBuffer::CreateCase( - DeviceMemory index, std::vector branches, + DeviceMemory index, std::vector create_branches, absl::Span dependencies) { return CreateCase( DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), - /*index_is_bool=*/false, branches, dependencies); + /*index_is_bool=*/false, std::move(create_branches), dependencies); } absl::StatusOr GpuCommandBuffer::CreateCase( - DeviceMemory index, std::vector branches, + DeviceMemory index, std::vector create_branches, absl::Span dependencies) { return CreateCase( DeviceMemory::MakeFromByteSize(index.opaque(), index.size()), - /*index_is_bool=*/true, branches, dependencies); + /*index_is_bool=*/true, std::move(create_branches), dependencies); } absl::Status GpuCommandBuffer::UpdateCase( @@ -492,37 +474,32 @@ absl::Status GpuCommandBuffer::UpdateCase( } absl::StatusOr GpuCommandBuffer::CreateWhile( - DeviceMemory pred, Builder cond_builder, Builder body_builder, - absl::Span dependencies) { + DeviceMemory pred, CreateCommands create_cond, + CreateCommands create_body, absl::Span dependencies) { TF_RETURN_IF_ERROR(CheckInState(State::kCreate)); GpuWhileCommand command = {}; - Dependencies barrier = dependencies.empty() - ? GetAutoDependencies() - : ToGraphNodeDependencies(dependencies); - - // TODO(ezhulenev): cond_builder should be able to take dependencies. - (void)barrier; - - TF_RETURN_IF_ERROR(cond_builder(this)); + TF_ASSIGN_OR_RETURN(auto init_cond, create_cond(this, dependencies)); TF_ASSIGN_OR_RETURN(command.conditional, CreateConditionalHandle()); - TF_ASSIGN_OR_RETURN(command.set_init_condition_node, - CreateSetWhileConditionNode(command.conditional, pred, - GetAutoDependencies())); + TF_ASSIGN_OR_RETURN( + command.set_init_condition_node, + CreateSetWhileConditionNode(command.conditional, pred, + ToGraphNodeDependencies(init_cond))); TF_ASSIGN_OR_RETURN( command.conditional_node, CreateConditionalNode({command.set_init_condition_node}, command.conditional, ConditionType::kWhile)); GpuCommandBuffer* body = command.conditional_node.command_buffer.get(); - TF_RETURN_IF_ERROR(body_builder(body)); - TF_RETURN_IF_ERROR(cond_builder(body)); + TF_ASSIGN_OR_RETURN(auto body_commands, + create_body(body, /*dependencies=*/{})); + TF_ASSIGN_OR_RETURN(auto update_cond, create_cond(body, body_commands)); TF_ASSIGN_OR_RETURN( command.set_body_condition_node, body->CreateSetWhileConditionNode(command.conditional, pred, - body->GetAutoDependencies())); + ToGraphNodeDependencies(update_cond))); TF_RETURN_IF_ERROR(command.conditional_node.command_buffer->Finalize()); return AppendCommand(std::move(command)); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 8989327eef298e..49f02d3ea1164a 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/casts.h" namespace stream_executor::gpu { @@ -89,22 +89,6 @@ class GpuCommandBuffer : public CommandBuffer { GraphNodeHandle handle = nullptr; }; - // A GPU command recorded for the If operation. - struct GpuIfCommand : public CommandBuffer::Command { - GraphConditionalHandle then_conditional; - GraphNodeHandle set_condition_node; - GraphConditionalNodeHandle then_conditional_node; - }; - - // A GPU command recorded for the IfElse operation. - struct GpuIfElseCommand : public CommandBuffer::Command { - GraphConditionalHandle then_conditional; - GraphConditionalHandle else_conditional; - GraphNodeHandle set_condition_node; - GraphConditionalNodeHandle then_conditional_node; - GraphConditionalNodeHandle else_conditional_node; - }; - // A GPU command recorded for the Case operation. struct GpuCaseCommand : public CommandBuffer::Command { std::vector conditionals; @@ -168,11 +152,11 @@ class GpuCommandBuffer : public CommandBuffer { absl::Span operands) override; absl::StatusOr CreateCase( - DeviceMemory index, std::vector branches, + DeviceMemory index, std::vector create_branches, absl::Span dependencies) override; absl::StatusOr CreateCase( - DeviceMemory index, std::vector branches, + DeviceMemory index, std::vector create_branches, absl::Span dependencies) override; absl::Status UpdateCase(const Command* command, DeviceMemory index, @@ -182,7 +166,8 @@ class GpuCommandBuffer : public CommandBuffer { std::vector update_branches) override; absl::StatusOr CreateWhile( - DeviceMemory pred, Builder cond_builder, Builder body_builder, + DeviceMemory pred, CreateCommands create_cond, + CreateCommands create_body, absl::Span dependencies) override; absl::Status UpdateWhile(const Command* command, DeviceMemory pred, @@ -221,8 +206,6 @@ class GpuCommandBuffer : public CommandBuffer { absl::StatusOr> CreateConditionalHandles( size_t num_handles); - Dependencies GetAutoDependencies() const; - //===--------------------------------------------------------------------===// // APIs for launching kernels to update conditional handles. //===--------------------------------------------------------------------===// @@ -282,7 +265,7 @@ class GpuCommandBuffer : public CommandBuffer { private: absl::StatusOr CreateCase( DeviceMemory index, bool index_is_bool, - std::vector branches, + std::vector create_branches, absl::Span dependencies); absl::Status UpdateCase(const Command* command, DeviceMemory index, @@ -297,15 +280,8 @@ class GpuCommandBuffer : public CommandBuffer { } // Converts a list of command dependencies to a list of graph node handles. - Dependencies ToGraphNodeDependencies( - absl::Span dependencies) { - Dependencies handles; - for (const Command* dependency : dependencies) { - auto* gpu_command = tsl::down_cast(dependency); - handles.push_back(gpu_command->handle); - } - return handles; - } + std::vector ToGraphNodeDependencies( + absl::Span dependencies); //===--------------------------------------------------------------------===// // APIs for creating and updating underlying GPU graph nodes. @@ -372,14 +348,6 @@ class GpuCommandBuffer : public CommandBuffer { //===--------------------------------------------------------------------===// - // Creates a new no-op node acting as a barrier and adds it to the graph. - virtual absl::StatusOr CreateBarrierNode( - absl::Span dependencies) = 0; - - // Enables or disables the execution of the given node in the graph. - virtual absl::Status SetNodeExecutionEnabled(GraphNodeHandle node_handle, - bool enabled) = 0; - // Launches an instantiated graph. Only supported on primary command buffers. virtual absl::Status LaunchGraph(Stream* stream) = 0; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 121b6378fa9970..2f7697b8a22e07 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -32,13 +33,13 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/trace_command_buffer_factory.h" #include "xla/stream_executor/typed_kernel_factory.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" @@ -80,6 +81,12 @@ static bool IsAtLeastCuda12300( return true; } +absl::StatusOr> Wrap( + absl::StatusOr command) { + TF_RETURN_IF_ERROR(command.status()); + return std::vector{*command}; +} + TEST(GpuCommandBufferTest, LaunchSingleKernel) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); @@ -381,20 +388,23 @@ TEST(GpuCommandBufferTest, ConditionalCaseEmptyGraph) { TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // if (index == 0) c = a + b - CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { - return branch0_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c) - .status(); + CommandBuffer::CreateCommands branch0 = [&](CommandBuffer* b0, auto deps) { + return Wrap(b0->CreateLaunch(add, ThreadDim(), BlockDim(4), deps, a, b, c)); }; - // if (index == 1) c = a * b - CommandBuffer::Builder branch1 = [&](CommandBuffer* branch1_cmd) { - return absl::OkStatus(); + // if (index == 1) + CommandBuffer::CreateCommands branch1 = [&](CommandBuffer*, auto deps) { + return std::vector{}; }; + std::vector branches; + branches.push_back(std::move(branch0)); + branches.push_back(std::move(branch1)); + // Create a command buffer with a single conditional operation. TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK(cmd_buffer->CreateCase(index, {branch0, branch1}, {})); + TF_ASSERT_OK(cmd_buffer->CreateCase(index, std::move(branches), {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -471,7 +481,7 @@ TEST_P(GpuCommandBufferCaseTest, ConditionalMultiCase) { const int kNumCases = GetNumCases(); std::vector> values; std::vector> results; - std::vector branches; + std::vector branches; values.resize(kNumCases); results.resize(kNumCases); branches.resize(kNumCases); @@ -480,19 +490,18 @@ TEST_P(GpuCommandBufferCaseTest, ConditionalMultiCase) { TF_ASSERT_OK(stream->Memset32(&values[i], i, byte_length)); results[i] = executor->AllocateArray(kLength, 0); TF_ASSERT_OK(stream->Memset32(&results[i], 0, byte_length)); - branches[i] = [&, i](CommandBuffer* branch_cmd) { + branches[i] = [&, i](CommandBuffer* branch_cmd, auto dependencies) { // result = i * i; - return branch_cmd - ->CreateLaunch(mul, ThreadDim(), BlockDim(kLength), {}, values[i], - values[i], results[i]) - .status(); + return Wrap(branch_cmd->CreateLaunch(mul, ThreadDim(), BlockDim(kLength), + dependencies, values[i], values[i], + results[i])); }; } // Create a command buffer with a single conditional operation. TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, executor->CreateCommandBuffer(primary)); - TF_ASSERT_OK(cmd_buffer->CreateCase(index, branches, {})); + TF_ASSERT_OK(cmd_buffer->CreateCase(index, std::move(branches), {})); TF_ASSERT_OK(cmd_buffer->Finalize()); // We test the out of bounds cases as well ( i < 0, i >= kNumCases). @@ -568,20 +577,22 @@ TEST(GpuCommandBufferTest, ConditionalCase) { TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // if (index == 0) c = a + b - CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { - return branch0_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c) - .status(); + CommandBuffer::CreateCommands branch0 = [&](CommandBuffer* b0, auto deps) { + return Wrap(b0->CreateLaunch(add, ThreadDim(), BlockDim(4), deps, a, b, c)); }; // if (index == 1) c = a * b - CommandBuffer::Builder branch1 = [&](CommandBuffer* branch1_cmd) { - return branch1_cmd->CreateLaunch(mul, ThreadDim(), BlockDim(4), {}, a, b, c) - .status(); + CommandBuffer::CreateCommands branch1 = [&](CommandBuffer* b1, auto deps) { + return Wrap(b1->CreateLaunch(mul, ThreadDim(), BlockDim(4), deps, a, b, c)); }; + std::vector branches; + branches.push_back(std::move(branch0)); + branches.push_back(std::move(branch1)); + // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->CreateCase(index, {branch0, branch1}, {})); + TF_ASSERT_OK(cmd_buffer->CreateCase(index, std::move(branches), {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -665,23 +676,23 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { TF_ASSERT_OK(stream->MemZero(&b, byte_length)); // Loop cond: loop_counter++ < num_iters; - CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { - return cond_cmd - ->CreateLaunch(inc_and_cmp, ThreadDim(), BlockDim(), {}, loop_counter, - pred, num_iters) - .status(); + CommandBuffer::CreateCommands create_cond = [&](CommandBuffer* cond_cmd, + auto deps) { + return Wrap(cond_cmd->CreateLaunch(inc_and_cmp, ThreadDim(), BlockDim(), {}, + loop_counter, pred, num_iters)); }; // Loop body: b = a + b - CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { - return body_cmd - ->CreateLaunch(add, ThreadDim(), BlockDim(length), {}, a, b, b) - .status(); + CommandBuffer::CreateCommands create_body = [&](CommandBuffer* body_cmd, + auto deps) { + return Wrap(body_cmd->CreateLaunch(add, ThreadDim(), BlockDim(length), {}, + a, b, b)); }; // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, cond_builder, body_builder, {})); + TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, std::move(create_cond), + std::move(create_body), {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); @@ -738,37 +749,46 @@ TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) { TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); - CommandBuffer::Builder then_builder = + CommandBuffer::CreateCommands create_then = // Then body: b = a + b - [&](CommandBuffer* then_cmd) { - return then_cmd - ->CreateLaunch(add, ThreadDim(), BlockDim(length), {}, a, b, b) - .status(); + [&](CommandBuffer* then_cmd, auto deps) { + return Wrap(then_cmd->CreateLaunch(add, ThreadDim(), BlockDim(length), + deps, a, b, b)); }; + CommandBuffer::CreateCommands create_else = + // Else body: b = a + b + [&](CommandBuffer* then_cmd, auto deps) { + return Wrap(then_cmd->CreateLaunch(add, ThreadDim(), BlockDim(length), + deps, a, b, b)); + }; + + std::vector branches; + branches.push_back(std::move(create_then)); + branches.push_back(std::move(create_else)); + auto nested_cmd = executor->CreateCommandBuffer(nested).value(); // TODO(b/339653343): Adding this Case condition causes AddNestedCommandBuffer // to fail. - TF_ASSERT_OK( - nested_cmd->CreateCase(pred_then, {then_builder, then_builder}, {})); + TF_ASSERT_OK(nested_cmd->CreateCase(pred_then, std::move(branches), {})); // Loop cond: loop_counter++ < num_iters; - CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { - return cond_cmd - ->CreateLaunch(inc_and_cmp, ThreadDim(), BlockDim(length), {}, - loop_counter, pred, num_iters) - .status(); + CommandBuffer::CreateCommands create_cond = [&](CommandBuffer* cond_cmd, + auto deps) { + return Wrap(cond_cmd->CreateLaunch(inc_and_cmp, ThreadDim(), + BlockDim(length), deps, loop_counter, + pred, num_iters)); }; - CommandBuffer::Builder body_builder = - [&](CommandBuffer* body_cmd) -> absl::Status { - CHECK_OK(body_cmd->CreateNestedCommand(*nested_cmd, {})); - return absl::OkStatus(); + CommandBuffer::CreateCommands create_body = [&](CommandBuffer* body_cmd, + auto deps) { + return Wrap(body_cmd->CreateNestedCommand(*nested_cmd, deps)); }; // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); - TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, cond_builder, body_builder, {})); + TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, std::move(create_cond), + std::move(create_body), {})); TF_ASSERT_OK(cmd_buffer->Finalize()); TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc index 27c4e5d0572046..a3c73f53dad809 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc @@ -337,22 +337,6 @@ absl::Status RocmCommandBuffer::UpdateKernelNode( "Failed to set HIP graph kernel node params"); } -absl::StatusOr RocmCommandBuffer::CreateBarrierNode( - absl::Span dependencies) { - VLOG(2) << "Add empty node to a graph " << graph_ - << "; deps: " << dependencies.size(); - - hipGraphNode_t barrier_handle = nullptr; - std::vector deps = ToHipGraphHandles(dependencies); - - TF_RETURN_IF_ERROR( - ToStatus(wrap::hipGraphAddEmptyNode(&barrier_handle, graph_, deps.data(), - deps.size()), - "Failed to add empty node to a HIP graph")); - - return FromHipGraphHandle(barrier_handle); -} - absl::Status RocmCommandBuffer::Trace( Stream* stream, absl::AnyInvocable function) { TF_RETURN_IF_ERROR(CheckNotFinalized()); @@ -396,17 +380,6 @@ absl::Status RocmCommandBuffer::Trace( return absl::OkStatus(); } -absl::Status RocmCommandBuffer::SetNodeExecutionEnabled( - GraphNodeHandle node_handle, bool enabled) { - // Node is enabled if value != 0, otherwise the node is disabled. - unsigned value = enabled ? 1 : 0; - VLOG(2) << "Set HIP executable graph " << exec_ << " node " << node_handle - << " enabled flag to " << value; - return ToStatus( - wrap::hipGraphNodeSetEnabled(exec_, ToHipGraphHandle(node_handle), value), - "Failed to set HIP graph node enabled flag"); -} - absl::Status RocmCommandBuffer::LaunchGraph(Stream* stream) { VLOG(3) << "Launch command buffer executable graph " << exec_ << " on a stream: " << stream; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h index 1ac7412a8b8955..2edce7679ce437 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h @@ -134,15 +134,9 @@ class RocmCommandBuffer : public GpuCommandBuffer { const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& args) override; - absl::StatusOr CreateBarrierNode( - absl::Span dependencies) override; - absl::Status Trace(Stream* stream, absl::AnyInvocable function) override; - absl::Status SetNodeExecutionEnabled(GraphNodeHandle node_handle, - bool enabled) override; - absl::Status LaunchGraph(Stream* stream) override; absl::StatusOr GetNodeCount() const override; From 36c70c1fba22a0c3002926d32b5f95013d87c945 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:30:04 -0700 Subject: [PATCH 0732/1324] Run build_cleaner on BUILD file(s) located in /xla/backends/cpu/codegen. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747569717 --- third_party/xla/xla/backends/cpu/codegen/BUILD | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/BUILD b/third_party/xla/xla/backends/cpu/codegen/BUILD index a6f6cae3d5a51c..ed4c057836cc95 100644 --- a/third_party/xla/xla/backends/cpu/codegen/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/BUILD @@ -290,7 +290,6 @@ cc_library( deps = [ ":target_machine_features", "//xla:shape_util", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/algorithm:container", @@ -312,7 +311,6 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:hlo_module_config", "//xla/service/llvm_ir:ir_array", @@ -445,7 +443,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:JITLink", - "@llvm-project//llvm:OrcJIT", + "@llvm-project//llvm:OrcJIT", # buildcleaner: keep "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", ], @@ -538,8 +536,6 @@ xla_cc_test( ":symbol_name_util", "//xla/tsl/platform:statusor", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) From fa1246bee1de9f0e59f2e4830deb4f8f962e5b17 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:36:04 -0700 Subject: [PATCH 0733/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt/cpu. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747571833 --- third_party/xla/xla/pjrt/cpu/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 299bb4f33f2d1b..3150a55b46010b 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -19,7 +19,6 @@ xla_cc_test( ":cpu_client", "//xla/pjrt:pjrt_client_test_common", "//xla/tsl/platform:test_main", - "@com_google_googletest//:gtest", ], ) From 8ca308fb689b4cce78d8d53e7988c3a8c2ea215e Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:40:03 -0700 Subject: [PATCH 0734/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt/gpu. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747573310 --- third_party/xla/xla/pjrt/gpu/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 317ce13bec7a7c..06efc16f215a44 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -323,9 +323,6 @@ cc_library( "//xla/service/gpu:executable_proto_cc", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/rocm:rocm_platform_id", - "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", From 6e193b32b6bafe82d8de78c1fb3d72b9e6d2a5f9 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:40:19 -0700 Subject: [PATCH 0735/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt/gpu/tfrt. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747573415 --- third_party/xla/xla/pjrt/gpu/tfrt/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD index 3380c24718838b..548ec464d64eec 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD +++ b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD @@ -198,8 +198,6 @@ xla_cc_test( "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) From c9a3840b0d729b889f7dfcf8a4e48ded24b1079a Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:40:27 -0700 Subject: [PATCH 0736/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt/plugin/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747573471 --- third_party/xla/xla/pjrt/plugin/example_plugin/BUILD | 1 - third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD | 1 - 2 files changed, 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD b/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD index f9a174cc97fde2..101d50dbb35a1a 100644 --- a/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD +++ b/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD @@ -58,7 +58,6 @@ xla_cc_test( "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD b/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD index 979a1145132fe2..2e3bc46d7cf5d5 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD @@ -15,7 +15,6 @@ cc_library( ], hdrs = ["xla_cpu_pjrt_client.h"], deps = [ - ":cpu_client_options", "//xla/pjrt:pjrt_client", "//xla/pjrt/cpu:cpu_client", "@com_google_absl//absl/status:statusor", From 44104cbac6239f06410a14b2b8026cd55fc17aee Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:40:38 -0700 Subject: [PATCH 0737/1324] Run build_cleaner on BUILD file(s) located in /xla/runtime/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747573524 --- .../xla/xla/runtime/large_hlo_snapshot_serialization/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD index 97172f3aec2f81..98932ee3ad6401 100644 --- a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD +++ b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD @@ -30,8 +30,6 @@ xla_cc_test( "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) @@ -72,6 +70,5 @@ xla_cc_test( "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", ], ) From 4b5f3506c36646cca842a2bfc732d210b29695d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 14:42:24 -0700 Subject: [PATCH 0738/1324] Increase the size of `dlpack_test` to prevent timeouts. PiperOrigin-RevId: 747574192 --- tensorflow/c/eager/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 4e8f5156761f9c..e4c2c92783d4d8 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -1119,7 +1119,7 @@ cc_library( tf_cuda_cc_test( name = "dlpack_test", - size = "small", + size = "medium", srcs = [ "dlpack_test.cc", ], From 15e6753a92a2a8f36b4a34a9582f2e514ff95fc8 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:43:15 -0700 Subject: [PATCH 0739/1324] Run build_cleaner on BUILD file(s) located in /xla/gpu/autotuning. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747574566 --- third_party/xla/xla/service/gpu/autotuning/BUILD | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 1e7d3629511875..fec7cb04b678d4 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -358,17 +358,13 @@ cc_library( ":autotuner_status_key", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", - "//xla:shape_util", "//xla:status_macros", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:dump", - "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/platform:env", @@ -442,7 +438,6 @@ xla_test( "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], @@ -477,7 +472,6 @@ xla_test( "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", ], ) @@ -582,7 +576,6 @@ xla_test( "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", ], ) From 714787ecdfa0a48ba45e774c136a4f5a0837bf6a Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Mon, 14 Apr 2025 14:49:21 -0700 Subject: [PATCH 0740/1324] [XLA:Upkeep] Resolve the following technical debt issue: Todo PiperOrigin-RevId: 747576928 --- third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index 1b64f384984041..9d35a4242f8168 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -73,7 +73,6 @@ PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) { destroy_args.client = client; PJRT_Error* error = api->PJRT_Client_Destroy(&destroy_args); - // TODO(b/236710439): handle the error and remove this CHECK() call CHECK(error == nullptr); }; } From ef331c949931006a887fa4065c00c538468c1088 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 14:51:03 -0700 Subject: [PATCH 0741/1324] Run build_cleaner on BUILD file(s) located in /xla/service/cpu. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747577516 --- third_party/xla/xla/service/cpu/BUILD | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index fb7865d60a4b54..759ede6a3771f6 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -168,7 +168,6 @@ cc_library( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform_id", - "@com_google_absl//absl/base", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -506,7 +505,6 @@ xla_test( "//xla/tsl/lib/monitoring:collected_metrics", "//xla/tsl/lib/monitoring:collection_registry", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -542,7 +540,6 @@ xla_test( ], deps = [ ":cpu_aot_compilation_result", - ":cpu_compiler_pure", ":test_header_helper", "//xla:literal", "//xla:literal_util", @@ -711,7 +708,6 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], @@ -1040,15 +1036,12 @@ cc_library( ":backend_config_proto_cc", ":cpu_options", ":cpu_runtime", - ":ir_emission_utils", ":tiled_dot_emitter", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/codegen:target_machine_features", - "//xla/backends/cpu/codegen:vector_ir_builder", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service/llvm_ir:ir_array", @@ -1057,13 +1050,11 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", ], ) @@ -1112,7 +1103,6 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/backends/cpu/collectives:cpu_clique", "//xla/backends/cpu/collectives:cpu_clique_key", "//xla/backends/cpu/collectives:cpu_cliques", "//xla/backends/cpu/collectives:cpu_collectives", @@ -1169,10 +1159,7 @@ cc_library( ":runtime_lightweight_check", "//xla:executable_run_options", "//xla/backends/cpu/runtime:convolution_thunk_internal", - "//xla/tsl/framework/contraction:eigen_contraction_kernel", - "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:mutex", # build_cleaner: keep ], @@ -1187,7 +1174,6 @@ cc_library( deps = [ "//xla/service:custom_call_status_internal", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -1299,10 +1285,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/backends/cpu/runtime:convolution_thunk_internal", - "//xla/tsl/framework/contraction:eigen_contraction_kernel", - "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:mutex", # build_cleaner: keep ], @@ -1316,10 +1299,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/backends/cpu/runtime:convolution_thunk_internal", - "//xla/tsl/framework/contraction:eigen_contraction_kernel", - "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:mutex", # build_cleaner: keep ], @@ -1538,7 +1518,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:fusion_node_indexing_evaluation", "//xla/service:instruction_fusion", - "//xla/service/llvm_ir:fused_ir_emitter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1555,7 +1534,6 @@ cc_library( deps = [ "//xla/codegen/emitters:fusion_wrapper_base", "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_fusible", "@com_google_absl//absl/strings:string_view", ], ) @@ -1729,7 +1707,6 @@ xla_cc_test( "//xla/service:hlo_cost_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", ], @@ -1989,6 +1966,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":onednn_util", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", ] + mkl_deps(), ) From a644848de4e733532f1ff65b790ef14a5ff76cf2 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 14 Apr 2025 15:02:36 -0700 Subject: [PATCH 0742/1324] Removed reference to deprecated `CoordinationServiceInterface`. PiperOrigin-RevId: 747581922 --- tensorflow/core/distributed_runtime/rpc/BUILD | 1 + tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc | 3 ++- tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h | 2 +- tensorflow/core/distributed_runtime/session_mgr.cc | 5 ++--- tensorflow/core/distributed_runtime/session_mgr.h | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 5ef36545817eeb..c0351bbe448535 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -335,6 +335,7 @@ cc_library( "//tensorflow/core/nccl:collective_communicator", "//tensorflow/core/profiler/rpc:profiler_service_impl", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/rpc:async_service_interface", ] + tf_protos_profiler_service() + tf_grpc_dependencies() + tf_grpc_cc_dependencies(), alwayslink = 1, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 213047b6165e84..cc5ab2bd5a2cb5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -27,6 +27,7 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "grpcpp/server_builder.h" #include "absl/strings/numbers.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -524,7 +525,7 @@ absl::Status GrpcServer::SetCoordinationServiceAgentInstance( } absl::Status GrpcServer::SetCoordinationServiceInstance( - tsl::CoordinationServiceInterface* service) { + tsl::CoordinationService* service) { auto* coord_service = static_cast(coordination_service_); coord_service->SetCoordinationServiceInstance(service); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 8aadbc3077732e..431e4c4490be2a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -176,7 +176,7 @@ class GrpcServer : public ServerInterface { GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } absl::Status SetCoordinationServiceInstance( - tsl::CoordinationServiceInterface* service); + tsl::CoordinationService* service); private: Env* env_; diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index a881b2952fa5fa..44e3eaf1ecc2d5 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -289,9 +289,8 @@ absl::Status SessionMgr::CreateSession( // Initialize coordination service if it is the leader. if (IsMultiClientLeader(server_def, coordination_config)) { - coordination_service_ = - tsl::CoordinationServiceInterface::EnableCoordinationService( - worker_env_->env, coordination_config, std::move(client_cache)); + coordination_service_ = tsl::CoordinationService::Create( + worker_env_->env, coordination_config, std::move(client_cache)); if (coordination_handler_ != nullptr) { coordination_handler_->SetServiceInstance(coordination_service_.get()); } diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 55c64f45c9daeb..0a2bddddb1aeb7 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -134,7 +134,7 @@ class SessionMgr { std::unique_ptr default_worker_cache_; std::shared_ptr legacy_session_; - std::unique_ptr coordination_service_; + std::unique_ptr coordination_service_; std::unique_ptr coordination_service_agent_; bool is_logging_active_ = false; From cc42a3148ba1668519aa6c72831e13bf93cf8007 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 15:25:23 -0700 Subject: [PATCH 0743/1324] Add `CustomCombiner` to TPU embedding V2 API. PiperOrigin-RevId: 747590280 --- tensorflow/python/tpu/BUILD | 4 + .../python/tpu/tpu_embedding_v2_utils.py | 311 +++++++++++++----- .../python/tpu/tpu_embedding_v2_utils_test.py | 72 ++++ ...rimental.embedding.-adagrad-momentum.pbtxt | 1 + ....tpu.experimental.embedding.-adagrad.pbtxt | 1 + ...low.tpu.experimental.embedding.-adam.pbtxt | 1 + ...erimental.embedding.-custom-combiner.pbtxt | 10 + ...rimental.embedding.-custom-optimizer.pbtxt | 1 + ....tpu.experimental.embedding.-f-t-r-l.pbtxt | 1 + ...ow.tpu.experimental.embedding.-s-g-d.pbtxt | 1 + ...ensorflow.tpu.experimental.embedding.pbtxt | 4 + ...rimental.embedding.-adagrad-momentum.pbtxt | 1 + ....tpu.experimental.embedding.-adagrad.pbtxt | 1 + ...low.tpu.experimental.embedding.-adam.pbtxt | 1 + ...erimental.embedding.-custom-combiner.pbtxt | 10 + ...rimental.embedding.-custom-optimizer.pbtxt | 1 + ....tpu.experimental.embedding.-f-t-r-l.pbtxt | 1 + ...ow.tpu.experimental.embedding.-s-g-d.pbtxt | 1 + ...ensorflow.tpu.experimental.embedding.pbtxt | 4 + 19 files changed, 343 insertions(+), 84 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index 9a19ae4eaeaf1c..92d59574c8ed30 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -918,6 +918,10 @@ tf_py_strict_test( ":tpu_embedding_v2_utils", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py", "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:init_ops_v2", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index efe9cad29a4578..1ba7da1f2c9132 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -47,47 +47,21 @@ ClipValueType = Union[Tuple[float, float], float] -class _Optimizer(metaclass=abc.ABCMeta): - """Base class for all optimizers, with common parameters.""" +class _WithSlotVariables(metaclass=abc.ABCMeta): + """Base class that allows slot variables to be created.""" def __init__( self, - learning_rate: Union[float, Callable[[], float]], - use_gradient_accumulation: bool, - clip_weight_min: Optional[float], - clip_weight_max: Optional[float], - weight_decay_factor: Optional[float], - multiply_weight_decay_factor_by_learning_rate: bool, - clipvalue: Optional[ClipValueType] = None, slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, - low_dimensional_packing_status: bool = False, ): - self.learning_rate = learning_rate - self.use_gradient_accumulation = use_gradient_accumulation - self.clip_weight_min = clip_weight_min - self.clip_weight_max = clip_weight_max - if not use_gradient_accumulation and clipvalue is not None: - raise ValueError( - f"When `use_gradient_accumulation` is False, gradient clipping " - f"cannot be used and `clipvalue` should be left as None. " - f"Received value {clipvalue} for argument `clipvalue`.") - if clipvalue is None: - clipvalue = (None, None) - elif not isinstance(clipvalue, tuple): - clipvalue = (-1. * clipvalue, clipvalue) - self.clip_gradient_min, self.clip_gradient_max = clipvalue - - self.weight_decay_factor = weight_decay_factor - self.multiply_weight_decay_factor_by_learning_rate = ( - multiply_weight_decay_factor_by_learning_rate) - - if (slot_variable_creation_fn is not None and - not callable(slot_variable_creation_fn)): + if slot_variable_creation_fn is not None and not callable( + slot_variable_creation_fn + ): raise ValueError( - f"Argument `slot_variable_creation_fn` must be either None or a " - f"callable. Received: {slot_variable_creation_fn}") + "Argument `slot_variable_creation_fn` must be either None or a " + f"callable. Received: {slot_variable_creation_fn}" + ) self.slot_variable_creation_fn = slot_variable_creation_fn - self.low_dimensional_packing_status = low_dimensional_packing_status @abc.abstractmethod def _slot_names(self) -> List[Text]: @@ -107,47 +81,6 @@ def _slot_initializers(self) -> List[init_ops_v2.Initializer]: """ raise NotImplementedError - def _set_optimization_parameters( - self, parameters: optimization_parameters_pb2.OptimizationParameters): - """Sets the optimizer fields in the OptimizationParameters.""" - if self.use_gradient_accumulation: - parameters.gradient_accumulation_status = ( - optimization_parameters_pb2.GradientAccumulationStatus.ENABLED) - else: - parameters.gradient_accumulation_status = ( - optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) - - if self.clip_weight_min is not None: - parameters.clipping_limits.lower.value = self.clip_weight_min - - if self.clip_weight_max is not None: - parameters.clipping_limits.upper.value = self.clip_weight_max - - if self.clip_gradient_min is not None: - parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min - - if self.clip_gradient_max is not None: - parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max - - if self.weight_decay_factor: - parameters.weight_decay_factor = self.weight_decay_factor - if self.multiply_weight_decay_factor_by_learning_rate: - parameters.multiply_weight_decay_factor_by_learning_rate = True - - parameters.low_dimensional_packing_status = ( - self.low_dimensional_packing_status - ) - - @abc.abstractmethod - def _load(self) -> Callable[..., ops.Operation]: - """Returns the load function for the optimizer.""" - raise NotImplementedError - - @abc.abstractmethod - def _retrieve(self) -> Callable[..., core.Tensor]: - """Returns the retrieve function for the optimizer.""" - raise NotImplementedError - def _create_slots( self, table: "TableConfig", @@ -199,6 +132,90 @@ def __hash__(self) -> int: return hash(tuple(self.__dict__.items())) +class _Optimizer(_WithSlotVariables): + """Base class for all optimizers, with common parameters.""" + + def __init__( + self, + learning_rate: Union[float, Callable[[], float]], + use_gradient_accumulation: bool, + clip_weight_min: Optional[float], + clip_weight_max: Optional[float], + weight_decay_factor: Optional[float], + multiply_weight_decay_factor_by_learning_rate: bool, + clipvalue: Optional[ClipValueType] = None, + slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, + low_dimensional_packing_status: bool = False, + ): + super().__init__(slot_variable_creation_fn=slot_variable_creation_fn) + self.learning_rate = learning_rate + self.use_gradient_accumulation = use_gradient_accumulation + self.clip_weight_min = clip_weight_min + self.clip_weight_max = clip_weight_max + if not use_gradient_accumulation and clipvalue is not None: + raise ValueError( + "When `use_gradient_accumulation` is False, gradient clipping " + "cannot be used and `clipvalue` should be left as None. " + f"Received value {clipvalue} for argument `clipvalue`." + ) + if clipvalue is None: + clipvalue = (None, None) + elif not isinstance(clipvalue, tuple): + clipvalue = (-1.0 * clipvalue, clipvalue) + self.clip_gradient_min, self.clip_gradient_max = clipvalue + + self.weight_decay_factor = weight_decay_factor + self.multiply_weight_decay_factor_by_learning_rate = ( + multiply_weight_decay_factor_by_learning_rate + ) + + self.low_dimensional_packing_status = low_dimensional_packing_status + + def _set_optimization_parameters( + self, parameters: optimization_parameters_pb2.OptimizationParameters + ): + """Sets the optimizer fields in the OptimizationParameters.""" + if self.use_gradient_accumulation: + parameters.gradient_accumulation_status = ( + optimization_parameters_pb2.GradientAccumulationStatus.ENABLED + ) + else: + parameters.gradient_accumulation_status = ( + optimization_parameters_pb2.GradientAccumulationStatus.DISABLED + ) + + if self.clip_weight_min is not None: + parameters.clipping_limits.lower.value = self.clip_weight_min + + if self.clip_weight_max is not None: + parameters.clipping_limits.upper.value = self.clip_weight_max + + if self.clip_gradient_min is not None: + parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min + + if self.clip_gradient_max is not None: + parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max + + if self.weight_decay_factor: + parameters.weight_decay_factor = self.weight_decay_factor + if self.multiply_weight_decay_factor_by_learning_rate: + parameters.multiply_weight_decay_factor_by_learning_rate = True + + parameters.low_dimensional_packing_status = ( + self.low_dimensional_packing_status + ) + + @abc.abstractmethod + def _load(self) -> Callable[..., ops.Operation]: + """Returns the load function for the optimizer.""" + raise NotImplementedError + + @abc.abstractmethod + def _retrieve(self) -> Callable[..., core.Tensor]: + """Returns the retrieve function for the optimizer.""" + raise NotImplementedError + + @tf_export("tpu.experimental.embedding.CustomOptimizer") class CustomOptimizer(_Optimizer): """Optimization parameters for custom optimizer for TPU embeddings. @@ -1094,6 +1111,123 @@ def _retrieve(self) -> Callable[..., core.Tensor]: return tpu_ops.retrieve_tpu_embedding_adam_parameters +@tf_export("tpu.experimental.embedding.CustomCombiner") +class CustomCombiner(_WithSlotVariables): + """Custom combiner for TPU embeddings. + + This class gives the user the ability to define a custom combiner for running + embedding lookups on TPU with SparseCores. + + The custom computation should be a function which takes the following + arguments: + (1) valency: an integer scalar that indicates the actual number of valent + IDs in the current sample. + (2) vectors: a 2D tensor of shape [max_valency, embedding_dim] that + represents the embedding lookup results to be combined. The + vectors are guaranteed to be in the same order as the embedding + IDs appear in the input sample. + (3) weights: this argument is only present if `num_weights` of this class + is greater than 0. It is a 1D tensor of shape [num_weights] + that will be back-propagated to during the backward pass. + Currently only SGD optimizer is supported on these weights, + with the learning rate being the + `combiner_weights_learning_rate` argument of the constructor + of this class. + + The custom computation should return a 1D tensor of shape [embedding_dim] that + represents the combined embedding vector. An example combiner computation is + as follows (it simply sums over the valent embedding vectors and is + semantically equivalent to the `sum` combiner): + + ```python + @tf.function + def sum_combiner(valency, vectors): + max_valency = vectors.shape[0] + valid_mask = tf.range(max_valency) < valency + vectors_masked = tf.where( + tf.expand_dims(valid_mask, axis=-1), + vectors, + tf.zeros_like(vectors), + ) + return tf.reduce_sum(vectors_masked, axis=0) + ``` + + The custom computation is defined as a per-sample combiner function and it + will be vectorized to the given batch size. This means certain constructs + (such as control flow ops) may not be supported in the custom computation. + + NOTE: This combiner can only be used with the `TPUEmbeddingV2` class. + + This class can be used in a `tf.tpu.experimental.embedding.TableConfig` as the + combiner parameter to set a table specific combiner. + + ```python + custom_combiner = tpu_embedding_v2_utils.CustomCombiner( + sum_combiner, + max_valency=16, + ) + + table_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=..., + combiner=custom_combiner) + table_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + + feature_config = ( + tf.tpu.experimental.embedding.FeatureConfig( + table=table_one), + tf.tpu.experimental.embedding.FeatureConfig( + table=table_two)) + + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=...) + ``` + In this example, the combiner of the first table will be the `sum_combiner`. + The second table will use the default `mean` combiner. + """ + + def __init__( + self, + combiner_computation: core.PolymorphicFunction, + max_valency: int, + num_weights: int = 0, + initializer: Optional[init_ops_v2.Initializer] = None, + combiner_weights_learning_rate: Union[float, Callable[[], float]] = 0.01, + ) -> Any: + super().__init__() + self.combiner = "custom_combiner_" + str(hash(combiner_computation)) + self.combiner_computation = combiner_computation + self.max_valency = max_valency + self.num_weights = num_weights + self.combiner_weights_learning_rate = combiner_weights_learning_rate + + if num_weights > 0 and initializer is None: + raise ValueError( + "When `num_weights` is greater than 0, `initializer` must be set." + ) + + self._slot_names_attr = tuple(["custom_combiner_variables"]) + self._slot_initializers_attr = tuple( + [initializer] if initializer is not None else () + ) + + def _slot_names(self) -> List[Text]: + if self.num_weights > 0: + return list(self._slot_names_attr) + return [] + + def _slot_initializers(self) -> List[init_ops_v2.Initializer]: + if self.num_weights > 0: + return list(self._slot_initializers_attr) + return [] + + def __str__(self) -> str: + return self.combiner + + @tf_export("tpu.experimental.embedding.QuantizationConfig") class QuantizationConfig: """Settings for simulated quantization of the tpu embedding table. @@ -1201,7 +1335,7 @@ def __init__( dim: int, initializer: Optional[Callable[[Any], None]] = None, optimizer: Optional[_Optimizer] = None, - combiner: Text = "mean", + combiner: Union[Text, CustomCombiner] = "mean", name: Optional[Text] = None, quantization_config: QuantizationConfig = None, # TODO(b/295372790): Change the type to SparseCoreTableLayout after it is @@ -1223,10 +1357,11 @@ def __init__( `tf.tpu.experimental.embedding.Adagrad` or `tf.tpu.experimental.embedding.Adam`. If set will override the global optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently 'mean', 'sqrtn', 'sum' are supported, with - 'mean' the default. 'sqrtn' often achieves good accuracy, in particular - with bag-of-words columns. For more information, see + combiner: A string or instance of a combiner class specifying how to + reduce if there are multiple entries in a single row. Currently 'mean', + 'sqrtn', 'sum', and custom combiners are supported, with 'mean' the + default. 'sqrtn' often achieves good accuracy, in particular with + bag-of-words columns. For more information, see `tf.nn.embedding_lookup_sparse`. name: An optional string used to name the table. Must be defined if running on SparseCore. @@ -1263,11 +1398,19 @@ def __init__( if initializer is None: initializer = init_ops_v2.TruncatedNormal(mean=0.0, stddev=1/math.sqrt(dim)) - accepted_combiners = ("mean", "sum", "sqrtn") - if combiner not in accepted_combiners: + + accepted_str_combiners = ("mean", "sum", "sqrtn") + if isinstance(combiner, str): + if combiner not in accepted_str_combiners: + raise ValueError( + f"String argument `combiner` must be in {accepted_str_combiners}. " + f"Received: {combiner}") + + elif not isinstance(combiner, CustomCombiner): raise ValueError( - f"Argument `combiner` must be one of {accepted_combiners}. " - f"Received: {combiner}") + f"Argument `combiner` should either be a str or a CustomCombiner. " + f"Received: {type(combiner)}" + ) if name is None: logging.warning( diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py index 9e18c8c3a0cf88..e0ba452d68802a 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py @@ -18,6 +18,10 @@ from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.compat import v2_compat +from tensorflow.python.eager import def_function +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops_v2 +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.tpu import tpu_embedding_v2_utils @@ -74,6 +78,74 @@ def test_equal_and_hash_function(self, optimizer): self.assertNotEqual(hash(opt1), hash(opt3)) +class TPUEmbeddingCustomCombinerTest(parameterized.TestCase, test.TestCase): + + def get_sum_combiner(self): + + @def_function.function + def sum_combiner(valency, vectors): + max_valency = vectors.shape[0] + valid_mask = array_ops.range(max_valency) < valency + vectors_masked = array_ops.where( + array_ops.expand_dims(valid_mask, axis=-1), + vectors, + array_ops.zeros_like(vectors), + ) + return math_ops.reduce_sum(vectors_masked, axis=0) + + return sum_combiner + + def get_positional_weight_combiner(self): + + @def_function.function + def positional_weight_combiner(valency, vectors, weights): + max_valency = vectors.shape[0] + valid_mask = array_ops.range(max_valency) < valency + vectors_masked = array_ops.where( + array_ops.expand_dims(valid_mask, axis=-1), + vectors, + array_ops.zeros_like(vectors), + ) + return math_ops.matvec(vectors_masked, weights, transpose_a=True) + + return positional_weight_combiner + + def test_zero_num_weights_combiner_has_no_slots(self): + combiner = tpu_embedding_v2_utils.CustomCombiner( + self.get_sum_combiner(), + max_valency=16, + num_weights=0, + ) + self.assertEmpty(combiner._slot_names()) + self.assertEmpty(combiner._slot_initializers()) + + def test_name_starts_with_custom_combiner(self): + combiner = tpu_embedding_v2_utils.CustomCombiner( + self.get_sum_combiner(), + max_valency=16, + ) + self.assertStartsWith(str(combiner), 'custom_combiner') + + def test_non_zero_weights_requires_initializer(self): + with self.assertRaisesRegex(ValueError, '`initializer` must be set'): + tpu_embedding_v2_utils.CustomCombiner( + self.get_positional_weight_combiner(), + max_valency=16, + num_weights=16, + ) + + def test_non_zero_weights_has_one_slot_variable(self): + combiner = tpu_embedding_v2_utils.CustomCombiner( + self.get_positional_weight_combiner(), + max_valency=16, + num_weights=16, + initializer=init_ops_v2.zeros_initializer, + ) + self.assertLen(combiner._slot_names(), 1) + self.assertLen(combiner._slot_initializers(), 1) + self.assertStartsWith(combiner._slot_names()[0], 'custom_combiner') + + class ConfigTest(test.TestCase): def test_table_config_repr(self): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt index 829d45dc73f567..09f08b08d49219 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.AdagradMomentum" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt index 0e3dda0dad1346..3f5cacc7100105 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.Adagrad" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt index f52fec3b97fb27..febe4a76a5ca0f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.Adam" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt new file mode 100644 index 00000000000000..976d08d8130d05 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.CustomCombiner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'combiner_computation\', \'max_valency\', \'num_weights\', \'initializer\', \'combiner_weights_learning_rate\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'0.01\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt index 08de88d791b171..699d074c904d6c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.CustomOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "custom_computation" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt index 4a1ec7116405a4..61946268a9498a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.FTRL" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt index 37fb55880ac4b2..cef70784fac716 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.SGD" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt index 052d0fd257ed66..1c3789e2d0a6b9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "Adam" mtype: "" } + member { + name: "CustomCombiner" + mtype: "" + } member { name: "CustomOptimizer" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt index 829d45dc73f567..09f08b08d49219 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad-momentum.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.AdagradMomentum" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt index 0e3dda0dad1346..3f5cacc7100105 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.Adagrad" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt index f52fec3b97fb27..febe4a76a5ca0f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.Adam" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt new file mode 100644 index 00000000000000..976d08d8130d05 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-combiner.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.CustomCombiner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'combiner_computation\', \'max_valency\', \'num_weights\', \'initializer\', \'combiner_weights_learning_rate\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'0.01\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt index 08de88d791b171..699d074c904d6c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-custom-optimizer.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.CustomOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "custom_computation" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt index 4a1ec7116405a4..61946268a9498a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-f-t-r-l.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.FTRL" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt index 37fb55880ac4b2..cef70784fac716 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.tpu.experimental.embedding.SGD" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt index 052d0fd257ed66..1c3789e2d0a6b9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "Adam" mtype: "" } + member { + name: "CustomCombiner" + mtype: "" + } member { name: "CustomOptimizer" mtype: "" From 65218d8f8fffd51ba297bd0c3cdefe35b66343ce Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 14 Apr 2025 15:48:21 -0700 Subject: [PATCH 0744/1324] Disable SVE instructions also for PLATFORM_GOOGLE. PiperOrigin-RevId: 747598229 --- third_party/xla/xla/debug_options_flags.cc | 4 ---- .../xla/xla/service/cpu/tests/cpu_vectorization_test.cc | 4 ---- 2 files changed, 8 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 16f94aa0022b59..1b226e3456a56e 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -51,15 +51,11 @@ limitations under the License. namespace xla { inline std::string DefaultMaxIsa() { -#ifdef PLATFORM_GOOGLE - return ""; -#else // There are many missing SVE lowerings in LLVM. Limit features to NEON for // now. There shouldn't be significant performance impact as most AAarch64 // CPUs still use 128-bit registers. // TODO(penporn): Remove this once SVE is fully supported. return tsl::port::IsAarch64CPU() ? "NEON" : ""; -#endif // PLATFORM_GOOGLE } DebugOptions DefaultDebugOptionsIgnoringFlags() { diff --git a/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc index 624dd2f53caeef..f11d99052be43e 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -241,11 +241,7 @@ TEST_F(DefaultMaxIsaTest, NeonForOssAArch64) { GTEST_SKIP() << "This test is for AArch64 CPUs."; } DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); -#ifdef PLATFORM_GOOGLE - EXPECT_EQ(debug_options.xla_cpu_max_isa(), ""); -#else EXPECT_EQ(debug_options.xla_cpu_max_isa(), "NEON"); -#endif // PLATFORM_GOOGLE } struct JitVectorizationTestSpec { From 6610710a94cb7a87dc57ba86eb59eeb42ce400a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 16:00:54 -0700 Subject: [PATCH 0745/1324] Update gloo_collectives to support shared pointers gloo has a breaking API change starting with commit 08c094bcdb4d6f6879046b2f5fca5aa510661ae7 (2025-04-01). When built with `GLOO_SHARED_STORE`, this library will now work with the newer version of the API. PiperOrigin-RevId: 747602570 --- .../cpu/collectives/gloo_collectives.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc index 20c59ac9c186c9..5206a9fabb071a 100644 --- a/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc @@ -61,16 +61,29 @@ GlooCollectives::CreateCommunicators(const CliqueKey& clique_key, auto gloo_context = std::make_shared( rank, clique_key.num_devices()); - auto prefix_store = gloo::rendezvous::PrefixStore( + +#ifdef GLOO_SHARED_STORE + auto store_pointer = std::shared_ptr( + store_.get(), [](gloo::rendezvous::Store*) {}); +#else + auto& store_pointer = *store_; +#endif // GLOO_SHARED_STORE + + auto prefix_store = std::make_shared( absl::StrCat("gloo/", absl::StrJoin(clique_key.devices(), ",", [](std::string* out, GlobalDeviceId id) { absl::StrAppend(out, id.value()); })), - *store_); + store_pointer); try { - gloo_context->connectFullMesh(prefix_store, device_); +#ifdef GLOO_SHARED_STORE + auto prefix_store_pointer = prefix_store; +#else + auto& prefix_store_pointer = *prefix_store; +#endif // GLOO_SHARED_STORE + gloo_context->connectFullMesh(prefix_store_pointer, device_); } catch (std::exception& e) { return absl::UnknownError( absl::StrCat("Gloo context initialization failed: ", e.what())); From c54f15283f25217706869649d0dfaa3533389f1c Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Mon, 14 Apr 2025 16:10:10 -0700 Subject: [PATCH 0746/1324] [XLA:CPU] VLOG thunk executor graph on construction PiperOrigin-RevId: 747606108 --- .../xla/backends/cpu/runtime/thunk_executor.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index 7cf253ed4283eb..aa7ad60ab18029 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -126,6 +126,8 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, "#sink_nodes=%d, is_sequential=%v, small_buffers=%v", num_thunks_, execution_graph_.source().size(), execution_graph_.sink().size(), is_sequential_, small_buffers); + + VLOG(6) << "ThunkExecutor execution graph:\n" << ToString(); } absl::StatusOr ThunkExecutor::Create( @@ -561,12 +563,13 @@ std::string ThunkExecutor::ToString() const { const Thunk& thunk = *thunk_sequence_[i]; bool is_source = absl::c_find(source, i) != source.end(); bool is_sink = absl::c_find(sink, i) != sink.end(); - absl::StrAppendFormat(&str, - "\n thunk #%05d: op_name=%s, dependencies=[%s], " - "source=%v, sink=%v, priority=%d", - i, thunk.info().op_name, - absl::StrJoin(in_edges[i], ", "), is_source, is_sink, - execution_graph_.priority(i)); + absl::StrAppendFormat( + &str, + "\n thunk #%05d: op_name=%s, kind=%s, dependencies=[%s], " + "source=%v, sink=%v, priority=%d", + i, thunk.info().op_name, Thunk::KindToString(thunk.kind()), + absl::StrJoin(in_edges[i], ", "), is_source, is_sink, + execution_graph_.priority(i)); } return str; From a115df9ae7bc75d8ef3d7a3b2c1956addf391695 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 14 Apr 2025 16:12:54 -0700 Subject: [PATCH 0747/1324] Update rules_python2.patch to download archives from python-build-standalone's new location Stewardship of python-build-standalone repo was transferred to Astral: https://github.com/astral-sh/python-build-standalone/discussions/396 This also fixes JAX's Windows Python 3.13 build: https://btx.cloud.google.com/invocations/89927854-5ee3-4cee-bf6a-f9f607d95ef8/log PiperOrigin-RevId: 747607057 --- third_party/py/rules_python2.patch | 25 ++++++++++++++++--- .../xla/third_party/py/rules_python2.patch | 25 ++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/third_party/py/rules_python2.patch b/third_party/py/rules_python2.patch index 304d0606ca08f7..93d79b206c8c80 100644 --- a/third_party/py/rules_python2.patch +++ b/third_party/py/rules_python2.patch @@ -1,7 +1,16 @@ diff --git a/python/versions.bzl b/python/versions.bzl -index 91e59f9b..223c90d4 100644 +index 91e59f9b..d32195f4 100644 --- a/python/versions.bzl +++ b/python/versions.bzl +@@ -21,7 +21,7 @@ LINUX_NAME = "linux" + WINDOWS_NAME = "windows" + FREETHREADED = "freethreaded" + +-DEFAULT_RELEASE_BASE_URL = "https://github.com/indygreg/python-build-standalone/releases/download" ++DEFAULT_RELEASE_BASE_URL = "https://github.com/astral-sh/python-build-standalone/releases/download" + + # When updating the versions and releases, run the following command to get + # the hashes: @@ -575,25 +575,42 @@ TOOL_VERSIONS = { }, "strip_prefix": "python", @@ -62,7 +71,7 @@ index 91e59f9b..223c90d4 100644 - "strip_prefix": "python", }, } - + @@ -604,7 +621,7 @@ MINOR_MAPPING = { "3.10": "3.10.15", "3.11": "3.11.10", @@ -70,5 +79,15 @@ index 91e59f9b..223c90d4 100644 - "3.13": "3.13.0", + "3.13": "3.13.2", } - + def _generate_platforms(): +@@ -793,9 +810,6 @@ def get_release_info(platform, python_version, base_url = DEFAULT_RELEASE_BASE_U + else: + build = "install_only" + +- if WINDOWS_NAME in platform: +- build = "shared-" + build +- + release_filename = u.format( + platform = p, + python_version = python_version, diff --git a/third_party/xla/third_party/py/rules_python2.patch b/third_party/xla/third_party/py/rules_python2.patch index 304d0606ca08f7..93d79b206c8c80 100644 --- a/third_party/xla/third_party/py/rules_python2.patch +++ b/third_party/xla/third_party/py/rules_python2.patch @@ -1,7 +1,16 @@ diff --git a/python/versions.bzl b/python/versions.bzl -index 91e59f9b..223c90d4 100644 +index 91e59f9b..d32195f4 100644 --- a/python/versions.bzl +++ b/python/versions.bzl +@@ -21,7 +21,7 @@ LINUX_NAME = "linux" + WINDOWS_NAME = "windows" + FREETHREADED = "freethreaded" + +-DEFAULT_RELEASE_BASE_URL = "https://github.com/indygreg/python-build-standalone/releases/download" ++DEFAULT_RELEASE_BASE_URL = "https://github.com/astral-sh/python-build-standalone/releases/download" + + # When updating the versions and releases, run the following command to get + # the hashes: @@ -575,25 +575,42 @@ TOOL_VERSIONS = { }, "strip_prefix": "python", @@ -62,7 +71,7 @@ index 91e59f9b..223c90d4 100644 - "strip_prefix": "python", }, } - + @@ -604,7 +621,7 @@ MINOR_MAPPING = { "3.10": "3.10.15", "3.11": "3.11.10", @@ -70,5 +79,15 @@ index 91e59f9b..223c90d4 100644 - "3.13": "3.13.0", + "3.13": "3.13.2", } - + def _generate_platforms(): +@@ -793,9 +810,6 @@ def get_release_info(platform, python_version, base_url = DEFAULT_RELEASE_BASE_U + else: + build = "install_only" + +- if WINDOWS_NAME in platform: +- build = "shared-" + build +- + release_filename = u.format( + platform = p, + python_version = python_version, From 889d891ad7a711d3c3bc0e99365d615553837f2a Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:17:32 -0700 Subject: [PATCH 0748/1324] Run build_cleaner on BUILD file(s) located in /xla/service/cpu/tests. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747608758 --- third_party/xla/xla/service/cpu/tests/BUILD | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index a3b1c7aac2a8f1..c7bfb02453a949 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -36,7 +36,6 @@ cc_library( "//xla/service:cpu_plugin", "//xla/tests:llvm_irgen_test_base", "//xla/tsl/platform:test_main", - "@com_google_googletest//:gtest", ], ) @@ -97,7 +96,6 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", - "@com_google_googletest//:gtest", ], ) @@ -110,7 +108,6 @@ xla_cc_test( "//xla/service/llvm_ir:llvm_util", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", - "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", ], ) @@ -149,7 +146,6 @@ xla_cc_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/status", - "@com_google_googletest//:gtest", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", ], @@ -216,7 +212,6 @@ xla_cc_test( ":cpu_codegen_test_main", "//xla/service/cpu:cpu_compiler", "//xla/tsl/platform:test", - "@com_google_googletest//:gtest", ], ) @@ -241,7 +236,6 @@ xla_cc_test( "//xla/tsl/platform:env", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", - "@com_google_googletest//:gtest", ], ) @@ -295,7 +289,6 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", ], ) From 91a3e3ce50631d2d53326762dd5da14870a3182b Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:18:23 -0700 Subject: [PATCH 0749/1324] Run build_cleaner on BUILD file(s) located in /xla/gpu/tests. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747609073 --- third_party/xla/xla/service/gpu/tests/BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index ff6bd494ba3d95..06c4e4b03a9916 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -642,7 +642,8 @@ lit_test_suite( # "//xla/backends/gpu/codegen/emitters/transforms:passes", # "//xla/backends/gpu/codegen/triton/ir:triton_xla", # "//xla/backends/gpu/codegen/triton/transforms:passes", -# "//xla/codegen/emitters/ir:xla", +# # Needed for xla_ops.h +# "//xla/codegen/emitters/ir:xla", # buildcleaner: keep # "//xla/codegen/emitters/transforms:passes", # "@triton//:AllPassesAndDialects", # "@triton//third_party/amd:TestAMDRangeAnalysis", From 7e13ff380c2882ae8435ae7437a880a4e27d541d Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:22:52 -0700 Subject: [PATCH 0750/1324] Run build_cleaner on BUILD file(s) located in /xla/gpu/model. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747610707 --- third_party/xla/xla/service/gpu/model/BUILD | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 72c5f4a26ef42c..d1b6c583e76627 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -105,8 +105,6 @@ xla_cc_test( deps = [ ":sol_gpu_cost_model", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], @@ -198,9 +196,7 @@ xla_cc_test( deps = [ ":gpu_cost_model_stats_collection", ":gpu_hlo_cost_analysis", - "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_cost_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", @@ -294,7 +290,6 @@ xla_cc_test( deps = [ ":gpu_hlo_cost_analysis", ":gpu_performance_model_base", - "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test_helpers", "//xla/service/gpu:backend_configs_cc", @@ -618,7 +613,6 @@ xla_cc_test( ":symbolic_tiled_hlo_instruction", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_traversal", "//xla/tests:hlo_test_base", "@com_google_googletest//:gtest_main", @@ -831,7 +825,6 @@ xla_cc_test( srcs = ["coalescing_analysis_test.cc"], deps = [ ":coalescing_analysis", - ":symbolic_tile", ":symbolic_tile_analysis", ":tiled_hlo_instruction_or_computation", "//xla:shape_util", From e9bbab45f3423394360330c8be6b09418105cad7 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:26:44 -0700 Subject: [PATCH 0751/1324] Run build_cleaner on BUILD file(s) located in /xla/tests. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747611914 --- third_party/xla/xla/tests/BUILD | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index bf2069617ad78a..4b466a0aadd576 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1385,7 +1385,10 @@ xla_test( "optonly", ], deps = [ + ":client_library_test_base", ":client_library_test_runner_mixin", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", @@ -1405,9 +1408,6 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/tests:client_library_test_base", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1432,8 +1432,12 @@ xla_test( "test_xla_cpu_no_thunks", ], deps = [ + ":client_library_test_base", ":client_library_test_runner_mixin", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", "//xla:array4d", @@ -1451,10 +1455,6 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/tests:client_library_test_base", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1517,7 +1517,10 @@ xla_test( "optonly", ], deps = [ + ":client_library_test_base", ":client_library_test_runner_mixin", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", @@ -1537,9 +1540,6 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/tests:client_library_test_base", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1591,8 +1591,12 @@ xla_test( shard_count = 25, tags = ["cuda-only"], deps = [ + ":client_library_test_base", ":client_library_test_runner_mixin", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", "//xla:array4d", @@ -1610,10 +1614,6 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/tests:client_library_test_base", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", From 040190f21785f3e37d6beac07ecc4272c0ea4441 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:28:48 -0700 Subject: [PATCH 0752/1324] Run build_cleaner on BUILD file(s) located in /xla/pjrt/profiling. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747612541 --- third_party/xla/xla/pjrt/profiling/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/profiling/BUILD b/third_party/xla/xla/pjrt/profiling/BUILD index d271e9c8232af4..70432ceec15638 100644 --- a/third_party/xla/xla/pjrt/profiling/BUILD +++ b/third_party/xla/xla/pjrt/profiling/BUILD @@ -56,7 +56,7 @@ cc_library( ":no_op_device_time_measurement", # copybara:comment_end # copybara:uncomment_begin(google-only) - # "//learning/brain/google/runtime:device_runtime_profiling", + # "//learning/brain/google/runtime:device_runtime_profiling", # buildcleaner: keep # copybara:uncomment_end "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", From 6083be93cf8f0f1c29ce8591aa7950c17cad1653 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:29:08 -0700 Subject: [PATCH 0753/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor/cuda. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747612643 --- .../xla/xla/stream_executor/cuda/BUILD | 27 +------------------ 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 8e0d321ebef7c3..8567306a767a9c 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -160,7 +160,6 @@ xla_test( "//xla/stream_executor:platform_manager", "@com_google_absl//absl/debugging:leak_check", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log:globals", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -256,7 +255,6 @@ cc_library( ], deps = [ "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:context", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], @@ -288,16 +286,12 @@ xla_test( "cuda-only", ], deps = [ - ":cuda_context", ":cuda_diagnostics", ":cuda_status", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", ], ) @@ -316,7 +310,6 @@ cc_library( deps = [ "//xla:types", "//xla/stream_executor:blas", - "//xla/stream_executor:device_memory", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", @@ -631,7 +624,6 @@ cc_library( "gpu", ], deps = [ - "//xla/stream_executor/gpu:gpu_helpers_header", "@com_google_absl//absl/log:check", "@local_config_cuda//cuda:cuda_headers", ], @@ -872,7 +864,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:nvjitlink", # buildcleaner: keep @@ -1320,12 +1311,10 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_test_kernels_cuda", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -1463,14 +1452,7 @@ cc_binary( name = "dummy_cuda_binary", testonly = True, srcs = ["dummy_cuda_binary.cc"], - deps = [ - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], + deps = ["@com_google_absl//absl/strings"], ) stage_in_bin_subdirectory( @@ -1547,7 +1529,6 @@ xla_cc_test( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", ], ) @@ -1588,7 +1569,6 @@ cc_library( ":subprocess_compilation", "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_asm_opts", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", @@ -1612,7 +1592,6 @@ cc_library( ":compilation_options", ":compilation_provider", ":cuda_platform", # buildcleaner: keep - ":cuda_platform_id", ":driver_compilation_provider", ":nvjitlink_compilation_provider", ":nvjitlink_support", @@ -1621,9 +1600,6 @@ cc_library( ":subprocess_compilation", ":subprocess_compilation_provider", "//xla/stream_executor:device_description", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1830,7 +1806,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", From 0f0262600ea2655d661edaec69195acdfd926bf7 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:29:14 -0700 Subject: [PATCH 0754/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor/gpu. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747612671 --- third_party/xla/xla/stream_executor/gpu/BUILD | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 59a0f4c680e5f8..c2699334093c9a 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -166,7 +166,6 @@ gpu_only_cc_library( "gpu_command_buffer.h", ], deps = [ - ":gpu_executor_header", ":scoped_update_mode", "//xla/stream_executor:bit_pattern", "//xla/stream_executor:command_buffer", @@ -175,15 +174,11 @@ gpu_only_cc_library( "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", - "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/cuda:cuda_platform_id", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", - "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -266,13 +261,7 @@ gpu_only_cc_library( hdrs = ["gpu_stream.h"], deps = [ ":gpu_types_header", - "//xla/stream_executor:kernel", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:platform", "//xla/stream_executor:stream", - "//xla/stream_executor:stream_common", - "//xla/stream_executor:stream_executor_h", - "@com_google_absl//absl/log:check", ], ) @@ -384,7 +373,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:statusor", ], ) From 527e4d0d3d57d7346253e07af18d2d39f65e8592 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:29:22 -0700 Subject: [PATCH 0755/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor/host. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747612716 --- third_party/xla/xla/stream_executor/host/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index 970581e9ac2268..053eb937e3c9f2 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -82,7 +82,6 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", "//xla/stream_executor:stream_executor_h", - "//xla/tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", From 1059f7f19fab3e72abdbf0825cedb7d75af055a9 Mon Sep 17 00:00:00 2001 From: Ezekiel Calubaquib Date: Mon, 14 Apr 2025 16:29:30 -0700 Subject: [PATCH 0756/1324] Fork RegisterExtraTfOpDefs and ImportSavedModel so that TF doesnt depend on TFL PiperOrigin-RevId: 747612755 --- tensorflow/compiler/mlir/utils/BUILD | 21 +++++ .../mlir/utils/saved_model_converter_utils.cc | 94 +++++++++++++++++++ .../mlir/utils/saved_model_converter_utils.h | 46 +++++++++ tensorflow/core/framework/BUILD | 2 + 4 files changed, 163 insertions(+) create mode 100644 tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc create mode 100644 tensorflow/compiler/mlir/utils/saved_model_converter_utils.h diff --git a/tensorflow/compiler/mlir/utils/BUILD b/tensorflow/compiler/mlir/utils/BUILD index 13cdb3e51d33a9..ae6a01df20e1b2 100644 --- a/tensorflow/compiler/mlir/utils/BUILD +++ b/tensorflow/compiler/mlir/utils/BUILD @@ -38,6 +38,27 @@ cc_library( ], ) +cc_library( + name = "saved_model_converter_utils", + srcs = ["saved_model_converter_utils.cc"], + hdrs = ["saved_model_converter_utils.h"], + visibility = [ + "//tensorflow/cc/experimental/tfa:__subpackages__", + ], + deps = [ + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/mlir/tf2xla/api/v2:mlir_roundtrip_flags", + "//tensorflow/core/framework:op", + "//tensorflow/core/framework:op_def_builder", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "validators", srcs = [ diff --git a/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc new file mode 100644 index 00000000000000..d818acf6ee528d --- /dev/null +++ b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/saved_model_converter_utils.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/mlir_roundtrip_flags.h" + + +namespace tensorflow { +namespace utils { + +// Util that registers 'extra_tf_opdefs' to the TF global registry. +// Return OK on success, failure if registering failed. +absl::Status RegisterExtraTfOpDefs( + absl::Span extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return absl::InvalidArgumentError("fail to parse extra OpDef"); + } + // Register extra opdefs. + // TODO: b/133770952 - Support shape functions. + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); + }); + } + return absl::OkStatus(); +} + +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, const int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle) { + // Register extra TF ops passed as OpDef. + auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); + if (!extra_opdefs_status.ok()) return extra_opdefs_status; + + if (saved_model_version == 2) { + auto module_or = SavedModelObjectGraphToMlirImport( + input_filename, tags, exported_names, context, + /*unconditionally_use_set_output_shapes=*/true); + if (!module_or.status().ok()) return module_or.status(); + return std::move(module_or).value(); + } else if (saved_model_version == 1) { + MLIRImportOptions options; + options.upgrade_legacy = specs.upgrade_legacy; + options.unconditionally_use_set_output_shapes = true; + options.lift_variables = enable_variable_lifting; + auto module_or = SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context, options, + saved_model_bundle); + + if (!module_or.status().ok()) return module_or.status(); + return std::move(module_or).value(); + } else { + return absl::InvalidArgumentError("Should be either saved model v1 or v2."); + } +} + +} // namespace utils +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h new file mode 100644 index 00000000000000..fc4440fb918a37 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/mlir_roundtrip_flags.h" + +namespace tensorflow { +namespace utils { + +// 'saved_model_bundle' will be initialized if V1 model was loaded. +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle); + +} // namespace utils +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index cc9215059ee06f..1bff2fba6601b2 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -46,6 +46,8 @@ default_visibility = [ #internal library, # TODO(matthurd): to be removed when summary.proto.h deps moves to TSL "@org_xprof//xprof:__subpackages__", + "//tensorflow/cc/experimental/tfa:__subpackages__", + "//tensorflow/compiler/mlir/utils:__subpackages__", ] package( From abe8ca2dc69f0ccfaa12271b9f04df945826e38b Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:29:36 -0700 Subject: [PATCH 0757/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor/tpu. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747612789 --- third_party/xla/xla/stream_executor/tpu/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index 09f7da009fcbdf..4c7323da2a7286 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -133,7 +133,6 @@ cc_library( ":tpu_executor_api", ":tpu_executor_c_api_hdrs", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:status", ], ) @@ -145,7 +144,6 @@ cc_library( "//xla/tsl/c:tsl_status", "//xla/tsl/c:tsl_status_helper", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:status", ], ) From 37031fa6e01afafbb0663a590f35f1dde4ba1cde Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:30:01 -0700 Subject: [PATCH 0758/1324] Run build_cleaner on BUILD file(s) located in /xla/service/spmd/shardy/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747612916 --- .../xla/xla/service/spmd/shardy/sdy_round_trip/BUILD | 6 ------ .../xla/xla/service/spmd/shardy/stablehlo_round_trip/BUILD | 2 -- 2 files changed, 8 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index bf8a2a3bcc8b39..f433c1261d9ecb 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -100,7 +100,6 @@ cc_library( "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -139,8 +138,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:stablehlo_ops", ], ) @@ -150,15 +147,12 @@ cc_library( srcs = ["dedup_meshes.cc"], hdrs = ["dedup_meshes.h"], deps = [ - "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", "@shardy//shardy/dialect/sdy/transforms/common:sharding_walker", - "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/BUILD index c8555043e9de89..9418fa967f2ae0 100644 --- a/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/stablehlo_round_trip/BUILD @@ -48,9 +48,7 @@ cc_library( srcs = ["export_ops.cc"], hdrs = ["export_ops.h"], deps = [ - "//xla:sharding_op_util", "//xla/mlir_hlo", - "//xla/service/spmd/shardy:constants", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", From 42f38e792fcdbfa3d223828601c78acf5aa0df11 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:32:13 -0700 Subject: [PATCH 0759/1324] Run build_cleaner on BUILD file(s) located in /xla/service/. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747613693 --- third_party/xla/xla/service/graphcycles/BUILD | 1 - third_party/xla/xla/service/llvm_ir/BUILD | 1 - third_party/xla/xla/service/memory_space_assignment/BUILD | 2 -- 3 files changed, 4 deletions(-) diff --git a/third_party/xla/xla/service/graphcycles/BUILD b/third_party/xla/xla/service/graphcycles/BUILD index 5243e8fdb8e264..38b87179c39e3c 100644 --- a/third_party/xla/xla/service/graphcycles/BUILD +++ b/third_party/xla/xla/service/graphcycles/BUILD @@ -18,7 +18,6 @@ cc_library( ":ordered_set", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index 11d5131a30056a..47133d198cfb73 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -227,7 +227,6 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service:elemental_ir_emitter", "//xla/service/cpu:backend_config_proto_cc", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 0a3650e19da5b6..43d988f00fc870 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -38,7 +38,6 @@ cc_library( deps = [ ":algorithm", ":allocation", - ":cost_analysis", ":memory_space_assignment_proto_cc", ":options", ":simulator", @@ -637,7 +636,6 @@ cc_library( hdrs = ["allocation_value.h"], deps = [ ":allocation", - "//xla:shape_tree", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_value", From 9ca468abd63e70b99428c0999e05bb1bbbef0d9d Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:32:57 -0700 Subject: [PATCH 0760/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor/rocm. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747613934 --- third_party/xla/xla/stream_executor/rocm/BUILD | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 84f2f9551f698a..787389ea4344b4 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -40,7 +40,6 @@ cc_library( deps = [ "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/tsl/platform:logging", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -129,7 +128,6 @@ xla_test( "//xla/stream_executor:event", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", @@ -346,9 +344,7 @@ cc_library( "rocm-only", ], deps = [ - ":rocblas_if_static", ":rocm_executor", - ":rocm_platform_id", "//xla/tsl/platform:env", "//xla/tsl/util:determinism_for_kernels", "@local_config_rocm//rocm:rocm_headers", @@ -430,7 +426,6 @@ cc_library( "rocm-only", ], deps = [ - ":hipblas_lt_header", ":hipsolver_wrapper", ":rocblas_wrapper", ":rocsolver_wrapper", @@ -524,6 +519,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/platform:initialize", + "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:macros", @@ -541,6 +537,7 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:hash", ], alwayslink = True, @@ -778,8 +775,6 @@ cc_library( ], deps = [ ":rocm_executor", - ":rocm_platform_id", - ":roctracer_if_static", "//xla/tsl/platform:env", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform", @@ -992,7 +987,6 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_test_kernels_rocm", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", From a11d5971a4fe50a5f8951699de3c6c64d6fff88d Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:36:06 -0700 Subject: [PATCH 0761/1324] Run build_cleaner on BUILD file(s) located in /xla/gpu/transforms. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747615030 --- third_party/xla/xla/service/gpu/transforms/BUILD | 5 ----- 1 file changed, 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 2a643979a48205..f9fabb01c4709b 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -443,7 +443,6 @@ cc_library( "//xla/backends/gpu/codegen:copy", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/service:call_graph", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", "@com_google_absl//absl/container:flat_hash_set", @@ -468,7 +467,6 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", @@ -680,12 +678,9 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/transforms:while_loop_trip_count_annotator", "//xla/service:collective_ops_utils", - "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", From 18dd3eb546e16573e00d29d96146826f4e0a4565 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 16:38:16 -0700 Subject: [PATCH 0762/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor/integrations. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747615880 --- third_party/xla/xla/stream_executor/integrations/BUILD | 5 ----- 1 file changed, 5 deletions(-) diff --git a/third_party/xla/xla/stream_executor/integrations/BUILD b/third_party/xla/xla/stream_executor/integrations/BUILD index d4fb1406426cdf..06f1f40a93ed09 100644 --- a/third_party/xla/xla/stream_executor/integrations/BUILD +++ b/third_party/xla/xla/stream_executor/integrations/BUILD @@ -71,14 +71,9 @@ cc_library( "device_mem_allocator.h", ], deps = [ - "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream_executor_h", "//xla/tsl/framework:allocator", "//xla/tsl/framework:device_id", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/profiler/lib:traceme", ], ) From 1782f3e2dcc78b4144da9729bad027b0c1433a0a Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 14 Apr 2025 16:48:06 -0700 Subject: [PATCH 0763/1324] Accurately print tuple value sizes in HloLiveRange::ToString. HloLiveRange::ToString prints out the live HloValues and their sizes at the point of peak memory usage. Before, for buffers whose defining value was within a tuple, it would incorrectly print the size of the tuple itself, not the value inside the tuple. Now the correct size is printed. Also, the name of the value now shows which tuple index the value is in. This makes the last section of the "buffer-assignment.txt" file dumped by --xla_dump_to more accurate, as the last section is simply the output of HloLiveRange::ToString. PiperOrigin-RevId: 747619031 --- .../xla/xla/hlo/utils/hlo_live_range.cc | 7 +- .../xla/xla/hlo/utils/hlo_live_range_test.cc | 79 +++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.cc b/third_party/xla/xla/hlo/utils/hlo_live_range.cc index 12ee1fbc45ebab..23bae083a4ed96 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.cc @@ -321,9 +321,10 @@ std::string HloLiveRange::ToString() const { auto it = buffer_live_ranges_.find(value); if (it != buffer_live_ranges_.end()) { if (it->second.start <= peak_moment && peak_moment <= it->second.end) { - int64_t bytes = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8); - absl::StrAppendFormat(&output, " %s: %lld bytes\n", - value->instruction()->name(), bytes); + int64_t bytes = ShapeUtil::ByteSizeOf(value->shape(), 8); + absl::StrAppendFormat(&output, " %s%s: %lld bytes\n", + value->instruction()->name(), + value->index().ToString(), bytes); } } } diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc index 5dc63e4434f042..fb1df152d5c134 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc @@ -520,5 +520,84 @@ TEST_F(HloLiveRangeTest, Call) { EXPECT_EQ(inst_ranges["e"], std::make_pair(6, 7)); } +TEST_F(HloLiveRangeTest, ToString) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + module_->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module_.get()); + + schedule.set_sequence(module_->entry_computation(), {paramA, paramX, mul}); + + Analyze(schedule); + + // The peak is at LogicalTime=2, where all three buffers are live. Each array + // of four F32 elements takes 16 bytes. + std::string expected_string = R"(HloLiveRange (max 3): + InstructionSequence: + 0:paramA + 1:paramX + 2:multiply + BufferLiveRange: + paramA{}:0-3 + paramX{}:0-3 + multiply{}:2-3 + Live ranges at 2 (peak): + paramA{}: 16 bytes + paramX{}: 16 bytes + multiply{}: 16 bytes +)"; + EXPECT_EQ(hlo_live_range_->ToString(), expected_string); +} + +TEST_F(HloLiveRangeTest, ToStringTuple) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto tuple_const = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR0(1.0f), + LiteralUtil::CreateR1({2.0f, 3.0f, 4.0f, 5.0f})))); + auto get_tuple_element = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32vec4_, tuple_const, 1)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kAdd, paramA, get_tuple_element)); + module_->AddEntryComputation(builder.Build()); + HloSchedule schedule(module_.get()); + + schedule.set_sequence(module_->entry_computation(), + {paramA, tuple_const, get_tuple_element, add}); + + Analyze(schedule); + + // The peak time is at LogicalTime=1, when both constants in the tuple are + // live. The tuple itself has two pointers of 8 bytes each, and the two + // constants in the tuple have 4 and 16 bytes respectively. + std::string expected_string = R"(HloLiveRange (max 4): + InstructionSequence: + 0:paramA + 1:constant + 2:get-tuple-element + 3:add + BufferLiveRange: + paramA{}:0-4 + constant{}:1-2 + constant{0}:1-1 + constant{1}:1-3 + add{}:3-4 + Live ranges at 1 (peak): + paramA{}: 16 bytes + constant{}: 16 bytes + constant{0}: 4 bytes + constant{1}: 16 bytes +)"; + EXPECT_EQ(hlo_live_range_->ToString(), expected_string); +} + } // namespace } // namespace xla From 533b679f0f4eee17f36ef304a76867d27202502e Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Mon, 14 Apr 2025 16:55:23 -0700 Subject: [PATCH 0764/1324] Update graphviz to use neato engine. PiperOrigin-RevId: 747621381 --- .../core/profiler/convert/hlo_proto_to_graph_view.cc | 11 +++++------ .../core/profiler/convert/hlo_proto_to_graph_view.h | 3 ++- tensorflow/core/profiler/convert/hlo_to_tools_data.cc | 3 ++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc index 87a2f123e7f917..43351cada4ca7c 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc @@ -446,8 +446,9 @@ absl::StatusOr WrapDotInFormat(std::string dot, } } -std::string WrapDotInHtml(std::string dot) { - return absl::StrReplaceAll(R"html( +std::string WrapDotInHtml(std::string dot, absl::string_view layout_engine) { + return absl::StrReplaceAll( + R"html( @@ -528,14 +529,12 @@ std::string WrapDotInHtml(std::string dot) { }); add_controls(svg); }; - hpccWasm.graphviz.layout(dot_data, "svg", "dot").then(render_callback); + hpccWasm.graphviz.layout(dot_data, "svg", "$LAYOUT_ENGINE").then(render_callback); )html", - { - {"$DOT", dot}, - }); + {{"$DOT", dot}, {"$LAYOUT_ENGINE", layout_engine}}); } void RegisterGraphvizURLRenderer( diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h index b3a3a7c45e1175..6f91f1c10feae4 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h +++ b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h @@ -86,7 +86,8 @@ absl::StatusOr WrapDotInFormat(std::string dot, xla::RenderedGraphFormat format); // Convert dot into visual graph in html -std::string WrapDotInHtml(std::string dot); +std::string WrapDotInHtml(std::string dot, + absl::string_view layout_engine = "dot"); // Registers a function which implements RenderedGraphFormat::kUrl. // The input to the function is dot, and the output should be a URL or an error. diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc index ba4a13fa6c52ba..608bc6df8d0d71 100644 --- a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc @@ -83,7 +83,8 @@ absl::StatusOr ConvertHloProtoToAllocationTimeline( return result_or.status(); } - return WrapDotInHtml(std::move(result_or.value().allocation_timeline())); + return WrapDotInHtml(std::move(result_or.value().allocation_timeline()), + "neato"); } absl::StatusOr ConvertHloProtoToGraphViewer( From 7743831af5080de5057a1fddf43b1aa1eefa25a2 Mon Sep 17 00:00:00 2001 From: Majid Dadashi Date: Mon, 14 Apr 2025 16:58:03 -0700 Subject: [PATCH 0765/1324] Make the fully_connected ref kernel do requant in float. PiperOrigin-RevId: 747622232 --- tensorflow/lite/kernels/fully_connected.cc | 46 +++++--- .../internal/reference/fully_connected.h | 109 +++++++++++++++++ .../reference/integer_ops/fully_connected.h | 110 ++++++++++++++++++ 3 files changed, 248 insertions(+), 17 deletions(-) diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index b4620c202cd674..287cf22365d4b4 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -1214,6 +1214,7 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter->params.scale, GetTensorData(output)); } else { optimized_integer_ops::FullyConnected( @@ -1242,12 +1243,14 @@ void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter->params.scale, GetTensorData(output)); } else { reference_integer_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter->params.scale, GetTensorData(output)); } } @@ -1271,12 +1274,16 @@ void FullyConnectedPerChannelInt8(const OpData* data, const TfLiteTensor* input, op_params.rhs_cacheable = IsConstantTensor(input); if (kernel_type == kReference) { + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + const float* filter_scales = affine_quantization->scale->data; reference_integer_ops::FullyConnectedPerChannel( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), filter_data, - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter_scales, + GetTensorData(output)); } else { optimized_integer_ops::FullyConnectedPerChannel( op_params, data->per_channel_output_multiplier.data(), @@ -1300,21 +1307,24 @@ void FullyConnectedPerChannelInt16( op_params.output_offset = output->params.zero_point; op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; + const auto* affine_quantization = + reinterpret_cast(filter->quantization.params); + const float* filter_scales = affine_quantization->scale->data; if (data->quantized_bias_type == kTfLiteInt32) { reference_integer_ops::FullyConnectedPerChannel( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), filter_data, - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter_scales, + GetTensorData(output)); } else { reference_integer_ops::FullyConnectedPerChannel( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), filter_data, - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter_scales, + GetTensorData(output)); } } @@ -1413,7 +1423,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + GetTensorShape(output), input->params.scale, output->params.scale, + filter->params.scale, GetTensorData(output)); } else { optimized_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), @@ -1534,7 +1545,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + GetTensorShape(output), input->params.scale, output->params.scale, + filter->params.scale, GetTensorData(output)); } else { optimized_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), diff --git a/tensorflow/lite/kernels/internal/reference/fully_connected.h b/tensorflow/lite/kernels/internal/reference/fully_connected.h index ba51cbcfe3e8a0..bccc6220062564 100644 --- a/tensorflow/lite/kernels/internal/reference/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/fully_connected.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ #include +#include +#include #include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" @@ -62,6 +64,59 @@ inline void FullyConnected( } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + float input_scale, float output_scale, float filter_scale, + uint8_t* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + // TODO(b/62193649): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32_t acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32_t input_val = input_data[b * accum_depth + d]; + int32_t filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + const double effective_output_scale = static_cast(input_scale) * + static_cast(filter_scale) / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(acc) * effective_output_scale)); + acc_scaled += output_offset; + acc_scaled = std::max(acc_scaled, output_activation_min); + acc_scaled = std::min(acc_scaled, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc_scaled); + } + } +} + inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const uint8_t* input_data, const RuntimeShape& filter_shape, @@ -164,6 +219,60 @@ inline void FullyConnected( } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + float input_scale, float output_scale, float filter_scale, + int16_t* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_EQ(output_offset, 0); + // TODO(b/62193649): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32_t accum = bias_data[out_c]; + // Accumulation loop. + for (int d = 0; d < accum_depth; ++d) { + int16_t input_val = input_data[b * accum_depth + d] + input_offset; + int16_t filter_val = + filter_data[out_c * accum_depth + d] + filter_offset; + accum += filter_val * input_val; + } + const double effective_output_scale = static_cast(input_scale) * + static_cast(filter_scale) / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(accum) * effective_output_scale)); + // Saturate, cast to int16_t, and store to output array. + acc_scaled = std::max(acc_scaled, output_activation_min - output_offset); + acc_scaled = std::min(acc_scaled, output_activation_max - output_offset); + acc_scaled += output_offset; + output_data[out_c + output_depth * b] = acc_scaled; + } + } +} + inline void ShuffledFullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const uint8_t* input_data, const RuntimeShape& weights_shape, diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h index c6d06077934839..f249beef8503f6 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_ #include +#include +#include #include "tensorflow/lite/kernels/internal/common.h" @@ -74,6 +76,61 @@ void FullyConnectedPerChannel( } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +template +void FullyConnectedPerChannel( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const InputType* input_data, const RuntimeShape& filter_shape, + const WeightType* filter_data, const RuntimeShape& bias_shape, + const BiasType* bias_data, const RuntimeShape& output_shape, + float input_scale, float output_scale, const float* filter_scales, + OutputType* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int filter_dim_count = filter_shape.DimensionsCount(); + + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); + TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + BiasType acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32_t input_val = input_data[b * accum_depth + d]; + int32_t filter_val = filter_data[out_c * accum_depth + d]; + acc += filter_val * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + + const float scale = filter_scales[out_c]; + const double filter_scale = static_cast(scale); + const double effective_output_scale = static_cast(input_scale) * + filter_scale / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(acc) * effective_output_scale)); + + acc_scaled += output_offset; + acc_scaled = std::max(acc_scaled, output_activation_min); + acc_scaled = std::min(acc_scaled, output_activation_max); + output_data[out_c + output_depth * b] = + static_cast(acc_scaled); + } + } +} + template void FullyConnected(const FullyConnectedParams& params, @@ -122,6 +179,59 @@ void FullyConnected(const FullyConnectedParams& params, } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +template +void FullyConnected(const FullyConnectedParams& params, + const RuntimeShape& input_shape, + const InputType* input_data, + const RuntimeShape& filter_shape, + const WeightType* filter_data, + const RuntimeShape& bias_shape, const BiasType* bias_data, + const RuntimeShape& output_shape, float input_scale, + float output_scale, float filter_scale, + OutputType* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); + TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + BiasType acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32_t input_val = input_data[b * accum_depth + d]; + int32_t filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + const double effective_output_scale = static_cast(input_scale) * + static_cast(filter_scale) / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(acc) * effective_output_scale)); + acc_scaled += output_offset; + acc_scaled = std::max(acc_scaled, output_activation_min); + acc_scaled = std::min(acc_scaled, output_activation_max); + output_data[out_c + output_depth * b] = + static_cast(acc_scaled); + } + } +} + } // namespace reference_integer_ops } // namespace tflite From 138f00abf6cf678ef4c173f1e6526a9e23e7a0ec Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 17:22:50 -0700 Subject: [PATCH 0766/1324] Run build_cleaner on BUILD file(s) located in /xla/service/gpu. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747630095 --- third_party/xla/xla/service/gpu/BUILD | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1804a153fb6266..de44c6df57fdcc 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -93,12 +93,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", - "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", "//xla/service:global_device_id", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) @@ -136,10 +134,7 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", ], ) @@ -472,7 +467,6 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/service/llvm_ir:tuple_ops", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", @@ -564,6 +558,7 @@ cc_library( ":gpu_constants", ":gpu_executable_run_options", ":ir_emission_utils", + ":resource_requests", ":stream_executor_util", "//xla:executable_run_options", "//xla:shape_tree", @@ -586,7 +581,6 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:stream_pool", "//xla/service:xla_debug_info_manager", - "//xla/service/gpu:resource_requests", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", @@ -935,15 +929,12 @@ xla_cc_test( deps = [ ":matmul_indexing_utils", "//xla:shape_util", - "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:test", - "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", ], ) @@ -1767,7 +1758,6 @@ cc_library( ], deps = [ ":nvptx_compiler_impl", - "//xla:debug_options_flags", "//xla/service:compiler", "//xla/stream_executor/cuda:cuda_platform_id", "@local_tsl//tsl/platform:path", @@ -2124,12 +2114,10 @@ cc_library( "//xla:util", "//xla/stream_executor:device_memory_handle", "//xla/stream_executor:stream_executor_h", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:notification", @@ -2271,7 +2259,6 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", ], ) @@ -2571,7 +2558,6 @@ xla_cc_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -2849,7 +2835,6 @@ xla_cc_test( "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:test", ], ) @@ -3041,7 +3026,6 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/stream_executor/cuda:compilation_options", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", ], ) @@ -3094,7 +3078,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", From b5e2e2c8b024ac5feb50f14e3904d35a64e802d2 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Mon, 14 Apr 2025 18:00:50 -0700 Subject: [PATCH 0767/1324] Binary Elementwise Ops : Direct StableHLO -> HLO Translation PiperOrigin-RevId: 747640422 --- .../mhlo_to_hlo/gen_hlo_op_writer.td | 26 ++++++------- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 16 ++++++++ .../xla/hlo/translate/tests/stablehlo.mlir | 37 +++++++++++++++++++ .../stablehlo_legalize_to_hlo_pass.cc | 18 +++++---- 4 files changed, 77 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index 98cc0524fdf4d5..1f1c88b0c5730c 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -26,7 +26,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_AllReduceOp, // StableHLO_AllToAllOp, // StableHLO_AndOp, - // StableHLO_Atan2Op, + StableHLO_Atan2Op, // StableHLO_BatchNormGradOp, // StableHLO_BatchNormInferenceOp, // StableHLO_BatchNormTrainingOp, @@ -42,7 +42,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_CollectiveBroadcastOp, // StableHLO_CollectivePermuteOp, // StableHLO_CompareOp, - // StableHLO_ComplexOp, + StableHLO_ComplexOp, // StableHLO_CompositeOp, // StableHLO_ConcatenateOp, StableHLO_ConstantOp, @@ -52,7 +52,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_CreateTokenOp, // StableHLO_CrossReplicaSumOp, // StableHLO_CustomCallOp, - // StableHLO_DivOp, + StableHLO_DivOp, // StableHLO_DotGeneralOp, // StableHLO_DotOp, StableHLO_DynamicBroadcastInDimOp, @@ -80,9 +80,9 @@ defvar HloConversionAllowedOps = [ StableHLO_LogisticOp, StableHLO_LogOp, // StableHLO_MapOp, - // StableHLO_MaxOp, - // StableHLO_MinOp, - // StableHLO_MulOp, + StableHLO_MaxOp, + StableHLO_MinOp, + StableHLO_MulOp, StableHLO_NegOp, StableHLO_NotOp, // StableHLO_OptimizationBarrierOp, @@ -91,7 +91,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_PadOp, // StableHLO_PartitionIdOp, StableHLO_PopulationCountOp, - // StableHLO_PowOp, + StableHLO_PowOp, // StableHLO_RealDynamicSliceOp, StableHLO_RealOp, // StableHLO_RecvOp, @@ -99,7 +99,7 @@ defvar HloConversionAllowedOps = [ // StableHLO_ReducePrecisionOp, // StableHLO_ReduceScatterOp, // StableHLO_ReduceWindowOp, - // StableHLO_RemOp, + StableHLO_RemOp, // StableHLO_ReplicaIdOp, // StableHLO_ReshapeOp, // StableHLO_ReturnOp, @@ -114,15 +114,15 @@ defvar HloConversionAllowedOps = [ // StableHLO_SelectOp, // StableHLO_SendOp, // StableHLO_SetDimensionSizeOp, - // StableHLO_ShiftLeftOp, - // StableHLO_ShiftRightArithmeticOp, - // StableHLO_ShiftRightLogicalOp, + StableHLO_ShiftLeftOp, + StableHLO_ShiftRightArithmeticOp, + StableHLO_ShiftRightLogicalOp, StableHLO_SignOp, StableHLO_SineOp, StableHLO_SliceOp, // StableHLO_SortOp, StableHLO_SqrtOp, - // StableHLO_SubtractOp, + StableHLO_SubtractOp, StableHLO_TanhOp, StableHLO_TanOp, // StableHLO_TorchIndexSelectOp, @@ -187,7 +187,7 @@ defvar CustomHloConverterOps = [ // StableHLO_SetDimensionSizeOp, StableHLO_SineOp, // StableHLO_SortOp, - // StableHLO_SubtractOp, + StableHLO_SubtractOp, StableHLO_TanOp, // StableHLO_UniformDequantizeOp, // StableHLO_UniformQuantizeOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 58da59d75e2f3a..c9254a295d739b 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -1248,6 +1248,22 @@ LogicalResult ExportXlaOp(TanOp op, OpLoweringContext ctx) { return mlir::success(); } +LogicalResult ExportXlaOp(SubtractOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + xla::XlaOp lhs; + if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &lhs, op))) + return mlir::failure(); + + xla::XlaOp rhs; + if (failed(GetXlaOp(*op.getODSOperands(1).begin(), value_map, &rhs, op))) + return mlir::failure(); + + auto xla_result = xla::Sub(Unwrap(lhs), Unwrap(rhs)); + value_map[result] = xla_result; + return mlir::success(); +} + } // namespace } // namespace stablehlo diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir index ec933fcfb5c12c..d4c4aec7fdb8d3 100644 --- a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir @@ -118,3 +118,40 @@ module { } } // CHECK-DIRECT: stablehlo.convolution + +// ----- +// Binary elementwise ops + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[2,2]{1,0}, f32[2,2]{1,0})->f32[2,2]{1,0}} + +// CHECK: ENTRY %[[$main_11:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[2,2] parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = f32[2,2] parameter(1) +// CHECK-NEXT: %[[add_3:[^ ]+]] = f32[2,2] add(%[[Arg_0_1]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[atan2_4:[^ ]+]] = f32[2,2] atan2(%[[add_3]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[divide_5:[^ ]+]] = f32[2,2] divide(%[[atan2_4]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[maximum_6:[^ ]+]] = f32[2,2] maximum(%[[divide_5]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[minimum_7:[^ ]+]] = f32[2,2] minimum(%[[maximum_6]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[multiply_8:[^ ]+]] = f32[2,2] multiply(%[[minimum_7]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[power_9:[^ ]+]] = f32[2,2] power(%[[multiply_8]], %[[Arg_1_2]]), +// CHECK-NEXT: ROOT %[[subtract_10:[^ ]+]] = f32[2,2] subtract(%[[power_9]], %[[Arg_1_2]]), + +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2x2xf32> + %1 = stablehlo.atan2 %0, %arg1 : tensor<2x2xf32> + %2 = stablehlo.divide %1, %arg1 : tensor<2x2xf32> + %3 = stablehlo.maximum %2, %arg1 : tensor<2x2xf32> + %4 = stablehlo.minimum %3, %arg1 : tensor<2x2xf32> + %5 = stablehlo.multiply %4, %arg1 : tensor<2x2xf32> + %6 = stablehlo.power %5, %arg1 : tensor<2x2xf32> + %7 = stablehlo.subtract %6, %arg1 : tensor<2x2xf32> + return %7 : tensor<2x2xf32> +} +// CHECK-DIRECT: stablehlo.add +// CHECK-DIRECT: stablehlo.atan2 +// CHECK-DIRECT: stablehlo.divide +// CHECK-DIRECT: stablehlo.maximum +// CHECK-DIRECT: stablehlo.minimum +// CHECK-DIRECT: stablehlo.multiply +// CHECK-DIRECT: stablehlo.power +// CHECK-DIRECT: stablehlo.subtract diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 58c1ab8d9f70cc..7acd9e3b3ea980 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -68,17 +68,21 @@ void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< // go/keep-sorted start stablehlo::AbsOp, stablehlo::CbrtOp, stablehlo::SqrtOp, stablehlo::TanOp, - stablehlo::AddOp, stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, - stablehlo::CeilOp, stablehlo::ClzOp, stablehlo::ConvertOp, - stablehlo::ConstantOp, stablehlo::ConvolutionOp, stablehlo::CosineOp, - stablehlo::DynamicSliceOp, stablehlo::FloorOp, stablehlo::ImagOp, - stablehlo::ExpOp, stablehlo::Expm1Op, stablehlo::DynamicBroadcastInDimOp, + stablehlo::AddOp, stablehlo::Atan2Op, stablehlo::AddOp, + stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, stablehlo::CeilOp, + stablehlo::ClzOp, stablehlo::ConvertOp, stablehlo::ComplexOp, + stablehlo::ConvolutionOp, stablehlo::CosineOp, stablehlo::DynamicSliceOp, + stablehlo::DivOp, stablehlo::MaxOp, stablehlo::ConstantOp, + stablehlo::Expm1Op, stablehlo::DynamicBroadcastInDimOp, + stablehlo::FloorOp, stablehlo::ImagOp, stablehlo::ExpOp, stablehlo::IsFiniteOp, stablehlo::Log1pOp, stablehlo::LogOp, stablehlo::LogisticOp, stablehlo::NegOp, stablehlo::NotOp, + stablehlo::MinOp, stablehlo::MulOp, stablehlo::PowOp, stablehlo::RemOp, stablehlo::PopulationCountOp, stablehlo::RealOp, stablehlo::RoundNearestEvenOp, stablehlo::RoundOp, stablehlo::RsqrtOp, - stablehlo::SignOp, stablehlo::SineOp, stablehlo::SliceOp, - stablehlo::TanhOp + stablehlo::ShiftLeftOp, stablehlo::ShiftRightArithmeticOp, + stablehlo::ShiftRightLogicalOp, stablehlo::SubtractOp, stablehlo::SignOp, + stablehlo::SineOp, stablehlo::SliceOp, stablehlo::TanhOp // go/keep-sorted end >(); } From 0cf6a5019e29305c61ae4bbf9116b076032e63c8 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 14 Apr 2025 18:10:59 -0700 Subject: [PATCH 0768/1324] Use `ToString` instead of `DebugString` for `Device`'s debug string `DeviceList` already uses `device->ToString()`, so it makes sense from the consistency perspective to use `ToString` for devices as well. Also, `ToString` is generally a lot more readable so this is desirable. PiperOrigin-RevId: 747643529 --- third_party/xla/xla/python/ifrt/device.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/device.h b/third_party/xla/xla/python/ifrt/device.h index c20b40008d1941..a9d86a7d17e3ed 100644 --- a/third_party/xla/xla/python/ifrt/device.h +++ b/third_party/xla/xla/python/ifrt/device.h @@ -68,8 +68,6 @@ class Device : public llvm::RTTIExtends { // Debug string suitable for logging when errors occur. Should be verbose // enough to describe the current device unambiguously. - // - // TODO(hyeontaek): Remove this method in favor of AbslStringify. virtual absl::string_view DebugString() const = 0; // Returns the default memory space attached to this device. @@ -91,7 +89,7 @@ class Device : public llvm::RTTIExtends { template friend void AbslStringify(Sink& sink, const Device& device) { - sink.Append(device.DebugString()); + sink.Append(device.ToString()); } template @@ -99,7 +97,7 @@ class Device : public llvm::RTTIExtends { if (device == nullptr) { sink.Append(""); } else { - sink.Append(device->DebugString()); + sink.Append(device->ToString()); } } From 8599ba0771d114c6be0b4d394e934c799f57a937 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 14 Apr 2025 18:12:53 -0700 Subject: [PATCH 0769/1324] [xla:gpu] CommandBuffer: use execution graph to automatically infer command dependencies If `xla_gpu_graph_enable_concurrent_region` is set to true, use `ExecutionGraph` to infer command dependencies and create command buffer as a DAG. PiperOrigin-RevId: 747644032 --- .../xla/xla/backends/gpu/runtime/BUILD | 2 + .../gpu/runtime/command_buffer_cmd.cc | 97 +++++++++++++++---- .../backends/gpu/runtime/command_buffer_cmd.h | 11 ++- .../gpu/runtime/command_buffer_cmd_test.cc | 18 ++-- .../gpu/runtime/command_buffer_thunk_test.cc | 63 +++++++----- third_party/xla/xla/runtime/execution_graph.h | 11 +++ .../service/gpu/tests/command_buffer_test.cc | 20 +++- .../stream_executor/gpu/gpu_command_buffer.cc | 2 + 8 files changed, 167 insertions(+), 57 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 2bcf42526d9210..a370f43c61e8a2 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -72,6 +72,8 @@ cc_library( "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", "//xla/runtime:buffer_use", + "//xla/runtime:execution_graph", + "//xla/runtime:resource_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 8b22edcf66494a..375cc6e514fc01 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -57,6 +57,8 @@ limitations under the License. #include "xla/ffi/call_frame.h" #include "xla/ffi/ffi_api.h" #include "xla/runtime/buffer_use.h" +#include "xla/runtime/execution_graph.h" +#include "xla/runtime/resource_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" @@ -225,21 +227,58 @@ CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( // CommandBufferCmdSequence //===----------------------------------------------------------------------===// +namespace { + +// An adaptor from CommandBufferCmd to ExecutionGraph::Operation for building an +// execution graph from a command sequence. +class CommandOperation : public ExecutionGraph::Operation { + public: + explicit CommandOperation(CommandBufferCmd::BufferUseVector buffers) + : buffers_(std::move(buffers)) {} + + absl::Span BufferUses() const final { return buffers_; } + absl::Span ResourceUses() const final { return {}; } + + private: + CommandBufferCmd::BufferUseVector buffers_; +}; + +} // namespace + void CommandBufferCmdSequence::Builder::Append( std::unique_ptr cmd) { commands_.push_back({std::move(cmd)}); } -CommandBufferCmdSequence CommandBufferCmdSequence::Builder::Build( +absl::StatusOr +CommandBufferCmdSequence::Builder::Build( SynchronizationMode synchronization_mode) && { - return CommandBufferCmdSequence(synchronization_mode, std::move(commands_)); + std::optional execution_graph = std::nullopt; + + // In automatic synchronization mode construct an execution graph for the + // sequence of commands and derive the structure of command dependencies + // from the buffer use conflicts. + if (synchronization_mode == SynchronizationMode::kAutomatic) { + std::vector operations; + operations.reserve(commands_.size()); + for (const std::unique_ptr& cmd : commands_) { + operations.emplace_back(cmd->buffers()); + } + TF_ASSIGN_OR_RETURN(execution_graph, + ExecutionGraph::Create(operations)); + } + + return CommandBufferCmdSequence(synchronization_mode, std::move(commands_), + std::move(execution_graph)); } CommandBufferCmdSequence::CommandBufferCmdSequence( SynchronizationMode synchronization_mode, - std::vector> commands) + std::vector> commands, + std::optional execution_graph) : synchronization_mode_(synchronization_mode), - commands_(std::move(commands)) { + commands_(std::move(commands)), + execution_graph_(std::move(execution_graph)) { // Record all buffers used by commands in the sequence. for (const std::unique_ptr& cmd : commands_) { for (const BufferUse& buffer : cmd->buffers()) { @@ -420,15 +459,13 @@ absl::Status CommandBufferCmdSequence::CheckCommandBufferState( return absl::OkStatus(); } -// TODO(b/406370928): Currently we assume sequential execution order of all -// recorded commands, so we use a very simple rule for identifying source and -// sink nodes and computing dependencies. Long term we should get that from the -// ExecutionGraph helper, when using automatic synchronization mode. - -bool CommandBufferCmdSequence::IsSource(CommandId id) const { return id == 0; } +bool CommandBufferCmdSequence::IsSource(CommandId id) const { + return execution_graph_ ? execution_graph_->is_source(id) : id == 0; +} bool CommandBufferCmdSequence::IsSink(CommandId id) const { - return id + 1 == commands_.size(); + return execution_graph_ ? execution_graph_->is_sink(id) + : id + 1 == commands_.size(); } std::vector @@ -440,19 +477,37 @@ CommandBufferCmdSequence::Dependencies(const RecordParams& record_params, return {}; } - // Find recorded command state for the previous command in the sequence. - auto* record_state = record_params.state.GetOrNull( - commands_[id - 1].get(), command_buffer); - DCHECK(record_state) << "Record state must be not null for " - << commands_[id - 1]->ToString(); + // Collect commands that are dependencies of the command `id`. + absl::InlinedVector dependencies_ids; + if (execution_graph_) { + dependencies_ids.assign(execution_graph_->in_edges(id).begin(), + execution_graph_->in_edges(id).end()); + } else { + dependencies_ids.push_back(id - 1); + } - // Some commands might end up not recording anything into the command buffer, - // e.g. memcpy commands where source and destination are the same. - if (record_state->command == nullptr) { - return {}; + // Collect dependencies from the recorded command state. + std::vector dependencies; + for (CommandId dependency_id : dependencies_ids) { + auto* record_state = record_params.state.GetOrNull( + commands_[dependency_id].get(), command_buffer); + DCHECK(record_state) << "Record state must be not null for " + << commands_[dependency_id]->ToString(); + + if (record_state->command == nullptr) { + // Some commands might end up not recording anything into the command + // buffer, e.g. memcpy commands where source and destination are the same. + // We have to follow dependencies of such commands to find the real + // dependencies, so we don't record a command that is immediately ready to + // execute, as it will create data races. + auto deps = Dependencies(record_params, command_buffer, dependency_id); + dependencies.insert(dependencies.end(), deps.begin(), deps.end()); + } else { + dependencies.push_back(record_state->command); + } } - return {record_state->command}; + return dependencies; } const absl::flat_hash_set& CommandBufferCmdSequence::buffers() diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index eab8981d693e3c..aeee085b44bc89 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -45,6 +45,7 @@ limitations under the License. #include "xla/ffi/api/c_api.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/runtime/buffer_use.h" +#include "xla/runtime/execution_graph.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" @@ -326,7 +327,8 @@ class CommandBufferCmdSequence { Append(std::make_unique(std::forward(args)...)); } - CommandBufferCmdSequence Build(SynchronizationMode synchronization_mode) &&; + absl::StatusOr Build( + SynchronizationMode synchronization_mode) &&; private: std::vector> commands_; @@ -390,7 +392,8 @@ class CommandBufferCmdSequence { CommandBufferCmdSequence( SynchronizationMode synchronization_mode, - std::vector> commands); + std::vector> commands, + std::optional execution_graph); absl::Status CheckCommandBufferState( se::CommandBuffer* command_buffer, @@ -410,6 +413,10 @@ class CommandBufferCmdSequence { SynchronizationMode synchronization_mode_; std::vector> commands_; + // In automatic synchronization mode we build an execution graph for the + // sequence of commands and use it to set up dependencies between commands. + std::optional execution_graph_; + // Buffers referenced by commands in this sequence. absl::flat_hash_set buffers_; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 740a8008d74a04..26889e6ce09457 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -156,7 +156,8 @@ TEST(CommandBufferCmdTest, SerializeExecution) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -174,7 +175,8 @@ TEST(CommandBufferCmdTest, NoReadBarrier) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -192,7 +194,8 @@ TEST(CommandBufferCmdTest, NoWriteBarrier) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -213,7 +216,8 @@ TEST(CommandBufferCmdTest, WriteConflictBarrier) { builder.Emplace(s0, BufferUseVector{use0}); builder.Emplace(s0, BufferUseVector{use1}); builder.Emplace(s0, BufferUseVector{use2}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // TODO(ezhulenev): Check that commands correctly infer dependencies. } @@ -242,7 +246,8 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_b, slice_a, byte_length); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); ServiceExecutableRunOptions run_options; se::StreamExecutorMemoryAllocator allocator(executor); @@ -298,7 +303,8 @@ TEST(CommandBufferCmdTest, LaunchCmd) { builder.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Initialize command sequence and load device kernels. TF_ASSERT_OK_AND_ASSIGN(std::vector fatbin, diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 6fea36551f864c..9c2185a95afde7 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -163,7 +163,8 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_b, slice_a, byte_length); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -218,7 +219,8 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -261,7 +263,8 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{84}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -311,7 +314,8 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { // be used. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{12}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); constexpr bool kProfileCommandBuffersEnabled = false; // Construct a thunk with command sequence. @@ -366,7 +370,8 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { // be used. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{12}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); constexpr bool kProfileCommandBuffersEnabled = true; // Construct a thunk with command sequence. @@ -411,7 +416,8 @@ TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice0, int32_t{12}); builder.Emplace(s1, slice1, int32_t{34}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -465,7 +471,8 @@ TEST(CommandBufferThunkTest, LaunchCmd) { builder.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -563,7 +570,8 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { builder.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -681,7 +689,8 @@ TEST(CommandBufferThunkTest, GemmCmd) { builder.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, slice_workspace, /*deterministic=*/true); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -808,8 +817,8 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { embed_builder.Emplace(s0, config.value(), fake_slice_lhs, slice_rhs, slice_out, slice_workspace, /*deterministic=*/true); - CommandBufferCmdSequence embed_commands = - std::move(embed_builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence embed_commands, + std::move(embed_builder).Build(serialize)); BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); @@ -839,7 +848,8 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { builder.Emplace( s0, std::move(embed_commands), arguments, std::move(fake_allocations), offsets, orig_shapes, sliced_shapes, offset_byte_sizes); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -952,7 +962,8 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), slice_workspace); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1086,7 +1097,8 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { builder.Emplace(s0, "AddI32", args_1, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1210,13 +1222,16 @@ TEST(CommandBufferThunkTest, CaseCmd) { } std::vector branches(2); - branches[0] = std::move(branches_builder[0]).Build(serialize); - branches[1] = std::move(branches_builder[1]).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(branches[0], + std::move(branches_builder[0]).Build(serialize)); + TF_ASSERT_OK_AND_ASSIGN(branches[1], + std::move(branches_builder[1]).Build(serialize)); // Prepare commands sequence for thunk. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_i, false, std::move(branches)); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1303,22 +1318,23 @@ TEST(CommandBufferThunkTest, WhileCmd) { cond_commands_builder.Emplace( s0, "IncAndCmp", cond_args, cond_args_access, LaunchDimensions(1, 1), /*shmem_bytes=*/0); - CommandBufferCmdSequence cond_commands = - std::move(cond_commands_builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence cond_commands, + std::move(cond_commands_builder).Build(serialize)); // Prepare commands sequence for loop `body`. CommandBufferCmdSequence::Builder body_commands_builder; body_commands_builder.Emplace( s0, "AddI32", body_args, body_args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - CommandBufferCmdSequence body_commands = - std::move(body_commands_builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence body_commands, + std::move(body_commands_builder).Build(serialize)); // Prepare commands sequence for thunk. CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_pred, std::move(cond_commands), std::move(body_commands)); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); @@ -1469,7 +1485,8 @@ TEST(CommandBufferThunkTest, ToStringPrintsNestedThunks) { BufferAllocation::Slice slice_a(&alloc_a, /*offset=*/0, /*size=*/4); CommandBufferCmdSequence::Builder builder; builder.Emplace(s0, slice_a, int32_t{42}); - CommandBufferCmdSequence commands = std::move(builder).Build(serialize); + TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, + std::move(builder).Build(serialize)); std::vector> thunks; thunks.emplace_back( std::make_unique(Thunk::ThunkInfo(), 42, slice_a)); diff --git a/third_party/xla/xla/runtime/execution_graph.h b/third_party/xla/xla/runtime/execution_graph.h index 2c2304b109e351..71d4999567a272 100644 --- a/third_party/xla/xla/runtime/execution_graph.h +++ b/third_party/xla/xla/runtime/execution_graph.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/statusor.h" @@ -112,6 +113,16 @@ class ExecutionGraph { // Sink nodes are the nodes that do not have any out-edges. absl::Span sink() const { return sink_; } + // Returns true if a given node id is a source node. + bool is_source(NodeId id) const { + return absl::c_find(source_, id) != source_.end(); + } + + // Returns true if a given node id is a sink node. + bool is_sink(NodeId id) const { + return absl::c_find(sink_, id) != sink_.end(); + } + // Returns in-edges for a given node id. absl::Span in_edges(NodeId id) const { DCHECK_EQ(id, nodes_defs_[id].id); diff --git a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc index f7c6f95c016f54..446904b0019deb 100644 --- a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc +++ b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc @@ -27,9 +27,16 @@ limitations under the License. namespace xla::gpu { namespace { -class CommandBufferTest : public HloTestBase {}; +class CommandBufferTest : public HloTestBase, + public ::testing::WithParamInterface { + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_graph_enable_concurrent_region(GetParam()); + return debug_options; + } +}; -TEST_F(CommandBufferTest, Fusions) { +TEST_P(CommandBufferTest, Fusions) { constexpr absl::string_view hlo_text = R"( HloModule m, is_scheduled=true @@ -70,7 +77,7 @@ TEST_F(CommandBufferTest, Fusions) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_F(CommandBufferTest, TrueFalseConditional) { +TEST_P(CommandBufferTest, TrueFalseConditional) { constexpr absl::string_view hlo_text = R"( HloModule m, is_scheduled=true @@ -129,7 +136,7 @@ TEST_F(CommandBufferTest, TrueFalseConditional) { } } -TEST_F(CommandBufferTest, IndexConditional) { +TEST_P(CommandBufferTest, IndexConditional) { constexpr absl::string_view hlo_text = R"( HloModule m, is_scheduled=true @@ -196,7 +203,7 @@ TEST_F(CommandBufferTest, IndexConditional) { } } -TEST_F(CommandBufferTest, WhileLoop) { +TEST_P(CommandBufferTest, WhileLoop) { constexpr absl::string_view hlo_text = R"( HloModule m, is_scheduled=true @@ -257,5 +264,8 @@ TEST_F(CommandBufferTest, WhileLoop) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +INSTANTIATE_TEST_SUITE_P(CommandBufferTests, CommandBufferTest, + ::testing::Values(false, true)); + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 041b886b84b2d4..66d9f98b30088a 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -102,6 +102,8 @@ GpuCommandBuffer::ToGraphNodeDependencies( std::vector handles; for (const Command* dep : dependencies) { + DCHECK(dep) << "Dependency command must be not null"; + if (auto* gpu_command = dynamic_cast(dep)) { handles.push_back(gpu_command->handle); From f72f40e8f8eaf9cfe0390072e381e1247a0453b9 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 18:22:32 -0700 Subject: [PATCH 0770/1324] Run build_cleaner on BUILD file(s) located in /xla/stream_executor. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747646319 --- third_party/xla/xla/stream_executor/BUILD | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index e99aeff6c0484b..30611b377bc593 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -113,10 +113,10 @@ cc_library( name = "gpu_solver_context", hdrs = ["gpu_solver_context.h"], deps = [ + ":blas", + ":device_memory", + ":stream", "//xla:xla_data_proto_cc", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:stream", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], @@ -158,7 +158,6 @@ cc_library( ":device_memory", ":event", ":event_based_timer", - ":kernel", ":launch_dim", ":platform", ":stream", @@ -216,18 +215,11 @@ cc_library( deps = [ ":device_memory", ":platform", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", ], ) @@ -539,7 +531,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -610,7 +601,6 @@ cc_library( ":dnn", ":kernel", ":launch_dim", - "//xla/tsl/platform:errors", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -660,7 +650,6 @@ cc_library( srcs = ["stream_common.cc"], hdrs = ["stream_common.h"], deps = [ - ":blas", ":device_description", ":fft", ":platform", @@ -671,7 +660,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -724,7 +712,6 @@ xla_cc_test( srcs = ["generic_memory_allocation_test.cc"], deps = [ ":generic_memory_allocation", - "//xla/tsl/platform:macros", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], @@ -737,7 +724,6 @@ xla_cc_test( ":generic_memory_allocation", ":generic_memory_allocator", ":memory_allocation", - "//xla/tsl/platform:macros", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", @@ -828,11 +814,9 @@ xla_cc_test( name = "scoped_module_handle_test", srcs = ["scoped_module_handle_test.cc"], deps = [ - ":device_description", ":mock_stream_executor", ":module_spec", ":scoped_module_handle", - "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", ], @@ -857,7 +841,6 @@ xla_cc_test( ":device_description", ":semantic_version", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test", ], ) From 14046e1121261c514598198f8074f34e4da7954d Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Mon, 14 Apr 2025 18:50:18 -0700 Subject: [PATCH 0771/1324] Port `xla/hlo/builder` Client tests. This change is the first step of two in the ClientLibraryTestBase process of our migration to the PjRt-based runners. PiperOrigin-RevId: 747652166 --- third_party/xla/xla/hlo/builder/lib/BUILD | 53 ++++-- .../xla/hlo/builder/lib/arithmetic_test.cc | 15 +- .../xla/hlo/builder/lib/comparators_test.cc | 23 +-- .../xla/xla/hlo/builder/lib/math_test.cc | 171 +++++++++--------- .../xla/xla/hlo/builder/lib/matrix_test.cc | 79 ++++---- .../xla/xla/hlo/builder/lib/pooling_test.cc | 77 ++++---- .../xla/xla/hlo/builder/lib/prng_test.cc | 9 +- .../hlo/builder/lib/self_adjoint_eig_test.cc | 56 +++--- .../xla/xla/hlo/builder/lib/svd_test.cc | 65 +++---- .../xla/hlo/builder/lib/tridiagonal_test.cc | 52 +++--- .../xla/xla/hlo/builder/lib/tuple_test.cc | 21 +-- 11 files changed, 305 insertions(+), 316 deletions(-) diff --git a/third_party/xla/xla/hlo/builder/lib/BUILD b/third_party/xla/xla/hlo/builder/lib/BUILD index cb27b469a5e3cc..3c6faf303adec8 100644 --- a/third_party/xla/xla/hlo/builder/lib/BUILD +++ b/third_party/xla/xla/hlo/builder/lib/BUILD @@ -47,8 +47,9 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:test", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], @@ -85,8 +86,9 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", "//xla/service:hlo_proto_cc", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:test", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -245,9 +247,11 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", "//xla/service", - "//xla/tests:client_library_test_base", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", ], ) @@ -290,11 +294,14 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:array4d", + "//xla:literal", + "//xla:literal_util", "//xla:types", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -333,8 +340,9 @@ xla_test( "//xla:shape_util", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:test", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -369,8 +377,9 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -603,13 +612,18 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", "//xla/tests:client_library_test_base", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -659,8 +673,10 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", - "//xla/tests:client_library_test_base", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", ], @@ -696,18 +712,19 @@ xla_test( deps = [ ":slicing", ":tridiagonal", - "//xla:array", "//xla:array3d", "//xla:literal", "//xla:shape_util", "//xla:util", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", - "//xla/tests:client_library_test_base", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", ], ) @@ -777,8 +794,10 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/service", - "//xla/tests:client_library_test_base", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc index 2e5b546f801e84..20689b19cf1c48 100644 --- a/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc @@ -22,14 +22,15 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class ArithmeticTest : public ClientLibraryTestBase { +class ArithmeticTest : public ClientLibraryTestRunnerMixin { public: template void TestArgMin(std::initializer_list> input, @@ -63,22 +64,22 @@ class ArithmeticTest : public ClientLibraryTestBase { std::function MinMaxImpl) {} }; -XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { +TEST_F(ArithmeticTest, ArgMinR2Axis0) { TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2}, /*axis=*/0); } -XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { +TEST_F(ArithmeticTest, ArgMinR2Axis1) { TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1}, /*axis=*/1); } -XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { +TEST_F(ArithmeticTest, ArgMaxR2Axis0) { TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1}, /*axis=*/0); } -XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { +TEST_F(ArithmeticTest, ArgMaxR2Axis1) { TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0}, /*axis=*/1); } diff --git a/third_party/xla/xla/hlo/builder/lib/comparators_test.cc b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc index 8c831b1a7a2acb..51eb6f10c38fb7 100644 --- a/third_party/xla/xla/hlo/builder/lib/comparators_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc @@ -29,15 +29,16 @@ limitations under the License. #include "xla/hlo/testlib/test.h" #include "xla/primitive_util.h" #include "xla/service/hlo.pb.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" namespace xla { namespace { -class ComparatorsTest : public ClientLibraryTestBase { +class ComparatorsTest : public ClientLibraryTestRunnerMixin { public: ComparatorsTest() : builder_(TestName()) {} XlaBuilder* builder() { return &builder_; } @@ -96,56 +97,56 @@ void BuildComparatorAndComparisons(ComparatorsTest* test, } } -XLA_TEST_F(ComparatorsTest, CompareLtBF16) { +TEST_F(ComparatorsTest, CompareLtBF16) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/true, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareGtBF16) { +TEST_F(ComparatorsTest, CompareGtBF16) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/false, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareLtF16) { +TEST_F(ComparatorsTest, CompareLtF16) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/true, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareGtF16) { +TEST_F(ComparatorsTest, CompareGtF16) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/false, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareLtF32) { +TEST_F(ComparatorsTest, CompareLtF32) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/true, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareGtF32) { +TEST_F(ComparatorsTest, CompareGtF32) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/false, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareLtF64) { +TEST_F(ComparatorsTest, CompareLtF64) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/true, &expected); ComputeAndCompareR1(builder(), expected, {}); } -XLA_TEST_F(ComparatorsTest, CompareGtF64) { +TEST_F(ComparatorsTest, CompareGtF64) { absl::InlinedVector expected; BuildComparatorAndComparisons(this, /*compare_less_than=*/false, &expected); diff --git a/third_party/xla/xla/hlo/builder/lib/math_test.cc b/third_party/xla/xla/hlo/builder/lib/math_test.cc index 50326da5c0ee65..7dd27f0911efa9 100644 --- a/third_party/xla/xla/hlo/builder/lib/math_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/math_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -34,22 +33,22 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" -#include "xla/service/service.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class MathTest : public ClientLibraryTestBase { - public: - ErrorSpec error_spec_{0.0001}; -}; +constexpr ErrorSpec kErrorSpec{0.0001}; + +using MathTest = ClientLibraryTestRunnerMixin; // Write TYPED_TESTs within the class definition so that we don't have to litter // "this->" everywhere. @@ -60,21 +59,23 @@ class MathTypedTest : public MathTest { SetFastMathDisabled(true); XlaBuilder b(TestName()); - Log(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}}), &b)); + const Literal param0 = LiteralUtil::CreateR1({T{0.0}, T{-0.0}}); + Log(Parameter(&b, 0, param0.shape(), "")); ComputeAndCompareR1(&b, {-std::numeric_limits::infinity(), -std::numeric_limits::infinity()}, - {}, error_spec_); + {¶m0}, kErrorSpec); } void TestLog1pEdgeCases() { SetFastMathDisabled(true); XlaBuilder b(TestName()); - Log1p(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}, T{-1.0}}), &b)); + const Literal param0 = LiteralUtil::CreateR1({T{0.0}, T{-0.0}, T{-1.0}}); + Log1p(Parameter(&b, 0, param0.shape(), "")); ComputeAndCompareR1( - &b, {T{0.0}, T{-0.0}, -std::numeric_limits::infinity()}, {}, - error_spec_); + &b, {T{0.0}, T{-0.0}, -std::numeric_limits::infinity()}, {¶m0}, + kErrorSpec); } void TestIsInfOrNan() { @@ -120,16 +121,16 @@ class MathTypedTest : public MathTest { XlaBuilder b(TestName()); T inf(std::numeric_limits::infinity()); T nan(std::numeric_limits::quiet_NaN()); - IsNegZero(AddParam( - LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), - &b)); + const Literal param0 = + LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}); + IsNegZero(Parameter(&b, 0, param0.shape(), "")); bool is_mx = std::is_same_v; ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( {has_negative_zero_v, false, false, false, false, false, is_mx}), - {}, error_spec_); + {¶m0}, kErrorSpec); } // sqrt(x) == pow(x, 0.5) except that @@ -157,25 +158,27 @@ class MathTypedTest : public MathTest { const T nan(std::numeric_limits::quiet_NaN()); XlaBuilder b(TestName()); - auto x = AddParam(LiteralUtil::CreateR1({-inf}), &b); + const Literal param0 = LiteralUtil::CreateR1({-inf}); + XlaOp x = Parameter(&b, 0, param0.shape(), ""); ConcatInDim( &b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))}, 0); std::vector expected = {nan, inf, inf}; - ComputeAndCompareR1(&b, expected, {}, error_spec_); + ComputeAndCompareR1(&b, expected, {¶m0}, kErrorSpec); } void TestErfInvEdgeCases() { SetFastMathDisabled(true); XlaBuilder b(TestName()); - auto x = AddParam(LiteralUtil::CreateR1({T{-1}, T{1}, T{0}}), &b); + const Literal param0 = LiteralUtil::CreateR1({T{-1}, T{1}, T{0}}); + XlaOp x = Parameter(&b, 0, param0.shape(), ""); ErfInv(x); const T inf(std::numeric_limits::infinity()); std::vector expected = {-inf, inf, T{0}}; - ComputeAndCompareR1(&b, expected, {}, error_spec_); + ComputeAndCompareR1(&b, expected, {¶m0}, kErrorSpec); } void TestErfEdgeCases() { @@ -185,10 +188,10 @@ class MathTypedTest : public MathTest { const T nan(std::numeric_limits::quiet_NaN()); XlaBuilder b(TestName()); - auto x = AddParam(LiteralUtil::CreateR1({T{-inf}, T{inf}, T{-0}, T{0}, - T{-kErfInvOneMinusHalfULP}, - T{kErfInvOneMinusHalfULP}}), - &b); + const Literal param0 = LiteralUtil::CreateR1( + {T{-inf}, T{inf}, T{-0}, T{0}, T{-kErfInvOneMinusHalfULP}, + T{kErfInvOneMinusHalfULP}}); + XlaOp x = Parameter(&b, 0, param0.shape(), ""); Erf(x); bool inf_as_nan = !std::numeric_limits::has_infinity && @@ -200,7 +203,7 @@ class MathTypedTest : public MathTest { T(-1), T(1)}; - ComputeAndCompareR1(&b, expected, {}, error_spec_); + ComputeAndCompareR1(&b, expected, {¶m0}, kErrorSpec); } }; @@ -237,7 +240,7 @@ XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfInvEdgeCases(); } XLA_TYPED_TEST(MathTypedTest, ErfEdgeCases) { this->TestErfEdgeCases(); } // Check that certain ops only support real, floating-point inputs. -XLA_TEST_F(MathTest, RealFpOnlyOps) { +TEST_F(MathTest, RealFpOnlyOps) { for (int64_t i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) { auto ty = static_cast(i); SCOPED_TRACE(PrimitiveType_Name(ty)); @@ -284,34 +287,27 @@ XLA_TEST_F(MathTest, RealFpOnlyOps) { } } -XLA_TEST_F(MathTest, SqrtF32) { +TEST_F(MathTest, SqrtF32) { XlaBuilder builder(TestName()); - Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); - - std::unique_ptr zero_data = - client_->TransferToServer(zero_literal).value(); - + const Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); Sqrt(zero); - ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); + ComputeAndCompareR0(&builder, 0.0f, {&zero_literal}, kErrorSpec); } -XLA_TEST_F(MathTest, SqrtF64) { +TEST_F(MathTest, SqrtF64) { XlaBuilder builder(TestName()); - Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F64); - - std::unique_ptr zero_data = - client_->TransferToServer(zero_literal).value(); + const Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F64); XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); Sqrt(zero); - ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); + ComputeAndCompareR0(&builder, 0.0f, {&zero_literal}, kErrorSpec); } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 -XLA_TEST_F(MathTest, ErfInvF64) { +TEST_F(MathTest, ErfInvF64) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, @@ -332,7 +328,7 @@ XLA_TEST_F(MathTest, ErfInvF64) { } #endif -XLA_TEST_F(MathTest, SquareTenValues) { +TEST_F(MathTest, SquareTenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -340,10 +336,10 @@ XLA_TEST_F(MathTest, SquareTenValues) { std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, 5.29, 25., 0.81, 5.76, 2.56}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, ReciprocalTenValues) { +TEST_F(MathTest, ReciprocalTenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -352,27 +348,27 @@ XLA_TEST_F(MathTest, ReciprocalTenValues) { std::vector expected = { 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, SqrtZeroes) { +TEST_F(MathTest, SqrtZeroes) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0.0, -0.0}); Sqrt(x); - ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); + ComputeAndCompareR1(&builder, {0, 0}, {}, kErrorSpec); } -XLA_TEST_F(MathTest, SqrtSixValues) { +TEST_F(MathTest, SqrtSixValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); Sqrt(x); std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, CbrtSixF32Values) { +TEST_F(MathTest, CbrtSixF32Values) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331}); Cbrt(x); @@ -381,7 +377,7 @@ XLA_TEST_F(MathTest, CbrtSixF32Values) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.001)); } -XLA_TEST_F(MathTest, CbrtSixF64Values) { +TEST_F(MathTest, CbrtSixF64Values) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331}); Cbrt(x); @@ -390,31 +386,31 @@ XLA_TEST_F(MathTest, CbrtSixF64Values) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.001)); } -XLA_TEST_F(MathTest, SinhSmallValues) { +TEST_F(MathTest, SinhSmallValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); Sinh(x); std::vector expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, AsinhSmallValues) { +TEST_F(MathTest, AsinhSmallValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); Asinh(x); std::vector expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, AtanhSmallValues) { +TEST_F(MathTest, AtanhSmallValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1e-8, 1e-9, 1e-10, 1e-11}); Atanh(x); std::vector expected = {1e-8, 1e-9, 1e-10, 1e-11}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, Lgamma) { +TEST_F(MathTest, Lgamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, -1.5, -3.5, -5.5}); @@ -433,12 +429,11 @@ XLA_TEST_F(MathTest, Lgamma) { static_cast(std::log(M_PI) / 2 - std::log(3) + std::log(4)), static_cast(std::log(M_PI) / 2 - std::log(105) + std::log(16)), static_cast(std::log(M_PI) / 2 - std::log(10395) + std::log(64))}; - error_spec_ = ErrorSpec{0.001}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec{0.001}); } #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -XLA_TEST_F(MathTest, LgammaF16) { +TEST_F(MathTest, LgammaF16) { SetFastMathDisabled(true); XlaBuilder b(TestName()); @@ -460,7 +455,7 @@ XLA_TEST_F(MathTest, LgammaF16) { } #endif -XLA_TEST_F(MathTest, Digamma) { +TEST_F(MathTest, Digamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, 2.0, 3.0, 4.0, 6.0, 8.0, 9.0}); @@ -487,10 +482,10 @@ XLA_TEST_F(MathTest, Digamma) { static_cast(137 / 60.0 - euler_mascheroni), static_cast(363 / 140.0 - euler_mascheroni), static_cast(761 / 280.0 - euler_mascheroni)}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, Igamma) { +TEST_F(MathTest, Igamma) { XlaBuilder builder(TestName()); auto a = ConstantR3FromArray3D( &builder, @@ -509,10 +504,10 @@ XLA_TEST_F(MathTest, Igamma) { {{0.78746926, 0.99940502, 0.98028261, 0.97033807, 0.99054696}, {0.33265522, 0.99983558, 0.32599159, 0.99923275, 0.99980893}, {0.74343963, 0.46703197, 0.33923541, 0.99978511, 0.99460685}}}; - ComputeAndCompareR3(&builder, expected, {}, error_spec_); + ComputeAndCompareR3(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, IgammaSpecialValues) { +TEST_F(MathTest, IgammaSpecialValues) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); const float nan = std::numeric_limits::quiet_NaN(); @@ -524,11 +519,11 @@ XLA_TEST_F(MathTest, IgammaSpecialValues) { Igamma(a, x); std::vector expected = {nan, nan, nan, nan, nan, nan}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -XLA_TEST_F(MathTest, IgammaF16) { +TEST_F(MathTest, IgammaF16) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); @@ -550,7 +545,7 @@ XLA_TEST_F(MathTest, IgammaF16) { } #endif -XLA_TEST_F(MathTest, Igammac) { +TEST_F(MathTest, Igammac) { XlaBuilder builder(TestName()); auto a = ConstantR3FromArray3D( &builder, @@ -571,11 +566,11 @@ XLA_TEST_F(MathTest, Igammac) { 7.67252602e-04, 1.91071108e-04}, {2.56560373e-01, 5.32968026e-01, 6.60764593e-01, 2.14889688e-04, 5.39314824e-03}}}; - ComputeAndCompareR3(&builder, expected, {}, error_spec_); + ComputeAndCompareR3(&builder, expected, {}, kErrorSpec); } #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -XLA_TEST_F(MathTest, IgammacF16) { +TEST_F(MathTest, IgammacF16) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); @@ -598,7 +593,7 @@ XLA_TEST_F(MathTest, IgammacF16) { } #endif -XLA_TEST_F(MathTest, RoundToEven) { +TEST_F(MathTest, RoundToEven) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {-1.4, -1.5, -2.5, -0.5, 0, 0.5, 1.5, 2.5, 3.5, 4.5}); @@ -607,45 +602,45 @@ XLA_TEST_F(MathTest, RoundToEven) { std::vector expected = {-1.0, -2.0, -2.0, -0.0, 0, 0.0, 2.0, 2.0, 4.0, 4.0}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, ErfRejectsComplexInputs) { +TEST_F(MathTest, ErfRejectsComplexInputs) { XlaBuilder b(TestName()); auto x = ConstantR1>(&b, {{0, 0}}); Erf(x); EXPECT_FALSE(b.Build().status().ok()); } -XLA_TEST_F(MathTest, ErfcRejectsComplexInputs) { +TEST_F(MathTest, ErfcRejectsComplexInputs) { XlaBuilder b(TestName()); auto x = ConstantR1>(&b, {{0, 0}}); Erfc(x); EXPECT_FALSE(b.Build().status().ok()); } -XLA_TEST_F(MathTest, LgammaRejectsComplexInputs) { +TEST_F(MathTest, LgammaRejectsComplexInputs) { XlaBuilder b(TestName()); auto x = ConstantR1>(&b, {{0, 0}}); Lgamma(x); EXPECT_FALSE(b.Build().status().ok()); } -XLA_TEST_F(MathTest, DigammaRejectsComplexInputs) { +TEST_F(MathTest, DigammaRejectsComplexInputs) { XlaBuilder b(TestName()); auto x = ConstantR1>(&b, {{0, 0}}); Digamma(x); EXPECT_FALSE(b.Build().status().ok()); } -XLA_TEST_F(MathTest, RoundToEvenRejectsComplexInputs) { +TEST_F(MathTest, RoundToEvenRejectsComplexInputs) { XlaBuilder b(TestName()); auto x = ConstantR1>(&b, {{0, 0}}); RoundToEven(x); EXPECT_FALSE(b.Build().status().ok()); } -XLA_TEST_F(MathTest, BesselI0eFloat) { +TEST_F(MathTest, BesselI0eFloat) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, @@ -675,10 +670,10 @@ XLA_TEST_F(MathTest, BesselI0eFloat) { 0.100544127361, 0.0947062952128, 0.0897803118848}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, DISABLED_ON_TPU(BesselI0eDouble)) { +TEST_F(MathTest, DISABLED_ON_TPU(BesselI0eDouble)) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, @@ -708,10 +703,10 @@ XLA_TEST_F(MathTest, DISABLED_ON_TPU(BesselI0eDouble)) { 0.100544127361, 0.0947062952128, 0.0897803118848}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, BesselI1eFloat) { +TEST_F(MathTest, BesselI1eFloat) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, @@ -741,10 +736,10 @@ XLA_TEST_F(MathTest, BesselI1eFloat) { 0.0973496147565, 0.092036796872, 0.0875062221833}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, DISABLED_ON_TPU(BesselI1eDouble)) { +TEST_F(MathTest, DISABLED_ON_TPU(BesselI1eDouble)) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, @@ -774,10 +769,10 @@ XLA_TEST_F(MathTest, DISABLED_ON_TPU(BesselI1eDouble)) { 0.0973496147565, 0.092036796872, 0.0875062221833}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, AcosComplexValues) { +TEST_F(MathTest, AcosComplexValues) { XlaBuilder builder(TestName()); auto x = ConstantR1>( &builder, {{0, 0}, {0, 1}, {1, 1}, {0.8, 0.2}}); @@ -788,10 +783,10 @@ XLA_TEST_F(MathTest, AcosComplexValues) { {1.5707963267948966, -0.881373587019543}, {0.9045568943023814, -1.0612750619050357}, {0.7011246914497526, -0.30527648462436596}}; - ComputeAndCompareR1>(&builder, expected, {}, error_spec_); + ComputeAndCompareR1>(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(MathTest, ZetaF64) { +TEST_F(MathTest, ZetaF64) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {2.0}); auto q = ConstantR1(&builder, {1.0}); diff --git a/third_party/xla/xla/hlo/builder/lib/matrix_test.cc b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc index 9924a106834484..3a21e6240d8959 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc @@ -34,14 +34,17 @@ limitations under the License. #include "xla/hlo/builder/lib/slicing.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" namespace xla { namespace { -class MatrixTest : public ClientLibraryTestBase { +class MatrixTest : public ClientLibraryTestRunnerMixin { protected: template void TestMatrixDiagonal(); @@ -66,7 +69,7 @@ class MatrixTest : public ClientLibraryTestBase { } }; -XLA_TEST_F(MatrixTest, Triangle) { +TEST_F(MatrixTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); input.FillIota(0); @@ -77,10 +80,10 @@ XLA_TEST_F(MatrixTest, Triangle) { Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); - ComputeAndCompareR3(&builder, expected, {a_data.get()}); + ComputeAndCompareR3(&builder, expected, {&a_data}); } -XLA_TEST_F(MatrixTest, Symmetrize) { +TEST_F(MatrixTest, Symmetrize) { for (bool lower : {false, true}) { XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); @@ -90,8 +93,8 @@ XLA_TEST_F(MatrixTest, Symmetrize) { {4, 5, 6}, }; - XlaOp a; - auto a_data = CreateParameter(input, 0, "a", &builder, &a); + const Literal a_literal = LiteralUtil::CreateFromArray(input); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower); Array expected = { @@ -100,11 +103,12 @@ XLA_TEST_F(MatrixTest, Symmetrize) { {4, 5, 6}, }; - ComputeAndCompare(&builder, expected, {a_data.get()}); + ComputeAndCompareLiteral(&builder, LiteralUtil::CreateFromArray(expected), + {&a_literal}); } } -XLA_TEST_F(MatrixTest, SymmetrizeComplex) { +TEST_F(MatrixTest, SymmetrizeComplex) { for (bool lower : {false, true}) { XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); @@ -114,8 +118,8 @@ XLA_TEST_F(MatrixTest, SymmetrizeComplex) { {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}}, }; - XlaOp a; - auto a_data = CreateParameter(input, 0, "a", &builder, &a); + const Literal a_literal = LiteralUtil::CreateFromArray(input); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower); Array expected = { @@ -124,7 +128,8 @@ XLA_TEST_F(MatrixTest, SymmetrizeComplex) { {complex64{4, 8}, complex64{5, 9}, 6}, }; - ComputeAndCompare(&builder, expected, {a_data.get()}); + ComputeAndCompareLiteral(&builder, LiteralUtil::CreateFromArray(expected), + {&a_literal}); } } @@ -138,7 +143,7 @@ void MatrixTest::TestMatrixDiagonal() { auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); GetMatrixDiagonal(a, kv.first); - ComputeAndCompareR2(&builder, kv.second, {a_data.get()}); + ComputeAndCompareR2(&builder, kv.second, {&a_data}); } } @@ -158,25 +163,19 @@ void MatrixTest::TestSetMatrixDiagonal() { kv.first) - ScalarLike(b, 1); - ComputeAndCompareR2(&builder, kv.second, {a_data.get(), new_diag.get()}); + ComputeAndCompareR2(&builder, kv.second, {&a_data, &new_diag}); } } -XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) { - TestSetMatrixDiagonal(); -} -XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) { - TestSetMatrixDiagonal(); -} -XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) { - TestSetMatrixDiagonal(); -} +TEST_F(MatrixTest, SetMatrixDiagonal_S32) { TestSetMatrixDiagonal(); } +TEST_F(MatrixTest, SetMatrixDiagonal_S64) { TestSetMatrixDiagonal(); } +TEST_F(MatrixTest, SetMatrixDiagonal_F32) { TestSetMatrixDiagonal(); } -XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } +TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } -XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } +TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } -XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } +TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } template void MatrixTest::TestMatrixDiagonal4D() { @@ -199,21 +198,15 @@ void MatrixTest::TestMatrixDiagonal4D() { auto a_data = CreateR4Parameter(input, 0, "a", &builder, &a); GetMatrixDiagonal(a, kv.first); - ComputeAndCompareR3(&builder, kv.second, {a_data.get()}); + ComputeAndCompareR3(&builder, kv.second, {&a_data}); } } -XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) { - TestMatrixDiagonal4D(); -} +TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) { TestMatrixDiagonal4D(); } -XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) { - TestMatrixDiagonal4D(); -} +TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) { TestMatrixDiagonal4D(); } -XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) { - TestMatrixDiagonal4D(); -} +TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) { TestMatrixDiagonal4D(); } Array3D BatchedAValsFull() { return {{ @@ -230,7 +223,7 @@ Array3D BatchedAValsFull() { }}; } -XLA_TEST_F(MatrixTest, RowBatchDot) { +TEST_F(MatrixTest, RowBatchDot) { XlaBuilder builder(TestName()); int n = 4; @@ -247,10 +240,10 @@ XLA_TEST_F(MatrixTest, RowBatchDot) { BatchDot(l_index, TransposeInMinorDims(row)); ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); + {&a_data, &row_data, &index_data}); } -XLA_TEST_F(MatrixTest, Einsum) { +TEST_F(MatrixTest, Einsum) { XlaBuilder builder(TestName()); int n = 4; @@ -268,10 +261,10 @@ XLA_TEST_F(MatrixTest, Einsum) { Einsum(l_index, row, "abc,adc->abd"); ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); + {&a_data, &row_data, &index_data}); } -XLA_TEST_F(MatrixTest, ParseEinsumString) { +TEST_F(MatrixTest, ParseEinsumString) { auto to_vec = [](absl::string_view s) { std::vector v; v.reserve(s.size()); @@ -326,7 +319,7 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { } } -XLA_TEST_F(MatrixTest, NormalizeEinsumString) { +TEST_F(MatrixTest, NormalizeEinsumString) { EXPECT_EQ(NormalizeEinsumString("a,b->ab"), ""); EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab"); EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd"); diff --git a/third_party/xla/xla/hlo/builder/lib/pooling_test.cc b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc index 83ebbb50337fdb..f068938c9a20f0 100644 --- a/third_party/xla/xla/hlo/builder/lib/pooling_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc @@ -25,12 +25,15 @@ limitations under the License. #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { +constexpr ErrorSpec kErrorSpec{0.0001}; + TensorFormat MakeNCHWFormat(int num_spatial_dims) { absl::InlinedVector spatial_dimensions; for (int i = 0; i < num_spatial_dims; ++i) { @@ -66,12 +69,9 @@ std::vector ExpandWithBatchAndFeatureDimensions( return tensor_sizes; } -class PoolingTest : public ClientLibraryTestBase { - public: - ErrorSpec error_spec_{0.0001}; -}; +using PoolingTest = ClientLibraryTestRunnerMixin; -XLA_TEST_F(PoolingTest, MaxPool2D) { +TEST_F(PoolingTest, MaxPool2D) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -81,10 +81,10 @@ XLA_TEST_F(PoolingTest, MaxPool2D) { auto stride = kernel_size; MaxPool(input, kernel_size, stride, Padding::kValid, data_format); - ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { +TEST_F(PoolingTest, MaxPool2DWithPadding) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -94,10 +94,10 @@ XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { auto stride = kernel_size; MaxPool(input, kernel_size, stride, Padding::kSame, data_format); - ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { +TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -108,10 +108,10 @@ XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { MaxPool(input, kernel_size, stride, Padding::kSame, data_format); ComputeAndCompareR4(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}}, - {}, error_spec_); + {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2D) { +TEST_F(PoolingTest, AvgPool2D) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -124,10 +124,10 @@ XLA_TEST_F(PoolingTest, AvgPool2D) { AvgPool(input, kernel_size, stride, padding, data_format, /*counts_include_padding=*/true); - ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { +TEST_F(PoolingTest, AvgPool2DWithPadding) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -140,10 +140,10 @@ XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { AvgPool(input, kernel_size, stride, padding, data_format, /*counts_include_padding=*/false); - ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { +TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -156,12 +156,11 @@ XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { AvgPool(input, kernel_size, stride, padding, data_format, /*counts_include_padding=*/false); - ComputeAndCompareR4(&builder, - {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, - error_spec_); + ComputeAndCompareR4( + &builder, {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { +TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -172,11 +171,11 @@ XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format, /*counts_include_padding=*/false); - ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, - AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { +TEST_F(PoolingTest, + AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { XlaBuilder builder(TestName()); XlaOp input = ConstantR4FromArray4D( @@ -188,10 +187,10 @@ XLA_TEST_F(PoolingTest, /*counts_include_padding=*/false); ComputeAndCompareR4(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {}, - error_spec_); + kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2DGradNoPadding) { +TEST_F(PoolingTest, AvgPool2DGradNoPadding) { XlaBuilder builder(TestName()); for (bool counts_include_padding : {false, true}) { XlaOp out_backprop = ConstantR4FromArray4D(&builder, {{{{1.}}}}); @@ -204,11 +203,11 @@ XLA_TEST_F(PoolingTest, AvgPool2DGradNoPadding) { // Without padding, counts_include_padding makes no difference. ComputeAndCompareR4( &builder, {{{{0.25, 0.25, 0.}, {0.25, 0.25, 0.}, {0., 0., 0.}}}}, {}, - error_spec_); + kErrorSpec); } } -XLA_TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) { +TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) { XlaBuilder builder(TestName()); for (bool counts_include_padding : {false, true}) { XlaOp out_backprop = @@ -222,11 +221,11 @@ XLA_TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) { // Without padding, counts_include_padding makes no difference. ComputeAndCompareR4( &builder, {{{{0.25, 0.5, 0.25}, {0.5, 1., 0.5}, {0.25, 0.5, 0.25}}}}, - {}, error_spec_); + {}, kErrorSpec); } } -XLA_TEST_F(PoolingTest, AvgPool2DGradWithPadding) { +TEST_F(PoolingTest, AvgPool2DGradWithPadding) { XlaBuilder builder(TestName()); XlaOp out_backprop = @@ -240,10 +239,10 @@ XLA_TEST_F(PoolingTest, AvgPool2DGradWithPadding) { ComputeAndCompareR4( &builder, {{{{0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}}}}, {}, - error_spec_); + kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) { +TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) { XlaBuilder builder(TestName()); XlaOp out_backprop = @@ -255,10 +254,10 @@ XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) { MakeNCHWFormat(2), false); ComputeAndCompareR4( &builder, {{{{1., 0.5, 0.5}, {0.5, 0.25, 0.25}, {0.5, 0.25, 0.25}}}}, {}, - error_spec_); + kErrorSpec); } -XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) { +TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) { XlaBuilder builder(TestName()); XlaOp out_backprop = @@ -271,13 +270,11 @@ XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) { auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, MakeNCHWFormat(2), true); - ComputeAndCompareR4(&builder, - {{{{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}}}, {}, - error_spec_); + ComputeAndCompareR4( + &builder, {{{{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}}}, {}, kErrorSpec); } -XLA_TEST_F(PoolingTest, - AvgPool2DGradWithPaddingCountWithStrideNotIncludePadding) { +TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStrideNotIncludePadding) { XlaBuilder builder(TestName()); XlaOp out_backprop = @@ -292,7 +289,7 @@ XLA_TEST_F(PoolingTest, MakeNCHWFormat(2), false); ComputeAndCompareR4( &builder, {{{{2.25, 1.5, 2.25}, {1.5, 1., 1.5}, {2.25, 1.5, 2.25}}}}, {}, - error_spec_); + kErrorSpec); } } // namespace diff --git a/third_party/xla/xla/hlo/builder/lib/prng_test.cc b/third_party/xla/xla/hlo/builder/lib/prng_test.cc index ae180736a541f5..6117d49524aeb4 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng_test.cc @@ -25,14 +25,15 @@ limitations under the License. #include "xla/hlo/testlib/test.h" #include "xla/primitive_util.h" #include "xla/shape.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class PrngTest : public ClientLibraryTestBase { +class PrngTest : public ClientLibraryTestRunnerMixin { public: template +#include #include #include @@ -32,19 +33,23 @@ limitations under the License. #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { -class SelfAdjointEigTest : public ClientLibraryTestBase { +class SelfAdjointEigTest : public ClientLibraryTestRunnerMixin { protected: void SetUp() override { - ClientLibraryTestBase::SetUp(); + ClientLibraryTestRunnerMixin::SetUp(); batch_3d_4x4_ = Array3D{ { {4, 6, 8, 10}, @@ -78,7 +83,6 @@ class SelfAdjointEigTest : public ClientLibraryTestBase { {3, 9, 11, 17}, }; } - void TearDown() override { ClientLibraryTestBase::TearDown(); } Array3D GetUnitMatrix3D(const Array3D& matrix) { Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); @@ -138,7 +142,7 @@ XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { PrecisionConfig::HIGHEST); } -XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { +TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { for (bool sort_eigenvalues : {false, true}) { XlaBuilder builder(TestName()); @@ -148,28 +152,28 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { /*tol=*/1e-5, sort_eigenvalues); ComputeMatmulVWVt(result, &builder); - ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ComputeAndCompareR3(&builder, batch_3d_4x4_, {&a_data}, ErrorSpec(1e-3, 1e-3)); } } -XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) { +TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) { XlaBuilder builder(TestName()); Array input = { {1, complex64{2, -7}, complex64{4, -8}}, {complex64{2, 7}, 3, complex64{5, -9}}, {complex64{4, 8}, complex64{5, 9}, 6}, }; - XlaOp a; - auto a_data = CreateParameter(input, 0, "a", &builder, &a); + const Literal a_literal = LiteralUtil::CreateFromArray(input); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); - ComputeAndCompare(&builder, input, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder, LiteralUtil::CreateFromArray(input), + {&a_literal}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { +TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; @@ -178,11 +182,11 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); - ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ComputeAndCompareR3(&builder, batch_3d_4x4_, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { +TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; @@ -191,11 +195,11 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { auto result = SelfAdjointEig(a, false); ComputeMatmulVWVt(result, &builder); - ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ComputeAndCompareR3(&builder, batch_3d_4x4_, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { +TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; @@ -204,10 +208,10 @@ XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); ComputeAndCompareR3(&builder, GetUnitMatrix3D(batch_3d_4x4_), - {a_data.get()}, ErrorSpec(1e-3, 1e-3)); + {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { +TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { XlaBuilder builder(TestName()); XlaOp a; @@ -215,11 +219,11 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); - ComputeAndCompareR2(&builder, low_rank_4x4_, {a_data.get()}, + ComputeAndCompareR2(&builder, low_rank_4x4_, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { +TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { XlaBuilder builder(TestName()); // This is computed by numpy.linalg.eigh with float32. @@ -231,11 +235,11 @@ XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { auto result = SelfAdjointEig(a); Add(result.w, ZerosLike(result.w)); - ComputeAndCompareR1(&builder, expected, {a_data.get()}, + ComputeAndCompareR1(&builder, expected, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { +TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { XlaBuilder builder(TestName()); float expected_vals = 1e-3; @@ -248,11 +252,11 @@ XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { BatchDot(TransposeInMinorDims(result.v), result.v), &builder); - ComputeAndCompareR0(&builder, expected_vals, {a_data.get()}, + ComputeAndCompareR0(&builder, expected_vals, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { +TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { XlaBuilder builder(TestName()); XlaOp a; @@ -276,7 +280,7 @@ Array2D GenerateRandomSymmetricMatrix(int size) { } using EighTestCase = int64_t; -class RandomEighTest : public ClientLibraryTestBase, +class RandomEighTest : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface {}; XLA_TEST_P(RandomEighTest, Random) { @@ -290,7 +294,7 @@ XLA_TEST_P(RandomEighTest, Random) { // TODO(phawkins): this would be better expressed as <= 6e-3. double kExpected = 0.00300000003; - ComputeAndCompareR0(&builder, kExpected, {a_data.get()}, + ComputeAndCompareR0(&builder, kExpected, {&a_data}, ErrorSpec(kExpected, 0)); } diff --git a/third_party/xla/xla/hlo/builder/lib/svd_test.cc b/third_party/xla/xla/hlo/builder/lib/svd_test.cc index 0bb0a63a011fe4..28b7c0172e8184 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd_test.cc @@ -30,16 +30,18 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { -class SVDTest : public ClientLibraryTestBase { +class SVDTest : public ClientLibraryTestRunnerMixin { protected: void SetUp() override { - ClientLibraryTestBase::SetUp(); + ClientLibraryTestRunnerMixin::SetUp(); batch_3d_4x5_ = Array3D{ { {4, 6, 8, 10, 1}, @@ -55,7 +57,6 @@ class SVDTest : public ClientLibraryTestBase { }, }; } - void TearDown() override { ClientLibraryTestBase::TearDown(); } Array3D GetUnitMatrix3D(int32_t batch_dim, int32_t mat_dim) { Array3D result(batch_dim, mat_dim, mat_dim, 0.0); @@ -112,7 +113,7 @@ class SVDTest : public ClientLibraryTestBase { Array3D batch_3d_4x5_; }; -XLA_TEST_F(SVDTest, Simple2D) { +TEST_F(SVDTest, Simple2D) { XlaBuilder builder(TestName()); Array2D simple_2d_4x4_ = Array2D{ @@ -126,11 +127,11 @@ XLA_TEST_F(SVDTest, Simple2D) { auto result = SVD(a, 100, 1e-6); ComputeMatmulUDVT(result, &builder); - ComputeAndCompareR2(&builder, simple_2d_4x4_, {a_data.get()}, + ComputeAndCompareR2(&builder, simple_2d_4x4_, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) { +TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) { XlaBuilder builder(TestName()); XlaOp a; @@ -138,11 +139,11 @@ XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) { auto result = SVD(a, 100, 1e-8); ComputeMatmulUDVT(result, &builder); - ComputeAndCompareR3(&builder, batch_3d_4x5_, {a_data.get()}, + ComputeAndCompareR3(&builder, batch_3d_4x5_, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Test_Orthogonality_U) { +TEST_F(SVDTest, Test_Orthogonality_U) { XlaBuilder builder(TestName()); XlaOp a; @@ -151,11 +152,11 @@ XLA_TEST_F(SVDTest, Test_Orthogonality_U) { ComputeMatmulUDVT(result, &builder); BatchDot(result.u, TransposeInMinorDims(result.u)); - ComputeAndCompareR3(&builder, GetUnitMatrix3D(2, 4), {a_data.get()}, + ComputeAndCompareR3(&builder, GetUnitMatrix3D(2, 4), {&a_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(SVDTest, Test_Orthogonality_V) { +TEST_F(SVDTest, Test_Orthogonality_V) { XlaBuilder builder(TestName()); XlaOp a; @@ -163,11 +164,11 @@ XLA_TEST_F(SVDTest, Test_Orthogonality_V) { auto result = SVD(a, 100, 1e-8); BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); - ComputeAndCompareR3(&builder, GetUnitMatrix3D(2, 5), {a_data.get()}, + ComputeAndCompareR3(&builder, GetUnitMatrix3D(2, 5), {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) { +TEST_F(SVDTest, TestSingleValuesMatchNumpy) { XlaBuilder builder(TestName()); auto singular_values = Array2D{ @@ -180,13 +181,12 @@ XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) { auto result = SVD(a, 100, 1e-8); Add(result.d, ZerosLike(result.d)); - ComputeAndCompareR2(&builder, singular_values, {a_data.get()}, + ComputeAndCompareR2(&builder, singular_values, {&a_data}, ErrorSpec(1e-3, 1e-3)); } // Too slow on the interpreter backend. -XLA_TEST_F(SVDTest, - DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) { +TEST_F(SVDTest, DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 128); XlaOp a; @@ -194,11 +194,10 @@ XLA_TEST_F(SVDTest, auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareR0(&builder, 1e-3, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) { +TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(128, 256); XlaOp a; @@ -206,11 +205,10 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) { auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareR0(&builder, 1e-3, {&a_data}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) { +TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(256, 128); XlaOp a; @@ -218,13 +216,11 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) { auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareR0(&builder, 1e-3, {&a_data}, ErrorSpec(1e-3, 1e-3)); } // Too slow on the interpreter backend. -XLA_TEST_F(SVDTest, - DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) { +TEST_F(SVDTest, DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(128, 512); XlaOp a; @@ -232,13 +228,12 @@ XLA_TEST_F(SVDTest, auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareR0(&builder, 1e-3, {&a_data}, ErrorSpec(1e-3, 1e-3)); } // Too slow on the interpreter and CPU backends. -XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( - Various_Size_Random_Matrix_512x256))) { +TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + Various_Size_Random_Matrix_512x256))) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 256); XlaOp a; @@ -246,13 +241,12 @@ XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareR0(&builder, 1e-3, {&a_data}, ErrorSpec(1e-3, 1e-3)); } // Too slow on the CPU, GPU and interpreter backends. -XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( - Various_Size_Random_Matrix_512x512)))) { +TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + Various_Size_Random_Matrix_512x512)))) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 512); XlaOp a; @@ -260,8 +254,7 @@ XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareR0(&builder, 1e-3, {&a_data}, ErrorSpec(1e-3, 1e-3)); } } // namespace xla diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc index f1410530108a5b..76795ba29cb1bf 100644 --- a/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc @@ -21,24 +21,25 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/array.h" #include "xla/array3d.h" #include "xla/hlo/builder/lib/slicing.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" namespace xla { namespace tridiagonal { namespace { class TridiagonalTest - : public ClientLibraryTestBase, + : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface> {}; XLA_TEST_P(TridiagonalTest, SimpleTridiagonalMatMulOk) { @@ -75,10 +76,10 @@ XLA_TEST_P(TridiagonalTest, SimpleTridiagonalMatMulOk) { std::vector expected_values{191, 246, 301, 356, 435, 502, 569, 636, 707, 830, 953, 1076}; TF_ASSERT_OK_AND_ASSIGN( - auto result, - ComputeAndTransfer(x.builder(), - {upper_diagonal_data.get(), main_diagonal_data.get(), - lower_diagonal_data.get(), rhs_data.get()})); + const Literal result, + ExecuteAndTransfer(x.builder(), + {&upper_diagonal_data, &main_diagonal_data, + &lower_diagonal_data, &rhs_data})); EXPECT_EQ(result.shape().dimensions(), expected_shape); EXPECT_EQ(result.data({}), expected_values); } @@ -86,23 +87,14 @@ XLA_TEST_P(TridiagonalTest, SimpleTridiagonalMatMulOk) { XLA_TEST_P(TridiagonalTest, TridiagonalMatMulWrongShape) { xla::XlaBuilder builder(TestName()); - Array upper_diagonal = Array({5, 3, 7}, 1); - Array main_diagonal = Array({5, 3, 7}, 1); - Array lower_diagonal = Array({5, 3, 7}, 1); - Array rhs = Array({5, 3, 7, 6}, 1); - - XlaOp upper_diagonal_xla; - XlaOp main_diagonal_xla; - XlaOp lower_diagonal_xla; - XlaOp rhs_xla; - - auto upper_diagonal_data = CreateParameter( - upper_diagonal, 0, "upper_diagonal", &builder, &upper_diagonal_xla); - auto main_diagonal_data = CreateParameter( - main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla); - auto lower_diagonal_data = CreateParameter( - lower_diagonal, 2, "lower_diagonal", &builder, &lower_diagonal_xla); - auto rhs_data = CreateParameter(rhs, 3, "rhs", &builder, &rhs_xla); + XlaOp upper_diagonal_xla = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {5, 3, 7}), "upper_diagonal"); + XlaOp main_diagonal_xla = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {5, 3, 7}), "main_diagonal"); + XlaOp lower_diagonal_xla = Parameter( + &builder, 2, ShapeUtil::MakeShape(F32, {5, 3, 7}), "lower_diagonal"); + XlaOp rhs_xla = + Parameter(&builder, 3, ShapeUtil::MakeShape(F32, {5, 3, 7, 6}), "rhs"); auto result = TridiagonalMatMul(upper_diagonal_xla, main_diagonal_xla, lower_diagonal_xla, rhs_xla); @@ -177,13 +169,11 @@ XLA_TEST_P(TridiagonalTest, Solves) { Abs(ConcatInDim(&builder, relative_errors, 2)); TF_ASSERT_OK_AND_ASSIGN( - auto result, - ComputeAndTransfer(&builder, - {lower_diagonal_data.get(), main_diagonal_data.get(), - upper_diagonal_data.get(), rhs_data.get()})); + const Literal result, + ExecuteAndTransfer(&builder, {&lower_diagonal_data, &main_diagonal_data, + &upper_diagonal_data, &rhs_data})); - auto result_data = result.data({}); - for (auto result_component : result_data) { + for (const float result_component : result.data({})) { EXPECT_TRUE(result_component < 5e-3); } } diff --git a/third_party/xla/xla/hlo/builder/lib/tuple_test.cc b/third_party/xla/xla/hlo/builder/lib/tuple_test.cc index 67f270300acce4..c81ceb4d834905 100644 --- a/third_party/xla/xla/hlo/builder/lib/tuple_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple_test.cc @@ -16,28 +16,28 @@ limitations under the License. #include "xla/hlo/builder/lib/tuple.h" #include -#include #include #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/service.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { -class TupleTest : public ClientLibraryTestBase {}; +using TupleTest = ClientLibraryTestRunnerMixin; -XLA_TEST_F(TupleTest, DisassembleAssemble) { +TEST_F(TupleTest, DisassembleAssemble) { XlaBuilder builder(TestName()); Shape shape = ShapeUtil::MakeTupleShape({ @@ -70,18 +70,13 @@ XLA_TEST_F(TupleTest, DisassembleAssemble) { }); AssembleTuple(&builder, std::move(disassembled_tuple)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr data, - client_->TransferToServer(input)); - - Literal expected = LiteralUtil::MakeTupleOwned( + const Literal expected = LiteralUtil::MakeTupleOwned( LiteralUtil::CreateFullWithDescendingLayout({3}, int32_t{43}), LiteralUtil::MakeTupleOwned( LiteralUtil::CreateFullWithDescendingLayout({4}, int32_t{45}), LiteralUtil::CreateFullWithDescendingLayout({5}, int32_t{47})), LiteralUtil::CreateFullWithDescendingLayout({6}, int32_t{49})); - - ComputeAndCompareLiteral(&builder, expected, {data.get()}, ErrorSpec(0), - &shape); + ComputeAndCompareLiteral(&builder, expected, {&input}, ErrorSpec(0), &shape); } } // namespace From 2095cf4b7c6f6ebf10feade0a7d41edeb7494bdf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 19:15:43 -0700 Subject: [PATCH 0772/1324] Fix an bug that incorrectly removes send-done instruction PiperOrigin-RevId: 747658901 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 5 +++++ third_party/xla/xla/hlo/ir/hlo_instruction.h | 3 +++ 2 files changed, 8 insertions(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index ed93d31470e51d..e258a26dc22c19 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2993,6 +2993,11 @@ bool HloInstruction::HasControlDependencies() const { return (!r->control_predecessors.empty() || !r->control_successors.empty()); } +bool HloInstruction::HasSuccessorControlDependencies() const { + const Rare* r = rare(); + return (!r->control_successors.empty()); +} + absl::Status HloInstruction::CopyAllControlDepsTo(HloInstruction* start, HloInstruction* end) const { for (auto* ctrl_pred : control_predecessors()) { diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 12e1804a69e1d6..25069c37d4642b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1268,6 +1268,9 @@ class HloInstruction { // Returns if instruction has any control dependencies. bool HasControlDependencies() const; + // Returns if instruction has successor control dependencies. + bool HasSuccessorControlDependencies() const; + // Copies the control predecessors and successors on this HLO instruction to // `inst`. Does not do a deep copy so this makes sense only if `inst` and // this HLO are in the same module. From cbd32ac9087f68b1d155732cf9eff2724869d207 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 20:48:07 -0700 Subject: [PATCH 0773/1324] Small refactor of annotation tracker PiperOrigin-RevId: 747687358 --- .../xla/service/latency_hiding_scheduler.cc | 53 ++++++++++++----- .../xla/service/latency_hiding_scheduler.h | 58 ++++++------------- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 88ab9546cef438..cba625a5a0d43e 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -1831,19 +1831,18 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( if (!annotations.empty()) { VLOG(2) << "Scheduled node is a frontier: " << n->GetInstr().name(); for (int64_t annotation : annotations) { - sched_state->num_scheduled_successors_for_annotation[annotation]++; - VLOG(2) - << "Annotation: " << annotation << " scheduled num successors: " - << sched_state->num_scheduled_successors_for_annotation[annotation] - << " total num successors: " - << annotation_tracker_->GetNumSuccessors(n->GetInstr().parent(), - annotation); + DefaultSchedulerCore::SchedulingState::NumSuccessorsForAnnotation& + num_successors_for_annotation = + sched_state->num_successors_for_annotation[annotation]; + num_successors_for_annotation.scheduled++; + VLOG(2) << "Annotation: " << annotation << " scheduled num successors: " + << num_successors_for_annotation.scheduled + << " total num successors: " << num_successors_for_annotation.all; // LegalizeSchedulingAnnotations pass should have made sure that we will // eventually reach a state where all successors of the annotation are // scheduled. - if (annotation_tracker_->GetNumSuccessors(n->GetInstr().parent(), - annotation) == - sched_state->num_scheduled_successors_for_annotation[annotation]) { + if (num_successors_for_annotation.scheduled == + num_successors_for_annotation.all) { sched_state->ready_annotations.push_back(annotation); } } @@ -2468,6 +2467,28 @@ DefaultSchedulerCore::GetNumResourcesNeededForAnnotation( return num_resources_needed; } +int64_t DefaultSchedulerCore::GetNumSuccessorsForAnnotation( + const SchedulingState& sched_state, int64_t annotation) const { + const HloComputation* comp = + sched_state.sched_graph.GetOriginalInstrList()[0]->parent(); + int64_t num_successors = 0; + std::vector instrs = + annotation_tracker_->GetInstructions(comp, annotation); + absl::flat_hash_set seen_instrs(instrs.begin(), + instrs.end()); + for (const HloInstruction* instr : instrs) { + for (const HloEdge& edge : + sched_state.sched_graph.GetNode(instr).GetSuccessors()) { + const HloGraphNode& user = edge.Target(); + if (seen_instrs.insert(&user.GetInstr()).second && + (user.GetAnnotation() != annotation)) { + ++num_successors; + } + } + } + return num_successors; +} + bool DefaultSchedulerCore::SchedulingAnnotationCrossesOverlapLimit( const SchedulingState& sched_state, int64_t annotation) { absl::flat_hash_map num_resources_needed = @@ -2496,19 +2517,25 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { latency_estimator_, async_tracker_, &memory_pressure_tracker, config_); async_tracker_->PostProcessScheduleGraph(&sched_state.sched_graph, latency_estimator_); + sched_state.sched_graph.InitializeGraphAnalysis(async_tracker_); + VLOG(5) << "Just built graph:"; + if (annotation_tracker_->HasAnnotations(computation)) { sched_state.sched_graph.AnnotateGraph(annotation_tracker_.get()); for (int64_t annotation : annotation_tracker_->GetAnnotations(computation)) { - if (annotation_tracker_->GetSuccessors(computation, annotation).empty()) { + int64_t num_successors = + GetNumSuccessorsForAnnotation(sched_state, annotation); + sched_state.num_successors_for_annotation[annotation].all = + num_successors; + if (num_successors == 0) { VLOG(3) << "Annotation " << annotation << " does not have any successors, is ready to be scheduled"; sched_state.ready_annotations.push_back(annotation); } } } - sched_state.sched_graph.InitializeGraphAnalysis(async_tracker_); - VLOG(5) << "Just built graph:"; + XLA_VLOG_LINES(5, sched_state.sched_graph.ToString(async_tracker_)); async_tracker_->SetConcurrentResourceLimits( sched_state.max_concurrent_resource); diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 9f0cb987d78e46..3687d46fcf0723 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -56,6 +56,8 @@ limitations under the License. namespace xla { +inline constexpr int64_t kInvalidAnnotation = -1; + struct CanonicalAsyncOp { HloOpcode outer; // kAsyncStart or kAsyncDone HloOpcode inner; // kAllReduce, kAllGather, kAllToAll, kCollectiveBroadcast, @@ -379,46 +381,15 @@ class AnnotationTracker { } std::vector GetInstructions( const HloComputation* comp, const int64_t annotation) const { + if (annotation == kInvalidAnnotation) { + return {}; + } return annotations_.at(annotation).at(comp); } int64_t GetNumInstructions(const HloComputation* comp, const int64_t annotation) { return annotations_[annotation][comp].size(); } - void FindSuccessors(const HloComputation* comp, const int64_t annotation) { - absl::flat_hash_set seen_instructions( - annotations_[annotation][comp].begin(), - annotations_[annotation][comp].end()); - for (const HloInstruction* instr : annotations_.at(annotation).at(comp)) { - for (const PtrVec& users : - {instr->users(), instr->control_successors()}) { - for (HloInstruction* user : users) { - if (!seen_instructions.contains(user) && - (GetAnnotation(user) == std::nullopt || - GetAnnotation(user).value() != annotation)) { - annotation_successors_[annotation][comp].push_back(user); - VLOG(3) << "Annotation group: " << annotation - << ", successor: " << user->name(); - } - seen_instructions.insert(user); - } - } - } - } - int64_t GetNumSuccessors(const HloComputation* comp, - const int64_t annotation) { - if (!annotation_successors_[annotation].contains(comp)) { - FindSuccessors(comp, annotation); - } - return annotation_successors_[annotation][comp].size(); - } - std::vector GetSuccessors(const HloComputation* comp, - const int64_t annotation) { - if (!annotation_successors_[annotation].contains(comp)) { - FindSuccessors(comp, annotation); - } - return annotation_successors_[annotation][comp]; - } void PrintAnnotationSets(int64_t level) const { for (const auto& [annotation, comp_instr_vector] : annotations_) { for (const auto& [comp, instrs] : comp_instr_vector) { @@ -664,7 +635,7 @@ class HloGraphNode { int64_t GetOriginalPosition() const { return original_position_; } int64_t GetAnnotation() const { return annotation_; } absl::Status SetAnnotation(int64_t annotation) { - TF_RET_CHECK(annotation_ == -1) + TF_RET_CHECK(annotation_ == kInvalidAnnotation) << "Instruction " << instr_->name() << " has an existing annotation: " << annotation_; annotation_ = annotation; @@ -752,7 +723,7 @@ class HloGraphNode { // Nums hops to closest selective resource occupier. int64_t num_hops_to_closest_selective_resource_occupier_ = std::numeric_limits::max(); - int64_t annotation_ = -1; + int64_t annotation_ = kInvalidAnnotation; }; // Schedule graph that can be used to drive scheduling @@ -1090,15 +1061,20 @@ class DefaultSchedulerCore : public SchedulerCore { std::vector selective_resource_releasers; // Similar to ready set, but only contains the no-op instructions. ReadyQueueSet nop_set; - // Number of scheduled nodes that are a successor for the given annotation. - absl::flat_hash_map - num_scheduled_successors_for_annotation; + // Number of {scheduled, all} nodes that are a successor for the given + // annotation. + struct NumSuccessorsForAnnotation { + int64_t scheduled = 0; + int64_t all = 0; + }; + absl::flat_hash_map + num_successors_for_annotation; // List of annotations that are ready to be scheduled. absl::InlinedVector ready_annotations; // List of annotated nodes that are ready to be scheduled. ReadyQueueSet annotation_ready; // Annotation that is currently being scheduled. - int64_t ongoing_annotation = -1; + int64_t ongoing_annotation = kInvalidAnnotation; // Reference to this scheduler run configuration. const SchedulerConfig& config; SchedulingState(const HloInstructionSequence* instr_sequence, @@ -1173,6 +1149,8 @@ class DefaultSchedulerCore : public SchedulerCore { const SchedulingState& sched_state, int64_t annotation); absl::flat_hash_map GetNumResourcesNeededForAnnotation( const SchedulingState& sched_state, int64_t annotation); + int64_t GetNumSuccessorsForAnnotation(const SchedulingState& sched_state, + int64_t annotation) const; ScheduleProto::ComputationScheduleProto ComputationScheduleToProto( const HloComputation* computation, const HloScheduleGraph& schedule_graph, From fe81783f0996d6e6d12209af5415a9e298e5ba46 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 21:31:34 -0700 Subject: [PATCH 0774/1324] Prevent calling `dimensions()` or `dimensions_size()` on non-array shapes. It's a bug to use an array shape as a tuple or use a tuple shape as an array, for example. To catch such bugs, we make these methods fail if called on non-array shapes. PiperOrigin-RevId: 747700020 --- .../xla/service/gpu/model/triton_emitter_constraints.cc | 8 +++++--- third_party/xla/xla/shape.h | 7 +------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc index 50c5726b535f92..c893575fc5443e 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc @@ -182,9 +182,11 @@ TritonEmitterConstraints::GetBuilder( const HloInstructionAdaptor& instr) { return &instr.instruction() == tiled_hlo_instruction->hlo(); })) { - root_infos.push_back(RootTileInfo{ - tiled_hlo_instruction->symbolic_tile().size_map(), - SpanToVector(tiled_hlo_instruction->hlo()->shape().dimensions())}); + const auto& shape = tiled_hlo_instruction->hlo()->shape(); + root_infos.push_back( + RootTileInfo{tiled_hlo_instruction->symbolic_tile().size_map(), + shape.IsArray() ? SpanToVector(shape.dimensions()) + : std::vector()}); } } diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index ee8a84cf293b20..0275d8a0218b7f 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -25,7 +25,6 @@ limitations under the License. #include #include -#include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" @@ -294,11 +293,7 @@ class Shape { // Returns a span to indicate the size of each dimension. // Precondition: this is an array shape. absl::Span dimensions() const { - if (const auto* const state = if_array_state()) { - return state->dimensions; - } - // TODO(b/404276923): ensure that this is never called on non-array shapes. - return {}; + return array_state().dimensions; } absl::Span mutable_dimensions() { return absl::MakeSpan(array_state().dimensions); From 1d6d2b78ba5511865fae0796221ec24a3863eb89 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 22:07:40 -0700 Subject: [PATCH 0775/1324] Run build_cleaner on BUILD file(s) located in /xla/python. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix: * any conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747709449 --- third_party/xla/xla/python/BUILD | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index cc85f7b928294a..4176e7e28f92a9 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -315,19 +315,15 @@ cc_library( "//xla/backends/profiler:profiler_backends", "//xla/backends/profiler/cpu:python_tracer", "//xla/backends/profiler/plugin:plugin_tracer", - "//xla/backends/profiler/plugin:profiler_c_api_hdrs", "//xla/pjrt:exceptions", "//xla/pjrt:status_casters", "//xla/pjrt/c:pjrt_c_api_hdrs", - "//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "//xla/python/profiler:profile_data_lib", "//xla/tsl/platform:macros", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc/client:capture_profile", "//xla/tsl/profiler/rpc/client:profiler_client_impl", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/profiler/lib:profiler_factory", - "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", @@ -384,7 +380,6 @@ xla_cc_test( "//xla/tsl/profiler/utils:file_system_utils", "//xla/tsl/profiler/utils:xplane_builder", "//xla/tsl/profiler/utils:xplane_schema", - "@com_google_googletest//:gtest", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], @@ -415,7 +410,9 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@nanobind", - # copybara:uncomment "//third_party/py/numpy:multiarray", + # copybara:uncomment_begin + # "//third_party/py/numpy:multiarray", # build_cleaner: keep + # copybara:uncomment_end "@local_config_python//:python_headers", "//xla/tsl/python/lib/core:numpy", ], @@ -493,7 +490,6 @@ xla_cc_test( "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", ], From 5a9264b2b852416374a8381a696745ce2ddda6f2 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 22:07:56 -0700 Subject: [PATCH 0776/1324] Run build_cleaner on BUILD file(s) located in /xla/gpu/kernels. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747709504 --- third_party/xla/xla/service/gpu/kernels/BUILD | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index def94e22faef1c..95eba50628a334 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -129,7 +129,6 @@ xla_test( "//xla:literal_util", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/hlo/testlib:verified_hlo_module", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/stream_executor:device_description", @@ -428,6 +427,7 @@ cuda_library( deps = [ ":cutlass_gemm", "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", ], ) @@ -481,7 +481,6 @@ cuda_library( deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -496,7 +495,6 @@ cuda_library( deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -514,7 +512,6 @@ cuda_library( deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -532,7 +529,6 @@ cuda_library( deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -547,7 +543,6 @@ cuda_library( deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -584,7 +579,6 @@ cc_library( "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], From 5d0154db59673793a532aed0d34d53f85a57e573 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 22:08:11 -0700 Subject: [PATCH 0777/1324] Run build_cleaner on BUILD file(s) located in /xla/tsl. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747709556 --- .../xla/xla/tsl/framework/convolution/BUILD | 6 +- third_party/xla/xla/tsl/platform/BUILD | 100 +++++++++--------- 2 files changed, 53 insertions(+), 53 deletions(-) diff --git a/third_party/xla/xla/tsl/framework/convolution/BUILD b/third_party/xla/xla/tsl/framework/convolution/BUILD index 80cab0e4d97eab..9687b8a9fd13ed 100644 --- a/third_party/xla/xla/tsl/framework/convolution/BUILD +++ b/third_party/xla/xla/tsl/framework/convolution/BUILD @@ -21,7 +21,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ - "//xla/tsl/framework/convolution:eigen_convolution_helpers", + ":eigen_convolution_helpers", ], ) @@ -45,9 +45,9 @@ cc_library( compatible_with = get_compatible_with_portable(), defines = ["EIGEN_NEON_GEBP_NR=4"], deps = [ + ":eigen_convolution_helpers", + ":eigen_spatial_convolutions-inl", "//xla/tsl/framework/contraction:eigen_contraction_kernel", - "//xla/tsl/framework/convolution:eigen_convolution_helpers", - "//xla/tsl/framework/convolution:eigen_spatial_convolutions-inl", "@eigen_archive//:eigen3", ], ) diff --git a/third_party/xla/xla/tsl/platform/BUILD b/third_party/xla/xla/tsl/platform/BUILD index db27bed86811cd..16c69ae2a1a28a 100644 --- a/third_party/xla/xla/tsl/platform/BUILD +++ b/third_party/xla/xla/tsl/platform/BUILD @@ -352,9 +352,9 @@ tsl_cc_test( ], tags = ["no_oss"], # TODO(b/327036247): revisit after this moves to XLA deps = [ + ":subprocess", + ":test", "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:subprocess", - "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:strcat", @@ -405,7 +405,7 @@ tsl_cc_test( srcs = ["errors_test.cc"], deps = [ ":errors", - "//xla/tsl/platform:test", + ":test", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], @@ -444,10 +444,10 @@ tsl_cc_test( size = "small", srcs = ["file_system_helper_test.cc"], deps = [ + ":env", + ":env_impl", ":file_system_helper", "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:env", - "//xla/tsl/platform:env_impl", "@com_google_googletest//:gtest_main", ], ) @@ -456,7 +456,7 @@ cc_library( name = "file_statistics", hdrs = ["file_statistics.h"], deps = [ - "//xla/tsl/platform:types", + ":types", ], ) @@ -478,8 +478,8 @@ tsl_cc_test( ], deps = [ ":logging", - "//xla/tsl/platform:statusor", - "//xla/tsl/platform:test", + ":statusor", + ":test", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -501,10 +501,10 @@ cc_library( srcs = ["status.cc"], hdrs = ["status.h"], deps = [ - "//xla/tsl/platform:logging", - "//xla/tsl/platform:macros", - "//xla/tsl/platform:stack_frame", - "//xla/tsl/platform:types", + ":logging", + ":macros", + ":stack_frame", + ":types", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -527,12 +527,12 @@ tsl_cc_test( size = "small", srcs = ["status_test.cc"], deps = [ + ":errors", + ":stack_frame", ":status", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:stack_frame", - "//xla/tsl/platform:status_matchers", - "//xla/tsl/platform:status_to_from_proto", - "//xla/tsl/platform:test", + ":status_matchers", + ":status_to_from_proto", + ":test", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "//xla/tsl/protobuf:status_proto_cc", "@com_google_absl//absl/status", @@ -548,9 +548,9 @@ cc_library( srcs = ["status_matchers.cc"], hdrs = ["status_matchers.h"], deps = [ - "//xla/tsl/platform:status", - "//xla/tsl/platform:statusor", - "//xla/tsl/platform:test", + ":status", + ":statusor", + ":test", "//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -560,11 +560,11 @@ tsl_cc_test( size = "small", srcs = ["status_matchers_test.cc"], deps = [ - "//xla/tsl/platform:errors", - "//xla/tsl/platform:status", - "//xla/tsl/platform:status_matchers", - "//xla/tsl/platform:statusor", - "//xla/tsl/platform:test", + ":errors", + ":status", + ":status_matchers", + ":statusor", + ":test", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_googletest//:gtest_main", ], @@ -577,7 +577,7 @@ cc_library( ], hdrs = ["status_to_from_proto.h"], deps = [ - "//xla/tsl/platform:status", + ":status", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "//xla/tsl/protobuf:status_proto_cc", "@com_google_absl//absl/strings", @@ -589,10 +589,10 @@ cc_library( name = "statusor", hdrs = ["statusor.h"], deps = [ - "//xla/tsl/platform:errors", - "//xla/tsl/platform:logging", - "//xla/tsl/platform:macros", - "//xla/tsl/platform:status", + ":errors", + ":logging", + ":macros", + ":status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -607,12 +607,12 @@ tsl_cc_test( size = "small", srcs = ["statusor_test.cc"], deps = [ + ":errors", + ":macros", ":statusor", + ":test", + ":test_benchmark", ":test_main", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:macros", - "//xla/tsl/platform:test", - "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/base:config", ], ) @@ -624,9 +624,9 @@ cc_library( compatible_with = get_compatible_with_portable(), textual_hdrs = ["test.h"], deps = [ - "//xla/tsl/platform:logging", - "//xla/tsl/platform:macros", - "//xla/tsl/platform:types", + ":logging", + ":macros", + ":types", "@com_google_googletest//:gtest_for_library", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:net", @@ -655,8 +655,8 @@ cc_library( "//conditions:default": ["-lm"], }), deps = [ - "//xla/tsl/platform:test", - "//xla/tsl/platform:test_benchmark", + ":test", + ":test_benchmark", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:stacktrace_handler", @@ -668,8 +668,8 @@ cc_library( name = "threadpool_async_executor", hdrs = ["threadpool_async_executor.h"], deps = [ + ":env", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:env", ], ) @@ -677,10 +677,10 @@ tsl_cc_test( name = "threadpool_async_executor_test", srcs = ["threadpool_async_executor_test.cc"], deps = [ + ":env", + ":env_impl", + ":test", ":threadpool_async_executor", - "//xla/tsl/platform:env", - "//xla/tsl/platform:env_impl", - "//xla/tsl/platform:test", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", ], @@ -691,7 +691,7 @@ cc_library( hdrs = ["threadpool_interface.h"], compatible_with = get_compatible_with_portable(), deps = [ - "//xla/tsl/platform:types", + ":types", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:mutex", ], @@ -702,7 +702,7 @@ cc_library( hdrs = ["threadpool_options.h"], compatible_with = get_compatible_with_portable(), deps = [ - "//xla/tsl/platform:threadpool_interface", + ":threadpool_interface", ], ) @@ -763,8 +763,8 @@ tsl_cc_test( ], deps = [ ":intrusive_ptr", - "//xla/tsl/platform:test", - "//xla/tsl/platform:test_main", + ":test", + ":test_main", "@local_tsl//tsl/platform:refcount", ], ) @@ -798,8 +798,8 @@ cc_library( srcs = ["resource_loader.cc"], textual_hdrs = ["resource_loader.h"], deps = [ - "//xla/tsl/platform:logging", - "//xla/tsl/platform:test", + ":logging", + ":test", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:path", ], @@ -832,7 +832,7 @@ tsl_cc_test( ], deps = [ ":criticality", - "//xla/tsl/platform:test", + ":test", "@com_google_googletest//:gtest_main", ], ) From cd199279aeeee067fd99fecc26a0660afc606069 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 22:09:30 -0700 Subject: [PATCH 0778/1324] Run build_cleaner on BUILD file(s) located in /xla/service. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747709960 --- third_party/xla/xla/service/BUILD | 57 ------------------------------- 1 file changed, 57 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2c7b067514974c..413a03404eb67d 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -245,17 +245,7 @@ cc_library( hdrs = ["collective_permute_cycle.h"], deps = [ ":source_target_pairs", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/parser:hlo_parser", - "//xla/service/graphcycles", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", ], ) @@ -265,11 +255,7 @@ xla_cc_test( deps = [ ":collective_permute_cycle", ":source_target_pairs", - "//xla:shape_util", - "//xla/hlo/ir:hlo", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) @@ -285,9 +271,7 @@ cc_library( ":pattern_matcher", ":source_target_pairs", "//xla:shape_util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", @@ -308,7 +292,6 @@ xla_cc_test( ":collective_permute_decomposer", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", @@ -328,7 +311,6 @@ cc_library( ":hlo_domain_map", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/log", ], ) @@ -342,7 +324,6 @@ cc_library( "//xla/hlo/parser:hlo_parser", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", ], ) @@ -607,7 +588,6 @@ xla_cc_test( deps = [ "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", ], ) @@ -710,7 +690,6 @@ xla_cc_test( ], deps = [ ":sharding_remover", - "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", @@ -980,8 +959,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) @@ -1434,7 +1411,6 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", @@ -1745,7 +1721,6 @@ xla_cc_test( ":cpu_plugin", ":hlo_buffer", ":hlo_proto_cc", - ":hlo_proto_util", ":hlo_value", ":logical_buffer", "//xla:comparison_util", @@ -1761,7 +1736,6 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/hlo/testlib:test_helpers", "//xla/hlo/transforms/simplifiers:flatten_call_graph", - "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/service/memory_space_assignment", "//xla/tests:hlo_test_base", @@ -2193,7 +2167,6 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", ], ) @@ -2271,7 +2244,6 @@ xla_test( deps = [ ":batchnorm_expander", "//xla:error_spec", - "//xla:literal", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", @@ -2411,7 +2383,6 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -2634,13 +2605,11 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/lib/core:bitmap", - "@com_google_absl//absl/algorithm", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2649,7 +2618,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -2703,12 +2671,9 @@ xla_cc_test( deps = [ ":copy_insertion", ":scan_loop_accumulator_input_unification", - "//xla:literal", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", @@ -2727,10 +2692,8 @@ cc_library( ":while_loop_unroller", "//xla:shape_util", "//xla:util", - "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2776,7 +2739,6 @@ cc_library( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:side_effect_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", @@ -3051,7 +3013,6 @@ xla_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", "//xla/tests:llvm_irgen_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", @@ -3066,7 +3027,6 @@ xla_test( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_benchmark", ], ) @@ -3435,7 +3395,6 @@ cc_library( "//xla:lazy", "//xla:shape_tree", "//xla:shape_util", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -3446,7 +3405,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", ], ) @@ -3547,10 +3505,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", ], ) @@ -4078,7 +4033,6 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo_sharding", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", @@ -4479,7 +4433,6 @@ xla_cc_test( name = "host_offload_utils_test", srcs = ["host_offload_utils_test.cc"], deps = [ - ":hlo_verifier", ":host_offload_utils", ":memory_annotations_hdr", ":pattern_matcher", @@ -4489,11 +4442,6 @@ xla_cc_test( "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], @@ -5143,7 +5091,6 @@ cc_library( srcs = ["global_device_id.cc"], hdrs = ["global_device_id.h"], deps = [ - "//xla:types", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -5253,11 +5200,9 @@ cc_library( srcs = ["rendezvous.cc"], hdrs = ["rendezvous.h"], deps = [ - "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -5410,8 +5355,6 @@ xla_cc_test( "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", From 6ae3ee532de2a8f4a4fcba1e41603aaef8aacd9d Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Mon, 14 Apr 2025 22:13:23 -0700 Subject: [PATCH 0779/1324] AllGatherOp : Direct StableHLO -> HLO Translation PiperOrigin-RevId: 747711024 --- .../mhlo_to_hlo/gen_hlo_op_writer.td | 4 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 63 +++++++++++++++++++ .../xla/hlo/translate/tests/stablehlo.mlir | 53 ++++++++++++++++ .../stablehlo_legalize_to_hlo_pass.cc | 22 +++---- 4 files changed, 129 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index 1f1c88b0c5730c..5635afaf7d5502 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -22,7 +22,7 @@ defvar HloConversionAllowedOps = [ StableHLO_AbsOp, StableHLO_AddOp, // StableHLO_AfterAllOp, - // StableHLO_AllGatherOp, + StableHLO_AllGatherOp, // StableHLO_AllReduceOp, // StableHLO_AllToAllOp, // StableHLO_AndOp, @@ -140,7 +140,7 @@ defvar HloConversionAllowedOps = [ defvar CustomHloConverterOps = [ // StableHLO ops // go/keep-sorted start - // StableHLO_AllGatherOp, + StableHLO_AllGatherOp, // StableHLO_AllReduceOp, // StableHLO_AllToAllOp, // StableHLO_BatchNormGradOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index c9254a295d739b..ad182b90fb2499 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -575,6 +575,17 @@ std::optional Convert_channel_handle( return Convert_channel_handle(attr.value()); } +std::optional Convert_channel_handle( + const std::optional attr) { + if (!attr.has_value()) return std::nullopt; + + xla::ChannelHandle channel_handle; + channel_handle.set_handle(attr->getHandle()); + channel_handle.set_type( + static_cast(attr->getType())); + return channel_handle; +} + // Converts the comparison_direction string attribute into the XLA enum. The // string is assumed to correspond to exactly one of the allowed strings // representing the enum. This should have been checked in the op verify method. @@ -1264,6 +1275,58 @@ LogicalResult ExportXlaOp(SubtractOp op, OpLoweringContext ctx) { return mlir::success(); } +LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + + SmallVector operands; + if (failed(GetTuple(op.getOperation(), op.getOperands(), ctx, operands))) + return op.emitOpError("failed to get tuple"); + + mlir::FailureOr shape_or = ExtractXlaShape(op.getOperation()); + if (failed(shape_or)) return op.emitOpError("failed to extract XLA shape"); + + auto all_gather_dim = op.getAllGatherDim(); + int64_t shard_count = 0; + + for (const auto& indexed_pair : + llvm::enumerate(llvm::zip(op.getOperandTypes(), op.getResultTypes()))) { + auto [operand_type, result_type] = indexed_pair.value(); + TensorType operand_ttype = mlir::cast(operand_type); + TensorType result_ttype = mlir::cast(result_type); + if (!operand_ttype || !result_ttype) + return op.emitOpError("operands/results must be TensorTypes"); + + if (!operand_ttype.hasStaticShape() || !result_ttype.hasStaticShape()) + return op.emitOpError("operands/results must have static shapes"); + + if (indexed_pair.index() == 0) { + shard_count = result_ttype.getDimSize(all_gather_dim) / + operand_ttype.getDimSize(all_gather_dim); + } + } + + if (shape_or->IsTuple()) { + std::optional layout = std::nullopt; + if (shape_or->has_layout()) layout = shape_or->layout(); + + auto tuple = xla::AllGatherTuple( + operands, all_gather_dim, shard_count, + Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle()), layout, + Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + BuildGetTupleElementsForTupleResults(op, tuple, ctx); + return success(); + } + + value_map[op->getResults()[0]] = xla::AllGather( + operands[0], all_gather_dim, shard_count, + Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle()), std::nullopt, + Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + + return success(); +} + } // namespace } // namespace stablehlo diff --git a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir index d4c4aec7fdb8d3..6ea8d547cee342 100644 --- a/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir +++ b/third_party/xla/xla/hlo/translate/tests/stablehlo.mlir @@ -155,3 +155,56 @@ func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf3 // CHECK-DIRECT: stablehlo.multiply // CHECK-DIRECT: stablehlo.power // CHECK-DIRECT: stablehlo.subtract + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[128,32] parameter(0) +// CHECK-NEXT: ROOT %[[all_gather_2:[^ ]+]] = f32[128,128] all-gather(%[[Arg_0_1]]), channel_id=1, +// CHECK-SAME{{LITERAL}}: replica_groups={{0,2,4,6},{1,3,5,7}}, +// CHECK-SAME: dimensions={1}, +func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { +%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 1 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>}> {shard_count = 4 : i64} : (tensor<128x32xf32>) -> tensor<128x128xf32> +return %0 : tensor<128x128xf32> +} +// CHECK-DIRECT: stablehlo.all_gather + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} + +// CHECK: ENTRY %[[$main_3:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[128,32] parameter(0) +// CHECK-NEXT: ROOT %[[all_gather_2:[^ ]+]] = f32[128,128] all-gather(%[[Arg_0_1]]), channel_id=1, +// CHECK-SAME{{LITERAL}}: replica_groups={{0,2,4,6},{1,3,5,7}}, +// CHECK-SAME: dimensions={1}, use_global_device_ids=true, +func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { + %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 1 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> {shard_count = 4 : i64} : (tensor<128x32xf32>) -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} +// CHECK-DIRECT: stablehlo.all_gather + +// ----- + +// CHECK-LABEL: HloModule main, entry_computation_layout={(f32[8,2]{1,0}, f32[8,4]{1,0})->(f32[8,8]{1,0}, f32[8,16]{1,0})} + +// CHECK: ENTRY %[[$main_10:[^ ]+]] +// CHECK-NEXT: %[[Arg_0_1:[^ ]+]] = f32[8,2] parameter(0) +// CHECK-NEXT: %[[Arg_1_2:[^ ]+]] = f32[8,4] parameter(1) +// CHECK-NEXT: %[[tuple_3:[^ ]+]] = (f32[8,2], f32[8,4]) tuple(%[[Arg_0_1]], %[[Arg_1_2]]), +// CHECK-NEXT: %[[get_tuple_element_4:[^ ]+]] = f32[8,2] get-tuple-element(%[[tuple_3]]), index=0, +// CHECK-NEXT: %[[get_tuple_element_5:[^ ]+]] = f32[8,4] get-tuple-element(%[[tuple_3]]), index=1, +// CHECK-NEXT: %[[all_gather_6:[^ ]+]] = (f32[8,8], f32[8,16]) all-gather(%[[get_tuple_element_4]], %[[get_tuple_element_5]]), channel_id=1, +// CHECK-SAME{{LITERAL}}: replica_groups={{0,2,4,6},{1,3,5,7}}, +// CHECK-SAME: dimensions={1}, use_global_device_ids=true, +// CHECK-NEXT: %[[get_tuple_element_7:[^ ]+]] = f32[8,8] get-tuple-element(%[[all_gather_6]]), index=0, +// CHECK-NEXT: %[[get_tuple_element_8:[^ ]+]] = f32[8,16] get-tuple-element(%[[all_gather_6]]), index=1, +// CHECK-NEXT: ROOT %[[tuple_9:[^ ]+]] = (f32[8,8], f32[8,16]) tuple(%[[get_tuple_element_7]], %[[get_tuple_element_8]]), +func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple, tensor<8x16xf32>> { + %0:2 = "stablehlo.all_gather"(%arg0, %arg1) <{all_gather_dim = 1 : i64, channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) + %1 = stablehlo.tuple %0#0, %0#1 {xla_shape = "(f32[8,8]{0,1}, f32[8,16]{0,1})"} : tuple, tensor<8x16xf32>> + return %1 : tuple, tensor<8x16xf32>> +} +// CHECK-DIRECT: stablehlo.all_gather diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 7acd9e3b3ea980..9112c4c6876e09 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -68,17 +68,17 @@ void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< // go/keep-sorted start stablehlo::AbsOp, stablehlo::CbrtOp, stablehlo::SqrtOp, stablehlo::TanOp, - stablehlo::AddOp, stablehlo::Atan2Op, stablehlo::AddOp, - stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, stablehlo::CeilOp, - stablehlo::ClzOp, stablehlo::ConvertOp, stablehlo::ComplexOp, - stablehlo::ConvolutionOp, stablehlo::CosineOp, stablehlo::DynamicSliceOp, - stablehlo::DivOp, stablehlo::MaxOp, stablehlo::ConstantOp, - stablehlo::Expm1Op, stablehlo::DynamicBroadcastInDimOp, - stablehlo::FloorOp, stablehlo::ImagOp, stablehlo::ExpOp, - stablehlo::IsFiniteOp, stablehlo::Log1pOp, stablehlo::LogOp, - stablehlo::LogisticOp, stablehlo::NegOp, stablehlo::NotOp, - stablehlo::MinOp, stablehlo::MulOp, stablehlo::PowOp, stablehlo::RemOp, - stablehlo::PopulationCountOp, stablehlo::RealOp, + stablehlo::AddOp, stablehlo::AddOp, stablehlo::AllGatherOp, + stablehlo::Atan2Op, stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, + stablehlo::CeilOp, stablehlo::ClzOp, stablehlo::ConvertOp, + stablehlo::ComplexOp, stablehlo::ConvolutionOp, stablehlo::CosineOp, + stablehlo::ConstantOp, stablehlo::Expm1Op, + stablehlo::DynamicBroadcastInDimOp, stablehlo::FloorOp, stablehlo::ImagOp, + stablehlo::DynamicSliceOp, stablehlo::DivOp, stablehlo::MaxOp, + stablehlo::ExpOp, stablehlo::IsFiniteOp, stablehlo::Log1pOp, + stablehlo::LogOp, stablehlo::LogisticOp, stablehlo::NegOp, + stablehlo::NotOp, stablehlo::MinOp, stablehlo::MulOp, stablehlo::PowOp, + stablehlo::RemOp, stablehlo::PopulationCountOp, stablehlo::RealOp, stablehlo::RoundNearestEvenOp, stablehlo::RoundOp, stablehlo::RsqrtOp, stablehlo::ShiftLeftOp, stablehlo::ShiftRightArithmeticOp, stablehlo::ShiftRightLogicalOp, stablehlo::SubtractOp, stablehlo::SignOp, From 4717156fd10f5a9ca2da82dae380a787b55b77fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 22:21:44 -0700 Subject: [PATCH 0780/1324] Automated Code Change PiperOrigin-RevId: 747713162 --- .../xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 2f07003cfc6131..70713335cbda74 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include From 481a26cb7e3d9e00323a06ec378ba324a9f66ba9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 14 Apr 2025 22:22:22 -0700 Subject: [PATCH 0781/1324] [xla:gpu] CommandBuffer: split command executor from command sequence - command sequence is simply a vector of commands (just like ThunkSequence is just a vector of thunks) - command executor takes a command sequence and provides and API to prepare/initialize/record all commands in a sequence into the underlying command buffer, and manages dependencies between commands CommandBufferCmdExecutor should be extracted into a separate file, but it requires splitting each individual command into separate target. This will come in followup changes. PiperOrigin-RevId: 747713278 --- .../xla/xla/backends/gpu/runtime/BUILD | 70 +-- .../gpu/runtime/command_buffer_cmd.cc | 66 +-- .../backends/gpu/runtime/command_buffer_cmd.h | 72 ++- .../gpu/runtime/command_buffer_cmd_emitter.cc | 54 +- .../gpu/runtime/command_buffer_cmd_emitter.h | 6 +- .../gpu/runtime/command_buffer_cmd_test.cc | 126 +++-- .../gpu/runtime/command_buffer_thunk.cc | 4 +- .../gpu/runtime/command_buffer_thunk.h | 9 +- .../gpu/runtime/command_buffer_thunk_test.cc | 509 ++++++++++-------- .../gpu/runtime/for_all_thunks_test.cc | 2 +- .../xla/service/gpu/ir_emitter_unnested.cc | 10 +- 11 files changed, 488 insertions(+), 440 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index a370f43c61e8a2..a16c0bf2495891 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -120,40 +120,6 @@ cc_library( ], ) -cc_library( - name = "command_buffer_cmd_emitter", - srcs = ["command_buffer_cmd_emitter.cc"], - hdrs = ["command_buffer_cmd_emitter.h"], - deps = [ - ":all_gather_thunk", - ":all_reduce_thunk", - ":all_to_all_thunk", - ":command_buffer_cmd", - ":conditional_thunk", - ":copy_thunk", - ":cudnn_thunk", - ":custom_call_thunk", - ":dynamic_slice_thunk", - ":gemm_thunk", - ":gpublas_lt_matmul_thunk", - ":kernel_thunk", - ":memset_thunk", - ":replica_id_thunk", - ":sequential_thunk", - ":thunk", - ":while_thunk", - "//xla:util", - "//xla/runtime:buffer_use", - "//xla/service:buffer_assignment", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - xla_test( name = "command_buffer_cmd_test", srcs = if_gpu_is_configured(["command_buffer_cmd_test.cc"]), @@ -191,6 +157,42 @@ xla_test( ], ) +cc_library( + name = "command_buffer_cmd_emitter", + srcs = ["command_buffer_cmd_emitter.cc"], + hdrs = ["command_buffer_cmd_emitter.h"], + deps = [ + ":all_gather_thunk", + ":all_reduce_thunk", + ":all_to_all_thunk", + ":collective_thunk", + ":command_buffer_cmd", + ":conditional_thunk", + ":copy_thunk", + ":cudnn_thunk", + ":custom_call_thunk", + ":dynamic_slice_thunk", + ":gemm_thunk", + ":gpublas_lt_matmul_thunk", + ":kernel_thunk", + ":memset_thunk", + ":replica_id_thunk", + ":sequential_thunk", + ":thunk", + ":wait_for_streams_thunk", + ":while_thunk", + "//xla:util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + #===-------------------------------------------------------------------------------------------===// # XLA Thunks Runtime #===-------------------------------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 375cc6e514fc01..7b5bf945c2922a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -119,7 +119,7 @@ static absl::string_view ReductionKindString(ReductionKind kind) { // Create a callback to create a command buffer from a command sequence. static se::CommandBuffer::CreateCommands CreateCommands( - const CommandBufferCmdSequence* commands, + const CommandBufferCmdExecutor* commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { return [=](se::CommandBuffer* command_buffer, @@ -131,11 +131,11 @@ static se::CommandBuffer::CreateCommands CreateCommands( // Create callbacks to create a command buffer from command sequences. static std::vector CreateCommands( - absl::Span commands, + absl::Span commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { std::vector create_commands; - for (const CommandBufferCmdSequence& cmd : commands) { + for (const CommandBufferCmdExecutor& cmd : commands) { create_commands.push_back( CreateCommands(&cmd, execute_params, record_params)); } @@ -144,7 +144,7 @@ static std::vector CreateCommands( // Create a callback to update a command buffer with command sequence. static se::CommandBuffer::UpdateCommands UpdateCommands( - const CommandBufferCmdSequence* commands, + const CommandBufferCmdExecutor* commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { return [=](se::CommandBuffer* command_buffer) { @@ -155,11 +155,11 @@ static se::CommandBuffer::UpdateCommands UpdateCommands( // Create callbacks to update a command buffer with command sequence. static std::vector UpdateCommands( - absl::Span commands, + absl::Span commands, const Thunk::ExecuteParams* execute_params, const CommandBufferCmd::RecordParams* record_params) { std::vector update_commands; - for (const CommandBufferCmdSequence& cmd : commands) { + for (const CommandBufferCmdExecutor& cmd : commands) { update_commands.push_back( UpdateCommands(&cmd, execute_params, record_params)); } @@ -228,7 +228,6 @@ CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( //===----------------------------------------------------------------------===// namespace { - // An adaptor from CommandBufferCmd to ExecutionGraph::Operation for building an // execution graph from a command sequence. class CommandOperation : public ExecutionGraph::Operation { @@ -242,17 +241,11 @@ class CommandOperation : public ExecutionGraph::Operation { private: CommandBufferCmd::BufferUseVector buffers_; }; - } // namespace -void CommandBufferCmdSequence::Builder::Append( - std::unique_ptr cmd) { - commands_.push_back({std::move(cmd)}); -} - -absl::StatusOr -CommandBufferCmdSequence::Builder::Build( - SynchronizationMode synchronization_mode) && { +absl::StatusOr CommandBufferCmdExecutor::Create( + CommandBufferCmdSequence commands, + SynchronizationMode synchronization_mode) { std::optional execution_graph = std::nullopt; // In automatic synchronization mode construct an execution graph for the @@ -260,21 +253,20 @@ CommandBufferCmdSequence::Builder::Build( // from the buffer use conflicts. if (synchronization_mode == SynchronizationMode::kAutomatic) { std::vector operations; - operations.reserve(commands_.size()); - for (const std::unique_ptr& cmd : commands_) { + operations.reserve(commands.size()); + for (const std::unique_ptr& cmd : commands) { operations.emplace_back(cmd->buffers()); } TF_ASSIGN_OR_RETURN(execution_graph, ExecutionGraph::Create(operations)); } - return CommandBufferCmdSequence(synchronization_mode, std::move(commands_), + return CommandBufferCmdExecutor(synchronization_mode, std::move(commands), std::move(execution_graph)); } -CommandBufferCmdSequence::CommandBufferCmdSequence( - SynchronizationMode synchronization_mode, - std::vector> commands, +CommandBufferCmdExecutor::CommandBufferCmdExecutor( + SynchronizationMode synchronization_mode, CommandBufferCmdSequence commands, std::optional execution_graph) : synchronization_mode_(synchronization_mode), commands_(std::move(commands)), @@ -288,7 +280,7 @@ CommandBufferCmdSequence::CommandBufferCmdSequence( } } -absl::Status CommandBufferCmdSequence::Prepare( +absl::Status CommandBufferCmdExecutor::Prepare( const Thunk::PrepareParams& params, Thunk::ResourceRequestsInterface& resource_requests) { for (auto& command : commands_) { @@ -297,7 +289,7 @@ absl::Status CommandBufferCmdSequence::Prepare( return absl::OkStatus(); } -absl::Status CommandBufferCmdSequence::Initialize( +absl::Status CommandBufferCmdExecutor::Initialize( const Thunk::InitializeParams& params, CommandBufferCmd::StateManager& state) { for (auto& command : commands_) { @@ -306,7 +298,7 @@ absl::Status CommandBufferCmdSequence::Initialize( return absl::OkStatus(); } -absl::Status CommandBufferCmdSequence::Record( +absl::Status CommandBufferCmdExecutor::Record( const Thunk::ExecuteParams& execute_params, const CommandBufferCmd::RecordParams& record_params, se::CommandBuffer* command_buffer) { @@ -329,7 +321,7 @@ absl::Status CommandBufferCmdSequence::Record( } absl::StatusOr> -CommandBufferCmdSequence::RecordCreate( +CommandBufferCmdExecutor::RecordCreate( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::Span dependencies) const { @@ -399,7 +391,7 @@ CommandBufferCmdSequence::RecordCreate( return sink_commands; } -absl::Status CommandBufferCmdSequence::RecordUpdate( +absl::Status CommandBufferCmdExecutor::RecordUpdate( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer) const { @@ -449,7 +441,7 @@ absl::Status CommandBufferCmdSequence::RecordUpdate( return absl::OkStatus(); } -absl::Status CommandBufferCmdSequence::CheckCommandBufferState( +absl::Status CommandBufferCmdExecutor::CheckCommandBufferState( se::CommandBuffer* command_buffer, se::CommandBuffer::State expected_state) const { if (command_buffer->state() != expected_state) { @@ -459,17 +451,17 @@ absl::Status CommandBufferCmdSequence::CheckCommandBufferState( return absl::OkStatus(); } -bool CommandBufferCmdSequence::IsSource(CommandId id) const { +bool CommandBufferCmdExecutor::IsSource(CommandId id) const { return execution_graph_ ? execution_graph_->is_source(id) : id == 0; } -bool CommandBufferCmdSequence::IsSink(CommandId id) const { +bool CommandBufferCmdExecutor::IsSink(CommandId id) const { return execution_graph_ ? execution_graph_->is_sink(id) : id + 1 == commands_.size(); } std::vector -CommandBufferCmdSequence::Dependencies(const RecordParams& record_params, +CommandBufferCmdExecutor::Dependencies(const RecordParams& record_params, se::CommandBuffer* command_buffer, CommandId id) const { // Source commands have no dependencies. @@ -510,13 +502,13 @@ CommandBufferCmdSequence::Dependencies(const RecordParams& record_params, return dependencies; } -const absl::flat_hash_set& CommandBufferCmdSequence::buffers() +const absl::flat_hash_set& CommandBufferCmdExecutor::buffers() const { return buffers_; } const absl::flat_hash_set& -CommandBufferCmdSequence::allocs_indices() const { +CommandBufferCmdExecutor::allocs_indices() const { return allocs_indices_; } @@ -981,7 +973,7 @@ CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() { CaseCmd::CaseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice index, bool index_is_bool, - std::vector branches) + std::vector branches) : CommandBufferCmd(CommandBufferCmdType::kCaseCmd, execution_stream_id), index_(index), index_is_bool_(index_is_bool), @@ -1055,8 +1047,8 @@ CommandBufferCmd::BufferUseVector CaseCmd::buffers() { WhileCmd::WhileCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred, - CommandBufferCmdSequence cond_commands, - CommandBufferCmdSequence body_commands) + CommandBufferCmdExecutor cond_commands, + CommandBufferCmdExecutor body_commands) : CommandBufferCmd(CommandBufferCmdType::kWhileCmd, execution_stream_id), pred_(pred), cond_commands_(std::move(cond_commands)), @@ -1947,7 +1939,7 @@ CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() { DynamicSliceFusionCmd::DynamicSliceFusionCmd( ExecutionStreamId execution_stream_id, - CommandBufferCmdSequence embedded_commands, + CommandBufferCmdExecutor embedded_commands, std::vector> arguments, std::vector> fake_allocations, std::vector>> offsets, diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index aeee085b44bc89..a7c8b56b3e576b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -277,18 +277,27 @@ class CommandBufferCmd { ExecutionStreamId execution_stream_id_; }; +// A sequence of commands (corresponds to a ThunkSequence from the Thunk API). +class CommandBufferCmdSequence + : public std::vector> { + public: + template + void Emplace(Args&&... args) { + this->emplace_back(std::make_unique(std::forward(args)...)); + } +}; + //===----------------------------------------------------------------------===// -// CommandBufferCmdSequence +// CommandBufferCmdExecutor //===----------------------------------------------------------------------===// -// A sequence of command buffer commands that create or update a command buffer. -// You can think of CommandBufferCmdSequence as a mini interpreter whose sole -// purpose is to manipulate command buffers at run time. -class CommandBufferCmdSequence { +// Command executor is responsible for recording commands sequence into the +// underlying command buffer and setting up dependencies between commands. +class CommandBufferCmdExecutor { public: - CommandBufferCmdSequence() = default; - CommandBufferCmdSequence(CommandBufferCmdSequence&&) = default; - CommandBufferCmdSequence& operator=(CommandBufferCmdSequence&&) = default; + CommandBufferCmdExecutor() = default; + CommandBufferCmdExecutor(CommandBufferCmdExecutor&&) = default; + CommandBufferCmdExecutor& operator=(CommandBufferCmdExecutor&&) = default; using RecordParams = CommandBufferCmd::RecordParams; @@ -316,23 +325,11 @@ class CommandBufferCmdSequence { } } - // A command buffer cmd sequence builder for lazy command sequence - // construction. - class Builder { - public: - void Append(std::unique_ptr cmd); - - template - void Emplace(Args... args) { - Append(std::make_unique(std::forward(args)...)); - } - - absl::StatusOr Build( - SynchronizationMode synchronization_mode) &&; - - private: - std::vector> commands_; - }; + // Creates a command executor from a sequence of commands using given + // synchronization mode. + static absl::StatusOr Create( + CommandBufferCmdSequence commands, + SynchronizationMode synchronization_mode); // Prepares all commands added to a sequence. absl::Status Prepare(const Thunk::PrepareParams& params, @@ -390,10 +387,9 @@ class CommandBufferCmdSequence { const se::CommandBuffer::Command* command; }; - CommandBufferCmdSequence( - SynchronizationMode synchronization_mode, - std::vector> commands, - std::optional execution_graph); + CommandBufferCmdExecutor(SynchronizationMode synchronization_mode, + CommandBufferCmdSequence commands, + std::optional execution_graph); absl::Status CheckCommandBufferState( se::CommandBuffer* command_buffer, @@ -411,7 +407,7 @@ class CommandBufferCmdSequence { CommandId id) const; SynchronizationMode synchronization_mode_; - std::vector> commands_; + CommandBufferCmdSequence commands_; // In automatic synchronization mode we build an execution graph for the // sequence of commands and use it to set up dependencies between commands. @@ -638,7 +634,7 @@ class Memset32Cmd : public CommandBufferCmd { class CaseCmd : public CommandBufferCmd { public: CaseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice index, - bool index_is_bool, std::vector branches); + bool index_is_bool, std::vector branches); absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; @@ -655,7 +651,7 @@ class CaseCmd : public CommandBufferCmd { private: BufferAllocation::Slice index_; bool index_is_bool_; - std::vector branches_; + std::vector branches_; }; //===----------------------------------------------------------------------===// @@ -665,8 +661,8 @@ class CaseCmd : public CommandBufferCmd { class WhileCmd : public CommandBufferCmd { public: WhileCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred, - CommandBufferCmdSequence cond_commands, - CommandBufferCmdSequence body_commands); + CommandBufferCmdExecutor cond_commands, + CommandBufferCmdExecutor body_commands); absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; @@ -682,8 +678,8 @@ class WhileCmd : public CommandBufferCmd { private: BufferAllocation::Slice pred_; - CommandBufferCmdSequence cond_commands_; - CommandBufferCmdSequence body_commands_; + CommandBufferCmdExecutor cond_commands_; + CommandBufferCmdExecutor body_commands_; }; //===----------------------------------------------------------------------===// @@ -1074,7 +1070,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { public: DynamicSliceFusionCmd( ExecutionStreamId execution_stream_id, - CommandBufferCmdSequence embedded_commands, + CommandBufferCmdExecutor embedded_commands, std::vector> arguments, std::vector> fake_allocations_, std::vector>> @@ -1102,7 +1098,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } private: - CommandBufferCmdSequence embedded_commands_; + CommandBufferCmdExecutor embedded_commands_; std::vector slices_; std::vector> fake_allocations_; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index 3aa4c001822d22..a9296db0ae9c98 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/backends/gpu/runtime/command_buffer_cmd_emitter.h" -#include #include #include #include @@ -50,15 +49,15 @@ limitations under the License. namespace xla::gpu { -// Appends command(s) converted from `thunk` to `cmd_sequence_builder`. -static absl::Status AppendCommands( - CommandBufferCmdSequence::Builder& cmd_sequence_builder, const Thunk& thunk, - const ConvertToCommandsOptions& options); +// Appends command(s) converted from `thunk` to `cmd_sequence`. +static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence, + const Thunk& thunk, + const ConvertToCommandsOptions& options); -// Appends command(s) converted from `sequence` to `cmd_sequence_builder`. -static absl::Status AppendCommands( - CommandBufferCmdSequence::Builder& cmd_sequence_builder, - const ThunkSequence& sequence, const ConvertToCommandsOptions& options); +// Appends command(s) converted from `sequence` to `cmd_sequence`. +static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence, + const ThunkSequence& sequence, + const ConvertToCommandsOptions& options); //===----------------------------------------------------------------------===// // Conversions from Thunk to Command @@ -108,10 +107,10 @@ static absl::StatusOr Convert(const Memset32BitValueThunk& thunk) { static absl::StatusOr Convert( const WhileThunk& thunk, const ConvertToCommandsOptions& options) { TF_ASSIGN_OR_RETURN( - CommandBufferCmdSequence cond_cmds, + CommandBufferCmdExecutor cond_cmds, ConvertToCommands(thunk.condition_thunk_sequence()->thunks(), options)); TF_ASSIGN_OR_RETURN( - CommandBufferCmdSequence body_cmds, + CommandBufferCmdExecutor body_cmds, ConvertToCommands(thunk.body_thunk_sequence()->thunks(), options)); return std::make_unique(thunk.execution_stream_id(), @@ -146,7 +145,7 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { static absl::StatusOr Convert( const ConditionalThunk& thunk, const ConvertToCommandsOptions& options) { - std::vector branch_cmds; + std::vector branch_cmds; branch_cmds.reserve(thunk.branch_thunks().size()); if (thunk.branch_index_is_bool()) { // For boolean predicates, we need to convert the branches in reverse order @@ -160,7 +159,7 @@ static absl::StatusOr Convert( ConvertToCommands(thunk.branch_thunks()[0]->thunks(), options)); } else { for (auto& branch_thunk : thunk.branch_thunks()) { - TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmds, + TF_ASSIGN_OR_RETURN(CommandBufferCmdExecutor cmds, ConvertToCommands(branch_thunk->thunks(), options)); branch_cmds.emplace_back(std::move(cmds)); } @@ -197,7 +196,7 @@ static absl::StatusOr Convert(const AllGatherStartThunk& thunk) { static absl::StatusOr Convert( const DynamicSliceThunk& thunk, const ConvertToCommandsOptions& options) { TF_ASSIGN_OR_RETURN( - CommandBufferCmdSequence embedded_cmds, + CommandBufferCmdExecutor embedded_cmds, ConvertToCommands(thunk.get_embedded_thunk()->thunks(), options)); auto& thunk_fake_allocations = thunk.get_fake_allocations(); @@ -260,12 +259,12 @@ static absl::StatusOr Convert(const Thunk& thunk, Args&&... args) { thunk); } -static absl::Status AppendCommands( - CommandBufferCmdSequence::Builder& cmd_sequence_builder, const Thunk& thunk, - const ConvertToCommandsOptions& options) { +static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence, + const Thunk& thunk, + const ConvertToCommandsOptions& options) { auto append = [&](absl::StatusOr command) -> absl::Status { if (command.ok()) { - cmd_sequence_builder.Append(std::move(*command)); + cmd_sequence.push_back(std::move(*command)); return absl::OkStatus(); } return command.status(); @@ -312,7 +311,7 @@ static absl::Status AppendCommands( // Sequential thunk does not have any special semantics and we simply inline // all nested thunks into command buffer. case Thunk::Kind::kSequential: - return AppendCommands(cmd_sequence_builder, + return AppendCommands(cmd_sequence, static_cast(thunk).thunks(), options); @@ -338,19 +337,20 @@ static absl::Status AppendCommands( } } -static absl::Status AppendCommands( - CommandBufferCmdSequence::Builder& cmd_sequence_builder, - const ThunkSequence& sequence, const ConvertToCommandsOptions& options) { +static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence, + const ThunkSequence& sequence, + const ConvertToCommandsOptions& options) { for (const std::unique_ptr& thunk : sequence) - TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, *thunk, options)); + TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence, *thunk, options)); return absl::OkStatus(); } -absl::StatusOr ConvertToCommands( +absl::StatusOr ConvertToCommands( const ThunkSequence& sequence, const ConvertToCommandsOptions& options) { - CommandBufferCmdSequence::Builder cmd_sequence_builder; - TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence_builder, sequence, options)); - return std::move(cmd_sequence_builder).Build(options.synchronization_mode); + CommandBufferCmdSequence cmd_sequence; + TF_RETURN_IF_ERROR(AppendCommands(cmd_sequence, sequence, options)); + return CommandBufferCmdExecutor::Create(std::move(cmd_sequence), + options.synchronization_mode); } } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h index f3a3781b39cc6e..96fb4c0c84f33e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.h @@ -24,12 +24,12 @@ namespace xla::gpu { // Options for converting from thunks to command buffer commands. struct ConvertToCommandsOptions { - CommandBufferCmdSequence::SynchronizationMode synchronization_mode = - CommandBufferCmdSequence::SynchronizationMode::kSerialize; + CommandBufferCmdExecutor::SynchronizationMode synchronization_mode = + CommandBufferCmdExecutor::SynchronizationMode::kSerialize; }; // Converts thunk sequence to a command buffer cmd sequence. -absl::StatusOr ConvertToCommands( +absl::StatusOr ConvertToCommands( const ThunkSequence& sequence, const ConvertToCommandsOptions& options); } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 26889e6ce09457..0b55c1a5c49e83 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -64,11 +64,11 @@ static constexpr auto s0 = ExecutionStreamId(0); // Give a short alias to synchronization mode. static constexpr auto serialize = - CommandBufferCmdSequence::SynchronizationMode::kSerialize; + CommandBufferCmdExecutor::SynchronizationMode::kSerialize; // A command buffer cmd for testing automatic barriers insertion by the command -// buffer cmd sequence. We never execute this command, we need it only to pass -// buffer usage vector to the command buffer cmd sequence. +// buffer cmd commands. We never execute this command, we need it only to pass +// buffer usage vector to the command buffer cmd commands. struct TestOnlyCommandBufferCmd : public CommandBufferCmd { TestOnlyCommandBufferCmd(ExecutionStreamId execution_stream_id, BufferUseVector buffer_usage) @@ -153,13 +153,14 @@ TEST(CommandBufferCmdTest, SerializeExecution) { auto use0 = BufferUse(slice0, BufferUse::kRead); auto use1 = BufferUse(slice1, BufferUse::kRead); - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, BufferUseVector{use0}); - builder.Emplace(s0, BufferUseVector{use1}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); - // TODO(ezhulenev): Check that commands correctly infer dependencies. + // TODO(ezhulenev): Check that executor correctly infer dependencies. } TEST(CommandBufferCmdTest, NoReadBarrier) { @@ -172,13 +173,14 @@ TEST(CommandBufferCmdTest, NoReadBarrier) { auto use0 = BufferUse(slice0, BufferUse::kRead); auto use1 = BufferUse(slice1, BufferUse::kRead); - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, BufferUseVector{use0}); - builder.Emplace(s0, BufferUseVector{use1}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); - // TODO(ezhulenev): Check that commands correctly infer dependencies. + // TODO(ezhulenev): Check that executor correctly infer dependencies. } TEST(CommandBufferCmdTest, NoWriteBarrier) { @@ -191,13 +193,14 @@ TEST(CommandBufferCmdTest, NoWriteBarrier) { auto use0 = BufferUse(slice0, BufferUse::kWrite); auto use1 = BufferUse(slice1, BufferUse::kWrite); - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, BufferUseVector{use0}); - builder.Emplace(s0, BufferUseVector{use1}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); - // TODO(ezhulenev): Check that commands correctly infer dependencies. + // TODO(ezhulenev): Check that executor correctly infer dependencies. } TEST(CommandBufferCmdTest, WriteConflictBarrier) { @@ -212,26 +215,29 @@ TEST(CommandBufferCmdTest, WriteConflictBarrier) { auto use1 = BufferUse(slice0, BufferUse::kRead); auto use2 = BufferUse(slice1, BufferUse::kWrite); - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, BufferUseVector{use0}); - builder.Emplace(s0, BufferUseVector{use1}); - builder.Emplace(s0, BufferUseVector{use2}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); + commands.Emplace(s0, BufferUseVector{use2}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); - // TODO(ezhulenev): Check that commands correctly infer dependencies. + // TODO(ezhulenev): Check that executor correctly infer dependencies. } TEST(CommandBufferCmdTest, MemcpyCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - auto stream = executor->CreateStream().value(); + auto stream = stream_executor->CreateStream().value(); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); @@ -244,13 +250,14 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_b, slice_a, byte_length); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_b, slice_a, byte_length); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a, b}, 0, &allocator); CommandBufferCmd::StateManager state; @@ -260,9 +267,10 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { CommandBufferCmd::RecordParams record_params = {state}; - auto command_buffer = - executor->CreateCommandBuffer(se::CommandBuffer::Mode::kPrimary).value(); - TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); + TF_ASSERT_OK_AND_ASSIGN( + auto command_buffer, + stream_executor->CreateCommandBuffer(se::CommandBuffer::Mode::kPrimary)); + TF_ASSERT_OK(executor.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. TF_ASSERT_OK(command_buffer->Submit(stream.get())); @@ -275,15 +283,17 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { } TEST(CommandBufferCmdTest, LaunchCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - auto stream = executor->CreateStream().value(); + auto stream = stream_executor->CreateStream().value(); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); @@ -299,24 +309,25 @@ TEST(CommandBufferCmdTest, LaunchCmd) { auto args_access = {BufferUse::kRead, MemoryAccess::kRead, BufferUse::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); - - // Initialize command sequence and load device kernels. + CommandBufferCmdSequence commands; + commands.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); + + // Initialize command commands and load device kernels. TF_ASSERT_OK_AND_ASSIGN(std::vector fatbin, se::gpu::GetGpuTestKernelsFatbin()); Thunk::ExecutableSource source = {/*text=*/{}, /*binary=*/fatbin}; CommandBufferCmd::StateManager state; - TF_ASSERT_OK(commands.Initialize({executor, source}, state)); + TF_ASSERT_OK(executor.Initialize({stream_executor, source}, state)); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -324,9 +335,10 @@ TEST(CommandBufferCmdTest, LaunchCmd) { CommandBufferCmd::RecordParams record_params = {state}; - auto command_buffer = - executor->CreateCommandBuffer(se::CommandBuffer::Mode::kPrimary).value(); - TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); + TF_ASSERT_OK_AND_ASSIGN( + auto command_buffer, + stream_executor->CreateCommandBuffer(se::CommandBuffer::Mode::kPrimary)); + TF_ASSERT_OK(executor.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. TF_ASSERT_OK(command_buffer->Submit(stream.get())); diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc index 4232e9d57a4586..584f37a4c0e9b1 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc @@ -55,7 +55,7 @@ CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( : command_buffer(std::move(command_buffer)) {} CommandBufferThunk::CommandBufferThunk( - CommandBufferCmdSequence commands, ThunkInfo thunk_info, + CommandBufferCmdExecutor commands, ThunkInfo thunk_info, std::unique_ptr thunks, bool enable_command_buffers_during_profiling) : Thunk(Thunk::kCommandBuffer, std::move(thunk_info)), @@ -81,7 +81,7 @@ CommandBufferThunk::CommandBufferThunk( } bool CommandBufferThunk::ExecutorCommandBuffer::ShouldUpdateCommandBuffer( - const CommandBufferCmdSequence& commands, + const CommandBufferCmdExecutor& commands, const Thunk::ExecuteParams& params) { if (commands.force_update()) { return true; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h index c8f248b67be198..8016648c17e081 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/base/thread_annotations.h" @@ -38,7 +37,7 @@ namespace xla::gpu { class CommandBufferThunk : public Thunk { public: - CommandBufferThunk(CommandBufferCmdSequence commands, ThunkInfo thunk_info, + CommandBufferThunk(CommandBufferCmdExecutor commands, ThunkInfo thunk_info, std::unique_ptr thunks = nullptr, bool enable_command_buffers_during_profiling = false); @@ -68,7 +67,7 @@ class CommandBufferThunk : public Thunk { // Returns true if `commands` cmd sequence has to be recorded into // `command_buffer` to update it (see `recorded_allocs` below). - bool ShouldUpdateCommandBuffer(const CommandBufferCmdSequence& commands, + bool ShouldUpdateCommandBuffer(const CommandBufferCmdExecutor& commands, const Thunk::ExecuteParams& params) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex); @@ -126,8 +125,8 @@ class CommandBufferThunk : public Thunk { // Evicts all previously instantiated command buffers. static void EvictCommandBuffers(); - // Command sequence that initializes command buffers on each executor. - CommandBufferCmdSequence commands_; + // Commands executor that initializes command buffers on each stream executor. + CommandBufferCmdExecutor commands_; // Thunk sequence that executes the same commands as in `commands_` but using // thunk mechanism. We use it as a fallback mechanism to work around CUPTI diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 9c2185a95afde7..69db4164ce44da 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -114,8 +114,8 @@ KernelArgsPacking CreateDefaultArgsPacking() { } // Some of the tests rely on CUDA 12.3+ features. -bool IsAtLeastCuda12300(const se::StreamExecutor* executor) { - const auto& device_description = executor->GetDeviceDescription(); +bool IsAtLeastCuda12300(const se::StreamExecutor* stream_executor) { + const auto& device_description = stream_executor->GetDeviceDescription(); const auto* cuda_cc = std::get_if( &device_description.gpu_compute_capability()); if (cuda_cc != nullptr) { @@ -134,21 +134,23 @@ constexpr auto s1 = ExecutionStreamId(1); // Give a short alias to synchronization mode. static constexpr auto serialize = - CommandBufferCmdSequence::SynchronizationMode::kSerialize; + CommandBufferCmdExecutor::SynchronizationMode::kSerialize; } // namespace TEST(CommandBufferThunkTest, MemcpyCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); @@ -161,15 +163,16 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_b, slice_a, byte_length); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_b, slice_a, byte_length); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); ServiceExecutableRunOptions run_options; BufferAllocations allocations({a, b}, 0, &allocator); @@ -201,15 +204,16 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { } TEST(CommandBufferThunkTest, MemzeroCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42 - se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); // Prepare buffer allocations for recording command buffer. @@ -217,16 +221,17 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_a); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -244,15 +249,16 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { } TEST(CommandBufferThunkTest, Memset32Cmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42 - se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); @@ -261,16 +267,17 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_a, int32_t{84}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{84}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -288,15 +295,16 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { } TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42 - se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); @@ -312,19 +320,20 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); // Prepare commands sequence for constructing command buffer that should not // be used. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_a, int32_t{12}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{12}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); constexpr bool kProfileCommandBuffersEnabled = false; // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(), + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo(), std::move(seq_thunks), kProfileCommandBuffersEnabled); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -344,15 +353,16 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { } TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42 - se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); @@ -368,19 +378,20 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); // Prepare commands sequence for constructing command buffer that should not // be used. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_a, int32_t{12}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{12}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); constexpr bool kProfileCommandBuffersEnabled = true; // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(), + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo(), std::move(seq_thunks), kProfileCommandBuffersEnabled); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -400,11 +411,11 @@ TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { } TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); - se::DeviceMemory a = executor->AllocateArray(2, 0); + se::DeviceMemory a = stream_executor->AllocateArray(2, 0); TF_ASSERT_OK(stream->MemZero(&a, 2 * sizeof(int32_t))); // Prepare buffer allocations for recording command buffer. @@ -413,17 +424,18 @@ TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { BufferAllocation::Slice slice1(&alloc, 1 * sizeof(int32_t), sizeof(int32_t)); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice0, int32_t{12}); - builder.Emplace(s1, slice1, int32_t{34}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice0, int32_t{12}); + commands.Emplace(s1, slice1, int32_t{34}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -441,16 +453,18 @@ TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { } TEST(CommandBufferThunkTest, LaunchCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); @@ -467,27 +481,28 @@ TEST(CommandBufferThunkTest, LaunchCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); - TF_ASSERT_OK( - thunk.Initialize({executor, static_cast(source), - &allocations, stream.get()})); + TF_ASSERT_OK(thunk.Initialize({stream_executor, + static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -500,7 +515,8 @@ TEST(CommandBufferThunkTest, LaunchCmd) { ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Prepare buffer allocation for updating command buffer: c=0 - se::DeviceMemory c = executor->AllocateArray(length, 0); + se::DeviceMemory c = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // Update buffer allocation #1 to buffer `c`. @@ -531,9 +547,9 @@ TEST(CommandBufferThunkTest, LaunchCmd) { } TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); auto packing = CreateDefaultArgsPacking(); @@ -548,8 +564,10 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); @@ -566,27 +584,28 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); - TF_ASSERT_OK( - thunk.Initialize({executor, static_cast(source), - &allocations, stream.get()})); + TF_ASSERT_OK(thunk.Initialize({stream_executor, + static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -599,7 +618,8 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Prepare buffer allocation for updating command buffer: c=0 - se::DeviceMemory c = executor->AllocateArray(length, 0); + se::DeviceMemory c = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // Update buffer allocation #1 to buffer `c`. @@ -630,13 +650,13 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { } TEST(CommandBufferThunkTest, GemmCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - if (!IsAtLeastCuda12300(executor)) { + if (!IsAtLeastCuda12300(stream_executor)) { GTEST_SKIP() << "CUDA graph tracing is not supported"; } - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t lhs_length = sizeof(float) * 2 * 4; int64_t rhs_length = sizeof(float) * 4 * 3; @@ -649,19 +669,19 @@ TEST(CommandBufferThunkTest, GemmCmd) { // 1.0, 1.0, 1.0 // 1.0, 1.0, 1.0 // 1.0, 1.0, 1.0] - se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + se::DeviceMemory lhs = stream_executor->AllocateArray(2 * 4); std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); - se::DeviceMemory rhs = executor->AllocateArray(4 * 3); + se::DeviceMemory rhs = stream_executor->AllocateArray(4 * 3); std::vector rhs_arr(12, 1); TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); - se::DeviceMemory out = executor->AllocateArray(2 * 3); + se::DeviceMemory out = stream_executor->AllocateArray(2 * 3); TF_ASSERT_OK(stream->MemZero(&out, out_length)); se::DeviceMemory workspace = - executor->AllocateArray(1024 * 1024); + stream_executor->AllocateArray(1024 * 1024); TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); // Prepare buffer allocations for recording command buffer. @@ -681,22 +701,23 @@ TEST(CommandBufferThunkTest, GemmCmd) { ShapeUtil::MakeShape(PrimitiveType::F32, {2, 3}), 1.0, 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, se::blas::kDefaultComputePrecision, false, false, - executor->GetDeviceDescription().gpu_compute_capability()); + stream_executor->GetDeviceDescription().gpu_compute_capability()); ASSERT_TRUE(config.ok()); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, - slice_workspace, - /*deterministic=*/true); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, + slice_workspace, + /*deterministic=*/true); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({lhs, rhs, out, workspace}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -704,7 +725,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; TF_ASSERT_OK(thunk.Initialize( - {executor, source, &allocations, stream.get(), stream.get()})); + {stream_executor, source, &allocations, stream.get(), stream.get()})); // Execute command buffer thunk and verify that it executed a GEMM. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -717,7 +738,8 @@ TEST(CommandBufferThunkTest, GemmCmd) { ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); // Prepare buffer allocation for updating command buffer. - se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); + se::DeviceMemory updated_out = + stream_executor->AllocateArray(2 * 3); TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); // Update buffer allocation to updated `out` buffer. @@ -749,13 +771,13 @@ TEST(CommandBufferThunkTest, GemmCmd) { } TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - if (!IsAtLeastCuda12300(executor)) { + if (!IsAtLeastCuda12300(stream_executor)) { GTEST_SKIP() << "CUDA graph tracing is not supported"; } - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t lhs_length = sizeof(float) * 4 * 4; int64_t fake_lhs_length = sizeof(float) * 2 * 4; @@ -769,19 +791,19 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { // 1.0, 1.0, 1.0 // 1.0, 1.0, 1.0 // 1.0, 1.0, 1.0] - se::DeviceMemory lhs = executor->AllocateArray(4 * 4); + se::DeviceMemory lhs = stream_executor->AllocateArray(4 * 4); std::vector lhs_arr{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}; TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); - se::DeviceMemory rhs = executor->AllocateArray(4 * 3); + se::DeviceMemory rhs = stream_executor->AllocateArray(4 * 3); std::vector rhs_arr(12, 1); TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); - se::DeviceMemory out = executor->AllocateArray(2 * 3); + se::DeviceMemory out = stream_executor->AllocateArray(2 * 3); TF_ASSERT_OK(stream->MemZero(&out, out_length)); se::DeviceMemory workspace = - executor->AllocateArray(1024 * 1024); + stream_executor->AllocateArray(1024 * 1024); TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); // Prepare buffer allocations for recording command buffer. @@ -809,16 +831,17 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { ShapeUtil::MakeShape(PrimitiveType::F32, {2, 3}), 1.0, 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, se::blas::kDefaultComputePrecision, false, false, - executor->GetDeviceDescription().gpu_compute_capability()); + stream_executor->GetDeviceDescription().gpu_compute_capability()); ASSERT_TRUE(config.ok()); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder embed_builder; - embed_builder.Emplace(s0, config.value(), fake_slice_lhs, slice_rhs, - slice_out, slice_workspace, - /*deterministic=*/true); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence embed_commands, - std::move(embed_builder).Build(serialize)); + CommandBufferCmdSequence embed_commands; + embed_commands.Emplace(s0, config.value(), fake_slice_lhs, slice_rhs, + slice_out, slice_workspace, + /*deterministic=*/true); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor embed_executor, + CommandBufferCmdExecutor::Create(std::move(embed_commands), serialize)); BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); @@ -844,18 +867,19 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { std::vector> offset_byte_sizes = { sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}; - CommandBufferCmdSequence::Builder builder; - builder.Emplace( - s0, std::move(embed_commands), arguments, std::move(fake_allocations), + CommandBufferCmdSequence commands; + commands.Emplace( + s0, std::move(embed_executor), arguments, std::move(fake_allocations), offsets, orig_shapes, sliced_shapes, offset_byte_sizes); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({lhs, rhs, out, workspace}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -863,7 +887,7 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; TF_ASSERT_OK(thunk.Initialize( - {executor, source, &allocations, stream.get(), stream.get()})); + {stream_executor, source, &allocations, stream.get(), stream.get()})); // Execute command buffer thunk and verify that it executed a GEMM. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -876,7 +900,8 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); // Prepare buffer allocation for updating command buffer. - se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); + se::DeviceMemory updated_out = + stream_executor->AllocateArray(2 * 3); TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); // Update buffer allocation to updated `out` buffer. @@ -908,14 +933,14 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) { } TEST(CommandBufferThunkTest, CublasLtCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - if (!IsAtLeastCuda12300(executor)) { + if (!IsAtLeastCuda12300(stream_executor)) { GTEST_SKIP() << "CUDA graph tracing is not supported"; } - TF_ASSERT_OK_AND_ASSIGN(auto stream1, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto stream2, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream1, stream_executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream2, stream_executor->CreateStream()); // CublasLt formula: D = alpha*(A*B) + beta*(C), @@ -951,22 +976,23 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { /*algorithm*/ std::nullopt, /*compute_precision*/ se::blas::kDefaultComputePrecision, /*grad_x*/ false, /*grad_y*/ false, - executor->GetDeviceDescription().gpu_compute_capability()); + stream_executor->GetDeviceDescription().gpu_compute_capability()); ASSERT_TRUE(config.ok()); // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace( + CommandBufferCmdSequence commands; + commands.Emplace( s0, config.value(), se::gpu::BlasLt::Epilogue::kDefault, 0, slice_a, slice_b, slice_c, slice_d, BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), BufferAllocation::Slice(), slice_workspace); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); std::vector a_arr_1{1, 2, 3, 4, 5, 6, 7, 8}; std::vector a_arr_2{2, 3, 4, 5, 6, 7, 8, 9}; @@ -976,26 +1002,26 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { auto run_cublaslt_test = [&](std::unique_ptr& stream, std::vector a_arr, std::vector result) { - se::DeviceMemory a = executor->AllocateArray(2 * 4); + se::DeviceMemory a = stream_executor->AllocateArray(2 * 4); TF_ASSERT_OK(stream->Memcpy(&a, a_arr.data(), a_length)); - se::DeviceMemory b = executor->AllocateArray(4 * 3); + se::DeviceMemory b = stream_executor->AllocateArray(4 * 3); std::vector b_arr(12, 1); TF_ASSERT_OK(stream->Memcpy(&b, b_arr.data(), b_length)); - se::DeviceMemory c = executor->AllocateArray(2 * 3); + se::DeviceMemory c = stream_executor->AllocateArray(2 * 3); std::vector c_arr(6, 1); TF_ASSERT_OK(stream->Memcpy(&c, c_arr.data(), c_length)); - se::DeviceMemory d = executor->AllocateArray(2 * 3); + se::DeviceMemory d = stream_executor->AllocateArray(2 * 3); TF_ASSERT_OK(stream->MemZero(&d, d_length)); se::DeviceMemory workspace = - executor->AllocateArray(1024 * 1024); + stream_executor->AllocateArray(1024 * 1024); TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a, b, c, d, workspace}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( @@ -1003,7 +1029,7 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; TF_ASSERT_OK(thunk.Initialize( - {executor, source, &allocations, stream.get(), stream.get()})); + {stream_executor, source, &allocations, stream.get(), stream.get()})); // Execute command buffer thunk and verify that it executed a GEMM. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1016,7 +1042,8 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { ASSERT_EQ(dst, result); // Prepare buffer allocation for updating command buffer. - se::DeviceMemory updated_d = executor->AllocateArray(2 * 3); + se::DeviceMemory updated_d = + stream_executor->AllocateArray(2 * 3); TF_ASSERT_OK(stream->MemZero(&updated_d, d_length)); // Update buffer allocation to updated `d` buffer. @@ -1055,18 +1082,22 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { } TEST(CommandBufferThunkTest, MultipleLaunchCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - se::DeviceMemory c = executor->AllocateArray(length, 0); - se::DeviceMemory d = executor->AllocateArray(length, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory c = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory d = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); TF_ASSERT_OK(stream->MemZero(&b, byte_length)); @@ -1090,30 +1121,31 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - builder.Emplace(s0, "AddI32", args_1, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + commands.Emplace(s0, "AddI32", args_1, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({a, b, c, d}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); - TF_ASSERT_OK( - thunk.Initialize({executor, static_cast(source), - &allocations, stream.get()})); + TF_ASSERT_OK(thunk.Initialize({stream_executor, + static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1133,7 +1165,8 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { BufferAllocation::Slice slice_e(&alloc_e, 0, byte_length); // Prepare buffer allocation for updating command buffer: e=0 - se::DeviceMemory e = executor->AllocateArray(length, 0); + se::DeviceMemory e = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->MemZero(&e, byte_length)); // Update buffer allocation #1 to buffer `c`. @@ -1172,21 +1205,24 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { } TEST(CommandBufferThunkTest, CaseCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - if (!IsAtLeastCuda12300(executor)) { + if (!IsAtLeastCuda12300(stream_executor)) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: index=0, a=42, b=0 - se::DeviceMemory index = executor->AllocateArray(1, 0); - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory index = + stream_executor->AllocateArray(1, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&index, 0, sizeof(int32_t))); TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); @@ -1202,51 +1238,54 @@ TEST(CommandBufferThunkTest, CaseCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); // Prepare commands sequence for branches. - std::vector branches_builder(2); + std::vector branches_sequence(2); auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, MemoryAccess::kWrite}; { // Case 0: b = a + a auto args = {slice_a, slice_a, slice_b}; - branches_builder[0].Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + branches_sequence[0].Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); } { // Case 1: b = b + b auto args = {slice_b, slice_b, slice_b}; - branches_builder[1].Emplace(s0, "AddI32", args, args_access, - LaunchDimensions(1, 4), - /*shmem_bytes=*/0); + branches_sequence[1].Emplace(s0, "AddI32", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); } - std::vector branches(2); + std::vector branches(2); TF_ASSERT_OK_AND_ASSIGN(branches[0], - std::move(branches_builder[0]).Build(serialize)); + CommandBufferCmdExecutor::Create( + std::move(branches_sequence[0]), serialize)); TF_ASSERT_OK_AND_ASSIGN(branches[1], - std::move(branches_builder[1]).Build(serialize)); + CommandBufferCmdExecutor::Create( + std::move(branches_sequence[1]), serialize)); // Prepare commands sequence for thunk. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_i, false, std::move(branches)); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_i, false, std::move(branches)); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({index, a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); - TF_ASSERT_OK( - thunk.Initialize({executor, static_cast(source), - &allocations, stream.get()})); + TF_ASSERT_OK(thunk.Initialize({stream_executor, + static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1269,23 +1308,27 @@ TEST(CommandBufferThunkTest, CaseCmd) { } TEST(CommandBufferThunkTest, WhileCmd) { - se::StreamExecutor* executor = GpuExecutor(); + se::StreamExecutor* stream_executor = GpuExecutor(); - if (!IsAtLeastCuda12300(executor)) { + if (!IsAtLeastCuda12300(stream_executor)) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_executor->CreateStream()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: loop_cnt=0, num_iters=10, a=1, b=0 - se::DeviceMemory pred = executor->AllocateArray(1, 0); - se::DeviceMemory loop_cnt = executor->AllocateArray(1, 0); - se::DeviceMemory num_iters = executor->AllocateArray(1, 0); - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory pred = stream_executor->AllocateArray(1, 0); + se::DeviceMemory loop_cnt = + stream_executor->AllocateArray(1, 0); + se::DeviceMemory num_iters = + stream_executor->AllocateArray(1, 0); + se::DeviceMemory a = + stream_executor->AllocateArray(length, 0); + se::DeviceMemory b = + stream_executor->AllocateArray(length, 0); TF_ASSERT_OK(stream->Memset32(&loop_cnt, 0, sizeof(int32_t))); TF_ASSERT_OK(stream->Memset32(&num_iters, 10, sizeof(int32_t))); @@ -1314,33 +1357,36 @@ TEST(CommandBufferThunkTest, WhileCmd) { MemoryAccess::kWrite}; // Prepare commands sequence for loop `cond`. - CommandBufferCmdSequence::Builder cond_commands_builder; - cond_commands_builder.Emplace( - s0, "IncAndCmp", cond_args, cond_args_access, LaunchDimensions(1, 1), - /*shmem_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence cond_commands, - std::move(cond_commands_builder).Build(serialize)); + CommandBufferCmdSequence cond_commands; + cond_commands.Emplace(s0, "IncAndCmp", cond_args, cond_args_access, + LaunchDimensions(1, 1), + /*shmem_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor cond_executor, + CommandBufferCmdExecutor::Create(std::move(cond_commands), serialize)); // Prepare commands sequence for loop `body`. - CommandBufferCmdSequence::Builder body_commands_builder; - body_commands_builder.Emplace( - s0, "AddI32", body_args, body_args_access, LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence body_commands, - std::move(body_commands_builder).Build(serialize)); + CommandBufferCmdSequence body_commands; + body_commands.Emplace(s0, "AddI32", body_args, body_args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor body_executor, + CommandBufferCmdExecutor::Create(std::move(body_commands), serialize)); // Prepare commands sequence for thunk. - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_pred, std::move(cond_commands), - std::move(body_commands)); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_pred, std::move(cond_executor), + std::move(body_executor)); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); + se::StreamExecutorMemoryAllocator allocator(stream_executor); BufferAllocations allocations({pred, loop_cnt, num_iters, a, b}, 0, &allocator); @@ -1348,9 +1394,9 @@ TEST(CommandBufferThunkTest, WhileCmd) { run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); - TF_ASSERT_OK( - thunk.Initialize({executor, static_cast(source), - &allocations, stream.get()})); + TF_ASSERT_OK(thunk.Initialize({stream_executor, + static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value 10 times. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1483,15 +1529,16 @@ ENTRY main.49 { TEST(CommandBufferThunkTest, ToStringPrintsNestedThunks) { BufferAllocation alloc_a(/*index=*/0, /*size=*/4, /*color=*/0); BufferAllocation::Slice slice_a(&alloc_a, /*offset=*/0, /*size=*/4); - CommandBufferCmdSequence::Builder builder; - builder.Emplace(s0, slice_a, int32_t{42}); - TF_ASSERT_OK_AND_ASSIGN(CommandBufferCmdSequence commands, - std::move(builder).Build(serialize)); + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{42}); + TF_ASSERT_OK_AND_ASSIGN( + CommandBufferCmdExecutor executor, + CommandBufferCmdExecutor::Create(std::move(commands), serialize)); std::vector> thunks; thunks.emplace_back( std::make_unique(Thunk::ThunkInfo(), 42, slice_a)); CommandBufferThunk thunk( - std::move(commands), Thunk::ThunkInfo(), + std::move(executor), Thunk::ThunkInfo(), std::make_unique(Thunk::ThunkInfo(), std::move(thunks))); EXPECT_TRUE( absl::StrContains(thunk.ToString(/*indent=*/1), " kMemset32BitValue")); diff --git a/third_party/xla/xla/backends/gpu/runtime/for_all_thunks_test.cc b/third_party/xla/xla/backends/gpu/runtime/for_all_thunks_test.cc index 777634369c385b..7323f9e7318a63 100644 --- a/third_party/xla/xla/backends/gpu/runtime/for_all_thunks_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/for_all_thunks_test.cc @@ -83,7 +83,7 @@ TEST(ForAllThunksTest, CommandBufferThunk) { Thunk::ThunkInfo(), std::move(thunk_sequence)); Thunk* sequential_thunk_ptr = sequential_thunk.get(); - CommandBufferThunk command_buffer_thunk(CommandBufferCmdSequence(), + CommandBufferThunk command_buffer_thunk(CommandBufferCmdExecutor(), Thunk::ThunkInfo(), std::move(sequential_thunk)); EXPECT_THAT(GetAllThunks(&command_buffer_thunk), diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 7e156db5f36730..63ccf0f9e8abf5 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -567,19 +567,19 @@ absl::Status IrEmitterUnnested::EmitCommandBufferThunk( // Maybe serialize all commands in a sequence by forcing barriers between all // recorded commands. This guarantees that we execute all device operations // in the exact same order as a thunk sequence. - CommandBufferCmdSequence::SynchronizationMode synchronization_mode = + CommandBufferCmdExecutor::SynchronizationMode synchronization_mode = ir_emitter_context_->debug_options() .xla_gpu_graph_enable_concurrent_region() - ? CommandBufferCmdSequence::SynchronizationMode::kAutomatic - : CommandBufferCmdSequence::SynchronizationMode::kSerialize; + ? CommandBufferCmdExecutor::SynchronizationMode::kAutomatic + : CommandBufferCmdExecutor::SynchronizationMode::kSerialize; TF_ASSIGN_OR_RETURN( - CommandBufferCmdSequence cmd_sequence, + CommandBufferCmdExecutor cmd_executor, ConvertToCommands(thunk_sequence->thunks(), ConvertToCommandsOptions{synchronization_mode})); AddThunkToThunkSequence(std::make_unique( - std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(cmd_executor), Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(thunk_sequence), ir_emitter_context_->debug_options() .xla_enable_command_buffers_during_profiling())); From 0320d552d03c76f653531bceeb8e596f896423dd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 22:25:37 -0700 Subject: [PATCH 0782/1324] [XLA:GPU] Move collection specific allocation logic to the gpu_collectives PiperOrigin-RevId: 747714042 --- .../gpu/collectives/gpu_collectives.h | 5 +++ .../gpu/collectives/gpu_collectives_stub.h | 5 +++ .../gpu/collectives/nccl_collectives.cc | 27 +++++++++++ .../gpu/collectives/nccl_collectives.h | 4 ++ .../xla/xla/stream_executor/cuda/BUILD | 4 +- .../stream_executor/cuda/cuda_collectives.cc | 45 +++++++------------ 6 files changed, 61 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h index 7e6fbe79cfbf32..fbd88b28063585 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h @@ -103,6 +103,11 @@ class GpuCollectives : public Collectives { // Tries to cast a Collectives::Config to a GpuCollectives::Config. static absl::StatusOr TryCast( const Collectives::Config* config); + + // TODO(b/410686553): Use smart wrapper instead of void*. + virtual absl::StatusOr Allocate(uint64_t bytes) = 0; + + virtual absl::Status Deallocate(void* buffer) = 0; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h index f034e3dd01eaba..f217f4ccd3621d 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -64,6 +64,11 @@ class GpuCollectivesStub : public GpuCollectives { absl::Status GroupStart() final { return UnimplementedError(); } absl::Status GroupEnd() final { return UnimplementedError(); } + absl::StatusOr Allocate(uint64_t bytes) final { + return UnimplementedError(); + } + + absl::Status Deallocate(void* buffer) final { return UnimplementedError(); } protected: static absl::Status UnimplementedError() { diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index 832d67ed32600d..bc84fbef3c62d0 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -227,6 +227,33 @@ absl::Status NcclCollectives::GroupEnd() { return XLA_NCCL_STATUS(ncclGroupEnd()); } +absl::StatusOr NcclCollectives::Allocate(uint64_t bytes) { + void* ptr = nullptr; + ncclResult_t res = ncclMemAlloc(&ptr, bytes); + if (res != ncclSuccess) { + return absl::InternalError(absl::StrFormat( + "failed to allocate %s (%llu bytes) from device collective memory: %s, " + "Last NCCL warning(error) log entry (may be unrelated): %s", + tsl::strings::HumanReadableNumBytes(bytes), bytes, + ncclGetErrorString(res), ncclGetLastError(nullptr))); + } + VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes + << " bytes"; + return ptr; +} + +absl::Status NcclCollectives::Deallocate(void* location) { + ncclResult_t res = ncclMemFree(location); + if (res != ncclSuccess) { + return absl::InternalError(absl::StrFormat( + "failed to free device collective memory at %p; result: %s, Last NCCL " + "warning(error) log entry (may be unrelated): %s", + location, ncclGetErrorString(res), ncclGetLastError(nullptr))); + } + + VLOG(2) << "Deallocated collective memory " << location; + return absl::OkStatus(); +} } // namespace xla::gpu XLA_COLLECTIVES_REGISTER("gpu", "nccl", 1, diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h index 8245c8fbdacaef..36860f1360f00b 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h @@ -57,6 +57,10 @@ class NcclCollectives : public GpuCollectives { absl::StatusOr>> SplitCommunicators( absl::Span comms, int32_t color, absl::Span keys, const Collectives::Config& config) final; + + absl::StatusOr Allocate(uint64_t bytes) final; + + absl::Status Deallocate(void* location) final; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 8567306a767a9c..5f49fe61368c14 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -235,6 +235,9 @@ cc_library( "manual", ], deps = [ + "//xla/backends/gpu/collectives:gpu_collectives", + "//xla/core/collectives", + "//xla/core/collectives:collectives_registry", "//xla/stream_executor:activate_context", "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", @@ -243,7 +246,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_nccl//:nccl", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", ], ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc index 8721ddef66f2d6..c8abf51384b10e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc @@ -21,50 +21,39 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "third_party/nccl/nccl.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/core/collectives/collectives.h" +#include "xla/core/collectives/collectives_registry.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/numbers.h" namespace stream_executor::gpu { +absl::StatusOr GetGpuCollectives( + StreamExecutor* executor) { + std::unique_ptr activation = executor->Activate(); + TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, + xla::CollectivesRegistry::Default("gpu")); + return tsl::down_cast(collectives); +} + /* static */ absl::StatusOr CudaCollectives::CollectiveMemoryAllocate( StreamExecutor* executor, uint64_t bytes) { if (bytes == 0) return nullptr; std::unique_ptr activation = executor->Activate(); - - void* ptr = nullptr; - ncclResult_t res = ncclMemAlloc(&ptr, bytes); - if (res != ncclSuccess) { - return absl::InternalError(absl::StrFormat( - "failed to allocate %s (%llu bytes) from device collective memory: %s, " - "Last NCCL warning(error) log entry (may be unrelated): %s", - tsl::strings::HumanReadableNumBytes(bytes), bytes, - ncclGetErrorString(res), ncclGetLastError(nullptr))); - } - VLOG(2) << "Allocated collective memory " << ptr << " for executor " - << executor << " of " << bytes << " bytes"; - return ptr; + TF_ASSIGN_OR_RETURN(xla::gpu::GpuCollectives * gpu_collectives, + GetGpuCollectives(executor)); + return gpu_collectives->Allocate(bytes); } /* static */ absl::Status CudaCollectives::CollectiveMemoryDeallocate( StreamExecutor* executor, void* location) { std::unique_ptr activation = executor->Activate(); - ncclResult_t res = ncclMemFree(location); - if (res != ncclSuccess) { - return absl::InternalError(absl::StrFormat( - "failed to free device collective memory at %p; result: %s, Last NCCL " - "warning(error) log entry (may be unrelated): %s", - location, ncclGetErrorString(res), ncclGetLastError(nullptr))); - } - - VLOG(2) << "Deallocated collective memory " << location << " for executor " - << executor; - return absl::OkStatus(); + TF_ASSIGN_OR_RETURN(xla::gpu::GpuCollectives * gpu_collectives, + GetGpuCollectives(executor)); + return gpu_collectives->Deallocate(location); } } // namespace stream_executor::gpu From ea60ee22f7544973af125e86165de5b57cc26c18 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 14 Apr 2025 22:29:53 -0700 Subject: [PATCH 0783/1324] Run build_cleaner on BUILD file(s) located in /xla/gpu/llvm_gpu_backend. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've encountered a few CLs that attempted to fix this in local directories, so I figured I run this for all of xla to fix the low-hanging fruits. It resolves several unnecessary & missing dependencies and simplifying target paths, but not all of them. Here are the issues that came up that I didn't attempt to fix entirely: * conflicts that needs manual handling * conflicts that needs to choose between two "valid" targets * missing BUILD in a directory * missing target for a file (e.g. a python script) * missing targets for some `bzl_library` * platform-specific code (e.g. rocm) * ones that use filegroup instead of individual cc_library * and more. Before: ``` metric median Δ 1-pval cpu: 3590.690s ±91.6s memory: 4533MB ±2.6MB system: 594.230s ±10.5s wall: 907.605s ±83.0s ``` After: ``` metric median Δ 1-pval cpu: 3599.015s ±131.4s +8.3s, +0.2% 0.03 (not significant) memory: 4533MB ±2.3MB +0.0MB, +0.0% 0.00 (not significant) system: 582.305s ±9.1s -11.9s, -2.0% 0.25 (not significant) wall: 808.958s ±95.5s -98.6s, -10.9% 0.57 (not significant) ``` Overall, it has modest savings of ~1 minute of wall (physical) time. Since I've excluded some execution tests under `stream_executor/` and `service/` the estimated savings may be greater. Overall, it's a small improvement but should pay dividends in the long run. Note: I'll be sending a series of CLs to fix them in batches of subdirectories to simplify merging. PiperOrigin-RevId: 747715298 --- .../xla/xla/service/gpu/llvm_gpu_backend/BUILD | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index e5e70db9a6b8a7..2bde15b7094cc1 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -138,38 +138,26 @@ cc_library( deps = [ ":llvm_gpu_backend", ":load_ir_module", - ":utils", - "//xla:status_macros", - "//xla:types", "//xla:util", "//xla:xla_proto_cc", - "//xla/service/gpu:metrics", "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/stream_executor:device_description", - "//xla/stream_executor:semantic_version", - "//xla/stream_executor/cuda:subprocess_compilation", "//xla/tsl/platform:rocm_rocdl_path", "//xla/tsl/util:env_var", "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@llvm-project//llvm:AMDGPUAsmParser", - "@llvm-project//llvm:AMDGPUCodeGen", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:BitReader", "@llvm-project//llvm:BitWriter", "@llvm-project//llvm:CodeGen", "@llvm-project//llvm:Core", "@llvm-project//llvm:IPO", - "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Linker", "@llvm-project//llvm:MC", "@llvm-project//llvm:ObjCARC", # buildcleaner: keep @@ -177,10 +165,6 @@ cc_library( "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", - "@llvm-project//mlir:NVVMDialect", - "@local_config_cuda//cuda:cuda_headers", - "@local_config_rocm//rocm:rocm_headers", - "@local_tsl//tsl/platform:cuda_root_path", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -188,7 +172,6 @@ cc_library( "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -249,7 +232,6 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:test", ], ) From 937c36163c3a993cfdaee274cca308eac2a173f9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 22:30:30 -0700 Subject: [PATCH 0784/1324] Automated Code Change PiperOrigin-RevId: 747715488 --- .../xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc | 6 +++--- third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc | 2 +- third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h | 2 +- .../hlo/translate/hlo_to_mhlo/module_attributes_importer.cc | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 8b35075984b148..100bd69ed7b084 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -124,15 +124,15 @@ std::string SanitizeFunctionName(llvm::StringRef name) { bool DotIsDefault(const HloInstruction* instruction) { // If LHS/RHS has rank greater than 2, not default dot const auto& operands = instruction->operands(); - if (operands[0]->shape().dimensions_size() > 2 || - operands[1]->shape().dimensions_size() > 2) { + if (operands[0]->shape().dimensions().size() > 2 || + operands[1]->shape().dimensions().size() > 2) { return false; } auto dnums = instruction->dot_dimension_numbers(); DotDimensionNumbers default_dimension_numbers; default_dimension_numbers.add_lhs_contracting_dimensions( - instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1); + instruction->operand(0)->shape().dimensions().size() == 1 ? 0 : 1); default_dimension_numbers.add_rhs_contracting_dimensions(0); return protobuf_util::HaveSameSerialization(dnums, default_dimension_numbers); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index ccf55a9739a41e..5bd16a05ee9489 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -115,7 +115,7 @@ absl::StatusOr GetPermutationIfAvailable(const Shape& shape, return Internal("Permutations for dynamic shapes are not yet supported"); } int64_t accumulated_stride = 1; - llvm::SmallVector strides(shape.dimensions_size(), 1); + llvm::SmallVector strides(shape.dimensions().size(), 1); for (int64_t dim : LayoutUtil::MinorToMajor(shape)) { strides[dim] = accumulated_stride; accumulated_stride *= shape.dimensions(dim); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h index 19ef3e82fe760c..a19d0da23e5fc1 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -67,7 +67,7 @@ static absl::StatusOr ConvertTensorShapeToType(const Shape& xla_ty, if (!element_type_or.ok()) return element_type_or.status(); bool is_bounded_dynamic = false; - int64_t rank = xla_ty.dimensions_size(); + int64_t rank = xla_ty.dimensions().size(); llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); for (int64_t dim = 0; dim < rank; ++dim) { diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc index 821b2317997206..f6e12ff07e7cfe 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -348,7 +348,7 @@ void ImportParameterLayoutModes(mlir::func::FuncOp main, CHECK_EQ(parameter_shapes.size(), main.getNumArguments()); for (size_t i = 0; i < main.getNumArguments(); ++i) { const Shape& shape = *parameter_shapes[i]; - if (shape.IsTuple() || (shape.IsArray() && shape.dimensions_size() == 0)) + if (shape.IsTuple() || (shape.IsArray() && shape.dimensions().size() == 0)) continue; if (LayoutUtil::HasAnyLayout(*parameter_shapes[i])) continue; main.setArgAttrs( @@ -368,7 +368,7 @@ void ImportResultLayoutModes(mlir::func::FuncOp main, CHECK_EQ(result_shapes.size(), main.getNumResults()); for (size_t i = 0; i < main.getNumResults(); ++i) { const Shape& shape = *result_shapes[i]; - if (shape.IsTuple() || (shape.IsArray() && shape.dimensions_size() == 0)) + if (shape.IsTuple() || (shape.IsArray() && shape.dimensions().size() == 0)) continue; if (LayoutUtil::HasAnyLayout(shape)) continue; main.setResultAttrs( From 8bf8bedfee88382acf5259a4939b60d350d488e3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 22:39:10 -0700 Subject: [PATCH 0785/1324] Automated Code Change PiperOrigin-RevId: 747717957 --- .../xla/xla/hlo/evaluator/hlo_evaluator.cc | 64 +++++++++---------- .../evaluator/hlo_evaluator_typed_visitor.h | 22 +++---- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index e89706164bf22b..d2d0ad281ede78 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -402,7 +402,7 @@ std::optional PatternMatchLoopCondRoot( if (Match(loop_cond_root, match::GetTupleElement().WithOperand( 0, match::Parameter().WithParameterNum(0)))) { if (loop_cond_root->shape().element_type() != PrimitiveType::PRED && - loop_cond_root->shape().dimensions_size() != 0) { + loop_cond_root->shape().dimensions().size() != 0) { return std::nullopt; } return ParamIndexAndValue{{/*param_index=*/loop_cond_root->tuple_index()}}; @@ -1172,7 +1172,7 @@ std::vector HloEvaluator::GetS64Indices( } DimensionVector HloEvaluator::MakeDimMultipliers(const Shape& shape) { - DimensionVector v(shape.dimensions_size()); + DimensionVector v(shape.dimensions().size()); int64_t scale = 1; for (auto dim : LayoutUtil::MinorToMajor(shape)) { v[dim] = scale; @@ -1434,7 +1434,7 @@ absl::Status HloEvaluator::HandleConcatenate( // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); CHECK(reference_shape.IsArray()); - const int64_t rank = reference_shape.dimensions_size(); + const int64_t rank = reference_shape.dimensions().size(); const int64_t concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); CHECK_LT(concat_dim, rank); @@ -1811,7 +1811,7 @@ class FftTransform { return false; }; GenerateIndices(output_lengths, output_strides, input_lengths, - input_strides, input_shape.dimensions_size(), 0, 0, + input_strides, input_shape.dimensions().size(), 0, 0, base_case); } @@ -2276,7 +2276,7 @@ class FftTransform { return InvalidArgument("Invalid input type: %d, must be %d (complex64).", input_elt_type, PrimitiveType::C64); } - const int64_t input_rank = input_shape.dimensions_size(); + const int64_t input_rank = input_shape.dimensions().size(); if (input_rank < fft_rank_) { return InvalidArgument("Input shape rank is smaller than FFT rank."); } @@ -2295,7 +2295,7 @@ class FftTransform { return InvalidArgument("Invalid output type: %d, must be %d (complex64).", output_elt_type, PrimitiveType::C64); } - const int64_t output_rank = output_shape.dimensions_size(); + const int64_t output_rank = output_shape.dimensions().size(); if (output_rank < fft_rank_) { return InvalidArgument("Output shape rank is smaller than FFT rank."); } @@ -2340,7 +2340,7 @@ absl::Status HloEvaluator::HandleFft(const HloInstruction* fft) { // dimensions while keeping the rest of the output dimensions clamped to 0. ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { - int64_t output_rank = output_shape.dimensions_size(); + int64_t output_rank = output_shape.dimensions().size(); std::vector index_base(output_rank, 0); std::vector index_count; index_count.reserve(output_rank); @@ -2391,12 +2391,12 @@ class OutputBatchIndexToInputIndex { const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, const Shape& output_shape, const Literal* start_indices) : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { - for (int64_t i = 0; i < output_shape.dimensions_size(); i++) { + for (int64_t i = 0; i < output_shape.dimensions().size(); i++) { output_dim_is_batch_dims_.push_back( !absl::c_binary_search(dim_numbers_.offset_dims(), i)); } - for (int64_t i = 0; i < input_shape.dimensions_size(); i++) { + for (int64_t i = 0; i < input_shape.dimensions().size(); i++) { int64_t index_of_input_dim_in_index_vector = std::distance(dim_numbers_.start_index_map().begin(), absl::c_find(dim_numbers_.start_index_map(), i)); @@ -2409,8 +2409,8 @@ class OutputBatchIndexToInputIndex { } } - index_vector_index_.resize(start_indices_.shape().dimensions_size()); - input_index_.resize(input_shape.dimensions_size()); + index_vector_index_.resize(start_indices_.shape().dimensions().size()); + input_index_.resize(input_shape.dimensions().size()); int64_t index_vector_size = start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); index_vector_.resize(index_vector_size); @@ -2419,8 +2419,8 @@ class OutputBatchIndexToInputIndex { GetStartIndicesDimToOutputDimForExplicitBatchingDims( dim_numbers_.start_indices_batching_dims(), dim_numbers_.index_vector_dim(), dim_numbers_.offset_dims(), - start_indices_.shape().dimensions_size(), - output_shape.dimensions_size()); + start_indices_.shape().dimensions().size(), + output_shape.dimensions().size()); for (int64_t i = 0; i < dim_numbers->operand_batching_dims().size(); ++i) { int64_t operand_dim = dim_numbers->operand_batching_dims(i); int64_t start_indices_dim = dim_numbers->start_indices_batching_dims(i); @@ -2544,7 +2544,7 @@ class OutputOffsetIndexToInputIndex { const GatherDimensionNumbers& dim_numbers, const Shape& input_shape) { CHECK(absl::c_is_sorted(dim_numbers.offset_dims())); int64_t window_dim_count = 0; - for (int64_t i = 0; i < input_shape.dimensions_size(); i++) { + for (int64_t i = 0; i < input_shape.dimensions().size(); i++) { if (IsCollapsedOrBatchingDim(dim_numbers.collapsed_slice_dims(), dim_numbers.operand_batching_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); @@ -2554,7 +2554,7 @@ class OutputOffsetIndexToInputIndex { } } - input_index_.resize(input_shape.dimensions_size()); + input_index_.resize(input_shape.dimensions().size()); } // Returns the contribution of the window indices to the input index @@ -2609,7 +2609,7 @@ class OutputOffsetIndexToInputIndex { static absl::StatusOr> ReshapedGatherIndices(int64_t index_vector_dim, const Literal& start_indices, Literal* reshaped_start_indices) { - if (start_indices.shape().dimensions_size() != index_vector_dim) { + if (start_indices.shape().dimensions().size() != index_vector_dim) { return std::cref(start_indices); } @@ -2649,13 +2649,13 @@ absl::Status HloEvaluator::HandleGather(const HloInstruction* gather) { IterationSpaceForOutputBatchIndices(shape, dim_numbers); ShapeUtil::IndexIterationSpace offset_indices_iteration_space = IterationSpaceForOutputOffsetIndices( - shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers); + shape.dimensions().size(), gather->gather_slice_sizes(), dim_numbers); // Scratch buffers that hold an index in the output shape and the // corresponding index in the input shape. - std::vector input_index(operand.shape().dimensions_size()); - std::vector output_index(gather->shape().dimensions_size()); - std::vector input_index_clamped(operand.shape().dimensions_size()); + std::vector input_index(operand.shape().dimensions().size()); + std::vector output_index(gather->shape().dimensions().size()); + std::vector input_index_clamped(operand.shape().dimensions().size()); OutputBatchIndexToInputIndex output_batch_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), @@ -2731,7 +2731,7 @@ namespace { absl::StatusOr> ReshapedScatterIndices( int64_t index_vector_dim, const Literal& indices, Literal* reshaped_indices) { - if (indices.shape().dimensions_size() != index_vector_dim) { + if (indices.shape().dimensions().size() != index_vector_dim) { return std::cref(indices); } @@ -2829,7 +2829,7 @@ class UpdateScatterIndexToInputIndex { } } - index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); + index_vector_index_.resize(scatter_indices_.shape().dimensions().size()); input_index_.resize(input_rank); int64_t index_vector_size = scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); @@ -2839,7 +2839,7 @@ class UpdateScatterIndexToInputIndex { GetStartIndicesDimToOutputDimForExplicitBatchingDims( dim_numbers_.scatter_indices_batching_dims(), dim_numbers_.index_vector_dim(), dim_numbers_.update_window_dims(), - scatter_indices_.shape().dimensions_size(), updates_rank); + scatter_indices_.shape().dimensions().size(), updates_rank); for (int64_t i = 0; i < dim_numbers.input_batching_dims().size(); ++i) { int64_t input_dim = dim_numbers.input_batching_dims(i); int64_t scatter_indices_dim = @@ -3183,10 +3183,10 @@ absl::Status HloEvaluator::HandleBroadcast(const HloInstruction* broadcast) { operand.shape().element_type()) << " broadcast from a different data type is not supported"; TF_RET_CHECK(broadcast->dimensions().size() == - operand.shape().dimensions_size()) + operand.shape().dimensions().size()) << "broadcast dimensions is of size: " << broadcast->dimensions().size() << " and rank of operand_to_broadcast is: " - << operand.shape().dimensions_size(); + << operand.shape().dimensions().size(); // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) { @@ -3536,7 +3536,7 @@ absl::Status HloEvaluator::HandleDynamicUpdateSlice(const HloInstruction* dus) { const Literal& update_literal = GetEvaluatedLiteralFor(update); auto result = operand_literal.Clone(); - const auto rank = result.shape().dimensions_size(); + const auto rank = result.shape().dimensions().size(); std::vector start = GetS64Indices(absl::MakeConstSpan(dus->operands()).subspan(2)); @@ -3556,8 +3556,8 @@ absl::Status HloEvaluator::HandleDynamicUpdateSlice(const HloInstruction* dus) { return true; }; - std::vector base(update_literal.shape().dimensions_size(), 0); - std::vector step(update_literal.shape().dimensions_size(), 1); + std::vector base(update_literal.shape().dimensions().size(), 0); + std::vector step(update_literal.shape().dimensions().size(), 1); ShapeUtil::ForEachIndexNoStatus(update_literal.shape(), base, update_literal.shape().dimensions(), step, func); @@ -3741,7 +3741,7 @@ void IterateThroughWindow( const Shape& window_shape, const Window& window, const Shape& base_shape, const absl::Span window_count_index, const std::function)>& f) { - const int64_t rank = base_shape.dimensions_size(); + const int64_t rank = base_shape.dimensions().size(); DimensionVector window_index(rank); std::fill(window_index.begin(), window_index.end(), 0); do { @@ -3964,7 +3964,7 @@ absl::Status HloEvaluator::HandleSelectAndScatter( const Literal& operand_literal = GetEvaluatedLiteralFor(operand); const Literal& source_literal = GetEvaluatedLiteralFor(source); - int64_t rank = operand_literal.shape().dimensions_size(); + int64_t rank = operand_literal.shape().dimensions().size(); HloEvaluator embedded_evaluator(max_loop_iterations_); DimensionVector source_index(rank, 0); @@ -4043,7 +4043,7 @@ absl::Status HloEvaluator::HandleSlice(const HloInstruction* slice) { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const int64_t rank = operand->shape().dimensions_size(); + const int64_t rank = operand->shape().dimensions().size(); const Literal& operand_literal = GetEvaluatedLiteralFor(operand); const size_t element_byte_size = primitive_util::ByteWidth(shape.element_type()); @@ -4082,7 +4082,7 @@ absl::Status HloEvaluator::HandleSort(const HloInstruction* sort) { } } Shape key_shape = sort->operand(0)->shape(); - auto rank = key_shape.dimensions_size(); + auto rank = key_shape.dimensions().size(); std::vector result_literals; result_literals.reserve(sort->operand_count()); for (int64_t i = 0; i < sort->operand_count(); ++i) { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 2bfb9cdc9f0e39..27eb48a6c0989c 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1026,8 +1026,8 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - const auto lhs_rank = lhs_shape.dimensions_size(); - const auto rhs_rank = rhs_shape.dimensions_size(); + const auto lhs_rank = lhs_shape.dimensions().size(); + const auto rhs_rank = rhs_shape.dimensions().size(); CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); @@ -1086,8 +1086,8 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { const auto& dnums = dot->dot_dimension_numbers(); - const int64_t lhs_rank = lhs->shape().dimensions_size(); - const int64_t rhs_rank = rhs->shape().dimensions_size(); + const int64_t lhs_rank = lhs->shape().dimensions().size(); + const int64_t rhs_rank = rhs->shape().dimensions().size(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); @@ -1151,8 +1151,8 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { const Literal& rhs_literal) { const auto& dnums = dot->dot_dimension_numbers(); - const auto lhs_rank = lhs_literal.shape().dimensions_size(); - const auto rhs_rank = rhs_literal.shape().dimensions_size(); + const auto lhs_rank = lhs_literal.shape().dimensions().size(); + const auto rhs_rank = rhs_literal.shape().dimensions().size(); CHECK(ShapeUtil::SameElementType(lhs_literal.shape(), rhs_literal.shape())); CHECK(ShapeUtil::SameElementType(lhs_literal.shape(), dot->shape())); @@ -1284,7 +1284,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { CHECK(pad->operand(0)->shape().IsArray()); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(pad->operand(0)->shape().dimensions_size(), + CHECK_EQ(pad->operand(0)->shape().dimensions().size(), pad->padding_config().dimensions_size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -1321,7 +1321,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); - std::vector target_index(result.shape().dimensions_size(), 0); + std::vector target_index(result.shape().dimensions().size(), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1348,9 +1348,9 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return true; }; - std::vector zero_base(evaluated_operand.shape().dimensions_size(), - 0); - std::vector step(evaluated_operand.shape().dimensions_size(), 1); + std::vector zero_base( + evaluated_operand.shape().dimensions().size(), 0); + std::vector step(evaluated_operand.shape().dimensions().size(), 1); ShapeUtil::ForEachIndexNoStatus(evaluated_operand.shape(), zero_base, evaluated_operand.shape().dimensions(), From 3b4f6665a9f16779997d32a617b0a848815ba204 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 22:48:07 -0700 Subject: [PATCH 0786/1324] Automated Code Change PiperOrigin-RevId: 747720131 --- tensorflow/core/profiler/convert/BUILD | 60 +++++++++++++++++++ .../convert/op_stats_to_pod_stats_test.cc | 3 + .../convert/op_stats_to_pod_viewer.cc | 3 + .../profiler/convert/op_stats_to_pod_viewer.h | 2 + .../convert/op_stats_to_pod_viewer_test.cc | 4 ++ .../convert/op_stats_to_roofline_model.cc | 4 ++ .../convert/op_stats_to_roofline_model.h | 3 + .../profiler/convert/op_stats_to_tf_stats.cc | 3 + .../profiler/convert/op_stats_to_tf_stats.h | 2 + .../convert/op_stats_to_tf_stats_test.cc | 2 + .../convert/preprocess_single_host_xplane.cc | 1 + .../profiler/convert/process_megascale_dcn.cc | 1 + .../convert/step_events_to_steps_db.cc | 1 + .../convert/step_events_to_steps_db.h | 1 + .../convert/xplane_to_dcn_collective_stats.cc | 1 + .../convert/xplane_to_dcn_collective_stats.h | 2 + .../xplane_to_dcn_collective_stats_test.cc | 1 + .../core/profiler/convert/xplane_to_hlo.h | 1 + .../convert/xplane_to_kernel_stats_db.cc | 1 + .../convert/xplane_to_kernel_stats_db.h | 1 + .../convert/xplane_to_kernel_stats_db_test.cc | 1 + .../convert/xplane_to_memory_profile.cc | 2 + .../convert/xplane_to_memory_profile.h | 2 + .../convert/xplane_to_memory_profile_test.cc | 1 + .../convert/xplane_to_op_metrics_db.cc | 1 + .../convert/xplane_to_op_metrics_db.h | 1 + .../convert/xplane_to_op_metrics_db_test.cc | 2 + .../profiler/convert/xplane_to_op_stats.cc | 5 ++ .../profiler/convert/xplane_to_op_stats.h | 1 + .../convert/xplane_to_op_stats_test.cc | 3 + .../profiler/convert/xplane_to_step_events.cc | 2 + .../convert/xplane_to_tf_data_stats.cc | 2 + .../convert/xplane_to_tf_data_stats.h | 1 + .../convert/xplane_to_tf_data_stats_test.cc | 1 + .../convert/xplane_to_tf_functions.cc | 1 + .../profiler/convert/xplane_to_tf_functions.h | 1 + .../convert/xplane_to_tf_functions_test.cc | 1 + .../profiler/convert/xplane_to_tool_names.cc | 1 + .../profiler/convert/xplane_to_tool_names.h | 1 + .../profiler/convert/xplane_to_tools_data.cc | 8 +++ .../profiler/convert/xplane_to_tools_data.h | 1 + .../convert/xplane_to_trace_container.cc | 3 + .../convert/xplane_to_trace_container.h | 2 + .../convert/xplane_to_trace_container_test.cc | 3 + .../convert/xspace_to_dcn_slack_analysis.cc | 3 + .../convert/xspace_to_dcn_slack_analysis.h | 2 + 46 files changed, 149 insertions(+) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 64675cee9b6c90..f56a989dd45613 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -33,6 +33,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:timespan", "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", "@org_xprof//xprof/utils:cost_utils", "@org_xprof//xprof/utils:gpu_event_stats", "@org_xprof//xprof/utils:hlo_module_map", @@ -56,9 +57,11 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", "@org_xprof//xprof/utils:hlo_cost_analysis_wrapper", "@org_xprof//xprof/utils:hlo_module_map", "@org_xprof//xprof/utils:op_metrics_db_utils", @@ -137,6 +140,10 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:roofline_model_proto_cc", "@org_xprof//xprof/utils:diagnostics", ], ) @@ -228,6 +235,9 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", "@org_xprof//xprof/utils:diagnostics", "@org_xprof//xprof/utils:event_span", ], @@ -245,6 +255,9 @@ cc_library( "//tensorflow/core/profiler/protobuf:pod_viewer_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "@com_google_absl//absl/log:check", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_viewer_proto_cc", "@org_xprof//xprof/utils:diagnostics", ], ) @@ -261,6 +274,10 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_viewer_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", "@org_xprof//xprof/utils:diagnostics", "@org_xprof//xprof/utils:event_span", ], @@ -276,7 +293,10 @@ tf_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_raw_proto_cc", ], ) @@ -351,6 +371,9 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_stats_proto_cc", "@org_xprof//xprof/utils:kernel_stats_utils", "@org_xprof//xprof/utils:op_metrics_db_utils", ], @@ -373,6 +396,8 @@ tf_cc_test( "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_stats_proto_cc", ], ) @@ -390,6 +415,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", "@org_xprof//xprof/utils:event_span", "@org_xprof//xprof/utils:op_metrics_db_utils", ], @@ -432,6 +458,11 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:timespan", "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", "@org_xprof//xprof/utils:device_caps_utils", "@org_xprof//xprof/utils:event_span", "@org_xprof//xprof/utils:gpu_event_stats", @@ -495,6 +526,9 @@ tf_cc_test( "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", ], ) @@ -519,6 +553,8 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", "@org_xprof//xprof/utils:event_span", "@org_xprof//xprof/utils:op_metrics_db_utils", ], @@ -558,6 +594,7 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", "@org_xprof//xprof/utils:gpu_event_stats", "@org_xprof//xprof/utils:hlo_module_map", "@org_xprof//xprof/utils:kernel_stats_utils", @@ -579,6 +616,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", "@org_xprof//xprof/utils:kernel_stats_utils", ], ) @@ -601,6 +639,7 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", ], ) @@ -622,6 +661,7 @@ tf_cc_test( "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", ], ) @@ -647,6 +687,7 @@ cc_library( "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:memory_profile_proto_cc", ], ) @@ -665,6 +706,7 @@ tf_cc_test( "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:group_events", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:memory_profile_proto_cc", ], ) @@ -773,6 +815,14 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hlo_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:inference_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_profile_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:overview_page_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_data_stats_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_raw_proto_cc", "@org_xprof//xprof/convert/trace_viewer:trace_events_to_json", "@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility", "@org_xprof//xprof/utils:hardware_type_utils", @@ -800,6 +850,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_data_stats_proto_cc", "@org_xprof//xprof/utils:html_utils", ], ) @@ -822,6 +873,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_data_stats_proto_cc", ], ) @@ -1042,6 +1094,7 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/platform:statusor", ], ) @@ -1082,6 +1135,8 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_raw_proto_cc", "@org_xprof//xprof/convert/trace_viewer:trace_event_arguments_builder", "@org_xprof//xprof/convert/trace_viewer:trace_events", "@org_xprof//xprof/convert/trace_viewer:trace_events_util", @@ -1193,6 +1248,9 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_collective_info_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:topology_proto_cc", "@org_xprof//xprof/utils:hlo_proto_to_module", ], ) @@ -1227,6 +1285,7 @@ cc_library( "@local_xla//xla/tsl/platform:statusor", "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", ], ) @@ -1248,6 +1307,7 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/platform:status", + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", ], ) diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc index 90909301b11e10..15a8caf84af5fe 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/pod_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof #include "xprof/utils/diagnostics.h" // from @org_xprof #include "xprof/utils/event_span.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc index 6dc5f8e870b2de..536ca0427d2dae 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" #include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/pod_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/pod_viewer.pb.h" // from @org_xprof #include "xprof/utils/diagnostics.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h index c45c99393758b0..0d690000527479 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/pod_viewer.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/pod_viewer.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc index d01fba32ee1cfb..0a5bb58ed7470d 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc @@ -22,6 +22,10 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/pod_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/pod_viewer.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof #include "xprof/utils/diagnostics.h" // from @org_xprof #include "xprof/utils/event_span.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc index 44cca8bdba14a2..613f7048aa92ed 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc @@ -29,6 +29,10 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tsl/platform/protobuf.h" +#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/roofline_model.pb.h" // from @org_xprof #include "xprof/utils/diagnostics.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h index f2ed42f783d86e..65b501bc851725 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h +++ b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h @@ -23,6 +23,9 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tsl/platform/protobuf.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/roofline_model.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc index c5b88a817e3766..88da45e892ecd5 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tf_stats.pb.h" // from @org_xprof #include "xprof/utils/kernel_stats_utils.h" // from @org_xprof #include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h index 3b8a06ef1c6619..79994d5570ec3e 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tf_stats.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc index abe9d599d971a9..daaae982635d13 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc @@ -30,6 +30,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tf_stats.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc index 7a1fc581fd3104..9824ef17fc9b53 100644 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/tsl/profiler/utils/preprocess_xplane.h" #include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/utils/derived_timeline.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.cc b/tensorflow/core/profiler/convert/process_megascale_dcn.cc index ab2b9fbefe60bd..febad594d6349b 100644 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.cc +++ b/tensorflow/core/profiler/convert/process_megascale_dcn.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/dcn_analysis.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc index 9351d1c3b2baa3..ce340db50cb6b3 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof #include "xprof/utils/event_span.h" // from @org_xprof #include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h index 18c0f6a34819e2..5bb980c32f1e01 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof #include "xprof/utils/event_span.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc index ad3bea87341162..9838f57816c584 100644 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h index 68e0b491331bdd..9b4f9ff1bf5845 100644 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ +#include "absl/status/statusor.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc index 2d73bbf8b929d6..068efbd752282e 100644 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.h b/tensorflow/core/profiler/convert/xplane_to_hlo.h index 2361ba6e13d194..c102f6a24f6a78 100644 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.h +++ b/tensorflow/core/profiler/convert/xplane_to_hlo.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/service/hlo.pb.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index d2360ecefd3924..a6aefc3fe7e188 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof #include "xprof/utils/gpu_event_stats.h" // from @org_xprof #include "xprof/utils/kernel_stats_utils.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h index ca6f98fd1515d3..4c003a325de2ee 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof #include "xprof/utils/gpu_event_stats.h" // from @org_xprof #include "xprof/utils/hlo_module_map.h" // from @org_xprof #include "xprof/utils/kernel_stats_utils.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc index 03429987be30a9..3e30beeffb41b9 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof #include "xprof/utils/kernel_stats_utils.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc index f996000579301a..3d60ee248ee48d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc @@ -40,6 +40,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/platform/protobuf.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/memory_profile.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h b/tensorflow/core/profiler/convert/xplane_to_memory_profile.h index 00f919d4dbd42e..bc6ceef062132d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.h @@ -18,10 +18,12 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/memory_profile.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc index a60d505cfc786f..512e81b3b740bd 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/memory_profile.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index f48acaea10d9d3..225171766e12fc 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof #include "xprof/utils/cost_utils.h" // from @org_xprof #include "xprof/utils/gpu_event_stats.h" // from @org_xprof #include "xprof/utils/hlo_module_map.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h index 39fea9a2ef786b..fbc52d3e27a5b7 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof #include "xprof/utils/op_utils.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index ee05ed31341025..2ed6a52c49947e 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/profiler/utils/math_utils.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof #include "xprof/utils/hlo_cost_analysis_wrapper.h" // from @org_xprof #include "xprof/utils/hlo_module_map.h" // from @org_xprof #include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 73f35d7ae0dc65..f01e8a833d64ba 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -50,6 +50,11 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof #include "xprof/utils/device_caps_utils.h" // from @org_xprof #include "xprof/utils/event_span.h" // from @org_xprof #include "xprof/utils/gpu_event_stats.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h index cd180e7c8dcd0e..7a2f103c714420 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index c1a310e0127165..c62b36428c9a28 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -45,6 +45,9 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index 251dc6b72a9150..104ed52a3eb02b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof #include "xprof/utils/event_span.h" // from @org_xprof #include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc index fafdcc386c1295..1f0009c44ed4eb 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc @@ -36,6 +36,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof #include "xprof/utils/html_utils.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h index f5f53488791942..1727dcadefa7dd 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc index 64f1f68fe3226e..6a7eb75194c73e 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc index 1a61c032d442d5..fc44baf39d8406 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/platform/protobuf.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h index fbff7ccecc72d2..54c7d127d36b8d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/tf_function.pb.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" +#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc index e77883c847e53c..c1764961aa948d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc index df3ccc129e5922..b70573db79458a 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/xplane_to_hlo.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.h b/tensorflow/core/profiler/convert/xplane_to_tool_names.h index a1e936940d2b91..3a23604ee7fcfd 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.h +++ b/tensorflow/core/profiler/convert/xplane_to_tool_names.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/profiler/convert/repository.h" diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc index 885c729af15b50..b743e25c586d2f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc @@ -72,6 +72,14 @@ limitations under the License. #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/convert/trace_viewer/trace_events_to_json.h" // from @org_xprof #include "xprof/convert/trace_viewer/trace_viewer_visibility.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/hlo_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/inference_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_profile.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/overview_page.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof #include "xprof/utils/hardware_type_utils.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.h b/tensorflow/core/profiler/convert/xplane_to_tools_data.h index 8a40e03a7cd1dd..49ab1ea588f41d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.h +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/profiler/convert/repository.h" diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc index 27aaa7af86d039..96fdf62e77129b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc @@ -32,8 +32,11 @@ limitations under the License. #include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/convert/trace_viewer/trace_event_arguments_builder.h" // from @org_xprof #include "xprof/convert/trace_viewer/trace_events_util.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/trace_events.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.h b/tensorflow/core/profiler/convert/xplane_to_trace_container.h index 644848460661e6..157c16aa6c38fa 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.h +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "xprof/convert/trace_viewer/trace_events.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc index 821582610fd6ec..86ba79a34870e1 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc @@ -26,6 +26,9 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "tensorflow/core/util/proto/proto_utils.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/trace_events.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc index c59b06d6bc049f..27ccff245411e3 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc @@ -49,6 +49,9 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tsl/platform/regexp.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_collective_info.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/topology.pb.h" // from @org_xprof #include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h index 388fe80d22d3b6..8f98c452e0eace 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h @@ -33,6 +33,8 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/topology.pb.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof +#include "plugin/tensorboard_plugin_profile/protobuf/topology.pb.h" // from @org_xprof namespace tensorflow { namespace profiler { From f4eb58acd2ed1d0b099cdaf68b80bbe570e96d0b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 23:15:25 -0700 Subject: [PATCH 0787/1324] [XLA:GPU] Move collective topology initialization from pjrt to collectives. PiperOrigin-RevId: 747727239 --- .../xla/xla/backends/gpu/collectives/BUILD | 17 ++++- .../gpu/collectives/gpu_collectives.h | 17 +++++ .../gpu/collectives/gpu_collectives_stub.h | 4 + .../gpu/collectives/nccl_collectives.cc | 76 +++++++++++++++++++ .../gpu/collectives/nccl_collectives.h | 2 + .../xla/xla/backends/gpu/runtime/BUILD | 6 +- third_party/xla/xla/pjrt/gpu/BUILD | 28 +------ third_party/xla/xla/pjrt/gpu/nccl_id_store.cc | 69 ----------------- third_party/xla/xla/pjrt/gpu/nccl_id_store.h | 59 -------------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 26 ++++--- 10 files changed, 133 insertions(+), 171 deletions(-) delete mode 100644 third_party/xla/xla/pjrt/gpu/nccl_id_store.cc delete mode 100644 third_party/xla/xla/pjrt/gpu/nccl_id_store.h diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 3595fe88020edf..2f8c015eea9676 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -1,7 +1,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//xla:xla.default.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") @@ -24,8 +24,7 @@ cc_library( name = "gpu_collectives_plugin", deps = [ ":gpu_collectives_stub", - ":nccl_collectives", - ], + ] + if_nccl([":nccl_collectives"]), ) cc_library( @@ -128,6 +127,7 @@ cc_library( srcs = ["gpu_collectives.cc"], hdrs = ["gpu_collectives.h"], deps = [ + "//xla:executable_run_options", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", @@ -136,9 +136,12 @@ cc_library( "//xla/core/collectives:clique_key", "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -189,9 +192,11 @@ cc_library( ]), visibility = ["//visibility:private"], deps = [ + ":gpu_clique_key", ":gpu_collectives", ":nccl_communicator", ":nccl_errors", + "//xla:debug_options_flags", "//xla:status_macros", "//xla:util", "//xla/core/collectives", @@ -200,15 +205,21 @@ cc_library( "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:global_device_id", + "//xla/service/gpu:gpu_executable_run_options", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", ] + if_cuda_is_configured([ diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h index fbd88b28063585..556699506336b0 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h @@ -19,13 +19,18 @@ limitations under the License. #include #include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/executable_run_options.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -108,6 +113,18 @@ class GpuCollectives : public Collectives { virtual absl::StatusOr Allocate(uint64_t bytes) = 0; virtual absl::Status Deallocate(void* buffer) = 0; + + struct Topology { + int32_t node_id; + int32_t num_nodes; + size_t device_count_per_process; + std::shared_ptr kv_store; + absl::flat_hash_map device_id_to_node_id; + gpu::GpuExecutableRunOptions* gpu_executable_run_options; + }; + + // Initializes the topology information for the collectives backend. + virtual absl::Status InitializeTopology(Topology topology) = 0; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h index f217f4ccd3621d..11b50eed094073 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -70,6 +70,10 @@ class GpuCollectivesStub : public GpuCollectives { absl::Status Deallocate(void* buffer) final { return UnimplementedError(); } + absl::Status InitializeTopology(Topology topology) final { + return UnimplementedError(); + } + protected: static absl::Status UnimplementedError() { return Unimplemented("XLA compiled without GPU collectives support"); diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index bc84fbef3c62d0..d54c91f00b0e8d 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -20,16 +20,21 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" #include "xla/backends/gpu/collectives/nccl_errors.h" @@ -39,6 +44,9 @@ limitations under the License. #include "xla/core/collectives/collectives_registry.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" @@ -254,6 +262,74 @@ absl::Status NcclCollectives::Deallocate(void* location) { VLOG(2) << "Deallocated collective memory " << location; return absl::OkStatus(); } + +class NcclIdStore { + public: + NcclIdStore(int node_id, + absl::flat_hash_map device_to_node, + std::shared_ptr kv_store) + : node_id_(node_id), + device_to_node_(std::move(device_to_node)), + kv_store_(std::move(kv_store)) {} + + absl::StatusOr GetNcclUniqueId(const CliqueKey& key) { + auto* gpu_key = tsl::down_cast(&key); + if (gpu_key == nullptr) { + return InvalidArgument("Expected GPU clique key"); + } + + // The caller must ensure that threads calling this method concurrently have + // unique keys, otherwise the global key-value store may hold the wrong + // value. + { + absl::MutexLock lock(&mu_); + auto it = cache_.find(*gpu_key); + if (it != cache_.end()) { + return it->second; + } + } + CliqueId clique_id; + int primary_node_id = device_to_node_.at(gpu_key->root_device()); + if (node_id_ == primary_node_id) { + TF_ASSIGN_OR_RETURN( + clique_id, gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); + TF_RETURN_IF_ERROR( + kv_store_->Set(gpu_key->ToString(), clique_id.ToString())); + } else { + TF_ASSIGN_OR_RETURN( + std::string id_str, + kv_store_->Get(gpu_key->ToString(), absl::Minutes(10))); + clique_id = CliqueId(id_str); + } + absl::MutexLock lock(&mu_); + auto result = cache_.emplace(*gpu_key, std::move(clique_id)); + TF_RET_CHECK(result.second) << "Unique ID already in cache."; + return result.first->second; + } + + private: + const int node_id_; + const absl::flat_hash_map device_to_node_; + const std::shared_ptr kv_store_; + + absl::Mutex mu_; + absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); +}; + +absl::Status NcclCollectives::InitializeTopology( + NcclCollectives::Topology topology) { + if (topology.num_nodes > 1) { + auto nccl_id_store = std::make_shared( + topology.node_id, topology.device_id_to_node_id, + std::move(topology.kv_store)); + topology.gpu_executable_run_options->set_clique_id_callback( + [nccl_id_store](const CliqueKey& key) { + return nccl_id_store->GetNcclUniqueId(key); + }); + } + return absl::OkStatus(); +} + } // namespace xla::gpu XLA_COLLECTIVES_REGISTER("gpu", "nccl", 1, diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h index 36860f1360f00b..b89a9360728f18 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h @@ -61,6 +61,8 @@ class NcclCollectives : public GpuCollectives { absl::StatusOr Allocate(uint64_t bytes) final; absl::Status Deallocate(void* location) final; + + absl::Status InitializeTopology(Topology topology) final; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index a16c0bf2495891..ea7cb55a8059f5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -949,11 +949,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@llvm-project//mlir:IR", - ] + if_cuda_is_configured([ - "@local_config_nccl//:nccl", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rccl", - ]), + ], ) cc_library( diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 06efc16f215a44..ee0363a8f51e7e 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -66,8 +66,11 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_proto_cc", + "//xla/backends/gpu/collectives:gpu_collectives", "//xla/client:client_library", "//xla/client:local_client", + "//xla/core/collectives", + "//xla/core/collectives:collectives_registry", "//xla/hlo/builder:xla_computation", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:event_pool", @@ -143,7 +146,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:traceme", ] + if_cuda_or_rocm([ # keep sorted - ":nccl_id_store", "//xla:debug_options_flags", "//xla/service/gpu:gpu_compiler", "//xla/service/gpu:gpu_constants", @@ -235,30 +237,6 @@ xla_test( ], ) -cc_library( - name = "nccl_id_store", - srcs = ["nccl_id_store.cc"], - hdrs = ["nccl_id_store.h"], - deps = [ - "//xla:status_macros", - "//xla:util", - "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_collectives", - "//xla/core/collectives:clique_id", - "//xla/core/collectives:clique_key", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:global_device_id", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - xla_test( name = "pjrt_client_test_se_gpu", srcs = ["pjrt_client_test_se_gpu.cc"], diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc deleted file mode 100644 index a2a72856e6d9f3..00000000000000 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/pjrt/gpu/nccl_id_store.h" - -#include -#include - -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" -#include "xla/core/collectives/clique_id.h" -#include "xla/core/collectives/clique_key.h" -#include "xla/status_macros.h" -#include "xla/util.h" -#include "tsl/platform/casts.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -absl::StatusOr NcclIdStore::GetNcclUniqueId(const CliqueKey& key) { - auto* gpu_key = tsl::down_cast(&key); - if (gpu_key == nullptr) { - return InvalidArgument("Expected GPU clique key"); - } - - // The caller must ensure that threads calling this method concurrently have - // unique keys, otherwise the global key-value store may hold the wrong value. - { - absl::MutexLock lock(&mu_); - auto it = cache_.find(*gpu_key); - if (it != cache_.end()) { - return it->second; - } - } - CliqueId clique_id; - int primary_node_id = device_to_node_.at(gpu_key->root_device()); - if (node_id_ == primary_node_id) { - TF_ASSIGN_OR_RETURN(clique_id, - gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); - TF_RETURN_IF_ERROR( - kv_store_->Set(gpu_key->ToString(), clique_id.ToString())); - } else { - TF_ASSIGN_OR_RETURN(std::string id_str, - kv_store_->Get(gpu_key->ToString(), absl::Minutes(10))); - clique_id = CliqueId(id_str); - } - absl::MutexLock lock(&mu_); - auto result = cache_.emplace(*gpu_key, std::move(clique_id)); - TF_RET_CHECK(result.second) << "Unique ID already in cache."; - return result.first->second; -} - -} // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h deleted file mode 100644 index fe8b060cb946a7..00000000000000 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PJRT_GPU_NCCL_ID_STORE_H_ -#define XLA_PJRT_GPU_NCCL_ID_STORE_H_ - -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/core/collectives/clique_id.h" -#include "xla/core/collectives/clique_key.h" -#include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/global_device_id.h" - -namespace xla { - -// A table mapping GpuCliqueKeys to CliqueIds. In a distributed setup the -// table of NCCL IDs is kept on the master node (node 0). The node of the first -// participating device will create the unique id. -class NcclIdStore { - public: - NcclIdStore(int node_id, - absl::flat_hash_map device_to_node, - std::shared_ptr kv_store) - : node_id_(node_id), - device_to_node_(std::move(device_to_node)), - kv_store_(std::move(kv_store)) {} - - absl::StatusOr GetNcclUniqueId(const CliqueKey& key); - - private: - const int node_id_; - const absl::flat_hash_map device_to_node_; - const std::shared_ptr kv_store_; - - absl::Mutex mu_; - absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); -}; - -} // namespace xla - -#endif // XLA_PJRT_GPU_NCCL_ID_STORE_H_ diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index d9a0e2df827359..8ae9b7210d6534 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -46,7 +46,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/client/local_client.h" +#include "xla/core/collectives/collectives.h" +#include "xla/core/collectives/collectives_registry.h" #include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" @@ -106,7 +109,6 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/gpu/gpu_metrics.h" -#include "xla/pjrt/gpu/nccl_id_store.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_constants.h" @@ -1209,16 +1211,20 @@ absl::StatusOr BuildDistributedDevices( } gpu_executable_run_options->set_gpu_global_device_ids( std::move(gpu_device_ids)); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (num_nodes > 1) { - auto nccl_id_store = std::make_shared(node_id, device_to_node, - std::move(kv_store)); - gpu_executable_run_options->set_clique_id_callback( - [nccl_id_store](const CliqueKey& key) { - return nccl_id_store->GetNcclUniqueId(key); - }); + + TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, + xla::CollectivesRegistry::Default("gpu")); + xla::gpu::GpuCollectives* gpu_collectives = + tsl::down_cast(collectives); + + if (gpu_collectives == nullptr) { + return absl::InternalError("Failed to get GPU collectives"); } -#endif // GOOGLE_CUDA + + TF_RETURN_IF_ERROR(gpu_collectives->InitializeTopology( + {node_id, global_topology.nodes().size(), local_device_states.size(), + kv_store, device_to_node, gpu_executable_run_options})); + TF_ASSIGN_OR_RETURN(GpuTopologyProto gpu_topology, BuildGpuTopology(global_topology)); return std::make_pair(std::move(devices), gpu_topology); From 20002a4ac2fa24aa77b7df989238b012175f7cc1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 23:51:29 -0700 Subject: [PATCH 0788/1324] Automated Code Change PiperOrigin-RevId: 747736876 --- .../hlo/analysis/hlo_replication_analysis.cc | 2 +- .../hlo/analysis/indexed_array_analysis.cc | 30 +++++---- .../xla/xla/hlo/analysis/indexing_analysis.cc | 66 ++++++++++--------- 3 files changed, 51 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index e75030e8665b50..430de88bb2c6a6 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -266,7 +266,7 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( if (hlo->dynamic_slice_sizes().size() == 1 && hlo->dynamic_slice_sizes()[0] == 1 && ds_buffer->opcode() == HloOpcode::kConstant && - ds_buffer->shape().dimensions_size() == 1 && + ds_buffer->shape().dimensions().size() == 1 && ds_buffer->shape().element_type() == PrimitiveType::S32 && ((cross_partition_spmd && hlo->operand(1)->opcode() == HloOpcode::kPartitionId) || diff --git a/third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc b/third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc index 90e0c5ec4f7bc6..2cd08d3129bcda 100644 --- a/third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc @@ -222,7 +222,7 @@ absl::StatusOr IndexedArrayAnalysis::FoldGatherOfGather( enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; - std::vector simulated_index(a->shape().dimensions_size(), + std::vector simulated_index(a->shape().dimensions().size(), IndexComponent::Ungathered); // Simulate the first gather. @@ -275,7 +275,7 @@ absl::StatusOr IndexedArrayAnalysis::FoldGatherOfGather( absl::StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, absl::Span slice_sizes, Array* source, Array* indices) { - if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions().size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } @@ -297,7 +297,7 @@ absl::StatusOr IndexedArrayAnalysis::ComputeArrayForGather( // dimensions -- for instance it cannot represent a gather that picks 5 [2,3] // arrays from an array of size [7,4,6]. We check that condition down below: - for (int64_t i = 0, e = source->shape().dimensions_size(); i < e; i++) { + for (int64_t i = 0, e = source->shape().dimensions().size(); i < e; i++) { if (i != dim_numbers.collapsed_slice_dims(0) && source->shape().dimensions(i) != slice_sizes[i]) { VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i @@ -311,7 +311,7 @@ absl::StatusOr IndexedArrayAnalysis::ComputeArrayForGather( int64_t source_dim = dim_numbers.start_index_map(0); std::vector output_dims; - for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) { + for (int64_t i = 0, e = shape.dimensions().size(); i < e; i++) { if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } @@ -498,7 +498,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( const Shape& source_shape = operand->source()->shape(); DimensionVector new_source_shape_dims; - for (int64_t i = 0, e = source_shape.dimensions_size(); i < e; i++) { + for (int64_t i = 0, e = source_shape.dimensions().size(); i < e; i++) { if (i == operand->source_dim() || source_shape.dimensions(i) != 1) { new_source_shape_dims.push_back(source_shape.dimensions(i)); } @@ -520,7 +520,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( // will no longer be present. DimensionVector new_output_dims; int64_t degenerate_dims_seen = 0; - for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) { + for (int64_t i = 0, e = shape.dimensions().size(); i < e; i++) { if (shape.dimensions(i) == 1) { degenerate_dims_seen++; } else if (absl::c_linear_search(operand->output_dims(), i)) { @@ -557,7 +557,7 @@ IndexedArrayAnalysis::ReshapeToAddDegenerateDims( // index. absl::InlinedVector output_dims_bitvector( - operand->shape().dimensions_size()); + operand->shape().dimensions().size()); for (int64_t output_dim : operand->output_dims()) { output_dims_bitvector[output_dim] = true; } @@ -644,7 +644,7 @@ absl::StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( } DimensionVector degenerate_result_dims; - for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) { + for (int64_t i = 0, e = shape.dimensions().size(); i < e; i++) { if (shape.dimensions(i) == 1) { degenerate_result_dims.push_back(i); } @@ -911,7 +911,8 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // existing broadcast: enum class IndexComponent { Broadcasted, NotBroadcasted }; std::vector simulated_index( - broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); + broadcast_instr->shape().dimensions().size(), + IndexComponent::Broadcasted); for (int64_t broadcast_dim : broadcast_dims) { simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; } @@ -1023,8 +1024,9 @@ bool CanFoldDotIntoIndexedArray( absl::Span contracting_dims, absl::Span batch_dims) { std::optional non_contracting_non_batch_dim = - GetOnlyNonContractingNonBatchDim(indexed_array->shape().dimensions_size(), - contracting_dims, batch_dims); + GetOnlyNonContractingNonBatchDim( + indexed_array->shape().dimensions().size(), contracting_dims, + batch_dims); if (!non_contracting_non_batch_dim.has_value()) { VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; return false; @@ -1036,7 +1038,7 @@ bool CanFoldDotIntoIndexedArray( return false; } - int64_t indexed_array_rank = indexed_array->shape().dimensions_size(); + int64_t indexed_array_rank = indexed_array->shape().dimensions().size(); if (indexed_array->source_dim() < (indexed_array_rank - 2)) { // This restriction can be lifted by inserting reshape nodes. VLOG(3) << tag @@ -1064,7 +1066,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( return nullptr; } - int64_t lhs_rank = lhs->shape().dimensions_size(); + int64_t lhs_rank = lhs->shape().dimensions().size(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); @@ -1099,7 +1101,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( return nullptr; } - int64_t rhs_rank = rhs->shape().dimensions_size(); + int64_t rhs_rank = rhs->shape().dimensions().size(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_rhs_contracting_dimensions( diff --git a/third_party/xla/xla/hlo/analysis/indexing_analysis.cc b/third_party/xla/xla/hlo/analysis/indexing_analysis.cc index eff320f1819115..e5c543019ca585 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/indexing_analysis.cc @@ -290,8 +290,8 @@ HloInstructionIndexing ComputeOutputToInputCwiseOpIndexing( // Select allows implicit broadcasting in the predicate. We just handle it // generically here. auto* operand = instr->operand(operand_id); - if (operand->shape().dimensions_size() == 0 && - instr->shape().dimensions_size() > 0) { + if (operand->shape().dimensions().size() == 0 && + instr->shape().dimensions().size() > 0) { instr_indexing.indexing_maps[operand_id].insert(unit_map); } else { instr_indexing.indexing_maps[operand_id].insert(identity_map); @@ -331,7 +331,7 @@ HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing( std::vector added_dims_sizes; std::vector exprs; - exprs.reserve(output_shape.dimensions_size()); + exprs.reserve(output_shape.dimensions().size()); for (auto [output_dim_id, output_dim] : llvm::enumerate(output_shape.dimensions())) { auto bcast_dim = @@ -346,7 +346,7 @@ HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing( std::distance(bcast_dims.begin(), bcast_dim), mlir_context)); } IndexingMap indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(input_shape.dimensions_size(), added_dims_sizes.size(), + AffineMap::get(input_shape.dimensions().size(), added_dims_sizes.size(), exprs, mlir_context), input_shape.dimensions(), added_dims_sizes); @@ -438,8 +438,8 @@ HloInstructionIndexing ComputeOutputToInputDotOpIndexing( // According to the StableHLO specification, the dimensions of the output // shape are ordered as follows: // lhs_batch_dims | lhs_non_contracting_dims | rhs_non_contracting_dims - SmallVector lhs_exprs(lhs_shape.dimensions_size()); - SmallVector rhs_exprs(rhs_shape.dimensions_size()); + SmallVector lhs_exprs(lhs_shape.dimensions().size()); + SmallVector rhs_exprs(rhs_shape.dimensions().size()); int64_t output_dim_id = 0; // lhs_batch_dims @@ -487,12 +487,12 @@ HloInstructionIndexing ComputeOutputToInputDotOpIndexing( } IndexingMap lhs_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(dot->shape().dimensions_size(), input_dim_sizes.size(), + AffineMap::get(dot->shape().dimensions().size(), input_dim_sizes.size(), lhs_exprs, mlir_context), dot->shape().dimensions(), input_dim_sizes); IndexingMap rhs_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(dot->shape().dimensions_size(), input_dim_sizes.size(), + AffineMap::get(dot->shape().dimensions().size(), input_dim_sizes.size(), rhs_exprs, mlir_context), dot->shape().dimensions(), input_dim_sizes); return HloInstructionIndexing::FromIndexingMaps( @@ -504,10 +504,11 @@ HloInstructionIndexing ComputeOutputToInputDynamicSliceOpIndexing( MLIRContext* mlir_context) { const Shape& input_shape = dynamic_slice->operand(0)->shape(); const Shape& output_shape = dynamic_slice->shape(); - int64_t rank = output_shape.dimensions_size(); + int64_t rank = output_shape.dimensions().size(); const int64_t first_index_num = dynamic_slice->first_index_operand_number(); - CHECK(dynamic_slice->operand(first_index_num)->shape().dimensions_size() == 0) + CHECK(dynamic_slice->operand(first_index_num)->shape().dimensions().size() == + 0) << "b/118437727: Old form, not supported."; // A map from tensor iteration space to (), because index operands are 0d // tensors. @@ -542,7 +543,7 @@ HloInstructionIndexing ComputeOutputToInputDynamicUpdateSliceOpIndexing( const HloDynamicUpdateSliceInstruction* dus, MLIRContext* mlir_context) { const Shape& update_shape = dus->update()->shape(); const Shape& output_shape = dus->shape(); - int64_t rank = output_shape.dimensions_size(); + int64_t rank = output_shape.dimensions().size(); // operand: (d0, ... d_{N-1}) -> (d0, ... d_{N-1}) std::vector identity; @@ -598,7 +599,7 @@ HloInstructionIndexing ComputeOutputToInputGatherOpIndexing( indices_shape.dimensions(dimension_numbers.index_vector_dim()); const Shape& output_shape = gather->shape(); - int64_t output_rank = output_shape.dimensions_size(); + int64_t output_rank = output_shape.dimensions().size(); // A map for the `indices` operand of gather. It is always // (d_0, ... d_{rank - 1}) -> (d_0, s_0), @@ -619,7 +620,7 @@ HloInstructionIndexing ComputeOutputToInputGatherOpIndexing( // where s_i are RTVars that extract indices from the `indices` operand. std::vector rt_vars; std::vector exprs; - exprs.reserve(operand_shape.dimensions_size()); + exprs.reserve(operand_shape.dimensions().size()); for (auto [operand_dim_id, slice_size] : llvm::enumerate(gather->gather_slice_sizes())) { int64_t output_dim_id = dimension_numbers.offset_dims(operand_dim_id); @@ -682,7 +683,7 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl( HloInstructionIndexing ComputeOutputToInputPadOpIndexing( const HloPadInstruction* pad, MLIRContext* mlir_context) { const Shape& output_shape = pad->shape(); - int64_t rank = output_shape.dimensions_size(); + int64_t rank = output_shape.dimensions().size(); SmallVector padding_low, padding_high, padding_interior; padding_low.reserve(rank); padding_high.reserve(rank); @@ -696,7 +697,7 @@ HloInstructionIndexing ComputeOutputToInputPadOpIndexing( output_shape.dimensions(), padding_low, padding_high, padding_interior, mlir_context); IndexingMap padding_value_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(output_shape.dimensions_size(), /*symbolCount=*/0, {}, + AffineMap::get(output_shape.dimensions().size(), /*symbolCount=*/0, {}, mlir_context), output_shape.dimensions(), /*symbol_upper_bounds=*/{}); return HloInstructionIndexing::FromIndexingMaps( @@ -714,7 +715,7 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( std::vector parallel_dims_sizes; int64_t output_dim_id = 0; std::vector exprs; - exprs.reserve(input_shape.dimensions_size()); + exprs.reserve(input_shape.dimensions().size()); for (auto [input_dim_id, input_dim] : llvm::enumerate(input_shape.dimensions())) { if (reduce_dims_ids.contains(input_dim_id)) { @@ -726,11 +727,11 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( exprs.push_back(getAffineDimExpr(output_dim_id++, mlir_context)); } IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(output_shape.dimensions_size(), reduce_dims_ids.size(), + AffineMap::get(output_shape.dimensions().size(), reduce_dims_ids.size(), exprs, mlir_context), output_shape.dimensions(), parallel_dims_sizes); IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(output_shape.dimensions_size(), /*symbolCount=*/0, {}, + AffineMap::get(output_shape.dimensions().size(), /*symbolCount=*/0, {}, mlir_context), output_shape.dimensions(), {}); @@ -749,7 +750,7 @@ HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( const HloReduceInstruction* reduce, int input_id, MLIRContext* mlir_context) { const Shape& output_shape = GetOutputShape(reduce, 0); - int64_t output_rank = output_shape.dimensions_size(); + int64_t output_rank = output_shape.dimensions().size(); HloInstructionIndexing instr_indexing; int arity = reduce->input_count(); @@ -783,7 +784,7 @@ HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( } } IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(input_shape.dimensions_size(), /*symbolCount=*/0, + AffineMap::get(input_shape.dimensions().size(), /*symbolCount=*/0, inputs_exprs, mlir_context), input_shape.dimensions(), {}); for (int64_t id = 0; id < arity; ++id) { @@ -865,7 +866,7 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( // Indexing map for the init value. IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( - AffineMap::get(output_shape.dimensions_size(), /*symbolCount=*/0, {}, + AffineMap::get(output_shape.dimensions().size(), /*symbolCount=*/0, {}, mlir_context), output_shape.dimensions(), /*symbol_upper_bounds=*/{}); @@ -888,7 +889,7 @@ HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( const Shape& output_shape = convolution->shape(); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - size_t rank = output_shape.dimensions_size(); + size_t rank = output_shape.dimensions().size(); // Collect sizes for input/output spatial dimensions. size_t spatial_rank = rank - 2; @@ -1105,12 +1106,12 @@ AffineMap ComputeReshapeIndexingMap(const Shape& input, const Shape& output, absl::Span output_dims = output.dimensions(); std::vector exprs; - exprs.reserve(input.dimensions_size()); + exprs.reserve(input.dimensions().size()); // If the input shape has no elements (e.g. 1000x10x0 -> 100x100x0), just set // everything to 0. if (ShapeUtil::ElementsIn(input) == 0) { - for (int i = 0; i < input.dimensions_size(); ++i) { + for (int i = 0; i < input.dimensions().size(); ++i) { exprs.push_back(getAffineConstantExpr(0, mlir_context)); } return AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, @@ -1124,9 +1125,10 @@ AffineMap ComputeReshapeIndexingMap(const Shape& input, const Shape& output, int64_t output_num_elements = 1; std::vector input_subshape, output_subshape; size_t input_dim_id = 0, output_dim_id = 0; - while (input_dim_id < input.dimensions_size() || - output_dim_id < output.dimensions_size() || !input_subshape.empty()) { - if (input_dim_id < input.dimensions_size() && + while (input_dim_id < input.dimensions().size() || + output_dim_id < output.dimensions().size() || + !input_subshape.empty()) { + if (input_dim_id < input.dimensions().size() && (input_subshape.empty() || input_num_elements < output_num_elements || input_dims[input_dim_id] == 1)) { input_num_elements *= input_dims[input_dim_id]; @@ -1134,7 +1136,7 @@ AffineMap ComputeReshapeIndexingMap(const Shape& input, const Shape& output, ++input_dim_id; continue; } - if (output_dim_id < output.dimensions_size() && + if (output_dim_id < output.dimensions().size() && (output_subshape.empty() || output_num_elements < input_num_elements || output_dims[output_dim_id] == 1)) { output_num_elements *= output_dims[output_dim_id]; @@ -1206,7 +1208,7 @@ HloInstructionIndexing ComputeReverseOpIndexing( HloInstructionIndexing ComputeOutputToInputSliceOpIndexing( const HloSliceInstruction* slice, MLIRContext* mlir_context) { - auto output_rank = slice->shape().dimensions_size(); + auto output_rank = slice->shape().dimensions().size(); std::vector exprs; exprs.reserve(output_rank); @@ -1223,7 +1225,7 @@ HloInstructionIndexing ComputeOutputToInputSliceOpIndexing( HloInstructionIndexing ComputeInputToOutputSliceOpIndexing( const HloSliceInstruction* slice, MLIRContext* mlir_context) { - auto output_rank = slice->shape().dimensions_size(); + auto output_rank = slice->shape().dimensions().size(); std::vector exprs; exprs.reserve(output_rank); @@ -1402,7 +1404,7 @@ llvm::SmallVector DelinearizeInBoundsIndex( IndexingMap GetIndexingMapFromPhysicalLayoutToLogical( const Shape& shape, MLIRContext* mlir_context) { - if (shape.dimensions_size() == 0) { + if (shape.dimensions().size() == 0) { return IndexingMap(AffineMap::get(mlir_context), /*dimensions=*/{}, /*range vars=*/{}, /*rt_vars=*/{}); } @@ -1417,7 +1419,7 @@ IndexingMap GetIndexingMapFromPhysicalLayoutToLogical( IndexingMap GetIndexingMapFromLogicalToPhysicalLayout( const Shape& shape, MLIRContext* mlir_context) { - if (shape.dimensions_size() == 0) { + if (shape.dimensions().size() == 0) { return IndexingMap(AffineMap::get(mlir_context), /*dimensions=*/{}, /*range vars=*/{}, /*rt_vars=*/{}); } From f321eb36fb34104174b6295921ef1e0238e16e0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 23:53:23 -0700 Subject: [PATCH 0789/1324] Automated Code Change PiperOrigin-RevId: 747737381 --- third_party/xla/xla/hlo/parser/hlo_parser.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/hlo/parser/hlo_parser.cc b/third_party/xla/xla/hlo/parser/hlo_parser.cc index 1019c02b772456..bdfe7b3729d62c 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser.cc @@ -2700,8 +2700,8 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return nullptr; } if (!(operands.size() == 2 && - operands[1]->shape().dimensions_size() == 1) && - operands.size() != 1 + operands[0]->shape().dimensions_size()) { + operands[1]->shape().dimensions().size() == 1) && + operands.size() != 1 + operands[0]->shape().dimensions().size()) { TokenError("Wrong number of operands."); return nullptr; } @@ -2720,8 +2720,8 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return nullptr; } if (!(operands.size() == 3 && - operands[2]->shape().dimensions_size() == 1) && - operands.size() != 2 + operands[0]->shape().dimensions_size()) { + operands[2]->shape().dimensions().size() == 1) && + operands.size() != 2 + operands[0]->shape().dimensions().size()) { TokenError("Wrong number of operands."); return nullptr; } @@ -4504,7 +4504,7 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { // Cast `rank` to int because we call shape.dimensions(int rank) below, and if // `rank` is an int64_t, that's an implicit narrowing conversion, which is // implementation-defined behavior. - const int rank = static_cast(shape.dimensions_size()); + const int rank = static_cast(shape.dimensions().size()); // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions(shape.element_type(), @@ -6337,17 +6337,18 @@ bool HloParserImpl::ParseShape(Shape* result, return false; } if (layout.dim_level_types_size() != 0 && - layout.dim_level_types_size() != result->dimensions_size()) { + layout.dim_level_types_size() != result->dimensions().size()) { return Error( lexer_.GetLoc(), StrFormat("Dimensions size is %ld, but dim level types size is %ld.", - result->dimensions_size(), layout.dim_level_types_size())); + result->dimensions().size(), + layout.dim_level_types_size())); } - if (layout.minor_to_major_size() != result->dimensions_size()) { + if (layout.minor_to_major_size() != result->dimensions().size()) { return Error( lexer_.GetLoc(), StrFormat("Dimensions size is %ld, but minor to major size is %ld.", - result->dimensions_size(), layout.minor_to_major_size())); + result->dimensions().size(), layout.minor_to_major_size())); } if (LayoutUtil::IsSparse(layout) && layout.tiles_size() > 0) { return Error(lexer_.GetLoc(), From f0a678a8789a844da64263229bbdd9568da87632 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Apr 2025 23:53:36 -0700 Subject: [PATCH 0790/1324] Automated Code Change PiperOrigin-RevId: 747737424 --- .../xla/xla/hlo/transforms/host_offload_legalize.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc b/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc index 5bd8c863c9a7ee..695c383c369bb2 100644 --- a/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc @@ -485,8 +485,8 @@ absl::Status MoveCopyDown( "Expecting copy to only change instructions layout. Copy: %s", copy_to_move->ToString())); } - if (after_bitcast_shape.dimensions_size() == - before_bitcast_shape.dimensions_size() - 1) { + if (after_bitcast_shape.dimensions().size() == + before_bitcast_shape.dimensions().size() - 1) { if (!(ShapeUtil::IsEffectivelyMostMajorDimension(before_bitcast_shape, 0) && before_bitcast_shape.dimensions(0) == 1)) { @@ -509,8 +509,8 @@ absl::Status MoveCopyDown( " Also updating shape after copy from %s to %s", shape_after_copy.ToString(true), new_copy_shape.ToString(true)); shape_after_copy = new_copy_shape; - } else if (after_bitcast_shape.dimensions_size() == - before_bitcast_shape.dimensions_size() + 1) { + } else if (after_bitcast_shape.dimensions().size() == + before_bitcast_shape.dimensions().size() + 1) { if (!(ShapeUtil::IsEffectivelyMostMajorDimension(after_bitcast_shape, 0) && after_bitcast_shape.dimensions(0) == 1)) { From cb2256642ff149713cde7da6b0c5af89898a0a2f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 00:49:08 -0700 Subject: [PATCH 0791/1324] Performance improvements: * Compute hash of large vectors in HloReplication once (on construction) * Cache results of GroupsForReplicas Cleanups: * Move static functions to member functions, stop passing around class variables * Simplify/tidy up HloReplication caching scheme PiperOrigin-RevId: 747753245 --- .../hlo/analysis/hlo_replication_analysis.cc | 183 ++++++++---------- .../hlo/analysis/hlo_replication_analysis.h | 84 +++++--- 2 files changed, 139 insertions(+), 128 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index 430de88bb2c6a6..df4204631d7354 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/hlo/analysis/hlo_replication_analysis.h" +#include + #include #include #include @@ -48,7 +50,7 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -namespace { + // When cross_partition_spmd is true, returns the partition IDs of all // replica groups in which a given replica participates. Specfically, the k-th // element of the outermost vector in the returned data structure holds the @@ -60,28 +62,30 @@ namespace { // element of the outermost vector in the returned data structure holds the // replica IDs converted from the global IDs in a collective's replica_groups // field for partition k. -std::vector>> GroupsForReplicas( - absl::Span groups, int64_t num_partitions, - int64_t replica_count, bool cross_partition_spmd) { - int64_t num_replicas = cross_partition_spmd ? replica_count : num_partitions; + +std::vector>> +HloReplicationAnalysis::GroupsForReplicas( + absl::Span groups) { + int64_t num_replicas = + cross_partition_spmd_ ? replica_count_ : num_partitions_; std::vector>> groups_for_replicas( num_replicas); for (const ReplicaGroup& group : groups) { absl::flat_hash_map> id_to_ids; for (int64_t id : group.replica_ids()) { - int64_t rid = id / num_partitions; - int64_t pid = id % num_partitions; - if (cross_partition_spmd) { + int64_t rid = id / num_partitions_; + int64_t pid = id % num_partitions_; + if (cross_partition_spmd_) { CHECK_LT(rid, num_replicas) << "Got replica ID " << rid << " which is greater or equal to the number of replicas: " << num_replicas; id_to_ids[rid].push_back(pid); } else { - CHECK_LT(pid, num_partitions) + CHECK_LT(pid, num_partitions_) << "Got partition ID " << rid << " which is greater or equal to the number of partitions: " - << num_partitions; + << num_partitions_; id_to_ids[pid].push_back(rid); } } @@ -93,8 +97,6 @@ std::vector>> GroupsForReplicas( return groups_for_replicas; } -} // namespace - // Determines whether an HLO instruction is replicated at index based on current // knowledge in hlo_replication. When cross_partition_spmd is true, the // instruction must be replicated across all partitions on each replica. @@ -102,41 +104,30 @@ std::vector>> GroupsForReplicas( // replicated across all replicas on each partition. HloReplicationAnalysis::HloReplication HloReplicationAnalysis::DetermineHloInstructionIsReplicated( - const HloInstruction* hlo, const ShapeIndex& index, - bool cross_partition_spmd, - const absl::flat_hash_map>& - hlo_replication, - bool support_partial_replication, - const absl::flat_hash_map*>& - replica_group_dedup_map, - absl::flat_hash_map, - HloReplication>& replication_merge_map) { - const auto merge_operand_replication = - [&hlo_replication, &replication_merge_map](const HloInstruction* inst) { - HloReplication replication = HloReplication::ReplicatedOnAllDevices(); - for (auto operand : inst->operands()) { - auto operand_it = hlo_replication.find(operand); - if (operand_it == hlo_replication.end()) { - replication = MergeReplications( - replication, HloReplication::UniqueOnAllDevices(), - replication_merge_map); - } else { - replication = - MergeReplications(replication, operand_it->second.element({}), - replication_merge_map); - } - } - return replication; - }; + const HloInstruction* hlo, const ShapeIndex& index) { + const auto merge_operand_replication = [this](const HloInstruction* inst) { + HloReplication replication = HloReplication::ReplicatedOnAllDevices(); + for (auto operand : inst->operands()) { + auto operand_it = hlo_replication_.find(operand); + if (operand_it == hlo_replication_.end()) { + replication = MergeReplications(replication, + HloReplication::UniqueOnAllDevices()); + } else { + replication = + MergeReplications(replication, operand_it->second.element({})); + } + } + return replication; + }; - auto calculate_all_reduce_all_gather_replication = [&](const HloInstruction* + auto calculate_all_reduce_all_gather_replication = [this]( + const HloInstruction* hlo) { if (!hlo->channel_id().has_value()) { if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) { return HloReplication::ReplicatedOnAllDevices(); } - if (support_partial_replication) { + if (support_partial_replication_) { std::vector>> device_sets_per_replica( 1); for (const ReplicaGroup& replica_group : hlo->replica_groups()) { @@ -157,34 +148,39 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( global_id = Cast(hlo)->use_global_device_ids(); } if (global_id) { - const int64_t num_partitions = - hlo->GetModule()->config().num_partitions(); - const int64_t replica_count = - hlo->GetModule()->config().replica_count(); - std::vector>> device_sets_per_replica = - GroupsForReplicas(hlo->replica_groups(), num_partitions, - replica_count, cross_partition_spmd); + // Wrap the replica_groups() in a ReplicaGroupSpan to enabling hashing + // then cache the result of GroupsForReplicas(). + HashableReplicaGroupSpan hashable_replica_groups(hlo->replica_groups()); + size_t key = absl::HashOf(hashable_replica_groups); + auto [it, inserted] = device_sets_per_replica_map_.try_emplace(key); + const std::vector>>* + device_sets_per_replica; + if (inserted) { + it->second = std::vector>>( + GroupsForReplicas(hlo->replica_groups())); + } + device_sets_per_replica = &it->second; // In the fully replicated case, there is one set of partition or // replica IDs on each replica or partition. Since the flattened ID // replica groups must contain every device, the size of the set is the // number of partitions or replicas. bool fully_replicated = true; - for (const auto& device_sets : device_sets_per_replica) { + for (const auto& device_sets : *device_sets_per_replica) { fully_replicated &= device_sets.size() == 1 && (*device_sets.begin()).size() == - (cross_partition_spmd ? num_partitions : replica_count); + (cross_partition_spmd_ ? num_partitions_ : replica_count_); } if (fully_replicated) { return HloReplication::ReplicatedOnAllDevices(); - } else if (support_partial_replication) { - return HloReplication::PartiallyReplicated(device_sets_per_replica); + } else if (support_partial_replication_) { + return HloReplication::PartiallyReplicated(*device_sets_per_replica); } else { return HloReplication::UniqueOnAllDevices(); } } - if (cross_partition_spmd) { + if (cross_partition_spmd_) { return HloReplication::ReplicatedOnAllDevices(); } if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) { @@ -203,15 +199,15 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( return replication; } // This is cross-replica-only. - if (!hlo->channel_id().has_value() && cross_partition_spmd) { + if (!hlo->channel_id().has_value() && cross_partition_spmd_) { return replication; } // To save compile time on very large replica groups, check first if the // replica group dedup map has an entry already populated with the // replication and if so return that. - auto unique_replication_it = replica_group_dedup_map.find(hlo); - if (unique_replication_it == replica_group_dedup_map.end()) { + auto unique_replication_it = replica_group_dedup_map_.find(hlo); + if (unique_replication_it == replica_group_dedup_map_.end()) { VLOG(1) << "No dedup entry for " << hlo->name(); return calculate_all_reduce_all_gather_replication(hlo); } @@ -227,21 +223,21 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( } if (hlo->opcode() == HloOpcode::kReplicaId) { // ReplicaId returns the same value for all partitions in each replica. - return cross_partition_spmd ? HloReplication::ReplicatedOnAllDevices() - : HloReplication::UniqueOnAllDevices(); + return cross_partition_spmd_ ? HloReplication::ReplicatedOnAllDevices() + : HloReplication::UniqueOnAllDevices(); } if (hlo->opcode() == HloOpcode::kPartitionId) { // PartitionId returns the same value for all replicas in each partition. - return cross_partition_spmd ? HloReplication::UniqueOnAllDevices() - : HloReplication::ReplicatedOnAllDevices(); + return cross_partition_spmd_ ? HloReplication::UniqueOnAllDevices() + : HloReplication::ReplicatedOnAllDevices(); } - auto it = hlo_replication.find(hlo); + auto it = hlo_replication_.find(hlo); if (hlo->opcode() == HloOpcode::kParameter) { // Parameters should have been processed. - CHECK(it != hlo_replication.end()); + CHECK(it != hlo_replication_.end()); return it->second.element(index); } - if (it != hlo_replication.end() && + if (it != hlo_replication_.end() && it->second.element(index).IsUniqueOnAllDevices()) { // The HLO is already marked as non-replicated. return it->second.element(index); @@ -259,7 +255,7 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( } // Pattern-match and process cases where the HLO is partially replicated. - if (support_partial_replication) { + if (support_partial_replication_) { // Below is a very specific pattern to match the SPMD pipeline case. if (hlo->opcode() == HloOpcode::kDynamicSlice) { const HloInstruction* ds_buffer = hlo->operand(0); @@ -268,12 +264,12 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( ds_buffer->opcode() == HloOpcode::kConstant && ds_buffer->shape().dimensions().size() == 1 && ds_buffer->shape().element_type() == PrimitiveType::S32 && - ((cross_partition_spmd && + ((cross_partition_spmd_ && hlo->operand(1)->opcode() == HloOpcode::kPartitionId) || - (!cross_partition_spmd && + (!cross_partition_spmd_ && hlo->operand(1)->opcode() == HloOpcode::kReplicaId))) { const HloModule* hlo_module = hlo->GetModule(); - int64_t num_devices = cross_partition_spmd + int64_t num_devices = cross_partition_spmd_ ? hlo_module->config().num_partitions() : hlo_module->config().replica_count(); absl::flat_hash_map> value_to_device_set; @@ -319,7 +315,7 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( const HloComputation* computation, bool mark_everything_not_replicated) { bool changed = false; - for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + for (const HloInstruction* inst : computation->MakeInstructionPostOrder()) { // Assigns the shape tree to dest if dest doesn't have one yet, or combines // it with the existing one by and'ing them. Returns if anything is updated. auto assign_or_combine_shapetree = @@ -331,15 +327,15 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( return true; } bool updated = false; - it->second.ForEachMutableElement([&](const ShapeIndex& index, - HloReplication* element) { - HloReplication new_replication = MergeReplications( - *element, to_combine.element(index), replication_merge_map_); - if (!element->Equal(new_replication)) { - *element = std::move(new_replication); - updated = true; - } - }); + it->second.ForEachMutableElement( + [&](const ShapeIndex& index, HloReplication* element) { + HloReplication new_replication = + MergeReplications(*element, to_combine.element(index)); + if (!element->Equal(new_replication)) { + *element = std::move(new_replication); + updated = true; + } + }); return updated; }; // Assigns or combines source's shape tree to dest. Returns if anything is @@ -476,10 +472,7 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( ShapeUtil::ForEachSubshape( inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { *shape_tree.mutable_element(index) = - DetermineHloInstructionIsReplicated( - inst, index, cross_partition_spmd_, hlo_replication_, - support_partial_replication_, replica_group_dedup_map_, - replication_merge_map_); + DetermineHloInstructionIsReplicated(inst, index); }); changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); } @@ -675,9 +668,9 @@ HloReplicationAnalysis::HloReplication::HloReplication( absl::Span> device_set_root_per_replica) : state_(state), device_set_root_per_replica_( - std::make_shared>>( - device_set_root_per_replica.begin(), - device_set_root_per_replica.end())) { + std::make_shared< + HashOnConstruction>>>( + device_set_root_per_replica)) { CHECK(state == State::kPartiallyReplicated || device_set_root_per_replica_->empty()); } @@ -771,27 +764,11 @@ HloReplicationAnalysis::HloReplication::Merge( } } -HloReplicationAnalysis::HloReplication::HloReplication( - const std::pair& merge_pair) { - auto merged_replication = merge_pair.first.Merge(merge_pair.second); - state_ = merged_replication.state_; - device_set_root_per_replica_ = - std::move(merged_replication.device_set_root_per_replica_); -} - bool HloReplicationAnalysis::HloReplication::Equal( const HloReplication& other) const { - if (state_ != other.state_) { - return false; - } - for (int i = 0; i < device_set_root_per_replica_->size(); ++i) { - if (device_set_root_per_replica_->at(i) != - other.device_set_root_per_replica_->at(i)) { - return false; - } - } - - return true; + return state_ == other.state_ && + device_set_root_per_replica_->hash_ == + other.device_set_root_per_replica_->hash_; } bool HloReplicationAnalysis::HloReplication::operator==( diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h index 21f64e0d0e6ab7..2d1fa0f6462535 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ #define XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ +#include #include #include #include @@ -34,6 +35,23 @@ limitations under the License. namespace xla { +// A wrapper around absl::Span that allows us to hash it +class HashableReplicaGroupSpan : absl::Span { + public: + explicit HashableReplicaGroupSpan(const absl::Span groups) + : absl::Span(groups) {} + + template + friend H AbslHashValue(H h, const HashableReplicaGroupSpan& a) { + for (const auto& group : a) { + for (int64_t id : group.replica_ids()) { + h = H::combine(std::move(h), id); + } + } + return H::combine(std::move(h), a.size()); + } +}; + // An HLO pass that determines whether each instruction in the module outputs // the same value across replicas or across partitions (depending on the value // `cross_partition_spmd`). It propagates sources of replicated values to @@ -81,11 +99,6 @@ class HloReplicationAnalysis { HloReplication(); HloReplication(const HloReplication& other) = default; HloReplication(HloReplication&& other) = default; - // Create a new HloReplication that is the merge of two other HloReplication - // objects using the Merge() method, useful for lazy construction with - // try_emplace. - explicit HloReplication( - const std::pair& merge_pair); HloReplication& operator=(HloReplication&& other) = default; HloReplication Merge(const HloReplication& other) const; bool Equal(const HloReplication& other) const; @@ -111,6 +124,25 @@ class HloReplicationAnalysis { State state, absl::Span> device_set_root_per_replica); State state_; + // Helper class that subclasses T, and computes the hash once on + // construction, and intercepts the hash function to use the precomputed + // hash. + template + class HashOnConstruction : public T { + public: + template + explicit HashOnConstruction(V& device_set_root_per_replica) + : T(device_set_root_per_replica.begin(), + device_set_root_per_replica.end()), + hash_(absl::HashOf(device_set_root_per_replica)) {} + + const size_t hash_; + + template + friend H AbslHashValue(H h, const HashOnConstruction& r) { + return H::combine(std::move(h), r.hash_); + } + }; // Empty if state_ is kReplicatedOnAllDevices or kUniqueOnAllDevices. // If cross_partition_spmd is true, groups_for_replicas_[k]'s size equals @@ -120,33 +152,28 @@ class HloReplicationAnalysis { // If cross_partition_spmd is false, groups_for_replicas_[k]'s size equals // the number of replicas, and within partition k, groups_for_replicas_[k] // maps each replica to the smallest replica ID in the set. - std::shared_ptr>> + std::shared_ptr>>> device_set_root_per_replica_; }; - static HloReplication DetermineHloInstructionIsReplicated( - const HloInstruction* hlo, const ShapeIndex& index, - bool cross_partition_spmd, - const absl::flat_hash_map>& hlo_replication, - bool support_partial_replication, - const absl::flat_hash_map*>& - replica_group_dedup_map, - absl::flat_hash_map, - HloReplication>& replication_merge_map); - - static HloReplication MergeReplications( - const HloReplication& replication_a, const HloReplication& replication_b, - absl::flat_hash_map, - HloReplication>& replication_merge_map) { + std::vector>> GroupsForReplicas( + absl::Span groups); + + HloReplication DetermineHloInstructionIsReplicated(const HloInstruction* hlo, + const ShapeIndex& index); + + HloReplication MergeReplications(const HloReplication& replication_a, + const HloReplication& replication_b) { std::pair key = {replication_a, replication_b}; // Look replication pair up in map: if not found we pass the pair to an // overloaded constructor of HloReplication which constructs and returns // a merged HloReplication. - auto [iter, inserted] = replication_merge_map.try_emplace(key, key); + auto [iter, inserted] = replication_merge_map_.try_emplace(key); + if (inserted) { + iter->second = replication_a.Merge(replication_b); + } return iter->second; } @@ -157,7 +184,9 @@ class HloReplicationAnalysis { : module_(module), cross_partition_spmd_(cross_partition_spmd), loops_known_with_same_iterations_(*loops_known_with_same_iterations), - support_partial_replication_(support_partial_replication) {} + support_partial_replication_(support_partial_replication), + num_partitions_(module_->config().num_partitions()), + replica_count_(module_->config().replica_count()) {} // Computes hlo_replication_. absl::Status ComputeHloReplication(); @@ -183,7 +212,7 @@ class HloReplicationAnalysis { // are identical across partitions. // // If false, HloReplicationAnalysis runs across replicas. - bool cross_partition_spmd_; + const bool cross_partition_spmd_; // A set of while loops that are known to have the same iteration counts // across replicas or partitions. This is provided by the caller as additional @@ -193,6 +222,9 @@ class HloReplicationAnalysis { const bool support_partial_replication_; + // Capture the number of partitions / replicas for the module. + const int64_t num_partitions_, replica_count_; + // A map from each analyzed HLO instruction to a shape tree that represents // whether the instruction outputs the same value across replicas or // partitions at each shape index. @@ -207,6 +239,8 @@ class HloReplicationAnalysis { absl::flat_hash_map, HloReplication> replication_merge_map_; std::vector> unique_replications_; + absl::flat_hash_map>>> + device_sets_per_replica_map_; }; } // namespace xla From c7bba2448005d7b668777e2c71f8a5dc12681628 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 15 Apr 2025 10:05:16 +0100 Subject: [PATCH 0792/1324] [mlir][tosa] Remove softmax legalization changes causing numerical divergence (#90802) This commit removes some changes that were unintentionally added in #89907. The softmax changes diverge form the numerical behaviour specified by the reference kernels, so they are undesireable at this time. Change-Id: I00e6360f3d5dd24cc817a8415c20592bd578e2fc --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 131 +++++++++--------- .../mlir/tosa/transforms/legalize_common.cc | 27 +--- 2 files changed, 70 insertions(+), 88 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 2c4a64af956ecf..7ed4a52665b6cf 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2502,73 +2502,70 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<5> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<31> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<-1010580540> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<1515870810> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<536870912> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<4> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<13> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<7> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<9> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<17> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> -// CHECK-DAG: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<23> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<"0x5{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_18:.*]] = "tosa.const"() <{values = dense<"0xE{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_19:.*]] = "tosa.const"() <{values = dense<"0x4{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<"0x0{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> -// CHECK-DAG: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_23:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_24:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> -// CHECK-DAG: %[[VAL_25:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_24]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32> -// CHECK-DAG: %[[VAL_26:.*]] = tosa.reduce_max %[[VAL_25]] -// CHECK-DAG: %[[VAL_27:.*]] = tosa.sub %[[VAL_25]], %[[VAL_26]] -// CHECK-DAG: %[[VAL_28:.*]] = tosa.rescale %[[VAL_27]], %[[VAL_21]], %[[VAL_16]], %[[VAL_24]], %[[VAL_15]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3x!quant.uniform> -// CHECK-DAG: %[[VAL_29:.*]] = tosa.table %[[VAL_28]], %[[VAL_20]] -// CHECK-DAG: %[[VAL_30:.*]] = tosa.table %[[VAL_28]], %[[VAL_19]] -// CHECK-DAG: %[[VAL_31:.*]] = tosa.table %[[VAL_28]], %[[VAL_18]] -// CHECK-DAG: %[[VAL_32:.*]] = tosa.table %[[VAL_28]], %[[VAL_17]] -// CHECK-DAG: %[[VAL_33:.*]] = tosa.logical_left_shift %[[VAL_29]], %[[VAL_14]] -// CHECK-DAG: %[[VAL_34:.*]] = tosa.logical_left_shift %[[VAL_30]], %[[VAL_13]] -// CHECK-DAG: %[[VAL_35:.*]] = tosa.logical_left_shift %[[VAL_31]], %[[VAL_12]] -// CHECK-DAG: %[[VAL_36:.*]] = tosa.arithmetic_right_shift %[[VAL_32]], %[[VAL_11]] -// CHECK-DAG: %[[VAL_37:.*]] = tosa.add %[[VAL_33]], %[[VAL_34]] -// CHECK-DAG: %[[VAL_38:.*]] = tosa.add %[[VAL_37]], %[[VAL_35]] -// CHECK-DAG: %[[VAL_39:.*]] = tosa.add %[[VAL_38]], %[[VAL_36]] -// CHECK-DAG: %[[VAL_40:.*]] = tosa.arithmetic_right_shift %[[VAL_39]], %[[VAL_10]] -// CHECK-DAG: %[[VAL_41:.*]] = tosa.reduce_sum %[[VAL_40]] -// CHECK-DAG: %[[VAL_42:.*]] = tosa.clz %[[VAL_41]] -// CHECK-DAG: %[[VAL_43:.*]] = tosa.sub %[[VAL_42]], %[[VAL_12]] -// CHECK-DAG: %[[VAL_44:.*]] = tosa.logical_left_shift %[[VAL_41]], %[[VAL_43]] -// CHECK-DAG: %[[VAL_45:.*]] = tosa.mul %[[VAL_44]], %[[VAL_6]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_46:.*]] = tosa.add %[[VAL_45]], %[[VAL_7]] -// CHECK-DAG: %[[VAL_47:.*]] = tosa.mul %[[VAL_46]], %[[VAL_44]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_48:.*]] = tosa.sub %[[VAL_8]], %[[VAL_47]] -// CHECK-DAG: %[[VAL_49:.*]] = tosa.mul %[[VAL_46]], %[[VAL_48]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_50:.*]] = tosa.mul %[[VAL_49]], %[[VAL_9]], %[[VAL_4]] -// CHECK-DAG: %[[VAL_51:.*]] = tosa.add %[[VAL_46]], %[[VAL_50]] -// CHECK-DAG: %[[VAL_52:.*]] = tosa.mul %[[VAL_51]], %[[VAL_44]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_53:.*]] = tosa.sub %[[VAL_8]], %[[VAL_52]] -// CHECK-DAG: %[[VAL_54:.*]] = tosa.mul %[[VAL_51]], %[[VAL_53]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_55:.*]] = tosa.mul %[[VAL_54]], %[[VAL_9]], %[[VAL_4]] -// CHECK-DAG: %[[VAL_56:.*]] = tosa.add %[[VAL_51]], %[[VAL_55]] -// CHECK-DAG: %[[VAL_57:.*]] = tosa.mul %[[VAL_56]], %[[VAL_44]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_58:.*]] = tosa.sub %[[VAL_8]], %[[VAL_57]] -// CHECK-DAG: %[[VAL_59:.*]] = tosa.mul %[[VAL_56]], %[[VAL_58]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_60:.*]] = tosa.mul %[[VAL_59]], %[[VAL_9]], %[[VAL_4]] -// CHECK-DAG: %[[VAL_61:.*]] = tosa.add %[[VAL_56]], %[[VAL_60]] -// CHECK-DAG: %[[VAL_62:.*]] = tosa.mul %[[VAL_39]], %[[VAL_61]], %[[VAL_22]] -// CHECK-DAG: %[[VAL_63:.*]] = tosa.sub %[[VAL_3]], %[[VAL_42]] -// CHECK-DAG: %[[VAL_64:.*]] = tosa.arithmetic_right_shift %[[VAL_62]], %[[VAL_2]] -// CHECK-DAG: %[[VAL_65:.*]] = tosa.arithmetic_right_shift %[[VAL_64]], %[[VAL_63]] -// CHECK-DAG: %[[VAL_66:.*]] = tosa.rescale %[[VAL_65]], %[[VAL_21]], %[[VAL_22]], %[[VAL_24]], %[[VAL_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<35> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<4> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<536870912> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{values = dense<1515870810> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{values = dense<-1010580540> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{values = dense<12> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{values = dense<7> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() <{values = dense<9> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() <{values = dense<17> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() <{values = dense<"0x5{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() <{values = dense<"0xE{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() <{values = dense<"0x4{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() <{values = dense<"0x0{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT_31:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> +// CHECK-DAG: %[[mult1073741824:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[shift30:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[shift23:.*]] = "tosa.const"() <{values = dense<23> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[input_zp1:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> +// CHECK-DAG: %[[zp0i32:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> +// CHECK-DAG: %[[output_zp128:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL27:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> +// CHECK-DAG: %[[VAR15:.*]] = tosa.rescale %arg0, %[[mult1073741824]], %[[shift30]], %[[input_zp1]], %[[zp0i32]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} +// CHECK-DAG: %[[VAR16:.*]] = tosa.reduce_max %[[VAR15]] {axis = 2 : i32} +// CHECK-DAG: %[[VAR17:.*]] = tosa.sub %[[VAR15]], %[[VAR16]] +// CHECK-DAG: %[[VAR18:.*]] = tosa.rescale %[[VAR17]], %[[mult1073741824]], %[[shift23]], %[[zp0i32]], %[[VAL27]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} +// CHECK-DAG: %[[VAR19:.*]] = tosa.table %[[VAR18]], %[[VAR14]] +// CHECK-DAG: %[[VAR20:.*]] = tosa.table %[[VAR18]], %[[VAR13]] +// CHECK-DAG: %[[VAR21:.*]] = tosa.table %[[VAR18]], %[[VAR12]] +// CHECK-DAG: %[[VAR22:.*]] = tosa.table %[[VAR18]], %[[VAR11]] +// CHECK-DAG: %[[VAR23:.*]] = tosa.logical_left_shift %[[VAR19]], %[[VAR10]] +// CHECK-DAG: %[[VAR24:.*]] = tosa.logical_left_shift %[[VAR20]], %[[VAR9]] +// CHECK-DAG: %[[VAR25:.*]] = tosa.logical_left_shift %[[VAR21]], %[[VAR6]] +// CHECK-DAG: %[[VAR26:.*]] = tosa.arithmetic_right_shift %[[VAR22]], %[[VAR8]] {round = true} +// CHECK-DAG: %[[VAR27:.*]] = tosa.add %[[VAR23]], %[[VAR24]] +// CHECK-DAG: %[[VAR28:.*]] = tosa.add %[[VAR27]], %[[VAR25]] +// CHECK-DAG: %[[VAR29:.*]] = tosa.add %[[VAR28]], %[[VAR26]] +// CHECK-DAG: %[[VAR30:.*]] = tosa.arithmetic_right_shift %[[VAR29]], %[[VAR7]] {round = true} +// CHECK-DAG: %[[VAR31:.*]] = tosa.reduce_sum %[[VAR30]] {axis = 2 : i32} +// CHECK-DAG: %[[VAR32:.*]] = tosa.clz %[[VAR31]] +// CHECK-DAG: %[[VAR33:.*]] = tosa.sub %[[VAR32]], %[[VAR6]] +// CHECK-DAG: %[[VAR34:.*]] = tosa.logical_left_shift %[[VAR31]], %[[VAR33]] +// CHECK-DAG: %[[VAR35:.*]] = tosa.mul %[[VAR34]], %[[VAR5]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR36:.*]] = tosa.add %[[VAR35]], %[[VAR4]] +// CHECK-DAG: %[[VAR37:.*]] = tosa.mul %[[VAR36]], %[[VAR34]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR38:.*]] = tosa.sub %[[VAR3]], %[[VAR37]] +// CHECK-DAG: %[[VAR39:.*]] = tosa.mul %[[VAR36]], %[[VAR38]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR40:.*]] = tosa.mul %[[VAR39]], %[[VAR2]], %[[SHIFT]] +// CHECK-DAG: %[[VAR41:.*]] = tosa.add %[[VAR36]], %[[VAR40]] +// CHECK-DAG: %[[VAR42:.*]] = tosa.mul %[[VAR41]], %[[VAR34]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR43:.*]] = tosa.sub %[[VAR3]], %[[VAR42]] +// CHECK-DAG: %[[VAR44:.*]] = tosa.mul %[[VAR41]], %[[VAR43]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR45:.*]] = tosa.mul %[[VAR44]], %[[VAR2]], %[[SHIFT]] +// CHECK-DAG: %[[VAR46:.*]] = tosa.add %[[VAR41]], %[[VAR45]] +// CHECK-DAG: %[[VAR47:.*]] = tosa.mul %[[VAR46]], %[[VAR34]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR48:.*]] = tosa.sub %[[VAR3]], %[[VAR47]] +// CHECK-DAG: %[[VAR49:.*]] = tosa.mul %[[VAR46]], %[[VAR48]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR50:.*]] = tosa.mul %[[VAR49]], %[[VAR2]], %[[SHIFT]] +// CHECK-DAG: %[[VAR51:.*]] = tosa.add %[[VAR46]], %[[VAR50]] +// CHECK-DAG: %[[VAR52:.*]] = tosa.mul %[[VAR29]], %[[VAR51]], %[[shift30]] +// CHECK-DAG: %[[VAR53:.*]] = tosa.sub %[[VAR1]], %[[VAR32]] +// CHECK-DAG: %[[VAR54:.*]] = tosa.arithmetic_right_shift %[[VAR52]], %[[VAR53]] {round = true} +// CHECK: %[[VAR55:.*]] = tosa.rescale %[[VAR54]], %[[mult1073741824]], %[[shift30]], %[[zp0i32]], %[[output_zp128]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 2cdace08b6d193..471149d7a99165 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -1699,11 +1699,11 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), int32_logits_type, op12_add_op11_op9.getResult(), op10_rshift_op8.getResult()); - // Step 3. get sum(exp()). output 13.18 + // Step 3. get sum(exp()). output 12.19 auto op14_rshift_op13_12 = CreateOpAndInfer( rewriter, op->getLoc(), int32_logits_type, op13_add_op12_op10.getResult(), - getTosaConstTensorSingleI32(rewriter, op, 13, input_rank), true); + getTosaConstTensorSingleI32(rewriter, op, 12, input_rank), true); auto op15_reducesum_op14 = CreateOpAndInfer( rewriter, op->getLoc(), int32_rsum_type, @@ -1790,32 +1790,17 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, // Right shift amount is // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 = - // (13 - headroom_plus_one) + 31 - 8 = - // (13 + 31 - 8) - headroom_plus_one - - // The calculated shift amount can be larger than 31, which is invalid - // in TOSA. In this case, the output should be the quantized equivalent - // to all 0's. To emulate this behaviour, we can use two shifts: - // 1. Right shift of 5, calculated by: - // max_headroom_plus_one_value = 31; - // 13 + 31 - 8 - max_headroom_plus_one_value - // 2. Right shift by the remainder - constexpr int constant_shift_amount = 5; - + // (12 - headroom_plus_one) + 31 - 8 = + // (12 + 31 - 8) - headroom_plus_one auto op27_sub_op16 = CreateOpAndInfer( rewriter, op->getLoc(), int32_rsum_type, - getTosaConstTensorSingleI32(rewriter, op, 13 + 31 - 8 - constant_shift_amount, input_rank), + getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8, input_rank), op16_clz_op15.getResult()); - auto constant_shift = CreateOpAndInfer( - rewriter, op->getLoc(), int32_logits_type, - op26_mul_op13_x.getResult(), getTosaConstTensorSingleI32(rewriter, op, constant_shift_amount, input_rank), - false); - auto op28_rshift_op26_op27 = CreateOpAndInfer( rewriter, op->getLoc(), int32_logits_type, - constant_shift.getResult(), op27_sub_op16.getResult(), true); + op26_mul_op13_x.getResult(), op27_sub_op16.getResult(), true); return buildRescale(rewriter, op, output_type, op28_rshift_op26_op27.getResult(), 1.0, 0, From acaf725b333b4c713360b8e29325076e560108e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 01:35:03 -0700 Subject: [PATCH 0793/1324] [XLA:GPU][Emitters] Refactor emitter filter for int4 datatypes PiperOrigin-RevId: 747766201 --- .../xla/service/gpu/hlo_fusion_analysis.cc | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index 0d318843944d33..99a7f664ca8931 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -246,21 +246,11 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kCuDnn; } - if (input_output_info_.smallest_input_dtype_bits < 8 || - input_output_info_.smallest_output_dtype_bits < 8) { - // Only loop and input slice fusions currently can handle packed - // inputs/outputs, due to the special handling with IrArray needed to deal - // with multiple values occupying a single byte. - if (fusion_roots_.size() > 1 && - IsInputFusibleNonStridedSlices(fusion_roots_) && - AllSliceInputsAreCompatible(fusion_roots_)) { - return EmitterFusionKind::kInputSlices; - } - if (fusion_roots_[0].opcode() == HloOpcode::kScatter) { - return EmitterFusionKind::kScatter; - } - return EmitterFusionKind::kLoop; - } + // TODO(b/406763726): Only some emitters currently can handle packed + // inputs/outputs, due to the special handling with IrArray needed to deal + // with multiple values occupying a single byte. + bool has_subtype_type = input_output_info_.smallest_input_dtype_bits < 8 || + input_output_info_.smallest_output_dtype_bits < 8; std::optional first_reduce_hero; for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) { @@ -292,13 +282,13 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() break; } } - if (valid_shapes) { + if (valid_shapes && !has_subtype_type) { return EmitterFusionKind::kReduction; } } // We expect that the last dimension is swapped with a different dimension. - if (HasConsistentTransposeHeros()) { + if (HasConsistentTransposeHeros() && !has_subtype_type) { return EmitterFusionKind::kTranspose; } @@ -314,7 +304,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kScatter; } - if (UseConcatenateFusion(fusion_roots_, fusion_heroes_)) { + if (UseConcatenateFusion(fusion_roots_, fusion_heroes_) && + !has_subtype_type) { return EmitterFusionKind::kConcatenate; } From 67c4f3890322d918ce7de7286d1b237c87d5cc86 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Tue, 15 Apr 2025 01:45:17 -0700 Subject: [PATCH 0794/1324] [xla:gpu] Use accumulators with the correct precision for Split-K GEMMs. This fixes numerics of GEMMs with 16bit input types. It removes the need to work around said issues with `--xla_gpu_enable_split_k_autotuning=false`. PiperOrigin-RevId: 747769209 --- third_party/xla/xla/debug_options_flags.cc | 2 +- .../xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc | 4 ++++ third_party/xla/xla/xla.proto | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 1b226e3456a56e..fa4446a4f78a93 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -241,7 +241,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); - opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(false); + opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(true); opts.set_xla_gpu_unsafe_pipelined_loop_annotator(false); opts.set_xla_gpu_copy_insertion_use_region_analysis(false); diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index e2daf1825f7a56..a08cfd484e71d8 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -370,6 +370,10 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { VLOG(5) << m->ToString(); const HloInstruction* dot_fusion = m->entry_computation()->root_instruction(); + // Split-K rewriting may introduce a convert and / or a reduce op. + if (dot_fusion->opcode() == HloOpcode::kConvert) { + dot_fusion = dot_fusion->operand(0); + } if (dot_fusion->opcode() == HloOpcode::kReduce) { dot_fusion = dot_fusion->operand(0); } diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index b0389816056c9d..31007c9e0d54d6 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -754,6 +754,8 @@ message DebugOptions { // `xla_gpu_cublas_fallback` set to false. bool xla_gpu_triton_gemm_any = 190; + // TODO(b/409940111): Remove this flag and use high precision reductions for + // Split-K GEMMs unconditionally. bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; // It is usually preferable to not fallback to the driver; it can consume more From 5ea25afd71164ced12527bfbf8a12330a201d416 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 01:52:55 -0700 Subject: [PATCH 0795/1324] Reverts 50cd97d31a27ad6cb97fb6613e0fb43f5bb746c2 PiperOrigin-RevId: 747771981 --- .../xla/xla/backends/gpu/runtime/cub_sort_thunk.cc | 8 ++------ third_party/xla/xla/service/gpu/build_defs.bzl | 3 --- third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc | 9 --------- third_party/xla/xla/service/gpu/cub_sort_kernel.h | 4 ---- .../xla/xla/service/gpu/tests/gpu_cub_sort_test.cc | 2 +- 5 files changed, 3 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc index f4b35813f4f4c2..8195bf0a1ab226 100644 --- a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc @@ -241,8 +241,8 @@ absl::StatusOr> CreateCubSortRunner( } // Returns an interface for calling CubSortPairs on the given key and value -// types. key_type can be any unsigned integer types or F32. value_type can be -// any type of 16/32/64 bit width. +// types. key_type can be only unsigned integer types. value_type can be any +// type of 16/32/64 bit width. absl::StatusOr> CreateCubSortRunner( PrimitiveType key_type, PrimitiveType value_type) { int value_width = primitive_util::BitWidth(value_type); @@ -260,10 +260,6 @@ absl::StatusOr> CreateCubSortRunner( if (key_type == U64 && value_width == 32) sort_fn = CubSortPairs_u64_b32; if (key_type == U64 && value_width == 64) sort_fn = CubSortPairs_u64_b64; - if (key_type == F32 && value_width == 16) sort_fn = CubSortPairs_f32_b16; - if (key_type == F32 && value_width == 32) sort_fn = CubSortPairs_f32_b32; - if (key_type == F32 && value_width == 64) sort_fn = CubSortPairs_f32_b64; - if (sort_fn == nullptr) { return InvalidArgument( "Unsupported key/value type combination for CubSortPairs: %s/%s", diff --git a/third_party/xla/xla/service/gpu/build_defs.bzl b/third_party/xla/xla/service/gpu/build_defs.bzl index e5f3036865a225..9ae3e2ab0b08f9 100644 --- a/third_party/xla/xla/service/gpu/build_defs.bzl +++ b/third_party/xla/xla/service/gpu/build_defs.bzl @@ -42,9 +42,6 @@ def get_cub_sort_kernel_types(name = ""): "u8_b16", "u8_b32", "u8_b64", - "f32_b16", - "f32_b32", - "f32_b64", ] def build_cub_sort_kernels(name, types, local_defines = [], **kwargs): diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc index e5d9db3f52ee75..507bd4d7da953f 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc @@ -244,15 +244,6 @@ XLA_CUB_DEFINE_SORT_PAIRS(u32_b32, uint32_t, uint32_t) #ifdef CUB_TYPE_U32_B64 XLA_CUB_DEFINE_SORT_PAIRS(u32_b64, uint32_t, uint64_t) #endif -#ifdef CUB_TYPE_F32_B16 -XLA_CUB_DEFINE_SORT_PAIRS(f32_b16, float, uint16_t) -#endif -#ifdef CUB_TYPE_F32_B32 -XLA_CUB_DEFINE_SORT_PAIRS(f32_b32, float, uint32_t) -#endif -#ifdef CUB_TYPE_F32_B64 -XLA_CUB_DEFINE_SORT_PAIRS(f32_b64, float, uint64_t) -#endif // Pairs with 64-bit key. #ifdef CUB_TYPE_U64_B16 diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.h b/third_party/xla/xla/service/gpu/cub_sort_kernel.h index 627dd7ef079b84..29b163e7b1bf0b 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.h +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.h @@ -67,10 +67,6 @@ XLA_CUB_DECLARE_SORT_PAIRS(u64_b16) XLA_CUB_DECLARE_SORT_PAIRS(u64_b32) XLA_CUB_DECLARE_SORT_PAIRS(u64_b64) -XLA_CUB_DECLARE_SORT_PAIRS(f32_b16) -XLA_CUB_DECLARE_SORT_PAIRS(f32_b32) -XLA_CUB_DECLARE_SORT_PAIRS(f32_b64) - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index eb22488d981ffa..d37d322ddb32f2 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -306,7 +306,7 @@ ENTRY m { INSTANTIATE_TEST_SUITE_P( CubSort, CubSortPairsTest, - ::testing::Combine(::testing::Values(U8, U16, U32, U64, F32), + ::testing::Combine(::testing::Values(U8, U16, U32, U64), ::testing::Values(F16, F32, F64), ::testing::Bool(), ::testing::Values(1, 10)), [](const ::testing::TestParamInfo& info) { From 3e645bfb3362e0cb8b459a7e5470042442c64c32 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 02:02:30 -0700 Subject: [PATCH 0796/1324] Update GraphDef version to 2198. PiperOrigin-RevId: 747774829 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 7e517bef019aa0..e98dbab9f65a0e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2197 // Updated: 2025/4/14 +#define TF_GRAPH_DEF_VERSION 2198 // Updated: 2025/4/15 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From a4810fda0fe9f5f6b1c483fed55c77e9bef6d658 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 02:05:15 -0700 Subject: [PATCH 0797/1324] compat: Update forward compatibility horizon to 2025-04-15 PiperOrigin-RevId: 747776041 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 61c797a99bedbc..02ca88b73f8257 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 14) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 15) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From e4ba5595b5f9769ad1ec6be0dc0ed5624b1deaee Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Tue, 15 Apr 2025 13:26:56 +0200 Subject: [PATCH 0798/1324] Add `use_default_shell_env = True` to `build_pip_package_py` rule This is similar to https://github.com/tensorflow/tensorflow/pull/44549 where the missing parameter causes failures to find required binaries on non-default locations. --- tensorflow/tools/pip_package/utils/tf_wheel.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/pip_package/utils/tf_wheel.bzl b/tensorflow/tools/pip_package/utils/tf_wheel.bzl index e541c1310eba44..c4dcd4682c8a23 100644 --- a/tensorflow/tools/pip_package/utils/tf_wheel.bzl +++ b/tensorflow/tools/pip_package/utils/tf_wheel.bzl @@ -130,6 +130,7 @@ def _tf_wheel_impl(ctx): inputs = srcs + headers + xla_aot, outputs = [output_file], executable = executable, + use_default_shell_env = True, ) return [DefaultInfo(files = depset(direct = [output_file]))] From 646344a0866d984e497d5ef1e1f2f9729846f27c Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Tue, 15 Apr 2025 02:36:06 -0700 Subject: [PATCH 0799/1324] [XLA:GPU] Add triton support test for copy-start & copy-done PiperOrigin-RevId: 747785683 --- .../backends/gpu/codegen/triton/support.cc | 2 -- .../gpu/codegen/triton/support_test.cc | 35 +++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 22512f3fa9188b..0d6f445476b2e0 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -630,8 +630,6 @@ namespace internal { bool IsTritonUnsupportedOpcode(HloOpcode opcode) { switch (opcode) { case HloOpcode::kConvolution: - case HloOpcode::kCopyDone: - case HloOpcode::kCopyStart: case HloOpcode::kDynamicReshape: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index f117ba9f0b01e6..edf16a0115fa93 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -2664,12 +2664,42 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest())), TritonSupportTestTypeAndDeviceToString); +using CopyStartDoneTest = TritonSupportTestWithTypeAndDeviceParam; + +TEST_P(CopyStartDoneTest, CopyStartDone) { + auto [data_type, cc] = GetParam(); + const std::string kHloTestTemplate = R"( + ENTRY triton_computation { + parameter = $0[10,10,10] parameter(0) + cp_start = ($0[10,10,10], $0[10,10,10], u32[]) copy-start(parameter) + ROOT cp_done = $0[10,10,10] copy-done(cp_start) + })"; + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti_start, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, + HloOpcode::kCopyStart)); + RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{1, 1, 1}, cc); + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti_done, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, + HloOpcode::kCopyDone)); + RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{1, 1, 1}, cc); +} +constexpr std::array kTestedOpsCopy = {HloOpcode::kCopyStart, + HloOpcode::kCopyDone}; + +INSTANTIATE_TEST_SUITE_P( + CopyStartDoneSuite, CopyStartDoneTest, + ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllDevicesToTest())), + TritonSupportTestTypeAndDeviceToString); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start HloOpcode::kConvolution, - HloOpcode::kCopyDone, - HloOpcode::kCopyStart, HloOpcode::kDynamicReshape, HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, @@ -2717,6 +2747,7 @@ absl::flat_hash_set AllTestedOpcodes() { ret.insert(kTestedOpsConstant.begin(), kTestedOpsConstant.end()); ret.insert(kTestedOpsIota.begin(), kTestedOpsIota.end()); ret.insert(kTestedOpsRng.begin(), kTestedOpsRng.end()); + ret.insert(kTestedOpsCopy.begin(), kTestedOpsCopy.end()); ret.emplace(HloOpcode::kAfterAll); ret.emplace(HloOpcode::kAddDependency); From 0cfdb0066ae8b9d2c22e354797ccab2656508c67 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 15 Apr 2025 04:02:16 -0700 Subject: [PATCH 0800/1324] [XLA:GPU] Reject dymamic shapes in `SortRewriter`. The custom call does not support dynamic shapes yet. PiperOrigin-RevId: 747808534 --- .../service/gpu/transforms/sort_rewriter.cc | 8 ++++ .../gpu/transforms/sort_rewriter_test.cc | 38 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc index c431ad1560e299..90e42982294a67 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc @@ -390,11 +390,19 @@ bool IsCubCompatibleSort(const se::DeviceDescription& device_description, return false; } + for (const auto& op : sort_op->operands()) { + if (op->shape().is_dynamic()) { + VLOG(2) << "Dynamic shape is not supported: " << op->shape().ToString(); + return false; + } + } + const Shape& operand_shape = sort_op->operand(0)->shape(); if (sort_op->sort_dimension() != operand_shape.dimensions().size() - 1) { VLOG(2) << "Sort dimension should be the minor one"; return false; } + if (!ShouldRewriteCompatibleSort(device_description, sort_op)) { VLOG(2) << "Tensor shape and type will not see an improvement."; return false; diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc index 5509441423e676..a7ba46633bf5d5 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -252,6 +252,44 @@ ENTRY %main { EXPECT_FALSE(RunModuleAndPass(module.get())); } +TEST_F(SortRewriterTest, NoRewriteDynamicSize) { + constexpr char kHlo[] = R"( +HloModule TestModule + +%compare { + %lhs = u8[] parameter(0) + %rhs = u8[] parameter(1) + ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT +} + +ENTRY %main { + %input = u8[100,<=100] parameter(0) + ROOT %sort = u8[100,<=100] sort(%input), dimensions={1}, to_apply=%compare +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + EXPECT_FALSE(RunModuleAndPass(module.get())); +} + +TEST_F(SortRewriterTest, NoRewriteDynamicBatch) { + constexpr char kHlo[] = R"( +HloModule TestModule + +%compare { + %lhs = u8[] parameter(0) + %rhs = u8[] parameter(1) + ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT +} + +ENTRY %main { + %input = u8[<=100,100] parameter(0) + ROOT %sort = u8[<=100,100] sort(%input), dimensions={1}, to_apply=%compare +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + EXPECT_FALSE(RunModuleAndPass(module.get())); +} + // Kernels are compiled for a subset of types. TEST_F(SortRewriterTest, NoRewriteUnsupportedType) { constexpr char kHlo[] = R"( From e514c1fd0faecc5862be8df57a27faf88deb387f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 15 Apr 2025 04:30:52 -0700 Subject: [PATCH 0801/1324] PR #21683: [XLA:GPU] NVSHMEM allocation Imported from GitHub PR https://github.com/openxla/xla/pull/21683 Requires https://github.com/openxla/xla/pull/20395 which adds the NVSHMEM library dependency. This PR adds the following: 1. Nvshmem flag to enable nvshmem 2. Set nvshmem initialization issue when GPU PJRT client is created. The first time NVSHMEM is used, it will be initialized. 3. Uses the user buffer memory pool for nvshmem. If nvshmem is enabled, it will be allocated using `nvshmem_malloc`. This same memory can be used by user buffers if nccl user buffers is also enabled. 4. Update the `CollectiveColorer` so that mosaic_gpu custom calls use the nvshmem memory space. Copybara import of the project: -- aee33791e16ab2149118de728dbb9e62f5e7cc31 by Trevor Morris : Add nvshmem flag, memory allocation, and memory space assignment Set Nvshmem env info during client creation Rename flag and use absl::string_view -- f8fca39300b3915eb6320142f58fa9c0ec7a1eaa by Trevor Morris : Use explicit types in test -- e41faa3f72b778fcf8ea8111d3cde59548b8f9f5 by Trevor Morris : Add user buffer allgather and allreduce tests with and without nvshmem alloc Set nvshmem in XLA_FLAGS test fixes formatting -- cf0c36865de8b8a010caaf62c3a36b64e36037bd by Trevor Morris : Fixes -- 3b4d11123cdb794d0a60e65b94d22ded04b7b2b4 by Trevor Morris : Remove early dso check -- 359f2b243ec97b1f8003c27f0b07dde82407ff6c by Trevor Morris : Add flag comment -- fd15a7cac745adc1971bec63e148047b9b811729 by Trevor Morris : Also assign memory space for mosaic_gpu_v2 Merging this change closes #21683 PiperOrigin-RevId: 747816712 --- .../xla/xla/backends/gpu/collectives/BUILD | 26 +- .../gpu/collectives/nccl_collectives.cc | 29 ++ .../gpu/collectives/nvshmem_collectives.cc | 6 + .../gpu/collectives/nvshmem_collectives.h | 29 +- third_party/xla/xla/debug_options_flags.cc | 6 + third_party/xla/xla/pjrt/gpu/BUILD | 50 ++++ .../gpu/se_gpu_pjrt_client_nvshmem_test.cc | 258 ++++++++++++++++++ third_party/xla/xla/service/gpu/BUILD | 2 + .../service/gpu/compile_module_to_llvm_ir.cc | 9 +- .../service/gpu/gpu_memory_space_assignment.h | 30 +- third_party/xla/xla/xla.proto | 6 +- 11 files changed, 425 insertions(+), 26 deletions(-) create mode 100644 third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 2f8c015eea9676..5d8d3c71afb69a 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -18,13 +18,22 @@ package_group( ], ) +config_setting( + name = "arm_build", + values = {"cpu": "arm"}, +) + # Build target that registers all available GPU collectives implementations with the collectives # registry at link time. cc_library( name = "gpu_collectives_plugin", deps = [ ":gpu_collectives_stub", - ] + if_nccl([":nccl_collectives"]), + ] + if_nccl([":nccl_collectives"]) + select({ + # TODO(b/409709288): Fix nvshmem ARM issues and remove this condition. + ":arm_build": [], + "//conditions:default": [":nvshmem_collectives"], + }), ) cc_library( @@ -222,6 +231,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:numbers", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", ]) + if_rocm_is_configured([ @@ -271,14 +281,11 @@ cc_library( cc_library( name = "nvshmem_collectives", - srcs = ["nvshmem_collectives.cc"], - hdrs = ["nvshmem_collectives.h"], - tags = [ - "cuda-only", - "gpu", - ], + srcs = if_cuda_is_configured(["nvshmem_collectives.cc"]), + hdrs = if_cuda_is_configured(["nvshmem_collectives.h"]), visibility = ["//visibility:private"], deps = [ + ":gpu_collectives", "//xla/core/collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", @@ -299,9 +306,8 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:numbers", - "@nvshmem//:nvshmem_lib", - ], - alwayslink = True, # registers collectives implementation + ] + if_cuda_is_configured(["@nvshmem//:nvshmem_lib"]), + alwayslink = True, ) xla_cc_test( diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index d54c91f00b0e8d..cfcd6469edf2b9 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/core/collectives/collectives_registry.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/debug_options_flags.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/gpu_executable_run_options.h" @@ -53,6 +54,7 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "tsl/platform/casts.h" +#include "tsl/platform/numbers.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -235,7 +237,24 @@ absl::Status NcclCollectives::GroupEnd() { return XLA_NCCL_STATUS(ncclGroupEnd()); } +static absl::StatusOr GetNvshmemCollectives() { + TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, + xla::CollectivesRegistry::Get("gpu", "nvshmem")); + xla::gpu::GpuCollectives* nvshmem_collectives = + tsl::down_cast(collectives); + if (nvshmem_collectives == nullptr) { + return absl::InternalError("Failed to get NVSHMEM collectives"); + } + + return nvshmem_collectives; +} + absl::StatusOr NcclCollectives::Allocate(uint64_t bytes) { + if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) { + TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives()); + return nvshmem_collectives->Allocate(bytes); + } + void* ptr = nullptr; ncclResult_t res = ncclMemAlloc(&ptr, bytes); if (res != ncclSuccess) { @@ -251,6 +270,11 @@ absl::StatusOr NcclCollectives::Allocate(uint64_t bytes) { } absl::Status NcclCollectives::Deallocate(void* location) { + if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) { + TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives()); + return nvshmem_collectives->Deallocate(location); + } + ncclResult_t res = ncclMemFree(location); if (res != ncclSuccess) { return absl::InternalError(absl::StrFormat( @@ -318,6 +342,11 @@ class NcclIdStore { absl::Status NcclCollectives::InitializeTopology( NcclCollectives::Topology topology) { + if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) { + TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives()); + TF_RETURN_IF_ERROR(nvshmem_collectives->InitializeTopology(topology)); + } + if (topology.num_nodes > 1) { auto nccl_id_store = std::make_shared( topology.node_id, topology.device_id_to_node_id, diff --git a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc index acb5c8cf13eba5..37462433ab9b99 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc @@ -57,6 +57,12 @@ NvshmemCollectives* NvshmemCollectives::Default() { LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM"; } +absl::Status NvshmemCollectives::InitializeTopology(Topology topology) { + SetEnvInfo(topology.node_id, topology.num_nodes, + topology.device_count_per_process, topology.kv_store); + return absl::OkStatus(); +} + void NvshmemCollectives::SetEnvInfo( int process_id, size_t num_processes, size_t device_count_per_process, std::weak_ptr kv_store) { diff --git a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h index f7cd034e1c8a86..3019feae0cc091 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/collectives.h" @@ -35,7 +36,7 @@ limitations under the License. namespace xla::gpu { // NVIDIA NVSHMEM library -class NvshmemCollectives : public Collectives { +class NvshmemCollectives : public GpuCollectives { public: ~NvshmemCollectives() override; @@ -45,28 +46,46 @@ class NvshmemCollectives : public Collectives { size_t device_count_per_process, std::weak_ptr kv_store); - absl::StatusOr Allocate(uint64_t bytes); + absl::StatusOr Allocate(uint64_t bytes) final; - absl::Status Deallocate(void* buffer); + absl::Status Deallocate(void* buffer) final; absl::StatusOr CreateUniqueCliqueId() const final { return absl::UnimplementedError("Not implemented."); } + absl::Status GroupStart() final { + return absl::UnimplementedError("Not implemented."); + } + absl::Status GroupEnd() final { + return absl::UnimplementedError("Not implemented."); + } + + bool IsImplemented() const final { return true; } + + bool IsGlobalConfig() const final { return false; } + + absl::StatusOr GetCliqueIdCallback( + const CliqueIdCallback* clique_id_callback, bool is_local) final { + return absl::UnimplementedError("Not implemented."); + } + absl::StatusOr>> CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_ids, absl::Span ranks, - const Config& config) final { + const Collectives::Config& config) { return absl::UnimplementedError("Not implemented."); } absl::StatusOr>> SplitCommunicators( absl::Span comms, int32_t color, - absl::Span keys, const Config& config) final { + absl::Span keys, const Collectives::Config& config) final { return absl::UnimplementedError("Not implemented."); } + absl::Status InitializeTopology(Topology topology) final; + private: absl::Status InitializeOnce(); diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index fa4446a4f78a93..34b9637d51316d 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -167,6 +167,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); opts.set_xla_gpu_enable_nccl_user_buffers(false); + opts.set_xla_gpu_experimental_enable_nvshmem(false); opts.set_xla_gpu_enable_nccl_comm_splitting(true); opts.set_xla_gpu_nccl_init_max_rank_per_root_ratio(0); @@ -1581,6 +1582,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Enables NCCL User Buffer Registration. collective_memory_size in the " "allocator config must also be set to a non-zero value that is large " "enough to meet peak collective memory usage.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_nvshmem", + bool_setter_for(&DebugOptions::set_xla_gpu_experimental_enable_nvshmem), + debug_options->xla_gpu_experimental_enable_nvshmem(), + "Enables NVSHMEM.")); flag_list->push_back(tsl::Flag( "xla_gpu_temp_buffer_use_separate_color", bool_setter_for( diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index ee0363a8f51e7e..91382cafe00752 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -237,6 +237,56 @@ xla_test( ], ) +# TODO(b/409713313): Move this test to collectives directory. +xla_test( + name = "se_gpu_pjrt_client_nvshmem_test", + srcs = ["se_gpu_pjrt_client_nvshmem_test.cc"], + backend_tags = {"gpu": [ + "multi_gpu_h100", + "no_oss", + "noasan", + "notap", # TODO(b/399931591): Re-enable once flakiness is resolved. + "nomsan", + ]}, + backends = ["gpu"], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + }, + deps = [ + ":gpu_topology_proto_cc", + ":se_gpu_pjrt_client", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/backends/gpu/collectives:gpu_collectives", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", + "//xla/hlo/utils:hlo_query", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:raw_buffer", + "//xla/pjrt/distributed", + "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:in_memory_key_value_store", + "//xla/pjrt/distributed:service", + "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", + "//xla/service:platform_util", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + xla_test( name = "pjrt_client_test_se_gpu", srcs = ["pjrt_client_test_se_gpu.cc"], diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc new file mode 100644 index 00000000000000..3200b37414d00e --- /dev/null +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc @@ -0,0 +1,258 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/layout.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/gpu/gpu_topology.pb.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" +#include "xla/pjrt/raw_buffer.h" +#include "xla/service/platform_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ::testing::NotNull; +using ::testing::SizeIs; + +HloInstruction* FindInstruction(const HloModule* module, HloOpcode opcode) { + for (const HloComputation* computation : module->computations()) { + if (HloInstruction* instruction = + hlo_query::FindInstruction(computation, opcode)) { + return instruction; + } + } + return nullptr; +} + +absl::StatusOr> CompileExecutable( + absl::string_view program, xla::PjRtClient& client, + xla::CompileOptions compile_options = xla::CompileOptions()) { + TF_ASSIGN_OR_RETURN(auto hlo_module, + ParseAndReturnUnverifiedModule(program, {})); + + xla::XlaComputation xla_computation(hlo_module->ToProto()); + return client.CompileAndLoad(xla_computation, compile_options); +} + +// Register a mock "mosaic_gpu" custom call op for NvshmemMemoryTest, since +// mosaic_gpu is defined in JAX and won't be available to the unit test. +static absl::Status MockMosaicGpu(ffi::AnyBuffer arg, + ffi::Result ret, + absl::string_view module) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kMockMosaicGpu, MockMosaicGpu, + ffi::Ffi::Bind() + .Arg() + .Ret() + .Attr("module")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu", + PlatformUtil::CanonicalPlatformName("GPU").value(), + kMockMosaicGpu); + +// Verify that the client can initialize NVSHMEM and that buffers used by +// mosaic_gpu custom calls are assigned to the collective memory space. +TEST(StreamExecutorGpuClientTest, NvshmemMemoryTest) { + static constexpr char const* kProgram = R"( + HloModule ffi_handler + ENTRY main { + param = s32[1,4]{1,0} parameter(0) + reshape = s32[4]{0} reshape(param) + ROOT %custom-call = s32[4] custom-call(param), + custom_call_target="mosaic_gpu", + api_version=API_VERSION_TYPED_FFI, + backend_config={"custom_call_backend_config": {"attributes": "{module = \"nvshmem\"}"}} + })"; + // Nvshmem requires one gpu per process. + GpuClientOptions client_options; + client_options.node_id = 0; + client_options.allowed_devices = {0}; + client_options.num_nodes = 1; + client_options.kv_store = std::make_shared(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetStreamExecutorGpuClient(client_options)); + xla::CompileOptions options; + options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_experimental_enable_nvshmem(true); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + CompileExecutable(kProgram, *client, options)); + std::vector data{1, 2, 3, 4}; + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1, 4}, + /*minor_to_major=*/{1, 0}); + shape.mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); + + PjRtDevice* const device = client->addressable_devices()[0]; + TF_EXPECT_OK(device->default_memory_space()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr input, + client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, *device->default_memory_space(), + /*device_layout=*/nullptr)); + EXPECT_EQ(input->memory_space()->kind(), "device"); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector> memory_kinds, + executable->GetOutputMemoryKinds()); + EXPECT_EQ(memory_kinds.size(), 1); + EXPECT_EQ(memory_kinds[0].size(), 1); + EXPECT_EQ(memory_kinds[0][0], "device"); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector>> result, + executable->Execute({{input.get()}}, ExecuteOptions())); + std::vector>& result_buffers = result[0]; + EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device"); + Shape result_shape = result_buffers[0]->on_device_shape(); + int64_t memory_space = result_shape.layout().memory_space(); + EXPECT_EQ(memory_space, 1); +} + +absl::Status UserBufferWithNvshmemMallocTestBody(const int node_id, + const int num_nodes) { + const absl::string_view kModuleStr = R"( +HloModule test +apply_op { +x = u32[] parameter(0) +y = u32[] parameter(1) +ROOT apply_op = u32[] add(x, y) +} +ENTRY test_computation { +id = u32[] replica-id() +ROOT all-reduce = u32[] all-reduce(id), to_apply=apply_op +} +)"; + std::unique_ptr service; + if (node_id == 0) { + xla::CoordinationServiceImpl::Options service_options; + service_options.num_nodes = num_nodes; + TF_ASSIGN_OR_RETURN(service, xla::GetDistributedRuntimeService( + "[::]:12346", service_options)); + } + + xla::DistributedRuntimeClient::Options distributed_options; + distributed_options.node_id = node_id; + distributed_options.init_timeout = absl::Seconds(120); + auto distributed_client = + GetDistributedRuntimeClient("127.0.0.1:12346", distributed_options); + TF_QCHECK_OK(distributed_client->Connect()); + GpuClientOptions client_options; + client_options.node_id = node_id; + client_options.allowed_devices = {node_id}; + client_options.num_nodes = num_nodes; + client_options.kv_store = + GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:"); + ; + TF_ASSIGN_OR_RETURN(std::unique_ptr client, + GetStreamExecutorGpuClient(client_options)); + xla::CompileOptions options; + options.executable_build_options.set_num_replicas(num_nodes); + options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_experimental_enable_nvshmem(true); + options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_enable_nccl_user_buffers(true); + + TF_ASSIGN_OR_RETURN(auto hlo_module, + ParseAndReturnUnverifiedModule(kModuleStr, {})); + xla::XlaComputation xla_computation(hlo_module->ToProto()); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + client->CompileAndLoad(xla_computation, options)); + + // Verify that the collective memory space is used. + TF_ASSIGN_OR_RETURN(auto modules, executable->GetHloModules()); + HloInstruction* all_reduce_start = + FindInstruction(modules[0].get(), HloOpcode::kAllReduceStart); + EXPECT_THAT(all_reduce_start, NotNull()); + EXPECT_EQ(all_reduce_start->shape().layout().memory_space(), 1); + EXPECT_THAT(all_reduce_start->operands(), SizeIs(1)); + const HloInstruction* input = all_reduce_start->operand(0); + EXPECT_EQ(input->shape().layout().memory_space(), 1); + + TF_ASSIGN_OR_RETURN( + std::vector>> results, + executable->Execute(/*argument_handles=*/{{}}, /*options=*/{})); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0].size(), 1); + TF_ASSIGN_OR_RETURN(auto literal, results[0][0]->ToLiteralSync()); + if (node_id == 0) { + LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16}, *literal); + } else if (node_id == 1) { + LiteralTestUtil::ExpectR1Equal({20, 25, 21, 26}, *literal); + } + return absl::OkStatus(); +} + +} // namespace +} // namespace xla + +int main(int argc, char* argv[]) { + int node_id = -1; + int num_nodes = -1; + std::vector flag_list = { + tsl::Flag("node_id", &node_id, "Node ID for multiprocess tests."), + tsl::Flag("num_nodes", &num_nodes, + "Number of nodes for multiprocess tests."), + }; + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + tsl::Flags::Parse(&argc, argv, flag_list); + testing::InitGoogleTest(&argc, argv); + if (node_id >= 0) { + absl::Status result = + xla::UserBufferWithNvshmemMallocTestBody(node_id, num_nodes); + if (!result.ok()) { + LOG(ERROR) << result; + } + return result.raw_code(); + } + return RUN_ALL_TESTS(); +} diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index de44c6df57fdcc..218adadff90a5e 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -115,7 +115,9 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:hlo_value", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index c15682a45aeb91..930ef2126c250b 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -250,9 +250,12 @@ absl::StatusOr> RunBufferAssignment( ScopedAnnotation annotation(Phase("XlaBufferAssignment", module)); const DebugOptions& options = module->config().debug_options(); - BufferAssigner::Colorer colorer = options.xla_gpu_enable_nccl_user_buffers() - ? CollectiveColorer() - : BufferAssigner::DefaultColorer(); + BufferAssigner::Colorer colorer = + (options.xla_gpu_enable_nccl_user_buffers() || + options.xla_gpu_experimental_enable_nvshmem()) + ? CollectiveColorer(options.xla_gpu_enable_nccl_user_buffers(), + options.xla_gpu_experimental_enable_nvshmem()) + : BufferAssigner::DefaultColorer(); std::optional color = options.xla_gpu_temp_buffer_use_separate_color() diff --git a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h index 32d1906c30dc07..220e15ede9cd3f 100644 --- a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h +++ b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_ordering.h" @@ -34,8 +35,10 @@ inline constexpr int64_t kTempBufferMemorySpaceColor = 2; // Set memory space to kCollectiveMemorySpaceColor for all allocations used by // all-reduce, all-gather, and reduce-scatter. This memory space maps to // collective memory using ncclMemAlloc in the runtime. -inline BufferAssigner::Colorer CollectiveColorer() { - return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { +inline BufferAssigner::Colorer CollectiveColorer(bool use_user_buffers, + bool use_nvshmem) { + return [use_user_buffers, use_nvshmem](HloAliasAnalysis* alias_analysis, + const HloOrdering&) { static const auto* kSupportedOpcodes = new absl::flat_hash_set{ HloOpcode::kAllReduce, HloOpcode::kAllReduceStart, @@ -49,12 +52,25 @@ inline BufferAssigner::Colorer CollectiveColorer() { HloOpcode::kCollectivePermuteDone, HloOpcode::kAllToAll, }; + auto is_mosaic_gpu_nvshmem_instr = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kCustomCall && + (instr->custom_call_target() == "mosaic_gpu" || + instr->custom_call_target() == "mosaic_gpu_v2") && + instr->raw_backend_config_string().find("nvshmem") != + std::string::npos; + }; auto is_collective_memory_instr = [&](const HloInstruction* instr) { - return kSupportedOpcodes->contains(instr->opcode()) || - // opcode or async wrapped opcode is in kSupportedOpcodes. - ((instr->opcode() == HloOpcode::kAsyncStart || - instr->opcode() == HloOpcode::kAsyncDone) && - kSupportedOpcodes->contains(instr->async_wrapped_opcode())); + if (use_user_buffers) { + return kSupportedOpcodes->contains(instr->opcode()) || + // opcode or async wrapped opcode is in kSupportedOpcodes. + ((instr->opcode() == HloOpcode::kAsyncStart || + instr->opcode() == HloOpcode::kAsyncDone) && + kSupportedOpcodes->contains(instr->async_wrapped_opcode())); + } + if (use_nvshmem) { + return is_mosaic_gpu_nvshmem_instr(instr); + } + return false; }; auto has_collective_memory_in_uses = [&](const HloValue* input_alias) { // If any use is a collective instruction, we must color the value to use diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 31007c9e0d54d6..813fe8813a26f3 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -551,6 +551,10 @@ message DebugOptions { // Pre-existing block-level fusions are left unmodified. bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334; + // Enable NVSHMEM. Must be set via XLA_FLAGS variable before XLA client is + // initialized and can't be set just through HLO Config->ExecutionOptions. + bool xla_gpu_experimental_enable_nvshmem = 387; + // Enable the pass that splits GEMMs that underutilize the GPU load by // splitting the K dimension using a heuristic. bool xla_gpu_experimental_enable_split_k_rewrite = 386; @@ -1202,7 +1206,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 387 + // Next id: 388 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 80ddf326c6f873a4637ed33d991e0c415784886f Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Tue, 15 Apr 2025 04:31:14 -0700 Subject: [PATCH 0802/1324] Cleanup unused arg in `xla::sdy::setFrontendAttribute` PiperOrigin-RevId: 747816802 --- .../ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc | 6 +++--- third_party/xla/xla/service/spmd/shardy/utils.cc | 3 +-- third_party/xla/xla/service/spmd/shardy/utils.h | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc index 556b5088ec9c76..e5a828bfbab81a 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc @@ -160,9 +160,9 @@ void IfrtCompileAtomProgramPass::runOnOperation() { "sdy."); return mlir::WalkResult::advance(); } - xla::sdy::setFrontendAttribute( - callee_module, xla::sdy::kMeshesRoundTripAttr, - sdy_meshes_round_trip_attr, /*escapeAttr=*/false); + xla::sdy::setFrontendAttribute(callee_module, + xla::sdy::kMeshesRoundTripAttr, + sdy_meshes_round_trip_attr); } absl::StatusOr compile_future = diff --git a/third_party/xla/xla/service/spmd/shardy/utils.cc b/third_party/xla/xla/service/spmd/shardy/utils.cc index e0ff8017583062..1a71dd491a89b2 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.cc +++ b/third_party/xla/xla/service/spmd/shardy/utils.cc @@ -133,8 +133,7 @@ void setFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index, } // namespace -void setFrontendAttribute(Operation* op, StringRef name, Attribute value, - bool) { +void setFrontendAttribute(Operation* op, StringRef name, Attribute value) { SmallVector existingAttributes = getExistingFrontendAttributes(getFrontendAttrs(op), ""); setFrontendAttribute(existingAttributes, name, value); diff --git a/third_party/xla/xla/service/spmd/shardy/utils.h b/third_party/xla/xla/service/spmd/shardy/utils.h index 54134ce9986ed9..a04daa33541f96 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.h +++ b/third_party/xla/xla/service/spmd/shardy/utils.h @@ -49,9 +49,8 @@ mlir::DictionaryAttr getFuncArgFrontendAttrs(mlir::func::FuncOp funcOp, // Adds `name` into the frontend attributes of `op` with value `value`. If // `name` already exists, it will be overwritten. Note that `value` will be // turned into a `StringAttr`. -// TODO(tomnatan): cleanup `escapeAttr` void setFrontendAttribute(mlir::Operation* op, mlir::StringRef name, - mlir::Attribute value, bool escapeAttr = true); + mlir::Attribute value); // Adds `name` into the argument at `argNum`'s frontend attributes of `funcOp` // with value `value`. If `name` already exists, it will be overwritten. Note From ae6e19ef66ded2f308e50a51a717407ce474913e Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 15 Apr 2025 04:51:24 -0700 Subject: [PATCH 0803/1324] Reverts 5ea25afd71164ced12527bfbf8a12330a201d416 PiperOrigin-RevId: 747821896 --- .../xla/xla/backends/gpu/runtime/cub_sort_thunk.cc | 8 ++++++-- third_party/xla/xla/service/gpu/build_defs.bzl | 3 +++ third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc | 9 +++++++++ third_party/xla/xla/service/gpu/cub_sort_kernel.h | 4 ++++ .../xla/xla/service/gpu/tests/gpu_cub_sort_test.cc | 2 +- 5 files changed, 23 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc index 8195bf0a1ab226..f4b35813f4f4c2 100644 --- a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc @@ -241,8 +241,8 @@ absl::StatusOr> CreateCubSortRunner( } // Returns an interface for calling CubSortPairs on the given key and value -// types. key_type can be only unsigned integer types. value_type can be any -// type of 16/32/64 bit width. +// types. key_type can be any unsigned integer types or F32. value_type can be +// any type of 16/32/64 bit width. absl::StatusOr> CreateCubSortRunner( PrimitiveType key_type, PrimitiveType value_type) { int value_width = primitive_util::BitWidth(value_type); @@ -260,6 +260,10 @@ absl::StatusOr> CreateCubSortRunner( if (key_type == U64 && value_width == 32) sort_fn = CubSortPairs_u64_b32; if (key_type == U64 && value_width == 64) sort_fn = CubSortPairs_u64_b64; + if (key_type == F32 && value_width == 16) sort_fn = CubSortPairs_f32_b16; + if (key_type == F32 && value_width == 32) sort_fn = CubSortPairs_f32_b32; + if (key_type == F32 && value_width == 64) sort_fn = CubSortPairs_f32_b64; + if (sort_fn == nullptr) { return InvalidArgument( "Unsupported key/value type combination for CubSortPairs: %s/%s", diff --git a/third_party/xla/xla/service/gpu/build_defs.bzl b/third_party/xla/xla/service/gpu/build_defs.bzl index 9ae3e2ab0b08f9..e5f3036865a225 100644 --- a/third_party/xla/xla/service/gpu/build_defs.bzl +++ b/third_party/xla/xla/service/gpu/build_defs.bzl @@ -42,6 +42,9 @@ def get_cub_sort_kernel_types(name = ""): "u8_b16", "u8_b32", "u8_b64", + "f32_b16", + "f32_b32", + "f32_b64", ] def build_cub_sort_kernels(name, types, local_defines = [], **kwargs): diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc index 507bd4d7da953f..e5d9db3f52ee75 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc @@ -244,6 +244,15 @@ XLA_CUB_DEFINE_SORT_PAIRS(u32_b32, uint32_t, uint32_t) #ifdef CUB_TYPE_U32_B64 XLA_CUB_DEFINE_SORT_PAIRS(u32_b64, uint32_t, uint64_t) #endif +#ifdef CUB_TYPE_F32_B16 +XLA_CUB_DEFINE_SORT_PAIRS(f32_b16, float, uint16_t) +#endif +#ifdef CUB_TYPE_F32_B32 +XLA_CUB_DEFINE_SORT_PAIRS(f32_b32, float, uint32_t) +#endif +#ifdef CUB_TYPE_F32_B64 +XLA_CUB_DEFINE_SORT_PAIRS(f32_b64, float, uint64_t) +#endif // Pairs with 64-bit key. #ifdef CUB_TYPE_U64_B16 diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.h b/third_party/xla/xla/service/gpu/cub_sort_kernel.h index 29b163e7b1bf0b..627dd7ef079b84 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.h +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.h @@ -67,6 +67,10 @@ XLA_CUB_DECLARE_SORT_PAIRS(u64_b16) XLA_CUB_DECLARE_SORT_PAIRS(u64_b32) XLA_CUB_DECLARE_SORT_PAIRS(u64_b64) +XLA_CUB_DECLARE_SORT_PAIRS(f32_b16) +XLA_CUB_DECLARE_SORT_PAIRS(f32_b32) +XLA_CUB_DECLARE_SORT_PAIRS(f32_b64) + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index d37d322ddb32f2..eb22488d981ffa 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -306,7 +306,7 @@ ENTRY m { INSTANTIATE_TEST_SUITE_P( CubSort, CubSortPairsTest, - ::testing::Combine(::testing::Values(U8, U16, U32, U64), + ::testing::Combine(::testing::Values(U8, U16, U32, U64, F32), ::testing::Values(F16, F32, F64), ::testing::Bool(), ::testing::Values(1, 10)), [](const ::testing::TestParamInfo& info) { From 3fc7a2e7debe61e049345318be31aa1733105e45 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 15 Apr 2025 04:53:25 -0700 Subject: [PATCH 0804/1324] [XLA:GPU] Enable remaining passing tests in `fusion_emitter_device_legacy_port_test.cc` and add TODOs for the others. Removed 15 additional tests that were no longer relevant as well. All the removed tests were confirmed to work manually before being deleted. We list them here, along with a reason for deleting them: * `TritonGemmTestWithSplitK.WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK`: this test just works, but the condition being tested is unclear, and easily falsifiable based on the documented expectations. * `TritonGemmTest.CanCodegenNonBatchedDotWithConcatenationCorrectly`: covered by an existing test (`TritonEmitterTest.ConcatenateOfNestsIsEmittedCorrectly` in `fusion_emitter_device_test.cc`; * `TritonGemmTest.CanCodegenBatchedDotWithConcatenationCorrectly`: this is a combination of two orthogonal features---supporting batch dimensions, and supporting `concatenate`s. Batch dimensions are well-covered, and the same argument as for `TritonGemmTest.CanCodegenNonBatchedDotWithConcatenationCorrectly` holds for `concatenate`; * `TritonGemmTest.TritonCompilerDoesNotFailOnConstants`: this is a regression test, but it's not written in a very reliable way---it basically only tested that autotuning wouldn't fail. There are already similar tests that `broadcast` constants whose lowering is tested and that would catch the same issue; * `TritonGemmTest.TritonEmitterCanHandleTransposes`: this is meant to be a standalone test for `transpose`. This is already covered by `TritonEmitterTest.Transpose3D`, for instance. * `TritonGemmTest.DoF32F32`: covered by existing `AlgUnset` tests; * `TritonGemmTest.DoAddConstantToScalarAndBroadcastThat`: just another test of compositionality---this does not bring much coverage; * `TritonGemmTest.BroadcastOfVectorParameterIsFused`: this test starts from a post-optimized HLO, so it's not even testing any fusion behaviour. As a `RunAndCompare` test, it's only testing compositionality again; * `TritonGemmTest.BroadcastOfScalarWorksCorrectly`: covered more extensively by `TritonEmitterTest.Multiple0DBroadcastsAreSupported`; * `TritonTest.Fp8LoweringIsSupportedPostHopper`: covered by `DotUnsetAlgorithmEmitterTest/UnsetAlgorithmIsEmittedCorrectly`; * `TritonTest.BF16ToFP8EndToEnd`: covered by `DotUnsetAlgorithmEmitterTest/UnsetAlgorithmIsEmittedCorrectly`; * `TritonTest.FP8ToFP8EndToEnd`: covered by `DotUnsetAlgorithmEmitterTest/UnsetAlgorithmIsEmittedCorrectly`; * `TritonGemmTest.S8xS8`: covered by `DotUnsetAlgorithmEmitterTest/UnsetAlgorithmIsEmittedCorrectly`; * `TritonGemmTest.FP8DotSmallTileDoesNotCrash`: covered by `DotUnsetAlgorithmEmitterTest/UnsetAlgorithmIsEmittedCorrectly`; * `TritonTest.FloatToSignedIntConversion`: the (legacy) HLO in that test was lowered to a `dot` fusion that would crash further down the compilation pipeline. We should make a proper and extensive `convert` test in the future, but this wasn't it. PiperOrigin-RevId: 747822366 --- .../gpu/codegen/triton/fusion_emitter.cc | 4 + .../fusion_emitter_device_legacy_port_test.cc | 848 +++++------------- 2 files changed, 236 insertions(+), 616 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index f93698bd992c99..75ab9257215452 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1388,6 +1388,10 @@ absl::StatusOr> EmitGeneric( const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, mlir::FunctionOpInterface fn, const BlockLevelParameters& block_level_parameters) { + if (VLOG_IS_ON(6)) { + VLOG(6) << "Emitting Triton IR for fusion\n" + << ExtractInstructionIntoNewModule(*fusion)->ToString(); + } const HloComputation* computation = fusion->fused_instructions_computation(); SymbolicTileAnalysisOrError symbolic_tile_analysis_or = SymbolicTileAnalysis::AnalyzeComputation( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc index d652df9e2332f8..8811f4edde2bd2 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_port_test.cc @@ -74,12 +74,26 @@ struct ModuleAndNestedFusionMetadata { BlockLevelParameters block_level_parameters; }; +// Returns the "real" root instruction of a computation, which is either the +// root instruction itself, or the first instruction feeding into the root that +// is not a bitcast. +HloInstruction* GetNonBitcastRoot(const HloComputation* computation) { + HloInstruction* root = computation->root_instruction(); + while (root->opcode() == HloOpcode::kBitcast) { + root = root->mutable_operand(0); + } + return root; +} + class TritonTest : public GpuCodegenTest { public: DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options .set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(true); + // Disable autotuning by default, re-enable it on a per-test basis in order + // to avoid unnecessary slowness. + debug_options.set_xla_gpu_autotune_level(0); return debug_options; } @@ -99,11 +113,8 @@ class TritonTest : public GpuCodegenTest { GpuComputeCapability())) { return stream_executor::GpuComputeCapability{ device_desc().rocm_compute_capability()}; - } else { - return stream_executor::GpuComputeCapability{ - stream_executor::CudaComputeCapability{ - stream_executor::CudaComputeCapability::kAmpere, 0}}; } + return se::CudaComputeCapability::Ampere(); } // Returns the module, its fusion computation and associated block level @@ -172,35 +183,6 @@ class TritonGemmTestWithSplitK : public TritonGemmTest { } }; -// TODO(b/393299275): requires enabling mixed-type dots for f8xf8->bf16. -TEST_F(TritonGemmTest, DISABLED_FP8DotSmallTileDoesNotCrash) { - if (!GetCudaComputeCapability().IsAtLeastHopper()) { - GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; - } - - constexpr absl::string_view kHloText = R"( -triton_dot { - p0 = f8e4m3fn[32,32]{1,0} parameter(0) - p1 = f8e4m3fn[32,32]{1,0} parameter(1) - ROOT dot = bf16[32,32]{1,0} dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f8e4m3fn[32,32]{1,0} parameter(0) - p1 = f8e4m3fn[32,32]{1,0} parameter(1) - ROOT _ = bf16[32,32] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":16,"block_n":16,"block_k":16, - "split_k":1,"num_stages":2,"num_warps":2, - "num_ctas":1}}} -})"; - TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, - GetModuleAndNestedFusionMetadata(kHloText)); - EXPECT_TRUE(Run(std::move(module_and_metadata.module), - /*run_hlo_passes=*/false)); -} - TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -312,6 +294,8 @@ CHECK: tt.dot {{.*}} : tensor<16x32xf32> * tensor<32x64xf32> -> tensor<16x64xf32 )")); } +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_F(TritonTest, DISABLED_CodegenDynamicSliceWithCorrectOffsets) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. @@ -477,19 +461,22 @@ CHECK: mma )"); } -TEST_F(TritonGemmTest, DISABLED_FailIfTooMuchShmem) { - if (std::holds_alternative( - GpuComputeCapability())) { - GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; - } - constexpr absl::string_view kHloText = R"( -HloModule module, is_scheduled=true +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonGemmTest, FailIfTooMuchShmem) { + auto cc = se::CudaComputeCapability::Ampere(); + const se::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + constexpr absl::string_view kHloTextTemplate = R"( triton_gemm_dot { p0 = s8[1024,1024] parameter(0) p1 = f32[1024,1024] parameter(1) c0 = f32[1024,1024] convert(p0) - ROOT dot.0 = f32[1024,1024] dot(c0, p1), + ROOT dot = f32[1024,1024] dot(c0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -498,94 +485,41 @@ ENTRY entry { p1 = f32[1024,1024] parameter(1) ROOT r = f32[1024,1024] fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":$0,"block_n":$1,"block_k":$2, + "split_k":1,"num_stages":$3,"num_warps":4, + "num_ctas":1}}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kHloText)); - HloFusionInstruction* triton_dot_fusion = Cast( - hlo_module->entry_computation()->root_instruction()); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - llvm::LLVMContext llvm_ctx; - llvm::Module llvm_module("module", llvm_ctx); - mlir::MLIRContext mlir_context; - - auto backend_config_or = - triton_dot_fusion->backend_config(); - TF_ASSERT_OK(backend_config_or); - GpuBackendConfig& backend_config = *backend_config_or; - - FusionBackendConfig& fusion_backend_config = - *backend_config.mutable_fusion_backend_config(); - auto& config = *fusion_backend_config.mutable_triton_gemm_config(); - config.set_block_m(16); - config.set_block_n(32); - config.set_block_k(512); - config.set_split_k(1); - config.set_num_ctas(1); - config.set_num_warps(8); - config.set_num_stages(4); - - TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); - - BlockLevelParameters block_level_parameters; - block_level_parameters.num_ctas = 1; - block_level_parameters.num_stages = 4; - block_level_parameters.num_warps = 8; + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module1_and_metadata, + GetModuleAndNestedFusionMetadata(absl::Substitute( + kHloTextTemplate, 16, 32, 512, 8))); + const HloFusionInstruction* fusion1 = Cast( + module1_and_metadata.computation->FusionInstruction()); EXPECT_THAT( - TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, - block_level_parameters, &llvm_module, mlir_context), + TritonWrapper("test_fn", fusion1, cc, device_info, + module1_and_metadata.block_level_parameters, &llvm_module, + mlir_context), StatusIs(tsl::error::RESOURCE_EXHAUSTED, ::testing::HasSubstr("Shared memory size limit exceeded"))); - config.set_block_m(64); - config.set_block_n(128); - config.set_block_k(128); - block_level_parameters.num_stages = 1; - TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module2_and_metadata, + GetModuleAndNestedFusionMetadata(absl::Substitute( + kHloTextTemplate, 64, 128, 128, 1))); + + const HloFusionInstruction* fusion2 = Cast( + module2_and_metadata.computation->FusionInstruction()); TF_ASSERT_OK_AND_ASSIGN( const auto result, - TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, - block_level_parameters, &llvm_module, mlir_context)); + TritonWrapper("test_fn", fusion2, cc, device_info, + module2_and_metadata.block_level_parameters, &llvm_module, + mlir_context)); // Use optin shared memory which is > shared_memory_per_block. - EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); -} - -TEST_F(TritonGemmTestWithSplitK, - DISABLED_WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK) { - // The condition mentioned in the test name is fulfilled by - // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for - // Ampere at the time of the addition of this test case. - constexpr absl::string_view kHloText = R"( -HloModule extracted - -ENTRY e { - a = f16[16,5120]{1,0} parameter(0) - b = s8[5120,10240]{1,0} parameter(1) - converted_b = f16[5120,10240]{1,0} convert(b) - ROOT r = f16[16,10240]{1,0} dot(a, converted_b), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - // This check tests if Triton is used at all plus it runs GemmFusionAutotuner, - // which verifies if the generated kernels can run without errors such as - // CUDA_ERROR_ILLEGAL_ADDRESS. - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": - )"); - - // Not doing a comparison here, because the input matrices are quite big. - // If I reduce their size then they can no longer trigger the error, that I - // want to avoid with this test case. + EXPECT_GT(result.shmem_bytes, device_info.shared_memory_per_block()); } +// TODO(b/393299275): there is a miscompile here. TEST_F(TritonGemmTest, DISABLED_MultipleDims) { constexpr absl::string_view kHloText = R"( HloModule t @@ -600,11 +534,10 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter +; CHECK-NOT: convert ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: "__triton_nested_gemm_fusion" )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -636,21 +569,33 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -// TODO(b/393299275): requires enabling mixed-type dots for s8xs8->s32. -TEST_F(TritonGemmTest, DISABLED_S8xS8) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY f { - x = s8[1024,1024]{1,0} parameter(0) - y = s8[1024,1024]{1,0} parameter(1) - ROOT z = s32[1024,1024]{1,0} dot(x, y), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - MatchOptimizedHlo(kHloText, "CHECK: __triton_nested_gemm_fusion"); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - +// TODO(b/393299275, b/410085031): requires canonicalizing the transpose in +// order to be able to go through tile constraints. We end up trying to +// propagate a tile with sizes (32, 32) upwards through the following ops: +// +// p0 = pred[3,122,96,12]{3,2,1,0} parameter(0) +// transpose = pred[3,96,12,122]{3,2,1,0} transpose(p0), dimensions={0,2,3,1} +// bitcast = pred[3456,122]{1,0} bitcast(transpose) +// +// Unfortunately, there is no way to propagate such tile sizes through the +// bitcast, since the trailing dimension of the reshaped dimension has size 12, +// which is not divisible by any power of 2. BUT! The legacy emitter also has to +// work around this problem, since it has generate a tensor pointer, which +// requires coming up with a tile-like structure for the parameter load. +// +// The reason this works is that we can actually rewrite the HLO to collapse +// the dimensions of size 96 and 12 at every step, and that those dimensions are +// initially contiguous: +// +// p0 = pred[3,122,1152]{3,2,1,0} parameter(0) +// transpose = pred[3,1152,122]{3,2,1,0} transpose(p0), dimensions={0,2,1} +// bitcast = pred[3456,122]{1,0} bitcast(transpose) +// +// (with a hoisted bitcast in the caller giving the right logical shape to the +// parameter). The resulting dimension has length 1152, which is divisible by +// 128 and therefore allows tile propagation to proceed smoothly. The legacy +// emitter essentially does this implicitly in code generation instead of +// materializing it in HLO. TEST_F(TritonGemmTest, DISABLED_SplitLhsNoncontractingTransposeRhs) { constexpr absl::string_view kHloText = R"( HloModule t @@ -670,16 +615,15 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: __triton_nested_gemm_fusion )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } +// TODO(b/393299275): requires hoisting bitcasts through transposes. TEST_F(TritonGemmTest, DISABLED_SplitLhsNoncontracting) { constexpr absl::string_view kHloText = R"( -Hl–oModule t - ENTRY e { p0 = f32[72,72] parameter(0) bc1 = f32[4,3,3,2,4,3,3,2] reshape(p0) @@ -697,7 +641,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: "block_m": +; CHECK-SAME: __triton_nested_gemm_fusion )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -883,102 +827,21 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); } -TEST_F(TritonGemmTest, - DISABLED_CanCodegenNonBatchedDotWithConcatenationCorrectly) { - constexpr absl::string_view kHloText = R"( -ENTRY e { - parameter_0 = f32[3,10]{1,0} parameter(0) - parameter_1 = f32[10,128]{1,0} parameter(1) - parameter_2 = f32[10,256]{1,0} parameter(2) - concatenate = f32[10,384]{1,0} concatenate(parameter_1, parameter_2), dimensions={1} - ROOT dot = f32[3,384]{1,0} dot(parameter_0, concatenate), - lhs_batch_dims={}, lhs_contracting_dims={1}, - rhs_batch_dims={}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NOT: concatenate -; CHECK: fusion -; CHECK-SAME: kind=kCustom -; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" -)"); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, - DISABLED_CanCodegenBatchedDotWithConcatenationCorrectly) { - constexpr absl::string_view kHloText = R"( -ENTRY e { - parameter_0 = f32[2,3,10]{2,1,0} parameter(0) - parameter_1 = f32[2,10,128]{2,1,0} parameter(1) - parameter_2 = f32[2,10,256]{2,1,0} parameter(2) - concatenate = f32[2,10,384]{2,1,0} concatenate(parameter_1, parameter_2), dimensions={2} - ROOT dot = f32[2,3,384]{2,1,0} dot(parameter_0, concatenate), - lhs_batch_dims={0}, lhs_contracting_dims={2}, - rhs_batch_dims={0}, rhs_contracting_dims={1} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NOT: concatenate -; CHECK: fusion -; CHECK-SAME: kind=kCustom -; CHECK-SAME: backend_config={{.*}}"kind":"__triton_nested_gemm_fusion" -)"); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonTest, DISABLED_FloatToSignedIntConversion) { - constexpr absl::string_view kHloText = R"( -HloModule t, is_scheduled=true - -triton_gemm_r { - p_0 = s8[32,32]{1,0} parameter(0) - p_1 = f16[32,32]{1,0} parameter(1) - cvt_1 = s8[32,32]{1,0} convert(p_1) - ROOT r.1 = f32[32,32]{1,0} dot(p_0, cvt_1), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -} - -ENTRY e { - p_0 = s8[32,32]{1,0} parameter(0) - p_1 = f16[32,32]{1,0} parameter(1) - ROOT triton_gemm_r = f32[32,32]{1,0} fusion(p_0, p_1), kind=kCustom, - calls=triton_gemm_r, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, - "split_k":1,"num_stages":1,"num_warps":4, - "num_ctas":1}}} -})"; - TF_EXPECT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm_r", R"( -CHECK: tt.func @triton_fn -CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> -CHECK-DAG: %[[FMIN:.*]] = arith.constant dense<-1.280000e+02> -CHECK-DAG: %[[IMIN:.*]] = arith.constant dense<-128> -CHECK-DAG: %[[FMAX:.*]] = arith.constant dense<1.270000e+02> -CHECK-DAG: %[[IMAX:.*]] = arith.constant dense<127> -CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[IN:.*]] : -CHECK: %[[CMP1:.*]] = arith.cmpf ole, %[[IN]], %[[FMIN]] -CHECK: %[[RES1:.*]] = arith.select %[[CMP1]], %[[IMIN]], %[[FPTOSI]] -CHECK: %[[CMP2:.*]] = arith.cmpf oge, %[[IN]], %[[FMAX]] -CHECK: %[[RES2:.*]] = arith.select %[[CMP2]], %[[IMAX]], %[[RES1]] -CHECK: %[[CMP3:.*]] = arith.cmpf uno, %[[IN]], %[[IN]] -CHECK: %[[RES3:.*]] = arith.select %[[CMP3]], %[[ZERO]], %[[RES2]] -})")); -} - // This tests the complexity heuristics in TritonWrapper. +// TODO(b/393299275): this is not worth keeping as a codegen test. Really, we +// should not reject tilings that are slow/spill in codegen. If this has use in +// autotuning, then this should be tested/called in the autotuner. +// The generic Triton emitter does not want to deal with this. TEST_F(TritonGemmTest, DISABLED_FailForTooComplexTiling) { - constexpr absl::string_view kHloText = R"( -HloModule module, is_scheduled=true + auto cc = se::CudaComputeCapability::Ampere(); + const se::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + constexpr absl::string_view kHloTextTemplate = R"( +HloModule module triton_gemm_dot { p0 = s8[1024,1024] parameter(0) @@ -993,116 +856,41 @@ ENTRY entry { p1 = f32[1024,1024] parameter(1) ROOT r = f32[1024,1024] fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":$0,"block_n":$1,"block_k":$2, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kHloText)); - HloFusionInstruction* triton_dot_fusion = Cast( - hlo_module->entry_computation()->root_instruction()); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - llvm::LLVMContext llvm_ctx; - llvm::Module llvm_module("module", llvm_ctx); - mlir::MLIRContext mlir_context; - auto backend_config_or = - triton_dot_fusion->backend_config(); - TF_ASSERT_OK(backend_config_or); - GpuBackendConfig& backend_config = *backend_config_or; - - FusionBackendConfig& fusion_backend_config = - *backend_config.mutable_fusion_backend_config(); - auto& config = *fusion_backend_config.mutable_triton_gemm_config(); - // Fails if the tiling is too complex. - config.set_block_m(512); - config.set_block_n(512); - config.set_block_k(32); - config.set_split_k(1); - config.set_num_ctas(1); - config.set_num_stages(1); - config.set_num_warps(2); - TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module1_and_metadata, + GetModuleAndNestedFusionMetadata(absl::Substitute( + kHloTextTemplate, 512, 512, 32))); - BlockLevelParameters block_level_parameters; - block_level_parameters.num_ctas = 1; - block_level_parameters.num_stages = 1; - block_level_parameters.num_warps = 2; - EXPECT_THAT( - TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, - block_level_parameters, &llvm_module, mlir_context), - StatusIs(tsl::error::RESOURCE_EXHAUSTED, - "Tiling complexity heuristic exceeded: 147456 > 9000")); + const HloFusionInstruction* fusion1 = Cast( + module1_and_metadata.computation->FusionInstruction()); + EXPECT_THAT(TritonWrapper("test_fn", fusion1, cc, device_info, + module1_and_metadata.block_level_parameters, + &llvm_module, mlir_context), + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + "Tiling complexity heuristic exceeded")); // Succeeds if the tiling is not too complex. - config.set_block_m(32); - config.set_block_n(32); - config.set_block_k(32); - TF_ASSERT_OK(triton_dot_fusion->set_backend_config(backend_config)); - - TF_ASSERT_OK(TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), - dev_info, block_level_parameters, &llvm_module, - mlir_context) - .status()); -} + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module2_and_metadata, + GetModuleAndNestedFusionMetadata( + absl::Substitute(kHloTextTemplate, 32, 32, 32))); -// Triton compiler used to have an issue with reordering constants: -// https://github.com/openai/triton/issues/1864 -TEST_F(TritonGemmTest, DISABLED_TritonCompilerDoesNotFailOnConstants) { - TF_ASSERT_OK(GetOptimizedModule(R"( -HloModule m + const HloFusionInstruction* fusion2 = Cast( + module1_and_metadata.computation->FusionInstruction()); -triton_gemm___computation { - parameter_0 = f32[92,11]{1,0} parameter(0) - c = f32[] constant(0) - b = f32[11,63] broadcast(c) - ROOT _.1 = f32[92,63]{1,0} dot(parameter_0, b), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f32[92,11]{1,0} parameter(0) - ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0), kind=kCustom, - calls=triton_gemm___computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2", - "num_ctas":"1"}}} -})") + TF_EXPECT_OK(TritonWrapper("test_fn", fusion2, cc, device_info, + module2_and_metadata.block_level_parameters, + &llvm_module, mlir_context) .status()); } -// Normally optimized HLO should contain `copy` instead of `transpose` but -// it's also possible to get transposes by modifying the compiler's pipeline. -// The emitter just has to skip through the transpose - it's handled by the -// tiled fusion analysis. -TEST_F(TritonGemmTest, DISABLED_TritonEmitterCanHandleTransposes) { - MatchOptimizedHlo(R"( -t { - p0 = f16[55,77,111]{2,1,0} parameter(0) - p1 = f16[111,77,99]{2,1,0} parameter(1) - t = f16[77,99,111]{2,1,0} transpose(p1), dimensions={1,2,0} - ROOT d = f16[77,55,99]{2,1,0} dot(p0, t), - lhs_batch_dims={1}, lhs_contracting_dims={2}, - rhs_batch_dims={0}, rhs_contracting_dims={2} -} - -ENTRY e { - p0 = f16[55,77,111]{2,1,0} parameter(0) - p1 = f16[111,77,99]{2,1,0} parameter(1) - ROOT r = f16[77,55,99]{2,1,0} fusion(p0, p1), kind=kCustom, - calls=t, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} -})", - // This partially optimized HLO will go through the - // autotuner which will run the fusion through the emitter - // multiple times and assign block sizes on success. - R"( -; CHECK: f16[77,99,111]{2,1,0} transpose -; CHECK-PTX: block_m -)"); -} - - +// TODO(b/393299275): this test may have some value while Triton tiling +// propagation is being replaced, but has little worth as a codegen test. +// Consider moving this. TEST_F(TritonGemmTest, BroadcastsOfTriviallySizedNonContractingDimensionsAreSupported) { constexpr absl::string_view kHloText = R"( @@ -1135,9 +923,14 @@ e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): this test may have some value while Triton tiling +// propagation is being replaced, but has little worth as a codegen test. +// Consider moving this. +// TODO(b/393299275): likely uncovered a bug in `NestGemmFusion`, where after +// transformations and collapse of a dimension, broadcast dimensions are wrong. TEST_F(TritonGemmTest, DISABLED_BroadcastsOfTriviallySizedContractingDimensionsAreSupported) { - EXPECT_TRUE(RunAndCompare(R"( + constexpr absl::string_view kHloText = R"( f { a = f16[2] parameter(0) bc0 = f16[1,2] bitcast(a) @@ -1152,56 +945,21 @@ e { a = f16[2] parameter(0) b = f16[3,4000] parameter(1) f = f16[2,3] fusion(a, b), - kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} -})", - ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_DoF32F32) { - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[3,5] parameter(0) - p1 = f32[5,7] parameter(1) - ROOT _ = f32[3,7] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion( -; CHECK-SAME: kind=kCustom -; CHECK-PTX-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, DISABLED_DoAddConstantToScalarAndBroadcastThat) { - if (std::holds_alternative( - GpuComputeCapability())) { - GTEST_SKIP() << "Not using autotuner on ROCM yet."; - } - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[] parameter(0) - p1 = f32[5,5] parameter(1) - %constant = f32[] constant(8) - add = add(p0, constant) - broadcast = f32[5,5] broadcast(add), dimensions={} - ROOT _ = f32[5,5] dot(broadcast, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} + kind=kCustom, calls=f, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config": {"block_m":"16","block_n":"16","block_k":"16","split_k":"1", + "num_stages":"1","num_warps":"1","num_ctas":"1"}}} })"; - MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion({{.*}} kind=kCustom, {{.*}}block_m -)"); + TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, + GetModuleAndNestedFusionMetadata(kHloText)); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module_and_metadata.module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_F(TritonGemmTest, DISABLED_DynamicSliceIsSupportedInLhsEndToEnd) { // The select is used to restrict the start index to values that make sense. // If it was constant, then the dynamic-slice would be optimized to slice. It @@ -1236,6 +994,8 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_F(TritonGemmTest, DISABLED_DynamicSliceIsSupportedInRhs) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. @@ -1275,6 +1035,8 @@ class TritonGemmDynamicSliceClampingTest : public TritonTest, public ::testing::WithParamInterface {}; +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_P(TritonGemmDynamicSliceClampingTest, DISABLED_DynamicSliceIsSupportedWhenTheStartIndexNeedsClamping) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) @@ -1321,6 +1083,8 @@ std::string OffsetParamToString(const ::testing::TestParamInfo& data) { INSTANTIATE_TEST_SUITE_P(All, TritonGemmDynamicSliceClampingTest, ::testing::Values(-100, 3, 999), OffsetParamToString); +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_F(TritonGemmTest, DISABLED_DynamicSliceOfMajormostContractingDimIsSupported) { // Tests that dynamic-slice works on the majormost dimension even if that @@ -1359,6 +1123,8 @@ ENTRY e { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_F(TritonGemmTest, DISABLED_DynamicSliceOfMajormostBatchDimIsSupported) { // Tests that dynamic-slice works on the majormost dimension even if that // dimension is a batch. @@ -1398,6 +1164,8 @@ ENTRY e { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } +// TODO(b/393299275): this requires adding support for dynamic-slice in the +// generic Triton emitter. TEST_F(TritonGemmTest, DISABLED_DynamicSliceSingleDimensionIntoReshapeIsSupported) { // This directly tests the targeted use case (b/307922364) of iterating over @@ -1439,8 +1207,9 @@ ENTRY e { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, - DISABLED_DoNotFuseConcatenationOfSplitNonContractingDimension) { +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, DoNotFuseConcatenationOfSplitNonContractingDimension) { if (std::holds_alternative( GpuComputeCapability())) { GTEST_SKIP() << "Not using autotuner on ROCM yet."; @@ -1462,67 +1231,14 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: ENTRY ; CHECK: concatenate -; CHECK: ROOT -; CHECK-SAME: fusion -; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m" +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: "__triton_nested_gemm_fusion" )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, DISABLED_BroadcastOfScalarWorksCorrectly) { - constexpr absl::string_view kHloText = R"( -fusion { - p0 = f16[2,18] parameter(0) - p1 = f16[256,2] parameter(1) - d = f16[18,256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={1} - p2 = f16[] parameter(2) - p3 = f16[] parameter(3) - multiply = f16[] multiply(p2, p3) - broadcast = f16[18,256] broadcast(multiply), dimensions={} - ROOT multiply.3 = f16[18,256] multiply(d, broadcast) -} -ENTRY e { - p0 = f16[2,18] parameter(0) - p1 = f16[256,2] parameter(1) - p2 = f16[] parameter(2) - p3 = f16[] parameter(3) - ROOT gemm_fusion = f16[18,256]{1,0} fusion(p0, p1, p2, p3), kind=kCustom, calls=fusion, backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}}} -})"; - - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText, "fusion", R"( - CHECK: tt.dot - CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor - CHECK: tt.broadcast %{{.*}} : tensor<1x1xf16> -> tensor<32x32xf16> - CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor<32x32xf16> - )")); - const se::DeviceDescription dev_info = - backend().default_stream_executor()->GetDeviceDescription(); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kHloText)); - const HloFusionInstruction* triton_dot_fusion = Cast( - hlo_module->entry_computation()->root_instruction()); - llvm::LLVMContext llvm_ctx; - llvm::Module llvm_module("module", llvm_ctx); - mlir::MLIRContext mlir_context; - - TF_ASSERT_OK_AND_ASSIGN( - auto gpu_config, triton_dot_fusion->backend_config()); - const FusionBackendConfig& config = gpu_config.fusion_backend_config(); - auto gemm_config = config.triton_gemm_config(); - BlockLevelParameters block_level_parameters; - block_level_parameters.num_ctas = gemm_config.num_ctas(); - block_level_parameters.num_warps = gemm_config.num_warps(); - block_level_parameters.num_stages = gemm_config.num_stages(); - - TF_ASSERT_OK(TritonWrapper("test_fn", triton_dot_fusion, - GpuComputeCapability(), dev_info, - block_level_parameters, &llvm_module, mlir_context) - .status()); -} - TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1654,6 +1370,8 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTest, BroadcastOfScalarParameterIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { @@ -1671,10 +1389,10 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTest, BroadcastOfScalarConstantIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1696,10 +1414,9 @@ ENTRY e { .WithFusionKind(HloInstruction::FusionKind::kCustom))); } -TEST_F(TritonGemmTest, DISABLED_DoubleBroadcastOfScalarConstantIsHandled) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, DoubleBroadcastOfScalarConstantIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { c = s32[] constant(1) @@ -1716,14 +1433,14 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT( - module->entry_computation()->root_instruction(), + GetNonBitcastRoot(module->entry_computation()), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, DISABLED_BroadcastOfVectorConstantIsFused) { +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, BroadcastOfVectorConstantIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1738,16 +1455,14 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT( - module->entry_computation()->root_instruction(), + GetNonBitcastRoot(module->entry_computation()), GmockMatch(m::Fusion(m::Parameter(), m::Constant()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTest, AlwaysFuseScalarConstantAtBroadcastInput) { - if (!SupportsBF16(GpuComputeCapability())) { - GTEST_SKIP() << "BF16 not supported."; - } constexpr absl::string_view kHloText = R"( ENTRY e { p0 = bf16[2,3,3]{2,1,0} parameter(0) @@ -1773,31 +1488,9 @@ ENTRY e { )"); } -TEST_F(TritonGemmTest, BroadcastOfVectorParameterIsFused) { - constexpr absl::string_view kHloText = R"( -triton_dot { - p0 = f16[75] parameter(0) - bc0 = f16[75,67] broadcast(p0), dimensions={0} - p1 = f16[92,75] parameter(1) - ROOT d = f16[92,67] dot(p1, bc0), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f16[75] parameter(0) - p1 = f16[92,75] parameter(1) - ROOT _ = f16[92,67] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":64,"block_k":32, - "split_k":1,"num_stages":1,"num_warps":1, - "num_ctas":1}}} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); -} - -TEST_F(TritonGemmTest, DISABLED_FuseConcatenation) { +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, FuseConcatenation) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -1815,9 +1508,8 @@ e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); - EXPECT_THAT( - module->entry_computation()->root_instruction(), + GetNonBitcastRoot(module->entry_computation()), GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); @@ -1826,6 +1518,8 @@ e { /*arel=*/1e-2})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTest, SineOutputIsNotFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1847,6 +1541,8 @@ ENTRY e { .WithFusionKind(HloInstruction::FusionKind::kCustom)))); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTest, SliceInputIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { @@ -1866,7 +1562,9 @@ ENTRY e { .WithFusionKind(HloInstruction::FusionKind::kCustom))); } -TEST_F(TritonGemmTest, DISABLED_SliceInputWithReshapeIsFused) { +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, SliceInputWithReshapeIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f32[363,1536] parameter(0) @@ -1881,13 +1579,16 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT( - module->entry_computation()->root_instruction(), + GetNonBitcastRoot(module->entry_computation()), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. From the point of view of codegen, this is purely +// about compositionality. TEST_F(TritonGemmTest, NestedSlicingWorks) { constexpr absl::string_view kHloText = R"( ENTRY e { @@ -1908,6 +1609,8 @@ ENTRY e { .WithFusionKind(HloInstruction::FusionKind::kCustom))); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTest, SlicedBatchDimensionIsSupported) { constexpr absl::string_view kHloText = R"( ENTRY e { @@ -1930,6 +1633,12 @@ ENTRY e { .WithFusionKind(HloInstruction::FusionKind::kCustom))); } +// TODO(b/393299275): symbolic tile analysis fails to derive a tile for one +// outer parameter here. However, we shouldn't be deriving this tile anyway, +// and the underlying indexing map is incorrect. This requires a fix in +// symbolic tile derivation. +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. TEST_F(TritonGemmTestWithSplitK, DISABLED_SplitKDoesNotBreakSlicedFragmentedContractingDimension) { constexpr absl::string_view kHloText = R"( @@ -1955,7 +1664,9 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); } -TEST_F(TritonGemmTestWithSplitK, DISABLED_SplitKWithTrivialDimension) { +// TODO(b/393299275): this should be rewritten to work on post-optimization HLO, +// and potentially have an associated fusion test. +TEST_F(TritonGemmTestWithSplitK, SplitKWithTrivialDimension) { constexpr absl::string_view kHloText = R"( ENTRY entry_computation { p0 = f16[1001,1]{1,0} parameter(0) @@ -1968,7 +1679,9 @@ ENTRY entry_computation { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); } -TEST_F(TritonGemmTest, DISABLED_NarrowingConvertOutputIsFused) { +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, NarrowingConvertOutputIsFused) { constexpr absl::string_view kHloText = R"( HloModule m @@ -1987,10 +1700,11 @@ ENTRY e { module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/3e-2, /*arel=*/3e-2})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +// TODO(b/393299275): looks like another miscompile. TEST_F(TritonGemmTest, DISABLED_ParameterAfterDotIsFused) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; @@ -2060,7 +1774,9 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); } -TEST_F(TritonGemmTest, DISABLED_SplitLHSOutputTransposeAloneIsNotFused) { +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +TEST_F(TritonGemmTest, SplitLHSOutputTransposeAloneIsNotFused) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; } @@ -2089,6 +1805,12 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/393299275): this should just be a fusion test and does not need to be +// in the codegen directory. +// TODO(b/393299275): symbolic tile analysis fails to derive a tile for one +// outer parameter here. However, we shouldn't be deriving this tile anyway, +// and the underlying indexing map is incorrect. This requires a fix in +// symbolic tile derivation. TEST_F(TritonGemmTest, DISABLED_SplitLHSInputOutputIsFused) { if (!SupportsBF16(GpuComputeCapability())) { GTEST_SKIP() << "BF16 not supported."; @@ -2100,8 +1822,7 @@ TEST_F(TritonGemmTest, DISABLED_SplitLHSInputOutputIsFused) { constexpr absl::string_view kHloText = R"( ENTRY e { - p0t = (s8[5,18,20,150]) parameter(0) - p0 = s8[5,18,20,150] get-tuple-element(p0t), index=0 + p0 = s8[5,18,20,150] parameter(0) p0c = bf16[5,18,20,150] convert(p0) t0 = bf16[18,5,20,150] transpose(p0c), dimensions={1,0,2,3} r0 = bf16[18,15000] reshape(t0) @@ -2116,7 +1837,7 @@ ENTRY e { GetOptimizedModule(kHloText)); EXPECT_THAT( module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::GetTupleElement(), m::Parameter()) + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2147,48 +1868,50 @@ ENTRY e { .WithFusionKind(HloInstruction::FusionKind::kCustom))); } -TEST_F(TritonGemmTest, - DISABLED_LowerDotWithLhsWithoutNonContractingDimThroughTriton) { +// TODO(b/393299275): This test name might be a bit misleading, since the dot is +// given a non-contracting dimension by the time it gets passed down to the +// Triton emitter. This should probably be a fusion test. +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonGemmTest, LowerDotWithLhsWithoutNonContractingDimThroughTriton) { constexpr absl::string_view kHloText = R"( -HloModule t - ENTRY e { parameter_0 = f32[1,40] parameter(0) parameter_1 = f32[1,40,250000] parameter(1) ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + GetNonBitcastRoot(module->entry_computation()), + GmockMatch( + m::Fusion(m::Bitcast(m::Parameter()), m::Bitcast(m::Parameter())) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); } -TEST_F(TritonGemmTest, - DISABLED_LowerDotWithRhsWithoutNonContractingDimThroughTriton) { +// TODO(b/393299275): This test name might be a bit misleading, since the dot is +// given a non-contracting dimension by the time it gets passed down to the +// Triton emitter. This should probably be a fusion test. +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonGemmTest, LowerDotWithRhsWithoutNonContractingDimThroughTriton) { constexpr absl::string_view kHloText = R"( -HloModule t - ENTRY e { parameter_0 = f32[1,40,250000] parameter(0) parameter_1 = f32[1,40] parameter(1) ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + GetNonBitcastRoot(module->entry_computation()), + GmockMatch( + m::Fusion(m::Bitcast(m::Parameter()), m::Bitcast(m::Parameter())) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); } // This group of tests compares GPU results of dots already rewritten @@ -2307,8 +2030,8 @@ ENTRY e { optin_shmem_module_and_metadata.block_level_parameters, &llvm_module, mlir_context)); // The config is chosen so that the used memory size is slightly above the - // 48 kB boundary of standard / optin shared memory so that any GPU that - // has the optin one should be able to execute the test. + // 48 kB boundary of standard / opt-in shared memory so that any GPU that + // has the opt-in one should be able to execute the test. EXPECT_EQ(result.shmem_bytes, kBytesOfSharedMemoryTested); // Make sure the written config indeed has to use optin shared memory. EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); @@ -3327,113 +3050,6 @@ CHECK: inputPrecision = tf32 ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonTest, Fp8LoweringIsSupportedPostHopper) { - if (!GetCudaComputeCapability().IsAtLeastHopper()) { - GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; - } - constexpr absl::string_view kHloText = R"( -HloModule t - -triton_dot { - parameter_0 = f8e4m3fn[1600,1600]{1,0} parameter(0) - parameter_1 = f8e4m3fn[1600,1600]{1,0} parameter(1) - transpose = f8e4m3fn[1600,1600]{0,1} transpose(parameter_1), dimensions={1,0} - ROOT dot = f16[1600,1600]{1,0} dot(parameter_0, transpose), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -} - -ENTRY main { - parameter_1 = f8e4m3fn[1600,1600]{1,0} parameter(1) - parameter_0 = f8e4m3fn[1600,1600]{1,0} parameter(0) - ROOT gemm_fusion_dot = f16[1600,1600]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_dot, - backend_config={ - "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": - {"block_m":"128","block_n":"32","block_k":"64","split_k":"1", - "num_stages":"4","num_warps":"4","num_ctas":"1"}}} -})"; - - TF_ASSERT_OK_AND_ASSIGN(ModuleAndNestedFusionMetadata module_and_metadata, - GetModuleAndNestedFusionMetadata(kHloText)); - TF_ASSERT_OK( - CreateTritonIrAndFileCheck(*module_and_metadata.computation, - module_and_metadata.block_level_parameters, - R"( -CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4M3FN> * tensor<64x32xf8E4M3FN> -> tensor<128x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module_and_metadata.module), - ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); -} - -// TODO(b/393299275): this test requires us to allow actual mixed type GEMMs -// in the lowering. We need to expand support tests and the lowering to model -// mixed types as needed. (f8e4m3fn x f8e4m3fn -> f32) -TEST_F(TritonTest, DISABLED_BF16ToFP8EndToEnd) { - if (!GetCudaComputeCapability().IsAtLeastHopper()) { - GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; - } - - constexpr absl::string_view kHloText = R"( -HloModule t - -triton_dot { - parameter_0 = bf16[32,32]{1,0} parameter(0) - parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) - convert = f8e4m3fn[32,32]{1,0} convert(parameter_0) - ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -} - -ENTRY main { - parameter_0 = bf16[32,32]{1,0} parameter(0) - parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) - ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_dot, - backend_config={ - "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": - {"block_m":"32","block_n":"32","block_k":"32","split_k":"1", - "num_stages":"1","num_warps":"4","num_ctas":"1"}}} -})"; - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); -} - -// TODO(b/393299275): this test requires us to allow actual mixed type GEMMs -// in the lowering. We need to expand support tests and the lowering to model -// mixed types as needed. -TEST_F(TritonTest, DISABLED_FP8ToFP8EndToEnd) { - if (!GetCudaComputeCapability().IsAtLeastHopper()) { - GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; - } - - constexpr absl::string_view kHloText = R"( -HloModule t - -triton_dot { - parameter_0 = f8e5m2[32,32]{1,0} parameter(0) - parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) - convert = f8e4m3fn[32,32]{1,0} convert(parameter_0) - ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -} - -ENTRY main { - parameter_0 = f8e5m2[32,32]{1,0} parameter(0) - parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) - ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_dot, - backend_config={ - "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": - {"block_m":"32","block_n":"32","block_k":"32","split_k":"1", - "num_stages":"1","num_warps":"4","num_ctas":"1"}}} -})"; - ASSERT_TRUE( - GetDebugOptionsForTest() - .xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms()); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); -} - // Test PreventMmaV3LoopUnrolling pass in order to keep compile time low. // See b/344841434. // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be From b3557b53427c84eb8c5ff7807d267dc4d153665b Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 15 Apr 2025 04:59:01 -0700 Subject: [PATCH 0805/1324] PR #25269: [ROCM] Fix asan issue do to a singleton in header file Imported from GitHub PR https://github.com/openxla/xla/pull/25269 Reported issue: ``` exec ${PAGER:-/usr/bin/less} "$0" || exit 1 Executing tests from //xla/service:compiler_test_gpu_amd_any ----------------------------------------------------------------------------- Running test /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-dbg/bin/xla/service/compiler_test_gpu_amd_any.runfiles/xla/xla/service/compiler_test_gpu_amd_any --gtest_shuffle --gtest_fail_if_no_test_linked on GPU 0 ================================================================= ==168009==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x50400002c1c0 at pc 0x7f59e50b52e7 bp 0x7ffc8c2358d0 sp 0x7ffc8c2358c8 READ of size 8 at 0x50400002c1c0 thread T0 #0 0x7f59e50b52e6 in absl::lts_20230802::container_internal::CommonFields::capacity() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:990:36 #1 0x7f59e50b52e6 in absl::lts_20230802::container_internal::probe(absl::lts_20230802::container_internal::CommonFields const&, unsigned long) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:1298:41 #2 0x7f59e50b52e6 in std::pair absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::find_or_prepare_insert>(std::tuple const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2645:16 #3 0x7f59e50af8a8 in std::pair, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::iterator, bool> absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable::operator(), std::piecewise_construct_t const&, std::tuple&&>, std::tuple>(std::tuple const&, std::piecewise_construct_t const&, std::tuple&&>&&, std::tuple&&) const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2459:20 #4 0x7f59e50af8a8 in decltype(std::declval, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable>()(std::declval&& const&>(), std::piecewise_construct, std::declval&&>>(), std::declval>())) absl::lts_20230802::container_internal::memory_internal::DecomposePairImpl, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable, std::tuple&&, std::tuple>(absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable&&, std::pair&&>, std::tuple>) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/container_memory.h:140:10 #5 0x7f59e50af8a8 in decltype(memory_internal::DecomposePairImpl(std::forward, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable>(fp), PairArgs(std::forward, stream_executor::MultiKernelLoaderSpec>>(fp0)))) absl::lts_20230802::container_internal::DecomposePair, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable, std::pair, stream_executor::MultiKernelLoaderSpec>>(absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable&&, std::pair, stream_executor::MultiKernelLoaderSpec>&&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/container_memory.h:207:10 #6 0x7f59e50af8a8 in decltype(absl::container_internal::DecomposePair(std::declval, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable>(), std::declval, stream_executor::MultiKernelLoaderSpec>>())) absl::lts_20230802::container_internal::FlatHashMapPolicy, stream_executor::MultiKernelLoaderSpec>::apply, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable, std::pair, stream_executor::MultiKernelLoaderSpec>>(absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable&&, std::pair, stream_executor::MultiKernelLoaderSpec>&&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/flat_hash_map.h:591:12 #7 0x7f59e50af8a8 in decltype(absl::lts_20230802::container_internal::FlatHashMapPolicy, stream_executor::MultiKernelLoaderSpec>::apply(std::forward, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable>(fp), std::forward, stream_executor::MultiKernelLoaderSpec>>(fp0))) absl::lts_20230802::container_internal::hash_policy_traits, stream_executor::MultiKernelLoaderSpec>, void>::apply, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable, std::pair, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::container_internal::FlatHashMapPolicy, stream_executor::MultiKernelLoaderSpec>>(absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::EmplaceDecomposable&&, std::pair, stream_executor::MultiKernelLoaderSpec>&&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/hash_policy_traits.h:134:12 #8 0x7f59e50af8a8 in std::pair, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::iterator, bool> absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::emplace, stream_executor::MultiKernelLoaderSpec>, 0>(std::pair, stream_executor::MultiKernelLoaderSpec>&&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:2064:12 #9 0x7f59e50af8a8 in absl::lts_20230802::container_internal::raw_hash_set, stream_executor::MultiKernelLoaderSpec>, absl::lts_20230802::hash_internal::Hash>, std::equal_to>, std::allocator const, stream_executor::MultiKernelLoaderSpec>>>::insert(std::pair, stream_executor::MultiKernelLoaderSpec>&&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:1991:12 #10 0x7f59e50af8a8 in stream_executor::gpu::GpuKernelRegistry::RegisterKernel(std::type_info const&, void*, stream_executor::MultiKernelLoaderSpec const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/gpu/gpu_kernel_registry.cc:67:45 #11 0x7f59e50d1982 in absl::lts_20230802::Status stream_executor::gpu::GpuKernelRegistry::RegisterKernel(void*, stream_executor::MultiKernelLoaderSpec const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/stream_executor/gpu/gpu_kernel_registry.h:86:12 #12 0x7f59e50d1982 in RegisterKernelMakeBatchPointersKernelRocmImpl() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc:35:1 #13 0x7f59e50d1982 in 'lambda'()::operator()() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc:35:1 #14 0x7f59e50d1982 in 'lambda'()::__invoke() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc:35:1 #15 0x7f59e50d1982 in stream_executor::port::Initializer::Initializer(void (*)()) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/stream_executor/platform/default/initialize.h:26:42 #16 0x7f59e50d1982 in __cxx_global_var_init.1 /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc:35:1 #17 0x7f59e50d1982 in _GLOBAL__sub_I_make_batch_pointers_kernel_rocm.cu.cc /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/rocm/make_batch_pointers_kernel_rocm.cu.cc #18 0x7f5a5b27a47d in call_init elf/dl-init.c:70:3 #19 0x7f5a5b27a567 in call_init elf/dl-init.c:33:6 #20 0x7f5a5b27a567 in _dl_init elf/dl-init.c:117:5 #21 0x7f5a5b2942c9 (/lib64/ld-linux-x86-64.so.2+0x202c9) (BuildId: e4de036b19e4768e7591b596c4be9f9015f2d28a) 0x50400002c1c0 is located 8 bytes after 40-byte region [0x50400002c190,0x50400002c1b8) allocated by thread T0 here: #0 0x557d0f77fcdf in malloc (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-dbg/bin/xla/service/compiler_test_gpu_amd_any+0x1e8cdf) (BuildId: e96972f8c7f880083ff6ad5985d3c06d) #1 0x7f59d733098b in operator new(unsigned long) (/lib/x86_64-linux-gnu/libstdc++.so.6+0xae98b) (BuildId: e37fe1a879783838de78cbc8c80621fa685d58a2) SUMMARY: AddressSanitizer: heap-buffer-overflow /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:990:36 in absl::lts_20230802::container_internal::CommonFields::capacity() const Shadow bytes around the buggy address: 0x50400002bf00: fa fa fd fd fd fd fd fd fa fa fd fd fd fd fd fa 0x50400002bf80: fa fa fd fd fd fd fd fa fa fa fd fd fd fd fd fa 0x50400002c000: fa fa fd fd fd fd fd fa fa fa 00 00 00 00 00 fa 0x50400002c080: fa fa fd fd fd fd fd fd fa fa 00 00 00 00 00 00 0x50400002c100: fa fa 00 00 00 00 00 00 fa fa 00 00 00 00 00 00 =>0x50400002c180: fa fa 00 00 00 00 00 fa[fa]fa fa fa fa fa fa fa 0x50400002c200: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa 0x50400002c280: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa 0x50400002c300: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa 0x50400002c380: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa 0x50400002c400: fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa fa Shadow byte legend (one shadow byte represents 8 application bytes): Addressable: 00 Partially addressable: 01 02 03 04 05 06 07 Heap left redzone: fa Freed heap region: fd Stack left redzone: f1 Stack mid redzone: f2 Stack right redzone: f3 Stack after return: f5 Stack use after scope: f8 Global redzone: f9 Global init order: f6 Poisoned by user: f7 Container overflow: fc Array cookie: ac Intra object redzone: bb ASan internal: fe Left alloca redzone: ca Right alloca redzone: cb ==168009==ABORTING ``` Why this fixes the issue: * Consider compiling this class into a different .so files where this function will get inlined and we will get different instances while we still want to have singleton. * In rocm compiler wrapper script we do not yet support sanitizer flags so our cu.cc files are not getting instrumented while our normal cc files do! This might cause a memory disalignment while running with asan (theory). Copybara import of the project: -- ffcd58918137191cdba6db571e0e5af0e57de2e1 by alekstheod : Fix asan issue do to a singleton in header file Merging this change closes #25269 PiperOrigin-RevId: 747823659 --- .../xla/xla/stream_executor/gpu/gpu_kernel_registry.cc | 5 +++++ .../xla/xla/stream_executor/gpu/gpu_kernel_registry.h | 5 +---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.cc b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.cc index f5262cdbac34db..c92bf44e115c84 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.cc @@ -41,6 +41,11 @@ std::string GetPlatformName(Platform::Id platform_id) { } } // namespace +GpuKernelRegistry& GpuKernelRegistry::GetGlobalRegistry() { + static auto registry = new GpuKernelRegistry(); + return *registry; +} + absl::StatusOr> GpuKernelRegistry::GetKernelSpec(const std::type_info& type, Platform::Id platform_id) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.h index c079f93f213ef5..2ad466de3e26fb 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_registry.h @@ -87,10 +87,7 @@ class GpuKernelRegistry { } // Returns a reference to the process-wide instance of the registry. - static GpuKernelRegistry& GetGlobalRegistry() { - static auto registry = new GpuKernelRegistry(); - return *registry; - } + static GpuKernelRegistry& GetGlobalRegistry(); private: absl::Status RegisterKernel(const std::type_info& type, From 58a8dd7f772dd6ef1680ef8a8aad58296ebe6af5 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Tue, 15 Apr 2025 06:05:40 -0700 Subject: [PATCH 0806/1324] Move `AllReduceKernel` behind `GpuKernelRegistry`. * Moves `AllReduce` logic into `backends/gpu/runtime` since it's a runtime component. * Defines trait for the `AllReduce` kernel in `stream_executor/gpu/` * Moves the implementations of this kernel into stream_executor/{cuda|rocm} and registers them with the registry for each supported type. * Makes `AllReduce` retrieve the kernel by using the kernel registry. * Add the kernel implementations as dependencies to the `all_runtime` targets for CUDA and ROCm. PiperOrigin-RevId: 747843484 --- .../xla/xla/backends/gpu/runtime/BUILD | 50 +++++++++++++++ .../gpu/runtime/all_reduce.cc} | 64 ++++++++++++------- .../gpu/runtime/all_reduce.h} | 6 +- .../gpu/runtime/all_reduce_test.cc} | 2 +- third_party/xla/xla/service/gpu/kernels/BUILD | 59 ----------------- .../xla/xla/stream_executor/cuda/BUILD | 22 +++++++ .../cuda/all_reduce_kernel_cuda.cc | 37 +++++++++++ third_party/xla/xla/stream_executor/gpu/BUILD | 10 +++ .../gpu/all_reduce_kernel.h} | 27 +++++--- .../gpu/all_reduce_kernel_lib.cu.h} | 23 ++----- .../xla/xla/stream_executor/rocm/BUILD | 22 +++++++ .../rocm/all_reduce_kernel_rocm.cc | 37 +++++++++++ 12 files changed, 247 insertions(+), 112 deletions(-) rename third_party/xla/xla/{service/gpu/kernels/all_reduce_kernel.cc => backends/gpu/runtime/all_reduce.cc} (55%) rename third_party/xla/xla/{service/gpu/kernels/all_reduce_kernel.h => backends/gpu/runtime/all_reduce.h} (92%) rename third_party/xla/xla/{service/gpu/kernels/all_reduce_kernel_test.cc => backends/gpu/runtime/all_reduce_test.cc} (98%) create mode 100644 third_party/xla/xla/stream_executor/cuda/all_reduce_kernel_cuda.cc rename third_party/xla/xla/{service/gpu/kernels/all_reduce_kernel_common.h => stream_executor/gpu/all_reduce_kernel.h} (52%) rename third_party/xla/xla/{service/gpu/kernels/all_reduce_kernel.cu.cc => stream_executor/gpu/all_reduce_kernel_lib.cu.h} (73%) create mode 100644 third_party/xla/xla/stream_executor/rocm/all_reduce_kernel_rocm.cc diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index ea7cb55a8059f5..b7c1085d8b2a1d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -1428,6 +1428,56 @@ xla_test( ], ) +cc_library( + name = "all_reduce", + srcs = ["all_reduce.cc"], + hdrs = ["all_reduce.h"], + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:all_reduce_kernel", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "all_reduce_test", + srcs = ["all_reduce_test.cc"], + backends = ["gpu"], + disabled_backends = [], + deps = [ + ":all_reduce", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor/gpu:gpu_init", + "//xla/stream_executor/host:host_platform", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "ragged_all_to_all", srcs = ["ragged_all_to_all.cc"], diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc similarity index 55% rename from third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc rename to third_party/xla/xla/backends/gpu/runtime/all_reduce.cc index d17217e82ead88..c9d654c7a3c2be 100644 --- a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/kernels/all_reduce_kernel.h" +#include "xla/backends/gpu/runtime/all_reduce.h" #include #include @@ -22,9 +22,12 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "xla/service/gpu/kernels/all_reduce_kernel_common.h" +#include "xla/primitive_util.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/all_reduce_kernel.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" @@ -36,22 +39,27 @@ limitations under the License. namespace xla::gpu { namespace { +template +absl::Status LaunchTypedKernel( + se::Stream* stream, se::StreamExecutor* executor, + const se::ThreadDim& thread_dims, const se::BlockDim& block_dims, + const std::array& + input_ptrs, + se::DeviceMemoryBase output_buffer, int64_t num_inputs, + int64_t num_elements) { + TF_ASSIGN_OR_RETURN(auto kernel, + se::gpu::GpuKernelRegistry::GetGlobalRegistry() + .LoadKernel>(executor)); -void* GetKernel(PrimitiveType element_type) { - switch (element_type) { - case F32: - return GetAllReduceKernel(); - default: - return nullptr; - } + return kernel.Launch(thread_dims, block_dims, stream, input_ptrs, + output_buffer, num_inputs, num_elements); } - } // namespace bool IsAllReduceKernelSupported(int64_t num_outputs, PrimitiveType element_type) { - return num_outputs <= kMaxNumAllReduceInputPtrs && - GetKernel(element_type) != nullptr; + return num_outputs <= stream_executor::gpu::kMaxNumAllReduceInputPtrs && + element_type == F32; } absl::Status RunAllReduceKernel( @@ -59,7 +67,7 @@ absl::Status RunAllReduceKernel( absl::Span input_buffers, se::DeviceMemoryBase output_buffer, int64_t num_inputs, int64_t num_elements) { - if (input_buffers.size() > kMaxNumAllReduceInputPtrs) { + if (input_buffers.size() > stream_executor::gpu::kMaxNumAllReduceInputPtrs) { return absl::InvalidArgumentError( "Number of input pointers exceeds the maximum supported number of " "input pointers."); @@ -70,22 +78,30 @@ absl::Status RunAllReduceKernel( // TODO(b/383125489): Fine tune the block and thread dimensions. static constexpr size_t kBlocks = 8; static constexpr size_t kThreads = 512; + se::ThreadDim thread_dims(kThreads, 1, 1); + se::BlockDim block_dims(kBlocks, 1, 1); - TF_ASSIGN_OR_RETURN( - auto kernel, - (se::TypedKernelFactory, - se::DeviceMemoryBase, int64_t, - int64_t>::Create(executor, "one_shot_all_reduce", - GetKernel(element_type)))); - - std::array input_ptrs; + std::array input_ptrs; absl::c_transform( input_buffers, input_ptrs.begin(), [](se::DeviceMemoryBase buffer) { return buffer.opaque(); }); - return kernel.Launch(se::ThreadDim(kThreads, 1, 1), - se::BlockDim(kBlocks, 1, 1), stream, input_ptrs, - output_buffer, num_inputs, num_elements); + auto launch_kernel = [&](auto type) -> absl::Status { + using T = decltype(type); + return LaunchTypedKernel(stream, executor, thread_dims, block_dims, + input_ptrs, output_buffer, num_inputs, + num_elements); + }; + + switch (element_type) { + case F32: + return launch_kernel(float{}); + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported element type: ", + primitive_util::LowercasePrimitiveTypeName(element_type), + " for AllReduce kernel.")); + } } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h b/third_party/xla/xla/backends/gpu/runtime/all_reduce.h similarity index 92% rename from third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h rename to third_party/xla/xla/backends/gpu/runtime/all_reduce.h index 7c354752e545f9..df0888a6b179c6 100644 --- a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.h +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_H_ -#define XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_H_ +#ifndef XLA_BACKENDS_GPU_RUNTIME_ALL_REDUCE_H_ +#define XLA_BACKENDS_GPU_RUNTIME_ALL_REDUCE_H_ #include @@ -51,4 +51,4 @@ absl::Status RunAllReduceKernel( } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_H_ +#endif // XLA_BACKENDS_GPU_RUNTIME_ALL_REDUCE_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc rename to third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc index 32e3ac51b4a558..b496188b31f705 100644 --- a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/kernels/all_reduce_kernel.h" +#include "xla/backends/gpu/runtime/all_reduce.h" #include #include diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 95eba50628a334..7df0e4c78e04b3 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -262,65 +262,6 @@ xla_test( ], ) -cc_library( - name = "all_reduce_kernel", - srcs = ["all_reduce_kernel.cc"], - hdrs = ["all_reduce_kernel.h"], - tags = ["gpu"], - visibility = [":friends"], - deps = [ - ":all_reduce_kernel_gpu", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:stream", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:typed_kernel_factory", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - ], -) - -gpu_kernel_library( - name = "all_reduce_kernel_gpu", - srcs = ["all_reduce_kernel.cu.cc"], - hdrs = ["all_reduce_kernel_common.h"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -xla_test( - name = "all_reduce_kernel_test", - srcs = ["all_reduce_kernel_test.cc"], - backends = ["gpu"], - deps = [ - ":all_reduce_kernel", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/stream_executor:stream", - "//xla/stream_executor/gpu:gpu_init", - "//xla/stream_executor/host:host_platform", - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:test", - "//xla/tsl/platform:test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - ], -) - #===--------------------------------------------------------------------------------------------===# # CUTLASS Gemm <-> xla::gpu::kernel::CustomKernel adaptor #===--------------------------------------------------------------------------------------------===# diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 5f49fe61368c14..91e9bd16c352b9 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1110,6 +1110,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":all_reduce_kernel_cuda", ":buffer_comparator_kernel_cuda", ":cublas_plugin", ":cuda_platform", @@ -2041,3 +2042,24 @@ cuda_library( ], alwayslink = 1, ) + +cuda_library( + name = "all_reduce_kernel_cuda", + srcs = [ + "all_reduce_kernel_cuda.cc", + "//xla/stream_executor/gpu:all_reduce_kernel_lib.cu.h", + ], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_platform_id", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:all_reduce_kernel", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "@local_config_cuda//cuda:cuda_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/cuda/all_reduce_kernel_cuda.cc b/third_party/xla/xla/stream_executor/cuda/all_reduce_kernel_cuda.cc new file mode 100644 index 00000000000000..8c1069ddb87175 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/all_reduce_kernel_cuda.cc @@ -0,0 +1,37 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/gpu/all_reduce_kernel.h" +#include "xla/stream_executor/gpu/all_reduce_kernel_lib.cu.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" + +#define REGISTER_ALL_REDUCE_KERNEL(TYPE) \ + GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( \ + AllReduceKernelCuda##TYPE, stream_executor::gpu::AllReduceKernel, \ + stream_executor::cuda::kCudaPlatformId, ([] { \ + stream_executor::MultiKernelLoaderSpec spec(4); \ + spec.AddInProcessSymbol( \ + absl::bit_cast( \ + &stream_executor::gpu::AllReduceKernelImpl), \ + "one_shot_all_reduce_" #TYPE); \ + return spec; \ + })); + +// Register the kernel for different types using the macro +REGISTER_ALL_REDUCE_KERNEL(float); diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index c2699334093c9a..f72bc7ced73c13 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -854,6 +854,7 @@ cc_library( exports_files([ "buffer_comparator_kernel_lib.cu.h", + "all_reduce_kernel_lib.cu.h", "ragged_all_to_all_kernel_lib.cu.h", ]) @@ -874,3 +875,12 @@ cc_library( "//xla/stream_executor:kernel", ], ) + +cc_library( + name = "all_reduce_kernel", + hdrs = ["all_reduce_kernel.h"], + deps = [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + ], +) diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h b/third_party/xla/xla/stream_executor/gpu/all_reduce_kernel.h similarity index 52% rename from third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h rename to third_party/xla/xla/stream_executor/gpu/all_reduce_kernel.h index dc199258e58dd9..5144d17a36a05c 100644 --- a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel_common.h +++ b/third_party/xla/xla/stream_executor/gpu/all_reduce_kernel.h @@ -13,22 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_COMMON_H_ -#define XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_COMMON_H_ +#ifndef XLA_STREAM_EXECUTOR_GPU_ALL_REDUCE_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_GPU_ALL_REDUCE_KERNEL_H_ +#include #include -namespace xla::gpu { +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" + +namespace stream_executor::gpu { // The maximum number of input pointers that can be passed to the all-reduce // kernel. inline constexpr int64_t kMaxNumAllReduceInputPtrs = 8; -// Returns a pointer to the all-reduce kernel for the given element type. -// Returns nullptr if the element type is not supported. -template -void* GetAllReduceKernel(); +// Defines a trait for the AllReduce kernel that can be used to register +// and look up the kernel in the GPU kernel registry. +template +struct AllReduceKernel { + using KernelType = + stream_executor::TypedKernel, + stream_executor::DeviceMemoryBase, int64_t, + int64_t>; +}; -} // namespace xla::gpu +} // namespace stream_executor::gpu -#endif // XLA_SERVICE_GPU_KERNELS_ALL_REDUCE_KERNEL_COMMON_H_ +#endif // XLA_STREAM_EXECUTOR_GPU_ALL_REDUCE_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc b/third_party/xla/xla/stream_executor/gpu/all_reduce_kernel_lib.cu.h similarity index 73% rename from third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc rename to third_party/xla/xla/stream_executor/gpu/all_reduce_kernel_lib.cu.h index f819ab286a3390..f1568548f1748c 100644 --- a/third_party/xla/xla/service/gpu/kernels/all_reduce_kernel.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/all_reduce_kernel_lib.cu.h @@ -12,17 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef XLA_STREAM_EXECUTOR_GPU_ALL_REDUCE_KERNEL_LIB_CU_H_ +#define XLA_STREAM_EXECUTOR_GPU_ALL_REDUCE_KERNEL_LIB_CU_H_ #include #include -#include "xla/service/gpu/kernels/all_reduce_kernel_common.h" +#include "xla/stream_executor/gpu/all_reduce_kernel.h" -namespace xla::gpu { -namespace { +namespace stream_executor::gpu { template -__global__ void AllReduceKernel( +__global__ void AllReduceKernelImpl( std::array input_ptrs, T* __restrict__ output_ptr, int64_t num_inputs, int64_t num_elements) { int64_t offset = blockIdx.x * blockDim.x + threadIdx.x; @@ -47,16 +48,6 @@ __global__ void AllReduceKernel( } } -} // namespace +} // namespace stream_executor::gpu -template -void* GetAllReduceKernel() { - return reinterpret_cast< // REINTERPRET_CAST_OK=tsl::safe_reinterpret_cast - // doesn't support this cast, but it's necessary to - // conform to se::TypedKernelFactory<>::Create(). - void*>(&AllReduceKernel); -} - -template void* GetAllReduceKernel(); - -} // namespace xla::gpu +#endif // XLA_STREAM_EXECUTOR_GPU_ALL_REDUCE_KERNEL_LIB_CU_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 787389ea4344b4..d98a17bc663f28 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -818,6 +818,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":all_reduce_kernel_rocm", ":amdhipblaslt_plugin", ":buffer_comparator_kernel_rocm", ":hipfft_plugin", @@ -1137,3 +1138,24 @@ rocm_library( ], alwayslink = 1, ) + +rocm_library( + name = "all_reduce_kernel_rocm", + srcs = [ + "all_reduce_kernel_rocm.cc", + "//xla/stream_executor/gpu:all_reduce_kernel_lib.cu.h", + ], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + ":rocm_platform_id", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor/gpu:all_reduce_kernel", + "//xla/stream_executor/gpu:gpu_kernel_registry", + "@local_config_rocm//rocm:rocm_headers", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/stream_executor/rocm/all_reduce_kernel_rocm.cc b/third_party/xla/xla/stream_executor/rocm/all_reduce_kernel_rocm.cc new file mode 100644 index 00000000000000..3f9ef765558614 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/all_reduce_kernel_rocm.cc @@ -0,0 +1,37 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/stream_executor/gpu/all_reduce_kernel.h" +#include "xla/stream_executor/gpu/all_reduce_kernel_lib.cu.h" +#include "xla/stream_executor/gpu/gpu_kernel_registry.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" + +#define REGISTER_ALL_REDUCE_KERNEL(TYPE) \ + GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( \ + AllReduceKernelRocm##TYPE, stream_executor::gpu::AllReduceKernel, \ + stream_executor::rocm::kROCmPlatformId, ([] { \ + stream_executor::MultiKernelLoaderSpec spec(4); \ + spec.AddInProcessSymbol( \ + absl::bit_cast( \ + &stream_executor::gpu::AllReduceKernelImpl), \ + "one_shot_all_reduce_" #TYPE); \ + return spec; \ + })); + +// Register the kernel for different types using the macro +REGISTER_ALL_REDUCE_KERNEL(float); From 921c0337fdb79c64ab2b550293f3d57971812164 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 15 Apr 2025 06:57:27 -0700 Subject: [PATCH 0807/1324] [XLA:GPU/Runtime] Allow XLA:GPU runtime to propagate by-value arguments to kernels. This is important as we need to propagate TMA descriptors using 128 byte arrays. PiperOrigin-RevId: 747856919 --- .../xla/xla/backends/gpu/runtime/BUILD | 2 + .../gpu/runtime/command_buffer_cmd.cc | 5 +- .../xla/backends/gpu/runtime/kernel_thunk.cc | 70 ++++++++------ third_party/xla/xla/service/gpu/kernels/BUILD | 1 + .../service/gpu/kernels/ptx_custom_kernel.cc | 5 +- .../xla/service/gpu/stream_executor_util.cc | 26 ++---- .../xla/service/gpu/stream_executor_util.h | 16 +--- third_party/xla/xla/service/gpu/tests/BUILD | 4 + .../gpu/tests/dynamic_shared_memory_test.cc | 12 ++- third_party/xla/xla/stream_executor/BUILD | 2 + .../xla/stream_executor/cuda/cuda_executor.cc | 10 +- .../xla/stream_executor/cuda/cuda_executor.h | 8 +- .../xla/xla/stream_executor/device_memory.h | 6 ++ third_party/xla/xla/stream_executor/gpu/BUILD | 2 + .../stream_executor/gpu/gpu_kernel_test.cc | 92 ++++++++++++++++++- third_party/xla/xla/stream_executor/kernel.h | 50 +++++++++- .../xla/xla/stream_executor/kernel_test.cc | 6 +- .../xla/xla/stream_executor/stream_executor.h | 8 +- 18 files changed, 240 insertions(+), 85 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index b7c1085d8b2a1d..07b0df207aa4c8 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -681,6 +681,7 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel", "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:tma_metadata", "//xla/tsl/platform:statusor", @@ -688,6 +689,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 7b5bf945c2922a..38e43d30be22ec 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -731,8 +731,9 @@ absl::StatusOr LaunchCmd::Record( buffers.push_back(buf); } - TF_ASSIGN_OR_RETURN(auto kernel_args, - se::PackKernelArgs(buffers, shmem_bytes_)); + TF_ASSIGN_OR_RETURN( + auto kernel_args, + se::PackKernelArgs(buffers, shmem_bytes_)); return Handle( std::move(record_action), diff --git a/third_party/xla/xla/backends/gpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/kernel_thunk.cc index 97df5aa688fdb3..c83872d3e53b33 100644 --- a/third_party/xla/xla/backends/gpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/kernel_thunk.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/statusor.h" #include "tsl/platform/logging.h" @@ -103,19 +106,29 @@ absl::Status KernelThunk::Initialize(const InitializeParams& params) { } static void PrintBufferContents( - se::Stream* stream, absl::Span buffer_args) { + se::Stream* stream, absl::Span kernel_args) { int input_idx = 0; - for (const se::DeviceMemoryBase& buf : buffer_args) { - auto host_buffer = std::make_unique(buf.size()); - CHECK_OK(stream->Memcpy(host_buffer.get(), buf, buf.size())); - CHECK_OK(stream->BlockHostUntilDone()); - - std::string buffer_contents; - for (int i = 0; i < buf.size(); i++) { - absl::StrAppendFormat(&buffer_contents, "%x ", - static_cast(host_buffer[i])); + for (const auto& arg : kernel_args) { + if (std::holds_alternative(arg)) { + se::DeviceMemoryBase buf = std::get(arg); + + auto host_buffer = std::make_unique(buf.size()); + CHECK_OK(stream->Memcpy(host_buffer.get(), buf, buf.size())); + CHECK_OK(stream->BlockHostUntilDone()); + + std::string buffer_contents; + for (int i = 0; i < buf.size(); i++) { + absl::StrAppendFormat(&buffer_contents, "%x ", + static_cast(host_buffer[i])); + } + VLOG(100) << "BUF(" << input_idx++ << ") = " << buffer_contents; + } else { + se::TensorMap tensor_map = std::get(arg); + VLOG(100) << "TENSOR_MAP(" << input_idx++ << ") = "; + for (auto element : tensor_map.storage) { + VLOG(100) << absl::StrFormat("%x ", static_cast(element)); + } } - VLOG(100) << "BUF(" << input_idx++ << ") = " << buffer_contents; } } @@ -141,38 +154,39 @@ absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { } VLOG(3) << "Launching " << kernel->name(); - absl::InlinedVector buffer_args; + absl::InlinedVector, 4> + kernel_args; stream_executor::gpu::TmaMetadata tma_metadata = tma_metadata_.value_or(stream_executor::gpu::TmaMetadata{}); for (const auto& [idx, arg] : llvm::enumerate(args_)) { se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); VLOG(3) << " Arg: alloc #" << arg.index() << ", offset: " << arg.offset() << ": " << buf.opaque() << " (" << buf.size() << "B)"; + auto it = tma_metadata.arg_index_to_tma_info.find(idx); if (it != tma_metadata.arg_index_to_tma_info.end()) { + // TMA descriptor argument. stream_executor::gpu::TmaDescriptor tma_desc = it->second; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase tensor_map, + TF_ASSIGN_OR_RETURN(se::TensorMap tensor_map, executor->CreateTensorMap(tma_desc, buf.opaque())); VLOG(3) << " Using TensorMap for arg #" << arg.index() << ": " - << tma_desc.ToString() << "; buffer: " << tensor_map.opaque() - << " (" << tensor_map.size() << "B)"; - buffer_args.push_back(tensor_map); + << tma_desc.ToString(); + kernel_args.push_back(std::move(tensor_map)); } else { - buffer_args.push_back(buf); + // Buffer argument. + kernel_args.push_back(buf); } } if (VLOG_IS_ON(100)) { - PrintBufferContents(stream, buffer_args); + PrintBufferContents(stream, kernel_args); } - if (cluster_dim.has_value()) { - return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, - cluster_dim.value(), stream); - } else { - return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, - stream); - } + return ExecuteKernelOnStream( + *kernel, + absl::Span>( + kernel_args.data(), kernel_args.size()), + launch_dimensions, cluster_dim, stream); } //===----------------------------------------------------------------------===// @@ -233,7 +247,11 @@ absl::Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { } if (VLOG_IS_ON(100)) { - PrintBufferContents(params.stream, buffer_args); + absl::InlinedVector kernel_args; + for (const auto& arg : buffer_args) { + kernel_args.push_back(arg); + } + PrintBufferContents(params.stream, kernel_args); } se::KernelArgsDeviceMemoryArray args(buffer_args, diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 7df0e4c78e04b3..056fcc92315514 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -517,6 +517,7 @@ cc_library( ], deps = [ ":custom_kernel", + "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc index 21e6e56b7c7113..dad8648f321001 100644 --- a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -34,8 +35,8 @@ absl::StatusOr> KernelArgsPacking(const se::Kernel &kernel, const se::KernelArgs &args) { auto *mem_args = se::Cast(&args); - return se::PackKernelArgs(mem_args->device_memory_args(), - mem_args->number_of_shared_bytes()); + return se::PackKernelArgs( + mem_args->device_memory_args(), mem_args->number_of_shared_bytes()); } // Note: Make sure that the kernel_name matches the kernel name in the ptx, diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 4aaef2204a5913..27f927d3ba18c4 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -57,6 +57,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" #include "xla/tsl/protobuf/dnn.pb.h" #include "xla/tsl/util/proto/proto_utils.h" @@ -386,31 +387,22 @@ absl::StatusOr> CreateKernel( return kernel; } -absl::Status ExecuteKernelOnStream(se::Kernel& kernel, - absl::Span args, - const LaunchDimensions& dims, - se::Stream* stream) { +absl::Status ExecuteKernelOnStream( + se::Kernel& kernel, absl::Span args, + const LaunchDimensions& dims, + const std::optional& cluster_dim, se::Stream* stream) { TF_ASSIGN_OR_RETURN( std::unique_ptr kernel_args, se::PackKernelArgs(args, kernel.metadata())); + if (cluster_dim.has_value()) { + return kernel.Launch(dims.thread_counts_per_block(), dims.block_counts(), + cluster_dim.value(), stream, *kernel_args); + } return kernel.Launch(dims.thread_counts_per_block(), dims.block_counts(), stream, *kernel_args); } -absl::Status ExecuteKernelOnStream(se::Kernel& kernel, - absl::Span args, - const LaunchDimensions& dims, - const se::ClusterDim& cluster_dim, - se::Stream* stream) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr kernel_args, - se::PackKernelArgs(args, kernel.metadata())); - - return kernel.Launch(dims.thread_counts_per_block(), dims.block_counts(), - cluster_dim, stream, *kernel_args); -} - // Unimplemented for integers yet. template typename std::enable_if::value, diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index 87a91c0bd10fbb..e261be542016fd 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -104,17 +105,10 @@ absl::StatusOr> CreateKernel( uint32_t shared_mem_bytes = 0); // Runs loaded kernel on the stream with the provided arguments. -absl::Status ExecuteKernelOnStream(se::Kernel& kernel, - absl::Span args, - const LaunchDimensions& dims, - se::Stream* stream); - -// Runs loaded kernel on the stream with the provided arguments. -absl::Status ExecuteKernelOnStream(se::Kernel& kernel, - absl::Span args, - const LaunchDimensions& dims, - const se::ClusterDim& cluster_dim, - se::Stream* stream); +absl::Status ExecuteKernelOnStream( + se::Kernel& kernel, absl::Span args, + const LaunchDimensions& dims, + const std::optional& cluster_dim, se::Stream* stream); // Initializes `buffer` with random data on `stream`. // `rng_state` is an inout parameter for the pseudorandom generator state. diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 06c4e4b03a9916..bfd970cfcd3e14 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -693,11 +693,15 @@ xla_test( "//xla:xla_proto_cc", "//xla/stream_executor:device_description", "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc b/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc index 2e1f9d80611285..b384c51b54634b 100644 --- a/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc +++ b/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc @@ -15,21 +15,24 @@ limitations under the License. #include #include +#include #include +#include #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace gpu { @@ -173,10 +176,11 @@ TEST(SharedMemoryUseTest, ArrayReversalWorks) { se::DeviceMemory dev_n_rows = executor->AllocateScalar(); TF_CHECK_OK(stream->Memcpy(&dev_n_rows, &n_rows, sizeof(uint32_t))); TF_CHECK_OK(stream->BlockHostUntilDone()); + TF_CHECK_OK(ExecuteKernelOnStream( *kernel, {device_buffer, dev_n_cols, dev_n_rows}, {/*block_x_count=*/1, /*thread_x_count_per_block=*/n_cols}, - stream.get())); + /*cluster_dim=*/{}, stream.get())); TF_CHECK_OK(stream->BlockHostUntilDone()); TF_CHECK_OK( stream->Memcpy(host_buffer.data(), device_buffer, buffer_size_bytes)); diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 30611b377bc593..7154ae5187c98c 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -762,6 +762,8 @@ xla_cc_test( ":typed_kernel_factory", "//xla/stream_executor/host:host_platform", "//xla/tsl/platform:test_main", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ], diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index e50d8f791476cf..109bf5cc592ed4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -1422,8 +1423,8 @@ absl::StatusOr CudaExecutor::GetCudaKernel( return static_cast(*it); } -absl::StatusOr CudaExecutor::CreateTensorMap( - TmaDescriptor tma_desc, void* global_address) { +absl::StatusOr CudaExecutor::CreateTensorMap(TmaDescriptor tma_desc, + void* global_address) { TF_ASSIGN_OR_RETURN(CUtensorMapDataType data_type, GetTensorMapDataType(tma_desc.element_size())); CUtensorMapSwizzle swizzle = GetTensorMapSwizzle(tma_desc.swizzle()); @@ -1447,10 +1448,7 @@ absl::StatusOr CudaExecutor::CreateTensorMap( "Failed to create tensormap with cuTensorMapEncodeTiled: %s", error_message)); } - DeviceMemoryBase device_tensor_map = Allocate(sizeof(tensor_map), 0); - TF_RETURN_IF_ERROR( - SynchronousMemcpy(&device_tensor_map, &tensor_map, sizeof(tensor_map))); - return device_tensor_map; + return absl::bit_cast(tensor_map); } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h index 99a150a956b118..b0c8e5b9154aa6 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h @@ -133,10 +133,10 @@ class CudaExecutor : public GpuExecutor { absl::StatusOr GetCudaKernel(const Kernel* kernel); // Creates, allocates, and copies a CUtensorMap object for the given TMA - // descriptor. Returns a DeviceMemoryBase pointing to the allocated - // CUtensorMap object to be used as an argument to a kernel. - absl::StatusOr CreateTensorMap( - TmaDescriptor tma_desc, void* global_address) override; + // descriptor. Returns a TensorMap, which is 128 bytes of storage, to be + // passed by value to the kernel. + absl::StatusOr CreateTensorMap(TmaDescriptor tma_desc, + void* global_address) override; absl::StatusOr> CreateMemoryAllocator( MemoryType type) override; diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index d599faadf7562f..fe43687094794d 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -174,6 +174,12 @@ class DeviceMemory final : public DeviceMemoryBase { DeviceMemory(void *opaque, uint64_t size) : DeviceMemoryBase(opaque, size) {} }; +// TensorMap is a wrapper around a 128 bytes of storage. It is used to pass TMA +// descriptors to the kernel. +struct TensorMap { + alignas(64) std::byte storage[128]; +}; + } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_DEVICE_MEMORY_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index f72bc7ced73c13..721cb7a298826b 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -604,6 +604,7 @@ xla_test( ":gpu_test_kernels_fatbin", "//xla/service:platform_util", "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", @@ -613,6 +614,7 @@ xla_test( "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/rocm:rocm_platform_id", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc index 3f5500cb60b46e..ae9a2b602555f4 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -27,6 +29,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -36,8 +39,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" namespace stream_executor::gpu { namespace { @@ -106,5 +108,91 @@ TEST_F(GpuKernelTest, LoadAndRunKernelFromSymbol) { RunAddI32Kernel(GetAddI32KernelSpec()); } +TEST_F(GpuKernelTest, ArrayArgByValue) { + constexpr absl::string_view copy_kernel = R"( + .version 8.0 + .target sm_60 + .address_size 64 + + .visible .entry copy_kernel( + .param .u64 foo_param_0, + .param .align 1 .b8 foo_param_1[16] +) +{ + .reg .b16 %rs<17>; + .reg .b64 %rd<3>; + .loc 1 5 0 + + ld.param.u64 %rd1, [foo_param_0]; + cvta.to.global.u64 %rd2, %rd1; + ld.param.u8 %rs1, [foo_param_1+15]; + ld.param.u8 %rs2, [foo_param_1+14]; + ld.param.u8 %rs3, [foo_param_1+13]; + ld.param.u8 %rs4, [foo_param_1+12]; + ld.param.u8 %rs5, [foo_param_1+11]; + ld.param.u8 %rs6, [foo_param_1+10]; + ld.param.u8 %rs7, [foo_param_1+9]; + ld.param.u8 %rs8, [foo_param_1+8]; + ld.param.u8 %rs9, [foo_param_1+7]; + ld.param.u8 %rs10, [foo_param_1+6]; + ld.param.u8 %rs11, [foo_param_1+5]; + ld.param.u8 %rs12, [foo_param_1+4]; + ld.param.u8 %rs13, [foo_param_1+3]; + ld.param.u8 %rs14, [foo_param_1+2]; + ld.param.u8 %rs15, [foo_param_1+1]; + ld.param.u8 %rs16, [foo_param_1]; + .loc 1 6 5 + st.global.u8 [%rd2], %rs16; + st.global.u8 [%rd2+1], %rs15; + st.global.u8 [%rd2+2], %rs14; + st.global.u8 [%rd2+3], %rs13; + st.global.u8 [%rd2+4], %rs12; + st.global.u8 [%rd2+5], %rs11; + st.global.u8 [%rd2+6], %rs10; + st.global.u8 [%rd2+7], %rs9; + st.global.u8 [%rd2+8], %rs8; + st.global.u8 [%rd2+9], %rs7; + st.global.u8 [%rd2+10], %rs6; + st.global.u8 [%rd2+11], %rs5; + st.global.u8 [%rd2+12], %rs4; + st.global.u8 [%rd2+13], %rs3; + st.global.u8 [%rd2+14], %rs2; + st.global.u8 [%rd2+15], %rs1; + .loc 1 7 1 + ret; + } + )"; + + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddCudaPtxInMemory(copy_kernel, "copy_kernel"); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor_->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto kernel, executor_->LoadKernel(spec)); + + constexpr int64_t kLength = 16; + + DeviceMemory dst = executor_->AllocateArray(kLength, 0); + TF_ASSERT_OK(stream->MemZero(&dst, kLength)); + + struct ByValArg { + std::byte storage[16]; + }; + ByValArg arg; + int i = 0; + for (auto& element : arg.storage) { + element = static_cast(i++); + } + + // Launch kernel. + auto args = stream_executor::PackKernelArgs(/*shmem_bytes=*/0, dst, arg); + TF_ASSERT_OK(kernel->Launch(ThreadDim(), BlockDim(), stream.get(), *args)); + + // Copy data back to host. + std::byte dst_host[16] = {}; + TF_ASSERT_OK(stream->Memcpy(dst_host, dst, kLength)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + EXPECT_THAT(dst_host, ::testing::ElementsAreArray(arg.storage)); +} } // namespace } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 2942eb4153e49e..c8ec5df54d52f7 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -79,6 +79,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/inlined_vector.h" #include "absl/meta/type_traits.h" @@ -91,7 +92,6 @@ limitations under the License. #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/logging.h" namespace stream_executor { @@ -528,6 +528,8 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase, ArgsStorage { size_t number_of_argument_addresses_ = 0; }; +using KernelArgument = std::variant; + namespace internal { template std::unique_ptr PackKernelArgs( @@ -541,11 +543,49 @@ std::unique_ptr PackKernelArgs( } return packed; } + +template +std::unique_ptr PackKernelArgs( + absl::Span args, uint32_t shared_mem_bytes) { + auto contains_tensor_map = [](absl::Span args) -> bool { + return absl::c_any_of(args, [](const auto &arg) { + return std::holds_alternative(arg); + }); + }; + + if (contains_tensor_map(args)) { + auto packed = + std::make_unique>>(); + for (auto &buf : args) { + if (std::holds_alternative(buf)) { + // Buffer argument. + packed->add_device_memory_argument(std::get(buf)); + } else { + // TMA descriptor argument. + packed->add_argument(std::get(buf).storage); + } + } + if (shared_mem_bytes > 0) { + packed->add_shared_bytes(shared_mem_bytes); + } + return packed; + } + + // No TensorMap arguments -> Can use EmptyArgs. + auto packed = std::make_unique>(); + for (auto &buf : args) { + packed->add_device_memory_argument(std::get(buf)); + } + if (shared_mem_bytes > 0) { + packed->add_shared_bytes(shared_mem_bytes); + } + return packed; +} } // namespace internal +template inline absl::StatusOr> -PackKernelArgs(absl::Span args, - uint32_t shared_mem_bytes) { +PackKernelArgs(absl::Span args, uint32_t shared_mem_bytes) { static constexpr int kKernelArgsLimit = 1024; if (args.size() > kKernelArgsLimit) @@ -575,9 +615,9 @@ PackKernelArgs(absl::Span args, return internal::PackKernelArgs(args, shared_mem_bytes); } +template inline absl::StatusOr> -PackKernelArgs(absl::Span args, - const KernelMetadata &metadata) { +PackKernelArgs(absl::Span args, const KernelMetadata &metadata) { return PackKernelArgs(args, metadata.shared_memory_bytes().value_or(0)); } diff --git a/third_party/xla/xla/stream_executor/kernel_test.cc b/third_party/xla/xla/stream_executor/kernel_test.cc index a554785735d3cd..7dcb889e11967a 100644 --- a/third_party/xla/xla/stream_executor/kernel_test.cc +++ b/third_party/xla/xla/stream_executor/kernel_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/types/span.h" +#include "benchmark/benchmark.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/platform.h" @@ -75,7 +77,7 @@ TEST(KernelTest, PackDeviceMemoryArguments) { DeviceMemoryBase a(reinterpret_cast(0x12345678)); DeviceMemoryBase b(reinterpret_cast(0x87654321)); - auto args = PackKernelArgs({a, b}, 0).value(); + auto args = PackKernelArgs({a, b}, 0).value(); ASSERT_EQ(args->number_of_arguments(), 2); auto packed = args->argument_addresses(); @@ -137,7 +139,7 @@ static void BM_PackDeviceMemoryArgs(benchmark::State& state) { } for (auto s : state) { - auto packed = PackKernelArgs(args, 0); + auto packed = PackKernelArgs(args, 0); benchmark::DoNotOptimize(packed); } } diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 1e93114ccb059d..0724222e046180 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -325,11 +325,11 @@ class StreamExecutor { virtual bool SetArgumentLoggingMode(uint64_t mode) { return false; } // Creates, allocates, and copies a CUtensorMap object for the given TMA - // descriptor. Returns a DeviceMemoryBase pointing to the allocated - // CUtensorMap object to be used as an argument to a kernel. + // descriptor. Returns a TensorMap, which is 128 bytes of storage, to be + // passed by value to the kernel. // Only implemented on CUDA GPUs. - virtual absl::StatusOr CreateTensorMap( - gpu::TmaDescriptor tma_desc, void* global_address) { + virtual absl::StatusOr CreateTensorMap(gpu::TmaDescriptor tma_desc, + void* global_address) { return absl::UnimplementedError("Not Implemented"); } }; From c3bcea4fec9b79c4e929e4aa037e635242c8a3a5 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 15 Apr 2025 07:02:00 -0700 Subject: [PATCH 0808/1324] Deprecate a few `tsl::errors` functions in favor of their Abseil counterparts. PiperOrigin-RevId: 747858092 --- third_party/xla/xla/tsl/platform/BUILD | 1 + third_party/xla/xla/tsl/platform/errors.h | 58 +++++++++++++++++++---- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/tsl/platform/BUILD b/third_party/xla/xla/tsl/platform/BUILD index 16c69ae2a1a28a..a4f638a4d4fe85 100644 --- a/third_party/xla/xla/tsl/platform/BUILD +++ b/third_party/xla/xla/tsl/platform/BUILD @@ -392,6 +392,7 @@ cc_library( ":logging", ":macros", ":status", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", diff --git a/third_party/xla/xla/tsl/platform/errors.h b/third_party/xla/xla/tsl/platform/errors.h index 6ef49cfb1889b0..e86e77a39de69e 100644 --- a/third_party/xla/xla/tsl/platform/errors.h +++ b/third_party/xla/xla/tsl/platform/errors.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" @@ -199,6 +201,7 @@ void AppendToMessage(absl::Status* status, Args... args) { // CANCELLED template +ABSL_DEPRECATED("Use absl::CancelledError() instead.") absl::Status Cancelled(Args... args) { return absl::CancelledError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -212,6 +215,7 @@ absl::Status CancelledWithPayloads( // InvalidArgument template +ABSL_DEPRECATED("Use absl::InvalidArgumentError() instead.") absl::Status InvalidArgument(Args... args) { return absl::InvalidArgumentError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -219,6 +223,7 @@ absl::Status InvalidArgument(Args... args) { // Specialized overloads to capture source location for up to four arguments. #if defined(PLATFORM_GOOGLE) template +ABSL_DEPRECATED("Use absl::InvalidArgumentError() instead.") absl::Status InvalidArgument( Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, absl::SourceLocation loc = absl::SourceLocation::current()) { @@ -230,6 +235,7 @@ absl::Status InvalidArgument( loc); } template +ABSL_DEPRECATED("Use absl::InvalidArgumentError() instead.") absl::Status InvalidArgument( Arg1 arg1, Arg2 arg2, Arg3 arg3, absl::SourceLocation loc = absl::SourceLocation::current()) { @@ -240,6 +246,7 @@ absl::Status InvalidArgument( loc); } template +ABSL_DEPRECATED("Use absl::InvalidArgumentError() instead.") absl::Status InvalidArgument( Arg1 arg1, Arg2 arg2, absl::SourceLocation loc = absl::SourceLocation::current()) { @@ -249,6 +256,7 @@ absl::Status InvalidArgument( loc); } template +ABSL_DEPRECATED("Use absl::InvalidArgumentError() instead.") absl::Status InvalidArgument( Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { return absl::InvalidArgumentError( @@ -272,6 +280,7 @@ inline absl::Status InvalidArgumentWithPayloads( // NotFound template +ABSL_DEPRECATED("Use absl::NotFoundError() instead.") absl::Status NotFound(Args... args) { return absl::NotFoundError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -279,9 +288,10 @@ absl::Status NotFound(Args... args) { // Specialized overloads to capture source location for up to three arguments. #if defined(PLATFORM_GOOGLE) template -absl::Status NotFound( - Arg1 arg1, Arg2 arg2, Arg3 arg3, - absl::SourceLocation loc = absl::SourceLocation::current()) { +ABSL_DEPRECATED("Use absl::NotFoundError() instead.") +absl::Status + NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3, + absl::SourceLocation loc = absl::SourceLocation::current()) { return absl::NotFoundError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2), @@ -289,17 +299,20 @@ absl::Status NotFound( loc); } template -absl::Status NotFound( - Arg1 arg1, Arg2 arg2, - absl::SourceLocation loc = absl::SourceLocation::current()) { +ABSL_DEPRECATED("Use absl::NotFoundError() instead.") +absl::Status + NotFound(Arg1 arg1, Arg2 arg2, + absl::SourceLocation loc = absl::SourceLocation::current()) { return absl::NotFoundError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), ::tsl::errors::internal::PrepareForStrCat(arg2)), loc); } template -absl::Status NotFound( - Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { +ABSL_DEPRECATED("Use absl::NotFoundError() instead.") +absl::Status + NotFound(Arg1 arg1, + absl::SourceLocation loc = absl::SourceLocation::current()) { return absl::NotFoundError( ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), loc); @@ -320,6 +333,7 @@ inline absl::Status NotFoundWithPayloads( // AlreadyExists template +ABSL_DEPRECATED("Use absl::AlreadyExistsError() instead.") absl::Status AlreadyExists(Args... args) { return absl::AlreadyExistsError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -332,6 +346,7 @@ inline absl::Status AlreadyExistsWithPayloads( // ResourceExhausted template +ABSL_DEPRECATED("Use absl::ResourceExhaustedError() instead.") absl::Status ResourceExhausted(Args... args) { return absl::ResourceExhaustedError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -345,6 +360,7 @@ inline absl::Status ResourceExhaustedWithPayloads( // Unavailable template +ABSL_DEPRECATED("Use absl::UnavailableError() instead.") absl::Status Unavailable(Args... args) { return absl::UnavailableError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -357,6 +373,7 @@ inline absl::Status UnavailableWithPayloads( // FailedPrecondition template +ABSL_DEPRECATED("Use absl::FailedPreconditionError() instead.") absl::Status FailedPrecondition(Args... args) { return absl::FailedPreconditionError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -370,6 +387,7 @@ inline absl::Status FailedPreconditionWithPayloads( // OutOfRange template +ABSL_DEPRECATED("Use absl::OutOfRangeError() instead.") absl::Status OutOfRange(Args... args) { return absl::OutOfRangeError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -382,6 +400,7 @@ inline absl::Status OutOfRangeWithPayloads( // Unimplemented template +ABSL_DEPRECATED("Use absl::UnimplementedError() instead.") absl::Status Unimplemented(Args... args) { return absl::UnimplementedError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -394,6 +413,7 @@ inline absl::Status UnimplementedWithPayloads( // Internal template +ABSL_DEPRECATED("Use absl::InternalError() instead.") absl::Status Internal(Args... args) { return absl::InternalError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -406,6 +426,7 @@ inline absl::Status InternalWithPayloads( // Aborted template +ABSL_DEPRECATED("Use absl::AbortedError() instead.") absl::Status Aborted(Args... args) { return absl::AbortedError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -418,6 +439,7 @@ inline absl::Status AbortedWithPayloads( // DeadlineExceeded template +ABSL_DEPRECATED("Use absl::DeadlineExceededError() instead.") absl::Status DeadlineExceeded(Args... args) { return absl::DeadlineExceededError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -430,6 +452,7 @@ inline absl::Status DeadlineExceededWithPayloads( // DataLoss template +ABSL_DEPRECATED("Use absl::DataLossError() instead.") absl::Status DataLoss(Args... args) { return absl::DataLossError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -442,6 +465,7 @@ inline absl::Status DataLossWithPayloads( // Unknown template +ABSL_DEPRECATED("Use absl::UnknownError() instead.") absl::Status Unknown(Args... args) { return absl::UnknownError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -453,6 +477,7 @@ inline absl::Status UnknownPayloads( } // PermissionDenied template +ABSL_DEPRECATED("Use absl::PermissionDeniedError() instead.") absl::Status PermissionDenied(Args... args) { return absl::PermissionDeniedError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -465,6 +490,7 @@ inline absl::Status PermissionDeniedWithPayloads( // Unauthenticated template +ABSL_DEPRECATED("Use absl::UnauthenticatedError() instead.") absl::Status Unauthenticated(Args... args) { return absl::UnauthenticatedError(::tsl::strings::StrCat( ::tsl::errors::internal::PrepareForStrCat(args)...)); @@ -475,51 +501,67 @@ inline absl::Status UnauthenticatedWithPayloads( return errors::Create(absl::StatusCode::kUnauthenticated, message, payloads); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsAborted(const absl::Status& status) { return absl::IsAborted(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsAlreadyExists(const absl::Status& status) { return absl::IsAlreadyExists(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsCancelled(const absl::Status& status) { return absl::IsCancelled(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsDataLoss(const absl::Status& status) { return absl::IsDataLoss(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsDeadlineExceeded(const absl::Status& status) { return absl::IsDeadlineExceeded(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsFailedPrecondition(const absl::Status& status) { return absl::IsFailedPrecondition(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsInternal(const absl::Status& status) { return absl::IsInternal(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsInvalidArgument(const absl::Status& status) { return absl::IsInvalidArgument(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsNotFound(const absl::Status& status) { return absl::IsNotFound(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsOutOfRange(const absl::Status& status) { return absl::IsOutOfRange(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsPermissionDenied(const absl::Status& status) { return absl::IsPermissionDenied(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsResourceExhausted(const absl::Status& status) { return absl::IsResourceExhausted(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsUnauthenticated(const absl::Status& status) { return absl::IsUnauthenticated(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsUnavailable(const absl::Status& status) { return absl::IsUnavailable(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsUnimplemented(const absl::Status& status) { return absl::IsUnimplemented(status); } +ABSL_DEPRECATE_AND_INLINE() inline bool IsUnknown(const absl::Status& status) { return absl::IsUnknown(status); } From cf76cb9995af787f742ada0d0fcc64dac85855fd Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 15 Apr 2025 17:09:12 +0200 Subject: [PATCH 0809/1324] [mlir][tosa] Fix fully connected legalization when output shape is dynamic (#90800) Currently the legalization uses the output shape to determine the 'shape' input of the final reshape operation inserted during the legalization. When the output type is not static, the 'shape' input contains '-1' values that represent dynamic dims from the output type shape. This commit uses information from the input type and "keep_num_dims" attribute to determine the final output shape to reshape to instead. This can help avoid creating a 'shape' input with non-static dimensions in some cases. Change-Id: Id0cce68a0930b2c9f9d6b6d1142f68e59286769c Signed-off-by: Luke Hutton --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 40 +++++++++++++++++++ .../mlir/tosa/transforms/legalize_tfl.cc | 17 ++++---- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 7ed4a52665b6cf..7b1cac84df042e 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2945,6 +2945,46 @@ func.func @test_fullyconnected_qi16(%input: tensor<1x7x!quant.uniform, %arg1: tensor<1000x2048xf32>, %arg2: tensor<1000xf32>) -> tensor { + // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2048]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[1000, 1, 1, 2048]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[1, 1000]> : tensor<2xindex>} + // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> + // CHECK: %[[VAL0:.*]] = tosa.reshape %arg0, %[[CONST0]] + // CHECK: %[[VAL1:.*]] = tosa.reshape %arg1, %[[CONST1]] + // CHECK: %[[VAL2:.*]] = tosa.conv2d %[[VAL0]], %[[VAL1]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} + // CHECK: %[[VAL3:.*]] = tosa.reshape %[[VAL2]], %[[CONST2]] + // return %[[VAL3]] + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x2048xf32>, tensor<1000x2048xf32>, tensor<1000xf32>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_fullyconnected_keep_dims +func.func @test_fullyconnected_keep_dims(%arg0: tensor<1x64x64x768x!quant.uniform>, %arg1: tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, %arg2: tensor<3072x!quant.uniform>) -> tensor<1x64x64x3072x!quant.uniform> { + // CHECK-DAG: %[[CONST_SHAPE0:.*]] = tosa.const_shape {values = dense<[1, 64, 64, 3072]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST0:.*]] = "tosa.const"() <{values = dense<38> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST1:.*]] = "tosa.const"() <{values = dense<1241512252> : tensor<1xi32>}> + // CHECK-DAG: %[[CONST2:.*]] = "tosa.const"() <{values = dense<45> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[CONST4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST5:.*]] = "tosa.const"() <{values = dense<5> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST_SHAPE1:.*]] = tosa.const_shape {values = dense<[3072, 1, 1, 768]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST_SHAPE2:.*]] = tosa.const_shape {values = dense<[4096, 1, 1, 768]> : tensor<4xindex>} + // CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %arg0, %[[CONST_SHAPE2]] : (tensor<1x64x64x768x!quant.uniform>, !tosa.shape<4>) + // CHECK: %[[RESHAPE_FILT:.*]] = tosa.reshape %arg1, %[[CONST_SHAPE1]] : (tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, !tosa.shape<4>) + // CHECK: %[[CONV:.*]] = tosa.conv2d %[[RESHAPE_IN]], %[[RESHAPE_FILT]], %arg2, %[[CONST5]], %[[CONST4]] {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<4096x1x1x768x!quant.uniform>, tensor<3072x1x1x768x!quant.uniform:f32, 0.003333511995151639>>, tensor<3072x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) + // CHECK: %[[RESCALE:.*]] = tosa.rescale %[[CONV]], %[[CONST1]], %[[CONST0]], %[[CONST3]], %[[CONST2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<4096x1x1x3072xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) + // CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[RESCALE]], %[[CONST_SHAPE0]] : (tensor<4096x1x1x3072x!quant.uniform>, !tosa.shape<4>) -> tensor<1x64x64x3072x!quant.uniform> + // CHECK: return %[[RESHAPE_OUT]] + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<1x64x64x768x!quant.uniform>, tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, tensor<3072x!quant.uniform>) -> tensor<1x64x64x3072x!quant.uniform> + func.return %0 : tensor<1x64x64x3072x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_gather // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index f3e6e371f47ca9..ac0c4b0c9c8266 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -2349,19 +2349,16 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( // If we know the output rank, we need to ensure the output shape is correct. ShapedType fc_type = mlir::cast(fc_output.getType()); - DenseI64ArrayAttr output_shape_attr; - if (output_type.hasRank()) { - output_shape_attr = rewriter.getDenseI64ArrayAttr(output_type.getShape()); + llvm::SmallVector output_shape; + if (tfl_fc_op.getKeepNumDims()) { + const llvm::ArrayRef orig_input_shape = tfl_fc_op.getInput().getType().getShape(); + output_shape.append(orig_input_shape.begin(), orig_input_shape.end() - 1); + output_shape.push_back(OC); } else { - // set output_shape to {N, OC} to match previous results - // with tosa::FullyConnectedOp - output_shape_attr = rewriter.getDenseI64ArrayAttr({N, OC}); + output_shape.append({N, OC}); } - auto output_shape_value = - (output_type.hasRank()) - ? getTosaConstShape(rewriter, op->getLoc(), output_type.getShape()) - : getTosaConstShape(rewriter, op->getLoc(), {N, OC}); + auto output_shape_value = getTosaConstShape(rewriter, op->getLoc(), output_shape); fc_output = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(fc_type.getElementType()), fc_output, output_shape_value); From 7245544dff515a9565b1a294bb906dd923df72e7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 15 Apr 2025 08:44:34 -0700 Subject: [PATCH 0810/1324] [xla:cpu] Remove clang compilation error workaround PiperOrigin-RevId: 747891268 --- .../xla/backends/cpu/runtime/thunk_executor.h | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index f26ca4bfca0af2..b9261b88cb3c04 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -39,27 +40,6 @@ limitations under the License. namespace xla::cpu { -namespace internal { -// Clang does not allow defining a nested struct with member initializer, as -// a workaround we define a struct in internal namespace and create an alias. -struct ThunkExecutorOptions { - enum class ReadyQueueType { kFifo, kLifo, kPriority }; - - // If all thunks in a sequence use buffers of size less than or equal to the - // given threshold, we mark execution as sequential, as concurrency overheads - // will likely dominate the overall execution time. - size_t execute_sequential_buffer_threshold = 512; - - // If thunk sequence length is less than or equal to the given threshold, we - // mark execution as sequential, as concurrency overheads will likely dominate - // the overall execution time. - size_t execute_sequential_num_thunks_threshold = 8; - - // The type of a queue for ready thunks. - ReadyQueueType ready_queue_type = ReadyQueueType::kFifo; -}; -} // namespace internal - // A dataflow-style (run when ready) executor for a ThunkSequence that depends // on buffer uses to build a DAG defining execution order. At run time executes // thunks concurrently in a given thread pool. @@ -68,13 +48,33 @@ class ThunkExecutor { using BufferUses = Thunk::BufferUses; using ResourceUses = Thunk::ResourceUses; using ExecuteEvent = Thunk::ExecuteEvent; - using Options = internal::ThunkExecutorOptions; ThunkExecutor(ThunkExecutor&&) = default; ThunkExecutor& operator=(ThunkExecutor&&) = default; - static absl::StatusOr Create( - ThunkSequence thunk_sequence, const Options& options = Options()); + struct Options { + enum class ReadyQueueType { kFifo, kLifo, kPriority }; + + // If all thunks in a sequence use buffers of size less than or equal to the + // given threshold, we mark execution as sequential, as concurrency + // overheads will likely dominate the overall execution time. + size_t execute_sequential_buffer_threshold = 512; + + // If thunk sequence length is less than or equal to the given threshold, we + // mark execution as sequential, as concurrency overheads will likely + // dominate the overall execution time. + size_t execute_sequential_num_thunks_threshold = 8; + + // The type of a queue for ready thunks. + ReadyQueueType ready_queue_type = ReadyQueueType::kFifo; + }; + + static absl::StatusOr Create(ThunkSequence thunk_sequence, + const Options& options); + + static absl::StatusOr Create(ThunkSequence thunk_sequence) { + return Create(std::move(thunk_sequence), Options()); + } // Executes the thunk sequence using the prepared dataflow graph. Executor // uses runner to execute ready tasks concurrently. If runner is not provided, From 19841353c3084359e2ea99950cc535bdf1f95a2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 08:46:13 -0700 Subject: [PATCH 0811/1324] Rollback nvshmem allocator usage due to tests failure Reverts e514c1fd0faecc5862be8df57a27faf88deb387f PiperOrigin-RevId: 747891808 --- .../xla/xla/backends/gpu/collectives/BUILD | 26 +- .../gpu/collectives/nccl_collectives.cc | 29 -- .../gpu/collectives/nvshmem_collectives.cc | 6 - .../gpu/collectives/nvshmem_collectives.h | 29 +- third_party/xla/xla/debug_options_flags.cc | 6 - third_party/xla/xla/pjrt/gpu/BUILD | 50 ---- .../gpu/se_gpu_pjrt_client_nvshmem_test.cc | 258 ------------------ third_party/xla/xla/service/gpu/BUILD | 2 - .../service/gpu/compile_module_to_llvm_ir.cc | 9 +- .../service/gpu/gpu_memory_space_assignment.h | 30 +- third_party/xla/xla/xla.proto | 6 +- 11 files changed, 26 insertions(+), 425 deletions(-) delete mode 100644 third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 5d8d3c71afb69a..2f8c015eea9676 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -18,22 +18,13 @@ package_group( ], ) -config_setting( - name = "arm_build", - values = {"cpu": "arm"}, -) - # Build target that registers all available GPU collectives implementations with the collectives # registry at link time. cc_library( name = "gpu_collectives_plugin", deps = [ ":gpu_collectives_stub", - ] + if_nccl([":nccl_collectives"]) + select({ - # TODO(b/409709288): Fix nvshmem ARM issues and remove this condition. - ":arm_build": [], - "//conditions:default": [":nvshmem_collectives"], - }), + ] + if_nccl([":nccl_collectives"]), ) cc_library( @@ -231,7 +222,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:numbers", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", ]) + if_rocm_is_configured([ @@ -281,11 +271,14 @@ cc_library( cc_library( name = "nvshmem_collectives", - srcs = if_cuda_is_configured(["nvshmem_collectives.cc"]), - hdrs = if_cuda_is_configured(["nvshmem_collectives.h"]), + srcs = ["nvshmem_collectives.cc"], + hdrs = ["nvshmem_collectives.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:private"], deps = [ - ":gpu_collectives", "//xla/core/collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", @@ -306,8 +299,9 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:numbers", - ] + if_cuda_is_configured(["@nvshmem//:nvshmem_lib"]), - alwayslink = True, + "@nvshmem//:nvshmem_lib", + ], + alwayslink = True, # registers collectives implementation ) xla_cc_test( diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index cfcd6469edf2b9..d54c91f00b0e8d 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -44,7 +44,6 @@ limitations under the License. #include "xla/core/collectives/collectives_registry.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" -#include "xla/debug_options_flags.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/gpu_executable_run_options.h" @@ -54,7 +53,6 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "tsl/platform/casts.h" -#include "tsl/platform/numbers.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -237,24 +235,7 @@ absl::Status NcclCollectives::GroupEnd() { return XLA_NCCL_STATUS(ncclGroupEnd()); } -static absl::StatusOr GetNvshmemCollectives() { - TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, - xla::CollectivesRegistry::Get("gpu", "nvshmem")); - xla::gpu::GpuCollectives* nvshmem_collectives = - tsl::down_cast(collectives); - if (nvshmem_collectives == nullptr) { - return absl::InternalError("Failed to get NVSHMEM collectives"); - } - - return nvshmem_collectives; -} - absl::StatusOr NcclCollectives::Allocate(uint64_t bytes) { - if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) { - TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives()); - return nvshmem_collectives->Allocate(bytes); - } - void* ptr = nullptr; ncclResult_t res = ncclMemAlloc(&ptr, bytes); if (res != ncclSuccess) { @@ -270,11 +251,6 @@ absl::StatusOr NcclCollectives::Allocate(uint64_t bytes) { } absl::Status NcclCollectives::Deallocate(void* location) { - if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) { - TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives()); - return nvshmem_collectives->Deallocate(location); - } - ncclResult_t res = ncclMemFree(location); if (res != ncclSuccess) { return absl::InternalError(absl::StrFormat( @@ -342,11 +318,6 @@ class NcclIdStore { absl::Status NcclCollectives::InitializeTopology( NcclCollectives::Topology topology) { - if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) { - TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives()); - TF_RETURN_IF_ERROR(nvshmem_collectives->InitializeTopology(topology)); - } - if (topology.num_nodes > 1) { auto nccl_id_store = std::make_shared( topology.node_id, topology.device_id_to_node_id, diff --git a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc index 37462433ab9b99..acb5c8cf13eba5 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc @@ -57,12 +57,6 @@ NvshmemCollectives* NvshmemCollectives::Default() { LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM"; } -absl::Status NvshmemCollectives::InitializeTopology(Topology topology) { - SetEnvInfo(topology.node_id, topology.num_nodes, - topology.device_count_per_process, topology.kv_store); - return absl::OkStatus(); -} - void NvshmemCollectives::SetEnvInfo( int process_id, size_t num_processes, size_t device_count_per_process, std::weak_ptr kv_store) { diff --git a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h index 3019feae0cc091..f7cd034e1c8a86 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h @@ -25,7 +25,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/collectives.h" @@ -36,7 +35,7 @@ limitations under the License. namespace xla::gpu { // NVIDIA NVSHMEM library -class NvshmemCollectives : public GpuCollectives { +class NvshmemCollectives : public Collectives { public: ~NvshmemCollectives() override; @@ -46,46 +45,28 @@ class NvshmemCollectives : public GpuCollectives { size_t device_count_per_process, std::weak_ptr kv_store); - absl::StatusOr Allocate(uint64_t bytes) final; + absl::StatusOr Allocate(uint64_t bytes); - absl::Status Deallocate(void* buffer) final; + absl::Status Deallocate(void* buffer); absl::StatusOr CreateUniqueCliqueId() const final { return absl::UnimplementedError("Not implemented."); } - absl::Status GroupStart() final { - return absl::UnimplementedError("Not implemented."); - } - absl::Status GroupEnd() final { - return absl::UnimplementedError("Not implemented."); - } - - bool IsImplemented() const final { return true; } - - bool IsGlobalConfig() const final { return false; } - - absl::StatusOr GetCliqueIdCallback( - const CliqueIdCallback* clique_id_callback, bool is_local) final { - return absl::UnimplementedError("Not implemented."); - } - absl::StatusOr>> CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_ids, absl::Span ranks, - const Collectives::Config& config) { + const Config& config) final { return absl::UnimplementedError("Not implemented."); } absl::StatusOr>> SplitCommunicators( absl::Span comms, int32_t color, - absl::Span keys, const Collectives::Config& config) final { + absl::Span keys, const Config& config) final { return absl::UnimplementedError("Not implemented."); } - absl::Status InitializeTopology(Topology topology) final; - private: absl::Status InitializeOnce(); diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 34b9637d51316d..fa4446a4f78a93 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -167,7 +167,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); opts.set_xla_gpu_enable_nccl_user_buffers(false); - opts.set_xla_gpu_experimental_enable_nvshmem(false); opts.set_xla_gpu_enable_nccl_comm_splitting(true); opts.set_xla_gpu_nccl_init_max_rank_per_root_ratio(0); @@ -1582,11 +1581,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Enables NCCL User Buffer Registration. collective_memory_size in the " "allocator config must also be set to a non-zero value that is large " "enough to meet peak collective memory usage.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_experimental_enable_nvshmem", - bool_setter_for(&DebugOptions::set_xla_gpu_experimental_enable_nvshmem), - debug_options->xla_gpu_experimental_enable_nvshmem(), - "Enables NVSHMEM.")); flag_list->push_back(tsl::Flag( "xla_gpu_temp_buffer_use_separate_color", bool_setter_for( diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 91382cafe00752..ee0363a8f51e7e 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -237,56 +237,6 @@ xla_test( ], ) -# TODO(b/409713313): Move this test to collectives directory. -xla_test( - name = "se_gpu_pjrt_client_nvshmem_test", - srcs = ["se_gpu_pjrt_client_nvshmem_test.cc"], - backend_tags = {"gpu": [ - "multi_gpu_h100", - "no_oss", - "noasan", - "notap", # TODO(b/399931591): Re-enable once flakiness is resolved. - "nomsan", - ]}, - backends = ["gpu"], - env = { - "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", - }, - deps = [ - ":gpu_topology_proto_cc", - ":se_gpu_pjrt_client", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/backends/gpu/collectives:gpu_collectives", - "//xla/ffi", - "//xla/ffi:ffi_api", - "//xla/hlo/builder:xla_computation", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:test", - "//xla/hlo/utils:hlo_query", - "//xla/pjrt:pjrt_client", - "//xla/pjrt:pjrt_compiler", - "//xla/pjrt:pjrt_executable", - "//xla/pjrt:raw_buffer", - "//xla/pjrt/distributed", - "//xla/pjrt/distributed:client", - "//xla/pjrt/distributed:in_memory_key_value_store", - "//xla/pjrt/distributed:service", - "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", - "//xla/service:platform_util", - "//xla/tests:literal_test_util", - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - xla_test( name = "pjrt_client_test_se_gpu", srcs = ["pjrt_client_test_se_gpu.cc"], diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc deleted file mode 100644 index 3200b37414d00e..00000000000000 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_nvshmem_test.cc +++ /dev/null @@ -1,258 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "xla/ffi/ffi.h" -#include "xla/ffi/ffi_api.h" -#include "xla/hlo/builder/xla_computation.h" -#include "xla/hlo/parser/hlo_parser.h" -#include "xla/hlo/testlib/test.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/layout.h" -#include "xla/pjrt/distributed/client.h" -#include "xla/pjrt/distributed/distributed.h" -#include "xla/pjrt/distributed/in_memory_key_value_store.h" -#include "xla/pjrt/distributed/service.h" -#include "xla/pjrt/gpu/gpu_topology.pb.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_compiler.h" -#include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" -#include "xla/pjrt/raw_buffer.h" -#include "xla/service/platform_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/util.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace { - -using ::testing::NotNull; -using ::testing::SizeIs; - -HloInstruction* FindInstruction(const HloModule* module, HloOpcode opcode) { - for (const HloComputation* computation : module->computations()) { - if (HloInstruction* instruction = - hlo_query::FindInstruction(computation, opcode)) { - return instruction; - } - } - return nullptr; -} - -absl::StatusOr> CompileExecutable( - absl::string_view program, xla::PjRtClient& client, - xla::CompileOptions compile_options = xla::CompileOptions()) { - TF_ASSIGN_OR_RETURN(auto hlo_module, - ParseAndReturnUnverifiedModule(program, {})); - - xla::XlaComputation xla_computation(hlo_module->ToProto()); - return client.CompileAndLoad(xla_computation, compile_options); -} - -// Register a mock "mosaic_gpu" custom call op for NvshmemMemoryTest, since -// mosaic_gpu is defined in JAX and won't be available to the unit test. -static absl::Status MockMosaicGpu(ffi::AnyBuffer arg, - ffi::Result ret, - absl::string_view module) { - return absl::OkStatus(); -} - -XLA_FFI_DEFINE_HANDLER(kMockMosaicGpu, MockMosaicGpu, - ffi::Ffi::Bind() - .Arg() - .Ret() - .Attr("module")); - -XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu", - PlatformUtil::CanonicalPlatformName("GPU").value(), - kMockMosaicGpu); - -// Verify that the client can initialize NVSHMEM and that buffers used by -// mosaic_gpu custom calls are assigned to the collective memory space. -TEST(StreamExecutorGpuClientTest, NvshmemMemoryTest) { - static constexpr char const* kProgram = R"( - HloModule ffi_handler - ENTRY main { - param = s32[1,4]{1,0} parameter(0) - reshape = s32[4]{0} reshape(param) - ROOT %custom-call = s32[4] custom-call(param), - custom_call_target="mosaic_gpu", - api_version=API_VERSION_TYPED_FFI, - backend_config={"custom_call_backend_config": {"attributes": "{module = \"nvshmem\"}"}} - })"; - // Nvshmem requires one gpu per process. - GpuClientOptions client_options; - client_options.node_id = 0; - client_options.allowed_devices = {0}; - client_options.num_nodes = 1; - client_options.kv_store = std::make_shared(); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - GetStreamExecutorGpuClient(client_options)); - xla::CompileOptions options; - options.executable_build_options.mutable_debug_options() - ->set_xla_gpu_experimental_enable_nvshmem(true); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - CompileExecutable(kProgram, *client, options)); - std::vector data{1, 2, 3, 4}; - Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1, 4}, - /*minor_to_major=*/{1, 0}); - shape.mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); - - PjRtDevice* const device = client->addressable_devices()[0]; - TF_EXPECT_OK(device->default_memory_space()); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr input, - client->BufferFromHostBuffer( - data.data(), shape.element_type(), shape.dimensions(), - /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, - /*on_done_with_host_buffer=*/nullptr, *device->default_memory_space(), - /*device_layout=*/nullptr)); - EXPECT_EQ(input->memory_space()->kind(), "device"); - - TF_ASSERT_OK_AND_ASSIGN( - std::vector> memory_kinds, - executable->GetOutputMemoryKinds()); - EXPECT_EQ(memory_kinds.size(), 1); - EXPECT_EQ(memory_kinds[0].size(), 1); - EXPECT_EQ(memory_kinds[0][0], "device"); - - TF_ASSERT_OK_AND_ASSIGN( - std::vector>> result, - executable->Execute({{input.get()}}, ExecuteOptions())); - std::vector>& result_buffers = result[0]; - EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device"); - Shape result_shape = result_buffers[0]->on_device_shape(); - int64_t memory_space = result_shape.layout().memory_space(); - EXPECT_EQ(memory_space, 1); -} - -absl::Status UserBufferWithNvshmemMallocTestBody(const int node_id, - const int num_nodes) { - const absl::string_view kModuleStr = R"( -HloModule test -apply_op { -x = u32[] parameter(0) -y = u32[] parameter(1) -ROOT apply_op = u32[] add(x, y) -} -ENTRY test_computation { -id = u32[] replica-id() -ROOT all-reduce = u32[] all-reduce(id), to_apply=apply_op -} -)"; - std::unique_ptr service; - if (node_id == 0) { - xla::CoordinationServiceImpl::Options service_options; - service_options.num_nodes = num_nodes; - TF_ASSIGN_OR_RETURN(service, xla::GetDistributedRuntimeService( - "[::]:12346", service_options)); - } - - xla::DistributedRuntimeClient::Options distributed_options; - distributed_options.node_id = node_id; - distributed_options.init_timeout = absl::Seconds(120); - auto distributed_client = - GetDistributedRuntimeClient("127.0.0.1:12346", distributed_options); - TF_QCHECK_OK(distributed_client->Connect()); - GpuClientOptions client_options; - client_options.node_id = node_id; - client_options.allowed_devices = {node_id}; - client_options.num_nodes = num_nodes; - client_options.kv_store = - GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:"); - ; - TF_ASSIGN_OR_RETURN(std::unique_ptr client, - GetStreamExecutorGpuClient(client_options)); - xla::CompileOptions options; - options.executable_build_options.set_num_replicas(num_nodes); - options.executable_build_options.mutable_debug_options() - ->set_xla_gpu_experimental_enable_nvshmem(true); - options.executable_build_options.mutable_debug_options() - ->set_xla_gpu_enable_nccl_user_buffers(true); - - TF_ASSIGN_OR_RETURN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr, {})); - xla::XlaComputation xla_computation(hlo_module->ToProto()); - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - client->CompileAndLoad(xla_computation, options)); - - // Verify that the collective memory space is used. - TF_ASSIGN_OR_RETURN(auto modules, executable->GetHloModules()); - HloInstruction* all_reduce_start = - FindInstruction(modules[0].get(), HloOpcode::kAllReduceStart); - EXPECT_THAT(all_reduce_start, NotNull()); - EXPECT_EQ(all_reduce_start->shape().layout().memory_space(), 1); - EXPECT_THAT(all_reduce_start->operands(), SizeIs(1)); - const HloInstruction* input = all_reduce_start->operand(0); - EXPECT_EQ(input->shape().layout().memory_space(), 1); - - TF_ASSIGN_OR_RETURN( - std::vector>> results, - executable->Execute(/*argument_handles=*/{{}}, /*options=*/{})); - EXPECT_EQ(results.size(), 1); - EXPECT_EQ(results[0].size(), 1); - TF_ASSIGN_OR_RETURN(auto literal, results[0][0]->ToLiteralSync()); - if (node_id == 0) { - LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16}, *literal); - } else if (node_id == 1) { - LiteralTestUtil::ExpectR1Equal({20, 25, 21, 26}, *literal); - } - return absl::OkStatus(); -} - -} // namespace -} // namespace xla - -int main(int argc, char* argv[]) { - int node_id = -1; - int num_nodes = -1; - std::vector flag_list = { - tsl::Flag("node_id", &node_id, "Node ID for multiprocess tests."), - tsl::Flag("num_nodes", &num_nodes, - "Number of nodes for multiprocess tests."), - }; - std::string usage = tsl::Flags::Usage(argv[0], flag_list); - tsl::Flags::Parse(&argc, argv, flag_list); - testing::InitGoogleTest(&argc, argv); - if (node_id >= 0) { - absl::Status result = - xla::UserBufferWithNvshmemMallocTestBody(node_id, num_nodes); - if (!result.ok()) { - LOG(ERROR) << result; - } - return result.raw_code(); - } - return RUN_ALL_TESTS(); -} diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 218adadff90a5e..de44c6df57fdcc 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -115,9 +115,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:hlo_value", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 930ef2126c250b..c15682a45aeb91 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -250,12 +250,9 @@ absl::StatusOr> RunBufferAssignment( ScopedAnnotation annotation(Phase("XlaBufferAssignment", module)); const DebugOptions& options = module->config().debug_options(); - BufferAssigner::Colorer colorer = - (options.xla_gpu_enable_nccl_user_buffers() || - options.xla_gpu_experimental_enable_nvshmem()) - ? CollectiveColorer(options.xla_gpu_enable_nccl_user_buffers(), - options.xla_gpu_experimental_enable_nvshmem()) - : BufferAssigner::DefaultColorer(); + BufferAssigner::Colorer colorer = options.xla_gpu_enable_nccl_user_buffers() + ? CollectiveColorer() + : BufferAssigner::DefaultColorer(); std::optional color = options.xla_gpu_temp_buffer_use_separate_color() diff --git a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h index 220e15ede9cd3f..32d1906c30dc07 100644 --- a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h +++ b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_ordering.h" @@ -35,10 +34,8 @@ inline constexpr int64_t kTempBufferMemorySpaceColor = 2; // Set memory space to kCollectiveMemorySpaceColor for all allocations used by // all-reduce, all-gather, and reduce-scatter. This memory space maps to // collective memory using ncclMemAlloc in the runtime. -inline BufferAssigner::Colorer CollectiveColorer(bool use_user_buffers, - bool use_nvshmem) { - return [use_user_buffers, use_nvshmem](HloAliasAnalysis* alias_analysis, - const HloOrdering&) { +inline BufferAssigner::Colorer CollectiveColorer() { + return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { static const auto* kSupportedOpcodes = new absl::flat_hash_set{ HloOpcode::kAllReduce, HloOpcode::kAllReduceStart, @@ -52,25 +49,12 @@ inline BufferAssigner::Colorer CollectiveColorer(bool use_user_buffers, HloOpcode::kCollectivePermuteDone, HloOpcode::kAllToAll, }; - auto is_mosaic_gpu_nvshmem_instr = [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kCustomCall && - (instr->custom_call_target() == "mosaic_gpu" || - instr->custom_call_target() == "mosaic_gpu_v2") && - instr->raw_backend_config_string().find("nvshmem") != - std::string::npos; - }; auto is_collective_memory_instr = [&](const HloInstruction* instr) { - if (use_user_buffers) { - return kSupportedOpcodes->contains(instr->opcode()) || - // opcode or async wrapped opcode is in kSupportedOpcodes. - ((instr->opcode() == HloOpcode::kAsyncStart || - instr->opcode() == HloOpcode::kAsyncDone) && - kSupportedOpcodes->contains(instr->async_wrapped_opcode())); - } - if (use_nvshmem) { - return is_mosaic_gpu_nvshmem_instr(instr); - } - return false; + return kSupportedOpcodes->contains(instr->opcode()) || + // opcode or async wrapped opcode is in kSupportedOpcodes. + ((instr->opcode() == HloOpcode::kAsyncStart || + instr->opcode() == HloOpcode::kAsyncDone) && + kSupportedOpcodes->contains(instr->async_wrapped_opcode())); }; auto has_collective_memory_in_uses = [&](const HloValue* input_alias) { // If any use is a collective instruction, we must color the value to use diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 813fe8813a26f3..31007c9e0d54d6 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -551,10 +551,6 @@ message DebugOptions { // Pre-existing block-level fusions are left unmodified. bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334; - // Enable NVSHMEM. Must be set via XLA_FLAGS variable before XLA client is - // initialized and can't be set just through HLO Config->ExecutionOptions. - bool xla_gpu_experimental_enable_nvshmem = 387; - // Enable the pass that splits GEMMs that underutilize the GPU load by // splitting the K dimension using a heuristic. bool xla_gpu_experimental_enable_split_k_rewrite = 386; @@ -1206,7 +1202,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 388 + // Next id: 387 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 942a6402f896304d6acc1afce582691f0c2636fc Mon Sep 17 00:00:00 2001 From: Zviki Nozadze Date: Tue, 15 Apr 2025 08:54:11 -0700 Subject: [PATCH 0812/1324] Remove unnecessary compatibility checks in `InstructionValueSet::AssignUnionOf`. PiperOrigin-RevId: 747894444 --- .../analysis/hlo_dataflow_analysis_test.cc | 51 +++++++++++++++++++ third_party/xla/xla/service/hlo_value.cc | 10 ---- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc index fe1ce5f513807f..06e6d133504f82 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc @@ -3674,5 +3674,56 @@ TEST_P(HloDataflowAnalysisTest, b409416499) { EXPECT_THAT(defining_instructions, UnorderedElementsAre(param2, add0)); } +TEST_P(HloDataflowAnalysisTest, b409756077) { + const char* after_layout_bitcast = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f32[1,256,256]{2,1,0:T(8,128)})->f32[1,256,256]{2,1,0:T(8,128)}} + add_f32 { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(%add_lhs, %add_rhs) + } + + %while_body (param.1: f32[256,256]) -> f32[256,256] { + %param.1 = f32[256,256]{1,0:T(8,128)} parameter(0) + %constant.0 = f32[]{:T(8,128)} constant(1) + %constant.1 = f32[256,256]{1,0:T(8,128)} broadcast(%constant.0), dimensions={} + ROOT %add.0 = f32[256,256]{1,0:T(8,128)} add(%param.1, %constant.1) + } + + %while_condition (param: f32[256,256]) -> pred[] { + %param.0 = f32[256,256]{1,0:T(8,128)} parameter(0) + %zero = f32[]{:T(8,128)} constant(0) + %sum_of_values_in_param = f32[]{:T(8,128)} reduce(%param.0, %zero), dimensions={0,1}, to_apply=%add_f32 + %constant = f32[]{:T(8,128)} constant(512) + ROOT %compare.0 = pred[] compare(%sum_of_values_in_param, %constant), direction=LT + } + + ENTRY %main (param.2: f32[1,256,256]) -> f32[1,256,256] { + %param.2 = f32[1,256,256]{2,1,0:T(8,128)} parameter(0) + %bitcast.2 = f32[256,256]{1,0:T(8,128)} bitcast(%param.2) + %while.1 = f32[256,256]{1,0:T(8,128)} while(%bitcast.2), condition=%while_condition, body=%while_body + ROOT %bitcast.3 = f32[1,256,256]{2,1,0:T(8,128)} bitcast(%while.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto after_layout_bitcast_module, + ParseAndReturnVerifiedModule(after_layout_bitcast)); + TF_ASSERT_OK_AND_ASSIGN(auto analysis, + HloDataflowAnalysis::Run(*after_layout_bitcast_module, + /*ssa_form=*/false)); + HloInstruction* bitcast3 = + FindInstruction(after_layout_bitcast_module.get(), "bitcast.3"); + HloInstruction* param2 = + FindInstruction(after_layout_bitcast_module.get(), "param.2"); + HloComputation* while_body = + FindComputation(after_layout_bitcast_module.get(), "while_body"); + HloInstruction* add0 = while_body->root_instruction(); + std::vector defining_instructions; + for (const HloValue* value : + analysis->GetValueSet(bitcast3, {}).TakeValues()) { + defining_instructions.push_back(value->defining_instruction()); + } + EXPECT_THAT(defining_instructions, UnorderedElementsAre(param2, add0)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_value.cc b/third_party/xla/xla/service/hlo_value.cc index 699b4ed6ad474c..03c204c5ec0bb1 100644 --- a/third_party/xla/xla/service/hlo_value.cc +++ b/third_party/xla/xla/service/hlo_value.cc @@ -286,16 +286,6 @@ bool InstructionValueSet::IsAmbiguous() const { bool InstructionValueSet::AssignUnionOf( absl::Span inputs) { CHECK_GT(inputs.size(), 0); - for (int i = 1; i < inputs.size(); ++i) { - // It is possible that some values come from effective scalar shapes, i.e., - // X[1] that was bitcasted to X[]. In such cases, shapes are not compatible - // but it is still valid to get the union of the values. - bool shapes_are_effective_scalar = - ShapeUtil::IsEffectiveScalar(inputs[0]->shape()) && - ShapeUtil::IsEffectiveScalar(inputs[i]->shape()); - DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()) || - shapes_are_effective_scalar); - } bool changed = false; for (auto& pair : *this) { const ShapeIndex& index = pair.first; From 6ba2c4a59c787e5665a9a4e32f7607ecb05581c4 Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Tue, 15 Apr 2025 09:04:34 -0700 Subject: [PATCH 0813/1324] Show fractions as percentage in HLO Stats. PiperOrigin-RevId: 747898322 --- tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc index 66ceccf0af3efc..fed7128c887ac0 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc @@ -152,8 +152,8 @@ std::unique_ptr CreateHloStatsDataTable( row->AddCell(record.avg_time_in_us()); row->AddCell(record.total_self_time_in_us()); row->AddCell(record.avg_self_time_in_us()); - row->AddCell(record.total_self_time_as_fraction()); - row->AddCell(record.cumulative_total_self_time_as_fraction()); + row->AddCell(record.total_self_time_as_fraction() * 100); + row->AddCell(record.cumulative_total_self_time_as_fraction() * 100); row->AddCell(record.dma_stall_fraction()); row->AddCell(record.model_flop_rate()); row->AddCell(record.measured_flop_rate()); From 6c303829a841bc4d82c8c7266a0a7bee29a6eb51 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 09:14:13 -0700 Subject: [PATCH 0814/1324] Restrict new default Tilings to Hopper & Ampere to fix issues on Blackwell. We will later update the default tiling set specifically for Blackwell. PiperOrigin-RevId: 747901700 --- .../xla/xla/service/gpu/autotuning/BUILD | 2 +- .../autotuning/gemm_fusion_autotuner_cuda.cc | 63 ++++++++++++------- .../autotuning/gemm_fusion_autotuner_test.cc | 40 ++++++++++++ 3 files changed, 81 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index fec7cb04b678d4..4d462dc3c9ff73 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -58,8 +58,8 @@ cc_library( "//xla/service/gpu/transforms:cudnn_fusion_compiler", "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:env", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc index 860f44f560be82..bc62e1cb545a6d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc @@ -14,10 +14,8 @@ limitations under the License. ==============================================================================*/ #include -#include #include -#include "absl/algorithm/container.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -30,7 +28,7 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" -#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" namespace xla { namespace gpu { @@ -82,27 +80,46 @@ bool GemmFusionAutotunerImpl::AddLibConfigs( std::vector GemmFusionAutotunerImpl::GetDefaultTritonConfigs() const { using Config = TritonGemmConfig; + auto compute_capability = + std::get(GetComputeCapability()); - std::vector configs = { - Config(16, 16, 64, 1, 4, 2), Config(16, 16, 128, 1, 4, 4), - Config(16, 16, 128, 128, 4, 2), Config(16, 16, 128, 16, 1, 2), - Config(16, 256, 16, 1, 1, 2), Config(32, 32, 128, 16, 1, 4), - Config(32, 256, 32, 1, 3, 4), Config(32, 256, 32, 16, 3, 8), - Config(64, 16, 32, 1, 4, 2), Config(64, 16, 32, 16, 4, 2), - Config(64, 16, 64, 1, 1, 4), Config(64, 16, 64, 4, 3, 2), - Config(64, 16, 64, 16, 4, 4), Config(64, 16, 128, 1, 4, 2), - Config(64, 16, 128, 16, 4, 4), Config(64, 32, 32, 1, 4, 4), - Config(64, 32, 64, 16, 3, 4), Config(64, 32, 128, 1, 3, 2), - Config(64, 32, 128, 128, 2, 4), Config(64, 64, 32, 1, 4, 4), - Config(64, 64, 64, 1, 4, 4), Config(64, 64, 64, 4, 4, 4), - Config(64, 64, 128, 16, 3, 4), Config(64, 64, 256, 16, 4, 8), - Config(64, 128, 16, 1, 4, 2), Config(64, 128, 64, 1, 3, 4), - Config(64, 128, 128, 8, 1, 4), Config(64, 256, 32, 1, 4, 4), - Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2), - Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2), - Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8), - Config(128, 256, 64, 1, 4, 8)}; - return configs; + if (compute_capability.IsHopper() || compute_capability.IsAmpere()) { + return {Config(16, 16, 64, 1, 4, 2), Config(16, 16, 128, 1, 4, 4), + Config(16, 16, 128, 128, 4, 2), Config(16, 16, 128, 16, 1, 2), + Config(16, 256, 16, 1, 1, 2), Config(32, 32, 128, 16, 1, 4), + Config(32, 256, 32, 1, 3, 4), Config(32, 256, 32, 16, 3, 8), + Config(64, 16, 32, 1, 4, 2), Config(64, 16, 32, 16, 4, 2), + Config(64, 16, 64, 1, 1, 4), Config(64, 16, 64, 4, 3, 2), + Config(64, 16, 64, 16, 4, 4), Config(64, 16, 128, 1, 4, 2), + Config(64, 16, 128, 16, 4, 4), Config(64, 32, 32, 1, 4, 4), + Config(64, 32, 64, 16, 3, 4), Config(64, 32, 128, 1, 3, 2), + Config(64, 32, 128, 128, 2, 4), Config(64, 64, 32, 1, 4, 4), + Config(64, 64, 64, 1, 4, 4), Config(64, 64, 64, 4, 4, 4), + Config(64, 64, 128, 16, 3, 4), Config(64, 64, 256, 16, 4, 8), + Config(64, 128, 16, 1, 4, 2), Config(64, 128, 64, 1, 3, 4), + Config(64, 128, 128, 8, 1, 4), Config(64, 256, 32, 1, 4, 4), + Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2), + Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2), + Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8), + Config(128, 256, 64, 1, 4, 8)}; + } + + return {Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), + Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), + Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), + Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), + Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), + Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8), + Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4), + Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4), + Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8), + Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4), + Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4), + Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4), + Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4), + Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4), + Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4), + Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)}; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index a08cfd484e71d8..ed4a583745ede9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -1812,6 +1813,45 @@ ENTRY e { )"); } +TEST_F(GemmFusionAutotunerTest, VerifyHopperConfigsAreDifferentFromBlackwell) { + if (isRocm()) { + GTEST_SKIP() << "Not supported on ROCm."; + } + + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( + ENTRY e { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .value(); + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector blackwell_configs, + GetPossibleMatmulAutotuneTritonConfigs( + *Cast( + module->entry_computation()->root_instruction()), + se::CudaComputeCapability(se::CudaComputeCapability::kBlackwell, 0), + GetToolkitVersion(), GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN( + const std::vector hopper_configs, + GetPossibleMatmulAutotuneTritonConfigs( + *Cast( + module->entry_computation()->root_instruction()), + se::CudaComputeCapability(se::CudaComputeCapability::kHopper, 0), + GetToolkitVersion(), GetDebugOptionsForTest())); + + std::set blackwell_configs_set(blackwell_configs.begin(), + blackwell_configs.end()); + std::set hopper_configs_set(hopper_configs.begin(), + hopper_configs.end()); + + EXPECT_GT(blackwell_configs_set.size(), 0); + EXPECT_GT(hopper_configs_set.size(), 0); + EXPECT_NE(blackwell_configs_set, hopper_configs_set); +} + } // namespace } // namespace gpu } // namespace xla From 8eab9f933505024385027d8bdbb0383bd83488d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 09:40:07 -0700 Subject: [PATCH 0815/1324] Creates an AutoShardingSolverParams struct with various fields absent from the IOPDDL formulation. PiperOrigin-RevId: 747910368 --- .../auto_sharding/auto_sharding.cc | 6 +- .../auto_sharding/auto_sharding_solver.cc | 107 ++++++++++++------ .../auto_sharding/auto_sharding_solver.h | 13 ++- .../auto_sharding_solver_test.cc | 61 +++++----- 4 files changed, 112 insertions(+), 75 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index ed8a7fc0aa547c..39b02c3d885531 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2000,12 +2000,8 @@ CreateAutoShardingSolverRequestAndCallSolver( const auto converted_problem = ConvertToProblem(request); const auto converted_request = ConvertToSolverRequest(converted_problem); - const std::optional overbudget_coeff = - option.memory_overbudget_coeff >= 0.0 - ? std::make_optional(option.memory_overbudget_coeff) - : std::nullopt; return FormulateAndSolveMIPFromSolverRequest(converted_request, - overbudget_coeff); + GetParams(request)); } void CheckHloSharding( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index cd2b5f813fd08d..26455ba4c23a9f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -118,6 +118,25 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { return min_memory_budget_required_estimate; } +AutoShardingSolverParams GetParams(const AutoShardingSolverRequest& request) { + AutoShardingSolverParams params; + for (const auto& departure_cost : request.departure_costs()) { + std::vector departure_cost_vector(departure_cost.costs().begin(), + departure_cost.costs().end()); + params.departure_costs.push_back(departure_cost_vector); + } + params.max_departures = + request.has_max_departures() + ? std::make_optional(request.max_departures().coeff()) + : std::nullopt; + params.minimize_departures = request.minimize_departures(); + params.overbudget_coeff = + request.has_overbudget_coeff() + ? std::make_optional(request.overbudget_coeff().coeff()) + : std::nullopt; + return params; +} + namespace { std::vector GetChosenNodeStrategy( @@ -199,10 +218,11 @@ void PrintLargestInstructions( absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, + const AutoShardingSolverParams& params, const std::vector>& s, const std::vector>& e, const MPVariable* overbudget_var, const MPVariable* makespan_var, - const std::optional overbudget_coeff, MPSolver& solver) { + MPSolver& solver) { auto status = solver.Solve(); LOG(INFO) << "Solver absl::Status: " << status; @@ -238,14 +258,17 @@ absl::StatusOr SolveAndExtractSolution( #endif return absl::InternalError( "MPSolver could not find any feasible solution."); - } else if (status == operations_research::MPSolver::MODEL_INVALID) { + } + if (status == operations_research::MPSolver::MODEL_INVALID) { LOG(FATAL) << "The MIP fed to the solver is invalid. This is most likely a " "bug and should be reported."; return absl::InternalError("Invalid MIP."); - } else if (status == operations_research::MPSolver::NOT_SOLVED) { + } + if (status == operations_research::MPSolver::NOT_SOLVED) { LOG(WARNING) << "Solver timeout; no solution was produced"; return absl::InternalError("Solver timed out."); - } else if (status != operations_research::MPSolver::OPTIMAL) { + } + if (status != operations_research::MPSolver::OPTIMAL) { LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution"; } else { is_optimal = true; @@ -298,7 +321,8 @@ absl::StatusOr SolveAndExtractSolution( unsalted_objective += request.resharding_costs(edge_idx).costs(j); } if (overbudget_var) { - unsalted_objective += *overbudget_coeff * overbudget_var->solution_value() * + unsalted_objective += *params.overbudget_coeff * + overbudget_var->solution_value() * request.memory_budget(); } if (makespan_var) { @@ -498,7 +522,7 @@ void AddMemoryTerms( // however. absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request, - std::optional overbudget_coeff) { + const AutoShardingSolverParams& params) { const absl::Time start_time = absl::Now(); const size_t num_edges = request.edges_size(); const int num_workers = 32; @@ -581,7 +605,7 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( edge_map.insert({followed_edge, edge_idx}); } - if (request.memory_budget() > 0 && overbudget_coeff.has_value()) { + if (request.memory_budget() > 0 && params.overbudget_coeff.has_value()) { overbudget_var = solver->MakeNumVar(0.0, MPSolver::infinity(), "overbudget"); } @@ -601,7 +625,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( infinity_vars.insert(s[node_idx][j]); continue; } - if (request.minimize_departures()) continue; + if (params.minimize_departures) { + continue; + } double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(s[node_idx][j]); solver->MutableObjective()->SetCoefficient( @@ -616,7 +642,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( infinity_vars.insert(e[edge_idx][j]); continue; } - if (request.minimize_departures()) continue; + if (params.minimize_departures) { + continue; + } double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(e[edge_idx][j]); solver->MutableObjective()->SetCoefficient( @@ -702,9 +730,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( reduced_intervals_nodes, request.memory_costs(), overbudget_var, reduced_times, s, group_node_vars, constraints); - if (overbudget_var && !request.minimize_departures()) { + if (overbudget_var && !params.minimize_departures) { solver->MutableObjective()->SetCoefficient( - overbudget_var, *overbudget_coeff * request.memory_budget()); + overbudget_var, *params.overbudget_coeff * request.memory_budget()); } LOG(INFO) << "Minimum memory budget estimate: " << MinimumMemoryBudgetRequired(request); @@ -715,7 +743,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( // d. specified via "BoolVarArray" // e. for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - if (e_follow[edge_idx] >= 0) continue; + if (e_follow[edge_idx] >= 0) { + continue; + } const auto& edge = request.edges(edge_idx); for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { @@ -739,7 +769,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const auto& raw_alias = request.aliases(alias_idx); const std::pair alias(raw_alias.first(), raw_alias.second()); - if (alias_set.contains(alias)) continue; + if (alias_set.contains(alias)) { + continue; + } alias_set.insert(alias); const auto& value_costs = request.value_costs(alias_idx).costs(); for (NodeStrategyIdx p = 0; p < s[alias.first].size(); ++p) { @@ -756,10 +788,10 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( } } } - if (request.has_max_departures()) { + if (params.max_departures.has_value()) { MPConstraint* constraint = solver->MakeRowConstraint( - 0, request.max_departures().coeff(), - absl::StrCat("departures <= ", request.max_departures().coeff())); + 0, *params.max_departures, + absl::StrCat("departures <= ", *params.max_departures)); for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { double accumulated_coefficient = @@ -770,7 +802,7 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( } } } - if (request.minimize_departures()) { + if (params.minimize_departures) { for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { double accumulated_coefficient = @@ -795,7 +827,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( (!request.has_max_cost() || request.max_cost().coeff() < kMaxCostValue)) { std::vector> hint; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { - if (request.s_follow(node_idx) >= 0) continue; + if (request.s_follow(node_idx) >= 0) { + continue; + } for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { double hint_val = (request.s_hint(node_idx) == j) ? 1.0 : 0.0; hint.push_back({s[node_idx][j], hint_val}); @@ -861,19 +895,19 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( << "Number variables for ILP: " << solver->NumVariables() << "\n" << "Number of ILP constraints: " << solver->NumConstraints() << "\n" << "Deterministic mode: " << request.deterministic_mode() << "\n" - << "Minimize departures: " << request.minimize_departures() << "\n" + << "Minimize departures: " << params.minimize_departures << "\n" << "Module name: " << request.module_name(); if (request.has_max_cost()) { VLOG(0) << "Max cost: " << request.max_cost().coeff(); } - if (request.has_max_departures()) { - VLOG(0) << "Max departures: " << request.max_departures().coeff(); + if (params.max_departures.has_value()) { + VLOG(0) << "Max departures: " << *params.max_departures; } - auto result = SolveAndExtractSolution( - request, s, e, overbudget_var, makespan_var, overbudget_coeff, *solver); + auto result = SolveAndExtractSolution(request, params, s, e, overbudget_var, + makespan_var, *solver); if (result.ok()) { const AutoShardingEvaluation evaluation = - Evaluate(request, *result, overbudget_coeff); + Evaluate(request, *result, params); LOG(INFO) << "*** Total costs for the solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost @@ -979,7 +1013,8 @@ AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request, bool candidate_is_feasible = (cost >= 0.0); if (have_feasible_solution && !candidate_is_feasible) { continue; - } else if (have_feasible_solution && candidate_is_feasible) { + } + if (have_feasible_solution && candidate_is_feasible) { if (cost < best_cost) { best_node_strategies = node_strategies; best_cost = cost; @@ -1089,12 +1124,12 @@ bool AutoShardingEvaluation::operator==( AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const AutoShardingSolverOutput& result, - std::optional overbudget_coeff) { + const AutoShardingSolverParams& params) { const auto& c = request.computation_costs(); const auto& d = request.communication_costs(); const auto& r = request.resharding_costs(); const auto& v = request.value_costs(); - const auto& p = request.departure_costs(); + const auto& p = params.departure_costs; const std::vector& s_val = result.s_val; const auto e_val = [&](EdgeIdx edge_idx) { const auto& edge = request.edges(edge_idx); @@ -1130,10 +1165,13 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, } } for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { - if (p.empty()) continue; - evaluation.total_departures += p.at(node_idx).costs(s_val[node_idx]); - if (request.has_max_departures() && - evaluation.total_departures > request.max_departures().coeff()) { + if (p.empty()) { + continue; + } + evaluation.total_departures += p[node_idx][s_val[node_idx]]; + + if (params.max_departures.has_value() && + evaluation.total_departures > *params.max_departures) { evaluation.violation_codes.insert(kMaxDeparturesViolationCode); } } @@ -1180,10 +1218,11 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, evaluation.violation_codes.insert(kMemoryViolationCode); } } - if (overbudget_coeff.has_value()) { - evaluation.total.overbudget_cost = *overbudget_coeff * total_overbudget; + if (params.overbudget_coeff.has_value()) { + evaluation.total.overbudget_cost = + *params.overbudget_coeff * total_overbudget; evaluation.lower_bound.overbudget_cost = - *overbudget_coeff * lower_bound_overbudget; + *params.overbudget_coeff * lower_bound_overbudget; } } // Compute metrics and lower bounds. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 551f9c7aff111c..5fe51df206b8ce 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -42,9 +42,18 @@ struct AutoShardingSolverOutput { // Determines the minimum memory budget required to avoid memory violations. double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request); +struct AutoShardingSolverParams { + std::vector> departure_costs; + std::optional max_departures; + bool minimize_departures = false; + std::optional overbudget_coeff; +}; + +AutoShardingSolverParams GetParams(const AutoShardingSolverRequest& request); + absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request, - std::optional overbudget_coeff); + const AutoShardingSolverParams& params); // TODO(fahrbach): Create AutoShardingHeuristicOptions proto with a oneof field. // Runs a heuristic specified by one of the following values of `algorithm`: @@ -100,7 +109,7 @@ struct AutoShardingEvaluation { // solution quality metrics and validating the consistency of hard constraints. AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const AutoShardingSolverOutput& result, - std::optional overbudget_coeff); + const AutoShardingSolverParams& params); // Computes the objective value of the sharding strategy. If the objective value // is infinite or the sharding is infeasible (e.g., violates the peak-memory diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index e7668e29224f0b..aecf03d5cf5acb 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -85,13 +85,6 @@ void AddGroups( } } -std::optional GetOverbudgetCoeff( - const AutoShardingSolverRequest& request) { - return request.has_overbudget_coeff() - ? std::make_optional(request.overbudget_coeff().coeff()) - : std::nullopt; -} - // clang-format off AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { @@ -265,7 +258,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -280,7 +273,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; @@ -294,7 +287,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -308,7 +301,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; @@ -324,7 +317,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; @@ -338,7 +331,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -364,7 +357,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; @@ -392,7 +385,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; @@ -407,7 +400,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -421,7 +414,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { const absl::StatusOr result = FormulateAndSolveMIPFromSolverRequest(request, - GetOverbudgetCoeff(request)); + GetParams(request)); EXPECT_TRUE(absl::IsInternal(result.status())); } @@ -432,7 +425,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -457,7 +450,7 @@ TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -480,7 +473,7 @@ TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -515,7 +508,7 @@ TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -537,7 +530,7 @@ TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -580,7 +573,7 @@ TEST(DISABLED_FormulateAndSolveMIPFromSolverRequestTest, TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -594,7 +587,7 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, FormulateAndSolveMIPFromSolverRequest( - request, GetOverbudgetCoeff(request))); + request, GetParams(request))); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; @@ -609,7 +602,7 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 159.0; // 13+21+32+42+51 @@ -633,7 +626,7 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -659,7 +652,7 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -693,7 +686,7 @@ TEST(DISABLED_AutoShardingEvaluatorTest, const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -717,7 +710,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kFollowerViolationCode}; @@ -740,7 +733,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kAliasViolationCode}; @@ -763,7 +756,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMemoryViolationCode}; @@ -789,7 +782,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -813,7 +806,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -837,7 +830,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingEvaluation evaluation = - Evaluate(request, output, GetOverbudgetCoeff(request)); + Evaluate(request, output, GetParams(request)); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMaxDeparturesViolationCode}; From 12bec512cffdfdd94fd61cc0a5de7d127e0aa754 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 10:02:31 -0700 Subject: [PATCH 0816/1324] [code cleanup] deprecate the `xla_gpu_enable_pipelined_collectives` flag in XLA GPU PiperOrigin-RevId: 747918456 --- third_party/xla/xla/debug_options_flags.cc | 7 ------- third_party/xla/xla/service/gpu/gpu_compiler.cc | 17 +++++------------ .../xla/xla/service/gpu/gpu_compiler_test.cc | 9 +-------- .../xla/xla/tests/collective_ops_e2e_test.cc | 3 --- third_party/xla/xla/xla.proto | 2 +- 5 files changed, 7 insertions(+), 31 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index fa4446a4f78a93..052e27833acef2 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -204,7 +204,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_memory_limit_slop_factor(95); opts.set_xla_gpu_enable_highest_priority_async_stream(true); - opts.set_xla_gpu_enable_pipelined_collectives(false); opts.set_xla_gpu_enable_pipelined_all_reduce(false); opts.set_xla_gpu_enable_pipelined_all_gather(false); opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); @@ -1712,12 +1711,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_highest_priority_async_stream), debug_options->xla_gpu_enable_highest_priority_async_stream(), "Enable async stream to have the highest priority.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_pipelined_collectives", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_collectives), - debug_options->xla_gpu_enable_pipelined_collectives(), - "Enable pipelinling of collective instructions (all-reduce, all-gather, " - "and reduce-scatter).")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_pipelined_all_reduce", bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_all_reduce), diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 8ce282559600c4..c78aef3f249562 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -916,8 +916,7 @@ absl::Status RunCollectiveOptimizationPasses( // Remove dead computations after collective quantization. collectives_pipeline.AddPass(); - if (debug_options.xla_gpu_enable_pipelined_collectives() || - debug_options.xla_gpu_enable_pipelined_all_reduce() || + if (debug_options.xla_gpu_enable_pipelined_all_reduce() || IsPassEnabledAtOptimizationEffort(*hlo_module)) { CollectivePipeliner::Config config{ /*level_to_operate_on=*/0, @@ -940,8 +939,7 @@ absl::Status RunCollectiveOptimizationPasses( }; collectives_pipeline.AddPass(config); } - if (debug_options.xla_gpu_enable_pipelined_collectives() || - debug_options.xla_gpu_enable_pipelined_all_gather() || + if (debug_options.xla_gpu_enable_pipelined_all_gather() || IsPassEnabledAtOptimizationEffort(*hlo_module)) { CollectivePipeliner::Config config{ /*level_to_operate_on=*/0, @@ -964,8 +962,7 @@ absl::Status RunCollectiveOptimizationPasses( }; collectives_pipeline.AddPass(config); } - if (debug_options.xla_gpu_enable_pipelined_collectives() || - debug_options.xla_gpu_enable_pipelined_reduce_scatter() || + if (debug_options.xla_gpu_enable_pipelined_reduce_scatter() || IsPassEnabledAtOptimizationEffort(*hlo_module)) { CollectivePipeliner::Config config{ /*level_to_operate_on=*/0, @@ -1017,8 +1014,7 @@ absl::Status RunCollectiveOptimizationPasses( bool enable_partial_send_recv_pipelining = pipeline_parallelism_opt_level != DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE; - if (debug_options.xla_gpu_enable_pipelined_collectives() || - debug_options.xla_gpu_enable_pipelined_p2p() || + if (debug_options.xla_gpu_enable_pipelined_p2p() || enable_partial_send_recv_pipelining) { collectives_pipeline.AddPass( enable_partial_send_recv_pipelining); @@ -2728,10 +2724,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( .debug_options() .xla_gpu_experimental_pipeline_parallelism_opt_level() == DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE && - (module->config() - .debug_options() - .xla_gpu_enable_pipelined_collectives() || - module->config().debug_options().xla_gpu_enable_pipelined_p2p())) { + (module->config().debug_options().xla_gpu_enable_pipelined_p2p())) { pipeline.AddPass(); } pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index bc8bd7a3fde0ec..cff463690608fc 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -1633,14 +1633,6 @@ TEST_F(PassOrderTest, /*last_pass_regex=*/"comparison-expander"); } -TEST_F(PassOrderTest, CollectivePipelinerRunsAfterCollectiveQuantizer) { - DebugOptions options = GetDebugOptionsForTest(); - options.set_xla_gpu_enable_pipelined_collectives(true); - SetDebugOptions(options); - - VerifyPassOrder(/*first_pass_regex=*/"collective-quantizer", - /*last_pass_regex=*/"collective-pipeliner.*"); -} TEST_F(PassOrderTest, AllGatherDynamicSliceSimplifierRunsAfterAllGatherOptimizer) { @@ -1867,6 +1859,7 @@ TEST_F(GpuCompilerTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr exec, test_runner_as_hlo_runner().ExecutableFromWrapped( std::move(wrapped_exec))); + std::cout << "exec module: " << exec->module().ToString() << "\n"; const char* kExpected = R"( // CHECK: dynamic-slice-fusion{{.+}} { // CHECK: %[[slice:.+]] = {{.+}} slice({{.+}}), slice={[4:8], [0:32]} diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 216ae9fc2685f5..a84445ff72a19f 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -1643,7 +1643,6 @@ ENTRY entry { HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); auto opts = GetDebugOptionsForTest(); - opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); CollectiveOpsVerifyF8Matmul( absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); @@ -1664,7 +1663,6 @@ class CollectiveOpsTestE2EPipelinedNonPipelined : public CollectiveOpsTestE2E { HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions); auto opts = GetDebugOptionsForTest(); - opts.set_xla_gpu_enable_pipelined_collectives(true); config.set_debug_options(opts); TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string, config)); @@ -1690,7 +1688,6 @@ class CollectiveOpsTestE2EPipelinedNonPipelined : public CollectiveOpsTestE2E { HloModuleConfig ref_config = GetModuleConfigForTest(kNumReplicas, kNumPartitions); auto ref_opts = GetDebugOptionsForTest(); - ref_opts.set_xla_gpu_enable_pipelined_collectives(false); ref_opts.set_xla_gpu_enable_pipelined_all_reduce(false); ref_opts.set_xla_gpu_enable_pipelined_all_gather(false); ref_opts.set_xla_gpu_enable_pipelined_reduce_scatter(false); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 31007c9e0d54d6..2e4d2c5b8df94d 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -456,7 +456,7 @@ message DebugOptions { bool xla_gpu_enable_pipelined_all_reduce = 217; - bool xla_gpu_enable_pipelined_collectives = 239; + bool xla_gpu_enable_pipelined_collectives = 239 [deprecated = true]; bool xla_gpu_enable_pipelined_p2p = 246; From f07cbcd1e89a0f3bb6f3f1a90f5c532b10be6322 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 15 Apr 2025 10:03:27 -0700 Subject: [PATCH 0817/1324] Implement ComputeAndCompareR1U8 in ClientLibraryTestRunnerMixin. PiperOrigin-RevId: 747918962 --- third_party/xla/xla/tests/BUILD | 2 ++ .../xla/tests/client_library_test_runner_mixin.h | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 4b466a0aadd576..70befa3997bfd2 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -386,10 +386,12 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/tsl/lib/core:bitmap", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index 0dc15530243b07..1af00e7fb6a212 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/array3d.h" @@ -40,6 +41,7 @@ limitations under the License. #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/bitmap.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -262,6 +264,20 @@ class ClientLibraryTestRunnerMixin : public T { ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } + // Compare with string. + // Side effect: EXPECT + void ComputeAndCompareR1U8(XlaBuilder* builder, + const absl::string_view expected, + absl::Span arguments) { + const absl::StatusOr actual = + ExecuteAndTransfer(builder, arguments); + TF_EXPECT_OK(actual.status()); + if (!actual.ok()) { + return; + } + EXPECT_EQ(actual->GetR1U8AsString(), expected); + } + XlaComputation CreateScalarMax() { return xla::CreateScalarMax(test_type_); } Literal CreateParameterAndTransferLiteral(const int64_t parameter_number, From 888de33af96f867b4ac270f526ec1d57919112cb Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 15 Apr 2025 10:24:20 -0700 Subject: [PATCH 0818/1324] Add Python 3.13 classifier Since https://pypi.org/project/tf-nightly/2.20.0.dev20250413/#files shows wheels for Python 3.13, the classifier also should be updated. --- tensorflow/tools/pip_package/setup.py.tpl | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/pip_package/setup.py.tpl b/tensorflow/tools/pip_package/setup.py.tpl index e2495eed665c43..5a18f9a109d00f 100644 --- a/tensorflow/tools/pip_package/setup.py.tpl +++ b/tensorflow/tools/pip_package/setup.py.tpl @@ -430,6 +430,7 @@ setup( 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', From 42b2132c7637c72bd1006ec54948df4bef86daf6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 10:04:08 -0700 Subject: [PATCH 0819/1324] Replace string, char* etc with absl::string_view. PiperOrigin-RevId: 747919293 --- .../hlo/analysis/while_loop_analysis_test.cc | 56 ++++++++++--------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc b/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc index 1604b4b798937b..ba165eb1f1c541 100644 --- a/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -46,10 +47,11 @@ namespace { class WhileLoopAnalysisTest : public HloHardwareIndependentTestBase { protected: - [[nodiscard]] absl::StatusOr MakeWhileLoopAndGetTripCount( - int init, int limit, int step, ComparisonDirection dir); - [[nodiscard]] absl::StatusOr MakeWhileLoopAndGetRange( - int init, int limit, int step, ComparisonDirection dir); + absl::StatusOr MakeWhileLoopAndGetTripCount(int init, int limit, + int step, + ComparisonDirection dir); + absl::StatusOr MakeWhileLoopAndGetRange(int init, int limit, int step, + ComparisonDirection dir); }; absl::StatusOr WhileLoopAnalysisTest::MakeWhileLoopAndGetTripCount( @@ -152,7 +154,7 @@ absl::StatusOr WhileLoopAnalysisTest::MakeWhileLoopAndGetRange( } TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { - const char* const kHloModule = R"( + absl::string_view kHloModule = R"( HloModule ModuleWithWhile body { @@ -183,7 +185,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { } TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCallNonTuple) { - std::string hlo_string = R"( + absl::string_view hlo_string = R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[]{:T(128)}, s32[3]{0}) parameter(0) @@ -215,7 +217,7 @@ TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCallNonTuple) { } TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCall) { - std::string hlo_string = R"( + absl::string_view hlo_string = R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[]{:T(128)}, s32[3]{0}) parameter(0) @@ -248,7 +250,7 @@ TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCall) { } TEST_F(WhileLoopAnalysisTest, NoUpperBound) { - const char* const kHloModule = R"( + absl::string_view kHloModule = R"( HloModule ModuleWithWhile body { @@ -370,7 +372,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialTripCount) { } TEST_F(WhileLoopAnalysisTest, NoAIVNoConstChain) { - const char* const kHloModule = R"( + absl::string_view kHloModule = R"( HloModule ModuleWithWhile body { @@ -407,7 +409,7 @@ TEST_F(WhileLoopAnalysisTest, NoAIVNoConstChain) { } TEST_F(WhileLoopAnalysisTest, AIVMultiChain) { - const char* const kHloModule = R"( + absl::string_view kHloModule = R"( HloModule ModuleWithWhile body { @@ -448,7 +450,7 @@ TEST_F(WhileLoopAnalysisTest, AIVMultiChain) { } TEST_F(WhileLoopAnalysisTest, NoAIV) { - const char* const kHloModule = R"( + absl::string_view kHloModule = R"( HloModule ModuleWithWhile body { @@ -485,7 +487,7 @@ TEST_F(WhileLoopAnalysisTest, NoAIV) { } TEST_F(WhileLoopAnalysisTest, AIVNoChain) { - const char* const kHloModule = R"( + absl::string_view kHloModule = R"( HloModule ModuleWithWhile body { @@ -522,7 +524,7 @@ TEST_F(WhileLoopAnalysisTest, AIVNoChain) { } TEST_F(WhileLoopAnalysisTest, NonScalarUpdateOp) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test, replica_count=2 add { param.3 = s32[] parameter(0) @@ -558,7 +560,7 @@ TEST_F(WhileLoopAnalysisTest, NonScalarUpdateOp) { } TEST_F(WhileLoopAnalysisTest, UpdateOnIndVarCopySuccess) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test, replica_count=2 body { param.0 = (s32[], s32[]) parameter(0) @@ -591,7 +593,7 @@ TEST_F(WhileLoopAnalysisTest, UpdateOnIndVarCopySuccess) { } TEST_F(WhileLoopAnalysisTest, IndVarInitialiationNotConstantSuccess) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test, replica_count=2 body { param.0 = (s32[], s32[]) parameter(0) @@ -624,7 +626,7 @@ TEST_F(WhileLoopAnalysisTest, IndVarInitialiationNotConstantSuccess) { } TEST_F(WhileLoopAnalysisTest, FusedUpdateOp) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test, replica_count=2 add { param.3 = s32[] parameter(0) @@ -661,7 +663,7 @@ TEST_F(WhileLoopAnalysisTest, FusedUpdateOp) { } TEST_F(WhileLoopAnalysisTest, NonScalarConditionOp) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test, replica_count=2 add { param.3 = s32[] parameter(0) @@ -704,7 +706,7 @@ TEST_F(WhileLoopAnalysisTest, NonScalarConditionOp) { } TEST_F(WhileLoopAnalysisTest, IndvarWithNonScalarShape) { - const std::string hlo_string = R"( + absl::string_view hlo_string = R"( HloModule test loop.body { @@ -786,7 +788,7 @@ TEST_F(WhileLoopAnalysisTest, FusedConditionOp) { } TEST_F(WhileLoopAnalysisTest, AvoidBruteForceForHugeParams) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test fused_comp { p.0 = pred[100000000]{0} parameter(0) @@ -834,7 +836,7 @@ TEST_F(WhileLoopAnalysisTest, AvoidBruteForceForHugeParams) { TEST_F(WhileLoopAnalysisTest, LoopFusionForLoopVariable) { // This test verifies that fusions in initialization, condition and update are // accepted by while loop analysis. - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test fused_add.11 { param_0.968 = s32[] parameter(0) @@ -886,7 +888,7 @@ TEST_F(WhileLoopAnalysisTest, LoopFusionForLoopVariable) { } TEST_F(WhileLoopAnalysisTest, UpdateIsMultipleOperationsWithConstantOperand) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test body { param.1 = (s32[], s32[8,8]) parameter(0) @@ -920,7 +922,7 @@ TEST_F(WhileLoopAnalysisTest, UpdateIsMultipleOperationsWithConstantOperand) { TEST_F(WhileLoopAnalysisTest, UpdateIsMultipleOperationsWithoutConstantOperand) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test body { param.1 = (s32[], s32[8,8]) parameter(0) @@ -954,7 +956,7 @@ TEST_F(WhileLoopAnalysisTest, TEST_F(WhileLoopAnalysisTest, ConditionIsMultipleOperationsWithConstantOperand) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test body { param.1 = (s32[], s32[8,8]) parameter(0) @@ -988,7 +990,7 @@ TEST_F(WhileLoopAnalysisTest, TEST_F(WhileLoopAnalysisTest, ConditionIsMultipleOperationsWithoutConstantOperand) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test body { param.1 = (s32[], s32[8,8]) parameter(0) @@ -1021,7 +1023,7 @@ TEST_F(WhileLoopAnalysisTest, } TEST_F(WhileLoopAnalysisTest, GetIndvarIndexShouldWorkWhenParamIsCopied) { - const char* hlo = R"( + absl::string_view hlo = R"( HloModule test fused_copy { @@ -1063,7 +1065,7 @@ TEST_F(WhileLoopAnalysisTest, GetIndvarIndexShouldWorkWhenParamIsCopied) { TEST_F(WhileLoopAnalysisTest, MatchTrivialLoopCountFailsWhenIndvarIsNotIncrementedByConstant) { - const char* hlo_with_constant = R"( + absl::string_view hlo_with_constant = R"( HloModule test body { param.1 = (s32[], s32[]) parameter(0) @@ -1085,7 +1087,7 @@ TEST_F(WhileLoopAnalysisTest, tuple = (s32[], s32[]) tuple(c0, data) ROOT while = (s32[], s32[]) while(tuple), body=body, condition=condition })"; - const char* hlo_without_constant = R"( + absl::string_view hlo_without_constant = R"( HloModule test body { param.1 = (s32[], s32[]) parameter(0) From 05dc457ddb4c36494844a62512e326854d9f4fca Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Tue, 15 Apr 2025 10:47:21 -0700 Subject: [PATCH 0820/1324] [PjRt-GPU] Populate PjRt GPU client device's `process_index` property for multi-host `xla::PjRtStreamExecutorDeviceDescription` takes a `process_index` to give a correct `process_index` property. This information is not fully propagated in a few places that all `xla::PjRtStreamExecutorDeviceDescription` returns `process_index() == 0` regardless of their owner processes. This CL makes a minimal change to populate `xla::PjRtStreamExecutorDeviceDescription::process_index`, either by forwarding the owner client's process index, or computing it from global device ids that a device description represents. This CL also removes the default value for `process_index` in both device and device description constructors. This is to help propagate `process_index` in the future code without accidentally assuming it to be 0. PiperOrigin-RevId: 747937501 --- third_party/xla/xla/pjrt/gpu/BUILD | 1 + .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 3 ++- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 24 +++++++++++++++++++ .../pjrt/gpu/se_gpu_topology_description.h | 9 ++++++- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 4 +++- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h | 1 + .../xla/pjrt/pjrt_stream_executor_client.h | 8 +++---- .../pjrt/pjrt_stream_executor_client_test.cc | 2 +- .../pjrt_stream_executor_device_description.h | 4 ++-- 9 files changed, 46 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index ee0363a8f51e7e..2fcfb4809bd114 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -194,6 +194,7 @@ xla_test( "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 8ae9b7210d6534..5be6c178af91ae 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1249,7 +1249,8 @@ StreamExecutorGpuDevice::StreamExecutorGpuDevice( std::string compute_capability, int core_count, int node_id, int slice_index) : PjRtStreamExecutorDevice(id, std::move(local_device_state), - std::move(device_kind), node_id), + /*process_index=*/node_id, + std::move(device_kind)), device_vendor_(std::move(device_vendor)), slice_index_(slice_index) { std::array coords = {local_device_id().value()}; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index c2c9683e8bbc15..0800b2ef630ded 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -61,6 +61,7 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_stream_executor_client.h" @@ -1198,6 +1199,29 @@ TEST(StreamExecutorGpuClientTest, GpuDeviceDescriptionTest) { } } +TEST(StreamExecutorGpuClientTest, GetTopologyDescriptionWithGlobalDevicesTest) { + const int num_nodes = 4; + GpuClientOptions options; + options.num_nodes = num_nodes; + options.enable_mock_nccl = true; + options.mock_gpu_topology = "2x2x2"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(options)); + int devices_per_host = client->addressable_device_count(); + + TF_ASSERT_OK_AND_ASSIGN(const PjRtTopologyDescription* topology, + client->GetTopologyDescription()); + + std::vector> + device_descriptions = topology->DeviceDescriptions(); + EXPECT_EQ(client->device_count(), device_descriptions.size()); + + for (const auto& device_description : device_descriptions) { + EXPECT_EQ(device_description->process_index(), + device_description->id() / devices_per_host); + } +} + TEST(TfrtCpuClientTest, CopyToMemorySpace) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_topology_description.h b/third_party/xla/xla/pjrt/gpu/se_gpu_topology_description.h index 13441e1dd938e4..59d7fa9dca9613 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_topology_description.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_topology_description.h @@ -66,9 +66,16 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { const override { std::vector> devices; devices.reserve(gpu_topology_->number_of_devices()); + int32_t num_devices_per_host = gpu_topology_->num_devices_per_host(); for (const int device_id : gpu_topology_->device_ids()) { + // The process index of a device can be inferred from its global device id + // because global device ids are always assigned to each node in the + // topology in the order they appear in the input when constructing the + // global view. + const int process_index = + num_devices_per_host == -1 ? 0 : (device_id / num_devices_per_host); devices.push_back(std::make_unique( - device_id, std::string(platform_version()))); + device_id, process_index, std::string(platform_version()))); } return devices; } diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 185c921b8a0667..02a639cfefc073 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -922,7 +922,7 @@ TfrtGpuDevice::TfrtGpuDevice(Options&& options) std::numeric_limits::max()), last_collective_launch_event_( tsl::MakeAvailableAsyncValueRef()), - description_(options.id, options.platform_version), + description_(options.id, options.process_index, options.platform_version), max_inflight_computations_semaphore_( /*capacity=*/options.max_inflight_computations) { description_.SetDebugString(absl::StrCat("TFRT_GPU_", id_)); @@ -1768,6 +1768,8 @@ GetTfrtGpuDevices(LocalClient* xla_client) { TfrtGpuDevice::Options options; options.id = i; + // TODO: b/382117736 - Support multi-host + options.process_index = 0; options.local_device_id = PjRtLocalDeviceId(i); options.local_hardware_id = PjRtLocalHardwareId(i); options.executor = executor; diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h index 32f8dd2766d0b8..ad53ae6e3a0607 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h @@ -120,6 +120,7 @@ class TfrtGpuDevice final : public PjRtDevice { public: struct Options { int id; + int32_t process_index; PjRtLocalDeviceId local_device_id; PjRtLocalHardwareId local_hardware_id; se::StreamExecutor* executor; diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index bf3a98c0998eb6..57b058c2a434bf 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -96,10 +96,10 @@ struct PjRtStreamExecutorExecutionOutput { class PjRtStreamExecutorDevice : public PjRtDevice { public: - explicit PjRtStreamExecutorDevice( - int id, std::unique_ptr local_device_state, - std::string device_kind, int process_index = 0) - : description_(id, std::move(device_kind), process_index), + PjRtStreamExecutorDevice(int id, + std::unique_ptr local_device_state, + int process_index, std::string device_kind) + : description_(id, process_index, std::move(device_kind)), local_device_id_(local_device_state ? local_device_state->local_device_id() : PjRtLocalDeviceId(-1)), diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc index cf6510a0b6da38..2f70860cddd67c 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -62,7 +62,7 @@ absl::StatusOr> GetClient() { /*allow_event_reuse=*/false, /*use_callback_stream=*/false); std::vector> devices; devices.emplace_back(std::make_unique( - 0, std::move(device_state), "cpu")); + 0, std::move(device_state), /*process_index=*/0, "cpu")); std::vector> memory_spaces; memory_spaces.emplace_back(std::make_unique( 0, devices.back().get(), "cpu", 0)); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_device_description.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_device_description.h index 25bc380f823fd2..2680d6826a8d82 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_device_description.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_device_description.h @@ -25,8 +25,8 @@ namespace xla { class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription { public: - explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind, - int process_index = 0) + PjRtStreamExecutorDeviceDescription(int id, int process_index, + std::string device_kind) : id_(id), process_index_(process_index), device_kind_(std::move(device_kind)) {} From 185f2f58bafc6410125080264d5d7730e1fa1eb2 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Tue, 15 Apr 2025 10:56:59 -0700 Subject: [PATCH 0821/1324] Fix support of weights with `RaggedTensor`s in `TPUEmbeddingV2`. Right now, when using a `RaggedTensor` and weights, `TPUEmbeddingV2` fails with error "RaggedTensor cannot be used as a bool". PiperOrigin-RevId: 747941247 --- tensorflow/python/tpu/tpu_embedding_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/tpu/tpu_embedding_v3.py b/tensorflow/python/tpu/tpu_embedding_v3.py index 87b23148287fa1..58e1ea430a91f7 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3.py +++ b/tensorflow/python/tpu/tpu_embedding_v3.py @@ -1703,7 +1703,7 @@ def _convert_input_feature_to_list_of_coo_tensors( ) ) elif isinstance(input_feature, ragged_tensor.RaggedTensor): - if not weight: + if weight is None: weight = array_ops.ones_like(input_feature.values, dtype=dtypes.float32) elif isinstance(weight, ragged_tensor.RaggedTensor): weight = weight.values From aa4b9c65cec41e21b49905c79c143b29ee683d2d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 11:01:37 -0700 Subject: [PATCH 0822/1324] Internal change only PiperOrigin-RevId: 747943148 --- third_party/xla/.bazelrc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 1f88d5d7d18aa5..83bce85c4fbbf5 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -9,3 +9,9 @@ import %workspace%/tensorflow.bazelrc import %workspace%/warnings.bazelrc try-import %workspace%/xla_configure.bazelrc + +# /*absl_nonnull*/, /*absl_nullable*/, and /*absl_nullability_unknown*/ are not yet present +# in the version of absl we are using. +# This can be removed when the absl version used is bumped to commit 48f0f91 or +# newer, likely after July 2025. +common --copt=-D/*absl_nonnull*/='' --copt=-D/*absl_nullable*/='' --copt=-D/*absl_nullability_unknown*/='' \ No newline at end of file From aa03f7723b1aee3d31f6e2a1735014dca8b09f7a Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 15 Apr 2025 11:15:21 -0700 Subject: [PATCH 0823/1324] Disable `//xla/hlo/tools/tests:hlo_opt_emit_proto.hlo.test` due to flakiness Example log: https://github.com/openxla/xla/actions/runs/14475189517/job/40599018424 PiperOrigin-RevId: 747949036 --- third_party/xla/xla/hlo/tools/tests/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/tests/BUILD b/third_party/xla/xla/hlo/tools/tests/BUILD index 0fb22f00c5f56c..3592fdba116916 100644 --- a/third_party/xla/xla/hlo/tools/tests/BUILD +++ b/third_party/xla/xla/hlo/tools/tests/BUILD @@ -1,7 +1,6 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") load("//xla:py_strict.bzl", "py_strict_test") -load("//xla/tsl:tsl.bzl", "if_oss") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -42,7 +41,7 @@ lit_test_suite( "hlo_opt_hlo_protobinary.pb", ], tags_override = { - "hlo_opt_emit_proto.hlo": if_oss(["not_run:arm"]), # TODO(b/394180263) + "hlo_opt_emit_proto.hlo": ["no_oss"], # TODO(b/394180263) }, tools = [ "//xla/hlo/tools:hlo-opt", From 4548251674b8a4e66bf3830a0bef3df95436328f Mon Sep 17 00:00:00 2001 From: "Ryan M. Lefever" Date: Tue, 15 Apr 2025 11:50:50 -0700 Subject: [PATCH 0824/1324] Add a feature to expand scoped alternate memory to the biggest free chunk at the end of MSA. PiperOrigin-RevId: 747963240 --- .../memory_space_assignment/algorithm.cc | 106 ++++++++++++++ .../memory_space_assignment/algorithm.h | 4 + .../memory_space_assignment.proto | 11 ++ .../memory_space_assignment_test.cc | 131 ++++++++++++++++++ .../service/memory_space_assignment/options.h | 7 + 5 files changed, 259 insertions(+) diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 6c8bff8b26f21e..e77fc625a51087 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -92,6 +92,10 @@ const HeapSimulator::Chunk kDummyChunk = // if the buffer occupies less of the execution time ratio than this value. const float kCrossProgramPrefetchOccupyFreeingLimit = 0.6; +int64_t GetAlignedOffset(int64_t offset, int64_t alignment) { + return CeilOfRatio(offset, alignment) * alignment; +} + template std::string VectorToString(const std::vector& v, bool include_indices = false, int start = 0, @@ -627,6 +631,103 @@ void MsaAlgorithm::FindAliases( } } +void MsaAlgorithm::ExtendScopedAlternateMemoryAllocations() { + VLOG(1) << "Starting vmem expansion"; + + // Iterate through all scoped allocations and try to expand them to the + // largest contiguous open space available. + for (std::unique_ptr& allocation : *allocations_) { + if (!allocation->is_scoped_allocation()) { + continue; + } + + // Find the set of nodes that are live during allocation. + std::vector live_nodes = interval_tree_.ChunksOverlappingInTime( + allocation->start_time(), allocation->end_time()); + absl::c_sort(live_nodes, [](const Chunk& a, const Chunk& b) { + return a.offset < b.offset; + }); + + // Loop over live_nodes to compute 2 things: + // 1. The largest contiguous free chunk (biggest_free_chunk) + // 2. The largest chunk we can get by moving the start time of the scoped + // allocation earlier (i.e., to max_end_before_scoped_allocation), and + // the end time later (i.e., to min_offset_after_scoped_allocation). + int64_t min_offset_after_scoped_allocation = available_heap_size(); + int64_t max_end_before_scoped_allocation = 0; + Chunk biggest_free_chunk = Chunk::FromOffsetSize(0, 0); + for (int i = 0; i < live_nodes.size(); ++i) { + const Chunk& chunk = live_nodes[i]; + if (allocation->chunk().chunk_end() <= chunk.offset) { + min_offset_after_scoped_allocation = + std::min(min_offset_after_scoped_allocation, chunk.offset); + } + if (allocation->chunk().offset >= chunk.chunk_end()) { + max_end_before_scoped_allocation = + std::max(max_end_before_scoped_allocation, chunk.chunk_end()); + } + + Chunk next_free_chunk = Chunk::FromOffsetEnd( + GetAlignedOffset(chunk.chunk_end(), options_.alignment_in_bytes), + (i + 1) < live_nodes.size() ? live_nodes[i + 1].offset + : available_heap_size()); + if (next_free_chunk.size > biggest_free_chunk.size) { + biggest_free_chunk = next_free_chunk; + } + } + + Chunk proposed_extended_chunk = + Chunk::FromOffsetEnd(GetAlignedOffset(max_end_before_scoped_allocation, + options_.alignment_in_bytes), + min_offset_after_scoped_allocation); + + // Check if we should extend the boundaries of the scoped allocation or + // move it. + Chunk proposed_chunk = allocation->chunk(); + std::string source; + if (proposed_extended_chunk.size > proposed_chunk.size) { + proposed_chunk = proposed_extended_chunk; + source = "extended"; + } + if (biggest_free_chunk.size > proposed_chunk.size) { + proposed_chunk = biggest_free_chunk; + source = "free"; + } + if (source.empty()) { + VLOG(3) << "Could not move the scoped allocation for " + << allocation->defining_position().ToString() + << "; Current fragmentation: " << + [&]() { + int64_t occupied_size = 0; + for (const Chunk& chunk : live_nodes) { + occupied_size += chunk.size; + } + double fragmentation = + static_cast(available_heap_size() - occupied_size) / + static_cast(available_heap_size()); + return 100.0 * fragmentation; + }() << "%"; + continue; + } + + VLOG(1) << "Moving the scoped allocation for " + << allocation->defining_position().ToString() << " from " + << allocation->chunk().ToString() << " to " + << proposed_chunk.ToString() << " (" << source + << "); Size increase: " + << (100.0 * + static_cast(proposed_chunk.size - + allocation->chunk().size) / + static_cast(allocation->chunk().size)) + << "%"; + + // Update the allocation. We don't need to update result_.chunk_map. It's + // not used by MSA. + *allocation->mutable_chunk() = proposed_chunk; + result_.UpdatedHeapSize(proposed_chunk); + } +} + std::string MsaAlgorithm::RequiredMemoryAssignment::ToString() const { std::string memory_space_str = memory_space == MemorySpace::kDefault ? "def" : "alt"; @@ -2093,6 +2194,11 @@ absl::StatusOr> MsaAlgorithm::Finish() { } } + if (options_.expanded_scoped_alternate_memory_mode == + ExpandedScopedAlternateMemoryMode::ENABLED) { + ExtendScopedAlternateMemoryAllocations(); + } + HeapSimulator::Result result; result.heap_size = result_.heap_size; result.heap_results.emplace_back(std::move(result_)); diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index b32340cb8927ae..d45e909ebf1d74 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -350,6 +350,10 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } const HloLiveRange& hlo_live_range() { return hlo_live_range_; } + // Runs a feature that attempts to expand the size of scoped alternate memory + // allocations to the largest contiguous open space available. + void ExtendScopedAlternateMemoryAllocations(); + private: // We inherit AllocationBlock struct to attach the Allocation information to // make importing repacked offsets easier. diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto index 0588a1d294530c..1c8903b810ebe8 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -193,3 +193,14 @@ message MsaSortOrderOverride { message MsaSortOrderOverrides { repeated MsaSortOrderOverride overrides = 1; } + +// Expanded scoped alternate memory is a feature used at the end of MSA, in +// in which we attempt to expand the size of allocated scoped alternate memory +// buffers to the largest contiguous open space available. +message ExpandedScopedAlternateMemoryMode { + enum Value { + UNDEFINED = 0; + DISABLED = 1; + ENABLED = 2; + } +} diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 982e87481ac26d..887258c212dab9 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -11705,6 +11705,137 @@ ENTRY main { EXPECT_EQ(f_index, p1_copy_end + 1); } +TEST_F(MemorySpaceAssignmentTest, ExpandScopedAlternateMemory) { + absl::string_view hlo_string = R"( + HloModule TestModule, is_scheduled=true + ENTRY Main { + p0 = f32[8,8] parameter(0) + p1 = f32[8,8] parameter(1) + p2 = f32[8,8] parameter(2) + p3 = f32[8,8] parameter(3) + + v0 = add(p0, p1) + v1 = add(v0, p1) + v2 = add(v1, p1) + + v3 = multiply(v2, p2) + v4 = multiply(v3, p3) + + ROOT t = tuple(v3, v4) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& lhs, const MsaBufferInterval& rhs) { + auto lookup = [](const MsaBufferInterval& x) { + // An arbitrary value that is greater than that used for 'prefetch'. + int priority = 100; + if (x.buffer->instruction()->name() == "p2") { + priority = 1; + } else if (x.buffer->instruction()->name() == "p3") { + priority = 2; + } + return std::make_tuple(priority, x.buffer->instruction()->name()); + }; + + return lookup(lhs) < lookup(rhs); + }; + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 1000); + + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 600; + options.reserved_scoped_memory_fn = + [](const HloInstruction* instruction, + const absl::flat_hash_set< + std::pair>& /*operands_in_alternate_memory*/, + const absl::flat_hash_set< + ShapeIndex>& /*outputs_in_alternate_memory*/) { return 10; }; + options.expanded_scoped_alternate_memory_mode = + ExpandedScopedAlternateMemoryMode::ENABLED; + options.alignment_in_bytes = 10; + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + + VLOG(1) << "Post-MSA module:\n" << module->ToString(); + + // We expect MSA to do the following: + // A. Initially allocate [0, 10) for scoped alternate memory, for each + // instruction. + // B. Since, p2 comes first in the buffer sorting, we expect it to be + // allocated [10, 266) for a prefetch + // C. Since, p3 comes next in the buffer sorting, we expect it to be allocated + // [270, 526) for a prefetch + // D. Finally, MSA will try to expand the scoped alternate memory allocations + // to the largest available buffers, keeping in mind the prefetches. + + // Check B and C. + for (const auto& [position, chunk] : preset_assignments->chunks()) { + if (position.instruction->opcode() == HloOpcode::kCopyDone) { + ASSERT_EQ(position.instruction->operand_count(), 1); + const HloInstruction* copy_start = position.instruction->operand(0); + ASSERT_EQ(copy_start->operand_count(), 1); + const HloInstruction* copy_operand = copy_start->operand(0); + if (copy_operand->name() == "p2") { + EXPECT_EQ(chunk.offset, 10); + EXPECT_EQ(chunk.size, 256); + } else if (copy_operand->name() == "p3") { + EXPECT_EQ(chunk.offset, 270); + EXPECT_EQ(chunk.size, 256); + } + } + } + + // Check D. + for (const auto& [instruction, chunk] : + preset_assignments->scoped_allocation_chunks()) { + if (instruction->name() == "p0") { + // Extended scoped allocation. + EXPECT_EQ(chunk.offset, 0); + EXPECT_EQ(chunk.size, 600); + } else if (instruction->name() == "p1") { + // Extended scoped allocation. + EXPECT_EQ(chunk.offset, 0); + EXPECT_EQ(chunk.size, 600); + } else if (instruction->name() == "p2") { + // Extended scoped allocation. + EXPECT_EQ(chunk.offset, 0); + EXPECT_EQ(chunk.size, 600); + } else if (instruction->name() == "p3") { + // Moved scoped allocation. + EXPECT_EQ(chunk.offset, 270); + EXPECT_EQ(chunk.size, 330); + } else if (instruction->name() == "v0") { + // Moved scoped allocation. + EXPECT_EQ(chunk.offset, 530); + EXPECT_EQ(chunk.size, 70); + } else if (instruction->name() == "v1") { + // Moved scoped allocation. + EXPECT_EQ(chunk.offset, 530); + EXPECT_EQ(chunk.size, 70); + } else if (instruction->name() == "v2") { + // Moved scoped allocation. + EXPECT_EQ(chunk.offset, 530); + EXPECT_EQ(chunk.size, 70); + } else if (instruction->name() == "v3") { + // Moved scoped allocation. + EXPECT_EQ(chunk.offset, 530); + EXPECT_EQ(chunk.size, 70); + } else if (instruction->name() == "v4") { + // Extended scoped allocation. + EXPECT_EQ(chunk.offset, 0); + EXPECT_EQ(chunk.size, 270); + } else if (instruction->name() == "t") { + // Extended scoped allocation. + EXPECT_EQ(chunk.offset, 0); + EXPECT_EQ(chunk.size, 600); + } + } +} + class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { protected: // Used by CheckSchedule() to classify instructions in the schedule. diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h index ff1a951d25a713..ef9825b29df8b4 100644 --- a/third_party/xla/xla/service/memory_space_assignment/options.h +++ b/third_party/xla/xla/service/memory_space_assignment/options.h @@ -374,7 +374,14 @@ struct Options { WindowPrefetchMode window_prefetch_mode = WindowPrefetchMode::kWindowExposure; MsaSortOrderOverrides msa_sort_order_overrides; + + // A mode that enables expanding scoped alternate memory allocations to the + // largest contiguous open space available. + ExpandedScopedAlternateMemoryMode::Value + expanded_scoped_alternate_memory_mode = + ExpandedScopedAlternateMemoryMode::DISABLED; }; + } // namespace memory_space_assignment } // namespace xla From 0054b0455108f8660508ad64aa3b6b7b870886db Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Tue, 15 Apr 2025 11:54:55 -0700 Subject: [PATCH 0825/1324] Reverts 67c4f3890322d918ce7de7286d1b237c87d5cc86 PiperOrigin-RevId: 747964889 --- third_party/xla/xla/debug_options_flags.cc | 2 +- third_party/xla/xla/xla.proto | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 052e27833acef2..d3a1683fd35e92 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -240,7 +240,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); - opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(true); + opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(false); opts.set_xla_gpu_unsafe_pipelined_loop_annotator(false); opts.set_xla_gpu_copy_insertion_use_region_analysis(false); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 2e4d2c5b8df94d..af97d839525c49 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -754,8 +754,6 @@ message DebugOptions { // `xla_gpu_cublas_fallback` set to false. bool xla_gpu_triton_gemm_any = 190; - // TODO(b/409940111): Remove this flag and use high precision reductions for - // Split-K GEMMs unconditionally. bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; // It is usually preferable to not fallback to the driver; it can consume more From 224b164ea34d3876511bd4ef156d7ec19e4b2af5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 11:57:02 -0700 Subject: [PATCH 0826/1324] Add unit tests for hlo_computation APIs. PiperOrigin-RevId: 747965621 --- third_party/xla/xla/hlo/ir/BUILD | 14 ++ .../xla/xla/hlo/ir/hlo_computation_test.cc | 215 ++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 third_party/xla/xla/hlo/ir/hlo_computation_test.cc diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 72bca9a992350a..82d03ec0abd9e7 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -122,6 +122,20 @@ cc_library( ], ) +xla_cc_test( + name = "hlo_computation_test", + srcs = ["hlo_computation_test.cc"], + deps = [ + ":hlo", + "//xla:shape_util", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + ], +) + xla_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], diff --git a/third_party/xla/xla/hlo/ir/hlo_computation_test.cc b/third_party/xla/xla/hlo/ir/hlo_computation_test.cc new file mode 100644 index 00000000000000..e9ef0eaf8efc48 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_computation_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_computation.h" + +#include +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace { + +using HLOComputationTest = HloHardwareIndependentTestBase; + +int64_t CountControlEdges(const HloComputation &computation) { + int64_t count = 0; + for (const auto &instruction : computation.instructions()) { + count += instruction->control_successors().size(); + } + return count; +} + +TEST_F(HLOComputationTest, DefUseOrder) { + absl::string_view hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = f32[] add(a, b) +} + +ENTRY entry { + p0 = f32[100] parameter(0), parameter_replication={false} + p1 = f32[100] parameter(1), parameter_replication={false} + add0 = f32[100] add(p0, p1) + mul0 = f32[100] multiply(p0, p1) + div0 = f32[100] divide(p0, p1) + reduce0 = f32[100] all-reduce(add0), replica_groups={}, to_apply=sum, channel_id=1 + reduce1 = f32[100] all-reduce(mul0), replica_groups={}, to_apply=sum, channel_id=1 + reduce2 = f32[100] all-reduce(div0), replica_groups={}, to_apply=sum, channel_id=1 + add1 = f32[100] add(reduce0, reduce1) + ROOT out = f32[100] add(add1, reduce2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + EXPECT_EQ(CountControlEdges(*module->entry_computation()), 0); + + const HloInstruction *root = module->entry_computation()->root_instruction(); + const HloInstruction *add1 = root->operand(0); // t = add(c1, c2) + const HloInstruction *reduce2 = root->operand(1); // c3 = all-reduce(i2)... + EXPECT_EQ(add1->opcode(), HloOpcode::kAdd); + EXPECT_EQ(reduce2->opcode(), HloOpcode::kAllReduce); + + const HloInstruction *reduce0 = add1->operand(0); + const HloInstruction *reduce1 = add1->operand(1); + EXPECT_EQ(reduce0->opcode(), HloOpcode::kAllReduce); + EXPECT_EQ(reduce1->opcode(), HloOpcode::kAllReduce); + + bool found_add0 = false; + // Verify that i0 is before c1. + auto post_order = module->entry_computation()->MakeInstructionPostOrder(); + for (const auto &instruction : post_order) { + if (instruction->name() == "reduce0") { + EXPECT_TRUE(found_add0); + } + if (instruction->name() == "add0") { + found_add0 = true; + } + } + + // Verify that MakeInstructionPostOrder() is idempotent. + auto post_order_2 = module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_EQ(post_order, post_order_2); +} + +TEST_F(HLOComputationTest, MakeInstructionPostOrder) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add0 = f32[100] add(p0, p1) + mul0 = f32[100] multiply(p0, add0) + ROOT div0 = f32[100] divide(p1, mul0) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto post_order = module->entry_computation()->MakeInstructionPostOrder(); + + // Verify the order of instructions in the post order. + bool found_p0 = false; + bool found_p1 = false; + bool found_add0 = false; + bool found_mul0 = false; + for (HloInstruction *instruction : post_order) { + if (instruction->name() == "add0") { + EXPECT_TRUE(found_p0); + EXPECT_TRUE(found_p1); + found_add0 = true; + } else if (instruction->name() == "mul0") { + EXPECT_TRUE(found_p0); + EXPECT_TRUE(found_add0); + found_mul0 = true; + } else if (instruction->name() == "div0") { + EXPECT_TRUE(found_p1); + EXPECT_TRUE(found_mul0); + } else if (instruction->name() == "p0") { + found_p0 = true; + } else if (instruction->name() == "p1") { + found_p1 = true; + } + } + + // Verify that MakeInstructionPostOrder() is idempotent. + auto post_order_2 = module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_EQ(post_order, post_order_2); +} + +// Test AddCallee +TEST_F(HLOComputationTest, AddCallee) { + absl::string_view hlo_string = R"( +HloModule module +diff { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = f32[] add(a, b) +} + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = f32[] add(a, b) +} + +ENTRY entry { + p0 = f32[100] parameter(0), parameter_replication={false} + p1 = f32[100] parameter(1), parameter_replication={false} + map0 = f32[100] map(p0, p1), to_apply=diff + ROOT map1 = f32[100] map(p0, map0), to_apply=sum +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloComputation *entry = module->entry_computation(); + HloComputation *sum = module->GetComputationWithName("sum"); + ASSERT_NE(entry, nullptr); + ASSERT_NE(sum, nullptr); + + EXPECT_EQ(entry->callee_computations().size(), 2); + EXPECT_TRUE(entry->callee_computations().contains(sum)); + EXPECT_EQ(sum->caller_computations().size(), 1); + EXPECT_EQ(sum->caller_computations().count(entry), 1); + + // Get the operands of the add. + HloInstruction *entry_a = entry->root_instruction()->mutable_operand(0); + HloInstruction *entry_b = entry->root_instruction()->mutable_operand(1); + + // Create a new computation and add it as a callee. + auto builder = HloComputation::Builder("mul"); + auto a = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a")); + auto b = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b")); + + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMultiply, a, b)); + + HloComputation *mul_comp = module->AddEmbeddedComputation(builder.Build()); + + auto map = HloInstruction::CreateMap(entry->root_instruction()->shape(), + {entry_a, entry_b}, mul_comp); + + // Add the new computation as a callee of the entry computation. + EXPECT_EQ(entry->ReplaceWithNewInstruction(entry->root_instruction(), + std::move(map)), + absl::OkStatus()); + + HloComputation *mul_int = module->GetComputationWithName("mul"); + + EXPECT_EQ(entry->callee_computations().size(), 2); + EXPECT_FALSE(entry->callee_computations().contains(sum)); + EXPECT_EQ(entry->callee_computations().count(mul_int), 1); + EXPECT_EQ(sum->caller_computations().size(), 0); + EXPECT_EQ(mul_int->caller_computations().size(), 1); + EXPECT_TRUE(mul_int->caller_computations().contains(entry)); +} +} // namespace +} // namespace xla From 8b12b0b645cf7a7523893b9a631b2b830bbd7df7 Mon Sep 17 00:00:00 2001 From: Bryan Massoth Date: Tue, 15 Apr 2025 11:59:37 -0700 Subject: [PATCH 0827/1324] Fix duty cycle calculations to include xla modules to account for idle time at the beginning/end of profiles PiperOrigin-RevId: 747966685 --- tensorflow/core/profiler/convert/BUILD | 1 + .../profiler/convert/xplane_to_op_stats.cc | 20 ++++++++++++++----- .../convert/xplane_to_op_stats_test.cc | 11 +++++++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index f56a989dd45613..858cee63d4cb79 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -457,6 +457,7 @@ cc_library( "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/tsl/profiler/utils:timespan", "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/utils:xplane_utils", "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index f01e8a833d64ba..bc3a6397c41fc3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" +#include + #include #include #include @@ -30,6 +32,7 @@ limitations under the License. #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/timespan.h" #include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/convert/duty_cycle_combiner.h" #include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" @@ -274,16 +277,23 @@ DutyCycleTracker ConstructDutyCycleTracker(XPlaneVisitor& visitor) { line.ForEachEvent([&](const XEventVisitor& event) { auto hlo_category_stat = event.GetStat(StatType::kHloCategory); duty_cycle_tracker.AddInterval( - Timespan(event.OffsetPs(), event.DurationPs()), + event.GetTimespan(), !(hlo_category_stat && tsl::profiler::IsOffDutyOp(hlo_category_stat->StrOrRefValue()))); }); - } else if (line.Name() == kSparseCoreOpLineName || - line.Name() == kSparseCoreModuleLineName) { + } else if (line.Name() == kSparseCoreOpLineName) { line.ForEachEvent([&](const XEventVisitor& event) { duty_cycle_tracker.AddInterval( - Timespan(event.OffsetPs(), event.DurationPs()), - /*is_active=*/line.Name() == kSparseCoreOpLineName); + event.GetTimespan(), + // TODO(b/397774568): Add support for SparseCore off-duty ops. + /*is_active=*/true); + }); + } else if (line.Name() == tsl::profiler::kXlaModuleLineName || + line.Name() == tsl::profiler::kSparseCoreModuleLineName) { + line.ForEachEvent([&](const XEventVisitor& event) { + duty_cycle_tracker.AddInterval(event.GetTimespan(), + /*is_active=*/false); + return; }); } }); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index c62b36428c9a28..91009537b23aca 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -606,11 +606,16 @@ TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromXlaOps) { CreateXEvent(&device_plane_builder, &op_line, "op.4", /*offset_ps=*/40, /*duration_ps=*/10, {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); + XLineBuilder xla_module_line = device_plane_builder.GetOrCreateLine(1); + xla_module_line.SetName(kXlaModuleLineName); + CreateXEvent(&device_plane_builder, &xla_module_line, "module.1", + /*offset_ps=*/5, + /*duration_ps=*/50); XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(device_plane); DutyCycleTracker tracker = ConstructDutyCycleTracker(visitor); EXPECT_EQ(tracker.GetActiveTimePs(), 20); - EXPECT_EQ(tracker.GetIdleTimePs(), 20); + EXPECT_EQ(tracker.GetIdleTimePs(), 30); } TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromSparseCore) { @@ -668,6 +673,10 @@ TEST(ConvertXPlaneToOpStats, MultiCoreChipBusyAndIdleTimeTest) { CreateXEvent(&tc_plane_builder, &xla_op_line, "op.4", /*offset_ps=*/40, /*duration_ps=*/10, {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); + XLineBuilder xla_module_line = tc_plane_builder.GetOrCreateLine(1); + xla_module_line.SetName(kXlaModuleLineName); + CreateXEvent(&tc_plane_builder, &xla_module_line, "module.1", /*offset_ps=*/5, + /*duration_ps=*/50); XPlane* sc_plane = GetOrCreateTpuXPlane( &space, /*device_ordinal=*/1, /*device_type=*/"TPU v4", From 8deb4a59e3523abdc10d99d404956af0566b23dc Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 15 Apr 2025 12:15:32 -0700 Subject: [PATCH 0828/1324] Fix treatment of `shape_with_output_layout` in `ClientLibraryTestRunnerMixin`. PiperOrigin-RevId: 747972915 --- third_party/xla/xla/tests/BUILD | 1 + .../tests/client_library_test_runner_mixin.h | 35 +++++++++++++------ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 70befa3997bfd2..536790d796894b 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -385,6 +385,7 @@ cc_library( "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", + "//xla/service:hlo_module_util", "//xla/tsl/lib/core:bitmap", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index 1af00e7fb6a212..51e95c328c4aaf 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_module_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_utils.h" @@ -103,9 +105,14 @@ class ClientLibraryTestRunnerMixin : public T { *execution_options.mutable_shape_with_output_layout() = shape_with_output_layout->ToProto(); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - BuildAndVerifyHloModule(computation, &execution_options)); + std::vector argument_shapes; + argument_shapes.reserve(arguments.size()); + for (const Literal* argument : arguments) { + argument_shapes.push_back(&argument->shape()); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + BuildAndVerifyHloModule(computation, argument_shapes, + &execution_options)); return this->Execute(std::move(module), arguments); } @@ -138,9 +145,15 @@ class ClientLibraryTestRunnerMixin : public T { void ComputeAndCompare(XlaBuilder* const builder, const absl::Span arguments, const std::optional error = std::nullopt) { + std::vector argument_shapes; + argument_shapes.reserve(arguments.size()); + for (const Literal* argument : arguments) { + argument_shapes.push_back(&argument->shape()); + } TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder->Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - BuildAndVerifyHloModule(computation)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + BuildAndVerifyHloModule(computation, argument_shapes)); EXPECT_TRUE(this->RunAndCompare(std::move(module), arguments, error)); } @@ -372,18 +385,20 @@ class ClientLibraryTestRunnerMixin : public T { private: absl::StatusOr> BuildAndVerifyHloModule( const XlaComputation& computation, + absl::Span argument_shapes, const ExecutionOptions* execution_options = nullptr) const { if (execution_options == nullptr) { execution_options = &execution_options_; } + TF_ASSIGN_OR_RETURN(const ProgramShape program_shape, + computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - HloModuleConfig module_config, - HloModule::CreateModuleConfigFromProto( - computation.proto(), execution_options->debug_options(), - execution_options)); + std::unique_ptr module_config, + CreateModuleConfig(program_shape, argument_shapes, execution_options, + /*default_num_replicas=*/1)); TF_ASSIGN_OR_RETURN( std::unique_ptr module, - HloModule::CreateFromProto(computation.proto(), module_config)); + HloModule::CreateFromProto(computation.proto(), *module_config)); TF_RETURN_IF_ERROR(this->verifier().Run(module.get()).status()); return module; } From 3a198832c8e025e43fa70c68d52a03e17fe0008d Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Tue, 15 Apr 2025 12:24:21 -0700 Subject: [PATCH 0829/1324] [cuDNN] Always set is_disabled_x32 to false This flag appears to always be false in current usage. This change is a prerequisite for removing legacy cuDNN API usage in an upcoming PR. PiperOrigin-RevId: 747976303 --- .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 55 +------------------ 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index d7135689aa9448..1c3867318e9ee2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -6283,32 +6283,7 @@ absl::Status CudnnSupport::GetConvolveRunners( ScratchAllocator* /*scratch_allocator*/, const NumericOptions& numeric_options, std::vector>* out_exec_plans) { - // cuDNN frontend support for Tx32 convolutions added in 8.3. - // If the filter is not reordered, do not use frontend (it is slow). - const bool is_disabled_x32 = - input_descriptor.layout() == dnn::kBatchDepthYX32 && - (filter_descriptor.layout() != - dnn::FilterLayout::kOutputInputYX32_CudnnReordered); - - const bool actually_use_cudnn_frontend = - use_cudnn_frontend && !is_disabled_x32; - - if (use_cudnn_frontend && !actually_use_cudnn_frontend) { - // This will happen once per unique conv configuration/shape that gets - // affected (and not, for example, on every conv launch). Confusion over - // whether this has happened or not has repeatedly wasted a lot of time - // debugging, so be sure it shows up in the logs. - LOG(INFO) << "Disabling cuDNN frontend for the following convolution:\n" - << " input: " << input_descriptor.ToString() << "\n" - << " filter: " << filter_descriptor.ToString() << "\n" - << " " << convolution_descriptor.ToString() << "\n" - << " ... because " - << (is_disabled_x32 - ? "Tx32 convolutions are disabled." - : "the current cuDNN version does not support it."); - } - - if (!actually_use_cudnn_frontend) { + if (!use_cudnn_frontend) { auto cuda_compute_capability = stream->GetCudaComputeCapability(); std::vector algorithms; bool got_algos = false; @@ -6758,32 +6733,6 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( // implicitly do ReLU on some engines, and we can't reliably detect which // ones. - // If the filter is not reordered, do not use frontend (it is slow). - const bool is_disabled_x32 = - input_descriptor.layout() == dnn::kBatchDepthYX32 && - (filter_descriptor.layout() != - dnn::FilterLayout::kOutputInputYX32_CudnnReordered); - - const bool actually_use_cudnn_frontend = - use_cudnn_frontend && !is_disabled_x32; - - if (use_cudnn_frontend && !actually_use_cudnn_frontend) { - const char* reason = "the current cuDNN version does not support it."; - if (is_disabled_x32) { - reason = "Tx32 convolutions are disabled."; - } - - // This will happen once per unique conv configuration/shape that gets - // affected (and not, for example, on every conv launch). Confusion over - // whether this has happened or not has repeatedly wasted a lot of time - // debugging, so be sure it shows up in the logs. - LOG(INFO) << "Disabling cuDNN frontend for the following convolution:\n" - << " input: " << input_descriptor.ToString() << "\n" - << " filter: " << filter_descriptor.ToString() << "\n" - << " " << convolution_descriptor.ToString() << "\n" - << " ... because " << reason; - } - if (input_type == dnn::DataType::kInt8 && !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) { return tsl::errors::Unimplemented( @@ -6801,7 +6750,7 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( "{Relu, Relu6, Elu, }."); } - if (!actually_use_cudnn_frontend) { + if (!use_cudnn_frontend) { std::vector algorithms; auto cuda_compute_capability = stream->GetCudaComputeCapability(); From 2c7783c36cd4c2a1c1ab421f16a76b4ff449be23 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 12:36:09 -0700 Subject: [PATCH 0830/1324] *** Reason for rollback *** Fix the regression. Instead of not broadcasting scalar constant, we check if the constant is for padding value explicitly. Reverts c2a3c368e79be0292faeac380086c42169763908 PiperOrigin-RevId: 747980341 --- .../xla/xla/service/collective_pipeliner.cc | 36 +++++++++++-------- .../xla/service/collective_pipeliner_test.cc | 5 ++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 60e57b4672641f..01fadac3d7d9ef 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -211,7 +211,7 @@ CollectDynamicSliceIndicesIfConstant(HloInstruction* instr) { for (int64_t i = dyn_slice->first_index_operand_number(); i < instr->operand_count(); ++i) { HloInstruction* operand = dyn_slice->mutable_operand(i); - CHECK_EQ(operand->shape().dimensions_size(), 0); + CHECK(operand->shape().dimensions().empty()); std::vector> stack( 1, std::make_pair(operand, 0)); absl::flat_hash_set visited; @@ -343,12 +343,11 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, ShapeUtil::ElementsIn(instr->operand(0)->shape()) < 1024)) { return true; } - // TODO(b/409716406): Reconsider cases where Pad can be supported. return HloPredicateIsOp(i) || + HloOpcode::kPad, HloOpcode::kCollectivePermute, + HloOpcode::kConvert, HloOpcode::kReshape, + HloOpcode::kAllReduce, HloOpcode::kTranspose, + HloOpcode::kBroadcast, HloOpcode::kAllGather>(i) || (multi_uses_pipelining && i->IsElementwise()) || i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep) || i->IsCustomCall(CollectivePipeliner::kSunkByPreviousStep); @@ -1529,7 +1528,7 @@ Shape ComputeFullOutputShape(const WhileMoveInfo& move_info, // Create zero of base type ptype and broadcast it to shape. HloInstruction* CreateZero(HloComputation* comp, const Shape& shape, PrimitiveType ptype) { - if (shape.dimensions_size() == 0) { + if (shape.dimensions().empty()) { return comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); } @@ -2009,8 +2008,8 @@ absl::Status TransformLoopForward( if (slice_target_shape != data_to_slice->shape()) { // Slice matrix. absl::InlinedVector dynamic_slice_sizes; - dynamic_slice_sizes.reserve(slice_target_shape.dimensions_size()); - for (int i = 0; i < slice_target_shape.dimensions_size(); ++i) { + dynamic_slice_sizes.reserve(slice_target_shape.dimensions().size()); + for (int i = 0; i < slice_target_shape.dimensions().size(); ++i) { dynamic_slice_sizes.push_back(slice_target_shape.dimensions(i)); } sliced_data = @@ -2268,7 +2267,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, Shape index_shape = move_info.dynamic_update_slices.front()->index_shapes()[0]; std::vector indices( - expanded_shape.dimensions_size(), + expanded_shape.dimensions().size(), CreateZero(body_computation, index_shape, index_shape.element_type())); indices[0] = move_info.dynamic_update_slices.front()->index_operands()[0]; @@ -2313,7 +2312,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, HloDynamicUpdateSliceInstruction* dyn_update = to_move.dynamic_update_slices[0]; std::vector indices( - expanded_shape.dimensions_size(), + expanded_shape.dimensions().size(), CreateZero(body_computation, dyn_update->index_shapes()[0], dyn_update->index_shapes()[0].element_type())); indices[0] = dyn_update->index_operands()[0]; @@ -2434,7 +2433,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, if (is_loop_invariant) { Shape full_shape = ComputeFullOutputShape(to_move, pipelined->shape()); absl::InlinedVector operand_dims; - operand_dims.resize(pipelined->shape().dimensions_size()); + operand_dims.resize(pipelined->shape().dimensions().size()); absl::c_iota(operand_dims, 1); HloInstruction* broadcasted = loop_computation->AddInstruction(HloInstruction::CreateBroadcast( @@ -2460,6 +2459,15 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, std::vector operands; for (auto* operand : instr->mutable_operands()) { if (operand->opcode() == HloOpcode::kConstant) { + if (instr->opcode() == HloOpcode::kPad && + instr->operand_index(operand) == 1) { + // No need to broadcast the padding value. + operands.push_back(loop_computation->AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {}))); + continue; + } + + // Broadcast constant into full shape. HloInstruction* cloned_constant = loop_computation->AddInstruction( operand->CloneWithNewOperands(operand->shape(), {})); if (!to_add_batch_set.contains(instr)) { @@ -2469,7 +2477,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, Shape full_shape = ComputeFullOutputShape(to_move, cloned_constant->shape()); absl::InlinedVector operand_dims; - operand_dims.resize(cloned_constant->shape().dimensions_size()); + operand_dims.resize(cloned_constant->shape().dimensions().size()); absl::c_iota(operand_dims, 1); HloInstruction* broadcasted = loop_computation->AddInstruction(HloInstruction::CreateBroadcast( @@ -2546,7 +2554,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } // Constant scalars don't get expanded ahead of time and are kept // scalar. - if (operands[0]->shape().dimensions_size() == 0) { + if (operands[0]->shape().dimensions().empty()) { dimensions.clear(); } HloInstruction* expanded_broadcast = diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 556d0234db0500..e1c72ab46888c0 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -3253,7 +3253,10 @@ while_body { dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 - b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + slice = bf16[1,8,120] slice(ar.1), slice={[0:1], [0:8], [0:120]} + constant.2563 = bf16[] constant(5.0) + pad = bf16[1,8,128] pad(slice, constant.2563), padding=0_0x0_0x0_8 + b.1 = bf16[1,8,128,32] broadcast(pad), dimensions={0,1,2} constant = bf16[] constant(0) reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) From 8d1fe68b1fc3affc5f660790afb7019415e79ffc Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Tue, 15 Apr 2025 12:37:54 -0700 Subject: [PATCH 0831/1324] Fix integer overflow in TFL::FullyConnectedOp::verify() by using int64_t to store num_elements. Consider the following tfl.fullyconnectd_op, the numerics in the verify function suffer from a integer overflow and produce undesired TFL Converter errors. ``` "tfl.fully_connected"(...) <{...}> : (tensor<2048x128xf32>, tensor<1049088x128xf32>, none) -> tensor<2048x1049088xf32> ``` Produces the error- `'tfl.fully_connected' op expect 'output' num_elements % 1049088 == 0, got 'tensor<2048x1049088xf32>'`, because the `num_elements` is calculated to be 2,148,532,224 which is slightly larger than the maximum value representable by a standard 32-bit signed integer (int32_t), which is 2,147,483,647. So using int64_t in the code. PiperOrigin-RevId: 747981025 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 4 ++-- tensorflow/compiler/mlir/lite/tests/ops.mlir | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 0b8cda7329f11d..afbb885d30388c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1529,7 +1529,7 @@ LogicalResult FullyConnectedOp::verify() { // Input's element size must be multiple of parameter's z_in dimension. const int z_in = filter_type.getDimSize(1); - const int num_input_elements = input_type.getNumElements(); + const int64_t num_input_elements = input_type.getNumElements(); if (z_in != 0 && num_input_elements % z_in != 0) { return op.emitOpError(llvm::formatv( "expect 'input' num_elements % {0} == 0, got input type ", z_in)) @@ -1545,7 +1545,7 @@ LogicalResult FullyConnectedOp::verify() { return mlir::success(); } - const int num_output_elements = output_type.getNumElements(); + const int64_t num_output_elements = output_type.getNumElements(); const int z_out = filter_type.getDimSize(0); if (num_output_elements % z_out != 0) { return op.emitOpError(llvm::formatv( diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index c2096033859fd4..56b82b9042593f 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -2609,6 +2609,13 @@ func.func @fully_connected(%arg0: tensor<1x37xf32>, %arg1: tensor<40x37xf32>, %a // ----- +func.func @fully_connected_with_int64_num_elements(%arg0: tensor<2048x128xf32>, %arg1: tensor<1049088x128xf32>, %arg2: none) -> tensor<2048x1049088xf32> { + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<2048x128xf32>, tensor<1049088x128xf32>, none) -> tensor<2048x1049088xf32> + func.return %0 : tensor<2048x1049088xf32> +} + +// ----- + func.func @fully_connected_no_bias(%arg0: tensor<2x2x10xf32>, %arg1: tensor<40x40xf32>, %arg2: none) -> tensor<1x40xf32> { %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2x10xf32>, tensor<40x40xf32>, none) -> tensor<1x40xf32> func.return %0 : tensor<1x40xf32> From a03f7d8cd0a3c03915bb9ef245f126a000e3db5d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 12:47:33 -0700 Subject: [PATCH 0832/1324] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/6dc35e57bf5a76b034a79af695cd7752de274d68. PiperOrigin-RevId: 747984597 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 3ab4952432ed9c..398e1b398d7bb9 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "345ae8d48df934a5e440febea9a5e1ec856868b0" - TFRT_SHA256 = "07f2ad8b51175ea1521a20a8d42b136ffa89db4aa5fae0935dba8e4a11a9542a" + TFRT_COMMIT = "6dc35e57bf5a76b034a79af695cd7752de274d68" + TFRT_SHA256 = "604c2a2a9c0d24981fe0a62f1ed300f57ea4681992c7dae63ab2b49639809442" tf_http_archive( name = "tf_runtime", From fa0784cbe887c3a00e84d6f0943dd5e628fc135f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 13:06:32 -0700 Subject: [PATCH 0833/1324] Automated Code Change PiperOrigin-RevId: 747991934 --- tensorflow/core/framework/resource_mgr_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 0f4ab03958a53e..ee1873c3a37a50 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -67,7 +67,7 @@ class Other : public ResourceBase { class Finalizable : public ResourceBase { public: - explicit Finalizable(absl::Nonnull finalize_count) + explicit Finalizable(int* /*absl_nonnull*/ finalize_count) : finalize_count_(*finalize_count) {} ~Finalizable() override = default; From 81c6e20274d1093079d901f4b6bd60e70eab18ad Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Tue, 15 Apr 2025 13:12:35 -0700 Subject: [PATCH 0834/1324] Rollback for failing tests. Reverts f4eb58acd2ed1d0b099cdaf68b80bbe570e96d0b PiperOrigin-RevId: 747994305 --- .../xla/xla/backends/gpu/collectives/BUILD | 17 +---- .../gpu/collectives/gpu_collectives.h | 17 ----- .../gpu/collectives/gpu_collectives_stub.h | 4 - .../gpu/collectives/nccl_collectives.cc | 76 ------------------- .../gpu/collectives/nccl_collectives.h | 2 - .../xla/xla/backends/gpu/runtime/BUILD | 6 +- third_party/xla/xla/pjrt/gpu/BUILD | 28 ++++++- third_party/xla/xla/pjrt/gpu/nccl_id_store.cc | 69 +++++++++++++++++ third_party/xla/xla/pjrt/gpu/nccl_id_store.h | 59 ++++++++++++++ .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 26 +++---- 10 files changed, 171 insertions(+), 133 deletions(-) create mode 100644 third_party/xla/xla/pjrt/gpu/nccl_id_store.cc create mode 100644 third_party/xla/xla/pjrt/gpu/nccl_id_store.h diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 2f8c015eea9676..3595fe88020edf 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -1,7 +1,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//xla:xla.default.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load("//xla/tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") @@ -24,7 +24,8 @@ cc_library( name = "gpu_collectives_plugin", deps = [ ":gpu_collectives_stub", - ] + if_nccl([":nccl_collectives"]), + ":nccl_collectives", + ], ) cc_library( @@ -127,7 +128,6 @@ cc_library( srcs = ["gpu_collectives.cc"], hdrs = ["gpu_collectives.h"], deps = [ - "//xla:executable_run_options", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", @@ -136,12 +136,9 @@ cc_library( "//xla/core/collectives:clique_key", "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -192,11 +189,9 @@ cc_library( ]), visibility = ["//visibility:private"], deps = [ - ":gpu_clique_key", ":gpu_collectives", ":nccl_communicator", ":nccl_errors", - "//xla:debug_options_flags", "//xla:status_macros", "//xla:util", "//xla/core/collectives", @@ -205,21 +200,15 @@ cc_library( "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:global_device_id", - "//xla/service/gpu:gpu_executable_run_options", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", ] + if_cuda_is_configured([ diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h index 556699506336b0..fbd88b28063585 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h @@ -19,18 +19,13 @@ limitations under the License. #include #include #include -#include -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" -#include "xla/executable_run_options.h" -#include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -113,18 +108,6 @@ class GpuCollectives : public Collectives { virtual absl::StatusOr Allocate(uint64_t bytes) = 0; virtual absl::Status Deallocate(void* buffer) = 0; - - struct Topology { - int32_t node_id; - int32_t num_nodes; - size_t device_count_per_process; - std::shared_ptr kv_store; - absl::flat_hash_map device_id_to_node_id; - gpu::GpuExecutableRunOptions* gpu_executable_run_options; - }; - - // Initializes the topology information for the collectives backend. - virtual absl::Status InitializeTopology(Topology topology) = 0; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h index 11b50eed094073..f217f4ccd3621d 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -70,10 +70,6 @@ class GpuCollectivesStub : public GpuCollectives { absl::Status Deallocate(void* buffer) final { return UnimplementedError(); } - absl::Status InitializeTopology(Topology topology) final { - return UnimplementedError(); - } - protected: static absl::Status UnimplementedError() { return Unimplemented("XLA compiled without GPU collectives support"); diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index d54c91f00b0e8d..bc84fbef3c62d0 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -20,21 +20,16 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" #include "xla/backends/gpu/collectives/nccl_errors.h" @@ -44,9 +39,6 @@ limitations under the License. #include "xla/core/collectives/collectives_registry.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" -#include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/global_device_id.h" -#include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" @@ -262,74 +254,6 @@ absl::Status NcclCollectives::Deallocate(void* location) { VLOG(2) << "Deallocated collective memory " << location; return absl::OkStatus(); } - -class NcclIdStore { - public: - NcclIdStore(int node_id, - absl::flat_hash_map device_to_node, - std::shared_ptr kv_store) - : node_id_(node_id), - device_to_node_(std::move(device_to_node)), - kv_store_(std::move(kv_store)) {} - - absl::StatusOr GetNcclUniqueId(const CliqueKey& key) { - auto* gpu_key = tsl::down_cast(&key); - if (gpu_key == nullptr) { - return InvalidArgument("Expected GPU clique key"); - } - - // The caller must ensure that threads calling this method concurrently have - // unique keys, otherwise the global key-value store may hold the wrong - // value. - { - absl::MutexLock lock(&mu_); - auto it = cache_.find(*gpu_key); - if (it != cache_.end()) { - return it->second; - } - } - CliqueId clique_id; - int primary_node_id = device_to_node_.at(gpu_key->root_device()); - if (node_id_ == primary_node_id) { - TF_ASSIGN_OR_RETURN( - clique_id, gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); - TF_RETURN_IF_ERROR( - kv_store_->Set(gpu_key->ToString(), clique_id.ToString())); - } else { - TF_ASSIGN_OR_RETURN( - std::string id_str, - kv_store_->Get(gpu_key->ToString(), absl::Minutes(10))); - clique_id = CliqueId(id_str); - } - absl::MutexLock lock(&mu_); - auto result = cache_.emplace(*gpu_key, std::move(clique_id)); - TF_RET_CHECK(result.second) << "Unique ID already in cache."; - return result.first->second; - } - - private: - const int node_id_; - const absl::flat_hash_map device_to_node_; - const std::shared_ptr kv_store_; - - absl::Mutex mu_; - absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); -}; - -absl::Status NcclCollectives::InitializeTopology( - NcclCollectives::Topology topology) { - if (topology.num_nodes > 1) { - auto nccl_id_store = std::make_shared( - topology.node_id, topology.device_id_to_node_id, - std::move(topology.kv_store)); - topology.gpu_executable_run_options->set_clique_id_callback( - [nccl_id_store](const CliqueKey& key) { - return nccl_id_store->GetNcclUniqueId(key); - }); - } - return absl::OkStatus(); -} - } // namespace xla::gpu XLA_COLLECTIVES_REGISTER("gpu", "nccl", 1, diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h index b89a9360728f18..36860f1360f00b 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h @@ -61,8 +61,6 @@ class NcclCollectives : public GpuCollectives { absl::StatusOr Allocate(uint64_t bytes) final; absl::Status Deallocate(void* location) final; - - absl::Status InitializeTopology(Topology topology) final; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 07b0df207aa4c8..3fb164156f02c6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -951,7 +951,11 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@llvm-project//mlir:IR", - ], + ] + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", + ]), ) cc_library( diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 2fcfb4809bd114..c91744bbf81189 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -66,11 +66,8 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_proto_cc", - "//xla/backends/gpu/collectives:gpu_collectives", "//xla/client:client_library", "//xla/client:local_client", - "//xla/core/collectives", - "//xla/core/collectives:collectives_registry", "//xla/hlo/builder:xla_computation", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:event_pool", @@ -146,6 +143,7 @@ cc_library( "@local_tsl//tsl/profiler/lib:traceme", ] + if_cuda_or_rocm([ # keep sorted + ":nccl_id_store", "//xla:debug_options_flags", "//xla/service/gpu:gpu_compiler", "//xla/service/gpu:gpu_constants", @@ -238,6 +236,30 @@ xla_test( ], ) +cc_library( + name = "nccl_id_store", + srcs = ["nccl_id_store.cc"], + hdrs = ["nccl_id_store.h"], + deps = [ + "//xla:status_macros", + "//xla:util", + "//xla/backends/gpu/collectives:gpu_clique_key", + "//xla/backends/gpu/collectives:gpu_collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:global_device_id", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_test( name = "pjrt_client_test_se_gpu", srcs = ["pjrt_client_test_se_gpu.cc"], diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc new file mode 100644 index 00000000000000..a2a72856e6d9f3 --- /dev/null +++ b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc @@ -0,0 +1,69 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/gpu/nccl_id_store.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +absl::StatusOr NcclIdStore::GetNcclUniqueId(const CliqueKey& key) { + auto* gpu_key = tsl::down_cast(&key); + if (gpu_key == nullptr) { + return InvalidArgument("Expected GPU clique key"); + } + + // The caller must ensure that threads calling this method concurrently have + // unique keys, otherwise the global key-value store may hold the wrong value. + { + absl::MutexLock lock(&mu_); + auto it = cache_.find(*gpu_key); + if (it != cache_.end()) { + return it->second; + } + } + CliqueId clique_id; + int primary_node_id = device_to_node_.at(gpu_key->root_device()); + if (node_id_ == primary_node_id) { + TF_ASSIGN_OR_RETURN(clique_id, + gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); + TF_RETURN_IF_ERROR( + kv_store_->Set(gpu_key->ToString(), clique_id.ToString())); + } else { + TF_ASSIGN_OR_RETURN(std::string id_str, + kv_store_->Get(gpu_key->ToString(), absl::Minutes(10))); + clique_id = CliqueId(id_str); + } + absl::MutexLock lock(&mu_); + auto result = cache_.emplace(*gpu_key, std::move(clique_id)); + TF_RET_CHECK(result.second) << "Unique ID already in cache."; + return result.first->second; +} + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h new file mode 100644 index 00000000000000..fe8b060cb946a7 --- /dev/null +++ b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_GPU_NCCL_ID_STORE_H_ +#define XLA_PJRT_GPU_NCCL_ID_STORE_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/service/global_device_id.h" + +namespace xla { + +// A table mapping GpuCliqueKeys to CliqueIds. In a distributed setup the +// table of NCCL IDs is kept on the master node (node 0). The node of the first +// participating device will create the unique id. +class NcclIdStore { + public: + NcclIdStore(int node_id, + absl::flat_hash_map device_to_node, + std::shared_ptr kv_store) + : node_id_(node_id), + device_to_node_(std::move(device_to_node)), + kv_store_(std::move(kv_store)) {} + + absl::StatusOr GetNcclUniqueId(const CliqueKey& key); + + private: + const int node_id_; + const absl::flat_hash_map device_to_node_; + const std::shared_ptr kv_store_; + + absl::Mutex mu_; + absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // XLA_PJRT_GPU_NCCL_ID_STORE_H_ diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 5be6c178af91ae..a1f82c6e156656 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -46,10 +46,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/client/local_client.h" -#include "xla/core/collectives/collectives.h" -#include "xla/core/collectives/collectives_registry.h" #include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" @@ -109,6 +106,7 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/gpu/gpu_metrics.h" +#include "xla/pjrt/gpu/nccl_id_store.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_constants.h" @@ -1211,20 +1209,16 @@ absl::StatusOr BuildDistributedDevices( } gpu_executable_run_options->set_gpu_global_device_ids( std::move(gpu_device_ids)); - - TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, - xla::CollectivesRegistry::Default("gpu")); - xla::gpu::GpuCollectives* gpu_collectives = - tsl::down_cast(collectives); - - if (gpu_collectives == nullptr) { - return absl::InternalError("Failed to get GPU collectives"); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (num_nodes > 1) { + auto nccl_id_store = std::make_shared(node_id, device_to_node, + std::move(kv_store)); + gpu_executable_run_options->set_clique_id_callback( + [nccl_id_store](const CliqueKey& key) { + return nccl_id_store->GetNcclUniqueId(key); + }); } - - TF_RETURN_IF_ERROR(gpu_collectives->InitializeTopology( - {node_id, global_topology.nodes().size(), local_device_states.size(), - kv_store, device_to_node, gpu_executable_run_options})); - +#endif // GOOGLE_CUDA TF_ASSIGN_OR_RETURN(GpuTopologyProto gpu_topology, BuildGpuTopology(global_topology)); return std::make_pair(std::move(devices), gpu_topology); From d102c39142052dc8fc92c66a5afbb8e019717eab Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 15 Apr 2025 13:20:56 -0700 Subject: [PATCH 0835/1324] Removed spurious forward declaration of `Env`. PiperOrigin-RevId: 747997315 --- .../tsl/distributed_runtime/coordination/coordination_service.h | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index abd6debaac5039..a103d461595f68 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -43,7 +43,6 @@ limitations under the License. #include "tsl/platform/random.h" namespace tsl { -class Env; // Coordination service is used for controlling and coordinating distributed // execution in a cluster of multiple tasks. From 76ad7aaa3674a7914189f844193c98df28b200dc Mon Sep 17 00:00:00 2001 From: Robert David Date: Tue, 15 Apr 2025 13:31:07 -0700 Subject: [PATCH 0836/1324] Use the platform-appropriate `printf` format specifier for `int32_t`. Also, don't bump the lookup index to 64 bit. PiperOrigin-RevId: 748001229 --- tensorflow/lite/kernels/embedding_lookup.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index d92701059822f6..158bd8c63bf932 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -29,8 +29,8 @@ limitations under the License. // When indices are out of bound, the ops will not succeed. // -#include - +#include +#include #include #include "tensorflow/lite/c/c_api_types.h" @@ -104,17 +104,17 @@ TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, // Propagate empty tensor if input is empty return kTfLiteOk; } - const int64_t row_bytes = value->bytes / row_size; + const size_t row_bytes = value->bytes / row_size; char* output_raw = GetTensorData(output); const char* value_raw = GetTensorData(value); const int32_t* lookup_data = GetTensorData(lookup); for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { - int64_t idx = lookup_data[i]; + const int32_t idx = lookup_data[i]; if (idx >= row_size || idx < 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: index out of bounds. " - "Got %d, and bounds are [0, %d]", + "Got %" PRId32 ", and bounds are [0, %d]", idx, row_size - 1); return kTfLiteError; } else { @@ -142,11 +142,11 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, const int32_t* lookup_data = GetTensorData(lookup); for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { - int idx = lookup_data[i]; + const int32_t idx = lookup_data[i]; if (idx >= row_size || idx < 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: index out of bounds. " - "Got %d, and bounds are [0, %d]", + "Got %" PRId32 ", and bounds are [0, %d]", idx, row_size - 1); return kTfLiteError; } else { From 37aa6c20a2fb7978de11abc2fa587f18cf43e7fd Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 15 Apr 2025 14:07:22 -0700 Subject: [PATCH 0837/1324] Removed deprecated `CoordinationService` methods and types. PiperOrigin-RevId: 748014992 --- .../coordination/coordination_service.h | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index a103d461595f68..fc2e1e7addc70b 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -91,14 +91,6 @@ class CoordinationService { return std::make_unique(env, config, std::move(cache)); } - // TODO: b/410607726 - Remove once deprecated EnableCoordinationService is - // unused. - static std::unique_ptr EnableCoordinationService( - Env* env, const tensorflow::CoordinationServiceConfig& config, - std::unique_ptr cache) { - return Create(env, config, std::move(cache)); - } - CoordinationService(Env* env, const tensorflow::CoordinationServiceConfig& config, std::unique_ptr client_cache); @@ -653,10 +645,6 @@ class CoordinationService { void operator=(const CoordinationService&) = delete; }; -// TODO: b/410607726 - Remove once deprecated CoordinationServiceInterface is -// removed. -using CoordinationServiceInterface = CoordinationService; - } // namespace tsl #endif // XLA_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ From 2560f65ddb2cfbfd219243b39aa7116a221fb286 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 14:41:24 -0700 Subject: [PATCH 0838/1324] Improve readability of code converting between `tpu::TpuTopology*` and `SE_TpuTopology*`. `tpu::TpuTopology` and `SE_TpuTopology` are unrelated types. Therefore, we have to use `reinterpret_cast` when we need to convert between pointers to these two types. This increases cognitive burden as it takes extra effort for the reader to convince themselves that the conversion is safe. To reduce the cognitive burden, we conceal the `reinterpret_cast` in a couple of conversion functions so that only readers of these conversion functions' implementation need to worry about the `reinterpret_cast` safety. Also remove some uses of `const_cast` by making the API const-correct. PiperOrigin-RevId: 748028526 --- .../stream_executor/tpu/tpu_executor_c_api.h | 34 ++++++++++--------- .../xla/stream_executor/tpu/tpu_ops_c_api.h | 2 +- .../xla/stream_executor/tpu/tpu_platform.cc | 2 +- .../xla/stream_executor/tpu/tpu_platform.h | 2 +- .../tpu/tpu_platform_interface.h | 6 +--- .../xla/stream_executor/tpu/tpu_topology.h | 4 +-- 6 files changed, 24 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h index 3b13c28cf109f1..7888144845fd90 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h @@ -33,7 +33,7 @@ SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform, int ordinal, SE_PlatformId TpuPlatform_Id(SE_Platform* platform); int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform); bool TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(SE_Platform* platform); -SE_TpuTopology* TpuPlatform_GetTopologyPtr(SE_Platform* platform); +const SE_TpuTopology* TpuPlatform_GetTopologyPtr(SE_Platform* platform); SE_TpuTopology_Host* TpuPlatform_GetHostLocation(SE_Platform* platform); TpuRuntimeVersion TpuPlatform_GetRuntimeVersion(SE_Platform* platform); @@ -213,31 +213,33 @@ void TpuComputationPlacer_AssignLocalDevices(SE_TpuTopology_Host* host, int* assignment, TF_Status* status); -int TpuTopology_LogicalDevicesPerHost(SE_TpuTopology* tpu_topology, +int TpuTopology_LogicalDevicesPerHost(const SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type); -int TpuTopology_LogicalDevicesPerChip(SE_TpuTopology* tpu_topology, +int TpuTopology_LogicalDevicesPerChip(const SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type); -int TpuTopology_HostCount(SE_TpuTopology* tpu_topology); -int TpuTopology_ChipsPerHost(SE_TpuTopology* tpu_topology); - -int TpuTopology_ChipBounds_X(SE_TpuTopology* tpu_topology); -int TpuTopology_ChipBounds_Y(SE_TpuTopology* tpu_topology); -int TpuTopology_ChipBounds_Z(SE_TpuTopology* tpu_topology); -bool TpuTopology_HasChip(SE_TpuTopology* tpu_topology, int x, int y, int z); -SE_TpuTopology_Core* TpuTopology_CoreForId(SE_TpuTopology* tpu_topology, +int TpuTopology_HostCount(const SE_TpuTopology* tpu_topology); +int TpuTopology_ChipsPerHost(const SE_TpuTopology* tpu_topology); + +int TpuTopology_ChipBounds_X(const SE_TpuTopology* tpu_topology); +int TpuTopology_ChipBounds_Y(const SE_TpuTopology* tpu_topology); +int TpuTopology_ChipBounds_Z(const SE_TpuTopology* tpu_topology); +bool TpuTopology_HasChip(const SE_TpuTopology* tpu_topology, int x, int y, + int z); +SE_TpuTopology_Core* TpuTopology_CoreForId(const SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type, int id); -SE_TpuTopology_Core* TpuTopology_Core(SE_TpuTopology* tpu_topology, +SE_TpuTopology_Core* TpuTopology_Core(const SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type, int x, int y, int z, int index); -int TpuTopology_NumCores(SE_TpuTopology* tpu_topology, +int TpuTopology_NumCores(const SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type); // 'cores' should be a preallocated array of size TpuTopology_NumCores. -void TpuTopology_Cores(SE_TpuTopology* tpu_topology, +void TpuTopology_Cores(const SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type, SE_TpuTopology_Core** cores); -int TpuTopology_IdForHost(SE_TpuTopology* tpu_topology, int x, int y, int z); -TpuVersionEnum TpuTopology_Version(SE_TpuTopology* tpu_topology); +int TpuTopology_IdForHost(const SE_TpuTopology* tpu_topology, int x, int y, + int z); +TpuVersionEnum TpuTopology_Version(const SE_TpuTopology* tpu_topology); void TpuCoreLocation_ChipCoordinates(SE_TpuTopology_Core* tpu_core_location, int* x, int* y, int* z); void TpuCoreLocation_HostCoordinates(SE_TpuTopology_Core* tpu_core_location, diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h index fbc8fcec6962d9..f44babaefbbc8a 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h @@ -463,7 +463,7 @@ TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( uint64_t fingerprint, const char* data, size_t size); // Returns a pointer to the TPU topology struct. -TFTPU_CAPI_EXPORT SE_TpuTopology* TpuUtil_GetTopologyPtr(); +TFTPU_CAPI_EXPORT const SE_TpuTopology* TpuUtil_GetTopologyPtr(); // Returns XLA pad size from TPU topology. TFTPU_CAPI_EXPORT size_t TpuUtil_GetXlaPadSizeFromTpuTopology(); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc index 8a3c5997e76f2e..e141fa6f5a2cb9 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc @@ -105,7 +105,7 @@ bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() { ->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_); } -const tensorflow::tpu::TpuTopologyPtr TpuPlatform::GetTopologyPtr() { +const SE_TpuTopology* TpuPlatform::GetTopologyPtr() { return stream_executor::tpu::ExecutorApiFn()->TpuPlatform_GetTopologyPtrFn( platform_); } diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h index 65cdd0e294acbc..aabf10420f7dc6 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h @@ -62,7 +62,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { bool ShouldRegisterTpuDeviceToDeviceCopy() override; - const tensorflow::tpu::TpuTopologyPtr GetTopologyPtr() override; + const SE_TpuTopology* GetTopologyPtr() override; const tensorflow::tpu::TpuHostLocationExternal GetTpuHostLocation() const override; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.h b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.h index 5f320adcd48228..f5e7b68cab5a4d 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.h @@ -28,10 +28,6 @@ limitations under the License. namespace tensorflow { namespace tpu { -// TODO(skyewm): get rid of TpuTopologyPtr and either use SE_TpuTopology* or -// return a TpuTopologyExternal. -typedef SE_TpuTopology* TpuTopologyPtr; - class TpuPlatformInterface : public stream_executor::Platform { public: // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are @@ -53,7 +49,7 @@ class TpuPlatformInterface : public stream_executor::Platform { virtual bool ShouldRegisterTpuDeviceToDeviceCopy() = 0; - virtual const TpuTopologyPtr GetTopologyPtr() = 0; + virtual const SE_TpuTopology* GetTopologyPtr() = 0; virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_topology.h b/third_party/xla/xla/stream_executor/tpu/tpu_topology.h index 8c5f6ae19285a0..f1ab8da3fe6b9e 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_topology.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_topology.h @@ -69,7 +69,7 @@ struct TpuTopologyChipBoundsExternal { class TpuTopologyExternal { public: - explicit TpuTopologyExternal(SE_TpuTopology* topology) + explicit TpuTopologyExternal(const SE_TpuTopology* topology) : topology_(topology) {} int32_t LogicalDevicesPerHost(TpuCoreTypeEnum core_type) const; int32_t LogicalDevicesPerChip(TpuCoreTypeEnum core_type) const; @@ -85,7 +85,7 @@ class TpuTopologyExternal { TpuVersionEnum version() const; private: - SE_TpuTopology* topology_; + const SE_TpuTopology* topology_; }; std::string TpuVersionEnumToString(TpuVersionEnum version); From 10457e59ce41cd1724e750df4bb835e38114f19f Mon Sep 17 00:00:00 2001 From: Mani Ananth Date: Tue, 15 Apr 2025 15:13:40 -0700 Subject: [PATCH 0839/1324] Implementation of a GEMM (without fusions) cost model that accounts for compute, memory (HBM) and L2 overheads. Subsequent work will refine heuristics and add support for cost modeling GEMM fusions. At the moment, the implementation is tuned for NVIDIA H100 GPUs. PiperOrigin-RevId: 748041075 --- third_party/xla/xla/service/gpu/model/BUILD | 45 ++ .../gpu/model/gpu_dot_fusion_cost_model.cc | 397 ++++++++++++++++++ .../gpu/model/gpu_dot_fusion_cost_model.h | 92 ++++ .../model/gpu_dot_fusion_cost_model_test.cc | 73 ++++ .../gpu/model/gpu_performance_model_base.cc | 16 +- .../gpu/model/gpu_performance_model_base.h | 4 + 6 files changed, 623 insertions(+), 4 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.cc create mode 100644 third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.h create mode 100644 third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index d1b6c583e76627..0ed134483688b6 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -358,6 +358,51 @@ xla_cc_test( ], ) +cc_library( + name = "gpu_dot_fusion_cost_model", + srcs = ["gpu_dot_fusion_cost_model.cc"], + hdrs = ["gpu_dot_fusion_cost_model.h"], + deps = [ + ":gpu_performance_model_base", + ":tiled_hlo_instruction_or_computation", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "gpu_dot_fusion_cost_model_test", + srcs = ["gpu_dot_fusion_cost_model_test.cc"], + deps = [ + ":gpu_dot_fusion_cost_model", + ":tiled_hlo_instruction_or_computation", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test_helpers", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "gpu_collective_performance_model", srcs = ["gpu_collective_performance_model.cc"], diff --git a/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.cc b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.cc new file mode 100644 index 00000000000000..06bf8c283fc0e1 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.cc @@ -0,0 +1,397 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_dot_fusion_cost_model.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/shape.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +using primitive_util::BitWidth; + +namespace { + +bool TileFitsInRegisters(int64_t block_m, int64_t block_n, + const PrimitiveType& element_type, + const se::DeviceDescription& device_info) { + int bits_per_output_elem = BitWidth(element_type); + int registers_per_block = device_info.registers_per_block_limit(); + int64_t block_size = block_m * block_n; + int64_t bytes_per_block = + CeilOfRatio(block_size * bits_per_output_elem, 8); + constexpr double kFractionOfRegistersAvailableForAccumulators = 0.8; + return bytes_per_block <= + (registers_per_block * kFractionOfRegistersAvailableForAccumulators); +} + +absl::StatusOr> +GetDotAlgorithmValidConfigs(const HloDotInstruction* dot, + const se::DeviceDescription& device_info) { + absl::InlinedVector valid_configs; + + for (int64_t block_m = detail::kMinBlockDim; block_m <= detail::kMaxBlockDim; + block_m *= 2) { + for (int64_t block_n = detail::kMinBlockDim; + block_n <= detail::kMaxBlockDim; block_n *= 2) { + if (!TileFitsInRegisters(block_m, block_n, dot->shape().element_type(), + device_info)) { + continue; + } + + // TODO(maniananth): Add the logic to find valid kBlock stages. + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes.push_back( + std::vector{block_m, block_n}); + // TODO(maniananth): Add the logic to sweep num warps per block. + block_level_parameters.num_warps = detail::kNumWarpsPerBlock; + valid_configs.push_back(block_level_parameters); + } + } + + return valid_configs; +} + +int64_t CalculateNumThreadblocks(const HloDotInstruction* dot, int64_t tile_m, + int64_t tile_n) { + GpuDotFusionCostModel::DotProblemDimensions dims(*dot); + int64_t tile_k = dims.k; + // TODO(maniananth): Add special handling for grouped matmuls here. + int64_t num_tiles_along_m_dimension = CeilOfRatio(dims.m, tile_m); + int64_t num_tiles_along_n_dimension = CeilOfRatio(dims.n, tile_n); + int64_t num_tiles_along_k_dimension = CeilOfRatio(dims.k, tile_k); + int64_t num_threadblocks = dims.b * num_tiles_along_m_dimension * + num_tiles_along_n_dimension * + num_tiles_along_k_dimension; + + return num_threadblocks; +} + +int64_t CalculateNumWaves(int64_t threadblock_count, + const se::DeviceDescription& device_info) { + int64_t core_count = device_info.core_count(); + return CeilOfRatio(threadblock_count, core_count); +} + +int64_t CalculateTileFlops(int64_t tile_m, int64_t tile_n, int64_t problem_k) { + return /*flops per MAC*/ 2 * tile_m * tile_n * problem_k; +} + +// Calculates the effective flops for a GPU DOT operation as a function of the +// tile size (excludes clock throttling). Not all tile sizes are equally able to +// extract utilization on the same generation GPUs even if the workload is +// compute bound. GEMM performance is sensitive to the tensor core +// instruction throughputs that the programming model exposes. +double GetEffectiveFlopsPerNsForTileSize( + const int64_t tile_m, const se::DeviceDescription& device_info) { + se::CudaComputeCapability cuda_compute_capability = + device_info.cuda_compute_capability(); + + // Peak flops per ns for device. + int64_t peak_flops_per_ns = + GpuPerformanceModelBase::CalculateEffectiveFlopsPerNs( + device_info, device_info.fpus_per_core(), device_info.core_count()); + + // Final flops derate factor. + double flops_derate = 1.0; + + if (cuda_compute_capability.IsBlackwell()) { + if (tile_m < 128) { + // TODO(maniananth): Update this derate once we have more data from + // actual measurements on Blackwell. For now, we are applying a 50% + // derate to account for smaller M shapes. + flops_derate = 0.5; + } + } else if (cuda_compute_capability.IsHopper()) { + if (tile_m < 64) { + // Having a tile size M < 64 will lead to not being able to use the H100 + // tensor core instructions (wgmma). Defaulting to wmma instructions from + // A100 can result in a 63% derate in flops as benchmarked by HazyResearch + // as part of ThunderKittens work. + // (https://hazyresearch.stanford.edu/blog/2024-05-12-tk) + flops_derate = 0.63; + } + } else if (cuda_compute_capability.IsAmpere()) { + if (tile_m < 16) { + // A100 tensor core instructions are effective at tile_m >= 16. We're + // applying a 50% derate to account for this. + flops_derate = 0.5; + } + } + return peak_flops_per_ns * flops_derate; +} + +int64_t CalculateL2Bytes(absl::Span tile_shape, + int64_t problem_k, int64_t threadblock_count) { + // When tiling the GEMM problem on the outputs and mapping one tile per SM, + // the problem of data replication (or extra loads of the same data) between + // multiple SMs occurs. This leads to more data loads than what’s expected + // algorithmically, and increases bandwidth needs on the L2 → SM paths. + + // Input data loaded by each tile is equal to (Tile_M + Tile_N) * Tile_K + // bytes. + int64_t l2_data_per_tile = (tile_shape[0] + tile_shape[1]) * problem_k; + + // Across all the tiles, data loads will be equal to: (l2_data_per_tile * + // threadblock_count). + + // TODO(maniananth): Since H100, threadblocks within the same cluster will + // avoid redundant loads by reading from L2 cache once and multicasting the + // data to all threadblocks within the cluster. This is controlled + // programmatically and most performant GEMM implementations will use this + // feature. To model this, we scale the total data loads by the total number + // of threadblocks in a cluster. + + // On A100 and older GPUs, we will not see this behavior and the total data + // loads will be equal to (l2_data_per_tile * threadblock_count). Hence the + // cluster shape can be set to (1x1). + // TODO(maniananth): Account for Threadblock clusters here. + int64_t total_l2_data = l2_data_per_tile * threadblock_count; + return total_l2_data; +} + +} // namespace + +namespace detail { + +absl::StatusOr CalculateComputeTimeWithTileAndWaveQuantization( + const HloDotInstruction* dot, absl::Span tile_shape, + const se::DeviceDescription& device_info) { + if (tile_shape.size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Tile shape must be of size 2, got ", tile_shape.size())); + } + + GpuDotFusionCostModel::DotProblemDimensions dims(*dot); + int64_t tile_m = tile_shape[0], tile_n = tile_shape[1]; + int64_t threadblock_count = CalculateNumThreadblocks(dot, tile_m, tile_n); + int64_t wave_count = CalculateNumWaves(threadblock_count, device_info); + int64_t flops_per_tile = CalculateTileFlops(tile_m, tile_n, dims.k); + // The following is not the actual number of threadblocks launched, but due to + // how wave quantization works, we get the effect of running extra + // threadblocks when adding to roofline projections. + int64_t cta_count_with_wave_quant = wave_count * device_info.core_count(); + int64_t total_flops_with_wave_quant = + flops_per_tile * cta_count_with_wave_quant; + double effective_flops = + GetEffectiveFlopsPerNsForTileSize(tile_m, device_info); + // TODO(maniananth): Add a cap for power throttling here. + return absl::Nanoseconds(1.0f * total_flops_with_wave_quant / + effective_flops); +} + +absl::StatusOr CalculateL2Time( + const HloDotInstruction* dot, absl::Span tile_shape, + const se::DeviceDescription& device_info) { + if (tile_shape.size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Tile shape must be of size 2, got ", tile_shape.size())); + } + // TODO(maniananth): L2 bandwidth has been hardcoded for H100 based on + // microbenchmarking L2 bandwidth within a partition, but we should add this + // to the device info and extend for more GPUs. + // TODO(maniananth): Enforcing this check will cause unit tests written for + // RTX A6000 device descriptions to fail. We should enable this check once we + // have the L2 bandwidth for RTX A6000 or move unit tests to use H100 + // device description. + // if (device_info.cuda_compute_capability() != + // se::CudaComputeCapability(9, 0)) { + // return absl::InvalidArgumentError( + // "L2 time calculation is only supported for H100 GPUs."); + // } + + GpuDotFusionCostModel::DotProblemDimensions dims(*dot); + int64_t tile_m = tile_shape[0], tile_n = tile_shape[1]; + int64_t threadblock_count = CalculateNumThreadblocks(dot, tile_m, tile_n); + double device_l2_bandwidth = 6.65 * 1e12; // Measured H100 L2 bandwidth. + + return absl::Seconds(1.0f * + CalculateL2Bytes(tile_shape, dims.k, threadblock_count) / + device_l2_bandwidth); +} + +absl::Duration CalculateHbmTime(const HloDotInstruction* dot, + const se::DeviceDescription& device_info) { + // TODO(maniananth): Implement HBM derate lookup using profiled tables. + float hbm_bandwidth_utilization_rate = 0.8; + float dram_bandwidth = + device_info.memory_bandwidth() * hbm_bandwidth_utilization_rate; + + GpuDotFusionCostModel::DotProblemDimensions dims(*dot); + PrimitiveType lhs_element_type = dot->operand(0)->shape().element_type(); + PrimitiveType rhs_element_type = dot->operand(1)->shape().element_type(); + PrimitiveType output_element_type = dot->shape().element_type(); + + // Calculate the number of bytes for input reads and output writes to HBM. + int64_t lhs_tile_bytes = CeilOfRatio( + dims.b * dims.m * dims.k * BitWidth(lhs_element_type), 8); + int64_t rhs_tile_bytes = CeilOfRatio( + dims.b * dims.k * dims.n * BitWidth(rhs_element_type), 8); + int64_t output_tile_bytes = CeilOfRatio( + dims.b * dims.m * dims.n * BitWidth(output_element_type), 8); + + // Main loop loads the input matrices from HBM using SW pipelining and updates + // accumulators stored in register files (within the SM/compute unit). The + // epilogue loop writes the output matrices from register files to HBM. Main + // loop and epilogue loop are executed sequentially. + int64_t main_loop_bytes = lhs_tile_bytes + rhs_tile_bytes; + int64_t epilogue_bytes = output_tile_bytes; + + // Calculate the HBM time using the effective bandwidth for each transfer + // size. In the current implementation, we are assuming that the main loop and + // epilogue loop have the same effective DRAM bandwidth. This could change in + // the future, if we choose to model it based on their respective transfer + // sizes. + absl::Duration hbm_time = + absl::Seconds(1.0f * (main_loop_bytes + epilogue_bytes) / dram_bandwidth); + + return hbm_time; +} + +} // namespace detail + +namespace GpuDotFusionCostModel { + +absl::Status IsSupported(const HloDotInstruction* dot) { + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + + DimensionVector lhs_non_contracting_dims = GetNonContractingDims( + lhs_shape.dimensions().size(), dim_numbers.lhs_batch_dimensions(), + dim_numbers.lhs_contracting_dimensions()); + DimensionVector rhs_non_contracting_dims = GetNonContractingDims( + rhs_shape.dimensions().size(), dim_numbers.rhs_batch_dimensions(), + dim_numbers.rhs_contracting_dimensions()); + + if (lhs_non_contracting_dims.size() > 1 || + rhs_non_contracting_dims.size() > 1) { + return absl::UnimplementedError(absl::StrCat( + "Multiple non-contracting dimensions are not supported, got LHS: [", + absl::StrJoin(lhs_non_contracting_dims, ","), "], RHS: [", + absl::StrJoin(rhs_non_contracting_dims, ","), "]")); + } + // Only checking one side of batch and contracting dimensions, since they must + // be the same for left and right. + if (dim_numbers.lhs_batch_dimensions_size() > 1) { + return absl::UnimplementedError( + absl::StrCat("Batch dimension > 1 is not supported, got ", + absl::StrJoin(dim_numbers.lhs_batch_dimensions(), ","))); + } + if (dim_numbers.lhs_contracting_dimensions_size() != 1) { + return absl::UnimplementedError(absl::StrCat( + "Exactly one contracting dimension is supported, got ", + absl::StrJoin(dim_numbers.lhs_contracting_dimensions(), ","))); + } + if (dim_numbers.lhs_contracting_dimensions(0) != 1 || + dim_numbers.rhs_contracting_dimensions(0) != 0) { + return absl::UnimplementedError(absl::StrCat( + "Only lhs_contracting_dimensions=1 (got ", + absl::StrJoin(dim_numbers.lhs_contracting_dimensions(), ","), + ") and rhs_contracting_dimensions=0 (got ", + absl::StrJoin(dim_numbers.rhs_contracting_dimensions(), ","), + ") are supported.")); + } + + return absl::OkStatus(); +} + +DotProblemDimensions::DotProblemDimensions(const HloDotInstruction& dot) { + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + DimensionVector lhs_non_contracting_dims = GetNonContractingDims( + lhs_shape.dimensions().size(), dim_numbers.lhs_contracting_dimensions(), + dim_numbers.lhs_batch_dimensions()); + DimensionVector rhs_non_contracting_dims = GetNonContractingDims( + rhs_shape.dimensions().size(), dim_numbers.rhs_contracting_dimensions(), + dim_numbers.rhs_batch_dimensions()); + + b = dim_numbers.lhs_batch_dimensions_size() > 0 + ? dim_numbers.lhs_batch_dimensions(0) + : 1; + m = lhs_shape.dimensions(lhs_non_contracting_dims[0]); + n = rhs_shape.dimensions(rhs_non_contracting_dims[0]); + k = lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions()[0]); +} + +absl::StatusOr EstimateRunTimeForDotOpWithBlockParameters( + const HloDotInstruction* dot, const BlockLevelParameters& block_params, + const se::DeviceDescription& device_info) { + TF_RETURN_IF_ERROR(IsSupported(dot)); + if (block_params.output_tile_sizes.size() != 1) { + return absl::UnimplementedError( + absl::StrCat("Only single tile size is supported, got ", + block_params.output_tile_sizes.size())); + } + + // Calculate compute roofline with tile and wave quantization. + TF_ASSIGN_OR_RETURN(absl::Duration compute_time, + detail::CalculateComputeTimeWithTileAndWaveQuantization( + dot, block_params.output_tile_sizes[0], device_info)); + // Calculate HBM roofline. + absl::Duration hbm_time = detail::CalculateHbmTime(dot, device_info); + // Calculate L2 time. + TF_ASSIGN_OR_RETURN(absl::Duration l2_time, + detail::CalculateL2Time( + dot, block_params.output_tile_sizes[0], device_info)); + + // Assuming perfect overlap between compute and memory. + return std::max({compute_time, hbm_time, l2_time}); +} + +absl::StatusOr EstimateRunTimeForDotOp( + const HloDotInstruction* dot, const se::DeviceDescription& device_info) { + TF_RETURN_IF_ERROR(IsSupported(dot)); + + // TODO(maniananth): Implement this. + return absl::UnimplementedError("Not implemented yet"); +} + +absl::StatusOr FindBestBlockLevelParameters( + const HloDotInstruction* dot, const se::DeviceDescription& device_info) { + TF_RETURN_IF_ERROR(IsSupported(dot)); + + // TODO(maniananth): Implement this. + return absl::UnimplementedError("Not implemented yet"); +} + +} // namespace GpuDotFusionCostModel + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.h b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.h new file mode 100644 index 00000000000000..be7893f3de1c8d --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model.h @@ -0,0 +1,92 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_DOT_FUSION_COST_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_DOT_FUSION_COST_MODEL_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +namespace GpuDotFusionCostModel { + +struct DotProblemDimensions { + int64_t b; + int64_t m; + int64_t n; + int64_t k; + + explicit DotProblemDimensions(const HloDotInstruction& dot); +}; + +// Returns OkStatus if the dot operation is supported by the cost model. +absl::Status IsSupported(const HloDotInstruction* dot); + +// Estimates the run time for a GPU DOT operation with the given set ofblock +// parameters. +absl::StatusOr EstimateRunTimeForDotOpWithBlockParameters( + const HloDotInstruction* dot, const BlockLevelParameters& block_params, + const se::DeviceDescription& device_info); + +// Estimates the run time for a GPU DOT operation, +absl::StatusOr EstimateRunTimeForDotOp( + const HloDotInstruction* dot, const se::DeviceDescription& device_info); + +absl::StatusOr FindBestBlockLevelParameters( + const HloDotInstruction* dot, const se::DeviceDescription& device_info); +} // namespace GpuDotFusionCostModel + +namespace detail { + +// Calculates the HBM time for a GPU DOT operation. Current implementation +// uses a flat derate on top of the spec bandwidth. A HBM bandwidth model based +// derate lookup from profiled data will be added in the future. +absl::Duration CalculateHbmTime(const HloDotInstruction* dot, + const se::DeviceDescription& device_info); + +// Calculates the L2 time for a GPU DOT operation. +absl::StatusOr CalculateL2Time( + const HloDotInstruction* dot, absl::Span tile_shape, + const se::DeviceDescription& device_info); + +// Calculates the compute time for a GPU DOT operation with tile and wave +// quantization effects taken into account. +// (1) Tile Quantization effects occur when the input problem dimensions are +// quantized to the tile shape. +// (2) Wave Quantization effects occur when the number of threadblocks is +// quantized to the number of SMs per GPU. +absl::StatusOr CalculateComputeTimeWithTileAndWaveQuantization( + const HloDotInstruction* dot, absl::Span tile_shape, + const se::DeviceDescription& device_info); + +const int kMinBlockDim = 32; +const int kMaxBlockDim = 256; +const int kMaxSplitK = 128; +const int kNumWarpsPerBlock = 4; +} // namespace detail + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_DOT_FUSION_COST_MODEL_H_ diff --git a/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc new file mode 100644 index 00000000000000..4e434abd360a43 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_dot_fusion_cost_model.h" + +#include + +#include +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test_helpers.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuDotFusionCostModelTest : public HloHardwareIndependentTestBase { + protected: + se::DeviceDescription device_description_{ + TestGpuDeviceInfo::RTXA6000DeviceInfo()}; +}; + +TEST_F(GpuDotFusionCostModelTest, GpuDotComputeBoundBf16) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { +p0 = bf16[8192,8192] parameter(0) +p1 = bf16[8192,8192] parameter(1) +ROOT r = bf16[8192,8192] dot(p0, p1), +lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_bf16 +})")); + + BlockLevelParameters block_params; + block_params.output_tile_sizes = {{64, 64}}; + block_params.num_warps = 4; + block_params.num_ctas = 1; + block_params.num_stages = 1; + auto* dot = + Cast(module->entry_computation()->root_instruction()); + ASSERT_IS_OK(GpuDotFusionCostModel::IsSupported(dot)); + absl::Duration runtime = + GpuDotFusionCostModel::EstimateRunTimeForDotOpWithBlockParameters( + dot, block_params, device_description_) + .value(); + absl::Duration expected_runtime_compute_bound = + detail::CalculateComputeTimeWithTileAndWaveQuantization( + dot, block_params.output_tile_sizes[0], device_description_) + .value(); + ASSERT_EQ(runtime, expected_runtime_compute_bound); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc index 87a9554112b464..d769f2b0f02926 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -334,9 +334,9 @@ absl::Duration GpuPerformanceModelBase::WriteTime( } /*static*/ -absl::Duration GpuPerformanceModelBase::ComputeTime( - const se::DeviceDescription& gpu_device_info, int64_t flops, - int64_t num_blocks, int64_t num_threads_per_block) { +int64_t GpuPerformanceModelBase::CalculateEffectiveFlopsPerNs( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t num_threads_per_block) { int64_t n_active_fpus_per_core = std::min(num_threads_per_block, gpu_device_info.fpus_per_core()); @@ -345,7 +345,15 @@ absl::Duration GpuPerformanceModelBase::ComputeTime( int64_t fpu_count = n_active_core * n_active_fpus_per_core; int64_t flop_per_ns_per_fpu = gpu_device_info.clock_rate_ghz() * /*fma:*/ 2; - int64_t flop_per_ns_effective = flop_per_ns_per_fpu * fpu_count; + return flop_per_ns_per_fpu * fpu_count; +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ComputeTime( + const se::DeviceDescription& gpu_device_info, int64_t flops, + int64_t num_blocks, int64_t num_threads_per_block) { + int64_t flop_per_ns_effective = CalculateEffectiveFlopsPerNs( + gpu_device_info, num_blocks, num_threads_per_block); return absl::Nanoseconds(1.0f * flops / flop_per_ns_effective); } diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h index 64963c286daf7b..a4f5c0f7e877f0 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h @@ -219,6 +219,10 @@ class GpuPerformanceModelBase { static absl::Duration WriteTime(const se::DeviceDescription& gpu_device_info, int64_t bytes_written); + static int64_t CalculateEffectiveFlopsPerNs( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t num_threads_per_block); + static absl::Duration ComputeTime( const se::DeviceDescription& gpu_device_info, int64_t flops, int64_t num_blocks, int64_t num_threads_per_block); From fd40ae35050610ab993d44071f8628793b3df3d4 Mon Sep 17 00:00:00 2001 From: Matt Hurd Date: Tue, 15 Apr 2025 16:33:50 -0700 Subject: [PATCH 0840/1324] Move xprof code from tensorflow/core/profiler/convert to xprof. This constitutes the last major chunk of code being moved over. PiperOrigin-RevId: 748067374 --- tensorflow/core/lib/gtl/BUILD | 2 +- tensorflow/core/profiler/convert/BUILD | 1546 +--------------- .../convert/compute_inference_latency.cc | 144 -- .../convert/compute_inference_latency.h | 33 - .../convert/compute_inference_latency_test.cc | 72 - .../core/profiler/convert/data_table_utils.h | 155 -- .../profiler/convert/data_table_utils_test.cc | 100 - .../core/profiler/convert/dcn_analysis.cc | 471 ----- .../core/profiler/convert/dcn_analysis.h | 227 --- .../profiler/convert/dcn_analysis_test.cc | 363 ---- .../convert/dcn_slack_analysis_combiner.cc | 92 - .../convert/dcn_slack_analysis_combiner.h | 47 - tensorflow/core/profiler/convert/dcn_utils.cc | 121 -- tensorflow/core/profiler/convert/dcn_utils.h | 76 - .../core/profiler/convert/dcn_utils_test.cc | 141 -- .../profiler/convert/duty_cycle_combiner.h | 72 - .../convert/duty_cycle_combiner_test.cc | 82 - .../profiler/convert/duty_cycle_tracker.cc | 103 -- .../profiler/convert/duty_cycle_tracker.h | 71 - .../convert/duty_cycle_tracker_test.cc | 148 -- .../convert/hlo_proto_to_graph_view.cc | 554 ------ .../convert/hlo_proto_to_graph_view.h | 102 - .../convert/hlo_proto_to_graph_view_test.cc | 123 -- ...hlo_proto_to_memory_visualization_utils.cc | 1108 ----------- .../hlo_proto_to_memory_visualization_utils.h | 25 +- ...roto_to_memory_visualization_utils_test.cc | 114 -- .../profiler/convert/hlo_to_tools_data.cc | 145 -- .../core/profiler/convert/hlo_to_tools_data.h | 41 - .../core/profiler/convert/inference_stats.cc | 1510 --------------- .../core/profiler/convert/inference_stats.h | 53 - .../convert/inference_stats_combiner.cc | 171 -- .../convert/inference_stats_combiner.h | 25 - .../convert/inference_stats_grouping.cc | 475 ----- .../convert/inference_stats_grouping.h | 29 - .../convert/inference_stats_grouping_test.cc | 508 ----- .../convert/inference_stats_sampler.cc | 311 ---- .../convert/inference_stats_sampler.h | 53 - .../convert/inference_stats_sampler_test.cc | 131 -- .../convert/multi_xplanes_to_op_stats.cc | 71 - .../convert/multi_xplanes_to_op_stats.h | 38 - .../multi_xspace_to_inference_stats.cc | 128 -- .../convert/multi_xspace_to_inference_stats.h | 34 - .../convert/op_metrics_db_combiner.cc | 143 -- .../profiler/convert/op_metrics_db_combiner.h | 54 - .../profiler/convert/op_metrics_to_record.cc | 50 - .../profiler/convert/op_metrics_to_record.h | 341 ---- .../profiler/convert/op_profile_builder.cc | 445 ----- .../profiler/convert/op_profile_builder.h | 157 -- tensorflow/core/profiler/convert/op_stack.h | 69 - .../profiler/convert/op_stats_combiner.cc | 318 ---- .../core/profiler/convert/op_stats_combiner.h | 86 - .../convert/op_stats_combiner_test.cc | 124 -- .../profiler/convert/op_stats_to_hlo_stats.cc | 178 -- .../profiler/convert/op_stats_to_hlo_stats.h | 42 - .../op_stats_to_input_pipeline_analysis.cc | 1648 ----------------- .../op_stats_to_input_pipeline_analysis.h | 133 -- ...p_stats_to_input_pipeline_analysis_test.cc | 205 -- .../convert/op_stats_to_op_profile.cc | 103 -- .../profiler/convert/op_stats_to_op_profile.h | 56 - .../convert/op_stats_to_overview_page.cc | 408 ---- .../convert/op_stats_to_overview_page.h | 81 - .../profiler/convert/op_stats_to_pod_stats.cc | 108 -- .../profiler/convert/op_stats_to_pod_stats.h | 30 - .../convert/op_stats_to_pod_stats_test.cc | 127 -- .../convert/op_stats_to_pod_viewer.cc | 67 - .../profiler/convert/op_stats_to_pod_viewer.h | 32 - .../convert/op_stats_to_pod_viewer_test.cc | 139 -- .../convert/op_stats_to_roofline_model.cc | 271 --- .../convert/op_stats_to_roofline_model.h | 101 - .../profiler/convert/op_stats_to_tf_stats.cc | 126 -- .../profiler/convert/op_stats_to_tf_stats.h | 32 - .../convert/op_stats_to_tf_stats_test.cc | 169 -- .../convert/preprocess_single_host_xplane.cc | 70 - .../convert/preprocess_single_host_xplane.h | 35 - .../profiler/convert/process_megascale_dcn.cc | 57 - .../profiler/convert/process_megascale_dcn.h | 29 - .../convert/profile_time_breakdown.cc | 79 - .../profiler/convert/profile_time_breakdown.h | 226 +-- .../core/profiler/convert/repository.cc | 179 -- tensorflow/core/profiler/convert/repository.h | 206 --- .../core/profiler/convert/repository_test.cc | 138 -- .../convert/step_events_to_steps_db.cc | 225 --- .../convert/step_events_to_steps_db.h | 38 - .../core/profiler/convert/tool_options.h | 71 - .../tpu_input_pipeline_analysis_constants.h | 30 - .../convert/xplane_to_dcn_collective_stats.cc | 161 -- .../convert/xplane_to_dcn_collective_stats.h | 45 - .../xplane_to_dcn_collective_stats_test.cc | 208 --- .../core/profiler/convert/xplane_to_hlo.cc | 132 -- .../core/profiler/convert/xplane_to_hlo.h | 43 - .../convert/xplane_to_kernel_stats_db.cc | 91 - .../convert/xplane_to_kernel_stats_db.h | 42 - .../convert/xplane_to_kernel_stats_db_test.cc | 139 -- .../convert/xplane_to_memory_profile.cc | 573 ------ .../convert/xplane_to_memory_profile.h | 42 - .../convert/xplane_to_memory_profile_test.cc | 129 -- .../convert/xplane_to_op_metrics_db.cc | 389 ---- .../convert/xplane_to_op_metrics_db.h | 63 - .../convert/xplane_to_op_metrics_db_test.cc | 305 --- .../profiler/convert/xplane_to_op_stats.cc | 495 ----- .../profiler/convert/xplane_to_op_stats.h | 48 +- .../convert/xplane_to_op_stats_test.cc | 826 --------- .../profiler/convert/xplane_to_step_events.cc | 394 ---- .../profiler/convert/xplane_to_step_events.h | 47 - .../convert/xplane_to_step_events_test.cc | 190 -- .../convert/xplane_to_tf_data_stats.cc | 525 ------ .../convert/xplane_to_tf_data_stats.h | 63 - .../convert/xplane_to_tf_data_stats_test.cc | 420 ----- .../convert/xplane_to_tf_functions.cc | 306 --- .../profiler/convert/xplane_to_tf_functions.h | 40 - .../convert/xplane_to_tf_functions_test.cc | 186 -- .../profiler/convert/xplane_to_tool_names.cc | 79 - .../profiler/convert/xplane_to_tool_names.h | 36 - .../convert/xplane_to_tool_names_test.cc | 169 -- .../profiler/convert/xplane_to_tools_data.cc | 432 ----- .../profiler/convert/xplane_to_tools_data.h | 40 - .../convert/xplane_to_trace_container.cc | 256 --- .../convert/xplane_to_trace_container.h | 38 - .../convert/xplane_to_trace_container_test.cc | 125 -- .../convert/xspace_to_dcn_slack_analysis.cc | 528 ------ .../convert/xspace_to_dcn_slack_analysis.h | 167 -- tensorflow/core/profiler/lib/BUILD | 2 + tensorflow/python/profiler/internal/BUILD | 13 +- .../profiler/internal/profiler_pywrap_impl.cc | 9 +- .../profiler/internal/profiler_wrapper.cc | 6 +- .../internal/pywrap_profiler_plugin.cc | 2 +- tensorflow/workspace2.bzl | 6 +- 127 files changed, 50 insertions(+), 24801 deletions(-) delete mode 100644 tensorflow/core/profiler/convert/compute_inference_latency.cc delete mode 100644 tensorflow/core/profiler/convert/compute_inference_latency.h delete mode 100644 tensorflow/core/profiler/convert/compute_inference_latency_test.cc delete mode 100644 tensorflow/core/profiler/convert/data_table_utils.h delete mode 100644 tensorflow/core/profiler/convert/data_table_utils_test.cc delete mode 100644 tensorflow/core/profiler/convert/dcn_analysis.cc delete mode 100644 tensorflow/core/profiler/convert/dcn_analysis.h delete mode 100644 tensorflow/core/profiler/convert/dcn_analysis_test.cc delete mode 100644 tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc delete mode 100644 tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h delete mode 100644 tensorflow/core/profiler/convert/dcn_utils.cc delete mode 100644 tensorflow/core/profiler/convert/dcn_utils.h delete mode 100644 tensorflow/core/profiler/convert/dcn_utils_test.cc delete mode 100644 tensorflow/core/profiler/convert/duty_cycle_combiner.h delete mode 100644 tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc delete mode 100644 tensorflow/core/profiler/convert/duty_cycle_tracker.cc delete mode 100644 tensorflow/core/profiler/convert/duty_cycle_tracker.h delete mode 100644 tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc delete mode 100644 tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc delete mode 100644 tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h delete mode 100644 tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc delete mode 100644 tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc delete mode 100644 tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc delete mode 100644 tensorflow/core/profiler/convert/hlo_to_tools_data.cc delete mode 100644 tensorflow/core/profiler/convert/hlo_to_tools_data.h delete mode 100644 tensorflow/core/profiler/convert/inference_stats.cc delete mode 100644 tensorflow/core/profiler/convert/inference_stats.h delete mode 100644 tensorflow/core/profiler/convert/inference_stats_combiner.cc delete mode 100644 tensorflow/core/profiler/convert/inference_stats_combiner.h delete mode 100644 tensorflow/core/profiler/convert/inference_stats_grouping.cc delete mode 100644 tensorflow/core/profiler/convert/inference_stats_grouping.h delete mode 100644 tensorflow/core/profiler/convert/inference_stats_grouping_test.cc delete mode 100644 tensorflow/core/profiler/convert/inference_stats_sampler.cc delete mode 100644 tensorflow/core/profiler/convert/inference_stats_sampler.h delete mode 100644 tensorflow/core/profiler/convert/inference_stats_sampler_test.cc delete mode 100644 tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc delete mode 100644 tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h delete mode 100644 tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc delete mode 100644 tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h delete mode 100644 tensorflow/core/profiler/convert/op_metrics_db_combiner.cc delete mode 100644 tensorflow/core/profiler/convert/op_metrics_db_combiner.h delete mode 100644 tensorflow/core/profiler/convert/op_metrics_to_record.cc delete mode 100644 tensorflow/core/profiler/convert/op_metrics_to_record.h delete mode 100644 tensorflow/core/profiler/convert/op_profile_builder.cc delete mode 100644 tensorflow/core/profiler/convert/op_profile_builder.h delete mode 100644 tensorflow/core/profiler/convert/op_stack.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_combiner.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_combiner.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_combiner_test.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_op_profile.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_op_profile.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_overview_page.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_overview_page.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_pod_stats.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_roofline_model.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_tf_stats.h delete mode 100644 tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc delete mode 100644 tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc delete mode 100644 tensorflow/core/profiler/convert/preprocess_single_host_xplane.h delete mode 100644 tensorflow/core/profiler/convert/process_megascale_dcn.cc delete mode 100644 tensorflow/core/profiler/convert/process_megascale_dcn.h delete mode 100644 tensorflow/core/profiler/convert/profile_time_breakdown.cc delete mode 100644 tensorflow/core/profiler/convert/repository.cc delete mode 100644 tensorflow/core/profiler/convert/repository.h delete mode 100644 tensorflow/core/profiler/convert/repository_test.cc delete mode 100644 tensorflow/core/profiler/convert/step_events_to_steps_db.cc delete mode 100644 tensorflow/core/profiler/convert/step_events_to_steps_db.h delete mode 100644 tensorflow/core/profiler/convert/tool_options.h delete mode 100644 tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_hlo.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_hlo.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_memory_profile.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_memory_profile.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_op_stats.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_step_events.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_step_events.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_step_events_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tf_functions.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tf_functions.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tool_names.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tool_names.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tools_data.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_tools_data.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_trace_container.cc delete mode 100644 tensorflow/core/profiler/convert/xplane_to_trace_container.h delete mode 100644 tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc delete mode 100644 tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc delete mode 100644 tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h diff --git a/tensorflow/core/lib/gtl/BUILD b/tensorflow/core/lib/gtl/BUILD index 31a74dca33bb07..338d6fe6fb4529 100644 --- a/tensorflow/core/lib/gtl/BUILD +++ b/tensorflow/core/lib/gtl/BUILD @@ -22,7 +22,7 @@ package( # tensorflow/examples/custom_ops_doc/simple_hash_table uses map_util "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", # tensorflow/core/profiler/convert uses map_util - "//tensorflow/core/profiler/convert:__pkg__", + "@org_xprof//xprof/convert:__pkg__", ], licenses = ["notice"], ) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 858cee63d4cb79..0d5c900684814b 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -1,1560 +1,74 @@ -load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_alias", "tf_profiler_copts") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], # TODO(matthurd): Update to profiler:internal after xprof migration. + default_visibility = ["//tensorflow/core/profiler:internal"], licenses = ["notice"], ) cc_library( - name = "xplane_to_op_metrics_db", - srcs = ["xplane_to_op_metrics_db.cc"], - hdrs = ["xplane_to_op_metrics_db.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_db_combiner", - ":op_stack", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//xprof/utils:cost_utils", - "@org_xprof//xprof/utils:gpu_event_stats", - "@org_xprof//xprof/utils:hlo_module_map", - "@org_xprof//xprof/utils:op_metrics_db_utils", - "@org_xprof//xprof/utils:op_utils", - ], -) - -tf_cc_test( - name = "xplane_to_op_metrics_db_test", - size = "small", - srcs = ["xplane_to_op_metrics_db_test.cc"], - deps = [ - ":xplane_to_op_metrics_db", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//xprof/utils:hlo_cost_analysis_wrapper", - "@org_xprof//xprof/utils:hlo_module_map", - "@org_xprof//xprof/utils:op_metrics_db_utils", - "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", - ], -) - -cc_library( - name = "op_metrics_db_combiner", - srcs = ["op_metrics_db_combiner.cc"], - hdrs = ["op_metrics_db_combiner.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -cc_library( - name = "op_metrics_to_record", - srcs = ["op_metrics_to_record.cc"], - hdrs = ["op_metrics_to_record.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "op_stack", - hdrs = ["op_stack.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "op_stats_to_hlo_stats", - srcs = ["op_stats_to_hlo_stats.cc"], - hdrs = ["op_stats_to_hlo_stats.h"], - deps = [ - ":data_table_utils", - ":op_metrics_to_record", - "//tensorflow/core/profiler/protobuf:hlo_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - ], -) - -cc_library( - name = "op_stats_to_roofline_model", - srcs = ["op_stats_to_roofline_model.cc"], - hdrs = ["op_stats_to_roofline_model.h"], - deps = [ - ":op_metrics_db_combiner", - ":op_metrics_to_record", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/log:check", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:roofline_model_proto_cc", - "@org_xprof//xprof/utils:diagnostics", - ], -) - -cc_library( - name = "op_stats_to_op_profile", - srcs = ["op_stats_to_op_profile.cc"], - hdrs = ["op_stats_to_op_profile.h"], - deps = [ - ":op_profile_builder", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -cc_library( - name = "op_stats_to_overview_page", - srcs = ["op_stats_to_overview_page.cc"], - hdrs = ["op_stats_to_overview_page.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_to_record", - ":op_stats_to_input_pipeline_analysis", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", - "//tensorflow/core/profiler/protobuf:power_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:format_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:input_pipeline_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:overview_page_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:power_metrics_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", - "@org_xprof//xprof/utils:diagnostics", - "@org_xprof//xprof/utils:hardware_type_utils", - "@org_xprof//xprof/utils:html_utils", - "@org_xprof//xprof/utils:kernel_stats_utils", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -cc_library( - name = "op_stats_to_pod_stats", - srcs = ["op_stats_to_pod_stats.cc"], - hdrs = ["op_stats_to_pod_stats.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//xprof/utils:diagnostics", - "@org_xprof//xprof/utils:event_span", - ], -) - -tf_cc_test( - name = "op_stats_to_pod_stats_test", - srcs = ["op_stats_to_pod_stats_test.cc"], - deps = [ - ":op_stats_to_pod_stats", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", - "@org_xprof//xprof/utils:diagnostics", - "@org_xprof//xprof/utils:event_span", - ], -) - -cc_library( - name = "op_stats_to_pod_viewer", - srcs = ["op_stats_to_pod_viewer.cc"], - hdrs = ["op_stats_to_pod_viewer.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_stats_to_pod_stats", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_viewer_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/log:check", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_viewer_proto_cc", - "@org_xprof//xprof/utils:diagnostics", - ], -) - -tf_cc_test( - name = "op_stats_to_pod_viewer_test", - srcs = ["op_stats_to_pod_viewer_test.cc"], - deps = [ - ":op_stats_to_pod_viewer", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:pod_viewer_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", - "@org_xprof//xprof/utils:diagnostics", - "@org_xprof//xprof/utils:event_span", - ], -) - -tf_cc_test( - name = "xplane_to_trace_container_test", - srcs = ["xplane_to_trace_container_test.cc"], - deps = [ - ":xplane_to_trace_container", - "//tensorflow/core/util/proto:proto_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_raw_proto_cc", - ], -) - -cc_library( - name = "op_stats_to_input_pipeline_analysis", - srcs = ["op_stats_to_input_pipeline_analysis.cc"], - hdrs = ["op_stats_to_input_pipeline_analysis.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_to_record", - ":profile_time_breakdown", - ":step_events_to_steps_db", - ":tpu_input_pipeline_analysis_constants", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/platform:logging", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tpu_input_pipeline_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:format_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/util:stats_calculator_portable", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:input_pipeline_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tpu_input_pipeline_proto_cc", - "@org_xprof//xprof/utils:diagnostics", - "@org_xprof//xprof/utils:event_span", - "@org_xprof//xprof/utils:html_utils", - "@org_xprof//xprof/utils:op_metrics_db_utils", - "@org_xprof//xprof/utils:tpu_step_breakdown_utils", - "@org_xprof//xprof/utils:tpu_step_details_utils", - ], -) - -tf_cc_test( - name = "op_stats_to_input_pipeline_analysis_test", - srcs = ["op_stats_to_input_pipeline_analysis_test.cc"], - deps = [ - ":op_stats_to_input_pipeline_analysis", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@org_xprof//xprof/utils:event_span", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -cc_library( - name = "op_stats_to_tf_stats", - srcs = ["op_stats_to_tf_stats.cc"], - hdrs = ["op_stats_to_tf_stats.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_to_record", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_stats_proto_cc", - "@org_xprof//xprof/utils:kernel_stats_utils", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -tf_cc_test( - name = "op_stats_to_tf_stats_test", - size = "small", - srcs = ["op_stats_to_tf_stats_test.cc"], - deps = [ - ":op_stats_to_tf_stats", - ":xplane_to_op_stats", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_stats_proto_cc", - ], -) - -cc_library( - name = "step_events_to_steps_db", - srcs = ["step_events_to_steps_db.cc"], - hdrs = ["step_events_to_steps_db.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_db_combiner", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", - "@org_xprof//xprof/utils:event_span", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -cc_library( - name = "xplane_to_op_stats", - srcs = ["xplane_to_op_stats.cc"], - hdrs = ["xplane_to_op_stats.h"], + name = "xplane_to_step_stats", + srcs = ["xplane_to_step_stats.cc"], + hdrs = ["xplane_to_step_stats.h"], copts = tf_profiler_copts(), - visibility = ["@local_xla//xla/tsl/profiler:friends"], deps = [ - ":duty_cycle_combiner", - ":duty_cycle_tracker", - ":op_metrics_db_combiner", - ":repository", - ":step_events_to_steps_db", - ":xplane_to_kernel_stats_db", - ":xplane_to_op_metrics_db", - ":xplane_to_step_events", - ":xplane_to_tf_functions", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:hlo_proto_map", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/utils:gpu_event_stats", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hardware_types_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", - "@org_xprof//xprof/utils:device_caps_utils", - "@org_xprof//xprof/utils:event_span", - "@org_xprof//xprof/utils:gpu_event_stats", - "@org_xprof//xprof/utils:hardware_type_utils", - "@org_xprof//xprof/utils:hlo_cost_analysis_wrapper", - "@org_xprof//xprof/utils:hlo_module_map", - "@org_xprof//xprof/utils:kernel_stats_utils", - "@org_xprof//xprof/utils:op_utils", - "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", - ], -) - -cc_library( - name = "multi_xplanes_to_op_stats", - srcs = ["multi_xplanes_to_op_stats.cc"], - hdrs = ["multi_xplanes_to_op_stats.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_stats_combiner", - ":preprocess_single_host_xplane", - ":repository", - ":xplane_to_op_stats", - "//tensorflow/core:portable_gif_internal", - "//tensorflow/core/platform:status", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/status", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//xprof/utils:hardware_type_utils", - "@org_xprof//xprof/utils:step_intersection", - ], -) - -tf_cc_test( - name = "xplane_to_op_stats_test", - size = "small", - srcs = ["xplane_to_op_stats_test.cc"], - deps = [ - ":duty_cycle_tracker", - ":multi_xplanes_to_op_stats", - ":repository", - ":step_events_to_steps_db", - ":xplane_to_op_stats", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", - ], -) - -cc_library( - name = "xplane_to_step_events", - srcs = ["xplane_to_step_events.cc"], - hdrs = ["xplane_to_step_events.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_metrics_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:steps_db_proto_cc", - "@org_xprof//xprof/utils:event_span", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -tf_cc_test( - name = "xplane_to_step_events_test", - size = "small", - srcs = ["xplane_to_step_events_test.cc"], - deps = [ - ":xplane_to_step_events", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@org_xprof//xprof/utils:event_span", - ], -) - -cc_library( - name = "xplane_to_kernel_stats_db", - srcs = ["xplane_to_kernel_stats_db.cc"], - hdrs = ["xplane_to_kernel_stats_db.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", "@org_xprof//xprof/utils:gpu_event_stats", - "@org_xprof//xprof/utils:hlo_module_map", - "@org_xprof//xprof/utils:kernel_stats_utils", ], ) -tf_cc_test( - name = "xplane_to_kernel_stats_db_test", - size = "small", - srcs = ["xplane_to_kernel_stats_db_test.cc"], - deps = [ - ":xplane_to_kernel_stats_db", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:kernel_stats_proto_cc", - "@org_xprof//xprof/utils:kernel_stats_utils", - ], -) +# DO NOT ADD NEW DEPENDENCIES TO ANY TARGET IN THIS FILE. +# Instead, use //third_party/xprof/convert. cc_library( - name = "xplane_to_tf_functions", - srcs = ["xplane_to_tf_functions.cc"], - hdrs = ["xplane_to_tf_functions.h"], + name = "hlo_proto_to_memory_visualization_utils", + hdrs = ["hlo_proto_to_memory_visualization_utils.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", + visibility = [ + "//learning/deepmind/jax/statix:__subpackages__", + "//platforms/xla/tools/shardy_migration:__subpackages__", + "//smartass/brain/tpu_worker:__subpackages__", + "//tensorflow/core/profiler:internal", ], -) - -tf_cc_test( - name = "xplane_to_tf_functions_test", - size = "small", - srcs = ["xplane_to_tf_functions_test.cc"], deps = [ - ":xplane_to_tf_functions", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_function_proto_cc", + "@org_xprof//xprof/convert:hlo_proto_to_memory_visualization_utils", ], ) cc_library( - name = "xplane_to_memory_profile", - srcs = ["xplane_to_memory_profile.cc"], - hdrs = ["xplane_to_memory_profile.h"], + name = "profile_time_breakdown", + hdrs = ["profile_time_breakdown.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/framework:protos_all_cc", - "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:memory_profile_proto_cc", + visibility = [ + "//platforms/performance/autograppler/utils:__subpackages__", + "//tensorflow/core/profiler:internal", ], -) - -tf_cc_test( - name = "xplane_to_memory_profile_test", - size = "small", - srcs = ["xplane_to_memory_profile_test.cc"], deps = [ - ":xplane_to_memory_profile", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:memory_profile_proto_cc", + "@org_xprof//xprof/convert:profile_time_breakdown", ], ) cc_library( - name = "op_stats_combiner", - srcs = ["op_stats_combiner.cc"], - hdrs = ["op_stats_combiner.h"], + name = "xplane_to_op_stats", + hdrs = ["xplane_to_op_stats.h"], copts = tf_profiler_copts(), - deps = [ - ":op_metrics_db_combiner", - ":xplane_to_tf_functions", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:power_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:topology_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@org_xprof//xprof/utils:hardware_type_utils", - "@org_xprof//xprof/utils:kernel_stats_utils", - "@org_xprof//xprof/utils:step_intersection", - ], -) - -tf_cc_test( - name = "op_stats_combiner_test", - srcs = ["op_stats_combiner_test.cc"], - deps = [ - ":op_stats_combiner", - "//tensorflow/core:portable_gif_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@org_xprof//xprof/utils:step_intersection", + visibility = [ + "//platforms/xla/tools/multihost_hlo_runner/hybrid_sim:__subpackages__", + "//tensorflow/core/profiler:internal", ], -) - -cc_library( - name = "preprocess_single_host_xplane", - srcs = ["preprocess_single_host_xplane.cc"], - hdrs = ["preprocess_single_host_xplane.h"], - copts = tf_profiler_copts(), - visibility = ["//tensorflow/core/profiler:internal"], - deps = [ - "//tensorflow/core/profiler/utils:xplane_schema", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:preprocess_xplane", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@org_xprof//xprof/utils:derived_timeline", - ], -) - -cc_library( - name = "xplane_to_tools_data", - srcs = ["xplane_to_tools_data.cc"], - hdrs = ["xplane_to_tools_data.h"], - copts = tf_profiler_copts(), - deps = [ - ":compute_inference_latency", - ":hlo_to_tools_data", - ":multi_xplanes_to_op_stats", - ":multi_xspace_to_inference_stats", - ":op_stats_to_hlo_stats", - ":op_stats_to_input_pipeline_analysis", - ":op_stats_to_op_profile", - ":op_stats_to_overview_page", - ":op_stats_to_pod_viewer", - ":op_stats_to_roofline_model", - ":op_stats_to_tf_stats", - ":preprocess_single_host_xplane", - ":process_megascale_dcn", - ":repository", - ":tool_options", - ":xplane_to_dcn_collective_stats", - ":xplane_to_memory_profile", - ":xplane_to_op_stats", - ":xplane_to_tf_data_stats", - ":xplane_to_tool_names", - ":xplane_to_trace_container", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", - "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:hlo_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:inference_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_profile_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:overview_page_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_data_stats_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_raw_proto_cc", - "@org_xprof//xprof/convert/trace_viewer:trace_events_to_json", - "@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility", - "@org_xprof//xprof/utils:hardware_type_utils", - ], -) - -cc_library( - name = "xplane_to_tf_data_stats", - srcs = ["xplane_to_tf_data_stats.cc"], - hdrs = ["xplane_to_tf_data_stats.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_data_stats_proto_cc", - "@org_xprof//xprof/utils:html_utils", - ], -) - -tf_cc_test( - name = "xplane_to_tf_data_stats_test", - size = "small", - srcs = ["xplane_to_tf_data_stats_test.cc"], - tags = if_oss([ - "manual", - "no_oss", - ]), # b/169705709, no protobuf matchers in OSS. - deps = [ - ":xplane_to_tf_data_stats", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:tf_data_stats_proto_cc", - ], -) - -cc_library( - name = "xplane_to_step_stats", - srcs = ["xplane_to_step_stats.cc"], - hdrs = ["xplane_to_step_stats.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/utils:gpu_event_stats", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@org_xprof//xprof/utils:gpu_event_stats", - ], -) - -cc_library( - name = "hlo_to_tools_data", - srcs = ["hlo_to_tools_data.cc"], - hdrs = ["hlo_to_tools_data.h"], - copts = tf_profiler_copts(), - visibility = ["//visibility:private"], - deps = [ - ":hlo_proto_to_graph_view", - ":hlo_proto_to_memory_visualization_utils", - ":repository", - ":tool_options", - ":xplane_to_hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_xla//xla/service:hlo_proto_cc", - ], -) - -cc_library( - name = "hlo_proto_to_memory_visualization_utils", - srcs = ["hlo_proto_to_memory_visualization_utils.cc"], - hdrs = ["hlo_proto_to_memory_visualization_utils.h"], - copts = tf_profiler_copts(), - visibility = ["//tensorflow/core/profiler/protobuf:memory_viewer_friends"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_xla//xla:shape_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/service:hlo_proto_cc", - ], -) - -tf_cc_test( - name = "hlo_proto_to_memory_visualization_utils_test", - srcs = ["hlo_proto_to_memory_visualization_utils_test.cc"], - deps = [ - ":hlo_proto_to_memory_visualization_utils", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc", - "//tensorflow/core/util/proto:proto_utils", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/service:hlo_proto_cc", - ], -) - -cc_library( - name = "xplane_to_hlo", - srcs = ["xplane_to_hlo.cc"], - hdrs = ["xplane_to_hlo.h"], - copts = tf_profiler_copts(), - deps = [ - ":repository", - "//tensorflow/core:lib", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/utils:hlo_proto_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/profiler/utils:file_system_utils", - ], -) - -cc_library( - name = "op_profile_builder", - srcs = ["op_profile_builder.cc"], - hdrs = ["op_profile_builder.h"], - deps = [ - ":op_metrics_db_combiner", - ":op_metrics_to_record", - "//tensorflow/core:lib", - "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@org_xprof//xprof/utils:op_metrics_db_utils", - ], -) - -cc_library( - name = "hlo_proto_to_graph_view", - srcs = ["hlo_proto_to_graph_view.cc"], - hdrs = ["hlo_proto_to_graph_view.h"], - deps = [ - ":tool_options", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - # copybara:uncomment(b/360874576) "//tensorflow/compiler/mlir/lite/experimental/google/tooling/hlo_adapter:direct_hlo_to_json_graph_convert", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_graph_dumper", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/platform:errors", - "@local_xla//xla/tsl/platform:statusor", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/utils:hlo_module_utils", - # copybara:uncomment "@com_github_nlohmann_json//:json", - "@org_xprof//xprof/utils:hlo_proto_to_module", - ], -) - -tf_cc_test( - name = "hlo_proto_to_graph_view_test", - size = "small", - srcs = ["hlo_proto_to_graph_view_test.cc"], - deps = [ - ":hlo_proto_to_graph_view", - ":tool_options", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/protobuf:error_codes_proto_impl_cc", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/service:hlo_graph_dumper", - "@local_xla//xla/tsl/platform:status_matchers", - "@local_xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "tool_options", - hdrs = ["tool_options.h"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "repository", - srcs = ["repository.cc"], - hdrs = ["repository.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/platform:errors", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:file_system_utils", - "@org_xprof//xprof/utils:hlo_module_map", - ], -) - -tf_cc_test( - name = "repository_test", - size = "small", - srcs = ["repository_test.cc"], - deps = [ - ":repository", - "//tensorflow/core/platform:errors", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:status", - ], -) - -cc_library( - name = "xplane_to_tool_names", - srcs = ["xplane_to_tool_names.cc"], - hdrs = ["xplane_to_tool_names.h"], - deps = [ - ":repository", - ":xplane_to_dcn_collective_stats", - ":xplane_to_hlo", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - ], -) - -tf_cc_test( - name = "xplane_to_tool_names_test", - size = "small", - srcs = ["xplane_to_tool_names_test.cc"], - deps = [ - ":repository", - ":xplane_to_tool_names", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:status", - ], -) - -cc_library( - name = "xplane_to_trace_container", - srcs = ["xplane_to_trace_container.cc"], - hdrs = ["xplane_to_trace_container.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", - "//tensorflow/core/profiler/protobuf:trace_events_raw_proto_cc", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:trace_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:trace_events_raw_proto_cc", - "@org_xprof//xprof/convert/trace_viewer:trace_event_arguments_builder", - "@org_xprof//xprof/convert/trace_viewer:trace_events", - "@org_xprof//xprof/convert/trace_viewer:trace_events_util", - ], -) - -cc_library( - name = "dcn_utils", - srcs = ["dcn_utils.cc"], - hdrs = ["dcn_utils.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -tf_cc_test( - name = "dcn_utils_test", - srcs = ["dcn_utils_test.cc"], - deps = [ - ":dcn_utils", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "dcn_analysis", - srcs = ["dcn_analysis.cc"], - hdrs = ["dcn_analysis.h"], - visibility = ["//visibility:public"], - deps = [ - ":dcn_utils", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "process_megascale_dcn", - srcs = ["process_megascale_dcn.cc"], - hdrs = ["process_megascale_dcn.h"], - deps = [ - ":dcn_analysis", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - ], -) - -tf_cc_test( - name = "dcn_analysis_test", - srcs = ["dcn_analysis_test.cc"], - deps = [ - ":dcn_analysis", - ":dcn_utils", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "xspace_to_dcn_slack_analysis", - srcs = ["xspace_to_dcn_slack_analysis.cc"], - hdrs = ["xspace_to_dcn_slack_analysis.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:dcn_collective_info_proto_cc", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/protobuf:topology_proto_cc", - "//tensorflow/core/profiler/utils:hlo_module_utils", - "//tensorflow/core/profiler/utils:hlo_proto_map", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:regexp", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla:shape_util", - "@local_xla//xla:side_effect_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_collective_info_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:topology_proto_cc", - "@org_xprof//xprof/utils:hlo_proto_to_module", - ], -) - -cc_library( - name = "dcn_slack_analysis_combiner", - srcs = ["dcn_slack_analysis_combiner.cc"], - hdrs = ["dcn_slack_analysis_combiner.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "xplane_to_dcn_collective_stats", - srcs = ["xplane_to_dcn_collective_stats.cc"], - hdrs = ["xplane_to_dcn_collective_stats.h"], - copts = tf_profiler_copts(), - deps = [ - ":dcn_slack_analysis_combiner", - ":repository", - ":xspace_to_dcn_slack_analysis", - "//tensorflow/core:lib", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", - ], -) - -tf_cc_test( - name = "xplane_to_dcn_collective_stats_test", - size = "small", - srcs = ["xplane_to_dcn_collective_stats_test.cc"], - deps = [ - ":repository", - ":xplane_to_dcn_collective_stats", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:status", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:dcn_slack_analysis_proto_cc", - ], -) - -cc_library( - name = "inference_stats", - srcs = ["inference_stats.cc"], - hdrs = ["inference_stats.h"], - deps = [ - "//tensorflow/core/lib/gtl:map_util", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:logging", - "@local_xla//xla/tsl/profiler/utils:device_utils", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@org_xprof//xprof/utils:event_span", - ], -) - -cc_library( - name = "inference_stats_sampler", - srcs = ["inference_stats_sampler.cc"], - hdrs = ["inference_stats_sampler.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "inference_stats_grouping", - srcs = ["inference_stats_grouping.cc"], - hdrs = ["inference_stats_grouping.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/tsl/lib/gtl:map_util", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "inference_stats_combiner", - srcs = ["inference_stats_combiner.cc"], - hdrs = ["inference_stats_combiner.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/lib/gtl:map_util", - ], -) - -cc_library( - name = "multi_xspace_to_inference_stats", - srcs = ["multi_xspace_to_inference_stats.cc"], - hdrs = ["multi_xspace_to_inference_stats.h"], - deps = [ - ":inference_stats", - ":inference_stats_combiner", - ":inference_stats_grouping", - ":inference_stats_sampler", - ":preprocess_single_host_xplane", - ":repository", - ":xplane_to_step_events", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:device_utils", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:inference_stats_proto_cc", - "@org_xprof//xprof/utils:event_span", - ], -) - -tf_cc_test( - name = "inference_stats_grouping_test", - srcs = ["inference_stats_grouping_test.cc"], - tags = [ - "no_oss", - ], - deps = [ - ":inference_stats_grouping", - "//tensorflow/core:test", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tests:test_utils", - ], -) - -tf_cc_test( - name = "inference_stats_sampler_test", - srcs = ["inference_stats_sampler_test.cc"], - tags = [ - "no_oss", - ], - deps = [ - ":inference_stats_sampler", - "//tensorflow/core:test", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tests:test_utils", - ], -) - -cc_library( - name = "compute_inference_latency", - srcs = ["compute_inference_latency.cc"], - hdrs = ["compute_inference_latency.h"], - visibility = ["//perftools/accelerators/xprof/convert:__pkg__"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "profile_time_breakdown", - srcs = ["profile_time_breakdown.cc"], - hdrs = ["profile_time_breakdown.h"], - visibility = ["@local_xla//xla/tsl/profiler:friends"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "tpu_input_pipeline_analysis_constants", - srcs = [tf_profiler_alias("@org_xprof//xprof/convert/", "tpu_input_pipeline_analysis_constants.cc")], - hdrs = ["tpu_input_pipeline_analysis_constants.h"], - visibility = ["@local_xla//xla/tsl/profiler:friends"], - deps = [ - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/platform:macros", - ], -) - -cc_library( - name = "duty_cycle_tracker", - srcs = ["duty_cycle_tracker.cc"], - hdrs = ["duty_cycle_tracker.h"], - deps = [ - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/log:check", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "duty_cycle_combiner", - hdrs = ["duty_cycle_combiner.h"], - deps = [ - ":duty_cycle_tracker", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], -) - -tf_cc_test( - name = "duty_cycle_tracker_test", - srcs = ["duty_cycle_tracker_test.cc"], - deps = [ - ":duty_cycle_tracker", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -tf_cc_test( - name = "compute_inference_latency_test", - srcs = ["compute_inference_latency_test.cc"], - deps = [ - ":compute_inference_latency", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_googletest//:gtest_main", - ], -) - -tf_cc_test( - name = "duty_cycle_combiner_test", - srcs = ["duty_cycle_combiner_test.cc"], - deps = [ - ":duty_cycle_combiner", - ":duty_cycle_tracker", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_googletest//:gtest", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "data_table_utils", - hdrs = ["data_table_utils.h"], - deps = [ - "@com_github_nlohmann_json//:json", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "data_table_utils_test", - srcs = ["data_table_utils_test.cc"], deps = [ - ":data_table_utils", - "@com_github_nlohmann_json//:json", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", + "@org_xprof//xprof/convert:xplane_to_op_stats", ], ) diff --git a/tensorflow/core/profiler/convert/compute_inference_latency.cc b/tensorflow/core/profiler/convert/compute_inference_latency.cc deleted file mode 100644 index ba0c8245fd033e..00000000000000 --- a/tensorflow/core/profiler/convert/compute_inference_latency.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/compute_inference_latency.h" - -#include -#include -#include -#include -#include - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" - -namespace tensorflow::profiler { - -struct LatencyBreakdown { - double total_latency_us = 0.0; - double host_latency_us = 0.0; - double device_latency_us = 0.0; - double communication_latency_us = 0.0; -}; - -void SetLatencyBreakdown(const LatencyBreakdown& src, - OverviewLatencyBreakdown* res) { - res->set_total_latency_us(src.total_latency_us); - res->set_host_latency_us(src.host_latency_us); - res->set_device_latency_us(src.device_latency_us); - res->set_communication_latency_us(src.communication_latency_us); -} - -void SafeDivide(int64_t count, double* num) { - constexpr double kEpsilon = 1.0e-20; - if (count == 0 || std::abs(*num) < kEpsilon) { - *num = 0.0; - } else { - *num /= count; - } -} - -void ComputeAverage(int64_t count, LatencyBreakdown* breakdown) { - SafeDivide(count, &breakdown->total_latency_us); - SafeDivide(count, &breakdown->host_latency_us); - SafeDivide(count, &breakdown->device_latency_us); - SafeDivide(count, &breakdown->communication_latency_us); -} - -void ComputeBreakdownFromSessionRun( - const tensorflow::profiler::RequestDetail& request_detail, - LatencyBreakdown* res, LatencyBreakdown* avg) { - double session_run_duration_us = tsl::profiler::PicoToMicro( - request_detail.end_time_ps() - request_detail.start_time_ps()); - double device_time_us = - tsl::profiler::PicoToMicro(request_detail.device_time_ps()); - double communication_time_us = - tsl::profiler::PicoToMicro(request_detail.read_from_device_time_ps() + - request_detail.write_to_device_time_ps()); - double host_time_us = - session_run_duration_us - device_time_us - communication_time_us; - *res = {session_run_duration_us, host_time_us, device_time_us, - communication_time_us}; - - avg->total_latency_us += session_run_duration_us; - avg->device_latency_us += device_time_us; - avg->communication_latency_us += communication_time_us; - avg->host_latency_us += - session_run_duration_us - device_time_us - communication_time_us; -} - -// Compute the inference latency from inference stats proto. -OverviewInferenceLatency ComputeInferenceLatencyResult( - const tensorflow::profiler::InferenceStats& inference_stats) { - OverviewInferenceLatency result; - // If inference_stats is empty, return early with empty result. - // The following code is able to return empty result even - // without early return. - if (inference_stats.inference_stats_per_model_size() == 0) return result; - - // Target percentiles over all session runs. - // Default is [50.0, 75.0, 90.0, 99.0, 99.9]. - constexpr double kTargetPercentiles[] = {50.0, 75.0, 90.0, 99.0, 99.9}; - // Saves the latency corresponding to each percentile. - - std::vector sessions; - double total_sessioins_per_sec = 0; - double max_latency = 0.0; - double min_latency = std::numeric_limits::max(); - LatencyBreakdown avg; - // Iterate over all session runs from all models, calculate the device, - // communication, and host time for each session run, and push in the - // vector sessions. Also update the max, min, count, avg. - for (const auto& model_inference_stats : - inference_stats.inference_stats_per_model()) { - total_sessioins_per_sec += - model_inference_stats.second.request_throughput(); - for (const auto& request_detail : - model_inference_stats.second.request_details()) { - LatencyBreakdown session_breakdown; - ComputeBreakdownFromSessionRun(request_detail, &session_breakdown, &avg); - sessions.push_back(session_breakdown); - double session_run_duration_us = tsl::profiler::PicoToMicro( - request_detail.end_time_ps() - request_detail.start_time_ps()); - max_latency = std::max(max_latency, session_run_duration_us); - min_latency = std::min(min_latency, session_run_duration_us); - } - } - // Return empty result if there is no session found. - if (sessions.empty()) return result; - result.set_sessions_per_second(total_sessioins_per_sec); - result.set_max_latency_us(max_latency); - result.set_min_latency_us(min_latency); - ComputeAverage(sessions.size(), &avg); - - // Sort the sessions based on session run duration. For a specified - // percentile, get the corresponding session with the (lower-bound) index. - std::sort(sessions.begin(), sessions.end(), - [](const LatencyBreakdown& a, const LatencyBreakdown& b) { - return a.total_latency_us < b.total_latency_us; - }); - for (const auto& percent : kTargetPercentiles) { - result.add_percentile_numbers(percent); - int64_t index = percent / 100.0 * sessions.size(); - SetLatencyBreakdown(sessions[index], result.add_latency_breakdowns()); - } - // Set the average latency stats. - SetLatencyBreakdown(avg, result.add_latency_breakdowns()); - - return result; -} - -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/compute_inference_latency.h b/tensorflow/core/profiler/convert/compute_inference_latency.h deleted file mode 100644 index 91632c907cbdf2..00000000000000 --- a/tensorflow/core/profiler/convert/compute_inference_latency.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ - -#include -#include - -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" - -namespace tensorflow::profiler { - -// Compute the inference latency from inference stats proto. -OverviewInferenceLatency ComputeInferenceLatencyResult( - const InferenceStats& inference_stats); - -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ diff --git a/tensorflow/core/profiler/convert/compute_inference_latency_test.cc b/tensorflow/core/profiler/convert/compute_inference_latency_test.cc deleted file mode 100644 index efd931c1384739..00000000000000 --- a/tensorflow/core/profiler/convert/compute_inference_latency_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/compute_inference_latency.h" - -#include -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -namespace { - -constexpr double kMaxError = 0.0001; - -TEST(ComputeInferenceLatencyResult, InferenceLatencyTest) { - InferenceStats inference_stats; - auto& model = (*inference_stats.mutable_inference_stats_per_model())[0]; - - // Generates requests for testing. - for (int i = 0; i < 100; i++) { - RequestDetail request_detail; - request_detail.set_start_time_ps(0); - request_detail.set_end_time_ps(i * 10000); - request_detail.set_device_time_ps(i * 1000); - request_detail.set_write_to_device_time_ps(i * 1000); - model.add_request_details()->Swap(&request_detail); - } - - auto result = ComputeInferenceLatencyResult(inference_stats); - - // 5 percentiles and 1 average, so 6 results in total. - ASSERT_EQ(result.latency_breakdowns_size(), 6); - - // Verify 50 percentile result. - EXPECT_NEAR(result.latency_breakdowns(0).total_latency_us(), 0.5, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(0).host_latency_us(), 0.4, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(0).device_latency_us(), 0.05, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(0).communication_latency_us(), 0.05, - kMaxError); - - // Verify 99.9 percentile result. - EXPECT_NEAR(result.latency_breakdowns(4).total_latency_us(), 0.99, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(4).host_latency_us(), 0.792, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(4).device_latency_us(), 0.099, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(4).communication_latency_us(), 0.099, - kMaxError); - - // Verify average result. - EXPECT_NEAR(result.latency_breakdowns(5).total_latency_us(), 0.495, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(5).host_latency_us(), 0.396, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(5).device_latency_us(), 0.0495, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(5).communication_latency_us(), 0.0495, - kMaxError); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/data_table_utils.h b/tensorflow/core/profiler/convert/data_table_utils.h deleted file mode 100644 index 34bc248b356db5..00000000000000 --- a/tensorflow/core/profiler/convert/data_table_utils.h +++ /dev/null @@ -1,155 +0,0 @@ - -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DATA_TABLE_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DATA_TABLE_UTILS_H_ -#include -#include -#include - -#include "absl/container/btree_map.h" -#include "absl/strings/str_replace.h" -#include "nlohmann/json_fwd.hpp" -#include "nlohmann/json.hpp" -namespace tensorflow { -namespace profiler { -// We Don't deal with formatted values on backend now. -struct TableCell { - TableCell() = default; - explicit TableCell(nlohmann::json value) : value(value) {}; - explicit TableCell( - nlohmann::json value, - absl::btree_map custom_properties) - : value(value), custom_properties(custom_properties) {}; - std::string value_str() const { - return absl::StrReplaceAll(value.dump(), {{"\"", ""}}); - } - nlohmann::json value; - absl::btree_map custom_properties; -}; -struct TableColumn { - TableColumn() = default; - explicit TableColumn(std::string id, std::string type, std::string label) - : id(id), type(type), label(label) {}; - explicit TableColumn( - std::string id, std::string type, std::string label, - absl::btree_map custom_properties) - : id(id), type(type), label(label), custom_properties(custom_properties) { - }; - std::string id; - std::string type; - std::string label; - absl::btree_map custom_properties; -}; -class TableRow { - public: - TableRow() = default; - virtual ~TableRow() = default; - // Adds a value of a single cell to the end of the row. - // Memory will be freed by the TableRow. - TableCell* AddCell(nlohmann::json value) { - cells_.push_back(std::make_unique(value)); - return cells_.back().get(); - } - std::vector GetCells() const { - std::vector cells; - cells.reserve(cells_.size()); - for (const std::unique_ptr& cell : cells_) { - cells.push_back(cell.get()); - } - return cells; - } - void SetCustomProperties( - const absl::btree_map& custom_properties) { - custom_properties_ = custom_properties; - } - void AddCustomProperty(std::string name, std::string value) { - custom_properties_[name] = value; - } - const absl::btree_map& GetCustomProperties() const { - return custom_properties_; - } - int RowSize() const { return cells_.size(); } - - private: - std::vector> cells_; - absl::btree_map custom_properties_; -}; -// A DataTable class that can be used to create a DataTable JSON/CSV -// serialization. We need this class instead raw JSON manipulation because we -// need to support custom properties. -class DataTable { - public: - DataTable() = default; - void AddColumn(TableColumn column) { table_descriptions_.push_back(column); } - const std::vector& GetColumns() { return table_descriptions_; } - // Create an empty row and return a pointer to it. - // DataTable takes the ownership of the returned TableRow. - TableRow* AddRow() { - table_rows_.push_back(std::make_unique()); - return table_rows_.back().get(); - } - std::vector GetRows() { - std::vector rows; - rows.reserve(table_rows_.size()); - for (const std::unique_ptr& row : table_rows_) { - rows.push_back(row.get()); - } - return rows; - } - void AddCustomProperty(std::string name, std::string value) { - custom_properties_[name] = value; - } - std::string ToJson() { - nlohmann::json table; - table["cols"] = nlohmann::json::array(); - table["rows"] = nlohmann::json::array(); - if (!custom_properties_.empty()) { - table["p"] = custom_properties_; - } - for (const TableColumn& col : table_descriptions_) { - nlohmann::json column_json; - column_json["id"] = col.id; - column_json["type"] = col.type; - column_json["label"] = col.label; - if (!col.custom_properties.empty()) { - column_json["p"] = col.custom_properties; - } - table["cols"].push_back(column_json); - } - for (const std::unique_ptr& row : table_rows_) { - nlohmann::json row_json; - row_json["c"] = nlohmann::json::array(); - for (const TableCell* cell : row->GetCells()) { - nlohmann::json cell_json; - cell_json["v"] = cell->value; - if (!cell->custom_properties.empty()) { - cell_json["p"] = cell->custom_properties; - } - row_json["c"].push_back(cell_json); - } - if (!row->GetCustomProperties().empty()) { - row_json["p"] = row->GetCustomProperties(); - } - table["rows"].push_back(row_json); - } - return table.dump(); - } - - private: - std::vector table_descriptions_; - std::vector> table_rows_; - absl::btree_map custom_properties_; -}; -} // namespace profiler -} // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DATA_TABLE_UTILS_H_ diff --git a/tensorflow/core/profiler/convert/data_table_utils_test.cc b/tensorflow/core/profiler/convert/data_table_utils_test.cc deleted file mode 100644 index 58c89f0e25f306..00000000000000 --- a/tensorflow/core/profiler/convert/data_table_utils_test.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/data_table_utils.h" - -#include -#include -#include - -#include -#include "nlohmann/json_fwd.hpp" -#include "nlohmann/json.hpp" - -namespace tensorflow::profiler { -namespace { - -std::vector> GetTestColumns() { - return {{"rank", "number", "Rank"}, - {"program_id", "string", "Program Id"}, - {"op_category", "string", "Op Category"}, - {"op_name", "string", "Op Name"}, - {"bytes_accessed", "number", "Bytes Accessed"}, - {"model_flops", "number", "Model Flops"}, - {"occurrences", "number", "#Occurrences"}}; -} - -std::vector GetTestRows() { - return {{1, "11111", "category1", "op1", 200000000, 123123123, 10}, - {2, "22222", "category2", "op2", 1000000, 0, 20}, - {3, "33333", "category3", "op3", 3000000, 565656, 30}}; -} - -std::unique_ptr CreateTestDataTable() { - auto data_table = std::make_unique(); - for (const std::vector& col : GetTestColumns()) { - data_table->AddColumn(TableColumn(col[0], col[1], col[2])); - } - for (const nlohmann::json& row_json : GetTestRows()) { - TableRow* row = data_table->AddRow(); - for (int i = 0; i < row_json.size(); ++i) { - row->AddCell(row_json[i]); - } - } - return data_table; -} - -std::unique_ptr -CreateTestDataTableWithCustomProperties() { - auto data_table = std::make_unique(); - data_table->AddCustomProperty("key1", "value1"); - data_table->AddCustomProperty("key2", "value2"); - return data_table; -} - -TEST(DataTableUtilsTest, ToJson) { - std::unique_ptr data_table = - CreateTestDataTable(); - std::string json_string = data_table->ToJson(); - const nlohmann::basic_json<> parsed_json = nlohmann::json::parse(json_string); - auto test_columns = GetTestColumns(); - auto test_rows = GetTestRows(); - EXPECT_EQ(parsed_json["cols"].size(), test_columns.size()); - EXPECT_EQ(parsed_json["rows"].size(), test_rows.size()); - for (int i = 0; i < test_columns.size(); ++i) { - EXPECT_EQ(parsed_json["cols"][i]["id"], test_columns[i][0]); - EXPECT_EQ(parsed_json["cols"][i]["label"], test_columns[i][2]); - EXPECT_EQ(parsed_json["cols"][i]["type"], test_columns[i][1]); - } - for (int i = 0; i < test_rows.size(); ++i) { - for (int j = 0; j < test_columns.size(); ++j) { - EXPECT_EQ(parsed_json["rows"][i]["c"][j]["v"], GetTestRows()[i][j]); - } - } -} - -TEST(DataTableUtilsTest, ToJsonWithCustomProperties) { - std::unique_ptr data_table = - CreateTestDataTableWithCustomProperties(); - std::string table_json_string = data_table->ToJson(); - const nlohmann::basic_json<> parsed_json = - nlohmann::json::parse(table_json_string); - EXPECT_EQ(parsed_json.find("p")->size(), 2); - EXPECT_EQ(parsed_json.find("p")->at("key1"), "value1"); - EXPECT_EQ(parsed_json.find("p")->at("key2"), "value2"); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/dcn_analysis.cc b/tensorflow/core/profiler/convert/dcn_analysis.cc deleted file mode 100644 index 15de6d44400def..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_analysis.cc +++ /dev/null @@ -1,471 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/dcn_analysis.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/convert/dcn_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::kMaxCollectivesToDisplay; -using tsl::profiler::kMegaScaleDcnReceive; -using tsl::profiler::LineIdType; -using tsl::profiler::MicroToNano; - -void DcnBurstManager::ResetBurstState() { - active_burst_messages_ = 0; - straggler_idx_ = 0; - active_burst_.num_messages = 0; - active_burst_.max_overlapping_messages = 0; - active_burst_.start_timestamp_ns = 0; - active_burst_.end_timestamp_ns = 0; - active_burst_.burst_size_bytes = 0; -} - -void DcnBurstManager::CreateBursts(const TimestampMap& tm_events) { - ResetBurstState(); - for (const auto& tm_event : tm_events) { - if (active_burst_messages_ < 0) { - LOG_FIRST_N(WARNING, 10) - << "Negative messages in burst, bursts will be incorrect."; - } - if (active_burst_messages_ == 0) { - // When no messages are active, next event starts a new burst - active_burst_.start_timestamp_ns = tm_event.first; - } - active_burst_messages_ += tm_event.second->message_diff; - if (tm_event.second->message_diff > 0) { - // On beginning of message increase messages and bytes - active_burst_.num_messages += tm_event.second->message_diff; - active_burst_.burst_size_bytes += tm_event.second->size_diff; - } else { - // On end of message, register straggler - Straggler straggler = {tm_event.second->duration_ns, // duration_ns - tm_event.second->timestamp_ns, // end_timestamp_ns - tm_event.second->size_diff * (-1), // size_bytes - tm_event.second->src_slice_id}; // src_slice_id - active_burst_.stragglers[straggler_idx_] = straggler; - straggler_idx_ = (straggler_idx_ + 1) % kMaxStragglersPerBurst; - } - active_burst_.max_overlapping_messages = - std::max(active_burst_.max_overlapping_messages, - static_cast(active_burst_messages_)); - // If we are back at 0 messages, the burst has finished and can be added - // to the bursts_ vector. - if (active_burst_messages_ == 0) { - active_burst_.end_timestamp_ns = tm_event.first; - total_latency_ += - (active_burst_.end_timestamp_ns - active_burst_.start_timestamp_ns); - bursts_.emplace_back(std::move(active_burst_)); - ResetBurstState(); - } - } -} - -DcnEventsProcessor::DcnEventsProcessor(uint32_t num_tpu_tensor_cores, - bool is_megacore) - : num_tpu_tensor_cores_(num_tpu_tensor_cores), is_megacore_(is_megacore) { - // Register all MSXLA messages we may need to analyze. Currently only - // receive messages are processed. - registered_dcn_messages_.push_back(kMegaScaleDcnReceive); - tpu_collective_ts_map_.resize(num_tpu_tensor_cores_); - tpu_collective_bursts_.resize(num_tpu_tensor_cores_); -} - -// Sets up map between registered Megascale messages and their event metadata -// so they can be captured from host events. -void DcnEventsProcessor::SetupMessageInfo(const XPlaneVisitor& plane) { - plane.ForEachEventMetadata([&](const XEventMetadataVisitor& event_metadata) { - if (std::find(registered_dcn_messages_.begin(), - registered_dcn_messages_.end(), - event_metadata.Name()) != registered_dcn_messages_.end()) { - megascale_msg_[event_metadata.Name()] = event_metadata.Id(); - } - }); -} - -// If we use megacore, collective traffic goes to even TPU tensor cores. -// Odd ones are woken up from their even pair (e.g. 0 wakes up 1). -uint32_t DcnEventsProcessor::FindTpuIdx(int tpu) { - uint32_t num_tpus = num_tpu_tensor_cores_; - if (is_megacore_) { - num_tpus /= 2; - } - uint32_t tpu_idx = tpu % num_tpus; - if (is_megacore_) { - tpu_idx = tpu_idx * 2; - } - return tpu_idx; -} - -void DcnEventsProcessor::GenerateTimestampEvents( - const DcnMessage& dcn_message) { - // Create one event for the beginning and one for the end of the message - std::shared_ptr start_event( - new TimestampEvent{dcn_message.start_timestamp_ns, 0, 1, - dcn_message.size_bytes, dcn_message.slice_src}); - std::shared_ptr end_event(new TimestampEvent{ - dcn_message.end_timestamp_ns, - static_cast(MicroToNano(dcn_message.duration_us)), -1, - -1 * dcn_message.size_bytes, dcn_message.slice_src}); - - // Add messages to host timestamp event map - std::pair> start_event_entry = - std::make_pair(dcn_message.start_timestamp_ns, start_event); - std::pair> end_event_entry = - std::make_pair(dcn_message.end_timestamp_ns, end_event); - host_ts_map_.insert(start_event_entry); - host_ts_map_.insert(end_event_entry); - - // Add messages to the proper TPU collective timestamp event map. - const std::string& collective_name = dcn_message.collective_name; - uint32_t tpu_idx = FindTpuIdx(dcn_message.tpu_dst); - auto& m = tpu_collective_ts_map_[tpu_idx][collective_name]; - m.insert(start_event_entry); - m.insert(end_event_entry); -} - -void DcnEventsProcessor::PrintTimestampEvents() { - for (const auto& host_ts : host_ts_map_) { - LOG(INFO) << host_ts.first << ": " << host_ts.second->timestamp_ns << " " - << host_ts.second->duration_ns << " " - << host_ts.second->message_diff << " " - << host_ts.second->size_diff << " " - << host_ts.second->src_slice_id; - } - for (uint32_t tpu_idx = 0; tpu_idx < num_tpu_tensor_cores_; tpu_idx++) { - LOG(INFO) << "TPU: " << tpu_idx; - for (const auto& col_id : tpu_collective_ts_map_[tpu_idx]) { - LOG(INFO) << col_id.first; - for (const auto& tpu_col_ts : - tpu_collective_ts_map_[tpu_idx][col_id.first]) { - LOG(INFO) << tpu_col_ts.first << ": " << tpu_col_ts.second->timestamp_ns - << " " << tpu_col_ts.second->duration_ns << " " - << tpu_col_ts.second->message_diff << " " - << tpu_col_ts.second->size_diff << " " - << tpu_col_ts.second->src_slice_id; - } - } - } -} - -// Uses heuristics to qualify a good enough amount of collectives. -// kMaxCollectivesToDisplay - 1 are displayed. -// Collectives with < 5% of total host BW time are never qualified -// Collectives with < 20% of total host BW time are qualified if less than 4 -// collectives have already been qualified. -// Top 8 collectives with > 20% of total host BW time are qualified -uint32_t DcnEventsProcessor::NumCollectivesQualified( - const std::vector& latencies) { - uint32_t num_collectives_qualified = 0; - // Allow for 1 line to display stragglers of non-qualified collectives. - uint32_t max_collectives = kMaxCollectivesToDisplay - 1; - for (const auto& lat : latencies) { - if (lat < host_dcn_bursts_.TotalLatency() * 0.05) { - return num_collectives_qualified; - } else if (lat < host_dcn_bursts_.TotalLatency() * 0.2 && - num_collectives_qualified >= (max_collectives / 2)) { - return num_collectives_qualified; - } else if (num_collectives_qualified >= max_collectives) { - return num_collectives_qualified; - } else { - num_collectives_qualified++; - } - } - return latencies.size(); -} - -// Find which collectives you are going to display in details (dedicated line) -// and which not (shared line for stragglers). -// Order collectives based on burst latency -- then qualify the top ones based -// on NumCollectivesQualified function. -void DcnEventsProcessor::QualifyCollectives() { - for (auto tpu_idx = 0; tpu_idx < num_tpu_tensor_cores_; tpu_idx++) { - std::vector latency_to_order; - latency_to_order.reserve(tpu_collective_bursts_[tpu_idx].size()); - for (const auto& col_info : tpu_collective_bursts_[tpu_idx]) { - latency_to_order.emplace_back(col_info.second.TotalLatency()); - } - std::sort(latency_to_order.begin(), latency_to_order.end(), - std::greater()); - uint32_t num_collectives_qualified = - NumCollectivesQualified(latency_to_order); - if (num_collectives_qualified > 0) { - uint32_t min_latency_to_qualify = - latency_to_order[num_collectives_qualified - 1]; - uint32_t col_num = 0; - for (auto& col_info : tpu_collective_bursts_[tpu_idx]) { - if (col_info.second.TotalLatency() >= min_latency_to_qualify) { - col_info.second.SetToDisplay(true); - if (++col_num == kMaxCollectivesToDisplay - 1) break; - } - } - } - } -} - -void DcnEventsProcessor::GenerateBursts() { - host_dcn_bursts_.CreateBursts(host_ts_map_); - host_dcn_bursts_.SetToDisplay(true); - - for (auto tpu_idx = 0; tpu_idx < num_tpu_tensor_cores_; tpu_idx++) { - for (const auto& col_info : tpu_collective_ts_map_[tpu_idx]) { - tpu_collective_bursts_[tpu_idx][col_info.first].CreateBursts( - tpu_collective_ts_map_[tpu_idx][col_info.first]); - } - } - QualifyCollectives(); -} - -void DcnEventsProcessor::ProcessReceiveMessages(const XPlaneVisitor& plane) { - plane.ForEachLine([&](const XLineVisitor& line) { - uint32_t recv_msg_id = megascale_msg_[kMegaScaleDcnReceive]; - line.ForEachEvent([&](const XEventVisitor& event) { - if (event.Id() == recv_msg_id) { - DcnMessage dcn_message = GetDcnMessageFromXEvent(event); - // TODO(emizan): Report invalid and clock skew messages somehow. - // TODO(emizan): Bring back loopback messages when MSXLA fixes them. - if (dcn_message.validity_info == DCN_MESSAGE_VALID) { - GenerateTimestampEvents(dcn_message); - } - received_messages_.emplace_back(std::move(dcn_message)); - } - }); - }); - GenerateBursts(); -} - -absl::string_view DcnEventsProcessor::GetBwInfo(bool is_per_tpu, - const DcnBurst& burst, - float& burst_mean_bw, - float& burst_bw_utilization) { - absl::string_view bw_level; - uint32_t bw_divider = 1; - burst_mean_bw = static_cast(burst.burst_size_bytes) / - (burst.end_timestamp_ns - burst.start_timestamp_ns); - if (is_per_tpu) { - bw_divider = num_tpu_tensor_cores_; - if (is_megacore_) { - bw_divider /= 2; - } - } - // Have 3 BW categories (low/med/high) to limit the amount of colors in the - // trace viewer - if (burst_mean_bw < kLimitLowHostDcnBw / bw_divider) { - bw_level = "Low BW"; - } else if (burst_mean_bw < kLimitMedHostDcnBw / bw_divider) { - bw_level = "Med BW"; - } else { - bw_level = "High BW"; - } - burst_bw_utilization = burst_mean_bw / (kMaxHostDcnBw / bw_divider); - return bw_level; -} - -void DcnEventsProcessor::AddHostDcnTrafficToXPlane(XPlane* host_xplane) { - if (!host_dcn_bursts_.ToDisplay()) return; - XPlaneBuilder plane_builder(host_xplane); - XLineBuilder line = - plane_builder.GetOrCreateLine(LineIdType::kDcnHostTraffic); - line.SetNameIfEmpty("DCN Host Bandwidth"); - line.SetTimestampNs(0); - XStatMetadata* bw_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth (GBytes/sec)"); - XStatMetadata* bw_util_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth Utilization"); - XStatMetadata* num_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Total Messages"); - XStatMetadata* max_overlap_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Max Overlapping Messages"); - XStatMetadata* avg_msg_size_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Average Message Size (Bytes)"); - for (const auto& host_burst : host_dcn_bursts_.GetBursts()) { - float burst_mean_bw, bw_utilization; - absl::string_view bw_level = - GetBwInfo(false, host_burst, burst_mean_bw, bw_utilization); - XEventMetadata* event_metadata = - plane_builder.GetOrCreateEventMetadata(bw_level); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetNs(host_burst.start_timestamp_ns); - event.SetDurationNs(host_burst.end_timestamp_ns - - host_burst.start_timestamp_ns); - - // Using std::string to limit number of decimals. - event.ParseAndAddStatValue(*bw_stat_metadata, - std::to_string(burst_mean_bw)); - event.ParseAndAddStatValue(*bw_util_stat_metadata, - std::to_string(bw_utilization)); - event.AddStatValue(*num_msg_stat_metadata, host_burst.num_messages); - event.AddStatValue(*max_overlap_msg_stat_metadata, - host_burst.max_overlapping_messages); - uint32_t avg_message_size = - host_burst.burst_size_bytes / host_burst.num_messages; - event.AddStatValue(*avg_msg_size_stat_metadata, avg_message_size); - } -} - -void DcnEventsProcessor::AddUnqualifiedCollectivesToXPlane( - XPlaneBuilder& plane_builder, uint32_t tpu_idx) { - XLineBuilder line = - plane_builder.GetOrCreateLine(LineIdType::kDcnCollectiveTrafficMax); - line.SetNameIfEmpty("Remaining collectives"); - line.SetTimestampNs(0); - for (const auto& col_item : tpu_collective_bursts_[tpu_idx]) { - if (col_item.second.ToDisplay()) continue; - for (const auto& col_burst : col_item.second.GetBursts()) { - XEventMetadata* straggler_event_metadata = - plane_builder.GetOrCreateEventMetadata(col_item.first); - uint32_t stragglers_processed = 0; - XStatMetadata* straggler_src_slice_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Source slice"); - XStatMetadata* straggler_duration_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Duration ns"); - XStatMetadata* straggler_send_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Send timestamp ns"); - XStatMetadata* straggler_recv_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Recv timestamp ns"); - for (const auto& straggler : col_burst.stragglers) { - XEventBuilder straggler_event = - line.AddEvent(*straggler_event_metadata); - straggler_event.SetOffsetNs(straggler.end_timestamp_ns - 10000); - straggler_event.SetDurationNs(10000); - straggler_event.AddStatValue(*straggler_src_slice_stat_metadata, - straggler.src_slice_id); - straggler_event.AddStatValue(*straggler_duration_ns_stat_metadata, - straggler.duration_ns); - straggler_event.AddStatValue( - *straggler_send_time_ns_stat_metadata, - straggler.end_timestamp_ns - straggler.duration_ns); - straggler_event.AddStatValue(*straggler_recv_time_ns_stat_metadata, - straggler.end_timestamp_ns); - if (++stragglers_processed >= col_burst.num_messages) break; - } - } - } -} - -void DcnEventsProcessor::AddQualifiedCollectivesToXPlane( - XPlaneBuilder& plane_builder, uint32_t tpu_idx) { - uint32_t total_collectives = 0; - for (const auto& col_item : tpu_collective_bursts_[tpu_idx]) { - // Skip collectives not enabled for display. - if (!col_item.second.ToDisplay()) continue; - const std::string& col_name = col_item.first; - XLineBuilder line = plane_builder.GetOrCreateLine( - LineIdType::kDcnCollectiveTraffic + total_collectives++); - line.SetNameIfEmpty(col_name); - line.SetTimestampNs(0); - XStatMetadata* bw_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth (GBytes/sec)"); - XStatMetadata* bw_util_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth Utilization"); - XStatMetadata* num_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Total Messages"); - XStatMetadata* max_overlap_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Max Overlapping Messages"); - XStatMetadata* avg_msg_size_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Average Message Size (Bytes)"); - XStatMetadata* straggler_details_metadata = - plane_builder.GetOrCreateStatMetadata("Straggler info:"); - XStatMetadata* straggler_src_slice_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Source slice"); - XStatMetadata* straggler_duration_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Duration ns"); - XStatMetadata* straggler_send_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Send timestamp ns"); - XStatMetadata* straggler_recv_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Recv timestamp ns"); - for (const auto& col_burst : col_item.second.GetBursts()) { - float burst_mean_bw, bw_utilization; - absl::string_view bw_level = - GetBwInfo(true, col_burst, burst_mean_bw, bw_utilization); - XEventMetadata* event_metadata = - plane_builder.GetOrCreateEventMetadata(bw_level); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetNs(col_burst.start_timestamp_ns); - event.SetDurationNs(col_burst.end_timestamp_ns - - col_burst.start_timestamp_ns); - event.ParseAndAddStatValue(*bw_stat_metadata, - std::to_string(burst_mean_bw)); - event.ParseAndAddStatValue(*bw_util_stat_metadata, - std::to_string(bw_utilization)); - event.AddStatValue(*num_msg_stat_metadata, col_burst.num_messages); - event.AddStatValue(*max_overlap_msg_stat_metadata, - col_burst.max_overlapping_messages); - event.AddStatValue(*avg_msg_size_stat_metadata, - col_burst.burst_size_bytes / col_burst.num_messages); - // Add straggler info. - XEventMetadata* straggler_event_metadata = - plane_builder.GetOrCreateEventMetadata("Straggler"); - uint32_t stragglers_processed = 0; - std::string straggler_details = "Stragglers:\n"; - for (const auto& straggler : col_burst.stragglers) { - // Add an event for the last straggler - if (straggler.end_timestamp_ns == col_burst.end_timestamp_ns) { - XEventBuilder straggler_event = - line.AddEvent(*straggler_event_metadata); - straggler_event.SetOffsetNs(straggler.end_timestamp_ns - - straggler.duration_ns); - straggler_event.SetDurationNs(straggler.duration_ns); - straggler_event.AddStatValue(*straggler_src_slice_stat_metadata, - straggler.src_slice_id); - straggler_event.AddStatValue(*straggler_duration_ns_stat_metadata, - straggler.duration_ns); - straggler_event.AddStatValue( - *straggler_send_time_ns_stat_metadata, - straggler.end_timestamp_ns - straggler.duration_ns); - straggler_event.AddStatValue(*straggler_recv_time_ns_stat_metadata, - straggler.end_timestamp_ns); - } - // Add text metadata for all stragglers. - straggler_details += - " Src slice: " + std::to_string(straggler.src_slice_id) + - " -- Duration (ns): " + std::to_string(straggler.duration_ns) + - " -- [Send Timestamp, Recv Timestamp]: [" + - std::to_string(straggler.end_timestamp_ns - straggler.duration_ns) + - ", " + std::to_string(straggler.end_timestamp_ns) + "]\n"; - if (++stragglers_processed >= col_burst.num_messages) break; - } - event.AddStatValue(*straggler_details_metadata, straggler_details); - } - } -} - -void DcnEventsProcessor::AddTpuCollectiveDcnTrafficToXPlane( - XPlane* device_xplane) { - XPlaneBuilder plane_builder(device_xplane); - auto tpu = tsl::profiler::GetTensorCoreId(plane_builder.Name()); - if (!tpu.has_value()) return; - uint32_t tpu_idx = FindTpuIdx(tpu.value()); - AddQualifiedCollectivesToXPlane(plane_builder, tpu_idx); - AddUnqualifiedCollectivesToXPlane(plane_builder, tpu_idx); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_analysis.h b/tensorflow/core/profiler/convert/dcn_analysis.h deleted file mode 100644 index d17cfc9f31764a..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_analysis.h +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -namespace tensorflow { -namespace profiler { - -// Structure representing a DcnMessage using two entries: -// One for the start of the message and one for the end. -struct TimestampEvent { - uint64_t timestamp_ns; // TraceMe logging timestamp - uint64_t duration_ns; // 0 for start of message, duration for end of message - int32_t message_diff; // +1/-1 for start/end of message. - // Makes handling 0-sized messages easier and is - // convenient for the burst generation algorithm. - size_t size_diff; // +size/-size for start/end of message. - int32_t src_slice_id; // Source slice for message, used for stragglers -}; - -// We use an multi map since TimestampEvents will be ordered and we -// need separate entries for possible events happening at exactly the -// same time. -typedef std::multimap> TimestampMap; -typedef absl::flat_hash_map CollectiveTimestampMap; - -// Straggler messages. These are shown at the end of the bursts they belong to. -struct Straggler { - uint64_t duration_ns; // Message duration in ns - uint64_t end_timestamp_ns; // End of the message. For the last straggler - // this will be the end of the burst - size_t size_bytes; // Size of the message in bytes - int32_t src_slice_id; // Source slice of the message - // TODO(emizan) Add host info. -}; - -static constexpr uint32_t kMaxStragglersPerBurst = 4; - -// DCN Burst description. -// A burst is defined as a period of time during which there is at least one -// message in the network. Since DCN traffic is bursty this structure is -// convenient to summarize 100K+ messages in a few 10s of bursts. -// Burst scope is flexible. In this analysis we have per-host bursts, which -// include messages arriving on a single host independent of sender/target TPU/ -// and collective. We also have per collective/TPU bursts which include messages -// for a single collective+TPU combination. -struct DcnBurst { - uint64_t start_timestamp_ns; // Beginning of burst in ns - uint64_t end_timestamp_ns; // End of burst in ns - uint64_t burst_size_bytes; // Total number of bytes in burst - uint64_t num_messages; // Messages in burst - uint64_t max_overlapping_messages; // Max overlapping messages in burst - // Buffer of stragglers in a bursts. Contains the last few messages in a burst - std::array stragglers; -}; - -// Class with functionality to generate DcnBursts out of TimestampEvents. -// Burst creation is a non-trivial state machine -class DcnBurstManager { - public: - DcnBurstManager() = default; - uint64_t TotalLatency() const { return total_latency_; } - void SetToDisplay(bool to_display) { to_display_ = to_display; } - bool ToDisplay() const { return to_display_; } - const std::vector &GetBursts() const { return bursts_; } - - // Run burst state machine creation out of timestamp map. - void CreateBursts(const TimestampMap &tm_events); - // For debugging purposes. - void PrintBursts() { - for (const auto &burst : bursts_) { - LOG(INFO) << burst.start_timestamp_ns << " " << burst.end_timestamp_ns - << " " << burst.num_messages << " " << burst.burst_size_bytes - << " " << burst.max_overlapping_messages; - } - } - - private: - std::vector bursts_; // Bursts created by this manager - uint64_t total_latency_ = 0; // Total latency of all bursts created - // Used to see if bursts will be displayed - bool to_display_ = false; // Set to true to enable burst display - - int32_t active_burst_messages_; // Used by burst creation state machine. - DcnBurst active_burst_; // Active burst in creation - uint32_t straggler_idx_; - - // Initializes state machine when new burst is detected. - void ResetBurstState(); -}; - -typedef absl::flat_hash_map - CollectiveBurstManager; - -class DcnEventsProcessor { - public: - DcnEventsProcessor() = delete; - DcnEventsProcessor(uint32_t num_tpu_tensor_cores, bool is_megacore); - - uint32_t NumTpuTensorCores() const { return num_tpu_tensor_cores_; } - bool IsMegacore() const { return is_megacore_; } - - // Populates available megascale messages from event metadata. - void SetupMessageInfo(const tsl::profiler::XPlaneVisitor &plane); - - std::optional MegaScaleMessageId(absl::string_view msg_name) const { - auto iter = megascale_msg_.find(msg_name); - if (iter != megascale_msg_.end()) { - return iter->second; - } - return std::nullopt; - } - - uint32_t NumReceivedMessages() const { return received_messages_.size(); } - const tensorflow::profiler::DcnMessage &GetMessage(uint32_t i) const { - return received_messages_[i]; - } - - // Checks if messages with msg event name have been found in event metadata. - bool HasDcnMessages(absl::string_view msg_name) const { - return (megascale_msg_.find(msg_name) != megascale_msg_.end()); - } - - const TimestampMap &HostTsMap() const { return host_ts_map_; } - const std::vector &GetHostBursts() const { - return host_dcn_bursts_.GetBursts(); - } - - // Main function to process receive messages, and call other functions - // to generate timestamp events and bursts. - void ProcessReceiveMessages(const tsl::profiler::XPlaneVisitor &plane); - - // Update XPlanes using DCN traffic info - void AddHostDcnTrafficToXPlane(tsl::profiler::XPlane *host_xplane); - void AddTpuCollectiveDcnTrafficToXPlane(tsl::profiler::XPlane *device_xplane); - - private: - // Tensor cores and megacore flag for this host. DCN messages are sent to a - // TPU chip, so we need to know the number of tensor cores and whether - // megacore is used to map DCN traffic to the proper tensor core. - const uint32_t num_tpu_tensor_cores_; - const bool is_megacore_; - - // Used for visualization of BW and computation of BW utilization. - static constexpr float kLimitLowHostDcnBw = 4.17; - static constexpr float kLimitMedHostDcnBw = 8.34; - static constexpr float kMaxHostDcnBw = 12.5; - - std::vector registered_dcn_messages_; - - // Available megascale messages for this trace. - absl::flat_hash_map megascale_msg_; - - std::vector received_messages_; - - // TimestampMaps for messages that arrive to this host - // and for messages of distinct collectives going to different TPUs. - TimestampMap host_ts_map_; - std::vector tpu_collective_ts_map_; - - // DcnBurstManagers for bursts that arrive to this host - // and for burst from distinct collectives going to different TPUs. - DcnBurstManager host_dcn_bursts_; - std::vector tpu_collective_bursts_; - - // Find the TPU index a DCN message goes to. - uint32_t FindTpuIdx(int tpu); - - // Generates BW info to display in the trace viewer. - // This included trace event BW level string, mean BW per burst and - // utilization. - absl::string_view GetBwInfo(bool is_per_tpu, const DcnBurst &burst, - float &burst_mean_bw, - float &burst_bw_utilization); - - // Qualify collectives to display on trace viewer. - // Qualified collectives are given a dedicated line, while for the rest - // we share a single line for their stragglers. - uint32_t NumCollectivesQualified(const std::vector &latencies); - void QualifyCollectives(); - // Export collective DCN activity to trace viewer. - void AddQualifiedCollectivesToXPlane( - tsl::profiler::XPlaneBuilder &plane_builder, uint32_t tpu_idx); - void AddUnqualifiedCollectivesToXPlane( - tsl::profiler::XPlaneBuilder &plane_builder, uint32_t tpu_idx); - - // Create timestamp events for every message - void GenerateTimestampEvents( - const tensorflow::profiler::DcnMessage &dcn_message); - // For debugging purposes - void PrintTimestampEvents(); - // Generate bursts (host and TPU/collective) from timestamp events. - void GenerateBursts(); -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/convert/dcn_analysis_test.cc b/tensorflow/core/profiler/convert/dcn_analysis_test.cc deleted file mode 100644 index b71a583bf26d65..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_analysis_test.cc +++ /dev/null @@ -1,363 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/dcn_analysis.h" - -#include -#include -#include - -#include -#include -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -using tensorflow::profiler::DCN_MESSAGE_INVALID_BAD_KEY; -using tensorflow::profiler::DCN_MESSAGE_INVALID_CLOCK_SKEW; -using tensorflow::profiler::DCN_MESSAGE_VALID; -using tensorflow::profiler::DCN_MESSAGE_VALID_LOOPBACK; -using ::testing::FieldsAre; -using tsl::profiler::kMegaScaleDcnReceive; -using tsl::profiler::kMegaScaleDcnSend; -using tsl::profiler::XEventBuilder; -using tsl::profiler::XEventMetadata; -using tsl::profiler::XLineBuilder; -using tsl::profiler::XPlane; -using tsl::profiler::XPlaneBuilder; -using tsl::profiler::XPlaneVisitor; -using tsl::profiler::XSpace; - -TEST(DcnAnalysis, SetupMessageInfoTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder host_trace_builder(host_trace); - - XEventMetadata *event_metadata_1 = - host_trace_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - XEventMetadata *event_metadata_2 = - host_trace_builder.GetOrCreateEventMetadata(2); - event_metadata_2->set_name(std::string(kMegaScaleDcnSend)); - - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(/*num_tpu_tensor_cores*/ 4, - /*is_megacore*/ false); - dcn_events_processor.SetupMessageInfo(plane); - ASSERT_FALSE(dcn_events_processor.HasDcnMessages(kMegaScaleDcnSend)); - ASSERT_TRUE(dcn_events_processor.HasDcnMessages(kMegaScaleDcnReceive)); - ASSERT_FALSE(dcn_events_processor.HasDcnMessages("Another Message")); - ASSERT_EQ(dcn_events_processor.MegaScaleMessageId(kMegaScaleDcnReceive), 1); - ASSERT_EQ(dcn_events_processor.MegaScaleMessageId(kMegaScaleDcnSend), - std::nullopt); -} - -// Test processing of valid messages and that all of them are received. -TEST(DcnAnalysis, CreateMessageTestValidMessages) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder_0 = xplane_builder.GetOrCreateLine(0); - XLineBuilder xline_builder_1 = xplane_builder.GetOrCreateLine(1); - - // 1st event - XEventBuilder event_builder = xline_builder_0.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(100000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), - "all-reduce.273_312"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 1); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 24); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 50); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 32768); - - // 2nd event, same line - event_builder = xline_builder_0.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(175000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), - "super-collective.1234"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 112); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 1); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 34); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 4); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 50); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 1); - - // 3rd event event, new line, no chunk/loop index - event_builder = xline_builder_1.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(150000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), "super-collective"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 9); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 0); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 75); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 10); - - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - - ASSERT_EQ(dcn_events_processor.NumReceivedMessages(), 3); - EXPECT_THAT(dcn_events_processor.GetMessage(0), - FieldsAre("all-reduce.273_312", /* collective name */ - 2, 3, 1, 3, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns, end_timestamp_ns, duration_us */ - 50000, 100000, 50, - /* size_bytes, chunk_id, loop_index_id */ - 32768, 0, 24, - /* validity_info */ - DCN_MESSAGE_VALID)); - EXPECT_THAT(dcn_events_processor.GetMessage(1), - FieldsAre("super-collective.1234", /* collective name */ - /* slice_src, tpu_src, slice_dst, tpu_dst */ - 112, 1, 34, 2, - /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 125000, 175000, 50, - /* size_bytes, chunk_id, loop_index_id */ - 1, 4, 0, - /* validity_info */ - DCN_MESSAGE_VALID)); - EXPECT_THAT( - dcn_events_processor.GetMessage(2), - FieldsAre("super-collective", /* collective name */ - 9, 3, 0, 0, /* slice_src, tpu_src, slice_dst, tpu_dst */ - 75000, 150000, /* start_timestamp_ns. end_timestamp_ns */ - 75, /* duration_us */ - 10, -1, -1, /* size_bytes, chunk_id, loop_index_id */ - /* validity_info */ - DCN_MESSAGE_VALID)); - TimestampMap host_ts_map = dcn_events_processor.HostTsMap(); - ASSERT_EQ(host_ts_map.size(), 6); - for (const auto &ts_map_item : host_ts_map) { - ASSERT_EQ(ts_map_item.first, ts_map_item.second->timestamp_ns); - if (ts_map_item.first == 50000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 0); - ASSERT_EQ(ts_map_item.second->message_diff, 1); - ASSERT_EQ(ts_map_item.second->size_diff, 32768); - } else if (ts_map_item.first == 125000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 0); - ASSERT_EQ(ts_map_item.second->message_diff, 1); - ASSERT_EQ(ts_map_item.second->size_diff, 1); - } else if (ts_map_item.first == 75000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 0); - ASSERT_EQ(ts_map_item.second->message_diff, 1); - ASSERT_EQ(ts_map_item.second->size_diff, 10); - } else if (ts_map_item.first == 100000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 50000); - ASSERT_EQ(ts_map_item.second->message_diff, -1); - ASSERT_EQ(ts_map_item.second->size_diff, -32768); - } else if (ts_map_item.first == 175000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 50000); - ASSERT_EQ(ts_map_item.second->message_diff, -1); - ASSERT_EQ(ts_map_item.second->size_diff, -1); - } else if (ts_map_item.first == 150000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 75000); - ASSERT_EQ(ts_map_item.second->message_diff, -1); - ASSERT_EQ(ts_map_item.second->size_diff, -10); - } else { - FAIL() << "Unexpected timestamp entry."; - } - } - const std::vector &host_bursts = - dcn_events_processor.GetHostBursts(); - ASSERT_EQ(host_bursts.size(), 1); - ASSERT_EQ(host_bursts[0].num_messages, 3); - ASSERT_EQ(host_bursts[0].start_timestamp_ns, 50000); - ASSERT_EQ(host_bursts[0].end_timestamp_ns, 175000); - ASSERT_EQ(host_bursts[0].burst_size_bytes, 32779); - ASSERT_EQ(host_bursts[0].max_overlapping_messages, 2); -} - -// Loopback message test, currently interpreted as valid. -TEST(DcnAnalysis, CreateLoopBackMessageTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(5000000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), "all-gather.1234"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 2); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 1); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 4); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 40); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 1000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 1000); - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - ASSERT_EQ(dcn_events_processor.NumReceivedMessages(), 1); - EXPECT_THAT(dcn_events_processor.GetMessage(0), - FieldsAre("all-gather.1234", /* collective name */ - 2, 3, 2, 1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 4000000, 5000000, 1000, - /* size_bytes, chunk_id, loop_index_id */ - 1000, 4, 40, - /* validity_info */ - DCN_MESSAGE_VALID_LOOPBACK)); -} - -// Zero duration message, this is due to a bug or clock skew between source -// and destination. Any analysis will just cause confusion, mark it as invalid. -TEST(DcnAnalysis, CreateZeroDurationMessageTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(20000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), - "all-reduce.273_312"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 1); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 1); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 25); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 512); - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - EXPECT_THAT( - dcn_events_processor.GetMessage(0), - FieldsAre("all-reduce.273_312", /* collective name */ - 2, 3, 1, 1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - 20000, 20000, - 0, /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 512, 0, 25, /* size_bytes, chunk_id, loop_index_id */ - /* validity_info */ - DCN_MESSAGE_INVALID_CLOCK_SKEW)); -} - -// Missing key test, make sure it is invalid and correctly initialized. -TEST(DcnAnalysis, CreateMissingKeyTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(50000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 10); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 100); - - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - EXPECT_THAT( - dcn_events_processor.GetMessage(0), - FieldsAre("", /* collective name */ - -1, -1, -1, -1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - 40000, 50000, /* start_timestamp_ns. end_timestamp_ns, */ - 10, /* duration_us */ - 100, -1, -1, /* size_bytes, chunk_id, loop_index_id */ - /* validity_info */ - DCN_MESSAGE_INVALID_BAD_KEY)); -} - -} // namespace - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc deleted file mode 100644 index d2b1e7abd59a3b..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h" - -#include - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::DcnSlackAnalysis; -using tensorflow::profiler::DcnSlackSummary; -using tsl::profiler::SafeDivide; - -void DcnSlackAnalysisCombiner::Combine(const DcnSlackAnalysis& slack_analysis) { - for (const auto& slack : slack_analysis.dcn_slack_summary()) { - uint64_t occurrences = slack.occurrences(); - DcnSlackSummary& summary = slack_summary_[slack.rendezvous()]; - summary.set_slack_us(summary.slack_us() + slack.slack_us() * occurrences); - summary.set_observed_duration_us(summary.observed_duration_us() + - slack.observed_duration_us() * - occurrences); - summary.set_stall_duration_us(summary.stall_duration_us() + - slack.stall_duration_us() * occurrences); - summary.set_send_done_duration_us(summary.send_done_duration_us() + - slack.send_done_duration_us() * - occurrences); - summary.set_recv_done_duration_us(summary.recv_done_duration_us() + - slack.recv_done_duration_us() * - occurrences); - summary.set_send_duration_us(summary.send_duration_us() + - slack.send_duration_us() * occurrences); - summary.set_recv_duration_us(summary.recv_duration_us() + - slack.recv_duration_us() * occurrences); - summary.set_host_stall_us(summary.host_stall_us() + - slack.host_stall_us() * occurrences); - summary.set_occurrences(summary.occurrences() + slack.occurrences()); - summary.set_bytes_transmitted_over_network( - slack.bytes_transmitted_over_network()); - summary.set_recv_op_name(slack.recv_op_name()); - summary.set_send_op_name(slack.send_op_name()); - summary.set_transfer_type(slack.transfer_type()); - } -} - -DcnSlackAnalysis DcnSlackAnalysisCombiner::Finalize() { - DcnSlackAnalysis analysis; - for (const auto& [rendezvous, summary] : slack_summary_) { - auto* slack = analysis.add_dcn_slack_summary(); - slack->set_rendezvous(rendezvous); - slack->set_recv_op_name(summary.recv_op_name()); - slack->set_send_op_name(summary.send_op_name()); - slack->set_transfer_type(summary.transfer_type()); - slack->set_slack_us(SafeDivide(summary.slack_us(), summary.occurrences())); - slack->set_observed_duration_us( - SafeDivide(summary.observed_duration_us(), summary.occurrences())); - slack->set_stall_duration_us( - SafeDivide(summary.stall_duration_us(), summary.occurrences())); - slack->set_send_done_duration_us( - SafeDivide(summary.send_done_duration_us(), summary.occurrences())); - slack->set_recv_done_duration_us( - SafeDivide(summary.recv_done_duration_us(), summary.occurrences())); - slack->set_send_duration_us( - SafeDivide(summary.send_duration_us(), summary.occurrences())); - slack->set_recv_duration_us( - SafeDivide(summary.recv_duration_us(), summary.occurrences())); - slack->set_host_stall_us( - SafeDivide(summary.host_stall_us(), summary.occurrences())); - slack->set_occurrences(summary.occurrences()); - slack->set_bytes_transmitted_over_network( - summary.bytes_transmitted_over_network()); - } - - return analysis; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h deleted file mode 100644 index f0fc727a62dcc1..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::DcnSlackAnalysis; -using tensorflow::profiler::DcnSlackSummary; - -class DcnSlackAnalysisCombiner { - private: - absl::flat_hash_map slack_summary_; - - public: - // Combine the DCN Slack Summary in the DcnSlackAnalysis. - // The DcnSlackAnalysis consists of average durations, The combine phase, the - // summary consists of the total duration for all the occurrences. Finazile - // must be called to get the accurate value. - void Combine(const DcnSlackAnalysis& slack_analysis); - - // Finalize the DcnSlackSummary by converting total durations to averages. - DcnSlackAnalysis Finalize(); -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/dcn_utils.cc b/tensorflow/core/profiler/convert/dcn_utils.cc deleted file mode 100644 index 6a457053c30b85..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_utils.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::MicroToNano; -using tsl::profiler::StatType; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XStatVisitor; - -DcnMessage CreateDcnMessageFromStats(const XEventVisitor& event_visitor) { - DcnMessage dcn_message; - event_visitor.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type()) return; - switch (static_cast(*stat.Type())) { - case StatType::kDcnLabel: { - dcn_message.collective_name = stat.ToString(); - break; - } - case StatType::kDcnSourceSliceId: { - dcn_message.slice_src = stat.IntValue(); - break; - } - case StatType::kDcnSourcePerSliceDeviceId: { - dcn_message.tpu_src = stat.IntValue(); - break; - } - case StatType::kDcnDestinationSliceId: { - dcn_message.slice_dst = stat.IntValue(); - break; - } - case StatType::kDcnDestinationPerSliceDeviceId: { - dcn_message.tpu_dst = stat.IntValue(); - break; - } - case StatType::kDcnChunk: { - dcn_message.chunk_id = stat.IntValue(); - break; - } - case StatType::kDcnLoopIndex: { - dcn_message.loop_index_id = stat.IntValue(); - - break; - } - case StatType::kPayloadSizeBytes: { - dcn_message.size_bytes = stat.IntValue(); - break; - } - case StatType::kDuration: { - dcn_message.duration_us = stat.IntOrUintValue(); - dcn_message.start_timestamp_ns = - event_visitor.TimestampNs() - MicroToNano(dcn_message.duration_us); - dcn_message.end_timestamp_ns = event_visitor.TimestampNs(); - break; - } - default: - break; - } - }); - return dcn_message; -} - -// Analyze message to see if it can be directly processed or it falls under -// corner-case categories, or if there is something wrong with it. -void SetMessageValidity(DcnMessage& dcn_message) { - // Message should not be valid if fields have not been set properly - // The main use of that is to detect unexpected key format changes that do - // not cause crashes. - if (dcn_message.collective_name.empty() || dcn_message.slice_src == -1 || - dcn_message.tpu_src == -1 || dcn_message.slice_dst == -1 || - dcn_message.tpu_dst == -1 || dcn_message.size_bytes == -1) { - dcn_message.validity_info = DCN_MESSAGE_INVALID_BAD_KEY; - } else if (dcn_message.duration_us == 0) { - // Destination timestamp smaller than the source timestamp likely due to - // clock skew - dcn_message.validity_info = DCN_MESSAGE_INVALID_CLOCK_SKEW; - } else if (dcn_message.slice_src == dcn_message.slice_dst) { - // Loopback messages remain on the same host, so they are valid - // even though they should not go through DCN. - // TODO(emizan): Get host/TPU info and check host, not slice. - dcn_message.validity_info = DCN_MESSAGE_VALID_LOOPBACK; - } else { - dcn_message.validity_info = DCN_MESSAGE_VALID; - } -} -} // namespace - -DcnMessage GetDcnMessageFromXEvent(const XEventVisitor& event_visitor) { - DcnMessage dcn_message = CreateDcnMessageFromStats(event_visitor); - SetMessageValidity(dcn_message); - return dcn_message; -} - -bool IsDcnEvent(const tsl::profiler::XEventVisitor& event) { - return absl::StartsWith(event.Name(), "MegaScale:"); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_utils.h b/tensorflow/core/profiler/convert/dcn_utils.h deleted file mode 100644 index e0dd3a174df919..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_utils.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ - -#include - -#include "xla/tsl/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// DCN Message Validity -enum DcnMessageValidity { - // Valid message - DCN_MESSAGE_VALID = 1, - // Valid message, but should not go through DCN, so it should not use BW. - DCN_MESSAGE_VALID_LOOPBACK = 2, - // Invalid message with 0 duration due to clock skew. Should be ignored. - DCN_MESSAGE_INVALID_CLOCK_SKEW = 3, - // Message that cannot be decoded. Should be ignored. - DCN_MESSAGE_INVALID_BAD_KEY = 4 -}; - -// Structure representing a DCN event -struct DcnMessage { - // Unique collective that generated this message, format should be - // _, e.g. all_gather_34 - std::string collective_name = ""; - // Src info - // TODO(emizan) Add host info when you figure out how to get it from - // slice+tpu. - int32_t slice_src = -1; - int32_t tpu_src = -1; - // Dst info - int32_t slice_dst = -1; - int32_t tpu_dst = -1; - // Timing info in ns. Since MSXLA TraceMe's have us timestamps, we need to - // multiply by 1000 to get these timestamps. - uint64_t start_timestamp_ns = 0; - uint64_t end_timestamp_ns = 0; - uint64_t duration_us = 0; - // Size info - size_t size_bytes = 0; - // Chunk and Loop index - int32_t chunk_id = -1; - int32_t loop_index_id = -1; - // Is message valid/invalid and why - DcnMessageValidity validity_info = DCN_MESSAGE_INVALID_BAD_KEY; - // TBD: Add flow events in case you need to connect to other events pointed to - // by MSXLA TraceMe's -}; - -DcnMessage GetDcnMessageFromXEvent( - const tsl::profiler::XEventVisitor& event_visitor); - -// Check if the XEventVisitor is a DCN Message -bool IsDcnEvent(const tsl::profiler::XEventVisitor& event); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ diff --git a/tensorflow/core/profiler/convert/dcn_utils_test.cc b/tensorflow/core/profiler/convert/dcn_utils_test.cc deleted file mode 100644 index 8789da9d07b8f8..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_utils_test.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::kMegaScaleDcnReceive; -using tsl::profiler::XEventBuilder; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XLineBuilder; -using tsl::profiler::XPlaneBuilder; -using tsl::profiler::XPlaneVisitor; - -void PopulateXPlane(XPlane &xplane, absl::string_view event_name, int offset, - absl::string_view label, int64_t source_slice_id, - int64_t source_per_slice_device_id, - int64_t destination_slice_id, - int64_t destination_per_slice_device_id, int64_t chunk, - int64_t loop_index, int64_t payload_size, - int64_t duration) { - XPlaneBuilder xplane_builder(&xplane); - - XEventMetadata *event_metadata = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata->set_name(std::string(event_name)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata); - event_builder.SetOffsetNs(offset); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), label); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), - source_slice_id); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - source_per_slice_device_id); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), - destination_slice_id); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - destination_per_slice_device_id); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), chunk); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), loop_index); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), duration); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), - payload_size); -} - -TEST(DcnUtilsTest, IsDcnEvent) { - XPlane xplane; - PopulateXPlane(xplane, kMegaScaleDcnReceive, 0, "test", 0, 0, 0, 0, 0, 0, 0, - 0); - XLine line = xplane.lines()[0]; - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - - XEventVisitor visitor(&xplane_visitor, &line, &line.events()[0]); - EXPECT_TRUE(IsDcnEvent(visitor)); -} - -TEST(DcnUtilsTest, IsNotDcnEvent) { - XPlane xplane; - PopulateXPlane(xplane, "test", 0, "test", 0, 0, 0, 0, 0, 0, 0, 0); - XLine line = xplane.lines()[0]; - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - - XEventVisitor visitor(&xplane_visitor, &line, &line.events()[0]); - EXPECT_FALSE(IsDcnEvent(visitor)); -} - -TEST(DcnUtilsTest, GetDcnMessageFromXEvent) { - XPlane xplane; - PopulateXPlane(xplane, kMegaScaleDcnReceive, 100000, "all-reduce.273_312", 2, - 3, 1, 3, 0, 24, 32768, 50); - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - XEventVisitor visitor(&xplane_visitor, &xplane.lines()[0], - &xplane.lines()[0].events()[0]); - EXPECT_THAT(GetDcnMessageFromXEvent(visitor), - testing::FieldsAre( - "all-reduce.273_312", /* collective name */ - 2, 3, 1, 3, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns, end_timestamp_ns, duration_us */ - 50000, 100000, 50, - /* size_bytes, chunk_id, loop_index_id */ - 32768, 0, 24, - /* validity_info */ - DCN_MESSAGE_VALID)); -} - -TEST(DcnUtilsTest, GetDcnMessageFromXEventLoopBack) { - XPlane xplane; - PopulateXPlane(xplane, kMegaScaleDcnReceive, 5000000, "all-gather.1234", 2, 3, - 2, 1, 4, 40, 1000, 1000); - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - XEventVisitor visitor(&xplane_visitor, &xplane.lines()[0], - &xplane.lines()[0].events()[0]); - EXPECT_THAT(GetDcnMessageFromXEvent(visitor), - testing::FieldsAre( - "all-gather.1234", /* collective name */ - 2, 3, 2, 1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 4000000, 5000000, 1000, - /* size_bytes, chunk_id, loop_index_id */ - 1000, 4, 40, - /* validity_info */ - DCN_MESSAGE_VALID_LOOPBACK)); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_combiner.h b/tensorflow/core/profiler/convert/duty_cycle_combiner.h deleted file mode 100644 index 74b2e0ebdc9aba..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_combiner.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_COMBINER_H_ - -#include - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Responsible for combining the duty cycle trackers for all cores and chips. -class DutyCycleCombiner { - public: - // Combines the given core tracker with the tracker for the given chip. - // NOTE: The given chip_id should be unique across all chips being combined. - void CombineCore(const DutyCycleTracker& core_tracker, uint32_t chip_id) { - chip_duty_cycle_trackers_[chip_id].Union(core_tracker); - } - - // Combines the given chip tracker with the tracker for other chips. - void CombineChip(const DutyCycleTracker& chip_tracker) { - chip_active_time_ps_ += chip_tracker.GetActiveTimePs(); - chip_idle_time_ps_ += chip_tracker.GetIdleTimePs(); - } - - // Returns the total active time across all chips and cores. - uint64_t GetTotalActiveTimePs() const { - uint64_t total_busy_time_ps = chip_active_time_ps_; - for (const auto& [chip_id, tracker] : chip_duty_cycle_trackers_) { - total_busy_time_ps += tracker.GetActiveTimePs(); - } - return total_busy_time_ps; - } - - // Returns the total idle time across all chips and cores. - uint64_t GetTotalIdleTimePs() const { - uint64_t total_idle_time_ps = chip_idle_time_ps_; - for (const auto& [chip_id, tracker] : chip_duty_cycle_trackers_) { - total_idle_time_ps += tracker.GetIdleTimePs(); - } - return total_idle_time_ps; - } - - private: - absl::flat_hash_map chip_duty_cycle_trackers_; - uint64_t chip_active_time_ps_ = 0; - uint64_t chip_idle_time_ps_ = 0; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc b/tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc deleted file mode 100644 index 6a9e158b43da5b..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/duty_cycle_combiner.h" - -#include -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::profiler::Timespan; - -TEST(DutyCycleAnalysisTest, CombineMultiCoreChipTest) { - DutyCycleTracker core0_tracker; - core0_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - core0_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - DutyCycleTracker core1_tracker; - core1_tracker.AddInterval(Timespan::FromEndPoints(10, 20), false); - core1_tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - - DutyCycleCombiner combiner; - combiner.CombineCore(core0_tracker, 0); - combiner.CombineCore(core1_tracker, 0); - - EXPECT_EQ(combiner.GetTotalActiveTimePs(), 20); - EXPECT_EQ(combiner.GetTotalIdleTimePs(), 0); -} - -TEST(DutyCycleAnalysisTest, CombineMultiChipTest) { - DutyCycleTracker chip0_tracker; - chip0_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - chip0_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - DutyCycleTracker chip1_tracker; - chip1_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - chip1_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - - DutyCycleCombiner combiner; - combiner.CombineChip(chip0_tracker); - combiner.CombineChip(chip1_tracker); - - EXPECT_EQ(combiner.GetTotalActiveTimePs(), 20); - EXPECT_EQ(combiner.GetTotalIdleTimePs(), 20); -} - -TEST(DutyCycleAnalysisTest, CombineMultiChipAndCoreTest) { - DutyCycleTracker chip0_core0_tracker; - chip0_core0_tracker.AddInterval(Timespan::FromEndPoints(10, 20), false); - chip0_core0_tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - DutyCycleTracker chip0_core1_tracker; - chip0_core1_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - chip0_core1_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - DutyCycleTracker chip1_tracker; - chip1_tracker.AddInterval(Timespan::FromEndPoints(15, 25), true); - chip1_tracker.AddInterval(Timespan::FromEndPoints(10, 30), false); - - DutyCycleCombiner combiner; - combiner.CombineCore(chip0_core0_tracker, 0); - combiner.CombineCore(chip0_core1_tracker, 0); - combiner.CombineChip(chip1_tracker); - - EXPECT_EQ(combiner.GetTotalActiveTimePs(), 30); - EXPECT_EQ(combiner.GetTotalIdleTimePs(), 10); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker.cc b/tensorflow/core/profiler/convert/duty_cycle_tracker.cc deleted file mode 100644 index 96d793e86b31ea..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_tracker.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" - -#include - -#include -#include -#include - -#include "absl/container/btree_set.h" -#include "absl/log/check.h" -#include "xla/tsl/profiler/utils/timespan.h" - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::Timespan; - -DutyCycleTracker::ActiveTimeSpans::const_iterator -DutyCycleTracker::MergeOrInsert(const Timespan& timespan, - ActiveTimeSpans::const_iterator hint) { - DCHECK(hint == active_time_spans_.end() || - hint == active_time_spans_.begin() || - hint->begin_ps() <= timespan.begin_ps()); - ActiveTimeSpans::const_iterator merge_begin = hint; - while (merge_begin != active_time_spans_.end() && - merge_begin->end_ps() < timespan.begin_ps()) { - ++merge_begin; - } - - // timespan is fully contained in an existing timespan. - if (merge_begin != active_time_spans_.end() && - merge_begin->Includes(timespan)) { - return merge_begin; - } - - ActiveTimeSpans::const_iterator merge_end = merge_begin; - while (merge_end != active_time_spans_.end() && - merge_end->begin_ps() <= timespan.end_ps()) { - ++merge_end; - } - if (merge_begin != merge_end) { - Timespan merged = Timespan::FromEndPoints( - std::min(timespan.begin_ps(), merge_begin->begin_ps()), - std::max(timespan.end_ps(), std::prev(merge_end)->end_ps())); - merge_end = active_time_spans_.erase(merge_begin, merge_end); - return active_time_spans_.insert(merge_end, merged); - } else { - // There is no overlap with the existing timespans. - return active_time_spans_.insert(merge_begin, timespan); - } -} - -void DutyCycleTracker::AddInterval(tsl::profiler::Timespan time_span, - bool is_active) { - total_time_span_.ExpandToInclude(time_span); - if (!is_active) { - return; - } - - auto hint = active_time_spans_.lower_bound(time_span); - if (hint != active_time_spans_.begin()) --hint; - MergeOrInsert(time_span, hint); -} - -void DutyCycleTracker::Union(const DutyCycleTracker& other) { - total_time_span_.ExpandToInclude(other.total_time_span_); - if (other.active_time_spans_.empty()) return; - ActiveTimeSpans::const_iterator hint_it = - active_time_spans_.lower_bound(*other.active_time_spans_.begin()); - if (hint_it != active_time_spans_.begin()) --hint_it; - for (const auto& interval : other.active_time_spans_) { - hint_it = MergeOrInsert(interval, hint_it); - } -} - -uint64_t DutyCycleTracker::GetActiveTimePs() const { - uint64_t active_time_ps = 0; - for (const auto& interval : active_time_spans_) { - DCHECK(!interval.Empty()); - active_time_ps += interval.duration_ps(); - } - return active_time_ps; -} - -uint64_t DutyCycleTracker::GetIdleTimePs() const { - return total_time_span_.duration_ps() - GetActiveTimePs(); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker.h b/tensorflow/core/profiler/convert/duty_cycle_tracker.h deleted file mode 100644 index bf5160d97d3037..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_tracker.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ - -#include - -#include "absl/container/btree_set.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" - -namespace tensorflow { -namespace profiler { - -// Tracks the active time intervals for a given TPU core. -// Disjoint intervals of time in ps for which this core was active. -class DutyCycleTracker { - public: - DutyCycleTracker() : active_time_spans_() {} - void AddInterval(tsl::profiler::Timespan time_span, bool is_active); - void Union(const DutyCycleTracker& other); - uint64_t GetActiveTimePs() const; - uint64_t GetIdleTimePs() const; - uint64_t GetDurationPs() const { return total_time_span_.duration_ps(); } - double DutyCycle() const { - return tsl::profiler::SafeDivide(GetActiveTimePs(), GetDurationPs()); - } - - private: - struct TimespanComparator { - // Order by increasing begin_ps, then decreasing duration_ps. - bool operator()(const tsl::profiler::Timespan& a, - const tsl::profiler::Timespan& b) const { - return a.begin_ps() < b.begin_ps() || (a.begin_ps() == b.begin_ps() && - a.duration_ps() > b.duration_ps()); - } - }; - using ActiveTimeSpans = - absl::btree_set; - - /** - * Merge or insert the given timespan into the set of active time spans. - * - * @param timespan The timespan to merge or insert. - * @param hint The iterator indicating where to begin the merge search. - * @return The iterator where the timespan was merged or inserted. - */ - ActiveTimeSpans::const_iterator MergeOrInsert( - const tsl::profiler::Timespan& timespan, - ActiveTimeSpans::const_iterator hint); - - ActiveTimeSpans active_time_spans_; - tsl::profiler::Timespan total_time_span_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc b/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc deleted file mode 100644 index 2ee0218d986f54..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" - -#include - -#include -#include - -#include -#include "absl/log/check.h" -#include "xla/tsl/platform/test_benchmark.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::profiler::Timespan; - -TEST(DutyCycleTrackerTest, NonOverlappingIntervalsTest) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); - EXPECT_EQ(tracker.GetActiveTimePs(), 20); - EXPECT_EQ(tracker.GetIdleTimePs(), 10); - EXPECT_EQ(tracker.GetDurationPs(), 30); - EXPECT_NEAR(tracker.DutyCycle(), 0.6666, 0.0001); -} - -TEST(DutyCycleTrackerTest, OverlappingIntervalsTest) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 35), true); - EXPECT_EQ(tracker.GetActiveTimePs(), 30); - EXPECT_EQ(tracker.GetIdleTimePs(), 0); - EXPECT_EQ(tracker.GetDurationPs(), 30); - EXPECT_EQ(tracker.DutyCycle(), 1.0); -} - -TEST(DutyCycleTrackerTest, DutyCycleTestWithIncludedIntervals) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(10, 40), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - EXPECT_EQ(tracker.GetActiveTimePs(), 30); - EXPECT_EQ(tracker.GetIdleTimePs(), 0); - EXPECT_EQ(tracker.GetDurationPs(), 30); - EXPECT_EQ(tracker.DutyCycle(), 1.0); -} - -TEST(DutyCycleTrackerTest, UnionTest) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(0, 10), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - - DutyCycleTracker other_tracker; - other_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - other_tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); - - tracker.Union(other_tracker); - EXPECT_EQ(tracker.GetActiveTimePs(), 40); - EXPECT_EQ(tracker.GetIdleTimePs(), 0); - EXPECT_EQ(tracker.GetDurationPs(), 40); -} - -TEST(DutyCycleTrackerTest, OverlappingMixedIntervalsTest) { - DutyCycleTracker tracker; - EXPECT_EQ(tracker.GetActiveTimePs(), 0); - tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - EXPECT_EQ(tracker.GetActiveTimePs(), 10); - EXPECT_EQ(tracker.GetIdleTimePs(), 10); -} - -void BM_DutyCycleTracker_AddInterval(::testing::benchmark::State& state) { - std::vector timespans; - timespans.reserve(state.range(0)); - for (uint64_t i = 0; i < state.range(0); ++i) { - timespans.push_back(Timespan::FromEndPoints(i * 2, i * 2 + 1)); - } - for (auto s : state) { - DutyCycleTracker tracker; - for (const auto& timespan : timespans) { - tracker.AddInterval(timespan, true); - } - } - state.SetItemsProcessed(state.iterations() * timespans.size()); -} - -BENCHMARK(BM_DutyCycleTracker_AddInterval)->Range(1 << 15, 1 << 21); - -void BM_DutyCycleTracker_AddInterval_Merge(::testing::benchmark::State& state) { - std::vector timespans; - timespans.reserve(state.range(0)); - for (uint64_t i = 0; i < state.range(0); ++i) { - timespans.push_back(Timespan::FromEndPoints(i, i + 1)); - } - for (auto s : state) { - DutyCycleTracker tracker; - for (const auto& timespan : timespans) { - tracker.AddInterval(timespan, true); - } - } - state.SetItemsProcessed(state.iterations() * timespans.size()); -} - -BENCHMARK(BM_DutyCycleTracker_AddInterval_Merge)->Range(1 << 15, 1 << 21); - -void BM_DutyCycleTracker_Union(::testing::benchmark::State& state) { - DCHECK_GT(state.range(1), 1); - DCHECK_LT(state.range(1), state.range(0)); - DutyCycleTracker tracker_a; - DutyCycleTracker tracker_b; - uint64_t merge_rate = state.range(1); - for (uint64_t i = 0; i < state.range(0); ++i) { - tracker_a.AddInterval(Timespan(i * 2, 1), true); - if (i % merge_rate == 0) { - tracker_b.AddInterval(Timespan(i * 2 + 1, merge_rate * 2 - 1), true); - } - } - for (auto s : state) { - DutyCycleTracker unioned_tracker; - unioned_tracker.Union(tracker_a); - unioned_tracker.Union(tracker_b); - } - state.SetItemsProcessed(state.iterations() * - (state.range(0) + state.range(0) / merge_rate)); -} - -BENCHMARK(BM_DutyCycleTracker_Union)->RangePair(1 << 10, 1 << 16, 2, 10); - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc deleted file mode 100644 index 43351cada4ca7c..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc +++ /dev/null @@ -1,554 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_print_options.h" -#include "xla/tsl/platform/statusor.h" -#ifdef PLATFORM_GOOGLE -#include "nlohmann/json.hpp" -#include "tensorflow/compiler/mlir/lite/experimental/google/tooling/hlo_adapter/direct_hlo_to_json_graph_convert.h" -#endif // PLATFORM_GOOGLE -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_graph_dumper.h" -#include "xla/tsl/platform/errors.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::StatusOr; -using ::tsl::errors::InvalidArgument; -using ::xla::HloComputation; -using ::xla::HloInstruction; -using ::xla::HloModule; -using ::xla::HloPrintOptions; -using ::xla::HloProto; -using ::xla::HloRenderOptions; -using ::xla::RenderedGraphFormat; - -constexpr char kCenterNodeKey[] = "centerNode"; - -void CleanUpHloModuleForGraphviz(HloModule* hlo_module) { - // Infeed config is escaped serialized proto, and graphviz server complains. - for (HloComputation* computation : hlo_module->computations()) { - for (HloInstruction* inst : computation->instructions()) { - if (inst->opcode() == xla::HloOpcode::kInfeed) { - inst->set_infeed_config(""); - } else if (inst->opcode() == xla::HloOpcode::kOutfeed) { - inst->set_outfeed_config(""); - } - } - } -} - -#ifdef PLATFORM_GOOGLE -// Add a custom group node on the graph level, for the center node chosen by the -// user set its attributes like `id`, `name` or `opcode` in `graph_json`. -void AddCenterNodeMetadata(nlohmann::json& graph_json, std::string id, - absl::string_view name, absl::string_view opcode) { - nlohmann::json centerGroupNodeAttributes; - centerGroupNodeAttributes["name"] = name; - centerGroupNodeAttributes["id"] = id; - if (!opcode.empty()) { - centerGroupNodeAttributes["opcode"] = opcode; - } - // Follow ModelExplorer's Graph typing: GraphCollectionFromBuiltinAdapters - graph_json[0]["subgraphs"][0]["groupNodeAttributes"][kCenterNodeKey] = - centerGroupNodeAttributes; -} -#endif // PLATFORM_GOOGLE - -void AddGraphMetadata(std::string& graph_json_str, - const HloInstruction& instr) { -#ifdef PLATFORM_GOOGLE - nlohmann::json graph_json = nlohmann::json::parse(graph_json_str); - // 1. Fusion instruction is represented as a layer on client, use its - // pinned node as the center node, id of the pinned node is the fusion name. - // 2. Other instructions are represented as nodes on client, use iteself as - // the center node, where node id is the instruction name. - std::string id = absl::StrCat(instr.name()); - AddCenterNodeMetadata(graph_json, id, instr.name(), - HloOpcodeString(instr.opcode())); - graph_json_str = graph_json.dump(); -#endif // PLATFORM_GOOGLE -} - -void AddGraphMetadata(std::string& graph_json_str, const HloComputation& comp) { -#ifdef PLATFORM_GOOGLE - nlohmann::json graph_json = nlohmann::json::parse(graph_json_str); - // Computation is represented as a layer on client, use its pinned node as the - // center node,id of the pinned node is the computation name. - AddCenterNodeMetadata(graph_json, absl::StrCat(comp.name()), comp.name(), ""); - graph_json_str = graph_json.dump(); -#endif // PLATFORM_GOOGLE -} - -// This function does the same thing as Plot() but uses the ModelExplorer -// instead of graphviz. -absl::StatusOr PlotMe(std::unique_ptr module, - const std::string& node_name, - int graph_width) { - if (node_name.empty()) { - // This should not happen. - return InvalidArgument("node_name should not be empty"); - } - // Find the node with the given name. - const HloInstruction* instr = FindInstruction(*module, node_name); - const HloComputation* comp = FindComputation(*module, node_name); - - if (!instr && !comp) { - return InvalidArgument( - absl::StrCat("Couldn't find HloInstruction or HloComputation named ", - node_name, ".")); - } - // Generate the graph and print the resulting string. - absl::StatusOr graph_handle; - std::string graph_json_str; -// b/360874576: Enable when the adapter is open sourced. -#ifdef PLATFORM_GOOGLE - if (comp) { - graph_handle = tooling::visualization_client::HloGraphAdapter(*comp); - } else { - graph_handle = - tooling::visualization_client::HloGraphAdapter(*instr, graph_width); - } -#endif // PLATFORM_GOOGLE - if (graph_handle.ok()) { - VLOG(1) << graph_handle.value(); - graph_json_str = graph_handle.value(); - if (comp) { - AddGraphMetadata(graph_json_str, *comp); - } else { - AddGraphMetadata(graph_json_str, *instr); - } - return graph_json_str; - } else { - LOG(ERROR) << "Unable to render graph: " << graph_handle.status(); - } - - return graph_handle; -} - -absl::StatusOr Plot(std::unique_ptr module, - const std::string& node_name, int graph_width, - const HloRenderOptions& render_options, - const RenderedGraphFormat& format) { - if (node_name.empty()) { - // This should not happen. - return InvalidArgument("node_name should not be empty"); - } - // Find the node with the given name. - const HloInstruction* instr = FindInstruction(*module, node_name); - const HloComputation* comp = FindComputation(*module, node_name); - if (!instr && !comp) { - return InvalidArgument( - absl::StrCat("Couldn't find HloInstruction or HloComputation named ", - node_name, ".")); - } - // Generate the graph and print the resulting string. - absl::StatusOr graph_handle; - - CleanUpHloModuleForGraphviz(module.get()); - if (comp) { - graph_handle = - RenderGraphView(*comp, "", comp->parent()->config().debug_options(), - format, render_options); - } else { - graph_handle = RenderGraphNeighborhoodAround(*instr, graph_width, format, - render_options); - } - if (graph_handle.ok()) { - VLOG(1) << graph_handle.value(); - } else { - LOG(ERROR) << "Unable to render graph: " << graph_handle.status(); - } - - return graph_handle; -} - -// Default parameter constants for graph viewer. -static constexpr char kGraphTypeName[] = "graph"; -static constexpr char kShortTxtTypeName[] = "short_txt"; -static constexpr char kLongTxtTypeName[] = "long_txt"; -static constexpr char kDefaultFormatString[] = "url"; -static constexpr int kDefaultWidth = 3; -static constexpr int kDefaultShowMetadata = 0; -static constexpr int kDefaultMergeFusion = 0; - -} // namespace - -absl::StatusOr GetNodeStyles() { - std::vector async_op_codes = {xla::HloOpcode::kAsyncStart, - xla::HloOpcode::kAsyncUpdate, - xla::HloOpcode::kAsyncDone}; - std::vector brown_op_codes = { - xla::HloOpcode::kAllGather, - xla::HloOpcode::kAllGatherStart, - xla::HloOpcode::kAllGatherDone, - xla::HloOpcode::kAllReduce, - xla::HloOpcode::kReduceScatter, - xla::HloOpcode::kAllReduceStart, - xla::HloOpcode::kAllReduceDone, - xla::HloOpcode::kAllToAll, - xla::HloOpcode::kCollectiveBroadcast, - xla::HloOpcode::kCollectivePermute, - xla::HloOpcode::kCollectivePermuteStart, - xla::HloOpcode::kCollectivePermuteDone, - xla::HloOpcode::kInfeed, - xla::HloOpcode::kOutfeed, - xla::HloOpcode::kPartitionId, - xla::HloOpcode::kRecv, - xla::HloOpcode::kRecvDone, - xla::HloOpcode::kSend, - xla::HloOpcode::kSendDone, - xla::HloOpcode::kReplicaId}; - std::vector dark_blue_op_codes = { - xla::HloOpcode::kConvolution, xla::HloOpcode::kDot, xla::HloOpcode::kFft, - xla::HloOpcode::kTriangularSolve, xla::HloOpcode::kCholesky}; - std::vector dark_green_op_codes = { - xla::HloOpcode::kCall, xla::HloOpcode::kConditional, - xla::HloOpcode::kCustomCall, xla::HloOpcode::kWhile}; - std::vector gray_op_codes = { - xla::HloOpcode::kDomain, xla::HloOpcode::kFusion, xla::HloOpcode::kMap, - xla::HloOpcode::kGetDimensionSize, xla::HloOpcode::kSetDimensionSize}; - std::vector green_op_codes = { - xla::HloOpcode::kConcatenate, xla::HloOpcode::kDynamicSlice, - xla::HloOpcode::kReshape, xla::HloOpcode::kDynamicReshape, - xla::HloOpcode::kReverse, xla::HloOpcode::kTranspose, - xla::HloOpcode::kCopy, xla::HloOpcode::kCopyStart, - xla::HloOpcode::kCopyDone}; - std::vector orange_op_codes = {xla::HloOpcode::kParameter}; - std::vector purple_op_codes = { - xla::HloOpcode::kBatchNormGrad, xla::HloOpcode::kBatchNormInference, - xla::HloOpcode::kBatchNormTraining, xla::HloOpcode::kReduce, - xla::HloOpcode::kReduceWindow, xla::HloOpcode::kScatter, - xla::HloOpcode::kSelectAndScatter, xla::HloOpcode::kGather}; - std::vector yellow_op_codes = { - xla::HloOpcode::kBroadcast, xla::HloOpcode::kDynamicUpdateSlice}; - - auto OpCodesToNames = - [&](std::vector op_codes) -> std::string { - std::string op_names = ""; - for (const auto& op_code : op_codes) { - if (!op_names.empty()) { - op_names += ","; - } - op_names += std::string(xla::HloOpcodeString(op_code)); - } - return op_names; - }; - - return absl::StrReplaceAll( - R"json({ - "kBlue": "$asyncOpNames", - "kBrown": "$brownOpNames", - "kDarkBlue": "$darkBlueOpNames", - "kDarkGreen": "$darkGreenOpNames", - "kGray": "$grayOpNames", - "kGreen": "$greenOpNames", - "kOrange": "$orangeOpNames", - "kPurple": "$purpleOpNames", - "kYellow": "$yellowOpNames" - })json", - { - {"$asyncOpNames", OpCodesToNames(async_op_codes)}, - {"$brownOpNames", OpCodesToNames(brown_op_codes)}, - {"$darkBlueOpNames", OpCodesToNames(dark_blue_op_codes)}, - {"$darkGreenOpNames", OpCodesToNames(dark_green_op_codes)}, - {"$grayOpNames", OpCodesToNames(gray_op_codes)}, - {"$greenOpNames", OpCodesToNames(green_op_codes)}, - {"$orangeOpNames", OpCodesToNames(orange_op_codes)}, - {"$purpleOpNames", OpCodesToNames(purple_op_codes)}, - {"$yellowOpNames", OpCodesToNames(yellow_op_codes)}, - }); -} - -absl::StatusOr ParseGraphViewerParams( - const ToolOptions& options) { - GraphViewerParams params; - std::optional type = GetParam(options, "type"); - if (!type.has_value()) { - return InvalidArgument("Graph viewer must provide a type option."); - } - - // For graph type. - if (type == kGraphTypeName) { - params.type = type.value(); - if (std::optional node_name = - GetParam(options, "node_name")) { - params.node_name = node_name.value(); - } - - params.graph_width = - GetParamWithDefault(options, "graph_width", kDefaultWidth); - params.render_options.show_backend_config = GetParamWithDefault( - options, "show_metadata", kDefaultShowMetadata); - params.render_options.show_fusion_subcomputations = - !GetParamWithDefault(options, "merge_fusion", kDefaultMergeFusion); - params.format = GetRenderFormat(GetParamWithDefault( - options, "format", kDefaultFormatString)); - - return params; - } - - // For txt type. - if (type == kShortTxtTypeName || type == kLongTxtTypeName) { - params.type = type.value(); - params.verbose = (type == kLongTxtTypeName); - params.show_metadata = - GetParamWithDefault(options, "show_metadata", kDefaultShowMetadata); - return params; - } - - // Unknown type. - return InvalidArgument("Unknown graph viewer type option: ", type.value()); -} - -xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string) { - if (format_string == "html") { - return xla::RenderedGraphFormat::kHtml; - } else if (format_string == "dot") { - return xla::RenderedGraphFormat::kDot; - } else if (format_string == "url") { - return xla::RenderedGraphFormat::kUrl; - } else { - LOG(ERROR) << "Invalid graph format argument: " << format_string - << ", fallback to default url"; - return xla::RenderedGraphFormat::kUrl; - } -} - -absl::StatusOr ConvertHloProtoToGraph( - const HloProto& hlo_proto, const std::string& node_name, int graph_width, - const HloRenderOptions& render_options, const RenderedGraphFormat& format) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ConvertHloProtoToModule(hlo_proto)); - return Plot(std::move(hlo_module), node_name, graph_width, render_options, - format); -} - -absl::StatusOr ConvertHloProtoToMeGraph( - const HloProto& hlo_proto, const std::string& node_name, int graph_width) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ConvertHloProtoToModule(hlo_proto)); - return PlotMe(std::move(hlo_module), node_name, graph_width); -} - -absl::StatusOr ConvertHloProtoToStringView( - const HloProto& hlo_proto, bool verbose, bool metadata) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ConvertHloProtoToModule(hlo_proto)); - HloPrintOptions options; - if (!verbose) { - options = HloPrintOptions::ShortParsable(); - } - options.set_print_large_constants(verbose); - options.set_print_metadata(metadata); - return hlo_module->ToString(options); -} - -std::function(absl::string_view)>* url_renderer = - nullptr; - -// Precondition: (url_renderer != nullptr || format != kUrl). -// -// (We specify this as a precondition rather than checking it in here and -// returning an error because we want to fail quickly when there's no URL -// renderer available, and this function runs only after we've done all the work -// of producing dot for the graph.) -absl::Status CheckPrecondition(xla::RenderedGraphFormat format) { - if (format == xla::RenderedGraphFormat::kUrl && url_renderer == nullptr) { - return absl::FailedPreconditionError( - "Can't render as URL; no URL renderer was registered."); - } - return absl::OkStatus(); -} - -absl::StatusOr RenderGraphView( - const xla::HloComputation& computation, absl::string_view label, - const xla::DebugOptions& debug_options, xla::RenderedGraphFormat format, - xla::HloRenderOptions hlo_render_options) { - auto precheck_status = CheckPrecondition(format); - if (!precheck_status.ok()) { - return precheck_status; - } - auto rendered_dot = - xla::RenderGraph(computation, label, debug_options, - RenderedGraphFormat::kDot, hlo_render_options); - if (!rendered_dot.ok()) { - return rendered_dot.status(); - } - return WrapDotInFormat(rendered_dot.value(), format); -} - -absl::StatusOr RenderGraphNeighborhoodAround( - const xla::HloInstruction& node, int radius, - xla::RenderedGraphFormat format, xla::HloRenderOptions hlo_render_options, - const absl::flat_hash_set& boundary) { - auto precheck_status = CheckPrecondition(format); - if (!precheck_status.ok()) { - return precheck_status; - } - auto rendered_dot = xla::RenderNeighborhoodAround( - node, radius, RenderedGraphFormat::kDot, hlo_render_options, boundary); - if (!rendered_dot.ok()) { - return rendered_dot.status(); - } - return WrapDotInFormat(rendered_dot.value(), format); -} - -absl::StatusOr WrapDotInFormat(std::string dot, - xla::RenderedGraphFormat format) { - switch (format) { - case xla::RenderedGraphFormat::kUrl: - if (url_renderer == nullptr) { - return absl::InternalError("url_renderer is null"); - } - return (*url_renderer)(dot); - case xla::RenderedGraphFormat::kHtml: - return WrapDotInHtml(dot); - case xla::RenderedGraphFormat::kDot: - return std::string(dot); - } -} - -std::string WrapDotInHtml(std::string dot, absl::string_view layout_engine) { - return absl::StrReplaceAll( - R"html( - - - - - - - - - -
- - - -)html", - {{"$DOT", dot}, {"$LAYOUT_ENGINE", layout_engine}}); -} - -void RegisterGraphvizURLRenderer( - std::function(absl::string_view)> renderer) { - if (url_renderer != nullptr) { - LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call " - "wins, but because order of initialization in C++ is " - "nondeterministic, this may not be what you want."; - } - delete url_renderer; - url_renderer = - new std::function(absl::string_view)>( - std::move(renderer)); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h deleted file mode 100644 index 6f91f1c10feae4..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ - -#include -#include -#include - -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_graph_dumper.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/tool_options.h" - -namespace tensorflow { -namespace profiler { - -// All the parameters for graph viewer. -struct GraphViewerParams { - // Whether to use GraphView or TxtView. - std::string type; - // Parameters for GraphView. - std::string node_name; - int graph_width; - xla::HloRenderOptions render_options; - xla::RenderedGraphFormat format; - // Parameters for TxtView. - bool verbose; - bool show_metadata; -}; - -// Return mapping from style key word to op names separated by comma. -// following hlo_graph_dumper styling -absl::StatusOr GetNodeStyles(); - -// Parse tool options to get the parameters for graph viewer. -absl::StatusOr ParseGraphViewerParams( - const ToolOptions& options); - -// Get graph render format. -xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string); - -// Convert `hlo_proto` to GraphView with the provided render options. -absl::StatusOr ConvertHloProtoToGraph( - const xla::HloProto& hlo_proto, const std::string& node_name, - int graph_width, const xla::HloRenderOptions& render_options, - const xla::RenderedGraphFormat& format); - -// Convert `hlo_proto` to ModelExplorer Graph JSON data. -absl::StatusOr ConvertHloProtoToMeGraph( - const xla::HloProto& hlo_proto, const std::string& node_name, - int graph_width); - -// Render graph with the provided render options. -absl::StatusOr RenderGraphView( - const xla::HloComputation& computation, absl::string_view label, - const xla::DebugOptions& debug_options, xla::RenderedGraphFormat format, - xla::HloRenderOptions hlo_render_options = {}); - -// Render graph with centered node and depth -absl::StatusOr RenderGraphNeighborhoodAround( - const xla::HloInstruction& node, int radius, - xla::RenderedGraphFormat format, - xla::HloRenderOptions hlo_render_options = {}, - const absl::flat_hash_set& boundary = {}); - -// Convert `hlo_proto` to StringView. -absl::StatusOr ConvertHloProtoToStringView( - const xla::HloProto& hlo_proto, bool verbose, bool metadata); - -// Convert dot into certain format -absl::StatusOr WrapDotInFormat(std::string dot, - xla::RenderedGraphFormat format); - -// Convert dot into visual graph in html -std::string WrapDotInHtml(std::string dot, - absl::string_view layout_engine = "dot"); - -// Registers a function which implements RenderedGraphFormat::kUrl. -// The input to the function is dot, and the output should be a URL or an error. -// There can only be one active renderer, and the last call to this function -// wins. -void RegisterGraphvizURLRenderer( - std::function(absl::string_view dot)> renderer); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc deleted file mode 100644 index b53ec03de2822e..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h" - -#include - -#include -#include "xla/service/hlo_graph_dumper.h" -#include "xla/tsl/platform/status_matchers.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::HasSubstr; -using ::tsl::testing::StatusIs; - -TEST(GraphViewerParamsTest, GraphType) { - // Default for graph type. - ToolOptions options1; - options1["type"] = "graph"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params1, - ParseGraphViewerParams(options1)); - EXPECT_EQ(params1.type, "graph"); - EXPECT_EQ(params1.node_name, ""); - EXPECT_EQ(params1.graph_width, 3); - EXPECT_EQ(params1.render_options.show_backend_config, false); - EXPECT_EQ(params1.render_options.show_fusion_subcomputations, true); - EXPECT_EQ(params1.format, xla::RenderedGraphFormat::kUrl); - - // User defined options for graph type. - ToolOptions options2; - options2["type"] = "graph"; - options2["node_name"] = "fusion.111"; - options2["graph_width"] = 10; - options2["show_metadata"] = 1; - options2["merge_fusion"] = 1; - options2["format"] = "html"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params2, - ParseGraphViewerParams(options2)); - EXPECT_EQ(params2.type, "graph"); - EXPECT_EQ(params2.node_name, "fusion.111"); - EXPECT_EQ(params2.graph_width, 10); - EXPECT_EQ(params2.render_options.show_backend_config, true); - EXPECT_EQ(params2.render_options.show_fusion_subcomputations, false); - EXPECT_EQ(params2.format, xla::RenderedGraphFormat::kHtml); -} - -TEST(GraphViewerParamsTest, ShortTxtType) { - // Default for short txt type. - ToolOptions options1; - options1["type"] = "short_txt"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params1, - ParseGraphViewerParams(options1)); - EXPECT_EQ(params1.type, "short_txt"); - EXPECT_EQ(params1.verbose, false); - EXPECT_EQ(params1.show_metadata, false); - - // User defined options for short txt type. - ToolOptions options2; - options2["type"] = "short_txt"; - options2["show_metadata"] = 1; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params2, - ParseGraphViewerParams(options2)); - EXPECT_EQ(params2.type, "short_txt"); - EXPECT_EQ(params2.verbose, false); - EXPECT_EQ(params2.show_metadata, true); -} - -TEST(GraphViewerParamsTest, LongTxtType) { - // Default for long txt type. - ToolOptions options1; - options1["type"] = "long_txt"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params1, - ParseGraphViewerParams(options1)); - EXPECT_EQ(params1.type, "long_txt"); - EXPECT_EQ(params1.verbose, true); - EXPECT_EQ(params1.show_metadata, false); - - // User defined options for long txt type. - ToolOptions options2; - options2["type"] = "long_txt"; - options2["show_metadata"] = 1; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params2, - ParseGraphViewerParams(options2)); - EXPECT_EQ(params2.type, "long_txt"); - EXPECT_EQ(params2.verbose, true); - EXPECT_EQ(params2.show_metadata, true); -} - -TEST(GraphViewerParamsTest, OtherTypes) { - ToolOptions options1; - EXPECT_THAT(ParseGraphViewerParams(options1), - StatusIs(error::INVALID_ARGUMENT, - HasSubstr("Graph viewer must provide a type option"))); - - ToolOptions options2; - options2["type"] = "abcd"; - EXPECT_THAT(ParseGraphViewerParams(options2), - StatusIs(error::INVALID_ARGUMENT, - HasSubstr("Unknown graph viewer type option: abcd"))); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc deleted file mode 100644 index cf4fce7aecde69..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc +++ /dev/null @@ -1,1108 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/layout_util.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tsl/platform/errors.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::xla::BufferAllocationProto; -using ::xla::HeapSimulatorTrace; -using ::xla::HloInstructionProto; -using ::xla::HloProto; -using ::xla::LayoutUtil; -using ::xla::LogicalBufferProto; -using ::xla::Shape; -using ::xla::ShapeUtil; - -Shape ResolveShapeIndex(const xla::ShapeProto& shape_proto, - absl::Span shape_index) { - if (shape_index.empty()) return Shape(shape_proto); - // Choosing the last subshape to maintain historical behavior. - int64_t i = shape_index.back(); - if (i >= shape_proto.tuple_shapes_size()) { - return Shape(shape_proto); - } - return Shape(shape_proto.tuple_shapes(i)); -} - -std::string ShapeDescription(const Shape& shape) { - return ShapeUtil::HumanStringWithLayout(shape); -} - -// A wrapper around ShapeUtil::ByteSizeOf that clears out the layout/padding, -// since that is considered in the ByteSizeOf calculation. -int64_t ShapeUnpaddedSize(Shape shape) { - // Ensure the layout has no padding by making it the default layout. - LayoutUtil::SetToDefaultLayout(&shape); - // Note: we make a simplifying assumption here that a "minimal" size for a - // tuple member would be the size of a `void*` -- there may be even fancier - // ways of doing things, but this should give a good enough approximation of - // what a minimal tuple size is. - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); -} - -class BufferAllocationStruct { - public: - explicit BufferAllocationStruct(const BufferAllocationProto& proto) - : buffer_allocation_((proto)) {} - bool IsIndefinite() const { - return buffer_allocation_.is_thread_local() || - buffer_allocation_.is_entry_computation_parameter() || - buffer_allocation_.is_constant() || - buffer_allocation_.maybe_live_out(); - } - const BufferAllocationProto& proto() const { return buffer_allocation_; } - size_t size() const { return buffer_allocation_.size(); } - int64_t color() const { return buffer_allocation_.color(); } - int64_t index() const { return buffer_allocation_.index(); } - std::optional heap_simulator_trace_id() const { - return heap_simulator_trace_id_; - } - void set_heap_simulator_trace_id(int64_t id) { - heap_simulator_trace_id_ = id; - } - - // Get buffer allocation category. - std::string category() const { - if (buffer_allocation_.is_entry_computation_parameter()) { - return "Parameter"; - } else if (buffer_allocation_.maybe_live_out()) { - return "Output"; - } else if (buffer_allocation_.is_thread_local()) { - return "Thread-local"; - } else if (buffer_allocation_.is_constant()) { - return "Constant"; - } else { - return "Temporary"; - } - } - - std::string description() const { - return absl::StrFormat( - "buffer_allocation_id:%d\nsize:%d\nbuffer_counts:%d\n", - buffer_allocation_.index(), size(), buffer_allocation_.assigned_size()); - } - - private: - const BufferAllocationProto& buffer_allocation_; - std::optional heap_simulator_trace_id_; -}; - -struct LogicalBufferStruct { - LogicalBufferStruct(const LogicalBufferProto& p, - const BufferAllocationStruct& b, - const ::xla::HloInstructionProto& i, uint64_t offset) - : proto(p), - buffer_allocation(b), - hlo_instruction(i), - offset(offset), - shape(ResolveShapeIndex(hlo_instruction.shape(), - proto.defined_at().shape_index())) {} - - absl::string_view instruction_name() const { return hlo_instruction.name(); } - - int64_t color() const { return proto.color(); } - size_t size() const { return proto.size(); } - size_t unpadded_size() const { return ShapeUnpaddedSize(shape); } - - // reference counting related - int64_t inc() { - if (canonical_buffer) return canonical_buffer->inc(); - return ++ref_count; - } - int64_t dec() { - if (canonical_buffer) return canonical_buffer->dec(); - return --ref_count; - } - int64_t share_with(LogicalBufferStruct* buffer) { - canonical_buffer = buffer; - return canonical_buffer->inc(); - } - LogicalBufferStruct* get_canonical_buffer() { - return canonical_buffer ? canonical_buffer->get_canonical_buffer() : this; - } - - // Get the instruction name with shape index for a logical buffer. - std::string GetInstructionNameWithShapeIndex() const { - if (proto.defined_at().shape_index().empty()) { - return std::string(instruction_name()); - } else { - return absl::StrCat(instruction_name(), "{", - absl::StrJoin(proto.defined_at().shape_index(), ","), - "}"); - } - } - - std::string description() const { - return absl::StrFormat( - "buffer_id:%d\nhlo_op:%s\nshape:%s\nsize:%d\nunpadded_size:%d\n" - "offset:%d\nspan:(%lld,%lld)", - proto.id(), instruction_name(), ShapeDescription(shape), size(), - unpadded_size(), offset, span ? span->first : -1, - span ? span->second : -1); - } - - const LogicalBufferProto& proto; - const BufferAllocationStruct& buffer_allocation; - const ::xla::HloInstructionProto& hlo_instruction; - uint64_t offset; // within the buffer allocation; - // Span within the specific simulator trace. - std::optional> span; - xla::Shape shape; - int64_t ref_count = 0; - LogicalBufferStruct* canonical_buffer = nullptr; -}; - -// A wrapper of HLO BufferAssignment, with lookup maps for logical buffers and -// buffer allocations. -class HloProtoBufferWrapper { - public: - explicit HloProtoBufferWrapper(const ::xla::HloProto& hlo_proto) - : hlo_proto_(hlo_proto) { - Init(); - } - - // Get the heap simulator trace ID using memory color. - // If unable to find the heap simulator trace, return -1. - int64_t GetHeapSimulatorTraceId(const int64_t memory_color) const { - int64_t id = GetHeapSimulatorTraceIdFromBufferAllocationIndex(memory_color); - if (id != -1) { - return id; - } - return GetHeapSimulatorTraceIdFromEvents(memory_color); - } - - // Get the raw HLO proto. - const ::xla::HloProto& GetHloProto() const { return hlo_proto_; } - - std::vector GetBufferAllocations( - int64_t memory_color) const { - std::vector buffer_allocations; - for (const auto& iter : id_to_buffer_allocation_) { - if (iter.second->proto().color() != memory_color) continue; - buffer_allocations.push_back(iter.second.get()); - } - return buffer_allocations; - } - - LogicalBufferStruct* GetLogicalBuffer(int64_t logical_buffer_id) const { - if (!id_to_logical_buffer_.contains(logical_buffer_id)) { - LOG(DFATAL) << "logical_buffer_id " << logical_buffer_id << "not found."; - return nullptr; - } - return id_to_logical_buffer_.at(logical_buffer_id).get(); - } - - // Get the logical buffers with indefinite lifetime (excluding thread_local). - std::vector LogicalBuffersWithIndefiniteLifetime( - int64_t memory_color) const { - std::vector indefinite_logical_buffers; - - for (const auto& buffer_assignment : GetBufferAllocations(memory_color)) { - if (!buffer_assignment->IsIndefinite()) continue; - if (buffer_assignment->proto().is_thread_local()) continue; - // A indefinite buffer allocation will contain multiple logical buffers. - // None of them have a offset, and may have different size than the buffer - // allocation's size. In most cases, if not all cases, one of the logical - // buffer will have the size equal to buffer allocation's size. We will - // pick the biggest logical buffer. - const LogicalBufferStruct* best_logical_buffer = nullptr; - size_t best_size = 0; - for (const auto& assigned : buffer_assignment->proto().assigned()) { - const LogicalBufferStruct* logical_buffer_struct = - GetLogicalBuffer(assigned.logical_buffer_id()); - if (logical_buffer_struct == nullptr) continue; - if (logical_buffer_struct->size() > best_size) { - best_size = logical_buffer_struct->size(); - best_logical_buffer = logical_buffer_struct; - } - } - if (best_logical_buffer) { - indefinite_logical_buffers.push_back(best_logical_buffer); - } - } - return indefinite_logical_buffers; - } - - private: - // Initialize the mappings of logical buffers and buffer allocations. - void Init() { - // A mapping from name to HLO instruction. - absl::flat_hash_map - name_to_hlo; - absl::flat_hash_map - unique_id_to_hlo; - - for (const auto& computation : hlo_proto_.hlo_module().computations()) { - for (const auto& instruction : computation.instructions()) { - name_to_hlo[instruction.name()] = &instruction; - unique_id_to_hlo[instruction.id()] = &instruction; - } - } - - absl::flat_hash_map - id_to_logical_buffer_proto; - for (const auto& logical_buffer : - hlo_proto_.buffer_assignment().logical_buffers()) { - id_to_logical_buffer_proto[logical_buffer.id()] = &logical_buffer; - } - - for (const auto& buffer_allocation : - hlo_proto_.buffer_assignment().buffer_allocations()) { - auto& buffer_allocation_s = - id_to_buffer_allocation_[buffer_allocation.index()]; - buffer_allocation_s = - std::make_unique(buffer_allocation); - for (const auto& assigned : buffer_allocation.assigned()) { - const auto id = assigned.logical_buffer_id(); - if (!id_to_logical_buffer_proto.contains(id)) { - LOG(DFATAL) << "logical_buffer_id " << id << " not found."; - continue; - } - const auto* logical_buffer = id_to_logical_buffer_proto.at(id); - int64_t inst_id = logical_buffer->defined_at().instruction_id(); - if (!unique_id_to_hlo.contains(inst_id)) { - LOG(DFATAL) << "instruction_id " << inst_id << " not found."; - continue; - } - const auto* instruction = unique_id_to_hlo.at(inst_id); - id_to_logical_buffer_[id] = std::make_unique( - *logical_buffer, *buffer_allocation_s, *instruction, - assigned.offset()); - } - } - - const auto& heap_simulator_traces = - hlo_proto_.buffer_assignment().heap_simulator_traces(); - for (int64_t i = 0; i < heap_simulator_traces.size(); i++) { - // The trace's buffer_allocation_index is not trustful, so we are trying - // to obtain the buffer allocation index ourselves. - if (heap_simulator_traces[i].events().empty()) continue; - int logical_buffer_id = heap_simulator_traces[i].events(0).buffer_id(); - if (!id_to_logical_buffer_.contains(logical_buffer_id)) continue; - auto* logical_buffer = id_to_logical_buffer_[logical_buffer_id].get(); - auto buffer_allocation_index = logical_buffer->buffer_allocation.index(); - id_to_buffer_allocation_[buffer_allocation_index] - ->set_heap_simulator_trace_id(i); - } - } - - // From a list of heap simulator traces, identify the one that has the largest - // number of memory events with color . - int64_t GetHeapSimulatorTraceIdFromEvents(const int64_t memory_color) const { - int64_t best_index = -1; - int64_t best_event_count = 0; - for (int64_t i = 0; - i < hlo_proto_.buffer_assignment().heap_simulator_traces_size(); i++) { - const auto& heap_simulator_trace = - hlo_proto_.buffer_assignment().heap_simulator_traces(i); - int64_t event_count = 0; - for (const auto& event : heap_simulator_trace.events()) { - if (!id_to_logical_buffer_.contains(event.buffer_id())) { - LOG(DFATAL) << "buffer_id " << event.buffer_id() << "not found."; - continue; - } - const auto& logical_buffer = - id_to_logical_buffer_.at(event.buffer_id()); - if (logical_buffer->color() == memory_color) { - event_count++; - } - } - if (event_count > best_event_count) { - best_index = i; - best_event_count = event_count; - } - } - return best_index; - } - - // Tries to get heap simulator trace based on buffer_allocation_index. - int64_t GetHeapSimulatorTraceIdFromBufferAllocationIndex( - const int64_t memory_color) const { - auto buffer_allocations = GetBufferAllocations(memory_color); - for (const auto* buffer_allocation : buffer_allocations) { - if (buffer_allocation->IsIndefinite()) continue; - // TODO(xprof): handle multiple temporary buffer allocations for the same - // color. - if (buffer_allocation->heap_simulator_trace_id()) { - return *buffer_allocation->heap_simulator_trace_id(); - } - } - return -1; - } - - // Reference to the original HLO proto. - const ::xla::HloProto& hlo_proto_; - - // A mapping from logical buffer ID to logical buffer. - absl::flat_hash_map> - id_to_logical_buffer_; - - // A mapping from buffer allocation ID to BufferAllocationProto. - absl::flat_hash_map> - id_to_buffer_allocation_; -}; - -double BytesToMiB(int64_t bytes) { - return static_cast(bytes) / (1ULL << 20); -} - -HeapObject MakeHeapObjectCommon(std::string label, int32_t color, - int64_t logical_buffer_id, - int64_t logical_buffer_size_bytes, - int64_t unpadded_shape_bytes) { - HeapObject result; - result.set_numbered(color); - result.set_label(std::move(label)); - result.set_logical_buffer_id(logical_buffer_id); - result.set_logical_buffer_size_mib(BytesToMiB(logical_buffer_size_bytes)); - result.set_unpadded_shape_mib(BytesToMiB(unpadded_shape_bytes)); - return result; -} - -HeapObject MakeHeapObject(const LogicalBufferStruct& logical_buffer, - int32_t color) { - const HloInstructionProto& hlo_instruction = logical_buffer.hlo_instruction; - std::string shape_string = ShapeDescription(logical_buffer.shape); - std::string label = - absl::StrFormat("%s: %s # %s", logical_buffer.instruction_name(), - shape_string, hlo_instruction.metadata().op_name()); - HeapObject result = MakeHeapObjectCommon( - std::move(label), color, logical_buffer.proto.id(), logical_buffer.size(), - logical_buffer.unpadded_size()); - result.set_instruction_name( - logical_buffer.GetInstructionNameWithShapeIndex()); - result.set_group_name(logical_buffer.buffer_allocation.category()); - result.set_tf_op_name(hlo_instruction.metadata().op_name()); - result.set_shape_string(shape_string); - result.set_op_code(hlo_instruction.opcode()); - return result; -} - -BufferSpan MakeBufferSpan(int32 start, int32 limit) { - BufferSpan result; - result.set_start(start); - result.set_limit(limit); - return result; -} - -void Convert(const xla::BufferAllocationProto_Assigned& assigned, - const HloProtoBufferWrapper& wrapper, LogicalBuffer* result) { - result->set_id(assigned.logical_buffer_id()), - result->set_size_mib(BytesToMiB(assigned.size())); - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(assigned.logical_buffer_id()); - if (logical_buffer == nullptr) return; - result->set_hlo_name(std::string(logical_buffer->instruction_name())); - result->mutable_shape_index()->CopyFrom( - logical_buffer->proto.defined_at().shape_index()); - result->set_shape(ShapeDescription(logical_buffer->shape)); -} - -bool IsReusable(const BufferAllocationProto& buffer_allocation) { - return !buffer_allocation.is_thread_local() && !buffer_allocation.is_tuple(); -} - -void Convert(const BufferAllocationProto& proto, - const HloProtoBufferWrapper& wrapper, BufferAllocation* result) { - result->set_id(proto.index()); - result->set_size_mib(BytesToMiB(proto.size())); - if (proto.is_entry_computation_parameter()) { - result->add_attributes("entry computation parameter"); - } - if (proto.maybe_live_out()) { - result->add_attributes("may-be live out"); - } - if (IsReusable(proto)) { - result->add_attributes("reusable"); - } - for (const auto& assigned : proto.assigned()) { - Convert(assigned, wrapper, result->add_logical_buffers()); - } - // Check whether all logical buffers for this buffer allocation have a common - // shape. - if (!result->logical_buffers().empty()) { - std::string common_shape = result->logical_buffers(0).shape(); - for (int64_t i = 1; i < result->logical_buffers_size(); ++i) { - if (result->logical_buffers(i).shape() != common_shape) { - common_shape = ""; - break; - } - } - if (!common_shape.empty()) { - result->set_common_shape(common_shape); - } - } -} - -void NoteSpecialAllocations(const HloProtoBufferWrapper& wrapper, - int64_t memory_color, int64_t small_buffer_size, - PreprocessResult* result) { - int64_t entry_parameters_bytes = 0; - int64_t non_reusable_bytes = 0; - int64_t maybe_live_out_bytes = 0; - int64_t total_buffer_allocation_bytes = 0; - int64_t indefinite_buffer_allocation_bytes = 0; - for (const auto* buffer_allocation_struct : - wrapper.GetBufferAllocations(memory_color)) { - const auto& buffer_allocation = buffer_allocation_struct->proto(); - if (buffer_allocation.is_entry_computation_parameter()) { - entry_parameters_bytes += buffer_allocation.size(); - } - if (!IsReusable(buffer_allocation)) { - non_reusable_bytes += buffer_allocation.size(); - } - if (buffer_allocation.maybe_live_out()) { - if (buffer_allocation.size() > small_buffer_size) { - VLOG(1) << "Maybe live out buffer allocation: " - << buffer_allocation.size() - << " bytes :: " << buffer_allocation.ShortDebugString(); - } - maybe_live_out_bytes += buffer_allocation.size(); - } - if (buffer_allocation_struct->IsIndefinite()) { - indefinite_buffer_allocation_bytes += buffer_allocation.size(); - Convert(buffer_allocation, wrapper, result->add_indefinite_lifetimes()); - } - total_buffer_allocation_bytes += buffer_allocation.size(); - } - - result->set_entry_computation_parameters_mib( - BytesToMiB(entry_parameters_bytes)); - result->set_non_reusable_mib(BytesToMiB(non_reusable_bytes)); - result->set_maybe_live_out_mib(BytesToMiB(maybe_live_out_bytes)); - result->set_total_buffer_allocation_mib( - BytesToMiB(total_buffer_allocation_bytes)); - result->set_indefinite_buffer_allocation_mib( - BytesToMiB(indefinite_buffer_allocation_bytes)); -} - -// Memory usage statistics collected from heap simulator trace. -struct HeapSimulatorStats { - explicit HeapSimulatorStats(const HloProtoBufferWrapper& wrapper) - : wrapper(wrapper) {} - - void SetSimulatorTraceEventSize(int64_t size) { - simulator_trace_event_size = size; - } - - // Update stats for general simulator event. - void UpdateOnSimulatorEvent(const HeapSimulatorTrace::Event& event) { - // Update memory timelines and seen buffers. - heap_size_bytes_timeline.push_back(heap_size_bytes); - unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes); - hlo_instruction_name_timeline.push_back(event.instruction_name()); - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(event.buffer_id()); - if (logical_buffer == nullptr) return; - seen_logical_buffers.insert(logical_buffer); - seen_buffer_allocations.insert(&logical_buffer->buffer_allocation.proto()); - } - - // Update stats when memory usage increase. - void IncreaseMemoryUsage(LogicalBufferStruct* canonical_logical_buffer, - bool init_buffer_span) { - logical_buffers.push_back(canonical_logical_buffer->proto.id()); - heap_size_bytes += canonical_logical_buffer->size(); - unpadded_heap_size_bytes += canonical_logical_buffer->unpadded_size(); - - // Increase peak memory usage if needed. - int64_t prior_peak_heap_size_bytes = peak_heap_size_bytes; - peak_heap_size_bytes = std::max(peak_heap_size_bytes, heap_size_bytes); - if (prior_peak_heap_size_bytes != peak_heap_size_bytes) { - peak_heap_size_position = heap_size_bytes_timeline.size() - 1; - peak_unpadded_heap_size_bytes = unpadded_heap_size_bytes; - VLOG(1) << absl::StrFormat("New peak heap size on %d :: %d bytes", - peak_heap_size_position, peak_heap_size_bytes); - peak_logical_buffers = logical_buffers; - } - // Initialize the buffer lifespan if needed. - if (init_buffer_span) { - // Initialize the buffer span from the current event to the last event in - // heap simulator trace. - canonical_logical_buffer->span.emplace( - heap_size_bytes_timeline.size() - 1, simulator_trace_event_size - 1); - } - } - - // Update stats when memory usage decrease. - absl::Status DecreaseMemoryUsage( - LogicalBufferStruct* canonical_logical_buffer) { - int64_t canonical_buffer_id = canonical_logical_buffer->proto.id(); - logical_buffers.remove(canonical_buffer_id); - heap_size_bytes -= canonical_logical_buffer->size(); - if (heap_size_bytes < 0) { - return errors::InvalidArgument(absl::StrCat( - "Heap size should be non-negative, but get: ", heap_size_bytes)); - } - unpadded_heap_size_bytes -= canonical_logical_buffer->unpadded_size(); - // Mark the end of this buffer. - if (canonical_logical_buffer->span) { - canonical_logical_buffer->span->second = - heap_size_bytes_timeline.size() - 1; - } - return absl::OkStatus(); - } - - // Finalize the memory usage stats from heap simulator trace. - absl::Status FinalizeMemoryUsage() { - // Add the final heap size after simulating the entire heap trace. - heap_size_bytes_timeline.push_back(heap_size_bytes); - unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes); - // Add an empty instruction name just so that this array is the same size as - // the other two. - hlo_instruction_name_timeline.push_back(""); - - if (seen_buffer_allocations.size() != 1) { - return errors::InvalidArgument( - absl::StrCat("All heap simulation should work out of a single buffer " - "allocation, actual seen_buffer_allocations.size():", - seen_buffer_allocations.size())); - } - - // Log stats. - VLOG(1) << "Found " << peak_logical_buffers.size() - << " logical buffers alive at point of peak heap usage."; - - VLOG(1) << "Peak logical buffers: [" - << absl::StrJoin(peak_logical_buffers, ", ") << "]"; - - return absl::OkStatus(); - } - - // Keep track of memory usage when iterating through heap simulator trace - // events. - int64_t heap_size_bytes = 0; - int64_t unpadded_heap_size_bytes = 0; - // Memory usage at peak. - int64_t peak_heap_size_bytes = 0; - int64_t peak_unpadded_heap_size_bytes = 0; - - // Keep track of logical buffer IDs when iterating through heap simulator - // trace events. It is important this is in "program order", i.e. heap - // simulator's order. - std::list logical_buffers; - // Logical buffer IDs at peak. - std::list peak_logical_buffers; - - // Heap size timeline. - std::vector heap_size_bytes_timeline; - std::vector unpadded_heap_size_bytes_timeline; - std::vector hlo_instruction_name_timeline; - - // Position of peak memory usage in the timeline. - int64_t peak_heap_size_position = 0; - - // Logical buffers and buffer allocations that exists in heap simulator trace. - absl::flat_hash_set seen_logical_buffers; - absl::flat_hash_set seen_buffer_allocations; - - // Constants while iterating through heap simulator trace. - const HloProtoBufferWrapper& wrapper; - int64_t simulator_trace_event_size; -}; - -absl::Status ProcessHeapSimulatorTrace(const HloProtoBufferWrapper& wrapper, - const int64_t memory_color, - HeapSimulatorStats* stats) { - int64_t heap_simulator_trace_id = - wrapper.GetHeapSimulatorTraceId(memory_color); - - // If unable to get a valid heap simulator trace id, skip heap simulator - // trace and process the rest of the buffers. - if (heap_simulator_trace_id < 0 || - heap_simulator_trace_id >= wrapper.GetHloProto() - .buffer_assignment() - .heap_simulator_traces_size()) { - return absl::OkStatus(); - } - - // Run through all the simulator events in the given trace, and simulate the - // heap in order to find the point of peak memory usage and record its - // associated metadata. - const auto& trace = - wrapper.GetHloProto().buffer_assignment().heap_simulator_traces( - heap_simulator_trace_id); - - stats->SetSimulatorTraceEventSize(trace.events_size()); - for (const auto& event : trace.events()) { - stats->UpdateOnSimulatorEvent(event); - LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(event.buffer_id()); - if (logical_buffer == nullptr) { - continue; - } - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - // ALLOC event increases memory usage and initializes the buffer lifetime - // span. - logical_buffer->inc(); - stats->IncreaseMemoryUsage(logical_buffer, - /*init_buffer_span=*/true); - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - auto ref_count = logical_buffer->dec(); - if (ref_count < 0) { - return errors::InvalidArgument(absl::StrCat( - "Buffer ", logical_buffer->proto.id(), "is freed multiple times.")); - } - if (ref_count == 0) { - // There is no more reference to the canonical buffer, the canonical - // buffer is finally freed. Update memory usage and memory timespan - // using the metadata of canonical buffer. - auto& canonical_buffer = *logical_buffer->get_canonical_buffer(); - TF_RETURN_IF_ERROR(stats->DecreaseMemoryUsage(&canonical_buffer)); - } - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - int64_t canonical_buffer_id = event.share_with_canonical_id(); - LogicalBufferStruct* canonical_buffer = - wrapper.GetLogicalBuffer(canonical_buffer_id); - if (canonical_buffer == nullptr) { - continue; - } - auto ref_count = logical_buffer->share_with(canonical_buffer); - - if (ref_count == 1) { - // SHARE_WITH happens after the FREE of a canonical buffer. - // SHARE_WITH event does not initialize buffer lifetime span, it was - // initialized by ALLOC event using the canonical logical buffer. - stats->IncreaseMemoryUsage(canonical_buffer, - /*init_buffer_span=*/false); - } - } else { - return errors::InvalidArgument( - absl::StrCat("Unhandled event kind: ", event.kind())); - } - } - TF_RETURN_IF_ERROR(stats->FinalizeMemoryUsage()); - return absl::OkStatus(); -} - -// The stats when processing buffer allocations and logical buffers. -struct PeakUsageSnapshot { - PeakUsageSnapshot(const HloProtoBufferWrapper& wrapper, - const HeapSimulatorStats& simulator_stats, - int64_t small_buffer_size) - : wrapper(wrapper), - simulator_stats(simulator_stats), - small_buffer_size(small_buffer_size) {} - - // Add a HeapObject derived from logical buffer and buffer allocation. - void AddHeapObject(const LogicalBufferStruct& logical_buffer) { - if (logical_buffer.size() < small_buffer_size) { - // Accumulate small buffers, don't make a HeapObject. - total_small_buffer_size_bytes += logical_buffer.size(); - } else { - // Make a new HeapObject, assign a new color to visualize it. - max_heap_objects.push_back(MakeHeapObject(logical_buffer, colorno++)); - } - } - - void FinalizeBufferUsage() { - // Buffers from HeapSimulatorTrace. - for (const int64_t logical_buffer_id : - simulator_stats.peak_logical_buffers) { - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(logical_buffer_id); - if (logical_buffer == nullptr) return; - AddHeapObject(*logical_buffer); - } - - // Make a single HeapObject out of all the small buffers. - if (total_small_buffer_size_bytes != 0) { - max_heap_objects.push_back(MakeHeapObjectCommon( - absl::StrFormat("small (<%d bytes)", small_buffer_size), colorno++, - /*logical_buffer_id=*/-1, total_small_buffer_size_bytes, - /*unpadded_shape_bytes=*/0)); - } - } - - // All the HeapObjects at peak memory time. - std::vector max_heap_objects; - // The total size of all memory buffers with indefinite lifetime. - int64_t indefinite_memory_usage_bytes = 0; - // The accumulated size of all small buffers. - int64_t total_small_buffer_size_bytes = 0; - // Tracker of memory viewer color. - int32_t colorno = 0; - - const HloProtoBufferWrapper& wrapper; - const HeapSimulatorStats& simulator_stats; - const int64_t small_buffer_size; -}; - -void CreatePeakUsageSnapshot(const HloProtoBufferWrapper& wrapper, - int64_t memory_color, - PeakUsageSnapshot* peak_snapshot) { - // Add indefinite (global) buffers to peak usage snapshot. - for (const auto* logical_buffer : - wrapper.LogicalBuffersWithIndefiniteLifetime(memory_color)) { - const auto& buffer_allocation = logical_buffer->buffer_allocation; - peak_snapshot->indefinite_memory_usage_bytes += buffer_allocation.size(); - peak_snapshot->AddHeapObject(*logical_buffer); - } - - // Add temporary buffers (traced by heap simulator) to peak usage snapshot. - peak_snapshot->FinalizeBufferUsage(); -} - -void ConvertAllocationTimeline(const HloProtoBufferWrapper& wrapper, - const HeapSimulatorStats& simulator_stats, - const int64_t memory_color, - PreprocessResult* result) { - // The color constants from https://graphviz.org/doc/info/colors.html. - const char* lb_colors[] = { - "antiquewhite3", - "aqua", - "aquamarine", - "bisque", - "blanchedalmond", - "blue", - "blueviolet", - "brown", - "burlywood", - "cadetblue", - "chartreuse", - "chocolate", - "coral", - "cornflowerblue", - "crimson", - "cyan", - "darkblue", - "darkcyan", - "darkgoldenrod", - "darkgray", - "darkgreen", - "darkkhaki", - "darkmagenta", - "darkolivegreen", - "darkorange", - "darkorchid", - "darkred", - "darksalmon", - "darkseagreen", - "darkslateblue", - "darkslategray", - "darkturquoise", - "darkviolet", - "deeppink", - "deepskyblue", - "dimgray", - "dodgerblue", - "firebrick", - "floralwhite", - "forestgreen", - "fuchsia", - "gainsboro", - "gold", - "goldenrod", - "green", - "greenyellow", - "goldenrod", - "greenyellow", - "honeydew", - "hotpink", - "indianred", - "indigo", - "ivory3", - "khaki", - "lavender", - "lavenderblush", - "lawngreen", - "lemonchiffon", - "lightblue", - "lightcoral", - "lightcyan", - "lightpink", - "limegreen", - "lightsalmon", - "lightseagreen", - "lightskyblue", - "lime", - "magenta", - "maroon", - "mediumaquamarine", - "mediumblue", - "mediumorchid", - "mediumpurple", - "midnightblue", - "mediumvioletred", - "mistyrose", - "moccasin", - "olive", - "orange", - "orangered", - "orchid", - "palegoldenrod", - "palegreen", - "paleturquoise", - "palevioletred", - "papayawhip", - "peachpuff", - "peachpuff", - "pink", - "plum", - "powderblue", - "purple", - "rebeccapurple", - "red", - "rosybrown", - "royalblue", - "salmon", - "sandybrown", - "seagreen", - "seashell", - "sienna", - "skyblue", - "tan", - "teal", - "turquoise", - "tomato", - "violet", - "violetred", - "yellow", - }; - - struct RenderOptions { - size_t graph_width = 2048; - size_t graph_height = 2048; - } render_options; - - const char* ba_colors[] = { - "azure", - "beige", - "cornsilk", - }; - - int num_lb_colors = sizeof(lb_colors) / sizeof(lb_colors[0]); - int num_ba_colors = sizeof(ba_colors) / sizeof(ba_colors[0]); - std::vector buffer_allocation_offsets; - size_t total_y_size = 0; // Range of y dimension. - size_t total_x_size = 0; // Range of x dimension. - std::vector rects; - auto buffer_allocations = wrapper.GetBufferAllocations(memory_color); - const auto& heap_simulator_traces = - wrapper.GetHloProto().buffer_assignment().heap_simulator_traces(); - for (const auto& buffer_allocation : buffer_allocations) { - // Exclude BAs for "global variables". The timeline provides little value. - if (buffer_allocation->IsIndefinite()) continue; - auto heap_simulator_trace_id = buffer_allocation->heap_simulator_trace_id(); - if (!heap_simulator_trace_id) continue; - buffer_allocation_offsets.push_back(total_y_size); - total_y_size += buffer_allocation->size(); - if (*heap_simulator_trace_id >= heap_simulator_traces.size()) { - LOG(DFATAL) << "heap_simulator_trace_id " << *heap_simulator_trace_id - << " out of bounds."; - continue; - } - total_x_size = std::max( - total_x_size, - heap_simulator_traces.at(*heap_simulator_trace_id).events_size()); - } - if (!total_y_size || !total_x_size) return; - double scale_x = - static_cast(render_options.graph_width) / total_x_size; - double scale_y = - static_cast(render_options.graph_height) / total_y_size; - - int node_id = 0; - auto add_rect = [&](size_t x, size_t y, size_t width, size_t height, - const string& description, const char* color) { - size_t center_x = x + (width >> 1); - size_t center_y = y + (height >> 1); - int pos_x = center_x * scale_x; - int pos_y = center_y * scale_y; - int rect_w = width * scale_x; - int rect_h = height * scale_y; - // Skip when block size is smaller than half a pixel in output size. - if (height * scale_y < 0.5) return; - rect_h = std::max(rect_h, 1); // Rounding up. - std::string rect = absl::StrFormat( - R"("%d" [tooltip="%s", pos="%d,%d!", width="%d!", height="%d!", color=%s];)", - node_id++, description, pos_x, pos_y, rect_w, rect_h, color); - rects.push_back(rect); - }; - int buffer_id = 0; - for (const auto& buffer_allocation : buffer_allocations) { - // Exclude BAs for "global variables". The timeline provides little value. - if (buffer_allocation->IsIndefinite()) continue; - auto buffer_allocation_offset = buffer_allocation_offsets[buffer_id++]; - add_rect(0, buffer_allocation_offset, total_x_size, - buffer_allocation->size(), buffer_allocation->description(), - ba_colors[buffer_id % num_ba_colors]); - - for (const auto& assigned : buffer_allocation->proto().assigned()) { - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(assigned.logical_buffer_id()); - if (logical_buffer == nullptr) continue; - // Exclude non-canonical logical buffers. - if (!logical_buffer->span || logical_buffer->canonical_buffer) continue; - size_t width = logical_buffer->span->second - logical_buffer->span->first; - size_t height = buffer_allocation_offset + logical_buffer->size(); - add_rect(logical_buffer->span->first, logical_buffer->offset, width, - height, logical_buffer->description(), - lb_colors[node_id % num_lb_colors]); - } - } - VLOG(1) << "rects:" << rects.size(); - result->set_allocation_timeline( - absl::StrFormat("graph G {\n node [shape=box,style=filled];\n %s\n}", - absl::StrJoin(rects, "\n"))); -} - -void GeneratePreprocessResult(const HloProtoBufferWrapper& wrapper, - const HeapSimulatorStats& simulator_stats, - const PeakUsageSnapshot& peak_snapshot, - const int64_t memory_color, - PreprocessResult* result) { - // Module info. - result->set_module_name(wrapper.GetHloProto().hlo_module().name()); - result->set_entry_computation_name( - wrapper.GetHloProto().hlo_module().entry_computation_name()); - - // Build HeapObjects and index. - std::vector max_heap_by_size; - max_heap_by_size.reserve(peak_snapshot.max_heap_objects.size()); - for (const auto& object : peak_snapshot.max_heap_objects) { - max_heap_by_size.push_back(&object); - } - std::sort(max_heap_by_size.begin(), max_heap_by_size.end(), - [](const HeapObject* a, const HeapObject* b) { - return a->logical_buffer_size_mib() > - b->logical_buffer_size_mib(); - }); - - std::vector max_heap_to_by_size; - max_heap_to_by_size.reserve(max_heap_by_size.size()); - for (const auto& object : peak_snapshot.max_heap_objects) { - auto it = - std::find(max_heap_by_size.begin(), max_heap_by_size.end(), &object); - int index = std::distance(max_heap_by_size.begin(), it); - max_heap_to_by_size.push_back(index); - } - - std::vector by_size_to_max_heap; - for (const auto* object : max_heap_by_size) { - int index = object - &peak_snapshot.max_heap_objects[0]; - by_size_to_max_heap.push_back(index); - } - - *result->mutable_max_heap() = {peak_snapshot.max_heap_objects.begin(), - peak_snapshot.max_heap_objects.end()}; - result->mutable_max_heap_by_size()->Reserve(max_heap_by_size.size()); - for (const HeapObject* o : max_heap_by_size) { - *result->add_max_heap_by_size() = *o; - } - *result->mutable_max_heap_to_by_size() = {max_heap_to_by_size.begin(), - max_heap_to_by_size.end()}; - *result->mutable_by_size_to_max_heap() = {by_size_to_max_heap.begin(), - by_size_to_max_heap.end()}; - - // For the buffers that have indefinite lifetime (that is, lifetime not - // reflected by the heap simulation) add it to the peak values and the vectors - // of heap sizes. - size_t timeline_size = simulator_stats.heap_size_bytes_timeline.size(); - double add_mib = BytesToMiB(peak_snapshot.indefinite_memory_usage_bytes); - result->mutable_heap_sizes()->Reserve(timeline_size); - result->mutable_unpadded_heap_sizes()->Reserve(timeline_size); - for (size_t i = 0; i < timeline_size; i++) { - result->add_heap_sizes( - BytesToMiB(simulator_stats.heap_size_bytes_timeline[i]) + add_mib); - result->add_unpadded_heap_sizes( - BytesToMiB(simulator_stats.unpadded_heap_size_bytes_timeline[i]) + - add_mib); - result->add_hlo_instruction_names( - simulator_stats.hlo_instruction_name_timeline[i]); - } - - result->set_peak_heap_mib(BytesToMiB(simulator_stats.peak_heap_size_bytes) + - add_mib); - result->set_peak_unpadded_heap_mib( - BytesToMiB(simulator_stats.peak_unpadded_heap_size_bytes) + add_mib); - result->set_peak_heap_size_position(simulator_stats.peak_heap_size_position); - - // Build buffer lifespan. - for (const auto* logical_buffer : simulator_stats.seen_logical_buffers) { - if (!logical_buffer->span) continue; - (*result->mutable_logical_buffer_spans())[logical_buffer->proto.id()] = - MakeBufferSpan(logical_buffer->span->first, - logical_buffer->span->second); - } - - NoteSpecialAllocations(wrapper, memory_color, peak_snapshot.small_buffer_size, - result); - - ConvertAllocationTimeline(wrapper, simulator_stats, memory_color, result); -} - -} // namespace - -absl::StatusOr ConvertHloProtoToPreprocessResult( - const HloProto& hlo_proto, int64_t small_buffer_size, - int64_t memory_color) { - HloProtoBufferWrapper wrapper(hlo_proto); - - // Process heap simulator trace. - HeapSimulatorStats simulator_stats(wrapper); - auto status = - ProcessHeapSimulatorTrace(wrapper, memory_color, &simulator_stats); - if (!status.ok()) { - return absl::InvalidArgumentError(absl::StrCat( - "Failed to process heap simulator trace: ", status.message())); - } - - // Process buffers with indefinite lifetime. - PeakUsageSnapshot peak_snapshot(wrapper, simulator_stats, small_buffer_size); - CreatePeakUsageSnapshot(wrapper, memory_color, &peak_snapshot); - - PreprocessResult result; - GeneratePreprocessResult(wrapper, simulator_stats, peak_snapshot, - memory_color, &result); - return result; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h index e7a681de51c393..d5a6061f90187a 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h +++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h @@ -16,29 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ -#include - -#include "absl/status/statusor.h" -#include "xla/service/hlo.pb.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" - -namespace tensorflow { -namespace profiler { - -constexpr int kSmallBufferSize = 16 * 1024; - -// Convert HloProto to PreprocessResult proto for memory visualization. -// small_buffer_size sets the byte size within which we collapse buffer entries -// for the max-heap display. -// is the index of heap simulator trace to be -// displayed. By default it is -1, which means the profiler will infer the heap -// simulator trace id from . -// By default the memory color is 0, which is HBM. -absl::StatusOr ConvertHloProtoToPreprocessResult( - const xla::HloProto& hlo_proto, - int64_t small_buffer_size = kSmallBufferSize, int64_t memory_color = 0); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/convert/hlo_proto_to_memory_visualization_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc deleted file mode 100644 index d92dea32152a36..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h" - -#include - -#include "absl/strings/str_format.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" -#include "tensorflow/core/util/proto/proto_utils.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// 1 buffer allocation of 1MB -// 2 logical buffers, each is 0.5MB -static constexpr char kHLOBase[] = R"pb( - hlo_module { - name: "test_module" - entry_computation_name: "test_computation" - computations { - name: "test_computation" - instructions { - name: "fusion.1" - id: 0 - shape { tuple_shapes { element_type: U64 } } - } - instructions { - name: "fusion.2" - id: 1 - shape { tuple_shapes { element_type: U64 } } - } - } - } - buffer_assignment { - buffer_allocations { - index: 0 - size: 1048576 - color: 0 - assigned { logical_buffer_id: 1 offset: 0 size: 524288 } - assigned { logical_buffer_id: 2 offset: 524288 size: 524288 } - } - logical_buffers { - id: 1 - size: 524288 - color: 0 - defined_at { instruction_id: 0 shape_index: 0 } - } - logical_buffers { - id: 2 - size: 524288 - color: 0 - defined_at { instruction_id: 1 shape_index: 0 } - } - heap_simulator_traces { %s } - } -)pb"; - -TEST(MemoryViewerTest, TestHeapSimulatorTraceShareWith_1) { - // Allocate and then share, the memory usage is not doubled. - static constexpr char kHeapSimulatorTrace[] = R"pb( - events { kind: ALLOC buffer_id: 1 } - events { kind: SHARE_WITH buffer_id: 2 share_with_canonical_id: 1 } - events { kind: FREE buffer_id: 1 } - events { kind: FREE buffer_id: 2 } - )pb"; - std::string hlo_string = absl::StrFormat(kHLOBase, kHeapSimulatorTrace); - xla::HloProto hlo_proto; - ASSERT_TRUE( - proto_utils::ParseTextFormatFromString(hlo_string, &hlo_proto).ok()); - TF_ASSERT_OK_AND_ASSIGN( - PreprocessResult preprocess_result, - ConvertHloProtoToPreprocessResult(hlo_proto, /*small_buffer_size=*/0)); - EXPECT_EQ(preprocess_result.peak_heap_mib(), 0.5); -} - -TEST(MemoryViewerTest, TestHeapSimulatorTraceShareWith_2) { - // Allocate, free and then share, the memory usage is not doubled. - static constexpr char kHeapSimulatorTrace[] = R"pb( - events { kind: ALLOC buffer_id: 1 } - events { kind: FREE buffer_id: 1 } - events { kind: SHARE_WITH buffer_id: 2 share_with_canonical_id: 1 } - events { kind: FREE buffer_id: 2 } - )pb"; - std::string hlo_string = absl::StrFormat(kHLOBase, kHeapSimulatorTrace); - xla::HloProto hlo_proto; - ASSERT_TRUE( - proto_utils::ParseTextFormatFromString(hlo_string, &hlo_proto).ok()); - TF_ASSERT_OK_AND_ASSIGN( - PreprocessResult preprocess_result, - ConvertHloProtoToPreprocessResult(hlo_proto, /*small_buffer_size=*/0)); - EXPECT_EQ(preprocess_result.peak_heap_mib(), 0.5); - EXPECT_FALSE(preprocess_result.allocation_timeline().empty()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc deleted file mode 100644 index 608bc6df8d0d71..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_to_tools_data.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/numbers.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h" -#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/convert/xplane_to_hlo.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -absl::StatusOr GetMemoryViewerPreprocessResult( - const xla::HloProto& hlo_proto, int memory_space_color) { - static constexpr int kSmallBufferSize = 16 * 1024; // 16KB - - auto result_or = ConvertHloProtoToPreprocessResult( - hlo_proto, kSmallBufferSize, memory_space_color); - if (!result_or.ok()) { - return errors::Internal( - "Failed to convert HLO proto to memory viewer result: ", - result_or.status().message()); - } - return result_or; -} - -absl::StatusOr ConvertHloProtoToMemoryViewer( - const xla::HloProto& hlo_proto, int memory_space_color) { - auto result_or = - GetMemoryViewerPreprocessResult(hlo_proto, memory_space_color); - if (!result_or.ok()) { - return result_or.status(); - } - - std::string json_output; - tsl::protobuf::util::JsonPrintOptions options; - options.always_print_primitive_fields = true; - auto encoded_status = tsl::protobuf::util::MessageToJsonString( - result_or.value(), &json_output, options); - if (!encoded_status.ok()) { - const auto& error_message = encoded_status.message(); - return errors::Internal( - "Failed to convert memory viewer result to JSON format: ", - absl::string_view(error_message.data(), error_message.length())); - } - - return json_output; -} - -absl::StatusOr ConvertHloProtoToAllocationTimeline( - const xla::HloProto& hlo_proto, int memory_space_color) { - auto result_or = - GetMemoryViewerPreprocessResult(hlo_proto, memory_space_color); - if (!result_or.ok()) { - return result_or.status(); - } - - return WrapDotInHtml(std::move(result_or.value().allocation_timeline()), - "neato"); -} - -absl::StatusOr ConvertHloProtoToGraphViewer( - const xla::HloProto& hlo_proto, const ToolOptions& options) { - TF_ASSIGN_OR_RETURN(GraphViewerParams params, - ParseGraphViewerParams(options)); - if (params.type == "graph") { - return ConvertHloProtoToGraph(hlo_proto, params.node_name, - params.graph_width, params.render_options, - params.format); - } else { - return ConvertHloProtoToStringView(hlo_proto, params.verbose, - params.show_metadata); - } -} - -} // namespace - -absl::StatusOr ConvertHloProtoToToolData( - const SessionSnapshot& session_snapshot, const absl::string_view tool_name, - const ToolOptions& options) { - // must provide a hlo module_name field to identify the HLO module. - std::optional hlo_module_name = - GetParam(options, "module_name"); - if (!hlo_module_name.has_value() || hlo_module_name->empty()) { - return errors::InvalidArgument( - "Can not find HLO module name from options."); - } - - // Load HLO module from file. - TF_ASSIGN_OR_RETURN( - xla::HloProto hlo_proto, - GetHloProtoByModuleName(session_snapshot, *hlo_module_name)); - - // Convert from HLO proto to tools data. - int memory_space_color = 0; - if (!absl::SimpleAtoi( - GetParamWithDefault(options, "memory_space", std::string("0")), - &memory_space_color)) { - memory_space_color = 0; - } - - if (tool_name == "memory_viewer") { - if (GetParamWithDefault(options, "view_memory_allocation_timeline", 0)) { - return ConvertHloProtoToAllocationTimeline(hlo_proto, memory_space_color); - } - return ConvertHloProtoToMemoryViewer(hlo_proto, memory_space_color); - } else if (tool_name == "graph_viewer") { - return ConvertHloProtoToGraphViewer(hlo_proto, options); - } else { - return errors::InvalidArgument( - "Can not find tool: ", tool_name, - ". Please update to the latest version of Tensorflow."); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.h b/tensorflow/core/profiler/convert/hlo_to_tools_data.h deleted file mode 100644 index b567c973382997..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_to_tools_data.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" - -namespace tensorflow { -namespace profiler { - -// Convert HLO proto to tool specific data. -// must provide a "module_name" field to identify which HLO proto -// is used for the conversion. -// Return the serialized string of tool specific data when the conversion is -// successful, else return an error status. -absl::StatusOr ConvertHloProtoToToolData( - const SessionSnapshot& session_snapshot, absl::string_view tool_name, - const ToolOptions& options); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats.cc b/tensorflow/core/profiler/convert/inference_stats.cc deleted file mode 100644 index 904e97b1615735..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats.cc +++ /dev/null @@ -1,1510 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/inference_stats.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/base/macros.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/tsl/platform/logging.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::EventType; -using ::tensorflow::profiler::EventTypeSpan; -using ::tensorflow::profiler::StepEvents; -using ::tensorflow::profiler::ToNonOverlappedEvents; -using ::tsl::profiler::CreateTfXPlaneVisitor; -using ::tsl::profiler::DeviceType; -using ::tsl::profiler::GroupMetadata; -using ::tsl::profiler::GroupMetadataMap; -using ::tsl::profiler::HostEventType; -using ::tsl::profiler::StatType; -using ::tsl::profiler::Timespan; -using ::tsl::profiler::XEventVisitor; -using ::tsl::profiler::XLineVisitor; -using ::tsl::profiler::XPlane; -using ::tsl::profiler::XPlaneVisitor; -using ::tsl::profiler::XSpace; -using ::tsl::profiler::XStatVisitor; - -using EventsByType = - absl::flat_hash_map>; - -// Holds all the events within a user facing request. -// A user facing request can be a Session.Run without batching, or a -// BatchingSession.Run with Batching, or a Session.Run with -// BatchingFunctionOp. -struct RequestEvents { - // Index to the model id. - int32_t model_id_index; - // The timespan of the entire request(including both host and device). - Timespan request_timespan; - // The latency between a request is scheduled and is processed in a batch. - int64_t batching_request_delay_ps; - // Size of a request in batching mode. - int32_t batching_request_size; - - // Timestamps of the events used for the detailed execution time breakdown. - struct EventTimestamps { - std::optional ts_batch_schedule; - std::optional ts_batch_concat_input; - std::optional ts_tpu_execute; - std::optional ts_tpu_program_launch; - std::optional ts_tpu_complete_callback; - }; - // Mapping from group ID to the timestamps, there can be multiple group IDs - // in a single request, because if request splitting is enabled, one request - // can be split to multiple batches for execution, and each batch has - // different group ID. - absl::flat_hash_map timestamps; - - // The events that record tensor details like shape, type and layout. - std::vector tensor_events; - // The final tensor details in proto format. - std::vector - tensor_event_detail_protos; - - // The batch ids related to this request. - std::vector related_batch_ids; - // All the events. - std::vector events; -}; - -// Helper functions to handle absl::optional -void MinOfOptional(std::optional& min, std::optional value) { - if (!min.has_value()) - min = value; - else - min = std::min(min, value); -} -void MaxOfOptional(std::optional& max, std::optional value) { - if (!max.has_value()) - max = value; - else - max = std::max(max, value); -} - -// Helper functions to set timestamps in RequestEvents. -void UpdateTsBatchSchedule(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_batch_schedule = value; -} -void UpdateTsBatchConcatInput(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_batch_concat_input = value; -} -void UpdateTsTPUExecute(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_tpu_execute = value; -} -void UpdateTsTPUProgramLaunch(int64_t group_id, int64_t value, - RequestEvents* events) { - // There might be multiple TPUProgramLaunch events in a single request. - // Set ts_tpu_program_launch to the earlist timestamp. - MinOfOptional(events->timestamps[group_id].ts_tpu_program_launch, value); -} -void UpdateTsTPUCompleteCallback(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_tpu_complete_callback = value; -} - -// Map from the ID of a request to its events. -using RequestEventsMap = - absl::flat_hash_map; - -// An internal data structure that holds all the events within a batch. -struct BatchEvents { - // The events that record tensor details like shape, type and layout. - std::vector tensor_events; - - // The BatchDetail proto. - tensorflow::profiler::BatchDetail batch_detail_proto; - - // All the events. - std::vector events; -}; - -// Map from the ID of a batch to its events. -using BatchEventsMap = absl::flat_hash_map; - -// Map from the ID of a request to its model ID. -using ModelIdMap = absl::flat_hash_map; - -int32_t AssignIndexToModelId( - const std::string& model_id, - tensorflow::profiler::ModelIdDatabase* model_id_db) { - if (model_id.empty()) return -1; - auto [iter, inserted] = model_id_db->mutable_id_to_index()->insert( - {model_id, model_id_db->ids_size()}); - if (inserted) { - model_id_db->add_ids(model_id); - } - return iter->second; -} - -// Updates timestamps in RequestEvents. -// is the timestamp to update, is the updated value. -void UpdateEventTimestamps( - const GroupMetadataMap& group_metadata_map, int64_t group_id, int64_t value, - std::function function, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map = nullptr) { - // Update RequestEvents that are directly associated with . - if (request_events_map != nullptr) { - if (auto request_events = gtl::FindOrNull(*request_events_map, group_id)) { - function(group_id, value, request_events); - } - - // Update all the parent RequestEvents of . - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) return; - for (const int64_t parent_group_id : group_metadata->parents) { - if (auto parent_request_events = - gtl::FindOrNull(*request_events_map, parent_group_id)) { - // Update parent events, but still use instead of - // , because xprof needs to track where these event - // timestamps originally come from. - function(group_id, value, parent_request_events); - } - } - } - // Note: Timestamp updates for batch analysis is not supported yet. -} - -void UpdateBatchEvents(const GroupMetadataMap& group_metadata_map, - absl::Span events, int64_t group_id, - BatchEventsMap* batch_events_map) { - // Update BatchEvents that are directly associated with . - if (auto batch_events = gtl::FindOrNull(*batch_events_map, group_id)) { - batch_events->events.insert(batch_events->events.end(), events.begin(), - events.end()); - } -} - -// Updates RequestEvents using ReadFromDevice, WriteToDevice and DeviceRun. -void UpdateRequestEvents(const GroupMetadataMap& group_metadata_map, - absl::Span events, - int64_t group_id, - RequestEventsMap* request_events_map) { - // Update RequestEvents that are directly associated with . - if (auto request_events = gtl::FindOrNull(*request_events_map, group_id)) { - request_events->events.insert(request_events->events.end(), events.begin(), - events.end()); - } - - // Update all the parent RequestEvents of with the same - // and . Parent RequestEvents are all the requests - // in a batch. - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) return; - for (const int64_t parent_group_id : group_metadata->parents) { - if (auto parent_request_events = - gtl::FindOrNull(*request_events_map, parent_group_id)) { - parent_request_events->events.insert(parent_request_events->events.end(), - events.begin(), events.end()); - } - } -} - -// Initializes RequestEvents. -// determines whether this event is a -// BatchingSession.Run -void InitializeRequestEvents( - const XEventVisitor& event, const GroupMetadataMap& group_metadata_map, - const absl::flat_hash_set& process_batch_group_ids, - const ModelIdMap& model_id_map, bool is_batching_request, - bool is_user_defined_request, - tensorflow::profiler::ModelIdDatabase* model_id_db, - RequestEventsMap* request_events_map) { - std::optional optional_group_id = - event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) return; - int64_t group_id = optional_group_id->IntValue(); - - // If the event has ProcessBatch event as a parent, then do not consider - // it as a request. - if (process_batch_group_ids.contains(group_id)) return; - - RequestEvents& request_events = (*request_events_map)[group_id]; - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) return; - // The children group_ids of a request are the batches related to this - // request. - for (const int64_t child_group_id : group_metadata->children) { - request_events.related_batch_ids.push_back(child_group_id); - } - // Sort related_batch_ids to get deterministic result. - absl::c_sort(request_events.related_batch_ids); - if (is_batching_request) { - // The children events of BatchingSession.Run are multiple Session.Run, - // use the first child event to initialize ModelId information, because - // all the children events should have the same ModelId. - if (group_metadata->children.empty()) return; - int64_t children_group_id = *group_metadata->children.begin(); - const std::string* children_model_id = - gtl::FindOrNull(model_id_map, children_group_id); - request_events.model_id_index = AssignIndexToModelId( - children_model_id ? *children_model_id : "", model_id_db); - } else if (is_user_defined_request) { - const std::string* model_id = gtl::FindOrNull(model_id_map, group_id); - if (model_id) { - request_events.model_id_index = - AssignIndexToModelId(*model_id, model_id_db); - } else { - // In some cases (e.g., BrainServer::Estimate), a single request might - // dispatch batches for multiple models. If all children events - // have the same ModelId, we assign that ModelId to the request. - if (group_metadata->children.empty()) return; - int32_t model_id_index_for_all_children = -1; - bool all_children_have_same_model_id = true; - for (int64_t children_group_id : group_metadata->children) { - const std::string* children_model_id = - gtl::FindOrNull(model_id_map, children_group_id); - int32_t child_model_id_index = AssignIndexToModelId( - children_model_id ? *children_model_id : "", model_id_db); - if (model_id_index_for_all_children == -1) { - model_id_index_for_all_children = child_model_id_index; - } else if (child_model_id_index != model_id_index_for_all_children) { - all_children_have_same_model_id = false; - } - } - request_events.model_id_index = - all_children_have_same_model_id - ? model_id_index_for_all_children - : AssignIndexToModelId("", model_id_db); - } - } else { - const std::string* model_id = gtl::FindOrNull(model_id_map, group_id); - request_events.model_id_index = - AssignIndexToModelId(model_id ? *model_id : "", model_id_db); - } -} - -// Set the begin and end timestamp of the request. -// The timespan of the request is marked by the earliest timestamp and latest -// timestamp of the events with the same group_id. -void UpdateRequestTimespan(const EventsByType& host_events_by_type, - RequestEventsMap* request_events_map) { - for (const auto& [_, events] : host_events_by_type) { - for (const auto& event : events) { - auto optional_group_id = event.GetStat(StatType::kGroupId); - if (optional_group_id.has_value()) { - if (RequestEvents* request = gtl::FindOrNull( - *request_events_map, optional_group_id->IntValue())) { - auto begin_ps = request->request_timespan.begin_ps() == 0 - ? event.GetTimespan().begin_ps() - : std::min(request->request_timespan.begin_ps(), - event.GetTimespan().begin_ps()); - auto end_ps = std::max(request->request_timespan.end_ps(), - event.GetTimespan().end_ps()); - request->request_timespan = Timespan::FromEndPoints(begin_ps, end_ps); - } - } - } - } -} - -// Update RequestEventsMap using data transfer events in tpu::system. -// Each data transfer is associated with a start event, an end event, and a -// transfer type (H2D or D2H). -void UpdateTpuDataTransferEventsInTpuSystem( - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - const HostEventType data_transfer_start_event, - const HostEventType data_transfer_end_event, - const EventType data_transfer_type, RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - absl::flat_hash_map> - events_per_transfer; - - auto build_events = - [&](const HostEventType event_type, - std::function func) { - if (const auto* events = - gtl::FindOrNull(host_events_by_type, event_type)) { - for (const XEventVisitor& event : *events) { - std::optional optional_group_id = - event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - std::optional context_id = - event.GetStat(StatType::kConsumerId); - if (!context_id.has_value()) continue; - func(context_id->IntValue(), &event); - } - } - }; - - // Build start event. - build_events(data_transfer_start_event, - [&](uint64_t id, const XEventVisitor* start_event) { - events_per_transfer[id] = {start_event, nullptr}; - }); - - // Build end event. - // This only happens when the start event exists, the end event has the same - // group ID as the start event, and the end event timestamp is larger than - // start event timestamp. - build_events(data_transfer_end_event, - [&](uint64_t id, const XEventVisitor* end_event) { - if (auto* value = gtl::FindOrNull(events_per_transfer, id)) { - const XEventVisitor* start_event = value->at(0); - if (start_event->TimestampPs() < end_event->TimestampPs()) { - value->at(1) = end_event; - } - } - }); - - std::vector event_to_update = { - {data_transfer_type, Timespan(0, 0)}}; - for (const auto& [id, events] : events_per_transfer) { - if (events[0] != nullptr && events[1] != nullptr) { - // Duration of the data transfer is measured as the timespan between - // start and end events. - event_to_update[0].span = - Timespan(events[0]->TimestampPs(), - events[1]->EndTimestampPs() - events[0]->TimestampPs()); - if (request_events_map != nullptr) { - UpdateRequestEvents(group_metadata_map, event_to_update, - events[0]->GetStat(StatType::kGroupId)->IntValue(), - request_events_map); - } - if (batch_events_map != nullptr) { - UpdateBatchEvents(group_metadata_map, event_to_update, - events[0]->GetStat(StatType::kGroupId)->IntValue(), - batch_events_map); - } - } - } -} - -// Initializes device side events for TPU. -void BuildTPUDeviceEvents(const std::vector& device_traces, - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - static constexpr int64_t kDataTransferTypes[] = { - HostEventType::kReadHbm, HostEventType::kTransferD2HRequest, - HostEventType::kWriteHbm, HostEventType::kTransferH2DRequest, - HostEventType::kTransferPreprocessedH2DRequest}; - auto data_transfer_type_to_enum = [](const int64_t type) { - switch (type) { - case HostEventType::kReadHbm: - case HostEventType::kTransferD2HRequest: - return EventType::DEVICE_TO_HOST; - case HostEventType::kWriteHbm: - case HostEventType::kTransferH2DRequest: - case HostEventType::kTransferPreprocessedH2DRequest: - return EventType::HOST_TO_DEVICE; - default: - return EventType::UNKNOWN_TIME; - } - }; - - // Initialize a TPU device event for future updates. - // In order to reuse the same UpdateRequestEvents function with GPU device - // events, here we create a vector of size 1 for TPU event. - std::vector event_to_update = { - {EventType::UNKNOWN_TIME, Timespan(0, 0)}}; - - // Update RequestEventsMap using data transfer events. - for (const int64_t data_transfer_type : kDataTransferTypes) { - if (const auto* data_transfer_events = - gtl::FindOrNull(host_events_by_type, data_transfer_type)) { - for (const XEventVisitor& data_transfer_event : *data_transfer_events) { - std::optional optional_group_id = - data_transfer_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - event_to_update[0] = {data_transfer_type_to_enum(data_transfer_type), - data_transfer_event.GetTimespan()}; - if (request_events_map != nullptr) { - UpdateRequestEvents(group_metadata_map, event_to_update, group_id, - request_events_map); - } - if (batch_events_map != nullptr) { - UpdateBatchEvents(group_metadata_map, event_to_update, group_id, - batch_events_map); - } - } - } - } - - UpdateTpuDataTransferEventsInTpuSystem( - host_events_by_type, group_metadata_map, - HostEventType::kTransferToDeviceIssueEvent, - HostEventType::kTransferToDeviceDone, EventType::HOST_TO_DEVICE, - request_events_map, batch_events_map); - - UpdateTpuDataTransferEventsInTpuSystem( - host_events_by_type, group_metadata_map, - HostEventType::kTransferFromDeviceIssueEvent, - HostEventType::kTransferFromDeviceDone, EventType::DEVICE_TO_HOST, - request_events_map, batch_events_map); - - for (const XPlane* device_trace : device_traces) { - XPlaneVisitor device_plane = CreateTfXPlaneVisitor(device_trace); - device_plane.ForEachLine([request_events_map, batch_events_map, - &event_to_update, - &group_metadata_map](const XLineVisitor& line) { - if (line.Name() != tsl::profiler::kXlaModuleLineName) return; - line.ForEachEvent([request_events_map, batch_events_map, &event_to_update, - &group_metadata_map](const XEventVisitor& event) { - std::optional group_id = - event.GetStat(StatType::kGroupId); - if (!group_id) return; - // TPU compute does not specify 32bit or 16bit, use - // DEVICE_COMPUTE_32 to annotate this is a compute event. - event_to_update[0] = {EventType::DEVICE_COMPUTE_32, - event.GetTimespan()}; - if (request_events_map != nullptr) { - UpdateRequestEvents(group_metadata_map, event_to_update, - group_id->IntValue(), request_events_map); - } - if (batch_events_map != nullptr) { - UpdateBatchEvents(group_metadata_map, event_to_update, - group_id->IntValue(), batch_events_map); - } - }); - }); - } - - // Update timestamp for TPU execute event. It is used as the beginning of - // TPU runtime. For old TPU runtime, it is the TPUPartitionedCall events, - // for the new TPU runtime, it is the tpu::system::Execute event. There - // might be multiple TPU execute events in the same request, - // UpdateTsTPUExecute is implemented as getting the earlist timestamp of TPU - // execute event. - static constexpr int64_t kTPUExecuteTypes[] = { - HostEventType::kTpuPartitionedCallOpExecuteLocal, - HostEventType::kTpuPartitionedCallOpExecuteRemote, - HostEventType::kTpuPartitionedCallOpInitializeVarOnTpu, - HostEventType::kTpuSystemExecute}; - for (const int64_t tpu_execute_type : kTPUExecuteTypes) { - if (const auto* tpu_execute_events = - gtl::FindOrNull(host_events_by_type, tpu_execute_type)) { - for (const XEventVisitor& tpu_execute_event : *tpu_execute_events) { - std::optional optional_group_id = - tpu_execute_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps( - group_metadata_map, group_id, tpu_execute_event.TimestampPs(), - UpdateTsTPUExecute, request_events_map, batch_events_map); - } - } - } - - // Update timestamp for TPU program launch events. This is used as the end - // of TPU runtime. Only one of the following program launch events will - // appear in a single profile. - static constexpr int64_t kTPUProgramLaunchTypes[] = { - HostEventType::kDoEnqueueProgram, - HostEventType::kDoEnqueueContinuationProgram}; - for (const int64_t tpu_program_launch_type : kTPUProgramLaunchTypes) { - if (const auto* tpu_program_launch_events = - gtl::FindOrNull(host_events_by_type, tpu_program_launch_type)) { - for (const XEventVisitor& tpu_program_launch_event : - *tpu_program_launch_events) { - std::optional optional_group_id = - tpu_program_launch_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps(group_metadata_map, group_id, - tpu_program_launch_event.TimestampPs(), - UpdateTsTPUProgramLaunch, request_events_map, - batch_events_map); - } - } - } - - // Update timestamp for TPU complete callbacks. This is used as the start of - // host postprocessing. - if (const auto* tpu_complete_callback_events = gtl::FindOrNull( - host_events_by_type, HostEventType::kCompleteCallbacks)) { - for (const XEventVisitor& tpu_complete_callback_event : - *tpu_complete_callback_events) { - std::optional optional_group_id = - tpu_complete_callback_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps(group_metadata_map, group_id, - tpu_complete_callback_event.TimestampPs(), - UpdateTsTPUCompleteCallback, request_events_map, - batch_events_map); - } - } -} - -// Initializes device side events for GPU. -void BuildGPUDeviceEvents(const StepEvents& nonoverlapped_step_events, - const GroupMetadataMap& group_metadata_map, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - if (request_events_map != nullptr) { - for (const auto& [step_id, step_details] : nonoverlapped_step_events) { - UpdateRequestEvents(group_metadata_map, step_details.Events(), step_id, - request_events_map); - } - } - if (batch_events_map != nullptr) { - for (const auto& [step_id, step_details] : nonoverlapped_step_events) { - UpdateBatchEvents(group_metadata_map, step_details.Events(), step_id, - batch_events_map); - } - } -} - -// Initialize the mapping from group_id to model_id. Skip the event if it -// doesn't have group_id or model_id. -ModelIdMap InitializeModelIdMap( - const EventsByType& host_events_by_type, - const std::vector& user_defined_root_events) { - ModelIdMap model_id_map; - - // Helper function to process model id. - auto process_model_id = [&](const XEventVisitor& event) { - auto group_id = event.GetStat(StatType::kGroupId); - if (!group_id.has_value()) return; - std::optional model_id = event.GetStat(StatType::kModelId); - if (!model_id.has_value()) return; - model_id_map[group_id->IntValue()] = model_id->ToString(); - }; - - static constexpr int64_t kModelIdRequestTypes[] = { - HostEventType::kSessionRun, HostEventType::kTfrtModelRun, - HostEventType::kServingModelRun}; - for (const int64_t event_type : kModelIdRequestTypes) { - auto event_list = gtl::FindOrNull(host_events_by_type, event_type); - if (!event_list) continue; - for (const XEventVisitor& event : *event_list) { - process_model_id(event); - } - } - - for (const XEventVisitor* event : user_defined_root_events) { - process_model_id(*event); - } - - return model_id_map; -} - -// Builds a request_events_map from the given trace events. -void BuildRequestEventsMap(const std::vector& device_traces, - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - const StepEvents& nonoverlapped_step_events, - DeviceType device_type, - tensorflow::profiler::ModelIdDatabase* model_id_db, - RequestEventsMap* request_events_map) { - static constexpr int64_t kBatchingRequestTypes[] = { - HostEventType::kBatchingSessionRun}; - static constexpr int64_t kNonBatchingRequestTypes[] = { - HostEventType::kSessionRun, HostEventType::kRunGraph}; - // TODO(wffw): Merge them once go/pathways-tfrt-serving-unification is done. - static constexpr int64_t kTfrtRequestTypes[] = {HostEventType::kTfrtModelRun}; - static constexpr int64_t kPathwayRequestTypes[] = { - HostEventType::kServingModelRun}; - - static constexpr int64_t kScheduleEventTypes[] = { - HostEventType::kScheduleWithSplit, HostEventType::kScheduleWithoutSplit, - HostEventType::kScheduleWithEagerSplit, - HostEventType::kASBSQueueSchedule}; - - // Events marked with "_r:-1" are user defined root events. - std::vector user_defined_root_events; - for (const auto& [_, events] : host_events_by_type) { - for (const auto& event : events) { - std::optional stat = event.GetStat(StatType::kIsRoot); - if (stat.has_value() && stat->IntValue() == -1) { - user_defined_root_events.push_back(&event); - } - } - } - - // Group IDs of ProcessBatch events. - absl::flat_hash_set process_batch_group_ids; - if (const auto* process_batch_events = - gtl::FindOrNull(host_events_by_type, HostEventType::kProcessBatch)) { - for (const XEventVisitor& process_batch_event : *process_batch_events) { - std::optional optional_group_id = - process_batch_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - process_batch_group_ids.insert(optional_group_id->IntValue()); - } - } - - ModelIdMap model_id_map = - InitializeModelIdMap(host_events_by_type, user_defined_root_events); - - // Initialize RequestEventsMap. - bool is_batching_request = - host_events_by_type.contains(HostEventType::kBatchingSessionRun); - bool is_tfrt_request = - host_events_by_type.contains(HostEventType::kTfrtModelRun); - // TODO(wffw): Merge them once go/pathways-tfrt-serving-unification is done. - bool is_pathway_request = - host_events_by_type.contains(HostEventType::kServingModelRun); - absl::Span request_types; - if (is_batching_request) { - request_types = absl::Span(kBatchingRequestTypes); - } else if (is_tfrt_request) { - request_types = absl::Span(kTfrtRequestTypes); - } else if (is_pathway_request) { - request_types = absl::Span(kPathwayRequestTypes); - } else { - request_types = absl::Span(kNonBatchingRequestTypes); - } - for (const int64_t request_type : request_types) { - if (const auto* request_events = - gtl::FindOrNull(host_events_by_type, request_type)) { - for (const XEventVisitor& request_event : *request_events) { - InitializeRequestEvents(request_event, group_metadata_map, - process_batch_group_ids, model_id_map, - is_batching_request, - /* is_user_defined_request=*/false, model_id_db, - request_events_map); - } - } - } - - for (const XEventVisitor* event : user_defined_root_events) { - InitializeRequestEvents( - *event, group_metadata_map, process_batch_group_ids, model_id_map, - /*is_batching_request=*/false, - /* is_user_defined_request=*/true, model_id_db, request_events_map); - } - - // Set the begin and end timestamp of the request. - UpdateRequestTimespan(host_events_by_type, request_events_map); - - // Update RequestEventsMap using the request size in schedule event. - for (const int64_t schedule_type : kScheduleEventTypes) { - if (const auto* schedule_events = - gtl::FindOrNull(host_events_by_type, schedule_type)) { - for (const XEventVisitor& schedule_event : *schedule_events) { - std::optional optional_group_id = - schedule_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - // Update timestamp for schedule events. It is used as the beginning - // of batch formation. - UpdateEventTimestamps(group_metadata_map, group_id, - schedule_event.TimestampPs(), - UpdateTsBatchSchedule, request_events_map); - if (auto* request_events = - gtl::FindOrNull(*request_events_map, group_id)) { - std::optional batching_request_size = - schedule_event.GetStat(StatType::kBatchingInputTaskSize); - if (!batching_request_size.has_value()) continue; - request_events->batching_request_size = - batching_request_size->IntValue(); - } - } - } - } - - if (device_type == DeviceType::kTpu) { - BuildTPUDeviceEvents(device_traces, host_events_by_type, group_metadata_map, - request_events_map, nullptr); - } else if (device_type == DeviceType::kGpu) { - BuildGPUDeviceEvents(nonoverlapped_step_events, group_metadata_map, - request_events_map, nullptr); - } -} - -// Extracts batch details from . -void BuildBatchEventsMap(const std::vector& device_traces, - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - const StepEvents& nonoverlapped_step_events, - DeviceType device_type, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - // Initialize BatchDetails from ProcessBatch events. - if (const auto* process_batch_events = - gtl::FindOrNull(host_events_by_type, HostEventType::kProcessBatch)) { - for (const XEventVisitor& process_batch_event : *process_batch_events) { - std::optional optional_group_id = - process_batch_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) continue; - BatchEvents& batch_events = (*batch_events_map)[group_id]; - tensorflow::profiler::BatchDetail& batch_detail = - batch_events.batch_detail_proto; - batch_detail.set_batch_id(group_id); - batch_detail.set_start_time_ps(process_batch_event.TimestampPs()); - batch_detail.set_end_time_ps(process_batch_event.EndTimestampPs()); - // The parent group_ids of a batch are the requests related to this - // batch. - for (const int64_t parent_group_id : group_metadata->parents) { - batch_detail.add_related_request_ids(parent_group_id); - } - // Sort related_request_ids to get deterministic result. - std::sort(batch_detail.mutable_related_request_ids()->begin(), - batch_detail.mutable_related_request_ids()->end()); - } - } - - // Update BatchDetailsMap with padding information. Only one of - // ConcatInputTensors (for in-graph batching) or MergeInputTensors (for - // BatchingSession), or BrainSessionRun will appear in the - // same profile. - static constexpr int64_t kPaddingEventTypes[] = { - HostEventType::kConcatInputTensors, - HostEventType::kMergeInputTensors, - HostEventType::kBrainSessionRun, - }; - for (const int64_t padding_event_type : kPaddingEventTypes) { - if (const auto* padding_events = - gtl::FindOrNull(host_events_by_type, padding_event_type)) { - for (const XEventVisitor& padding_event : *padding_events) { - // Update timestamp for padding events. They are used as the - // beginning of batch processing. - std::optional optional_group_id = - padding_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps(group_metadata_map, group_id, - padding_event.TimestampPs(), - UpdateTsBatchConcatInput, request_events_map); - BatchEvents* batch_events = - gtl::FindOrNull(*batch_events_map, group_id); - if (!batch_events) continue; - std::optional padding_amount = - padding_event.GetStat(StatType::kPaddingAmount); - if (!padding_amount.has_value()) continue; - std::optional batch_size_after_padding = - padding_event.GetStat(StatType::kBatchSizeAfterPadding); - if (!batch_size_after_padding.has_value()) continue; - tensorflow::profiler::BatchDetail* batch_detail = - &batch_events->batch_detail_proto; - batch_detail->set_batch_size_after_padding( - batch_size_after_padding->IntValue()); - batch_detail->set_padding_amount(padding_amount->IntValue()); - } - } - } - - // Populate BatchDetailsMap with model_id information from the corresponding - // requests in RequestEventsMap. - for (auto& [batch_id, batch_events] : *batch_events_map) { - tensorflow::profiler::BatchDetail& batch_detail = - batch_events.batch_detail_proto; - if (!batch_detail.related_request_ids().empty()) { - // Set the model_id of a batch using the model_id of the corresponding - // request. All requests in the same batch must share the same model_id, - // so we can pick any request in the batch here. - int32_t first_request_id = batch_detail.related_request_ids(0); - const RequestEvents* request_events = - gtl::FindOrNull(*request_events_map, first_request_id); - if (request_events) { - batch_detail.set_model_id_index(request_events->model_id_index); - } - } - } - - if (device_type == DeviceType::kTpu) { - BuildTPUDeviceEvents(device_traces, host_events_by_type, group_metadata_map, - nullptr, batch_events_map); - } else if (device_type == DeviceType::kGpu) { - BuildGPUDeviceEvents(nonoverlapped_step_events, group_metadata_map, nullptr, - batch_events_map); - } -} - -// Calculates the delay between request and batch. -void GenerateRequestAndBatchDelay(RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - for (auto& [request_id, request_event] : *request_events_map) { - const tensorflow::profiler::BatchDetail* first_batch_detail = nullptr; - const tensorflow::profiler::BatchDetail* last_batch_detail = nullptr; - // For each request, measure the latency between the request and the first - // batch that processes this request. - for (const int64_t batch_id : request_event.related_batch_ids) { - const auto* batch_events = gtl::FindOrNull(*batch_events_map, batch_id); - if (!batch_events) continue; - const tensorflow::profiler::BatchDetail* batch_detail = - &batch_events->batch_detail_proto; - if (!first_batch_detail || (first_batch_detail->has_start_time_ps() > - batch_detail->has_start_time_ps())) { - first_batch_detail = batch_detail; - } - if (!last_batch_detail || (last_batch_detail->has_end_time_ps() < - batch_detail->has_end_time_ps())) { - last_batch_detail = batch_detail; - } - } - if (first_batch_detail) { - request_event.batching_request_delay_ps = - first_batch_detail->start_time_ps() - - request_event.request_timespan.begin_ps(); - } - if (last_batch_detail && request_event.request_timespan.end_ps() < - last_batch_detail->end_time_ps()) { - request_event.request_timespan = - Timespan::FromEndPoints(request_event.request_timespan.begin_ps(), - last_batch_detail->end_time_ps()); - } - } - - for (auto& [batch_id, batch_events] : *batch_events_map) { - const RequestEvents* first_request_events = nullptr; - tensorflow::profiler::BatchDetail& batch_detail = - batch_events.batch_detail_proto; - // For each batch, measure the latency between the first request in this - // batch and the start time of this batch. - for (const int64_t request_id : batch_detail.related_request_ids()) { - const auto* request_events = - gtl::FindOrNull(*request_events_map, request_id); - if (!request_events) continue; - if (!first_request_events || - (first_request_events->request_timespan.begin_ps() > - request_events->request_timespan.begin_ps())) { - first_request_events = request_events; - } - } - if (first_request_events) { - batch_detail.set_batch_delay_ps( - batch_detail.start_time_ps() - - first_request_events->request_timespan.begin_ps()); - } - } -} - -// Generates detailed breakdown for a request by generating events using the -// timestamps in RequestEvents. -void GenerateRequestDetailedBreakdown(RequestEventsMap* request_events_map) { - for (auto& [_, request] : *request_events_map) { - std::optional first_tpu_execute; - std::optional first_batch_concat_input; - std::optional last_tpu_complete_callback; - std::optional only_batch_schedule; - for (const auto& [group_id, timestamps] : request.timestamps) { - if (timestamps.ts_tpu_execute.has_value()) { - MinOfOptional(first_tpu_execute, timestamps.ts_tpu_execute); - - // Host runtime: From the start of TPU execute event to the start of - // TPU program launch. Because of request splitting, there can be - // multiple host runtime in a single request, one for each batch. - if (timestamps.ts_tpu_program_launch.has_value()) { - request.events.push_back( - {EventType::HOST_RUNTIME, - Timespan::FromEndPoints( - timestamps.ts_tpu_execute.value(), - timestamps.ts_tpu_program_launch.value())}); - } - } - - if (timestamps.ts_batch_concat_input.has_value()) { - MinOfOptional(first_batch_concat_input, - timestamps.ts_batch_concat_input); - } - - if (timestamps.ts_tpu_complete_callback.has_value()) { - MaxOfOptional(last_tpu_complete_callback, - timestamps.ts_tpu_complete_callback); - } - - if (timestamps.ts_batch_schedule.has_value()) { - if (only_batch_schedule.has_value()) { - LOG(ERROR) << "Found multiple batch schedule events in a single " - << "request."; - } else { - only_batch_schedule = timestamps.ts_batch_schedule; - } - } - } - - // Host preprocessing: From the start of the request to the start of the - // first execute event. There is only one host preprocess even if there - // are multiple batches caused by request splitting. - if (first_tpu_execute.has_value()) { - request.events.push_back( - {EventType::HOST_PREPROCESS, - Timespan::FromEndPoints(request.request_timespan.begin_ps(), - first_tpu_execute.value())}); - } - - // Host postprocessing: If there are CompleteCallback events for this - // request, use the last CompleteCallback event as the beginning of host - // postprocessing. Else, use the end time of the last TPU device compute - // events. There is only one host postprocessing even if there are - // multiple batches caused by request splitting. - if (last_tpu_complete_callback.has_value()) { - request.events.push_back( - {EventType::HOST_POSTPROCESS, - Timespan::FromEndPoints(last_tpu_complete_callback.value(), - request.request_timespan.end_ps())}); - } else { - // Get the latest end time of TPU device compute events. - // These events are annotated with type DEVICE_COMPUTE_32. - // TODO(tianrun): Deprecate this code path after CompleteCallback is - // enabled in all Tensorflow binaries. - uint64_t device_compute_end = 0; - for (const auto& event : request.events) { - if (event.type == EventType::DEVICE_COMPUTE_32) { - device_compute_end = - std::max(device_compute_end, event.span.end_ps()); - } - } - if (device_compute_end != 0) { - request.events.push_back( - {EventType::HOST_POSTPROCESS, - Timespan::FromEndPoints(device_compute_end, - request.request_timespan.end_ps())}); - } - } - - // Batch formation: From the start of batch schedule, to the start of the - // first concat input. This is only applicable when batching is enabled, - // and it overlaps with host preprocessing. - if (only_batch_schedule.has_value() && - first_batch_concat_input.has_value()) { - request.events.push_back( - {EventType::HOST_BATCH_FORMATION, - Timespan::FromEndPoints(only_batch_schedule.value(), - first_batch_concat_input.value())}); - } - } -} - -// Generates tensor patterns from tensor related EventNodes. -// If there is any error during the generation, return an empty string. -std::string GenerateTensorPattern( - const std::vector& tensor_events) { - // Generate one sub pattern for each tensor event, the sub pattern records - // the tensor shape, type, and layout. - std::vector sub_patterns; - sub_patterns.reserve(tensor_events.size()); - for (const XEventVisitor* tensor_event : tensor_events) { - std::optional shape = - tensor_event->GetStat(StatType::kTensorShapes); - if (!shape.has_value()) return ""; - std::optional layout = - tensor_event->GetStat(StatType::kTensorLayout); - if (!layout.has_value()) return ""; - sub_patterns.push_back(absl::StrCat(tensor_event->Name(), " ", - shape->StrOrRefValue(), " ", - layout->StrOrRefValue())); - } - // Sort the sub patterns to get a deterministic result. - std::sort(sub_patterns.begin(), sub_patterns.end()); - // The final tensor pattern is generated as the concatenation of all sub - // patterns. Use
as separator so it can be displayed properly in - // frontend. - return absl::StrJoin(sub_patterns, "
"); -} - -// Generates the total time spent on linearize and delinearize tensors. -uint64_t GenerateTensorLinearizeDelinearizeTime( - const std::vector& tensor_events) { - uint64_t result = 0; - for (const XEventVisitor* tensor_event : tensor_events) { - result += tensor_event->DurationPs(); - } - return result; -} - -// Generates the details related to tensor shape, type, and layout. -void GenerateTensorDetails( - const EventsByType& host_events_by_type, - RequestEventsMap* request_events_map, BatchEventsMap* batch_events_map, - tensorflow::profiler::InferenceStats* inference_stats) { - static constexpr int64_t kTensorDetailEventTypes[] = { - HostEventType::kLinearize, HostEventType::kDelinearize, - HostEventType::kTransferBufferFromDeviceFastPath}; - - for (const int64_t tensor_detail_event_type : kTensorDetailEventTypes) { - if (const auto* tensor_detail_events = - gtl::FindOrNull(host_events_by_type, tensor_detail_event_type)) { - for (const XEventVisitor& tensor_detail_event : *tensor_detail_events) { - std::optional optional_group_id = - tensor_detail_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - // Add events to corresponding requests and batches. - if (auto* request_events = - gtl::FindOrNull(*request_events_map, group_id)) { - request_events->tensor_events.push_back(&tensor_detail_event); - } else if (auto* batch_events = - gtl::FindOrNull(*batch_events_map, group_id)) { - batch_events->tensor_events.push_back(&tensor_detail_event); - } - } - } - } - - absl::flat_hash_map tensor_patterns; - auto get_tensor_pattern_index = - [&tensor_patterns](const std::string& tensor_pattern) { - if (int* index = gtl::FindOrNull(tensor_patterns, tensor_pattern)) { - return *index; - } - int index = tensor_patterns.size(); - tensor_patterns.insert(std::make_pair(tensor_pattern, index)); - return index; - }; - - // Generates the tensor details that are owned by request. - for (auto& [group_id, request_events] : *request_events_map) { - if (request_events.tensor_events.empty()) continue; - std::string tensor_pattern = - GenerateTensorPattern(request_events.tensor_events); - if (tensor_pattern.empty()) continue; - int index = get_tensor_pattern_index(tensor_pattern); - tensorflow::profiler::TensorEventDetail tensor_event_detail; - tensor_event_detail.set_tensor_pattern_index(index); - tensor_event_detail.set_owner( - tensorflow::profiler::TensorEventDetail::REQUEST); - tensor_event_detail.set_linearize_delinearize_time_ps( - GenerateTensorLinearizeDelinearizeTime(request_events.tensor_events)); - request_events.tensor_event_detail_protos.push_back( - std::move(tensor_event_detail)); - } - - // Generates the tensor details that are owned by batch. - for (auto& [group_id, batch_events] : *batch_events_map) { - if (batch_events.tensor_events.empty()) continue; - std::string tensor_pattern = - GenerateTensorPattern(batch_events.tensor_events); - if (tensor_pattern.empty()) continue; - int index = get_tensor_pattern_index(tensor_pattern); - auto* tensor_event_detail = - batch_events.batch_detail_proto.mutable_tensor_event_detail(); - tensor_event_detail->set_tensor_pattern_index(index); - tensor_event_detail->set_owner( - tensorflow::profiler::TensorEventDetail::BATCH); - tensor_event_detail->set_linearize_delinearize_time_ps( - GenerateTensorLinearizeDelinearizeTime(batch_events.tensor_events)); - } - - // Populates the tensor details from batch to the related requests. These - // tensor details are still owned by the batches and will not be used to - // calculate statistics like the number of occurrence of each tensor - // pattern. - for (const auto& [group_id, batch_events] : *batch_events_map) { - if (!batch_events.batch_detail_proto.has_tensor_event_detail()) continue; - for (const int64_t request_id : - batch_events.batch_detail_proto.related_request_ids()) { - if (auto* request_events = - gtl::FindOrNull(*request_events_map, request_id)) { - request_events->tensor_event_detail_protos.push_back( - batch_events.batch_detail_proto.tensor_event_detail()); - } - } - } - - // Generates TensorPatternDatabase. - if (tensor_patterns.empty()) { - return; - } - absl::flat_hash_map reversed_tensor_patterns; - for (const auto& tensor_pattern : tensor_patterns) { - reversed_tensor_patterns[tensor_pattern.second] = &tensor_pattern.first; - } - for (int i = 0; i < static_cast(tensor_patterns.size()); i++) { - inference_stats->mutable_tensor_pattern_db()->add_tensor_pattern( - *reversed_tensor_patterns.at(i)); - } -} - -// Generate batch details from batch events. -// host runtime breakdown (added in request details) is not supported. -void BatchEventsToDetails(DeviceType device_type, int64_t group_id, - const BatchEvents& batch_events, - tensorflow::profiler::BatchDetail* batch_detail) { - std::vector tpu_non_overlapped_events; - const std::vector* non_overlapped_events = - &tpu_non_overlapped_events; - if (device_type == DeviceType::kTpu) { - // For TPU device events, batch_events.events may be overlapped in the - // timeline. So first converts it to non-overlapped events in the timeline - // before the breakdown. - tpu_non_overlapped_events = ToNonOverlappedEvents(batch_events.events); - } else if (device_type == DeviceType::kGpu) { - // For GPU device events, batch_events.events come from non overlapped - // StepEvents, so there is no need to convert to non overlapping events - // again. - non_overlapped_events = &(batch_events.events); - } - - int64_t device_time_ps = 0; - for (const auto& event : *non_overlapped_events) { - const auto& duration_ps = event.span.duration_ps(); - switch (event.type) { - case EventType::DEVICE_COMPUTE_16: - case EventType::DEVICE_COMPUTE_32: - device_time_ps += duration_ps; - break; - default: - break; - } - } - batch_detail->set_device_time_ps(device_time_ps); -} - -// Generates the request details proto from its events. -void RequestEventsToDetails( - DeviceType device_type, int64_t group_id, - const RequestEvents& request_events, - tensorflow::profiler::RequestDetail* request_detail) { - request_detail->set_request_id(group_id); - request_detail->set_model_id_index(request_events.model_id_index); - request_detail->set_start_time_ps(request_events.request_timespan.begin_ps()); - request_detail->set_end_time_ps(request_events.request_timespan.end_ps()); - request_detail->set_batching_request_delay_ps( - request_events.batching_request_delay_ps); - request_detail->set_batching_request_size( - request_events.batching_request_size); - for (const auto& tensor_event_detail : - request_events.tensor_event_detail_protos) { - *request_detail->add_tensor_event_details() = tensor_event_detail; - } - for (const int64_t batch_id : request_events.related_batch_ids) { - request_detail->add_related_batch_ids(batch_id); - } - - std::vector tpu_non_overlapped_events; - const std::vector* non_overlapped_events = - &tpu_non_overlapped_events; - if (device_type == DeviceType::kTpu) { - // For TPU device events, request_events.events may be overlapped in the - // timeline. So first converts it to non-overlapped events in the timeline - // before the breakdown. - tpu_non_overlapped_events = ToNonOverlappedEvents(request_events.events); - } else if (device_type == DeviceType::kGpu) { - // For GPU device events, request_events.events come from non overlapped - // StepEvents, so there is no need to convert to non overlapping events - // again. - non_overlapped_events = &(request_events.events); - } - - int64_t device_time_ps = 0; - int64_t write_time_ps = 0; - int64_t read_time_ps = 0; - int64_t host_preprocess_ps = 0; - int64_t host_postprocess_ps = 0; - int64_t host_runtime_ps = 0; - int64_t host_batch_formation_ps = 0; - int64_t idle_time_ps = 0; - for (const auto& event : *non_overlapped_events) { - const auto& duration_ps = event.span.duration_ps(); - switch (event.type) { - case EventType::DEVICE_COMPUTE_16: - case EventType::DEVICE_COMPUTE_32: - device_time_ps += duration_ps; - break; - case EventType::HOST_TO_DEVICE: - write_time_ps += duration_ps; - break; - case EventType::DEVICE_TO_HOST: - read_time_ps += duration_ps; - break; - case EventType::HOST_PREPROCESS: - host_preprocess_ps += duration_ps; - break; - case EventType::HOST_POSTPROCESS: - host_postprocess_ps += duration_ps; - break; - case EventType::HOST_RUNTIME: - host_runtime_ps += duration_ps; - break; - case EventType::HOST_BATCH_FORMATION: - host_batch_formation_ps += duration_ps; - break; - case EventType::UNKNOWN_TIME: - idle_time_ps += duration_ps; - break; - default: - break; - } - } - request_detail->set_device_time_ps(device_time_ps); - request_detail->set_write_to_device_time_ps(write_time_ps); - request_detail->set_read_from_device_time_ps(read_time_ps); - request_detail->set_host_preprocessing_ps(host_preprocess_ps); - request_detail->set_host_postprocessing_ps(host_postprocess_ps); - request_detail->set_host_runtime_ps(host_runtime_ps); - request_detail->set_host_batch_formation_ps(host_batch_formation_ps); - request_detail->set_idle_time_ps(idle_time_ps); -} - -// Compares two data points by duration. -// DataType can be either RequestDetail or BatchDetail. -template -bool CompareByDuration(const DataType& a, const DataType& b) { - return Timespan::ByDuration( - Timespan::FromEndPoints(a.start_time_ps(), a.end_time_ps()), - Timespan::FromEndPoints(b.start_time_ps(), b.end_time_ps())); -} - -void BuildRequestDetails( - const RequestEventsMap& request_events_map, DeviceType device_type, - const int32_t host_id, - tsl::protobuf::RepeatedPtrField* - request_details) { - for (auto& [group_id, request_events] : request_events_map) { - if (request_events.request_timespan.duration_ps() == 0) continue; - tensorflow::profiler::RequestDetail* request_detail = - request_details->Add(); - request_detail->set_host_id(host_id); - RequestEventsToDetails(device_type, group_id, request_events, - request_detail); - } - std::sort(request_details->begin(), request_details->end(), - CompareByDuration); -} - -void BuildBatchDetails( - BatchEventsMap batch_events_map, DeviceType device_type, - const int32_t host_id, - tsl::protobuf::RepeatedPtrField* - batch_details) { - for (auto& [group_id, batch_events] : batch_events_map) { - tensorflow::profiler::BatchDetail* batch_detail = batch_details->Add(); - *batch_detail = std::move(batch_events.batch_detail_proto); - batch_detail->set_host_id(host_id); - BatchEventsToDetails(device_type, group_id, batch_events, batch_detail); - } - std::sort(batch_details->begin(), batch_details->end(), - CompareByDuration); -} - -// Parses TFstreamz xplane to get batching parameters, and stores the -// parameters to . -void ParseTfstreamzForBatchingParameter( - const XSpace& xspace, tensorflow::profiler::ModelIdDatabase* model_id_db) { - const XPlane* tfstreamz_plane = ::tsl::profiler::FindPlaneWithName( - xspace, tsl::profiler::kTFStreamzPlaneName); - // There are two TFStreamz events per profile, one at the beginning, one at - // the end of the profile, each represents a snapshot of the TFstreamz. - // Use the last one as the source to get batching parameters because the - // first snapshot might be taken before Tensorflow setting up the batching - // parameters. - if (tfstreamz_plane == nullptr || tfstreamz_plane->lines().empty() || - tfstreamz_plane->lines(0).events_size() != 2) { - return; - } - XPlaneVisitor plane(tfstreamz_plane); - XEventVisitor event(&plane, &tfstreamz_plane->lines(0), - &tfstreamz_plane->lines(0).events(1)); - - static constexpr char kBatchingParamPrefix[] = - "/tensorflow/serving/batching/"; - static constexpr char kBatchingParamNumBatchThreads[] = "num_batch_threads"; - static constexpr char kBatchingParamBatchTimeoutMicros[] = - "batch_timeout_micros"; - static constexpr char kBatchingParamMaxBatchSize[] = "max_batch_size"; - static constexpr char kBatchingParamMaxEnqueuedBatches[] = - "max_enqueued_batches"; - static constexpr char kBatchingParamAllowedBatchSizes[] = - "allowed_batch_sizes"; - - // Parse the batching parameters from TFstreamz and associate it them with - // model IDs. - absl::flat_hash_map - model_params; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!absl::StartsWith(stat.Name(), kBatchingParamPrefix)) return; - - absl::string_view param_detail = - stat.Name().substr(ABSL_ARRAYSIZE(kBatchingParamPrefix) - 1); - auto [parse_success, model_id_tfstreamz] = ParseModelName(param_detail); - if (!parse_success) { - return; - } - - if (absl::StartsWith(param_detail, kBatchingParamNumBatchThreads)) { - model_params[model_id_tfstreamz].set_num_batch_threads(stat.IntValue()); - } else if (absl::StartsWith(param_detail, - kBatchingParamBatchTimeoutMicros)) { - model_params[model_id_tfstreamz].set_batch_timeout_micros( - stat.IntValue()); - } else if (absl::StartsWith(param_detail, kBatchingParamMaxBatchSize)) { - model_params[model_id_tfstreamz].set_max_batch_size(stat.IntValue()); - } else if (absl::StartsWith(param_detail, - kBatchingParamMaxEnqueuedBatches)) { - model_params[model_id_tfstreamz].set_max_enqueued_batches( - stat.IntValue()); - } else if (absl::StartsWith(param_detail, - kBatchingParamAllowedBatchSizes)) { - model_params[model_id_tfstreamz].set_allowed_batch_sizes( - std::string(stat.StrOrRefValue())); - } - }); - - // It is possible that the model IDs from Session.Run is in the format of - // :, while the model IDs in TFstreamz is in the format - // of (without the version number). Build a map to connect the - // model IDs in TFstreamz and Session.Run. - absl::flat_hash_map> - model_id_map; - for (const auto& model_id_and_version : model_id_db->ids()) { - size_t i = model_id_and_version.find_last_of(':'); - if (i == std::string::npos) { - model_id_map[model_id_and_version].push_back(model_id_and_version); - } else { - // If there is a version number at the end of model_id, remove the - // version number. - absl::string_view version_str(model_id_and_version.data() + i + 1); - int64_t version; - bool success = absl::SimpleAtoi(version_str, &version); - if (success) { - absl::string_view model_id_only(model_id_and_version.data(), i); - model_id_map[model_id_only].push_back(model_id_and_version); - } else { - LOG(ERROR) << "Can not parse model version number: " << version_str; - } - } - } - - // One model ID from TFstreamz might map to multiple model IDs in - // Session.Run, update the batching parameters of all the model IDs in - // Session.Run. - for (const auto& [model_id_tfstreamz, params] : model_params) { - if (const std::vector* model_ids_session_run = - gtl::FindOrNull(model_id_map, model_id_tfstreamz)) { - for (const absl::string_view model_id_session_run : - *model_ids_session_run) { - (*model_id_db->mutable_id_to_batching_params())[model_id_session_run] = - params; - } - } - } -} - -} // namespace - -std::pair ParseModelName(absl::string_view param) { - // Param can be in one of the two following formats: - // batching_param{model_name=} - // batching_param{model_name=, op_name=} - size_t label_begin = param.find_first_of('{'); - size_t label_end = param.find_last_of('}'); - if (label_begin == absl::string_view::npos || - label_end == absl::string_view::npos || label_end <= label_begin) { - return {false, ""}; - } - // Go over all the labels to look for model name. - std::vector labels = absl::StrSplit( - param.substr(label_begin + 1, label_end - label_begin - 1), ", "); - for (const absl::string_view label : labels) { - std::vector key_value = absl::StrSplit(label, '='); - if (key_value.size() != 2) continue; - if (key_value[0] == "model_name") { - return {true, key_value[1]}; - } - } - // Unable to find model name. - return {false, ""}; -} - -void GenerateInferenceStats( - const std::vector& device_traces, - const StepEvents& nonoverlapped_step_events, - const GroupMetadataMap& group_metadata_map, const XSpace& xspace, - DeviceType device_type, int32_t host_id, - tensorflow::profiler::InferenceStats* inference_stats) { - tensorflow::profiler::PerHostInferenceStats* per_host_inference_stats = - &(*inference_stats->mutable_inference_stats_per_host())[host_id]; - RequestEventsMap request_events_map; - - // Build the mapping from host event type to events. - EventsByType host_events_by_type; - const XPlane* host = tsl::profiler::FindPlaneWithName( - xspace, tsl::profiler::kHostThreadsPlaneName); - if (!host) return; - XPlaneVisitor host_plane = CreateTfXPlaneVisitor(host); - for (const auto& line : host->lines()) { - for (const auto& event : line.events()) { - XEventVisitor event_visitor(&host_plane, &line, &event); - auto type = event_visitor.Type(); - if (!type.has_value()) { - type = HostEventType::kUnknownHostEventType; - } - host_events_by_type[type.value()].push_back(event_visitor); - } - } - - BuildRequestEventsMap(device_traces, host_events_by_type, group_metadata_map, - nonoverlapped_step_events, device_type, - inference_stats->mutable_model_id_db(), - &request_events_map); - BatchEventsMap batch_events_map; - BuildBatchEventsMap(device_traces, host_events_by_type, group_metadata_map, - nonoverlapped_step_events, device_type, - &request_events_map, &batch_events_map); - - GenerateRequestAndBatchDelay(&request_events_map, &batch_events_map); - GenerateRequestDetailedBreakdown(&request_events_map); - - GenerateTensorDetails(host_events_by_type, &request_events_map, - &batch_events_map, inference_stats); - - auto* request_details = per_host_inference_stats->mutable_request_details(); - BuildRequestDetails(request_events_map, device_type, host_id, - request_details); - auto* batch_details = per_host_inference_stats->mutable_batch_details(); - BuildBatchDetails(std::move(batch_events_map), device_type, host_id, - batch_details); - - ParseTfstreamzForBatchingParameter(xspace, - inference_stats->mutable_model_id_db()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/inference_stats.h b/tensorflow/core/profiler/convert/inference_stats.h deleted file mode 100644 index cc291fa9e336f4..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Generates PerHostInferenceStats from the given trace events. -// For TPU, get time breakdown from device_traces. For GPU, get time breakdown -// from nonoverlapped_step_events. -// Get batching parameters from TFstreamz xplane in . -void GenerateInferenceStats( - const std::vector& device_traces, - const tensorflow::profiler::StepEvents& nonoverlapped_step_events, - const tsl::profiler::GroupMetadataMap& group_metadata_map, - const tsl::profiler::XSpace& xspace, tsl::profiler::DeviceType device_type, - int32_t host_id, tensorflow::profiler::InferenceStats* inference_stats); - -// Parses model name from TFstreamz. -// Returns whether the parsing is successful and the actual model name. If -// parsing failed, returns false and an empty string. -std::pair ParseModelName(absl::string_view param); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_combiner.cc b/tensorflow/core/profiler/convert/inference_stats_combiner.cc deleted file mode 100644 index fcca1310061d16..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_combiner.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_combiner.h" - -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/lib/gtl/map_util.h" - -namespace tensorflow::profiler { -namespace { -// Combines two ModelIdDatabases. Returns true if this combination requires -// updating the model_id_index in the SessionRunTimes of dst. This will be -// the case if: (1) Src has a model name that doesn't already exist in dst; -// or (2) Src has a model name that does exist in dst but has a different -// index. -bool CombineModelIdDatabases(const ModelIdDatabase& src, ModelIdDatabase* dst) { - if (dst->ids_size() == 0) { - // dst is empty. Simply copy src to dst. This avoids rebuilding - // dst from src from scratch, which may change the name-to-index mapping. - *dst = src; - return false; - } - // TODO(tianrun): For now, assume a model is always served with the same - // parameter on different hosts. In the future, we might consider the case - // when the same model are served with different batching parameters on - // different hosts. - for (const auto& id_and_param : src.id_to_batching_params()) { - dst->mutable_id_to_batching_params()->insert(id_and_param); - } - bool need_update = false; - for (const auto& [src_id, index] : src.id_to_index()) { - auto [iter, was_inserted] = - dst->mutable_id_to_index()->insert({src_id, dst->ids_size()}); - if (was_inserted) { - *dst->add_ids() = src_id; - need_update = true; - continue; - } - if (iter->second != index) { - // src_id is already in dst but has a different index. - need_update = true; - } - } - return need_update; -} - -// Combines two TensorPatternDatabase. Returns true if this combination requires -// updating the tensor_pattern_index. This will be the case if: (1) Src has a -// tensor pattern that doesn't exist in dst; or (2) Src has a tensor pattern -// that does exist in dst but has a different index. -bool CombineTensorPatternDatabase( - const TensorPatternDatabase& src, TensorPatternDatabase* dst, - absl::flat_hash_map* dst_pattern_to_index) { - if (dst->tensor_pattern().empty()) { - *dst = src; - return false; - } - - bool need_update = false; - for (int i = 0; i < static_cast(src.tensor_pattern_size()); i++) { - auto [iter, inserted] = dst_pattern_to_index->insert( - {src.tensor_pattern(i), dst_pattern_to_index->size()}); - if (inserted) { - // Src has a tensor pattern that doesn't exist in dst. - dst->add_tensor_pattern(src.tensor_pattern(i)); - need_update = true; - } else if (iter->second != i) { - // Src has a tensor pattern with different index than dst. - need_update = true; - } - } - return need_update; -} - -void UpdateTensorPatternIndex( - const TensorPatternDatabase& src, - const absl::flat_hash_map& dst_pattern_to_index, - TensorEventDetail* detail) { - absl::string_view tensor_pattern = - src.tensor_pattern(detail->tensor_pattern_index()); - if (const int* new_index = - tsl::gtl::FindOrNull(dst_pattern_to_index, tensor_pattern)) { - detail->set_tensor_pattern_index(*new_index); - } else { - LOG(WARNING) << "Tensor pattern " << tensor_pattern - << " is not found in dst->tensor_pattern_db()"; - } -} -} // namespace - -void CombineInferenceStatsResult(int src_host_id, const InferenceStats& src, - InferenceStats* dst) { - // There should be one key-value pair inside src.inference_stats_per_host(), - // because the src comes from one XprofResponse (i.e., one host). - DCHECK_LE(src.inference_stats_per_host_size(), 1); - bool need_update_model_id = - CombineModelIdDatabases(src.model_id_db(), dst->mutable_model_id_db()); - absl::flat_hash_map dst_pattern_to_index; - for (int i = 0; - i < static_cast(dst->tensor_pattern_db().tensor_pattern_size()); - i++) { - dst_pattern_to_index[dst->tensor_pattern_db().tensor_pattern(i)] = i; - } - bool need_update_tensor_pattern = CombineTensorPatternDatabase( - src.tensor_pattern_db(), dst->mutable_tensor_pattern_db(), - &dst_pattern_to_index); - for (const auto& [host_id, inf_stats] : src.inference_stats_per_host()) { - auto [iter, was_inserted] = dst->mutable_inference_stats_per_host()->insert( - {src_host_id, inf_stats}); - if (!was_inserted) { - LOG(INFO) << "Duplicate host_id: " << iter->first; - } - if (need_update_model_id || need_update_tensor_pattern) { - // Needs to update the model_id_index in the dst. - PerHostInferenceStats* dst_inference_stats = - &(*dst->mutable_inference_stats_per_host())[src_host_id]; - for (RequestDetail& request_detail : - *dst_inference_stats->mutable_request_details()) { - if (need_update_model_id && request_detail.model_id_index() != -1) { - // "model_id_index = -1" means there is no model_id associated with - // the group id in this event if client doesn't specify "model_id" in - // TraceMeEncode. so we don't need to update model_id if it doesn't - // have a model. - const std::string& model_id = - src.model_id_db().ids(request_detail.model_id_index()); - auto iter = dst->model_id_db().id_to_index().find(model_id); - if (iter == dst->model_id_db().id_to_index().end()) { - LOG(WARNING) << "Model ID " << model_id - << " is not found in dst->model_id_db()"; - continue; - } - request_detail.set_model_id_index(iter->second); - } - if (need_update_tensor_pattern) { - for (auto& tensor_event_details : - *request_detail.mutable_tensor_event_details()) { - UpdateTensorPatternIndex(src.tensor_pattern_db(), - dst_pattern_to_index, - &tensor_event_details); - } - } - } - } - if (need_update_tensor_pattern) { - PerHostInferenceStats* dst_inference_stats = - &(*dst->mutable_inference_stats_per_host())[src_host_id]; - for (BatchDetail& batch_detail : - *dst_inference_stats->mutable_batch_details()) { - UpdateTensorPatternIndex(src.tensor_pattern_db(), dst_pattern_to_index, - batch_detail.mutable_tensor_event_detail()); - } - } - } -} -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_combiner.h b/tensorflow/core/profiler/convert/inference_stats_combiner.h deleted file mode 100644 index ceccc9cca2608a..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_combiner.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -void CombineInferenceStatsResult(int src_host_id, const InferenceStats& src, - InferenceStats* dst); -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_grouping.cc b/tensorflow/core/profiler/convert/inference_stats_grouping.cc deleted file mode 100644 index fad8330ac72f63..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_grouping.cc +++ /dev/null @@ -1,475 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_grouping.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/tsl/lib/gtl/map_util.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow::profiler { - -namespace { - -using ::tensorflow::profiler::BatchDetail; -using ::tensorflow::profiler::InferenceStats; -using ::tensorflow::profiler::ModelIdDatabase; -using ::tensorflow::profiler::PerBatchSizeAggregatedResult; -using ::tensorflow::profiler::PerModelInferenceStats; -using ::tensorflow::profiler::RequestDetail; -using ::tensorflow::profiler::TensorEventDetail; -using ::tsl::profiler::Timespan; - -template -void push_down_heap(size_t hole, RandIt first, RandIt last, Compare comp) { - size_t size = last - first; - assert(hole < size); - auto value = std::move(first[hole]); - while (true) { - size_t l_child = 2 * hole + 1; - size_t r_child = l_child + 1; - size_t max_child = l_child; - if (r_child < size && comp(first[l_child], first[r_child])) { - max_child = r_child; - } - if (max_child >= size) break; - if (!comp(value, first[max_child])) break; - first[hole] = std::move(first[max_child]); - hole = max_child; - } - first[hole] = std::move(value); -} -// Pushes the root down the heap. -template -void push_root_heap(RandIt first, RandIt last, Compare comp) { - push_down_heap(0, std::move(first), std::move(last), std::move(comp)); -} - -template -Out nway_merge(const ContainerContainer& containers, Out out, Cmp cmp) { - using std::begin; - using std::end; - using In = decltype(begin(*begin(containers))); // The input iterator type. - using Range = std::pair; - std::vector sources; - for (const auto& container : containers) { - Range r(begin(container), end(container)); - if (r.first != r.second) sources.push_back(std::move(r)); - } - // Zero, one or two collections can be merged without a priority queue. - switch (sources.size()) { - case 0: - return out; - case 1: - return std::copy(sources[0].first, sources[0].second, out); - case 2: - return std::merge(sources[0].first, sources[0].second, sources[1].first, - sources[1].second, out, cmp); - } - // Take a comparator for T and produce an inverse comparator - // for std::pair, In>, inverted so as to produce a min-heap. - auto heap_cmp = [&](const Range& a, const Range& b) { - // Compares b < a instead of a < b. - return cmp(*b.first, *a.first); - }; - auto heap_data = sources.data(); - auto heap_size = sources.size(); - std::make_heap(heap_data, heap_data + heap_size, heap_cmp); - auto& top = sources.front(); - auto pop = [&]() { - *out = *top.first; - ++out; - ++top.first; - }; - - for (; heap_size > 2;) { - for (pop(); top.first != top.second; pop()) { - push_root_heap(heap_data, heap_data + heap_size, heap_cmp); - } - top = std::move(sources[--heap_size]); - push_root_heap(heap_data, heap_data + heap_size, heap_cmp); - } - - return std::merge(sources[0].first, sources[0].second, sources[1].first, - sources[1].second, out, cmp); -} - -double GetThroughput(size_t data_size, uint64_t start_time_ps, - uint64_t end_time_ps) { - return data_size / tsl::profiler::PicoToUni(end_time_ps - start_time_ps); -} - -// Compute throughput and average latency. -// DataType can either be RequestDetail or BatchDetail. -template -std::pair ComputeThroughputAndAverageLatencyUs( - const std::vector& all_data) { - if (all_data.empty()) { - // Return 0 immediately to avoid divide by zero error. - return std::make_pair(0.0, 0.0); - } - - uint64_t min_start_time_ps = std::numeric_limits::max(); - uint64_t max_end_time_ps = 0; - uint64_t total_latency_ps = 0; - - for (const DataType* data : all_data) { - min_start_time_ps = std::min(min_start_time_ps, data->start_time_ps()); - max_end_time_ps = std::max(max_end_time_ps, data->end_time_ps()); - total_latency_ps += (data->end_time_ps() - data->start_time_ps()); - } - - double throughput = - GetThroughput(all_data.size(), min_start_time_ps, max_end_time_ps); - double average_latency_us = - tsl::profiler::PicoToMicro(total_latency_ps) / all_data.size(); - return std::make_pair(throughput, average_latency_us); -} - -template -bool CompareByDuration(const DataType* a, const DataType* b) { - return Timespan::ByDuration( - Timespan::FromEndPoints(a->start_time_ps(), a->end_time_ps()), - Timespan::FromEndPoints(b->start_time_ps(), b->end_time_ps())); -} - -// Regroup data in using model id for future analysis. -// DataType can be either RequestDetail or BatchDetail. -template -void RegroupDataByModelId( - const ModelIdDatabase& model_id_db, - const std::vector*>& - data_by_host, - std::vector>* data_by_model_id) { - // First group data by model_id and host. - std::vector>> - data_by_model_id_by_host; - - // If model_id_db is empty, this means model_id is not available in the trace, - // so we simply consider the entire execution as a single model_id. - bool no_model_id = model_id_db.ids_size() == 0; - int model_index_size = no_model_id ? 1 : model_id_db.ids_size(); - int host_index_size = data_by_host.size(); - data_by_model_id_by_host.resize(model_index_size); - for (size_t model_index = 0; model_index < model_index_size; ++model_index) { - data_by_model_id_by_host[model_index].resize(host_index_size); - } - - int32_t host_index = 0; - for (const tsl::protobuf::RepeatedPtrField* single_host_data : - data_by_host) { - for (const DataType& data : *single_host_data) { - int model_index = no_model_id ? 0 : data.model_id_index(); - // If model_id_db is not empty, and a session/batch does not have - // model_id, ignore it in per model analysis. - if (model_index == -1) { - continue; - } - data_by_model_id_by_host[model_index][host_index].push_back(&data); - } - ++host_index; - } - - // data_by_host is already sorted by the latency, so - // data_by_model_id_by_host is also sorted by the latency. Therefore, - // we just need to do a n way merge instead of a real sorting. - data_by_model_id->resize(model_index_size); - for (size_t model_index = 0; model_index < model_index_size; ++model_index) { - int total_size = 0; - for (const auto& per_model_per_host : - data_by_model_id_by_host[model_index]) { - total_size += per_model_per_host.size(); - } - data_by_model_id->at(model_index).reserve(total_size); - } - for (size_t model_index = 0; model_index < model_index_size; ++model_index) { - nway_merge(data_by_model_id_by_host[model_index], - std::back_inserter(data_by_model_id->at(model_index)), - CompareByDuration); - } -} - -// Generates the tensor transfer aggregated result using the per model data in -// . -void GenerateTensorTransferAggregatedResult(PerModelInferenceStats* per_model) { - absl::flat_hash_map> - tensor_events_by_index; - // For requests, only count the tensor events with owner REQUEST, because if - // inference batching is enabled, there will be tensor events that are owned - // by batches and just inherited by requests. Counting these tensor events - // will lead to double counting. - for (const auto& request : per_model->request_details()) { - for (const auto& tensor_event : request.tensor_event_details()) { - if (tensor_event.owner() == TensorEventDetail::REQUEST) { - tensor_events_by_index[tensor_event.tensor_pattern_index()].push_back( - &tensor_event); - } - } - } - for (const auto& batch : per_model->batch_details()) { - if (batch.has_tensor_event_detail()) { - tensor_events_by_index[batch.tensor_event_detail().tensor_pattern_index()] - .push_back(&batch.tensor_event_detail()); - } - } - - if (tensor_events_by_index.empty()) return; - - static constexpr double kPercentiles[] = {50.0, 75.0, 90.0, 95.0, 99.0, 99.9}; - for (auto& [index, events] : tensor_events_by_index) { - auto* tensor_pattern_result = - per_model->mutable_tensor_transfer_aggregated_result() - ->add_tensor_pattern_results(); - tensor_pattern_result->set_tensor_pattern_index(index); - tensor_pattern_result->set_count(events.size()); - std::sort(events.begin(), events.end(), - [](const TensorEventDetail* a, const TensorEventDetail* b) { - return a->linearize_delinearize_time_ps() < - b->linearize_delinearize_time_ps(); - }); - for (const double percentile : kPercentiles) { - int index = static_cast(percentile / 100.0 * events.size()); - auto* percentile_time = - tensor_pattern_result->add_linearize_delinearize_percentile_time(); - percentile_time->set_percentile(percentile); - percentile_time->set_time_ps( - events[index]->linearize_delinearize_time_ps()); - } - } -} - -void AggregateRequest(const RequestDetail& input, RequestDetail* result) { - // In aggregated result, start_time is set to 0, and end time is set to the - // sum of the duration of the input requests. - result->set_end_time_ps(input.end_time_ps() - input.start_time_ps() + - result->end_time_ps()); - result->set_device_time_ps(result->device_time_ps() + input.device_time_ps()); - result->set_read_from_device_time_ps(result->read_from_device_time_ps() + - input.read_from_device_time_ps()); - result->set_write_to_device_time_ps(result->write_to_device_time_ps() + - input.write_to_device_time_ps()); - result->set_batching_request_delay_ps(result->batching_request_delay_ps() + - input.batching_request_delay_ps()); - result->set_batching_request_size(result->batching_request_size() + - input.batching_request_size()); - result->set_host_preprocessing_ps(result->host_preprocessing_ps() + - input.host_preprocessing_ps()); - result->set_host_batch_formation_ps(result->host_batch_formation_ps() + - input.host_batch_formation_ps()); - result->set_host_runtime_ps(result->host_runtime_ps() + - input.host_runtime_ps()); - result->set_host_postprocessing_ps(result->host_postprocessing_ps() + - input.host_postprocessing_ps()); - result->set_idle_time_ps(result->idle_time_ps() + input.idle_time_ps()); -} - -RequestDetail GetAverageRequestDetails(const RequestDetail& request, - int64_t size) { - RequestDetail result; - if (size == 0) return result; - // Average request detail does not have a request ID. - result.set_request_id(-1); - result.set_start_time_ps(0); - // Calculating average by dividing aggregated request by size. - result.set_end_time_ps(request.end_time_ps() / size); - result.set_device_time_ps(request.device_time_ps() / size); - result.set_write_to_device_time_ps(request.write_to_device_time_ps() / size); - result.set_read_from_device_time_ps(request.read_from_device_time_ps() / - size); - result.set_batching_request_delay_ps(request.batching_request_delay_ps() / - size); - result.set_batching_request_size(request.batching_request_size() / size); - result.set_host_preprocessing_ps(request.host_preprocessing_ps() / size); - result.set_host_batch_formation_ps(request.host_batch_formation_ps() / size); - result.set_host_runtime_ps(request.host_runtime_ps() / size); - result.set_host_postprocessing_ps(request.host_postprocessing_ps() / size); - result.set_idle_time_ps(request.idle_time_ps() / size); - return result; -} - -void AggregateBatch(const BatchDetail& input, BatchDetail* result) { - // In aggregated result, start_time is set to 0, and end time is set to the - // sum of the duration of the input batches. - result->set_end_time_ps(input.end_time_ps() - input.start_time_ps() + - result->end_time_ps()); - result->set_batch_delay_ps(result->batch_delay_ps() + input.batch_delay_ps()); - result->set_padding_amount(result->padding_amount() + input.padding_amount()); - result->set_batch_size_after_padding(result->batch_size_after_padding() + - input.batch_size_after_padding()); - result->set_device_time_ps(result->device_time_ps() + input.device_time_ps()); -} - -BatchDetail GetAverageBatchDetails(const BatchDetail& batch, int64_t size) { - BatchDetail result; - if (size == 0) return result; - // Average batch detail does not have a batch ID. - result.set_batch_id(-1); - result.set_start_time_ps(0); - // Calculating average by dividing aggregated batch by size. - result.set_end_time_ps(batch.end_time_ps() / size); - result.set_batch_delay_ps(batch.batch_delay_ps() / size); - result.set_padding_amount(batch.padding_amount() / size); - result.set_batch_size_after_padding(batch.batch_size_after_padding() / size); - result.set_device_time_ps(batch.device_time_ps() / size); - return result; -} - -void AggregatePerModelInferenceStats(InferenceStats* inference_stats) { - for (auto& [model_index, per_model_stats] : - *inference_stats->mutable_inference_stats_per_model()) { - // TODO: remove batch size aggregation from request table. - absl::flat_hash_map batch_id_to_batch; - for (const BatchDetail& b : per_model_stats.batch_details()) { - batch_id_to_batch[b.batch_id()] = &b; - } - - // Aggregated result for all data. - RequestDetail aggregated_r; - BatchDetail aggregated_b; - - struct PerBatchSizeInfo { - PerBatchSizeAggregatedResult result; - int request_count; - int batch_count; - }; - // Aggregated result per batch size. - absl::flat_hash_map per_batch_size_info; - - for (const RequestDetail& r : per_model_stats.request_details()) { - // Aggregate all data. - AggregateRequest(r, &aggregated_r); - // Aggregate per batch size. - // TODO: remove batch size aggregation from request table. - for (const auto batch_id : r.related_batch_ids()) { - if (const BatchDetail* batch = - ::tsl::gtl::FindPtrOrNull(batch_id_to_batch, batch_id)) { - int batch_size = batch->batch_size_after_padding(); - auto& info = per_batch_size_info[batch_size]; - AggregateRequest(r, info.result.mutable_aggregated_request_result()); - info.request_count++; - } - } - } - - for (const BatchDetail& b : per_model_stats.batch_details()) { - // Aggregate all data. - AggregateBatch(b, &aggregated_b); - // Aggregate per batch size. - int batch_size = b.batch_size_after_padding(); - auto& info = per_batch_size_info[batch_size]; - AggregateBatch(b, info.result.mutable_aggregated_batch_result()); - info.batch_count++; - } - - *per_model_stats.mutable_aggregated_request_detail() = - GetAverageRequestDetails(aggregated_r, - per_model_stats.request_details().size()); - *per_model_stats.mutable_aggregated_batch_detail() = GetAverageBatchDetails( - aggregated_b, per_model_stats.batch_details().size()); - - std::vector sorted_batch_sizes; - for (const auto& [batch_size, _] : per_batch_size_info) { - sorted_batch_sizes.push_back(batch_size); - } - std::sort(sorted_batch_sizes.begin(), sorted_batch_sizes.end()); - for (const int batch_size : sorted_batch_sizes) { - auto* result = per_model_stats.add_per_batch_size_aggregated_result(); - result->set_batch_size(batch_size); - auto& info = per_batch_size_info[batch_size]; - *result->mutable_aggregated_request_result() = GetAverageRequestDetails( - info.result.aggregated_request_result(), info.request_count); - result->set_request_throughput(info.request_count * - per_model_stats.request_throughput() / - per_model_stats.request_details_size()); - *result->mutable_aggregated_batch_result() = GetAverageBatchDetails( - info.result.aggregated_batch_result(), info.batch_count); - result->set_batch_throughput(info.batch_count * - per_model_stats.batch_throughput() / - per_model_stats.batch_details_size()); - } - } -} - -} // namespace - -void RegroupInferenceStatsByModel(InferenceStats* inference_stats) { - if (inference_stats->inference_stats_per_host().empty()) { - return; - } - std::vector*> - all_requests_by_host; - for (const auto& [host_id, per_host_inference_stats] : - inference_stats->inference_stats_per_host()) { - all_requests_by_host.push_back(&per_host_inference_stats.request_details()); - } - std::vector> requests_by_model_id; - RegroupDataByModelId(inference_stats->model_id_db(), all_requests_by_host, - &requests_by_model_id); - - std::vector*> - all_batches_by_host; - for (const auto& [host_id, per_host_inference_stats] : - inference_stats->inference_stats_per_host()) { - all_batches_by_host.push_back(&per_host_inference_stats.batch_details()); - } - std::vector> batches_by_model_id; - RegroupDataByModelId(inference_stats->model_id_db(), all_batches_by_host, - &batches_by_model_id); - - for (size_t index = 0; index < requests_by_model_id.size(); index++) { - auto* per_model = - &(*inference_stats->mutable_inference_stats_per_model())[index]; - for (const RequestDetail* request : requests_by_model_id[index]) { - *per_model->add_request_details() = *request; - } - for (const BatchDetail* batch : batches_by_model_id[index]) { - *per_model->add_batch_details() = *batch; - } - auto [request_throughput, request_latency] = - ComputeThroughputAndAverageLatencyUs(requests_by_model_id[index]); - per_model->set_request_throughput(request_throughput); - per_model->set_request_average_latency_us(request_latency); - auto [batch_throughput, batch_latency] = - ComputeThroughputAndAverageLatencyUs(batches_by_model_id[index]); - per_model->set_batch_throughput(batch_throughput); - per_model->set_batch_average_latency_us(batch_latency); - GenerateTensorTransferAggregatedResult(per_model); - } - - AggregatePerModelInferenceStats(inference_stats); - - // If there is no model id provided by user, create a fake "ALL" model id to - // represent all the requests during profiling. - // This ALL model id is mapped to index 0, which is consistent with the index - // used by RegroupDataByModelId. - if (inference_stats->model_id_db().ids().empty()) { - inference_stats->mutable_model_id_db()->add_ids("ALL"); - inference_stats->mutable_model_id_db()->mutable_id_to_index()->insert( - {"ALL", 0}); - } - inference_stats->clear_inference_stats_per_host(); -} - -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_grouping.h b/tensorflow/core/profiler/convert/inference_stats_grouping.h deleted file mode 100644 index 7d60da0f311826..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_grouping.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ - -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { - -// Change inference stats from per host to per model_id by doing a regroup. -// Future analysis of inference_stats will be on a per model_id basis. -void RegroupInferenceStatsByModel( - tensorflow::profiler::InferenceStats* inference_stats); - -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_grouping_test.cc b/tensorflow/core/profiler/convert/inference_stats_grouping_test.cc deleted file mode 100644 index 5d6d43e5ba8150..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_grouping_test.cc +++ /dev/null @@ -1,508 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_grouping.h" - -#include -#include "xla/tests/test_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -namespace { - -using ::testing::EqualsProto; -using ::xla::ParseTextProto; - -TEST(InferenceStatsGroupingTest, TestWithModelId) { - // An inference stats with two hosts, two models. - InferenceStats inference_stats = ParseTextProto(R"pb( - inference_stats_per_host { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - model_id_index: 0 - request_id: 0 - device_time_ps: 100 - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - model_id_index: 1 - request_id: 1 - device_time_ps: 100 - } - } - } - inference_stats_per_host { - key: 1 - value { - request_details { - start_time_ps: 3000 - end_time_ps: 4000 - model_id_index: 0 - request_id: 2 - device_time_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 5000 - model_id_index: 1 - request_id: 3 - device_time_ps: 100 - } - } - } - model_id_db { - ids: "Model-A:1" - ids: "Model-B:1" - id_to_index { key: "Model-A:1" value: 0 } - id_to_index { key: "Model-B:1" value: 1 } - } - )pb") - .value(); - - RegroupInferenceStatsByModel(&inference_stats); - - // Verifies that requests with the same model ID are grouped together. - EXPECT_THAT(inference_stats, EqualsProto(R"pb( - model_id_db { - ids: "Model-A:1" - ids: "Model-B:1" - id_to_index { key: "Model-A:1" value: 0 } - id_to_index { key: "Model-B:1" value: 1 } - } - inference_stats_per_model { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - model_id_index: 0 - request_id: 0 - device_time_ps: 100 - } - request_details { - start_time_ps: 3000 - end_time_ps: 4000 - model_id_index: 0 - request_id: 2 - device_time_ps: 100 - } - aggregated_request_detail { - request_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 100 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 0 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_detail {} - request_throughput: 666666666.66666663 - request_average_latency_us: 0.001 - batch_throughput: 0 - batch_average_latency_us: 0 - } - } - inference_stats_per_model { - key: 1 - value { - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - model_id_index: 1 - request_id: 1 - device_time_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 5000 - model_id_index: 1 - request_id: 3 - device_time_ps: 100 - } - aggregated_request_detail { - request_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 100 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 0 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_detail {} - request_throughput: 666666666.66666663 - request_average_latency_us: 0.001 - batch_throughput: 0 - batch_average_latency_us: 0 - } - })pb")); -} - -TEST(InferenceStatsGroupingTest, TestTensorPatternPercentile) { - // Generates an inference stats for test, 6 requests have tensor events owned - // by REQUEST, 2 requests have tensor events owned by BATCH. - InferenceStats inference_stats = - ParseTextProto(R"pb( - inference_stats_per_host { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 0 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 600000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 1 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 500000 - } - } - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 2 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 400000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 3 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 300000 - } - } - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 4 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 200000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 5 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 100000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 6 - tensor_event_details { - tensor_pattern_index: 0 - owner: BATCH - linearize_delinearize_time_ps: 700000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 7 - tensor_event_details { - tensor_pattern_index: 0 - owner: BATCH - linearize_delinearize_time_ps: 800000 - } - } - } - } - )pb") - .value(); - - RegroupInferenceStatsByModel(&inference_stats); - - // Count equals to 6 because request tensor events owned by BATCH are ignored. - // Percentile selector selects linearize and delinearize time at 50.0, 75.0, - // 90.0, 95.0, 99.0, 99.9 percentiles. - EXPECT_THAT(inference_stats.inference_stats_per_model() - .at(0) - .tensor_transfer_aggregated_result(), - EqualsProto(R"pb( - tensor_pattern_results { - tensor_pattern_index: 0 - count: 6 - linearize_delinearize_percentile_time { - percentile: 50 - time_ps: 400000 - } - linearize_delinearize_percentile_time { - percentile: 75 - time_ps: 500000 - } - linearize_delinearize_percentile_time { - percentile: 90 - time_ps: 600000 - } - linearize_delinearize_percentile_time { - percentile: 95 - time_ps: 600000 - } - linearize_delinearize_percentile_time { - percentile: 99 - time_ps: 600000 - } - linearize_delinearize_percentile_time { - percentile: 99.9 - time_ps: 600000 - } - } - )pb")); -} - -TEST(InferenceStatsGroupingTest, TestWithoutModelId) { - // An inference stats with two hosts, no model ID data. - InferenceStats inference_stats = ParseTextProto(R"pb( - inference_stats_per_host { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 0 - related_batch_ids: 0 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 2000 - end_time_ps: 4000 - request_id: 1 - related_batch_ids: 0 - host_runtime_ps: 100 - } - batch_details { - batch_id: 0 - related_request_ids: 0 - related_request_ids: 1 - start_time_ps: 1000 - end_time_ps: 2000 - batch_size_after_padding: 128 - } - } - } - inference_stats_per_host { - key: 1 - value { - request_details { - start_time_ps: 3000 - end_time_ps: 6000 - request_id: 2 - related_batch_ids: 1 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 8000 - request_id: 3 - related_batch_ids: 1 - host_runtime_ps: 100 - } - batch_details { - batch_id: 1 - related_request_ids: 2 - related_request_ids: 3 - start_time_ps: 3000 - end_time_ps: 4000 - batch_size_after_padding: 256 - } - } - } - )pb") - .value(); - - RegroupInferenceStatsByModel(&inference_stats); - - // Verifies that all requests are grouped into a single model, and a "ALL" - // model ID is added. - EXPECT_THAT(inference_stats, EqualsProto(R"pb( - model_id_db { - ids: "ALL" - id_to_index { key: "ALL" value: 0 } - } - inference_stats_per_model { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 0 - related_batch_ids: 0 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 2000 - end_time_ps: 4000 - request_id: 1 - related_batch_ids: 0 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 3000 - end_time_ps: 6000 - request_id: 2 - related_batch_ids: 1 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 8000 - request_id: 3 - related_batch_ids: 1 - host_runtime_ps: 100 - } - batch_details { - batch_id: 0 - related_request_ids: 0 - related_request_ids: 1 - start_time_ps: 1000 - end_time_ps: 2000 - batch_size_after_padding: 128 - } - batch_details { - batch_id: 1 - related_request_ids: 2 - related_request_ids: 3 - start_time_ps: 3000 - end_time_ps: 4000 - batch_size_after_padding: 256 - } - aggregated_request_detail { - request_id: -1 - start_time_ps: 0 - end_time_ps: 2500 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 0 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 100 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_detail { - batch_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - batch_delay_ps: 0 - padding_amount: 0 - device_time_ps: 0 - batch_size_after_padding: 192 - } - per_batch_size_aggregated_result { - batch_size: 128 - aggregated_request_result { - start_time_ps: 0 - end_time_ps: 1500 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 0 - request_id: -1 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 100 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_result { - batch_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - batch_delay_ps: 0 - padding_amount: 0 - device_time_ps: 0 - batch_size_after_padding: 128 - } - request_throughput: 285714285.71428573 - batch_throughput: 333333333.33333331 - } - per_batch_size_aggregated_result { - batch_size: 256 - aggregated_request_result { - start_time_ps: 0 - end_time_ps: 3500 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 0 - request_id: -1 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 100 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_result { - batch_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - batch_delay_ps: 0 - padding_amount: 0 - device_time_ps: 0 - batch_size_after_padding: 256 - } - request_throughput: 285714285.71428573 - batch_throughput: 333333333.33333331 - } - request_throughput: 571428571.42857146 - request_average_latency_us: 0.0025 - batch_throughput: 666666666.66666663 - batch_average_latency_us: 0.001 - } - })pb")); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_sampler.cc b/tensorflow/core/profiler/convert/inference_stats_sampler.cc deleted file mode 100644 index be3e392ec85e78..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_sampler.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_sampler.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { - -namespace { - -using ::tensorflow::profiler::BatchDetail; -using ::tensorflow::profiler::InferenceStats; -using ::tensorflow::profiler::PerModelInferenceStats; -using ::tensorflow::profiler::RequestDetail; - -// Column names that can be used to do percentile selection. -// For request: -constexpr char kColumnLatencyUs[] = "Latency"; -constexpr char kColumnBatchingRequestDelayUs[] = "Request delay for batching"; -constexpr char kColumnBatchingRequestSize[] = "Request size"; -constexpr char kColumnHostPreprocessing[] = "Host preprocess"; -constexpr char kColumnHostBatchFormation[] = "Host batch formation"; -constexpr char kColumnHostRuntime[] = "Host runtime"; -constexpr char kColumnHostToDevice[] = "Data transfer H2D"; -constexpr char kColumnDeviceToHost[] = "Data transfer D2H"; -constexpr char kColumnDeviceCompute[] = "Device compute"; -constexpr char kColumnHostPostprocessing[] = "Host postprocess"; -constexpr char kColumnIdleTime[] = "Idle time"; -// For batch: -constexpr char kColumnBatchingDelayUs[] = "Batching delay"; -constexpr char kColumnPaddingAmount[] = "Padding amount"; -constexpr char kColumnBatchSizeAfterPadding[] = "Batch size after padding"; -constexpr char kColumnBatchingEfficiency[] = "Batching efficiency"; - -double CalculateBatchingEfficiency(const BatchDetail& batch) { - return tsl::profiler::SafeDivide( - static_cast(batch.batch_size_after_padding() - - batch.padding_amount()), - static_cast(batch.batch_size_after_padding())); -} - -// Comparator for RequestDetail proto. -bool CompareByRequestLatency(const RequestDetail* a, const RequestDetail* b) { - return (a->end_time_ps() - a->start_time_ps()) < - (b->end_time_ps() - b->start_time_ps()); -} -bool CompareByBatchingRequestDelay(const RequestDetail* a, - const RequestDetail* b) { - return a->batching_request_delay_ps() < b->batching_request_delay_ps(); -} -bool CompareByBatchingRequestSize(const RequestDetail* a, - const RequestDetail* b) { - return a->batching_request_size() < b->batching_request_size(); -} -bool CompareByHostPreprocessing(const RequestDetail* a, - const RequestDetail* b) { - return a->host_preprocessing_ps() < b->host_preprocessing_ps(); -} -bool CompareByHostBatchFormation(const RequestDetail* a, - const RequestDetail* b) { - return a->host_batch_formation_ps() < b->host_batch_formation_ps(); -} -bool CompareByHostRuntime(const RequestDetail* a, const RequestDetail* b) { - return a->host_runtime_ps() < b->host_runtime_ps(); -} -bool CompareByHostToDevice(const RequestDetail* a, const RequestDetail* b) { - return a->write_to_device_time_ps() < b->write_to_device_time_ps(); -} -bool CompareByDeviceToHost(const RequestDetail* a, const RequestDetail* b) { - return a->read_from_device_time_ps() < b->read_from_device_time_ps(); -} -bool CompareByDeviceCompute(const RequestDetail* a, const RequestDetail* b) { - return a->device_time_ps() < b->device_time_ps(); -} -bool CompareByPostProcessing(const RequestDetail* a, const RequestDetail* b) { - return a->host_postprocessing_ps() < b->host_postprocessing_ps(); -} -bool CompareByIdleTime(const RequestDetail* a, const RequestDetail* b) { - return a->idle_time_ps() < b->idle_time_ps(); -} -// Use percentile column name to get the corresponding compare function. -std::function -GetRequestCompareFunction(absl::string_view column_name) { - if (column_name == kColumnBatchingRequestDelayUs) { - return CompareByBatchingRequestDelay; - } else if (column_name == kColumnBatchingRequestSize) { - return CompareByBatchingRequestSize; - } else if (column_name == kColumnHostPreprocessing) { - return CompareByHostPreprocessing; - } else if (column_name == kColumnHostBatchFormation) { - return CompareByHostBatchFormation; - } else if (column_name == kColumnHostRuntime) { - return CompareByHostRuntime; - } else if (column_name == kColumnHostToDevice) { - return CompareByHostToDevice; - } else if (column_name == kColumnDeviceToHost) { - return CompareByDeviceToHost; - } else if (column_name == kColumnDeviceCompute) { - return CompareByDeviceCompute; - } else if (column_name == kColumnHostPostprocessing) { - return CompareByPostProcessing; - } else if (column_name == kColumnIdleTime) { - return CompareByIdleTime; - } else { - // Return CompareByRequestLatency by default. - return CompareByRequestLatency; - } -} - -// Comparator for BatchDetail proto. -bool CompareByBatchLatency(const BatchDetail* a, const BatchDetail* b) { - return (a->end_time_ps() - a->start_time_ps()) < - (b->end_time_ps() - b->start_time_ps()); -} -bool CompareByBatchDelay(const BatchDetail* a, const BatchDetail* b) { - return a->batch_delay_ps() < b->batch_delay_ps(); -} -bool CompareByPaddingAmount(const BatchDetail* a, const BatchDetail* b) { - return a->padding_amount() < b->padding_amount(); -} -bool CompareByBatchSizeAfterPadding(const BatchDetail* a, - const BatchDetail* b) { - return a->batch_size_after_padding() < b->batch_size_after_padding(); -} -bool CompareByBatchingEfficiency(const BatchDetail* a, const BatchDetail* b) { - return CalculateBatchingEfficiency(*a) < CalculateBatchingEfficiency(*b); -} -// Use percentile column name to get the corresponding compare function. -std::function -GetBatchCompareFunction(absl::string_view column_name) { - if (column_name == kColumnBatchingDelayUs) { - return CompareByBatchDelay; - } else if (column_name == kColumnPaddingAmount) { - return CompareByPaddingAmount; - } else if (column_name == kColumnBatchSizeAfterPadding) { - return CompareByBatchSizeAfterPadding; - } else if (column_name == kColumnBatchingEfficiency) { - return CompareByBatchingEfficiency; - } else { - // Return CompareByBatchLatency by default. - return CompareByBatchLatency; - } -} - -// A static helper class to select a subset of inference data (request or batch) -// to show in the frontend. -// DataType can be either RequestDetail or BatchDetail. -template -class PercentileSelector { - public: - // The range of values in [percentile, perentile+error) are still regarded as - // percentile. - struct PercentileRange { - double percentile; - double error; - }; - - // The percentiles (with the corresponding error bounds) that will be included - // in inference profile result. - static constexpr std::array kWantedPercentiles = { - {{50.0, 1}, - {75.0, 1}, - {90.0, 1}, - {99.0, 0.5}, - {99.9, 0.05}, - {99.99, 0.005}}}; - - // Maximum number of values included for each percentile range. - static constexpr size_t kMaxNumDataSelectedPerPercentile = 10; - - // Select a subset of data from , return pointer to the original - // data and the percentile. - static std::vector> Select( - const std::vector& all_data) { - return SelectInternal(all_data); - } - - private: - static bool GreaterThan(double percentile, const PercentileRange& wanted) { - // Uses ">=" instead of ">" so that the round-up value is not included. - return percentile >= (wanted.percentile + wanted.error); - } - - static bool LessThan(double percentile, const PercentileRange& wanted) { - return percentile < wanted.percentile; - } - - static bool WithinRange(double percentile, const PercentileRange& wanted) { - return !GreaterThan(percentile, wanted) && !LessThan(percentile, wanted); - } - - static std::vector> SelectInternal( - const std::vector& all_data) { - std::vector> result; - // If the number of data points is too small (smaller than the result size - // when select by percentile, like in a unit test), it does not make sense - // to select by percentile, just select all the data points and the frontend - // is able to display all of them. - if (all_data.size() <= - kWantedPercentiles.size() * kMaxNumDataSelectedPerPercentile) { - for (size_t i = 0; i < all_data.size(); i++) { - double percentile = 100.0 * i / all_data.size(); - result.push_back(std::make_pair(all_data[i], percentile)); - } - return result; - } - - // Select by percentile. - size_t idx_to_next_data = 0; - for (size_t i = 0; i < kWantedPercentiles.size(); i++) { - const auto& wanted = kWantedPercentiles[i]; - size_t num_data_selected = 0; - for (size_t k = idx_to_next_data; k < all_data.size(); k++) { - double percentile = 100.0 * k / all_data.size(); - if (GreaterThan(percentile, wanted)) { - // Updates idx_to_next_data to k so that when we select data for the - // next percentile we don't need to consider the data with smaller - // latenices than that for the next percentile. - idx_to_next_data = k; - break; - } - if (WithinRange(percentile, wanted)) { - if (num_data_selected < kMaxNumDataSelectedPerPercentile) { - // Selects this data only if we have not hit the limit for this - // percentile. - result.push_back(std::make_pair(all_data[k], percentile)); - ++num_data_selected; - } - } - } - } - return result; - } -}; - -// Sample the requests and batches in using sampling column -// and . -void SamplePerModelInferenceStats( - absl::string_view request_percentile_column, - absl::string_view batch_percentile_column, - const PerModelInferenceStats& per_model_stats, - SampledPerModelInferenceStats* sampled_per_model_stats) { - // Select a subset of requests and batches based on percentile and generate - // final result. - std::vector requests( - per_model_stats.request_details_size()); - for (size_t i = 0; i < per_model_stats.request_details_size(); i++) { - requests[i] = &per_model_stats.request_details(i); - } - // Requests in per model stats are already sorted by latency. Only redo the - // sorting when percentile column is not latency. - if (request_percentile_column != kColumnLatencyUs) { - std::sort(requests.begin(), requests.end(), - GetRequestCompareFunction(request_percentile_column)); - } - sampled_per_model_stats->sampled_requests = - PercentileSelector::Select(requests); - - std::vector batches(per_model_stats.batch_details_size()); - for (size_t i = 0; i < per_model_stats.batch_details_size(); i++) { - batches[i] = &per_model_stats.batch_details(i); - } - // Batches in per model stats are already sorted by latency. Only redo the - // sorting when percentile column is not latency. - if (batch_percentile_column != kColumnLatencyUs) { - std::sort(batches.begin(), batches.end(), - GetBatchCompareFunction(batch_percentile_column)); - } - sampled_per_model_stats->sampled_batches = - PercentileSelector::Select(batches); -} - -} // namespace - -SampledInferenceStats SampleInferenceStats( - absl::string_view request_percentile_column, - absl::string_view batch_percentile_column, - const InferenceStats& inference_stats) { - SampledInferenceStats result; - for (const auto& [model_index, model_inference_stats] : - inference_stats.inference_stats_per_model()) { - SamplePerModelInferenceStats(request_percentile_column, - batch_percentile_column, model_inference_stats, - &(result[model_index])); - } - - return result; -} - -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_sampler.h b/tensorflow/core/profiler/convert/inference_stats_sampler.h deleted file mode 100644 index 2706c16a8ff97a..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_sampler.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { - -// Sampled inference stats of a model. -// The pointers of RequestDetail and BatchDetail point to the actual data stored -// in TfOpStats.InferenceStats. -struct SampledPerModelInferenceStats { - // Sampled requests and their percentile. - std::vector> - sampled_requests; - // Sampled batches and their percentile. - std::vector> - sampled_batches; -}; - -// All the sampled inference stats of a profile. -// TODO: Move to use SampledInferenceStatsProto if feasible. -using SampledInferenceStats = - absl::flat_hash_map; - -// Samples a subset of InferenceStats from based on sampling -// column and . -SampledInferenceStats SampleInferenceStats( - absl::string_view request_percentile_column, - absl::string_view batch_percentile_column, - const tensorflow::profiler::InferenceStats& inference_stats); - -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_sampler_test.cc b/tensorflow/core/profiler/convert/inference_stats_sampler_test.cc deleted file mode 100644 index 72c35a520a4cf4..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_sampler_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_sampler.h" - -#include "absl/status/statusor.h" -#include "xla/tests/test_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -namespace { -using ::tensorflow::profiler::InferenceStats; -using xla::ParseTextProto; - -TEST(ConvertInferenceStatsToInferenceProfileTest, TestSort) { - // Generate an inference stats for test. - // Requests and batches are ordered by latency (end_time_ps - start_time_ps), - // this is guaranteed by inference_stats.cc - InferenceStats inference_stats = ParseTextProto( - R"pb( - inference_stats_per_model { - key: 1 - value { - request_details { - request_id: 0 - start_time_ps: 0 - end_time_ps: 10000 - batching_request_delay_ps: 2000 - batching_request_size: 200 - } - request_details { - request_id: 1 - start_time_ps: 0 - end_time_ps: 20000 - batching_request_delay_ps: 1000 - batching_request_size: 100 - } - request_details { - request_id: 2 - start_time_ps: 0 - end_time_ps: 30000 - batching_request_delay_ps: 3000 - batching_request_size: 300 - } - batch_details { - batch_id: 3 - start_time_ps: 0 - end_time_ps: 10000 - batch_delay_ps: 2000 - padding_amount: 20 - batch_size_after_padding: 200 - } - batch_details { - batch_id: 4 - start_time_ps: 0 - end_time_ps: 20000 - batch_delay_ps: 1000 - padding_amount: 10 - batch_size_after_padding: 100 - } - batch_details { - batch_id: 5 - start_time_ps: 0 - end_time_ps: 30000 - batch_delay_ps: 3000 - padding_amount: 30 - batch_size_after_padding: 300 - } - } - } - )pb") - .value(); - - // Sort by latency, the result does not change. - auto result_1 = SampleInferenceStats("Latency", "Latency", inference_stats); - const auto& per_model_1 = result_1.at(1); - EXPECT_EQ(per_model_1.sampled_requests.at(0).first->request_id(), 0); - EXPECT_EQ(per_model_1.sampled_requests.at(1).first->request_id(), 1); - EXPECT_EQ(per_model_1.sampled_requests.at(2).first->request_id(), 2); - EXPECT_EQ(per_model_1.sampled_batches.at(0).first->batch_id(), 3); - EXPECT_EQ(per_model_1.sampled_batches.at(1).first->batch_id(), 4); - EXPECT_EQ(per_model_1.sampled_batches.at(2).first->batch_id(), 5); - - // Sort requests by Request size, sort batches by Padding amount. - // Verifies the values are in increasing order. - auto result_2 = - SampleInferenceStats("Request size", "Padding amount", inference_stats); - const auto& per_model_2 = result_2.at(1); - EXPECT_EQ(per_model_2.sampled_requests.at(0).first->batching_request_size(), - 100); - EXPECT_EQ(per_model_2.sampled_requests.at(1).first->batching_request_size(), - 200); - EXPECT_EQ(per_model_2.sampled_requests.at(2).first->batching_request_size(), - 300); - EXPECT_EQ(per_model_2.sampled_batches.at(0).first->padding_amount(), 10); - EXPECT_EQ(per_model_2.sampled_batches.at(1).first->padding_amount(), 20); - EXPECT_EQ(per_model_2.sampled_batches.at(2).first->padding_amount(), 30); - - // Sort requests by Request delay for batching, sort batches by - // Batching delay. Verifies the values are in increasing order. - auto result_3 = SampleInferenceStats("Request delay for batching", - "Batching delay", inference_stats); - const auto& per_model_3 = result_3.at(1); - EXPECT_EQ( - per_model_3.sampled_requests.at(0).first->batching_request_delay_ps(), - 1000); - EXPECT_EQ( - per_model_3.sampled_requests.at(1).first->batching_request_delay_ps(), - 2000); - EXPECT_EQ( - per_model_3.sampled_requests.at(2).first->batching_request_delay_ps(), - 3000); - EXPECT_EQ(per_model_3.sampled_batches.at(0).first->batch_delay_ps(), 1000); - EXPECT_EQ(per_model_3.sampled_batches.at(1).first->batch_delay_ps(), 2000); - EXPECT_EQ(per_model_3.sampled_batches.at(2).first->batch_delay_ps(), 3000); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc deleted file mode 100644 index 38cfb2ea2ffc4e..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" - -#include -#include - -#include "absl/status/status.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_stats_combiner.h" -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "xprof/utils/hardware_type_utils.h" // from @org_xprof -#include "xprof/utils/step_intersection.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -absl::Status ConvertMultiXSpacesToCombinedOpStats( - const SessionSnapshot& session_snapshot, const OpStatsOptions& options, - OpStats* combined_op_stats) { - // Read multiple XSpaces and convert to multiple OpStats. - // TODO(profiler): Change the combiner to convert and combine one OpStats at a - // time, to reduce peak memory usage. - std::vector all_op_stats; - all_op_stats.reserve(session_snapshot.XSpaceSize()); - for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/true); - all_op_stats.push_back(ConvertXSpaceToOpStats(*xspace, options)); - } - - // Combine OpStats. - std::vector all_op_stats_info; - all_op_stats_info.reserve(all_op_stats.size()); - for (int i = 0; i < all_op_stats.size(); i++) { - all_op_stats_info.emplace_back( - &all_op_stats[i], - ParseHardwareType(all_op_stats[i].run_environment().device_type()), i); - } - - // Do not limit the maximum number of steps during the merge of OpStats. - StepIntersection step_intersection = - ComputeStepIntersectionToMergeOpStats(all_op_stats_info, kuint32max); - CombineAllOpStats(all_op_stats_info, step_intersection, combined_op_stats); - - return absl::OkStatus(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h deleted file mode 100644 index 51348097d321f3..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ - -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Converts and combines multiple XSpace protos into a single OpStats -// . -// Return the first error status during conversion, or return OkStatus() if -// there is no error. -absl::Status ConvertMultiXSpacesToCombinedOpStats( - const SessionSnapshot& session_snapshot, const OpStatsOptions& options, - OpStats* combined_op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ diff --git a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc deleted file mode 100644 index f5cbb9b62a4b66..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h" - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/convert/inference_stats.h" -#include "tensorflow/core/profiler/convert/inference_stats_combiner.h" -#include "tensorflow/core/profiler/convert/inference_stats_grouping.h" -#include "tensorflow/core/profiler/convert/inference_stats_sampler.h" -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/inference_stats.pb.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow::profiler { - -namespace { -using tsl::profiler::FindMutablePlanesWithPrefix; -using tsl::profiler::FindMutablePlaneWithName; - -SampledInferenceStatsProto GetSampledInferenceStatsProto( - const InferenceStats& inference_stats, absl::string_view request_column, - absl::string_view batch_column) { - SampledInferenceStatsProto result; - SampledInferenceStats sampled_stats = - SampleInferenceStats(request_column, batch_column, inference_stats); - for (const auto& [model_index, samples] : sampled_stats) { - SampledPerModelInferenceStatsProto per_model_stats; - for (const auto& [request, percentile] : samples.sampled_requests) { - RequestDetail request_detail = *request; - request_detail.set_percentile(percentile); - *per_model_stats.add_sampled_requests() = request_detail; - } - for (const auto& [batch, percentile] : samples.sampled_batches) { - BatchDetail batch_detail = *batch; - batch_detail.set_percentile(percentile); - *per_model_stats.add_sampled_batches() = batch_detail; - } - result.mutable_sampled_inference_stats_per_model()->insert( - {model_index, per_model_stats}); - } - return result; -} -} // namespace - -StepEvents GetNonOverlappedStepEvents(XSpace* xspace) { - StepEvents non_overlapped_step_events; - - std::vector device_traces = - FindMutablePlanesWithPrefix(xspace, kGpuPlanePrefix); - if (device_traces.empty()) return non_overlapped_step_events; - - StepEvents device_step_events; - StepEvents host_step_events; - for (XPlane* device_trace : device_traces) { - StepEvents events = ConvertDeviceTraceXPlaneToStepEvents(*device_trace); - UnionCombineStepEvents(events, &device_step_events); - } - - XPlaneVisitor host_plane = tsl::profiler::CreateTfXPlaneVisitor( - FindMutablePlaneWithName(xspace, kHostThreadsPlaneName)); - - host_plane.ForEachLine([&](const XLineVisitor& line) { - StepEvents events = - ConvertHostThreadsXLineToStepEvents(line, &device_step_events); - UnionCombineStepEvents(events, &host_step_events); - }); - StepEvents overlapped_step_events; - UnionCombineStepEvents(device_step_events, &overlapped_step_events); - UnionCombineStepEvents(host_step_events, &overlapped_step_events); - non_overlapped_step_events = - ToNonOverlappedStepEvents(overlapped_step_events); - return non_overlapped_step_events; -} - -absl::Status ConvertMultiXSpaceToInferenceStats( - const SessionSnapshot& session_snapshot, absl::string_view request_column, - absl::string_view batch_column, InferenceStats* inference_stats) { - for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - tsl::profiler::GroupMetadataMap metadata_map; - InferenceStats inference_stats_per_host; - std::vector device_traces = - tsl::profiler::FindMutableTensorCorePlanes(xspace.get()); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false, &metadata_map); - StepEvents non_overlapped_step_events = - GetNonOverlappedStepEvents(xspace.get()); - GenerateInferenceStats( - device_traces, non_overlapped_step_events, metadata_map, *xspace, - tsl::profiler::DeviceType::kTpu, i, &inference_stats_per_host); - CombineInferenceStatsResult(i, inference_stats_per_host, inference_stats); - } - RegroupInferenceStatsByModel(inference_stats); - *inference_stats->mutable_sampled_inference_stats() = - GetSampledInferenceStatsProto(*inference_stats, request_column, - batch_column); - return absl::OkStatus(); -} -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h deleted file mode 100644 index 9d8399f2f43f62..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/inference_stats.pb.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow::profiler { -// Get non overlapped step events from xspace for GPU. -StepEvents GetNonOverlappedStepEvents(XSpace* xspace); - -absl::Status ConvertMultiXSpaceToInferenceStats( - const SessionSnapshot& session_snapshot, absl::string_view request_column, - absl::string_view batch_column, InferenceStats* inference_stats); -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc deleted file mode 100644 index ab0c25b6f38c33..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using OperationType = OpMetrics::MemoryAccessed::OperationType; - -void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) { - dst->set_compute_16bit_ps(src.compute_16bit_ps() + dst->compute_16bit_ps()); - dst->set_compute_32bit_ps(src.compute_32bit_ps() + dst->compute_32bit_ps()); -} - -} // namespace - -void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) { - DCHECK(dst != nullptr); - DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id()); - DCHECK_EQ(src.name(), dst->name()); - if (dst->long_name().empty()) { - dst->set_long_name(src.long_name()); - } - if (dst->fingerprint() == 0) { - dst->set_fingerprint(src.fingerprint()); - } - if (dst->category().empty()) { - dst->set_category(src.category()); - } - if (dst->provenance().empty()) { - dst->set_provenance(src.provenance()); - } - if (dst->deduplicated_name().empty()) { - dst->set_deduplicated_name(src.deduplicated_name()); - } - if (!dst->has_layout() && src.has_layout()) { - *dst->mutable_layout() = src.layout(); - } - if (!dst->has_children() && src.has_children()) { - *dst->mutable_children() = src.children(); - } -} - -void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst, - bool update_num_cores) { - DCHECK(dst != nullptr); - if (dst->occurrences() == 0) { - dst->set_min_time_ps(src.min_time_ps()); - } else { - dst->set_min_time_ps(std::min(src.min_time_ps(), dst->min_time_ps())); - } - dst->set_is_eager(dst->is_eager() || src.is_eager()); - dst->set_occurrences(src.occurrences() + dst->occurrences()); - dst->set_time_ps(src.time_ps() + dst->time_ps()); - dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps()); - dst->set_flops(src.flops() + dst->flops()); - dst->set_model_flops(src.model_flops() + dst->model_flops()); - dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed()); - dst->set_autotuned(dst->autotuned() || src.autotuned()); - if (update_num_cores) { - dst->set_num_cores(src.num_cores() + dst->num_cores()); - } - CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(), - dst->mutable_memory_accessed_breakdown()); - dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps()); -} - -void CombineMemoryAccessedBreakdown( - const tsl::protobuf::RepeatedPtrField& src, - tsl::protobuf::RepeatedPtrField* dst) { - if (src.empty()) return; - absl::flat_hash_map, - OpMetrics_MemoryAccessed*> - dst_memory_accessed_map; - for (auto& dst_memory_accessed : *dst) { - dst_memory_accessed_map[{dst_memory_accessed.memory_space(), - dst_memory_accessed.operation_type()}] = - &dst_memory_accessed; - } - for (const auto& src_memory_accessed : src) { - uint64 memory_space = src_memory_accessed.memory_space(); - OperationType operation_type = src_memory_accessed.operation_type(); - auto*& dst_memory_accessed = - dst_memory_accessed_map[{memory_space, operation_type}]; - if (dst_memory_accessed == nullptr) { - dst_memory_accessed = dst->Add(); - dst_memory_accessed->set_memory_space(memory_space); - dst_memory_accessed->set_operation_type(operation_type); - } - dst_memory_accessed->set_bytes_accessed( - src_memory_accessed.bytes_accessed() + - dst_memory_accessed->bytes_accessed()); - } -} - -void OpMetricsDbCombiner::Combine(const OpMetricsDb& src, - bool update_num_cores) { - OpMetricsDb* dst = db(); - dst->set_total_host_infeed_enq_duration_ps( - src.total_host_infeed_enq_duration_ps() + - dst->total_host_infeed_enq_duration_ps()); - dst->set_total_host_infeed_enq_start_timestamp_ps_diff( - src.total_host_infeed_enq_start_timestamp_ps_diff() + - dst->total_host_infeed_enq_start_timestamp_ps_diff()); - dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps()); - dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps()); - dst->set_idle_time_ps(src.idle_time_ps() + dst->idle_time_ps()); - dst->set_busy_time_ps(src.busy_time_ps() + dst->busy_time_ps()); - CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats()); - - for (const auto& src_metrics : src.metrics_db()) { - auto* dst_metrics = LookupOrInsertNewOpMetrics(src_metrics.hlo_module_id(), - src_metrics.name()); - CopyOpMetricsMetadata(src_metrics, dst_metrics); - CombineOpMetrics(src_metrics, dst_metrics, update_num_cores); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.h b/tensorflow/core/profiler/convert/op_metrics_db_combiner.h deleted file mode 100644 index d538a232e4f41e..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ - -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/platform/protobuf.h" -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Copies OpMetrics metadata (e.g., category, provenance) from src to dst. -void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst); - -// Combines OpMetrics data (e.g., occurrences, time) from src into dst. -// If is set to true, update the dst->num_cores to -// calculate the number of cores a certain op occurs. -void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst, - bool update_num_cores); - -// Combines the memory access breakdown. -void CombineMemoryAccessedBreakdown( - const tsl::protobuf::RepeatedPtrField& src, - tsl::protobuf::RepeatedPtrField* dst); - -// Helper to combine op metrics databases. -class OpMetricsDbCombiner : public OpMetricsDbBuilder { - public: - explicit OpMetricsDbCombiner(OpMetricsDb* dst) : OpMetricsDbBuilder(dst) {} - - // Combine the OpMetrics in OpMetricsDb to current OpMetricsDbCombiner. - // If is set to true, update the OpMetrics.num_cores to - // calculate the number of cores a certain op occurs. - void Combine(const OpMetricsDb& src, bool update_num_cores = true); -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.cc b/tensorflow/core/profiler/convert/op_metrics_to_record.cc deleted file mode 100644 index b6f1cadb59388c..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_to_record.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" - -#include -#include - -#include "absl/algorithm/container.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" - -namespace tensorflow { -namespace profiler { - -std::vector SortedOpMetricsDb(const OpMetricsDb& metrics_db, - int max_records) { - std::vector result; - result.reserve(metrics_db.metrics_db_size()); - for (const OpMetrics& metrics : metrics_db.metrics_db()) { - result.push_back(&metrics); - } - - auto comp = [](const OpMetrics* a, const OpMetrics* b) { - return std::make_tuple(a->self_time_ps(), b->name()) > - std::make_tuple(b->self_time_ps(), a->name()); - }; - int result_size = result.size(); - if (max_records != -1 && result_size > max_records) { - absl::c_partial_sort(result, result.begin() + max_records, comp); - result.resize(max_records); - } else { - absl::c_sort(result, comp); - } - return result; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.h b/tensorflow/core/profiler/convert/op_metrics_to_record.h deleted file mode 100644 index 4884fb64adc24c..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_to_record.h +++ /dev/null @@ -1,341 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -std::vector SortedOpMetricsDb(const OpMetricsDb& metrics_db, - int max_records = -1); - -inline double GigaFlopsPerSecondPerCore(const OpMetrics& metrics) { - // flops and time_ps are accumulated across all occurrences on all cores. - // time_ps is used instead of self_time_ps because flops for an op includes - // the flops executed by children (nested) ops. - return tsl::profiler::SafeDivide( - metrics.flops(), tsl::profiler::PicoToNano(metrics.time_ps())); -} - -inline double GigaModelFlopsPerSecondPerCore(const OpMetrics& metrics) { - // flops and time_ps are accumulated across all occurrences on all cores. - // time_ps is used instead of self_time_ps because flops for an op includes - // the flops executed by children (nested) ops. - return tsl::profiler::SafeDivide( - metrics.model_flops(), tsl::profiler::PicoToNano(metrics.time_ps())); -} - -// Return ByteAccessed for memory_space and operation_type. -inline double BytesAccessedPerCore( - const OpMetrics& metrics, uint64_t memory_space, - OpMetrics::MemoryAccessed::OperationType operation_type) { - uint64_t bytes = 0; - if (memory_space == MemorySpace::MEMORY_SPACE_ALL) { - bytes = metrics.bytes_accessed(); - } else { - for (const auto& breakdown : metrics.memory_accessed_breakdown()) { - // Count either on-chip or off-chip bytes. - if ((breakdown.operation_type() != operation_type) && - (operation_type != OpMetrics::MemoryAccessed::UNKNOWN)) { - continue; - } - if (((memory_space == MemorySpace::MEMORY_SPACE_HBM) && - (breakdown.memory_space() == MemorySpace::MEMORY_SPACE_HBM)) || - ((memory_space == MemorySpace::MEMORY_SPACE_ON_CHIP) && - (breakdown.memory_space() != MemorySpace::MEMORY_SPACE_HBM))) { - bytes += breakdown.bytes_accessed(); - } - } - } - return bytes; -} - -inline double GigaBytesPerSecondPerCore( - const OpMetrics& metrics, uint64_t memory_space, - OpMetrics::MemoryAccessed::OperationType operation_type) { - // bytes_accessed and time_ps are accumulated across all occurrences on all - // cores. - // time_ps is used instead of self_time_ps because bytes_accessed for an op - // includes the bytes accessed by children (nested) ops. - return tsl::profiler::SafeDivide( - BytesAccessedPerCore(metrics, memory_space, operation_type), - tsl::profiler::PicoToNano(metrics.time_ps())); -} - -inline double GibiBytesPerSecondPerCore( - const OpMetrics& metrics, uint64_t memory_space, - OpMetrics::MemoryAccessed::OperationType op_type) { - return tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(metrics, memory_space, op_type)); -} - -template -inline void SetExecutionTimes(const OpMetrics& metrics, Record* record) { - record->set_occurrences(metrics.occurrences()); - record->set_total_time_in_us(tsl::profiler::PicoToMicro(metrics.time_ps())); - record->set_avg_time_in_us(tsl::profiler::SafeDivide( - record->total_time_in_us(), metrics.occurrences())); - record->set_total_self_time_in_us( - tsl::profiler::PicoToMicro(metrics.self_time_ps())); - record->set_avg_self_time_in_us(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), metrics.occurrences())); -} - -template -inline void SetTpuUnitFractions(const OpMetrics& metrics, Record* record) { - record->set_dma_stall_fraction( - tsl::profiler::SafeDivide(metrics.dma_stall_ps(), metrics.time_ps())); -} - -template -inline void SetRankAndTimeFractions(double total_time_us, - const Record& prev_record, Record* record) { - record->set_rank(prev_record.rank() + 1); - record->set_total_self_time_as_fraction(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), total_time_us)); - record->set_cumulative_total_self_time_as_fraction( - prev_record.cumulative_total_self_time_as_fraction() + - record->total_self_time_as_fraction()); -} - -template -inline void SetRankAndDeviceTimeFractions(double total_time_us, - const Record& prev_record, - Record* record) { - record->set_rank(prev_record.rank() + 1); - record->set_device_total_self_time_as_fraction(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), total_time_us)); - record->set_device_cumulative_total_self_time_as_fraction( - prev_record.device_cumulative_total_self_time_as_fraction() + - record->device_total_self_time_as_fraction()); -} - -template -inline void SetRankAndHostTimeFractions(double total_time_us, - const Record& prev_record, - Record* record) { - record->set_rank(prev_record.rank() + 1); - record->set_host_total_self_time_as_fraction(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), total_time_us)); - record->set_host_cumulative_total_self_time_as_fraction( - prev_record.host_cumulative_total_self_time_as_fraction() + - record->host_total_self_time_as_fraction()); -} - -// Returns the memory bandwidth in GigaBytes/s in the PerfEnv. -// memory space is chosen by index following order in xplane_to_op_stats.cc -static inline double GetMemoryPeakBandwidth(const PerfEnv& perf_env, - const int index) { - if (perf_env.peak_bws_giga_bytes_per_second_size() > index) { - return perf_env.peak_bws_giga_bytes_per_second(index); - } - return perf_env.peak_hbm_bw_giga_bytes_per_second(); -} - -template -inline void SetRooflineMetrics(const OpMetrics& metrics, const PerfEnv perf_env, - const RunEnvironment& run_env, Record* record) { - using ::tensorflow::profiler::MemorySpace; - using ::tensorflow::profiler::PerformanceInfo; - - // Set overall performance metrics. - record->set_measured_flop_rate(GigaFlopsPerSecondPerCore(metrics)); - record->set_model_flop_rate(GigaModelFlopsPerSecondPerCore(metrics)); - record->set_measured_memory_bw(GibiBytesPerSecondPerCore( - metrics, tensorflow::profiler::MemorySpace::MEMORY_SPACE_ALL, - OpMetrics::MemoryAccessed::UNKNOWN)); - record->set_flops(metrics.flops()); - record->set_bytes_accessed(metrics.bytes_accessed()); - record->set_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), metrics.bytes_accessed())); - // Set performance metrics per memory access type. - uint64_t hbm_bytes = 0; - uint64_t cmem_read_bytes = 0; - uint64_t cmem_write_bytes = 0; - uint64_t vmem_read_bytes = 0; - uint64_t vmem_write_bytes = 0; - for (const auto& memory_access : metrics.memory_accessed_breakdown()) { - if (memory_access.memory_space() == PerformanceInfo::MemoryAccessed::HBM) { - hbm_bytes += memory_access.bytes_accessed(); - } else if (memory_access.memory_space() == - PerformanceInfo::MemoryAccessed::CMEM) { - if (memory_access.operation_type() == OpMetrics::MemoryAccessed::READ) { - cmem_read_bytes += memory_access.bytes_accessed(); - } else if (memory_access.operation_type() == - OpMetrics::MemoryAccessed::WRITE) { - cmem_write_bytes += memory_access.bytes_accessed(); - } - } else if (memory_access.memory_space() == - PerformanceInfo::MemoryAccessed::VMEM) { - if (memory_access.operation_type() == OpMetrics::MemoryAccessed::READ) { - vmem_read_bytes += memory_access.bytes_accessed(); - } else if (memory_access.operation_type() == - OpMetrics::MemoryAccessed::WRITE) { - vmem_write_bytes += memory_access.bytes_accessed(); - } - } - } - if (metrics.memory_accessed_breakdown_size() == 0) { - // For legacy profiles without memory access breakdown, consider all memory - // access as HBM access. - hbm_bytes = metrics.bytes_accessed(); - } - record->set_hbm_bw(tsl::profiler::GibibytesPerSecond( - hbm_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_cmem_read_bw(tsl::profiler::GibibytesPerSecond( - cmem_read_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_cmem_write_bw(tsl::profiler::GibibytesPerSecond( - cmem_write_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_vmem_read_bw(tsl::profiler::GibibytesPerSecond( - vmem_read_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_vmem_write_bw(tsl::profiler::GibibytesPerSecond( - vmem_write_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_hbm_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), hbm_bytes)); - record->set_cmem_read_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), cmem_read_bytes)); - record->set_cmem_write_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), cmem_write_bytes)); - record->set_vmem_read_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), vmem_read_bytes)); - record->set_vmem_write_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), vmem_write_bytes)); - // Resources considered for roofline analysis. - constexpr absl::string_view kUnknown = "Unknown"; - constexpr absl::string_view kCompute = "Compute"; - constexpr absl::string_view kHbm = "HBM"; - constexpr absl::string_view kCmemRead = "CMEM Read"; - constexpr absl::string_view kCmemWrite = "CMEM Write"; - constexpr absl::string_view kVmemRead = "VMEM Read"; - constexpr absl::string_view kVmemWrite = "VMEM Write"; - constexpr absl::string_view kShmL1 = "Shm/L1"; - // Compute the bound time assuming the peak capacity of each resource and - // choose the highest one as the bottleneck. See go/xprof-roofline-pxc for - // more details. - // NOTE: The roofline analysis result is the same for Megacore because every - // resource's capacity is doubled for Megacore so the comparison result is the - // same. - absl::string_view bottleneck_resource = kUnknown; - double bottleneck_utilization = 0; - double bottleneck_operational_intensity = 0; - double peak_flops = - tsl::profiler::TeraToGiga(perf_env.peak_tera_flops_per_second()); - double flops_utilization = - tsl::profiler::SafeDivide(record->measured_flop_rate(), peak_flops); - if (bottleneck_utilization < flops_utilization) { - bottleneck_resource = kCompute; - bottleneck_utilization = flops_utilization; - bottleneck_operational_intensity = record->operational_intensity(); - } - double peak_hbm_bw = GetMemoryPeakBandwidth(perf_env, 0); - double hbm_bw_utilization = tsl::profiler::SafeDivide( - record->hbm_bw(), tsl::profiler::GigaToGibi(peak_hbm_bw)); - if (bottleneck_utilization < hbm_bw_utilization) { - bottleneck_resource = kHbm; - bottleneck_utilization = hbm_bw_utilization; - bottleneck_operational_intensity = record->hbm_operational_intensity(); - } - tensorflow::profiler::HardwareType hardware_type = run_env.hardware_type(); - if (hardware_type == tensorflow::profiler::HardwareType::TPU) { - if (cmem_read_bytes) { - double peak_cmem_read_bw = GetMemoryPeakBandwidth(perf_env, 3); - if (peak_cmem_read_bw) { - double cmem_read_bw_utilization = tsl::profiler::SafeDivide( - record->cmem_read_bw(), - tsl::profiler::GigaToGibi(peak_cmem_read_bw)); - if (bottleneck_utilization < cmem_read_bw_utilization) { - bottleneck_resource = kCmemRead; - bottleneck_utilization = cmem_read_bw_utilization; - bottleneck_operational_intensity = - record->cmem_read_operational_intensity(); - } - } - } - if (cmem_write_bytes) { - double peak_cmem_write_bw = GetMemoryPeakBandwidth(perf_env, 4); - if (peak_cmem_write_bw) { - double cmem_write_bw_utilization = tsl::profiler::SafeDivide( - record->cmem_write_bw(), - tsl::profiler::GigaToGibi(peak_cmem_write_bw)); - if (bottleneck_utilization < cmem_write_bw_utilization) { - bottleneck_resource = kCmemWrite; - bottleneck_utilization = cmem_write_bw_utilization; - bottleneck_operational_intensity = - record->cmem_write_operational_intensity(); - } - } - } - if (vmem_read_bytes) { - double peak_vmem_read_bw = GetMemoryPeakBandwidth(perf_env, 5); - if (peak_vmem_read_bw) { - double vmem_read_bw_utilization = tsl::profiler::SafeDivide( - record->vmem_read_bw(), - tsl::profiler::GigaToGibi(peak_vmem_read_bw)); - if (bottleneck_utilization < vmem_read_bw_utilization) { - bottleneck_resource = kVmemRead; - bottleneck_utilization = vmem_read_bw_utilization; - bottleneck_operational_intensity = - record->vmem_read_operational_intensity(); - } - } - } - if (vmem_write_bytes) { - double peak_vmem_write_bw = GetMemoryPeakBandwidth(perf_env, 6); - if (peak_vmem_write_bw) { - double vmem_write_bw_utilization = tsl::profiler::SafeDivide( - record->vmem_write_bw(), - tsl::profiler::GigaToGibi(peak_vmem_write_bw)); - if (bottleneck_utilization < vmem_write_bw_utilization) { - bottleneck_resource = kVmemWrite; - bottleneck_utilization = vmem_write_bw_utilization; - bottleneck_operational_intensity = - record->vmem_write_operational_intensity(); - } - } - } - } - if (hardware_type == tensorflow::profiler::HardwareType::GPU) { - double peak_shm_l1_bw = GetMemoryPeakBandwidth(perf_env, 2); - if (peak_shm_l1_bw) { - // Currently, we only have general read/write bandwidth in record. - double shm_l1_bw_utilization = tsl::profiler::SafeDivide( - record->hbm_bw(), tsl::profiler::GigaToGibi(peak_shm_l1_bw)); - if (bottleneck_utilization < shm_l1_bw_utilization) { - bottleneck_resource = kShmL1; - bottleneck_utilization = shm_l1_bw_utilization; - bottleneck_operational_intensity = record->hbm_operational_intensity(); - } - } - } - record->set_bound_by(std::string(bottleneck_resource)); - record->set_bottleneck_operational_intensity( - bottleneck_operational_intensity); -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc deleted file mode 100644 index 8741b92df8bfff..00000000000000 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ /dev/null @@ -1,445 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_profile_builder.h" - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/node_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/lib/gtl/top_n.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tsl/platform/protobuf.h" -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using op_profile::Metrics; -using op_profile::Node; -using tsl::profiler::IsFusion; - -double CapUtilization(double utilization) { return std::min(utilization, 1.0); } - -// Fill symbol details into a node. -void PopulateSymbolNode(const OpMetrics& op_metrics, Node* node) { - node->set_name(op_metrics.name()); - Node::XLAInstruction& xla = *node->mutable_xla(); - xla.set_program_id(op_metrics.hlo_module_id()); - xla.set_expression(op_metrics.long_name()); - xla.set_fingerprint(op_metrics.fingerprint()); - xla.set_category(op_metrics.category()); - xla.set_provenance(op_metrics.provenance()); - if (op_metrics.has_layout()) { - for (const auto& dimension : op_metrics.layout().dimensions()) { - auto* dim = xla.mutable_layout()->add_dimensions(); - dim->set_size(dimension.size()); - dim->set_alignment(dimension.alignment()); - dim->set_semantics(absl::AsciiStrToLower( - LayoutDimensionSemantics_Name(dimension.semantics()))); - } - } - xla.set_computation_primitive_size(op_metrics.computation_primitive_size()); -} - -// Sort the children and only keep the top K children. -template -Node TopKChildren(const Node* root, int k, Cmp cmp) { - tensorflow::gtl::TopN top_n(k, cmp); - for (const Node& node : root->children()) { - top_n.push(&node); - } - Node output; - std::unique_ptr> extracted_nodes(top_n.Extract()); - for (const Node* node : *extracted_nodes) { - *output.add_children() = *node; - } - return output; -} - -// Copy symbol details into a deduplicated node from the top child node. -void CopySymbolDetailsToDeduplicatedNode(Node* top_child_node, - Node* deduplicated_node) { - deduplicated_node->set_name( - absl::StrCat(top_child_node->name(), " and its duplicate(s)")); - Node::XLAInstruction& xla = *deduplicated_node->mutable_xla(); - const Node::XLAInstruction& top_child_node_xla = top_child_node->xla(); - xla.set_program_id(top_child_node_xla.program_id()); - xla.set_expression(top_child_node_xla.expression()); - xla.set_fingerprint(top_child_node_xla.fingerprint()); - xla.set_category(top_child_node_xla.category()); - if (IsFusion(top_child_node_xla.category())) return; - xla.set_provenance(top_child_node_xla.provenance()); - *xla.mutable_layout() = top_child_node_xla.layout(); -} - -void SortAndPruneChildren(int k, int level, Node* root) { - // Set the total number of children before pruning. - root->set_num_children(root->children_size()); - for (Node& node : *root->mutable_children()) { - SortAndPruneChildren(k, level - 1, &node); - } - k = level > 0 ? root->children_size() : k; - - if (root->children_size() > 1) { - if (root->has_xla() && IsFusion(root->xla().category())) { - // Sort the children under fusion node by raw flops. - *root->mutable_children() = - TopKChildren(root, k, [](const Node* a, const Node* b) { - return a->metrics().raw_flops() > b->metrics().raw_flops(); - }).children(); - } else { - *root->mutable_children() = - TopKChildren(root, k, [](const Node* a, const Node* b) { - return a->metrics().raw_time() > b->metrics().raw_time(); - }).children(); - } - } -} - -// Finalize deduplicated nodes by copying symbol details from the top child -// node. -void FinalizeDeduplicatedNodes(bool by_program, Node* root) { - if (by_program) { - for (Node& program_node : *root->mutable_children()) { - for (Node& category_node : *program_node.mutable_children()) { - for (Node& deduplicated_node : *category_node.mutable_children()) { - // Node with 1 child doesn't have deduplication, the child is itself. - // Removing the dedup layer. - if (deduplicated_node.children_size() == 1) { - Node child = *deduplicated_node.mutable_children(0); - deduplicated_node = child; - continue; - } - CopySymbolDetailsToDeduplicatedNode( - deduplicated_node.mutable_children(0), &deduplicated_node); - } - } - } - } else { - for (Node& category_node : *root->mutable_children()) { - for (Node& deduplicated_node : *category_node.mutable_children()) { - // Node with 1 child doesn't have deduplication, the child is itself. - // Removing the dedup layer. - if (deduplicated_node.children_size() == 1) { - Node child = *deduplicated_node.mutable_children(0); - deduplicated_node = child; - continue; - } - CopySymbolDetailsToDeduplicatedNode( - deduplicated_node.mutable_children(0), &deduplicated_node); - } - } - } -} - -// Fills op metrics into a node. -void PopulateOpMetricsNode( - const OpMetrics& op_metrics, double peak_gigaflops_per_second_per_core, - std::vector peak_mem_gibibytes_per_second_per_core, - uint64_t total_time_ps, Node* node) { - - Metrics* metrics = node->mutable_metrics(); - // The UI computes flops_rate = raw_flops / raw_time - // and memory_bandwidth = raw_bytes_accessed / raw_time. See: - // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts - metrics->set_raw_time(op_metrics.time_ps()); - metrics->set_raw_flops(op_metrics.model_flops()); - metrics->set_occurrences(op_metrics.occurrences()); - metrics->set_avg_time_ps(tsl::profiler::SafeDivide(op_metrics.time_ps(), - op_metrics.occurrences())); - - double flops_utilization = CapUtilization( - tsl::profiler::SafeDivide(GigaFlopsPerSecondPerCore(op_metrics), - peak_gigaflops_per_second_per_core)); - // The UI expects flops_utilization = flop_util / time_fraction. See: - // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts - const double time_fraction = - tsl::profiler::SafeDivide(op_metrics.time_ps(), total_time_ps); - metrics->set_flops(flops_utilization * time_fraction); - - // Capture both on-chip and off-chip memory utilization. - const double hbm_gibibytes_per_second = - tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_HBM, - OpMetrics::MemoryAccessed::READ)) + - tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_HBM, - OpMetrics::MemoryAccessed::WRITE)); - const double hbm_bw_utilization = CapUtilization(tsl::profiler::SafeDivide( - hbm_gibibytes_per_second, - peak_mem_gibibytes_per_second_per_core[MemBwType::MEM_BW_TYPE_HBM_RW])); - metrics->add_bandwidth_utils(hbm_bw_utilization); - double hbm_bytes = tsl::profiler::GibiToGiga(hbm_gibibytes_per_second) * - tsl::profiler::PicoToNano(op_metrics.time_ps()); - - const double sram_rd_gibibytes_per_second = tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_ON_CHIP, - OpMetrics::MemoryAccessed::READ)); - const double sram_rd_bw_utilization = - CapUtilization(tsl::profiler::SafeDivide( - sram_rd_gibibytes_per_second, peak_mem_gibibytes_per_second_per_core - [MemBwType::MEM_BW_TYPE_SRAM_RD])); - metrics->add_bandwidth_utils(sram_rd_bw_utilization); - double sram_rd_bytes = - tsl::profiler::GibiToGiga(sram_rd_gibibytes_per_second) * - tsl::profiler::PicoToNano(op_metrics.time_ps()); - - const double sram_wr_gibibytes_per_second = tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_ON_CHIP, - OpMetrics::MemoryAccessed::WRITE)); - const double sram_wr_bw_utilization = - CapUtilization(tsl::profiler::SafeDivide( - sram_wr_gibibytes_per_second, peak_mem_gibibytes_per_second_per_core - [MemBwType::MEM_BW_TYPE_SRAM_WR])); - metrics->add_bandwidth_utils(sram_wr_bw_utilization); - double sram_wr_bytes = - tsl::profiler::GibiToGiga(sram_wr_gibibytes_per_second) * - tsl::profiler::PicoToNano(op_metrics.time_ps()); - - metrics->add_raw_bytes_accessed_array(hbm_bytes); - metrics->add_raw_bytes_accessed_array(sram_rd_bytes); - metrics->add_raw_bytes_accessed_array(sram_wr_bytes); -} - -// Recursively insert "fused instruction" nodes (with raw flops). -void InsertFusedInstructions(const OpMetrics& op_metrics, Node* node) { - if (!op_metrics.has_children()) return; - for (const auto& child : op_metrics.children().metrics_db()) { - Node* new_node = node->add_children(); - PopulateSymbolNode(child, new_node); - new_node->mutable_metrics()->set_raw_flops(child.flops()); - if (child.has_children()) { - InsertFusedInstructions(child, new_node); - } - } -} - -void UpdateNodeMetrics(const OpMetrics& child, OpMetrics* parent) { - DCHECK(parent != nullptr); - parent->set_time_ps(child.self_time_ps() + parent->time_ps()); - if (ChildrenTimePs(child) == 0) { - parent->set_flops(child.flops() + parent->flops()); - parent->set_model_flops(child.model_flops() + parent->model_flops()); - parent->set_bytes_accessed(child.bytes_accessed() + - parent->bytes_accessed()); - parent->set_dma_stall_ps(child.dma_stall_ps() + parent->dma_stall_ps()); - CombineMemoryAccessedBreakdown(child.memory_accessed_breakdown(), - parent->mutable_memory_accessed_breakdown()); - } -} - -} // namespace - -std::string OpProfileBuilder::GenerateProgramName(uint64_t program_id) const { - DCHECK(program_name_map_ != nullptr); - auto iter = program_name_map_->find(program_id); - if (iter == program_name_map_->end()) return "main"; - return tsl::profiler::HloModuleNameWithProgramId(iter->second, program_id); -} - -Node* OpProfileBuilder::AddOpNode(const OpMetrics& op_metrics, - Category* category, Node* deduplicated_node) { - Node* leaf; - if (deduplicated_node != nullptr) { - leaf = deduplicated_node->add_children(); - } else if (category != nullptr) { - leaf = category->node->add_children(); - } else { - leaf = root_->add_children(); - } - PopulateSymbolNode(op_metrics, leaf); - InsertFusedInstructions(op_metrics, leaf); - return leaf; -} - -// Function to create deduplicated aggregation layer. -// 1. Empty deduplicated_name in op_metrics means either: -// (1) a grouping op of a deduplicated op list. (fusion.3 in the example below) -// (2) an op that does not have duplicates. (fusion.4 in the example below) -// We create dedup layer for both cases due to lack of clue which case it is. -// The op name is used directly as the hash key for the dedup group. The dedup -// layer will be removed in the 2nd pass for case (2). -// 2. Non-empty deduplicated_name means this op can be grouped to a -// deduplicated op list (fusion.1 in the example below). -// Example: -// op_metrics { -// name: "fusion.1" -// deduplicated_name: "fusion.3" -// category: "convolution" -// } -// op_metrics { -// name: "fusion.3" -// deduplicated_name: "" -// category: "convolution" -// } -// op_metrics { -// name: "fusion.4" -// deduplicated_name: "" -// category: "convolution" -// } -// The data above will create the following tree after calling the function -// repeatedly: -// root(by_program) -// - jit.xx -// - convolution -// - fusion.3 -// - fusion.1 -// - fusion.2 -// - fusion.3 -// - fusion.4 -// - fusion.4 -// After finalization, the tree will look like: -// root(by_program) -// - jit.xx -// - convolution -// - fusion.3 and its duplicate(s) -// - fusion.1 -// - fusion.2 -// - fusion.3 -// - fusion.4 -Node* OpProfileBuilder::LookupOrAddDeduplicatedNode(const OpMetrics& op_metrics, - Category* category) { - std::string deduplicated_name = op_metrics.deduplicated_name().empty() - ? op_metrics.name() - : op_metrics.deduplicated_name(); - Node*& deduplicated_node = category->deduplicated_nodes[deduplicated_name]; - if (deduplicated_node == nullptr) { - deduplicated_node = category->node->add_children(); - // Set deduplicated name which is the hash key for the dedup group. - // Symbol details will be added in finalization step. - deduplicated_node->set_name(deduplicated_name); - } - return deduplicated_node; -} - -OpProfileBuilder::Category* OpProfileBuilder::LookupOrAddCategoryNode( - const OpMetrics& op_metrics, Program* program) { - Category* category; - Node* category_parent; - if (program != nullptr) { - category = &program->categories[op_metrics.category()]; - category_parent = program->node; - } else { - category = &category_map_[op_metrics.category()]; - category_parent = root_; - } - if (category->node == nullptr) { - category->node = category_parent->add_children(); - category->node->set_name(op_metrics.category()); - } - return category; -} - -OpProfileBuilder::Program* OpProfileBuilder::LookupOrAddProgramNode( - const OpMetrics& op_metrics) { - uint64_t program_id = op_metrics.hlo_module_id(); - Program* program = &programs_map_[program_id]; - if (program->node == nullptr) { - program->node = root_->add_children(); - program->node->set_name(GenerateProgramName(program_id)); - } - return program; -} - -void OpProfileBuilder::AddOp(const OpMetrics& op_metrics) { - // 1. Deal with nested parent nodes - // op_metrics.time_ps in root node will be reset to total_time_ps later - UpdateNodeMetrics(op_metrics, &metrics_[root_]); - Program* program = nullptr; - if (!IsIdleOp(op_metrics) && options_.group_by_program) { - program = LookupOrAddProgramNode(op_metrics); - UpdateNodeMetrics(op_metrics, &metrics_[program->node]); - } - - // 2. Deal with nested grouping nodes, only accumulate non-child ops - if (ChildrenTimePs(op_metrics) > 0) return; - std::vector nested_grouping_nodes; - if (IsIdleOp(op_metrics)) { - Node* leaf = AddOpNode(op_metrics); - nested_grouping_nodes.push_back(leaf); - } else { - Category* category = LookupOrAddCategoryNode(op_metrics, program); - nested_grouping_nodes.push_back(category->node); - - Node* deduplicated_node = nullptr; - if (options_.group_by_deduplicated_name) { - deduplicated_node = LookupOrAddDeduplicatedNode(op_metrics, category); - nested_grouping_nodes.push_back(deduplicated_node); - } - - Node* leaf = AddOpNode(op_metrics, category, deduplicated_node); - nested_grouping_nodes.push_back(leaf); - } - - for (auto* node : nested_grouping_nodes) { - // Per program combiner does not need to update OpMetrics.num_cores - CombineOpMetrics(op_metrics, &metrics_[node], /*update_num_cores=*/false); - } -} - -void OpProfileBuilder::Finalize( - double peak_gigaflops_per_second_per_core, - std::vector peak_mem_gibibytes_per_second_per_core, - uint64_t total_time_ps) { - // Call to `PopulateOpMetricsNode` depends on node time_ps to calculate - // flops, bandwidth_utils..etc. The root / program node time_ps might - // be off a bit, missing its own self_time when calling `UpdateNodeMetrics`. - // This is best effort to at least reset the time_ps for root node to be more - // precise. - metrics_[root_].set_time_ps(total_time_ps); - for (const auto& [node, op_metrics] : metrics_) { - PopulateOpMetricsNode(op_metrics, peak_gigaflops_per_second_per_core, - peak_mem_gibibytes_per_second_per_core, total_time_ps, - node); - } - // If grouping by program, we build a two-level pruned tree: the first level - // is per program and the second level is per category. Otherwise we build a - // single-level per category pruned tree. - int level = options_.group_by_program ? 2 : 1; - SortAndPruneChildren(options_.children_per_node, level, root_); - if (options_.group_by_deduplicated_name) { - FinalizeDeduplicatedNodes(options_.group_by_program, root_); - } -} - -OpProfileBuilder::OpProfileBuilder( - const OpProfileOptions& options, - tensorflow::profiler::op_profile::Node* root, - const tsl::protobuf::Map* program_name_map) - : options_(options), root_(root), program_name_map_(program_name_map) { - if (root == nullptr) { - LOG(DFATAL) << "root is null."; - return; - } - DCHECK(!options_.group_by_program || program_name_map_ != nullptr); - root->set_name(options_.group_by_program ? "by_program" : "by_category"); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_profile_builder.h b/tensorflow/core/profiler/convert/op_profile_builder.h deleted file mode 100644 index 3d4e7abd1f6b18..00000000000000 --- a/tensorflow/core/profiler/convert/op_profile_builder.h +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/node_hash_map.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" - -namespace tensorflow { -namespace profiler { - -struct OpProfileOptions { - bool group_by_program = true; - bool group_by_deduplicated_name = true; - int children_per_node = 100; -}; - -// The structure of an op profile tree may looks like below: -// 1. group "by_program" -// - It starts from the root node, named as "by_program", and this node does -// not show up in op profile. -// - The children of root node is a list of hlo program node, named as the -// program/module name (eg. cluster.xx). -// - The children of a program node is hlo op category node, named as the -// category name (eg. data formatting). -// - The children of a category node is a list of op node or deduplicated -// group node: -// - For op that has duplicates, the child will be a deduplicated node, -// named like "copy.1111 and its deduplicate(s)". Its children will be all op -// nodes that are deduplicated. -// - For op that does not have duplicates, the child will be an op node -// under the op category (eg. copy.2222). -// -// Example path: "by_program" -> "main(...)" -// -> "data_formatting" -> "copy.12345 and its duplicate(s) -> "copy.12345" -// -// 2. group "by_category" -// Similarly to how the `by_program` op profile tree is constructed, -// `by_category` just removed the "program_node" layer: -// - It starts from the root node, named as "by_category", this node also does -// not show up in op profile. -// - The children of root node is a list of op category node, everything below -// is similar to above. -// - ... -// -// Example path: "by_category" -> "data_formatting" -> "copy.12345 and its -// duplicate(s) -> "copy.12345" -// -// How the op profile metrics are calculated: -// 1. For parent node in the nested structure like root node and program node: -// - time_ps will be accumulated from the self_time of all op nodes under it -// (might still be off a bit if the parent node has self_time, more details in -// b/333608397#comment5) -// - flops and memory access will only be accumulated from leaf op node under -// it to avoid double counting -// - unable to get occurrences of program executions now -// 2. For conceptual horizontal grouping node (eg.category, deduplicated) -// - all op_metris fields will be accumulated from leaf op node only in the -// group, to avoid double counting -class OpProfileBuilder { - public: - OpProfileBuilder(const OpProfileOptions& options, op_profile::Node* root, - const tensorflow::protobuf::Map* - program_name_map = nullptr); - - // Accumulate the op_metrics to the op_profile node tree - void AddOp(const OpMetrics& op_metrics); - - // Finalize the op_profile proto in a few steps (inter-dependent): - // 1. Reset time_ps for root node for more precise total time - // 2. Loop over the node to op_metrics map, populate corresponding op_metrics - // to the node.metrics - // 3. `SortAndPruneChildren` given query param `op_profile_limit` - // 4. `FinalizeDeduplicatedNodes` by coping the first op node data to the - // deduplicated node - void Finalize(double peak_gigaflops_per_second_per_core, - std::vector peak_mem_gibibytes_per_second_per_core, - uint64_t total_time_ps); - - private: - struct Category { - op_profile::Node* node; - absl::flat_hash_map deduplicated_nodes; - }; - - struct Program { - op_profile::Node* node; - absl::flat_hash_map categories; - }; - - std::string GenerateProgramName(uint64_t program_id) const; - - // Adds and returns a node for op_metrics. - // If op_metrics corresponds to a fusion, adds children to the node for the - // fused instructions. - // If deduplicated_node is not null, adds the node under it. - // Otherwise, if category is not null, adds the node under category. - // Otherwise, adds the node under root. - op_profile::Node* AddOpNode(const OpMetrics& op_metrics, - Category* category = nullptr, - op_profile::Node* deduplicated_node = nullptr); - - // Returns a node for op_metrics.deduplicated_name(). - // Adds a node to the tree if necessary. - op_profile::Node* LookupOrAddDeduplicatedNode(const OpMetrics& op_metrics, - Category* category); - - // Returns a node for op_metrics.category(). - // Adds a node to the tree if necessary. - // If program is not null, the category node is added under program. - // Otherwise, the category node is added under root. - Category* LookupOrAddCategoryNode(const OpMetrics& op_metrics, - Program* program); - - // Returns a node for op_metrics.hlo_module_id(). - // Adds a node to the Node tree if necessary. - Program* LookupOrAddProgramNode(const OpMetrics& op_metrics); - - OpProfileOptions options_; - op_profile::Node* root_; - - // Map to look up and aggregate OpMetrics. - absl::node_hash_map metrics_; - - // Maps to look up if a category / program / deduplicated node has - // already been added to the tree. - absl::flat_hash_map programs_map_; - absl::flat_hash_map category_map_; - - // Map to look up program names by id. - const tensorflow::protobuf::Map* program_name_map_ = - nullptr; -}; -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ diff --git a/tensorflow/core/profiler/convert/op_stack.h b/tensorflow/core/profiler/convert/op_stack.h deleted file mode 100644 index 6bfa4d776436da..00000000000000 --- a/tensorflow/core/profiler/convert/op_stack.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ - -#include -#include -#include - -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace profiler { - -template -class OpStack { - public: - // Pushes an Op onto the stack. - void Push(uint32 op_id, std::unique_ptr op_info) { - stack_.emplace_back(op_id, std::move(op_info)); - } - - // Pops the Op with the given op_id from the stack. - std::unique_ptr Pop(uint32 op_id) { - // Pop until match or stack_ is empty. - std::unique_ptr result; - while (!stack_.empty()) { - auto back = std::move(stack_.back()); - stack_.pop_back(); - if (op_id == back.first) { - result = std::move(back.second); - break; - } - } - return result; - } - - // Returns the Op at the top of the stack. - OpInfo* Top() const { - return stack_.empty() ? nullptr : stack_.back().second.get(); - } - - // Returns true if the stack is empty. - bool Empty() const { return stack_.empty(); } - - // Clears the stack. - void Clear() { stack_.clear(); } - - private: - std::vector>> stack_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc deleted file mode 100644 index 5c4b1bf08abb27..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_combiner.cc +++ /dev/null @@ -1,318 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_combiner.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/power_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "xprof/utils/hardware_type_utils.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof -#include "xprof/utils/step_intersection.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -// Combines the src PerCoreStepInfo into the dst PerCoreStepInfo. -void CombinePerCoreStepInfo( - int src_host_id, const PerCoreStepInfo& src, bool use_incomplete_step, - PerCoreStepInfo* dst, - OpMetricsDbCombiner* hlo_metrics_db_complete_steps_only_combiner, - OpMetricsDbCombiner* hlo_metrics_db_per_step_combiner) { - CombineCoreIdMap(src_host_id, src.step_info_per_core(), - dst->mutable_step_info_per_core()); - - // Since we have assigned a new step number to the combined result, update - // the step number on each core to this new step number. - uint32 new_step_num = dst->step_num(); - for (auto& percore_stepinfo : *dst->mutable_step_info_per_core()) { - auto& stepinfo = percore_stepinfo.second; - stepinfo.set_step_num(new_step_num); - } - - if (!use_incomplete_step) { - hlo_metrics_db_complete_steps_only_combiner->Combine(src.hlo_metrics_db()); - } - hlo_metrics_db_per_step_combiner->Combine(src.hlo_metrics_db()); - CombineCoreIdMap(src_host_id, src.all_reduce_db_per_core(), - dst->mutable_all_reduce_db_per_core()); - CombineCoreIdMap(src_host_id, src.core_id_to_replica_id_map(), - dst->mutable_core_id_to_replica_id_map()); -} - -void CombineStepDatabase( - int src_host_id, const StepIntersection& step_intersection, - const StepDatabaseResult& src, StepDatabaseResult* dst, - OpMetricsDbCombiner* hlo_metrics_db_complete_steps_only_combiner, - std::vector* hlo_metrics_db_per_step_combiners) { - if (src.use_incomplete_step()) dst->set_use_incomplete_step(true); - uint32 src_first_step_idx = step_intersection.FirstStepIndex(src_host_id); - for (uint32 i = 0; i < step_intersection.NumSteps(); i++) { - CombinePerCoreStepInfo( - src_host_id, src.step_sequence(src_first_step_idx + i), - src.use_incomplete_step(), dst->mutable_step_sequence(i), - hlo_metrics_db_complete_steps_only_combiner, - &(*hlo_metrics_db_per_step_combiners)[i]); - } -} - -void CombinePowerMetrics(const RunEnvironment& src, RunEnvironment* dst) { - const size_t src_hosts = src.hostnames_size(); - const size_t dst_hosts = dst->hostnames_size(); - const double src_weight = src_hosts * 1.0 / (src_hosts + dst_hosts); - const double dst_weight = dst_hosts * 1.0 / (src_hosts + dst_hosts); - // Always assume src/dst have the same number of power components. - for (const auto& src_metric : src.power_metrics().power_component_metrics()) { - for (auto& dst_metric : - *dst->mutable_power_metrics()->mutable_power_component_metrics()) { - if (src_metric.component_name() != dst_metric.component_name()) continue; - dst_metric.set_max_power( - std::max(src_metric.max_power(), dst_metric.max_power())); - dst_metric.set_avg_power(src_metric.avg_power() * src_weight + - dst_metric.avg_power() * dst_weight); - } - } -} - -void CombineRunEnvironment(const RunEnvironment& src, RunEnvironment* dst) { - dst->mutable_hostnames()->insert(src.hostnames().begin(), - src.hostnames().end()); - dst->set_host_count(dst->hostnames_size()); - // Ignore CPU and Unknown Device type for device type selection if the - // destination does not have a device type already. - if (src.device_type() != "CPU" && src.device_type() != "Device") { - dst->set_device_type(src.device_type()); - dst->set_device_core_count(src.device_core_count() + - dst->device_core_count()); - // Replica count and num cores per replica must be same for all copies. - dst->set_replica_count(std::max(src.replica_count(), dst->replica_count())); - dst->set_num_cores_per_replica( - std::max(src.num_cores_per_replica(), dst->num_cores_per_replica())); - *dst->mutable_system_topology() = src.system_topology(); - } else if (dst->device_type().empty()) { - dst->set_device_type(src.device_type()); - } - if (src.hardware_type() != dst->hardware_type()) { - // Select the highest hardware type as TPU/GPU should override CPU_ONLY - // (e.g. coordinator). - dst->set_hardware_type(std::max(src.hardware_type(), dst->hardware_type())); - } - dst->set_task_count(src.task_count() + dst->task_count()); - // Only overwrite the dst if profile_duration_ms in dst is not defined or - // is zero and profile_duration_ms in src is greater than zero. - if (src.host_independent_job_info().profile_duration_ms() > 0) { - (*dst->mutable_host_independent_job_info()) = - src.host_independent_job_info(); - } - for (const auto& job_info : src.host_dependent_job_info()) { - *(dst->add_host_dependent_job_info()) = job_info; - } - dst->set_host_trace_level(src.host_trace_level()); - dst->set_is_training(src.is_training()); - CombinePowerMetrics(src, dst); -} - -// Combines the src PerfEnv into the dst PerfEnv. -void CombinePerfEnv(const PerfEnv& src, PerfEnv* dst) { - if (src.peak_tera_flops_per_second() > 0) { - dst->set_peak_tera_flops_per_second(src.peak_tera_flops_per_second()); - } - - if (src.peak_bws_giga_bytes_per_second_size() > 0 && - dst->peak_bws_giga_bytes_per_second_size() == 0) { - *dst->mutable_peak_bws_giga_bytes_per_second() = - src.peak_bws_giga_bytes_per_second(); - } - if (src.ridge_point() > 0) { - dst->set_ridge_point(src.ridge_point()); - } -} - -// Combines the src Diagnostics into the dst Diagnostics. -void CombineDiagnostics(const Diagnostics& src, Diagnostics* dst) { - dst->mutable_info()->MergeFrom(src.info()); - dst->mutable_warnings()->MergeFrom(src.warnings()); - dst->mutable_errors()->MergeFrom(src.errors()); -} - -// Combine the src OpStats into the dst OpStats. -void CombineOpStats( - bool no_accelerator_in_system, int src_host_id, HardwareType hardware_type, - const StepIntersection& step_intersection, const OpStats& src, OpStats* dst, - OpMetricsDbCombiner* host_op_metrics_db_combiner, - OpMetricsDbCombiner* device_op_metrics_db_combiner, - OpMetricsDbCombiner* hlo_metrics_db_complete_steps_only_combiner, - std::vector* hlo_metrics_db_per_step_combiners) { - // Combine host_metrics_db. - // Host OpMetricsDb does not need to update the number of cores a certain op - // occurs. - host_op_metrics_db_combiner->Combine(src.host_op_metrics_db(), - /*update_num_cores=*/false); - // Combine device_metrics_db. - device_op_metrics_db_combiner->Combine(src.device_op_metrics_db()); - - // Combine step_db. - if (!IsCoordinator(no_accelerator_in_system, hardware_type)) { - CombineStepDatabase(src_host_id, step_intersection, src.step_db(), - dst->mutable_step_db(), - hlo_metrics_db_complete_steps_only_combiner, - hlo_metrics_db_per_step_combiners); - } - - // Combine run environment info. - CombineRunEnvironment(src.run_environment(), dst->mutable_run_environment()); - - // Combine the perf environment info. - CombinePerfEnv(src.perf_env(), dst->mutable_perf_env()); - - // Combine diagnostics. - CombineDiagnostics(src.diagnostics(), dst->mutable_diagnostics()); - - // Combine kernel stats. - dst->mutable_kernel_stats_db()->mutable_reports()->MergeFrom( - src.kernel_stats_db().reports()); - - // Combine tf-function stats. - CombineTfFunctionDb(src.tf_function_db(), dst->mutable_tf_function_db()); - - // Combine the mapping from core ID to details. - CombineCoreIdMap(src_host_id, src.core_id_to_details(), - dst->mutable_core_id_to_details()); - - // Combine performance counter result. - dst->mutable_performance_counter_result() - ->set_matrix_unit_utilization_percent( - dst->performance_counter_result().matrix_unit_utilization_percent() + - src.performance_counter_result().matrix_unit_utilization_percent()); -} - -} // namespace - -bool IsCoordinator(bool no_accelerator_in_system, HardwareType hardware_type) { - // A host is a coordinator if: - // (1) The host doesn't have a device, and - // (2) The system does use accelerator (if not, it uses CPU only and so this - // host should be regarded as a worker as well). - return !HasDevice(hardware_type) && !no_accelerator_in_system; -} - -bool NoAcceleratorInSystem(const std::vector& all_op_stats_info) { - for (const auto& op_stats_info : all_op_stats_info) { - if (HasDevice(op_stats_info.hardware_type)) { - return false; - } - } - return true; -} - -uint32 GlobalCoreId(int host_id, uint32 device_ordinal) { - constexpr uint32 kMaxDevicesPerHost = 1000; // power-of-10 for debuggability - return host_id * kMaxDevicesPerHost + device_ordinal; -} - -StepIntersection ComputeStepIntersectionToMergeOpStats( - const std::vector& all_op_stats_info, - uint32 max_step_per_host) { - bool no_accelerator_in_system = NoAcceleratorInSystem(all_op_stats_info); - - absl::flat_hash_map per_host_step_db; - for (const auto& op_stats_info : all_op_stats_info) { - if (IsCoordinator(no_accelerator_in_system, op_stats_info.hardware_type)) - continue; - // Includes only workers in per_host_step_db. - per_host_step_db[op_stats_info.src_host_id] = - &op_stats_info.op_stats->step_db(); - } - - return StepIntersection(max_step_per_host, per_host_step_db); -} - -void CombineAllOpStats(const std::vector& all_op_stats_info, - const StepIntersection& step_intersection, - OpStats* combined_op_stats) { - // A shortcut code path for a single OpStats. There is no need to merge. - if (all_op_stats_info.size() == 1) { - *combined_op_stats = *all_op_stats_info[0].op_stats; - return; - } - - StepDatabaseResult* combined_step_db = combined_op_stats->mutable_step_db(); - // Initialize the StepDatabaseResult field that depends on the number of - // steps. - for (uint32 dst_step_num : step_intersection.DstStepNumbers()) { - combined_step_db->add_step_sequence()->set_step_num(dst_step_num); - } - // Record the number of steps that are dropped. - combined_step_db->set_num_steps_dropped(step_intersection.StepsDropped()); - - combined_step_db->set_empty_intersect(step_intersection.EmptyIntersect()); - - // Initialize all the OpMetricsDbCombiners. - OpMetricsDbCombiner host_op_metrics_db_combiner( - combined_op_stats->mutable_host_op_metrics_db()); - OpMetricsDbCombiner device_op_metrics_db_combiner( - combined_op_stats->mutable_device_op_metrics_db()); - OpMetricsDbCombiner hlo_metrics_db_complete_steps_only_combiner( - combined_op_stats->mutable_hlo_metrics_db_complete_steps_only()); - std::vector hlo_metrics_db_per_step_combiners; - hlo_metrics_db_per_step_combiners.reserve( - combined_step_db->step_sequence_size()); - for (PerCoreStepInfo& step_info : - *combined_step_db->mutable_step_sequence()) { - hlo_metrics_db_per_step_combiners.emplace_back( - step_info.mutable_hlo_metrics_db()); - } - - bool no_accelerator_in_system = NoAcceleratorInSystem(all_op_stats_info); - - for (const auto& op_stats_info : all_op_stats_info) { - CombineOpStats(no_accelerator_in_system, op_stats_info.src_host_id, - op_stats_info.hardware_type, step_intersection, - *op_stats_info.op_stats, combined_op_stats, - &host_op_metrics_db_combiner, &device_op_metrics_db_combiner, - &hlo_metrics_db_complete_steps_only_combiner, - &hlo_metrics_db_per_step_combiners); - } - - // Sorts all the kernel reports that have been merged by CombineTfOpStats and - // keeps only the top kernel reports with long kernel duration. - SortAndKeepTopKDurationKernelReportsInDb( - combined_op_stats->mutable_kernel_stats_db()); - - // Process performance counter results. - combined_op_stats->mutable_performance_counter_result() - ->set_matrix_unit_utilization_percent( - combined_op_stats->performance_counter_result() - .matrix_unit_utilization_percent() / - all_op_stats_info.size()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.h b/tensorflow/core/profiler/convert/op_stats_combiner.h deleted file mode 100644 index e2a8bf25db0556..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_combiner.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "xprof/utils/step_intersection.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Whether a host is a coordinator. -bool IsCoordinator(bool no_accelerator_in_system, HardwareType hardware_type); - -// Translates the core id from single host to the one for multiple-host. -// We need this translation because the device_ordinal was assigned when a -// single host response was given. Now, we need a global core_id to distinguish -// it with multiple hosts. -uint32 GlobalCoreId(int host_id, uint32 device_ordinal); - -// Combines the src map into the dst map. -// The src map keys are local core_ids. The src_host_id is used to convert them -// into global core_ids used as keys in the dst map. -// REQUIRED: cores from src_host_id are not already in dst. -template -void CombineCoreIdMap(int src_host_id, const CoreIdMap& src, CoreIdMap* dst) { - for (const auto& core_id_and_value : src) { - uint32 global_core_id = GlobalCoreId(src_host_id, core_id_and_value.first); - auto iter_and_inserted = - dst->insert({global_core_id, core_id_and_value.second}); - DCHECK(iter_and_inserted.second) - << "Duplicated core_id: " << iter_and_inserted.first->first; - } -} - -// A struct that contains all the information that is needed to combine OpStats. -struct OpStatsInfo { - OpStatsInfo(const OpStats* op_stats, HardwareType hardware_type, - int src_host_id) - : op_stats(op_stats), - hardware_type(hardware_type), - src_host_id(src_host_id) {} - const OpStats* op_stats; - HardwareType hardware_type; - int src_host_id; -}; - -// Returns true if there is no device (accelerator) in any of the hosts. -bool NoAcceleratorInSystem(const std::vector& all_op_stats_info); - -// Compute the StepIntersection to merge OpStats. -// Profiler will limit the number of steps to be at most . -StepIntersection ComputeStepIntersectionToMergeOpStats( - const std::vector& all_op_stats_info, - uint32 max_step_per_host); - -// Combine all the OpStats in using the steps in range -// . The result is stored in . -void CombineAllOpStats(const std::vector& all_op_stats_info, - const StepIntersection& step_intersection, - OpStats* combined_op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc deleted file mode 100644 index b4da91a61e1611..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_combiner.h" - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "xprof/utils/step_intersection.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -// Tests that the run_environment field of the combined op stats is set -// correctly. -TEST(CombineAllOpStatsTest, CombineRunEnvironment) { - // Construct OpStatsInfo and all_op_stats_info. - OpStats dst_op_stats, op_stats_1, op_stats_2; - op_stats_1.mutable_run_environment() - ->mutable_host_independent_job_info() - ->set_profile_duration_ms(100); - op_stats_2.mutable_run_environment() - ->mutable_host_independent_job_info() - ->set_profile_duration_ms(0); - OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0), - op_stats_info_2(&op_stats_2, TPU, 0); - std::vector all_op_stats_info = {op_stats_info_1, - op_stats_info_2}; - - // Construct dummy step_intersection. - StepDatabaseResult dummy_step_db_result; - absl::flat_hash_map result; - result.insert({0, &dummy_step_db_result}); - StepIntersection dummy_step_intersection = StepIntersection(1, result); - - // Combine all op stats. - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats); - - // Verify that the profile_duration_ms field of the second object is now set. - EXPECT_EQ(100, dst_op_stats.run_environment() - .host_independent_job_info() - .profile_duration_ms()); -} - -TEST(CombineAllOpStatsTest, CombineRunEnvironmentWithUnknownDevice) { - OpStats dst_op_stats, op_stats_1, op_stats_2; - op_stats_1.mutable_run_environment()->set_device_type("TPU"); - op_stats_2.mutable_run_environment()->set_device_type("Device"); - OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0), - op_stats_info_2(&op_stats_2, TPU, 0); - std::vector all_op_stats_info = {op_stats_info_1, - op_stats_info_2}; - - // Construct dummy step_intersection. - StepDatabaseResult dummy_step_db_result; - absl::flat_hash_map result; - result.insert({0, &dummy_step_db_result}); - StepIntersection dummy_step_intersection = StepIntersection(1, result); - - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats); - - EXPECT_EQ("TPU", dst_op_stats.run_environment().device_type()); -} - -TEST(CombineAllOpStatsTest, CombinePerfEnvOrderZero) { - // Ensure CombinePerfEnv behaves consistently regardless of order of op stats. - OpStats dst_op_stats1, dst_op_stats2, op_stats_1, op_stats_2; - op_stats_1.mutable_perf_env()->set_peak_tera_flops_per_second(100); - op_stats_2.mutable_perf_env()->set_peak_tera_flops_per_second(0); - // Construct dummy step_intersection which is required by CombineAllOpStats(). - absl::flat_hash_map result; - StepIntersection dummy_step_intersection = StepIntersection(1, result); - - OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0), - op_stats_info_2(&op_stats_2, TPU, 0); - - // Test order 1. - std::vector all_op_stats_info = {op_stats_info_1, - op_stats_info_2}; - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats1); - EXPECT_EQ(100, dst_op_stats1.perf_env().peak_tera_flops_per_second()); - - // Test order 2. - all_op_stats_info = { - op_stats_info_2, - op_stats_info_1, - }; - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats2); - EXPECT_EQ(100, dst_op_stats2.perf_env().peak_tera_flops_per_second()); -} - -TEST(CombineAllOpStatsTest, CombineRunEnvironmentWithMismatchHardwareType) { - OpStats coordinator_op_stats, device_op_stats, dst_op_stats; - coordinator_op_stats.mutable_run_environment()->set_hardware_type( - HardwareType::CPU_ONLY); - device_op_stats.mutable_run_environment()->set_hardware_type( - HardwareType::TPU); - CombineAllOpStats({OpStatsInfo(&coordinator_op_stats, CPU_ONLY, 0), - OpStatsInfo(&device_op_stats, TPU, 1)}, - StepIntersection(1, {}), &dst_op_stats); - EXPECT_EQ(dst_op_stats.run_environment().hardware_type(), HardwareType::TPU); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc deleted file mode 100644 index fed7128c887ac0..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h" - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/profiler/convert/data_table_utils.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/hlo_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tensorflow::profiler::OpMetrics; -using tensorflow::profiler::OpMetricsDb; -using ::tensorflow::profiler::OpStats; -using ::tensorflow::profiler::PerfEnv; -using ::tensorflow::profiler::RunEnvironment; -using tensorflow::profiler::hlo_stats::HloStatsDatabase; -using tensorflow::profiler::hlo_stats::HloStatsRecord; -using tsl::profiler::IsOutsideCompilationOp; - -HloStatsRecord ConvertOpMetricsToHloStatsRecord(const OpMetrics& metrics, - const PerfEnv& perf_env, - const RunEnvironment& run_env) { - HloStatsRecord record; - record.set_program_id(metrics.hlo_module_id()); - record.set_hlo_expression(metrics.long_name()); - record.set_tf_op_name(metrics.provenance()); - record.set_hlo_category(metrics.category()); - record.set_autotuned(metrics.autotuned()); - tensorflow::profiler::SetExecutionTimes(metrics, &record); - tensorflow::profiler::SetTpuUnitFractions(metrics, &record); - SetRooflineMetrics(metrics, perf_env, run_env, &record); - record.set_rematerialization(tsl::profiler::IsRematerialization( - /*hlo_expression=*/metrics.long_name(), - /*framework_op_name=*/metrics.provenance())); - record.set_outside_compilation( - IsOutsideCompilationOp(metrics.provenance(), metrics.long_name())); - return record; -} - -} // namespace - -HloStatsDatabase ConvertOpStatsToHloStats(const OpStats& op_stats) { - HloStatsDatabase hlo_stats_db; - const OpMetricsDb& hlo_metrics_db = op_stats.device_op_metrics_db(); - double total_device_time_us = - tsl::profiler::PicoToMicro(hlo_metrics_db.total_time_ps()); - HloStatsRecord sentinel; - sentinel.set_rank(0); - sentinel.set_cumulative_total_self_time_as_fraction(0.0); - const HloStatsRecord* prev_record = &sentinel; - for (const OpMetrics* metrics : - tensorflow::profiler::SortedOpMetricsDb(hlo_metrics_db)) { - if (metrics->occurrences() == 0) continue; - HloStatsRecord* record = hlo_stats_db.add_hlo_stats_record(); - *record = ConvertOpMetricsToHloStatsRecord(*metrics, op_stats.perf_env(), - op_stats.run_environment()); - tensorflow::profiler::SetRankAndTimeFractions(total_device_time_us, - *prev_record, record); - prev_record = record; - } - return hlo_stats_db; -} - -// The parse logic based on the assumption that the hlo op text is in format of -// '%op_name = ' -std::string GetHloOpNameFromExpression(std::string expression) { - std::vector<::std::string> parts = absl::StrSplit(expression, " = "); - std::string hlo_op_name = parts[0]; - if (hlo_op_name[0] == '%') { - hlo_op_name = hlo_op_name.substr(1); - } - return hlo_op_name; -} - -std::vector> HloStatsDataTableColumns() { - const std::vector> kColumns = { - {"rank", "number", "Rank"}, - {"program_id", "string", "Program id"}, - {"category", "string", "HLO op category"}, - {"hlo_op_name", "string", "HLO op name"}, - {"hlo_op_expression", "string", "HLO op text"}, - {"tf_op_name", "string", "Framework op name"}, - {"occurrences", "number", "#Occurrences"}, - {"total_time", "number", "Total time (us)"}, - {"avg_time", "number", "Avg. time (us)"}, - {"total_self_time", "number", "Total self time (us)"}, - {"avg_self_time", "number", "Avg. self time (us)"}, - {"total_self_time_percent", "number", "Total self time (%)"}, - { - "cumulative_total_self_time_percent", - "number", - "Cumulative total self time (%)", - }, - {"dma_stall_percent", "number", "%time stalled by DMA"}, - {"model_flop_rate", "number", "Model GFLOP/s"}, - {"normalized_flop_rate", "number", "Normalized GFLOP/s"}, - {"measured_memory_bw", "number", "Measured memory BW (GiB/s)"}, - {"hbm_bw", "number", "HBM BW (GiB/s)"}, - {"cmem_read_bw", "number", "CMEM Read BW (GiB/s)"}, - {"cmem_write_bw", "number", "CMEM Write BW (GiB/s)"}, - {"operational_intensity", "number", "Operational intensity (FLOPS/Byte)"}, - {"bound_by", "string", "Bound by"}, - {"hlo_rematerialization", "string", "Rematerialization"}, - {"outside_compilation", "string", "Outside Compilation"}, - {"autotuned", "string", "Autotuned"}, - }; - return kColumns; -} - -std::unique_ptr CreateHloStatsDataTable( - const HloStatsDatabase& hlo_stats_db) { - auto data_table = std::make_unique(); - for (const std::vector& col : HloStatsDataTableColumns()) { - data_table->AddColumn(TableColumn(col[0], col[1], col[2])); - } - for (const HloStatsRecord& record : hlo_stats_db.hlo_stats_record()) { - TableRow* row = data_table->AddRow(); - row->AddCell(record.rank()); - row->AddCell(absl::StrCat(record.program_id())); - row->AddCell(record.hlo_category()); - row->AddCell(GetHloOpNameFromExpression(record.hlo_expression())); - row->AddCell(record.hlo_expression()); - row->AddCell(record.tf_op_name()); - row->AddCell(record.occurrences()); - row->AddCell(record.total_time_in_us()); - row->AddCell(record.avg_time_in_us()); - row->AddCell(record.total_self_time_in_us()); - row->AddCell(record.avg_self_time_in_us()); - row->AddCell(record.total_self_time_as_fraction() * 100); - row->AddCell(record.cumulative_total_self_time_as_fraction() * 100); - row->AddCell(record.dma_stall_fraction()); - row->AddCell(record.model_flop_rate()); - row->AddCell(record.measured_flop_rate()); - row->AddCell(record.measured_memory_bw()); - row->AddCell(record.hbm_bw()); - row->AddCell(record.cmem_read_bw()); - row->AddCell(record.cmem_write_bw()); - row->AddCell(record.operational_intensity()); - row->AddCell(absl::StrCat(record.bound_by())); - row->AddCell(record.rematerialization() ? "Yes" : "No"); - row->AddCell(record.outside_compilation() ? "Yes" : "No"); - row->AddCell(record.autotuned() ? "Yes" : "No"); - } - return data_table; -} - -std::string HloStatsToDataTableJson(const HloStatsDatabase& hlo_stats_db) { - return CreateHloStatsDataTable(hlo_stats_db)->ToJson(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h b/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h deleted file mode 100644 index 359024df04b221..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ - -#include -#include - -#include "tensorflow/core/profiler/convert/data_table_utils.h" -#include "tensorflow/core/profiler/protobuf/hlo_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { -tensorflow::profiler::hlo_stats::HloStatsDatabase ConvertOpStatsToHloStats( - const tensorflow::profiler::OpStats& op_stats); - -// Converts to JSON align with current DataTable JSON format. -std::string HloStatsToDataTableJson( - const hlo_stats::HloStatsDatabase& hlo_stats_db); - -// Construct a DataTable object from HloStatsDatabase. -std::unique_ptr CreateHloStatsDataTable( - const hlo_stats::HloStatsDatabase& hlo_stats_db); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc deleted file mode 100644 index 956e7e46c8b34e..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ /dev/null @@ -1,1648 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" - -#include - -#include -#include -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/format_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/util/stats_calculator.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/convert/profile_time_breakdown.h" -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" -#include "tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/platform/protobuf.h" -#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/input_pipeline.pb.h" // from @org_xprof -#include "xprof/utils/diagnostics.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof -#include "xprof/utils/html_utils.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof -#include "xprof/utils/tpu_step_breakdown_utils.h" // from @org_xprof -#include "xprof/utils/tpu_step_details_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -using tsl::profiler::OneDigit; - -// If the percentage of step time that spends on SparseCoreV0 is more than -// kModeratelySparseCoreV0BoundThresholdInPercent, it is considered highly -// SparseCoreV0 bound. -constexpr double kModeratelySparseCoreV0BoundThresholdInPercent = 10; -// If the percentage of step time that spends on all-reduce is more than -// kAllReduceBoundThresholdInPercent, it is considered all-reduce bound. -constexpr double kAllReduceBoundThresholdInPercent = 6; -// If the percentage of step time that is idle due to host overhead (but not -// input-related) is >= kTcIdleThresholdInPercent, it will be highlighted in the -// recommendation section of the Overview Page. -constexpr double kTcIdleThresholdInPercent = 3; -// Public doc on how to run multiple steps in a tf-function. -constexpr absl::string_view kMultipleStepsInTffunctionDoc = - "https://www.tensorflow.org/guide/" - "tpu#improving_performance_by_multiple_steps_within_tffunction"; - -const double kNumPsPerMs = 1000000000.0; - -// If the percentage of step time that is due to infeed is less than -// kModeratelyInfeedBoundThresholdInPercent, it is considered NOT -// input-bound; else if it is less than -// kHighlyInfeedBoundThresholdInPercent, it is considered MODERATELY -// input-bound; else if it is considered HIGHLY input-bound. -constexpr double kModeratelyInfeedBoundThresholdInPercent = 5; -constexpr double kHighlyInfeedBoundThresholdInPercent = 20; - -// If the percentage of step time that is due to outfeed is less than -// kModeratelyOutfeedBoundThresholdInPercent, it is considered NOT -// output-bound; else if it is less than -// kHighlyOutfeedBoundThresholdInPercent, it is considered MODERATELY -// output-bound; else if it is considered HIGHLY output-bound. -constexpr double kModeratelyOutfeedBoundThresholdInPercent = 5; -constexpr double kHighlyOutfeedBoundThresholdInPercent = 20; - -// If the percentage of step time that is due to kernel launch is less than -// kModeratelyKernelLaunchBoundThresholdInPercent, it is considered NOT -// kernel-launch bound; else if it is less than -// kHighlyKernelLaunchBoundThresholdInPercent, it is considered MODERATELY -// kernel-launch bound; else if it is considered HIGHLY kernel-launch bound. -constexpr double kModeratelyKernelLaunchBoundThresholdInPercent = 3; -constexpr double kHighlyKernelLaunchBoundThresholdInPercent = 15; - -// If the percentage of step time that is due to all other time is less than -// kModeratelyAllOtherBoundThresholdInPercent, it is considered NOT -// all-other bound; else if it is less than -// kHighlyAllOtherBoundThresholdInPercent, it is considered MODERATELY -// all-other bound; else if it is considered HIGHLY all-other bound. -constexpr double kModeratelyAllOtherBoundThresholdInPercent = 3; -constexpr double kHighlyAllOtherBoundThresholdInPercent = 15; - -// If the percentage of step time that is due to device collectives is less than -// kModeratelyDeviceCollectivesBoundThresholdInPercent, it is considered NOT -// device-collectives bound; else if it is less than -// kHighlyDeviceCollectivesBoundThresholdInPercent, it is considered MODERATELY -// device-collectives bound; else if it is considered HIGHLY device-collectives -// bound. -constexpr double kModeratelyDeviceCollectivesBoundThresholdInPercent = 3; -constexpr double kHighlyDeviceCollectivesBoundThresholdInPercent = 15; - -// Section number of the host-analysis section in the input-pipeline analysis. -constexpr int kHostAnalysisSectionNumber = 3; -// Python-only explanation for "All Others" time. -const char* kAllOthersPythonExplanation = - " % of the total step time sampled is spent on 'All Others' time. " - "This could be due to Python execution overhead."; -// Explanation for "Kernel Launch" time due to CPU contention with tf.data. -const char* kKernelLaunchTfDataContention = - " It could be due to CPU contention with tf.data. In this case, you may " - "try to set the environment variable TF_GPU_THREAD_MODE=gpu_private."; - -template -double GetTimeInMs(const Collection& type_ps, EventType event_type) { - return tsl::profiler::PicoToMilli( - gtl::FindWithDefault(type_ps, event_type, /*value=*/0)); -} - -GenericStepTimeBreakdown ComputeGenericStepTimeBreakdownInMs( - const InputPipelineAnalysisResult& analysis) { - tsl::Stat unknown_time_ms; - tsl::Stat host_wait_input_ms; - tsl::Stat host_to_device_ms; - tsl::Stat input_ms; - tsl::Stat output_ms; - tsl::Stat device_compute_ms; - tsl::Stat device_to_device_ms; - tsl::Stat device_collectives_ms; - tsl::Stat host_compute_ms; - tsl::Stat host_prepare_ms; - tsl::Stat host_compile_ms; - GenericStepTimeBreakdown result; - - for (const google::protobuf::Any& step_details : analysis.step_details()) { - PerGenericStepDetails details; - bool success = step_details.UnpackTo(&details); - if (!success && !step_details.type_url().empty()) { - LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic" - << std::endl; - return {}; - } - unknown_time_ms.UpdateStat(details.unknown_time_ms()); - host_wait_input_ms.UpdateStat(details.host_wait_input_ms()); - host_to_device_ms.UpdateStat(details.host_to_device_ms()); - input_ms.UpdateStat(details.host_wait_input_ms() + - details.host_to_device_ms()); - output_ms.UpdateStat(details.output_ms()); - device_compute_ms.UpdateStat(details.device_compute_ms()); - device_to_device_ms.UpdateStat(details.device_to_device_ms()); - device_collectives_ms.UpdateStat(details.device_collectives_ms()); - host_compute_ms.UpdateStat(details.host_compute_ms()); - host_prepare_ms.UpdateStat(details.host_prepare_ms()); - host_compile_ms.UpdateStat(details.host_compile_ms()); - } - *result.mutable_unknown_time_ms_summary() = - GetStepSummaryForSampleStats(unknown_time_ms); - *result.mutable_host_wait_input_ms_summary() = - GetStepSummaryForSampleStats(host_wait_input_ms); - *result.mutable_host_to_device_ms_summary() = - GetStepSummaryForSampleStats(host_to_device_ms); - *result.mutable_input_ms_summary() = GetStepSummaryForSampleStats(input_ms); - *result.mutable_output_ms_summary() = GetStepSummaryForSampleStats(output_ms); - *result.mutable_device_compute_ms_summary() = - GetStepSummaryForSampleStats(device_compute_ms); - *result.mutable_device_to_device_ms_summary() = - GetStepSummaryForSampleStats(device_to_device_ms); - *result.mutable_device_collectives_ms_summary() = - GetStepSummaryForSampleStats(device_collectives_ms); - *result.mutable_host_compute_ms_summary() = - GetStepSummaryForSampleStats(host_compute_ms); - *result.mutable_host_prepare_ms_summary() = - GetStepSummaryForSampleStats(host_prepare_ms); - *result.mutable_host_compile_ms_summary() = - GetStepSummaryForSampleStats(host_compile_ms); - return result; -} - -InputPipelineAnalysisResult ComputeGenericInputPipelineAnalysisResult( - const tsl::protobuf::RepeatedPtrField& grouped_by_step) { - InputPipelineAnalysisResult result; - result.set_tag(false); - - // Computes the summary of step time in ms. - *result.mutable_step_time_summary() = - ComputeStepTimeSummaryInMs(grouped_by_step); - - tsl::Stat input_summary_stats_in_percent; - for (const auto& coreid_stepinfo_map : grouped_by_step) { - // Iterates over each step. - const auto* ptr = gtl::FindOrNull(coreid_stepinfo_map.step_info_per_core(), - kDefaultGpuLocalCoreId); - if (ptr == nullptr) { - // For generic hardware, all step-info is put under core-0. If ptr - // is nullptr, it means there is no step at all. - continue; - } - const StepInfoResult& step_info = *ptr; - // Adds the details for a new step. - PerGenericStepDetails details; - details.set_step_number(step_info.step_num()); - if (step_info.step_name().empty()) { - details.set_step_name(absl::StrCat(step_info.step_num())); - } else { - details.set_step_name(step_info.step_name()); - } - details.set_step_time_ms( - tsl::profiler::PicoToMilli(step_info.duration_ps())); - GenericStepBreakdown generic; - bool success = step_info.step_breakdown().UnpackTo(&generic); - if (!success && !step_info.step_breakdown().type_url().empty()) { - LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic" - << std::endl; - return {}; - } - const auto& type_ps = generic.type_ps(); - details.set_unknown_time_ms(GetTimeInMs(type_ps, UNKNOWN_TIME)); - details.set_host_wait_input_ms(GetTimeInMs(type_ps, HOST_WAIT_INPUT)); - details.set_host_to_device_ms(GetTimeInMs(type_ps, HOST_TO_DEVICE) + - GetTimeInMs(type_ps, DEVICE_WAIT_HOST)); - details.set_output_ms(GetTimeInMs(type_ps, DEVICE_TO_HOST)); - details.set_device_compute_ms(GetTimeInMs(type_ps, DEVICE_COMPUTE_16) + - GetTimeInMs(type_ps, DEVICE_COMPUTE_32)); - details.set_device_to_device_ms(GetTimeInMs(type_ps, DEVICE_TO_DEVICE) + - GetTimeInMs(type_ps, DEVICE_WAIT_DEVICE)); - details.set_device_collectives_ms(GetTimeInMs(type_ps, DEVICE_COLLECTIVES)); - details.set_host_compute_ms(GetTimeInMs(type_ps, HOST_COMPUTE)); - details.set_host_prepare_ms(GetTimeInMs(type_ps, HOST_PREPARE)); - details.set_host_compile_ms(GetTimeInMs(type_ps, HOST_COMPILE)); - result.add_step_details()->PackFrom(details); - - const double input_percent_of_step_time = - 100.0 * tsl::profiler::SafeDivide( - details.host_wait_input_ms() + details.host_to_device_ms(), - details.step_time_ms()); - input_summary_stats_in_percent.UpdateStat(input_percent_of_step_time); - } - - // Computes the summary of input time as percentage of step time. - *result.mutable_input_percent_summary() = - GetStepSummaryForSampleStats(input_summary_stats_in_percent); - - // Computes the breakdown of step time. - GenericStepTimeBreakdown generic_step_time_breakdown = - ComputeGenericStepTimeBreakdownInMs(result); - result.mutable_step_time_breakdown()->PackFrom(generic_step_time_breakdown); - - return result; -} - -// Classification of input processing on the host. -enum class InputOpCategory { - kEnqueue, // enqueue data to be transferred to device. - kDemandedFileRead, // demanded read from file. - kAdvancedFileRead, // advanced read from file (including cached, - // prefetch, parallel-map, interleave). - kPreprocessing // data preprocessing. -}; - -std::string InputOpCategoryString(InputOpCategory category) { - switch (category) { - case InputOpCategory::kEnqueue: - return "Enqueue"; - case InputOpCategory::kDemandedFileRead: - return "Demanded file read"; - case InputOpCategory::kAdvancedFileRead: - return "Advanced file read"; - case InputOpCategory::kPreprocessing: - return "Preprocessing"; - } -} - -inline bool IsInputOp(absl::string_view category) { - // Do not include "IteratorGetNext*" here, because IteratorGetNext is an Op - // that experiences the install stall, not an Op that causes the input stall. - return tsl::profiler::IsInfeedEnqueueOp(category) || - tsl::profiler::IsDatasetOp(category) || - tsl::profiler::IsMemcpyHToDOp(category); -} - -// TODO(ckluk): -// Confirm with the tf.data team if the classification below is correct. -InputOpCategory CategorizeInputOp(absl::string_view name, - absl::string_view category) { - if (tsl::profiler::IsInfeedEnqueueOp(category) || - tsl::profiler::IsMemcpyHToDOp(category)) { - // Ops for sending input from host to device. - return InputOpCategory::kEnqueue; - } - DCHECK(tsl::profiler::IsDatasetOp(category)); - if (absl::EndsWith(name, "::TFRecord") || - absl::EndsWith(name, "::TextLine") || - absl::EndsWith(name, "::FixedLengthRecord") || - absl::EndsWith(name, "::SSTable") || absl::EndsWith(name, "::RecordIO")) { - // Ops that read files. - if (absl::StrContains(name, "::MemoryReader") || - absl::StrContains(name, "::MemoryWriter") || - absl::StrContains(name, "::Interleave") || - absl::StrContains(name, "::Prefetch") || - absl::StrContains(name, "::ParallelMap")) { - // Ops that read files in advance, including caching, interleaving, and - // prefetching. - return InputOpCategory::kAdvancedFileRead; - } else { - // Ops that read files on demand. - return InputOpCategory::kDemandedFileRead; - } - } else { - // All other ops are classified as preprocessing. - return InputOpCategory::kPreprocessing; - } -} - -struct InputOpMetrics { - std::vector input_op_metrics; - uint64 input_op_time_ps = 0; -}; - -InputOpMetrics SelectInputOpMetrics(const OpMetricsDb& all_op_metrics) { - InputOpMetrics input_op_metrics; - for (const OpMetrics* op_metrics : SortedOpMetricsDb(all_op_metrics)) { - if (IsInputOp(op_metrics->category())) { - input_op_metrics.input_op_metrics.push_back(op_metrics); - input_op_metrics.input_op_time_ps += op_metrics->self_time_ps(); - } - } - return input_op_metrics; -} - -InputOpDetails ConvertOpMetricsToInputOpDetails(const OpMetrics& op_metrics, - uint64 input_op_time_ps, - InputOpCategory category) { - InputOpDetails details; - details.set_op_name(op_metrics.name()); - details.set_count(op_metrics.occurrences()); - details.set_time_in_ms(tsl::profiler::PicoToMilli(op_metrics.time_ps())); - details.set_self_time_in_ms( - tsl::profiler::PicoToMilli(op_metrics.self_time_ps())); - details.set_time_in_percent( - 100.0 * - tsl::profiler::SafeDivide(op_metrics.time_ps(), input_op_time_ps)); - details.set_self_time_in_percent( - 100.0 * - tsl::profiler::SafeDivide(op_metrics.self_time_ps(), input_op_time_ps)); - details.set_category(InputOpCategoryString(category)); - return details; -} - -// Returns the ratio of the host-to-device time in each step to the step-time. -double RatioOfHostToDeviceTimeToStepTime( - const OpMetricsDb& host_tf_metrics_db, - const InputPipelineAnalysisResult& input_pipeline_analysis) { - // For TPU execution that uses infeed. - std::optional host_infeed_enqueue_ratio = - HostInfeedEnqueueRatio(host_tf_metrics_db); - if (host_infeed_enqueue_ratio.has_value()) { - return host_infeed_enqueue_ratio.value(); - } - // For GPU and TPU execution that do not use infeed. - double avg_step_time_ms = - input_pipeline_analysis.step_time_summary().average(); - if (avg_step_time_ms > 0) { - // Uses the on-device step time. - GenericStepTimeBreakdown generic_breakdown; - if (input_pipeline_analysis.step_time_breakdown().UnpackTo( - &generic_breakdown)) { - double avg_host_to_device_time_ms = - generic_breakdown.host_to_device_ms_summary().average(); - return tsl::profiler::SafeDivide(avg_host_to_device_time_ms, - avg_step_time_ms); - } - } - return 0.0; -} - -void DeviceCollectivesAnalysis(double device_collectives_percent, - std::string* device_collectives_classification, - std::string* device_collectives_statement) { - if (device_collectives_percent >= - kHighlyDeviceCollectivesBoundThresholdInPercent) { - *device_collectives_classification = "high"; - *device_collectives_statement = - absl::StrCat(OneDigit(device_collectives_percent), - " % of the total step time sampled is spent on 'Device " - "Collective Communication'."); - } else if (device_collectives_percent >= - kModeratelyDeviceCollectivesBoundThresholdInPercent) { - *device_collectives_classification = "moderate"; - *device_collectives_statement = - absl::StrCat(OneDigit(device_collectives_percent), - " % of the total step time sampled is spent on 'Device " - "Collective Communication'."); - } else { - *device_collectives_classification = "no"; - *device_collectives_statement = ""; - } -} - -void KernelLaunchAnalysis(bool tfdata_used, double kernel_launch_percent, - std::string* kernel_launch_classification, - std::string* kernel_launch_statement) { - if (kernel_launch_percent >= kHighlyKernelLaunchBoundThresholdInPercent) { - *kernel_launch_classification = "high"; - *kernel_launch_statement = absl::StrCat( - OneDigit(kernel_launch_percent), - " % of the total step time sampled is spent on 'Kernel Launch'."); - if (tfdata_used) { - absl::StrAppend(kernel_launch_statement, kKernelLaunchTfDataContention); - } - } else if (kernel_launch_percent >= - kModeratelyKernelLaunchBoundThresholdInPercent) { - *kernel_launch_classification = "moderate"; - *kernel_launch_statement = absl::StrCat( - OneDigit(kernel_launch_percent), - " % of the total step time sampled is spent on 'Kernel Launch'."); - if (tfdata_used) { - absl::StrAppend(kernel_launch_statement, kKernelLaunchTfDataContention); - } - } else { - *kernel_launch_classification = "no"; - *kernel_launch_statement = ""; - } -} - -void AllOtherAnalysis(bool all_other_reported, double all_other_percent, - std::string* all_other_classification, - std::string* all_other_statement) { - if (all_other_reported) { - *all_other_classification = "no"; - *all_other_statement = ""; - return; - } - if (all_other_percent >= kHighlyAllOtherBoundThresholdInPercent) { - *all_other_classification = "high"; - *all_other_statement = - absl::StrCat(OneDigit(all_other_percent), kAllOthersPythonExplanation); - } else if (all_other_percent >= kModeratelyAllOtherBoundThresholdInPercent) { - *all_other_classification = "moderate"; - *all_other_statement = - absl::StrCat(OneDigit(all_other_percent), kAllOthersPythonExplanation); - } else { - *all_other_classification = "no"; - *all_other_statement = ""; - } -} - -// Tests if tf.data API is in use. -bool TfDataInUse(const InputTimeBreakdown& breakdown) { - // Do not include enqueue_us because the "enqueue" Op that Xprof recognizes is - // not part of tf.data. - return breakdown.demanded_file_read_us() > 0 || - breakdown.advanced_file_read_us() > 0 || - breakdown.preprocessing_us() > 0; -} - -// Returns a HTML link with the given text. -std::string MakeDocLink(absl::string_view doc_link, absl::string_view text) { - return absl::StrCat("
", text, - ""); -} - -// Returns the HTML link to the introduction to the tf.data API. -std::string DatasetIntroDoc() { - return "https://www.tensorflow.org/guide/data"; -} - -struct WaitForScV0Breakdown { - uint64_t DurationPs() const { - return scv0_infeed_duration_ps + scv0_compute_duration_ps; - } - - uint64_t scv0_infeed_duration_ps = 0; - uint64_t scv0_compute_duration_ps = 0; -}; - -struct TcInfeed { - std::optional core_id; - uint64_t duration_ps = 0; -}; - -void ConvertGenericStepBreakdownToTpuStepBreakdown( - const tensorflow::profiler::GenericStepBreakdown& generic_step_breakdown, - uint64_t step_time_ps, TpuStepBreakdown& tpu_step_breakdown) { - auto& category_ps = generic_step_breakdown.category_ps(); - tensorflow::profiler::ProfileTimeBreakdown time_breakdown; - for (const auto& [category, time_ps] : category_ps) { - // Don't add idle time to time_breakdown as the idle time is inferred. - if (category == "IDLE") continue; - time_breakdown.IncrementCategoryTimePs(category, time_ps); - } - time_breakdown.SetProfileTimePs(step_time_ps); - time_breakdown.BreakdownSparseCoreV0Infeed(); - - tpu_step_breakdown.set_infeed_duration_ps(time_breakdown.InfeedTimePs()); - tpu_step_breakdown.set_host_outfeed_ps(time_breakdown.OutfeedTimePs()); - tpu_step_breakdown.set_wait_for_scv0_duration_ps( - time_breakdown.SparseCoreV0InfeedWaitTimePs()); - tpu_step_breakdown.set_scv0_infeed_transform_ps( - time_breakdown.SparseCoreV0InfeedTransformTimePs()); - tpu_step_breakdown.set_scv0_outfeed_ps( - time_breakdown.SparseCoreV0OutfeedTimePs()); - tpu_step_breakdown.set_crs_duration_ps( - time_breakdown.AllReduceOrAllToAllTimePs()); - tpu_step_breakdown.set_send_duration_ps(time_breakdown.SendTimePs()); - tpu_step_breakdown.set_recv_duration_ps(time_breakdown.RecvTimePs()); - tpu_step_breakdown.set_host_send_duration_ps(time_breakdown.HostSendTimePs()); - tpu_step_breakdown.set_host_recv_duration_ps(time_breakdown.HostRecvTimePs()); - tpu_step_breakdown.set_wait_for_megacore_fusion_peer_duration_ps( - time_breakdown.MegacoreFusionTimePs()); - tpu_step_breakdown.set_high_flops_compute_ps( - time_breakdown.HighFlopsComputeTimePs()); - tpu_step_breakdown.set_tc_idle_ps(time_breakdown.IdleTimePs()); - tpu_step_breakdown.set_tc_busy_ps(time_breakdown.TensorCoreBusyTimePs()); -} - -TpuStepTimeBreakdown ComputeTpuStepTimeBreakdownInMs( - const InputPipelineAnalysisResult& analysis, bool has_sparse_core) { - tsl::Stat tc_compute_ms; - tsl::Stat tc_infeed_ms; - tsl::Stat tc_outfeed_ms; - tsl::Stat tc_idle_ms; - tsl::Stat scv0_compute_ms; - tsl::Stat scv0_infeed_ms; - tsl::Stat host_transfer_ms; - tsl::Stat sc_compute_ms; - tsl::Stat sc_infeed_ms; - tsl::Stat sc_outfeed_ms; - tsl::Stat sc_idle_ms; - tsl::Stat sc_step_time_ms; - TpuStepTimeBreakdown result; - - for (const google::protobuf::Any& step_details : analysis.step_details()) { - PerTpuStepDetails details; - if (!step_details.UnpackTo(&details)) { - LOG(ERROR) << "Unable to unpack step_details. Expected: tpu"; - // TODO(b/302086111): Switch back to DFATAL once absl is updated. - DCHECK(false); - return result; - } - tc_compute_ms.UpdateStat(details.tc_compute_time_ms()); - tc_idle_ms.UpdateStat(details.tc_idle_time_ms()); - tc_infeed_ms.UpdateStat(details.tc_infeed_time_ms()); - tc_outfeed_ms.UpdateStat(details.tc_outfeed_time_ms()); - scv0_compute_ms.UpdateStat(details.scv0_compute_time_ms()); - scv0_infeed_ms.UpdateStat(details.scv0_infeed_time_ms()); - host_transfer_ms.UpdateStat(details.host_transfer_ms()); - sc_compute_ms.UpdateStat(details.sc_compute_time_ms()); - sc_idle_ms.UpdateStat(details.sc_idle_time_ms()); - sc_infeed_ms.UpdateStat(details.sc_infeed_time_ms()); - sc_outfeed_ms.UpdateStat(details.sc_outfeed_time_ms()); - sc_step_time_ms.UpdateStat(details.sc_step_time_ms()); - } - *result.mutable_tc_compute_ms_summary() = - GetStepSummaryForSampleStats(tc_compute_ms); - *result.mutable_scv0_compute_ms_summary() = - GetStepSummaryForSampleStats(scv0_compute_ms); - *result.mutable_tc_infeed_ms_summary() = - GetStepSummaryForSampleStats(tc_infeed_ms); - *result.mutable_tc_outfeed_ms_summary() = - GetStepSummaryForSampleStats(tc_outfeed_ms); - *result.mutable_scv0_infeed_ms_summary() = - GetStepSummaryForSampleStats(scv0_infeed_ms); - *result.mutable_tc_idle_ms_summary() = - GetStepSummaryForSampleStats(tc_idle_ms); - *result.mutable_host_transfer_ms_summary() = - GetStepSummaryForSampleStats(host_transfer_ms); - if (has_sparse_core) { - auto* sparse_core_step_summary = result.mutable_sparse_core_step_summary(); - *sparse_core_step_summary->mutable_sc_compute_ms_summary() = - GetStepSummaryForSampleStats(sc_compute_ms); - *sparse_core_step_summary->mutable_sc_infeed_ms_summary() = - GetStepSummaryForSampleStats(sc_infeed_ms); - *sparse_core_step_summary->mutable_sc_outfeed_ms_summary() = - GetStepSummaryForSampleStats(sc_outfeed_ms); - *sparse_core_step_summary->mutable_sc_idle_ms_summary() = - GetStepSummaryForSampleStats(sc_idle_ms); - *sparse_core_step_summary->mutable_sc_step_time_ms_summary() = - GetStepSummaryForSampleStats(sc_step_time_ms); - } - return result; -} - -// Given the step sequence on each core, computes the result proto of the -// input-pipeline analysis tool (the InputPipelineAnalysisResult defined in -// input_pipeline.proto). -// Note on grouped_by_step: There is one element for each step executed (on -// multiple cores). Each element is a map from the core_id to the information -// of the step that runs on that core. Elements are in the same order that the -// steps are executed over time. -InputPipelineAnalysisResult ComputeTpuInputPipelineAnalysisResult( - const tsl::protobuf::RepeatedPtrField& grouped_by_step, - const tsl::protobuf::Map& - core_details_map) { - InputPipelineAnalysisResult result; - bool has_sparse_core = false; - for (const auto& [core_id, core_details] : core_details_map) { - has_sparse_core |= core_details.is_sparse_core(); - } - - // Computes the summary of step time in ms. - *result.mutable_step_time_summary() = - ComputeStepTimeSummaryInMs(grouped_by_step); - - // Summary of the statistics of infeed time as percentage of the step - // time. - tsl::Stat infeed_summary_stats_in_percent; - for (const auto& coreid_stepinfo_map : grouped_by_step) { - // Compute each TPU step stats. - const PerTpuStepDetails& per_step_data = - ComputeTpuPerStepDataAcrossCores(coreid_stepinfo_map, core_details_map); - result.add_step_details()->PackFrom(per_step_data); - - // The infeed summary is based on the maximum infeed time across cores at - // each step. - infeed_summary_stats_in_percent.UpdateStat( - per_step_data.infeed_percent_maximum()); - } - - // Computes the summary of infeed time as percentage of step time. - *result.mutable_input_percent_summary() = - GetStepSummaryForSampleStats(infeed_summary_stats_in_percent); - - // Computes the breakdown of step time - TpuStepTimeBreakdown tpu_step_time_breakdown = - ComputeTpuStepTimeBreakdownInMs(result, has_sparse_core); - result.mutable_step_time_breakdown()->PackFrom(tpu_step_time_breakdown); - result.set_tag(true); - - return result; -} - -// Returns true if device_op_metrics_db contains an infeed op. -bool HasTpuInfeedOp(const OpMetricsDb& device_op_metrics_db) { - for (const OpMetrics& metrics : device_op_metrics_db.metrics_db()) { - if (tsl::profiler::IsHostOrSparseCoreV0Infeed(metrics.category())) { - return true; - } - } - return false; -} - -// Returns the time spent waiting for input for generic hardware. -uint64_t TotalInputPs(const StepDetails& step_details) { - uint64_t total_input_ps = 0; - for (const auto& event : step_details.Events()) { - if (event.type == HOST_WAIT_INPUT || event.type == HOST_TO_DEVICE) { - // Includes both the time where the host was waiting input and the time - // where the host was sending data to the device. - total_input_ps += event.span.duration_ps(); - } - } - return total_input_ps; -} - -void TensorCoreIdleAnalysis(bool all_cores_profiled, double tc_idle_percent, - std::string* input_classification, - std::string* input_statement, - std::string* tc_idle_classification, - std::string* tc_idle_statement) { - // In MayFixTpuStepAnalysis(), we have already separated the idle time from - // the input time. So, we don't need to substract the input time from the - // idle time here. - if (tc_idle_percent < kTcIdleThresholdInPercent) { - *tc_idle_classification = "no"; - *tc_idle_statement = ""; - return; - } - std::string idle_percent_str = absl::StrFormat("%.1lf", tc_idle_percent); - if (all_cores_profiled) { - // Significant idle time with all cores profiled. - *tc_idle_classification = "yes"; - *tc_idle_statement = - absl::StrCat(idle_percent_str, - " % of the total step time sampled is due to host " - "overhead that is not input-related. For TF 2.x, you may " - "want to use a ", - AnchorElement(kMultipleStepsInTffunctionDoc, - "host-training loop (i.e. running multiple " - "steps within a tf.function).")); - return; - } - - // Significant idle time without all cores profiled. - if (*input_classification == "host") { - // We've already identified that it is input bound. So, no need to issue - // more warnings. - *tc_idle_classification = "no"; - *tc_idle_statement = ""; - return; - } - - *input_classification = "host"; // focuses on "host" first. - *input_statement = absl::StrCat( - "Your program COULD be input-bound because ", idle_percent_str, - "% of the total step time is idle. This may be a manifestation of an " - "input issue on a worker " - "machine that was not profiled. To be certain, please profile ALL " - "worker machines in your job by following ", - AnchorElement(kProfileAllHostsDoc, "this instruction.")); - *tc_idle_classification = "no"; - *tc_idle_statement = ""; -} - -void AllReduceAnalysis(bool all_cores_profiled, - double all_reduce_compute_percent, - double all_reduce_sync_percent, double input_percent, - std::string* input_classification, - std::string* input_statement, - std::string* all_reduce_classification, - std::string* all_reduce_statement) { - double all_reduce_percent = - all_reduce_compute_percent + all_reduce_sync_percent; - // Since all-reduce time is overlapped with the input time, we consider the - // all-reduce time that is not input related. - double all_reduce_not_input_related_percent = - all_reduce_percent - input_percent; - - if (all_reduce_not_input_related_percent < - kAllReduceBoundThresholdInPercent) { - // Insignificant time spent on all-reduce. - *all_reduce_classification = "no"; - *all_reduce_statement = ""; - return; - } - - if (all_cores_profiled) { - // Significant time spent on all-reduce with all cores profiled. - std::string all_reduce_compute_percent_str = - absl::StrFormat("%.1lf", all_reduce_compute_percent); - std::string all_reduce_sync_percent_str = - absl::StrFormat("%.1lf", all_reduce_sync_percent); - *all_reduce_classification = "yes"; - *all_reduce_statement = absl::StrCat( - "Also, ", all_reduce_sync_percent_str, - " % of the total step time sampled is spent on synchronization with " - "other TPU cores, and ", - all_reduce_compute_percent_str, - " % of the total step time sampled is spent on actual AllReduce."); - return; - } - - // Significant time spent on all-reduce and not all cores were profiled. - std::string all_reduce_percent_str = - absl::StrFormat("%.1lf", all_reduce_percent); - - if (*input_classification != "device") { - // InputAnalysis() already indicates some potential input issue. So, we - // can focus on all-reduce performance. - *all_reduce_classification = "yes"; - *all_reduce_statement = absl::StrCat( - "Also, ", all_reduce_percent_str, - " % of the total step time sampled is spent on synchronization " - "with " - "other TPU cores and AllReduce. Not all worker machines are " - "profiled, " - "therefore " - "we " - "cannot disambiguate the actual time for AllReduce from the " - "synchronization. To be certain, please profile ALL " - "worker machines in your job by following ", - AnchorElement(kProfileAllHostsDoc, "this instruction.")); - return; - } - - // InputAnalysis() indicates that it is NOT input-bound. However, it may - // be because the input delay is manifested as all-reduce time. So, - // attribute it to a possible input issue. - *input_classification = "host"; // focuses on "host" first. - *input_statement = absl::StrCat( - "Your program COULD be input-bound because ", all_reduce_percent_str, - "% of the total step time is spent on synchronization with other " - "TPU cores. This may be a manifestation of an input issue on a " - "worker " - "machine that was not profiled. To be certain, please profile ALL " - "worker machines in your job by following ", - AnchorElement(kProfileAllHostsDoc, "this instruction.")); - *all_reduce_classification = "no"; - *all_reduce_statement = ""; -} - -void ScV0Analysis(double scv0_percent, std::string* scv0_classification, - std::string* scv0_statement) { - if (scv0_percent == 0) { - *scv0_classification = "no"; - *scv0_statement = ""; - return; - } - std::string scv0_percent_str = absl::StrFormat("%.1lf", scv0_percent); - if (scv0_percent < kModeratelySparseCoreV0BoundThresholdInPercent) { - *scv0_classification = "moderate"; - *scv0_statement = absl::StrCat( - "Also, ", scv0_percent_str, - " % of the total step time sampled is spent on the ", kSparseCoreV0Name, - " compute. You may also want to reduce the ", kSparseCoreV0Name, - " compute time."); - return; - } - *scv0_classification = "high"; - *scv0_statement = absl::StrCat( - "Also, ", scv0_percent_str, - " % of the total step time sampled is spent on the ", kSparseCoreV0Name, - " compute. You should focus on reducing the ", kSparseCoreV0Name, - " compute time as well."); -} - -// A map keeps track of the minimum value associated with an id. -class MinMap { - public: - void Observe(uint64_t id, uint64_t value) { - auto [iter, inserted] = min_map_.try_emplace(id, value); - if (!inserted && iter->second > value) { - iter->second = value; - } - } - - uint64_t Min(uint64_t id) const { - auto iter = min_map_.find(id); - return (iter != min_map_.end()) ? iter->second : 0; - } - - private: - absl::flat_hash_map min_map_; -}; - -} // namespace - -PerTpuStepDetails ComputeTpuPerStepDataAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map, - const tsl::protobuf::Map& - core_details_map) { - PerTpuStepDetails per_step_data; - - PerCoreAllReduceBreakdown all_reduce_breakdown = - ComputePerStepAllReduceBreakdownAcrossCores(coreid_stepinfo_map); - - tsl::Stat infeed_percent_stats; - tsl::Stat step_stats_in_ps; - tsl::Stat optimal_step_time_ps; - // Take the average TC outfeed time in result. - tsl::Stat tc_outfeed_time_in_ps; - tsl::Stat sc_compute_time_ps; - tsl::Stat sc_step_stats_in_ps; - tsl::Stat sc_outfeed_time_in_ps; - tsl::Stat sc_infeed_time_in_ps; - tsl::Stat sc_idle_time_in_ps; - - tsl::Stat host_send_recv_time_ps; - - // For the core with the max wait-for-scv0 duration, breakdown to compute and - // infeed time. - WaitForScV0Breakdown max_wait_for_scv0; - - TcInfeed max_infeed; - - // For the core with the max all reduce duration, breakdown to compute and - // synchronization time. - AllReduceBreakdown max_all_reduce; - - per_step_data.set_step_number(-1); - auto process_step_for_sc = - [&](const tensorflow::profiler::StepInfoResult& step_info, - const SparseCoreStepBreakdown& sc_step) { - if (per_step_data.step_number() < 0) { - per_step_data.set_step_number(step_info.step_num()); - } else { - if (per_step_data.step_number() != step_info.step_num()) { - VLOG(1) << "Inconsistent step numbers across cores (" - << per_step_data.step_number() << " vs. " - << step_info.step_num() << ")."; - } - } - sc_step_stats_in_ps.UpdateStat(step_info.duration_ps()); - sc_outfeed_time_in_ps.UpdateStat(sc_step.sc_outfeed_ps()); - sc_infeed_time_in_ps.UpdateStat(sc_step.sc_infeed_ps()); - sc_compute_time_ps.UpdateStat(step_info.duration_ps() - - sc_step.sc_infeed_ps() - - sc_step.sc_outfeed_ps()); - sc_idle_time_in_ps.UpdateStat(sc_step.sc_idle_ps()); - }; - for (const auto& [core_id, step_info] : - coreid_stepinfo_map.step_info_per_core()) { - // iterates over each core. - TpuStepBreakdown tpu; - if (!step_info.step_breakdown().UnpackTo(&tpu)) { - VLOG(1) << "Unable to unpack step_breakdown from tpu, try unpacking from " - "generic"; - tensorflow::profiler::GenericStepBreakdown generic_step_breakdown; - if (!step_info.step_breakdown().UnpackTo(&generic_step_breakdown)) { - SparseCoreStepBreakdown sc_step; - if (step_info.step_breakdown().UnpackTo(&sc_step)) { - process_step_for_sc(step_info, sc_step); - continue; - } else { - LOG(ERROR) << "Unable to unpack step_breakdown from " - "GenericStepBreakdown or SparseCoreStepBreakdown"; - // TODO(b/302086111): Switch back to DFATAL once absl is updated. - DCHECK(false); - return per_step_data; - } - } - if (core_id >= kSparseCoreIndexStart) { - // Sparse core step breakdown from xspace. - uint64_t idle_time_ps = 0; - uint64_t busy_time_ps = 0; - for (const auto& [category, time_ps] : - generic_step_breakdown.category_ps()) { - if (category == kIdle) { - idle_time_ps = time_ps; - } else if (category == "sparse_core_busy_ops") { - busy_time_ps = time_ps; - } - } - sc_step_stats_in_ps.UpdateStat(step_info.duration_ps()); - sc_compute_time_ps.UpdateStat(busy_time_ps); - sc_idle_time_in_ps.UpdateStat(idle_time_ps); - continue; - } else { - // Tensor core step breakdown from xspace. - ConvertGenericStepBreakdownToTpuStepBreakdown( - generic_step_breakdown, step_info.duration_ps(), tpu); - } - } - step_stats_in_ps.UpdateStat(step_info.duration_ps()); - if (tpu.wait_for_scv0_duration_ps() > max_wait_for_scv0.DurationPs()) { - max_wait_for_scv0.scv0_infeed_duration_ps = ScV0InfeedDurationPs(tpu); - max_wait_for_scv0.scv0_compute_duration_ps = ScV0ComputeDurationPs(tpu); - } - - tc_outfeed_time_in_ps.UpdateStat(tpu.host_outfeed_ps()); - - const AllReduceBreakdown& breakdown = all_reduce_breakdown[core_id]; - if (breakdown.DurationPs() > max_all_reduce.DurationPs()) { - max_all_reduce = breakdown; - } - - infeed_percent_stats.UpdateStat(100.0 * TcPlusScV0InfeedDurationPs(tpu) / - step_info.duration_ps()); - // The optimal step time is the actual step time minus the time tensor - // core spends waiting for host or sparsecorev0 (but not other tensor - // cores). - optimal_step_time_ps.UpdateStat(step_info.duration_ps() - - WaitForHostOrScV0DurationPs(tpu)); - host_send_recv_time_ps.UpdateStat(HostSendRecvDurationPs(tpu)); - - if (per_step_data.step_number() < 0) { - // Sets the step number of the current step from the first core. - per_step_data.set_step_number(step_info.step_num()); - } else { - // The step number of the current step is already set. Checks if it is - // the same across cores. In case of multi-host tracing, we may have - // some inconsistent steps as tracing is not exactly guaranteed to be - // synchronized across all hosts. - if (per_step_data.step_number() != step_info.step_num()) { - VLOG(1) << "Inconsistent step numbers across cores (" - << per_step_data.step_number() << " vs. " - << step_info.step_num() << ")."; - } - } - if (tpu.infeed_duration_ps() > max_infeed.duration_ps) { - max_infeed.core_id = core_id; - max_infeed.duration_ps = tpu.infeed_duration_ps(); - } - } - - per_step_data.set_tc_outfeed_time_ms( - tsl::profiler::PicoToMilli(tc_outfeed_time_in_ps.avg())); - // The TC compute time is the minimum of the optimal step time across cores. - per_step_data.set_tc_compute_time_ms( - tsl::profiler::PicoToMilli(optimal_step_time_ps.min())); - per_step_data.set_host_transfer_ms( - tsl::profiler::PicoToMilli(host_send_recv_time_ps.max())); - // TODO(b/153730997): Use the maximum step time. - // The infeed time is the step time across cores minus all other times. - // Previously, we used the maximum step time but changed to use the minimum - // step time to work around b/153730997. - // Uses the max TC infeed duration across cores as the step's TC infeed - // duration. - per_step_data.set_tc_infeed_time_ms( - tsl::profiler::PicoToMilli(max_infeed.duration_ps)); - if (max_infeed.core_id.has_value()) { - per_step_data.set_coreid_max_infeed_time(max_infeed.core_id.value()); - if (core_details_map.contains(max_infeed.core_id.value())) { - const CoreDetails& core_details = - core_details_map.at(max_infeed.core_id.value()); - per_step_data.set_max_infeed_time_core_name(absl::StrCat( - core_details.hostname(), ":", core_details.device_ordinal())); - } - } - - per_step_data.set_scv0_compute_time_ms( - tsl::profiler::PicoToMilli(max_wait_for_scv0.scv0_compute_duration_ps)); - per_step_data.set_scv0_infeed_time_ms( - tsl::profiler::PicoToMilli(max_wait_for_scv0.scv0_infeed_duration_ps)); - - // The TC idle time is the time TC spends waiting for the host but not - // waiting for input. - per_step_data.set_tc_idle_time_ms( - tsl::profiler::PicoToMilli(step_stats_in_ps.min()) - - NonIdleTimeMs(per_step_data)); - if (per_step_data.tc_idle_time_ms() < 0) { - per_step_data.set_tc_idle_time_ms(0); - } - - per_step_data.set_all_reduce_compute_time_ms( - tsl::profiler::PicoToMilli(max_all_reduce.compute_duration_ps)); - per_step_data.set_all_reduce_sync_time_ms( - tsl::profiler::PicoToMilli(max_all_reduce.sync_duration_ps)); - - per_step_data.set_infeed_percent_average(infeed_percent_stats.avg()); - per_step_data.set_infeed_percent_minimum(infeed_percent_stats.min()); - per_step_data.set_infeed_percent_maximum(infeed_percent_stats.max()); - - per_step_data.set_sc_infeed_time_ms( - tsl::profiler::PicoToMilli(sc_infeed_time_in_ps.avg())); - per_step_data.set_sc_outfeed_time_ms( - tsl::profiler::PicoToMilli(sc_outfeed_time_in_ps.avg())); - per_step_data.set_sc_compute_time_ms( - tsl::profiler::PicoToMilli(sc_compute_time_ps.min())); - per_step_data.set_sc_idle_time_ms( - tsl::profiler::PicoToMilli(sc_idle_time_in_ps.avg())); - per_step_data.set_sc_step_time_ms( - tsl::profiler::PicoToMilli(sc_step_stats_in_ps.avg())); - if (per_step_data.sc_idle_time_ms() < 0) { - per_step_data.set_sc_idle_time_ms(0); - } - return per_step_data; -} - -StepSummary GetStepSummaryForSampleStats( - const tsl::Stat& sample_stats) { - StepSummary step_time_summary; - double avg, sdv, min, max; - if (sample_stats.empty()) { - // If sample_stats is empty, sample_stats.avg() will return NaN. However, we - // prefer to show an 0 instead. - avg = sdv = min = max = 0.0; - } else { - avg = sample_stats.avg(); - sdv = sqrt(sample_stats.sample_variance()); - min = sample_stats.min(); - max = sample_stats.max(); - } - step_time_summary.set_average(avg); - step_time_summary.set_standard_deviation(sdv); - step_time_summary.set_minimum(min); - step_time_summary.set_maximum(max); - return step_time_summary; -} - -PerCoreAllReduceBreakdown ComputePerStepAllReduceBreakdownAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map) { - PerCoreAllReduceBreakdown result; - MinMap min_duration_map; - for (const auto& [core_id, all_reduce_db] : - coreid_stepinfo_map.all_reduce_db_per_core()) { - for (const auto& all_reduce : all_reduce_db.all_reduce_info()) { - uint64_t duration_ps = - all_reduce.end_time_ps() - all_reduce.start_time_ps(); - min_duration_map.Observe(all_reduce.id(), duration_ps); - } - } - for (const auto& [core_id, all_reduce_db] : - coreid_stepinfo_map.all_reduce_db_per_core()) { - AllReduceBreakdown& breakdown = result[core_id]; - for (const auto& all_reduce : all_reduce_db.all_reduce_info()) { - uint64_t duration_ps = - all_reduce.end_time_ps() - all_reduce.start_time_ps(); - uint64_t min_duration_ps = min_duration_map.Min(all_reduce.id()); - breakdown.compute_duration_ps += min_duration_ps; - breakdown.sync_duration_ps += duration_ps - min_duration_ps; - } - } - return result; -} - -void MayFixTpuStepAnalysis( - const StepEvents& host_step_events, const OpMetricsDb& device_op_metrics_db, - StepDatabaseResult& step_db, - const tsl::protobuf::Map& - core_details_map) { - // This code is only applicable when input is received by the tensor core - // from the host without the use of infeed. If the tensor core receives - // input via host infeed or via sparsecorev0 infeed, there's nothing to do. - if (HasTpuInfeedOp(device_op_metrics_db)) return; - - for (PerCoreStepInfo& per_core_step_info : - *(step_db.mutable_step_sequence())) { - uint32_t step_num = per_core_step_info.step_num(); - // TODO(ckluk): step_num is obtained from tf_op_stats, which is based on the - // step-tracking mechanism with the on-device training loop. However, this - // step_num is different from the group_id. So, what we are doing here is - // only an approximation, assuming that all steps exhibit similar - // breakdown. Once grouping works on TPU device, we need to replace step_num - // by the group_id from TPU device. - const StepDetails* step_details = - gtl::FindOrNull(host_step_events, step_num); - if (step_details == nullptr) { - continue; // step_num not in host_step_events, we don't know how to fix. - } - uint64_t total_input_ps = TotalInputPs(*step_details); - if (total_input_ps == 0) { - continue; // no host input events. - } - PerTpuStepDetails tpu_step_data = - ComputeTpuPerStepDataAcrossCores(per_core_step_info, core_details_map); - double tc_idle_ms = tpu_step_data.tc_idle_time_ms(); - double adjusted_input_ratio = - std::min(tsl::profiler::SafeDivide( - tsl::profiler::PicoToMilli(total_input_ps), tc_idle_ms), - 1.0); - for (auto& [core_id, step_info] : - *per_core_step_info.mutable_step_info_per_core()) { - // skip sparse cores for this. - if (core_id >= kSparseCoreIndexStart) continue; - TpuStepBreakdown tpu; - if (TpuStepBreakdown tpu; step_info.step_breakdown().UnpackTo(&tpu)) { - DCHECK_EQ(tpu.infeed_duration_ps(), 0); - if (tpu.tc_idle_ps() > 0) { - // Extract the infeed fraction of idle time. - tpu.set_infeed_duration_ps(tpu.tc_idle_ps() * adjusted_input_ratio); - tpu.set_tc_idle_ps(tpu.tc_idle_ps() - tpu.infeed_duration_ps()); - step_info.mutable_step_breakdown()->PackFrom(tpu); - } - } else if (tensorflow::profiler::GenericStepBreakdown generic; - step_info.step_breakdown().UnpackTo(&generic)) { - uint64_t& infeed_time_ps = - (*generic.mutable_category_ps())[xla::HloOpcodeString( - xla::HloOpcode::kInfeed)]; - uint64_t& idle_time_ps = - (*generic.mutable_category_ps())[tensorflow::profiler::kIdle]; - DCHECK_EQ(infeed_time_ps, 0); - if (idle_time_ps > 0) { - infeed_time_ps = idle_time_ps * adjusted_input_ratio; - idle_time_ps -= infeed_time_ps; - step_info.mutable_step_breakdown()->PackFrom(generic); - } - } else { - // Likely encountered an ScStepBreakdown instance which can be skipped - // as we only care about attributing TC idle time to host. - LOG(INFO) << "Unable to unpack step_breakdown."; - } - } - } -} - -TpuBottleneckAnalysis ComputeTpuBottleneckAnalysis( - bool all_cores_profiled, const InputPipelineAnalysisResult& result) { - double total_step_time_ms = 0; - double total_infeed_time_ms = 0; - double total_tc_outfeed_time_ms = 0; - double total_scv0_compute_time_ms = 0; - double total_all_reduce_compute_time_ms = 0; - double total_all_reduce_sync_time_ms = 0; - double total_tc_idle_time_ms = 0; - - TpuBottleneckAnalysis analysis; - for (const google::protobuf::Any& step_details : result.step_details()) { - PerTpuStepDetails details; - if (!step_details.UnpackTo(&details)) { - LOG(ERROR) << "Unable to unpack step_details. Expected: tpu"; - // TODO(b/302086111): Switch back to DFATAL once absl is updated. - DCHECK(false); - return analysis; - } - total_step_time_ms += StepTimeMs(details); - total_infeed_time_ms += InfeedTimeMs(details); - total_tc_outfeed_time_ms += details.tc_outfeed_time_ms(); - total_scv0_compute_time_ms += details.scv0_compute_time_ms(); - total_all_reduce_compute_time_ms += details.all_reduce_compute_time_ms(); - total_all_reduce_sync_time_ms += details.all_reduce_sync_time_ms(); - total_tc_idle_time_ms += details.tc_idle_time_ms(); - } - if (total_step_time_ms == 0) { - analysis.set_input_classification("unknown"); - analysis.set_input_statement( - "No step time measured. Therefore we cannot tell where the performance " - "bottleneck is."); - analysis.set_tc_idle_classification("no"), - analysis.set_tc_idle_statement(""); - analysis.set_scv0_classification("no"); - analysis.set_scv0_statement(""); - analysis.set_all_reduce_classification("no"); - analysis.set_all_reduce_statement(""); - return analysis; - } - - double infeed_percent = 100.0 * total_infeed_time_ms / total_step_time_ms; - std::string input_classification; - std::string input_statement; - InputAnalysis(infeed_percent, /*all_other_percent=*/0, &input_classification, - &input_statement); - - double tc_outfeed_percent = - 100.0 * total_tc_outfeed_time_ms / total_step_time_ms; - std::string output_classification; - std::string output_statement; - OutputAnalysis(tc_outfeed_percent, &output_classification, &output_statement); - - double tc_idle_percent = 100.0 * total_tc_idle_time_ms / total_step_time_ms; - std::string tc_idle_classification; - std::string tc_idle_statement; - TensorCoreIdleAnalysis(all_cores_profiled, tc_idle_percent, - &input_classification, &input_statement, - &tc_idle_classification, &tc_idle_statement); - - double all_reduce_compute_percent = - 100.0 * total_all_reduce_compute_time_ms / total_step_time_ms; - double all_reduce_sync_percent = - 100.0 * total_all_reduce_sync_time_ms / total_step_time_ms; - std::string all_reduce_classification; - std::string all_reduce_statement; - AllReduceAnalysis(all_cores_profiled, all_reduce_compute_percent, - all_reduce_sync_percent, infeed_percent, - &input_classification, &input_statement, - &all_reduce_classification, &all_reduce_statement); - - double scv0_percent = 100.0 * total_scv0_compute_time_ms / total_step_time_ms; - std::string scv0_classification; - std::string scv0_statement; - ScV0Analysis(scv0_percent, &scv0_classification, &scv0_statement); - - // compute_percent includes both TC and ScV0 compute. - double compute_percent = std::max( - 0.0, 100.0 - infeed_percent - tc_outfeed_percent - tc_idle_percent); - - analysis.set_compute_percent(compute_percent); - analysis.set_input_percent(infeed_percent); - analysis.set_output_percent(tc_outfeed_percent); - analysis.set_tc_idle_percent(tc_idle_percent); - analysis.set_input_classification(input_classification); - analysis.set_input_statement(input_statement); - analysis.set_output_statement(output_statement); - analysis.set_tc_idle_classification(tc_idle_classification), - analysis.set_tc_idle_statement(tc_idle_statement); - analysis.set_scv0_classification(scv0_classification); - analysis.set_scv0_statement(scv0_statement); - analysis.set_all_reduce_classification(all_reduce_classification); - analysis.set_all_reduce_statement(all_reduce_statement); - return analysis; -} - -void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db, - InputPipelineAnalysisResult* result) { - InputOpMetrics input_op_metrics = SelectInputOpMetrics(host_tf_metrics_db); - // Returns if the program is not using an input pipeline with - // instrumentation and hence no input ops are found. - if (input_op_metrics.input_op_metrics.empty()) return; - - absl::flat_hash_map aggregated_input_op_times_us; - for (const OpMetrics* op_metrics : input_op_metrics.input_op_metrics) { - InputOpCategory category = - CategorizeInputOp(op_metrics->name(), op_metrics->category()); - *result->add_input_op_details() = ConvertOpMetricsToInputOpDetails( - *op_metrics, input_op_metrics.input_op_time_ps, category); - aggregated_input_op_times_us[category] += - tsl::profiler::PicoToMicro(op_metrics->self_time_ps()); - } - - double enqueue_time_us = - aggregated_input_op_times_us[InputOpCategory::kEnqueue]; - double total_input_op_time_us = - aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead] + - aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead] + - aggregated_input_op_times_us[InputOpCategory::kPreprocessing]; - - double ratio = std::min( - 1.0, RatioOfHostToDeviceTimeToStepTime(host_tf_metrics_db, *result)); - DCHECK_GE(ratio, 0.0); - double non_enqueue_time_us = (ratio != 0.0) - ? (enqueue_time_us * (1.0 - ratio) / ratio) - : total_input_op_time_us; - - // Scales the various input-time components wrt to non_enqueue_time_us. - double scaled_demanded_fileread_time_us = tsl::profiler::SafeDivide( - non_enqueue_time_us * - aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead], - total_input_op_time_us); - double scaled_advanced_fileread_time_us = tsl::profiler::SafeDivide( - non_enqueue_time_us * - aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead], - total_input_op_time_us); - double scaled_preprocessing_time_us = tsl::profiler::SafeDivide( - non_enqueue_time_us * - aggregated_input_op_times_us[InputOpCategory::kPreprocessing], - total_input_op_time_us); - double unclassified_non_enqueue_time_us = std::max( - 0.0, non_enqueue_time_us - scaled_demanded_fileread_time_us - - scaled_advanced_fileread_time_us - scaled_preprocessing_time_us); - - InputTimeBreakdown* input_time_breakdown = - result->mutable_input_time_breakdown(); - input_time_breakdown->set_enqueue_us(enqueue_time_us); - input_time_breakdown->set_demanded_file_read_us( - scaled_demanded_fileread_time_us); - input_time_breakdown->set_advanced_file_read_us( - scaled_advanced_fileread_time_us); - input_time_breakdown->set_preprocessing_us(scaled_preprocessing_time_us); - input_time_breakdown->set_unclassified_non_enqueue_us( - unclassified_non_enqueue_time_us); -} - -InputPipelineAnalysisRecommendation GenerateRecommendation() { - const absl::string_view kDatasetIntro = - "https://www.tensorflow.org/programmers_guide/datasets"; - - const absl::string_view kDatasetTopic = - "https://www.tensorflow.org/api_docs/python/tf/data/Dataset#"; - - const absl::string_view kTfRecordDataset = - "https://www.tensorflow.org/api_docs/python/tf/data/" - "TFRecordDataset#class_tfrecorddataset"; - - InputPipelineAnalysisRecommendation recommendation; - *recommendation.add_details() = - "Enqueuing data: you may want to combine small input data chunks " - "into fewer " - "but larger chunks."; - *recommendation.add_details() = absl::StrCat( - "Data preprocessing: you may increase num_parallel_calls in ", - AnchorElement(absl::StrCat(kDatasetTopic, "map"), "Dataset map()"), - " or preprocess the data OFFLINE."); - *recommendation.add_details() = absl::StrCat( - "Reading data from files in advance: you may tune parameters in the " - "following tf.data API (", - AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch size"), - ", ", - AnchorElement(absl::StrCat(kDatasetTopic, "interleave"), - "interleave cycle_length"), - ", ", AnchorElement(kTfRecordDataset, "reader buffer_size"), ")"); - *recommendation.add_details() = absl::StrCat( - "Reading data from files on demand: you should read data IN ADVANCE " - "using the following tf.data API (", - AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch"), ", ", - AnchorElement(absl::StrCat(kDatasetTopic, "interleave"), "interleave"), - ", ", AnchorElement(kTfRecordDataset, "reader buffer"), ")"); - *recommendation.add_details() = absl::StrCat( - "Other data reading or processing: you may consider using the ", - AnchorElement(kDatasetIntro, "tf.data API"), - " (if you are not using it now)"); - - return recommendation; -} - -StepSummary ComputeStepTimeSummaryInMs( - const tsl::protobuf::RepeatedPtrField& grouped_by_step) { - tsl::Stat total_step_stats_in_ms; - // iterates over each step. - for (const auto& coreid_stepinfo_map : grouped_by_step) { - double max_per_step_stats_in_ms = 0.0; - // iterates over each core. - for (const auto& coreid_and_stepinfo : - coreid_stepinfo_map.step_info_per_core()) { - if (coreid_and_stepinfo.first >= kSparseCoreIndexStart) continue; - const auto& step_info = coreid_and_stepinfo.second; - max_per_step_stats_in_ms = std::max(step_info.duration_ps() / kNumPsPerMs, - max_per_step_stats_in_ms); - } - // Step time of each step is determined by the slowest core. - total_step_stats_in_ms.UpdateStat(max_per_step_stats_in_ms); - } - - return GetStepSummaryForSampleStats(total_step_stats_in_ms); -} - -InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis( - const OpStats& op_stats) { - const HardwareType hardware_type = op_stats.run_environment().hardware_type(); - - InputPipelineAnalysisResult result; - if (hardware_type == tensorflow::profiler::TPU) { - result = ComputeTpuInputPipelineAnalysisResult( - op_stats.step_db().step_sequence(), op_stats.core_id_to_details()); - } else { - result = ComputeGenericInputPipelineAnalysisResult( - op_stats.step_db().step_sequence()); - } - result.set_hardware_type(HardwareType_Name(hardware_type)); - - PopulateStepDiagnostics(op_stats, result.mutable_diagnostics()); - GenerateHostResult(op_stats.host_op_metrics_db(), &result); - - InputPipelineAnalysisRecommendation recommendation = GenerateRecommendation(); - if (hardware_type == tensorflow::profiler::TPU) { - TpuBottleneckAnalysis bottleneck_analysis = ComputeTpuBottleneckAnalysis( - /*all_cores_profiled=*/true, result); - result.set_input_percent(bottleneck_analysis.input_percent()); - result.set_output_percent(bottleneck_analysis.output_percent()); - result.set_idle_percent(bottleneck_analysis.tc_idle_percent()); - result.set_compute_percent(bottleneck_analysis.compute_percent()); - - recommendation.mutable_bottleneck_analysis()->PackFrom(bottleneck_analysis); - *recommendation.mutable_summary_next_step() = - GetSummaryNextStep(bottleneck_analysis.input_classification(), - result.input_time_breakdown()); - } else { - BottleneckAnalysis bottleneck_analysis = ComputeBottleneckAnalysis( - result.input_time_breakdown(), result.step_details()); - result.set_input_percent(bottleneck_analysis.input_percent()); - result.set_output_percent(bottleneck_analysis.output_percent()); - result.set_idle_percent(bottleneck_analysis.idle_percent()); - result.set_compute_percent(bottleneck_analysis.compute_percent()); - recommendation.mutable_bottleneck_analysis()->PackFrom(bottleneck_analysis); - *recommendation.mutable_summary_next_step() = - GetSummaryNextStep(bottleneck_analysis.input_classification(), - result.input_time_breakdown()); - } - - *result.mutable_recommendation() = recommendation; - return result; -} - -bool InputAnalysis(double input_percent, double all_other_percent, - std::string* input_classification, - std::string* input_statement) { - absl::string_view non_input_time = "other time"; - if (input_percent >= kHighlyInfeedBoundThresholdInPercent) { - *input_classification = "host"; - *input_statement = absl::StrCat( - "Your program is HIGHLY input-bound because ", OneDigit(input_percent), - "% of the total step time sampled is waiting for input. Therefore, you " - "should first focus on reducing the input time."); - return false; - } else if (input_percent >= kModeratelyInfeedBoundThresholdInPercent) { - *input_classification = "both"; - *input_statement = absl::StrCat( - "Your program is MODERATELY input-bound because ", - OneDigit(input_percent), - "% of the total step time sampled is waiting for input. Therefore, " - "you would need to reduce both the input time and ", - non_input_time, "."); - return false; - } else if (all_other_percent >= kModeratelyAllOtherBoundThresholdInPercent) { - // Input analysis says it is not input-bound, but "All-Other" time - // is significant. It could still be input-bound (or Python overhead). - *input_classification = "both"; - *input_statement = absl::StrCat( - "Your program is POTENTIALLY input-bound because ", - OneDigit(all_other_percent), - "% of the total step time sampled is spent on 'All Others' time (which " - "could be due to I/O or Python execution or both)."); - return true; - } else { - // Definitely not input-bound. - *input_classification = "device"; - *input_statement = - absl::StrCat("Your program is NOT input-bound because only ", - OneDigit(input_percent), - "% of the total step time sampled is waiting for " - "input. Therefore, you should focus on " - "reducing ", - non_input_time, "."); - return false; - } -} - -void OutputAnalysis(double output_percent, std::string* output_classification, - std::string* output_statement) { - if (output_percent >= kHighlyOutfeedBoundThresholdInPercent) { - *output_classification = "host"; - *output_statement = absl::StrCat( - "Your program is HIGHLY output-bound because ", - OneDigit(output_percent), - "% of the total step time sampled is spent on output. Therefore, you " - "should first focus on reducing the output time."); - } else if (output_percent >= kModeratelyOutfeedBoundThresholdInPercent) { - *output_classification = "both"; - *output_statement = absl::StrCat( - "Your program is MODERATELY output-bound because ", - OneDigit(output_percent), - "% of the total step time sampled is spent on output. Therefore, " - "you would need to reduce both the output time and other time."); - } else { - *output_classification = "device"; - *output_statement = ""; - } -} - -BottleneckAnalysis ComputeBottleneckAnalysis( - const InputTimeBreakdown& input_time_breakdown, - const tsl::protobuf::RepeatedPtrField<::google::protobuf::Any>& - any_step_details) { - double total_step_time_ms = 0; - double total_input_ms = 0; - double total_output_ms = 0; - double total_host_compute_ms = 0; - double total_host_prepare_ms = 0; - double total_host_compile_ms = 0; - double total_device_compute_ms = 0; - double total_device_to_device_ms = 0; - double total_device_collectives_ms = 0; - double total_unknown_ms = 0; - - for (const google::protobuf::Any& step_details : any_step_details) { - PerGenericStepDetails details; - bool success = step_details.UnpackTo(&details); - if (!success && !step_details.type_url().empty()) { - LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic" - << std::endl; - return {}; - } - total_step_time_ms += details.step_time_ms(); - total_input_ms += - details.host_wait_input_ms() + details.host_to_device_ms(); - total_output_ms += details.output_ms(); - total_host_prepare_ms += details.host_prepare_ms(); - total_device_compute_ms += details.device_compute_ms(); - total_device_to_device_ms += details.device_to_device_ms(); - total_device_collectives_ms += details.device_collectives_ms(); - total_host_compute_ms += details.host_compute_ms(); - total_host_compile_ms += details.host_compile_ms(); - total_unknown_ms += details.unknown_time_ms(); - } - - if (total_step_time_ms == 0) { - BottleneckAnalysis analysis; - analysis.set_input_classification("unknown"); - analysis.set_input_statement( - "No step time measured. Therefore we cannot tell where the " - "performance bottleneck is."); - analysis.set_kernel_launch_classification("no"); - analysis.set_kernel_launch_statement(""); - analysis.set_all_other_classification("no"); - analysis.set_all_other_statement(""); - analysis.set_device_collectives_classification("no"); - analysis.set_device_collectives_statement(""); - return analysis; - } - double input_percent = 100.0 * total_input_ms / total_step_time_ms; - double output_percent = 100.0 * total_output_ms / total_step_time_ms; - double compute_percent = 100.0 * total_device_compute_ms / total_step_time_ms; - double device_collectives_percent = - 100.0 * total_device_collectives_ms / total_step_time_ms; - - // idle_percent includes host_prepare (i.e. kernel launch, device-to-device, - // host compute, host compile, and unknown. - double idle_percent = - std::max(0.0, 100.0 - input_percent - output_percent - compute_percent - - device_collectives_percent); - double kernel_launch_percent = - 100.0 * total_host_prepare_ms / total_step_time_ms; - double all_other_percent = 100.0 * total_unknown_ms / total_step_time_ms; - - std::string input_classification; - std::string input_statement; - bool all_other_reported = - InputAnalysis(input_percent, all_other_percent, &input_classification, - &input_statement); - - std::string device_collectives_classification; - std::string device_collectives_statement; - DeviceCollectivesAnalysis(device_collectives_percent, - &device_collectives_classification, - &device_collectives_statement); - - std::string kernel_launch_classification; - std::string kernel_launch_statement; - KernelLaunchAnalysis(TfDataInUse(input_time_breakdown), kernel_launch_percent, - &kernel_launch_classification, &kernel_launch_statement); - - std::string all_other_classification; - std::string all_other_statement; - AllOtherAnalysis(all_other_reported, all_other_percent, - &all_other_classification, &all_other_statement); - - BottleneckAnalysis analysis; - analysis.set_input_percent(input_percent); - analysis.set_output_percent(output_percent); - analysis.set_idle_percent(idle_percent); - analysis.set_compute_percent(compute_percent); - - analysis.set_input_classification(input_classification); - analysis.set_input_statement(input_statement); - analysis.set_kernel_launch_classification(kernel_launch_classification); - analysis.set_kernel_launch_statement(kernel_launch_statement); - analysis.set_all_other_classification(all_other_classification); - analysis.set_all_other_statement(all_other_statement); - analysis.set_device_collectives_classification( - device_collectives_classification); - analysis.set_device_collectives_statement(device_collectives_statement); - - return analysis; -} - -std::string GetSummaryNextStep(absl::string_view input_classification, - const InputTimeBreakdown& breakdown) { - std::string summary_next_step; - if (input_classification == "host" || input_classification == "both") { - if (!TfDataInUse(breakdown)) { - summary_next_step = absl::StrCat( - "Consider using ", MakeDocLink(DatasetIntroDoc(), "the tf.data API"), - " to enable profiler's host-side analysis for input pipeline. " - "Profiler currently does not support custom input pipeline (please " - "ignore " - "Section ", - kHostAnalysisSectionNumber, " below)."); - } else { - summary_next_step = - absl::StrCat("Look at Section ", kHostAnalysisSectionNumber, - " for the breakdown of input time on the host."); - } - } else { - summary_next_step = "You may skip the rest of this page."; - } - - return summary_next_step; -} - -double HostToDeviceTransferAsPercentOfInputTime( - const InputTimeBreakdown& breakdown) { - // Thanks to the scaling trick we did in GenerateHostResult(), we can - // estimate the percentage of input-time spent on host-to-device transfer in - // the following way. - double total_input_time_us = - breakdown.demanded_file_read_us() + breakdown.advanced_file_read_us() + - breakdown.preprocessing_us() + breakdown.enqueue_us() + - breakdown.unclassified_non_enqueue_us(); - return 100.0 * - tsl::profiler::SafeDivide(breakdown.enqueue_us(), total_input_time_us); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h deleted file mode 100644 index 79c874212d8da1..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ - -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/util/stats_calculator.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" -#include "tsl/platform/protobuf.h" -#include "plugin/tensorboard_plugin_profile/protobuf/input_pipeline.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tpu_input_pipeline.pb.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -struct AllReduceBreakdown { - uint64_t compute_duration_ps = 0; - uint64_t sync_duration_ps = 0; - - uint64_t DurationPs() const { return compute_duration_ps + sync_duration_ps; } -}; - -// Used to store AllReduceBreakdown per core id. Just an alias for user -// convenience. -using PerCoreAllReduceBreakdown = - absl::flat_hash_map; - -// Breakdown AllReduce time into synchronization time and actual compute time -// for each core and step. -PerCoreAllReduceBreakdown ComputePerStepAllReduceBreakdownAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map); - -// Computes the fields in PerStepData by considering the different StepInfos -// of the same step across cores. -PerTpuStepDetails ComputeTpuPerStepDataAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map, - const tsl::protobuf::Map& - core_details_map); - -StepSummary GetStepSummaryForSampleStats(const tsl::Stat& sample_stats); - -// If the percent of input-time spent on host-to-device transfer is greater than -// kHostToDeviceTimePercentAsSignificant, we should advise the -// user to optimize this transfer. -constexpr double kHostToDeviceTimePercentAsSignificant = 10.0; - -// If the percent of input-time spent on host-to-device transfer is greater than -// kHostToDeviceTimePercentAsDominant, we should ONLY advise the -// user to optimize this transfer; we won't bother to suggest optimization for -// tf.data. -constexpr double kHostToDeviceTimePercentAsDominant = 90.0; - -// Computes the summary of step time in milliseconds. -StepSummary ComputeStepTimeSummaryInMs( - const tsl::protobuf::RepeatedPtrField& grouped_by_step); - -void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db, - InputPipelineAnalysisResult* result); - -InputPipelineAnalysisRecommendation GenerateRecommendation(); - -// For TPU, we may have mis-regarded some host overhead as idle time. -// This function checks if this is the case using host_step_events. If this is, -// it will do the correction in op_stats. -void MayFixTpuStepAnalysis( - const StepEvents& host_step_events, const OpMetricsDb& device_op_metrics_db, - StepDatabaseResult& step_db, - const tsl::protobuf::Map& core_details_map); - -// Returns a struct that describes the performance bottleneck of the -// program executed on TPU. -TpuBottleneckAnalysis ComputeTpuBottleneckAnalysis( - bool all_cores_profiled, const InputPipelineAnalysisResult& result); - -// Returns the performance bottleneck of the program executed. -BottleneckAnalysis ComputeBottleneckAnalysis( - const InputTimeBreakdown& input_time_breakdown, - const tsl::protobuf::RepeatedPtrField<::google::protobuf::Any>& - any_step_details); - -InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis( - const OpStats& op_stats); - -// Returns true if explanation for "All Others" time is also included in -// input_statement. -bool InputAnalysis(double input_percent, double all_other_percent, - std::string* input_classification, - std::string* input_statement); - -void OutputAnalysis(double output_percent, std::string* output_classification, - std::string* output_statement); - -string GetSummaryNextStep(absl::string_view input_classification, - const InputTimeBreakdown& breakdown); - -// Returns the percentage of the input time that is spent on transferring the -// data from host to device. -double HostToDeviceTransferAsPercentOfInputTime( - const InputTimeBreakdown& breakdown); - -void AddErrorMessages(const OpStats& op_stats, - InputPipelineAnalysisResult* result); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc deleted file mode 100644 index 663fc62ed80d83..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc +++ /dev/null @@ -1,205 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" - -#include -#include - -#include "google/protobuf/any.pb.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/platform/protobuf.h" -#include "xprof/utils/event_span.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::CoreDetails; -using ::tensorflow::profiler::OpMetricsDb; -using ::tensorflow::profiler::StepDatabaseResult; -using ::tensorflow::profiler::StepEvents; - -TEST(TfOpStatsToInputPipelineAnalysisTest, - AttributeHostInputTimeToTCWhenInfeedMissing) { - uint64_t step_num = 1; - tensorflow::profiler::StepDetails step_details; - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_WAIT_INPUT, - tsl::profiler::Timespan::FromEndPoints(50, 100))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(110, 200))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(430, 500))); - StepEvents host_step_events = {{step_num, step_details}}; - StepDatabaseResult step_db; - tensorflow::profiler::PerCoreStepInfo* pcsi = step_db.add_step_sequence(); - pcsi->set_step_num(step_num); - auto& sipc_map = *pcsi->mutable_step_info_per_core(); - tensorflow::profiler::StepInfoResult& sir = sipc_map[/* core_id= */ 2]; - sir.set_step_num(step_num); - sir.set_begin_ps(40); - sir.set_duration_ps(1000); - tensorflow::profiler::GenericStepBreakdown step_breakdown; - tsl::protobuf::Map& category_ps = - *step_breakdown.mutable_category_ps(); - category_ps[tensorflow::profiler::kIdle] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncDone)] = 50; - sir.mutable_step_breakdown()->PackFrom(step_breakdown); - tsl::protobuf::Map core_details_map; - MayFixTpuStepAnalysis(host_step_events, OpMetricsDb(), step_db, - core_details_map); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown; - sir.step_breakdown().UnpackTo(&updated_step_breakdown); - const tsl::protobuf::Map& updated_category_ps = - updated_step_breakdown.category_ps(); - EXPECT_EQ(updated_category_ps.at(tensorflow::profiler::kIdle), 90); - ASSERT_TRUE(updated_category_ps.contains( - xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - EXPECT_EQ( - updated_category_ps.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 210); -} - -TEST(TfOpStatsToInputPipelineAnalysisTest, - AttributeHostInputTimeToTCWhenInfeedMissingMultiCore) { - uint64_t step_num = 1; - tensorflow::profiler::StepDetails step_details; - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_WAIT_INPUT, - tsl::profiler::Timespan::FromEndPoints(50, 100))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(110, 200))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(430, 500))); - StepEvents host_step_events = {{step_num, step_details}}; - StepDatabaseResult step_db; - tensorflow::profiler::PerCoreStepInfo* pcsi = step_db.add_step_sequence(); - pcsi->set_step_num(step_num); - tsl::protobuf::Map& sipc_map = - *pcsi->mutable_step_info_per_core(); - tensorflow::profiler::StepInfoResult& sir = sipc_map[/* core_id= */ 2]; - sir.set_step_num(step_num); - sir.set_begin_ps(40); - sir.set_duration_ps(1000); - tensorflow::profiler::GenericStepBreakdown step_breakdown; - tsl::protobuf::Map& category_ps = - *step_breakdown.mutable_category_ps(); - category_ps[tensorflow::profiler::kIdle] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncDone)] = 50; - sir.mutable_step_breakdown()->PackFrom(step_breakdown); - tensorflow::profiler::StepInfoResult& sir2 = sipc_map[/* core_id= */ 1]; - sir2.set_step_num(step_num); - sir2.set_begin_ps(45); - sir2.set_duration_ps(900); - tensorflow::profiler::GenericStepBreakdown step_breakdown2; - tsl::protobuf::Map& category_ps2 = - *step_breakdown2.mutable_category_ps(); - category_ps2[tensorflow::profiler::kIdle] = 250; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 250; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kAsyncDone)] = 50; - sir2.mutable_step_breakdown()->PackFrom(step_breakdown2); - tsl::protobuf::Map core_details_map; - OpMetricsDb device_op_metrics_db; - MayFixTpuStepAnalysis(host_step_events, device_op_metrics_db, step_db, - core_details_map); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown; - sir.step_breakdown().UnpackTo(&updated_step_breakdown); - const tsl::protobuf::Map& updated_category_ps = - updated_step_breakdown.category_ps(); - EXPECT_EQ(updated_category_ps.at(tensorflow::profiler::kIdle), 48); - ASSERT_TRUE(updated_category_ps.contains( - xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - EXPECT_EQ( - updated_category_ps.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 252); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown2; - sir2.step_breakdown().UnpackTo(&updated_step_breakdown2); - const tsl::protobuf::Map& updated_category_ps2 = - updated_step_breakdown2.category_ps(); - EXPECT_EQ(updated_category_ps2.at(tensorflow::profiler::kIdle), 40); - ASSERT_TRUE(updated_category_ps2.contains( - xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - EXPECT_EQ( - updated_category_ps2.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 210); -} - -TEST(TfOpStatsToInputPipelineAnalysisTest, - SkipMayFixTpuStepAnalysisWhenInfeedExists) { - uint64_t step_num = 1; - tensorflow::profiler::StepDetails step_details; - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_WAIT_INPUT, - tsl::profiler::Timespan::FromEndPoints(50, 100))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(110, 200))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(430, 500))); - StepEvents host_step_events = {{step_num, step_details}}; - StepDatabaseResult step_db; - tensorflow::profiler::PerCoreStepInfo* pcsi = step_db.add_step_sequence(); - pcsi->set_step_num(step_num); - tsl::protobuf::Map& sipc_map = - *pcsi->mutable_step_info_per_core(); - tensorflow::profiler::StepInfoResult& sir = sipc_map[/* core_id= */ 2]; - sir.set_step_num(step_num); - sir.set_begin_ps(40); - sir.set_duration_ps(1000); - tensorflow::profiler::GenericStepBreakdown step_breakdown; - tsl::protobuf::Map& category_ps = - *step_breakdown.mutable_category_ps(); - category_ps[tensorflow::profiler::kIdle] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kInfeed)] = 50; - sir.mutable_step_breakdown()->PackFrom(step_breakdown); - tsl::protobuf::Map core_details_map; - OpMetricsDb device_op_metrics_db; - device_op_metrics_db.add_metrics_db()->set_category( - std::string(xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - MayFixTpuStepAnalysis(host_step_events, device_op_metrics_db, step_db, - core_details_map); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown; - sir.step_breakdown().UnpackTo(&updated_step_breakdown); - const tsl::protobuf::Map& updated_category_ps = - updated_step_breakdown.category_ps(); - EXPECT_EQ(updated_category_ps.at(tensorflow::profiler::kIdle), 300); - EXPECT_EQ( - updated_category_ps.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 50); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc b/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc deleted file mode 100644 index 6e3119e1b3931d..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_op_profile.h" - -#include -#include - -#include "absl/log/check.h" -#include "absl/strings/match.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/convert/op_profile_builder.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::IsIdleOp; -using ::tensorflow::profiler::OpMetrics; -using ::tensorflow::profiler::OpProfileBuilder; -using ::tensorflow::profiler::OpProfileOptions; -using ::tensorflow::profiler::OpStats; -using ::tensorflow::profiler::TotalTimePs; -using ::tensorflow::profiler::op_profile::Node; - -void BuildOpProfileNodeTree(const OpStats& op_stats, bool group_by_program, - bool exclude_idle_ops, int op_profile_limit, - Node* root) { - const auto& metrics_db = op_stats.device_op_metrics_db(); - if (metrics_db.metrics_db().empty()) return; - - OpProfileOptions options = {group_by_program, - /*group_by_deduplicated_name=*/true, - /*children_per_node=*/op_profile_limit}; - OpProfileBuilder builder(options, root, &op_stats.program_id_to_name_map()); - - for (const OpMetrics& op_metrics : metrics_db.metrics_db()) { - DCHECK(!op_metrics.name().empty()); - // Don't add ops that cannot be symbolized. - if (absl::StartsWith(op_metrics.name(), "region")) continue; - if (exclude_idle_ops && IsIdleOp(op_metrics)) continue; - builder.AddOp(op_metrics); - } - - const auto& perf_env = op_stats.perf_env(); - double max_gigaflops_per_second_per_core = - tsl::profiler::TeraToGiga(perf_env.peak_tera_flops_per_second()); - std::vector peak_bws; - for (auto bw : perf_env.peak_bws_giga_bytes_per_second()) { - peak_bws.push_back(tsl::profiler::GigaToGibi(bw)); - } - builder.Finalize(max_gigaflops_per_second_per_core, peak_bws, - TotalTimePs(metrics_db, exclude_idle_ops)); -} - -} // namespace - -void ConvertOpStatsToOpProfile( - const OpStats& op_stats, tensorflow::profiler::HardwareType hardware_type, - tensorflow::profiler::op_profile::Profile& profile, int op_profile_limit) { - profile.set_device_type(HardwareType_Name(hardware_type)); - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/false, - /*exclude_idle_ops=*/false, op_profile_limit, - profile.mutable_by_category()); - - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/false, - /*exclude_idle_ops=*/true, op_profile_limit, - profile.mutable_by_category_exclude_idle()); - - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/true, - /*exclude_idle_ops=*/false, op_profile_limit, - profile.mutable_by_program()); - - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/true, - /*exclude_idle_ops=*/true, op_profile_limit, - profile.mutable_by_program_exclude_idle()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_op_profile.h b/tensorflow/core/profiler/convert/op_stats_to_op_profile.h deleted file mode 100644 index 1fcfefb510d454..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_op_profile.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ - -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Assembles a hierarchical performance profile based on HLOs in the op metrics -// db. -// The node hierarchy is as following: -// by_category -// - combined_root -// - category 1 -// - category 2 -// - ... -// - idle -// by_program -// - program_1_root -// - category 1 -// - category 2 -// - ... -// - program_2_root -// - category 1 -// - ... -// - idle -// The nodes in the profile are sorted by time in decreasing order and pruned -// to reduce the profile size. Only 100 nodes are kept for level >= 3. -// See op_profile.proto for the detailed semantics of the returned profile. -void ConvertOpStatsToOpProfile( - const tensorflow::profiler::OpStats& op_stats, - tensorflow::profiler::HardwareType hardware_type, - tensorflow::profiler::op_profile::Profile& profile, - int op_profile_limit = 100); - -} // namespace profiler -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc deleted file mode 100644 index f582933c782aeb..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ /dev/null @@ -1,408 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" - -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/algorithm/container.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/format_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" -#include "tensorflow/core/profiler/protobuf/power_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/overview_page.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/power_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof -#include "xprof/utils/diagnostics.h" // from @org_xprof -#include "xprof/utils/hardware_type_utils.h" // from @org_xprof -#include "xprof/utils/html_utils.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -using tsl::profiler::OneDigit; - -// If the use of low-precision ops is less than this percentage threshold, a -// statement of suggestion will be made. -constexpr double kLowPrecisionPercentThreshold = 10; - -struct TfFunctionInfo { - absl::string_view function_name; - double expensive_call_percent; -}; - -OverviewPageTip MakeOverviewPageTip(std::string text) { - OverviewPageTip tip; - tip.set_link(std::move(text)); - return tip; -} - -// Makes a recommendation for looking up a document. -// doc_url is expected to be already be escaped suitably for use in an HTML -// attribute. -OverviewPageTip MakeOverviewPageTipDocLink(absl::string_view doc_url, - absl::string_view text) { - return MakeOverviewPageTip(AnchorElement(doc_url, text)); -} - -void ComputeHostTips(OverviewPageRecommendation* re) { - *re->add_host_tips() = MakeOverviewPageTip( - "input_pipeline_analyzer (especially Section 3 for the breakdown of " - "input operations on the Host)"); - *re->add_host_tips() = MakeOverviewPageTip( - "trace_viewer (look at the activities on the timeline of each Host " - "Thread near the bottom of the trace view)"); -} - -void ComputeDeviceTips(HardwareType hardware_type, - OverviewPageRecommendation* re) { - absl::string_view device_name = HardwareType_Name(hardware_type); - absl::string_view timeline_name = device_name; - absl::string_view op_stats_toolname = "framework_op_stats"; - if (hardware_type == tensorflow::profiler::TPU) { - timeline_name = "TPU core"; - op_stats_toolname = "op_profile"; - } - *re->add_device_tips() = MakeOverviewPageTip( - absl::StrCat(op_stats_toolname, - " (identify the time-consuming operations " - "executed on the ", - device_name, ")")); - *re->add_device_tips() = MakeOverviewPageTip(absl::StrCat( - "trace_viewer (look at the activities on the timeline of each ", - timeline_name, " in the trace view)")); -} - -void ComputeFaqTips(OverviewPageRecommendation* re) { - *re->add_faq_tips() = MakeOverviewPageTip("Refer to the TF2 Profiler FAQ"); -} - -void ComputeDocumentationTips(OverviewPageRecommendation* re) { - *re->add_documentation_tips() = MakeOverviewPageTipDocLink( - "https://www.tensorflow.org/guide/data_performance_analysis", - "Analyze tf.data performance with the TF Profiler"); - *re->add_documentation_tips() = MakeOverviewPageTipDocLink( - "https://www.tensorflow.org/guide/" - "data_performance", - "Better performance with the tf.data API"); -} - -std::string GeneratePrecisionStatement(const PrecisionStats& precision_stats) { - uint64 total_compute_ps = - precision_stats.compute_16bit_ps() + precision_stats.compute_32bit_ps(); - if (total_compute_ps > 0) { - double percent_16bit = - (100.0 * precision_stats.compute_16bit_ps()) / total_compute_ps; - if (percent_16bit < kLowPrecisionPercentThreshold) { - return absl::StrCat( - "Only ", OneDigit(percent_16bit), - "% of device computation is 16 bit. So you might want to replace " - "more 32-bit Ops by 16-bit Ops to improve performance (if the " - "reduced accuracy is acceptable)."); - } - } - return ""; -} - -} // namespace - -void SetCommonRecommendation( - absl::string_view input_classification, absl::string_view input_statement, - absl::string_view output_statement, HardwareType hardware_type, - absl::string_view tf_function_statement_html, - absl::string_view eager_statement_html, - absl::string_view outside_compilation_statement_html, - OverviewPageRecommendation* re) { - re->set_bottleneck(std::string(input_classification)); - re->set_statement(std::string(input_statement)); - re->set_output_statement(std::string(output_statement)); - re->set_tf_function_statement_html(std::string(tf_function_statement_html)); - re->set_eager_statement_html(std::string(eager_statement_html)); - re->set_outside_compilation_statement_html( - std::string(outside_compilation_statement_html)); - ComputeHostTips(re); - ComputeDeviceTips(hardware_type, re); - ComputeDocumentationTips(re); - ComputeFaqTips(re); -} - -OverviewPageRecommendation ComputeGenericRecommendation( - const BottleneckAnalysis& bottleneck, - const PrecisionStats& precision_stats) { - OverviewPageRecommendation re; - GenericRecommendation generic; - generic.set_device_collectives_bottleneck( - bottleneck.device_collectives_classification()); - generic.set_device_collectives_statement( - bottleneck.device_collectives_statement()); - generic.set_kernel_launch_bottleneck( - bottleneck.kernel_launch_classification()); - generic.set_kernel_launch_statement(bottleneck.kernel_launch_statement()); - generic.set_all_other_bottleneck(bottleneck.all_other_classification()); - generic.set_all_other_statement(bottleneck.all_other_statement()); - generic.set_precision_statement(GeneratePrecisionStatement(precision_stats)); - re.mutable_recommendation()->PackFrom(generic); - return re; -} - -OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats) { - OverviewPageAnalysis analysis; - OpMetricsDb device_tf_op_metrics_db = CreateTfMetricsDbFromDeviceOpMetricsDb( - op_stats.device_op_metrics_db(), /*with_idle=*/false); - KernelStatsByOpName kernel_stats_by_op_name = - GroupKernelReportsByOpName(op_stats.kernel_stats_db()); - uint64 total_device_time_ps = device_tf_op_metrics_db.total_time_ps(); - constexpr int kNumTopOpsShown = 10; - double device_cumulative_fraction = 0.0; - for (const OpMetrics* metrics : - SortedOpMetricsDb(device_tf_op_metrics_db, kNumTopOpsShown)) { - OverviewTfOp* op = analysis.add_top_device_ops(); - op->set_name(metrics->name()); - op->set_category(metrics->category()); - op->set_self_time_fraction(tsl::profiler::SafeDivide( - metrics->self_time_ps(), total_device_time_ps)); - device_cumulative_fraction += op->self_time_fraction(); - op->set_cumulative_time_fraction(device_cumulative_fraction); - op->set_flop_rate(tsl::profiler::SafeDivide( - metrics->flops(), tsl::profiler::PicoToNano(metrics->time_ps()))); - auto iter = kernel_stats_by_op_name.find(op->name()); - if (iter != kernel_stats_by_op_name.end()) { - op->set_is_op_tensorcore_eligible( - iter->second.is_op_tensor_core_eligible); - op->set_is_op_using_tensorcore(iter->second.tensor_core_duration_ns != 0); - } - } - uint64 total_device_compute_ps = - op_stats.device_op_metrics_db().precision_stats().compute_16bit_ps() + - op_stats.device_op_metrics_db().precision_stats().compute_32bit_ps(); - analysis.set_device_compute_16bit_percent( - 100.0 * - tsl::profiler::SafeDivide( - op_stats.device_op_metrics_db().precision_stats().compute_16bit_ps(), - total_device_compute_ps)); - analysis.set_device_compute_32bit_percent( - 100.0 * - tsl::profiler::SafeDivide( - op_stats.device_op_metrics_db().precision_stats().compute_32bit_ps(), - total_device_compute_ps)); - - uint64 num_host_tf_ops = 0; - uint64 total_host_op_time_ps_exclude_idle = 0; - uint64 eager_host_op_time_ps = 0; - for (const OpMetrics& metrics : op_stats.host_op_metrics_db().metrics_db()) { - num_host_tf_ops += metrics.occurrences(); - if (!IsIdleOp(metrics)) { - total_host_op_time_ps_exclude_idle += metrics.self_time_ps(); - if (metrics.is_eager()) eager_host_op_time_ps += metrics.self_time_ps(); - } - } - uint64 num_device_tf_ops = 0; - uint64 total_device_op_time_ps_exclude_idle = 0; - uint64 eager_device_op_time_ps = 0; - for (const OpMetrics& metrics : device_tf_op_metrics_db.metrics_db()) { - num_device_tf_ops += metrics.occurrences(); - if (!IsIdleOp(metrics)) { - total_device_op_time_ps_exclude_idle += metrics.self_time_ps(); - if (metrics.is_eager()) eager_device_op_time_ps += metrics.self_time_ps(); - } - } - // Figures out outside_compilation time from - // op_stats.device_op_metrics_db().metrics_db(). We don't use the - // {metrics.provenance(), metrics.name()} from - // device_tf_op_metrics_db.metrics_db(), because metrics.provenance() there is - // not set and metrics.name() can be either HLO-Op name or TF-Op name, which - // will confuse tsl::profiler::IsOutsideCompilationOp(). - uint64 outside_compilation_device_op_time_ps = 0; - for (const OpMetrics& metrics : - op_stats.device_op_metrics_db().metrics_db()) { - if (!tsl::profiler::IsOutsideCompilationOp(metrics.provenance(), - metrics.long_name())) - continue; - outside_compilation_device_op_time_ps += metrics.self_time_ps(); - } - uint64 num_total_tf_ops = num_host_tf_ops + num_device_tf_ops; - analysis.set_host_tf_op_percent( - 100.0 * tsl::profiler::SafeDivide(num_host_tf_ops, num_total_tf_ops)); - analysis.set_device_tf_op_percent( - 100.0 * tsl::profiler::SafeDivide(num_device_tf_ops, num_total_tf_ops)); - analysis.set_host_trace_level(op_stats.run_environment().host_trace_level()); - analysis.set_host_op_time_eager_percent( - 100.0 * tsl::profiler::SafeDivide(eager_host_op_time_ps, - total_host_op_time_ps_exclude_idle)); - analysis.set_device_op_time_eager_percent( - 100.0 * tsl::profiler::SafeDivide(eager_device_op_time_ps, - total_device_op_time_ps_exclude_idle)); - analysis.set_device_op_time_outside_compilation_percent( - 100.0 * tsl::profiler::SafeDivide(outside_compilation_device_op_time_ps, - total_device_op_time_ps_exclude_idle)); - return analysis; -} - -// Converts from HostIndependentJobInfo to OverviewPageHostIndependentJobInfo. -OverviewPageHostIndependentJobInfo ToOverviewPageHostIndependentJobInfo( - const HostIndependentJobInfoResult& host_independent_job_info) { - OverviewPageHostIndependentJobInfo result; - result.set_change_list(host_independent_job_info.change_list()); - result.set_build_time(host_independent_job_info.build_time()); - result.set_build_target(host_independent_job_info.build_target()); - result.set_profile_duration_ms( - host_independent_job_info.profile_duration_ms()); - return result; -} - -// Converts from HostDependentJobInfo to OverviewPageHostDependentJobInfo. -OverviewPageHostDependentJobInfo ToOverviewPageHostDependentJobInfo( - const HostDependentJobInfoResult& host_dependent_job_info) { - OverviewPageHostDependentJobInfo result; - result.set_host_id(host_dependent_job_info.host_id()); - result.set_command_line(host_dependent_job_info.command_line()); - result.set_start_time(host_dependent_job_info.start_time()); - result.set_bns_address(host_dependent_job_info.bns_address()); - result.set_profile_time_ns(host_dependent_job_info.profile_time_ns()); - return result; -} - -OverviewPageRunEnvironment ComputeRunEnvironment( - const RunEnvironment& run_environment) { - OverviewPageRunEnvironment re; - re.set_host_count(run_environment.host_count()); - re.set_task_count(run_environment.task_count()); - re.set_device_type(run_environment.device_type()); - re.set_device_core_count(run_environment.device_core_count()); - re.set_replica_count(run_environment.replica_count()); - re.set_num_cores_per_replica(run_environment.num_cores_per_replica()); - re.set_is_training(run_environment.is_training()); - if (run_environment.has_power_metrics()) { - *re.mutable_power_metrics() = run_environment.power_metrics(); - } - *re.mutable_host_independent_job_info() = - ToOverviewPageHostIndependentJobInfo( - run_environment.host_independent_job_info()); - for (const auto& host_dependent_job_info : - run_environment.host_dependent_job_info()) { - *re.add_host_dependent_job_info() = - ToOverviewPageHostDependentJobInfo(host_dependent_job_info); - } - return re; -} - -std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db) { - std::vector candidates; - for (const auto& name_fun : tf_function_db.tf_functions()) { - const auto& fun = name_fun.second; - if (fun.expensive_call_percent() >= kTfFunctionReportThresholdInPercent) { - candidates.push_back({name_fun.first, fun.expensive_call_percent()}); - } - } - if (candidates.empty()) return ""; - auto cmp = [](const TfFunctionInfo& a, const TfFunctionInfo& b) { - return a.expensive_call_percent > b.expensive_call_percent; - }; - // Sorts candidates in descending order of expensive_call_percent. - absl::c_sort(candidates, cmp); - std::string expensive_functions = ""; - auto num_functions_shown = std::min( - static_cast(3), candidates.size()); - - for (decltype(candidates)::size_type i = 0; i < num_functions_shown; i++) { - if (i > 0) absl::StrAppend(&expensive_functions, ", "); - absl::StrAppend(&expensive_functions, "\"", candidates[i].function_name, - "\""); - } - if (candidates.size() > num_functions_shown) - absl::StrAppend(&expensive_functions, " and more"); - return absl::StrCat("Expensive tf-functions detected (", expensive_functions, - ") due to either retracing or eager execution."); -} - -std::string EagerRecommendationHtml(double host_op_time_eager_percent, - double device_op_time_eager_percent) { - std::string recommendation = ""; - if (host_op_time_eager_percent > kEagerReportThresholdInPercent) - absl::StrAppend(&recommendation, OneDigit(host_op_time_eager_percent), - "% of Op time on the host used eager execution. "); - if (device_op_time_eager_percent > kEagerReportThresholdInPercent) - absl::StrAppend(&recommendation, OneDigit(device_op_time_eager_percent), - "% of Op time on the device used eager execution. "); - if (!recommendation.empty()) - absl::StrAppend(&recommendation, "Performance could be improved with ", - AnchorElement("https://www.tensorflow.org/guide/function", - "tf.function.")); - return recommendation; -} - -std::string OutsideCompilationRecommendationHtml( - double device_op_time_outside_compilation_percent) { - if (device_op_time_outside_compilation_percent <= - kOutsideCompilationThresholdInPercent) - return ""; - return absl::StrCat( - OneDigit(device_op_time_outside_compilation_percent), - " % of Op time on the device are for outside compilation. Performance " - "could be improved by avoiding outside compilation."); -} - -OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats) { - OverviewPage overview_page; - *overview_page.mutable_run_environment() = - ComputeRunEnvironment(op_stats.run_environment()); - *overview_page.mutable_analysis() = ComputeAnalysisResult(op_stats); - *overview_page.mutable_input_analysis() = - ConvertOpStatsToInputPipelineAnalysis(op_stats); - BottleneckAnalysis bottleneck = ComputeBottleneckAnalysis( - overview_page.input_analysis().input_time_breakdown(), - overview_page.input_analysis().step_details()); - *overview_page.mutable_recommendation() = ComputeGenericRecommendation( - bottleneck, op_stats.device_op_metrics_db().precision_stats()); - SetCommonRecommendation( - bottleneck.input_classification(), bottleneck.input_statement(), "", - ParseHardwareType(op_stats.run_environment().device_type()), - TfFunctionRecommendationHtml(op_stats.tf_function_db()), - EagerRecommendationHtml( - overview_page.analysis().host_op_time_eager_percent(), - overview_page.analysis().device_op_time_eager_percent()), - OutsideCompilationRecommendationHtml( - overview_page.analysis() - .device_op_time_outside_compilation_percent()), - overview_page.mutable_recommendation()); - PopulateOverviewDiagnostics(op_stats, overview_page.mutable_diagnostics()); - overview_page.mutable_analysis()->set_mxu_utilization_percent( - op_stats.performance_counter_result().matrix_unit_utilization_percent()); - return overview_page; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h deleted file mode 100644 index ba6d906e325d96..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Reports tf-function optimization opportunity in the Overview Page if the -// expensive-call-time percentage is over this threshold for at least one of -// the tf-functions profiled. -const double kTfFunctionReportThresholdInPercent = 20; - -// Reports eager-mode optimization opportunity in the Overview Page if the -// percent of Op time on host (or device) that is spent on eager mode is over -// this threshold. -const double kEagerReportThresholdInPercent = 10; - -// Reports outside-compilation opportunity in the Overview Page if the -// percent of Op time on device that is for outside compilation is over -// this threshold. -const double kOutsideCompilationThresholdInPercent = 5; - -void SetCommonRecommendation( - absl::string_view input_classification, absl::string_view input_statement, - absl::string_view output_statement, HardwareType hardware_type, - absl::string_view tf_function_statement_html, - absl::string_view eager_statement_html, - absl::string_view outside_compilation_statement_html, - OverviewPageRecommendation* re); - -OverviewPageRecommendation ComputeGenericRecommendation( - const BottleneckAnalysis& bottleneck, - const PrecisionStats& precision_stats); - -OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats); - -OverviewPageRunEnvironment ComputeRunEnvironment( - const RunEnvironment& run_environment); - -OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats); - -// Returns a html which provides tf-function related recommendation. -std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db); - -// Returns a html which provides eager-mode related recommendation. -std::string EagerRecommendationHtml(double host_op_time_eager_percent, - double device_op_time_eager_percent); - -// Returns a html which provides outside-compilation related recommendation. -std::string OutsideCompilationRecommendationHtml( - double device_op_time_outside_compilation_percent); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc deleted file mode 100644 index 13fcef0ca25dec..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" - -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "xprof/utils/diagnostics.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -PodStatsRecord CreatePodStatsRecord(absl::string_view host_name, - const StepInfoResult& step_info) { - PodStatsRecord record; - GenericStepBreakdown generic; - bool success = step_info.step_breakdown().UnpackTo(&generic); - DCHECK(success); - record.set_host_name(string(host_name)); - record.set_step_num(step_info.step_num()); - record.set_total_duration_us( - tsl::profiler::PicoToMicro(step_info.duration_ps())); - auto& step_breakdown_map = *record.mutable_step_breakdown_us(); - std::vector> metrics; - - auto add_event = [&](GenericEventType type, - std::initializer_list event_list) { - uint64 ps = 0; - for (const auto& event_type : event_list) { - ps += gtl::FindWithDefault(generic.type_ps(), event_type, /*value=*/0); - } - step_breakdown_map[type] = tsl::profiler::PicoToMicro(ps); - metrics.emplace_back(ps, GetGenericEventTypeStr(type)); - }; - - add_event(kDeviceCompute, {DEVICE_COMPUTE_32, DEVICE_COMPUTE_16}); - add_event(kDeviceToDevice, {DEVICE_TO_DEVICE, DEVICE_WAIT_DEVICE}); - add_event(kDeviceCollectives, {DEVICE_COLLECTIVES}); - add_event(kHostCompute, {HOST_COMPUTE}); - add_event(kHostPrepare, {HOST_PREPARE}); - add_event(kInput, {HOST_WAIT_INPUT, HOST_TO_DEVICE, DEVICE_WAIT_HOST}); - add_event(kOutput, {DEVICE_TO_HOST}); - add_event(kCompile, {HOST_COMPILE}); - add_event(kAllOthers, {UNKNOWN_TIME}); - - std::sort(metrics.begin(), metrics.end()); - record.set_bottleneck(metrics.back().second.data(), - metrics.back().second.size()); - return record; -} - -} // namespace - -PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats) { - PodStatsDatabase pod_stats_db; - const auto& core_id_map = op_stats.core_id_to_details(); - for (int i = GenericEventType::kFirstGenericEventType; - i <= GenericEventType::kLastGenericEventType; i++) { - auto& event = *pod_stats_db.add_step_breakdown_events(); - event.set_id(i); - absl::string_view type_str = - GetGenericEventTypeStr(static_cast(i)); - event.set_name(type_str.data(), type_str.size()); - } - - for (const auto& step_sequence : op_stats.step_db().step_sequence()) { - for (const auto& entry : step_sequence.step_info_per_core()) { - if (!core_id_map.contains(entry.first)) { - LOG(WARNING) << "core_id_map does not contain " << entry.first; - continue; - } - const CoreDetails& details = core_id_map.at(entry.first); - *pod_stats_db.add_pod_stats_record() = - CreatePodStatsRecord(details.hostname(), entry.second); - } - } - PopulateStepDiagnostics(op_stats, pod_stats_db.mutable_diagnostics()); - return pod_stats_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h deleted file mode 100644 index bd3d74068d8a6c..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ - -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc deleted file mode 100644 index 15a8caf84af5fe..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" - -#include "google/protobuf/any.pb.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/pod_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof -#include "xprof/utils/diagnostics.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -const double kMaxError = 1e-6; -constexpr int kStepNum = 2; -constexpr int kCoreId = 1001; -constexpr int kStepTimePs = 1000; -constexpr int kHostComputePs = 50; -constexpr int kHostCompilePs = 50; -constexpr int kHostToHostPs = 50; -constexpr int kHostToDevicePs = 50; -constexpr int kHostPreparePs = 50; -constexpr int kDeviceCollectivePs = 350; -constexpr int kHostWaitInputPs = 50; -constexpr int kDeviceToDevicePs = 50; -constexpr int kDeviceToHostPs = 50; -constexpr int kDeviceCompute32Ps = 50; -constexpr int kDeviceCompute16Ps = 50; -constexpr int kDeviceWaitDevicePs = 50; -constexpr int kDeviceWaitHostPs = 50; -constexpr int kUnknownTimePs = 50; -static constexpr char kHostname[] = "host:123"; - -void CreateOpStats(OpStats* op_stats) { - PerCoreStepInfo* info = op_stats->mutable_step_db()->add_step_sequence(); - info->set_step_num(kStepNum); - StepInfoResult& step_info = (*info->mutable_step_info_per_core())[kCoreId]; - step_info.set_step_num(kStepNum); - step_info.set_duration_ps(kStepTimePs); - GenericStepBreakdown breakdown; - auto& type_ps = *breakdown.mutable_type_ps(); - type_ps[HOST_COMPUTE] = kHostComputePs; - type_ps[HOST_COMPILE] = kHostCompilePs; - type_ps[HOST_TO_HOST] = kHostToHostPs; - type_ps[HOST_TO_DEVICE] = kHostToDevicePs; - type_ps[HOST_PREPARE] = kHostPreparePs; - type_ps[DEVICE_COLLECTIVES] = kDeviceCollectivePs; - type_ps[HOST_WAIT_INPUT] = kHostWaitInputPs; - type_ps[DEVICE_TO_DEVICE] = kDeviceToDevicePs; - type_ps[DEVICE_TO_HOST] = kDeviceToHostPs; - type_ps[DEVICE_COMPUTE_32] = kDeviceCompute32Ps; - type_ps[DEVICE_COMPUTE_16] = kDeviceCompute16Ps; - type_ps[DEVICE_WAIT_DEVICE] = kDeviceWaitDevicePs; - type_ps[DEVICE_WAIT_HOST] = kDeviceWaitHostPs; - type_ps[UNKNOWN_TIME] = kUnknownTimePs; - step_info.mutable_step_breakdown()->PackFrom(breakdown); - CoreDetails& details = (*op_stats->mutable_core_id_to_details())[kCoreId]; - details.set_hostname(kHostname); -} - -TEST(OpStatsToPodStats, GpuPodStats) { - OpStats op_stats; - CreateOpStats(&op_stats); - PodStatsDatabase pod_stats_db = ConvertOpStatsToPodStats(op_stats); - EXPECT_EQ(1, pod_stats_db.pod_stats_record_size()); - const PodStatsRecord& record = pod_stats_db.pod_stats_record(0); - EXPECT_EQ(kStepNum, record.step_num()); - EXPECT_EQ(kHostname, record.host_name()); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kStepTimePs), - record.total_duration_us(), kMaxError); - const auto& breakdown = record.step_breakdown_us(); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceCompute32Ps + kDeviceCompute16Ps), - breakdown.at(kDeviceCompute), kMaxError); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceToDevicePs + kDeviceWaitDevicePs), - breakdown.at(kDeviceToDevice), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceCollectivePs), - breakdown.at(kDeviceCollectives), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostComputePs), - breakdown.at(kHostCompute), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostPreparePs), - breakdown.at(kHostPrepare), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostWaitInputPs + kHostToDevicePs + - kDeviceWaitHostPs), - breakdown.at(kInput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceToHostPs), - breakdown.at(kOutput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostCompilePs), - breakdown.at(kCompile), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kUnknownTimePs), - breakdown.at(kAllOthers), kMaxError); - - EXPECT_EQ(GetGenericEventTypeStr(kDeviceCollectives), record.bottleneck()); -} - -TEST(OpStatsToPodStats, Diagnostics) { - OpStats op_stats; - op_stats.mutable_step_db()->set_use_incomplete_step(true); - PodStatsDatabase pod_stats_db = ConvertOpStatsToPodStats(op_stats); - EXPECT_EQ(1, pod_stats_db.diagnostics().warnings_size()); - EXPECT_EQ(kErrorIncompleteStep, pod_stats_db.diagnostics().warnings(0)); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc deleted file mode 100644 index 536ca0427d2dae..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" - -#include - -#include "absl/log/check.h" -#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" -#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/pod_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/pod_viewer.pb.h" // from @org_xprof -#include "xprof/utils/diagnostics.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -PodStatsSequence ConvertOpStatsToPodStatsSequence(const OpStats& op_stats, - PodStatsDatabase pod_stats) { - PodStatsSequence result_db; - // PodStatsDatabase is created using the same iteration order below. - // Thus, we just need to move one record at a time. - int i = 0; - for (const auto& step_sequence : op_stats.step_db().step_sequence()) { - PodStatsMap* pod_stats_map = result_db.add_pod_stats_map(); - pod_stats_map->set_step_num(step_sequence.step_num()); - for (const auto& entry : step_sequence.step_info_per_core()) { - PodStatsRecord& record = - (*pod_stats_map->mutable_pod_stats_per_core())[entry.first]; - DCHECK_LE(i, pod_stats.pod_stats_record_size()); - record = std::move(*pod_stats.mutable_pod_stats_record(i++)); - } - } - return result_db; -} - -} // namespace - -PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats) { - PodViewerDatabase database; - database.set_device_type(op_stats.run_environment().device_type()); - PodStatsDatabase pod_stats = ConvertOpStatsToPodStats(op_stats); - database.mutable_step_breakdown_events()->Swap( - pod_stats.mutable_step_breakdown_events()); - *database.mutable_pod_stats_sequence() = - ConvertOpStatsToPodStatsSequence(op_stats, std::move(pod_stats)); - PopulateStepDiagnostics(op_stats, database.mutable_diagnostics()); - return database; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h deleted file mode 100644 index 0d690000527479..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ - -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/pod_viewer.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/pod_viewer.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc deleted file mode 100644 index 0a5bb58ed7470d..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" - -#include "google/protobuf/any.pb.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/pod_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/pod_viewer.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof -#include "xprof/utils/diagnostics.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -const double kMaxError = 1e-6; -constexpr int kStepNum = 2; -constexpr int kCoreId = 1001; -constexpr int kStepTimePs = 1000; -constexpr int kHostComputePs = 50; -constexpr int kHostCompilePs = 50; -constexpr int kHostToHostPs = 50; -constexpr int kHostToDevicePs = 50; -constexpr int kHostPreparePs = 50; -constexpr int kDeviceCollectivePs = 350; -constexpr int kHostWaitInputPs = 50; -constexpr int kDeviceToDevicePs = 50; -constexpr int kDeviceToHostPs = 50; -constexpr int kDeviceCompute32Ps = 50; -constexpr int kDeviceCompute16Ps = 50; -constexpr int kDeviceWaitDevicePs = 50; -constexpr int kDeviceWaitHostPs = 50; -constexpr int kUnknownTimePs = 50; -static constexpr char kHostname[] = "host:123"; - -void CreateOpStats(OpStats* op_stats) { - PerCoreStepInfo* info = op_stats->mutable_step_db()->add_step_sequence(); - info->set_step_num(kStepNum); - StepInfoResult& step_info = (*info->mutable_step_info_per_core())[kCoreId]; - step_info.set_step_num(kStepNum); - step_info.set_duration_ps(kStepTimePs); - GenericStepBreakdown breakdown; - auto& type_ps = *breakdown.mutable_type_ps(); - type_ps[HOST_COMPUTE] = kHostComputePs; - type_ps[HOST_COMPILE] = kHostCompilePs; - type_ps[HOST_TO_HOST] = kHostToHostPs; - type_ps[HOST_TO_DEVICE] = kHostToDevicePs; - type_ps[HOST_PREPARE] = kHostPreparePs; - type_ps[DEVICE_COLLECTIVES] = kDeviceCollectivePs; - type_ps[HOST_WAIT_INPUT] = kHostWaitInputPs; - type_ps[DEVICE_TO_DEVICE] = kDeviceToDevicePs; - type_ps[DEVICE_TO_HOST] = kDeviceToHostPs; - type_ps[DEVICE_COMPUTE_32] = kDeviceCompute32Ps; - type_ps[DEVICE_COMPUTE_16] = kDeviceCompute16Ps; - type_ps[DEVICE_WAIT_DEVICE] = kDeviceWaitDevicePs; - type_ps[DEVICE_WAIT_HOST] = kDeviceWaitHostPs; - type_ps[UNKNOWN_TIME] = kUnknownTimePs; - step_info.mutable_step_breakdown()->PackFrom(breakdown); - CoreDetails& details = (*op_stats->mutable_core_id_to_details())[kCoreId]; - details.set_hostname(kHostname); -} - -TEST(OpStatsToPodViewer, GpuPodViewer) { - OpStats op_stats; - CreateOpStats(&op_stats); - PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); - EXPECT_EQ(1, pod_viewer_db.pod_stats_sequence().pod_stats_map_size()); - const PodStatsMap& pod_stats_map = - pod_viewer_db.pod_stats_sequence().pod_stats_map(0); - EXPECT_EQ(kStepNum, pod_stats_map.step_num()); - const PodStatsRecord& record = pod_stats_map.pod_stats_per_core().at(kCoreId); - EXPECT_EQ(kStepNum, record.step_num()); - EXPECT_EQ(kHostname, record.host_name()); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kStepTimePs), - record.total_duration_us(), kMaxError); - const auto& breakdown = record.step_breakdown_us(); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceCompute32Ps + kDeviceCompute16Ps), - breakdown.at(kDeviceCompute), kMaxError); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceToDevicePs + kDeviceWaitDevicePs), - breakdown.at(kDeviceToDevice), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceCollectivePs), - breakdown.at(kDeviceCollectives), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostComputePs), - breakdown.at(kHostCompute), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostPreparePs), - breakdown.at(kHostPrepare), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostWaitInputPs + kHostToDevicePs + - kDeviceWaitHostPs), - breakdown.at(kInput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceToHostPs), - breakdown.at(kOutput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostCompilePs), - breakdown.at(kCompile), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kUnknownTimePs), - breakdown.at(kAllOthers), kMaxError); - - EXPECT_EQ(GetGenericEventTypeStr(kDeviceCollectives), record.bottleneck()); -} - -TEST(OpStatsToPodViewer, Diagnostics) { - OpStats op_stats; - op_stats.mutable_step_db()->set_use_incomplete_step(true); - PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); - EXPECT_EQ(1, pod_viewer_db.diagnostics().warnings_size()); - EXPECT_EQ(kErrorIncompleteStep, pod_viewer_db.diagnostics().warnings(0)); -} - -TEST(OpStatsToPodViewer, DeviceType) { - OpStats op_stats; - op_stats.mutable_run_environment()->set_device_type("GPU"); - PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); - EXPECT_EQ("GPU", pod_viewer_db.device_type()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc deleted file mode 100644 index 613f7048aa92ed..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_roofline_model.h" - -#include -#include - -#include "absl/log/check.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/platform/protobuf.h" -#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/roofline_model.pb.h" // from @org_xprof -#include "xprof/utils/diagnostics.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using tensorflow::profiler::OpMetrics; -using tensorflow::profiler::OpMetricsDb; -using tensorflow::profiler::PerfEnv; -using tensorflow::profiler::roofline_model::RecordType; -using tensorflow::profiler::roofline_model::RooflineModelDatabase; -using tensorflow::profiler::roofline_model::RooflineModelRecord; - -// The maximum number of records to generate. -const uint32_t kMaxNumRecords = 1000; -} // namespace - -RooflineModelRecord ConvertOpMetricsToRooflineModelRecord( - const OpStats& op_stats, const OpMetrics& metrics, RecordType record_type, - uint32_t step_num, uint64_t total_time_ps, - const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - RooflineModelRecord record; - record.set_hlo_name(metrics.name()); - record.set_hlo_category(metrics.category()); - record.set_hlo_module_id(metrics.hlo_module_id()); - record.set_record_type(record_type); - record.set_step_num(step_num); - SetExecutionTimes(metrics, &record); - if (record_type == RecordType::AVERAGE_STEP) { - // For RecordType::AVERAGE_STEP, divide by num_steps to show per-step - // numbers when appropriate. - int num_steps = op_stats.step_db().step_sequence_size(); - record.set_total_time_in_us( - tsl::profiler::SafeDivide(record.total_time_in_us(), num_steps)); - record.set_total_self_time_in_us( - tsl::profiler::SafeDivide(record.total_self_time_in_us(), num_steps)); - } - record.set_total_time_per_core_in_us(tsl::profiler::SafeDivide( - record.total_time_in_us(), - op_stats.run_environment().device_core_count())); - record.set_total_time_in_percentage( - tsl::profiler::SafeDivide(metrics.time_ps(), total_time_ps)); - - tensorflow::profiler::SetTpuUnitFractions(metrics, &record); - - // Set the roofline-specific fields. - SetRooflineMetrics(metrics, op_stats.perf_env(), op_stats.run_environment(), - &record); - const double cmem_wr_utilization = - roofline_model_db.has_cmem() - ? tsl::profiler::SafeDivide(record.cmem_write_bw(), - roofline_model_db.peak_cmem_write_bw()) - : 0; - const double cmem_rd_utilization = - roofline_model_db.has_cmem() - ? tsl::profiler::SafeDivide(record.cmem_read_bw(), - roofline_model_db.peak_cmem_read_bw()) - : 0; - const double vmem_rd_utilization = - roofline_model_db.has_merged_vmem() - ? tsl::profiler::SafeDivide(record.vmem_read_bw(), - roofline_model_db.peak_vmem_read_bw()) - : 0; - const double vmem_wr_utilization = - roofline_model_db.has_merged_vmem() - ? tsl::profiler::SafeDivide(record.vmem_write_bw(), - roofline_model_db.peak_vmem_write_bw()) - : 0; - const double flops_utilization = tsl::profiler::SafeDivide( - record.measured_flop_rate(), roofline_model_db.peak_flop_rate()); - const double hbm_utilization = tsl::profiler::SafeDivide( - record.hbm_bw(), roofline_model_db.peak_hbm_bw()); - - const double max_mem_utilization = - std::max({cmem_wr_utilization, cmem_rd_utilization, hbm_utilization, - vmem_wr_utilization, vmem_rd_utilization}); - const double roofline_efficiency = - std::max({max_mem_utilization, flops_utilization}); - // Note, copy-start/done can have utilizations above 1.0 since their - // bytes/time are not accurate as they are asynchronous. - record.set_optimal_flop_rate(tsl::profiler::SafeDivide( - record.measured_flop_rate(), roofline_efficiency)); - record.set_roofline_efficiency(roofline_efficiency); - record.set_flop_rate_relative_to_hw_limit(flops_utilization); - record.set_memory_bw_relative_to_hw_limit(max_mem_utilization); - - record.set_include_infeed_outfeed(include_infeed_outfeed); - - return record; -} - -RooflineModelRecord GenerateRooflineModelProgramRecord( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - OpMetrics program_metrics; - program_metrics.set_name("Program"); - program_metrics.set_category("Program"); - program_metrics.set_occurrences(1); - uint64_t infeed_outfeed_time = 0; - for (const OpMetrics& metrics : db.metrics_db()) { - // Aggregate innermost ops only to avoid redundant counting. - if (tsl::profiler::MayHaveInnerOps(metrics.category())) continue; - if (!include_infeed_outfeed && - tsl::profiler::IsInfeedOrOutfeed(metrics.category())) { - infeed_outfeed_time += metrics.time_ps(); - continue; - } - program_metrics.set_flops(program_metrics.flops() + metrics.flops()); - program_metrics.set_model_flops(program_metrics.model_flops() + - metrics.model_flops()); - program_metrics.set_bytes_accessed(program_metrics.bytes_accessed() + - metrics.bytes_accessed()); - CombineMemoryAccessedBreakdown( - metrics.memory_accessed_breakdown(), - program_metrics.mutable_memory_accessed_breakdown()); - } - uint64_t total_time_ps = db.total_time_ps(); - if (!include_infeed_outfeed) total_time_ps -= infeed_outfeed_time; - program_metrics.set_time_ps(total_time_ps); - RooflineModelRecord program_record = ConvertOpMetricsToRooflineModelRecord( - op_stats, program_metrics, record_type, step_num, total_time_ps, - roofline_model_db, include_infeed_outfeed); - program_record.set_rank(0); - program_record.set_total_self_time_as_fraction(0.0); - program_record.set_cumulative_total_self_time_as_fraction(0.0); - return program_record; -} - -tsl::protobuf::RepeatedPtrField -ConvertOpMetricsDbToRooflineModelRecords( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - tsl::protobuf::RepeatedPtrField roofline_model_records; - RooflineModelRecord* program_record = roofline_model_records.Add(); - *program_record = GenerateRooflineModelProgramRecord( - op_stats, db, record_type, step_num, roofline_model_db, - include_infeed_outfeed); - const RooflineModelRecord* prev_record = program_record; - uint64_t infeed_outfeed_time = 0; - if (!include_infeed_outfeed) { - // Calculate the total time spent on infeed and outfeed ops. - for (const OpMetrics& metrics : db.metrics_db()) { - if (tsl::profiler::IsInfeedOrOutfeed(metrics.category())) { - infeed_outfeed_time += metrics.time_ps(); - } - } - } - uint64_t total_time_ps = db.total_time_ps() - infeed_outfeed_time; - double total_time_us = tsl::profiler::PicoToMicro(total_time_ps); - for (const auto* metrics : SortedOpMetricsDb(db, kMaxNumRecords)) { - if (metrics->occurrences() == 0) continue; - if (!include_infeed_outfeed && - tsl::profiler::IsInfeedOrOutfeed(metrics->category())) { - continue; - } - RooflineModelRecord* record = roofline_model_records.Add(); - *record = ConvertOpMetricsToRooflineModelRecord( - op_stats, *metrics, record_type, step_num, total_time_ps, - roofline_model_db, include_infeed_outfeed); - SetRankAndTimeFractions(total_time_us, *prev_record, record); - prev_record = record; - } - return roofline_model_records; -} - -RooflineModelDatabase InitializeRooflineModelDatabaseFromOpStats( - const OpStats& op_stats, bool include_infeed_outfeed) { - tensorflow::profiler::HardwareType hardware_type = - op_stats.run_environment().hardware_type(); - DCHECK(hardware_type == GPU || hardware_type == TPU); - - RooflineModelDatabase roofline_model_db; - const PerfEnv& perf_env = op_stats.perf_env(); - roofline_model_db.set_device_type(op_stats.run_environment().device_type()); - - // Set peak flop rate in GFLOPs/s. - roofline_model_db.set_peak_flop_rate( - tsl::profiler::TeraToGiga((perf_env.peak_tera_flops_per_second()))); - roofline_model_db.set_peak_hbm_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 0))); - - if (hardware_type == HardwareType::TPU) { - roofline_model_db.set_megacore(perf_env.has_megacore()); - - roofline_model_db.set_has_cmem(perf_env.has_cmem()); - roofline_model_db.set_has_merged_vmem(perf_env.has_merged_vmem()); - if (roofline_model_db.has_cmem()) { - roofline_model_db.set_peak_cmem_read_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 3))); - roofline_model_db.set_peak_cmem_write_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 4))); - } else if (roofline_model_db.has_merged_vmem()) { - roofline_model_db.set_peak_vmem_read_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 5))); - roofline_model_db.set_peak_vmem_write_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 6))); - } - } else if (hardware_type == HardwareType::GPU) { - roofline_model_db.set_megacore(false); - roofline_model_db.set_has_cmem(false); - roofline_model_db.set_has_merged_vmem(true); - roofline_model_db.set_peak_vmem_read_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 1))); - roofline_model_db.set_peak_vmem_write_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 2))); - } - - return roofline_model_db; -} - -RooflineModelDatabase ConvertOpStatsToRooflineModel( - const OpStats& op_stats, bool include_infeed_outfeed) { - HardwareType hardware_type = op_stats.run_environment().hardware_type(); - if (hardware_type != GPU && hardware_type != TPU) { - return RooflineModelDatabase(); - } - - RooflineModelDatabase roofline_model_db = - InitializeRooflineModelDatabaseFromOpStats(op_stats, - include_infeed_outfeed); - - AddRooflineModelRecordForProfileDuration(op_stats, roofline_model_db, - include_infeed_outfeed); - AddRooflineModelRecordsForCompleteSteps(op_stats, roofline_model_db, - include_infeed_outfeed); - AddRooflineModelRecordsPerStep(op_stats, roofline_model_db, - include_infeed_outfeed); - PopulateStepDiagnostics(op_stats, roofline_model_db.mutable_diagnostics()); - return roofline_model_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h deleted file mode 100644 index 65b501bc851725..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ - -#include - -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/platform/protobuf.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/roofline_model.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::OpMetrics; -using tensorflow::profiler::roofline_model::RecordType; -using tensorflow::profiler::roofline_model::RooflineModelDatabase; -using tensorflow::profiler::roofline_model::RooflineModelRecord; - -RooflineModelRecord ConvertOpMetricsToRooflineModelRecord( - const OpStats& op_stats, const OpMetrics& metrics, RecordType record_type, - uint32_t step_num, uint64_t total_time_ps, - const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed); - -RooflineModelRecord GenerateRooflineModelProgramRecord( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed); - -tsl::protobuf::RepeatedPtrField -ConvertOpMetricsDbToRooflineModelRecords( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed); - -tensorflow::profiler::roofline_model::RooflineModelDatabase -ConvertOpStatsToRooflineModel(const tensorflow::profiler::OpStats& tf_op_stats, - bool include_infeed_outfeed); - -tensorflow::profiler::roofline_model::RooflineModelDatabase -InitializeRooflineModelDatabaseFromOpStats(const OpStats& op_stats, - bool include_infeed_outfeed); -// Generate RooflineModelRecord for the HLO DB over the entire profiling -// duration including incomplete steps. -inline void AddRooflineModelRecordForProfileDuration( - const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - *roofline_model_db.mutable_roofline_model_record() = - ConvertOpMetricsDbToRooflineModelRecords( - op_stats, op_stats.device_op_metrics_db(), RecordType::ALL, - /*step_num=*/0, roofline_model_db, include_infeed_outfeed); -} - -// Generate RooflineModelRecord for the HLO DB over complete steps only. -inline void AddRooflineModelRecordsForCompleteSteps( - const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - if (op_stats.has_hlo_metrics_db_complete_steps_only()) { - *roofline_model_db.add_roofline_model_record() = - GenerateRooflineModelProgramRecord( - op_stats, op_stats.hlo_metrics_db_complete_steps_only(), - RecordType::AVERAGE_STEP, /*step_num=*/0, roofline_model_db, - include_infeed_outfeed); - } -} - -// Generate RooflineModelRecords for the per-step DBs. -inline void AddRooflineModelRecordsPerStep( - const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - for (const auto& step_info : op_stats.step_db().step_sequence()) { - *roofline_model_db.add_roofline_model_record() = - GenerateRooflineModelProgramRecord( - op_stats, step_info.hlo_metrics_db(), RecordType::PER_STEP, - step_info.step_num(), roofline_model_db, include_infeed_outfeed); - } -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc deleted file mode 100644 index 88da45e892ecd5..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tf_stats.pb.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -// The maximum number of Tensorflow Ops displayed on Tensorflow Stats page. -// 500 device side ops and 500 host side ops. -const int kMaxNumOfOps = 500; - -TfStatsRecord ConvertOpMetricsToTfStatsRecord(bool on_device, - const OpMetrics& metrics, - const PerfEnv& perf_env, - const RunEnvironment& run_env) { - TfStatsRecord record; - record.set_host_or_device(on_device ? "Device" : "Host"); - record.set_is_eager(metrics.is_eager()); - record.set_op_type(metrics.category()); - record.set_op_name(metrics.name()); - SetExecutionTimes(metrics, &record); - SetRooflineMetrics(metrics, perf_env, run_env, &record); - return record; -} - -TfStatsTable GenerateTfStatsTable( - const OpMetricsDb& host_tf_metrics_db, - const OpMetricsDb& device_tf_metrics_db, - const KernelStatsByOpName& kernel_stats_by_op_name, const PerfEnv& perf_env, - const RunEnvironment& run_env, bool exclude_idle) { - TfStatsTable tf_stats_table; - TfStatsRecord sentinel; - sentinel.set_rank(0); - sentinel.set_device_cumulative_total_self_time_as_fraction(0.0); - sentinel.set_host_cumulative_total_self_time_as_fraction(0.0); - const TfStatsRecord* prev_record = &sentinel; - - // Sets device-side TF stats. - uint64 total_device_time_ps = TotalTimePs(device_tf_metrics_db, exclude_idle); - double total_device_time_us = - tsl::profiler::PicoToMicro(total_device_time_ps); - for (const OpMetrics* metrics : - SortedOpMetricsDb(device_tf_metrics_db, kMaxNumOfOps)) { - if (exclude_idle && IsIdleOp(*metrics)) continue; - TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); - *record = ConvertOpMetricsToTfStatsRecord( - /*on_device=*/true, *metrics, perf_env, run_env); - // Compute TensorCore utilization only on device side. - auto iter = kernel_stats_by_op_name.find(record->op_name()); - if (iter != kernel_stats_by_op_name.end()) { - record->set_gpu_tensorcore_utilization( - tsl::profiler::SafeDivide(iter->second.tensor_core_duration_ns, - iter->second.total_duration_ns)); - } else { - record->set_gpu_tensorcore_utilization(0.0); - } - SetRankAndDeviceTimeFractions(total_device_time_us, *prev_record, record); - prev_record = record; - } - - // Sets host-side TF stats. - uint64 total_host_time_ps = TotalTimePs(host_tf_metrics_db, exclude_idle); - double total_host_time_us = tsl::profiler::PicoToMicro(total_host_time_ps); - for (const OpMetrics* metrics : tensorflow::profiler::SortedOpMetricsDb( - host_tf_metrics_db, kMaxNumOfOps)) { - if (exclude_idle && IsIdleOp(*metrics)) continue; - TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); - *record = ConvertOpMetricsToTfStatsRecord( - /*on_device=*/false, *metrics, perf_env, run_env); - // Host side TensorCore utilization is always 0.0 - record->set_gpu_tensorcore_utilization(0.0); - SetRankAndHostTimeFractions(total_host_time_us, *prev_record, record); - prev_record = record; - } - return tf_stats_table; -} - -} // namespace - -TfStatsDatabase ConvertOpStatsToTfStats(const OpStats& op_stats) { - const OpMetricsDb& host_tf_metrics_db = op_stats.host_op_metrics_db(); - OpMetricsDb device_tf_metrics_db = - CreateTfMetricsDbFromDeviceOpMetricsDb(op_stats.device_op_metrics_db()); - const PerfEnv perf_env = op_stats.perf_env(); - const RunEnvironment run_env = op_stats.run_environment(); - KernelStatsByOpName kernel_stats_by_op_name = - GroupKernelReportsByOpName(op_stats.kernel_stats_db()); - TfStatsDatabase tf_stats_db; - *tf_stats_db.mutable_with_idle() = GenerateTfStatsTable( - host_tf_metrics_db, device_tf_metrics_db, kernel_stats_by_op_name, - perf_env, run_env, /*exclude_idle=*/false); - *tf_stats_db.mutable_without_idle() = GenerateTfStatsTable( - host_tf_metrics_db, device_tf_metrics_db, kernel_stats_by_op_name, - perf_env, run_env, /*exclude_idle=*/true); - tf_stats_db.set_device_type(op_stats.run_environment().device_type()); - return tf_stats_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h deleted file mode 100644 index 79994d5570ec3e..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ - -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tf_stats.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -TfStatsDatabase ConvertOpStatsToTfStats(const OpStats& op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc deleted file mode 100644 index daaae982635d13..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tf_stats.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -XEventBuilder AddTensorFlowOpEvent(std::string&& tf_op_fullname, - int64_t start_timestamp_ns, - int64_t duration_ns, bool on_device, - absl::string_view kernel_name, - XPlaneBuilder* plane, XLineBuilder* line) { - absl::string_view name = on_device ? kernel_name : tf_op_fullname; - XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - if (!on_device) return event; - event.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); - return event; -} - -void AddTensorFlowOpEventWithKernelDetails(std::string&& tf_op_fullname, - int64_t start_timestamp_ns, - int64_t duration_ns, bool on_device, - absl::string_view kernel_name, - absl::string_view kernel_details, - XPlaneBuilder* plane, - XLineBuilder* line) { - XEventBuilder event = - AddTensorFlowOpEvent(std::move(tf_op_fullname), start_timestamp_ns, - duration_ns, on_device, kernel_name, plane, line); - if (!on_device) return; - event.ParseAndAddStatValue(*plane->GetOrCreateStatMetadata("kernel_details"), - kernel_details); -} - -TEST(OpStatsToTfStats, GpuTfStats) { - // TfOp1 has kernel1 and kernel2; TfOp2 has kernel3; - // TfOp3 has kernel4 and kernel5 and is TensorCore eligible. - static constexpr char kTfOp1[] = "TfOp1"; - static constexpr char kTfOp2[] = "TfOp2"; - static constexpr char kTfOp3[] = "Conv2D"; - static constexpr char kKernel1[] = "kernel1"; - static constexpr char kKernel2[] = "kernel2"; - static constexpr char kKernel3[] = "kernel3"; - // Kernel4 is a kernel using TensorCore - static constexpr char kKernel4[] = "volta_fp16_s884gemm"; - static constexpr char kKernel5[] = "kernel5"; - constexpr int64_t kKernel1StartNs = 100000; - constexpr int64_t kKernel1DurationNs = 8000; - constexpr int64_t kKernel2StartNs = 110000; - constexpr int64_t kKernel2DurationNs = 10000; - constexpr int64_t kKernel3StartNs = 120000; - constexpr int64_t kKernel3DurationNs = 10000; - constexpr int64_t kKernel4StartNs = 130000; - constexpr int64_t kKernel4DurationNs = 10000; - constexpr int64_t kKernel5StartNs = 150000; - constexpr int64_t kKernel5DurationNs = 10000; - - // Mock kernel details for both kernel4 and kernel5. - const std::string kKernelDetails = R"MULTI(regs:32 -static_shared:0 -dynamic_shared:16384 -grid:2,1,1 -block:32,1,1 -occ_pct:100)MULTI"; - - XSpace space; - XPlaneBuilder device_plane( - GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0)); - XLineBuilder stream1 = device_plane.GetOrCreateLine(/*line_id=*/10); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream1); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream1); - XLineBuilder stream2 = device_plane.GetOrCreateLine(/*line_id=*/20); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kKernel3StartNs, - kKernel3DurationNs, /*on_device=*/true, kKernel3, - &device_plane, &stream2); - AddTensorFlowOpEventWithKernelDetails( - absl::StrCat(kTfOp3, ":", kTfOp3), kKernel4StartNs, kKernel4DurationNs, - /*on_device=*/true, kKernel4, kKernelDetails, &device_plane, &stream2); - AddTensorFlowOpEventWithKernelDetails( - absl::StrCat(kTfOp3, ":", kTfOp3), kKernel5StartNs, kKernel5DurationNs, - /*on_device=*/true, kKernel5, kKernelDetails, &device_plane, &stream2); - - OpStatsOptions options; - options.generate_kernel_stats_db = true; - options.generate_op_metrics_db = true; - const OpStats op_stats = ConvertXSpaceToOpStats(space, options); - const TfStatsDatabase tf_stats = ConvertOpStatsToTfStats(op_stats); - - EXPECT_EQ(tf_stats.device_type(), op_stats.run_environment().device_type()); - - // TfOp1, TfOp3, TfOp2, Idle - EXPECT_EQ(4, tf_stats.with_idle().tf_stats_record_size()); - - const TfStatsRecord& record_0 = tf_stats.with_idle().tf_stats_record(0); - EXPECT_EQ(kTfOp1, record_0.op_name()); - EXPECT_EQ(kTfOp1, record_0.op_type()); - EXPECT_EQ(2, record_0.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToMicro(kKernel1DurationNs) * 2 + - tsl::profiler::NanoToMicro(kKernel2DurationNs) * 2, - record_0.total_self_time_in_us()); - - const TfStatsRecord& record_1 = tf_stats.with_idle().tf_stats_record(1); - EXPECT_EQ(kTfOp3, record_1.op_name()); - EXPECT_EQ(kTfOp3, record_1.op_type()); - EXPECT_EQ(1, record_1.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToMicro(kKernel4DurationNs) + - tsl::profiler::NanoToMicro(kKernel5DurationNs), - record_1.total_self_time_in_us()); - // GPU TensorCore utilization is 0.5 because kernel4 is using TensorCore and - // kernel5 is not using TensorCore, and they have the same duration. - EXPECT_DOUBLE_EQ(0.5, record_1.gpu_tensorcore_utilization()); - - const TfStatsRecord& record_2 = tf_stats.with_idle().tf_stats_record(2); - EXPECT_EQ(kTfOp2, record_2.op_name()); - EXPECT_EQ(kTfOp2, record_2.op_type()); - EXPECT_EQ(1, record_2.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToMicro(kKernel3DurationNs), - record_2.total_self_time_in_us()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc deleted file mode 100644 index 9824ef17fc9b53..00000000000000 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" - -#include - -#include "absl/strings/match.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/preprocess_xplane.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/utils/derived_timeline.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -void PreprocessSingleHostXSpace( - XSpace* space, bool step_grouping, bool derived_timeline, - tsl::profiler::GroupMetadataMap* group_metadata_map) { - if (step_grouping && !tsl::profiler::IsXSpaceGrouped(*space)) { - // Grouping (i.e. marking step number) events in the XSpace. - std::vector device_traces; - bool isTpu = false; - for (XPlane& plane : *space->mutable_planes()) { - if (tsl::profiler::IsDevicePlane(plane)) { - device_traces.push_back(&plane); - } - // Preprocess XPlane to convert stats to Traceme2 semantics - tsl::profiler::PreprocessXPlane(&plane); - - if (!isTpu && absl::StartsWith(plane.name(), kTpuPlanePrefix)) { - isTpu = true; - } - } - - tsl::profiler::EventForest event_forest; - if (isTpu) { - // group TPU events - GroupTpuEventsOSS(space, device_traces, &event_forest); - } else { - // group GPU events - tsl::profiler::GroupTfEvents(space, &event_forest); - } - - if (derived_timeline) { - // Generated miscellaneous derived time lines for device planes. - GenerateDerivedTimeLines(event_forest.GetGroupMetadataMap(), space); - } - - if (group_metadata_map != nullptr) { - *group_metadata_map = event_forest.GetGroupMetadataMap(); - } - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h deleted file mode 100644 index 4c86ed8758bc4a..00000000000000 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ - -#include "xla/tsl/profiler/utils/group_events.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Preprocess XSpaces before tools conversion. -// If step_grouping = true, perform events grouping for step tracking. -// If derived_timeline, generate derived timeline (XLines). -// If group_metadata_map is not nullptr, populate the group metadata map. -void PreprocessSingleHostXSpace( - XSpace* space, bool step_grouping, bool derived_timeline, - tsl::profiler::GroupMetadataMap* group_metadata_map = nullptr); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.cc b/tensorflow/core/profiler/convert/process_megascale_dcn.cc deleted file mode 100644 index febad594d6349b..00000000000000 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/process_megascale_dcn.h" - -#include - -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/convert/dcn_analysis.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::CreateTfXPlaneVisitor; -using tsl::profiler::FindMutableTensorCorePlanes; - -void ProcessMegascaleDcn(XSpace* space) { - std::vector device_xplanes = FindMutableTensorCorePlanes(space); - int num_tpu_cores = device_xplanes.size(); - // DCN TraceMe's are in the Host XPlane - XPlane* host_plane = - FindMutablePlaneWithName(space, tsl::profiler::kHostThreadsPlaneName); - const XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(host_plane); - // TODO(yashjs): Update parameter value for `is_megacore`. - DcnEventsProcessor dcn_events_processor(num_tpu_cores, false); - dcn_events_processor.SetupMessageInfo(plane_visitor); - if (dcn_events_processor.HasDcnMessages( - tsl::profiler::kMegaScaleDcnReceive)) { - dcn_events_processor.ProcessReceiveMessages(plane_visitor); - } - // Update host XPlane with DCN traffic - dcn_events_processor.AddHostDcnTrafficToXPlane(host_plane); - // Update device XPlanes with per collective TPU traffic. - for (XPlane* device_xplane : device_xplanes) { - dcn_events_processor.AddTpuCollectiveDcnTrafficToXPlane(device_xplane); - } - - SortXSpace(space); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.h b/tensorflow/core/profiler/convert/process_megascale_dcn.h deleted file mode 100644 index 794c2bea66462a..00000000000000 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ - -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Process Dcn Megascale TraceMe info. -void ProcessMegascaleDcn(XSpace* space); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ diff --git a/tensorflow/core/profiler/convert/profile_time_breakdown.cc b/tensorflow/core/profiler/convert/profile_time_breakdown.cc deleted file mode 100644 index e1826a7119f9a2..00000000000000 --- a/tensorflow/core/profiler/convert/profile_time_breakdown.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/profile_time_breakdown.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" - -namespace tensorflow { -namespace profiler { - -void ProfileTimeBreakdown::SetCategoryTimePs(absl::string_view category, - uint64_t time_ps) { - time_ps_by_category_.insert_or_assign(category, time_ps); -} - -uint64_t ProfileTimeBreakdown::PopCategoryTimePs(absl::string_view category) { - uint64_t time_ps = 0; - auto iter = time_ps_by_category_.find(category); - if (iter != time_ps_by_category_.end()) { - time_ps = iter->second; - time_ps_by_category_.erase(iter); - } - return time_ps; -} - -void ProfileTimeBreakdown::BreakdownSparseCoreV0Infeed() { - // Infeed from SparseCoreV0 and outfeed to SparseCoreV0 are mostly identical - // in compute since they do the same transformation. We can subtract out the - // outfeed time from the infeed time to know how much time the TensorCore - // actually spent waiting on SparseCoreV0. - uint64_t bc_infeed_ps = - PopCategoryTimePs(tsl::profiler::kHloSparseCoreV0Infeed); - if (bc_infeed_ps == 0) return; - uint64_t bc_outfeed_ps = - CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); - - uint64_t bc_infeed_transform_ps = std::min(bc_infeed_ps, bc_outfeed_ps); - uint64_t bc_infeed_wait_ps = bc_infeed_ps - bc_infeed_transform_ps; - - SetCategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait, - bc_infeed_wait_ps); - SetCategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform, - bc_infeed_transform_ps); -} - -std::string ProfileTimeBreakdown::DebugString() const { - std::string str; - for (const auto& [category, time_ps] : time_ps_by_category_) { - absl::StrAppend(&str, category, ": ", tsl::profiler::PicoToUni(time_ps), - "\n"); - } - absl::StrAppend( - &str, "total_time: ", tsl::profiler::PicoToUni(total_time_ps_), "\n"); - absl::StrAppend( - &str, "profile_time: ", tsl::profiler::PicoToUni(profile_time_ps_), "\n"); - return str; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/profile_time_breakdown.h b/tensorflow/core/profiler/convert/profile_time_breakdown.h index 1e3379beb4c457..9b68baad5ecf79 100644 --- a/tensorflow/core/profiler/convert/profile_time_breakdown.h +++ b/tensorflow/core/profiler/convert/profile_time_breakdown.h @@ -15,230 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" - -namespace tensorflow { -namespace profiler { - -// Allows accumulating time spent in different HLO instruction categories to -// breakdown the total profile time and compute metrics of interest. -class ProfileTimeBreakdown { - public: - // Category should be the operator category disambiguated by xprof instead of - // the original category from XLA. - // For a correct time breakdown, we need to use the self time of operators, - // instead of total time to avoid double counting. Note that for leaf ops, - // self time and total time are the same. - void IncrementCategoryTimePs(absl::string_view category, - uint64_t self_time_ps) { - time_ps_by_category_[category] += self_time_ps; - total_time_ps_ += self_time_ps; - } - - // Profile time cannot be smaller than the total time in all categories. - // If combining profiles across multiple cores, profile time should be the - // profiling duration multiplied by the number of cores that were profiled. - // go/autograppler_profile_time - void SetProfileTimePs(uint64_t profile_time_ps) { - DCHECK_LE(total_time_ps_, profile_time_ps); - profile_time_ps_ = profile_time_ps; - } - - // Breaks down "sparsecorev0 infeed" into two components: - // 1) "sparsecorev0 infeed wait": Time spent waiting on the SparseCoreV0. - // 2) "sparsecorev0 infeed transform": Time spent transforming activations in - // SparseCoreV0 layout into XLA layout. - // Even though 2) is part of the overall embedding computation, it is time - // spent doing work on the TensorCore. - void BreakdownSparseCoreV0Infeed(); - - // Duty cycle is the fraction of time an accelerator is being actively used. - // go/accelerator-metrics-definitions#common-accelerator-metrics - // go/ag-tpu-duty-cycle - double DutyCycle() const { return TimeFraction(OnDutyTimePs()); } - - double IdleFraction() const { return TimeFraction(IdleTimePs()); } - - double InfeedFraction() const { - return CategoryFraction(tsl::profiler::kHloInfeed); - } - - double OutfeedFraction() const { - return CategoryFraction(tsl::profiler::kHloOutfeed); - } - - double SparseCoreV0InfeedFraction() const { - return CategoriesFraction({tsl::profiler::kHloSparseCoreV0Infeed, - tsl::profiler::kHloSparseCoreV0InfeedWait, - tsl::profiler::kHloSparseCoreV0InfeedTransform}); - } - - double SparseCoreV0OutfeedFraction() const { - return CategoryFraction(tsl::profiler::kHloSparseCoreV0Outfeed); - } - - double AllReduceFraction() const { - return CategoryFraction(tsl::profiler::kHloAllReduce); - } - - double AllReduceFusionFraction() const { - return CategoryFraction(tsl::profiler::kHloAllReduceFusion); - } - - double SendRecvFraction() const { - return CategoriesFraction( - {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone, - tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); - } - - double HostSendRecvFraction() const { - return CategoriesFraction( - {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, - tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); - } - - double CategoriesFraction( - const std::initializer_list& categories) const { - return TimeFraction(CategoriesTimePs(categories)); - } - - double CategoryFraction(absl::string_view category) const { - return TimeFraction(CategoryTimePs(category)); - } - - uint64_t ProfileTimePs() const { return profile_time_ps_; } - - uint64_t TotalTimePs() const { return total_time_ps_; } - - uint64_t IdleTimePs() const { return profile_time_ps_ - total_time_ps_; } - - uint64_t OnDutyTimePs() const { return profile_time_ps_ - OffDutyTimePs(); } - - uint64_t OffDutyTimePs() const { - return IdleTimePs() + - CategoriesTimePs( - {tsl::profiler::kHloInfeed, tsl::profiler::kHloOutfeed, - tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, - tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone, - tsl::profiler::kHloMegacoreFusion}); - } - - uint64_t InfeedTimePs() const { - return CategoryTimePs(tsl::profiler::kHloInfeed); - } - - uint64_t OutfeedTimePs() const { - return CategoryTimePs(tsl::profiler::kHloOutfeed); - } - - uint64_t SparseCoreV0InfeedWaitTimePs() const { - return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait); - } - - uint64_t SparseCoreV0InfeedTransformTimePs() const { - return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform); - } - - uint64_t SparseCoreV0OutfeedTimePs() const { - return CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); - } - - uint64_t AllReduceOrAllToAllTimePs() const { - return CategoriesTimePs({tsl::profiler::kHloAllReduce, - tsl::profiler::kHloAllReduceFusion, - tsl::profiler::kHloAllToAll}); - } - - uint64_t SendTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone}); - } - - uint64_t RecvTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); - } - - uint64_t HostSendTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone}); - } - - uint64_t HostRecvTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); - } - - // Megacore fusion runs different operations on each core, e.g., a convolution - // on one core and an all-reduce on the other core. In a trace, megacore - // fusion is the parent operation, and its self time is the time that the core - // executing the faster operation waits for the core executing the slower - // operation to reach the synchronization point. - uint64_t MegacoreFusionTimePs() const { - return CategoryTimePs(tsl::profiler::kHloMegacoreFusion); - } - - uint64_t HighFlopsComputeTimePs() const { - return CategoriesTimePs({tsl::profiler::kHloConvolution, - tsl::profiler::kHloConvolutionBaseDilated, - tsl::profiler::kHloConvolutionWindowDilated, - tsl::profiler::kHloConvolutionFusion, - tsl::profiler::kHloOutputFusion}); - } - - // Calculated according to the "TC busy time" defined in go/tpu_kpis - uint64_t TensorCoreBusyTimePs() const { - return profile_time_ps_ - OffDutyTimePs() - SparseCoreV0InfeedWaitTimePs(); - } - - uint64_t CategoriesTimePs( - const std::initializer_list& categories) const { - uint64_t time_ps = 0; - for (auto category : categories) { - time_ps += CategoryTimePs(category); - } - return time_ps; - } - - uint64_t CategoryTimePs(absl::string_view category) const { - auto iter = time_ps_by_category_.find(category); - return (iter == time_ps_by_category_.end()) ? 0 : iter->second; - } - - template - void ComputeCategoryFractions(Map& category_fractions) { - for (const auto& [category, time_ps] : time_ps_by_category_) { - category_fractions[category] = TimeFraction(time_ps); - } - } - - std::string DebugString() const; - - private: - // Overwrites the time attributed to the given category. - void SetCategoryTimePs(absl::string_view category, uint64_t time_ps); - - // Removes and returns the time attributed to the given category. - uint64_t PopCategoryTimePs(absl::string_view category); - - double TimeFraction(uint64_t time_ps) const { - return tsl::profiler::SafeDivide(time_ps, profile_time_ps_); - } - - absl::flat_hash_map time_ps_by_category_; - uint64_t total_time_ps_ = 0; // Sum of values in time_ps_by_category_. - uint64_t profile_time_ps_ = 0; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/convert/profile_time_breakdown.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ diff --git a/tensorflow/core/profiler/convert/repository.cc b/tensorflow/core/profiler/convert/repository.cc deleted file mode 100644 index 6fcadd8caf65c0..00000000000000 --- a/tensorflow/core/profiler/convert/repository.cc +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/repository.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tensorflow/core/platform/errors.h" -#include "tsl/platform/path.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { -std::string GetHostnameByPath(absl::string_view xspace_path) { - std::string_view file_name = tsl::io::Basename(xspace_path); - // Remove suffix from file_name, preserving entire prefix. - absl::ConsumeSuffix(&file_name, ".xplane.pb"); - return std::string(file_name); -} -} // namespace - -absl::StatusOr SessionSnapshot::Create( - std::vector xspace_paths, - std::optional>> xspaces) { - if (xspace_paths.empty()) { - return errors::InvalidArgument("Can not find XSpace path."); - } - - if (xspaces.has_value()) { - if (xspaces->size() != xspace_paths.size()) { - return errors::InvalidArgument( - "The size of the XSpace paths: ", xspace_paths.size(), - " is not equal ", - "to the size of the XSpace proto: ", xspaces->size()); - } - for (size_t i = 0; i < xspace_paths.size(); ++i) { - auto host_name = GetHostnameByPath(xspace_paths.at(i)); - if (xspaces->at(i)->hostnames_size() > 0 && !host_name.empty()) { - if (!absl::StrContains(host_name, xspaces->at(i)->hostnames(0))) { - return errors::InvalidArgument( - "The hostname of xspace path and preloaded xpace don't match at " - "index: ", - i, ". \nThe host name of xpace path is ", host_name, - " but the host name of preloaded xpace is ", - xspaces->at(i)->hostnames(0), "."); - } - } - } - } - - return SessionSnapshot(std::move(xspace_paths), std::move(xspaces)); -} - -absl::StatusOr> SessionSnapshot::GetXSpace( - size_t index) const { - if (index >= xspace_paths_.size()) { - return errors::InvalidArgument("Can not get the ", index, - "th XSpace. The total number of XSpace is ", - xspace_paths_.size()); - } - - // Return the pre-loaded XSpace proto. - if (xspaces_.has_value()) { - if (xspaces_->at(index) == nullptr) { - return errors::Internal(""); - } - return std::move(xspaces_->at(index)); - } - - // Return the XSpace proto from file. - auto xspace_from_file = std::make_unique(); - TF_RETURN_IF_ERROR(tsl::ReadBinaryProto( - tsl::Env::Default(), xspace_paths_.at(index), xspace_from_file.get())); - return xspace_from_file; -} - -absl::StatusOr> SessionSnapshot::GetXSpaceByName( - absl::string_view name) const { - if (auto it = hostname_map_.find(name); it != hostname_map_.end()) { - return GetXSpace(it->second); - } - - return errors::InvalidArgument("Can not find the XSpace by name: ", name, - ". The total number of XSpace is ", - xspace_paths_.size()); -} - -std::string SessionSnapshot::GetHostname(size_t index) const { - return GetHostnameByPath(xspace_paths_.at(index)); -} - -std::optional SessionSnapshot::GetFilePath( - absl::string_view toolname, absl::string_view hostname) const { - if (!has_accessible_run_dir_) return std::nullopt; - std::string file_name = ""; - if (toolname == "trace_viewer@") - file_name = absl::StrCat(hostname, ".", "SSTABLE"); - if (!file_name.empty()) return tsl::io::JoinPath(session_run_dir_, file_name); - return std::nullopt; -} - -absl::StatusOr SessionSnapshot::GetHostDataFileName( - const StoredDataType data_type, const std::string host) const { - for (const auto& format : *kHostDataSuffixes) { - if (data_type == format.first) return absl::StrCat(host, format.second); - } - return absl::InternalError(&"Unknown StoredDataType: "[data_type]); -} - -absl::StatusOr> SessionSnapshot::GetHostDataFilePath( - const StoredDataType data_type, const std::string host) const { - // Gets all the files in session run directory. - std::vector results; - TF_RETURN_IF_ERROR(::tsl::Env::Default()->GetChildren( - std::string(GetSessionRunDir()), &results)); - - TF_ASSIGN_OR_RETURN(std::string filename, - GetHostDataFileName(data_type, host)); - - for (const std::string& path : results) { - if (absl::EndsWith(path, filename)) { - return ::tsl::profiler::ProfilerJoinPath(GetSessionRunDir(), filename); - } - } - - return std::nullopt; -} - -absl::StatusOr> SessionSnapshot::HasCacheFile( - const StoredDataType data_type) const { - std::optional filepath; - TF_ASSIGN_OR_RETURN(filepath, - GetHostDataFilePath(data_type, kNoHostIdentifier)); - if (filepath) { - // cache file is present but file contains no data_type events - return std::pair(true, std::string()); - } - - TF_ASSIGN_OR_RETURN(filepath, - GetHostDataFilePath(data_type, kAllHostsIdentifier)); - if (filepath) { - // cache file is present and file contains data_type events - return std::pair(true, filepath.value()); - } - - // no cache file present - return std::pair(false, std::string()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h deleted file mode 100644 index 2db0ad41777384..00000000000000 --- a/tensorflow/core/profiler/convert/repository.h +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tsl/platform/path.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/utils/hlo_module_map.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -constexpr char kAllHostsIdentifier[] = "ALL_HOSTS"; -constexpr char kNoHostIdentifier[] = "NO_HOST"; - -enum StoredDataType { - DCN_COLLECTIVE_STATS, - OP_STATS, -}; - -static auto* kHostDataSuffixes = - new std::vector>( - {{StoredDataType::DCN_COLLECTIVE_STATS, ".dcn_collective_stats.pb"}, - {StoredDataType::OP_STATS, ".op_stats.pb"}}); - -// File system directory snapshot of a profile session. -class SessionSnapshot { - public: - // Performs validation and creates SessionSnapshot. - // are the file paths to XSpace protos. - // Optionally, can contain the XSpace protos pre-loaded by the - // profiler plugin. - static absl::StatusOr Create( - std::vector xspace_paths, - std::optional>> xspaces); - - // Returns the number of XSpaces in the profile session. - size_t XSpaceSize() const { return xspace_paths_.size(); } - - // Gets XSpace proto. - // The caller of this function will take ownership of the XSpace. - absl::StatusOr> GetXSpace(size_t index) const; - - // Gets XSpace proto. - // The caller of this function will take ownership of the XSpace. - absl::StatusOr> GetXSpaceByName( - absl::string_view name) const; - - // Gets host name. - std::string GetHostname(size_t index) const; - - // Gets the run directory of the profile session. - absl::string_view GetSessionRunDir() const { return session_run_dir_; } - - // Gets whether the session has an accessible run dir. If false, any - // path-based file read will be disabled in this mode. - bool HasAccessibleRunDir() const { return has_accessible_run_dir_; } - - // Gets the path of the fast file for a given tool. - std::optional GetFilePath(absl::string_view toolname, - absl::string_view host) const; - - // Gets the name of the host data file. - absl::StatusOr GetHostDataFileName(StoredDataType data_type, - std::string host) const; - - // Gets the path of the host data file. - absl::StatusOr> GetHostDataFilePath( - StoredDataType data_type, std::string host) const; - - /* Gets whether the cache file is present in run dir. First value indicates - whether cache file is present or not. Second value indicates the path of cache - file. Possible cases are: - 1. : If no cache file is present - 2. : If cache file is present but file contains no data_type - events - 3. : If cache file is present and file contains data_type - events - */ - absl::StatusOr> HasCacheFile( - StoredDataType data_type) const; - - template - absl::Status WriteBinaryProto(const StoredDataType data_type, - const std::string host, T& proto) const { - // Gets name for host data file. - TF_ASSIGN_OR_RETURN(std::string filename, - GetHostDataFileName(data_type, host)); - - std::string filepath = - tsl::profiler::ProfilerJoinPath(GetSessionRunDir(), filename); - - return tsl::WriteBinaryProto(tsl::Env::Default(), filepath, proto); - } - - template - absl::Status ReadBinaryProto(const StoredDataType data_type, - const std::string host, T* proto) const { - // Gets file path for host data. - TF_ASSIGN_OR_RETURN(std::optional filepath, - GetHostDataFilePath(data_type, host)); - if (filepath) { - return tsl::ReadBinaryProto(tsl::Env::Default(), filepath.value(), proto); - } - - return absl::NotFoundError( - absl::StrCat("No binary proto found for ", host, " and ", data_type)); - } - - private: - SessionSnapshot(std::vector xspace_paths, - std::optional>> xspaces) - : xspace_paths_(std::move(xspace_paths)), - // If the snapshot was initialized by xspaces, the file path and run dir - // is a path tensorflow can't read from or write to so any file IO - // encapsulated in this class will be disabled in this mode. - has_accessible_run_dir_(!xspaces.has_value()), - xspaces_(std::move(xspaces)) { - session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0)); - for (size_t i = 0; i < xspace_paths_.size(); ++i) { - std::string host_name = GetHostname(i); - hostname_map_[host_name] = i; - } - } - - // File paths to XSpace protos. - std::vector xspace_paths_; - // The run directory of the profile session. - absl::string_view session_run_dir_; - - absl::flat_hash_map - hostname_map_; - - const bool has_accessible_run_dir_; - - // XSpace protos pre-loaded by the profiler plugin. - // TODO(profiler): Use blobstore paths to initialize SessionSnapshot instead - // of using pre-loaded XSpaces. - mutable std::optional>> xspaces_; -}; - -// Writes binary proto format T for a host and data_type to a session. -template -absl::Status WriteBinaryProto(const SessionSnapshot& session_snapshot, - const StoredDataType data_type, - const std::string& host, T& proto) { - return session_snapshot.WriteBinaryProto(data_type, host, proto); -} - -// Reads binary proto format T for a host and data_type to a session. -template -absl::Status ReadBinaryProto(const SessionSnapshot& session_snapshot, - const StoredDataType data_type, - const std::string& host, T* proto) { - return session_snapshot.ReadBinaryProto(data_type, host, proto); -} - -// TODO(b/408280338) Remove this function as 0 reference is found. -// Add a dummy cost_analysis factory function as a no-op now. -// Process HloModuleMap from all XSpaces in a session. -inline absl::StatusOr ProcessHloModuleMap( - const SessionSnapshot& session_snapshot) { - HloModuleMap hlo_module_map; - tensorflow::profiler::HloCostAnalysisWrapper::Factory create_cost_analysis = - []() { return nullptr; }; - for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - ProcessHloModuleMapFromXSpace(hlo_module_map, xspace.get(), - create_cost_analysis); - } - return hlo_module_map; -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ diff --git a/tensorflow/core/profiler/convert/repository_test.cc b/tensorflow/core/profiler/convert/repository_test.cc deleted file mode 100644 index 3f3872bd13fd8b..00000000000000 --- a/tensorflow/core/profiler/convert/repository_test.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/repository.h" - -#include -#include -#include -#include - -#include -#include -#include "xla/tsl/platform/status.h" -#include "tensorflow/core/platform/errors.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::Eq; - -TEST(Repository, GetHostName) { - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname0.xplane.pb", - "log/plugins/profile/hostname1.xplane.pb"}, - /*xspaces=*/std::nullopt); - TF_CHECK_OK(session_snapshot_or.status()); - EXPECT_THAT(session_snapshot_or.value().GetHostname(0), Eq("hostname0")); - EXPECT_THAT(session_snapshot_or.value().GetHostname(1), Eq("hostname1")); - EXPECT_TRUE(session_snapshot_or.value().HasAccessibleRunDir()); -} - -TEST(Repository, GetHostNameWithPeriods) { - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/127.0.0.1_6009.xplane.pb"}, - /*xspaces=*/std::nullopt); - TF_CHECK_OK(session_snapshot_or.status()); - EXPECT_THAT(session_snapshot_or.value().GetHostname(0), Eq("127.0.0.1_6009")); - EXPECT_TRUE(session_snapshot_or.value().HasAccessibleRunDir()); -} - -TEST(Repository, GetSpaceByHostName) { - std::vector> xspaces; - // prepare host 1. - auto space1 = std::make_unique(); - *(space1->add_hostnames()) = "hostname1"; - // with index 0 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space1)); - - // prepare host 0. - auto space0 = std::make_unique(); - *(space0->add_hostnames()) = "hostname0"; - // with index 1 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space0)); - - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname1.xplane.pb", - "log/plugins/profile/hostname0.xplane.pb"}, - std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - auto xspace0_or = session_snapshot_or.value().GetXSpaceByName("hostname0"); - TF_CHECK_OK(xspace0_or.status()); - auto xspace1_or = session_snapshot_or.value().GetXSpaceByName("hostname1"); - EXPECT_FALSE(session_snapshot_or.value().HasAccessibleRunDir()); - TF_CHECK_OK(xspace1_or.status()); - EXPECT_THAT(xspace0_or.value()->hostnames(0), Eq("hostname0")); - EXPECT_THAT(xspace1_or.value()->hostnames(0), Eq("hostname1")); -} - -TEST(Repository, GetSSTableFile) { - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname0.xplane.pb"}, - /*xspaces=*/std::nullopt); - TF_CHECK_OK(session_snapshot_or.status()); - auto sstable_path = - session_snapshot_or.value().GetFilePath("trace_viewer@", "hostname0"); - auto not_found_path = - session_snapshot_or.value().GetFilePath("memory_viewer", "hostname0"); - EXPECT_THAT(sstable_path, Eq("log/plugins/profile/hostname0.SSTABLE")); - EXPECT_THAT(not_found_path, Eq(std::nullopt)); -} - -TEST(Repository, GetSSTableFileWithXSpace) { - std::vector> xspaces; - // prepare host 0. - auto space0 = std::make_unique(); - *(space0->add_hostnames()) = "hostname0"; - // with index 1 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space0)); - auto session_snapshot_or = SessionSnapshot::Create( - {"log/plugins/profile/hostname0.xplane.pb"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - auto file_path_init_by_xspace = - session_snapshot_or.value().GetFilePath("trace_viewer@", "hostname0"); - // The file path should be disabled in this mode. - EXPECT_THAT(file_path_init_by_xspace, Eq(std::nullopt)); -} - -TEST(Repository, MismatchedXSpaceAndPath) { - std::vector> xspaces; - // prepare host 1. - auto space1 = std::make_unique(); - *(space1->add_hostnames()) = "hostname1"; - // with index 0 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space1)); - - // prepare host 0. - auto space0 = std::make_unique(); - *(space0->add_hostnames()) = "hostname0"; - // with index 1 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space0)); - - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname0.xplane.pb", - "log/plugins/profile/hostname1.xplane.pb"}, - std::move(xspaces)); - auto error = - R"(The hostname of xspace path and preloaded xpace don't match at index: 0. -The host name of xpace path is hostname0 but the host name of preloaded xpace is hostname1.)"; - EXPECT_THAT(session_snapshot_or.status(), Eq(errors::InvalidArgument(error))); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc deleted file mode 100644 index ce340db50cb6b3..00000000000000 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc +++ /dev/null @@ -1,225 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" - -#include -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Local core id should start from 1. -const uint32 kDefaultGpuLocalCoreId = 1; - -namespace { - -void StepEventsToPerCoreStepInfo(uint32_t step_num, StepDetails& step_details, - PerCoreStepInfo& per_core_step_info) { - per_core_step_info.set_step_num(step_num); - OpMetricsDbCombiner combiner(per_core_step_info.mutable_hlo_metrics_db()); - auto step_time = step_details.StepTime(); - if (step_time.duration_ps() == 0) { - // In case no step markers are observed for the particular step, Skip the - // step. - VLOG(1) << "Skipping step " << step_details.StepName() - << "with no step markers"; - return; - } - for (auto& [core_id, metrics_db] : step_details.PerCoreOpMetricsDb()) { - SetTotalTimePs(metrics_db, step_time.duration_ps()); - AddIdleOp(metrics_db); - // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is - // implemented. - if (core_id < kSparseCoreIndexStart) combiner.Combine(metrics_db); - - GenericStepBreakdown step_breakdown; - auto& category_ps = *(step_breakdown.mutable_category_ps()); - for (auto& metric : metrics_db.metrics_db()) { - category_ps[metric.category()] += metric.self_time_ps(); - } - - StepInfoResult step_info; - step_info.set_step_num(step_num); - step_info.set_step_name(step_details.StepName()); - step_info.set_begin_ps(step_time.begin_ps()); - step_info.set_duration_ps(step_time.duration_ps()); - step_info.mutable_step_breakdown()->PackFrom(step_breakdown); - (*per_core_step_info.mutable_step_info_per_core())[core_id] = - std::move(step_info); - } - auto& all_reduce_db_per_core_map = - *per_core_step_info.mutable_all_reduce_db_per_core(); - for (const auto& [core_id, all_reduce_db] : step_details.Collectives()) { - all_reduce_db_per_core_map[core_id].CopyFrom(all_reduce_db); - } -} - -// Converts from StepDetails to StepInfoResult. -StepInfoResult ConvertStepDetailsToStepInfo(bool has_device, int64_t step_num, - StepDetails& step_details) { - GenericStepBreakdown generic; - tsl::profiler::Timespan step_time = step_details.StepTime(); - auto& type_ps = *(generic.mutable_type_ps()); - uint64 total_event_duration = 0; - for (const auto& event : step_details.Events()) { - // Ignore event duration outside the step marker. - uint64 event_duration = step_time.OverlappedDurationPs(event.span); - type_ps[event.type] += event_duration; - total_event_duration += event_duration; - } - if (total_event_duration < step_time.duration_ps()) { - // Some time in the step is not associated with any event. Classify them as - // "unknown time". - type_ps[UNKNOWN_TIME] += step_time.duration_ps() - total_event_duration; - } - // Determines if this particular step is a well-formed one. - bool well_formed_step = has_device ? type_ps.contains(DEVICE_COMPUTE_16) || - type_ps.contains(DEVICE_COMPUTE_32) - : type_ps.contains(HOST_COMPUTE); - StepInfoResult step_info; - step_info.mutable_step_breakdown()->PackFrom(generic); - if (well_formed_step) { - step_info.set_step_num(step_num); - step_info.set_step_name(step_details.StepName()); - step_info.set_begin_ps(step_time.begin_ps()); - step_info.set_duration_ps(step_time.duration_ps()); - } else { - // For a non-well-formed step, sets its duration to 0 so that it will be - // ignored by the caller of this function. - step_info.set_duration_ps(0); - } - return step_info; -} - -string DebugGenericStepBreakdown(const GenericStepBreakdown& generic) { - std::ostringstream out; - uint64 total_ps = 0; - const auto& type_ps_map = generic.type_ps(); - for (const auto& type_ps : type_ps_map) { - total_ps += type_ps.second; - } - out << "Total ps = " << total_ps << std::endl; - for (int type = LAST_EVENT_TYPE; type >= 0; --type) { - const auto* ps = gtl::FindOrNull(type_ps_map, type); - if (ps == nullptr) continue; - double percent = (*ps * 100.0) / total_ps; - auto event_type = static_cast(type); - out << PrintEventType(event_type) << ": " << percent << "%" - << ", ps = " << *ps << std::endl; - } - return out.str(); -} - -string DebugStepInfo(const StepInfoResult& step_info) { - std::ostringstream out; - out << "step_num=" << step_info.step_num() - << ", duration_ps=" << step_info.duration_ps() - << ", begin_ps=" << step_info.begin_ps() << std::endl; - GenericStepBreakdown generic; - if (step_info.step_breakdown().UnpackTo(&generic)) { - out << "Generic step breakdown:" << std::endl; - out << DebugGenericStepBreakdown(generic) << std::endl; - } else { - out << step_info.step_breakdown().DebugString() << std::endl; - } - return out.str(); -} - -} // namespace - -StepDatabaseResult ConvertStepEventsToStepDb( - bool has_device, bool maybe_drop_incomplete_steps, - StepEvents& nonoverlapped_step_events) { - StepDatabaseResult step_db; - // Gets sorted step numbers. - std::vector step_numbers; - step_numbers.reserve(nonoverlapped_step_events.size()); - for (const auto& step_events : nonoverlapped_step_events) { - step_numbers.push_back(step_events.first); - } - absl::c_sort(step_numbers); - for (const auto& step : step_numbers) { - auto* step_details = gtl::FindOrNull(nonoverlapped_step_events, step); - if (step_details == nullptr) continue; - PerCoreStepInfo per_core_step_info; - per_core_step_info.set_step_num(step); - if (!step_details->PerCoreOpMetricsDb().empty()) { - StepEventsToPerCoreStepInfo(step, *step_details, per_core_step_info); - } else { - StepInfoResult step_info = - ConvertStepDetailsToStepInfo(has_device, step, *step_details); - if (step_info.duration_ps() == 0) - continue; // Do not include non-well-formed steps. - // When we generated StepEvents, we already put events from all device - // cores and cpu threads on this host into a single event stream, - // therefore we can't separate them anymore. Simply assigns all events to - // Core-0. - (*per_core_step_info - .mutable_step_info_per_core())[kDefaultGpuLocalCoreId] = - std::move(step_info); - VLOG(2) - << std::endl - << "step_id: " << step << ", step_info:" << std::endl - << DebugStepInfo( - (*per_core_step_info - .mutable_step_info_per_core())[kDefaultGpuLocalCoreId]); - // Populates the collective ops information. - auto& collectives = *per_core_step_info.mutable_all_reduce_db_per_core(); - for (const auto& it : step_details->Collectives()) { - collectives[it.first] = it.second; - } - // Populates the device transfer stats for this step. - auto& device_memory_transfers = - *per_core_step_info.mutable_device_memory_transfers(); - for (const auto& dma : step_details->DeviceMemoryTransfers()) { - *device_memory_transfers.Add() = dma; - } - } - // The remaining fields in PerCoreStepInfo are not filled. - *step_db.add_step_sequence() = per_core_step_info; - } - - // If we are using sampling mode and we get enough steps, we would like to - // drop the incomplete steps at the beginning and the end. - // (Sometimes CUTPI instrumentation will prolong the first step too). - int kDropIncomplteteStepThreshold = 5; - if (maybe_drop_incomplete_steps && - step_db.step_sequence_size() > kDropIncomplteteStepThreshold) { - step_db.mutable_step_sequence()->erase( - step_db.mutable_step_sequence()->begin()); - step_db.mutable_step_sequence()->RemoveLast(); - } - return step_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h deleted file mode 100644 index 5bb980c32f1e01..00000000000000 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -TF_CONST_INIT extern const uint32 kDefaultGpuLocalCoreId; - -// Converts from overlapped Step-Events to StepDatabaseResult. -StepDatabaseResult ConvertStepEventsToStepDb( - bool has_device, bool maybe_drop_incomplete_steps, - StepEvents& overlapped_step_events); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ diff --git a/tensorflow/core/profiler/convert/tool_options.h b/tensorflow/core/profiler/convert/tool_options.h deleted file mode 100644 index b3f787df943058..00000000000000 --- a/tensorflow/core/profiler/convert/tool_options.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" - -namespace tensorflow { -namespace profiler { - -using ToolOptions = - absl::flat_hash_map>; - -// Helper function to get parameter from tool options. -template -std::optional GetParam(const ToolOptions& options, const std::string& key) { - const auto iter = options.find(key); - if (iter == options.end()) { - return std::nullopt; - } - - const T* result = std::get_if(&iter->second); - if (!result) { - return std::nullopt; - } - return *result; -} - -// Helper function to get parameter from tool options with default value. -template -T GetParamWithDefault(const ToolOptions& options, const std::string& key, - const T& default_param) { - if (auto param = GetParam(options, key)) { - return *param; - } - return default_param; -} - -inline std::string DebugString(const ToolOptions& options) { - std::string output; - for (const auto& [k, v] : options) { - absl::StrAppend( - &output, k, ":", - std::visit([](const auto& value) { return absl::StrCat(value); }, v), - ":", v.index(), ";"); - } - return absl::StrCat("{", output, "}"); -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ diff --git a/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h b/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h deleted file mode 100644 index ba0fcf1919e414..00000000000000 --- a/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ - -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/macros.h" - -namespace tensorflow { -namespace profiler { - -TF_CONST_INIT extern const absl::string_view kProfileAllHostsDoc; -TF_CONST_INIT extern const absl::string_view kSparseCoreV0Name; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc deleted file mode 100644 index 9838f57816c584..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -bool HasDcnCollectiveStatsInXSpace(const XSpace& xspace) { - if (const tsl::profiler::XPlane* xplane = - FindPlaneWithName(xspace, tsl::profiler::kHostThreadsPlaneName); - xplane != nullptr) { - for (const auto& [_, metadata] : xplane->event_metadata()) { - if (absl::StartsWith(metadata.name(), "MegaScale:")) { - return true; - } - } - } - return false; -} - -absl::StatusOr GetDcnCollectiveStatsFromMultiXSpaceAndSaveToFile( - const SessionSnapshot& session_snapshot) { - DcnSlackAnalysisCombiner combiner; - for (int idx = 0; idx < session_snapshot.XSpaceSize(); idx++) { - std::string hostname = session_snapshot.GetHostname(idx); - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(idx)); - - // The profile does not have dcn collective stats. - if (!HasDcnCollectiveStatsInXSpace(*xspace)) { - DcnSlackAnalysis dcnSlackAnalysis; - TF_RETURN_IF_ERROR(WriteBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - kNoHostIdentifier, dcnSlackAnalysis)); - return false; - } - - DcnSlackAnalysis dcnSlackAnalysis = - ConvertXSpaceToDcnSlackAnalysis(*xspace, nullptr, nullptr); - - TF_RETURN_IF_ERROR(WriteBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - hostname, dcnSlackAnalysis)); - - combiner.Combine(dcnSlackAnalysis); - } - - DcnSlackAnalysis dcnSlackAnalysis = combiner.Finalize(); - TF_RETURN_IF_ERROR(WriteBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - kAllHostsIdentifier, dcnSlackAnalysis)); - - // The profile has dcn collective stats. - return true; -} - -} // namespace - -absl::StatusOr HasDcnCollectiveStatsInMultiXSpace( - const SessionSnapshot& session_snapshot) { - std::pair hasCacheFile; - TF_ASSIGN_OR_RETURN(hasCacheFile, session_snapshot.HasCacheFile( - StoredDataType::DCN_COLLECTIVE_STATS)); - - // Cache file not present, check if trace contains dcn collective stats. - if (!hasCacheFile.first) { - for (int idx = 0; idx < session_snapshot.XSpaceSize(); idx++) { - std::string hostname = session_snapshot.GetHostname(idx); - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(idx)); - - if (HasDcnCollectiveStatsInXSpace(*xspace)) { - return true; - } - } - return false; - } - - if (hasCacheFile.second.empty()) { - // If the profiler finds a file NO_HOST.dcn_collective_stats.pb, this means - // dcn collective stats are not present in the profile. - return false; - } else { - // If the profiler finds a file ALL_HOSTS.dcn_collective_stats.pb, this - // means dcn collective stats are present in the profile. - return true; - } -} - -absl::StatusOr ConvertMultiXSpaceToDcnCollectiveStats( - const SessionSnapshot& session_snapshot) { - std::pair hasCacheFile; - TF_ASSIGN_OR_RETURN(hasCacheFile, session_snapshot.HasCacheFile( - StoredDataType::DCN_COLLECTIVE_STATS)); - - // Cache file not present, generate dcn collective stats. - if (!hasCacheFile.first) { - return GetDcnCollectiveStatsFromMultiXSpaceAndSaveToFile(session_snapshot); - } - - if (hasCacheFile.second.empty()) { - // If the profiler finds a file NO_HOST.dcn_collective_stats.pb, this means - // dcn collective stats are not present in the profile. - return false; - } else { - // If the profiler finds a file ALL_HOSTS.dcn_collective_stats.pb, this - // means dcn collective stats are present in the profile. - return true; - } -} - -absl::StatusOr GetDcnSlackAnalysisByHostName( - const SessionSnapshot& session_snapshot, const std::string hostname) { - TF_ASSIGN_OR_RETURN(bool hasDcnCollectiveStats, - ConvertMultiXSpaceToDcnCollectiveStats(session_snapshot)); - - DcnSlackAnalysis dcnSlackAnalysis; - if (hasDcnCollectiveStats) { - TF_RETURN_IF_ERROR(ReadBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - hostname, &dcnSlackAnalysis)); - } - - return dcnSlackAnalysis; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h deleted file mode 100644 index 9b4f9ff1bf5845..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ - -#include "absl/status/statusor.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Converts multiple XSpaces to dcn collective stats. -// Stores the dcn collective stats as files in the same directory -// as the xspace files. -absl::StatusOr ConvertMultiXSpaceToDcnCollectiveStats( - const SessionSnapshot& session_snapshot); - -// Returns whether there are dcn collective stats in the profile. -absl::StatusOr HasDcnCollectiveStatsInMultiXSpace( - const SessionSnapshot& session_snapshot); - -// Gets DcnSlackAnalysis proto for a host. -absl::StatusOr GetDcnSlackAnalysisByHostName( - const SessionSnapshot& session_snapshot, std::string hostname); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc deleted file mode 100644 index 068efbd752282e..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" - -#include -#include -#include -#include -#include - -#include -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/status.h" -#include "tensorflow/core/platform/file_system.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -DcnSlackAnalysis CreateDcnSlackAnalysisProto() { - DcnSlackAnalysis dcn_slack_analysis; - DcnSlackSummary* dcn_slack_summary = - dcn_slack_analysis.add_dcn_slack_summary(); - dcn_slack_summary->set_rendezvous("collective"); - dcn_slack_summary->set_recv_op_name("recv-done"); - dcn_slack_summary->set_send_op_name("send"); - dcn_slack_summary->set_slack_us(2); - dcn_slack_summary->set_observed_duration_us(12); - dcn_slack_summary->set_stall_duration_us(5); - dcn_slack_summary->set_occurrences(4); - dcn_slack_summary->set_bytes_transmitted_over_network(819200); - return dcn_slack_analysis; -} - -SessionSnapshot CreateSessionSnapshot(bool create_cache_file, - bool has_dcn_collective_stats) { - std::string test_name = - ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::string path = absl::StrCat("ram://", test_name, "/"); - std::unique_ptr xplane_file; - std::vector paths = {absl::StrCat(path, "hostname.xplane.pb")}; - - auto xspace = std::make_unique(); - XPlane* xplane = FindOrAddMutablePlaneWithName(xspace.get(), "/host:CPU"); - if (has_dcn_collective_stats) { - XPlaneBuilder xplane_builder(xplane); - xplane_builder.GetOrCreateEventMetadata("MegaScale:"); - } - - if (create_cache_file) { - if (has_dcn_collective_stats) { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "hostname.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "ALL_HOSTS.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - } else { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "NO_HOST.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - } - } - - std::vector> xspaces; - xspaces.push_back(std::move(xspace)); - - absl::StatusOr session_snapshot_status = - SessionSnapshot::Create(paths, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_status.status()); - SessionSnapshot session_snapshot = std::move(session_snapshot_status.value()); - if (has_dcn_collective_stats) { - DcnSlackAnalysis dcn_slack_analysis = CreateDcnSlackAnalysisProto(); - TF_CHECK_OK(session_snapshot.WriteBinaryProto( - DCN_COLLECTIVE_STATS, "hostname", dcn_slack_analysis)); - TF_CHECK_OK(session_snapshot.WriteBinaryProto( - DCN_COLLECTIVE_STATS, kAllHostsIdentifier, dcn_slack_analysis)); - } - return session_snapshot; -} - -TEST(ConvertXplaneToDcnCollectiveStats, - HasAllHostsDcnCollectiveStatsCacheFile) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(true, true); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), true); -} - -TEST(ConvertXplaneToDcnCollectiveStats, HasNoHostDcnCollectiveStatsCacheFile) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(true, false); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), false); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - NoCacheFileButTraceHasDcnCollectiveStats) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, true); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), true); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - NoCacheFileNoDcnCollectiveStatsPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, false); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), false); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - ConvertXSpaceToDcnCollectiveStatsWhenStatsPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, true); - - absl::StatusOr status = - ConvertMultiXSpaceToDcnCollectiveStats(session_snapshot); - absl::StatusOr> all_hosts_filepath = - session_snapshot.GetHostDataFilePath(StoredDataType::DCN_COLLECTIVE_STATS, - kAllHostsIdentifier); - absl::StatusOr> host_filepath = - session_snapshot.GetHostDataFilePath(StoredDataType::DCN_COLLECTIVE_STATS, - "hostname"); - - EXPECT_EQ(status.value(), true); - TF_EXPECT_OK(all_hosts_filepath.status()); - EXPECT_TRUE(all_hosts_filepath.value().has_value()); - EXPECT_FALSE(all_hosts_filepath.value().value().empty()); - TF_EXPECT_OK(host_filepath.status()); - EXPECT_TRUE(host_filepath.value().has_value()); - EXPECT_FALSE(host_filepath.value().value().empty()); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - ConvertXSpaceToDcnCollectiveStatsWhenStatsNotPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, false); - - absl::StatusOr status = - ConvertMultiXSpaceToDcnCollectiveStats(session_snapshot); - absl::StatusOr> filepath = - session_snapshot.GetHostDataFilePath(StoredDataType::DCN_COLLECTIVE_STATS, - kNoHostIdentifier); - - EXPECT_EQ(status.value(), false); - TF_EXPECT_OK(filepath.status()); - EXPECT_TRUE(filepath.value().has_value()); - EXPECT_FALSE(filepath.value().value().empty()); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - GetHostDcnSlackAnalysisWhenStatsNotPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, false); - - absl::StatusOr host_dcn_slack_analysis = - GetDcnSlackAnalysisByHostName(session_snapshot, "hostname"); - - TF_EXPECT_OK(host_dcn_slack_analysis.status()); - EXPECT_EQ(host_dcn_slack_analysis.value().dcn_slack_summary_size(), 0); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - GetHostDcnSlackAnalysisWhenStatsPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(true, true); - - absl::StatusOr host_dcn_slack_analysis = - GetDcnSlackAnalysisByHostName(session_snapshot, "hostname"); - - TF_EXPECT_OK(host_dcn_slack_analysis.status()); - EXPECT_EQ(host_dcn_slack_analysis.value().dcn_slack_summary_size(), 1); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.cc b/tensorflow/core/profiler/convert/xplane_to_hlo.cc deleted file mode 100644 index 62ee1c487b41a7..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_hlo.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -using tsl::profiler::ProfilerJoinPath; - -constexpr char kNoModuleIdentifier[] = "NO_MODULE"; -constexpr char kHloProtoSuffix[] = ".hlo_proto.pb"; - -// Extracts and deduplicates the HLO protos from all the XSpaces. -// Stores the HLO protos as files in the same directory as the xspace files. -absl::StatusOr GetHloProtoFromMultiXSpaceAndSaveToFile( - const SessionSnapshot& session_snapshot) { - // Get all HLO protos from XSpaces and deduplicate. - HloProtoMap hlo_proto_map; - for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - hlo_proto_map.AddHloProtosFromXSpace(*xspace); - } - - std::vector module_list = hlo_proto_map.GetModuleList(); - // Write an empty identifier if there is no HLO module. - if (module_list.empty()) { - std::string file_name = - ProfilerJoinPath(session_snapshot.GetSessionRunDir(), - absl::StrCat(kNoModuleIdentifier, kHloProtoSuffix)); - xla::HloProto empty_hlo; - TF_RETURN_IF_ERROR( - tsl::WriteBinaryProto(tsl::Env::Default(), file_name, empty_hlo)); - // The profile does not have HLO proto. - return false; - } - - // Save HLO protos to session run directory. - for (const absl::string_view module_name : module_list) { - auto hlo_proto_or = hlo_proto_map.GetHloProtoByModuleName(module_name); - if (!hlo_proto_or.ok()) { - return errors::Internal(hlo_proto_or.status().message()); - } - std::string file_name = - ProfilerJoinPath(session_snapshot.GetSessionRunDir(), - absl::StrCat(module_name, kHloProtoSuffix)); - TF_RETURN_IF_ERROR(tsl::WriteBinaryProto(tsl::Env::Default(), file_name, - *hlo_proto_or.value())); - } - - // The profile has HLO proto. - return true; -} - -} // namespace - -absl::StatusOr GetHloProtoByModuleName( - const SessionSnapshot& session_snapshot, - const absl::string_view module_name) { - std::string file_name = - ProfilerJoinPath(session_snapshot.GetSessionRunDir(), - absl::StrCat(module_name, kHloProtoSuffix)); - xla::HloProto hlo_proto; - TF_RETURN_IF_ERROR( - tsl::ReadBinaryProto(tsl::Env::Default(), file_name, &hlo_proto)); - return hlo_proto; -} - -absl::StatusOr ConvertMultiXSpaceToHloProto( - const SessionSnapshot& session_snapshot) { - // Gets all the files in session run directory. - // TODO(profiler): Move this glob to SessionSnapshot and build a map from file - // type to file paths. - std::vector results; - TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren( - std::string(session_snapshot.GetSessionRunDir()), &results)); - - // If the profiler finds a filename with hlo proto suffix, this means HLO - // proto was already generated previously. - for (const std::string& path : results) { - if (absl::EndsWith(path, kHloProtoSuffix)) { - if (absl::EndsWith(path, - absl::StrCat(kNoModuleIdentifier, kHloProtoSuffix))) { - return false; - } else { - return true; - } - } - } - - // Generate HLO proto. - // TODO(jiesun): Maybe generate a tag file at profile collection time, so - // don't need to read XSpace files for checking whether HLO proto exists or - // not. - return GetHloProtoFromMultiXSpaceAndSaveToFile(session_snapshot); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.h b/tensorflow/core/profiler/convert/xplane_to_hlo.h deleted file mode 100644 index c102f6a24f6a78..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ - -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" - -namespace tensorflow { -namespace profiler { - -// Get HLO proto by module name. -absl::StatusOr GetHloProtoByModuleName( - const SessionSnapshot& session_snapshot, absl::string_view module_name); - -// Converts multiple XSpaces to HLO protos. -// Stores the HLO protos as files in the same directory as the xspace files. -// Returns whether there are HLO protos in this profile. -absl::StatusOr ConvertMultiXSpaceToHloProto( - const SessionSnapshot& session_snapshot); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc deleted file mode 100644 index a6aefc3fe7e188..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" - -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof -#include "xprof/utils/gpu_event_stats.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -void ConvertDeviceTraceXPlaneToKernelReports( - const XPlane& device_trace, - const std::function& - on_kernel_fn, - KernelReportMap* reports) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - plane.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) { - return; - } - line.ForEachEvent([&](const XEventVisitor& event) { - if (event.DurationNs() == 0) return; - KernelReport kernel; - GpuEventStats stats(&event); - if (!stats.IsKernel()) return; - - kernel.set_name(std::string(event.Name())); - kernel.set_is_kernel_using_tensor_core( - IsKernelUsingTensorCore(event.Name())); - kernel.set_total_duration_ns(event.DurationNs()); - kernel.set_min_duration_ns(event.DurationNs()); - kernel.set_max_duration_ns(event.DurationNs()); - ParseKernelLaunchParams(stats.kernel_details, &kernel); - - if (stats.IsTfOp()) { - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(stats.tf_op_fullname); - kernel.set_op_name(std::string(tf_op.name)); - bool tensor_core_eligible = - IsEinsumTensorCoreEligible(stats.equation) || - IsOpTensorCoreEligible(kernel.op_name()); - if (!tensor_core_eligible && kernel.is_kernel_using_tensor_core()) { - VLOG(1) << "Detected new Op using TensorCores: " << kernel.op_name() - << std::endl; - tensor_core_eligible = true; - } - kernel.set_is_op_tensor_core_eligible(tensor_core_eligible); - } - - if (on_kernel_fn) { - on_kernel_fn(stats, &kernel); - } - - KernelReportValue value; - value.total_duration_ns = event.DurationNs(); - value.min_duration_ns = event.DurationNs(); - value.max_duration_ns = event.DurationNs(); - value.occurrences = 1; - InsertOrUpdateKernelReport(kernel, value, reports); - }); - }); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h deleted file mode 100644 index 4c003a325de2ee..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ - -#include -#include - -#include "absl/log/log.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof -#include "xprof/utils/gpu_event_stats.h" // from @org_xprof -#include "xprof/utils/hlo_module_map.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -void ConvertDeviceTraceXPlaneToKernelReports( - const XPlane& device_trace, - const std::function& - on_kernel_fn, - KernelReportMap* reports); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc deleted file mode 100644 index 3e30beeffb41b9..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -TEST(ConvertXplaneToKernelStats, MultiKernels) { - XSpace space; - XPlane* device_trace = space.add_planes(); - XPlaneBuilder device_trace_builder(device_trace); - - // Empty default stream - device_trace_builder.GetOrCreateLine(0); - - XLineBuilder line_builder = device_trace_builder.GetOrCreateLine(0); - CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_shortest", - /*offset_ps=*/10000, /*duration_ps=*/1000, - {{StatType::kTfOp, "mul_786"}, - {StatType::kKernelDetails, R"MULTI(regs:16 -static_shared:0 -dynamic_shared:0 -grid:1,1,1 -block:1,1,1 -occ_pct:50.0)MULTI"}, - {StatType::kEquation, ""}}); - - CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_middle", - /*offset_ps=*/20000, /*duration_ps=*/2000, - {{StatType::kTfOp, "Conv2D"}, - {StatType::kKernelDetails, R"MULTI(regs:32 -static_shared:0 -dynamic_shared:16384 -grid:2,1,1 -block:32,1,1 -occ_pct=13.0)MULTI"}, - {StatType::kEquation, ""}}); - - CreateXEvent(&device_trace_builder, &line_builder, - "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn", - /*offset_ps=*/30000, /*duration_ps=*/3000, - {{StatType::kTfOp, "Einsum_80"}, - {StatType::kKernelDetails, R"MULTI(regs:32 -static_shared:0 -dynamic_shared:16384 -grid:3,1,1 -block:64,1,1 -occ_pct:25.0)MULTI"}, - {StatType::kEquation, ""}}); - - KernelReportMap reports; - ConvertDeviceTraceXPlaneToKernelReports(*device_trace, {}, &reports); - KernelStatsDb kernel_stats; - CopyTopKDurationKernelReportsToDb(reports, &kernel_stats); - - EXPECT_EQ(kernel_stats.reports_size(), 3); - - { - const auto& kernel = kernel_stats.reports().at(2); - EXPECT_EQ(kernel.name(), "kernel_name_shortest"); - EXPECT_EQ(kernel.registers_per_thread(), 16); - EXPECT_EQ(kernel.static_shmem_bytes(), 0); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 0); - EXPECT_EQ(kernel.grid_dim().at(0), 1); - EXPECT_EQ(kernel.grid_dim().at(1), 1); - EXPECT_EQ(kernel.grid_dim().at(2), 1); - EXPECT_EQ(kernel.block_dim().at(0), 1); - EXPECT_EQ(kernel.block_dim().at(1), 1); - EXPECT_EQ(kernel.block_dim().at(2), 1); - EXPECT_EQ(kernel.total_duration_ns(), 1); - EXPECT_FALSE(kernel.is_kernel_using_tensor_core()); - EXPECT_FALSE(kernel.is_op_tensor_core_eligible()); - EXPECT_EQ(kernel.op_name(), "mul_786"); - } - - { - const auto& kernel = kernel_stats.reports().at(1); - EXPECT_EQ(kernel.name(), "kernel_name_middle"); - EXPECT_EQ(kernel.registers_per_thread(), 32); - EXPECT_EQ(kernel.static_shmem_bytes(), 0); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 16384); - EXPECT_EQ(kernel.grid_dim().at(0), 2); - EXPECT_EQ(kernel.grid_dim().at(1), 1); - EXPECT_EQ(kernel.grid_dim().at(2), 1); - EXPECT_EQ(kernel.block_dim().at(0), 32); - EXPECT_EQ(kernel.block_dim().at(1), 1); - EXPECT_EQ(kernel.block_dim().at(2), 1); - EXPECT_EQ(kernel.total_duration_ns(), 2); - EXPECT_FALSE(kernel.is_kernel_using_tensor_core()); - EXPECT_TRUE(kernel.is_op_tensor_core_eligible()); - EXPECT_EQ(kernel.op_name(), "Conv2D"); - } - - { - const auto& kernel = kernel_stats.reports().at(0); - EXPECT_EQ(kernel.name(), "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn"); - EXPECT_EQ(kernel.registers_per_thread(), 32); - EXPECT_EQ(kernel.static_shmem_bytes(), 0); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 16384); - EXPECT_EQ(kernel.grid_dim().at(0), 3); - EXPECT_EQ(kernel.grid_dim().at(1), 1); - EXPECT_EQ(kernel.grid_dim().at(2), 1); - EXPECT_EQ(kernel.block_dim().at(0), 64); - EXPECT_EQ(kernel.block_dim().at(1), 1); - EXPECT_EQ(kernel.block_dim().at(2), 1); - EXPECT_EQ(kernel.total_duration_ns(), 3); - EXPECT_TRUE(kernel.is_kernel_using_tensor_core()); - EXPECT_TRUE(kernel.is_op_tensor_core_eligible()); - EXPECT_EQ(kernel.op_name(), "Einsum_80"); - } -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc deleted file mode 100644 index 3d60ee248ee48d..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ /dev/null @@ -1,573 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/memory_profile.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -constexpr int64_t kInvalidStepId = -1; - -// Index of the time-sorted memory_profile_snapshots list, and the -// MemoryActivityMetadata proto it contains. -using IndexMetaPair = - std::pair; - -bool IsMemoryAllocation(int64_t event_type) { - return event_type == HostEventType::kMemoryAllocation; -} - -bool IsMemoryDeallocation(int64_t event_type) { - return event_type == HostEventType::kMemoryDeallocation; -} - -void UpdateProfileSummary(const MemoryAggregationStats& stats, - int64_t time_offset_ps, - MemoryProfileSummary* summary) { - // Update the peak memory usage over allocator's lifetime. - summary->set_peak_bytes_usage_lifetime(stats.peak_bytes_in_use()); - MemoryAggregationStats* peak_stats = summary->mutable_peak_stats(); - // If we reach (or stay at) peak memory usage within the profiling window, - // update memory profile summary. - if (stats.stack_reserved_bytes() + stats.heap_allocated_bytes() >= - peak_stats->peak_bytes_in_use()) { - *peak_stats = stats; - peak_stats->set_peak_bytes_in_use(stats.stack_reserved_bytes() + - stats.heap_allocated_bytes()); - summary->set_peak_stats_time_ps(time_offset_ps); - summary->set_memory_capacity(stats.stack_reserved_bytes() + - stats.heap_allocated_bytes() + - stats.free_memory_bytes()); - } -} - -// Generate memory profile proto by processing host trace XPlane. -MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - MemoryProfile memory_profile; - // Iterate over all XEvents in the XPlane, and add the XStats to a new - // MemoryProfileSnapshot if the EventType is kMemoryAllocation or - // kMemoryDeallocation. - plane.ForEachLine([&](const XLineVisitor& line) { - line.ForEachEvent([&](const XEventVisitor& event) { - int64_t event_type = - event.Type().value_or(HostEventType::kUnknownHostEventType); - if (!(IsMemoryAllocation(event_type) || - IsMemoryDeallocation(event_type))) { - return; - } - - MemoryAggregationStats stats; - MemoryActivityMetadata metadata; - if (IsMemoryAllocation(event_type)) { - metadata.set_memory_activity(ALLOCATION); - } else if (IsMemoryDeallocation(event_type)) { - metadata.set_memory_activity(DEALLOCATION); - } - metadata.set_step_id(kInvalidStepId); - - std::string memory_id; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kIndexOnHost: - case StatType::kDeviceOrdinal: - memory_id = absl::StrCat(stat.IntValue()); - break; - case StatType::kAllocatorName: - memory_id = std::string(stat.StrOrRefValue()); - break; - case StatType::kBytesReserved: - stats.set_stack_reserved_bytes(stat.IntValue()); - break; - case StatType::kBytesAllocated: - stats.set_heap_allocated_bytes(stat.IntValue()); - break; - case StatType::kBytesAvailable: - stats.set_free_memory_bytes(stat.IntValue()); - break; - case StatType::kFragmentation: - stats.set_fragmentation(stat.DoubleValue()); - break; - case StatType::kPeakBytesInUse: - stats.set_peak_bytes_in_use(stat.IntValue()); - break; - case StatType::kRequestedBytes: - metadata.set_requested_bytes(stat.IntValue()); - break; - case StatType::kAllocationBytes: - metadata.set_allocation_bytes(stat.IntValue()); - break; - case StatType::kAddress: - metadata.set_address(stat.IntValue()); - break; - case StatType::kTfOp: - metadata.set_tf_op_name(std::string(stat.StrOrRefValue())); - break; - case StatType::kGroupId: - metadata.set_step_id(stat.IntValue()); - break; - case StatType::kRegionType: - metadata.set_region_type(std::string(stat.StrOrRefValue())); - break; - case StatType::kDataType: - metadata.set_data_type(tensorflow::DataTypeString( - static_cast(stat.IntValue()))); - break; - case StatType::kTensorShapes: - metadata.set_tensor_shape(std::string(stat.StrOrRefValue())); - break; - } - }); - - MemoryProfileSummary* summary = - (*memory_profile.mutable_memory_profile_per_allocator())[memory_id] - .mutable_profile_summary(); - UpdateProfileSummary(stats, event.OffsetPs(), summary); - - MemoryProfileSnapshot* snapshot = - (*memory_profile.mutable_memory_profile_per_allocator())[memory_id] - .add_memory_profile_snapshots(); - snapshot->set_time_offset_ps(event.OffsetPs()); - *snapshot->mutable_aggregation_stats() = std::move(stats); - *snapshot->mutable_activity_metadata() = std::move(metadata); - }); - }); - return memory_profile; -} - -// Fix invalid step ids of snapshots at the beginning/end of the profile or at -// the step boundaries. The snapshots with invalid step ids at the beginning get -// 0 for their step ids. Those at the step boundaries or at the end get the -// previous snapshot's step id + 1. -void UpdateStepId(PerAllocatorMemoryProfile* memory_profile) { - int64_t last_valid_step_id = -1; - // Snapshots are already sorted in time. - for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) { - DCHECK(snapshot.has_activity_metadata()); - if (snapshot.mutable_activity_metadata()->step_id() == kInvalidStepId) { - snapshot.mutable_activity_metadata()->set_step_id(last_valid_step_id + 1); - } else { - last_valid_step_id = snapshot.mutable_activity_metadata()->step_id(); - } - } -} - -// Update the MemoryActivityMetadata for each deallocation event by copying from -// matching allocation. -void UpdateDeallocation(PerAllocatorMemoryProfile* memory_profile) { - absl::flat_hash_map - addr_metadata_map; - for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) { - // Match the deallocation with previous allocation based on address. - uint64 address = snapshot.activity_metadata().address(); - if (snapshot.activity_metadata().memory_activity() == DEALLOCATION) { - if (addr_metadata_map.contains(address)) { - const MemoryActivityMetadata* alloc_meta = addr_metadata_map[address]; - snapshot.mutable_activity_metadata()->set_tf_op_name( - alloc_meta->tf_op_name()); - snapshot.mutable_activity_metadata()->set_region_type( - alloc_meta->region_type()); - snapshot.mutable_activity_metadata()->set_data_type( - alloc_meta->data_type()); - snapshot.mutable_activity_metadata()->set_tensor_shape( - alloc_meta->tensor_shape()); - // In case of following (unexpected) deallocations to the same chunk - // address, leave the metadata as it is (empty or already captured). - addr_metadata_map.erase(address); - } else { - VLOG(2) - << "Can't find matching memory allocation for this deallocation: " - << snapshot.DebugString(); - } - } else if (!addr_metadata_map.contains(address)) { // Allocation. - addr_metadata_map[address] = &snapshot.activity_metadata(); - } else { - VLOG(2) << "There are two allocations recorded for the same address: " - << address - << ". The later allocation event is: " << snapshot.DebugString(); - } - } - VLOG(2) << "Number of allocations that cannot find matching dealloctions: " - << addr_metadata_map.size(); -} - -// Return the step id for the peak memory usage data point. -int64_t GetPeakMemoryStep(int64_t peak_bytes_profile, - const PerAllocatorMemoryProfile* memory_profile) { - int64_t peak_bytes_profile_step_id = 0; - for (const auto& snapshot : memory_profile->memory_profile_snapshots()) { - // Get the step id of the peak memory usage. - if (peak_bytes_profile == - snapshot.aggregation_stats().heap_allocated_bytes() + - snapshot.aggregation_stats().stack_reserved_bytes()) { - DCHECK(snapshot.has_activity_metadata()); - peak_bytes_profile_step_id = snapshot.activity_metadata().step_id(); - } - } - return peak_bytes_profile_step_id; -} - -// Functor that compares (index, metadata) pair to sort in the order of -// allocation bytes and requested bytes (descending), as well as TF Op name, -// region type, data type, and tensor shape (ascending). -struct MetadataComparator { - bool operator()(const IndexMetaPair& a, const IndexMetaPair& b) const { - const MemoryActivityMetadata* a_meta = a.second; - const MemoryActivityMetadata* b_meta = b.second; - DCHECK_NE(a_meta, nullptr); - DCHECK_NE(b_meta, nullptr); - - auto lhs = - std::make_tuple(-a_meta->allocation_bytes(), -a_meta->requested_bytes(), - a_meta->tf_op_name(), a_meta->region_type(), - a_meta->data_type(), a_meta->tensor_shape()); - auto rhs = - std::make_tuple(-b_meta->allocation_bytes(), -b_meta->requested_bytes(), - b_meta->tf_op_name(), b_meta->region_type(), - b_meta->data_type(), b_meta->tensor_shape()); - return lhs < rhs; - } -}; - -// If applicable, add items into active_allocs vector and special_allocations -// proto for the unmapped memory usage (in heap) and stack reservation at peak. -void InsertSpecialAllocations(int64_t unmapped_allocation_bytes, - int64_t step_id, - PerAllocatorMemoryProfile* memory_profile, - std::vector* active_allocs) { - int index = 0; - if (unmapped_allocation_bytes > 0) { - MemoryActivityMetadata* special_allocation = - memory_profile->add_special_allocations(); - special_allocation->set_memory_activity(ALLOCATION); - special_allocation->set_requested_bytes(unmapped_allocation_bytes); - special_allocation->set_allocation_bytes(unmapped_allocation_bytes); - special_allocation->set_address(0); - special_allocation->set_tf_op_name("unused preallocated device memory"); - special_allocation->set_step_id(step_id); - special_allocation->set_region_type("persist/dynamic"); - special_allocation->set_data_type( - tensorflow::DataTypeString(static_cast(0))); - special_allocation->set_tensor_shape("unknown"); - active_allocs->push_back({--index, special_allocation}); - } - int64_t stack_bytes = - memory_profile->profile_summary().peak_stats().stack_reserved_bytes(); - if (stack_bytes > 0) { - MemoryActivityMetadata* special_allocation = - memory_profile->add_special_allocations(); - special_allocation->set_memory_activity(ALLOCATION); - special_allocation->set_requested_bytes(stack_bytes); - special_allocation->set_allocation_bytes(stack_bytes); - special_allocation->set_address(0); - special_allocation->set_tf_op_name("stack"); - special_allocation->set_step_id(step_id); - special_allocation->set_region_type("stack"); - special_allocation->set_data_type( - tensorflow::DataTypeString(static_cast(0))); - special_allocation->set_tensor_shape("unknown"); - active_allocs->push_back({--index, special_allocation}); - } -} - -bool operator==(const IndexMetaPair& a, const IndexMetaPair& b) { - const MemoryActivityMetadata* a_meta = a.second; - const MemoryActivityMetadata* b_meta = b.second; - return a_meta->allocation_bytes() == b_meta->allocation_bytes() && - a_meta->requested_bytes() == b_meta->requested_bytes() && - a_meta->tf_op_name() == b_meta->tf_op_name() && - a_meta->region_type() == b_meta->region_type() && - a_meta->data_type() == b_meta->data_type() && - a_meta->tensor_shape() == b_meta->tensor_shape(); -} - -// Generate the memory breakdown table of active allocations at the peak usage -// (within profiling window) and fill each ActiveAllocation proto (i.e. a row). -void ProcessActiveAllocations(int64_t peak_bytes_profile_step_id, - PerAllocatorMemoryProfile* memory_profile) { - int64_t unmapped_allocation_bytes = - memory_profile->profile_summary().peak_stats().heap_allocated_bytes(); - int64_t unmapped_deallocation_bytes = 0; - absl::flat_hash_map active_alloc_map; - // Only account for the memory activities in the step that includes peak - // memory usage. - for (int i = 0; i < memory_profile->memory_profile_snapshots_size(); i++) { - const auto& snapshot = memory_profile->memory_profile_snapshots().at(i); - DCHECK(snapshot.has_activity_metadata()); - const MemoryActivityMetadata& metadata = snapshot.activity_metadata(); - if (snapshot.time_offset_ps() > - memory_profile->profile_summary().peak_stats_time_ps()) - break; - if (metadata.step_id() != peak_bytes_profile_step_id) continue; - - if (metadata.memory_activity() == ALLOCATION) { - active_alloc_map[metadata.address()] = {i, &metadata}; - unmapped_allocation_bytes -= metadata.allocation_bytes(); - } else { - DCHECK_EQ(metadata.memory_activity(), DEALLOCATION); - if (active_alloc_map.contains(metadata.address())) { - active_alloc_map.erase(metadata.address()); - } else { - unmapped_deallocation_bytes += metadata.allocation_bytes(); - } - unmapped_allocation_bytes += metadata.allocation_bytes(); - } - } - // This separates the persistent memory from the freed memory from last step's - // allocations. - unmapped_allocation_bytes -= unmapped_deallocation_bytes; - - VLOG(2) << "unmapped_allocation_bytes=" << unmapped_allocation_bytes - << ", unmapped_deallocation_bytes=" << unmapped_deallocation_bytes; - - // Using pair of (index, MemoryActivityMetadata*) so that we can sort by the - // metadata, and fetch metadata by indexing the time-sorted snapshots at - // frontend. - std::vector active_allocs; - for (const auto& address_and_index_meta : active_alloc_map) { - active_allocs.push_back(address_and_index_meta.second); - } - - InsertSpecialAllocations(unmapped_allocation_bytes, - peak_bytes_profile_step_id, memory_profile, - &active_allocs); - - std::sort(active_allocs.begin(), active_allocs.end(), MetadataComparator()); - - // Fill the sorted active_allocations proto messages at peak memory usage. - // Merge identical allocations and show occurrences. - for (int i = 0, end = active_allocs.size(); i < end; i++) { - ActiveAllocation* allocation = memory_profile->add_active_allocations(); - allocation->set_snapshot_index(active_allocs[i].first); - if (active_allocs[i].first < 0) { - allocation->set_special_index(-active_allocs[i].first - 1); - } else { - allocation->set_special_index(-1); - } - allocation->set_num_occurrences(1); - const int last_alloc = active_allocs.size() - 1; - while (i < last_alloc && active_allocs[i] == active_allocs[i + 1]) { - allocation->set_num_occurrences(allocation->num_occurrences() + 1); - i++; - } - } - - VLOG(2) << "Distinctive active allocation count=" - << memory_profile->active_allocations_size(); -} - -// This function saves the MemoryProfileSnapshots referenced by -// max_num_snapshots. -void SaveActiveAllocationSnapshots( - tsl::protobuf::RepeatedPtrField* snapshots, - tsl::protobuf::RepeatedPtrField* active_allocations) { - std::vector samples; - // Puts the snapshots referenced by active_allocations in . - for (const auto& allocation : *active_allocations) { - auto orig_index = allocation.snapshot_index(); - if (orig_index < 0) continue; - samples.push_back(&(*snapshots)[orig_index]); - } - - // Change the reference index in . - int new_index = 0; - for (auto& allocation : *active_allocations) { - int64_t origin_index = allocation.snapshot_index(); - if (origin_index < 0) continue; - allocation.set_snapshot_index(new_index); - new_index++; - } - - tsl::protobuf::RepeatedPtrField new_snapshots; - new_snapshots.Reserve(samples.size()); - for (const auto& sample : samples) { - *new_snapshots.Add() = std::move(*sample); - } - *snapshots = std::move(new_snapshots); -} - -// Sample memory profile snapshots from the original memory -// profile data. -void SampleMemoryProfileTimeline(int64_t max_num_snapshots, - PerAllocatorMemoryProfile* memory_profile) { - const tsl::protobuf::RepeatedPtrField& - original_snapshots = memory_profile->memory_profile_snapshots(); - tsl::protobuf::RepeatedPtrField* timeline_snapshots = - memory_profile->mutable_sampled_timeline_snapshots(); - int64_t snapshot_count = original_snapshots.size(); - if (snapshot_count > max_num_snapshots) { - // When there are more memory profile data than , we - // sample the origin data using a max box filter. Filter width is - // , collect samples starting from the index - // in the original snapshots. - auto max_box_filter = [&](int filter_width, int count, int start) { - for (int i = 0; i < count; i++) { - // Use a max function to get the MemoryProfileSnapshot with the largest - // memory usage in the box filter. - const MemoryProfileSnapshot* max_snapshot = - &original_snapshots[start + filter_width * i]; - int64_t max_bytes = - max_snapshot->aggregation_stats().heap_allocated_bytes() + - max_snapshot->aggregation_stats().stack_reserved_bytes(); - for (int index = start + filter_width * i + 1; - index < start + filter_width * (i + 1); index++) { - int64_t bytes = original_snapshots[index] - .aggregation_stats() - .heap_allocated_bytes() + - original_snapshots[index] - .aggregation_stats() - .stack_reserved_bytes(); - if (bytes > max_bytes) { - max_snapshot = &original_snapshots[index]; - max_bytes = bytes; - } - } - *timeline_snapshots->Add() = *max_snapshot; - } - }; - - int width = snapshot_count / max_num_snapshots; - int count1 = max_num_snapshots * (width + 1) - snapshot_count; - int count2 = max_num_snapshots - count1; - - // Collect samples with box filter width , then collect - // samples with box filter width , the total number of - // samples collected will be . - max_box_filter(width, count1, 0); - max_box_filter(width + 1, count2, width * count1); - } else { - // When the number of original snapshots are smaller than - // , just copy all the data points to the timeline. - *timeline_snapshots = original_snapshots; - } -} - -// Post-process the memory profile to correctly update proto fields, and break -// down peak memory usage for each allocator. -void ProcessMemoryProfileProto(int64_t max_num_snapshots, - MemoryProfile* memory_profile) { - memory_profile->set_num_hosts(1); - // Add sorted memory ids within memory profile data to the selection list. - for (const auto& id_and_allocator_profile : - memory_profile->memory_profile_per_allocator()) { - if (!id_and_allocator_profile.second.memory_profile_snapshots().empty()) { - memory_profile->add_memory_ids(id_and_allocator_profile.first); - } - } - absl::c_sort(*memory_profile->mutable_memory_ids()); - - for (auto& id_and_allocator_profile : - *memory_profile->mutable_memory_profile_per_allocator()) { - PerAllocatorMemoryProfile* allocator_memory_profile = - &id_and_allocator_profile.second; - tsl::protobuf::RepeatedPtrField* snapshots = - allocator_memory_profile->mutable_memory_profile_snapshots(); - // Sort the memory_profile_snapshots by time_offset_ps (ascending) in proto. - absl::c_sort(*snapshots, [](const MemoryProfileSnapshot& a, - const MemoryProfileSnapshot& b) { - return a.time_offset_ps() < b.time_offset_ps(); - }); - - UpdateStepId(allocator_memory_profile); - UpdateDeallocation(allocator_memory_profile); - - // Sample a subset of MemoryProfileSnapshots to display in the frontend - // memory timeline graph. - SampleMemoryProfileTimeline(max_num_snapshots, allocator_memory_profile); - - int64_t peak_step_id = - GetPeakMemoryStep(allocator_memory_profile->profile_summary() - .peak_stats() - .peak_bytes_in_use(), - allocator_memory_profile); - ProcessActiveAllocations(peak_step_id, allocator_memory_profile); - SaveActiveAllocationSnapshots( - snapshots, allocator_memory_profile->mutable_active_allocations()); - } -} - -template -absl::Status ConvertProtoToJson(const Proto& proto_output, - std::string* json_output) { - tsl::protobuf::util::JsonPrintOptions json_options; - json_options.always_print_primitive_fields = true; - auto status = tsl::protobuf::util::MessageToJsonString( - proto_output, json_output, json_options); - if (!status.ok()) { - // Convert error_msg google::protobuf::StringPiece (or absl::string_view) to - // tensorflow::StringPiece. - auto error_msg = status.message(); - return errors::Internal( - "Could not convert proto to JSON string: ", - absl::string_view(error_msg.data(), error_msg.length())); - } - return absl::OkStatus(); -} - -} // namespace - -MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, - int64_t max_num_snapshots) { - MemoryProfile memory_profile = GenerateMemoryProfile(&host_plane); - ProcessMemoryProfileProto(max_num_snapshots, &memory_profile); - // Default version number is 0, set version number to 1 here due to the new - // memory profile sampling algorithm. - memory_profile.set_version(1); - return memory_profile; -} - -absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, - std::string* json_output) { - if (const XPlane* host_plane = - FindPlaneWithName(xspace, kHostThreadsPlaneName)) { - MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane); - TF_RETURN_IF_ERROR(ConvertProtoToJson(memory_profile, json_output)); - } - return absl::OkStatus(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h b/tensorflow/core/profiler/convert/xplane_to_memory_profile.h deleted file mode 100644 index bc6ceef062132d..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ - -#include - -#include "absl/status/status.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/memory_profile.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Process the host threads XPlane and generate MemoryProfile result; at most -// max_num_snapshots will be displayed on the UI. -// REQUIRED: host_plane should have been grouped by calling GroupTfEvents(). -MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, - int64_t max_num_snapshots = 1000); - -absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, - std::string* json_output); -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc deleted file mode 100644 index 512e81b3b740bd..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" - -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/memory_profile.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -// Tests with a sample profile with multiple memory allocation and deallocation -// activities within one memory allocator captured in host trace. -TEST(ConvertXPlaneToMemoryProfile, OneAllocatorMultiActivitiesTest) { - XSpace space; - XPlane* host_plane = GetOrCreateHostXPlane(&space); - XPlaneBuilder host_plane_builder(host_plane); - host_plane_builder.ReserveLines(1); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "MemoryAllocation", - 40000, 1000, - {{StatType::kBytesReserved, int64_t{2000}}, - {StatType::kBytesAllocated, int64_t{3000}}, - {StatType::kBytesAvailable, int64_t{5000}}, - {StatType::kPeakBytesInUse, int64_t{8500}}, - {StatType::kRequestedBytes, int64_t{200}}, - {StatType::kAllocationBytes, int64_t{256}}, - {StatType::kAddress, int64_t{222333}}, - {StatType::kStepId, int64_t{-93746}}, - {StatType::kDataType, int64_t{1}}, - {StatType::kAllocatorName, "GPU_0_bfc"}, - {StatType::kTfOp, "foo/bar"}, - {StatType::kRegionType, "output"}, - {StatType::kTensorShapes, "[3, 3, 512, 512]"}}); - - CreateXEvent(&host_plane_builder, &tf_executor_thread, "MemoryDeallocation", - 50000, 1000, - {{StatType::kBytesReserved, int64_t{2000}}, - {StatType::kBytesAllocated, int64_t{2744}}, - {StatType::kBytesAvailable, int64_t{5256}}, - {StatType::kPeakBytesInUse, int64_t{8500}}, - {StatType::kRequestedBytes, int64_t{200}}, - {StatType::kAllocationBytes, int64_t{256}}, - {StatType::kAddress, int64_t{222333}}, - {StatType::kStepId, int64_t{0}}, - {StatType::kDataType, int64_t{0}}, - {StatType::kAllocatorName, "GPU_0_bfc"}, - {StatType::kRegionType, ""}, - {StatType::kTensorShapes, ""}}); - - CreateXEvent(&host_plane_builder, &tf_executor_thread, "MemoryAllocation", - 70000, 1000, - {{StatType::kBytesReserved, int64_t{2000}}, - {StatType::kBytesAllocated, int64_t{5000}}, - {StatType::kBytesAvailable, int64_t{3000}}, - {StatType::kPeakBytesInUse, int64_t{9500}}, - {StatType::kRequestedBytes, int64_t{300}}, - {StatType::kAllocationBytes, int64_t{300}}, - {StatType::kAddress, int64_t{345678}}, - {StatType::kStepId, int64_t{-93746}}, - {StatType::kDataType, int64_t{9}}, - {StatType::kAllocatorName, "GPU_0_bfc"}, - {StatType::kTfOp, "mul_grad/Sum"}, - {StatType::kRegionType, "temp"}, - {StatType::kTensorShapes, "[1, 2]"}}); - - tsl::profiler::GroupTfEvents(&space); - MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane); - EXPECT_EQ(memory_profile.memory_profile_per_allocator().size(), 1); - EXPECT_EQ(memory_profile.num_hosts(), 1); - EXPECT_EQ(memory_profile.memory_ids_size(), 1); - EXPECT_EQ(memory_profile.memory_profile_per_allocator().begin()->first, - "GPU_0_bfc"); - EXPECT_EQ(memory_profile.version(), 1); - const auto& allocator_memory_profile = - memory_profile.memory_profile_per_allocator().begin()->second; - EXPECT_EQ( - allocator_memory_profile.profile_summary().peak_bytes_usage_lifetime(), - 9500); - EXPECT_EQ(allocator_memory_profile.profile_summary() - .peak_stats() - .peak_bytes_in_use(), - 7000); - EXPECT_EQ(allocator_memory_profile.profile_summary().peak_stats_time_ps(), - 70000); - EXPECT_EQ(allocator_memory_profile.sampled_timeline_snapshots_size(), 3); - EXPECT_EQ(allocator_memory_profile.memory_profile_snapshots_size(), 1); - EXPECT_EQ(allocator_memory_profile.memory_profile_snapshots() - .at(0) - .activity_metadata() - .tf_op_name(), - "mul_grad/Sum"); - EXPECT_EQ(allocator_memory_profile.active_allocations_size(), 3); - EXPECT_EQ( - allocator_memory_profile.active_allocations().at(2).snapshot_index(), 0); - EXPECT_EQ(allocator_memory_profile.special_allocations_size(), 2); - EXPECT_EQ(allocator_memory_profile.special_allocations().at(1).tf_op_name(), - "stack"); - EXPECT_EQ( - allocator_memory_profile.special_allocations().at(1).allocation_bytes(), - 2000); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc deleted file mode 100644 index 225171766e12fc..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ /dev/null @@ -1,389 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/op_stack.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "xprof/utils/cost_utils.h" // from @org_xprof -#include "xprof/utils/gpu_event_stats.h" // from @org_xprof -#include "xprof/utils/hlo_module_map.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof -#include "xprof/utils/op_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::GpuEventStats; -using tsl::profiler::GetDeviceEventTimespan; - -struct HLOTracker { - uint64_t duration = 0; - uint64_t program_id = 0; - uint64_t group_id = 0; - bool is_eager; - const HloInstructionWrapper* hlo_instruction = nullptr; - std::string hlo_op_name; - - void Reset() { - duration = program_id = group_id = 0; - hlo_op_name.clear(); - hlo_instruction = nullptr; - } -}; - -// Type of a TensorFlow Op activity, which is either beginning or ending an Op. -enum TfActivityType { kTfOpBegin, kTfOpEnd }; - -// Instant activity representing the begin or end of a host-side TF Op. -struct TfActivity { - // The timestamp in picoseconds when this activity happened. - uint64 timestamp_ps; - // The ID of this Op. - uint32 tf_op_id; - // Type of this activity. - TfActivityType activity_type; - // Full TF op name and type of this activity (backed by XEvent::name). - tsl::profiler::TfOp tf_op; - // Whether it is eagerly executed. - bool is_eager; -}; - -// TF Op metrics stored as element in OpStack. -struct TfOpInfo { - explicit TfOpInfo(uint64 ts) : start_timestamp_ps(ts) {} - - // Start timestamp in picoseconds. - uint64 start_timestamp_ps; - // Children duration in picoseconds. - uint64 children_duration_ps = 0; -}; - -// Processes a TF-activity on particular core. -void ProcessOneTfActivity(const TfActivity& activity, - OpStack* tf_op_stack, - TfMetricsDbData* tf_metrics_data) { - uint32 tf_op_id = activity.tf_op_id; - switch (activity.activity_type) { - case kTfOpBegin: { - tf_op_stack->Push(tf_op_id, - std::make_unique(activity.timestamp_ps)); - break; - } - case kTfOpEnd: { - std::unique_ptr info = tf_op_stack->Pop(tf_op_id); - if (info == nullptr) { - // This happens if TraceMes overlap. - VLOG(1) << "No begin event found for TF activity id=" << tf_op_id - << " name=" << activity.tf_op.name - << " type=" << activity.tf_op.type; - break; - } - tsl::profiler::Timespan tf_op_span = tsl::profiler::PicoSpan( - info->start_timestamp_ps, activity.timestamp_ps); - tf_metrics_data->tf_metrics_db_builder.EnterOp( - activity.tf_op.name, activity.tf_op.type, activity.is_eager, - tf_op_span.duration_ps(), info->children_duration_ps); - TfOpInfo* parent_info = tf_op_stack->Top(); - if (parent_info != nullptr) { - parent_info->children_duration_ps += tf_op_span.duration_ps(); - } - if (tsl::profiler::IsInfeedEnqueueOp(activity.tf_op.type)) { - tf_metrics_data->tf_metrics_db_builder.EnterHostInfeedEnqueue( - tf_op_span); - } - break; - } - } -} - -// Processes all TF-activities on the given core. -void ProcessTfActivities(std::vector* tf_activities, - TfMetricsDbData* tf_metrics_db_data) { - if (tf_activities->empty()) return; - absl::c_stable_sort(*tf_activities, - [](const TfActivity& a, const TfActivity& b) { - return a.timestamp_ps < b.timestamp_ps; - }); - OpStack tf_op_stack; - for (const auto& tf_activity : *tf_activities) { - ProcessOneTfActivity(tf_activity, &tf_op_stack, tf_metrics_db_data); - } - SetTotalTimePs( - tf_metrics_db_data->tf_metrics_db, - tf_activities->back().timestamp_ps - tf_activities->front().timestamp_ps); -} - -void CollectTfActivities( - const XLineVisitor& line, - const absl::flat_hash_map& tf_ops, - std::vector* tf_activities) { - uint32 tf_op_id = 0; - if (IsDerivedThreadId(line.Id())) return; - tf_activities->reserve(line.NumEvents() * 2); - line.ForEachEvent( - [&tf_ops, &tf_op_id, &tf_activities](const XEventVisitor& event) { - const tsl::profiler::TfOp* tf_op = gtl::FindOrNull(tf_ops, event.Id()); - if (tf_op != nullptr) { - ++tf_op_id; - bool is_eager = false; - if (std::optional stat = - event.GetStat(StatType::kIsEager)) { - is_eager = stat->IntValue(); - } - tsl::profiler::Timespan span = event.GetTimespan(); - tf_activities->push_back( - {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager}); - tf_activities->push_back( - {span.end_ps(), tf_op_id, kTfOpEnd, *tf_op, is_eager}); - } - if (auto tf_op_stat = event.GetStat(StatType::kTfOp); - tf_op_stat.has_value()) { - ++tf_op_id; - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(tf_op_stat->StrOrRefValue()); - tsl::profiler::Timespan span = event.GetTimespan(); - tf_activities->push_back( - {span.begin_ps(), tf_op_id, kTfOpBegin, tf_op, false}); - tf_activities->push_back( - {span.end_ps(), tf_op_id, kTfOpEnd, tf_op, false}); - } - }); -} - -} // namespace - -absl::flat_hash_map -CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace) { - absl::flat_hash_map tf_ops; - for (const auto& id_metadata : host_trace.event_metadata()) { - const XEventMetadata& metadata = id_metadata.second; - // On the host, we have added some user-specified TraceMe's in addition to - // the TraceMe's added to every TensorFlow op by the system. These - // user-inserted TraceMe's have "unknown" type. We don't count them in - // Tf-stats. - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(metadata.name()); - if (tf_op.category != tsl::profiler::Category::kUnknown) { - tf_ops.try_emplace(metadata.id(), tf_op); - } - } - return tf_ops; -} - -TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( - const XLineVisitor& line, - const absl::flat_hash_map& tf_ops) { - TfMetricsDbData tf_metrics_db_data; - std::vector tf_activities; - CollectTfActivities(line, tf_ops, &tf_activities); - ProcessTfActivities(&tf_activities, &tf_metrics_db_data); - return tf_metrics_db_data; -} - -void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst) { - AddIdleOp(src.tf_metrics_db); - // Host OpMetricsDb does not need to update the number of cores a certain op - // occurs. - dst->Combine(src.tf_metrics_db, /*update_num_cores=*/false); - src.tf_metrics_db.Clear(); -} - -OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) { - absl::flat_hash_map tf_ops = - CollectTfOpsFromHostThreadsXPlane(host_trace); - OpMetricsDb result; - OpMetricsDbCombiner combiner(&result); - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&host_trace); - plane.ForEachLine([&tf_ops, &combiner](const XLineVisitor& line) { - ConsumeTfMetricsDbData( - ConvertHostThreadsXLineToTfMetricsDbData(line, tf_ops), &combiner); - }); - return result; -} - -OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - XEventsOpMetricsDbBuilder builder; - uint64_t first_op_timestamp_ps = std::numeric_limits::max(); - uint64_t last_op_timestamp_ps = 0; - - struct ParentReference { - const XEventVisitor event; - tsl::profiler::Timespan device_timespan; - uint64_t children_duration_ps = 0; - }; - - tsl::profiler::AncestorStack event_stack( - [&](const ParentReference& parent) { - OpMetrics op_metrics = FromXEvent(parent.event); - op_metrics.set_time_ps(parent.device_timespan.duration_ps()); - op_metrics.set_self_time_ps(op_metrics.time_ps() - - parent.children_duration_ps); - builder.AddOpMetric(op_metrics, GetOpKeyFromXEvent(parent.event)); - }, - [](const ParentReference& parent, const ParentReference& child) { - return parent.device_timespan.Includes(child.device_timespan); - }, - [](ParentReference& parent, ParentReference& child) { - parent.children_duration_ps += child.device_timespan.duration_ps(); - }); - - auto track_first_and_last_op_timestamps = [&](const XEventVisitor& event) { - tsl::profiler::Timespan timespan = GetDeviceEventTimespan(event); - first_op_timestamp_ps = - std::min(first_op_timestamp_ps, timespan.begin_ps()); - last_op_timestamp_ps = std::max(last_op_timestamp_ps, timespan.end_ps()); - }; - - plane.ForEachLine([&](const XLineVisitor& line) { - if (line.Name() == tsl::profiler::kSparseCoreStepLineName || - line.Name() == tsl::profiler::kStepLineName) { - line.ForEachEvent(track_first_and_last_op_timestamps); - } - if (!tsl::profiler::IsOpLineName(line.Name())) return; - line.ForEachEvent([&](const XEventVisitor& event) { - tsl::profiler::Timespan timespan = GetDeviceEventTimespan(event); - track_first_and_last_op_timestamps(event); - - event_stack.Push({.event = event, .device_timespan = timespan}); - }); - event_stack.Flush(); - }); - - return builder.Finalize(last_op_timestamp_ps - first_op_timestamp_ps); -} - -void AggregateHloFunc(HLOTracker& current, DeviceOpMetricsDbBuilder& metricDb) { - if (current.hlo_instruction == nullptr) return; - auto performance_info_wrapper = - current.hlo_instruction->GetPerformanceInfoWrapper(); - auto flops = 0; - auto bytes_accessed = 0; - if (performance_info_wrapper != nullptr) { - flops = performance_info_wrapper->flops(); - bytes_accessed = performance_info_wrapper->bytes_accessed(); - } - metricDb.EnterOp( - current.program_id, current.hlo_op_name, - current.hlo_instruction->Category(), current.hlo_instruction->TfOpName(), - current.hlo_instruction->DeduplicatedName(), current.is_eager, 1, - current.duration, 0, performance_info_wrapper->DeviceFlops(), - performance_info_wrapper->bytes_accessed(), - ConvertPerformanceInfo( - performance_info_wrapper->memory_accessed_breakdown(), 1), - performance_info_wrapper->ModelFlops(), - current.hlo_instruction->Expression()); - current.Reset(); -} - -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace, const HloModuleMap& hlo_module_map) { - OpMetricsDb result; - DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result); - - int64_t first_op_offset_ps = kint64max; - int64_t last_op_offset_ps = 0; - - TfOpRoofLineCostEstimator op_level_cost_estimator; - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - HLOTracker current; - plane.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) return; - line.ForEachEvent([&](const XEventVisitor& event) { - first_op_offset_ps = std::min(first_op_offset_ps, event.OffsetPs()); - last_op_offset_ps = std::max(last_op_offset_ps, event.EndOffsetPs()); - - GpuEventStats stats(&event); - if (stats.IsXlaOp()) { - const auto* hlo_instruction = GetHloInstruction( - hlo_module_map, stats.program_id, stats.hlo_op_names.back()); - if (hlo_instruction != nullptr) { - if (stats.hlo_op_names.back() != current.hlo_op_name || - stats.group_id != current.group_id) { - AggregateHloFunc(current, device_op_metrics_db_builder); - } - // Merge identical and contiguous HLOs. - current.hlo_instruction = hlo_instruction; - current.hlo_op_name = stats.hlo_op_names.back(); - current.duration += event.DurationPs(); - current.is_eager = stats.is_eager; - current.program_id = *stats.program_id; - if (stats.group_id.has_value()) { - current.group_id = *stats.group_id; - } - } - } else if (stats.IsTfOp()) { - AggregateHloFunc(current, device_op_metrics_db_builder); - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(stats.tf_op_fullname); - PerformanceInfo perf_info; - if (tf_op.category != tsl::profiler::Category::kUnknown) { - auto costs = op_level_cost_estimator.Predict(event); - // NOTE: events are per kernel, but costs are per tf-ops. - perf_info.set_flops(costs.flops); - perf_info.set_bytes_accessed(costs.bytes_accessed); - } - std::string name = absl::StrCat(tf_op.name, "/", event.Name()); - device_op_metrics_db_builder.EnterOp( - /*program_id=*/0, - /**name=*/name, - /**category=*/tf_op.type, - /*provenance=*/stats.tf_op_fullname, "", stats.is_eager, - /*occurrences=*/1, event.DurationPs(), - /*children_time_ps=*/0, perf_info.flops(), - perf_info.bytes_accessed()); - } - }); - AggregateHloFunc(current, device_op_metrics_db_builder); - }); - SetTotalTimePs( - result, last_op_offset_ps ? last_op_offset_ps - first_op_offset_ps : 0); - AddIdleOp(result); - return result; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h deleted file mode 100644 index fbc52d3e27a5b7..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ - -#include "absl/container/flat_hash_map.h" -#include "absl/types/optional.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "xprof/utils/op_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Data per host thread for TensorFlow Op Metrics Database. -struct TfMetricsDbData { - // A database of TF-Op metrics for this core. - OpMetricsDb tf_metrics_db; - HostOpMetricsDbBuilder tf_metrics_db_builder{&tf_metrics_db}; -}; - -absl::flat_hash_map -CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace); - -TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( - const XLineVisitor& line, - const absl::flat_hash_map& tf_ops); - -void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); - -OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace); - -// Converts GPU device trace to OpMetricsDb. -// Will use HloModuleMap to source performance info for cost analysis. -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace, const HloModuleMap& hlo_module_map); - -// Convert TPU DeviceTrace XPlane to OpMetricDb -OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc deleted file mode 100644 index 2ed6a52c49947e..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h" - -#include -#include -#include -#include - -#include -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "xprof/utils/hlo_cost_analysis_wrapper.h" // from @org_xprof -#include "xprof/utils/hlo_module_map.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof -#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -#if defined(PLATFORM_GOOGLE) -// NOLINTNEXTLINE: clang-tidy missing-includes -using ::testing::EqualsProto; -#endif - -void AddTensorFlowTpuOpEvent(std::string&& name, std::string&& tf_op_fullname, - int64_t start_timestamp_ns, int64_t duration_ns, - std::string&& hlo_category, uint64 flops, - uint64 bytes_accessed, int64_t occurences, - int64_t self_duration, int64_t program_id, - int64_t symbol_id, XPlaneBuilder* plane, - XLineBuilder* line) { - XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - event.SetNumOccurrences(occurences); - XStatsBuilder event_metadata( - plane->GetOrCreateEventMetadata(name), plane); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - tf_op_fullname); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kHloCategory)), - hlo_category); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kFlops)), flops); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), - symbol_id); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), - program_id); -} - -void AddTensorFlowOpEvent(std::string&& tf_op_fullname, - int64_t start_timestamp_ns, int64_t duration_ns, - bool on_device, absl::string_view kernel_name, - XPlaneBuilder* plane, XLineBuilder* line) { - absl::string_view name = on_device ? kernel_name : tf_op_fullname; - XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - if (!on_device) return; - event.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); -} - -void AddXlaCpuOpEvent(std::string&& hlo_op_name, std::string&& tf_op, - int64_t start_timestamp_ns, int64_t duration_ns, - XPlaneBuilder* plane, XLineBuilder* line) { - XEventBuilder event = - line->AddEvent(*plane->GetOrCreateEventMetadata(hlo_op_name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - event.ParseAndAddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), tf_op); -} - -TEST(ConvertXPlaneToOpMetricsDb, HostOpMetricsDb) { - static constexpr char kTfOp1[] = "TfOp1"; - static constexpr char kTfOp2[] = "TfOp2"; - constexpr int64_t kTfOp1StartNs = 100000; - constexpr int64_t kTfOp1DurationNs = 8000; - constexpr int64_t kTfOp2StartNs = 110000; - constexpr int64_t kTfOp2DurationNs = 10000; - - XSpace xspace; - XPlane* xplane = GetOrCreateHostXPlane(&xspace); - XPlaneBuilder host_plane(xplane); - XLineBuilder thread1 = host_plane.GetOrCreateLine(/*line_id=*/10); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kTfOp1StartNs, - kTfOp1DurationNs, /*on_device=*/false, - /*kernel_name=*/"", &host_plane, &thread1); - XLineBuilder thread2 = host_plane.GetOrCreateLine(/*line_id=*/20); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kTfOp1StartNs, - kTfOp1DurationNs, /*on_device=*/false, - /*kernel_name=*/"", &host_plane, &thread2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kTfOp2StartNs, - kTfOp2DurationNs, /*on_device=*/false, - /*kernel_name=*/"", &host_plane, &thread2); - - OpMetricsDb op_metrics = ConvertHostThreadsXPlaneToOpMetricsDb(*xplane); - // Op1, Op2, Idle. - EXPECT_EQ(3, op_metrics.metrics_db_size()); - uint64 total_op_duration = - tsl::profiler::NanoToPico(kTfOp1DurationNs * 2 + kTfOp2DurationNs); - EXPECT_EQ(total_op_duration, op_metrics.total_op_time_ps()); - uint64 total_duration = tsl::profiler::NanoToPico( - kTfOp2StartNs - kTfOp1StartNs + kTfOp2DurationNs + kTfOp1DurationNs); - EXPECT_EQ(total_duration, op_metrics.total_time_ps()); - - // Verifies OpMetricsDb is built correctly. - const OpMetrics& op_1 = op_metrics.metrics_db().at(0); - EXPECT_EQ(kTfOp1, op_1.name()); - EXPECT_EQ(kTfOp1, op_1.category()); - EXPECT_EQ(2, op_1.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kTfOp1DurationNs) * 2, op_1.time_ps()); - - const OpMetrics& idle = op_metrics.metrics_db().at(1); - EXPECT_EQ(kIdle, idle.name()); - EXPECT_EQ(kIdle, idle.category()); - // Idle time is the gap between Op2 start and the end of Op1, which is 2000ns. - EXPECT_EQ(tsl::profiler::NanoToPico(2000), idle.time_ps()); - - const OpMetrics& op_2 = op_metrics.metrics_db().at(2); - EXPECT_EQ(kTfOp2, op_2.name()); - EXPECT_EQ(kTfOp2, op_2.category()); - EXPECT_EQ(1, op_2.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kTfOp2DurationNs), op_2.time_ps()); -} - -TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDb) { - // TfOp1 has kernel1 and kernel2; TfOp2 has kernel3. - static constexpr char kTfOp1[] = "TfOp1"; - static constexpr char kTfOp2[] = "TfOp2"; - static constexpr char kKernel1[] = "kernel1"; - static constexpr char kKernel2[] = "kernel2"; - static constexpr char kKernel3[] = "kernel3"; - constexpr int64_t kKernel1StartNs = 100000; - constexpr int64_t kKernel1DurationNs = 8000; - constexpr int64_t kKernel2StartNs = 110000; - constexpr int64_t kKernel2DurationNs = 10000; - constexpr int64_t kKernel3StartNs = 120000; - constexpr int64_t kKernel3DurationNs = 10000; - - XSpace xspace; - XPlane* xplane = GetOrCreateGpuXPlane(&xspace, /*device_ordinal=*/0); - XPlaneBuilder device_plane(xplane); - XLineBuilder stream1 = device_plane.GetOrCreateLine(/*line_id=*/10); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream1); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream1); - XLineBuilder stream2 = device_plane.GetOrCreateLine(/*line_id=*/20); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kKernel3StartNs, - kKernel3DurationNs, /*on_device=*/true, kKernel3, - &device_plane, &stream2); - HloModuleMap hlo_module_map; - tensorflow::profiler::HloCostAnalysisWrapper::Factory create_cost_analysis = - []() { return tensorflow::profiler::CreateXprofGpuCostAnalysis(); }; - ProcessHloModuleMapFromXSpace(hlo_module_map, &xspace, create_cost_analysis); - OpMetricsDb op_metrics = - ConvertDeviceTraceXPlaneToOpMetricsDb(*xplane, hlo_module_map); - - // kernel1, kernel2, kernel3, Idle. - EXPECT_EQ(4, op_metrics.metrics_db_size()); - uint64 total_op_duration = tsl::profiler::NanoToPico( - kKernel1DurationNs * 2 + kKernel2DurationNs * 2 + kKernel3DurationNs); - EXPECT_EQ(total_op_duration, op_metrics.total_op_time_ps()); - // For device, the total_duration for each device is the total duration - // merged from all GPU streams, which is from 100000 to 130000. - uint64 total_duration = tsl::profiler::NanoToPico( - kKernel3StartNs + kKernel3DurationNs - kKernel1StartNs); - EXPECT_EQ(std::max(total_duration, total_op_duration), - op_metrics.total_time_ps()); - - // Verifies OpMetricsDb is built correctly. - const OpMetrics& op_1 = op_metrics.metrics_db().at(0); - EXPECT_EQ(absl::StrCat(kTfOp1, "/", kKernel1), op_1.name()); - EXPECT_EQ(kTfOp1, op_1.category()); - EXPECT_EQ(2, op_1.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kKernel1DurationNs) * 2, op_1.time_ps()); - - const OpMetrics& op_2 = op_metrics.metrics_db().at(1); - EXPECT_EQ(absl::StrCat(kTfOp1, "/", kKernel2), op_2.name()); - EXPECT_EQ(kTfOp1, op_2.category()); - EXPECT_EQ(2, op_2.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kKernel2DurationNs) * 2, op_2.time_ps()); - - const OpMetrics& op_3 = op_metrics.metrics_db().at(2); - EXPECT_EQ(absl::StrCat(kTfOp2, "/", kKernel3), op_3.name()); - EXPECT_EQ(kTfOp2, op_3.category()); - EXPECT_EQ(1, op_3.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kKernel3DurationNs), op_3.time_ps()); - - const OpMetrics& idle = op_metrics.metrics_db().at(3); - EXPECT_EQ(kIdle, idle.name()); - EXPECT_EQ(kIdle, idle.category()); - // GPU is always busy in this example. - EXPECT_EQ(tsl::profiler::NanoToPico(0), idle.time_ps()); -} - -TEST(ConvertXPlaneToOpMetricsDb, TpuDeviceOpMetricsDb) { - XSpace xspace; - XPlane* xplane = GetOrCreateTpuXPlane(&xspace, /*device_ordinal=*/0, "TPU V4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder device_plane(xplane); - XLineBuilder stream1 = device_plane.GetOrCreateLine(/*line_id=*/10); - stream1.SetName(tsl::profiler::kTensorFlowOpLineName); - AddTensorFlowTpuOpEvent("MatMul", "while:MatMul", 0, 10, "MatMul", 34, 45, 2, - 5, 1, 1, &device_plane, &stream1); - OpMetricsDb op_metrics = ConvertTpuDeviceTraceXPlaneToOpMetricsDb(*xplane); -#if defined(PLATFORM_GOOGLE) - EXPECT_THAT(op_metrics, - EqualsProto(R"pb(metrics_db { - hlo_module_id: 1 - self_time_ps: 10000 - flops: 68 - model_flops: 68 - num_cores: 1 - occurrences: 2 - name: "MatMul" - time_ps: 10000 - category: "MatMul" - provenance: "while:MatMul" - min_time_ps: 10000 - } - metrics_db { name: "IDLE" category: "IDLE" } - total_time_ps: 10000 - total_op_time_ps: 10000 - )pb")); -#endif -} - -TEST(ConvertXPlaneToOpMetricsDb, HostXPlaneWithXlaOps) { - XPlane xplane; - XPlaneBuilder plane(&xplane); - XLineBuilder line = plane.GetOrCreateLine(/*line_id=*/10); - AddXlaCpuOpEvent("xla_op", "tf_op", 100000, 8000, &plane, &line); - AddXlaCpuOpEvent("xla_op2", "tf_op2", 110000, 10000, &plane, &line); - OpMetricsDb op_metrics = ConvertHostThreadsXPlaneToOpMetricsDb(xplane); -#if defined(PLATFORM_GOOGLE) - EXPECT_THAT(op_metrics, EqualsProto(R"pb(metrics_db { - self_time_ps: 8000000 - occurrences: 1 - name: "tf_op" - time_ps: 8000000 - } - metrics_db { - self_time_ps: 10000000 - occurrences: 1 - name: "tf_op2" - time_ps: 10000000 - } - metrics_db { - self_time_ps: 2000000 - name: "IDLE" - time_ps: 2000000 - category: "IDLE" - } - total_time_ps: 20000000 - total_op_time_ps: 18000000 - precision_stats {} - )pb")); -#endif -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc deleted file mode 100644 index bc3a6397c41fc3..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ /dev/null @@ -1,495 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" - -#include - -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/convert/duty_cycle_combiner.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" -#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h" -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/hardware_types.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/kernel_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof -#include "xprof/utils/device_caps_utils.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof -#include "xprof/utils/gpu_event_stats.h" // from @org_xprof -#include "xprof/utils/hardware_type_utils.h" // from @org_xprof -#include "xprof/utils/hlo_cost_analysis_wrapper.h" // from @org_xprof -#include "xprof/utils/hlo_module_map.h" // from @org_xprof -#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof -#include "xprof/utils/op_utils.h" // from @org_xprof -#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::FindPlanesWithPrefix; -using tsl::profiler::FindTensorCorePlanes; -using tsl::profiler::Timespan; - -std::string Hostname(const XSpace& space) { - if (space.hostnames().empty()) return "localhost"; - DCHECK_EQ(space.hostnames_size(), 1); - const std::string& hostname = space.hostnames(0); - return hostname; -} - -} // namespace - -PerfEnv MakePerfEnv(double peak_tera_flops_per_second, - std::vector peak_bws) { - PerfEnv result; - result.set_peak_tera_flops_per_second(peak_tera_flops_per_second); - - for (const auto bw : peak_bws) { - result.add_peak_bws_giga_bytes_per_second(bw); - } - result.set_ridge_point(tsl::profiler::TeraToGiga(peak_tera_flops_per_second) / - peak_bws[MemBwType::MEM_BW_TYPE_HBM_RW]); - return result; -} - -PerfEnv MakePerfEnvForTpu(double peak_tera_flops_per_second, - std::vector peak_bws, bool has_merged_vmem, - bool has_megacore) { - PerfEnv result = MakePerfEnv(peak_tera_flops_per_second, peak_bws); - result.set_has_cmem(peak_bws[MemBwType::MEM_BW_TYPE_CMEM_RD] > 0 || - peak_bws[MemBwType::MEM_BW_TYPE_CMEM_WR] > 0); - result.set_has_merged_vmem(has_merged_vmem); - result.set_has_megacore(has_megacore); - return result; -} - -PerfEnv MakePerfEnvForGpu(double peak_tera_flops_per_second, - std::vector peak_bws) { - return MakePerfEnv(peak_tera_flops_per_second, peak_bws); -} - -PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { - DeviceCapabilities cap = GetDeviceCaps(device_plane); - if (!absl::StartsWith(device_plane.name(), kTpuPlanePrefix)) { - double peak_tera_flops_per_second = - cap.num_cores() * - tsl::profiler::GigaToTera(GetFlopMaxThroughputPerSM(cap)); - double hbm_bw_giga_bytes_per_second = - tsl::profiler::UniToGiga(cap.memory_bandwidth()); - double shm_giga_bytes_per_second = - cap.num_cores() * - tsl::profiler::UniToGiga(GetSharedMemoryBandwidthPerSM(cap)); - // Note that treat SRAM_RD and SRAM_WR as the same. So in future, we could - // only use one for shared memory / L1 cache, one for another like L2. - return MakePerfEnvForGpu(peak_tera_flops_per_second, - {/*HBM_RW=*/hbm_bw_giga_bytes_per_second, - /*SRAM_RD=*/shm_giga_bytes_per_second, - /*SRAM_WR=*/shm_giga_bytes_per_second}); - } else { - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(&device_plane); - std::optional peak_tera_flops_per_second = - visitor.GetStat(StatType::kDevCapPeakTeraflopsPerSecond); - double peak_tera_flops_per_second_val = - peak_tera_flops_per_second.has_value() - ? peak_tera_flops_per_second->DoubleValue() - : 0.0; - std::optional peak_hbm_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakHbmBwGigabytesPerSecond); - double peak_hbm_bw_giga_bytes_per_second_val = - peak_hbm_bw_giga_bytes_per_second.has_value() - ? peak_hbm_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional peak_sram_rd_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakSramRdBwGigabytesPerSecond); - double peak_sram_rd_bw_giga_bytes_per_second_val = - peak_sram_rd_bw_giga_bytes_per_second.has_value() - ? peak_sram_rd_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional peak_sram_wr_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakSramWrBwGigabytesPerSecond); - double peak_sram_wr_bw_giga_bytes_per_second_val = - peak_sram_wr_bw_giga_bytes_per_second.has_value() - ? peak_sram_wr_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional cmem_rd_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakCmemRdBwGigabytesPerSecond); - double cmem_rd_bw_giga_bytes_per_second_val = - cmem_rd_bw_giga_bytes_per_second.has_value() - ? cmem_rd_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional cmem_wr_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakCmemWrBwGigabytesPerSecond); - double cmem_wr_bw_giga_bytes_per_second_val = - cmem_wr_bw_giga_bytes_per_second.has_value() - ? cmem_wr_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional vmem_rd_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakVmemRdBwGigabytesPerSecond); - double vmem_rd_bw_giga_bytes_per_second_val = - vmem_rd_bw_giga_bytes_per_second.has_value() - ? vmem_rd_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional vmem_wr_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakVmemWrBwGigabytesPerSecond); - double vmem_wr_bw_giga_bytes_per_second_val = - vmem_wr_bw_giga_bytes_per_second.has_value() - ? vmem_wr_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional has_megacore = - visitor.GetStat(StatType::kDevHasMegacore); - bool has_megacore_val = - has_megacore.has_value() ? has_megacore->BoolValue() : false; - std::optional has_merged_vmem = - visitor.GetStat(StatType::kDevHasMergedVmem); - bool has_merged_vmem_val = - has_merged_vmem.has_value() ? has_merged_vmem->BoolValue() : false; - return MakePerfEnvForTpu( - peak_tera_flops_per_second_val, - {/*HBM_RW=*/peak_hbm_bw_giga_bytes_per_second_val, - /*SRAM_RD=*/peak_sram_rd_bw_giga_bytes_per_second_val, - /*SRAM_WR=*/peak_sram_wr_bw_giga_bytes_per_second_val, - /**CMEM_RD=*/cmem_rd_bw_giga_bytes_per_second_val, - /**CMEM_WR=*/cmem_wr_bw_giga_bytes_per_second_val, - /**VMEM_RD=*/vmem_rd_bw_giga_bytes_per_second_val, - /**VMEM_WR=*/vmem_wr_bw_giga_bytes_per_second_val}, - has_merged_vmem_val, has_megacore_val); - } -} - -void SetRunEnvironment(const XSpace& space, RunEnvironment* env) { - // Currently, we only support profiling one host and one program. - env->set_host_count(1); - env->set_task_count(1); - env->mutable_hostnames()->insert({Hostname(space), true}); - - std::vector gpu_planes = - FindPlanesWithPrefix(space, kGpuPlanePrefix); - if (!gpu_planes.empty()) { - absl::string_view gpu_model = - GpuModelName(GetDeviceCaps(*gpu_planes.front())); - if (!gpu_model.empty()) { - env->set_device_type(std::string(gpu_model)); - } else { - env->set_device_type("GPU"); - } - env->set_device_core_count(gpu_planes.size()); - env->set_hardware_type(tensorflow::profiler::HardwareType::GPU); - } else if (std::vector tpu_planes = - FindTensorCorePlanes(space); - !tpu_planes.empty()) { - XPlaneVisitor visitor = - tsl::profiler::CreateTfXPlaneVisitor(tpu_planes.at(0)); - auto xstat = visitor.GetStat(StatType::kDeviceTypeString); - if (xstat.has_value()) { - env->set_device_type(std::string(xstat->StrOrRefValue())); - } - env->set_device_core_count(tpu_planes.size()); - env->set_hardware_type(tensorflow::profiler::HardwareType::TPU); - } else { - env->set_device_type("CPU"); - env->set_device_core_count(0); - env->set_hardware_type(tensorflow::profiler::HardwareType::CPU_ONLY); - } -} - -void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, - OpStats* op_stats) { - if (!space.errors().empty()) { - absl::flat_hash_set unique_errors; - unique_errors.insert(space.errors().begin(), space.errors().end()); - *op_stats->mutable_diagnostics()->mutable_errors() = {unique_errors.begin(), - unique_errors.end()}; - } - if (!space.warnings().empty()) { - absl::flat_hash_set unique_warnings; - unique_warnings.insert(space.warnings().begin(), space.warnings().end()); - *op_stats->mutable_diagnostics()->mutable_warnings() = { - unique_warnings.begin(), unique_warnings.end()}; - } -} - -// This function should be idempotent to be called -void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, - tensorflow::profiler::OpStats& op_stats) { - auto& program_id_to_name_map = *op_stats.mutable_program_id_to_name_map(); - for (const auto& [program_id, hlo_proto] : hlo_proto_map) { - program_id_to_name_map[program_id] = hlo_proto->hlo_module().name(); - } -} - -void UpdateOpMetricsDbFromHloModuleMap(OpMetricsDb& op_metrics_db, - const HloModuleMap& hlo_module_map) { - for (OpMetrics& op_metrics : *op_metrics_db.mutable_metrics_db()) { - EnterOpMetadataFromHloModuleMap(&op_metrics, hlo_module_map); - } -} - -DutyCycleTracker ConstructDutyCycleTracker(XPlaneVisitor& visitor) { - DutyCycleTracker duty_cycle_tracker; - visitor.ForEachLine([&](const XLineVisitor& line) { - if (line.Name() == kXlaOpLineName) { - line.ForEachEvent([&](const XEventVisitor& event) { - auto hlo_category_stat = event.GetStat(StatType::kHloCategory); - duty_cycle_tracker.AddInterval( - event.GetTimespan(), - !(hlo_category_stat && - tsl::profiler::IsOffDutyOp(hlo_category_stat->StrOrRefValue()))); - }); - } else if (line.Name() == kSparseCoreOpLineName) { - line.ForEachEvent([&](const XEventVisitor& event) { - duty_cycle_tracker.AddInterval( - event.GetTimespan(), - // TODO(b/397774568): Add support for SparseCore off-duty ops. - /*is_active=*/true); - }); - } else if (line.Name() == tsl::profiler::kXlaModuleLineName || - line.Name() == tsl::profiler::kSparseCoreModuleLineName) { - line.ForEachEvent([&](const XEventVisitor& event) { - duty_cycle_tracker.AddInterval(event.GetTimespan(), - /*is_active=*/false); - return; - }); - } - }); - return duty_cycle_tracker; -} - -OpStats ConvertXSpaceToOpStats(const XSpace& space, - const OpStatsOptions& options) { - OpStats op_stats; - StepEvents step_events; - PropagateXSpaceDiagnosticsToOpStats(space, &op_stats); - // Convert device planes. - OpMetricsDbCombiner op_metrics_db_combiner( - op_stats.mutable_device_op_metrics_db()); - SetRunEnvironment(space, op_stats.mutable_run_environment()); - - KernelReportMap reports; - - // Handle device planes first. device_planes will contain either GPU or TPU. - std::vector device_planes = - FindPlanesWithPrefix(space, kTpuPlanePrefix); - const bool is_gpu = device_planes.empty(); - if (is_gpu) { - device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); - } - const bool is_tpu = !is_gpu; - std::string hostname = Hostname(space); - auto& core_id_to_details_map = *op_stats.mutable_core_id_to_details(); - if (is_gpu) { - core_id_to_details_map[kDefaultGpuLocalCoreId].set_hostname(hostname); - } - DutyCycleCombiner duty_cycle_combiner; - // TODO(b/161942993) parallelize XPlane processing per thread. - HloModuleMap hlo_module_map; - if (options.generate_kernel_stats_db || - (is_tpu && options.generate_op_metrics_db)) { - tensorflow::profiler::HloCostAnalysisWrapper::Factory create_cost_analysis = - []() { return nullptr; }; - if (is_gpu) { - create_cost_analysis = []() { - return tensorflow::profiler::CreateXprofGpuCostAnalysis(); - }; - } - ProcessHloModuleMapFromXSpace(hlo_module_map, &space, create_cost_analysis); - } - for (const XPlane* device_trace : device_planes) { - if (options.generate_op_metrics_db) { - if (!op_stats.has_perf_env()) { - *op_stats.mutable_perf_env() = GetPerfEnvFromXPlane(*device_trace); - } - if (!is_tpu) { - OpMetricsDb device_op_metrics_db = - ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace, - hlo_module_map); - op_metrics_db_combiner.Combine(device_op_metrics_db); - } else { - // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is - // implemented. - if (!tsl::profiler::GetSparseCoreId(device_trace->name()).has_value()) { - OpMetricsDb device_op_metrics_db = - ConvertTpuDeviceTraceXPlaneToOpMetricsDb(*device_trace); - UpdateOpMetricsDbFromHloModuleMap(device_op_metrics_db, - hlo_module_map); - op_metrics_db_combiner.Combine(device_op_metrics_db); - } - } - } - if (options.generate_step_db) { - StepEvents device_step_events = - ConvertDeviceTraceXPlaneToStepEvents(*device_trace); - if (is_tpu) { - // In TPU, we take the intersection of step events across cores as well - // as hosts.see b/158249775 and cl/331842545. - IntersectCombineStepEvents(device_step_events, &step_events); - } else { - UnionCombineStepEvents(device_step_events, &step_events); - } - } - if (options.generate_kernel_stats_db) { - ConvertDeviceTraceXPlaneToKernelReports( - *device_trace, - // TODO(cleanup): Move this to xplane_to_kernel_stats_db.cc - [&](const GpuEventStats& stats, KernelReport* kernel) { - if (!stats.IsXlaOp()) return; - const HloInstructionWrapper* hlo_instruction = GetHloInstruction( - hlo_module_map, stats.program_id, stats.hlo_op_names.back()); - if (hlo_instruction != nullptr) { - kernel->set_op_name(std::string(hlo_instruction->TfOpName())); - bool tc_eligible = IsOpTensorCoreEligible(kernel->op_name()); - if (VLOG_IS_ON(1) && !tc_eligible && - kernel->is_kernel_using_tensor_core()) { - VLOG(1) << "Detected new Op using TensorCores: " - << kernel->op_name() << std::endl; - } - kernel->set_is_op_tensor_core_eligible( - tc_eligible || kernel->is_op_tensor_core_eligible()); - } - }, - &reports); - } - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(device_trace); - DutyCycleTracker duty_cycle_tracker = ConstructDutyCycleTracker(visitor); - if (std::optional core_details_stat = - visitor.GetStat(StatType::kCoreDetails)) { - CoreDetails core_details; - absl::string_view core_details_bytes = core_details_stat->BytesValue(); - if (core_details.ParseFromArray(core_details_bytes.data(), - core_details_bytes.size())) { - core_details.set_hostname(hostname); - // This is a backfill for XPlanes that were create before this field was - // added. - core_details.set_is_sparse_core( - tsl::profiler::GetSparseCoreId(device_trace->name()).has_value()); - core_id_to_details_map[device_trace->id()] = core_details; - } - } - if (core_id_to_details_map.contains(device_trace->id())) { - CoreDetails& core_details = core_id_to_details_map[device_trace->id()]; - duty_cycle_combiner.CombineCore(duty_cycle_tracker, - core_details.local_chip_id()); - } else { - LOG(WARNING) << "No CoreDetails found for TPU device plane: " - << device_trace->name(); - duty_cycle_combiner.CombineChip(duty_cycle_tracker); - } - } - - if (is_tpu) { - OpMetricsDb& op_metrics_db = *op_stats.mutable_device_op_metrics_db(); - op_metrics_db.set_idle_time_ps(duty_cycle_combiner.GetTotalIdleTimePs()); - op_metrics_db.set_busy_time_ps(duty_cycle_combiner.GetTotalActiveTimePs()); - } - - // Combine into reports. - if (options.generate_kernel_stats_db) { - CopyTopKDurationKernelReportsToDb(reports, - op_stats.mutable_kernel_stats_db()); - } - - bool has_device = !device_planes.empty(); - // Convert a host plane. - const XPlane* host_plane = FindPlaneWithName(space, kHostThreadsPlaneName); - if (host_plane) { - if (options.generate_op_metrics_db) { - *op_stats.mutable_host_op_metrics_db() = - ConvertHostThreadsXPlaneToOpMetricsDb(*host_plane); - } - if (options.generate_step_db && !has_device) { - StepEvents host_step_events = - ConvertHostThreadsXPlaneToStepEvents(*host_plane, nullptr); - UnionCombineStepEvents(host_step_events, &step_events); - } - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(host_plane); - auto stat = visitor.GetStat(StatType::kMatrixUnitUtilizationPercent); - if (stat.has_value()) { - op_stats.mutable_performance_counter_result() - ->set_matrix_unit_utilization_percent(stat->DoubleValue()); - } - TfFunctionDb* tf_function_db = op_stats.mutable_tf_function_db(); - visitor.ForEachLine([&](const XLineVisitor& line) { - CombineTfFunctionDb(ConvertHostThreadsXLineToTfFunctionDb(line), - tf_function_db); - }); - } - if (options.generate_step_db) { - if (is_tpu) { - // TPU steps relies on step number in step line in Xplane which has - // already dropped the incomplete steps at both beginning and end. - *op_stats.mutable_step_db() = ConvertStepEventsToStepDb( - has_device, /*maybe_drop_incomplete_steps=*/false, step_events); - *op_stats.mutable_device_op_metrics_db()->mutable_precision_stats() = - ComputePrecisionStats(step_events); - OpMetricsDbCombiner combiner( - op_stats.mutable_hlo_metrics_db_complete_steps_only()); - for (const auto& step_info : op_stats.step_db().step_sequence()) { - combiner.Combine(step_info.hlo_metrics_db()); - } - } else { - StepEvents nonoverlapped_step_events = - ToNonOverlappedStepEvents(step_events); - *op_stats.mutable_step_db() = ConvertStepEventsToStepDb( - has_device, options.maybe_drop_incomplete_steps, - nonoverlapped_step_events); - *op_stats.mutable_device_op_metrics_db()->mutable_precision_stats() = - ComputePrecisionStats(nonoverlapped_step_events); - } - } - - // Set program_id_to_name map in OpStats from Xspace - // Will be non-op if the space does not have materialized device traces - HloProtoMap hlo_proto_map; - hlo_proto_map.AddHloProtosFromXSpace(space); - SetProgramIdToNameMap(hlo_proto_map, op_stats); - - return op_stats; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h index 7a2f103c714420..4f116494393761 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -16,52 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ -#include - -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -struct OpStatsOptions { - bool maybe_drop_incomplete_steps = false; - bool generate_op_metrics_db = false; - bool generate_step_db = false; - bool generate_kernel_stats_db = false; -}; - -// NOTE: call GroupTfEvents before if OpStats.step_db needs to be generated. -OpStats ConvertXSpaceToOpStats(const XSpace& space, - const OpStatsOptions& options); - -// Populates the program_id_to_name map in OpStats. -void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, - tensorflow::profiler::OpStats& op_stats); - -// Populates the given RunEnvironment with data from XSpace. -void SetRunEnvironment(const XSpace& space, RunEnvironment* env); - -// Propagate and dedup the diagnostics in XSpace and add to OpStats. -void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, - OpStats* op_stats); - -// Populates PerfEnv. -PerfEnv MakePerfEnv(double peak_tera_flops_per_second, - std::vector peak_bws); - -// Extracts PerfEnv from XPlane stats. -PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane); - -// Constructs a DutyCycleTracker from the given XPlaneVisitor. -DutyCycleTracker ConstructDutyCycleTracker(XPlaneVisitor& visitor); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/convert/xplane_to_op_stats.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc deleted file mode 100644 index 91009537b23aca..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ /dev/null @@ -1,826 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/status.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::Property; -using ::testing::UnorderedElementsAre; - -TEST(ConvertXPlaneToOpStats, GpuPerfEnv) { - auto space = std::make_unique(); - constexpr double kMaxError = 0.01; - constexpr int kClockRateKHz = 1530000; - constexpr int kCoreCount = 80; - constexpr uint64 kMemoryBandwidthBytesPerSecond = - uint64{900} * 1000 * 1000 * 1000; - // Volta. - constexpr int kComputeCapMajor = 7; - constexpr int kComputeCapMinor = 0; - - XPlaneBuilder device_plane( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/0)); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("clock_rate"), - kClockRateKHz); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("core_count"), - kCoreCount); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("memory_bandwidth"), - kMemoryBandwidthBytesPerSecond); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_major"), - kComputeCapMajor); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_minor"), - kComputeCapMinor); - - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStatsOptions options; - options.generate_op_metrics_db = true; - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const PerfEnv& perf_env = op_stats.perf_env(); - // Change to lower flops number that we do not use sum of the tensor core peak - // flops and the cuda core peak flops together as peak flops. Only use the - // tensor core peak flops as all those white papers are using. - EXPECT_NEAR(125.34, perf_env.peak_tera_flops_per_second(), kMaxError); - EXPECT_NEAR( - 900, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_HBM_RW), - kMaxError); - // Ridge point changed accordingly from above peak flops change. - EXPECT_NEAR(139.26, perf_env.ridge_point(), kMaxError); -} - -TEST(ConvertXPlaneToOpStats, GpuRunEnvironment) { - auto space = std::make_unique(); - XPlaneBuilder device_plane1( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/0)); - device_plane1.AddStatValue(*device_plane1.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); - XPlaneBuilder device_plane2( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/1)); - device_plane2.AddStatValue(*device_plane2.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); - - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot_or.value(), OpStatsOptions(), &op_stats)); - const RunEnvironment& run_env = op_stats.run_environment(); - - EXPECT_EQ("Nvidia GPU", run_env.device_type()); - EXPECT_EQ(1, run_env.host_count()); - EXPECT_EQ(1, run_env.task_count()); - EXPECT_EQ(2, run_env.device_core_count()); -} - -TEST(ConvertXPlaneToOpStats, CpuOnlyStepDbTest) { - constexpr int64_t kStepNum = 123; - constexpr int64_t kStepId = 0; - - auto space = std::make_unique(); - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(space.get())); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kStepId}}); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 20, 80, - {{StatType::kStepId, kStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const StepDatabaseResult& step_db = op_stats.step_db(); - - EXPECT_EQ(step_db.step_sequence_size(), 1); -} - -TEST(ConvertXPlaneToOpStats, GpuStepDbTest) { - constexpr int64_t kStepNum = 123; - constexpr int64_t kStepId = 0; - constexpr int64_t kCorrelationId = 100; - - auto space = std::make_unique(); - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(space.get())); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kStepId}}); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 20, 20, - {{StatType::kStepId, kStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 10, - {{StatType::kCorrelationId, kCorrelationId}}); - - XPlaneBuilder device_plane_builder( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/0)); - device_plane_builder.ReserveLines(1); - - auto stream = device_plane_builder.GetOrCreateLine(0); - CreateXEvent(&device_plane_builder, &stream, "matmul", 50, 40, - {{StatType::kCorrelationId, kCorrelationId}}); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const StepDatabaseResult& step_db = op_stats.step_db(); - - EXPECT_EQ(step_db.step_sequence_size(), 1); - - PrecisionStats precision_stats = - op_stats.device_op_metrics_db().precision_stats(); - EXPECT_EQ(precision_stats.compute_16bit_ps(), 0); - EXPECT_EQ(precision_stats.compute_32bit_ps(), 40); -} - -TEST(ConvertXPlaneToOpStats, PropagateAndDedupErrors) { - XSpace space; - static constexpr char kError[] = "host: error"; - *space.add_errors() = kError; - *space.add_errors() = kError; - - OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); - - EXPECT_EQ(1, op_stats.diagnostics().errors_size()); - EXPECT_EQ(kError, op_stats.diagnostics().errors(/*index=*/0)); -} - -TEST(ConvertXPlaneToOpStats, Hostnames) { - XSpace space; - static constexpr char kHost[] = "host1"; - *space.add_hostnames() = kHost; - - OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); - EXPECT_EQ( - kHost, - op_stats.core_id_to_details().at(kDefaultGpuLocalCoreId).hostname()); -} - -void BuildXSpaceForTest(XSpace& xspace, absl::string_view hostname) { - constexpr int64_t kStepNum = 123; - constexpr int64_t kStepId = 456; - // Create a host only XSpace for test. - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&xspace)); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kStepId}}); - - auto executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &executor_thread, - HostEventType::kExecutorStateProcess, 20, 80, - {{StatType::kStepId, kStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kStepId}}); - // Create a TensorFlow op that runs for 70 ps. - CreateXEvent(&host_plane_builder, &executor_thread, "aaa:bbb", 30, 70); - xspace.add_hostnames(std::string(hostname)); -} - -TEST(ConvertXPlaneToOpStats, TestConvertMultiXSpacesToCombinedOpStats) { - static constexpr char kHost1[] = "host1"; - static constexpr char kHost2[] = "host2"; - - auto xspace1 = std::make_unique(); - auto xspace2 = std::make_unique(); - - BuildXSpaceForTest(*xspace1, kHost1); - BuildXSpaceForTest(*xspace2, kHost2); - - std::vector xspace_paths; - xspace_paths.push_back("host1.pb"); - xspace_paths.push_back("host2.pb"); - - std::vector> xspaces; - xspaces.push_back(std::move(xspace1)); - xspaces.push_back(std::move(xspace2)); - - auto session_snapshot_or = - SessionSnapshot::Create(std::move(xspace_paths), std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats combined_op_stats; - - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &combined_op_stats)) - << "Failed to convert multi XSpace to OpStats"; - - // Result OpStats has 2 Host Ops, "IDLE" and "aaa:bbb". - ASSERT_EQ(combined_op_stats.host_op_metrics_db().metrics_db_size(), 2); - const auto& metric = combined_op_stats.host_op_metrics_db().metrics_db(1); - EXPECT_EQ(metric.name(), "aaa"); - EXPECT_EQ(metric.category(), "bbb"); - // Each host has the HostOp "aaa:bbb" running for 70 ps, so the combined - // OpStats has "aaa:bbb" running for 140 ps in total. - EXPECT_EQ(metric.self_time_ps(), 140); - - // Result OpStats has 1 step, 2 cores. - ASSERT_EQ(combined_op_stats.step_db().step_sequence_size(), 1); - ASSERT_EQ( - combined_op_stats.step_db().step_sequence(0).step_info_per_core_size(), - 2); - const auto& step_info_per_core = - combined_op_stats.step_db().step_sequence(0).step_info_per_core(); - // global_core_id is computed using: 1000 * host_id + local_core_id. - EXPECT_TRUE(step_info_per_core.contains(kDefaultGpuLocalCoreId)); - EXPECT_TRUE(step_info_per_core.contains(1000 + kDefaultGpuLocalCoreId)); - - const auto& core_details_map = combined_op_stats.core_id_to_details(); - EXPECT_EQ(kHost1, core_details_map.at(kDefaultGpuLocalCoreId).hostname()); - EXPECT_EQ(kHost2, - core_details_map.at(1000 + kDefaultGpuLocalCoreId).hostname()); -} - -TEST(ConvertXPlaneToOpStats, RunEnvironmentExtractedFromTpuPlane) { - XSpace xspace; - for (int i : {0, 1, 2, 3}) { - GetOrCreateTpuXPlane(&xspace, i, "TPU V4", 0, 0); - } - - OpStats op_stats = ConvertXSpaceToOpStats(xspace, OpStatsOptions()); - - EXPECT_EQ(op_stats.run_environment().device_type(), "TPU V4"); - EXPECT_EQ(op_stats.run_environment().device_core_count(), 4); -} - -TEST(ConvertXPlaneToOpStats, TpuPerfEnv) { - auto space = std::make_unique(); - constexpr double kMaxError = 0.01; - constexpr int kClockRateKHz = 1530000; - constexpr int kCoreCount = 80; - constexpr uint64 kMemoryBandwidthBytesPerSecond = - uint64{900} * 1000 * 1000 * 1000; - // Volta. - constexpr int kComputeCapMajor = 7; - constexpr int kComputeCapMinor = 0; - constexpr double kDevCapPeakTeraflopsPerSecond = 141.0; - constexpr double kDevCapPeakHbmBwGigabytesPerSecond = 900.0; - constexpr double kDevCapPeakSramRdBwGigabytesPerSecond = 101.0; - constexpr double kDevCapPeakSramWrBwGigabytesPerSecond = 102.0; - constexpr double kDevCapPeakCmemRdBwGigabytesPerSecond = 101.0; - constexpr double kDevCapPeakCmemWrBwGigabytesPerSecond = 102.0; - constexpr double kDevCapPeakVmemRdBwGigabytesPerSecond = 201.0; - constexpr double kDevCapPeakVmemWrBwGigabytesPerSecond = 202.0; - - XPlaneBuilder device_plane(GetOrCreateTpuXPlane( - space.get(), /*device_ordinal=*/0, "TPU V4", - kDevCapPeakTeraflopsPerSecond, kDevCapPeakHbmBwGigabytesPerSecond)); - /*device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); // "Google, Inc.");*/ - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("clock_rate"), - kClockRateKHz); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("core_count"), - kCoreCount); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("memory_bandwidth"), - kMemoryBandwidthBytesPerSecond); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_major"), - kComputeCapMajor); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_minor"), - kComputeCapMinor); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_sram_rd_bw_gigabytes_per_second"), - kDevCapPeakSramRdBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_sram_wr_bw_gigabytes_per_second"), - kDevCapPeakSramWrBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_cmem_rd_bw_gigabytes_per_second"), - kDevCapPeakCmemRdBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_cmem_wr_bw_gigabytes_per_second"), - kDevCapPeakCmemWrBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_vmem_rd_bw_gigabytes_per_second"), - kDevCapPeakVmemRdBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_vmem_wr_bw_gigabytes_per_second"), - kDevCapPeakVmemWrBwGigabytesPerSecond); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const PerfEnv& perf_env = op_stats.perf_env(); - EXPECT_NEAR(kDevCapPeakTeraflopsPerSecond, - perf_env.peak_tera_flops_per_second(), kMaxError); - EXPECT_NEAR( - kDevCapPeakHbmBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_HBM_RW), - kMaxError); - EXPECT_NEAR( - kDevCapPeakSramRdBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_SRAM_RD), - kMaxError); - EXPECT_NEAR( - kDevCapPeakSramWrBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_SRAM_WR), - kMaxError); - EXPECT_NEAR( - kDevCapPeakCmemRdBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_CMEM_RD), - kMaxError); - EXPECT_NEAR( - kDevCapPeakCmemWrBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_CMEM_WR), - kMaxError); - EXPECT_NEAR( - kDevCapPeakVmemRdBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_VMEM_RD), - kMaxError); - EXPECT_NEAR( - kDevCapPeakVmemWrBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_VMEM_WR), - kMaxError); - EXPECT_NEAR(156.67, perf_env.ridge_point(), kMaxError); -} - -TEST(ConvertXPlaneToOpStats, TpuRunEnvironment) { - auto space = std::make_unique(); - XPlaneBuilder device_plane1( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/0, "TPU V4", 0, 0)); - XPlaneBuilder device_plane2( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/1, "TPU V4", 0, 0)); - - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot_or.value(), OpStatsOptions(), &op_stats)); - const RunEnvironment& run_env = op_stats.run_environment(); - - EXPECT_EQ("TPU V4", run_env.device_type()); - EXPECT_EQ(1, run_env.host_count()); - EXPECT_EQ(1, run_env.task_count()); - EXPECT_EQ(2, run_env.device_core_count()); -} - -TEST(ConvertXPlaneToOpStats, TpuDeviceTraceToStepDb) { - auto space = std::make_unique(); - constexpr double kDevCapPeakTeraflopsPerSecond = 141.0; - constexpr double kDevCapPeakHbmBwGigabytesPerSecond = 1000.0; - XPlaneBuilder xplane_builder(GetOrCreateTpuXPlane( - space.get(), /*device_ordinal=*/0, "TPU V4", - kDevCapPeakTeraflopsPerSecond, kDevCapPeakHbmBwGigabytesPerSecond)); - - XEventMetadata* event_metadata = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata->set_name("op_name"); - XStatsBuilder stats(event_metadata, &xplane_builder); - - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kProgramId)), - 1); - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kSymbolId)), - 1); - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kSelfDurationPs)), - 10); - stats.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - "tf_op_name"); - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kHloCategory)), - "category"); - XLineBuilder line = xplane_builder.GetOrCreateLine(1); - line.SetName(kTensorFlowOpLineName); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetNs(0); - event.SetDurationNs(10); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - EXPECT_THAT(op_stats.device_op_metrics_db().metrics_db(), - UnorderedElementsAre(Property(&OpMetrics::name, "op_name"), - Property(&OpMetrics::name, "IDLE"))); -} - -// Verifies that the step db is generated correctly by intersecting for -// multi-device TPU. -TEST(ConvertXPlaneToOpStats, TpuMultiDeviceStepDbTest) { - auto space = std::make_unique(); - - XPlaneBuilder device_plane_builder1( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/0, "TPU V4", 0, 0)); - XPlaneBuilder device_plane_builder2( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/1, "TPU V4", 0, 0)); - device_plane_builder1.ReserveLines(1); - device_plane_builder2.ReserveLines(1); - - // Create 1 step in xplane in TPU ordinal 0. - XStatMetadata* kGroupId1 = device_plane_builder1.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - XLineBuilder line = device_plane_builder1.GetOrCreateLine(1); - line.SetName(kXlaOpLineName); - // Step 1 - XEventMetadata* event_metadata = - device_plane_builder1.GetOrCreateEventMetadata(1); - event_metadata->set_name("Step 1"); - XEventBuilder event_builder = line.AddEvent(*event_metadata); - event_builder.AddStatValue(*kGroupId1, 1); // step num - event_builder.SetDurationNs(100); - event_builder.SetOffsetNs(100); - - // Create 2 steps in xplane in TPU ordinal 1. - line = device_plane_builder2.GetOrCreateLine(1); - line.SetName(kXlaOpLineName); - // Step 1 - XStatMetadata* kGroupId2 = device_plane_builder2.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - XEventMetadata* event_metadata2 = - device_plane_builder2.GetOrCreateEventMetadata(2); - event_metadata2->set_name("Step 1"); - XEventBuilder event_builder2 = line.AddEvent(*event_metadata2); - event_builder2.AddStatValue(*kGroupId2, 1); // step num - event_builder2.SetDurationNs(100); - event_builder2.SetOffsetNs(300); - // Step 2 - XStatMetadata* kGroupId3 = device_plane_builder2.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - XEventMetadata* event_metadata3 = - device_plane_builder2.GetOrCreateEventMetadata(2); - event_metadata3->set_name("Step 2"); - XEventBuilder event_builder3 = line.AddEvent(*event_metadata3); - event_builder3.AddStatValue(*kGroupId3, 2); // step num - event_builder3.SetDurationNs(100); - event_builder3.SetOffsetNs(300); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats op_stats = ConvertXSpaceToOpStats(*space, options); - const StepDatabaseResult& step_db = op_stats.step_db(); - // For TPU step events, we intersect the step events by step num across - // different TPU devices. - EXPECT_EQ(step_db.step_sequence_size(), 1); -} - -TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromXlaOps) { - XSpace space; - XPlane* device_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder device_plane_builder(device_plane); - XLineBuilder op_line = device_plane_builder.GetOrCreateLine(0); - op_line.SetName(kXlaOpLineName); - CreateXEvent(&device_plane_builder, &op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloInfeed}}); - CreateXEvent(&device_plane_builder, &op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloCall}}); - CreateXEvent(&device_plane_builder, &op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&device_plane_builder, &op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); - XLineBuilder xla_module_line = device_plane_builder.GetOrCreateLine(1); - xla_module_line.SetName(kXlaModuleLineName); - CreateXEvent(&device_plane_builder, &xla_module_line, "module.1", - /*offset_ps=*/5, - /*duration_ps=*/50); - - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(device_plane); - DutyCycleTracker tracker = ConstructDutyCycleTracker(visitor); - EXPECT_EQ(tracker.GetActiveTimePs(), 20); - EXPECT_EQ(tracker.GetIdleTimePs(), 30); -} - -TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromSparseCore) { - XSpace space; - XPlane* sc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder sc_plane_builder(sc_plane); - XLineBuilder op_line = sc_plane_builder.GetOrCreateLine(0); - op_line.SetName(kSparseCoreOpLineName); - CreateXEvent(&sc_plane_builder, &op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10); - XLineBuilder module_line = sc_plane_builder.GetOrCreateLine(1); - module_line.SetName(kSparseCoreModuleLineName); - CreateXEvent(&sc_plane_builder, &module_line, "module.1", /*offset_ps=*/5, - /*duration_ps=*/50); - - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(sc_plane); - DutyCycleTracker tracker = ConstructDutyCycleTracker(visitor); - EXPECT_EQ(tracker.GetActiveTimePs(), 40); - EXPECT_EQ(tracker.GetIdleTimePs(), 10); -} - -TEST(ConvertXPlaneToOpStats, MultiCoreChipBusyAndIdleTimeTest) { - XSpace space; - CoreDetails tc_core_details; - tc_core_details.set_local_chip_id(0); - CoreDetails sc_core_details; - sc_core_details.set_local_chip_id(0); - XPlane* tc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder tc_plane_builder(tc_plane); - tc_plane_builder.AddStatValue(*tc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCoreDetails)), - tc_core_details); - XLineBuilder xla_op_line = tc_plane_builder.GetOrCreateLine(0); - xla_op_line.SetName(kXlaOpLineName); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloInfeed}}); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloCall}}); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); - XLineBuilder xla_module_line = tc_plane_builder.GetOrCreateLine(1); - xla_module_line.SetName(kXlaModuleLineName); - CreateXEvent(&tc_plane_builder, &xla_module_line, "module.1", /*offset_ps=*/5, - /*duration_ps=*/50); - - XPlane* sc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/1, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder sc_plane_builder(sc_plane); - sc_plane_builder.AddStatValue(*sc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCoreDetails)), - sc_core_details); - XLineBuilder sc_op_line = sc_plane_builder.GetOrCreateLine(0); - sc_op_line.SetName(kSparseCoreOpLineName); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10); - XLineBuilder sc_module_line = sc_plane_builder.GetOrCreateLine(1); - sc_module_line.SetName(kSparseCoreModuleLineName); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.1", /*offset_ps=*/5, - /*duration_ps=*/50); - - OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); - EXPECT_EQ(op_stats.device_op_metrics_db().idle_time_ps(), 10); - EXPECT_EQ(op_stats.device_op_metrics_db().busy_time_ps(), 40); -} - -TEST(ConvertXPlaneToOpStats, HandleSparseCoreBusyOpMetrics) { - XSpace space; - XPlane* tc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder tc_plane_builder(tc_plane); - tc_plane_builder.SetId(0); - XLineBuilder tc_step_line = tc_plane_builder.GetOrCreateLine(0); - tc_step_line.SetName(tsl::profiler::kStepLineName); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.1", /*offset_ps=*/10, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.2", /*offset_ps=*/20, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.3", /*offset_ps=*/30, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.4", /*offset_ps=*/40, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{4}}}); - XLineBuilder tc_module_line = tc_plane_builder.GetOrCreateLine(1); - tc_module_line.SetName(tsl::profiler::kXlaModuleLineName); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.1", /*offset_ps=*/10, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.2", /*offset_ps=*/20, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.3", /*offset_ps=*/30, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.4", /*offset_ps=*/40, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{4}}}); - XLineBuilder tc_op_line = tc_plane_builder.GetOrCreateLine(2); - tc_op_line.SetName(kXlaOpLineName); - auto& program_id_stat = *tc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kProgramId)); - auto& symbol_id_stat = *tc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kSymbolId)); - XStatsBuilder op1_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.1"), &tc_plane_builder); - op1_stats.AddStatValue(program_id_stat, 1); - op1_stats.AddStatValue(symbol_id_stat, 1); - XStatsBuilder op2_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.2"), &tc_plane_builder); - op2_stats.AddStatValue(program_id_stat, 1); - op2_stats.AddStatValue(symbol_id_stat, 2); - XStatsBuilder op3_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.3"), &tc_plane_builder); - op3_stats.AddStatValue(program_id_stat, 1); - op3_stats.AddStatValue(symbol_id_stat, 3); - XStatsBuilder op4_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.4"), &tc_plane_builder); - op4_stats.AddStatValue(program_id_stat, 1); - op4_stats.AddStatValue(symbol_id_stat, 4); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.1", /*offset_ps=*/15, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.2", /*offset_ps=*/25, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.3", /*offset_ps=*/35, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.4", /*offset_ps=*/45, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{4}}}); - XPlane* sc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/1, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder sc_plane_builder(sc_plane); - sc_plane_builder.SetId(1); - sc_plane_builder.SetName( - absl::StrCat(sc_plane->name(), " SparseCore ", sc_plane->id())); - XLineBuilder sc_step_line = sc_plane_builder.GetOrCreateLine(0); - sc_step_line.SetName(tsl::profiler::kSparseCoreStepLineName); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.3", /*offset_ps=*/30, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{4}}}); - XLineBuilder sc_module_line = sc_plane_builder.GetOrCreateLine(1); - sc_module_line.SetName(kSparseCoreModuleLineName); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.1", /*offset_ps=*/10, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.2", /*offset_ps=*/20, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.3", /*offset_ps=*/30, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.4", /*offset_ps=*/40, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{4}}}); - XLineBuilder sc_op_line = sc_plane_builder.GetOrCreateLine(2); - sc_op_line.SetName(kSparseCoreOpLineName); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.1", /*offset_ps=*/15, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.2", /*offset_ps=*/25, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.3", /*offset_ps=*/35, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.4", /*offset_ps=*/45, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{4}}}); - OpStats op_stats = ConvertXSpaceToOpStats( - space, - OpStatsOptions{.generate_op_metrics_db = true, .generate_step_db = true}); - EXPECT_EQ(op_stats.device_op_metrics_db().total_time_ps(), 40); - EXPECT_EQ(op_stats.device_op_metrics_db().total_op_time_ps(), 20); - EXPECT_EQ(op_stats.step_db().step_sequence_size(), 4); - EXPECT_EQ(op_stats.hlo_metrics_db_complete_steps_only().total_time_ps(), 40); - EXPECT_EQ(op_stats.hlo_metrics_db_complete_steps_only().total_op_time_ps(), - 20); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc deleted file mode 100644 index 104ed52a3eb02b..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ /dev/null @@ -1,394 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/op_metrics.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/steps_db.pb.h" // from @org_xprof -#include "xprof/utils/event_span.h" // from @org_xprof -#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -inline AllReduceInfo GetAllReduceInfo(const XEventVisitor& event, - uint64_t all_reduce_unique_id) { - AllReduceInfo collective_ops; - collective_ops.set_id(all_reduce_unique_id); - collective_ops.set_start_time_ps(event.TimestampPs()); - if (auto device_offset_ps_stat = event.GetStat(StatType::kDeviceOffsetPs)) { - collective_ops.set_start_time_ps(device_offset_ps_stat->IntOrUintValue()); - } - collective_ops.set_end_time_ps(event.EndTimestampPs()); - if (auto device_duration_ps_stat = - event.GetStat(StatType::kDeviceDurationPs)) { - collective_ops.set_end_time_ps(collective_ops.start_time_ps() + - device_duration_ps_stat->IntOrUintValue()); - } - if (auto all_reduce_id_stat = event.GetStat(StatType::kAllReduceId)) { - collective_ops.set_all_reduce_id(all_reduce_id_stat->IntOrUintValue()); - } - if (auto bytes_accessed_stat = - event.Metadata().GetStat(StatType::kBytesAccessed)) { - collective_ops.set_byte_size(bytes_accessed_stat->IntOrUintValue()); - } - return collective_ops; -} - -inline bool IsExplicitHostStepMarker(absl::string_view event_name) { - return (absl::StartsWith(event_name, "train") || - absl::StartsWith(event_name, "test") || - absl::StartsWith(event_name, "TraceContext")) && - !absl::StrContains(event_name, "/"); -} - -// Returns true if the given event_name should be considered as real computation -// on CPU. -inline bool IsRealCpuCompute(absl::string_view event_name) { - bool not_real = absl::StartsWith(event_name, "EagerExecute") || - absl::StartsWith(event_name, "EagerLocalExecute") || - absl::StartsWith(event_name, "EagerKernelExecute") || - absl::StartsWith(event_name, "FunctionRun") || - IsExplicitHostStepMarker(event_name); - return !not_real; -} - -uint64 ParseNumBytesFromMemcpyDetail(absl::string_view memcpy_detail) { - const std::vector params = - absl::StrSplit(memcpy_detail, absl::ByAnyChar(":\n")); - - // Processes value pairs. - for (uint32 ii = 0; ii < params.size(); ii += 2) { - if (params[ii] != "num_bytes") continue; - uint64 value = 0; - if (absl::SimpleAtoi(params[ii + 1], &value)) return value; - break; - } - return 0ULL; -} - -EventType ClassifyGpuCompute(absl::string_view event_name, - absl::string_view tensor_shapes) { - if (tensor_shapes.empty()) { - // Deduces the precision from the name. - return (absl::StrContains(event_name, "half") || - absl::StrContains(event_name, "fp16")) - ? DEVICE_COMPUTE_16 - : DEVICE_COMPUTE_32; - } else { - // Deduces the precision from the shapes. - return (absl::StrContains(tensor_shapes, "half")) ? DEVICE_COMPUTE_16 - : DEVICE_COMPUTE_32; - } -} - -EventType ClassifyGpuEvent(absl::string_view event_name, - absl::string_view tensor_shapes) { - tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname(event_name); - if (tsl::profiler::IsMemcpyHToDOp(tf_op)) { - return HOST_TO_DEVICE; - } else if (tsl::profiler::IsMemcpyDToHOp(tf_op)) { - return DEVICE_TO_HOST; - } else if (tsl::profiler::IsMemcpyDToDOp(tf_op)) { - return DEVICE_TO_DEVICE; - } else if (absl::StartsWithIgnoreCase(event_name, "nccl")) { - return DEVICE_COLLECTIVES; - } else { - return ClassifyGpuCompute(event_name, tensor_shapes); - } -} - -EventType ClassifyCpuEvent(absl::string_view event_name, bool has_device, - bool has_correlation_id) { - tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname(event_name); - if (tsl::profiler::IsInfeedEnqueueOp(tf_op) || - tsl::profiler::IsMemcpyHToDOp(tf_op)) { - return HOST_TO_DEVICE; - } else if (tsl::profiler::IsMemcpyHToHOp(tf_op)) { - return HOST_TO_HOST; - } else if (has_device && (has_correlation_id || - absl::StartsWithIgnoreCase( - event_name, "ExecutorState::Process"))) { - // TODO(b/150420972): Separate runtime overhead from actual compute for - // CPU-only. - return HOST_PREPARE; - } else if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext")) { - return HOST_WAIT_INPUT; - } else { - return HOST_COMPUTE; - } -} - -} // namespace - -StepEvents ConvertHostThreadsXLineToStepEvents( - const XLineVisitor& line, const StepEvents* device_step_events) { - StepEvents result; - line.ForEachEvent([&](const XEventVisitor& event) { - int64_t correlation_id = -1; - int64_t group_id = -1; - absl::string_view step_name; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kCorrelationId: - correlation_id = stat.IntValue(); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - case StatType::kStepName: - step_name = stat.StrOrRefValue(); - break; - } - }); - if (group_id < 0) return; - // Don't add CPU events when (1) it includes device step events and (2) it - // doesn't have a device and that the group_id (i.e. step number) already - // appears on the device. This will filter out all cpu events that do not - // correspond to any steps executed on the device. - bool has_device = (device_step_events != nullptr); - if (has_device && !device_step_events->contains(group_id)) return; - if (IsExplicitHostStepMarker(event.Name())) { - result[group_id].AddMarker( - StepMarker(StepMarkerType::kExplicitHostStepMarker, event.Name(), - event.GetTimespan())); - } else if (!step_name.empty()) { - // Grouping adds a step_name stat to implicit host step markers. - result[group_id].AddMarker( - StepMarker(StepMarkerType::kImplicitHostStepMarker, event.Name(), - event.GetTimespan())); - } else if (IsRealCpuCompute(event.Name())) { - result[group_id].AddEvent(EventTypeSpan( - ClassifyCpuEvent(event.Name(), has_device, correlation_id >= 0), - event.GetTimespan())); - } - if (!step_name.empty()) { - result[group_id].SetStepName(std::string(step_name)); - } - }); - return result; -} - -StepEvents ConvertHostThreadsXPlaneToStepEvents( - const XPlane& host_trace, const StepEvents* device_step_events) { - StepEvents host_step_events; - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&host_trace); - plane.ForEachLine([&](const XLineVisitor& line) { - StepEvents thread_step_events = - ConvertHostThreadsXLineToStepEvents(line, device_step_events); - UnionCombineStepEvents(thread_step_events, &host_step_events); - }); - return host_step_events; -} - -StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) { - StepEvents result; - line.ForEachEvent([&](const XEventVisitor& event) { - if (std::optional stat = event.GetStat(StatType::kGroupId)) { - result[stat->IntValue()].AddMarker( - StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(), - event.GetTimespan())); - } - }); - return result; -} - -StepEvents ConvertDeviceTraceXLineToStepEvents(const uint64 device_id, - const XLineVisitor& line) { - StepEvents result; - line.ForEachEvent([&](const XEventVisitor& event) { - int64_t correlation_id = -1; - int64_t group_id = -1; - absl::string_view tensor_shapes; - absl::string_view memcpy_details; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kCorrelationId: - correlation_id = stat.IntValue(); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - case StatType::kTensorShapes: - tensor_shapes = stat.StrOrRefValue(); - break; - case StatType::kMemcpyDetails: - memcpy_details = stat.StrOrRefValue(); - break; - } - }); - - if (correlation_id >= 0 && group_id >= 0) { - EventType event_type = ClassifyGpuEvent(event.Name(), tensor_shapes); - EventTypeSpan event_type_span(event_type, event.GetTimespan()); - result[group_id].AddEvent(event_type_span); - switch (event_type) { - case DEVICE_COLLECTIVES: { - AllReduceInfo collective_ops; - collective_ops.set_start_time_ps(event.TimestampPs()); - collective_ops.set_end_time_ps(event.EndOffsetPs()); - // TODO(jiesun): figure out how to get size info etc. - result[group_id].AddCollectiveOpEvent(device_id, collective_ops); - break; - } - case HOST_TO_DEVICE: - case DEVICE_TO_DEVICE: - case DEVICE_TO_HOST: { - // TODO(jiesun): not all memcpy events are grouped, figure out a - // better way to attribute them to steps. - uint64 bytes_transferred = - ParseNumBytesFromMemcpyDetail(memcpy_details); - result[group_id].AddDeviceMemoryTransferEvent( - event_type, event.GetTimespan(), bytes_transferred); - break; - } - default: - return; - } - } - }); - return result; -} - -StepEvents ConvertTpuDeviceTraceXLineToStepEvents(const uint64 device_id, - const XLineVisitor& line) { - StepEvents result; - absl::flat_hash_map - op_metrics_builder; - struct ParentRef { - const XEventVisitor event; - tsl::profiler::Timespan device_timespan; - uint64_t children_duration_ps = 0; - int64_t group_id = -1; - }; - tsl::profiler::AncestorStack event_stack( - // Adds an OpMetric to the builder based on the provided parent reference. - [&](const ParentRef& parent) { - OpMetrics op_metrics = FromXEvent(parent.event); - op_metrics.set_time_ps(parent.device_timespan.duration_ps()); - // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is - // implemented. - if (device_id < kSparseCoreIndexStart) { - op_metrics.set_self_time_ps(op_metrics.time_ps() - - parent.children_duration_ps); - } - op_metrics_builder[parent.group_id].AddOpMetric( - op_metrics, GetOpKeyFromXEvent(parent.event)); - }, - // Checks if the child event is a child of the parent event. - [](const ParentRef& parent, const ParentRef& child) { - return parent.device_timespan.Includes(child.device_timespan); - }, - // Adds the child duration to the parent. - [](ParentRef& parent, ParentRef& child) { - parent.children_duration_ps += child.device_timespan.duration_ps(); - }); - line.ForEachEvent([&](const XEventVisitor& event) { - auto group_id_stat = event.GetStat(StatType::kGroupId); - if (!group_id_stat.has_value()) return; - int64_t group_id = group_id_stat->IntOrUintValue(); - event_stack.Push(ParentRef{ - .event = event, - .device_timespan = tsl::profiler::GetDeviceEventTimespan(event), - .group_id = group_id, - }); - - if (auto all_reduce_unique_id_stat = - event.GetStat(StatType::kAllReduceUniqueId)) { - result[group_id].AddCollectiveOpEvent( - device_id, - GetAllReduceInfo(event, all_reduce_unique_id_stat->IntOrUintValue())); - } - }); - event_stack.Flush(); - for (auto& [group_id, builder] : op_metrics_builder) { - // Finalize Without the step time now. - result[group_id].SetPerCoreOpMetricsDb(builder.Finalize(), device_id); - } - return result; -} - -StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { - StepEvents device_step_events; - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - std::optional tpu_core_id = tsl::profiler::GetTensorCoreId(plane.Name()); - std::optional sc_core_id = tsl::profiler::GetSparseCoreId(plane.Name()); - plane.ForEachLine([&](const XLineVisitor& line) { - int64_t line_id = line.Id(); - if (line_id == kThreadIdStepInfo || - (tpu_core_id.has_value() && - line.Name() == tsl::profiler::kStepLineName)) { - // TODO(b/397774568): Re-add processing of SparseCore steps once the - // SparseCore OpMetricsDb is implemented. - StepEvents step_marker_events = ConvertDeviceStepInfoToStepMarkers(line); - UnionCombineStepEvents(step_marker_events, &device_step_events); - } else if (IsDerivedThreadId(line_id)) { - return; - } else { - StepEvents stream_step_events; - if (tpu_core_id.has_value()) { - if (!tsl::profiler::IsOpLineName(line.Name())) return; - // In TPU sampling mode, the profiling session could stop in the middle - // of a training step. In this case, the "XLA Ops" line will have - // one more step than the "Step" line. We need to intersect them to get - // the common step numbers. - stream_step_events = - ConvertTpuDeviceTraceXLineToStepEvents(plane.Id(), line); - IntersectCombineStepEvents(stream_step_events, &device_step_events); - } else if (sc_core_id.has_value()) { - // TODO(b/397774568): Switch to IsOpLineName once SparseCore OpMetricsDb - // is implemented. - if (line.Name() != tsl::profiler::kSparseCoreStepLineName) return; - stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents( - kSparseCoreIndexStart + plane.Id(), line); - IntersectCombineStepEvents(stream_step_events, &device_step_events); - } else { - stream_step_events = - ConvertDeviceTraceXLineToStepEvents(plane.Id(), line); - UnionCombineStepEvents(stream_step_events, &device_step_events); - } - } - }); - return device_step_events; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.h b/tensorflow/core/profiler/convert/xplane_to_step_events.h deleted file mode 100644 index 7d343b746b9fff..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ - -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Convert the host threads in XLine format to StepEvents format. If -// device_step_events is non-null, we will filter out events that only happens -// on CPU. -StepEvents ConvertHostThreadsXLineToStepEvents( - const XLineVisitor& line, const StepEvents* device_step_events); - -// Convert the host threads in XPlane format to StepEvents format. If -// device_step_events is non-null, we will filter out events that only happens -// on CPU. -StepEvents ConvertHostThreadsXPlaneToStepEvents( - const XPlane& host_trace, const StepEvents* device_step_events); - -// Convert the device trace in XLine format to StepEvents. -StepEvents ConvertDeviceTraceXLineToStepEvents(const XLineVisitor& line); - -// Convert the device trace in XPlane format to StepEvents. -StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc deleted file mode 100644 index 5c97b045622a56..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/utils/event_span.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -// Tests with a sample profile with two steps captured on the host but only one -// step on the device. On the host, each step consists of TraceContext -> -// FunctionRun -> ExecutorState::Process -> matmul. On the host, each step -// consists of matmul. The host's step db should be created only for the step -// observed on the host. -TEST(ConvertXPlaneToOpStats, CpuOnlyStepDbTest) { - constexpr int64_t kFirstStepNum = 123; - constexpr int64_t kSecondStepNum = 456; - constexpr int64_t kFirstStepId = 0; - constexpr int64_t kSecondStepId = 1; - constexpr int64_t kFirstCorrelationId = 100; - constexpr int64_t kSecondCorrelationId = 200; - - XSpace space; - XPlane* host_plane = GetOrCreateHostXPlane(&space); - XPlaneBuilder host_plane_builder(host_plane); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kFirstStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kFirstStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kFirstStepId}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 300, 100, {{StatType::kStepNum, kSecondStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 310, 90, - {{StatType::kStepId, kSecondStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kSecondStepId}}); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 20, 20, - {{StatType::kStepId, kFirstStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kFirstStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 10, - {{StatType::kCorrelationId, kFirstCorrelationId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 320, 20, - {{StatType::kStepId, kSecondStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kSecondStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 330, 10, - {{StatType::kCorrelationId, kSecondCorrelationId}}); - - XPlane* device_plane = space.add_planes(); - XPlaneBuilder device_plane_builder(device_plane); - device_plane_builder.ReserveLines(1); - - auto stream = device_plane_builder.GetOrCreateLine(0); - CreateXEvent(&device_plane_builder, &stream, "matmul", 50, 40, - {{StatType::kCorrelationId, kFirstCorrelationId}}); - - tsl::profiler::GroupTfEvents(&space); - StepEvents device_step_events = - ConvertDeviceTraceXPlaneToStepEvents(*device_plane); - EXPECT_EQ(device_step_events.size(), 1); - EXPECT_EQ(device_step_events[0].Events().size(), 1); - StepEvents host_step_events = - ConvertHostThreadsXPlaneToStepEvents(*host_plane, &device_step_events); - // Should contain only the step which is also present on the device. - EXPECT_EQ(host_step_events.size(), 1); - // TraceContext should be added as a step marker. - EXPECT_EQ(host_step_events[0].Markers().size(), 1); - // FunctionRun shouldn't be added. - EXPECT_EQ(host_step_events[0].Events().size(), 2); -} - -TEST(ConvertXPlaneToStepEvents, TpuDevicePlaneToStepEvents) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - int64_t device_id = 1; - plane.SetId(device_id); - plane.SetName("/device:TPU:0"); - XLineBuilder op_line = plane.GetOrCreateLine(0); - op_line.SetName(tsl::profiler::kXlaOpLineName); - const XStatMetadata& program_id_stat = - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)); - const XStatMetadata& symbol_id_stat = - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)); - const XStatMetadata& group_id_stat = - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId)); - { - XEventMetadata* event_metadata = - plane.GetOrCreateEventMetadata("op_long_name"); - event_metadata->set_display_name("op_name"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue(program_id_stat, 1); - stats.AddStatValue(symbol_id_stat, 1); - { - XEventBuilder event = op_line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(50); - event.AddStatValue(group_id_stat, 1); - } - { - XEventBuilder event = op_line.AddEvent(*event_metadata); - event.SetOffsetPs(100); - event.SetDurationPs(50); - event.AddStatValue(group_id_stat, 2); - } - } - { - XEventMetadata* event_metadata = - plane.GetOrCreateEventMetadata("op_long_name2"); - event_metadata->set_display_name("op_name2"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue(program_id_stat, 1); - stats.AddStatValue(symbol_id_stat, 2); - XEventBuilder event = op_line.AddEvent(*event_metadata); - event.SetOffsetPs(50); - event.SetDurationPs(50); - event.AddStatValue(group_id_stat, 1); - } - XLineBuilder step_line = plane.GetOrCreateLine(1); - step_line.SetName(tsl::profiler::kStepLineName); - { - XEventMetadata* event_metadata = plane.CreateEventMetadata(); - XStatsBuilder stats(event_metadata, &plane); - { - XEventBuilder event = step_line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - event.AddStatValue(group_id_stat, 1); - } - { - XEventBuilder event = step_line.AddEvent(*event_metadata); - event.SetOffsetPs(100); - event.SetDurationPs(100); - event.AddStatValue(group_id_stat, 2); - } - } - - StepEvents step_events = ConvertDeviceTraceXPlaneToStepEvents(raw_plane); - EXPECT_EQ(step_events.size(), 2); - EXPECT_TRUE(step_events.contains(1)); - StepDetails step_1 = step_events[/*group_id=*/1]; - ASSERT_TRUE(step_1.PerCoreOpMetricsDb().contains(device_id)); - EXPECT_EQ(step_1.PerCoreOpMetricsDb().at(device_id).metrics_db_size(), 2); - EXPECT_EQ(step_1.Markers().size(), 1); - EXPECT_TRUE(step_events.contains(2)); - StepDetails step_2 = step_events[/*group_id=*/2]; - ASSERT_TRUE(step_2.PerCoreOpMetricsDb().contains(device_id)); - EXPECT_EQ(step_2.PerCoreOpMetricsDb().at(device_id).metrics_db_size(), 1); - EXPECT_EQ(step_2.Markers().size(), 1); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc deleted file mode 100644 index 1f0009c44ed4eb..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ /dev/null @@ -1,525 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" - -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof -#include "xprof/utils/html_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// 50 us from https://www.tensorflow.org/guide/data_performance_analysis -const int64_t kSlowCallThresholdPs = 50 * 1000000; - -namespace { - -// Returns true if the given iterator event is for a root iterator. -bool IsRootIteratorEvent(const XEventVisitor& iterator_event) { - std::vector split_result = - absl::StrSplit(iterator_event.Name(), "::"); - // The root iterator's name contains only its own name (no parent - // information). - return split_result.size() == 2; -} - -// Returns true if the given iterator event name is for an async iterator. -bool IsAsyncIterator(absl::string_view iterator_event_name) { - static auto* kAsyncIterators = new absl::flat_hash_set( - {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample", - "MapAndBatch", "DataService", "LegacyParallelInterleave", - "ParallelBatch"}); - return kAsyncIterators->contains(iterator_event_name); -} - -void SetIteratorMetadata(int64_t id, const XEventVisitor& event, - IteratorMetadata* metadata) { - metadata->set_id(id); - auto parent_id_stat = event.GetStat(StatType::kParentId); - if (parent_id_stat.has_value()) { - metadata->set_parent_id(parent_id_stat->IntValue()); - } - metadata->set_name(tsl::profiler::IteratorName(event.Name())); - metadata->set_long_name(event.Name().data(), event.Name().size()); - metadata->set_is_async(IsAsyncIterator(metadata->name())); - // TODO(b/161831651): Set params. -} - -// Returns the parent iterator's id if it is a root of a device input -// pipeline. -std::optional FindDeviceInputPipeline(const XEventVisitor& event) { - if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) { - auto parent_id_stat = event.GetStat(StatType::kParentId); - if (parent_id_stat.has_value()) return parent_id_stat->IntValue(); - } - return std::nullopt; -} - -// Processes tsl::profiler::EventForest to do the following: -// (1) set iterator metadata -// (2) find root iterator events -// (3) find device input pipeline ids -void ProcessEventForest( - const tsl::profiler::EventForest& event_forest, - absl::flat_hash_set* device_input_pipeline_ids, - absl::flat_hash_map>* - root_iterator_event_map, - TfDataStats* tf_data_stats) { - const tsl::profiler::EventNodeMap& event_node_map = - event_forest.GetEventNodeMap(); - auto* iterator_event_list = - gtl::FindOrNull(event_node_map, HostEventType::kIterator); - if (!iterator_event_list) return; - for (const tsl::profiler::EventNode& iterator_event : *iterator_event_list) { - const XEventVisitor& iterator_event_visitor = - iterator_event.GetEventVisitor(); - auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId); - if (!iterator_id_stat.has_value()) continue; - int64_t iterator_id = iterator_id_stat->IntValue(); - auto result = tf_data_stats->mutable_iterator_metadata()->insert( - {iterator_id, IteratorMetadata()}); - IteratorMetadata& metadata = result.first->second; - if (result.second) { - // First time processing this iterator. - SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata); - } - if (IsRootIteratorEvent(iterator_event_visitor)) { - // Record root iterator events. - (*root_iterator_event_map)[iterator_id].push_back(&iterator_event); - } - } - auto* device_input_pipeline_second_iterator_events = gtl::FindOrNull( - event_node_map, HostEventType::kDeviceInputPipelineSecondIterator); - if (!device_input_pipeline_second_iterator_events) return; - for (const tsl::profiler::EventNode& iterator_event : - *device_input_pipeline_second_iterator_events) { - const XEventVisitor& iterator_event_visitor = - iterator_event.GetEventVisitor(); - auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId); - if (!iterator_id_stat.has_value()) continue; - int64_t iterator_id = iterator_id_stat->IntValue(); - auto result = tf_data_stats->mutable_iterator_metadata()->insert( - {iterator_id, IteratorMetadata()}); - IteratorMetadata& metadata = result.first->second; - if (result.second) { - // First time processing this iterator. - SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata); - // Find and record device input pipeline ids. - std::optional device_input_pipeline_id = - FindDeviceInputPipeline(iterator_event_visitor); - if (device_input_pipeline_id.has_value()) { - device_input_pipeline_ids->insert(*device_input_pipeline_id); - } - } - } -} - -void SetInputPipelineMetadata(int64_t id, int64_t name_id, - bool is_device_input_pipeline, - InputPipelineMetadata* metadata) { - constexpr absl::string_view kHostInputPipelinePrefix = "Host:"; - constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:"; - metadata->set_id(id); - if (is_device_input_pipeline) { - metadata->set_type(InputPipelineMetadata::DEVICE); - metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id)); - } else { - metadata->set_type(InputPipelineMetadata::HOST); - metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id)); - } -} - -void ProcessIteratorEvent(const tsl::profiler::EventNode& iterator_event, - InputPipelineStat* input_pipeline_stat, - bool is_blocking, int level = 0) { - if (level > 100) return; - const XEventVisitor& visitor = iterator_event.GetEventVisitor(); - auto iterator_id_stat = visitor.GetStat(StatType::kStepId); - if (!iterator_id_stat.has_value()) return; - int64_t iterator_id = iterator_id_stat->IntValue(); - auto result = input_pipeline_stat->mutable_iterator_stats()->insert( - {iterator_id, IteratorStat()}); - IteratorStat& iterator_stat = result.first->second; - if (result.second) { - iterator_stat.set_id(iterator_id); - iterator_stat.set_start_time_ps(visitor.TimestampPs()); - } - iterator_stat.set_duration_ps(iterator_stat.duration_ps() + - visitor.DurationPs()); - int64_t self_time_ps = visitor.DurationPs(); - tsl::profiler::Timespan self_time_span = visitor.GetTimespan(); - for (const tsl::profiler::EventNode* child : iterator_event.GetChildren()) { - const XEventVisitor& child_visitor = child->GetEventVisitor(); - if (tsl::profiler::ParseTfOpFullname(child_visitor.Name()).category == - tsl::profiler::Category::kTfData) { - int64_t overlap_duration_ps = - self_time_span.OverlappedDurationPs(child_visitor.GetTimespan()); - ProcessIteratorEvent(*child, input_pipeline_stat, - is_blocking && overlap_duration_ps, level + 1); - // Note: Assume no overlap between child events. - self_time_ps -= overlap_duration_ps; - } - } - iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps); - iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking); - iterator_stat.set_num_calls(iterator_stat.num_calls() + 1); -} - -void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) { - int64_t bottleneck_iterator_id = 0; - int64_t max_self_time = 0; - for (const auto& pair : input_pipeline_stat->iterator_stats()) { - const auto& id = pair.first; - const auto& iterator_stat = pair.second; - if (iterator_stat.is_blocking() && - iterator_stat.self_time_ps() > max_self_time) { - bottleneck_iterator_id = id; - max_self_time = iterator_stat.self_time_ps(); - } - } - input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id); - input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time); -} - -void ProcessInputPipelines( - const absl::flat_hash_set& device_input_pipeline_ids, - absl::flat_hash_map>* - root_iterator_event_map, - TfDataStats* tf_data_stats) { - auto* input_pipelines = tf_data_stats->mutable_input_pipelines(); - int64_t num_host_input_pipelines = 0; - int64_t num_device_input_pipelines = 0; - for (auto& id_and_events : *root_iterator_event_map) { - auto& root_iterator_id = id_and_events.first; - auto& root_iterator_events = id_and_events.second; - absl::c_sort(root_iterator_events, [](const tsl::profiler::EventNode* lhs, - const tsl::profiler::EventNode* rhs) { - return lhs->GetEventVisitor().DurationPs() > - rhs->GetEventVisitor().DurationPs(); - }); - auto result = - input_pipelines->insert({root_iterator_id, InputPipelineStats()}); - InputPipelineStats& input_pipeline_stats = result.first->second; - InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata(); - if (result.second) { - bool is_device_input_pipeline = - device_input_pipeline_ids.contains(root_iterator_id); - int64_t name_id = is_device_input_pipeline ? num_device_input_pipelines++ - : num_host_input_pipelines++; - SetInputPipelineMetadata(root_iterator_id, name_id, - is_device_input_pipeline, metadata); - } - int64_t sum_latency_ps = 0; - int64_t min_latency_ps = INT64_MAX; - int64_t max_latency_ps = 0; - int64_t num_slow_calls = 0; - for (const tsl::profiler::EventNode* root_iterator_event : - root_iterator_events) { - InputPipelineStat* stat = input_pipeline_stats.add_stats(); - ProcessIteratorEvent(*root_iterator_event, stat, - /*is_blocking*/ true); - SetBottleneckIteratorId(stat); - int64_t latency_ps = root_iterator_event->GetEventVisitor().DurationPs(); - sum_latency_ps += latency_ps; - min_latency_ps = std::min(min_latency_ps, latency_ps); - max_latency_ps = std::max(max_latency_ps, latency_ps); - if (latency_ps > kSlowCallThresholdPs) num_slow_calls++; - } - input_pipeline_stats.set_avg_latency_ps(sum_latency_ps / - root_iterator_events.size()); - input_pipeline_stats.set_min_latency_ps(min_latency_ps); - input_pipeline_stats.set_max_latency_ps(max_latency_ps); - input_pipeline_stats.set_num_slow_calls(num_slow_calls); - } -} - -void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) { - struct InputPipeline { - InputPipeline(absl::string_view host_name, - absl::string_view input_pipeline_name, int64_t max_latency_ps, - absl::string_view iterator_name, - absl::string_view iterator_long_name, - int64_t iterator_latency_ps) - : host_name(host_name), - input_pipeline_name(input_pipeline_name), - max_latency_ps(max_latency_ps), - iterator_name(iterator_name), - iterator_long_name(iterator_long_name), - iterator_latency_ps(iterator_latency_ps) {} - absl::string_view host_name; - absl::string_view input_pipeline_name; - int64_t max_latency_ps; - absl::string_view iterator_name; - absl::string_view iterator_long_name; - int64_t iterator_latency_ps; - - bool operator<(const InputPipeline& rhs) const { - return max_latency_ps > rhs.max_latency_ps; - } - }; - std::vector slow_input_pipelines; - for (const auto& host_name_and_tf_data_stats : - combined_tf_data_stats->tf_data_stats()) { - absl::string_view host_name = host_name_and_tf_data_stats.first; - const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second; - for (const auto& id_and_stats : tf_data_stats.input_pipelines()) { - const InputPipelineStats& input_pipeline_stats = id_and_stats.second; - if (input_pipeline_stats.metadata().type() == - InputPipelineMetadata::DEVICE) { - // Ignore device input pipelines. - continue; - } - // Choose the slowest execution trace of the input pipeline. - // `input_pipeline_stats.stats` is already sorted so choose the first one. - const InputPipelineStat& input_pipeline_stat = - input_pipeline_stats.stats(0); - const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at( - input_pipeline_stat.bottleneck_iterator_id()); - slow_input_pipelines.emplace_back( - host_name, input_pipeline_stats.metadata().name(), - input_pipeline_stats.max_latency_ps(), metadata.name(), - metadata.long_name(), - input_pipeline_stat.bottleneck_iterator_latency_ps()); - } - } - std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end()); - for (const auto& input_pipeline : slow_input_pipelines) { - TfDataBottleneckAnalysis* bottleneck_analysis = - combined_tf_data_stats->add_bottleneck_analysis(); - bottleneck_analysis->set_host(input_pipeline.host_name.data(), - input_pipeline.host_name.size()); - bottleneck_analysis->set_input_pipeline( - input_pipeline.input_pipeline_name.data(), - input_pipeline.input_pipeline_name.size()); - bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps); - bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(), - input_pipeline.iterator_name.size()); - bottleneck_analysis->set_iterator_long_name( - input_pipeline.iterator_long_name.data(), - input_pipeline.iterator_long_name.size()); - bottleneck_analysis->set_iterator_latency_ps( - input_pipeline.iterator_latency_ps); - } -} - -std::string GetSuggestion(BottleneckType type) { - constexpr absl::string_view kPlaybookLink = - "https://www.tensorflow.org/guide/data_performance_analysis"; - constexpr absl::string_view kPlaybookSourceDatasetLink = - "https://www.tensorflow.org/guide/" - "data_performance_analysis#source_datasets"; - constexpr absl::string_view kPlaybookCpuUtilizationLink = - "https://www.tensorflow.org/guide/" - "data_performance_analysis#3_are_you_reaching_high_cpu_utilization"; - constexpr absl::string_view kPlaybookTransformationLink = - "https://www.tensorflow.org/guide/" - "data_performance_analysis#transformation_datasets"; - constexpr absl::string_view kTfGuideParallelDataExtractionLink = - "https://www.tensorflow.org/guide/" - "data_performance#parallelizing_data_extraction"; - constexpr absl::string_view kTfGuideParallelTransformationLink = - "https://www.tensorflow.org/guide/" - "data_performance#parallelizing_data_transformation"; - constexpr absl::string_view kTfGuideCacheLink = - "https://www.tensorflow.org/guide/data_performance#caching"; - constexpr absl::string_view kTfDataServiceLink = - "https://www.tensorflow.org/api_docs/python/tf/data/experimental/" - "service?version=nightly"; - switch (type) { - case BottleneckType::kSlowSource: - return absl::StrFormat( - "1. Check the locality of a host and input data. Ideally, they " - "should be in the same cell (or very close, like the same " - "region).
" - "2. Parallelize reading from this dataset source. See %s and %s for " - "more details.
", - AnchorElement(kPlaybookSourceDatasetLink, "here"), - AnchorElement(kTfGuideParallelDataExtractionLink, "here")); - case BottleneckType::kSlowDataService: - return absl::StrFormat( - "1. Fetching data from tf.data service took a while. Profile the " - "tf.data service worker to analyze the issue further.
" - "2. See %s for more details on tf.data service.
" - "3. See %s for other suggestions.", - AnchorElement(kTfDataServiceLink, "this"), - AnchorElement(kPlaybookLink, "this")); - case BottleneckType::kSlowRemoteSource: - return absl::StrFormat( - "1. The remote data source is slow. Profile its host to analyze the " - "issue further.
" - "2. See %s for other suggestions.", - AnchorElement(kPlaybookLink, "this")); - case BottleneckType::kSlowTransformationWithParallelVersion: - return absl::StrFormat( - "1. Parallelize this transformation by setting " - "num_parallel_calls=tf.data.experimental.AUTOTUNE. See " - "%s for more details.
" - "2. Consider adding cache after this transformation if " - "your data fits into memory and it is appropriate (e.g., there is no " - "randomness in upstream transformations like shuffle). " - "See %s for more details.
" - "3. Find more resources %s.", - AnchorElement(kTfGuideParallelTransformationLink, "this"), - AnchorElement(kTfGuideCacheLink, "this"), - AnchorElement(kPlaybookTransformationLink, "here")); - case BottleneckType::kSlowTransformationWithoutParallelVersion: - return absl::StrFormat( - "1. This transformation is inherently sequential. Add outer " - "parallelism by running multiple copies of the input pipeline over " - "sharded inputs and combining the results. See %s for more " - "details.
" - "2. Consider adding cache after this transformation if " - "your data fits into memory and it is appropriate (e.g., there is no " - "randomness in upstream transformations like shuffle). " - "See %s for more details.
" - "3. Find more resources %s.", - AnchorElement(kPlaybookTransformationLink, "this"), - AnchorElement(kTfGuideCacheLink, "this"), - AnchorElement(kPlaybookCpuUtilizationLink, "here")); - default: - return absl::StrFormat("See %s for suggestions.", - AnchorElement(kPlaybookLink, "this")); - } -} - -void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) { - for (TfDataBottleneckAnalysis& bottleneck_analysis : - *combined_tf_data_stats->mutable_bottleneck_analysis()) { - bottleneck_analysis.set_suggestion( - GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name()))); - } -} - -void SetSummary(CombinedTfDataStats* combined_tf_data_stats) { - int64_t max_latency_ps = 0; - if (combined_tf_data_stats->bottleneck_analysis_size()) { - max_latency_ps = - combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps(); - } - if (max_latency_ps > kSlowCallThresholdPs) { - combined_tf_data_stats->set_is_input_bound(true); - combined_tf_data_stats->set_summary( - "Your profile has a tf.data input pipeline slower than 50 us. For each " - "slow input pipeline, below shows a bottleneck in the input pipeline " - "and a suggestion on how to fix it."); - } else if (max_latency_ps > 0) { - combined_tf_data_stats->set_is_input_bound(false); - combined_tf_data_stats->set_summary( - "Your profile does not have any tf.data input pipeline slower than 50 " - "us. Your job could be still input bound if this profile didn't " - "capture all workers."); - } else { - combined_tf_data_stats->set_is_input_bound(false); - combined_tf_data_stats->set_summary( - "No tf.data activity captured in your profile. If your job uses " - "tf.data, try to capture a longer profile."); - } -} - -} // namespace - -BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) { - static auto* kBottleneckTypeMap = new absl::flat_hash_map( - {// Read from storage. - {"TFRecord", BottleneckType::kSlowSource}, - {"SSTable", BottleneckType::kSlowSource}, - {"RecordIO", BottleneckType::kSlowSource}, - {"Spanner", BottleneckType::kSlowSource}, - {"TFColumn", BottleneckType::kSlowSource}, - {"SleepwalkRemoteDataset", BottleneckType::kSlowSource}, - {"TextLine", BottleneckType::kSlowSource}, - {"StitchedTimelineDataset", BottleneckType::kSlowSource}, - {"DateKeyDataset", BottleneckType::kSlowSource}, - {"CapacitorProto", BottleneckType::kSlowSource}, - {"LMDB", BottleneckType::kSlowSource}, - {"ExternalDataset", BottleneckType::kSlowSource}, - {"PearModel", BottleneckType::kSlowSource}, - {"FixedLengthRecordV2", BottleneckType::kSlowSource}, - // Read from local memory. - {"FromTensor", BottleneckType::kSlowSource}, - {"TensorSlice", BottleneckType::kSlowSource}, - {"Generator", BottleneckType::kSlowSource}, - {"SyntheticDatasetOp", BottleneckType::kSlowSource}, - // tf.data service. - {"DataService", BottleneckType::kSlowDataService}, - // Read from remote memory. - {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource}, - {"ReverbDataset", BottleneckType::kSlowRemoteSource}, - {"DatasetSampleGame", BottleneckType::kSlowRemoteSource}, - {"Courier", BottleneckType::kSlowRemoteSource}, - {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource}, - // Transformations with parallel version. - {"Map", BottleneckType::kSlowTransformationWithParallelVersion}, - {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion}, - // Transformations without parallel version. - {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion}, - {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion}, - {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}}); - if (auto type = - gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) { - return *type; - } - return BottleneckType::kOther; -} - -void CombinedTfDataStatsBuilder::Add(absl::string_view host_name, - XPlane* host_plane) { - TfDataStats& tf_data_stats = - (*combined_tf_data_stats_ - ->mutable_tf_data_stats())[std::string(host_name)]; - tsl::profiler::EventForest event_forest; - event_forest.AddPlanes(tsl::profiler::CreateTfXPlaneVisitor, {host_plane}); - event_forest.ConnectEvents(); - event_forest.ConnectTfDataEvents(); - absl::flat_hash_set device_input_pipeline_ids; - absl::flat_hash_map> - root_iterator_event_map; - ProcessEventForest(event_forest, &device_input_pipeline_ids, - &root_iterator_event_map, &tf_data_stats); - ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map, - &tf_data_stats); -} - -void CombinedTfDataStatsBuilder::Finalize() { - SetBottleneckAnalysis(combined_tf_data_stats_); - if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_); - SetSummary(combined_tf_data_stats_); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h deleted file mode 100644 index 1727dcadefa7dd..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -TF_CONST_INIT extern const int64_t kSlowCallThresholdPs; - -enum class BottleneckType { - kSlowSource, - kSlowDataService, - kSlowRemoteSource, - kSlowTransformationWithParallelVersion, - kSlowTransformationWithoutParallelVersion, - kOther, -}; - -BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name); - -class CombinedTfDataStatsBuilder { - public: - explicit CombinedTfDataStatsBuilder( - CombinedTfDataStats* combined_tf_data_stats, - bool generate_suggestion = true) - : combined_tf_data_stats_(combined_tf_data_stats), - generate_suggestion_(generate_suggestion) {} - - void Add(absl::string_view host_name, XPlane* host_plane); - - // Finalizes by populating TfDataBottleneckAnalysis. - void Finalize(); - - private: - CombinedTfDataStats* combined_tf_data_stats_; - bool generate_suggestion_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc deleted file mode 100644 index 6a7eb75194c73e..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc +++ /dev/null @@ -1,420 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" - -#include - -#include -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::EqualsProto; - -// Test with the following example dataset: -// dataset = tf.data.Dataset.range(8) -// dataset = dataset.prefetch(2) -// for _ in dataset: -// pass -TEST(XPlaneToTfDataStatsTest, HostInputPipeline) { - constexpr int64_t kPrefetchIteratorId = 123; - constexpr int64_t kRangeIteratorId = 456; - constexpr int64_t kFirstElementId = 100; - constexpr int64_t kSecondElementId = 200; - - XPlane host_plane; - XPlaneBuilder host_plane_builder(&host_plane); - host_plane_builder.ReserveLines(2); - - auto consumer_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, - 100000000, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 80000000, 20000000, - {{StatType::kElementId, kFirstElementId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", - 200000000, 20000000, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 210000000, 10000000, - {{StatType::kElementId, kSecondElementId}}); - - auto producer_thread = host_plane_builder.GetOrCreateLine(1); - // Blocking producer. - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 0, 80000000, - {{StatType::kElementId, kFirstElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Range", 0, 80000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kPrefetchIteratorId}}); - // Non-blocking producer. - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 100000000, 80000000, - {{StatType::kElementId, kSecondElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Range", 100000000, 80000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kPrefetchIteratorId}}); - - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - builder.Add("host1", &host_plane); - builder.Finalize(); - EXPECT_THAT( - combined_tf_data_stats, EqualsProto(R"pb( - bottleneck_analysis: { - host: "host1" - input_pipeline: "Host:0" - max_latency_ps: 100000000 - iterator_name: "Range" - iterator_long_name: "Iterator::Prefetch::Range" - iterator_latency_ps: 80000000 - suggestion: "See this for suggestions." - } - tf_data_stats: { - key: "host1" - value: { - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "Prefetch" - long_name: "Iterator::Prefetch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Range" - long_name: "Iterator::Prefetch::Range" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: HOST name: "Host:0" } - avg_latency_ps: 60000000 - min_latency_ps: 20000000 - max_latency_ps: 100000000 - num_slow_calls: 1 - stats { - bottleneck_iterator_id: 456 - bottleneck_iterator_latency_ps: 80000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 100000000 - self_time_ps: 20000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 0 - duration_ps: 80000000 - self_time_ps: 80000000 - is_blocking: true - num_calls: 1 - } - } - } - stats { - bottleneck_iterator_id: 123 - bottleneck_iterator_latency_ps: 20000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 200000000 - duration_ps: 20000000 - self_time_ps: 20000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 100000000 - duration_ps: 80000000 - self_time_ps: 80000000 - is_blocking: false - num_calls: 1 - } - } - } - } - } - } - } - is_input_bound: true - summary: "Your profile has a tf.data input pipeline slower than 50 us. For each slow input pipeline, below shows a bottleneck in the input pipeline and a suggestion on how to fix it." - )pb")); -} - -TEST(XPlaneToTfDataStatsTest, DeviceInputPipeline) { - constexpr int64_t kPrefetchIteratorId = 123; - constexpr int64_t kRangeIteratorId = 456; - constexpr int64_t kElementId = 100; - - XPlane host_plane; - XPlaneBuilder host_plane_builder(&host_plane); - host_plane_builder.ReserveLines(2); - - auto consumer_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, - 30000000, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", - 100000000, 100000000, - {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 180000000, 20000000, - {{StatType::kElementId, kElementId}}); - - auto producer_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 100000000, 80000000, - {{StatType::kElementId, kElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Generator", 100000000, 80000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kPrefetchIteratorId}}); - - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - builder.Add("host1", &host_plane); - builder.Finalize(); - // Device input pipeline is not considered for bottleneck analysis. - EXPECT_THAT( - combined_tf_data_stats, EqualsProto(R"pb( - tf_data_stats: { - key: "host1" - value: { - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "Prefetch" - long_name: "Iterator::Prefetch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Generator" - long_name: "Iterator::Prefetch::Generator" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: DEVICE name: "Device:0" } - avg_latency_ps: 65000000 - min_latency_ps: 30000000 - max_latency_ps: 100000000 - num_slow_calls: 1 - stats { - bottleneck_iterator_id: 456 - bottleneck_iterator_latency_ps: 80000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 100000000 - duration_ps: 100000000 - self_time_ps: 20000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 100000000 - duration_ps: 80000000 - self_time_ps: 80000000 - is_blocking: true - num_calls: 1 - } - } - } - stats { - bottleneck_iterator_id: 123 - bottleneck_iterator_latency_ps: 30000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 30000000 - self_time_ps: 30000000 - is_blocking: true - num_calls: 1 - } - } - } - } - } - } - } - summary: "No tf.data activity captured in your profile. If your job uses tf.data, try to capture a longer profile." - )pb")); -} - -// Test with the following example dataset: -// dataset = tf.data.Dataset.range(8) -// dataset = dataset.map(lambda x: x + 1) -// dataset = dataset.batch(2) -// for _ in dataset: -// pass -TEST(XPlaneToTfDataStatsTest, MapAndBatch) { - constexpr int64_t kMapAndBatchIteratorId = 123; - constexpr int64_t kRangeIteratorId = 456; - constexpr int64_t kElementId = 100; - - XPlane host_plane; - XPlaneBuilder host_plane_builder(&host_plane); - host_plane_builder.ReserveLines(2); - - XLineBuilder consumer_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::MapAndBatch", - 0, 100000000, {{StatType::kStepId, kMapAndBatchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kMapAndBatchConsume, 80000000, 20000000, - {{StatType::kElementId, kElementId}}); - - XLineBuilder producer_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kMapAndBatchProduce, 0, 30000000, - {{StatType::kElementId, kElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::MapAndBatch::Range", 0, 30000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kMapAndBatchIteratorId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kMapAndBatchProduce, 40000000, 30000000, - {{StatType::kElementId, kElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::MapAndBatch::Range", 40000000, 30000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kMapAndBatchIteratorId}}); - - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - builder.Add("host1", &host_plane); - builder.Finalize(); - EXPECT_THAT( - combined_tf_data_stats, EqualsProto(R"pb( - bottleneck_analysis: { - host: "host1" - input_pipeline: "Host:0" - max_latency_ps: 100000000 - iterator_name: "Range" - iterator_long_name: "Iterator::MapAndBatch::Range" - iterator_latency_ps: 60000000 - suggestion: "See this for suggestions." - } - tf_data_stats: { - key: "host1" - value: { - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "MapAndBatch" - long_name: "Iterator::MapAndBatch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Range" - long_name: "Iterator::MapAndBatch::Range" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: HOST name: "Host:0" } - avg_latency_ps: 100000000 - min_latency_ps: 100000000 - max_latency_ps: 100000000 - num_slow_calls: 1 - stats { - bottleneck_iterator_id: 456 - bottleneck_iterator_latency_ps: 60000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 100000000 - self_time_ps: 40000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 0 - duration_ps: 60000000 - self_time_ps: 60000000 - is_blocking: true - num_calls: 2 - } - } - } - } - } - } - } - is_input_bound: true - summary: "Your profile has a tf.data input pipeline slower than 50 us. For each slow input pipeline, below shows a bottleneck in the input pipeline and a suggestion on how to fix it." - )pb")); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc deleted file mode 100644 index fc44baf39d8406..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -std::pair Decode( - absl::string_view function_name, absl::string_view mode) { - // mode is one of ["eager", "concrete", "traced-xla", "traced-nonXla", - // "notTraced-xla", "notTraced-nonXla"] - if (mode == "eager") return {EAGER_MODE, INVALID_COMPILER}; - if (mode == "concrete") return {CONCRETE_MODE, INVALID_COMPILER}; - if (mode == "traced-xla") return {TRACED_MODE, XLA_COMPILER}; - if (mode == "traced-nonXla") return {TRACED_MODE, OTHER_COMPILER}; - if (mode == "notTraced-xla") return {NOT_TRACED_MODE, XLA_COMPILER}; - if (mode == "notTraced-nonXla") return {NOT_TRACED_MODE, OTHER_COMPILER}; - // Shouldn't reach here. - LOG(ERROR) << absl::StrCat("tf-function '", function_name, - "' has an unexpected execution mode '", mode, "'") - << std::endl; - return {INVALID_MODE, INVALID_COMPILER}; - DCHECK(false); -} - -double ComputeExpensiveCallPercent(const TfFunction& tf_function) { - // Computes the expensiveness in terms of time (rather than count). - uint64 total_call_time_ps = 0; - uint64 expensive_call_time_ps = 0; - for (const auto& mode_metrics : tf_function.metrics()) { - const auto mode = mode_metrics.first; - const auto& metrics = mode_metrics.second; - total_call_time_ps += metrics.self_time_ps(); - if (mode == TRACED_MODE || mode == EAGER_MODE) { - expensive_call_time_ps += metrics.self_time_ps(); - } - } - return tsl::profiler::SafeDivide(100.0 * expensive_call_time_ps, - total_call_time_ps); -} - -// Each invocation of a tf-function creates an ActivationRecord. -struct ActivationRecord { - std::string function_name; // name of the tf-function. - tsl::profiler::Timespan timespan; // timespan of this invocation. - TfFunctionExecutionMode execution_mode; // execution mode. - TfFunctionCompiler compiler; // compiler used. - int64_t tracing_count; // the total tracing count of this function when this - // invocation happened. - uint64 children_duration_ps; // Sum of the duration of all (immediate) - // children tf-functions of this function. - ActivationRecord() - : function_name(""), - execution_mode(INVALID_MODE), - compiler(INVALID_COMPILER), - tracing_count(0), - children_duration_ps(0) {} - ActivationRecord(absl::string_view name, - const tsl::profiler::Timespan& timespan, - TfFunctionExecutionMode exe_mode, - TfFunctionCompiler compiler, int64_t tracing_cnt) - : function_name(std::string(name)), - timespan(timespan), - execution_mode(exe_mode), - compiler(compiler), - tracing_count(tracing_cnt), - children_duration_ps(0) {} - std::string DebugString() const { - return absl::StrCat("{", function_name, ", ", - TfFunctionExecutionMode_Name(execution_mode), ", ", - TfFunctionCompiler_Name(compiler), - ", tracing_count:", tracing_count, - ", children_duration:", children_duration_ps, - " ps, timespan:", timespan.DebugString(), "}"); - } -}; - -// Entry or exit point of a tf-function. -struct EntryOrExit { - bool is_entry; // true for entry, false for exit. - int64_t index; // index to the ActivationRecord. - uint64 timestamp_ps; // the time when this entry/exit happens. - EntryOrExit() : is_entry(false), index(-1), timestamp_ps(0) {} - EntryOrExit(bool is_entry, int64_t index, uint64 timestamp_ps) - : is_entry(is_entry), index(index), timestamp_ps(timestamp_ps) {} - std::string DebugString() const { - std::string entry_or_exit = is_entry ? "entry, " : "exit, "; - return absl::StrCat("{", entry_or_exit, "idx:", index, - ", timestamp:", timestamp_ps, "}"); - } -}; - -TfFunctionCompiler CombineCompilers(TfFunctionCompiler a, - TfFunctionCompiler b) { - if (a == INVALID_COMPILER) return b; - if (b == INVALID_COMPILER) return a; - if (a == b) return a; - return MIXED_COMPILER; -} - -void CombineTfFunctionMetrics(const TfFunctionMetrics& src, - TfFunctionMetrics* dst) { - dst->set_count(src.count() + dst->count()); - dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps()); -} - -void CombineTfFunction(const TfFunction& src, TfFunction* dst) { - dst->set_total_tracing_count( - std::max(src.total_tracing_count(), dst->total_tracing_count())); - dst->set_compiler(CombineCompilers(src.compiler(), dst->compiler())); - for (const auto& mode_metrics : src.metrics()) { - int32_t execution_mode = mode_metrics.first; - const TfFunctionMetrics& src_metrics = mode_metrics.second; - TfFunctionMetrics* dst_metrics = - gtl::FindOrNull(*dst->mutable_metrics(), execution_mode); - if (dst_metrics == nullptr) { - (*dst->mutable_metrics())[execution_mode] = src_metrics; - } else { - CombineTfFunctionMetrics(src_metrics, dst_metrics); - } - } - dst->set_expensive_call_percent(ComputeExpensiveCallPercent(*dst)); -} - -// Execution history of all tf-functions invoked. -class TfFunctionExecutions { - public: - explicit TfFunctionExecutions(const XLineVisitor& line) { - // Creates points_ and activations_ from line. - line.ForEachEvent([&](const XEventVisitor& event) { - absl::string_view mode; - int64_t tracing_count = 0; - event.ForEachStat([&mode, &tracing_count](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kTfFunctionCall: - mode = stat.StrOrRefValue(); - break; - case StatType::kTfFunctionTracingCount: - tracing_count = stat.IntValue(); - break; - } - }); - if (mode.empty()) return; - - // event is a tf-function. - int64_t index = activations_.size(); - auto timespan = event.GetTimespan(); - auto mode_compiler = Decode(event.Name(), mode); - ActivationRecord activation_record = - ActivationRecord(event.Name(), timespan, mode_compiler.first, - mode_compiler.second, tracing_count); - activations_.push_back(activation_record); - EntryOrExit entry_point = - EntryOrExit(/*is_entry=*/true, index, timespan.begin_ps()); - EntryOrExit exit_point = - EntryOrExit(/*is_entry=*/false, index, timespan.end_ps()); - points_.push_back(entry_point); - points_.push_back(exit_point); - }); - - // Sorts points_ in ascending order of timestamps. - auto ascending_in_timestamp = [](const EntryOrExit& a, - const EntryOrExit& b) { - return a.timestamp_ps < b.timestamp_ps; - }; - absl::c_sort(points_, ascending_in_timestamp); - - // Calculates the children duration for each activation record. - CalculateChildrenDurations(); - } - - std::string DebugString() const { - std::string result = "\nActivations:\n"; - for (int i = 0, end = activations_.size(); i < end; i++) { - absl::StrAppend(&result, "[", i, "] ", activations_[i].DebugString(), - "\n"); - } - absl::StrAppend(&result, "tf-function Entry/Exit Points:\n"); - for (const auto& pt : points_) { - absl::StrAppend(&result, pt.DebugString(), "\n"); - } - return result; - } - - // Converts this execution history to a TfFunctionDb. - TfFunctionDb ConvertToTfFunctionDb() { - TfFunctionDb result; - for (const auto& record : activations_) { - TfFunction* fun = &(*result.mutable_tf_functions())[record.function_name]; - fun->set_total_tracing_count( - std::max(static_cast(fun->total_tracing_count()), - record.tracing_count)); - fun->set_compiler(CombineCompilers(fun->compiler(), record.compiler)); - // The self-time of this function is the difference between the duration - // of this function and the duration of its children. - uint64 self_time_ps = - record.timespan.duration_ps() - record.children_duration_ps; - // Updates the metrics for this execution mode with this invocation. - TfFunctionMetrics* metrics = - &(*fun->mutable_metrics())[record.execution_mode]; - metrics->set_count(metrics->count() + 1); - metrics->set_self_time_ps(metrics->self_time_ps() + self_time_ps); - } - for (auto& name_fun : *result.mutable_tf_functions()) { - TfFunction& fun = name_fun.second; - fun.set_expensive_call_percent(ComputeExpensiveCallPercent(fun)); - } - return result; - } - - // Calculates the children duration of every tf-function. - void CalculateChildrenDurations() { - std::stack call_stack; - for (const auto& pt : points_) { - if (pt.is_entry) { - // Function entry. - call_stack.push(pt.index); - } else { - // Function exit. - DCHECK(call_stack.top() == pt.index); // must be well nested. - uint64 call_duration = activations_[pt.index].timespan.duration_ps(); - call_stack.pop(); - if (!call_stack.empty()) { - // call_stack.top() is the parent tf-function; adds call_duration to - // its children_duration. - activations_[call_stack.top()].children_duration_ps += call_duration; - } - } - } - } - - private: - // ActivationRecords for all tf-function invocations. - std::vector activations_; - // Entry and exit points of all invocations. - std::vector points_; -}; - -} // namespace - -std::string DebugString(const TfFunctionDb& tf_function_db) { - std::string str; - tsl::protobuf::TextFormat::PrintToString(tf_function_db, &str); - return str; -} - -void CombineTfFunctionDb(const TfFunctionDb& src, TfFunctionDb* dst) { - for (const auto& name_function : src.tf_functions()) { - const auto& name = name_function.first; - const auto& src_fun = name_function.second; - TfFunction* dst_fun = gtl::FindOrNull(*dst->mutable_tf_functions(), name); - if (dst_fun == nullptr) { - (*dst->mutable_tf_functions())[name] = src_fun; - } else { - CombineTfFunction(src_fun, dst_fun); - } - } -} - -TfFunctionDb ConvertHostThreadsXLineToTfFunctionDb(const XLineVisitor& line) { - TfFunctionExecutions tf_function_executions = TfFunctionExecutions(line); - return tf_function_executions.ConvertToTfFunctionDb(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h deleted file mode 100644 index 54c7d127d36b8d..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ - -#include - -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -// Converts from the given XLine to a TfFunctionDb. -TfFunctionDb ConvertHostThreadsXLineToTfFunctionDb(const XLineVisitor& line); - -// Returns a debugging string for the given TfFunctionDb. -std::string DebugString(TfFunctionDb tf_function_db); - -// Combines the tf-function statistics from src and dst into dst. -void CombineTfFunctionDb(const TfFunctionDb& src, TfFunctionDb* dst); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc deleted file mode 100644 index c1764961aa948d..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" - -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/tf_function.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -const absl::string_view kEager = "eager"; -const absl::string_view kConcrete = "concrete"; -const absl::string_view kTracedNonXla = "traced-nonXla"; -const absl::string_view kTracedXla = "traced-xla"; -const absl::string_view kNotTracedNonXla = "notTraced-nonXla"; -const absl::string_view kNotTracedXla = "notTraced-xla"; - -constexpr double kMaxError = 0.001; - -TfFunctionDb ConvertXSpaceToTfFunctionDb(const XSpace& space) { - TfFunctionDb result; - const XPlane* host_plane = FindPlaneWithName(space, kHostThreadsPlaneName); - if (host_plane) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_plane); - plane.ForEachLine([&result](const XLineVisitor& line) { - TfFunctionDb tf_function_db = ConvertHostThreadsXLineToTfFunctionDb(line); - CombineTfFunctionDb(tf_function_db, &result); - }); - } - return result; -} - -TEST(ConvertXPlaneToTfFunctions, CombineTwoThreads) { - XSpace space; - XPlaneBuilder host_plane_builder(space.add_planes()); - host_plane_builder.SetName(kHostThreadsPlaneName); - host_plane_builder.ReserveLines(2); - std::string kFunctionName = "decrement"; - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, kFunctionName, - 10, 100, kTracedNonXla, 1); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, kFunctionName, - 150, 20, kNotTracedNonXla, 2); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, kFunctionName, - 200, 80, kTracedNonXla, 3); - - auto other_thread = host_plane_builder.GetOrCreateLine(1); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, kFunctionName, - 20, 100, kTracedNonXla, 2); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, kFunctionName, - 160, 20, kNotTracedNonXla, 2); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, kFunctionName, - 210, 80, kTracedXla, 4); - - TfFunctionDb tf_function_db = ConvertXSpaceToTfFunctionDb(space); - EXPECT_EQ(tf_function_db.tf_functions().size(), 1); - EXPECT_EQ(tf_function_db.tf_functions().count(kFunctionName), 1); - const TfFunction& tf_function = - tf_function_db.tf_functions().at(kFunctionName); - EXPECT_EQ(tf_function.total_tracing_count(), 4); - EXPECT_EQ(tf_function.compiler(), MIXED_COMPILER); - EXPECT_NEAR(tf_function.expensive_call_percent(), 90, kMaxError); - - const auto& metrics = tf_function.metrics(); - EXPECT_EQ(metrics.size(), 2); - EXPECT_EQ(metrics.count(TRACED_MODE), 1); - EXPECT_EQ(metrics.count(NOT_TRACED_MODE), 1); - const auto& traced_mode = metrics.at(TRACED_MODE); - EXPECT_EQ(traced_mode.count(), 4); - EXPECT_EQ(traced_mode.self_time_ps(), 360); - const auto& not_traced_mode = metrics.at(NOT_TRACED_MODE); - EXPECT_EQ(not_traced_mode.count(), 2); - EXPECT_EQ(not_traced_mode.self_time_ps(), 40); -} - -TEST(ConvertXPlaneToTfFunctions, NestedFunctions) { - XSpace space; - XPlaneBuilder host_plane_builder(space.add_planes()); - host_plane_builder.SetName(kHostThreadsPlaneName); - host_plane_builder.ReserveLines(1); - std::string kOuterFunctionName = "outer"; - std::string kInnerFunctionName = "inner"; - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, - kOuterFunctionName, 10, 100, kTracedNonXla, 1); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, - kInnerFunctionName, 30, 40, kNotTracedXla, 0); - TfFunctionDb tf_function_db = ConvertXSpaceToTfFunctionDb(space); - EXPECT_EQ(tf_function_db.tf_functions().size(), 2); - EXPECT_EQ(tf_function_db.tf_functions().count(kOuterFunctionName), 1); - EXPECT_EQ(tf_function_db.tf_functions().count(kInnerFunctionName), 1); - const TfFunction& outer = - tf_function_db.tf_functions().at(kOuterFunctionName); - EXPECT_EQ(outer.total_tracing_count(), 1); - EXPECT_EQ(outer.compiler(), OTHER_COMPILER); - EXPECT_NEAR(outer.expensive_call_percent(), 100, kMaxError); - const auto& outer_metrics = outer.metrics(); - EXPECT_EQ(outer_metrics.size(), 1); - EXPECT_EQ(outer_metrics.count(TRACED_MODE), 1); - const auto& traced_mode = outer_metrics.at(TRACED_MODE); - EXPECT_EQ(traced_mode.count(), 1); - EXPECT_EQ(traced_mode.self_time_ps(), 60); - const TfFunction& inner = - tf_function_db.tf_functions().at(kInnerFunctionName); - EXPECT_EQ(inner.total_tracing_count(), 0); - EXPECT_EQ(inner.compiler(), XLA_COMPILER); - EXPECT_NEAR(inner.expensive_call_percent(), 0, kMaxError); - const auto& inner_metrics = inner.metrics(); - EXPECT_EQ(inner_metrics.size(), 1); - EXPECT_EQ(inner_metrics.count(NOT_TRACED_MODE), 1); - const auto& not_traced_mode = inner_metrics.at(NOT_TRACED_MODE); - EXPECT_EQ(not_traced_mode.count(), 1); - EXPECT_EQ(not_traced_mode.self_time_ps(), 40); -} - -TEST(ConvertXPlaneToTfFunctions, EagerPlusConcrete) { - XSpace space; - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&space)); - host_plane_builder.ReserveLines(2); - std::string kEagerFunctionName = "i_am_eager"; - std::string kConcreteFunctionName = "i_am_concrete"; - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, - kEagerFunctionName, 10, 200, kEager); - auto other_thread = host_plane_builder.GetOrCreateLine(1); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, - kConcreteFunctionName, 20, 40, kConcrete); - TfFunctionDb tf_function_db = ConvertXSpaceToTfFunctionDb(space); - EXPECT_EQ(tf_function_db.tf_functions().size(), 2); - EXPECT_EQ(tf_function_db.tf_functions().count(kEagerFunctionName), 1); - EXPECT_EQ(tf_function_db.tf_functions().count(kConcreteFunctionName), 1); - const TfFunction& eager = - tf_function_db.tf_functions().at(kEagerFunctionName); - EXPECT_EQ(eager.total_tracing_count(), 0); - EXPECT_EQ(eager.compiler(), INVALID_COMPILER); - EXPECT_NEAR(eager.expensive_call_percent(), 100, kMaxError); - const auto& eager_metrics = eager.metrics(); - EXPECT_EQ(eager_metrics.size(), 1); - EXPECT_EQ(eager_metrics.count(EAGER_MODE), 1); - const auto& eager_mode = eager_metrics.at(EAGER_MODE); - EXPECT_EQ(eager_mode.count(), 1); - EXPECT_EQ(eager_mode.self_time_ps(), 200); - const TfFunction& concrete = - tf_function_db.tf_functions().at(kConcreteFunctionName); - EXPECT_EQ(concrete.total_tracing_count(), 0); - EXPECT_EQ(concrete.compiler(), INVALID_COMPILER); - EXPECT_NEAR(concrete.expensive_call_percent(), 0, kMaxError); - const auto& concrete_metrics = concrete.metrics(); - EXPECT_EQ(concrete_metrics.size(), 1); - EXPECT_EQ(concrete_metrics.count(CONCRETE_MODE), 1); - const auto& concrete_mode = concrete_metrics.at(CONCRETE_MODE); - EXPECT_EQ(concrete_mode.count(), 1); - EXPECT_EQ(concrete_mode.self_time_ps(), 40); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc deleted file mode 100644 index b70573db79458a..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tool_names.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_join.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_hlo.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -absl::StatusOr GetAvailableToolNames( - const SessionSnapshot& session_snapshot) { - std::vector tools; - bool is_cloud_vertex_ai = !session_snapshot.HasAccessibleRunDir(); - if (session_snapshot.XSpaceSize() != 0) { - tools.reserve(11); - tools.push_back(is_cloud_vertex_ai ? "trace_viewer" : "trace_viewer@"); - tools.push_back("overview_page"); - // TODO(jonahweaver): Re-enable input_pipeline_analyzer when it is ready. - // b/407096031 - // tools.push_back("input_pipeline_analyzer"); - tools.push_back("framework_op_stats"); - tools.push_back("memory_profile"); - tools.push_back("pod_viewer"); - tools.push_back("op_profile"); - tools.push_back("inference_profile"); - tools.push_back("hlo_stats"); - tools.push_back("roofline_model"); - - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - - if (!FindPlanesWithPrefix(*xspace, kGpuPlanePrefix).empty()) { - tools.push_back("kernel_stats"); - } - - TF_ASSIGN_OR_RETURN(bool has_hlo, - ConvertMultiXSpaceToHloProto(session_snapshot)); - if (has_hlo) { - tools.push_back("memory_viewer"); - tools.push_back("graph_viewer"); - } - - TF_ASSIGN_OR_RETURN(bool has_dcn_collective_stats, - HasDcnCollectiveStatsInMultiXSpace(session_snapshot)); - if (has_dcn_collective_stats) { - tools.push_back("dcn_collective_stats"); - } - } - - return absl::StrJoin(tools, ","); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.h b/tensorflow/core/profiler/convert/xplane_to_tool_names.h deleted file mode 100644 index 3a23604ee7fcfd..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ - -#include - -#include "absl/status/statusor.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" - -namespace tensorflow { -namespace profiler { - -// Gets the names of the available tools given a session snapshot. -// Returns a comma separated list of tool names. -absl::StatusOr GetAvailableToolNames( - const SessionSnapshot& session_snapshot); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc deleted file mode 100644 index d10e81519e563a..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tool_names.h" - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/status.h" -#include "tensorflow/core/platform/file_system.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -struct XPlaneToToolsTestCase { - std::string test_name; - std::string_view plane_name; - bool has_hlo_module; - bool has_dcn_collective_stats; - std::vector expected_tools; -}; - -SessionSnapshot CreateSessionSnapshot(std::unique_ptr xspace, - bool has_hlo_module, - bool has_dcn_collective_stats) { - std::string test_name = - ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::string path = absl::StrCat("ram://", test_name, "/"); - std::unique_ptr xplane_file; - tsl::Env::Default() - ->NewAppendableFile(absl::StrCat(path, "hostname.xplane.pb"), - &xplane_file) - .IgnoreError(); - std::vector paths = {path}; - - if (has_hlo_module) { - tsl::Env::Default() - ->NewAppendableFile(absl::StrCat(path, "module_name.hlo_proto.pb"), - &xplane_file) - .IgnoreError(); - } else { - tsl::Env::Default() - ->NewAppendableFile(absl::StrCat(path, "NO_MODULE.hlo_proto.pb"), - &xplane_file) - .IgnoreError(); - } - - if (has_dcn_collective_stats) { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "hostname.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "ALL_HOSTS.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - } else { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "NO_HOST.dcn_collective_stats.pb"), &xplane_file) - .IgnoreError(); - } - - std::vector> xspaces; - xspaces.push_back(std::move(xspace)); - - absl::StatusOr session_snapshot = - SessionSnapshot::Create(paths, std::move(xspaces)); - TF_CHECK_OK(session_snapshot.status()); - return std::move(session_snapshot.value()); -} - -using XPlaneToToolsTest = ::testing::TestWithParam; - -TEST_P(XPlaneToToolsTest, ToolsList) { - const XPlaneToToolsTestCase& test_case = GetParam(); - auto xspace = std::make_unique(); - FindOrAddMutablePlaneWithName(xspace.get(), test_case.plane_name); - - SessionSnapshot sessionSnapshot = - CreateSessionSnapshot(std::move(xspace), test_case.has_hlo_module, - test_case.has_dcn_collective_stats); - - absl::StatusOr toolsString = - GetAvailableToolNames(sessionSnapshot); - ASSERT_TRUE(toolsString.ok()); - - std::vector tools = absl::StrSplit(toolsString.value(), ','); - - std::vector expected_tools = { - "trace_viewer", - "overview_page", - // TODO(jonahweaver): Re-enable input_pipeline_analyzer when it is ready. - // b/407096031 - // "input_pipeline_analyzer", - "framework_op_stats", - "memory_profile", - "pod_viewer", - "op_profile", - "hlo_stats", - "roofline_model", - "inference_profile", - }; - expected_tools.insert(expected_tools.end(), test_case.expected_tools.begin(), - test_case.expected_tools.end()); - EXPECT_THAT(tools, ::testing::UnorderedElementsAreArray(expected_tools)); -} - -INSTANTIATE_TEST_SUITE_P( - XPlaneToToolsTests, XPlaneToToolsTest, - ::testing::ValuesIn({ - {"ToolsForTpuWithoutHloModule", kTpuPlanePrefix, false, false, {}}, - {"ToolsForTpuWithHloModule", - kTpuPlanePrefix, - true, - false, - {"graph_viewer", "memory_viewer"}}, - {"ToolsForGpuWithoutHloModule", - kGpuPlanePrefix, - false, - false, - {"kernel_stats"}}, - {"ToolsForGpuWithHloModule", - kGpuPlanePrefix, - true, - false, - {"kernel_stats", "graph_viewer", "memory_viewer"}}, - {"ToolsForTpuWithDcnCollectiveStats", - kTpuPlanePrefix, - false, - true, - {"dcn_collective_stats"}}, - }), - [](const ::testing::TestParamInfo& info) { - return info.param.test_name; - }); - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc deleted file mode 100644 index b743e25c586d2f..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ /dev/null @@ -1,432 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" - -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/numbers.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/file_system.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/convert/xplane_to_trace_events.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/convert/compute_inference_latency.h" -#include "tensorflow/core/profiler/convert/hlo_to_tools_data.h" -#include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" -#include "tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h" -#include "tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h" -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" -#include "tensorflow/core/profiler/convert/op_stats_to_op_profile.h" -#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" -#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" -#include "tensorflow/core/profiler/convert/op_stats_to_roofline_model.h" -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" -#include "tensorflow/core/profiler/convert/process_megascale_dcn.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_tool_names.h" -#include "tensorflow/core/profiler/convert/xplane_to_trace_container.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" -#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/convert/trace_viewer/trace_events_to_json.h" // from @org_xprof -#include "xprof/convert/trace_viewer/trace_viewer_visibility.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/hlo_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/inference_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_profile.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/overview_page.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/tf_data_stats.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof -#include "xprof/utils/hardware_type_utils.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -struct TraceViewOption { - uint64_t resolution = 0; - double start_time_ms = 0.0; - double end_time_ms = 0.0; -}; - -absl::StatusOr GetTraceViewOption(const ToolOptions& options) { - TraceViewOption trace_options; - auto start_time_ms_opt = - GetParamWithDefault(options, "start_time_ms", "0.0"); - auto end_time_ms_opt = - GetParamWithDefault(options, "end_time_ms", "0.0"); - auto resolution_opt = - GetParamWithDefault(options, "resolution", "0"); - - if (!absl::SimpleAtoi(resolution_opt, &trace_options.resolution) || - !absl::SimpleAtod(start_time_ms_opt, &trace_options.start_time_ms) || - !absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms)) { - return errors::InvalidArgument("wrong arguments"); - } - return trace_options; -} - -absl::StatusOr ConvertXSpaceToTraceEvents( - const SessionSnapshot& session_snapshot, const absl::string_view tool_name, - const ToolOptions& options) { - if (session_snapshot.XSpaceSize() != 1) { - return errors::InvalidArgument( - "Trace events tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/true); - std::string content; - if (tool_name == "trace_viewer") { - tsl::profiler::ConvertXSpaceToTraceEventsString(*xspace, &content); - return content; - } else { // streaming trace viewer. - std::string host_name = session_snapshot.GetHostname(0); - auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); - if (!sstable_path) { - return errors::Unimplemented( - "streaming trace viewer hasn't been supported in Cloud AI"); - } - if (!Env::Default()->FileExists(*sstable_path).ok()) { - ProcessMegascaleDcn(xspace.get()); - TraceEventsContainer trace_container; - ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); - std::unique_ptr file; - TF_RETURN_IF_ERROR( - tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); - TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); - } - TF_ASSIGN_OR_RETURN(TraceViewOption trace_option, - GetTraceViewOption(options)); - auto visibility_filter = std::make_unique( - tsl::profiler::MilliSpan(trace_option.start_time_ms, - trace_option.end_time_ms), - trace_option.resolution); - TraceEventsContainer trace_container; - // Trace smaller than threshold will be disabled from streaming. - constexpr int64_t kDisableStreamingThreshold = 500000; - TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( - *sstable_path, /*filter=*/nullptr, std::move(visibility_filter), - kDisableStreamingThreshold)); - JsonTraceOptions options; - IOBufferAdapter adapter(&content); - TraceEventsToJson( - options, trace_container, &adapter); - return content; - } -} - -absl::Status ConvertMultiXSpaceToCombinedOpStatsWithCache( - const SessionSnapshot& session_snapshot, OpStats* combined_op_stats) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - options.generate_kernel_stats_db = true; - TF_ASSIGN_OR_RETURN(auto has_cache, - session_snapshot.HasCacheFile(StoredDataType::OP_STATS)); - if (has_cache.first) { - TF_RETURN_IF_ERROR(ReadBinaryProto(session_snapshot, - StoredDataType::OP_STATS, - kAllHostsIdentifier, combined_op_stats)); - - } else { - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, combined_op_stats)); - if (!WriteBinaryProto(session_snapshot, StoredDataType::OP_STATS, - kAllHostsIdentifier, *combined_op_stats) - .ok()) { - LOG(WARNING) << "Failed to write op stats cache file."; - }; - } - return absl::OkStatus(); -} - -absl::StatusOr ConvertMultiXSpacesToOverviewPage( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - OverviewPage overview_page = ConvertOpStatsToOverviewPage(combined_op_stats); - InferenceStats inference_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToInferenceStats(session_snapshot, "", - "", &inference_stats)); - *overview_page.mutable_inference_latency() = - ComputeInferenceLatencyResult(inference_stats); - return overview_page.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToInputPipeline( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - return ConvertOpStatsToInputPipelineAnalysis(combined_op_stats) - .SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToTfStats( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - return ConvertOpStatsToTfStats(combined_op_stats).SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToKernelStats( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - return combined_op_stats.kernel_stats_db().SerializeAsString(); -} - -absl::StatusOr ConvertXSpaceToMemoryProfile( - const SessionSnapshot& session_snapshot) { - if (session_snapshot.XSpaceSize() != 1) { - return errors::InvalidArgument( - "Memory profile tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - std::string json_output; - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false); - TF_RETURN_IF_ERROR(ConvertXSpaceToMemoryProfileJson(*xspace, &json_output)); - return json_output; -} - -absl::StatusOr ConvertMultiXSpacesToPodViewer( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - - std::string json_output; - tsl::protobuf::util::JsonPrintOptions opts; - opts.always_print_primitive_fields = true; - auto encode_status = tsl::protobuf::util::MessageToJsonString( - ConvertOpStatsToPodViewer(combined_op_stats), &json_output, opts); - if (!encode_status.ok()) { - const auto& error_message = encode_status.message(); - return errors::Internal( - "Could not convert pod viewer to json. Error: ", - absl::string_view(error_message.data(), error_message.length())); - } - return json_output; -} - -absl::StatusOr ConvertMultiXSpacesToTfDataBottleneckAnalysis( - const SessionSnapshot& session_snapshot) { - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - - for (int idx = 0; idx < session_snapshot.XSpaceSize(); ++idx) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(idx)); - - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false); - XPlane* host_plane = - FindMutablePlaneWithName(xspace.get(), kHostThreadsPlaneName); - std::string host_name_from_file = session_snapshot.GetHostname(idx); - if (host_plane == nullptr) { - return errors::InvalidArgument( - "Could not find host XPlane for tf data stats: ", - host_name_from_file); - } - absl::string_view host_name = - xspace->hostnames_size() ? xspace->hostnames(0) : host_name_from_file; - builder.Add(host_name, host_plane); - } - builder.Finalize(); - return combined_tf_data_stats.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToHloStats( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - hlo_stats::HloStatsDatabase hlo_stats_db = - ConvertOpStatsToHloStats(combined_op_stats); - return HloStatsToDataTableJson(hlo_stats_db); -} - -absl::StatusOr ConvertMultiXSpacesToRooflineModel( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - RooflineModelDatabase result = - ConvertOpStatsToRooflineModel(combined_op_stats, true); - RooflineModelDatabase result_without_infeed_outfeed = - ConvertOpStatsToRooflineModel(combined_op_stats, false); - result.mutable_roofline_model_record()->MergeFrom( - result_without_infeed_outfeed.roofline_model_record()); - return result.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToOpProfileViewer( - const SessionSnapshot& session_snapshot) { - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToCombinedOpStatsWithCache( - session_snapshot, &combined_op_stats)); - - tensorflow::profiler::op_profile::Profile profile; - ConvertOpStatsToOpProfile( - combined_op_stats, - ParseHardwareType(combined_op_stats.run_environment().device_type()), - profile); - std::string json_output; - tsl::protobuf::util::JsonPrintOptions opts; - opts.always_print_primitive_fields = true; - - auto encode_status = - tsl::protobuf::util::MessageToJsonString(profile, &json_output, opts); - if (!encode_status.ok()) { - const auto& error_message = encode_status.message(); - return errors::Internal( - "Could not convert op profile proto to json. Error: ", - absl::string_view(error_message.data(), error_message.length())); - } - return json_output; -} - -absl::StatusOr PreprocessXSpace( - const SessionSnapshot& session_snapshot) { - if (session_snapshot.XSpaceSize() != 1) { - return errors::InvalidArgument( - "PreprocessXSpace tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/true); - return xspace->SerializeAsString(); -} - -absl::StatusOr ConvertDcnCollectiveStatsToToolData( - const SessionSnapshot& session_snapshot, const ToolOptions& options) { - // must provide a host_name field. - std::optional hostname = - GetParam(options, "host_name"); - if (!hostname.has_value() || hostname->empty()) { - return absl::InvalidArgumentError( - "Cannot find host_name from options for dcn_collective_stats tool."); - } - - // Load DcnSlackAnalysis for a host. - TF_ASSIGN_OR_RETURN( - DcnSlackAnalysis dcnSlackAnalysis, - GetDcnSlackAnalysisByHostName(session_snapshot, hostname.value())); - - return dcnSlackAnalysis.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToInferenceStats( - const SessionSnapshot& session_snapshot, const ToolOptions& options) { - InferenceStats inference_stats; - std::string request_column = - GetParamWithDefault(options, "request_column", ""); - std::string batch_column = - GetParamWithDefault(options, "batch_column", ""); - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToInferenceStats( - session_snapshot, request_column, batch_column, &inference_stats)); - return inference_stats.SerializeAsString(); -} - -} // namespace - -absl::StatusOr ConvertMultiXSpacesToToolData( - const SessionSnapshot& session_snapshot, const absl::string_view tool_name, - const ToolOptions& options) { - LOG(INFO) << "serving tool: " << tool_name - << " with options: " << DebugString(options); - if (tool_name == "trace_viewer" || tool_name == "trace_viewer@") { - return ConvertXSpaceToTraceEvents(session_snapshot, tool_name, options); - } else if (tool_name == "overview_page") { - return ConvertMultiXSpacesToOverviewPage(session_snapshot); - } else if (tool_name == "input_pipeline_analyzer") { - return ConvertMultiXSpacesToInputPipeline(session_snapshot); - } else if (tool_name == "framework_op_stats") { - return ConvertMultiXSpacesToTfStats(session_snapshot); - } else if (tool_name == "kernel_stats") { - return ConvertMultiXSpacesToKernelStats(session_snapshot); - } else if (tool_name == "memory_profile") { - return ConvertXSpaceToMemoryProfile(session_snapshot); - } else if (tool_name == "pod_viewer") { - return ConvertMultiXSpacesToPodViewer(session_snapshot); - } else if (tool_name == "op_profile") { - return ConvertMultiXSpacesToOpProfileViewer(session_snapshot); - } else if (tool_name == "hlo_stats") { - return ConvertMultiXSpacesToHloStats(session_snapshot); - } else if (tool_name == "roofline_model") { - return ConvertMultiXSpacesToRooflineModel(session_snapshot); - } else if (tool_name == "memory_viewer" || tool_name == "graph_viewer") { - return ConvertHloProtoToToolData(session_snapshot, tool_name, options); - } else if (tool_name == "tool_names") { - return GetAvailableToolNames(session_snapshot); - } else if (tool_name == "_xplane.pb") { // internal test only. - return PreprocessXSpace(session_snapshot); - } else if (tool_name == "inference_profile") { - return ConvertMultiXSpacesToInferenceStats(session_snapshot, options); - } else { - return errors::InvalidArgument( - "Can not find tool: ", tool_name, - ". Please update to the latest version of Tensorflow."); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.h b/tensorflow/core/profiler/convert/xplane_to_tools_data.h deleted file mode 100644 index 49ab1ea588f41d..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ - -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" - -namespace tensorflow { -namespace profiler { - -// Convert XSpace protos to a tool specific data. -// Return the serialized string of tool specific data when the conversion is -// successful, else return error status. -absl::StatusOr ConvertMultiXSpacesToToolData( - const SessionSnapshot& session_snapshot, absl::string_view tool_name, - const ToolOptions& options); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc deleted file mode 100644 index 96fdf62e77129b..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_trace_container.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/trace_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/convert/trace_viewer/trace_event_arguments_builder.h" // from @org_xprof -#include "xprof/convert/trace_viewer/trace_events_util.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/trace_events.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::FindPlanesWithPrefix; -using tsl::profiler::FindPlaneWithName; -using tsl::profiler::HostEventType; -using tsl::profiler::StatType; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XFlow; -using tsl::profiler::XLineVisitor; -using tsl::profiler::XPlaneVisitor; -using tsl::profiler::XStatVisitor; - -struct SpecialArguments { - std::optional group_id; - absl::string_view step_name; - bool is_async_event = false; - // Both flow and async events share the flow specification. - std::optional flow; -}; - -inline TraceEvent::FlowEntryType FlowEntryTypeFromDirection( - XFlow::FlowDirection direction) { - switch (direction) { - case XFlow::kFlowUnspecified: - return TraceEvent::FLOW_NONE; - case XFlow::kFlowIn: - return TraceEvent::FLOW_END; - case XFlow::kFlowOut: - return TraceEvent::FLOW_START; - case XFlow::kFlowInOut: - return TraceEvent::FLOW_MID; - } -} - -template -void ConvertXStatToTraceEventArgument(const XStatVisitor& stat, T value, - SpecialArguments& special_args, - TraceEventArgumentsBuilder& args) { - if (stat.Type() == StatType::kFlow) { - special_args.flow = XFlow::FromStatValue(value); - } else if (stat.Type() == StatType::kGroupId) { - special_args.group_id = value; - } else if (stat.Type() == StatType::kIsAsync) { - special_args.is_async_event = true; - } else { - args.Append(stat.Name(), value); - } -} - -SpecialArguments ConvertXStatsToTraceEventArguments( - const XEventVisitor& event, RawData* raw_data, - TraceEventArguments* raw_args) { - TraceEventArgumentsBuilder args(raw_args); - SpecialArguments special_args; - auto for_each_stat = [&special_args, &args](const XStatVisitor& stat) { - if (tsl::profiler::IsInternalStat(stat.Type())) return; - switch (stat.ValueCase()) { - case XStat::kInt64Value: - ConvertXStatToTraceEventArgument(stat, stat.IntValue(), special_args, - args); - break; - case XStat::kUint64Value: - ConvertXStatToTraceEventArgument(stat, stat.UintValue(), special_args, - args); - break; - case XStat::kDoubleValue: - args.Append(stat.Name(), stat.DoubleValue()); - break; - case XStat::kStrValue: - case XStat::kRefValue: { - auto stat_value = stat.StrOrRefValue(); - if (stat.Type() == StatType::kStepName) { - special_args.step_name = stat_value; - } - args.Append(stat.Name(), stat_value); - break; - } - case XStat::kBytesValue: - break; - case XStat::VALUE_NOT_SET: - break; - } - }; - // Ensure the metadata stats appear before the per-occurrence stats. - event.Metadata().ForEachStat(for_each_stat); - event.ForEachStat(for_each_stat); - return special_args; -} - -void ConvertXLineToTraceEventsContainer(uint32_t device_id, - const XLineVisitor& line, - TraceEventsContainer* container) { - std::optional resource_id; - - if (line.Name() != tsl::profiler::kCounterEventsLineName) { - resource_id = line.DisplayId(); - Resource* resource = container->MutableResource(*resource_id, device_id); - resource->set_resource_id(*resource_id); - resource->set_name(std::string(line.DisplayName())); - resource->set_num_events(line.NumEvents()); - } - - RawData raw_data; // hoisted for performance - line.ForEachEvent([device_id, resource_id, &raw_data, - container](const XEventVisitor& event) { - int64_t event_type = - event.Type().value_or(HostEventType::kUnknownHostEventType); - if (tsl::profiler::IsInternalEvent(event_type)) return; - TraceEventArguments* raw_args = raw_data.mutable_args(); - absl::string_view event_name; - if (event.HasDisplayName()) { - event_name = event.DisplayName(); - TraceEventArgumentsBuilder args(raw_args); - constexpr size_t kMaxLongName = 10000; - if (event.Name().size() > kMaxLongName) { - args.Append("long_name", - absl::StrCat(event.Name().substr(0, kMaxLongName), - "...")); - } else { - args.Append("long_name", event.Name()); - } - } else { - event_name = event.Name(); - } - SpecialArguments special_args = - ConvertXStatsToTraceEventArguments(event, &raw_data, raw_args); - if (!special_args.step_name.empty()) { - event_name = special_args.step_name; - } - if (!resource_id) { - container->AddCounterEvent(event_name, device_id, event.TimestampPs(), - raw_data); - } else if (special_args.flow) { - tsl::profiler::Timespan span(event.TimestampPs(), event.DurationPs()); - if (special_args.is_async_event) { - container->AddAsyncEvent( - event_name, device_id, span, special_args.flow->Id(), - FlowEntryTypeFromDirection(special_args.flow->Direction()), - special_args.flow->Category(), &raw_data, special_args.group_id); - } else { - container->AddFlowEvent( - event_name, *resource_id, device_id, span, special_args.flow->Id(), - FlowEntryTypeFromDirection(special_args.flow->Direction()), - special_args.flow->Category(), &raw_data, special_args.group_id); - } - } else { - tsl::profiler::Timespan span(event.TimestampPs(), event.DurationPs()); - container->AddCompleteEvent(event_name, *resource_id, device_id, span, - &raw_data, special_args.group_id); - } - // Cleanup hoisted structure for next event. - if (raw_data.has_args()) raw_args->clear_arg(); - }); -} - -void ConvertXPlaneToTraceEventsContainer(uint64_t device_id, - absl::string_view hostname, - const XPlane& xplane, - TraceEventsContainer* container) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - std::unique_ptr resource_grouper = - CreateDefaultResourceGrouper(device_id, plane.Name()); - - if (plane.NumLines() == 0) return; - - for (const auto& [device_id, name] : resource_grouper->Devices()) { - Device* device = container->MutableDevice(device_id); - device->set_device_id(device_id); - device->set_name(absl::StrCat(hostname, " ", name)); - } - - plane.ForEachLine([&](const XLineVisitor& line) { - if (line.DisplayName() == tsl::profiler::kXlaAsyncOpLineName) return; - if (line.NumEvents() == 0) return; - // Capture a copy of XLineVisitor because it will go out of scope. - uint32_t device_id = resource_grouper->GetDeviceId(line.DisplayId()); - ConvertXLineToTraceEventsContainer(device_id, line, container); - }); -} - -} // namespace - -void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, - const XSpace& space, - TraceEventsContainer* container) { - const XPlane* host_plane = - FindPlaneWithName(space, tsl::profiler::kHostThreadsPlaneName); - if (host_plane != nullptr) { - ConvertXPlaneToTraceEventsContainer(tsl::profiler::kHostThreadsDeviceId, - hostname, *host_plane, container); - } - - std::vector device_planes = - FindPlanesWithPrefix(space, tsl::profiler::kGpuPlanePrefix); - - if (device_planes.empty()) { - device_planes = FindPlanesWithPrefix(space, tsl::profiler::kTpuPlanePrefix); - } - - for (const XPlane* device_plane : device_planes) { - ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kFirstDeviceId + device_plane->id(), hostname, - *device_plane, container); - } - for (const XPlane* custom_plane : - FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) { - ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kFirstCustomPlaneDeviceId + custom_plane->id(), hostname, - *custom_plane, container); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.h b/tensorflow/core/profiler/convert/xplane_to_trace_container.h deleted file mode 100644 index 157c16aa6c38fa..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/convert/trace_viewer/trace_events.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -using TraceEventsContainer = TraceEventsContainerBase; - -// Converts XEvents within the XSpace into trace_viewer events container. -void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, - const XSpace& xspace, - TraceEventsContainer* container); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc deleted file mode 100644 index 86ba79a34870e1..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_trace_container.h" - -#include -#include - -#include -#include -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/match.h" -#include "absl/strings/substitute.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/util/proto/proto_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/trace_events.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/trace_events_raw.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::Pair; -using ::testing::UnorderedElementsAre; - -TEST(XPlaneToTraceContainerTest, CounterLine) { - XSpace xspace; - CHECK_OK(tensorflow::proto_utils::ParseTextFormatFromString( - absl::Substitute( - "planes {" - " name: \"/device:GPU:0\"" - " lines {" - " name: \"_counters_\"" - " events {" - " metadata_id: 100" - " offset_ps: $0" - " stats { metadata_id: 200 uint64_value: 100 }" - " }" - " events {" - " metadata_id: 100" - " offset_ps: $1" - " stats { metadata_id: 200 uint64_value: 200 }" - " }" - " events {" - " metadata_id: 101" - " offset_ps: $0" - " stats { metadata_id: 201 uint64_value: 300 }" - " }" - " events {" - " metadata_id: 101" - " offset_ps: $1" - " stats { metadata_id: 201 uint64_value: 400 }" - " }" - " }" - " lines {" - " id: 14" - " name: \"Stream #14(MemcpyH2D)\"" - " timestamp_ns: $3" - " events {" - " metadata_id: 10" - " offset_ps: 0" - " duration_ps: $1" - " stats { metadata_id: 8 uint64_value: 100 }" - " stats { metadata_id: 9 str_value: \"$$1\" }" - " }" - " events {" - " metadata_id: 10" - " offset_ps: $0" - " duration_ps: $3" - " stats { metadata_id: 8 uint64_value: 200 }" - " stats { metadata_id: 9 str_value: \"abcd\" }" - " }" - " }" - " event_metadata {key: 10 value: { id: 10 name: \"MemcpyD2D\" }}" - " event_metadata {key: 100 value: { id: 100 name: \"Counter 1\" }}" - " event_metadata {key: 101 value: { id: 101 name: \"Counter 2\" }}" - " stat_metadata {key: 8 value: { id: 8 name: \"RemoteCall\"}}" - " stat_metadata {key: 9 value: { id: 8 name: \"context_id\"}}" - " stat_metadata {key: 200 value: { id: 200 name: \"counter_1\"}}" - " stat_metadata {key: 201 value: { id: 201 name: \"counter_2\"}}" - "}", - tsl::profiler::UniToPico(1), tsl::profiler::UniToPico(2), - tsl::profiler::UniToNano(1), tsl::profiler::UniToNano(500)), - &xspace)); - TraceEventsContainer container; - ConvertXSpaceToTraceEventsContainer("localhost", xspace, &container); - absl::flat_hash_map> - counter_offset_to_values; - container.ForAllEvents([&counter_offset_to_values](const TraceEvent& event) { - if (absl::StrContains(event.name(), "Counter")) { - uint64_t offset = event.timestamp_ps(); - RawData raw_data; - raw_data.ParseFromString(event.raw_data()); - counter_offset_to_values[event.name()][offset] = - raw_data.args().arg(0).uint_value(); - } - }); - EXPECT_THAT( - counter_offset_to_values, - UnorderedElementsAre( - Pair("Counter 1", - UnorderedElementsAre(Pair(tsl::profiler::UniToPico(1), 100), - Pair(tsl::profiler::UniToPico(2), 200))), - Pair("Counter 2", - UnorderedElementsAre(Pair(tsl::profiler::UniToPico(1), 300), - Pair(tsl::profiler::UniToPico(2), 400))))); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc deleted file mode 100644 index 27ccff245411e3..00000000000000 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ /dev/null @@ -1,528 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h" - -#include - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/shape_util.h" -#include "xla/side_effect_util.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/platform/regexp.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_collective_info.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/topology.pb.h" // from @org_xprof -#include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using tensorflow::profiler::DcnSlackSummary; -using tensorflow::profiler::Topology; -using tsl::profiler::CreateTfXPlaneVisitor; -using tsl::profiler::FindLineWithName; -using tsl::profiler::kXlaOpLineName; -using tsl::profiler::NanoToMicro; -using tsl::profiler::PicoToMicro; -using tsl::profiler::SafeDivide; -using tsl::profiler::StatType; -using tsl::profiler::Timespan; -using tsl::profiler::XEventContextTracker; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XLineVisitor; -using tsl::profiler::XPlaneVisitor; -using tsl::profiler::XStatVisitor; -using xla::HloOpcode; - -// TODO: Identify mechanism to maintain consistency between producer and -// consumer here. -const char kHostEventRegex[] = { - "device_[0-9]+([0-9][0-9][0-9][0-9][0-9])_gid_(.*)"}; - -std::optional GetAttributeFromInstr( - const xla::HloInstruction* instr, std::string_view attribute) { - std::optional attribute_value; - if (instr->frontend_attributes().IsInitialized() && - !instr->frontend_attributes().map().empty() && - instr->frontend_attributes().map().contains(attribute)) { - attribute_value = instr->frontend_attributes().map().at(attribute); - } - return attribute_value; -} -std::optional GetRendezvous(const xla::HloInstruction* instr) { - return GetAttributeFromInstr(instr, xla::kXlaHostTransferRendezvousNameAttr); -} - -dcn_analysis_internal::DcnHostEvent ParseDcnHostEvent( - const XEventVisitor& visitor) { - dcn_analysis_internal::DcnHostEvent event; - static const LazyRE2 re = {kHostEventRegex}; - RE2::FullMatch(visitor.Name(), *re, &event.multi_slice_device_id, - &event.rendezvous_name); - - event.timespan = visitor.GetTimespan(); - return event; -} - -std::optional GetTransferType(const xla::HloInstruction* instr) { - return GetAttributeFromInstr(instr, "_xla_megascale_transfer_type"); -} - -std::string HostCollectiveKey(int index_on_host, - std::string_view rendezvous_name) { - return absl::StrCat(index_on_host, "_", rendezvous_name); -} - -DcnCollectiveInfoProto GetDcnCollectiveInfoProto(const XEventVisitor& xevent) { - DcnCollectiveInfoProto dcn_collective_info; - xevent.Metadata().ForEachStat([&](const XStatVisitor& xstat) { - if (static_cast(*xstat.Type()) == StatType::kDcnCollectiveInfo) { - absl::string_view byte_value = xstat.BytesValue(); - if (!dcn_collective_info.ParseFromArray(byte_value.data(), - byte_value.size())) { - LOG(WARNING) << "Could not parse DcnCollectiveInfoProto from metadata."; - } - } - }); - - return dcn_collective_info; -} - -} // namespace - -namespace dcn_analysis_internal { - -void DcnHostEventList::insert(DcnHostEvent event) { - if (iter_ != events_.end() && event.timespan < iter_->timespan) { - // The event being inserted is from a new line, Reset iterator to the - // beginning. - iter_ = events_.begin(); - } - while (iter_ != events_.end() && iter_->timespan < event.timespan) { - iter_++; - } - iter_ = events_.insert(iter_, event); -} - -std::optional DcnHostEventList::pop(const Timespan& timespan) { - while (!events_.empty() && events_.front().timespan < timespan) { - events_.pop_front(); - } - - if (!events_.empty() && - (timespan.Includes(events_.front().timespan.begin_ps()) || - events_.front().timespan.Includes(timespan.begin_ps()))) { - DcnHostEvent front = events_.front(); - events_.pop_front(); - return front; - } else { - return std::nullopt; - } -} - -absl::StatusOr DcnTracker::GetInstrMetadataFromHloModule( - std::string_view module_name, std::string_view instr_name) { - if (!hlo_module_cache_.contains(module_name)) { - TF_ASSIGN_OR_RETURN(auto hlo_proto, - hlo_proto_map_.GetHloProtoByModuleName(module_name)); - TF_ASSIGN_OR_RETURN(auto module, ConvertHloProtoToModule(*hlo_proto)); - hlo_module_cache_[module_name] = std::move(module); - } - const auto& hlo_module = hlo_module_cache_[module_name]; - dcn_analysis_internal::InstrMetadata instr_metadata; - auto instr = FindInstruction(*hlo_module, std::string(instr_name)); - - instr_metadata.opcode = instr->opcode(); - instr_metadata.channel_id = instr->channel_id().value(); - instr_metadata.rendezvous_name = GetRendezvous(instr); - instr_metadata.transfer_type = GetTransferType(instr); - instr_metadata.size = 0; - if (instr->shape().IsArray()) { - instr_metadata.size = xla::ShapeUtil::ByteSizeOfElements(instr->shape()); - } else if (instr->shape().IsTuple()) { - for (const auto& shape : instr->shape().tuple_shapes()) { - instr_metadata.size += xla::ShapeUtil::ByteSizeOf(shape); - } - } - return instr_metadata; -} - -absl::StatusOr DcnTracker::GetInstructionMetadata( - std::string_view module, std::string_view instr) { - std::string key = absl::StrCat(module, "_", instr); - if (const auto& it = instruction_metadata_map_.find(key); - it != instruction_metadata_map_.end()) { - return it->second; - } - - absl::StatusOr instr_metadata = - GetInstrMetadataFromHloModule(module, instr); - if (instr_metadata.ok()) { - instruction_metadata_map_[key] = *instr_metadata; - } - - return instr_metadata; -} - -DcnSlackAnalysis DcnTracker::Finalize() { - SummarizeDcnSlackAnalysis(); - return slack_analysis_; -} - -void DcnTracker::DebugString() { - for (const DcnSlack& analysis : slack_analysis_.dcn_slack()) { - LOG(INFO) << analysis.rendezvous() << " : " << analysis.slack_us(); - } -} - -void DcnTracker::UpdateActiveOps(uint64_t duration) { - for (auto& [rendezvous, opState] : rendezvous_to_op_map_) { - opState.overlapping_duration += duration; - } -} - -int DcnTracker::GetReplicaGroupSize(const std::string& rendezvous_name, - const XEventVisitor& visitor) { - if (rendezvous_to_replica_group_size_map_.contains(rendezvous_name)) { - return rendezvous_to_replica_group_size_map_[rendezvous_name]; - } - - DcnCollectiveInfoProto dcn_collective_info = - GetDcnCollectiveInfoProto(visitor); - - if (dcn_collective_info.one_to_one_groups_size() != 0) { - // OneToOneGroup has a source and a destination, which is one replica group - rendezvous_to_replica_group_size_map_[rendezvous_name] = 1; - } else if (dcn_collective_info.endpoint_groups_size() != 0) { - rendezvous_to_replica_group_size_map_[rendezvous_name] = - dcn_collective_info.endpoint_groups(0).endpoints().size(); - } else { - rendezvous_to_replica_group_size_map_[rendezvous_name] = 0; - } - - return rendezvous_to_replica_group_size_map_[rendezvous_name]; -} - -// ComputeTransmittedDataSize is called with the buffer_size for recv-done. -uint64_t DcnTracker::ComputeTransmittedDataSize( - const int64_t recv_buffer_size, const int group_size, - const std::string& transfer_type) { - uint64_t transmitted_bytes = 0; - if (group_size == 0) { - LOG(ERROR) << "Replica group size is 0."; - return transmitted_bytes; - } - - if (transfer_type == "ONE_TO_ONE") { - transmitted_bytes = group_size * recv_buffer_size; - } else if (transfer_type == "ALL_GATHER") { - transmitted_bytes = - SafeDivide((group_size - 1) * recv_buffer_size, group_size); - } else if (transfer_type == "ALL_REDUCE") { - // Since the reduced buffer now has to be sent back to the replicas, - // the total bytes transmitted over the network is 2x the shape of the op. - transmitted_bytes = - 2 * SafeDivide(group_size - 1, group_size) * recv_buffer_size; - } else if (transfer_type == "ALL_TO_ALL") { - transmitted_bytes = - SafeDivide(group_size - 1, group_size) * recv_buffer_size; - } else if (transfer_type == "REDUCE_SCATTER") { - transmitted_bytes = recv_buffer_size * (group_size - 1); - } else { - LOG(ERROR) << "Unsupported transfer type: " << transfer_type; - } - return transmitted_bytes; -} - -void DcnTracker::VisitOp(const InstrMetadata& instr, - const XEventVisitor& visitor) { - std::string rendezvous_name; - if (instr.rendezvous_name.has_value()) { - rendezvous_name = *instr.rendezvous_name; - channel_id_to_rendezvous_map_[instr.channel_id] = rendezvous_name; - } else { - if (auto it = channel_id_to_rendezvous_map_.find(instr.channel_id); - it != channel_id_to_rendezvous_map_.end()) { - rendezvous_name = it->second; - } else { - // Ignore ops as we have not seen the corresponding send/recv. - return; - } - } - - DcnOpState& opState = rendezvous_to_op_map_[rendezvous_name]; - opState.stall_duration_ns += visitor.DurationNs(); - - switch (instr.opcode) { - case HloOpcode::kSend: - opState.start_time = visitor.TimestampNs(); - opState.rendezvous_name = rendezvous_name; - opState.transfer_type = - instr.transfer_type.has_value() ? *instr.transfer_type : ""; - opState.overlapping_duration = 0; - opState.stall_duration_ns = visitor.DurationNs(); - opState.send_op_name = visitor.DisplayName(); - opState.send.set_duration_ps(visitor.DurationPs()); - opState.send.set_start_time_ps(visitor.TimestampPs()); - opState.replica_group_size = - GetReplicaGroupSize(rendezvous_name, visitor); - break; - case HloOpcode::kRecv: - opState.recv.set_duration_ps(visitor.DurationPs()); - opState.recv.set_start_time_ps(visitor.TimestampPs()); - break; - case HloOpcode::kSendDone: - opState.send_done.set_duration_ps(visitor.DurationPs()); - opState.send_done.set_start_time_ps(visitor.TimestampPs()); - break; - case HloOpcode::kRecvDone: { - opState.recv_done.set_duration_ps(visitor.DurationPs()); - opState.recv_done.set_start_time_ps(visitor.TimestampPs()); - if (opState.start_time != 0) { - DcnSlack* analysis = slack_analysis_.add_dcn_slack(); - analysis->set_rendezvous(rendezvous_name); - analysis->set_transfer_type(opState.transfer_type); - analysis->set_send_start_time_us(NanoToMicro(opState.start_time)); - analysis->set_recv_done_end_time_us( - NanoToMicro(visitor.EndTimestampNs())); - analysis->set_slack_us(NanoToMicro(visitor.TimestampNs() - - opState.start_time - - opState.overlapping_duration)); - analysis->set_bytes_transmitted_over_network(ComputeTransmittedDataSize( - instr.size, opState.replica_group_size, opState.transfer_type)); - analysis->set_stall_duration_us(NanoToMicro(opState.stall_duration_ns)); - analysis->set_recv_op_name(std::string(visitor.DisplayName())); - analysis->set_send_op_name(opState.send_op_name); - *analysis->mutable_send() = opState.send; - *analysis->mutable_recv() = opState.recv; - *analysis->mutable_send_done() = opState.send_done; - *analysis->mutable_recv_done() = opState.recv_done; - } - - break; - } - default: - LOG(ERROR) << "Received unexpected op"; - } - UpdateActiveOps(visitor.DurationNs()); -} - -std::optional DcnTracker::GetCollectiveHostEvent( - int core_id, std::string_view rendezvous, Timespan timespan) { - return core_id_to_host_event_map_[HostCollectiveKey(core_id, rendezvous)].pop( - timespan); -} - -void DcnTracker::SummarizeDcnSlackAnalysis() { - absl::flat_hash_map summary; - // TODO(b/302596260) : Expand to process all cores. - int core_id = 0; - for (DcnSlack& analysis : *slack_analysis_.mutable_dcn_slack()) { - DcnSlackSummary& s = summary[analysis.rendezvous()]; - s.set_slack_us(s.slack_us() + analysis.slack_us()); - s.set_occurrences(s.occurrences() + 1); - s.set_rendezvous(analysis.rendezvous()); - s.set_transfer_type(analysis.transfer_type()); - s.set_bytes_transmitted_over_network( - analysis.bytes_transmitted_over_network()); - s.set_stall_duration_us(s.stall_duration_us() + - analysis.stall_duration_us()); - s.set_observed_duration_us(s.observed_duration_us() + - analysis.recv_done_end_time_us() - - analysis.send_start_time_us()); - s.set_recv_op_name(analysis.recv_op_name()); - s.set_send_op_name(analysis.send_op_name()); - s.set_send_duration_us(s.send_duration_us() + - PicoToMicro(analysis.send().duration_ps())); - s.set_recv_duration_us(s.recv_duration_us() + - PicoToMicro(analysis.recv().duration_ps()) / 1E6); - s.set_send_done_duration_us( - s.send_done_duration_us() + - PicoToMicro(analysis.send_done().duration_ps())); - s.set_recv_done_duration_us( - s.recv_done_duration_us() + - PicoToMicro(analysis.recv_done().duration_ps())); - - // Populate Host summary to DcnSlackSummary - std::optional host_event = GetCollectiveHostEvent( - core_id, analysis.rendezvous(), - Timespan::FromEndPoints(analysis.send().start_time_ps(), - analysis.recv_done().start_time_ps() + - analysis.recv_done().duration_ps())); - if (host_event.has_value()) { - OpInstance* host_graph_execution = - analysis.mutable_host_graph_execution(); - host_graph_execution->set_start_time_ps(host_event->timespan.begin_ps()); - host_graph_execution->set_duration_ps(host_event->timespan.duration_ps()); - s.set_host_stall_us(s.host_stall_us() + - (((int64_t)host_event->timespan.end_ps() - - (int64_t)analysis.recv_done().start_time_ps()) / - 1E6)); - s.set_host_events_count(s.host_events_count() + 1); - } - } - - for (auto& [_, s] : summary) { - s.set_slack_us(SafeDivide(s.slack_us(), s.occurrences())); - s.set_stall_duration_us(SafeDivide(s.stall_duration_us(), s.occurrences())); - s.set_observed_duration_us( - SafeDivide(s.observed_duration_us(), s.occurrences())); - s.set_send_done_duration_us( - SafeDivide(s.send_done_duration_us(), s.occurrences())); - s.set_recv_done_duration_us( - SafeDivide(s.recv_done_duration_us(), s.occurrences())); - s.set_send_duration_us(SafeDivide(s.send_duration_us(), s.occurrences())); - s.set_recv_duration_us(SafeDivide(s.recv_duration_us(), s.occurrences())); - s.set_host_stall_us(SafeDivide(s.host_stall_us(), s.host_events_count())); - *slack_analysis_.add_dcn_slack_summary() = s; - } -} - -void DcnTracker::ProcessTopology(const Topology& topology) { - for (const auto& mesh_location : topology.mesh_location()) { - global_chip_id_to_local_index_map_[mesh_location.global_id()] = - mesh_location.index_on_host(); - } -} - -int DcnTracker::GetLocalIndex(int dcn_device_id) { - /* Based on if megacore was present or not, the LocalIndex calculation will - * differ, - * dcn device id would use the global index in cases of megacore, and use - * 2*global_index (+1) for non megacore instances - * TODO(b/302145703): Identify if transformation can be obtained from the - * TpuTopology directly - */ - int global_device_id = dcn_device_id; - if (!is_megacore_) { - if (global_chip_id_to_local_index_map_.contains(global_device_id)) { - return global_chip_id_to_local_index_map_[dcn_device_id / 2] + - dcn_device_id % 2; - } - } - if (global_chip_id_to_local_index_map_.contains(global_device_id)) { - return global_chip_id_to_local_index_map_[global_device_id]; - } - LOG(WARNING) << "Could not map dcn_device_id to Local index, Using " - "dcn_device_id : " - << global_device_id; - return global_device_id; -} - -void DcnTracker::VisitHostEvent(const DcnHostEvent& event) { - std::string key = HostCollectiveKey( - GetLocalIndex(event.multi_slice_device_id), event.rendezvous_name); - if (event.rendezvous_name.empty()) return; - core_id_to_host_event_map_[key].insert(event); -} - -void ProcessDcnTraces(const XPlane& xplane, DcnTracker& dcn_tracker) { - XPlaneVisitor xplane_visitor = CreateTfXPlaneVisitor(&xplane); - HloProtoMap hlo_proto_map; - xplane_visitor.ForEachLine([&](const XLineVisitor& line) { - line.ForEachEvent([&](const XEventVisitor& event) { - dcn_tracker.VisitHostEvent(ParseDcnHostEvent(event)); - }); - }); -} - -} // namespace dcn_analysis_internal - -DcnSlackAnalysis ConvertXSpaceToDcnSlackAnalysis(const XSpace& xspace, - const XPlane* dcn_host_plane, - const Topology* topology, - bool is_megacore) { - int num_cores = tsl::profiler::FindTensorCorePlanes(xspace).size(); - if (num_cores == 0) return DcnSlackAnalysis(); - const XPlane* xplane = - FindPlaneWithName(xspace, tsl::profiler::TpuPlaneName(0)); - XPlaneVisitor xplane_visitor = CreateTfXPlaneVisitor(xplane); - HloProtoMap hlo_proto_map; - hlo_proto_map.AddHloProtosFromXSpace(xspace); - dcn_analysis_internal::DcnTracker dcn_tracker(hlo_proto_map, is_megacore); - XEventContextTracker hlo_module_context( - &xplane_visitor, - FindLineWithName(*xplane, tsl::profiler::kXlaModuleLineName)); - xplane_visitor.ForEachLine([&](const XLineVisitor& xline) { - if (xline.Name() == kXlaOpLineName) { - xline.ForEachEvent([&](const XEventVisitor& xevent) { - std::string_view hlo_category; - - xevent.Metadata().ForEachStat([&](const XStatVisitor& xstat) { - switch (static_cast(*xstat.Type())) { - case StatType::kHloCategory: - hlo_category = xstat.StrOrRefValue(); - break; - default: - break; - } - }); - auto module = - hlo_module_context.GetContainingEvent(xevent.GetTimespan()); - if (!module.has_value()) return; - if (absl::StrContains(hlo_category, "host send") || - absl::StrContains(hlo_category, "host recv")) { - // All Dcn send/send-done/recv/recv-done ops. - auto instr = dcn_tracker.GetInstructionMetadata(module->Name(), - xevent.DisplayName()); - if (instr.ok()) { - dcn_tracker.VisitOp(*instr, xevent); - } - } - }); - } - }); - - if (dcn_host_plane != nullptr) { - VLOG(1) << "Processing host traces."; - if (topology != nullptr) { - dcn_tracker.ProcessTopology(*topology); - } - ProcessDcnTraces(*dcn_host_plane, dcn_tracker); - } - return dcn_tracker.Finalize(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h deleted file mode 100644 index 8f98c452e0eace..00000000000000 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.pb.h" // from @org_xprof -#include "plugin/tensorboard_plugin_profile/protobuf/topology.pb.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::DcnSlackAnalysis; - -namespace dcn_analysis_internal { - -struct DcnOpState { - uint64_t start_time = 0; - uint64_t end_time = 0; - - // Duration of containing send/send-done/recv/recv-done ops that needs to be - // subtracted from the total duration - uint64_t overlapping_duration = 0; - std::string rendezvous_name; - std::string transfer_type; - uint64_t stall_duration_ns = 0; - std::string send_op_name; - int replica_group_size = 0; - - OpInstance send; - OpInstance send_done; - OpInstance recv; - OpInstance recv_done; -}; - -// Structure to extract and store the DcnHostEvents. -struct DcnHostEvent { - std::string rendezvous_name; - tsl::profiler::Timespan timespan; - int multi_slice_device_id; -}; - -// When visiting DcnHostEvents from the megascale planes, The events are stored -// in separate lines in an ascending (by time) order. The List allows insertion -// of multiple arrays of sorted events. -class DcnHostEventList { - public: - // Insert the event into the sorted list. - void insert(DcnHostEvent event); - - // Pop the events from the front that is included within the timestamp when - // available. - std::optional pop(const tsl::profiler::Timespan& timespan); - - // Number of events. - int size() const { return events_.size(); } - - private: - std::list events_; - std::list::iterator iter_ = events_.begin(); -}; - -struct InstrMetadata { - xla::HloOpcode opcode; - uint64_t channel_id; - std::optional rendezvous_name; - int64_t size = 0; - std::optional transfer_type; -}; - -class DcnTracker { - public: - explicit DcnTracker(const tensorflow::profiler::HloProtoMap& hlo_proto_map, - bool is_megacore) - : hlo_proto_map_(hlo_proto_map), is_megacore_(is_megacore) {} - - absl::StatusOr GetInstructionMetadata(std::string_view module, - std::string_view instr); - - DcnSlackAnalysis Finalize(); - - void DebugString(); - - void VisitOp(const InstrMetadata& instr, - const tsl::profiler::XEventVisitor& visitor); - - void VisitHostEvent(const DcnHostEvent& event); - - void ProcessTopology(const tensorflow::profiler::Topology& topology); - - private: - DcnSlackAnalysis slack_analysis_; - absl::flat_hash_map rendezvous_to_op_map_; - absl::flat_hash_map channel_id_to_rendezvous_map_; - absl::flat_hash_map instruction_metadata_map_; - absl::flat_hash_map core_id_to_host_event_map_; - const tensorflow::profiler::HloProtoMap& hlo_proto_map_; - absl::flat_hash_map global_chip_id_to_local_index_map_; - absl::flat_hash_map> - hlo_module_cache_; - absl::flat_hash_map rendezvous_to_replica_group_size_map_; - bool is_megacore_ = true; - - absl::StatusOr GetInstrMetadataFromHloModule( - std::string_view module, std::string_view instr); - - void UpdateActiveOps(uint64_t duration); - - void SummarizeDcnSlackAnalysis(); - - std::optional GetCollectiveHostEvent( - int core_id, std::string_view rendezvous_name, - tsl::profiler::Timespan timespan); - - // GetLocalIndex when available, else return the global_device_id itself. - int GetLocalIndex(int dcn_device_id); - - // Get number of replica group - int GetReplicaGroupSize(const std::string& rendezvous_name, - const tsl::profiler::XEventVisitor& visitor); - - // Compute data transmitted size based on number of replica groups - uint64_t ComputeTransmittedDataSize(int64_t buffer_size, int group_size, - const std::string& transfer_type); -}; - -} // namespace dcn_analysis_internal - -// Convert Hlo Events in XSpace to Dcn Slack analysis. -DcnSlackAnalysis ConvertXSpaceToDcnSlackAnalysis( - const tsl::profiler::XSpace& xspace, - const tsl::profiler::XPlane* dcn_host_plane, - const tensorflow::profiler::Topology* topology, bool is_megacore = true); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 7e52d2c96a6bbb..76cafa8a8aa196 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -111,6 +111,8 @@ cc_library( hdrs = ["profiler_controller.h"], deps = [ "@com_google_absl//absl/base:core_headers", + "@local_tsl//tsl/profiler/lib:profiler_controller", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 5cc755babb4057..4767590ee13097 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -129,12 +129,12 @@ tf_python_pybind_extension( ], deps = [ ":profiler_pywrap_impl", - "//tensorflow/core/profiler/convert:repository", - "//tensorflow/core/profiler/convert:tool_options", - "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/core/profiler/rpc:profiler_server_for_pybind", "//tensorflow/python/lib/core:pybind11_status", "@com_google_absl//absl/status", + "@org_xprof//xprof/convert:repository", + "@org_xprof//xprof/convert:tool_options", + "@org_xprof//xprof/convert:xplane_to_tools_data", "@pybind11", ], ) @@ -165,14 +165,13 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/core/profiler/lib:profiler_session_for_pybind", "//tensorflow/core/profiler/rpc:profiler_server_for_pybind", - "//tensorflow/core/profiler/rpc/client:save_profile", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events", "@local_xla//xla/tsl/profiler/rpc/client:capture_profile", @@ -196,9 +195,6 @@ tsl_pybind_extension( "//tensorflow/core/framework:attr_value_proto_cc_impl", "//tensorflow/core/framework:op", "//tensorflow/core/framework:tensor", - "//tensorflow/core/profiler/convert:repository", - "//tensorflow/core/profiler/convert:tool_options", - "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/python/lib/core:py_exception_registry", "//tensorflow/python/lib/core:pybind11_status", "@com_google_absl//absl/container:flat_hash_map", @@ -239,6 +235,7 @@ tsl_pybind_extension( "@local_xla//xla/tsl/protobuf:histogram_proto_cc_impl", "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc_impl", "@local_xla//xla/tsl/protobuf:test_log_proto_cc_impl", + "@org_xprof//xprof/convert:tool_options", "@org_xprof//xprof/pywrap:profiler_plugin_impl", "@pybind11", ] + if_macos([ diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc index 8ef893b13fe176..e5a640b931063a 100644 --- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc +++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc @@ -21,16 +21,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/profiler/convert/xplane_to_trace_events.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" #include "xla/tsl/profiler/utils/session_manager.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" -#include "tensorflow/core/profiler/rpc/client/save_profile.h" -#include "tensorflow/core/profiler/rpc/profiler_server.h" +#include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { @@ -38,7 +34,6 @@ namespace profiler { namespace pywrap { using tsl::profiler::GetRemoteSessionManagerOptionsLocked; -using tsl::profiler::ValidateHostPortPair; absl::Status ProfilerSessionWrapper::Start( const char* logdir, diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index 409d8fec75bdac..359a0f7446053e 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -20,12 +20,12 @@ limitations under the License. #include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" #include "tensorflow/python/lib/core/pybind11_status.h" #include "tensorflow/python/profiler/internal/profiler_pywrap_impl.h" +#include "xprof/convert/repository.h" // from @org_xprof +#include "xprof/convert/tool_options.h" // from @org_xprof +#include "xprof/convert/xplane_to_tools_data.h" // from @org_xprof namespace py = ::pybind11; diff --git a/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc b/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc index 26d9561e4b7d3e..05e07683a251ac 100644 --- a/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc +++ b/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" -#include "tensorflow/core/profiler/convert/tool_options.h" +#include "xprof/convert/tool_options.h" // from @org_xprof #include "xprof/pywrap/profiler_plugin_impl.h" // from @org_xprof namespace py = ::pybind11; diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 6ce2f4be5b2a80..c908483fecc8fe 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -927,9 +927,9 @@ def _tf_repositories(): tf_http_archive( name = "org_xprof", - sha256 = "dec4889a6a5123fca0a775ba20f22717b2d0c3af1491f41bb52e1b502595271e", - strip_prefix = "xprof-c3dbeb2c69b48163c6156d6f4a8c82ac34736f49", - urls = tf_mirror_urls("https://github.com/openxla/xprof/archive/c3dbeb2c69b48163c6156d6f4a8c82ac34736f49.zip"), + sha256 = "a14e688d4145b4964bf1e9deac4cf52e0baadfb77906d513ebd397a43fa06d1f", + strip_prefix = "xprof-81409fb324525ba73ff204b7702db3f436773430", + urls = tf_mirror_urls("https://github.com/openxla/xprof/archive/81409fb324525ba73ff204b7702db3f436773430.zip"), ) # used for adding androidx.annotation dependencies in tflite android jni. From d106eb7157080ebc011ba7bc5f934a2538d75f5d Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 15 Apr 2025 16:38:49 -0700 Subject: [PATCH 0841/1324] Add check for client mismatch when constructing arrays and then fix broken codepath. PiperOrigin-RevId: 748068869 --- .../xla/xla/python/pjrt_ifrt/pjrt_array.cc | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 1e2b59f31bd2c5..65c65c9e5f851d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -69,7 +69,7 @@ static const xla::ifrt::MemoryKind kPinnedHostMemoryKind( // Validates the sharding and PjRtBuffers have consistent device and memory // kind. absl::Status ValidateArrayCreationInput( - std::shared_ptr sharding, + PjRtCompatibleClient* client, std::shared_ptr sharding, const PjRtArray::PjRtBuffers& pjrt_buffers) { absl::Span sharding_devices = sharding->devices()->AddressableDeviceList()->devices(); @@ -90,6 +90,11 @@ absl::Status ValidateArrayCreationInput( if (!device) { return InvalidArgument("Sharding device %d is not a PjRtDevice", i); } + if (device->client() != client) { + return InvalidArgument( + "sharding client mismatches array client: %s vs %s", + sharding_devices[i]->DebugString(), client->platform_version()); + } if (pjrt_buffers[i]->device() != device->pjrt_device()) { return InvalidArgument( "PjRtBuffer's memory space is addressed by device %s vs sharding is " @@ -148,7 +153,8 @@ absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, DType dtype, Shape shape, std::shared_ptr sharding, PjRtBuffers pjrt_buffers, std::shared_ptr layout) { - TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers)); + TF_RETURN_IF_ERROR( + ValidateArrayCreationInput(client, sharding, pjrt_buffers)); return tsl::MakeRef(client, dtype, std::move(shape), std::move(sharding), std::move(pjrt_buffers), std::move(layout)); @@ -158,7 +164,8 @@ absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, std::shared_ptr sharding, PjRtBuffers pjrt_buffers, std::shared_ptr layout) { - TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers)); + TF_RETURN_IF_ERROR( + ValidateArrayCreationInput(client, sharding, pjrt_buffers)); return tsl::MakeRef(client, dtype, std::move(dynamic_shape), std::move(sharding), std::move(pjrt_buffers), std::move(layout)); @@ -437,6 +444,7 @@ absl::StatusOr> PjRtArray::Copy( canonicalized_sharding_memory_kind.memory_kind().has_value(); const absl::Span new_sharding_devices = new_sharding->devices()->devices(); + PjRtCompatibleClient* new_client = nullptr; for (int i = 0; i < pjrt_buffers_.size(); ++i) { TF_ASSIGN_OR_RETURN(Device * buffer_device, client_->LookupPjRtDevice(pjrt_buffers_[i]->device())); @@ -473,6 +481,7 @@ absl::StatusOr> PjRtArray::Copy( } else { PjRtCompatibleDevice* pjrt_device = llvm::dyn_cast(new_sharding_devices[i]); + new_client = llvm::dyn_cast(pjrt_device->client()); if (!pjrt_device) { return InvalidArgument( "The destination device is owned by a non-PjRt-compatible client. " @@ -510,9 +519,12 @@ absl::StatusOr> PjRtArray::Copy( buffers.push_back(std::move(copied_buffer)); } } + if (new_client == nullptr) { + new_client = client_; + } return std::visit( - [this, &new_sharding, &buffers](const auto& shape) { - return PjRtArray::Create(client_, dtype_, shape, + [this, new_client, &new_sharding, &buffers](const auto& shape) { + return PjRtArray::Create(new_client, dtype_, shape, std::move(new_sharding), std::move(buffers), layout_); }, From 97f37848ed991702b772fbdbd0b61b46d6bff008 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 16:38:55 -0700 Subject: [PATCH 0842/1324] Integrate LLVM at llvm/llvm-project@179d30f8c3fd Updates LLVM usage to match [179d30f8c3fd](https://github.com/llvm/llvm-project/commit/179d30f8c3fd) PiperOrigin-RevId: 748068900 --- .../tensorflow/utils/dump_mlir_util_test.cc | 1 - third_party/llvm/generated.patch | 169 ++++++++--- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 263 ++++++++++++++---- third_party/shardy/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 24 ++ .../triton/llvm_integration/cl747619712.patch | 112 ++++++++ .../triton/llvm_integration/series.bzl | 1 + .../xla/third_party/shardy/temporary.patch | 263 ++++++++++++++---- .../xla/third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 24 ++ .../triton/llvm_integration/cl747619712.patch | 112 ++++++++ .../triton/llvm_integration/series.bzl | 1 + .../emitters/transforms/convert_float_amd.cc | 22 +- .../transforms/convert_float_nvidia.cc | 8 +- 15 files changed, 845 insertions(+), 167 deletions(-) create mode 100644 third_party/triton/llvm_integration/cl747619712.patch create mode 100644 third_party/xla/third_party/triton/llvm_integration/cl747619712.patch diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 2efd63b29b04ef..aa818d2ae73bd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -126,7 +126,6 @@ TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { registry, mlir::MlirOptMainConfig{} .splitInputFile("") - .verifyDiagnostics(false) .verifyPasses(false) .allowUnregisteredDialects(false) .setPassPipelineParser(passPipeline)) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index bbffc2ff4b7cc3..436c4e97f096ef 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,49 +1,132 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h ---- a/libc/src/__support/FPUtil/aarch64/sqrt.h -+++ b/libc/src/__support/FPUtil/aarch64/sqrt.h -@@ -18,6 +18,8 @@ - #error "Invalid include" - #endif +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp +--- a/clang/lib/Sema/SemaExprCXX.cpp ++++ b/clang/lib/Sema/SemaExprCXX.cpp +@@ -1929,8 +1929,9 @@ + } + return true; + } +- +- return S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); ++ Sema::AccessResult Accessible = ++ S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); ++ return Accessible == Sema::AR_inaccessible; + } -+#include "src/__support/FPUtil/generic/sqrt.h" -+ - namespace LIBC_NAMESPACE_DECL { - namespace fputil { - -diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h ---- a/libc/src/__support/FPUtil/arm/sqrt.h -+++ b/libc/src/__support/FPUtil/arm/sqrt.h -@@ -18,6 +18,8 @@ - #error "Invalid include" - #endif + /// Select the correct "usual" deallocation function to use from a selection of +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp +--- a/clang/lib/Serialization/ASTReaderStmt.cpp ++++ b/clang/lib/Serialization/ASTReaderStmt.cpp +@@ -2226,10 +2226,7 @@ + E->AssociatedDeclAndRef.setPointer(readDeclAs()); + E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); + E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); +- if (CurrentUnpackingBits->getNextBit()) +- E->PackIndex = Record.readInt(); +- else +- E->PackIndex = 0; ++ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); + E->Final = CurrentUnpackingBits->getNextBit(); + E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); + E->Replacement = Record.readSubExpr(); +@@ -2239,6 +2236,7 @@ + SubstNonTypeTemplateParmPackExpr *E) { + VisitExpr(E); + E->AssociatedDecl = readDeclAs(); ++ E->Final = CurrentUnpackingBits->getNextBit(); + E->Index = Record.readInt(); + TemplateArgument ArgPack = Record.readTemplateArgument(); + if (ArgPack.getKind() != TemplateArgument::Pack) +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp +--- a/clang/lib/Serialization/ASTWriterStmt.cpp ++++ b/clang/lib/Serialization/ASTWriterStmt.cpp +@@ -2228,9 +2228,7 @@ + Record.AddDeclRef(E->getAssociatedDecl()); + CurrentPackingBits.addBit(E->isReferenceParameter()); + CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); +- CurrentPackingBits.addBit((bool)E->getPackIndex()); +- if (auto PackIndex = E->getPackIndex()) +- Record.push_back(*PackIndex + 1); ++ Record.writeUnsignedOrNone(E->getPackIndex()); + CurrentPackingBits.addBit(E->getFinal()); -+#include "src/__support/FPUtil/generic/sqrt.h" + Record.AddSourceLocation(E->getNameLoc()); +@@ -2242,6 +2240,7 @@ + SubstNonTypeTemplateParmPackExpr *E) { + VisitExpr(E); + Record.AddDeclRef(E->getAssociatedDecl()); ++ CurrentPackingBits.addBit(E->getFinal()); + Record.push_back(E->getIndex()); + Record.AddTemplateArgument(E->getArgumentPack()); + Record.AddSourceLocation(E->getParameterPackLocation()); +diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp +--- a/clang/test/CodeGenCXX/bug135668.cpp ++++ b/clang/test/CodeGenCXX/bug135668.cpp +@@ -0,0 +1,38 @@ ++// RUN: %clang_cc1 %s -triple arm64-apple-macosx -emit-llvm -fcxx-exceptions -fexceptions -std=c++23 -o - | FileCheck %s + - namespace LIBC_NAMESPACE_DECL { - namespace fputil { - -diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h ---- a/libc/src/__support/FPUtil/riscv/sqrt.h -+++ b/libc/src/__support/FPUtil/riscv/sqrt.h -@@ -18,6 +18,8 @@ - #error "Invalid include" - #endif - -+#include "src/__support/FPUtil/generic/sqrt.h" ++class TestClass { ++ public: ++ TestClass(); ++ int field = 0; ++ friend class Foo; ++ static void * operator new(unsigned long size); ++ private: ++ static void operator delete(void *p); ++ }; + - namespace LIBC_NAMESPACE_DECL { - namespace fputil { - -diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h ---- a/libc/src/__support/FPUtil/x86_64/sqrt.h -+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h -@@ -18,6 +18,8 @@ - #error "sqrtss / sqrtsd need SSE2" - #endif - -+#include "src/__support/FPUtil/generic/sqrt.h" ++class Foo { ++public: ++ int test_method(); ++}; + - namespace LIBC_NAMESPACE_DECL { - namespace fputil { - ++int Foo::test_method() { ++ TestClass *obj = new TestClass() ; ++ return obj->field; ++} ++ ++// CHECK-LABEL: define noundef i32 @_ZN3Foo11test_methodEv ++// CHECK: [[THIS_ADDR:%.*]] = alloca ptr, align 8 ++// CHECK: [[OBJ:%.*]] = alloca ptr, align 8 ++// CHECK: store ptr %this, ptr [[THIS_ADDR]], align 8 ++// CHECK: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8 ++// CHECK: [[ALLOCATION:%.*]] = call noundef ptr @_ZN9TestClassnwEm(i64 noundef 4) ++// CHECK: [[INITIALIZEDOBJ:%.*]] = invoke noundef ptr @_ZN9TestClassC1Ev(ptr noundef nonnull align 4 dereferenceable(4) [[ALLOCATION]]) ++// CHECK-NEXT: to label %[[INVOKE_CONT:.*]] unwind label %[[LPAD:.*]] ++// CHECK: [[INVOKE_CONT]]: ++// CHECK: store ptr [[ALLOCATION]], ptr [[OBJ]], align 8 ++// CHECK: [[OBJPTR:%.*]] = load ptr, ptr [[OBJ]], align 8 ++// CHECK: [[FIELDPTR:%.*]] = getelementptr inbounds nuw %class.TestClass, ptr [[OBJPTR]], i32 0, i32 0 ++// CHECK: [[FIELD:%.*]] = load i32, ptr [[FIELDPTR]], align 4 ++// CHECK: ret i32 [[FIELD]] ++// CHECK: [[LPAD]]: ++// CHECK: call void @_ZN9TestClassdlEPv(ptr noundef [[ALLOCATION]]) #3 +diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/SemaCXX/bug135668.cpp +--- a/clang/test/SemaCXX/bug135668.cpp ++++ b/clang/test/SemaCXX/bug135668.cpp +@@ -0,0 +1,25 @@ ++// RUN: %clang_cc1 -triple arm64-apple-macosx -Wall -fsyntax-only -verify %s -std=c++26 -fexceptions -fcxx-exceptions ++// expected-no-diagnostics ++ ++// This test makes sure that we don't erroneously consider an accessible operator ++// delete to be inaccessible, and then discard the entire new expression. ++ ++class TestClass { ++public: ++ TestClass(); ++ int field = 0; ++ friend class Foo; ++ static void * operator new(unsigned long size); ++private: ++ static void operator delete(void *p); ++}; ++ ++class Foo { ++public: ++ int test_method(); ++}; ++ ++int Foo::test_method() { ++ TestClass *obj = new TestClass() ; ++ return obj->field; ++} diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 0b67d8b3fd140f..3ec4c3ec618e1f 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" - LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" + LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" + LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 34a45370f62ef2..4779b912722730 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,69 +1,230 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..bbffc2f 100644 +index bbffc2f..436c4e9 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1 +1,49 @@ +@@ -1,49 +1,132 @@ Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h -+--- a/libc/src/__support/FPUtil/aarch64/sqrt.h -++++ b/libc/src/__support/FPUtil/aarch64/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "Invalid include" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h +---- a/libc/src/__support/FPUtil/aarch64/sqrt.h +-+++ b/libc/src/__support/FPUtil/aarch64/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "Invalid include" +- #endif ++diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp ++--- a/clang/lib/Sema/SemaExprCXX.cpp +++++ b/clang/lib/Sema/SemaExprCXX.cpp ++@@ -1929,8 +1929,9 @@ ++ } ++ return true; ++ } ++- ++- return S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); +++ Sema::AccessResult Accessible = +++ S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); +++ return Accessible == Sema::AR_inaccessible; ++ } + +-+#include "src/__support/FPUtil/generic/sqrt.h" +-+ +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h +---- a/libc/src/__support/FPUtil/arm/sqrt.h +-+++ b/libc/src/__support/FPUtil/arm/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "Invalid include" +- #endif ++ /// Select the correct "usual" deallocation function to use from a selection of ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ++--- a/clang/lib/Serialization/ASTReaderStmt.cpp +++++ b/clang/lib/Serialization/ASTReaderStmt.cpp ++@@ -2226,10 +2226,7 @@ ++ E->AssociatedDeclAndRef.setPointer(readDeclAs()); ++ E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); ++ E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); ++- if (CurrentUnpackingBits->getNextBit()) ++- E->PackIndex = Record.readInt(); ++- else ++- E->PackIndex = 0; +++ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); ++ E->Final = CurrentUnpackingBits->getNextBit(); ++ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); ++ E->Replacement = Record.readSubExpr(); ++@@ -2239,6 +2236,7 @@ ++ SubstNonTypeTemplateParmPackExpr *E) { ++ VisitExpr(E); ++ E->AssociatedDecl = readDeclAs(); +++ E->Final = CurrentUnpackingBits->getNextBit(); ++ E->Index = Record.readInt(); ++ TemplateArgument ArgPack = Record.readTemplateArgument(); ++ if (ArgPack.getKind() != TemplateArgument::Pack) ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ++--- a/clang/lib/Serialization/ASTWriterStmt.cpp +++++ b/clang/lib/Serialization/ASTWriterStmt.cpp ++@@ -2228,9 +2228,7 @@ ++ Record.AddDeclRef(E->getAssociatedDecl()); ++ CurrentPackingBits.addBit(E->isReferenceParameter()); ++ CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); ++- CurrentPackingBits.addBit((bool)E->getPackIndex()); ++- if (auto PackIndex = E->getPackIndex()) ++- Record.push_back(*PackIndex + 1); +++ Record.writeUnsignedOrNone(E->getPackIndex()); ++ CurrentPackingBits.addBit(E->getFinal()); + +-+#include "src/__support/FPUtil/generic/sqrt.h" ++ Record.AddSourceLocation(E->getNameLoc()); ++@@ -2242,6 +2240,7 @@ ++ SubstNonTypeTemplateParmPackExpr *E) { ++ VisitExpr(E); ++ Record.AddDeclRef(E->getAssociatedDecl()); +++ CurrentPackingBits.addBit(E->getFinal()); ++ Record.push_back(E->getIndex()); ++ Record.AddTemplateArgument(E->getArgumentPack()); ++ Record.AddSourceLocation(E->getParameterPackLocation()); ++diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp ++--- a/clang/test/CodeGenCXX/bug135668.cpp +++++ b/clang/test/CodeGenCXX/bug135668.cpp ++@@ -0,0 +1,38 @@ +++// RUN: %clang_cc1 %s -triple arm64-apple-macosx -emit-llvm -fcxx-exceptions -fexceptions -std=c++23 -o - | FileCheck %s + + +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h +---- a/libc/src/__support/FPUtil/riscv/sqrt.h +-+++ b/libc/src/__support/FPUtil/riscv/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "Invalid include" +- #endif +- +-+#include "src/__support/FPUtil/generic/sqrt.h" +++class TestClass { +++ public: +++ TestClass(); +++ int field = 0; +++ friend class Foo; +++ static void * operator new(unsigned long size); +++ private: +++ static void operator delete(void *p); +++ }; + + +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h +---- a/libc/src/__support/FPUtil/x86_64/sqrt.h +-+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "sqrtss / sqrtsd need SSE2" +- #endif +- +-+#include "src/__support/FPUtil/generic/sqrt.h" +++class Foo { +++public: +++ int test_method(); +++}; + + +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +++int Foo::test_method() { +++ TestClass *obj = new TestClass() ; +++ return obj->field; +++} ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h -+--- a/libc/src/__support/FPUtil/arm/sqrt.h -++++ b/libc/src/__support/FPUtil/arm/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "Invalid include" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +++// CHECK-LABEL: define noundef i32 @_ZN3Foo11test_methodEv +++// CHECK: [[THIS_ADDR:%.*]] = alloca ptr, align 8 +++// CHECK: [[OBJ:%.*]] = alloca ptr, align 8 +++// CHECK: store ptr %this, ptr [[THIS_ADDR]], align 8 +++// CHECK: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8 +++// CHECK: [[ALLOCATION:%.*]] = call noundef ptr @_ZN9TestClassnwEm(i64 noundef 4) +++// CHECK: [[INITIALIZEDOBJ:%.*]] = invoke noundef ptr @_ZN9TestClassC1Ev(ptr noundef nonnull align 4 dereferenceable(4) [[ALLOCATION]]) +++// CHECK-NEXT: to label %[[INVOKE_CONT:.*]] unwind label %[[LPAD:.*]] +++// CHECK: [[INVOKE_CONT]]: +++// CHECK: store ptr [[ALLOCATION]], ptr [[OBJ]], align 8 +++// CHECK: [[OBJPTR:%.*]] = load ptr, ptr [[OBJ]], align 8 +++// CHECK: [[FIELDPTR:%.*]] = getelementptr inbounds nuw %class.TestClass, ptr [[OBJPTR]], i32 0, i32 0 +++// CHECK: [[FIELD:%.*]] = load i32, ptr [[FIELDPTR]], align 4 +++// CHECK: ret i32 [[FIELD]] +++// CHECK: [[LPAD]]: +++// CHECK: call void @_ZN9TestClassdlEPv(ptr noundef [[ALLOCATION]]) #3 ++diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/SemaCXX/bug135668.cpp ++--- a/clang/test/SemaCXX/bug135668.cpp +++++ b/clang/test/SemaCXX/bug135668.cpp ++@@ -0,0 +1,25 @@ +++// RUN: %clang_cc1 -triple arm64-apple-macosx -Wall -fsyntax-only -verify %s -std=c++26 -fexceptions -fcxx-exceptions +++// expected-no-diagnostics ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h -+--- a/libc/src/__support/FPUtil/riscv/sqrt.h -++++ b/libc/src/__support/FPUtil/riscv/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "Invalid include" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +++// This test makes sure that we don't erroneously consider an accessible operator +++// delete to be inaccessible, and then discard the entire new expression. ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h -+--- a/libc/src/__support/FPUtil/x86_64/sqrt.h -++++ b/libc/src/__support/FPUtil/x86_64/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "sqrtss / sqrtsd need SSE2" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +++class TestClass { +++public: +++ TestClass(); +++ int field = 0; +++ friend class Foo; +++ static void * operator new(unsigned long size); +++private: +++ static void operator delete(void *p); +++}; ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ +++class Foo { +++public: +++ int test_method(); +++}; +++ +++int Foo::test_method() { +++ TestClass *obj = new TestClass() ; +++ return obj->field; +++} diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 7993194..0b67d8b 100644 +index 0b67d8b..3ec4c3e 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" -- LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" -+ LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" -+ LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" +- LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" +- LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" ++ LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" ++ LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" tf_http_archive( name = name, +diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch +index ff118b8..839d98a 100755 +--- a/third_party/stablehlo/temporary.patch ++++ b/third_party/stablehlo/temporary.patch +@@ -607,6 +607,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ ++diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir ++--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir +++++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir ++@@ -45,7 +45,7 @@ ++ ++ // CHECK-LABEL: @divide ++ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> { ++- // CHECK: tosa.int_div +++ // CHECK: tosa.intdiv ++ %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> ++ return %0 : tensor<10xi32> ++ } ++diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll ++--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll ++@@ -156,7 +156,7 @@ ++ Pattern => ++ replace op(input0 : Value<_: Tosa_Int32Tensor>, ++ input1 : Value<_: Tosa_Int32Tensor>) ++- with op(input0, input1); +++ with op(input0, input1); ++ Pattern => ++ replace op(input0 : Value<_: Tosa_Tensor>, ++ input1 : Value<_: Tosa_Tensor>) + diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel + --- stablehlo/stablehlo/tests/BUILD.bazel + +++ stablehlo/stablehlo/tests/BUILD.bazel diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 1e0b188c3f5b28..e2bc747f2d6d05 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "0d88b5d25971bd66272195ceeb2288cde72997d0" - SHARDY_SHA256 = "e2cb1a9d409c49c724739e77156e7ca69b51b68e07e6017f149769f6fdafed42" + SHARDY_COMMIT = "9585bea76e06fff5574ddf20bd88cbdfd0b98985" + SHARDY_SHA256 = "f93f37639f8ec1cd5ae4c26e0b3291bfa85923567c01e9404c6e070d061c598a" tf_http_archive( name = "shardy", diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index ff118b89c8b1ba..839d98a4ab4b74 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -607,6 +607,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "stablehlo/dialect/VhloOps.td", deps = [ +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir +--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir ++++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir +@@ -45,7 +45,7 @@ + + // CHECK-LABEL: @divide + func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> { +- // CHECK: tosa.int_div ++ // CHECK: tosa.intdiv + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %0 : tensor<10xi32> + } +diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll ++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +@@ -156,7 +156,7 @@ + Pattern => + replace op(input0 : Value<_: Tosa_Int32Tensor>, + input1 : Value<_: Tosa_Int32Tensor>) +- with op(input0, input1); ++ with op(input0, input1); + Pattern => + replace op(input0 : Value<_: Tosa_Tensor>, + input1 : Value<_: Tosa_Tensor>) diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel --- stablehlo/stablehlo/tests/BUILD.bazel +++ stablehlo/stablehlo/tests/BUILD.bazel diff --git a/third_party/triton/llvm_integration/cl747619712.patch b/third_party/triton/llvm_integration/cl747619712.patch new file mode 100644 index 00000000000000..00d8b36216b3eb --- /dev/null +++ b/third_party/triton/llvm_integration/cl747619712.patch @@ -0,0 +1,112 @@ + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-11 01:29:32.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-14 17:05:58.000000000 -0700 +@@ -12,6 +12,7 @@ + #include "mlir/IR/TypeUtilities.h" + #include "mlir/IR/ValueRange.h" + #include "mlir/Transforms/DialectConversion.h" ++#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" + #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + #include "triton/Dialect/Triton/IR/Types.h" + #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +@@ -366,7 +367,7 @@ + + auto cacheMod = op.getCache(); + SmallVector loadedVals; +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; +@@ -466,7 +467,7 @@ + // Create the resource descriptor and then emit the buffer_load intrinsic(s) + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); + SmallVector loadedVals; +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Value pred = mask ? maskElems[vecStart] : b.int_val(1, 1); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); +@@ -818,7 +819,7 @@ + Value pred = + llMask ? b.and_(threadPred, maskElems[vecStart]) : threadPred; + +- auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ auto vecTy = VectorType::get(vec, valueElemTy); + + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; +@@ -1038,7 +1039,7 @@ + Value pred = + llMask ? b.and_(threadPred, maskElems[vecStart]) : threadPred; + +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + // Create the store val + Value storeVal = packElementRangeIntoVector( +@@ -1148,7 +1149,7 @@ + Value pred = + llMask ? b.and_(threadPred, maskElems[vecStart]) : threadPred; + +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-11 01:29:32.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-14 17:05:59.000000000 -0700 +@@ -1,9 +1,11 @@ + #include "Utility.h" ++ + #include "Dialect/TritonAMDGPU/IR/Dialect.h" + #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" + #include "mlir/Dialect/LLVMIR/LLVMTypes.h" + #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" + #include "mlir/IR/PatternMatch.h" ++#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" + #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + #include "triton/Dialect/Triton/IR/Dialect.h" + #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +@@ -49,7 +51,7 @@ + Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, + Value pred, int64_t vecSize) { + auto b = TritonLLVMOpBuilder(loc, rewriter); +- auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize); ++ auto vecMaskTy = VectorType::get(vecSize, rewriter.getI1Type()); + Value maskVal = b.undef(vecMaskTy); + for (size_t s = 0; s < vecSize; ++s) { + Value indexVal = +@@ -70,7 +72,7 @@ + Type castToVectorType(Type ty) { + if (isa(ty)) + return ty; +- return LLVM::getFixedVectorType(ty, 1); ++ return VectorType::get(1, ty); + } + + } // namespace + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-11 01:29:32.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-14 17:05:59.000000000 -0700 +@@ -305,7 +305,7 @@ + + size_t size = width / valueElemNBits; + +- auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); ++ auto vecTy = VectorType::get(size, valueElemTy); + Value v = b.undef(vecTy); + for (size_t s = 0; s < size; ++s) { + Value falseVal = otherElems[vecStart + ii * size + s]; +@@ -376,8 +376,8 @@ + } else { + curr = ret; + } +- curr = b.bitcast(curr, LLVM::getFixedVectorType( +- valueElemTy, width / valueElemNBits)); ++ curr = b.bitcast(curr, ++ VectorType::get(width / valueElemNBits, valueElemTy)); + rets.push_back(curr); + } + int tmp = width / valueElemNBits; diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index d820528c8a38f6..d48952dd39c0aa 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -9,5 +9,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. llvm_patch_list = [ "//third_party/triton:llvm_integration/cl744822685.patch", + "//third_party/triton:llvm_integration/cl747619712.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 34a45370f62ef2..4779b912722730 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,69 +1,230 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..bbffc2f 100644 +index bbffc2f..436c4e9 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1 +1,49 @@ +@@ -1,49 +1,132 @@ Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h -+--- a/libc/src/__support/FPUtil/aarch64/sqrt.h -++++ b/libc/src/__support/FPUtil/aarch64/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "Invalid include" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h +---- a/libc/src/__support/FPUtil/aarch64/sqrt.h +-+++ b/libc/src/__support/FPUtil/aarch64/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "Invalid include" +- #endif ++diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp ++--- a/clang/lib/Sema/SemaExprCXX.cpp +++++ b/clang/lib/Sema/SemaExprCXX.cpp ++@@ -1929,8 +1929,9 @@ ++ } ++ return true; ++ } ++- ++- return S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); +++ Sema::AccessResult Accessible = +++ S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); +++ return Accessible == Sema::AR_inaccessible; ++ } + +-+#include "src/__support/FPUtil/generic/sqrt.h" +-+ +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h +---- a/libc/src/__support/FPUtil/arm/sqrt.h +-+++ b/libc/src/__support/FPUtil/arm/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "Invalid include" +- #endif ++ /// Select the correct "usual" deallocation function to use from a selection of ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ++--- a/clang/lib/Serialization/ASTReaderStmt.cpp +++++ b/clang/lib/Serialization/ASTReaderStmt.cpp ++@@ -2226,10 +2226,7 @@ ++ E->AssociatedDeclAndRef.setPointer(readDeclAs()); ++ E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); ++ E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); ++- if (CurrentUnpackingBits->getNextBit()) ++- E->PackIndex = Record.readInt(); ++- else ++- E->PackIndex = 0; +++ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); ++ E->Final = CurrentUnpackingBits->getNextBit(); ++ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); ++ E->Replacement = Record.readSubExpr(); ++@@ -2239,6 +2236,7 @@ ++ SubstNonTypeTemplateParmPackExpr *E) { ++ VisitExpr(E); ++ E->AssociatedDecl = readDeclAs(); +++ E->Final = CurrentUnpackingBits->getNextBit(); ++ E->Index = Record.readInt(); ++ TemplateArgument ArgPack = Record.readTemplateArgument(); ++ if (ArgPack.getKind() != TemplateArgument::Pack) ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ++--- a/clang/lib/Serialization/ASTWriterStmt.cpp +++++ b/clang/lib/Serialization/ASTWriterStmt.cpp ++@@ -2228,9 +2228,7 @@ ++ Record.AddDeclRef(E->getAssociatedDecl()); ++ CurrentPackingBits.addBit(E->isReferenceParameter()); ++ CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); ++- CurrentPackingBits.addBit((bool)E->getPackIndex()); ++- if (auto PackIndex = E->getPackIndex()) ++- Record.push_back(*PackIndex + 1); +++ Record.writeUnsignedOrNone(E->getPackIndex()); ++ CurrentPackingBits.addBit(E->getFinal()); + +-+#include "src/__support/FPUtil/generic/sqrt.h" ++ Record.AddSourceLocation(E->getNameLoc()); ++@@ -2242,6 +2240,7 @@ ++ SubstNonTypeTemplateParmPackExpr *E) { ++ VisitExpr(E); ++ Record.AddDeclRef(E->getAssociatedDecl()); +++ CurrentPackingBits.addBit(E->getFinal()); ++ Record.push_back(E->getIndex()); ++ Record.AddTemplateArgument(E->getArgumentPack()); ++ Record.AddSourceLocation(E->getParameterPackLocation()); ++diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp ++--- a/clang/test/CodeGenCXX/bug135668.cpp +++++ b/clang/test/CodeGenCXX/bug135668.cpp ++@@ -0,0 +1,38 @@ +++// RUN: %clang_cc1 %s -triple arm64-apple-macosx -emit-llvm -fcxx-exceptions -fexceptions -std=c++23 -o - | FileCheck %s + + +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h +---- a/libc/src/__support/FPUtil/riscv/sqrt.h +-+++ b/libc/src/__support/FPUtil/riscv/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "Invalid include" +- #endif +- +-+#include "src/__support/FPUtil/generic/sqrt.h" +++class TestClass { +++ public: +++ TestClass(); +++ int field = 0; +++ friend class Foo; +++ static void * operator new(unsigned long size); +++ private: +++ static void operator delete(void *p); +++ }; + + +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +-diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h +---- a/libc/src/__support/FPUtil/x86_64/sqrt.h +-+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h +-@@ -18,6 +18,8 @@ +- #error "sqrtss / sqrtsd need SSE2" +- #endif +- +-+#include "src/__support/FPUtil/generic/sqrt.h" +++class Foo { +++public: +++ int test_method(); +++}; + + +- namespace LIBC_NAMESPACE_DECL { +- namespace fputil { +- +++int Foo::test_method() { +++ TestClass *obj = new TestClass() ; +++ return obj->field; +++} ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h -+--- a/libc/src/__support/FPUtil/arm/sqrt.h -++++ b/libc/src/__support/FPUtil/arm/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "Invalid include" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +++// CHECK-LABEL: define noundef i32 @_ZN3Foo11test_methodEv +++// CHECK: [[THIS_ADDR:%.*]] = alloca ptr, align 8 +++// CHECK: [[OBJ:%.*]] = alloca ptr, align 8 +++// CHECK: store ptr %this, ptr [[THIS_ADDR]], align 8 +++// CHECK: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8 +++// CHECK: [[ALLOCATION:%.*]] = call noundef ptr @_ZN9TestClassnwEm(i64 noundef 4) +++// CHECK: [[INITIALIZEDOBJ:%.*]] = invoke noundef ptr @_ZN9TestClassC1Ev(ptr noundef nonnull align 4 dereferenceable(4) [[ALLOCATION]]) +++// CHECK-NEXT: to label %[[INVOKE_CONT:.*]] unwind label %[[LPAD:.*]] +++// CHECK: [[INVOKE_CONT]]: +++// CHECK: store ptr [[ALLOCATION]], ptr [[OBJ]], align 8 +++// CHECK: [[OBJPTR:%.*]] = load ptr, ptr [[OBJ]], align 8 +++// CHECK: [[FIELDPTR:%.*]] = getelementptr inbounds nuw %class.TestClass, ptr [[OBJPTR]], i32 0, i32 0 +++// CHECK: [[FIELD:%.*]] = load i32, ptr [[FIELDPTR]], align 4 +++// CHECK: ret i32 [[FIELD]] +++// CHECK: [[LPAD]]: +++// CHECK: call void @_ZN9TestClassdlEPv(ptr noundef [[ALLOCATION]]) #3 ++diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/SemaCXX/bug135668.cpp ++--- a/clang/test/SemaCXX/bug135668.cpp +++++ b/clang/test/SemaCXX/bug135668.cpp ++@@ -0,0 +1,25 @@ +++// RUN: %clang_cc1 -triple arm64-apple-macosx -Wall -fsyntax-only -verify %s -std=c++26 -fexceptions -fcxx-exceptions +++// expected-no-diagnostics ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h -+--- a/libc/src/__support/FPUtil/riscv/sqrt.h -++++ b/libc/src/__support/FPUtil/riscv/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "Invalid include" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +++// This test makes sure that we don't erroneously consider an accessible operator +++// delete to be inaccessible, and then discard the entire new expression. ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ -+diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h -+--- a/libc/src/__support/FPUtil/x86_64/sqrt.h -++++ b/libc/src/__support/FPUtil/x86_64/sqrt.h -+@@ -18,6 +18,8 @@ -+ #error "sqrtss / sqrtsd need SSE2" -+ #endif -+ -++#include "src/__support/FPUtil/generic/sqrt.h" +++class TestClass { +++public: +++ TestClass(); +++ int field = 0; +++ friend class Foo; +++ static void * operator new(unsigned long size); +++private: +++ static void operator delete(void *p); +++}; ++ -+ namespace LIBC_NAMESPACE_DECL { -+ namespace fputil { -+ +++class Foo { +++public: +++ int test_method(); +++}; +++ +++int Foo::test_method() { +++ TestClass *obj = new TestClass() ; +++ return obj->field; +++} diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 7993194..0b67d8b 100644 +index 0b67d8b..3ec4c3e 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "836476660e5c068a8b3034c2bc21dbb70683f0fe" -- LLVM_SHA256 = "5f04042bc59cf156cea0f4a03eb9408371e50e4337e7256b4dced10dfa43dec9" -+ LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" -+ LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" +- LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" +- LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" ++ LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" ++ LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" tf_http_archive( name = name, +diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch +index ff118b8..839d98a 100755 +--- a/third_party/stablehlo/temporary.patch ++++ b/third_party/stablehlo/temporary.patch +@@ -607,6 +607,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/dialect/VhloOps.td", + deps = [ ++diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir ++--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir +++++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir ++@@ -45,7 +45,7 @@ ++ ++ // CHECK-LABEL: @divide ++ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> { ++- // CHECK: tosa.int_div +++ // CHECK: tosa.intdiv ++ %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> ++ return %0 : tensor<10xi32> ++ } ++diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll ++--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll ++@@ -156,7 +156,7 @@ ++ Pattern => ++ replace op(input0 : Value<_: Tosa_Int32Tensor>, ++ input1 : Value<_: Tosa_Int32Tensor>) ++- with op(input0, input1); +++ with op(input0, input1); ++ Pattern => ++ replace op(input0 : Value<_: Tosa_Tensor>, ++ input1 : Value<_: Tosa_Tensor>) + diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel + --- stablehlo/stablehlo/tests/BUILD.bazel + +++ stablehlo/stablehlo/tests/BUILD.bazel diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index 1e0b188c3f5b28..e2bc747f2d6d05 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "0d88b5d25971bd66272195ceeb2288cde72997d0" - SHARDY_SHA256 = "e2cb1a9d409c49c724739e77156e7ca69b51b68e07e6017f149769f6fdafed42" + SHARDY_COMMIT = "9585bea76e06fff5574ddf20bd88cbdfd0b98985" + SHARDY_SHA256 = "f93f37639f8ec1cd5ae4c26e0b3291bfa85923567c01e9404c6e070d061c598a" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index ff118b89c8b1ba..839d98a4ab4b74 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -607,6 +607,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "stablehlo/dialect/VhloOps.td", deps = [ +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir +--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir ++++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir +@@ -45,7 +45,7 @@ + + // CHECK-LABEL: @divide + func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> { +- // CHECK: tosa.int_div ++ // CHECK: tosa.intdiv + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %0 : tensor<10xi32> + } +diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll ++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +@@ -156,7 +156,7 @@ + Pattern => + replace op(input0 : Value<_: Tosa_Int32Tensor>, + input1 : Value<_: Tosa_Int32Tensor>) +- with op(input0, input1); ++ with op(input0, input1); + Pattern => + replace op(input0 : Value<_: Tosa_Tensor>, + input1 : Value<_: Tosa_Tensor>) diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel --- stablehlo/stablehlo/tests/BUILD.bazel +++ stablehlo/stablehlo/tests/BUILD.bazel diff --git a/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch b/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch new file mode 100644 index 00000000000000..00d8b36216b3eb --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch @@ -0,0 +1,112 @@ + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-11 01:29:32.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-14 17:05:58.000000000 -0700 +@@ -12,6 +12,7 @@ + #include "mlir/IR/TypeUtilities.h" + #include "mlir/IR/ValueRange.h" + #include "mlir/Transforms/DialectConversion.h" ++#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" + #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + #include "triton/Dialect/Triton/IR/Types.h" + #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +@@ -366,7 +367,7 @@ + + auto cacheMod = op.getCache(); + SmallVector loadedVals; +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; +@@ -466,7 +467,7 @@ + // Create the resource descriptor and then emit the buffer_load intrinsic(s) + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); + SmallVector loadedVals; +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Value pred = mask ? maskElems[vecStart] : b.int_val(1, 1); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); +@@ -818,7 +819,7 @@ + Value pred = + llMask ? b.and_(threadPred, maskElems[vecStart]) : threadPred; + +- auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ auto vecTy = VectorType::get(vec, valueElemTy); + + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; +@@ -1038,7 +1039,7 @@ + Value pred = + llMask ? b.and_(threadPred, maskElems[vecStart]) : threadPred; + +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + // Create the store val + Value storeVal = packElementRangeIntoVector( +@@ -1148,7 +1149,7 @@ + Value pred = + llMask ? b.and_(threadPred, maskElems[vecStart]) : threadPred; + +- Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); ++ Type vecTy = VectorType::get(vec, valueElemTy); + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-11 01:29:32.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-14 17:05:59.000000000 -0700 +@@ -1,9 +1,11 @@ + #include "Utility.h" ++ + #include "Dialect/TritonAMDGPU/IR/Dialect.h" + #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" + #include "mlir/Dialect/LLVMIR/LLVMTypes.h" + #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" + #include "mlir/IR/PatternMatch.h" ++#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" + #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + #include "triton/Dialect/Triton/IR/Dialect.h" + #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +@@ -49,7 +51,7 @@ + Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, + Value pred, int64_t vecSize) { + auto b = TritonLLVMOpBuilder(loc, rewriter); +- auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize); ++ auto vecMaskTy = VectorType::get(vecSize, rewriter.getI1Type()); + Value maskVal = b.undef(vecMaskTy); + for (size_t s = 0; s < vecSize; ++s) { + Value indexVal = +@@ -70,7 +72,7 @@ + Type castToVectorType(Type ty) { + if (isa(ty)) + return ty; +- return LLVM::getFixedVectorType(ty, 1); ++ return VectorType::get(1, ty); + } + + } // namespace + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-11 01:29:32.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-14 17:05:59.000000000 -0700 +@@ -305,7 +305,7 @@ + + size_t size = width / valueElemNBits; + +- auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); ++ auto vecTy = VectorType::get(size, valueElemTy); + Value v = b.undef(vecTy); + for (size_t s = 0; s < size; ++s) { + Value falseVal = otherElems[vecStart + ii * size + s]; +@@ -376,8 +376,8 @@ + } else { + curr = ret; + } +- curr = b.bitcast(curr, LLVM::getFixedVectorType( +- valueElemTy, width / valueElemNBits)); ++ curr = b.bitcast(curr, ++ VectorType::get(width / valueElemNBits, valueElemTy)); + rets.push_back(curr); + } + int tmp = width / valueElemNBits; diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index d820528c8a38f6..d48952dd39c0aa 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -9,5 +9,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. llvm_patch_list = [ "//third_party/triton:llvm_integration/cl744822685.patch", + "//third_party/triton:llvm_integration/cl747619712.patch", # Add new patches just above this line ] diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc index 4075a4977572a8..2679fb22ed0a58 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_amd.cc @@ -221,7 +221,7 @@ struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern { size_t num_chunks = (num_elements + 2) / 4; - mlir::Type chunks_ty = LLVM::getFixedVectorType(i32_ty, num_chunks); + mlir::Type chunks_ty = mlir::VectorType::get(num_chunks, i32_ty); mlir::Value chunks = b.create(chunks_ty); bool pos = false; for (size_t i = 0; i < inputs.size() / 2; i++) { @@ -241,10 +241,10 @@ struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern { .create( to_ty, mlir::ValueRange{b.create( - LLVM::getFixedVectorType(i8_ty, num_elements), + mlir::VectorType::get(num_elements, i8_ty), b.create( b.create( - LLVM::getFixedVectorType(b.getI16Type(), 2), chunks), + mlir::VectorType::get(2, b.getI16Type()), chunks), b.create(i32_ty, 0)))}) .getResult(0); } @@ -252,7 +252,7 @@ struct RewriteFp8TruncFPattern : public Fp8OpRewritePattern { return b .create( to_ty, mlir::ValueRange{b.create( - LLVM::getFixedVectorType(i8_ty, num_elements), chunks)}) + mlir::VectorType::get(num_elements, i8_ty), chunks)}) .getResult(0); } @@ -435,24 +435,24 @@ struct RewriteFp8ExtFPattern : public Fp8OpRewritePattern { assert(num_elements == 2 || num_elements % 4 == 0); size_t num_chunks = (num_elements + 2) / 4; - mlir::Type chunks_ty = LLVM::getFixedVectorType(i32_ty, num_chunks); + mlir::Type chunks_ty = mlir::VectorType::get(num_chunks, i32_ty); mlir::Value chunks; if (num_elements == 2) { chunks = b.create( chunks_ty, b.create( - b.create(LLVM::getFixedVectorType(i16_ty, 2)), + b.create(mlir::VectorType::get(2, i16_ty)), b.create( i16_ty, b.create( - LLVM::getFixedVectorType(i8_ty, num_elements), + mlir::VectorType::get(num_elements, i8_ty), mlir::ValueRange{value}) .getResult(0)), zero_cst)); } else { chunks = b.create( chunks_ty, b.create( - LLVM::getFixedVectorType(i8_ty, num_elements), + mlir::VectorType::get(num_elements, i8_ty), mlir::ValueRange{value}) .getResult(0)); } @@ -461,7 +461,7 @@ struct RewriteFp8ExtFPattern : public Fp8OpRewritePattern { mlir::StringAttr cvtIntr = b.getStringAttr( isFp8(value.getType().getElementType()) ? "llvm.amdgcn.cvt.pk.f32.fp8" : "llvm.amdgcn.cvt.pk.f32.bf8"); - mlir::Type result_ty = LLVM::getFixedVectorType(f32_ty, 2); + mlir::Type result_ty = mlir::VectorType::get(2, f32_ty); LLVM::FastmathFlagsAttr flags = LLVM::FastmathFlagsAttr::get(b.getContext(), LLVM::FastmathFlags::ninf); for (size_t i = 0; i < num_elements / 2; i++) { @@ -480,7 +480,7 @@ struct RewriteFp8ExtFPattern : public Fp8OpRewritePattern { } if (to_ty.isF16()) { - result_ty = LLVM::getFixedVectorType(b.getF16Type(), 2); + result_ty = mlir::VectorType::get(2, b.getF16Type()); cvtIntr = b.getStringAttr("llvm.amdgcn.cvt.pkrtz"); for (size_t i = 0; i < num_elements / 2; i++) { LLVM::CallIntrinsicOp cvtOp = b.create( @@ -513,7 +513,7 @@ struct RewriteFp8ExtFPattern : public Fp8OpRewritePattern { // Emulate anyext mlir::Value input = b.create( i32_ty, b.create( - b.create(LLVM::getFixedVectorType(i8_ty, 4)), + b.create(mlir::VectorType::get(4, i8_ty)), b.create( i8_ty, mlir::ValueRange{value}) .getResult(0), diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc index ff280e5a38059f..fd763ade93cb09 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/transforms/convert_float_nvidia.cc @@ -79,7 +79,7 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { if (value.getType() == b.getF16Type()) { // Fast path for truncating F16 type. Value vec = - b.create(ml::getFixedVectorType(value.getType(), 2)); + b.create(mlir::VectorType::get(2, value.getType())); vec = b.create(vec, value, b.create(0, 8)); auto cvtIntr = llvm::isa(to_ty) @@ -215,9 +215,9 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { ? "llvm.nvvm.e4m3x2.to.f16x2.rn" : "llvm.nvvm.e5m2x2.to.f16x2.rn"; mlir::FloatType f16_ty = b.getF16Type(); - auto cvtOp = b.create( - ml::getFixedVectorType(f16_ty, 2), b.getStringAttr(cvtIntr), - mlir::ValueRange{input}); + auto cvtOp = b.create(mlir::VectorType::get(2, f16_ty), + b.getStringAttr(cvtIntr), + mlir::ValueRange{input}); Value res = b.create( cvtOp.getResults(), b.create(0, 8)); if (to_ty.getWidth() > f16_ty.getWidth()) { From d3d86f7e7db4bda2eba6dfce2df0ca42ee927715 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 15 Apr 2025 16:50:34 -0700 Subject: [PATCH 0843/1324] Add setter for ExecutionOptions seed to ClientLibraryTestRunnerMixin. PiperOrigin-RevId: 748071969 --- third_party/xla/xla/tests/client_library_test_runner_mixin.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/tests/client_library_test_runner_mixin.h b/third_party/xla/xla/tests/client_library_test_runner_mixin.h index 51e95c328c4aaf..159ef5318b421f 100644 --- a/third_party/xla/xla/tests/client_library_test_runner_mixin.h +++ b/third_party/xla/xla/tests/client_library_test_runner_mixin.h @@ -376,6 +376,9 @@ class ClientLibraryTestRunnerMixin : public T { opts->set_xla_gpu_enable_fast_min_max(!disabled); } + void SetSeed(const uint64_t seed) { execution_options_.set_seed(seed); } + void ClearSeed() { execution_options_.clear_seed(); } + // Provides mutable access to the execution DebugOptions field; this lets // tests tweak the options that will be used to compile/run the graph. DebugOptions* mutable_debug_options() { From 074fcaa018108b91580d93e866fe5c0b874153e0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 17:25:07 -0700 Subject: [PATCH 0844/1324] Changed a few hot array iterations to use a new templated Array::Each API variation that avoids type-erasure and a virtual call per element. PiperOrigin-RevId: 748082231 --- third_party/xla/xla/array.h | 19 ++++++++++++++++++- third_party/xla/xla/hlo/ir/hlo_sharding.cc | 5 +++-- third_party/xla/xla/hlo/ir/tile_assignment.h | 8 ++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index fbe6cb944b1adb..216e868a97b369 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -324,6 +324,23 @@ class Array { } } + // Templated variants of Each() that avoid virtual function call + // overhead per element. Useful for hot code paths. + template // void(absl::Span, T*) + void TemplatedEach(const Fn& f) { + OwnedBuffer index(sizes_.size, default_init_t{}); + for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { + f(index.span(), &values_[i]); + } + } + template // void(absl::Span, T) + void TemplatedEach(const Fn& f) const { + OwnedBuffer index(sizes_.size, default_init_t{}); + for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { + f(index.span(), values_[i]); + } + } + // Invokes a callback with the (indices, value_ptr) for each cell in the // array. If a callback returns a non-OK status, returns that else returns // absl::OkStatus(). @@ -541,7 +558,7 @@ class Array { } Array permuted(permuted_dims.span()); OwnedBuffer src_indices(sizes_.size, -1); - permuted.Each([&](absl::Span indices, T* value) { + permuted.TemplatedEach([&](absl::Span indices, T* value) { for (int64_t i = 0; i < sizes_.size; ++i) { src_indices[permutation[i]] = indices[i]; } diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index 54551b70278e87..ad30ce559ea90a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -186,7 +186,7 @@ HloSharding HloSharding::PartialTile( } return group_id; }; - tile_assignment_last_dim_replicate.Each( + tile_assignment_last_dim_replicate.TemplatedEach( [&](absl::Span indices, const int64_t device) { const int64_t group_id = get_group_id(indices); sorted_groups[group_id * group_size + current_group_idx[group_id]++] = @@ -199,7 +199,8 @@ HloSharding HloSharding::PartialTile( absl::c_fill(current_group_idx, 0); auto sorted_tile = std::make_shared>( tile_assignment_last_dim_replicate.dimensions()); - sorted_tile->Each([&](absl::Span indices, int64_t* device) { + sorted_tile->TemplatedEach([&](absl::Span indices, + int64_t* device) { const int64_t group_id = get_group_id(indices); *device = sorted_groups[group_id * group_size + current_group_idx[group_id]++]; diff --git a/third_party/xla/xla/hlo/ir/tile_assignment.h b/third_party/xla/xla/hlo/ir/tile_assignment.h index 31d874328b64cb..9e32af53d6cdf1 100644 --- a/third_party/xla/xla/hlo/ir/tile_assignment.h +++ b/third_party/xla/xla/hlo/ir/tile_assignment.h @@ -214,6 +214,14 @@ class TileAssignment { void Each( absl::FunctionRef, int64_t)> f) const; + // Templated variant of Each() that avoids virtual function call + // overhead per element. Useful for hot code paths. + template + void TemplatedEach(const Fn& fn) const { + MaybeMaterializeFullArray(); + array_->TemplatedEach(fn); + } + absl::Status EachStatus( absl::FunctionRef, int64_t)> f) const; From 766638b395aa08b573fe220a878e1e66a4ffe430 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 18:36:16 -0700 Subject: [PATCH 0845/1324] Deprecate `Shape::tuple_shapes_size()` in favor of `Shape::tuple_shapes().size()`. This is to be consistent with `Shape::dimensions().size()`. PiperOrigin-RevId: 748100466 --- third_party/xla/xla/shape.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 0275d8a0218b7f..af59dc4edc589b 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -301,7 +301,8 @@ class Shape { // Returns the number of top-level tuple components in this shape. // Precondition: this is a tuple shape. - int tuple_shapes_size() const { return tuple_state().tuple_shapes.size(); } + ABSL_DEPRECATE_AND_INLINE() + inline int tuple_shapes_size() const { return tuple_shapes().size(); } // Returns the shape of the i-th tuple component. // Precondition: this is a tuple shape and `index` is a valid tuple component From eda96e35bac457dd4051da84b4b26985130ad028 Mon Sep 17 00:00:00 2001 From: Siqiao Wu Date: Tue, 15 Apr 2025 19:00:31 -0700 Subject: [PATCH 0846/1324] Internal change only PiperOrigin-RevId: 748105651 --- .../compiler/mlir/tfrt/transforms/ifrt/BUILD | 1 + .../core/tfrt/tfrt_session/tfrt_session.cc | 18 +++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index c957df04f2ff2f..abeca3e730bb7e 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -23,6 +23,7 @@ package_group( "//learning/pathways/serving/runtime/...", "//learning/pathways/serving/tests/...", "//learning/brain/tfrt/ifrt/...", + "//learning/brain/tfrt/tfrt_session/...", "//learning/brain/tfrt/mlir/mlrt/application/pathways/compiler/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc index 05f1a9c3521822..ea46dc949dbd77 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc @@ -503,7 +503,7 @@ class TfrtSession : public tensorflow::Session { compile_options.device_target = device_target_; compile_options.tpu_fuse_ops = tpu_use_tpu_runner_; compile_options.hoist_invariant_ops = true; - compile_options.sink_in_invariant_ops = false; + compile_options.sink_in_invariant_ops = true; compile_options.cost_threshold = 1024; if (use_gpu_) { @@ -779,18 +779,22 @@ void TfrtSessionFactory::RegisterInitializer(RuntimeInitializer initializer) { absl::Status TfrtSessionFactory::InitializeLocked( const TfrtSessionOptions& options) { mutex_.AssertHeld(); + if (options.backend_compiler) { + backend_compiler_ = options.backend_compiler; + } if (options.use_tpu) { - DCHECK(!options.backend_compiler); DCHECK(!options.use_gpu); device_target_ = TfrtDeviceInfraTarget::kTpurt; - tpu_use_tpu_runner_ = true; + if (!options.backend_compiler) { + tpu_use_tpu_runner_ = true; + } } else if (options.use_gpu) { - DCHECK(!options.backend_compiler); device_target_ = TfrtDeviceInfraTarget::kGpu; - use_gpu_ = true; - } else if (options.backend_compiler) { - backend_compiler_ = options.backend_compiler; + if (!options.backend_compiler) { + use_gpu_ = true; + } } + LOG(INFO) << "Start initializing TfrtSession"; if (options.runtime != nullptr) { runtime_ = options.runtime; From 5305b0a17b9145000135581355d4e1b9e5107e36 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 19:18:24 -0700 Subject: [PATCH 0847/1324] Add `--xnnpack_runtime_flags` option to pass raw XNNPACK flags to XNNPACK PiperOrigin-RevId: 748109461 --- .../configuration/configuration.proto | 2 ++ .../configuration/configuration_generated.h | 29 +++++++++++++++---- .../testdata/configuration.proto_prev | 2 ++ .../configuration/c/xnnpack_plugin.cc | 1 + .../delegates/xnnpack/xnnpack_delegate.cc | 3 ++ .../lite/delegates/xnnpack/xnnpack_delegate.h | 2 ++ .../delegates/xnnpack_delegate_provider.cc | 8 +++++ tensorflow/lite/tools/evaluation/utils.cc | 3 ++ 8 files changed, 44 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/acceleration/configuration/configuration.proto b/tensorflow/lite/acceleration/configuration/configuration.proto index 29d911bd1b05b1..f9e480ec2b3e6b 100644 --- a/tensorflow/lite/acceleration/configuration/configuration.proto +++ b/tensorflow/lite/acceleration/configuration/configuration.proto @@ -337,6 +337,8 @@ message XNNPackSettings { // reloaded from this cache which can reduce initialization time and the // packing memory footprint. optional string weight_cache_file_path = 3; + // Extra flags to pass to xnn_create_runtime. + optional int32 runtime_flags = 4; } // CoreML Delegate settings. diff --git a/tensorflow/lite/acceleration/configuration/configuration_generated.h b/tensorflow/lite/acceleration/configuration/configuration_generated.h index 4cb4861e78f4f4..0e7d3219ef4974 100644 --- a/tensorflow/lite/acceleration/configuration/configuration_generated.h +++ b/tensorflow/lite/acceleration/configuration/configuration_generated.h @@ -1701,6 +1701,7 @@ struct XNNPackSettingsT : public ::flatbuffers::NativeTable { int32_t num_threads = 0; tflite::XNNPackFlags flags = tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS; std::string weight_cache_file_path{}; + int32_t runtime_flags = 0; }; struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1709,7 +1710,8 @@ struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_NUM_THREADS = 4, VT_FLAGS = 6, - VT_WEIGHT_CACHE_FILE_PATH = 8 + VT_WEIGHT_CACHE_FILE_PATH = 8, + VT_RUNTIME_FLAGS = 10 }; int32_t num_threads() const { return GetField(VT_NUM_THREADS, 0); @@ -1720,12 +1722,16 @@ struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::String *weight_cache_file_path() const { return GetPointer(VT_WEIGHT_CACHE_FILE_PATH); } + int32_t runtime_flags() const { + return GetField(VT_RUNTIME_FLAGS, 0); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_NUM_THREADS, 4) && VerifyField(verifier, VT_FLAGS, 4) && VerifyOffset(verifier, VT_WEIGHT_CACHE_FILE_PATH) && verifier.VerifyString(weight_cache_file_path()) && + VerifyField(verifier, VT_RUNTIME_FLAGS, 4) && verifier.EndTable(); } XNNPackSettingsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -1746,6 +1752,9 @@ struct XNNPackSettingsBuilder { void add_weight_cache_file_path(::flatbuffers::Offset<::flatbuffers::String> weight_cache_file_path) { fbb_.AddOffset(XNNPackSettings::VT_WEIGHT_CACHE_FILE_PATH, weight_cache_file_path); } + void add_runtime_flags(int32_t runtime_flags) { + fbb_.AddElement(XNNPackSettings::VT_RUNTIME_FLAGS, runtime_flags, 0); + } explicit XNNPackSettingsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1761,8 +1770,10 @@ inline ::flatbuffers::Offset CreateXNNPackSettings( ::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_threads = 0, tflite::XNNPackFlags flags = tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS, - ::flatbuffers::Offset<::flatbuffers::String> weight_cache_file_path = 0) { + ::flatbuffers::Offset<::flatbuffers::String> weight_cache_file_path = 0, + int32_t runtime_flags = 0) { XNNPackSettingsBuilder builder_(_fbb); + builder_.add_runtime_flags(runtime_flags); builder_.add_weight_cache_file_path(weight_cache_file_path); builder_.add_flags(flags); builder_.add_num_threads(num_threads); @@ -1773,13 +1784,15 @@ inline ::flatbuffers::Offset CreateXNNPackSettingsDirect( ::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_threads = 0, tflite::XNNPackFlags flags = tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS, - const char *weight_cache_file_path = nullptr) { + const char *weight_cache_file_path = nullptr, + int32_t runtime_flags = 0) { auto weight_cache_file_path__ = weight_cache_file_path ? _fbb.CreateString(weight_cache_file_path) : 0; return tflite::CreateXNNPackSettings( _fbb, num_threads, flags, - weight_cache_file_path__); + weight_cache_file_path__, + runtime_flags); } ::flatbuffers::Offset CreateXNNPackSettings(::flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -4971,7 +4984,8 @@ inline bool operator==(const XNNPackSettingsT &lhs, const XNNPackSettingsT &rhs) return (lhs.num_threads == rhs.num_threads) && (lhs.flags == rhs.flags) && - (lhs.weight_cache_file_path == rhs.weight_cache_file_path); + (lhs.weight_cache_file_path == rhs.weight_cache_file_path) && + (lhs.runtime_flags == rhs.runtime_flags); } inline bool operator!=(const XNNPackSettingsT &lhs, const XNNPackSettingsT &rhs) { @@ -4991,6 +5005,7 @@ inline void XNNPackSettings::UnPackTo(XNNPackSettingsT *_o, const ::flatbuffers: { auto _e = num_threads(); _o->num_threads = _e; } { auto _e = flags(); _o->flags = _e; } { auto _e = weight_cache_file_path(); if (_e) _o->weight_cache_file_path = _e->str(); } + { auto _e = runtime_flags(); _o->runtime_flags = _e; } } inline ::flatbuffers::Offset XNNPackSettings::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -5004,11 +5019,13 @@ inline ::flatbuffers::Offset CreateXNNPackSettings(::flatbuffer auto _num_threads = _o->num_threads; auto _flags = _o->flags; auto _weight_cache_file_path = _o->weight_cache_file_path.empty() ? 0 : _fbb.CreateString(_o->weight_cache_file_path); + auto _runtime_flags = _o->runtime_flags; return tflite::CreateXNNPackSettings( _fbb, _num_threads, _flags, - _weight_cache_file_path); + _weight_cache_file_path, + _runtime_flags); } diff --git a/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev b/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev index 569042d3c88e7b..b8881307b3aa33 100644 --- a/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev +++ b/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev @@ -335,6 +335,8 @@ message XNNPackSettings { // reloaded from this cache which can reduce initialization time and the // packing memory footprint. optional string weight_cache_file_path = 3; + // Extra flags to pass to xnn_create_runtime + optional int32 runtime_flags = 4; } // CoreML Delegate settings. diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc index 1133b1b69c0e84..8154615931a43b 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc @@ -40,6 +40,7 @@ static TfLiteDelegate* CreateDelegate(const void* settings) { if (xnnpack_settings->flags()) { options.flags = xnnpack_settings->flags(); } + options.runtime_flags = xnnpack_settings->runtime_flags(); if (xnnpack_settings->weight_cache_file_path()) { options.weight_cache_file_path = xnnpack_settings->weight_cache_file_path()->c_str(); diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 48c8c03eedad52..d53dd1fd530b90 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -717,6 +717,8 @@ class Delegate { return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SLINKY) != 0; } + uint32_t runtime_flags() const { return options_.runtime_flags; } + bool support_variable_ops() const { if (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS) { return true; @@ -1314,6 +1316,7 @@ class Subgraph { constexpr uint32_t XNN_FLAG_SLINKY_ENABLED = 0x40000000; flags |= XNN_FLAG_SLINKY_ENABLED; } + flags |= delegate.runtime_flags(); if (delegate.weight_cache_provider_.IsActive() && delegate.weight_cache_provider_.CanStartBuildStep()) { diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h index e7f1713072776a..ccd2840a8a13ed 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h @@ -58,6 +58,8 @@ typedef struct { // Number of threads to use in the thread pool. // 0 or negative value means no thread pool used. int32_t num_threads; + // Flags to pass to `xnn_create_runtime` + uint32_t runtime_flags; // Bitfield with any combination of the following binary options: // - TFLITE_XNNPACK_DELEGATE_FLAG_QS8 // - TFLITE_XNNPACK_DELEGATE_FLAG_QU8 diff --git a/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc b/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc index c6cbcf8e7aab6a..5465e3613f6816 100644 --- a/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc @@ -34,6 +34,8 @@ class XnnpackDelegateProvider : public DelegateProvider { default_params_.AddParam("xnnpack_weight_cache_file_path", ToolParam::Create("")); default_params_.AddParam("xnnpack_slinky", ToolParam::Create(false)); + default_params_.AddParam("xnnpack_runtime_flags", + ToolParam::Create(0)); } std::vector CreateFlags(ToolParams* params) const final; @@ -67,6 +69,8 @@ std::vector XnnpackDelegateProvider::CreateFlags( "enable the Slinky optimizer. " "(Ignored if --use_xnnpack is false, or if XNNPACK is " "built without Slinky.)"), + CreateFlag("xnnpack_runtime_flags", params, + "Extra flags to pass to XNNPACK runtime."), }; return flags; } @@ -79,6 +83,8 @@ void XnnpackDelegateProvider::LogParams(const ToolParams& params, LOG_TOOL_PARAM(params, std::string, "xnnpack_weight_cache_file_path", "xnnpack_weight_cache_file_path", verbose); LOG_TOOL_PARAM(params, bool, "xnnpack_slinky", "Use Slinky", verbose); + LOG_TOOL_PARAM(params, int, "xnnpack_runtime_flags", + "Extra flags for XNNPACK runtime", verbose); } TfLiteDelegatePtr XnnpackDelegateProvider::CreateTfLiteDelegate( @@ -94,6 +100,8 @@ TfLiteDelegatePtr XnnpackDelegateProvider::CreateTfLiteDelegate( if (params.Get("xnnpack_slinky")) { opts.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SLINKY; } + opts.runtime_flags = params.Get("xnnpack_runtime_flags"); + const std::string path = params.Get("xnnpack_weight_cache_file_path"); if (!path.empty()) { diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index 671bdacd42e7f2..aa14fe169c127c 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -243,6 +243,9 @@ TfLiteDelegatePtr CreateXNNPACKDelegate( xnnpack_settings_builder.fbb_.AddElement( XNNPackSettings::VT_FLAGS, static_cast(xnnpack_options->flags), 0); + xnnpack_settings_builder.fbb_.AddElement( + XNNPackSettings::VT_RUNTIME_FLAGS, + static_cast(xnnpack_options->runtime_flags), 0); xnnpack_settings_builder.add_weight_cache_file_path(weight_cache_file_path); flatbuffers::Offset xnnpack_settings = xnnpack_settings_builder.Finish(); From 1806f6893dad58f29b8e72a8a332b65257c425b1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 22:00:03 -0700 Subject: [PATCH 0848/1324] Automated Code Change PiperOrigin-RevId: 748150643 --- tensorflow/lite/toco/BUILD | 2 ++ tensorflow/lite/toco/model_cmdline_flags.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index d31a2f83d67785..aa41da3c01f42b 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -151,6 +151,8 @@ cc_library( ":types_proto_cc", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/toco/model_cmdline_flags.cc b/tensorflow/lite/toco/model_cmdline_flags.cc index b916d80c43baa6..bc2b8ec50264ad 100644 --- a/tensorflow/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/lite/toco/model_cmdline_flags.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" From 1e553b8724542b0efcc9e670dfc7cb5f9b522560 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Apr 2025 23:39:52 -0700 Subject: [PATCH 0849/1324] Automated Code Change PiperOrigin-RevId: 748174023 --- tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc b/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc index 12d032bde9c56b..027f53cc3fc3e2 100644 --- a/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc +++ b/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.h" +#include #include #include From 6290ea06d9cf356f3058f8172f4aaf274e3062cc Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Tue, 15 Apr 2025 23:56:49 -0700 Subject: [PATCH 0850/1324] Preserve the original error code from the definition event if possible PiperOrigin-RevId: 748177595 --- .../xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index dfa9e753d30f69..3375f712543cc0 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -510,22 +510,23 @@ PjRtFuture<> AbstractTfrtCpuBuffer::GetReadyFuture() { if (definition_event.IsAvailable()) { if (definition_event.IsError()) { - return PjRtFuture<>( - FailedPrecondition("Buffer Definition Event: %s", - definition_event.GetError().message())); + const absl::Status& s = definition_event.GetError(); + return PjRtFuture<>(tsl::errors::CreateWithUpdatedMessage( + s, absl::StrCat("Buffer Definition Event: ", s.message()))); } return PjRtFuture<>(absl::OkStatus()); } else { PjRtFuture<>::Promise promise = PjRtFuture<>::CreatePromise(); - definition_event.AndThen([definition_event = definition_event.AsPtr(), - promise]() mutable { - if (definition_event.IsError()) { - promise.Set(FailedPrecondition("Buffer Definition Event: %s", - definition_event.GetError().message())); - } else { - promise.Set(); - } - }); + definition_event.AndThen( + [definition_event = definition_event.AsPtr(), promise]() mutable { + if (definition_event.IsError()) { + const absl::Status& s = definition_event.GetError(); + promise.Set(tsl::errors::CreateWithUpdatedMessage( + s, absl::StrCat("Buffer Definition Event: ", s.message()))); + } else { + promise.Set(); + } + }); std::string message = absl::StrCat(buffer_name(), "::Await"); return PjRtFuture<>( From 783aea3ef1ae6b46f40196a9485720fa5f58125d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 00:16:17 -0700 Subject: [PATCH 0851/1324] Automated Code Change PiperOrigin-RevId: 748183136 --- third_party/xla/xla/hlo/builder/lib/broadcast.cc | 1 + third_party/xla/xla/hlo/builder/lib/comparators.cc | 2 +- third_party/xla/xla/hlo/builder/lib/qr.cc | 2 +- third_party/xla/xla/hlo/builder/lib/tuple.cc | 2 -- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.cc b/third_party/xla/xla/hlo/builder/lib/broadcast.cc index aaabe046cebb02..1baec363b53a35 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.cc +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/broadcast.h" +#include #include #include "absl/algorithm/container.h" diff --git a/third_party/xla/xla/hlo/builder/lib/comparators.cc b/third_party/xla/xla/hlo/builder/lib/comparators.cc index a4965caab0d931..b6bd2254c02835 100644 --- a/third_party/xla/xla/hlo/builder/lib/comparators.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/comparators.h" -#include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/qr.cc b/third_party/xla/xla/hlo/builder/lib/qr.cc index 2118d54f345d4f..130e3ed754377d 100644 --- a/third_party/xla/xla/hlo/builder/lib/qr.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/hlo/builder/lib/qr.h" #include -#include +#include #include #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/tuple.cc b/third_party/xla/xla/hlo/builder/lib/tuple.cc index 1edfa273aff452..78ec14c1fd8be7 100644 --- a/third_party/xla/xla/hlo/builder/lib/tuple.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/hlo/builder/lib/tuple.h" -#include - #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "xla/hlo/builder/xla_builder.h" From a834efe8988a4296dbfbc55b11a0722204ced5b4 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 16 Apr 2025 00:38:33 -0700 Subject: [PATCH 0852/1324] [XLA:GPU] Use one-shot kernel in AllReduceThunk. This change adds support for one-shot all-reduce kernel. The kernel is currently limited to only 1 input and `F32` element type. PiperOrigin-RevId: 748188368 --- .../xla/xla/backends/gpu/runtime/BUILD | 16 + .../xla/backends/gpu/runtime/all_reduce.cc | 4 +- .../xla/xla/backends/gpu/runtime/all_reduce.h | 4 + .../backends/gpu/runtime/all_reduce_thunk.cc | 328 ++++++++++++++++-- .../backends/gpu/runtime/all_reduce_thunk.h | 33 ++ third_party/xla/xla/debug_options_flags.cc | 9 + .../xla/xla/tests/collective_ops_e2e_test.cc | 83 +++++ third_party/xla/xla/xla.proto | 6 +- 8 files changed, 447 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 3fb164156f02c6..e600afa6b979c6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -744,20 +744,36 @@ cc_library( deps = [ ":collective_thunk", ":thunk", + "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", + "//xla/backends/gpu/runtime:all_reduce", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", + "//xla/service:rendezvous", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu/transforms/collectives:collective_ops_utils", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:event", "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc index c9d654c7a3c2be..5ecd7469b79372 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc @@ -56,9 +56,9 @@ absl::Status LaunchTypedKernel( } } // namespace -bool IsAllReduceKernelSupported(int64_t num_outputs, +bool IsAllReduceKernelSupported(int64_t num_inputs, PrimitiveType element_type) { - return num_outputs <= stream_executor::gpu::kMaxNumAllReduceInputPtrs && + return num_inputs <= stream_executor::gpu::kMaxNumAllReduceInputPtrs && element_type == F32; } diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce.h b/third_party/xla/xla/backends/gpu/runtime/all_reduce.h index df0888a6b179c6..87baf4a24d0664 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce.h +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce.h @@ -27,6 +27,10 @@ limitations under the License. namespace xla::gpu { +// Returns true if the all-reduce kernel is supported for the given number of +// inputs and element type. +bool IsAllReduceKernelSupported(int64_t num_inputs, PrimitiveType element_type); + // Performs element-wise addition of all input buffers and stores the result in // the output buffer. // The kernel is intended to be used for all-reduce operations in environment diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc index 2d6bf02820dc7e..d2a9e5ed05f43b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc @@ -15,24 +15,41 @@ limitations under the License. #include "xla/backends/gpu/runtime/all_reduce_thunk.h" +#include #include +#include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/backends/gpu/runtime/all_reduce.h" #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" +#include "xla/service/rendezvous.h" +#include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_handle.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -40,27 +57,145 @@ limitations under the License. namespace xla { namespace gpu { +namespace { -absl::Status RunAllReduce(GpuCollectives* collectives, - ReductionKind reduction_kind, - std::vector& buffers, - se::Stream& stream, Communicator* comm) { - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; - TF_RETURN_IF_ERROR( - MaybeRegisterBuffers(collectives, stream.parent(), buffers, comm)); +constexpr int64_t kMaxOneShotAllReduceSizeBytes = 256 * 1024; - TF_RETURN_IF_ERROR(collectives->GroupStart()); - for (DeviceBufferPair& buffer : buffers) { - TF_RETURN_IF_ERROR(comm->AllReduce( - buffer.source_buffer, buffer.destination_buffer, buffer.element_type, - buffer.element_count, reduction_kind, GpuCollectives::On(stream))); +// Contains the values that are passed between host threads with rendezvous. +struct RendezvousValue { + RankId rank; + se::DeviceMemoryBase input_buffer; + se::Event* start_event; + se::Event* end_event; + + bool operator<(const RendezvousValue& other) const { + return rank < other.rank; + } +}; + +// Executes the rendezvous before the kernel start. +// Inserts CUDA events into the stream to ensure that all devices have reached +// the start event before the kernel starts. +absl::StatusOr>> +RendezvousBeforeKernelStart(const GpuCliqueKey& clique_key, RankId rank, + int64_t num_ranks, + const se::DeviceMemoryBase& input_buffer, + se::Stream& stream, se::Event* start_event, + se::Event* end_event) { + RendezvousValue rendezvous_value; + rendezvous_value.rank = rank; + rendezvous_value.input_buffer = input_buffer; + rendezvous_value.start_event = start_event; + rendezvous_value.end_event = end_event; + + // Record that this device has started executing the kernel. We do + // this before the rendezvous to make sure that RecordEvent is called before + // WaitFor on another stream. + TF_RETURN_IF_ERROR(stream.RecordEvent(start_event)); + + auto rendezvous_fn = [](absl::Span values) { + std::vector values_copy; + for (const auto& value : values) { + values_copy.push_back(*value); + } + // Sort to make sure that values are in the same order as the devices are + // ordered in the communicator. + absl::c_sort(values_copy); + return values_copy; + }; + + std::string start_rendezvous_key = + absl::StrFormat("start one-shot all-reduce for rank %d, clique %s", + rank.value(), clique_key.ToString()); + TF_ASSIGN_OR_RETURN( + std::shared_ptr> rendezvous_values, + Rendezvous>( + /*name=*/start_rendezvous_key, /*key=*/clique_key, + /*value=*/rendezvous_value, /*num_threads=*/num_ranks, + rendezvous_fn)); + + // Wait for all devices to reach the start event. This indicates that all + // output buffers are ready for transfer. + for (auto& value : *rendezvous_values) { + TF_RETURN_IF_ERROR(stream.WaitFor(value.start_event)); } - return collectives->GroupEnd(); + return rendezvous_values; +} + +// Executes the rendezvous after the kernel finish. Waits for all devices to +// reach the end event. +absl::Status RendezvousAfterKernelFinish( + const GpuCliqueKey& clique_key, RankId rank, int64_t num_ranks, + se::Stream& stream, se::Event* end_event, + const std::shared_ptr>& rendezvous_values) { + // Record that this device has finished executing the kernel. + TF_RETURN_IF_ERROR(stream.RecordEvent(end_event)); + + // Do another rendezvous to make sure that we call RecordEvent for end_event + // before WaitFor on another stream. + std::string finish_rendezvous_key = + absl::StrFormat("finish one-shot all-reduce for rank %d, clique %s", + rank.value(), clique_key.ToString()); + TF_RETURN_IF_ERROR(Rendezvous(/*name=*/finish_rendezvous_key, + /*key=*/clique_key, + /*num_threads=*/num_ranks)); + + // Wait for all devices to reach the end event. This indicates that all + // updates from other devices have arrived. + for (auto& value : *rendezvous_values) { + TF_RETURN_IF_ERROR(stream.WaitFor(value.end_event)); + } + + return absl::OkStatus(); } -namespace impl { +absl::Status RunOneShotAllReduce(const GpuCliqueKey& clique_key, RankId rank, + std::vector& buffers, + se::Stream& stream, Communicator* comm, + se::DeviceMemoryBase local_buffer, + se::Event* start_event, se::Event* end_event) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing one-shot all-reduce from device ordinal: " + << device_ordinal; + + // TODO(b/407736956): Support variadic all-reduce. + if (buffers.size() > 1) { + return absl::UnimplementedError( + "One-shot kernel does not support variadic all-reduce"); + } + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); + + const DeviceBufferPair& buffer = buffers[0]; + + // Buffer assignment aliases the source buffer to the destination buffer. This + // works for NCCL implementation, but for one-shot kernel, input and output + // buffers should be different. We do not have enough information at buffer + // assignement time to change aliasing, so we allocate a new device buffer + // ourselves and copy the data to it. + // TODO(b/407736956): Fuse the copy into the one-shot kernel. + TF_RETURN_IF_ERROR(stream.MemcpyD2D(&local_buffer, buffer.source_buffer, + buffer.source_buffer.size())); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr> rendezvous_values, + RendezvousBeforeKernelStart(clique_key, rank, num_ranks, local_buffer, + stream, start_event, end_event)); + + absl::InlinedVector input_ptrs; + for (auto& value : *rendezvous_values) { + input_ptrs.push_back(value.input_buffer); + } + + TF_RETURN_IF_ERROR(RunAllReduceKernel(&stream, buffer.element_type, + input_ptrs, buffer.destination_buffer, + num_ranks, buffer.element_count)); + + TF_RETURN_IF_ERROR(RendezvousAfterKernelFinish( + clique_key, rank, num_ranks, stream, end_event, rendezvous_values)); + + return absl::OkStatus(); +} absl::Status CheckImplementableInst(const HloInstruction* inst, Thunk::Kind reduction_op) { @@ -93,7 +228,26 @@ CollectiveOpGroupMode GetGroupModeInst(HloInstType* inst) { return GetAllReduceConfigInst(inst).config.group_mode; } -} // namespace impl +} // namespace + +absl::Status RunAllReduce(GpuCollectives* collectives, + ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, Communicator* comm) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(collectives, stream.parent(), buffers, comm)); + + TF_RETURN_IF_ERROR(collectives->GroupStart()); + for (DeviceBufferPair& buffer : buffers) { + TF_RETURN_IF_ERROR(comm->AllReduce( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count, reduction_kind, GpuCollectives::On(stream))); + } + + return collectives->GroupEnd(); +} AllReduceReduceScatterThunkBase::AllReduceReduceScatterThunkBase( Thunk::Kind kind, ThunkInfo thunk_info, AllReduceConfig config, @@ -108,22 +262,107 @@ AllReduceStartThunk::AllReduceStartThunk(ThunkInfo thunk_info, const HloAllReduceInstruction* inst, std::vector buffers, bool p2p_memcpy_enabled) - : AllReduceReduceScatterThunkBase(Thunk::kAllReduceStart, thunk_info, - impl::GetAllReduceConfigInst(inst), - std::move(buffers), - IsGPUSyncCollective(*inst)) {} + : AllReduceReduceScatterThunkBase( + Thunk::kAllReduceStart, thunk_info, GetAllReduceConfigInst(inst), + std::move(buffers), IsGPUSyncCollective(*inst)), + one_shot_kernel_enabled_( + inst->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_use_all_reduce_one_shot_kernel()) {} absl::Status AllReduceStartThunk::CheckImplementable( const HloAllReduceInstruction* inst, int64_t replica_count, int64_t partition_count) { return AddOpDescription( - impl::CheckImplementableInst(inst, Thunk::kAllReduceStart), inst, - replica_count, partition_count); + CheckImplementableInst(inst, Thunk::kAllReduceStart), inst, replica_count, + partition_count); } CollectiveOpGroupMode AllReduceStartThunk::GetGroupMode( const HloAllReduceInstruction* inst) { - return impl::GetGroupModeInst(inst); + return GetGroupModeInst(inst); +} + +absl::StatusOr AllReduceStartThunk::ShouldUseOneShotAllReduceKernel( + const GpuCliqueKey& clique_key, + const CollectiveCliques* collective_cliques) { + if (!one_shot_kernel_enabled_) { + return false; + } + + // TODO(b/407736956): Support variadic all-reduce. + if (buffers_.size() != 1) { + return false; + } + + int64_t num_elements = buffers_[0].element_count; + PrimitiveType element_type = config().operand_element_type[0]; + + int64_t input_size_bytes = + num_elements * ShapeUtil::ByteSizeOfPrimitiveType(element_type); + + // One-shot all-reduce is only beneficial for small inputs. + if (input_size_bytes > kMaxOneShotAllReduceSizeBytes) { + return false; + } + + TF_ASSIGN_OR_RETURN(bool peer_access_enabled, + collective_cliques->peer_access_enabled(clique_key)); + + // Check that peer access is enabled. + if (!peer_access_enabled) { + return false; + } + + return IsAllReduceKernelSupported(clique_key.num_local_participants(), + config().operand_element_type[0]); +} + +absl::Status AllReduceStartThunk::Initialize(const InitializeParams& params) { + TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); + + TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); + TF_ASSIGN_OR_RETURN( + GpuCliqueKey clique_key, + GetGpuCliqueKey(collectives, *params.collective_params, + config().replica_groups, config().group_mode, + GetAsyncStreamKind())); + + TF_ASSIGN_OR_RETURN( + bool use_one_shot_kernel, + ShouldUseOneShotAllReduceKernel(clique_key, params.collective_cliques)); + + if (use_one_shot_kernel) { + absl::MutexLock lock(&mutex_); + + if (!local_buffer_allocs_.contains(params.executor)) { + int64_t max_size = 0; + for (auto buffer : buffers_) { + max_size = std::max(max_size, buffer.source_buffer.size()); + } + + se::DeviceMemoryHandle local_buffer_alloc( + params.executor, params.executor->Allocate(max_size)); + + local_buffer_allocs_.emplace(params.executor, + std::move(local_buffer_alloc)); + } + + if (!start_events_.contains(params.executor)) { + TF_ASSIGN_OR_RETURN(std::unique_ptr event, + params.executor->CreateEvent()); + start_events_.emplace(params.executor, std::move(event)); + } + + if (!end_events_.contains(params.executor)) { + TF_ASSIGN_OR_RETURN(std::unique_ptr event, + params.executor->CreateEvent()); + end_events_.emplace(params.executor, std::move(event)); + } + } + + return absl::OkStatus(); } absl::Status AllReduceStartThunk::RunCollective( @@ -134,29 +373,52 @@ absl::Status AllReduceStartThunk::RunCollective( ConvertToDeviceBuffers(params, buffers_, config_.config.operand_element_type)); TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); - return ::xla::gpu::RunAllReduce(collectives, config_.reduction_kind, - device_buffers, stream, comm_handle.comm); + + TF_ASSIGN_OR_RETURN(bool use_one_shot_kernel, + ShouldUseOneShotAllReduceKernel( + comm_handle.clique_key, params.collective_cliques)); + + if (use_one_shot_kernel) { + se::Event* start_event = nullptr; + se::Event* end_event = nullptr; + se::DeviceMemoryBase local_buffer; + { + absl::MutexLock lock(&mutex_); + local_buffer = local_buffer_allocs_[stream.parent()].memory(); + start_event = start_events_[stream.parent()].get(); + end_event = end_events_[stream.parent()].get(); + } + + std::optional rank = + comm_handle.clique_key.rank(params.collective_params->global_device_id); + + return RunOneShotAllReduce(comm_handle.clique_key, *rank, device_buffers, + stream, comm_handle.comm, local_buffer, + start_event, end_event); + } + + return RunAllReduce(collectives, config_.reduction_kind, device_buffers, + stream, comm_handle.comm); } ReduceScatterStartThunk::ReduceScatterStartThunk( ThunkInfo thunk_info, const HloReduceScatterInstruction* inst, std::vector buffers, bool p2p_memcpy_enabled) - : AllReduceReduceScatterThunkBase(Thunk::kReduceScatterStart, thunk_info, - impl::GetAllReduceConfigInst(inst), - std::move(buffers), - IsGPUSyncCollective(*inst)) {} + : AllReduceReduceScatterThunkBase( + Thunk::kReduceScatterStart, thunk_info, GetAllReduceConfigInst(inst), + std::move(buffers), IsGPUSyncCollective(*inst)) {} /*static*/ absl::Status ReduceScatterStartThunk::CheckImplementable( const HloReduceScatterInstruction* inst, int64_t replica_count, int64_t partition_count) { return AddOpDescription( - impl::CheckImplementableInst(inst, Thunk::kReduceScatterStart), inst, + CheckImplementableInst(inst, Thunk::kReduceScatterStart), inst, replica_count, partition_count); } /*static*/ CollectiveOpGroupMode ReduceScatterStartThunk::GetGroupMode( const HloReduceScatterInstruction* inst) { - return impl::GetGroupModeInst(inst); + return GetGroupModeInst(inst); } absl::Status ReduceScatterStartThunk::RunCollective( @@ -167,8 +429,8 @@ absl::Status ReduceScatterStartThunk::RunCollective( ConvertToDeviceBuffers(params, buffers_, config_.config.operand_element_type)); TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); - return ::xla::gpu::RunReduceScatter(collectives, config_.reduction_kind, - device_buffers, stream, comm_handle.comm); + return RunReduceScatter(collectives, config_.reduction_kind, device_buffers, + stream, comm_handle.comm); } absl::Status RunReduceScatter(GpuCollectives* collectives, diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.h b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.h index ac5422de15fa62..8dc2d14ba5cf6e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.h @@ -17,15 +17,23 @@ limitations under the License. #define XLA_BACKENDS_GPU_RUNTIME_ALL_REDUCE_THUNK_H_ #include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/core/collectives/communicator.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory_handle.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" namespace xla { @@ -73,9 +81,34 @@ class AllReduceStartThunk : public AllReduceReduceScatterThunkBase { static CollectiveOpGroupMode GetGroupMode( const HloAllReduceInstruction* inst); + absl::StatusOr ShouldUseOneShotAllReduceKernel( + const GpuCliqueKey& clique_key, + const CollectiveCliques* collective_cliques); + + absl::Status Initialize(const InitializeParams& params) override; + protected: absl::Status RunCollective(const ExecuteParams& params, se::Stream& stream, CommunicatorHandle comm_handle) override; + + private: + bool one_shot_kernel_enabled_ = false; + + absl::Mutex mutex_; + + // Local buffer allocations to copy input data for the one-shot kernel. + absl::flat_hash_map + local_buffer_allocs_ ABSL_GUARDED_BY(mutex_); + + // Events to synchronize steams on different devices at the start of the + // one-shot kernel. + absl::flat_hash_map> + start_events_ ABSL_GUARDED_BY(mutex_); + + // Events to synchronize steams on different devices at the end of the + // one-shot kernel. + absl::flat_hash_map> + end_events_ ABSL_GUARDED_BY(mutex_); }; // ----------------------------------------------------------------------------- diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index d3a1683fd35e92..43ce83ae9e346e 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -326,6 +326,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_pjrt_allow_auto_layout_in_hlo(false); opts.set_xla_gpu_enable_scatter_determinism_expander(true); opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(false); + opts.set_xla_gpu_unsupported_use_all_reduce_one_shot_kernel(false); opts.set_xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel(true); opts.set_xla_gpu_unsupported_enable_all_reduce_decomposer(false); opts.set_xla_gpu_experimental_pack_dot_operands_along_k_dimension(true); @@ -2253,6 +2254,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(), "Internal: Enable the RaggedAllToAllDecomposer, an experimental pass " "that rewrites ragged-all-to-all as a dense all-to-all operation.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_unsupported_use_all_reduce_one_shot_kernel", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_unsupported_use_all_reduce_one_shot_kernel), + debug_options->xla_gpu_unsupported_use_all_reduce_one_shot_kernel(), + "Internal: Enable the one-shot kernel for single-host all-reduce " + "operations.")); flag_list->push_back(tsl::Flag( "xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel", bool_setter_for( diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index a84445ff72a19f..5ab157d65790b7 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -3146,5 +3146,88 @@ ENTRY main { EXPECT_EQ(ag_start->operand(0)->shape().layout().memory_space(), 1); } +class AllReduceTest + : public CollectiveOpsWithFlagsBase, + public ::testing::WithParamInterface> { + public: + AllReduceTest() + : CollectiveOpsWithFlagsBase(std::get<0>(GetParam()), + /*enable_p2p_memcpy=*/false) {} + + protected: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions opts = CollectiveOpsWithFlagsBase::GetDebugOptionsForTest(); + + opts.set_xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel( + std::get<1>(GetParam())); + + return opts; + } +}; + +TEST_P(AllReduceTest, AsyncAllReduce) { + const absl::string_view kModuleStr = R"( + HloModule test + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY test_computation { + param_0 = f32[65536] parameter(0) + ROOT all-reduce = f32[65536] all-reduce(param_0), to_apply=apply_op, replica_groups={{0,1}} + } + )"; + + const int64_t kNumReplicas = 2; + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + int64_t num_elements = + module->entry_computation()->root_instruction()->shape().dimensions()[0]; + + Array input1({num_elements}), input2({num_elements}); + input1.FillRandom(1.0f, 10.0f, /*seed=*/0); + input2.FillRandom(1.0f, 10.0f, /*seed=*/1); + Array expected_output({num_elements}); + expected_output.Each([&](absl::Span indices, float* val) { + *val = input1(indices) + input2(indices); + }); + + Literal input_literal1 = LiteralUtil::CreateFromArray(input1); + Literal input_literal2 = LiteralUtil::CreateFromArray(input2); + Literal expected_output_literal = + LiteralUtil::CreateFromArray(expected_output); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + HloTestBase::ExecuteReplicated(std::move(module), + {{&input_literal1}, {&input_literal2}}, + /*num_replicas=*/kNumReplicas, + /*run_hlo_passes=*/true, + /*device_assignment=*/nullptr)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_output_literal, results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_output_literal, results[1])); +} + +INSTANTIATE_TEST_SUITE_P( + AllReduceTest, AllReduceTest, + ::testing::Combine(::testing::Bool(), ::testing::Bool()), + [](const ::testing::TestParamInfo>& info) { + return absl::StrCat(GetAsyncTestName(std::get<0>(info.param)), "_", + std::get<1>(info.param) ? "one_shot" : "nccl"); + }); + } // namespace } // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index af97d839525c49..0c32b62a0f472f 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -787,6 +787,10 @@ message DebugOptions { // TODO(b/390559452): Remove the flag once the feature is stable. bool xla_gpu_unsupported_enable_triton_multi_output_fusion = 382; + // Internal testing flag to enable one-shot kernel for single-host + // all-reduce operations. + bool xla_gpu_unsupported_use_all_reduce_one_shot_kernel = 387; + // Internal testing flag to enable one-shot kernel for single-host // ragged-all-to-all operations. bool xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel = 375; @@ -1200,7 +1204,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 387 + // Next id: 388 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 028e650f1496d12ee0a680c0ab4b951be387defd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 02:02:37 -0700 Subject: [PATCH 0853/1324] Update GraphDef version to 2199. PiperOrigin-RevId: 748210353 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e98dbab9f65a0e..e25c61fb6f0a5e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2198 // Updated: 2025/4/15 +#define TF_GRAPH_DEF_VERSION 2199 // Updated: 2025/4/16 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 88d786ed68ea6a4c0ae0889d8a5225c6dad7f454 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 02:02:47 -0700 Subject: [PATCH 0854/1324] compat: Update forward compatibility horizon to 2025-04-16 PiperOrigin-RevId: 748210414 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 02ca88b73f8257..2ecb50367ed8df 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 15) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 16) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From bce8c64e27db1374b40e791fce002257a9813ed3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 02:11:12 -0700 Subject: [PATCH 0855/1324] Add DebugOptions to gpu backend base class. PiperOrigin-RevId: 748212735 --- .../xla/backends/autotuner/backends/gpu/BUILD | 1 + .../autotuner/backends/gpu/gpu_backend.h | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD b/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD index 4de5fe9f3b1db2..60d643e2b5ce51 100644 --- a/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD +++ b/third_party/xla/xla/backends/autotuner/backends/gpu/BUILD @@ -9,6 +9,7 @@ cc_library( name = "gpu_backend", hdrs = ["gpu_backend.h"], deps = [ + "//xla:xla_proto_cc", "//xla/backends/autotuner:backend", "//xla/hlo/ir:hlo", "//xla/service:compiler", diff --git a/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h b/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h index 46d23ba057fe6c..f3f25941276247 100644 --- a/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h +++ b/third_party/xla/xla/backends/autotuner/backends/gpu/gpu_backend.h @@ -29,23 +29,32 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla.pb.h" namespace xla { // Abstract base class for GPU backends, implementing the Backend interface. class GpuBackend : public Backend { public: - // target_config and compiler should outlive the backend. + // target_config, debug_options and compiler should outlive the backend. GpuBackend(absl::string_view name, - const Compiler::TargetConfig& target_config, Compiler* compiler) - : name_(name), target_config_(target_config), compiler_(compiler) {} + const Compiler::TargetConfig* target_config, + const DebugOptions* debug_options, Compiler* compiler) + : name_(name), + target_config_(*target_config), + debug_options_(*debug_options), + compiler_(compiler) {} absl::string_view name() const override { return name_; } + const Compiler::TargetConfig& target_config() const { return target_config_; } + const DebugOptions& debug_options() const { return debug_options_; } + absl::StatusOr> Compile( const HloInstruction& hlo_instruction, const BackendConfig& config) override { TF_ASSIGN_OR_RETURN(auto hlo_module, WrapInModule(hlo_instruction, config)); + hlo_module->mutable_config().set_debug_options(debug_options_); Compiler::CompileOptions options; options.target_config = target_config_; @@ -70,6 +79,7 @@ class GpuBackend : public Backend { std::string name_; const Compiler::TargetConfig& target_config_; + const DebugOptions& debug_options_; // TODO(b/407494653): remove compiler when we don't need to run any HLO passes // and the codegen backend can directly produce an executable without a // compiler instance. From ee1c384df386dc2eab0971564f4774e4aacc7c46 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 16 Apr 2025 02:38:41 -0700 Subject: [PATCH 0856/1324] NFC: Use the free function variants for dyn_cast/cast/isa/.... The member functions in Type/Attribute/Value/Location/AffineExpr got [removed](https://github.com/llvm/llvm-project/commit/0078cf79adc2f24a168bc774cba1f39dda5e3752). PiperOrigin-RevId: 748219915 --- .../compiler/mlir/lite/flatbuffer_export.cc | 78 +++++++++++-------- .../lite/stablehlo/odml_converter/folders.cc | 6 +- .../stablehlo/transforms/optimize_layout.cc | 4 +- .../compiler/mlir/lite/transforms/quantize.cc | 4 +- tensorflow/compiler/mlir/lite/utils/utils.h | 17 ++-- .../stablehlo/passes/insert_weight_param.cc | 4 +- .../passes/insert_custom_aggregation_ops.cc | 2 +- .../tensorflow/utils/xla_sharding_util.cc | 2 +- .../internal/passes/tpu_cluster_formation.cc | 8 +- .../codegen/emitters/elemental_hlo_to_mlir.cc | 2 +- .../emitters/transforms/lower_tensors.cc | 2 +- .../transforms/vectorize_loads_stores.cc | 2 +- .../xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 2 + .../translate/hlo_to_mhlo/async_importer.cc | 9 ++- .../hlo_to_mhlo/hlo_function_importer.cc | 41 +++++----- .../hlo/translate/hlo_to_mhlo/hlo_utils.cc | 5 +- .../tools/mlir_interpreter/dialects/arith.cc | 6 +- .../dialects/bufferization.cc | 2 +- .../tools/mlir_interpreter/dialects/linalg.cc | 4 +- .../tools/mlir_interpreter/dialects/mhlo.cc | 8 +- .../tools/mlir_interpreter/dialects/tensor.cc | 6 +- .../tools/mlir_interpreter/dialects/util.cc | 8 +- .../tools/mlir_interpreter/dialects/vector.cc | 13 ++-- .../mlir/tools/mlir_replay/mlir_replay_lib.cc | 10 ++- .../public/execution_trace_utils.cc | 6 +- .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 2 +- 26 files changed, 137 insertions(+), 116 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 3429c156f1551b..6045278ffa541e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1716,30 +1716,34 @@ void CreateFlexbufferVector( const std::unique_ptr& flex_builder, std::string& name, const mlir::Attribute& attr) { auto start = flex_builder->StartVector(name.c_str()); - auto array = attr.cast().getValue(); + auto array = mlir::cast(attr).getValue(); for (int i = 0; i < array.size(); i++) { if (llvm::isa(array[i])) { flex_builder->Bool(name.c_str(), - array[i].cast().getValue()); + mlir::cast(array[i]).getValue()); } else if (llvm::isa(attr)) { - flex_builder->String(name.c_str(), - array[i].cast().getValue().str()); + flex_builder->String( + name.c_str(), + mlir::cast(array[i]).getValue().str()); } else if (llvm::isa(array[i])) { - flex_builder->Bool(name.c_str(), - array[i].cast().getValue()); + flex_builder->Bool( + name.c_str(), + mlir::cast(array[i]).getValue()); } else if (llvm::isa(array[i])) { flex_builder->String( name.c_str(), - array[i].cast().getValue().str()); + mlir::cast(array[i]).getValue().str()); } else if (llvm::isa(array[i])) { - flex_builder->Int( - name.c_str(), - array[i].cast().getValue().getSExtValue()); + flex_builder->Int(name.c_str(), + mlir::cast(array[i]) + .getValue() + .getSExtValue()); } else if (llvm::isa(array[i])) { - flex_builder->Float( - name.c_str(), - array[i].cast().getValue().convertToFloat()); + flex_builder->Float(name.c_str(), + mlir::cast(array[i]) + .getValue() + .convertToFloat()); } else if (llvm::isa(array[i])) { CreateFlexbufferVector(flex_builder, name, array[i]); @@ -1835,43 +1839,49 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, uint32_t opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_COMPOSITE); - int32_t api_version = composite_op.getVersion() - .cast() - .getValue() - .getSExtValue(); + int32_t api_version = + mlir::cast(composite_op.getVersion()) + .getValue() + .getSExtValue(); auto name = builder_.CreateString( - composite_op.getName().cast().getValue().str()); + mlir::cast(composite_op.getName()) + .getValue() + .str()); - auto composite_attributes = composite_op.getCompositeAttributes() - .cast(); + auto composite_attributes = mlir::cast( + composite_op.getCompositeAttributes()); auto flex_builder = std::make_unique(); size_t map_start = flex_builder->StartMap(); for (auto namedAttr : composite_attributes.getValue()) { auto name = - namedAttr.first.cast().getValue().str(); + mlir::cast(namedAttr.first).getValue().str(); auto attr = namedAttr.second; if (llvm::isa(attr)) - flex_builder->Bool(name.c_str(), attr.cast().getValue()); + flex_builder->Bool(name.c_str(), + mlir::cast(attr).getValue()); else if (llvm::isa(attr)) flex_builder->String(name.c_str(), - attr.cast().getValue().str()); + mlir::cast(attr).getValue().str()); else if (llvm::isa(attr)) - flex_builder->Bool(name.c_str(), - attr.cast().getValue()); + flex_builder->Bool( + name.c_str(), mlir::cast(attr).getValue()); else if (llvm::isa(attr)) flex_builder->String( - name.c_str(), attr.cast().getValue().str()); - else if (llvm::isa(attr)) - flex_builder->Int( name.c_str(), - attr.cast().getValue().getSExtValue()); + mlir::cast(attr).getValue().str()); + else if (llvm::isa(attr)) + flex_builder->Int(name.c_str(), + mlir::cast(attr) + .getValue() + .getSExtValue()); else if (llvm::isa(attr)) - flex_builder->Float( - name.c_str(), - attr.cast().getValue().convertToFloat()); + flex_builder->Float(name.c_str(), + mlir::cast(attr) + .getValue() + .convertToFloat()); else if (llvm::isa(attr)) CreateFlexbufferVector(flex_builder, name, attr); else if (llvm::isa(attr)) { @@ -1932,8 +1942,8 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, flex_builder->Finish(); int32_t decomposition_subgraph_index = - subgraph_index_map_[composite_op.getDecomposition() - .cast() + subgraph_index_map_[mlir::cast( + composite_op.getDecomposition()) .getValue() .str()]; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc index cb48050db47cb5..778e76c79c984b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc @@ -104,7 +104,7 @@ static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, } auto res_attr = DenseElementsAttr::get( - const_oprs[0].getType().cast(), res); + mlir::cast(const_oprs[0].getType()), res); rewriter.replaceOpWithNewOp(adaptor.value().Op(), res_attr); return success(); @@ -112,10 +112,10 @@ static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, static LogicalResult FoldDivOp(stablehlo::DivOp op, PatternRewriter& rewriter) { auto etype = op.getType().getElementType(); - if (etype.isa()) { + if (mlir::isa(etype)) { return FoldDivOpInternal(op, rewriter); } - if (etype.isa()) { + if (mlir::isa(etype)) { return FoldDivOpInternal(op, rewriter); } return failure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc index d251f49cfa28bf..b0bbeb57c5a6ac 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc @@ -91,7 +91,7 @@ struct TransposeCommuteWithPad : public OpRewritePattern { LogicalResult matchAndRewrite(stablehlo::PadOp pad_op, PatternRewriter& rewriter) const override { Value pad_input = pad_op.getOperand(); - RankedTensorType pad_type = pad_op.getType().cast(); + RankedTensorType pad_type = mlir::cast(pad_op.getType()); auto transpose_op = pad_input.getDefiningOp(); if (!transpose_op || !transpose_op->hasOneUse()) return failure(); @@ -132,7 +132,7 @@ struct TransposeCommuteWithReduceWindow Value reduce_input = inputs[0]; RankedTensorType reduce_type = - reduce_op.getResultTypes()[0].cast(); + mlir::cast(reduce_op.getResultTypes()[0]); auto transpose_op = reduce_input.getDefiningOp(); if (!transpose_op || !transpose_op->hasOneUse()) return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index be0c1803543140..eb545059806905 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -236,7 +236,7 @@ class StrictQuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } @@ -292,7 +292,7 @@ class StrictQuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 3cf71da1d31ac7..2d2fcd73fc39ff 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -83,7 +83,7 @@ inline bool OpHasSameStaticShapes(Operation* op) { int operand_num = 0; ArrayRef shape; for (Value value : values) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type || !shaped_type.hasStaticShape()) { return false; } @@ -247,14 +247,14 @@ inline bool IsReshapeEquivalentToTranspose(mlir::ShapedType input_type, // Checks if all elements in the constant attribute value are 1. inline bool IsAllOnesConstant(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of(values.begin(), values.end(), [](int32_t element_value) { return element_value != 1; }); } // Checks if all elements in the constant attribute value are non-negative. inline bool HasNonNegativeValues(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of( values.begin(), values.end(), [](const APInt& element_value) { return element_value.isNegative(); }); @@ -262,8 +262,8 @@ inline bool HasNonNegativeValues(Attribute value) { // Utility function to get the offset between two dense attribute values. inline TypedAttr GetOffSet(Attribute begin, Attribute end) { - auto begin_values = begin.cast().getValues(); - auto end_values = end.cast().getValues(); + auto begin_values = mlir::cast(begin).getValues(); + auto end_values = mlir::cast(end).getValues(); SmallVector offsets; if (begin_values.size() == end_values.size()) { @@ -301,7 +301,7 @@ inline bool AreLastTwoDimsTransposed(Value permutation) { // Gets the new type after transposing the last 2 dimensions. inline Type TransposeLastTwoDims(Type type) { - auto shaped_type = type.dyn_cast(); + auto shaped_type = mlir::dyn_cast(type); if (!shaped_type.hasStaticShape() || shaped_type.getRank() < 2) { return nullptr; } @@ -319,7 +319,7 @@ inline Type TransposeLastTwoDims(Type type) { // applying the permutation to the given shape through a transpose. inline mlir::ShapedType GetTransposedType( Value input, llvm::ArrayRef permutation_array) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); if (permutation_array.size() != input_type.getRank()) { return nullptr; } @@ -371,7 +371,8 @@ inline mlir::ShapedType GetExpandedShapeType(Value input_val, int n) { // Returns a squeezed shape when `squeeze_leading_ones` is set to true. inline SmallVector GetShape(Value input_value, bool squeeze_leading_ones = false) { - auto output_shape = input_value.getType().dyn_cast().getShape(); + auto output_shape = + mlir::dyn_cast(input_value.getType()).getShape(); SmallVector shape; shape.reserve(output_shape.size()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index ac8649835f78ff..fb2e5caba7b59f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -219,7 +219,7 @@ class InsertWeightParamPattern dimension_numbers.getRhsContractingDimensions(); ArrayRef rhs_batching_dims = dimension_numbers.getRhsBatchingDimensions(); - int64_t rank = dot.getRhs().getType().cast().getRank(); + int64_t rank = mlir::cast(dot.getRhs().getType()).getRank(); for (int i = 0; i < rank; ++i) { // Return the first non-contracting, non-batching dimension of rhs. if (llvm::find(rhs_contracting_dims, i) == rhs_contracting_dims.end() && @@ -228,7 +228,7 @@ class InsertWeightParamPattern } } } - return op.getOperand(1).getType().cast().getRank() - 1; + return mlir::cast(op.getOperand(1).getType()).getRank() - 1; } }; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 886f9cd28a127b..ec7ffefd2d43f7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -86,7 +86,7 @@ std::optional GetCompsiteFunctionName(Operation *op) { return entry_function_attr.getValue(); } else { TF::PartitionedCallOp call_op = dyn_cast_or_null(op); - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); if (!f_attr) return std::nullopt; return f_attr.getValue(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index ac8ecf1090b2a7..b87afe63412551 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -380,7 +380,7 @@ mlir::LogicalResult HandleTileShardedInputsUsingXlaSplitOps( std::vector paddings; paddings.reserve(rank); auto shape = llvm::to_vector<4>( - original_source.getType().cast().getShape()); + mlir::cast(original_source.getType()).getShape()); for (int dim = 0; dim < rank; ++dim) { paddings.push_back( GetPadding(dim, input_sharding.tile_assignment_dimensions(dim), diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index db39ca12d9ce91..0da0cc4fc4ddfb 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -872,7 +872,7 @@ LogicalResult FormClustersInBlock( block, cluster_ops, results, cluster_successor_ops.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); - if (!num_replicas || !num_replicas.isa()) + if (!num_replicas || !mlir::isa(num_replicas)) return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; @@ -881,9 +881,9 @@ LogicalResult FormClustersInBlock( cluster_metadata->getSecond().get(kNumCoresPerReplicaAttr)); if (num_cores_per_replica_attr) num_cores_per_replica = num_cores_per_replica_attr.getInt(); - if (failed(ReplicateCluster(cluster, - num_replicas.cast().getInt(), - num_cores_per_replica))) + if (failed(ReplicateCluster( + cluster, mlir::cast(num_replicas).getInt(), + num_cores_per_replica))) return mlir::failure(); // Copy TPUReplicateMetadata attributes to `tf_device.cluster`. diff --git a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc index 79bca0e4e850fb..e710fc51e4e053 100644 --- a/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc @@ -524,7 +524,7 @@ absl::StatusOr> EmitDotLoop( const mlir::Type accumulator_type = result_element_type.isBF16() ? b.getF32Type() : result_element_type; Value accum_init_value; - if (auto complex_ty = accumulator_type.dyn_cast()) { + if (auto complex_ty = mlir::dyn_cast(accumulator_type)) { // For complex, build real-zero and imag-zero separately: mlir::Type element_ty = complex_ty.getElementType(); diff --git a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc index f4cddfe8c0b9fb..ec9e6bfcc46d26 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/lower_tensors.cc @@ -147,7 +147,7 @@ std::optional GetAlignmentFromArg(Value addr, ValueRange indices) { auto align_attr = func.getArgAttr(base.getArgNumber(), ml::LLVMDialect::getAlignAttrName()); if (!align_attr) return std::nullopt; - return align_attr.cast().getValue().getSExtValue(); + return mlir::cast(align_attr).getValue().getSExtValue(); } template diff --git a/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc index 0e4524087757a5..fa0e1e4afcfe6d 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc +++ b/third_party/xla/xla/codegen/emitters/transforms/vectorize_loads_stores.cc @@ -101,7 +101,7 @@ int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr, std::optional rhs_cst = std::nullopt; if (binop.getRHS().getKind() == mlir::AffineExprKind::Constant) { - rhs_cst = binop.getRHS().cast().getValue(); + rhs_cst = mlir::cast(binop.getRHS()).getValue(); } switch (binop.getKind()) { diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index 678814e990e69a..eb81655b95072e 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -49,6 +49,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_ops", ], ) @@ -193,6 +194,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:SparseTensorEnums", + "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc index bbd49177867e74..c453cd2e1b172a 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -81,7 +82,7 @@ absl::StatusOr ImportOldStyleAsyncStart( return tsl::errors::InvalidArgument( "expected async_bundle tuple result type"); } - auto result_types = result_type.cast().getTypes(); + auto result_types = mlir::cast(result_type).getTypes(); if (result_types.size() < 2) { return tsl::errors::InvalidArgument( "async_bundle must contain at least two values"); @@ -213,7 +214,7 @@ absl::StatusOr ImportSend( // format of (args, results, scratchpad), so to rewrite the `send` and // `send-done` ops to use the new-style async API, we need to reorder the // arguments to be in (args, token, sync flag) order. - auto result_types = result_type.cast().getTypes(); + auto result_types = mlir::cast(result_type).getTypes(); if (result_types.size() != 3) return InvalidArgument("send should return a 3-tuple"); auto async_arg_type = mlir::TupleType::get( @@ -447,8 +448,8 @@ absl::StatusOr ImportCopyStart( *cross_program_prefetch_index))); // Cross-program prefetch allows copy ops to accept tuples, in which // case, we need to double-wrap inputs and outputs in tuples. - if (operands[0].getType().isa()) { - auto result_types = result_type.cast().getTypes(); + if (mlir::isa(operands[0].getType())) { + auto result_types = mlir::cast(result_type).getTypes(); result_type = mlir::TupleType::get( context, {mlir::TupleType::get(context, {result_types[0]}), diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 100bd69ed7b084..2bedc76a2f25db 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -330,7 +330,7 @@ void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( } static bool IsNestedTupleInData(Type type) { - auto tuple_type = type.dyn_cast(); + auto tuple_type = mlir::dyn_cast(type); if (!tuple_type) return false; assert((llvm::isa(tuple_type.getType(1)) || @@ -338,11 +338,11 @@ static bool IsNestedTupleInData(Type type) { "Infeed: Non token type"); auto data_type = tuple_type.getType(0); - auto data_tuple_type = data_type.dyn_cast(); + auto data_tuple_type = mlir::dyn_cast(data_type); if (!data_tuple_type) return false; for (auto child_type : data_tuple_type.getTypes()) { - if (child_type.isa()) return true; + if (mlir::isa(child_type)) return true; } return false; @@ -360,7 +360,7 @@ mlir::Attribute GetFrontendAttributes(mlir::Builder& b, void HloFunctionImporter::FlattenTupleType( Type type, llvm::SmallVectorImpl& flattened_types) { - auto tuple_type = type.dyn_cast(); + auto tuple_type = mlir::dyn_cast(type); if (!tuple_type) { flattened_types.push_back(type); return; @@ -660,7 +660,7 @@ absl::Status HloFunctionImporter::ImportInstructions( int flatten_idx = 0; for (Type computation_arg_type : computation_arg_types) { auto orig_tuple_arg_type = - computation_arg_type.dyn_cast(); + mlir::dyn_cast(computation_arg_type); // If the computation-parameter type is non-tuple, no action is needed. if (!orig_tuple_arg_type) { @@ -824,7 +824,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (instruction->opcode() == HloOpcode::kAsyncStart) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( - context_, result_type.cast().getTypes()); + context_, mlir::cast(result_type).getTypes()); // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, @@ -832,7 +832,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( .getOperation(); } else if (instruction->opcode() == HloOpcode::kAsyncUpdate) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( - context_, result_type.cast().getTypes()); + context_, mlir::cast(result_type).getTypes()); // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, @@ -1198,7 +1198,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } else { mlir::Attribute attr = mlir::parseAttribute(raw_backend_config, builder_->getContext()); - if (!attr.isa()) + if (!mlir::isa(attr)) return Internal( "Couldn't parse backend config into a dictionary attribute"); @@ -1451,7 +1451,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto sort_instruction = Cast(instruction); llvm::SmallVector return_types = {result_type}; - if (mlir::TupleType tuple_ty = result_type.dyn_cast()) { + if (mlir::TupleType tuple_ty = + mlir::dyn_cast(result_type)) { return_types = llvm::to_vector<6>(tuple_ty.getTypes()); } @@ -1473,8 +1474,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto topk_instruction = Cast(instruction); // XLA Feature -- MHLO Only auto topk_op = func_builder->create( - loc, result_type.dyn_cast().getTypes(), operands[0], - builder_->getI64IntegerAttr(topk_instruction->k()), + loc, mlir::dyn_cast(result_type).getTypes(), + operands[0], builder_->getI64IntegerAttr(topk_instruction->k()), builder_->getBoolAttr(topk_instruction->largest())); return WrapInTuple(func_builder, topk_op); } @@ -1513,7 +1514,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( flattened_operands.end()); mlir::Type pred_or_index_type = - operands[0].getType().cast().getElementType(); + mlir::cast(operands[0].getType()).getElementType(); // It is a predicated conditional if first argument is a boolean and // should be mapped to If op. if (pred_or_index_type.isInteger(1)) { @@ -1580,7 +1581,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kAllGather: { auto all_gather = Cast(instruction); - auto result_tuple_ty = result_type.dyn_cast(); + auto result_tuple_ty = mlir::dyn_cast(result_type); llvm::SmallVector result_types = {result_type}; if (result_tuple_ty) { @@ -1613,7 +1614,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kAllReduce: { auto all_reduce = Cast(instruction); - auto result_tuple_ty = result_type.dyn_cast(); + auto result_tuple_ty = mlir::dyn_cast(result_type); llvm::SmallVector result_types = {result_type}; if (result_tuple_ty) { @@ -1713,7 +1714,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( // operands are corresponding initial values. size_t num_inputs = operands.size() / 2; llvm::SmallVector return_types = {result_type}; - if (mlir::TupleType tuple_ty = result_type.dyn_cast()) { + if (mlir::TupleType tuple_ty = + mlir::dyn_cast(result_type)) { return_types = llvm::to_vector<6>(tuple_ty.getTypes()); } @@ -1747,7 +1749,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kRng: { auto shape = func_builder->create( - loc, Convert(result_type.cast().getShape())); + loc, Convert(mlir::cast(result_type).getShape())); switch (instruction->random_distribution()) { case RNG_UNIFORM: return func_builder @@ -1905,7 +1907,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kReduceWindow: { llvm::SmallVector return_types = {result_type}; - if (mlir::TupleType tuple_ty = result_type.dyn_cast()) { + if (mlir::TupleType tuple_ty = + mlir::dyn_cast(result_type)) { return_types = llvm::to_vector<6>(tuple_ty.getTypes()); } llvm::SmallVector sizes, strides, base_dilations, @@ -2084,10 +2087,10 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( case HloOpcode::kConvert: { // Convert to boolean is special, it requires a comparison to 0 instead of // a truncation to i1, otherwise it is a 1-1 translation. - auto ranked_type = result_type.dyn_cast(); + auto ranked_type = mlir::dyn_cast(result_type); mlir::IntegerType integer_type = (ranked_type) - ? ranked_type.getElementType().dyn_cast() + ? mlir::dyn_cast(ranked_type.getElementType()) : nullptr; if (!integer_type || integer_type.getWidth() != 1) { // Simple case: 1-1 mapping. diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index 5bd16a05ee9489..983efbc7987f29 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -195,7 +196,7 @@ bool HasMhloTokenType(mlir::TypeRange types) { mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, mlir::ValueRange& flatten_values, mlir::Type type) { - auto tuple_type = type.dyn_cast(); + auto tuple_type = mlir::dyn_cast(type); if (!tuple_type) { assert(!flatten_values.empty()); auto retval = flatten_values.front(); @@ -220,7 +221,7 @@ mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op, mlir::Type type) { - if (!type.isa()) return op; + if (!mlir::isa(type)) return op; mlir::ValueRange flattened_results_ref(op->getResults()); auto result = diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/arith.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/arith.cc index 3dcf6f601778f6..5cf5fe5a6454d0 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/arith.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/arith.cc @@ -78,10 +78,10 @@ InterpreterValue Constant(InterpreterState&, arith::ConstantOp constant) { } auto value = constant.getValue(); - if (auto integer = value.dyn_cast()) { + if (auto integer = mlir::dyn_cast(value)) { return {static_cast(integer.getInt())}; } - if (auto float_value = value.dyn_cast()) { + if (auto float_value = mlir::dyn_cast(value)) { return {static_cast(float_value.getValueAsDouble())}; } @@ -135,7 +135,7 @@ llvm::SmallVector UiToFP( MutableArrayRef args, mlir::Operation* op, InterpreterState&) { if (args[0].IsTensor()) { - auto ty = op->getResultTypes()[0].cast(); + auto ty = mlir::cast(op->getResultTypes()[0]); return {DispatchScalarType( ty.getElementType(), [&](auto dummy) -> InterpreterValue { auto result = diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/bufferization.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/bufferization.cc index 7feb91003c4a28..f5000e17d1b37d 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/bufferization.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/bufferization.cc @@ -45,7 +45,7 @@ InterpreterValue AllocTensor( InterpreterState&, bufferization::AllocTensorOp alloc, ArrayRef dynamic_sizes, std::optional copy, const std::optional& /*sizeHint*/) { - auto ty = alloc->getResultTypes().front().cast(); + auto ty = mlir::cast(alloc->getResultTypes().front()); auto shape = ReplaceDynamicVals(ty.getShape(), dynamic_sizes); if (copy) { diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/linalg.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/linalg.cc index 3974460117095d..bab976cd949197 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/linalg.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/linalg.cc @@ -276,7 +276,7 @@ SmallVector Dot(InterpreterState&, linalg::DotOp op, InterpreterValue acc) { const auto& lhs = inputs[0]; const auto& rhs = inputs[1]; - if (op.getOutputs()[0].getType().isa()) { + if (mlir::isa(op.getOutputs()[0].getType())) { acc = acc.Clone(); } DispatchScalarType(op.getOutputs()[0].getType(), [&](auto dummy) { @@ -300,7 +300,7 @@ SmallVector Vecmat(InterpreterState&, linalg::VecmatOp op, InterpreterValue acc) { const auto& lhs = inputs[0]; const auto& rhs = inputs[1]; - if (op.getOutputs()[0].getType().isa()) { + if (mlir::isa(op.getOutputs()[0].getType())) { acc = acc.Clone(); } DispatchScalarType(op.getOutputs()[0].getType(), [&](auto dummy) { diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/mhlo.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/mhlo.cc index 37d457ec768a0d..65abfec99d3c16 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/mhlo.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/mhlo.cc @@ -137,7 +137,7 @@ InterpreterValue Concatenate(InterpreterState&, mhlo::ConcatenateOp concat, InterpreterValue Reshape(InterpreterState&, mhlo::ReshapeOp reshape, const InterpreterValue& in) { - auto ty = reshape->getResultTypes()[0].cast(); + auto ty = mlir::cast(reshape->getResultTypes()[0]); return ReshapeTensor(in, ty.getShape()); } @@ -214,7 +214,7 @@ InterpreterValue Slice(InterpreterState&, mhlo::SliceOp slice, llvm::SmallVector Constant(InterpreterState&, mhlo::ConstantOp constant) { - auto ty = constant->getResultTypes()[0].cast(); + auto ty = mlir::cast(constant->getResultTypes()[0]); return {DispatchScalarType(ty, [&](auto dummy) -> InterpreterValue { if (ty.getElementType().isUnsignedInteger()) { if constexpr (!std::is_same_v && @@ -507,7 +507,7 @@ InterpreterValue Transpose(InterpreterState&, mhlo::TransposeOp transpose, InterpreterValue Iota(InterpreterState&, mhlo::IotaOp iota) { auto dim = iota.getIotaDimension(); - auto ty = iota->getResultTypes()[0].cast(); + auto ty = mlir::cast(iota->getResultTypes()[0]); return DispatchScalarType(ty, [&](auto dummy) -> InterpreterValue { auto result = TensorOrMemref::Empty({ty.getShape()[dim]}); for (const auto& index : result.view.Indices()) { @@ -834,7 +834,7 @@ InterpreterValue DotGeneral(InterpreterState&, mhlo::DotGeneralOp op, lhs, rhs, dims.getLhsContractingDimensions(), dims.getRhsContractingDimensions(), dims.getLhsBatchingDimensions(), dims.getRhsBatchingDimensions(), - op->getResultTypes()[0].cast().getElementType()); + mlir::cast(op->getResultTypes()[0]).getElementType()); } // TODO(jreiffers): Migrate remaining ops to the safer signature. diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/tensor.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/tensor.cc index e7b45f5d33c42e..c35bfc8b87081c 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/tensor.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/tensor.cc @@ -40,7 +40,7 @@ int64_t Dim(InterpreterState& state, tensor::DimOp, InterpreterValue Empty(InterpreterState&, tensor::EmptyOp op, ArrayRef dynamic_sizes) { - auto ty = op->getResultTypes().front().cast(); + auto ty = mlir::cast(op->getResultTypes().front()); auto shape = ReplaceDynamicVals(ty.getShape(), dynamic_sizes); return InterpreterValue::MakeTensor(ty.getElementType(), shape); } @@ -56,7 +56,7 @@ InterpreterValue Extract(InterpreterState& state, tensor::ExtractOp, InterpreterValue FromElements(InterpreterState&, tensor::FromElementsOp op, MutableArrayRef elements) { - auto ty = op->getResultTypes().front().cast(); + auto ty = mlir::cast(op->getResultTypes().front()); auto result = InterpreterValue::MakeTensor(ty.getElementType(), llvm::to_vector(ty.getShape())); for (auto [index, element] : llvm::zip(result.View().Indices(), elements)) { @@ -186,7 +186,7 @@ llvm::SmallVector InsertSlice( InterpreterValue Generate(InterpreterState& state, tensor::GenerateOp generate, ArrayRef dynamic_sizes) { - auto ty = generate->getResultTypes().front().cast(); + auto ty = mlir::cast(generate->getResultTypes().front()); auto sizes = ReplaceDynamicVals(ty.getShape(), dynamic_sizes); auto result = InterpreterValue::MakeTensor(ty.getElementType(), sizes); diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/util.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/util.cc index 1bc5398aa666c6..7a638df252926a 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/util.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/util.cc @@ -137,7 +137,7 @@ llvm::SmallVector NoOpTerminator( int64_t EvalAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { int64_t lhs = 0, rhs = 0; - if (auto bin = expr.dyn_cast()) { + if (auto bin = mlir::dyn_cast(expr)) { lhs = EvalAffineExpr(bin.getLHS(), dims, symbols); rhs = EvalAffineExpr(bin.getRHS(), dims, symbols); } @@ -153,11 +153,11 @@ int64_t EvalAffineExpr(AffineExpr expr, ArrayRef dims, case AffineExprKind::CeilDiv: return llvm::divideCeilSigned(lhs, rhs); case AffineExprKind::Constant: - return expr.cast().getValue(); + return mlir::cast(expr).getValue(); case AffineExprKind::DimId: - return dims[expr.cast().getPosition()]; + return dims[mlir::cast(expr).getPosition()]; case AffineExprKind::SymbolId: - return symbols[expr.cast().getPosition()]; + return symbols[mlir::cast(expr).getPosition()]; } } diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc index 80d851d89de6c3..4f08a10a18bb4b 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc @@ -242,7 +242,7 @@ InterpreterValue Contract(InterpreterState&, vector::ContractionOp contraction, contraction.getIterationBounds(iter.sizes); auto maps = contraction.getIndexingMapsArray(); auto result_ty = contraction->getResultTypes()[0]; - auto shaped_ty = result_ty.dyn_cast(); + auto shaped_ty = mlir::dyn_cast(result_ty); auto result = DispatchScalarType(result_ty, [&](auto dummy) -> InterpreterValue { using T = decltype(dummy); @@ -621,8 +621,8 @@ InterpreterValue ShapeCast(InterpreterState&, vector::ShapeCastOp op, const InterpreterValue& in) { auto out = in.CoerceLayout({}); auto& out_view = out.View(); - out_view.sizes = - llvm::to_vector(op->getResultTypes()[0].cast().getShape()); + out_view.sizes = llvm::to_vector( + mlir::cast(op->getResultTypes()[0]).getShape()); out_view.strides = BufferView::GetDefaultStrides(out_view.sizes); return out; } @@ -656,8 +656,8 @@ InterpreterValue Splat(InterpreterState&, vector::SplatOp op, const InterpreterValue& in) { auto out = in.AsUnitTensor(/*is_vector=*/true); auto& view = out.View(); - view.sizes = - llvm::to_vector(op->getResultTypes()[0].cast().getShape()); + view.sizes = llvm::to_vector( + mlir::cast(op->getResultTypes()[0]).getShape()); view.strides = SmallVector(view.sizes.size(), 0); return out; } @@ -805,7 +805,8 @@ llvm::SmallVector TransferWrite( src_view.num_dimensions() && "expected matching number of results"); - dst = transfer.getSource().getType().isa() ? dst.Clone() : dst; + dst = + mlir::isa(transfer.getSource().getType()) ? dst.Clone() : dst; auto dst_slice = ExtractMemorySlice(state, transfer.getPermutationMap(), dst, src, offsets, transfer.getInBounds()); if (!dst_slice) { diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc b/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc index 6b25898377756f..74e72759f621e0 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc @@ -78,7 +78,7 @@ template class rng_t> mlir::interpreter::InterpreterValue RandomTensor(absl::BitGenRef bitgen, mlir::Type type) { llvm::SmallVector shape; - auto shaped_ty = type.dyn_cast(); + auto shaped_ty = mlir::dyn_cast(type); if (shaped_ty) { shape = llvm::to_vector(shaped_ty.getShape()); } @@ -102,8 +102,9 @@ mlir::interpreter::InterpreterValue RandomTensor(absl::BitGenRef bitgen, mlir::FailureOr MakeRandomInput( absl::BitGenRef bitgen, mlir::Type type) { - auto elem_ty = - type.isa() ? type.cast().getElementType() : type; + auto elem_ty = mlir::isa(type) + ? mlir::cast(type).getElementType() + : type; if (elem_ty.isF32()) { return RandomTensor(bitgen, type); } @@ -120,7 +121,8 @@ mlir::FailureOr MakeRandomInput( return RandomTensor(bitgen, type); } if (elem_ty.isInteger(1)) { - return {{TensorOrMemref::Empty(type.cast().getShape())}}; + return { + {TensorOrMemref::Empty(mlir::cast(type).getShape())}}; } llvm::errs() << "Unsupported type: "; diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc index 1584341e4f7cf5..8b7e0ec0c22d4d 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -177,7 +177,7 @@ void ExecutionTraceListener::LeaveRegion(ArrayRef yielded) { llvm::SmallVector ValueToAttribute( const InterpreterValue& value, mlir::Type type) { if (std::holds_alternative(value.storage)) { - auto types = type.cast().getTypes(); + auto types = mlir::cast(type).getTypes(); const auto& t = std::get(value.storage); llvm::SmallVector attrs; for (const auto& [v, ty] : llvm::zip(t.values, types)) { @@ -196,11 +196,11 @@ llvm::SmallVector ValueToAttribute( .getValues()[0]}; } - if (!type.isa()) { + if (!mlir::isa(type)) { return {}; } - auto shaped_ty = type.cast(); + auto shaped_ty = mlir::cast(type); return {DispatchScalarType(shaped_ty, [&](auto dummy) -> mlir::Attribute { using T = decltype(dummy); auto& t = std::get>(value.storage); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 998a7ac6a1969a..1b65ea14eaf838 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -3478,7 +3478,7 @@ OpFoldResult DynamicSliceOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); if (!operands[0]) return nullptr; - auto cst_attr = operands[0].dyn_cast(); + auto cst_attr = mlir::dyn_cast(operands[0]); if (cst_attr && cst_attr.isSplat()) { return cst_attr.resizeSplat(getResult().getType()); } From 4785cf12ac41b71210c2fbb565a33701270ca290 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Wed, 16 Apr 2025 03:11:37 -0700 Subject: [PATCH 0857/1324] Delete a lonely `ragged_all_to_all_kernel_common.h`. PiperOrigin-RevId: 748227775 --- .../kernels/ragged_all_to_all_kernel_common.h | 31 ------------------- 1 file changed, 31 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_common.h diff --git a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_common.h b/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_common.h deleted file mode 100644 index 05ac6ba8bb081e..00000000000000 --- a/third_party/xla/xla/service/gpu/kernels/ragged_all_to_all_kernel_common.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNELS_RAGGED_ALL_TO_ALL_KERNEL_COMMON_H_ -#define XLA_SERVICE_GPU_KERNELS_RAGGED_ALL_TO_ALL_KERNEL_COMMON_H_ - -#include - -namespace xla::gpu { - -// Maximum number of output pointers that can be passed to the kernel. -inline constexpr int64_t kMaxNumRaggedAllToAllOutputPtrs = 8; - -template -void* GetRaggedAllToAllKernel(); - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_KERNELS_RAGGED_ALL_TO_ALL_KERNEL_COMMON_H_ From 6415be0c15c7f0fa2310530d5eb8a917cfb6bae0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 03:55:29 -0700 Subject: [PATCH 0858/1324] Automated Code Change PiperOrigin-RevId: 748237733 --- .../xla/xla/service/memory_space_assignment/allocation.cc | 4 ++-- .../memory_space_assignment/memory_space_assignment_test.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc index 7a423dc36d507b..8cd5f7679ccf34 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc @@ -747,13 +747,13 @@ bool SlicedCopyAllocation::SliceDetail::operator==( absl::Status SlicedCopyAllocation::SliceDetail::CreateAsyncSlice( const Shape& original_shape, HloInstruction& producer, HloComputation& parent) { - if (original_shape.dimensions_size() != + if (original_shape.dimensions().size() != slice_decision.sizing.slice_params.size()) { return FailedPrecondition( "%s", absl::StrCat("The number of SlicedCopyAllocation parameters ", slice_decision.sizing.slice_params.size(), " does not match the rank ", - original_shape.dimensions_size(), + original_shape.dimensions().size(), " of the tensor we are slicing.")); } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 887258c212dab9..bb50b34a309b05 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -12010,7 +12010,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { return false; } if (slice->slice_starts().size() != - copy_operand->shape().dimensions_size()) { + copy_operand->shape().dimensions().size()) { *listener << " has slice (" << slice->name() << "), with " << slice->slice_starts().size() From 3d2402282a4ce2d8494fa72f805de55419850785 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Wed, 16 Apr 2025 04:16:59 -0700 Subject: [PATCH 0859/1324] [XLA:GPU] For SplitK rewrites, determine the accumulator type by the precision algorithm or the dot output. PiperOrigin-RevId: 748243385 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/split_k_gemm_rewriter.cc | 40 ++++++---- .../service/gpu/split_k_gemm_rewriter_test.cc | 76 ++++++++++++++++++- 3 files changed, 101 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index de44c6df57fdcc..51e09b0abcb1f4 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -859,6 +859,7 @@ cc_library( "//xla/backends/gpu/codegen/triton:support", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:algorithm_util", "//xla/service:hlo_creation_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 53abd5d6fd0d78..d66798b64c2301 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/literal_util.h" +#include "xla/service/algorithm_util.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" @@ -128,19 +129,32 @@ absl::StatusOr MakeSparseMetaOperand( } PrimitiveType GetAccumulatorType(bool disable_reduced_precision_reduction, - HloComputation* computation, + HloDotInstruction* dot, HloInstruction* instr) { if (!disable_reduced_precision_reduction) { return instr->shape().element_type(); } - PrimitiveType output_type = - computation->root_instruction()->shape().element_type(); - PrimitiveType accumulator_type = output_type == PrimitiveType::F64 - ? PrimitiveType::F64 - : PrimitiveType::F32; - return accumulator_type; + // Return the accumulator type if it is explicitly specified as dot algorithm. + auto accumulator_type = algorithm_util::GetDotAccumulatorType( + dot->precision_config().algorithm()); + if (accumulator_type.ok()) { + return accumulator_type.value(); + } + // Otherwise, return the default accumulator type for the output type. + PrimitiveType output_type = dot->shape().element_type(); + switch (output_type) { + case PrimitiveType::F16: + case PrimitiveType::BF16: + return PrimitiveType::F32; + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::S32: + default: + return output_type; + } } + } // namespace absl::StatusOr MakeSplitKOperand( @@ -335,10 +349,10 @@ absl::Status MakeDotComputationSplitKBatch( MakeSparseMetaOperand(*dot, config)); } // Keep the precision of the accumulator type for the dot output. - PrimitiveType dot_dtype = GetAccumulatorType( - disable_reduced_precision_reduction, computation, dot); + PrimitiveType accumulator_dtype = + GetAccumulatorType(disable_reduced_precision_reduction, dot, current); expanded = MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), - dot_dtype, sparsity, sparse_meta) + accumulator_dtype, sparsity, sparse_meta) .value(); // Make the added batch dimension the major-most, keep the order of the // original dimensions. @@ -352,8 +366,8 @@ absl::Status MakeDotComputationSplitKBatch( dot->SetupDerivedInstruction(expanded); } else { // Propagate the precision of the accumulator to the GEMM fusion root. - PrimitiveType accumulator_dtype = GetAccumulatorType( - disable_reduced_precision_reduction, computation, current); + PrimitiveType accumulator_dtype = + GetAccumulatorType(disable_reduced_precision_reduction, dot, current); expanded = computation->AddInstruction( current->CloneWithNewShape(ShapeUtil::PrependMajorDimension( config.split_k, ShapeUtil::ChangeElementType( @@ -383,7 +397,7 @@ absl::Status MakeDotComputationSplitKBatch( // Broadcast the operand to the Split-K dimension and convert to the // accumulator dtype. auto accumulator_dtype = GetAccumulatorType( - disable_reduced_precision_reduction, computation, operand); + disable_reduced_precision_reduction, dot, operand); HloInstruction* convert = MakeConvertToHlo(operand, accumulator_dtype); std::vector broadcast_dimensions( operand->shape().dimensions().size()); diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index d0a63b501feaf6..93f8c9593f1108 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -863,12 +863,82 @@ ENTRY e { HloInstruction* reduce; EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Convert(m::Op(&reduce) - .WithOpcode(HloOpcode::kReduce) - .WithOperand(0, m::Fusion())))); + GmockMatch(m::Convert().WithElementType(BF16).WithOperand( + 0, m::Op(&reduce) + .WithOpcode(HloOpcode::kReduce) + .WithElementType(F32) + .WithOperand(0, m::Fusion().WithElementType(F32))))) + << module->ToString(); EXPECT_EQ(reduce->metadata().op_name(), "foo"); } +TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKForInt32Dot) { + constexpr absl::string_view kHloText = R"( +HloModule t + +triton_gemm_dot { + parameter_0 = s8[480,128]{1,0} parameter(0) + parameter_1 = s8[16,128]{1,0} parameter(1) + ROOT dot.0 = s32[480,16]{1,0} dot(parameter_0, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = s8[480,128]{1,0} parameter(0) + p1 = s8[16,128]{1,0} parameter(1) + ROOT fusion = s32[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + TritonGemmConfig config(16, 16, 16, 4, 1, 4); + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Reduce() + .WithElementType(S32) + .WithOpcode(HloOpcode::kReduce) + .WithOperand(0, m::Fusion().WithElementType(S32)))) + << module->ToString(); +} + +TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKHonorsDotAlgorithm) { + constexpr absl::string_view kHloText = R"( +HloModule t + +triton_gemm_dot { + parameter_0 = s8[3,128,5,32]{3,2,1,0} parameter(0) + bitcast.1 = s8[3,5,32,128]{2,1,3,0} bitcast(parameter_0) + copy.1 = s8[3,5,32,128]{3,2,1,0} copy(bitcast.1) + reshape.5 = s8[480,128]{1,0} reshape(copy.1) + convert.8 = bf16[480,128]{1,0} convert(reshape.5) + parameter_1 = bf16[16,128]{1,0} parameter(1) + ROOT dot.0 = bf16[480,16]{1,0} dot(convert.8, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + algorithm=dot_bf16_bf16_bf16 +} + +ENTRY e { + p0 = s8[3,128,5,32]{3,2,1,0} parameter(0) + p1 = bf16[16,128]{1,0} parameter(1) + ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + TritonGemmConfig config(16, 16, 16, 4, 1, 4); + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Reduce().WithElementType(BF16).WithOperand( + 0, m::Fusion().WithElementType(BF16)))) + << module->ToString(); +} + TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKWithOutputFusion) { const std::string hlo_text = R"( HloModule t From 71d39ad59e7f9508ae5916127928594e463433a6 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 16 Apr 2025 04:26:24 -0700 Subject: [PATCH 0860/1324] Restrict output tilings to square-ish tiles It is very unlikely that anything except those would produce good results, since those are the only tiles that optimize data reuse for the given resource consumption. PiperOrigin-RevId: 748245698 --- .../gpu/autotuning/dot_search_space.cc | 24 ++++++++++++++++++- .../gpu/autotuning/dot_search_space_test.cc | 14 ++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index f5295147feeeca..b9b258be13c2c2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -402,7 +402,29 @@ void TritonDotFusionSearchSpace::AddOutputTilings( const int split = config.config.split_k; ConfigWithNotes new_config = config; for (int m = min_out_tile_.lhs_dim; m <= max_out_tile_.lhs_dim; m *= 2) { - for (int n = min_out_tile_.rhs_dim; n <= max_out_tile_.rhs_dim; n *= 2) { + int min_n = min_out_tile_.rhs_dim; + int max_n = max_out_tile_.rhs_dim; + // If there are square-ish tiles contained within the search space, it is + // extremely unlikely that a non-square-ish tile will perform better, since + // it does not optimize data reuse. The one exception to this is the + // edge-case where one of the dimensions is small: m >= LHS dim, or max_n >= + // RHS dim. + // + // Thus, as soon as there are square-ish tiles in the search space, and + // we're not in the edge case (i.e., m < LHS dim; the requirement on max_n + // is satisfied by construction as soon as [m/2, m*2] and [min_n, max_n] + // overlap), we can restrict the n-space to only these tiles. + auto overlaps = [](std::pair a, std::pair b) { + return !(a.second < b.first || b.second < a.first); + }; + if (m < lhs_parallel_size_ && overlaps({m / 2, m * 2}, {min_n, max_n})) { + min_n = std::max(m / 2, min_n); + max_n = std::min(m * 2, max_n); + VLOG(5) << "Computing output tile: For m = " << m + << ", restricting n-space to [" << min_n << "," << max_n + << "] to have square-ish tiles."; + } + for (int n = min_n; n <= max_n; n *= 2) { OutputTile tile = {m, n}; // We could make the tile size limits depend on split_k, but then we // need to implement the "inverse" of `GetMaxContractingSplit`. diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index e1425791c8df11..282b39c283f4e7 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla::gpu { namespace { +using ::testing::AllOf; using ::testing::Eq; using ::testing::Field; using ::testing::Ge; @@ -192,6 +193,17 @@ TEST_F(DotSearchSpaceTest, FindsGoodDataReuseOutputTiles) { Contains(AllOf(BlockMIs(Ge(32)), BlockNIs(Ge(32)))).Times(Ge(2))); } +TEST_F(DotSearchSpaceTest, RestrictsOutputToSquareishTiles) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/1024, + /*rhs_parallel_dim=*/1024)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT( + search_space.GenerateConfigs(), + WhenFilteredBy(BlockMIs(Eq(64)), BlockNIs(AllOf(Ge(32), Le(128))))); +} + TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForLowOccupancyProblem) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, @@ -285,7 +297,7 @@ TEST_F(DotSearchSpaceTest, ConsidersAppropriateCtaSizeForTileSize) { EXPECT_THAT(search_space.GenerateConfigs(), AllOf(Contains(AllOf(BlockMIs(Eq(64)), BlockNIs(Eq(32)), NumWarpsIs(Eq(4)))), - Contains(AllOf(BlockMIs(Eq(128)), BlockNIs(Eq(32)), + Contains(AllOf(BlockMIs(Eq(64)), BlockNIs(Eq(64)), NumWarpsIs(Eq(8)))))); } From df2446f015ca3d39dd09e75d33c3649545e7ef93 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 04:40:28 -0700 Subject: [PATCH 0861/1324] Prevent repeated vector allocations for temporary buffers PiperOrigin-RevId: 748249090 --- .../xla/service/memory_space_assignment/algorithm.cc | 12 ++++++------ .../xla/service/memory_space_assignment/algorithm.h | 3 +++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index e77fc625a51087..e38b8ee4f3298e 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -3565,18 +3565,18 @@ bool AsynchronousCopyResource::HasEnoughResource(int64_t exclusive_start_time, bool AsynchronousCopyResource::HasEnoughResourceMultiCheck( const std::vector& specs) { - std::vector> delay_changes; - delay_changes.reserve(delay_.size()); + delay_changes_.resize(0); + delay_changes_.reserve(delay_.size()); bool result = absl::c_all_of(specs, [&](const ResourceSpec& spec) { return ConsumeResource(spec.exclusive_start_time, spec.end_time, GetScaledIntegerResource(spec.resource), - &delay_changes); + &delay_changes_); }); // Apply the delay changes in reverse order. This ensures that the original // value of each delay is restored. - if (!delay_changes.empty()) { - for (int64_t i = delay_changes.size() - 1; i >= 0; --i) { - const auto& [time, delay] = delay_changes[i]; + if (!delay_changes_.empty()) { + for (int64_t i = delay_changes_.size() - 1; i >= 0; --i) { + const auto& [time, delay] = delay_changes_[i]; delay_[time] = delay; } } diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index d45e909ebf1d74..7d71b245be81d4 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -276,6 +276,9 @@ class AsynchronousCopyResource { std::vector initial_resources_; std::vector initial_resources_scaled_; std::vector delay_; + // A vector of pairs of (time, delay) used by + // HasEnoughResourceMultiCheck(), stored here to avoid reallocations. + std::vector> delay_changes_; }; // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of From 7e56ad8b615eaa0fb9cd055a9dba1311525ab405 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 04:52:33 -0700 Subject: [PATCH 0862/1324] Internal change in a BUILD file. PiperOrigin-RevId: 748251455 --- tensorflow/tools/api/tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 6afade4298077b..9576b008f1e37b 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -28,6 +28,7 @@ py_strict_test( "//third_party/py/numpy/tf_numpy_api:api_golden", ], tags = [ + "DO_NOT_USE_MIGRATION_ONLY_NO_PYTHON_NEXT", "no_mac", # b/198669105 "no_oss", # Runs explicitly in OSS "no_pip", From c222bcc4a0c4de6ca3da29f5493d64b0deed63c8 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Wed, 16 Apr 2025 05:27:06 -0700 Subject: [PATCH 0863/1324] Make xla compiler factory return absl::StatusOr. This should allows us to cacth errors during construction of the compiler objects. PiperOrigin-RevId: 748259191 --- third_party/xla/xla/service/compiler.cc | 3 +-- third_party/xla/xla/service/compiler.h | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/compiler.cc b/third_party/xla/xla/service/compiler.cc index a9f26dafb7d005..decd3e40636a42 100644 --- a/third_party/xla/xla/service/compiler.cc +++ b/third_party/xla/xla/service/compiler.cc @@ -99,8 +99,7 @@ Compiler::GetPlatformCompilers() { } /* static */ void Compiler::RegisterCompilerFactory( - se::Platform::Id platform_id, - std::function()> compiler_factory) { + se::Platform::Id platform_id, CompilerFactory compiler_factory) { absl::MutexLock lock(&platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); CHECK(factories->find(platform_id) == factories->end()) diff --git a/third_party/xla/xla/service/compiler.h b/third_party/xla/xla/service/compiler.h index bbb7ccc35842c9..268258b520032c 100644 --- a/third_party/xla/xla/service/compiler.h +++ b/third_party/xla/xla/service/compiler.h @@ -282,7 +282,8 @@ class Compiler { // The Compiler class also serves as a point to register compiler objects // for the various platforms. - using CompilerFactory = std::function()>; + using CompilerFactory = + std::function>()>; // Registers the compiler singleton for the platform. This is assumed to // be a singleton, so no ownership is transferred. From 06a50c2226faafcd14028cb294aa6caaa0fff09f Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Wed, 16 Apr 2025 06:16:21 -0700 Subject: [PATCH 0864/1324] Reverts 0054b0455108f8660508ad64aa3b6b7b870886db PiperOrigin-RevId: 748270278 --- third_party/xla/xla/debug_options_flags.cc | 2 +- third_party/xla/xla/xla.proto | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 43ce83ae9e346e..5cd208112be184 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -240,7 +240,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); - opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(false); + opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(true); opts.set_xla_gpu_unsafe_pipelined_loop_annotator(false); opts.set_xla_gpu_copy_insertion_use_region_analysis(false); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 0c32b62a0f472f..c5d4426a370ea2 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -754,6 +754,8 @@ message DebugOptions { // `xla_gpu_cublas_fallback` set to false. bool xla_gpu_triton_gemm_any = 190; + // TODO(b/409940111): Remove this flag and use high precision reductions for + // Split-K GEMMs unconditionally. bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; // It is usually preferable to not fallback to the driver; it can consume more From da10705e0ad06be982a7291a54e87601d7a3378d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 07:06:40 -0700 Subject: [PATCH 0865/1324] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/7c9a510bdfd98a39f2639e7446a3119fdd3ccbfe. PiperOrigin-RevId: 748281648 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 398e1b398d7bb9..fb81926386c0dd 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "6dc35e57bf5a76b034a79af695cd7752de274d68" - TFRT_SHA256 = "604c2a2a9c0d24981fe0a62f1ed300f57ea4681992c7dae63ab2b49639809442" + TFRT_COMMIT = "7c9a510bdfd98a39f2639e7446a3119fdd3ccbfe" + TFRT_SHA256 = "67467e2e0914859e657d384917cb27bd362e7ebb007a21476a5d271ddd540001" tf_http_archive( name = "tf_runtime", From 14998007d52e1aa0dd0bb2e6954df2ed5dabaded Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 07:29:29 -0700 Subject: [PATCH 0866/1324] Reverts 81c6e20274d1093079d901f4b6bd60e70eab18ad PiperOrigin-RevId: 748287087 --- .../xla/xla/backends/gpu/collectives/BUILD | 12 +++ .../gpu/collectives/gpu_collectives.h | 17 +++++ .../gpu/collectives/gpu_collectives_stub.h | 4 + .../gpu/collectives/nccl_collectives.cc | 76 +++++++++++++++++++ .../gpu/collectives/nccl_collectives.h | 2 + .../xla/xla/backends/gpu/runtime/BUILD | 6 +- third_party/xla/xla/pjrt/gpu/BUILD | 28 +------ third_party/xla/xla/pjrt/gpu/nccl_id_store.cc | 69 ----------------- third_party/xla/xla/pjrt/gpu/nccl_id_store.h | 59 -------------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 26 ++++--- 10 files changed, 131 insertions(+), 168 deletions(-) delete mode 100644 third_party/xla/xla/pjrt/gpu/nccl_id_store.cc delete mode 100644 third_party/xla/xla/pjrt/gpu/nccl_id_store.h diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 3595fe88020edf..a07dcd1e588ff9 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -128,6 +128,7 @@ cc_library( srcs = ["gpu_collectives.cc"], hdrs = ["gpu_collectives.h"], deps = [ + "//xla:executable_run_options", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", @@ -136,9 +137,12 @@ cc_library( "//xla/core/collectives:clique_key", "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -189,9 +193,11 @@ cc_library( ]), visibility = ["//visibility:private"], deps = [ + ":gpu_clique_key", ":gpu_collectives", ":nccl_communicator", ":nccl_errors", + "//xla:debug_options_flags", "//xla:status_macros", "//xla:util", "//xla/core/collectives", @@ -200,15 +206,21 @@ cc_library( "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:global_device_id", + "//xla/service/gpu:gpu_executable_run_options", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", ] + if_cuda_is_configured([ diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h index fbd88b28063585..556699506336b0 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h @@ -19,13 +19,18 @@ limitations under the License. #include #include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/executable_run_options.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -108,6 +113,18 @@ class GpuCollectives : public Collectives { virtual absl::StatusOr Allocate(uint64_t bytes) = 0; virtual absl::Status Deallocate(void* buffer) = 0; + + struct Topology { + int32_t node_id; + int32_t num_nodes; + size_t device_count_per_process; + std::shared_ptr kv_store; + absl::flat_hash_map device_id_to_node_id; + gpu::GpuExecutableRunOptions* gpu_executable_run_options; + }; + + // Initializes the topology information for the collectives backend. + virtual absl::Status InitializeTopology(Topology topology) = 0; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h index f217f4ccd3621d..11b50eed094073 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -70,6 +70,10 @@ class GpuCollectivesStub : public GpuCollectives { absl::Status Deallocate(void* buffer) final { return UnimplementedError(); } + absl::Status InitializeTopology(Topology topology) final { + return UnimplementedError(); + } + protected: static absl::Status UnimplementedError() { return Unimplemented("XLA compiled without GPU collectives support"); diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index bc84fbef3c62d0..d54c91f00b0e8d 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -20,16 +20,21 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" #include "xla/backends/gpu/collectives/nccl_errors.h" @@ -39,6 +44,9 @@ limitations under the License. #include "xla/core/collectives/collectives_registry.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" @@ -254,6 +262,74 @@ absl::Status NcclCollectives::Deallocate(void* location) { VLOG(2) << "Deallocated collective memory " << location; return absl::OkStatus(); } + +class NcclIdStore { + public: + NcclIdStore(int node_id, + absl::flat_hash_map device_to_node, + std::shared_ptr kv_store) + : node_id_(node_id), + device_to_node_(std::move(device_to_node)), + kv_store_(std::move(kv_store)) {} + + absl::StatusOr GetNcclUniqueId(const CliqueKey& key) { + auto* gpu_key = tsl::down_cast(&key); + if (gpu_key == nullptr) { + return InvalidArgument("Expected GPU clique key"); + } + + // The caller must ensure that threads calling this method concurrently have + // unique keys, otherwise the global key-value store may hold the wrong + // value. + { + absl::MutexLock lock(&mu_); + auto it = cache_.find(*gpu_key); + if (it != cache_.end()) { + return it->second; + } + } + CliqueId clique_id; + int primary_node_id = device_to_node_.at(gpu_key->root_device()); + if (node_id_ == primary_node_id) { + TF_ASSIGN_OR_RETURN( + clique_id, gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); + TF_RETURN_IF_ERROR( + kv_store_->Set(gpu_key->ToString(), clique_id.ToString())); + } else { + TF_ASSIGN_OR_RETURN( + std::string id_str, + kv_store_->Get(gpu_key->ToString(), absl::Minutes(10))); + clique_id = CliqueId(id_str); + } + absl::MutexLock lock(&mu_); + auto result = cache_.emplace(*gpu_key, std::move(clique_id)); + TF_RET_CHECK(result.second) << "Unique ID already in cache."; + return result.first->second; + } + + private: + const int node_id_; + const absl::flat_hash_map device_to_node_; + const std::shared_ptr kv_store_; + + absl::Mutex mu_; + absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); +}; + +absl::Status NcclCollectives::InitializeTopology( + NcclCollectives::Topology topology) { + if (topology.num_nodes > 1) { + auto nccl_id_store = std::make_shared( + topology.node_id, topology.device_id_to_node_id, + std::move(topology.kv_store)); + topology.gpu_executable_run_options->set_clique_id_callback( + [nccl_id_store](const CliqueKey& key) { + return nccl_id_store->GetNcclUniqueId(key); + }); + } + return absl::OkStatus(); +} + } // namespace xla::gpu XLA_COLLECTIVES_REGISTER("gpu", "nccl", 1, diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h index 36860f1360f00b..b89a9360728f18 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h @@ -61,6 +61,8 @@ class NcclCollectives : public GpuCollectives { absl::StatusOr Allocate(uint64_t bytes) final; absl::Status Deallocate(void* location) final; + + absl::Status InitializeTopology(Topology topology) final; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index e600afa6b979c6..41082e9801cb22 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -967,11 +967,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@llvm-project//mlir:IR", - ] + if_cuda_is_configured([ - "@local_config_nccl//:nccl", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rccl", - ]), + ], ) cc_library( diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index c91744bbf81189..2fcfb4809bd114 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -66,8 +66,11 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_proto_cc", + "//xla/backends/gpu/collectives:gpu_collectives", "//xla/client:client_library", "//xla/client:local_client", + "//xla/core/collectives", + "//xla/core/collectives:collectives_registry", "//xla/hlo/builder:xla_computation", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:event_pool", @@ -143,7 +146,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:traceme", ] + if_cuda_or_rocm([ # keep sorted - ":nccl_id_store", "//xla:debug_options_flags", "//xla/service/gpu:gpu_compiler", "//xla/service/gpu:gpu_constants", @@ -236,30 +238,6 @@ xla_test( ], ) -cc_library( - name = "nccl_id_store", - srcs = ["nccl_id_store.cc"], - hdrs = ["nccl_id_store.h"], - deps = [ - "//xla:status_macros", - "//xla:util", - "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_collectives", - "//xla/core/collectives:clique_id", - "//xla/core/collectives:clique_key", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:global_device_id", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - xla_test( name = "pjrt_client_test_se_gpu", srcs = ["pjrt_client_test_se_gpu.cc"], diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc deleted file mode 100644 index a2a72856e6d9f3..00000000000000 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/pjrt/gpu/nccl_id_store.h" - -#include -#include - -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" -#include "xla/core/collectives/clique_id.h" -#include "xla/core/collectives/clique_key.h" -#include "xla/status_macros.h" -#include "xla/util.h" -#include "tsl/platform/casts.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -absl::StatusOr NcclIdStore::GetNcclUniqueId(const CliqueKey& key) { - auto* gpu_key = tsl::down_cast(&key); - if (gpu_key == nullptr) { - return InvalidArgument("Expected GPU clique key"); - } - - // The caller must ensure that threads calling this method concurrently have - // unique keys, otherwise the global key-value store may hold the wrong value. - { - absl::MutexLock lock(&mu_); - auto it = cache_.find(*gpu_key); - if (it != cache_.end()) { - return it->second; - } - } - CliqueId clique_id; - int primary_node_id = device_to_node_.at(gpu_key->root_device()); - if (node_id_ == primary_node_id) { - TF_ASSIGN_OR_RETURN(clique_id, - gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); - TF_RETURN_IF_ERROR( - kv_store_->Set(gpu_key->ToString(), clique_id.ToString())); - } else { - TF_ASSIGN_OR_RETURN(std::string id_str, - kv_store_->Get(gpu_key->ToString(), absl::Minutes(10))); - clique_id = CliqueId(id_str); - } - absl::MutexLock lock(&mu_); - auto result = cache_.emplace(*gpu_key, std::move(clique_id)); - TF_RET_CHECK(result.second) << "Unique ID already in cache."; - return result.first->second; -} - -} // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h deleted file mode 100644 index fe8b060cb946a7..00000000000000 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PJRT_GPU_NCCL_ID_STORE_H_ -#define XLA_PJRT_GPU_NCCL_ID_STORE_H_ - -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/core/collectives/clique_id.h" -#include "xla/core/collectives/clique_key.h" -#include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/global_device_id.h" - -namespace xla { - -// A table mapping GpuCliqueKeys to CliqueIds. In a distributed setup the -// table of NCCL IDs is kept on the master node (node 0). The node of the first -// participating device will create the unique id. -class NcclIdStore { - public: - NcclIdStore(int node_id, - absl::flat_hash_map device_to_node, - std::shared_ptr kv_store) - : node_id_(node_id), - device_to_node_(std::move(device_to_node)), - kv_store_(std::move(kv_store)) {} - - absl::StatusOr GetNcclUniqueId(const CliqueKey& key); - - private: - const int node_id_; - const absl::flat_hash_map device_to_node_; - const std::shared_ptr kv_store_; - - absl::Mutex mu_; - absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); -}; - -} // namespace xla - -#endif // XLA_PJRT_GPU_NCCL_ID_STORE_H_ diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index a1f82c6e156656..5be6c178af91ae 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -46,7 +46,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/client/local_client.h" +#include "xla/core/collectives/collectives.h" +#include "xla/core/collectives/collectives_registry.h" #include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" @@ -106,7 +109,6 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/gpu/gpu_metrics.h" -#include "xla/pjrt/gpu/nccl_id_store.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_constants.h" @@ -1209,16 +1211,20 @@ absl::StatusOr BuildDistributedDevices( } gpu_executable_run_options->set_gpu_global_device_ids( std::move(gpu_device_ids)); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (num_nodes > 1) { - auto nccl_id_store = std::make_shared(node_id, device_to_node, - std::move(kv_store)); - gpu_executable_run_options->set_clique_id_callback( - [nccl_id_store](const CliqueKey& key) { - return nccl_id_store->GetNcclUniqueId(key); - }); + + TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, + xla::CollectivesRegistry::Default("gpu")); + xla::gpu::GpuCollectives* gpu_collectives = + tsl::down_cast(collectives); + + if (gpu_collectives == nullptr) { + return absl::InternalError("Failed to get GPU collectives"); } -#endif // GOOGLE_CUDA + + TF_RETURN_IF_ERROR(gpu_collectives->InitializeTopology( + {node_id, global_topology.nodes().size(), local_device_states.size(), + kv_store, device_to_node, gpu_executable_run_options})); + TF_ASSIGN_OR_RETURN(GpuTopologyProto gpu_topology, BuildGpuTopology(global_topology)); return std::make_pair(std::move(devices), gpu_topology); From 1b9895119c28550ae8722d42551994593fc96b03 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 16 Apr 2025 08:01:56 -0700 Subject: [PATCH 0867/1324] Add newline to `.bazelrc` PiperOrigin-RevId: 748295173 --- third_party/xla/.bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 83bce85c4fbbf5..2f1ab7dd1421d2 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -14,4 +14,4 @@ try-import %workspace%/xla_configure.bazelrc # in the version of absl we are using. # This can be removed when the absl version used is bumped to commit 48f0f91 or # newer, likely after July 2025. -common --copt=-D/*absl_nonnull*/='' --copt=-D/*absl_nullable*/='' --copt=-D/*absl_nullability_unknown*/='' \ No newline at end of file +common --copt=-D/*absl_nonnull*/='' --copt=-D/*absl_nullable*/='' --copt=-D/*absl_nullability_unknown*/='' From f896e13427b1c24113ca91bd7df03563e8673046 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 08:35:41 -0700 Subject: [PATCH 0868/1324] Removes temporal values / makespan code (never used in production). PiperOrigin-RevId: 748304700 --- .../auto_sharding/auto_sharding.cc | 2 -- .../auto_sharding/auto_sharding.h | 4 --- .../auto_sharding/auto_sharding_impl.cc | 5 --- .../auto_sharding/auto_sharding_solver.cc | 31 +++++-------------- .../auto_sharding/auto_sharding_solver.h | 14 --------- 5 files changed, 7 insertions(+), 49 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 39b02c3d885531..033ca9695d33b9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1996,8 +1996,6 @@ CreateAutoShardingSolverRequestAndCallSolver( *request.add_edge_intervals() = std::move(interval); } - PopulateTemporalValues(cost_graph, request); - const auto converted_problem = ConvertToProblem(request); const auto converted_request = ConvertToSolverRequest(converted_problem); return FormulateAndSolveMIPFromSolverRequest(converted_request, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 2034dd68972415..9db3f047c91d25 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -209,10 +209,6 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env); -// Populates temporal distance values. -void PopulateTemporalValues(const CostGraph& cost_graph, - AutoShardingSolverRequest& request); - void AddReplicatedStrategy( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index bc208c48f6506e..0bb737a624c17d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -53,11 +53,6 @@ absl::StatusOr Solve( /*deterministic mode*/ true); } -void PopulateTemporalValues(const CostGraph& cost_graph, - AutoShardingSolverRequest& request) { - // TODO(moffitt): Implement this. -} - double GetDotConvReplicationPenalty(const HloInstruction* inst, size_t instruction_id, size_t window, const HloInstructionSequence& sequence, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 26455ba4c23a9f..1e05bce0aee984 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -221,8 +221,7 @@ absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverParams& params, const std::vector>& s, const std::vector>& e, - const MPVariable* overbudget_var, const MPVariable* makespan_var, - MPSolver& solver) { + const MPVariable* overbudget_var, MPSolver& solver) { auto status = solver.Solve(); LOG(INFO) << "Solver absl::Status: " << status; @@ -325,10 +324,6 @@ absl::StatusOr SolveAndExtractSolution( overbudget_var->solution_value() * request.memory_budget(); } - if (makespan_var) { - unsalted_objective += - request.makespan_coeff().coeff() * makespan_var->solution_value(); - } LOG(INFO) << "Unsalted objective value: " << unsalted_objective; LOG(INFO) << "N = " << request.num_nodes(); @@ -511,10 +506,8 @@ void AddMemoryTerms( // the share same sharding as s_follow[i]. // 2. If request.overbudget_coeff is present, we turn the hard memory budget // constraint into a soft constraint instead. -// 3. If request.makespan_coeff is present, the objective additionally includes -// a makespan term. This is experimental and turned off by default. -// 4. request.max_departures is used only for debugging and can be ignored. -// 5. Note that due to our modeling of XLA's AllReduceReassociate optimization +// 3. request.max_departures is used only for debugging and can be ignored. +// 4. Note that due to our modeling of XLA's AllReduceReassociate optimization // (more details in CostGraph::CostGraph() in auto_sharding_cost_graph.cc, // and in CreateElementwiseOperatorStrategies() in auto_sharding.cc), there // can be a few (usually < 10) edges in the problem with negative costs. This @@ -562,7 +555,6 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( std::vector> s(request.num_nodes()); std::vector> e(num_edges); MPVariable* overbudget_var = nullptr; - MPVariable* makespan_var = nullptr; size_t unique_nodes = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { @@ -610,10 +602,6 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( solver->MakeNumVar(0.0, MPSolver::infinity(), "overbudget"); } - if (request.has_makespan_coeff()) { - makespan_var = CreateMakespanVar(request, e, *solver); - } - // Construct objective function. // Node costs absl::flat_hash_set infinity_vars; @@ -903,8 +891,8 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( if (params.max_departures.has_value()) { VLOG(0) << "Max departures: " << *params.max_departures; } - auto result = SolveAndExtractSolution(request, params, s, e, overbudget_var, - makespan_var, *solver); + auto result = + SolveAndExtractSolution(request, params, s, e, overbudget_var, *solver); if (result.ok()) { const AutoShardingEvaluation evaluation = Evaluate(request, *result, params); @@ -922,13 +910,9 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( LOG(INFO) << "Total Overbudget Cost: " << evaluation.total.overbudget_cost << " (lower bound: " << evaluation.lower_bound.overbudget_cost << ")"; - LOG(INFO) << "Total Makespan Cost: " << evaluation.total.makespan_cost - << " (lower bound: " << evaluation.lower_bound.makespan_cost - << ")"; LOG(INFO) << "Total Cost: " << evaluation.total.cost() << " (lower bound: " << evaluation.lower_bound.cost() << ")"; LOG(INFO) << "Total Departures: " << evaluation.total_departures; - LOG(INFO) << "Total Makespan: " << evaluation.total_makespan; LOG(INFO) << "Total Violations: " << evaluation.violation_codes.size(); LOG(INFO) << "Total Maximum Memory: " << evaluation.total.max_memory << " (lower bound: " << evaluation.lower_bound.max_memory << ")"; @@ -1107,12 +1091,12 @@ bool CostComponents::operator==(const CostComponents& other) const { computation_cost == other.computation_cost && resharding_cost == other.resharding_cost && overbudget_cost == other.overbudget_cost && - makespan_cost == other.makespan_cost && max_memory == other.max_memory; + max_memory == other.max_memory; } double CostComponents::cost() const { return communication_cost + computation_cost + resharding_cost + - overbudget_cost + makespan_cost; + overbudget_cost; } bool AutoShardingEvaluation::operator==( @@ -1240,7 +1224,6 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, evaluation.lower_bound.resharding_cost += *std::min_element( r.at(edge_idx).costs().begin(), r.at(edge_idx).costs().end()); } - evaluation.total_makespan = EvaluateMakespan(request, result, evaluation); return evaluation; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 5fe51df206b8ce..60a4092e01c8ae 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -78,7 +78,6 @@ struct CostComponents { double computation_cost = 0.0; double resharding_cost = 0.0; double overbudget_cost = 0.0; - double makespan_cost = 0.0; double max_memory = 0.0; double cost() const; @@ -99,9 +98,6 @@ struct AutoShardingEvaluation { // How many instructions departed from the "default" sharding strategy. double total_departures = 0.0; - // The (raw) total makespan, i.e., not scaled by the makespan coefficient. - double total_makespan = 0.0; - bool operator==(const AutoShardingEvaluation& other) const; }; @@ -120,16 +116,6 @@ double ComputeShardingStrategyCost( const AutoShardingSolverRequest& request, const std::vector& node_strategies); -// Creates and returns a variable for makespan. -operations_research::MPVariable* CreateMakespanVar( - const AutoShardingSolverRequest& request, - const std::vector>& e, - operations_research::MPSolver& solver); - -double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverOutput& result, - AutoShardingEvaluation& evaluation); - // Determines if strategy 'first' is dominated by strategy 'second' (i.e., its // costs are all equal or worse, and it has identical alias mappings). bool CheckDominance(const AutoShardingSolverRequest& request, From 927fb379454bbae7f53ebebc09e96978da6e37d9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 09:05:54 -0700 Subject: [PATCH 0869/1324] Make `Shape::DeleteDimensions()` less error-prone. This function requires its argument to be sorted, which is not obvious from the name. It also puts extra burden on the callers. Remove this requirement to make it easier to use and less error-prone. PiperOrigin-RevId: 748313411 --- third_party/xla/xla/shape.cc | 6 ++++-- third_party/xla/xla/shape.h | 4 +++- third_party/xla/xla/shape_test.cc | 7 +++++++ third_party/xla/xla/shape_util.cc | 5 +---- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 578c290eb0b503..851706e1530d56 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -228,9 +228,11 @@ void Shape::DeleteDimension(int64_t dim_to_delete) { } } -void Shape::DeleteDimensions(absl::Span sorted_dims_to_delete) { +void Shape::DeleteDimensions(absl::Span dims_to_delete) { auto& state = array_state(); - CHECK(absl::c_is_sorted(sorted_dims_to_delete)); + std::vector sorted_dims_to_delete(dims_to_delete.begin(), + dims_to_delete.end()); + absl::c_sort(sorted_dims_to_delete); state.dimensions = RemoveElements(sorted_dims_to_delete, state.dimensions); state.dynamic_dimensions = RemoveElements(sorted_dims_to_delete, state.dynamic_dimensions); diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index af59dc4edc589b..6774417451db3b 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -223,7 +223,9 @@ class Shape { // Precondition: this is an array shape, and the input dimension indices are // valid. void DeleteDimension(int64_t dim_to_delete); - void DeleteDimensions(absl::Span sorted_dims_to_delete); + // Like the above, but deletes multiple dimensions at once. The dimensions + // must not contain duplicates. + void DeleteDimensions(absl::Span dims_to_delete); // Returns the primitive type of the shape. PrimitiveType element_type() const { return element_type_; } diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index 72e2768196e619..8e4258591c760a 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -108,6 +108,13 @@ TEST_F(ShapeTest, DeleteDimensions) { EXPECT_EQ(shape, ShapeUtil::MakeShapeWithDenseLayout(F32, {5, 9}, {0, 1})); } +TEST_F(ShapeTest, DeleteDimensionsUnordered) { + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {5, 3, 2, 7, 9}, + {2, 0, 1, 4, 3}); + shape.DeleteDimensions({3, 1, 2}); + EXPECT_EQ(shape, ShapeUtil::MakeShapeWithDenseLayout(F32, {5, 9}, {0, 1})); +} + TEST_F(ShapeTest, EqualityTest) { // Different layouts. EXPECT_NE(ShapeUtil::MakeShapeWithDenseLayout(F32, {23, 44}, {1, 0}), diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index dd7af95c71f352..f0768043eef259 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -2075,10 +2075,7 @@ struct ParallelState { /* static */ Shape ShapeUtil::DeleteDimensions( absl::Span dims_to_delete, Shape shape) { - std::vector dims_to_delete_v(dims_to_delete.begin(), - dims_to_delete.end()); - absl::c_sort(dims_to_delete_v); - shape.DeleteDimensions(dims_to_delete_v); + shape.DeleteDimensions(dims_to_delete); return shape; } From 2495c486e8711a0e69e8e232a1e9f25c6d78f152 Mon Sep 17 00:00:00 2001 From: Marissa Ikonomidis Date: Wed, 16 Apr 2025 09:14:20 -0700 Subject: [PATCH 0870/1324] Add a helper library for memory/latency logging PiperOrigin-RevId: 748315913 --- tensorflow/lite/profiling/BUILD | 13 ++++ .../lite/profiling/memory_latency_logger.cc | 71 +++++++++++++++++++ .../lite/profiling/memory_latency_logger.h | 56 +++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 tensorflow/lite/profiling/memory_latency_logger.cc create mode 100644 tensorflow/lite/profiling/memory_latency_logger.h diff --git a/tensorflow/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD index e71e7ea1fb2d42..95e4319f981f8f 100644 --- a/tensorflow/lite/profiling/BUILD +++ b/tensorflow/lite/profiling/BUILD @@ -188,6 +188,19 @@ cc_test( ], ) +cc_library( + name = "memory_latency_logger", + srcs = ["memory_latency_logger.cc"], + hdrs = ["memory_latency_logger.h"], + deps = [ + ":memory_usage_monitor", + "//tensorflow/lite:framework_stable", + "//tensorflow/lite/tools:logging", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + cc_library( name = "profile_summary_formatter", srcs = ["profile_summary_formatter.cc"], diff --git a/tensorflow/lite/profiling/memory_latency_logger.cc b/tensorflow/lite/profiling/memory_latency_logger.cc new file mode 100644 index 00000000000000..02204a6b93a7f8 --- /dev/null +++ b/tensorflow/lite/profiling/memory_latency_logger.cc @@ -0,0 +1,71 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/profiling/memory_latency_logger.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/lite/profiling/memory_usage_monitor.h" +#include "tensorflow/lite/tools/logging.h" + +namespace tflite { +namespace profiling { +namespace memory { + +MemoryLatencyLogger::MemoryLatencyLogger() { + mem_monitor_ = + std::make_unique(/*sampling_interval_ms=*/50); +} +void MemoryLatencyLogger::Start() { + if (start_ != absl::UnixEpoch()) { + TFLITE_LOG(INFO) << "MemoryLatencyLogger start called multiple times."; + return; + } + start_ = absl::Now(); + mem_monitor_->Start(); +} + +void MemoryLatencyLogger::Stop(absl::string_view log_message) { + if (start_ == absl::UnixEpoch()) { + TFLITE_LOG(INFO) + << "MemoryLatencyLogger hasn't started yet or has stopped!"; + return; + } + + absl::Time stop = absl::Now(); + mem_monitor_->Stop(); + int space_count = + 35 - log_message.size(); // used for better user readability. + std::string space(std::max(space_count, 0), '-'); + TFLITE_LOG(INFO) << log_message << " " << space << " latency: " << std::fixed + << std::setprecision(1) + << absl::ToDoubleMilliseconds(stop - start_) + << " ms, peak alloc: " << mem_monitor_->GetPeakMemUsageInMB() + << " MB, peak in-use: " + << mem_monitor_->GetPeakInUseMemoryInMB() + << " MB, current in-use: " + << mem_monitor_->GetCurrentInUseMemoryInMB() << " MB"; +} + +} // namespace memory +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/lite/profiling/memory_latency_logger.h b/tensorflow/lite/profiling/memory_latency_logger.h new file mode 100644 index 00000000000000..8273ab1ee7f0ca --- /dev/null +++ b/tensorflow/lite/profiling/memory_latency_logger.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_PROFILING_MEMORY_LATENCY_LOGGER_H_ +#define TENSORFLOW_LITE_PROFILING_MEMORY_LATENCY_LOGGER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "tensorflow/lite/profiling/memory_usage_monitor.h" + +namespace tflite { +namespace profiling { +namespace memory { + +// This class is used to measure the memory and latency of the surrounding code +// block. Example usage: +// MemoryLatencyLogger logger; +// logger.Start(); +// Code block +// logger.Stop("Code block"); + +// This class is thread-unsafe. +class MemoryLatencyLogger { + public: + MemoryLatencyLogger(); + // Starts the memory and latency monitoring. + void Start(); + // Stops the memory and latency monitoring and logs the results. + void Stop(absl::string_view log_message); + + private: + // The memory usage monitor. + std::unique_ptr mem_monitor_; + // The start time of the memory and latency monitoring. + absl::Time start_; +}; + +} // namespace memory +} // namespace profiling +} // namespace tflite + +#endif // TENSORFLOW_LITE_PROFILING_MEMORY_LATENCY_LOGGER_H_ From a7d84686eee7b9abf7cdf8e8ffab726ddbded4d2 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 16 Apr 2025 09:20:21 -0700 Subject: [PATCH 0871/1324] The second parameter to `TF_LITE_KERNEL_LOG` is expected to be a `printf` format string, known at compile time. Run IWYU to fix includes. PiperOrigin-RevId: 748317645 --- tensorflow/lite/delegates/gpu/BUILD | 14 +++++-- .../delegates/gpu/common/model_builder.cc | 7 +++- tensorflow/lite/delegates/gpu/delegate.cc | 42 +++++++++---------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 37c4df92b4dad8..4e216c6677ffe8 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -42,25 +42,31 @@ _DELEGATE_NO_GL_DEPS = select({ ":tflite_profile", #"//third_party/GL:EGL_headers", #"//third_party/GL:GLES3_headers", + # go/keep-sorted start "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "//tensorflow/lite:kernel_api", - "//tensorflow/lite:minimal_logging", "//tensorflow/lite/async:backend_async_kernel_interface", "//tensorflow/lite/core/async/interop/c:types", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/delegates:serialization", "//tensorflow/lite/delegates/gpu/cl:util", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_builder_helper", "//tensorflow/lite/delegates/gpu/common:quantization_util", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates:serialization", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/profiling/telemetry", - "//tensorflow/lite/profiling/telemetry:telemetry_status", + "//tensorflow/lite/profiling/telemetry/c:telemetry_setting", "//tensorflow/lite/profiling/telemetry/c:telemetry_setting_internal", + "//tensorflow/lite/profiling/telemetry:telemetry_status", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", + # go/keep-sorted end ] config_setting( diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 31d12a503dafc7..fc4966122aa18b 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include +#include #include #include #include @@ -31,9 +32,13 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/interpreter.h" +#include "tensorflow/lite/core/interpreter_builder.h" +#include "tensorflow/lite/core/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h" @@ -3369,7 +3374,7 @@ TfLiteIntArray* GetOpsToReplace( partition_helper.num_total_nodes()); } absl::StrAppend(&error_message, " operations will run on the CPU."); - TF_LITE_KERNEL_LOG(context, error_message.c_str()); + TF_LITE_KERNEL_LOG(context, "%s", error_message.c_str()); } return ConvertVectorToTfLiteIntArray(ops_to_replace); } diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 9fac6e598f1b1a..cfad378991585c 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -20,13 +20,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/delegate.h" -#include "tensorflow/lite/logger.h" - -#if defined(__ANDROID__) -#include -#endif - #include +#include #include #include #include @@ -40,28 +35,35 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/types/span.h" #include "tensorflow/lite/builtin_ops.h" - -#if defined(__ANDROID__) -#include "tensorflow/lite/async/backend_async_kernel_interface.h" -#include "tensorflow/lite/core/async/c/task.h" -#include "tensorflow/lite/core/async/interop/c/attribute_map.h" -#include "tensorflow/lite/core/async/interop/c/constants.h" -#include "tensorflow/lite/core/async/interop/c/types.h" -#endif - #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/quantization_util.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" #include "tensorflow/lite/delegates/gpu/tflite_profile.h" #include "tensorflow/lite/delegates/serialization.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/logger.h" +#include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/profiling/telemetry/c/telemetry_setting.h" +#include "tensorflow/lite/profiling/telemetry/telemetry.h" +#include "tensorflow/lite/profiling/telemetry/telemetry_status.h" #if defined(__ANDROID__) +#include + +#include "tensorflow/lite/async/backend_async_kernel_interface.h" +#include "tensorflow/lite/core/async/c/task.h" +#include "tensorflow/lite/core/async/interop/c/attribute_map.h" +#include "tensorflow/lite/core/async/interop/c/constants.h" +#include "tensorflow/lite/core/async/interop/c/types.h" +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/async_buffers.h" #include "tensorflow/lite/delegates/gpu/gl/android_sync.h" #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" @@ -71,12 +73,6 @@ limitations under the License. #include "tensorflow/lite/delegates/utils/utils.h" #endif -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/minimal_logging.h" -#include "tensorflow/lite/profiling/telemetry/c/telemetry_setting_internal.h" -#include "tensorflow/lite/profiling/telemetry/telemetry.h" -#include "tensorflow/lite/profiling/telemetry/telemetry_status.h" - #ifndef CL_DELEGATE_NO_GL #include "tensorflow/lite/delegates/gpu/gl/api2.h" #endif @@ -469,7 +465,7 @@ absl::Status DelegateKernelCore::Setup( InitializeOpenClApi(&graph, &builder, &graph_is_destroyed, context, delegate_params, delegate_->serialization()); if (!status.ok()) { - TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str()); + TF_LITE_KERNEL_LOG(context, "%s", std::string(status.message()).c_str()); TF_LITE_KERNEL_LOG(context, "Falling back to OpenGL"); // Graph needs to be re-created because it is moved above. From 30f6528725afd35a26ef9641ff691fa3a26fb7c1 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 16 Apr 2025 09:22:06 -0700 Subject: [PATCH 0872/1324] The second parameter to `TF_LITE_KERNEL_LOG` is expected to be a `printf` format string, known at compile time. PiperOrigin-RevId: 748318175 --- tensorflow/lite/kernels/parse_example/BUILD | 11 +++++ .../kernels/parse_example/parse_example.cc | 44 ++++++++++++------- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/tensorflow/lite/kernels/parse_example/BUILD b/tensorflow/lite/kernels/parse_example/BUILD index b18d23b6c7d607..bbf62c1decb983 100644 --- a/tensorflow/lite/kernels/parse_example/BUILD +++ b/tensorflow/lite/kernels/parse_example/BUILD @@ -24,6 +24,7 @@ cc_library( compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), deps = [ + "//tensorflow/core/platform:hash", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:common", @@ -31,6 +32,11 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@flatbuffers", ] + select({ "//tensorflow:android": [ @@ -111,6 +117,11 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@flatbuffers", ], ) diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc index ec87aabfc86c95..6d6e77f02bb6cc 100644 --- a/tensorflow/lite/kernels/parse_example/parse_example.cc +++ b/tensorflow/lite/kernels/parse_example/parse_example.cc @@ -16,24 +16,35 @@ limitations under the License. #include #include +#include +#include +#include +#include #include #include -#include #include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "xla/tsl/platform/errors.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/blocking_counter.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/util/example_proto_fast_parsing.h" #include "tensorflow/core/util/presized_cuckoo_map.h" #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h" #include "tensorflow/lite/mutable_op_resolver.h" @@ -46,7 +57,6 @@ namespace parse_example { namespace { namespace tf = ::tensorflow; -using tf::Status; using tf::StringPiece; using tf::tstring; using tf::example::CopyOrMoveBlock; @@ -116,7 +126,7 @@ bool ParseExample(StringRef serialized, Example* example) { return ParseExample(&stream, example); } -Status FastParseSerializedExample( +absl::Status FastParseSerializedExample( StringRef serialized_example, const tstring& example_name, const size_t example_index, const FastParseExampleConfig& config, bool* quick_filter, int quick_filter_size, @@ -139,7 +149,7 @@ Status FastParseSerializedExample( // I.e. last entry in the map overwrites all the previous ones. tensorflow::example::parsed::FeatureMapEntry& name_and_feature = parsed_example[parsed_example_size - i - 1]; - const StringPiece feature_name = name_and_feature.first; + const absl::string_view feature_name = name_and_feature.first; tensorflow::example::parsed::Feature& feature = name_and_feature.second; if (feature_name.length() >= quick_filter_size || !quick_filter[feature_name.length()]) { @@ -153,7 +163,7 @@ Status FastParseSerializedExample( size_t d = d_and_type.first; bool is_dense = d_and_type.second == Type::Dense; - auto example_error = [&](StringPiece suffix) { + auto example_error = [&](absl::string_view suffix) { return tf::errors::Internal("Name: ", example_name, ", Key: ", feature_name, ", Index: ", example_index, ". ", suffix); @@ -164,7 +174,7 @@ Status FastParseSerializedExample( }; tf::DataType example_dtype; - if (feature.ParseDataType(&example_dtype) != absl::OkStatus()) { + if (!feature.ParseDataType(&example_dtype).ok()) { return parse_error(); } if (is_dense) { @@ -184,7 +194,7 @@ Status FastParseSerializedExample( const std::size_t num_elements = config.dense[d].elements_per_stride; const std::size_t offset = example_index * num_elements; - auto shape_error = [&](size_t size, StringPiece type_str) { + auto shape_error = [&](size_t size, absl::string_view type_str) { return example_error(absl::StrCat( "Number of ", type_str, " values != expected. " @@ -238,7 +248,7 @@ Status FastParseSerializedExample( "Expected type: ", DataTypeString(config.dense[d].dtype))); } - auto shape_error = [&](size_t size, StringPiece type_str) { + auto shape_error = [&](size_t size, absl::string_view type_str) { return example_error( absl::StrCat("Number of ", type_str, " values is not a multiple of stride length. Saw ", @@ -452,7 +462,7 @@ inline void CopyToBuffer(absl::Span vec, char* tensor_buffer, } } -Status FastParseExampleLite( +absl::Status FastParseExampleLite( const FastParseExampleConfig& config, const TfLiteTensor* serialized, absl::Span example_names, bool* quick_filter, int quick_filter_size, const std::unique_ptr& config_index, @@ -465,7 +475,7 @@ Status FastParseExampleLite( std::vector fixed_dense_values(config.dense.size()); std::vector sparse_buffers(config.sparse.size()); std::vector varlen_dense_buffers(config.dense.size()); - Status status_of_minibatch; + absl::Status status_of_minibatch; for (size_t e = 0; e < count; ++e) { status_of_minibatch = FastParseSerializedExample( GetString(serialized, e), @@ -971,8 +981,8 @@ TfLiteStatus EvalParseExample(TfLiteContext* context, TfLiteNode* node) { data->config, serialized, {}, data->quick_filter, data->quick_filter_size, data->config_index, data->config_index_size, &data->hasher, &data->got, stats, context); - if (status != absl::OkStatus()) { - TF_LITE_KERNEL_LOG(context, status.ToString().c_str()); + if (!status.ok()) { + TF_LITE_KERNEL_LOG(context, "%s", status.ToString().c_str()); return kTfLiteError; } return kTfLiteOk; From 21db0cd50ef700bf1c0f8df9b9d22847b06d5607 Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 16 Apr 2025 17:59:28 +0100 Subject: [PATCH 0873/1324] Fix TOSA HardSwish Table generation for int8 inputs (#62829) More closely match tflite kernel behaviour for integers when calculating HardSwish table Change-Id: I32a1338fb15d3505d4bea432a8523c79d8f5da7a Signed-off-by: James Ward --- tensorflow/compiler/mlir/tosa/BUILD | 1 + .../mlir/tosa/transforms/legalize_tfl.cc | 10 +-- .../mlir/tosa/transforms/legalize_utils.cc | 84 +++++++++++++++++++ .../mlir/tosa/transforms/legalize_utils.h | 5 ++ 4 files changed, 92 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 238781aa6455eb..8d6cd90fa35a13 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -99,6 +99,7 @@ cc_library( "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@gemmlowp", "@local_xla//xla/tsl/framework/fixedpoint", ], ) diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index ac0c4b0c9c8266..cfd80d61d1b245 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -3444,17 +3444,11 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( mlir::dyn_cast_or_null( output_type.getElementType()); - auto hardswish_func = [](double v) -> double { - double w = v + 3.0; - w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w; - return v * w / 6.0; - }; - if (input_qtype.getStorageTypeIntegralWidth() == 8) { // Implement with 8-bit table lookup. - Value table_const = getTosaConst8bitTable( + Value table_const = getTosaConstHardSwish8bitTable( rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), - output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func); + output_qtype.getScale(), output_qtype.getZeroPoint()); CreateReplaceOpAndInfer( rewriter, op, output_type, tfl_hardswish_op.getInput(), table_const); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 101676be1a0110..43008bc0a81441 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -582,6 +582,90 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, } } +Value getTosaConstHardSwish8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp) { + // Define tflite params: + // See: HardSwishPrepare / HardSwishParams + const float hires_input_scale = (1.0f / 128.0f) * input_scale; + const float reluish_scale = 3.0f / 32768.0f; + const float output_multiplier = hires_input_scale / output_scale; + + int16_t output_multiplier_fixedpoint_int16; + int output_multiplier_exponent; + + int16_t reluish_multiplier_fixedpoint_int16; + int reluish_multiplier_exponent; + + int32_t output_multiplier_fixedpoint_int32; + tflite::QuantizeMultiplier(output_multiplier, + &output_multiplier_fixedpoint_int32, + &output_multiplier_exponent); + tflite::DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32, + &output_multiplier_fixedpoint_int16); + assert(output_multiplier_exponent <= 0); + + const float reluish_multiplier = hires_input_scale / reluish_scale; + int32_t reluish_multiplier_fixedpoint_int32; + + tflite::QuantizeMultiplier(reluish_multiplier, + &reluish_multiplier_fixedpoint_int32, + &reluish_multiplier_exponent); + tflite::DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32, + &reluish_multiplier_fixedpoint_int16); + + // See HardSwish function in + // tensorflow/lite/kernels/internal/reference/hardswish.h + SmallVector table; + for (int32_t i = -128; i < 128; i++) { + const int16_t input_value = i - input_zp; + const int16_t input_value_on_hires_input_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + gemmlowp::SaturatingRoundingDoublingHighMul( + input_value_on_hires_input_scale, + output_multiplier_fixedpoint_int16); + int16_t reluish_value = input_value_on_hires_input_scale; + if (reluish_multiplier_exponent > 0) { + reluish_value = tflite::reference_ops::SaturatingLeftShift( + reluish_value, reluish_multiplier_exponent - 1); + } + reluish_value = gemmlowp::SaturatingRoundingDoublingHighMul( + reluish_value, reluish_multiplier_fixedpoint_int16); + if (reluish_multiplier_exponent > 0) { + reluish_value = + tflite::reference_ops::SaturatingLeftShift(reluish_value, 1); + } + if (reluish_multiplier_exponent < 0) { + reluish_value = gemmlowp::RoundingDivideByPOT( + reluish_value, -reluish_multiplier_exponent); + } + reluish_value = (reluish_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + tflite::reference_ops::SaturatingDoublingHighMul( + reluish_value, input_value_on_preshift_output_scale); + int16_t output_value = gemmlowp::RoundingDivideByPOT( + preshift_output_value, -output_multiplier_exponent); + output_value += output_zp; + output_value = + std::min(output_value, std::numeric_limits::max()); + output_value = + std::max(output_value, std::numeric_limits::min()); + table.push_back(output_value); + } + + auto element_qtype = + UniformQuantizedType::get(true, rewriter.getIntegerType(8), + rewriter.getF32Type(), 1.0f, 0, -128, 127); + auto const_type = tensorflow::GetTypeFromTFTensorShape({256}, element_qtype); + auto storage_type = tensorflow::GetTypeFromTFTensorShape( + {256}, element_qtype.getStorageType()); + auto const_attr = DenseElementsAttr::get(storage_type, llvm::ArrayRef(table)); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, float input_scale, int32_t input_zp, float output_scale, int32_t output_zp) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index 443054e5bf9f01..a2b990446924c9 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -126,6 +126,11 @@ Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, float input_scale, int32_t input_zp, float output_scale, int32_t output_zp); +// Create an 8-bit TOSA Table constant tensor for the HardSwish operator +Value getTosaConstHardSwish8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp); + // Create a 32-bit float constant operator from a float Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op, float val, int rank); From f0dcade504021273a805f978ed3f8c1b9be12766 Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Wed, 16 Apr 2025 09:35:07 -0700 Subject: [PATCH 0874/1324] [XLA:CPU] Support custom calls from NanoRt PiperOrigin-RevId: 748322026 --- third_party/xla/xla/backends/cpu/nanort/BUILD | 8 ++ .../backends/cpu/nanort/nanort_client_test.cc | 78 +++++++++++++++++++ .../backends/cpu/nanort/nanort_executable.cc | 48 ++++++++++-- .../backends/cpu/nanort/nanort_executable.h | 27 ++++++- .../xla/xla/backends/cpu/runtime/thunk.h | 1 - 5 files changed, 152 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/nanort/BUILD b/third_party/xla/xla/backends/cpu/nanort/BUILD index e7965bcfac6fcf..d770cc57b7328c 100644 --- a/third_party/xla/xla/backends/cpu/nanort/BUILD +++ b/third_party/xla/xla/backends/cpu/nanort/BUILD @@ -52,6 +52,9 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/backends/cpu:alignment", + "//xla/ffi", + "//xla/ffi:execution_context", + "//xla/ffi:ffi_api", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", @@ -60,12 +63,14 @@ xla_cc_test( "//xla/pjrt:pjrt_executable", "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", @@ -81,6 +86,7 @@ cc_library( "//xla/backends/cpu/nanort:nanort_users", ]), deps = [ + "//xla:executable_run_options", "//xla:shape_util", "//xla:util", "//xla/backends/cpu:alignment", @@ -88,12 +94,14 @@ cc_library( "//xla/backends/cpu/runtime:function_library", "//xla/backends/cpu/runtime:thread_pool_task_runner", "//xla/backends/cpu/runtime:thunk", + "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:computation_layout", "//xla/service:executable", "//xla/service:hlo_value", "//xla/service/cpu:cpu_executable", + "//xla/service/cpu:cpu_runtime", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc index 72a6ccb2098464..36967d953aaec6 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc @@ -18,15 +18,21 @@ limitations under the License. #include #include #include +#include +#include #include #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/backends/cpu/alignment.h" #include "xla/backends/cpu/nanort/nanort_executable.h" +#include "xla/ffi/execution_context.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -38,6 +44,7 @@ limitations under the License. #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -274,6 +281,77 @@ ENTRY test_module { EXPECT_EQ(result_span[0], expected_result); } +//===----------------------------------------------------------------------===// +// Custom call tests below +//===----------------------------------------------------------------------===// + +struct StrUserData { + explicit StrUserData(std::string str) : str(std::move(str)) {} + std::string str; +}; + +static absl::Status Add(StrUserData* user_data, + ffi::BufferR0 a, + ffi::BufferR0 b, + ffi::Result> sum) { + EXPECT_EQ(user_data->str, "foo"); + sum->typed_data()[0] = a.typed_data()[0] + b.typed_data()[0]; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kAdd, Add, + ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Arg>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_nanort_test$$add", "Host", + kAdd); + +TEST(NanoRtClientTest, CustomCallTest) { + const char* kModuleStr = R"( + HloModule module + + ENTRY custom_call { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT custom-call = s32[] custom-call(a, b), + custom_call_target="__xla_nanort_test$$add", + api_version=API_VERSION_TYPED_FFI + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + XlaComputation computation(module->ToProto()); + + NanoRtClient client; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client.Compile(computation)); + + int32_t a = 1.0f; + int32_t b = 2.0f; + int32_t result = 0.0f; + + std::vector arguments; + std::vector results; + arguments.push_back({&a, 1}); + arguments.push_back({&b, 1}); + results.push_back({&result, 1}); + + ffi::ExecutionContext context; + TF_ASSERT_OK(context.Emplace("foo")); + + NanoRtExecutable::ExecuteOptions execute_options; + execute_options.set_ffi_context(&context); + + auto event = executable->Execute(arguments, results, {}, execute_options); + tsl::BlockUntilReady(event); + + EXPECT_TRUE(event.IsConcrete()); + EXPECT_EQ(result, 3.0f); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc index 2930f5263e7366..f129a1983f61fc 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/cpu/nanort/nanort_executable.h" #include +#include #include #include #include @@ -31,6 +32,8 @@ limitations under the License. #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/executable_run_options.h" +#include "xla/ffi/execution_context.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_layout.h" @@ -190,6 +193,25 @@ NanoRtExecutable::ExecuteOptions::set_intra_op_thread_pool( return *this; } +NanoRtExecutable::ExecuteOptions& +NanoRtExecutable::ExecuteOptions::set_ffi_context( + const ffi::ExecutionContext* ffi_context) { + ffi_context_ = ffi_context; + return *this; +} + +NanoRtExecutable::ExecuteOptions& +NanoRtExecutable::ExecuteOptions::set_launch_id(int32_t launch_id) { + launch_id_ = launch_id; + return *this; +} + +NanoRtExecutable::ExecuteOptions& +NanoRtExecutable::ExecuteOptions::set_device_ordinal(int32_t device_ordinal) { + device_ordinal_ = device_ordinal; + return *this; +} + const Eigen::ThreadPoolDevice* NanoRtExecutable::ExecuteOptions::intra_op_thread_pool() const { return intra_op_thread_pool_; @@ -342,17 +364,26 @@ tsl::AsyncValueRef NanoRtExecutable::Execute( FunctionLibrary* function_library, const ExecuteOptions& options) : allocations(std::move(buffers)), - execute_params({function_library, &allocations, - /*xfeed=*/nullptr, options.intra_op_thread_pool(), - options.task_runner()}) {} + execute_params(Thunk::ExecuteParams{function_library, &allocations, + /*xfeed=*/nullptr, + options.intra_op_thread_pool(), + options.task_runner()}), + custom_call_execute_params( + RunId(options.launch_id()), options.device_ordinal(), + options.intra_op_thread_pool(), options.ffi_context()) { + execute_params.custom_call_params = &custom_call_execute_params; + } cpu::BufferAllocations allocations; Thunk::ExecuteParams execute_params; + Thunk::CustomCallExecuteParams custom_call_execute_params; }; - // Only do a heap allocation if we're running with a thread pool, this allows - // us to keep the execution context alive as long as we need it. - if (options.intra_op_thread_pool()) { + // Do a heap allocation if we're running with a thread pool or using + // custom calls. This allows us to keep the execution context + // alive as long as we need it, but also to skip a dynamic allocation when it + // is not required. + if (options.intra_op_thread_pool() || options.ffi_context()) { auto execution_context = std::make_unique( std::move(buffers), executable->function_library(), options); @@ -366,8 +397,9 @@ tsl::AsyncValueRef NanoRtExecutable::Execute( } else { cpu::BufferAllocations allocations(std::move(buffers)); Thunk::ExecuteParams execute_params{ - executable->function_library(), &allocations, /*xfeed=*/nullptr, - options.intra_op_thread_pool(), options.task_runner()}; + executable->function_library(), &allocations, + /*xfeed=*/nullptr, options.intra_op_thread_pool(), + options.task_runner()}; return executable->thunks().Execute(execute_params); } } diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h index 334f41cf7e304b..b0893cd959246d 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/cpu/alignment.h" #include "xla/backends/cpu/runtime/thread_pool_task_runner.h" +#include "xla/ffi/execution_context.h" #include "xla/service/executable.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" @@ -55,7 +56,12 @@ class NanoRtExecutable { class ExecuteOptions { public: - ExecuteOptions() : intra_op_thread_pool_(nullptr), task_runner_(nullptr) {} + ExecuteOptions() + : intra_op_thread_pool_(nullptr), + task_runner_(nullptr), + device_ordinal_(0), + launch_id_(0), + ffi_context_(nullptr) {} // Sets the thread pool device on which to run Eigen subcomputations. // // This field must be set for XLA:CPU models that call Eigen routines, but @@ -67,12 +73,31 @@ class NanoRtExecutable { ExecuteOptions& set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool); + ExecuteOptions& set_ffi_context(const ffi::ExecutionContext* ffi_context); + ExecuteOptions& set_collectives(CpuCollectives* collectives); + + ExecuteOptions& set_launch_id(int32_t launch_id); + + ExecuteOptions& set_device_ordinal(int32_t device_ordinal); + const Eigen::ThreadPoolDevice* intra_op_thread_pool() const; ThreadPoolTaskRunner* task_runner() const; + int32_t device_ordinal() const { return device_ordinal_; } + int32_t launch_id() const { return launch_id_; } + const ffi::ExecutionContext* ffi_context() const { return ffi_context_; } + private: const Eigen::ThreadPoolDevice* intra_op_thread_pool_; std::unique_ptr task_runner_; + + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. + // if multi-host programs are launched in different orders on different + // hosts, the launch IDs may be used by the runtime to detect the mismatch. + int32_t device_ordinal_; + int32_t launch_id_; + const ffi::ExecutionContext* ffi_context_; }; // A non-owning read-only view into the XLA executable's argument buffer. diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h index b5eb4b539335ac..99ad80cd5e7d47 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -181,7 +181,6 @@ class Thunk { const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr; const ffi::ExecutionContext* ffi_execution_context = nullptr; - private: CustomCallExecuteParams(RunId run_id, int32_t device_ordinal, const Eigen::ThreadPoolDevice* intra_op_thread_pool, const ffi::ExecutionContext* ffi_execution_context); From 366387b603b0e33b109dc54b89fbbfaef9d0aa83 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 16 Apr 2025 19:21:13 +0200 Subject: [PATCH 0875/1324] [mlir][tosa] Fix reduction legalization when output is dynamic (#90801) Currently the legalization uses the output shape to determine the 'shape' input of the final reshape operation inserted during the legalization. When the output type is not static, the 'shape' input contains '-1' values that represent dynamic dims from the output type shape. This commit uses information from the input type and "keep_dims" attribute to determine the final output shape to reshape to instead. This can help avoid creating a 'shape' input with non-static dimensions in some cases. Functionally the changes include adding `squeeze_axes` to compute the final output shape, and the propagation of the `keep_axis` attribute to the common reduction legalization to support this calculation. Change-Id: I6a1a013b518bb2c84b8ed8d901225e13383bfab4 --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 28 ++++++++- .../mlir/tosa/transforms/legalize_common.cc | 60 ++++++++++++++----- .../mlir/tosa/transforms/legalize_common.h | 17 ++++-- .../mlir/tosa/transforms/legalize_tf.cc | 14 ++--- .../mlir/tosa/transforms/legalize_tfl.cc | 14 ++--- 5 files changed, 97 insertions(+), 36 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 7b1cac84df042e..b9be1ce6aa2840 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -752,6 +752,18 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- +// CHECK-LABEL: test_reduce_any_dynamic_output +// CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_any %arg0 {axis = 0 : i32} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} +// CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] +func.func @test_reduce_any_dynamic_output(%arg0: tensor<13x21x3xi1>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: test_reduce_min // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_min %arg0 {axis = 0 : i32} // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} @@ -832,6 +844,21 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- +// CHECK-LABEL: test_reduce_mean_dynamic_output +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.0769230798> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR10]] +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK: %[[VAR4:.*]] = tosa.mul %[[VAR2]], %[[VAR0]], %[[SHIFT]] +func.func @test_reduce_mean_dynamic_output(%arg0: tensor<13x21x3xf32>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.mean"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: test_reduce_mean_out_of_bounds // CHECK: "tfl.mean" func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -846,7 +873,6 @@ func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor< // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x2x!quant.uniform> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1105078632> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 471149d7a99165..d93198b710502a 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -3010,13 +3010,12 @@ std::optional convertReduceOpCommon( bool is_quantized, int32_t input_scale_multiplier, int32_t input_scale_shift, int64_t input_zp, int32_t output_scale_multiplier, int32_t output_scale_shift, - int64_t output_zp, StringRef nan_mode = "") { + int64_t output_zp, bool keep_dims, StringRef nan_mode = "") { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; ArrayRef input_shape = input_type.getShape(); - ArrayRef output_shape = output_type.getShape(); auto input_rank = input_shape.size(); Location loc = op->getLoc(); @@ -3083,7 +3082,29 @@ std::optional convertReduceOpCommon( /*scale32=*/true); } + // If keep dims, no reshaping of the output is required + if (keep_dims) { + return val; + } + // Squeeze out the reduced axes. + const auto squeeze_axes = [](llvm::ArrayRef in, llvm::ArrayRef axes) { + llvm::SmallVector sorted_axes{axes}; + std::sort(sorted_axes.begin(), sorted_axes.end()); + auto current_axis = sorted_axes.begin(); + + llvm::SmallVector out; + out.reserve(in.size() - axes.size()); + for (const auto& [i, dim] : llvm::enumerate(in)) { + if (current_axis == sorted_axes.end() || i != *current_axis) + out.push_back(dim); + else + current_axis++; + } + return out; + }; + + const auto output_shape = squeeze_axes(input_shape, axes); auto output_shape_value = getTosaConstShape(rewriter, op->getLoc(), tensorflow::ConvertMlirShapeToTF(output_shape)); @@ -3099,7 +3120,7 @@ std::optional convertReduceOpCommon( PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, Type reduce_element_type, bool is_quantized, double input_scale, int64_t input_zp, - double output_scale, int64_t output_zp, StringRef nan_mode = "") { + double output_scale, int64_t output_zp, bool keep_dims, StringRef nan_mode = "") { const int32_t scale_width = 32; int32_t input_scale_multiplier; @@ -3115,7 +3136,7 @@ std::optional convertReduceOpCommon( return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, is_quantized, input_scale_multiplier, input_scale_shift, input_zp, - output_scale_multiplier, output_scale_shift, output_zp, nan_mode); + output_scale_multiplier, output_scale_shift, output_zp, keep_dims, nan_mode); } // Lowers ReduceAll to a sequence of TOSA ops. @@ -3123,14 +3144,15 @@ std::optional convertReduceAllOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceAny to a sequence of TOSA ops. @@ -3138,14 +3160,15 @@ std::optional convertReduceAnyOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceMin to a sequence of TOSA ops. @@ -3154,6 +3177,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode) { RankedTensorType input_type = dyn_cast(input_value.getType()); @@ -3161,7 +3185,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, nan_mode); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims, nan_mode); } // Lowers ReduceMax to a sequence of TOSA ops. @@ -3170,6 +3194,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode) { RankedTensorType input_type = dyn_cast(input_value.getType()); @@ -3177,7 +3202,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, nan_mode); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims, nan_mode); } // Lowers ReduceProd to a sequence of TOSA ops. @@ -3185,7 +3210,8 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -3203,7 +3229,7 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceSum to a sequence of TOSA ops. @@ -3211,7 +3237,8 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -3254,7 +3281,7 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, - input_is_qtype, input_scale, input_zp, output_scale, output_zp); + input_is_qtype, input_scale, input_zp, output_scale, output_zp, keep_dims); } // Lowers ReduceMean to a sequence of TOSA ops. @@ -3262,7 +3289,8 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { // reduce_mean is lowered as followed for quantized types: // op1 = reduce_sum(input) with the 1.0/num_elements_on_reduced_axis // integrated to the rescale layer, @@ -3355,7 +3383,7 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, auto val = convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, input_is_qtype, input_scale_multiplier, input_scale_shift, input_zp, - output_scale_multiplier, output_scale_shift, output_zp); + output_scale_multiplier, output_scale_shift, output_zp, keep_dims); if (!val.has_value()) return std::nullopt; diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 9b118ad6e73335..8cc74ee9bd5157 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -179,14 +179,16 @@ std::optional convertReduceAllOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceAny to a sequence of TOSA ops. std::optional convertReduceAnyOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceMin to a sequence of TOSA ops. std::optional convertReduceMinOp(PatternRewriter& rewriter, @@ -194,6 +196,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode = "PROPAGATE"); // Lowers ReduceMax to a sequence of TOSA ops. @@ -202,6 +205,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode = "PROPAGATE"); // Lowers ReduceProd to a sequence of TOSA ops. @@ -209,21 +213,24 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceSum to a sequence of TOSA ops. std::optional convertReduceSumOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceMean to a sequence of TOSA ops. std::optional convertReduceMeanOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elem); + ElementsAttr axes_elem, + bool keep_dims); // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize. std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 9578d0ecf8c0aa..5264e52fa69907 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -1123,7 +1123,7 @@ LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceAllOp( - rewriter, op, output_type, tf_all_op.getInput(), axes_elems); + rewriter, op, output_type, tf_all_op.getInput(), axes_elems, tf_all_op.getKeepDims()); if (!result) return failure(); @@ -1145,7 +1145,7 @@ LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceAnyOp( - rewriter, op, output_type, tf_any_op.getInput(), axes_elems); + rewriter, op, output_type, tf_any_op.getInput(), axes_elems, tf_any_op.getKeepDims()); if (!result) return failure(); @@ -1167,7 +1167,7 @@ LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceMaxOp( - rewriter, op, output_type, tf_max_op.getInput(), axes_elems); + rewriter, op, output_type, tf_max_op.getInput(), axes_elems, tf_max_op.getKeepDims()); if (!result) return failure(); @@ -1189,7 +1189,7 @@ LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceMinOp( - rewriter, op, output_type, tf_min_op.getInput(), axes_elems); + rewriter, op, output_type, tf_min_op.getInput(), axes_elems, tf_min_op.getKeepDims()); if (!result) return failure(); @@ -1211,7 +1211,7 @@ LogicalResult ConvertTFMeanOp::matchAndRewrite( return failure(); std::optional result = convertReduceMeanOp( - rewriter, op, output_type, tf_mean_op.getInput(), axes_elems); + rewriter, op, output_type, tf_mean_op.getInput(), axes_elems, tf_mean_op.getKeepDims()); if (!result) return failure(); @@ -1233,7 +1233,7 @@ LogicalResult ConvertTFProdOp::matchAndRewrite( return failure(); std::optional result = convertReduceProdOp( - rewriter, op, output_type, tf_prod_op.getInput(), axes_elems); + rewriter, op, output_type, tf_prod_op.getInput(), axes_elems, tf_prod_op.getKeepDims()); if (!result) return failure(); @@ -1255,7 +1255,7 @@ LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceSumOp( - rewriter, op, output_type, tf_sum_op.getInput(), axes_elems); + rewriter, op, output_type, tf_sum_op.getInput(), axes_elems, tf_sum_op.getKeepDims()); if (!result) return failure(); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index cfd80d61d1b245..b69a5cdba4dea2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -2612,7 +2612,7 @@ LogicalResult ConvertTFLReduceAllOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "fail to get reduction indices"); std::optional result = convertReduceAllOp( - rewriter, op, output_type, tfl_all_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_all_op.getInput(), axes_elems, tfl_all_op.getKeepDims()); if (!result) return failure(); @@ -2634,7 +2634,7 @@ LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite( return failure(); std::optional result = convertReduceAnyOp( - rewriter, op, output_type, tfl_any_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_any_op.getInput(), axes_elems, tfl_any_op.getKeepDims()); if (!result) return failure(); @@ -2656,7 +2656,7 @@ LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite( return failure(); std::optional result = convertReduceMaxOp( - rewriter, op, output_type, tfl_max_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_max_op.getInput(), axes_elems, tfl_max_op.getKeepDims()); if (!result) return failure(); @@ -2678,7 +2678,7 @@ LogicalResult ConvertTFLReduceMinOp::matchAndRewrite( return failure(); std::optional result = convertReduceMinOp( - rewriter, op, output_type, tfl_min_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_min_op.getInput(), axes_elems, tfl_min_op.getKeepDims()); if (!result) return failure(); @@ -2700,7 +2700,7 @@ LogicalResult ConvertTFLReduceProdOp::matchAndRewrite( return failure(); std::optional result = convertReduceProdOp( - rewriter, op, output_type, tfl_prod_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_prod_op.getInput(), axes_elems, tfl_prod_op.getKeepDims()); if (!result) return failure(); @@ -2722,7 +2722,7 @@ LogicalResult ConvertTFLMeanOp::matchAndRewrite( return failure(); std::optional result = convertReduceMeanOp( - rewriter, op, output_type, tfl_mean_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_mean_op.getInput(), axes_elems, tfl_mean_op.getKeepDims()); if (!result) return failure(); @@ -2744,7 +2744,7 @@ LogicalResult ConvertTFLSumOp::matchAndRewrite( return failure(); std::optional result = convertReduceSumOp( - rewriter, op, output_type, tfl_sum_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_sum_op.getInput(), axes_elems, tfl_sum_op.getKeepDims()); if (!result) return failure(); From 0010fa3aa6473d53e9ab2fae6727d6c80e739f92 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 09:46:14 -0700 Subject: [PATCH 0876/1324] Bug fix: canonicalize broadcasts after all-gather broadcast reorder. PiperOrigin-RevId: 748325564 --- .../xla/hlo/transforms/simplifiers/algebraic_simplifier.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 7466c5d7d9c482..a7925cc1ab222a 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -8256,6 +8256,10 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { if (arg->opcode() == HloOpcode::kBroadcast && Match(reduce->to_apply()->root_instruction(), m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { + TF_RET_CHECK( + std::is_sorted(arg->dimensions().begin(), arg->dimensions().end())) + << "Broadcasts need to be canonicalized before algebraic " + "simplification."; bool only_reduce_dims_from_broadcast = true; int64_t common_dims_prod = 1; int64_t num_common_dims = 0; From 015d11923a05bd051ef6bb822839dd8ba5e04a5f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 16 Apr 2025 09:59:54 -0700 Subject: [PATCH 0877/1324] [XLA:GPU] Fork `fusion_emitter_parametrized_test` in order to pursue the GEMM infrastructure test migration. PiperOrigin-RevId: 748330315 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 27 + ...fusion_emitter_parametrized_legacy_test.cc | 2391 +++++++++++++++++ .../fusion_emitter_parametrized_test.cc | 82 +- 3 files changed, 2462 insertions(+), 38 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_legacy_test.cc diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index cc323e19e4f271..a177538b71f353 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -829,6 +829,33 @@ xla_test( "gpu_b200", "gpu_amd_any", ], + tags = ["no_mac"], + deps = [ + ":support", + ":test_utils", + "//xla:comparison_util", + "//xla:error_spec", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +xla_test( + name = "fusion_emitter_parametrized_legacy_test", + srcs = if_gpu_is_configured(["fusion_emitter_parametrized_legacy_test.cc"]), + backends = [ + "gpu_a100", + "gpu_h100", + "gpu_b200", + "gpu_amd_any", + ], shard_count = 10, tags = ["no_mac"], deps = [ diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_legacy_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_legacy_test.cc new file mode 100644 index 00000000000000..7dc1de19dc85d1 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_legacy_test.cc @@ -0,0 +1,2391 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include +#include "absl/base/optimization.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "xla/backends/gpu/codegen/triton/support_legacy.h" +#include "xla/backends/gpu/codegen/triton/test_utils.h" +#include "xla/comparison_util.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +struct MixTypeParams { + PrimitiveType lhs_ty; + PrimitiveType rhs_ty; + int m; + int k; + int n; + float aabs = 1e-6; + float arel = 1e-6; +}; + +class MixedTypeTest : public GpuCodegenTest, + public ::testing::WithParamInterface { + public: + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() + << "Related fusions are not performed on ROCm without Triton."; + } + } + + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // We are testing Triton, remove cuBLAS fallback for these tests. + debug_options.set_xla_gpu_cublas_fallback(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } +}; + +TEST_P(MixedTypeTest, MixedTypeDotProducesCorrectResult) { + MixTypeParams params = GetParam(); + const std::string hlo_string_template = R"( +HloModule m + +ENTRY e { + p0 = $0[$2,$3] parameter(0) + p0c = $1[$2,$3] convert(p0) + p1 = $1[$3,$4] parameter(1) + ROOT _ = $1[$2,$4] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + std::string hlo_string = absl::Substitute( + hlo_string_template, + primitive_util::LowercasePrimitiveTypeName(params.lhs_ty), + primitive_util::LowercasePrimitiveTypeName(params.rhs_ty), params.m, + params.k, params.n); + MatchOptimizedHlo(hlo_string, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{params.aabs, params.arel})); +} + +std::string GemmTestParamsParamsToString( + const ::testing::TestParamInfo& data) { + return absl::StrCat( + primitive_util::LowercasePrimitiveTypeName(data.param.lhs_ty), "_", + primitive_util::LowercasePrimitiveTypeName(data.param.rhs_ty), "_", + data.param.m, "_", data.param.k, "_", data.param.n); +} + +INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest, + ::testing::ValuesIn({ + MixTypeParams{PRED, F16, 16, 32, 8}, + MixTypeParams{PRED, BF16, 16, 32, 8}, + MixTypeParams{PRED, F32, 16, 32, 8, 2e-4, 2e-3}, + MixTypeParams{S8, F16, 16, 32, 8}, + MixTypeParams{S8, BF16, 16, 32, 8}, + MixTypeParams{S8, F32, 16, 32, 8, 5e-2, 1e-2}, + MixTypeParams{S8, F32, 101, 7, 303, 0.1, 0.1}, + MixTypeParams{S8, F32, 101, 32, 303, 0.1, 0.1}, + MixTypeParams{S8, F32, 101, 2048, 303, 0.5, 0.1}, + MixTypeParams{S8, F32, 101, 2555, 303, 0.5, 0.1}, + // Is supported but overflows. + // GemmTestParams{S32, F16}, + MixTypeParams{S16, F16, 30, 19, 12}, + MixTypeParams{S32, F32, 4, 4, 4, 1, 1e-2}, + MixTypeParams{F16, BF16, 16, 32, 8}, + MixTypeParams{F16, F32, 16, 32, 8, 1e-3, 1e-6}, + MixTypeParams{BF16, F16, 16, 32, 8, 1e-3, 1e-6}, + MixTypeParams{BF16, F32, 16, 32, 8, 1e-3, 1e-6}, + // Supported but disabled because narrowing + // converts should rather belong to producers. + // TODO(b/266862493): Move these to CompareTest. + // TritonRewriteTest2Params{S32, BF16}, + // TritonRewriteTest2Params{F32, F16}, + // TritonRewriteTest2Params{F32, BF16}, + MixTypeParams{S8, BF16, 24, 40, 8}, + MixTypeParams{S8, F16, 80, 16, 32, 1e-3, 1e-6}, + MixTypeParams{F16, F32, 127, 3, 300, 1e-2, 1e-2}, + MixTypeParams{F16, BF16, 544, 96, 16, 1e-3, 1e-3}, + MixTypeParams{BF16, F32, 77, 500, 333, 3e-3, 3e-3}, + }), + GemmTestParamsParamsToString); + +class TritonTest : public GpuCodegenTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_cublas_fallback(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } + + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } +}; + +class ElementwiseTest : public TritonTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +std::string ElementwiseTestParamsToString( + const ::testing::TestParamInfo>& + data) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = data.param; + return absl::StrCat( + primitive_util::LowercasePrimitiveTypeName(data_type), "_", + absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); +} + +using UnaryElementwiseTest = ElementwiseTest; + +TEST_P(UnaryElementwiseTest, ElementwiseFusionExecutesCorrectly) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = GetParam(); + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[15,33]{1,0} parameter(0) + parameter_1 = $0[33,68]{1,0} parameter(1) + f1.1 = $0[33,68]{1,0} $1(parameter_1) + c.1 = f32[33,68]{1,0} convert(f1.1) + ROOT _.1 = f32[15,68]{1,0} dot(parameter_0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p1 = $0[33,68]{1,0} parameter(1) + p0 = f32[15,33]{1,0} parameter(0) + ROOT triton_gemm__ = f32[15,68]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"32", + "block_n":"32", + "block_k":"32", + "split_k":"1", + "num_stages":"1", + "num_warps":"4", + "num_ctas":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + const std::string kHloRefTemplate = R"( +fused_computation { + param_0.1 = $0[33,68]{1,0} parameter(0) + f.1 = $0[33,68]{1,0} $1(param_0.1) + ROOT convert.1 = f32[33,68]{1,0} convert(f.1) +} + +ENTRY e { + p1 = $0[33,68]{1,0} parameter(1) + p0 = f32[15,33]{1,0} parameter(0) + fusion = f32[33,68]{1,0} fusion(p1), kind=kLoop, calls=fused_computation + gemm = (f32[15,68]{1,0}, s8[0]{0}) custom-call(p0, fusion), + custom_call_target="__cublas$$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": + {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, + "alpha_imag":0,"precision_config": + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[15,68]{1,0} get-tuple-element((f32[15,68]{1,0}, s8[0]{0}) gemm), index=0 +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, + /*run_hlo_passes=*/false)); +} + +TEST_P(UnaryElementwiseTest, ElementwiseUnaryOpExecutesCorrectly) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = GetParam(); + + const std::string kHloTestTemplate = R"( +triton_computation { + parameter_0 = $0[33,68]{1,0} parameter(0) + output = $0[33,68]{1,0} $1(parameter_0) + ROOT convert = f32[33,68]{1,0} convert(output) +} + +ENTRY e { + p0 = $0[33,68]{1,0} parameter(0) + ROOT triton_fusion = f32[33,68]{1,0} fusion(p0), kind=kCustom, + calls=triton_computation, backend_config={ + "fusion_backend_config":{ + "kind":"__triton", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["1", "1"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + const std::string kHloRefTemplate = R"( +fused_computation { + param_0.1 = $0[33,68]{1,0} parameter(0) + output = $0[33,68]{1,0} $1(param_0.1) + ROOT convert = f32[33,68]{1,0} convert(output) +} + +ENTRY e { + p0 = $0[33,68]{1,0} parameter(0) + ROOT fusion = f32[33,68]{1,0} fusion(p0), kind=kLoop, calls=fused_computation +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, + /*run_hlo_passes=*/false)); +} + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuitePRED, UnaryElementwiseTest, + ::testing::Combine( + ::testing::Values(PRED), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(PRED)), + ::testing::Values(3e-2)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteS8, UnaryElementwiseTest, + ::testing::Combine( + ::testing::Values(S8), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(S8)), + ::testing::Values(3e-2)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteS16, UnaryElementwiseTest, + ::testing::Combine( + ::testing::Values(S16), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(S16)), + ::testing::Values(1e-3)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteS32, UnaryElementwiseTest, + ::testing::Combine( + ::testing::Values(S32), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(S32)), + ::testing::Values(1e-3)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteF16, UnaryElementwiseTest, + ::testing::Combine( + ::testing::Values(F16), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(F16)), + ::testing::Values(2e-4)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteF32, UnaryElementwiseTest, + ::testing::Combine( + ::testing::Values(F32), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(F32)), + ::testing::Values(1e-6)), + ElementwiseTestParamsToString); + +using BinaryElementwiseTest = ElementwiseTest; + +TEST_P(BinaryElementwiseTest, ElementwiseFusionExecutesCorrectly) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = GetParam(); + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + f1.1 = $0[11,63]{1,0} $1(parameter_1, parameter_2) + c.1 = f32[11,63]{1,0} convert(f1.1) + ROOT _.1 = f32[92,63]{1,0} dot(parameter_0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p0 = f32[92,11]{1,0} parameter(0) + p1 = $0[11,63]{1,0} parameter(1) + p2 = $0[11,63]{1,0} parameter(2) + ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"64", + "block_n":"32", + "block_k":"64", + "split_k":"1", + "num_stages":"2", + "num_warps":"2", + "num_ctas":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + const std::string kHloRefTemplate = R"( +fused_computation { + p0 = $0[11,63]{1,0} parameter(0) + p1 = $0[11,63]{1,0} parameter(1) + f.1 = $0[11,63]{1,0} $1(p0, p1) + ROOT convert.1 = f32[11,63]{1,0} convert(f.1) +} + +ENTRY e { + p2 = $0[11,63]{1,0} parameter(2) + p1 = $0[11,63]{1,0} parameter(1) + p0 = f32[92,11]{1,0} parameter(0) + fusion = f32[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation + gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), + custom_call_target="__cublas$$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": + {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, + "alpha_imag":0,"precision_config": + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[92,63]{1,0} get-tuple-element((f32[92,63]{1,0}, s8[0]{0}) gemm), index=0 +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, + /*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6)); +} + +TEST_P(BinaryElementwiseTest, ElementwiseBinaryOpExecutesCorrectly) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = GetParam(); + + const std::string kHloTestTemplate = R"( +triton_computation { + parameter_0 = $0[11,63]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + output = $0[11,63]{1,0} $1(parameter_0, parameter_1) + ROOT c.1 = f32[11,63]{1,0} convert(output) +} + +ENTRY e { + p0 = $0[11,63]{1,0} parameter(0) + p1 = $0[11,63]{1,0} parameter(1) + ROOT triton_fusion = f32[11,63]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_computation, backend_config={ + "fusion_backend_config":{ + "kind":"__triton", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["1", "1"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + const std::string kHloRefTemplate = R"( +fused_computation { + p0 = $0[11,63]{1,0} parameter(0) + p1 = $0[11,63]{1,0} parameter(1) + output = $0[11,63]{1,0} $1(p0, p1) + ROOT convert.1 = f32[11,63]{1,0} convert(output) +} + +ENTRY e { + p1 = $0[11,63]{1,0} parameter(1) + p0 = $0[11,63]{1,0} parameter(0) + ROOT fusion = f32[11,63]{1,0} fusion(p0, p1), kind=kLoop, calls=fused_computation +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, + /*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6)); +} + +bool HloOpcodeIsComparison(HloOpcode opcode) { + return opcode == HloOpcode::kCompare; +} +std::vector TestedBinaryElementwise(PrimitiveType element_type) { + std::vector ret = + legacy_triton::TritonSupportedBinaryElementwiseUpToFloatNormalization( + element_type); + // Comparison requires an additional property. + ret.erase(std::remove_if(ret.begin(), ret.end(), HloOpcodeIsComparison), + ret.end()); + return ret; +} + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuitePRED, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED), + ::testing::ValuesIn(TestedBinaryElementwise(PRED)), + ::testing::Values(0)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteS8, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(S8), + ::testing::ValuesIn(TestedBinaryElementwise(S8)), + ::testing::Values(0)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteS16, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(S16), + ::testing::ValuesIn(TestedBinaryElementwise(S16)), + ::testing::Values(0)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteS32, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(S32), + ::testing::ValuesIn(TestedBinaryElementwise(S32)), + ::testing::Values(0)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteF16, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(F16), + ::testing::ValuesIn(TestedBinaryElementwise(F16)), + ::testing::Values(2e-4)), + ElementwiseTestParamsToString); + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteF32, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(F32), + ::testing::ValuesIn(TestedBinaryElementwise(F32)), + ::testing::Values(1e-6)), + ElementwiseTestParamsToString); + +class CompareTest : public TritonTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +std::string CompareTestParamsToString( + const ::testing::TestParamInfo< + std::tuple>& data) { + PrimitiveType data_type; + Comparison::Direction direction; + std::tie(data_type, direction) = data.param; + return absl::StrCat(primitive_util::LowercasePrimitiveTypeName(data_type), + "_", ComparisonDirectionToString(direction)); +} + +TEST_P(CompareTest, CompareFusionExecutesCorrectly) { + PrimitiveType data_type; + Comparison::Direction direction; + std::tie(data_type, direction) = GetParam(); + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + f1.1 = pred[11,63]{1,0} compare(parameter_1, parameter_2), direction=$1 + c.1 = f32[11,63]{1,0} convert(f1.1) + ROOT _.1 = f32[92,63]{1,0} dot(parameter_0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p0 = f32[92,11]{1,0} parameter(0) + p1 = $0[11,63]{1,0} parameter(1) + p2 = $0[11,63]{1,0} parameter(2) + ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, + calls=triton_gemm___computation, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm", + "triton_gemm_config": { + "block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + ComparisonDirectionToString(direction)); + + const std::string kHloRefTemplate = R"( +fused_computation { + p0 = $0[11,63]{1,0} parameter(0) + p1 = $0[11,63]{1,0} parameter(1) + f.1 = pred[11,63]{1,0} compare(p0, p1), direction=$1 + ROOT convert.1 = f32[11,63]{1,0} convert(f.1) +} + +ENTRY e { + p2 = $0[11,63]{1,0} parameter(2) + p1 = $0[11,63]{1,0} parameter(1) + p0 = f32[92,11]{1,0} parameter(0) + fusion = f32[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation + gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), + custom_call_target="__cublas$$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": + {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, + "alpha_imag":0,"precision_config": + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[92,63]{1,0} get-tuple-element((f32[92,63]{1,0}, s8[0]{0}) gemm), index=0 +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + ComparisonDirectionToString(direction)); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case PRED: + case S8: + tolerance = 3e-2; + break; + case S16: + tolerance = 1e-3; + break; + case S32: + tolerance = 1e-5; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, + /*run_hlo_passes=*/false)); +} + +using cd = Comparison::Direction; + +INSTANTIATE_TEST_SUITE_P( + CompareTestSuite, CompareTest, + ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32), + ::testing::Values(cd::kEq, cd::kNe, cd::kGe, cd::kGt, + cd::kLe, cd::kLt)), + CompareTestParamsToString); + +class SelectTest : public TritonTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(SelectTest, SelectFusionExecutesCorrectly) { + PrimitiveType data_type1, data_type2; + std::tie(data_type1, data_type2) = GetParam(); + for (const PrimitiveType type : {data_type1, data_type2}) { + if (!legacy_triton::IsTritonSupportedDataType(type, + GetCudaComputeCapability())) { + GTEST_SKIP() << absl::Substitute( + "Unsupported data type: $0", + primitive_util::LowercasePrimitiveTypeName(type)); + } + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = $1[92,13]{1,0} parameter(0) + parameter_1 = $0[13,63]{1,0} parameter(1) + parameter_2 = $0[13,63]{1,0} parameter(2) + parameter_3 = pred[13,63]{1,0} parameter(3) + f1.1 = $0[13,63]{1,0} select(parameter_3, parameter_1, parameter_2) + c.1 = $1[13,63]{1,0} convert(f1.1) + ROOT _.1 = $1[92,63]{1,0} dot(parameter_0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p0 = $1[92,13]{1,0} parameter(0) + p1 = $0[13,63]{1,0} parameter(1) + p2 = $0[13,63]{1,0} parameter(2) + p3 = pred[13,63]{1,0} parameter(3) + ROOT triton_gemm__ = $1[92,63]{1,0} fusion(p0, p1, p2, p3), kind=kCustom, + calls=triton_gemm___computation, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm", + "triton_gemm_config": { + "block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type1), + primitive_util::LowercasePrimitiveTypeName(data_type2)); + + const std::string kHloRefTemplate = R"( +fused_computation { + p0 = $0[13,63]{1,0} parameter(0) + p1 = $0[13,63]{1,0} parameter(1) + p2 = pred[13,63]{1,0} parameter(2) + f.1 = $0[13,63]{1,0} select(p2, p0, p1) + ROOT convert.1 = $1[13,63]{1,0} convert(f.1) +} + +ENTRY e { + p3 = pred[13,63]{1,0} parameter(3) + p2 = $0[13,63]{1,0} parameter(2) + p1 = $0[13,63]{1,0} parameter(1) + p0 = $1[92,13]{1,0} parameter(0) + fusion = $1[13,63]{1,0} fusion(p1, p2, p3), kind=kLoop, + calls=fused_computation + gemm = ($1[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), + custom_call_target="__cublas$$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": + {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, + "alpha_imag":0,"precision_config": + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = $1[92,63]{1,0} get-tuple-element(($1[92,63]{1,0}, s8[0]{0}) gemm), index=0 +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type1), + primitive_util::LowercasePrimitiveTypeName(data_type2)); + + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/0, /*arel=*/0}, + /*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/9)); +} + +std::string TwoPrimitiveTypesToString( + const ::testing::TestParamInfo>& + data) { + PrimitiveType data_type1; + PrimitiveType data_type2; + std::tie(data_type1, data_type2) = data.param; + return absl::StrCat(primitive_util::LowercasePrimitiveTypeName(data_type1), + "_", + primitive_util::LowercasePrimitiveTypeName(data_type2)); +} + +// BF16: depending on the GPU generation. +constexpr std::array kSupportedDataTypes{PRED, S8, S16, S32, + F16, F32, BF16}; + +INSTANTIATE_TEST_SUITE_P( + SelectTestSuite, SelectTest, + ::testing::Combine(::testing::ValuesIn(kSupportedDataTypes), + ::testing::Values(F16, BF16, F32)), + TwoPrimitiveTypesToString); + +class ConstantTest : public TritonTest, + public ::testing::WithParamInterface {}; + +TEST_P(ConstantTest, ConstantFusionExecutesCorrectly) { + const PrimitiveType data_type = GetParam(); + if (!legacy_triton::IsTritonSupportedDataType(data_type, + GetCudaComputeCapability())) { + GTEST_SKIP() << absl::Substitute( + "Unsupported data type: $0", + primitive_util::LowercasePrimitiveTypeName(data_type)); + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = f32[11,63]{1,0} parameter(1) + c = $0[] constant(123) + b = $0[11,63] broadcast(c) + cv = f32[11,63] convert(b) + m = f32[11,63] multiply(cv, parameter_1) + ROOT _.1 = f32[92,63]{1,0} dot(parameter_0, m), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p0 = f32[92,11]{1,0} parameter(0) + p1 = f32[11,63]{1,0} parameter(1) + ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm___computation, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm", + "triton_gemm_config":{ + "block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string kHloRefTemplate = R"( +fused_computation { + p0 = f32[11,63]{1,0} parameter(0) + c = $0[] constant(123) + b = $0[11,63] broadcast(c) + cv = f32[11,63] convert(b) + ROOT m = f32[11,63] multiply(cv, p0) +} + +ENTRY e { + p1 = f32[11,63]{1,0} parameter(1) + p0 = f32[92,11]{1,0} parameter(0) + fusion = f32[11,63]{1,0} fusion(p1), kind=kLoop, + calls=fused_computation + gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), + custom_call_target="__cublas$$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": + {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, + "alpha_imag":0,"precision_config": + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[92,63]{1, 0} get-tuple-element((f32[92,63]{1, 0}, s8[0]{0}) gemm), index=0 +})"; + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); + + float tolerance; + switch (data_type) { + case F32: + case BF16: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case PRED: + case S8: + tolerance = 3e-2; + break; + case S16: + tolerance = 1e-3; + break; + case S32: + tolerance = 1e-5; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, + /*run_hlo_passes=*/false)); +} + +INSTANTIATE_TEST_SUITE_P(ConstantTestSuite, ConstantTest, + ::testing::ValuesIn(kSupportedDataTypes), + TritonSupportTestTypeToString); + +class ConvertTest : public TritonTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(ConvertTest, ConvertFusionExecutesCorrectly) { + PrimitiveType data_type1, data_type2; + std::tie(data_type1, data_type2) = GetParam(); + for (const PrimitiveType type : {data_type1, data_type2}) { + if (!legacy_triton::IsTritonSupportedDataType(type, + GetCudaComputeCapability())) { + GTEST_SKIP() << absl::Substitute( + "Unsupported data type: $0", + primitive_util::LowercasePrimitiveTypeName(type)); + } + } + + const std::string hlo_text = absl::Substitute( + R"( +t { + p0 = $0[2,2] parameter(0) + p0c = $1[2,2] convert(p0) + p0cc = f32[2,2] convert(p0c) + p1 = f32[2,2] parameter(1) + ROOT r = f32[2,2] dot(p0cc, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p0 = $0[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT r = f32[2,2] fusion(p0, p1), kind=kCustom, calls=t, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})", + primitive_util::LowercasePrimitiveTypeName(data_type1), + primitive_util::LowercasePrimitiveTypeName(data_type2)); + + MatchOptimizedHlo(hlo_text, R"( +CHECK: block_m + )"); +} + +INSTANTIATE_TEST_SUITE_P( + ConvertTestSuite, ConvertTest, + ::testing::Combine(::testing::ValuesIn(kSupportedDataTypes), + ::testing::ValuesIn(kSupportedDataTypes)), + TwoPrimitiveTypesToString); + +class TritonNormalizationTest + : public GpuCodegenTest, + public ::testing::WithParamInterface { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // TODO(b/38354253): Remove once HloTestBase does not remove constant + // folding. + debug_options.clear_xla_disable_hlo_passes(); + return debug_options; + } +}; + +TEST_P(TritonNormalizationTest, CanFuseAndEmitExactSoftmax) { + PrimitiveType data_type = GetParam(); + + if (data_type == F16) { + GTEST_SKIP() << "Exponential op does not support F16."; + } + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + exponential = $0[127,125]{1,0} exponential(subtract) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[param_0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[param_0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case BF16: + tolerance = 2e-4; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, CanFuseAndEmitFirstSoftmaxDiamond) { + PrimitiveType data_type = GetParam(); + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + case BF16: + tolerance = 2e-4; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, CanFuseAndEmitSoftmaxDiamondWithSmallRows) { + PrimitiveType data_type = GetParam(); + constexpr absl::string_view kHloTextTemplate = R"( +HloModule softmax +min_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT minimum = $0[] minimum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,7]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=min_computation + broadcast = $0[127,7]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,7]{1,0} subtract(param_0, broadcast) +} +)"; + + const std::string hlo_text = absl::Substitute( + kHloTextTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); + + constexpr absl::string_view kHloRefTemplate = R"( +; CHECK: ENTRY +; CHECK: %[[param_0:.*]] = $0[127,7]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[param_0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0))); +} + +TEST_F(TritonNormalizationTest, CanFuseAndEmitDiamondWithBF16Converts) { + const std::string hlo_text = R"( +HloModule softmax +max_computation { + arg_0 = bf16[] parameter(0) + arg_1 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + constant_neg_inf = bf16[] constant(-inf) + reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = bf16[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + + const std::string hlo_ref = R"( +; CHECK: %[[P0_FUSION:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: %[[convert:.*]] = f32[127,125]{1,0} convert(%[[P0_FUSION]]) +; CHECK: ENTRY +; CHECK: %[[P0_ENTRY:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0_ENTRY]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance = 2e-4; + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, + CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[1,3,125,125]{3,2,1,0} parameter(0) + reshape = $0[3,125,125]{2,1,0} reshape($0[1,3,125,125]{3,2,1,0} param_0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[3,125]{1,0} reduce($0[3,125,125]{2,1,0} reshape, $0[] constant_neg_inf), dimensions={2}, to_apply=max_computation + broadcast = $0[1,3,125,125]{3,2,1,0} broadcast($0[3,125]{1,0} reduce), dimensions={1,2} + ROOT subtract = $0[1,3,125,125]{3,2,1,0} subtract($0[1,3,125,125]{3,2,1,0} param_0, $0[1,3,125,125]{3,2,1,0} broadcast) +})"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[1,3,125,125]{3,2,1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, + CanFuseAndEmitSoftmaxWithIntermediateUnaryElementwise) { + PrimitiveType data_type = GetParam(); + + if (data_type == F16) { + GTEST_SKIP() << "Exponential op does not support F16."; + } + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + abs = $0[127,125]{1,0} abs(subtract) + exponential = $0[127,125]{1,0} exponential(abs) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(subtract, constant_zero), dimensions={1}, to_apply=add_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT multiply = $0[127,125]{1,0} multiply(subtract, second_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, + CanFuseAndEmitDiamondWithTrailingUnaryElementwiseAtTheRoot) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + ROOT abs = $0[127,125]{1,0} abs(subtract) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, + CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + abs = $0[127,125]{1,0} abs(param_0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(abs, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, + CanFuseAndEmitSoftmaxDiamondWithLastDimensionBitcastAfterReduce) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} + +ENTRY main { + param_0 = $0[3,127,125]{2,1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[3,127]{1,0} reduce(param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation + bitcasted_reduce = $0[381]{0} reshape(reduce) + broadcast = $0[381,125]{1,0} broadcast(bitcasted_reduce), dimensions={0} + bitcasted_broadcast = $0[3,127,125]{2,1,0} reshape(broadcast) + ROOT subtract = $0[3,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[3,127,125]{2,1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, + CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectly) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_$0 = $0[127,125]{1,0} convert(param_0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0_$0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(param_0_$0, broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamond +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + multiply = $0[127,125]{1,0} multiply(param_0, param_0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(multiply, broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 3e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamond +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + multiply = $0[127]{0} multiply(reduce, reduce) + broadcast = $0[127,125]{1,0} broadcast(multiply), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamonds +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + add = $0[127,125]{1,0} add(subtract, subtract) + second_reduce = $0[127]{0} reduce(add, constant_neg_inf), dimensions={1}, to_apply=max_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT add_root = $0[127,125]{1,0} add(add, second_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + // The precision-changing ops in the kernel above are add & subtract, meaning + // a value can be X*2(first add)*2(subtract)*2(second add) larger than it was + // originally. In order to fit this into a datatype, we do: + // X*2^3 <= 2^(fraction bits of the data type) + // 2^(max_bits_of_precision)*2^3 <= 2^(fraction bits of the data type) + // max_bits_of_precision = fraction_bits - 3. + uint max_bits_of_precision; + switch (data_type) { + case F32: + max_bits_of_precision = 20; + break; + case F16: + max_bits_of_precision = 7; + break; + case BF16: + max_bits_of_precision = 4; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0), + /*reference_preprocessor=*/nullptr, + /*test_preprocessor=*/nullptr, + max_bits_of_precision)); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamond +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + ROOT multiply = $0[127,125]{1,0} multiply(subtract, subtract) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-3; + break; + default: + ABSL_UNREACHABLE(); + // ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitTwoBinaryElementwiseWhereBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamonds +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + add = $0[127,125]{1,0} add(subtract, subtract) + multiply = $0[127,125]{1,0} multiply(add, add) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT multiply_root = $0[127,125]{1,0} multiply(multiply, second_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P(TritonNormalizationTest, DiamondEmitterIsNumericallyStable) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule softmax +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +min_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT minimum = $0[] minimum(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + exponential = $0[127,125]{1,0} exponential(subtract) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=min_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0))); +} + +TEST_P(TritonNormalizationTest, CanFuseAndEmitRMSNormDiamond) { + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule rms_norm +add_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT add.1 = $0[] add(arg_0, arg_1) +} +ENTRY main.30 { + param_0 = $0[10,10,10,128]{3,2,1,0} parameter(0) + multiply_param = $0[10,10,10,128]{3,2,1,0} multiply(param_0, param_0) + constant_0 = $0[] constant(0) + reduce = $0[10,10,10]{2,1,0} reduce(multiply_param, constant_0), dimensions={3}, to_apply=add_computation + constant_1 = $0[] constant(0.333333343) + splat = $0[10,10,10]{2,1,0} broadcast(constant_1), dimensions={} + multiply_splat = $0[10,10,10]{2,1,0} multiply(reduce, splat) + epsilon = $0[] constant(1e-06) + splat_epsilon = $0[10,10,10]{2,1,0} broadcast(epsilon), dimensions={} + add = $0[10,10,10]{2,1,0} add(multiply_splat, splat_epsilon) + rsqrt = $0[10,10,10]{2,1,0} rsqrt(add) + broadcast = $0[10,10,10,128]{3,2,1,0} broadcast(rsqrt), dimensions={0,1,2} + ROOT multiply = $0[10,10,10,128]{3,2,1,0} multiply(param_0, broadcast) +} +)"; + + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[10,10,10,128]{3,2,1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 5e-4; + break; + case BF16: + tolerance = 4e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamonds +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + constant = $0[] constant(0.333333343) + broadcast_splat = $0[127,125]{1,0} broadcast(constant), dimensions={} + multiply = $0[127,125]{1,0} multiply(broadcast_splat, subtract) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT second_subtract = $0[127,125]{1,0} subtract(multiply, second_broadcast) +} +)"; + + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamonds +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + constant = $0[] constant(0.333333343) + broadcast_splat = $0[127,125]{1,0} broadcast(constant), dimensions={} + multiply = $0[127,125]{1,0} multiply(subtract, broadcast_splat) + constant_zero = $0[] constant(0) + second_reduce = $0[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation + second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0} + ROOT second_subtract = $0[127,125]{1,0} subtract(multiply, second_broadcast) +} +)"; + + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamond +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + constant = $0[] constant(0.333333343) + broadcast_splat = $0[127]{0} broadcast(constant), dimensions={} + multiply = $0[127]{0} multiply(broadcast_splat, reduce) + broadcast = $0[127,125]{1,0} broadcast(multiply), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamond +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + subtract = $0[127,125]{1,0} subtract(param_0, broadcast) + constant = $0[] constant(0.333333343) + broadcast_splat = $0[127,125]{1,0} broadcast(constant), dimensions={} + ROOT multiply = $0[127,125]{1,0} multiply(broadcast_splat, subtract) +} +)"; + + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseProducerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule fusible_diamond +add_computation { + arg_0.1 = $0[] parameter(0) + arg_1.1 = $0[] parameter(1) + ROOT add = $0[] add(arg_0.1, arg_1.1) +} +ENTRY main { + + param_0 = $0[127,125]{1,0} parameter(0) + constant = $0[] constant(0.333333343) + broadcast_splat = $0[127,125]{1,0} broadcast(constant), dimensions={} + multiply = $0[127,125]{1,0} multiply(broadcast_splat, param_0) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=add_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(multiply, broadcast) +} +)"; + + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance; + switch (data_type) { + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + default: + ABSL_UNREACHABLE(); + } + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_P( + TritonNormalizationTest, + CanFuseAndEmitBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducerIntoDiamond) { // NOLINT(whitespace/line_length) + PrimitiveType data_type = GetParam(); + + const std::string hlo_text_template = R"( +HloModule nonfusible_diamond +max_computation { + arg_0 = $0[] parameter(0) + arg_1 = $0[] parameter(1) + ROOT max = $0[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = $0[127,125]{1,0} parameter(0) + param_1 = $0[127,125]{1,0} parameter(1) + constant_2 = $0[] constant(2) + broadcast_splat = $0[127,125]{1,0} broadcast(constant_2), dimensions={} + multiply = $0[127,125]{1,0} multiply(param_0, broadcast_splat) + constant_neg_inf = $0[] constant(-inf) + reduce = $0[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = $0[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + const std::string hlo_text = absl::Substitute( + hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + const std::string hlo_ref_template = R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + const std::string hlo_ref = absl::Substitute( + hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); + + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance = 0.0; + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +INSTANTIATE_TEST_SUITE_P(TritonNormalizationTestSuite, TritonNormalizationTest, + ::testing::Values(F32, F16, BF16)); + +TEST_F(TritonNormalizationTest, CanFuseAndEmitTritonSoftmaxWithTwoParameters) { + const std::string hlo_text = R"( +HloModule layernorm + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + broadcast_0 = f32[125,127]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[125,127]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} +)"; + + const std::string hlo_ref = R"( +; CHECK: ENTRY +; CHECK-DAG: %[[param_0:.*]] = f32[125,127]{1,0} parameter(0) +; CHECK-DAG: %[[param_1:.*]] = f32[127]{0} parameter(1) +; CHECK: ROOT +; CHECK-SAME: f32[125,127]{1,0} fusion +; CHECK-SAME: %[[param_0]] +; CHECK-SAME: %[[param_1]] +; CHECK-SAME: kind=kCustom +; CHECK-SAME: triton_softmax +)"; + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance = 2e-6; + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_F(TritonNormalizationTest, CanFuseAndEmitTritonSoftmaxWithNonBatchReduce) { + const std::string hlo_text = R"( +HloModule layernorm + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[10,125,127]{2,1,0} parameter(1) + constant = f32[] constant(0) + reduce_0 = f32[125,127]{1,0} reduce(param_1, constant), dimensions={0}, to_apply=add + multiply_0 = f32[125,127]{1,0} multiply(param_0, reduce_0) + constant_0 = f32[] constant(0) + reduce_1 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_1), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} +)"; + + // We expect to not fuse everything into the triton softmax, because of the + // reduce over the non-row dimension. + const std::string hlo_ref = R"( +; CHECK: ENTRY +; CHECK-DAG: %[[P0:.*]] = f32[125,127]{1,0} parameter(0) +; CHECK-DAG: %[[P1:.*]] = f32[10,125,127]{2,1,0} parameter(1) +; CHECK: ROOT %[[FUSION:.*]] = f32[125,127]{1,0} fusion(%[[P0]], %[[P1]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton +)"; + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance = 2e-6; + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +class ReductionTypeTest : public TritonTest, + public ::testing::WithParamInterface { +}; + +TEST_P(ReductionTypeTest, DifferentReductionTypes) { + PrimitiveType data_type = GetParam(); + + const std::string kHloTestTemplate = R"( +max { + p0 = $0[] parameter(0) + p1 = $0[] parameter(1) + ROOT max = $0[] maximum(p0, p1) +} + +triton_computation { + p = $0[400,16] parameter(0) + zero = $0[] constant(0) + ROOT reduce = $0[400] reduce(p, zero), dimensions={1}, to_apply=max +} + +ENTRY entry_computation { + p = $0[400,16] parameter(0) + ROOT fusion = $0[400] fusion(p), kind=kCustom, calls=triton_computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton", + "block_level_fusion_config":{ + "output_tiles":[{"sizes":["400"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); + EXPECT_TRUE( + RunAndCompareNoHloPasses(hlo_test, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +constexpr std::array kReductionSupportedDataTypes{ + PRED, S8, S16, S32, S64, F16, F32, F64, BF16}; + +INSTANTIATE_TEST_SUITE_P(ReductionTypeTestSuite, ReductionTypeTest, + ::testing::ValuesIn(kReductionSupportedDataTypes), + TritonSupportTestTypeToString); + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc index 7dc1de19dc85d1..54218d257a7802 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_parametrized_test.cc @@ -79,7 +79,7 @@ class MixedTypeTest : public GpuCodegenTest, } }; -TEST_P(MixedTypeTest, MixedTypeDotProducesCorrectResult) { +TEST_P(MixedTypeTest, DISABLED_MixedTypeDotProducesCorrectResult) { MixTypeParams params = GetParam(); const std::string hlo_string_template = R"( HloModule m @@ -184,7 +184,7 @@ std::string ElementwiseTestParamsToString( using UnaryElementwiseTest = ElementwiseTest; -TEST_P(UnaryElementwiseTest, ElementwiseFusionExecutesCorrectly) { +TEST_P(UnaryElementwiseTest, DISABLED_ElementwiseFusionExecutesCorrectly) { PrimitiveType data_type; HloOpcode opcode; float tolerance; @@ -249,7 +249,7 @@ ENTRY e { /*run_hlo_passes=*/false)); } -TEST_P(UnaryElementwiseTest, ElementwiseUnaryOpExecutesCorrectly) { +TEST_P(UnaryElementwiseTest, DISABLED_ElementwiseUnaryOpExecutesCorrectly) { PrimitiveType data_type; HloOpcode opcode; float tolerance; @@ -360,7 +360,7 @@ INSTANTIATE_TEST_SUITE_P( using BinaryElementwiseTest = ElementwiseTest; -TEST_P(BinaryElementwiseTest, ElementwiseFusionExecutesCorrectly) { +TEST_P(BinaryElementwiseTest, DISABLED_ElementwiseFusionExecutesCorrectly) { PrimitiveType data_type; HloOpcode opcode; float tolerance; @@ -429,7 +429,7 @@ ENTRY e { /*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6)); } -TEST_P(BinaryElementwiseTest, ElementwiseBinaryOpExecutesCorrectly) { +TEST_P(BinaryElementwiseTest, DISABLED_ElementwiseBinaryOpExecutesCorrectly) { PrimitiveType data_type; HloOpcode opcode; float tolerance; @@ -551,7 +551,7 @@ std::string CompareTestParamsToString( "_", ComparisonDirectionToString(direction)); } -TEST_P(CompareTest, CompareFusionExecutesCorrectly) { +TEST_P(CompareTest, DISABLED_CompareFusionExecutesCorrectly) { PrimitiveType data_type; Comparison::Direction direction; std::tie(data_type, direction) = GetParam(); @@ -654,7 +654,7 @@ class SelectTest : public TritonTest, public ::testing::WithParamInterface< std::tuple> {}; -TEST_P(SelectTest, SelectFusionExecutesCorrectly) { +TEST_P(SelectTest, DISABLED_SelectFusionExecutesCorrectly) { PrimitiveType data_type1, data_type2; std::tie(data_type1, data_type2) = GetParam(); for (const PrimitiveType type : {data_type1, data_type2}) { @@ -759,7 +759,7 @@ INSTANTIATE_TEST_SUITE_P( class ConstantTest : public TritonTest, public ::testing::WithParamInterface {}; -TEST_P(ConstantTest, ConstantFusionExecutesCorrectly) { +TEST_P(ConstantTest, DISABLED_ConstantFusionExecutesCorrectly) { const PrimitiveType data_type = GetParam(); if (!legacy_triton::IsTritonSupportedDataType(data_type, GetCudaComputeCapability())) { @@ -861,7 +861,7 @@ class ConvertTest : public TritonTest, public ::testing::WithParamInterface< std::tuple> {}; -TEST_P(ConvertTest, ConvertFusionExecutesCorrectly) { +TEST_P(ConvertTest, DISABLED_ConvertFusionExecutesCorrectly) { PrimitiveType data_type1, data_type2; std::tie(data_type1, data_type2) = GetParam(); for (const PrimitiveType type : {data_type1, data_type2}) { @@ -918,7 +918,7 @@ class TritonNormalizationTest } }; -TEST_P(TritonNormalizationTest, CanFuseAndEmitExactSoftmax) { +TEST_P(TritonNormalizationTest, DISABLED_CanFuseAndEmitExactSoftmax) { PrimitiveType data_type = GetParam(); if (data_type == F16) { @@ -982,7 +982,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonNormalizationTest, CanFuseAndEmitFirstSoftmaxDiamond) { +TEST_P(TritonNormalizationTest, DISABLED_CanFuseAndEmitFirstSoftmaxDiamond) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( HloModule softmax @@ -1037,7 +1037,8 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonNormalizationTest, CanFuseAndEmitSoftmaxDiamondWithSmallRows) { +TEST_P(TritonNormalizationTest, + DISABLED_CanFuseAndEmitSoftmaxDiamondWithSmallRows) { PrimitiveType data_type = GetParam(); constexpr absl::string_view kHloTextTemplate = R"( HloModule softmax @@ -1074,7 +1075,8 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0))); } -TEST_F(TritonNormalizationTest, CanFuseAndEmitDiamondWithBF16Converts) { +TEST_F(TritonNormalizationTest, + DISABLED_CanFuseAndEmitDiamondWithBF16Converts) { const std::string hlo_text = R"( HloModule softmax max_computation { @@ -1110,7 +1112,7 @@ ENTRY main { } TEST_P(TritonNormalizationTest, - CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { + DISABLED_CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1164,7 +1166,7 @@ ENTRY main { } TEST_P(TritonNormalizationTest, - CanFuseAndEmitSoftmaxWithIntermediateUnaryElementwise) { + DISABLED_CanFuseAndEmitSoftmaxWithIntermediateUnaryElementwise) { PrimitiveType data_type = GetParam(); if (data_type == F16) { @@ -1231,7 +1233,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { + DISABLED_CanFuseAndEmitTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1292,7 +1294,7 @@ ENTRY main { } TEST_P(TritonNormalizationTest, - CanFuseAndEmitDiamondWithTrailingUnaryElementwiseAtTheRoot) { + DISABLED_CanFuseAndEmitDiamondWithTrailingUnaryElementwiseAtTheRoot) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1347,7 +1349,7 @@ ENTRY main { } TEST_P(TritonNormalizationTest, - CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { + DISABLED_CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1401,8 +1403,9 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonNormalizationTest, - CanFuseAndEmitSoftmaxDiamondWithLastDimensionBitcastAfterReduce) { +TEST_P( + TritonNormalizationTest, + DISABLED_CanFuseAndEmitSoftmaxDiamondWithLastDimensionBitcastAfterReduce) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1458,8 +1461,9 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonNormalizationTest, - CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectly) { +TEST_P( + TritonNormalizationTest, + DISABLED_CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectly) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1512,7 +1516,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1568,7 +1572,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1624,7 +1628,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1691,7 +1695,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1753,7 +1757,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitTwoBinaryElementwiseWhereBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitTwoBinaryElementwiseWhereBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1815,7 +1819,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonNormalizationTest, DiamondEmitterIsNumericallyStable) { +TEST_P(TritonNormalizationTest, DISABLED_DiamondEmitterIsNumericallyStable) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1849,7 +1853,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0))); } -TEST_P(TritonNormalizationTest, CanFuseAndEmitRMSNormDiamond) { +TEST_P(TritonNormalizationTest, DISABLED_CanFuseAndEmitRMSNormDiamond) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1913,7 +1917,7 @@ ENTRY main.30 { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1976,7 +1980,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -2039,7 +2043,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -2098,7 +2102,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -2156,7 +2160,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseProducerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseProducerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -2215,7 +2219,7 @@ ENTRY main { TEST_P( TritonNormalizationTest, - CanFuseAndEmitBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducerIntoDiamond) { // NOLINT(whitespace/line_length) + DISABLED_CanFuseAndEmitBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducerIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -2261,7 +2265,8 @@ ENTRY main { INSTANTIATE_TEST_SUITE_P(TritonNormalizationTestSuite, TritonNormalizationTest, ::testing::Values(F32, F16, BF16)); -TEST_F(TritonNormalizationTest, CanFuseAndEmitTritonSoftmaxWithTwoParameters) { +TEST_F(TritonNormalizationTest, + DISABLED_CanFuseAndEmitTritonSoftmaxWithTwoParameters) { const std::string hlo_text = R"( HloModule layernorm @@ -2301,7 +2306,8 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_F(TritonNormalizationTest, CanFuseAndEmitTritonSoftmaxWithNonBatchReduce) { +TEST_F(TritonNormalizationTest, + DISABLED_CanFuseAndEmitTritonSoftmaxWithNonBatchReduce) { const std::string hlo_text = R"( HloModule layernorm @@ -2345,7 +2351,7 @@ class ReductionTypeTest : public TritonTest, public ::testing::WithParamInterface { }; -TEST_P(ReductionTypeTest, DifferentReductionTypes) { +TEST_P(ReductionTypeTest, DISABLED_DifferentReductionTypes) { PrimitiveType data_type = GetParam(); const std::string kHloTestTemplate = R"( From 4f4774136fdd773668a2655f638193c1b1a1f624 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Wed, 16 Apr 2025 10:17:02 -0700 Subject: [PATCH 0878/1324] Merged `PreemptionSyncManager` and `PreemptionSyncManagerImpl`. PiperOrigin-RevId: 748337125 --- .../preemption/preemption_sync_manager.cc | 63 ++++--------------- .../preemption/preemption_sync_manager.h | 49 ++++++++++++--- 2 files changed, 52 insertions(+), 60 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index 6b70f803573653..a6aca40cdc435f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include "absl/base/thread_annotations.h" #include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -33,7 +32,6 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" #include "xla/tsl/lib/monitoring/gauge.h" @@ -46,7 +44,6 @@ namespace { using tensorflow::CoordinatedTask; using tensorflow::KeyValueEntry; -constexpr int64_t kPreemptionSyncUnsetCounter = -1; constexpr char kPreemptionNoticeKey[] = "RECEIVED_PREEMPTION_NOTICE"; constexpr char kPreemptionCounterDirKey[] = "PREEMPTION_CURRENT_COUNTER/"; constexpr char kPreemptionBarrier[] = "PREEMPTION_SYNC_BARRIER"; @@ -75,49 +72,14 @@ auto* reached_sync_point_metric = monitoring::Gauge::New( // accommodate higher checkpoint durations. constexpr absl::Duration kProtocolDuration = absl::Minutes(15); -class PreemptionSyncManagerImpl : public PreemptionSyncManager { - public: - PreemptionSyncManagerImpl() = default; - ~PreemptionSyncManagerImpl() override { - shutdown_.Notify(); - } - absl::Status Initialize(CoordinationServiceAgent* agent) override; - absl::Status Initialize(CoordinationServiceAgent* agent, - const std::string& preemption_notifier_type) override; - absl::Status Initialize( - CoordinationServiceAgent* agent, - std::unique_ptr notifier) override; - bool ReachedSyncPoint(int step_counter) override; - - private: - // Determine the sync point upon receipt of preemption notice (death time). - void ComputeSyncCallCounter(absl::Time death_time); - // Notify other tasks to not wait at the barrier if the sync protocol failed - // midway. - void CancelPreemptionBarrier(); - - absl::Mutex mu_; - // Tracks the last step_counter passed into ReachedSyncPoint(); - int64_t call_counter_ ABSL_GUARDED_BY(mu_) = 0; - // If set, determines the sync point. - int64_t preemption_sync_counter_ ABSL_GUARDED_BY(mu_) = - kPreemptionSyncUnsetCounter; - std::string current_call_counter_key_; - - Env* env_; // Not owned; - CoordinationServiceAgent* agent_; // Not owned. - absl::Notification shutdown_; - std::unique_ptr sync_protocol_thread_; - std::unique_ptr preemption_notifier_; - std::shared_ptr call_opts_; -}; - -absl::Status PreemptionSyncManagerImpl::Initialize( +} // namespace + +absl::Status PreemptionSyncManager::Initialize( CoordinationServiceAgent* agent) { return Initialize(agent, "sigterm"); } -absl::Status PreemptionSyncManagerImpl::Initialize( +absl::Status PreemptionSyncManager::Initialize( CoordinationServiceAgent* agent, const std::string& preemption_notifier_type) { TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); @@ -125,7 +87,7 @@ absl::Status PreemptionSyncManagerImpl::Initialize( preemption_notifier_type, env)); } -absl::Status PreemptionSyncManagerImpl::Initialize( +absl::Status PreemptionSyncManager::Initialize( CoordinationServiceAgent* agent, std::unique_ptr notifier) { TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); @@ -176,7 +138,8 @@ absl::Status PreemptionSyncManagerImpl::Initialize( LOG(INFO) << "Cancelled call to retrieve preemption notice. This is " "expected upon program shutdown."; return; - } else if (!status_or_death_time.ok()) { + } + if (!status_or_death_time.ok()) { LOG(WARNING) << "Failed to retrieve preemption notice from " "coordination service: " @@ -211,14 +174,14 @@ absl::Status PreemptionSyncManagerImpl::Initialize( // Trigger protocol in a separate thread: compute max call counter. sync_protocol_thread_ = absl::WrapUnique(env_->StartThread( {}, "PreemptionSyncManager_SyncProtocol", - std::bind(&PreemptionSyncManagerImpl::ComputeSyncCallCounter, this, + std::bind(&PreemptionSyncManager::ComputeSyncCallCounter, this, death_time))); }); return absl::OkStatus(); } -void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { +void PreemptionSyncManager::ComputeSyncCallCounter(absl::Time death_time) { // 1. If death time is in the distant future, sleep until there's // `kProtocolDuration` left until death time before we begin the protocol. const absl::Duration remaining_time = death_time - absl::Now(); @@ -296,7 +259,7 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { set_sync_point_metric->GetCell()->Set(true); } -void PreemptionSyncManagerImpl::CancelPreemptionBarrier() { +void PreemptionSyncManager::CancelPreemptionBarrier() { agent_->CancelBarrierAsync( kPreemptionBarrier, [](const absl::Status& status) { if (!status.ok()) { @@ -305,7 +268,7 @@ void PreemptionSyncManagerImpl::CancelPreemptionBarrier() { }); } -bool PreemptionSyncManagerImpl::ReachedSyncPoint(int step_counter) { +bool PreemptionSyncManager::ReachedSyncPoint(int step_counter) { // Record that this API was called at least once. sync_usage_metric->GetCell()->Set(true); // Note: if a preemption notice has been received and ComputeSyncCallCounter() @@ -325,8 +288,8 @@ bool PreemptionSyncManagerImpl::ReachedSyncPoint(int step_counter) { } return reached_sync_point; } -} // namespace + std::unique_ptr CreatePreemptionSyncManager() { - return std::make_unique(); + return std::make_unique(); } } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h index 5d36540a7898fe..6a72b1256588ba 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h @@ -15,12 +15,19 @@ limitations under the License. #ifndef XLA_TSL_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ #define XLA_TSL_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" +#include "xla/tsl/platform/env.h" namespace tsl { @@ -33,15 +40,13 @@ namespace tsl { // TODO(b/230630494): Add Reset() to allow multiple sync points to be set. class PreemptionSyncManager { public: - virtual ~PreemptionSyncManager() = default; - - virtual absl::Status Initialize(CoordinationServiceAgent* agent) = 0; - virtual absl::Status Initialize( - CoordinationServiceAgent* agent, - const std::string& preemption_notifier_type) = 0; - virtual absl::Status Initialize( - CoordinationServiceAgent* agent, - std::unique_ptr notifier) = 0; + PreemptionSyncManager() = default; + ~PreemptionSyncManager() { shutdown_.Notify(); } + absl::Status Initialize(CoordinationServiceAgent* agent); + absl::Status Initialize(CoordinationServiceAgent* agent, + const std::string& preemption_notifier_type); + absl::Status Initialize(CoordinationServiceAgent* agent, + std::unique_ptr notifier); // Check if the synchronized point has been reached. When a task has been // preempted, a safe sync point will be determined by using the fastest task's @@ -57,7 +62,31 @@ class PreemptionSyncManager { // task. Once a preemption notice is received, all tasks will agree on a safe // step to pause training and handle the preemption (e.g. save checkpoint and // exit, or wait for preempted task to restart, then resume training). - virtual bool ReachedSyncPoint(int step_counter) = 0; + bool ReachedSyncPoint(int step_counter); + + private: + static constexpr int64_t kPreemptionSyncUnsetCounter = -1; + + // Determine the sync point upon receipt of preemption notice (death time). + void ComputeSyncCallCounter(absl::Time death_time); + // Notify other tasks to not wait at the barrier if the sync protocol failed + // midway. + void CancelPreemptionBarrier(); + + absl::Mutex mu_; + // Tracks the last step_counter passed into ReachedSyncPoint(); + int64_t call_counter_ ABSL_GUARDED_BY(mu_) = 0; + // If set, determines the sync point. + int64_t preemption_sync_counter_ ABSL_GUARDED_BY(mu_) = + kPreemptionSyncUnsetCounter; + std::string current_call_counter_key_; + + Env* env_; // Not owned; + CoordinationServiceAgent* agent_; // Not owned. + absl::Notification shutdown_; + std::unique_ptr sync_protocol_thread_; + std::unique_ptr preemption_notifier_; + std::shared_ptr call_opts_; }; std::unique_ptr CreatePreemptionSyncManager(); From 3e2e371ad51e09ffb3687c5a69f4d53557f722b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 11:00:06 -0700 Subject: [PATCH 0879/1324] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/f16b1e19ccc0466f202630e3fd17832d097edceb. PiperOrigin-RevId: 748351885 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index fb81926386c0dd..2d298bec758a78 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "7c9a510bdfd98a39f2639e7446a3119fdd3ccbfe" - TFRT_SHA256 = "67467e2e0914859e657d384917cb27bd362e7ebb007a21476a5d271ddd540001" + TFRT_COMMIT = "f16b1e19ccc0466f202630e3fd17832d097edceb" + TFRT_SHA256 = "c111345265c0c924e060100610fa2721f8f3996fc74bca3bc2ae0241b603f941" tf_http_archive( name = "tf_runtime", From 969134bc86fa14518d9fe0d889d8d85846d8566f Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Wed, 16 Apr 2025 11:24:11 -0700 Subject: [PATCH 0880/1324] Make tf_quantization_lib rely on non-lite QuantizeUtils.h PiperOrigin-RevId: 748362018 --- .../mlir/quantization/common/tf_quantization_lib/BUILD | 1 - .../common/tf_quantization_lib/tf_quantization_utils.cc | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD index f42c3ca6446c42..607faaab78f3b4 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD @@ -30,7 +30,6 @@ cc_library( deps = [ ":tf_quantization_config", ":tf_quantization_interfaces_inc_gen", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/tools/optimize:quantization_utils", diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc index 80a6a5c9c9b442..2beccf116125d9 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc @@ -45,10 +45,10 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" #include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" @@ -712,7 +712,7 @@ ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; return dyn_cast_or_null( - quantfork::quantizeAttr(real_value, q_type, converted_type)); + mlir::quant::ir::quantizeAttr(real_value, q_type, converted_type)); } return {}; } From df81eed13e588fd0d9e4bd5b909d8deb26bba0f9 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 16 Apr 2025 11:58:03 -0700 Subject: [PATCH 0881/1324] Remove `autorun_ci.{py,yml}`, this is no longer relevant now that we use GitHub Actions for all presubmit testing. PiperOrigin-RevId: 748374097 --- .../xla/.github/workflows/autorun_ci.py | 43 ------------------- .../xla/.github/workflows/autorun_ci.yml | 38 ---------------- 2 files changed, 81 deletions(-) delete mode 100644 third_party/xla/.github/workflows/autorun_ci.py delete mode 100644 third_party/xla/.github/workflows/autorun_ci.yml diff --git a/third_party/xla/.github/workflows/autorun_ci.py b/third_party/xla/.github/workflows/autorun_ci.py deleted file mode 100644 index 8221fdcd90cfb5..00000000000000 --- a/third_party/xla/.github/workflows/autorun_ci.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2024 The OpenXLA Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Autoruns CI for OpenXLA org members with membership set to public.""" -import logging -import os - -import github_api - -_OPENXLA_ORG_ID = 107584881 # https://api.github.com/orgs/107584881 - - -def main(): - username = os.getenv("PR_AUTHOR_USERNAME") - pr_number = os.getenv("PR_NUMBER") - api = github_api.GitHubAPI(os.getenv("GH_TOKEN")) - - orgs = api.get_user_orgs(username) - logging.info("Found public organizations for user %s: %s", username, orgs) - - if _OPENXLA_ORG_ID in {org["id"] for org in orgs}: - logging.info( - "Found OpenXLA org in public memberships, so adding kokoro:force-run" - " label." - ) - api.add_issue_labels("openxla/xla", pr_number, ["kokoro:force-run"]) - - -if __name__ == "__main__": - logging.basicConfig() - logging.getLogger().setLevel(logging.INFO) - main() diff --git a/third_party/xla/.github/workflows/autorun_ci.yml b/third_party/xla/.github/workflows/autorun_ci.yml deleted file mode 100644 index 92ebd74e75797f..00000000000000 --- a/third_party/xla/.github/workflows/autorun_ci.yml +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 The OpenXLA Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -name: Autorun CI for OpenXLA Public Members -permissions: - pull-requests: write -on: - pull_request_target: - branches: ["main"] - -jobs: - autorun-ci: - runs-on: ubuntu-22.04 - defaults: - run: - shell: bash - env: - GH_TOKEN: ${{ github.token }} - PR_NUMBER: ${{ github.event.number }} - PR_AUTHOR_USERNAME: ${{ github.event.pull_request.user.login }} - timeout-minutes: 6 - if: github.event.sender.type == 'User' - steps: - - name: "Checking out repository" - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: "Autorun CI for public OpenXLA org members" - run: python3 .github/workflows/autorun_ci.py From db349f81847ef39921b1aa978b6b5f0f6dd1995d Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Wed, 16 Apr 2025 12:23:19 -0700 Subject: [PATCH 0882/1324] PR #25288: [ROCm] fix kernel tiling test Imported from GitHub PR https://github.com/openxla/xla/pull/25288 In ROCm, the limit should be 65536. Copybara import of the project: -- 8d34eacdda3ad6eb36bdd899aff519dbf04ab972 by scxfjiang : fix kernel tiling Merging this change closes #25288 PiperOrigin-RevId: 748382202 --- third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index f6b2cb0b4fe60b..f1d43c8deefc6f 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -404,7 +404,7 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) { EXPECT_THAT(status.message(), ::testing::ContainsRegex( "Kernel '.*' launch needs more blocks [(]2147483648, 1[)] " - "than allowed by hardware [(]2147483647, 65535[)]")); + "than allowed by hardware [(]2147483647, 65536[)]")); } else { EXPECT_THAT(status.message(), ::testing::ContainsRegex( From 250cb35fa23c3111e406f88100e6e1aefda27e04 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 16 Apr 2025 12:29:51 -0700 Subject: [PATCH 0883/1324] [xla:cpu] Disable SLP vectorizer PiperOrigin-RevId: 748384387 --- third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc index c8be164e0f6e4b..d80f599710d6fe 100644 --- a/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc @@ -123,6 +123,10 @@ static llvm::PipelineTuningOptions GetPipelineTuningOptions( pto.SLPVectorization = !opts.optimize_for_size && !opts.disable_slp_vectorizer; pto.LoopUnrolling = !opts.disable_loop_unrolling; + + // TODO(b/411125413): Re-enable SLPVectorization once the LLVM bug is fixed. + pto.SLPVectorization = false; + return pto; }; From da03edfe213b94aa346d3b3a508c1ebe93b77d5c Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Wed, 16 Apr 2025 13:21:38 -0700 Subject: [PATCH 0884/1324] Port gpu custom_call_test test to HloTestBase. PiperOrigin-RevId: 748401363 --- third_party/xla/xla/service/gpu/BUILD | 6 ++- .../xla/xla/service/gpu/custom_call_test.cc | 47 +++++++++---------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 51e09b0abcb1f4..21803f2584a5ad 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -163,16 +163,18 @@ xla_test( "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/tests:client_library_test_base", + "//xla/tests:client_library_test_runner_mixin", + "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ] + if_cuda_is_configured([ diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index a215ed0b578da3..0c8c83910565ca 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -31,12 +31,11 @@ limitations under the License. #define PLATFORM "ROCM" #endif -#include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/ffi.h" @@ -46,19 +45,19 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/testlib/test_helpers.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" #if GOOGLE_CUDA #define gpuSuccess cudaSuccess @@ -93,7 +92,7 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(::xla::Range, StructMember("lo"), namespace xla { namespace { -class CustomCallTest : public ClientLibraryTestBase {}; +using CustomCallTest = ClientLibraryTestRunnerMixin; bool is_invoked_called = false; void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/, @@ -108,7 +107,7 @@ TEST_F(CustomCallTest, IsInvoked) { ShapeUtil::MakeShape(F32, {}), /*opaque=*/""); EXPECT_FALSE(is_invoked_called); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); EXPECT_TRUE(is_invoked_called); } @@ -117,7 +116,7 @@ TEST_F(CustomCallTest, UnknownTarget) { CustomCall(&b, "UnknownTarget", /*operands=*/{}, ShapeUtil::MakeShape(F32, {}), /*opaque=*/""); - ASSERT_FALSE(Execute(&b, {}).ok()); + ASSERT_FALSE(ExecuteAndTransfer(&b, {}).ok()); } void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers, const char* /*opaque*/, size_t /*opaque_len*/) { @@ -149,7 +148,7 @@ TEST_F(CustomCallTest, Opaque) { XlaBuilder b(TestName()); CustomCall(&b, "Callback_Opaque", /*operands=*/{}, ShapeUtil::MakeShape(F32, {}), kExpectedOpaque); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers, @@ -260,7 +259,7 @@ std::vector GetTokenTestCases() { class CustomCallTokensTest : public ::testing::WithParamInterface, - public ClientLibraryTestBase { + public ClientLibraryTestRunnerMixin { public: static std::vector BuildInputs(XlaBuilder& b, std::istringstream& str) { @@ -315,7 +314,7 @@ TEST_P(CustomCallTokensTest, TokensTest) { CustomCall(&b, "Callback_Tokens", call_inputs, call_output.front(), tc.opaque); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } INSTANTIATE_TEST_CASE_P(CustomCallTokens, CustomCallTokensTest, @@ -338,7 +337,7 @@ TEST_F(CustomCallTest, WithStatusSucceeded) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_STATUS_RETURNING); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } void Callback_WithStatusFailed(se::gpu::GpuStreamHandle /*stream*/, @@ -358,7 +357,7 @@ TEST_F(CustomCallTest, WithStatusFailed) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_STATUS_RETURNING); - auto status = Execute(&b, {}).status(); + auto status = ExecuteAndTransfer(&b, {}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), ::testing::HasSubstr("Failed")); } @@ -387,7 +386,7 @@ TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - auto status = Execute(&b, {}).status(); + auto status = ExecuteAndTransfer(&b, {}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42")); } @@ -404,7 +403,7 @@ TEST_F(CustomCallTest, PassAttributesByBackendConfig) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - auto status = Execute(&b, {}).status(); + auto status = ExecuteAndTransfer(&b, {}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42")); } @@ -464,7 +463,7 @@ TEST_F(CustomCallTest, PassUserPointerWithAttrs) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - auto status = Execute(&b, {}).status(); + auto status = ExecuteAndTransfer(&b, {}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), ::testing::HasSubstr("User-defined message")); } @@ -502,7 +501,7 @@ TEST_F(CustomCallTest, ExportedFfiUnknownTarget) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - auto status = Execute(&b, {}).status(); + auto status = ExecuteAndTransfer(&b, {}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kUnimplemented); EXPECT_THAT(status.message(), ::testing::HasSubstr("No registered implementation")); @@ -541,7 +540,7 @@ TEST_F(CustomCallTest, ExportedFfiOpaque) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } static absl::Status CheckTokens(std::vector args, @@ -612,7 +611,7 @@ TEST_P(CustomCallTokensTest, ExportedTokensTest) { /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } INSTANTIATE_TEST_SUITE_P(CustomCallTokensTest, CustomCallTokensTest, @@ -636,7 +635,7 @@ TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } //===----------------------------------------------------------------------===// @@ -679,7 +678,7 @@ TEST_F(CustomCallTest, FfiAttributes) { /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } //===----------------------------------------------------------------------===// @@ -816,7 +815,7 @@ TEST_F(CustomCallTest, FfiExecutionContext) { ffi::internal::ScopedExecutionContext scoped_execution_context( &execution_context); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); // Check that FFI handler was called during initialization and execution. TF_ASSERT_OK_AND_ASSIGN(auto* user_context, @@ -876,7 +875,7 @@ TEST_F(CustomCallTest, FfiExecutionState) { /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); - TF_ASSERT_OK(Execute(&b, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&b, {}).status()); } } // anonymous namespace From f6514a925f16d3196ecf3c8a477adca24ec7eb5a Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Wed, 16 Apr 2025 13:21:41 -0700 Subject: [PATCH 0885/1324] Port Client tests to HloTestBase where possible. PiperOrigin-RevId: 748401382 --- third_party/xla/xla/tests/BUILD | 546 +++++++++--------- third_party/xla/xla/tests/axpy_simple_test.cc | 15 +- .../tests/bad_rng_shape_validation_test.cc | 13 +- third_party/xla/xla/tests/bfloat16_test.cc | 47 +- .../xla/xla/tests/binop_scaling_test.cc | 13 +- .../xla/xla/tests/bitcast_convert_test.cc | 24 +- .../xla/xla/tests/broadcast_simple_test.cc | 242 ++++---- third_party/xla/xla/tests/call_test.cc | 27 +- .../xla/tests/check_execution_arity_test.cc | 52 +- .../xla/xla/tests/complex_unary_op_test.cc | 24 +- third_party/xla/xla/tests/conditional_test.cc | 145 ++--- third_party/xla/xla/tests/constants_test.cc | 52 +- .../convolution_dimension_numbers_test.cc | 27 +- .../xla/xla/tests/convolution_test_1d.cc | 191 ++---- .../xla/tests/convolution_variants_test.cc | 282 +++++---- .../xla/xla/tests/cpu_gpu_fusion_test.cc | 120 ++-- third_party/xla/xla/tests/dynamic_ops_test.cc | 236 ++++---- third_party/xla/xla/tests/float8_test.cc | 10 +- third_party/xla/xla/tests/floor_ceil_test.cc | 15 +- third_party/xla/xla/tests/fmax_fmin_test.cc | 29 +- third_party/xla/xla/tests/half_test.cc | 37 +- third_party/xla/xla/tests/iota_test.cc | 15 +- third_party/xla/xla/tests/log_test.cc | 13 +- third_party/xla/xla/tests/map_test.cc | 126 ++-- .../xla/tests/multidimensional_slice_test.cc | 15 +- third_party/xla/xla/tests/params_test.cc | 181 +++--- third_party/xla/xla/tests/pred_test.cc | 15 +- .../xla/tests/query_inferred_shape_test.cc | 22 +- .../xla/xla/tests/reduce_precision_test.cc | 25 +- third_party/xla/xla/tests/reshape_test.cc | 456 +++++++-------- third_party/xla/xla/tests/reverse_test.cc | 35 +- .../xla/xla/tests/scalar_computations_test.cc | 138 ++--- .../xla/xla/tests/select_and_scatter_test.cc | 45 +- third_party/xla/xla/tests/select_test.cc | 56 +- third_party/xla/xla/tests/slice_test.cc | 38 +- third_party/xla/xla/tests/transpose_test.cc | 88 ++- .../xla/xla/tests/triangular_solve_test.cc | 118 ++-- third_party/xla/xla/tests/tuple_test.cc | 132 ++--- third_party/xla/xla/tests/unary_op_test.cc | 37 +- .../xla/xla/tests/vector_ops_reduce_test.cc | 40 +- .../xla/xla/tests/vector_ops_simple_test.cc | 91 ++- 41 files changed, 1796 insertions(+), 2037 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 536790d796894b..7857f256b1cf19 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -487,15 +487,14 @@ xla_test( srcs = ["bad_rng_shape_validation_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", ":xla_internal_test_main", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/testlib:test", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", ], ) @@ -605,18 +604,22 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_xla_cpu_no_thunks", + ], deps = [ - ":client_library_test_base", - ":test_macros_header", - ":xla_internal_test_main", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", "//xla/hlo/testlib:test_helpers", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", ], ) @@ -626,15 +629,12 @@ xla_test( srcs = ["query_inferred_shape_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", - "//xla/hlo/testlib:test_helpers", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -677,13 +677,15 @@ xla_test( srcs = ["axpy_simple_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", - "//xla/client:local_client", + "//xla:error_spec", + "//xla:shape_util", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -692,24 +694,27 @@ xla_test( srcs = ["map_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":test_utils", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/testlib:test", "//xla/hlo/testlib:test_helpers", - "//xla/stream_executor:stream_executor_h", + "//xla/service", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -723,21 +728,22 @@ xla_test( "test_xla_cpu_no_thunks", ], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/service", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test", ], ) @@ -746,14 +752,13 @@ xla_test( srcs = ["pred_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":xla_internal_test_main", - "//xla:array2d", - "//xla/client:local_client", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:arithmetic", - "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", + "@com_google_absl//absl/types:span", ], ) @@ -762,14 +767,16 @@ xla_test( srcs = ["select_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", + "//xla:literal", "//xla:types", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/service", + "//xla/tsl/platform:test", ], ) @@ -779,19 +786,23 @@ xla_test( shard_count = 2, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test_helpers", "//xla/tsl/platform:env", "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", ], ) @@ -800,14 +811,16 @@ xla_test( srcs = ["unary_op_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", - ":test_macros_header", - ":xla_internal_test_main", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -823,15 +836,17 @@ xla_test( ], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:math", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -841,22 +856,23 @@ xla_test( shard_count = 32, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:status_macros", + "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/testlib:test_helpers", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/service", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:test", ], ) @@ -957,20 +973,18 @@ xla_test( srcs = ["reduce_precision_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", - "//xla:array2d", + ":xla_internal_test_main", # fixdeps: keep "//xla:literal", - "//xla:shape_util", + "//xla:literal_util", "//xla:types", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", + "//xla/tsl/platform:test", "@com_google_absl//absl/base", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1292,18 +1306,19 @@ xla_test( srcs = ["transpose_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", "//xla:literal_util", "//xla:reference_util", "//xla:util", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -1312,24 +1327,24 @@ xla_test( srcs = ["constants_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", ":literal_test_util", ":test_macros_header", - ":test_utils", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", "//xla:array4d", + "//xla:error_spec", + "//xla:literal", "//xla:literal_util", "//xla:types", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:constants", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:test", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:test", ], ) @@ -1350,7 +1365,7 @@ xla_test( ":hlo_pjrt_interpreter_reference_mixin", ":hlo_pjrt_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:array4d", "//xla:error_spec", @@ -1391,32 +1406,18 @@ xla_test( ":client_library_test_base", ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:array2d", "//xla:array3d", - "//xla:array4d", "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:reference_util", "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", - "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cuda_platform_id", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/tsl/platform:test", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", ], ) @@ -1435,36 +1436,20 @@ xla_test( "test_xla_cpu_no_thunks", ], deps = [ - ":client_library_test_base", ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", - "//xla:array2d", + ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", - "//xla:array4d", "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:reference_util", "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", - "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cuda_platform_id", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + "//xla/tsl/platform:test", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", ], ) @@ -1520,35 +1505,20 @@ xla_test( "optonly", ], deps = [ - ":client_library_test_base", ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", - "//xla:array2d", + ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", - "//xla:array4d", "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:reference_util", "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", - "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cuda_platform_id", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/tsl/platform:test", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", ], ) @@ -1594,35 +1564,20 @@ xla_test( shard_count = 25, tags = ["cuda-only"], deps = [ - ":client_library_test_base", ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", - "//xla:array2d", + ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", - "//xla:array4d", "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:reference_util", "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", - "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cuda_platform_id", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/tsl/platform:test", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", ], ) @@ -1637,19 +1592,21 @@ xla_test( shard_count = 30, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", "//xla:array4d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:reference_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", + "@com_google_absl//absl/types:span", ], ) @@ -1660,16 +1617,21 @@ xla_test( shard_count = 20, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array4d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:reference_util", - "//xla/client:local_client", + "//xla:shape_util", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", ], ) @@ -1741,26 +1703,16 @@ xla_test( shard_count = 40, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":hlo_test_base", - ":literal_test_util", - ":test_macros_header", - ":test_utils", - ":xla_internal_test_main", - "//xla:array2d", - "//xla:array4d", - "//xla:literal", - "//xla:reference_util", - "//xla:shape_util", - "//xla:util", - "//xla/client:local_client", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:array4d", + "//xla:error_spec", + "//xla:literal_util", + "//xla:types", "//xla/hlo/builder:xla_builder", - "//xla/hlo/builder/lib:arithmetic", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:test", - "//xla/hlo/testlib:test_helpers", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -1769,10 +1721,12 @@ xla_test( srcs = ["float8_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":xla_internal_test_main", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", + "//xla/tsl/platform:test", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:ml_dtypes", ], @@ -1787,15 +1741,17 @@ xla_test( ], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":test_utils", - ":xla_internal_test_main", - "//xla:literal", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", + "//xla:types", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", - "//xla/hlo/testlib:test_helpers", - "@com_google_absl//absl/status:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", ], ) @@ -1829,19 +1785,23 @@ xla_test( shard_count = 40, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:array3d", + "//xla:array4d", + "//xla:error_spec", + "//xla:literal_util", "//xla:reference_util", - "//xla/client:local_client", + "//xla:shape_util", "//xla/hlo/builder:xla_builder", + "//xla/tsl/platform:test", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:test", ], ) @@ -1850,15 +1810,15 @@ xla_test( srcs = ["multidimensional_slice_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:array3d", - "//xla/client:local_client", + "//xla:error_spec", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -1871,27 +1831,36 @@ xla_test( "test_xla_cpu_no_thunks", ] + if_oss(["not_run:arm"]), deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", - "//xla:reference_util", + "//xla:array3d", + "//xla:error_spec", + "//xla:executable_run_options", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test_helpers", + "//xla/service", "//xla/service:computation_placer", - "//xla/service:local_service", "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", ], ) @@ -1900,23 +1869,25 @@ xla_test( srcs = ["tuple_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", ":literal_test_util", - ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:error_spec", "//xla:literal_util", "//xla:shape_util", + "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:test_helpers", - "//xla/tsl/lib/core:status_test_util", + "//xla/service", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test", + "@com_google_absl//absl/types:span", ], ) @@ -1925,17 +1896,17 @@ xla_test( srcs = ["vector_ops_reduce_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:array3d", + "//xla:error_spec", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:arithmetic", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -2058,9 +2029,10 @@ xla_test( "test_xla_cpu_no_thunks", ], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array", "//xla:array2d", "//xla:array4d", @@ -2071,7 +2043,7 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -2205,10 +2177,11 @@ xla_test( srcs = ["call_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -2216,7 +2189,9 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/testlib:test_helpers", - "@local_tsl//tsl/platform:test", + "//xla/service", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -2273,16 +2248,15 @@ xla_test( srcs = ["binop_scaling_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", - ":test_macros_header", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":xla_internal_test_main", "//xla:array2d", "//xla:array4d", + "//xla:error_spec", "//xla:reference_util", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -2291,19 +2265,23 @@ xla_test( srcs = ["broadcast_simple_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", - "//xla:array4d", + "//xla:array3d", + "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla/client:local_client", + "//xla:shape_util", "//xla/hlo/builder:xla_builder", + "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", - "@com_google_absl//absl/status:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -2336,12 +2314,15 @@ xla_test( srcs = ["fmax_fmin_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", + "//xla:literal", "//xla/hlo/builder:xla_builder", "//xla/service", + "//xla/tsl/platform:test", ], ) @@ -2350,13 +2331,14 @@ xla_test( srcs = ["log_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", - "//xla/client:local_client", + ":xla_internal_test_main", # fixdeps: keep + "//xla:array3d", + "//xla:error_spec", "//xla/hlo/builder:xla_builder", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:test", ], ) @@ -2445,11 +2427,10 @@ xla_test( shard_count = 30, tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", ":literal_test_util", - ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:array3d", "//xla:array4d", @@ -2460,16 +2441,15 @@ xla_test( "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/testlib:test", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", ], ) @@ -2496,9 +2476,10 @@ xla_test( srcs = ["reverse_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":test_macros_header", - ":xla_internal_test_main", + ":client_library_test_runner_mixin", + ":client_library_test_runner_utils", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla:array4d", "//xla:error_spec", "//xla:literal", @@ -2506,10 +2487,11 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:test", ], ) @@ -2537,20 +2519,19 @@ xla_test( srcs = ["vector_ops_simple_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", - ":literal_test_util", - ":test_macros_header", - ":xla_internal_test_main", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla:array4d", + "//xla:error_spec", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/testlib:test_helpers", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -2838,17 +2819,17 @@ xla_test( srcs = ["bitcast_convert_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", + "//xla:error_spec", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/tsl/platform:test", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:test", ], ) @@ -2874,15 +2855,15 @@ xla_test( srcs = ["floor_ceil_test.cc"], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", + ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", - "//xla:error_spec", + ":xla_internal_test_main", # fixdeps: keep "//xla/hlo/builder:xla_builder", + "//xla/tsl/platform:test", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", ], ) @@ -3061,25 +3042,35 @@ xla_test( ], tags = ["test_xla_cpu_no_thunks"], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", ":literal_test_util", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", + "//xla:comparison_util", + "//xla:error_spec", + "//xla:executable_run_options", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/service:platform_util", + "//xla/service:shaped_buffer", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:env", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test_benchmark", ], ) @@ -3384,15 +3375,16 @@ xla_test( "test_xla_cpu_no_thunks", ], deps = [ - ":client_library_test_base", + ":client_library_test_runner_mixin", ":hlo_test_base", ":test_macros_header", - ":xla_internal_test_main", + ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", + "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:ml_dtypes", @@ -3480,22 +3472,24 @@ xla_test( "test_xla_cpu_no_thunks", ], deps = [ - ":client_library_test_base", - ":literal_test_util", - ":test_macros_header", - ":xla_internal_test_main", + ":client_library_test_runner_mixin", + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla:array", "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", - "//xla/hlo/builder/lib:math", "//xla/hlo/builder/lib:matrix", "//xla/hlo/testlib:test", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/tests/axpy_simple_test.cc b/third_party/xla/xla/tests/axpy_simple_test.cc index d4f1a753c6e776..e463a7af2a388e 100644 --- a/third_party/xla/xla/tests/axpy_simple_test.cc +++ b/third_party/xla/xla/tests/axpy_simple_test.cc @@ -15,17 +15,20 @@ limitations under the License. #include -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class AxpySimpleTest : public ClientLibraryTestBase {}; +using AxpySimpleTest = ClientLibraryTestRunnerMixin; TEST_F(AxpySimpleTest, AxTenValues) { XlaBuilder builder("ax_10"); @@ -40,7 +43,7 @@ TEST_F(AxpySimpleTest, AxTenValues) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { +TEST_F(AxpySimpleTest, AxpyZeroValues) { XlaBuilder builder("axpy_10"); auto alpha = ConstantR0(&builder, 3.1415926535); auto x = ConstantR1(&builder, {}); diff --git a/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc b/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc index c4a8efbc7509e0..7b38c008837803 100644 --- a/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc +++ b/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc @@ -21,25 +21,22 @@ limitations under the License. #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/testlib/test.h" #include "xla/shape.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" namespace xla { namespace { -class BadRngShapeValidationTest : public ClientLibraryTestBase {}; - -TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { - XlaBuilder builder(TestName()); +TEST(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { + XlaBuilder builder("BadRngShapeValidationTest"); auto zero = ConstantR0(&builder, 0.0); auto one = ConstantR0(&builder, 1.0); RngUniform(zero, one, Shape()); EXPECT_FALSE(builder.Build().ok()); } -TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { - XlaBuilder builder(TestName()); +TEST(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { + XlaBuilder builder("BadRngShapeValidationTest"); auto zero = ConstantR0(&builder, 0.0); auto one = ConstantR0(&builder, 1.0); Shape shape; diff --git a/third_party/xla/xla/tests/bfloat16_test.cc b/third_party/xla/xla/tests/bfloat16_test.cc index 2eb9dea66b8596..8e5c878d1f006f 100644 --- a/third_party/xla/xla/tests/bfloat16_test.cc +++ b/third_party/xla/xla/tests/bfloat16_test.cc @@ -13,51 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include - -#include "absl/status/statusor.h" -#include "xla/array2d.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" -#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/testlib/test.h" -#include "xla/hlo/testlib/test_helpers.h" -#include "xla/literal.h" -#include "xla/reference_util.h" -#include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/literal_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" -#include "xla/util.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" +#include "xla/types.h" namespace xla { namespace { -class Bfloat16Test : public ClientLibraryTestBase { - protected: - const ErrorSpec error_spec_{0.001, 0.001}; -}; +constexpr ErrorSpec kErrorSpec{0.001, 0.001}; + +using Bfloat16Test = ClientLibraryTestRunnerMixin; -XLA_TEST_F(Bfloat16Test, ScalarOperation) { +TEST_F(Bfloat16Test, ScalarOperation) { XlaBuilder builder(TestName()); auto x = ConstantR0(&builder, static_cast(2.0f)); auto y = ConstantR0(&builder, static_cast(1.0f)); Add(x, y); ComputeAndCompareR0(&builder, static_cast(3.0f), {}, - error_spec_); + kErrorSpec); } -XLA_TEST_F(Bfloat16Test, LogOperation) { +TEST_F(Bfloat16Test, LogOperation) { XlaBuilder builder(TestName()); auto x = ConstantR0(&builder, static_cast(4.0f)); Log(x); @@ -66,17 +49,17 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { ErrorSpec(0.01, 0.01)); } -XLA_TEST_F(Bfloat16Test, NegateScalarF16) { +TEST_F(Bfloat16Test, NegateScalarF16) { XlaBuilder builder(TestName()); Neg(ConstantR0(&builder, static_cast(2.1f))); ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, - error_spec_); + kErrorSpec); } // Disabled on interpreter since BatchNormExpander is not run by default on the // interpreter backend. -XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { +TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); @@ -112,7 +95,7 @@ XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { // Disabled on interpreter since BatchNormExpander is not run by default on the // interpreter backend. -XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { +TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/binop_scaling_test.cc b/third_party/xla/xla/tests/binop_scaling_test.cc index 6aab7717b2a1f7..5fbed99aefc1a6 100644 --- a/third_party/xla/xla/tests/binop_scaling_test.cc +++ b/third_party/xla/xla/tests/binop_scaling_test.cc @@ -13,20 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "xla/array2d.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class BinopScalingTest : public ClientLibraryTestBase {}; +using BinopScalingTest = ClientLibraryTestRunnerMixin; TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4); diff --git a/third_party/xla/xla/tests/bitcast_convert_test.cc b/third_party/xla/xla/tests/bitcast_convert_test.cc index 999c17680d6671..d102815aa86235 100644 --- a/third_party/xla/xla/tests/bitcast_convert_test.cc +++ b/third_party/xla/xla/tests/bitcast_convert_test.cc @@ -13,29 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include -#include "xla/client/local_client.h" +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class BitcastConvertTest : public ClientLibraryTestBase { +class BitcastConvertTest : public ClientLibraryTestRunnerMixin { public: - explicit BitcastConvertTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) { + BitcastConvertTest() { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); } @@ -71,7 +71,7 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { +TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); BitcastConvertType(a, F32); @@ -149,7 +149,7 @@ TEST_F(BitcastConvertTest, ConvertReshape) { class BitcastConvertHloTest : public HloTestBase {}; -XLA_TEST_F(BitcastConvertHloTest, S32to4S8) { +TEST_F(BitcastConvertHloTest, S32to4S8) { absl::string_view hlo_string = R"( HloModule bitcast_to_smaller @@ -161,7 +161,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(BitcastConvertHloTest, FourS8toS32) { +TEST_F(BitcastConvertHloTest, FourS8toS32) { absl::string_view hlo_string = R"( HloModule bitcast_to_larger @@ -173,7 +173,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(BitcastConvertHloTest, F32to2F16) { +TEST_F(BitcastConvertHloTest, F32to2F16) { absl::string_view hlo_string = R"( HloModule bitcast_to_smaller @@ -185,7 +185,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); } -XLA_TEST_F(BitcastConvertHloTest, TwoF16toF32) { +TEST_F(BitcastConvertHloTest, TwoF16toF32) { absl::string_view hlo_string = R"( HloModule bitcast_to_smaller diff --git a/third_party/xla/xla/tests/broadcast_simple_test.cc b/third_party/xla/xla/tests/broadcast_simple_test.cc index 0f7f5656dc75ee..6d07b7afa34ad0 100644 --- a/third_party/xla/xla/tests/broadcast_simple_test.cc +++ b/third_party/xla/xla/tests/broadcast_simple_test.cc @@ -13,143 +13,137 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include +#include +#include +#include -#include "absl/status/statusor.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/array2d.h" -#include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/array3d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/test.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class BroadcastSimpleTest : public ClientLibraryTestBase { - public: - static constexpr absl::string_view kIncompatibleBinaryOpShapeErrorMessage = - "Binary op with incompatible shapes"; +using ::testing::HasSubstr; - XlaOp BuildBinOp(HloOpcode op, const XlaOp lhs, const XlaOp rhs, - XlaBuilder* builder) { - switch (op) { - case HloOpcode::kMinimum: { - return Min(lhs, rhs); - } - case HloOpcode::kMaximum: { - return Max(lhs, rhs); - } - case HloOpcode::kMultiply: { - return Mul(lhs, rhs); - } - default: { - // Default to Add - return Add(lhs, rhs); - } +constexpr absl::string_view kIncompatibleBinaryOpShapeErrorMessage = + "Binary op with incompatible shapes"; + +XlaOp BuildBinOp(HloOpcode op, const XlaOp lhs, const XlaOp rhs, + XlaBuilder* builder) { + switch (op) { + case HloOpcode::kMinimum: { + return Min(lhs, rhs); + } + case HloOpcode::kMaximum: { + return Max(lhs, rhs); + } + case HloOpcode::kMultiply: { + return Mul(lhs, rhs); + } + default: { + // Default to Add + return Add(lhs, rhs); } } +} - std::unique_ptr MakeR3Data( - absl::Span bounds, - absl::Span minor_to_major, Shape* r3_shape, - Array3D* r3_array, float start, float end, int seed) { - *r3_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, bounds, minor_to_major); - r3_array->FillRandom(start, end, seed); - auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( - LayoutUtil::MakeLayout(minor_to_major)); - std::unique_ptr r3_global_data = - client_->TransferToServer(r3_data).value(); - return r3_global_data; - } +Literal MakeR3Data(absl::Span bounds, + absl::Span minor_to_major, Shape* r3_shape, + Array3D* r3_array, float start, float end, int seed) { + *r3_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, bounds, minor_to_major); + r3_array->FillRandom(start, end, seed); + return LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( + LayoutUtil::MakeLayout(minor_to_major)); +} - std::unique_ptr MakeR2Data( - absl::Span bounds, - absl::Span minor_to_major, Shape* r2_shape, - Array2D* r2_array, float start, float end, int seed) { - *r2_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, bounds, minor_to_major); - r2_array->FillRandom(start, end, seed); - auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( - LayoutUtil::MakeLayout(minor_to_major)); - std::unique_ptr r2_global_data = - client_->TransferToServer(r2_data).value(); - return r2_global_data; - } +Literal MakeR2Data(absl::Span bounds, + absl::Span minor_to_major, Shape* r2_shape, + Array2D* r2_array, float start, float end, int seed) { + *r2_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, bounds, minor_to_major); + r2_array->FillRandom(start, end, seed); + return LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( + LayoutUtil::MakeLayout(minor_to_major)); +} - float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) { - switch (op) { - case HloOpcode::kMinimum: { - return std::min(lhs, rhs); - } - case HloOpcode::kMaximum: { - return std::max(lhs, rhs); - } - case HloOpcode::kMultiply: { - return lhs * rhs; - } - case HloOpcode::kAdd: { - return lhs + rhs; - } - default: { - // Default to Add - LOG(FATAL); - } +float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) { + switch (op) { + case HloOpcode::kMinimum: { + return std::min(lhs, rhs); + } + case HloOpcode::kMaximum: { + return std::max(lhs, rhs); + } + case HloOpcode::kMultiply: { + return lhs * rhs; + } + case HloOpcode::kAdd: { + return lhs + rhs; + } + default: { + // Default to Add + LOG(FATAL); } } -}; +} -using ::testing::HasSubstr; +using BroadcastSimpleTest = ClientLibraryTestRunnerMixin; -XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { +TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { XlaBuilder b(TestName()); Broadcast(ConstantR0(&b, 1.5), {}); ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { +TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { XlaBuilder b(TestName()); Broadcast(ConstantR0(&b, 2.25), {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { +TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { XlaBuilder b(TestName()); XlaOp src; - std::unique_ptr param_data = + const Literal param_data = CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", /*builder=*/&b, /*data_handle=*/&src); Broadcast(src, {2, 3}); Array2D expected(2, 3, 2.25); - ComputeAndCompareR2(&b, expected, {param_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, expected, {¶m_data}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { +TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { XlaBuilder b(TestName()); Broadcast(ConstantR0(&b, 2.25), {2, 0}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { +TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { XlaBuilder b(TestName()); Broadcast(ConstantR0(&b, 2.25), {0, 2}); Array2D expected(0, 2); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { +TEST_F(BroadcastSimpleTest, 1DTo2D) { XlaBuilder b(TestName()); Broadcast(ConstantR1(&b, {1, 2, 3}), {2}); @@ -163,7 +157,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { +TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XlaBuilder b(TestName()); BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {1}); @@ -176,7 +170,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { +TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XlaBuilder b(TestName()); BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {0}); @@ -189,7 +183,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { +TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XlaBuilder b(TestName()); BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, {0, 1}); @@ -207,7 +201,7 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { +TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, {0, 2}); @@ -225,7 +219,7 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { +TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); BroadcastInDim(ConstantR1(&b, {1, 2}), {3, 2}, {1}); @@ -241,7 +235,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { } // Tests implicit broadcasting of PREDs. -XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { +TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaBuilder b(TestName()); Array2D x_vals(2, 1); @@ -264,10 +258,10 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { expected(1, 0, 0) = true; expected(1, 1, 0) = false; - ComputeAndCompareR3(&b, expected, {x_data.get(), y_data.get()}); + ComputeAndCompareR3(&b, expected, {&x_data, &y_data}); } -XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { +TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XlaBuilder b(TestName()); Broadcast(ConstantR1(&b, {}), {2}); @@ -275,7 +269,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { +TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { XlaBuilder b(TestName()); Broadcast(ConstantR1(&b, {1, 2, 3}), {0}); @@ -283,7 +277,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { +TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { // Verify that binary op and degenerate dimension broadcast work together in // the same operation. // @@ -337,10 +331,10 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { Array3D r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1], spec.input_bounds[2]); - std::unique_ptr r3_global_data = + const Literal r3_global_data = MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape, &r3_array, 1.0, 2.5, 56789); - std::unique_ptr r3_implicit_global_data = + const Literal r3_implicit_global_data = MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, &r3_implicit_array, 1.0, 0.2, 56789); @@ -370,9 +364,9 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); - ComputeAndCompareLiteral( - &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()}, - ErrorSpec(1e-7, 1e-7)); + ComputeAndCompareLiteral(&builder, expected, + {&r3_implicit_global_data, &r3_global_data}, + ErrorSpec(1e-7, 1e-7)); } INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, @@ -380,7 +374,7 @@ INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, ::testing::ValuesIn(kR3ImplicitBroadcastTestCases)); // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XlaBuilder b(TestName()); XlaOp r1h; XlaOp r3h; @@ -395,11 +389,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); - ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {&r3, &r1}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( @@ -412,7 +405,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( @@ -425,7 +418,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); @@ -439,7 +432,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); @@ -453,7 +446,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral( &b, LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); @@ -467,7 +460,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { +TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( @@ -584,13 +577,13 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { Array2D r2_implicit_array2(spec.input_bounds2[0], spec.input_bounds2[1]); - std::unique_ptr r2_global_data = + const Literal r2_global_data = MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape, &r2_array, 1.0, 2.5, 56789); - std::unique_ptr r2_implicit_global_data1 = + const Literal r2_implicit_global_data1 = MakeR2Data(spec.input_bounds1, spec.minor2major_layout, &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789); - std::unique_ptr r2_implicit_global_data2 = + const Literal r2_implicit_global_data2 = MakeR2Data(spec.input_bounds2, spec.minor2major_layout, &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); @@ -619,8 +612,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, expected, - {r2_implicit_global_data1.get(), r2_global_data.get(), - r2_implicit_global_data2.get()}, + {&r2_implicit_global_data1, &r2_global_data, &r2_implicit_global_data2}, ErrorSpec(1e-6, 1e-6)); } @@ -628,7 +620,7 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, BroadcastR2ImplicitTest, ::testing::ValuesIn(kR2ImplicitBroadcastTestCases)); -XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { +TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}})); auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); @@ -639,7 +631,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { +TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1}, {2}})); auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); @@ -650,7 +642,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { +TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( @@ -663,7 +655,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { +TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( @@ -676,7 +668,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { +TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( @@ -689,7 +681,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { +TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XlaBuilder b(TestName()); auto r1_0 = ConstantR1(&b, {1000, 2000}); auto r1_1 = ConstantR1(&b, {100, 200}); @@ -710,7 +702,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { +TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { XlaBuilder b(TestName()); auto r1_0 = ConstantR1(&b, {1000, 2000}); auto r1_1 = ConstantR1(&b, {100, 200}); @@ -731,7 +723,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { +TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) // results in a shape incompatible with the lhs [2, 3, 1]. XlaBuilder b(TestName()); @@ -741,33 +733,33 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); - auto result_status = Execute(&b, {}); + const absl::StatusOr result_status = ExecuteAndTransfer(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), HasSubstr("dimension 0 mismatch")); } -XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { +TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 2.0}}), ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); - auto result_status = Execute(&b, {}); + absl::StatusOr result_status = ExecuteAndTransfer(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } -XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { +TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 2.0}}), ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); - auto result_status = Execute(&b, {}); + absl::StatusOr result_status = ExecuteAndTransfer(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); diff --git a/third_party/xla/xla/tests/call_test.cc b/third_party/xla/xla/tests/call_test.cc index 4fdfc73db84296..45c2f4cb39c0ea 100644 --- a/third_party/xla/xla/tests/call_test.cc +++ b/third_party/xla/xla/tests/call_test.cc @@ -16,22 +16,25 @@ limitations under the License. #include #include +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class CallOpTest : public ClientLibraryTestBase { +class CallOpTest : public ClientLibraryTestRunnerMixin { protected: XlaComputation CreateR0F32IdentityComputation() { XlaBuilder builder("Identity"); @@ -74,7 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2}); }; -XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { +TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0(42.0)); @@ -83,7 +86,7 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); } -XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { +TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); @@ -93,7 +96,7 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); } -XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { +TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = @@ -105,7 +108,7 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); } -XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { +TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { XlaBuilder builder("inner"); { auto x = Parameter(&builder, 0, r0f32_, "x"); @@ -130,13 +133,11 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { x = Call(&builder3, outer, {x}); } - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr start, - client_->TransferToServer(LiteralUtil::CreateR0(1.0f))); - ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); + const Literal start = LiteralUtil::CreateR0(1.0f); + ComputeAndCompareR0(&builder3, 10.0f, {&start}, ErrorSpec(0.0f)); } -XLA_TEST_F(CallOpTest, CallR0F32Tuple) { +TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); auto elem = LiteralUtil::CreateR0(42.0); diff --git a/third_party/xla/xla/tests/check_execution_arity_test.cc b/third_party/xla/xla/tests/check_execution_arity_test.cc index f7d080f86d7292..c2fd56d738fc6c 100644 --- a/third_party/xla/xla/tests/check_execution_arity_test.cc +++ b/third_party/xla/xla/tests/check_execution_arity_test.cc @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "absl/status/statusor.h" -#include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { @@ -32,7 +33,8 @@ namespace { using ::testing::ContainsRegex; -class CheckExecutionArityTest : public ClientLibraryTestBase {}; +class CheckExecutionArityTest + : public ClientLibraryTestRunnerMixin {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); @@ -42,33 +44,30 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1"); Add(p0, p1); - auto param0_data = client_->TransferToServer(param_literal).value(); - auto param1_data = client_->TransferToServer(param_literal).value(); - auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); auto computation = std::move(computation_status).value(); // The arity of the UserComputation is 2 arguments. Execution will succeed // with 2 arguments, but fail with a different number. - auto result_two_args = client_->Execute( - computation, {param0_data.get(), param1_data.get()}, &execution_options_); + absl::StatusOr result_two_args = + ExecuteAndTransfer(computation, {¶m_literal, ¶m_literal}); ASSERT_IS_OK(result_two_args.status()); - auto result_one_arg = - client_->Execute(computation, {param0_data.get()}, &execution_options_); + absl::StatusOr result_one_arg = + ExecuteAndTransfer(computation, {¶m_literal}); ASSERT_FALSE(result_one_arg.ok()); ASSERT_EQ(result_one_arg.status().code(), tsl::error::INVALID_ARGUMENT); ASSERT_THAT(result_one_arg.status().message(), ContainsRegex("takes 2")); - auto result_zero_args = - client_->Execute(computation, {}, &execution_options_); + absl::StatusOr result_zero_args = + ExecuteAndTransfer(computation, {}); ASSERT_FALSE(result_zero_args.ok()); ASSERT_EQ(result_zero_args.status().code(), tsl::error::INVALID_ARGUMENT); ASSERT_THAT(result_zero_args.status().message(), ContainsRegex("takes 2")); } -XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { +TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { XlaBuilder builder("add_two_params"); auto p0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); @@ -79,21 +78,18 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_IS_OK(computation_status.status()); auto computation = std::move(computation_status).value(); - auto f32_literal = LiteralUtil::CreateR0(1.1f); - auto f32_data = client_->TransferToServer(f32_literal).value(); - auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); - auto f32_4_data = client_->TransferToServer(f32_4_literal).value(); - auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); - auto u8_4_data = client_->TransferToServer(u8_4_literal).value(); + const Literal f32_literal = LiteralUtil::CreateR0(1.1f); + const Literal f32_4_literal = + LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + const Literal u8_4_literal = LiteralUtil::CreateR1U8("hola"); // Match - auto status = client_->Execute( - computation, {f32_data.get(), f32_4_data.get()}, &execution_options_); + absl::StatusOr status = + ExecuteAndTransfer(computation, {&f32_literal, &f32_4_literal}); ASSERT_IS_OK(status.status()); // Shape mismatch in parameter 0 - status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}, - &execution_options_); + status = ExecuteAndTransfer(computation, {&f32_4_literal, &f32_4_literal}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tsl::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().message(), @@ -101,8 +97,7 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { "Argument does not match shape of computation parameter 0")); // Shape mismatch in parameter 1 (rank) - status = client_->Execute(computation, {f32_data.get(), f32_data.get()}, - &execution_options_); + status = ExecuteAndTransfer(computation, {&f32_literal, &f32_literal}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tsl::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().message(), @@ -110,8 +105,7 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { "Argument does not match shape of computation parameter 1")); // Shape mismatch in parameter 1 (element type) - status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}, - &execution_options_); + status = ExecuteAndTransfer(computation, {&f32_literal, &u8_4_literal}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tsl::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().message(), diff --git a/third_party/xla/xla/tests/complex_unary_op_test.cc b/third_party/xla/xla/tests/complex_unary_op_test.cc index 0bac77a9d69b7e..661a16fa958155 100644 --- a/third_party/xla/xla/tests/complex_unary_op_test.cc +++ b/third_party/xla/xla/tests/complex_unary_op_test.cc @@ -13,18 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include #include -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/complex_unary_op_samples.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { @@ -32,7 +36,7 @@ namespace { template constexpr bool dependent_false = false; -class ComplexUnaryOpTest : public ClientLibraryTestBase { +class ComplexUnaryOpTest : public ClientLibraryTestRunnerMixin { protected: template std::vector get_column(const std::vector>& table) { @@ -95,26 +99,26 @@ class ComplexUnaryOpTest : public ClientLibraryTestBase { } }; -XLA_TEST_F(ComplexUnaryOpTest, Log1pTest) { +TEST_F(ComplexUnaryOpTest, Log1pTest) { UnaryTestHelper>( [](XlaOp x) { return Log1p(x); }); UnaryTestHelper>( [](XlaOp x) { return Log1p(x); }); } -XLA_TEST_F(ComplexUnaryOpTest, TanTest) { +TEST_F(ComplexUnaryOpTest, TanTest) { UnaryTestHelper>( [](XlaOp x) { return Tan(x); }); UnaryTestHelper>( [](XlaOp x) { return Tan(x); }); } -XLA_TEST_F(ComplexUnaryOpTest, AsinTest) { +TEST_F(ComplexUnaryOpTest, AsinTest) { UnaryTestHelper>(Asin); UnaryTestHelper>(Asin); } -XLA_TEST_F(ComplexUnaryOpTest, AsinhTest) { +TEST_F(ComplexUnaryOpTest, AsinhTest) { UnaryTestHelper>(Asinh); UnaryTestHelper>(Asinh); } diff --git a/third_party/xla/xla/tests/conditional_test.cc b/third_party/xla/xla/tests/conditional_test.cc index f221ad7674b5ec..95fa2b2ae09eae 100644 --- a/third_party/xla/xla/tests/conditional_test.cc +++ b/third_party/xla/xla/tests/conditional_test.cc @@ -13,32 +13,40 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include #include +#include -#include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "xla/array2d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/platform/threadpool.h" namespace xla { namespace { -class ConditionalOpTest : public ClientLibraryTestBase { +constexpr ErrorSpec kErrorSpec{0.001}; + +class ConditionalOpTest : public ClientLibraryTestRunnerMixin { protected: void SetUp() override { - ClientLibraryTestBase::SetUp(); - execution_options_.mutable_debug_options() - ->set_xla_test_add_command_buffer_mode(true); + ClientLibraryTestRunnerMixin::SetUp(); + mutable_debug_options()->set_xla_test_add_command_buffer_mode(true); } XlaComputation CreateR0ConstantComputation(float value) { @@ -183,7 +191,6 @@ class ConditionalOpTest : public ClientLibraryTestBase { Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})}); Shape empty_tuple_ = ShapeUtil::MakeTupleShape({}); - ErrorSpec error_spec_{0.001}; }; // Test fixture to run indexed conditional (switch/case) tests with varying @@ -192,7 +199,7 @@ class CaseOpTest : public ConditionalOpTest, public ::testing::WithParamInterface {}; // Test true and false computations that do not take any parameters. -XLA_TEST_F(ConditionalOpTest, Parameters0) { +TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); @@ -201,7 +208,7 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) { auto false_computation = CreateR0ConstantComputation(12.0f); Conditional(pred, operands, true_computation, operands, false_computation); - ComputeAndCompareR0(&builder, 56.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 56.0f, {&pred_arg}, kErrorSpec); } // Test branch computations that do not take any parameters. @@ -229,13 +236,13 @@ XLA_TEST_P(CaseOpTest, Parameters0) { float expected = 10 * static_cast((bi < 0 || bi >= num_branches) ? num_branches - 1 : bi); - ComputeAndCompareR0(&builder, expected, {branch_index_arg.get()}, - error_spec_); + ComputeAndCompareR0(&builder, expected, {&branch_index_arg}, + kErrorSpec); } } // Test true and false computations that take in 1 parameter. -XLA_TEST_F(ConditionalOpTest, Parameters1) { +TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -244,7 +251,7 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { auto identity = CreateR0IdentityComputation(); Conditional(pred, operand1, identity, operand2, identity); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test branch computations that take in 1 parameter. @@ -281,14 +288,14 @@ XLA_TEST_P(CaseOpTest, Parameters1) { float expected = (bi < 0 || bi >= num_branches) ? expecteds[num_branches - 1] : expecteds[bi]; - ComputeAndCompareR0(&builder, expected, {branch_index_arg.get()}, - error_spec_); + ComputeAndCompareR0(&builder, expected, {&branch_index_arg}, + kErrorSpec); } } // Test conditional with two different computations in the true and false cases // that take in different arguments. -XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { +TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -297,12 +304,12 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { Conditional(pred, operand1, CreateR0CeilComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test conditional with two different computations in the true and false cases // that take in the same arguments. -XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { +TEST_F(ConditionalOpTest, DiffComputationsSameArg) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -310,12 +317,12 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { Conditional(pred, operand, CreateR0CeilComputation(), operand, CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test conditional with the same computation in the true and false cases but // take in different arguments. -XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { +TEST_F(ConditionalOpTest, SameComputationDiffArgs) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -324,12 +331,12 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { auto floor = CreateR0FloorComputation(); Conditional(pred, operand1, floor, operand2, floor); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test conditional with the same computation in the true and false cases that // take in the same arguments. -XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { +TEST_F(ConditionalOpTest, SameComputationSameArg) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -337,12 +344,12 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { auto floor = CreateR0FloorComputation(); Conditional(pred, operand, floor, operand, floor); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test conditional with different instances of the same computation in the true // and false cases. -XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { +TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -351,11 +358,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { Conditional(pred, operand1, CreateR0FloorComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test the case when a call invokes a computation that contains a conditional. -XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { +TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); XlaBuilder inner_builder(TestName() + ".inner_conditional"); auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0"); @@ -372,12 +379,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { auto operand2 = ConstantR0(&builder, 12.6f); Call(&builder, inner_builder_result, {pred, operand1, operand2}); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test true and false computations that take in 2 parameters and predicate is // true. -XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { +TEST_F(ConditionalOpTest, Parameters2TrueBranch) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); @@ -387,12 +394,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { Conditional(pred, operands, CreateR0TupleAddComputation(), operands, CreateR0TupleSubComputation()); - ComputeAndCompareR0(&builder, 68.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 68.0f, {&pred_arg}, kErrorSpec); } // Test true and false computations that take in 2 parameters and predicate is // false. -XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { +TEST_F(ConditionalOpTest, Parameters2FalseBranch) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -402,12 +409,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { Conditional(pred, operands, CreateR0TupleAddComputation(), operands, CreateR0TupleSubComputation()); - ComputeAndCompareR0(&builder, 44.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 44.0f, {&pred_arg}, kErrorSpec); } // Test true and false computations that take in 2 array parameters and // predicate is true. -XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { +TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); @@ -417,8 +424,7 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { Conditional(pred, operands, CreateR1TupleAddComputation(), operands, CreateR1TupleSubComputation()); - ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {pred_arg.get()}, - error_spec_); + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {&pred_arg}, kErrorSpec); } // Test branch computations that take in 2 array parameters. @@ -454,7 +460,7 @@ XLA_TEST_P(CaseOpTest, Parameters2Array) { (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi); ComputeAndCompareR1( &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11}, - {branch_index_arg.get()}, error_spec_); + {&branch_index_arg}, kErrorSpec); } } @@ -463,7 +469,7 @@ INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest, // Test true and false computations that take in 2 array parameters and // predicate is false. -XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { +TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -473,12 +479,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { Conditional(pred, operands, CreateR1TupleAddComputation(), operands, CreateR1TupleSubComputation()); - ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {pred_arg.get()}, - error_spec_); + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {&pred_arg}, kErrorSpec); } // Test true and false computations that return a tuple of scalars. -XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { +TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); @@ -491,11 +496,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { &builder, LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(12.0f), LiteralUtil::CreateR0(25.0f)}), - {pred_arg.get()}, error_spec_); + {&pred_arg}, kErrorSpec); } // Test true and false computations that return a tuple of arrays. -XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { +TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XlaBuilder builder(TestName()); XlaOp pred; auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); @@ -509,12 +514,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1({13.0f, 16.0f}), LiteralUtil::CreateR1({26.0f, 30.0f})}), - {pred_arg.get()}, error_spec_); + {&pred_arg}, kErrorSpec); } // Test true and false computations that return a tuple of a predicate, a // scalar, and an array. -XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { +TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { XlaBuilder true_builder(TestName() + ".true"); { Parameter(&true_builder, 0, empty_tuple_, "tuple"); @@ -549,11 +554,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { {LiteralUtil::CreateR0(true), LiteralUtil::CreateR0(12.2f), LiteralUtil::CreateR1({12.8f, 14.6f})}), - {pred_arg.get()}, error_spec_); + {&pred_arg}, kErrorSpec); } // Test true and false computations that return a nested tuple. -XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { +TEST_F(ConditionalOpTest, ReturnNestedTuple) { XlaBuilder true_builder(TestName() + ".true"); { Parameter(&true_builder, 0, empty_tuple_, "tuple"); @@ -598,12 +603,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1({62.1f, 67.4f}), LiteralUtil::CreateR0(9.3f)})}), - {pred_arg.get()}, error_spec_); + {&pred_arg}, kErrorSpec); } // Test conditional that takes in scalar operands in the form of external // params. -XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { +TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); XlaBuilder builder(TestName()); @@ -616,14 +621,13 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { Conditional(pred, operand1, CreateR0CeilComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0( - &builder, 57.0f, - {pred_arg.get(), operand1_param.get(), operand2_param.get()}, - error_spec_); + ComputeAndCompareR0(&builder, 57.0f, + {&pred_arg, &operand1_param, &operand2_param}, + kErrorSpec); } // Test conditional that takes in array operands in the form of external params. -XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { +TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); XlaBuilder builder(TestName()); @@ -636,14 +640,13 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { Conditional(pred, operand1, CreateR1CeilComputation(), operand2, CreateR1FloorComputation()); - ComputeAndCompareR1( - &builder, {10.0f, 11.0f}, - {pred_arg.get(), operand1_param.get(), operand2_param.get()}, - error_spec_); + ComputeAndCompareR1(&builder, {10.0f, 11.0f}, + {&pred_arg, &operand1_param, &operand2_param}, + kErrorSpec); } // Test the case where one conditional is nested within another. -XLA_TEST_F(ConditionalOpTest, NestedConditionals) { +TEST_F(ConditionalOpTest, NestedConditionals) { XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); @@ -669,11 +672,11 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { Conditional(pred1, tuple_operand, std::move(inner_builder_result).value(), operand3, CreateR0IdentityComputation()); - ComputeAndCompareR0(&builder, 12.0f, - {pred1_arg.get(), pred2_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred1_arg, &pred2_arg}, + kErrorSpec); } -XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { +TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); @@ -696,11 +699,11 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { auto tuple_operand = Tuple(&builder, {pred, operand1, operand2}); Call(&builder, std::move(inner_builder_result).value(), {tuple_operand}); - ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {&pred_arg}, kErrorSpec); } // Test a mismatch in the shape of the true operand and true computation. -XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { +TEST_F(ConditionalOpTest, ShapeMismatch) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, true); auto operand1 = ConstantR0(&builder, 56.0f); @@ -716,7 +719,7 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { "only parameter of branch computation 0")); } -XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { +TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_}); XlaComputation swapper; { @@ -760,7 +763,7 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { &builder, LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), - {x_arg.get(), y_arg.get()}, error_spec_); + {&x_arg, &y_arg}, kErrorSpec); }; test_swap(3.11f, 9.4f); test_swap(11.24f, 5.55f); @@ -768,7 +771,7 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { // Test conditional that duplicates tuple elements in the then and else // computations. This is a regression test for b/112550242. -XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { +TEST_F(ConditionalOpTest, DuplicateElementsConditional) { const Shape scalar = ShapeUtil::MakeShape(S32, {}); const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar}); XlaComputation then_comp; @@ -801,7 +804,7 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); Conditional(p_pred, p, then_comp, p, else_comp); - ComputeAndCompare(&builder, args); + ComputeAndCompare(&builder, {&args[0], &args[1]}); } { // Pred is false case. @@ -814,11 +817,13 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); Conditional(p_pred, p, then_comp, p, else_comp); - ComputeAndCompare(&builder, args); + ComputeAndCompare(&builder, {&args[0], &args[1]}); } } -XLA_TEST_F(HloTestBase, ParallelExecution) { +using ConditionalOpHloTest = HloTestBase; + +TEST_F(ConditionalOpHloTest, ParallelExecution) { // Test conditional works when an executable is executed in parallel. const char* const hlo_string = R"( HloModule m diff --git a/third_party/xla/xla/tests/constants_test.cc b/third_party/xla/xla/tests/constants_test.cc index f6ba0406e31e7e..239c134af63052 100644 --- a/third_party/xla/xla/tests/constants_test.cc +++ b/third_party/xla/xla/tests/constants_test.cc @@ -17,33 +17,34 @@ limitations under the License. #include "xla/hlo/builder/lib/constants.h" +#include #include +#include #include #include #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class ConstantsTest : public ClientLibraryTestBase { - protected: - const ErrorSpec error_spec_{1e-3, 1e-5}; -}; +constexpr ErrorSpec kErrorSpec{1e-3, 1e-5}; + +using ConstantsTest = ClientLibraryTestRunnerMixin; template class ConstantsFloatTest : public ConstantsTest {}; @@ -66,17 +67,17 @@ TEST_F(ConstantsTest, ZeroCellF32) { XlaBuilder builder(TestName()); ConstantR1(&builder, {}); - ComputeAndCompareR1(&builder, {}, {}, error_spec_); + ComputeAndCompareR1(&builder, {}, {}, kErrorSpec); } TYPED_TEST(ConstantsFloatTest, OneCellFloat) { std::vector constant = {TypeParam{2.0}}; - XlaBuilder builder(ClientLibraryTestBase::TestName()); + XlaBuilder builder(ConstantsTest::TestName()); ConstantR1(&builder, constant); - ClientLibraryTestBase::ComputeAndCompareR1(&builder, constant, {}, - this->error_spec_); + ConstantsTest::ComputeAndCompareR1(&builder, constant, {}, + kErrorSpec); } TEST_F(ConstantsTest, OneCellS32) { @@ -125,7 +126,7 @@ TEST_F(ConstantsTest, EightCells) { XlaBuilder builder(TestName()); ConstantR1(&builder, constant); - ComputeAndCompareR1(&builder, constant, {}, error_spec_); + ComputeAndCompareR1(&builder, constant, {}, kErrorSpec); } TEST_F(ConstantsTest, SixteenCells) { @@ -135,14 +136,14 @@ TEST_F(ConstantsTest, SixteenCells) { XlaBuilder builder(TestName()); ConstantR1(&builder, constant); - ComputeAndCompareR1(&builder, constant, {}, error_spec_); + ComputeAndCompareR1(&builder, constant, {}, kErrorSpec); } TEST_F(ConstantsTest, Empty_0x2) { XlaBuilder builder(TestName()); ConstantR2FromArray2D(&builder, Array2D(0, 2)); - ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); + ComputeAndCompareR2(&builder, Array2D(0, 2), {}, kErrorSpec); } TEST_F(ConstantsTest, Small_2x2) { @@ -152,7 +153,7 @@ TEST_F(ConstantsTest, Small_2x2) { XlaBuilder builder(TestName()); ConstantR2FromArray2D(&builder, *constant); - ComputeAndCompareR2(&builder, *constant, {}, error_spec_); + ComputeAndCompareR2(&builder, *constant, {}, kErrorSpec); } TEST_F(ConstantsTest, Empty_3x0x2) { @@ -192,13 +193,13 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { { XlaBuilder builder(TestName()); ConstantLiteral(&builder, input_literal); - ComputeAndCompareR4(&builder, input_array, {}, error_spec_); + ComputeAndCompareR4(&builder, input_array, {}, kErrorSpec); } { XlaBuilder builder(TestName()); ConstantR4FromArray4D(&builder, input_array); - ComputeAndCompareR4(&builder, input_array, {}, error_spec_); + ComputeAndCompareR4(&builder, input_array, {}, kErrorSpec); } } @@ -212,9 +213,9 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { Literal result = ExecuteAndTransfer(&builder, {}).value(); LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - LiteralSlice(result, {0}), error_spec_); + LiteralSlice(result, {0}), kErrorSpec); LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(result, {1}), - error_spec_); + kErrorSpec); } TEST_F(ConstantsTest, Token) { @@ -222,7 +223,7 @@ TEST_F(ConstantsTest, Token) { ConstantLiteral(&builder, LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); - TF_ASSERT_OK(Execute(&builder, {}).status()); + TF_ASSERT_OK(ExecuteAndTransfer(&builder, {}).status()); } TEST_F(ConstantsTest, FullLike) { @@ -230,7 +231,7 @@ TEST_F(ConstantsTest, FullLike) { auto val1 = Iota(&b, F32, 3); auto val2 = FullLike(val1, 10); val1 + val2; - ComputeAndCompareR1(&b, {10, 11, 12}, {}, error_spec_); + ComputeAndCompareR1(&b, {10, 11, 12}, {}, kErrorSpec); } TEST_F(ConstantsTest, IllegalFullLikeOnTuple) { @@ -245,14 +246,13 @@ TEST_F(ConstantsTest, FullLikeScalar) { auto scalar1 = ConstantR0WithType(&b, F32, 1); auto scalar2 = FullLike(scalar1, 2); scalar1 - scalar2; - ComputeAndCompareR0(&b, -1, {}, error_spec_); + ComputeAndCompareR0(&b, -1, {}, kErrorSpec); } -class ConstantsHloTest : public HloTestBase {}; +using ConstantsHloTest = HloTestBase; // TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. -XLA_TEST_F(ConstantsHloTest, - DISABLED_ON_TPU(DISABLED_ON_GPU(BitcastOfConstant))) { +TEST_F(ConstantsHloTest, DISABLED_ON_TPU(DISABLED_ON_GPU(BitcastOfConstant))) { const char* testcase = R"( HloModule module, is_scheduled=true diff --git a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc index 833a9266afb3bf..2b7e718966e7b5 100644 --- a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc +++ b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc @@ -13,20 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include "absl/status/statusor.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { @@ -54,7 +58,8 @@ absl::StatusOr CreateConvDimensionNumbers( return dimension_numbers; } -class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; +class ConvolutionDimensionNumbersTest + : public ClientLibraryTestRunnerMixin {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { @@ -83,15 +88,13 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { ::testing::HasSubstr("output are not unique")); } -XLA_TEST_F(ConvolutionDimensionNumbersTest, - TwoConvsWithDifferentDimensionNumbers) { +TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { auto input_array = std::make_unique>(2, 3, 5, 5); input_array->FillWithMultiples(0.1); auto weight_array = std::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); - auto weight_data = - client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array)) - .value(); + const Literal weight_literal = + LiteralUtil::CreateR4FromArray4D(*weight_array); XlaBuilder builder(TestName()); auto input = ConstantR4FromArray4D(&builder, *input_array); @@ -121,7 +124,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto expected_conv2 = ReferenceUtil::ConvArray4DGeneralDimensions( *input_array, *expected_conv1, {1, 1}, Padding::kValid, dim_nums); - ComputeAndCompareR4(&builder, *expected_conv2, {weight_data.get()}, + ComputeAndCompareR4(&builder, *expected_conv2, {&weight_literal}, ErrorSpec(0.001, 0.01)); } diff --git a/third_party/xla/xla/tests/convolution_test_1d.cc b/third_party/xla/xla/tests/convolution_test_1d.cc index 58d53efd2dc6a4..fe82387ec0924b 100644 --- a/third_party/xla/xla/tests/convolution_test_1d.cc +++ b/third_party/xla/xla/tests/convolution_test_1d.cc @@ -16,42 +16,36 @@ limitations under the License. // Tests of 1D convolution with trivial kernels and no special variations (like // strides and padding). -#include +#include +#include -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/array2d.h" +#include "Eigen/Core" #include "xla/array3d.h" -#include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/reference_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class ConvolutionTest : public ClientLibraryTestBase { - protected: #if XLA_TEST_BACKEND_GPU - // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial - // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3); +// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial +// convolution. So relax the absolute error threshold. +constexpr ErrorSpec kErrorSpec(1e-2, 1e-3); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3); +constexpr ErrorSpec kErrorSpec(1e-4, 1e-3); #endif -}; + +using ConvolutionTest = ClientLibraryTestRunnerMixin; #ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 using TestTypes = ::testing::Types; @@ -121,11 +115,8 @@ class Convolve1D1WindowTestBase auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}).value(); - auto input_literal = client_->TransferToServer(input_r3).value(); - auto filter_literal = client_->TransferToServer(filter_r3).value(); - ComputeAndCompareLiteral(&builder, expected_r3, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompareLiteral(&builder, expected_r3, {&input_r3, &filter_r3}, + kErrorSpec); } }; @@ -202,7 +193,7 @@ INSTANTIATE_TEST_CASE_P( ); #endif -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { +TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); @@ -217,16 +208,11 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } template @@ -252,16 +238,11 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } }; // namespace @@ -269,8 +250,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } // Basic test with LHS dilation (i.e. strided transposed convolution). -XLA_TEST_F(ConvolutionTest, - Convolve1D_1x1x5_1x1x3_WithLHSDilation_FullPadding) { +TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilation_FullPadding) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 5}); @@ -289,19 +269,14 @@ XLA_TEST_F(ConvolutionTest, Array3D expected({{{34, 22, 56, 33, 78, 44, 100}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilation_NoPadding) { +TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilation_NoPadding) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 5}); @@ -319,20 +294,14 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilation_NoPadding) { Array3D filter({{{10, 11, 12}}}); Array3D expected({{{12, 11, 34, 22, 56, 33, 78, 44, 100, 55, 50}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } -XLA_TEST_F(ConvolutionTest, - Convolve1D_1x1x5_1x1x3_WithLHSDilation_HalfPadding) { +TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilation_HalfPadding) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 5}); @@ -350,20 +319,15 @@ XLA_TEST_F(ConvolutionTest, Array3D filter({{{10, 11, 12}}}); Array3D expected({{{11, 34, 22, 56, 33, 78, 44, 100, 55}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } // Test multiple output channels. -XLA_TEST_F(ConvolutionTest, Convolve1D_1x1x5_2x1x3_WithLHSDilation) { +TEST_F(ConvolutionTest, Convolve1D_1x1x5_2x1x3_WithLHSDilation) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 5}); @@ -382,20 +346,15 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x1x5_2x1x3_WithLHSDilation) { Array3D expected( {{{34, 22, 56, 33, 78, 44, 100}, {68, 44, 112, 66, 156, 88, 200}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } // Test multiple input channels. -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x3_WithLHSDilation) { +TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x3_WithLHSDilation) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); @@ -414,20 +373,15 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x3_WithLHSDilation) { Array3D expected({{{730, 390, 870, 460, 1010, 530, 1150}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } // Batched version of the above test. -XLA_TEST_F(ConvolutionTest, Convolve1D_3x2x5_1x2x3_WithLHSDilation) { +TEST_F(ConvolutionTest, Convolve1D_3x2x5_1x2x3_WithLHSDilation) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {3, 2, 5}); @@ -451,20 +405,15 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_3x2x5_1x2x3_WithLHSDilation) { {{7300, 3900, 8700, 4600, 10100, 5300, 11500}}, {{1460, 780, 1740, 920, 2020, 1060, 2300}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } // Test all together: batched, multiple input and output channels. -XLA_TEST_F(ConvolutionTest, Convolve1D_3x2x5_2x2x3_WithLHSDilation) { +TEST_F(ConvolutionTest, Convolve1D_3x2x5_2x2x3_WithLHSDilation) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {3, 2, 5}); @@ -492,22 +441,17 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_3x2x5_2x2x3_WithLHSDilation) { {{1460, 780, 1740, 920, 2020, 1060, 2300}, {1606, 858, 1914, 1012, 2222, 1166, 2530}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } // Test LHS dilation (i.e. transposed convolution) and window strides at the // same time. That's probably never used in practice, but since the generic // algorithm covers it, we test it anyway with a simple case. -XLA_TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilationAndStrides) { +TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilationAndStrides) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 5}); @@ -526,19 +470,14 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x1x5_1x1x3_WithLHSDilationAndStrides) { Array3D expected({{{34, 56, 78, 100}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { +TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); @@ -557,16 +496,11 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } template @@ -593,16 +527,11 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { Array3D expected( {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .value(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .value(); + const Literal input_literal = LiteralUtil::CreateR3FromArray3D(input); + const Literal filter_literal = LiteralUtil::CreateR3FromArray3D(filter); ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + {&input_literal, &filter_literal}, kErrorSpec); } }; diff --git a/third_party/xla/xla/tests/convolution_variants_test.cc b/third_party/xla/xla/tests/convolution_variants_test.cc index 719e9b3d80e8bd..b7284a297715f3 100644 --- a/third_party/xla/xla/tests/convolution_variants_test.cc +++ b/third_party/xla/xla/tests/convolution_variants_test.cc @@ -17,51 +17,53 @@ limitations under the License. // in small sized data. #include +#include #include #include #include #include #include +#include "absl/types/span.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class ConvolutionVariantsTest : public ClientLibraryTestBase { - protected: #if XLA_TEST_BACKEND_GPU - // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial - // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-1, 1e-5); +// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial +// convolution. So relax the absolute error threshold. +ErrorSpec kErrorSpec(1e-1, 1e-5); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-2); +ErrorSpec kErrorSpec(1e-4, 1e-2); #endif - XlaOp ConvWithHighestPrecision(const XlaOp lhs, const XlaOp rhs, - absl::Span window_strides, - Padding padding) { - PrecisionConfig precision_config; - // Set the 2 operands to have the HIGHEST precision. - precision_config.add_operand_precision(PrecisionConfig::HIGHEST); - precision_config.add_operand_precision(PrecisionConfig::HIGHEST); - return Conv(lhs, rhs, window_strides, padding, /*feature_group_count=*/1, - /*batch_group_count=*/1, &precision_config); - } -}; +XlaOp ConvWithHighestPrecision(const XlaOp lhs, const XlaOp rhs, + absl::Span window_strides, + Padding padding) { + PrecisionConfig precision_config; + // Set the 2 operands to have the HIGHEST precision. + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + return Conv(lhs, rhs, window_strides, padding, /*feature_group_count=*/1, + /*batch_group_count=*/1, &precision_config); +} + +using ConvolutionVariantsTest = ClientLibraryTestRunnerMixin; -XLA_TEST_F(ConvolutionVariantsTest, Minimal) { +TEST_F(ConvolutionVariantsTest, Minimal) { XlaBuilder builder(TestName()); const Array4D input_array(1, 1, 1, 1, {2}); @@ -73,10 +75,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(1, 1, 1, 1, {6}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { +TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { XlaBuilder builder(TestName()); const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); @@ -88,10 +90,10 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(5, 1, 1, 1, {2, 4, 6, 8, 10}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { +TEST_F(ConvolutionVariantsTest, Flat1x1) { XlaBuilder builder(TestName()); Array4D input_array(2, 1, 3, 4); @@ -105,10 +107,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { Array4D expected(2, 1, 3, 4); expected.FillWithMultiples(2.3); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { +TEST_F(ConvolutionVariantsTest, Deep1x1) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 1, {10, 1}); @@ -120,10 +122,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 3, 1, 1, {12, 34, 56}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { +TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 2, {1, 2}); @@ -135,10 +137,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {12}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { +TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); @@ -150,10 +152,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {12, 23}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { +TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); @@ -165,10 +167,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 1, {12, 34}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { +TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); @@ -180,10 +182,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {13, 24}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { +TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); @@ -195,10 +197,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {1234}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { +TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { XlaBuilder builder(TestName()); Array4D input_array( @@ -216,10 +218,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { 2, 2, 2, 2, {167, 1278, 3490, 4500, 0.0167, 0.1278, 0.3490, 0.4500, // plane 0 334, 2556, 6980, 9000, 0.0334, 0.2556, 0.6980, 0.9000}); // plane 1 - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { +TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); @@ -231,10 +233,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {10, 30}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { +TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); @@ -246,10 +248,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 3, {10, 30, 50}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { +TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); @@ -261,10 +263,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 1, {123}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { +TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); @@ -276,10 +278,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {123, 345}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { +TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -291,10 +293,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { Conv(input, filter, {2, 2}, Padding::kValid); Array4D expected(1, 1, 2, 2, {10, 30, 70, 90}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { +TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 1, {1}); @@ -306,10 +308,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 1, {20}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { +TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); @@ -321,10 +323,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 3, {123, 1230, 12300}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { +TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); @@ -339,10 +341,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 2, 2, {104, 230, 2300, 10400}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { +TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); @@ -354,10 +356,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 2, {13, 24}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { +TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -369,10 +371,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 2, {216, 276, 396, 456}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { +TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); @@ -384,10 +386,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {33, 53}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { +TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { XlaBuilder builder(TestName()); std::vector input_data(64); @@ -404,10 +406,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {2016, 4032}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { +TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { XlaBuilder builder(TestName()); std::vector input_data(16 * 1 * 1 * 1); @@ -425,10 +427,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { std::vector expected_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; Array4D expected(16, 1, 1, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { +TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { XlaBuilder builder(TestName()); constexpr int bs = 16; @@ -456,10 +458,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { expected_data[i] = 10 * (i + 1); } Array4D expected(bs, 1, 1, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { +TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { XlaBuilder builder(TestName()); constexpr int kx = 2; @@ -488,10 +490,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { 43, }; Array4D expected(bs, 1, 1, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { +TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { XlaBuilder builder(TestName()); Array4D input_array(16, 1, 8, 8); @@ -516,10 +518,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { 36304, 38384, 40464, 42544, 44624, 46704, 48784, 50864, }; Array4D expected(16, 1, 1, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { +TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { XlaBuilder builder(TestName()); std::vector input_data(2 * 8 * 8); @@ -542,10 +544,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {14240, 30496}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { +TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { XlaBuilder builder(TestName()); std::vector input_data(2 * 2 * 8 * 8); @@ -568,10 +570,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 2, 1, 1, {14240, 30496, 38816, 87840}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { +TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { XlaBuilder builder(TestName()); std::vector input_data(32 * 2 * 8 * 8); @@ -608,10 +610,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { Array4D expected(32, 2, 1, 1, expected_data); // The output elements can be larger than 1e+5, making the absolute error // large sometimes. So, we focus on relative errors for this test case. - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { +TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { XlaBuilder builder(TestName()); Array4D input_array(16, 16, 1, 1); @@ -634,10 +636,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } } - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { +TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 4 * 6); @@ -653,10 +655,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 2, 2, {3924, 4257, 5922, 6255}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { +TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); @@ -672,10 +674,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { +TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 3 * 4); @@ -698,10 +700,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { {204, 40, 406, 60, 608, // 1518, 180, 1821, 210, 2124, // 4146, 460, 4651, 510, 5156}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { +TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); @@ -717,10 +719,10 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 2, {23, 34}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { +TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); @@ -736,10 +738,10 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 5, {23, 34, 45, 50, 0}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { +TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); @@ -755,10 +757,10 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 5, {0, 1, 12, 23, 34}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { +TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); @@ -781,9 +783,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { // [10, 1] --dilate-> [10, 0, 1] Array4D expected(1, 1, 1, 12, {0, 1, 0, 12, 0, 23, 0, 34, 0, 45, 0, 50}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { +TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); @@ -805,10 +807,10 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { // filter: // [10, 1] --dilate-> [10, 0, 1] Array4D expected(1, 1, 1, 2, {0, 34}); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { +TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { constexpr int bs = 1; constexpr int iz = 1; constexpr int oz = 2; @@ -838,10 +840,10 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); - ComputeAndCompareR4(&builder, *expected, {}, error_spec_); + ComputeAndCompareR4(&builder, *expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { +TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { constexpr int bs = 1; constexpr int iz = 16; constexpr int oz = 1; @@ -871,10 +873,10 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); - ComputeAndCompareR4(&builder, *expected, {}, error_spec_); + ComputeAndCompareR4(&builder, *expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { +TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { constexpr int bs = 16; constexpr int iz = 16; constexpr int oz = 1; @@ -904,10 +906,10 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); - ComputeAndCompareR4(&builder, *expected, {}, error_spec_); + ComputeAndCompareR4(&builder, *expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { +TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { constexpr int bs = 16; constexpr int iz = 16; constexpr int oz = 16; @@ -937,11 +939,10 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); - ComputeAndCompareR4(&builder, *expected, {}, error_spec_); + ComputeAndCompareR4(&builder, *expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, - RandomData_Input16x16x16x16_Filter16x16x16x16) { +TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) { constexpr int bs = 16; constexpr int iz = 16; constexpr int oz = 16; @@ -971,10 +972,10 @@ XLA_TEST_F(ConvolutionVariantsTest, std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); - ComputeAndCompareR4(&builder, *expected, {}, error_spec_); + ComputeAndCompareR4(&builder, *expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { +TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); @@ -1015,10 +1016,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { 0, 0, 0, 0, 0, 0, 0 // }; Array4D expected(1, 5, 7, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { +TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); @@ -1059,10 +1060,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { 0, 0, 0, 0, 0, 0, 0, 0 // }; Array4D expected(1, 5, 8, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { +TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); @@ -1100,10 +1101,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { 8, 10, 12, }; Array4D expected(1, 2, 3, 1, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { +TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 2); @@ -1145,7 +1146,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { 82, 105, 128, // }; Array4D expected(1, 2, 3, 3, expected_data); - ComputeAndCompareR4(&builder, expected, {}, error_spec_); + ComputeAndCompareR4(&builder, expected, {}, kErrorSpec); } // Regression test for b/32034796. @@ -1154,8 +1155,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // Conv([1,2,3], Reverse([5,6]), padding_low=1) // into // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) -XLA_TEST_F(ConvolutionVariantsTest, - BackwardInputLowPaddingLessThanHighPadding) { +TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { XlaBuilder builder(TestName()); auto gradients = ConstantR4FromArray4D( @@ -1166,15 +1166,14 @@ XLA_TEST_F(ConvolutionVariantsTest, ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {1, 0}}); - ComputeAndCompareR4(&builder, {{{{5, 16, 27}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{5, 16, 27}}}}, {}, kErrorSpec); } // XLA:GPU fuses // Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3) // into // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) -XLA_TEST_F(ConvolutionVariantsTest, - BackwardInputLowPaddingGreaterThanHighPadding) { +TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { XlaBuilder builder(TestName()); auto gradients = ConstantR4FromArray4D( @@ -1187,14 +1186,14 @@ XLA_TEST_F(ConvolutionVariantsTest, /*padding=*/{{0, 0}, {0, 3}}, /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); - ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, kErrorSpec); } // XLA:GPU fuses // Conv([1], Reverse([1,10,100]), padding=(1,1)) // into // BackwardInputConv([1], [1,10,100], padding=(1,1)) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { +TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XlaBuilder builder(TestName()); auto gradients = ConstantR4FromArray4D( @@ -1205,7 +1204,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {1, 1}}); - ComputeAndCompareR4(&builder, {{{{10}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{10}}}}, {}, kErrorSpec); } // HLO pattern @@ -1215,7 +1214,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { // // However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't // support negative padding on backward convolution yet (b/32744257). -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { +TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XlaBuilder builder(TestName()); auto gradients = ConstantR4FromArray4D( @@ -1227,11 +1226,10 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {0, 2}}); - ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, - BackwardFilterLowPaddingLessThanHighPadding) { +TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 @@ -1251,11 +1249,11 @@ XLA_TEST_F(ConvolutionVariantsTest, XlaBuilder::CreateDefaultConvDimensionNumbers()); Transpose(forward_conv, {0, 1, 2, 3}); - ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, - BackwardFilterLowPaddingGreaterThanHighPadding) { +TEST_F(ConvolutionVariantsTest, + BackwardFilterLowPaddingGreaterThanHighPadding) { XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 @@ -1277,10 +1275,10 @@ XLA_TEST_F(ConvolutionVariantsTest, XlaBuilder::CreateDefaultConvDimensionNumbers()); Transpose(forward_conv, {0, 1, 2, 3}); - ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { +TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0 @@ -1304,10 +1302,10 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { XlaBuilder::CreateDefaultConvDimensionNumbers()); Transpose(forward_conv, {0, 1, 2, 3}); - ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); + ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { +TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { XlaBuilder builder(TestName()); auto gradients = ConstantR3FromArray3D( @@ -1318,10 +1316,10 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1}, /*padding=*/{{1, 1}}); - ComputeAndCompareR3(&builder, {{{10}}}, {}, error_spec_); + ComputeAndCompareR3(&builder, {{{10}}}, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { +TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XlaBuilder builder(TestName()); auto activations = @@ -1337,10 +1335,10 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { /*num_spatial_dims=*/1)); Transpose(forward_conv, {0, 1, 2}); - ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); + ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { +TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { XlaBuilder builder(TestName()); auto gradients_flat = LiteralUtil::CreateR1({1}); @@ -1358,10 +1356,10 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); - ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, kErrorSpec); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { +TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder builder(TestName()); auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); @@ -1383,7 +1381,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); Transpose(forward_conv, {0, 1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, kErrorSpec); } } // namespace diff --git a/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc b/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc index af903c53dabe53..d5ef5bb8de46e7 100644 --- a/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc +++ b/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc @@ -16,35 +16,50 @@ limitations under the License. #include #include +#include #include -#include +#include #include +#include #include +#include #define EIGEN_USE_THREADS +#include "absl/log/log.h" #include "absl/types/span.h" +#include "benchmark/benchmark.h" #include "unsupported/Eigen/CXX11/Tensor" #include "xla/array2d.h" #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" +#include "xla/comparison_util.h" +#include "xla/error_spec.h" +#include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/platform_util.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/test_benchmark.h" namespace xla { namespace { @@ -187,7 +202,7 @@ bool CpuGpuFusionTest::ComputeElementwiseAnswerCompare( } } -XLA_TEST_F(CpuGpuFusionTest, Test) { +TEST_F(CpuGpuFusionTest, Test) { // test expression: // slice(select({{T, F, T}, {F, T, F}}, // concat(transpose({{1.0}, {2.0}, {3.0}} + @@ -240,7 +255,7 @@ XLA_TEST_F(CpuGpuFusionTest, Test) { } // Test whether we emit appropriate code for parameters of fusion instructions. -XLA_TEST_F(CpuGpuFusionTest, Parameter) { +TEST_F(CpuGpuFusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); @@ -265,7 +280,7 @@ XLA_TEST_F(CpuGpuFusionTest, Parameter) { ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } -XLA_TEST_F(CpuGpuFusionTest, RandomizedParallelPartition) { +TEST_F(CpuGpuFusionTest, RandomizedParallelPartition) { // Tests parallel partitioning of a fusion instruction. // Create shape with random outer dimension size to generate random parallel // partition counts for each test run. @@ -301,7 +316,7 @@ XLA_TEST_F(CpuGpuFusionTest, RandomizedParallelPartition) { } } -XLA_TEST_F(CpuGpuFusionTest, BroadcastIntoBinaryOp) { +TEST_F(CpuGpuFusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( @@ -325,7 +340,7 @@ XLA_TEST_F(CpuGpuFusionTest, BroadcastIntoBinaryOp) { ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } -XLA_TEST_F(CpuGpuFusionTest, ReshapeToScalar) { +TEST_F(CpuGpuFusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto single_element_array = builder.AddInstruction( @@ -340,7 +355,7 @@ XLA_TEST_F(CpuGpuFusionTest, ReshapeToScalar) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reshape_3by2_1by2by3) { +TEST_F(CpuGpuFusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -355,7 +370,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reshape_3by2_1by2by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reshape_1by2by3_3by2) { +TEST_F(CpuGpuFusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -370,7 +385,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reshape_1by2by3_3by2) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reshape_1by1by1_) { +TEST_F(CpuGpuFusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -385,7 +400,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reshape_1by1by1_) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reshape__1by1by1) { +TEST_F(CpuGpuFusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -400,7 +415,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reshape__1by1by1) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reshape__) { +TEST_F(CpuGpuFusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -415,7 +430,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reshape__) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reshape_3by3_3by3) { +TEST_F(CpuGpuFusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -430,7 +445,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reshape_3by3_3by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Transpose_2by3) { +TEST_F(CpuGpuFusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -445,7 +460,7 @@ XLA_TEST_F(CpuGpuFusionTest, Transpose_2by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Transpose_3by3) { +TEST_F(CpuGpuFusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -460,7 +475,7 @@ XLA_TEST_F(CpuGpuFusionTest, Transpose_3by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, Reverse) { +TEST_F(CpuGpuFusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -476,7 +491,7 @@ XLA_TEST_F(CpuGpuFusionTest, Reverse) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, ReverseNegate) { +TEST_F(CpuGpuFusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -494,7 +509,7 @@ XLA_TEST_F(CpuGpuFusionTest, ReverseNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, BroadcastNegate) { +TEST_F(CpuGpuFusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -512,7 +527,7 @@ XLA_TEST_F(CpuGpuFusionTest, BroadcastNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, SliceNegate) { +TEST_F(CpuGpuFusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -530,7 +545,7 @@ XLA_TEST_F(CpuGpuFusionTest, SliceNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, DynamicSliceNegate) { +TEST_F(CpuGpuFusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -552,7 +567,7 @@ XLA_TEST_F(CpuGpuFusionTest, DynamicSliceNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, ReshapeNegate) { +TEST_F(CpuGpuFusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -570,7 +585,7 @@ XLA_TEST_F(CpuGpuFusionTest, ReshapeNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, TransposeNegate) { +TEST_F(CpuGpuFusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -599,7 +614,7 @@ std::unique_ptr MakeReduceTestComputation() { return builder.Build(); } -XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(Reduce)) { +TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( @@ -618,7 +633,7 @@ XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(Reduce)) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, ReduceImplicitBroadcast) { +TEST_F(CpuGpuFusionTest, ReduceImplicitBroadcast) { auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -640,7 +655,7 @@ XLA_TEST_F(CpuGpuFusionTest, ReduceImplicitBroadcast) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(ReduceWindow)) { +TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -693,7 +708,7 @@ XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(ReduceWindow)) { // When a constant (or other op) which has multiple users is imported // into a fusion, it should remain shared, rather than being duplicated // within the fusion. -XLA_TEST_F(CpuGpuFusionTest, SharedConstant) { +TEST_F(CpuGpuFusionTest, SharedConstant) { auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -728,7 +743,7 @@ XLA_TEST_F(CpuGpuFusionTest, SharedConstant) { // Test that fusion can handle elementwise ops with more than one user. This // test case needs deduplication to avoid exponential compile time. -XLA_TEST_F(CpuGpuFusionTest, Fibonacci) { +TEST_F(CpuGpuFusionTest, Fibonacci) { const char* const kModuleStr = R"( HloModule fibonacci @@ -777,65 +792,66 @@ XLA_TEST_F(CpuGpuFusionTest, Fibonacci) { RunAndCompare(std::move(module), {&literal0, &literal1}, std::nullopt)); } -XLA_TEST_F(CpuGpuFusionTest, Add2D) { +TEST_F(CpuGpuFusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } -XLA_TEST_F(CpuGpuFusionTest, Subtract2D) { +TEST_F(CpuGpuFusionTest, Subtract2D) { TestElementwise2D(HloOpcode::kSubtract); } -XLA_TEST_F(CpuGpuFusionTest, Multiply2D) { +TEST_F(CpuGpuFusionTest, Multiply2D) { TestElementwise2D(HloOpcode::kMultiply); } -XLA_TEST_F(CpuGpuFusionTest, Divide2D) { +TEST_F(CpuGpuFusionTest, Divide2D) { TestElementwise2D(HloOpcode::kDivide); } -XLA_TEST_F(CpuGpuFusionTest, Power2D) { +TEST_F(CpuGpuFusionTest, Power2D) { TestElementwise2D(HloOpcode::kPower); } -XLA_TEST_F(CpuGpuFusionTest, Minimum2D) { +TEST_F(CpuGpuFusionTest, Minimum2D) { TestElementwise2D(HloOpcode::kMinimum); } -XLA_TEST_F(CpuGpuFusionTest, Maximum2D) { +TEST_F(CpuGpuFusionTest, Maximum2D) { TestElementwise2D(HloOpcode::kMaximum); } -XLA_TEST_F(CpuGpuFusionTest, Equal2D) { +TEST_F(CpuGpuFusionTest, Equal2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kEq); } -XLA_TEST_F(CpuGpuFusionTest, Inequal2D) { +TEST_F(CpuGpuFusionTest, Inequal2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kNe); } -XLA_TEST_F(CpuGpuFusionTest, Greater2D) { +TEST_F(CpuGpuFusionTest, Greater2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGt); } -XLA_TEST_F(CpuGpuFusionTest, Lesser2D) { +TEST_F(CpuGpuFusionTest, Lesser2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLt); } -XLA_TEST_F(CpuGpuFusionTest, GreaterOrEqual2D) { +TEST_F(CpuGpuFusionTest, GreaterOrEqual2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGe); } -XLA_TEST_F(CpuGpuFusionTest, LesserOrEqual2D) { +TEST_F(CpuGpuFusionTest, LesserOrEqual2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLe); } -XLA_TEST_F(CpuGpuFusionTest, Clamp2D) { +TEST_F(CpuGpuFusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } -class FusionClientLibraryTest : public ClientLibraryTestBase {}; +class FusionClientLibraryTest + : public ClientLibraryTestRunnerMixin {}; -XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { +TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { // On the GPU backend, it's possible to have too many transposes within one // fusion, causing the kernel to run out shared memory and thus not compile. // We want to check that doesn't happen. @@ -863,17 +879,21 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({1, 0})); - XlaOp p0 = AddParam(l1, &b); + std::vector params; + XlaOp p0 = Parameter(&b, 0, l1.shape(), ""); + params.push_back(&l1); XlaOp sum = p0; for (int i = 1; i < kNumParams; ++i) { - auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); + const Literal& l = i % 2 == 0 ? l1 : l2; + XlaOp pN = Parameter(&b, i, l.shape(), ""); + params.push_back(&l); sum = sum + p0 * pN * pN; } - ComputeAndCompare(&b, {}); + ComputeAndCompare(&b, params); } -XLA_TEST_F(CpuGpuFusionTest, TransposeDiamondWithNonTrivialBranch) { +TEST_F(CpuGpuFusionTest, TransposeDiamondWithNonTrivialBranch) { const char* hlo = R"( HloModule module diff --git a/third_party/xla/xla/tests/dynamic_ops_test.cc b/third_party/xla/xla/tests/dynamic_ops_test.cc index 747d073fc5dd83..5ca6bba5aeb48e 100644 --- a/third_party/xla/xla/tests/dynamic_ops_test.cc +++ b/third_party/xla/xla/tests/dynamic_ops_test.cc @@ -13,33 +13,48 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include +#include +#include #include +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "benchmark/benchmark.h" #include "xla/array2d.h" +#include "xla/array3d.h" #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" +#include "xla/error_spec.h" +#include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test_helpers.h" -#include "xla/reference_util.h" -#include "xla/service/local_service.h" +#include "xla/literal_util.h" #include "xla/service/platform_util.h" +#include "xla/service/service.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/types.h" namespace xla { namespace { -class DynamicSliceTest : public ClientLibraryTestBase { +class DynamicSliceTest : public ClientLibraryTestRunnerMixin { protected: template void TestR1() { @@ -136,13 +151,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR0Parameter( + const Literal start_data = CreateR0Parameter( slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, absl::Span({starts}), slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {&start_data}); } template @@ -162,7 +177,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. std::vector starts(2); - std::vector> start_data(2); + std::vector start_data(2); for (int i = 0; i < 2; ++i) { start_data[i] = CreateR0Parameter( slice_starts[i], i, "slice_starts", &builder, &starts[i]); @@ -172,11 +187,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - std::vector argument_ptrs; + std::vector argument_ptrs; absl::c_transform(start_data, std::back_inserter(argument_ptrs), - [](const std::unique_ptr& argument) { - return argument.get(); - }); + [](const Literal& argument) { return &argument; }); ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } @@ -197,7 +210,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. std::vector starts(3); - std::vector> start_data(3); + std::vector start_data(3); for (int i = 0; i < 3; ++i) { start_data[i] = CreateR0Parameter( slice_starts[i], i, "slice_starts", &builder, &starts[i]); @@ -206,48 +219,46 @@ class DynamicSliceTest : public ClientLibraryTestBase { auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - std::vector argument_ptrs; + std::vector argument_ptrs; absl::c_transform(start_data, std::back_inserter(argument_ptrs), - [](const std::unique_ptr& argument) { - return argument.get(); - }); + [](const Literal& argument) { return &argument; }); ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } }; -XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } -XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) { +TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } +TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } +TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } +TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } +TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +TEST_F(DynamicSliceTest, UInt32R1OOB) { RunR1({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4}); } -XLA_TEST_F(DynamicSliceTest, UInt8R1) { +TEST_F(DynamicSliceTest, UInt8R1) { std::vector data(129); absl::c_iota(data, 0); RunR1(data, {128}, {1}, {128}); } -XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } -XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) { +TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } +TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } +TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } +TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } +TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } +TEST_F(DynamicSliceTest, UInt32R2OOB) { RunR2({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}}); } -XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } -XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) { +TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } +TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } +TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } +TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } +TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +TEST_F(DynamicSliceTest, UInt32R3OOB) { RunR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}}); } -XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { +TEST_F(DynamicSliceTest, Int32R1Pred) { // Slice at dimension start. RunR1({true, false, false, true, false, true, true, false}, {0}, {5}, {true, false, false, true, false}); @@ -262,7 +273,7 @@ XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { {2}, {0}, {}); } -XLA_TEST_F(DynamicSliceTest, Int32R2Pred) { +TEST_F(DynamicSliceTest, Int32R2Pred) { // Slice at dimension start. RunR2( {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0}, @@ -285,7 +296,7 @@ XLA_TEST_F(DynamicSliceTest, Int32R2Pred) { {0, 2}, Array2D(0, 2)); } -XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { +TEST_F(DynamicSliceTest, Int32R3Pred) { // R3 Shape: [2, 3, 2] // clang-format off @@ -306,14 +317,14 @@ XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { // clang-format on } -class DynamicUpdateSliceTest : public ClientLibraryTestBase { +class DynamicUpdateSliceTest + : public ClientLibraryTestRunnerMixin { protected: template void TestR0() { // Disable algebraic simplifier, otherwise the op will be replaced by a // constant. - execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( - "algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); RunR0(0, 123, {}, 123); } @@ -423,14 +434,14 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR0Parameter( + const Literal start_data = CreateR0Parameter( slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, absl::Span({starts})); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {&start_data}); } template @@ -454,7 +465,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. std::vector starts(2); - std::vector> start_data(2); + std::vector start_data(2); for (int i = 0; i < 2; ++i) { start_data[i] = CreateR0Parameter( slice_starts[i], i, "slice_starts", &builder, &starts[i]); @@ -464,11 +475,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - std::vector argument_ptrs; + std::vector argument_ptrs; absl::c_transform(start_data, std::back_inserter(argument_ptrs), - [](const std::unique_ptr& argument) { - return argument.get(); - }); + [](const Literal& argument) { return &argument; }); ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } @@ -493,7 +502,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. std::vector starts(3); - std::vector> start_data(3); + std::vector start_data(3); for (int i = 0; i < 3; ++i) { start_data[i] = CreateR0Parameter( slice_starts[i], i, "slice_starts", &builder, &starts[i]); @@ -504,11 +513,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - std::vector argument_ptrs; + std::vector argument_ptrs; absl::c_transform(start_data, std::back_inserter(argument_ptrs), - [](const std::unique_ptr& argument) { - return argument.get(); - }); + [](const Literal& argument) { return &argument; }); ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } @@ -547,11 +554,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer input parameter. XlaOp input; - std::unique_ptr input_data = + const Literal input_data = CreateR3Parameter(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. XlaOp update; - std::unique_ptr update_data = CreateR3Parameter( + const Literal update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); auto constant_index = ConstantR0(&builder, index); auto zero = ConstantR0(&builder, 0); @@ -559,8 +566,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, - {input_data.get(), update_data.get()}, - ErrorSpec(0.000001)); + {&input_data, &update_data}, ErrorSpec(0.000001)); } template @@ -570,20 +576,20 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; -XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } +TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0(); } +TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } +TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } +TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) { +TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1(); } +TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } +TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } +TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) { RunR1({0, 1, 2, 3, 4}, {5, 6}, {2147483648u}, {0, 1, 2, 5, 6}); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt8R1) { +TEST_F(DynamicUpdateSliceTest, UInt8R1) { std::vector data(129); absl::c_iota(data, 0); std::vector expected = data; @@ -591,33 +597,31 @@ XLA_TEST_F(DynamicUpdateSliceTest, UInt8R1) { RunR1(data, {-1}, {128}, expected); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) { +TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2(); } +TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } +TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } +TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } +TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) { RunR2({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0}, {{0, 1}, {4, 3}}); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) { +TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3(); } +TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } +TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } +TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) { RunR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}}, {2147483648u, 0, 2147483648u}, {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}}); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { - TestOOB(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } +TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { +TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. RunR1({false, false, true, true, false, true, true, false}, {true, true, false}, {0}, @@ -636,7 +640,7 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { {false, false, true, true, false, true, true, false}); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) { +TEST_F(DynamicUpdateSliceTest, Int32R2Pred) { // Slice at dimension start. RunR2( {{false, true, false}, {true, false, true}, {false, true, true}}, @@ -658,7 +662,7 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) { {2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}}); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { +TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // R3 Shape: [2, 3, 2] // Slice at dimension start. RunR3( @@ -678,77 +682,77 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { std::vector operand_shape({3, 123, 247}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) { std::vector operand_shape({3, 123, 247}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousLarger) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousLarger) { std::vector operand_shape({32, 128, 1024}); RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousLargerBF16) { +TEST_F(DynamicUpdateSliceTest, R3ContiguousLargerBF16) { std::vector operand_shape({32, 128, 1024}); RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } // This test that buffer assignment does not alias constants with the output of // dynamic update slice. -XLA_TEST_F(HloTestBase, AddOfDUS) { +TEST_F(HloTestBase, AddOfDUS) { const char* hlo_string = R"( HloModule m test { @@ -769,7 +773,7 @@ XLA_TEST_F(HloTestBase, AddOfDUS) { // and multiple output fusions of dynamic update slices produce the right // results. On some backends (e.g. GPU), this is done inplace. #ifdef XLA_TEST_BACKEND_GPU -XLA_TEST_F(HloTestBase, MultipleOutputFusedDynamicUpdateSlices) { +TEST_F(HloTestBase, MultipleOutputFusedDynamicUpdateSlices) { const char* hlo_string = R"( HloModule MultipleInplaceDus, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } @@ -799,8 +803,8 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(HloTestBase, - MultipleOutputFusedDynamicUpdateSlicesWithTransposeBitcastedRoot) { +TEST_F(HloTestBase, + MultipleOutputFusedDynamicUpdateSlicesWithTransposeBitcastedRoot) { const char* hlo_string = R"( HloModule MultipleInplaceDusWithTransposeBitcastToTheRoot, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } @@ -831,8 +835,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(HloTestBase, - SingleFusedDynamicUpdateSliceWithTransposeBitcastedRoot) { +TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithTransposeBitcastedRoot) { const char* hlo_string = R"( HloModule SingleInplaceDusWithTransposeBitcastToTheRoot, input_output_alias={ {}: (0, {}) } @@ -859,7 +862,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithReshapeBitcastedRoot) { +TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithReshapeBitcastedRoot) { const char* hlo_string = R"( HloModule SingleInplaceDusWithReshapeBitcastToTheRoot, input_output_alias={ {}: (0, {}) } @@ -886,8 +889,8 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(HloTestBase, - SingleFusedDynamicUpdateSliceWithBitcastedRootAndParameter) { +TEST_F(HloTestBase, + SingleFusedDynamicUpdateSliceWithBitcastedRootAndParameter) { const char* hlo_string = R"( HloModule SingleInplaceDusWithBitcastToTheRootAndFromTheParameter, input_output_alias={ {}: (0, {}) } @@ -916,8 +919,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(HloTestBase, - SingleFusedDynamicUpdateSliceWithSameDynamicSliceAccess) { +TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithSameDynamicSliceAccess) { const char* hlo_string = R"( HloModule fusion, input_output_alias={ {}: (0, {}) } @@ -943,8 +945,8 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -XLA_TEST_F(HloTestBase, - SingleFusedDynamicUpdateSliceWithDynamicSliceAccessSlicesOfSizeOne) { +TEST_F(HloTestBase, + SingleFusedDynamicUpdateSliceWithDynamicSliceAccessSlicesOfSizeOne) { const char* hlo_string = R"( HloModule fusion, input_output_alias={ {}: (0, {}) } diff --git a/third_party/xla/xla/tests/float8_test.cc b/third_party/xla/xla/tests/float8_test.cc index 71d50ebd6f8676..2b110e521a39af 100644 --- a/third_party/xla/xla/tests/float8_test.cc +++ b/third_party/xla/xla/tests/float8_test.cc @@ -13,15 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include - #include #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/ml_dtypes.h" namespace xla { @@ -29,7 +27,7 @@ namespace { // Test FP8 floating-point types template -class Float8Test : public ClientLibraryTestBase {}; +class Float8Test : public ClientLibraryTestRunnerMixin {}; using DataTypes = ::testing::Types; diff --git a/third_party/xla/xla/tests/floor_ceil_test.cc b/third_party/xla/xla/tests/floor_ceil_test.cc index c164645e954e7a..b55ff38888c3e7 100644 --- a/third_party/xla/xla/tests/floor_ceil_test.cc +++ b/third_party/xla/xla/tests/floor_ceil_test.cc @@ -14,21 +14,20 @@ limitations under the License. ==============================================================================*/ #include -#include +#include "absl/log/log.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class FloorCeilTest : public ClientLibraryTestBase { +class FloorCeilTest : public ClientLibraryTestRunnerMixin { public: enum Function { kFloor, @@ -63,8 +62,6 @@ class FloorCeilTest : public ClientLibraryTestBase { ComputeAndCompareR0(&builder, expected, /*arguments=*/{}); } - const ErrorSpec error_spec_{0.0001}; - float infinity_ = std::numeric_limits::infinity(); float minus_infinity_ = -std::numeric_limits::infinity(); }; @@ -74,7 +71,7 @@ class FloorCeilTest : public ClientLibraryTestBase { // * passing x86-based CPU's qnan to the GPU makes a different nan // "7fc00000=nan=nan vs 7fffffff=nan=nan" -XLA_TEST_F(FloorCeilTest, R1S0Floor) { TestR1F32({}, {}, kFloor); } +TEST_F(FloorCeilTest, R1S0Floor) { TestR1F32({}, {}, kFloor); } TEST_F(FloorCeilTest, R1Floor) { TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1}, diff --git a/third_party/xla/xla/tests/fmax_fmin_test.cc b/third_party/xla/xla/tests/fmax_fmin_test.cc index b386de39ad20b3..2869afa6aa5b1e 100644 --- a/third_party/xla/xla/tests/fmax_fmin_test.cc +++ b/third_party/xla/xla/tests/fmax_fmin_test.cc @@ -13,20 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/service/service.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/literal.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class FmaxSimpleTest : public ClientLibraryTestBase {}; +using FmaxSimpleTest = ClientLibraryTestRunnerMixin; -XLA_TEST_F(FmaxSimpleTest, FmaxTenValues) { +TEST_F(FmaxSimpleTest, FmaxTenValues) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto x = ConstantR1( @@ -40,16 +43,16 @@ XLA_TEST_F(FmaxSimpleTest, FmaxTenValues) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(FmaxSimpleTest, FmaxEdgeCases) { +TEST_F(FmaxSimpleTest, FmaxEdgeCases) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); XlaOp param0, param1; - std::unique_ptr param0_data = CreateR1Parameter( + const Literal param0_data = CreateR1Parameter( {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, INFINITY, -INFINITY, NAN}, /*parameter_number=*/0, /*name=*/"param0", /*builder=*/&builder, /*data_handle=*/¶m0); - std::unique_ptr param1_data = CreateR1Parameter( + const Literal param1_data = CreateR1Parameter( {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, /*parameter_number=*/1, /*name=*/"param1", /*builder=*/&builder, /*data_handle=*/¶m1); @@ -57,21 +60,20 @@ XLA_TEST_F(FmaxSimpleTest, FmaxEdgeCases) { Max(param0, param1); std::vector expected = {INFINITY, INFINITY, NAN, NAN, INFINITY, -5, NAN, INFINITY, 8, NAN}; - ComputeAndCompareR1(&builder, expected, - {param0_data.get(), param1_data.get()}, + ComputeAndCompareR1(&builder, expected, {¶m0_data, ¶m1_data}, ErrorSpec(0.0001)); } -XLA_TEST_F(FmaxSimpleTest, FminEdgeCases) { +TEST_F(FmaxSimpleTest, FminEdgeCases) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); XlaOp param0, param1; - std::unique_ptr param0_data = CreateR1Parameter( + const Literal param0_data = CreateR1Parameter( {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, INFINITY, -INFINITY, NAN}, /*parameter_number=*/0, /*name=*/"param0", /*builder=*/&builder, /*data_handle=*/¶m0); - std::unique_ptr param1_data = CreateR1Parameter( + const Literal param1_data = CreateR1Parameter( {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, /*parameter_number=*/1, /*name=*/"param1", /*builder=*/&builder, /*data_handle=*/¶m1); @@ -79,8 +81,7 @@ XLA_TEST_F(FmaxSimpleTest, FminEdgeCases) { Min(param0, param1); std::vector expected = {INFINITY, -INFINITY, NAN, NAN, -4, -INFINITY, NAN, 7, -INFINITY, NAN}; - ComputeAndCompareR1(&builder, expected, - {param0_data.get(), param1_data.get()}, + ComputeAndCompareR1(&builder, expected, {¶m0_data, ¶m1_data}, ErrorSpec(0.0001)); } diff --git a/third_party/xla/xla/tests/half_test.cc b/third_party/xla/xla/tests/half_test.cc index bc05e5f284ca65..639e529d0ffc61 100644 --- a/third_party/xla/xla/tests/half_test.cc +++ b/third_party/xla/xla/tests/half_test.cc @@ -14,32 +14,36 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include -#include "absl/status/statusor.h" +#include "absl/log/check.h" +#include "absl/types/span.h" #include "Eigen/Core" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" -#include "xla/hlo/testlib/test_helpers.h" -#include "xla/literal.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" +#include "xla/tsl/platform/test.h" +#include "xla/types.h" // Tests the handling of the basic mathematics operations with F16 operands. namespace xla { namespace { -class HalfTestBase : public ClientLibraryTestBase { - protected: - const ErrorSpec error_spec_{0.001, 0.001}; - // Number of elements in the input buffers. - static constexpr int kNumElements = 4; -}; - using UnaryBuildFuncTy = std::function; +// Number of elements in the input buffers. +constexpr int kNumElements = 4; +constexpr ErrorSpec kErrorSpec{0.001, 0.001}; + +using HalfTestBase = ClientLibraryTestRunnerMixin; + struct UnaryOpTestParam { std::function compute_func; UnaryBuildFuncTy build_func; @@ -67,7 +71,7 @@ XLA_TEST_P(UnaryOpTest, Ops) { UnaryBuildFuncTy build_func = GetParam().build_func; build_func(x_opnd); - ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); + ComputeAndCompareR1(&builder, expected, {&x_data}, kErrorSpec); } half sign_imp(half value) { @@ -125,7 +129,7 @@ XLA_TEST_P(UnaryPredTest, Ops) { UnaryBuildFuncTy build_func = GetParam().build_func; build_func(x_opnd); - ComputeAndCompareR1(&builder, expected, {x_data.get()}); + ComputeAndCompareR1(&builder, expected, {&x_data}); } INSTANTIATE_TEST_SUITE_P(half, UnaryPredTest, @@ -166,8 +170,7 @@ XLA_TEST_P(BinaryOpTest, Ops) { BinaryBuildFuncTy build_func = GetParam().build_func; build_func(x_opnd, y_opnd, {}); - ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, - error_spec_); + ComputeAndCompareR1(&builder, expected, {&x_data, &y_data}, kErrorSpec); } half atan2_imp(half x, half y) { @@ -221,7 +224,7 @@ XLA_TEST_P(BinaryPredTest, Ops) { BinaryBuildFuncTy build_func = GetParam().build_func; build_func(x_opnd, y_opnd, {}); - ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); + ComputeAndCompareR1(&builder, expected, {&x_data, &y_data}); } INSTANTIATE_TEST_SUITE_P( diff --git a/third_party/xla/xla/tests/iota_test.cc b/third_party/xla/xla/tests/iota_test.cc index a9dddb816b4705..34afcad89c4efe 100644 --- a/third_party/xla/xla/tests/iota_test.cc +++ b/third_party/xla/xla/tests/iota_test.cc @@ -25,9 +25,10 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" @@ -35,7 +36,9 @@ limitations under the License. namespace xla { namespace { -XLA_TEST_F(HloTestBase, IotaReshapeR1) { +using IotaTest = HloTestBase; + +TEST_F(IotaTest, IotaReshapeR1) { const std::string hlo_text = R"( HloModule iota_reshape ENTRY main { @@ -46,7 +49,7 @@ XLA_TEST_F(HloTestBase, IotaReshapeR1) { EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); } -XLA_TEST_F(HloTestBase, IotaReshapeExtraDims) { +TEST_F(IotaTest, IotaReshapeExtraDims) { const std::string hlo_text = R"( HloModule iota_reshape ENTRY main { @@ -67,7 +70,7 @@ std::vector GetR1Expected(const int64_t num_elements) { } class IotaR1Test - : public ClientLibraryTestBase, + : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface> {}; XLA_TEST_P(IotaR1Test, DoIt) { @@ -113,7 +116,7 @@ INSTANTIATE_TEST_CASE_P( /*end=*/10001, /*step=*/10))); -class IotaR2Test : public ClientLibraryTestBase, +class IotaR2Test : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface< std::tuple> {}; @@ -151,7 +154,7 @@ INSTANTIATE_TEST_CASE_P( /*step=*/10), ::testing::Values(0, 1))); -class IotaR3Test : public ClientLibraryTestBase, +class IotaR3Test : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface< std::tuple> {}; diff --git a/third_party/xla/xla/tests/log_test.cc b/third_party/xla/xla/tests/log_test.cc index 114a00ee387682..e58fc2daff9efc 100644 --- a/third_party/xla/xla/tests/log_test.cc +++ b/third_party/xla/xla/tests/log_test.cc @@ -16,19 +16,20 @@ limitations under the License. #include #include -#include "xla/client/local_client.h" +#include "xla/array3d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class LogTest : public ClientLibraryTestBase {}; +class LogTest : public ClientLibraryTestRunnerMixin {}; -XLA_TEST_F(LogTest, LogZeroValues) { +TEST_F(LogTest, LogZeroValues) { XlaBuilder builder(TestName()); auto x = ConstantR3FromArray3D(&builder, Array3D(3, 0, 0)); Log(x); diff --git a/third_party/xla/xla/tests/map_test.cc b/third_party/xla/xla/tests/map_test.cc index b35daacfa4b8d3..0a4f6bcad183b4 100644 --- a/third_party/xla/xla/tests/map_test.cc +++ b/third_party/xla/xla/tests/map_test.cc @@ -13,34 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/array2d.h" -#include "xla/client/local_client.h" +#include "xla/array3d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/test_helpers.h" +#include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class MapTest : public ClientLibraryTestBase { +class MapTest : public ClientLibraryTestRunnerMixin { public: - explicit MapTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) { + MapTest() { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); } @@ -170,28 +174,23 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR0(42.0); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {}); - ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, + ComputeAndCompareR0(&builder, 43.0, {¶m0_literal}, ErrorSpec(0.01f)); } -XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { +TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - ErrorSpec(0.01f)); + ComputeAndCompareR1(&builder, {}, {¶m0_literal}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapEachElemPlusOneR1S4) { @@ -199,40 +198,34 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, - {param0_data.get()}, ErrorSpec(0.01f)); + {¶m0_literal}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); - ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); + ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {¶m0_literal}); } TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); - ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); + ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {¶m0_literal}); } TEST_F(MapTest, MapEachElemLongerChainR1) { @@ -240,31 +233,26 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, - {param0_data.get()}, ErrorSpec(0.01f)); + {¶m0_literal}, ErrorSpec(0.01f)); } -XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { +TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - ErrorSpec(0.01f)); + ComputeAndCompareR1(&builder, {}, {¶m0_literal}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapMultipleMapsR1S4) { @@ -273,15 +261,13 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, - {param0_data.get()}, ErrorSpec(0.01f)); + {¶m0_literal}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapEachElemPlusOneR2) { @@ -289,19 +275,17 @@ TEST_F(MapTest, MapEachElemPlusOneR2) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); - ComputeAndCompareR2(&builder, expected_array, {param0_data.get()}, + ComputeAndCompareR2(&builder, expected_array, {¶m0_literal}, ErrorSpec(0.01f)); } -XLA_TEST_F(MapTest, ComplexNestedMaps) { +TEST_F(MapTest, ComplexNestedMaps) { // Constructs a complex graph of embedded computations to test the computation // lowering order. Python equivalent: // @@ -344,12 +328,8 @@ TEST_F(MapTest, MapBinaryAdder) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); @@ -357,23 +337,18 @@ TEST_F(MapTest, MapBinaryAdder) { {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, - {param0_data.get(), param1_data.get()}, + {¶m0_literal, ¶m1_literal}, ErrorSpec(0.01f)); } // Adds two rank-2 arrays with different layouts. This test exercises a path // for Map that used to fail in shape inference (b/28989438). -XLA_TEST_F(MapTest, AddWithMixedLayouts) { +TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); - Literal param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); @@ -386,20 +361,15 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { expected(1, 0) = 33; expected(1, 1) = 44; ComputeAndCompareR2(&builder, expected, - {param0_data.get(), param1_data.get()}); + {¶m0_literal, ¶m1_literal}); } -XLA_TEST_F(MapTest, AddR3_3x0x2) { +TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); - Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); @@ -407,7 +377,7 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), - {param0_data.get(), param1_data.get()}); + {¶m0_literal, ¶m1_literal}); } TEST_F(MapTest, MapTernaryAdder) { @@ -415,16 +385,10 @@ TEST_F(MapTest, MapTernaryAdder) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); Literal param2_literal = LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); - std::unique_ptr param2_data = - client_->TransferToServer(param2_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); @@ -433,8 +397,7 @@ TEST_F(MapTest, MapTernaryAdder) { ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, - {param0_data.get(), param1_data.get(), param2_data.get()}, - ErrorSpec(0.01f)); + {¶m0_literal, ¶m1_literal, ¶m2_literal}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapGt) { @@ -477,12 +440,8 @@ TEST_F(MapTest, MapOperationWithBuildError) { Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); @@ -498,7 +457,7 @@ TEST_F(MapTest, MapOperationWithBuildError) { class MapHloTest : public HloTestBase {}; // TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map. -XLA_TEST_F(MapHloTest, DISABLED_ON_GPU(MapWithMixedInputTypes)) { +TEST_F(MapHloTest, DISABLED_ON_GPU(MapWithMixedInputTypes)) { absl::string_view hlo_string = R"( HloModule MapMixedInputTypes @@ -522,7 +481,7 @@ XLA_TEST_F(MapHloTest, DISABLED_ON_GPU(MapWithMixedInputTypes)) { // MapTest disables inline and algsimp. MapTestWithFullOpt runs all // optimizations. -using MapTestWithFullOpt = ClientLibraryTestBase; +using MapTestWithFullOpt = ClientLibraryTestRunnerMixin; // Regression test for b/31466798. The inliner simplifies map(param0, param1, // power) to power(param0, param1) without deleting the old subcomputation which @@ -540,18 +499,13 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Literal param0_literal = LiteralUtil::CreateR0(2.0f); Literal param1_literal = LiteralUtil::CreateR0(5.0f); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, power, {}); - ComputeAndCompareR0(&builder, 32.0f, - {param0_data.get(), param1_data.get()}, - ErrorSpec(0.01f)); + ComputeAndCompareR0( + &builder, 32.0f, {¶m0_literal, ¶m1_literal}, ErrorSpec(0.01f)); } // Regression test for b/35786417, where the inliner would not notice the change @@ -567,17 +521,13 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Literal param0_literal = LiteralUtil::CreateR0(2.0f); Literal param1_literal = LiteralUtil::CreateR0(5.0f); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); - std::unique_ptr param1_data = - client_->TransferToServer(param1_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, sub_opposite, {}); - ComputeAndCompareR0( - &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); + ComputeAndCompareR0(&builder, 3.0f, {¶m0_literal, ¶m1_literal}, + ErrorSpec(0.01f)); } // Regression test for b/35786417, where the inliner would CHECK-fail due to the @@ -591,13 +541,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) { auto square = sub_builder->BuildAndNoteError(); Literal param0_literal = LiteralUtil::CreateR0(10.0f); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param0}, square, {}); - ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, + ComputeAndCompareR0(&builder, 100.0f, {¶m0_literal}, ErrorSpec(0.01f)); } diff --git a/third_party/xla/xla/tests/multidimensional_slice_test.cc b/third_party/xla/xla/tests/multidimensional_slice_test.cc index 0b89cbee3341f6..86361a4c638929 100644 --- a/third_party/xla/xla/tests/multidimensional_slice_test.cc +++ b/third_party/xla/xla/tests/multidimensional_slice_test.cc @@ -15,23 +15,22 @@ limitations under the License. // Tests that slice operations can be performed. -#include #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class SliceTest : public ClientLibraryTestBase {}; +class SliceTest : public ClientLibraryTestRunnerMixin {}; -XLA_TEST_F(SliceTest, Slice2D) { +TEST_F(SliceTest, Slice2D) { XlaBuilder builder("slice_2d"); auto original = ConstantR2( &builder, @@ -42,7 +41,7 @@ XLA_TEST_F(SliceTest, Slice2D) { ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); } -XLA_TEST_F(SliceTest, Slice3D) { +TEST_F(SliceTest, Slice3D) { XlaBuilder builder("slice_3d"); Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); diff --git a/third_party/xla/xla/tests/params_test.cc b/third_party/xla/xla/tests/params_test.cc index 025fe8de95b2ad..fe9b59eca0b415 100644 --- a/third_party/xla/xla/tests/params_test.cc +++ b/third_party/xla/xla/tests/params_test.cc @@ -14,120 +14,105 @@ limitations under the License. ==============================================================================*/ #include -#include +#include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/array2d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class ParamsTest : public ClientLibraryTestBase {}; +using ParamsTest = ClientLibraryTestRunnerMixin; -XLA_TEST_F(ParamsTest, ConstantR0F32Param) { +TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR0(3.14159f); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); - ComputeAndCompareR0(&builder, 3.14159f, {param0_data.get()}, + ComputeAndCompareR0(&builder, 3.14159f, {¶m0_literal}, ErrorSpec(0.0001f)); } -XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { +TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - ErrorSpec(0.01f)); + ComputeAndCompareR1(&builder, {}, {¶m0_literal}, ErrorSpec(0.01f)); } -XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { +TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); - ComputeAndCompareR1(&builder, {3.14f, -100.25f}, {param0_data.get()}, + ComputeAndCompareR1(&builder, {3.14f, -100.25f}, {¶m0_literal}, ErrorSpec(0.01f)); } -XLA_TEST_F(ParamsTest, ConstantR1U8Param) { +TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); std::string str("hello world"); Literal param0_literal = LiteralUtil::CreateR1U8(str); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Parameter(&builder, 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), "param0"); - ComputeAndCompareR1U8(&builder, str, {param0_data.get()}); + ComputeAndCompareR1U8(&builder, str, {¶m0_literal}); } -XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { +TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); - ComputeAndCompareR2(&builder, Array2D(3, 0), - {param0_data.get()}, ErrorSpec(0.01f)); + ComputeAndCompareR2(&builder, Array2D(3, 0), {¶m0_literal}, + ErrorSpec(0.01f)); } -XLA_TEST_F(ParamsTest, ConstantR2F32Param) { +TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); - std::unique_ptr param0_data = - client_->TransferToServer(param0_literal).value(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); Array2D expected_array( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); - ComputeAndCompareR2(&builder, expected_array, {param0_data.get()}, + ComputeAndCompareR2(&builder, expected_array, {¶m0_literal}, ErrorSpec(0.01f)); } -XLA_TEST_F(ParamsTest, TwoParameters) { +TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); Literal literal0 = LiteralUtil::CreateR1({1, 2}); - std::unique_ptr param0_data = - client_->TransferToServer(literal0).value(); auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); Literal literal1 = LiteralUtil::CreateR1({10, 20}); - std::unique_ptr param1_data = - client_->TransferToServer(literal1).value(); auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); // Use both parameters @@ -143,16 +128,14 @@ XLA_TEST_F(ParamsTest, TwoParameters) { // {11, 22} * {10, 20} = {110, 440} Mul(sum, param1); - ComputeAndCompareR1(&builder, {110, 440}, - {param0_data.get(), param1_data.get()}, + ComputeAndCompareR1(&builder, {110, 440}, {&literal0, &literal1}, ErrorSpec(0.0001f)); } -XLA_TEST_F(ParamsTest, MissingParameter) { +TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. Literal literal = LiteralUtil::CreateR0(3.14159f); - std::unique_ptr data = client_->TransferToServer(literal).value(); XlaBuilder builder(TestName()); Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); @@ -161,36 +144,27 @@ XLA_TEST_F(ParamsTest, MissingParameter) { ASSERT_NE(computation_status.status(), absl::OkStatus()); } -XLA_TEST_F(ParamsTest, UnusedParameter) { +TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); Literal literal0 = LiteralUtil::CreateR1({1, 2}); - std::unique_ptr param0_data = - client_->TransferToServer(literal0).value(); Parameter(&builder, 0, literal0.shape(), "param0"); Literal literal1 = LiteralUtil::CreateR1({10, 20}); - std::unique_ptr param1_data = - client_->TransferToServer(literal1).value(); Parameter(&builder, 1, literal1.shape(), "param1"); - ComputeAndCompareR1(&builder, {10, 20}, - {param0_data.get(), param1_data.get()}, + ComputeAndCompareR1(&builder, {10, 20}, {&literal0, &literal1}, ErrorSpec(0.0001f)); } -XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { +TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // Build a computation with a couple unused parameters which are used in an // unused expression. XlaBuilder builder(TestName()); Literal literal0 = LiteralUtil::CreateR1({1, 2}); - std::unique_ptr param0_data = - client_->TransferToServer(literal0).value(); Literal literal1 = LiteralUtil::CreateR1({10, 20, 30}); - std::unique_ptr param1_data = - client_->TransferToServer(literal1).value(); auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); @@ -201,13 +175,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { Neg(param0); - ComputeAndCompareR1( - &builder, {-1, -2}, - {param0_data.get(), param1_data.get(), param1_data.get()}, - ErrorSpec(0.0001f)); + ComputeAndCompareR1(&builder, {-1, -2}, + {&literal0, &literal1, &literal1}, + ErrorSpec(0.0001f)); } -XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { +TEST_F(ParamsTest, HundredLargeR1Parameters) { XlaBuilder builder(TestName()); constexpr int size = 8 * 128 * 2; @@ -217,7 +190,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum = {{0, 1}}; sum.resize(size); - std::vector> param_data_owner; + std::vector param_data_owner; constexpr int parameter_count = 100; for (int i = 0; i < parameter_count; ++i) { @@ -229,15 +202,15 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); Literal literal = LiteralUtil::CreateR1(sum_value); - param_data_owner.push_back(client_->TransferToServer(literal).value()); XlaOp param = Parameter(&builder, i, literal.shape(), "param"); + param_data_owner.push_back(std::move(literal)); sum_handle = Add(sum_handle, param); } - std::vector param_data; + std::vector param_data; param_data.reserve(param_data_owner.size()); - for (const std::unique_ptr& data : param_data_owner) { - param_data.push_back(data.get()); + for (const Literal& data : param_data_owner) { + param_data.push_back(&data); } ComputeAndCompareR1(&builder, sum, param_data, ErrorSpec(0.0001f)); @@ -249,26 +222,25 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { // TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM // compilation. -XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(ThreeThousandParameters)) { +TEST_F(ParamsTest, DISABLED_ON_CPU(ThreeThousandParameters)) { XlaBuilder builder(TestName()); - std::vector> param_data_owner; + std::vector param_data_owner; XlaOp sum_handle = ConstantR0(&builder, 0.0f); float target = 0.0; constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; Literal literal = LiteralUtil::CreateR0(i); - param_data_owner.push_back( - std::move(client_->TransferToServer(literal)).value()); XlaOp param = Parameter(&builder, i, literal.shape(), "param"); + param_data_owner.push_back(std::move(literal)); sum_handle = Add(sum_handle, param); } - std::vector param_data; + std::vector param_data; param_data.reserve(param_data_owner.size()); - for (const std::unique_ptr& data : param_data_owner) { - param_data.push_back(data.get()); + for (const Literal& data : param_data_owner) { + param_data.push_back(&data); } ComputeAndCompareR0(&builder, target, param_data, ErrorSpec(0.0001f)); @@ -276,11 +248,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(ThreeThousandParameters)) { // TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM // compilation. -XLA_TEST_F(ParamsTest, - DISABLED_ON_CPU(ThreeThousandParametersAndOutputElements)) { +TEST_F(ParamsTest, DISABLED_ON_CPU(ThreeThousandParametersAndOutputElements)) { XlaBuilder builder(TestName()); - std::vector> param_data_owner; + std::vector param_data_owner; XlaOp sum_handle = ConstantR1(&builder, {0, 0}); int32_t target = 0; constexpr int kParamCount = 3000; @@ -290,9 +261,8 @@ XLA_TEST_F(ParamsTest, for (int i = 0; i < kParamCount; ++i) { target += i; Literal literal = LiteralUtil::CreateR1({i, i}); - param_data_owner.push_back( - std::move(client_->TransferToServer(literal)).value()); XlaOp param = Parameter(&builder, i, literal.shape(), "param"); + param_data_owner.push_back(std::move(literal)); params.push_back(param); sum_handle = Add(sum_handle, param); } @@ -305,10 +275,10 @@ XLA_TEST_F(ParamsTest, Tuple(&builder, outputs); - std::vector param_data; + std::vector param_data; param_data.reserve(param_data_owner.size()); - for (const std::unique_ptr& data : param_data_owner) { - param_data.push_back(data.get()); + for (const Literal& data : param_data_owner) { + param_data.push_back(&data); } std::vector elements; @@ -336,10 +306,10 @@ XLA_TEST_F(ParamsTest, // pN += (1, 1) // } // result = {p0, p1, ..., pN} -XLA_TEST_F(ParamsTest, ManyParametersIntoWhileLoop) { +TEST_F(ParamsTest, ManyParametersIntoWhileLoop) { XlaBuilder builder(TestName()); - std::vector> param_data_owner; + std::vector param_data_owner; constexpr int kParamCount = 1900; std::vector params; std::vector parameter_shapes; @@ -348,22 +318,20 @@ XLA_TEST_F(ParamsTest, ManyParametersIntoWhileLoop) { parameter_shapes.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { Literal literal = LiteralUtil::CreateR1({i, i}); - param_data_owner.push_back( - std::move(client_->TransferToServer(literal)).value()); XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); parameter_shapes.push_back(literal.shape()); + param_data_owner.push_back(std::move(literal)); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. Literal bool_literal = LiteralUtil::CreateR0(false); - param_data_owner.push_back( - std::move(client_->TransferToServer(bool_literal)).value()); XlaOp bool_param = Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param"); params.push_back(bool_param); parameter_shapes.push_back(bool_literal.shape()); + param_data_owner.push_back(std::move(bool_literal)); auto init = Tuple(&builder, params); @@ -407,10 +375,10 @@ XLA_TEST_F(ParamsTest, ManyParametersIntoWhileLoop) { } Tuple(&builder, outputs); - std::vector param_data; + std::vector param_data; param_data.reserve(param_data_owner.size()); - for (const std::unique_ptr& data : param_data_owner) { - param_data.push_back(data.get()); + for (const Literal& data : param_data_owner) { + param_data.push_back(&data); } std::vector elements; @@ -425,7 +393,7 @@ XLA_TEST_F(ParamsTest, ManyParametersIntoWhileLoop) { #endif -XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { +TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { XlaBuilder builder(TestName()); Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); @@ -435,43 +403,43 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { auto rhs = GetTupleElement(input, 1); Add(lhs, rhs); - std::unique_ptr data = - client_ - ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ - LiteralUtil::CreateR1({1, 2, 3}), - LiteralUtil::CreateR1({4, 5, 6}), - })) - .value(); + const Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR1({4, 5, 6}), + }); - std::vector arguments = {data.get()}; const std::vector expected = {1 + 4, 2 + 5, 3 + 6}; - ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); + ComputeAndCompareR1(&builder, expected, {&literal}, ErrorSpec(1e-5)); } // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. -XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { +TEST_F(ParamsTest, R2_2x2_Layout_01) { Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal.shape(), "input"); - std::unique_ptr data = client_->TransferToServer(literal).value(); - ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); + ComputeAndCompareLiteral(&builder, literal, {&literal}, ErrorSpec(1e-3)); } // As above, but for {1, 0} layout. -XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { +TEST_F(ParamsTest, R2_2x2_Layout_10) { Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal.shape(), "input"); - std::unique_ptr data = client_->TransferToServer(literal).value(); - ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); + ComputeAndCompareLiteral(&builder, literal, {&literal}, ErrorSpec(1e-3)); } -XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { +// Disabled on CPU, GPU, and interpreter. Not all all HLO-based runners support +// using the layout of the parameter literal on all backends. The entry +// computation layout is used instead. The way this test is set up, the ECL will +// reflect the layout of the original literal, so it does not pass. This seems +// to be a niche behavior that is not worth fixing. +TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + R2_2x2_TryToPassReverseLayoutToParameter)))) { Literal literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, @@ -494,11 +462,10 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { // Use the slice operator to get an off-diagonal element. Slice(input, {0, 1}, {1, 2}, {1, 1}); - std::unique_ptr data = client_->TransferToServer(literal).value(); // Check that we got the off-diagonal value that we expected. Array2D expected(1, 1); expected(0, 0) = 2; - ComputeAndCompareR2(&builder, expected, {data.get()}, ErrorSpec(1e-3)); + ComputeAndCompareR2(&builder, expected, {&literal}, ErrorSpec(1e-3)); } } // namespace diff --git a/third_party/xla/xla/tests/pred_test.cc b/third_party/xla/xla/tests/pred_test.cc index 060a433753aa8e..204a3cc538c72e 100644 --- a/third_party/xla/xla/tests/pred_test.cc +++ b/third_party/xla/xla/tests/pred_test.cc @@ -14,20 +14,21 @@ limitations under the License. ==============================================================================*/ // Miscellaneous tests with the PRED type that don't fit anywhere else. -#include +#include +#include +#include -#include "xla/array2d.h" -#include "xla/client/local_client.h" +#include "absl/types/span.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class PredTest : public ClientLibraryTestBase { +class PredTest : public ClientLibraryTestRunnerMixin { protected: void TestCompare(bool lhs, bool rhs, bool expected, std::function - -#include "absl/status/statusor.h" -#include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/hlo/testlib/test_helpers.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class QueryInferredShapeTest : public ClientLibraryTestBase {}; - -TEST_F(QueryInferredShapeTest, OnePlusOneShape) { +TEST(QueryInferredShapeTest, OnePlusOneShape) { XlaBuilder builder("one_plus_one"); - auto one = ConstantR0(&builder, 1.0); - auto result = Add(one, one); - absl::StatusOr shape_status = builder.GetShape(result); - ASSERT_IS_OK(shape_status.status()); - auto shape = shape_status.value(); + XlaOp one = ConstantR0(&builder, 1.0); + XlaOp result = Add(one, one); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, builder.GetShape(result)); ASSERT_TRUE(ShapeUtil::Equal(shape, ShapeUtil::MakeShape(F32, {}))); } diff --git a/third_party/xla/xla/tests/reduce_precision_test.cc b/third_party/xla/xla/tests/reduce_precision_test.cc index 5adbd3e50458fa..35e5179846018d 100644 --- a/third_party/xla/xla/tests/reduce_precision_test.cc +++ b/third_party/xla/xla/tests/reduce_precision_test.cc @@ -13,25 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include -#include +#include #include #include "absl/base/casts.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/array2d.h" -#include "xla/client/local_client.h" +#include "absl/strings/str_format.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" -#include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/literal_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" namespace xla { @@ -452,8 +446,9 @@ static const uint64_t f64_test_values[][4] = { }, }; -class ReducedPrecisionAccuracyTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class ReducedPrecisionAccuracyTest + : public ClientLibraryTestRunnerMixin, + public ::testing::WithParamInterface { protected: template void DoIt(int exponent_bits, int mantissa_bits, @@ -517,7 +512,7 @@ void ReducedPrecisionAccuracyTest::DoIt( ReducePrecision(a, exponent_bits, mantissa_bits); - ComputeAndCompare(&builder, {std::move(a_literal)}); + ComputeAndCompare(&builder, {&a_literal}); } INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest, diff --git a/third_party/xla/xla/tests/reshape_test.cc b/third_party/xla/xla/tests/reshape_test.cc index 4aeb2950e034e9..19b6ae21226d53 100644 --- a/third_party/xla/xla/tests/reshape_test.cc +++ b/third_party/xla/xla/tests/reshape_test.cc @@ -14,20 +14,16 @@ limitations under the License. ==============================================================================*/ #include -#include -#include #include #include #include -#include #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" @@ -36,21 +32,22 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/reference_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { class ReshapeTest : public ::testing::WithParamInterface, - public ClientLibraryTestBase { + public ClientLibraryTestRunnerMixin { public: ReshapeTest() { set_float_type(GetParam()); } @@ -58,342 +55,320 @@ class ReshapeTest : public ::testing::WithParamInterface, }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { +TEST_P(ReshapeTest, CollapseTrivial1x1) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { +TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { +TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. -XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { +TEST_P(ReshapeTest, SingleElementArrayToScalar) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{}); auto new_shape = builder.GetShape(reshape).value(); auto expected_literal = LiteralUtil::CreateR0(1.0f); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { +TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, param0_literal, "param0", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{1}); auto expected_literal = LiteralUtil::CreateR1({-1.0f}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, Trivial0x3) { +TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { +TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, param0_literal, "param0", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, Trivial3x0) { +TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Collapses a 2-dimensional row vector to 1 dimension. -XLA_TEST_P(ReshapeTest, Trivial1x3) { +TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Collapses a 2-dimensional column vector to 1 dimension. -XLA_TEST_P(ReshapeTest, Trivial3x1) { +TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Splits an empty vector into an empty matrix. -XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { +TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Splits a vector into a matrix. -XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { +TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{2, 3}); auto expected_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { +TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Transposes a 2-dimensional row vector to a column vector. -XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { +TEST_P(ReshapeTest, ReshapeRowToCol) { XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Transposes a 2-dimensional array. -XLA_TEST_P(ReshapeTest, TransposeAsReshape) { +TEST_P(ReshapeTest, TransposeAsReshape) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 0}), /*dimensions=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Transposes a 0x4 array with XlaBuilder::Transpose. -XLA_TEST_P(ReshapeTest, Transpose0x4) { +TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Transposes a 2-dimensional array with ComputationBuilder::Trans. -XLA_TEST_P(ReshapeTest, Transpose4x3) { +TEST_P(ReshapeTest, Transpose4x3) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { +TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{2, 3, 0, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { +TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { +TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { +TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 0}), /*dimensions=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { +TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 0}), /*dimensions=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = LiteralUtil::CreateFromArray(expected); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } @@ -408,29 +383,27 @@ static Array3D ArrayForDocR3Tests() { {{40, 41, 42}, {45, 46, 47}}}); } -XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { +TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{0, 1, 2}), /*dimensions=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { +TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{0, 1, 2}), /*dimensions=*/{8, 3}); auto expected_literal = LiteralUtil::CreateR2({{10, 11, 12}, @@ -441,33 +414,31 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { {35, 36, 37}, {40, 41, 42}, {45, 46, 47}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { +TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 2, 0}), /*dimensions=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { +TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 2, 0}), /*dimensions=*/{8, 3}); @@ -479,23 +450,22 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { {45, 16, 26}, {36, 46, 17}, {27, 37, 47}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { +TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 2, 0}), /*dimensions=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } @@ -514,27 +484,26 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // Then we collapse Z be collapsed so we just end up with planes: // // 1 2 3 4 5 6 1 2 3 4 5 6 -XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { +TEST_P(ReshapeTest, FullyConnectedCollapse) { XlaBuilder builder(TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // As above, but uses reshape directly. -XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { +TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { XlaBuilder builder(TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; @@ -547,19 +516,18 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{2, 4}); auto expected_literal = LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Reshape various ranks to a scalar. -XLA_TEST_P(ReshapeTest, ToScalar) { +TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { XlaBuilder b(TestName()); std::vector ones(rank, 1); // this is {1, ..., 1}. @@ -568,30 +536,27 @@ XLA_TEST_P(ReshapeTest, ToScalar) { input_literal.Set(zeros, 83.0f); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &b, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); - ComputeAndCompareLiteral(&b, expected_literal, {input.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&b, expected_literal, {&input}, zero_error_spec_); } } -XLA_TEST_P(ReshapeTest, BadNewSizes) { +TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &b, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } -XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { +TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { XlaBuilder builder(TestName()); // clang-format off auto input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( @@ -620,9 +585,8 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{2, 8}); @@ -631,15 +595,12 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { {222, 333, 444, 555, 666, 777, 888, 999}, }); - XlaComputation computation = builder.Build().value(); - ExecutionOptions execution_options = execution_options_; - *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {2, 8}, {1, 0}) - .ToProto(); - Literal actual = - client_ - ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) - .value(); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + const Shape shape_with_output_layout = + ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {2, 8}, {1, 0}); + TF_ASSERT_OK_AND_ASSIGN( + const Literal actual, + ExecuteAndTransfer(computation, {&input}, &shape_with_output_layout)); Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); if (FloatType() != F32) { expected = MaybeConvertLiteralToTestType(expected); @@ -647,7 +608,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } -XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { +TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, @@ -655,9 +616,8 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{3, 2, 1, 4}); // clang-format off @@ -670,12 +630,12 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {{204, 205, 206, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. -XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { +TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, @@ -683,9 +643,8 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 0}), /*dimensions=*/{3, 2, 1, 4}); @@ -699,11 +658,11 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {{206, 7, 107, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {&input}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { +TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; @@ -713,17 +672,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{2, 1}); Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { +TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; @@ -733,18 +690,16 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{4, 2}); Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_); } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. -XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { +TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; @@ -754,9 +709,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{0, 2, 1, 3}), /*dimensions=*/{5, 60}); @@ -767,11 +721,10 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_); } -XLA_TEST_P(ReshapeTest, NoopReshape) { +TEST_P(ReshapeTest, NoopReshape) { XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; @@ -782,24 +735,18 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{3, 0, 1, 2}), /*dimensions=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().value(); - ExecutionOptions execution_options = execution_options_; - *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {7, 2, 3, 5}, - {2, 3, 0, 1}) - .ToProto(); - Literal output_literal = - client_ - ->ExecuteAndTransfer(computation, {input_data.get()}, - &execution_options) - .value(); + const Shape shape_with_output_layout = ShapeUtil::MakeShapeWithDenseLayout( + FloatType(), {7, 2, 3, 5}, {2, 3, 0, 1}); + TF_ASSERT_OK_AND_ASSIGN(const Literal output_literal, + ExecuteAndTransfer(computation, {&input_data}, + &shape_with_output_layout)); // Since the reshape is a no-op, verify that it does not change the underlying // data. @@ -829,31 +776,29 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { } } -XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { +TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { XlaBuilder builder(TestName()); auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()}); + ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {&input}); } -XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { +TEST_P(ReshapeTest, R4ToR4Reshape) { auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaBuilder builder(TestName()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN( - auto input, CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", - &builder, ¶meter)); + const Literal input = CreateParameterAndTransferLiteral( + 0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{1, 3, 2, 0}), /*dimensions=*/{2, 4, 3, 1}); @@ -869,10 +814,10 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()}); + ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {&input}); } -XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { +TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {2, 2, 2, 2}; @@ -885,9 +830,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, /*permutation=*/{0, 1, 3, 2}), /*dimensions=*/new_bounds); @@ -897,11 +841,11 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_, &expected.shape()); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_, + &expected.shape()); } -XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { +TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {1, 1, 250, 300}; @@ -914,9 +858,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, {0, 1, 3, 2}), /*dimensions=*/new_bounds); Literal expected = @@ -925,11 +868,11 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_, &expected.shape()); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_, + &expected.shape()); } -XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { +TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {5, 5, 1, 10}; @@ -942,9 +885,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, {0, 1, 3, 2}), /*dimensions=*/new_bounds); Literal expected = @@ -953,11 +895,11 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_, &expected.shape()); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_, + &expected.shape()); } -XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { +TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::mt19937 rng; std::uniform_real_distribution distribution; // This happens in NN-Builder MNIST. @@ -971,9 +913,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, {0, 1, 3, 2}), /*dimensions=*/new_bounds); Literal expected = @@ -982,11 +923,11 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_, &expected.shape()); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_, + &expected.shape()); } -XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { +TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {3, 3, 1, 3}; @@ -999,9 +940,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - TF_ASSERT_OK_AND_ASSIGN(auto input_data, - CreateParameterAndTransferLiteral( - 0, input_literal, "input", &builder, ¶meter)); + const Literal input_data = CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter); Reshape(Transpose(parameter, {1, 0, 2, 3}), /*dimensions=*/new_bounds); Literal expected = @@ -1010,8 +950,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, - zero_error_spec_, &expected.shape()); + ComputeAndCompareLiteral(&builder, expected, {&input_data}, zero_error_spec_, + &expected.shape()); } INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, diff --git a/third_party/xla/xla/tests/reverse_test.cc b/third_party/xla/xla/tests/reverse_test.cc index a7991d930c7f85..469d8dc5d59c88 100644 --- a/third_party/xla/xla/tests/reverse_test.cc +++ b/third_party/xla/xla/tests/reverse_test.cc @@ -30,16 +30,22 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/client_library_test_runner_utils.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -static std::array primitive_type_params{F32, BF16, F8E5M2, - F8E4M3FN}; +constexpr std::array kPrimitiveTypeParams{ + F32, + BF16, + F8E5M2, + F8E4M3FN, +}; struct ReverseSpec { std::vector input_dims; @@ -57,7 +63,7 @@ struct ReverseSpec { static std::vector GetTestCases() { // clang-format off return ExpandTestType( - primitive_type_params, + kPrimitiveTypeParams, {{{}, {}}, {{0, 0}, {0, 1}}, {{0, 1}, {0, 1}}, @@ -75,7 +81,7 @@ void PrintTo(const ReverseSpec& spec, std::ostream* os) { *os << spec.ToTestCaseName(); } -class FloatReverseTest : public ClientLibraryTestBase, +class FloatReverseTest : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface { public: FloatReverseTest() { set_float_type(GetParam().test_type); } @@ -86,11 +92,14 @@ TEST_P(FloatReverseTest, Reverses) { std::vector input_vector( ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); - auto r1_literal = LiteralUtil::CreateR1(input_vector); - auto input_literal = r1_literal.Reshape(spec.input_dims).value(); + const Literal r1_literal = LiteralUtil::CreateR1(input_vector); + TF_ASSERT_OK_AND_ASSIGN(const Literal input_literal, + r1_literal.Reshape(spec.input_dims)); + const Literal conv_input_literal = + MaybeConvertLiteralToTestType(input_literal); XlaBuilder builder(TestName()); - auto a = AddParam(input_literal, &builder); + XlaOp a = Parameter(&builder, 0, conv_input_literal.shape(), "input"); Rev(a, spec.reversal); Literal expected = input_literal.Clone(); @@ -105,7 +114,7 @@ TEST_P(FloatReverseTest, Reverses) { } expected.Set(output_indices, value); }); - ComputeAndCompareLiteral(&builder, expected, {}); + ComputeAndCompareLiteral(&builder, expected, {&conv_input_literal}); } INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, @@ -113,10 +122,10 @@ INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, ::testing::PrintToStringParamName()); // A simple test class which not templated by float precision. -class ReverseTest : public ClientLibraryTestBase {}; +using ReverseTest = ClientLibraryTestRunnerMixin; // Tests the reverse operation on a 4D U8 array on dimension 0 and 3. -XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { +TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { XlaBuilder b(TestName()); // Input shape is U8[1x2x3x4]. // clang-format off diff --git a/third_party/xla/xla/tests/scalar_computations_test.cc b/third_party/xla/xla/tests/scalar_computations_test.cc index 664e8adc37f5b8..d71ea615a9032d 100644 --- a/third_party/xla/xla/tests/scalar_computations_test.cc +++ b/third_party/xla/xla/tests/scalar_computations_test.cc @@ -14,33 +14,37 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include +#include #include +#include -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/status_macros.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/shape_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class ScalarComputationsTest : public ClientLibraryTestBase { - public: - ErrorSpec error_spec_{0.0001}; +constexpr ErrorSpec kErrorSpec{0.0001}; +class ScalarComputationsTest + : public ClientLibraryTestRunnerMixin { protected: // A template for building and running a binary comparison test. template @@ -76,14 +80,14 @@ XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { XlaBuilder builder(TestName()); ConstantR0(&builder, 2.1f); - ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); + ComputeAndCompareR0(&builder, 2.1f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { XlaBuilder builder(TestName()); Neg(ConstantR0(&builder, 2.1f)); - ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); + ComputeAndCompareR0(&builder, -2.1f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { @@ -97,7 +101,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { XlaBuilder builder(TestName()); Add(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); - ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); + ComputeAndCompareR0(&builder, 7.6f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { @@ -150,7 +154,7 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { XlaBuilder builder(TestName()); Sub(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); - ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); + ComputeAndCompareR0(&builder, -3.4f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { @@ -167,10 +171,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { int64_t value = 3LL << 35; Literal a_literal = LiteralUtil::CreateR0(value); - std::unique_ptr a_data = - client_->TransferToServer(a_literal).value(); - ComputeAndCompareR0(&builder, static_cast(value), - {a_data.get()}); + ComputeAndCompareR0(&builder, static_cast(value), {&a_literal}); } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { @@ -178,7 +179,7 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { Mul(Mul(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)), ConstantR0(&builder, 0.5f)); - ComputeAndCompareR0(&builder, 5.775f, {}, error_spec_); + ComputeAndCompareR0(&builder, 5.775f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF64) { @@ -240,16 +241,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR0(2.1f); - Literal b_literal = LiteralUtil::CreateR0(5.5f); - Literal c_literal = LiteralUtil::CreateR0(0.5f); - - std::unique_ptr a_data = - client_->TransferToServer(a_literal).value(); - std::unique_ptr b_data = - client_->TransferToServer(b_literal).value(); - std::unique_ptr c_data = - client_->TransferToServer(c_literal).value(); + const Literal a_literal = LiteralUtil::CreateR0(2.1f); + const Literal b_literal = LiteralUtil::CreateR0(5.5f); + const Literal c_literal = LiteralUtil::CreateR0(0.5f); XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b"); @@ -257,22 +251,21 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, - {a_data.get(), b_data.get(), c_data.get()}, - error_spec_); + {&a_literal, &b_literal, &c_literal}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { XlaBuilder builder(TestName()); Div(ConstantR0(&builder, 5.0f), ConstantR0(&builder, 2.5f)); - ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 2.0f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { XlaBuilder builder(TestName()); Rem(ConstantR0(&builder, 2.5f), ConstantR0(&builder, 5.0f)); - ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); + ComputeAndCompareR0(&builder, 2.5f, {}, kErrorSpec); } struct DivS32Params { @@ -287,7 +280,7 @@ void PrintTo(const DivS32Params& p, std::ostream* os) { << p.remainder << "}"; } -class DivS32Test : public ClientLibraryTestBase, +class DivS32Test : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface {}; XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { @@ -319,8 +312,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); Div(dividend, divisor); - ComputeAndCompareR0(&builder, p.quotient, - {dividendd.get(), divisord.get()}); + ComputeAndCompareR0(&builder, p.quotient, {÷ndd, &divisord}); } XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { @@ -334,8 +326,7 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); Rem(dividend, divisor); - ComputeAndCompareR0(&builder, p.remainder, - {dividendd.get(), divisord.get()}); + ComputeAndCompareR0(&builder, p.remainder, {÷ndd, &divisord}); } INSTANTIATE_TEST_CASE_P( @@ -389,19 +380,15 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { for (uint32_t divisor : vals) { if (divisor != 0) { for (uint32_t dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); - TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(dividend_literal)); - TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(divisor_literal)); - auto actual_literal = - client_ - ->ExecuteAndTransfer(div_computation, - {dividend_data.get(), divisor_data.get()}, - &execution_options_) - .value(); - auto expected_literal = + const Literal dividend_literal = + LiteralUtil::CreateR0(dividend); + const Literal divisor_literal = + LiteralUtil::CreateR0(divisor); + TF_ASSERT_OK_AND_ASSIGN( + const Literal actual_literal, + ExecuteAndTransfer(div_computation, + {÷nd_literal, &divisor_literal})); + const Literal expected_literal = LiteralUtil::CreateR0(dividend / divisor); EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } @@ -431,19 +418,15 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { for (uint32_t divisor : vals) { if (divisor != 0) { for (uint32_t dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); - TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(dividend_literal)); - TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(divisor_literal)); - auto actual_literal = - client_ - ->ExecuteAndTransfer(rem_computation, - {dividend_data.get(), divisor_data.get()}, - &execution_options_) - .value(); - auto expected_literal = + const Literal dividend_literal = + LiteralUtil::CreateR0(dividend); + const Literal divisor_literal = + LiteralUtil::CreateR0(divisor); + TF_ASSERT_OK_AND_ASSIGN( + const Literal actual_literal, + ExecuteAndTransfer(rem_computation, + {÷nd_literal, &divisor_literal})); + const Literal expected_literal = LiteralUtil::CreateR0(dividend % divisor); EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } @@ -457,8 +440,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { Rem(x, ConstantR0(&builder, 80000)); Literal literal = LiteralUtil::CreateR0(87919); - TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); - ComputeAndCompareR0(&builder, 7919, {input_data.get()}); + ComputeAndCompareR0(&builder, 7919, {&literal}); } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { @@ -577,7 +559,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { ConstantR0(&builder, 123.0f), // The value on true. ConstantR0(&builder, 42.0f)); // The value on false. - ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 123.0f, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { @@ -586,7 +568,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { ConstantR0(&builder, 123.0f), // The value on true. ConstantR0(&builder, 42.0f)); // The value on false. - ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 42.0f, {}, kErrorSpec); } // This test is an explicit version of what is happening in the following @@ -716,42 +698,42 @@ XLA_TEST_F(ScalarComputationsTest, ExpScalar) { XlaBuilder builder(TestName()); Exp(ConstantR0(&builder, 2.0f)); - ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); + ComputeAndCompareR0(&builder, 7.3890562, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { XlaBuilder builder("log"); Log(ConstantR0(&builder, 2.0f)); - ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); + ComputeAndCompareR0(&builder, 0.6931471, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { XlaBuilder builder(TestName()); Tanh(ConstantR0(&builder, 2.0f)); - ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); + ComputeAndCompareR0(&builder, 0.96402758, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { XlaBuilder builder(TestName()); Tanh(ConstantR0(&builder, 2.0)); - ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); + ComputeAndCompareR0(&builder, 0.96402758, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { XlaBuilder builder(TestName()); Pow(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 3.0f)); - ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); + ComputeAndCompareR0(&builder, 8.0, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, CbrtScalar) { XlaBuilder builder(TestName()); Cbrt(ConstantR0(&builder, 2.0f)); - ComputeAndCompare(&builder, {}, error_spec_); + ComputeAndCompare(&builder, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { @@ -814,7 +796,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { ConstantR0(&builder, 5.0f), // The operand to be clamped. ConstantR0(&builder, 3.0f)); // The upper bound. - ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); + ComputeAndCompareR0(&builder, 3.0, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { @@ -823,7 +805,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { ConstantR0(&builder, 2.5f), // The operand to be clamped. ConstantR0(&builder, 3.0f)); // The upper bound. - ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); + ComputeAndCompareR0(&builder, 2.5, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { @@ -832,7 +814,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { ConstantR0(&builder, -5.0f), // The operand to be clamped. ConstantR0(&builder, 3.0f)); // The upper bound. - ComputeAndCompareR0(&builder, 2.0, {}, error_spec_); + ComputeAndCompareR0(&builder, 2.0, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { @@ -906,7 +888,7 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { ConstantR0(&b, 4)), ConstantR0(&b, 20)); - ComputeAndCompareR0(&b, 0.5, {}, error_spec_); + ComputeAndCompareR0(&b, 0.5, {}, kErrorSpec); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { @@ -924,7 +906,7 @@ XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); Round(ConstantR0(&builder, 1.4f)); - ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 1.0f, {}, kErrorSpec); } } // namespace diff --git a/third_party/xla/xla/tests/select_and_scatter_test.cc b/third_party/xla/xla/tests/select_and_scatter_test.cc index 111f6cb662bdcb..cab2edef3838ad 100644 --- a/third_party/xla/xla/tests/select_and_scatter_test.cc +++ b/third_party/xla/xla/tests/select_and_scatter_test.cc @@ -33,10 +33,11 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { @@ -50,7 +51,7 @@ struct SelectAndScatterTestParam { }; class SelectAndScatterTest - : public ClientLibraryTestBase, + : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface { public: SelectAndScatterTest() : builder_(TestName()) { @@ -90,11 +91,11 @@ class SelectAndScatterTest XlaComputation min_f32_; }; -XLA_TEST_P(SelectAndScatterTest, OVERSIZE_ON_GRM(ParamTest)) { DoIt(); } +TEST_P(SelectAndScatterTest, OVERSIZE_ON_GRM(ParamTest)) { DoIt(); } class SelectAndScatterLarge : public SelectAndScatterTest {}; -XLA_TEST_P(SelectAndScatterLarge, DISABLED_ON_ISS(OVERSIZE_ON_GRM(ParamTest))) { +TEST_P(SelectAndScatterLarge, DISABLED_ON_ISS(OVERSIZE_ON_GRM(ParamTest))) { DoIt(); } @@ -229,7 +230,7 @@ INSTANTIATE_TEST_CASE_P( {3000}, {1701}, Padding::kValid, {1300}, {1}})); // Test for F32 1D array, with a zero-element input. -XLA_TEST_F(SelectAndScatterTest, R1S0F32) { +TEST_F(SelectAndScatterTest, R1S0F32) { const auto operand = ConstantR1(&builder_, {}); const auto source = ConstantR1(&builder_, {}); SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, @@ -239,7 +240,7 @@ XLA_TEST_F(SelectAndScatterTest, R1S0F32) { } // Test for F32 1D array, when windows do not overlap. -XLA_TEST_F(SelectAndScatterTest, R1F32) { +TEST_F(SelectAndScatterTest, R1F32) { const auto operand = ConstantR1(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); const auto source = ConstantR1(&builder_, {34.f, 42.f}); @@ -251,7 +252,7 @@ XLA_TEST_F(SelectAndScatterTest, R1F32) { } // Test for S32 1D array, when windows do not overlap and the init value is 1. -XLA_TEST_F(SelectAndScatterTest, R1S32) { +TEST_F(SelectAndScatterTest, R1S32) { const auto operand = ConstantR1(&builder_, {-1, 0, 6, 4, -4, 10}); const auto source = ConstantR1(&builder_, {-10, 20}); const std::vector expected = {1, 1, -9, 1, 1, 21}; @@ -262,7 +263,7 @@ XLA_TEST_F(SelectAndScatterTest, R1S32) { } // Test for S32 1D array, when windows overlap with each other. -XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { +TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { const auto operand = ConstantR1(&builder_, {1, 9, 3, 7, 5, 6}); const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 76, 0, 72, 0, 0}; @@ -273,7 +274,7 @@ XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { } // Test for S32 2D array, when windows do not overlap. -XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32)) { +TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32)) { const auto operand = ConstantR2(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); const auto source = ConstantR2(&builder_, {{2, 6}}); @@ -286,7 +287,7 @@ XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32)) { // Test for tie breaking rule in ge_f32_. When a tie is present, the operand // that has the lower lexicographical order (smaller index) should be chosen. -XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2F32Tie)) { +TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2F32Tie)) { const auto operand = ConstantR2( &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); const auto source = ConstantR2( @@ -300,7 +301,7 @@ XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2F32Tie)) { } // Similar to SelectAndScatterTest.R2S32 but the input is transposed. -XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(ReshapeR2S32)) { +TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(ReshapeR2S32)) { const auto operand = ConstantR2( &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); const auto reshape = Reshape(Transpose(operand, /*permutation=*/{1, 0}), @@ -314,7 +315,7 @@ XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(ReshapeR2S32)) { } // Test for S32 2D array, when windows overlap with each other. -XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32OverlappingWindow)) { +TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32OverlappingWindow)) { const auto operand = ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); const auto source = ConstantR2(&builder_, {{2, 6, 4}}); @@ -326,7 +327,7 @@ XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32OverlappingWindow)) { } // Test for S32 2D array, when the padding is Padding::kSAME. -XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32SamePadding)) { +TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32SamePadding)) { const auto operand = ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); const auto source = ConstantR2(&builder_, {{2, 6, 4}}); @@ -339,8 +340,8 @@ XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2S32SamePadding)) { // Test for S32 2D array, when the padding is Padding::kSAME and windows overlap // with each other. -XLA_TEST_F(SelectAndScatterTest, - DISABLED_ON_TPU(R2S32SamePaddingOverlappingWindow)) { +TEST_F(SelectAndScatterTest, + DISABLED_ON_TPU(R2S32SamePaddingOverlappingWindow)) { const auto operand = ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); const auto source = @@ -352,7 +353,7 @@ XLA_TEST_F(SelectAndScatterTest, ComputeAndCompareR2(&builder_, expected, {}); } -XLA_TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2F32OverlappingR2Source)) { +TEST_F(SelectAndScatterTest, DISABLED_ON_TPU(R2F32OverlappingR2Source)) { const auto operand = ConstantR2( &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); const auto source = @@ -464,7 +465,7 @@ TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) { } // Test for F32 4D array with negative padding on both ends. -XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingOnBothEnds) { +TEST_F(SelectAndScatterTest, R4NegativePaddingOnBothEnds) { Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f}, {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, @@ -491,7 +492,7 @@ XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingOnBothEnds) { } // Test for F32 4D array with positive low padding and negative high padding. -XLA_TEST_F(SelectAndScatterTest, R4PositivePaddingLowAndNegativePaddingHigh) { +TEST_F(SelectAndScatterTest, R4PositivePaddingLowAndNegativePaddingHigh) { Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f}, {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, @@ -518,7 +519,7 @@ XLA_TEST_F(SelectAndScatterTest, R4PositivePaddingLowAndNegativePaddingHigh) { } // Test for F32 4D array with negative low padding and positive high padding. -XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingLowAndPositivePaddingHigh) { +TEST_F(SelectAndScatterTest, R4NegativePaddingLowAndPositivePaddingHigh) { Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f}, {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, @@ -544,7 +545,7 @@ XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingLowAndPositivePaddingHigh) { ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } -XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { +TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 0, 0, 53, 0, 0, 0}; @@ -554,7 +555,7 @@ XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } -XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { +TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const float max_float = std::numeric_limits::max(); diff --git a/third_party/xla/xla/tests/select_test.cc b/third_party/xla/xla/tests/select_test.cc index 4c16ee26d9bae0..02339bf6fe1a5a 100644 --- a/third_party/xla/xla/tests/select_test.cc +++ b/third_party/xla/xla/tests/select_test.cc @@ -13,24 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/literal.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class SelectTest : public ClientLibraryTestBase { - public: - ErrorSpec error_spec_{0.0001}; -}; +constexpr ErrorSpec kErrorSpec{0.0001}; + +using SelectTest = ClientLibraryTestRunnerMixin; TEST_F(SelectTest, SelectScalarF32True) { XlaBuilder builder(TestName()); @@ -39,7 +39,7 @@ TEST_F(SelectTest, SelectScalarF32True) { auto on_false = ConstantR0(&builder, 42.0f); Select(pred, on_true, on_false); - ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 123.0f, {}, kErrorSpec); } TEST_F(SelectTest, SelectScalarS32True) { @@ -59,7 +59,7 @@ TEST_F(SelectTest, SelectScalarF32False) { auto on_false = ConstantR0(&builder, 42.0f); Select(pred, on_true, on_false); - ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 42.0f, {}, kErrorSpec); } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { @@ -69,7 +69,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { auto on_false = ConstantR1(&builder, {}); Select(pred, on_true, on_false); - ComputeAndCompareR1(&builder, {}, {}, error_spec_); + ComputeAndCompareR1(&builder, {}, {}, kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { @@ -82,7 +82,7 @@ TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, - error_spec_); + kErrorSpec); } XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { @@ -96,7 +96,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { auto on_false = ConstantR1(&builder, {}); Select(cmp, on_true, on_false); - ComputeAndCompareR1(&builder, {}, {}, error_spec_); + ComputeAndCompareR1(&builder, {}, {}, kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { @@ -113,7 +113,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, - error_spec_); + kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { @@ -129,7 +129,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, - error_spec_); + kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { @@ -138,18 +138,17 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { XlaBuilder builder(TestName()); XlaOp v1, v2; - std::unique_ptr param0_data = CreateR1Parameter( + const Literal param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); - std::unique_ptr param1_data = CreateR1Parameter( + const Literal param1_data = CreateR1Parameter( {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); auto cmp = Gt(v1, v2); Select(cmp, v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, - {param0_data.get(), param1_data.get()}, - error_spec_); + {¶m0_data, ¶m1_data}, kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { @@ -182,18 +181,17 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { } XlaOp v1, v2; - std::unique_ptr param0_data = + const Literal param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); - std::unique_ptr param1_data = + const Literal param1_data = CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); auto cmp = Gt(v1, v2); Select(cmp, v1, v2); ComputeAndCompareR1(&builder, expected_vec, - {param0_data.get(), param1_data.get()}, - error_spec_); + {¶m0_data, ¶m1_data}, kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { @@ -210,7 +208,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, - error_spec_); + kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { @@ -227,7 +225,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, - error_spec_); + kErrorSpec); } XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { @@ -238,7 +236,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { auto on_false = ConstantR1(&builder, {}); Select(pred, on_true, on_false); - ComputeAndCompareR1(&builder, {}, {}, error_spec_); + ComputeAndCompareR1(&builder, {}, {}, kErrorSpec); } } @@ -249,7 +247,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); Select(pred, on_true, on_false); - ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, error_spec_); + ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, kErrorSpec); } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { @@ -259,7 +257,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); Select(pred, on_true, on_false); - ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, error_spec_); + ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, kErrorSpec); } TEST_F(SelectTest, SelectR1S4WithConstantR1PRED) { diff --git a/third_party/xla/xla/tests/slice_test.cc b/third_party/xla/xla/tests/slice_test.cc index 91f5989438b2df..44df437edbef07 100644 --- a/third_party/xla/xla/tests/slice_test.cc +++ b/third_party/xla/xla/tests/slice_test.cc @@ -15,8 +15,12 @@ limitations under the License. // Tests that slice operations can be performed. +#include +#include +#include +#include #include -#include +#include #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" @@ -24,18 +28,22 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/array2d.h" -#include "xla/client/local_client.h" +#include "xla/array3d.h" +#include "xla/array4d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class SliceTest : public ClientLibraryTestBase {}; +using SliceTest = ClientLibraryTestRunnerMixin; TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { Array3D values(3, 3, 3); @@ -208,7 +216,7 @@ struct R1Spec { // Parameterized test that generates R1 values, slices them according // to the R1Spec, and compares the result with a computed version. -class SliceR1Test : public ClientLibraryTestBase, +class SliceR1Test : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface { protected: template @@ -233,9 +241,7 @@ class SliceR1Test : public ClientLibraryTestBase, expected.push_back(i); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(literal)); - ComputeAndCompareR1(&builder, expected, {arg.get()}); + ComputeAndCompareR1(&builder, expected, {&literal}); } }; @@ -397,7 +403,7 @@ struct R2Spec { // Parameterized test that generates patterned R2 values, slices them according // to the R2Spec, and compares the results with the ReferenceUtil version. -class SliceR2Test : public ClientLibraryTestBase, +class SliceR2Test : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface {}; XLA_TEST_P(SliceR2Test, DoIt) { @@ -411,11 +417,9 @@ XLA_TEST_P(SliceR2Test, DoIt) { auto a = Parameter(&builder, 0, literal.shape(), "p0"); Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR2(&builder, *expected, {arg.get()}); + ComputeAndCompareR2(&builder, *expected, {&literal}); } INSTANTIATE_TEST_CASE_P( @@ -488,7 +492,7 @@ std::string R4SpecToString(const ::testing::TestParamInfo& data) { "__strides_", absl::StrJoin(spec.slice_strides, "x")); } -class SliceR4Test : public ClientLibraryTestBase, +class SliceR4Test : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface { protected: void Run(const R4Spec& spec) { @@ -501,10 +505,8 @@ class SliceR4Test : public ClientLibraryTestBase, auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = Parameter(&builder, 0, literal.shape(), "p0"); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(literal)); Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); + ComputeAndCompareR4(&builder, *expected, {&literal}, ErrorSpec(0.000001)); } }; diff --git a/third_party/xla/xla/tests/transpose_test.cc b/third_party/xla/xla/tests/transpose_test.cc index ffd1a5c6156faf..28abe480abfca4 100644 --- a/third_party/xla/xla/tests/transpose_test.cc +++ b/third_party/xla/xla/tests/transpose_test.cc @@ -19,49 +19,74 @@ limitations under the License. #include #include "xla/array2d.h" +#include "xla/array3d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class TransposeTest : public ClientLibraryTestBase { - public: - ErrorSpec error_spec_{0.0001}; +constexpr ErrorSpec kErrorSpec{0.0001}; +class TransposeTest : public ClientLibraryTestRunnerMixin { protected: - void TestTransposeConstant(Vector3 sizes, Vector3 transpose_dims); + void TestTransposeConstant(Vector3 sizes, Vector3 transpose_dims) { + Array3D aoperand(sizes[0], sizes[1], sizes[2]); + std::vector expected(sizes[0] * sizes[1] * sizes[2]); + for (int64_t i = 0; i < sizes[0]; ++i) { + for (int64_t j = 0; j < sizes[1]; ++j) { + for (int64_t k = 0; k < sizes[2]; ++k) { + Vector3 indices{i, j, k}; + aoperand(i, j, k) = (i * sizes[1] + j) * sizes[2] + k; + expected[(indices[transpose_dims[0]] * sizes[transpose_dims[1]] + + indices[transpose_dims[1]]) * + sizes[transpose_dims[2]] + + indices[transpose_dims[2]]] = aoperand(i, j, k); + } + } + } + + XlaBuilder builder(TestName()); + auto operand = ConstantR3FromArray3D(&builder, aoperand); + auto transpose = Transpose(operand, transpose_dims); + // Add a reshape so that the transpose does not disappear during layout + // assignment. + Reshape(transpose, {sizes[0] * sizes[1] * sizes[2]}); + + ComputeAndCompareR1(&builder, expected, {}); + } }; -XLA_TEST_F(TransposeTest, Transpose0x0) { +TEST_F(TransposeTest, Transpose0x0) { XlaBuilder builder("Transpose"); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 0)); Transpose(lhs, {1, 0}); - ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); + ComputeAndCompareR2(&builder, Array2D(0, 0), {}, kErrorSpec); } -XLA_TEST_F(TransposeTest, Transpose0x42) { +TEST_F(TransposeTest, Transpose0x42) { XlaBuilder builder("Transpose"); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 42)); Transpose(lhs, {1, 0}); - ComputeAndCompareR2(&builder, Array2D(42, 0), {}, error_spec_); + ComputeAndCompareR2(&builder, Array2D(42, 0), {}, kErrorSpec); } -XLA_TEST_F(TransposeTest, Transpose7x0) { +TEST_F(TransposeTest, Transpose7x0) { XlaBuilder builder("Transpose"); auto lhs = ConstantR2FromArray2D(&builder, Array2D(7, 0)); Transpose(lhs, {1, 0}); - ComputeAndCompareR2(&builder, Array2D(0, 7), {}, error_spec_); + ComputeAndCompareR2(&builder, Array2D(0, 7), {}, kErrorSpec); } TEST_F(TransposeTest, Transpose2x2) { @@ -74,10 +99,10 @@ TEST_F(TransposeTest, Transpose2x2) { Array2D expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + ComputeAndCompareR2(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { +TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { XlaBuilder builder("Transpose"); auto operand = ConstantR3FromArray3D(&builder, Array3D(0, 2, 3)); @@ -130,7 +155,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { computed = Transpose(computed, {1, 0}); } const Array2D& expected = transposes % 2 == 0 ? input : transposed; - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + ComputeAndCompareR2(&builder, expected, {}, kErrorSpec); } } @@ -158,33 +183,6 @@ TEST_F(TransposeTest, Small_2x2) { ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); } -void TransposeTest::TestTransposeConstant(Vector3 sizes, - Vector3 transpose_dims) { - Array3D aoperand(sizes[0], sizes[1], sizes[2]); - std::vector expected(sizes[0] * sizes[1] * sizes[2]); - for (int64_t i = 0; i < sizes[0]; ++i) { - for (int64_t j = 0; j < sizes[1]; ++j) { - for (int64_t k = 0; k < sizes[2]; ++k) { - Vector3 indices{i, j, k}; - aoperand(i, j, k) = (i * sizes[1] + j) * sizes[2] + k; - expected[(indices[transpose_dims[0]] * sizes[transpose_dims[1]] + - indices[transpose_dims[1]]) * - sizes[transpose_dims[2]] + - indices[transpose_dims[2]]] = aoperand(i, j, k); - } - } - } - - XlaBuilder builder(TestName()); - auto operand = ConstantR3FromArray3D(&builder, aoperand); - auto transpose = Transpose(operand, transpose_dims); - // Add a reshape so that the transpose does not disappear during layout - // assignment. - Reshape(transpose, {sizes[0] * sizes[1] * sizes[2]}); - - ComputeAndCompareR1(&builder, expected, {}); -} - TEST_F(TransposeTest, TransposeConstant021_SingleIncompleteTilePerLayer) { TestTransposeConstant({2, 16, 17}, {0, 2, 1}); } @@ -204,8 +202,8 @@ TEST_F(TransposeTest, TransposeConstant210_DegenerateDim) { using HloTransposeTest = HloTestBase; // Disable HLO passes to verify the default behavior -XLA_TEST_F(HloTransposeTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( - DISABLED_ON_TPU(HloPassesDisabled)))) { +TEST_F(HloTransposeTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_TPU(HloPassesDisabled)))) { const char* const kModuleStr = R"( HloModule Transpose diff --git a/third_party/xla/xla/tests/triangular_solve_test.cc b/third_party/xla/xla/tests/triangular_solve_test.cc index a2e6334f69f99c..d5e95eca8c7345 100644 --- a/third_party/xla/xla/tests/triangular_solve_test.cc +++ b/third_party/xla/xla/tests/triangular_solve_test.cc @@ -13,33 +13,41 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include +#include +#include +#include #include -#include "absl/status/statusor.h" +#include "absl/log/check.h" #include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/array.h" #include "xla/array2d.h" -#include "xla/hlo/builder/lib/math.h" +#include "xla/array3d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" #include "xla/literal.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/literal_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -using TriangularSolveTest = ClientLibraryTestBase; -using TriangularSolveLeftLookingTest = ClientLibraryTestBase; +constexpr float kNan = std::numeric_limits::quiet_NaN(); +constexpr complex64 kNanC64 = complex64(kNan, kNan); -static constexpr float kNan = std::numeric_limits::quiet_NaN(); +using TriangularSolveTest = ClientLibraryTestRunnerMixin; +using TriangularSolveLeftLookingTest = + ClientLibraryTestRunnerMixin; Array2D AValsLower() { return {{2, kNan, kNan, kNan}, @@ -77,8 +85,6 @@ Array2D BValsLeft() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -static constexpr complex64 kNanC64 = complex64(kNan, kNan); - Array2D AValsLowerComplex() { return {{2, kNanC64, kNanC64, kNanC64}, {complex64(3, 1), 6, kNanC64, kNanC64}, @@ -101,7 +107,7 @@ Array2D BValsLeftComplex() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -XLA_TEST_F(TriangularSolveTest, EmptyArrays) { +TEST_F(TriangularSolveTest, EmptyArrays) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -115,10 +121,10 @@ XLA_TEST_F(TriangularSolveTest, EmptyArrays) { /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); ComputeAndCompareR2(&builder, Array2D(0, 10), - {a_data.get(), b_data.get()}); + {&a_data, &b_data}); } -XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { +TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -135,11 +141,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { {4.5, -0.58333331, -0.32407406, -0.23569024}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { +TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -156,11 +162,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { +TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -177,11 +183,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { +TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -198,11 +204,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { {4.5, -0.58333331, -0.32407406, -0.23569024}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { +TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -220,11 +226,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { {0.90909091, 1., 1.09090909}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { +TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -242,11 +248,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { {0.16835017, 0.13468013, 0.1010101}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { +TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -261,11 +267,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { Array2D expected( {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}}); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { +TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -283,11 +289,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { {0.16835017, 0.13468013, 0.1010101}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { +TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -305,11 +311,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { {0.16835017, 0.13468013, 0.1010101}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { +TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -327,11 +333,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { {0.90909091, 1., 1.09090909}, }); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { +TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -348,11 +354,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { {-93., -102., -111.}, {10., 11., 12.}}); - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { +TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -374,11 +380,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { complex64(0.11026936, -0.03114478)}, }); - ComputeAndCompareR2( - &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, + ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { +TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { XlaBuilder builder(TestName()); XlaOp a, b; @@ -402,11 +408,11 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { complex64(0.15798226, 5.12749446e-01)}, }); - ComputeAndCompareR2( - &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2(&builder, expected, {&a_data, &b_data}, + ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { +TEST_F(TriangularSolveTest, BatchedLeftUpper) { XlaBuilder builder(TestName()); Array3D bvals(7, 5, 5); @@ -430,7 +436,7 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { /*unit_diagonal=*/false, /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE)); - ComputeAndCompareR3(&builder, bvals, {a_data.get(), b_data.get()}, + ComputeAndCompareR3(&builder, bvals, {&a_data, &b_data}, ErrorSpec(1e-2, 1e-2)); } @@ -442,13 +448,13 @@ struct TriangularSolveTestSpec { }; class TriangularSolveParametricTest - : public ClientLibraryTestBase, + : public ClientLibraryTestRunnerMixin, public ::testing::WithParamInterface {}; -XLA_TEST_P(TriangularSolveParametricTest, Random) { +TEST_P(TriangularSolveParametricTest, Random) { TriangularSolveTestSpec spec = GetParam(); - if (client_->backend() + if (backend() .default_stream_executor() ->GetDeviceDescription() .cuda_compute_capability() @@ -480,12 +486,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { Array bvals(b_dims); bvals.FillRandom(1.0); - XlaOp a, b; - auto a_data = CreateParameter(avals, 0, "a", &builder, &a); - auto b_data = CreateParameter(bvals, 1, "b", &builder, &b); - auto x = TriangularSolve(a, b, spec.left_side, spec.lower, - /*unit_diagonal=*/false, spec.transpose_a); - auto a_tri = Triangle(a, spec.lower); + const Literal avals_literal = LiteralUtil::CreateFromArray(avals); + const Literal bvals_literal = LiteralUtil::CreateFromArray(bvals); + XlaOp a = Parameter(&builder, 0, avals_literal.shape(), "a"); + XlaOp b = Parameter(&builder, 1, bvals_literal.shape(), "b"); + XlaOp x = TriangularSolve(a, b, spec.left_side, spec.lower, + /*unit_diagonal=*/false, spec.transpose_a); + XlaOp a_tri = Triangle(a, spec.lower); a_tri = MaybeTransposeInMinorDims( a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE); if (spec.left_side) { @@ -494,7 +501,8 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { BatchDot(x, a_tri, xla::PrecisionConfig::HIGHEST); } - ComputeAndCompare(&builder, bvals, {a_data.get(), b_data.get()}, + ComputeAndCompareLiteral(&builder, LiteralUtil::CreateFromArray(bvals), + {&avals_literal, &bvals_literal}, ErrorSpec(3e-2, 3e-2)); } diff --git a/third_party/xla/xla/tests/tuple_test.cc b/third_party/xla/xla/tests/tuple_test.cc index 5469ac9ba48f56..d2f7fd5ff1bbee 100644 --- a/third_party/xla/xla/tests/tuple_test.cc +++ b/third_party/xla/xla/tests/tuple_test.cc @@ -13,36 +13,40 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include +#include #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/array2d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" +#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class TupleTest : public ClientLibraryTestBase { - public: - ErrorSpec error_spec_{0.0001}; -}; +constexpr ErrorSpec kErrorSpec{0.0001}; + +using TupleTest = ClientLibraryTestRunnerMixin; // Tests a tuple-shaped constant. -XLA_TEST_F(TupleTest, TupleConstant) { +TEST_F(TupleTest, TupleConstant) { XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; @@ -57,11 +61,11 @@ XLA_TEST_F(TupleTest, TupleConstant) { LiteralUtil::CreateR2(constant_matrix)}); ConstantLiteral(&builder, value); - ComputeAndCompareTuple(&builder, value, {}, error_spec_); + ComputeAndCompareTuple(&builder, value, {}, kErrorSpec); } // Tests a tuple made of scalar constants. -XLA_TEST_F(TupleTest, TupleScalarConstant) { +TEST_F(TupleTest, TupleScalarConstant) { XlaBuilder builder(TestName()); const float constant_scalar1 = 7.3f; @@ -71,11 +75,11 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { LiteralUtil::CreateR0(constant_scalar2)}); ConstantLiteral(&builder, value); - ComputeAndCompareTuple(&builder, value, {}, error_spec_); + ComputeAndCompareTuple(&builder, value, {}, kErrorSpec); } // Tests the creation of tuple data. -XLA_TEST_F(TupleTest, TupleCreate) { +TEST_F(TupleTest, TupleCreate) { XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; @@ -92,11 +96,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { {LiteralUtil::CreateR0(constant_scalar), LiteralUtil::CreateR1(constant_vector), LiteralUtil::CreateR2(constant_matrix)}); - ComputeAndCompareTuple(&builder, expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, kErrorSpec); } // Tests the creation of tuple data. -XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { +TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XlaBuilder builder(TestName()); Tuple(&builder, @@ -104,19 +108,19 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR0(7.0), LiteralUtil::CreateR1({})}); - ComputeAndCompareTuple(&builder, expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, kErrorSpec); } // Tests the creation of an empty tuple. -XLA_TEST_F(TupleTest, EmptyTupleCreate) { +TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, kErrorSpec); } // Trivial test for extracting a tuple element with GetTupleElement. -XLA_TEST_F(TupleTest, GetTupleElement) { +TEST_F(TupleTest, GetTupleElement) { XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { @@ -128,21 +132,21 @@ XLA_TEST_F(TupleTest, GetTupleElement) { ConstantR2(&builder, constant_matrix)}); GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, - error_spec_); + kErrorSpec); } // Trivial test for extracting a tuple element with GetTupleElement. -XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { +TEST_F(TupleTest, GetTupleElementWithZeroElements) { XlaBuilder builder(TestName()); auto tuple_data = Tuple(&builder, {ConstantR1(&builder, {}), ConstantR2FromArray2D(&builder, Array2D(0, 101))}); GetTupleElement(tuple_data, 1); - ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); + ComputeAndCompareR2(&builder, Array2D(0, 101), {}, kErrorSpec); } -XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { +TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { XlaBuilder builder(TestName()); auto value = ConstantR1(&builder, {4.5f}); GetTupleElement(value, 1); @@ -155,7 +159,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { // Extracts both elements from a tuple with GetTupleElement and then adds them // together. -XLA_TEST_F(TupleTest, AddTupleElements) { +TEST_F(TupleTest, AddTupleElements) { XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { @@ -179,12 +183,12 @@ XLA_TEST_F(TupleTest, AddTupleElements) { ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3}))); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3}))); - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + ComputeAndCompareR2(&builder, expected, {}, kErrorSpec); } // Extracts both elements from a tuple and then puts them into a new tuple in // the opposite order. -XLA_TEST_F(TupleTest, TupleGTEToTuple) { +TEST_F(TupleTest, TupleGTEToTuple) { XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { @@ -199,13 +203,12 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR2(constant_matrix), LiteralUtil::CreateR1(constant_vector)}); - ComputeAndCompareTuple(&builder, expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, kErrorSpec); } - // Builds two new tuples from an existing tuple (by means of GetTupleElement), // then adds up the components of the new tuples. -XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { +TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { // // v------ --(GTE 0)-- --(GTE 0)---------- // \ / \ / \ @@ -248,10 +251,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { {4.f, 8.f, 12.f}, // row 0 {10.f, 14.f, 18.f}, // row 1 }); - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + ComputeAndCompareR2(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(TupleTest, NestedTuples) { +TEST_F(TupleTest, NestedTuples) { XlaBuilder builder(TestName()); auto inner_tuple = Tuple(&builder, {ConstantR1(&builder, {1.0, 2.0}), ConstantR0(&builder, 42.0)}); @@ -264,10 +267,10 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { +TEST_F(TupleTest, GetTupleElementOfNestedTuple) { XlaBuilder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {3}); @@ -280,23 +283,18 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { auto gte1 = GetTupleElement(gte0, 1); Add(gte1, ConstantR1(&builder, {10.0, 11.0, 12.0})); - std::unique_ptr data = - client_ - ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ - LiteralUtil::MakeTupleFromSlices({ - LiteralUtil::CreateR1({1.0, 2.0, 3.0}), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}), - }), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}), - })) - .value(); - - std::vector arguments = {data.get()}; + const Literal data = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}), + }); const std::vector expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0}; - ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); + ComputeAndCompareR1(&builder, expected, {&data}, ErrorSpec(1e-5)); } -XLA_TEST_F(TupleTest, ComplexTuples) { +TEST_F(TupleTest, ComplexTuples) { XlaBuilder builder(TestName()); { Shape c64r0 = ShapeUtil::MakeShape(C64, {}); @@ -316,23 +314,16 @@ XLA_TEST_F(TupleTest, ComplexTuples) { ConstantR0(&builder, {123, 456})}); } - std::unique_ptr arg0 = - client_ - ->TransferToServer(LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR0({1, 2}), - LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), - LiteralUtil::CreateR2( - {{{100, 200}, {300, 400}}, - {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}})})})) - .value(); - std::unique_ptr arg1 = - client_ - ->TransferToServer( - LiteralUtil::CreateR1({{1, 2}, {1, -2}})) - .value(); - auto sum = + const Literal arg0 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), + LiteralUtil::CreateR2( + {{{100, 200}, {300, 400}}, + {{1000, 2000}, {3000, 4000}}, + {{10000, 20000}, {30000, 40000}}})})}); + const Literal arg1 = LiteralUtil::CreateR1({{1, 2}, {1, -2}}); + const Literal sum = LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); @@ -345,16 +336,15 @@ XLA_TEST_F(TupleTest, ComplexTuples) { : complex64(1, -2)); }) .ok()); - auto expected = LiteralUtil::MakeTupleFromSlices( + const Literal expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::MakeTupleFromSlices({prod, sum}), LiteralUtil::CreateR0({123, 456})}); - ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, - error_spec_); + ComputeAndCompareTuple(&builder, expected, {&arg0, &arg1}, kErrorSpec); } -class TupleHloTest : public HloTestBase {}; +using TupleHloTest = HloTestBase; -XLA_TEST_F(TupleHloTest, BadTupleShapeFailsGracefully) { +TEST_F(TupleHloTest, BadTupleShapeFailsGracefully) { const char* testcase = R"( HloModule m, is_scheduled=true @@ -374,7 +364,7 @@ XLA_TEST_F(TupleHloTest, BadTupleShapeFailsGracefully) { EXPECT_THAT(status.message(), ::testing::HasSubstr("actual shape is")); } -XLA_TEST_F(TupleHloTest, BitcastAfterGTE) { +TEST_F(TupleHloTest, BitcastAfterGTE) { const char* testcase = R"( HloModule m, is_scheduled=true diff --git a/third_party/xla/xla/tests/unary_op_test.cc b/third_party/xla/xla/tests/unary_op_test.cc index aa1cf98d8509fc..017e39f08bf478 100644 --- a/third_party/xla/xla/tests/unary_op_test.cc +++ b/third_party/xla/xla/tests/unary_op_test.cc @@ -13,21 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include +#include +#include -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class UnaryOpTest : public ClientLibraryTestBase { +class UnaryOpTest : public ClientLibraryTestRunnerMixin { protected: template T inf() { @@ -131,19 +134,19 @@ void UnaryOpTest::SignAbsTestHelper() { ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } -XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { +TEST_F(UnaryOpTest, AbsTestR1Size0) { AbsSize0TestHelper(); AbsSize0TestHelper(); AbsSize0TestHelper(); } -XLA_TEST_F(UnaryOpTest, AbsTestR1) { +TEST_F(UnaryOpTest, AbsTestR1) { AbsTestHelper(); AbsTestHelper(); AbsTestHelper(); } -XLA_TEST_F(UnaryOpTest, AbsTestR0) { +TEST_F(UnaryOpTest, AbsTestR0) { XlaBuilder builder(TestName()); auto argi = ConstantR0(&builder, -5); auto absi = Abs(argi); @@ -158,7 +161,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) { ComputeAndCompareR0(&builder, 8.5f, {}); } -XLA_TEST_F(UnaryOpTest, SignTestR0) { +TEST_F(UnaryOpTest, SignTestR0) { XlaBuilder builder(TestName()); auto argi = ConstantR0(&builder, -5); auto sgni = Sign(argi); // -1 @@ -175,20 +178,20 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } -XLA_TEST_F(UnaryOpTest, SignTestR1) { +TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); SignTestHelper(); SignTestHelper(); SignTestHelper(); } -XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { +TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); SignAbsTestHelper(); SignAbsTestHelper(); } -XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { +TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); auto sign = Sign(arg); @@ -198,7 +201,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } -XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { +TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 1}); auto rhs = ConstantR1(&builder, {1, 1}); @@ -207,7 +210,7 @@ XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { ComputeAndCompareR1(&builder, {0, 1}, {}); } -XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { +TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 1}); auto rhs = ConstantR1(&builder, {1, 1}); diff --git a/third_party/xla/xla/tests/vector_ops_reduce_test.cc b/third_party/xla/xla/tests/vector_ops_reduce_test.cc index f35beb32f78fde..834e8dcc867db6 100644 --- a/third_party/xla/xla/tests/vector_ops_reduce_test.cc +++ b/third_party/xla/xla/tests/vector_ops_reduce_test.cc @@ -13,25 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class VecOpsReduceTest : public ClientLibraryTestBase { +constexpr ErrorSpec kErrorSpec{1e-3, 0}; + +class VecOpsReduceTest : public ClientLibraryTestRunnerMixin { public: VecOpsReduceTest() : builder_(TestName()) {} @@ -50,7 +51,6 @@ class VecOpsReduceTest : public ClientLibraryTestBase { } XlaBuilder builder_; - ErrorSpec errspec_{1e-3, 0}; }; TEST_F(VecOpsReduceTest, AddReduceR1F32) { @@ -61,7 +61,7 @@ TEST_F(VecOpsReduceTest, AddReduceR1F32) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{0}); - ComputeAndCompareR0(&builder_, -4.2f, {}, errspec_); + ComputeAndCompareR0(&builder_, -4.2f, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { @@ -75,7 +75,7 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { /*dimensions_to_reduce=*/{0}); float expected = std::accumulate(input.begin(), input.end(), 0.0f); - ComputeAndCompareR0(&builder_, expected, {}, errspec_); + ComputeAndCompareR0(&builder_, expected, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, MaxReduceR1F32) { @@ -86,7 +86,7 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) { Reduce(x, ConstantR0(&builder_, 0.0f), max_reducer, /*dimensions_to_reduce=*/{0}); - ComputeAndCompareR0(&builder_, 2.6f, {}, errspec_); + ComputeAndCompareR0(&builder_, 2.6f, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { @@ -97,7 +97,7 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { Reduce(x, ConstantR0(&builder_, 4.0f), max_reducer, /*dimensions_to_reduce=*/{0}); - ComputeAndCompareR0(&builder_, 4.0f, {}, errspec_); + ComputeAndCompareR0(&builder_, 4.0f, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { @@ -113,7 +113,7 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{1}); - ComputeAndCompareR1(&builder_, {6.0, 15.0}, {}, errspec_); + ComputeAndCompareR1(&builder_, {6.0, 15.0}, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { @@ -127,7 +127,7 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{0}); - ComputeAndCompareR1(&builder_, {5.0, 7.0, 9.0}, {}, errspec_); + ComputeAndCompareR1(&builder_, {5.0, 7.0, 9.0}, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { @@ -138,7 +138,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { Array2D expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}}); - ComputeAndCompareR2(&builder_, expected_array, {}, errspec_); + ComputeAndCompareR2(&builder_, expected_array, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { @@ -150,7 +150,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { Array2D expected_array( {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}}); - ComputeAndCompareR2(&builder_, expected_array, {}, errspec_); + ComputeAndCompareR2(&builder_, expected_array, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { @@ -161,7 +161,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { Array2D expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}}); - ComputeAndCompareR2(&builder_, expected_array, {}, errspec_); + ComputeAndCompareR2(&builder_, expected_array, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { @@ -170,7 +170,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{1, 2}); - ComputeAndCompareR1(&builder_, {21.0, 21.0, 21.0}, {}, errspec_); + ComputeAndCompareR1(&builder_, {21.0, 21.0, 21.0}, {}, kErrorSpec); } XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { @@ -179,7 +179,7 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{0, 2}); - ComputeAndCompareR1(&builder_, {18.0, 45.0}, {}, errspec_); + ComputeAndCompareR1(&builder_, {18.0, 45.0}, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { @@ -188,7 +188,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{0, 1}); - ComputeAndCompareR1(&builder_, {15.0, 21.0, 27.0}, {}, errspec_); + ComputeAndCompareR1(&builder_, {15.0, 21.0, 27.0}, {}, kErrorSpec); } TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { @@ -197,7 +197,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, /*dimensions_to_reduce=*/{0, 1, 2}); - ComputeAndCompareR0(&builder_, 63.0, {}, errspec_); + ComputeAndCompareR0(&builder_, 63.0, {}, kErrorSpec); } } // namespace diff --git a/third_party/xla/xla/tests/vector_ops_simple_test.cc b/third_party/xla/xla/tests/vector_ops_simple_test.cc index e9defab005758f..1b9b78a70c0b81 100644 --- a/third_party/xla/xla/tests/vector_ops_simple_test.cc +++ b/third_party/xla/xla/tests/vector_ops_simple_test.cc @@ -13,41 +13,40 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include -#include #include #include #include "absl/status/statusor.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/client_library_test_runner_mixin.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { -class VecOpsSimpleTest : public ClientLibraryTestBase { +constexpr ErrorSpec kErrorSpec{0.0001}; + +class VecOpsSimpleTest : public ClientLibraryTestRunnerMixin { public: - explicit VecOpsSimpleTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) { + VecOpsSimpleTest() { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); } - - ErrorSpec error_spec_{0.0001}; }; -XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { +TEST_F(VecOpsSimpleTest, ExpTenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -57,10 +56,10 @@ XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { 8.1662, 9.9742, 6.7379e-03, 4.0657e-01, 9.0718e-02, 4.9530}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { +TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { XlaBuilder builder(TestName()); std::vector exponents; @@ -82,7 +81,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { } } -XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { +TEST_F(VecOpsSimpleTest, ExpIn4D) { XlaBuilder builder(TestName()); Array4D exponents(2, 2, 2, 2); @@ -107,7 +106,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); } -XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { +TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -115,10 +114,10 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { std::vector expected = {-2.1, 2.6, -2.6, 4.0, -2.1, -2.3, 5.0, 0.9, 2.4, -1.6}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { +TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); Neg(x); @@ -127,7 +126,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { +TEST_F(VecOpsSimpleTest, NegateUint32Values) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0, 1, 42, static_cast(-1), static_cast(-12)}); @@ -137,7 +136,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { +TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); @@ -146,10 +145,10 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { std::vector expected = {.25, 1, .03125, 2.5, 2.23607, .009000, .900025}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { +TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); @@ -161,10 +160,10 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { std::vector expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); + ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } -XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { +TEST_F(VecOpsSimpleTest, MaxTenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -177,25 +176,24 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { +TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { // Similar to MaxTenValues, except that the inputs come from params rather // than constants. XlaBuilder builder(TestName()); XlaOp v1, v2; - std::unique_ptr param0_data = CreateR1Parameter( + const Literal param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); - std::unique_ptr param1_data = CreateR1Parameter( + const Literal param1_data = CreateR1Parameter( {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); Max(v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, - {param0_data.get(), param1_data.get()}, - error_spec_); + {¶m0_data, ¶m1_data}, kErrorSpec); } -XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { +TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { // Similar to MaxTenValuesFromParams, except that the data size passed in and // out is large. XlaBuilder builder(TestName()); @@ -225,20 +223,19 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { } XlaOp v1, v2; - std::unique_ptr param0_data = + const Literal param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); - std::unique_ptr param1_data = + const Literal param1_data = CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); Max(v1, v2); ComputeAndCompareR1(&builder, expected_vec, - {param0_data.get(), param1_data.get()}, - error_spec_); + {¶m0_data, ¶m1_data}, kErrorSpec); } -XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { +TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -250,7 +247,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { +TEST_F(VecOpsSimpleTest, MinTenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -263,7 +260,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { +TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XlaBuilder builder(TestName()); auto zero = ConstantR0(&builder, 0); auto one = ConstantR0(&builder, 1); @@ -276,7 +273,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { +TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XlaBuilder builder(TestName()); auto zero = ConstantR0(&builder, 0); auto one = ConstantR0(&builder, 1); @@ -289,7 +286,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { +TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XlaBuilder builder(TestName()); auto zero = ConstantR1(&builder, {0.0f, 0.0f}); auto one = ConstantR1(&builder, {1.0f, 1.0f}); @@ -300,7 +297,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { +TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XlaBuilder builder(TestName()); auto one = ConstantR0(&builder, 1); auto two = ConstantR0(&builder, 2); @@ -313,7 +310,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { +TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto low = ConstantR1(&builder, {NAN, 1, 1}); @@ -324,7 +321,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { ComputeAndCompareR1(&builder, {false, false, false}, {}); } -XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { +TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { XlaBuilder builder(TestName()); auto zero = ConstantR0(&builder, 0); auto one = ConstantR0(&builder, 10); @@ -335,7 +332,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { +TEST_F(VecOpsSimpleTest, MapTenValues) { XlaComputation add_half; { // add_half(x) = x + 0.5 @@ -391,7 +388,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { +TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); auto y = ConstantR0(&builder, 3); @@ -401,7 +398,7 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { +TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {false, true}); auto y = ConstantR1(&builder, {true, false}); @@ -411,7 +408,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { +TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {false, true}); auto y = ConstantR1(&builder, {true, false}); @@ -421,7 +418,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, CbrtSevenValues) { +TEST_F(VecOpsSimpleTest, CbrtSevenValues) { XlaBuilder builder(TestName()); float inf = std::numeric_limits::infinity(); float qnan = std::numeric_limits::quiet_NaN(); From e450e986edd576614b424da03e425bb4bbf09d41 Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Wed, 16 Apr 2025 13:59:05 -0700 Subject: [PATCH 0886/1324] [XLA:GPU] Add simple triton support test for infeed/outfeed ops PiperOrigin-RevId: 748414319 --- .../backends/gpu/codegen/triton/support.cc | 2 - .../gpu/codegen/triton/support_test.cc | 50 +++++++++++++++++-- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index 0d6f445476b2e0..d4cfaffdd5a026 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -635,9 +635,7 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kGather: case HloOpcode::kGetTupleElement: - case HloOpcode::kInfeed: case HloOpcode::kMap: - case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kRaggedDot: case HloOpcode::kRecv: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index edf16a0115fa93..f1ef2ef5e2247d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -242,7 +242,7 @@ class TritonSupportTest : public TritonSupportTestBase { root_instruction->shape().tuple_shapes_size()); for (int64_t i = 0; i < output_tile_sizes.size(); ++i) { const auto& shape = root_instruction->shape().tuple_shapes(i); - if (shape.IsTuple()) { + if (shape.IsTuple() || shape.IsToken()) { continue; // No validation for nested tuples, as there is no way to // specify output tile sizes for them. } @@ -2696,6 +2696,50 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(AllDevicesToTest())), TritonSupportTestTypeAndDeviceToString); +using InfeedTest = TritonSupportTestWithTypeAndDeviceParam; + +TEST_P(InfeedTest, Infeed) { + auto [data_type, cc] = GetParam(); + const std::string kHloTestTemplate = R"( + ENTRY triton_computation { + token0 = token[] after-all() + ROOT infeed_op = ($0[10], token[]) infeed(token0) + })"; + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, + ParseTemplateAndGetInstruction( + kHloTestTemplate, data_type, HloOpcode::kInfeed)); + RunSupportTestMultipleOutputTiles(std::move(ti), + /*output_tile_sizes=*/{{1}, {}}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + InfeedSuite, InfeedTest, + ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllDevicesToTest())), + TritonSupportTestTypeAndDeviceToString); + +using OutfeedTest = TritonSupportTestWithTypeAndDeviceParam; + +TEST_P(OutfeedTest, Outfeed) { + auto [data_type, cc] = GetParam(); + const std::string kHloTestTemplate = R"( + ENTRY triton_computation { + data = $0[10] parameter(0) + token0 = token[] after-all() + ROOT outfeed_op = token[] outfeed(data, token0) + })"; + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( + kHloTestTemplate, data_type, + HloOpcode::kOutfeed)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc); +} + +INSTANTIATE_TEST_SUITE_P( + OutfeedSuite, OutfeedTest, + ::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()), + ::testing::ValuesIn(AllDevicesToTest())), + TritonSupportTestTypeAndDeviceToString); + constexpr std::array kUnsupportedOps = { // clang-format off // go/keep-sorted start @@ -2705,9 +2749,7 @@ constexpr std::array kUnsupportedOps = { HloOpcode::kDynamicUpdateSlice, HloOpcode::kGather, HloOpcode::kGetTupleElement, - HloOpcode::kInfeed, HloOpcode::kMap, - HloOpcode::kOutfeed, HloOpcode::kPad, HloOpcode::kRaggedDot, HloOpcode::kRecv, @@ -2769,6 +2811,8 @@ absl::flat_hash_set AllTestedOpcodes() { ret.emplace(HloOpcode::kRngGetAndUpdateState); ret.emplace(HloOpcode::kWhile); ret.emplace(HloOpcode::kFusion); + ret.emplace(HloOpcode::kInfeed); + ret.emplace(HloOpcode::kOutfeed); ret.insert(kUnsupportedOps.begin(), kUnsupportedOps.end()); return ret; From f8643ea5a61e1c7b0cb0c9a37e57e13784a3c320 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Wed, 16 Apr 2025 14:07:27 -0700 Subject: [PATCH 0887/1324] Remove auto sort util because it is not useful here. We don't want to add 50+ new lines, ok if ops are not sorted here. PiperOrigin-RevId: 748417724 --- .../stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 9112c4c6876e09..2909ada7898381 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -66,7 +66,6 @@ struct AddDependencyOpToMhoTokenConverter void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< - // go/keep-sorted start stablehlo::AbsOp, stablehlo::CbrtOp, stablehlo::SqrtOp, stablehlo::TanOp, stablehlo::AddOp, stablehlo::AddOp, stablehlo::AllGatherOp, stablehlo::Atan2Op, stablehlo::BroadcastInDimOp, stablehlo::BroadcastOp, @@ -83,7 +82,6 @@ void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { stablehlo::ShiftLeftOp, stablehlo::ShiftRightArithmeticOp, stablehlo::ShiftRightLogicalOp, stablehlo::SubtractOp, stablehlo::SignOp, stablehlo::SineOp, stablehlo::SliceOp, stablehlo::TanhOp - // go/keep-sorted end >(); } From 73bbfd083b164187dd732bfc1681395435640c7f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 14:18:34 -0700 Subject: [PATCH 0888/1324] This change proactively relaxes floating point comparisons in unit tests to avoid breaking them when updating the Eigen library in cl/747497592. Background: The Eigen update contains a number of improvements to the use of fused-multiply-and-add instructions for both float, bfloat16, and float16, as well as changes that may alter summation order in matrix multiplication. While such changes only cause minor numerical changes locally (relative changes on the order of 2^-23, 2^-10, 2^-7, for float, float16, and bfloat16, respectively), such changes may cause sensitive ("ill-conditioned") computations to deviate more significantly, as those small changes propagate. PiperOrigin-RevId: 748422021 --- tensorflow/lite/kernels/pooling_test.cc | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/kernels/pooling_test.cc b/tensorflow/lite/kernels/pooling_test.cc index 401bee6a242470..4634c5ce5e7adf 100644 --- a/tensorflow/lite/kernels/pooling_test.cc +++ b/tensorflow/lite/kernels/pooling_test.cc @@ -147,7 +147,7 @@ TEST(FloatPoolingOpTest, AveragePool) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {2.75, 5.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); } TEST(FloatPoolingOpTest, AveragePoolActivationRelu) { @@ -161,7 +161,7 @@ TEST(FloatPoolingOpTest, AveragePoolActivationRelu) { 3, 2, -10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 0.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 0.75}))); } TEST(FloatPoolingOpTest, AveragePoolActivationRelu1) { @@ -175,14 +175,14 @@ TEST(FloatPoolingOpTest, AveragePoolActivationRelu1) { -3, -2, -10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, 0.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, 0.75}))); m.SetInput({ 0, -6, -2, -4, // -3, -2, 10, -7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, -0.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, -0.75}))); } TEST(FloatPoolingOpTest, AveragePoolActivationRelu6) { @@ -196,14 +196,14 @@ TEST(FloatPoolingOpTest, AveragePoolActivationRelu6) { -3, -2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 6.0}))); m.SetInput({ 0, 6, 12, 4, // 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {2.75, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({2.75, 6.0}))); } TEST(FloatPoolingOpTest, AveragePoolPaddingSameStride1) { @@ -217,9 +217,8 @@ TEST(FloatPoolingOpTest, AveragePoolPaddingSameStride1) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT( - m.GetOutput(), - Pointwise(FloatingPointEq(), {2.75, 5.0, 5.75, 5.5, 2.5, 6.0, 8.5, 7.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {2.75, 5.0, 5.75, 5.5, 2.5, 6.0, 8.5, 7.0}))); } TEST(FloatPoolingOpTest, AveragePoolPaddingValidStride1) { @@ -233,7 +232,8 @@ TEST(FloatPoolingOpTest, AveragePoolPaddingValidStride1) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {2.75, 5.0, 5.75})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.0, 5.75}))); } TEST(QuantizedPoolingOpTest, AveragePool) { @@ -643,7 +643,7 @@ TEST(FloatPoolingOpTest, MaxPoolActivationRelu) { -3, -2, 10.5, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 10.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 10.5}))); } TEST(FloatPoolingOpTest, MaxPoolActivationRelu1) { @@ -657,14 +657,14 @@ TEST(FloatPoolingOpTest, MaxPoolActivationRelu1) { -3, -2, -0.3, 0.7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, 0.7})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, 0.7}))); m.SetInput({ -2.75, -6, -2, -4, // -3, -2, 10, -7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, 1.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, 1.0}))); } TEST(FloatPoolingOpTest, MaxPoolActivationRelu6) { @@ -678,14 +678,14 @@ TEST(FloatPoolingOpTest, MaxPoolActivationRelu6) { -3, -2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 6.0}))); m.SetInput({ 0, 4.5, 12, 4, // 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {4.5, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({4.5, 6.0}))); } TEST(FloatPoolingOpTest, MaxPoolPaddingSameStride1) { @@ -1063,7 +1063,7 @@ TEST(FloatPoolingOpTest, L2Pool) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {3.5, 6.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.5, 6.5}))); } TEST(FloatPoolingOpTest, L2PoolActivationRelu) { @@ -1118,7 +1118,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingSame) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {3.5, 6.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.5, 6.5}))); } TEST(FloatPoolingOpTest, L2PoolPaddingSameSlide1) { @@ -1149,7 +1149,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingValidSlide1) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {3.5, 6.0, 6.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.5, 6.0, 6.5}))); } #if GTEST_HAS_DEATH_TEST From 0f5713f9e95c119e448fc3b3b17414325bc442b9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 16:18:25 -0700 Subject: [PATCH 0889/1324] Add XlaOp for custom combiner BWD pass. PiperOrigin-RevId: 748459072 --- .../compiler/mlir/tensorflow/ir/tf_ops.td | 44 +++ .../transforms/legalization_op_config.cc | 2 + .../transforms/legalization_op_config_test.cc | 2 +- ...ulCustomCombinerOnTcGradWithCsrInput.pbtxt | 4 + ...ulCustomCombinerOnTcGradWithCsrInput.pbtxt | 4 + tensorflow/core/tpu/kernels/BUILD | 1 + .../core/tpu/kernels/sparse_core_xla_ops.cc | 364 ++++++++++++++++++ tensorflow/core/tpu/ops/sparse_core_ops.cc | 68 ++++ tensorflow/python/tpu/ops/BUILD | 1 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 11 files changed, 497 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 007cd3f652439e..1400785bc2d22d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -2849,6 +2849,50 @@ def TF_XlaSparseDenseMatmulGradWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulGradW TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<5>; } +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput", [AttrSizedOperandSegments, Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // The embedding table and the associated slot variables. + Variadic:$tables, + // Hyperparameters of the current optimizer. + Variadic:$hyperparameters, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + SymbolRefAttr:$optimizer_custom_computation, + StrAttr:$table_name + ); + + let results = (outs + Variadic:$updated_tables, + TF_Float32Tensor:$updated_weights + ); + + // Number of embedding table + its associated slot variables. + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<9>; + // Number of hyperparameters. + TF_DerivedOperandSizeAttr M = TF_DerivedOperandSizeAttr<10>; +} + // b/394499589: move back to tf_generated_ops.td def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, DeclareOpInterfaceMethods, Pure]> { let summary = [{ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index 25c894b5785c91..06eec56d61607b 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -371,6 +371,8 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get< TF::XlaSparseDenseMatmulGradWithSgdAndStaticBufferSizeOp>(), // NOLINT TypeID::get(), + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp>(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 3b0f27330be9c9..612f23b8d590ef 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -83,7 +83,7 @@ TEST(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 324); + EXPECT_EQ(tf2xla_fallback_count, 325); EXPECT_EQ(non_categorized_count, 431); } diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt new file mode 100644 index 00000000000000..ccc3643bbf4345 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt new file mode 100644 index 00000000000000..ccc3643bbf4345 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 4378e884a65010..8d1f8565205b91 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -201,6 +201,7 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", "@local_xla//xla/hlo/builder/lib:slicing", "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:tpu_api", diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc index 560bd16ca94977..5590d08ba00a3a 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/slicing.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" @@ -874,6 +875,369 @@ class XlaSparseDenseMatmulGradWithCsrInputOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaSparseDenseMatmulGradWithCsrInput"), XlaSparseDenseMatmulGradWithCsrInputOp); +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp + : public XlaOpKernel { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp( + OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_valency", &max_valency_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_weights", &num_weights_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &num_tables_)); + + // The updated weights is immediately after the updated tables in the output + // list. + updated_weights_index_ = num_tables_; + + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("combiner_table_vjp_computation", &name_attr)); + combiner_lookups_custom_vjp_computation_ = *name_attr; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("combiner_weights_vjp_computation", &name_attr)); + combiner_weights_custom_vjp_computation_ = *name_attr; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("optimizer_custom_computation", &name_attr)); + optimizer_custom_computation_ = *name_attr; + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, absl::Span tables_inputs, + absl::Span hyperparameters_inputs, + int32_t feature_width) { + XlaCompiler::CompileOptions options; + + // We don't use tuple args and always return tuple for this computation. + options.use_tuple_arg = false; + options.always_return_tuple = true; + options.is_entry_computation = false; + + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult custom_computation_result; + + // The number of arguments is the number of tables + the number of + // hyperparameters + 1 for the activation gradients. + int32_t num_arguments = + 1 + tables_inputs.size() + hyperparameters_inputs.size(); + + std::vector arguments(num_arguments); + + // For all the arguments, we use the float type and the shape is + // {1, feature_width}. + for (int32_t i = 0; i < num_arguments; ++i) { + arguments[i].kind = XlaCompiler::Argument::kParameter; + arguments[i].type = DT_FLOAT; + arguments[i].shape = + xla::ShapeUtil::MakeShape(xla::F32, {1, feature_width}); + } + + TF_RETURN_IF_ERROR( + compiler->CompileFunction(options, optimizer_custom_computation_, + arguments, &custom_computation_result)); + + return std::move(*custom_computation_result.computation); + } + + std::vector BuildVjpArguments(XlaOpKernelContext* ctx, + int32_t input_size, + int32_t feature_width) { + std::vector arguments; + + XlaCompiler::Argument valencies_arg; + XlaCompiler::Argument vectors_arg; + XlaCompiler::Argument weights_arg; + XlaCompiler::Argument activation_gradients_arg; + + valencies_arg.kind = XlaCompiler::Argument::kParameter; + valencies_arg.type = DT_INT32; + valencies_arg.shape = xla::ShapeUtil::MakeShape(xla::S32, {input_size}); + valencies_arg.name = "valencies"; + + vectors_arg.kind = XlaCompiler::Argument::kParameter; + vectors_arg.type = DT_FLOAT; + vectors_arg.shape = xla::ShapeUtil::MakeShape( + xla::F32, {input_size, max_valency_, feature_width}); + vectors_arg.name = "vectors"; + + weights_arg.kind = XlaCompiler::Argument::kParameter; + weights_arg.type = DT_FLOAT; + weights_arg.shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size, num_weights_}); + weights_arg.name = "weights"; + arguments.push_back(weights_arg); + + activation_gradients_arg.kind = XlaCompiler::Argument::kParameter; + activation_gradients_arg.type = DT_FLOAT; + activation_gradients_arg.shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size, feature_width}); + activation_gradients_arg.name = "activation_gradients"; + arguments.push_back(activation_gradients_arg); + + if (num_weights_ > 0) { + arguments = {valencies_arg, vectors_arg, weights_arg, + activation_gradients_arg}; + } else { + // Don't add the weights argument if it's not needed. This helps avoid + // issues of passing around zero-sized tensors and Xla values. + arguments = {valencies_arg, vectors_arg, activation_gradients_arg}; + } + + return arguments; + } + + absl::StatusOr BuildCombinerVjpComputation( + XlaOpKernelContext* ctx, int32_t input_size, int32_t feature_width, + const NameAttrList& computation) { + XlaCompiler::CompileOptions options; + options.use_tuple_arg = false; + options.always_return_tuple = false; + options.is_entry_computation = false; + + XlaCompiler* compiler = ctx->compiler(); + XlaCompiler::CompilationResult vjp_computation_result; + + TF_RETURN_IF_ERROR(compiler->CompileFunction( + options, computation, BuildVjpArguments(ctx, input_size, feature_width), + &vjp_computation_result)); + return std::move(*vjp_computation_result.computation); + } + + xla::XlaOp EmitTensorCoreComputations( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder, + xla::XlaComputation&& combiner_vectors_vjp, + xla::XlaComputation&& combiner_weights_vjp, int32_t input_size) { + xla::XlaOp weights = ctx->Input("weights"); + xla::XlaOp activation_gradients = ctx->Input("activation_gradients"); + xla::XlaOp valencies = ctx->Input("preserved_valencies"); + xla::XlaOp vectors = ctx->Input("preserved_vectors"); + + std::vector vjp_args; + if (num_weights_ > 0) { + xla::XlaOp broadcasted_weights = xla::Broadcast(weights, {input_size}); + vjp_args = {valencies, vectors, broadcasted_weights, + activation_gradients}; + } else { + vjp_args = {valencies, vectors, activation_gradients}; + } + + // Compute the lookup gradients based on the activation gradients. This + // result will be passed to SC to drive the embedding table update. + xla::XlaOp lookup_gradients = + xla::Call(builder, combiner_vectors_vjp, vjp_args); + + // Compute the weights gradients based on the activation gradients. + if (num_weights_ > 0) { + // The weights VJP returns a tensor of shape f32[input_size, num_weights]. + xla::XlaOp weights_gradients_all_samples = + xla::Call(builder, combiner_weights_vjp, vjp_args); + // Local reduction, which aggregates the contributions from all samples + // and returns a tensor of shape f32[num_weights]. + xla::XlaOp per_replica_reduced_weights_gradients = xla::Reduce( + weights_gradients_all_samples, xla::ConstantR0(builder, 0.0), + xla::CreateScalarAddComputation(xla::F32, builder), {0}); + // Global reduction, which aggregates the contributions from all replicas + // and returns a tensor of shape f32[num_weights]. + // Here we assume that all replicas participate in the all-reduce (using + // default value of `replica_groups`) and that all-reduce from different + // modules do not participate in this reduction (using default value of + // `channel_id`). + xla::XlaOp global_reduced_weights_gradients = + xla::AllReduce(per_replica_reduced_weights_gradients, + xla::CreateScalarAddComputation(xla::F32, builder)); + // Use SGD optimizer on the weights. + // TODO(peitianpan): Add support for more optimizers. + xla::XlaOp learning_rate = ctx->Input("combiner_weights_learning_rate"); + xla::XlaOp updated_weights = + weights - learning_rate * global_reduced_weights_gradients; + ctx->SetOutput(updated_weights_index_, updated_weights); + } else { + // The caller is not supposed to rely on this output if num_weights is 0. + ctx->SetOutput(updated_weights_index_, + xla::ConstantR0(builder, 0)); + } + + return lookup_gradients; + } + + void EmitSparseCoreComputations( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder, + absl::Span tables_inputs, + absl::Span tables_shapes, + absl::Span hyperparameters_inputs, + xla::XlaOp lookup_gradients, xla::XlaComputation&& optimizer, + int32_t max_ids_per_partition, int32_t max_unique_ids_per_partition) { + xla::XlaOp row_pointers = ctx->Input("row_pointers"); + xla::XlaOp sorted_sample_ids = ctx->Input("sorted_sample_ids"); + xla::XlaOp sorted_token_ids = ctx->Input("sorted_token_ids"); + xla::XlaOp sorted_pos_ids = ctx->Input("sorted_pos_ids"); + xla::XlaOp sorted_gains = ctx->Input("sorted_gains"); + + xla::FrontendAttributes original_frontend_attributes = + builder->frontend_attributes(); + + xla::FrontendAttributes tuple_frontend_attributes; + + tuple_frontend_attributes.mutable_map()->insert( + {"_xla_compute_type", "sparse"}); + + builder->SetFrontendAttributes(tuple_frontend_attributes); + + xla::XlaOp tables = xla::Tuple(ctx->builder(), tables_inputs); + + xla::XlaOp hyperparameters = + xla::Tuple(ctx->builder(), hyperparameters_inputs); + + std::vector xla_tables_shapes; + + xla_tables_shapes.reserve(tables_shapes.size()); + for (const auto& table_shape : tables_shapes) { + xla_tables_shapes.push_back(xla::ShapeUtil::MakeShape( + xla::F32, {table_shape.dim_size(0), table_shape.dim_size(1)})); + } + + xla::Shape tables_shape = xla::ShapeUtil::MakeTupleShape(xla_tables_shapes); + + xla::FrontendAttributes custom_call_frontend_attributes; + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_compute_type", "sparse"}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_sharding_strategy", "mod"}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_pad_value", absl::StrCat(kXlaPadValue)}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_max_ids_per_partition", absl::StrCat(max_ids_per_partition)}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_max_unique_ids_per_partition", + absl::StrCat(max_unique_ids_per_partition)}); + + builder->SetFrontendAttributes(custom_call_frontend_attributes); + + xla::XlaOp updated_tables = xla::CustomCallWithComputation( + builder, + "SparseDenseMatmulCustomCombinerTcCombinerGradOptimizerUpdateMegachipO" + "p", + {row_pointers, sorted_token_ids, sorted_sample_ids, sorted_pos_ids, + sorted_gains, tables, lookup_gradients, hyperparameters}, + optimizer, tables_shape); + + builder->SetFrontendAttributes(tuple_frontend_attributes); + + // Updated embedding table. + for (int i = 0; i < tables_shape.tuple_shapes_size(); ++i) { + ctx->SetOutput(i, xla::GetTupleElement(updated_tables, i)); + } + + builder->SetFrontendAttributes(original_frontend_attributes); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + + // Get the shape of the gradient. + OP_REQUIRES_VALUE(xla::Shape activation_shape, ctx, + ctx->InputXlaShape("activation_gradients")); + OP_REQUIRES( + ctx, + activation_shape.is_static() && activation_shape.dimensions_size() == 2, + absl::InvalidArgumentError(absl::StrCat( + "activations input has non static or non-rank 2 shape: ", + activation_shape.ToString()))); + OP_REQUIRES_VALUE(int64_t num_sparsecores_per_chip, ctx, + GetSparseCoresPerChip()); + int64_t num_samples_per_chip = activation_shape.dimensions(0); + OP_REQUIRES(ctx, num_samples_per_chip % num_sparsecores_per_chip == 0, + absl::InvalidArgumentError(absl::StrCat( + "num_samples_per_chip ", num_samples_per_chip, + " not divisible by the number of sparsecores per chip ", + num_sparsecores_per_chip))); + + std::vector tables_inputs; + std::vector tables_shapes; + OP_REQUIRES_OK(ctx, + ctx->InputList("tables", &tables_inputs, &tables_shapes)); + OP_REQUIRES(ctx, num_tables_ == tables_inputs.size(), + absl::InvalidArgumentError( + absl::StrCat("Expecting ", num_tables_, " tables, but got ", + tables_inputs.size()))); + + std::vector hyperparameters_inputs; + std::vector hyperparameters_shapes; + OP_REQUIRES_OK(ctx, + ctx->InputList("hyperparameters", &hyperparameters_inputs, + &hyperparameters_shapes)); + + int64_t per_sparse_core_batch_size = + num_samples_per_chip / num_sparsecores_per_chip; + int64_t max_ids_per_partition = 0; + int64_t max_unique_ids_per_partition = 0; + + const int32_t feature_width = tables_shapes[0].dim_size(1); + OP_REQUIRES_OK( + ctx, GetMaxIdsAndUniquesExternal(kUnknownProgramKey, table_name_, + per_sparse_core_batch_size, + feature_width, &max_ids_per_partition, + &max_unique_ids_per_partition)); + LOG(INFO) + << "Lowering XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp " + << "to HLO: table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; + + // Build the required computations -- one for the optimizer and two for the + // custom combiner. + int32_t input_size = activation_shape.dimensions(0); + OP_REQUIRES_VALUE( + xla::XlaComputation optimizer, ctx, + BuildOptimizerComputation(ctx, tables_inputs, hyperparameters_inputs, + feature_width)); + OP_REQUIRES_VALUE( + xla::XlaComputation combiner_vectors_vjp, ctx, + BuildCombinerVjpComputation(ctx, input_size, feature_width, + combiner_lookups_custom_vjp_computation_)); + OP_REQUIRES_VALUE( + xla::XlaComputation combiner_weights_vjp, ctx, + BuildCombinerVjpComputation(ctx, input_size, feature_width, + combiner_weights_custom_vjp_computation_)); + + // Emit the two custom combiner VJP computations onto TC. + xla::XlaOp lookup_gradients = EmitTensorCoreComputations( + ctx, builder, std::move(combiner_vectors_vjp), + std::move(combiner_weights_vjp), input_size); + + // Pass the TC activation gradients back to SC for back-propagation with + // optimizer. + EmitSparseCoreComputations(ctx, builder, tables_inputs, tables_shapes, + hyperparameters_inputs, lookup_gradients, + std::move(optimizer), max_ids_per_partition, + max_unique_ids_per_partition); + } + + private: + int32_t max_valency_; + int32_t num_weights_; + int32_t num_tables_; + int32_t updated_weights_index_; + std::string table_name_; + NameAttrList optimizer_custom_computation_; + NameAttrList combiner_weights_custom_vjp_computation_; + NameAttrList combiner_lookups_custom_vjp_computation_; + + XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp&) = delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp&) = delete; +}; + +REGISTER_XLA_OP(Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp); + // This TensorFlow op calculates the gradients and performs SGD update on the // embedding table on SparseCore. It takes the activation gradients, input // sparse tensor represented by the `row_pointers`, `sorted_embedding_ids`, diff --git a/tensorflow/core/tpu/ops/sparse_core_ops.cc b/tensorflow/core/tpu/ops/sparse_core_ops.cc index 36c9248aad14ab..b24197c2d40917 100644 --- a/tensorflow/core/tpu/ops/sparse_core_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_ops.cc @@ -597,4 +597,72 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithCsrInput") return absl::OkStatus(); }); +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + // We need to preserve the outputs of the SC forward pass and feed them into + // the VJP computations in the backward pass. + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("activation_gradients: float32") + .Input("tables: N * float32") + .Input("hyperparameters: M * float32") + .Input("combiner_weights_learning_rate: float32") + .Output("updated_tables: N * float32") + .Output("updated_weights: float32") + .Attr("N: int >= 1") + .Attr("M: int >= 1") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("optimizer_custom_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kActivationGradientsIndex = 8; + constexpr int kTablesIndex = 9; + shape_inference::ShapeHandle shape; + int num_tables; + int num_weights; + int max_valency_int; + TF_RETURN_IF_ERROR(c->GetAttr("N", &num_tables)); + TF_RETURN_IF_ERROR(c->GetAttr("num_weights", &num_weights)); + TF_RETURN_IF_ERROR(c->GetAttr("max_valency", &max_valency_int)); + // Only check the shape of the weights when num_weights > 0 to avoid + // issues of 0-shaped values. + if (num_weights > 0) { + TF_RETURN_IF_ERROR(c->Merge(c->input(kWeightsIndex), + c->MakeShape({c->MakeDim(num_weights)}), + &shape)); + } + // Check that the preserved tensors have the expected shapes: + // valencies: [input_size]; + // vectors: [input_size, max_valency, feature_width]; + auto input_size = c->Dim(c->input(kActivationGradientsIndex), 0); + auto max_valency = c->MakeDim(max_valency_int); + auto feature_width = c->Dim(c->input(kTablesIndex), 1); + TF_RETURN_IF_ERROR(c->Merge(c->input(kPreservedValenciesIndex), + c->MakeShape({input_size}), &shape)); + TF_RETURN_IF_ERROR(c->Merge( + c->input(kPreservedVectorsIndex), + c->MakeShape({input_size, max_valency, feature_width}), &shape)); + // `updated_tables` refers to both the embedding table and the associated + // slot variables. They all have the same embedding table shape. + for (int i = 0; i < num_tables; ++i) { + c->set_output(i, c->input(kTablesIndex)); + } + // `updated_weights` simply have a 1D shape of `num_weights`. + // TODO(peitianpan): Do we need to account for the number of replicas + // here? + c->set_output(num_tables, c->MakeShape({c->MakeDim(num_weights)})); + return absl::OkStatus(); + }); + } // namespace tensorflow diff --git a/tensorflow/python/tpu/ops/BUILD b/tensorflow/python/tpu/ops/BUILD index bf0bfc21f63a16..9aa2672336a35b 100644 --- a/tensorflow/python/tpu/ops/BUILD +++ b/tensorflow/python/tpu/ops/BUILD @@ -68,6 +68,7 @@ tf_gen_op_wrapper_py( "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput", "XlaSparseDenseMatmulGrad", "XlaSparseDenseMatmulGradWithCsrInput", + "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput", "XlaSparseDenseMatmulGradWithSgdAndCsrInput", "XlaSparseDenseMatmulGradWithAdagradAndCsrInput", "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput", diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index b3885a3b5b3d84..1a298ddbd4f412 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -5740,6 +5740,10 @@ tf_module { name: "XlaSparseDenseMatmul" argspec: "args=[\'row_ids\', \'col_ids\', \'values\', \'offsets\', \'embedding_table\', \'max_ids_per_partition\', \'max_unique_ids_per_partition\', \'input_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_pos_ids\', \'sorted_gains\', \'weights\', \'preserved_valencies\', \'preserved_vectors\', \'activation_gradients\', \'tables\', \'hyperparameters\', \'combiner_weights_learning_rate\', \'max_valency\', \'num_weights\', \'combiner_table_vjp_computation\', \'combiner_weights_vjp_computation\', \'optimizer_custom_computation\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_pos_ids\', \'sorted_gains\', \'embedding_table\', \'weights\', \'input_size\', \'max_valency\', \'num_weights\', \'combiner_computation\', \'quantization_config_low\', \'quantization_config_high\', \'quantization_config_num_buckets\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index b3885a3b5b3d84..1a298ddbd4f412 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -5740,6 +5740,10 @@ tf_module { name: "XlaSparseDenseMatmul" argspec: "args=[\'row_ids\', \'col_ids\', \'values\', \'offsets\', \'embedding_table\', \'max_ids_per_partition\', \'max_unique_ids_per_partition\', \'input_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_pos_ids\', \'sorted_gains\', \'weights\', \'preserved_valencies\', \'preserved_vectors\', \'activation_gradients\', \'tables\', \'hyperparameters\', \'combiner_weights_learning_rate\', \'max_valency\', \'num_weights\', \'combiner_table_vjp_computation\', \'combiner_weights_vjp_computation\', \'optimizer_custom_computation\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_pos_ids\', \'sorted_gains\', \'embedding_table\', \'weights\', \'input_size\', \'max_valency\', \'num_weights\', \'combiner_computation\', \'quantization_config_low\', \'quantization_config_high\', \'quantization_config_num_buckets\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " From df58c4957147ddb1db2947bc9ee17433d2510a74 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Wed, 16 Apr 2025 16:30:58 -0700 Subject: [PATCH 0890/1324] PR #25273: Move fusion-dispatch to RunBackend. Imported from GitHub PR https://github.com/openxla/xla/pull/25273 The intention is that this pass runs after all passes that can create or modify fusions. However, this is currently not the case, since there are a couple of passes in RunBackend that violate this condition. The tests didn't catch this because they only checked the passes in RunHloPasses, not the ones in RunBackend. Copybara import of the project: -- 9b09c79375964b2c6d2e3d262d56b49d49cacdba by Johannes Reifferscheid : Move fusion-dispatch to RunBackend. Currently, there are still a couple of passes that can create new fusions after fusion-dispatch. This should not happen. The tests didn't catch this because they only checked the passes in RunHloPasses, not the ones in RunBackend. -- d2365bec337485a786ecd70a0d6dd59f29c9e506 by Johannes Reifferscheid : Add more comments to the test. Merging this change closes #25273 PiperOrigin-RevId: 748462531 --- .../xla/xla/service/gpu/gpu_compiler.cc | 20 ++--- .../xla/xla/service/gpu/gpu_compiler_test.cc | 85 ++++++++++++------- 2 files changed, 61 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index c78aef3f249562..7ca3a342f3f06a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1847,17 +1847,6 @@ absl::StatusOr> GpuCompiler::RunHloPasses( TF_RETURN_IF_ERROR( RunPreSchedulingCopyInsertion(*module, device_description)); - const auto* cuda_cc = std::get_if( - &device_description.gpu_compute_capability()); - if (cuda_cc != nullptr && cuda_cc->IsAtLeastAmpere()) { - // This needs to run after every pass affecting fusions, which includes - // `CopyFusion`, which runs just before. - TF_RETURN_IF_ERROR( - FusionDispatchPipeline(device_description, ShapeSizeBytesFunction()) - .Run(module.get()) - .status()); - } - uint64_t end_usecs = tsl::Env::Default()->NowMicros(); // This won't record values for calls that error out (because if they error @@ -2757,6 +2746,15 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( pipeline.AddPass(gpu_device_info); } + const auto* cuda_cc = std::get_if( + &gpu_device_info.gpu_compute_capability()); + if (cuda_cc != nullptr && cuda_cc->IsAtLeastAmpere()) { + // This needs to run after every pass affecting fusions. The last passes + // that create new fusions are FusionWrapper and StreamAttributeAnnotator. + main_pipeline.AddPass( + FusionDispatchPipeline(gpu_device_info, ShapeSizeBytesFunction())); + } + // Pipeline with passes which wrap a scheduled module into command buffers. { HloPassPipeline& pipeline = diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index cff463690608fc..cf297b341e3c30 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -106,6 +106,24 @@ class GpuCompilerTest : public HloTestBase { return tensorflow::down_cast(compiler) ->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info); } + + // Like GetOptimizedModule, but also runs the backend. This is important for + // tests that need to verify behavior of passes that run in RunBackend. The + // former function will only run the passes in RunHloPasses. + // This returns the module and the executable because the latter owns the + // former. + absl::StatusOr>> + GetOptimizedModuleForExecutable(absl::string_view hlo, + const HloModuleConfig& config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo, config)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); + TF_ASSIGN_OR_RETURN(const HloModule* optimized_module, + test_runner().HloModuleFromWrapped(executable.get())); + return {{optimized_module, std::move(executable)}}; + } }; // TODO(b/399912696): Fix and enable this test. @@ -1184,10 +1202,10 @@ ENTRY main { })"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnVerifiedModule(transpose_fusion_module)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(std::move(module))); + auto module_and_executable, + GetOptimizedModuleForExecutable(transpose_fusion_module, + GetModuleConfigForTest())); + const HloModule* optimized_module = module_and_executable.first; if (cc.IsAtLeastAmpere()) { EXPECT_TRUE(HasBlockLevelFusionConfig( @@ -1223,12 +1241,11 @@ ENTRY main { })"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr rewritable_transpose_module, - ParseAndReturnVerifiedModule(rewritable_transpose_string)); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr rewritable_transpose_optimized_module, - GetOptimizedModule(std::move(rewritable_transpose_module))); + auto rewritable_transpose_module_and_executable, + GetOptimizedModuleForExecutable(rewritable_transpose_string, + GetModuleConfigForTest())); + const HloModule* rewritable_transpose_optimized_module = + rewritable_transpose_module_and_executable.first; EXPECT_TRUE(HasBlockLevelFusionConfig( rewritable_transpose_optimized_module->entry_computation() ->root_instruction())); @@ -1247,8 +1264,11 @@ ENTRY main { ParseAndReturnVerifiedModule(unrewritable_transpose_string)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr unrewritable_transpose_optimized_module, - GetOptimizedModule(std::move(unrewritable_transpose_module))); + auto unrewritable_transpose_module_and_executable, + GetOptimizedModuleForExecutable(unrewritable_transpose_string, + GetModuleConfigForTest())); + const HloModule* unrewritable_transpose_optimized_module = + unrewritable_transpose_module_and_executable.first; EXPECT_FALSE(HasBlockLevelFusionConfig( unrewritable_transpose_optimized_module->entry_computation() ->root_instruction())); @@ -1476,7 +1496,7 @@ class PassOrderTest : public GpuCompilerTest { int other_pass_first_run = std::numeric_limits::max(); int run_index = 0; for (const HloPassMetadata& pass_metadata : - optimized_module_->metadata()->proto().pass_metadata()) { + optimized_module_->metadata().proto().pass_metadata()) { if (RE2::FullMatch(pass_metadata.pass_name(), first_pass_regex)) { VLOG(2) << "Pass " << pass_metadata.pass_name() << " matches first_pass_regex." << std::endl; @@ -1517,7 +1537,7 @@ class PassOrderTest : public GpuCompilerTest { int last_pass_earliest_run = std::numeric_limits::max(); int run_index = 0; for (const HloPassMetadata& pass_metadata : - optimized_module_->metadata()->proto().pass_metadata()) { + optimized_module_->metadata().proto().pass_metadata()) { std::string name = pass_metadata.pass_name(); if (include_pipeline_name) { name = absl::StrCat(pass_metadata.pipeline_name(), ".", @@ -1549,9 +1569,10 @@ class PassOrderTest : public GpuCompilerTest { // `pass_range.first_pass_run_index` and `pass_range.second_pass_run_index`. void VerifyNotRunInBetween(const PassRange& pass_range, absl::string_view pass_regex) { + CHECK(optimized_module_); int run_index = 0; for (const HloPassMetadata& pass_metadata : - optimized_module_->metadata()->proto().pass_metadata()) { + optimized_module_->metadata().proto().pass_metadata()) { if (run_index >= pass_range.second_pass_run_index) { break; } @@ -1564,21 +1585,22 @@ class PassOrderTest : public GpuCompilerTest { } protected: - absl::Status ScheduleModule() { return Schedule(optimized_module_.get()); } - + // Compiles a dummy module with the given configuration, running all passes, + // including the ones in RunBackend. This is important because otherwise, we + // might miss some passes when verifying pass order. void CompileModule(const HloModuleConfig& config) { constexpr absl::string_view constant_module = R"( -ENTRY main { - ROOT constant = f32[] constant(0) -})"; + ENTRY main { + ROOT constant = f32[] constant(0) + })"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnVerifiedModule(constant_module, config)); - TF_ASSERT_OK_AND_ASSIGN(optimized_module_, - GetOptimizedModule(std::move(module))); + std::tie(optimized_module_, compiled_executable_), + GetOptimizedModuleForExecutable(constant_module, config)); } - std::unique_ptr optimized_module_; + // Owns the optimized_module_ below. + std::unique_ptr compiled_executable_ = nullptr; + const HloModule* optimized_module_ = nullptr; }; TEST_F(PassOrderTest, PassesAreRunInCorrectOrder) { @@ -1620,9 +1642,10 @@ TEST_F(PassOrderTest, FusionDispatchRunsAfterAllFusionPasses) { true); SetDebugOptions(debug_options); - VerifyPassOrder(/*first_pass_regex=*/".*fusion.*", - /*last_pass_regex=*/"fusion-dispatch-pipeline.*", - /*include_pipeline_name=*/true); + VerifyPassOrder( + /*first_pass_regex=*/".*(fusion|stream-attribute-annotator).*", + /*last_pass_regex=*/"fusion-dispatch-pipeline.*", + /*include_pipeline_name=*/true); } TEST_F(PassOrderTest, @@ -1649,7 +1672,7 @@ TEST_F(PassOrderTest, StableSortExpanderRunsAfterDynamicPadder) { MATCHER_P(HasExpectedPasses, expected_pass_names, "") { std::vector run_pass_names; - auto metadata = arg->metadata()->proto(); + auto metadata = arg->metadata().proto(); run_pass_names.reserve(metadata.pass_metadata_size()); for (auto& pass_metadata : metadata.pass_metadata()) { run_pass_names.push_back(pass_metadata.pass_name()); @@ -1660,7 +1683,6 @@ MATCHER_P(HasExpectedPasses, expected_pass_names, "") { TEST_F(PassOrderTest, ExecEffortAt0point2RunsSpecifiedPasses) { HloModuleConfig config = GetModuleConfigForTest(); CompileModule(config); - TF_ASSERT_OK(ScheduleModule()); // Make sure passes are not enabled by default. std::vector kExpectedPasses = { @@ -1675,7 +1697,6 @@ TEST_F(PassOrderTest, ExecEffortAt0point2RunsSpecifiedPasses) { // enabled. config.set_exec_time_optimization_effort(0.2); CompileModule(config); - TF_ASSERT_OK(ScheduleModule()); EXPECT_THAT(optimized_module_, HasExpectedPasses(kExpectedPasses)); } @@ -1687,7 +1708,6 @@ TEST_F(PassOrderTest, LHSRunsIfProfileDataIsAvailable) { "latency-hiding-scheduler", }; CompileModule(config); - TF_ASSERT_OK(ScheduleModule()); EXPECT_THAT(optimized_module_, Not(HasExpectedPasses(kExpectedPasses))); // Make sure we turn the LHS on with we schedule with profile data. @@ -1696,7 +1716,6 @@ TEST_F(PassOrderTest, LHSRunsIfProfileDataIsAvailable) { )pb"; config.set_fdo_profile(kProfile); CompileModule(config); - TF_ASSERT_OK(ScheduleModule()); EXPECT_THAT(optimized_module_, HasExpectedPasses(kExpectedPasses)); } From 45a01d94be08a30cc8688424392dae354c79e767 Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Wed, 16 Apr 2025 16:45:42 -0700 Subject: [PATCH 0891/1324] [XLA:CPU][XLA:GPU][collectives] Change communicator interface to return tsl::AsyncValueRef This change will allow us to switch collective implementations to async mode. PiperOrigin-RevId: 748466746 --- .../xla/xla/backends/cpu/collectives/BUILD | 4 + .../cpu/collectives/gloo_collectives_test.cc | 11 ++- .../cpu/collectives/gloo_communicator.cc | 40 ++++----- .../cpu/collectives/gloo_communicator.h | 70 +++++++-------- .../collectives/in_process_communicator.cc | 70 ++++++++++----- .../cpu/collectives/in_process_communicator.h | 70 +++++++-------- .../cpu/collectives/mpi_communicator.cc | 53 ++++++------ .../cpu/collectives/mpi_communicator.h | 66 +++++++------- .../xla/xla/backends/cpu/runtime/BUILD | 5 ++ .../backends/cpu/runtime/all_gather_thunk.cc | 22 ++++- .../backends/cpu/runtime/all_reduce_thunk.cc | 25 ++++-- .../backends/cpu/runtime/all_to_all_thunk.cc | 9 +- .../cpu/runtime/collective_permute_thunk.cc | 20 ++++- .../backends/cpu/runtime/collective_thunk.cc | 23 ++++- .../backends/cpu/runtime/collective_thunk.h | 5 +- .../cpu/runtime/reduce_scatter_thunk.cc | 18 +++- .../xla/xla/backends/gpu/collectives/BUILD | 4 + .../gpu/collectives/nccl_communicator.cc | 86 +++++++++++-------- .../gpu/collectives/nccl_communicator.h | 75 ++++++++-------- .../gpu/collectives/nccl_communicator_test.cc | 29 +++++-- .../xla/xla/backends/gpu/runtime/BUILD | 8 ++ .../backends/gpu/runtime/all_gather_thunk.cc | 10 ++- .../backends/gpu/runtime/all_reduce_thunk.cc | 19 +++- .../backends/gpu/runtime/all_to_all_thunk.cc | 43 ++++++++-- .../gpu/runtime/collective_broadcast_thunk.cc | 10 ++- .../gpu/runtime/collective_permute_thunk.cc | 10 ++- .../gpu/runtime/ragged_all_to_all_thunk.cc | 45 +++++++--- .../xla/backends/gpu/runtime/recv_thunk.cc | 10 ++- .../xla/backends/gpu/runtime/send_thunk.cc | 10 ++- third_party/xla/xla/core/collectives/BUILD | 1 + .../xla/xla/core/collectives/communicator.h | 78 ++++++++++------- third_party/xla/xla/service/cpu/BUILD | 1 + .../xla/xla/service/cpu/cpu_runtime.cc | 33 ++++--- 33 files changed, 627 insertions(+), 356 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/collectives/BUILD b/third_party/xla/xla/backends/cpu/collectives/BUILD index ac5ec02cab5e3d..ea47b4977c11a4 100644 --- a/third_party/xla/xla/backends/cpu/collectives/BUILD +++ b/third_party/xla/xla/backends/cpu/collectives/BUILD @@ -139,6 +139,7 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:rendezvous", "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/math:math_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", @@ -221,6 +222,7 @@ xla_cc_test( "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", @@ -262,6 +264,7 @@ cc_library( "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", @@ -325,6 +328,7 @@ cc_library( "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc index c4a9009e73c884..9a8583a75ca17d 100644 --- a/third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -110,9 +111,15 @@ absl::StatusOr> AllReduce( GetCommunicator(kNumParticipants, global_devices, kv_store, rank)); CpuCollectives::Executor executor(rendezvous_key, kTimeout); - TF_RETURN_IF_ERROR(communicator->AllReduce( + auto event = communicator->AllReduce( AsDeviceMemory(input_buffer), AsDeviceMemory(output_buffer), - xla::PrimitiveType::U8, kBufferSize, xla::ReductionKind::SUM, executor)); + xla::PrimitiveType::U8, kBufferSize, xla::ReductionKind::SUM, executor); + + tsl::BlockUntilReady(event); + + if (event.IsError()) { + return event.GetError(); + } return output_buffer; } diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc index ff24895d48d62f..1ddd484f2e73d5 100644 --- a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/types.h" @@ -101,11 +102,10 @@ static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, return absl::OkStatus(); } -absl::Status GlooCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef GlooCommunicator::AllReduce( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); gloo::AllreduceOptions options(context_); @@ -181,12 +181,12 @@ absl::Status GlooCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, return absl::UnknownError( absl::StrCat("Gloo all-reduce failed: ", e.what())); } - return absl::OkStatus(); + return OkEvent(); } static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; -absl::Status GlooCommunicator::CollectivePermute( +tsl::AsyncValueRef GlooCommunicator::CollectivePermute( se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, std::optional source_rank, absl::Span target_ranks, const Executor& executor) { @@ -234,10 +234,10 @@ absl::Status GlooCommunicator::CollectivePermute( return absl::UnknownError( absl::StrCat("Gloo collective permute failed: ", e.what())); } - return absl::OkStatus(); + return OkEvent(); } -absl::Status GlooCommunicator::AllToAll( +tsl::AsyncValueRef GlooCommunicator::AllToAll( absl::Span send_buffers, absl::Span recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) { @@ -290,13 +290,12 @@ absl::Status GlooCommunicator::AllToAll( return absl::UnknownError( absl::StrCat("Gloo all-to-all failed: ", e.what())); } - return absl::OkStatus(); + return OkEvent(); } -absl::Status GlooCommunicator::AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - const Executor& executor) { +tsl::AsyncValueRef GlooCommunicator::AllGather( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, const Executor& executor) { uint32_t tag = 0; // TODO(phawkins): use better tags. TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); @@ -315,7 +314,7 @@ absl::Status GlooCommunicator::AllGather(se::DeviceMemoryBase send_buffer, return absl::UnknownError( absl::StrCat("Gloo AllGather failed: ", e.what())); } - return absl::OkStatus(); + return OkEvent(); } template @@ -367,11 +366,10 @@ absl::Status ReduceScatterHelper(std::shared_ptr context, return absl::OkStatus(); } -absl::Status GlooCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef GlooCommunicator::ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); std::unique_ptr temp(new char[chunk_bytes * context_->size]); std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size); @@ -437,7 +435,7 @@ absl::Status GlooCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, return absl::InvalidArgumentError("Unknown datatype in reducescatter"); } std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes); - return absl::OkStatus(); + return OkEvent(); } } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h index 234716da759340..6ab36a9263f890 100644 --- a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h @@ -30,6 +30,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -42,46 +43,47 @@ class GlooCommunicator : public Communicator { size_t num_ranks); ~GlooCommunicator() override; - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, - PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + tsl::AsyncValueRef CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) override; + + tsl::AsyncValueRef AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + tsl::AsyncValueRef AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + tsl::AsyncValueRef ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + tsl::AsyncValueRef Broadcast(se::DeviceMemoryBase, + se::DeviceMemoryBase, PrimitiveType, + size_t, RankId, + const Executor&) override { return Unimplemented("Broadcast is not implemented"); } - absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef Send(se::DeviceMemoryBase, PrimitiveType, size_t, + RankId, const Executor&) override { return Unimplemented("Send is not implemented"); } - absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef Recv(se::DeviceMemoryBase, PrimitiveType, size_t, + RankId, const Executor&) override { return Unimplemented("Recv is not implemented"); } diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc index 95baf3874dbfc7..2906860609fb70 100644 --- a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/rendezvous.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/lib/math/math_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -393,11 +394,12 @@ static absl::Status CollectivePermuteOp( InProcessCommunicator::InProcessCommunicator(size_t rank, size_t num_ranks) : rank_(rank), num_ranks_(num_ranks) {} -absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef +InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); const RendezvousKey& key = cpu_executor->rendezvous_key(); @@ -409,13 +411,18 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, name, key, partiticipant, key.num_local_participants, CollectParticipants)); - return op->Invoke(AllReduceOp, rank_, dtype, count, reduction_kind); + TF_RETURN_IF_ERROR( + op->Invoke(AllReduceOp, rank_, dtype, count, reduction_kind)); + + return OkEvent(); } -absl::Status InProcessCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef +InProcessCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); const RendezvousKey& key = cpu_executor->rendezvous_key(); @@ -427,13 +434,19 @@ absl::Status InProcessCommunicator::ReduceScatter( name, key, partiticipant, key.num_local_participants, CollectParticipants)); - return op->Invoke(ReduceScatterOp, rank_, dtype, count, reduction_kind); + TF_RETURN_IF_ERROR( + op->Invoke(ReduceScatterOp, rank_, dtype, count, reduction_kind)); + + return OkEvent(); } -absl::Status InProcessCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { +tsl::AsyncValueRef +InProcessCommunicator::CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) { TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); const RendezvousKey& key = cpu_executor->rendezvous_key(); @@ -447,10 +460,14 @@ absl::Status InProcessCommunicator::CollectivePermute( CollectParticipants)); size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return op->Invoke(CollectivePermuteOp, rank_, num_bytes); + + TF_RETURN_IF_ERROR(op->Invoke(CollectivePermuteOp, rank_, num_bytes)); + + return OkEvent(); } -absl::Status InProcessCommunicator::AllToAll( +tsl::AsyncValueRef +InProcessCommunicator::AllToAll( absl::Span send_buffers, absl::Span recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) { @@ -468,13 +485,17 @@ absl::Status InProcessCommunicator::AllToAll( CollectParticipants)); size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return op->Invoke(AllToAllOp, rank_, num_bytes); + + TF_RETURN_IF_ERROR(op->Invoke(AllToAllOp, rank_, num_bytes)); + + return OkEvent(); } -absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - const Executor& executor) { +tsl::AsyncValueRef +InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); const RendezvousKey& key = cpu_executor->rendezvous_key(); @@ -487,7 +508,10 @@ absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, CollectParticipants)); size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return op->Invoke(AllGatherOp, rank_, num_bytes); + + TF_RETURN_IF_ERROR(op->Invoke(AllGatherOp, rank_, num_bytes)); + + return OkEvent(); } } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h index f4366c858f6608..4a28ff99da14a5 100644 --- a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -39,46 +40,47 @@ class InProcessCommunicator : public Communicator { public: InProcessCommunicator(size_t rank, size_t num_ranks); - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, - PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + tsl::AsyncValueRef CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) override; + + tsl::AsyncValueRef AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + tsl::AsyncValueRef AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + tsl::AsyncValueRef ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + tsl::AsyncValueRef Broadcast(se::DeviceMemoryBase, + se::DeviceMemoryBase, PrimitiveType, + size_t, RankId, + const Executor&) override { return Unimplemented("Broadcast is not implemented"); } - absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef Send(se::DeviceMemoryBase, PrimitiveType, size_t, + RankId, const Executor&) override { return Unimplemented("Send is not implemented"); } - absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef Recv(se::DeviceMemoryBase, PrimitiveType, size_t, + RankId, const Executor&) override { return Unimplemented("Recv is not implemented"); } diff --git a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc index b863c2eca62bc3..5d6e9fe01a940a 100644 --- a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/xla_data.pb.h" @@ -123,18 +124,18 @@ MpiCommunicator::MpiCommunicator(int color, int key) { MpiCommunicator::~MpiCommunicator() { MPI_Comm_free(&comm_); }; -absl::Status MpiCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef MpiCommunicator::AllReduce( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus(MPI_Allreduce( - send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_)); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus(MPI_Allreduce( + send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_))); + return OkEvent(); } -absl::Status MpiCommunicator::CollectivePermute( +tsl::AsyncValueRef MpiCommunicator::CollectivePermute( se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, std::optional source_rank, absl::Span target_ranks, const Executor& executor) { @@ -175,10 +176,10 @@ absl::Status MpiCommunicator::CollectivePermute( MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); } - return absl::OkStatus(); + return OkEvent(); } -absl::Status MpiCommunicator::AllToAll( +tsl::AsyncValueRef MpiCommunicator::AllToAll( absl::Span send_buffers, absl::Span recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) { @@ -212,31 +213,33 @@ absl::Status MpiCommunicator::AllToAll( recv_rank, tag, comm_, MPI_STATUS_IGNORE))); } - return absl::OkStatus(); + return OkEvent(); } -absl::Status MpiCommunicator::AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - const Executor& executor) { +tsl::AsyncValueRef MpiCommunicator::AllGather( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, const Executor& executor) { TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type, - recv_buffer.opaque(), count, type, - comm_)); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Allgather(send_buffer.opaque(), count, type, recv_buffer.opaque(), + count, type, comm_))); + + return OkEvent(); } -absl::Status MpiCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef MpiCommunicator::ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { const int size = mpi_size_; std::vector recvcounts(size, count); TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus( + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(), - recvcounts.data(), type, op, comm_)); + recvcounts.data(), type, op, comm_))); + + return OkEvent(); } } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h index cfed534b66bd51..0e6aff969d9a86 100644 --- a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -39,44 +40,45 @@ class MpiCommunicator : public Communicator { explicit MpiCommunicator(int color, int key); ~MpiCommunicator() override; - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, - PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + tsl::AsyncValueRef CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) override; + + tsl::AsyncValueRef AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) override; + tsl::AsyncValueRef AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + tsl::AsyncValueRef ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + tsl::AsyncValueRef Broadcast(se::DeviceMemoryBase, + se::DeviceMemoryBase, PrimitiveType, + size_t, RankId, + const Executor&) override { return Unimplemented("Broadcast is not implemented"); } - absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef Send(se::DeviceMemoryBase, PrimitiveType, size_t, + RankId, const Executor&) override { return Unimplemented("Send is not implemented"); } - absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { + tsl::AsyncValueRef Recv(se::DeviceMemoryBase, PrimitiveType, size_t, + RankId, const Executor&) override { return Unimplemented("Recv is not implemented"); } diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 009d8a8c6b5a73..b7d7a769fc5955 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -392,6 +392,7 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -515,6 +516,7 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -566,6 +568,7 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -594,6 +597,7 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -633,6 +637,7 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc index e23905e5fc6f59..3a941a95aebc8e 100644 --- a/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/memory/memory.h" @@ -74,16 +75,29 @@ tsl::AsyncValueRef AllGatherThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, Communicator& comm) { + [&](const RendezvousKey& key, + Communicator& comm) -> tsl::AsyncValueRef { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); + tsl::CountDownAsyncValueRef state( + data.source.size()); + for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = source_shape(i); - TF_RETURN_IF_ERROR(comm.AllGather( + auto communicator_event = comm.AllGather( data.source[i], data.destination[i], shape.element_type(), - ShapeUtil::ElementsIn(shape), executor)); + ShapeUtil::ElementsIn(shape), executor); + + communicator_event.AndThen([state, communicator_event]() mutable { + if (ABSL_PREDICT_FALSE(communicator_event.IsError())) { + state.CountDown(communicator_event.GetError()); + } else { + state.CountDown(); + } + }); } - return absl::OkStatus(); + + return state.AsRef(); }); } diff --git a/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc index 9db71d2d878d74..9376f80bd1776c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -68,7 +69,6 @@ AllReduceThunk::AllReduceThunk(Info info, ReductionKind reduction_kind, tsl::AsyncValueRef AllReduceThunk::Execute( const ExecuteParams& params) { - TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); VLOG(3) << absl::StreamFormat( @@ -101,18 +101,29 @@ tsl::AsyncValueRef AllReduceThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, Communicator& comm) { + [&, data = std::move(data)](const RendezvousKey& key, Communicator& comm) + -> tsl::AsyncValueRef { + tsl::CountDownAsyncValueRef state( + data.source.size()); + CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = destination_shape(i); - TF_RETURN_IF_ERROR(comm.AllReduce( + + auto communicator_event = comm.AllReduce( data.source[i], data.destination[i], shape.element_type(), - ShapeUtil::ElementsIn(shape), reduction_kind_, executor)); + ShapeUtil::ElementsIn(shape), reduction_kind_, executor); + + communicator_event.AndThen([state, communicator_event]() mutable { + if (ABSL_PREDICT_FALSE(communicator_event.IsError())) { + state.CountDown(communicator_event.GetError()); + } else { + state.CountDown(); + } + }); } - return absl::OkStatus(); + return state.AsRef(); }); - - return OkExecuteEvent(); } } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc index 42452761da6a31..26afa74dd38caf 100644 --- a/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -54,7 +54,6 @@ AllToAllThunk::AllToAllThunk(Info info, OpParams op_params, tsl::AsyncValueRef AllToAllThunk::Execute( const ExecuteParams& params) { - TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); VLOG(3) << absl::StreamFormat( @@ -79,11 +78,9 @@ tsl::AsyncValueRef AllToAllThunk::Execute( CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); const Shape& shape = destination_shape(0); - TF_RETURN_IF_ERROR( - comm.AllToAll(data.source, data.destination, shape.element_type(), - ShapeUtil::ElementsIn(shape), executor)); - - return absl::OkStatus(); + return comm.AllToAll(data.source, data.destination, + shape.element_type(), ShapeUtil::ElementsIn(shape), + executor); }); } diff --git a/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc index 3d495355915087..e9630ee2d17363 100644 --- a/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/memory/memory.h" @@ -130,15 +131,26 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { params.collective_params, [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); - + tsl::CountDownAsyncValueRef state( + data.source.size()); for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = source_shape(i); - TF_RETURN_IF_ERROR(comm.CollectivePermute( + + auto communicator_event = comm.CollectivePermute( data.source[i], data.destination[i], shape.element_type(), ShapeUtil::ElementsIn(shape), source_replica_id, copy_to, - executor)); + executor); + + communicator_event.AndThen([state, communicator_event]() mutable { + if (ABSL_PREDICT_FALSE(communicator_event.IsError())) { + state.CountDown(communicator_event.GetError()); + } else { + state.CountDown(); + } + }); } - return absl::OkStatus(); + + return state.AsRef(); }); } diff --git a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc index 452fdf762783e9..2abaf73ad88f9a 100644 --- a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -216,9 +217,27 @@ CollectiveThunk::ExecuteWithCommunicator( Communicator * communicator, AcquireCommunicator(collectives, clique_key, RankId(rank))); - TF_RETURN_IF_ERROR(callback(key, *communicator)); + tsl::AsyncValueRef communicator_event = + callback(key, *communicator); - return OkExecuteEvent(); + auto event = tsl::MakeConstructedAsyncValueRef(); + + // Keeps communicator event alive until the event is ready. + event.AndThen([communicator_event]() {}); + + // Set the event to concrete state once the communicator event is ready. + communicator_event.AndThen( + // We pass `communicator_event` as a pointer because `event` will keep it + // alive. + [event, communicator_event_ptr = communicator_event.AsPtr()]() { + if (ABSL_PREDICT_FALSE(communicator_event_ptr.IsError())) { + event.SetError(communicator_event_ptr.GetError()); + } + + event.SetStateConcrete(); + }); + + return event; } const BufferAllocation::Slice& CollectiveThunk::source_buffer( diff --git a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h index 7b7fb624c6822e..c9cdcec375d2bb 100644 --- a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -104,8 +105,8 @@ class CollectiveThunk : public Thunk { ResourceUses resource_uses() const final; // Callback for collective thunk implementations. - using Callback = absl::AnyInvocable; + using Callback = absl::AnyInvocable( + const RendezvousKey& key, Communicator& comm)>; static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype); diff --git a/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index 8682ce958a4016..f17fabab0b4f21 100644 --- a/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/memory/memory.h" @@ -89,14 +90,25 @@ ReduceScatterThunk::Execute(const ExecuteParams& params) { [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); + tsl::CountDownAsyncValueRef state( + data.source.size()); + for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = destination_shape(i); - TF_RETURN_IF_ERROR(comm.ReduceScatter( + auto communicator_event = comm.ReduceScatter( data.source[i], data.destination[i], shape.element_type(), - ShapeUtil::ElementsIn(shape), reduction_kind_, executor)); + ShapeUtil::ElementsIn(shape), reduction_kind_, executor); + + communicator_event.AndThen([state, communicator_event]() mutable { + if (ABSL_PREDICT_FALSE(communicator_event.IsError())) { + state.CountDown(communicator_event.GetError()); + } else { + state.CountDown(); + } + }); } - return absl::OkStatus(); + return state.AsRef(); }); } diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index a07dcd1e588ff9..eb2ee3c1b58ec6 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -254,6 +254,8 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_stream", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", @@ -327,9 +329,11 @@ xla_cc_test( ":gpu_collectives", ":nccl_communicator", ":nccl_errors", + "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc index 05e5f91068ec44..29ab0a83cf07be 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc @@ -36,6 +36,8 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -231,7 +233,7 @@ NcclCommunicator::RegisterBuffer(stream_executor::DeviceMemoryBase buffer) { #endif // NCCL_VERSION_CODE >= 21901 } -absl::Status NcclCommunicator::AllReduce( +tsl::AsyncValueRef NcclCommunicator::AllReduce( se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, ReductionKind reduction_kind, const Communicator::Executor& executor) { @@ -250,17 +252,17 @@ absl::Status NcclCommunicator::AllReduce( TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclAllReduce( + TF_RETURN_IF_ERROR(XLA_NCCL_STATUS(ncclAllReduce( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, ToNcclReduction(reduction_kind), comm_, - se::gpu::AsGpuStreamValue(stream))); + se::gpu::AsGpuStreamValue(stream)))); + + return OkEvent(); } -absl::Status NcclCommunicator::Broadcast(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - RankId root, - const Executor& executor) { +tsl::AsyncValueRef NcclCommunicator::Broadcast( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, RankId root, const Executor& executor) { if (aborted_) { return absl::FailedPreconditionError("NcclCommunicator aborted"); } @@ -276,16 +278,17 @@ absl::Status NcclCommunicator::Broadcast(se::DeviceMemoryBase send_buffer, TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclBroadcast( + TF_RETURN_IF_ERROR(XLA_NCCL_STATUS(ncclBroadcast( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), - nccl_dtype, root.value(), comm_, se::gpu::AsGpuStreamValue(stream))); + nccl_dtype, root.value(), comm_, se::gpu::AsGpuStreamValue(stream)))); + + return OkEvent(); } -absl::Status NcclCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) { +tsl::AsyncValueRef NcclCommunicator::ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { if (aborted_) { return absl::FailedPreconditionError("NcclCommunicator aborted"); } @@ -301,16 +304,17 @@ absl::Status NcclCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclReduceScatter( + TF_RETURN_IF_ERROR(XLA_NCCL_STATUS(ncclReduceScatter( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, ToNcclReduction(reduction_kind), comm_, - se::gpu::AsGpuStreamValue(stream))); + se::gpu::AsGpuStreamValue(stream)))); + + return OkEvent(); } -absl::Status NcclCommunicator::AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - const Executor& executor) { +tsl::AsyncValueRef NcclCommunicator::AllGather( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, const Executor& executor) { if (aborted_) { return absl::FailedPreconditionError("NcclCommunicator aborted"); } @@ -325,12 +329,14 @@ absl::Status NcclCommunicator::AllGather(se::DeviceMemoryBase send_buffer, TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclAllGather( + TF_RETURN_IF_ERROR(XLA_NCCL_STATUS(ncclAllGather( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), - nccl_dtype, comm_, se::gpu::AsGpuStreamValue(stream))); + nccl_dtype, comm_, se::gpu::AsGpuStreamValue(stream)))); + + return OkEvent(); } -absl::Status NcclCommunicator::AllToAll( +tsl::AsyncValueRef NcclCommunicator::AllToAll( absl::Span send_buffers, absl::Span recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) { @@ -385,10 +391,10 @@ absl::Status NcclCommunicator::AllToAll( XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); - return absl::OkStatus(); + return OkEvent(); } -absl::Status NcclCommunicator::CollectivePermute( +tsl::AsyncValueRef NcclCommunicator::CollectivePermute( se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, std::optional source_rank, absl::Span target_ranks, const Executor& executor) { @@ -414,7 +420,7 @@ absl::Status NcclCommunicator::CollectivePermute( // Short-circuit if there is no source or target rank. if (!source_rank && target_ranks.empty()) { - return absl::OkStatus(); + return OkEvent(); } XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); @@ -433,12 +439,12 @@ absl::Status NcclCommunicator::CollectivePermute( XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); - return absl::OkStatus(); + return OkEvent(); } -absl::Status NcclCommunicator::Send(se::DeviceMemoryBase send_buffer, - PrimitiveType dtype, size_t count, - RankId peer, const Executor& executor) { +tsl::AsyncValueRef NcclCommunicator::Send( + se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, + RankId peer, const Executor& executor) { if (aborted_) { return absl::FailedPreconditionError("NcclCommunicator aborted"); } @@ -453,14 +459,16 @@ absl::Status NcclCommunicator::Send(se::DeviceMemoryBase send_buffer, TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS( + TF_RETURN_IF_ERROR(XLA_NCCL_STATUS( ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, - peer.value(), comm_, se::gpu::AsGpuStreamValue(stream))); + peer.value(), comm_, se::gpu::AsGpuStreamValue(stream)))); + + return OkEvent(); } -absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - RankId peer, const Executor& executor) { +tsl::AsyncValueRef NcclCommunicator::Recv( + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, + RankId peer, const Executor& executor) { if (aborted_) { return absl::FailedPreconditionError("NcclCommunicator aborted"); } @@ -475,9 +483,11 @@ absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer, TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS( + TF_RETURN_IF_ERROR(XLA_NCCL_STATUS( ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, - peer.value(), comm_, se::gpu::AsGpuStreamValue(stream))); + peer.value(), comm_, se::gpu::AsGpuStreamValue(stream)))); + + return OkEvent(); } std::string NcclCommunicator::ToString() const { diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h index 77bc950b983fa0..121269e6fddcda 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h @@ -30,6 +30,8 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/concurrency/async_value_ref.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -63,43 +65,46 @@ class NcclCommunicator : public Communicator { absl::StatusOr> RegisterBuffer( se::DeviceMemoryBase buffer) final; - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) final; - - absl::Status Broadcast(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, RankId root, - const Executor& executor) final; - - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) final; - - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) final; - - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) final; - - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, + tsl::AsyncValueRef AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final; + + tsl::AsyncValueRef Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + RankId root, + const Executor& executor) final; + + tsl::AsyncValueRef ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final; + + tsl::AsyncValueRef AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) final; + + tsl::AsyncValueRef AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) final; + + tsl::AsyncValueRef CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) final; + + tsl::AsyncValueRef Send(se::DeviceMemoryBase send_buffer, + PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) final; - absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, - size_t count, RankId peer, const Executor& executor) final; - - absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, RankId peer, const Executor& executor) final; + tsl::AsyncValueRef Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, RankId peer, + const Executor& executor) final; std::string ToString() const final; diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator_test.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator_test.cc index f69500373b1cdb..372079e66e2f50 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator_test.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator_test.cc @@ -27,9 +27,11 @@ limitations under the License. #include "absl/utility/utility.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_errors.h" +#include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/errors.h" @@ -112,6 +114,15 @@ TEST(NcclCommunicator, OperationsFailAfterAbort) { HasSubstr("aborted"))); }; + auto assert_event_aborted = + [](tsl::AsyncValueRef event) { + tsl::BlockUntilReady(event); + ASSERT_TRUE(event.IsError()); + ASSERT_THAT(event.GetError(), + StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("aborted"))); + }; + // Declare placeholder variables to make the operations below compile. se::DeviceMemoryBase buf; PrimitiveType dtype = PrimitiveType::U64; @@ -130,15 +141,17 @@ TEST(NcclCommunicator, OperationsFailAfterAbort) { assert_aborted(comm->HealthCheck()); assert_aborted(comm->NumRanks().status()); assert_aborted(comm->RegisterBuffer(buf).status()); - assert_aborted(comm->AllReduce(buf, buf, dtype, count, rk, executor)); - assert_aborted(comm->Broadcast(buf, buf, dtype, count, RankId(0), executor)); - assert_aborted(comm->ReduceScatter(buf, buf, dtype, count, rk, executor)); - assert_aborted(comm->AllGather(buf, buf, dtype, count, executor)); - assert_aborted(comm->AllToAll({}, {}, dtype, count, executor)); - assert_aborted( + assert_event_aborted(comm->AllReduce(buf, buf, dtype, count, rk, executor)); + assert_event_aborted( + comm->Broadcast(buf, buf, dtype, count, RankId(0), executor)); + assert_event_aborted( + comm->ReduceScatter(buf, buf, dtype, count, rk, executor)); + assert_event_aborted(comm->AllGather(buf, buf, dtype, count, executor)); + assert_event_aborted(comm->AllToAll({}, {}, dtype, count, executor)); + assert_event_aborted( comm->CollectivePermute(buf, buf, dtype, count, {}, {}, executor)); - assert_aborted(comm->Send(buf, dtype, count, RankId(0), executor)); - assert_aborted(comm->Recv(buf, dtype, count, RankId(0), executor)); + assert_event_aborted(comm->Send(buf, dtype, count, RankId(0), executor)); + assert_event_aborted(comm->Recv(buf, dtype, count, RankId(0), executor)); } } // namespace diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 41082e9801cb22..b9a102462006b4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -728,6 +728,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", @@ -762,6 +763,7 @@ cc_library( "//xla/stream_executor:event", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", @@ -799,6 +801,7 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", @@ -838,6 +841,7 @@ cc_library( "//xla/stream_executor:event", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", @@ -872,6 +876,7 @@ cc_library( "//xla/service/gpu/transforms/collectives:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", @@ -1015,6 +1020,7 @@ cc_library( "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", @@ -1043,6 +1049,7 @@ cc_library( "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", @@ -1511,6 +1518,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/stream_executor/gpu:ragged_all_to_all_kernel", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/gpu/runtime/all_gather_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_gather_thunk.cc index 07bf3899807961..77b6fc65ed4445 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_gather_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_gather_thunk.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -112,9 +113,14 @@ absl::Status RunAllGather(GpuCollectives* collectives, TF_RETURN_IF_ERROR(collectives->GroupStart()); for (DeviceBufferPair& buffer : buffers) { - TF_RETURN_IF_ERROR(comm->AllGather( + auto event = comm->AllGather( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, - buffer.element_count, GpuCollectives::On(stream))); + buffer.element_count, GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } TF_RETURN_IF_ERROR(collectives->GroupEnd()); diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc index d2a9e5ed05f43b..76f7ca57e28fbc 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc @@ -50,6 +50,7 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -241,9 +242,14 @@ absl::Status RunAllReduce(GpuCollectives* collectives, TF_RETURN_IF_ERROR(collectives->GroupStart()); for (DeviceBufferPair& buffer : buffers) { - TF_RETURN_IF_ERROR(comm->AllReduce( + auto event = comm->AllReduce( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, - buffer.element_count, reduction_kind, GpuCollectives::On(stream))); + buffer.element_count, reduction_kind, GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } return collectives->GroupEnd(); @@ -454,10 +460,15 @@ absl::Status RunReduceScatter(GpuCollectives* collectives, << "Source buffer was not an exact multiple of the number of " "participants."; - TF_RETURN_IF_ERROR(comm->ReduceScatter( + auto event = comm->ReduceScatter( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, buffer.element_count / num_ranks, reduction_kind, - GpuCollectives::On(stream))); + GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } return collectives->GroupEnd(); diff --git a/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc index 3f5bb1c464902c..fd3fba1d2f8b28 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -237,18 +238,30 @@ absl::Status RunAllToAll(GpuCollectives* collectives, bool has_split_dimension, } } - return comm->AllToAll(send_buffers, recv_buffers, element_type, - chunk_element_count, GpuCollectives::On(stream)); + auto event = + comm->AllToAll(send_buffers, recv_buffers, element_type, + chunk_element_count, GpuCollectives::On(stream)); + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } else { for (const DeviceBufferPair& buffer : buffers) { send_buffers.push_back(buffer.source_buffer); recv_buffers.push_back(buffer.destination_buffer); } - return comm->AllToAll(send_buffers, recv_buffers, element_type, - element_count, GpuCollectives::On(stream)); + auto event = comm->AllToAll(send_buffers, recv_buffers, element_type, + element_count, GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } + + return absl::OkStatus(); } static absl::Status SendPtrToPeer(void* ptr, RankId peer, Communicator* comm, @@ -257,8 +270,15 @@ static absl::Status SendPtrToPeer(void* ptr, RankId peer, Communicator* comm, "RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p", stream.parent()->device_ordinal(), peer.value(), comm, &stream); - return comm->Send(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer, - GpuCollectives::On(stream)); + auto event = comm->Send(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, + peer, GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } + + return absl::OkStatus(); } static absl::Status RecvPtrFromPeer(void* ptr, RankId peer, Communicator* comm, @@ -267,8 +287,15 @@ static absl::Status RecvPtrFromPeer(void* ptr, RankId peer, Communicator* comm, "RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p", stream.parent()->device_ordinal(), peer.value(), comm, &stream); - return comm->Recv(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer, - GpuCollectives::On(stream)); + auto event = comm->Recv(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, + peer, GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } + + return absl::OkStatus(); } // TODO(b/380457503): Memcpy AllToAll implementation must be moved to diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_broadcast_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_broadcast_thunk.cc index 43e0e163da55d7..624d407fee1e9f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_broadcast_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_broadcast_thunk.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -76,11 +77,16 @@ absl::Status RunCollectiveBroadcast(std::vector& buffers, for (auto buffer : buffers) { se::DeviceMemoryBase src_addr = buffer.source_buffer; se::DeviceMemoryBase dest_addr = buffer.destination_buffer; - TF_RETURN_IF_ERROR(comm->Broadcast( + auto event = comm->Broadcast( // Always use rank 0 since we always broadcast from the first id in // replica_groups src_addr, dest_addr, buffer.element_type, buffer.element_count, - RankId(0), GpuCollectives::On(stream))); + RankId(0), GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } return collectives->GroupEnd(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc index 6bb17ce4a8fd13..f770c4dfee4f26 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/service/rendezvous.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -409,9 +410,14 @@ absl::Status RunCollectivePermute( const auto src_addr = src_addrs.at(idx); const auto dest_addr = dest_addrs.at(idx); const auto buffer = buffers.at(idx); - TF_RETURN_IF_ERROR(comm->CollectivePermute( + auto event = comm->CollectivePermute( src_addr, dest_addr, buffer.element_type, buffer.element_count, - source_rank, target_ranks, GpuCollectives::On(stream))); + source_rank, target_ranks, GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } if (is_nccl_group_needed) { diff --git a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc index feaae49dd850ca..fe200465da04d0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -120,13 +121,23 @@ absl::Status RunAllToAllOnIndexBuffer( collectives->Slice(destination_buffer, element_type, offset, /*count=*/num_updates_per_replica); - TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, - /*count=*/num_updates_per_replica, - RankId(peer), GpuCollectives::On(stream))); + auto event_send = comm->Send(send_slice, element_type, + /*count=*/num_updates_per_replica, + RankId(peer), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, - /*count=*/num_updates_per_replica, - RankId(peer), GpuCollectives::On(stream))); + tsl::BlockUntilReady(event_send); + if (event_send.IsError()) { + return event_send.GetError(); + } + + auto event_recv = comm->Recv(recv_slice, element_type, + /*count=*/num_updates_per_replica, + RankId(peer), GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event_recv); + if (event_recv.IsError()) { + return event_recv.GetError(); + } } TF_RETURN_IF_ERROR(collectives->GroupEnd()); @@ -191,13 +202,23 @@ absl::Status RunRaggedAllToAll( output_offsets[idx] * ragged_row_element_size, recv_sizes[idx] * ragged_row_element_size); - TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, - send_sizes[idx] * ragged_row_element_size, - RankId(peer), GpuCollectives::On(stream))); + auto event_send = comm->Send(send_slice, element_type, + send_sizes[idx] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, - recv_sizes[idx] * ragged_row_element_size, - RankId(peer), GpuCollectives::On(stream))); + tsl::BlockUntilReady(event_send); + if (event_send.IsError()) { + return event_send.GetError(); + } + + auto event_recv = comm->Recv(recv_slice, element_type, + recv_sizes[idx] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event_recv); + if (event_recv.IsError()) { + return event_recv.GetError(); + } } } diff --git a/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc index b9b0953eceded4..17c15e3190720a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -130,9 +131,14 @@ absl::Status RecvThunk::RunCollective(const ExecuteParams& params, ++(*counter); } if (should_run) { - TF_RETURN_IF_ERROR(comm_handle.comm->Recv( + auto event = comm_handle.comm->Recv( dest_addr, buffer.element_type, buffer.element_count, - RankId(*source_id), GpuCollectives::On(stream))); + RankId(*source_id), GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } else { VLOG(3) << "Skipping Recv"; } diff --git a/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc index 0e4a227c9a9a1c..29a8aeb27e595e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -131,9 +132,14 @@ absl::Status SendThunk::RunCollective(const ExecuteParams& params, } if (should_run) { - TF_RETURN_IF_ERROR(comm_handle.comm->Send( + auto event = comm_handle.comm->Send( src_addr, buffer.element_type, buffer.element_count, - RankId(*target_id), GpuCollectives::On(stream))); + RankId(*target_id), GpuCollectives::On(stream)); + + tsl::BlockUntilReady(event); + if (event.IsError()) { + return event.GetError(); + } } else { VLOG(3) << "Skipping Send"; } diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index 85f30c2c76b0bf..3a2658ba5ac928 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -73,6 +73,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/core/collectives/communicator.h b/third_party/xla/xla/core/collectives/communicator.h index af95f7063fc803..883faeeabab8fa 100644 --- a/third_party/xla/xla/core/collectives/communicator.h +++ b/third_party/xla/xla/core/collectives/communicator.h @@ -28,14 +28,25 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/concurrency/chain.h" #include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { // Collective communicator defines the set of communicating XLA processes. +// +// Returned async value signals that the communicator has successfully +// launched the operation on the underlying executor. +// Completion of the operation depends on the backend implementation. i.e. on +// GPU the async value becomes available when the operation is scheduled on the +// device stream, and on CPU it becomes available when the operation is +// completed. class Communicator { public: + using Event = tsl::Chain; + virtual ~Communicator() = default; // An executor is an abstraction for the underlying resource where collective @@ -75,67 +86,72 @@ class Communicator { // Reduce buffers of length `count` in `send_buff` using `reduction_kind` // reduction and leaves identical copies of the result on each `recv_buff`. - virtual absl::Status AllReduce(stream_executor::DeviceMemoryBase send_buffer, - stream_executor::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef AllReduce( + stream_executor::DeviceMemoryBase send_buffer, + stream_executor::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, const Executor& executor) = 0; // Copy data in `send_buff` from the root device to the `recv_buff` on // all other devices. - virtual absl::Status Broadcast(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, RankId root, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + RankId root, + const Executor& executor) = 0; // Reduce data in `send_buff` from all devices using the `reduction_kind` // operation and leave the reduced result scattered over the devices so that // the `recv_buff` on rank `i` will contain the i-th block of the result. - virtual absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) = 0; // Gather `count` values from all devices into `recv_buffer`, receiving data // from rank `i` at offset `i * sendcount`. - virtual absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) = 0; // Sends data from `send_buffer` to `target_ranks` and receives data from // `source_rank` into `recv_buffer`. If `source_rank` is not specified, the // output is filled with zeros. - virtual absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) = 0; // Sends `count` values from `send_buffers` to other ranks and receives data // from other ranks into `recv_buffers`. - virtual absl::Status AllToAll( + virtual tsl::AsyncValueRef AllToAll( absl::Span send_buffers, absl::Span recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) = 0; // Send data from `send_buff` to rank `peer`. - virtual absl::Status Send(se::DeviceMemoryBase send_buffer, - PrimitiveType dtype, size_t count, RankId peer, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef Send(se::DeviceMemoryBase send_buffer, + PrimitiveType dtype, size_t count, + RankId peer, + const Executor& executor) = 0; // Receive data from rank `peer` into `recv_buff`. - virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, RankId peer, - const Executor& executor) = 0; + virtual tsl::AsyncValueRef Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + RankId peer, + const Executor& executor) = 0; // Returns the number of ranks in the communicator. virtual absl::StatusOr NumRanks() const = 0; // Returns a human-readable description of the communicator. virtual std::string ToString() const = 0; + + protected: + // Returns an `Event` that is always available. + static tsl::AsyncValueRef OkEvent() { + return tsl::MakeAvailableAsyncValueRef(); + } }; inline std::ostream& operator<<(std::ostream& os, const Communicator& comm) { diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 759ede6a3771f6..dd1a5d4f89a629 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1115,6 +1115,7 @@ cc_library( "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:status", diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 95c383f9d57d86..6f2d5731274f71 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -57,6 +57,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/status.h" @@ -409,9 +410,10 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, se::DeviceMemoryBase(destination_buffers[i], buffer_size)); } - TF_CHECK_OK(communicator->AllToAll(source_buffers_data, - destination_buffers_data, U8, buffer_size, - executor)); + auto event = communicator->AllToAll( + source_buffers_data, destination_buffers_data, U8, buffer_size, executor); + tsl::BlockUntilReady(event); + TF_CHECK_OK(event.GetError()); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -441,8 +443,10 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, se::DeviceMemoryBase output_buffer_data(destination_buffer, buffer_size); CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); - TF_CHECK_OK(communicator->AllGather(input_buffer_data, output_buffer_data, U8, - buffer_size, executor)); + auto event = communicator->AllGather(input_buffer_data, output_buffer_data, + U8, buffer_size, executor); + tsl::BlockUntilReady(event); + TF_CHECK_OK(event.GetError()); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -479,9 +483,12 @@ void ReduceScatterImpl(const ExecutableRunOptions* run_options, primitive_util::ByteWidth(dtype)); CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); - TF_CHECK_OK(communicator->ReduceScatter( + auto event = communicator->ReduceScatter( input_buffer_data, output_buffer_data, dtype, chunk_elems, - static_cast(reduction_kind), executor)); + static_cast(reduction_kind), executor); + + tsl::BlockUntilReady(event); + TF_CHECK_OK(event.GetError()); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -532,10 +539,12 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, for (int i = 0; i < num_buffers; i++) { Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i); - TF_CHECK_OK(communicator->AllReduce( + auto event = communicator->AllReduce( input_buffers_data[i], output_buffers_data[i], subshape.element_type(), ShapeUtil::ElementsIn(subshape), - static_cast(reduction_kind), executor)); + static_cast(reduction_kind), executor); + tsl::BlockUntilReady(event); + TF_CHECK_OK(event.GetError()); } } @@ -586,9 +595,11 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, se::DeviceMemoryBase input_buffer_data(input_buffer, byte_size); se::DeviceMemoryBase output_buffer_data(output_buffer, byte_size); - TF_CHECK_OK(communicator->CollectivePermute( + auto event = communicator->CollectivePermute( input_buffer_data, output_buffer_data, U8, byte_size, source_replica_id, - copy_to, executor)); + copy_to, executor); + tsl::BlockUntilReady(event); + TF_CHECK_OK(event.GetError()); } } // namespace } // namespace runtime From f84c6d38f805d50e0b6b937dddb1e52728ceb2ea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 17:13:00 -0700 Subject: [PATCH 0892/1324] Use kDeviceTypeString instead of kDeviceType PiperOrigin-RevId: 748473543 --- third_party/xla/xla/tsl/profiler/utils/xplane_schema.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h index 537c76f19de3fd..45d197e0b7ced2 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h @@ -315,7 +315,7 @@ enum StatType { kMinDurationPs, kTotalProfileDurationPs, kMaxIterationNum, - kDeviceType, + kDeviceType, // Do not use. Use kDeviceTypeString instead. kUsesMegaCore, kSymbolId, kTfOpName, From a8917ece166f755dade1a39653f9abe641f42589 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 17:17:36 -0700 Subject: [PATCH 0893/1324] Update ops-related pbtxt files. PiperOrigin-RevId: 748474753 --- ...ulCustomCombinerOnTcGradWithCsrInput.pbtxt | 100 ++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 100 ++++++++++++++++++ 2 files changed, 200 insertions(+) create mode 100644 tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt new file mode 100644 index 00000000000000..901953a56a4fd7 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt @@ -0,0 +1,100 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "tables" + type: DT_FLOAT + number_attr: "N" + } + input_arg { + name: "hyperparameters" + type: DT_FLOAT + number_attr: "M" + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + output_arg { + name: "updated_tables" + type: DT_FLOAT + number_attr: "N" + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "M" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "optimizer_custom_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 3650e9d60bed3d..3fd3aed6945788 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -67609,6 +67609,106 @@ op { has_minimum: true } } +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "tables" + type: DT_FLOAT + number_attr: "N" + } + input_arg { + name: "hyperparameters" + type: DT_FLOAT + number_attr: "M" + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + output_arg { + name: "updated_tables" + type: DT_FLOAT + number_attr: "N" + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "M" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "optimizer_custom_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} op { name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" input_arg { From 4dbae670e3a5c6e15c2dd9192ee229e56c84a3c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 17:43:39 -0700 Subject: [PATCH 0894/1324] * Add a new Printer implementation, based on HighwayHash, that incrementally computes fingerprints. * Add a new Printer method that can print a series of int64s. PiperOrigin-RevId: 748480794 --- third_party/xla/third_party/highwayhash/BUILD | 3 + .../third_party/highwayhash/highwayhash.BUILD | 296 ++++++++++++++++++ .../xla/third_party/highwayhash/workspace.bzl | 12 + third_party/xla/tsl_workspace2.bzl | 2 + third_party/xla/xla/BUILD | 1 + third_party/xla/xla/hlo/ir/BUILD | 4 + .../xla/xla/hlo/ir/collective_device_list.cc | 20 ++ .../xla/xla/hlo/ir/collective_device_list.h | 7 + .../xla/xla/hlo/ir/hlo_instructions.cc | 4 +- third_party/xla/xla/hlo/ir/hlo_module.cc | 55 +++- third_party/xla/xla/hlo/ir/hlo_module.h | 3 + third_party/xla/xla/hlo/ir/hlo_module_test.cc | 25 ++ third_party/xla/xla/printer.cc | 12 + third_party/xla/xla/printer.h | 8 + 14 files changed, 448 insertions(+), 4 deletions(-) create mode 100644 third_party/xla/third_party/highwayhash/BUILD create mode 100644 third_party/xla/third_party/highwayhash/highwayhash.BUILD create mode 100644 third_party/xla/third_party/highwayhash/workspace.bzl diff --git a/third_party/xla/third_party/highwayhash/BUILD b/third_party/xla/third_party/highwayhash/BUILD new file mode 100644 index 00000000000000..9e2309bed15565 --- /dev/null +++ b/third_party/xla/third_party/highwayhash/BUILD @@ -0,0 +1,3 @@ +# Dummy BUILD file to make this directory a package. + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/xla/third_party/highwayhash/highwayhash.BUILD b/third_party/xla/third_party/highwayhash/highwayhash.BUILD new file mode 100644 index 00000000000000..c24c987a276acd --- /dev/null +++ b/third_party/xla/third_party/highwayhash/highwayhash.BUILD @@ -0,0 +1,296 @@ +# Description: +# SipHash and HighwayHash: cryptographically-strong pseudorandom functions +package( + default_visibility = ["//visibility:public"], + features = ["header_modules"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +config_setting( + name = "haswell", + values = {"cpu": "haswell"}, +) + +config_setting( + name = "k8", + values = {"cpu": "k8"}, +) + +config_setting( + name = "cpu_ppc", + values = {"cpu": "ppc"}, +) + +config_setting( + name = "cpu_aarch64", + values = {"cpu": "aarch64"}, +) + +#----------------------------------------------------------------------------- +# Platform-specific + +cc_library( + name = "compiler_specific", + hdrs = ["highwayhash/compiler_specific.h"], +) + +cc_library( + name = "arch_specific", + srcs = ["highwayhash/arch_specific.cc"], + hdrs = ["highwayhash/arch_specific.h"], + deps = [":compiler_specific"], +) + +cc_library( + name = "endianess", + hdrs = ["highwayhash/endianess.h"], +) + +cc_library( + name = "instruction_sets", + srcs = ["highwayhash/instruction_sets.cc"], + hdrs = ["highwayhash/instruction_sets.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ], +) + +cc_library( + name = "iaca", + hdrs = ["highwayhash/iaca.h"], + deps = [":compiler_specific"], +) + +cc_library( + name = "os_specific", + srcs = ["highwayhash/os_specific.cc"], + hdrs = ["highwayhash/os_specific.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ], +) + +#----------------------------------------------------------------------------- +# Vectors + +cc_library( + name = "scalar", + textual_hdrs = ["highwayhash/scalar.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ], +) + +cc_library( + name = "vector128", + textual_hdrs = ["highwayhash/vector128.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ], +) + +cc_library( + name = "vector256", + textual_hdrs = ["highwayhash/vector256.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ], +) + +#----------------------------------------------------------------------------- +# SipHash + +cc_library( + name = "sip_hash", + srcs = ["highwayhash/sip_hash.cc"], + hdrs = [ + "highwayhash/sip_hash.h", + "highwayhash/state_helpers.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":arch_specific", + ":compiler_specific", + ":endianess", + ], +) + +#----------------------------------------------------------------------------- +# HighwayHash + +cc_library( + name = "hh_types", + hdrs = ["highwayhash/hh_types.h"], + deps = [":instruction_sets"], +) + +cc_library( + name = "load3", + textual_hdrs = ["highwayhash/load3.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ":endianess", + ], +) + +cc_library( + name = "hh_avx2", + srcs = ["highwayhash/hh_avx2.cc"], + hdrs = ["highwayhash/highwayhash_target.h"], + copts = select({ + ":k8": ["-mavx2"], + ":haswell": ["-mavx2"], + "//conditions:default": ["-DHH_DISABLE_TARGET_SPECIFIC"], + }), + textual_hdrs = [ + "highwayhash/hh_avx2.h", + "highwayhash/highwayhash_target.cc", + "highwayhash/highwayhash.h", + "highwayhash/hh_buffer.h", + ], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_types", + ":iaca", + ":load3", + ":vector128", + ":vector256", + ], +) + +cc_library( + name = "hh_sse41", + srcs = ["highwayhash/hh_sse41.cc"], + hdrs = ["highwayhash/highwayhash_target.h"], + copts = select({ + ":k8": ["-msse4.1"], + ":haswell": ["-msse4.1"], + "//conditions:default": ["-DHH_DISABLE_TARGET_SPECIFIC"], + }), + textual_hdrs = [ + "highwayhash/hh_sse41.h", + "highwayhash/highwayhash_target.cc", + "highwayhash/highwayhash.h", + "highwayhash/hh_buffer.h", + ], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_types", + ":iaca", + ":load3", + ":vector128", + ], +) + +cc_library( + name = "hh_neon", + srcs = [ + "highwayhash/hh_neon.cc", + "highwayhash/vector_neon.h", + ], + hdrs = ["highwayhash/highwayhash_target.h"], + copts = select({ + ":cpu_aarch64": [], + "//conditions:default": ["-DHH_DISABLE_TARGET_SPECIFIC"], + }), + textual_hdrs = [ + "highwayhash/highwayhash_target.cc", + "highwayhash/highwayhash.h", + "highwayhash/hh_buffer.h", + "highwayhash/hh_neon.h", + ], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_types", + ":load3", + ], +) + +cc_library( + name = "hh_vsx", + srcs = ["highwayhash/hh_vsx.cc"], + hdrs = ["highwayhash/highwayhash_target.h"], + textual_hdrs = [ + "highwayhash/highwayhash_target.cc", + "highwayhash/highwayhash.h", + "highwayhash/hh_buffer.h", + "highwayhash/hh_vsx.h", + ], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_types", + ":load3", + ], +) + +cc_library( + name = "hh_portable", + srcs = ["highwayhash/hh_portable.cc"], + hdrs = ["highwayhash/highwayhash_target.h"], + textual_hdrs = [ + "highwayhash/hh_portable.h", + "highwayhash/highwayhash_target.cc", + "highwayhash/highwayhash.h", + "highwayhash/hh_buffer.h", + ], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_types", + ":iaca", + ":load3", + ":scalar", + ], +) + +# For users of the HighwayHashT template +cc_library( + name = "highwayhash", + hdrs = ["highwayhash/highwayhash.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_portable", + ":hh_types", + ] + select({ + ":cpu_ppc": [":hh_vsx"], + ":cpu_aarch64": [":hh_neon"], + "//conditions:default": [ + ":hh_avx2", + ":hh_sse41", + ":iaca", + ], + }), +) + +# For users of InstructionSets runtime dispatch +cc_library( + name = "highwayhash_dynamic", + hdrs = ["highwayhash/highwayhash_target.h"], + deps = [ + ":arch_specific", + ":compiler_specific", + ":hh_portable", + ":hh_types", + ] + select({ + ":cpu_ppc": [":hh_vsx"], + ":cpu_aarch64": [":hh_neon"], + "//conditions:default": [ + ":hh_avx2", + ":hh_sse41", + ], + }), +) diff --git a/third_party/xla/third_party/highwayhash/workspace.bzl b/third_party/xla/third_party/highwayhash/workspace.bzl new file mode 100644 index 00000000000000..9b2c0ccbec0796 --- /dev/null +++ b/third_party/xla/third_party/highwayhash/workspace.bzl @@ -0,0 +1,12 @@ +"""loads the highwayhash library, used by TF.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "highwayhash", + urls = tf_mirror_urls("https://github.com/google/highwayhash/archive/c13d28517a4db259d738ea4886b1f00352a3cc33.tar.gz"), + sha256 = "c0e2b9931fbcce3bfbcd7999c3c114f404ac0f8b89775a5bbccbcaa501868e58", + strip_prefix = "highwayhash-c13d28517a4db259d738ea4886b1f00352a3cc33", + build_file = "//third_party/highwayhash:highwayhash.BUILD", + ) diff --git a/third_party/xla/tsl_workspace2.bzl b/third_party/xla/tsl_workspace2.bzl index a3243925c95f09..1a42a3389eddc7 100644 --- a/third_party/xla/tsl_workspace2.bzl +++ b/third_party/xla/tsl_workspace2.bzl @@ -21,6 +21,7 @@ load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") +load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") @@ -49,6 +50,7 @@ def _initialize_third_party(): eigen3() farmhash() gemmlowp() + highwayhash() hwloc() implib_so() ml_dtypes() diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 1c1b04a744a478..e18d438b41efb3 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1324,6 +1324,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 82d03ec0abd9e7..85b86473825c97 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -117,6 +117,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@highwayhash", + "@highwayhash//:arch_specific", + "@highwayhash//:hh_types", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:protobuf", ], @@ -201,6 +204,7 @@ xla_cc_test( "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_module_config", + "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/hash", diff --git a/third_party/xla/xla/hlo/ir/collective_device_list.cc b/third_party/xla/xla/hlo/ir/collective_device_list.cc index 5f962c47cb7174..00eea926e2b610 100644 --- a/third_party/xla/xla/hlo/ir/collective_device_list.cc +++ b/third_party/xla/xla/hlo/ir/collective_device_list.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/array.h" +#include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" @@ -58,6 +59,10 @@ std::string IotaReplicaGroupList::ToString() const { return iota_tile_assignment_.ToString(); } +void IotaReplicaGroupList::Print(Printer* printer) const { + iota_tile_assignment_.Print(printer); +} + IotaReplicaGroupListProto IotaReplicaGroupList::ToProto() const { IotaReplicaGroupListProto proto; proto.set_num_replica_groups(num_replica_groups_); @@ -122,6 +127,21 @@ std::string CollectiveDeviceList::ToString( return ReplicaGroupsToString(replica_groups()); } +void CollectiveDeviceList::Print(Printer* printer, + bool print_full_replica_group_list) const { + if (iota_replica_group_list_.has_value() && !print_full_replica_group_list) { + iota_replica_group_list_->Print(printer); + return; + } + printer->Append("{"); + bool leading_comma = false; + for (const ReplicaGroup& group : replica_groups()) { + printer->AppendInt64List(group.replica_ids(), leading_comma); + leading_comma = true; + } + printer->Append("}"); +} + CollectiveDeviceListProto CollectiveDeviceList::ToProto() const { CollectiveDeviceListProto proto; if (iota_replica_group_list_.has_value()) { diff --git a/third_party/xla/xla/hlo/ir/collective_device_list.h b/third_party/xla/xla/hlo/ir/collective_device_list.h index 0c2d3742e29783..b03b9dbf9d1d61 100644 --- a/third_party/xla/xla/hlo/ir/collective_device_list.h +++ b/third_party/xla/xla/hlo/ir/collective_device_list.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" @@ -67,6 +68,8 @@ class IotaReplicaGroupList { } Array ToArray() const { return iota_tile_assignment_.ToArray(); } + void Print(Printer* printer) const; + std::string ToString() const; IotaReplicaGroupListProto ToProto() const; @@ -105,6 +108,10 @@ class CollectiveDeviceList { const std::optional& iota_replica_group_list() const { return iota_replica_group_list_; } + + void Print(Printer* printer, + bool print_full_replica_group_list = false) const; + std::string ToString(bool print_full_replica_group_list = false) const; CollectiveDeviceListProto ToProto() const; diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index bc5ace122c82d2..3801bbc631b936 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -934,8 +934,8 @@ void HloCollectiveInstruction::PrintExtraAttributesImpl( VLOG(4) << name() << " replica_groups=" << device_list_.ToString(options.print_full_replica_group_list()); - AppendCat(printer, "replica_groups=", - device_list_.ToString(options.print_full_replica_group_list())); + printer->Append("replica_groups="); + device_list_.Print(printer, options.print_full_replica_group_list()); }); if (constrain_layout_) { printer.Next( diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 51e3bfa6ca744b..c2652e59a65101 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -37,6 +37,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "highwayhash/arch_specific.h" +#include "highwayhash/hh_types.h" +#include "highwayhash/highwayhash.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -488,6 +491,52 @@ absl::Cord HloModule::ToCord(const HloPrintOptions& options) const { return std::move(printer).ToCord(); } +namespace { +// Generated using openssl rand. +static constexpr highwayhash::HHKey kDefaultKey = { + 0x9e0433b546e065d2ull, + 0x0e7ecad49e703760ull, + 0x83d29f20dae229b0ull, + 0x40c1ce3ff9d19a42ull, +}; + +// HighwayHashPrinter is a Printer that computes the fingerprint of the added +// data using a HighwayHash hasher. +class HighwayHashPrinter : public Printer { + public: + HighwayHashPrinter() : hasher_(kDefaultKey) {} + + void Append(const absl::AlphaNum& a) override { + hasher_.Append(a.data(), a.size()); + } + + void AppendInt64List(absl::Span list, + bool _ /*leading_comma*/) override { + // Instead of separators, prefix with the length. This is fine since + // there's no way for the caller to distinguish between the two. + const uint64_t num = list.size(); + hasher_.Append(reinterpret_cast(&num), sizeof(num)); + hasher_.Append(reinterpret_cast(list.data()), + list.size() * sizeof(list[0])); + } + + uint64_t ToFingerprint() { + highwayhash::HHResult64 result; + hasher_.Finalize(&result); + return result; + } + + private: + highwayhash::HighwayHashCatT hasher_; +}; +} // namespace + +uint64_t HloModule::ToFingerprint(const HloPrintOptions& options) const { + HighwayHashPrinter printer; + Print(&printer, options); + return printer.ToFingerprint(); +} + HloModuleProto HloModule::ToProto() const { HloModuleProto proto; proto.set_id(unique_id_); @@ -1090,8 +1139,10 @@ class FingerprintMap { uint64_t GetFingerprint(const HloComputation* computation) { auto result = fingerprint_map_.try_emplace(computation, 0); if (result.second) { - result.first->second = - tsl::Fingerprint64(computation->ToString(print_options_)); + HighwayHashPrinter printer; + computation->Print(&printer, print_options_, + computation->MakeInstructionPostOrder()); + result.first->second = printer.ToFingerprint(); } return result.first->second; } diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index a8fc065dcfc5e0..e34ef01ac71a6f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -425,6 +425,9 @@ class HloModule { absl::Cord ToCord() const { return ToCord(HloPrintOptions::Default()); } absl::Cord ToCord(const HloPrintOptions& options) const; + // Returns a stable fingerprint of the module using the given print options. + uint64_t ToFingerprint(const HloPrintOptions& options) const; + // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static absl::StatusOr> CreateFromProto( diff --git a/third_party/xla/xla/hlo/ir/hlo_module_test.cc b/third_party/xla/xla/hlo/ir/hlo_module_test.cc index 342f13e89d4ce2..7db2120bf17937 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_test.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/filecheck.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/util.h" @@ -68,6 +70,29 @@ TEST(HloModuleTest, AbslHashValue) { EXPECT_NE(absl::HashOf(module1), absl::HashOf(*module4)); } +TEST(HloModuleTest, ToFingerprint) { + auto fp = [](const HloModule& module) { + return module.ToFingerprint(HloPrintOptions::ModuleFingerprint()); + }; + HloModule module1("m1", HloModuleConfig()); + HloModule module2("m2", HloModuleConfig()); + EXPECT_EQ(fp(module1), fp(module2)); + + absl::string_view hlo = R"( + HloModule m3 + ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT res = f32[] multiply(a, b) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module3, + ParseAndReturnUnverifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module4, + ParseAndReturnUnverifiedModule(hlo)); + EXPECT_EQ(fp(*module3), fp(*module4)); + EXPECT_NE(fp(module1), fp(*module4)); +} + TEST(HloModuleTest, MutableAndReadOnlyConfigEquals) { HloModuleConfig config1; config1.set_device_type("GPU"); diff --git a/third_party/xla/xla/printer.cc b/third_party/xla/xla/printer.cc index 3f4dc50c8d9932..2d0bfaf0225399 100644 --- a/third_party/xla/xla/printer.cc +++ b/third_party/xla/xla/printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/printer.h" +#include #include #include #include @@ -23,10 +24,21 @@ limitations under the License. #include "absl/strings/cord.h" #include "absl/strings/cord_buffer.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "xla/tsl/platform/logging.h" namespace xla { +void Printer::AppendInt64List(absl::Span list, + bool leading_comma) { + if (leading_comma) { + Append(","); + } + Append("{"); + AppendJoin(this, list, ","); + Append("}"); +} + void StringPrinter::Append(const absl::AlphaNum& a) { absl::StrAppend(&result_, a); } diff --git a/third_party/xla/xla/printer.h b/third_party/xla/xla/printer.h index fd6c4314e85ce4..2e951bb0c7e6e1 100644 --- a/third_party/xla/xla/printer.h +++ b/third_party/xla/xla/printer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PRINTER_H_ #define XLA_PRINTER_H_ +#include #include #include @@ -39,6 +40,13 @@ class Printer { // Appends the given string to the printer. virtual void Append(const absl::AlphaNum& a) = 0; + + // Prints a list of numbers in the format: + // {(,)*} + // , pre-pending a comma if `leading_comma` is true. + // May be overridden in some Printer implementations. + virtual void AppendInt64List(absl::Span list, + bool leading_comma); }; // A printer implementation that accumulates printed strings into `std::string`. From 9bba8c513f585f71c9e6300a6081404a4d2e267c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 16 Apr 2025 17:53:59 -0700 Subject: [PATCH 0895/1324] [xla:runtime] Add support for kExecution and kScheduling execution graph edges XLA:CPU ThunkExecutor steal treats all edges equally. Separating scheduling from execution edges at run time will come next. PiperOrigin-RevId: 748483207 --- .../backends/cpu/runtime/thunk_executor.cc | 17 ++- .../xla/backends/cpu/runtime/thunk_executor.h | 3 +- .../cpu/runtime/thunk_sequence_serdes_test.cc | 18 +-- .../gpu/runtime/command_buffer_cmd.cc | 6 +- .../xla/xla/runtime/execution_graph.cc | 104 +++++++++++++----- third_party/xla/xla/runtime/execution_graph.h | 58 ++++++++-- .../xla/xla/runtime/execution_graph_test.cc | 91 ++++++++++++--- third_party/xla/xla/runtime/resource_use.cc | 25 ++++- third_party/xla/xla/runtime/resource_use.h | 9 +- .../xla/xla/runtime/resource_use_test.cc | 10 ++ 10 files changed, 273 insertions(+), 68 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index aa7ad60ab18029..55ad58f0082588 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -509,12 +509,14 @@ void ThunkExecutor::ProcessOutEdges( bool is_sink = node.out_edges.empty(); // Append ready nodes to the back of the ready queue. - for (NodeId out_edge : node.out_edges) { - ExecuteState::Node& out_node = state->node(out_edge); + for (NodeEdge out_edge : node.out_edges) { + ExecuteState::Node& out_node = state->node(out_edge.id); int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release); DCHECK_GE(cnt, 1) << "Node counter can't drop below 0"; - if (cnt == 1) ready_queue.Push(out_edge); + if (cnt == 1) { + ready_queue.Push(out_edge.id); + } } // Drop the pending sink nodes counter if the node is a sink. @@ -524,7 +526,9 @@ void ThunkExecutor::ProcessOutEdges( // remaining memory writes are visible to the consumer of execute event. bool is_done = state->pending_sink_nodes.fetch_sub(1, std::memory_order_acq_rel) == 1; - if (ABSL_PREDICT_TRUE(!is_done)) return; + if (ABSL_PREDICT_TRUE(!is_done)) { + return; + } // In the unlikely event of an execution error during thunk execution, // forward it to the caller via the execute event. @@ -550,8 +554,9 @@ std::string ThunkExecutor::ToString() const { // Collect names of `in_edges`. std::vector> in_edges(num_thunks_); for (const auto& node_def : execution_graph_.nodes_defs()) { - for (NodeId in_edge : node_def.in_edges) { - in_edges[node_def.id].push_back(thunk_sequence_[in_edge]->info().op_name); + for (NodeEdge in_edge : node_def.in_edges) { + in_edges[node_def.id].push_back( + thunk_sequence_[in_edge.id]->info().op_name); } } diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index b9261b88cb3c04..7bc634eb47c400 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -96,6 +96,7 @@ class ThunkExecutor { // We use underlying execution graph nodes to index into the thunk sequence. using NodeId = ExecutionGraph::NodeId; using NodeDef = ExecutionGraph::NodeDef; + using NodeEdge = ExecutionGraph::NodeEdge; // A ready queue that executes nodes in FIFO order. class FifoReadyQueue { @@ -185,7 +186,7 @@ class ThunkExecutor { explicit Node(const NodeDef& node_def); alignas(kAtomicAlignment) std::atomic counter; - absl::Span out_edges; + absl::Span out_edges; }; static_assert(std::is_trivially_destructible_v, diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc index 35041d55b187ff..a09676db37e848 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc @@ -219,7 +219,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test { } // Thunk creation helper functions. absl::StatusOr> CreateAllGatherThunk( - std::shared_ptr communicator_resource = nullptr) { + std::shared_ptr communicator_resource = + Resource::Create(Resource::Kind::kCollectiveCommunicator)) { TF_RETURN_IF_ERROR(AddBufferAllocations(2)); return AllGatherThunk::Create( @@ -248,7 +249,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test { } absl::StatusOr> CreateAllReduceThunk( - std::shared_ptr communicator_resource = nullptr) { + std::shared_ptr communicator_resource = + Resource::Create(Resource::Kind::kCollectiveCommunicator)) { TF_RETURN_IF_ERROR(AddBufferAllocations(2)); return AllReduceThunk::Create( @@ -278,7 +280,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test { } absl::StatusOr> CreateAllToAllThunk( - std::shared_ptr communicator_resource = nullptr) { + std::shared_ptr communicator_resource = + Resource::Create(Resource::Kind::kCollectiveCommunicator)) { TF_RETURN_IF_ERROR(AddBufferAllocations(2)); return AllToAllThunk::Create( @@ -307,7 +310,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test { } absl::StatusOr> CreateReduceScatterThunk( - std::shared_ptr communicator_resource = nullptr) { + std::shared_ptr communicator_resource = + Resource::Create(Resource::Kind::kCollectiveCommunicator)) { TF_RETURN_IF_ERROR(AddBufferAllocations(2)); return ReduceScatterThunk::Create( @@ -341,12 +345,12 @@ class ThunkSequenceSerdesTest : public ::testing::Test { TF_ASSIGN_OR_RETURN(called_sequence.emplace_back(), CreateAllReduceThunk()); TF_ASSIGN_OR_RETURN(called_sequence.emplace_back(), CreateAllToAllThunk()); return CallThunk::Create(Thunk::Info(), - /*called_sequence=*/ - std::move(called_sequence)); + /*called_sequence=*/std::move(called_sequence)); } absl::StatusOr> CreateCollectivePermuteThunk( - std::shared_ptr communicator_resource = nullptr) { + std::shared_ptr communicator_resource = + Resource::Create(Resource::Kind::kCollectiveCommunicator)) { TF_RETURN_IF_ERROR(AddBufferAllocations(2)); return CollectivePermuteThunk::Create( diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 38e43d30be22ec..6b44b4690ec836 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -472,8 +472,10 @@ CommandBufferCmdExecutor::Dependencies(const RecordParams& record_params, // Collect commands that are dependencies of the command `id`. absl::InlinedVector dependencies_ids; if (execution_graph_) { - dependencies_ids.assign(execution_graph_->in_edges(id).begin(), - execution_graph_->in_edges(id).end()); + for (const ExecutionGraph::NodeEdge& in_edge : + execution_graph_->in_edges(id)) { + dependencies_ids.push_back(in_edge.id); + } } else { dependencies_ids.push_back(id - 1); } diff --git a/third_party/xla/xla/runtime/execution_graph.cc b/third_party/xla/xla/runtime/execution_graph.cc index dd193ccf9a86ea..fd3e8aa0549ce6 100644 --- a/third_party/xla/xla/runtime/execution_graph.cc +++ b/third_party/xla/xla/runtime/execution_graph.cc @@ -34,6 +34,27 @@ limitations under the License. namespace xla { +// Give aliases to the edge kinds to make code more readable. +static constexpr auto kExecution = ExecutionGraph::NodeEdge::Kind::kExecution; +static constexpr auto kScheduling = ExecutionGraph::NodeEdge::Kind::kScheduling; + +// A helper function to create a predicate that checks if a given node edge +// points to a given node id. +static auto EdgePredicate(ExecutionGraph::NodeId id) { + return [id](const ExecutionGraph::NodeEdge& edge) { return edge.id == id; }; +} + +// If any of the resource uses requires execution edge, we return kExecution +// edge kind, otherwise we return kScheduling edge kind. +static auto EdgeKind(absl::Span resource_uses) { + auto requires_execution_edge = [](const ResourceUse& resource_use) { + auto kind = resource_use.resource()->kind(); + return ExecutionGraph::NodeEdge::KindOf(kind) == kExecution; + }; + return absl::c_any_of(resource_uses, requires_execution_edge) ? kExecution + : kScheduling; +} + ExecutionGraph::ExecutionGraph(NodesEdges nodes_in_edges, NodesEdges nodes_out_edges, std::vector nodes_defs) @@ -57,7 +78,8 @@ ExecutionGraph::ExecutionGraph(NodesEdges nodes_in_edges, // Check if constructed execution DAG is sequential: every node depends on the // completion of the previous node. for (NodeId i = 1; i < nodes_defs_.size() && is_sequential_; ++i) { - is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0); + is_sequential_ &= + (absl::c_count_if(nodes_defs_[i].in_edges, EdgePredicate(i - 1)) != 0); } VLOG(2) << absl::StreamFormat( @@ -96,20 +118,30 @@ absl::StatusOr ExecutionGraph::Create( resource_rwsets[i].AddAll(op->ResourceUses()); for (NodeId j = 0; j < i; ++j) { - // Check if node `i` must be executed after node `j`. - if (buffer_rwsets[j].HasConflicts(buffer_rwsets[i]) || - resource_rwsets[j].HasConflicts(resource_rwsets[i])) { - builders[j].out_edges.push_back(i); - builders[i].in_edges.push_back(j); + if (buffer_rwsets[j].HasConflicts(buffer_rwsets[i])) { + // If we have buffer conflicts we must add an execution edge to + // guarantee that we don't have data races at run time. + builders[j].out_edges.push_back(NodeEdge{kExecution, i}); + builders[i].in_edges.push_back(NodeEdge{kExecution, j}); + + } else if (resource_rwsets[j].HasConflicts(resource_rwsets[i])) { + // If we have resource conflicts, we must check resources that are + // accessed by both nodes to find out what kind of edge we need to add. + auto kind = EdgeKind(resource_rwsets[j].Conflicts(resource_rwsets[i])); + builders[j].out_edges.push_back(NodeEdge{kind, i}); + builders[i].in_edges.push_back(NodeEdge{kind, j}); } } } - // Verify that both in-edges and out-edges are sorted in ascending order as we - // use this property later. + // Verify that both in-edges and out-edges are sorted in ascending order + // according to node id as we use this property later. for (NodeId i = 0; i < builders.size(); ++i) { - DCHECK(absl::c_is_sorted(builders[i].out_edges)); - DCHECK(absl::c_is_sorted(builders[i].in_edges)); + auto by_id = [](const NodeEdge& a, const NodeEdge& b) { + return a.id < b.id; + }; + DCHECK(absl::c_is_sorted(builders[i].out_edges, by_id)); + DCHECK(absl::c_is_sorted(builders[i].in_edges, by_id)); } // Erase redundant edges between nodes. @@ -155,9 +187,9 @@ ExecutionGraph::CreateNodeDefs(std::vector builders) { nodes_defs.push_back(NodeDef{ b.id, num_in_edges ? absl::MakeConstSpan(&*inserted_in_edges, num_in_edges) - : absl::Span(), + : absl::Span(), num_out_edges ? absl::MakeConstSpan(&*inserted_out_edges, num_out_edges) - : absl::Span(), + : absl::Span(), b.priority, }); } @@ -172,37 +204,51 @@ int64_t ExecutionGraph::EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to) { // Short-circuit if out or in-edges are empty. if (from.out_edges.empty() || to.in_edges.empty()) { - DCHECK_EQ(absl::c_count(from.out_edges, to.id), 0) << "Unexpected out edge"; - DCHECK_EQ(absl::c_count(to.in_edges, from.id), 0) << "Unexpected in edge"; + DCHECK_EQ(absl::c_count_if(from.out_edges, EdgePredicate(to.id)), 0) + << "Unexpected out edge from " << from.id << " to " << to.id; + DCHECK_EQ(absl::c_count_if(to.in_edges, EdgePredicate(from.id)), 0) + << "Unexpected in edge from " << from.id << " to " << to.id; return 0; } // Short-circuit if out-edges or in-edges don't intersect with `to` or `from` // node ids (remember that edges are sorted). - if (from.out_edges.back() < to.id || to.in_edges.front() > from.id) { - DCHECK_EQ(absl::c_count(from.out_edges, to.id), 0) << "Unexpected out edge"; - DCHECK_EQ(absl::c_count(to.in_edges, from.id), 0) << "Unexpected in edge"; + if (from.out_edges.back().id < to.id || to.in_edges.front().id > from.id) { + DCHECK_EQ(absl::c_count_if(from.out_edges, EdgePredicate(to.id)), 0) + << "Unexpected out edge from " << from.id << " to " << to.id; + DCHECK_EQ(absl::c_count_if(to.in_edges, EdgePredicate(from.id)), 0) + << "Unexpected in edge from " << from.id << " to " << to.id; return 0; } + // Comparator to find a node edge with a given node id. + auto less_than = [](const NodeEdge& edge, NodeId id) { return edge.id < id; }; + // Check if `from` node has an out edge to `to` node. - auto out_edges_it = absl::c_lower_bound(from.out_edges, to.id); + auto out_edges_it = absl::c_lower_bound(from.out_edges, to.id, less_than); bool has_out_edge = - out_edges_it != from.out_edges.end() && *out_edges_it == to.id; + out_edges_it != from.out_edges.end() && out_edges_it->id == to.id; // Short-circuit if there is no out edge from `from` node to `to` node. if (!has_out_edge) { - DCHECK_EQ(absl::c_count(to.in_edges, from.id), 0) << "Unexpected in edge"; + DCHECK_EQ(absl::c_count_if(to.in_edges, EdgePredicate(from.id)), 0) + << "Unexpected in edge from " << from.id << " to " << to.id; return 0; } // Check if `to` node has an in edge from `from` node. - auto in_edges_it = absl::c_lower_bound(to.in_edges, from.id); + auto in_edges_it = absl::c_lower_bound(to.in_edges, from.id, less_than); bool has_in_edge = - in_edges_it != to.in_edges.end() && *in_edges_it == from.id; + in_edges_it != to.in_edges.end() && in_edges_it->id == from.id; DCHECK(has_in_edge) << "In-edge must exist if out-edge exists"; + // At this point we must have exactly one edge between `from` and `to` nodes. + DCHECK_EQ(absl::c_count_if(from.out_edges, EdgePredicate(to.id)), 1) + << "Expected exactly one out edge from " << from.id << " to " << to.id; + DCHECK_EQ(absl::c_count_if(to.in_edges, EdgePredicate(from.id)), 1) + << "Expected exactly one in edge from " << from.id << " to " << to.id; + from.out_edges.erase(out_edges_it); to.in_edges.erase(in_edges_it); @@ -237,10 +283,12 @@ int64_t ExecutionGraph::RunTransitiveReductionAndUpdatePriorities( // Initialize stack with nodes reachable via immediate out nodes. We mark // immediate out nodes as visited to correctly compute node priority below. - for (int64_t out_id : source_node.out_edges) { - NodeDefBuilder& out_node = builders[out_id]; - visited[out_id] = true; - for (int64_t start_id : out_node.out_edges) add_to_stack(start_id); + for (NodeEdge out_edge : source_node.out_edges) { + NodeDefBuilder& out_node = builders[out_edge.id]; + visited[out_edge.id] = true; + for (NodeEdge start_edge : out_node.out_edges) { + add_to_stack(start_edge.id); + } } // Traverse the graph and delete redundant edges. @@ -251,7 +299,9 @@ int64_t ExecutionGraph::RunTransitiveReductionAndUpdatePriorities( NodeDefBuilder& node = builders[node_id]; num_erased_edges += EraseEdge(source_node, node); - for (int64_t out_id : node.out_edges) add_to_stack(out_id); + for (NodeEdge out_edge : node.out_edges) { + add_to_stack(out_edge.id); + } } // Set node priority to the number of visited nodes in the DFS traversal. diff --git a/third_party/xla/xla/runtime/execution_graph.h b/third_party/xla/xla/runtime/execution_graph.h index 71d4999567a272..0475a0aee7f4eb 100644 --- a/third_party/xla/xla/runtime/execution_graph.h +++ b/third_party/xla/xla/runtime/execution_graph.h @@ -45,9 +45,14 @@ namespace xla { // // At run time we can relax sequential schedule and execute operations // concurrently, as long as we don't create data races (reading and writing -// from/To the same or overlapping buffer slices concurrently), or resource +// from/to the same or overlapping buffer slices concurrently), or resource // races (using the same mutable resource concurrently). // +// Resources can behave as buffers and require an execution order (operation +// must wait for the completion of execution of all dependencies), or as a +// scheduling barrier (operation must wait for the completion of scheduling of +// all dependencies). See more details in the `NodeEdge::Kind` definition. +// // We use buffer and resource use conflicts to define an execution order of // operations as a directed acyclic graph (DAG) that satisfies all dependencies. // @@ -61,12 +66,49 @@ class ExecutionGraph { static constexpr NodeId kInvalidNodeId = std::numeric_limits::min(); + struct NodeEdge { + enum class Kind { + // If two operations have a scheduling edge between them, then the + // dependent operation must be scheduled (start execution) after the + // dependency operation scheduled (started execution), however it doesn't + // have to wait for the completion of execution. We use this type of + // edge to guarantee that operations that share the same resource (i.e. + // collective communicator) start execution in a deterministic order + // across different ranks, however the execution of operations can + // overlap and finish in any order, and backend-implementation specific. + kScheduling, + + // If two operations have an execution edge between them, then the + // dependent operation must wait for the completion of dependency + // operation execution. We use this type of edge to order execution of + // operations that read and write from/to the same buffers, as otherwise + // we may create data races. + kExecution, + }; + + static constexpr NodeEdge::Kind KindOf(Resource::Kind resource) { + switch (resource) { + case Resource::kToken: + return NodeEdge::Kind::kExecution; + case Resource::kCollectiveCommunicator: + return NodeEdge::Kind::kScheduling; + } + } + + bool operator==(const NodeEdge& other) const { + return kind == other.kind && id == other.id; + } + + Kind kind; + NodeId id; + }; + // NodeDef defines a dependency-based execution order for all operations. struct NodeDef { NodeId id = kInvalidNodeId; - absl::Span in_edges; - absl::Span out_edges; + absl::Span in_edges; + absl::Span out_edges; // When doing the transitive reduction, we assign a priority to each node // based on the number of nodes that are reachable from the given node. The @@ -124,13 +166,13 @@ class ExecutionGraph { } // Returns in-edges for a given node id. - absl::Span in_edges(NodeId id) const { + absl::Span in_edges(NodeId id) const { DCHECK_EQ(id, nodes_defs_[id].id); return nodes_defs_[id].in_edges; } // Returns out-edges for a given node id. - absl::Span out_edges(NodeId id) const { + absl::Span out_edges(NodeId id) const { DCHECK_EQ(id, nodes_defs_[id].id); return nodes_defs_[id].out_edges; } @@ -150,7 +192,7 @@ class ExecutionGraph { // We store all `in_edges` and `out_edges` referenced by the `NodeDef` inside // large vectors to optimize for data locality on a hot path. - using NodesEdges = std::vector; + using NodesEdges = std::vector; // A NodeDef builder to collect all in-edges and out-edges before constructing // a NodeDef. We use it at dependency graph construction time when we don't @@ -158,8 +200,8 @@ class ExecutionGraph { struct NodeDefBuilder { NodeId id = kInvalidNodeId; int64_t priority = 0; - std::vector in_edges; - std::vector out_edges; + std::vector in_edges; + std::vector out_edges; }; ExecutionGraph(NodesEdges nodes_in_edges, NodesEdges nodes_out_edges, diff --git a/third_party/xla/xla/runtime/execution_graph_test.cc b/third_party/xla/xla/runtime/execution_graph_test.cc index a05ad08c21de18..606a51342328bd 100644 --- a/third_party/xla/xla/runtime/execution_graph_test.cc +++ b/third_party/xla/xla/runtime/execution_graph_test.cc @@ -31,6 +31,12 @@ namespace { using ::testing::ElementsAre; +using NodeEdge = ExecutionGraph::NodeEdge; + +// Give aliases to the edge kinds to make tests more readable. +static constexpr auto kExecution = NodeEdge::Kind::kExecution; +static constexpr auto kScheduling = NodeEdge::Kind::kScheduling; + // A test-only operation for verifying execution graph implementation. class Operation : public ExecutionGraph::Operation { public: @@ -71,9 +77,12 @@ TEST(ExecutionGraphTest, DependencyOrdering) { EXPECT_THAT(execution_graph.source(), ElementsAre(0, 1)); EXPECT_THAT(execution_graph.sink(), ElementsAre(2)); - EXPECT_THAT(execution_graph.out_edges(0), ElementsAre(2)); - EXPECT_THAT(execution_graph.out_edges(1), ElementsAre(2)); - EXPECT_THAT(execution_graph.in_edges(2), ElementsAre(0, 1)); + EXPECT_THAT(execution_graph.out_edges(0), + ElementsAre(NodeEdge{kExecution, 2})); + EXPECT_THAT(execution_graph.out_edges(1), + ElementsAre(NodeEdge{kExecution, 2})); + EXPECT_THAT(execution_graph.in_edges(2), + ElementsAre(NodeEdge{kExecution, 0}, NodeEdge{kExecution, 1})); EXPECT_EQ(execution_graph.priority(0), 1); EXPECT_EQ(execution_graph.priority(1), 1); @@ -99,17 +108,21 @@ TEST(ExecutionGraphTest, SequentialOrdering) { EXPECT_THAT(execution_graph.source(), ElementsAre(0)); EXPECT_THAT(execution_graph.sink(), ElementsAre(2)); - EXPECT_THAT(execution_graph.out_edges(0), ElementsAre(1)); - EXPECT_THAT(execution_graph.out_edges(1), ElementsAre(2)); - EXPECT_THAT(execution_graph.in_edges(1), ElementsAre(0)); - EXPECT_THAT(execution_graph.in_edges(2), ElementsAre(1)); + EXPECT_THAT(execution_graph.out_edges(0), + ElementsAre(NodeEdge{kExecution, 1})); + EXPECT_THAT(execution_graph.out_edges(1), + ElementsAre(NodeEdge{kExecution, 2})); + EXPECT_THAT(execution_graph.in_edges(1), + ElementsAre(NodeEdge{kExecution, 0})); + EXPECT_THAT(execution_graph.in_edges(2), + ElementsAre(NodeEdge{kExecution, 1})); EXPECT_EQ(execution_graph.priority(0), 2); EXPECT_EQ(execution_graph.priority(1), 1); EXPECT_EQ(execution_graph.priority(2), 0); } -TEST(ExecutionGraphTest, ResourceOrdering) { +TEST(ExecutionGraphTest, TokenResourceOrdering) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/40); @@ -132,13 +145,59 @@ TEST(ExecutionGraphTest, ResourceOrdering) { EXPECT_THAT(execution_graph.source(), ElementsAre(0)); EXPECT_THAT(execution_graph.sink(), ElementsAre(1)); - EXPECT_THAT(execution_graph.out_edges(0), ElementsAre(1)); - EXPECT_THAT(execution_graph.in_edges(1), ElementsAre(0)); + EXPECT_THAT(execution_graph.out_edges(0), + ElementsAre(NodeEdge{kExecution, 1})); + EXPECT_THAT(execution_graph.in_edges(1), + ElementsAre(NodeEdge{kExecution, 0})); EXPECT_EQ(execution_graph.priority(0), 1); EXPECT_EQ(execution_graph.priority(1), 0); } +TEST(ExecutionGraphTest, CollectivesResourceOrdering) { + BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); + + BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/40); + BufferAllocation::Slice slice1(&alloc, /*offset=*/40, /*size=*/40); + + auto resource = Resource::Create(Resource::Kind::kCollectiveCommunicator); + + std::vector operations; + operations.push_back( + Operation({BufferUse::Read(slice0), BufferUse::Write(slice0)}, + {ResourceUse::Write(resource)})); + operations.push_back( + Operation({BufferUse::Read(slice1), BufferUse::Write(slice1)}, + {ResourceUse::Write(resource)})); + operations.push_back( + Operation({BufferUse::Read(slice1), BufferUse::Write(slice1)}, + {ResourceUse::Write(resource)})); + + TF_ASSERT_OK_AND_ASSIGN(ExecutionGraph execution_graph, + ExecutionGraph::Create(operations)); + + EXPECT_TRUE(execution_graph.is_sequential()); + EXPECT_THAT(execution_graph.source(), ElementsAre(0)); + EXPECT_THAT(execution_graph.sink(), ElementsAre(2)); + + EXPECT_THAT(execution_graph.out_edges(0), + ElementsAre(NodeEdge{kScheduling, 1})); + + EXPECT_THAT(execution_graph.in_edges(1), + ElementsAre(NodeEdge{kScheduling, 0})); + + // We have buffer conflicts, and a resource conflict, so in this case we + // must add an execution edge as it provides stronger ordering guarantee + EXPECT_THAT(execution_graph.out_edges(1), + ElementsAre(NodeEdge{kExecution, 2})); + EXPECT_THAT(execution_graph.in_edges(2), + ElementsAre(NodeEdge{kExecution, 1})); + + EXPECT_EQ(execution_graph.priority(0), 2); + EXPECT_EQ(execution_graph.priority(1), 1); + EXPECT_EQ(execution_graph.priority(2), 0); +} + TEST(ExecutionGraphTest, TransitiveReduction) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); BufferAllocation::Slice slice(&alloc, /*offset=*/0, /*size=*/40); @@ -157,10 +216,14 @@ TEST(ExecutionGraphTest, TransitiveReduction) { EXPECT_THAT(execution_graph.source(), ElementsAre(0)); EXPECT_THAT(execution_graph.sink(), ElementsAre(2)); - EXPECT_THAT(execution_graph.out_edges(0), ElementsAre(1)); - EXPECT_THAT(execution_graph.in_edges(1), ElementsAre(0)); - EXPECT_THAT(execution_graph.out_edges(1), ElementsAre(2)); - EXPECT_THAT(execution_graph.in_edges(2), ElementsAre(1)); + EXPECT_THAT(execution_graph.out_edges(0), + ElementsAre(NodeEdge{kExecution, 1})); + EXPECT_THAT(execution_graph.in_edges(1), + ElementsAre(NodeEdge{kExecution, 0})); + EXPECT_THAT(execution_graph.out_edges(1), + ElementsAre(NodeEdge{kExecution, 2})); + EXPECT_THAT(execution_graph.in_edges(2), + ElementsAre(NodeEdge{kExecution, 1})); EXPECT_EQ(execution_graph.priority(0), 2); EXPECT_EQ(execution_graph.priority(1), 1); diff --git a/third_party/xla/xla/runtime/resource_use.cc b/third_party/xla/xla/runtime/resource_use.cc index 185842bd66c8b2..164323bfaa48b2 100644 --- a/third_party/xla/xla/runtime/resource_use.cc +++ b/third_party/xla/xla/runtime/resource_use.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/runtime/resource_use.h" #include +#include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" @@ -49,7 +51,9 @@ void ResourceUse::ReadWriteSet::Add(ResourceUse use) { } void ResourceUse::ReadWriteSet::AddAll(absl::Span uses) { - for (const auto& use : uses) Add(use); + for (const auto& use : uses) { + Add(use); + } } bool ResourceUse::ReadWriteSet::HasConflicts(const ResourceUse& use) const { @@ -75,4 +79,23 @@ bool ResourceUse::ReadWriteSet::HasConflicts(const ReadWriteSet& other) { }); } +std::vector ResourceUse::ReadWriteSet::Conflicts( + const ReadWriteSet& other) { + std::vector conflicts; + + for (const std::shared_ptr& resource : other.read_) { + if (auto read = ResourceUse::Read(resource); HasConflicts(read)) { + conflicts.push_back(std::move(read)); + } + } + + for (const std::shared_ptr& resource : other.write_) { + if (auto write = ResourceUse::Write(resource); HasConflicts(write)) { + conflicts.push_back(std::move(write)); + } + } + + return conflicts; +} + } // namespace xla diff --git a/third_party/xla/xla/runtime/resource_use.h b/third_party/xla/xla/runtime/resource_use.h index bb105062ee913e..a43a9e8474b6a9 100644 --- a/third_party/xla/xla/runtime/resource_use.h +++ b/third_party/xla/xla/runtime/resource_use.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" @@ -25,7 +26,7 @@ limitations under the License. namespace xla { // `Resource` models a run time resource that imposes ordering on the thunk -// execution in addition to thunk buffer uses. +// execution (scheduling) in addition to thunk buffer uses. class Resource { public: enum class Kind { @@ -34,7 +35,7 @@ class Resource { // enforce ordering at run time. kToken, - // Collective operations must be executed in the same order as they are + // Collective operations must be scheduled in the same order as they are // defined in the HLO module. We rely on collective communicator resource // to enforce ordering at run time. kCollectiveCommunicator @@ -87,6 +88,10 @@ class ResourceUse { bool HasConflicts(absl::Span uses) const; bool HasConflicts(const ReadWriteSet& other); + // Collects all resource uses that have a conflict with tracked resource + // reads or writes. + std::vector Conflicts(const ReadWriteSet& other); + private: absl::flat_hash_set> read_; absl::flat_hash_set> write_; diff --git a/third_party/xla/xla/runtime/resource_use_test.cc b/third_party/xla/xla/runtime/resource_use_test.cc index f42877484f65d3..b6a97adbd05242 100644 --- a/third_party/xla/xla/runtime/resource_use_test.cc +++ b/third_party/xla/xla/runtime/resource_use_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "xla/runtime/resource_use.h" +#include + +#include #include "xla/tsl/platform/test.h" namespace xla { @@ -47,6 +50,13 @@ TEST(ResourceUseTest, ReadWriteSet) { EXPECT_TRUE(rwset.HasConflicts({ResourceUse::Write(token0)})); EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Read(token1)})); EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Write(token1)})); + + ResourceUse::ReadWriteSet rwset2; + rwset2.Add(ResourceUse::Write(token0)); + + std::vector conflicts = rwset.Conflicts(rwset2); + ASSERT_EQ(conflicts.size(), 1); + EXPECT_EQ(conflicts.front(), ResourceUse::Write(token0)); } } // namespace From 2c9c485695d2f36f56264442540a39a1570648b0 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 16 Apr 2025 18:00:23 -0700 Subject: [PATCH 0896/1324] Finish HLO->StableHLO direct conversions. Begin cleanups after Direct HLO->StableHLO, remove tuple / token handling. PiperOrigin-RevId: 748484626 --- .../compiler/mlir/tensorflow/transforms/BUILD | 1 + .../tensorflow/transforms/shape_inference.cc | 3 +- tensorflow/compiler/mlir/tf2xla/api/v1/BUILD | 1 + .../mlir/tf2xla/api/v1/compile_mlir_util.cc | 9 +- .../mlir/tf2xla/tests/legalize-tf.mlir | 32 ++-- .../compiler/mlir/tf2xla/transforms/BUILD | 5 +- .../tf2xla/transforms/legalize_tf_patterns.td | 10 +- .../transforms/legalize_tf_with_tf2xla.cc | 11 +- .../mlir/tf2xla/transforms/tf2xla_rewriter.cc | 11 +- .../transforms/verify_tfxla_legalization.cc | 7 +- .../mlir/tf2xla/transforms/xla_legalize_tf.cc | 10 +- .../xla/xla/hlo/translate/hlo_to_mhlo/BUILD | 5 +- .../hlo_to_mhlo/attribute_importer.cc | 175 +++++------------- .../hlo_to_mhlo/attribute_importer.h | 26 ++- .../hlo_to_mhlo/custom_call_importer.cc | 15 +- .../hlo_to_mhlo/hlo_function_importer.cc | 116 ++++-------- .../hlo_to_mhlo/hlo_function_importer.h | 6 + .../hlo/translate/hlo_to_mhlo/hlo_utils.cc | 27 +-- .../xla/hlo/translate/hlo_to_mhlo/hlo_utils.h | 3 +- .../translate/hlo_to_mhlo/hlo_utils_test.cc | 7 +- .../tests/import_emit_stablehlo.hlo | 38 ++-- .../hlo_legalize_to_stablehlo_pass.cc | 50 ++++- .../xla/mlir_hlo/mhlo/transforms/rewriters.h | 4 + .../stablehlo_legalize_to_hlo.cc | 75 ++++++++ .../stablehlo_legalize_to_hlo_pass.cc | 32 +--- .../mhlo/hlo-legalize-to-stablehlo.mlir | 8 + .../mhlo/stablehlo-legalize-to-hlo.mlir | 8 + 27 files changed, 326 insertions(+), 369 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 1e14755b0119f5..4c7aa232edbb1d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -796,6 +796,7 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", + "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:shape_inference", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/util:env_var", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 50f6cc54c4e12c..106c65368a18f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/tsl/platform/errors.h" @@ -510,7 +511,7 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, } new_arg_type = tensorflow::GetTypeFromTFTensorShape( new_shape, element_type, - mhlo::TypeExtensionsAttr::get(context, new_bounds)); + mlir::mhlo::TypeExtensionsAttr::get(context, new_bounds)); } } return new_arg_type; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 3fe8f0cb052062..cf83f71d0a6629 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -77,6 +77,7 @@ cc_library( "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:base", "@stablehlo//:register", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 6281ea68e37807..829d3ca5819379 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -54,6 +54,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -201,12 +202,12 @@ mlir::RankedTensorType GetBufferType(mlir::Type ty) { int64_t rank = ranked_ty.getRank(); llvm::SmallVector dims = llvm::to_vector<4>(ranked_ty.getShape()); - auto encoding = mlir::dyn_cast_or_null( - ranked_ty.getEncoding()); - if (encoding && !encoding.getBounds().empty()) { + llvm::ArrayRef bounds = + mlir::hlo::encodingToBounds(ranked_ty.getEncoding()); + if (!bounds.empty()) { for (int64_t dim = 0; dim < rank; ++dim) { if (dims[dim] == mlir::ShapedType::kDynamic) { - dims[dim] = encoding.getBounds()[dim]; + dims[dim] = bounds[dim]; } } } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index c54bef4f4a6947..92754a181e8551 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -91,12 +91,12 @@ func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8x // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %arg3, %[[ALPHA]] - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[MEAN]], %[[BETA]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[MEAN]] // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %arg4, %[[ALPHA]] - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[CORRECTED_VAR]], %[[BETA]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]] @@ -430,7 +430,8 @@ func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: // CHECK-LABEL: func @biasAdd_default func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> @@ -442,7 +443,8 @@ func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) // CHECK-LABEL: func @biasAdd_NHWC func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> @@ -454,7 +456,8 @@ func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK-LABEL: func @biasAdd_NCHW func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> @@ -466,7 +469,8 @@ func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK-LABEL: func @biasAdd_dynamic func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor @@ -478,7 +482,8 @@ func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: func @biasAdd_partial_dynamic func.func @biasAdd_partial_dynamic(%arg0: tensor, %arg1: tensor<512xi32>) -> tensor { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_SHAPE]]) + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] // CHECK: %[[CAST:.+]] = tensor.cast %[[RESULT]] : tensor to tensor @@ -1792,7 +1797,7 @@ func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> // CHECK-LABEL: func @relu func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %arg0, %[[ZERO]] {broadcast_dimensions = array} : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> } @@ -1802,7 +1807,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unsigned func.func @relu_unsigned(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %arg0, %[[ZERO]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor func.return %0: tensor } @@ -1877,7 +1882,7 @@ func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> te // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ABS]], %[[ONE]] {broadcast_dimensions = array} : (tensor<4x10xf32>, tensor) -> tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = array} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> @@ -1953,7 +1958,8 @@ func.func @select_batch_dynamic_r1(%arg0: tensor, %arg1: tensor // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex> // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor) { - // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor, tensor<3xindex>) -> tensor + // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex> + // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[SELECT:.*]] = mhlo.select %[[BCAST]], %arg1, %arg2 : tensor, tensor // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 1da9a071a3c0f3..76d690d382f2c7 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -290,6 +290,7 @@ cc_library( "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", + "@stablehlo//:base", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", ], @@ -345,12 +346,12 @@ cc_library( "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", - "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/platform:env", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:base", "@stablehlo//:stablehlo_ops", ], ) @@ -425,7 +426,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@local_xla//xla/mlir_hlo", + "@stablehlo//:base", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 46f3ebfe19104d..a0404806ced750 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -692,12 +692,16 @@ def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), // Sigmoid grad op. //===----------------------------------------------------------------------===// -// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the -// shape of $l instead of having it as a constant. +// Only handle static shape here, dynamic shape is handled by +// ConvertSigmoidGradOpDynamic +def HasStaticShape : Constraint< + CPred<"::llvm::dyn_cast($0.getType()).hasStaticShape()">>; + def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), (MHLO_MulOp (MHLO_MulOp $r, $l), - (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; + (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l)), + [(HasStaticShape $l)]>; //===----------------------------------------------------------------------===// // Softplus op. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index 2d9bc167d2c0a4..14cec354ddcb9e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -31,6 +32,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" @@ -75,13 +76,11 @@ bool IsBounded(Type ty) { if (ranked_ty.hasStaticShape()) return true; - auto encoding = - mlir::dyn_cast_or_null(ranked_ty.getEncoding()); - if (!encoding) return false; + auto bounds = hlo::encodingToBounds(ranked_ty.getEncoding()); + if (bounds.empty()) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { - if (ranked_ty.isDynamicDim(i) && - encoding.getBounds()[i] == ShapedType::kDynamic) { + if (ranked_ty.isDynamicDim(i) && bounds[i] == ShapedType::kDynamic) { return false; } } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index f35af77a1e6082..7f3ec19a70967a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -50,6 +51,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -70,7 +72,6 @@ limitations under the License. #include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -260,13 +261,11 @@ bool IsBounded(Type ty) { if (ranked_ty.hasStaticShape()) return true; - auto encoding = - mlir::dyn_cast_or_null(ranked_ty.getEncoding()); - if (!encoding) return false; + ArrayRef bounds = hlo::encodingToBounds(ranked_ty.getEncoding()); + if (bounds.empty()) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { - if (ranked_ty.isDynamicDim(i) && - encoding.getBounds()[i] == ShapedType::kDynamic) { + if (ranked_ty.isDynamicDim(i) && bounds[i] == ShapedType::kDynamic) { return false; } } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index d99f80ff5eacd5..e3d5d5f1b5a5d3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -21,11 +21,13 @@ limitations under the License. #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" @@ -87,9 +89,8 @@ static void IncrementCounterFor(tensorflow::monitoring::Counter<1>* counter, } bool HasBounds(RankedTensorType type) { - auto encoding = mlir::dyn_cast_or_null( - type.getEncoding()); - return (encoding && !encoding.getBounds().empty()); + auto bounds = hlo::encodingToBounds(type.getEncoding()); + return !bounds.empty(); } bool HasStaticShapeOrBounded(Value val) { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index 6db7ccec27c710..89df1e1793d1d0 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -155,17 +155,15 @@ mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, } mlir::LogicalResult StablehloToMhlo(Operation *op) { - RewritePatternSet patterns(op->getContext()); - stablehlo::StablehloToHloTypeConverter converter; - stablehlo::populateStablehloToHloPatterns(&patterns, &converter, - op->getContext()); ConversionTarget target(*op->getContext()); - target.addLegalDialect(); - target.addIllegalDialect(); + stablehlo::setupStablehloToHloConversionTarget(target); + + RewritePatternSet patterns(op->getContext()); stablehlo::StablehloToHloTypeConverter shlo_converter; stablehlo::populateStablehloToHloPatterns(&patterns, &shlo_converter, patterns.getContext()); stablehlo::registerFuncOpsForTypeConversion(target, patterns, shlo_converter); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { return op->emitError("TF2XLA failed to convert StableHLO to MHLO"); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index eb81655b95072e..670f44a3c901b1 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -61,7 +61,6 @@ cc_library( deps = [ "//xla:util", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -69,6 +68,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", ], ) @@ -185,7 +185,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -206,12 +205,12 @@ xla_cc_test( ":hlo_utils", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test_main", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc index cf411420bbac99..ba76f5cce6cf7f 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc @@ -103,6 +103,31 @@ mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, return stablehlo::ConvertChannelHandle(channel_handle, builder); } +mlir::stablehlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( + const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) { + auto arrayref = [](absl::Span array) { + return llvm::ArrayRef{array.data(), array.size()}; + }; + llvm::SmallVector input_spatial_dims( + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + llvm::SmallVector kernel_spatial_dims( + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + llvm::SmallVector output_spatial_dims( + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + return mlir::stablehlo::ConvDimensionNumbersAttr::get( + builder->getContext(), dnums.input_batch_dimension(), + dnums.input_feature_dimension(), + arrayref(dnums.input_spatial_dimensions()), + dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension(), + arrayref(dnums.kernel_spatial_dimensions()), + dnums.output_batch_dimension(), dnums.output_feature_dimension(), + arrayref(dnums.output_spatial_dimensions())); +} + absl::StatusOr ConvertCustomCallApiVersion(xla::CustomCallApiVersion api_version) { switch (api_version) { @@ -247,6 +272,28 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, return builder->getArrayAttr(operand_precision_attrs); } +mlir::stablehlo::ResultAccuracyAttr ConvertResultAccuracy( + const ResultAccuracy& result_accuracy, mlir::Builder* builder) { + if (result_accuracy.has_tolerance()) { + return mlir::stablehlo::ResultAccuracyAttr::get( + builder->getContext(), + llvm::APFloat(result_accuracy.tolerance().atol()), + llvm::APFloat(result_accuracy.tolerance().rtol()), + result_accuracy.tolerance().ulps(), + // Explicitly set the mode to TOLERANCE since ResultAccuracy has no + // TOLERANCE enum. + mlir::stablehlo::ResultAccuracyModeAttr::get( + builder->getContext(), + mlir::stablehlo::ResultAccuracyMode::TOLERANCE)); + } + return mlir::stablehlo::ResultAccuracyAttr::get( + builder->getContext(), llvm::APFloat(0.0), llvm::APFloat(0.0), 0, + mlir::stablehlo::ResultAccuracyModeAttr::get( + builder->getContext(), + mlir::stablehlo::symbolizeResultAccuracyMode(result_accuracy.mode()) + .value())); +} + } // namespace stablehlo mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, @@ -307,82 +354,6 @@ mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( scatter_dims_to_operand_dims, dnums.index_vector_dim()); } -mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( - const PrecisionConfig::Algorithm algorithm, mlir::Builder* builder) { - mlir::Type lhs, rhs, accum; - int64_t lhsComponentCount = 1, rhsComponentCount = 1, - numPrimitiveOperations = 1; - bool allowImpreciseAccumulation = false; - switch (algorithm) { - case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: { - lhs = rhs = builder->getType(); - accum = builder->getF32Type(); - break; - } - case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: { - lhs = rhs = builder->getType(); - accum = builder->getF32Type(); - allowImpreciseAccumulation = true; - break; - } - case PrecisionConfig::ALG_DOT_F16_F16_F16: { - lhs = rhs = accum = builder->getF16Type(); - break; - } - case PrecisionConfig::ALG_DOT_F16_F16_F32: { - lhs = rhs = builder->getF16Type(); - accum = builder->getF32Type(); - break; - } - case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: { - lhs = rhs = accum = builder->getBF16Type(); - break; - } - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: { - lhs = rhs = builder->getBF16Type(); - accum = builder->getF32Type(); - break; - } - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: { - lhs = rhs = builder->getBF16Type(); - accum = builder->getF32Type(); - numPrimitiveOperations = 3; - break; - } - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: { - lhs = rhs = builder->getBF16Type(); - accum = builder->getF32Type(); - numPrimitiveOperations = 6; - break; - } - case PrecisionConfig::ALG_DOT_TF32_TF32_F32: { - lhs = rhs = builder->getTF32Type(); - accum = builder->getF32Type(); - break; - } - case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: { - lhs = rhs = builder->getTF32Type(); - accum = builder->getF32Type(); - numPrimitiveOperations = 3; - break; - } - case PrecisionConfig::ALG_DOT_F32_F32_F32: { - lhs = rhs = accum = builder->getF32Type(); - break; - } - case PrecisionConfig::ALG_DOT_F64_F64_F64: { - lhs = rhs = accum = builder->getF64Type(); - break; - } - default: - // Unset, sentinels - return mlir::mhlo::DotAlgorithmAttr{}; - } - return mlir::mhlo::DotAlgorithmAttr::get( - builder->getContext(), lhs, rhs, accum, lhsComponentCount, - rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation); -} - mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( const DotDimensionNumbers& dnums, mlir::Builder* builder) { auto arrayref = [](absl::Span array) { @@ -461,37 +432,6 @@ absl::StatusOr ConvertSparsityDescriptor( } } -absl::StatusOr ConvertFftType(FftType type) { - switch (type) { - case FftType::FFT: - return mlir::mhlo::FftType::FFT; - case FftType::IFFT: - return mlir::mhlo::FftType::IFFT; - case FftType::RFFT: - return mlir::mhlo::FftType::RFFT; - case FftType::IRFFT: - return mlir::mhlo::FftType::IRFFT; - default: - return InvalidArgument("Unknown FFT type enum value #%d", type); - } -} - -absl::StatusOr ConvertTranspose( - xla::TriangularSolveOptions_Transpose transpose) { - switch (transpose) { - case TriangularSolveOptions::NO_TRANSPOSE: - return mlir::mhlo::Transpose::NO_TRANSPOSE; - case TriangularSolveOptions::TRANSPOSE: - return mlir::mhlo::Transpose::TRANSPOSE; - case TriangularSolveOptions::ADJOINT: - return mlir::mhlo::Transpose::ADJOINT; - case TriangularSolveOptions::TRANSPOSE_INVALID: - return mlir::mhlo::Transpose::TRANSPOSE_INVALID; - default: - return InvalidArgument("Unknown transpose enum value #%d", transpose); - } -} - absl::StatusOr ConvertCustomCallApiVersion( xla::CustomCallApiVersion api_version) { TF_ASSIGN_OR_RETURN(auto stablehlo_api_version, @@ -604,25 +544,4 @@ absl::StatusOr ExtractLayoutsFromTuple( return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder); } -mlir::mhlo::ResultAccuracyAttr ConvertResultAccuracy( - const ResultAccuracy& result_accuracy, mlir::Builder* builder) { - if (result_accuracy.has_tolerance()) { - return mlir::mhlo::ResultAccuracyAttr::get( - builder->getContext(), - llvm::APFloat(result_accuracy.tolerance().atol()), - llvm::APFloat(result_accuracy.tolerance().rtol()), - result_accuracy.tolerance().ulps(), - // Explicitly set the mode to TOLERANCE since ResultAccuracy has no - // TOLERANCE enum. - mlir::mhlo::ResultAccuracyModeAttr::get( - builder->getContext(), mlir::mhlo::ResultAccuracyMode::TOLERANCE)); - } - return mlir::mhlo::ResultAccuracyAttr::get( - builder->getContext(), llvm::APFloat(0.0), llvm::APFloat(0.0), 0, - mlir::mhlo::ResultAccuracyModeAttr::get( - builder->getContext(), - mlir::mhlo::symbolizeResultAccuracyMode(result_accuracy.mode()) - .value())); -} - } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h index eb857cf0053528..acc18302bcaa1e 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h @@ -56,20 +56,28 @@ mlir::stablehlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( mlir::stablehlo::DotAlgorithmAttr ConvertDotAlgorithm( PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); +// Converts the conv dimensions to attributes. +mlir::stablehlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( + const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); + // Converts the dot dimensions to attributes. mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( const DotDimensionNumbers& dnums, mlir::Builder* builder); // Converts the output operand aliasing to attributes. mlir::ArrayAttr ConvertOutputOperandAliasing( - const std::vector>>& aliaInfo, + const std::vector>>& aliasInfo, mlir::Builder* builder); // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, mlir::Builder* builder); +// Converts an XLA ResultAccuracy to the corresponding MLIR attribute. +mlir::stablehlo::ResultAccuracyAttr ConvertResultAccuracy( + const ResultAccuracy& result_accuracy, mlir::Builder* builder); + } // namespace stablehlo // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. @@ -86,11 +94,6 @@ mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); -// Converts the dot algorithm to attributes. -// Used by sparse dot. -mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( - PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); - // Converts the dot dimensions to attributes. // Used by sparse dot. mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( @@ -101,6 +104,7 @@ mlir::mhlo::RaggedDotDimensionNumbersAttr ConvertRaggedDotDimensionNumbers( const RaggedDotDimensionNumbers& dnums, mlir::Builder* builder); // Converts the conv dimensions to attributes. +// [Deprecated] Used in TF2XLA only. mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); @@ -114,10 +118,6 @@ mlir::ArrayAttr ConvertOutputOperandAliasing( absl::StatusOr ConvertSparsityDescriptor( xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder); -absl::StatusOr ConvertFftType(FftType type); -absl::StatusOr ConvertTranspose( - TriangularSolveOptions_Transpose transpose); - absl::StatusOr ConvertCustomCallApiVersion( xla::CustomCallApiVersion api_version); @@ -146,10 +146,6 @@ absl::StatusOr ExtractLayoutsFromShapes( absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, mlir::Builder* builder); -// Converts the ResultAccuracy to ResultAccuracyAttr. -mlir::mhlo::ResultAccuracyAttr ConvertResultAccuracy( - const ResultAccuracy& result_accuracy, mlir::Builder* builder); - } // namespace xla #endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc index 10a78e149ec973..ede0bec69c1bcb 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc @@ -31,8 +31,8 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/util.h" namespace xla { @@ -66,9 +66,9 @@ absl::StatusOr ImportDynamicBroadcastInDimOp( } return builder - ->create( + ->create( loc, result_type, operands[0], operands[1], - builder->getI64TensorAttr(broadcast_dimensions)) + builder->getDenseI64ArrayAttr(broadcast_dimensions)) .getOperation(); } @@ -79,7 +79,7 @@ absl::StatusOr ImportDynamicReshapeOp( return Internal("backend_config attribute must be empty."); } return builder - ->create(loc, result_type, operands) + ->create(loc, result_type, operands) .getOperation(); } @@ -90,7 +90,7 @@ absl::StatusOr ImportRealDynamicSliceOp( return Internal("backend_config attribute must be empty."); } return builder - ->create(loc, result_type, operands) + ->create(loc, result_type, operands) .getOperation(); } @@ -187,7 +187,7 @@ absl::StatusOr ImportCustomCallAsOp( if (custom_call_target == "mhlo.uniform_quantize") { return builder - ->create( + ->create( loc, mlir::RankedTensorType::get( mlir::cast(result_type).getShape(), @@ -198,7 +198,8 @@ absl::StatusOr ImportCustomCallAsOp( if (custom_call_target == "mhlo.uniform_dequantize") { return builder - ->create(loc, result_type, operands) + ->create(loc, result_type, + operands) .getOperation(); } return InvalidArgument("Unsupported MHLO op custom_call %s", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 2bedc76a2f25db..01fcef5757bbbe 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -145,27 +145,10 @@ ArrayRef FlattenTupleSharding(const HloSharding& sharding) { } // Returns true if changed. -bool FoldGetTupleElementOfTuple(Operation* op) { - int64_t idx; - if (auto getTupleElementOp = - llvm::dyn_cast(op)) { - idx = getTupleElementOp.getIndex(); - } else if (auto getTupleElementOp = - llvm::dyn_cast(op)) { - idx = getTupleElementOp.getIndex(); - } else { - llvm::report_fatal_error("Unexpected op for tuple folding: " + - op->getName().getStringRef()); - } - - if (auto tupleOp = op->getOperand(0).getDefiningOp()) { - llvm::SmallVector new_operand{tupleOp.getOperand(idx)}; - op->replaceAllUsesWith(new_operand); - op->erase(); - return true; - } +bool FoldGetTupleElementOfTuple(mlir::stablehlo::GetTupleElementOp op) { if (auto tupleOp = op->getOperand(0).getDefiningOp()) { + int64_t idx = op.getIndex(); llvm::SmallVector new_operand{tupleOp.getOperand(idx)}; op->replaceAllUsesWith(new_operand); op->erase(); @@ -185,10 +168,10 @@ void CleanUpTupleOps(mlir::Block* block, mlir::OpBuilder* builder) { while (changed) { changed = false; for (Operation& op : llvm::make_early_inc_range(block->getOperations())) { - if (llvm::isa(op)) { - changed = FoldGetTupleElementOfTuple(&op); - } else if (llvm::isa(op) && + if (auto get_tuple_op = + llvm::dyn_cast(op)) { + changed = FoldGetTupleElementOfTuple(get_tuple_op); + } else if (llvm::isa(op) && mlir::isOpTriviallyDead(&op)) { op.erase(); changed = true; @@ -207,6 +190,7 @@ Operation* CreateReturnOp(mlir::OpBuilder& builder, mlir::Location loc, LLVM_DEBUG(llvm::dbgs() << "CreateReturnOp: " << parent_dialect->getNamespace() << '\n'); if (llvm::isa(parent_dialect)) { + // Potentially unused, but if future MHLO ops have bodies, will be needed. return builder.create(loc, operands); } if (llvm::isa(parent_dialect)) { @@ -215,26 +199,7 @@ Operation* CreateReturnOp(mlir::OpBuilder& builder, mlir::Location loc, return builder.create(loc, operands); } -bool HasMhloTokenType(mlir::TypeRange types) { - bool use_mhlo = false; - for (auto type : types) { - if (!use_mhlo) { - type.walk([&](Type type) { - use_mhlo |= llvm::isa(type); - if (use_mhlo) return mlir::WalkResult::interrupt(); - return mlir::WalkResult::advance(); - }); - } - } - return use_mhlo; -} - Operation* WrapInTuple(mlir::OpBuilder* builder, Operation* op) { - // TODO(b/408024772) ToStablehlo: Make StableHLO only once tokens migrated. - - if (HasMhloTokenType(op->getResultTypes())) { - return builder->create(op->getLoc(), op->getResults()); - } LLVM_DEBUG(llvm::dbgs() << "WrapInTuple: " << op->getName() << op->getResultTypes() << '\n'); return builder->create(op->getLoc(), @@ -244,13 +209,8 @@ Operation* WrapInTuple(mlir::OpBuilder* builder, Operation* op) { Operation* GetTupleElementOp(mlir::OpBuilder* builder, Value value, int64_t index, llvm::SmallVector&& attributes) { - // TODO(b/408024772) ToStablehlo: Inline once tokens migrated. attributes.push_back( builder->getNamedAttr("index", builder->getI32IntegerAttr(index))); - if (HasMhloTokenType(value.getType())) { - return builder->create( - value.getLoc(), value, builder->getI32IntegerAttr(index)); - } return builder->create( value.getLoc(), value, builder->getI32IntegerAttr(index)); } @@ -312,9 +272,7 @@ absl::StatusOr createConstantZeroLike(mlir::Value operand, void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( mlir::Operation* op, llvm::ArrayRef implicit_operands) { - assert((mlir::dyn_cast(*op) || - mlir::dyn_cast(*op) || - mlir::dyn_cast(*op) || + assert((mlir::dyn_cast(*op) || mlir::dyn_cast(*op)) && "Unexpected mlir op in " "HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!"); @@ -333,8 +291,7 @@ static bool IsNestedTupleInData(Type type) { auto tuple_type = mlir::dyn_cast(type); if (!tuple_type) return false; - assert((llvm::isa(tuple_type.getType(1)) || - llvm::isa(tuple_type.getType(1))) && + assert(llvm::isa(tuple_type.getType(1)) && "Infeed: Non token type"); auto data_type = tuple_type.getType(0); @@ -682,7 +639,6 @@ absl::Status HloFunctionImporter::ImportInstructions( arguments.begin() + flatten_idx, arguments.begin() + flatten_idx + flattened_arg_type.size())); - // TODO(b/408024772) ToStablehlo: CreateTupleValue auto tupleVal = CreateTupleValue(&builder, loc, sub_args, orig_tuple_arg_type); effective_arguments.push_back(tupleVal); @@ -824,7 +780,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (instruction->opcode() == HloOpcode::kAsyncStart) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( - context_, mlir::cast(result_type).getTypes()); + context_, result_type.cast().getTypes()); // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, @@ -832,7 +788,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( .getOperation(); } else if (instruction->opcode() == HloOpcode::kAsyncUpdate) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( - context_, mlir::cast(result_type).getTypes()); + context_, result_type.cast().getTypes()); // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, @@ -905,13 +861,6 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "precision_config", ConvertPrecisionConfig(&instruction->precision_config(), builder_))); - if (instruction->precision_config().algorithm() != - PrecisionConfig::ALG_UNSET) { - attributes.push_back(builder_->getNamedAttr( - "algorithm", - ConvertDotAlgorithm(instruction->precision_config().algorithm(), - builder_))); - } attributes.push_back(builder_->getNamedAttr( "dot_dimension_numbers", ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(), @@ -1677,7 +1626,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( .getValue()); } - // TODO(b/408024772) Fix StableHLO AllToAll to support tuple types the + // TODO(b/408024772): Fix StableHLO AllToAll to support tuple types the // way that XLA expects it. Currently it is mis-designed. // XLA Feature -- MHLO Only auto result = func_builder->create( @@ -1965,17 +1914,17 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } attributes.push_back( - builder_->getNamedAttr("window_strides", Convert(strides))); + builder_->getNamedAttr("window_strides", ConvertArray(strides))); attributes.push_back(ConvertPadding(paddings)); attributes.push_back( - builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations))); + builder_->getNamedAttr("lhs_dilation", ConvertArray(lhs_dilations))); attributes.push_back( - builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations))); + builder_->getNamedAttr("rhs_dilation", ConvertArray(rhs_dilations))); attributes.push_back( - builder_->getNamedAttr("window_reversal", Convert(reversals))); + builder_->getNamedAttr("window_reversal", ConvertArray(reversals))); attributes.push_back(builder_->getNamedAttr( "dimension_numbers", - ConvertConvDimensionNumbers( + stablehlo::ConvertConvDimensionNumbers( instruction->convolution_dimension_numbers(), builder_))); attributes.push_back(builder_->getNamedAttr( "feature_group_count", @@ -1984,8 +1933,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "batch_group_count", builder_->getI64IntegerAttr(instruction->batch_group_count()))); attributes.push_back(builder_->getNamedAttr( - "precision_config", - ConvertPrecisionConfig(&instruction->precision_config(), builder_))); + "precision_config", stablehlo::ConvertPrecisionConfig( + &instruction->precision_config(), builder_))); // If the element types of the operands for convolution are different, // insert a convert op to convert the operands to the common element type @@ -1995,6 +1944,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( auto lhs_element_type = instruction->operand(0)->shape().element_type(); auto rhs_element_type = instruction->operand(1)->shape().element_type(); if (lhs_element_type != rhs_element_type) { + // Cast LHS or RHS to the common element type. if (primitive_util::CastPreservesValues(lhs_element_type, rhs_element_type)) { auto convert_op_return_type = @@ -2016,20 +1966,11 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( instruction->operand(0)->shape().ToString(), instruction->operand(1)->shape().ToString()); } - // TODO(b/408024772) ToStablehlo: ConvertPrecisionConfig, - // ConvertConvDimensionNumbers - return func_builder - ->create( - loc, result_type, std::vector{lhs, rhs}, - attributes) - .getOperation(); } - // TODO(b/408024772) ToStablehlo: ConvertPrecisionConfig, - // ConvertConvDimensionNumbers return func_builder - ->create(loc, result_type, operands, - attributes) + ->create( + loc, result_type, std::vector{lhs, rhs}, attributes) .getOperation(); } @@ -2224,16 +2165,16 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( #undef NO_ATTRIBUTE_CASE_MHLO -// TODO(b/408024772) ToStablehlo: ConvertResultAccuracy #define RESULT_ACCURACY_CASE(hlo_op_code, mlir_op) \ case HloOpcode::hlo_op_code: { \ if (instruction->has_result_accuracy()) { \ attributes.push_back(builder_->getNamedAttr( \ - "result_accuracy", \ - ConvertResultAccuracy(instruction->result_accuracy(), builder_))); \ + "result_accuracy", stablehlo::ConvertResultAccuracy( \ + instruction->result_accuracy(), builder_))); \ } \ return func_builder \ - ->create(loc, result_type, operands, attributes) \ + ->create(loc, result_type, operands, \ + attributes) \ .getOperation(); \ } @@ -2434,6 +2375,11 @@ mlir::DenseI64ArrayAttr HloFunctionImporter::ConvertArray( return builder_->getDenseI64ArrayAttr(elements); } +mlir::DenseBoolArrayAttr HloFunctionImporter::ConvertArray( + llvm::ArrayRef elements) { + return builder_->getDenseBoolArrayAttr(elements); +} + mlir::DenseIntElementsAttr HloFunctionImporter::Convert( llvm::ArrayRef elements) { return DenseIntElementsAttr::get( diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h index 386dff23fe83e6..7b34a32e720f59 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -135,6 +136,7 @@ class HloFunctionImporter { context_->loadDialect(); context_->loadDialect(); context_->loadDialect(); + context_->loadDialect(); context_->loadDialect(); } @@ -209,8 +211,12 @@ class HloFunctionImporter { mlir::DenseIntElementsAttr ConvertDimensions( absl::Span op_dimensions); + // Converts Array ref to a DenseI64ArrayAttr. mlir::DenseI64ArrayAttr ConvertArray(llvm::ArrayRef elements); + // Converts Array ref to a DenseBoolArrayAttr. + mlir::DenseBoolArrayAttr ConvertArray(llvm::ArrayRef elements); + // Converts Array ref to an DenseIntElementsAttr. mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index 983efbc7987f29..aa4e95e2ddafee 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -42,7 +42,6 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir/utils/type_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/tsl/platform/statusor.h" @@ -176,23 +175,6 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( vector); } -namespace { -bool HasMhloTokenType(mlir::TypeRange types) { - bool use_mhlo = false; - for (auto type : types) { - if (!use_mhlo) { - type.walk([&](mlir::Type type) { - use_mhlo |= llvm::isa(type); - if (use_mhlo) return mlir::WalkResult::interrupt(); - return mlir::WalkResult::advance(); - }); - } - } - return use_mhlo; -} - -} // namespace - mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, mlir::ValueRange& flatten_values, mlir::Type type) { @@ -209,10 +191,6 @@ mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, flatten_sub_values.push_back( CreateTupleValue(func_builder, loc, flatten_values, child_type)); - if (HasMhloTokenType(mlir::TypeRange(flatten_sub_values))) { - return func_builder->create(loc, flatten_sub_values) - .getResult(); - } return func_builder->create(loc, flatten_sub_values) .getResult(); } @@ -226,10 +204,7 @@ mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, mlir::ValueRange flattened_results_ref(op->getResults()); auto result = CreateTupleValue(func_builder, loc, flattened_results_ref, type); - mlir::Operation* tuple_op = result.getDefiningOp(); - if (!tuple_op) { - tuple_op = result.getDefiningOp(); - } + mlir::Operation* tuple_op = result.getDefiningOp(); assert(tuple_op && "builder didn't return the right type"); return tuple_op; } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h index a19d0da23e5fc1..789180d85ec240 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -42,7 +42,6 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir/utils/type_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -81,7 +80,7 @@ static absl::StatusOr ConvertTensorShapeToType(const Shape& xla_ty, shape[dim] = dim_size; } } - using mlir::mhlo::TypeExtensionsAttr; + using mlir::stablehlo::TypeExtensionsAttr; mlir::Attribute encoding; if (is_bounded_dynamic) { encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc index 68a9607a5e5db4..e7306858de667a 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/DebugStringHelper.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -34,7 +34,7 @@ namespace { TEST(ConvertTensorShapeToType, Simple) { mlir::MLIRContext context; - context.loadDialect(); + context.loadDialect(); mlir::Builder builder(&context); // Static shape. @@ -59,7 +59,8 @@ TEST(ConvertTensorShapeToType, Simple) { ConvertTensorShapeToType(shape, builder)); int64_t bounds[] = {8, mlir::ShapedType::kDynamic}; - auto extensions = mlir::mhlo::TypeExtensionsAttr::get(&context, bounds); + auto extensions = + mlir::stablehlo::TypeExtensionsAttr::get(&context, bounds); auto expected = mlir::RankedTensorType::get( {mlir::ShapedType::kDynamic, 128}, builder.getI32Type(), extensions); EXPECT_TRUE(type == expected) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo index 1d13766c9e1b5d..eddb91986a124a 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_emit_stablehlo.hlo @@ -562,7 +562,7 @@ ENTRY %main.12 (Arg_0.1: s32[4]) -> (s32[4], s32[4]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { -// CHECK: %[[VAL_1:.*]] = mhlo.cosine %[[VAL_0]] : tensor<1x16x16x3xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.cosine %[[VAL_0]] : tensor<1x16x16x3xf32> // CHECK: return %[[VAL_1]] : tensor<1x16x16x3xf32> // CHECK: } // CHECK: } @@ -577,7 +577,7 @@ ENTRY %main.3 (Arg_0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { -// CHECK: %[[VAL_1:.*]] = mhlo.sine %[[VAL_0]] : tensor<1x16x16x3xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.sine %[[VAL_0]] : tensor<1x16x16x3xf32> // CHECK: return %[[VAL_1]] : tensor<1x16x16x3xf32> // CHECK: } // CHECK: } @@ -592,7 +592,7 @@ ENTRY %main.3 (Arg_0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_1:.*]] = mhlo.exponential %[[VAL_0]] {result_accuracy = #mhlo.result_accuracy>} : tensor +// CHECK: %[[VAL_1:.*]] = stablehlo.exponential %[[VAL_0]] {result_accuracy = #stablehlo.result_accuracy>} : tensor // CHECK: return %[[VAL_1]] : tensor // CHECK: } // CHECK: } @@ -668,7 +668,7 @@ ENTRY %main.2 () -> () { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<100x26x26x32xf32>, %[[VAL_1:.*]]: tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { -// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> +// CHECK: %[[VAL_2:.*]] = stablehlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> // CHECK: return %[[VAL_2]] : tensor<100x28x28x1xf32> // CHECK: } // CHECK: } @@ -684,7 +684,7 @@ ENTRY %main.4 (Arg_0.1: f32[100,26,26,32], Arg_1.2: f32[3,3,1,32]) -> f32[100,28 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<100x26x26x32xi8>, %[[VAL_1:.*]]: tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { -// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> +// CHECK: %[[VAL_2:.*]] = stablehlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> // CHECK: return %[[VAL_2]] : tensor<100x28x28x1xi32> // CHECK: } // CHECK: } @@ -700,7 +700,7 @@ ENTRY %main.4 (Arg_0.1: s8[100,26,26,32], Arg_1.2: s8[3,3,1,32]) -> s32[100,28,2 // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<100x26x26x32xi8>, %[[VAL_1:.*]]: tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { -// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> +// CHECK: %[[VAL_2:.*]] = stablehlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [true, true]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> // CHECK: return %[[VAL_2]] : tensor<100x28x28x1xi32> // CHECK: } // CHECK: } @@ -1207,9 +1207,9 @@ ENTRY %main.4 (Arg_0.1: f32[4,2], Arg_1.2: s32[]) -> s32[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor) -> tensor<4x?xf32, #mhlo.type_extensions> { -// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 1 : (tensor<4x4xf32>, tensor) -> tensor<4x?xf32, #mhlo.type_extensions> -// CHECK: return %[[VAL_2]] : tensor<4x?xf32, #mhlo.type_extensions> +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor) -> tensor<4x?xf32, #stablehlo.bounds> { +// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 1 : (tensor<4x4xf32>, tensor) -> tensor<4x?xf32, #stablehlo.bounds> +// CHECK: return %[[VAL_2]] : tensor<4x?xf32, #stablehlo.bounds> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(f32[4,4]{1,0}, s32[])->f32[4,<=4]{1,0}} @@ -1865,9 +1865,9 @@ ENTRY %main.5 (Arg_0.1: token[]) -> token[] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor) -> tensor<4x?xf32, #mhlo.type_extensions> { -// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 1 : (tensor<4x4xf32>, tensor) -> tensor<4x?xf32, #mhlo.type_extensions> -// CHECK: return %[[VAL_2]] : tensor<4x?xf32, #mhlo.type_extensions> +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4x4xf32>, %[[VAL_1:.*]]: tensor) -> tensor<4x?xf32, #stablehlo.bounds> { +// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 1 : (tensor<4x4xf32>, tensor) -> tensor<4x?xf32, #stablehlo.bounds> +// CHECK: return %[[VAL_2]] : tensor<4x?xf32, #stablehlo.bounds> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(f32[4,4]{1,0}, s32[])->f32[4,<=4]{1,0}} @@ -1961,8 +1961,8 @@ ENTRY %main.4 (Arg_0.1: f32[], Arg_1.2: s32[]) -> (f32[], s32[]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor<4xi32>) -> tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> { -// CHECK: %[[VAL_2:.*]] = mhlo.exponential_minus_one %[[VAL_0]] : tensor<4xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.log_plus_one %[[VAL_0]] : tensor<4xf32> +// CHECK: %[[VAL_2:.*]] = stablehlo.exponential_minus_one %[[VAL_0]] : tensor<4xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.log_plus_one %[[VAL_0]] : tensor<4xf32> // CHECK: %[[VAL_4:.*]] = stablehlo.not %[[VAL_1]] : tensor<4xi32> // CHECK: %[[VAL_5:.*]] = stablehlo.popcnt %[[VAL_1]] : tensor<4xi32> // CHECK: %[[VAL_6:.*]] = stablehlo.tuple %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {xla_shape = "(f32[4]{0}, f32[4]{0}, s32[4]{0}, s32[4]{0})"} : tuple, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>> @@ -2120,9 +2120,9 @@ ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[4] { // ----- // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { -// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor) -> tensor> { -// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 0 : (tensor<4xf32>, tensor) -> tensor> -// CHECK: return %[[VAL_2]] : tensor> +// CHECK: func.func @main(%[[VAL_0:.*]]: tensor<4xf32>, %[[VAL_1:.*]]: tensor) -> tensor> { +// CHECK: %[[VAL_2:.*]] = stablehlo.set_dimension_size %[[VAL_0]], %[[VAL_1]], dim = 0 : (tensor<4xf32>, tensor) -> tensor> +// CHECK: return %[[VAL_2]] : tensor> // CHECK: } // CHECK: } HloModule main, entry_computation_layout={(f32[4]{0}, s32[])->f32[<=4]{0}} @@ -2214,7 +2214,7 @@ ENTRY %main.6 (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<3x4xf32>) -> tensor<3x4xf32> { -// CHECK: %[[VAL_1:.*]] = mhlo.cbrt %[[VAL_0]] : tensor<3x4xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.cbrt %[[VAL_0]] : tensor<3x4xf32> // CHECK: return %[[VAL_1]] : tensor<3x4xf32> // CHECK: } // CHECK: } @@ -2383,7 +2383,7 @@ ENTRY %main.3 (Arg_0.1: f32[2]) -> f32[2] { // CHECK-LABEL: module @main attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { // CHECK: func.func @main(%[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: %[[VAL_1:.*]] = mhlo.tan %[[VAL_0]] : tensor<2xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.tan %[[VAL_0]] : tensor<2xf32> // CHECK: return %[[VAL_1]] : tensor<2xf32> // CHECK: } // CHECK: } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc index 9f5e164969decd..133067450d40bf 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -63,6 +64,41 @@ struct AddDependencyOpToStablehloTokenConverter } }; +bool hasMhloOperand(Operation* op) { + return llvm::any_of(op->getOperandTypes(), [](Type type) { + // Check for !stablehlo.token + if (llvm::isa(type.getDialect())) return true; + + // Check for tensor> + if (auto rankedType = dyn_cast(type)) { + return llvm::isa_and_nonnull( + rankedType.getEncoding()); + } + // Not StableHLO + return false; + }); +} + +struct UpdateOperandsInUnknownOp : public ConversionPattern { + UpdateOperandsInUnknownOp(TypeConverter& converter, MLIRContext* context) + : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, + context) {} + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + // Input types already converted to MHLO. + if (llvm::isa( + op->getDialect())) + return rewriter.notifyMatchFailure(op, "op is not an unknown op"); + + if (!hasMhloOperand(op)) + return rewriter.notifyMatchFailure(op, "op has no mhlo operands"); + + rewriter.modifyOpInPlace(op, [&]() { op->setOperands(operands); }); + return success(); + } +}; + struct HloLegalizeToStablehloPass : public impl::HloLegalizeToStablehloPassBase { HloLegalizeToStablehloPass() @@ -78,6 +114,9 @@ struct HloLegalizeToStablehloPass stablehlo::HloToStablehloTypeConverter converter; RewritePatternSet patterns(&getContext()); + stablehlo::populateHloToStablehloPatterns( + &patterns, &converter, &getContext(), allow_experimental_features_); + stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); if (allow_xla_features_) { // These ops do not exist in StableHLO. @@ -88,15 +127,14 @@ struct HloLegalizeToStablehloPass mhlo::SparseDotOp, mhlo::StochasticConvertOp, mhlo::TopKOp, mhlo::TraceOp, mhlo::XlaRngGetAndUpdateStateOp>(); target.addDynamicallyLegalOp( - [](mhlo::AddDependencyOp op) { - return llvm::isa(op.getToken().getType()); - }); + [](mhlo::AddDependencyOp op) { return !hasMhloOperand(op); }); patterns.add(&getContext()); } - stablehlo::populateHloToStablehloPatterns( - &patterns, &converter, &getContext(), allow_experimental_features_); - stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); + // Handle non-MHLO ops that may have bounded dynamism or token types. + target.markUnknownOpDynamicallyLegal( + [](Operation* op) { return !hasMhloOperand(op); }); + patterns.add(converter, &getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h index 194e92b14757e1..fd626abfaeaf02 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -142,6 +142,10 @@ void populateStablehloToHloPatterns(RewritePatternSet *patterns, TypeConverter *converter, MLIRContext *context); +// Sets up legality definitions for StableHLO ops and non-StableHLO ops that +// may have StableHLO operands. +void setupStablehloToHloConversionTarget(ConversionTarget &target); + } // namespace stablehlo } // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index cb39eb52ca7f29..352e9a669e9eb8 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/map_stablehlo_to_hlo_op.h" #include "mhlo/transforms/rewriters.h" @@ -425,6 +426,64 @@ class StablehloToHloOpConverter : public OpConversionPattern { } }; +// AddDependencyOp is the only op that doesn't exist in StableHLO but uses +// token types. This led to two options (1) support either token type in +// AddDependencyOp or (2) Design a token conversion (or unrealized cast) between +// MHLO and StableHLO. Option (1) seems safer, and we can hopefully obsolete +// mhlo::TokenType all together and just use StableHLO tokens everywhere. +// +// Note: Only the second argument needs to be converted. All token creation and +// propagation is already handled by existing conversions. +struct AddDependencyOpToMhoTokenConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::AddDependencyOp op, mhlo::AddDependencyOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Only convert if input token type is MHLO token + if (!llvm::isa(adaptor.getToken().getType())) + return rewriter.notifyMatchFailure(op, "nothing to convert"); + rewriter.replaceOpWithNewOp(op, adaptor.getOperand(), + adaptor.getToken()); + return success(); + } +}; + +bool hasStablehloOperand(Operation* op) { + return llvm::any_of(op->getOperandTypes(), [](Type type) { + // Check for !stablehlo.token + if (llvm::isa(type.getDialect())) return true; + + // Check for tensor> + if (auto rankedType = dyn_cast(type)) { + return llvm::isa_and_nonnull( + rankedType.getEncoding()); + } + // Not StableHLO + return false; + }); +} + +struct UpdateOperandsInUnknownOp : public ConversionPattern { + UpdateOperandsInUnknownOp(TypeConverter& converter, MLIRContext* context) + : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, + context) {} + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + // Input types already converted to MHLO. + if (llvm::isa( + op->getDialect())) + return rewriter.notifyMatchFailure(op, "op is not an unknown op"); + + if (!hasStablehloOperand(op)) + return rewriter.notifyMatchFailure(op, "op has no stablehlo operands"); + + rewriter.modifyOpInPlace(op, [&]() { op->setOperands(operands); }); + return success(); + } +}; + // Deprecated ops. template <> class StablehloToHloOpConverter @@ -458,6 +517,22 @@ void populateStablehloToHloPatterns(RewritePatternSet* patterns, #define GET_OP_LIST #include "stablehlo/dialect/StablehloOps.cpp.inc" >(patterns, converter, context); + + // Populate conversion patterns for ops that don't exist in StableHLO + // and unknown dialect ops that may have StableHLO operands. + patterns->add(context); + patterns->add(*converter, context); +} + +void setupStablehloToHloConversionTarget(ConversionTarget& target) { + target.addIllegalDialect(); + target.addLegalDialect(); + + // Some ops may have MHLO / StableHLO types in operands + target.addDynamicallyLegalOp( + [](mhlo::AddDependencyOp op) { return !hasStablehloOperand(op); }); + target.markUnknownOpDynamicallyLegal( + [](Operation* op) { return !hasStablehloOperand(op); }); } } // namespace stablehlo diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 2909ada7898381..492dc410be943f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -29,7 +29,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" @@ -41,29 +40,6 @@ namespace mhlo { namespace { -// AddDependencyOp is the only op that doesn't exist in StableHLO but uses -// token types. This led to two options (1) support either token type in -// AddDependencyOp or (2) Design a token conversion (or unrealized cast) between -// MHLO and StableHLO. Option (1) seems safer, and we can hopefully obsolete -// mhlo::TokenType all together and just use StableHLO tokens everywhere. -// -// Note: Only the second argument needs to be converted. All token creation and -// propagation is already handled by existing conversions. -struct AddDependencyOpToMhoTokenConverter - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::AddDependencyOp op, mhlo::AddDependencyOpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - // Only convert if input token type is MHLO token - if (!llvm::isa(adaptor.getToken().getType())) - return rewriter.notifyMatchFailure(op, "nothing to convert"); - rewriter.replaceOpWithNewOp(op, adaptor.getOperand(), - adaptor.getToken()); - return success(); - } -}; - void legalDirectStablehloToHloConversionOps(ConversionTarget& target) { target.addLegalOp< stablehlo::AbsOp, stablehlo::CbrtOp, stablehlo::SqrtOp, stablehlo::TanOp, @@ -90,12 +66,7 @@ struct StablehloLegalizeToHloPass using StablehloLegalizeToHloPassBase::StablehloLegalizeToHloPassBase; void runOnOperation() override { ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalDialect(); - target.addDynamicallyLegalOp( - [](mhlo::AddDependencyOp op) { - return llvm::isa(op.getToken().getType()); - }); + stablehlo::setupStablehloToHloConversionTarget(target); // Allow injecting legal ops to permit gradual migration. if (!convert_xla_supported_stablehlo_) { @@ -104,7 +75,6 @@ struct StablehloLegalizeToHloPass stablehlo::StablehloToHloTypeConverter converter; RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); stablehlo::populateStablehloToHloPatterns(&patterns, &converter, &getContext()); stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 2196a2190d2d48..8c850c4af074da 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1838,6 +1838,14 @@ func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #mhlo.type_e return %0 : tensor<2x1x?xf32, #mhlo.type_extensions> } +// CHECK-LABEL: bounded_dynamism_with_unknown_op +func.func @bounded_dynamism_with_unknown_op(%arg0: tensor<1x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { + %0 = "mhlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<1x4xi32>, tensor) -> tensor<1x?xi32, #mhlo.type_extensions> + // CHECK: "tensor.cast"({{.*}}) : (tensor<1x?xi32, #stablehlo.bounds>) -> tensor<1x4xi32> + %cast = tensor.cast %0 : tensor<1x?xi32, #mhlo.type_extensions> to tensor<1x4xi32> + return %cast : tensor<1x4xi32> +} + // ============ TYPES ============ // CHECK-LABEL: "type_i1" diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index a72ac0486c4f8b..27bdd8ebf0a9a3 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -2095,6 +2095,14 @@ func.func @bounded_dynamism_broadcast_in_dim(%arg0: tensor<1x?xf32, #stablehlo.b return %0 : tensor<2x1x?xf32, #stablehlo.bounds> } +// CHECK-LABEL: bounded_dynamism_with_unknown_op +func.func @bounded_dynamism_with_unknown_op(%arg0: tensor<1x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<1x4xi32>, tensor) -> tensor<1x?xi32, #stablehlo.bounds> + // CHECK: "tensor.cast"({{.*}}) : (tensor<1x?xi32, #mhlo.type_extensions>) -> tensor<1x4xi32> + %cast = tensor.cast %0 : tensor<1x?xi32, #stablehlo.bounds> to tensor<1x4xi32> + return %cast : tensor<1x4xi32> +} + // ============ TYPES ============ // CHECK-LABEL: "type_i1" From c0868a5ff9ab6082065c3c5328909e46c3c4d4a4 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Wed, 16 Apr 2025 19:27:38 -0700 Subject: [PATCH 0897/1324] Use argument instead of env variable to enable tfrt gpu client A Pathways flag will control which client to use. PiperOrigin-RevId: 748503152 --- third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD | 2 -- .../pjrt/plugin/xla_gpu/xla_gpu_client_options.h | 2 ++ .../xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.cc | 13 +------------ .../xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h | 3 --- .../pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc | 10 +++------- 5 files changed, 6 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD b/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD index 9f1ac114c5a992..17fad658a81966 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD @@ -19,7 +19,6 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/pjrt/gpu/tfrt:tfrt_gpu_client", - "//xla/tsl/util:env_var", "@com_google_absl//absl/status:statusor", ], ) @@ -51,7 +50,6 @@ xla_test( "no_oss", ] + if_google(["config-cuda-only"]), deps = [ - ":xla_gpu_client_options", ":xla_gpu_pjrt_client", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/pjrt/gpu/tfrt:tfrt_gpu_client", diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h index 58c2c67ba5101d..3e2693f9a09a50 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h @@ -48,6 +48,8 @@ struct GpuClientOptions { std::optional mock_gpu_topology; std::optional slice_index; + + bool use_tfrt_gpu_client = false; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.cc index 8ee1b3dbeb4cbe..a7c9eadf1bc4ff 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.cc @@ -22,24 +22,13 @@ limitations under the License. #include "xla/pjrt/gpu/tfrt/tfrt_gpu_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" -#include "xla/tsl/util/env_var.h" namespace xla { -bool UseTfrtGpuClient() { - bool xla_pjrt_gpu_host_memory_preallocate; - if (!tsl::ReadBoolFromEnvVar("USE_TFRT_GPU_CLIENT", false, - &xla_pjrt_gpu_host_memory_preallocate) - .ok()) { - return false; - } - return xla_pjrt_gpu_host_memory_preallocate; -} - absl::StatusOr> GetXlaPjrtGpuClient( GpuClientOptions options) { // TODO(masonchang): Wrap the GPU Client inside the PJRT Sandwich - if (UseTfrtGpuClient()) { + if (options.use_tfrt_gpu_client) { return GetTfrtGpuClient(options); } return GetStreamExecutorGpuClient(options); diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h index f7f848ddf14528..2c1fcc13fd55e3 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h @@ -24,9 +24,6 @@ limitations under the License. namespace xla { -// Whether to use the TFRT GPU Client. -bool UseTfrtGpuClient(); - // Public entry point to get an XLA:GPU PjRtClient absl::StatusOr> GetXlaPjrtGpuClient( GpuClientOptions options); diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc index b776970a3eb7f6..8c4106a1d86123 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc @@ -19,24 +19,20 @@ limitations under the License. #include #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/gpu/tfrt/tfrt_gpu_client.h" -#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" namespace xla { TEST(XlaCpuPjrtClientTest, GetXlaPjrtGpuClient) { - GpuClientOptions options; - ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtGpuClient(options)); + ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtGpuClient({})); EXPECT_EQ(client->platform_name(), "cuda"); EXPECT_NE(dynamic_cast(client.get()), nullptr); } TEST(XlaCpuPjrtClientTest, GetXlaPjrtGpuClientWithTfrtClient) { - setenv("USE_TFRT_GPU_CLIENT", "true", 1); - GpuClientOptions options; - ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtGpuClient(options)); + ASSERT_OK_AND_ASSIGN(auto client, + GetXlaPjrtGpuClient({.use_tfrt_gpu_client = true})); EXPECT_EQ(client->platform_name(), "cuda"); EXPECT_NE(dynamic_cast(client.get()), nullptr); - unsetenv("USE_TFRT_GPU_CLIENT"); } } // namespace xla From 8a6bbc0efb6f01ed0fe9a0bf8a8b12b1b2692746 Mon Sep 17 00:00:00 2001 From: Nicolas Perez Date: Wed, 16 Apr 2025 19:33:46 -0700 Subject: [PATCH 0898/1324] Get correct number of SparseCores per logical device and give better name to util. PiperOrigin-RevId: 748504154 --- .../xla/xla/stream_executor/tpu/tpu_library_init_fns.inc | 3 ++- third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc b/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc index 07fa04a011494a..e3d237e8671bd8 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_library_init_fns.inc @@ -75,7 +75,8 @@ absl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount); TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoresPerChip); - TFTPU_SET_FN(ops_api_fn, TpuTopology_MaybeAvailableCoresPerChip); + TFTPU_SET_FN(ops_api_fn, + TpuTopology_MaybeAvailableSparseCoresPerLogicalDevice); TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort); TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled); TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h index f44babaefbbc8a..7ff7f31e7f1f87 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_ops_c_api.h @@ -439,7 +439,8 @@ TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoresPerChip( // Returns the number of cores per Chip or -1 if the TPU system is not // available. -TFTPU_CAPI_EXPORT absl::StatusOr TpuTopology_MaybeAvailableCoresPerChip( +TFTPU_CAPI_EXPORT absl::StatusOr +TpuTopology_MaybeAvailableSparseCoresPerLogicalDevice( TpuCoreTypeEnum tpu_core_type); // Recycle unused service port. @@ -808,7 +809,7 @@ struct TfTpu_OpsApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoresPerChip); - TFTPU_ADD_FN_IN_STRUCT(TpuTopology_MaybeAvailableCoresPerChip); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_MaybeAvailableSparseCoresPerLogicalDevice); TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); From 88e017d80dfebfba8395bf4afd36375e06413962 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 16 Apr 2025 19:34:04 -0700 Subject: [PATCH 0899/1324] [xla] ExecutionGraph: do not weaker execution order when doing transitive reduction PiperOrigin-RevId: 748504208 --- .../xla/xla/runtime/execution_graph.cc | 125 +++++++++++++----- third_party/xla/xla/runtime/execution_graph.h | 21 ++- .../xla/xla/runtime/execution_graph_test.cc | 45 +++++++ 3 files changed, 153 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/runtime/execution_graph.cc b/third_party/xla/xla/runtime/execution_graph.cc index fd3e8aa0549ce6..74eb39edbdec44 100644 --- a/third_party/xla/xla/runtime/execution_graph.cc +++ b/third_party/xla/xla/runtime/execution_graph.cc @@ -198,7 +198,8 @@ ExecutionGraph::CreateNodeDefs(std::vector builders) { std::move(nodes_defs)); } -int64_t ExecutionGraph::EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to) { +int64_t ExecutionGraph::EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to, + NodeEdge::Kind kind) { DCHECK_NE(from.id, to.id) << "Nodes must be different"; DCHECK_LT(from.id, to.id) << "Nodes must be ordered"; @@ -242,6 +243,7 @@ int64_t ExecutionGraph::EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to) { in_edges_it != to.in_edges.end() && in_edges_it->id == from.id; DCHECK(has_in_edge) << "In-edge must exist if out-edge exists"; + DCHECK_EQ(in_edges_it->kind, out_edges_it->kind) << "Edges kind must match"; // At this point we must have exactly one edge between `from` and `to` nodes. DCHECK_EQ(absl::c_count_if(from.out_edges, EdgePredicate(to.id)), 1) @@ -249,27 +251,74 @@ int64_t ExecutionGraph::EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to) { DCHECK_EQ(absl::c_count_if(to.in_edges, EdgePredicate(from.id)), 1) << "Expected exactly one in edge from " << from.id << " to " << to.id; + // We can't erase an edge with a stronger ordering guarantee. + if (in_edges_it->kind > kind) { + return 0; + } + + // We erased exactly one edge between `from` and `to` nodes. from.out_edges.erase(out_edges_it); to.in_edges.erase(in_edges_it); - - // We erased one edge between `from` and `to` nodes. return 1; } +namespace { + +// A state of a DFS traversal for transitive reduction. +class TransitiveReductionDfsState { + public: + void PushToStack(ExecutionGraph::NodeEdge edge) { + if (!visited_[edge.id]) { + ++(edge.kind == kExecution ? num_execution_edges_ + : num_scheduling_edges_); + stack_.push_back(edge); + visited_[edge.id] = true; + } + } + + void PushToStack(absl::Span edges) { + for (const ExecutionGraph::NodeEdge& edge : edges) { + PushToStack(edge); + } + } + + ExecutionGraph::NodeEdge PopFromStack() { + ExecutionGraph::NodeEdge edge = stack_.back(); + --(edge.kind == kExecution ? num_execution_edges_ : num_scheduling_edges_); + stack_.pop_back(); + return edge; + } + + bool Empty() const { return stack_.empty(); } + + void Visited(ExecutionGraph::NodeId id) { visited_[id] = true; } + size_t NumVisited() const { return absl::c_count(visited_, true); } + + void Clear(size_t num_nodes) { + stack_.clear(); + visited_.assign(num_nodes, false); + } + + bool num_execution_edges() const { return num_execution_edges_; } + bool num_scheduling_edges() const { return num_scheduling_edges_; } + + private: + std::vector stack_; + std::vector visited_; + + // The number of execution and scheduling edges currently in the stack. + size_t num_execution_edges_ = 0; + size_t num_scheduling_edges_ = 0; +}; + +} // namespace + int64_t ExecutionGraph::RunTransitiveReductionAndUpdatePriorities( absl::Span builders) { int64_t num_erased_edges = 0; // Keep workspace for DFS traversal between iterations. - std::vector stack; - std::vector visited; - - auto add_to_stack = [&](int64_t node_id) { - if (!visited[node_id]) { - stack.push_back(node_id); - visited[node_id] = true; - } - }; + TransitiveReductionDfsState state; // For each node we do a DFS traversal and delete redundant edges that // connect source node with the node reachable via DFS. We do traversal in @@ -277,35 +326,41 @@ int64_t ExecutionGraph::RunTransitiveReductionAndUpdatePriorities( for (int64_t i = builders.size() - 1; i >= 0; --i) { NodeDefBuilder& source_node = builders[i]; - // Clear DFS workspace from previous iteration. - stack.clear(); - visited.assign(builders.size(), false); + // Clear DFS state from previous iteration. + state.Clear(builders.size()); - // Initialize stack with nodes reachable via immediate out nodes. We mark - // immediate out nodes as visited to correctly compute node priority below. - for (NodeEdge out_edge : source_node.out_edges) { - NodeDefBuilder& out_node = builders[out_edge.id]; - visited[out_edge.id] = true; - for (NodeEdge start_edge : out_node.out_edges) { - add_to_stack(start_edge.id); - } - } + // Make a copy of out edges to avoid invalidating iterators. + for (NodeEdge out_edge : std::vector(source_node.out_edges)) { + DCHECK(state.Empty()) << "Stack must be empty at the start of the DFS"; - // Traverse the graph and delete redundant edges. - while (!stack.empty()) { - int64_t node_id = stack.back(); - stack.pop_back(); - - NodeDefBuilder& node = builders[node_id]; - num_erased_edges += EraseEdge(source_node, node); - - for (NodeEdge out_edge : node.out_edges) { - add_to_stack(out_edge.id); + // Initialize state with nodes reachable via `out_edge`. We mark immediate + // out nodes as visited to correctly compute node priority below. + NodeDefBuilder& out_node = builders[out_edge.id]; + state.Visited(out_edge.id); + state.PushToStack(out_node.out_edges); + + // Do a round of DFS traversal and delete redundant edges from the + // `source_node` to the nodes reachable via DFS. + while (!state.Empty()) { + NodeEdge node_edge = state.PopFromStack(); + NodeDefBuilder& node = builders[node_edge.id]; + + // If we reached `node` via a scheduling edge, then we can't remove an + // execution edge from the `source_node`, as we might weaker the + // execution order and introduce a data race. + bool has_scheduling_edge = out_edge.kind == kScheduling || + node_edge.kind == kScheduling || + state.num_scheduling_edges(); + NodeEdge::Kind kind = has_scheduling_edge ? kScheduling : kExecution; + num_erased_edges += EraseEdge(source_node, node, kind); + + // Keep following nodes reachable via `node` out edges. + state.PushToStack(node.out_edges); } } // Set node priority to the number of visited nodes in the DFS traversal. - source_node.priority = absl::c_count(visited, true); + source_node.priority = state.NumVisited(); } return num_erased_edges; diff --git a/third_party/xla/xla/runtime/execution_graph.h b/third_party/xla/xla/runtime/execution_graph.h index 0475a0aee7f4eb..2096fc98cacec7 100644 --- a/third_party/xla/xla/runtime/execution_graph.h +++ b/third_party/xla/xla/runtime/execution_graph.h @@ -67,6 +67,9 @@ class ExecutionGraph { static constexpr NodeId kInvalidNodeId = std::numeric_limits::min(); struct NodeEdge { + // Edge kind defines execution ordering between two operations. Scheduling + // edge is weaker than an execution edge, as it gives more flexibility + // to the backend runtime to execute operations concurrently. enum class Kind { // If two operations have a scheduling edge between them, then the // dependent operation must be scheduled (start execution) after the @@ -99,6 +102,16 @@ class ExecutionGraph { return kind == other.kind && id == other.id; } + template + friend void AbslStringify(Sink& sink, Kind kind) { + sink.Append(kind == Kind::kScheduling ? "scheduling" : "execution"); + } + + template + friend void AbslStringify(Sink& sink, const NodeEdge& edge) { + absl::Format(&sink, "NodeEdge {kind: %v, id: %v}", edge.kind, edge.id); + } + Kind kind; NodeId id; }; @@ -212,9 +225,11 @@ class ExecutionGraph { static std::tuple> CreateNodeDefs(std::vector builders); - // Erases edge from `from` node to `to` node if it exists. We rely on the fact - // that out and in-edges are sorted and use binary search on a critical path. - static int64_t EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to); + // Erases edge from `from` node to `to` node if it exists and it has a weaker + // ordering than the given `kind`. We rely on the fact that out and in-edges + // are sorted and use binary search on a critical path. + static int64_t EraseEdge(NodeDefBuilder& from, NodeDefBuilder& to, + NodeEdge::Kind kind); // Runs a transitive reduction on the NodeDefBuilder graph to remove redundant // edges, and updates nodes priorities. Returns the number of removed edges. diff --git a/third_party/xla/xla/runtime/execution_graph_test.cc b/third_party/xla/xla/runtime/execution_graph_test.cc index 606a51342328bd..3f4d959c5e7367 100644 --- a/third_party/xla/xla/runtime/execution_graph_test.cc +++ b/third_party/xla/xla/runtime/execution_graph_test.cc @@ -55,6 +55,11 @@ class Operation : public ExecutionGraph::Operation { std::vector resources_; }; +TEST(ExecutionGraphTest, EdgePriority) { + // Scheduling edge has weaker ordering guarantee than an execution edge. + EXPECT_LE(kScheduling, kExecution); +} + TEST(ExecutionGraphTest, DependencyOrdering) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); @@ -230,5 +235,45 @@ TEST(ExecutionGraphTest, TransitiveReduction) { EXPECT_EQ(execution_graph.priority(2), 0); } +TEST(ExecutionGraphTest, TransitiveReductionKeepsExecutionEdge) { + BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); + BufferAllocation::Slice slice(&alloc, /*offset=*/0, /*size=*/40); + + auto resource = Resource::Create(Resource::Kind::kCollectiveCommunicator); + + std::vector operations; + + // All three operations connected with scheduling edges, but because execution + // edge provides stronger ordering guarantee, we must keep an 0-2 execution + // edge, or we might get a data race. + operations.push_back( + Operation({BufferUse::Write(slice)}, {ResourceUse::Write(resource)})); + operations.push_back( + Operation(/*buffers=*/{}, {ResourceUse::Write(resource)})); + operations.push_back( + Operation({BufferUse::Write(slice)}, {ResourceUse::Write(resource)})); + + TF_ASSERT_OK_AND_ASSIGN(ExecutionGraph execution_graph, + ExecutionGraph::Create(operations)); + + EXPECT_THAT(execution_graph.source(), ElementsAre(0)); + EXPECT_THAT(execution_graph.sink(), ElementsAre(2)); + + EXPECT_THAT(execution_graph.out_edges(0), + ElementsAre(NodeEdge{kScheduling, 1}, NodeEdge{kExecution, 2})); + + EXPECT_THAT(execution_graph.in_edges(1), + ElementsAre(NodeEdge{kScheduling, 0})); + EXPECT_THAT(execution_graph.out_edges(1), + ElementsAre(NodeEdge{kScheduling, 2})); + + EXPECT_THAT(execution_graph.in_edges(2), + ElementsAre(NodeEdge{kExecution, 0}, NodeEdge{kScheduling, 1})); + + EXPECT_EQ(execution_graph.priority(0), 2); + EXPECT_EQ(execution_graph.priority(1), 1); + EXPECT_EQ(execution_graph.priority(2), 0); +} + } // namespace } // namespace xla From 6eb27cb6fffd5ab266f2c5b2c76349077383e5b7 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 16 Apr 2025 20:32:57 -0700 Subject: [PATCH 0900/1324] Polish `SpmdPartitioningVisitor::HandleDynamicUpdateSlice`. Unify the strategies into 2 methods. There are two methods to handle the partitioned slice dimensions. 1. Replicate the slice dimensions for all involved tensors. 2. If we ensure that the update is fully contained in a single partition, we can keep the sharding for input and output. There are two cases. * The slice size is 1. * The index is a constant. The start index and the end index reside in PiperOrigin-RevId: 748517627 --- .../xla/xla/service/spmd/spmd_partitioner.cc | 342 ++++++++---------- 1 file changed, 157 insertions(+), 185 deletions(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index 63d6d9a48e93a2..3853807e504728 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -3753,212 +3753,184 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( return DefaultAction(hlo); } + std::vector new_indices; + new_indices.reserve(hlo->shape().dimensions_size()); + for (int64_t i = 0; i < hlo->shape().dimensions_size(); ++i) { + const HloInstruction* index = hlo->operand(i + 2); + if (hlo->operand(1)->shape().dimensions(i) == hlo->shape().dimensions(i)) { + new_indices.emplace_back(CreateZero(index->shape(), &b_)); + } else { + // Replicate the indices. + new_indices.emplace_back(GetPartitionedHlo(index).Replicate().hlo()); + } + } + + // We always keep the sharding axes along the batch dimensions. There are two + // methods to handle the partitioned slice dimensions. + // 1. Replicate the slice dimensions for all involved tensors. + // 2. If we ensure that the update is fully contained in a single partition, + // we can keep the sharding for input and output. There are two cases: + // (1) The slice size is 1. + // (2) The index is a constant. The start index and the end index reside in + // the same partition. std::vector partitioned_slice_dims; std::vector slice_dims; - std::vector partitioned_non_slice_dims; - std::vector partitioned_slice_offsets; - bool any_non_constant_sliced_dim = false; + bool will_replicate_slice_dims = false; for (int64_t i = 0; i < hlo->shape().dimensions_size(); ++i) { - if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) { - slice_dims.push_back(i); - int64_t slice_size = hlo->operand(1)->shape().dimensions(i); - if (hlo->sharding().tile_assignment().dim(i) != 1) { - if (!hlo->operand(i + 2)->IsConstant() && slice_size != 1) { - any_non_constant_sliced_dim = true; - continue; - } - partitioned_slice_dims.push_back(i); - // Set partitioned_slice_offsets to -1 when slice_size is 1. - if (slice_size == 1) { - partitioned_slice_offsets.push_back(-1); - } else { - const PrimitiveType elemType = - hlo->operand(i + 2)->shape().element_type(); - partitioned_slice_offsets.push_back( - elemType == S64 ? hlo->operand(i + 2)->literal().Get({}) - : hlo->operand(i + 2)->literal().Get({})); + if (hlo->operand(1)->shape().dimensions(i) == hlo->shape().dimensions(i)) { + continue; + } + + slice_dims.push_back(i); + int64_t slice_size = hlo->operand(1)->shape().dimensions(i); + if (hlo->sharding().tile_assignment().dim(i) != 1) { + partitioned_slice_dims.push_back(i); + if (slice_size == 1) { + continue; + } + if (hlo->operand(i + 2)->IsConstant()) { + const PrimitiveType elemType = + hlo->operand(i + 2)->shape().element_type(); + int64_t start_index = + elemType == S64 ? hlo->operand(i + 2)->literal().Get({}) + : hlo->operand(i + 2)->literal().Get({}); + int64_t end_index = start_index + slice_size - 1; + + int64_t per_partition_size = + CeilOfRatio(hlo->shape().dimensions(i), + hlo->sharding().tile_assignment().dim(i)); + if (start_index / per_partition_size != + end_index / per_partition_size) { + // The update is not fully contained in a single partition. + will_replicate_slice_dims = true; } + } else { + will_replicate_slice_dims = true; } - } else if (hlo->sharding().tile_assignment().dim(i) != 1) { - partitioned_non_slice_dims.push_back(i); } } - auto handle_with_replicate_slice_dims = [&]() { + + // Method 1. Replicate the slice dimensions for all involved tensors. + if (will_replicate_slice_dims || partitioned_slice_dims.empty()) { + const HloSharding& input_sharding = hlo->operand(0)->sharding(); + const HloSharding& output_sharding = hlo->sharding(); + const HloSharding& better_sharding = + input_sharding.NumTiles() > output_sharding.NumTiles() + ? input_sharding + : output_sharding; + HloSharding replicated_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( - hlo->operand(0)->sharding(), partitioned_non_slice_dims); + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + better_sharding, slice_dims); auto base = GetPartitionedHlo(hlo->operand(0)).Reshard(replicated_sharding); auto operand = GetPartitionedHlo(hlo->operand(1)).Reshard(replicated_sharding); - std::vector new_indices(hlo->shape().dimensions_size()); - for (int64_t i = 0; i < new_indices.size(); ++i) { - // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); - } auto dus = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( base.hlo()->shape(), base.hlo(), operand.hlo(), new_indices)); dus->set_sharding(replicated_sharding); SetPartitionedHlo(hlo, PartitionedHlo(dus, base.base_shape(), base.state()) .Reshard(hlo->sharding())); - }; - if (any_non_constant_sliced_dim) { - if (partitioned_non_slice_dims.empty()) { - return DefaultAction(hlo); - } - handle_with_replicate_slice_dims(); return absl::OkStatus(); } - // Handle when there is slice dim partitioned. - if (!partitioned_slice_dims.empty()) { - auto add_hlo = [&](std::unique_ptr to_add) { - return b_.AddInstruction(std::move(to_add)); - }; - std::vector new_indices(hlo->shape().dimensions_size()); - for (int64_t i = 0; i < new_indices.size(); ++i) { - if (hlo->operand(1)->shape().dimensions(i) == - hlo->shape().dimensions(i)) { - new_indices[i] = CreateZero(hlo->operand(i + 2)->shape(), &b_); - continue; - } - // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); - } - - // Get partitioned input. - const auto& dus_sharding = hlo->sharding(); - const auto& partitioned_input = - GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo(); - - // Get replicate update. - auto update_sharding = HloSharding::Replicate(); - if (!partitioned_non_slice_dims.empty()) { - // Do partial replicate for update if non slice dims are partitioned. - update_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding, - slice_dims); - } - - // TODO(wangtao): use collective permute for sharded update. - HloInstruction* replicate_update = - GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo(); - - const auto& update_shape = replicate_update->shape(); - const auto& partitioned_shape = partitioned_input->shape(); - auto partition_ordinals = MakeTiledPartitionOrdinals( - hlo->sharding(), MakePartitioningState().partition_id, &b_); - HloInstruction* all_dims_within_partition = add_hlo( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); - - for (int i = 0; i < partitioned_slice_dims.size(); ++i) { - int dim = partitioned_slice_dims[i]; - // Calculate per partition size. - const int64_t per_partition_size = partitioned_shape.dimensions(dim); - - // Only update within a single partition is supported. - // Will ignore this check when slice size is 1 where - // partitioned_slice_offsets[i] is -1. - if ((partitioned_slice_offsets[i] != -1) && - (partitioned_slice_offsets[i] / per_partition_size) != - ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) - - 1) / - per_partition_size)) { - handle_with_replicate_slice_dims(); - return absl::OkStatus(); - } - - // within_partition = (offset >= partition_id * per_partition_size) && - // (offset < (partition_id + 1) * per_partition_size) - const Shape& compare_shape = - ShapeUtil::ChangeElementType(partition_id_->shape(), PRED); - auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(per_partition_size))); - const Shape& offset_shape = per_partition_size_hlo->shape(); - const Shape& index_shape = new_indices[dim]->shape(); - if (offset_shape.element_type() != index_shape.element_type()) - new_indices[dim] = add_hlo(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(index_shape, - offset_shape.element_type()), - new_indices[dim])); - auto partition_offset = add_hlo(HloInstruction::CreateBinary( - offset_shape, HloOpcode::kMultiply, partition_ordinals[dim], - per_partition_size_hlo)); - // offset >= partition_id * per_partition_size - auto offset_ge = add_hlo(HloInstruction::CreateCompare( - compare_shape, new_indices[dim], partition_offset, - ComparisonDirection::kGe)); - // offset < (partition_id + 1) * per_partition_size - auto offset_lt = add_hlo(HloInstruction::CreateCompare( - compare_shape, new_indices[dim], - add_hlo(HloInstruction::CreateBinary( - offset_shape, HloOpcode::kMultiply, - add_hlo(HloInstruction::CreateBinary( - offset_shape, HloOpcode::kAdd, partition_ordinals[dim], - add_hlo(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(1))))), - per_partition_size_hlo)), - ComparisonDirection::kLt)); - auto update_within_partition = add_hlo(HloInstruction::CreateBinary( - compare_shape, HloOpcode::kAnd, offset_ge, offset_lt)); - - all_dims_within_partition = add_hlo(HloInstruction::CreateBinary( - compare_shape, HloOpcode::kAnd, all_dims_within_partition, - update_within_partition)); - - // Calculate offset. - // slice dim offset = - // within_partition ? - // offset - partition_id * per_partition_size : 0 - new_indices[dim] = add_hlo(HloInstruction::CreateTernary( - new_indices[dim]->shape(), HloOpcode::kSelect, - update_within_partition, - add_hlo(HloInstruction::CreateBinary( - new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim], - partition_offset)), - add_hlo( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))))); - if (new_indices[dim]->shape().element_type() != - index_shape.element_type()) - new_indices[dim] = add_hlo(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(new_indices[dim]->shape(), - index_shape.element_type()), - new_indices[dim])); - } - - // Create dynamic update slice. - auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice( - partitioned_shape, partitioned_input, replicate_update, new_indices)); - SetPartitionedHlo(hlo, [&]() { - // Select if update is needed. - return add_hlo(HloInstruction::CreateTernary( - dus->shape(), HloOpcode::kSelect, - add_hlo(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType(dus->shape(), PRED), - all_dims_within_partition, {})), - dus, partitioned_input)); - }); - return absl::OkStatus(); - } + // Method 2. Keep the sharding for input and output since the update is fully + // contained in a single partition. + auto add_hlo = [&](std::unique_ptr to_add) { + return b_.AddInstruction(std::move(to_add)); + }; - // Partition non slice dims only. - std::vector new_indices(hlo->shape().dimensions_size()); - auto new_input = - GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); - auto new_update = - GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo(); - for (int64_t i = 0; i < new_indices.size(); ++i) { - if (hlo->operand(1)->shape().dimensions(i) == hlo->shape().dimensions(i)) { - new_indices[i] = CreateZero(hlo->operand(i + 2)->shape(), &b_); - continue; - } - // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); - } + // Get partitioned input. + const auto& dus_sharding = hlo->sharding(); + const auto& partitioned_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo(); + + auto update_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding, + slice_dims); + + // TODO(wangtao): use collective permute for sharded update. + HloInstruction* replicate_update = + GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo(); + + const auto& partitioned_shape = partitioned_input->shape(); + auto partition_ordinals = MakeTiledPartitionOrdinals( + hlo->sharding(), MakePartitioningState().partition_id, &b_); + HloInstruction* all_dims_within_partition = add_hlo( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + + for (int64_t dim : partitioned_slice_dims) { + // Calculate per partition size. + const int64_t per_partition_size = partitioned_shape.dimensions(dim); + + // within_partition = (offset >= partition_id * per_partition_size) && + // (offset < (partition_id + 1) * per_partition_size) + const Shape& compare_shape = + ShapeUtil::ChangeElementType(partition_id_->shape(), PRED); + auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(per_partition_size))); + const Shape& offset_shape = per_partition_size_hlo->shape(); + const Shape& index_shape = new_indices[dim]->shape(); + if (offset_shape.element_type() != index_shape.element_type()) { + new_indices[dim] = add_hlo(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(index_shape, + offset_shape.element_type()), + new_indices[dim])); + } + auto partition_offset = add_hlo(HloInstruction::CreateBinary( + offset_shape, HloOpcode::kMultiply, partition_ordinals[dim], + per_partition_size_hlo)); + // offset >= partition_id * per_partition_size + auto offset_ge = add_hlo(HloInstruction::CreateCompare( + compare_shape, new_indices[dim], partition_offset, + ComparisonDirection::kGe)); + // offset < (partition_id + 1) * per_partition_size + auto offset_lt = add_hlo(HloInstruction::CreateCompare( + compare_shape, new_indices[dim], + add_hlo(HloInstruction::CreateBinary( + offset_shape, HloOpcode::kMultiply, + add_hlo(HloInstruction::CreateBinary( + offset_shape, HloOpcode::kAdd, partition_ordinals[dim], + add_hlo(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(1))))), + per_partition_size_hlo)), + ComparisonDirection::kLt)); + auto update_within_partition = add_hlo(HloInstruction::CreateBinary( + compare_shape, HloOpcode::kAnd, offset_ge, offset_lt)); + + all_dims_within_partition = add_hlo(HloInstruction::CreateBinary( + compare_shape, HloOpcode::kAnd, all_dims_within_partition, + update_within_partition)); + + // Calculate offset. + // slice dim offset = within_partition ? + // offset - partition_id * per_partition_size : 0 + new_indices[dim] = add_hlo(HloInstruction::CreateTernary( + new_indices[dim]->shape(), HloOpcode::kSelect, update_within_partition, + add_hlo(HloInstruction::CreateBinary( + new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim], + partition_offset)), + add_hlo( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))))); + if (new_indices[dim]->shape().element_type() != + index_shape.element_type()) { + new_indices[dim] = add_hlo(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(new_indices[dim]->shape(), + index_shape.element_type()), + new_indices[dim])); + } + } + + // Create dynamic update slice. + auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice( + partitioned_shape, partitioned_input, replicate_update, new_indices)); SetPartitionedHlo(hlo, [&]() { - auto partitioned_shape = - MakePartitionedShape(hlo->shape(), hlo->sharding()); - return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - partitioned_shape, new_input, new_update, new_indices)); + // Select if update is needed. + return add_hlo(HloInstruction::CreateTernary( + dus->shape(), HloOpcode::kSelect, + add_hlo(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(dus->shape(), PRED), + all_dims_within_partition, {})), + dus, partitioned_input)); }); return absl::OkStatus(); } From ba688bf33bda31ac57375f939cd2cc9e89ec5873 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Wed, 16 Apr 2025 21:18:19 -0700 Subject: [PATCH 0901/1324] Add traceme to `execute_fn` and `prepare_inputs` PiperOrigin-RevId: 748527512 --- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index 02a639cfefc073..ae48957c63d25b 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -2508,7 +2508,10 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( const RunId& run_id, const ExecuteOptions& options, tsl::AsyncValueRef last_collective_launch_event, bool fill_future, TfrtGpuDevice* device) { - tsl::profiler::TraceMe traceme("TfrtGpuExecutable::ExecuteHelper"); + tsl::profiler::TraceMeProducer activity("TfrtGpuExecutable::ExecuteHelper", + tsl::profiler::ContextType::kPjRt, + run_id.ToInt()); + if (VLOG_IS_ON(2)) { LOG(INFO) << "ExecuteHelper " << name() << ": " << options.launch_id << "; replica: " << replica << "; partition: " << partition @@ -2704,8 +2707,12 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( // launch is delayed. VLOG(1) << "Going to get compute reservation for " << name() << ": " << options.launch_id << "; replica: " << replica; - auto compute_reservation = std::make_unique( - device->max_inflight_computations_semaphore().ScopedAcquire(1)); + std::unique_ptr compute_reservation; + { + tsl::profiler::TraceMe t("waiting for compute reservation"); + compute_reservation = std::make_unique( + device->max_inflight_computations_semaphore().ScopedAcquire(1)); + } VLOG(1) << "Got compute reservation for " << name() << ": " << options.launch_id << "; replica: " << replica; auto ffi_context = @@ -2735,7 +2742,8 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( client = client_](std::vector execution_inputs) mutable { VLOG(1) << "execute_fn for " << executable_name << ": " << launch_id << "; replica: " << replica; - tsl::profiler::TraceMe traceme("execute_fn"); + tsl::profiler::TraceMeConsumer activity( + "execute_fn", tsl::profiler::ContextType::kPjRt, run_id.ToInt()); auto set_error = [&](absl::Status status) { for (auto& output_buffer : output_buffers) { output_buffer.SetError(status); @@ -2835,8 +2843,8 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( }; auto prepare_inputs = - [blocking_thread_pool = client_->blocking_thread_pool(), device, - tracked_buffers(std::move(tracked_buffers)), + [blocking_thread_pool = client_->blocking_thread_pool(), run_id(run_id), + device, tracked_buffers(std::move(tracked_buffers)), buffer_is_donated(std::move(buffer_is_donated)), prepare_inputs_avs(CopyAsyncValues(prepare_input_deps)), execute_event(execute_event.CopyRef()), @@ -2847,7 +2855,9 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( arguments_are_tupled(options.arguments_are_tupled), input_buffer_sizes_in_bytes( input_buffer_sizes_in_bytes_[executable_idx])]() mutable { - tsl::profiler::TraceMe traceme("prepare_inputs"); + tsl::profiler::TraceMeConsumer activity( + "prepare_inputs", tsl::profiler::ContextType::kPjRt, + run_id.ToInt()); VLOG(2) << "prepare_inputs"; DCHECK_EQ(tracked_buffers.size(), buffer_is_donated.size()); From 9ff0f1cf5a375815ae5d1b24543ce050bbd0a8e1 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Wed, 16 Apr 2025 21:27:33 -0700 Subject: [PATCH 0902/1324] PR #18410: [XLA:CPU][oneDNN] Absorb Transpose into matmul whenever possible Imported from GitHub PR https://github.com/openxla/xla/pull/18410 This PR tries to absorb transposes into matmuls and eliminate associated time consuming copy operations, whenever feasible. These optimizations are expected to benefit attention-based models, where transposed projections are used to compute attention scores. Additionally, this PR includes tests to ensure correct functionality. Copybara import of the project: -- 7b20cf42219d5c19c72f630373fcb42fb88b8284 by Akhil Goel : Absorb Transpose into matmul whenever possible -- 51d2927335f90436c033c34b458bb81e0693c4c9 by Akhil Goel : Remove unused variables Merging this change closes #18410 PiperOrigin-RevId: 748529894 --- .../xla/xla/service/cpu/onednn_config.proto | 3 + .../cpu/onednn_contraction_rewriter.cc | 66 ++++++++++++++++++- .../xla/xla/service/cpu/onednn_matmul.cc | 47 +++++++++---- third_party/xla/xla/service/cpu/tests/BUILD | 1 + .../service/cpu/tests/onednn_matmul_test.cc | 62 +++++++++++++++++ 5 files changed, 167 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/cpu/onednn_config.proto b/third_party/xla/xla/service/cpu/onednn_config.proto index ab534a8f00689c..bf11634e2c3bb7 100644 --- a/third_party/xla/xla/service/cpu/onednn_config.proto +++ b/third_party/xla/xla/service/cpu/onednn_config.proto @@ -83,6 +83,9 @@ message OneDnnMatMulConfig { reserved 7; // was user_scratchpad OneDnnOptimizationConfig optimization_config = 8; + OneDnnTensorLayoutProto lhs = 9; + OneDnnTensorLayoutProto rhs = 10; + OneDnnTensorLayoutProto result = 11; } message OneDnnWindowProto { diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index 0cefc32d0b663b..76d346ee4edc9a 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -981,6 +981,31 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } + absl::Status HandleCopy(HloInstruction* instr) override { + HloInstruction *copy, *transpose, *custom_call; + if (Match(instr, + m::Copy(©, m::Transpose(&transpose, + OneDnnMatmulInstr(&custom_call))))) { + auto backend_config = custom_call->backend_config(); + auto dimensions = backend_config->mutable_onednn_matmul_config() + ->mutable_result() + ->mutable_tensor() + ->mutable_dimensions(); + dimensions->Resize(transpose->dimensions().size(), 0); + // Configure inverse transpose dimensions + int counter = 1; + for (auto x : transpose->dimensions()) { + dimensions->Set(x, counter++); + } + auto matmul_call = Cast( + custom_call->AddInstruction(custom_call->CloneWithNewOperands( + copy->shape(), custom_call->mutable_operands()))); + TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(copy, matmul_call)); + } + return absl::OkStatus(); + } + absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind, HloInstruction* activation, HloInstruction* contraction, @@ -1150,11 +1175,38 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { #endif } + void UpdateTransposeDimensions( + HloInstruction* matmul, absl::InlinedVector& new_ops, + int operand_idx, absl::StatusOr* backend_config) { + HloInstruction *transpose, *operand; + // Update the dimensions only when the transpose does not involve the batch + // dimension, as modifying it could significantly impact the performance. + if (Match(matmul->mutable_operand(operand_idx), + m::Copy(m::Transpose(&transpose, m::Op(&operand)))) && + transpose->dimensions()[0] == 0) { + new_ops[operand_idx] = operand; + for (auto x : transpose->dimensions()) { + (*GetOperandTensor(operand_idx, backend_config))->Add(x + 1); + } + } + } + absl::Status HandleCustomCall(HloInstruction* custom_call) override { HloInstruction* contraction; if (Match(custom_call, OneDnnMatmulInstr(&contraction))) { + auto backend_config = contraction->backend_config(); + auto new_ops = contraction->mutable_operands(); + + UpdateTransposeDimensions(contraction, new_ops, 0, &backend_config); + UpdateTransposeDimensions(contraction, new_ops, 1, &backend_config); + + auto matmul_call = Cast( + contraction->AddInstruction(contraction->CloneWithNewOperands( + contraction->shape(), new_ops))); + TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(contraction, matmul_call)); return HandleCustomCallInternal( - custom_call); + matmul_call); } else if (Match(custom_call, OneDnnConvolutionInstr(&contraction))) { return HandleCustomCallInternal< dnnl::convolution_forward::primitive_desc>(custom_call); @@ -1264,6 +1316,18 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { } } + absl::StatusOr*> GetOperandTensor( + int operand_idx, absl::StatusOr* backend_config) { + if (operand_idx > 1) { + return absl::CancelledError("Operand index must be either 0 or 1"); + } + auto operand = + (operand_idx == 0) + ? (*backend_config)->mutable_onednn_matmul_config()->mutable_lhs() + : (*backend_config)->mutable_onednn_matmul_config()->mutable_rhs(); + return operand->mutable_tensor()->mutable_dimensions(); + } + void ReorderWeight(const dnnl::memory::desc& src_md, void* src_buf, const dnnl::memory::desc& dst_md, void* dst_buf) { auto onednn_threadpool = CreateOneDnnThreadPool(threadpool_device_.get()); diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 91d5979473bc62..e617387831de2f 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -50,6 +50,20 @@ using dnnl::matmul; using dnnl::memory; using dnnl::stream; +void TransposeIfNecessary( + const tsl::protobuf::RepeatedField dimensions, + bool transpose_last_2_dims, dnnl::memory::desc& mem_desc) { + if (mem_desc.get_ndims() < 2) return; + std::vector permutation(mem_desc.get_ndims()); + std::iota(permutation.begin(), permutation.end(), 0); + int counter = 0; + for (auto it = dimensions.begin(); it != dimensions.end(); it++) { + permutation[*it - 1] = counter++; + } + mem_desc = mem_desc.permute_axes(permutation); + TRANSPOSE_LAST_TWO_DIMS_IF(transpose_last_2_dims, mem_desc); +} + dnnl::memory::desc OneDnnMatMulOptWeightsDesc( const dnnl::engine& engine, const dnnl::memory::desc& input_md, const dnnl::memory::desc& weights_md, const dnnl::memory::desc& bias_md, @@ -70,13 +84,17 @@ dnnl::memory::desc OneDnnMatMulOptWeightsDesc( const Shape& output_shape, const OneDnnMatMulConfig* matmul_config) { auto input_md = ShapeToMemDesc(input_shape); auto weights_md = ShapeToMemDesc(weights_shape); - TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config->transpose_a(), input_md); - TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config->transpose_b(), weights_md); + TransposeIfNecessary(matmul_config->lhs().tensor().dimensions(), + matmul_config->transpose_a(), input_md); + TransposeIfNecessary(matmul_config->rhs().tensor().dimensions(), + matmul_config->transpose_b(), weights_md); auto bias_md = absl::c_count(matmul_config->fusions().ops(), OneDnnFusionConfig::BIAS) > 0 ? ShapeToMemDesc(bias_shape) : dnnl::memory::desc{}; auto output_md = ShapeToMemDesc(output_shape); + TransposeIfNecessary(matmul_config->result().tensor().dimensions(), false, + output_md); // extend bias rank to match result rank auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); @@ -140,9 +158,13 @@ std::unique_ptr CreateMatMulPrimDesc( const OneDnnMatMulConfig& matmul_config) { auto input_md = ShapeToMemDesc(input_shape); auto weights_md = ShapeToMemDesc(weights_shape); - TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config.transpose_a(), input_md); - TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config.transpose_b(), weights_md); + TransposeIfNecessary(matmul_config.lhs().tensor().dimensions(), + matmul_config.transpose_a(), input_md); + TransposeIfNecessary(matmul_config.rhs().tensor().dimensions(), + matmul_config.transpose_b(), weights_md); auto output_md = ShapeToMemDesc(output_shape); + TransposeIfNecessary(matmul_config.result().tensor().dimensions(), false, + output_md); std::vector fused_mds; std::transform(fused_shapes.begin(), fused_shapes.end(), std::back_inserter(fused_mds), @@ -218,10 +240,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( auto weights_md = weights_minfo.GetOneDnnMemDesc(); // Input and weights memory::desc need to be in correct layout before matmul // primitive descriptor is created. - TRANSPOSE_LAST_TWO_DIMS_IF( - matmul_config.transpose_a() && input_md.get_ndims() > 1, input_md); - TRANSPOSE_LAST_TWO_DIMS_IF( - matmul_config.transpose_b() && weights_md.get_ndims() > 1, weights_md); + TransposeIfNecessary(matmul_config.lhs().tensor().dimensions(), + matmul_config.transpose_a(), input_md); + TransposeIfNecessary(matmul_config.rhs().tensor().dimensions(), + matmul_config.transpose_b(), weights_md); auto output_md = output_minfo.GetOneDnnMemDesc(); Literal* reordered_weights_literal = nullptr; @@ -278,7 +300,8 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( onednn_stream.wait(); weights_md = reordered_weights_md; } - + TransposeIfNecessary(matmul_config.result().tensor().dimensions(), false, + output_md); const int64_t num_fused_operands = num_args - arg_indx; std::vector fused_mds; std::vector fused_bufs; @@ -374,8 +397,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMulReorder( XLA_LIGHTWEIGHT_CHECK(num_args >= arg_indx); // Update dims and strides for transposed inputs. - TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config.transpose_a(), input_md); - TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config.transpose_b(), weight_md); + TransposeIfNecessary(matmul_config.lhs().tensor().dimensions(), + matmul_config.transpose_a(), input_md); + TransposeIfNecessary(matmul_config.rhs().tensor().dimensions(), + matmul_config.transpose_b(), weight_md); // extend bias rank to match result rank if (!bias_md.is_zero()) { diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index c7bfb02453a949..70567c1365e98a 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -377,6 +377,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:platform_port", ], ) diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc index c2053ee8851eba..e7a39d97610f95 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc @@ -15,8 +15,10 @@ limitations under the License. #if defined(INTEL_MKL) +#include #include +#include "absl/strings/str_replace.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/test_helpers.h" @@ -148,6 +150,10 @@ class MatmulTest : public HloTestBase { ; CHECK-DAG: } ; CHECK: } )"; + const char* matmul_transpose_rewrite_str_ = R"( + ; CHECK-NOT: transpose(%{{[a-z,A-Z,0-9,_,\.]*}}), + ; CHECK: custom_call_target="__onednn$matmul", + )"; }; TEST_F(MatmulTest, SimpleTestF32) { @@ -1532,6 +1538,62 @@ TEST_F(MatmulTest, SimpleTestBF16WithMulAndAddFusion) { )"); } +std::string CreateTransposeFusionModuleText(std::string dtype) { + const char* matmul_module_str = R"( + ENTRY matmul.test { + arg0.1 = DTYPE[32,40,30,64] parameter(0), parameter_replication={false} + transpose.1 = DTYPE[32,30,40,64]{3,1,2,0} transpose(arg0.1), dimensions={0,2,1,3} + copy.1 = DTYPE[32,30,40,64]{3,2,1,0} copy(transpose.1) + arg0.2 = DTYPE[32,40,30,64]{3,2,1,0} parameter(1), parameter_replication={false} + transpose.2 = DTYPE[32,30,40,64]{3,1,2,0} transpose(arg0.2), dimensions={0,2,1,3} + copy.2 = DTYPE[32,30,40,64]{3,2,1,0} copy(transpose.2) + ROOT dot.201 = DTYPE[32,30,40,40]{3,2,1,0} dot(copy.1, copy.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + return absl::StrReplaceAll(matmul_module_str, {{"DTYPE", dtype}}); +} + +TEST_F(MatmulTest, SimpleTestTransposeFusionF32) { + const std::string matmul_module_str = CreateTransposeFusionModuleText("f32"); + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_transpose_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestTransposeFusionBF16) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + const std::string matmul_module_str = CreateTransposeFusionModuleText("bf16"); + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, matmul_transpose_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestTransposeFusionF16) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + const std::string matmul_module_str = CreateTransposeFusionModuleText("f16"); + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, matmul_transpose_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestNoTransposeFusion) { + const char* matmul_module_str = R"( + ENTRY matmul.test { + arg0.1 = f32[32,40,40,32] parameter(0), parameter_replication={false} + arg0.2 = f32[32,40,40,32] parameter(1), parameter_replication={false} + transpose.2 = f32[32,40,40,32] transpose(arg0.2), dimensions={3,2,1,0} + copy.2 = f32[32,40,40,32] copy(transpose.2) + ROOT dot.201 = f32[32,40,40,40] dot(arg0.1, copy.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: transpose(%{{[a-z,A-Z,0-9,_,\.]*}}), + ; CHECK: custom_call_target="__onednn$matmul", + )"); +} + TEST_F(MatmulTest, WeightsPrepackAndScratch) { const char* matmul_module_str = R"( HloModule matmul.test.f32 From c1868be49f864ba135f8465191fdf19509267b76 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 22:11:45 -0700 Subject: [PATCH 0903/1324] Remove unsafe `Shape::mutable_dynamic_dimensions()`. Some combinations of (dimension_size, is_dynamic) are invalid (e.g. a static dimension cannot have a size = `kUnboundedSize`). To ensure that `Shape`'s invariants aren't broken, we shouldn't expose `dimensions` and `dynamic_dimensions` as mutable values directly. Instead, all changes to these properties should go through safe `Shape` APIs that validate the arguments (e.g. `set_dynamic_dimension()`). Luckily, `mutable_dynamic_dimensions()` is unused, so we can just remove it. PiperOrigin-RevId: 748540904 --- third_party/xla/xla/shape.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 6774417451db3b..311590febe815b 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -214,9 +214,6 @@ class Shape { absl::Span dynamic_dimensions() const { return array_state().dynamic_dimensions; } - absl::Span mutable_dynamic_dimensions() { - return absl::MakeSpan(array_state().dynamic_dimensions); - } // Removes the given dimension from the shape. Layout, if it exists, is // adjusted to match the modified shape. From 578324f33ca4a1f189b0f6f44d6be73f13c13553 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 22:17:25 -0700 Subject: [PATCH 0904/1324] Automated Code Change PiperOrigin-RevId: 748542713 --- tensorflow/core/tpu/kernels/infeed_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/tpu/kernels/infeed_ops.cc b/tensorflow/core/tpu/kernels/infeed_ops.cc index 17953799fff16c..d59c6c4b6d4683 100644 --- a/tensorflow/core/tpu/kernels/infeed_ops.cc +++ b/tensorflow/core/tpu/kernels/infeed_ops.cc @@ -89,7 +89,7 @@ absl::StatusOr TransposeTensor(OpKernelContext* ctx, const Tensor& input_tensor, const xla::Shape& xla_shape) { tsl::profiler::TraceMe trace_me("TransposeTensor", /*level=*/2); - const int64_t rank = xla_shape.dimensions_size(); + const int64_t rank = xla_shape.dimensions().size(); std::vector permutation(rank); std::vector transposed_shapes(rank); for (int64_t i = 0; i < rank; ++i) { From a68ff8b0d15119ce601bab87599e47973e4cc5f0 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 16 Apr 2025 22:21:21 -0700 Subject: [PATCH 0905/1324] HLO->StableHLO Direct conversion cleanup PiperOrigin-RevId: 748543737 --- third_party/xla/xla/hlo/translate/BUILD | 1 - .../xla/xla/hlo/translate/stablehlo.cc | 55 +------------------ 2 files changed, 2 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/hlo/translate/BUILD b/third_party/xla/xla/hlo/translate/BUILD index ef35b6f67a6276..4a5bcf1e127909 100644 --- a/third_party/xla/xla/hlo/translate/BUILD +++ b/third_party/xla/xla/hlo/translate/BUILD @@ -95,7 +95,6 @@ cc_library( "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter", "//xla/mlir/utils:error_util", - "//xla/mlir_hlo", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/mlir_hlo:mhlo_passes", "//xla/mlir_hlo:stablehlo_extension_passes", diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index cc4d0a761c9e46..9288a560e5780d 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -16,15 +16,11 @@ limitations under the License. #include "xla/hlo/translate/stablehlo.h" #include -#include #include "mhlo/transforms/passes.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -32,11 +28,9 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/Types.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "stablehlo/dialect/Register.h" @@ -46,7 +40,6 @@ limitations under the License. #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" #include "xla/mlir/utils/error_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/mlir_hlo/stablehlo_ext/transforms/passes.h" @@ -60,48 +53,6 @@ namespace xla { namespace { -bool isBoundedDynamic(mlir::Type type) { - LLVM_DEBUG(llvm::dbgs() << "isBoundedDynamic: " << type << "\n"); - if (!llvm::isa(type)) { - return false; - } - auto encoding = llvm::cast(type).getEncoding(); - return encoding && llvm::isa(encoding); -} - -bool hasBoundedDynamism(mlir::ModuleOp module) { - bool has_bounded_dynamism = false; - module->walk([&](mlir::Operation* op) { - auto results = op->getResultTypes(); - has_bounded_dynamism |= llvm::any_of(results, isBoundedDynamic); - if (has_bounded_dynamism) { - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }); - return has_bounded_dynamism; -} - -absl::Status MhloToStablehlo(mlir::ModuleOp module) { - LLVM_DEBUG(llvm::dbgs() << "MHLO to StableHLO\n"); - auto context = module.getContext(); - mlir::PassManager pm(context); - mlir::BaseScopedDiagnosticHandler diag_handler(context); - mlir::mhlo::HloLegalizeToStablehloPassOptions options; - options.allow_xla_features_ = true; - bool has_bounded_dynamism = hasBoundedDynamism(module); - if (has_bounded_dynamism) { - // Need to converge program to MHLO before StableHLO in the presence of - // bounded dynamism. - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - } - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass(options)); - if (failed(pm.run(module))) { - return diag_handler.ConsumeStatus(); - } - return absl::OkStatus(); -} - // TODO(b/385393967) Separate createCanonicalizerPass from StableHLO -> HLO // Translation absl::Status StablehloToMhlo(mlir::ModuleOp module, bool run_canonicalizer) { @@ -190,9 +141,8 @@ absl::StatusOr> ConvertHloToStablehlo( TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), /*import_all_computation=*/true, /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/false) + /*emit_stablehlo=*/true) .Import(*hlo_module)); - TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); return mlir_module; } @@ -203,9 +153,8 @@ absl::StatusOr> ConvertHloToStablehlo( TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), /*import_all_computation=*/true, /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/false) + /*emit_stablehlo=*/true) .Import(*hlo_module_proto)); - TF_RETURN_IF_ERROR(MhloToStablehlo(mlir_module.get())); return mlir_module; } From 0924d71649aef2432708328c3153b4d66169eb69 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Wed, 16 Apr 2025 23:11:04 -0700 Subject: [PATCH 0906/1324] PR #23315: [ROCM][NFC] GpuBlaslt matmul thunk cache refactoring part II Imported from GitHub PR https://github.com/openxla/xla/pull/23315 After this PR is merged: https://github.com/openxla/xla/pull/21886, GpuBlasLt MatmulPlan cache can now be refactored. It was originally introduced here: https://github.com/openxla/xla/pull/6595 Here the idea is that multiple matmul thunks can share the **same** matmul plans (allocated on the same device). This can significantly reduce the memory overhead for large training. Furthermore, the original cache used **stream pointer** as a cache key: this might be inefficient when banchmarking the same HLO using XLA tools like multi_host_hlo_runner (which could allocate **a new stream** for each iteration). I have also added the correspnding **gpublas_lt_matmul_thunk_test** to check this functionality. Besides, I also refactored CublasLtCmd which was a blank copy-paste of CublasLtMatmulThunk. Finally, I have removed the magic constant 128 and replaced it with GemmConfig::kNumAlgorithms. @xla-rotation could you please have a look ? I have also gathered some stats for LLAMA Maxtext model training with 8 GPUs: ``` 2025-03-05 13:33:04.446763: I external/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc:131] Total matmul thunks created: 1039 ................ 2025-03-05 13:34:48.478186: I external/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc:188] 0x558aeca66fd8: Adding new MatmulPlan for stream: 0x558a9ee83a40 2025-03-05 13:34:48.478220: I external/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc:67] Plan created: cache size: 29 ``` So, XLA runtime created 1039 GpuBlasLt Thunk instances while cache contained only 29 unique entries (because most of Gemm configs are **duplicates**). Hence, provided that it runs with 8 GPUs, we create at least (1039 - 29) * 8 = **8080** matmul plans less than before. Copybara import of the project: -- b7da2d5b06dee0ccfd521839449903b25f2fe8c2 by Pavel Emeliyanenko : gpublaslt matmul thunk cache refactoring adapted cublaslt added test enabled tests on rocm update restored command_buffer thunk tests (to be updated in a new PR) moved matmul plan cache to a separate file updated matmul_thunk test updated thunk test added a separate matmul_plan cache unit test made SetAlgorithm function non-const implemented matmul plan cache directly as part of BlasLt instance added some cosmetics fixing the unit tests fixes -- ce12834e7455afafbe0143b5000b23c513449099 by Pavel Emeliyanenko : fixing the build -- 598441f8687f79c0f833436f188de57fdf424a14 by Pavel Emeliyanenko : some last test fixes -- b7c9db4bd59a7bf35717eb4bdc900ffcc0c0910a by Pavel Emeliyanenko : removed unused include Merging this change closes #23315 PiperOrigin-RevId: 748556645 --- .../xla/xla/backends/gpu/runtime/BUILD | 43 +- .../gpu/runtime/command_buffer_cmd.cc | 178 ++------- .../backends/gpu/runtime/command_buffer_cmd.h | 57 +-- .../gpu/runtime/command_buffer_cmd_emitter.cc | 8 +- .../gpu/runtime/command_buffer_thunk_test.cc | 12 +- .../gpu/runtime/gpublas_lt_matmul_thunk.cc | 187 +++++---- .../gpu/runtime/gpublas_lt_matmul_thunk.h | 98 ++--- .../runtime/gpublas_lt_matmul_thunk_test.cc | 371 ++++++++++++++++++ .../gpu/autotuning/gemm_algorithm_picker.cc | 10 +- .../xla/service/gpu/ir_emitter_unnested.cc | 10 +- .../xla/xla/service/gpu/matmul_utils.h | 2 + .../xla/stream_executor/cuda/cuda_blas_lt.cc | 15 +- .../xla/stream_executor/cuda/cuda_blas_lt.h | 10 +- third_party/xla/xla/stream_executor/gpu/BUILD | 2 + .../xla/stream_executor/gpu/gpu_blas_lt.cc | 25 +- .../xla/xla/stream_executor/gpu/gpu_blas_lt.h | 56 ++- .../xla/stream_executor/rocm/hip_blas_lt.cc | 15 +- .../xla/stream_executor/rocm/hip_blas_lt.h | 10 +- 18 files changed, 724 insertions(+), 385 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index b9a102462006b4..6aebea75c6f506 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -58,6 +58,7 @@ cc_library( ":collective_thunk", ":custom_call_thunk", ":dynamic_slice_thunk", + ":gpublas_lt_matmul_thunk", ":thunk", "//xla:debug_options_flags", "//xla:executable_run_options", @@ -630,11 +631,13 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", @@ -642,6 +645,44 @@ cc_library( ], ) +xla_test( + name = "gpublas_lt_matmul_thunk_test", + srcs = ["gpublas_lt_matmul_thunk_test.cc"], + backends = ["gpu"], + deps = [ + ":gpublas_lt_matmul_thunk", + ":thunk", + "//xla:error_spec", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:executable", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "infeed_thunk", srcs = ["infeed_thunk.cc"], diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 6b44b4690ec836..7b78f8aab99c60 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -1169,82 +1169,15 @@ CommandBufferCmd::BufferUseVector GemmCmd::buffers() { // CublasLtCmd //===----------------------------------------------------------------------===// -CublasLtCmd::CublasLtCmd( - ExecutionStreamId execution_stream_id, GemmConfig gemm_config, - se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, - BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer /* may be null */, - BufferAllocation::Slice aux_buffer /* may be null */, - BufferAllocation::Slice a_scale_buffer /* may be null */, - BufferAllocation::Slice b_scale_buffer /* may be null */, - BufferAllocation::Slice c_scale_buffer /* may be null */, - BufferAllocation::Slice d_scale_buffer /* may be null */, - BufferAllocation::Slice d_amax_buffer /* may be null */, - BufferAllocation::Slice workspace_buffer) +CublasLtCmd::CublasLtCmd(ExecutionStreamId execution_stream_id, + const CublasLtMatmulThunk& matmul_thunk) : TracedCommandBufferCmd(CommandBufferCmdType::kCublasLtCmd, execution_stream_id), - gemm_config_(std::move(gemm_config)), - epilogue_(epilogue), - algorithm_idx_(algorithm_idx), - a_buffer_(a_buffer), - b_buffer_(b_buffer), - c_buffer_(c_buffer), - d_buffer_(d_buffer), - bias_buffer_(bias_buffer), - aux_buffer_(aux_buffer), - a_scale_buffer_(a_scale_buffer), - b_scale_buffer_(b_scale_buffer), - c_scale_buffer_(c_scale_buffer), - d_scale_buffer_(d_scale_buffer), - d_amax_buffer_(d_amax_buffer), - workspace_buffer_(workspace_buffer) {} - -absl::StatusOr CublasLtCmd::GetMatmulPlan( - const se::Stream* stream) { - { - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto it = matmul_plans_cache_.find(stream); - if (it != matmul_plans_cache_.end()) return it->second.get(); - } - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( - stream, gemm_config_, epilogue_)); - - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto [it_insert, _] = matmul_plans_cache_.emplace(stream, std::move(plan)); - return it_insert->second.get(); -} - -absl::StatusOr -CublasLtCmd::GetMatmulAlgorithm(const se::Stream* stream, - const se::gpu::BlasLt::MatmulPlan* plan, - int64_t max_workspace) { - { - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto it = matmul_algorithm_cache_.find(plan); - if (it != matmul_algorithm_cache_.end()) return it->second; - } - TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(stream, /*max_algorithm_count*/ 128, - /*max_workspace_size*/ max_workspace)); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto [it_insert, _] = - matmul_algorithm_cache_.emplace(plan, algorithms[algorithm_idx_]); - return it_insert->second; -} + CublasLtMatmulThunk(matmul_thunk) {} absl::Status CublasLtCmd::Initialize(const Thunk::InitializeParams& params, StateManager& state) { - if (!params.stream->parent()->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support for GemmCmd"); - } - // Populate plan and algorithm cache; - TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); - TF_RETURN_IF_ERROR( - GetMatmulAlgorithm(params.stream, plan, workspace_buffer_.size()) - .status()); + TF_RETURN_IF_ERROR(CublasLtMatmulThunk::Initialize(params)); return absl::OkStatus(); } @@ -1252,92 +1185,61 @@ absl::StatusOr CublasLtCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { - TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(execute_params.stream)); - TF_ASSIGN_OR_RETURN(auto algorithm, - GetMatmulAlgorithm(execute_params.stream, plan, - workspace_buffer_.size())); - - const BufferAllocations& allocs = *execute_params.buffer_allocations; - - se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, aux, d_amax; - if (bias_buffer_.allocation() != nullptr) { - bias = allocs.GetDeviceAddress(bias_buffer_); - } - if (a_scale_buffer_.allocation() != nullptr) { - a_scale = allocs.GetDeviceAddress(a_scale_buffer_); - } - if (b_scale_buffer_.allocation() != nullptr) { - b_scale = allocs.GetDeviceAddress(b_scale_buffer_); - } - if (c_scale_buffer_.allocation() != nullptr) { - c_scale = allocs.GetDeviceAddress(c_scale_buffer_); - } - if (d_scale_buffer_.allocation() != nullptr) { - d_scale = allocs.GetDeviceAddress(d_scale_buffer_); - } - if (d_amax_buffer_.allocation() != nullptr) { - d_amax = allocs.GetDeviceAddress(d_amax_buffer_); - } - if (aux_buffer_.allocation() != nullptr) { - aux = allocs.GetDeviceAddress(aux_buffer_); - } + // This call is required to make sure matmul plan is already created and + // cached before recording the command buffer. + TF_RETURN_IF_ERROR(GetCachedMatmulPlan(execute_params).status()); VLOG(5) << "CublasLtCmd:"; - VLOG(5) << " a_buffer: " << a_buffer_.ToString(); - VLOG(5) << " b_buffer: " << b_buffer_.ToString(); - VLOG(5) << " c_buffer: " << c_buffer_.ToString(); - VLOG(5) << " d_buffer: " << d_buffer_.ToString(); - VLOG(5) << " bias_buffer: " << bias_buffer_.ToString(); - VLOG(5) << " aux_buffer: " << aux_buffer_.ToString(); - VLOG(5) << " a_scale_buffer: " << a_scale_buffer_.ToString(); - VLOG(5) << " b_scale_buffer: " << b_scale_buffer_.ToString(); - VLOG(5) << " c_scale_buffer: " << c_scale_buffer_.ToString(); - VLOG(5) << " d_scale_buffer: " << d_scale_buffer_.ToString(); - VLOG(5) << " d_amax_buffer: " << d_amax_buffer_.ToString(); - VLOG(5) << " workspace_buffer: " << workspace_buffer_.ToString(); + VLOG(5) << " a_buffer: " << a_.ToString(); + VLOG(5) << " b_buffer: " << b_.ToString(); + VLOG(5) << " c_buffer: " << c_.ToString(); + VLOG(5) << " d_buffer: " << d_.ToString(); + VLOG(5) << " bias_buffer: " << bias_.ToString(); + VLOG(5) << " aux_buffer: " << aux_.ToString(); + VLOG(5) << " a_scale_buffer: " << a_scale_.ToString(); + VLOG(5) << " b_scale_buffer: " << b_scale_.ToString(); + VLOG(5) << " c_scale_buffer: " << c_scale_.ToString(); + VLOG(5) << " d_scale_buffer: " << d_scale_.ToString(); + VLOG(5) << " d_amax_buffer: " << d_amax_.ToString(); + // workspace buffer is guaranteed to be non-null here. + VLOG(5) << " workspace_buffer: " << workspace_->ToString(); return RecordTracedCommand( execute_params, record_params, std::move(record_action), command_buffer, [&](se::Stream* stream) { - return plan->ExecuteOnStream( - stream, allocs.GetDeviceAddress(a_buffer_), - allocs.GetDeviceAddress(b_buffer_), - allocs.GetDeviceAddress(c_buffer_), - allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax, algorithm, - allocs.GetDeviceAddress(workspace_buffer_)); + return ExecuteOnStreamInternal(stream, execute_params); }); } CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() { BufferUseVector buffer_usage; buffer_usage.reserve(13); - buffer_usage.push_back({a_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({b_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({c_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({workspace_buffer_, MemoryAccess::kWrite}); + buffer_usage.push_back({a_, MemoryAccess::kRead}); + buffer_usage.push_back({b_, MemoryAccess::kRead}); + buffer_usage.push_back({c_, MemoryAccess::kRead}); + buffer_usage.push_back({d_, MemoryAccess::kWrite}); + buffer_usage.push_back({*workspace_, MemoryAccess::kWrite}); - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); + if (bias_.allocation() != nullptr) { + buffer_usage.push_back({bias_, MemoryAccess::kRead}); } - if (a_scale_buffer_.allocation() != nullptr) { - buffer_usage.push_back({a_scale_buffer_, MemoryAccess::kRead}); + if (a_scale_.allocation() != nullptr) { + buffer_usage.push_back({a_scale_, MemoryAccess::kRead}); } - if (b_scale_buffer_.allocation() != nullptr) { - buffer_usage.push_back({b_scale_buffer_, MemoryAccess::kRead}); + if (b_scale_.allocation() != nullptr) { + buffer_usage.push_back({b_scale_, MemoryAccess::kRead}); } - if (c_scale_buffer_.allocation() != nullptr) { - buffer_usage.push_back({c_scale_buffer_, MemoryAccess::kRead}); + if (c_scale_.allocation() != nullptr) { + buffer_usage.push_back({c_scale_, MemoryAccess::kRead}); } - if (d_scale_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_scale_buffer_, MemoryAccess::kRead}); + if (d_scale_.allocation() != nullptr) { + buffer_usage.push_back({d_scale_, MemoryAccess::kRead}); } - if (aux_buffer_.allocation() != nullptr) { - buffer_usage.push_back({aux_buffer_, MemoryAccess::kWrite}); + if (aux_.allocation() != nullptr) { + buffer_usage.push_back({aux_, MemoryAccess::kWrite}); } - if (d_amax_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_amax_buffer_, MemoryAccess::kRead}); + if (d_amax_.allocation() != nullptr) { + buffer_usage.push_back({d_amax_, MemoryAccess::kRead}); } return buffer_usage; } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index a7c8b56b3e576b..d8abe1cf6d33e9 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -41,6 +41,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" +#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/ffi/api/c_api.h" #include "xla/hlo/ir/hlo_computation.h" @@ -720,26 +721,19 @@ class GemmCmd : public TracedCommandBufferCmd { // CublasLtCmd //===----------------------------------------------------------------------===// -class CublasLtCmd : public TracedCommandBufferCmd { +class CublasLtCmd : public TracedCommandBufferCmd, public CublasLtMatmulThunk { public: - CublasLtCmd(ExecutionStreamId execution_stream_id, GemmConfig gemm_config, - se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, - BufferAllocation::Slice a_buffer, - BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, - BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer /* may be null */, - BufferAllocation::Slice aux_buffer /* may be null */, - BufferAllocation::Slice a_scale_buffer /* may be null */, - BufferAllocation::Slice b_scale_buffer /* may be null */, - BufferAllocation::Slice c_scale_buffer /* may be null */, - BufferAllocation::Slice d_scale_buffer /* may be null */, - BufferAllocation::Slice d_amax_buffer /* may be null */, - BufferAllocation::Slice workspace_buffer); + CublasLtCmd(ExecutionStreamId execution_stream_id, + const CublasLtMatmulThunk& matmul_thunk); absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; + // This is needed to avoid compile errors about "shadowed" virtual function + absl::Status Initialize(const InitializeParams& params) override { + return CublasLtMatmulThunk::Initialize(params); + } + absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, @@ -748,39 +742,6 @@ class CublasLtCmd : public TracedCommandBufferCmd { BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } - - private: - absl::StatusOr GetMatmulPlan( - const se::Stream* stream); - - absl::StatusOr GetMatmulAlgorithm( - const se::Stream* stream, const se::gpu::BlasLt::MatmulPlan* plan, - int64_t max_workspace); - - absl::Mutex matmul_plans_cache_mutex_; - absl::flat_hash_map - matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); - - absl::Mutex matmul_algorithm_cache_mutex_; - absl::flat_hash_map - matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_); - - const GemmConfig gemm_config_; - const se::gpu::BlasLt::Epilogue epilogue_; - const int64_t algorithm_idx_; - const BufferAllocation::Slice a_buffer_; - const BufferAllocation::Slice b_buffer_; - const BufferAllocation::Slice c_buffer_; - const BufferAllocation::Slice d_buffer_; - const BufferAllocation::Slice bias_buffer_; - const BufferAllocation::Slice aux_buffer_; - const BufferAllocation::Slice a_scale_buffer_; - const BufferAllocation::Slice b_scale_buffer_; - const BufferAllocation::Slice c_scale_buffer_; - const BufferAllocation::Slice d_scale_buffer_; - const BufferAllocation::Slice d_amax_buffer_; - const BufferAllocation::Slice workspace_buffer_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index a9296db0ae9c98..83037d4260931f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -134,13 +134,7 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { return absl::InternalError( "Gemm thunk does not contain a workspace buffer"); } - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), thunk.epilogue(), - thunk.algorithm_idx(), thunk.a_buffer(), thunk.b_buffer(), - thunk.c_buffer(), thunk.d_buffer(), thunk.bias_buffer(), - thunk.aux_buffer(), thunk.a_scale_buffer(), thunk.b_scale_buffer(), - thunk.c_scale_buffer(), thunk.d_scale_buffer(), thunk.d_amax_buffer(), - thunk.workspace().value()); + return std::make_unique(thunk.execution_stream_id(), thunk); } static absl::StatusOr Convert( diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index 69db4164ce44da..e5bf0e10571ec0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -982,11 +982,13 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; commands.Emplace( - s0, config.value(), se::gpu::BlasLt::Epilogue::kDefault, 0, slice_a, - slice_b, slice_c, slice_d, BufferAllocation::Slice(), - BufferAllocation::Slice(), BufferAllocation::Slice(), - BufferAllocation::Slice(), BufferAllocation::Slice(), - BufferAllocation::Slice(), BufferAllocation::Slice(), slice_workspace); + s0, CublasLtMatmulThunk( + nullptr, config.value(), se::gpu::BlasLt::Epilogue::kDefault, 0, + slice_a, slice_b, slice_c, slice_d, BufferAllocation::Slice(), + BufferAllocation::Slice(), BufferAllocation::Slice(), + BufferAllocation::Slice(), BufferAllocation::Slice(), + BufferAllocation::Slice(), BufferAllocation::Slice(), + slice_workspace)); TF_ASSERT_OK_AND_ASSIGN( CommandBufferCmdExecutor executor, CommandBufferCmdExecutor::Create(std::move(commands), serialize)); diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc index 5f6be14eb1f74a..ef7f10627a7eba 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -19,12 +19,15 @@ limitations under the License. #include #include +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -35,117 +38,127 @@ limitations under the License. namespace xla { namespace gpu { +CublasLtMatmulThunk::CublasLtMatmulThunk(const CublasLtMatmulThunk& rhs) + : Thunk(Kind::kCublasLtMatmul, {}), + gemm_config_(rhs.gemm_config_), + epilogue_(rhs.epilogue_), + algorithm_idx_(rhs.algorithm_idx_), + canonical_hlo_(rhs.canonical_hlo_), + a_(rhs.a_), + b_(rhs.b_), + c_(rhs.c_), + d_(rhs.d_), + bias_(rhs.bias_), + aux_(rhs.aux_), + a_scale_(rhs.a_scale_), + b_scale_(rhs.b_scale_), + c_scale_(rhs.c_scale_), + d_scale_(rhs.d_scale_), + d_amax_(rhs.d_amax_), + workspace_(rhs.workspace_) {} + CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, + const HloInstruction* instr, GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, - BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer, BufferAllocation::Slice aux_buffer, + BufferAllocation::Slice a, BufferAllocation::Slice b, + BufferAllocation::Slice c, BufferAllocation::Slice d, + BufferAllocation::Slice bias, BufferAllocation::Slice aux, BufferAllocation::Slice a_scale, BufferAllocation::Slice b_scale, BufferAllocation::Slice c_scale, BufferAllocation::Slice d_scale, BufferAllocation::Slice d_amax, - std::optional workspace_buffer) - : Thunk(Kind::kCublasLtMatmul, thunk_info), + std::optional workspace) + : Thunk(Kind::kCublasLtMatmul, + instr ? Thunk::ThunkInfo::WithProfileAnnotation(instr) + : Thunk::ThunkInfo{}), gemm_config_(std::move(gemm_config)), epilogue_(epilogue), algorithm_idx_(algorithm_idx), - a_buffer_(a_buffer), - b_buffer_(b_buffer), - c_buffer_(c_buffer), - d_buffer_(d_buffer), - bias_buffer_(bias_buffer), - aux_buffer_(aux_buffer), - a_scale_buffer_(a_scale), - b_scale_buffer_(b_scale), - c_scale_buffer_(c_scale), - d_scale_buffer_(d_scale), - d_amax_buffer_(d_amax), - workspace_buffer_(workspace_buffer) {} - -absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { - TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); - - TF_ASSIGN_OR_RETURN(auto algorithm, - GetMatmulAlgorithm(params.stream, plan, - workspace_buffer_.has_value() - ? workspace_buffer_.value().size() - : 0)); + a_(a), + b_(b), + c_(c), + d_(d), + bias_(bias), + aux_(aux), + a_scale_(a_scale), + b_scale_(b_scale), + c_scale_(c_scale), + d_scale_(d_scale), + d_amax_(d_amax), + workspace_(workspace) { + // The tests creating CublasLtMatmulThunk directly might not provide the + // pointer to the actual instruction, in this case Matmul plans are not + // cached. + if (instr != nullptr) { + canonical_hlo_ = xla::gpu::AutotuneCacheKey("unused", *instr).GetHlo(); + } +} + +absl::Status CublasLtMatmulThunk::ExecuteOnStreamInternal( + se::Stream* stream, const ExecuteParams& params) { + TF_ASSIGN_OR_RETURN(auto* plan, GetCachedMatmulPlan(params)); VLOG(3) << "Running cublas_lt matmul thunk"; const BufferAllocations& allocs = *params.buffer_allocations; - se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; - if (bias_buffer_.allocation() != nullptr) { - bias = allocs.GetDeviceAddress(bias_buffer_); + se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax, aux, + workspace; + if (bias_.allocation() != nullptr) { + bias = allocs.GetDeviceAddress(bias_); } - if (a_scale_buffer_.allocation() != nullptr) { - a_scale = allocs.GetDeviceAddress(a_scale_buffer_); + if (a_scale_.allocation() != nullptr) { + a_scale = allocs.GetDeviceAddress(a_scale_); } - if (b_scale_buffer_.allocation() != nullptr) { - b_scale = allocs.GetDeviceAddress(b_scale_buffer_); + if (b_scale_.allocation() != nullptr) { + b_scale = allocs.GetDeviceAddress(b_scale_); } - if (c_scale_buffer_.allocation() != nullptr) { - c_scale = allocs.GetDeviceAddress(c_scale_buffer_); + if (c_scale_.allocation() != nullptr) { + c_scale = allocs.GetDeviceAddress(c_scale_); } - if (d_scale_buffer_.allocation() != nullptr) { - d_scale = allocs.GetDeviceAddress(d_scale_buffer_); + if (d_scale_.allocation() != nullptr) { + d_scale = allocs.GetDeviceAddress(d_scale_); } - if (d_amax_buffer_.allocation() != nullptr) { - d_amax = allocs.GetDeviceAddress(d_amax_buffer_); + if (d_amax_.allocation() != nullptr) { + d_amax = allocs.GetDeviceAddress(d_amax_); } - - se::DeviceMemoryBase aux; - if (aux_buffer_.allocation() != nullptr) { - aux = allocs.GetDeviceAddress(aux_buffer_); + if (aux_.allocation() != nullptr) { + aux = allocs.GetDeviceAddress(aux_); } - - se::DeviceMemoryBase workspace; - if (workspace_buffer_.has_value()) { - workspace = allocs.GetDeviceAddress(workspace_buffer_.value()); + if (workspace_.has_value()) { + workspace = allocs.GetDeviceAddress(workspace_.value()); } return plan->ExecuteOnStream( - params.stream, allocs.GetDeviceAddress(a_buffer_), - allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), - allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, algorithm, workspace); + stream, allocs.GetDeviceAddress(a_), allocs.GetDeviceAddress(b_), + allocs.GetDeviceAddress(c_), allocs.GetDeviceAddress(d_), bias, aux, + a_scale, b_scale, c_scale, d_scale, d_amax, workspace); } -absl::StatusOr CublasLtMatmulThunk::GetMatmulPlan( - const se::Stream* stream) { - { - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto it = matmul_plans_cache_.find(stream); - if (it != matmul_plans_cache_.end()) return it->second.get(); - } - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( - stream, gemm_config_, epilogue_)); - - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto [it, _] = matmul_plans_cache_.emplace(stream, std::move(plan)); - return it->second.get(); -} - -absl::StatusOr -CublasLtMatmulThunk::GetMatmulAlgorithm(const se::Stream* stream, - const se::gpu::BlasLt::MatmulPlan* plan, - int64_t max_workspace) { - { - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto it = matmul_algorithm_cache_.find(plan); - if (it != matmul_algorithm_cache_.end()) return it->second; - } - TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(stream, - /*max_algorithm_count*/ 128, - /*max_workspace_size*/ max_workspace)); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto [it, _] = - matmul_algorithm_cache_.emplace(plan, algorithms[algorithm_idx_]); - return it->second; +absl::StatusOr +CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) { + auto* blas_lt = se::gpu::BlasLt::Get(params.stream); + auto create = [&]() -> absl::StatusOr { + VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream + << " instr: " << canonical_hlo_; + + TF_ASSIGN_OR_RETURN(auto plan, + blas_lt->GetMatmulPlan(gemm_config_, epilogue_)); + // if workspace buffer is not provided, consider onlt the algorithms which + // do not require a scratch space + int64_t max_workspace = + workspace_.has_value() ? workspace_.value().size() : 0; + + // If autotuning is disabled, there is no point on retrieving all + // algorithms, it's enough to get the default one only. + int64_t num_algorithms = + algorithm_idx_ == 0 ? 1 : GemmConfig::kNumAlgorithms; + TF_ASSIGN_OR_RETURN( + auto algorithms, + plan->GetAlgorithms(params.stream, num_algorithms, max_workspace)); + + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); + return std::move(plan); + }; + return blas_lt->GetOrCreateMatmulPlan(canonical_hlo_, create); } absl::Status CublasLtMatmulThunk::Initialize(const InitializeParams& params) { diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h index 3f8570764024f3..60d7112b60c7d7 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h @@ -20,9 +20,7 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" -#include "absl/synchronization/mutex.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/matmul_utils.h" @@ -34,74 +32,52 @@ namespace gpu { class CublasLtMatmulThunk : public Thunk { public: - CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, - se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, - BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer /* may be null */, - BufferAllocation::Slice aux_buffer /* may be null */, - BufferAllocation::Slice a_scale_buffer /* may be null */, - BufferAllocation::Slice b_scale_buffer /* may be null */, - BufferAllocation::Slice c_scale_buffer /* may be null */, - BufferAllocation::Slice d_scale_buffer /* may be null */, - BufferAllocation::Slice d_amax_buffer /* may be null */, - std::optional workspace_buffer); - - absl::Status ExecuteOnStream(const ExecuteParams& params) override; + CublasLtMatmulThunk(const HloInstruction* instr, GemmConfig gemm_config, + se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, + BufferAllocation::Slice a, BufferAllocation::Slice b, + BufferAllocation::Slice c, BufferAllocation::Slice d, + BufferAllocation::Slice bias /* may be null */, + BufferAllocation::Slice aux /* may be null */, + BufferAllocation::Slice a_scale /* may be null */, + BufferAllocation::Slice b_scale /* may be null */, + BufferAllocation::Slice c_scale /* may be null */, + BufferAllocation::Slice d_scale /* may be null */, + BufferAllocation::Slice d_amax /* may be null */, + std::optional workspace); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override { + return ExecuteOnStreamInternal(params.stream, params); + } absl::Status Initialize(const InitializeParams& params) override; - - GemmConfig config() const { return gemm_config_; } - se::gpu::BlasLt::Epilogue epilogue() const { return epilogue_; } - int64_t algorithm_idx() const { return algorithm_idx_; } - - BufferAllocation::Slice a_buffer() const { return a_buffer_; } - BufferAllocation::Slice b_buffer() const { return b_buffer_; } - BufferAllocation::Slice c_buffer() const { return c_buffer_; } - BufferAllocation::Slice d_buffer() const { return d_buffer_; } - BufferAllocation::Slice bias_buffer() const { return bias_buffer_; } - BufferAllocation::Slice aux_buffer() const { return aux_buffer_; } - BufferAllocation::Slice a_scale_buffer() const { return a_scale_buffer_; } - BufferAllocation::Slice b_scale_buffer() const { return b_scale_buffer_; } - BufferAllocation::Slice c_scale_buffer() const { return c_scale_buffer_; } - BufferAllocation::Slice d_scale_buffer() const { return d_scale_buffer_; } - BufferAllocation::Slice d_amax_buffer() const { return d_amax_buffer_; } std::optional workspace() const { - return workspace_buffer_; + return workspace_; } - private: - absl::StatusOr GetMatmulPlan( - const stream_executor::Stream* stream); - absl::StatusOr GetMatmulAlgorithm( - const se::Stream* stream, const se::gpu::BlasLt::MatmulPlan* plan, - int64_t max_workspace); - - absl::Mutex matmul_plans_cache_mutex_; - absl::flat_hash_map - matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); + protected: + CublasLtMatmulThunk(const CublasLtMatmulThunk& rhs); - absl::Mutex matmul_algorithm_cache_mutex_; - absl::flat_hash_map - matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_); + absl::Status ExecuteOnStreamInternal(se::Stream* stream, + const ExecuteParams& params); + absl::StatusOr GetCachedMatmulPlan( + const ExecuteParams& params); + protected: GemmConfig gemm_config_; se::gpu::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; - BufferAllocation::Slice a_buffer_; - BufferAllocation::Slice b_buffer_; - BufferAllocation::Slice c_buffer_; - BufferAllocation::Slice d_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice aux_buffer_; - BufferAllocation::Slice a_scale_buffer_; - BufferAllocation::Slice b_scale_buffer_; - BufferAllocation::Slice c_scale_buffer_; - BufferAllocation::Slice d_scale_buffer_; - BufferAllocation::Slice d_amax_buffer_; - std::optional workspace_buffer_; + std::string canonical_hlo_; + BufferAllocation::Slice a_; + BufferAllocation::Slice b_; + BufferAllocation::Slice c_; + BufferAllocation::Slice d_; + BufferAllocation::Slice bias_; + BufferAllocation::Slice aux_; + BufferAllocation::Slice a_scale_; + BufferAllocation::Slice b_scale_; + BufferAllocation::Slice c_scale_; + BufferAllocation::Slice d_scale_; + BufferAllocation::Slice d_amax_; + std::optional workspace_; }; } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc new file mode 100644 index 00000000000000..80c7f1d11865b1 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/error_spec.h" +#include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/threadpool.h" + +namespace xla::gpu { + +namespace { + +class GpuBlasLtMatmulThunkTest : public HloTestBase { + public: + DebugOptions GetDebugOptionsForTest() const override { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_cublaslt(true); + debug_options.set_xla_gpu_enable_triton_gemm(false); + return debug_options; + } + se::StreamExecutor* default_exec() { + return backend().default_stream_executor(); + } + const se::DeviceDescription& device_desc(se::StreamExecutor* exec = nullptr) { + if (exec == nullptr) { + exec = default_exec(); + } + return exec->GetDeviceDescription(); + } + const se::GpuComputeCapability& gpu_comp(se::StreamExecutor* exec = nullptr) { + return device_desc(exec).gpu_compute_capability(); + } + + void SetUp() override { + if (auto* rocm = std::get_if(&gpu_comp()); + rocm != nullptr && !rocm->has_hipblaslt()) { + GTEST_SKIP() << "No hipblas-lt support on this architecture!"; + } + } + + void CreateExecuteThunksFromHLO(se::StreamExecutor* executor, + absl::string_view hlo_string); +}; + +struct GpuBlasLtThunkBuilder { + GpuBlasLtThunkBuilder(se::StreamExecutor* exec, + const se::GpuComputeCapability& gpu_comp) + : exec_(exec), allocator_(exec), gpu_comp_(gpu_comp) {} + + absl::StatusOr> CreateThunk( + HloInstruction* gemm) { + TF_ASSIGN_OR_RETURN(const auto gpu_config, + gemm->backend_config()); + const auto& backend_config = gpu_config.gemm_backend_config(); + + TF_ASSIGN_OR_RETURN( + bool has_vector_bias, + gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue())); + bool has_matrix_bias = backend_config.beta() != 0; + TF_ASSIGN_OR_RETURN( + auto epilogue, gpublas_lt::AsBlasLtEpilogue(backend_config.epilogue())); + + std::vector slices; + std::vector buf_sizes; + for (auto op : gemm->operands()) { + auto size = ShapeUtil::ByteSizeOf(op->shape()); + buf_sizes.push_back(size); + } + const auto& output_shape = + gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); + buf_sizes.push_back(ShapeUtil::ByteSizeOf(output_shape)); + + size_t idx = allocs_.size(); + slices.reserve(buf_sizes.size()); + for (auto size : buf_sizes) { + mem_buffers_.emplace_back(); + TF_ASSIGN_OR_RETURN(mem_buffers_.back(), + allocator_.Allocate(exec_->device_ordinal(), size)); + allocs_.emplace_back(/*index=*/idx++, size, /*color=*/0); + slices.emplace_back(&allocs_.back(), /*offset*/ 0, size); + } + // we need at least 3 buffers: lhs, rhs and output + EXPECT_EQ(slices.size(), + 3 + size_t{has_matrix_bias} + size_t{has_vector_bias}); + TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm, gpu_comp_)); + + BufferAllocation::Slice bias; + if (has_vector_bias) { + bias = slices[has_matrix_bias ? 3 : 2]; + } + + return std::make_unique( + gemm, std::move(gemm_config), epilogue, + /*algorithm_idx*/ 0, slices[0], slices[1], + has_matrix_bias ? slices[2] : slices.back(), slices.back(), bias, + BufferAllocation::Slice{} /* aux */, + BufferAllocation::Slice{} /* a_scale */, + BufferAllocation::Slice{} /* b_scale */, + BufferAllocation::Slice{} /* c_scale */, + BufferAllocation::Slice{} /* d_scale */, + BufferAllocation::Slice{} /* d_amax */, std::nullopt /* workspace */); + } + + std::unique_ptr buffer_allocations() { + std::vector buffers(mem_buffers_.size()); + for (size_t i = 0; i < buffers.size(); i++) { + buffers[i] = *mem_buffers_[i]; + } + return std::make_unique(buffers, exec_->device_ordinal(), + &allocator_); + } + + private: + se::StreamExecutor* exec_; + se::StreamExecutorMemoryAllocator allocator_; + se::GpuComputeCapability gpu_comp_; + std::deque allocs_; + std::vector mem_buffers_; +}; + +void GpuBlasLtMatmulThunkTest::CreateExecuteThunksFromHLO( + se::StreamExecutor* executor, absl::string_view hlo_string) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + this->ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + RunHloPass( + GemmRewriter(gpu_comp(executor), + /*toolkit_version=*/se::SemanticVersion{12, 4, 0}), + module.get())); + ASSERT_TRUE(changed); + + GpuBlasLtThunkBuilder builder(executor, gpu_comp(executor)); + std::vector> gemm_thunks; + + for (auto* instr : module->entry_computation()->instructions()) { + if (IsCublasLtMatmul(*instr)) { + TF_ASSERT_OK_AND_ASSIGN(auto thunk, builder.CreateThunk(instr)); + gemm_thunks.push_back(std::move(thunk)); + } + } + auto allocs = builder.buffer_allocations(); + ServiceExecutableRunOptions run_options; + + auto thread_func = [&](se::Stream* stream) -> absl::Status { + auto thunk_params = Thunk::ExecuteParams::Create( + run_options, *allocs, stream, stream, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + for (auto& thunk : gemm_thunks) { + TF_RETURN_IF_ERROR( + thunk->Initialize({executor, source, allocs.get(), stream, stream})); + TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); + } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + return absl::OkStatus(); + }; + + // Running BlasLt thunks across multiple streams with shared matmul plan + int num_streams = 10; + struct StreamInfo { + std::unique_ptr stream; + absl::Status result; + }; + std::vector threads(num_streams); + { + tsl::thread::ThreadPool pool(tsl::Env::Default(), "test_streams", + num_streams); + // use two different loops to make sure all threads start at the same time + for (auto& [s, _] : threads) { + TF_ASSERT_OK_AND_ASSIGN(s, executor->CreateStream()); + } + // some compilers complain about lambda capture of structured bindings + for (auto& info : threads) { + pool.Schedule([&] { info.result = thread_func(info.stream.get()); }); + } + } + for (const auto& [_, res] : threads) { + TF_ASSERT_OK(res); + } +} + +const absl::string_view hlo_single_plan = R"( +HloModule SharedMatmulPlan + +ENTRY test { + x1 = f32[101,407] parameter(0) + x2 = f32[101,407] parameter(1) + x3 = f32[101,407] parameter(2) + y = f32[407,400] parameter(3) + z = f32[407,400] parameter(4) + w = f32[407,400] parameter(5) + dot_a = f32[101,400] dot(x1, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_b = f32[101,400] dot(x2, z), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_c = f32[101,400] dot(x3, w), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul_ab = f32[101,400] multiply(dot_a, dot_b) + ROOT abc = f32[101,400] subtract(mul_ab, dot_c) +})"; + +// same as above but now we have non-default epilogue for one dot operation +const absl::string_view hlo_two_plans = + R"( +HloModule SharedMatmulPlan + +ENTRY test { + x1 = f32[101,407] parameter(0) + x2 = f32[101,407] parameter(1) + x3 = f32[101,407] parameter(2) + y = f32[407,400] parameter(3) + z = f32[407,400] parameter(4) + w = f32[407,400] parameter(5) + c = f32[] constant(0) + c_bcast = f32[101,400] broadcast(c), dimensions={} + dot_a = f32[101,400] dot(x1, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + out_a = f32[101,400] maximum(dot_a, c_bcast) + dot_b = f32[101,400] dot(x2, z), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_c = f32[101,400] dot(x3, w), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul_ab = f32[101,400] multiply(out_a, dot_b) + ROOT abc = f32[101,400] subtract(mul_ab, dot_c) +})"; + +XLA_TEST_F(GpuBlasLtMatmulThunkTest, SharedMatmulPlansUnit) { + auto* exec = default_exec(); + auto* blas_lt = exec->AsBlas()->GetBlasLt(); + EXPECT_NE(blas_lt, nullptr); + blas_lt->ClearMatmulPlanCache(); + + CreateExecuteThunksFromHLO(exec, hlo_single_plan); + // Assert that only one matmul plan was created + EXPECT_EQ(blas_lt->GetMatmulPlanCacheSize(), 1); + + CreateExecuteThunksFromHLO(exec, hlo_two_plans); + // Assert that we have now 2 MatmulPlans (one more created for ReLu epilogue). + EXPECT_EQ(blas_lt->GetMatmulPlanCacheSize(), 2); +} + +// Same as above but instead of creating thunks manually, we use XLA runtime +XLA_TEST_F(GpuBlasLtMatmulThunkTest, SharedMatmulPlansFunctional) { + auto* exec = default_exec(); + auto* blas_lt = exec->AsBlas()->GetBlasLt(); + EXPECT_NE(blas_lt, nullptr); + blas_lt->ClearMatmulPlanCache(); + + EXPECT_TRUE(RunAndCompare(hlo_single_plan, ErrorSpec{1e-3, 1e-3})); + // Assert that only one MatmulPlan cache entry was created. + EXPECT_EQ(blas_lt->GetMatmulPlanCacheSize(), 1); + + EXPECT_TRUE(RunAndCompare(hlo_two_plans, ErrorSpec{1e-3, 1e-3})); + // Assert that we have now 2 MatmulPlans (one more created for ReLu epilogue). + EXPECT_EQ(blas_lt->GetMatmulPlanCacheSize(), 2); +} + +// Mock BlasLt interface to test only the cache function +struct MockBlasLt : public se::gpu::BlasLt { + absl::Status Init() override { return absl::OkStatus(); } + + absl::StatusOr GetMatmulPlan(const se::gpu::GemmConfig&, + Epilogue) const override { + return MatmulPlanPtr{}; + } + ~MockBlasLt() override = default; +}; + +XLA_TEST_F(GpuBlasLtMatmulThunkTest, CacheUnitTest) { + auto thread_func = [&](MockBlasLt* blas_lt, const std::string& key, + int sleep_ms) -> absl::Status { + auto create_func = [&]() -> absl::StatusOr { + // We don't care about creation of matmul plans -> emulate it with a sleep + absl::SleepFor(absl::Milliseconds(sleep_ms)); + return se::gpu::BlasLt::MatmulPlanPtr{}; + }; + + return blas_lt->GetOrCreateMatmulPlan(key, create_func).status(); + }; // thread_func + + const int num_blas_lts = 30, num_streams = 30, + total = num_blas_lts * num_streams, mod = 11; + + std::vector results(total); + std::vector blas_lts(num_blas_lts); + + { + tsl::thread::ThreadPool pool(tsl::Env::Default(), "test_streams", total); + std::random_device rand_dev; + std::default_random_engine engine(rand_dev()); + std::uniform_int_distribution uniform_sleeps(1, 500); + + for (int j = 0, k = 0; j < num_blas_lts; j++) { + for (int i = 0; i < num_streams; i++, k++) { + int sleep_ms = uniform_sleeps(engine), x = i + j + 1; + // we could have same keys for different executors + auto key = std::to_string((x * x * x) % mod); + VLOG(1) << j << "," << i << " :" << key; + pool.Schedule([&, key, sleep_ms, k, j] { + results[k] = thread_func(&blas_lts[j], key, sleep_ms); + }); + } + } // for j + } // end block + for (auto& res : results) { + TF_ASSERT_OK(res); + } + + // We assert that we have the same number of cache entries for each executor + // and that this number is <= mod (based on our logic to create keys) + std::optional size; + for (const auto& blas_lt : blas_lts) { + if (!size) { + size = blas_lt.GetMatmulPlanCacheSize(); + } else { + EXPECT_EQ(*size, blas_lt.GetMatmulPlanCacheSize()); + } + } + EXPECT_TRUE(size.has_value() && static_cast(*size <= mod)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc index c323def8b13a2b..460fa8a0682bb8 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc @@ -201,24 +201,24 @@ class GemmAutotuner { TF_ASSIGN_OR_RETURN( auto algorithms, - plan->GetAlgorithms(stream_, /*max_algorithm_count*/ 128, + plan->GetAlgorithms(stream_, GemmConfig::kNumAlgorithms, /*max_workspace_size*/ workspace_buffer.size())); auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm) -> absl::StatusOr { // Run a warmup iteration without the profiler active. + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithm)); TF_RETURN_IF_ERROR(plan->ExecuteOnStream( stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(), bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, - c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, - workspace_buffer)); + c_scale_buffer, d_scale_buffer, d_amax_buffer, workspace_buffer)); se::blas::ProfileResult profile_result; profile_result.set_warmup_run_executed(true); TF_RETURN_IF_ERROR(plan->ExecuteOnStream( stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(), bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, - c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, - workspace_buffer, &profile_result)); + c_scale_buffer, d_scale_buffer, d_amax_buffer, workspace_buffer, + &profile_result)); return std::move(profile_result); }; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 63ccf0f9e8abf5..5358655f76322e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -733,9 +733,8 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), - blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax, workspace_buffer); + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -825,9 +824,8 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), - blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax, workspace_buffer); + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index b9bf8b1408eb88..f8dec3249ced25 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -84,6 +84,8 @@ struct GemmConfig : public se::gpu::GemmConfig { static constexpr int64_t kHopperWorkspace = 32 * 1024 * 1024; // 32 MiB static constexpr int64_t kGFX950Workspace = 64 * 1024 * 1024; // 64 MiB static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024; // 4 MiB + // the number of algorithms to consider for autotuning by default + static constexpr int64_t kNumAlgorithms = 128; static absl::StatusOr For( const HloInstruction* gemm, const se::GpuComputeCapability& gpu_version); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index 17e7d6dc9fca89..317d218f6bd531 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -345,8 +345,12 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, const void* beta, - const MatmulAlgorithm& algorithm, const gpu::BlasLt::MemoryArgs& args, + const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { + if (!algorithm_.has_value()) { + return absl::InternalError( + "Algorithm must be set before calling DoMatMul!"); + } DeviceMemoryBase a = args.a, b = args.b; if (must_swap_operands_) { std::swap(a, b); @@ -362,7 +366,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } void* workspace_addr = nullptr; - uint64_t workspace_size = algorithm.workspace_size; + uint64_t workspace_size = algorithm_->workspace_size; if (workspace_size > 0) { if (args.scratch_allocator != nullptr) { TF_ASSIGN_OR_RETURN( @@ -377,7 +381,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } } - auto palgo = std::any_cast(&algorithm.opaque_algo); + auto palgo = std::any_cast(&algorithm_->opaque_algo); { absl::MutexLock lock(&blas_lt->mu_); TF_RET_CHECK(blas_lt->blas_lt_ != nullptr); @@ -478,8 +482,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } absl::Status BlasLt::MatmulPlan::ExecuteOnStream( - Stream* stream, const MatmulAlgorithm& algorithm, - const gpu::BlasLt::MemoryArgs& args, + Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { auto wrapped_matmul = [&](auto scale) { using Scale = decltype(scale); @@ -491,7 +494,7 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( salpha = static_cast(alpha_.real()); } Scale sbeta = static_cast(beta_); - return DoMatmul(stream, &salpha, &sbeta, algorithm, args, profile_result); + return DoMatmul(stream, &salpha, &sbeta, args, profile_result); }; std::tuple operand_types{a_desc_.type(), b_desc_.type(), c_desc_.type(), diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 8909c6ac7ed0e1..4fc21afa50db92 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -101,17 +101,20 @@ class BlasLt : public gpu::BlasLt { ~MatmulPlan() override = default; absl::Status ExecuteOnStream( - Stream* stream, const MatmulAlgorithm& algorithm, - const gpu::BlasLt::MemoryArgs& args, + Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const override; absl::StatusOr> GetAlgorithms( const Stream* stream, size_t max_algorithm_count, size_t max_workspace_size) const override; + absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override { + algorithm_ = algorithm; + return absl::OkStatus(); + } + private: absl::Status DoMatmul(Stream* stream, const void* alpha, const void* beta, - const MatmulAlgorithm& algorithm, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const; @@ -124,6 +127,7 @@ class BlasLt : public gpu::BlasLt { xla::complex128 alpha_; double beta_; bool must_swap_operands_; + std::optional algorithm_; // selected algorithm }; // class MatmulPlan explicit BlasLt(StreamExecutor* parent) diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 721cb7a298826b..06a20a621cf838 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -528,6 +528,8 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc index 182af599af9e5c..456258fba638a5 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -234,6 +234,29 @@ DataType GetScaleType(DataType c_type, ComputationType computation_type) { : c_type); } -} // namespace gpu +absl::StatusOr BlasLt::GetOrCreateMatmulPlan( + const std::string& key, PlanCreateFunc create) { + absl::MutexLock lock(&plan_cache_mu_); // double mutex ??? + auto res = plan_cache_.emplace(key, MatmulPlanPtr{}); + // New entry inserted: always create a new matmul plan if key is empty, + // this is used by command_buffer_thunk test. + if (res.second || key.empty()) { + VLOG(2) << "Creating a plan for: " << key; + TF_ASSIGN_OR_RETURN(res.first->second, create()); + VLOG(2) << "Plan created: cache size: " << plan_cache_.size(); + } + return res.first->second.get(); +} + +void BlasLt::ClearMatmulPlanCache() { + absl::MutexLock lock(&plan_cache_mu_); + plan_cache_.clear(); +} + +size_t BlasLt::GetMatmulPlanCacheSize() const { + absl::MutexLock lock(&plan_cache_mu_); + return plan_cache_.size(); +} +} // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index ca4aff2c8d42ab..68699016670c05 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -24,6 +24,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/stream_executor/blas.h" @@ -155,8 +157,8 @@ struct BlasLt { }; struct MatmulPlan { - // API that uses scratch_allocator to allocate workspace. - // This version is used by TF: see tensorflow/core/kernels/matmul_util.cc + // This function is to be removed once TF interface is fixed, + // see tensorflow/core/kernels/matmul_util.cc absl::Status ExecuteOnStream( Stream* stream, DeviceMemoryBase a, DeviceMemoryBase b, DeviceMemoryBase c, DeviceMemoryBase d, @@ -167,8 +169,29 @@ struct BlasLt { DeviceMemoryBase d_amax, const MatmulAlgorithm& algorithm, ScratchAllocator& scratch_allocator, blas::ProfileResult* profile_result = nullptr) const { + // Temporary hack until Tensorflow side is fixed + TF_RETURN_IF_ERROR( + const_cast(this)->SetAlgorithm(algorithm)); return ExecuteOnStream( - stream, algorithm, + stream, + MemoryArgs{a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, + d_amax, DeviceMemoryBase{}, &scratch_allocator}, + profile_result); + } + + // API that uses scratch_allocator to allocate workspace. + // This version is used by TF: see tensorflow/core/kernels/matmul_util.cc + absl::Status ExecuteOnStream( + Stream* stream, DeviceMemoryBase a, DeviceMemoryBase b, + DeviceMemoryBase c, DeviceMemoryBase d, + DeviceMemoryBase bias, // may be null + DeviceMemoryBase aux, // may be null + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, ScratchAllocator& scratch_allocator, + blas::ProfileResult* profile_result = nullptr) const { + return ExecuteOnStream( + stream, MemoryArgs{a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, DeviceMemoryBase{}, &scratch_allocator}, profile_result); @@ -182,11 +205,10 @@ struct BlasLt { DeviceMemoryBase aux, // may be null DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, - DeviceMemoryBase d_amax, const MatmulAlgorithm& algorithm, - DeviceMemoryBase workspace, + DeviceMemoryBase d_amax, DeviceMemoryBase workspace, blas::ProfileResult* profile_result = nullptr) const { return ExecuteOnStream( - stream, algorithm, + stream, MemoryArgs{a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace, nullptr}, profile_result); @@ -194,8 +216,8 @@ struct BlasLt { // The most general form: to be implemented by derived clases. virtual absl::Status ExecuteOnStream( - Stream* stream, const MatmulAlgorithm& algorithm, - const MemoryArgs& args, blas::ProfileResult* profile_result) const = 0; + Stream* stream, const MemoryArgs& args, + blas::ProfileResult* profile_result) const = 0; // Returns a list of supported algorithms for DoMatmul. The algorithms are // returned in the order of increasing estimated compute time according to @@ -204,10 +226,17 @@ struct BlasLt { const Stream* stream, size_t max_algorithm_count = 128, size_t max_workspace_size = 1ll << 32) const = 0; + // Algorithm must to be set before calling ExecuteOnStream function(s). + // Usually, we call ExecuteOnStream with the same algorithm ID, hence using + // a separate function here enables BlasLt implementations to do additional + // optimizations (like preloading matmul kernels) once the algorithm is set. + virtual absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) = 0; + virtual ~MatmulPlan() {} }; // class MatmulPlan using MatmulPlanPtr = std::unique_ptr; + using PlanCreateFunc = absl::AnyInvocable()>; virtual absl::Status Init() = 0; @@ -221,7 +250,18 @@ struct BlasLt { const GemmConfig& cfg, Epilogue epilogue); + absl::StatusOr GetOrCreateMatmulPlan(const std::string& key, + PlanCreateFunc create); + + void ClearMatmulPlanCache(); + size_t GetMatmulPlanCacheSize() const; + virtual ~BlasLt() {} + + protected: + mutable absl::Mutex plan_cache_mu_; + absl::flat_hash_map plan_cache_ + ABSL_GUARDED_BY(plan_cache_mu_); }; // class BlasLt } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 37084b0ce38971..81fa80f8a5f798 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -378,8 +378,12 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, const void* beta, - const MatmulAlgorithm& algorithm, const gpu::BlasLt::MemoryArgs& args, + const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { + if (!algorithm_.has_value()) { + return absl::InternalError( + "Algorithm must be set before calling DoMatMul!"); + } DeviceMemoryBase a = args.a, b = args.b; if (must_swap_operands_) { std::swap(a, b); @@ -399,7 +403,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } void* workspace_addr = nullptr; - uint64_t workspace_size = algorithm.workspace_size; + uint64_t workspace_size = algorithm_->workspace_size; if (workspace_size > 0) { if (args.scratch_allocator != nullptr) { TF_ASSIGN_OR_RETURN( @@ -414,7 +418,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } } - auto palgo = std::any_cast(&algorithm.opaque_algo); + auto palgo = std::any_cast(&algorithm_->opaque_algo); { absl::MutexLock lock(&blas_lt->mu_); TF_RET_CHECK(blas_lt->blas_lt_ != nullptr); @@ -497,8 +501,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } absl::Status BlasLt::MatmulPlan::ExecuteOnStream( - Stream* stream, const MatmulAlgorithm& algorithm, - const gpu::BlasLt::MemoryArgs& args, + Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { auto wrapped_matmul = [&](auto scale) { using Scale = decltype(scale); @@ -510,7 +513,7 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( salpha = static_cast(alpha_.real()); } Scale sbeta = static_cast(beta_); - return DoMatmul(stream, &salpha, &sbeta, algorithm, args, profile_result); + return DoMatmul(stream, &salpha, &sbeta, args, profile_result); }; std::tuple operand_types{a_desc_.type(), b_desc_.type(), c_desc_.type(), diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 1236bcec2e6e50..70740c05b43af9 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -102,17 +102,20 @@ class BlasLt : public gpu::BlasLt { ~MatmulPlan() override = default; absl::Status ExecuteOnStream( - Stream* stream, const MatmulAlgorithm& algorithm, - const gpu::BlasLt::MemoryArgs& args, + Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const override; absl::StatusOr> GetAlgorithms( const Stream* stream, size_t max_algorithm_count, size_t max_workspace_size) const override; + absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override { + algorithm_ = algorithm; + return absl::OkStatus(); + } + protected: absl::Status DoMatmul(Stream* stream, const void* alpha, const void* beta, - const MatmulAlgorithm& algorithm, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const; @@ -126,6 +129,7 @@ class BlasLt : public gpu::BlasLt { xla::complex128 alpha_; double beta_; bool must_swap_operands_; + std::optional algorithm_; // selected algorithm }; // class MatmulPlan explicit BlasLt(StreamExecutor* parent) From d8931ce24a839fb459e849d1549a347c9b80dcb5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Apr 2025 23:20:18 -0700 Subject: [PATCH 0907/1324] Automated Code Change PiperOrigin-RevId: 748558996 --- third_party/xla/xla/hlo/transforms/despecializer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/transforms/despecializer.cc b/third_party/xla/xla/hlo/transforms/despecializer.cc index 15edce035c7bc0..11297956ae6cab 100644 --- a/third_party/xla/xla/hlo/transforms/despecializer.cc +++ b/third_party/xla/xla/hlo/transforms/despecializer.cc @@ -158,7 +158,7 @@ absl::StatusOr DeconstructReduceWindowToReduceBroadcast::Run( auto reduce_dim_index = rw.second; if (reduce_window == nullptr || reduce_dim_index < 0 || reduce_dim_index >= - reduce_window->operand(0)->shape().dimensions_size()) { + reduce_window->operand(0)->shape().dimensions().size()) { continue; } std::vector reduce_instr_dimensions; From 93013794dd4eb57d39950fd0191eaa2f696eb006 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 17 Apr 2025 01:19:02 -0700 Subject: [PATCH 0908/1324] Fix use of mlir::isa/dyn_cast. The member methods are deprecated in favour of the free functions. PiperOrigin-RevId: 748588024 --- .../quantization_lib/quantization_utils.h | 12 +++++------- .../quantization/import_quant_stats_pass.cc | 4 ++-- .../lite/quantization/quantization_context.cc | 4 ++-- .../transforms/lower_static_tensor_list.cc | 4 ++-- .../tf_quantization_utils.h | 12 +++++------- .../tensorflow/passes/quantize.cc | 12 ++++++------ .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 15 ++++++--------- .../tpu_merge_variables_with_execute.cc | 7 ++++--- .../tensor_array_ops_decomposition.cc | 18 +++++++++--------- .../transforms/tpu_resource_partitioning.cc | 3 ++- tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc | 19 +++++++++---------- .../core/ir/utils/shape_inference_utils.cc | 8 ++++---- tensorflow/dtensor/mlir/shape_utils.cc | 6 +++--- tensorflow/dtensor/mlir/spmd_expansion.cc | 4 ++-- tensorflow/dtensor/mlir/value_utils.cc | 17 ++++++++--------- 15 files changed, 69 insertions(+), 76 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h index 4bac179aff06fb..11f7a9aff366f7 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h @@ -504,7 +504,7 @@ class QuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } @@ -569,7 +569,7 @@ class QuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; @@ -649,11 +649,9 @@ class QuantizationPattern : public RewritePattern { } for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!quantizing_op->getResult(i) - .getType() - .cast() - .getElementType() - .isa()) { + if (!isa( + cast(quantizing_op->getResult(i).getType()) + .getElementType())) { continue; } CreateVerifier(quantizing_op, quantized_op, rewriter, i, diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 9347e96330203e..ecbf52d08ec4f3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -106,8 +106,8 @@ class ImportQuantStatsPass if (index < 0 || index >= static_cast(op->getNumResults())) return false; Value res = op->getResult(index); - return res.getType().isa() && - res.getType().cast().getElementType().isa(); + return isa(res.getType()) && + isa(cast(res.getType()).getElementType()); } // A method to retrieve the name for the given op. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index 8682cba5cdc5a9..02d84f906551b5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -259,7 +259,7 @@ LogicalResult QuantizeContext::PropagateQuantParams( // Use the final state to set all the operands' parameters. for (int i = 0, e = op->getNumOperands(); i != e; ++i) { auto ele = op->getOperand(i).getType().cast().getElementType(); - if (ele.isa() && SetOperandParams(op, i, params)) { + if (isa(ele) && SetOperandParams(op, i, params)) { *changed |= true; new_items->push_back(op->getOperand(i).getDefiningOp()); } @@ -268,7 +268,7 @@ LogicalResult QuantizeContext::PropagateQuantParams( // Use the final state to set all the results' parameters. for (int res = 0, e = op->getNumResults(); res != e; ++res) { auto ele = op->getResult(res).getType().cast().getElementType(); - if (ele.isa() && SetResultParams(op, res, params)) { + if (isa(ele) && SetResultParams(op, res, params)) { auto users = op->getResult(res).getUsers(); *changed |= !users.empty(); new_items->append(users.begin(), users.end()); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 2b5b7537f5154c..f2e6cbb5596dbe 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -1336,7 +1336,7 @@ llvm::DenseMap MapTensorListResultToArgument(func::FuncOp func) { break; } } - if (auto block_arg = parent.dyn_cast()) { + if (auto block_arg = dyn_cast(parent)) { return block_arg.getArgNumber(); } // Returns -1 if we don't find which this result maps to. @@ -1547,7 +1547,7 @@ void LowerStaticTensorListPass::runOnOperation() { // still. auto is_legal = [](Operation *op) { auto is_not_variant = [](Type ty) { - return !ty.cast().getElementType().isa(); + return !isa(cast(ty).getElementType()); }; return llvm::all_of(op->getOperandTypes(), is_not_variant) && llvm::all_of(op->getResultTypes(), is_not_variant); diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h index 926adebdab3764..37a83191857a1e 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h @@ -504,7 +504,7 @@ class QuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } @@ -569,7 +569,7 @@ class QuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; @@ -649,11 +649,9 @@ class QuantizationPattern : public RewritePattern { } for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!quantizing_op->getResult(i) - .getType() - .cast() - .getElementType() - .isa()) { + if (!isa( + cast(quantizing_op->getResult(i).getType()) + .getElementType())) { continue; } CreateVerifier(quantizing_op, quantized_op, rewriter, i, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index c18d76327ca844..045128fe349f68 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -221,7 +221,7 @@ class QuantizeSameScaleOpsPattern inputs.reserve(quantizing_op->getNumOperands()); for (const auto& operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } @@ -253,7 +253,7 @@ class QuantizeSameScaleOpsPattern llvm::enumerate(quantizing_op->getResults())) { Value result = enumerated_result.value(); Type result_type = result.getType(); - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; @@ -389,20 +389,20 @@ class QuantizeSameScaleOpsPattern bool has_quantized_types = false; for (Value input : call_op.getArgs()) { if (auto type = input.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { return false; } - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { has_quantized_types = true; } } } for (Value output : call_op.getOutput()) { if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { return false; } - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { has_quantized_types = true; } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 905f4864655a33..c8d4b4e1f48e39 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -98,7 +98,7 @@ namespace { // Returns the equivalent Value skipping through identity nodes. Value LookThroughIdentity(Value result) { while (isa_and_nonnull(result.getDefiningOp())) { - auto op_result = result.cast(); + auto op_result = cast(result); result = op_result.getOwner()->getOperand(op_result.getResultNumber()); } return result; @@ -809,7 +809,7 @@ OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { } return BuildConstRangeTensor(elem_type, num_elements, start_attr, delta_attr); - } else if (elem_type.isa()) { + } else if (isa(elem_type)) { auto start_attr = start_tensor.getValues()[0]; auto limit_attr = limit_tensor.getValues()[0]; auto delta_attr = delta_tensor.getValues()[0]; @@ -2418,7 +2418,7 @@ LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { TPUExecuteAndUpdateVariablesOp op = *this; int num_resource_args = 0; for (Type arg_type : op.getArgs().getTypes()) - if (arg_type.cast().getElementType().isa()) + if (isa(cast(arg_type).getElementType())) ++num_resource_args; auto check_attr = [&](ArrayAttr indices, llvm::StringRef name, @@ -2457,11 +2457,8 @@ void TPUExecuteAndUpdateVariablesOp::getEffects( ResourceEffects::TPUExecute::get()); auto resource_handles = llvm::make_filter_range(getArgsMutable(), [](OpOperand &op_operand) { - return op_operand.get() - .getType() - .cast() - .getElementType() - .isa(); + return isa( + cast(op_operand.get().getType()).getElementType()); }); for (const auto& entry : llvm::enumerate(resource_handles)) { @@ -2858,7 +2855,7 @@ class ToBoolOfRankedTensor : public OpRewritePattern { Attribute zero_attr; if (element_type.isIntOrFloat()) zero_attr = rewriter.getZeroAttr(type); - else if (element_type.isa()) + else if (isa(element_type)) zero_attr = DenseStringElementsAttr::get(type, {""}); if (!zero_attr) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc index 714fefaca8cded..01cc0c55e54a97 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -196,7 +197,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // Check device matching for the node defining the resource. if (!IsResourceMergeable(resource_attr, device_attr)) continue; } else { - auto resource_arg = resource.dyn_cast(); + auto resource_arg = dyn_cast(resource); assert(resource_arg); if (resource_arg.getOwner() != &func.front()) continue; // Check device matching for the argument defining the resource. @@ -518,8 +519,8 @@ LogicalResult MergeForOneTPUExecute( // Check that all resources are either read or written to. for (auto it : llvm::enumerate(var_access_info.new_operand_values)) { Type type = it.value().getType(); - if (type.isa() && - type.cast().getElementType().isa()) { + if (isa(type) && + isa(cast(type).getElementType())) { if (!llvm::is_contained(device_var_reads_indices, it.index()) && !llvm::is_contained(device_var_updates_indices, it.index())) { return execute_launch.GetBody().front().emitError("operand #") diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 47b046d9fdaee2..5e0c94dbec2e3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -478,7 +478,7 @@ llvm::SmallDenseMap> AccessedGradients( llvm::SmallDenseMap> result; llvm::SmallDenseMap> result_sets; auto insert = [&](Value v, const string& source, const Block& func_block) { - auto arg = v.dyn_cast(); + auto arg = dyn_cast(v); if (!arg || arg.getOwner() != &func_block) return; auto insert_res = result_sets[arg.getArgNumber()].insert(source); if (!insert_res.second) return; @@ -594,7 +594,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, for (int64_t i = 0; i < while_op.getNumResults(); ++i) { if (!ta_arg_buffer_type(i)) continue; auto retval = old_body_ret->getOperand(i); - auto arg = retval.dyn_cast(); + auto arg = dyn_cast(retval); if (!arg) { return while_op.emitOpError( "output tensor array does not alias input in a while loop"); @@ -702,13 +702,13 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, if_op->getAttrs()); auto ret_forwards_input = [](func::FuncOp f, int64_t ret_ind) -> int64_t { auto retval = f.front().getTerminator()->getOperand(ret_ind); - auto arg = retval.dyn_cast(); + auto arg = dyn_cast(retval); if (!arg) return -1; return arg.getArgNumber(); }; for (int64_t i = 0; i < if_op.getNumResults(); ++i) { - if (!getElementTypeOrSelf(if_op.getResult(i).getType()) - .isa()) { + if (!isa( + getElementTypeOrSelf(if_op.getResult(i).getType()))) { if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i)); continue; } @@ -811,8 +811,8 @@ LogicalResult HandlePartitionedCallOp( } for (int64_t i = 0; i < call.getNumResults(); ++i) { auto ret = lowered_callee.front().getTerminator()->getOperand(i); - if (!getElementTypeOrSelf(ret.getType()).isa()) continue; - auto arg = ret.dyn_cast(); + if (!isa(getElementTypeOrSelf(ret.getType()))) continue; + auto arg = dyn_cast(ret); if (!arg) continue; info.ret_forward_input.emplace_back(i, arg.getArgNumber()); } @@ -842,7 +842,7 @@ LogicalResult HandleRegionControlFlowOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (OpOperand& operand : op.getOpOperands()) { - if (getElementTypeOrSelf(operand.get().getType()).isa()) { + if (isa(getElementTypeOrSelf(operand.get().getType()))) { return op.emitOpError() << "found unexpected type " << operand.get().getType() << " of operand #" << operand.getOperandNumber() @@ -851,7 +851,7 @@ LogicalResult HandleRegionControlFlowOps( } } for (OpResult result : op.getResults()) { - if (getElementTypeOrSelf(result.getType()).isa()) { + if (isa(getElementTypeOrSelf(result.getType()))) { return op.emitOpError() << "found unexpected type " << result.getType() << " of result #" << result.getResultNumber() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc index fdacf313d30240..64286aa678a683 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" @@ -118,7 +119,7 @@ mlir::Attribute GetDeviceOfResource(mlir::func::FuncOp func, if (auto* resource_op = resource.getDefiningOp()) { return resource_op->getAttr(kDeviceAttr); } else { - const auto resource_arg = resource.dyn_cast_or_null(); + const auto resource_arg = dyn_cast_or_null(resource); if (resource_arg && (resource_arg.getOwner() == &(func.front()))) { return func.getArgAttrOfType( resource_arg.getArgNumber(), kFuncDeviceAttr); diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 6780328b8e8975..755412d5908997 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -111,8 +111,7 @@ class TFRInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type result_type, Location conversion_loc) const final { - if (!input.getType().isa() || - !result_type.isa()) { + if (!isa(input.getType()) || !isa(result_type)) { return nullptr; } auto input_itype = input.getType().cast(); @@ -250,7 +249,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (!arg_type.isa()) { + if (!isa(arg_type)) { if (first_attr == -1) { first_attr = arg.index(); } @@ -423,7 +422,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { DenseElementsAttr attr = DenseElementsAttr::get(new_out_type, array.getValue()); new_cst = rewriter.create(loc, new_out_type, attr); - if (out_type.isa()) { + if (isa(out_type)) { new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); } rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); @@ -434,7 +433,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { if (matchPattern(cst_tensor_op.getArg(), m_Constant(&scalar))) { Type new_out_type = RankedTensorType::get({}, scalar.getType()); new_cst = rewriter.create(loc, new_out_type, scalar); - if (out_type.isa()) { + if (isa(out_type)) { new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); } rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); @@ -447,7 +446,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { inline bool isQuantizedType(Type type) { auto tensor_type = type.dyn_cast(); return (tensor_type && - tensor_type.getElementType().isa()); + isa(tensor_type.getElementType())); } class RemoveRedundantCast : public OpRewritePattern { @@ -493,7 +492,7 @@ class RemoveRedundantCast : public OpRewritePattern { // If the two types are the same, the back-to-back tfr.cast ops can be // removed. - if (input_type == output_type || output_type.isa()) { + if (input_type == output_type || isa(output_type)) { rewriter.replaceOp(cast_op, {input}); return success(); } @@ -501,7 +500,7 @@ class RemoveRedundantCast : public OpRewritePattern { // If the rank of the input tensor isn't ranked, we replace the pair // with tf.EnsureShape op so it can be removed after shape inference or // confirmed at runtime. - if (input_type.isa()) { + if (isa(input_type)) { auto shape = output_type.cast().getShape(); auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape); rewriter.replaceOpWithNewOp(cast_op, output_type, @@ -548,7 +547,7 @@ class RemoveRedundantGetElement : public OpRewritePattern { Value input = preceding_build_list.getOperand(index.getInt()); Type output_type = ge_op.getType(); if (input.getType() != output_type && - !output_type.isa()) { + !isa(output_type)) { return failure(); } rewriter.replaceOp(ge_op, {input}); @@ -995,7 +994,7 @@ Type TFRDialect::parseType(DialectAsmParser &parser) const { void TFRDialect::printType(Type type, DialectAsmPrinter &os) const { llvm::ArrayRef attrs; - if (type.isa()) { + if (isa(type)) { os << "attr"; return; } diff --git a/tensorflow/core/ir/utils/shape_inference_utils.cc b/tensorflow/core/ir/utils/shape_inference_utils.cc index 753ad1450b8a9e..e7227985360b19 100644 --- a/tensorflow/core/ir/utils/shape_inference_utils.cc +++ b/tensorflow/core/ir/utils/shape_inference_utils.cc @@ -106,7 +106,7 @@ std::optional GetShapeFromMlirType(Type t) { // Extracts a PartialTensorShape from the MLIR attr. std::optional GetShapeFromMlirAttr(Value v) { // Function arguments may have shape attr to describe its output shape. - if (auto arg = v.dyn_cast()) { + if (auto arg = dyn_cast(v)) { Operation* parent_op = arg.getOwner()->getParentOp(); if (auto func_op = llvm::dyn_cast(parent_op)) { int arg_idx = arg.getArgNumber(); @@ -336,7 +336,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( if (c.requested_input_tensor_as_partial_shape(input) && !input_tensors[input] && !input_tensors_as_shapes[input].Handle()) { VLOG(4) << "Requesting " << input << " as shape\n"; - auto op_result = op->getOperand(input).dyn_cast(); + auto op_result = dyn_cast(op->getOperand(input)); if (!op_result) continue; // Resize on first valid shape computed. auto handle = op_result_as_shape_fn(c, op_result); @@ -370,7 +370,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( Type new_element_type = result_element_type_fn(output); // Populate the handle shapes for a resource/variant. if (new_element_type && - new_element_type.isa()) { + isa(new_element_type)) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { SmallVector subtypes; @@ -382,7 +382,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( subtypes.push_back( CreateTensorType(c, shape_n_type.shape, element_type)); } - if (new_element_type.isa()) { + if (isa(new_element_type)) { new_element_type = tf_type::ResourceType::get(subtypes, op->getContext()); } else { diff --git a/tensorflow/dtensor/mlir/shape_utils.cc b/tensorflow/dtensor/mlir/shape_utils.cc index 0864fe28ba074a..787f1b7b149123 100644 --- a/tensorflow/dtensor/mlir/shape_utils.cc +++ b/tensorflow/dtensor/mlir/shape_utils.cc @@ -73,7 +73,7 @@ StatusOr> ExtractGlobalInputShape( return errors::Internal("global_shape does not have static rank"); return *global_shape; } - return ExtractGlobalOutputShape(input_value.get().cast()); + return ExtractGlobalOutputShape(cast(input_value.get())); } // If we reach this point, we're working with a function argument. @@ -85,7 +85,7 @@ StatusOr> ExtractGlobalInputShape( operand_index, op->getName()) .str()); - auto block_arg = input_value.get().dyn_cast(); + auto block_arg = mlir::dyn_cast(input_value.get()); auto global_shape_attr = enclosing_function.getArgAttrOfType( block_arg.getArgNumber(), kGlobalShapeDialectAttr); @@ -303,7 +303,7 @@ StatusOr> GetShapeOfValue(const mlir::Value& value, StatusOr> GetGlobalShapeOfValueFromDTensorLayout( const mlir::Value& value) { - if (value.isa() && + if (mlir::isa(value) && mlir::isa(value.getDefiningOp())) { auto layout_op = mlir::cast(value.getDefiningOp()); if (layout_op.getGlobalShape()) return layout_op.getGlobalShape().value(); diff --git a/tensorflow/dtensor/mlir/spmd_expansion.cc b/tensorflow/dtensor/mlir/spmd_expansion.cc index ff7e1444520af0..3fa9b115087bb9 100644 --- a/tensorflow/dtensor/mlir/spmd_expansion.cc +++ b/tensorflow/dtensor/mlir/spmd_expansion.cc @@ -190,7 +190,7 @@ bool GetResourceArgIndexIfUsedInAssignmentOp( GetForwardedDTensorLayoutInput(assign_variable_op.getResource()); if (llvm::isa(resource)) { *resource_argument_index_for_assign_variable = - resource.cast().getArgNumber(); + cast(resource).getArgNumber(); return true; } } @@ -223,7 +223,7 @@ mlir::LogicalResult UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function) { // If argument is a resource type update the subtype shape information // to reflect local shape of resources. - if (arg_type.isa()) { + if (isa(arg_type)) { if (mlir::failed(UpdateResourceArgumentType(argument_index, function))) return mlir::failure(); continue; diff --git a/tensorflow/dtensor/mlir/value_utils.cc b/tensorflow/dtensor/mlir/value_utils.cc index aff45541759515..4fae043b60106f 100644 --- a/tensorflow/dtensor/mlir/value_utils.cc +++ b/tensorflow/dtensor/mlir/value_utils.cc @@ -57,7 +57,8 @@ mlir::Value GetForwardedInput(mlir::Value value) { bool value_updated; do { value_updated = false; - if (mlir::BlockArgument argument = value.dyn_cast()) { + if (mlir::BlockArgument argument = + mlir::dyn_cast(value)) { mlir::Region* region = argument.getParentRegion(); if (region == nullptr) break; mlir::Operation* parent_op = region->getParentOp(); @@ -176,7 +177,7 @@ mlir::Value IntConstWithMatchingType(mlir::OpBuilder& builder, StatusOr ExtractConstIntFromValue(mlir::Value value) { value = GetForwardedInput(value); - if (value.isa()) + if (mlir::isa(value)) return errors::Internal("unable get constant value from block argument"); mlir::DenseIntElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) { @@ -195,7 +196,7 @@ StatusOr ExtractConstIntFromValue(mlir::Value value) { absl::Status ExtractConstVectorFromValue( mlir::Value value, llvm::SmallVector* out_vector) { value = GetForwardedInput(value); - if (value.isa()) + if (mlir::isa(value)) return errors::Internal("unable get constant value from block argument"); mlir::DenseIntElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) { @@ -289,8 +290,8 @@ StatusOr SelectScalarValueFromArray(mlir::OpBuilder& builder, mlir::Type GetSubtypeOrSelf(mlir::Value val) { mlir::Type type = val.getType(); if (auto type_with_subtype = - mlir::getElementTypeOrSelf(val) - .dyn_cast()) { + mlir::dyn_cast( + mlir::getElementTypeOrSelf(val))) { if (type_with_subtype.GetSubtypes().size() == 1) { type = type_with_subtype.GetSubtypes().front(); } @@ -299,10 +300,8 @@ mlir::Type GetSubtypeOrSelf(mlir::Value val) { } bool IsResourceType(mlir::Value val) { - return val.getType() - .cast() - .getElementType() - .isa(); + return mlir::isa( + mlir::cast(val.getType()).getElementType()); } } // namespace dtensor From ac32c1652755de306d3de6880f047c2d0ae5e1b8 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 17 Apr 2025 01:40:50 -0700 Subject: [PATCH 0909/1324] Delegate 4 bit statically quantized FC PiperOrigin-RevId: 748593224 --- tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc | 6 +++++- tensorflow/lite/tools/cmake/modules/xnnpack.cmake | 2 +- tensorflow/workspace2.bzl | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index d53dd1fd530b90..c665ced07237be 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -4499,8 +4499,12 @@ class Subgraph { (input_tensor.type == kTfLiteFloat32 && (filter_tensor.type == kTfLiteInt4 || filter_tensor.type == kTfLiteInt8))); + bool supported_srq = (input_tensor.type == kTfLiteInt8 && + (filter_tensor.type == kTfLiteInt4 || + filter_tensor.type == kTfLiteInt8)); if (input_tensor.type != output_tensor.type || - ((input_tensor.type != filter_tensor.type) && !dynamically_quantized)) { + ((input_tensor.type != filter_tensor.type) && + !(dynamically_quantized || supported_srq))) { TF_LITE_MAYBE_KERNEL_LOG( logging_context, "unsupported mixed types in FULLY_CONNECTED operator #%d", diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 7dbda8f999e4d3..dee999559abb80 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG 8a2f5f441833b80806b58b5d704ec8335634182c + GIT_TAG a6654be72590651ebc1aec686cd41ac9c2461f59 GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index c908483fecc8fe..8628a185e9c257 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -160,9 +160,9 @@ def _tf_repositories(): # LINT.IfChange(xnnpack) tf_http_archive( name = "XNNPACK", - sha256 = "f25179a30775d9918670fb5fb07cd8e80c2ae0a8f4ec450a6d6c496d159ba66b", - strip_prefix = "XNNPACK-ece21c589be842fbeaee297b0d668194d6f3a35b", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/ece21c589be842fbeaee297b0d668194d6f3a35b.zip"), + sha256 = "7adcf80f9a10ad3eaf49777a97d16201f341970326cd52211391e2c93c65e3f4", + strip_prefix = "XNNPACK-a6654be72590651ebc1aec686cd41ac9c2461f59", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/a6654be72590651ebc1aec686cd41ac9c2461f59.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) From 3a0e9e1bc558e9d4bb5f97d247c1d32e270f1629 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 17 Apr 2025 01:49:01 -0700 Subject: [PATCH 0910/1324] [XLA:GPU] Adapt LayoutNormalization pass to skip already normalized ops. If we process every op, we needlessly add two extra HloInstructions to the HLO graph for every existing one. Given that replaced instructions are not deleted right away, that triples the amount of memory needed. We can improve this by skipping instructions that are already normalized. PiperOrigin-RevId: 748595524 --- third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h | 5 ++ third_party/xla/xla/hlo/ir/hlo_instruction.cc | 14 ++-- third_party/xla/xla/service/BUILD | 4 +- .../xla/xla/service/layout_normalization.cc | 82 ++++++++++++------- .../xla/service/layout_normalization_test.cc | 54 ++++++------ 5 files changed, 100 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h index 37ca85f126a405..577ce6aaff62cc 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h @@ -418,6 +418,11 @@ class DfsHloVisitorBase { // own postprocessing. virtual absl::Status Postprocess(HloInstructionPtr hlo); + // This method should be overriden by subclasses that wish to skip some ops + // while traversing the HLO graph. If this method returns false, the calls to + // Preprocess(op), Handle/OpType/(op) and Postprocess(op) are skipped. + virtual bool ShouldProcessNode(HloInstructionPtr hlo) { return true; } + private: absl::flat_hash_map visit_state_; diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index e258a26dc22c19..7c96e1bbdd4372 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -4765,11 +4765,15 @@ static absl::Status PostOrderDFS( if (visit_state == Visitor::kVisiting) { dfs_stack.pop_back(); - TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); - VLOG(2) << "Visiting HLO %" << current_node->name(); - TF_RETURN_IF_ERROR(current_node->Visit(visitor)); - visitor->SetVisitState(current_id, Visitor::kVisited); - TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); + if (visitor->ShouldProcessNode(current_node)) { + TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); + VLOG(2) << "Visiting HLO %" << current_node->name(); + TF_RETURN_IF_ERROR(current_node->Visit(visitor)); + visitor->SetVisitState(current_id, Visitor::kVisited); + TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); + } else { + visitor->SetVisitState(current_id, Visitor::kVisited); + } continue; } diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 413a03404eb67d..18e05363012f9a 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5312,6 +5312,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -5319,8 +5321,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc index 3a593f40b38f8a..22d52013f96615 100644 --- a/third_party/xla/xla/service/layout_normalization.cc +++ b/third_party/xla/xla/service/layout_normalization.cc @@ -41,10 +41,10 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -54,11 +54,15 @@ namespace { // applied to the shape itself). // // Local precondition for every call: -// -> Input is a bitcast from a normalized layout. +// -> Input either already has a normalized layout, or is a bitcast from a +// normalized layout. // // Local postcondition: // -> Input and output of a processed operation have descending layout* // +// Instructions that already have a normalized layout for operands and output +// are skipped. +// // *: For current fusion limitations this is currently not applicable to // unnested reductions only. class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { @@ -68,6 +72,17 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { const CustomCallTransformer& custom_call_transformer = nullptr) : normalization_(normalization), custom_call_transformer_(custom_call_transformer) {} + bool ShouldProcessNode(HloInstruction* hlo) override { + // Skip `hlo` if it already has a default layout and the operands have a + // default layout as well. + if (hlo->shape().IsArray() && HasDefaultLayout(hlo) && + absl::c_all_of(hlo->operands(), [this](HloInstruction* operand) { + return HasDefaultLayout(operand); + })) { + return false; + } + return true; + } // To handle a constant, just give the literal data a new layout. absl::Status HandleConstant(HloInstruction* hlo) override { @@ -77,7 +92,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - const Shape& shape = hlo->shape(); + Shape shape = hlo->shape(); Shape normalized_shape = Normalize(shape); *literal.mutable_shape_do_not_use() = normalized_shape; // Ensure element_size_in_bits of literal is 0, because literals do not @@ -86,8 +101,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { ->mutable_layout() ->set_element_size_in_bits(0); - HloInstruction* bc_to_orig = MakeBitcastHlo(hlo, shape); *hlo->mutable_shape() = normalized_shape; + HloInstruction* bc_to_orig = MaybeBitcast(hlo, shape); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(bc_to_orig)); MarkAsChanged(); return absl::OkStatus(); @@ -120,7 +135,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { *normalized_slice->mutable_shape()->mutable_layout() = normalized_input->shape().layout(); SetVisited(*normalized_slice); - HloInstruction* bc_to_orig = MakeBitcastHlo(normalized_slice, s); + HloInstruction* bc_to_orig = MaybeBitcast(normalized_slice, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -144,9 +159,9 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { } auto normalized_shape = Normalize(shape); - auto bc_to_normalized = MakeBitcastHlo(hlo, normalized_shape); + auto bc_to_normalized = MaybeBitcast(hlo, normalized_shape); SetVisited(*bc_to_normalized); - auto bc_to_orig = MakeBitcastHlo(bc_to_normalized, shape); + auto bc_to_orig = MaybeBitcast(bc_to_normalized, shape); TF_RETURN_IF_ERROR(hlo->ReplaceUsesWith(users, bc_to_orig)); MarkAsChanged(); return absl::OkStatus(); @@ -173,7 +188,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { hlo->AddInstruction(HloInstruction::CreateConcatenate( normalized_shape, normalized_inputs, normalized_concat_dim)); SetVisited(*normalized_concat); - auto bc_to_orig = MakeBitcastHlo(normalized_concat, hlo->shape()); + auto bc_to_orig = MaybeBitcast(normalized_concat, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -211,7 +226,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { normalization_->UpdateLayout(rw->mutable_shape()); SetVisited(*rw); - HloInstruction* bc_to_orig = MakeBitcastHlo(rw, hlo->shape()); + HloInstruction* bc_to_orig = MaybeBitcast(rw, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -248,7 +263,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { normalized_input, br_dimensions, normalized_shape, &hlo->metadata()); SetVisited(*normalized_broadcast); VLOG(3) << "Generated broadcast: " << normalized_broadcast->ToString(); - auto bc_to_orig = MakeBitcastHlo(normalized_broadcast, s); + auto bc_to_orig = MaybeBitcast(normalized_broadcast, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -265,7 +280,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateIota(normalized_shape, new_iota_dimension)); SetVisited(*normalized_iota); VLOG(3) << "Generated iota: " << normalized_iota->ToString(); - auto bc_to_orig = MakeBitcastHlo(normalized_iota, s); + auto bc_to_orig = MaybeBitcast(normalized_iota, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -325,7 +340,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { // 'normalized_input' is already marked as visited. SetVisited(*new_unary); } - auto bc_to_orig = MakeBitcastHlo(new_unary, s); + auto bc_to_orig = MaybeBitcast(new_unary, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -367,7 +382,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { new_binary, MakeBinaryHlo(hlo->opcode(), a0, b0, &hlo->metadata())); } SetVisited(*new_binary); - auto bc_to_orig = MakeBitcastHlo(new_binary, s); + auto bc_to_orig = MaybeBitcast(new_binary, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -390,7 +405,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN(auto new_reshape, MakeReshapeHlo(normalized_reshape_s, a0)); SetVisited(*new_reshape); - auto bc_to_orig = MakeBitcastHlo(new_reshape, s); + auto bc_to_orig = MaybeBitcast(new_reshape, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -511,7 +526,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { normalized_updates, scatter->to_apply(), normalized_dims, scatter->indices_are_sorted(), scatter->unique_indices())); SetVisited(*normalized_scatter); - auto bc_to_orig = MakeBitcastHlo(normalized_scatter, scatter->shape()); + auto bc_to_orig = MaybeBitcast(normalized_scatter, scatter->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(scatter, bc_to_orig)); return absl::OkStatus(); } @@ -560,7 +575,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { << normalized_transpose->ToString(); } - auto bc_to_orig = MakeBitcastHlo(normalized_transpose, s); + auto bc_to_orig = MaybeBitcast(normalized_transpose, s); return ReplaceInstruction(hlo, bc_to_orig); } @@ -589,7 +604,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto t = hlo->AddInstruction( HloInstruction::CreateTranspose(s_normalized, a0, dimensions)); SetVisited(*t); - auto bc_to_orig = MakeBitcastHlo(t, s); + auto bc_to_orig = MaybeBitcast(t, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -611,7 +626,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto normalized_reverse = hlo->AddInstruction( HloInstruction::CreateReverse(a0->shape(), a0, new_dimensions)); SetVisited(*normalized_reverse); - auto bc_to_orig = MakeBitcastHlo(normalized_reverse, s); + auto bc_to_orig = MaybeBitcast(normalized_reverse, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -644,7 +659,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto padded_normalized = hlo->AddInstruction(HloInstruction::CreatePad( s_normalized, normalized_input, padded_by, new_padding)); SetVisited(*padded_normalized); - auto bc_to_orig = MakeBitcastHlo(padded_normalized, s); + auto bc_to_orig = MaybeBitcast(padded_normalized, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -699,7 +714,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { *normalized_dynamic_slice->mutable_shape()->mutable_layout() = normalized_input->shape().layout(); SetVisited(*normalized_dynamic_slice); - HloInstruction* bc_to_orig = MakeBitcastHlo(normalized_dynamic_slice, s); + HloInstruction* bc_to_orig = MaybeBitcast(normalized_dynamic_slice, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -727,7 +742,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { *new_dus->mutable_shape()->mutable_layout() = new_operand->shape().layout(); SetVisited(*new_dus); - HloInstruction* bc_to_orig = MakeBitcastHlo(new_dus, s); + HloInstruction* bc_to_orig = MaybeBitcast(new_dus, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); @@ -771,7 +786,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { hlo->SetupDerivedInstruction(normalized); SetVisited(*normalized); - HloInstruction* bc_to_orig = MakeBitcastHlo(normalized, s); + HloInstruction* bc_to_orig = MaybeBitcast(normalized, s); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -798,15 +813,15 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { } // Due to Local Precondition we have, the input to all processed ops should - // be HLO in descending layout piped through bitcast. + // be HLO in descending layout (possibly piped through bitcast). absl::StatusOr GetNormalizedInput(HloInstruction* hlo) { + if (HasDefaultLayout(hlo)) { + return hlo; + } TF_RET_CHECK(hlo->opcode() == HloOpcode::kBitcast) << "Unexpected HLO input: " << hlo->ToString(); auto input = hlo->mutable_operand(0); - auto input_shape = input->shape(); - TF_RET_CHECK(Layout::Equal().IgnoreElementSize()( - input_shape.layout(), - LayoutUtil::GetDefaultLayoutForShape(input_shape))); + TF_RET_CHECK(HasDefaultLayout(input)); return input; } @@ -815,6 +830,17 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s); } + bool HasDefaultLayout(HloInstruction* hlo) { + return hlo->shape().IsArray() && + LayoutUtil::IsMonotonicWithDim0Major(hlo->shape().layout()); + } + + HloInstruction* MaybeBitcast(HloInstruction* hlo, + const Shape& original_shape) { + return hlo->shape() == original_shape ? hlo + : MakeBitcastHlo(hlo, original_shape); + } + LayoutNormalization* normalization_; CustomCallTransformer custom_call_transformer_; }; diff --git a/third_party/xla/xla/service/layout_normalization_test.cc b/third_party/xla/xla/service/layout_normalization_test.cc index 7611ce70e78db1..67ced7fc1d731d 100644 --- a/third_party/xla/xla/service/layout_normalization_test.cc +++ b/third_party/xla/xla/service/layout_normalization_test.cc @@ -55,6 +55,19 @@ ENTRY main { )"); } +TEST_F(LayoutNormalizationTest, + TestInstructionsWithNormalizedLayoutAreSkipped) { + const char* hlo = R"( +HloModule module + +ENTRY main { + p = f32[5,4]{1,0} parameter(0) + ROOT o = f32[5,4]{1,0} abs(p) +} +)"; + CheckLayoutNormalization(hlo, /*expected=*/std::nullopt); +} + TEST_F(LayoutNormalizationTest, TestUnary) { const char* hlo = R"( HloModule module @@ -138,14 +151,13 @@ HloModule module ENTRY main { a = f32[5,4]{1,0} parameter(0) t = f32[4,5]{0,1} transpose(a), dimensions={1,0} - ROOT out = abs(t) + ROOT out = f32[4,5]{0,1} abs(t) } )"; CheckLayoutNormalization(hlo, R"( // CHECK: [[a_0:%[^ ]+]] = f32[5,4]{1,0} parameter(0) -// CHECK: [[bitcast_1:%[^ ]+]] = f32[5,4]{1,0} bitcast([[a_0]]) -// CHECK: [[abs_2:%[^ ]+]] = f32[5,4]{1,0} abs([[bitcast_1]]) +// CHECK: [[abs_2:%[^ ]+]] = f32[5,4]{1,0} abs([[a_0]]) // CHECK: ROOT [[bitcast_3_3:%[^ ]+]] = f32[4,5]{0,1} bitcast([[abs_2]]) )"); } @@ -265,14 +277,13 @@ HloModule module ENTRY main { a = f32[2,3]{1,0} parameter(0) b = f32[2,4,3]{1,2,0} broadcast(a), dimensions={0,2} - ROOT out = abs(b) + ROOT out = f32[2,4,3]{1,2,0} abs(b) } )"; CheckLayoutNormalization(hlo, R"( // CHECK: [[a_0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -// CHECK: [[bitcast_1:%[^ ]+]] = f32[2,3]{1,0} bitcast([[a_0]]) -// CHECK: [[broadcast_2:%[^ ]+]] = f32[2,3,4]{2,1,0} broadcast([[bitcast_1]]), dimensions={0,1} +// CHECK: [[broadcast_2:%[^ ]+]] = f32[2,3,4]{2,1,0} broadcast([[a_0]]), dimensions={0,1} // CHECK: [[abs_3:%[^ ]+]] = f32[2,3,4]{2,1,0} abs([[broadcast_2]]) // CHECK: ROOT [[bitcast_3_4:%[^ ]+]] = f32[2,4,3]{1,2,0} bitcast([[abs_3]]) )"); @@ -285,17 +296,11 @@ HloModule module ENTRY main { a = f32[2,3]{1,0} parameter(0) b = f32[3,4,2]{2,1,0} broadcast(a), dimensions={2,0} - ROOT out = abs(b) + ROOT out = f32[3,4,2]{2,1,0} abs(b) } )"; - CheckLayoutNormalization(hlo, R"( -// CHECK: [[a_0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -// CHECK: [[bitcast_1:%[^ ]+]] = f32[2,3]{1,0} bitcast([[a_0]]) -// CHECK: [[broadcast_2:%[^ ]+]] = f32[3,4,2]{2,1,0} broadcast([[bitcast_1]]), dimensions={2,0} -// CHECK: [[abs_3:%[^ ]+]] = f32[3,4,2]{2,1,0} abs([[broadcast_2]]) -// CHECK: ROOT [[bitcast_3_4:%[^ ]+]] = f32[3,4,2]{2,1,0} bitcast([[abs_3]]) -)"); + CheckLayoutNormalization(hlo, std::nullopt); } TEST_F(LayoutNormalizationTest, BroadcastCustomOutputLayoutWithDegenerate) { @@ -305,13 +310,13 @@ HloModule module ENTRY main { a = f32[9]{0} parameter(0) b = f32[2,1,4,9]{2,0,1,3} broadcast(a), dimensions={3} - ROOT out = abs(b) + ROOT out = f32[2,1,4,9]{2,0,1,3} abs(b) } )"; CheckLayoutNormalization(hlo, R"( -// CHECK: [[bitcast_0:%[^ ]+]] = f32[9]{0} bitcast([[a_1:%[^ ]+]]) -// CHECK: [[broadcast_2:%[^ ]+]] = f32[9,1,2,4]{3,2,1,0} broadcast([[bitcast_0]]), dimensions={0} +// CHECK: [[a:%[^ ]+]] = f32[9]{0} parameter(0) +// CHECK: [[broadcast_2:%[^ ]+]] = f32[9,1,2,4]{3,2,1,0} broadcast([[a]]), dimensions={0} // CHECK: [[abs_3:%[^ ]+]] = f32[9,1,2,4]{3,2,1,0} abs([[broadcast_2]]) // CHECK: ROOT [[bitcast_3_4:%[^ ]+]] = f32[2,1,4,9]{2,0,1,3} bitcast([[abs_3]]) )"); @@ -816,14 +821,15 @@ TEST_F(LayoutNormalizationTest, BitcastConvertToSmallerType) { HloModule m ENTRY main { - p0 = u64[4]{0} parameter(0) - ROOT out = u32[4,2]{0,1} bitcast-convert(u64[4]{0} p0), metadata={op_name="test"} + p0 = u64[3,4]{0,1} parameter(0) + bc_convert = u32[3,4,2]{1,0,2} bitcast-convert(p0), metadata={op_name="test"} + ROOT out = u32[3,4,2]{1,0,2} reverse(bc_convert), dimensions={0} } )"; CheckLayoutNormalization(hlo, R"( // CHECK: bitcast-convert({{.*}}), metadata={op_name="test"} -)"); + )"); } TEST_F(LayoutNormalizationTest, Scatter) { @@ -927,14 +933,14 @@ TEST_F(LayoutNormalizationTest, CompareInt4) { HloModule module ENTRY main { - a = s4[10]{0:E(4)} parameter(0) - b = s4[10]{0:E(4)} parameter(1) - ROOT out = compare(a, b), direction=EQ + a = s4[10,11]{0,1:E(4)} parameter(0) + b = s4[10,11]{0,1:E(4)} parameter(1) + ROOT out = pred[10,11]{0,1} compare(a, b), direction=EQ } )"; CheckLayoutNormalization(hlo, R"( -// CHECK: pred[10]{0} compare({{.*}}) +// CHECK: pred[11,10]{1,0} compare({{.*}}) )"); } From 246373a4138fb3a9e165afe6d027e0202344ced9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 02:03:00 -0700 Subject: [PATCH 0911/1324] Automated Code Change PiperOrigin-RevId: 748599417 --- third_party/xla/xla/tools/BUILD | 7 +++++++ third_party/xla/xla/tools/collective_perf_table_gen.cc | 1 + third_party/xla/xla/tools/compute_xspace_stats.cc | 1 - third_party/xla/xla/tools/compute_xspace_stats_test.cc | 1 + third_party/xla/xla/tools/prepare_reference_module.cc | 2 ++ third_party/xla/xla/tools/prepare_reference_module.h | 1 + third_party/xla/xla/tools/run_hlo_module.cc | 3 +++ third_party/xla/xla/tools/run_hlo_module_test.cc | 1 + third_party/xla/xla/tools/xla_compile_lib.cc | 2 ++ third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc | 1 + 10 files changed, 19 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index d38c371fe4f096..c51dd9ca1f40f0 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -400,6 +400,7 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service:hlo_runner_interface", "//xla/stream_executor:platform", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", @@ -478,6 +479,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -713,6 +715,7 @@ cc_library( hdrs = ["collective_perf_table_gen.h"], deps = [ "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_query", @@ -789,6 +792,7 @@ xla_cc_test( ":compute_xspace_stats", "//xla/tsl/platform:statusor", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -906,6 +910,7 @@ tsl_gpu_library( "//xla:debug_options_flags", "//xla:shape_util", "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", @@ -923,6 +928,7 @@ tsl_gpu_library( "//xla/service/cpu:cpu_executable", "//xla/service/gpu:gpu_symbol_repository", "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", @@ -1002,6 +1008,7 @@ xla_test( deps = [ ":xla_compile_lib", "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:platform_util", "//xla/service:symbol_repository", diff --git a/third_party/xla/xla/tools/collective_perf_table_gen.cc b/third_party/xla/xla/tools/collective_perf_table_gen.cc index 278e6f9ed7ce99..4cbfe8a5ca70db 100644 --- a/third_party/xla/xla/tools/collective_perf_table_gen.cc +++ b/third_party/xla/xla/tools/collective_perf_table_gen.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" namespace xla::gpu { diff --git a/third_party/xla/xla/tools/compute_xspace_stats.cc b/third_party/xla/xla/tools/compute_xspace_stats.cc index 878a88537598e3..fe0a3961ae524e 100644 --- a/third_party/xla/xla/tools/compute_xspace_stats.cc +++ b/third_party/xla/xla/tools/compute_xspace_stats.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "absl/container/flat_hash_map.h" diff --git a/third_party/xla/xla/tools/compute_xspace_stats_test.cc b/third_party/xla/xla/tools/compute_xspace_stats_test.cc index 5b7a30f2363d12..0bc522beaa3dec 100644 --- a/third_party/xla/xla/tools/compute_xspace_stats_test.cc +++ b/third_party/xla/xla/tools/compute_xspace_stats_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "xla/tsl/platform/statusor.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace xla::gpu { namespace { diff --git a/third_party/xla/xla/tools/prepare_reference_module.cc b/third_party/xla/xla/tools/prepare_reference_module.cc index 73b7307d9e33c8..6fd7ac111ec796 100644 --- a/third_party/xla/xla/tools/prepare_reference_module.cc +++ b/third_party/xla/xla/tools/prepare_reference_module.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/tools/prepare_reference_module.h b/third_party/xla/xla/tools/prepare_reference_module.h index f26e84745b40d3..3b52ea818523a1 100644 --- a/third_party/xla/xla/tools/prepare_reference_module.h +++ b/third_party/xla/xla/tools/prepare_reference_module.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_module_config.h" diff --git a/third_party/xla/xla/tools/run_hlo_module.cc b/third_party/xla/xla/tools/run_hlo_module.cc index 22c0c02cafde9b..fa7a86cecf0d60 100644 --- a/third_party/xla/xla/tools/run_hlo_module.cc +++ b/third_party/xla/xla/tools/run_hlo_module.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/tools/run_hlo_module.h" +#include +#include #include #include #include @@ -32,6 +34,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/tools/run_hlo_module_test.cc b/third_party/xla/xla/tools/run_hlo_module_test.cc index 255563a5893657..7177eee5b019ea 100644 --- a/third_party/xla/xla/tools/run_hlo_module_test.cc +++ b/third_party/xla/xla/tools/run_hlo_module_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tools/run_hlo_module.pb.h" diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc index e6fce7e44a8fb9..0a0670009fdf7e 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.cc +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -56,12 +56,14 @@ limitations under the License. #include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" #include "xla/shape.h" +#include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tools/hlo_module_loader.h" #include "xla/tsl/platform/env.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc index 83861fa062a61f..78c3407d365ae4 100644 --- a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/tsl/protobuf/status.pb.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" From 2a9125d48878d08a79b7e50ee6b935f341d7d5c0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 02:04:13 -0700 Subject: [PATCH 0912/1324] Update GraphDef version to 2200. PiperOrigin-RevId: 748599901 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e25c61fb6f0a5e..b90d763e50a58e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2199 // Updated: 2025/4/16 +#define TF_GRAPH_DEF_VERSION 2200 // Updated: 2025/4/17 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 88c6471d9e835cafc548acf04b556a6fe3253b34 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 02:04:18 -0700 Subject: [PATCH 0913/1324] compat: Update forward compatibility horizon to 2025-04-17 PiperOrigin-RevId: 748599925 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 2ecb50367ed8df..6e640f8cfa609e 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 16) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 4, 17) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 51700db32245c7e810bf5e19e98c6f5ca94a9076 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 17 Apr 2025 02:25:49 -0700 Subject: [PATCH 0914/1324] Fix mlir cast/dyn_cast/isa in tensorflow Use llvm::cast/dyn_cast/isa since alternatives are deprecated in https://github.com/llvm/llvm-project/pull/135556 PiperOrigin-RevId: 748605847 --- .../mlir/tensorflow/transforms/replicate_to_island.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 3928faaa280398..e680ac6e0618d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -94,7 +94,8 @@ LogicalResult GetDeviceOrdinal(const std::optional& devices, << " to be present in 'tf.device.replicate' op"; } llvm::StringRef tpu_device = - tpu_replica.cast()[replica_id].cast().getValue(); + llvm::cast(tpu_replica.cast()[replica_id]) + .getValue(); return tensorflow::GetDeviceOrdinalFromDeviceString(op->getLoc(), tpu_device, &device_ordinal); } @@ -136,9 +137,9 @@ LogicalResult UpdateRegionReplicateVariantOps( // Map aliased devices to explicit devices based on replica. if (auto launch = dyn_cast(op)) if (auto device_by_replica = devices.value().get(launch.getDevice())) - launch->setAttr( - kDeviceAttr, - device_by_replica.cast()[replica_id].cast()); + launch->setAttr(kDeviceAttr, + llvm::cast( + device_by_replica.cast()[replica_id])); return WalkResult::advance(); }); From 3476e653171dc864201b9c61bedb833c67ad8461 Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Thu, 17 Apr 2025 02:41:25 -0700 Subject: [PATCH 0915/1324] [XLA:GPU] VLOG from filecheck makes it easier to see the text of the final module without source modifications PiperOrigin-RevId: 748609935 --- third_party/xla/xla/hlo/testlib/filecheck.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/testlib/filecheck.cc b/third_party/xla/xla/hlo/testlib/filecheck.cc index 2620742bd80fdf..a98490090873fa 100644 --- a/third_party/xla/xla/hlo/testlib/filecheck.cc +++ b/third_party/xla/xla/hlo/testlib/filecheck.cc @@ -38,7 +38,7 @@ absl::StatusOr RunFileCheck(const std::string& input, return tsl::errors::Internal("couldn't get a pattern file name"); } TF_RETURN_IF_ERROR(tsl::WriteStringToFile(env, pattern_path, pattern)); - // LOG(INFO) << "input: " << input; + VLOG(3) << "input: " << input; return RunFileCheckWithPatternFile(input, pattern_path); } From 87f5a72845938d78d009d33cbda96bc3986bf81f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 03:11:39 -0700 Subject: [PATCH 0916/1324] Add missing platform definitions for Android. PiperOrigin-RevId: 748617019 --- tensorflow/opensource_only.files | 1 + tensorflow/tools/toolchains/android/BUILD | 35 +++++++++++++++++++ third_party/xla/opensource_only.files | 1 + .../xla/tools/toolchains/android/BUILD | 35 +++++++++++++++++++ 4 files changed, 72 insertions(+) create mode 100644 tensorflow/tools/toolchains/android/BUILD create mode 100644 third_party/xla/tools/toolchains/android/BUILD diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 8e37b86fdf2f9c..3281c096528b7e 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -176,6 +176,7 @@ tf_staging/tensorflow/tools/pip_package/simple_console_for_windows:.py tf_staging/tensorflow/tools/pip_package/utils/BUILD: tf_staging/tensorflow/tools/pip_package/xla_build/CMakeLists.txt: tf_staging/tensorflow/tools/toolchains/BUILD: +tf_staging/tensorflow/tools/toolchains/android/BUILD: tf_staging/tensorflow/tools/toolchains/clang6/BUILD: tf_staging/tensorflow/tools/toolchains/cpus/py/BUILD: tf_staging/tensorflow/tools/toolchains/cpus/py3/BUILD: diff --git a/tensorflow/tools/toolchains/android/BUILD b/tensorflow/tools/toolchains/android/BUILD new file mode 100644 index 00000000000000..fe32baa142bbf1 --- /dev/null +++ b/tensorflow/tools/toolchains/android/BUILD @@ -0,0 +1,35 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +platform( + name = "x86", + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:x86_32", + ], +) + +platform( + name = "x86_64", + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:x86_64", + ], +) + +platform( + name = "armeabi-v7a", + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:armv7", + ], +) + +platform( + name = "arm64-v8a", + constraint_values = [ + "@platforms//cpu:arm64", + "@platforms//os:android", + ], +) diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 2278c6c433fb0a..6681f73f8b73bb 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -146,6 +146,7 @@ tools/def_file_filter/BUILD: tools/def_file_filter/def_file_filter.py.tpl: tools/def_file_filter/def_file_filter_configure.bzl: tools/toolchains/BUILD: +tools/toolchains/android/BUILD: tools/toolchains/clang6/BUILD: tools/toolchains/cpus/py/BUILD: tools/toolchains/cpus/py3/BUILD: diff --git a/third_party/xla/tools/toolchains/android/BUILD b/third_party/xla/tools/toolchains/android/BUILD new file mode 100644 index 00000000000000..fe32baa142bbf1 --- /dev/null +++ b/third_party/xla/tools/toolchains/android/BUILD @@ -0,0 +1,35 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +platform( + name = "x86", + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:x86_32", + ], +) + +platform( + name = "x86_64", + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:x86_64", + ], +) + +platform( + name = "armeabi-v7a", + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:armv7", + ], +) + +platform( + name = "arm64-v8a", + constraint_values = [ + "@platforms//cpu:arm64", + "@platforms//os:android", + ], +) From 6234136131fe61604de59438733803e5ec5efe94 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 17 Apr 2025 03:46:20 -0700 Subject: [PATCH 0917/1324] Internal non functional BUILD file change allowing Shardy to be used in other places. PiperOrigin-RevId: 748624512 --- third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index f433c1261d9ecb..2ff9318d820fca 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -12,6 +12,7 @@ package_group( packages = [ "//learning/deepmind/partir/compiler/mpmd/...", "//learning/deepmind/partir/compiler/shardonnay/...", + "//third_party/australis/google/ifrt/...", "//third_party/openxla/shardy/tools/...", "//third_party/py/jax/...", "//xla/...", From ab342c0f9411396eb5cf9a299affe06f15e3cf71 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 17 Apr 2025 04:02:47 -0700 Subject: [PATCH 0918/1324] #sdy don't add a sharding-constraint for a free var that already has a fully open sharding in -xla-sdy-open-while-free-vars-sharding This makes multiple calls to this pass no-op. PiperOrigin-RevId: 748628121 --- .../open_while_free_vars_sharding.cc | 9 +++++++-- .../shardy/test/open_while_free_vars_sharding.mlir | 12 +++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc index 6fe201ccb4fb4d..31bb0788367ec6 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc @@ -58,10 +58,15 @@ class OpenWhileFreeVarsShardingPass if (!sharding || sharding.getRank() == 0) { continue; } + auto fullyOpenSharding = TensorShardingAttr::getFullyOpenLike(sharding); + if (fullyOpenSharding == sharding) { + // The sharding of the `freeVar` is already fully open, no need to add + // a sharding constraint. + continue; + } auto shardingConstraint = rewriter.create( - freeVar.getLoc(), freeVar, - TensorShardingAttr::getFullyOpenLike(sharding)); + freeVar.getLoc(), freeVar, fullyOpenSharding); // Only replace uses in the regions of the while op. rewriter.replaceUsesWithIf( freeVar, shardingConstraint, [op](mlir::OpOperand& use) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir index fe13f45d4e09a4..be418bad663a6d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir @@ -1,5 +1,8 @@ // RUN: sdy_opt %s -xla-sdy-open-while-free-vars-sharding 2>&1 | FileCheck %s +// Verify calling this pass a second time is a no-op. +// RUN: sdy_opt %s -xla-sdy-open-while-free-vars-sharding -xla-sdy-open-while-free-vars-sharding 2>&1 | FileCheck %s + sdy.mesh @mesh1 = <["a"=2]> sdy.mesh @mesh2 = <["b"=2]> @@ -7,7 +10,8 @@ sdy.mesh @mesh2 = <["b"=2]> func.func @while_with_free_variables( %arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}, - %arg2: tensor<32x96xf32>) + %arg2: tensor<32x96xf32>, + %arg3: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{?}, {?}]>}) -> (tensor<32x96xf32>, tensor<32x96xf32>) { // CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<0> // CHECK-NEXT: %[[C1:.*]] = stablehlo.constant dense<1> @@ -24,7 +28,8 @@ func.func @while_with_free_variables( // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %iterArg, %[[SC_0]] // CHECK-NEXT: %[[ADD_3:.*]] = stablehlo.add %[[ADD_2]], %arg2 // CHECK-NEXT: %[[ADD_4:.*]] = stablehlo.add %[[ADD_3]], %[[SC_1]] - // CHECK-NEXT: stablehlo.return %[[ADD_4]], %[[ADD_1]] + // CHECK-NEXT: %[[ADD_5:.*]] = stablehlo.add %[[ADD_4]], %arg3 + // CHECK-NEXT: stablehlo.return %[[ADD_5]], %[[ADD_1]] // CHECK-NEXT: } // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0 %0 = stablehlo.constant dense<0> : tensor @@ -40,7 +45,8 @@ func.func @while_with_free_variables( %6 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> %7 = stablehlo.add %6, %arg2 : tensor<32x96xf32> %8 = stablehlo.add %7, %3 : tensor<32x96xf32> - stablehlo.return %8, %5 : tensor<32x96xf32>, tensor + %9 = stablehlo.add %8, %arg3 : tensor<32x96xf32> + stablehlo.return %9, %5 : tensor<32x96xf32>, tensor } return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32> } From c89ed3ef33633724dae06c28b3cd718bcbd3e206 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Thu, 17 Apr 2025 05:51:54 -0700 Subject: [PATCH 0919/1324] PR #18410: [XLA:CPU][oneDNN] Absorb Transpose into matmul whenever possible Imported from GitHub PR https://github.com/openxla/xla/pull/18410 This PR tries to absorb transposes into matmuls and eliminate associated time consuming copy operations, whenever feasible. These optimizations are expected to benefit attention-based models, where transposed projections are used to compute attention scores. Additionally, this PR includes tests to ensure correct functionality. Copybara import of the project: -- 7b20cf42219d5c19c72f630373fcb42fb88b8284 by Akhil Goel : Absorb Transpose into matmul whenever possible Merging this change closes #18410 PiperOrigin-RevId: 748652610 --- third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index 76d346ee4edc9a..4f07aba6fb43c9 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -1195,6 +1195,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { HloInstruction* contraction; if (Match(custom_call, OneDnnMatmulInstr(&contraction))) { auto backend_config = contraction->backend_config(); + auto new_ops = contraction->mutable_operands(); UpdateTransposeDimensions(contraction, new_ops, 0, &backend_config); From 22644cf997cc14ea60092a1918da4fb0785d6baa Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 17 Apr 2025 06:29:00 -0700 Subject: [PATCH 0920/1324] Fix mlir cast/dyn_cast/isa in tensorflow Use llvm::cast/dyn_cast/isa since alternatives are deprecated in https://github.com/llvm/llvm-project/pull/135556 PiperOrigin-RevId: 748660830 --- .../compiler/mlir/lite/ir/tfl_op_enums.td | 4 +- .../common/quantization_lib/quantization.td | 8 +- .../mlir/lite/transforms/legalize_patterns.td | 36 +++--- .../lite/transforms/legalize_variables.td | 2 +- .../lite/transforms/optimize_batch_matmul.td | 4 +- .../mlir/lite/transforms/optimize_patterns.td | 106 +++++++++--------- .../mlir/lite/transforms/quantize_patterns.td | 2 +- tensorflow/compiler/mlir/lite/utils/utils.td | 64 +++++------ 8 files changed, 113 insertions(+), 113 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td index fa85389789e554..57e4ec22976df3 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td @@ -27,9 +27,9 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" // Referred TF_AnyStrAttrOf in tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td class TFL_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", + "llvm::cast($_self).getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), - "$_self.cast().getValue() == \"" # case # "\""), + "llvm::cast($_self).getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td index 690fe4be1d46eb..143996e8816cac 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 26659b157933f2..a148fbd3685f5c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -26,20 +26,20 @@ include "tensorflow/compiler/mlir/lite/utils/utils.td" def CreateEmptyBoolAttr : NativeCodeCall<"::mlir::BoolAttr()">; def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def Int64ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr // with builder. class ExtractI32At : NativeCodeCall< - "$_builder.getI32IntegerAttr($_self.cast().getValue()[" # i # - "].cast().getInt())">; + "$_builder.getI32IntegerAttr(llvm::cast(llvm::cast($_self).getValue()[" # i # + "]).getInt())">; // Use the tensor type information from $0 and convert min $1, max $2 and // numBits $3 and narrowRange $4 to a QuantizedType. @@ -48,7 +48,7 @@ def ConvertToQuantTypeFromAttrs : NativeCodeCall< // Converts an integer attribute $0 to 32-bit with builder. def convertIntAttrTo32Bit : NativeCodeCall< - "$_builder.getI32IntegerAttr($0.cast().getInt())">; + "$_builder.getI32IntegerAttr(llvm::cast($0).getInt())">; // Builds a constant bool attribute. class GetBoolAttr : @@ -56,15 +56,15 @@ class GetBoolAttr : // Converts an integer attribute $0 to 64-bit with builder. def convertIntAttrTo64Bit : NativeCodeCall< - "$_builder.getI64IntegerAttr($0.cast().getInt())">; + "$_builder.getI64IntegerAttr(llvm::cast($0).getInt())">; // Extracts the single integer element from $_self. def ExtractSingleElementAsInteger : NativeCodeCall< - "ExtractSingleElementAsInteger($_self.cast())">; + "ExtractSingleElementAsInteger(llvm::cast($_self))">; // Extracts the single int32 element from $_self. def ExtractSingleElementAsInt32 : NativeCodeCall< - "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast()).getInt())">; + "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger(llvm::cast($_self)).getInt())">; // Converts tensor with int64 to int32. def CreateTFCastToInt32Op : NativeCodeCall< @@ -75,7 +75,7 @@ def CreateInt32ConstOrCast : NativeCodeCall< // Creates an int32 constant op from an integer attribute $0. def CreateInt32ConstOpFromIntAttr - : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast($0.cast().getInt())}))">; + : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast(llvm::cast($0).getInt())}))">; //===----------------------------------------------------------------------===// // Nullary ops patterns. @@ -100,8 +100,8 @@ def IsDataFormatNHWC : ConstantAttr; def IsDataFormatNCHW : ConstantAttr; class I32VectorElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() &&" - "$_self.cast().getType()." + CPred<"llvm::isa($_self) &&" + "llvm::cast($_self).getType()." "getElementType().isSignlessInteger(32)">, "32-bit int elements attribute of shape [" # len # "]"> { @@ -123,8 +123,8 @@ def IsAllOnes : AttrConstraint>; // Constraint that attribute is string with value either "SAME" or "VALID" def IsSameOrValid : AttrConstraint< - CPred<"$_self.cast().getValue() == \"SAME\" || " # - "$_self.cast().getValue() == \"VALID\"">, + CPred<"llvm::cast($_self).getValue() == \"SAME\" || " # + "llvm::cast($_self).getValue() == \"VALID\"">, "'SAME' or 'VALID' paddings">; def TFL_GetMirrorPaddingType : NativeCodeCall< @@ -443,8 +443,8 @@ def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>; -def ReductionDimensionIsLastDim : Constraint().getInt() == " - "$1.getType().cast().getRank() - 1 || $0.cast().getInt() == -1)">>; +def ReductionDimensionIsLastDim : Constraint($0).getInt() == " + "llvm::cast($1.getType()).getRank() - 1 || llvm::cast($0).getInt() == -1)">>; // Legalizes TF_ApproxTopKOp to TFL_TopKV2Op with the following constraints: // 1. It computes max k @@ -558,10 +558,10 @@ def LegalizeConv2DBackpropInput : Pat< /*fused_activation_function=*/TFL_AF_None)>; def IsRankZeroAttr - : CPred<"$_self.cast().getType().getRank() == 0">; + : CPred<"llvm::cast($_self).getType().getRank() == 0">; def HasValueZero - : CPred<"$_self.cast()." + : CPred<"llvm::cast($_self)." "getSplatValue<::mlir::IntegerAttr>().getInt() == 0">; // TFLite only supports MatrixSetDiag ops with scalar zero k attribute. diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td index 5c26b6ea468565..72ec563930d7d2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td @@ -22,7 +22,7 @@ def HasSupportedElementType : Constraint>; def IsSupportedElementType : - Constraint())">>; + Constraint($0.getType()))">>; def LegalizeVarHandle : Pat< (TF_VarHandleOp:$result $container, $shared_name), diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td index 85bdf63babcbab..bc82b1f496acfb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td @@ -26,8 +26,8 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def NotFromDequant : Constraint>; def IsResultRankEqualTo : Constraint().getRank() == " - "$1.getType().cast().getRank()">>; + "llvm::cast($0.getType().front()).getRank() == " + "llvm::cast($1.getType()).getRank()">>; // Fuses TFL_FullyConnectedOp and TFL_TransposeOp Rhs to TFL_BatchMatMulOp when // it's used by TFL_BatchMatMulOp and "transpose_lhs" is true. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 1e97e8f42584b6..fc09b1f6a55021 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -27,21 +27,21 @@ include "mlir/IR/CommonAttrConstraints.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isF32()">, + CPred<"llvm::isa($_self) && llvm::cast($_self).getShapedType().getElementType().isF32()">, "32 bit float constant tensor">; // Checks if the param passed is a float ElementsAttr. def FloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isa()">, + CPred<"llvm::isa($_self) && llvm::isa(llvm::cast($_self).getShapedType().getElementType())">, "float constant tensor">; def ExtractSingleElementAsFloat : NativeCodeCall< - "ExtractSingleElementAsFloat($_self.cast())">; + "ExtractSingleElementAsFloat(llvm::cast($_self))">; // Checks if the value has rank 'n'. class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>>; class FloatValueEquals : Constraint>; @@ -57,9 +57,9 @@ def HasOneUse : Constraint>; def IsPermutationNCHW : Constraint>; def IsBiasShape : Constraint< - CPred<"$0.getType().cast().getRank() == 4 && " - "$0.getType().cast().getShape()[2] == 1 && " - "$0.getType().cast().getShape()[3] == 1">, + CPred<"llvm::cast($0.getType()).getRank() == 4 && " + "llvm::cast($0.getType()).getShape()[2] == 1 && " + "llvm::cast($0.getType()).getShape()[3] == 1">, "has shape consistent with a bias">; def ReshapeNCHWBiasToNHWC : NativeCodeCall<"ReshapeNCHWBiasToNHWC($0, $1)">; @@ -114,7 +114,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], } def GetBiasMultiplier: - NativeCodeCall<"GetBiasMultiplier($_builder, $0, $1.cast())">; + NativeCodeCall<"GetBiasMultiplier($_builder, $0, llvm::cast($1))">; class CanFuseConvOrDepthwiseConv : Constraint< CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; @@ -372,22 +372,22 @@ def MatchHardSwishPattern6 : Pat< // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< - CPred<"$0.isa() && " - "$0.cast().getNumElements() == 1 && " - "std::abs(*$0.cast().getValues().begin()) < " + CPred<"llvm::isa($0) && " + "llvm::cast($0).getNumElements() == 1 && " + "std::abs(*llvm::cast($0).getValues().begin()) < " # n>>; // Constraint that the attribute value is negative infinity or negative largest. // We use both -inf & flt_min due to the forward compatibility. def ConstAPFloatNegLargestOrNegInfinity : Constraint() && " - "$0.cast().getNumElements() == 1 && " - "(($0.cast().getValues()[0].isLargest() && " - "$0.cast().getValues()[0].isNegative()) || " - "$0.cast().getValues()[0].isNegInfinity())">>; + "llvm::isa($0) && " + "llvm::cast($0).getNumElements() == 1 && " + "((llvm::cast($0).getValues()[0].isLargest() && " + "llvm::cast($0).getValues()[0].isNegative()) || " + "llvm::cast($0).getValues()[0].isNegInfinity())">>; def L2NormValidReduceIndex : Constraint())">>; + "L2NormalizeReduceAxis($0, llvm::cast($1))">>; // Currently L2Normalization doesn't support activation function // in TFLite. @@ -456,9 +456,9 @@ def IsReducedTailOfShape : Constraint>; def Flatten : NativeCodeCall< - "$0.cast()" - ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " - "$0.getType().cast().getElementType()))">; + "llvm::cast($0)" + ".reshape(RankedTensorType::get({llvm::cast($0.getType()).getNumElements()}, " + "llvm::cast($0.getType()).getElementType()))">; def IsLastDimEqualToNumElements : Constraint>; @@ -729,8 +729,8 @@ def GetPrefixTruncatedShape: NativeCodeCall<"GetShapeAttr($0, true)">; // Returns True if the operand type is RankedTensorType and valid. def HasValidRankedTensor : Constraint() && " - "$0.getType().cast().getNumDynamicDims() <= 1">>; + "llvm::isa($0.getType()) && " + "llvm::cast($0.getType()).getNumDynamicDims() <= 1">>; // Check if the truncated shape of the lhs is equal to the shape of rhs def IsPrefixTruncatedShapeEqualTo : Constraint().getRank() == " - "$1.getType().cast().getRank() - 1">>; + CPred<"llvm::cast($0.getType()).getRank() == " + "llvm::cast($1.getType()).getRank() - 1">>; // PReLU pattern from Keras: // f(x) = Relu(x) + (-alpha * Relu(-x)) @@ -979,7 +979,7 @@ def OptimizePow2ToRsqrt : Pat< def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint(), $2.getType())">>; + "$0, llvm::cast($1), $2.getType())">>; def OptimizeIdentityGatherNdOp : Pat< (TFL_GatherNdOp:$output $params, (Arith_ConstantOp I32ElementsAttr: $indices)), @@ -1013,9 +1013,9 @@ def IsSame : Constraint>; def HasTwoUse : Constraint>; def AxesIsLastDimension : Constraint().getNumElements() == 1 && " - "($0.cast().getValues()[0] == " - "$1.getType().cast().getRank() - 1 || $0.cast().getValues()[0] == -1)">>; + "llvm::cast($0).getNumElements() == 1 && " + "(llvm::cast($0).getValues()[0] == " + "llvm::cast($1.getType()).getRank() - 1 || llvm::cast($0).getValues()[0] == -1)">>; // Convert exp(x)/sum(exp(x)) into softmax. def OptimizeToSoftmax : Pat< @@ -1070,10 +1070,10 @@ def FoldNormalizationIntoSoftmaxJaxWithAxisMinus1 : Pat< def HaveSameType : Constraint>; class AllElementsAreF32 : Constraint() && " - "$0.cast().getType().cast().getElementType().isF32() && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " + "(llvm::isa($0) && " + "llvm::cast(llvm::cast($0).getType()).getElementType().isF32() && " + "std::all_of(llvm::cast($0).getValues().begin(), " + "llvm::cast($0).getValues().end(), " "[](float v){ return v == " #val# ";}))">>; // Optimize X*1 to X @@ -1086,10 +1086,10 @@ def OptimizeMul1ToIdentity : Pat< (AllElementsAreF32<"1.0f"> $constant)]>; class AllElementsAreBool : Constraint() && " - "$0.cast().getType().cast().getElementType().isInteger(1) && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " + "(llvm::isa($0) && " + "llvm::cast(llvm::cast($0).getType()).getElementType().isInteger(1) && " + "std::all_of(llvm::cast($0).getValues().begin(), " + "llvm::cast($0).getValues().end(), " "[](bool v){ return v == " #val# ";}))">>; // Remove select operators when the result is known in advance. @@ -1225,11 +1225,11 @@ def IsLastDimensionEqualOne : Constraint>; // As above but if shape is not static and rank 2 with last dim 1. def IsLastDimensionEqualOneOrDynamicBatchDimRank2 : Constraint< CPred<"IsLastDimensionEqualOne($0) || " - "(!$0.getType().cast().hasStaticShape() && " - " $0.getType().cast().hasRank() && " - " $0.getType().cast().getRank() == 2 && " - " !$0.getType().cast().getShape().empty() && " - " $0.getType().cast().getShape()[1] == 1)">>; + "(!llvm::cast($0.getType()).hasStaticShape() && " + " llvm::cast($0.getType()).hasRank() && " + " llvm::cast($0.getType()).getRank() == 2 && " + " !llvm::cast($0.getType()).getShape().empty() && " + " llvm::cast($0.getType()).getShape()[1] == 1)">>; // Replace // Equal(X, indices) @@ -1250,10 +1250,10 @@ def ReshapeEqualOpToOneHotOp : Pat< (IsOneHotIndexAttribute $series)]>; def F32ElementsVal : Constraint().getElementType().isF32()">, + "llvm::cast($0.getType()).getElementType().isF32()">, "32 bit float tensor">; def I32ElementsVal : Constraint().getElementType().isInteger(32)">, + "llvm::cast($0.getType()).getElementType().isInteger(32)">, "32 bit integer tensor">; def ConvertSingleElementAttrToFloatAttr : @@ -1664,7 +1664,7 @@ def isF32Splat : Constraint< CPred<"IsF32Splat($0)">>; def ExtractF32AtIndex0: NativeCodeCall< - "$_builder.getF32FloatAttr($_self.cast().getValues()[0])">; + "$_builder.getF32FloatAttr(llvm::cast($_self).getValues()[0])">; def FuseLeakyReluConst : Pat< (TFL_SelectOp @@ -1699,16 +1699,16 @@ class ContractingDimsProductEqual : Constraint : Constraint().getShape()" + "(llvm::dyn_cast($0.getType()).getShape()" ".drop_back("#skip_last#").drop_front("#skip_first#") ==" - "$1.getType().dyn_cast().getShape()" + "llvm::dyn_cast($1.getType()).getShape()" ".drop_back("#skip_last#").drop_front("#skip_first#"))">>; // Returns true if the broadcast dimension of a tensor is [1] // here- broadcast dimension is first prefix dimension // excluding the last two dimensions def IsBroadcastDimEqualToOne : Constraint().getShape()[0] == 1">>; + "llvm::dyn_cast($0.getType()).getShape()[0] == 1">>; // Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp // This pattern is applied when the rank of rhs is 2 @@ -1953,25 +1953,25 @@ def FuseSliceAndPack4D : Pat<( // Given a value, checks if dim `d` is static. class HasStaticDim : Constraint().isDynamicDim(" # d # ")">>; + "!llvm::cast($0.getType()).isDynamicDim(" # d # ")">>; class IsBalancedPaddingArray : Constraint())">>; + "llvm::cast($0))">>; // Given in_shape, out_shape, stride checks ceil(in_shape[d] / stride) == out_shape[d] def IsSameStridedShape2D : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsSameStridedShapeDepthwise : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsSameStridedShape3D : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsValidPadding : Constraint>; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index ae8af0a99cc889..e6781e7ce30b7c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -27,7 +27,7 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def HasSameType : Constraint>; diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index b1dce77f71b001..7583d48618f4fc 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -25,9 +25,9 @@ include "mlir/IR/PatternBase.td" //////////////////////////////////////////////////////////////////////////////// def IsQuantized : Constraint() && " - "$0.getType().dyn_cast().getElementType()" - ".isa()">>; + "llvm::dyn_cast($0.getType()) && " + "llvm::isa(" + "llvm::dyn_cast($0.getType()).getElementType())">>; def IsNotQuantized : Constraint>; @@ -38,42 +38,42 @@ def IsNotQuantized : Constraint>; // Checks if the rank of the value is less than or equal to the rank of the // other value. def IsRankLessThanEqualTo : Constraint().getRank() <= " - "$1.getType().cast().getRank()">>; + "llvm::cast($0.getType()).getRank() <= " + "llvm::cast($1.getType()).getRank()">>; // Checks if the value has rank at most 'n'. class HasRankAtMost : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() <= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() <= " # n>>; //////////////////////////////////////////////////////////////////////////////// ///////////////// DENSE UTILITIES ///////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -def DenseFPElementsAttrPred : CPred<"$_self.isa()">; -def DenseIntElementsAttrPred : CPred<"$_self.isa()">; +def DenseFPElementsAttrPred : CPred<"llvm::isa($_self)">; +def DenseIntElementsAttrPred : CPred<"llvm::isa($_self)">; //////////////////////////////////////////////////////////////////////////////// ///////////////// SPLAT CONSTANT UTILITIES ///////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// def DenseElementsAttrIsSplatPred - : CPred<"$_self.cast().isSplat()">; + : CPred<"llvm::cast($_self).isSplat()">; class DenseFPElementsAttrSplatValueEqualToPred - : CPred<"$_self.cast().getSplatValue()" + : CPred<"llvm::cast($_self).getSplatValue()" ".getValueAsDouble() == " # val>; class DenseFPElementsAttrSplatValueEqualToPredWithTolerance - : CPred<"std::abs($_self.cast().getSplatValue()" + : CPred<"std::abs(llvm::cast($_self).getSplatValue()" ".getValueAsDouble() - " # val # ") <= "#tolerance>; class DenseIntElementsAttrSplatValueEqualToPred - : CPred<"$_self.isa() && " - "$_self.cast().getElementType()" - " .isa() && " - "$_self.cast().isSplat() && " - "$_self.cast().getSplatValue()" + : CPred<"llvm::isa($_self) && " + "llvm::isa(" + "llvm::cast($_self).getElementType()) && " + "llvm::cast($_self).isSplat() && " + "llvm::cast($_self).getSplatValue()" " .getValue().getSExtValue() == " # val>; // AttrConstraint to match a floating point dense elements attribute with a @@ -110,8 +110,8 @@ def SplatIntElementsAttr : ElementsAttrBase< def GetScalarElementsAttrFromSplat : NativeCodeCall< "DenseElementsAttr::get(" " RankedTensorType::get({}," - " $0.cast().getType().getElementType())," - " $0.cast().getSplatValue())">; + " llvm::cast($0).getType().getElementType())," + " llvm::cast($0).getSplatValue())">; //////////////////////////////////////////////////////////////////////////////// ///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// @@ -129,10 +129,10 @@ def OperandsDontBroadcastToOutputType : Constraint().hasStaticShape() && " - "$1.getType().cast().hasStaticShape() && " - "$0.getType().cast().getShape() ==" - "$1.getType().cast().getShape()">, + CPred<"llvm::cast($0.getType()).hasStaticShape() && " + "llvm::cast($1.getType()).hasStaticShape() && " + "llvm::cast($0.getType()).getShape() ==" + "llvm::cast($1.getType()).getShape()">, "have the same static shape">; def CreateNoneValue : NativeCodeCall< @@ -159,7 +159,7 @@ def IsAllOnesConstant : Constraint>; // the permutation is a cyclic permutation of the original shape with only the // identity dimensions permuted. def IsTransposeTrivial : Constraint().getShape(), $1)">>; + "TFL::IsTransposeTrivial(llvm::cast($0.getType()).getShape(), $1)">>; // Constraint that checks if the transpose op is a no-op. def IsTransposeNoop : Constraint>; @@ -169,15 +169,15 @@ def IsTransposeNoop : Constraint>; // the order of non-identity dimensions. def IsReshapeEquivalentToTranspose : Constraint()," - "$1.getType().cast())">>; + "llvm::cast($0.getType())," + "llvm::cast($1.getType()))">>; // Returns the permutation of the trivial reshape op, this will be used to // construct the transpose op. def GetPermutationFromTrivialReshape : NativeCodeCall< "TFL::GetPermutationFromTrivialReshape(" - "$0.getType().cast()," - "$1.getType().cast())">; + "llvm::cast($0.getType())," + "llvm::cast($1.getType()))">; // Constraint that checks if all values in offset between two // attributes are non-negative. @@ -191,12 +191,12 @@ def GetOffSet : NativeCodeCall<"TFL::GetOffSet($0, $1)">; // Attribute Constraint that checks if the attribute value is zero. def ZeroIntAttr - : AttrConstraint().getInt() == 0">>; + : AttrConstraint($_self).getInt() == 0">>; // Checks if the value has rank at most 'n'. class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() >= " # n>>; // Accepts two inputs and check if both have the same element type. def SameElementType : Constraint< @@ -227,7 +227,7 @@ def AreLastTwoDimsTransposed : Constraint>; // Checks if the param passed is of NoneType. -def IsNoneType : Constraint()">>; +def IsNoneType : Constraint($0.getType())">>; def ConstantLikePred : CPred<"::mlir::matchPattern($0, ::mlir::m_Constant())">; def IsConstantLike : Constraint; From 05dd91bb58b67067e249057101f4165a708b9bc2 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Thu, 17 Apr 2025 06:34:53 -0700 Subject: [PATCH 0921/1324] Use the first config returned by search space instead of the hardcoded default one This should be slightly better, since it allows us to adjust to the problem, rather than just hardcoding something that might not be good at all. PiperOrigin-RevId: 748662097 --- .../gpu/autotuning/gemm_fusion_autotuner.cc | 10 +++--- .../autotuning/gemm_fusion_autotuner_test.cc | 35 +++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 19fde26bcb1012..e521e9d07fb149 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -891,19 +891,21 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { debug_options_.xla_gpu_enable_split_k_autotuning(); if (debug_options_.xla_gpu_experimental_enable_dynamic_dot_search_space()) { - if (!IsAutotuningEnabled()) { - return {{kDefaultConfig}}; - } TritonDotFusionSearchSpace search_space(config_.GetDeviceDescription(), &dot); VLOG(1) << "Generating configs from search space: " << search_space.ToString(); // We don't need to consider small_dot here. The new search space will // already generate a unique config for small problems. - return search_space.GenerateConfigs( + std::vector configs = search_space.GenerateConfigs( /*force_contracting_split=*/autotune_contracting_split ? std::nullopt : std::make_optional(1)); + if (!IsAutotuningEnabled()) { + // Keep the first config, which likely does not spill registers. + configs.resize(1); + } + return configs; } // Retrieve the minimum bit-width participating in the dot. This is needed diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index ed4a583745ede9..5f6bb139292ad9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -417,6 +417,19 @@ class DynamicSearchSpaceAutotunerTest : public GemmFusionAutotunerTest { } }; +// TODO: b/404470821 - Merge this with the autotuning levels test once dynamic +// search space is enabled by default. +class DynamicSearchSpaceAutotunerDisabledTest + : public DynamicSearchSpaceAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = + GemmFusionAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_autotune_level(0); + return debug_options; + } +}; + absl::StatusOr> GetPossibleMatmulAutotuneTritonConfigs( const HloDotInstruction& dot, @@ -1813,6 +1826,28 @@ ENTRY e { )"); } +TEST_F(DynamicSearchSpaceAutotunerDisabledTest, + ReturnsSingleConfigWhenAutotuningIsDisabled) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::kAmpere, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneTritonConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + + EXPECT_EQ(configs.size(), 1); +} + TEST_F(GemmFusionAutotunerTest, VerifyHopperConfigsAreDifferentFromBlackwell) { if (isRocm()) { GTEST_SKIP() << "Not supported on ROCm."; From 2fed685dbd7790b0184f295634451ae2440749c6 Mon Sep 17 00:00:00 2001 From: Ranko Sredojevic Date: Thu, 17 Apr 2025 06:45:20 -0700 Subject: [PATCH 0922/1324] Clean up HloModule constness annotations. PiperOrigin-RevId: 748664464 --- third_party/xla/xla/hlo/ir/hlo_module.cc | 8 +++++--- third_party/xla/xla/hlo/ir/hlo_module.h | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index c2652e59a65101..0c56fe60b656b9 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -244,8 +244,9 @@ HloComputation* HloModule::AddEmbeddedComputation( } void HloModule::MarkFusionDuplications( - const absl::flat_hash_map& replacements) { - for (std::unique_ptr& computation : computations_) { + const absl::flat_hash_map& replacements) + const { + for (const std::unique_ptr& computation : computations_) { for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kFusion) { auto rep = @@ -1343,7 +1344,8 @@ uint64_t HloModule::RandomNew64() const { return rng_(); } -HloComputation* HloModule::GetComputationWithName(absl::string_view name) { +HloComputation* HloModule::GetComputationWithName( + absl::string_view name) const { auto computations_in_module = computations(); auto it = absl::c_find_if( computations_in_module, diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index e34ef01ac71a6f..70356cea6d59ff 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -120,8 +120,8 @@ class HloModule { // Marks duplicate fusions with the same name to be able to group them for // analysis purposes (e.g. through Xprof). void MarkFusionDuplications( - const absl::flat_hash_map& - replacements); + const absl::flat_hash_map& replacements) + const; // Replaces all uses of computations that are keys of 'replacements' with // the corresponding values in 'replacements'. Replaces the entry computation, @@ -287,7 +287,7 @@ class HloModule { // Returns the computation in this module that has the name `name`. Returns // null if there is no such computation. - HloComputation* GetComputationWithName(absl::string_view name); + HloComputation* GetComputationWithName(absl::string_view name) const; // Gets the number of computations in this module. int64_t computation_count() const { return computations_.size(); } From e8112266c240a61523328e5f0c589d0b4995d226 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 17 Apr 2025 06:50:49 -0700 Subject: [PATCH 0923/1324] #sdy Make X64Rewriter handle `xla.sdy.FuncResultSharding` PiperOrigin-RevId: 748665867 --- .../service/spmd/shardy/sdy_round_trip/BUILD | 2 + .../sdy_round_trip/import_shardy_attrs.cc | 54 +++++++++++++++++-- .../sdy_round_trip/import_shardy_attrs.h | 4 +- .../test/sdy_round_trip_import_pipeline.mlir | 29 ++++++++-- 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index 2ff9318d820fca..85f777322ab702 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -62,6 +62,7 @@ cc_library( deps = [ "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", + "@com_google_absl//absl/log", "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", @@ -134,6 +135,7 @@ cc_library( srcs = ["import_callback_custom_calls.cc"], hdrs = ["import_callback_custom_calls.h"], deps = [ + "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc index 53dbc0cf473d40..4ac910b252a685 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc @@ -19,9 +19,12 @@ limitations under the License. #include #include #include +#include +#include "absl/log/log.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/ScopedPrinter.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -45,6 +48,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -74,6 +78,29 @@ using ::mlir::sdy::TensorShardingPerValueAttr; namespace stablehlo = ::mlir::stablehlo; +stablehlo::CustomCallOp dynCastX64CombineCustomCall(Operation* op) { + auto customCallOp = mlir::dyn_cast(op); + if (!customCallOp || customCallOp.getCallTargetName() != "X64Combine") { + return nullptr; + } + return customCallOp; +} + +stablehlo::CustomCallOp getX64CombineOnFuncResultSharding( + stablehlo::CustomCallOp funcResultSharding) { + if (funcResultSharding.getNumResults() != 2 || + !funcResultSharding.getResult(0).hasOneUse() || + !funcResultSharding.getResult(1).hasOneUse()) { + return nullptr; + } + Operation* lhsUser = *funcResultSharding.getResult(0).user_begin(); + Operation* rhsUser = *funcResultSharding.getResult(1).user_begin(); + if (lhsUser != rhsUser) { + return nullptr; + } + return dynCastX64CombineCustomCall(lhsUser); +} + // Builds the shardy attributes coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs @@ -128,8 +155,19 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { // func result and delete the CustomCallOp. auto shardingPerValueAttr = parseStringAttr( dictAttr, kShardingRoundTripAttr); - for (mlir::OpOperand& use : - llvm::make_early_inc_range(customCallOp->getUses())) { + + auto resultUses = customCallOp->getUses(); + if (auto x64CombineOp = + getX64CombineOnFuncResultSharding(customCallOp)) { + // X64Rewriter pass will pass through the two split 32-bit operands to + // the `kFuncResultShardingTargetName`, which will return two 32-bit + // results, that would then be passed to a `X64Combine` custom-call. + // Therefore, we need to look at the uses of the `X64Combine` instead + // to find the corresponding `func.return` op. + mlir::sdy::setShardings(x64CombineOp, shardingPerValueAttr); + resultUses = x64CombineOp->getUses(); + } + for (mlir::OpOperand& use : llvm::make_early_inc_range(resultUses)) { // We currently ignore users that are not the func return op. // This might happen due to inlined func ops that originally had // result shardings. @@ -137,10 +175,18 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { if (mlir::isa(use.getOwner())) { funcOp.setResultAttr(use.getOperandNumber(), kShardingAttr, shardingPerValueAttr.getSharding(0)); - use.set(customCallOp.getOperand(0)); + } else if (!dynCastX64CombineCustomCall(use.getOwner())) { + LOG(WARNING) + << std::string_view( // non-absl ok + kFuncResultShardingTargetName) + << " custom-call has a user that isn't `func.return` (" + << std::string_view( // non-absl ok + use.getOwner()->getName().getStringRef()) + << "), which will be ignored. Please file a bug with a " + << "reproducer."; } } - rewriter.replaceOp(customCallOp, customCallOp.getOperand(0)); + rewriter.replaceOp(customCallOp, customCallOp.getOperands()); return; } if (targetName == kShardingCustomCallTargetName || diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h index 0e75e2fbc648a4..9e930e537017e5 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h @@ -25,10 +25,12 @@ namespace sdy { // Creates the pass to convert frontend attributes to SDY attributes: // +// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols // - Converts shardings from `kShardingRoundTripAttr` to `kShardingAttr` // - Converts sharding rules from `kShardingRuleRoundTripAttr` to // `kShardingRuleAttr` -// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols +// - Replaces `kFuncResultShardingTargetName` custom-calls with the +// corresponding func result sharding. std::unique_ptr createSdyRoundTripImportShardyAttrsPass(); // Registers the xla-sdy-round-trip-import-shardy-attrs pass. diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 379023c9db13ec..e853926bb33e5e 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -18,13 +18,12 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}, // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) { - // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2 - // CHECK-NEXT: } func.func @func_results_with_sharding( %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"b\"}p2]>"}}, %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"a\"}p1]>"}}, %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\"c\"}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) { + // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2 %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"b\"}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> @@ -33,18 +32,38 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> } + // CHECK-LABEL: func @func_result_shardings_used_by_x64_combine(%arg0: tensor<16xi64>) + // CHECK-SAME: -> (tensor<16xi64> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) { + func.func @func_result_shardings_used_by_x64_combine( + %arg0: tensor<16xi64>) -> tensor<16xi64> { + // CHECK-NEXT: %[[SPLIT_LOW:.*]] = stablehlo.custom_call @X64SplitLow(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + // CHECK-NEXT: %[[SPLIT_HIGH:.*]] = stablehlo.custom_call @X64SplitHigh(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + // CHECK-NEXT: %[[SPLIT_COMBINE:.*]] = stablehlo.custom_call @X64Combine(%[[SPLIT_LOW]], %[[SPLIT_HIGH]]) + // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : (tensor<16xui32>, tensor<16xui32>) -> tensor<16xi64> + // CHECK-NEXT: return %[[SPLIT_COMBINE]] + %0 = stablehlo.custom_call @X64SplitLow(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + %1 = stablehlo.custom_call @X64SplitHigh(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + %2 = stablehlo.tuple %0, %1 : tuple, tensor<16xui32>> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) + {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}]>]>"}} : + (tuple, tensor<16xui32>>) -> tuple, tensor<16xui32>> + %4 = stablehlo.get_tuple_element %3[0] : (tuple, tensor<16xui32>>) -> tensor<16xui32> + %5 = stablehlo.get_tuple_element %3[1] : (tuple, tensor<16xui32>>) -> tensor<16xui32> + %6 = stablehlo.custom_call @X64Combine(%4, %5) : (tensor<16xui32>, tensor<16xui32>) -> tensor<16xi64> + return %6 : tensor<16xi64> + } + // This might happen due to inlined funcs that originally had result shardings // CHECK-LABEL: func @func_result_shardings_used_by_other_ops( // CHECK-SAME: %arg0: tensor<32xi32>, %arg1: tensor<32xi32> // CHECK-SAME: ) -> ( // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, // CHECK-SAME: tensor<32xi32>) { - // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 - // CHECK-NEXT: return %arg0, %[[ADD]] - // CHECK-NEXT: } func.func @func_result_shardings_used_by_other_ops( %arg0: tensor<32xi32>, %arg1: tensor<32xi32> ) -> (tensor<32xi32>, tensor<32xi32>) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 + // CHECK-NEXT: return %arg0, %[[ADD]] %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"b\"}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\"a\"}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> From afd404e84b5024abed12337b33010e7e0ccf943c Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Thu, 17 Apr 2025 07:54:05 -0700 Subject: [PATCH 0924/1324] Limit the list of search space configs when exhaustive search is disabled In this case, we limit the full set to only the configs compatible with the default set. PiperOrigin-RevId: 748680825 --- .../xla/xla/service/gpu/autotuning/BUILD | 1 + .../gpu/autotuning/dot_search_space.cc | 68 +++++++++++++-- .../service/gpu/autotuning/dot_search_space.h | 36 +++++--- .../gpu/autotuning/dot_search_space_test.cc | 85 +++++++++++++++++++ .../gpu/autotuning/gemm_fusion_autotuner.cc | 5 ++ 5 files changed, 176 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 4d462dc3c9ff73..17678434dfd0af 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -270,6 +270,7 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/stream_executor:device_description", "//xla/tsl/lib/core:bits", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc index b9b258be13c2c2..4145f22c4a6881 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_format.h" @@ -111,7 +112,7 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace( } std::vector TritonDotFusionSearchSpace::GenerateConfigs( - std::optional force_contracting_split) { + std::optional force_contracting_split) const { std::vector configs; if (force_contracting_split.has_value()) { ConfigWithNotes config; @@ -151,6 +152,56 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( return result; } +std::vector TritonDotFusionSearchSpace::OptimizeConfigSet( + const std::vector& configs, + const std::vector& hints) const { + if (hints.empty() || configs.empty()) { + return configs; + } + + auto split_limits = std::minmax_element( + configs.begin(), configs.end(), + [](const auto& a, const auto& b) { return a.split_k < b.split_k; }); + absl::flat_hash_set filter; + for (TritonGemmConfig config : hints) { + // Our default config set does not take problem size into account, so we + // might not even have some of them in the "exhaustive set", since they + // might be outside of the efficient config range. Hence, we limit the tile + // to what can appear in the exhaustive set. + config.block_m = std::clamp(config.block_m, min_out_tile_.lhs_dim, + max_out_tile_.lhs_dim); + config.block_n = std::clamp(config.block_n, min_out_tile_.rhs_dim, + max_out_tile_.rhs_dim); + config.block_k = + std::clamp(config.block_k, min_contracting_tile_size_, + GetMaxContractingTileSize({config.block_m, config.block_n}, + /*contracting_split=*/1)); + config.split_k = std::clamp(config.split_k, split_limits.first->split_k, + split_limits.second->split_k); + VLOG(10) << "Adding config to hint filter: " << config.ToString(); + filter.insert(config); + } + + std::vector result_configs; + for (const TritonGemmConfig& config : configs) { + if (!filter.contains(config)) { + continue; + } + VLOG(10) << "Filtering out configs based on hints: surviving config = " + << config.ToString(); + result_configs.push_back(config); + }; + + if (result_configs.empty()) { + LOG(WARNING) << "All configs were filtered out because none of them " + "sufficiently match the hints. Maybe the hints set does " + "not contain a good representative set of valid configs?" + "Working around this by using the full hints set instead."; + return hints; + } + return result_configs; +} + std::string TritonDotFusionSearchSpace::ToString() const { return absl::StrFormat( "problem_size_BxMxNxKxE: %dx%dx%dx%dx(%d->%d) " @@ -370,7 +421,7 @@ int TritonDotFusionSearchSpace::GetMaxNumStages(OutputTile output_tile, } std::vector -TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { +TritonDotFusionSearchSpace::GenerateContractingSplitFactors() const { CHECK_GE(max_contracting_split_, 1); std::vector configs; ConfigWithNotes config; @@ -384,7 +435,8 @@ TritonDotFusionSearchSpace::GenerateContractingSplitFactors() { } void TritonDotFusionSearchSpace::ExtendConfigs( - std::vector& configs, ExtendConfigCallback extend_config) { + std::vector& configs, + ExtendConfigCallback extend_config) const { CHECK(!configs.empty()); std::vector updated_configs; for (ConfigWithNotes& config : configs) { @@ -396,7 +448,7 @@ void TritonDotFusionSearchSpace::ExtendConfigs( void TritonDotFusionSearchSpace::AddOutputTilings( const ConfigWithNotes& config, - std::vector& updated_configs) { + std::vector& updated_configs) const { CHECK_GT(config.config.split_k, 0) << "Need config with contracting split already set."; const int split = config.config.split_k; @@ -447,7 +499,7 @@ void TritonDotFusionSearchSpace::AddOutputTilings( void TritonDotFusionSearchSpace::AddCtaSizeParameter( const ConfigWithNotes& config, - std::vector& updated_configs) { + std::vector& updated_configs) const { ConfigWithNotes new_config = config; const int tile_rows = config.config.block_m; const int tile_cols = config.config.block_n; @@ -466,7 +518,7 @@ void TritonDotFusionSearchSpace::AddCtaSizeParameter( void TritonDotFusionSearchSpace::AddContractingTiling( const ConfigWithNotes& config, - std::vector& updated_configs) { + std::vector& updated_configs) const { const int tile_rows = config.config.block_m; const int tile_cols = config.config.block_n; const int split = config.config.split_k; @@ -486,7 +538,7 @@ void TritonDotFusionSearchSpace::AddContractingTiling( void TritonDotFusionSearchSpace::AddPipeliningParameter( const ConfigWithNotes& config, - std::vector& updated_configs) { + std::vector& updated_configs) const { const int tile_rows = config.config.block_m; const int tile_cols = config.config.block_n; const int tile_contracting = config.config.block_k; @@ -508,7 +560,7 @@ void TritonDotFusionSearchSpace::AddPipeliningParameter( } void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( - std::vector& configs) { + std::vector& configs) const { CHECK(!configs.empty()); ConfigWithNotes last_config = configs.back(); // Largest split. auto has_too_few_tiles = [](const ConfigWithNotes& config) { diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h index d070f5496ef9cb..42c7c0fbcd8b5f 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space.h @@ -44,7 +44,18 @@ class TritonDotFusionSearchSpace { // autotuner to try. If `force_contracting_split` is set, the search space // will be restricted to only include configs with the given split_k factor. std::vector GenerateConfigs( - std::optional force_contracting_split = std::nullopt); + std::optional force_contracting_split = std::nullopt) const; + + // Restrict the set of configs to the ones compatible with the hints list. + // Generally, this will mean that configs are restricted to the ones that + // appear in hints. The implementation is allowed to deviate though, and + // slightly change the hints list if it thinks that the exact configs in the + // hints are unlikely to be performant (e.g., if the RHS side of a config in + // hints list is larger than the problem's RHS side, it might restrict that + // config to the problem's RHS size). + std::vector OptimizeConfigSet( + const std::vector& configs, + const std::vector& hints) const; // Serializes the search space to a human-readable string. std::string ToString() const; @@ -89,13 +100,13 @@ class TritonDotFusionSearchSpace { // extensions of `config` to the `updated_configs` vector. using ExtendConfigCallback = void (TritonDotFusionSearchSpace::*)( const ConfigWithNotes& config, - std::vector& updated_configs); + std::vector& updated_configs) const; // Extends Triton gemm configs by repeatedly calling `*extend_config()` on // each config in `configs`. Expects that after all calls to `extend_config`, // the updated list of configs is non-empty. void ExtendConfigs(std::vector& configs, - ExtendConfigCallback extend_config); + ExtendConfigCallback extend_config) const; // Computes the maximum number of total warps we should have to sufficiently // saturate the GPU. @@ -155,40 +166,43 @@ class TritonDotFusionSearchSpace { // Finds all promising values for splitting the contracting dimension to // achieve sufficient occupancy (split_k). - std::vector GenerateContractingSplitFactors(); + std::vector GenerateContractingSplitFactors() const; // Finds all promising output shape tilings (block_m, block_n), based on // `config` with already determined contracting split value and appends them // to `updated_configs`. Each config in the input list might yield zero or // more configs in the output. void AddOutputTilings(const ConfigWithNotes& config, - std::vector& updated_configs); + std::vector& updated_configs) const; // Finds all promising values for the Cooperative Thread Array (aka. CTA, aka. // CUDA block) size (num_warps), based on `config` with already determined // output tiling and appends them to `updated_configs`. Each config in the // input list might yield zero or more configs in the output. void AddCtaSizeParameter(const ConfigWithNotes& config, - std::vector& updated_configs); + std::vector& updated_configs) const; // Finds all promising values for the contracting dimension tile size // (block_k), based on `config` with already determined contracting split and // output tiling, and appends them to `updated_configs`. Each config in the // input list might yield zero or more configs in the output. - void AddContractingTiling(const ConfigWithNotes& config, - std::vector& updated_configs); + void AddContractingTiling( + const ConfigWithNotes& config, + std::vector& updated_configs) const; // Finds all promising values for the pipelining parameter, based on // `config` with already determined contracting split, output tiling, and // contracting tile size, and appends them to `updated_configs`. Each config // in the input list might yield zero or more configs in the output. - void AddPipeliningParameter(const ConfigWithNotes& config, - std::vector& updated_configs); + void AddPipeliningParameter( + const ConfigWithNotes& config, + std::vector& updated_configs) const; // Removes configs that are marked with `not_enough_tiles` from the list. If // this results in an empty list, adds a config that should be the most // optimal one even though it does not occupy all cores. - void EliminateLowOccupancyConfigs(std::vector& configs); + void EliminateLowOccupancyConfigs( + std::vector& configs) const; // The order of these fields is important: the values of those defined earlier // are used to compute the values of later ones. diff --git a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc index 282b39c283f4e7..e7d8c507c56b67 100644 --- a/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/dot_search_space_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/dot_search_space.h" #include +#include #include #include @@ -35,6 +36,8 @@ namespace xla::gpu { namespace { using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::Field; using ::testing::Ge; @@ -403,5 +406,87 @@ TEST_F(DotSearchSpaceTest, EnsuresWgmmaShapeForLargeProblem) { BlockNIs(Ge(16)))))); } +TEST_F(DotSearchSpaceTest, ReturnsAllConfigsIfNoHints) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + std::vector configs = search_space.GenerateConfigs(); + + EXPECT_THAT(search_space.OptimizeConfigSet(configs, {}), + ElementsAreArray(configs)); +} + +TEST_F(DotSearchSpaceTest, OptimizesEmptyConfigSet) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + TritonGemmConfig hint = {/*block_m=*/32, /*block_n=*/32, + /*block_k=*/32, /*split_k=*/1, + /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + + EXPECT_THAT(search_space.OptimizeConfigSet({}, {hint}), IsEmpty()); +} + +TEST_F(DotSearchSpaceTest, RestrictsConfigsToHints) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + TritonGemmConfig matching_hint = { + /*block_m=*/32, /*block_n=*/32, /*block_k=*/32, + /*split_k=*/1, /*num_stages=*/1, /*num_warps=*/4, + /* num_ctas=*/1}; + TritonGemmConfig non_matching_hint = { + /*block_m=*/64, /*block_n=*/32, /*block_k=*/32, + /*split_k=*/1, /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + TritonGemmConfig other_config = { + /*block_m=*/32, /*block_n=*/64, /*block_k=*/32, + /*split_k=*/1, /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + + EXPECT_THAT( + search_space.OptimizeConfigSet({other_config, matching_hint}, + {matching_hint, non_matching_hint}), + ElementsAre(matching_hint)); +} + +TEST_F(DotSearchSpaceTest, RestrictsConfigsWithPartialMatch) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/4096, /*rhs_parallel_dim=*/16, + /*contracting_dim=*/1024)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + TritonGemmConfig hint = {/*block_m=*/32, /*block_n=*/32, + /*block_k=*/32, /*split_k=*/1, + /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + TritonGemmConfig expected = {/*block_m=*/32, /*block_n=*/16, + /*block_k=*/32, /*split_k=*/2, + /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + + EXPECT_THAT( + search_space.OptimizeConfigSet( + search_space.GenerateConfigs(/*force_contracting_split=*/2), {hint}), + ElementsAre(expected)); +} + +TEST_F(DotSearchSpaceTest, ReturnsNonEmptySetForUnusualHints) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule(/*lhs_parallel_dim=*/4096, + /*rhs_parallel_dim=*/4096)); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + TritonGemmConfig hint = {/*block_m=*/1024, /*block_n=*/1024, + /*block_k=*/32, /*split_k=*/1, + /*num_stages=*/1, /*num_warps=*/4, + /*num_ctas=*/1}; + + EXPECT_THAT( + search_space.OptimizeConfigSet(search_space.GenerateConfigs(), {hint}), + Not(IsEmpty())); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index e521e9d07fb149..fca9dc1a60a601 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -901,6 +901,11 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { /*force_contracting_split=*/autotune_contracting_split ? std::nullopt : std::make_optional(1)); + if (!debug_options_.xla_gpu_exhaustive_tiling_search()) { + VLOG(1) << "Restricting configs to the default set."; + configs = search_space.OptimizeConfigSet( + configs, /*hints=*/GetDefaultTritonConfigs()); + } if (!IsAutotuningEnabled()) { // Keep the first config, which likely does not spill registers. configs.resize(1); From 846203befdd86046854bcd7768eb685df9914120 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 17 Apr 2025 08:18:22 -0700 Subject: [PATCH 0925/1324] Fix algebraic simplifier for sharded pad PiperOrigin-RevId: 748687434 --- .../xla/xla/hlo/transforms/simplifiers/BUILD | 1 + .../simplifiers/algebraic_simplifier.cc | 1 + .../simplifiers/algebraic_simplifier_test.cc | 25 +++++++++++++++++++ 3 files changed, 27 insertions(+) diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD index de6d92b820ffce..430eaa10d4b882 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -411,6 +411,7 @@ xla_cc_test( "//xla/service:memory_annotations_hdr", "//xla/service:pattern_matcher", "//xla/service:shape_inference", + "//xla/tests:test_utils", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index a7925cc1ab222a..4a13474b48a805 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -5728,6 +5728,7 @@ absl::Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { } *broadcast->mutable_shape() = broadcast_shape1; *broadcast->mutable_dimensions() = broadcast_dimensions; + broadcast->clear_sharding(); simplifier_->UpdateLayout(broadcast->mutable_shape()); auto pad2 = pad->AddInstruction(pad->CloneWithNewShape(pad_shape1)); *pad2->mutable_padding_config() = pad_config; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 984af5cdce37c8..7cff2113e4aceb 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -59,6 +59,7 @@ limitations under the License. #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/window_util.h" @@ -13011,5 +13012,29 @@ TEST_F(AlgebraicSimplifierTest, CopyReshapeToReshapeCopyWithHostCopies) { EXPECT_FALSE(simplifier.Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, SimplifyShardedPad) { + const char* hlo = R"( +HloModule test, num_partitions=4 + +ENTRY main { + c0 = f32[] constant(0) + c1 = f32[] constant(1) + b0 = f32[512,34,5]{2,1,0} broadcast(c0), dimensions={}, sharding={devices=[1,2,2]<=[2,2]T(1,0)} + ROOT pad = f32[512,46,5]{2,1,0} pad(b0, c1), padding=0_0x6_6x0_0, sharding={devices=[1,2,2]<=[2,2]T(1,0)} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant())))); + TF_EXPECT_OK(VerifyHloModule(m.get(), + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/true)); +} + } // namespace } // namespace xla From 041625e7745ba2591a3569d80277aabd05bb11db Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 17 Apr 2025 08:19:14 -0700 Subject: [PATCH 0926/1324] PR #25374: [ROCM] Fix invalid include path in triton patch Imported from GitHub PR https://github.com/openxla/xla/pull/25374 Fix invalid path in triton script which breaks rocm CI build. Discussed in xla google chat. Seems it is preferable to adjust the patch itself: Commend by @akuegel ``` I think adjusting the patch itself is probably better In principle we aim to upstream these patches to Triton, and upstreaming a broken patch does not make sense ``` the build break is introduced https://github.com/openxla/xla/commit/ca404278c0ab0b92169727d5e90c6e32662133b9 Copybara import of the project: -- 105ce45069335de390dc34b3550944e36b1157b5 by Alexandros Theodoridis : Fix invalid include path in triton patch Merging this change closes #25374 PiperOrigin-RevId: 748687648 --- .../triton/llvm_integration/cl747619712.patch | 19 +++++++++---------- .../triton/llvm_integration/cl747619712.patch | 19 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/third_party/triton/llvm_integration/cl747619712.patch b/third_party/triton/llvm_integration/cl747619712.patch index 00d8b36216b3eb..3b796776b5c4ee 100644 --- a/third_party/triton/llvm_integration/cl747619712.patch +++ b/third_party/triton/llvm_integration/cl747619712.patch @@ -1,14 +1,14 @@ --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-11 01:29:32.000000000 -0700 +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-14 17:05:58.000000000 -0700 -@@ -12,6 +12,7 @@ +@@ -7,6 +7,7 @@ + #include "Utility.h" + #include "mlir/Conversion/LLVMCommon/TypeConverter.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinTypes.h" + #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" - #include "mlir/IR/ValueRange.h" - #include "mlir/Transforms/DialectConversion.h" -+#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" - #include "triton/Conversion/TritonGPUToLLVM/Utility.h" - #include "triton/Dialect/Triton/IR/Types.h" - #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -366,7 +367,7 @@ auto cacheMod = op.getCache(); @@ -57,18 +57,17 @@ --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-11 01:29:32.000000000 -0700 +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-14 17:05:59.000000000 -0700 -@@ -1,9 +1,11 @@ +@@ -1,8 +1,10 @@ #include "Utility.h" + #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" ++#include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" -+#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" - #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" @@ -49,7 +51,7 @@ Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, Value pred, int64_t vecSize) { diff --git a/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch b/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch index 00d8b36216b3eb..3b796776b5c4ee 100644 --- a/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch +++ b/third_party/xla/third_party/triton/llvm_integration/cl747619712.patch @@ -1,14 +1,14 @@ --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-11 01:29:32.000000000 -0700 +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp 2025-04-14 17:05:58.000000000 -0700 -@@ -12,6 +12,7 @@ +@@ -7,6 +7,7 @@ + #include "Utility.h" + #include "mlir/Conversion/LLVMCommon/TypeConverter.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinTypes.h" + #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" - #include "mlir/IR/ValueRange.h" - #include "mlir/Transforms/DialectConversion.h" -+#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" - #include "triton/Conversion/TritonGPUToLLVM/Utility.h" - #include "triton/Dialect/Triton/IR/Types.h" - #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -366,7 +367,7 @@ auto cacheMod = op.getCache(); @@ -57,18 +57,17 @@ --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-11 01:29:32.000000000 -0700 +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp 2025-04-14 17:05:59.000000000 -0700 -@@ -1,9 +1,11 @@ +@@ -1,8 +1,10 @@ #include "Utility.h" + #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" ++#include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" -+#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" - #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" @@ -49,7 +51,7 @@ Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, Value pred, int64_t vecSize) { From 7061630e8824be2434e7b4dd57925cfb296ce232 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Thu, 17 Apr 2025 08:22:07 -0700 Subject: [PATCH 0927/1324] Improve error reporting by adding source location to error message of ASSERT_TRUE PiperOrigin-RevId: 748688345 --- third_party/xla/xla/tsl/platform/statusor.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/tsl/platform/statusor.h b/third_party/xla/xla/tsl/platform/statusor.h index f638fe3f2cda32..041bd35648afc1 100644 --- a/third_party/xla/xla/tsl/platform/statusor.h +++ b/third_party/xla/xla/tsl/platform/statusor.h @@ -100,9 +100,10 @@ using StatusOr ABSL_DEPRECATE_AND_INLINE() = absl::StatusOr; TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ rexpr); -#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ +#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_TRUE(statusor.status().ok()) \ + << ADD_SOURCE_LOCATION(statusor.status()); \ lhs = std::move(statusor).value() #define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) From dfbc884452610cafd659118dccee126d5dec4a23 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 08:44:27 -0700 Subject: [PATCH 0928/1324] Integrate LLVM at llvm/llvm-project@ffd5b148941a Updates LLVM usage to match [ffd5b148941a](https://github.com/llvm/llvm-project/commit/ffd5b148941a) PiperOrigin-RevId: 748693996 --- third_party/llvm/generated.patch | 336 +++++++-- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/temporary.patch | 698 ++++++++++++------ third_party/shardy/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 698 ++++++++++++------ .../xla/third_party/shardy/workspace.bzl | 4 +- 6 files changed, 1269 insertions(+), 475 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 436c4e97f096ef..2337741d5d8c43 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -14,51 +14,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/ } /// Select the correct "usual" deallocation function to use from a selection of -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ---- a/clang/lib/Serialization/ASTReaderStmt.cpp -+++ b/clang/lib/Serialization/ASTReaderStmt.cpp -@@ -2226,10 +2226,7 @@ - E->AssociatedDeclAndRef.setPointer(readDeclAs()); - E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); - E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); -- if (CurrentUnpackingBits->getNextBit()) -- E->PackIndex = Record.readInt(); -- else -- E->PackIndex = 0; -+ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); - E->Final = CurrentUnpackingBits->getNextBit(); - E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); - E->Replacement = Record.readSubExpr(); -@@ -2239,6 +2236,7 @@ - SubstNonTypeTemplateParmPackExpr *E) { - VisitExpr(E); - E->AssociatedDecl = readDeclAs(); -+ E->Final = CurrentUnpackingBits->getNextBit(); - E->Index = Record.readInt(); - TemplateArgument ArgPack = Record.readTemplateArgument(); - if (ArgPack.getKind() != TemplateArgument::Pack) -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ---- a/clang/lib/Serialization/ASTWriterStmt.cpp -+++ b/clang/lib/Serialization/ASTWriterStmt.cpp -@@ -2228,9 +2228,7 @@ - Record.AddDeclRef(E->getAssociatedDecl()); - CurrentPackingBits.addBit(E->isReferenceParameter()); - CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); -- CurrentPackingBits.addBit((bool)E->getPackIndex()); -- if (auto PackIndex = E->getPackIndex()) -- Record.push_back(*PackIndex + 1); -+ Record.writeUnsignedOrNone(E->getPackIndex()); - CurrentPackingBits.addBit(E->getFinal()); - - Record.AddSourceLocation(E->getNameLoc()); -@@ -2242,6 +2240,7 @@ - SubstNonTypeTemplateParmPackExpr *E) { - VisitExpr(E); - Record.AddDeclRef(E->getAssociatedDecl()); -+ CurrentPackingBits.addBit(E->getFinal()); - Record.push_back(E->getIndex()); - Record.AddTemplateArgument(E->getArgumentPack()); - Record.AddSourceLocation(E->getParameterPackLocation()); diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp --- a/clang/test/CodeGenCXX/bug135668.cpp +++ b/clang/test/CodeGenCXX/bug135668.cpp @@ -130,3 +85,294 @@ diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/Se + TestClass *obj = new TestClass() ; + return obj->field; +} +diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ++++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +@@ -25183,7 +25183,7 @@ + return SDValue(); + + auto *Ld = dyn_cast(Extract->getOperand(0)); +- if (!Ld || Ld->getExtensionType() || !Ld->isSimple()) ++ if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple()) + return SDValue(); + + // Allow targets to opt-out. +diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +@@ -7241,6 +7241,8 @@ + return Res.takeVector(); + }; + auto GetNumOperands = [](const TreeEntry *TE) { ++ if (TE->State == TreeEntry::SplitVectorize) ++ return TE->getNumOperands(); + if (auto *CI = dyn_cast(TE->getMainOp()); CI) + return CI->arg_size(); + return TE->getNumOperands(); +@@ -18064,8 +18066,14 @@ + // need to rebuild it. + EntryToLastInstruction.clear(); + // All blocks must be scheduled before any instructions are inserted. +- for (auto &BSIter : BlocksSchedules) { ++ for (auto &BSIter : BlocksSchedules) + scheduleBlock(BSIter.second.get()); ++ // Cache last instructions for the nodes to avoid side effects, which may ++ // appear during vectorization, like extra uses, etc. ++ for (const std::unique_ptr &TE : VectorizableTree) { ++ if (TE->isGather()) ++ continue; ++ (void)getLastInstructionInBundle(TE.get()); + } + + if (ReductionRoot) +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/pr135821.ll b/llvm/test/CodeGen/AArch64/pr135821.ll +--- a/llvm/test/CodeGen/AArch64/pr135821.ll ++++ b/llvm/test/CodeGen/AArch64/pr135821.ll +@@ -0,0 +1,27 @@ ++; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ++; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu | FileCheck %s ++ ++define <4 x float> @f(ptr %0) { ++; CHECK-LABEL: f: ++; CHECK: // %bb.0: ++; CHECK-NEXT: sub sp, sp, #32 ++; CHECK-NEXT: str x30, [sp, #16] // 8-byte Folded Spill ++; CHECK-NEXT: .cfi_def_cfa_offset 32 ++; CHECK-NEXT: .cfi_offset w30, -16 ++; CHECK-NEXT: ldr q1, [x0, #56]! ++; CHECK-NEXT: ldr d0, [x0, #16] ++; CHECK-NEXT: mov v1.d[1], v0.d[0] ++; CHECK-NEXT: str q1, [sp] // 16-byte Folded Spill ++; CHECK-NEXT: bl use ++; CHECK-NEXT: ldr q0, [sp] // 16-byte Folded Reload ++; CHECK-NEXT: ldr x30, [sp, #16] // 8-byte Folded Reload ++; CHECK-NEXT: add sp, sp, #32 ++; CHECK-NEXT: ret ++ %2 = getelementptr inbounds nuw i8, ptr %0, i64 56 ++ %3 = load <6 x float>, ptr %2, align 4 ++ %4 = shufflevector <6 x float> %3, <6 x float> poison, <4 x i32> ++ tail call void @use(ptr %2) ++ ret <4 x float> %4 ++} ++ ++declare void @use(ptr) +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll b/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll +--- a/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll ++++ b/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll +@@ -0,0 +1,91 @@ ++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-generic-linux-gnu < %s | FileCheck %s ++ ++define void @test(ptr %nExp, float %0, i1 %cmp, float %1) { ++; CHECK-LABEL: define void @test( ++; CHECK-SAME: ptr [[NEXP:%.*]], float [[TMP0:%.*]], i1 [[CMP:%.*]], float [[TMP1:%.*]]) { ++; CHECK-NEXT: [[ENTRY:.*]]: ++; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> , float [[TMP1]], i32 2 ++; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[TMP0]], i32 3 ++; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]] ++; CHECK: [[IF_THEN]]: ++; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[NEXP]], align 4 ++; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP3]], <4 x float> poison, <2 x i32> ++; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x float> [[TMP5]], float [[TMP4]], i32 0 ++; CHECK-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[TMP6]], zeroinitializer ++; CHECK-NEXT: [[TMP8:%.*]] = fmul <2 x float> [[TMP5]], zeroinitializer ++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x float> , float [[TMP1]], i32 3 ++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <4 x i32> ++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP9]], <4 x float> [[TMP10]], <4 x i32> ++; CHECK-NEXT: br label %[[IF_END]] ++; CHECK: [[IF_END]]: ++; CHECK-NEXT: [[TMP12:%.*]] = phi <4 x float> [ [[TMP11]], %[[IF_THEN]] ], [ [[TMP3]], %[[ENTRY]] ] ++; CHECK-NEXT: [[TMP13:%.*]] = phi <2 x float> [ [[TMP8]], %[[IF_THEN]] ], [ zeroinitializer, %[[ENTRY]] ] ++; CHECK-NEXT: [[TMP14:%.*]] = phi <2 x float> [ zeroinitializer, %[[IF_THEN]] ], [ , %[[ENTRY]] ] ++; CHECK-NEXT: [[TMP15:%.*]] = phi <2 x float> [ [[TMP7]], %[[IF_THEN]] ], [ zeroinitializer, %[[ENTRY]] ] ++; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x float> [[TMP14]], <2 x float> , <2 x i32> ++; CHECK-NEXT: [[TMP17:%.*]] = fmul <2 x float> [[TMP15]], [[TMP16]] ++; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP13]], [[TMP14]] ++; CHECK-NEXT: [[TMP19:%.*]] = fmul <4 x float> [[TMP12]], zeroinitializer ++; CHECK-NEXT: [[CALL25:%.*]] = load volatile ptr, ptr null, align 8 ++; CHECK-NEXT: [[TMP20:%.*]] = fadd <2 x float> [[TMP18]], [[TMP17]] ++; CHECK-NEXT: [[TMP21:%.*]] = fmul <2 x float> [[TMP20]], zeroinitializer ++; CHECK-NEXT: [[TMP22:%.*]] = fadd <2 x float> [[TMP21]], zeroinitializer ++; CHECK-NEXT: [[TMP23:%.*]] = fmul <4 x float> [[TMP19]], zeroinitializer ++; CHECK-NEXT: [[TMP24:%.*]] = fadd <4 x float> [[TMP19]], zeroinitializer ++; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x float> [[TMP23]], <4 x float> [[TMP24]], <4 x i32> ++; CHECK-NEXT: [[TMP26:%.*]] = call <4 x float> @llvm.vector.insert.v4f32.v2f32(<4 x float> , <2 x float> [[TMP22]], i64 2) ++; CHECK-NEXT: [[TMP27:%.*]] = fadd <4 x float> [[TMP25]], [[TMP26]] ++; CHECK-NEXT: store <4 x float> [[TMP27]], ptr [[CALL25]], align 4 ++; CHECK-NEXT: ret void ++; ++entry: ++ br i1 %cmp, label %if.then, label %if.end ++ ++if.then: ++ %div.i41 = fmul float %0, 0.000000e+00 ++ %2 = load float, ptr %nExp, align 4 ++ %div.1.i.i = fmul float %2, 0.000000e+00 ++ %div.2.i.i = fmul float %0, 0.000000e+00 ++ br label %if.end ++ ++if.end: ++ %3 = phi float [ %1, %if.then ], [ %0, %entry ] ++ %4 = phi float [ 0.000000e+00, %if.then ], [ %1, %entry ] ++ %5 = phi float [ 0.000000e+00, %if.then ], [ 0x7FF8000000000000, %entry ] ++ %6 = phi float [ 0.000000e+00, %if.then ], [ 1.000000e+00, %entry ] ++ %fa.sroa.9.0 = phi float [ %div.2.i.i, %if.then ], [ 0.000000e+00, %entry ] ++ %fa.sroa.7.0 = phi float [ %div.1.i.i, %if.then ], [ 0.000000e+00, %entry ] ++ %fa.sroa.0.0 = phi float [ %div.i41, %if.then ], [ 0.000000e+00, %entry ] ++ %mul.1.i.i58 = fmul float %fa.sroa.7.0, %6 ++ %mul.2.i.i60 = fmul float %fa.sroa.9.0, %6 ++ %mul.1.i.i.i63 = fmul float %fa.sroa.0.0, %5 ++ %mul.2.i.i.i65 = fmul float %fa.sroa.0.0, 0.000000e+00 ++ %mul.i66 = fmul float %fa.sroa.0.0, 0.000000e+00 ++ %add.1.i.i = fadd float %mul.1.i.i58, %mul.1.i.i.i63 ++ %add.2.i.i = fadd float %mul.2.i.i60, %mul.2.i.i.i65 ++ %mul.1.i.i74 = fmul float %add.1.i.i, 0.000000e+00 ++ %mul.2.i.i76 = fmul float %add.2.i.i, 0.000000e+00 ++ %mul.i.i.i78 = fmul float %mul.i66, 0.000000e+00 ++ %add.1.i.i85 = fadd float %mul.1.i.i74, 0.000000e+00 ++ %add.2.i.i86 = fadd float %mul.2.i.i76, 0.000000e+00 ++ %mul.i.i.i97 = fmul float %5, 0.000000e+00 ++ %mul.1.i.i.i99 = fmul float %4, 0.000000e+00 ++ %mul.2.i.i.i101 = fmul float %3, 0.000000e+00 ++ %add.i.i103 = fadd float %mul.i.i.i97, 0.000000e+00 ++ %add.1.i.i104 = fadd float %mul.1.i.i.i99, 0.000000e+00 ++ %add.2.i.i105 = fadd float %mul.2.i.i.i101, 0.000000e+00 ++ %add = fadd float %mul.i.i.i78, 0.000000e+00 ++ %add.i = fadd float %add.i.i103, 1.000000e+00 ++ %add.1.i = fadd float %add.1.i.i104, %add.1.i.i85 ++ %add.2.i = fadd float %add.2.i.i105, %add.2.i.i86 ++ %call25 = load volatile ptr, ptr null, align 8 ++ store float %add, ptr %call25, align 4 ++ %__trans_tmp_29.sroa.5.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 4 ++ store float %add.i, ptr %__trans_tmp_29.sroa.5.0.call25.sroa_idx, align 4 ++ %__trans_tmp_29.sroa.6.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 8 ++ store float %add.1.i, ptr %__trans_tmp_29.sroa.6.0.call25.sroa_idx, align 4 ++ %__trans_tmp_29.sroa.7.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 12 ++ store float %add.2.i, ptr %__trans_tmp_29.sroa.7.0.call25.sroa_idx, align 4 ++ ret void ++} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll +--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll ++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll +@@ -0,0 +1,121 @@ ++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -slp-threshold=-1000 < %s | FileCheck %s ++ ++define i64 @Foo(ptr align 8 dereferenceable(344) %0, i64 %1) { ++; CHECK-LABEL: define i64 @Foo( ++; CHECK-SAME: ptr align 8 dereferenceable(344) [[TMP0:%.*]], i64 [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] { ++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP0]], i64 104 ++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 112 ++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24 ++; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[TMP3]], align 8 ++; CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr [[TMP4]], align 8 ++; CHECK-NEXT: [[TMP8:%.*]] = load i64, ptr [[TMP5]], align 8 ++; CHECK-NEXT: [[TMP9:%.*]] = load i64, ptr [[TMP0]], align 8 ++; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i64> poison, i64 [[TMP6]], i32 0 ++; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x i64> [[TMP10]], i64 [[TMP9]], i32 1 ++; CHECK-NEXT: [[TMP12:%.*]] = insertelement <2 x i64> poison, i64 [[TMP7]], i32 0 ++; CHECK-NEXT: [[TMP13:%.*]] = insertelement <2 x i64> [[TMP12]], i64 [[TMP8]], i32 1 ++; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x i64> poison, i64 0, i32 0 ++; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x i64> , i64 [[TMP1]], i32 1 ++; CHECK-NEXT: br label %[[BB16:.*]] ++; CHECK: [[BB16]]: ++; CHECK-NEXT: [[TMP17:%.*]] = phi <2 x i64> [ [[TMP11]], [[TMP2:%.*]] ], [ zeroinitializer, %[[TMP25:.*]] ] ++; CHECK-NEXT: [[TMP18:%.*]] = phi <2 x i64> [ [[TMP13]], [[TMP2]] ], [ [[TMP29:%.*]], %[[TMP25]] ] ++; CHECK-NEXT: switch i32 0, label %[[BB19:.*]] [ ++; CHECK-NEXT: i32 0, label %[[TMP25]] ++; CHECK-NEXT: ] ++; CHECK: [[BB19]]: ++; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i64> [[TMP18]], <2 x i64> poison, <4 x i32> ++; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i64> [[TMP20]], i64 0, i32 1 ++; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i64> [[TMP21]], i64 0, i32 2 ++; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x i64> [[TMP22]], <4 x i64> poison, <4 x i32> ++; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <2 x i64> [[TMP14]], <2 x i64> [[TMP18]], <2 x i32> ++; CHECK-NEXT: br label %[[TMP25]] ++; CHECK: [[TMP25]]: ++; CHECK-NEXT: [[TMP26:%.*]] = phi <2 x i64> [ [[TMP17]], %[[BB19]] ], [ zeroinitializer, %[[BB16]] ] ++; CHECK-NEXT: [[TMP27:%.*]] = phi <4 x i64> [ [[TMP23]], %[[BB19]] ], [ zeroinitializer, %[[BB16]] ] ++; CHECK-NEXT: [[TMP28:%.*]] = phi <2 x i64> [ [[TMP24]], %[[BB19]] ], [ [[TMP15]], %[[BB16]] ] ++; CHECK-NEXT: [[TMP29]] = shufflevector <2 x i64> [[TMP18]], <2 x i64> , <2 x i32> ++; CHECK-NEXT: br i1 false, label %[[DOTLOOPEXIT206:.*]], label %[[BB16]] ++; CHECK: [[_LOOPEXIT206:.*:]] ++; CHECK-NEXT: switch i32 0, label %[[BB32:.*]] [ ++; CHECK-NEXT: i32 0, [[DOTCONT174:label %.*]] ++; CHECK-NEXT: i32 1, label %[[BB30:.*]] ++; CHECK-NEXT: ] ++; CHECK: [[BB30]]: ++; CHECK-NEXT: [[TMP31:%.*]] = shufflevector <4 x i64> [[TMP27]], <4 x i64> , <4 x i32> ++; CHECK-NEXT: br [[DOTCONT174]] ++; CHECK: [[BB32]]: ++; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x i64> [[TMP27]], i64 0, i32 1 ++; CHECK-NEXT: [[TMP34:%.*]] = insertelement <4 x i64> [[TMP33]], i64 0, i32 2 ++; CHECK-NEXT: [[TMP35:%.*]] = shufflevector <4 x i64> [[TMP34]], <4 x i64> poison, <4 x i32> ++; CHECK-NEXT: [[TMP36:%.*]] = insertelement <2 x i64> [[TMP28]], i64 0, i32 0 ++; CHECK-NEXT: br [[DOTCONT174]] ++; CHECK: [[_CONT174:.*:]] ++; CHECK-NEXT: [[TMP37:%.*]] = phi <2 x i64> [ [[TMP26]], %[[BB32]] ], [ zeroinitializer, %[[BB30]] ], [ [[TMP26]], %[[DOTLOOPEXIT206]] ] ++; CHECK-NEXT: [[TMP38:%.*]] = phi <4 x i64> [ [[TMP35]], %[[BB32]] ], [ [[TMP31]], %[[BB30]] ], [ [[TMP27]], %[[DOTLOOPEXIT206]] ] ++; CHECK-NEXT: [[TMP39:%.*]] = phi <2 x i64> [ [[TMP36]], %[[BB32]] ], [ zeroinitializer, %[[BB30]] ], [ [[TMP28]], %[[DOTLOOPEXIT206]] ] ++; CHECK-NEXT: ret i64 0 ++; ++ %3 = getelementptr i8, ptr %0, i64 104 ++ %4 = getelementptr i8, ptr %0, i64 112 ++ %5 = getelementptr i8, ptr %0, i64 24 ++ %6 = load i64, ptr %3, align 8 ++ %7 = load i64, ptr %4, align 8 ++ %8 = load i64, ptr %5, align 8 ++ %9 = load i64, ptr %0, align 8 ++ br label %10 ++ ++10: ++ %11 = phi i64 [ %9, %2 ], [ 0, %18 ] ++ %12 = phi i64 [ %8, %2 ], [ %12, %18 ] ++ %13 = phi i64 [ %7, %2 ], [ 0, %18 ] ++ %14 = phi i64 [ %6, %2 ], [ 0, %18 ] ++ switch i32 0, label %15 [ ++ i32 0, label %18 ++ ] ++ ++15: ++ %16 = tail call i64 @llvm.umin.i64(i64 0, i64 0) ++ %17 = tail call i64 @llvm.umax.i64(i64 0, i64 0) ++ br label %18 ++ ++18: ++ %19 = phi i64 [ %17, %15 ], [ 0, %10 ] ++ %20 = phi i64 [ %16, %15 ], [ 0, %10 ] ++ %21 = phi i64 [ %11, %15 ], [ 0, %10 ] ++ %22 = phi i64 [ %12, %15 ], [ 0, %10 ] ++ %23 = phi i64 [ %13, %15 ], [ %1, %10 ] ++ %24 = phi i64 [ %14, %15 ], [ 0, %10 ] ++ br i1 false, label %.loopexit206, label %10 ++ ++.loopexit206: ++ switch i32 0, label %26 [ ++ i32 0, label %.cont174 ++ i32 1, label %25 ++ ] ++ ++25: ++ br label %.cont174 ++ ++26: ++ %27 = tail call i64 @llvm.umin.i64(i64 0, i64 0) ++ %28 = tail call i64 @llvm.umax.i64(i64 0, i64 0) ++ br label %.cont174 ++ ++.cont174: ++ %.sroa.139.1 = phi i64 [ %28, %26 ], [ %19, %25 ], [ %19, %.loopexit206 ] ++ %.sroa.133.1 = phi i64 [ %27, %26 ], [ 0, %25 ], [ %20, %.loopexit206 ] ++ %.sroa.81.1 = phi i64 [ %23, %26 ], [ 0, %25 ], [ %23, %.loopexit206 ] ++ %.sroa.75.1 = phi i64 [ %24, %26 ], [ 0, %25 ], [ %24, %.loopexit206 ] ++ %.sroa.21.1 = phi i64 [ %21, %26 ], [ 0, %25 ], [ %21, %.loopexit206 ] ++ %.sroa.15.1 = phi i64 [ %22, %26 ], [ 0, %25 ], [ %22, %.loopexit206 ] ++ %29 = phi i64 [ %28, %26 ], [ 0, %25 ], [ %19, %.loopexit206 ] ++ %30 = phi i64 [ %27, %26 ], [ 0, %25 ], [ %20, %.loopexit206 ] ++ ret i64 0 ++} ++ ++declare i64 @llvm.umax.i64(i64, i64) ++ ++declare i64 @llvm.umin.i64(i64, i64) ++ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 3ec4c3ec618e1f..d44a9f6d84632d 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" - LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" + LLVM_COMMIT = "ffd5b148941a1146378a247c70c4faface3a1f96" + LLVM_SHA256 = "fc57e9b703ddfb6d888e1c5beb2a65ca8d84d439bcf88c63eb014ccb8bbea414" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 4779b912722730..b4bbb1e1fef6ea 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,230 +1,504 @@ +diff --git a/docs/sdy_dialect.md b/docs/sdy_dialect.md +index 7b9e18c..ef83d35 100755 +--- a/docs/sdy_dialect.md ++++ b/docs/sdy_dialect.md +@@ -467,7 +467,6 @@ the body on any free axes - those not in the manual_axes list. + - Elements in `in_shardings` and `out_shardings` must satisfy the constraints listed in `TensorShardingAttr`. + - The number of global and local tensor inputs/outputs of the op region must match. + - The manual axes must come before any free axes in each dim sharding. +-- The manual axes cannot introduce padding. Namely, the dimension size must be divisible by the corresponding manual axes size. + - The global and local shapes of the op regions arguments/results must match. + - No manual axes are split. + +diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td +index d598517..07bfa11 100644 +--- a/shardy/dialect/sdy/ir/ops.td ++++ b/shardy/dialect/sdy/ir/ops.td +@@ -145,7 +145,6 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation", + - Elements in `in_shardings` and `out_shardings` must satisfy the constraints listed in `TensorShardingAttr`. + - The number of global and local tensor inputs/outputs of the op region must match. + - The manual axes must come before any free axes in each dim sharding. +- - The manual axes cannot introduce padding. Namely, the dimension size must be divisible by the corresponding manual axes size. + - The global and local shapes of the op regions arguments/results must match. + - No manual axes are split. + }]; +diff --git a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +index c17ea23..139e1f2 100644 +--- a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir ++++ b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +@@ -165,19 +165,6 @@ func.func @man_comp_result_rank_mistmatch(%arg0: tensor<16x32xf32>) -> tensor<16 + + // ----- + +-sdy.mesh @mesh = <["a"=4]> +- +-func.func @dimension_size_not_divisible_by_manual_axes_size(%arg0: tensor<6xf32>) -> tensor<6xf32> { +- // expected-error @+1 {{dimension size 6 is not divisible by the manual axes size 4}} +- %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"a"}]>] out_shardings=[<@mesh, [{"a"}]>] manual_axes={"a"} (%arg1: tensor<1xf32>) { +- %1 = stablehlo.add %arg1, %arg1 : tensor<1xf32> +- sdy.return %1 : tensor<1xf32> +- } : (tensor<6xf32>) -> tensor<6xf32> +- func.return %0: tensor<6xf32> +-} +- +-// ----- +- + sdy.mesh @mesh = <["a"=2]> + + func.func @man_comp_operand_shape_mismatch_replicated(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { +diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc +index e3edcfe..87fc0ee 100644 +--- a/shardy/dialect/sdy/ir/verifiers.cc ++++ b/shardy/dialect/sdy/ir/verifiers.cc +@@ -801,11 +801,9 @@ ArrayRef::iterator findManualAxisAfterFreeAxis( + // 3. the number of global and local tensor inputs/outputs of the op region + // match, + // 4. the manual axes come before any free axes in each dim sharding, +-// 5. The manual axes cannot introduce padding. The dimension size must be +-// divisible by the corresponding manual axes size. +-// 6. the global shape and local shapes of the op regions arguments/results ++// 5. the global shape and local shapes of the op regions arguments/results + // match, and +-// 7. No manual axes are split. ++// 6. No manual axes are split. + // + // `valueKindStr` is a string included in any verification error message + // specifying whether the values we are verifying are the operands or results. +@@ -864,6 +862,8 @@ LogicalResult verifyManualComputationValue( + } + } + ++ // 5. Verify the global shape and local shapes of the op regions ++ // arguments/results match. + SmallVector newDimSizes; + auto globalRankedType = mlir::cast(globalType); + for (auto [dimensionSize, dimSharding] : llvm::zip_equal( +@@ -871,24 +871,13 @@ LogicalResult verifyManualComputationValue( + if (dimensionSize == ShapedType::kDynamic) { + newDimSizes.push_back(ShapedType::kDynamic); + } else { +- // 5. The manual axes cannot introduce padding. The dimension size must +- // be divisible by the corresponding manual axes size. +- + // Safe to call `getMesh` because the sharding was already verified. +- int64_t manualAxesSize = ++ newDimSizes.push_back( ++ dimensionSize / + accumulatedManualAxesSize(op, dimSharding.getAxes(), manualAxesSet, +- sharding.getMesh(symbolTable)); +- if (dimensionSize % manualAxesSize != 0) { +- return op->emitOpError(valueKindStr) +- << " dimension size " << dimensionSize +- << " is not divisible by the manual axes size " +- << manualAxesSize; +- } +- newDimSizes.push_back(dimensionSize / manualAxesSize); ++ sharding.getMesh(symbolTable))); + } + } +- // 6. Verify the global shape and local shapes of the op regions +- // arguments/results match. + auto expectedLocalRankedType = + RankedTensorType::get(newDimSizes, globalRankedType.getElementType()); + auto localRankedType = mlir::cast(localType); +@@ -900,7 +889,7 @@ LogicalResult verifyManualComputationValue( + << ", actual local shape " << localRankedType; + } + +- // 7. No manual axes are split. ++ // 6. No manual axes are split. + if (sharding.anyOfAxisRef([&](AxisRefAttr axis) { + return axis.getSubAxisInfo() && + manualAxesSet.contains(axis.getName()); +diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +index 96c631a..a4ca997 100644 +--- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir ++++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +@@ -1716,7 +1716,7 @@ func.func @manual_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.shardi + } + + // CHECK-LABEL: func @manual_computation_with_manual_axes +-func.func @manual_computation_with_manual_axes(%arg0: tensor<208xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","y"}]>}) -> (tensor<208xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","z"}]>}) { ++func.func @manual_computation_with_manual_axes(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","y"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","z"}]>}) { + %0 = sdy.manual_computation(%arg0) + in_shardings=[<@mesh_xyzt, [{"x","y"}]>] out_shardings=[<@mesh_xyzt, [{"x", "z"}]>] manual_axes={"x"} (%arg1: tensor<52xf32>) { + // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_xyzt, [{"t"}]> : tensor<52xf32> +@@ -1725,9 +1725,9 @@ func.func @manual_computation_with_manual_axes(%arg0: tensor<208xf32> {sdy.shard + // CHECK-NEXT: sdy.return %[[RESHARD2]] : tensor<52xf32> + %2 = stablehlo.abs %arg1 {sdy.sharding=#sdy.sharding_per_value<[<@mesh_xyzt, [{"t"}]>]>} : tensor<52xf32> + sdy.return %2 : tensor<52xf32> +- } : (tensor<208xf32>) -> (tensor<208xf32>) +- %1 = stablehlo.negate %0 {sdy.sharding= #sdy.sharding_per_value<[<@mesh_xyzt, [{"x","z"}]>]>} : tensor<208xf32> +- return %1 : tensor<208xf32> ++ } : (tensor<210xf32>) -> (tensor<210xf32>) ++ %1 = stablehlo.negate %0 {sdy.sharding= #sdy.sharding_per_value<[<@mesh_xyzt, [{"x","z"}]>]>} : tensor<210xf32> ++ return %1 : tensor<210xf32> + } + + // CHECK-LABEL: func @optimization_barrier diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index bbffc2f..436c4e9 100644 +index 436c4e9..2337741 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,49 +1,132 @@ - Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h ----- a/libc/src/__support/FPUtil/aarch64/sqrt.h --+++ b/libc/src/__support/FPUtil/aarch64/sqrt.h --@@ -18,6 +18,8 @@ -- #error "Invalid include" -- #endif -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp -+--- a/clang/lib/Sema/SemaExprCXX.cpp -++++ b/clang/lib/Sema/SemaExprCXX.cpp -+@@ -1929,8 +1929,9 @@ -+ } -+ return true; -+ } -+- -+- return S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); -++ Sema::AccessResult Accessible = -++ S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); -++ return Accessible == Sema::AR_inaccessible; -+ } +@@ -14,51 +14,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/ + } --+#include "src/__support/FPUtil/generic/sqrt.h" --+ -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { + /// Select the correct "usual" deallocation function to use from a selection of +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp +---- a/clang/lib/Serialization/ASTReaderStmt.cpp +-+++ b/clang/lib/Serialization/ASTReaderStmt.cpp +-@@ -2226,10 +2226,7 @@ +- E->AssociatedDeclAndRef.setPointer(readDeclAs()); +- E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); +- E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); +-- if (CurrentUnpackingBits->getNextBit()) +-- E->PackIndex = Record.readInt(); +-- else +-- E->PackIndex = 0; +-+ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); +- E->Final = CurrentUnpackingBits->getNextBit(); +- E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); +- E->Replacement = Record.readSubExpr(); +-@@ -2239,6 +2236,7 @@ +- SubstNonTypeTemplateParmPackExpr *E) { +- VisitExpr(E); +- E->AssociatedDecl = readDeclAs(); +-+ E->Final = CurrentUnpackingBits->getNextBit(); +- E->Index = Record.readInt(); +- TemplateArgument ArgPack = Record.readTemplateArgument(); +- if (ArgPack.getKind() != TemplateArgument::Pack) +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp +---- a/clang/lib/Serialization/ASTWriterStmt.cpp +-+++ b/clang/lib/Serialization/ASTWriterStmt.cpp +-@@ -2228,9 +2228,7 @@ +- Record.AddDeclRef(E->getAssociatedDecl()); +- CurrentPackingBits.addBit(E->isReferenceParameter()); +- CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); +-- CurrentPackingBits.addBit((bool)E->getPackIndex()); +-- if (auto PackIndex = E->getPackIndex()) +-- Record.push_back(*PackIndex + 1); +-+ Record.writeUnsignedOrNone(E->getPackIndex()); +- CurrentPackingBits.addBit(E->getFinal()); - --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h ----- a/libc/src/__support/FPUtil/arm/sqrt.h --+++ b/libc/src/__support/FPUtil/arm/sqrt.h --@@ -18,6 +18,8 @@ -- #error "Invalid include" -- #endif -+ /// Select the correct "usual" deallocation function to use from a selection of -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp -+--- a/clang/lib/Serialization/ASTReaderStmt.cpp -++++ b/clang/lib/Serialization/ASTReaderStmt.cpp -+@@ -2226,10 +2226,7 @@ -+ E->AssociatedDeclAndRef.setPointer(readDeclAs()); -+ E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); -+ E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); -+- if (CurrentUnpackingBits->getNextBit()) -+- E->PackIndex = Record.readInt(); -+- else -+- E->PackIndex = 0; -++ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); -+ E->Final = CurrentUnpackingBits->getNextBit(); -+ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); -+ E->Replacement = Record.readSubExpr(); -+@@ -2239,6 +2236,7 @@ -+ SubstNonTypeTemplateParmPackExpr *E) { -+ VisitExpr(E); -+ E->AssociatedDecl = readDeclAs(); -++ E->Final = CurrentUnpackingBits->getNextBit(); -+ E->Index = Record.readInt(); -+ TemplateArgument ArgPack = Record.readTemplateArgument(); -+ if (ArgPack.getKind() != TemplateArgument::Pack) -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp -+--- a/clang/lib/Serialization/ASTWriterStmt.cpp -++++ b/clang/lib/Serialization/ASTWriterStmt.cpp -+@@ -2228,9 +2228,7 @@ -+ Record.AddDeclRef(E->getAssociatedDecl()); -+ CurrentPackingBits.addBit(E->isReferenceParameter()); -+ CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); -+- CurrentPackingBits.addBit((bool)E->getPackIndex()); -+- if (auto PackIndex = E->getPackIndex()) -+- Record.push_back(*PackIndex + 1); -++ Record.writeUnsignedOrNone(E->getPackIndex()); -+ CurrentPackingBits.addBit(E->getFinal()); - --+#include "src/__support/FPUtil/generic/sqrt.h" -+ Record.AddSourceLocation(E->getNameLoc()); -+@@ -2242,6 +2240,7 @@ -+ SubstNonTypeTemplateParmPackExpr *E) { -+ VisitExpr(E); -+ Record.AddDeclRef(E->getAssociatedDecl()); -++ CurrentPackingBits.addBit(E->getFinal()); -+ Record.push_back(E->getIndex()); -+ Record.AddTemplateArgument(E->getArgumentPack()); -+ Record.AddSourceLocation(E->getParameterPackLocation()); -+diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp -+--- a/clang/test/CodeGenCXX/bug135668.cpp -++++ b/clang/test/CodeGenCXX/bug135668.cpp -+@@ -0,0 +1,38 @@ -++// RUN: %clang_cc1 %s -triple arm64-apple-macosx -emit-llvm -fcxx-exceptions -fexceptions -std=c++23 -o - | FileCheck %s - + -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { -- --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h ----- a/libc/src/__support/FPUtil/riscv/sqrt.h --+++ b/libc/src/__support/FPUtil/riscv/sqrt.h --@@ -18,6 +18,8 @@ -- #error "Invalid include" -- #endif -- --+#include "src/__support/FPUtil/generic/sqrt.h" -++class TestClass { -++ public: -++ TestClass(); -++ int field = 0; -++ friend class Foo; -++ static void * operator new(unsigned long size); -++ private: -++ static void operator delete(void *p); -++ }; - + -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { -- --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h ----- a/libc/src/__support/FPUtil/x86_64/sqrt.h --+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h --@@ -18,6 +18,8 @@ -- #error "sqrtss / sqrtsd need SSE2" -- #endif -- --+#include "src/__support/FPUtil/generic/sqrt.h" -++class Foo { -++public: -++ int test_method(); -++}; - + -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { -- -++int Foo::test_method() { -++ TestClass *obj = new TestClass() ; -++ return obj->field; +- Record.AddSourceLocation(E->getNameLoc()); +-@@ -2242,6 +2240,7 @@ +- SubstNonTypeTemplateParmPackExpr *E) { +- VisitExpr(E); +- Record.AddDeclRef(E->getAssociatedDecl()); +-+ CurrentPackingBits.addBit(E->getFinal()); +- Record.push_back(E->getIndex()); +- Record.AddTemplateArgument(E->getArgumentPack()); +- Record.AddSourceLocation(E->getParameterPackLocation()); + diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp + --- a/clang/test/CodeGenCXX/bug135668.cpp + +++ b/clang/test/CodeGenCXX/bug135668.cpp +@@ -130,3 +85,294 @@ diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/Se + + TestClass *obj = new TestClass() ; + + return obj->field; + +} ++diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ++--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ++@@ -25183,7 +25183,7 @@ ++ return SDValue(); ++ ++ auto *Ld = dyn_cast(Extract->getOperand(0)); ++- if (!Ld || Ld->getExtensionType() || !Ld->isSimple()) +++ if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple()) ++ return SDValue(); ++ ++ // Allow targets to opt-out. ++diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++@@ -7241,6 +7241,8 @@ ++ return Res.takeVector(); ++ }; ++ auto GetNumOperands = [](const TreeEntry *TE) { +++ if (TE->State == TreeEntry::SplitVectorize) +++ return TE->getNumOperands(); ++ if (auto *CI = dyn_cast(TE->getMainOp()); CI) ++ return CI->arg_size(); ++ return TE->getNumOperands(); ++@@ -18064,8 +18066,14 @@ ++ // need to rebuild it. ++ EntryToLastInstruction.clear(); ++ // All blocks must be scheduled before any instructions are inserted. ++- for (auto &BSIter : BlocksSchedules) { +++ for (auto &BSIter : BlocksSchedules) ++ scheduleBlock(BSIter.second.get()); +++ // Cache last instructions for the nodes to avoid side effects, which may +++ // appear during vectorization, like extra uses, etc. +++ for (const std::unique_ptr &TE : VectorizableTree) { +++ if (TE->isGather()) +++ continue; +++ (void)getLastInstructionInBundle(TE.get()); ++ } ++ ++ if (ReductionRoot) ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/pr135821.ll b/llvm/test/CodeGen/AArch64/pr135821.ll ++--- a/llvm/test/CodeGen/AArch64/pr135821.ll +++++ b/llvm/test/CodeGen/AArch64/pr135821.ll ++@@ -0,0 +1,27 @@ +++; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +++; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu | FileCheck %s +++ +++define <4 x float> @f(ptr %0) { +++; CHECK-LABEL: f: +++; CHECK: // %bb.0: +++; CHECK-NEXT: sub sp, sp, #32 +++; CHECK-NEXT: str x30, [sp, #16] // 8-byte Folded Spill +++; CHECK-NEXT: .cfi_def_cfa_offset 32 +++; CHECK-NEXT: .cfi_offset w30, -16 +++; CHECK-NEXT: ldr q1, [x0, #56]! +++; CHECK-NEXT: ldr d0, [x0, #16] +++; CHECK-NEXT: mov v1.d[1], v0.d[0] +++; CHECK-NEXT: str q1, [sp] // 16-byte Folded Spill +++; CHECK-NEXT: bl use +++; CHECK-NEXT: ldr q0, [sp] // 16-byte Folded Reload +++; CHECK-NEXT: ldr x30, [sp, #16] // 8-byte Folded Reload +++; CHECK-NEXT: add sp, sp, #32 +++; CHECK-NEXT: ret +++ %2 = getelementptr inbounds nuw i8, ptr %0, i64 56 +++ %3 = load <6 x float>, ptr %2, align 4 +++ %4 = shufflevector <6 x float> %3, <6 x float> poison, <4 x i32> +++ tail call void @use(ptr %2) +++ ret <4 x float> %4 ++} ++ -++// CHECK-LABEL: define noundef i32 @_ZN3Foo11test_methodEv -++// CHECK: [[THIS_ADDR:%.*]] = alloca ptr, align 8 -++// CHECK: [[OBJ:%.*]] = alloca ptr, align 8 -++// CHECK: store ptr %this, ptr [[THIS_ADDR]], align 8 -++// CHECK: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8 -++// CHECK: [[ALLOCATION:%.*]] = call noundef ptr @_ZN9TestClassnwEm(i64 noundef 4) -++// CHECK: [[INITIALIZEDOBJ:%.*]] = invoke noundef ptr @_ZN9TestClassC1Ev(ptr noundef nonnull align 4 dereferenceable(4) [[ALLOCATION]]) -++// CHECK-NEXT: to label %[[INVOKE_CONT:.*]] unwind label %[[LPAD:.*]] -++// CHECK: [[INVOKE_CONT]]: -++// CHECK: store ptr [[ALLOCATION]], ptr [[OBJ]], align 8 -++// CHECK: [[OBJPTR:%.*]] = load ptr, ptr [[OBJ]], align 8 -++// CHECK: [[FIELDPTR:%.*]] = getelementptr inbounds nuw %class.TestClass, ptr [[OBJPTR]], i32 0, i32 0 -++// CHECK: [[FIELD:%.*]] = load i32, ptr [[FIELDPTR]], align 4 -++// CHECK: ret i32 [[FIELD]] -++// CHECK: [[LPAD]]: -++// CHECK: call void @_ZN9TestClassdlEPv(ptr noundef [[ALLOCATION]]) #3 -+diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/SemaCXX/bug135668.cpp -+--- a/clang/test/SemaCXX/bug135668.cpp -++++ b/clang/test/SemaCXX/bug135668.cpp -+@@ -0,0 +1,25 @@ -++// RUN: %clang_cc1 -triple arm64-apple-macosx -Wall -fsyntax-only -verify %s -std=c++26 -fexceptions -fcxx-exceptions -++// expected-no-diagnostics -++ -++// This test makes sure that we don't erroneously consider an accessible operator -++// delete to be inaccessible, and then discard the entire new expression. -++ -++class TestClass { -++public: -++ TestClass(); -++ int field = 0; -++ friend class Foo; -++ static void * operator new(unsigned long size); -++private: -++ static void operator delete(void *p); -++}; -++ -++class Foo { -++public: -++ int test_method(); -++}; -++ -++int Foo::test_method() { -++ TestClass *obj = new TestClass() ; -++ return obj->field; +++declare void @use(ptr) ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll b/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll ++--- a/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll +++++ b/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll ++@@ -0,0 +1,91 @@ +++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-generic-linux-gnu < %s | FileCheck %s +++ +++define void @test(ptr %nExp, float %0, i1 %cmp, float %1) { +++; CHECK-LABEL: define void @test( +++; CHECK-SAME: ptr [[NEXP:%.*]], float [[TMP0:%.*]], i1 [[CMP:%.*]], float [[TMP1:%.*]]) { +++; CHECK-NEXT: [[ENTRY:.*]]: +++; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> , float [[TMP1]], i32 2 +++; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[TMP0]], i32 3 +++; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]] +++; CHECK: [[IF_THEN]]: +++; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[NEXP]], align 4 +++; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP3]], <4 x float> poison, <2 x i32> +++; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x float> [[TMP5]], float [[TMP4]], i32 0 +++; CHECK-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[TMP6]], zeroinitializer +++; CHECK-NEXT: [[TMP8:%.*]] = fmul <2 x float> [[TMP5]], zeroinitializer +++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x float> , float [[TMP1]], i32 3 +++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <4 x i32> +++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP9]], <4 x float> [[TMP10]], <4 x i32> +++; CHECK-NEXT: br label %[[IF_END]] +++; CHECK: [[IF_END]]: +++; CHECK-NEXT: [[TMP12:%.*]] = phi <4 x float> [ [[TMP11]], %[[IF_THEN]] ], [ [[TMP3]], %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP13:%.*]] = phi <2 x float> [ [[TMP8]], %[[IF_THEN]] ], [ zeroinitializer, %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP14:%.*]] = phi <2 x float> [ zeroinitializer, %[[IF_THEN]] ], [ , %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP15:%.*]] = phi <2 x float> [ [[TMP7]], %[[IF_THEN]] ], [ zeroinitializer, %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x float> [[TMP14]], <2 x float> , <2 x i32> +++; CHECK-NEXT: [[TMP17:%.*]] = fmul <2 x float> [[TMP15]], [[TMP16]] +++; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP13]], [[TMP14]] +++; CHECK-NEXT: [[TMP19:%.*]] = fmul <4 x float> [[TMP12]], zeroinitializer +++; CHECK-NEXT: [[CALL25:%.*]] = load volatile ptr, ptr null, align 8 +++; CHECK-NEXT: [[TMP20:%.*]] = fadd <2 x float> [[TMP18]], [[TMP17]] +++; CHECK-NEXT: [[TMP21:%.*]] = fmul <2 x float> [[TMP20]], zeroinitializer +++; CHECK-NEXT: [[TMP22:%.*]] = fadd <2 x float> [[TMP21]], zeroinitializer +++; CHECK-NEXT: [[TMP23:%.*]] = fmul <4 x float> [[TMP19]], zeroinitializer +++; CHECK-NEXT: [[TMP24:%.*]] = fadd <4 x float> [[TMP19]], zeroinitializer +++; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x float> [[TMP23]], <4 x float> [[TMP24]], <4 x i32> +++; CHECK-NEXT: [[TMP26:%.*]] = call <4 x float> @llvm.vector.insert.v4f32.v2f32(<4 x float> , <2 x float> [[TMP22]], i64 2) +++; CHECK-NEXT: [[TMP27:%.*]] = fadd <4 x float> [[TMP25]], [[TMP26]] +++; CHECK-NEXT: store <4 x float> [[TMP27]], ptr [[CALL25]], align 4 +++; CHECK-NEXT: ret void +++; +++entry: +++ br i1 %cmp, label %if.then, label %if.end +++ +++if.then: +++ %div.i41 = fmul float %0, 0.000000e+00 +++ %2 = load float, ptr %nExp, align 4 +++ %div.1.i.i = fmul float %2, 0.000000e+00 +++ %div.2.i.i = fmul float %0, 0.000000e+00 +++ br label %if.end +++ +++if.end: +++ %3 = phi float [ %1, %if.then ], [ %0, %entry ] +++ %4 = phi float [ 0.000000e+00, %if.then ], [ %1, %entry ] +++ %5 = phi float [ 0.000000e+00, %if.then ], [ 0x7FF8000000000000, %entry ] +++ %6 = phi float [ 0.000000e+00, %if.then ], [ 1.000000e+00, %entry ] +++ %fa.sroa.9.0 = phi float [ %div.2.i.i, %if.then ], [ 0.000000e+00, %entry ] +++ %fa.sroa.7.0 = phi float [ %div.1.i.i, %if.then ], [ 0.000000e+00, %entry ] +++ %fa.sroa.0.0 = phi float [ %div.i41, %if.then ], [ 0.000000e+00, %entry ] +++ %mul.1.i.i58 = fmul float %fa.sroa.7.0, %6 +++ %mul.2.i.i60 = fmul float %fa.sroa.9.0, %6 +++ %mul.1.i.i.i63 = fmul float %fa.sroa.0.0, %5 +++ %mul.2.i.i.i65 = fmul float %fa.sroa.0.0, 0.000000e+00 +++ %mul.i66 = fmul float %fa.sroa.0.0, 0.000000e+00 +++ %add.1.i.i = fadd float %mul.1.i.i58, %mul.1.i.i.i63 +++ %add.2.i.i = fadd float %mul.2.i.i60, %mul.2.i.i.i65 +++ %mul.1.i.i74 = fmul float %add.1.i.i, 0.000000e+00 +++ %mul.2.i.i76 = fmul float %add.2.i.i, 0.000000e+00 +++ %mul.i.i.i78 = fmul float %mul.i66, 0.000000e+00 +++ %add.1.i.i85 = fadd float %mul.1.i.i74, 0.000000e+00 +++ %add.2.i.i86 = fadd float %mul.2.i.i76, 0.000000e+00 +++ %mul.i.i.i97 = fmul float %5, 0.000000e+00 +++ %mul.1.i.i.i99 = fmul float %4, 0.000000e+00 +++ %mul.2.i.i.i101 = fmul float %3, 0.000000e+00 +++ %add.i.i103 = fadd float %mul.i.i.i97, 0.000000e+00 +++ %add.1.i.i104 = fadd float %mul.1.i.i.i99, 0.000000e+00 +++ %add.2.i.i105 = fadd float %mul.2.i.i.i101, 0.000000e+00 +++ %add = fadd float %mul.i.i.i78, 0.000000e+00 +++ %add.i = fadd float %add.i.i103, 1.000000e+00 +++ %add.1.i = fadd float %add.1.i.i104, %add.1.i.i85 +++ %add.2.i = fadd float %add.2.i.i105, %add.2.i.i86 +++ %call25 = load volatile ptr, ptr null, align 8 +++ store float %add, ptr %call25, align 4 +++ %__trans_tmp_29.sroa.5.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 4 +++ store float %add.i, ptr %__trans_tmp_29.sroa.5.0.call25.sroa_idx, align 4 +++ %__trans_tmp_29.sroa.6.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 8 +++ store float %add.1.i, ptr %__trans_tmp_29.sroa.6.0.call25.sroa_idx, align 4 +++ %__trans_tmp_29.sroa.7.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 12 +++ store float %add.2.i, ptr %__trans_tmp_29.sroa.7.0.call25.sroa_idx, align 4 +++ ret void +++} ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll ++--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll +++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll ++@@ -0,0 +1,121 @@ +++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -slp-threshold=-1000 < %s | FileCheck %s +++ +++define i64 @Foo(ptr align 8 dereferenceable(344) %0, i64 %1) { +++; CHECK-LABEL: define i64 @Foo( +++; CHECK-SAME: ptr align 8 dereferenceable(344) [[TMP0:%.*]], i64 [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] { +++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP0]], i64 104 +++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 112 +++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24 +++; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[TMP3]], align 8 +++; CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr [[TMP4]], align 8 +++; CHECK-NEXT: [[TMP8:%.*]] = load i64, ptr [[TMP5]], align 8 +++; CHECK-NEXT: [[TMP9:%.*]] = load i64, ptr [[TMP0]], align 8 +++; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i64> poison, i64 [[TMP6]], i32 0 +++; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x i64> [[TMP10]], i64 [[TMP9]], i32 1 +++; CHECK-NEXT: [[TMP12:%.*]] = insertelement <2 x i64> poison, i64 [[TMP7]], i32 0 +++; CHECK-NEXT: [[TMP13:%.*]] = insertelement <2 x i64> [[TMP12]], i64 [[TMP8]], i32 1 +++; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x i64> poison, i64 0, i32 0 +++; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x i64> , i64 [[TMP1]], i32 1 +++; CHECK-NEXT: br label %[[BB16:.*]] +++; CHECK: [[BB16]]: +++; CHECK-NEXT: [[TMP17:%.*]] = phi <2 x i64> [ [[TMP11]], [[TMP2:%.*]] ], [ zeroinitializer, %[[TMP25:.*]] ] +++; CHECK-NEXT: [[TMP18:%.*]] = phi <2 x i64> [ [[TMP13]], [[TMP2]] ], [ [[TMP29:%.*]], %[[TMP25]] ] +++; CHECK-NEXT: switch i32 0, label %[[BB19:.*]] [ +++; CHECK-NEXT: i32 0, label %[[TMP25]] +++; CHECK-NEXT: ] +++; CHECK: [[BB19]]: +++; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i64> [[TMP18]], <2 x i64> poison, <4 x i32> +++; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i64> [[TMP20]], i64 0, i32 1 +++; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i64> [[TMP21]], i64 0, i32 2 +++; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x i64> [[TMP22]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <2 x i64> [[TMP14]], <2 x i64> [[TMP18]], <2 x i32> +++; CHECK-NEXT: br label %[[TMP25]] +++; CHECK: [[TMP25]]: +++; CHECK-NEXT: [[TMP26:%.*]] = phi <2 x i64> [ [[TMP17]], %[[BB19]] ], [ zeroinitializer, %[[BB16]] ] +++; CHECK-NEXT: [[TMP27:%.*]] = phi <4 x i64> [ [[TMP23]], %[[BB19]] ], [ zeroinitializer, %[[BB16]] ] +++; CHECK-NEXT: [[TMP28:%.*]] = phi <2 x i64> [ [[TMP24]], %[[BB19]] ], [ [[TMP15]], %[[BB16]] ] +++; CHECK-NEXT: [[TMP29]] = shufflevector <2 x i64> [[TMP18]], <2 x i64> , <2 x i32> +++; CHECK-NEXT: br i1 false, label %[[DOTLOOPEXIT206:.*]], label %[[BB16]] +++; CHECK: [[_LOOPEXIT206:.*:]] +++; CHECK-NEXT: switch i32 0, label %[[BB32:.*]] [ +++; CHECK-NEXT: i32 0, [[DOTCONT174:label %.*]] +++; CHECK-NEXT: i32 1, label %[[BB30:.*]] +++; CHECK-NEXT: ] +++; CHECK: [[BB30]]: +++; CHECK-NEXT: [[TMP31:%.*]] = shufflevector <4 x i64> [[TMP27]], <4 x i64> , <4 x i32> +++; CHECK-NEXT: br [[DOTCONT174]] +++; CHECK: [[BB32]]: +++; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x i64> [[TMP27]], i64 0, i32 1 +++; CHECK-NEXT: [[TMP34:%.*]] = insertelement <4 x i64> [[TMP33]], i64 0, i32 2 +++; CHECK-NEXT: [[TMP35:%.*]] = shufflevector <4 x i64> [[TMP34]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: [[TMP36:%.*]] = insertelement <2 x i64> [[TMP28]], i64 0, i32 0 +++; CHECK-NEXT: br [[DOTCONT174]] +++; CHECK: [[_CONT174:.*:]] +++; CHECK-NEXT: [[TMP37:%.*]] = phi <2 x i64> [ [[TMP26]], %[[BB32]] ], [ zeroinitializer, %[[BB30]] ], [ [[TMP26]], %[[DOTLOOPEXIT206]] ] +++; CHECK-NEXT: [[TMP38:%.*]] = phi <4 x i64> [ [[TMP35]], %[[BB32]] ], [ [[TMP31]], %[[BB30]] ], [ [[TMP27]], %[[DOTLOOPEXIT206]] ] +++; CHECK-NEXT: [[TMP39:%.*]] = phi <2 x i64> [ [[TMP36]], %[[BB32]] ], [ zeroinitializer, %[[BB30]] ], [ [[TMP28]], %[[DOTLOOPEXIT206]] ] +++; CHECK-NEXT: ret i64 0 +++; +++ %3 = getelementptr i8, ptr %0, i64 104 +++ %4 = getelementptr i8, ptr %0, i64 112 +++ %5 = getelementptr i8, ptr %0, i64 24 +++ %6 = load i64, ptr %3, align 8 +++ %7 = load i64, ptr %4, align 8 +++ %8 = load i64, ptr %5, align 8 +++ %9 = load i64, ptr %0, align 8 +++ br label %10 +++ +++10: +++ %11 = phi i64 [ %9, %2 ], [ 0, %18 ] +++ %12 = phi i64 [ %8, %2 ], [ %12, %18 ] +++ %13 = phi i64 [ %7, %2 ], [ 0, %18 ] +++ %14 = phi i64 [ %6, %2 ], [ 0, %18 ] +++ switch i32 0, label %15 [ +++ i32 0, label %18 +++ ] +++ +++15: +++ %16 = tail call i64 @llvm.umin.i64(i64 0, i64 0) +++ %17 = tail call i64 @llvm.umax.i64(i64 0, i64 0) +++ br label %18 +++ +++18: +++ %19 = phi i64 [ %17, %15 ], [ 0, %10 ] +++ %20 = phi i64 [ %16, %15 ], [ 0, %10 ] +++ %21 = phi i64 [ %11, %15 ], [ 0, %10 ] +++ %22 = phi i64 [ %12, %15 ], [ 0, %10 ] +++ %23 = phi i64 [ %13, %15 ], [ %1, %10 ] +++ %24 = phi i64 [ %14, %15 ], [ 0, %10 ] +++ br i1 false, label %.loopexit206, label %10 +++ +++.loopexit206: +++ switch i32 0, label %26 [ +++ i32 0, label %.cont174 +++ i32 1, label %25 +++ ] +++ +++25: +++ br label %.cont174 +++ +++26: +++ %27 = tail call i64 @llvm.umin.i64(i64 0, i64 0) +++ %28 = tail call i64 @llvm.umax.i64(i64 0, i64 0) +++ br label %.cont174 +++ +++.cont174: +++ %.sroa.139.1 = phi i64 [ %28, %26 ], [ %19, %25 ], [ %19, %.loopexit206 ] +++ %.sroa.133.1 = phi i64 [ %27, %26 ], [ 0, %25 ], [ %20, %.loopexit206 ] +++ %.sroa.81.1 = phi i64 [ %23, %26 ], [ 0, %25 ], [ %23, %.loopexit206 ] +++ %.sroa.75.1 = phi i64 [ %24, %26 ], [ 0, %25 ], [ %24, %.loopexit206 ] +++ %.sroa.21.1 = phi i64 [ %21, %26 ], [ 0, %25 ], [ %21, %.loopexit206 ] +++ %.sroa.15.1 = phi i64 [ %22, %26 ], [ 0, %25 ], [ %22, %.loopexit206 ] +++ %29 = phi i64 [ %28, %26 ], [ 0, %25 ], [ %19, %.loopexit206 ] +++ %30 = phi i64 [ %27, %26 ], [ 0, %25 ], [ %20, %.loopexit206 ] +++ ret i64 0 ++} +++ +++declare i64 @llvm.umax.i64(i64, i64) +++ +++declare i64 @llvm.umin.i64(i64, i64) +++ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 0b67d8b..3ec4c3e 100644 +index 3ec4c3e..d44a9f6 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" -- LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" -+ LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" -+ LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" +- LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" +- LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" ++ LLVM_COMMIT = "ffd5b148941a1146378a247c70c4faface3a1f96" ++ LLVM_SHA256 = "fc57e9b703ddfb6d888e1c5beb2a65ca8d84d439bcf88c63eb014ccb8bbea414" tf_http_archive( name = name, -diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch -index ff118b8..839d98a 100755 ---- a/third_party/stablehlo/temporary.patch -+++ b/third_party/stablehlo/temporary.patch -@@ -607,6 +607,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/dialect/VhloOps.td", - deps = [ -+diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir -+--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir -++++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir -+@@ -45,7 +45,7 @@ -+ -+ // CHECK-LABEL: @divide -+ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> { -+- // CHECK: tosa.int_div -++ // CHECK: tosa.intdiv -+ %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> -+ return %0 : tensor<10xi32> -+ } -+diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll -+--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll -++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll -+@@ -156,7 +156,7 @@ -+ Pattern => -+ replace op(input0 : Value<_: Tosa_Int32Tensor>, -+ input1 : Value<_: Tosa_Int32Tensor>) -+- with op(input0, input1); -++ with op(input0, input1); -+ Pattern => -+ replace op(input0 : Value<_: Tosa_Tensor>, -+ input1 : Value<_: Tosa_Tensor>) - diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel - --- stablehlo/stablehlo/tests/BUILD.bazel - +++ stablehlo/stablehlo/tests/BUILD.bazel diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index e2bc747f2d6d05..70631415b77b9c 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "9585bea76e06fff5574ddf20bd88cbdfd0b98985" - SHARDY_SHA256 = "f93f37639f8ec1cd5ae4c26e0b3291bfa85923567c01e9404c6e070d061c598a" + SHARDY_COMMIT = "0f48503d743de99500cd6f0120988bf06abba91c" + SHARDY_SHA256 = "c7540bfd5a12eedb3a3b9c049943fc44135af653a90e13b969850ba87c4b464b" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 4779b912722730..b4bbb1e1fef6ea 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,230 +1,504 @@ +diff --git a/docs/sdy_dialect.md b/docs/sdy_dialect.md +index 7b9e18c..ef83d35 100755 +--- a/docs/sdy_dialect.md ++++ b/docs/sdy_dialect.md +@@ -467,7 +467,6 @@ the body on any free axes - those not in the manual_axes list. + - Elements in `in_shardings` and `out_shardings` must satisfy the constraints listed in `TensorShardingAttr`. + - The number of global and local tensor inputs/outputs of the op region must match. + - The manual axes must come before any free axes in each dim sharding. +-- The manual axes cannot introduce padding. Namely, the dimension size must be divisible by the corresponding manual axes size. + - The global and local shapes of the op regions arguments/results must match. + - No manual axes are split. + +diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td +index d598517..07bfa11 100644 +--- a/shardy/dialect/sdy/ir/ops.td ++++ b/shardy/dialect/sdy/ir/ops.td +@@ -145,7 +145,6 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation", + - Elements in `in_shardings` and `out_shardings` must satisfy the constraints listed in `TensorShardingAttr`. + - The number of global and local tensor inputs/outputs of the op region must match. + - The manual axes must come before any free axes in each dim sharding. +- - The manual axes cannot introduce padding. Namely, the dimension size must be divisible by the corresponding manual axes size. + - The global and local shapes of the op regions arguments/results must match. + - No manual axes are split. + }]; +diff --git a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +index c17ea23..139e1f2 100644 +--- a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir ++++ b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +@@ -165,19 +165,6 @@ func.func @man_comp_result_rank_mistmatch(%arg0: tensor<16x32xf32>) -> tensor<16 + + // ----- + +-sdy.mesh @mesh = <["a"=4]> +- +-func.func @dimension_size_not_divisible_by_manual_axes_size(%arg0: tensor<6xf32>) -> tensor<6xf32> { +- // expected-error @+1 {{dimension size 6 is not divisible by the manual axes size 4}} +- %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"a"}]>] out_shardings=[<@mesh, [{"a"}]>] manual_axes={"a"} (%arg1: tensor<1xf32>) { +- %1 = stablehlo.add %arg1, %arg1 : tensor<1xf32> +- sdy.return %1 : tensor<1xf32> +- } : (tensor<6xf32>) -> tensor<6xf32> +- func.return %0: tensor<6xf32> +-} +- +-// ----- +- + sdy.mesh @mesh = <["a"=2]> + + func.func @man_comp_operand_shape_mismatch_replicated(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { +diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc +index e3edcfe..87fc0ee 100644 +--- a/shardy/dialect/sdy/ir/verifiers.cc ++++ b/shardy/dialect/sdy/ir/verifiers.cc +@@ -801,11 +801,9 @@ ArrayRef::iterator findManualAxisAfterFreeAxis( + // 3. the number of global and local tensor inputs/outputs of the op region + // match, + // 4. the manual axes come before any free axes in each dim sharding, +-// 5. The manual axes cannot introduce padding. The dimension size must be +-// divisible by the corresponding manual axes size. +-// 6. the global shape and local shapes of the op regions arguments/results ++// 5. the global shape and local shapes of the op regions arguments/results + // match, and +-// 7. No manual axes are split. ++// 6. No manual axes are split. + // + // `valueKindStr` is a string included in any verification error message + // specifying whether the values we are verifying are the operands or results. +@@ -864,6 +862,8 @@ LogicalResult verifyManualComputationValue( + } + } + ++ // 5. Verify the global shape and local shapes of the op regions ++ // arguments/results match. + SmallVector newDimSizes; + auto globalRankedType = mlir::cast(globalType); + for (auto [dimensionSize, dimSharding] : llvm::zip_equal( +@@ -871,24 +871,13 @@ LogicalResult verifyManualComputationValue( + if (dimensionSize == ShapedType::kDynamic) { + newDimSizes.push_back(ShapedType::kDynamic); + } else { +- // 5. The manual axes cannot introduce padding. The dimension size must +- // be divisible by the corresponding manual axes size. +- + // Safe to call `getMesh` because the sharding was already verified. +- int64_t manualAxesSize = ++ newDimSizes.push_back( ++ dimensionSize / + accumulatedManualAxesSize(op, dimSharding.getAxes(), manualAxesSet, +- sharding.getMesh(symbolTable)); +- if (dimensionSize % manualAxesSize != 0) { +- return op->emitOpError(valueKindStr) +- << " dimension size " << dimensionSize +- << " is not divisible by the manual axes size " +- << manualAxesSize; +- } +- newDimSizes.push_back(dimensionSize / manualAxesSize); ++ sharding.getMesh(symbolTable))); + } + } +- // 6. Verify the global shape and local shapes of the op regions +- // arguments/results match. + auto expectedLocalRankedType = + RankedTensorType::get(newDimSizes, globalRankedType.getElementType()); + auto localRankedType = mlir::cast(localType); +@@ -900,7 +889,7 @@ LogicalResult verifyManualComputationValue( + << ", actual local shape " << localRankedType; + } + +- // 7. No manual axes are split. ++ // 6. No manual axes are split. + if (sharding.anyOfAxisRef([&](AxisRefAttr axis) { + return axis.getSubAxisInfo() && + manualAxesSet.contains(axis.getName()); +diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +index 96c631a..a4ca997 100644 +--- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir ++++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +@@ -1716,7 +1716,7 @@ func.func @manual_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.shardi + } + + // CHECK-LABEL: func @manual_computation_with_manual_axes +-func.func @manual_computation_with_manual_axes(%arg0: tensor<208xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","y"}]>}) -> (tensor<208xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","z"}]>}) { ++func.func @manual_computation_with_manual_axes(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","y"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"x","z"}]>}) { + %0 = sdy.manual_computation(%arg0) + in_shardings=[<@mesh_xyzt, [{"x","y"}]>] out_shardings=[<@mesh_xyzt, [{"x", "z"}]>] manual_axes={"x"} (%arg1: tensor<52xf32>) { + // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_xyzt, [{"t"}]> : tensor<52xf32> +@@ -1725,9 +1725,9 @@ func.func @manual_computation_with_manual_axes(%arg0: tensor<208xf32> {sdy.shard + // CHECK-NEXT: sdy.return %[[RESHARD2]] : tensor<52xf32> + %2 = stablehlo.abs %arg1 {sdy.sharding=#sdy.sharding_per_value<[<@mesh_xyzt, [{"t"}]>]>} : tensor<52xf32> + sdy.return %2 : tensor<52xf32> +- } : (tensor<208xf32>) -> (tensor<208xf32>) +- %1 = stablehlo.negate %0 {sdy.sharding= #sdy.sharding_per_value<[<@mesh_xyzt, [{"x","z"}]>]>} : tensor<208xf32> +- return %1 : tensor<208xf32> ++ } : (tensor<210xf32>) -> (tensor<210xf32>) ++ %1 = stablehlo.negate %0 {sdy.sharding= #sdy.sharding_per_value<[<@mesh_xyzt, [{"x","z"}]>]>} : tensor<210xf32> ++ return %1 : tensor<210xf32> + } + + // CHECK-LABEL: func @optimization_barrier diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index bbffc2f..436c4e9 100644 +index 436c4e9..2337741 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,49 +1,132 @@ - Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h ----- a/libc/src/__support/FPUtil/aarch64/sqrt.h --+++ b/libc/src/__support/FPUtil/aarch64/sqrt.h --@@ -18,6 +18,8 @@ -- #error "Invalid include" -- #endif -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp -+--- a/clang/lib/Sema/SemaExprCXX.cpp -++++ b/clang/lib/Sema/SemaExprCXX.cpp -+@@ -1929,8 +1929,9 @@ -+ } -+ return true; -+ } -+- -+- return S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); -++ Sema::AccessResult Accessible = -++ S.CheckAllocationAccess(StartLoc, Range, NamingClass, Decl, Diagnose); -++ return Accessible == Sema::AR_inaccessible; -+ } +@@ -14,51 +14,6 @@ diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/ + } --+#include "src/__support/FPUtil/generic/sqrt.h" --+ -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { + /// Select the correct "usual" deallocation function to use from a selection of +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp +---- a/clang/lib/Serialization/ASTReaderStmt.cpp +-+++ b/clang/lib/Serialization/ASTReaderStmt.cpp +-@@ -2226,10 +2226,7 @@ +- E->AssociatedDeclAndRef.setPointer(readDeclAs()); +- E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); +- E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); +-- if (CurrentUnpackingBits->getNextBit()) +-- E->PackIndex = Record.readInt(); +-- else +-- E->PackIndex = 0; +-+ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); +- E->Final = CurrentUnpackingBits->getNextBit(); +- E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); +- E->Replacement = Record.readSubExpr(); +-@@ -2239,6 +2236,7 @@ +- SubstNonTypeTemplateParmPackExpr *E) { +- VisitExpr(E); +- E->AssociatedDecl = readDeclAs(); +-+ E->Final = CurrentUnpackingBits->getNextBit(); +- E->Index = Record.readInt(); +- TemplateArgument ArgPack = Record.readTemplateArgument(); +- if (ArgPack.getKind() != TemplateArgument::Pack) +-diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp +---- a/clang/lib/Serialization/ASTWriterStmt.cpp +-+++ b/clang/lib/Serialization/ASTWriterStmt.cpp +-@@ -2228,9 +2228,7 @@ +- Record.AddDeclRef(E->getAssociatedDecl()); +- CurrentPackingBits.addBit(E->isReferenceParameter()); +- CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); +-- CurrentPackingBits.addBit((bool)E->getPackIndex()); +-- if (auto PackIndex = E->getPackIndex()) +-- Record.push_back(*PackIndex + 1); +-+ Record.writeUnsignedOrNone(E->getPackIndex()); +- CurrentPackingBits.addBit(E->getFinal()); - --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/arm/sqrt.h b/libc/src/__support/FPUtil/arm/sqrt.h ----- a/libc/src/__support/FPUtil/arm/sqrt.h --+++ b/libc/src/__support/FPUtil/arm/sqrt.h --@@ -18,6 +18,8 @@ -- #error "Invalid include" -- #endif -+ /// Select the correct "usual" deallocation function to use from a selection of -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp -+--- a/clang/lib/Serialization/ASTReaderStmt.cpp -++++ b/clang/lib/Serialization/ASTReaderStmt.cpp -+@@ -2226,10 +2226,7 @@ -+ E->AssociatedDeclAndRef.setPointer(readDeclAs()); -+ E->AssociatedDeclAndRef.setInt(CurrentUnpackingBits->getNextBit()); -+ E->Index = CurrentUnpackingBits->getNextBits(/*Width=*/12); -+- if (CurrentUnpackingBits->getNextBit()) -+- E->PackIndex = Record.readInt(); -+- else -+- E->PackIndex = 0; -++ E->PackIndex = Record.readUnsignedOrNone().toInternalRepresentation(); -+ E->Final = CurrentUnpackingBits->getNextBit(); -+ E->SubstNonTypeTemplateParmExprBits.NameLoc = readSourceLocation(); -+ E->Replacement = Record.readSubExpr(); -+@@ -2239,6 +2236,7 @@ -+ SubstNonTypeTemplateParmPackExpr *E) { -+ VisitExpr(E); -+ E->AssociatedDecl = readDeclAs(); -++ E->Final = CurrentUnpackingBits->getNextBit(); -+ E->Index = Record.readInt(); -+ TemplateArgument ArgPack = Record.readTemplateArgument(); -+ if (ArgPack.getKind() != TemplateArgument::Pack) -+diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp -+--- a/clang/lib/Serialization/ASTWriterStmt.cpp -++++ b/clang/lib/Serialization/ASTWriterStmt.cpp -+@@ -2228,9 +2228,7 @@ -+ Record.AddDeclRef(E->getAssociatedDecl()); -+ CurrentPackingBits.addBit(E->isReferenceParameter()); -+ CurrentPackingBits.addBits(E->getIndex(), /*Width=*/12); -+- CurrentPackingBits.addBit((bool)E->getPackIndex()); -+- if (auto PackIndex = E->getPackIndex()) -+- Record.push_back(*PackIndex + 1); -++ Record.writeUnsignedOrNone(E->getPackIndex()); -+ CurrentPackingBits.addBit(E->getFinal()); - --+#include "src/__support/FPUtil/generic/sqrt.h" -+ Record.AddSourceLocation(E->getNameLoc()); -+@@ -2242,6 +2240,7 @@ -+ SubstNonTypeTemplateParmPackExpr *E) { -+ VisitExpr(E); -+ Record.AddDeclRef(E->getAssociatedDecl()); -++ CurrentPackingBits.addBit(E->getFinal()); -+ Record.push_back(E->getIndex()); -+ Record.AddTemplateArgument(E->getArgumentPack()); -+ Record.AddSourceLocation(E->getParameterPackLocation()); -+diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp -+--- a/clang/test/CodeGenCXX/bug135668.cpp -++++ b/clang/test/CodeGenCXX/bug135668.cpp -+@@ -0,0 +1,38 @@ -++// RUN: %clang_cc1 %s -triple arm64-apple-macosx -emit-llvm -fcxx-exceptions -fexceptions -std=c++23 -o - | FileCheck %s - + -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { -- --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/riscv/sqrt.h b/libc/src/__support/FPUtil/riscv/sqrt.h ----- a/libc/src/__support/FPUtil/riscv/sqrt.h --+++ b/libc/src/__support/FPUtil/riscv/sqrt.h --@@ -18,6 +18,8 @@ -- #error "Invalid include" -- #endif -- --+#include "src/__support/FPUtil/generic/sqrt.h" -++class TestClass { -++ public: -++ TestClass(); -++ int field = 0; -++ friend class Foo; -++ static void * operator new(unsigned long size); -++ private: -++ static void operator delete(void *p); -++ }; - + -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { -- --diff -ruN --strip-trailing-cr a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h ----- a/libc/src/__support/FPUtil/x86_64/sqrt.h --+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h --@@ -18,6 +18,8 @@ -- #error "sqrtss / sqrtsd need SSE2" -- #endif -- --+#include "src/__support/FPUtil/generic/sqrt.h" -++class Foo { -++public: -++ int test_method(); -++}; - + -- namespace LIBC_NAMESPACE_DECL { -- namespace fputil { -- -++int Foo::test_method() { -++ TestClass *obj = new TestClass() ; -++ return obj->field; +- Record.AddSourceLocation(E->getNameLoc()); +-@@ -2242,6 +2240,7 @@ +- SubstNonTypeTemplateParmPackExpr *E) { +- VisitExpr(E); +- Record.AddDeclRef(E->getAssociatedDecl()); +-+ CurrentPackingBits.addBit(E->getFinal()); +- Record.push_back(E->getIndex()); +- Record.AddTemplateArgument(E->getArgumentPack()); +- Record.AddSourceLocation(E->getParameterPackLocation()); + diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bug135668.cpp b/clang/test/CodeGenCXX/bug135668.cpp + --- a/clang/test/CodeGenCXX/bug135668.cpp + +++ b/clang/test/CodeGenCXX/bug135668.cpp +@@ -130,3 +85,294 @@ diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/Se + + TestClass *obj = new TestClass() ; + + return obj->field; + +} ++diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ++--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ++@@ -25183,7 +25183,7 @@ ++ return SDValue(); ++ ++ auto *Ld = dyn_cast(Extract->getOperand(0)); ++- if (!Ld || Ld->getExtensionType() || !Ld->isSimple()) +++ if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple()) ++ return SDValue(); ++ ++ // Allow targets to opt-out. ++diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp ++@@ -7241,6 +7241,8 @@ ++ return Res.takeVector(); ++ }; ++ auto GetNumOperands = [](const TreeEntry *TE) { +++ if (TE->State == TreeEntry::SplitVectorize) +++ return TE->getNumOperands(); ++ if (auto *CI = dyn_cast(TE->getMainOp()); CI) ++ return CI->arg_size(); ++ return TE->getNumOperands(); ++@@ -18064,8 +18066,14 @@ ++ // need to rebuild it. ++ EntryToLastInstruction.clear(); ++ // All blocks must be scheduled before any instructions are inserted. ++- for (auto &BSIter : BlocksSchedules) { +++ for (auto &BSIter : BlocksSchedules) ++ scheduleBlock(BSIter.second.get()); +++ // Cache last instructions for the nodes to avoid side effects, which may +++ // appear during vectorization, like extra uses, etc. +++ for (const std::unique_ptr &TE : VectorizableTree) { +++ if (TE->isGather()) +++ continue; +++ (void)getLastInstructionInBundle(TE.get()); ++ } ++ ++ if (ReductionRoot) ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/pr135821.ll b/llvm/test/CodeGen/AArch64/pr135821.ll ++--- a/llvm/test/CodeGen/AArch64/pr135821.ll +++++ b/llvm/test/CodeGen/AArch64/pr135821.ll ++@@ -0,0 +1,27 @@ +++; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +++; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu | FileCheck %s +++ +++define <4 x float> @f(ptr %0) { +++; CHECK-LABEL: f: +++; CHECK: // %bb.0: +++; CHECK-NEXT: sub sp, sp, #32 +++; CHECK-NEXT: str x30, [sp, #16] // 8-byte Folded Spill +++; CHECK-NEXT: .cfi_def_cfa_offset 32 +++; CHECK-NEXT: .cfi_offset w30, -16 +++; CHECK-NEXT: ldr q1, [x0, #56]! +++; CHECK-NEXT: ldr d0, [x0, #16] +++; CHECK-NEXT: mov v1.d[1], v0.d[0] +++; CHECK-NEXT: str q1, [sp] // 16-byte Folded Spill +++; CHECK-NEXT: bl use +++; CHECK-NEXT: ldr q0, [sp] // 16-byte Folded Reload +++; CHECK-NEXT: ldr x30, [sp, #16] // 8-byte Folded Reload +++; CHECK-NEXT: add sp, sp, #32 +++; CHECK-NEXT: ret +++ %2 = getelementptr inbounds nuw i8, ptr %0, i64 56 +++ %3 = load <6 x float>, ptr %2, align 4 +++ %4 = shufflevector <6 x float> %3, <6 x float> poison, <4 x i32> +++ tail call void @use(ptr %2) +++ ret <4 x float> %4 ++} ++ -++// CHECK-LABEL: define noundef i32 @_ZN3Foo11test_methodEv -++// CHECK: [[THIS_ADDR:%.*]] = alloca ptr, align 8 -++// CHECK: [[OBJ:%.*]] = alloca ptr, align 8 -++// CHECK: store ptr %this, ptr [[THIS_ADDR]], align 8 -++// CHECK: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8 -++// CHECK: [[ALLOCATION:%.*]] = call noundef ptr @_ZN9TestClassnwEm(i64 noundef 4) -++// CHECK: [[INITIALIZEDOBJ:%.*]] = invoke noundef ptr @_ZN9TestClassC1Ev(ptr noundef nonnull align 4 dereferenceable(4) [[ALLOCATION]]) -++// CHECK-NEXT: to label %[[INVOKE_CONT:.*]] unwind label %[[LPAD:.*]] -++// CHECK: [[INVOKE_CONT]]: -++// CHECK: store ptr [[ALLOCATION]], ptr [[OBJ]], align 8 -++// CHECK: [[OBJPTR:%.*]] = load ptr, ptr [[OBJ]], align 8 -++// CHECK: [[FIELDPTR:%.*]] = getelementptr inbounds nuw %class.TestClass, ptr [[OBJPTR]], i32 0, i32 0 -++// CHECK: [[FIELD:%.*]] = load i32, ptr [[FIELDPTR]], align 4 -++// CHECK: ret i32 [[FIELD]] -++// CHECK: [[LPAD]]: -++// CHECK: call void @_ZN9TestClassdlEPv(ptr noundef [[ALLOCATION]]) #3 -+diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/bug135668.cpp b/clang/test/SemaCXX/bug135668.cpp -+--- a/clang/test/SemaCXX/bug135668.cpp -++++ b/clang/test/SemaCXX/bug135668.cpp -+@@ -0,0 +1,25 @@ -++// RUN: %clang_cc1 -triple arm64-apple-macosx -Wall -fsyntax-only -verify %s -std=c++26 -fexceptions -fcxx-exceptions -++// expected-no-diagnostics -++ -++// This test makes sure that we don't erroneously consider an accessible operator -++// delete to be inaccessible, and then discard the entire new expression. -++ -++class TestClass { -++public: -++ TestClass(); -++ int field = 0; -++ friend class Foo; -++ static void * operator new(unsigned long size); -++private: -++ static void operator delete(void *p); -++}; -++ -++class Foo { -++public: -++ int test_method(); -++}; -++ -++int Foo::test_method() { -++ TestClass *obj = new TestClass() ; -++ return obj->field; +++declare void @use(ptr) ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll b/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll ++--- a/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll +++++ b/llvm/test/Transforms/SLPVectorizer/X86/entry-no-bundle-but-extra-use-on-vec.ll ++@@ -0,0 +1,91 @@ +++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-generic-linux-gnu < %s | FileCheck %s +++ +++define void @test(ptr %nExp, float %0, i1 %cmp, float %1) { +++; CHECK-LABEL: define void @test( +++; CHECK-SAME: ptr [[NEXP:%.*]], float [[TMP0:%.*]], i1 [[CMP:%.*]], float [[TMP1:%.*]]) { +++; CHECK-NEXT: [[ENTRY:.*]]: +++; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> , float [[TMP1]], i32 2 +++; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[TMP0]], i32 3 +++; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]] +++; CHECK: [[IF_THEN]]: +++; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[NEXP]], align 4 +++; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP3]], <4 x float> poison, <2 x i32> +++; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x float> [[TMP5]], float [[TMP4]], i32 0 +++; CHECK-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[TMP6]], zeroinitializer +++; CHECK-NEXT: [[TMP8:%.*]] = fmul <2 x float> [[TMP5]], zeroinitializer +++; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x float> , float [[TMP1]], i32 3 +++; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <4 x i32> +++; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP9]], <4 x float> [[TMP10]], <4 x i32> +++; CHECK-NEXT: br label %[[IF_END]] +++; CHECK: [[IF_END]]: +++; CHECK-NEXT: [[TMP12:%.*]] = phi <4 x float> [ [[TMP11]], %[[IF_THEN]] ], [ [[TMP3]], %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP13:%.*]] = phi <2 x float> [ [[TMP8]], %[[IF_THEN]] ], [ zeroinitializer, %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP14:%.*]] = phi <2 x float> [ zeroinitializer, %[[IF_THEN]] ], [ , %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP15:%.*]] = phi <2 x float> [ [[TMP7]], %[[IF_THEN]] ], [ zeroinitializer, %[[ENTRY]] ] +++; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x float> [[TMP14]], <2 x float> , <2 x i32> +++; CHECK-NEXT: [[TMP17:%.*]] = fmul <2 x float> [[TMP15]], [[TMP16]] +++; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP13]], [[TMP14]] +++; CHECK-NEXT: [[TMP19:%.*]] = fmul <4 x float> [[TMP12]], zeroinitializer +++; CHECK-NEXT: [[CALL25:%.*]] = load volatile ptr, ptr null, align 8 +++; CHECK-NEXT: [[TMP20:%.*]] = fadd <2 x float> [[TMP18]], [[TMP17]] +++; CHECK-NEXT: [[TMP21:%.*]] = fmul <2 x float> [[TMP20]], zeroinitializer +++; CHECK-NEXT: [[TMP22:%.*]] = fadd <2 x float> [[TMP21]], zeroinitializer +++; CHECK-NEXT: [[TMP23:%.*]] = fmul <4 x float> [[TMP19]], zeroinitializer +++; CHECK-NEXT: [[TMP24:%.*]] = fadd <4 x float> [[TMP19]], zeroinitializer +++; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x float> [[TMP23]], <4 x float> [[TMP24]], <4 x i32> +++; CHECK-NEXT: [[TMP26:%.*]] = call <4 x float> @llvm.vector.insert.v4f32.v2f32(<4 x float> , <2 x float> [[TMP22]], i64 2) +++; CHECK-NEXT: [[TMP27:%.*]] = fadd <4 x float> [[TMP25]], [[TMP26]] +++; CHECK-NEXT: store <4 x float> [[TMP27]], ptr [[CALL25]], align 4 +++; CHECK-NEXT: ret void +++; +++entry: +++ br i1 %cmp, label %if.then, label %if.end +++ +++if.then: +++ %div.i41 = fmul float %0, 0.000000e+00 +++ %2 = load float, ptr %nExp, align 4 +++ %div.1.i.i = fmul float %2, 0.000000e+00 +++ %div.2.i.i = fmul float %0, 0.000000e+00 +++ br label %if.end +++ +++if.end: +++ %3 = phi float [ %1, %if.then ], [ %0, %entry ] +++ %4 = phi float [ 0.000000e+00, %if.then ], [ %1, %entry ] +++ %5 = phi float [ 0.000000e+00, %if.then ], [ 0x7FF8000000000000, %entry ] +++ %6 = phi float [ 0.000000e+00, %if.then ], [ 1.000000e+00, %entry ] +++ %fa.sroa.9.0 = phi float [ %div.2.i.i, %if.then ], [ 0.000000e+00, %entry ] +++ %fa.sroa.7.0 = phi float [ %div.1.i.i, %if.then ], [ 0.000000e+00, %entry ] +++ %fa.sroa.0.0 = phi float [ %div.i41, %if.then ], [ 0.000000e+00, %entry ] +++ %mul.1.i.i58 = fmul float %fa.sroa.7.0, %6 +++ %mul.2.i.i60 = fmul float %fa.sroa.9.0, %6 +++ %mul.1.i.i.i63 = fmul float %fa.sroa.0.0, %5 +++ %mul.2.i.i.i65 = fmul float %fa.sroa.0.0, 0.000000e+00 +++ %mul.i66 = fmul float %fa.sroa.0.0, 0.000000e+00 +++ %add.1.i.i = fadd float %mul.1.i.i58, %mul.1.i.i.i63 +++ %add.2.i.i = fadd float %mul.2.i.i60, %mul.2.i.i.i65 +++ %mul.1.i.i74 = fmul float %add.1.i.i, 0.000000e+00 +++ %mul.2.i.i76 = fmul float %add.2.i.i, 0.000000e+00 +++ %mul.i.i.i78 = fmul float %mul.i66, 0.000000e+00 +++ %add.1.i.i85 = fadd float %mul.1.i.i74, 0.000000e+00 +++ %add.2.i.i86 = fadd float %mul.2.i.i76, 0.000000e+00 +++ %mul.i.i.i97 = fmul float %5, 0.000000e+00 +++ %mul.1.i.i.i99 = fmul float %4, 0.000000e+00 +++ %mul.2.i.i.i101 = fmul float %3, 0.000000e+00 +++ %add.i.i103 = fadd float %mul.i.i.i97, 0.000000e+00 +++ %add.1.i.i104 = fadd float %mul.1.i.i.i99, 0.000000e+00 +++ %add.2.i.i105 = fadd float %mul.2.i.i.i101, 0.000000e+00 +++ %add = fadd float %mul.i.i.i78, 0.000000e+00 +++ %add.i = fadd float %add.i.i103, 1.000000e+00 +++ %add.1.i = fadd float %add.1.i.i104, %add.1.i.i85 +++ %add.2.i = fadd float %add.2.i.i105, %add.2.i.i86 +++ %call25 = load volatile ptr, ptr null, align 8 +++ store float %add, ptr %call25, align 4 +++ %__trans_tmp_29.sroa.5.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 4 +++ store float %add.i, ptr %__trans_tmp_29.sroa.5.0.call25.sroa_idx, align 4 +++ %__trans_tmp_29.sroa.6.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 8 +++ store float %add.1.i, ptr %__trans_tmp_29.sroa.6.0.call25.sroa_idx, align 4 +++ %__trans_tmp_29.sroa.7.0.call25.sroa_idx = getelementptr i8, ptr %call25, i64 12 +++ store float %add.2.i, ptr %__trans_tmp_29.sroa.7.0.call25.sroa_idx, align 4 +++ ret void +++} ++diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll ++--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll +++++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-num-operands.ll ++@@ -0,0 +1,121 @@ +++; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +++; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -slp-threshold=-1000 < %s | FileCheck %s +++ +++define i64 @Foo(ptr align 8 dereferenceable(344) %0, i64 %1) { +++; CHECK-LABEL: define i64 @Foo( +++; CHECK-SAME: ptr align 8 dereferenceable(344) [[TMP0:%.*]], i64 [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] { +++; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP0]], i64 104 +++; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP0]], i64 112 +++; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24 +++; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[TMP3]], align 8 +++; CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr [[TMP4]], align 8 +++; CHECK-NEXT: [[TMP8:%.*]] = load i64, ptr [[TMP5]], align 8 +++; CHECK-NEXT: [[TMP9:%.*]] = load i64, ptr [[TMP0]], align 8 +++; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i64> poison, i64 [[TMP6]], i32 0 +++; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x i64> [[TMP10]], i64 [[TMP9]], i32 1 +++; CHECK-NEXT: [[TMP12:%.*]] = insertelement <2 x i64> poison, i64 [[TMP7]], i32 0 +++; CHECK-NEXT: [[TMP13:%.*]] = insertelement <2 x i64> [[TMP12]], i64 [[TMP8]], i32 1 +++; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x i64> poison, i64 0, i32 0 +++; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x i64> , i64 [[TMP1]], i32 1 +++; CHECK-NEXT: br label %[[BB16:.*]] +++; CHECK: [[BB16]]: +++; CHECK-NEXT: [[TMP17:%.*]] = phi <2 x i64> [ [[TMP11]], [[TMP2:%.*]] ], [ zeroinitializer, %[[TMP25:.*]] ] +++; CHECK-NEXT: [[TMP18:%.*]] = phi <2 x i64> [ [[TMP13]], [[TMP2]] ], [ [[TMP29:%.*]], %[[TMP25]] ] +++; CHECK-NEXT: switch i32 0, label %[[BB19:.*]] [ +++; CHECK-NEXT: i32 0, label %[[TMP25]] +++; CHECK-NEXT: ] +++; CHECK: [[BB19]]: +++; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i64> [[TMP18]], <2 x i64> poison, <4 x i32> +++; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i64> [[TMP20]], i64 0, i32 1 +++; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i64> [[TMP21]], i64 0, i32 2 +++; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x i64> [[TMP22]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <2 x i64> [[TMP14]], <2 x i64> [[TMP18]], <2 x i32> +++; CHECK-NEXT: br label %[[TMP25]] +++; CHECK: [[TMP25]]: +++; CHECK-NEXT: [[TMP26:%.*]] = phi <2 x i64> [ [[TMP17]], %[[BB19]] ], [ zeroinitializer, %[[BB16]] ] +++; CHECK-NEXT: [[TMP27:%.*]] = phi <4 x i64> [ [[TMP23]], %[[BB19]] ], [ zeroinitializer, %[[BB16]] ] +++; CHECK-NEXT: [[TMP28:%.*]] = phi <2 x i64> [ [[TMP24]], %[[BB19]] ], [ [[TMP15]], %[[BB16]] ] +++; CHECK-NEXT: [[TMP29]] = shufflevector <2 x i64> [[TMP18]], <2 x i64> , <2 x i32> +++; CHECK-NEXT: br i1 false, label %[[DOTLOOPEXIT206:.*]], label %[[BB16]] +++; CHECK: [[_LOOPEXIT206:.*:]] +++; CHECK-NEXT: switch i32 0, label %[[BB32:.*]] [ +++; CHECK-NEXT: i32 0, [[DOTCONT174:label %.*]] +++; CHECK-NEXT: i32 1, label %[[BB30:.*]] +++; CHECK-NEXT: ] +++; CHECK: [[BB30]]: +++; CHECK-NEXT: [[TMP31:%.*]] = shufflevector <4 x i64> [[TMP27]], <4 x i64> , <4 x i32> +++; CHECK-NEXT: br [[DOTCONT174]] +++; CHECK: [[BB32]]: +++; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x i64> [[TMP27]], i64 0, i32 1 +++; CHECK-NEXT: [[TMP34:%.*]] = insertelement <4 x i64> [[TMP33]], i64 0, i32 2 +++; CHECK-NEXT: [[TMP35:%.*]] = shufflevector <4 x i64> [[TMP34]], <4 x i64> poison, <4 x i32> +++; CHECK-NEXT: [[TMP36:%.*]] = insertelement <2 x i64> [[TMP28]], i64 0, i32 0 +++; CHECK-NEXT: br [[DOTCONT174]] +++; CHECK: [[_CONT174:.*:]] +++; CHECK-NEXT: [[TMP37:%.*]] = phi <2 x i64> [ [[TMP26]], %[[BB32]] ], [ zeroinitializer, %[[BB30]] ], [ [[TMP26]], %[[DOTLOOPEXIT206]] ] +++; CHECK-NEXT: [[TMP38:%.*]] = phi <4 x i64> [ [[TMP35]], %[[BB32]] ], [ [[TMP31]], %[[BB30]] ], [ [[TMP27]], %[[DOTLOOPEXIT206]] ] +++; CHECK-NEXT: [[TMP39:%.*]] = phi <2 x i64> [ [[TMP36]], %[[BB32]] ], [ zeroinitializer, %[[BB30]] ], [ [[TMP28]], %[[DOTLOOPEXIT206]] ] +++; CHECK-NEXT: ret i64 0 +++; +++ %3 = getelementptr i8, ptr %0, i64 104 +++ %4 = getelementptr i8, ptr %0, i64 112 +++ %5 = getelementptr i8, ptr %0, i64 24 +++ %6 = load i64, ptr %3, align 8 +++ %7 = load i64, ptr %4, align 8 +++ %8 = load i64, ptr %5, align 8 +++ %9 = load i64, ptr %0, align 8 +++ br label %10 +++ +++10: +++ %11 = phi i64 [ %9, %2 ], [ 0, %18 ] +++ %12 = phi i64 [ %8, %2 ], [ %12, %18 ] +++ %13 = phi i64 [ %7, %2 ], [ 0, %18 ] +++ %14 = phi i64 [ %6, %2 ], [ 0, %18 ] +++ switch i32 0, label %15 [ +++ i32 0, label %18 +++ ] +++ +++15: +++ %16 = tail call i64 @llvm.umin.i64(i64 0, i64 0) +++ %17 = tail call i64 @llvm.umax.i64(i64 0, i64 0) +++ br label %18 +++ +++18: +++ %19 = phi i64 [ %17, %15 ], [ 0, %10 ] +++ %20 = phi i64 [ %16, %15 ], [ 0, %10 ] +++ %21 = phi i64 [ %11, %15 ], [ 0, %10 ] +++ %22 = phi i64 [ %12, %15 ], [ 0, %10 ] +++ %23 = phi i64 [ %13, %15 ], [ %1, %10 ] +++ %24 = phi i64 [ %14, %15 ], [ 0, %10 ] +++ br i1 false, label %.loopexit206, label %10 +++ +++.loopexit206: +++ switch i32 0, label %26 [ +++ i32 0, label %.cont174 +++ i32 1, label %25 +++ ] +++ +++25: +++ br label %.cont174 +++ +++26: +++ %27 = tail call i64 @llvm.umin.i64(i64 0, i64 0) +++ %28 = tail call i64 @llvm.umax.i64(i64 0, i64 0) +++ br label %.cont174 +++ +++.cont174: +++ %.sroa.139.1 = phi i64 [ %28, %26 ], [ %19, %25 ], [ %19, %.loopexit206 ] +++ %.sroa.133.1 = phi i64 [ %27, %26 ], [ 0, %25 ], [ %20, %.loopexit206 ] +++ %.sroa.81.1 = phi i64 [ %23, %26 ], [ 0, %25 ], [ %23, %.loopexit206 ] +++ %.sroa.75.1 = phi i64 [ %24, %26 ], [ 0, %25 ], [ %24, %.loopexit206 ] +++ %.sroa.21.1 = phi i64 [ %21, %26 ], [ 0, %25 ], [ %21, %.loopexit206 ] +++ %.sroa.15.1 = phi i64 [ %22, %26 ], [ 0, %25 ], [ %22, %.loopexit206 ] +++ %29 = phi i64 [ %28, %26 ], [ 0, %25 ], [ %19, %.loopexit206 ] +++ %30 = phi i64 [ %27, %26 ], [ 0, %25 ], [ %20, %.loopexit206 ] +++ ret i64 0 ++} +++ +++declare i64 @llvm.umax.i64(i64, i64) +++ +++declare i64 @llvm.umin.i64(i64, i64) +++ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 0b67d8b..3ec4c3e 100644 +index 3ec4c3e..d44a9f6 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "98feb05825a179c56f965d936b948a95d2a6b888" -- LLVM_SHA256 = "5b8d3c97a0340042380153919fb75fa50669c7266e32ce2cf42f62ad943eddb8" -+ LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" -+ LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" +- LLVM_COMMIT = "179d30f8c3fddd3c85056fd2b8e877a4a8513158" +- LLVM_SHA256 = "39f33d0ba77ca40d254c767519a0f3f5692c2caa271f413e7245ab63d0787bd5" ++ LLVM_COMMIT = "ffd5b148941a1146378a247c70c4faface3a1f96" ++ LLVM_SHA256 = "fc57e9b703ddfb6d888e1c5beb2a65ca8d84d439bcf88c63eb014ccb8bbea414" tf_http_archive( name = name, -diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch -index ff118b8..839d98a 100755 ---- a/third_party/stablehlo/temporary.patch -+++ b/third_party/stablehlo/temporary.patch -@@ -607,6 +607,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/dialect/VhloOps.td", - deps = [ -+diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir -+--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir -++++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir -+@@ -45,7 +45,7 @@ -+ -+ // CHECK-LABEL: @divide -+ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> { -+- // CHECK: tosa.int_div -++ // CHECK: tosa.intdiv -+ %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> -+ return %0 : tensor<10xi32> -+ } -+diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll -+--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll -++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll -+@@ -156,7 +156,7 @@ -+ Pattern => -+ replace op(input0 : Value<_: Tosa_Int32Tensor>, -+ input1 : Value<_: Tosa_Int32Tensor>) -+- with op(input0, input1); -++ with op(input0, input1); -+ Pattern => -+ replace op(input0 : Value<_: Tosa_Tensor>, -+ input1 : Value<_: Tosa_Tensor>) - diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel - --- stablehlo/stablehlo/tests/BUILD.bazel - +++ stablehlo/stablehlo/tests/BUILD.bazel diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index e2bc747f2d6d05..70631415b77b9c 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "9585bea76e06fff5574ddf20bd88cbdfd0b98985" - SHARDY_SHA256 = "f93f37639f8ec1cd5ae4c26e0b3291bfa85923567c01e9404c6e070d061c598a" + SHARDY_COMMIT = "0f48503d743de99500cd6f0120988bf06abba91c" + SHARDY_SHA256 = "c7540bfd5a12eedb3a3b9c049943fc44135af653a90e13b969850ba87c4b464b" tf_http_archive( name = "shardy", From 9fda08c6f1b346cb8988e7418bd561cf019ac495 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 17 Apr 2025 09:04:25 -0700 Subject: [PATCH 0929/1324] Fix mlir cast/dyn_cast/isa in tensorflow Use llvm::cast/dyn_cast/isa since alternatives are deprecated in https://github.com/llvm/llvm-project/pull/135556 PiperOrigin-RevId: 748699830 --- .../transforms/composite_lowering_patterns.td | 12 +++---- .../transforms/legalize_hlo_patterns.td | 4 +-- .../transforms/legalize_tf_patterns.td | 34 +++++++++---------- .../lite/stablehlo/transforms/prepare_hlo.td | 10 +++--- .../tflite_legalize_hlo_patterns.td | 4 +-- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 4f833459cc3a07..2cf060c6379d53 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -29,21 +29,21 @@ def LegalizeHardSwishComposite: Pat< (TFL_HardSwishOp $input)>; def IsNchwLayoutOp: Constraint() " + "$0.get(\"is_nchw_op\") && llvm::dyn_cast($0.get(\"is_nchw_op\")) " "== mlir::BoolAttr::get($_builder.getContext(), true)">>; def IsNhwcLayoutOp: Constraint>; class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>>; class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() >= " # n>>; def I32ElementsVal : Constraint().getElementType().isInteger(32)">, + "llvm::cast($0.getType()).getElementType().isInteger(32)">, "32 bit integer tensor">; // TODO(b/343278954): Move the creation of transposes to a separate prepare pass diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index f9fd092f1e04fb..05a68b2cff370e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -283,7 +283,7 @@ def : Pat<(MHLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// class HasChloCompareType : - CPred<"$_self.cast<::mlir::chlo::ComparisonTypeAttr>().getValue() == " # value>; + CPred<"llvm::cast<::mlir::chlo::ComparisonTypeAttr>($_self).getValue() == " # value>; // Attribute value should be such that it matches the comparison used by // TensorFlow, if the attribute is present. @@ -298,7 +298,7 @@ class CHLO_ComparisonDirectionValue : ConstantAttr; class HasMhloCompareType : - CPred<"$_self.cast<::mlir::mhlo::ComparisonTypeAttr>().getValue() == " # value>; + CPred<"llvm::cast<::mlir::mhlo::ComparisonTypeAttr>($_self).getValue() == " # value>; // Attribute value should be such that it matches the comparison used by // TensorFlow, if the attribute is present. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td index 5e01eea4ed3435..971940086d7010 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td @@ -33,8 +33,8 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; // BatchNorm op patterns. //===----------------------------------------------------------------------===// -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def FalseBoolAttr : AttrConstraint($_self).getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CastValueToI64: NativeCodeCall< "CastValueToI64($0.getLoc(), $1, &$_builder)">; @@ -47,18 +47,18 @@ def CastValueToElementType: NativeCodeCall< // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; + "$0, llvm::cast($1.getType()).getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " + "$0, llvm::cast((*$1.begin()).getType()).getRank(), " "&$_builder)">; def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; + "llvm::cast(hlo::convertElementsAttr(" + "llvm::cast($0), $_builder.getIntegerType(64)))">; def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; @@ -274,17 +274,17 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + : CPred<"llvm::cast($_self).getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, "Scalar ElementsAttr">; def HasRankedFirstOperand - : Constraint()">>; + : Constraint((*$0.begin()).getType())">>; def IsShapedTensor - : Constraint()">>; + : Constraint($0.getType())">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -332,10 +332,10 @@ class MHLO_FftTypeValue : ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + "GetInnerDimFromValue(llvm::cast($0.getType()), &$_builder)">; def CheckInnerDimStatic - : Constraint(), &$_builder)">>; + : Constraint($0.getType()), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), @@ -364,14 +364,14 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + "SliceDenseIntElementsAttrColumn2D(llvm::cast($0), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr(llvm::cast($0), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; + "GetInteriorPadding(llvm::cast($0))">; def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), (MHLO_PadOp $input, $c, @@ -511,10 +511,10 @@ def UnpackStartingIndices: NativeCodeCall< "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; + "CanBeTranslatedToDynamicSlice($0, $1, llvm::cast($2))">>; def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "TFSliceSizes2HLOSliceSizes($0, $1, llvm::cast($2)," "&$_builder)">; def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, @@ -560,7 +560,7 @@ def : Pat<(TF_LegacyCallOp:$op $args, $args_attrs, $res_attrs, FlatSymbolRefAttr //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, llvm::cast($1), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td index 9b6f6efbfcf4f6..c0b274ac1f852b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td @@ -56,10 +56,10 @@ def AreDnumsFullyDefined : Constraint()" + "llvm::cast($2.getType())" ".clone($0.PermuteShape(" "$1," - "$2.getType().cast().getShape()))">; + "llvm::cast($2.getType()).getShape()))">; def IsStandardConv : Constraint())">>; @@ -380,7 +380,7 @@ def GetExplicitPaddingArgs : NativeCodeCall< // Gets element type from Value. def GetElementType : NativeCodeCall< - "$0.getType().cast().getElementType()">; + "llvm::cast($0.getType()).getElementType()">; // Given element type, get a DenseElements with scalar shape and 0 value. def GetZeroScalarAttrFromType : NativeCodeCall< @@ -439,9 +439,9 @@ def UnfuseConvWithExplicitPadding : Pat<(MHLO_ConvolutionOp:$conv def TrivialStrides : NativeCodeCall< "DenseIntElementsAttr::get(" - "RankedTensorType::get({$0.getType().cast().getRank()}," + "RankedTensorType::get({llvm::cast($0.getType()).getRank()}," "$_builder.getI64Type())," - "llvm::SmallVector($0.getType().cast().getRank()," + "llvm::SmallVector(llvm::cast($0.getType()).getRank()," "1))">; def SliceStart : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index 4fb22ae6bbe992..e438e9580697e2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -56,7 +56,7 @@ def : Pat< def I64AttrToI32Attr: NativeCodeCall< "$_builder.getI32IntegerAttr(" - "static_cast($0.cast().getInt()))">; + "static_cast(llvm::cast($0).getInt()))">; def : Pat< (MHLO_ConcatenateOp $inputs, $dim), @@ -298,7 +298,7 @@ foreach pair = [ // Check implicit bool cast of `$_self` to ensure Attribute is non-null before // casting. def HasSupportedComparisonType : AttrConstraint< - CPred<"!$_self || SupportedComparisonType($_self.cast())">>; + CPred<"!$_self || SupportedComparisonType(llvm::cast($_self))">>; class MHLO_ComparisonDirectionValue : ConstantAttr Date: Thu, 17 Apr 2025 09:22:46 -0700 Subject: [PATCH 0930/1324] Fix mlir cast/dyn_cast/isa in tensorflow Use llvm::cast/dyn_cast/isa since alternatives are deprecated in https://github.com/llvm/llvm-project/pull/135556 PiperOrigin-RevId: 748704726 --- .../common/attrs_and_constraints.td | 16 ++++---- .../common/quantization_lib/quantization.td | 8 ++-- .../tf_quantization_lib/tf_quantization.td | 8 ++-- .../passes/remove_sharding_custom_call.td | 2 +- .../lift_quantizable_spots_as_functions.td | 4 +- .../tensorflow/passes/prepare_lifting.td | 10 ++--- .../compiler/mlir/tensorflow/ir/tf_op_base.td | 38 +++++++++---------- .../compiler/mlir/tensorflow/ir/tf_ops.td | 8 ++-- .../transforms/decompose_resource_ops.td | 2 +- .../mlir/tensorflow/transforms/lower_tf.td | 10 ++--- .../mlir/tensorflow/transforms/optimize.td | 16 ++++---- .../tf2xla/transforms/legalize_tf_patterns.td | 34 ++++++++--------- .../mlir/tfr/passes/decompose_patterns.td | 2 +- .../tools/kernel_gen/ir/tf_framework_ops.td | 6 +-- tensorflow/core/ir/ops.td | 2 +- tensorflow/dtensor/mlir/ir/tf_dtensor.td | 6 +-- .../hlo_to_mhlo/hlo_function_importer.cc | 4 +- .../mlir/framework/ir/xla_framework_ops.td | 2 +- 18 files changed, 89 insertions(+), 89 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td index 1921345d601283..b6085d30f656c4 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td @@ -17,7 +17,7 @@ include "mlir/IR/PatternBase.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; // Checks if the data format is "NHWC". @@ -31,13 +31,13 @@ def IsConstTensor : Constraint($0.getDefin // Checks if the element value has a float type. def IsFloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && " - "getElementTypeOrSelf($_self.cast().getType()).isa()">, + CPred<"llvm::isa($_self) && " + "llvm::isa(getElementTypeOrSelf(llvm::cast($_self).getType()))">, "float constant tensor">; // Checks if the boolean value is false. def IsFalseBoolAttr : AttrConstraint< - CPred<"!$_self.cast().getValue()">>; + CPred<"!llvm::cast($_self).getValue()">>; // Checks if the value has only one user. def HasOneUse : Constraint>; @@ -63,7 +63,7 @@ def IsBF16ElementType : Constraint< // Checks if the value has the type of UniformQuantizedType. def IsUniformQuantizedType : Constraint< - CPred<"getElementTypeOrSelf($0).isa()">>; + CPred<"llvm::isa(getElementTypeOrSelf($0))">>; // Checks if the given two values have the same type. def AreTheSameElementType : Constraint< @@ -75,12 +75,12 @@ def AreTheSameValue : Constraint< // Checks if the value has rank. def HasRank : Constraint< - CPred<"$0.getType().cast().hasRank()">>; + CPred<"llvm::cast($0.getType()).hasRank()">>; // Checks if the value has rank of `n`. class HasRankOf : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>, + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>, "Checks if the value has rank of 'n'.">; // Checks if the value has static shape. diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td index 0f9b6a74762f9b..706eb8552eb1ff 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td index 050b87e8c08834..3909495ef239fb 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class TFQuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td index 70ee6dc077ee11..0ff3ece326d242 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td @@ -15,7 +15,7 @@ limitations under the License. include "stablehlo/dialect/StablehloOps.td" class IsStringAttrOf : Constraint< - CPred<"::llvm::isa_and_nonnull($_self) && $_self.cast().getValue() == \"" # value # "\"">, + CPred<"::llvm::isa_and_nonnull($_self) && llvm::cast($_self).getValue() == \"" # value # "\"">, "Is a string attribute whose value is \"" # value # "\"" >; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td index d56ee05dc071dc..9e0f26d8793684 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td @@ -25,8 +25,8 @@ include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" //===----------------------------------------------------------------------===// class IsFusedOpEndsWith : AttrConstraint< - CPred<"!$_self.cast().empty() && " - "$_self.cast()[$_self.cast().size() - 1]." + CPred<"!llvm::cast($_self).empty() && " + "llvm::cast($_self)[llvm::cast($_self).size() - 1]." "cast<::mlir::StringAttr>().str() == \"" # OpName # "\"">, "Matching fused '" # OpName # "' op at the end">; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index d75a01be7d2182..338fdc91fc521c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -83,21 +83,21 @@ class HasEqualElementSize shape_1, list shape_2> : Constraint< "Checks if the given dimensions contain the same number of elements.">; def ReshapableTo1DTensor : Constraint< - CPred<"quant::ReshapableTo1DTensor($0.getType().cast())">, + CPred<"quant::ReshapableTo1DTensor(llvm::cast($0.getType()))">, "Checks if the value dims are all ones except the right most dim">; def ReshapeTo1DTensor : NativeCodeCall< "quant::ReshapeTo1DTensor($_builder, $_loc, $0)">; def HasEqualShape : Constraint().hasRank() && " - "$1.getType().cast().hasRank() && " - "$0.getType().cast().getShape() == $1.getType().cast().getShape()">, + "llvm::cast($0.getType()).hasRank() && " + "llvm::cast($1.getType()).hasRank() && " + "llvm::cast($0.getType()).getShape() == llvm::cast($1.getType()).getShape()">, "Checks if the shapes of tensors are same.">; // Make the 1D value $0 broadcastable with the shape of $1. def MakeOneDimValueBroadcastable : NativeCodeCall< - "MakeOneDimValueBroadcastable($_builder, $_loc, $0, $1.getType().cast())">; + "MakeOneDimValueBroadcastable($_builder, $_loc, $0, llvm::cast($1.getType()))">; // Match convolution op with "NHWC" data format or matmul op. def SupportedAffineOpMatcher : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 127210340114a5..d7ae0542890a79 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -144,24 +144,24 @@ def TF_UniqueResourceAllocation: TraitList<[ //===----------------------------------------------------------------------===// class TF_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">; class TF_ResultIsUnrankedPred : - CPred<"$_op.getResult(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getResult(" # n # ").getType())">; // Returns true if the n-th operand has unknown rank or has rank m. class TF_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TF_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; // Returns true if the n-th result has unknown rank or has rank m. class TF_ResultHasRank : PredOpTrait<"result " # n # " is " # m # "-D", Or<[TF_ResultIsUnrankedPred, - CPred<"$_op.getResult(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getResult(" # n # + ").getType()).getRank() == " # m>]>>; //===----------------------------------------------------------------------===// // TensorFlow resources and side effects @@ -282,12 +282,12 @@ class TF_Op traits = []> : //===----------------------------------------------------------------------===// class TF_TensorFlowAttr : - Attr()">, + Attr($_self)">, "TensorFlow " # description # " attribute">; def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { let returnType = "std::optional>"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; // Create a ranked shape attr by default. let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; @@ -309,11 +309,11 @@ def TF_SymbolRefArrayAttr : // Any tensor element type defined in the TensorFlow dialect def TF_TFDialectType : - Type()">, "TensorFlow type">; + Type($_self)">, "TensorFlow type">; // Class for any TensorFlow dialect specific type class TF_TensorFlowType : - Type()">, + Type($_self)">, "TensorFlow " # description # " type">, BuildableType<"getType()">; @@ -547,9 +547,9 @@ def TF_Tensor : TensorOf<[TF_ElementType]>; // A string attribute whose value are one of the values in `cases`. class TF_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", + "llvm::cast($_self).getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), - "$_self.cast().getValue() == \"" # case # "\""), + "llvm::cast($_self).getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), @@ -558,8 +558,8 @@ class TF_AnyStrAttrOf cases> : StringBasedAttr< // TODO: Use EnumAttr to define the common attribute cases def TF_ConvnetDataFormatAttr : StringBasedAttr< - CPred<"$_self.cast().getValue() == \"NHWC\" || " # - "$_self.cast().getValue() == \"NCHW\"">, + CPred<"llvm::cast($_self).getValue() == \"NHWC\" || " # + "llvm::cast($_self).getValue() == \"NCHW\"">, "'NHWC' or 'NCHW' convnet data format">; //===----------------------------------------------------------------------===// @@ -679,7 +679,7 @@ class TF_DerivedResultShapeListAttr : DerivedAttr< // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", - "return (*getOperation()->result_type_begin()).cast();", + "return llvm::cast((*getOperation()->result_type_begin()));", [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { @@ -713,14 +713,14 @@ class WithBroadcastableCmpOpBuilder { OpBuilder<(ins "Value":$x, "Value":$y), [{ Type resultType; - if (x.getType().isa() || - y.getType().isa()) { + if (llvm::isa(x.getType()) || + llvm::isa(y.getType())) { resultType = UnrankedTensorType::get($_builder.getI1Type()); } else { SmallVector resultShape; if (!OpTrait::util::getBroadcastedShape( - x.getType().cast().getShape(), - y.getType().cast().getShape(), resultShape)) { + llvm::cast(x.getType()).getShape(), + llvm::cast(y.getType()).getShape(), resultShape)) { mlir::emitError($_state.location, "operands have no broadcastable shapes"); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 1400785bc2d22d..8f44953b10fa68 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -57,7 +57,7 @@ class TF_TensorListInitOp : TF_Op { // Returns data type of the result handle. Returned type contains type of // the TensorList element as a subtype. VariantType handle_dtype() { - return getElementTypeOrSelf(getHandle().getType()).cast(); + return llvm::cast(getElementTypeOrSelf(getHandle().getType())); } }]; } @@ -118,7 +118,7 @@ An n-way switch statement, implementing the following: // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) { - auto flat_sym_ref = getBranches()[index].cast(); + auto flat_sym_ref = llvm::cast(getBranches()[index]); if (table) return table->lookupNearestSymbolFrom(*this, flat_sym_ref); return SymbolTable::lookupNearestSymbolFrom(*this, flat_sym_ref); @@ -854,14 +854,14 @@ Example: "return getElementTypeOrSelf(resource_subtype());">; DerivedAttr shape = DerivedAttr< "ShapedType", - "return resource_subtype().cast();", + "return llvm::cast(resource_subtype());", [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; let extraClassDeclaration = [{ TensorType resource_subtype() { return resource_type().getSubtypes()[0]; } ResourceType resource_type() { - return getElementTypeOrSelf(getResource()).cast(); + return llvm::cast(getElementTypeOrSelf(getResource())); } }]; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index f466c1d48d6835..1fc666da4a8d95 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -30,7 +30,7 @@ def CreateTFReadVariableOp : NativeCodeCall< "$_builder.create(" " $0.getLoc()," " GetResourceSubtypeOrDefault(" - " $2, $1.getType().cast().getElementType())," + " $2, llvm::cast($1.getType()).getElementType())," " $2)" >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 4c7810f8df51b1..a9ff5a8f76268a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -34,7 +34,7 @@ class GetI64ScalarElementsAttr : class GetF32Scalar : NativeCodeCall<"GetF32Scalar(&$_builder, " # value # ")">; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CreateTFShapeOp : NativeCodeCall< "$_builder.create($0.getLoc(), $1, $2)">; @@ -74,7 +74,7 @@ def LowerAddOp : Pat<(TF_AddOp TF_NumberNotQuantizedTensor:$x, def GetBiasAddGradReductionIndices : NativeCodeCall< "GetBiasAddGradReductionIndices(" - "$0.getType().cast().getRank(), $1, &$_builder)">; + "llvm::cast($0.getType()).getRank(), $1, &$_builder)">; def LowerBiasAddGradOp : Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format), @@ -120,12 +120,12 @@ def LowerSoftmaxCrossEntropyWithLogitsOp : Pattern< // dimension should be known. class GetDimSizeOfType : NativeCodeCall< "GetScalarOfType(getElementTypeOrSelf($1), " - "$0.getType().cast().getDimSize(" # dim # "))">; + "llvm::cast($0.getType()).getDimSize(" # dim # "))">; // Same as the above with i32 element type. class GetDimSizeAsI32 : NativeCodeCall< "GetScalarOfType($_builder.getIntegerType(32), " - "$0.getType().cast().getDimSize(" # dim # "))">; + "llvm::cast($0.getType()).getDimSize(" # dim # "))">; // Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by // expanding the sparse labels using: @@ -285,7 +285,7 @@ def LowerIsNanOp : Pat<(TF_IsNanOp $x), def GetAllAxes : NativeCodeCall< "GetI64ElementsAttrForSeq(" - "0, $0.getType().cast().getRank(), &$_builder)">; + "0, llvm::cast($0.getType()).getRank(), &$_builder)">; // L2Loss is lowered using the formula, // L2Loss(input) = Sum(input * input) / 2 diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 188fbbb6be532b..9ad34d2064c764 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -23,18 +23,18 @@ def IsDataFormatNHWC : ConstantAttr; // Get the last dimension size as a 1-d single element attr. def GetLastDimSizeAsI32 : NativeCodeCall< "DenseElementsAttr::get(RankedTensorType::get({1}, $_builder.getIntegerType(32)), " - "static_cast($0.getType().cast().getDimSize( " - " $0.getType().cast().getRank() - 1)))">; + "static_cast(llvm::cast($0.getType()).getDimSize( " + " llvm::cast($0.getType()).getRank() - 1)))">; // Check whether the tensor is ranked and whether its last dim is static. def IsRankedShapeLastDimStatic : Constraint()">, - CPred<"!$0.getType().cast().isDynamicDim( " - " $0.getType().cast().getRank() - 1)">]>>; + CPred<"llvm::isa($0.getType())">, + CPred<"!llvm::cast($0.getType()).isDynamicDim( " + " llvm::cast($0.getType()).getRank() - 1)">]>>; def IsNotComplexType : Constraint()">, - CPred<"!$0.getType().cast().getElementType().isa()"> + CPred<"llvm::isa($0.getType())">, + CPred<"!llvm::isa(llvm::cast($0.getType()).getElementType())"> ]>>; // Only fuse multiplier if all dimensions other than the channel dimension @@ -43,7 +43,7 @@ def CanFuseMulAndConv2D : Constraint>; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def DefinedByConv2D : Constraint($0.getDefiningOp())">>; // Checks if the value has only one user. def HasOneUse : Constraint>; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index a0404806ced750..e97e589709433c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -33,8 +33,8 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; // BatchNorm op patterns. //===----------------------------------------------------------------------===// -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def FalseBoolAttr : AttrConstraint($_self).getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CastValueToI64: NativeCodeCall< "CastValueToI64($0.getLoc(), $1, &$_builder)">; @@ -47,18 +47,18 @@ def CastValueToElementType: NativeCodeCall< // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; + "$0, llvm::cast($1.getType()).getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " + "$0, llvm::cast((*$1.begin()).getType()).getRank(), " "&$_builder)">; def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; + "llvm::cast(hlo::convertElementsAttr(" + "llvm::cast($0), $_builder.getIntegerType(64)))">; def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; @@ -274,17 +274,17 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + : CPred<"llvm::cast($_self).getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, "Scalar ElementsAttr">; def HasRankedFirstOperand - : Constraint()">>; + : Constraint((*$0.begin()).getType())">>; def IsShapedTensor - : Constraint()">>; + : Constraint($0.getType())">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -332,10 +332,10 @@ class MHLO_FftTypeValue : ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + "GetInnerDimFromValue(llvm::cast($0.getType()), &$_builder)">; def CheckInnerDimStatic - : Constraint(), &$_builder)">>; + : Constraint($0.getType()), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), @@ -364,14 +364,14 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + "SliceDenseIntElementsAttrColumn2D(llvm::cast($0), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr(llvm::cast($0), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; + "GetInteriorPadding(llvm::cast($0))">; def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), (MHLO_PadOp $input, $c, @@ -511,10 +511,10 @@ def UnpackStartingIndices: NativeCodeCall< "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; + "CanBeTranslatedToDynamicSlice($0, $1, llvm::cast($2))">>; def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "TFSliceSizes2HLOSliceSizes($0, $1, llvm::cast($2)," "&$_builder)">; def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, @@ -563,7 +563,7 @@ def : Pat<(TF_LegacyCallOp:$op $args, $args_attrs, $res_attrs, //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, llvm::cast($1), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td b/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td index 503fd6256f16ed..d3b0322095d8d7 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td +++ b/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td @@ -21,7 +21,7 @@ include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.td" class Quantize : NativeCodeCall<"TFR::Quantize(" # value # ", $0, $1, $_builder)">; class HasStringAttr : AttrConstraint< - CPred<"$_self.cast().getValue() == \"" # value # "\"">>; + CPred<"llvm::cast($_self).getValue() == \"" # value # "\"">>; def QuantActRangeNonePattern : Pattern< diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index d8e7617cc352ba..64f782d02346e8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -43,7 +43,7 @@ def TFFramework_Dialect : Dialect { } def TFFramework_OpKernelContextType : DialectType()">, + CPred<"llvm::isa<::mlir::kernel_gen::tf_framework::OpKernelContextType>($_self)">, "op_kernel_construction">, BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::OpKernelContextType>()"> { let description = [{ @@ -53,7 +53,7 @@ def TFFramework_OpKernelContextType : DialectType()">>, + "llvm::isa<::mlir::kernel_gen::tf_framework::JITCallableType>($_self)">>, BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::JITCallableType>()"> { let description = [{ A `callable` represents the result of JIT compilation. Conceptually, it @@ -107,7 +107,7 @@ def TFFramework_TFAllocOp : TFFramework_Op<"alloc", [ }]>]; let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } + MemRefType getType() { return llvm::cast(getResult().getType()); } static constexpr StringRef kReuseOutputAttrName = "reuse_output"; static constexpr StringRef kReuseInputCandidatesAttrName = "reuse_input_candidates"; diff --git a/tensorflow/core/ir/ops.td b/tensorflow/core/ir/ops.td index 0cb9ea90d8b92e..b6bbbee3b6e88e 100644 --- a/tensorflow/core/ir/ops.td +++ b/tensorflow/core/ir/ops.td @@ -684,7 +684,7 @@ class TFGraph_CaseLikeRegionOp : TFGraph_RegionOp< RegionAttr getPreservedAttrs(unsigned index) { if (auto attrs = getRegionAttrsAttr()) - return attrs[index].cast(); + return llvm::cast(attrs[index]); return {}; } void setPreservedAttrs(unsigned index, RegionAttr attrs) { diff --git a/tensorflow/dtensor/mlir/ir/tf_dtensor.td b/tensorflow/dtensor/mlir/ir/tf_dtensor.td index 999d8df041e74d..11a6ea761e00aa 100644 --- a/tensorflow/dtensor/mlir/ir/tf_dtensor.td +++ b/tensorflow/dtensor/mlir/ir/tf_dtensor.td @@ -31,17 +31,17 @@ include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// class DTensor_DTensorAttr : - Attr()">, + Attr($_self)">, "DTensor " # description # " attribute">; def DTensor_LayoutAttr : DTensor_DTensorAttr<"Layout", "layout"> { let returnType = "mlir::dtensor::LayoutAttr::Layout"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; } def DTensor_MeshAttr : DTensor_DTensorAttr<"Mesh", "mesh"> { let returnType = "mlir::dtensor::MeshAttr::Mesh"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 01fcef5757bbbe..09f65f9b4c6813 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -780,7 +780,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( if (instruction->opcode() == HloOpcode::kAsyncStart) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( - context_, result_type.cast().getTypes()); + context_, llvm::cast(result_type).getTypes()); // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, @@ -788,7 +788,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( .getOperation(); } else if (instruction->opcode() == HloOpcode::kAsyncUpdate) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( - context_, result_type.cast().getTypes()); + context_, llvm::cast(result_type).getTypes()); // XLA Feature -- MHLO Only return func_builder ->create(loc, bundle_result_type, diff --git a/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td b/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td index 6a72799112f780..6b9d7f4005639b 100644 --- a/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td +++ b/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td @@ -81,7 +81,7 @@ def XLAFramework_XLABufferToMemOp : XLAFramework_Op<"buffer_to_mem", }]>]; let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } + MemRefType getType() { return llvm::cast(getResult().getType()); } }]; let assemblyFormat = [{ $buffer attr-dict `:` type($result) From c926ed633455917acccbce6c6770c1eb85556ff6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 17 Apr 2025 10:01:34 -0700 Subject: [PATCH 0931/1324] [xla:cpu] Handle conditional operations in parallel task assignment PiperOrigin-RevId: 748717934 --- .../service/cpu/parallel_task_assignment.cc | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc index 3d9fd3169fd02d..da6a909b5d862c 100644 --- a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc +++ b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc @@ -230,18 +230,35 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( std::vector instructions(computation->instructions().begin(), computation->instructions().end()); for (auto* instruction : instructions) { - // Assign parallel tasks to sub-computations for While and Call HLOs. + // Assign parallel tasks to sub-computations for While, Conditional and Call + // HLOs. // TODO(b/27458679) Evaluate alternative intra-op parallelism placement, // and support other callable computations like reduce. + bool control_flow_hlo = false; + if (instruction->opcode() == HloOpcode::kWhile) { + control_flow_hlo = true; changed |= AssignParallelTasksHelper(module, instruction->while_body(), hlo_to_parallel_tasks); - continue; + + } else if (instruction->opcode() == HloOpcode::kConditional) { + control_flow_hlo = true; + for (HloComputation* branch : instruction->branch_computations()) { + changed |= + AssignParallelTasksHelper(module, branch, hlo_to_parallel_tasks); + } + } else if (instruction->opcode() == HloOpcode::kCall) { + control_flow_hlo = true; changed |= AssignParallelTasksHelper(module, instruction->to_apply(), hlo_to_parallel_tasks); + } + + // Continue to the next instruction if we handled control flow above. + if (control_flow_hlo) { continue; } + // Skip if no parallel tasks were computed in first pass. auto it = hlo_to_parallel_tasks.find(instruction); if (it == hlo_to_parallel_tasks.end()) { From 211695424f7497f7f4420ea9c647922e67c09d73 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 10:11:31 -0700 Subject: [PATCH 0932/1324] Update minimum CMake version to 3.5 CMake 4.0 removed support for CMake older than 3.5. PiperOrigin-RevId: 748721988 --- tensorflow/tools/pip_package/xla_build/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/pip_package/xla_build/CMakeLists.txt b/tensorflow/tools/pip_package/xla_build/CMakeLists.txt index 1690338a625754..b45970adfe188f 100644 --- a/tensorflow/tools/pip_package/xla_build/CMakeLists.txt +++ b/tensorflow/tools/pip_package/xla_build/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.4.3) +cmake_minimum_required(VERSION 3.5) file(GLOB_RECURSE TF_RUNTIME_SRC "*.cc") add_library(tf_xla_runtime_objects OBJECT From e5cd0f8bc7a7de595c26418abc5e5363be740170 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 10:48:19 -0700 Subject: [PATCH 0933/1324] [HLO Diff] For each instruction cache its used HloValues. This change significantly improves the runtime performance for larger graphs. PiperOrigin-RevId: 748735467 --- .../hlo_diff/matchers/hlo_gumgraph_matcher.cc | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc index d3cc4a6ca0254c..1156a7bfe59180 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/matchers/hlo_gumgraph_matcher.cc @@ -290,18 +290,29 @@ std::vector GetAllValuesUsedByInstruction( // Returns true if all HloValues used by the left and right nodes have their // defining instructions matched. -double AllOperandHloValuesMatchedScore(const HloInstructionNode* left_node, - const HloInstructionNode* right_node, - const HloGumgraph& left, - const HloGumgraph& right, - HloGumgraphMappings& mappings) { - std::vector left_hlo_values = - GetAllValuesUsedByInstruction(left_node->instruction, left); - std::vector right_hlo_values = - GetAllValuesUsedByInstruction(right_node->instruction, right); +double AllOperandHloValuesMatchedScore( + const HloInstructionNode* left_node, const HloInstructionNode* right_node, + const HloGumgraph& left, const HloGumgraph& right, + absl::flat_hash_map>& + instruction_used_values_cache, + HloGumgraphMappings& mappings) { + if (!instruction_used_values_cache.contains(left_node->instruction)) { + instruction_used_values_cache.emplace( + left_node->instruction, + GetAllValuesUsedByInstruction(left_node->instruction, left)); + } + if (!instruction_used_values_cache.contains(right_node->instruction)) { + instruction_used_values_cache.emplace( + right_node->instruction, + GetAllValuesUsedByInstruction(right_node->instruction, right)); + } + auto& left_hlo_values = instruction_used_values_cache[left_node->instruction]; + auto& right_hlo_values = + instruction_used_values_cache[right_node->instruction]; if (left_hlo_values.empty() || right_hlo_values.empty() || - left_hlo_values.size() != right_hlo_values.size()) { + (left_hlo_values.size() != right_hlo_values.size())) { return 0.0; } @@ -431,6 +442,8 @@ void GreedyLimitedCandidatesBottomUpMatcher::Match( HloGumgraphMappings& mappings) const { LOG(INFO) << "Running GreedyLimitedCandidatesBottomUpMatcher: matching " "subgraphs that match based on Dice similarity"; + absl::flat_hash_map> + instruction_used_values_cache; int current_mapping_count = mappings.left_to_right_instruction_map.size(); std::vector left_postorder = GetAllNodesInDfsOrder( left_.GetRoot(), DfsTraversalOrder::kPostOrder, left_.GetNodeCount()); @@ -479,7 +492,8 @@ void GreedyLimitedCandidatesBottomUpMatcher::Match( node.instruction->opcode() == left_node->instruction->opcode()) { // Found candidate. Calculate similarity. double operands_match_similarity = AllOperandHloValuesMatchedScore( - left_node, &node, left_, right_, mappings); + left_node, &node, left_, right_, instruction_used_values_cache, + mappings); double dice_sim = DiceSimLimitedSubgraph( left_node, &node, mappings, max_dice_subgraph_size_, left_.GetNodeCount(), right_.GetNodeCount()); From 15b97889af9a5d6a2f44bb700c114b5ace479279 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Thu, 17 Apr 2025 10:56:23 -0700 Subject: [PATCH 0934/1324] Bulk migration of HloTestBase tests to HloPjRtTestBase, where it was possible. PiperOrigin-RevId: 748738483 --- .../xla/xla/backends/gpu/runtime/BUILD | 22 +- .../gpu/runtime/command_buffer_thunk_test.cc | 11 +- .../backends/gpu/runtime/while_thunk_test.cc | 20 +- third_party/xla/xla/codegen/emitters/ir/BUILD | 3 +- .../xla/codegen/emitters/ir/xla_ops_test.cc | 4 +- third_party/xla/xla/hlo/builder/lib/BUILD | 48 +- .../xla/hlo/builder/lib/arithmetic_test.cc | 6 +- .../xla/hlo/builder/lib/comparators_test.cc | 6 +- .../xla/xla/hlo/builder/lib/math_test.cc | 6 +- .../xla/xla/hlo/builder/lib/matrix_test.cc | 6 +- .../xla/xla/hlo/builder/lib/pooling_test.cc | 6 +- .../xla/xla/hlo/builder/lib/prng_test.cc | 6 +- .../hlo/builder/lib/self_adjoint_eig_test.cc | 13 +- .../xla/xla/hlo/builder/lib/svd_test.cc | 9 +- .../xla/hlo/builder/lib/tridiagonal_test.cc | 6 +- third_party/xla/xla/service/BUILD | 32 +- .../algebraic_simplifier_overflow_test.cc | 7 +- .../xla/service/batchnorm_expander_test.cc | 10 +- .../xla/service/dynamic_update_slice_test.cc | 23 +- third_party/xla/xla/service/gpu/tests/BUILD | 35 +- .../gpu/tests/async_kernel_launch_test.cc | 6 +- .../service/gpu/tests/command_buffer_test.cc | 6 +- .../gpu/tests/dynamic_slice_fusion_test.cc | 10 +- .../element_wise_row_vectorization_test.cc | 7 +- .../xla/service/gpu/tests/in_place_op_test.cc | 5 +- .../service/gpu/tests/nop_custom_call_test.cc | 4 +- .../tests/tensor_float_32_global_var_test.cc | 23 +- .../xla/xla/service/gpu/transforms/BUILD | 17 +- .../block_scaling_rewriter_cudnn_test.cc | 11 +- .../gpu/transforms/sort_rewriter_test.cc | 7 +- .../triton_fusion_numerics_verifier_test.cc | 8 +- .../scatter_determinism_expander_test.cc | 6 +- third_party/xla/xla/tests/BUILD | 423 ++++++++++++------ third_party/xla/xla/tests/axpy_simple_test.cc | 7 +- .../xla/xla/tests/batch_norm_grad_test.cc | 6 +- .../xla/xla/tests/batch_norm_training_test.cc | 6 +- third_party/xla/xla/tests/bfloat16_test.cc | 6 +- .../xla/xla/tests/binop_scaling_test.cc | 6 +- .../xla/xla/tests/bitcast_convert_test.cc | 11 +- third_party/xla/xla/tests/call_test.cc | 8 +- .../xla/tests/check_execution_arity_test.cc | 6 +- .../collective_pipeliner_execution_test.cc | 7 +- .../xla/xla/tests/complex_unary_op_test.cc | 8 +- third_party/xla/xla/tests/concatenate_test.cc | 10 +- .../tests/constant_reduction_function_test.cc | 10 +- third_party/xla/xla/tests/constants_test.cc | 8 +- .../xla/xla/tests/conv_depthwise_common.cc | 15 +- .../xla/xla/tests/conv_depthwise_common.h | 13 +- .../convolution_dimension_numbers_test.cc | 7 +- .../xla/xla/tests/convolution_test_1d.cc | 6 +- .../xla/tests/convolution_variants_test.cc | 7 +- third_party/xla/xla/tests/dynamic_ops_test.cc | 31 +- third_party/xla/xla/tests/float8_test.cc | 6 +- third_party/xla/xla/tests/floor_ceil_test.cc | 7 +- third_party/xla/xla/tests/fmax_fmin_test.cc | 7 +- .../xla/xla/tests/get_dimension_size_test.cc | 8 +- third_party/xla/xla/tests/half_test.cc | 6 +- third_party/xla/xla/tests/iota_test.cc | 14 +- third_party/xla/xla/tests/log_test.cc | 7 +- .../xla/tests/multidimensional_slice_test.cc | 7 +- .../xla/tests/nccl_group_execution_test.cc | 6 +- third_party/xla/xla/tests/numerics_test.cc | 18 +- third_party/xla/xla/tests/pred_test.cc | 6 +- .../xla/xla/tests/ptxas_bug_120501638.cc | 12 +- .../xla/xla/tests/reduce_precision_test.cc | 6 +- third_party/xla/xla/tests/reverse_test.cc | 9 +- third_party/xla/xla/tests/rng_test.cc | 10 +- .../xla/xla/tests/runtime_topk_test.cc | 6 +- third_party/xla/xla/tests/sample_text_test.cc | 11 +- .../xla/xla/tests/scalar_computations_test.cc | 10 +- .../xla/xla/tests/select_and_scatter_test.cc | 6 +- third_party/xla/xla/tests/select_test.cc | 6 +- .../xla/xla/tests/set_dimension_size_test.cc | 6 +- third_party/xla/xla/tests/sort_test.cc | 16 +- .../xla/xla/tests/stochastic_convert_test.cc | 10 +- third_party/xla/xla/tests/topk_test.cc | 13 +- third_party/xla/xla/tests/transpose_test.cc | 8 +- third_party/xla/xla/tests/unary_op_test.cc | 6 +- .../xla/xla/tests/vector_ops_reduce_test.cc | 7 +- .../xla/xla/tests/vector_ops_simple_test.cc | 7 +- third_party/xla/xla/tools/BUILD | 5 +- .../xla/xla/tools/hlo_decomposer_test.cc | 6 +- 82 files changed, 797 insertions(+), 445 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 6aebea75c6f506..e88cb686ba6a59 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -360,16 +360,20 @@ xla_test( "gpu_b200", "gpu_amd_any", ], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":command_buffer_cmd", ":command_buffer_thunk", ":dynamic_slice_thunk", + ":gpublas_lt_matmul_thunk", ":memset_thunk", ":sequential_thunk", ":thunk", + "//xla:error_spec", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:executable", @@ -396,6 +400,8 @@ xla_test( "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", @@ -1331,22 +1337,30 @@ xla_test( name = "while_thunk_test", srcs = ["while_thunk_test.cc"], backends = ["gpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ + ":sequential_thunk", ":thunk", ":while_thunk", "//xla:executable_run_options", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index e5bf0e10571ec0..bbcb3b5f08565c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -29,12 +29,16 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/types/span.h" #include "xla/backends/gpu/runtime/command_buffer_cmd.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" +#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" @@ -63,7 +67,8 @@ limitations under the License. #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/types.h" // IWYU pragma: keep @@ -1420,10 +1425,10 @@ TEST(CommandBufferThunkTest, WhileCmd) { ASSERT_EQ(dst, std::vector(4, 15)); } -class CmdBufferTest : public HloTestBase { +class CmdBufferTest : public HloPjRtInterpreterReferenceMixin { public: DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + DebugOptions debug_options = HloPjRtTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_autotune_level(0); debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); debug_options.set_xla_gpu_graph_min_graph_size(1); diff --git a/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc index e74deefb1910e9..65f8b840dcea18 100644 --- a/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc @@ -15,23 +15,33 @@ limitations under the License. #include "xla/backends/gpu/runtime/while_thunk.h" +#include #include #include +#include +#include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::gpu { namespace { @@ -65,7 +75,7 @@ class IterationLoggerThunk : public Thunk { // Non-known trip count while thunks are difficult to unit test, so we only have // a unit test for the known trip count case. -class KnownTripCountWhileThunkTest : public HloTestBase { +class KnownTripCountWhileThunkTest : public HloPjRtTestBase { protected: absl::StatusOr CreateFakeWhileInstruction() { constexpr absl::string_view kDummyModule = R"( diff --git a/third_party/xla/xla/codegen/emitters/ir/BUILD b/third_party/xla/xla/codegen/emitters/ir/BUILD index 11da591197ee32..b396ff3fc6ef19 100644 --- a/third_party/xla/xla/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/codegen/emitters/ir/BUILD @@ -111,12 +111,13 @@ xla_test( name = "xla_ops_test", srcs = ["xla_ops_test.cc"], backends = ["cpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/testlib:filecheck", "//xla/mlir/utils:error_util", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", diff --git a/third_party/xla/xla/codegen/emitters/ir/xla_ops_test.cc b/third_party/xla/xla/codegen/emitters/ir/xla_ops_test.cc index eaa43a82e6d04d..d6c565694107de 100644 --- a/third_party/xla/xla/codegen/emitters/ir/xla_ops_test.cc +++ b/third_party/xla/xla/codegen/emitters/ir/xla_ops_test.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/hlo/analysis/indexing_map_serialization.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/mlir/utils/error_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -67,7 +67,7 @@ absl::StatusOr> ParseMlirModuleString( return std::move(module); } -class XLAOpsTest : public HloTestBase { +class XLAOpsTest : public HloPjRtTestBase { public: mlir::MLIRContext mlir_context_; }; diff --git a/third_party/xla/xla/hlo/builder/lib/BUILD b/third_party/xla/xla/hlo/builder/lib/BUILD index 3c6faf303adec8..219a45484be324 100644 --- a/third_party/xla/xla/hlo/builder/lib/BUILD +++ b/third_party/xla/xla/hlo/builder/lib/BUILD @@ -42,13 +42,15 @@ cc_library( xla_test( name = "arithmetic_test", srcs = ["arithmetic_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":arithmetic", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/platform:test", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -76,6 +78,7 @@ cc_library( xla_test( name = "comparators_test", srcs = ["comparators_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":comparators", ":constants", @@ -87,7 +90,8 @@ xla_test( "//xla/hlo/testlib:test", "//xla/service:hlo_proto_cc", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/platform:test", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", @@ -234,6 +238,7 @@ xla_test( name = "math_test", timeout = "long", srcs = ["math_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":constants", ":math", @@ -248,7 +253,8 @@ xla_test( "//xla/hlo/testlib:test", "//xla/service", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:test_macros_header", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:test", @@ -286,6 +292,7 @@ cc_library( xla_test( name = "matrix_test", srcs = ["matrix_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":constants", ":matrix", @@ -300,7 +307,8 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -334,6 +342,7 @@ cc_library( xla_test( name = "pooling_test", srcs = ["pooling_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":pooling", "//xla:error_spec", @@ -341,7 +350,8 @@ xla_test( "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/platform:test", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", @@ -370,6 +380,7 @@ cc_library( xla_test( name = "prng_test", srcs = ["prng_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":constants", ":prng", @@ -378,7 +389,8 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -601,7 +613,10 @@ xla_test( srcs = ["self_adjoint_eig_test.cc"], real_hardware_only = True, shard_count = 5, - tags = ["optonly"], + tags = [ + "optonly", + "test_migrated_to_hlo_runner_pjrt", + ], deps = [ ":arithmetic", ":constants", @@ -621,7 +636,8 @@ xla_test( "//xla/hlo/testlib:test", "//xla/tests:client_library_test_base", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:test_macros_header", "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", @@ -660,7 +676,10 @@ xla_test( srcs = ["svd_test.cc"], real_hardware_only = True, shard_count = 10, - tags = ["optonly"], + tags = [ + "optonly", + "test_migrated_to_hlo_runner_pjrt", + ], deps = [ ":arithmetic", ":constants", @@ -674,7 +693,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:test_macros_header", "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", @@ -708,7 +728,10 @@ xla_test( srcs = ["tridiagonal_test.cc"], real_hardware_only = True, shard_count = 10, - tags = ["optonly"], + tags = [ + "optonly", + "test_migrated_to_hlo_runner_pjrt", + ], deps = [ ":slicing", ":tridiagonal", @@ -719,7 +742,8 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", "//xla/tests:client_library_test_runner_mixin", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:test_macros_header", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc index 20689b19cf1c48..575bb05a8f816d 100644 --- a/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc @@ -23,14 +23,16 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class ArithmeticTest : public ClientLibraryTestRunnerMixin { +class ArithmeticTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: template void TestArgMin(std::initializer_list> input, diff --git a/third_party/xla/xla/hlo/builder/lib/comparators_test.cc b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc index 51eb6f10c38fb7..974ae4899046b9 100644 --- a/third_party/xla/xla/hlo/builder/lib/comparators_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc @@ -30,7 +30,8 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/hlo.pb.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" @@ -38,7 +39,8 @@ limitations under the License. namespace xla { namespace { -class ComparatorsTest : public ClientLibraryTestRunnerMixin { +class ComparatorsTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: ComparatorsTest() : builder_(TestName()) {} XlaBuilder* builder() { return &builder_; } diff --git a/third_party/xla/xla/hlo/builder/lib/math_test.cc b/third_party/xla/xla/hlo/builder/lib/math_test.cc index 7dd27f0911efa9..243f19ff44c78f 100644 --- a/third_party/xla/xla/hlo/builder/lib/math_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/math_test.cc @@ -36,7 +36,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/test.h" @@ -48,7 +49,8 @@ namespace { constexpr ErrorSpec kErrorSpec{0.0001}; -using MathTest = ClientLibraryTestRunnerMixin; +using MathTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; // Write TYPED_TESTs within the class definition so that we don't have to litter // "this->" everywhere. diff --git a/third_party/xla/xla/hlo/builder/lib/matrix_test.cc b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc index 3a21e6240d8959..e35f52182aef15 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc @@ -37,14 +37,16 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" namespace xla { namespace { -class MatrixTest : public ClientLibraryTestRunnerMixin { +class MatrixTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: template void TestMatrixDiagonal(); diff --git a/third_party/xla/xla/hlo/builder/lib/pooling_test.cc b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc index f068938c9a20f0..1a32613ae864a5 100644 --- a/third_party/xla/xla/hlo/builder/lib/pooling_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc @@ -26,7 +26,8 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { @@ -69,7 +70,8 @@ std::vector ExpandWithBatchAndFeatureDimensions( return tensor_sizes; } -using PoolingTest = ClientLibraryTestRunnerMixin; +using PoolingTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(PoolingTest, MaxPool2D) { XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/hlo/builder/lib/prng_test.cc b/third_party/xla/xla/hlo/builder/lib/prng_test.cc index 6117d49524aeb4..c7e4e8d4773ae4 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng_test.cc @@ -26,14 +26,16 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class PrngTest : public ClientLibraryTestRunnerMixin { +class PrngTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: template { +class SelfAdjointEigTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: void SetUp() override { - ClientLibraryTestRunnerMixin::SetUp(); + ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>::SetUp(); batch_3d_4x4_ = Array3D{ { {4, 6, 8, 10}, @@ -280,7 +284,8 @@ Array2D GenerateRandomSymmetricMatrix(int size) { } using EighTestCase = int64_t; -class RandomEighTest : public ClientLibraryTestRunnerMixin, +class RandomEighTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface {}; XLA_TEST_P(RandomEighTest, Random) { diff --git a/third_party/xla/xla/hlo/builder/lib/svd_test.cc b/third_party/xla/xla/hlo/builder/lib/svd_test.cc index 28b7c0172e8184..e7ad1dcaec1b4a 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd_test.cc @@ -31,17 +31,20 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { -class SVDTest : public ClientLibraryTestRunnerMixin { +class SVDTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: void SetUp() override { - ClientLibraryTestRunnerMixin::SetUp(); + ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>::SetUp(); batch_3d_4x5_ = Array3D{ { {4, 6, 8, 10, 1}, diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc index 76795ba29cb1bf..c6bf2c86bd7b4a 100644 --- a/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc @@ -28,7 +28,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -39,7 +40,8 @@ namespace tridiagonal { namespace { class TridiagonalTest - : public ClientLibraryTestRunnerMixin, + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface> {}; XLA_TEST_P(TridiagonalTest, SimpleTridiagonalMatMulOk) { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 18e05363012f9a..ba22d248a001cb 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -737,15 +737,22 @@ xla_test( "cpu", "gpu", ], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ - "//xla:execution_options_util", - "//xla:status_macros", - "//xla/hlo/parser:hlo_parser", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/testlib:test", "//xla/tests:client_library_test_base", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:test_macros_header", + "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", ], ) @@ -2177,13 +2184,14 @@ xla_test( "cpu", "gpu", ], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":scatter_determinism_expander", "//xla:literal", "//xla/hlo/testlib:test", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/tsl/platform:statusor", ], ) @@ -2241,28 +2249,30 @@ xla_test( "cpu", "gpu", ], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":batchnorm_expander", "//xla:error_spec", "//xla:shape_util", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:test", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/tsl/platform:statusor", ], ) xla_test( name = "algebraic_simplifier_overflow_test", srcs = ["algebraic_simplifier_overflow_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:error_spec", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", ], diff --git a/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc b/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc index 071f9994b54a08..a6ddb38e41870e 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc @@ -15,16 +15,17 @@ limitations under the License. #include #include -#include #include #include "xla/error_spec.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" namespace xla { namespace { -class AlgebraicSimplifierOverflowTest : public HloTestBase {}; +class AlgebraicSimplifierOverflowTest + : public HloPjRtInterpreterReferenceMixin {}; // Test that the algebraic simplifier does not generate integer overflows // by moving the subtraction to the other side of the comparison diff --git a/third_party/xla/xla/service/batchnorm_expander_test.cc b/third_party/xla/xla/service/batchnorm_expander_test.cc index 658426f867873b..85141237f6164d 100644 --- a/third_party/xla/xla/service/batchnorm_expander_test.cc +++ b/third_party/xla/xla/service/batchnorm_expander_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/service/batchnorm_expander.h" +#include #include -#include #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" @@ -26,14 +26,16 @@ limitations under the License. #include "xla/hlo/testlib/test.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { -class BatchNormExpanderTest : public HloTestBase { +class BatchNormExpanderTest + : public HloPjRtInterpreterReferenceMixin { protected: // BatchNorm should have a dynamic sized divider for mean operations. int64_t CountGetDimensionSize(const HloModule& module) { diff --git a/third_party/xla/xla/service/dynamic_update_slice_test.cc b/third_party/xla/xla/service/dynamic_update_slice_test.cc index 154657307b09a9..fd1ee8815be79f 100644 --- a/third_party/xla/xla/service/dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/dynamic_update_slice_test.cc @@ -13,18 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/execution_options_util.h" -#include "xla/hlo/parser/hlo_parser.h" +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/span.h" +#include "xla/error_spec.h" #include "xla/hlo/testlib/test.h" -#include "xla/status_macros.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tests/test_utils.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { -class DynamicUpdateSliceTest : public HloTestBase {}; +using DynamicUpdateSliceTest = + HloPjRtInterpreterReferenceMixin; XLA_TEST_F(DynamicUpdateSliceTest, ShardedInPlaceDUS) { // A dynamic-update-slice within a while loop. This construction is an easy diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index bfd970cfcd3e14..81b0d586ae9dcd 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -74,7 +74,10 @@ xla_test( srcs = if_gpu_is_configured(["dynamic_slice_fusion_test.cc"]), backends = ["gpu"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = ["notsan"], # TODO(b/345034145): Fix tsan error. + tags = [ + "notsan", + "test_migrated_to_hlo_runner_pjrt", + ], # TODO(b/345034145): Fix tsan error. deps = if_gpu_is_configured( #keep sorted [ @@ -90,7 +93,9 @@ xla_test( ) + [ "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:test", ], ) @@ -98,10 +103,13 @@ xla_test( name = "element_wise_row_vectorization_test", srcs = ["element_wise_row_vectorization_test.cc"], backends = ["gpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:error_spec", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:test", ], ) @@ -122,14 +130,16 @@ xla_test( srcs = ["async_kernel_launch_test.cc"], backends = ["gpu"], # "requires-net:external" tag allows uploading `xprof` results. - tags = if_google(["requires-net:external"]), + tags = if_google(["requires-net:external"]) + ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:debug_options_flags", + "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla:xla_proto_cc", "//xla/service:hlo_module_config", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:literal_test_util", "@com_google_googletest//:gtest_main", ], @@ -139,12 +149,11 @@ xla_test( name = "command_buffer_test", srcs = ["command_buffer_test.cc"], backends = ["gpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ - "//xla:debug_options_flags", "//xla:literal", "//xla:literal_util", - "//xla/service:hlo_module_config", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:literal_test_util", "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", @@ -676,9 +685,11 @@ xla_test( name = "in_place_op_test", srcs = ["in_place_op_test.cc"], backends = ["gpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:debug_options_flags", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "@com_google_googletest//:gtest_main", ], ) @@ -722,9 +733,12 @@ xla_test( ] + if_oss([ "gpu_any", ]), + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:error_spec", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", + "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:tensor_float_32_utils", ], @@ -894,10 +908,11 @@ xla_test( name = "nop_custom_call_test", srcs = ["nop_custom_call_test.cc"], backends = ["gpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:literal", "//xla:literal_util", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:literal_test_util", "//xla/tsl/platform:test", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc b/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc index 4acb8ec495a783..f47e560b332f09 100644 --- a/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc +++ b/third_party/xla/xla/service/gpu/tests/async_kernel_launch_test.cc @@ -17,17 +17,19 @@ limitations under the License. #include #include "xla/debug_options_flags.h" +#include "xla/error_spec.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/xla.pb.h" namespace xla::gpu { namespace { -class AsyncKernelLaunchTest : public HloTestBase {}; +using AsyncKernelLaunchTest = HloPjRtInterpreterReferenceMixin; HloModuleConfig GetModuleConfig() { // Allow even small graphs to be launched on the GPU. diff --git a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc index 446904b0019deb..1818613d7830a8 100644 --- a/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc +++ b/third_party/xla/xla/service/gpu/tests/command_buffer_test.cc @@ -20,17 +20,17 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/platform/statusor.h" namespace xla::gpu { namespace { -class CommandBufferTest : public HloTestBase, +class CommandBufferTest : public HloPjRtTestBase, public ::testing::WithParamInterface { DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + DebugOptions debug_options = HloPjRtTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_graph_enable_concurrent_region(GetParam()); return debug_options; } diff --git a/third_party/xla/xla/service/gpu/tests/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/tests/dynamic_slice_fusion_test.cc index e379e512a57e5f..6238a5cba57141 100644 --- a/third_party/xla/xla/service/gpu/tests/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/tests/dynamic_slice_fusion_test.cc @@ -13,20 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include -#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "xla/error_spec.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" -#include "xla/primitive_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/test.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { @@ -38,7 +34,7 @@ static constexpr char kPlatform[] = "CUDA"; static constexpr char kPlatform[] = "ROCM"; #endif -class DynamicSliceFusionTest : public HloTestBase {}; +class DynamicSliceFusionTest : public HloPjRtTestBase {}; TEST_F(DynamicSliceFusionTest, GemmSlice) { const char* hlo_reference = R"( diff --git a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization_test.cc b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization_test.cc index ffd507f94959c3..6ec496842a0891 100644 --- a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization_test.cc +++ b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization_test.cc @@ -11,13 +11,16 @@ limitations under the License. ==============================================================================*/ #include "xla/error_spec.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { namespace { -class ElementWiseRowVectorizationTest : public HloTestBase {}; +using ElementWiseRowVectorizationTest = + HloPjRtInterpreterReferenceMixin; TEST_F(ElementWiseRowVectorizationTest, SimpleAddSmallRowBroadcastingTest) { const char* hlo_text = R"( diff --git a/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc b/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc index 17af5a49919eaa..00832954536d63 100644 --- a/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc +++ b/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc @@ -16,13 +16,14 @@ limitations under the License. #include #include "xla/debug_options_flags.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" namespace xla { namespace gpu { namespace { -class InPlaceOpTest : public HloTestBase { +class InPlaceOpTest : public HloPjRtInterpreterReferenceMixin { // Don't override any flags. DebugOptions GetDebugOptionsForTest() const override { return GetDebugOptionsFromFlags(); diff --git a/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc index 06df6792eb3e9a..cf8e79aead3b87 100644 --- a/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/platform/test.h" @@ -26,7 +26,7 @@ namespace xla { namespace gpu { namespace { -class NopCustomCallTest : public HloTestBase {}; +class NopCustomCallTest : public HloPjRtTestBase {}; TEST_F(NopCustomCallTest, RunAllocateBufferAndUpdate) { // The test uses a custom call with the AllocateBuffer target (also known as diff --git a/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc b/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc index 39d20ea80c9f28..ca9a7ffce30174 100644 --- a/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc +++ b/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc @@ -16,13 +16,19 @@ limitations under the License. #include #include "xla/error_spec.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/tensor_float_32_utils.h" namespace xla { namespace gpu { namespace { +// The error tolerances are small enough so that the use of TF32 will cause +// the error to be greater than the tolerances. +constexpr ErrorSpec kErrorSpec = ErrorSpec{1e-4, 1e-4}; + // Test that setting the TensorFloat-32 global variable to false causes // TensorFloat-32 not to be used, even when the operand precision is set to the // default. @@ -30,15 +36,12 @@ namespace { // NOTE: Unfortunately TF2XLA doesn't set the precision config for all // operations based on tensor_float_32_execution_enabled(), so we can not ignore // the global variable. -class TensorFloat32GlobalVarTest : public ::testing::WithParamInterface, - public HloTestBase { +class TensorFloat32GlobalVarTest + : public ::testing::WithParamInterface, + public HloPjRtInterpreterReferenceMixin { protected: TensorFloat32GlobalVarTest() { tsl::enable_tensor_float_32_execution(false); - - // The error tolerances are small enough so that the use of TF32 will cause - // the error to be greater than the tolerances. - error_spec_ = ErrorSpec{1e-4, 1e-4}; } ~TensorFloat32GlobalVarTest() override { @@ -46,7 +49,7 @@ class TensorFloat32GlobalVarTest : public ::testing::WithParamInterface, } DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + DebugOptions debug_options = HloPjRtTestBase::GetDebugOptionsForTest(); const bool enable_triton_gemm = GetParam(); if (enable_triton_gemm) { debug_options.set_xla_gpu_enable_triton_gemm(true); @@ -68,7 +71,7 @@ ENTRY %dot_computation (x: f32[1024,1024], source: f32[1024,1024]) -> f32[1024,1 ROOT %result = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default, default} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, error_spec_)); + EXPECT_TRUE(RunAndCompare(hlo_text, kErrorSpec)); } TEST_P(TensorFloat32GlobalVarTest, Convolution) { @@ -81,7 +84,7 @@ ENTRY %conv_computation (x: f32[16,40,40,64], source: f32[3,3,64,64]) -> f32[16, ROOT %result = f32[16,40,40,64] convolution(x, y), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, operand_precision={default, default} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, error_spec_)); + EXPECT_TRUE(RunAndCompare(hlo_text, kErrorSpec)); } std::string TestParamToString(const ::testing::TestParamInfo& info) { diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index f9fabb01c4709b..94cf69298e5da7 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -518,13 +518,17 @@ xla_test( name = "block_scaling_rewriter_cudnn_test", srcs = ["block_scaling_rewriter_cudnn_test.cc"], backends = ["gpu_b200"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":block_scaling_rewriter", - "//xla/tests:hlo_test_base", + "//xla:error_spec", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", ], ) @@ -3044,6 +3048,7 @@ xla_test( ], tags = [ "cuda-only", + "test_migrated_to_hlo_runner_pjrt", ], deps = [ ":sort_rewriter", @@ -3056,7 +3061,8 @@ xla_test( "//xla/service:pattern_matcher", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_interpreter_reference_mixin", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", @@ -3392,6 +3398,7 @@ xla_test( "gpu_h100", "gpu_b200", ], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":triton_fusion_numerics_verifier", "//xla:shape_util", @@ -3402,13 +3409,13 @@ xla_test( "//xla/service/gpu/autotuning:autotuner_compile_util", "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:platform", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_matchers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", ], ) diff --git a/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc b/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc index 7f4e7446717072..84e5e21c70d2e8 100644 --- a/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc @@ -13,17 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/strings/string_view.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/transforms/block_scaling_rewriter.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/status_matchers.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/status_matchers.h" namespace xla::gpu { namespace { using ::tsl::testing::IsOkAndHolds; -using BlockScalingRewriterCudnnTest = HloTestBase; +using BlockScalingRewriterCudnnTest = + HloPjRtInterpreterReferenceMixin; TEST_F(BlockScalingRewriterCudnnTest, Mxfp8) { constexpr absl::string_view hlo_string = R"( diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc index a7ba46633bf5d5..47637a0c2c28f6 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -33,7 +33,8 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/pattern_matcher.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -44,11 +45,11 @@ namespace { namespace m = ::xla::match; class SortRewriterTest - : public HloTestBase, + : public HloPjRtInterpreterReferenceMixin, public ::testing::WithParamInterface> { public: void SetUp() override { - HloTestBase::SetUp(); + HloPjRtInterpreterReferenceMixin::SetUp(); SortRewriter::SetSortModeForTestingOnly(SortRewriter::Mode::kAlways); } diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index 6861a4d858e718..1d5c6e277882fe 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -31,21 +31,21 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/platform.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/status_matchers.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status_matchers.h" namespace xla::gpu { namespace { class TritonFusionNumericsVerifierTest - : public HloTestBase, + : public HloPjRtTestBase, public ::testing::WithParamInterface { public: DebugOptions GetDebugOptionsForTest() const override { - auto options = HloTestBase::GetDebugOptionsForTest(); + auto options = HloPjRtTestBase::GetDebugOptionsForTest(); options.set_xla_gpu_verify_triton_fusion_numerics(true); return options; } diff --git a/third_party/xla/xla/service/scatter_determinism_expander_test.cc b/third_party/xla/xla/service/scatter_determinism_expander_test.cc index b530a8d23b77f3..485d1d72b35e28 100644 --- a/third_party/xla/xla/service/scatter_determinism_expander_test.cc +++ b/third_party/xla/xla/service/scatter_determinism_expander_test.cc @@ -21,13 +21,13 @@ limitations under the License. #include "xla/hlo/testlib/test.h" #include "xla/literal.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { -class ScatterDeterminismExpanderTest : public HloTestBase {}; +class ScatterDeterminismExpanderTest : public HloPjRtTestBase {}; TEST_F(ScatterDeterminismExpanderTest, DoNotEliminateScatterWithAssociativeCombiner) { diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 7857f256b1cf19..7e8af90adfc4b5 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -605,11 +605,13 @@ xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], tags = [ + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", @@ -675,11 +677,14 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", "//xla:error_spec", "//xla:shape_util", @@ -750,10 +755,14 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:arithmetic", @@ -765,10 +774,14 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", @@ -809,10 +822,14 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", "//xla:literal", @@ -834,11 +851,14 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", "//xla:literal", @@ -854,10 +874,14 @@ xla_test( name = "scalar_computations_test", srcs = ["scalar_computations_test.cc"], shard_count = 32, - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep @@ -955,26 +979,29 @@ cc_library( testonly = True, srcs = ["conv_depthwise_common.cc"], hdrs = ["conv_depthwise_common.h"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":client_library_test_base", - ":hlo_test_base", - ":test_macros_header", - "//xla:execution_options_util", - "//xla:status_macros", - "//xla/hlo/builder:xla_computation", "//xla/hlo/testlib:test", - "//xla/hlo/transforms:despecializer", - "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/tsl/platform:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_for_library", ], ) xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", @@ -1097,6 +1124,7 @@ xla_test( ], deps = [ ":client_library_test_base", + ":hlo_pjrt_test_base", ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", @@ -1146,6 +1174,7 @@ xla_test( ], deps = [ ":client_library_test_base", + ":hlo_pjrt_test_base", ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", @@ -1273,6 +1302,7 @@ xla_test( ], deps = [ ":client_library_test_base", + ":hlo_pjrt_test_base", ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", @@ -1304,10 +1334,14 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", @@ -1325,10 +1359,14 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", @@ -1401,11 +1439,12 @@ xla_test( tags = [ "cuda-only", "optonly", + "test_migrated_to_hlo_runner_pjrt", ], deps = [ - ":client_library_test_base", ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:array3d", @@ -1433,11 +1472,13 @@ xla_test( tags = [ "cuda-only", "optonly", + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", @@ -1503,10 +1544,12 @@ xla_test( tags = [ "cuda-only", "optonly", + "test_migrated_to_hlo_runner_pjrt", ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", @@ -1562,10 +1605,14 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - tags = ["cuda-only"], + tags = [ + "cuda-only", + "test_migrated_to_hlo_runner_pjrt", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", @@ -1590,11 +1637,14 @@ xla_test( "cpu": ["nomsan"], }, shard_count = 30, - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", "//xla:array4d", @@ -1615,11 +1665,14 @@ xla_test( timeout = "long", srcs = ["convolution_dimension_numbers_test.cc"], shard_count = 20, - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array4d", "//xla:error_spec", @@ -1701,10 +1754,14 @@ xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], shard_count = 40, - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:array4d", @@ -1719,10 +1776,14 @@ xla_test( xla_test( name = "float8_test", srcs = ["float8_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla/hlo/builder:xla_builder", "//xla/hlo/testlib:test", @@ -1739,10 +1800,14 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", @@ -1808,11 +1873,14 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:array3d", @@ -1828,12 +1896,13 @@ xla_test( srcs = ["dynamic_ops_test.cc"], shard_count = 4, tags = [ + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ] + if_oss(["not_run:arm"]), deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:array3d", @@ -1894,10 +1963,14 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", @@ -2026,11 +2099,13 @@ xla_test( "no_mac", # b/194731834 "nozapfhahn", "optonly", + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:array", @@ -2108,28 +2183,35 @@ xla_test( xla_test( name = "sort_test", srcs = ["sort_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", - ":test_macros_header", - ":xla_internal_test_main", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", + ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", ], ) xla_test( name = "topk_test", srcs = ["topk_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", "//xla:error_spec", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", ], @@ -2139,17 +2221,20 @@ xla_test( name = "runtime_topk_test", srcs = ["runtime_topk_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:literal", "//xla:literal_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", ], ) @@ -2175,11 +2260,14 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", "//xla:literal", @@ -2246,10 +2334,14 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", "//xla:array2d", "//xla:array4d", @@ -2312,11 +2404,14 @@ xla_test( xla_test( name = "fmax_fmin_test", srcs = ["fmax_fmin_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", "//xla:literal", @@ -2329,11 +2424,14 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array3d", "//xla:error_spec", @@ -2405,19 +2503,22 @@ xla_test( name = "rng_test", srcs = ["rng_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:rng_bit_generator_expander", "//xla/hlo/transforms/expanders:rng_expander", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -2474,11 +2575,15 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", ":client_library_test_runner_utils", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array4d", "//xla:error_spec", @@ -2499,28 +2604,36 @@ xla_test( name = "stochastic_convert_test", srcs = ["stochastic_convert_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:array4d", "//xla:error_spec", @@ -2700,8 +2813,9 @@ xla_test( backends = [ "gpu", ], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":xla_internal_test_main", "//xla:literal", "//xla/hlo/testlib:verified_hlo_module", @@ -2766,9 +2880,12 @@ xla_test( xla_test( name = "collective_pipeliner_execution_test", srcs = ["collective_pipeliner_execution_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":xla_internal_test_main", "//xla:error_spec", "//xla:util", @@ -2777,6 +2894,8 @@ xla_test( "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:collective_pipeliner", + "//xla/service:hlo_verifier", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", @@ -2817,11 +2936,14 @@ xla_test( xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", "//xla:error_spec", "//xla:shape_util", @@ -2853,11 +2975,14 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla/hlo/builder:xla_builder", "//xla/tsl/platform:test", @@ -3315,13 +3440,15 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", - ":literal_test_util", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", - "//xla:types", + "//xla:error_spec", "//xla/hlo/testlib:test", ], ) @@ -3372,11 +3499,13 @@ xla_test( }, shard_count = 50, tags = [ + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:error_spec", @@ -3416,30 +3545,40 @@ xla_test( tags = [ # Disabled in OSS until nvidia publicly releases a fixed ptxas. "no_oss", + "test_migrated_to_hlo_runner_pjrt", "test_xla_cpu_no_thunks", ], deps = [ - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:debug_options_flags", + "//xla:error_spec", "//xla/hlo/testlib:test", + "//xla/service:hlo_module_config", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) xla_test( name = "get_dimension_size_test", srcs = ["get_dimension_size_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", ], ) @@ -3449,16 +3588,19 @@ xla_test( backend_tags = { "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error. }, - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", ], ) @@ -3521,13 +3663,14 @@ xla_test( xla_test( name = "constant_reduction_function_test", srcs = ["constant_reduction_function_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", - ":literal_test_util", - ":test_macros_header", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":xla_internal_test_main", - "//xla:types", "//xla/hlo/testlib:test", ], ) @@ -3548,18 +3691,23 @@ xla_cc_test( xla_test( name = "numerics_test", srcs = ["numerics_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", + "//xla:error_spec", "//xla:literal_util", "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -3569,54 +3717,59 @@ xla_test( backend_tags = { "gpu": ["notsan"], }, - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":literal_test_util", ":xla_internal_test_main", "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla/hlo/testlib:test", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) xla_test( name = "batch_norm_grad_test", srcs = ["batch_norm_grad_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep - "//xla:literal", "//xla:literal_util", - "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", ], ) xla_test( name = "batch_norm_training_test", srcs = ["batch_norm_training_test.cc"], - tags = ["test_xla_cpu_no_thunks"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", + ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":xla_internal_test_main", # fixdeps: keep - "//xla:literal", "//xla:literal_util", - "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/tests/axpy_simple_test.cc b/third_party/xla/xla/tests/axpy_simple_test.cc index e463a7af2a388e..df888a85bec082 100644 --- a/third_party/xla/xla/tests/axpy_simple_test.cc +++ b/third_party/xla/xla/tests/axpy_simple_test.cc @@ -20,15 +20,16 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -using AxpySimpleTest = ClientLibraryTestRunnerMixin; +using AxpySimpleTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(AxpySimpleTest, AxTenValues) { XlaBuilder builder("ax_10"); diff --git a/third_party/xla/xla/tests/batch_norm_grad_test.cc b/third_party/xla/xla/tests/batch_norm_grad_test.cc index 74512febada3c5..22a4319cdd9dc8 100644 --- a/third_party/xla/xla/tests/batch_norm_grad_test.cc +++ b/third_party/xla/xla/tests/batch_norm_grad_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/status/status.h" #include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { @@ -38,7 +38,7 @@ const char* const kModuleStr = R"( } )"; -class BatchNormGradTest : public HloTestBase {}; +class BatchNormGradTest : public HloPjRtTestBase {}; TEST_F(BatchNormGradTest, CorrectComputation) { TF_ASSERT_OK_AND_ASSIGN(auto module, diff --git a/third_party/xla/xla/tests/batch_norm_training_test.cc b/third_party/xla/xla/tests/batch_norm_training_test.cc index 77386432c733b6..a13ff37a27f53e 100644 --- a/third_party/xla/xla/tests/batch_norm_training_test.cc +++ b/third_party/xla/xla/tests/batch_norm_training_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/status/status.h" #include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { } )"; -class BatchNormTrainingTest : public HloTestBase {}; +class BatchNormTrainingTest : public HloPjRtTestBase {}; TEST_F(BatchNormTrainingTest, CorrectComputation) { TF_ASSERT_OK_AND_ASSIGN(auto module, diff --git a/third_party/xla/xla/tests/bfloat16_test.cc b/third_party/xla/xla/tests/bfloat16_test.cc index 8e5c878d1f006f..aaf4838d8edb47 100644 --- a/third_party/xla/xla/tests/bfloat16_test.cc +++ b/third_party/xla/xla/tests/bfloat16_test.cc @@ -18,7 +18,8 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" @@ -28,7 +29,8 @@ namespace { constexpr ErrorSpec kErrorSpec{0.001, 0.001}; -using Bfloat16Test = ClientLibraryTestRunnerMixin; +using Bfloat16Test = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(Bfloat16Test, ScalarOperation) { XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/binop_scaling_test.cc b/third_party/xla/xla/tests/binop_scaling_test.cc index 5fbed99aefc1a6..10e136df031ee0 100644 --- a/third_party/xla/xla/tests/binop_scaling_test.cc +++ b/third_party/xla/xla/tests/binop_scaling_test.cc @@ -21,13 +21,15 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -using BinopScalingTest = ClientLibraryTestRunnerMixin; +using BinopScalingTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4); diff --git a/third_party/xla/xla/tests/bitcast_convert_test.cc b/third_party/xla/xla/tests/bitcast_convert_test.cc index d102815aa86235..a5f25e51db2ff2 100644 --- a/third_party/xla/xla/tests/bitcast_convert_test.cc +++ b/third_party/xla/xla/tests/bitcast_convert_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" @@ -33,7 +33,9 @@ limitations under the License. namespace xla { namespace { -class BitcastConvertTest : public ClientLibraryTestRunnerMixin { +class BitcastConvertTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: BitcastConvertTest() { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); @@ -147,7 +149,8 @@ TEST_F(BitcastConvertTest, ConvertReshape) { ComputeAndCompareR0(&builder, 42.0f, {}); } -class BitcastConvertHloTest : public HloTestBase {}; +class BitcastConvertHloTest + : public HloPjRtInterpreterReferenceMixin {}; TEST_F(BitcastConvertHloTest, S32to4S8) { absl::string_view hlo_string = R"( diff --git a/third_party/xla/xla/tests/call_test.cc b/third_party/xla/xla/tests/call_test.cc index 45c2f4cb39c0ea..328a18acea4920 100644 --- a/third_party/xla/xla/tests/call_test.cc +++ b/third_party/xla/xla/tests/call_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "xla/error_spec.h" @@ -25,8 +24,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -34,7 +33,8 @@ limitations under the License. namespace xla { namespace { -class CallOpTest : public ClientLibraryTestRunnerMixin { +class CallOpTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: XlaComputation CreateR0F32IdentityComputation() { XlaBuilder builder("Identity"); diff --git a/third_party/xla/xla/tests/check_execution_arity_test.cc b/third_party/xla/xla/tests/check_execution_arity_test.cc index c2fd56d738fc6c..d25fb3c9b06440 100644 --- a/third_party/xla/xla/tests/check_execution_arity_test.cc +++ b/third_party/xla/xla/tests/check_execution_arity_test.cc @@ -23,7 +23,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -34,7 +35,8 @@ namespace { using ::testing::ContainsRegex; class CheckExecutionArityTest - : public ClientLibraryTestRunnerMixin {}; + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); diff --git a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc index c76df8f8a30694..3e900f682df03c 100644 --- a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc +++ b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -30,13 +31,15 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/collective_pipeliner.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/service/hlo_verifier.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" namespace xla { namespace { -using CollectivePipelinerExecutionTest = HloTestBase; +using CollectivePipelinerExecutionTest = HloPjRtTestBase; absl::StatusOr RunOptimizer( HloModule* module, bool last_run, int64_t level_to_operate_on = 0, diff --git a/third_party/xla/xla/tests/complex_unary_op_test.cc b/third_party/xla/xla/tests/complex_unary_op_test.cc index 661a16fa958155..eaf59acd7e6e4a 100644 --- a/third_party/xla/xla/tests/complex_unary_op_test.cc +++ b/third_party/xla/xla/tests/complex_unary_op_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/complex_unary_op_samples.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -36,7 +36,9 @@ namespace { template constexpr bool dependent_false = false; -class ComplexUnaryOpTest : public ClientLibraryTestRunnerMixin { +class ComplexUnaryOpTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: template std::vector get_column(const std::vector>& table) { diff --git a/third_party/xla/xla/tests/concatenate_test.cc b/third_party/xla/xla/tests/concatenate_test.cc index e69018cb77175f..373fc9590e3249 100644 --- a/third_party/xla/xla/tests/concatenate_test.cc +++ b/third_party/xla/xla/tests/concatenate_test.cc @@ -27,16 +27,16 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -using ConcatenateTest = HloTestBase; +using ConcatenateTest = HloPjRtTestBase; TEST_F(ConcatenateTest, TwoR3Axis1) { const std::string hlo_text_module = R"( diff --git a/third_party/xla/xla/tests/constant_reduction_function_test.cc b/third_party/xla/xla/tests/constant_reduction_function_test.cc index 4c2529ca46f33e..744c33ba3b55a4 100644 --- a/third_party/xla/xla/tests/constant_reduction_function_test.cc +++ b/third_party/xla/xla/tests/constant_reduction_function_test.cc @@ -18,20 +18,18 @@ limitations under the License. #include #include -#include #include "xla/hlo/testlib/test.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "xla/types.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" namespace xla { namespace { using std::nullopt; -class ConstantReductionFunctionTest : public HloTestBase {}; +using ConstantReductionFunctionTest = + HloPjRtInterpreterReferenceMixin; TEST_F(ConstantReductionFunctionTest, Bool) { const std::string& hlo_string = R"( diff --git a/third_party/xla/xla/tests/constants_test.cc b/third_party/xla/xla/tests/constants_test.cc index 239c134af63052..9dcfac4bb9252b 100644 --- a/third_party/xla/xla/tests/constants_test.cc +++ b/third_party/xla/xla/tests/constants_test.cc @@ -31,7 +31,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -44,7 +45,8 @@ namespace { constexpr ErrorSpec kErrorSpec{1e-3, 1e-5}; -using ConstantsTest = ClientLibraryTestRunnerMixin; +using ConstantsTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; template class ConstantsFloatTest : public ConstantsTest {}; @@ -249,7 +251,7 @@ TEST_F(ConstantsTest, FullLikeScalar) { ComputeAndCompareR0(&b, -1, {}, kErrorSpec); } -using ConstantsHloTest = HloTestBase; +using ConstantsHloTest = HloPjRtTestBase; // TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. TEST_F(ConstantsHloTest, DISABLED_ON_TPU(DISABLED_ON_GPU(BitcastOfConstant))) { diff --git a/third_party/xla/xla/tests/conv_depthwise_common.cc b/third_party/xla/xla/tests/conv_depthwise_common.cc index 09cd38576322fa..86ce74d7a6f2eb 100644 --- a/third_party/xla/xla/tests/conv_depthwise_common.cc +++ b/third_party/xla/xla/tests/conv_depthwise_common.cc @@ -15,17 +15,14 @@ limitations under the License. #include "xla/tests/conv_depthwise_common.h" -#include +#include -#include "xla/execution_options_util.h" -#include "xla/hlo/builder/xla_computation.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "xla/hlo/testlib/test.h" -#include "xla/hlo/transforms/despecializer.h" -#include "xla/hlo/transforms/simplifiers/float_normalization.h" -#include "xla/status_macros.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" namespace xla { std::string GetFloatDataType(bool use_bfloat16) { diff --git a/third_party/xla/xla/tests/conv_depthwise_common.h b/third_party/xla/xla/tests/conv_depthwise_common.h index 010dde84898815..9ea231e826eb79 100644 --- a/third_party/xla/xla/tests/conv_depthwise_common.h +++ b/third_party/xla/xla/tests/conv_depthwise_common.h @@ -16,17 +16,12 @@ limitations under the License. #ifndef XLA_TESTS_CONV_DEPTHWISE_COMMON_H_ #define XLA_TESTS_CONV_DEPTHWISE_COMMON_H_ -#include +#include +#include +#include -#include "xla/execution_options_util.h" -#include "xla/hlo/builder/xla_computation.h" +#include #include "xla/hlo/testlib/test.h" -#include "xla/hlo/transforms/despecializer.h" -#include "xla/hlo/transforms/simplifiers/float_normalization.h" -#include "xla/status_macros.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" namespace xla { std::string GetFloatDataType(bool use_bfloat16); diff --git a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc index 2b7e718966e7b5..2106f82f0cd13a 100644 --- a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc +++ b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/reference_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/test.h" @@ -59,7 +59,8 @@ absl::StatusOr CreateConvDimensionNumbers( } class ConvolutionDimensionNumbersTest - : public ClientLibraryTestRunnerMixin {}; + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { diff --git a/third_party/xla/xla/tests/convolution_test_1d.cc b/third_party/xla/xla/tests/convolution_test_1d.cc index fe82387ec0924b..39ea0d7ece1b32 100644 --- a/third_party/xla/xla/tests/convolution_test_1d.cc +++ b/third_party/xla/xla/tests/convolution_test_1d.cc @@ -29,7 +29,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -45,7 +46,8 @@ constexpr ErrorSpec kErrorSpec(1e-2, 1e-3); constexpr ErrorSpec kErrorSpec(1e-4, 1e-3); #endif -using ConvolutionTest = ClientLibraryTestRunnerMixin; +using ConvolutionTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; #ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 using TestTypes = ::testing::Types; diff --git a/third_party/xla/xla/tests/convolution_variants_test.cc b/third_party/xla/xla/tests/convolution_variants_test.cc index b7284a297715f3..f001e646a0a08d 100644 --- a/third_party/xla/xla/tests/convolution_variants_test.cc +++ b/third_party/xla/xla/tests/convolution_variants_test.cc @@ -34,8 +34,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -61,7 +61,8 @@ XlaOp ConvWithHighestPrecision(const XlaOp lhs, const XlaOp rhs, /*batch_group_count=*/1, &precision_config); } -using ConvolutionVariantsTest = ClientLibraryTestRunnerMixin; +using ConvolutionVariantsTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(ConvolutionVariantsTest, Minimal) { XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/dynamic_ops_test.cc b/third_party/xla/xla/tests/dynamic_ops_test.cc index 5ca6bba5aeb48e..29e86ec9fad702 100644 --- a/third_party/xla/xla/tests/dynamic_ops_test.cc +++ b/third_party/xla/xla/tests/dynamic_ops_test.cc @@ -44,8 +44,8 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" @@ -54,7 +54,9 @@ limitations under the License. namespace xla { namespace { -class DynamicSliceTest : public ClientLibraryTestRunnerMixin { +class DynamicSliceTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: template void TestR1() { @@ -318,7 +320,8 @@ TEST_F(DynamicSliceTest, Int32R3Pred) { } class DynamicUpdateSliceTest - : public ClientLibraryTestRunnerMixin { + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: template void TestR0() { @@ -750,9 +753,11 @@ TEST_F(DynamicUpdateSliceTest, R3ContiguousLargerBF16) { RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } +using DynamicOpsTest = HloPjRtInterpreterReferenceMixin; + // This test that buffer assignment does not alias constants with the output of // dynamic update slice. -TEST_F(HloTestBase, AddOfDUS) { +TEST_F(DynamicOpsTest, AddOfDUS) { const char* hlo_string = R"( HloModule m test { @@ -773,7 +778,7 @@ TEST_F(HloTestBase, AddOfDUS) { // and multiple output fusions of dynamic update slices produce the right // results. On some backends (e.g. GPU), this is done inplace. #ifdef XLA_TEST_BACKEND_GPU -TEST_F(HloTestBase, MultipleOutputFusedDynamicUpdateSlices) { +TEST_F(DynamicOpsTest, MultipleOutputFusedDynamicUpdateSlices) { const char* hlo_string = R"( HloModule MultipleInplaceDus, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } @@ -803,7 +808,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -TEST_F(HloTestBase, +TEST_F(DynamicOpsTest, MultipleOutputFusedDynamicUpdateSlicesWithTransposeBitcastedRoot) { const char* hlo_string = R"( HloModule MultipleInplaceDusWithTransposeBitcastToTheRoot, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } @@ -835,7 +840,8 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithTransposeBitcastedRoot) { +TEST_F(DynamicOpsTest, + SingleFusedDynamicUpdateSliceWithTransposeBitcastedRoot) { const char* hlo_string = R"( HloModule SingleInplaceDusWithTransposeBitcastToTheRoot, input_output_alias={ {}: (0, {}) } @@ -862,7 +868,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithReshapeBitcastedRoot) { +TEST_F(DynamicOpsTest, SingleFusedDynamicUpdateSliceWithReshapeBitcastedRoot) { const char* hlo_string = R"( HloModule SingleInplaceDusWithReshapeBitcastToTheRoot, input_output_alias={ {}: (0, {}) } @@ -889,7 +895,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -TEST_F(HloTestBase, +TEST_F(DynamicOpsTest, SingleFusedDynamicUpdateSliceWithBitcastedRootAndParameter) { const char* hlo_string = R"( HloModule SingleInplaceDusWithBitcastToTheRootAndFromTheParameter, input_output_alias={ {}: (0, {}) } @@ -919,7 +925,8 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -TEST_F(HloTestBase, SingleFusedDynamicUpdateSliceWithSameDynamicSliceAccess) { +TEST_F(DynamicOpsTest, + SingleFusedDynamicUpdateSliceWithSameDynamicSliceAccess) { const char* hlo_string = R"( HloModule fusion, input_output_alias={ {}: (0, {}) } @@ -945,7 +952,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0, 0})); } -TEST_F(HloTestBase, +TEST_F(DynamicOpsTest, SingleFusedDynamicUpdateSliceWithDynamicSliceAccessSlicesOfSizeOne) { const char* hlo_string = R"( HloModule fusion, input_output_alias={ {}: (0, {}) } diff --git a/third_party/xla/xla/tests/float8_test.cc b/third_party/xla/xla/tests/float8_test.cc index 2b110e521a39af..4d5ac803be56f6 100644 --- a/third_party/xla/xla/tests/float8_test.cc +++ b/third_party/xla/xla/tests/float8_test.cc @@ -17,7 +17,8 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "tsl/platform/ml_dtypes.h" @@ -27,7 +28,8 @@ namespace { // Test FP8 floating-point types template -class Float8Test : public ClientLibraryTestRunnerMixin {}; +class Float8Test : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> {}; using DataTypes = ::testing::Types; diff --git a/third_party/xla/xla/tests/floor_ceil_test.cc b/third_party/xla/xla/tests/floor_ceil_test.cc index b55ff38888c3e7..7e5469cb2e8e1e 100644 --- a/third_party/xla/xla/tests/floor_ceil_test.cc +++ b/third_party/xla/xla/tests/floor_ceil_test.cc @@ -20,14 +20,15 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -class FloorCeilTest : public ClientLibraryTestRunnerMixin { +class FloorCeilTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: enum Function { kFloor, diff --git a/third_party/xla/xla/tests/fmax_fmin_test.cc b/third_party/xla/xla/tests/fmax_fmin_test.cc index 2869afa6aa5b1e..ec10bdeb3fe3f1 100644 --- a/third_party/xla/xla/tests/fmax_fmin_test.cc +++ b/third_party/xla/xla/tests/fmax_fmin_test.cc @@ -20,14 +20,15 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -using FmaxSimpleTest = ClientLibraryTestRunnerMixin; +using FmaxSimpleTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(FmaxSimpleTest, FmaxTenValues) { SetFastMathDisabled(true); diff --git a/third_party/xla/xla/tests/get_dimension_size_test.cc b/third_party/xla/xla/tests/get_dimension_size_test.cc index 3c815fd989d17b..342d21ee181e7e 100644 --- a/third_party/xla/xla/tests/get_dimension_size_test.cc +++ b/third_party/xla/xla/tests/get_dimension_size_test.cc @@ -16,13 +16,15 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { @@ -33,7 +35,7 @@ void DisableAllHloPasses(HloModule& module) { module.mutable_config().set_debug_options(debug_options); } -class GetDimensionSizeTest : public HloTestBase {}; +using GetDimensionSizeTest = HloPjRtInterpreterReferenceMixin; // Test that the interpreter can correctly compute get_dimension_size. TEST_F(GetDimensionSizeTest, CorrectComputation) { diff --git a/third_party/xla/xla/tests/half_test.cc b/third_party/xla/xla/tests/half_test.cc index 639e529d0ffc61..6e145dcf34539b 100644 --- a/third_party/xla/xla/tests/half_test.cc +++ b/third_party/xla/xla/tests/half_test.cc @@ -26,7 +26,8 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/testlib/test.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" @@ -42,7 +43,8 @@ using UnaryBuildFuncTy = std::function; constexpr int kNumElements = 4; constexpr ErrorSpec kErrorSpec{0.001, 0.001}; -using HalfTestBase = ClientLibraryTestRunnerMixin; +using HalfTestBase = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; struct UnaryOpTestParam { std::function compute_func; diff --git a/third_party/xla/xla/tests/iota_test.cc b/third_party/xla/xla/tests/iota_test.cc index 34afcad89c4efe..66b77396c182e1 100644 --- a/third_party/xla/xla/tests/iota_test.cc +++ b/third_party/xla/xla/tests/iota_test.cc @@ -26,7 +26,8 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" @@ -36,7 +37,7 @@ limitations under the License. namespace xla { namespace { -using IotaTest = HloTestBase; +using IotaTest = HloPjRtInterpreterReferenceMixin; TEST_F(IotaTest, IotaReshapeR1) { const std::string hlo_text = R"( @@ -70,7 +71,8 @@ std::vector GetR1Expected(const int64_t num_elements) { } class IotaR1Test - : public ClientLibraryTestRunnerMixin, + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface> {}; XLA_TEST_P(IotaR1Test, DoIt) { @@ -116,7 +118,8 @@ INSTANTIATE_TEST_CASE_P( /*end=*/10001, /*step=*/10))); -class IotaR2Test : public ClientLibraryTestRunnerMixin, +class IotaR2Test : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface< std::tuple> {}; @@ -154,7 +157,8 @@ INSTANTIATE_TEST_CASE_P( /*step=*/10), ::testing::Values(0, 1))); -class IotaR3Test : public ClientLibraryTestRunnerMixin, +class IotaR3Test : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface< std::tuple> {}; diff --git a/third_party/xla/xla/tests/log_test.cc b/third_party/xla/xla/tests/log_test.cc index e58fc2daff9efc..69250508363136 100644 --- a/third_party/xla/xla/tests/log_test.cc +++ b/third_party/xla/xla/tests/log_test.cc @@ -20,14 +20,15 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -class LogTest : public ClientLibraryTestRunnerMixin {}; +class LogTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> {}; TEST_F(LogTest, LogZeroValues) { XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/multidimensional_slice_test.cc b/third_party/xla/xla/tests/multidimensional_slice_test.cc index 86361a4c638929..e416630ef8dd85 100644 --- a/third_party/xla/xla/tests/multidimensional_slice_test.cc +++ b/third_party/xla/xla/tests/multidimensional_slice_test.cc @@ -21,14 +21,15 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -class SliceTest : public ClientLibraryTestRunnerMixin {}; +class SliceTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> {}; TEST_F(SliceTest, Slice2D) { XlaBuilder builder("slice_2d"); diff --git a/third_party/xla/xla/tests/nccl_group_execution_test.cc b/third_party/xla/xla/tests/nccl_group_execution_test.cc index aef4335c642588..f2920ea449090f 100644 --- a/third_party/xla/xla/tests/nccl_group_execution_test.cc +++ b/third_party/xla/xla/tests/nccl_group_execution_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/service/hlo_module_config.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -34,10 +34,10 @@ namespace { // Tests NCCL group execution. -class NcclGroupExecutionTest : public HloTestBase { +class NcclGroupExecutionTest : public HloPjRtTestBase { public: NcclGroupExecutionTest() { - VLOG(1) << "Running with " << num_devices() << " devices"; + VLOG(1) << "Running with " << test_runner().device_count() << " devices"; } }; diff --git a/third_party/xla/xla/tests/numerics_test.cc b/third_party/xla/xla/tests/numerics_test.cc index b5f82d78c86766..efc8ca97c05dff 100644 --- a/third_party/xla/xla/tests/numerics_test.cc +++ b/third_party/xla/xla/tests/numerics_test.cc @@ -18,21 +18,23 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace { -using NumericsTest = HloTestBase; +using NumericsTest = HloPjRtInterpreterReferenceMixin; -XLA_TEST_F(NumericsTest, AbsOfLargeComplexNumber) { +TEST_F(NumericsTest, AbsOfLargeComplexNumber) { const char* hlo = R"( HloModule module @@ -54,7 +56,7 @@ ENTRY entry { EXPECT_TRUE(abs_of_complex_x(1e30)); } -XLA_TEST_F(NumericsTest, PowerOfLargeComplexNumber) { +TEST_F(NumericsTest, PowerOfLargeComplexNumber) { const char* hlo = R"( HloModule module @@ -91,8 +93,8 @@ ENTRY entry { // CPU thunks backend (due to incorrect LLVM IR generated). // This is an HLO module optimized for CPU backend, it may be invalid for other // backends. -XLA_TEST_F(NumericsTest, - DISABLED_ON_GPU(DISABLED_ON_TPU(MultiplySubtractConcatTest))) { +TEST_F(NumericsTest, + DISABLED_ON_GPU(DISABLED_ON_TPU(MultiplySubtractConcatTest))) { const char* test_hlo = R"( HloModule jit_step, is_scheduled=true diff --git a/third_party/xla/xla/tests/pred_test.cc b/third_party/xla/xla/tests/pred_test.cc index 204a3cc538c72e..12fd3b679949c9 100644 --- a/third_party/xla/xla/tests/pred_test.cc +++ b/third_party/xla/xla/tests/pred_test.cc @@ -22,13 +22,15 @@ limitations under the License. #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" namespace xla { namespace { -class PredTest : public ClientLibraryTestRunnerMixin { +class PredTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: void TestCompare(bool lhs, bool rhs, bool expected, std::function + #include "xla/debug_options_flags.h" +#include "xla/error_spec.h" #include "xla/hlo/testlib/test.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/service/hlo_module_config.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class PtxasBugTest : public HloTestBase {}; +using PtxasBugTest = HloPjRtInterpreterReferenceMixin; // Checks for a bug in ptxas, tracked as Google bug 120501638, and nvidia bug // 2459377. We never received an explanation of what exactly was going wrong diff --git a/third_party/xla/xla/tests/reduce_precision_test.cc b/third_party/xla/xla/tests/reduce_precision_test.cc index 35e5179846018d..8d8f2a0158af6a 100644 --- a/third_party/xla/xla/tests/reduce_precision_test.cc +++ b/third_party/xla/xla/tests/reduce_precision_test.cc @@ -23,7 +23,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" @@ -447,7 +448,8 @@ static const uint64_t f64_test_values[][4] = { }; class ReducedPrecisionAccuracyTest - : public ClientLibraryTestRunnerMixin, + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface { protected: template diff --git a/third_party/xla/xla/tests/reverse_test.cc b/third_party/xla/xla/tests/reverse_test.cc index 469d8dc5d59c88..b37fea698da546 100644 --- a/third_party/xla/xla/tests/reverse_test.cc +++ b/third_party/xla/xla/tests/reverse_test.cc @@ -32,7 +32,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/client_library_test_runner_utils.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -81,7 +82,8 @@ void PrintTo(const ReverseSpec& spec, std::ostream* os) { *os << spec.ToTestCaseName(); } -class FloatReverseTest : public ClientLibraryTestRunnerMixin, +class FloatReverseTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface { public: FloatReverseTest() { set_float_type(GetParam().test_type); } @@ -122,7 +124,8 @@ INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, ::testing::PrintToStringParamName()); // A simple test class which not templated by float precision. -using ReverseTest = ClientLibraryTestRunnerMixin; +using ReverseTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; // Tests the reverse operation on a 4D U8 array on dimension 0 and 3. TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { diff --git a/third_party/xla/xla/tests/rng_test.cc b/third_party/xla/xla/tests/rng_test.cc index a9e2896f97b34d..8b1d1263a5eea7 100644 --- a/third_party/xla/xla/tests/rng_test.cc +++ b/third_party/xla/xla/tests/rng_test.cc @@ -24,15 +24,15 @@ limitations under the License. #include "xla/hlo/transforms/expanders/rng_expander.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -using RngTest = HloTestBase; +using RngTest = HloPjRtTestBase; void DisableHloPass(HloModule& module, absl::string_view pass_name) { auto debug_options = module.config().debug_options(); @@ -77,7 +77,7 @@ XLA_TEST_F(RngTest, ReturnsErrorWhenExpanderPassDisabled) { ::testing::HasSubstr("Rng should be expanded for CPU")); } -using RngBitGeneratorTest = HloTestBase; +using RngBitGeneratorTest = HloPjRtTestBase; XLA_TEST_F(RngBitGeneratorTest, ReturnsErrorWhenExpanderPassDisabled_Default) { const char* const kModuleStr = R"( diff --git a/third_party/xla/xla/tests/runtime_topk_test.cc b/third_party/xla/xla/tests/runtime_topk_test.cc index b913e3e4d361ac..16d9ffdb18aaef 100644 --- a/third_party/xla/xla/tests/runtime_topk_test.cc +++ b/third_party/xla/xla/tests/runtime_topk_test.cc @@ -20,15 +20,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla::cpu { namespace { -class TopkTest : public HloTestBase {}; +class TopkTest : public HloPjRtTestBase {}; XLA_TEST_F(TopkTest, CustomCallTarget) { absl::string_view hlo_text_module = R"( diff --git a/third_party/xla/xla/tests/sample_text_test.cc b/third_party/xla/xla/tests/sample_text_test.cc index 3b5dc2692149de..e7f0ca3dbde6e5 100644 --- a/third_party/xla/xla/tests/sample_text_test.cc +++ b/third_party/xla/xla/tests/sample_text_test.cc @@ -18,20 +18,19 @@ limitations under the License. #include #include -#include +#include "xla/error_spec.h" #include "xla/hlo/testlib/test.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "xla/types.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" namespace xla { namespace { using std::nullopt; -class SampleTextTest : public HloTestBase {}; +class SampleTextTest + : public HloPjRtInterpreterReferenceMixin {}; TEST_F(SampleTextTest, Axpy) { const std::string& hlo_string = R"( diff --git a/third_party/xla/xla/tests/scalar_computations_test.cc b/third_party/xla/xla/tests/scalar_computations_test.cc index d71ea615a9032d..7310446d1e080b 100644 --- a/third_party/xla/xla/tests/scalar_computations_test.cc +++ b/third_party/xla/xla/tests/scalar_computations_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -31,7 +30,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/statusor.h" @@ -44,7 +44,8 @@ namespace { constexpr ErrorSpec kErrorSpec{0.0001}; class ScalarComputationsTest - : public ClientLibraryTestRunnerMixin { + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: // A template for building and running a binary comparison test. template @@ -280,7 +281,8 @@ void PrintTo(const DivS32Params& p, std::ostream* os) { << p.remainder << "}"; } -class DivS32Test : public ClientLibraryTestRunnerMixin, +class DivS32Test : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface {}; XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { diff --git a/third_party/xla/xla/tests/select_and_scatter_test.cc b/third_party/xla/xla/tests/select_and_scatter_test.cc index cab2edef3838ad..b3595a640144c1 100644 --- a/third_party/xla/xla/tests/select_and_scatter_test.cc +++ b/third_party/xla/xla/tests/select_and_scatter_test.cc @@ -34,7 +34,8 @@ limitations under the License. #include "xla/hlo/builder/xla_computation.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -51,7 +52,8 @@ struct SelectAndScatterTestParam { }; class SelectAndScatterTest - : public ClientLibraryTestRunnerMixin, + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>, public ::testing::WithParamInterface { public: SelectAndScatterTest() : builder_(TestName()) { diff --git a/third_party/xla/xla/tests/select_test.cc b/third_party/xla/xla/tests/select_test.cc index 02339bf6fe1a5a..f575498bc38534 100644 --- a/third_party/xla/xla/tests/select_test.cc +++ b/third_party/xla/xla/tests/select_test.cc @@ -20,7 +20,8 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" @@ -30,7 +31,8 @@ namespace { constexpr ErrorSpec kErrorSpec{0.0001}; -using SelectTest = ClientLibraryTestRunnerMixin; +using SelectTest = ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>; TEST_F(SelectTest, SelectScalarF32True) { XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/set_dimension_size_test.cc b/third_party/xla/xla/tests/set_dimension_size_test.cc index 3674e582802647..c23f1494c7417d 100644 --- a/third_party/xla/xla/tests/set_dimension_size_test.cc +++ b/third_party/xla/xla/tests/set_dimension_size_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { @@ -45,7 +45,7 @@ void DisableAllHloPasses(HloModule& module) { module.mutable_config().set_debug_options(debug_options); } -class SetDimensionSizeTest : public HloTestBase {}; +class SetDimensionSizeTest : public HloPjRtTestBase {}; TEST_F(SetDimensionSizeTest, CorrectComputation) { TF_ASSERT_OK_AND_ASSIGN(auto module, diff --git a/third_party/xla/xla/tests/sort_test.cc b/third_party/xla/xla/tests/sort_test.cc index 974441cbb9fe23..ae00ac94b1c4e7 100644 --- a/third_party/xla/xla/tests/sort_test.cc +++ b/third_party/xla/xla/tests/sort_test.cc @@ -16,22 +16,22 @@ limitations under the License. #include #include -#include #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/error_spec.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class SortTest : public HloTestBase {}; +using SortTest = HloPjRtInterpreterReferenceMixin; -XLA_TEST_F(SortTest, SortDim0) { +TEST_F(SortTest, SortDim0) { absl::string_view hlo_text_module = R"( HloModule sort @@ -50,7 +50,7 @@ XLA_TEST_F(SortTest, SortDim0) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } -XLA_TEST_F(SortTest, SortDim1) { +TEST_F(SortTest, SortDim1) { absl::string_view hlo_text_module = R"( HloModule sort @@ -69,7 +69,7 @@ XLA_TEST_F(SortTest, SortDim1) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } -XLA_TEST_F(SortTest, SortTwiceWithSameComparator) { +TEST_F(SortTest, SortTwiceWithSameComparator) { absl::string_view hlo_text_module = R"( HloModule sort @@ -100,7 +100,7 @@ class SortManyInputsTest : public SortTest, } }; -XLA_TEST_P(SortManyInputsTest, SortManyInputs) { +TEST_P(SortManyInputsTest, SortManyInputs) { int num_inputs = GetParam(); absl::string_view hlo_text_module_template = R"( HloModule sort diff --git a/third_party/xla/xla/tests/stochastic_convert_test.cc b/third_party/xla/xla/tests/stochastic_convert_test.cc index a2c351114c19e0..71e27a0e7bf6f8 100644 --- a/third_party/xla/xla/tests/stochastic_convert_test.cc +++ b/third_party/xla/xla/tests/stochastic_convert_test.cc @@ -22,15 +22,17 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -using StochasticConvertTest = HloTestBase; +using StochasticConvertTest = HloPjRtInterpreterReferenceMixin; + const char* const kModuleStr = R"( HloModule stochastic-convert diff --git a/third_party/xla/xla/tests/topk_test.cc b/third_party/xla/xla/tests/topk_test.cc index 8ad1dd15cd1ebc..17117d01a412e6 100644 --- a/third_party/xla/xla/tests/topk_test.cc +++ b/third_party/xla/xla/tests/topk_test.cc @@ -16,15 +16,16 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "xla/error_spec.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { -class TopkTest : public HloTestBase {}; +using TopkTest = HloPjRtInterpreterReferenceMixin; -XLA_TEST_F(TopkTest, LargestTopK) { +TEST_F(TopkTest, LargestTopK) { absl::string_view hlo_text_module = R"( HloModule topk @@ -36,7 +37,7 @@ XLA_TEST_F(TopkTest, LargestTopK) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); } -XLA_TEST_F(TopkTest, SmallestTopK) { +TEST_F(TopkTest, SmallestTopK) { absl::string_view hlo_text_module = R"( HloModule topk @@ -48,7 +49,7 @@ XLA_TEST_F(TopkTest, SmallestTopK) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); } -XLA_TEST_F(TopkTest, TopKOfTranspose) { +TEST_F(TopkTest, TopKOfTranspose) { // Regression test for b/362565176 absl::string_view hlo_text_module = R"( HloModule topk diff --git a/third_party/xla/xla/tests/transpose_test.cc b/third_party/xla/xla/tests/transpose_test.cc index 28abe480abfca4..bd93fdf41bac7d 100644 --- a/third_party/xla/xla/tests/transpose_test.cc +++ b/third_party/xla/xla/tests/transpose_test.cc @@ -25,7 +25,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -36,7 +37,8 @@ namespace { constexpr ErrorSpec kErrorSpec{0.0001}; -class TransposeTest : public ClientLibraryTestRunnerMixin { +class TransposeTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: void TestTransposeConstant(Vector3 sizes, Vector3 transpose_dims) { Array3D aoperand(sizes[0], sizes[1], sizes[2]); @@ -199,7 +201,7 @@ TEST_F(TransposeTest, TransposeConstant210_DegenerateDim) { TestTransposeConstant({20, 30, 1}, {2, 1, 0}); } -using HloTransposeTest = HloTestBase; +using HloTransposeTest = HloPjRtTestBase; // Disable HLO passes to verify the default behavior TEST_F(HloTransposeTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( diff --git a/third_party/xla/xla/tests/unary_op_test.cc b/third_party/xla/xla/tests/unary_op_test.cc index 017e39f08bf478..1e9a9cbb102a05 100644 --- a/third_party/xla/xla/tests/unary_op_test.cc +++ b/third_party/xla/xla/tests/unary_op_test.cc @@ -22,7 +22,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -30,7 +31,8 @@ limitations under the License. namespace xla { namespace { -class UnaryOpTest : public ClientLibraryTestRunnerMixin { +class UnaryOpTest : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: template T inf() { diff --git a/third_party/xla/xla/tests/vector_ops_reduce_test.cc b/third_party/xla/xla/tests/vector_ops_reduce_test.cc index 834e8dcc867db6..846660ac69aa95 100644 --- a/third_party/xla/xla/tests/vector_ops_reduce_test.cc +++ b/third_party/xla/xla/tests/vector_ops_reduce_test.cc @@ -22,7 +22,8 @@ limitations under the License. #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -32,7 +33,9 @@ namespace { constexpr ErrorSpec kErrorSpec{1e-3, 0}; -class VecOpsReduceTest : public ClientLibraryTestRunnerMixin { +class VecOpsReduceTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: VecOpsReduceTest() : builder_(TestName()) {} diff --git a/third_party/xla/xla/tests/vector_ops_simple_test.cc b/third_party/xla/xla/tests/vector_ops_simple_test.cc index 1b9b78a70c0b81..05f8ca093b17c3 100644 --- a/third_party/xla/xla/tests/vector_ops_simple_test.cc +++ b/third_party/xla/xla/tests/vector_ops_simple_test.cc @@ -29,7 +29,8 @@ limitations under the License. #include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" @@ -38,7 +39,9 @@ namespace { constexpr ErrorSpec kErrorSpec{0.0001}; -class VecOpsSimpleTest : public ClientLibraryTestRunnerMixin { +class VecOpsSimpleTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { public: VecOpsSimpleTest() { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index c51dd9ca1f40f0..3dca1ce7e2d495 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -1033,15 +1033,16 @@ xla_test( xla_test( name = "hlo_decomposer_test", srcs = ["hlo_decomposer_test.cc"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":hlo_decomposer_lib", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", - "//xla/tests:hlo_test_base", + "//xla/tests:hlo_pjrt_test_base", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/tools/hlo_decomposer_test.cc b/third_party/xla/xla/tools/hlo_decomposer_test.cc index c38aa8faa53599..022f7f339dda36 100644 --- a/third_party/xla/xla/tools/hlo_decomposer_test.cc +++ b/third_party/xla/xla/tools/hlo_decomposer_test.cc @@ -23,13 +23,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { -class HloDecomposerTest : public HloTestBase { +class HloDecomposerTest : public HloPjRtTestBase { protected: std::unique_ptr GetModule() { absl::string_view kHlo = R"( From 9a6e7233746e33935acf2d96a7d28e724f3b912d Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Thu, 17 Apr 2025 10:59:50 -0700 Subject: [PATCH 0935/1324] TF convert_fake_quant_to_qdq pass migrates to TF quant dialect from lite quant dialect PiperOrigin-RevId: 748739866 --- .../common/tf_quantization_lib/BUILD | 14 ++- .../mlir/quantization/tensorflow/BUILD | 31 +++++++ .../passes/tf_convert_fake_quant_to_qdq.cc | 90 +++++++++++++++++++ .../tensorflow/passes/tf_passes.h | 41 +++++++++ .../tests/tf_convert_fake_quant_to_qdq.mlir | 44 +++++++++ 5 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD index 607faaab78f3b4..2ce3b743dcd766 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD @@ -101,10 +101,16 @@ td_library( gentbl_cc_library( name = "tf_quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = { - "tf_quantization_interface.h.inc": ["-gen-op-interface-decls"], - "tf_quantization_interface.cc.inc": ["-gen-op-interface-defs"], - }, + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "tf_quantization_interface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "tf_quantization_interface.cc.inc", + ), + ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 37edee990eb314..b057bcbe074659 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -409,6 +409,36 @@ cc_library( alwayslink = True, ) +cc_library( + name = "tf_passes", + srcs = [ + "passes/tf_convert_fake_quant_to_qdq.cc", + ], + hdrs = [ + "passes/tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib:tf_quantization_config", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:temp_fake_quant_utils", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], + # Alwayslink is required for registering the MLIR passes. + # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. + alwayslink = True, +) + cc_library( name = "quantize_preprocess", srcs = [ @@ -527,6 +557,7 @@ tf_cc_binary( srcs = ["passes/tf_quant_opt.cc"], deps = [ ":passes", + ":tf_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc new file mode 100644 index 00000000000000..d1a1fd04b30440 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc @@ -0,0 +1,90 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project // IWYU pragma: keep, for applyPatternsGreedily +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace quant { +namespace { + +class TFConvertFakeQuantToQdqPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFConvertFakeQuantToQdqPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-convert-fake-quant-to-qdq"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Convert Fake Quant op to quant.qcast and quant.dcast pairs"; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +static PassRegistration pass; + +void TFConvertFakeQuantToQdqPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + func::FuncOp func = getOperation(); + + if (failed(tf_quant::ConvertFakeQuantOps( + func, ctx, /*use_fake_quant_num_bits=*/false))) { + func.emitError() << "quant-convert-fake-quant-to-qdq pass failed."; + signalPassFailure(); + } + + // For removing dead FakeQuant* ops + RewritePatternSet patterns(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateTFConvertFakeQuantToQdqPass() { + return std::make_unique(); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h new file mode 100644 index 00000000000000..63712df2ffcf39 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace quant { + +// Converts FakeQuant ops to quant.qcast and quant.dcast (QDQ) pairs. +std::unique_ptr> +CreateTFConvertFakeQuantToQdqPass(); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir new file mode 100644 index 00000000000000..2909f73d4bba6b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir @@ -0,0 +1,44 @@ +// RUN: tf-quant-opt %s -tf-quant-convert-fake-quant-to-qdq | FileCheck %s + +func.func @fakeQuantArgs(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min = -0.1 : f32, max = 0.2 : f32, num_bits = 8 + } : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + func.return %0 : tensor<8x8x8x8xf32> +} +// CHECK: func @fakeQuantArgs +// CHECK-NEXT: %[[q:.*]] = "quantization.qcast"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "quantization.dcast"(%[[q]]) +// CHECK-NEXT: return %[[dq]] + +func.func @doNotHandleNonEightBitFakeQuant(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min = -0.1 : f32, max = 0.2 : f32, num_bits = 16 + } : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + func.return %0 : tensor<8x8x8x8xf32> +} +// CHECK: func @doNotHandleNonEightBitFakeQuant +// CHECK: tf.FakeQuantWithMinMaxArgs +// CHECK-NOT: "quantization.qcast" + +func.func @fakeQuantVars(%arg0: tensor<3xf32>, %arg1: tensor<4x3xf32>) -> (tensor<3xf32>, tensor<4x3xf32>) { + %cst = "tf.Const"() {value = dense<-0.950868546> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<9.951540e-01> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<[-0.5, -0.4, -0.7]> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_2 = "tf.Const"() {value = dense<[0.5, 0.6, 0.3]> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) { + device = "", narrow_range = false, num_bits = 8 : i64 + } : (tensor<3xf32>, tensor, tensor) -> tensor<3xf32> + %1 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg1, %cst_1, %cst_2) { + device = "", narrow_range = true, num_bits = 8 : i64 + } : (tensor<4x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<4x3xf32> + func.return %0, %1 : tensor<3xf32>, tensor<4x3xf32> +} + +// CHECK: %[[q1:.*]] = "quantization.qcast"(%arg0) +// CHECK-SAME: tensor<3x!quant.uniform> +// CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) +// CHECK: %[[q2:.*]] = "quantization.qcast"(%arg1) +// CHECK-SAME: tensor<4x3x!quant.uniform:f32:1, {0.003937007874015748,0.0039370079913477263:-25,0.003937007874015748:51}>> +// CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) +// CHECK: return %[[dq1]], %[[dq2]] From 783bfb1d5d1f9780d4a6f8c17580554ed8a787ff Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Thu, 17 Apr 2025 11:06:41 -0700 Subject: [PATCH 0936/1324] Fix mlir cast/dyn_cast/isa in tensorflow Use llvm::cast/dyn_cast/isa since alternatives are deprecated in https://github.com/llvm/llvm-project/pull/135556 PiperOrigin-RevId: 748742657 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 156 ++++++++++---------- 1 file changed, 78 insertions(+), 78 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 87f6754b4cc9da..8becf635922672 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -34,21 +34,21 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" //===----------------------------------------------------------------------===// // TFLite dialect string type - uses the TF string type as implementation //===----------------------------------------------------------------------===// -def TFL_Str : Type()">, +def TFL_Str : Type($_self)">, "TFLite string type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // TFLite dialect quint8 type - uses the TF quint8 type as implementation //===----------------------------------------------------------------------===// -def TFL_Quint8 : Type()">, +def TFL_Quint8 : Type($_self)">, "TFLite quint8 type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // Type that represents control dependencies //===----------------------------------------------------------------------===// -def TFL_Control: Type()">, "control">, +def TFL_Control: Type($_self)">, "control">, BuildableType<"$_builder.getType()">; @@ -151,10 +151,10 @@ def TFL_StatefulTensor : TypeAlias; // Returns true of operand is none type. class TFL_OperandIsNoneType : - CPred<"$_op.getOperand(" # i # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # i # ").getType())">; class TFL_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">; // TODO: Some of these could be generalized and/or moved to more general // location. @@ -162,52 +162,52 @@ class TFL_OperandIsUnrankedPred : class TFL_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; // Returns true if the n-th operand is ranked and has rank dim. class TFL_OperandHasKnownRank : And<[ - CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() == " + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() == " # dim>]>; // True if operand n is ranked and has a rank > dim. class TFL_OperandIsRankedAndHasDimPred : And<[ - CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() > " + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() > " # dim>]>; // Returns true if the n-th operand is ranked and has a dimension length = size // at the rank dim. class TFL_OperandDimEquals : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ").getType().cast()" + CPred<"llvm::cast($_op.getOperand(" # n # ").getType())" ".getShape()[" # dim # " ] == " # size>]>; // Returns true if the n-th operand is ranked and has a dimension length <= // size at the rank dim. class TFL_OperandDimIsAtMost : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ").getType().cast()" + CPred<"llvm::cast($_op.getOperand(" # n # ").getType())" ".getShape()[" # dim # " ] <= " # size>]>; // Returns true if the n-th operand has unknown rank or at least rank m. class TFL_OperandHasAtleastRank : PredOpTrait<"operand " # n # " is " # m # "-D", - Or<[CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() >= " # m>]>>; + Or<[CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() >= " # m>]>>; class TFL_OperandRankEquals1DimOfOperand : PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getRank() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[0]">]>>; + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getRank() == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[0]">]>>; class TFL_Operand0DOr1ElementTensor : PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", @@ -219,14 +219,14 @@ class TFL_Operand0DOr1ElementTensor : class TFL_OperandsHaveSameDims : Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getShape()[" # i # "] == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[" # j # "]">]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getShape()[" # i # "] == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[" # j # "]">]>; class TFL_OperandsHaveSameDimsTrait : PredOpTrait<"dim " # i # " of operand " # x # " equals to dim " # j # @@ -238,14 +238,14 @@ class TFL_OperandsHaveSameDimsTrait : class TFL_NumElementsEqualsDim : Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getNumElements() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[" # j # "]">]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getNumElements() == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[" # j # "]">]>; class TFL_NumElementsEqualsDimTrait : PredOpTrait<"operand " # x # " has num of elements equals to dim " # j # @@ -255,10 +255,10 @@ class TFL_NumElementsEqualsDimTrait : // Return true if number of elements of x-th operand equals to n. class TFL_NumElements : Or<[TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getNumElements() == " # n>]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getNumElements() == " # n>]>; class TFL_NumElementsTrait : PredOpTrait<"operand " # x # " has num of elements equals to " # n, @@ -268,16 +268,16 @@ class TFL_NumElementsTrait : // when used as element types. class TFL_TFTypesWithSameBits : And<[ - Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getResult(" # i # ")))">, CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>, - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # j # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_TFOperandTypesWithSameBits : And<[ - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # i # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>, - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # j # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_OperandIsNoneOrHasRank : @@ -285,21 +285,21 @@ class TFL_OperandIsNoneOrHasRank : Or<[ TFL_OperandIsNoneType, TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; class TFL_OperandIsNoneOrHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[ TFL_OperandIsNoneType, TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() <= " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() <= " # m>]>>; class TFL_OperandHasRankAtMostPred : Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() <= " # m>]>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() <= " # m>]>; class TFL_OperandHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", @@ -310,54 +310,54 @@ class TFL_OperandHasRankAtMost : class TFL_TransposeOperandHasEffectiveRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"GetSqueezedShape($_op.getOperand(" # n # - ")).cast().size() <= " # m>]>>; + CPred<"llvm::cast(GetSqueezedShape($_op.getOperand(" # n # + "))).size() <= " # m>]>>; class TFL_OperandHasRankAtLeast : PredOpTrait<"operand " # n # " is at least " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() >= " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() >= " # m>]>>; class TFL_OperandHasRankRange : PredOpTrait<"operand " # n # " has rank range [" # x # ", " # y # "]", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() " - ">= " # x # " && $_op.getOperand(" # n # ").getType().cast()." + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() " + ">= " # x # " && llvm::cast($_op.getOperand(" # n # ").getType())." "getRank() <= " # y>]>>; def TFL_FloatNonNegative : AttrConstraint< - CPred<"$_self.isa() && " - "!$_self.cast().getValue().isNegative()">, + CPred<"llvm::isa($_self) && " + "!llvm::cast($_self).getValue().isNegative()">, "whose value is non-negative">; def TFL_BoolTrue : AttrConstraint< - CPred<"$_self.isa() && $_self.cast().getValue()">, + CPred<"llvm::isa($_self) && llvm::cast($_self).getValue()">, "whose value is true">; def TFL_BoolFalse : AttrConstraint< - CPred<"$_self.isa() && !$_self.cast().getValue()">, + CPred<"llvm::isa($_self) && !llvm::cast($_self).getValue()">, "whose value is false">; class TFL_StringEqualsTo : AttrConstraint< - CPred<"$_self.cast().getValue() == \"" # value # "\"">, + CPred<"llvm::cast($_self).getValue() == \"" # value # "\"">, "whose value equals to '" # value # "'">; // Ensures the array attribute's size is within the given maximum size. class TFL_ArrayMaxCount : AttrConstraint< - CPred<"$_self.isa() && $_self.cast().size() <= " # n>, + CPred<"llvm::isa($_self) && llvm::cast($_self).size() <= " # n>, "whose size is at most " # n>; // Ensures the given integer attribute has the given value. class TFL_IntEqualsTo : AttrConstraint< - CPred<"$_self.isa() && " - "$_self.cast().getInt() == " # n>, + CPred<"llvm::isa($_self) && " + "llvm::cast($_self).getInt() == " # n>, "whose value is " # n>; // Ensures the given LSTMKernelType attribute has the given value. class TFL_LSTMKernelTypeEqualsTo : AttrConstraint< - CPred<"$_self.isa() && " - "$_self.cast().getValue() == " # value>, + CPred<"llvm::isa($_self) && " + "llvm::cast($_self).getValue() == " # value>, "whose value is " # value>; // This is a quantization-aware version of TCresVTEtIsSameAsOp @@ -769,11 +769,11 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [ let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult().getType().cast().getElementType(). + return llvm::cast(getResult().getType()).getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult().getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult().getType()).getElementType()) }]>; } @@ -801,11 +801,11 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [ let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult().getType().cast().getElementType(). + return llvm::cast(getResult().getType()).getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult().getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult().getType()).getElementType()) }]>; } @@ -3162,7 +3162,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [ let results = (outs TFL_TensorOf<[I32, I64]>:$output); DerivedTypeAttr out_type = DerivedTypeAttr<[{ - return getResult().getType().cast().getElementType(); + return llvm::cast(getResult().getType()).getElementType(); }]>; let hasOptions = 1; @@ -3935,9 +3935,9 @@ def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [ TFL_OperandHasRankAtMost<2, 1>, PredOpTrait<"the first operand should have a rank <= 2, when its rank is 2 and has static shape, the second dim should be <= 4", Or<[TFL_OperandIsUnrankedPred<0>, - CPred<"$_op.getOperand(0).getType().cast().getRank() <= 1">, - CPred<"$_op.getOperand(0).getType().cast().getRank() == 2 && !$_op.getOperand(0).getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(0).getType().cast().getRank() == 2 && $_op.getOperand(0).getType().cast().getShape()[1] <= 4">]>>]> { + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() <= 1">, + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() == 2 && !llvm::cast($_op.getOperand(0).getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() == 2 && llvm::cast($_op.getOperand(0).getType()).getShape()[1] <= 4">]>>]> { let summary = "Converts a sparse representation into a dense tensor."; let description = [{ @@ -4109,11 +4109,11 @@ value of `input` in the unique output `output`. In other words: ); DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ - return getResult(1).getType().cast().getElementType(). + return llvm::cast(getResult(1).getType()).getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult(1).getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult(1).getType()).getElementType()) }]>; let hasOptions = 1; From 8606b22676b0098b6bf70b5dcb4d4c476d6bf7ae Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 11:15:52 -0700 Subject: [PATCH 0937/1324] Refactor the interface to support writing model runtime info to proto directly. PiperOrigin-RevId: 748745974 --- tensorflow/lite/profiling/model_runtime_info.cc | 16 +++++++++++++--- tensorflow/lite/profiling/model_runtime_info.h | 6 ++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/profiling/model_runtime_info.cc b/tensorflow/lite/profiling/model_runtime_info.cc index f12f3fdbfe3b7f..915dbec0fc5539 100644 --- a/tensorflow/lite/profiling/model_runtime_info.cc +++ b/tensorflow/lite/profiling/model_runtime_info.cc @@ -162,10 +162,10 @@ TfLiteStatus TfliteNodeToNode(const TfLiteNode& node, return kTfLiteOk; } } // namespace -TfLiteStatus GenerateModelRuntimeInfo(const tflite::Interpreter& interpreter, - absl::string_view output_file_path) { - tflite::profiling::ModelRuntimeDetails model_runtime_details; +TfLiteStatus GenerateModelRuntimeInfo( + const tflite::Interpreter& interpreter, + ModelRuntimeDetails& model_runtime_details) { const size_t num_subgraphs = interpreter.subgraphs_size(); for (int i = 0; i < num_subgraphs; ++i) { @@ -224,7 +224,17 @@ TfLiteStatus GenerateModelRuntimeInfo(const tflite::Interpreter& interpreter, runtime_subgraph->mutable_execution_plan()->Add( subgraph.execution_plan().begin(), subgraph.execution_plan().end()); } + return kTfLiteOk; +} +TfLiteStatus GenerateModelRuntimeInfo(const tflite::Interpreter& interpreter, + absl::string_view output_file_path) { + ModelRuntimeDetails model_runtime_details; + auto status = GenerateModelRuntimeInfo(interpreter, model_runtime_details); + if (status != kTfLiteOk) { + TFLITE_LOG(ERROR) << "Failed to generate model runtime info: " << status; + return status; + } std::ofstream ofs(std::string(output_file_path), std::ios::out | std::ios::binary); if (ofs.good()) { diff --git a/tensorflow/lite/profiling/model_runtime_info.h b/tensorflow/lite/profiling/model_runtime_info.h index a88b80d22814fb..70f1dae41e2cc5 100644 --- a/tensorflow/lite/profiling/model_runtime_info.h +++ b/tensorflow/lite/profiling/model_runtime_info.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/lite/core/interpreter.h" +#include "tensorflow/lite/profiling/proto/model_runtime_info.pb.h" namespace tflite { namespace profiling { @@ -26,6 +27,11 @@ namespace profiling { // the given output file path. TfLiteStatus GenerateModelRuntimeInfo(const Interpreter &interpreter, absl::string_view output_file_path); + +// Generates a ModelRuntimeInfo proto for the given interpreter and writes it to +// the given model_runtime_details proto. +TfLiteStatus GenerateModelRuntimeInfo( + const Interpreter &interpreter, ModelRuntimeDetails &model_runtime_details); } // namespace profiling } // namespace tflite From 110c6bbdfd429218e694963b3a931fa5c11bc67e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 12:41:41 -0700 Subject: [PATCH 0938/1324] Add a test for gpusolver PiperOrigin-RevId: 748775329 --- .../xla/xla/service/gpu/transforms/BUILD | 18 +++++- .../gpu/transforms/gpusolver_rewriter_test.cc | 62 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter_test.cc diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 94cf69298e5da7..ec16816bdf7dae 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -2106,7 +2106,6 @@ xla_cc_test( ], ) -# TODO(b/358278858): Currently lacking test coverage. cc_library( name = "gpusolver_rewriter", srcs = if_gpu_is_configured(["gpusolver_rewriter.cc"]), @@ -2141,6 +2140,23 @@ cc_library( ], ) +xla_test( + name = "gpusolver_rewriter_test", + srcs = ["gpusolver_rewriter_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":gpusolver_rewriter", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/service:pattern_matcher", + "//xla/stream_executor/cuda:cuda_solver_context", # fixdeps: keep + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "horizontal_input_fusion", srcs = ["horizontal_input_fusion.cc"], diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter_test.cc new file mode 100644 index 00000000000000..a7b4800b6734dc --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/gpusolver_rewriter.h" + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/service/pattern_matcher.h" +#include "xla/stream_executor/cuda/cuda_solver_context.h" + +namespace xla { +namespace gpu { +namespace { + +namespace m = ::xla::match; + +class GpusolverRewriterTest : public HloHardwareIndependentTestBase { + public: + GpusolverRewriter gpusolver_rewriter_{ + stream_executor::CudaSolverContext::Create}; +}; + +TEST_F(GpusolverRewriterTest, CholeskyTest) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule CholeskyTest + + ENTRY entry_computation { + input = f32[1,256,256] parameter(0) + ROOT decomp = f32[1,256,256] cholesky(input) + } +)") + .value(); + + EXPECT_TRUE(gpusolver_rewriter_.Run(module.get()).value()); + + const HloInstruction* entry_root = + module->entry_computation()->root_instruction(); + ASSERT_THAT( + entry_root, + GmockMatch(m::Select( + m::Broadcast( + m::Compare(m::GetTupleElement(), m::Broadcast(m::Constant()))), + m::GetTupleElement(m::CustomCall()), m::Broadcast(m::Constant())))); +} +} // namespace +} // namespace gpu +} // namespace xla From ba382d828f7542c90b19922081132526838fb4d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Apr 2025 13:17:17 -0700 Subject: [PATCH 0939/1324] Change MSA to use absl::HashOf() in computing instruction fingerprints across HLO live range. PiperOrigin-RevId: 748787512 --- .../xla/service/memory_space_assignment/BUILD | 1 + .../memory_space_assignment/algorithm.cc | 18 ++++---------- .../memory_space_assignment/algorithm.h | 24 +++++++++++++++++-- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 43d988f00fc870..c9301f59aff3cb 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -596,6 +596,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index e38b8ee4f3298e..fa8240d5ec3f6d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -457,18 +458,7 @@ MsaAlgorithm::MsaAlgorithm(AllocationSequence* allocations, options.cost_analysis->GetInstructionElapsed(*inst); if (options_.use_repeated_instance_for_preferred_prefetch_time || options_.memory_bound_loop_optimizer_options.enabled()) { - std::string fingerprint; - absl::StrAppend(&fingerprint, inst->shape().ToString(), " ", - HloOpcodeString(inst->opcode()), "("); - for (int operand_idx = 0; operand_idx < inst->operands().size(); - ++operand_idx) { - if (operand_idx > 0) { - absl::StrAppend(&fingerprint, ", "); - } - absl::StrAppend(&fingerprint, - inst->operand(operand_idx)->shape().ToString()); - } - absl::StrAppend(&fingerprint, ")"); + uint64_t fingerprint = absl::HashOf(MsaInstructionFingerprint(inst)); fingerprint_map_[inst] = fingerprint; repeated_inst_map_[fingerprint].push_back(inst); } @@ -1137,7 +1127,7 @@ bool AreOperandCandidatesCompatible(int loop_size_candidate, } // namespace void MsaAlgorithm::IdentifyAndOptimizeMemoryBoundLoops() { - absl::flat_hash_map fingerprint_schedule_map; + absl::flat_hash_map fingerprint_schedule_map; const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); // The minimum and maximum loop sizes that we consider. @@ -1185,7 +1175,7 @@ void MsaAlgorithm::IdentifyAndOptimizeMemoryBoundLoops() { << " fingerprint: " << (fingerprint_it == fingerprint_map_.end() ? "none" - : fingerprint_it->second); + : std::to_string(fingerprint_it->second)); } VLOG(3) << "Loop size candidate: " << loop_size_candidate; if (loop_size_candidate == -1) { diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 7d71b245be81d4..a45a368bc79d57 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -281,6 +281,26 @@ class AsynchronousCopyResource { std::vector> delay_changes_; }; +// Helper class to compute a minimal fingerprint of an HloInstruction and it's +// operand shapes for MSA. +class MsaInstructionFingerprint { + public: + explicit MsaInstructionFingerprint(const HloInstruction* instruction) + : inst_(instruction) {}; + + template + friend H AbslHashValue(H h, const MsaInstructionFingerprint& fp) { + for (const HloInstruction* operand : fp.inst_->operands()) { + h = H::combine(std::move(h), operand->shape()); + } + return H::combine(std::move(h), fp.inst_->opcode(), + fp.inst_->operand_count(), fp.inst_->shape()); + } + + private: + const HloInstruction* inst_; +}; + // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of // maximum size. // @@ -1110,10 +1130,10 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { int64_t memory_pressure_ = 0; int64_t next_async_copy_id_ = 0; // Fingerprint cache. - absl::flat_hash_map fingerprint_map_; + absl::flat_hash_map fingerprint_map_; // Vector of repeated instructions (that have the same fingerprint) indexed by // fingerprint. - absl::flat_hash_map> + absl::flat_hash_map> repeated_inst_map_; // Loop-optimized allocations found by MemoryBoundLoopOptimizer. These From 3168fc534b39f7fe525fd2f0a9df532199c9ce81 Mon Sep 17 00:00:00 2001 From: Francisco Unda Date: Thu, 17 Apr 2025 13:18:00 -0700 Subject: [PATCH 0940/1324] Remove old genrules in favor of direct dep on jni. PiperOrigin-RevId: 748787777 --- tensorflow/java/src/main/native/BUILD | 70 ++++++++++----------------- tensorflow/lite/java/jni/BUILD | 65 ++++++++++--------------- tensorflow/opensource_only.files | 2 - 3 files changed, 51 insertions(+), 86 deletions(-) diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD index 527288dc5348ad..f7fd51a7e9ccb5 100644 --- a/tensorflow/java/src/main/native/BUILD +++ b/tensorflow/java/src/main/native/BUILD @@ -4,29 +4,37 @@ load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_library") -package(default_visibility = [ - "//tensorflow/java:__pkg__", - "//tensorflow/tools/android/inference_interface:__pkg__", -]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/java:__pkg__", + # TODO(ashankar): Temporary hack for the Java API and + # //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_inference_jni + # to co-exist in a single shared library. However, the hope is that + # //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_jni can be + # removed once the Java API provides feature parity with it. + "//tensorflow/tools/android/inference_interface:__pkg__", + ], + licenses = ["notice"], +) -licenses(["notice"]) # Apache 2.0 +filegroup( + name = "native_srcs", + srcs = glob([ + "*.cc", + "*.h", + ]), + visibility = ["//visibility:public"], +) tf_cuda_library( name = "native", - srcs = glob(["*.cc"]) + select({ - # The Android toolchain makes "jni.h" available in the include path. - # For non-Android toolchains, generate jni.h and jni_md.h. - "//tensorflow:android": [], - "//conditions:default": [ - ":jni.h", - ":jni_md.h", - ], - }), + srcs = glob(["*.cc"]), hdrs = glob(["*.h"]), copts = tf_copts(), - includes = select({ - "//tensorflow:android": [], - "//conditions:default": ["."], + features = select({ + "//tensorflow:android": ["-layering_check"], + "//conditions:default": [], }), deps = select({ "//tensorflow:android": [ @@ -38,34 +46,8 @@ tf_cuda_library( "//tensorflow/core:all_kernels", "//tensorflow/core:direct_session", "//tensorflow/core:ops", + "@bazel_tools//tools/jdk:jni", ], }), alwayslink = 1, ) - -# Silly rules to make -# #include -# in the source headers work -# (in combination with the "includes" attribute of the tf_cuda_library rule -# above. Not needed when using the Android toolchain). -# -# Inspired from: -# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD -# but hopefully there is a simpler alternative to this. -genrule( - name = "copy_jni_h", - srcs = ["@bazel_tools//tools/jdk:jni_header"], - outs = ["jni.h"], - cmd = "cp -f $< $@", -) - -genrule( - name = "copy_jni_md_h", - srcs = select({ - "//tensorflow:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"], - "//tensorflow:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], - "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], - }), - outs = ["jni_md.h"], - cmd = "cp -f $< $@", -) diff --git a/tensorflow/lite/java/jni/BUILD b/tensorflow/lite/java/jni/BUILD index 137ca32b0489d5..9638ee99882514 100644 --- a/tensorflow/lite/java/jni/BUILD +++ b/tensorflow/lite/java/jni/BUILD @@ -1,48 +1,33 @@ -package(default_visibility = ["//tensorflow/lite:__subpackages__"]) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") -licenses(["notice"]) # Apache 2.0 +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite:__subpackages__"], + licenses = ["notice"], +) -# Helper target for exposing JNI headers across multiple platforms. -cc_library( +# We need special handling for JNI inclusion for Android. Rather than duplicating this logic +# for every target that uses JNI, we use a single proxy target that +# encapsulates it. +alias( name = "jni", - hdrs = select({ - # The Android toolchain makes "jni.h" available in the include path. - # For non-Android toolchains, generate jni.h and jni_md.h. - "//tensorflow:android": [], - "//conditions:default": [ - ":jni.h", - ":jni_md.h", - ], - }), - includes = select({ - "//tensorflow:android": [], - "//conditions:default": ["."], + actual = select({ + # The Android toolchain makes available in the system include + # path. + # Aliases need to resolve to a single target however, so alias to an + # empty library instead. + # (Making this target a cc_library with empty deps for the Android case + # doesn't work, because go/cpp-features#layering-check requires targets + # to _directly_ depend on libraries they include, and cc_library doesn't + # have any direct equivalent to java_library's 'export' attribute). + "//tensorflow:android": ":empty", + # For non-Android toolchains, depend on the JDK JNI headers. + "//conditions:default": "@bazel_tools//tools/jdk:jni", }), visibility = ["//visibility:public"], ) -# Silly rules to make -# #include -# in the source headers work -# (in combination with the "includes" attribute of the tf_cuda_library rule -# above. Not needed when using the Android toolchain). -# -# Inspired from: -# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD -# but hopefully there is a simpler alternative to this. -genrule( - name = "copy_jni_h", - srcs = ["@bazel_tools//tools/jdk:jni_header"], - outs = ["jni.h"], - cmd = "cp -f $< $@", -) - -genrule( - name = "copy_jni_md_h", - srcs = select({ - "//tensorflow:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], - "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], - }), - outs = ["jni_md.h"], - cmd = "cp -f $< $@", +cc_library( + name = "empty", + compatible_with = get_compatible_with_portable(), ) diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 3281c096528b7e..e3aaeeddca20de 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -53,7 +53,6 @@ tf_staging/tensorflow/go/stream_executor/BUILD: tf_staging/tensorflow/go/tsl/profiler/protobuf/BUILD: tf_staging/tensorflow/go/tsl/protobuf/BUILD: tf_staging/tensorflow/java/README.md: -tf_staging/tensorflow/java/src/main/native/BUILD: tf_staging/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h: tf_staging/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h: tf_staging/tensorflow/lite/acceleration/configuration/c/nnapi_plugin.h: @@ -112,7 +111,6 @@ tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rule tf_staging/tensorflow/lite/interpreter.h: tf_staging/tensorflow/lite/interpreter_builder.h: tf_staging/tensorflow/lite/ios/BUILD: -tf_staging/tensorflow/lite/java/jni/BUILD: tf_staging/tensorflow/lite/kernels/builtin_op_kernels.h: tf_staging/tensorflow/lite/kernels/register.h: tf_staging/tensorflow/lite/lib_package/BUILD: From 7ca72402d56908867c8e90f0ed5c418c1f6b5630 Mon Sep 17 00:00:00 2001 From: Daniel Chen Date: Thu, 17 Apr 2025 14:40:41 -0700 Subject: [PATCH 0941/1324] Make profile data optional and skip rendering if not present. PiperOrigin-RevId: 748816708 --- .../xla/hlo/tools/hlo_diff/hlo_diff_main.cc | 5 +- .../xla/xla/hlo/tools/hlo_diff/render/BUILD | 26 +++++- .../render/hlo_gumgraph_html_renderer.cc | 79 +++++++++++-------- .../render/hlo_gumgraph_html_renderer.h | 36 +++++---- .../render/hlo_gumgraph_html_renderer_test.cc | 71 +++++++++++++++++ .../tools/hlo_diff/render/op_metric_getter.h | 40 ++++++++++ 6 files changed, 205 insertions(+), 52 deletions(-) create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer_test.cc create mode 100644 third_party/xla/xla/hlo/tools/hlo_diff/render/op_metric_getter.h diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc index 89c4c734ce8172..837756d514a607 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc @@ -174,10 +174,7 @@ absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, std::string html_output = opts.render_options.html_output; if (!html_output.empty()) { std::ostringstream html; - RenderHtml( - diff, diff_summary, nullptr, - [](absl::string_view op_name) { return std::nullopt; }, - [](absl::string_view op_name) { return std::nullopt; }, html); + RenderHtml(diff, diff_summary, html); TF_RETURN_IF_ERROR( tsl::WriteStringToFile(tsl::Env::Default(), html_output, html.str())); diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD index b451a0a2615c71..1d38111affefe7 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/BUILD @@ -65,12 +65,13 @@ cc_library( deps = [ ":graph_url_generator", ":hlo_gumgraph_renderer_util", + ":op_metric_getter", "//xla/hlo/ir:hlo", "//xla/hlo/tools/hlo_diff:hlo_diff_result", "//xla/hlo/tools/hlo_diff:hlo_diff_summary", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", @@ -78,6 +79,20 @@ cc_library( ], ) +xla_cc_test( + name = "hlo_gumgraph_html_renderer_test", + srcs = ["hlo_gumgraph_html_renderer_test.cc"], + deps = [ + ":hlo_gumgraph_html_renderer", + ":op_metric_getter", + "//xla/hlo/tools/hlo_diff:hlo_diff_result", + "//xla/hlo/tools/hlo_diff:hlo_diff_summary", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "graph_url_generator", hdrs = ["graph_url_generator.h"], @@ -86,3 +101,12 @@ cc_library( "@com_google_absl//absl/strings:string_view", ], ) + +cc_library( + name = "op_metric_getter", + hdrs = ["op_metric_getter.h"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc index 36b3cf4da50053..0d0c803779bb37 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -25,6 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -36,6 +36,7 @@ #include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" #include "xla/hlo/tools/hlo_diff/render/graph_url_generator.h" #include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_renderer_util.h" +#include "xla/hlo/tools/hlo_diff/render/op_metric_getter.h" namespace xla { namespace hlo_diff { @@ -196,7 +197,8 @@ std::string PrintDetails(absl::string_view summary, absl::string_view content) { // Prints a link to the given url. std::string PrintLink(absl::string_view text, absl::string_view url) { - return absl::StrFormat("%s", url, text); + return absl::StrFormat(R"html(%s)html", url, + text); } // Prints a html block with a header. @@ -495,11 +497,12 @@ std::string PrintUnchangedInstructions( std::string PrintUnmatchedMetricsDiff( const absl::flat_hash_set& instructions, - GetOpMetricFn get_op_metrics, GraphUrlGenerator* url_generator) { + const OpMetricGetter& op_metric_getter, GraphUrlGenerator* url_generator) { std::vector> sorted_metrics_diff; for (const HloInstruction* inst : instructions) { - if (auto metric = get_op_metrics(inst->name()); metric.has_value()) { - sorted_metrics_diff.push_back({inst, static_cast(*metric)}); + if (auto time_ps = op_metric_getter.GetOpTimePs(inst->name()); + time_ps.ok()) { + sorted_metrics_diff.push_back({inst, static_cast(*time_ps)}); } } @@ -517,19 +520,23 @@ std::string PrintUnmatchedMetricsDiff( std::string PrintMatchedMetricsDiff( const absl::flat_hash_map& instructions, - GetOpMetricFn left_op_metrics, GetOpMetricFn right_op_metrics, + const OpMetricGetter& left_op_metric_getter, + const OpMetricGetter& right_op_metric_getter, GraphUrlGenerator* url_generator) { std::vector, double>> sorted_metrics_diff; for (const auto& [left_inst, right_inst] : instructions) { - auto left_metric = left_op_metrics(left_inst->name()); - auto right_metric = right_op_metrics(right_inst->name()); - if (left_metric.has_value() && right_metric.has_value()) { - sorted_metrics_diff.push_back( - {{left_inst, right_inst}, - static_cast(*left_metric - *right_metric)}); + absl::StatusOr left_time_ps = + left_op_metric_getter.GetOpTimePs(left_inst->name()); + absl::StatusOr right_time_ps = + right_op_metric_getter.GetOpTimePs(right_inst->name()); + if (!left_time_ps.ok() || !right_time_ps.ok()) { + continue; } + sorted_metrics_diff.push_back( + {{left_inst, right_inst}, + static_cast(*right_time_ps - *left_time_ps)}); } std::sort(sorted_metrics_diff.begin(), sorted_metrics_diff.end()); std::vector metrics_diff_list(sorted_metrics_diff.size()); @@ -621,8 +628,10 @@ std::string PrintRepetitiveDiffPatterns( } // namespace void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, - GraphUrlGenerator* url_generator, GetOpMetricFn left_op_metrics, - GetOpMetricFn right_op_metrics, std::ostringstream& out) { + GraphUrlGenerator* url_generator, + OpMetricGetter* left_op_metric_getter, + OpMetricGetter* right_op_metric_getter, + std::ostringstream& out) { const absl::flat_hash_set ignored_opcodes(kIgnoredOpcodes.begin(), kIgnoredOpcodes.end()); out << PrintCss() << PrintJavascript(); @@ -654,25 +663,29 @@ void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, ignored_opcodes, url_generator)))); // Print profile metrics diff - out << PrintSectionWithHeader( - "Profile Metrics Diff", - absl::StrCat( - PrintDetails("Left Module Unmatched Instructions", - PrintUnmatchedMetricsDiff( - diff_result.left_module_unmatched_instructions, - left_op_metrics, url_generator)), - PrintDetails("Right Module Unmatched Instructions", - PrintUnmatchedMetricsDiff( - diff_result.right_module_unmatched_instructions, - right_op_metrics, url_generator)), - PrintDetails("Changed Instructions", - PrintMatchedMetricsDiff( - diff_result.changed_instructions, left_op_metrics, - right_op_metrics, url_generator)), - PrintDetails("Unchanged Instructions", - PrintMatchedMetricsDiff( - diff_result.unchanged_instructions, left_op_metrics, - right_op_metrics, url_generator)))); + if (left_op_metric_getter != nullptr && right_op_metric_getter != nullptr) { + out << PrintSectionWithHeader( + "Profile Metrics Diff", + absl::StrCat( + PrintDetails("Left Module Unmatched Instructions", + PrintUnmatchedMetricsDiff( + diff_result.left_module_unmatched_instructions, + *left_op_metric_getter, url_generator)), + PrintDetails("Right Module Unmatched Instructions", + PrintUnmatchedMetricsDiff( + diff_result.right_module_unmatched_instructions, + *right_op_metric_getter, url_generator)), + PrintDetails( + "Changed Instructions", + PrintMatchedMetricsDiff( + diff_result.changed_instructions, *left_op_metric_getter, + *right_op_metric_getter, url_generator)), + PrintDetails( + "Unchanged Instructions", + PrintMatchedMetricsDiff( + diff_result.unchanged_instructions, *left_op_metric_getter, + *right_op_metric_getter, url_generator)))); + } // Print repetitive computation groups out << PrintSectionWithHeader( diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h index 9076a5b09270e1..121340dbe8e4f4 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h @@ -17,31 +17,39 @@ #ifndef XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_HTML_RENDERER_H_ #define XLA_HLO_TOOLS_HLO_DIFF_RENDER_HLO_GUMGRAPH_HTML_RENDERER_H_ -#include -#include #include -#include "absl/functional/function_ref.h" -#include "absl/strings/string_view.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" #include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" #include "xla/hlo/tools/hlo_diff/render/graph_url_generator.h" +#include "xla/hlo/tools/hlo_diff/render/op_metric_getter.h" namespace xla { namespace hlo_diff { -// A function that returns the op metric for the given op name. -using GetOpMetricFn = - absl::FunctionRef(absl::string_view)>; - // Renders the diff result in HTML format, and writes the result to the given -// output stream. - -// url_generator can be specified which is used to link an url to each -// generated diff result. +// output stream. url_generator can be specified which is used to link an url to +// each generated diff result. void RenderHtml(const DiffResult& diff_result, const DiffSummary& diff_summary, - GraphUrlGenerator* url_generator, GetOpMetricFn left_op_metrics, - GetOpMetricFn right_op_metrics, std::ostringstream& out); + GraphUrlGenerator* url_generator, + OpMetricGetter* left_op_metric_getter, + OpMetricGetter* right_op_metric_getter, + std::ostringstream& out); +inline void RenderHtml(const DiffResult& diff_result, + const DiffSummary& diff_summary, + GraphUrlGenerator* url_generator, + std::ostringstream& out) { + RenderHtml(diff_result, diff_summary, url_generator, + /*left_op_metric_getter=*/nullptr, + /*right_op_metric_getter=*/nullptr, out); +} +inline void RenderHtml(const DiffResult& diff_result, + const DiffSummary& diff_summary, + std::ostringstream& out) { + RenderHtml(diff_result, diff_summary, /*url_generator=*/nullptr, + /*left_op_metric_getter=*/nullptr, + /*right_op_metric_getter=*/nullptr, out); +} } // namespace hlo_diff } // namespace xla diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer_test.cc b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer_test.cc new file mode 100644 index 00000000000000..33196073b8d82b --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/tools/hlo_diff/render/hlo_gumgraph_html_renderer.h" + +#include +#include + +#include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h" +#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h" +#include "xla/hlo/tools/hlo_diff/render/op_metric_getter.h" + +namespace xla { +namespace hlo_diff { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +// A mock OpMetricGetter for testing. +class MockOpMetricGetter : public OpMetricGetter { + public: + MOCK_METHOD(absl::StatusOr, GetOpTimePs, (absl::string_view), + (const, override)); +}; + +TEST(HloGumgraphHtmlRendererTest, RenderHtml) { + DiffResult diff_result; + DiffSummary diff_summary; + std::ostringstream out; + RenderHtml(diff_result, diff_summary, out); + EXPECT_THAT(out.str(), HasSubstr("